summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.bazelrc60
-rw-r--r--.github/ISSUE_TEMPLATE/bug_report.md31
-rw-r--r--.github/ISSUE_TEMPLATE/config.yml11
-rw-r--r--.github/ISSUE_TEMPLATE/feature_request.md21
-rw-r--r--.github/labeler.yml42
-rw-r--r--.github/pull_request_template.md5
-rw-r--r--.github/workflows/build.yml21
-rw-r--r--.github/workflows/go.yml66
-rw-r--r--.github/workflows/issue_reviver.yml14
-rw-r--r--.github/workflows/labeler.yml12
-rw-r--r--.github/workflows/stale.yml20
-rw-r--r--.gitignore2
-rw-r--r--.travis.yml45
-rw-r--r--AUTHORS8
-rw-r--r--BUILD94
-rw-r--r--CODE_OF_CONDUCT.md91
-rw-r--r--CONTRIBUTING.md129
-rw-r--r--GOVERNANCE.md113
-rw-r--r--LICENSE202
-rw-r--r--Makefile254
-rw-r--r--README.md122
-rw-r--r--SECURITY.md10
-rw-r--r--WORKSPACE559
-rw-r--r--benchmarks/BUILD29
-rw-r--r--benchmarks/README.md186
-rw-r--r--benchmarks/defs.bzl14
-rw-r--r--benchmarks/examples/localhost.yaml2
-rw-r--r--benchmarks/harness/BUILD202
-rw-r--r--benchmarks/harness/__init__.py62
-rw-r--r--benchmarks/harness/benchmark_driver.py85
-rw-r--r--benchmarks/harness/container.py181
-rw-r--r--benchmarks/harness/machine.py265
-rw-r--r--benchmarks/harness/machine_mocks/BUILD9
-rw-r--r--benchmarks/harness/machine_mocks/__init__.py81
-rw-r--r--benchmarks/harness/machine_producers/BUILD84
-rw-r--r--benchmarks/harness/machine_producers/__init__.py13
-rw-r--r--benchmarks/harness/machine_producers/gcloud_mock_recorder.py97
-rw-r--r--benchmarks/harness/machine_producers/gcloud_producer.py250
-rw-r--r--benchmarks/harness/machine_producers/gcloud_producer_test.py48
-rw-r--r--benchmarks/harness/machine_producers/machine_producer.py51
-rw-r--r--benchmarks/harness/machine_producers/mock_producer.py52
-rw-r--r--benchmarks/harness/machine_producers/testdata/get_five.json211
-rw-r--r--benchmarks/harness/machine_producers/testdata/get_one.json145
-rw-r--r--benchmarks/harness/machine_producers/yaml_producer.py106
-rw-r--r--benchmarks/harness/ssh_connection.py126
-rw-r--r--benchmarks/harness/tunnel_dispatcher.py122
-rw-r--r--benchmarks/requirements.txt32
-rw-r--r--benchmarks/run.py19
-rw-r--r--benchmarks/runner/BUILD56
-rw-r--r--benchmarks/runner/__init__.py308
-rw-r--r--benchmarks/runner/commands.py135
-rw-r--r--benchmarks/runner/runner_test.py59
-rw-r--r--benchmarks/suites/BUILD130
-rw-r--r--benchmarks/suites/__init__.py119
-rw-r--r--benchmarks/suites/absl.py37
-rw-r--r--benchmarks/suites/density.py121
-rw-r--r--benchmarks/suites/fio.py165
-rw-r--r--benchmarks/suites/helpers.py57
-rw-r--r--benchmarks/suites/http.py138
-rw-r--r--benchmarks/suites/media.py42
-rw-r--r--benchmarks/suites/ml.py33
-rw-r--r--benchmarks/suites/network.py101
-rw-r--r--benchmarks/suites/redis.py46
-rw-r--r--benchmarks/suites/startup.py110
-rw-r--r--benchmarks/suites/sysbench.py119
-rw-r--r--benchmarks/suites/syscall.py37
-rw-r--r--benchmarks/tcp/BUILD41
-rw-r--r--benchmarks/tcp/README.md87
-rw-r--r--benchmarks/tcp/nsjoin.c47
-rwxr-xr-xbenchmarks/tcp/tcp_benchmark.sh392
-rw-r--r--benchmarks/tcp/tcp_proxy.go451
-rw-r--r--benchmarks/workloads/BUILD35
-rw-r--r--benchmarks/workloads/__init__.py14
-rw-r--r--benchmarks/workloads/ab/BUILD28
-rw-r--r--benchmarks/workloads/ab/Dockerfile15
-rw-r--r--benchmarks/workloads/ab/__init__.py88
-rw-r--r--benchmarks/workloads/ab/ab_test.py42
-rw-r--r--benchmarks/workloads/absl/BUILD28
-rw-r--r--benchmarks/workloads/absl/Dockerfile25
-rw-r--r--benchmarks/workloads/absl/__init__.py63
-rw-r--r--benchmarks/workloads/absl/absl_test.py31
-rw-r--r--benchmarks/workloads/curl/BUILD13
-rw-r--r--benchmarks/workloads/curl/Dockerfile14
-rw-r--r--benchmarks/workloads/ffmpeg/BUILD18
-rw-r--r--benchmarks/workloads/ffmpeg/Dockerfile10
-rw-r--r--benchmarks/workloads/ffmpeg/__init__.py20
-rw-r--r--benchmarks/workloads/fio/BUILD28
-rw-r--r--benchmarks/workloads/fio/Dockerfile23
-rw-r--r--benchmarks/workloads/fio/__init__.py369
-rw-r--r--benchmarks/workloads/fio/fio_test.py44
-rw-r--r--benchmarks/workloads/httpd/BUILD14
-rw-r--r--benchmarks/workloads/httpd/Dockerfile27
-rw-r--r--benchmarks/workloads/httpd/apache2-tmpdir.conf5
-rw-r--r--benchmarks/workloads/iperf/BUILD28
-rw-r--r--benchmarks/workloads/iperf/Dockerfile14
-rw-r--r--benchmarks/workloads/iperf/__init__.py40
-rw-r--r--benchmarks/workloads/iperf/iperf_test.py28
-rw-r--r--benchmarks/workloads/netcat/BUILD13
-rw-r--r--benchmarks/workloads/netcat/Dockerfile14
-rw-r--r--benchmarks/workloads/nginx/BUILD13
-rw-r--r--benchmarks/workloads/nginx/Dockerfile1
-rw-r--r--benchmarks/workloads/node/BUILD15
-rw-r--r--benchmarks/workloads/node/Dockerfile2
-rw-r--r--benchmarks/workloads/node/index.js28
-rw-r--r--benchmarks/workloads/node/package.json19
-rw-r--r--benchmarks/workloads/node_template/BUILD17
-rw-r--r--benchmarks/workloads/node_template/Dockerfile5
-rw-r--r--benchmarks/workloads/node_template/index.hbs8
-rw-r--r--benchmarks/workloads/node_template/index.js43
-rw-r--r--benchmarks/workloads/node_template/package-lock.json486
-rw-r--r--benchmarks/workloads/node_template/package.json19
-rw-r--r--benchmarks/workloads/redis/BUILD13
-rw-r--r--benchmarks/workloads/redis/Dockerfile1
-rw-r--r--benchmarks/workloads/redisbenchmark/BUILD28
-rw-r--r--benchmarks/workloads/redisbenchmark/Dockerfile4
-rw-r--r--benchmarks/workloads/redisbenchmark/__init__.py85
-rw-r--r--benchmarks/workloads/redisbenchmark/redisbenchmark_test.py51
-rw-r--r--benchmarks/workloads/ruby/BUILD28
-rw-r--r--benchmarks/workloads/ruby/Dockerfile28
-rw-r--r--benchmarks/workloads/ruby/Gemfile12
-rw-r--r--benchmarks/workloads/ruby/Gemfile.lock71
-rwxr-xr-xbenchmarks/workloads/ruby/config.ru2
-rwxr-xr-xbenchmarks/workloads/ruby/index.rb14
-rw-r--r--benchmarks/workloads/ruby_template/BUILD18
-rwxr-xr-xbenchmarks/workloads/ruby_template/Dockerfile38
-rwxr-xr-xbenchmarks/workloads/ruby_template/Gemfile5
-rw-r--r--benchmarks/workloads/ruby_template/Gemfile.lock26
-rwxr-xr-xbenchmarks/workloads/ruby_template/config.ru2
-rwxr-xr-xbenchmarks/workloads/ruby_template/index.erb8
-rwxr-xr-xbenchmarks/workloads/ruby_template/main.rb27
-rw-r--r--benchmarks/workloads/sleep/BUILD13
-rw-r--r--benchmarks/workloads/sleep/Dockerfile3
-rw-r--r--benchmarks/workloads/sysbench/BUILD28
-rw-r--r--benchmarks/workloads/sysbench/Dockerfile16
-rw-r--r--benchmarks/workloads/sysbench/__init__.py167
-rw-r--r--benchmarks/workloads/sysbench/sysbench_test.py34
-rw-r--r--benchmarks/workloads/syscall/BUILD29
-rw-r--r--benchmarks/workloads/syscall/Dockerfile6
-rw-r--r--benchmarks/workloads/syscall/__init__.py29
-rw-r--r--benchmarks/workloads/syscall/syscall.c55
-rw-r--r--benchmarks/workloads/syscall/syscall_test.py27
-rw-r--r--benchmarks/workloads/tensorflow/BUILD18
-rw-r--r--benchmarks/workloads/tensorflow/Dockerfile14
-rw-r--r--benchmarks/workloads/tensorflow/__init__.py20
-rw-r--r--benchmarks/workloads/true/BUILD14
-rw-r--r--benchmarks/workloads/true/Dockerfile3
-rw-r--r--g3doc/BUILD44
-rw-r--r--g3doc/Layers.pngbin0 -> 11044 bytes
-rw-r--r--g3doc/Layers.svg1
-rw-r--r--g3doc/Machine-Virtualization.pngbin0 -> 13205 bytes
-rw-r--r--g3doc/Machine-Virtualization.svg1
-rw-r--r--g3doc/README.md168
-rw-r--r--g3doc/Rule-Based-Execution.pngbin0 -> 6780 bytes
-rw-r--r--g3doc/Rule-Based-Execution.svg1
-rw-r--r--g3doc/Sentry-Gofer.pngbin0 -> 9064 bytes
-rw-r--r--g3doc/Sentry-Gofer.svg1
-rw-r--r--g3doc/architecture_guide/BUILD50
-rw-r--r--g3doc/architecture_guide/performance.md277
-rw-r--r--g3doc/architecture_guide/platforms.md61
-rw-r--r--g3doc/architecture_guide/platforms.pngbin0 -> 21384 bytes
-rw-r--r--g3doc/architecture_guide/platforms.svg334
-rw-r--r--g3doc/architecture_guide/resources.md144
-rw-r--r--g3doc/architecture_guide/resources.pngbin0 -> 16621 bytes
-rw-r--r--g3doc/architecture_guide/resources.svg208
-rw-r--r--g3doc/architecture_guide/security.md255
-rw-r--r--g3doc/architecture_guide/security.pngbin0 -> 16932 bytes
-rw-r--r--g3doc/architecture_guide/security.svg153
-rw-r--r--g3doc/community.md31
-rw-r--r--g3doc/logo.pngbin0 -> 27719 bytes
-rw-r--r--g3doc/logo.txt1
-rw-r--r--g3doc/roadmap.md49
-rw-r--r--g3doc/style.md88
-rw-r--r--g3doc/user_guide/BUILD70
-rw-r--r--g3doc/user_guide/FAQ.md122
-rw-r--r--g3doc/user_guide/checkpoint_restore.md101
-rw-r--r--g3doc/user_guide/compatibility.md93
-rw-r--r--g3doc/user_guide/debugging.md141
-rw-r--r--g3doc/user_guide/filesystem.md60
-rw-r--r--g3doc/user_guide/install.md157
-rw-r--r--g3doc/user_guide/networking.md85
-rw-r--r--g3doc/user_guide/platforms.md95
-rw-r--r--g3doc/user_guide/quick_start/BUILD33
-rw-r--r--g3doc/user_guide/quick_start/docker.md96
-rw-r--r--g3doc/user_guide/quick_start/kubernetes.md36
-rw-r--r--g3doc/user_guide/quick_start/oci.md43
-rw-r--r--g3doc/user_guide/tutorials/BUILD37
-rw-r--r--g3doc/user_guide/tutorials/add-node-pool.pngbin0 -> 70208 bytes
-rw-r--r--g3doc/user_guide/tutorials/cni.md174
-rw-r--r--g3doc/user_guide/tutorials/docker.md68
-rw-r--r--g3doc/user_guide/tutorials/kubernetes.md134
-rw-r--r--g3doc/user_guide/tutorials/node-pool-button.pngbin0 -> 13757 bytes
-rw-r--r--go.mod20
-rw-r--r--go.sum32
-rw-r--r--images/BUILD11
-rw-r--r--images/Makefile93
-rw-r--r--images/README.md61
-rw-r--r--images/basic/alpine/Dockerfile1
-rw-r--r--images/basic/busybox/Dockerfile1
-rw-r--r--images/basic/httpd/Dockerfile1
-rw-r--r--images/basic/mysql/Dockerfile1
-rw-r--r--images/basic/nginx/Dockerfile1
-rw-r--r--images/basic/python/Dockerfile2
-rw-r--r--images/basic/resolv/Dockerfile1
-rw-r--r--images/basic/ruby/Dockerfile1
-rw-r--r--images/basic/tomcat/Dockerfile1
-rw-r--r--images/basic/ubuntu/Dockerfile1
-rw-r--r--images/default/Dockerfile16
-rw-r--r--images/hostoverlaytest/Dockerfile7
-rw-r--r--images/hostoverlaytest/test.c88
-rw-r--r--images/hostoverlaytest/testfile.txt1
-rw-r--r--images/iptables/Dockerfile2
-rw-r--r--images/jekyll/Dockerfile13
-rw-r--r--images/packetdrill/Dockerfile8
-rw-r--r--images/packetimpact/Dockerfile16
-rw-r--r--images/runtimes/go1.12/Dockerfile4
-rw-r--r--images/runtimes/java11/Dockerfile22
-rw-r--r--images/runtimes/nodejs12.4.0/Dockerfile21
-rw-r--r--images/runtimes/php7.3.6/Dockerfile19
-rw-r--r--images/runtimes/python3.7.3/Dockerfile21
-rw-r--r--images/tmpfile/Dockerfile4
-rw-r--r--pkg/abi/BUILD13
-rw-r--r--pkg/abi/abi.go45
-rw-r--r--pkg/abi/abi_linux.go20
-rw-r--r--pkg/abi/flag.go85
-rw-r--r--pkg/abi/linux/BUILD86
-rw-r--r--pkg/abi/linux/aio.go76
-rw-r--r--pkg/abi/linux/arch_amd64.go23
-rw-r--r--pkg/abi/linux/audit.go23
-rw-r--r--pkg/abi/linux/bpf.go34
-rw-r--r--pkg/abi/linux/capability.go190
-rw-r--r--pkg/abi/linux/clone.go41
-rw-r--r--pkg/abi/linux/dev.go66
-rw-r--r--pkg/abi/linux/elf.go108
-rw-r--r--pkg/abi/linux/epoll.go62
-rw-r--r--pkg/abi/linux/epoll_amd64.go29
-rw-r--r--pkg/abi/linux/epoll_arm64.go28
-rw-r--r--pkg/abi/linux/errors.go172
-rw-r--r--pkg/abi/linux/eventfd.go22
-rw-r--r--pkg/abi/linux/exec.go18
-rw-r--r--pkg/abi/linux/fadvise.go24
-rw-r--r--pkg/abi/linux/fcntl.go69
-rw-r--r--pkg/abi/linux/file.go384
-rw-r--r--pkg/abi/linux/file_amd64.go46
-rw-r--r--pkg/abi/linux/file_arm64.go47
-rw-r--r--pkg/abi/linux/fs.go103
-rw-r--r--pkg/abi/linux/futex.go62
-rw-r--r--pkg/abi/linux/inotify.go97
-rw-r--r--pkg/abi/linux/ioctl.go100
-rw-r--r--pkg/abi/linux/ioctl_tun.go29
-rw-r--r--pkg/abi/linux/ip.go161
-rw-r--r--pkg/abi/linux/ipc.go53
-rw-r--r--pkg/abi/linux/limits.go88
-rw-r--r--pkg/abi/linux/linux.go39
-rw-r--r--pkg/abi/linux/mm.go130
-rw-r--r--pkg/abi/linux/netdevice.go86
-rw-r--r--pkg/abi/linux/netfilter.go552
-rw-r--r--pkg/abi/linux/netfilter_test.go46
-rw-r--r--pkg/abi/linux/netlink.go130
-rw-r--r--pkg/abi/linux/netlink_route.go346
-rw-r--r--pkg/abi/linux/poll.go42
-rw-r--r--pkg/abi/linux/prctl.go164
-rw-r--r--pkg/abi/linux/ptrace.go89
-rw-r--r--pkg/abi/linux/ptrace_amd64.go52
-rw-r--r--pkg/abi/linux/ptrace_arm64.go29
-rw-r--r--pkg/abi/linux/rseq.go130
-rw-r--r--pkg/abi/linux/rusage.go46
-rw-r--r--pkg/abi/linux/sched.go36
-rw-r--r--pkg/abi/linux/seccomp.go72
-rw-r--r--pkg/abi/linux/sem.go52
-rw-r--r--pkg/abi/linux/shm.go86
-rw-r--r--pkg/abi/linux/signal.go234
-rw-r--r--pkg/abi/linux/signalfd.go45
-rw-r--r--pkg/abi/linux/socket.go456
-rw-r--r--pkg/abi/linux/splice.go23
-rw-r--r--pkg/abi/linux/tcp.go61
-rw-r--r--pkg/abi/linux/time.go270
-rw-r--r--pkg/abi/linux/timer.go23
-rw-r--r--pkg/abi/linux/tty.go344
-rw-r--r--pkg/abi/linux/uio.go18
-rw-r--r--pkg/abi/linux/utsname.go49
-rw-r--r--pkg/abi/linux/wait.go36
-rw-r--r--pkg/abi/linux/xattr.go28
-rw-r--r--pkg/amutex/BUILD18
-rw-r--r--pkg/amutex/amutex.go137
-rw-r--r--pkg/amutex/amutex_test.go98
-rw-r--r--pkg/atomicbitops/BUILD22
-rw-r--r--pkg/atomicbitops/atomicbitops.go47
-rw-r--r--pkg/atomicbitops/atomicbitops_amd64.s77
-rw-r--r--pkg/atomicbitops/atomicbitops_arm64.s105
-rw-r--r--pkg/atomicbitops/atomicbitops_noasm.go105
-rw-r--r--pkg/atomicbitops/atomicbitops_test.go198
-rw-r--r--pkg/binary/BUILD16
-rw-r--r--pkg/binary/binary.go266
-rw-r--r--pkg/binary/binary_test.go266
-rw-r--r--pkg/bits/BUILD55
-rw-r--r--pkg/bits/bits.go16
-rw-r--r--pkg/bits/bits_template.go52
-rw-r--r--pkg/bits/uint64_arch.go36
-rw-r--r--pkg/bits/uint64_arch_amd64_asm.s31
-rw-r--r--pkg/bits/uint64_arch_arm64_asm.s33
-rw-r--r--pkg/bits/uint64_arch_generic.go55
-rw-r--r--pkg/bits/uint64_test.go134
-rw-r--r--pkg/bpf/BUILD31
-rw-r--r--pkg/bpf/bpf.go129
-rw-r--r--pkg/bpf/decoder.go245
-rw-r--r--pkg/bpf/decoder_test.go146
-rw-r--r--pkg/bpf/input_bytes.go58
-rw-r--r--pkg/bpf/interpreter.go412
-rw-r--r--pkg/bpf/interpreter_test.go797
-rw-r--r--pkg/bpf/program_builder.go191
-rw-r--r--pkg/bpf/program_builder_test.go157
-rw-r--r--pkg/buffer/BUILD43
-rw-r--r--pkg/buffer/buffer.go94
-rw-r--r--pkg/buffer/safemem.go133
-rw-r--r--pkg/buffer/safemem_test.go170
-rw-r--r--pkg/buffer/view.go390
-rw-r--r--pkg/buffer/view_test.go467
-rw-r--r--pkg/buffer/view_unsafe.go25
-rw-r--r--pkg/cleanup/BUILD17
-rw-r--r--pkg/cleanup/cleanup.go60
-rw-r--r--pkg/cleanup/cleanup_test.go66
-rw-r--r--pkg/compressio/BUILD20
-rw-r--r--pkg/compressio/compressio.go773
-rw-r--r--pkg/compressio/compressio_test.go290
-rw-r--r--pkg/context/BUILD13
-rw-r--r--pkg/context/context.go137
-rw-r--r--pkg/control/client/BUILD15
-rw-r--r--pkg/control/client/client.go33
-rw-r--r--pkg/control/server/BUILD15
-rw-r--r--pkg/control/server/server.go160
-rw-r--r--pkg/cpuid/BUILD35
-rw-r--r--pkg/cpuid/cpu_amd64.s24
-rw-r--r--pkg/cpuid/cpuid.go38
-rw-r--r--pkg/cpuid/cpuid_arm64.go483
-rw-r--r--pkg/cpuid/cpuid_arm64_test.go55
-rw-r--r--pkg/cpuid/cpuid_parse_x86_test.go144
-rw-r--r--pkg/cpuid/cpuid_x86.go1111
-rw-r--r--pkg/cpuid/cpuid_x86_test.go243
-rw-r--r--pkg/eventchannel/BUILD37
-rw-r--r--pkg/eventchannel/event.go201
-rw-r--r--pkg/eventchannel/event.proto27
-rw-r--r--pkg/eventchannel/event_test.go146
-rw-r--r--pkg/eventchannel/rate.go54
-rw-r--r--pkg/fd/BUILD16
-rw-r--r--pkg/fd/fd.go234
-rw-r--r--pkg/fd/fd_test.go136
-rw-r--r--pkg/fdchannel/BUILD17
-rw-r--r--pkg/fdchannel/fdchannel_test.go132
-rw-r--r--pkg/fdchannel/fdchannel_unsafe.go146
-rw-r--r--pkg/fdnotifier/BUILD17
-rw-r--r--pkg/fdnotifier/fdnotifier.go203
-rw-r--r--pkg/fdnotifier/poll_unsafe.go82
-rw-r--r--pkg/flipcall/BUILD34
-rw-r--r--pkg/flipcall/ctrl_futex.go176
-rw-r--r--pkg/flipcall/flipcall.go257
-rw-r--r--pkg/flipcall/flipcall_example_test.go113
-rw-r--r--pkg/flipcall/flipcall_test.go405
-rw-r--r--pkg/flipcall/flipcall_unsafe.go87
-rw-r--r--pkg/flipcall/futex_linux.go118
-rw-r--r--pkg/flipcall/io.go113
-rw-r--r--pkg/flipcall/packet_window_allocator.go166
-rw-r--r--pkg/flipcall/packet_window_mmap.go25
-rw-r--r--pkg/fspath/BUILD26
-rw-r--r--pkg/fspath/builder.go112
-rw-r--r--pkg/fspath/builder_test.go58
-rw-r--r--pkg/fspath/fspath.go187
-rw-r--r--pkg/fspath/fspath_test.go134
-rw-r--r--pkg/gate/BUILD22
-rw-r--r--pkg/gate/gate.go134
-rw-r--r--pkg/gate/gate_test.go192
-rw-r--r--pkg/gohacks/BUILD12
-rw-r--r--pkg/gohacks/gohacks_unsafe.go57
-rw-r--r--pkg/goid/BUILD25
-rw-r--r--pkg/goid/empty_test.go22
-rw-r--r--pkg/goid/goid.go24
-rw-r--r--pkg/goid/goid_amd64.s21
-rw-r--r--pkg/goid/goid_arm64.s21
-rw-r--r--pkg/goid/goid_race.go25
-rw-r--r--pkg/goid/goid_test.go74
-rw-r--r--pkg/goid/goid_unsafe.go64
-rw-r--r--pkg/ilist/BUILD56
-rw-r--r--pkg/ilist/list.go227
-rw-r--r--pkg/ilist/list_test.go240
-rw-r--r--pkg/linewriter/BUILD18
-rw-r--r--pkg/linewriter/linewriter.go79
-rw-r--r--pkg/linewriter/linewriter_test.go81
-rw-r--r--pkg/log/BUILD32
-rw-r--r--pkg/log/glog.go85
-rw-r--r--pkg/log/json.go76
-rw-r--r--pkg/log/json_k8s.go47
-rw-r--r--pkg/log/json_test.go64
-rw-r--r--pkg/log/log.go378
-rw-r--r--pkg/log/log_test.go105
-rw-r--r--pkg/memutil/BUILD10
-rw-r--r--pkg/memutil/memutil_unsafe.go42
-rw-r--r--pkg/merkletree/BUILD16
-rw-r--r--pkg/merkletree/merkletree.go135
-rw-r--r--pkg/merkletree/merkletree_test.go122
-rw-r--r--pkg/metric/BUILD32
-rw-r--r--pkg/metric/metric.go250
-rw-r--r--pkg/metric/metric.proto76
-rw-r--r--pkg/metric/metric_test.go258
-rw-r--r--pkg/p9/BUILD52
-rw-r--r--pkg/p9/buffer.go263
-rw-r--r--pkg/p9/buffer_test.go31
-rw-r--r--pkg/p9/client.go575
-rw-r--r--pkg/p9/client_file.go686
-rw-r--r--pkg/p9/client_test.go109
-rw-r--r--pkg/p9/file.go288
-rw-r--r--pkg/p9/handlers.go1393
-rw-r--r--pkg/p9/messages.go2662
-rw-r--r--pkg/p9/messages_test.go483
-rw-r--r--pkg/p9/p9.go1171
-rw-r--r--pkg/p9/p9_test.go188
-rw-r--r--pkg/p9/p9test/BUILD88
-rw-r--r--pkg/p9/p9test/client_test.go2242
-rw-r--r--pkg/p9/p9test/p9test.go329
-rw-r--r--pkg/p9/path_tree.go222
-rw-r--r--pkg/p9/server.go694
-rw-r--r--pkg/p9/transport.go345
-rw-r--r--pkg/p9/transport_flipcall.go243
-rw-r--r--pkg/p9/transport_test.go231
-rw-r--r--pkg/p9/version.go175
-rw-r--r--pkg/p9/version_test.go145
-rw-r--r--pkg/pool/BUILD25
-rw-r--r--pkg/pool/pool.go66
-rw-r--r--pkg/pool/pool_test.go64
-rw-r--r--pkg/procid/BUILD34
-rw-r--r--pkg/procid/procid.go21
-rw-r--r--pkg/procid/procid_amd64.s30
-rw-r--r--pkg/procid/procid_arm64.s29
-rw-r--r--pkg/procid/procid_net_test.go21
-rw-r--r--pkg/procid/procid_test.go86
-rw-r--r--pkg/rand/BUILD16
-rw-r--r--pkg/rand/rand.go29
-rw-r--r--pkg/rand/rand_linux.go77
-rw-r--r--pkg/refs/BUILD38
-rw-r--r--pkg/refs/refcounter.go469
-rw-r--r--pkg/refs/refcounter_state.go35
-rw-r--r--pkg/refs/refcounter_test.go173
-rw-r--r--pkg/safecopy/BUILD29
-rw-r--r--pkg/safecopy/LICENSE27
-rw-r--r--pkg/safecopy/atomic_amd64.s136
-rw-r--r--pkg/safecopy/atomic_arm64.s126
-rw-r--r--pkg/safecopy/memclr_amd64.s147
-rw-r--r--pkg/safecopy/memclr_arm64.s74
-rw-r--r--pkg/safecopy/memcpy_amd64.s219
-rw-r--r--pkg/safecopy/memcpy_arm64.s78
-rw-r--r--pkg/safecopy/safecopy.go144
-rw-r--r--pkg/safecopy/safecopy_test.go629
-rw-r--r--pkg/safecopy/safecopy_unsafe.go361
-rw-r--r--pkg/safecopy/sighandler_amd64.s133
-rw-r--r--pkg/safecopy/sighandler_arm64.s143
-rw-r--r--pkg/safemem/BUILD27
-rw-r--r--pkg/safemem/block_unsafe.go279
-rw-r--r--pkg/safemem/io.go392
-rw-r--r--pkg/safemem/io_test.go199
-rw-r--r--pkg/safemem/safemem.go16
-rw-r--r--pkg/safemem/seq_test.go217
-rw-r--r--pkg/safemem/seq_unsafe.go319
-rw-r--r--pkg/seccomp/BUILD50
-rw-r--r--pkg/seccomp/seccomp.go404
-rw-r--r--pkg/seccomp/seccomp_amd64.go26
-rw-r--r--pkg/seccomp/seccomp_arm64.go26
-rw-r--r--pkg/seccomp/seccomp_rules.go139
-rw-r--r--pkg/seccomp/seccomp_test.go580
-rw-r--r--pkg/seccomp/seccomp_test_victim.go117
-rw-r--r--pkg/seccomp/seccomp_unsafe.go63
-rw-r--r--pkg/secio/BUILD19
-rw-r--r--pkg/secio/full_reader.go34
-rw-r--r--pkg/secio/secio.go105
-rw-r--r--pkg/secio/secio_test.go126
-rw-r--r--pkg/segment/BUILD33
-rw-r--r--pkg/segment/range.go79
-rw-r--r--pkg/segment/set.go1754
-rw-r--r--pkg/segment/set_state.go25
-rw-r--r--pkg/segment/test/BUILD68
-rw-r--r--pkg/segment/test/segment_test.go865
-rw-r--r--pkg/segment/test/set_functions.go54
-rw-r--r--pkg/sentry/BUILD14
-rw-r--r--pkg/sentry/arch/BUILD48
-rw-r--r--pkg/sentry/arch/aligned.go31
-rw-r--r--pkg/sentry/arch/arch.go366
-rw-r--r--pkg/sentry/arch/arch_aarch64.go321
-rw-r--r--pkg/sentry/arch/arch_amd64.go328
-rw-r--r--pkg/sentry/arch/arch_amd64.s136
-rw-r--r--pkg/sentry/arch/arch_arm64.go284
-rw-r--r--pkg/sentry/arch/arch_state_x86.go91
-rw-r--r--pkg/sentry/arch/arch_x86.go615
-rw-r--r--pkg/sentry/arch/arch_x86_impl.go41
-rw-r--r--pkg/sentry/arch/auxv.go30
-rw-r--r--pkg/sentry/arch/registers.proto92
-rw-r--r--pkg/sentry/arch/signal.go253
-rw-r--r--pkg/sentry/arch/signal_act.go83
-rw-r--r--pkg/sentry/arch/signal_amd64.go291
-rw-r--r--pkg/sentry/arch/signal_arm64.go181
-rw-r--r--pkg/sentry/arch/signal_info.go66
-rw-r--r--pkg/sentry/arch/signal_stack.go68
-rw-r--r--pkg/sentry/arch/stack.go249
-rw-r--r--pkg/sentry/arch/syscalls_amd64.go59
-rw-r--r--pkg/sentry/arch/syscalls_arm64.go81
-rw-r--r--pkg/sentry/contexttest/BUILD21
-rw-r--r--pkg/sentry/contexttest/contexttest.go188
-rw-r--r--pkg/sentry/control/BUILD52
-rw-r--r--pkg/sentry/control/control.go17
-rw-r--r--pkg/sentry/control/logging.go136
-rw-r--r--pkg/sentry/control/pprof.go209
-rw-r--r--pkg/sentry/control/proc.go416
-rw-r--r--pkg/sentry/control/proc_test.go166
-rw-r--r--pkg/sentry/control/state.go73
-rw-r--r--pkg/sentry/device/BUILD20
-rw-r--r--pkg/sentry/device/device.go269
-rw-r--r--pkg/sentry/device/device_test.go59
-rw-r--r--pkg/sentry/devices/memdev/BUILD28
-rw-r--r--pkg/sentry/devices/memdev/full.go76
-rw-r--r--pkg/sentry/devices/memdev/memdev.go59
-rw-r--r--pkg/sentry/devices/memdev/null.go77
-rw-r--r--pkg/sentry/devices/memdev/random.go93
-rw-r--r--pkg/sentry/devices/memdev/zero.go89
-rw-r--r--pkg/sentry/devices/ttydev/BUILD16
-rw-r--r--pkg/sentry/devices/ttydev/ttydev.go91
-rw-r--r--pkg/sentry/devices/tundev/BUILD23
-rw-r--r--pkg/sentry/devices/tundev/tundev.go178
-rw-r--r--pkg/sentry/fdimport/BUILD19
-rw-r--r--pkg/sentry/fdimport/fdimport.go129
-rw-r--r--pkg/sentry/fs/BUILD135
-rw-r--r--pkg/sentry/fs/README.md229
-rw-r--r--pkg/sentry/fs/anon/BUILD20
-rw-r--r--pkg/sentry/fs/anon/anon.go42
-rw-r--r--pkg/sentry/fs/anon/device.go22
-rw-r--r--pkg/sentry/fs/attr.go493
-rw-r--r--pkg/sentry/fs/context.go138
-rw-r--r--pkg/sentry/fs/copy_up.go436
-rw-r--r--pkg/sentry/fs/copy_up_test.go183
-rw-r--r--pkg/sentry/fs/dentry.go234
-rw-r--r--pkg/sentry/fs/dev/BUILD40
-rw-r--r--pkg/sentry/fs/dev/dev.go151
-rw-r--r--pkg/sentry/fs/dev/device.go20
-rw-r--r--pkg/sentry/fs/dev/fs.go64
-rw-r--r--pkg/sentry/fs/dev/full.go81
-rw-r--r--pkg/sentry/fs/dev/net_tun.go177
-rw-r--r--pkg/sentry/fs/dev/null.go131
-rw-r--r--pkg/sentry/fs/dev/random.go79
-rw-r--r--pkg/sentry/fs/dev/tty.go67
-rw-r--r--pkg/sentry/fs/dirent.go1558
-rw-r--r--pkg/sentry/fs/dirent_cache.go174
-rw-r--r--pkg/sentry/fs/dirent_cache_limiter.go56
-rw-r--r--pkg/sentry/fs/dirent_cache_test.go247
-rw-r--r--pkg/sentry/fs/dirent_refs_test.go418
-rw-r--r--pkg/sentry/fs/dirent_state.go77
-rw-r--r--pkg/sentry/fs/fdpipe/BUILD48
-rw-r--r--pkg/sentry/fs/fdpipe/pipe.go168
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_opener.go193
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_opener_test.go523
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_state.go89
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_test.go505
-rw-r--r--pkg/sentry/fs/file.go593
-rw-r--r--pkg/sentry/fs/file_operations.go175
-rw-r--r--pkg/sentry/fs/file_overlay.go556
-rw-r--r--pkg/sentry/fs/file_overlay_test.go192
-rw-r--r--pkg/sentry/fs/file_state.go31
-rw-r--r--pkg/sentry/fs/filesystems.go160
-rw-r--r--pkg/sentry/fs/filetest/BUILD19
-rw-r--r--pkg/sentry/fs/filetest/filetest.go61
-rw-r--r--pkg/sentry/fs/flags.go138
-rw-r--r--pkg/sentry/fs/fs.go161
-rw-r--r--pkg/sentry/fs/fsutil/BUILD118
-rw-r--r--pkg/sentry/fs/fsutil/README.md207
-rw-r--r--pkg/sentry/fs/fsutil/dirty_set.go237
-rw-r--r--pkg/sentry/fs/fsutil/dirty_set_test.go38
-rw-r--r--pkg/sentry/fs/fsutil/file.go396
-rw-r--r--pkg/sentry/fs/fsutil/file_range_set.go209
-rw-r--r--pkg/sentry/fs/fsutil/frame_ref_set.go91
-rw-r--r--pkg/sentry/fs/fsutil/fsutil.go24
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper.go242
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper_state.go20
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go27
-rw-r--r--pkg/sentry/fs/fsutil/host_mappable.go214
-rw-r--r--pkg/sentry/fs/fsutil/inode.go531
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go1061
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached_test.go389
-rw-r--r--pkg/sentry/fs/g3doc/.gitignore1
-rw-r--r--pkg/sentry/fs/g3doc/fuse.md263
-rw-r--r--pkg/sentry/fs/g3doc/inotify.md122
-rw-r--r--pkg/sentry/fs/gofer/BUILD67
-rw-r--r--pkg/sentry/fs/gofer/attr.go172
-rw-r--r--pkg/sentry/fs/gofer/cache_policy.go186
-rw-r--r--pkg/sentry/fs/gofer/context_file.go218
-rw-r--r--pkg/sentry/fs/gofer/device.go20
-rw-r--r--pkg/sentry/fs/gofer/fifo.go40
-rw-r--r--pkg/sentry/fs/gofer/file.go369
-rw-r--r--pkg/sentry/fs/gofer/file_state.go44
-rw-r--r--pkg/sentry/fs/gofer/fs.go267
-rw-r--r--pkg/sentry/fs/gofer/gofer_test.go310
-rw-r--r--pkg/sentry/fs/gofer/handles.go140
-rw-r--r--pkg/sentry/fs/gofer/inode.go719
-rw-r--r--pkg/sentry/fs/gofer/inode_state.go171
-rw-r--r--pkg/sentry/fs/gofer/path.go495
-rw-r--r--pkg/sentry/fs/gofer/session.go426
-rw-r--r--pkg/sentry/fs/gofer/session_state.go113
-rw-r--r--pkg/sentry/fs/gofer/socket.go152
-rw-r--r--pkg/sentry/fs/gofer/util.go72
-rw-r--r--pkg/sentry/fs/host/BUILD82
-rw-r--r--pkg/sentry/fs/host/control.go97
-rw-r--r--pkg/sentry/fs/host/descriptor.go99
-rw-r--r--pkg/sentry/fs/host/descriptor_state.go29
-rw-r--r--pkg/sentry/fs/host/descriptor_test.go78
-rw-r--r--pkg/sentry/fs/host/device.go25
-rw-r--r--pkg/sentry/fs/host/file.go286
-rw-r--r--pkg/sentry/fs/host/host.go59
-rw-r--r--pkg/sentry/fs/host/inode.go416
-rw-r--r--pkg/sentry/fs/host/inode_state.go49
-rw-r--r--pkg/sentry/fs/host/inode_test.go45
-rw-r--r--pkg/sentry/fs/host/ioctl_unsafe.go60
-rw-r--r--pkg/sentry/fs/host/socket.go384
-rw-r--r--pkg/sentry/fs/host/socket_iovec.go117
-rw-r--r--pkg/sentry/fs/host/socket_state.go42
-rw-r--r--pkg/sentry/fs/host/socket_test.go246
-rw-r--r--pkg/sentry/fs/host/socket_unsafe.go105
-rw-r--r--pkg/sentry/fs/host/tty.go364
-rw-r--r--pkg/sentry/fs/host/util.go129
-rw-r--r--pkg/sentry/fs/host/util_amd64_unsafe.go41
-rw-r--r--pkg/sentry/fs/host/util_arm64_unsafe.go41
-rw-r--r--pkg/sentry/fs/host/util_unsafe.go77
-rw-r--r--pkg/sentry/fs/host/wait_test.go69
-rw-r--r--pkg/sentry/fs/inode.go477
-rw-r--r--pkg/sentry/fs/inode_inotify.go170
-rw-r--r--pkg/sentry/fs/inode_operations.go326
-rw-r--r--pkg/sentry/fs/inode_overlay.go737
-rw-r--r--pkg/sentry/fs/inode_overlay_test.go470
-rw-r--r--pkg/sentry/fs/inotify.go352
-rw-r--r--pkg/sentry/fs/inotify_event.go139
-rw-r--r--pkg/sentry/fs/inotify_watch.go135
-rw-r--r--pkg/sentry/fs/lock/BUILD58
-rw-r--r--pkg/sentry/fs/lock/lock.go453
-rw-r--r--pkg/sentry/fs/lock/lock_range_test.go136
-rw-r--r--pkg/sentry/fs/lock/lock_set_functions.go63
-rw-r--r--pkg/sentry/fs/lock/lock_test.go1060
-rw-r--r--pkg/sentry/fs/mock.go176
-rw-r--r--pkg/sentry/fs/mount.go285
-rw-r--r--pkg/sentry/fs/mount_overlay.go151
-rw-r--r--pkg/sentry/fs/mount_test.go272
-rw-r--r--pkg/sentry/fs/mounts.go623
-rw-r--r--pkg/sentry/fs/mounts_test.go105
-rw-r--r--pkg/sentry/fs/offset.go65
-rw-r--r--pkg/sentry/fs/overlay.go320
-rw-r--r--pkg/sentry/fs/path.go119
-rw-r--r--pkg/sentry/fs/path_test.go289
-rw-r--r--pkg/sentry/fs/proc/BUILD72
-rw-r--r--pkg/sentry/fs/proc/README.md336
-rw-r--r--pkg/sentry/fs/proc/cgroup.go45
-rw-r--r--pkg/sentry/fs/proc/cpuinfo.go41
-rw-r--r--pkg/sentry/fs/proc/device/BUILD10
-rw-r--r--pkg/sentry/fs/proc/device/device.go23
-rw-r--r--pkg/sentry/fs/proc/exec_args.go207
-rw-r--r--pkg/sentry/fs/proc/fds.go283
-rw-r--r--pkg/sentry/fs/proc/filesystems.go65
-rw-r--r--pkg/sentry/fs/proc/fs.go85
-rw-r--r--pkg/sentry/fs/proc/inode.go137
-rw-r--r--pkg/sentry/fs/proc/loadavg.go59
-rw-r--r--pkg/sentry/fs/proc/meminfo.go93
-rw-r--r--pkg/sentry/fs/proc/mounts.go232
-rw-r--r--pkg/sentry/fs/proc/net.go841
-rw-r--r--pkg/sentry/fs/proc/net_test.go74
-rw-r--r--pkg/sentry/fs/proc/proc.go248
-rw-r--r--pkg/sentry/fs/proc/seqfile/BUILD35
-rw-r--r--pkg/sentry/fs/proc/seqfile/seqfile.go283
-rw-r--r--pkg/sentry/fs/proc/seqfile/seqfile_test.go279
-rw-r--r--pkg/sentry/fs/proc/stat.go146
-rw-r--r--pkg/sentry/fs/proc/sys.go159
-rw-r--r--pkg/sentry/fs/proc/sys_net.go372
-rw-r--r--pkg/sentry/fs/proc/sys_net_state.go42
-rw-r--r--pkg/sentry/fs/proc/sys_net_test.go125
-rw-r--r--pkg/sentry/fs/proc/task.go914
-rw-r--r--pkg/sentry/fs/proc/uid_gid_map.go183
-rw-r--r--pkg/sentry/fs/proc/uptime.go91
-rw-r--r--pkg/sentry/fs/proc/version.go82
-rw-r--r--pkg/sentry/fs/ramfs/BUILD37
-rw-r--r--pkg/sentry/fs/ramfs/dir.go548
-rw-r--r--pkg/sentry/fs/ramfs/socket.go85
-rw-r--r--pkg/sentry/fs/ramfs/symlink.go106
-rw-r--r--pkg/sentry/fs/ramfs/tree.go77
-rw-r--r--pkg/sentry/fs/ramfs/tree_test.go80
-rw-r--r--pkg/sentry/fs/restore.go78
-rw-r--r--pkg/sentry/fs/save.go77
-rw-r--r--pkg/sentry/fs/seek.go43
-rw-r--r--pkg/sentry/fs/splice.go181
-rw-r--r--pkg/sentry/fs/sync.go43
-rw-r--r--pkg/sentry/fs/sys/BUILD24
-rw-r--r--pkg/sentry/fs/sys/device.go20
-rw-r--r--pkg/sentry/fs/sys/devices.go91
-rw-r--r--pkg/sentry/fs/sys/fs.go65
-rw-r--r--pkg/sentry/fs/sys/sys.go64
-rw-r--r--pkg/sentry/fs/timerfd/BUILD19
-rw-r--r--pkg/sentry/fs/timerfd/timerfd.go151
-rw-r--r--pkg/sentry/fs/tmpfs/BUILD50
-rw-r--r--pkg/sentry/fs/tmpfs/device.go20
-rw-r--r--pkg/sentry/fs/tmpfs/file_regular.go60
-rw-r--r--pkg/sentry/fs/tmpfs/file_test.go72
-rw-r--r--pkg/sentry/fs/tmpfs/fs.go155
-rw-r--r--pkg/sentry/fs/tmpfs/inode_file.go687
-rw-r--r--pkg/sentry/fs/tmpfs/tmpfs.go356
-rw-r--r--pkg/sentry/fs/tty/BUILD47
-rw-r--r--pkg/sentry/fs/tty/dir.go342
-rw-r--r--pkg/sentry/fs/tty/fs.go111
-rw-r--r--pkg/sentry/fs/tty/line_discipline.go449
-rw-r--r--pkg/sentry/fs/tty/master.go238
-rw-r--r--pkg/sentry/fs/tty/queue.go240
-rw-r--r--pkg/sentry/fs/tty/slave.go178
-rw-r--r--pkg/sentry/fs/tty/terminal.go132
-rw-r--r--pkg/sentry/fs/tty/tty_test.go56
-rw-r--r--pkg/sentry/fs/user/BUILD40
-rw-r--r--pkg/sentry/fs/user/path.go170
-rw-r--r--pkg/sentry/fs/user/user.go239
-rw-r--r--pkg/sentry/fs/user/user_test.go198
-rw-r--r--pkg/sentry/fsbridge/BUILD24
-rw-r--r--pkg/sentry/fsbridge/bridge.go54
-rw-r--r--pkg/sentry/fsbridge/fs.go181
-rw-r--r--pkg/sentry/fsbridge/vfs.go142
-rw-r--r--pkg/sentry/fsimpl/devpts/BUILD44
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts.go233
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts_test.go56
-rw-r--r--pkg/sentry/fsimpl/devpts/line_discipline.go445
-rw-r--r--pkg/sentry/fsimpl/devpts/master.go237
-rw-r--r--pkg/sentry/fsimpl/devpts/queue.go236
-rw-r--r--pkg/sentry/fsimpl/devpts/slave.go197
-rw-r--r--pkg/sentry/fsimpl/devpts/terminal.go120
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/BUILD33
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/devtmpfs.go219
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go122
-rw-r--r--pkg/sentry/fsimpl/eventfd/BUILD33
-rw-r--r--pkg/sentry/fsimpl/eventfd/eventfd.go285
-rw-r--r--pkg/sentry/fsimpl/eventfd/eventfd_test.go97
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD102
-rw-r--r--pkg/sentry/fsimpl/ext/README.md117
-rw-r--r--pkg/sentry/fsimpl/ext/assets/README.md36
-rw-r--r--pkg/sentry/fsimpl/ext/assets/bigfile.txt41
-rw-r--r--pkg/sentry/fsimpl/ext/assets/file.txt1
l---------pkg/sentry/fsimpl/ext/assets/symlink.txt1
-rw-r--r--pkg/sentry/fsimpl/ext/assets/tiny.ext2bin0 -> 65536 bytes
-rw-r--r--pkg/sentry/fsimpl/ext/assets/tiny.ext3bin0 -> 65536 bytes
-rw-r--r--pkg/sentry/fsimpl/ext/assets/tiny.ext4bin0 -> 65536 bytes
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/BUILD17
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go206
-rwxr-xr-xpkg/sentry/fsimpl/ext/benchmark/make_deep_ext4.sh72
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_file.go201
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_test.go156
-rw-r--r--pkg/sentry/fsimpl/ext/dentry.go79
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go318
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/BUILD47
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group.go137
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group_32.go72
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group_64.go93
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group_test.go26
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent.go72
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent_new.go61
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent_old.go49
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent_test.go26
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/disklayout.go50
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent.go143
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent_test.go27
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode.go274
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode_new.go96
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode_old.go117
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode_test.go222
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock.go471
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_32.go76
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_64.go95
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_old.go105
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_test.go27
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/test_utils.go30
-rw-r--r--pkg/sentry/fsimpl/ext/ext.go157
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go921
-rw-r--r--pkg/sentry/fsimpl/ext/extent_file.go238
-rw-r--r--pkg/sentry/fsimpl/ext/extent_test.go265
-rw-r--r--pkg/sentry/fsimpl/ext/file_description.go65
-rw-r--r--pkg/sentry/fsimpl/ext/filesystem.go548
-rw-r--r--pkg/sentry/fsimpl/ext/inode.go242
-rw-r--r--pkg/sentry/fsimpl/ext/regular_file.go162
-rw-r--r--pkg/sentry/fsimpl/ext/symlink.go111
-rw-r--r--pkg/sentry/fsimpl/ext/utils.go94
-rw-r--r--pkg/sentry/fsimpl/fuse/BUILD19
-rw-r--r--pkg/sentry/fsimpl/fuse/dev.go100
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD89
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go308
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go1504
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go1550
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer_test.go63
-rw-r--r--pkg/sentry/fsimpl/gofer/handle.go141
-rw-r--r--pkg/sentry/fsimpl/gofer/host_named_pipe.go97
-rw-r--r--pkg/sentry/fsimpl/gofer/p9file.go233
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go892
-rw-r--r--pkg/sentry/fsimpl/gofer/socket.go146
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go245
-rw-r--r--pkg/sentry/fsimpl/gofer/symlink.go47
-rw-r--r--pkg/sentry/fsimpl/gofer/time.go79
-rw-r--r--pkg/sentry/fsimpl/host/BUILD52
-rw-r--r--pkg/sentry/fsimpl/host/control.go96
-rw-r--r--pkg/sentry/fsimpl/host/host.go766
-rw-r--r--pkg/sentry/fsimpl/host/ioctl_unsafe.go56
-rw-r--r--pkg/sentry/fsimpl/host/mmap.go132
-rw-r--r--pkg/sentry/fsimpl/host/socket.go385
-rw-r--r--pkg/sentry/fsimpl/host/socket_iovec.go113
-rw-r--r--pkg/sentry/fsimpl/host/socket_unsafe.go101
-rw-r--r--pkg/sentry/fsimpl/host/tty.go390
-rw-r--r--pkg/sentry/fsimpl/host/util.go56
-rw-r--r--pkg/sentry/fsimpl/host/util_unsafe.go34
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD75
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go147
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go252
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go840
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go613
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go456
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go330
-rw-r--r--pkg/sentry/fsimpl/kernfs/symlink.go66
-rw-r--r--pkg/sentry/fsimpl/overlay/BUILD41
-rw-r--r--pkg/sentry/fsimpl/overlay/copy_up.go262
-rw-r--r--pkg/sentry/fsimpl/overlay/directory.go287
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go1364
-rw-r--r--pkg/sentry/fsimpl/overlay/non_directory.go266
-rw-r--r--pkg/sentry/fsimpl/overlay/overlay.go627
-rw-r--r--pkg/sentry/fsimpl/pipefs/BUILD21
-rw-r--r--pkg/sentry/fsimpl/pipefs/pipefs.go165
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD67
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go117
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go182
-rw-r--r--pkg/sentry/fsimpl/proc/task.go239
-rw-r--r--pkg/sentry/fsimpl/proc/task_fds.go307
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go902
-rw-r--r--pkg/sentry/fsimpl/proc/task_net.go810
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go256
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go384
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go209
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys_test.go78
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_test.go505
-rw-r--r--pkg/sentry/fsimpl/signalfd/BUILD20
-rw-r--r--pkg/sentry/fsimpl/signalfd/signalfd.go136
-rw-r--r--pkg/sentry/fsimpl/sockfs/BUILD18
-rw-r--r--pkg/sentry/fsimpl/sockfs/sockfs.go109
-rw-r--r--pkg/sentry/fsimpl/sys/BUILD34
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go151
-rw-r--r--pkg/sentry/fsimpl/sys/sys_test.go89
-rw-r--r--pkg/sentry/fsimpl/testutil/BUILD37
-rw-r--r--pkg/sentry/fsimpl/testutil/kernel.go176
-rw-r--r--pkg/sentry/fsimpl/testutil/testutil.go284
-rw-r--r--pkg/sentry/fsimpl/timerfd/BUILD17
-rw-r--r--pkg/sentry/fsimpl/timerfd/timerfd.go144
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD112
-rw-r--r--pkg/sentry/fsimpl/tmpfs/benchmark_test.go486
-rw-r--r--pkg/sentry/fsimpl/tmpfs/device_file.go49
-rw-r--r--pkg/sentry/fsimpl/tmpfs/directory.go232
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go859
-rw-r--r--pkg/sentry/fsimpl/tmpfs/named_pipe.go38
-rw-r--r--pkg/sentry/fsimpl/tmpfs/pipe_test.go238
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go626
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file_test.go349
-rw-r--r--pkg/sentry/fsimpl/tmpfs/socket_file.go34
-rw-r--r--pkg/sentry/fsimpl/tmpfs/stat_test.go236
-rw-r--r--pkg/sentry/fsimpl/tmpfs/symlink.go37
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go787
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs_test.go156
-rw-r--r--pkg/sentry/hostcpu/BUILD20
-rw-r--r--pkg/sentry/hostcpu/getcpu_amd64.s24
-rw-r--r--pkg/sentry/hostcpu/getcpu_arm64.s28
-rw-r--r--pkg/sentry/hostcpu/hostcpu.go67
-rw-r--r--pkg/sentry/hostcpu/hostcpu_test.go52
-rw-r--r--pkg/sentry/hostfd/BUILD17
-rw-r--r--pkg/sentry/hostfd/hostfd.go84
-rw-r--r--pkg/sentry/hostfd/hostfd_unsafe.go85
-rw-r--r--pkg/sentry/hostmm/BUILD17
-rw-r--r--pkg/sentry/hostmm/cgroup.go111
-rw-r--r--pkg/sentry/hostmm/hostmm.go130
-rw-r--r--pkg/sentry/inet/BUILD20
-rw-r--r--pkg/sentry/inet/context.go35
-rw-r--r--pkg/sentry/inet/inet.go191
-rw-r--r--pkg/sentry/inet/namespace.go102
-rw-r--r--pkg/sentry/inet/test_stack.go118
-rw-r--r--pkg/sentry/kernel/BUILD241
-rw-r--r--pkg/sentry/kernel/README.md108
-rw-r--r--pkg/sentry/kernel/abstract_socket_namespace.go111
-rw-r--r--pkg/sentry/kernel/aio.go81
-rw-r--r--pkg/sentry/kernel/auth/BUILD69
-rw-r--r--pkg/sentry/kernel/auth/auth.go22
-rw-r--r--pkg/sentry/kernel/auth/capability_set.go61
-rw-r--r--pkg/sentry/kernel/auth/context.go36
-rw-r--r--pkg/sentry/kernel/auth/credentials.go262
-rw-r--r--pkg/sentry/kernel/auth/id.go121
-rw-r--r--pkg/sentry/kernel/auth/id_map.go285
-rw-r--r--pkg/sentry/kernel/auth/id_map_functions.go45
-rw-r--r--pkg/sentry/kernel/auth/user_namespace.go129
-rw-r--r--pkg/sentry/kernel/context.go114
-rw-r--r--pkg/sentry/kernel/contexttest/BUILD17
-rw-r--r--pkg/sentry/kernel/contexttest/contexttest.go40
-rw-r--r--pkg/sentry/kernel/epoll/BUILD51
-rw-r--r--pkg/sentry/kernel/epoll/epoll.go462
-rw-r--r--pkg/sentry/kernel/epoll/epoll_state.go51
-rw-r--r--pkg/sentry/kernel/epoll/epoll_test.go54
-rw-r--r--pkg/sentry/kernel/eventfd/BUILD33
-rw-r--r--pkg/sentry/kernel/eventfd/eventfd.go285
-rw-r--r--pkg/sentry/kernel/eventfd/eventfd_test.go78
-rw-r--r--pkg/sentry/kernel/fasync/BUILD18
-rw-r--r--pkg/sentry/kernel/fasync/fasync.go188
-rw-r--r--pkg/sentry/kernel/fd_table.go638
-rw-r--r--pkg/sentry/kernel/fd_table_test.go228
-rw-r--r--pkg/sentry/kernel/fd_table_unsafe.go169
-rw-r--r--pkg/sentry/kernel/fs_context.go283
-rw-r--r--pkg/sentry/kernel/futex/BUILD57
-rw-r--r--pkg/sentry/kernel/futex/futex.go795
-rw-r--r--pkg/sentry/kernel/futex/futex_test.go530
-rw-r--r--pkg/sentry/kernel/g3doc/run_states.dot99
-rw-r--r--pkg/sentry/kernel/g3doc/run_states.pngbin0 -> 234152 bytes
-rw-r--r--pkg/sentry/kernel/ipc_namespace.go58
-rw-r--r--pkg/sentry/kernel/kernel.go1682
-rw-r--r--pkg/sentry/kernel/kernel_opts.go20
-rw-r--r--pkg/sentry/kernel/kernel_state.go42
-rw-r--r--pkg/sentry/kernel/memevent/BUILD24
-rw-r--r--pkg/sentry/kernel/memevent/memory_events.go111
-rw-r--r--pkg/sentry/kernel/memevent/memory_events.proto29
-rw-r--r--pkg/sentry/kernel/pending_signals.go142
-rw-r--r--pkg/sentry/kernel/pending_signals_state.go46
-rw-r--r--pkg/sentry/kernel/pipe/BUILD54
-rw-r--r--pkg/sentry/kernel/pipe/device.go20
-rw-r--r--pkg/sentry/kernel/pipe/node.go139
-rw-r--r--pkg/sentry/kernel/pipe/node_test.go320
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go419
-rw-r--r--pkg/sentry/kernel/pipe/pipe_test.go139
-rw-r--r--pkg/sentry/kernel/pipe/pipe_unsafe.go35
-rw-r--r--pkg/sentry/kernel/pipe/pipe_util.go214
-rw-r--r--pkg/sentry/kernel/pipe/reader.go42
-rw-r--r--pkg/sentry/kernel/pipe/reader_writer.go67
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go468
-rw-r--r--pkg/sentry/kernel/pipe/writer.go42
-rw-r--r--pkg/sentry/kernel/posixtimer.go308
-rw-r--r--pkg/sentry/kernel/ptrace.go1119
-rw-r--r--pkg/sentry/kernel/ptrace_amd64.go89
-rw-r--r--pkg/sentry/kernel/ptrace_arm64.go27
-rw-r--r--pkg/sentry/kernel/rseq.go393
-rw-r--r--pkg/sentry/kernel/sched/BUILD19
-rw-r--r--pkg/sentry/kernel/sched/cpuset.go105
-rw-r--r--pkg/sentry/kernel/sched/cpuset_test.go44
-rw-r--r--pkg/sentry/kernel/sched/sched.go16
-rw-r--r--pkg/sentry/kernel/seccomp.go217
-rw-r--r--pkg/sentry/kernel/semaphore/BUILD49
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go572
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore_test.go172
-rw-r--r--pkg/sentry/kernel/sessions.go528
-rw-r--r--pkg/sentry/kernel/shm/BUILD29
-rw-r--r--pkg/sentry/kernel/shm/device.go20
-rw-r--r--pkg/sentry/kernel/shm/shm.go707
-rw-r--r--pkg/sentry/kernel/signal.go79
-rw-r--r--pkg/sentry/kernel/signal_handlers.go88
-rw-r--r--pkg/sentry/kernel/signalfd/BUILD22
-rw-r--r--pkg/sentry/kernel/signalfd/signalfd.go139
-rw-r--r--pkg/sentry/kernel/syscalls.go364
-rw-r--r--pkg/sentry/kernel/syscalls_state.go47
-rw-r--r--pkg/sentry/kernel/syslog.go108
-rw-r--r--pkg/sentry/kernel/table_test.go110
-rw-r--r--pkg/sentry/kernel/task.go886
-rw-r--r--pkg/sentry/kernel/task_acct.go196
-rw-r--r--pkg/sentry/kernel/task_block.go230
-rw-r--r--pkg/sentry/kernel/task_clone.go540
-rw-r--r--pkg/sentry/kernel/task_context.go169
-rw-r--r--pkg/sentry/kernel/task_exec.go277
-rw-r--r--pkg/sentry/kernel/task_exit.go1167
-rw-r--r--pkg/sentry/kernel/task_futex.go54
-rw-r--r--pkg/sentry/kernel/task_identity.go606
-rw-r--r--pkg/sentry/kernel/task_log.go208
-rw-r--r--pkg/sentry/kernel/task_net.go44
-rw-r--r--pkg/sentry/kernel/task_run.go380
-rw-r--r--pkg/sentry/kernel/task_sched.go668
-rw-r--r--pkg/sentry/kernel/task_signals.go1139
-rw-r--r--pkg/sentry/kernel/task_start.go319
-rw-r--r--pkg/sentry/kernel/task_stop.go226
-rw-r--r--pkg/sentry/kernel/task_syscall.go469
-rw-r--r--pkg/sentry/kernel/task_test.go69
-rw-r--r--pkg/sentry/kernel/task_usermem.go301
-rw-r--r--pkg/sentry/kernel/thread_group.go531
-rw-r--r--pkg/sentry/kernel/threads.go478
-rw-r--r--pkg/sentry/kernel/time/BUILD19
-rw-r--r--pkg/sentry/kernel/time/context.go44
-rw-r--r--pkg/sentry/kernel/time/time.go709
-rw-r--r--pkg/sentry/kernel/timekeeper.go325
-rw-r--r--pkg/sentry/kernel/timekeeper_state.go41
-rw-r--r--pkg/sentry/kernel/timekeeper_test.go156
-rw-r--r--pkg/sentry/kernel/tty.go41
-rw-r--r--pkg/sentry/kernel/uncaught_signal.proto37
-rw-r--r--pkg/sentry/kernel/uts_namespace.go101
-rw-r--r--pkg/sentry/kernel/vdso.go148
-rw-r--r--pkg/sentry/kernel/version.go33
-rw-r--r--pkg/sentry/limits/BUILD27
-rw-r--r--pkg/sentry/limits/context.go35
-rw-r--r--pkg/sentry/limits/limits.go137
-rw-r--r--pkg/sentry/limits/limits_test.go43
-rw-r--r--pkg/sentry/limits/linux.go100
-rw-r--r--pkg/sentry/loader/BUILD46
-rw-r--r--pkg/sentry/loader/elf.go700
-rw-r--r--pkg/sentry/loader/interpreter.go108
-rw-r--r--pkg/sentry/loader/loader.go315
-rw-r--r--pkg/sentry/loader/vdso.go382
-rw-r--r--pkg/sentry/loader/vdso_state.go48
-rw-r--r--pkg/sentry/memmap/BUILD55
-rw-r--r--pkg/sentry/memmap/mapping_set.go253
-rw-r--r--pkg/sentry/memmap/mapping_set_test.go260
-rw-r--r--pkg/sentry/memmap/memmap.go363
-rw-r--r--pkg/sentry/mm/BUILD142
-rw-r--r--pkg/sentry/mm/README.md280
-rw-r--r--pkg/sentry/mm/address_space.go236
-rw-r--r--pkg/sentry/mm/aio_context.go429
-rw-r--r--pkg/sentry/mm/aio_context_state.go20
-rw-r--r--pkg/sentry/mm/debug.go98
-rw-r--r--pkg/sentry/mm/io.go639
-rw-r--r--pkg/sentry/mm/lifecycle.go283
-rw-r--r--pkg/sentry/mm/metadata.go183
-rw-r--r--pkg/sentry/mm/mm.go478
-rw-r--r--pkg/sentry/mm/mm_test.go230
-rw-r--r--pkg/sentry/mm/pma.go1036
-rw-r--r--pkg/sentry/mm/procfs.go329
-rw-r--r--pkg/sentry/mm/save_restore.go57
-rw-r--r--pkg/sentry/mm/shm.go66
-rw-r--r--pkg/sentry/mm/special_mappable.go157
-rw-r--r--pkg/sentry/mm/syscalls.go1286
-rw-r--r--pkg/sentry/mm/vma.go568
-rw-r--r--pkg/sentry/pgalloc/BUILD108
-rw-r--r--pkg/sentry/pgalloc/context.go48
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go1279
-rw-r--r--pkg/sentry/pgalloc/pgalloc_test.go246
-rw-r--r--pkg/sentry/pgalloc/pgalloc_unsafe.go40
-rw-r--r--pkg/sentry/pgalloc/save_restore.go212
-rw-r--r--pkg/sentry/platform/BUILD39
-rw-r--r--pkg/sentry/platform/context.go36
-rw-r--r--pkg/sentry/platform/interrupt/BUILD19
-rw-r--r--pkg/sentry/platform/interrupt/interrupt.go97
-rw-r--r--pkg/sentry/platform/interrupt/interrupt_test.go99
-rw-r--r--pkg/sentry/platform/kvm/BUILD80
-rw-r--r--pkg/sentry/platform/kvm/address_space.go249
-rw-r--r--pkg/sentry/platform/kvm/bluepill.go96
-rw-r--r--pkg/sentry/platform/kvm/bluepill_allocator.go100
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.go129
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.s93
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go87
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go124
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.s89
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go97
-rw-r--r--pkg/sentry/platform/kvm/bluepill_fault.go127
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go232
-rw-r--r--pkg/sentry/platform/kvm/context.go90
-rw-r--r--pkg/sentry/platform/kvm/filters_amd64.go33
-rw-r--r--pkg/sentry/platform/kvm/filters_arm64.go32
-rw-r--r--pkg/sentry/platform/kvm/kvm.go201
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64.go190
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64_unsafe.go77
-rw-r--r--pkg/sentry/platform/kvm/kvm_arm64.go67
-rw-r--r--pkg/sentry/platform/kvm/kvm_arm64_unsafe.go39
-rw-r--r--pkg/sentry/platform/kvm/kvm_const.go87
-rw-r--r--pkg/sentry/platform/kvm/kvm_const_arm64.go140
-rw-r--r--pkg/sentry/platform/kvm/kvm_test.go533
-rw-r--r--pkg/sentry/platform/kvm/machine.go575
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go347
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64_unsafe.go177
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go195
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go282
-rw-r--r--pkg/sentry/platform/kvm/machine_unsafe.go145
-rw-r--r--pkg/sentry/platform/kvm/physical_map.go214
-rw-r--r--pkg/sentry/platform/kvm/physical_map_amd64.go22
-rw-r--r--pkg/sentry/platform/kvm/physical_map_arm64.go19
-rw-r--r--pkg/sentry/platform/kvm/testutil/BUILD17
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil.go72
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_amd64.go139
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_amd64.s98
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_arm64.go60
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_arm64.s106
-rw-r--r--pkg/sentry/platform/kvm/virtual_map.go113
-rw-r--r--pkg/sentry/platform/kvm/virtual_map_test.go93
-rw-r--r--pkg/sentry/platform/mmap_min_addr.go60
-rw-r--r--pkg/sentry/platform/platform.go398
-rw-r--r--pkg/sentry/platform/ptrace/BUILD39
-rw-r--r--pkg/sentry/platform/ptrace/filters.go33
-rw-r--r--pkg/sentry/platform/ptrace/ptrace.go266
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_amd64.go46
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_arm64.go29
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go62
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_unsafe.go172
-rw-r--r--pkg/sentry/platform/ptrace/stub_amd64.s119
-rw-r--r--pkg/sentry/platform/ptrace/stub_arm64.s112
-rw-r--r--pkg/sentry/platform/ptrace/stub_unsafe.go98
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go663
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_amd64.go259
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_arm64.go174
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go259
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go95
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_unsafe.go33
-rw-r--r--pkg/sentry/platform/ring0/BUILD86
-rw-r--r--pkg/sentry/platform/ring0/aarch64.go110
-rw-r--r--pkg/sentry/platform/ring0/defs.go109
-rw-r--r--pkg/sentry/platform/ring0/defs_amd64.go148
-rw-r--r--pkg/sentry/platform/ring0/defs_arm64.go143
-rw-r--r--pkg/sentry/platform/ring0/entry_amd64.go128
-rw-r--r--pkg/sentry/platform/ring0/entry_amd64.s319
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.go60
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s782
-rw-r--r--pkg/sentry/platform/ring0/gen_offsets/BUILD34
-rw-r--r--pkg/sentry/platform/ring0/gen_offsets/main.go24
-rw-r--r--pkg/sentry/platform/ring0/kernel.go82
-rw-r--r--pkg/sentry/platform/ring0/kernel_amd64.go281
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go66
-rw-r--r--pkg/sentry/platform/ring0/kernel_unsafe.go41
-rw-r--r--pkg/sentry/platform/ring0/lib_amd64.go131
-rw-r--r--pkg/sentry/platform/ring0/lib_amd64.s247
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.go52
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.s131
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64_unsafe.go108
-rw-r--r--pkg/sentry/platform/ring0/offsets_amd64.go93
-rw-r--r--pkg/sentry/platform/ring0/offsets_arm64.go127
-rw-r--r--pkg/sentry/platform/ring0/pagetables/BUILD115
-rw-r--r--pkg/sentry/platform/ring0/pagetables/allocator.go127
-rw-r--r--pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go53
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables.go220
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go212
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go54
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go75
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go57
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go80
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_test.go156
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_x86.go180
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pcids.go104
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pcids_aarch64.go32
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pcids_aarch64.s45
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pcids_x86.go20
-rw-r--r--pkg/sentry/platform/ring0/pagetables/walker_amd64.go307
-rw-r--r--pkg/sentry/platform/ring0/pagetables/walker_arm64.go314
-rw-r--r--pkg/sentry/platform/ring0/ring0.go16
-rw-r--r--pkg/sentry/platform/ring0/x86.go264
-rw-r--r--pkg/sentry/sighandling/BUILD13
-rw-r--r--pkg/sentry/sighandling/sighandling.go102
-rw-r--r--pkg/sentry/sighandling/sighandling_unsafe.go48
-rw-r--r--pkg/sentry/socket/BUILD24
-rw-r--r--pkg/sentry/socket/control/BUILD29
-rw-r--r--pkg/sentry/socket/control/control.go591
-rw-r--r--pkg/sentry/socket/control/control_vfs2.go131
-rw-r--r--pkg/sentry/socket/hostinet/BUILD45
-rw-r--r--pkg/sentry/socket/hostinet/device.go19
-rw-r--r--pkg/sentry/socket/hostinet/hostinet.go17
-rw-r--r--pkg/sentry/socket/hostinet/save_restore.go20
-rw-r--r--pkg/sentry/socket/hostinet/socket.go713
-rw-r--r--pkg/sentry/socket/hostinet/socket_unsafe.go139
-rw-r--r--pkg/sentry/socket/hostinet/socket_vfs2.go202
-rw-r--r--pkg/sentry/socket/hostinet/sockopt_impl.go27
-rw-r--r--pkg/sentry/socket/hostinet/stack.go459
-rw-r--r--pkg/sentry/socket/netfilter/BUILD29
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go95
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go761
-rw-r--r--pkg/sentry/socket/netfilter/owner_matcher.go149
-rw-r--r--pkg/sentry/socket/netfilter/targets.go35
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go130
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go129
-rw-r--r--pkg/sentry/socket/netlink/BUILD52
-rw-r--r--pkg/sentry/socket/netlink/message.go281
-rw-r--r--pkg/sentry/socket/netlink/message_test.go312
-rw-r--r--pkg/sentry/socket/netlink/port/BUILD16
-rw-r--r--pkg/sentry/socket/netlink/port/port.go117
-rw-r--r--pkg/sentry/socket/netlink/port/port_test.go82
-rw-r--r--pkg/sentry/socket/netlink/provider.go116
-rw-r--r--pkg/sentry/socket/netlink/provider_vfs2.go69
-rw-r--r--pkg/sentry/socket/netlink/route/BUILD20
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go498
-rw-r--r--pkg/sentry/socket/netlink/socket.go780
-rw-r--r--pkg/sentry/socket/netlink/socket_vfs2.go152
-rw-r--r--pkg/sentry/socket/netlink/uevent/BUILD16
-rw-r--r--pkg/sentry/socket/netlink/uevent/protocol.go60
-rw-r--r--pkg/sentry/socket/netstack/BUILD56
-rw-r--r--pkg/sentry/socket/netstack/device.go20
-rw-r--r--pkg/sentry/socket/netstack/netstack.go3143
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go330
-rw-r--r--pkg/sentry/socket/netstack/provider.go199
-rw-r--r--pkg/sentry/socket/netstack/provider_vfs2.go141
-rw-r--r--pkg/sentry/socket/netstack/save_restore.go27
-rw-r--r--pkg/sentry/socket/netstack/stack.go386
-rw-r--r--pkg/sentry/socket/socket.go461
-rw-r--r--pkg/sentry/socket/unix/BUILD39
-rw-r--r--pkg/sentry/socket/unix/device.go20
-rw-r--r--pkg/sentry/socket/unix/io.go111
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD41
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go486
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned_state.go53
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go218
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go247
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go1006
-rw-r--r--pkg/sentry/socket/unix/unix.go772
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go371
-rw-r--r--pkg/sentry/state/BUILD23
-rw-r--r--pkg/sentry/state/state.go119
-rw-r--r--pkg/sentry/state/state_metadata.go45
-rw-r--r--pkg/sentry/state/state_unsafe.go34
-rw-r--r--pkg/sentry/strace/BUILD45
-rw-r--r--pkg/sentry/strace/capability.go176
-rw-r--r--pkg/sentry/strace/clone.go113
-rw-r--r--pkg/sentry/strace/epoll.go89
-rw-r--r--pkg/sentry/strace/futex.go52
-rw-r--r--pkg/sentry/strace/linux64_amd64.go384
-rw-r--r--pkg/sentry/strace/linux64_arm64.go323
-rw-r--r--pkg/sentry/strace/open.go96
-rw-r--r--pkg/sentry/strace/poll.go71
-rw-r--r--pkg/sentry/strace/ptrace.go62
-rw-r--r--pkg/sentry/strace/select.go56
-rw-r--r--pkg/sentry/strace/signal.go148
-rw-r--r--pkg/sentry/strace/socket.go644
-rw-r--r--pkg/sentry/strace/strace.go874
-rw-r--r--pkg/sentry/strace/strace.proto49
-rw-r--r--pkg/sentry/strace/syscalls.go292
-rw-r--r--pkg/sentry/syscalls/BUILD21
-rw-r--r--pkg/sentry/syscalls/epoll.go173
-rw-r--r--pkg/sentry/syscalls/linux/BUILD103
-rw-r--r--pkg/sentry/syscalls/linux/error.go157
-rw-r--r--pkg/sentry/syscalls/linux/flags.go55
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go736
-rw-r--r--pkg/sentry/syscalls/linux/sigset.go71
-rw-r--r--pkg/sentry/syscalls/linux/sys_aio.go382
-rw-r--r--pkg/sentry/syscalls/linux/sys_capability.go149
-rw-r--r--pkg/sentry/syscalls/linux/sys_clone_amd64.go35
-rw-r--r--pkg/sentry/syscalls/linux/sys_clone_arm64.go35
-rw-r--r--pkg/sentry/syscalls/linux/sys_epoll.go147
-rw-r--r--pkg/sentry/syscalls/linux/sys_eventfd.go56
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go2238
-rw-r--r--pkg/sentry/syscalls/linux/sys_futex.go288
-rw-r--r--pkg/sentry/syscalls/linux/sys_getdents.go250
-rw-r--r--pkg/sentry/syscalls/linux/sys_identity.go180
-rw-r--r--pkg/sentry/syscalls/linux/sys_inotify.go133
-rw-r--r--pkg/sentry/syscalls/linux/sys_lseek.go58
-rw-r--r--pkg/sentry/syscalls/linux/sys_mempolicy.go312
-rw-r--r--pkg/sentry/syscalls/linux/sys_mmap.go332
-rw-r--r--pkg/sentry/syscalls/linux/sys_mount.go154
-rw-r--r--pkg/sentry/syscalls/linux/sys_pipe.go77
-rw-r--r--pkg/sentry/syscalls/linux/sys_poll.go545
-rw-r--r--pkg/sentry/syscalls/linux/sys_prctl.go228
-rw-r--r--pkg/sentry/syscalls/linux/sys_random.go92
-rw-r--r--pkg/sentry/syscalls/linux/sys_read.go394
-rw-r--r--pkg/sentry/syscalls/linux/sys_rlimit.go224
-rw-r--r--pkg/sentry/syscalls/linux/sys_rseq.go48
-rw-r--r--pkg/sentry/syscalls/linux/sys_rusage.go112
-rw-r--r--pkg/sentry/syscalls/linux/sys_sched.go99
-rw-r--r--pkg/sentry/syscalls/linux/sys_seccomp.go76
-rw-r--r--pkg/sentry/syscalls/linux/sys_sem.go241
-rw-r--r--pkg/sentry/syscalls/linux/sys_shm.go161
-rw-r--r--pkg/sentry/syscalls/linux/sys_signal.go590
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go1138
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go337
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat.go290
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat_amd64.go45
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat_arm64.go45
-rw-r--r--pkg/sentry/syscalls/linux/sys_sync.go141
-rw-r--r--pkg/sentry/syscalls/linux/sys_sysinfo.go48
-rw-r--r--pkg/sentry/syscalls/linux/sys_syslog.go61
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go769
-rw-r--r--pkg/sentry/syscalls/linux/sys_time.go342
-rw-r--r--pkg/sentry/syscalls/linux/sys_timer.go203
-rw-r--r--pkg/sentry/syscalls/linux/sys_timerfd.go121
-rw-r--r--pkg/sentry/syscalls/linux/sys_tls_amd64.go52
-rw-r--r--pkg/sentry/syscalls/linux/sys_tls_arm64.go28
-rw-r--r--pkg/sentry/syscalls/linux/sys_utsname.go95
-rw-r--r--pkg/sentry/syscalls/linux/sys_write.go364
-rw-r--r--pkg/sentry/syscalls/linux/sys_xattr.go432
-rw-r--r--pkg/sentry/syscalls/linux/timespec.go111
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD76
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/aio.go216
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll.go228
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/eventfd.go61
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/execve.go137
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go355
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/filesystem.go384
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fscontext.go131
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/getdents.go161
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/inotify.go137
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/ioctl.go107
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/lock.go64
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/memfd.go63
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mmap.go92
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mount.go145
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/path.go94
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/pipe.go63
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/poll.go586
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/read_write.go641
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go428
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/signal.go100
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go1139
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/splice.go291
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat.go388
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat_amd64.go46
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat_arm64.go46
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/sync.go115
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/timerfd.go127
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/vfs2.go168
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/xattr.go356
-rw-r--r--pkg/sentry/syscalls/syscalls.go111
-rw-r--r--pkg/sentry/time/BUILD50
-rw-r--r--pkg/sentry/time/LICENSE27
-rw-r--r--pkg/sentry/time/arith_arm64.go70
-rw-r--r--pkg/sentry/time/calibrated_clock.go269
-rw-r--r--pkg/sentry/time/calibrated_clock_test.go186
-rw-r--r--pkg/sentry/time/clock_id.go40
-rw-r--r--pkg/sentry/time/clocks.go31
-rw-r--r--pkg/sentry/time/muldiv_amd64.s44
-rw-r--r--pkg/sentry/time/muldiv_arm64.s47
-rw-r--r--pkg/sentry/time/parameters.go239
-rw-r--r--pkg/sentry/time/parameters_test.go501
-rw-r--r--pkg/sentry/time/sampler.go225
-rw-r--r--pkg/sentry/time/sampler_test.go183
-rw-r--r--pkg/sentry/time/sampler_unsafe.go56
-rw-r--r--pkg/sentry/time/tsc_amd64.s27
-rw-r--r--pkg/sentry/time/tsc_arm64.s22
-rw-r--r--pkg/sentry/unimpl/BUILD20
-rw-r--r--pkg/sentry/unimpl/events.go45
-rw-r--r--pkg/sentry/unimpl/unimplemented_syscall.proto27
-rw-r--r--pkg/sentry/uniqueid/BUILD13
-rw-r--r--pkg/sentry/uniqueid/context.go54
-rw-r--r--pkg/sentry/usage/BUILD22
-rw-r--r--pkg/sentry/usage/cpu.go46
-rw-r--r--pkg/sentry/usage/io.go90
-rw-r--r--pkg/sentry/usage/memory.go291
-rw-r--r--pkg/sentry/usage/memory_unsafe.go27
-rw-r--r--pkg/sentry/usage/usage.go16
-rw-r--r--pkg/sentry/vfs/BUILD100
-rw-r--r--pkg/sentry/vfs/README.md195
-rw-r--r--pkg/sentry/vfs/anonfs.go314
-rw-r--r--pkg/sentry/vfs/context.go75
-rw-r--r--pkg/sentry/vfs/debug.go22
-rw-r--r--pkg/sentry/vfs/dentry.go324
-rw-r--r--pkg/sentry/vfs/device.go132
-rw-r--r--pkg/sentry/vfs/epoll.go383
-rw-r--r--pkg/sentry/vfs/file_description.go837
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go428
-rw-r--r--pkg/sentry/vfs/file_description_impl_util_test.go224
-rw-r--r--pkg/sentry/vfs/filesystem.go556
-rw-r--r--pkg/sentry/vfs/filesystem_impl_util.go43
-rw-r--r--pkg/sentry/vfs/filesystem_type.go117
-rw-r--r--pkg/sentry/vfs/g3doc/inotify.md210
-rw-r--r--pkg/sentry/vfs/genericfstree/BUILD16
-rw-r--r--pkg/sentry/vfs/genericfstree/genericfstree.go81
-rw-r--r--pkg/sentry/vfs/inotify.go774
-rw-r--r--pkg/sentry/vfs/lock.go109
-rw-r--r--pkg/sentry/vfs/memxattr/BUILD15
-rw-r--r--pkg/sentry/vfs/memxattr/xattr.go102
-rw-r--r--pkg/sentry/vfs/mount.go903
-rw-r--r--pkg/sentry/vfs/mount_test.go458
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go364
-rw-r--r--pkg/sentry/vfs/options.go235
-rw-r--r--pkg/sentry/vfs/pathname.go195
-rw-r--r--pkg/sentry/vfs/permissions.go280
-rw-r--r--pkg/sentry/vfs/resolving_path.go466
-rw-r--r--pkg/sentry/vfs/vfs.go849
-rw-r--r--pkg/sentry/watchdog/BUILD17
-rw-r--r--pkg/sentry/watchdog/watchdog.go374
-rw-r--r--pkg/sleep/BUILD24
-rw-r--r--pkg/sleep/commit_amd64.s35
-rw-r--r--pkg/sleep/commit_arm64.s38
-rw-r--r--pkg/sleep/commit_asm.go20
-rw-r--r--pkg/sleep/commit_noasm.go33
-rw-r--r--pkg/sleep/empty.s15
-rw-r--r--pkg/sleep/sleep_test.go573
-rw-r--r--pkg/sleep/sleep_unsafe.go400
-rw-r--r--pkg/state/BUILD100
-rw-r--r--pkg/state/README.md158
-rw-r--r--pkg/state/decode.go725
-rw-r--r--pkg/state/decode_unsafe.go27
-rw-r--r--pkg/state/encode.go841
-rw-r--r--pkg/state/encode_unsafe.go33
-rw-r--r--pkg/state/pretty/BUILD13
-rw-r--r--pkg/state/pretty/pretty.go273
-rw-r--r--pkg/state/state.go321
-rw-r--r--pkg/state/state_norace.go19
-rw-r--r--pkg/state/state_race.go19
-rw-r--r--pkg/state/statefile/BUILD22
-rw-r--r--pkg/state/statefile/statefile.go239
-rw-r--r--pkg/state/statefile/statefile_test.go290
-rw-r--r--pkg/state/stats.go145
-rw-r--r--pkg/state/tests/BUILD43
-rw-r--r--pkg/state/tests/array.go35
-rw-r--r--pkg/state/tests/array_test.go134
-rw-r--r--pkg/state/tests/bench.go24
-rw-r--r--pkg/state/tests/bench_test.go153
-rw-r--r--pkg/state/tests/bool_test.go31
-rw-r--r--pkg/state/tests/float_test.go118
-rw-r--r--pkg/state/tests/integer.go163
-rw-r--r--pkg/state/tests/integer_test.go94
-rw-r--r--pkg/state/tests/load.go61
-rw-r--r--pkg/state/tests/load_test.go70
-rw-r--r--pkg/state/tests/map.go28
-rw-r--r--pkg/state/tests/map_test.go90
-rw-r--r--pkg/state/tests/register.go21
-rw-r--r--pkg/state/tests/register_test.go167
-rw-r--r--pkg/state/tests/string_test.go34
-rw-r--r--pkg/state/tests/struct.go65
-rw-r--r--pkg/state/tests/struct_test.go89
-rw-r--r--pkg/state/tests/tests.go215
-rw-r--r--pkg/state/types.go361
-rw-r--r--pkg/state/wire/BUILD12
-rw-r--r--pkg/state/wire/wire.go970
-rw-r--r--pkg/sync/BUILD55
-rw-r--r--pkg/sync/LICENSE27
-rw-r--r--pkg/sync/README.md5
-rw-r--r--pkg/sync/aliases.go36
-rw-r--r--pkg/sync/atomicptr_unsafe.go47
-rw-r--r--pkg/sync/atomicptrtest/BUILD27
-rw-r--r--pkg/sync/atomicptrtest/atomicptr_test.go31
-rw-r--r--pkg/sync/memmove_unsafe.go28
-rw-r--r--pkg/sync/mutex_test.go71
-rw-r--r--pkg/sync/mutex_unsafe.go49
-rw-r--r--pkg/sync/norace_unsafe.go35
-rw-r--r--pkg/sync/race_unsafe.go41
-rw-r--r--pkg/sync/rwmutex_test.go205
-rw-r--r--pkg/sync/rwmutex_unsafe.go198
-rw-r--r--pkg/sync/seqatomic_unsafe.go72
-rw-r--r--pkg/sync/seqatomictest/BUILD31
-rw-r--r--pkg/sync/seqatomictest/seqatomic_test.go132
-rw-r--r--pkg/sync/seqcount.go149
-rw-r--r--pkg/sync/seqcount_test.go153
-rw-r--r--pkg/sync/sync.go7
-rw-r--r--pkg/syncevent/BUILD39
-rw-r--r--pkg/syncevent/broadcaster.go218
-rw-r--r--pkg/syncevent/broadcaster_test.go376
-rw-r--r--pkg/syncevent/receiver.go103
-rw-r--r--pkg/syncevent/source.go59
-rw-r--r--pkg/syncevent/syncevent.go32
-rw-r--r--pkg/syncevent/syncevent_example_test.go108
-rw-r--r--pkg/syncevent/waiter_amd64.s32
-rw-r--r--pkg/syncevent/waiter_arm64.s34
-rw-r--r--pkg/syncevent/waiter_asm_unsafe.go24
-rw-r--r--pkg/syncevent/waiter_noasm_unsafe.go39
-rw-r--r--pkg/syncevent/waiter_test.go414
-rw-r--r--pkg/syncevent/waiter_unsafe.go206
-rw-r--r--pkg/syserr/BUILD18
-rw-r--r--pkg/syserr/host_linux.go46
-rw-r--r--pkg/syserr/netstack.go103
-rw-r--r--pkg/syserr/syserr.go293
-rw-r--r--pkg/syserror/BUILD17
-rw-r--r--pkg/syserror/syserror.go159
-rw-r--r--pkg/syserror/syserror_test.go136
-rw-r--r--pkg/tcpip/BUILD32
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD37
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go738
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go716
-rw-r--r--pkg/tcpip/buffer/BUILD19
-rw-r--r--pkg/tcpip/buffer/prependable.go85
-rw-r--r--pkg/tcpip/buffer/view.go256
-rw-r--r--pkg/tcpip/buffer/view_test.go521
-rw-r--r--pkg/tcpip/checker/BUILD16
-rw-r--r--pkg/tcpip/checker/checker.go976
-rw-r--r--pkg/tcpip/hash/jenkins/BUILD18
-rw-r--r--pkg/tcpip/hash/jenkins/jenkins.go80
-rw-r--r--pkg/tcpip/hash/jenkins/jenkins_test.go176
-rw-r--r--pkg/tcpip/header/BUILD69
-rw-r--r--pkg/tcpip/header/arp.go100
-rw-r--r--pkg/tcpip/header/checksum.go249
-rw-r--r--pkg/tcpip/header/checksum_test.go171
-rw-r--r--pkg/tcpip/header/eth.go177
-rw-r--r--pkg/tcpip/header/eth_test.go102
-rw-r--r--pkg/tcpip/header/gue.go73
-rw-r--r--pkg/tcpip/header/icmpv4.go170
-rw-r--r--pkg/tcpip/header/icmpv6.go221
-rw-r--r--pkg/tcpip/header/interfaces.go92
-rw-r--r--pkg/tcpip/header/ipv4.go312
-rw-r--r--pkg/tcpip/header/ipv6.go499
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers.go551
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers_test.go992
-rw-r--r--pkg/tcpip/header/ipv6_fragment.go146
-rw-r--r--pkg/tcpip/header/ipv6_test.go417
-rw-r--r--pkg/tcpip/header/ipversion_test.go67
-rw-r--r--pkg/tcpip/header/ndp_neighbor_advert.go110
-rw-r--r--pkg/tcpip/header/ndp_neighbor_solicit.go52
-rw-r--r--pkg/tcpip/header/ndp_options.go899
-rw-r--r--pkg/tcpip/header/ndp_router_advert.go112
-rw-r--r--pkg/tcpip/header/ndp_router_solicit.go36
-rw-r--r--pkg/tcpip/header/ndp_test.go1521
-rw-r--r--pkg/tcpip/header/ndpoptionidentifier_string.go50
-rw-r--r--pkg/tcpip/header/tcp.go621
-rw-r--r--pkg/tcpip/header/tcp_test.go148
-rw-r--r--pkg/tcpip/header/udp.go120
-rw-r--r--pkg/tcpip/link/channel/BUILD15
-rw-r--r--pkg/tcpip/link/channel/channel.go298
-rw-r--r--pkg/tcpip/link/fdbased/BUILD40
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go657
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go502
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_unsafe.go23
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go199
-rw-r--r--pkg/tcpip/link/fdbased/mmap_stub.go23
-rw-r--r--pkg/tcpip/link/fdbased/mmap_unsafe.go84
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go317
-rw-r--r--pkg/tcpip/link/loopback/BUILD15
-rw-r--r--pkg/tcpip/link/loopback/loopback.go115
-rw-r--r--pkg/tcpip/link/muxed/BUILD28
-rw-r--r--pkg/tcpip/link/muxed/injectable.go137
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go98
-rw-r--r--pkg/tcpip/link/nested/BUILD31
-rw-r--r--pkg/tcpip/link/nested/nested.go131
-rw-r--r--pkg/tcpip/link/nested/nested_test.go105
-rw-r--r--pkg/tcpip/link/qdisc/fifo/BUILD19
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go209
-rw-r--r--pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go84
-rw-r--r--pkg/tcpip/link/rawfile/BUILD20
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_amd64.s41
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_arm64.s42
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go31
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go66
-rw-r--r--pkg/tcpip/link/rawfile/errors.go70
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go192
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD41
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/BUILD23
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/pipe.go78
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/pipe_test.go518
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go35
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/rx.go93
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/tx.go161
-rw-r--r--pkg/tcpip/link/sharedmem/queue/BUILD27
-rw-r--r--pkg/tcpip/link/sharedmem/queue/queue_test.go517
-rw-r--r--pkg/tcpip/link/sharedmem/queue/rx.go221
-rw-r--r--pkg/tcpip/link/sharedmem/queue/tx.go151
-rw-r--r--pkg/tcpip/link/sharedmem/rx.go159
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go289
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go812
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_unsafe.go25
-rw-r--r--pkg/tcpip/link/sharedmem/tx.go272
-rw-r--r--pkg/tcpip/link/sniffer/BUILD20
-rw-r--r--pkg/tcpip/link/sniffer/pcap.go66
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go394
-rw-r--r--pkg/tcpip/link/tun/BUILD25
-rw-r--r--pkg/tcpip/link/tun/device.go358
-rw-r--r--pkg/tcpip/link/tun/protocol.go56
-rw-r--r--pkg/tcpip/link/tun/tun_unsafe.go63
-rw-r--r--pkg/tcpip/link/waitable/BUILD30
-rw-r--r--pkg/tcpip/link/waitable/waitable.go149
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go173
-rw-r--r--pkg/tcpip/network/BUILD22
-rw-r--r--pkg/tcpip/network/arp/BUILD32
-rw-r--r--pkg/tcpip/network/arp/arp.go224
-rw-r--r--pkg/tcpip/network/arp/arp_test.go146
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD45
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap.go77
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap_test.go126
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go144
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go165
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go118
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go105
-rw-r--r--pkg/tcpip/network/hash/BUILD13
-rw-r--r--pkg/tcpip/network/hash/hash.go93
-rw-r--r--pkg/tcpip/network/ip_test.go673
-rw-r--r--pkg/tcpip/network/ipv4/BUILD39
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go167
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go594
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go745
-rw-r--r--pkg/tcpip/network/ipv6/BUILD44
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go549
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go953
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go599
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go1265
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go907
-rw-r--r--pkg/tcpip/ports/BUILD22
-rw-r--r--pkg/tcpip/ports/ports.go554
-rw-r--r--pkg/tcpip/ports/ports_test.go450
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/BUILD22
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go225
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/BUILD21
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go203
-rw-r--r--pkg/tcpip/seqnum/BUILD9
-rw-r--r--pkg/tcpip/seqnum/seqnum.go62
-rw-r--r--pkg/tcpip/stack/BUILD118
-rw-r--r--pkg/tcpip/stack/conntrack.go331
-rw-r--r--pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go40
-rw-r--r--pkg/tcpip/stack/forwarder.go131
-rw-r--r--pkg/tcpip/stack/forwarder_test.go650
-rw-r--r--pkg/tcpip/stack/icmp_rate_limit.go41
-rw-r--r--pkg/tcpip/stack/iptables.go367
-rw-r--r--pkg/tcpip/stack/iptables_targets.go164
-rw-r--r--pkg/tcpip/stack/iptables_types.go253
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go295
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go277
-rw-r--r--pkg/tcpip/stack/ndp.go1981
-rw-r--r--pkg/tcpip/stack/ndp_test.go5363
-rw-r--r--pkg/tcpip/stack/nic.go1743
-rw-r--r--pkg/tcpip/stack/nic_test.go318
-rw-r--r--pkg/tcpip/stack/packet_buffer.go115
-rw-r--r--pkg/tcpip/stack/rand.go40
-rw-r--r--pkg/tcpip/stack/registration.go560
-rw-r--r--pkg/tcpip/stack/route.go289
-rw-r--r--pkg/tcpip/stack/stack.go1938
-rw-r--r--pkg/tcpip/stack/stack_global_state.go19
-rw-r--r--pkg/tcpip/stack/stack_options.go106
-rw-r--r--pkg/tcpip/stack/stack_test.go3420
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go686
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go390
-rw-r--r--pkg/tcpip/stack/transport_test.go664
-rw-r--r--pkg/tcpip/tcpip.go1616
-rw-r--r--pkg/tcpip/tcpip_test.go228
-rw-r--r--pkg/tcpip/time.s15
-rw-r--r--pkg/tcpip/time_unsafe.go47
-rw-r--r--pkg/tcpip/timer.go184
-rw-r--r--pkg/tcpip/timer_test.go261
-rw-r--r--pkg/tcpip/transport/icmp/BUILD40
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go831
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go95
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go145
-rw-r--r--pkg/tcpip/transport/packet/BUILD37
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go469
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go72
-rw-r--r--pkg/tcpip/transport/raw/BUILD39
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go729
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go94
-rw-r--r--pkg/tcpip/transport/raw/protocol.go35
-rw-r--r--pkg/tcpip/transport/tcp/BUILD126
-rw-r--r--pkg/tcpip/transport/tcp/accept.go752
-rw-r--r--pkg/tcpip/transport/tcp/connect.go1713
-rw-r--r--pkg/tcpip/transport/tcp/connect_unsafe.go30
-rw-r--r--pkg/tcpip/transport/tcp/cubic.go234
-rw-r--r--pkg/tcpip/transport/tcp/cubic_state.go29
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go234
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go651
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go2888
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go348
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go169
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go541
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go475
-rw-r--r--pkg/tcpip/transport/tcp/rcv_state.go29
-rw-r--r--pkg/tcpip/transport/tcp/rcv_test.go74
-rw-r--r--pkg/tcpip/transport/tcp/reno.go103
-rw-r--r--pkg/tcpip/transport/tcp/sack.go105
-rw-r--r--pkg/tcpip/transport/tcp/sack_scoreboard.go306
-rw-r--r--pkg/tcpip/transport/tcp/sack_scoreboard_test.go249
-rw-r--r--pkg/tcpip/transport/tcp/segment.go194
-rw-r--r--pkg/tcpip/transport/tcp/segment_heap.go51
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go85
-rw-r--r--pkg/tcpip/transport/tcp/segment_state.go82
-rw-r--r--pkg/tcpip/transport/tcp/snd.go1487
-rw-r--r--pkg/tcpip/transport/tcp/snd_state.go60
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go550
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go589
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go7258
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go291
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/BUILD26
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go1121
-rw-r--r--pkg/tcpip/transport/tcp/timer.go142
-rw-r--r--pkg/tcpip/transport/tcp/timer_test.go47
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD23
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go352
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go511
-rw-r--r--pkg/tcpip/transport/udp/BUILD60
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go1497
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go137
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go96
-rw-r--r--pkg/tcpip/transport/udp/protocol.go231
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go2072
-rw-r--r--pkg/test/criutil/BUILD14
-rw-r--r--pkg/test/criutil/criutil.go317
-rw-r--r--pkg/test/dockerutil/BUILD25
-rw-r--r--pkg/test/dockerutil/container.go501
-rw-r--r--pkg/test/dockerutil/dockerutil.go121
-rw-r--r--pkg/test/dockerutil/exec.go194
-rw-r--r--pkg/test/dockerutil/network.go113
-rw-r--r--pkg/test/testutil/BUILD20
-rw-r--r--pkg/test/testutil/testutil.go536
-rw-r--r--pkg/test/testutil/testutil_runfiles.go75
-rw-r--r--pkg/unet/BUILD26
-rw-r--r--pkg/unet/unet.go569
-rw-r--r--pkg/unet/unet_test.go736
-rw-r--r--pkg/unet/unet_unsafe.go288
-rw-r--r--pkg/urpc/BUILD23
-rw-r--r--pkg/urpc/urpc.go636
-rw-r--r--pkg/urpc/urpc_test.go210
-rw-r--r--pkg/usermem/BUILD55
-rw-r--r--pkg/usermem/README.md31
-rw-r--r--pkg/usermem/access_type.go128
-rw-r--r--pkg/usermem/addr.go125
-rw-r--r--pkg/usermem/addr_range_seq_test.go197
-rw-r--r--pkg/usermem/addr_range_seq_unsafe.go277
-rw-r--r--pkg/usermem/bytes_io.go141
-rw-r--r--pkg/usermem/bytes_io_unsafe.go47
-rw-r--r--pkg/usermem/usermem.go595
-rw-r--r--pkg/usermem/usermem_arm64.go53
-rw-r--r--pkg/usermem/usermem_test.go424
-rw-r--r--pkg/usermem/usermem_x86.go38
-rw-r--r--pkg/waiter/BUILD35
-rw-r--r--pkg/waiter/waiter.go244
-rw-r--r--pkg/waiter/waiter_test.go192
-rw-r--r--runsc/BUILD123
-rw-r--r--runsc/boot/BUILD137
-rw-r--r--runsc/boot/compat.go202
-rw-r--r--runsc/boot/compat_amd64.go100
-rw-r--r--runsc/boot/compat_arm64.go95
-rw-r--r--runsc/boot/compat_test.go90
-rw-r--r--runsc/boot/config.go329
-rw-r--r--runsc/boot/controller.go506
-rw-r--r--runsc/boot/debug.go29
-rw-r--r--runsc/boot/events.go81
-rw-r--r--runsc/boot/filter/BUILD28
-rw-r--r--runsc/boot/filter/config.go559
-rw-r--r--runsc/boot/filter/config_amd64.go31
-rw-r--r--runsc/boot/filter/config_arm64.go21
-rw-r--r--runsc/boot/filter/config_profile.go34
-rw-r--r--runsc/boot/filter/extra_filters.go28
-rw-r--r--runsc/boot/filter/extra_filters_msan.go34
-rw-r--r--runsc/boot/filter/extra_filters_race.go41
-rw-r--r--runsc/boot/filter/filter.go60
-rw-r--r--runsc/boot/fs.go1034
-rw-r--r--runsc/boot/fs_test.go250
-rw-r--r--runsc/boot/limits.go154
-rw-r--r--runsc/boot/loader.go1284
-rw-r--r--runsc/boot/loader_test.go715
-rw-r--r--runsc/boot/network.go341
-rw-r--r--runsc/boot/platforms/BUILD15
-rw-r--r--runsc/boot/platforms/platforms.go30
-rw-r--r--runsc/boot/pprof/BUILD11
-rw-r--r--runsc/boot/pprof/pprof.go20
-rw-r--r--runsc/boot/strace.go40
-rw-r--r--runsc/boot/vfs.go482
-rw-r--r--runsc/cgroup/BUILD27
-rw-r--r--runsc/cgroup/cgroup.go576
-rw-r--r--runsc/cgroup/cgroup_test.go649
-rw-r--r--runsc/cmd/BUILD95
-rw-r--r--runsc/cmd/boot.go290
-rw-r--r--runsc/cmd/capability.go157
-rw-r--r--runsc/cmd/capability_test.go127
-rw-r--r--runsc/cmd/checkpoint.go155
-rw-r--r--runsc/cmd/chroot.go97
-rw-r--r--runsc/cmd/cmd.go98
-rw-r--r--runsc/cmd/create.go115
-rw-r--r--runsc/cmd/debug.go304
-rw-r--r--runsc/cmd/delete.go87
-rw-r--r--runsc/cmd/delete_test.go41
-rw-r--r--runsc/cmd/do.go385
-rw-r--r--runsc/cmd/error.go72
-rw-r--r--runsc/cmd/events.go111
-rw-r--r--runsc/cmd/exec.go481
-rw-r--r--runsc/cmd/exec_test.go154
-rw-r--r--runsc/cmd/gofer.go484
-rw-r--r--runsc/cmd/gofer_test.go164
-rw-r--r--runsc/cmd/help.go120
-rw-r--r--runsc/cmd/install.go210
-rw-r--r--runsc/cmd/kill.go154
-rw-r--r--runsc/cmd/list.go117
-rw-r--r--runsc/cmd/path.go28
-rw-r--r--runsc/cmd/pause.go68
-rw-r--r--runsc/cmd/ps.go86
-rw-r--r--runsc/cmd/restore.go119
-rw-r--r--runsc/cmd/resume.go69
-rw-r--r--runsc/cmd/run.go100
-rw-r--r--runsc/cmd/spec.go206
-rw-r--r--runsc/cmd/start.go65
-rw-r--r--runsc/cmd/state.go76
-rw-r--r--runsc/cmd/statefile.go149
-rw-r--r--runsc/cmd/syscalls.go356
-rw-r--r--runsc/cmd/wait.go127
-rw-r--r--runsc/console/BUILD17
-rw-r--r--runsc/console/console.go63
-rw-r--r--runsc/container/BUILD74
-rw-r--r--runsc/container/console_test.go480
-rw-r--r--runsc/container/container.go1171
-rw-r--r--runsc/container/container_norace_test.go20
-rw-r--r--runsc/container/container_race_test.go20
-rw-r--r--runsc/container/container_test.go2348
-rw-r--r--runsc/container/hook.go111
-rw-r--r--runsc/container/multi_container_test.go1774
-rw-r--r--runsc/container/shared_volume_test.go273
-rw-r--r--runsc/container/state_file.go185
-rw-r--r--runsc/container/status.go60
-rw-r--r--runsc/debian/description1
-rwxr-xr-xrunsc/debian/postinst.sh24
-rw-r--r--runsc/flag/BUILD9
-rw-r--r--runsc/flag/flag.go33
-rw-r--r--runsc/fsgofer/BUILD35
-rw-r--r--runsc/fsgofer/filter/BUILD26
-rw-r--r--runsc/fsgofer/filter/config.go250
-rw-r--r--runsc/fsgofer/filter/config_amd64.go33
-rw-r--r--runsc/fsgofer/filter/config_arm64.go27
-rw-r--r--runsc/fsgofer/filter/extra_filters.go28
-rw-r--r--runsc/fsgofer/filter/extra_filters_msan.go33
-rw-r--r--runsc/fsgofer/filter/extra_filters_race.go42
-rw-r--r--runsc/fsgofer/filter/filter.go38
-rw-r--r--runsc/fsgofer/fsgofer.go1181
-rw-r--r--runsc/fsgofer/fsgofer_amd64_unsafe.go49
-rw-r--r--runsc/fsgofer/fsgofer_arm64_unsafe.go49
-rw-r--r--runsc/fsgofer/fsgofer_test.go692
-rw-r--r--runsc/fsgofer/fsgofer_unsafe.go82
-rw-r--r--runsc/main.go372
-rw-r--r--runsc/sandbox/BUILD37
-rw-r--r--runsc/sandbox/network.go411
-rw-r--r--runsc/sandbox/network_unsafe.go56
-rw-r--r--runsc/sandbox/sandbox.go1228
-rw-r--r--runsc/specutils/BUILD33
-rw-r--r--runsc/specutils/cri.go110
-rw-r--r--runsc/specutils/fs.go155
-rw-r--r--runsc/specutils/namespace.go289
-rw-r--r--runsc/specutils/specutils.go523
-rw-r--r--runsc/specutils/specutils_test.go265
-rw-r--r--runsc/version.go18
-rwxr-xr-xrunsc/version_test.sh36
-rwxr-xr-xscripts/benchmark.sh45
-rwxr-xr-xscripts/common.sh86
-rwxr-xr-xscripts/common_build.sh116
-rwxr-xr-xscripts/dev.sh75
-rwxr-xr-xscripts/do_tests.sh27
-rwxr-xr-xscripts/docker_tests.sh25
-rwxr-xr-xscripts/go.sh45
-rwxr-xr-xscripts/hostnet_tests.sh23
-rwxr-xr-xscripts/iptables_tests.sh26
-rwxr-xr-xscripts/issue_reviver.sh27
-rwxr-xr-xscripts/kvm_tests.sh30
-rwxr-xr-xscripts/make_tests.sh20
-rwxr-xr-xscripts/overlay_tests.sh23
-rwxr-xr-xscripts/packetdrill_tests.sh23
-rwxr-xr-xscripts/packetimpact_tests.sh23
-rwxr-xr-xscripts/root_tests.sh32
-rwxr-xr-xscripts/runtime_tests.sh26
-rwxr-xr-xscripts/simple_tests.sh20
-rwxr-xr-xscripts/swgso_tests.sh23
-rwxr-xr-xscripts/syscall_kvm_tests.sh20
-rwxr-xr-xscripts/syscall_tests.sh20
-rw-r--r--test/BUILD1
-rw-r--r--test/README.md40
-rw-r--r--test/cmd/test_app/BUILD21
-rw-r--r--test/cmd/test_app/fds.go185
-rw-r--r--test/cmd/test_app/test_app.go394
-rw-r--r--test/e2e/BUILD33
-rw-r--r--test/e2e/exec_test.go268
-rw-r--r--test/e2e/integration.go16
-rw-r--r--test/e2e/integration_test.go441
-rw-r--r--test/e2e/regression_test.go47
-rw-r--r--test/image/BUILD33
-rw-r--r--test/image/image.go16
-rw-r--r--test/image/image_test.go312
-rw-r--r--test/image/latin10k.txt33
-rw-r--r--test/image/mysql.sql23
-rw-r--r--test/image/ruby.rb23
-rwxr-xr-xtest/image/ruby.sh20
-rw-r--r--test/iptables/BUILD36
-rw-r--r--test/iptables/README.md54
-rw-r--r--test/iptables/filter_input.go729
-rw-r--r--test/iptables/filter_output.go607
-rw-r--r--test/iptables/iptables.go60
-rw-r--r--test/iptables/iptables_test.go345
-rw-r--r--test/iptables/iptables_util.go201
-rw-r--r--test/iptables/nat.go439
-rw-r--r--test/iptables/runner/BUILD12
-rw-r--r--test/iptables/runner/main.go73
-rw-r--r--test/packetdrill/BUILD38
-rw-r--r--test/packetdrill/accept_ack_drop.pkt27
-rw-r--r--test/packetdrill/defs.bzl87
-rw-r--r--test/packetdrill/fin_wait2_timeout.pkt23
-rw-r--r--test/packetdrill/listen_close_before_handshake_complete.pkt31
-rw-r--r--test/packetdrill/no_rst_to_rst.pkt36
-rwxr-xr-xtest/packetdrill/packetdrill_setup.sh26
-rwxr-xr-xtest/packetdrill/packetdrill_test.sh226
-rw-r--r--test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt9
-rw-r--r--test/packetdrill/sanity_test.pkt7
-rw-r--r--test/packetdrill/tcp_defer_accept.pkt48
-rw-r--r--test/packetdrill/tcp_defer_accept_timeout.pkt48
-rw-r--r--test/packetimpact/README.md702
-rw-r--r--test/packetimpact/dut/BUILD18
-rw-r--r--test/packetimpact/dut/posix_server.cc365
-rw-r--r--test/packetimpact/netdevs/BUILD15
-rw-r--r--test/packetimpact/netdevs/netdevs.go104
-rw-r--r--test/packetimpact/proto/BUILD12
-rw-r--r--test/packetimpact/proto/posix_server.proto230
-rw-r--r--test/packetimpact/runner/BUILD21
-rw-r--r--test/packetimpact/runner/defs.bzl136
-rw-r--r--test/packetimpact/runner/packetimpact_test.go370
-rw-r--r--test/packetimpact/testbench/BUILD46
-rw-r--r--test/packetimpact/testbench/connections.go950
-rw-r--r--test/packetimpact/testbench/dut.go658
-rw-r--r--test/packetimpact/testbench/dut_client.go28
-rw-r--r--test/packetimpact/testbench/layers.go1384
-rw-r--r--test/packetimpact/testbench/layers_test.go618
-rw-r--r--test/packetimpact/testbench/rawsockets.go178
-rw-r--r--test/packetimpact/testbench/testbench.go106
-rw-r--r--test/packetimpact/tests/BUILD264
-rw-r--r--test/packetimpact/tests/fin_wait2_timeout_test.go75
-rw-r--r--test/packetimpact/tests/icmpv6_param_problem_test.go78
-rw-r--r--test/packetimpact/tests/ipv4_id_uniqueness_test.go122
-rw-r--r--test/packetimpact/tests/ipv6_unknown_options_action_test.go187
-rw-r--r--test/packetimpact/tests/tcp_close_wait_ack_test.go108
-rw-r--r--test/packetimpact/tests/tcp_cork_mss_test.go84
-rw-r--r--test/packetimpact/tests/tcp_handshake_window_size_test.go66
-rw-r--r--test/packetimpact/tests/tcp_noaccept_close_rst_test.go42
-rw-r--r--test/packetimpact/tests/tcp_outside_the_window_test.go93
-rw-r--r--test/packetimpact/tests/tcp_paws_mechanism_test.go109
-rw-r--r--test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go132
-rw-r--r--test/packetimpact/tests/tcp_reordering_test.go174
-rw-r--r--test/packetimpact/tests/tcp_retransmits_test.go84
-rw-r--r--test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go105
-rw-r--r--test/packetimpact/tests/tcp_synrcvd_reset_test.go52
-rw-r--r--test/packetimpact/tests/tcp_synsent_reset_test.go88
-rw-r--r--test/packetimpact/tests/tcp_user_timeout_test.go105
-rw-r--r--test/packetimpact/tests/tcp_window_shrink_test.go73
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go105
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_test.go112
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go98
-rw-r--r--test/packetimpact/tests/udp_icmp_error_propagation_test.go365
-rw-r--r--test/packetimpact/tests/udp_recv_multicast_test.go40
-rw-r--r--test/packetimpact/tests/udp_send_recv_dgram_test.go90
-rw-r--r--test/perf/BUILD117
-rw-r--r--test/perf/linux/BUILD356
-rw-r--r--test/perf/linux/clock_getres_benchmark.cc39
-rw-r--r--test/perf/linux/clock_gettime_benchmark.cc60
-rw-r--r--test/perf/linux/death_benchmark.cc36
-rw-r--r--test/perf/linux/epoll_benchmark.cc99
-rw-r--r--test/perf/linux/fork_benchmark.cc350
-rw-r--r--test/perf/linux/futex_benchmark.cc198
-rw-r--r--test/perf/linux/getdents_benchmark.cc149
-rw-r--r--test/perf/linux/getpid_benchmark.cc37
-rw-r--r--test/perf/linux/gettid_benchmark.cc38
-rw-r--r--test/perf/linux/mapping_benchmark.cc163
-rw-r--r--test/perf/linux/open_benchmark.cc56
-rw-r--r--test/perf/linux/pipe_benchmark.cc66
-rw-r--r--test/perf/linux/randread_benchmark.cc100
-rw-r--r--test/perf/linux/read_benchmark.cc53
-rw-r--r--test/perf/linux/sched_yield_benchmark.cc37
-rw-r--r--test/perf/linux/send_recv_benchmark.cc372
-rw-r--r--test/perf/linux/seqwrite_benchmark.cc66
-rw-r--r--test/perf/linux/signal_benchmark.cc61
-rw-r--r--test/perf/linux/sleep_benchmark.cc60
-rw-r--r--test/perf/linux/stat_benchmark.cc62
-rw-r--r--test/perf/linux/unlink_benchmark.cc66
-rw-r--r--test/perf/linux/write_benchmark.cc52
-rw-r--r--test/root/BUILD58
-rw-r--r--test/root/cgroup_test.go359
-rw-r--r--test/root/chroot_test.go151
-rw-r--r--test/root/crictl_test.go393
-rw-r--r--test/root/main_test.go49
-rw-r--r--test/root/oom_score_adj_test.go366
-rw-r--r--test/root/root.go21
-rw-r--r--test/root/runsc_test.go151
-rw-r--r--test/runner/BUILD22
-rw-r--r--test/runner/defs.bzl238
-rw-r--r--test/runner/gtest/BUILD9
-rw-r--r--test/runner/gtest/gtest.go168
-rw-r--r--test/runner/runner.go497
-rw-r--r--test/runtimes/BUILD33
-rw-r--r--test/runtimes/defs.bzl79
-rw-r--r--test/runtimes/exclude_go1.12.csv16
-rw-r--r--test/runtimes/exclude_java11.csv126
-rw-r--r--test/runtimes/exclude_nodejs12.4.0.csv47
-rw-r--r--test/runtimes/exclude_php7.3.6.csv29
-rw-r--r--test/runtimes/exclude_python3.7.3.csv27
-rw-r--r--test/runtimes/proctor/BUILD28
-rw-r--r--test/runtimes/proctor/go.go90
-rw-r--r--test/runtimes/proctor/java.go71
-rw-r--r--test/runtimes/proctor/nodejs.go46
-rw-r--r--test/runtimes/proctor/php.go42
-rw-r--r--test/runtimes/proctor/proctor.go163
-rw-r--r--test/runtimes/proctor/proctor_test.go127
-rw-r--r--test/runtimes/proctor/python.go49
-rw-r--r--test/runtimes/runner/BUILD21
-rw-r--r--test/runtimes/runner/exclude_test.go37
-rw-r--r--test/runtimes/runner/main.go197
-rw-r--r--test/syscalls/BUILD1121
-rw-r--r--test/syscalls/README.md107
-rw-r--r--test/syscalls/linux/32bit.cc248
-rw-r--r--test/syscalls/linux/BUILD3933
-rw-r--r--test/syscalls/linux/accept_bind.cc641
-rw-r--r--test/syscalls/linux/accept_bind_stream.cc92
-rw-r--r--test/syscalls/linux/access.cc170
-rw-r--r--test/syscalls/linux/affinity.cc242
-rw-r--r--test/syscalls/linux/aio.cc430
-rw-r--r--test/syscalls/linux/alarm.cc192
-rw-r--r--test/syscalls/linux/arch_prctl.cc48
-rw-r--r--test/syscalls/linux/bad.cc45
-rw-r--r--test/syscalls/linux/base_poll_test.cc65
-rw-r--r--test/syscalls/linux/base_poll_test.h101
-rw-r--r--test/syscalls/linux/bind.cc145
-rw-r--r--test/syscalls/linux/brk.cc31
-rw-r--r--test/syscalls/linux/chdir.cc64
-rw-r--r--test/syscalls/linux/chmod.cc264
-rw-r--r--test/syscalls/linux/chown.cc206
-rw-r--r--test/syscalls/linux/chroot.cc366
-rw-r--r--test/syscalls/linux/clock_getres.cc37
-rw-r--r--test/syscalls/linux/clock_gettime.cc163
-rw-r--r--test/syscalls/linux/clock_nanosleep.cc179
-rw-r--r--test/syscalls/linux/concurrency.cc127
-rw-r--r--test/syscalls/linux/connect_external.cc163
-rw-r--r--test/syscalls/linux/creat.cc68
-rw-r--r--test/syscalls/linux/dev.cc167
-rw-r--r--test/syscalls/linux/dup.cc133
-rw-r--r--test/syscalls/linux/epoll.cc428
-rw-r--r--test/syscalls/linux/eventfd.cc222
-rw-r--r--test/syscalls/linux/exceptions.cc367
-rw-r--r--test/syscalls/linux/exec.cc904
-rw-r--r--test/syscalls/linux/exec.h34
-rw-r--r--test/syscalls/linux/exec_assert_closed_workload.cc45
-rw-r--r--test/syscalls/linux/exec_basic_workload.cc31
-rw-r--r--test/syscalls/linux/exec_binary.cc1646
-rw-r--r--test/syscalls/linux/exec_proc_exe_workload.cc42
-rw-r--r--test/syscalls/linux/exec_state_workload.cc202
-rw-r--r--test/syscalls/linux/exit.cc78
-rwxr-xr-xtest/syscalls/linux/exit_script.sh22
-rw-r--r--test/syscalls/linux/fadvise64.cc72
-rw-r--r--test/syscalls/linux/fallocate.cc186
-rw-r--r--test/syscalls/linux/fault.cc74
-rw-r--r--test/syscalls/linux/fchdir.cc77
-rw-r--r--test/syscalls/linux/fcntl.cc1353
-rw-r--r--test/syscalls/linux/file_base.h100
-rw-r--r--test/syscalls/linux/flock.cc636
-rw-r--r--test/syscalls/linux/fork.cc464
-rw-r--r--test/syscalls/linux/fpsig_fork.cc131
-rw-r--r--test/syscalls/linux/fpsig_nested.cc167
-rw-r--r--test/syscalls/linux/fsync.cc58
-rw-r--r--test/syscalls/linux/futex.cc742
-rw-r--r--test/syscalls/linux/getcpu.cc40
-rw-r--r--test/syscalls/linux/getdents.cc539
-rw-r--r--test/syscalls/linux/getrandom.cc63
-rw-r--r--test/syscalls/linux/getrusage.cc177
-rw-r--r--test/syscalls/linux/inotify.cc2380
-rw-r--r--test/syscalls/linux/ioctl.cc406
-rw-r--r--test/syscalls/linux/ip_socket_test_util.cc239
-rw-r--r--test/syscalls/linux/ip_socket_test_util.h135
-rw-r--r--test/syscalls/linux/iptables.cc204
-rw-r--r--test/syscalls/linux/iptables.h198
-rw-r--r--test/syscalls/linux/itimer.cc366
-rw-r--r--test/syscalls/linux/kill.cc383
-rw-r--r--test/syscalls/linux/link.cc305
-rw-r--r--test/syscalls/linux/lseek.cc202
-rw-r--r--test/syscalls/linux/madvise.cc251
-rw-r--r--test/syscalls/linux/memfd.cc557
-rw-r--r--test/syscalls/linux/memory_accounting.cc99
-rw-r--r--test/syscalls/linux/mempolicy.cc289
-rw-r--r--test/syscalls/linux/mincore.cc96
-rw-r--r--test/syscalls/linux/mkdir.cc88
-rw-r--r--test/syscalls/linux/mknod.cc190
-rw-r--r--test/syscalls/linux/mlock.cc332
-rw-r--r--test/syscalls/linux/mmap.cc1676
-rw-r--r--test/syscalls/linux/mount.cc327
-rw-r--r--test/syscalls/linux/mremap.cc492
-rw-r--r--test/syscalls/linux/msync.cc151
-rw-r--r--test/syscalls/linux/munmap.cc53
-rw-r--r--test/syscalls/linux/network_namespace.cc52
-rw-r--r--test/syscalls/linux/open.cc451
-rw-r--r--test/syscalls/linux/open_create.cc155
-rw-r--r--test/syscalls/linux/packet_socket.cc440
-rw-r--r--test/syscalls/linux/packet_socket_raw.cc565
-rw-r--r--test/syscalls/linux/partial_bad_buffer.cc405
-rw-r--r--test/syscalls/linux/pause.cc88
-rw-r--r--test/syscalls/linux/ping_socket.cc91
-rw-r--r--test/syscalls/linux/pipe.cc670
-rw-r--r--test/syscalls/linux/poll.cc294
-rw-r--r--test/syscalls/linux/ppoll.cc155
-rw-r--r--test/syscalls/linux/prctl.cc230
-rw-r--r--test/syscalls/linux/prctl_setuid.cc268
-rw-r--r--test/syscalls/linux/pread64.cc167
-rw-r--r--test/syscalls/linux/preadv.cc95
-rw-r--r--test/syscalls/linux/preadv2.cc280
-rw-r--r--test/syscalls/linux/priority.cc216
-rw-r--r--test/syscalls/linux/priority_execve.cc42
-rw-r--r--test/syscalls/linux/proc.cc2173
-rw-r--r--test/syscalls/linux/proc_net.cc482
-rw-r--r--test/syscalls/linux/proc_net_tcp.cc496
-rw-r--r--test/syscalls/linux/proc_net_udp.cc309
-rw-r--r--test/syscalls/linux/proc_net_unix.cc443
-rw-r--r--test/syscalls/linux/proc_pid_oomscore.cc72
-rw-r--r--test/syscalls/linux/proc_pid_smaps.cc468
-rw-r--r--test/syscalls/linux/proc_pid_uid_gid_map.cc311
-rw-r--r--test/syscalls/linux/pselect.cc190
-rw-r--r--test/syscalls/linux/ptrace.cc1229
-rw-r--r--test/syscalls/linux/pty.cc1627
-rw-r--r--test/syscalls/linux/pty_root.cc78
-rw-r--r--test/syscalls/linux/pwrite64.cc83
-rw-r--r--test/syscalls/linux/pwritev2.cc307
-rw-r--r--test/syscalls/linux/raw_socket.cc819
-rw-r--r--test/syscalls/linux/raw_socket_hdrincl.cc406
-rw-r--r--test/syscalls/linux/raw_socket_icmp.cc514
-rw-r--r--test/syscalls/linux/read.cc118
-rw-r--r--test/syscalls/linux/readahead.cc91
-rw-r--r--test/syscalls/linux/readv.cc294
-rw-r--r--test/syscalls/linux/readv_common.cc220
-rw-r--r--test/syscalls/linux/readv_common.h61
-rw-r--r--test/syscalls/linux/readv_socket.cc212
-rw-r--r--test/syscalls/linux/rename.cc394
-rw-r--r--test/syscalls/linux/rlimits.cc75
-rw-r--r--test/syscalls/linux/rseq.cc198
-rw-r--r--test/syscalls/linux/rseq/BUILD61
-rw-r--r--test/syscalls/linux/rseq/critical.h39
-rw-r--r--test/syscalls/linux/rseq/critical_amd64.S66
-rw-r--r--test/syscalls/linux/rseq/critical_arm64.S66
-rw-r--r--test/syscalls/linux/rseq/rseq.cc366
-rw-r--r--test/syscalls/linux/rseq/start_amd64.S45
-rw-r--r--test/syscalls/linux/rseq/start_arm64.S45
-rw-r--r--test/syscalls/linux/rseq/syscalls.h69
-rw-r--r--test/syscalls/linux/rseq/test.h43
-rw-r--r--test/syscalls/linux/rseq/types.h31
-rw-r--r--test/syscalls/linux/rseq/uapi.h51
-rw-r--r--test/syscalls/linux/rtsignal.cc171
-rw-r--r--test/syscalls/linux/sched.cc71
-rw-r--r--test/syscalls/linux/sched_yield.cc33
-rw-r--r--test/syscalls/linux/seccomp.cc425
-rw-r--r--test/syscalls/linux/select.cc168
-rw-r--r--test/syscalls/linux/semaphore.cc491
-rw-r--r--test/syscalls/linux/sendfile.cc587
-rw-r--r--test/syscalls/linux/sendfile_socket.cc231
-rw-r--r--test/syscalls/linux/shm.cc508
-rw-r--r--test/syscalls/linux/sigaction.cc79
-rw-r--r--test/syscalls/linux/sigaltstack.cc268
-rw-r--r--test/syscalls/linux/sigaltstack_check.cc33
-rw-r--r--test/syscalls/linux/sigiret.cc136
-rw-r--r--test/syscalls/linux/signalfd.cc373
-rw-r--r--test/syscalls/linux/sigprocmask.cc269
-rw-r--r--test/syscalls/linux/sigstop.cc151
-rw-r--r--test/syscalls/linux/sigtimedwait.cc323
-rw-r--r--test/syscalls/linux/socket.cc121
-rw-r--r--test/syscalls/linux/socket_abstract.cc49
-rw-r--r--test/syscalls/linux/socket_bind_to_device.cc313
-rw-r--r--test/syscalls/linux/socket_bind_to_device_distribution.cc401
-rw-r--r--test/syscalls/linux/socket_bind_to_device_sequence.cc513
-rw-r--r--test/syscalls/linux/socket_bind_to_device_util.cc75
-rw-r--r--test/syscalls/linux/socket_bind_to_device_util.h67
-rw-r--r--test/syscalls/linux/socket_blocking.cc60
-rw-r--r--test/syscalls/linux/socket_blocking.h29
-rw-r--r--test/syscalls/linux/socket_capability.cc61
-rw-r--r--test/syscalls/linux/socket_filesystem.cc49
-rw-r--r--test/syscalls/linux/socket_generic.cc820
-rw-r--r--test/syscalls/linux/socket_generic.h30
-rw-r--r--test/syscalls/linux/socket_generic_stress.cc83
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc2566
-rw-r--r--test/syscalls/linux/socket_inet_loopback_nogotsan.cc171
-rw-r--r--test/syscalls/linux/socket_ip_loopback_blocking.cc49
-rw-r--r--test/syscalls/linux/socket_ip_tcp_generic.cc1054
-rw-r--r--test/syscalls/linux/socket_ip_tcp_generic.h29
-rw-r--r--test/syscalls/linux/socket_ip_tcp_generic_loopback.cc45
-rw-r--r--test/syscalls/linux/socket_ip_tcp_loopback.cc40
-rw-r--r--test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc45
-rw-r--r--test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc44
-rw-r--r--test/syscalls/linux/socket_ip_tcp_udp_generic.cc77
-rw-r--r--test/syscalls/linux/socket_ip_udp_generic.cc452
-rw-r--r--test/syscalls/linux/socket_ip_udp_generic.h29
-rw-r--r--test/syscalls/linux/socket_ip_udp_loopback.cc50
-rw-r--r--test/syscalls/linux/socket_ip_udp_loopback_blocking.cc39
-rw-r--r--test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc39
-rw-r--r--test/syscalls/linux/socket_ip_unbound.cc474
-rw-r--r--test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc66
-rw-r--r--test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.h30
-rw-r--r--test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc39
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound.cc2456
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound.h29
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc1099
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h46
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc39
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc32
-rw-r--r--test/syscalls/linux/socket_netdevice.cc184
-rw-r--r--test/syscalls/linux/socket_netlink.cc153
-rw-r--r--test/syscalls/linux/socket_netlink_route.cc935
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.cc162
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.h55
-rw-r--r--test/syscalls/linux/socket_netlink_uevent.cc83
-rw-r--r--test/syscalls/linux/socket_netlink_util.cc187
-rw-r--r--test/syscalls/linux/socket_netlink_util.h62
-rw-r--r--test/syscalls/linux/socket_non_blocking.cc62
-rw-r--r--test/syscalls/linux/socket_non_blocking.h29
-rw-r--r--test/syscalls/linux/socket_non_stream.cc337
-rw-r--r--test/syscalls/linux/socket_non_stream.h29
-rw-r--r--test/syscalls/linux/socket_non_stream_blocking.cc85
-rw-r--r--test/syscalls/linux/socket_non_stream_blocking.h30
-rw-r--r--test/syscalls/linux/socket_stream.cc178
-rw-r--r--test/syscalls/linux/socket_stream.h30
-rw-r--r--test/syscalls/linux/socket_stream_blocking.cc163
-rw-r--r--test/syscalls/linux/socket_stream_blocking.h30
-rw-r--r--test/syscalls/linux/socket_stream_nonblock.cc49
-rw-r--r--test/syscalls/linux/socket_stream_nonblock.h30
-rw-r--r--test/syscalls/linux/socket_test_util.cc907
-rw-r--r--test/syscalls/linux/socket_test_util.h518
-rw-r--r--test/syscalls/linux/socket_test_util_impl.cc28
-rw-r--r--test/syscalls/linux/socket_unix.cc274
-rw-r--r--test/syscalls/linux/socket_unix.h29
-rw-r--r--test/syscalls/linux/socket_unix_abstract_nonblock.cc39
-rw-r--r--test/syscalls/linux/socket_unix_blocking_local.cc45
-rw-r--r--test/syscalls/linux/socket_unix_cmsg.cc1501
-rw-r--r--test/syscalls/linux/socket_unix_cmsg.h30
-rw-r--r--test/syscalls/linux/socket_unix_dgram.cc45
-rw-r--r--test/syscalls/linux/socket_unix_dgram.h29
-rw-r--r--test/syscalls/linux/socket_unix_dgram_local.cc58
-rw-r--r--test/syscalls/linux/socket_unix_dgram_non_blocking.cc57
-rw-r--r--test/syscalls/linux/socket_unix_domain.cc39
-rw-r--r--test/syscalls/linux/socket_unix_filesystem_nonblock.cc39
-rw-r--r--test/syscalls/linux/socket_unix_non_stream.cc256
-rw-r--r--test/syscalls/linux/socket_unix_non_stream.h30
-rw-r--r--test/syscalls/linux/socket_unix_non_stream_blocking_local.cc42
-rw-r--r--test/syscalls/linux/socket_unix_pair.cc44
-rw-r--r--test/syscalls/linux/socket_unix_pair_nonblock.cc39
-rw-r--r--test/syscalls/linux/socket_unix_seqpacket.cc67
-rw-r--r--test/syscalls/linux/socket_unix_seqpacket.h30
-rw-r--r--test/syscalls/linux/socket_unix_seqpacket_local.cc58
-rw-r--r--test/syscalls/linux/socket_unix_stream.cc125
-rw-r--r--test/syscalls/linux/socket_unix_stream_blocking_local.cc40
-rw-r--r--test/syscalls/linux/socket_unix_stream_local.cc48
-rw-r--r--test/syscalls/linux/socket_unix_stream_nonblock_local.cc39
-rw-r--r--test/syscalls/linux/socket_unix_unbound_abstract.cc116
-rw-r--r--test/syscalls/linux/socket_unix_unbound_dgram.cc183
-rw-r--r--test/syscalls/linux/socket_unix_unbound_filesystem.cc84
-rw-r--r--test/syscalls/linux/socket_unix_unbound_seqpacket.cc89
-rw-r--r--test/syscalls/linux/socket_unix_unbound_stream.cc733
-rw-r--r--test/syscalls/linux/splice.cc699
-rw-r--r--test/syscalls/linux/stat.cc720
-rw-r--r--test/syscalls/linux/stat_times.cc303
-rw-r--r--test/syscalls/linux/statfs.cc82
-rw-r--r--test/syscalls/linux/sticky.cc161
-rw-r--r--test/syscalls/linux/symlink.cc402
-rw-r--r--test/syscalls/linux/sync.cc59
-rw-r--r--test/syscalls/linux/sync_file_range.cc112
-rw-r--r--test/syscalls/linux/sysinfo.cc86
-rw-r--r--test/syscalls/linux/syslog.cc51
-rw-r--r--test/syscalls/linux/sysret.cc142
-rw-r--r--test/syscalls/linux/tcp_socket.cc1568
-rw-r--r--test/syscalls/linux/tgkill.cc48
-rw-r--r--test/syscalls/linux/time.cc107
-rw-r--r--test/syscalls/linux/timerfd.cc273
-rw-r--r--test/syscalls/linux/timers.cc662
-rw-r--r--test/syscalls/linux/tkill.cc75
-rw-r--r--test/syscalls/linux/truncate.cc218
-rw-r--r--test/syscalls/linux/tuntap.cc422
-rw-r--r--test/syscalls/linux/tuntap_hostinet.cc38
-rw-r--r--test/syscalls/linux/udp_bind.cc316
-rw-r--r--test/syscalls/linux/udp_socket.cc30
-rw-r--r--test/syscalls/linux/udp_socket_errqueue_test_case.cc57
-rw-r--r--test/syscalls/linux/udp_socket_test_cases.cc1727
-rw-r--r--test/syscalls/linux/udp_socket_test_cases.h82
-rw-r--r--test/syscalls/linux/uidgid.cc276
-rw-r--r--test/syscalls/linux/uname.cc111
-rw-r--r--test/syscalls/linux/unix_domain_socket_test_util.cc351
-rw-r--r--test/syscalls/linux/unix_domain_socket_test_util.h162
-rw-r--r--test/syscalls/linux/unlink.cc214
-rw-r--r--test/syscalls/linux/unshare.cc50
-rw-r--r--test/syscalls/linux/utimes.cc319
-rw-r--r--test/syscalls/linux/vdso.cc48
-rw-r--r--test/syscalls/linux/vdso_clock_gettime.cc108
-rw-r--r--test/syscalls/linux/vfork.cc195
-rw-r--r--test/syscalls/linux/vsyscall.cc46
-rw-r--r--test/syscalls/linux/wait.cc913
-rw-r--r--test/syscalls/linux/write.cc139
-rw-r--r--test/syscalls/linux/xattr.cc610
-rw-r--r--test/uds/BUILD16
-rw-r--r--test/uds/uds.go228
-rw-r--r--test/util/BUILD358
-rw-r--r--test/util/capability_util.cc81
-rw-r--r--test/util/capability_util.h101
-rw-r--r--test/util/cleanup.h61
-rw-r--r--test/util/epoll_util.cc52
-rw-r--r--test/util/epoll_util.h36
-rw-r--r--test/util/eventfd_util.h43
-rw-r--r--test/util/file_descriptor.h134
-rw-r--r--test/util/fs_util.cc633
-rw-r--r--test/util/fs_util.h210
-rw-r--r--test/util/fs_util_test.cc105
-rw-r--r--test/util/logging.cc97
-rw-r--r--test/util/logging.h73
-rw-r--r--test/util/memory_util.h147
-rw-r--r--test/util/mount_util.h51
-rw-r--r--test/util/multiprocess_util.cc173
-rw-r--r--test/util/multiprocess_util.h132
-rw-r--r--test/util/platform_util.cc48
-rw-r--r--test/util/platform_util.h56
-rw-r--r--test/util/posix_error.cc98
-rw-r--r--test/util/posix_error.h462
-rw-r--r--test/util/posix_error_test.cc46
-rw-r--r--test/util/proc_util.cc107
-rw-r--r--test/util/proc_util.h150
-rw-r--r--test/util/proc_util_test.cc81
-rw-r--r--test/util/pty_util.cc53
-rw-r--r--test/util/pty_util.h33
-rw-r--r--test/util/rlimit_util.cc45
-rw-r--r--test/util/rlimit_util.h32
-rw-r--r--test/util/save_util.cc71
-rw-r--r--test/util/save_util.h52
-rw-r--r--test/util/save_util_linux.cc49
-rw-r--r--test/util/save_util_other.cc27
-rw-r--r--test/util/signal_util.cc104
-rw-r--r--test/util/signal_util.h107
-rw-r--r--test/util/temp_path.cc164
-rw-r--r--test/util/temp_path.h135
-rw-r--r--test/util/temp_umask.h39
-rw-r--r--test/util/test_main.cc20
-rw-r--r--test/util/test_util.cc233
-rw-r--r--test/util/test_util.h784
-rw-r--r--test/util/test_util_impl.cc52
-rw-r--r--test/util/test_util_runfiles.cc50
-rw-r--r--test/util/test_util_test.cc251
-rw-r--r--test/util/thread_util.h93
-rw-r--r--test/util/time_util.cc41
-rw-r--r--test/util/time_util.h29
-rw-r--r--test/util/timer_util.cc27
-rw-r--r--test/util/timer_util.h74
-rw-r--r--test/util/uid_util.cc44
-rw-r--r--test/util/uid_util.h29
-rw-r--r--tools/BUILD1
-rw-r--r--tools/bazel.mk124
-rw-r--r--tools/bazeldefs/BUILD51
-rw-r--r--tools/bazeldefs/defs.bzl182
-rw-r--r--tools/bazeldefs/platforms.bzl9
-rw-r--r--tools/bazeldefs/tags.bzl56
-rw-r--r--tools/bigquery/BUILD10
-rw-r--r--tools/bigquery/bigquery.go121
-rw-r--r--tools/checkescape/BUILD16
-rw-r--r--tools/checkescape/checkescape.go726
-rw-r--r--tools/checkescape/test1/BUILD9
-rw-r--r--tools/checkescape/test1/test1.go195
-rw-r--r--tools/checkescape/test2/BUILD9
-rw-r--r--tools/checkescape/test2/test2.go94
-rw-r--r--tools/checkunsafe/BUILD13
-rw-r--r--tools/checkunsafe/check_unsafe.go56
-rw-r--r--tools/defs.bzl254
-rwxr-xr-xtools/go_branch.sh101
-rw-r--r--tools/go_generics/BUILD38
-rw-r--r--tools/go_generics/defs.bzl139
-rw-r--r--tools/go_generics/generics.go286
-rw-r--r--tools/go_generics/generics_tests/all_stmts/input.go290
-rw-r--r--tools/go_generics/generics_tests/all_stmts/opts.txt1
-rw-r--r--tools/go_generics/generics_tests/all_stmts/output/output.go288
-rw-r--r--tools/go_generics/generics_tests/all_types/input.go43
-rw-r--r--tools/go_generics/generics_tests/all_types/lib/lib.go17
-rw-r--r--tools/go_generics/generics_tests/all_types/opts.txt1
-rw-r--r--tools/go_generics/generics_tests/all_types/output/output.go41
-rw-r--r--tools/go_generics/generics_tests/anon/input.go46
-rw-r--r--tools/go_generics/generics_tests/anon/opts.txt1
-rw-r--r--tools/go_generics/generics_tests/anon/output/output.go42
-rw-r--r--tools/go_generics/generics_tests/consts/input.go26
-rw-r--r--tools/go_generics/generics_tests/consts/opts.txt1
-rw-r--r--tools/go_generics/generics_tests/consts/output/output.go26
-rw-r--r--tools/go_generics/generics_tests/imports/input.go24
-rw-r--r--tools/go_generics/generics_tests/imports/opts.txt1
-rw-r--r--tools/go_generics/generics_tests/imports/output/output.go27
-rw-r--r--tools/go_generics/generics_tests/remove_typedef/input.go37
-rw-r--r--tools/go_generics/generics_tests/remove_typedef/opts.txt1
-rw-r--r--tools/go_generics/generics_tests/remove_typedef/output/output.go29
-rw-r--r--tools/go_generics/generics_tests/simple/input.go45
-rw-r--r--tools/go_generics/generics_tests/simple/opts.txt1
-rw-r--r--tools/go_generics/generics_tests/simple/output/output.go43
-rw-r--r--tools/go_generics/globals/BUILD13
-rw-r--r--tools/go_generics/globals/globals_visitor.go597
-rw-r--r--tools/go_generics/globals/scope.go84
-rwxr-xr-xtools/go_generics/go_generics_unittest.sh70
-rw-r--r--tools/go_generics/go_merge/BUILD9
-rw-r--r--tools/go_generics/go_merge/main.go139
-rw-r--r--tools/go_generics/imports.go150
-rw-r--r--tools/go_generics/remove.go105
-rw-r--r--tools/go_generics/rules_tests/BUILD43
-rw-r--r--tools/go_generics/rules_tests/template.go42
-rw-r--r--tools/go_generics/rules_tests/template_test.go48
-rw-r--r--tools/go_marshal/BUILD19
-rw-r--r--tools/go_marshal/README.md116
-rw-r--r--tools/go_marshal/analysis/BUILD12
-rw-r--r--tools/go_marshal/analysis/analysis_unsafe.go179
-rw-r--r--tools/go_marshal/defs.bzl65
-rw-r--r--tools/go_marshal/gomarshal/BUILD21
-rw-r--r--tools/go_marshal/gomarshal/generator.go499
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go276
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go146
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go289
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_struct.go618
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go233
-rw-r--r--tools/go_marshal/gomarshal/util.go491
-rw-r--r--tools/go_marshal/main.go72
-rw-r--r--tools/go_marshal/marshal/BUILD16
-rw-r--r--tools/go_marshal/marshal/marshal.go187
-rw-r--r--tools/go_marshal/primitive/BUILD18
-rw-r--r--tools/go_marshal/primitive/primitive.go175
-rw-r--r--tools/go_marshal/test/BUILD44
-rw-r--r--tools/go_marshal/test/benchmark_test.go220
-rw-r--r--tools/go_marshal/test/escape/BUILD14
-rw-r--r--tools/go_marshal/test/escape/escape.go95
-rw-r--r--tools/go_marshal/test/external/BUILD11
-rw-r--r--tools/go_marshal/test/external/external.go31
-rw-r--r--tools/go_marshal/test/marshal_test.go515
-rw-r--r--tools/go_marshal/test/test.go176
-rwxr-xr-xtools/go_mod.sh29
-rw-r--r--tools/go_stateify/BUILD10
-rw-r--r--tools/go_stateify/defs.bzl60
-rw-r--r--tools/go_stateify/main.go476
-rw-r--r--tools/installers/BUILD35
-rwxr-xr-xtools/installers/head.sh21
-rwxr-xr-xtools/installers/images.sh24
-rwxr-xr-xtools/installers/master.sh34
-rwxr-xr-xtools/installers/shim.sh24
-rw-r--r--tools/issue_reviver/BUILD12
-rw-r--r--tools/issue_reviver/github/BUILD16
-rw-r--r--tools/issue_reviver/github/github.go164
-rw-r--r--tools/issue_reviver/main.go100
-rw-r--r--tools/issue_reviver/reviver/BUILD18
-rw-r--r--tools/issue_reviver/reviver/reviver.go192
-rw-r--r--tools/issue_reviver/reviver/reviver_test.go88
-rwxr-xr-xtools/make_apt.sh139
-rwxr-xr-xtools/make_release.sh82
-rw-r--r--tools/nogo/BUILD49
-rw-r--r--tools/nogo/README.md31
-rw-r--r--tools/nogo/build.go36
-rw-r--r--tools/nogo/check/BUILD12
-rw-r--r--tools/nogo/check/main.go24
-rw-r--r--tools/nogo/config.go116
-rw-r--r--tools/nogo/data/BUILD10
-rw-r--r--tools/nogo/data/data.go21
-rw-r--r--tools/nogo/defs.bzl172
-rw-r--r--tools/nogo/io_bazel_rules_go-visibility.patch25
-rw-r--r--tools/nogo/matchers.go143
-rw-r--r--tools/nogo/nogo.go316
-rw-r--r--tools/nogo/register.go64
-rwxr-xr-xtools/tag_release.sh82
-rw-r--r--tools/tags/BUILD11
-rw-r--r--tools/tags/tags.go89
-rw-r--r--tools/vm/BUILD57
-rw-r--r--tools/vm/README.md42
-rwxr-xr-xtools/vm/build.sh117
-rw-r--r--tools/vm/defs.bzl201
-rwxr-xr-xtools/vm/execute.sh160
-rw-r--r--tools/vm/test.cc27
-rwxr-xr-xtools/vm/ubuntu1604/10_core.sh43
-rwxr-xr-xtools/vm/ubuntu1604/15_gcloud.sh50
-rwxr-xr-xtools/vm/ubuntu1604/20_bazel.sh38
-rwxr-xr-xtools/vm/ubuntu1604/25_docker.sh65
-rwxr-xr-xtools/vm/ubuntu1604/30_containerd.sh86
-rwxr-xr-xtools/vm/ubuntu1604/40_kokoro.sh72
-rw-r--r--tools/vm/ubuntu1604/BUILD7
-rw-r--r--tools/vm/ubuntu1804/BUILD7
-rwxr-xr-xtools/vm/zone.sh17
-rwxr-xr-xtools/workspace_status.sh18
-rw-r--r--vdso/BUILD81
-rw-r--r--vdso/barrier.h49
-rw-r--r--vdso/check_vdso.py204
-rw-r--r--vdso/compiler.h29
-rw-r--r--vdso/cycle_clock.h51
-rw-r--r--vdso/seqlock.h39
-rw-r--r--vdso/syscalls.h100
-rw-r--r--vdso/vdso.cc155
-rw-r--r--vdso/vdso_amd64.lds102
-rw-r--r--vdso/vdso_arm64.lds99
-rw-r--r--vdso/vdso_time.cc159
-rw-r--r--vdso/vdso_time.h27
-rw-r--r--website/BUILD181
-rw-r--r--website/_config.yml36
-rw-r--r--website/_includes/byline.html18
-rw-r--r--website/_includes/footer-links.html43
-rw-r--r--website/_includes/footer.html72
-rw-r--r--website/_includes/graph.html205
-rw-r--r--website/_includes/header-links.html19
-rw-r--r--website/_includes/header.html30
-rw-r--r--website/_includes/paginator.html10
-rw-r--r--website/_includes/required_linux.html2
-rw-r--r--website/_layouts/base.html9
-rw-r--r--website/_layouts/blog.html17
-rw-r--r--website/_layouts/default.html14
-rw-r--r--website/_layouts/docs.html59
-rw-r--r--website/_layouts/post.html10
-rw-r--r--website/_plugins/svg_mime_type.rb3
-rw-r--r--website/_sass/footer.scss15
-rw-r--r--website/_sass/front.scss17
-rw-r--r--website/_sass/navbar.scss26
-rw-r--r--website/_sass/sidebar.scss61
-rw-r--r--website/_sass/style.scss154
-rw-r--r--website/archive.key29
-rw-r--r--website/assets/favicons/apple-touch-icon-180x180.pngbin0 -> 18820 bytes
-rw-r--r--website/assets/favicons/favicon-16x16.pngbin0 -> 926 bytes
-rw-r--r--website/assets/favicons/favicon-32x32.pngbin0 -> 2308 bytes
-rw-r--r--website/assets/favicons/favicon.icobin0 -> 1150 bytes
-rw-r--r--website/assets/favicons/pwa-192x192.pngbin0 -> 20666 bytes
-rw-r--r--website/assets/favicons/pwa-512x512.pngbin0 -> 24397 bytes
-rw-r--r--website/assets/favicons/tile150x150.pngbin0 -> 18440 bytes
-rw-r--r--website/assets/favicons/tile310x150.pngbin0 -> 21486 bytes
-rw-r--r--website/assets/favicons/tile310x310.pngbin0 -> 25629 bytes
-rw-r--r--website/assets/favicons/tile70x70.pngbin0 -> 11148 bytes
-rw-r--r--website/assets/images/2019-11-18-security-basics-figure1.pngbin0 -> 19088 bytes
-rw-r--r--website/assets/images/2019-11-18-security-basics-figure2.pngbin0 -> 17642 bytes
-rw-r--r--website/assets/images/2019-11-18-security-basics-figure3.pngbin0 -> 16471 bytes
-rw-r--r--website/assets/images/2020-04-02-networking-security-figure1.pngbin0 -> 29775 bytes
-rw-r--r--website/assets/images/background.jpgbin0 -> 1070364 bytes
-rw-r--r--website/assets/logos/Makefile13
-rw-r--r--website/assets/logos/README.md10
-rw-r--r--website/assets/logos/logo_solo_monochrome.pngbin0 -> 10483 bytes
-rw-r--r--website/assets/logos/logo_solo_monochrome.svg73
-rw-r--r--website/assets/logos/logo_solo_on_dark-1024.pngbin0 -> 59374 bytes
-rw-r--r--website/assets/logos/logo_solo_on_dark-128.pngbin0 -> 5951 bytes
-rw-r--r--website/assets/logos/logo_solo_on_dark-16.pngbin0 -> 701 bytes
-rw-r--r--website/assets/logos/logo_solo_on_dark.pngbin0 -> 8387 bytes
-rw-r--r--website/assets/logos/logo_solo_on_dark.svg73
-rw-r--r--website/assets/logos/logo_solo_on_dark_full-1024.pngbin0 -> 80121 bytes
-rw-r--r--website/assets/logos/logo_solo_on_dark_full-128.pngbin0 -> 8616 bytes
-rw-r--r--website/assets/logos/logo_solo_on_dark_full-16.pngbin0 -> 900 bytes
-rw-r--r--website/assets/logos/logo_solo_on_dark_full.pngbin0 -> 17055 bytes
-rw-r--r--website/assets/logos/logo_solo_on_dark_full.svg79
-rw-r--r--website/assets/logos/logo_solo_on_white.pngbin0 -> 10572 bytes
-rw-r--r--website/assets/logos/logo_solo_on_white.svg73
-rw-r--r--website/assets/logos/logo_solo_on_white_bordered-1024.pngbin0 -> 95350 bytes
-rw-r--r--website/assets/logos/logo_solo_on_white_bordered-128.pngbin0 -> 10231 bytes
-rw-r--r--website/assets/logos/logo_solo_on_white_bordered-16.pngbin0 -> 960 bytes
-rw-r--r--website/assets/logos/logo_solo_on_white_bordered.pngbin0 -> 15330 bytes
-rw-r--r--website/assets/logos/logo_solo_on_white_bordered.svg82
-rw-r--r--website/assets/logos/logo_with_text_monochrome.pngbin0 -> 22220 bytes
-rw-r--r--website/assets/logos/logo_with_text_monochrome.svg116
-rw-r--r--website/assets/logos/logo_with_text_on_dark-1024.pngbin0 -> 30774 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_dark-128.pngbin0 -> 3129 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_dark-16.pngbin0 -> 315 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_dark.pngbin0 -> 17035 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_dark.svg116
-rw-r--r--website/assets/logos/logo_with_text_on_dark_full-1024.pngbin0 -> 34866 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_dark_full-128.pngbin0 -> 3746 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_dark_full-16.pngbin0 -> 372 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_dark_full.pngbin0 -> 25956 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_dark_full.svg120
-rw-r--r--website/assets/logos/logo_with_text_on_white.pngbin0 -> 22363 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_white.svg116
-rw-r--r--website/assets/logos/logo_with_text_on_white_bordered.pngbin0 -> 27719 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_white_bordered.svg122
-rw-r--r--website/assets/logos/powered-gvisor.pngbin0 -> 5193 bytes
-rw-r--r--website/blog/2019-11-18-security-basics.md299
-rw-r--r--website/blog/2020-04-02-networking-security.md183
-rw-r--r--website/blog/BUILD37
-rw-r--r--website/blog/index.html22
-rw-r--r--website/cmd/server/BUILD10
-rw-r--r--website/cmd/server/main.go215
-rw-r--r--website/cmd/syscalldocs/BUILD9
-rw-r--r--website/cmd/syscalldocs/main.go211
-rw-r--r--website/css/main.scss5
-rw-r--r--website/defs.bzl176
-rwxr-xr-xwebsite/import.sh27
-rw-r--r--website/index.md50
-rw-r--r--website/performance/README.md9
-rw-r--r--website/performance/applications.csv13
-rw-r--r--website/performance/density.csv9
-rw-r--r--website/performance/ffmpeg.csv3
-rw-r--r--website/performance/fio-tmpfs.csv9
-rw-r--r--website/performance/fio.csv9
-rw-r--r--website/performance/httpd100k.csv17
-rw-r--r--website/performance/httpd10240k.csv17
-rw-r--r--website/performance/iperf.csv5
-rw-r--r--website/performance/redis.csv35
-rw-r--r--website/performance/startup.csv7
-rw-r--r--website/performance/sysbench-cpu.csv3
-rw-r--r--website/performance/sysbench-memory.csv3
-rw-r--r--website/performance/syscall.csv4
-rw-r--r--website/performance/tensorflow.csv3
2528 files changed, 489885 insertions, 0 deletions
diff --git a/.bazelrc b/.bazelrc
new file mode 100644
index 000000000..4a0671f4a
--- /dev/null
+++ b/.bazelrc
@@ -0,0 +1,60 @@
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Build with C++17.
+build --cxxopt=-std=c++17
+
+# Display the current git revision in the info block.
+build --stamp --workspace_status_command tools/workspace_status.sh
+
+# Enable remote execution so actions are performed on the remote systems.
+build:remote --remote_executor=grpcs://remotebuildexecution.googleapis.com
+build:remote --project_id=gvisor-rbe
+build:remote --remote_instance_name=projects/gvisor-rbe/instances/default_instance
+# Enable authentication. This will pick up application default credentials by
+# default. You can use --google_credentials=some_file.json to use a service
+# account credential instead.
+build:remote --google_default_credentials=true
+build:remote --auth_scope="https://www.googleapis.com/auth/cloud-source-tools"
+
+# Add a custom platform and toolchain that builds in a privileged docker
+# container, which is required by our syscall tests.
+build:remote --host_platform=//tools/bazeldefs:rbe_ubuntu1604
+build:remote --extra_toolchains=//tools/bazeldefs:cc-toolchain-clang-x86_64-default
+build:remote --extra_execution_platforms=//tools/bazeldefs:rbe_ubuntu1604
+build:remote --platforms=//tools/bazeldefs:rbe_ubuntu1604
+build:remote --crosstool_top=@rbe_default//cc:toolchain
+build:remote --jobs=50
+build:remote --remote_timeout=3600
+# RBE requires a strong hash function, such as SHA256.
+startup --host_jvm_args=-Dbazel.DigestFunction=SHA256
+
+# Set flags for uploading to BES in order to view results in the Bazel Build
+# Results UI.
+build:results --bes_backend="buildeventservice.googleapis.com"
+build:results --bes_timeout=60s
+build:results --tls_enabled
+
+# Output BES results url
+build:results --bes_results_url="https://source.cloud.google.com/results/invocations/"
+
+# Set flags for uploading to BES without Remote Build Execution.
+build:results-local --bes_backend="buildeventservice.googleapis.com"
+build:results-local --bes_timeout=60s
+build:results-local --tls_enabled=true
+build:results-local --auth_enabled=true
+build:results-local --spawn_strategy=local
+build:results-local --remote_cache=remotebuildexecution.googleapis.com
+build:results-local --remote_timeout=3600
+build:results-local --bes_results_url="https://source.cloud.google.com/results/invocations/"
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
new file mode 100644
index 000000000..49a1ba697
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -0,0 +1,31 @@
+---
+name: Bug report
+about: Create a bug report to help us improve
+title:
+labels:
+ - 'type: bug'
+assignees: ''
+---
+
+**Description**
+
+A clear description of what the bug is. If possible, explicitly indicate the
+expected behavior vs. the observed behavior.
+
+**Steps to reproduce**
+
+If available, please include detailed reproduction steps.
+
+If the bug requires software that is not publicly available, see if it can be
+reproduced with software that is publicly available.
+
+**Environment**
+
+Please include the following details of your environment:
+
+* `runsc -v`
+* `docker version` or `docker info` (if available)
+* `kubectl version` and `kubectl get nodes` (if using Kubernetes)
+* `uname -a`
+* `git describe` (if built from source)
+* `runsc` debug logs (if available)
diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml
new file mode 100644
index 000000000..772c9a0ac
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1,11 @@
+blank_issues_enabled: false
+contact_links:
+ - name: gVisor Documentation (FAQ)
+ url: https://gvisor.dev/docs/user_guide/faq/
+ about: Please see our documentation for common questions and answers.
+ - name: gVisor Documentation (Debugging)
+ url: https://gvisor.dev/docs/user_guide/debugging/
+ about: Please see our documentation for debugging tips.
+ - name: gVisor User Forum
+ url: https://groups.google.com/g/gvisor-users
+ about: Ask and answer general questions here.
diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md
new file mode 100644
index 000000000..65f60f385
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature_request.md
@@ -0,0 +1,21 @@
+---
+name: Feature request
+about: Suggest an idea or improvement
+title: ''
+labels:
+ - 'type: enhancement'
+assignees: ''
+---
+
+**Description**
+
+A clear description of the feature or enhancement.
+
+**Is this feature related to a specific bug?**
+
+Please include a bug references if yes.
+
+**Do you have a specific solution in mind?**
+
+Please include any details about a solution that you have in mind, including any
+alternatives considered.
diff --git a/.github/labeler.yml b/.github/labeler.yml
new file mode 100644
index 000000000..b6a17051c
--- /dev/null
+++ b/.github/labeler.yml
@@ -0,0 +1,42 @@
+"arch: arm":
+ - "**/*_arm64.*"
+ - "**/*_aarch64.*"
+"arch: x86_64":
+ - "**/*_amd64.*"
+ - "**/*_x86.*"
+"area: bazel":
+ - "**/BUILD"
+ - "**/*.bzl"
+"area: docs":
+ - "**/g3doc/**"
+ - "**/README.md"
+"area: filesystem":
+ - "pkg/sentry/fs/**"
+ - "pkg/sentry/vfs/**"
+ - "pkg/sentry/fsimpl/**"
+"area: hostinet":
+ - "pkg/sentry/socket/hostinet/**"
+"area: networking":
+ - "pkg/tcpip/**"
+ - "pkg/sentry/socket/**"
+"area: kernel":
+ - "pkg/sentry/arch/**"
+ - "pkg/sentry/kernel/**"
+ - "pkg/sentry/syscalls/**"
+"area: mm":
+ - "pkg/sentry/mm/**"
+"area: tests":
+ - "**/tests/**"
+ - "**/*_test.go"
+ - "**/test/**"
+"area: tooling":
+ - "tools/**"
+"dependencies":
+ - "WORKSPACE"
+ - "go.mod"
+ - "go.sum"
+"platform: kvm":
+ - "pkg/sentry/platform/kvm/**"
+ - "pkg/sentry/platform/ring0/**"
+"platform: ptrace":
+ - "pkg/sentry/platform/ptrace/**"
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
new file mode 100644
index 000000000..264b4e9fa
--- /dev/null
+++ b/.github/pull_request_template.md
@@ -0,0 +1,5 @@
+* [ ] Have you followed the guidelines in [CONTRIBUTING.md](../blob/master/CONTRIBUTING.md)?
+* [ ] Have you formatted and linted your code?
+* [ ] Have you added relevant tests?
+* [ ] Have you added appropriate Fixes & Updates references?
+* [ ] If yes, please erase all these lines!
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
new file mode 100644
index 000000000..cf782a580
--- /dev/null
+++ b/.github/workflows/build.yml
@@ -0,0 +1,21 @@
+name: "Build"
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ branches:
+ - master
+
+jobs:
+ default:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ - uses: actions/cache@v1
+ with:
+ path: ~/.cache/bazel
+ key: ${{ runner.os }}-bazel-${{ hashFiles('WORKSPACE') }}
+ restore-keys: |
+ ${{ runner.os }}-bazel-
+ - run: make
diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml
new file mode 100644
index 000000000..10c86f5cd
--- /dev/null
+++ b/.github/workflows/go.yml
@@ -0,0 +1,66 @@
+name: "Go"
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ branches:
+ - master
+
+jobs:
+ generate:
+ runs-on: ubuntu-latest
+ steps:
+ - run: |
+ jq -nc '{"state": "pending", "context": "go tests"}' | \
+ curl -sL -X POST -d @- \
+ -H "Content-Type: application/json" \
+ -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
+ "${{ github.event.pull_request.statuses_url }}"
+ if: github.event_name == 'pull_request'
+ - uses: actions/checkout@v2
+ if: github.event_name == 'push'
+ with:
+ fetch-depth: 0
+ token: '${{ secrets.GO_TOKEN }}'
+ - uses: actions/checkout@v2
+ if: github.event_name == 'pull_request'
+ with:
+ fetch-depth: 0
+ - uses: actions/setup-go@v2
+ with:
+ go-version: 1.14
+ - uses: actions/cache@v1
+ with:
+ path: ~/go/pkg/mod
+ key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
+ restore-keys: |
+ ${{ runner.os }}-go-
+ - uses: actions/cache@v1
+ with:
+ path: ~/.cache/bazel
+ key: ${{ runner.os }}-bazel-${{ hashFiles('WORKSPACE') }}
+ restore-keys: |
+ ${{ runner.os }}-bazel-
+ - run: make build TARGETS="//:gopath"
+ - run: tools/go_branch.sh
+ - run: git checkout go && git clean -f
+ - run: go build ./...
+ - if: github.event_name == 'push'
+ run: |
+ git remote add upstream "https://github.com/${{ github.repository }}"
+ git push upstream go:go
+ - if: ${{ success() && github.event_name == 'pull_request' }}
+ run: |
+ jq -nc '{"state": "success", "context": "go tests"}' | \
+ curl -sL -X POST -d @- \
+ -H "Content-Type: application/json" \
+ -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
+ "${{ github.event.pull_request.statuses_url }}"
+ - if: ${{ failure() && github.event_name == 'pull_request' }}
+ run: |
+ jq -nc '{"state": "failure", "context": "go tests"}' | \
+ curl -sL -X POST -d @- \
+ -H "Content-Type: application/json" \
+ -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
+ "${{ github.event.pull_request.statuses_url }}"
diff --git a/.github/workflows/issue_reviver.yml b/.github/workflows/issue_reviver.yml
new file mode 100644
index 000000000..5e0254111
--- /dev/null
+++ b/.github/workflows/issue_reviver.yml
@@ -0,0 +1,14 @@
+name: "Issue reviver"
+on:
+ schedule:
+ - cron: '0 0 * * *'
+
+jobs:
+ label:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ - run: make run TARGETS="//tools/issue_reviver"
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ GITHUB_REPOSITORY: ${{ github.repository }}
diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml
new file mode 100644
index 000000000..c09f7eb36
--- /dev/null
+++ b/.github/workflows/labeler.yml
@@ -0,0 +1,12 @@
+name: "Labeler"
+on:
+- pull_request
+
+jobs:
+ label:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/labeler@v2
+ if: github.base_ref == null
+ with:
+ repo-token: "${{ secrets.GITHUB_TOKEN }}"
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
new file mode 100644
index 000000000..0b31fecf5
--- /dev/null
+++ b/.github/workflows/stale.yml
@@ -0,0 +1,20 @@
+name: "Close stale issues"
+on:
+ schedule:
+ - cron: "0 0 * * *"
+
+jobs:
+ stale:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/stale@v3
+ with:
+ repo-token: ${{ secrets.GITHUB_TOKEN }}
+ stale-issue-label: 'stale'
+ stale-pr-label: 'stale'
+ exempt-issue-labels: 'exported, type: bug, type: cleanup, type: enhancement, type: process, type: proposal, type: question'
+ exempt-pr-labels: 'ready to pull'
+ stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove the stale label or comment or this will be closed in 30 days.'
+ stale-pr-message: 'This pull request is stale because it has been open 90 days with no activity. Remove the stale label or comment or this will be closed in 30 days.'
+ days-before-stale: 90
+ days-before-close: 30
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 000000000..13babef4d
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+# Generated bazel symlinks.
+/bazel-*
diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 000000000..9d3141f38
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,45 @@
+language: shell
+dist: xenial
+git:
+ clone: false # Clone manually in before_install
+before_install:
+ - set -e -o pipefail
+ - |
+ if [ "${TRAVIS_PULL_REQUEST}" = false ]; then
+ # This is not a PR build, fetch and checkout the commit being tested
+ git clone -q --depth 1 "https://github.com/${TRAVIS_REPO_SLUG}.git" "${TRAVIS_REPO_SLUG}"
+ cd "${TRAVIS_REPO_SLUG}"
+ git fetch origin "${TRAVIS_COMMIT}" --depth 1
+ git checkout -qf "${TRAVIS_COMMIT}"
+ else
+ # This is a PR build, simulate +refs/pull/{num}/merge.
+ # We can do that by fetching +refs/pull/{num}/head and cherry picking it
+ # onto the target branch.
+ git clone -q --branch "${TRAVIS_BRANCH}" --depth 1 "https://github.com/${TRAVIS_REPO_SLUG}.git" "${TRAVIS_REPO_SLUG}"
+ cd "${TRAVIS_REPO_SLUG}"
+ git fetch origin "+refs/pull/${TRAVIS_PULL_REQUEST}/head" --depth 1
+ git config --global user.email "$(git log -1 FETCH_HEAD --pretty="%cE")"
+ git config --global user.name "$(git log -1 FETCH_HEAD --pretty="%aN")"
+ git cherry-pick --strategy=recursive -X theirs --keep-redundant-commits FETCH_HEAD
+ fi
+cache:
+ directories:
+ - /home/travis/.cache/bazel/
+os: linux
+services:
+ - docker
+jobs:
+ include:
+ - os: linux
+ arch: amd64
+ - os: linux
+ arch: arm64
+script:
+ # On arm64, we need to create our own pipes for stderr and stdout,
+ # otherwise we will not be able to open /dev/stderr. This is probably
+ # due to AppArmor rules.
+ - bash -xeo pipefail -c 'uname -a && make smoke-test 2>&1 | cat'
+branches:
+ except:
+ # Skip copybara branches.
+ - /^test\/cl.*$/
diff --git a/AUTHORS b/AUTHORS
new file mode 100644
index 000000000..01ba46567
--- /dev/null
+++ b/AUTHORS
@@ -0,0 +1,8 @@
+# This is the list of gVisor authors for copyright purposes.
+#
+# This does not necessarily list everyone who has contributed code, since in
+# some cases, their employer may be the copyright holder. To see the full list
+# of contributors, see the revision history in source control.
+#
+# Please send a patch if you would like to be included in this list.
+Google LLC
diff --git a/BUILD b/BUILD
new file mode 100644
index 000000000..962d54821
--- /dev/null
+++ b/BUILD
@@ -0,0 +1,94 @@
+load("//tools:defs.bzl", "build_test", "gazelle", "go_path")
+load("//website:defs.bzl", "doc")
+
+package(licenses = ["notice"])
+
+exports_files(["LICENSE"])
+
+doc(
+ name = "contributing",
+ src = "CONTRIBUTING.md",
+ category = "Project",
+ permalink = "/contributing/",
+ visibility = ["//website:__pkg__"],
+ weight = "20",
+)
+
+doc(
+ name = "security",
+ src = "SECURITY.md",
+ category = "Project",
+ permalink = "/security/",
+ visibility = ["//website:__pkg__"],
+ weight = "30",
+)
+
+doc(
+ name = "governance",
+ src = "GOVERNANCE.md",
+ category = "Project",
+ permalink = "/community/governance/",
+ subcategory = "Community",
+ visibility = ["//website:__pkg__"],
+ weight = "91",
+)
+
+doc(
+ name = "code_of_conduct",
+ src = "CODE_OF_CONDUCT.md",
+ category = "Project",
+ permalink = "/community/code_of_conduct/",
+ subcategory = "Community",
+ visibility = ["//website:__pkg__"],
+ weight = "99",
+)
+
+# The sandbox filegroup is used for sandbox-internal dependencies.
+package_group(
+ name = "sandbox",
+ packages = ["//..."],
+)
+
+# For targets that will not normally build internally, we ensure that they are
+# least build by a static BUILD test.
+build_test(
+ name = "build_test",
+ targets = [
+ "//test/e2e:integration_test",
+ "//test/image:image_test",
+ "//test/root:root_test",
+ ],
+)
+
+# gopath defines a directory that is structured in a way that is compatible
+# with standard Go tools. Things like godoc, editors and refactor tools should
+# work as expected.
+#
+# The files in this tree are symlinks to the true sources.
+go_path(
+ name = "gopath",
+ mode = "link",
+ deps = [
+ "//runsc",
+
+ # Packages that are not dependencies of //runsc.
+ "//pkg/sentry/kernel/memevent",
+ "//pkg/tcpip/adapters/gonet",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/muxed",
+ "//pkg/tcpip/link/sharedmem",
+ "//pkg/tcpip/link/sharedmem/pipe",
+ "//pkg/tcpip/link/sharedmem/queue",
+ "//pkg/tcpip/link/tun",
+ "//pkg/tcpip/link/waitable",
+ "//pkg/tcpip/sample/tun_tcp_connect",
+ "//pkg/tcpip/sample/tun_tcp_echo",
+ "//pkg/tcpip/transport/tcpconntrack",
+ ],
+)
+
+# gazelle is a set of build tools.
+#
+# To update the WORKSPACE from go.mod, use:
+# bazel run //:gazelle -- update-repos -from_file=go.mod
+gazelle(name = "gazelle")
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
new file mode 100644
index 000000000..fbf517fe5
--- /dev/null
+++ b/CODE_OF_CONDUCT.md
@@ -0,0 +1,91 @@
+# Code of Conduct
+
+## Our Pledge
+
+In the interest of fostering an open and welcoming environment, we as
+contributors and maintainers pledge to making participation in our project and
+our community a harassment-free experience for everyone, regardless of age, body
+size, disability, ethnicity, gender identity and expression, level of
+experience, education, socio-economic status, nationality, personal appearance,
+race, religion, or sexual identity and orientation.
+
+## Our Standards
+
+Examples of behavior that contributes to creating a positive environment
+include:
+
+* Using welcoming and inclusive language
+* Being respectful of differing viewpoints and experiences
+* Gracefully accepting constructive criticism
+* Focusing on what is best for the community
+* Showing empathy towards other community members
+
+Examples of unacceptable behavior by participants include:
+
+* The use of sexualized language or imagery and unwelcome sexual attention or
+ advances
+* Trolling, insulting/derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or electronic
+ address, without explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Our Responsibilities
+
+Project maintainers are responsible for clarifying the standards of acceptable
+behavior and are expected to take appropriate and fair corrective action in
+response to any instances of unacceptable behavior.
+
+Project maintainers have the right and responsibility to remove, edit, or reject
+comments, commits, code, wiki edits, issues, and other contributions that are
+not aligned to this Code of Conduct, or to ban temporarily or permanently any
+contributor for other behaviors that they deem inappropriate, threatening,
+offensive, or harmful.
+
+## Scope
+
+This Code of Conduct applies both within project spaces and in public spaces
+when an individual is representing the project or its community. Examples of
+representing a project or community include using an official project e-mail
+address, posting via an official social media account, or acting as an appointed
+representative at an online or offline event. Representation of a project may be
+further defined and clarified by project maintainers.
+
+This Code of Conduct also applies outside the project spaces when the Project
+Steward has a reasonable belief that an individual's behavior may have a
+negative impact on the project or its community.
+
+## Conflict Resolution
+
+We do not believe that all conflict is bad; healthy debate and disagreement
+often yield positive results. However, it is never okay to be disrespectful or
+to engage in behavior that violates the project’s code of conduct.
+
+If you see someone violating the code of conduct, you are encouraged to address
+the behavior directly with those involved. Many issues can be resolved quickly
+and easily, and this gives people more control over the outcome of their
+dispute. If you are unable to resolve the matter for any reason, or if the
+behavior is threatening or harassing, report it. We are dedicated to providing
+an environment where participants feel welcome and safe.
+
+Reports should be directed to Jaice Singer DuMars, jaice at google dot com, the
+Project Steward for gVisor. It is the Project Steward’s duty to receive and
+address reported violations of the code of conduct. They will then work with a
+committee consisting of representatives from the Open Source Programs Office and
+the Google Open Source Strategy team. If for any reason you are uncomfortable
+reaching out the Project Steward, please email opensource@google.com.
+
+We will investigate every complaint, but you may not receive a direct response.
+We will use our discretion in determining when and how to follow up on reported
+incidents, which may range from not taking action to permanent expulsion from
+the project and project-sponsored spaces. We will notify the accused of the
+report and provide them an opportunity to discuss it before any action is taken.
+The identity of the reporter will be omitted from the details of the report
+supplied to the accused. In potentially harmful situations, such as ongoing
+harassment or threats to anyone's safety, we may take action without notice.
+
+## Attribution
+
+This Code of Conduct is adapted from the
+[Contributor Covenant, version 1.4](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html).
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 000000000..89180eb3f
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,129 @@
+# Contributing
+
+Want to contribute? Great! First, read this page.
+
+### Contributor License Agreement
+
+Contributions to this project must be accompanied by a Contributor License
+Agreement. You (or your employer) retain the copyright to your contribution;
+this simply gives us permission to use and redistribute your contributions as
+part of the project. Head over to <https://cla.developers.google.com/> to see
+your current agreements on file or to sign a new one.
+
+You generally only need to submit a CLA once, so if you've already submitted one
+(even if it was for a different project), you probably don't need to do it
+again.
+
+### Using GOPATH
+
+Some editors may require the code to be structured in a `GOPATH` directory tree.
+In this case, you may use the `:gopath` target to generate a directory tree with
+symlinks to the original source files.
+
+```
+bazel build :gopath
+```
+
+You can then set the `GOPATH` in your editor to `bazel-bin/gopath`.
+
+If you use this mechanism, keep in mind that the generated tree is not the
+canonical source. You will still need to build and test with `bazel`. New files
+will need to be added to the appropriate `BUILD` files, and the `:gopath` target
+will need to be re-run to generate appropriate symlinks in the `GOPATH`
+directory tree.
+
+Dependencies can be added by using `go mod get`. In order to keep the
+`WORKSPACE` file in sync, run `tools/go_mod.sh` in place of `go mod`.
+
+### Coding Guidelines
+
+All code should comply with the [style guide](g3doc/style.md). Note that code
+may be automatically formatted per the guidelines when merged.
+
+As a secure runtime, we need to maintain the safety of all of code included in
+gVisor. The following rules help mitigate issues.
+
+Definitions for the rules below:
+
+`core`:
+
+* `//pkg/sentry/...`
+* Transitive dependencies in `//pkg/...`, etc.
+
+`runsc`:
+
+* `//runsc/...`
+
+Rules:
+
+* No cgo in `core` or `runsc`. The final binary must be a statically-linked
+ pure Go binary.
+
+* Any files importing "unsafe" must have a name ending in `_unsafe.go`.
+
+* `core` may only depend on the following packages:
+
+ * Itself.
+ * Go standard library.
+ * Except (transitively) package "net" (this will result in a non-cgo
+ binary). Use `//pkg/unet` instead.
+ * `@org_golang_x_sys//unix:go_default_library` (Go import
+ `golang.org/x/sys/unix`).
+ * Generated Go protobuf packages.
+ * `@com_github_golang_protobuf//proto:go_default_library` (Go import
+ `github.com/golang/protobuf/proto`).
+ * `@com_github_golang_protobuf//ptypes:go_default_library` (Go import
+ `github.com/golang/protobuf/ptypes`).
+
+* `runsc` may only depend on the following packages:
+
+ * All packages allowed for `core`.
+ * `@com_github_google_subcommands//:go_default_library` (Go import
+ `github.com/google/subcommands`).
+ * `@com_github_opencontainers_runtime_spec//specs_go:go_default_library`
+ (Go import `github.com/opencontainers/runtime-spec/specs_go`).
+
+### Code reviews
+
+Before sending code reviews, run `bazel test ...` to ensure tests are passing.
+
+Code changes are accepted via [pull request][github].
+
+When approved, the change will be submitted by a team member and automatically
+merged into the repository.
+
+### Presubmit checks
+
+Accessing check logs may require membership in the
+[gvisor-dev mailing list][gvisor-dev-list], which is public.
+
+### Bug IDs
+
+Some TODOs and NOTEs sprinkled throughout the code have associated IDs of the
+form `b/1234`. These correspond to bugs in our internal bug tracker. Eventually
+these bugs will be moved to the GitHub Issues, but until then they can simply be
+ignored.
+
+### Build and test with Docker
+
+Running `make dev` is a convenient way to build and install `runsc` as a Docker
+runtime. The output of this command will show the runtimes installed.
+
+You may use `make refresh` to refresh the binary after any changes. For example:
+
+```bash
+make dev
+docker run --rm --runtime=my-branch --rm hello-world
+make refresh
+```
+
+### The small print
+
+Contributions made by corporations are covered by a different agreement than the
+one above, the
+[Software Grant and Corporate Contributor License Agreement][gccla].
+
+[gcla]: https://cla.developers.google.com/about/google-individual
+[gccla]: https://cla.developers.google.com/about/google-corporate
+[github]: https://github.com/google/gvisor/compare
+[gvisor-dev-list]: https://groups.google.com/forum/#!forum/gvisor-dev
diff --git a/GOVERNANCE.md b/GOVERNANCE.md
new file mode 100644
index 000000000..40846bc2f
--- /dev/null
+++ b/GOVERNANCE.md
@@ -0,0 +1,113 @@
+# Governance
+
+## Projects
+
+A *project* is the primary unit of collaboration. Each project may have its own
+repository and contribution process.
+
+All projects are covered by the [Code of Conduct](CODE_OF_CONDUCT.md), and
+should include an up-to-date copy in the project repository or a link here.
+
+## Contributors
+
+Anyone can be a *contributor* to a project, provided they have signed relevant
+Contributor License Agreements (CLAs) and follow the project's contribution
+guidelines. Contributions will be reviewed by a maintainer, and must pass all
+applicable tests.
+
+Reviews check for code quality and style, including documentation, and enforce
+other policies. Contributions may be rejected for reasons unrelated to the code
+in question. For example, a change may be too complex to maintain or duplicate
+existing functionality.
+
+Note that contributions are not limited to code alone. Bugs, documentation,
+experience reports or public advocacy are all valuable ways to contribute to a
+project and build trust in the community.
+
+## Maintainers
+
+Each project has one or more *maintainers*. Maintainers set technical direction,
+facilitate contributions and exercise overall stewardship.
+
+Maintainers have write access to the project repository. Maintainers review and
+approve changes. They can also assign issues and add additional reviewers.
+
+Note that some repositories may not allow direct commit access, which is
+reserved for administrators or automated processes. In this case, maintainers
+have approval rights, and a separate process exists for merging a change.
+
+Maintainers are responsible for upholding the code of conduct in interactions
+via project communication channels. If comments or exchanges are in violation,
+they may remove them at their discretion.
+
+### Repositories requiring synchronization
+
+For some projects initiated by Google, the infrastructure which synchronizes and
+merges internal and external changes requires that merges are performed by a
+Google employee. In such cases, Google will initiate a rotation to merge changes
+once they pass tests and are approved by a maintainer. This does not preclude
+non-Google contributors from becoming maintainers, in which case the maintainer
+holds approval rights and the merge is an automated process. In some cases,
+Google-internal tests may fail and have to be fixed: the Google employee will
+work with the submitter to achieve this.
+
+### Becoming a maintainer
+
+The list of maintainers is defined by the list of people with commit access or
+approval authority on a repository, typically via a Gerrit group or a GitHub
+team.
+
+Existing maintainers may elevate a contributor to maintainer status on evidence
+of previous contributions and established trust. This decision is based on lazy
+consensus from existing maintainers. While contributors may ask maintainers to
+make this decision, existing maintainers will also pro-actively identify
+contributors who have demonstrated a sustained track record of technical
+leadership and direct contributions.
+
+## Special Interest Groups (SIGs)
+
+From time-to-time, a SIG may be formed in order to solve larger, more complex
+problems across one or more projects. There are many avenues for collaboration
+outside a SIG, but a SIG can provide structure for collaboration on a single
+topic.
+
+Each group will be established by a charter, and governed by the Code of
+Conduct. Some resources may be provided to the group, such as mailing lists or
+meeting space, and archives will be public.
+
+## Security disclosure
+
+Projects may maintain security mailing lists for vulnerability reports and
+internal project audits may occasionally reveal security issues. Access to these
+lists and audits will be limited to project *maintainers*; individual
+maintainers should opt to participate in these lists based on need and
+expertise. Once maintainers become aware of a potential security issue, they
+will assess the scope and potential impact. If reported externally, maintainers
+will determine a reasonable embargo period with the reporter.
+
+During the embargo period, the maintainers will prioritize a fix for the
+security issue. They may choose to disclose the issue to additional trusted
+contributors in order to facilitate a fix, subjecting them to the embargo, or
+notify affected users in order to give them an advanced opportunity to mitigate
+the issue. The inclusion of specific users in this disclosure is left to the
+discretion of the maintainers and contributors involved, and depends on the
+scale of known project use and exposure.
+
+Once a fix is widely available or the embargo period ends, the maintainers will
+make technical details about the vulnerability and associated fixes available.
+
+## Mailing lists
+
+There are four key mailing lists that span projects.
+
+* [gvisor-users](mailto:gvisor-users@googlegroups.com): general purpose user
+ list.
+* [gvisor-dev](mailto:gvisor-dev@googlegroups.com): general purpose
+ development list.
+* [gvisor-security](mailto:gvisor-security@googlegroups.com): private security
+ list. Access to this list is restricted to maintainers of the core gVisor
+ project, subject to the security disclosure policy described above.
+* [gvisor-syzkaller](mailto:gvisor-syzkaller@googlegroups.com): private
+ syzkaller bug tracking list. Access to this list is not limited to
+ maintainers, but will be granted to those who can credibly contribute to
+ fixes.
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 000000000..d64569567
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/Makefile b/Makefile
new file mode 100644
index 000000000..85818ebea
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,254 @@
+#!/usr/bin/make -f
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Described below.
+OPTIONS :=
+STARTUP_OPTIONS :=
+TARGETS := //runsc
+ARGS :=
+
+default: runsc
+.PHONY: default
+
+## usage: make <target>
+## or
+## make <build|test|copy|run|sudo> STARTUP_OPTIONS="..." OPTIONS="..." TARGETS="..." ARGS="..."
+##
+## Basic targets.
+##
+## This Makefile wraps basic build and test targets for ease-of-use. Bazel
+## is run inside a canonical Docker container in order to simplify up-front
+## requirements.
+##
+## There are common arguments that may be passed to targets. These are:
+## STARTUP_OPTIONS - Bazel startup options.
+## OPTIONS - Build or test options.
+## TARGETS - The bazel targets.
+## ARGS - Arguments for run or sudo.
+##
+## Additionally, the copy target expects a DESTINATION to be provided.
+##
+## For example, to build runsc using this Makefile, you can run:
+## make build OPTIONS="" TARGETS="//runsc"'
+##
+help: ## Shows all targets and help from the Makefile (this message).
+ @grep --no-filename -E '^([a-z.A-Z_-]+:.*?|)##' $(MAKEFILE_LIST) | \
+ awk 'BEGIN {FS = "(:.*?|)## ?"}; { \
+ if (length($$1) > 0) { \
+ printf " \033[36m%-20s\033[0m %s\n", $$1, $$2; \
+ } else { \
+ printf "%s\n", $$2; \
+ } \
+ }'
+build: ## Builds the given $(TARGETS) with the given $(OPTIONS). E.g. make build TARGETS=runsc
+test: ## Tests the given $(TARGETS) with the given $(OPTIONS). E.g. make test TARGETS=pkg/buffer:buffer_test
+copy: ## Copies the given $(TARGETS) to the given $(DESTINATION). E.g. make copy TARGETS=runsc DESTINATION=/tmp
+run: ## Runs the given $(TARGETS), built with $(OPTIONS), using $(ARGS). E.g. make run TARGETS=runsc ARGS=-version
+sudo: ## Runs the given $(TARGETS) as per run, but using "sudo -E". E.g. make sudo TARGETS=test/root:root_test ARGS=-test.v
+.PHONY: help build test copy run sudo
+
+# Load all bazel wrappers.
+#
+# This file should define the basic "build", "test", "run" and "sudo" rules, in
+# addition to the $(BRANCH_NAME) variable.
+ifneq (,$(wildcard tools/google.mk))
+include tools/google.mk
+else
+include tools/bazel.mk
+endif
+
+##
+## Docker image targets.
+##
+## Images used by the tests must also be built and available locally.
+## The canonical test targets defined below will automatically load
+## relevant images. These can be loaded or built manually via these
+## targets.
+##
+## (*) Note that you may provide an ARCH parameter in order to build
+## and load images from an alternate archiecture (using qemu). When
+## bazel is run as a server, this has the effect of running an full
+## cross-architecture chain, and can produce cross-compiled binaries.
+##
+define images
+$(1)-%: ## Image tool: $(1) a given image (also may use 'all-images').
+ @$(MAKE) -C images $$@
+endef
+rebuild-...: ## Rebuild the given image. Also may use 'rebuild-all-images'.
+$(eval $(call images,rebuild))
+push-...: ## Push the given image. Also may use 'push-all-images'.
+$(eval $(call images,pull))
+pull-...: ## Pull the given image. Also may use 'pull-all-images'.
+$(eval $(call images,push))
+load-...: ## Load (pull or rebuild) the given image. Also may use 'load-all-images'.
+$(eval $(call images,load))
+list-images: ## List all available images.
+ @$(MAKE) -C images $$@
+
+##
+## Canonical build and test targets.
+##
+## These targets are used by continuous integration and provide
+## convenient entrypoints for testing changes. If you're adding a
+## new subsystem or workflow, consider adding a new target here.
+##
+runsc: ## Builds the runsc binary.
+ @$(MAKE) build TARGETS="//runsc"
+.PHONY: runsc
+
+smoke-test: ## Runs a simple smoke test after build runsc.
+ @$(MAKE) run DOCKER_PRIVILEGED="" ARGS="--alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do true"
+.PHONY: smoke-tests
+
+unit-tests: ## Runs all unit tests in pkg runsc and tools.
+ @$(MAKE) test OPTIONS="pkg/... runsc/... tools/..."
+.PHONY: unit-tests
+
+tests: ## Runs all local ptrace system call tests.
+ @$(MAKE) test OPTIONS="--test_tag_filters runsc_ptrace test/syscalls/..."
+.PHONY: tests
+
+##
+## Website & documentation helpers.
+##
+## The website is built from repository documentation and wrappers, using
+## using a locally-defined Docker image (see images/jekyll). The following
+## variables may be set when using website-push:
+## WEBSITE_IMAGE - The name of the container image.
+## WEBSITE_SERVICE - The backend service.
+## WEBSITE_PROJECT - The project id to use.
+## WEBSITE_REGION - The region to deploy to.
+##
+WEBSITE_IMAGE := gcr.io/gvisordev/gvisordev
+WEBSITE_SERVICE := gvisordev
+WEBSITE_PROJECT := gvisordev
+WEBSITE_REGION := us-central1
+
+website-build: load-jekyll ## Build the site image locally.
+ @$(MAKE) run TARGETS="//website:website"
+.PHONY: website-build
+
+website-server: website-build ## Run a local server for development.
+ @docker run -i -p 8080:8080 gvisor.dev/images/website
+.PHONY: website-server
+
+website-push: website-build ## Push a new image and update the service.
+ @docker tag gvisor.dev/images/website $(WEBSITE_IMAGE) && docker push $(WEBSITE_IMAGE)
+.PHONY: website-push
+
+website-deploy: website-push ## Deploy a new version of the website.
+ @gcloud run deploy $(WEBSITE_SERVICE) --platform=managed --region=$(WEBSITE_REGION) --project=$(WEBSITE_PROJECT) --image=$(WEBSITE_IMAGE)
+.PHONY: website-push
+
+##
+## Repository builders.
+##
+## This builds a local apt repository. The following variables may be set:
+## RELEASE_ROOT - The repository root (default: "repo" directory).
+## RELEASE_KEY - The repository GPG private key file (default: dummy key is created).
+## RELEASE_NIGHTLY - Set to true if a nightly release (default: false).
+## RELEASE_COMMIT - The commit or Change-Id for the release (needed for tag).
+## RELEASE_NAME - The name of the release in the proper format (needed for tag).
+## RELEASE_NOTES - The file containing release notes (needed for tag).
+##
+RELEASE_ROOT := $(CURDIR)/repo
+RELEASE_KEY := repo.key
+RELEASE_NIGHTLY := false
+RELEASE_COMMIT :=
+RELEASE_NAME :=
+RELEASE_NOTES :=
+
+GPG_TEST_OPTIONS := $(shell if gpg --pinentry-mode loopback --version >/dev/null 2>&1; then echo --pinentry-mode loopback; fi)
+$(RELEASE_KEY):
+ @echo "WARNING: Generating a key for testing ($@); don't use this."
+ T=$$(mktemp /tmp/keyring.XXXXXX); \
+ C=$$(mktemp /tmp/config.XXXXXX); \
+ echo Key-Type: DSA >> $$C && \
+ echo Key-Length: 1024 >> $$C && \
+ echo Name-Real: Test >> $$C && \
+ echo Name-Email: test@example.com >> $$C && \
+ echo Expire-Date: 0 >> $$C && \
+ echo %commit >> $$C && \
+ gpg --batch $(GPG_TEST_OPTIONS) --passphrase '' --no-default-keyring --keyring $$T --no-tty --gen-key $$C && \
+ gpg --batch $(GPG_TEST_OPTIONS) --export-secret-keys --no-default-keyring --keyring $$T --secret-keyring $$T > $@; \
+ rc=$$?; rm -f $$T $$C; exit $$rc
+
+release: $(RELEASE_KEY) ## Builds a release.
+ @mkdir -p $(RELEASE_ROOT)
+ @T=$$(mktemp -d /tmp/release.XXXXXX); \
+ $(MAKE) copy TARGETS="runsc" DESTINATION=$$T && \
+ $(MAKE) copy TARGETS="runsc:runsc-debian" DESTINATION=$$T && \
+ NIGHTLY=$(RELEASE_NIGHTLY) tools/make_release.sh $(RELEASE_KEY) $(RELEASE_ROOT) $$T/*; \
+ rc=$$?; rm -rf $$T; exit $$rc
+.PHONY: release
+
+tag: ## Creates and pushes a release tag.
+ @tools/tag_release.sh "$(RELEASE_COMMIT)" "$(RELEASE_NAME)" "$(RELEASE_NOTES)"
+.PHONY: tag
+
+##
+## Development helpers and tooling.
+##
+## These targets faciliate local development by automatically
+## installing and configuring a runtime. Several variables may
+## be used here to tweak the installation:
+## RUNTIME - The name of the installed runtime (default: branch).
+## RUNTIME_DIR - Where the runtime will be installed (default: temporary directory with the $RUNTIME).
+## RUNTIME_BIN - The runtime binary (default: $RUNTIME_DIR/runsc).
+## RUNTIME_LOG_DIR - The logs directory (default: $RUNTIME_DIR/logs).
+## RUNTIME_LOGS - The log pattern (default: $RUNTIME_LOG_DIR/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%).
+##
+ifeq (,$(BRANCH_NAME))
+RUNTIME := runsc
+RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/runsc
+else
+RUNTIME := $(BRANCH_NAME)
+RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(BRANCH_NAME)
+endif
+RUNTIME_BIN := $(RUNTIME_DIR)/runsc
+RUNTIME_LOG_DIR := $(RUNTIME_DIR)/logs
+RUNTIME_LOGS := $(RUNTIME_LOG_DIR)/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%
+
+dev: ## Installs a set of local runtimes. Requires sudo.
+ @$(MAKE) refresh ARGS="--net-raw"
+ @$(MAKE) configure RUNTIME="$(RUNTIME)" ARGS="--net-raw"
+ @$(MAKE) configure RUNTIME="$(RUNTIME)-d" ARGS="--net-raw --debug --strace --log-packets"
+ @$(MAKE) configure RUNTIME="$(RUNTIME)-p" ARGS="--net-raw --profile"
+ @$(MAKE) configure RUNTIME="$(RUNTIME)-vfs2-d" ARGS="--net-raw --debug --strace --log-packets --vfs2"
+ @sudo systemctl restart docker
+.PHONY: dev
+
+refresh: ## Refreshes the runtime binary (for development only). Must have called 'dev' or 'test-install' first.
+ @mkdir -p "$(RUNTIME_DIR)"
+ @$(MAKE) copy TARGETS=runsc DESTINATION="$(RUNTIME_BIN)" && chmod 0755 "$(RUNTIME_BIN)"
+.PHONY: install
+
+test-install: ## Installs the runtime for testing. Requires sudo.
+ @$(MAKE) refresh ARGS="--net-raw --TESTONLY-test-name-env=RUNSC_TEST_NAME --debug --strace --log-packets $(ARGS)"
+ @$(MAKE) configure
+ @sudo systemctl restart docker
+.PHONY: install-test
+
+configure: ## Configures a single runtime. Requires sudo. Typically called from dev or test-install.
+ @sudo sudo "$(RUNTIME_BIN)" install --experimental=true --runtime="$(RUNTIME)" -- --debug-log "$(RUNTIME_LOGS)" $(ARGS)
+ @echo "Installed runtime \"$(RUNTIME)\" @ $(RUNTIME_BIN)"
+ @echo "Logs are in: $(RUNTIME_LOG_DIR)"
+ @sudo rm -rf "$(RUNTIME_LOG_DIR)" && mkdir -p "$(RUNTIME_LOG_DIR)"
+.PHONY: configure
+
+test-runtime: ## A convenient wrapper around test that provides the runtime argument. Target must still be provided.
+ @$(MAKE) test OPTIONS="$(OPTIONS) --test_arg=--runtime=$(RUNTIME)"
+.PHONY: runtime-test
diff --git a/README.md b/README.md
new file mode 100644
index 000000000..0e3d96b68
--- /dev/null
+++ b/README.md
@@ -0,0 +1,122 @@
+![gVisor](g3doc/logo.png)
+
+![](https://github.com/google/gvisor/workflows/Build/badge.svg)
+[![gVisor chat](https://badges.gitter.im/gvisor/community.png)](https://gitter.im/gvisor/community)
+
+## What is gVisor?
+
+**gVisor** is an application kernel, written in Go, that implements a
+substantial portion of the Linux system surface. It includes an
+[Open Container Initiative (OCI)][oci] runtime called `runsc` that provides an
+isolation boundary between the application and the host kernel. The `runsc`
+runtime integrates with Docker and Kubernetes, making it simple to run sandboxed
+containers.
+
+## Why does gVisor exist?
+
+Containers are not a [**sandbox**][sandbox]. While containers have
+revolutionized how we develop, package, and deploy applications, using them to
+run untrusted or potentially malicious code without additional isolation is not
+a good idea. While using a single, shared kernel allows for efficiency and
+performance gains, it also means that container escape is possible with a single
+vulnerability.
+
+gVisor is an application kernel for containers. It limits the host kernel
+surface accessible to the application while still giving the application access
+to all the features it expects. Unlike most kernels, gVisor does not assume or
+require a fixed set of physical resources; instead, it leverages existing host
+kernel functionality and runs as a normal process. In other words, gVisor
+implements Linux by way of Linux.
+
+gVisor should not be confused with technologies and tools to harden containers
+against external threats, provide additional integrity checks, or limit the
+scope of access for a service. One should always be careful about what data is
+made available to a container.
+
+## Documentation
+
+User documentation and technical architecture, including quick start guides, can
+be found at [gvisor.dev][gvisor-dev].
+
+## Installing from source
+
+gVisor builds on x86_64 and ARM64. Other architectures may become available in
+the future.
+
+For the purposes of these instructions, [bazel][bazel] and other build
+dependencies are wrapped in a build container. It is possible to use
+[bazel][bazel] directly, or type `make help` for standard targets.
+
+### Requirements
+
+Make sure the following dependencies are installed:
+
+* Linux 4.14.77+ ([older linux][old-linux])
+* [Docker version 17.09.0 or greater][docker]
+
+### Building
+
+Build and install the `runsc` binary:
+
+```
+make runsc
+sudo cp ./bazel-bin/runsc/linux_amd64_pure_stripped/runsc /usr/local/bin
+```
+
+### Testing
+
+To run standard test suites, you can use:
+
+```
+make unit-tests
+make tests
+```
+
+To run specific tests, you can specify the target:
+
+```
+make test TARGETS="//runsc:version_test"
+```
+
+### Using `go get`
+
+This project uses [bazel][bazel] to build and manage dependencies. A synthetic
+`go` branch is maintained that is compatible with standard `go` tooling for
+convenience.
+
+For example, to build `runsc` directly from this branch:
+
+```
+echo "module runsc" > go.mod
+GO111MODULE=on go get gvisor.dev/gvisor/runsc@go
+CGO_ENABLED=0 GO111MODULE=on go install gvisor.dev/gvisor/runsc
+```
+
+Note that this branch is supported in a best effort capacity, and direct
+development on this branch is not supported. Development should occur on the
+`master` branch, which is then reflected into the `go` branch.
+
+## Community & Governance
+
+See [GOVERNANCE.md](GOVERNANCE.md) for project governance information.
+
+The [gvisor-users mailing list][gvisor-users-list] and
+[gvisor-dev mailing list][gvisor-dev-list] are good starting points for
+questions and discussion.
+
+## Security Policy
+
+See [SECURITY.md](SECURITY.md).
+
+## Contributing
+
+See [Contributing.md](CONTRIBUTING.md).
+
+[bazel]: https://bazel.build
+[docker]: https://www.docker.com
+[gvisor-users-list]: https://groups.google.com/forum/#!forum/gvisor-users
+[gvisor-dev]: https://gvisor.dev
+[gvisor-dev-list]: https://groups.google.com/forum/#!forum/gvisor-dev
+[oci]: https://www.opencontainers.org
+[old-linux]: https://gvisor.dev/docs/user_guide/networking/#gso
+[sandbox]: https://en.wikipedia.org/wiki/Sandbox_(computer_security)
diff --git a/SECURITY.md b/SECURITY.md
new file mode 100644
index 000000000..a96843895
--- /dev/null
+++ b/SECURITY.md
@@ -0,0 +1,10 @@
+# Security and Vulnerability Reporting
+
+Sensitive security-related questions, comments, and reports should be sent to
+the [gvisor-security mailing list][gvisor-security-list]. You should receive a
+prompt response, typically within 48 hours.
+
+Policies for security list access, vulnerability embargo, and vulnerability
+disclosure are outlined in the [governance policy](GOVERNANCE.md).
+
+[gvisor-security-list]: https://groups.google.com/forum/#!forum/gvisor-security
diff --git a/WORKSPACE b/WORKSPACE
new file mode 100644
index 000000000..417ec6100
--- /dev/null
+++ b/WORKSPACE
@@ -0,0 +1,559 @@
+load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
+load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
+
+# Bazel/starlark utilities.
+http_archive(
+ name = "bazel_skylib",
+ urls = [
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
+ "https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
+ ],
+ sha256 = "97e70364e9249702246c0e9444bccdc4b847bed1eb03c5a3ece4f83dfe6abc44",
+)
+
+load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace")
+
+bazel_skylib_workspace()
+
+# Load go bazel rules and gazelle.
+#
+# Note that this repository actually patches some other Go repositories as it
+# loads it, in order to limit visibility. We hack this process by patching the
+# patch used by the Go rules, turning the trick against itself.
+http_archive(
+ name = "io_bazel_rules_go",
+ patch_args = ["-p1"],
+ patches = [
+ "//tools/nogo:io_bazel_rules_go-visibility.patch",
+ ],
+ sha256 = "db2b2d35293f405430f553bc7a865a8749a8ef60c30287e90d2b278c32771afe",
+ urls = [
+ "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.22.3/rules_go-v0.22.3.tar.gz",
+ "https://github.com/bazelbuild/rules_go/releases/download/v0.22.3/rules_go-v0.22.3.tar.gz",
+ ],
+)
+
+http_archive(
+ name = "bazel_gazelle",
+ sha256 = "d8c45ee70ec39a57e7a05e5027c32b1576cc7f16d9dd37135b0eddde45cf1b10",
+ urls = [
+ "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/v0.20.0/bazel-gazelle-v0.20.0.tar.gz",
+ "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.20.0/bazel-gazelle-v0.20.0.tar.gz",
+ ],
+)
+
+load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies")
+
+go_rules_dependencies()
+
+go_register_toolchains(go_version = "1.14.2")
+
+load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository")
+
+gazelle_dependencies()
+
+# TODO(gvisor.dev/issue/1876): Move the statement to "External repositories"
+# block below once 1876 is fixed.
+#
+# The com_google_protobuf repository below would trigger downloading a older
+# version of org_golang_x_sys. If putting this repository statment in a place
+# after that of the com_google_protobuf, this statement will not work as
+# expectd to download a new version of org_golang_x_sys.
+go_repository(
+ name = "org_golang_x_sys",
+ importpath = "golang.org/x/sys",
+ sum = "h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So=",
+ version = "v0.0.0-20200302150141-5c8b2ff67527",
+)
+
+# Load C++ rules.
+http_archive(
+ name = "rules_cc",
+ sha256 = "67412176974bfce3f4cf8bdaff39784a72ed709fc58def599d1f68710b58d68b",
+ strip_prefix = "rules_cc-b7fe9697c0c76ab2fd431a891dbb9a6a32ed7c3e",
+ urls = [
+ "https://mirror.bazel.build/github.com/bazelbuild/rules_cc/archive/b7fe9697c0c76ab2fd431a891dbb9a6a32ed7c3e.zip",
+ "https://github.com/bazelbuild/rules_cc/archive/b7fe9697c0c76ab2fd431a891dbb9a6a32ed7c3e.zip",
+ ],
+)
+
+# Load protobuf dependencies.
+http_archive(
+ name = "rules_proto",
+ sha256 = "602e7161d9195e50246177e7c55b2f39950a9cf7366f74ed5f22fd45750cd208",
+ strip_prefix = "rules_proto-97d8af4dc474595af3900dd85cb3a29ad28cc313",
+ urls = [
+ "https://mirror.bazel.build/github.com/bazelbuild/rules_proto/archive/97d8af4dc474595af3900dd85cb3a29ad28cc313.tar.gz",
+ "https://github.com/bazelbuild/rules_proto/archive/97d8af4dc474595af3900dd85cb3a29ad28cc313.tar.gz",
+ ],
+)
+
+load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains")
+
+rules_proto_dependencies()
+
+rules_proto_toolchains()
+
+# Load python dependencies.
+git_repository(
+ name = "rules_python",
+ commit = "abc4869e02fe9b3866942e89f07b7341f830e805",
+ remote = "https://github.com/bazelbuild/rules_python.git",
+ shallow_since = "1583341286 -0500",
+)
+
+load("@rules_python//python:pip.bzl", "pip_import")
+
+pip_import(
+ name = "pydeps",
+ python_interpreter = "python3",
+ requirements = "//benchmarks:requirements.txt",
+)
+
+load("@pydeps//:requirements.bzl", "pip_install")
+
+pip_install()
+
+# Load bazel_toolchain to support Remote Build Execution.
+# See releases at https://releases.bazel.build/bazel-toolchains.html
+http_archive(
+ name = "bazel_toolchains",
+ sha256 = "239a1a673861eabf988e9804f45da3b94da28d1aff05c373b013193c315d9d9e",
+ strip_prefix = "bazel-toolchains-3.0.1",
+ urls = [
+ "https://github.com/bazelbuild/bazel-toolchains/releases/download/3.0.1/bazel-toolchains-3.0.1.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/releases/download/3.0.1/bazel-toolchains-3.0.1.tar.gz",
+ ],
+)
+
+# Creates a default toolchain config for RBE.
+load("@bazel_toolchains//rules:rbe_repo.bzl", "rbe_autoconfig")
+
+rbe_autoconfig(name = "rbe_default")
+
+http_archive(
+ name = "rules_pkg",
+ sha256 = "5bdc04987af79bd27bc5b00fe30f59a858f77ffa0bd2d8143d5b31ad8b1bd71c",
+ url = "https://github.com/bazelbuild/rules_pkg/releases/download/0.2.0/rules_pkg-0.2.0.tar.gz",
+)
+
+load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies")
+
+rules_pkg_dependencies()
+
+# Load C++ grpc rules.
+http_archive(
+ name = "com_github_grpc_grpc",
+ sha256 = "2fcb7f1ab160d6fd3aaade64520be3e5446fc4c6fa7ba6581afdc4e26094bd81",
+ strip_prefix = "grpc-1.26.0",
+ urls = [
+ "https://github.com/grpc/grpc/archive/v1.26.0.tar.gz",
+ ],
+)
+
+load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
+
+grpc_deps()
+
+load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps")
+
+grpc_extra_deps()
+
+# External repositories, in sorted order.
+go_repository(
+ name = "com_github_cenkalti_backoff",
+ importpath = "github.com/cenkalti/backoff",
+ sum = "h1:+FKjzBIdfBHYDvxCv+djmDJdes/AoDtg8gpcxowBlF8=",
+ version = "v0.0.0-20190506075156-2146c9339422",
+)
+
+go_repository(
+ name = "com_github_gofrs_flock",
+ importpath = "github.com/gofrs/flock",
+ sum = "h1:JFTFz3HZTGmgMz4E1TabNBNJljROSYgja1b4l50FNVs=",
+ version = "v0.6.1-0.20180915234121-886344bea079",
+)
+
+go_repository(
+ name = "com_github_golang_mock",
+ importpath = "github.com/golang/mock",
+ sum = "h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s=",
+ version = "v1.3.1",
+)
+
+go_repository(
+ name = "com_github_google_go-cmp",
+ importpath = "github.com/google/go-cmp",
+ sum = "h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ=",
+ version = "v0.2.0",
+)
+
+go_repository(
+ name = "com_github_google_subcommands",
+ importpath = "github.com/google/subcommands",
+ sum = "h1:GZGUPQiZfYrd9uOqyqwbQcHPkz/EZJVkZB1MkaO9UBI=",
+ version = "v0.0.0-20190508160503-636abe8753b8",
+)
+
+go_repository(
+ name = "com_github_google_uuid",
+ importpath = "github.com/google/uuid",
+ sum = "h1:rXQlD9GXkjA/PQZhmEaF/8Pj/sJfdZJK7GJG0gkS8I0=",
+ version = "v0.0.0-20171129191014-dec09d789f3d",
+)
+
+go_repository(
+ name = "com_github_kr_pretty",
+ importpath = "github.com/kr/pretty",
+ sum = "h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=",
+ version = "v0.2.0",
+)
+
+go_repository(
+ name = "com_github_kr_pty",
+ importpath = "github.com/kr/pty",
+ sum = "h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw=",
+ version = "v1.1.1",
+)
+
+go_repository(
+ name = "com_github_kr_text",
+ importpath = "github.com/kr/text",
+ sum = "h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=",
+ version = "v0.1.0",
+)
+
+go_repository(
+ name = "com_github_mohae_deepcopy",
+ commit = "c48cc78d482608239f6c4c92a4abd87eb8761c90",
+ importpath = "github.com/mohae/deepcopy",
+)
+
+go_repository(
+ name = "com_github_opencontainers_runtime-spec",
+ importpath = "github.com/opencontainers/runtime-spec",
+ sum = "h1:d9F+LNYwMyi3BDN4GzZdaSiq4otb8duVEWyZjeUtOQI=",
+ version = "v0.1.2-0.20171211145439-b2d941ef6a78",
+)
+
+go_repository(
+ name = "com_github_syndtr_gocapability",
+ importpath = "github.com/syndtr/gocapability",
+ sum = "h1:b6uOv7YOFK0TYG7HtkIgExQo+2RdLuwRft63jn2HWj8=",
+ version = "v0.0.0-20180916011248-d98352740cb2",
+)
+
+go_repository(
+ name = "com_github_vishvananda_netlink",
+ importpath = "github.com/vishvananda/netlink",
+ sum = "h1:/Tdc23Arz1OtdIsBY2utWepGRQ9fEAJlhkdoLzWMK8Q=",
+ version = "v1.0.1-0.20190318003149-adb577d4a45e",
+)
+
+go_repository(
+ name = "com_github_vishvananda_netns",
+ importpath = "github.com/vishvananda/netns",
+ sum = "h1:J9gO8RJCAFlln1jsvRba/CWVUnMHwObklfxxjErl1uk=",
+ version = "v0.0.0-20171111001504-be1fbeda1936",
+)
+
+go_repository(
+ name = "org_golang_google_grpc",
+ build_file_proto_mode = "disable",
+ importpath = "google.golang.org/grpc",
+ sum = "h1:zvIju4sqAGvwKspUQOhwnpcqSbzi7/H6QomNNjTL4sk=",
+ version = "v1.27.1",
+)
+
+go_repository(
+ name = "in_gopkg_check_v1",
+ importpath = "gopkg.in/check.v1",
+ sum = "h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=",
+ version = "v1.0.0-20190902080502-41f04d3bba15",
+)
+
+go_repository(
+ name = "org_golang_x_crypto",
+ importpath = "golang.org/x/crypto",
+ sum = "h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=",
+ version = "v0.0.0-20190308221718-c2843e01d9a2",
+)
+
+go_repository(
+ name = "org_golang_x_mod",
+ importpath = "golang.org/x/mod",
+ sum = "h1:p1YOIz9H/mGN8k1XkaV5VFAq9+zhN9Obefv439UwRhI=",
+ version = "v0.2.1-0.20200224194123-e5e73c1b9c72",
+)
+
+go_repository(
+ name = "org_golang_x_net",
+ importpath = "golang.org/x/net",
+ sum = "h1:R/3boaszxrf1GEUWTVDzSKVwLmSJpwZ1yqXm8j0v2QI=",
+ version = "v0.0.0-20190620200207-3b0461eec859",
+)
+
+go_repository(
+ name = "org_golang_x_sync",
+ importpath = "golang.org/x/sync",
+ sum = "h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU=",
+ version = "v0.0.0-20190423024810-112230192c58",
+)
+
+go_repository(
+ name = "org_golang_x_text",
+ importpath = "golang.org/x/text",
+ sum = "h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=",
+ version = "v0.3.0",
+)
+
+go_repository(
+ name = "org_golang_x_time",
+ importpath = "golang.org/x/time",
+ sum = "h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=",
+ version = "v0.0.0-20191024005414-555d28b269f0",
+)
+
+go_repository(
+ name = "org_golang_x_tools",
+ importpath = "golang.org/x/tools",
+ sum = "h1:Uglradbb4KfUWaYasZhlsDsGRwHHvRsHoNAEONef0W8=",
+ version = "v0.0.0-20200131233409-575de47986ce",
+)
+
+go_repository(
+ name = "org_golang_x_xerrors",
+ importpath = "golang.org/x/xerrors",
+ sum = "h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc=",
+ version = "v0.0.0-20190717185122-a985d3407aa7",
+)
+
+go_repository(
+ name = "com_github_google_btree",
+ importpath = "github.com/google/btree",
+ sum = "h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=",
+ version = "v1.0.0",
+)
+
+go_repository(
+ name = "com_github_golang_protobuf",
+ importpath = "github.com/golang/protobuf",
+ sum = "h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=",
+ version = "v1.3.1",
+)
+
+go_repository(
+ name = "com_github_google_go-github",
+ importpath = "github.com/google/go-github",
+ sum = "h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY=",
+ version = "v17.0.0",
+)
+
+go_repository(
+ name = "org_golang_x_oauth2",
+ importpath = "golang.org/x/oauth2",
+ sum = "h1:pE8b58s1HRDMi8RDc79m0HISf9D4TzseP40cEA6IGfs=",
+ version = "v0.0.0-20191202225959-858c2ad4c8b6",
+)
+
+go_repository(
+ name = "com_github_google_go-querystring",
+ importpath = "github.com/google/go-querystring",
+ sum = "h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=",
+ version = "v1.0.0",
+)
+
+go_repository(
+ name = "com_google_cloud_go_bigquery",
+ importpath = "cloud.google.com/go/bigquery",
+ sum = "h1:K2NyuHRuv15ku6eUpe0DQk5ZykPMnSOnvuVf6IHcjaE=",
+ version = "v1.5.0",
+)
+
+# Docker API dependencies.
+go_repository(
+ name = "com_github_docker_docker",
+ importpath = "github.com/docker/docker",
+ sum = "h1:iWPIG7pWIsCwT6ZtHnTUpoVMnete7O/pzd9HFE3+tn8=",
+ version = "v17.12.0-ce-rc1.0.20200618181300-9dc6525e6118+incompatible",
+)
+
+go_repository(
+ name = "com_github_docker_go_connections",
+ importpath = "github.com/docker/go-connections",
+ sum = "h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ=",
+ version = "v0.4.0",
+)
+
+go_repository(
+ name = "com_github_pkg_errors",
+ importpath = "github.com/pkg/errors",
+ sum = "h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=",
+ version = "v0.9.1",
+)
+
+go_repository(
+ name = "com_github_docker_go_units",
+ importpath = "github.com/docker/go-units",
+ sum = "h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw=",
+ version = "v0.4.0",
+)
+
+go_repository(
+ name = "com_github_opencontainers_go_digest",
+ importpath = "github.com/opencontainers/go-digest",
+ sum = "h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=",
+ version = "v1.0.0",
+)
+
+go_repository(
+ name = "com_github_docker_distribution",
+ importpath = "github.com/docker/distribution",
+ sum = "h1:a5mlkVzth6W5A4fOsS3D2EO5BUmsJpcB+cRlLU7cSug=",
+ version = "v2.7.1+incompatible",
+)
+
+go_repository(
+ name = "com_github_davecgh_go_spew",
+ importpath = "github.com/davecgh/go-spew",
+ sum = "h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=",
+ version = "v1.1.1",
+)
+
+go_repository(
+ name = "com_github_konsorten_go_windows_terminal_sequences",
+ importpath = "github.com/konsorten/go-windows-terminal-sequences",
+ sum = "h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8=",
+ version = "v2.7.1+incompatible",
+)
+
+go_repository(
+ name = "com_github_pmezard_go_difflib",
+ importpath = "github.com/pmezard/go-difflib",
+ sum = "h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=",
+ version = "v1.0.0",
+)
+
+go_repository(
+ name = "com_github_sirupsen_logrus",
+ importpath = "github.com/sirupsen/logrus",
+ sum = "h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I=",
+ version = "v1.6.0",
+)
+
+go_repository(
+ name = "com_github_stretchr_testify",
+ importpath = "github.com/stretchr/testify",
+ sum = "h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=",
+ version = "v1.2.2",
+)
+
+go_repository(
+ name = "com_github_opencontainers_image_spec",
+ importpath = "github.com/opencontainers/image-spec",
+ sum = "h1:JMemWkRwHx4Zj+fVxWoMCFm/8sYGGrUVojFA6h/TRcI=",
+ version = "v1.0.1",
+)
+
+go_repository(
+ name = "com_github_containerd_containerd",
+ importpath = "github.com/containerd/containerd",
+ sum = "h1:3o0smo5SKY7H6AJCmJhsnCjR2/V2T8VmiHt7seN2/kI=",
+ version = "v1.3.4",
+)
+
+go_repository(
+ name = "com_github_microsoft_go_winio",
+ importpath = "github.com/Microsoft/go-winio",
+ sum = "h1:+hMXMk01us9KgxGb7ftKQt2Xpf5hH/yky+TDA+qxleU=",
+ version = "v0.4.14",
+)
+
+go_repository(
+ name = "com_github_stretchr_objx",
+ importpath = "github.com/stretchr/objx",
+ sum = "h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A=",
+ version = "v0.1.1",
+)
+
+go_repository(
+ name = "org_golang_google_api",
+ importpath = "google.golang.org/api",
+ sum = "h1:jz2KixHX7EcCPiQrySzPdnYT7DbINAypCqKZ1Z7GM40=",
+ version = "v0.20.0",
+)
+
+go_repository(
+ name = "org_uber_go_atomic",
+ importpath = "go.uber.org/atomic",
+ sum = "h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk=",
+ version = "v1.6.0",
+)
+
+go_repository(
+ name = "org_uber_go_multierr",
+ importpath = "go.uber.org/multierr",
+ sum = "h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A=",
+ version = "v1.5.0",
+)
+
+# BigQuery Dependencies for Benchmarks
+go_repository(
+ name = "com_google_cloud_go",
+ importpath = "cloud.google.com/go",
+ sum = "h1:eoz/lYxKSL4CNAiaUJ0ZfD1J3bfMYbU5B3rwM1C1EIU=",
+ version = "v0.55.0",
+)
+
+go_repository(
+ name = "com_github_googleapis_gax_go_v2",
+ importpath = "github.com/googleapis/gax-go/v2",
+ sum = "h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM=",
+ version = "v2.0.5",
+)
+
+go_repository(
+ name = "io_opencensus_go",
+ importpath = "go.opencensus.io",
+ sum = "h1:8sGtKOrtQqkN1bp2AtX+misvLIlOmsEsNd+9NIcPEm8=",
+ version = "v0.22.3",
+)
+
+go_repository(
+ name = "com_github_golang_groupcache",
+ importpath = "github.com/golang/groupcache",
+ sum = "h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY=",
+ version = "v0.0.0-20200121045136-8c9f03a8e57e",
+)
+
+# System Call test dependencies.
+http_archive(
+ name = "com_google_absl",
+ sha256 = "56775f1283a59e6274c28d99981a9717ff4e0b1161e9129fdb2fcf22531d8d93",
+ strip_prefix = "abseil-cpp-a0d1e098c2f99694fa399b175a7ccf920762030e",
+ urls = [
+ "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz",
+ "https://github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz",
+ ],
+)
+
+http_archive(
+ name = "com_google_googletest",
+ sha256 = "0a10bea96d8670e5eef948d79d824162b1577bb7889539e49ec786bfc3e48912",
+ strip_prefix = "googletest-565f1b848215b77c3732bca345fe76a0431d8b34",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz",
+ "https://github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz",
+ ],
+)
+
+http_archive(
+ name = "com_google_benchmark",
+ sha256 = "3c6a165b6ecc948967a1ead710d4a181d7b0fbcaa183ef7ea84604994966221a",
+ strip_prefix = "benchmark-1.5.0",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/benchmark/archive/v1.5.0.tar.gz",
+ "https://github.com/google/benchmark/archive/v1.5.0.tar.gz",
+ ],
+)
+
diff --git a/benchmarks/BUILD b/benchmarks/BUILD
new file mode 100644
index 000000000..389351210
--- /dev/null
+++ b/benchmarks/BUILD
@@ -0,0 +1,29 @@
+package(licenses = ["notice"])
+
+config_setting(
+ name = "gcloud_rule",
+ values = {
+ "define": "gcloud=off",
+ },
+)
+
+py_binary(
+ name = "benchmarks",
+ testonly = 1,
+ srcs = ["run.py"],
+ data = select({
+ ":gcloud_rule": [],
+ "//conditions:default": [
+ "//tools/vm:ubuntu1604",
+ "//tools/vm:zone",
+ ],
+ }),
+ main = "run.py",
+ python_version = "PY3",
+ srcs_version = "PY3",
+ tags = [
+ "local",
+ "manual",
+ ],
+ deps = ["//benchmarks/runner"],
+)
diff --git a/benchmarks/README.md b/benchmarks/README.md
new file mode 100644
index 000000000..814bcb220
--- /dev/null
+++ b/benchmarks/README.md
@@ -0,0 +1,186 @@
+# Benchmark tools
+
+These scripts are tools for collecting performance data for Docker-based tests.
+
+## Setup
+
+The scripts assume the following:
+
+* There are two sets of machines: one where the scripts will be run
+ (controller) and one or more machines on which docker containers will be run
+ (environment).
+* The controller machine must have bazel installed along with this source
+ code. You should be able to run a command like `bazel run //benchmarks --
+ --list`
+* Environment machines must have docker and the required runtimes installed.
+ More specifically, you should be able to run a command like: `docker run
+ --runtime=$RUNTIME your/image`.
+* The controller has ssh private key which can be used to login to environment
+ machines and run docker commands without using `sudo`. This is not required
+ if running locally via the `run-local` command.
+* The docker daemon on each of your environment machines is listening on
+ `unix:///var/run/docker.sock` (docker's default).
+
+For configuring the environment manually, consult the
+[dockerd documentation][dockerd].
+
+## Running benchmarks
+
+### Locally
+
+The tool is built to, by default, use Google Cloud Platform to run benchmarks,
+but it does support GCP workflows. To run locally, run the following from the
+benchmarks directory:
+
+```bash
+bazel run --define gcloud=off //benchmarks -- run-local startup
+
+...
+method,metric,result
+startup.empty,startup_time_ms,652.5772
+startup.node,startup_time_ms,1654.4042000000002
+startup.ruby,startup_time_ms,1429.835
+```
+
+The above command ran the startup benchmark locally, which consists of three
+benchmarks (empty, node, and ruby). Benchmark tools ran it on the default
+runtime, runc. Running on another installed runtime, like say runsc, is as
+simple as:
+
+```bash
+bazel run --define gcloud=off //benchmarks -- run-local startup --runtime=runsc
+```
+
+There is help:
+
+```bash
+bazel run --define gcloud=off //benchmarks -- --help
+bazel run --define gcloud=off //benchmarks -- run-local --help
+```
+
+To list available benchmarks, use the `list` commmand:
+
+```bash
+bazel --define gcloud=off run //benchmarks -- list
+
+...
+Benchmark: sysbench.cpu
+Metrics: events_per_second
+ Run sysbench CPU test. Additional arguments can be provided for sysbench.
+
+ :param max_prime: The maximum prime number to search.
+```
+
+You can choose benchmarks by name or regex like:
+
+```bash
+bazel run --define gcloud=off //benchmarks -- run-local startup.node
+...
+metric,result
+startup_time_ms,1671.7178000000001
+
+```
+
+or
+
+```bash
+bazel run --define gcloud=off //benchmarks -- run-local s
+...
+method,metric,result
+startup.empty,startup_time_ms,1792.8292
+startup.node,startup_time_ms,3113.5274
+startup.ruby,startup_time_ms,3025.2424
+sysbench.cpu,cpu_events_per_second,12661.47
+sysbench.memory,memory_ops_per_second,7228268.44
+sysbench.mutex,mutex_time,17.4835
+sysbench.mutex,mutex_latency,3496.7
+sysbench.mutex,mutex_deviation,0.04
+syscall.syscall,syscall_time_ns,2065.0
+```
+
+You can run parameterized benchmarks, for example to run with different
+runtimes:
+
+```bash
+bazel run --define gcloud=off //benchmarks -- run-local --runtime=runc --runtime=runsc sysbench.cpu
+```
+
+Or with different parameters:
+
+```bash
+bazel run --define gcloud=off //benchmarks -- run-local --max_prime=10 --max_prime=100 sysbench.cpu
+```
+
+### On Google Compute Engine (GCE)
+
+Benchmarks may be run on GCE in an automated way. The default project configured
+for `gcloud` will be used.
+
+An additional parameter `installers` may be provided to ensure that the latest
+runtime is installed from the workspace. See the files in `tools/installers` for
+supported install targets.
+
+```bash
+bazel run //benchmarks -- run-gcp --installers=head --runtime=runsc sysbench.cpu
+```
+
+When running on GCE, the scripts generate a per run SSH key, which is added to
+your project. The key is set to expire in GCE after 60 minutes and is stored in
+a temporary directory on the local machine running the scripts.
+
+## Writing benchmarks
+
+To write new benchmarks, you should familiarize yourself with the structure of
+the repository. There are three key components.
+
+## Harness
+
+The harness makes use of the [docker py SDK][docker-py]. It is advisable that
+you familiarize yourself with that API when making changes, specifically:
+
+* clients
+* containers
+* images
+
+In general, benchmarks need only interact with the `Machine` objects provided to
+the benchmark function, which are the machines defined in the environment. These
+objects allow the benchmark to define the relationships between different
+containers, and parse the output.
+
+## Workloads
+
+The harness requires workloads to run. These are all available in the
+`workloads` directory.
+
+In general, a workload consists of a Dockerfile to build it (while these are not
+hermetic, in general they should be as fixed and isolated as possible), some
+parsers for output if required, parser tests and sample data. Provided the test
+is named after the workload package and contains a function named `sample`, this
+variable will be used to automatically mock workload output when the `--mock`
+flag is provided to the main tool.
+
+## Writing benchmarks
+
+Benchmarks define the tests themselves. All benchmarks have the following
+function signature:
+
+```python
+def my_func(output) -> float:
+ return float(output)
+
+@benchmark(metrics = my_func, machines = 1)
+def my_benchmark(machine: machine.Machine, arg: str):
+ return "3.4432"
+```
+
+Each benchmark takes a variable amount of position arguments as
+`harness.Machine` objects and some set of keyword arguments. It is recommended
+that you accept arbitrary keyword arguments and pass them through when
+constructing the container under test.
+
+To write a new benchmark, open a module in the `suites` directory and use the
+above signature. You should add a descriptive doc string to describe what your
+benchmark is and any test centric arguments.
+
+[dockerd]: https://docs.docker.com/engine/reference/commandline/dockerd/
+[docker-py]: https://docker-py.readthedocs.io/en/stable/
diff --git a/benchmarks/defs.bzl b/benchmarks/defs.bzl
new file mode 100644
index 000000000..56d28223e
--- /dev/null
+++ b/benchmarks/defs.bzl
@@ -0,0 +1,14 @@
+"""Provides attributes common to many workload tests."""
+
+load("//tools:defs.bzl", "py_requirement")
+
+test_deps = [
+ py_requirement("attrs", direct = False),
+ py_requirement("atomicwrites", direct = False),
+ py_requirement("more-itertools", direct = False),
+ py_requirement("pathlib2", direct = False),
+ py_requirement("pluggy", direct = False),
+ py_requirement("py", direct = False),
+ py_requirement("pytest"),
+ py_requirement("six", direct = False),
+]
diff --git a/benchmarks/examples/localhost.yaml b/benchmarks/examples/localhost.yaml
new file mode 100644
index 000000000..f70fe0fb7
--- /dev/null
+++ b/benchmarks/examples/localhost.yaml
@@ -0,0 +1,2 @@
+client: localhost
+server: localhost
diff --git a/benchmarks/harness/BUILD b/benchmarks/harness/BUILD
new file mode 100644
index 000000000..48c548d59
--- /dev/null
+++ b/benchmarks/harness/BUILD
@@ -0,0 +1,202 @@
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+pkg_tar(
+ name = "installers",
+ srcs = [
+ "//tools/installers:head",
+ "//tools/installers:master",
+ "//tools/installers:runsc",
+ ],
+ mode = "0755",
+)
+
+filegroup(
+ name = "files",
+ srcs = [
+ ":installers",
+ ],
+)
+
+py_library(
+ name = "harness",
+ srcs = ["__init__.py"],
+ data = [
+ ":files",
+ ],
+)
+
+py_library(
+ name = "benchmark_driver",
+ srcs = ["benchmark_driver.py"],
+ deps = [
+ "//benchmarks/harness/machine_mocks",
+ "//benchmarks/harness/machine_producers:machine_producer",
+ "//benchmarks/suites",
+ ],
+)
+
+py_library(
+ name = "container",
+ srcs = ["container.py"],
+ deps = [
+ "//benchmarks/workloads",
+ py_requirement(
+ "asn1crypto",
+ direct = False,
+ ),
+ py_requirement(
+ "chardet",
+ direct = False,
+ ),
+ py_requirement(
+ "certifi",
+ direct = False,
+ ),
+ py_requirement("docker"),
+ py_requirement(
+ "docker-pycreds",
+ direct = False,
+ ),
+ py_requirement(
+ "idna",
+ direct = False,
+ ),
+ py_requirement(
+ "ptyprocess",
+ direct = False,
+ ),
+ py_requirement(
+ "requests",
+ direct = False,
+ ),
+ py_requirement(
+ "urllib3",
+ direct = False,
+ ),
+ py_requirement(
+ "websocket-client",
+ direct = False,
+ ),
+ ],
+)
+
+py_library(
+ name = "machine",
+ srcs = ["machine.py"],
+ deps = [
+ "//benchmarks/harness",
+ "//benchmarks/harness:container",
+ "//benchmarks/harness:ssh_connection",
+ "//benchmarks/harness:tunnel_dispatcher",
+ "//benchmarks/harness/machine_mocks",
+ py_requirement(
+ "asn1crypto",
+ direct = False,
+ ),
+ py_requirement(
+ "chardet",
+ direct = False,
+ ),
+ py_requirement(
+ "certifi",
+ direct = False,
+ ),
+ py_requirement("docker"),
+ py_requirement(
+ "docker-pycreds",
+ direct = False,
+ ),
+ py_requirement(
+ "idna",
+ direct = False,
+ ),
+ py_requirement(
+ "ptyprocess",
+ direct = False,
+ ),
+ py_requirement(
+ "requests",
+ direct = False,
+ ),
+ py_requirement(
+ "six",
+ direct = False,
+ ),
+ py_requirement(
+ "urllib3",
+ direct = False,
+ ),
+ py_requirement(
+ "websocket-client",
+ direct = False,
+ ),
+ ],
+)
+
+py_library(
+ name = "ssh_connection",
+ srcs = ["ssh_connection.py"],
+ deps = [
+ "//benchmarks/harness",
+ py_requirement(
+ "bcrypt",
+ direct = False,
+ ),
+ py_requirement("cffi"),
+ py_requirement("paramiko"),
+ py_requirement(
+ "cryptography",
+ direct = False,
+ ),
+ ],
+)
+
+py_library(
+ name = "tunnel_dispatcher",
+ srcs = ["tunnel_dispatcher.py"],
+ deps = [
+ py_requirement(
+ "asn1crypto",
+ direct = False,
+ ),
+ py_requirement(
+ "chardet",
+ direct = False,
+ ),
+ py_requirement(
+ "certifi",
+ direct = False,
+ ),
+ py_requirement("docker"),
+ py_requirement(
+ "docker-pycreds",
+ direct = False,
+ ),
+ py_requirement(
+ "idna",
+ direct = False,
+ ),
+ py_requirement("pexpect"),
+ py_requirement(
+ "ptyprocess",
+ direct = False,
+ ),
+ py_requirement(
+ "requests",
+ direct = False,
+ ),
+ py_requirement(
+ "urllib3",
+ direct = False,
+ ),
+ py_requirement(
+ "websocket-client",
+ direct = False,
+ ),
+ ],
+)
diff --git a/benchmarks/harness/__init__.py b/benchmarks/harness/__init__.py
new file mode 100644
index 000000000..15aa2a69a
--- /dev/null
+++ b/benchmarks/harness/__init__.py
@@ -0,0 +1,62 @@
+# python3
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Core benchmark utilities."""
+
+import getpass
+import os
+import subprocess
+import tempfile
+
+# LOCAL_WORKLOADS_PATH defines the path to use for local workloads. This is a
+# format string that accepts a single string parameter.
+LOCAL_WORKLOADS_PATH = os.path.dirname(__file__) + "/../workloads/{}/tar.tar"
+
+# REMOTE_WORKLOADS_PATH defines the path to use for storing the workloads on the
+# remote host. This is a format string that accepts a single string parameter.
+REMOTE_WORKLOADS_PATH = "workloads/{}"
+
+# INSTALLER_ROOT is the set of files that needs to be copied.
+INSTALLER_ARCHIVE = os.readlink(os.path.join(
+ os.path.dirname(__file__), "installers.tar"))
+
+# SSH_KEY_DIR holds SSH_PRIVATE_KEY for this run. bm-tools paramiko requires
+# keys generated with the '-t rsa -m PEM' options from ssh-keygen. This is
+# abstracted away from the user.
+SSH_KEY_DIR = tempfile.TemporaryDirectory()
+SSH_PRIVATE_KEY = "key"
+
+# DEFAULT_USER is the default user running this script.
+DEFAULT_USER = getpass.getuser()
+
+# DEFAULT_USER_HOME is the home directory of the user running the script.
+DEFAULT_USER_HOME = os.environ["HOME"] if "HOME" in os.environ else ""
+
+# Default directory to remotely installer "installer" targets.
+REMOTE_INSTALLERS_PATH = "installers"
+
+
+def make_key():
+ """Wraps a valid ssh key in a temporary directory."""
+ path = os.path.join(SSH_KEY_DIR.name, SSH_PRIVATE_KEY)
+ if not os.path.exists(path):
+ cmd = "ssh-keygen -t rsa -m PEM -b 4096 -f {key} -q -N".format(
+ key=path).split(" ")
+ cmd.append("")
+ subprocess.run(cmd, check=True)
+ return path
+
+
+def delete_key():
+ """Deletes temporary directory containing private key."""
+ SSH_KEY_DIR.cleanup()
diff --git a/benchmarks/harness/benchmark_driver.py b/benchmarks/harness/benchmark_driver.py
new file mode 100644
index 000000000..9abc21b54
--- /dev/null
+++ b/benchmarks/harness/benchmark_driver.py
@@ -0,0 +1,85 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Main driver for benchmarks."""
+
+import copy
+import statistics
+import threading
+import types
+
+from benchmarks import suites
+from benchmarks.harness.machine_producers import machine_producer
+
+
+# pylint: disable=too-many-instance-attributes
+class BenchmarkDriver:
+ """Allocates machines and invokes a benchmark method."""
+
+ def __init__(self,
+ producer: machine_producer.MachineProducer,
+ method: types.FunctionType,
+ runs: int = 1,
+ **kwargs):
+
+ self._producer = producer
+ self._method = method
+ self._kwargs = copy.deepcopy(kwargs)
+ self._threads = []
+ self.lock = threading.RLock()
+ self._runs = runs
+ self._metric_results = {}
+
+ def start(self):
+ """Starts a benchmark thread."""
+ for _ in range(self._runs):
+ thread = threading.Thread(target=self._run_method)
+ thread.start()
+ self._threads.append(thread)
+
+ def join(self):
+ """Joins the thread."""
+ # pylint: disable=expression-not-assigned
+ [t.join() for t in self._threads]
+
+ def _run_method(self):
+ """Runs all benchmarks."""
+ machines = self._producer.get_machines(
+ suites.benchmark_machines(self._method))
+ try:
+ result = self._method(*machines, **self._kwargs)
+ for name, res in result:
+ with self.lock:
+ if name in self._metric_results:
+ self._metric_results[name].append(res)
+ else:
+ self._metric_results[name] = [res]
+ finally:
+ # Always release.
+ self._producer.release_machines(machines)
+
+ def median(self):
+ """Returns the median result, after join is finished."""
+ for key, value in self._metric_results.items():
+ yield key, [statistics.median(value)]
+
+ def all(self):
+ """Returns all results."""
+ for key, value in self._metric_results.items():
+ yield key, value
+
+ def meanstd(self):
+ """Returns all results."""
+ for key, value in self._metric_results.items():
+ mean = statistics.mean(value)
+ yield key, [mean, statistics.stdev(value, xbar=mean)]
diff --git a/benchmarks/harness/container.py b/benchmarks/harness/container.py
new file mode 100644
index 000000000..585436e20
--- /dev/null
+++ b/benchmarks/harness/container.py
@@ -0,0 +1,181 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Container definitions."""
+
+import contextlib
+import logging
+import pydoc
+import types
+from typing import Tuple
+
+import docker
+import docker.errors
+
+from benchmarks import workloads
+
+
+class Container:
+ """Abstract container.
+
+ Must be a context manager.
+
+ Usage:
+
+ with Container(client, image, ...):
+ ...
+ """
+
+ def run(self, **env) -> str:
+ """Run the container synchronously."""
+ raise NotImplementedError
+
+ def detach(self, **env):
+ """Run the container asynchronously."""
+ raise NotImplementedError
+
+ def address(self) -> Tuple[str, int]:
+ """Return the bound address for the container."""
+ raise NotImplementedError
+
+ def get_names(self) -> types.GeneratorType:
+ """Return names of all containers."""
+ raise NotImplementedError
+
+
+# pylint: disable=too-many-instance-attributes
+class DockerContainer(Container):
+ """Class that handles creating a docker container."""
+
+ # pylint: disable=too-many-arguments
+ def __init__(self,
+ client: docker.DockerClient,
+ host: str,
+ image: str,
+ count: int = 1,
+ runtime: str = "runc",
+ port: int = 0,
+ **kwargs):
+ """Trys to setup "count" containers.
+
+ Args:
+ client: A docker client from dockerpy.
+ host: The host address the image is running on.
+ image: The name of the image to run.
+ count: The number of containers to setup.
+ runtime: The container runtime to use.
+ port: The port to reserve.
+ **kwargs: Additional container options.
+ """
+ assert count >= 1
+ assert port == 0 or count == 1
+ self._client = client
+ self._host = host
+ self._containers = []
+ self._count = count
+ self._image = image
+ self._runtime = runtime
+ self._port = port
+ self._kwargs = kwargs
+ if port != 0:
+ self._ports = {"%d/tcp" % port: None}
+ else:
+ self._ports = {}
+
+ @contextlib.contextmanager
+ def detach(self, **env):
+ env = ["%s=%s" % (key, value) for (key, value) in env.items()]
+ # Start all containers.
+ for _ in range(self._count):
+ try:
+ # Start the container in a detached mode.
+ container = self._client.containers.run(
+ self._image,
+ detach=True,
+ remove=True,
+ runtime=self._runtime,
+ ports=self._ports,
+ environment=env,
+ **self._kwargs)
+ logging.info("Started detached container %s -> %s", self._image,
+ container.attrs["Id"])
+ self._containers.append(container)
+ except Exception as exc:
+ self._clean_containers()
+ raise exc
+ try:
+ # Wait for all containers to be up.
+ for container in self._containers:
+ while not container.attrs["State"]["Running"]:
+ container = self._client.containers.get(container.attrs["Id"])
+ yield self
+ finally:
+ self._clean_containers()
+
+ def address(self) -> Tuple[str, int]:
+ assert self._count == 1
+ assert self._port != 0
+ container = self._client.containers.get(self._containers[0].attrs["Id"])
+ port = container.attrs["NetworkSettings"]["Ports"][
+ "%d/tcp" % self._port][0]["HostPort"]
+ return (self._host, port)
+
+ def get_names(self) -> types.GeneratorType:
+ for container in self._containers:
+ yield container.name
+
+ def run(self, **env) -> str:
+ env = ["%s=%s" % (key, value) for (key, value) in env.items()]
+ return self._client.containers.run(
+ self._image,
+ runtime=self._runtime,
+ ports=self._ports,
+ remove=True,
+ environment=env,
+ **self._kwargs).decode("utf-8")
+
+ def _clean_containers(self):
+ """Kills all containers."""
+ for container in self._containers:
+ try:
+ container.kill()
+ except docker.errors.NotFound:
+ pass
+
+
+class MockContainer(Container):
+ """Mock of Container."""
+
+ def __init__(self, workload: str):
+ self._workload = workload
+
+ def __enter__(self):
+ return self
+
+ def run(self, **env):
+ # Lookup sample data if any exists for the workload module. We use a
+ # well-defined test locate and a well-defined sample function.
+ mod = pydoc.locate(workloads.__name__ + "." + self._workload)
+ if hasattr(mod, "sample"):
+ return mod.sample(**env)
+ return "" # No output.
+
+ def address(self) -> Tuple[str, int]:
+ return ("example.com", 80)
+
+ def get_names(self) -> types.GeneratorType:
+ yield "mock"
+
+ @contextlib.contextmanager
+ def detach(self, **env):
+ yield self
diff --git a/benchmarks/harness/machine.py b/benchmarks/harness/machine.py
new file mode 100644
index 000000000..5bdc4aa85
--- /dev/null
+++ b/benchmarks/harness/machine.py
@@ -0,0 +1,265 @@
+# python3
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Machine abstraction passed to benchmarks to run docker containers.
+
+Abstraction for interacting with test machines. Machines are produced
+by Machine producers and represent a local or remote machine. Benchmark
+methods in /benchmarks/suite are passed the required number of machines in order
+to run the benchmark. Machines contain methods to run commands via bash,
+possibly over ssh. Machines also hold a connection to the docker UNIX socket
+to run contianers.
+
+ Typical usage example:
+
+ machine = Machine()
+ machine.run(cmd)
+ machine.pull(path)
+ container = machine.container()
+"""
+
+import logging
+import os
+import re
+import subprocess
+import time
+from typing import List, Tuple
+
+import docker
+
+from benchmarks import harness
+from benchmarks.harness import container
+from benchmarks.harness import machine_mocks
+from benchmarks.harness import ssh_connection
+from benchmarks.harness import tunnel_dispatcher
+
+log = logging.getLogger(__name__)
+
+
+class Machine(object):
+ """The machine object is the primary object for benchmarks.
+
+ Machine objects are passed to each metric function call and benchmarks use
+ machines to access real connections to those machines.
+
+ Attributes:
+ _name: Name as a string
+ """
+ _name = ""
+
+ def run(self, cmd: str) -> Tuple[str, str]:
+ """Convenience method for running a bash command on a machine object.
+
+ Some machines may point to the local machine, and thus, do not have ssh
+ connections. Run runs a command either local or over ssh and returns the
+ output stdout and stderr as strings.
+
+ Args:
+ cmd: The command to run as a string.
+
+ Returns:
+ The command output.
+ """
+ raise NotImplementedError
+
+ def read(self, path: str) -> str:
+ """Reads the contents of some file.
+
+ This will be mocked.
+
+ Args:
+ path: The path to the file to be read.
+
+ Returns:
+ The file contents.
+ """
+ raise NotImplementedError
+
+ def pull(self, workload: str) -> str:
+ """Send the given workload to the machine, build and tag it.
+
+ All images must be defined by the workloads directory.
+
+ Args:
+ workload: The workload name.
+
+ Returns:
+ The workload tag.
+ """
+ raise NotImplementedError
+
+ def container(self, image: str, **kwargs) -> container.Container:
+ """Returns a container object.
+
+ Args:
+ image: The pulled image tag.
+ **kwargs: Additional container options.
+
+ Returns:
+ :return: a container.Container object.
+ """
+ raise NotImplementedError
+
+ def sleep(self, amount: float):
+ """Sleeps the given amount of time."""
+ time.sleep(amount)
+
+ def __str__(self):
+ return self._name
+
+
+class MockMachine(Machine):
+ """A mocked machine."""
+ _name = "mock"
+
+ def run(self, cmd: str) -> Tuple[str, str]:
+ return "", ""
+
+ def read(self, path: str) -> str:
+ return machine_mocks.Readfile(path)
+
+ def pull(self, workload: str) -> str:
+ return workload # Workload is the tag.
+
+ def container(self, image: str, **kwargs) -> container.Container:
+ return container.MockContainer(image)
+
+ def sleep(self, amount: float):
+ pass
+
+
+def get_address(machine: Machine) -> str:
+ """Return a machine's default address."""
+ default_route, _ = machine.run("ip route get 8.8.8.8")
+ return re.search(" src ([0-9.]+) ", default_route).group(1)
+
+
+class LocalMachine(Machine):
+ """The local machine.
+
+ Attributes:
+ _name: Name as a string
+ _docker_client: a pythonic connection to to the local dockerd unix socket.
+ See: https://github.com/docker/docker-py
+ """
+
+ def __init__(self, name):
+ self._name = name
+ self._docker_client = docker.from_env()
+
+ def run(self, cmd: str) -> Tuple[str, str]:
+ process = subprocess.Popen(
+ cmd.split(" "), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ stdout, stderr = process.communicate()
+ return stdout.decode("utf-8"), stderr.decode("utf-8")
+
+ def read(self, path: str) -> bytes:
+ # Read the exact path locally.
+ return open(path, "r").read()
+
+ def pull(self, workload: str) -> str:
+ # Run the docker build command locally.
+ logging.info("Building %s@%s locally...", workload, self._name)
+ with open(harness.LOCAL_WORKLOADS_PATH.format(workload),
+ "rb") as dockerfile:
+ self._docker_client.images.build(
+ fileobj=dockerfile, tag=workload, custom_context=True)
+ return workload # Workload is the tag.
+
+ def container(self, image: str, **kwargs) -> container.Container:
+ # Return a local docker container directly.
+ return container.DockerContainer(self._docker_client, get_address(self),
+ image, **kwargs)
+
+ def sleep(self, amount: float):
+ time.sleep(amount)
+
+
+class RemoteMachine(Machine):
+ """Remote machine accessible via an SSH connection.
+
+ Attributes:
+ _name: Name as a string
+ _ssh_connection: a paramiko backed ssh connection which can be used to run
+ commands on this machine
+ _tunnel: a python wrapper around a port forwarded ssh connection between a
+ local unix socket and the remote machine's dockerd unix socket.
+ _docker_client: a pythonic wrapper backed by the _tunnel. Allows sending
+ docker commands: see https://github.com/docker/docker-py
+ """
+
+ def __init__(self, name, **kwargs):
+ self._name = name
+ self._ssh_connection = ssh_connection.SSHConnection(name, **kwargs)
+ self._tunnel = tunnel_dispatcher.Tunnel(name, **kwargs)
+ self._tunnel.connect()
+ self._docker_client = self._tunnel.get_docker_client()
+ self._has_installers = False
+
+ def run(self, cmd: str) -> Tuple[str, str]:
+ return self._ssh_connection.run(cmd)
+
+ def read(self, path: str) -> str:
+ # Just cat remotely.
+ stdout, stderr = self._ssh_connection.run("cat '{}'".format(path))
+ return stdout + stderr
+
+ def install(self,
+ installer: str,
+ results: List[bool] = None,
+ index: int = -1):
+ """Method unique to RemoteMachine to handle installation of installers.
+
+ Handles installers, which install things that may change between runs (e.g.
+ runsc). Usually called from gcloud_producer, which expects this method to
+ to store results.
+
+ Args:
+ installer: the installer target to run.
+ results: Passed by the caller of where to store success.
+ index: Index for this method to store the result in the passed results
+ list.
+ """
+ # This generates a tarball of the full installer root (which will generate
+ # be the full bazel root directory) and sends it over.
+ if not self._has_installers:
+ archive = self._ssh_connection.send_installers()
+ self.run("tar -xvf {archive} -C {dir}".format(
+ archive=archive, dir=harness.REMOTE_INSTALLERS_PATH))
+ self._has_installers = True
+
+ # Execute the remote installer.
+ self.run("sudo {dir}/{file}".format(
+ dir=harness.REMOTE_INSTALLERS_PATH, file=installer))
+
+ if results:
+ results[index] = True
+
+ def pull(self, workload: str) -> str:
+ # Push to the remote machine and build.
+ logging.info("Building %s@%s remotely...", workload, self._name)
+ remote_path = self._ssh_connection.send_workload(workload)
+ remote_dir = os.path.dirname(remote_path)
+ # Workloads are all tarballs.
+ self.run("tar -xvf {remote_path} -C {remote_dir}".format(
+ remote_path=remote_path, remote_dir=remote_dir))
+ self.run("docker build --tag={} {}".format(workload, remote_dir))
+ return workload # Workload is the tag.
+
+ def container(self, image: str, **kwargs) -> container.Container:
+ # Return a remote docker container.
+ return container.DockerContainer(self._docker_client, get_address(self),
+ image, **kwargs)
+
+ def sleep(self, amount: float):
+ time.sleep(amount)
diff --git a/benchmarks/harness/machine_mocks/BUILD b/benchmarks/harness/machine_mocks/BUILD
new file mode 100644
index 000000000..c8ec4bc79
--- /dev/null
+++ b/benchmarks/harness/machine_mocks/BUILD
@@ -0,0 +1,9 @@
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+py_library(
+ name = "machine_mocks",
+ srcs = ["__init__.py"],
+)
diff --git a/benchmarks/harness/machine_mocks/__init__.py b/benchmarks/harness/machine_mocks/__init__.py
new file mode 100644
index 000000000..00f0085d7
--- /dev/null
+++ b/benchmarks/harness/machine_mocks/__init__.py
@@ -0,0 +1,81 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Machine mock files."""
+
+MEMINFO = """\
+MemTotal: 7652344 kB
+MemFree: 7174724 kB
+MemAvailable: 7152008 kB
+Buffers: 7544 kB
+Cached: 178856 kB
+SwapCached: 0 kB
+Active: 270928 kB
+Inactive: 68436 kB
+Active(anon): 153124 kB
+Inactive(anon): 880 kB
+Active(file): 117804 kB
+Inactive(file): 67556 kB
+Unevictable: 0 kB
+Mlocked: 0 kB
+SwapTotal: 0 kB
+SwapFree: 0 kB
+Dirty: 900 kB
+Writeback: 0 kB
+AnonPages: 153000 kB
+Mapped: 129120 kB
+Shmem: 1044 kB
+Slab: 60864 kB
+SReclaimable: 22792 kB
+SUnreclaim: 38072 kB
+KernelStack: 2672 kB
+PageTables: 5756 kB
+NFS_Unstable: 0 kB
+Bounce: 0 kB
+WritebackTmp: 0 kB
+CommitLimit: 3826172 kB
+Committed_AS: 663836 kB
+VmallocTotal: 34359738367 kB
+VmallocUsed: 0 kB
+VmallocChunk: 0 kB
+HardwareCorrupted: 0 kB
+AnonHugePages: 0 kB
+ShmemHugePages: 0 kB
+ShmemPmdMapped: 0 kB
+CmaTotal: 0 kB
+CmaFree: 0 kB
+HugePages_Total: 0
+HugePages_Free: 0
+HugePages_Rsvd: 0
+HugePages_Surp: 0
+Hugepagesize: 2048 kB
+DirectMap4k: 94196 kB
+DirectMap2M: 4624384 kB
+DirectMap1G: 3145728 kB
+"""
+
+CONTENTS = {
+ "/proc/meminfo": MEMINFO,
+}
+
+
+def Readfile(path: str) -> str:
+ """Reads a mock file.
+
+ Args:
+ path: The target path.
+
+ Returns:
+ Mocked file contents or None.
+ """
+ return CONTENTS.get(path, None)
diff --git a/benchmarks/harness/machine_producers/BUILD b/benchmarks/harness/machine_producers/BUILD
new file mode 100644
index 000000000..81f19bd08
--- /dev/null
+++ b/benchmarks/harness/machine_producers/BUILD
@@ -0,0 +1,84 @@
+load("//tools:defs.bzl", "py_library", "py_requirement")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+py_library(
+ name = "harness",
+ srcs = ["__init__.py"],
+)
+
+py_library(
+ name = "machine_producer",
+ srcs = ["machine_producer.py"],
+)
+
+py_library(
+ name = "mock_producer",
+ srcs = ["mock_producer.py"],
+ deps = [
+ "//benchmarks/harness:machine",
+ "//benchmarks/harness/machine_producers:gcloud_producer",
+ "//benchmarks/harness/machine_producers:machine_producer",
+ ],
+)
+
+py_library(
+ name = "yaml_producer",
+ srcs = ["yaml_producer.py"],
+ deps = [
+ "//benchmarks/harness:machine",
+ "//benchmarks/harness/machine_producers:machine_producer",
+ py_requirement(
+ "PyYAML",
+ direct = False,
+ ),
+ ],
+)
+
+py_library(
+ name = "gcloud_mock_recorder",
+ srcs = ["gcloud_mock_recorder.py"],
+)
+
+py_library(
+ name = "gcloud_producer",
+ srcs = ["gcloud_producer.py"],
+ deps = [
+ "//benchmarks/harness:machine",
+ "//benchmarks/harness/machine_producers:gcloud_mock_recorder",
+ "//benchmarks/harness/machine_producers:machine_producer",
+ ],
+)
+
+filegroup(
+ name = "test_data",
+ srcs = [
+ "testdata/get_five.json",
+ "testdata/get_one.json",
+ ],
+)
+
+py_library(
+ name = "gcloud_producer_test_lib",
+ srcs = ["gcloud_producer_test.py"],
+ deps = [
+ "//benchmarks/harness/machine_producers:machine_producer",
+ "//benchmarks/harness/machine_producers:mock_producer",
+ ],
+)
+
+py_test(
+ name = "gcloud_producer_test",
+ srcs = [":gcloud_producer_test_lib"],
+ data = [
+ ":test_data",
+ ],
+ python_version = "PY3",
+ tags = [
+ "local",
+ "manual",
+ ],
+)
diff --git a/benchmarks/harness/machine_producers/__init__.py b/benchmarks/harness/machine_producers/__init__.py
new file mode 100644
index 000000000..634ef4843
--- /dev/null
+++ b/benchmarks/harness/machine_producers/__init__.py
@@ -0,0 +1,13 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/benchmarks/harness/machine_producers/gcloud_mock_recorder.py b/benchmarks/harness/machine_producers/gcloud_mock_recorder.py
new file mode 100644
index 000000000..fd9837a37
--- /dev/null
+++ b/benchmarks/harness/machine_producers/gcloud_mock_recorder.py
@@ -0,0 +1,97 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A recorder and replay for testing the GCloudProducer.
+
+MockPrinter and MockReader handle printing and reading mock data for the
+purposes of testing. MockPrinter is passed to GCloudProducer objects. The user
+can then run scenarios and record them for playback in tests later.
+
+MockReader is passed to MockGcloudProducer objects and handles reading the
+previously recorded mock data.
+
+It is left to the user to check if data printed is properly redacted for their
+own use. The intended usecase for this class is data coming from gcloud
+commands, which will contain public IPs and other instance data.
+
+The data format is json and printed/read from the ./test_data directory. The
+data is the output of subprocess.CompletedProcess objects in json format.
+
+ Typical usage example:
+
+ recorder = MockPrinter()
+ producer = GCloudProducer(args, recorder)
+ machines = producer.get_machines(1)
+ with open("my_file.json") as fd:
+ recorder.write_out(fd)
+
+ reader = MockReader(filename)
+ producer = MockGcloudProducer(args, mock)
+ machines = producer.get_machines(1)
+ assert len(machines) == 1
+"""
+
+import io
+import json
+import subprocess
+
+
+class MockPrinter(object):
+ """Handles printing Mock data for MockGcloudProducer.
+
+ Attributes:
+ _records: list of json object records for printing
+ """
+
+ def __init__(self):
+ self._records = []
+
+ def record(self, entry: subprocess.CompletedProcess):
+ """Records data and strips out ip addresses."""
+
+ record = {
+ "args": entry.args,
+ "stdout": entry.stdout.decode("utf-8"),
+ "returncode": str(entry.returncode)
+ }
+ self._records.append(record)
+
+ def write_out(self, fd: io.FileIO):
+ """Prints out the data into the given filepath."""
+ fd.write(json.dumps(self._records, indent=4))
+
+
+class MockReader(object):
+ """Handles reading Mock data for MockGcloudProducer.
+
+ Attributes:
+ _records: List[json] records read from the passed in file.
+ """
+
+ def __init__(self, filepath: str):
+ with open(filepath, "rb") as file:
+ self._records = json.loads(file.read())
+ self._i = 0
+
+ def __iter__(self):
+ return self
+
+ def __next__(self, args) -> subprocess.CompletedProcess:
+ """Returns the next record as a CompletedProcess."""
+ if self._i < len(self._records):
+ record = self._records[self._i]
+ stdout = record["stdout"].encode("ascii")
+ returncode = int(record["returncode"])
+ return subprocess.CompletedProcess(
+ args=args, returncode=returncode, stdout=stdout, stderr=b"")
+ raise StopIteration()
diff --git a/benchmarks/harness/machine_producers/gcloud_producer.py b/benchmarks/harness/machine_producers/gcloud_producer.py
new file mode 100644
index 000000000..44d72f575
--- /dev/null
+++ b/benchmarks/harness/machine_producers/gcloud_producer.py
@@ -0,0 +1,250 @@
+# python3
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A machine producer which produces machine objects using `gcloud`.
+
+Machine producers produce valid harness.Machine objects which are backed by
+real machines. This producer produces those machines on the given user's GCP
+account using the `gcloud` tool.
+
+GCloudProducer creates instances on the given GCP account named like:
+`machine-XXXXXXX-XXXX-XXXX-XXXXXXXXXXXX` in a randomized fashion such that name
+collisions with user instances shouldn't happen.
+
+ Typical usage example:
+
+ producer = GCloudProducer(args)
+ machines = producer.get_machines(NUM_MACHINES)
+ # run stuff on machines with machines[i].run(CMD)
+ producer.release_machines(NUM_MACHINES)
+"""
+import datetime
+import json
+import subprocess
+import threading
+from typing import List, Dict, Any
+import uuid
+
+from benchmarks.harness import machine
+from benchmarks.harness.machine_producers import gcloud_mock_recorder
+from benchmarks.harness.machine_producers import machine_producer
+
+
+class GCloudProducer(machine_producer.MachineProducer):
+ """Implementation of MachineProducer backed by GCP.
+
+ Produces Machine objects backed by GCP instances.
+
+ Attributes:
+ image: image name as a string.
+ zone: string to a valid GCP zone.
+ machine_type: type of GCP to create (e.g. n1-standard-4).
+ installers: list of installers post-boot.
+ ssh_key_file: path to a valid ssh private key. See README on vaild ssh keys.
+ ssh_user: string of user name for ssh_key
+ ssh_password: string of password for ssh key
+ internal: if true, use internal IPs of instances. Used if bm-tools is
+ running on a GCP vm when a firewall is set for external IPs.
+ mock: a mock printer which will print mock data if required. Mock data is
+ recorded output from subprocess calls (returncode, stdout, args).
+ condition: mutex for this class around machine creation and deleteion.
+ """
+
+ def __init__(self,
+ image: str,
+ zone: str,
+ machine_type: str,
+ installers: List[str],
+ ssh_key_file: str,
+ ssh_user: str,
+ ssh_password: str,
+ internal: bool,
+ mock: gcloud_mock_recorder.MockPrinter = None):
+ self.image = image
+ self.zone = zone
+ self.machine_type = machine_type
+ self.installers = installers
+ self.ssh_key_file = ssh_key_file
+ self.ssh_user = ssh_user
+ self.ssh_password = ssh_password
+ self.internal = internal
+ self.mock = mock
+ self.condition = threading.Condition()
+
+ def get_machines(self, num_machines: int) -> List[machine.Machine]:
+ """Returns requested number of machines backed by GCP instances."""
+ if num_machines <= 0:
+ raise ValueError(
+ "Cannot ask for {num} machines!".format(num=num_machines))
+ with self.condition:
+ names = self._get_unique_names(num_machines)
+ instances = self._build_instances(names)
+ self._add_ssh_key_to_instances(names)
+ machines = self._machines_from_instances(instances)
+
+ # Install all bits in lock-step.
+ #
+ # This will perform paralell installations for however many machines we
+ # have, but it's easy to track errors because if installing (a, b, c), we
+ # won't install "c" until "b" is installed on all machines.
+ for installer in self.installers:
+ threads = [None] * len(machines)
+ results = [False] * len(machines)
+ for i in range(len(machines)):
+ threads[i] = threading.Thread(
+ target=machines[i].install, args=(installer, results, i))
+ threads[i].start()
+ for thread in threads:
+ thread.join()
+ for result in results:
+ if not result:
+ raise NotImplementedError(
+ "Installers failed on at least one machine!")
+
+ # Add this user to each machine's docker group.
+ for m in machines:
+ m.run("sudo setfacl -m user:$USER:rw /var/run/docker.sock")
+
+ return machines
+
+ def release_machines(self, machine_list: List[machine.Machine]):
+ """Releases the requested number of machines, deleting the instances."""
+ if not machine_list:
+ return
+ cmd = "gcloud compute instances delete --quiet".split(" ")
+ names = [str(m) for m in machine_list]
+ cmd.extend(names)
+ cmd.append("--zone={zone}".format(zone=self.zone))
+ self._run_command(cmd, detach=True)
+
+ def _machines_from_instances(
+ self, instances: List[Dict[str, Any]]) -> List[machine.Machine]:
+ """Creates Machine Objects from json data describing created instances."""
+ machines = []
+ for instance in instances:
+ name = instance["name"]
+ external = instance["networkInterfaces"][0]["accessConfigs"][0]["natIP"]
+ internal = instance["networkInterfaces"][0]["networkIP"]
+ kwargs = {
+ "hostname": internal if self.internal else external,
+ "key_path": self.ssh_key_file,
+ "username": self.ssh_user,
+ "key_password": self.ssh_password
+ }
+ machines.append(machine.RemoteMachine(name=name, **kwargs))
+ return machines
+
+ def _get_unique_names(self, num_names) -> List[str]:
+ """Returns num_names unique names based on data from the GCP project."""
+ return ["machine-" + str(uuid.uuid4()) for _ in range(0, num_names)]
+
+ def _build_instances(self, names: List[str]) -> List[Dict[str, Any]]:
+ """Creates instances using gcloud command.
+
+ Runs the command `gcloud compute instances create` and returns json data
+ on created instances on success. Creates len(names) instances, one for each
+ name.
+
+ Args:
+ names: list of names of instances to create.
+
+ Returns:
+ List of json data describing created machines.
+ """
+ if not names:
+ raise ValueError(
+ "_build_instances cannot create instances without names.")
+ cmd = "gcloud compute instances create".split(" ")
+ cmd.extend(names)
+ cmd.append("--image=" + self.image)
+ cmd.append("--zone=" + self.zone)
+ cmd.append("--machine-type=" + self.machine_type)
+ res = self._run_command(cmd)
+ data = res.stdout
+ data = str(data, "utf-8") if isinstance(data, (bytes, bytearray)) else data
+ return json.loads(data)
+
+ def _add_ssh_key_to_instances(self, names: List[str]) -> None:
+ """Adds ssh key to instances by calling gcloud ssh command.
+
+ Runs the command `gcloud compute ssh instance_name` on list of images by
+ name. Tries to ssh into given instance.
+
+ Args:
+ names: list of machine names to which to add the ssh-key
+ self.ssh_key_file.
+
+ Raises:
+ subprocess.CalledProcessError: when underlying subprocess call returns an
+ error other than 255 (Connection closed by remote host).
+ TimeoutError: when 3 unsuccessful tries to ssh into the host return 255.
+ """
+ for name in names:
+ cmd = "gcloud compute ssh {user}@{name}".format(
+ user=self.ssh_user, name=name).split(" ")
+ if self.internal:
+ cmd.append("--internal-ip")
+ cmd.append("--ssh-key-file={key}".format(key=self.ssh_key_file))
+ cmd.append("--zone={zone}".format(zone=self.zone))
+ cmd.append("--command=uname")
+ timeout = datetime.timedelta(seconds=5 * 60)
+ start = datetime.datetime.now()
+ while datetime.datetime.now() <= timeout + start:
+ try:
+ self._run_command(cmd)
+ break
+ except subprocess.CalledProcessError:
+ if datetime.datetime.now() > timeout + start:
+ raise TimeoutError(
+ "Could not SSH into instance after 5 min: {name}".format(
+ name=name))
+
+ def _run_command(self,
+ cmd: List[str],
+ detach: bool = False) -> [None, subprocess.CompletedProcess]:
+ """Runs command as a subprocess.
+
+ Runs command as subprocess and returns the result.
+ If this has a mock recorder, use the record method to record the subprocess
+ call.
+
+ Args:
+ cmd: command to be run as a list of strings.
+ detach: if True, run the child process and don't wait for it to return.
+
+ Returns:
+ Completed process object to be parsed by caller or None if detach=True.
+
+ Raises:
+ CalledProcessError: if subprocess.run returns an error.
+ """
+ cmd = cmd + ["--format=json"]
+ if detach:
+ p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ if self.mock:
+ out, _ = p.communicate()
+ self.mock.record(
+ subprocess.CompletedProcess(
+ returncode=p.returncode, stdout=out, args=p.args))
+ return
+
+ res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ if self.mock:
+ self.mock.record(res)
+ if res.returncode != 0:
+ raise subprocess.CalledProcessError(
+ cmd=" ".join(res.args),
+ output=res.stdout,
+ stderr=res.stderr,
+ returncode=res.returncode)
+ return res
diff --git a/benchmarks/harness/machine_producers/gcloud_producer_test.py b/benchmarks/harness/machine_producers/gcloud_producer_test.py
new file mode 100644
index 000000000..c8adb2bdc
--- /dev/null
+++ b/benchmarks/harness/machine_producers/gcloud_producer_test.py
@@ -0,0 +1,48 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests GCloudProducer using mock data.
+
+GCloudProducer produces machines using 'get_machines' and 'release_machines'
+methods. The tests check recorded data (jsonified subprocess.CompletedProcess
+objects) of the producer producing one and five machines.
+"""
+import os
+import types
+
+from benchmarks.harness.machine_producers import machine_producer
+from benchmarks.harness.machine_producers import mock_producer
+
+TEST_DIR = os.path.dirname(__file__)
+
+
+def run_get_release(producer: machine_producer.MachineProducer,
+ num_machines: int,
+ validator: types.FunctionType = None):
+ machines = producer.get_machines(num_machines)
+ assert len(machines) == num_machines
+ if validator:
+ validator(machines=machines, cmd="uname -a", workload=None)
+ producer.release_machines(machines)
+
+
+def test_run_one():
+ mock = mock_producer.MockReader(TEST_DIR + "get_one.json")
+ producer = mock_producer.MockGCloudProducer(mock)
+ run_get_release(producer, 1)
+
+
+def test_run_five():
+ mock = mock_producer.MockReader(TEST_DIR + "get_five.json")
+ producer = mock_producer.MockGCloudProducer(mock)
+ run_get_release(producer, 5)
diff --git a/benchmarks/harness/machine_producers/machine_producer.py b/benchmarks/harness/machine_producers/machine_producer.py
new file mode 100644
index 000000000..f5591c026
--- /dev/null
+++ b/benchmarks/harness/machine_producers/machine_producer.py
@@ -0,0 +1,51 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Abstract types."""
+
+import threading
+from typing import List
+
+from benchmarks.harness import machine
+
+
+class MachineProducer:
+ """Abstract Machine producer."""
+
+ def get_machines(self, num_machines: int) -> List[machine.Machine]:
+ """Returns the requested number of machines."""
+ raise NotImplementedError
+
+ def release_machines(self, machine_list: List[machine.Machine]):
+ """Releases the given set of machines."""
+ raise NotImplementedError
+
+
+class LocalMachineProducer(MachineProducer):
+ """Produces Local Machines."""
+
+ def __init__(self, limit: int):
+ self.limit_sem = threading.Semaphore(value=limit)
+
+ def get_machines(self, num_machines: int) -> List[machine.Machine]:
+ """Returns the request number of MockMachines."""
+
+ self.limit_sem.acquire()
+ return [machine.LocalMachine("local") for _ in range(num_machines)]
+
+ def release_machines(self, machine_list: List[machine.MockMachine]):
+ """No-op."""
+ if not machine_list:
+ raise ValueError("Cannot release an empty list!")
+ self.limit_sem.release()
+ machine_list.clear()
diff --git a/benchmarks/harness/machine_producers/mock_producer.py b/benchmarks/harness/machine_producers/mock_producer.py
new file mode 100644
index 000000000..37e9cb4b7
--- /dev/null
+++ b/benchmarks/harness/machine_producers/mock_producer.py
@@ -0,0 +1,52 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Producers of mocks."""
+
+from typing import List, Any
+
+from benchmarks.harness import machine
+from benchmarks.harness.machine_producers import gcloud_mock_recorder
+from benchmarks.harness.machine_producers import gcloud_producer
+from benchmarks.harness.machine_producers import machine_producer
+
+
+class MockMachineProducer(machine_producer.MachineProducer):
+ """Produces MockMachine objects."""
+
+ def get_machines(self, num_machines: int) -> List[machine.MockMachine]:
+ """Returns the request number of MockMachines."""
+ return [machine.MockMachine() for i in range(num_machines)]
+
+ def release_machines(self, machine_list: List[machine.MockMachine]):
+ """No-op."""
+ return
+
+
+class MockGCloudProducer(gcloud_producer.GCloudProducer):
+ """Mocks GCloudProducer for testing purposes."""
+
+ def __init__(self, mock: gcloud_mock_recorder.MockReader, **kwargs):
+ gcloud_producer.GCloudProducer.__init__(
+ self, project="mock", ssh_private_key_path="mock", **kwargs)
+ self.mock = mock
+
+ def _validate_ssh_file(self):
+ pass
+
+ def _run_command(self, cmd):
+ return self.mock.pop(cmd)
+
+ def _machines_from_instances(
+ self, instances: List[Any]) -> List[machine.MockMachine]:
+ return [machine.MockMachine() for _ in instances]
diff --git a/benchmarks/harness/machine_producers/testdata/get_five.json b/benchmarks/harness/machine_producers/testdata/get_five.json
new file mode 100644
index 000000000..32bad1b06
--- /dev/null
+++ b/benchmarks/harness/machine_producers/testdata/get_five.json
@@ -0,0 +1,211 @@
+[
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "list",
+ "--project",
+ "project",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":{\"natIP\":\"0.0.0.0\"}]}]}]",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "create",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92",
+ "machine-da5859b5-bae6-435d-8005-0202d6f6e065",
+ "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05",
+ "machine-1149147d-71e2-43ea-8fe1-49256e5c441c",
+ "--preemptible",
+ "--image=ubuntu-1910-eoan-v20191204",
+ "--zone=us-west1-b",
+ "--image-project=ubuntu-os-cloud",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "start",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92",
+ "machine-da5859b5-bae6-435d-8005-0202d6f6e065",
+ "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05",
+ "machine-1149147d-71e2-43ea-8fe1-49256e5c441c",
+ "--zone=us-west1-b",
+ "--project=project",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-da5859b5-bae6-435d-8005-0202d6f6e065",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-1149147d-71e2-43ea-8fe1-49256e5c441c",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "delete",
+ "--quiet",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92",
+ "machine-da5859b5-bae6-435d-8005-0202d6f6e065",
+ "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05",
+ "machine-1149147d-71e2-43ea-8fe1-49256e5c441c",
+ "--zone=us-west1-b",
+ "--format=json"
+ ],
+ "stdout": "[]\n",
+ "returncode": "0"
+ }
+]
diff --git a/benchmarks/harness/machine_producers/testdata/get_one.json b/benchmarks/harness/machine_producers/testdata/get_one.json
new file mode 100644
index 000000000..c359c19c8
--- /dev/null
+++ b/benchmarks/harness/machine_producers/testdata/get_one.json
@@ -0,0 +1,145 @@
+[
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "list",
+ "--project",
+ "linux-testing-user",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "create",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--preemptible",
+ "--image=ubuntu-1910-eoan-v20191204",
+ "--zone=us-west1-b",
+ "--image-project=ubuntu-os-cloud",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "start",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--zone=us-west1-b",
+ "--project=linux-testing-user",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "delete",
+ "--quiet",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--zone=us-west1-b",
+ "--format=json"
+ ],
+ "stdout": "[]\n",
+ "returncode": "0"
+ }
+]
diff --git a/benchmarks/harness/machine_producers/yaml_producer.py b/benchmarks/harness/machine_producers/yaml_producer.py
new file mode 100644
index 000000000..5d334e480
--- /dev/null
+++ b/benchmarks/harness/machine_producers/yaml_producer.py
@@ -0,0 +1,106 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Producers based on yaml files."""
+
+import os
+import threading
+from typing import Dict
+from typing import List
+
+import yaml
+
+from benchmarks.harness import machine
+from benchmarks.harness.machine_producers import machine_producer
+
+
+class YamlMachineProducer(machine_producer.MachineProducer):
+ """Loads machines from a yaml file."""
+
+ def __init__(self, path: str):
+ self.machines = build_machines(path)
+ self.max_machines = len(self.machines)
+ self.machine_condition = threading.Condition()
+
+ def get_machines(self, num_machines: int) -> List[machine.Machine]:
+ if num_machines > self.max_machines:
+ raise ValueError(
+ "Insufficient Ammount of Machines. {ask} asked for and have {max_num} max."
+ .format(ask=num_machines, max_num=self.max_machines))
+
+ with self.machine_condition:
+ while not self._enough_machines(num_machines):
+ self.machine_condition.wait(timeout=1)
+ return [self.machines.pop(0) for _ in range(num_machines)]
+
+ def release_machines(self, machine_list: List[machine.Machine]):
+ with self.machine_condition:
+ while machine_list:
+ next_machine = machine_list.pop()
+ self.machines.append(next_machine)
+ self.machine_condition.notify()
+
+ def _enough_machines(self, ask: int):
+ return ask <= len(self.machines)
+
+
+def build_machines(path: str, num_machines: str = -1) -> List[machine.Machine]:
+ """Builds machine objects defined by the yaml file "path".
+
+ Args:
+ path: The path to a yaml file which defines machines.
+ num_machines: Optional limit on how many machine objects to build.
+
+ Returns:
+ Machine objects in a list.
+
+ If num_machines is set, len(machines) <= num_machines.
+ """
+ data = parse_yaml(path)
+ machines = []
+ for key, value in data.items():
+ if len(machines) == num_machines:
+ return machines
+ if isinstance(value, dict):
+ machines.append(machine.RemoteMachine(key, **value))
+ else:
+ machines.append(machine.LocalMachine(key))
+ return machines
+
+
+def parse_yaml(path: str) -> Dict[str, Dict[str, str]]:
+ """Parse the yaml file pointed by path.
+
+ Args:
+ path: The path to yaml file.
+
+ Returns:
+ The contents of the yaml file as a dictionary.
+ """
+ data = get_file_contents(path)
+ return yaml.load(data, Loader=yaml.Loader)
+
+
+def get_file_contents(path: str) -> str:
+ """Dumps the file contents to a string and returns them.
+
+ Args:
+ path: The path to dump.
+
+ Returns:
+ The file contents as a string.
+ """
+ if not os.path.isabs(path):
+ path = os.path.abspath(path)
+ with open(path) as input_file:
+ return input_file.read()
diff --git a/benchmarks/harness/ssh_connection.py b/benchmarks/harness/ssh_connection.py
new file mode 100644
index 000000000..b8c8e42d4
--- /dev/null
+++ b/benchmarks/harness/ssh_connection.py
@@ -0,0 +1,126 @@
+# python3
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""SSHConnection handles the details of SSH connections."""
+
+import logging
+import os
+import warnings
+
+import paramiko
+
+from benchmarks import harness
+
+# Get rid of paramiko Cryptography Warnings.
+warnings.filterwarnings(action="ignore", module=".*paramiko.*")
+
+log = logging.getLogger(__name__)
+
+
+def send_one_file(client: paramiko.SSHClient, path: str,
+ remote_dir: str) -> str:
+ """Sends a single file via an SSH client.
+
+ Args:
+ client: The existing SSH client.
+ path: The local path.
+ remote_dir: The remote directory.
+
+ Returns:
+ :return: The remote path as a string.
+ """
+ filename = path.split("/").pop()
+ if remote_dir != ".":
+ client.exec_command("mkdir -p " + remote_dir)
+ with client.open_sftp() as ftp_client:
+ ftp_client.put(path, os.path.join(remote_dir, filename))
+ return os.path.join(remote_dir, filename)
+
+
+class SSHConnection:
+ """SSH connection to a remote machine."""
+
+ def __init__(self, name: str, hostname: str, key_path: str, username: str,
+ **kwargs):
+ """Sets up a paramiko ssh connection to the given hostname."""
+ self._name = name # Unused.
+ self._hostname = hostname
+ self._username = username
+ self._key_path = key_path # RSA Key path
+ self._kwargs = kwargs
+ # SSHConnection wraps paramiko. paramiko supports RSA, ECDSA, and Ed25519
+ # keys, and we've chosen to only suport and require RSA keys. paramiko
+ # supports RSA keys that begin with '----BEGIN RSAKEY----'.
+ # https://stackoverflow.com/questions/53600581/ssh-key-generated-by-ssh-keygen-is-not-recognized-by-paramiko
+ self.rsa_key = self._rsa()
+ self.run("true") # Validate.
+
+ def _client(self) -> paramiko.SSHClient:
+ """Returns a connected SSH client."""
+ client = paramiko.SSHClient()
+ client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ client.connect(
+ hostname=self._hostname,
+ port=22,
+ username=self._username,
+ pkey=self.rsa_key,
+ allow_agent=False,
+ look_for_keys=False)
+ return client
+
+ def _rsa(self):
+ if "key_password" in self._kwargs:
+ password = self._kwargs["key_password"]
+ else:
+ password = None
+ rsa = paramiko.RSAKey.from_private_key_file(self._key_path, password)
+ return rsa
+
+ def run(self, cmd: str) -> (str, str):
+ """Runs a command via ssh.
+
+ Args:
+ cmd: The shell command to run.
+
+ Returns:
+ The contents of stdout and stderr.
+ """
+ with self._client() as client:
+ log.info("running command: %s", cmd)
+ _, stdout, stderr = client.exec_command(command=cmd)
+ log.info("returned status: %d", stdout.channel.recv_exit_status())
+ stdout = stdout.read().decode("utf-8")
+ stderr = stderr.read().decode("utf-8")
+ log.info("stdout: %s", stdout)
+ log.info("stderr: %s", stderr)
+ return stdout, stderr
+
+ def send_workload(self, name: str) -> str:
+ """Sends a workload tarball to the remote machine.
+
+ Args:
+ name: The workload name.
+
+ Returns:
+ The remote path.
+ """
+ with self._client() as client:
+ return send_one_file(client, harness.LOCAL_WORKLOADS_PATH.format(name),
+ harness.REMOTE_WORKLOADS_PATH.format(name))
+
+ def send_installers(self) -> str:
+ with self._client() as client:
+ return send_one_file(
+ client,
+ path=harness.INSTALLER_ARCHIVE,
+ remote_dir=harness.REMOTE_INSTALLERS_PATH)
diff --git a/benchmarks/harness/tunnel_dispatcher.py b/benchmarks/harness/tunnel_dispatcher.py
new file mode 100644
index 000000000..c56fd022a
--- /dev/null
+++ b/benchmarks/harness/tunnel_dispatcher.py
@@ -0,0 +1,122 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tunnel handles setting up connections to remote machines.
+
+Tunnel dispatcher is a wrapper around the connection from a local UNIX socket
+and a remote UNIX socket via SSH with port forwarding. This is done to
+initialize the pythonic dockerpy client to run containers on the remote host by
+connecting to /var/run/docker.sock (where Docker is listening). Tunnel
+dispatcher sets up the local UNIX socket and calls the `ssh` command as a
+subprocess, and holds a reference to that subprocess. It manages clean-up on
+exit as best it can by killing the ssh subprocess and deleting the local UNIX
+socket,stored in /tmp for easy cleanup in most systems if this fails.
+
+ Typical usage example:
+
+ t = Tunnel(name, **kwargs)
+ t.connect()
+ client = t.get_docker_client() #
+ client.containers.run("ubuntu", "echo hello world")
+
+"""
+
+import os
+import tempfile
+import time
+
+import docker
+import pexpect
+
+SSH_TUNNEL_COMMAND = """ssh
+ -o GlobalKnownHostsFile=/dev/null
+ -o UserKnownHostsFile=/dev/null
+ -o StrictHostKeyChecking=no
+ -o IdentitiesOnly=yes
+ -nNT -L {filename}:/var/run/docker.sock
+ -i {key_path}
+ {username}@{hostname}"""
+
+
+class Tunnel(object):
+ """The tunnel object represents the tunnel via ssh.
+
+ This connects a local unix domain socket with a remote socket.
+
+ Attributes:
+ _filename: a temporary name of the UNIX socket prefixed by the name
+ argument.
+ _hostname: the IP or resolvable hostname of the remote host.
+ _username: the username of the ssh_key used to run ssh.
+ _key_path: path to a valid key.
+ _key_password: optional password to the ssh key in _key_path
+ _process: holds reference to the ssh subprocess created.
+
+ Returns:
+ The new minimum port.
+
+ Raises:
+ ConnectionError: If no available port is found.
+ """
+
+ def __init__(self,
+ name: str,
+ hostname: str,
+ username: str,
+ key_path: str,
+ key_password: str = "",
+ **kwargs):
+ self._filename = tempfile.NamedTemporaryFile(prefix=name).name
+ self._hostname = hostname
+ self._username = username
+ self._key_path = key_path
+ self._key_password = key_password
+ self._kwargs = kwargs
+ self._process = None
+
+ def connect(self):
+ """Connects the SSH tunnel and stores the subprocess reference in _process."""
+ cmd = SSH_TUNNEL_COMMAND.format(
+ filename=self._filename,
+ key_path=self._key_path,
+ username=self._username,
+ hostname=self._hostname)
+ self._process = pexpect.spawn(cmd, timeout=10)
+
+ # If given a password, assume we'll be asked for it.
+ if self._key_password:
+ self._process.expect(["Enter passphrase for key .*: "])
+ self._process.sendline(self._key_password)
+
+ while True:
+ # Wait for the tunnel to appear.
+ if self._process.exitstatus is not None:
+ raise ConnectionError("Error in setting up ssh tunnel")
+ if os.path.exists(self._filename):
+ return
+ time.sleep(0.1)
+
+ def path(self):
+ """Return the socket file."""
+ return self._filename
+
+ def get_docker_client(self):
+ """Returns a docker client for this Tunnel."""
+ return docker.DockerClient(base_url="unix:/" + self._filename)
+
+ def __del__(self):
+ """Closes the ssh connection process and deletes the socket file."""
+ if self._process:
+ self._process.close()
+ if os.path.exists(self._filename):
+ os.remove(self._filename)
diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt
new file mode 100644
index 000000000..577eb1a2e
--- /dev/null
+++ b/benchmarks/requirements.txt
@@ -0,0 +1,32 @@
+asn1crypto==1.2.0
+atomicwrites==1.3.0
+attrs==19.3.0
+bcrypt==3.1.7
+certifi==2019.9.11
+cffi==1.13.2
+chardet==3.0.4
+Click==7.0
+cryptography==2.8
+docker==3.7.0
+docker-pycreds==0.4.0
+idna==2.8
+importlib-metadata==0.23
+more-itertools==7.2.0
+packaging==19.2
+paramiko==2.6.0
+pathlib2==2.3.5
+pexpect==4.7.0
+pluggy==0.9.0
+ptyprocess==0.6.0
+py==1.8.0
+pycparser==2.19
+PyNaCl==1.3.0
+pyparsing==2.4.5
+pytest==4.3.0
+PyYAML==5.1.2
+requests==2.22.0
+six==1.13.0
+urllib3==1.25.7
+wcwidth==0.1.7
+websocket-client==0.56.0
+zipp==0.6.0
diff --git a/benchmarks/run.py b/benchmarks/run.py
new file mode 100644
index 000000000..a22eb8641
--- /dev/null
+++ b/benchmarks/run.py
@@ -0,0 +1,19 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Benchmark runner."""
+
+from benchmarks import runner
+
+if __name__ == "__main__":
+ runner.runner()
diff --git a/benchmarks/runner/BUILD b/benchmarks/runner/BUILD
new file mode 100644
index 000000000..471debfdf
--- /dev/null
+++ b/benchmarks/runner/BUILD
@@ -0,0 +1,56 @@
+load("//tools:defs.bzl", "py_library", "py_requirement", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
+
+package(licenses = ["notice"])
+
+py_library(
+ name = "runner",
+ srcs = ["__init__.py"],
+ data = [
+ "//benchmarks/workloads:files",
+ ],
+ visibility = ["//benchmarks:__pkg__"],
+ deps = [
+ ":commands",
+ "//benchmarks/harness:benchmark_driver",
+ "//benchmarks/harness/machine_producers:machine_producer",
+ "//benchmarks/harness/machine_producers:mock_producer",
+ "//benchmarks/harness/machine_producers:yaml_producer",
+ "//benchmarks/suites",
+ "//benchmarks/suites:absl",
+ "//benchmarks/suites:density",
+ "//benchmarks/suites:fio",
+ "//benchmarks/suites:helpers",
+ "//benchmarks/suites:http",
+ "//benchmarks/suites:media",
+ "//benchmarks/suites:ml",
+ "//benchmarks/suites:network",
+ "//benchmarks/suites:redis",
+ "//benchmarks/suites:startup",
+ "//benchmarks/suites:sysbench",
+ "//benchmarks/suites:syscall",
+ py_requirement("click"),
+ ],
+)
+
+py_library(
+ name = "commands",
+ srcs = ["commands.py"],
+ deps = [
+ py_requirement("click"),
+ ],
+)
+
+py_test(
+ name = "runner_test",
+ srcs = ["runner_test.py"],
+ python_version = "PY3",
+ tags = [
+ "local",
+ "manual",
+ ],
+ deps = test_deps + [
+ ":runner",
+ py_requirement("click"),
+ ],
+)
diff --git a/benchmarks/runner/__init__.py b/benchmarks/runner/__init__.py
new file mode 100644
index 000000000..fc59cf505
--- /dev/null
+++ b/benchmarks/runner/__init__.py
@@ -0,0 +1,308 @@
+# python3
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""High-level benchmark utility."""
+
+import copy
+import csv
+import logging
+import pkgutil
+import pydoc
+import re
+import subprocess
+import sys
+import types
+from typing import List
+from typing import Tuple
+
+import click
+
+from benchmarks import harness
+from benchmarks import suites
+from benchmarks.harness import benchmark_driver
+from benchmarks.harness.machine_producers import gcloud_producer
+from benchmarks.harness.machine_producers import machine_producer
+from benchmarks.harness.machine_producers import mock_producer
+from benchmarks.harness.machine_producers import yaml_producer
+from benchmarks.runner import commands
+
+
+@click.group()
+@click.option(
+ "--verbose/--no-verbose", default=False, help="Enable verbose logging.")
+@click.option("--debug/--no-debug", default=False, help="Enable debug logging.")
+def runner(verbose: bool = False, debug: bool = False):
+ """Run distributed benchmarks.
+
+ See the run and list commands for details.
+
+ Args:
+ verbose: Enable verbose logging.
+ debug: Enable debug logging (supercedes verbose).
+ """
+ if debug:
+ logging.basicConfig(level=logging.DEBUG)
+ elif verbose:
+ logging.basicConfig(level=logging.INFO)
+
+
+def find_benchmarks(
+ regex: str) -> List[Tuple[str, types.ModuleType, types.FunctionType]]:
+ """Finds all available benchmarks.
+
+ Args:
+ regex: A regular expression to match.
+
+ Returns:
+ A (short_name, module, function) tuple for each match.
+ """
+ pkgs = pkgutil.walk_packages(suites.__path__, suites.__name__ + ".")
+ found = []
+ for _, name, _ in pkgs:
+ mod = pydoc.locate(name)
+ funcs = [
+ getattr(mod, x)
+ for x in dir(mod)
+ if suites.is_benchmark(getattr(mod, x))
+ ]
+ for func in funcs:
+ # Use the short_name with the benchmarks. prefix stripped.
+ prefix_len = len(suites.__name__ + ".")
+ short_name = mod.__name__[prefix_len:] + "." + func.__name__
+ # Add to the list if a pattern is provided.
+ if re.compile(regex).match(short_name):
+ found.append((short_name, mod, func))
+ return found
+
+
+@runner.command("list")
+@click.argument("method", nargs=-1)
+def list_all(method):
+ """Lists available benchmarks."""
+ if not method:
+ method = ".*"
+ else:
+ method = "(" + ",".join(method) + ")"
+ for (short_name, _, func) in find_benchmarks(method):
+ print("Benchmark %s:" % short_name)
+ metrics = suites.benchmark_metrics(func)
+ if func.__doc__:
+ print(" " + func.__doc__.lstrip().rstrip())
+ if metrics:
+ print("\n Metrics:")
+ for metric in metrics:
+ print("\t{name}: {doc}".format(name=metric[0], doc=metric[1]))
+ print("\n")
+
+
+@runner.command("run-local", commands.LocalCommand)
+@click.pass_context
+def run_local(ctx, limit: float, **kwargs):
+ """Runs benchmarks locally."""
+ run(ctx, machine_producer.LocalMachineProducer(limit=limit), **kwargs)
+
+
+@runner.command("run-mock", commands.RunCommand)
+@click.pass_context
+def run_mock(ctx, **kwargs):
+ """Runs benchmarks on Mock machines. Used for testing."""
+ run(ctx, mock_producer.MockMachineProducer(), **kwargs)
+
+
+@runner.command("run-gcp", commands.GCPCommand)
+@click.pass_context
+def run_gcp(ctx, image_file: str, zone_file: str, internal: bool,
+ machine_type: str, installers: List[str], **kwargs):
+ """Runs all benchmarks on GCP instances."""
+
+ # Resolve all files.
+ image = subprocess.check_output([image_file]).rstrip()
+ zone = subprocess.check_output([zone_file]).rstrip()
+ key_file = harness.make_key()
+
+ producer = gcloud_producer.GCloudProducer(
+ image,
+ zone,
+ machine_type,
+ installers,
+ ssh_key_file=key_file,
+ ssh_user=harness.DEFAULT_USER,
+ ssh_password="",
+ internal=internal)
+
+ try:
+ run(ctx, producer, **kwargs)
+ finally:
+ harness.delete_key()
+
+
+def run(ctx, producer: machine_producer.MachineProducer, method: str, runs: int,
+ runtime: List[str], metric: List[str], stat: str, **kwargs):
+ """Runs arbitrary benchmarks.
+
+ All unknown command line flags are passed through to the underlying benchmark
+ method. Flags may be specified multiple times, in which case it is considered
+ a "dimension" for the test, and a comma-separated table will be emitted
+ instead of a single result.
+
+ See the output of list to see available metrics for any given benchmark
+ method. The method parameter is a regular expression that will match against
+ available benchmarks. If multiple benchmarks match, then that is considered a
+ distinct "dimension" for the test.
+
+ All benchmarks are run in parallel where possible, but have exclusive
+ ownership over the individual machines.
+
+ Every benchmark method will be run the times indicated by --runs.
+
+ Args:
+ ctx: Click context.
+ producer: A Machine Producer from which to get Machines.
+ method: A regular expression for methods to be run.
+ runs: Number of runs.
+ runtime: A list of runtimes to test.
+ metric: A list of metrics to extract.
+ stat: The class of statistics to extract.
+ **kwargs: Dimensions to test.
+ """
+ # First, calculate additional arguments.
+ #
+ # This essentially calculates any arguments that appear multiple times, and
+ # moves those to the "dimensions" dictionary, which maps to lists. These
+ # dimensions are then iterated over to generate the relevant csv output.
+ dimensions = {}
+
+ if stat not in ["median", "all", "meanstd"]:
+ raise ValueError("Illegal value for --result, see help.")
+
+ def squish(key: str, value: str):
+ """Collapse an argument into kwargs or dimensions."""
+ if key in dimensions:
+ # Extend an existing dimension.
+ dimensions[key].append(value)
+ elif key in kwargs:
+ # Create a new dimension.
+ dimensions[key] = [kwargs[key], value]
+ del kwargs[key]
+ else:
+ # A single value.
+ kwargs[key] = value
+
+ for item in ctx.args:
+ if "=" in method:
+ # This must be the method. The method is simply set to the first
+ # non-matching argument, which we're also parsing here.
+ item, method = method, item
+ if "=" not in item:
+ logging.error("illegal argument: %s", item)
+ sys.exit(1)
+ (key, value) = item.lstrip("-").split("=", 1)
+ squish(key, value)
+
+ # Convert runtime and metric to dimensions.
+ #
+ # They exist only in the arguments above for documentation purposes.
+ # Essentially here we are treating them like anything else. Note however,
+ # that an empty set here will result in a dimension. This is important for
+ # metrics, where an empty set actually means all metrics.
+ def fold(key: str, value, allow_flatten=False):
+ """Collapse a list value into kwargs or dimensions."""
+ if len(value) == 1 and allow_flatten:
+ kwargs[key] = value[0]
+ else:
+ dimensions[key] = value
+
+ fold("runtime", runtime, allow_flatten=True)
+ fold("metric", metric)
+
+ # Lookup the methods.
+ #
+ # We match the method parameter to a regular expression. This allows you to
+ # do things like `run --mock .*` for a broad test. Note that we track the
+ # short_names in the dimensions here, and look up again in the recursion.
+ methods = {
+ short_name: func for (short_name, _, func) in find_benchmarks(method)
+ }
+ if not methods:
+ # Must match at least one method.
+ logging.error("no matching benchmarks for %s: try list.", method)
+ sys.exit(1)
+ fold("method", list(methods.keys()), allow_flatten=True)
+
+ # Spin up the drivers.
+ #
+ # We ensure that metric is the last entry, because we have special behavior.
+ # They actually run the test once and the benchmark is a generator that
+ # produces all viable metrics.
+ dimension_keys = list(dimensions.keys())
+ if "metric" in dimension_keys:
+ dimension_keys.remove("metric")
+ dimension_keys.append("metric")
+ drivers = []
+
+ def _start(keywords, finished, left):
+ """Runs a test across dimensions recursively."""
+ # Resolve the method fully, it starts as a string.
+ if "method" in keywords and isinstance(keywords["method"], str):
+ keywords["method"] = methods[keywords["method"]]
+ # Is this a non-recursive case?
+ if not left:
+ driver = benchmark_driver.BenchmarkDriver(producer, runs=runs, **keywords)
+ driver.start()
+ drivers.append((finished, driver))
+ else:
+ # Recurse on the next dimension.
+ current, left = left[0], left[1:]
+ keywords = copy.deepcopy(keywords)
+ if current == "metric":
+ # We use a generator, popped below. Note that metric is
+ # guaranteed to be the last element here, and we will provide
+ # the value for 'done' below when generating the csv.
+ keywords[current] = dimensions[current]
+ _start(keywords, finished, left)
+ else:
+ # Generate manually.
+ for value in dimensions[current]:
+ keywords[current] = value
+ _start(keywords, finished + [value], left)
+
+ # Start all the drivers, recursively.
+ _start(kwargs, [], dimension_keys)
+
+ # Finish all tests, write results.
+ output = csv.writer(sys.stdout)
+ output.writerow(dimension_keys + ["result"])
+ for (done, driver) in drivers:
+ driver.join()
+ for (metric_name, result) in getattr(driver, stat)():
+ output.writerow([ # Collapse the method name.
+ hasattr(x, "__name__") and x.__name__ or x for x in done
+ ] + [metric_name] + result)
+
+
+@runner.command()
+@click.argument("env")
+@click.option(
+ "--cmd", default="uname -a", help="command to run on all found machines")
+@click.option(
+ "--workload", default="true", help="workload to run all found machines")
+def validate(env, cmd, workload):
+ """Validates an environment described by yaml file."""
+ producer = yaml_producer.YamlMachineProducer(env)
+ for machine in producer.machines:
+ print("Machine %s:" % machine)
+ stdout, _ = machine.run(cmd)
+ print(" Output of '%s': %s" % (cmd, stdout.lstrip().rstrip()))
+ image = machine.pull(workload)
+ stdout = machine.container(image).run()
+ print(" Container %s: %s" % (workload, stdout.lstrip().rstrip()))
diff --git a/benchmarks/runner/commands.py b/benchmarks/runner/commands.py
new file mode 100644
index 000000000..9a391eb01
--- /dev/null
+++ b/benchmarks/runner/commands.py
@@ -0,0 +1,135 @@
+# python3
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Module with the guts of `click` commands.
+
+Overrides of the click.core.Command. This is done so flags are inherited between
+similar commands (the run command). The classes below are meant to be used in
+click templates like so.
+
+@runner.command("run-mock", RunCommand)
+def run_mock(**kwargs):
+ # mock implementation
+
+"""
+import os
+
+import click
+
+
+class RunCommand(click.core.Command):
+ """Base Run Command with flags.
+
+ Attributes:
+ method: regex of which suite to choose (e.g. sysbench would run
+ sysbench.cpu, sysbench.memory, and sysbench.mutex) See list command for
+ details.
+ metric: metric(s) to extract. See list command for details.
+ runtime: the runtime(s) on which to run.
+ runs: the number of runs to do of each method.
+ stat: how to compile results in the case of multiple run (e.g. median).
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ method = click.core.Argument(("method",))
+
+ metric = click.core.Option(("--metric",),
+ help="The metric to extract.",
+ multiple=True)
+
+ runtime = click.core.Option(("--runtime",),
+ default=["runc"],
+ help="The runtime to use.",
+ multiple=True)
+ runs = click.core.Option(("--runs",),
+ default=1,
+ help="The number of times to run each benchmark.")
+ stat = click.core.Option(
+ ("--stat",),
+ default="median",
+ help="How to aggregate the data from all runs."
+ "\nmedian - returns the median of all runs (default)"
+ "\nall - returns all results comma separated"
+ "\nmeanstd - returns result as mean,std")
+ self.params.extend([method, runtime, runs, stat, metric])
+ self.ignore_unknown_options = True
+ self.allow_extra_args = True
+
+
+class LocalCommand(RunCommand):
+ """LocalCommand inherits all flags from RunCommand.
+
+ Attributes:
+ limit: limits the number of machines on which to run benchmarks. This limits
+ for local how many benchmarks may run at a time. e.g. "startup" requires
+ one machine -- passing two machines would limit two startup jobs at a
+ time. Default is infinity.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.params.append(
+ click.core.Option(
+ ("--limit",),
+ default=1,
+ help="Limit of number of benchmarks that can run at a given time."))
+
+
+class GCPCommand(RunCommand):
+ """GCPCommand inherits all flags from RunCommand and adds flags for run_gcp method.
+
+ Attributes:
+ image_file: name of the image to build machines from
+ zone_file: a GCP zone (e.g. us-west1-b)
+ installers: named installers for post-create
+ machine_type: type of machine to create (e.g. n1-standard-4)
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ image_file = click.core.Option(
+ ("--image_file",),
+ help="The binary that emits the GCP image.",
+ default=os.path.join(
+ os.path.dirname(__file__), "../../tools/vm/ubuntu1604"),
+ )
+ zone_file = click.core.Option(
+ ("--zone_file",),
+ help="The binary that emits the GCP zone.",
+ default=os.path.join(os.path.dirname(__file__), "../../tools/vm/zone"),
+ )
+ internal = click.core.Option(
+ ("--internal/--no-internal",),
+ help="""Use instance internal IPs. Used if bm-tools runner is running on
+ GCP instance with firewall rules blocking external IPs.""",
+ default=False,
+ )
+ installers = click.core.Option(
+ ("--installers",),
+ help="The set of installers to use.",
+ multiple=True,
+ )
+ machine_type = click.core.Option(
+ ("--machine_type",),
+ help="Type to make all machines.",
+ default="n1-standard-4",
+ )
+ self.params.extend([
+ image_file,
+ zone_file,
+ internal,
+ machine_type,
+ installers,
+ ])
diff --git a/benchmarks/runner/runner_test.py b/benchmarks/runner/runner_test.py
new file mode 100644
index 000000000..7818d631a
--- /dev/null
+++ b/benchmarks/runner/runner_test.py
@@ -0,0 +1,59 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Top-level tests."""
+
+import os
+import subprocess
+import sys
+
+from click import testing
+import pytest
+
+from benchmarks import runner
+
+
+def _get_locale():
+ output = subprocess.check_output(["locale", "-a"])
+ locales = output.split()
+ if b"en_US.utf8" in locales:
+ return "en_US.UTF-8"
+ else:
+ return "C.UTF-8"
+
+
+def _set_locale():
+ locale = _get_locale()
+ if os.getenv("LANG") != locale:
+ os.environ["LANG"] = locale
+ os.environ["LC_ALL"] = locale
+ os.execv("/proc/self/exe", ["python"] + sys.argv)
+
+
+def test_list():
+ cli_runner = testing.CliRunner()
+ result = cli_runner.invoke(runner.runner, ["list"])
+ print(result.output)
+ assert result.exit_code == 0
+
+
+def test_run():
+ cli_runner = testing.CliRunner()
+ result = cli_runner.invoke(runner.runner, ["run-mock", "."])
+ print(result.output)
+ assert result.exit_code == 0
+
+
+if __name__ == "__main__":
+ _set_locale()
+ sys.exit(pytest.main([__file__]))
diff --git a/benchmarks/suites/BUILD b/benchmarks/suites/BUILD
new file mode 100644
index 000000000..04fc23261
--- /dev/null
+++ b/benchmarks/suites/BUILD
@@ -0,0 +1,130 @@
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+py_library(
+ name = "suites",
+ srcs = ["__init__.py"],
+)
+
+py_library(
+ name = "absl",
+ srcs = ["absl.py"],
+ deps = [
+ "//benchmarks/harness:machine",
+ "//benchmarks/suites",
+ "//benchmarks/workloads/absl",
+ ],
+)
+
+py_library(
+ name = "density",
+ srcs = ["density.py"],
+ deps = [
+ "//benchmarks/harness:container",
+ "//benchmarks/harness:machine",
+ "//benchmarks/suites",
+ "//benchmarks/suites:helpers",
+ ],
+)
+
+py_library(
+ name = "fio",
+ srcs = ["fio.py"],
+ deps = [
+ "//benchmarks/harness:machine",
+ "//benchmarks/suites",
+ "//benchmarks/suites:helpers",
+ "//benchmarks/workloads/fio",
+ ],
+)
+
+py_library(
+ name = "helpers",
+ srcs = ["helpers.py"],
+ deps = ["//benchmarks/harness:machine"],
+)
+
+py_library(
+ name = "http",
+ srcs = ["http.py"],
+ deps = [
+ "//benchmarks/harness:machine",
+ "//benchmarks/suites",
+ "//benchmarks/workloads/ab",
+ ],
+)
+
+py_library(
+ name = "media",
+ srcs = ["media.py"],
+ deps = [
+ "//benchmarks/harness:machine",
+ "//benchmarks/suites",
+ "//benchmarks/suites:helpers",
+ "//benchmarks/workloads/ffmpeg",
+ ],
+)
+
+py_library(
+ name = "ml",
+ srcs = ["ml.py"],
+ deps = [
+ "//benchmarks/harness:machine",
+ "//benchmarks/suites",
+ "//benchmarks/suites:startup",
+ "//benchmarks/workloads/tensorflow",
+ ],
+)
+
+py_library(
+ name = "network",
+ srcs = ["network.py"],
+ deps = [
+ "//benchmarks/harness:machine",
+ "//benchmarks/suites",
+ "//benchmarks/suites:helpers",
+ "//benchmarks/workloads/iperf",
+ ],
+)
+
+py_library(
+ name = "redis",
+ srcs = ["redis.py"],
+ deps = [
+ "//benchmarks/harness:machine",
+ "//benchmarks/suites",
+ "//benchmarks/workloads/redisbenchmark",
+ ],
+)
+
+py_library(
+ name = "startup",
+ srcs = ["startup.py"],
+ deps = [
+ "//benchmarks/harness:machine",
+ "//benchmarks/suites",
+ "//benchmarks/suites:helpers",
+ ],
+)
+
+py_library(
+ name = "sysbench",
+ srcs = ["sysbench.py"],
+ deps = [
+ "//benchmarks/harness:machine",
+ "//benchmarks/suites",
+ "//benchmarks/workloads/sysbench",
+ ],
+)
+
+py_library(
+ name = "syscall",
+ srcs = ["syscall.py"],
+ deps = [
+ "//benchmarks/harness:machine",
+ "//benchmarks/suites",
+ "//benchmarks/workloads/syscall",
+ ],
+)
diff --git a/benchmarks/suites/__init__.py b/benchmarks/suites/__init__.py
new file mode 100644
index 000000000..360736cc3
--- /dev/null
+++ b/benchmarks/suites/__init__.py
@@ -0,0 +1,119 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Core benchmark annotations."""
+
+import functools
+import inspect
+import types
+from typing import List
+from typing import Tuple
+
+BENCHMARK_METRICS = '__benchmark_metrics__'
+BENCHMARK_MACHINES = '__benchmark_machines__'
+
+
+def is_benchmark(func: types.FunctionType) -> bool:
+ """Returns true if the given function is a benchmark."""
+ return isinstance(func, types.FunctionType) and \
+ hasattr(func, BENCHMARK_METRICS) and \
+ hasattr(func, BENCHMARK_MACHINES)
+
+
+def benchmark_metrics(func: types.FunctionType) -> List[Tuple[str, str]]:
+ """Returns the list of available metrics."""
+ return [(metric.__name__, metric.__doc__)
+ for metric in getattr(func, BENCHMARK_METRICS)]
+
+
+def benchmark_machines(func: types.FunctionType) -> int:
+ """Returns the number of machines required."""
+ return getattr(func, BENCHMARK_MACHINES)
+
+
+# pylint: disable=unused-argument
+def default(value, **kwargs):
+ """Returns the passed value."""
+ return value
+
+
+def benchmark(metrics: List[types.FunctionType] = None,
+ machines: int = 1) -> types.FunctionType:
+ """Define a benchmark function with metrics.
+
+ Args:
+ metrics: A list of metric functions.
+ machines: The number of machines required.
+
+ Returns:
+ A function that accepts the given number of machines, and iteratively
+ returns a set of (metric_name, metric_value) pairs when called repeatedly.
+ """
+ if not metrics:
+ # The default passes through.
+ metrics = [default]
+
+ def decorator(func: types.FunctionType) -> types.FunctionType:
+ """Decorator function."""
+ # Every benchmark should accept at least two parameters:
+ # runtime: The runtime to use for the benchmark (str, required).
+ # metrics: The metrics to use, if not the default (str, optional).
+ @functools.wraps(func)
+ def wrapper(*args, runtime: str, metric: list = None, **kwargs):
+ """Wrapper function."""
+ # First -- ensure that we marshall all types appropriately. In
+ # general, we will call this with only strings. These strings will
+ # need to be converted to their underlying types/classes.
+ sig = inspect.signature(func)
+ for param in sig.parameters.values():
+ if param.annotation != inspect.Parameter.empty and \
+ param.name in kwargs and not isinstance(kwargs[param.name], param.annotation):
+ try:
+ # Marshall to the appropriate type.
+ kwargs[param.name] = param.annotation(kwargs[param.name])
+ except Exception as exc:
+ raise ValueError(
+ 'illegal type for %s(%s=%s): %s' %
+ (func.__name__, param.name, kwargs[param.name], exc))
+ elif param.default != inspect.Parameter.empty and \
+ param.name not in kwargs:
+ # Ensure that we have the value set, because it will
+ # be passed to the metric function for evaluation.
+ kwargs[param.name] = param.default
+
+ # Next, figure out how to apply a metric. We do this prior to
+ # running the underlying function to prevent having to wait a few
+ # minutes for a result just to see some error.
+ if not metric:
+ # Return all metrics in the iterator.
+ result = func(*args, runtime=runtime, **kwargs)
+ for metric_func in metrics:
+ yield (metric_func.__name__, metric_func(result, **kwargs))
+ else:
+ result = None
+ for single_metric in metric:
+ for metric_func in metrics:
+ # Is this a function that matches the name?
+ # Apply this function to the result.
+ if metric_func.__name__ == single_metric:
+ if not result:
+ # Lazy evaluation: only if metric matches.
+ result = func(*args, runtime=runtime, **kwargs)
+ yield single_metric, metric_func(result, **kwargs)
+
+ # Set metadata on the benchmark (used above).
+ setattr(wrapper, BENCHMARK_METRICS, metrics)
+ setattr(wrapper, BENCHMARK_MACHINES, machines)
+ return wrapper
+
+ return decorator
diff --git a/benchmarks/suites/absl.py b/benchmarks/suites/absl.py
new file mode 100644
index 000000000..5d9b57a09
--- /dev/null
+++ b/benchmarks/suites/absl.py
@@ -0,0 +1,37 @@
+# python3
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""absl build benchmark."""
+
+from benchmarks import suites
+from benchmarks.harness import machine
+from benchmarks.workloads import absl
+
+
+@suites.benchmark(metrics=[absl.elapsed_time], machines=1)
+def build(target: machine.Machine, **kwargs) -> str:
+ """Runs the absl workload and report the absl build time.
+
+ Runs the 'bazel build //absl/...' in a clean bazel directory and
+ monitors time elapsed.
+
+ Args:
+ target: A machine object.
+ **kwargs: Additional container options.
+
+ Returns:
+ Container output.
+ """
+ image = target.pull("absl")
+ return target.container(image, **kwargs).run()
diff --git a/benchmarks/suites/density.py b/benchmarks/suites/density.py
new file mode 100644
index 000000000..89d29fb26
--- /dev/null
+++ b/benchmarks/suites/density.py
@@ -0,0 +1,121 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Density tests."""
+
+import re
+import types
+
+from benchmarks import suites
+from benchmarks.harness import container
+from benchmarks.harness import machine
+from benchmarks.suites import helpers
+
+
+# pylint: disable=unused-argument
+def memory_usage(value, **kwargs):
+ """Returns the passed value."""
+ return value
+
+
+def density(target: machine.Machine,
+ workload: str,
+ count: int = 50,
+ wait: float = 0,
+ load_func: types.FunctionType = None,
+ **kwargs):
+ """Calculate the average memory usage per container.
+
+ Args:
+ target: A machine object.
+ workload: The workload to run.
+ count: The number of containers to start.
+ wait: The time to wait after starting.
+ load_func: Callback that is called after count images have been started on
+ the given machine.
+ **kwargs: Additional container options.
+
+ Returns:
+ The average usage in Kb per container.
+ """
+ count = int(count)
+
+ # Drop all caches.
+ helpers.drop_caches(target)
+ before = target.read("/proc/meminfo")
+
+ # Load the workload.
+ image = target.pull(workload)
+
+ with target.container(
+ image=image, count=count, **kwargs).detach() as containers:
+ # Call the optional load function callback if given.
+ if load_func:
+ load_func(target, containers)
+ # Wait 'wait' time before taking a measurement.
+ target.sleep(wait)
+
+ # Drop caches again.
+ helpers.drop_caches(target)
+ after = target.read("/proc/meminfo")
+
+ # Calculate the memory used.
+ available_re = re.compile(r"MemAvailable:\s*(\d+)\skB\n")
+ before_available = available_re.findall(before)
+ after_available = available_re.findall(after)
+ return 1024 * float(int(before_available[0]) -
+ int(after_available[0])) / float(count)
+
+
+def load_redis(target: machine.Machine, containers: container.Container):
+ """Use redis-benchmark "LPUSH" to load each container with 1G of data.
+
+ Args:
+ target: A machine object.
+ containers: A set of containers.
+ """
+ target.pull("redisbenchmark")
+ for name in containers.get_names():
+ flags = "-d 10000 -t LPUSH"
+ target.container(
+ "redisbenchmark", links={
+ name: name
+ }).run(
+ host=name, flags=flags)
+
+
+@suites.benchmark(metrics=[memory_usage], machines=1)
+def empty(target: machine.Machine, **kwargs) -> float:
+ """Run trivial containers in a density test."""
+ return density(target, workload="sleep", wait=1.0, **kwargs)
+
+
+@suites.benchmark(metrics=[memory_usage], machines=1)
+def node(target: machine.Machine, **kwargs) -> float:
+ """Run node containers in a density test."""
+ return density(target, workload="node", wait=3.0, **kwargs)
+
+
+@suites.benchmark(metrics=[memory_usage], machines=1)
+def ruby(target: machine.Machine, **kwargs) -> float:
+ """Run ruby containers in a density test."""
+ return density(target, workload="ruby", wait=3.0, **kwargs)
+
+
+@suites.benchmark(metrics=[memory_usage], machines=1)
+def redis(target: machine.Machine, **kwargs) -> float:
+ """Run redis containers in a density test."""
+ if "count" not in kwargs:
+ kwargs["count"] = 5
+ return density(
+ target, workload="redis", wait=3.0, load_func=load_redis, **kwargs)
diff --git a/benchmarks/suites/fio.py b/benchmarks/suites/fio.py
new file mode 100644
index 000000000..2171790c5
--- /dev/null
+++ b/benchmarks/suites/fio.py
@@ -0,0 +1,165 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""File I/O tests."""
+
+import os
+
+from benchmarks import suites
+from benchmarks.harness import machine
+from benchmarks.suites import helpers
+from benchmarks.workloads import fio
+
+
+# pylint: disable=too-many-arguments
+# pylint: disable=too-many-locals
+def run_fio(target: machine.Machine,
+ test: str,
+ ioengine: str = "sync",
+ size: int = 1024 * 1024 * 1024,
+ iodepth: int = 4,
+ blocksize: int = 1024 * 1024,
+ time: int = -1,
+ mount_dir: str = "",
+ filename: str = "file.dat",
+ tmpfs: bool = False,
+ ramp_time: int = 0,
+ **kwargs) -> str:
+ """FIO benchmarks.
+
+ For more on fio see:
+ https://media.readthedocs.org/pdf/fio/latest/fio.pdf
+
+ Args:
+ target: A machine object.
+ test: The test to run (read, write, randread, randwrite, etc.)
+ ioengine: The engine for I/O.
+ size: The size of the generated file in bytes (if an integer) or 5g, 16k,
+ etc.
+ iodepth: The I/O for certain engines.
+ blocksize: The blocksize for reads and writes in bytes (if an integer) or
+ 4k, etc.
+ time: If test is time based, how long to run in seconds.
+ mount_dir: The absolute path on the host to mount a bind mount.
+ filename: The name of the file to creat inside container. For a path of
+ /dir/dir/file, the script setup a volume like 'docker run -v
+ mount_dir:/dir/dir fio' and fio will create (and delete) the file
+ /dir/dir/file. If tmpfs is set, this /dir/dir will be a tmpfs.
+ tmpfs: If true, mount on tmpfs.
+ ramp_time: The time to run before recording statistics
+ **kwargs: Additional container options.
+
+ Returns:
+ The output of fio as a string.
+ """
+ # Pull the image before dropping caches.
+ image = target.pull("fio")
+
+ if not mount_dir:
+ stdout, _ = target.run("pwd")
+ mount_dir = stdout.rstrip()
+
+ # Setup the volumes.
+ volumes = {mount_dir: {"bind": "/disk", "mode": "rw"}} if not tmpfs else None
+ tmpfs = {"/disk": ""} if tmpfs else None
+
+ # Construct a file in the volume.
+ filepath = os.path.join("/disk", filename)
+
+ # If we are running a read test, us fio to write a file and then flush file
+ # data from memory.
+ if "read" in test:
+ target.container(
+ image, volumes=volumes, tmpfs=tmpfs, **kwargs).run(
+ test="write",
+ ioengine="sync",
+ size=size,
+ iodepth=iodepth,
+ blocksize=blocksize,
+ path=filepath)
+ helpers.drop_caches(target)
+
+ # Run the test.
+ time_str = "--time_base --runtime={time}".format(
+ time=time) if int(time) > 0 else ""
+ res = target.container(
+ image, volumes=volumes, tmpfs=tmpfs, **kwargs).run(
+ test=test,
+ ioengine=ioengine,
+ size=size,
+ iodepth=iodepth,
+ blocksize=blocksize,
+ time=time_str,
+ path=filepath,
+ ramp_time=ramp_time)
+
+ target.run(
+ "rm {path}".format(path=os.path.join(mount_dir.rstrip(), filename)))
+
+ return res
+
+
+@suites.benchmark(metrics=[fio.read_bandwidth, fio.read_io_ops], machines=1)
+def read(*args, **kwargs):
+ """Read test.
+
+ Args:
+ *args: None.
+ **kwargs: Additional container options.
+
+ Returns:
+ The output of fio.
+ """
+ return run_fio(*args, test="read", **kwargs)
+
+
+@suites.benchmark(metrics=[fio.read_bandwidth, fio.read_io_ops], machines=1)
+def randread(*args, **kwargs):
+ """Random read test.
+
+ Args:
+ *args: None.
+ **kwargs: Additional container options.
+
+ Returns:
+ The output of fio.
+ """
+ return run_fio(*args, test="randread", **kwargs)
+
+
+@suites.benchmark(metrics=[fio.write_bandwidth, fio.write_io_ops], machines=1)
+def write(*args, **kwargs):
+ """Write test.
+
+ Args:
+ *args: None.
+ **kwargs: Additional container options.
+
+ Returns:
+ The output of fio.
+ """
+ return run_fio(*args, test="write", **kwargs)
+
+
+@suites.benchmark(metrics=[fio.write_bandwidth, fio.write_io_ops], machines=1)
+def randwrite(*args, **kwargs):
+ """Random write test.
+
+ Args:
+ *args: None.
+ **kwargs: Additional container options.
+
+ Returns:
+ The output of fio.
+ """
+ return run_fio(*args, test="randwrite", **kwargs)
diff --git a/benchmarks/suites/helpers.py b/benchmarks/suites/helpers.py
new file mode 100644
index 000000000..b3c7360ab
--- /dev/null
+++ b/benchmarks/suites/helpers.py
@@ -0,0 +1,57 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Benchmark helpers."""
+
+import datetime
+from benchmarks.harness import machine
+
+
+class Timer:
+ """Helper to time runtime of some call.
+
+ Usage:
+
+ with Timer as t:
+ # do something.
+ t.get_time_in_seconds()
+ """
+
+ def __init__(self):
+ self._start = datetime.datetime.now()
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def start(self):
+ """Starts the timer."""
+ self._start = datetime.datetime.now()
+
+ def elapsed(self) -> float:
+ """Returns the elapsed time in seconds."""
+ return (datetime.datetime.now() - self._start).total_seconds()
+
+ def __exit__(self, exception_type, exception_value, exception_traceback):
+ pass
+
+
+def drop_caches(target: machine.Machine):
+ """Drops caches on the machine.
+
+ Args:
+ target: A machine object.
+ """
+ target.run("sudo sync")
+ target.run("sudo sysctl vm.drop_caches=3")
+ target.run("sudo sysctl vm.drop_caches=3")
diff --git a/benchmarks/suites/http.py b/benchmarks/suites/http.py
new file mode 100644
index 000000000..6efea938c
--- /dev/null
+++ b/benchmarks/suites/http.py
@@ -0,0 +1,138 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""HTTP benchmarks."""
+
+from benchmarks import suites
+from benchmarks.harness import machine
+from benchmarks.workloads import ab
+
+
+# pylint: disable=too-many-arguments
+def http(server: machine.Machine,
+ client: machine.Machine,
+ workload: str,
+ requests: int = 5000,
+ connections: int = 10,
+ port: int = 80,
+ path: str = "notfound",
+ **kwargs) -> str:
+ """Run apachebench (ab) against an http server.
+
+ Args:
+ server: A machine object.
+ client: A machine object.
+ workload: The http-serving workload.
+ requests: Number of requests to send the server. Default is 5000.
+ connections: Number of concurent connections to use. Default is 10.
+ port: The port to access in benchmarking.
+ path: File to download, generally workload-specific.
+ **kwargs: Additional container options.
+
+ Returns:
+ The full apachebench output.
+ """
+ # Pull the client & server.
+ apachebench = client.pull("ab")
+ netcat = client.pull("netcat")
+ image = server.pull(workload)
+
+ with server.container(image, port=port, **kwargs).detach() as container:
+ (host, port) = container.address()
+ # Wait for the server to come up.
+ client.container(netcat).run(host=host, port=port)
+ # Run the benchmark, no arguments.
+ return client.container(apachebench).run(
+ host=host,
+ port=port,
+ requests=requests,
+ connections=connections,
+ path=path)
+
+
+# pylint: disable=too-many-arguments
+# pylint: disable=too-many-locals
+def http_app(server: machine.Machine,
+ client: machine.Machine,
+ workload: str,
+ requests: int = 5000,
+ connections: int = 10,
+ port: int = 80,
+ path: str = "notfound",
+ **kwargs) -> str:
+ """Run apachebench (ab) against an http application.
+
+ Args:
+ server: A machine object.
+ client: A machine object.
+ workload: The http-serving workload.
+ requests: Number of requests to send the server. Default is 5000.
+ connections: Number of concurent connections to use. Default is 10.
+ port: The port to use for benchmarking.
+ path: File to download, generally workload-specific.
+ **kwargs: Additional container options.
+
+ Returns:
+ The full apachebench output.
+ """
+ # Pull the client & server.
+ apachebench = client.pull("ab")
+ netcat = client.pull("netcat")
+ server_netcat = server.pull("netcat")
+ redis = server.pull("redis")
+ image = server.pull(workload)
+ redis_port = 6379
+ redis_name = "{workload}_redis_server".format(workload=workload)
+
+ with server.container(redis, name=redis_name).detach():
+ server.container(server_netcat, links={redis_name: redis_name})\
+ .run(host=redis_name, port=redis_port)
+ with server.container(image, port=port, links={redis_name: redis_name}, **kwargs)\
+ .detach(host=redis_name) as container:
+ (host, port) = container.address()
+ # Wait for the server to come up.
+ client.container(netcat).run(host=host, port=port)
+ # Run the benchmark, no arguments.
+ return client.container(apachebench).run(
+ host=host,
+ port=port,
+ requests=requests,
+ connections=connections,
+ path=path)
+
+
+@suites.benchmark(metrics=[ab.transfer_rate, ab.latency], machines=2)
+def httpd(*args, **kwargs) -> str:
+ """Apache2 benchmark."""
+ return http(*args, workload="httpd", port=80, **kwargs)
+
+
+@suites.benchmark(
+ metrics=[ab.transfer_rate, ab.latency, ab.requests_per_second], machines=2)
+def nginx(*args, **kwargs) -> str:
+ """Nginx benchmark."""
+ return http(*args, workload="nginx", port=80, **kwargs)
+
+
+@suites.benchmark(
+ metrics=[ab.transfer_rate, ab.latency, ab.requests_per_second], machines=2)
+def node(*args, **kwargs) -> str:
+ """Node benchmark."""
+ return http_app(*args, workload="node_template", path="", port=8080, **kwargs)
+
+
+@suites.benchmark(
+ metrics=[ab.transfer_rate, ab.latency, ab.requests_per_second], machines=2)
+def ruby(*args, **kwargs) -> str:
+ """Ruby benchmark."""
+ return http_app(*args, workload="ruby_template", path="", port=9292, **kwargs)
diff --git a/benchmarks/suites/media.py b/benchmarks/suites/media.py
new file mode 100644
index 000000000..9cbffdaa1
--- /dev/null
+++ b/benchmarks/suites/media.py
@@ -0,0 +1,42 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Media processing benchmarks."""
+
+from benchmarks import suites
+from benchmarks.harness import machine
+from benchmarks.suites import helpers
+from benchmarks.workloads import ffmpeg
+
+
+@suites.benchmark(metrics=[ffmpeg.run_time], machines=1)
+def transcode(target: machine.Machine, **kwargs) -> float:
+ """Runs a video transcoding workload and times it.
+
+ Args:
+ target: A machine object.
+ **kwargs: Additional container options.
+
+ Returns:
+ Total workload runtime.
+ """
+ # Load before timing.
+ image = target.pull("ffmpeg")
+
+ # Drop caches.
+ helpers.drop_caches(target)
+
+ # Time startup + transcoding.
+ with helpers.Timer() as timer:
+ target.container(image, **kwargs).run()
+ return timer.elapsed()
diff --git a/benchmarks/suites/ml.py b/benchmarks/suites/ml.py
new file mode 100644
index 000000000..a394d1f69
--- /dev/null
+++ b/benchmarks/suites/ml.py
@@ -0,0 +1,33 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Machine Learning tests."""
+
+from benchmarks import suites
+from benchmarks.harness import machine
+from benchmarks.suites import startup
+from benchmarks.workloads import tensorflow
+
+
+@suites.benchmark(metrics=[tensorflow.run_time], machines=1)
+def train(target: machine.Machine, **kwargs):
+ """Run the tensorflow benchmark and return the runtime in seconds of workload.
+
+ Args:
+ target: A machine object.
+ **kwargs: Additional container options.
+
+ Returns:
+ The total runtime.
+ """
+ return startup.startup(target, workload="tensorflow", count=1, **kwargs)
diff --git a/benchmarks/suites/network.py b/benchmarks/suites/network.py
new file mode 100644
index 000000000..f973cf3f1
--- /dev/null
+++ b/benchmarks/suites/network.py
@@ -0,0 +1,101 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Network microbenchmarks."""
+
+from typing import Dict
+
+from benchmarks import suites
+from benchmarks.harness import machine
+from benchmarks.suites import helpers
+from benchmarks.workloads import iperf
+
+
+def run_iperf(client: machine.Machine,
+ server: machine.Machine,
+ client_kwargs: Dict[str, str] = None,
+ server_kwargs: Dict[str, str] = None) -> str:
+ """Measure iperf performance.
+
+ Args:
+ client: A machine object.
+ server: A machine object.
+ client_kwargs: Additional client container options.
+ server_kwargs: Additional server container options.
+
+ Returns:
+ The output of iperf.
+ """
+ if not client_kwargs:
+ client_kwargs = dict()
+ if not server_kwargs:
+ server_kwargs = dict()
+
+ # Pull images.
+ netcat = client.pull("netcat")
+ iperf_client_image = client.pull("iperf")
+ iperf_server_image = server.pull("iperf")
+
+ # Set this due to a bug in the kernel that resets connections.
+ client.run("sudo /sbin/sysctl -w net.netfilter.nf_conntrack_tcp_be_liberal=1")
+ server.run("sudo /sbin/sysctl -w net.netfilter.nf_conntrack_tcp_be_liberal=1")
+
+ with server.container(
+ iperf_server_image, port=5001, **server_kwargs).detach() as iperf_server:
+ (host, port) = iperf_server.address()
+ # Wait until the service is available.
+ client.container(netcat).run(host=host, port=port)
+ # Run a warm-up run.
+ client.container(
+ iperf_client_image, stderr=True, **client_kwargs).run(
+ host=host, port=port)
+ # Run the client with relevant arguments.
+ res = client.container(iperf_client_image, stderr=True, **client_kwargs)\
+ .run(host=host, port=port)
+ helpers.drop_caches(client)
+ return res
+
+
+@suites.benchmark(metrics=[iperf.bandwidth], machines=2)
+def upload(client: machine.Machine, server: machine.Machine, **kwargs) -> str:
+ """Measure upload performance.
+
+ Args:
+ client: A machine object.
+ server: A machine object.
+ **kwargs: Client container options.
+
+ Returns:
+ The output of iperf.
+ """
+ if kwargs["runtime"] == "runc":
+ kwargs["network_mode"] = "host"
+ return run_iperf(client, server, client_kwargs=kwargs)
+
+
+@suites.benchmark(metrics=[iperf.bandwidth], machines=2)
+def download(client: machine.Machine, server: machine.Machine, **kwargs) -> str:
+ """Measure download performance.
+
+ Args:
+ client: A machine object.
+ server: A machine object.
+ **kwargs: Server container options.
+
+ Returns:
+ The output of iperf.
+ """
+
+ client_kwargs = {"network_mode": "host"}
+ return run_iperf(
+ client, server, client_kwargs=client_kwargs, server_kwargs=kwargs)
diff --git a/benchmarks/suites/redis.py b/benchmarks/suites/redis.py
new file mode 100644
index 000000000..b84dd073d
--- /dev/null
+++ b/benchmarks/suites/redis.py
@@ -0,0 +1,46 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Redis benchmarks."""
+
+from benchmarks import suites
+from benchmarks.harness import machine
+from benchmarks.workloads import redisbenchmark
+
+
+@suites.benchmark(metrics=list(redisbenchmark.METRICS.values()), machines=2)
+def redis(server: machine.Machine,
+ client: machine.Machine,
+ flags: str = "",
+ **kwargs) -> str:
+ """Run redis-benchmark on client pointing at server machine.
+
+ Args:
+ server: A machine object.
+ client: A machine object.
+ flags: Flags to pass redis-benchmark.
+ **kwargs: Additional container options.
+
+ Returns:
+ Output from redis-benchmark.
+ """
+ redis_server = server.pull("redis")
+ redis_client = client.pull("redisbenchmark")
+ netcat = client.pull("netcat")
+ with server.container(
+ redis_server, port=6379, **kwargs).detach() as container:
+ (host, port) = container.address()
+ # Wait for the container to be up.
+ client.container(netcat).run(host=host, port=port)
+ # Run all redis benchmarks.
+ return client.container(redis_client).run(host=host, port=port, flags=flags)
diff --git a/benchmarks/suites/startup.py b/benchmarks/suites/startup.py
new file mode 100644
index 000000000..a1b6c5753
--- /dev/null
+++ b/benchmarks/suites/startup.py
@@ -0,0 +1,110 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Start-up benchmarks."""
+
+from benchmarks import suites
+from benchmarks.harness import machine
+from benchmarks.suites import helpers
+
+
+# pylint: disable=unused-argument
+def startup_time_ms(value, **kwargs):
+ """Returns average startup time per container in milliseconds.
+
+ Args:
+ value: The floating point time in seconds.
+ **kwargs: Ignored.
+
+ Returns:
+ The time given in milliseconds.
+ """
+ return value * 1000
+
+
+def startup(target: machine.Machine,
+ workload: str,
+ count: int = 5,
+ port: int = 0,
+ **kwargs):
+ """Time the startup of some workload.
+
+ Args:
+ target: A machine object.
+ workload: The workload to run.
+ count: Number of containers to start.
+ port: The port to check for liveness, if provided.
+ **kwargs: Additional container options.
+
+ Returns:
+ The mean start-up time in seconds.
+ """
+ # Load before timing.
+ image = target.pull(workload)
+ netcat = target.pull("netcat")
+ count = int(count)
+ port = int(port)
+
+ with helpers.Timer() as timer:
+ for _ in range(count):
+ if not port:
+ # Run the container synchronously.
+ target.container(image, **kwargs).run()
+ else:
+ # Run a detached container until httpd available.
+ with target.container(image, port=port, **kwargs).detach() as server:
+ (server_host, server_port) = server.address()
+ target.container(netcat).run(host=server_host, port=server_port)
+ return timer.elapsed() / float(count)
+
+
+@suites.benchmark(metrics=[startup_time_ms], machines=1)
+def empty(target: machine.Machine, **kwargs) -> float:
+ """Time the startup of a trivial container.
+
+ Args:
+ target: A machine object.
+ **kwargs: Additional startup options.
+
+ Returns:
+ The time to run the container.
+ """
+ return startup(target, workload="true", **kwargs)
+
+
+@suites.benchmark(metrics=[startup_time_ms], machines=1)
+def node(target: machine.Machine, **kwargs) -> float:
+ """Time the startup of the node container.
+
+ Args:
+ target: A machine object.
+ **kwargs: Additional statup options.
+
+ Returns:
+ The time to run the container.
+ """
+ return startup(target, workload="node", port=8080, **kwargs)
+
+
+@suites.benchmark(metrics=[startup_time_ms], machines=1)
+def ruby(target: machine.Machine, **kwargs) -> float:
+ """Time the startup of the ruby container.
+
+ Args:
+ target: A machine object.
+ **kwargs: Additional startup options.
+
+ Returns:
+ The time to run the container.
+ """
+ return startup(target, workload="ruby", port=3000, **kwargs)
diff --git a/benchmarks/suites/sysbench.py b/benchmarks/suites/sysbench.py
new file mode 100644
index 000000000..2a6e2126c
--- /dev/null
+++ b/benchmarks/suites/sysbench.py
@@ -0,0 +1,119 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Sysbench-based benchmarks."""
+
+from benchmarks import suites
+from benchmarks.harness import machine
+from benchmarks.workloads import sysbench
+
+
+def run_sysbench(target: machine.Machine,
+ test: str = "cpu",
+ threads: int = 8,
+ time: int = 5,
+ options: str = "",
+ **kwargs) -> str:
+ """Run sysbench container with arguments.
+
+ Args:
+ target: A machine object.
+ test: Relevant sysbench test to run (e.g. cpu, memory).
+ threads: The number of threads to use for tests.
+ time: The time to run tests.
+ options: Additional sysbench options.
+ **kwargs: Additional container options.
+
+ Returns:
+ The output of the command as a string.
+ """
+ image = target.pull("sysbench")
+ return target.container(image, **kwargs).run(
+ test=test, threads=threads, time=time, options=options)
+
+
+@suites.benchmark(metrics=[sysbench.cpu_events_per_second], machines=1)
+def cpu(target: machine.Machine, max_prime: int = 5000, **kwargs) -> str:
+ """Run sysbench CPU test.
+
+ Additional arguments can be provided for sysbench.
+
+ Args:
+ target: A machine object.
+ max_prime: The maximum prime number to search.
+ **kwargs:
+ - threads: The number of threads to use for tests.
+ - time: The time to run tests.
+ - options: Additional sysbench options. See sysbench tool:
+ https://github.com/akopytov/sysbench
+
+ Returns:
+ Sysbench output.
+ """
+ options = kwargs.pop("options", "")
+ options += " --cpu-max-prime={}".format(max_prime)
+ return run_sysbench(target, test="cpu", options=options, **kwargs)
+
+
+@suites.benchmark(metrics=[sysbench.memory_ops_per_second], machines=1)
+def memory(target: machine.Machine, **kwargs) -> str:
+ """Run sysbench memory test.
+
+ Additional arguments can be provided per sysbench.
+
+ Args:
+ target: A machine object.
+ **kwargs:
+ - threads: The number of threads to use for tests.
+ - time: The time to run tests.
+ - options: Additional sysbench options. See sysbench tool:
+ https://github.com/akopytov/sysbench
+
+ Returns:
+ Sysbench output.
+ """
+ return run_sysbench(target, test="memory", **kwargs)
+
+
+@suites.benchmark(
+ metrics=[
+ sysbench.mutex_time, sysbench.mutex_latency, sysbench.mutex_deviation
+ ],
+ machines=1)
+def mutex(target: machine.Machine,
+ locks: int = 4,
+ count: int = 10000000,
+ threads: int = 8,
+ **kwargs) -> str:
+ """Run sysbench mutex test.
+
+ Additional arguments can be provided per sysbench.
+
+ Args:
+ target: A machine object.
+ locks: The number of locks to use.
+ count: The number of mutexes.
+ threads: The number of threads to use for tests.
+ **kwargs:
+ - time: The time to run tests.
+ - options: Additional sysbench options. See sysbench tool:
+ https://github.com/akopytov/sysbench
+
+ Returns:
+ Sysbench output.
+ """
+ options = kwargs.pop("options", "")
+ options += " --mutex-loops=1 --mutex-locks={} --mutex-num={}".format(
+ count, locks)
+ return run_sysbench(
+ target, test="mutex", options=options, threads=threads, **kwargs)
diff --git a/benchmarks/suites/syscall.py b/benchmarks/suites/syscall.py
new file mode 100644
index 000000000..fa7665b00
--- /dev/null
+++ b/benchmarks/suites/syscall.py
@@ -0,0 +1,37 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Syscall microbenchmark."""
+
+from benchmarks import suites
+from benchmarks.harness import machine
+from benchmarks.workloads.syscall import syscall_time_ns
+
+
+@suites.benchmark(metrics=[syscall_time_ns], machines=1)
+def syscall(target: machine.Machine, count: int = 1000000, **kwargs) -> str:
+ """Runs the syscall workload and report the syscall time.
+
+ Runs the syscall 'SYS_gettimeofday(0,0)' 'count' times and monitors time
+ elapsed based on the runtime's MONOTONIC clock.
+
+ Args:
+ target: A machine object.
+ count: The number of syscalls to execute.
+ **kwargs: Additional container options.
+
+ Returns:
+ Container output.
+ """
+ image = target.pull("syscall")
+ return target.container(image, **kwargs).run(count=count)
diff --git a/benchmarks/tcp/BUILD b/benchmarks/tcp/BUILD
new file mode 100644
index 000000000..6dde7d9e6
--- /dev/null
+++ b/benchmarks/tcp/BUILD
@@ -0,0 +1,41 @@
+load("//tools:defs.bzl", "cc_binary", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "tcp_proxy",
+ srcs = ["tcp_proxy.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/adapters/gonet",
+ "//pkg/tcpip/link/fdbased",
+ "//pkg/tcpip/link/qdisc/fifo",
+ "//pkg/tcpip/network/arp",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+# nsjoin is a trivial replacement for nsenter. This is used because nsenter is
+# not available on all systems where this benchmark is run (and we aim to
+# minimize external dependencies.)
+
+cc_binary(
+ name = "nsjoin",
+ srcs = ["nsjoin.c"],
+ visibility = ["//:sandbox"],
+)
+
+sh_binary(
+ name = "tcp_benchmark",
+ srcs = ["tcp_benchmark.sh"],
+ data = [
+ ":nsjoin",
+ ":tcp_proxy",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/benchmarks/tcp/README.md b/benchmarks/tcp/README.md
new file mode 100644
index 000000000..38e6e69f0
--- /dev/null
+++ b/benchmarks/tcp/README.md
@@ -0,0 +1,87 @@
+# TCP Benchmarks
+
+This directory contains a standardized TCP benchmark. This helps to evaluate the
+performance of netstack and native networking stacks under various conditions.
+
+## `tcp_benchmark`
+
+This benchmark allows TCP throughput testing under various conditions. The setup
+consists of an iperf client, a client proxy, a server proxy and an iperf server.
+The client proxy and server proxy abstract the network mechanism used to
+communicate between the iperf client and server.
+
+The setup looks like the following:
+
+```
+ +--------------+ (native) +--------------+
+ | iperf client |[lo @ 10.0.0.1]------>| client proxy |
+ +--------------+ +--------------+
+ [client.0 @ 10.0.0.2]
+ (netstack) | | (native)
+ +------+-----+
+ |
+ [br0]
+ |
+ Network emulation applied ---> [wan.0:wan.1]
+ |
+ [br1]
+ |
+ +------+-----+
+ (netstack) | | (native)
+ [server.0 @ 10.0.0.3]
+ +--------------+ +--------------+
+ | iperf server |<------[lo @ 10.0.0.4]| server proxy |
+ +--------------+ (native) +--------------+
+```
+
+Different configurations can be run using different arguments. For example:
+
+* Native test under normal internet conditions: `tcp_benchmark`
+* Native test under ideal conditions: `tcp_benchmark --ideal`
+* Netstack client under ideal conditions: `tcp_benchmark --client --ideal`
+* Netstack client with 5% packet loss: `tcp_benchmark --client --ideal --loss
+ 5`
+
+Use `tcp_benchmark --help` for full arguments.
+
+This tool may be used to easily generate data for graphing. For example, to
+generate a CSV for various latencies, you might do:
+
+```
+rm -f /tmp/netstack_latency.csv /tmp/native_latency.csv
+latencies=$(seq 0 5 50;
+ seq 60 10 100;
+ seq 125 25 250;
+ seq 300 50 500)
+for latency in $latencies; do
+ read throughput client_cpu server_cpu <<< \
+ $(./tcp_benchmark --duration 30 --client --ideal --latency $latency)
+ echo $latency,$throughput,$client_cpu >> /tmp/netstack_latency.csv
+done
+for latency in $latencies; do
+ read throughput client_cpu server_cpu <<< \
+ $(./tcp_benchmark --duration 30 --ideal --latency $latency)
+ echo $latency,$throughput,$client_cpu >> /tmp/native_latency.csv
+done
+```
+
+Similarly, to generate a CSV for various levels of packet loss, the following
+would be appropriate:
+
+```
+rm -f /tmp/netstack_loss.csv /tmp/native_loss.csv
+losses=$(seq 0 0.1 1.0;
+ seq 1.2 0.2 2.0;
+ seq 2.5 0.5 5.0;
+ seq 6.0 1.0 10.0)
+for loss in $losses; do
+ read throughput client_cpu server_cpu <<< \
+ $(./tcp_benchmark --duration 30 --client --ideal --latency 10 --loss $loss)
+ echo $loss,$throughput,$client_cpu >> /tmp/netstack_loss.csv
+done
+for loss in $losses; do
+ read throughput client_cpu server_cpu <<< \
+ $(./tcp_benchmark --duration 30 --ideal --latency 10 --loss $loss)
+ echo $loss,$throughput,$client_cpu >> /tmp/native_loss.csv
+done
+```
diff --git a/benchmarks/tcp/nsjoin.c b/benchmarks/tcp/nsjoin.c
new file mode 100644
index 000000000..524b4d549
--- /dev/null
+++ b/benchmarks/tcp/nsjoin.c
@@ -0,0 +1,47 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef _GNU_SOURCE
+#define _GNU_SOURCE
+#endif
+
+#include <errno.h>
+#include <fcntl.h>
+#include <sched.h>
+#include <stdio.h>
+#include <string.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+int main(int argc, char** argv) {
+ if (argc <= 2) {
+ fprintf(stderr, "error: must provide a namespace file.\n");
+ fprintf(stderr, "usage: %s <file> [arguments...]\n", argv[0]);
+ return 1;
+ }
+
+ int fd = open(argv[1], O_RDONLY);
+ if (fd < 0) {
+ fprintf(stderr, "error opening %s: %s\n", argv[1], strerror(errno));
+ return 1;
+ }
+ if (setns(fd, 0) < 0) {
+ fprintf(stderr, "error joining %s: %s\n", argv[1], strerror(errno));
+ return 1;
+ }
+
+ execvp(argv[2], &argv[2]);
+ return 1;
+}
diff --git a/benchmarks/tcp/tcp_benchmark.sh b/benchmarks/tcp/tcp_benchmark.sh
new file mode 100755
index 000000000..ef04b4ace
--- /dev/null
+++ b/benchmarks/tcp/tcp_benchmark.sh
@@ -0,0 +1,392 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# TCP benchmark; see README.md for documentation.
+
+# Fixed parameters.
+iperf_port=45201 # Not likely to be privileged.
+proxy_port=44000 # Ditto.
+client_addr=10.0.0.1
+client_proxy_addr=10.0.0.2
+server_proxy_addr=10.0.0.3
+server_addr=10.0.0.4
+mask=8
+
+# Defaults; this provides a reasonable approximation of a decent internet link.
+# Parameters can be varied independently from this set to see response to
+# various changes in the kind of link available.
+client=false
+server=false
+verbose=false
+gso=0
+swgso=false
+mtu=1280 # 1280 is a reasonable lowest-common-denominator.
+latency=10 # 10ms approximates a fast, dedicated connection.
+latency_variation=1 # +/- 1ms is a relatively low amount of jitter.
+loss=0.1 # 0.1% loss is non-zero, but not extremely high.
+duplicate=0.1 # 0.1% means duplicates are 1/10x as frequent as losses.
+duration=30 # 30s is enough time to consistent results (experimentally).
+helper_dir=$(dirname $0)
+netstack_opts=
+disable_linux_gso=
+num_client_threads=1
+
+# Check for netem support.
+lsmod_output=$(lsmod | grep sch_netem)
+if [ "$?" != "0" ]; then
+ echo "warning: sch_netem may not be installed." >&2
+fi
+
+while [ $# -gt 0 ]; do
+ case "$1" in
+ --client)
+ client=true
+ ;;
+ --client_tcp_probe_file)
+ shift
+ netstack_opts="${netstack_opts} -client_tcp_probe_file=$1"
+ ;;
+ --server)
+ server=true
+ ;;
+ --verbose)
+ verbose=true
+ ;;
+ --gso)
+ shift
+ gso=$1
+ ;;
+ --swgso)
+ swgso=true
+ ;;
+ --server_tcp_probe_file)
+ shift
+ netstack_opts="${netstack_opts} -server_tcp_probe_file=$1"
+ ;;
+ --ideal)
+ mtu=1500 # Standard ethernet.
+ latency=0 # No latency.
+ latency_variation=0 # No jitter.
+ loss=0 # No loss.
+ duplicate=0 # No duplicates.
+ ;;
+ --mtu)
+ shift
+ [ "$#" -le 0 ] && echo "no mtu provided" && exit 1
+ mtu=$1
+ ;;
+ --sack)
+ netstack_opts="${netstack_opts} -sack"
+ ;;
+ --cubic)
+ netstack_opts="${netstack_opts} -cubic"
+ ;;
+ --moderate-recv-buf)
+ netstack_opts="${netstack_opts} -moderate_recv_buf"
+ ;;
+ --duration)
+ shift
+ [ "$#" -le 0 ] && echo "no duration provided" && exit 1
+ duration=$1
+ ;;
+ --latency)
+ shift
+ [ "$#" -le 0 ] && echo "no latency provided" && exit 1
+ latency=$1
+ ;;
+ --latency-variation)
+ shift
+ [ "$#" -le 0 ] && echo "no latency variation provided" && exit 1
+ latency_variation=$1
+ ;;
+ --loss)
+ shift
+ [ "$#" -le 0 ] && echo "no loss probability provided" && exit 1
+ loss=$1
+ ;;
+ --duplicate)
+ shift
+ [ "$#" -le 0 ] && echo "no duplicate provided" && exit 1
+ duplicate=$1
+ ;;
+ --cpuprofile)
+ shift
+ netstack_opts="${netstack_opts} -cpuprofile=$1"
+ ;;
+ --memprofile)
+ shift
+ netstack_opts="${netstack_opts} -memprofile=$1"
+ ;;
+ --disable-linux-gso)
+ disable_linux_gso=1
+ ;;
+ --num-client-threads)
+ shift
+ num_client_threads=$1
+ ;;
+ --helpers)
+ shift
+ [ "$#" -le 0 ] && echo "no helper dir provided" && exit 1
+ helper_dir=$1
+ ;;
+ *)
+ echo "usage: $0 [options]"
+ echo "options:"
+ echo " --help show this message"
+ echo " --verbose verbose output"
+ echo " --client use netstack as the client"
+ echo " --ideal reset all network emulation"
+ echo " --server use netstack as the server"
+ echo " --mtu set the mtu (bytes)"
+ echo " --sack enable SACK support"
+ echo " --moderate-recv-buf enable TCP receive buffer auto-tuning"
+ echo " --cubic enable CUBIC congestion control for Netstack"
+ echo " --duration set the test duration (s)"
+ echo " --latency set the latency (ms)"
+ echo " --latency-variation set the latency variation"
+ echo " --loss set the loss probability (%)"
+ echo " --duplicate set the duplicate probability (%)"
+ echo " --helpers set the helper directory"
+ echo " --num-client-threads number of parallel client threads to run"
+ echo " --disable-linux-gso disable segmentation offload in the Linux network stack"
+ echo ""
+ echo "The output will of the script will be:"
+ echo " <throughput> <client-cpu-usage> <server-cpu-usage>"
+ exit 1
+ esac
+ shift
+done
+
+if [ ${verbose} == "true" ]; then
+ set -x
+fi
+
+# Latency needs to be halved, since it's applied on both ways.
+half_latency=$(echo ${latency}/2 | bc -l | awk '{printf "%1.2f", $0}')
+half_loss=$(echo ${loss}/2 | bc -l | awk '{printf "%1.6f", $0}')
+half_duplicate=$(echo ${duplicate}/2 | bc -l | awk '{printf "%1.6f", $0}')
+helper_dir=${helper_dir#$(pwd)/} # Use relative paths.
+proxy_binary=${helper_dir}/tcp_proxy
+nsjoin_binary=${helper_dir}/nsjoin
+
+if [ ! -e ${proxy_binary} ]; then
+ echo "Could not locate ${proxy_binary}, please make sure you've built the binary"
+ exit 1
+fi
+
+if [ ! -e ${nsjoin_binary} ]; then
+ echo "Could not locate ${nsjoin_binary}, please make sure you've built the binary"
+ exit 1
+fi
+
+if [ $(echo ${latency_variation} | awk '{printf "%1.2f", $0}') != "0.00" ]; then
+ # As long as there's some jitter, then we use the paretonormal distribution.
+ # This will preserve the minimum RTT, but add a realistic amount of jitter to
+ # the connection and cause re-ordering, etc. The regular pareto distribution
+ # appears to an unreasonable level of delay (we want only small spikes.)
+ distribution="distribution paretonormal"
+else
+ distribution=""
+fi
+
+# Client proxy that will listen on the client's iperf target forward traffic
+# using the host networking stack.
+client_args="${proxy_binary} -port ${proxy_port} -forward ${server_proxy_addr}:${proxy_port}"
+if ${client}; then
+ # Client proxy that will listen on the client's iperf target
+ # and forward traffic using netstack.
+ client_args="${proxy_binary} ${netstack_opts} -port ${proxy_port} -client \\
+ -mtu ${mtu} -iface client.0 -addr ${client_proxy_addr} -mask ${mask} \\
+ -forward ${server_proxy_addr}:${proxy_port} -gso=${gso} -swgso=${swgso}"
+fi
+
+# Server proxy that will listen on the proxy port and forward to the server's
+# iperf server using the host networking stack.
+server_args="${proxy_binary} -port ${proxy_port} -forward ${server_addr}:${iperf_port}"
+if ${server}; then
+ # Server proxy that will listen on the proxy port and forward to the servers'
+ # iperf server using netstack.
+ server_args="${proxy_binary} ${netstack_opts} -port ${proxy_port} -server \\
+ -mtu ${mtu} -iface server.0 -addr ${server_proxy_addr} -mask ${mask} \\
+ -forward ${server_addr}:${iperf_port} -gso=${gso} -swgso=${swgso}"
+fi
+
+# Specify loss and duplicate parameters only if they are non-zero
+loss_opt=""
+if [ "$(echo $half_loss | bc -q)" != "0" ]; then
+ loss_opt="loss random ${half_loss}%"
+fi
+duplicate_opt=""
+if [ "$(echo $half_duplicate | bc -q)" != "0" ]; then
+ duplicate_opt="duplicate ${half_duplicate}%"
+fi
+
+exec unshare -U -m -n -r -f -p --mount-proc /bin/bash << EOF
+set -e -m
+
+if [ ${verbose} == "true" ]; then
+ set -x
+fi
+
+mount -t tmpfs netstack-bench /tmp
+
+# We may have reset the path in the unshare if the shell loaded some public
+# profiles. Ensure that tools are discoverable via the parent's PATH.
+export PATH=${PATH}
+
+# Add client, server interfaces.
+ip link add client.0 type veth peer name client.1
+ip link add server.0 type veth peer name server.1
+
+# Add network emulation devices.
+ip link add wan.0 type veth peer name wan.1
+ip link set wan.0 up
+ip link set wan.1 up
+
+# Enroll on the bridge.
+ip link add name br0 type bridge
+ip link add name br1 type bridge
+ip link set client.1 master br0
+ip link set server.1 master br1
+ip link set wan.0 master br0
+ip link set wan.1 master br1
+ip link set br0 up
+ip link set br1 up
+
+# Set the MTU appropriately.
+ip link set client.0 mtu ${mtu}
+ip link set server.0 mtu ${mtu}
+ip link set wan.0 mtu ${mtu}
+ip link set wan.1 mtu ${mtu}
+
+# Add appropriate latency, loss and duplication.
+#
+# This is added in at the point of bridge connection.
+for device in wan.0 wan.1; do
+ # NOTE: We don't support a loss correlation as testing has shown that it
+ # actually doesn't work. The man page actually has a small comment about this
+ # "It is also possible to add a correlation, but this option is now deprecated
+ # due to the noticed bad behavior." For more information see netem(8).
+ tc qdisc add dev \$device root netem \\
+ delay ${half_latency}ms ${latency_variation}ms ${distribution} \\
+ ${loss_opt} ${duplicate_opt}
+done
+
+# Start a client proxy.
+touch /tmp/client.netns
+unshare -n mount --bind /proc/self/ns/net /tmp/client.netns
+
+# Move the endpoint into the namespace.
+while ip link | grep client.0 > /dev/null; do
+ ip link set dev client.0 netns /tmp/client.netns
+done
+
+if ! ${client}; then
+ # Only add the address to NIC if netstack is not in use. Otherwise the host
+ # will also process the inbound SYN and send a RST back.
+ ${nsjoin_binary} /tmp/client.netns ip addr add ${client_proxy_addr}/${mask} dev client.0
+fi
+
+# Start a server proxy.
+touch /tmp/server.netns
+unshare -n mount --bind /proc/self/ns/net /tmp/server.netns
+# Move the endpoint into the namespace.
+while ip link | grep server.0 > /dev/null; do
+ ip link set dev server.0 netns /tmp/server.netns
+done
+if ! ${server}; then
+ # Only add the address to NIC if netstack is not in use. Otherwise the host
+ # will also process the inbound SYN and send a RST back.
+ ${nsjoin_binary} /tmp/server.netns ip addr add ${server_proxy_addr}/${mask} dev server.0
+fi
+
+# Add client and server addresses, and bring everything up.
+${nsjoin_binary} /tmp/client.netns ip addr add ${client_addr}/${mask} dev client.0
+${nsjoin_binary} /tmp/server.netns ip addr add ${server_addr}/${mask} dev server.0
+if [ "${disable_linux_gso}" == "1" ]; then
+ ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 tso off
+ ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 gro off
+ ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 gso off
+ ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 tso off
+ ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 gso off
+ ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 gro off
+fi
+${nsjoin_binary} /tmp/client.netns ip link set client.0 up
+${nsjoin_binary} /tmp/client.netns ip link set lo up
+${nsjoin_binary} /tmp/server.netns ip link set server.0 up
+${nsjoin_binary} /tmp/server.netns ip link set lo up
+ip link set dev client.1 up
+ip link set dev server.1 up
+
+${nsjoin_binary} /tmp/client.netns ${client_args} &
+client_pid=\$!
+${nsjoin_binary} /tmp/server.netns ${server_args} &
+server_pid=\$!
+
+# Start the iperf server.
+${nsjoin_binary} /tmp/server.netns iperf -p ${iperf_port} -s >&2 &
+iperf_pid=\$!
+
+# Show traffic information.
+if ! ${client} && ! ${server}; then
+ ${nsjoin_binary} /tmp/client.netns ping -c 100 -i 0.001 -W 1 ${server_addr} >&2 || true
+fi
+
+results_file=\$(mktemp)
+function cleanup {
+ rm -f \$results_file
+ kill -TERM \$client_pid
+ kill -TERM \$server_pid
+ wait \$client_pid
+ wait \$server_pid
+ kill -9 \$iperf_pid 2>/dev/null
+}
+
+# Allow failure from this point.
+set +e
+trap cleanup EXIT
+
+# Run the benchmark, recording the results file.
+while ${nsjoin_binary} /tmp/client.netns iperf \\
+ -p ${proxy_port} -c ${client_addr} -t ${duration} -f m -P ${num_client_threads} 2>&1 \\
+ | tee \$results_file \\
+ | grep "connect failed" >/dev/null; do
+ sleep 0.1 # Wait for all services.
+done
+
+# Unlink all relevant devices from the bridge. This is because when the bridge
+# is deleted, the kernel may hang. It appears that this problem is fixed in
+# upstream commit 1ce5cce895309862d2c35d922816adebe094fe4a.
+ip link set client.1 nomaster
+ip link set server.1 nomaster
+ip link set wan.0 nomaster
+ip link set wan.1 nomaster
+
+# Emit raw results.
+cat \$results_file >&2
+
+# Emit a useful result (final throughput).
+mbits=\$(grep Mbits/sec \$results_file \\
+ | sed -n -e 's/^.*[[:space:]]\\([[:digit:]]\\+\\(\\.[[:digit:]]\\+\\)\\?\\)[[:space:]]*Mbits\\/sec.*/\\1/p')
+client_cpu_ticks=\$(cat /proc/\$client_pid/stat \\
+ | awk '{print (\$14+\$15);}')
+server_cpu_ticks=\$(cat /proc/\$server_pid/stat \\
+ | awk '{print (\$14+\$15);}')
+ticks_per_sec=\$(getconf CLK_TCK)
+client_cpu_load=\$(bc -l <<< \$client_cpu_ticks/\$ticks_per_sec/${duration})
+server_cpu_load=\$(bc -l <<< \$server_cpu_ticks/\$ticks_per_sec/${duration})
+echo \$mbits \$client_cpu_load \$server_cpu_load
+EOF
diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go
new file mode 100644
index 000000000..4b7ca7a14
--- /dev/null
+++ b/benchmarks/tcp/tcp_proxy.go
@@ -0,0 +1,451 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary tcp_proxy is a simple TCP proxy.
+package main
+
+import (
+ "encoding/gob"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "math/rand"
+ "net"
+ "os"
+ "os/signal"
+ "regexp"
+ "runtime"
+ "runtime/pprof"
+ "strconv"
+ "syscall"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+ "gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
+ "gvisor.dev/gvisor/pkg/tcpip/link/qdisc/fifo"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+)
+
+var (
+ port = flag.Int("port", 0, "bind port (all addresses)")
+ forward = flag.String("forward", "", "forwarding target")
+ client = flag.Bool("client", false, "use netstack for listen")
+ server = flag.Bool("server", false, "use netstack for dial")
+
+ // Netstack-specific options.
+ mtu = flag.Int("mtu", 1280, "mtu for network stack")
+ addr = flag.String("addr", "", "address for tap-based netstack")
+ mask = flag.Int("mask", 8, "mask size for address")
+ iface = flag.String("iface", "", "network interface name to bind for netstack")
+ sack = flag.Bool("sack", false, "enable SACK support for netstack")
+ moderateRecvBuf = flag.Bool("moderate_recv_buf", false, "enable TCP Receive Buffer Auto-tuning")
+ cubic = flag.Bool("cubic", false, "enable use of CUBIC congestion control for netstack")
+ gso = flag.Int("gso", 0, "GSO maximum size")
+ swgso = flag.Bool("swgso", false, "software-level GSO")
+ clientTCPProbeFile = flag.String("client_tcp_probe_file", "", "if specified, installs a tcp probe to dump endpoint state to the specified file.")
+ serverTCPProbeFile = flag.String("server_tcp_probe_file", "", "if specified, installs a tcp probe to dump endpoint state to the specified file.")
+ cpuprofile = flag.String("cpuprofile", "", "write cpu profile to the specified file.")
+ memprofile = flag.String("memprofile", "", "write memory profile to the specified file.")
+)
+
+type impl interface {
+ dial(address string) (net.Conn, error)
+ listen(port int) (net.Listener, error)
+ printStats()
+}
+
+type netImpl struct{}
+
+func (netImpl) dial(address string) (net.Conn, error) {
+ return net.Dial("tcp", address)
+}
+
+func (netImpl) listen(port int) (net.Listener, error) {
+ return net.Listen("tcp", fmt.Sprintf(":%d", port))
+}
+
+func (netImpl) printStats() {
+}
+
+const (
+ nicID = 1 // Fixed.
+ bufSize = 4 << 20 // 4MB.
+)
+
+type netstackImpl struct {
+ s *stack.Stack
+ addr tcpip.Address
+ mode string
+}
+
+func setupNetwork(ifaceName string, numChannels int) (fds []int, err error) {
+ // Get all interfaces in the namespace.
+ ifaces, err := net.Interfaces()
+ if err != nil {
+ return nil, fmt.Errorf("querying interfaces: %v", err)
+ }
+
+ for _, iface := range ifaces {
+ if iface.Name != ifaceName {
+ continue
+ }
+ // Create the socket.
+ const protocol = 0x0300 // htons(ETH_P_ALL)
+ fds := make([]int, numChannels)
+ for i := range fds {
+ fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create raw socket: %v", err)
+ }
+
+ // Bind to the appropriate device.
+ ll := syscall.SockaddrLinklayer{
+ Protocol: protocol,
+ Ifindex: iface.Index,
+ Pkttype: syscall.PACKET_HOST,
+ }
+ if err := syscall.Bind(fd, &ll); err != nil {
+ return nil, fmt.Errorf("unable to bind to %q: %v", iface.Name, err)
+ }
+
+ // RAW Sockets by default have a very small SO_RCVBUF of 256KB,
+ // up it to at least 4MB to reduce packet drops.
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bufSize); err != nil {
+ return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", bufSize, err)
+ }
+
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bufSize); err != nil {
+ return nil, fmt.Errorf("setsockopt(..., SO_SNDBUF, %v,..) = %v", bufSize, err)
+ }
+
+ if !*swgso && *gso != 0 {
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil {
+ return nil, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err)
+ }
+ }
+ fds[i] = fd
+ }
+ return fds, nil
+ }
+ return nil, fmt.Errorf("failed to find interface: %v", ifaceName)
+}
+
+func newNetstackImpl(mode string) (impl, error) {
+ fds, err := setupNetwork(*iface, runtime.GOMAXPROCS(-1))
+ if err != nil {
+ return nil, err
+ }
+
+ // Parse details.
+ parsedAddr := tcpip.Address(net.ParseIP(*addr).To4())
+ parsedDest := tcpip.Address("") // Filled in below.
+ parsedMask := tcpip.AddressMask("") // Filled in below.
+ switch *mask {
+ case 8:
+ parsedDest = tcpip.Address([]byte{parsedAddr[0], 0, 0, 0})
+ parsedMask = tcpip.AddressMask([]byte{0xff, 0, 0, 0})
+ case 16:
+ parsedDest = tcpip.Address([]byte{parsedAddr[0], parsedAddr[1], 0, 0})
+ parsedMask = tcpip.AddressMask([]byte{0xff, 0xff, 0, 0})
+ case 24:
+ parsedDest = tcpip.Address([]byte{parsedAddr[0], parsedAddr[1], parsedAddr[2], 0})
+ parsedMask = tcpip.AddressMask([]byte{0xff, 0xff, 0xff, 0})
+ default:
+ // This is just laziness; we don't expect a different mask.
+ return nil, fmt.Errorf("mask %d not supported", mask)
+ }
+
+ // Create a new network stack.
+ netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}
+ transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()}
+ s := stack.New(stack.Options{
+ NetworkProtocols: netProtos,
+ TransportProtocols: transProtos,
+ })
+
+ // Generate a new mac for the eth device.
+ mac := make(net.HardwareAddr, 6)
+ rand.Read(mac) // Fill with random data.
+ mac[0] &^= 0x1 // Clear multicast bit.
+ mac[0] |= 0x2 // Set local assignment bit (IEEE802).
+ ep, err := fdbased.New(&fdbased.Options{
+ FDs: fds,
+ MTU: uint32(*mtu),
+ EthernetHeader: true,
+ Address: tcpip.LinkAddress(mac),
+ // Enable checksum generation as we need to generate valid
+ // checksums for the veth device to deliver our packets to the
+ // peer. But we do want to disable checksum verification as veth
+ // devices do perform GRO and the linux host kernel may not
+ // regenerate valid checksums after GRO.
+ TXChecksumOffload: false,
+ RXChecksumOffload: true,
+ PacketDispatchMode: fdbased.RecvMMsg,
+ GSOMaxSize: uint32(*gso),
+ SoftwareGSOEnabled: *swgso,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to create FD endpoint: %v", err)
+ }
+ if err := s.CreateNIC(nicID, fifo.New(ep, runtime.GOMAXPROCS(0), 1000)); err != nil {
+ return nil, fmt.Errorf("error creating NIC %q: %v", *iface, err)
+ }
+ if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ return nil, fmt.Errorf("error adding ARP address to %q: %v", *iface, err)
+ }
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, parsedAddr); err != nil {
+ return nil, fmt.Errorf("error adding IP address to %q: %v", *iface, err)
+ }
+
+ subnet, err := tcpip.NewSubnet(parsedDest, parsedMask)
+ if err != nil {
+ return nil, fmt.Errorf("tcpip.Subnet(%s, %s): %s", parsedDest, parsedMask, err)
+ }
+ // Add default route; we only support
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: subnet,
+ NIC: nicID,
+ },
+ })
+
+ // Set protocol options.
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(*sack)); err != nil {
+ return nil, fmt.Errorf("SetTransportProtocolOption for SACKEnabled failed: %s", err)
+ }
+
+ // Enable Receive Buffer Auto-Tuning.
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(*moderateRecvBuf)); err != nil {
+ return nil, fmt.Errorf("SetTransportProtocolOption failed: %s", err)
+ }
+
+ // Set Congestion Control to cubic if requested.
+ if *cubic {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.CongestionControlOption("cubic")); err != nil {
+ return nil, fmt.Errorf("SetTransportProtocolOption for CongestionControlOption(cubic) failed: %s", err)
+ }
+ }
+
+ return netstackImpl{
+ s: s,
+ addr: parsedAddr,
+ mode: mode,
+ }, nil
+}
+
+func (n netstackImpl) dial(address string) (net.Conn, error) {
+ host, port, err := net.SplitHostPort(address)
+ if err != nil {
+ return nil, err
+ }
+ if host == "" {
+ // A host must be provided for the dial.
+ return nil, fmt.Errorf("no host provided")
+ }
+ portNumber, err := strconv.Atoi(port)
+ if err != nil {
+ return nil, err
+ }
+ addr := tcpip.FullAddress{
+ NIC: nicID,
+ Addr: tcpip.Address(net.ParseIP(host).To4()),
+ Port: uint16(portNumber),
+ }
+ conn, err := gonet.DialTCP(n.s, addr, ipv4.ProtocolNumber)
+ if err != nil {
+ return nil, err
+ }
+ return conn, nil
+}
+
+func (n netstackImpl) listen(port int) (net.Listener, error) {
+ addr := tcpip.FullAddress{
+ NIC: nicID,
+ Port: uint16(port),
+ }
+ listener, err := gonet.ListenTCP(n.s, addr, ipv4.ProtocolNumber)
+ if err != nil {
+ return nil, err
+ }
+ return listener, nil
+}
+
+var zeroFieldsRegexp = regexp.MustCompile(`\s*[a-zA-Z0-9]*:0`)
+
+func (n netstackImpl) printStats() {
+ // Don't show zero fields.
+ stats := zeroFieldsRegexp.ReplaceAllString(fmt.Sprintf("%+v", n.s.Stats()), "")
+ log.Printf("netstack %s Stats: %+v\n", n.mode, stats)
+}
+
+// installProbe installs a TCP Probe function that will dump endpoint
+// state to the specified file. It also returns a close func() that
+// can be used to close the probeFile.
+func (n netstackImpl) installProbe(probeFileName string) (close func()) {
+ // Install Probe to dump out end point state.
+ probeFile, err := os.Create(probeFileName)
+ if err != nil {
+ log.Fatalf("failed to create tcp_probe file %s: %v", probeFileName, err)
+ }
+ probeEncoder := gob.NewEncoder(probeFile)
+ // Install a TCP Probe.
+ n.s.AddTCPProbe(func(state stack.TCPEndpointState) {
+ probeEncoder.Encode(state)
+ })
+ return func() { probeFile.Close() }
+}
+
+func main() {
+ flag.Parse()
+ if *port == 0 {
+ log.Fatalf("no port provided")
+ }
+ if *forward == "" {
+ log.Fatalf("no forward provided")
+ }
+ // Seed the random number generator to ensure that we are given MAC addresses that don't
+ // for the case of the client and server stack.
+ rand.Seed(time.Now().UTC().UnixNano())
+
+ if *cpuprofile != "" {
+ f, err := os.Create(*cpuprofile)
+ if err != nil {
+ log.Fatal("could not create CPU profile: ", err)
+ }
+ defer func() {
+ if err := f.Close(); err != nil {
+ log.Print("error closing CPU profile: ", err)
+ }
+ }()
+ if err := pprof.StartCPUProfile(f); err != nil {
+ log.Fatal("could not start CPU profile: ", err)
+ }
+ defer pprof.StopCPUProfile()
+ }
+
+ var (
+ in impl
+ out impl
+ err error
+ )
+ if *server {
+ in, err = newNetstackImpl("server")
+ if *serverTCPProbeFile != "" {
+ defer in.(netstackImpl).installProbe(*serverTCPProbeFile)()
+ }
+
+ } else {
+ in = netImpl{}
+ }
+ if err != nil {
+ log.Fatalf("netstack error: %v", err)
+ }
+ if *client {
+ out, err = newNetstackImpl("client")
+ if *clientTCPProbeFile != "" {
+ defer out.(netstackImpl).installProbe(*clientTCPProbeFile)()
+ }
+ } else {
+ out = netImpl{}
+ }
+ if err != nil {
+ log.Fatalf("netstack error: %v", err)
+ }
+
+ // Dial forward before binding.
+ var next net.Conn
+ for {
+ next, err = out.dial(*forward)
+ if err == nil {
+ break
+ }
+ time.Sleep(50 * time.Millisecond)
+ log.Printf("connect failed retrying: %v", err)
+ }
+
+ // Bind once to the server socket.
+ listener, err := in.listen(*port)
+ if err != nil {
+ // Should not happen, everything must be bound by this time
+ // this proxy is started.
+ log.Fatalf("unable to listen: %v", err)
+ }
+ log.Printf("client=%v, server=%v, ready.", *client, *server)
+
+ sigs := make(chan os.Signal, 1)
+ signal.Notify(sigs, syscall.SIGTERM)
+ go func() {
+ <-sigs
+ if *cpuprofile != "" {
+ pprof.StopCPUProfile()
+ }
+ if *memprofile != "" {
+ f, err := os.Create(*memprofile)
+ if err != nil {
+ log.Fatal("could not create memory profile: ", err)
+ }
+ defer func() {
+ if err := f.Close(); err != nil {
+ log.Print("error closing memory profile: ", err)
+ }
+ }()
+ runtime.GC() // get up-to-date statistics
+ if err := pprof.WriteHeapProfile(f); err != nil {
+ log.Fatalf("Unable to write heap profile: %v", err)
+ }
+ }
+ os.Exit(0)
+ }()
+
+ for {
+ // Forward all connections.
+ inConn, err := listener.Accept()
+ if err != nil {
+ // This should not happen; we are listening
+ // successfully. Exhausted all available FDs?
+ log.Fatalf("accept error: %v", err)
+ }
+ log.Printf("incoming connection established.")
+
+ // Copy both ways.
+ go io.Copy(inConn, next)
+ go io.Copy(next, inConn)
+
+ // Print stats every second.
+ go func() {
+ t := time.NewTicker(time.Second)
+ defer t.Stop()
+ for {
+ <-t.C
+ in.printStats()
+ out.printStats()
+ }
+ }()
+
+ for {
+ // Dial again.
+ next, err = out.dial(*forward)
+ if err == nil {
+ break
+ }
+ }
+ }
+}
diff --git a/benchmarks/workloads/BUILD b/benchmarks/workloads/BUILD
new file mode 100644
index 000000000..ccb86af5b
--- /dev/null
+++ b/benchmarks/workloads/BUILD
@@ -0,0 +1,35 @@
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+py_library(
+ name = "workloads",
+ srcs = ["__init__.py"],
+)
+
+filegroup(
+ name = "files",
+ srcs = [
+ "//benchmarks/workloads/ab:tar",
+ "//benchmarks/workloads/absl:tar",
+ "//benchmarks/workloads/curl:tar",
+ "//benchmarks/workloads/ffmpeg:tar",
+ "//benchmarks/workloads/fio:tar",
+ "//benchmarks/workloads/httpd:tar",
+ "//benchmarks/workloads/iperf:tar",
+ "//benchmarks/workloads/netcat:tar",
+ "//benchmarks/workloads/nginx:tar",
+ "//benchmarks/workloads/node:tar",
+ "//benchmarks/workloads/node_template:tar",
+ "//benchmarks/workloads/redis:tar",
+ "//benchmarks/workloads/redisbenchmark:tar",
+ "//benchmarks/workloads/ruby:tar",
+ "//benchmarks/workloads/ruby_template:tar",
+ "//benchmarks/workloads/sleep:tar",
+ "//benchmarks/workloads/sysbench:tar",
+ "//benchmarks/workloads/syscall:tar",
+ "//benchmarks/workloads/tensorflow:tar",
+ "//benchmarks/workloads/true:tar",
+ ],
+)
diff --git a/benchmarks/workloads/__init__.py b/benchmarks/workloads/__init__.py
new file mode 100644
index 000000000..e12651e76
--- /dev/null
+++ b/benchmarks/workloads/__init__.py
@@ -0,0 +1,14 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Workloads, parsers and test data."""
diff --git a/benchmarks/workloads/ab/BUILD b/benchmarks/workloads/ab/BUILD
new file mode 100644
index 000000000..945ac7026
--- /dev/null
+++ b/benchmarks/workloads/ab/BUILD
@@ -0,0 +1,28 @@
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+py_library(
+ name = "ab",
+ srcs = ["__init__.py"],
+)
+
+py_test(
+ name = "ab_test",
+ srcs = ["ab_test.py"],
+ python_version = "PY3",
+ deps = test_deps + [
+ ":ab",
+ ],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ ],
+)
diff --git a/benchmarks/workloads/ab/Dockerfile b/benchmarks/workloads/ab/Dockerfile
new file mode 100644
index 000000000..0d0b6e2eb
--- /dev/null
+++ b/benchmarks/workloads/ab/Dockerfile
@@ -0,0 +1,15 @@
+FROM ubuntu:18.04
+
+RUN set -x \
+ && apt-get update \
+ && apt-get install -y \
+ apache2-utils \
+ && rm -rf /var/lib/apt/lists/*
+
+# Parameterized workload.
+ENV requests 5000
+ENV connections 10
+ENV host localhost
+ENV port 8080
+ENV path notfound
+CMD ["sh", "-c", "ab -n ${requests} -c ${connections} http://${host}:${port}/${path}"]
diff --git a/benchmarks/workloads/ab/__init__.py b/benchmarks/workloads/ab/__init__.py
new file mode 100644
index 000000000..eedf8e083
--- /dev/null
+++ b/benchmarks/workloads/ab/__init__.py
@@ -0,0 +1,88 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Apachebench tool."""
+
+import re
+
+SAMPLE_DATA = """This is ApacheBench, Version 2.3 <$Revision: 1826891 $>
+Copyright 1996 Adam Twiss, Zeus Technology Ltd, http://www.zeustech.net/
+Licensed to The Apache Software Foundation, http://www.apache.org/
+
+Benchmarking 10.10.10.10 (be patient).....done
+
+
+Server Software: Apache/2.4.38
+Server Hostname: 10.10.10.10
+Server Port: 80
+
+Document Path: /latin10k.txt
+Document Length: 210 bytes
+
+Concurrency Level: 1
+Time taken for tests: 0.180 seconds
+Complete requests: 100
+Failed requests: 0
+Non-2xx responses: 100
+Total transferred: 38800 bytes
+HTML transferred: 21000 bytes
+Requests per second: 556.44 [#/sec] (mean)
+Time per request: 1.797 [ms] (mean)
+Time per request: 1.797 [ms] (mean, across all concurrent requests)
+Transfer rate: 210.84 [Kbytes/sec] received
+
+Connection Times (ms)
+ min mean[+/-sd] median max
+Connect: 0 0 0.2 0 2
+Processing: 1 2 1.0 1 8
+Waiting: 1 1 1.0 1 7
+Total: 1 2 1.2 1 10
+
+Percentage of the requests served within a certain time (ms)
+ 50% 1
+ 66% 2
+ 75% 2
+ 80% 2
+ 90% 2
+ 95% 3
+ 98% 7
+ 99% 10
+ 100% 10 (longest request)"""
+
+
+# pylint: disable=unused-argument
+def sample(**kwargs) -> str:
+ return SAMPLE_DATA
+
+
+# pylint: disable=unused-argument
+def transfer_rate(data: str, **kwargs) -> float:
+ """Mean transfer rate in Kbytes/sec."""
+ regex = r"Transfer rate:\s+(\d+\.?\d+?)\s+\[Kbytes/sec\]\s+received"
+ return float(re.compile(regex).search(data).group(1))
+
+
+# pylint: disable=unused-argument
+def latency(data: str, **kwargs) -> float:
+ """Mean latency in milliseconds."""
+ regex = r"Total:\s+\d+\s+(\d+)\s+(\d+\.?\d+?)\s+\d+\s+\d+\s"
+ res = re.compile(regex).search(data)
+ return float(res.group(1))
+
+
+# pylint: disable=unused-argument
+def requests_per_second(data: str, **kwargs) -> float:
+ """Requests per second."""
+ regex = r"Requests per second:\s+(\d+\.?\d+?)\s+"
+ res = re.compile(regex).search(data)
+ return float(res.group(1))
diff --git a/benchmarks/workloads/ab/ab_test.py b/benchmarks/workloads/ab/ab_test.py
new file mode 100644
index 000000000..4afac2996
--- /dev/null
+++ b/benchmarks/workloads/ab/ab_test.py
@@ -0,0 +1,42 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Parser test."""
+
+import sys
+
+import pytest
+
+from benchmarks.workloads import ab
+
+
+def test_transfer_rate_parser():
+ """Test transfer rate parser."""
+ res = ab.transfer_rate(ab.sample())
+ assert res == 210.84
+
+
+def test_latency_parser():
+ """Test latency parser."""
+ res = ab.latency(ab.sample())
+ assert res == 2
+
+
+def test_requests_per_second():
+ """Test requests per second parser."""
+ res = ab.requests_per_second(ab.sample())
+ assert res == 556.44
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main([__file__]))
diff --git a/benchmarks/workloads/absl/BUILD b/benchmarks/workloads/absl/BUILD
new file mode 100644
index 000000000..bb1a308bf
--- /dev/null
+++ b/benchmarks/workloads/absl/BUILD
@@ -0,0 +1,28 @@
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+py_library(
+ name = "absl",
+ srcs = ["__init__.py"],
+)
+
+py_test(
+ name = "absl_test",
+ srcs = ["absl_test.py"],
+ python_version = "PY3",
+ deps = test_deps + [
+ ":absl",
+ ],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ ],
+)
diff --git a/benchmarks/workloads/absl/Dockerfile b/benchmarks/workloads/absl/Dockerfile
new file mode 100644
index 000000000..f29cfa156
--- /dev/null
+++ b/benchmarks/workloads/absl/Dockerfile
@@ -0,0 +1,25 @@
+FROM ubuntu:18.04
+
+RUN set -x \
+ && apt-get update \
+ && apt-get install -y \
+ wget \
+ git \
+ pkg-config \
+ zip \
+ g++ \
+ zlib1g-dev \
+ unzip \
+ python3 \
+ && rm -rf /var/lib/apt/lists/*
+RUN wget https://github.com/bazelbuild/bazel/releases/download/0.27.0/bazel-0.27.0-installer-linux-x86_64.sh
+RUN chmod +x bazel-0.27.0-installer-linux-x86_64.sh
+RUN ./bazel-0.27.0-installer-linux-x86_64.sh
+
+RUN mkdir abseil-cpp && cd abseil-cpp \
+ && git init && git remote add origin https://github.com/abseil/abseil-cpp.git \
+ && git fetch --depth 1 origin 43ef2148c0936ebf7cb4be6b19927a9d9d145b8f && git checkout FETCH_HEAD
+WORKDIR abseil-cpp
+RUN bazel clean
+ENV path "absl/base/..."
+CMD bazel build ${path} 2>&1
diff --git a/benchmarks/workloads/absl/__init__.py b/benchmarks/workloads/absl/__init__.py
new file mode 100644
index 000000000..b40e3f915
--- /dev/null
+++ b/benchmarks/workloads/absl/__init__.py
@@ -0,0 +1,63 @@
+# python3
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ABSL build benchmark."""
+
+import re
+
+SAMPLE_BAZEL_OUTPUT = """Extracting Bazel installation...
+Starting local Bazel server and connecting to it...
+Loading:
+Loading: 0 packages loaded
+Loading: 0 packages loaded
+ currently loading: absl/algorithm ... (11 packages)
+Analyzing: 241 targets (16 packages loaded, 0 targets configured)
+Analyzing: 241 targets (21 packages loaded, 617 targets configured)
+Analyzing: 241 targets (27 packages loaded, 687 targets configured)
+Analyzing: 241 targets (32 packages loaded, 1105 targets configured)
+Analyzing: 241 targets (32 packages loaded, 1294 targets configured)
+Analyzing: 241 targets (35 packages loaded, 1575 targets configured)
+Analyzing: 241 targets (35 packages loaded, 1575 targets configured)
+Analyzing: 241 targets (36 packages loaded, 1603 targets configured)
+Analyzing: 241 targets (36 packages loaded, 1603 targets configured)
+INFO: Analyzed 241 targets (37 packages loaded, 1864 targets configured).
+INFO: Found 241 targets...
+[0 / 5] [Prepa] BazelWorkspaceStatusAction stable-status.txt
+[16 / 50] [Analy] Compiling absl/base/dynamic_annotations.cc ... (20 actions, 10 running)
+[60 / 77] Compiling external/com_google_googletest/googletest/src/gtest.cc; 5s processwrapper-sandbox ... (12 actions, 11 running)
+[158 / 174] Compiling absl/container/internal/raw_hash_set_test.cc; 2s processwrapper-sandbox ... (12 actions, 11 running)
+[278 / 302] Compiling absl/container/internal/raw_hash_set_test.cc; 6s processwrapper-sandbox ... (12 actions, 11 running)
+[384 / 406] Compiling absl/container/internal/raw_hash_set_test.cc; 10s processwrapper-sandbox ... (12 actions, 11 running)
+[581 / 604] Compiling absl/container/flat_hash_set_test.cc; 11s processwrapper-sandbox ... (12 actions, 11 running)
+[722 / 745] Compiling absl/container/node_hash_set_test.cc; 9s processwrapper-sandbox ... (12 actions, 11 running)
+[846 / 867] Compiling absl/hash/hash_test.cc; 11s processwrapper-sandbox ... (12 actions, 11 running)
+INFO: From Compiling absl/debugging/symbolize_test.cc:
+/tmp/cclCVipU.s: Assembler messages:
+/tmp/cclCVipU.s:1662: Warning: ignoring changed section attributes for .text
+[999 / 1,022] Compiling absl/hash/hash_test.cc; 19s processwrapper-sandbox ... (12 actions, 11 running)
+[1,082 / 1,084] Compiling absl/container/flat_hash_map_test.cc; 7s processwrapper-sandbox
+INFO: Elapsed time: 81.861s, Critical Path: 23.81s
+INFO: 515 processes: 515 processwrapper-sandbox.
+INFO: Build completed successfully, 1084 total actions
+INFO: Build completed successfully, 1084 total actions"""
+
+
+def sample():
+ return SAMPLE_BAZEL_OUTPUT
+
+
+# pylint: disable=unused-argument
+def elapsed_time(data: str, **kwargs) -> float:
+ """Returns the elapsed time for running an absl build."""
+ return float(re.compile(r"Elapsed time: (\d*.?\d*)s").search(data).group(1))
diff --git a/benchmarks/workloads/absl/absl_test.py b/benchmarks/workloads/absl/absl_test.py
new file mode 100644
index 000000000..41f216999
--- /dev/null
+++ b/benchmarks/workloads/absl/absl_test.py
@@ -0,0 +1,31 @@
+# python3
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ABSL build test."""
+
+import sys
+
+import pytest
+
+from benchmarks.workloads import absl
+
+
+def test_elapsed_time():
+ """Test elapsed_time."""
+ res = absl.elapsed_time(absl.sample())
+ assert res == 81.861
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main([__file__]))
diff --git a/benchmarks/workloads/curl/BUILD b/benchmarks/workloads/curl/BUILD
new file mode 100644
index 000000000..a70873065
--- /dev/null
+++ b/benchmarks/workloads/curl/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "pkg_tar")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ ],
+)
diff --git a/benchmarks/workloads/curl/Dockerfile b/benchmarks/workloads/curl/Dockerfile
new file mode 100644
index 000000000..336cb088a
--- /dev/null
+++ b/benchmarks/workloads/curl/Dockerfile
@@ -0,0 +1,14 @@
+FROM ubuntu:18.04
+
+RUN set -x \
+ && apt-get update \
+ && apt-get install -y \
+ curl \
+ && rm -rf /var/lib/apt/lists/*
+
+# Accept a host and port parameter.
+ENV host localhost
+ENV port 8080
+
+# Spin until we make a successful request.
+CMD ["sh", "-c", "while ! curl -v -i http://$host:$port; do true; done"]
diff --git a/benchmarks/workloads/ffmpeg/BUILD b/benchmarks/workloads/ffmpeg/BUILD
new file mode 100644
index 000000000..7c41ba631
--- /dev/null
+++ b/benchmarks/workloads/ffmpeg/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "pkg_tar")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+py_library(
+ name = "ffmpeg",
+ srcs = ["__init__.py"],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ ],
+)
diff --git a/benchmarks/workloads/ffmpeg/Dockerfile b/benchmarks/workloads/ffmpeg/Dockerfile
new file mode 100644
index 000000000..f2f530d7c
--- /dev/null
+++ b/benchmarks/workloads/ffmpeg/Dockerfile
@@ -0,0 +1,10 @@
+FROM ubuntu:18.04
+
+RUN set -x \
+ && apt-get update \
+ && apt-get install -y \
+ ffmpeg \
+ && rm -rf /var/lib/apt/lists/*
+WORKDIR /media
+ADD https://samples.ffmpeg.org/MPEG-4/video.mp4 video.mp4
+CMD ["ffmpeg", "-i", "video.mp4", "-c:v", "libx264", "-preset", "veryslow", "output.mp4"]
diff --git a/benchmarks/workloads/ffmpeg/__init__.py b/benchmarks/workloads/ffmpeg/__init__.py
new file mode 100644
index 000000000..7578a443b
--- /dev/null
+++ b/benchmarks/workloads/ffmpeg/__init__.py
@@ -0,0 +1,20 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Simple ffmpeg workload."""
+
+
+# pylint: disable=unused-argument
+def run_time(value, **kwargs):
+ """Returns the startup and runtime of the ffmpeg workload in seconds."""
+ return value
diff --git a/benchmarks/workloads/fio/BUILD b/benchmarks/workloads/fio/BUILD
new file mode 100644
index 000000000..24d909c53
--- /dev/null
+++ b/benchmarks/workloads/fio/BUILD
@@ -0,0 +1,28 @@
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+py_library(
+ name = "fio",
+ srcs = ["__init__.py"],
+)
+
+py_test(
+ name = "fio_test",
+ srcs = ["fio_test.py"],
+ python_version = "PY3",
+ deps = test_deps + [
+ ":fio",
+ ],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ ],
+)
diff --git a/benchmarks/workloads/fio/Dockerfile b/benchmarks/workloads/fio/Dockerfile
new file mode 100644
index 000000000..b3cf864eb
--- /dev/null
+++ b/benchmarks/workloads/fio/Dockerfile
@@ -0,0 +1,23 @@
+FROM ubuntu:18.04
+
+RUN set -x \
+ && apt-get update \
+ && apt-get install -y \
+ fio \
+ && rm -rf /var/lib/apt/lists/*
+
+# Parameterized test.
+ENV test write
+ENV ioengine sync
+ENV size 5000000
+ENV iodepth 4
+ENV blocksize "1m"
+ENV time ""
+ENV path "/disk/file.dat"
+ENV ramp_time 0
+
+CMD ["sh", "-c", "fio --output-format=json --name=test --ramp_time=${ramp_time} --ioengine=${ioengine} --size=${size} \
+--filename=${path} --iodepth=${iodepth} --bs=${blocksize} --rw=${test} ${time}"]
+
+
+
diff --git a/benchmarks/workloads/fio/__init__.py b/benchmarks/workloads/fio/__init__.py
new file mode 100644
index 000000000..52711e956
--- /dev/null
+++ b/benchmarks/workloads/fio/__init__.py
@@ -0,0 +1,369 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""FIO benchmark tool."""
+
+import json
+
+SAMPLE_DATA = """
+{
+ "fio version" : "fio-3.1",
+ "timestamp" : 1554837456,
+ "timestamp_ms" : 1554837456621,
+ "time" : "Tue Apr 9 19:17:36 2019",
+ "jobs" : [
+ {
+ "jobname" : "test",
+ "groupid" : 0,
+ "error" : 0,
+ "eta" : 2147483647,
+ "elapsed" : 1,
+ "job options" : {
+ "name" : "test",
+ "ioengine" : "sync",
+ "size" : "1073741824",
+ "filename" : "/disk/file.dat",
+ "iodepth" : "4",
+ "bs" : "4096",
+ "rw" : "write"
+ },
+ "read" : {
+ "io_bytes" : 0,
+ "io_kbytes" : 0,
+ "bw" : 0,
+ "iops" : 0.000000,
+ "runtime" : 0,
+ "total_ios" : 0,
+ "short_ios" : 0,
+ "drop_ios" : 0,
+ "slat_ns" : {
+ "min" : 0,
+ "max" : 0,
+ "mean" : 0.000000,
+ "stddev" : 0.000000
+ },
+ "clat_ns" : {
+ "min" : 0,
+ "max" : 0,
+ "mean" : 0.000000,
+ "stddev" : 0.000000,
+ "percentile" : {
+ "1.000000" : 0,
+ "5.000000" : 0,
+ "10.000000" : 0,
+ "20.000000" : 0,
+ "30.000000" : 0,
+ "40.000000" : 0,
+ "50.000000" : 0,
+ "60.000000" : 0,
+ "70.000000" : 0,
+ "80.000000" : 0,
+ "90.000000" : 0,
+ "95.000000" : 0,
+ "99.000000" : 0,
+ "99.500000" : 0,
+ "99.900000" : 0,
+ "99.950000" : 0,
+ "99.990000" : 0,
+ "0.00" : 0,
+ "0.00" : 0,
+ "0.00" : 0
+ }
+ },
+ "lat_ns" : {
+ "min" : 0,
+ "max" : 0,
+ "mean" : 0.000000,
+ "stddev" : 0.000000
+ },
+ "bw_min" : 0,
+ "bw_max" : 0,
+ "bw_agg" : 0.000000,
+ "bw_mean" : 0.000000,
+ "bw_dev" : 0.000000,
+ "bw_samples" : 0,
+ "iops_min" : 0,
+ "iops_max" : 0,
+ "iops_mean" : 0.000000,
+ "iops_stddev" : 0.000000,
+ "iops_samples" : 0
+ },
+ "write" : {
+ "io_bytes" : 1073741824,
+ "io_kbytes" : 1048576,
+ "bw" : 1753471,
+ "iops" : 438367.892977,
+ "runtime" : 598,
+ "total_ios" : 262144,
+ "short_ios" : 0,
+ "drop_ios" : 0,
+ "slat_ns" : {
+ "min" : 0,
+ "max" : 0,
+ "mean" : 0.000000,
+ "stddev" : 0.000000
+ },
+ "clat_ns" : {
+ "min" : 1693,
+ "max" : 754733,
+ "mean" : 2076.404373,
+ "stddev" : 1724.195529,
+ "percentile" : {
+ "1.000000" : 1736,
+ "5.000000" : 1752,
+ "10.000000" : 1768,
+ "20.000000" : 1784,
+ "30.000000" : 1800,
+ "40.000000" : 1800,
+ "50.000000" : 1816,
+ "60.000000" : 1816,
+ "70.000000" : 1848,
+ "80.000000" : 1928,
+ "90.000000" : 2512,
+ "95.000000" : 2992,
+ "99.000000" : 6176,
+ "99.500000" : 6304,
+ "99.900000" : 11328,
+ "99.950000" : 15168,
+ "99.990000" : 17792,
+ "0.00" : 0,
+ "0.00" : 0,
+ "0.00" : 0
+ }
+ },
+ "lat_ns" : {
+ "min" : 1731,
+ "max" : 754770,
+ "mean" : 2117.878979,
+ "stddev" : 1730.290512
+ },
+ "bw_min" : 1731120,
+ "bw_max" : 1731120,
+ "bw_agg" : 98.725328,
+ "bw_mean" : 1731120.000000,
+ "bw_dev" : 0.000000,
+ "bw_samples" : 1,
+ "iops_min" : 432780,
+ "iops_max" : 432780,
+ "iops_mean" : 432780.000000,
+ "iops_stddev" : 0.000000,
+ "iops_samples" : 1
+ },
+ "trim" : {
+ "io_bytes" : 0,
+ "io_kbytes" : 0,
+ "bw" : 0,
+ "iops" : 0.000000,
+ "runtime" : 0,
+ "total_ios" : 0,
+ "short_ios" : 0,
+ "drop_ios" : 0,
+ "slat_ns" : {
+ "min" : 0,
+ "max" : 0,
+ "mean" : 0.000000,
+ "stddev" : 0.000000
+ },
+ "clat_ns" : {
+ "min" : 0,
+ "max" : 0,
+ "mean" : 0.000000,
+ "stddev" : 0.000000,
+ "percentile" : {
+ "1.000000" : 0,
+ "5.000000" : 0,
+ "10.000000" : 0,
+ "20.000000" : 0,
+ "30.000000" : 0,
+ "40.000000" : 0,
+ "50.000000" : 0,
+ "60.000000" : 0,
+ "70.000000" : 0,
+ "80.000000" : 0,
+ "90.000000" : 0,
+ "95.000000" : 0,
+ "99.000000" : 0,
+ "99.500000" : 0,
+ "99.900000" : 0,
+ "99.950000" : 0,
+ "99.990000" : 0,
+ "0.00" : 0,
+ "0.00" : 0,
+ "0.00" : 0
+ }
+ },
+ "lat_ns" : {
+ "min" : 0,
+ "max" : 0,
+ "mean" : 0.000000,
+ "stddev" : 0.000000
+ },
+ "bw_min" : 0,
+ "bw_max" : 0,
+ "bw_agg" : 0.000000,
+ "bw_mean" : 0.000000,
+ "bw_dev" : 0.000000,
+ "bw_samples" : 0,
+ "iops_min" : 0,
+ "iops_max" : 0,
+ "iops_mean" : 0.000000,
+ "iops_stddev" : 0.000000,
+ "iops_samples" : 0
+ },
+ "usr_cpu" : 17.922948,
+ "sys_cpu" : 81.574539,
+ "ctx" : 3,
+ "majf" : 0,
+ "minf" : 10,
+ "iodepth_level" : {
+ "1" : 100.000000,
+ "2" : 0.000000,
+ "4" : 0.000000,
+ "8" : 0.000000,
+ "16" : 0.000000,
+ "32" : 0.000000,
+ ">=64" : 0.000000
+ },
+ "latency_ns" : {
+ "2" : 0.000000,
+ "4" : 0.000000,
+ "10" : 0.000000,
+ "20" : 0.000000,
+ "50" : 0.000000,
+ "100" : 0.000000,
+ "250" : 0.000000,
+ "500" : 0.000000,
+ "750" : 0.000000,
+ "1000" : 0.000000
+ },
+ "latency_us" : {
+ "2" : 82.737350,
+ "4" : 12.605286,
+ "10" : 4.543686,
+ "20" : 0.107956,
+ "50" : 0.010000,
+ "100" : 0.000000,
+ "250" : 0.000000,
+ "500" : 0.000000,
+ "750" : 0.000000,
+ "1000" : 0.010000
+ },
+ "latency_ms" : {
+ "2" : 0.000000,
+ "4" : 0.000000,
+ "10" : 0.000000,
+ "20" : 0.000000,
+ "50" : 0.000000,
+ "100" : 0.000000,
+ "250" : 0.000000,
+ "500" : 0.000000,
+ "750" : 0.000000,
+ "1000" : 0.000000,
+ "2000" : 0.000000,
+ ">=2000" : 0.000000
+ },
+ "latency_depth" : 4,
+ "latency_target" : 0,
+ "latency_percentile" : 100.000000,
+ "latency_window" : 0
+ }
+ ],
+ "disk_util" : [
+ {
+ "name" : "dm-1",
+ "read_ios" : 0,
+ "write_ios" : 3,
+ "read_merges" : 0,
+ "write_merges" : 0,
+ "read_ticks" : 0,
+ "write_ticks" : 0,
+ "in_queue" : 0,
+ "util" : 0.000000,
+ "aggr_read_ios" : 0,
+ "aggr_write_ios" : 3,
+ "aggr_read_merges" : 0,
+ "aggr_write_merge" : 0,
+ "aggr_read_ticks" : 0,
+ "aggr_write_ticks" : 0,
+ "aggr_in_queue" : 0,
+ "aggr_util" : 0.000000
+ },
+ {
+ "name" : "dm-0",
+ "read_ios" : 0,
+ "write_ios" : 3,
+ "read_merges" : 0,
+ "write_merges" : 0,
+ "read_ticks" : 0,
+ "write_ticks" : 0,
+ "in_queue" : 0,
+ "util" : 0.000000,
+ "aggr_read_ios" : 0,
+ "aggr_write_ios" : 3,
+ "aggr_read_merges" : 0,
+ "aggr_write_merge" : 0,
+ "aggr_read_ticks" : 0,
+ "aggr_write_ticks" : 2,
+ "aggr_in_queue" : 0,
+ "aggr_util" : 0.000000
+ },
+ {
+ "name" : "nvme0n1",
+ "read_ios" : 0,
+ "write_ios" : 3,
+ "read_merges" : 0,
+ "write_merges" : 0,
+ "read_ticks" : 0,
+ "write_ticks" : 2,
+ "in_queue" : 0,
+ "util" : 0.000000
+ }
+ ]
+}
+"""
+
+
+# pylint: disable=unused-argument
+def sample(**kwargs) -> str:
+ return SAMPLE_DATA
+
+
+# pylint: disable=unused-argument
+def read_bandwidth(data: str, **kwargs) -> int:
+ """File I/O bandwidth."""
+ return json.loads(data)["jobs"][0]["read"]["bw"] * 1024
+
+
+# pylint: disable=unused-argument
+def write_bandwidth(data: str, **kwargs) -> int:
+ """File I/O bandwidth."""
+ return json.loads(data)["jobs"][0]["write"]["bw"] * 1024
+
+
+# pylint: disable=unused-argument
+def read_io_ops(data: str, **kwargs) -> float:
+ """File I/O operations per second."""
+ return float(json.loads(data)["jobs"][0]["read"]["iops"])
+
+
+# pylint: disable=unused-argument
+def write_io_ops(data: str, **kwargs) -> float:
+ """File I/O operations per second."""
+ return float(json.loads(data)["jobs"][0]["write"]["iops"])
+
+
+# Change function names so we just print "bandwidth" and "io_ops".
+read_bandwidth.__name__ = "bandwidth"
+write_bandwidth.__name__ = "bandwidth"
+read_io_ops.__name__ = "io_ops"
+write_io_ops.__name__ = "io_ops"
diff --git a/benchmarks/workloads/fio/fio_test.py b/benchmarks/workloads/fio/fio_test.py
new file mode 100644
index 000000000..04a6eeb7e
--- /dev/null
+++ b/benchmarks/workloads/fio/fio_test.py
@@ -0,0 +1,44 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Parser tests."""
+
+import sys
+
+import pytest
+
+from benchmarks.workloads import fio
+
+
+def test_read_io_ops():
+ """Test read ops parser."""
+ assert fio.read_io_ops(fio.sample()) == 0.0
+
+
+def test_write_io_ops():
+ """Test write ops parser."""
+ assert fio.write_io_ops(fio.sample()) == 438367.892977
+
+
+def test_read_bandwidth():
+ """Test read bandwidth parser."""
+ assert fio.read_bandwidth(fio.sample()) == 0.0
+
+
+def test_write_bandwith():
+ """Test write bandwidth parser."""
+ assert fio.write_bandwidth(fio.sample()) == 1753471 * 1024
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main([__file__]))
diff --git a/benchmarks/workloads/httpd/BUILD b/benchmarks/workloads/httpd/BUILD
new file mode 100644
index 000000000..83450d190
--- /dev/null
+++ b/benchmarks/workloads/httpd/BUILD
@@ -0,0 +1,14 @@
+load("//tools:defs.bzl", "pkg_tar")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ "apache2-tmpdir.conf",
+ ],
+)
diff --git a/benchmarks/workloads/httpd/Dockerfile b/benchmarks/workloads/httpd/Dockerfile
new file mode 100644
index 000000000..52a550678
--- /dev/null
+++ b/benchmarks/workloads/httpd/Dockerfile
@@ -0,0 +1,27 @@
+FROM ubuntu:18.04
+
+RUN set -x \
+ && apt-get update \
+ && apt-get install -y \
+ apache2 \
+ && rm -rf /var/lib/apt/lists/*
+
+# Generate a bunch of relevant files.
+RUN mkdir -p /local && \
+ for size in 1 10 100 1000 1024 10240; do \
+ dd if=/dev/zero of=/local/latin${size}k.txt count=${size} bs=1024; \
+ done
+
+# Rewrite DocumentRoot to point to /tmp/html instead of the default path.
+RUN sed -i 's/DocumentRoot.*\/var\/www\/html$/DocumentRoot \/tmp\/html/' /etc/apache2/sites-enabled/000-default.conf
+COPY ./apache2-tmpdir.conf /etc/apache2/sites-enabled/apache2-tmpdir.conf
+
+# Standard settings.
+ENV APACHE_RUN_DIR /tmp
+ENV APACHE_RUN_USER nobody
+ENV APACHE_RUN_GROUP nogroup
+ENV APACHE_LOG_DIR /tmp
+ENV APACHE_PID_FILE /tmp/apache.pid
+
+# Copy on start-up; serve everything from /tmp (including the configuration).
+CMD ["sh", "-c", "mkdir -p /tmp/html && cp -a /local/* /tmp/html && apache2 -X"]
diff --git a/benchmarks/workloads/httpd/apache2-tmpdir.conf b/benchmarks/workloads/httpd/apache2-tmpdir.conf
new file mode 100644
index 000000000..e33f8d9bb
--- /dev/null
+++ b/benchmarks/workloads/httpd/apache2-tmpdir.conf
@@ -0,0 +1,5 @@
+<Directory /tmp/html/>
+ Options Indexes FollowSymLinks
+ AllowOverride None
+ Require all granted
+</Directory> \ No newline at end of file
diff --git a/benchmarks/workloads/iperf/BUILD b/benchmarks/workloads/iperf/BUILD
new file mode 100644
index 000000000..91b953718
--- /dev/null
+++ b/benchmarks/workloads/iperf/BUILD
@@ -0,0 +1,28 @@
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+py_library(
+ name = "iperf",
+ srcs = ["__init__.py"],
+)
+
+py_test(
+ name = "iperf_test",
+ srcs = ["iperf_test.py"],
+ python_version = "PY3",
+ deps = test_deps + [
+ ":iperf",
+ ],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ ],
+)
diff --git a/benchmarks/workloads/iperf/Dockerfile b/benchmarks/workloads/iperf/Dockerfile
new file mode 100644
index 000000000..9704c506c
--- /dev/null
+++ b/benchmarks/workloads/iperf/Dockerfile
@@ -0,0 +1,14 @@
+FROM ubuntu:18.04
+
+RUN set -x \
+ && apt-get update \
+ && apt-get install -y \
+ iperf \
+ && rm -rf /var/lib/apt/lists/*
+
+# Accept a host parameter.
+ENV host ""
+ENV port 5001
+
+# Start as client if the host is provided.
+CMD ["sh", "-c", "test -z \"${host}\" && iperf -s || iperf -f K --realtime -c ${host} -p ${port}"]
diff --git a/benchmarks/workloads/iperf/__init__.py b/benchmarks/workloads/iperf/__init__.py
new file mode 100644
index 000000000..3817a7ade
--- /dev/null
+++ b/benchmarks/workloads/iperf/__init__.py
@@ -0,0 +1,40 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""iperf."""
+
+import re
+
+SAMPLE_DATA = """
+------------------------------------------------------------
+Client connecting to 10.138.15.215, TCP port 32779
+TCP window size: 45.0 KByte (default)
+------------------------------------------------------------
+[ 3] local 10.138.15.216 port 32866 connected with 10.138.15.215 port 32779
+[ ID] Interval Transfer Bandwidth
+[ 3] 0.0-10.0 sec 459520 KBytes 45900 KBytes/sec
+
+"""
+
+
+# pylint: disable=unused-argument
+def sample(**kwargs) -> str:
+ return SAMPLE_DATA
+
+
+# pylint: disable=unused-argument
+def bandwidth(data: str, **kwargs) -> float:
+ """Calculate the bandwidth."""
+ regex = r"\[\s*\d+\][^\n]+\s+(\d+\.?\d*)\s+KBytes/sec"
+ res = re.compile(regex).search(data)
+ return float(res.group(1)) * 1000
diff --git a/benchmarks/workloads/iperf/iperf_test.py b/benchmarks/workloads/iperf/iperf_test.py
new file mode 100644
index 000000000..6959b7e8a
--- /dev/null
+++ b/benchmarks/workloads/iperf/iperf_test.py
@@ -0,0 +1,28 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for iperf."""
+
+import sys
+
+import pytest
+
+from benchmarks.workloads import iperf
+
+
+def test_bandwidth():
+ assert iperf.bandwidth(iperf.sample()) == 45900 * 1000
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main([__file__]))
diff --git a/benchmarks/workloads/netcat/BUILD b/benchmarks/workloads/netcat/BUILD
new file mode 100644
index 000000000..a70873065
--- /dev/null
+++ b/benchmarks/workloads/netcat/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "pkg_tar")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ ],
+)
diff --git a/benchmarks/workloads/netcat/Dockerfile b/benchmarks/workloads/netcat/Dockerfile
new file mode 100644
index 000000000..d8548d89a
--- /dev/null
+++ b/benchmarks/workloads/netcat/Dockerfile
@@ -0,0 +1,14 @@
+FROM ubuntu:18.04
+
+RUN set -x \
+ && apt-get update \
+ && apt-get install -y \
+ netcat \
+ && rm -rf /var/lib/apt/lists/*
+
+# Accept a host and port parameter.
+ENV host localhost
+ENV port 8080
+
+# Spin until we make a successful request.
+CMD ["sh", "-c", "while ! nc -zv $host $port; do true; done"]
diff --git a/benchmarks/workloads/nginx/BUILD b/benchmarks/workloads/nginx/BUILD
new file mode 100644
index 000000000..a70873065
--- /dev/null
+++ b/benchmarks/workloads/nginx/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "pkg_tar")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ ],
+)
diff --git a/benchmarks/workloads/nginx/Dockerfile b/benchmarks/workloads/nginx/Dockerfile
new file mode 100644
index 000000000..b64eb52ae
--- /dev/null
+++ b/benchmarks/workloads/nginx/Dockerfile
@@ -0,0 +1 @@
+FROM nginx:1.15.10
diff --git a/benchmarks/workloads/node/BUILD b/benchmarks/workloads/node/BUILD
new file mode 100644
index 000000000..bfcf78cf9
--- /dev/null
+++ b/benchmarks/workloads/node/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "pkg_tar")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ "index.js",
+ "package.json",
+ ],
+)
diff --git a/benchmarks/workloads/node/Dockerfile b/benchmarks/workloads/node/Dockerfile
new file mode 100644
index 000000000..139a38bf5
--- /dev/null
+++ b/benchmarks/workloads/node/Dockerfile
@@ -0,0 +1,2 @@
+FROM node:onbuild
+CMD ["node", "index.js"]
diff --git a/benchmarks/workloads/node/index.js b/benchmarks/workloads/node/index.js
new file mode 100644
index 000000000..584158462
--- /dev/null
+++ b/benchmarks/workloads/node/index.js
@@ -0,0 +1,28 @@
+'use strict';
+
+var start = new Date().getTime();
+
+// Load dependencies to simulate an average nodejs app.
+var req_0 = require('async');
+var req_1 = require('bluebird');
+var req_2 = require('firebase');
+var req_3 = require('firebase-admin');
+var req_4 = require('@google-cloud/container');
+var req_5 = require('@google-cloud/logging');
+var req_6 = require('@google-cloud/monitoring');
+var req_7 = require('@google-cloud/spanner');
+var req_8 = require('lodash');
+var req_9 = require('mailgun-js');
+var req_10 = require('request');
+var express = require('express');
+var app = express();
+
+var loaded = new Date().getTime() - start;
+app.get('/', function(req, res) {
+ res.send('Hello World!<br>Loaded in ' + loaded + 'ms');
+});
+
+console.log('Loaded in ' + loaded + ' ms');
+app.listen(8080, function() {
+ console.log('Listening on port 8080...');
+});
diff --git a/benchmarks/workloads/node/package.json b/benchmarks/workloads/node/package.json
new file mode 100644
index 000000000..c00b9b3cb
--- /dev/null
+++ b/benchmarks/workloads/node/package.json
@@ -0,0 +1,19 @@
+{
+ "name": "node",
+ "version": "1.0.0",
+ "main": "index.js",
+ "dependencies": {
+ "@google-cloud/container": "^0.3.0",
+ "@google-cloud/logging": "^4.2.0",
+ "@google-cloud/monitoring": "^0.6.0",
+ "@google-cloud/spanner": "^2.2.1",
+ "async": "^2.6.1",
+ "bluebird": "^3.5.3",
+ "express": "^4.16.4",
+ "firebase": "^5.7.2",
+ "firebase-admin": "^6.4.0",
+ "lodash": "^4.17.11",
+ "mailgun-js": "^0.22.0",
+ "request": "^2.88.0"
+ }
+}
diff --git a/benchmarks/workloads/node_template/BUILD b/benchmarks/workloads/node_template/BUILD
new file mode 100644
index 000000000..e142f082a
--- /dev/null
+++ b/benchmarks/workloads/node_template/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "pkg_tar")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ "index.hbs",
+ "index.js",
+ "package.json",
+ "package-lock.json",
+ ],
+)
diff --git a/benchmarks/workloads/node_template/Dockerfile b/benchmarks/workloads/node_template/Dockerfile
new file mode 100644
index 000000000..7eb065d54
--- /dev/null
+++ b/benchmarks/workloads/node_template/Dockerfile
@@ -0,0 +1,5 @@
+FROM node:onbuild
+
+ENV host "127.0.0.1"
+
+CMD ["sh", "-c", "node index.js ${host}"]
diff --git a/benchmarks/workloads/node_template/index.hbs b/benchmarks/workloads/node_template/index.hbs
new file mode 100644
index 000000000..03feceb75
--- /dev/null
+++ b/benchmarks/workloads/node_template/index.hbs
@@ -0,0 +1,8 @@
+<!DOCTYPE html>
+<html>
+<body>
+ {{#each text}}
+ <p>{{this}}</p>
+ {{/each}}
+</body>
+</html>
diff --git a/benchmarks/workloads/node_template/index.js b/benchmarks/workloads/node_template/index.js
new file mode 100644
index 000000000..04a27f356
--- /dev/null
+++ b/benchmarks/workloads/node_template/index.js
@@ -0,0 +1,43 @@
+const app = require('express')();
+const path = require('path');
+const redis = require('redis');
+const srs = require('secure-random-string');
+
+// The hostname is the first argument.
+const host_name = process.argv[2];
+
+var client = redis.createClient({host: host_name, detect_buffers: true});
+
+app.set('views', __dirname);
+app.set('view engine', 'hbs');
+
+app.get('/', (req, res) => {
+ var tmp = [];
+ /* Pull four random keys from the redis server. */
+ for (i = 0; i < 4; i++) {
+ client.get(Math.floor(Math.random() * (100)), function(err, reply) {
+ tmp.push(reply.toString());
+ });
+ }
+
+ res.render('index', {text: tmp});
+});
+
+/**
+ * Securely generate a random string.
+ * @param {number} len
+ * @return {string}
+ */
+function randomBody(len) {
+ return srs({alphanumeric: true, length: len});
+}
+
+/** Mutates one hundred keys randomly. */
+function generateText() {
+ for (i = 0; i < 100; i++) {
+ client.set(i, randomBody(1024));
+ }
+}
+
+generateText();
+app.listen(8080);
diff --git a/benchmarks/workloads/node_template/package-lock.json b/benchmarks/workloads/node_template/package-lock.json
new file mode 100644
index 000000000..580e68aa5
--- /dev/null
+++ b/benchmarks/workloads/node_template/package-lock.json
@@ -0,0 +1,486 @@
+{
+ "name": "nodedum",
+ "version": "1.0.0",
+ "lockfileVersion": 1,
+ "requires": true,
+ "dependencies": {
+ "accepts": {
+ "version": "1.3.5",
+ "resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.5.tgz",
+ "integrity": "sha1-63d99gEXI6OxTopywIBcjoZ0a9I=",
+ "requires": {
+ "mime-types": "~2.1.18",
+ "negotiator": "0.6.1"
+ }
+ },
+ "array-flatten": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz",
+ "integrity": "sha1-ml9pkFGx5wczKPKgCJaLZOopVdI="
+ },
+ "async": {
+ "version": "2.6.2",
+ "resolved": "https://registry.npmjs.org/async/-/async-2.6.2.tgz",
+ "integrity": "sha512-H1qVYh1MYhEEFLsP97cVKqCGo7KfCyTt6uEWqsTBr9SO84oK9Uwbyd/yCW+6rKJLHksBNUVWZDAjfS+Ccx0Bbg==",
+ "requires": {
+ "lodash": "^4.17.11"
+ }
+ },
+ "body-parser": {
+ "version": "1.18.3",
+ "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.18.3.tgz",
+ "integrity": "sha1-WykhmP/dVTs6DyDe0FkrlWlVyLQ=",
+ "requires": {
+ "bytes": "3.0.0",
+ "content-type": "~1.0.4",
+ "debug": "2.6.9",
+ "depd": "~1.1.2",
+ "http-errors": "~1.6.3",
+ "iconv-lite": "0.4.23",
+ "on-finished": "~2.3.0",
+ "qs": "6.5.2",
+ "raw-body": "2.3.3",
+ "type-is": "~1.6.16"
+ }
+ },
+ "bytes": {
+ "version": "3.0.0",
+ "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.0.0.tgz",
+ "integrity": "sha1-0ygVQE1olpn4Wk6k+odV3ROpYEg="
+ },
+ "commander": {
+ "version": "2.20.0",
+ "resolved": "https://registry.npmjs.org/commander/-/commander-2.20.0.tgz",
+ "integrity": "sha512-7j2y+40w61zy6YC2iRNpUe/NwhNyoXrYpHMrSunaMG64nRnaf96zO/KMQR4OyN/UnE5KLyEBnKHd4aG3rskjpQ==",
+ "optional": true
+ },
+ "content-disposition": {
+ "version": "0.5.2",
+ "resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-0.5.2.tgz",
+ "integrity": "sha1-DPaLud318r55YcOoUXjLhdunjLQ="
+ },
+ "content-type": {
+ "version": "1.0.4",
+ "resolved": "https://registry.npmjs.org/content-type/-/content-type-1.0.4.tgz",
+ "integrity": "sha512-hIP3EEPs8tB9AT1L+NUqtwOAps4mk2Zob89MWXMHjHWg9milF/j4osnnQLXBCBFBk/tvIG/tUc9mOUJiPBhPXA=="
+ },
+ "cookie": {
+ "version": "0.3.1",
+ "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.3.1.tgz",
+ "integrity": "sha1-5+Ch+e9DtMi6klxcWpboBtFoc7s="
+ },
+ "cookie-signature": {
+ "version": "1.0.6",
+ "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.0.6.tgz",
+ "integrity": "sha1-4wOogrNCzD7oylE6eZmXNNqzriw="
+ },
+ "debug": {
+ "version": "2.6.9",
+ "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz",
+ "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==",
+ "requires": {
+ "ms": "2.0.0"
+ }
+ },
+ "depd": {
+ "version": "1.1.2",
+ "resolved": "https://registry.npmjs.org/depd/-/depd-1.1.2.tgz",
+ "integrity": "sha1-m81S4UwJd2PnSbJ0xDRu0uVgtak="
+ },
+ "destroy": {
+ "version": "1.0.4",
+ "resolved": "https://registry.npmjs.org/destroy/-/destroy-1.0.4.tgz",
+ "integrity": "sha1-l4hXRCxEdJ5CBmE+N5RiBYJqvYA="
+ },
+ "double-ended-queue": {
+ "version": "2.1.0-0",
+ "resolved": "https://registry.npmjs.org/double-ended-queue/-/double-ended-queue-2.1.0-0.tgz",
+ "integrity": "sha1-ED01J/0xUo9AGIEwyEHv3XgmTlw="
+ },
+ "ee-first": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz",
+ "integrity": "sha1-WQxhFWsK4vTwJVcyoViyZrxWsh0="
+ },
+ "encodeurl": {
+ "version": "1.0.2",
+ "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz",
+ "integrity": "sha1-rT/0yG7C0CkyL1oCw6mmBslbP1k="
+ },
+ "escape-html": {
+ "version": "1.0.3",
+ "resolved": "https://registry.npmjs.org/escape-html/-/escape-html-1.0.3.tgz",
+ "integrity": "sha1-Aljq5NPQwJdN4cFpGI7wBR0dGYg="
+ },
+ "etag": {
+ "version": "1.8.1",
+ "resolved": "https://registry.npmjs.org/etag/-/etag-1.8.1.tgz",
+ "integrity": "sha1-Qa4u62XvpiJorr/qg6x9eSmbCIc="
+ },
+ "express": {
+ "version": "4.16.4",
+ "resolved": "https://registry.npmjs.org/express/-/express-4.16.4.tgz",
+ "integrity": "sha512-j12Uuyb4FMrd/qQAm6uCHAkPtO8FDTRJZBDd5D2KOL2eLaz1yUNdUB/NOIyq0iU4q4cFarsUCrnFDPBcnksuOg==",
+ "requires": {
+ "accepts": "~1.3.5",
+ "array-flatten": "1.1.1",
+ "body-parser": "1.18.3",
+ "content-disposition": "0.5.2",
+ "content-type": "~1.0.4",
+ "cookie": "0.3.1",
+ "cookie-signature": "1.0.6",
+ "debug": "2.6.9",
+ "depd": "~1.1.2",
+ "encodeurl": "~1.0.2",
+ "escape-html": "~1.0.3",
+ "etag": "~1.8.1",
+ "finalhandler": "1.1.1",
+ "fresh": "0.5.2",
+ "merge-descriptors": "1.0.1",
+ "methods": "~1.1.2",
+ "on-finished": "~2.3.0",
+ "parseurl": "~1.3.2",
+ "path-to-regexp": "0.1.7",
+ "proxy-addr": "~2.0.4",
+ "qs": "6.5.2",
+ "range-parser": "~1.2.0",
+ "safe-buffer": "5.1.2",
+ "send": "0.16.2",
+ "serve-static": "1.13.2",
+ "setprototypeof": "1.1.0",
+ "statuses": "~1.4.0",
+ "type-is": "~1.6.16",
+ "utils-merge": "1.0.1",
+ "vary": "~1.1.2"
+ }
+ },
+ "finalhandler": {
+ "version": "1.1.1",
+ "resolved": "http://registry.npmjs.org/finalhandler/-/finalhandler-1.1.1.tgz",
+ "integrity": "sha512-Y1GUDo39ez4aHAw7MysnUD5JzYX+WaIj8I57kO3aEPT1fFRL4sr7mjei97FgnwhAyyzRYmQZaTHb2+9uZ1dPtg==",
+ "requires": {
+ "debug": "2.6.9",
+ "encodeurl": "~1.0.2",
+ "escape-html": "~1.0.3",
+ "on-finished": "~2.3.0",
+ "parseurl": "~1.3.2",
+ "statuses": "~1.4.0",
+ "unpipe": "~1.0.0"
+ }
+ },
+ "foreachasync": {
+ "version": "3.0.0",
+ "resolved": "https://registry.npmjs.org/foreachasync/-/foreachasync-3.0.0.tgz",
+ "integrity": "sha1-VQKYfchxS+M5IJfzLgBxyd7gfPY="
+ },
+ "forwarded": {
+ "version": "0.1.2",
+ "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.1.2.tgz",
+ "integrity": "sha1-mMI9qxF1ZXuMBXPozszZGw/xjIQ="
+ },
+ "fresh": {
+ "version": "0.5.2",
+ "resolved": "https://registry.npmjs.org/fresh/-/fresh-0.5.2.tgz",
+ "integrity": "sha1-PYyt2Q2XZWn6g1qx+OSyOhBWBac="
+ },
+ "handlebars": {
+ "version": "4.0.14",
+ "resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.0.14.tgz",
+ "integrity": "sha512-E7tDoyAA8ilZIV3xDJgl18sX3M8xB9/fMw8+mfW4msLW8jlX97bAnWgT3pmaNXuvzIEgSBMnAHfuXsB2hdzfow==",
+ "requires": {
+ "async": "^2.5.0",
+ "optimist": "^0.6.1",
+ "source-map": "^0.6.1",
+ "uglify-js": "^3.1.4"
+ }
+ },
+ "hbs": {
+ "version": "4.0.4",
+ "resolved": "https://registry.npmjs.org/hbs/-/hbs-4.0.4.tgz",
+ "integrity": "sha512-esVlyV/V59mKkwFai5YmPRSNIWZzhqL5YMN0++ueMxyK1cCfPa5f6JiHtapPKAIVAhQR6rpGxow0troav9WMEg==",
+ "requires": {
+ "handlebars": "4.0.14",
+ "walk": "2.3.9"
+ }
+ },
+ "http-errors": {
+ "version": "1.6.3",
+ "resolved": "http://registry.npmjs.org/http-errors/-/http-errors-1.6.3.tgz",
+ "integrity": "sha1-i1VoC7S+KDoLW/TqLjhYC+HZMg0=",
+ "requires": {
+ "depd": "~1.1.2",
+ "inherits": "2.0.3",
+ "setprototypeof": "1.1.0",
+ "statuses": ">= 1.4.0 < 2"
+ }
+ },
+ "iconv-lite": {
+ "version": "0.4.23",
+ "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.23.tgz",
+ "integrity": "sha512-neyTUVFtahjf0mB3dZT77u+8O0QB89jFdnBkd5P1JgYPbPaia3gXXOVL2fq8VyU2gMMD7SaN7QukTB/pmXYvDA==",
+ "requires": {
+ "safer-buffer": ">= 2.1.2 < 3"
+ }
+ },
+ "inherits": {
+ "version": "2.0.3",
+ "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.3.tgz",
+ "integrity": "sha1-Yzwsg+PaQqUC9SRmAiSA9CCCYd4="
+ },
+ "ipaddr.js": {
+ "version": "1.8.0",
+ "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.8.0.tgz",
+ "integrity": "sha1-6qM9bd16zo9/b+DJygRA5wZzix4="
+ },
+ "lodash": {
+ "version": "4.17.15",
+ "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.15.tgz",
+ "integrity": "sha512-8xOcRHvCjnocdS5cpwXQXVzmmh5e5+saE2QGoeQmbKmRS6J3VQppPOIt0MnmE+4xlZoumy0GPG0D0MVIQbNA1A=="
+ },
+ "media-typer": {
+ "version": "0.3.0",
+ "resolved": "http://registry.npmjs.org/media-typer/-/media-typer-0.3.0.tgz",
+ "integrity": "sha1-hxDXrwqmJvj/+hzgAWhUUmMlV0g="
+ },
+ "merge-descriptors": {
+ "version": "1.0.1",
+ "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-1.0.1.tgz",
+ "integrity": "sha1-sAqqVW3YtEVoFQ7J0blT8/kMu2E="
+ },
+ "methods": {
+ "version": "1.1.2",
+ "resolved": "https://registry.npmjs.org/methods/-/methods-1.1.2.tgz",
+ "integrity": "sha1-VSmk1nZUE07cxSZmVoNbD4Ua/O4="
+ },
+ "mime": {
+ "version": "1.4.1",
+ "resolved": "https://registry.npmjs.org/mime/-/mime-1.4.1.tgz",
+ "integrity": "sha512-KI1+qOZu5DcW6wayYHSzR/tXKCDC5Om4s1z2QJjDULzLcmf3DvzS7oluY4HCTrc+9FiKmWUgeNLg7W3uIQvxtQ=="
+ },
+ "mime-db": {
+ "version": "1.37.0",
+ "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.37.0.tgz",
+ "integrity": "sha512-R3C4db6bgQhlIhPU48fUtdVmKnflq+hRdad7IyKhtFj06VPNVdk2RhiYL3UjQIlso8L+YxAtFkobT0VK+S/ybg=="
+ },
+ "mime-types": {
+ "version": "2.1.21",
+ "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.21.tgz",
+ "integrity": "sha512-3iL6DbwpyLzjR3xHSFNFeb9Nz/M8WDkX33t1GFQnFOllWk8pOrh/LSrB5OXlnlW5P9LH73X6loW/eogc+F5lJg==",
+ "requires": {
+ "mime-db": "~1.37.0"
+ }
+ },
+ "minimist": {
+ "version": "0.0.10",
+ "resolved": "https://registry.npmjs.org/minimist/-/minimist-0.0.10.tgz",
+ "integrity": "sha1-3j+YVD2/lggr5IrRoMfNqDYwHc8="
+ },
+ "ms": {
+ "version": "2.0.0",
+ "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz",
+ "integrity": "sha1-VgiurfwAvmwpAd9fmGF4jeDVl8g="
+ },
+ "negotiator": {
+ "version": "0.6.1",
+ "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.1.tgz",
+ "integrity": "sha1-KzJxhOiZIQEXeyhWP7XnECrNDKk="
+ },
+ "on-finished": {
+ "version": "2.3.0",
+ "resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.3.0.tgz",
+ "integrity": "sha1-IPEzZIGwg811M3mSoWlxqi2QaUc=",
+ "requires": {
+ "ee-first": "1.1.1"
+ }
+ },
+ "optimist": {
+ "version": "0.6.1",
+ "resolved": "https://registry.npmjs.org/optimist/-/optimist-0.6.1.tgz",
+ "integrity": "sha1-2j6nRob6IaGaERwybpDrFaAZZoY=",
+ "requires": {
+ "minimist": "~0.0.1",
+ "wordwrap": "~0.0.2"
+ }
+ },
+ "parseurl": {
+ "version": "1.3.2",
+ "resolved": "https://registry.npmjs.org/parseurl/-/parseurl-1.3.2.tgz",
+ "integrity": "sha1-/CidTtiZMRlGDBViUyYs3I3mW/M="
+ },
+ "path-to-regexp": {
+ "version": "0.1.7",
+ "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.7.tgz",
+ "integrity": "sha1-32BBeABfUi8V60SQ5yR6G/qmf4w="
+ },
+ "proxy-addr": {
+ "version": "2.0.4",
+ "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.4.tgz",
+ "integrity": "sha512-5erio2h9jp5CHGwcybmxmVqHmnCBZeewlfJ0pex+UW7Qny7OOZXTtH56TGNyBizkgiOwhJtMKrVzDTeKcySZwA==",
+ "requires": {
+ "forwarded": "~0.1.2",
+ "ipaddr.js": "1.8.0"
+ }
+ },
+ "qs": {
+ "version": "6.5.2",
+ "resolved": "https://registry.npmjs.org/qs/-/qs-6.5.2.tgz",
+ "integrity": "sha512-N5ZAX4/LxJmF+7wN74pUD6qAh9/wnvdQcjq9TZjevvXzSUo7bfmw91saqMjzGS2xq91/odN2dW/WOl7qQHNDGA=="
+ },
+ "range-parser": {
+ "version": "1.2.0",
+ "resolved": "https://registry.npmjs.org/range-parser/-/range-parser-1.2.0.tgz",
+ "integrity": "sha1-9JvmtIeJTdxA3MlKMi9hEJLgDV4="
+ },
+ "raw-body": {
+ "version": "2.3.3",
+ "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.3.3.tgz",
+ "integrity": "sha512-9esiElv1BrZoI3rCDuOuKCBRbuApGGaDPQfjSflGxdy4oyzqghxu6klEkkVIvBje+FF0BX9coEv8KqW6X/7njw==",
+ "requires": {
+ "bytes": "3.0.0",
+ "http-errors": "1.6.3",
+ "iconv-lite": "0.4.23",
+ "unpipe": "1.0.0"
+ }
+ },
+ "redis": {
+ "version": "2.8.0",
+ "resolved": "https://registry.npmjs.org/redis/-/redis-2.8.0.tgz",
+ "integrity": "sha512-M1OkonEQwtRmZv4tEWF2VgpG0JWJ8Fv1PhlgT5+B+uNq2cA3Rt1Yt/ryoR+vQNOQcIEgdCdfH0jr3bDpihAw1A==",
+ "requires": {
+ "double-ended-queue": "^2.1.0-0",
+ "redis-commands": "^1.2.0",
+ "redis-parser": "^2.6.0"
+ },
+ "dependencies": {
+ "redis-commands": {
+ "version": "1.4.0",
+ "resolved": "https://registry.npmjs.org/redis-commands/-/redis-commands-1.4.0.tgz",
+ "integrity": "sha512-cu8EF+MtkwI4DLIT0x9P8qNTLFhQD4jLfxLR0cCNkeGzs87FN6879JOJwNQR/1zD7aSYNbU0hgsV9zGY71Itvw=="
+ },
+ "redis-parser": {
+ "version": "2.6.0",
+ "resolved": "https://registry.npmjs.org/redis-parser/-/redis-parser-2.6.0.tgz",
+ "integrity": "sha1-Uu0J2srBCPGmMcB+m2mUHnoZUEs="
+ }
+ }
+ },
+ "redis-commands": {
+ "version": "1.5.0",
+ "resolved": "https://registry.npmjs.org/redis-commands/-/redis-commands-1.5.0.tgz",
+ "integrity": "sha512-6KxamqpZ468MeQC3bkWmCB1fp56XL64D4Kf0zJSwDZbVLLm7KFkoIcHrgRvQ+sk8dnhySs7+yBg94yIkAK7aJg=="
+ },
+ "redis-parser": {
+ "version": "2.6.0",
+ "resolved": "https://registry.npmjs.org/redis-parser/-/redis-parser-2.6.0.tgz",
+ "integrity": "sha1-Uu0J2srBCPGmMcB+m2mUHnoZUEs="
+ },
+ "safe-buffer": {
+ "version": "5.1.2",
+ "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.2.tgz",
+ "integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g=="
+ },
+ "safer-buffer": {
+ "version": "2.1.2",
+ "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz",
+ "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg=="
+ },
+ "secure-random-string": {
+ "version": "1.1.0",
+ "resolved": "https://registry.npmjs.org/secure-random-string/-/secure-random-string-1.1.0.tgz",
+ "integrity": "sha512-V/h8jqoz58zklNGybVhP++cWrxEPXlLM/6BeJ4e0a8zlb4BsbYRzFs16snrxByPa5LUxCVTD3M6EYIVIHR1fAg=="
+ },
+ "send": {
+ "version": "0.16.2",
+ "resolved": "https://registry.npmjs.org/send/-/send-0.16.2.tgz",
+ "integrity": "sha512-E64YFPUssFHEFBvpbbjr44NCLtI1AohxQ8ZSiJjQLskAdKuriYEP6VyGEsRDH8ScozGpkaX1BGvhanqCwkcEZw==",
+ "requires": {
+ "debug": "2.6.9",
+ "depd": "~1.1.2",
+ "destroy": "~1.0.4",
+ "encodeurl": "~1.0.2",
+ "escape-html": "~1.0.3",
+ "etag": "~1.8.1",
+ "fresh": "0.5.2",
+ "http-errors": "~1.6.2",
+ "mime": "1.4.1",
+ "ms": "2.0.0",
+ "on-finished": "~2.3.0",
+ "range-parser": "~1.2.0",
+ "statuses": "~1.4.0"
+ }
+ },
+ "serve-static": {
+ "version": "1.13.2",
+ "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-1.13.2.tgz",
+ "integrity": "sha512-p/tdJrO4U387R9oMjb1oj7qSMaMfmOyd4j9hOFoxZe2baQszgHcSWjuya/CiT5kgZZKRudHNOA0pYXOl8rQ5nw==",
+ "requires": {
+ "encodeurl": "~1.0.2",
+ "escape-html": "~1.0.3",
+ "parseurl": "~1.3.2",
+ "send": "0.16.2"
+ }
+ },
+ "setprototypeof": {
+ "version": "1.1.0",
+ "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.1.0.tgz",
+ "integrity": "sha512-BvE/TwpZX4FXExxOxZyRGQQv651MSwmWKZGqvmPcRIjDqWub67kTKuIMx43cZZrS/cBBzwBcNDWoFxt2XEFIpQ=="
+ },
+ "source-map": {
+ "version": "0.6.1",
+ "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz",
+ "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g=="
+ },
+ "statuses": {
+ "version": "1.4.0",
+ "resolved": "https://registry.npmjs.org/statuses/-/statuses-1.4.0.tgz",
+ "integrity": "sha512-zhSCtt8v2NDrRlPQpCNtw/heZLtfUDqxBM1udqikb/Hbk52LK4nQSwr10u77iopCW5LsyHpuXS0GnEc48mLeew=="
+ },
+ "type-is": {
+ "version": "1.6.16",
+ "resolved": "https://registry.npmjs.org/type-is/-/type-is-1.6.16.tgz",
+ "integrity": "sha512-HRkVv/5qY2G6I8iab9cI7v1bOIdhm94dVjQCPFElW9W+3GeDOSHmy2EBYe4VTApuzolPcmgFTN3ftVJRKR2J9Q==",
+ "requires": {
+ "media-typer": "0.3.0",
+ "mime-types": "~2.1.18"
+ }
+ },
+ "uglify-js": {
+ "version": "3.5.9",
+ "resolved": "https://registry.npmjs.org/uglify-js/-/uglify-js-3.5.9.tgz",
+ "integrity": "sha512-WpT0RqsDtAWPNJK955DEnb6xjymR8Fn0OlK4TT4pS0ASYsVPqr5ELhgwOwLCP5J5vHeJ4xmMmz3DEgdqC10JeQ==",
+ "optional": true,
+ "requires": {
+ "commander": "~2.20.0",
+ "source-map": "~0.6.1"
+ }
+ },
+ "unpipe": {
+ "version": "1.0.0",
+ "resolved": "https://registry.npmjs.org/unpipe/-/unpipe-1.0.0.tgz",
+ "integrity": "sha1-sr9O6FFKrmFltIF4KdIbLvSZBOw="
+ },
+ "utils-merge": {
+ "version": "1.0.1",
+ "resolved": "https://registry.npmjs.org/utils-merge/-/utils-merge-1.0.1.tgz",
+ "integrity": "sha1-n5VxD1CiZ5R7LMwSR0HBAoQn5xM="
+ },
+ "vary": {
+ "version": "1.1.2",
+ "resolved": "https://registry.npmjs.org/vary/-/vary-1.1.2.tgz",
+ "integrity": "sha1-IpnwLG3tMNSllhsLn3RSShj2NPw="
+ },
+ "walk": {
+ "version": "2.3.9",
+ "resolved": "https://registry.npmjs.org/walk/-/walk-2.3.9.tgz",
+ "integrity": "sha1-MbTbZnjyrgHDnqn7hyWpAx5Vins=",
+ "requires": {
+ "foreachasync": "^3.0.0"
+ }
+ },
+ "wordwrap": {
+ "version": "0.0.3",
+ "resolved": "https://registry.npmjs.org/wordwrap/-/wordwrap-0.0.3.tgz",
+ "integrity": "sha1-o9XabNXAvAAI03I0u68b7WMFkQc="
+ }
+ }
+}
diff --git a/benchmarks/workloads/node_template/package.json b/benchmarks/workloads/node_template/package.json
new file mode 100644
index 000000000..7dcadd523
--- /dev/null
+++ b/benchmarks/workloads/node_template/package.json
@@ -0,0 +1,19 @@
+{
+ "name": "nodedum",
+ "version": "1.0.0",
+ "description": "",
+ "main": "index.js",
+ "scripts": {
+ "test": "echo \"Error: no test specified\" && exit 1"
+ },
+ "author": "",
+ "license": "ISC",
+ "dependencies": {
+ "express": "^4.16.4",
+ "hbs": "^4.0.4",
+ "redis": "^2.8.0",
+ "redis-commands": "^1.2.0",
+ "redis-parser": "^2.6.0",
+ "secure-random-string": "^1.1.0"
+ }
+}
diff --git a/benchmarks/workloads/redis/BUILD b/benchmarks/workloads/redis/BUILD
new file mode 100644
index 000000000..a70873065
--- /dev/null
+++ b/benchmarks/workloads/redis/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "pkg_tar")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ ],
+)
diff --git a/benchmarks/workloads/redis/Dockerfile b/benchmarks/workloads/redis/Dockerfile
new file mode 100644
index 000000000..0f17249af
--- /dev/null
+++ b/benchmarks/workloads/redis/Dockerfile
@@ -0,0 +1 @@
+FROM redis:5.0.4
diff --git a/benchmarks/workloads/redisbenchmark/BUILD b/benchmarks/workloads/redisbenchmark/BUILD
new file mode 100644
index 000000000..147cfedd2
--- /dev/null
+++ b/benchmarks/workloads/redisbenchmark/BUILD
@@ -0,0 +1,28 @@
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+py_library(
+ name = "redisbenchmark",
+ srcs = ["__init__.py"],
+)
+
+py_test(
+ name = "redisbenchmark_test",
+ srcs = ["redisbenchmark_test.py"],
+ python_version = "PY3",
+ deps = test_deps + [
+ ":redisbenchmark",
+ ],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ ],
+)
diff --git a/benchmarks/workloads/redisbenchmark/Dockerfile b/benchmarks/workloads/redisbenchmark/Dockerfile
new file mode 100644
index 000000000..f94f6442e
--- /dev/null
+++ b/benchmarks/workloads/redisbenchmark/Dockerfile
@@ -0,0 +1,4 @@
+FROM redis:5.0.4
+ENV host localhost
+ENV port 6379
+CMD ["sh", "-c", "redis-benchmark --csv -h ${host} -p ${port} ${flags}"]
diff --git a/benchmarks/workloads/redisbenchmark/__init__.py b/benchmarks/workloads/redisbenchmark/__init__.py
new file mode 100644
index 000000000..229cef5fa
--- /dev/null
+++ b/benchmarks/workloads/redisbenchmark/__init__.py
@@ -0,0 +1,85 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Redis-benchmark tool."""
+
+import re
+
+OPERATIONS = [
+ "PING_INLINE",
+ "PING_BULK",
+ "SET",
+ "GET",
+ "INCR",
+ "LPUSH",
+ "RPUSH",
+ "LPOP",
+ "RPOP",
+ "SADD",
+ "HSET",
+ "SPOP",
+ "LRANGE_100",
+ "LRANGE_300",
+ "LRANGE_500",
+ "LRANGE_600",
+ "MSET",
+]
+
+METRICS = dict()
+
+SAMPLE_DATA = """
+"PING_INLINE","48661.80"
+"PING_BULK","50301.81"
+"SET","48923.68"
+"GET","49382.71"
+"INCR","49975.02"
+"LPUSH","49875.31"
+"RPUSH","50276.52"
+"LPOP","50327.12"
+"RPOP","50556.12"
+"SADD","49504.95"
+"HSET","49504.95"
+"SPOP","50025.02"
+"LPUSH (needed to benchmark LRANGE)","48875.86"
+"LRANGE_100 (first 100 elements)","33955.86"
+"LRANGE_300 (first 300 elements)","16550.81"
+"LRANGE_500 (first 450 elements)","13653.74"
+"LRANGE_600 (first 600 elements)","11219.57"
+"MSET (10 keys)","44682.75"
+"""
+
+
+# pylint: disable=unused-argument
+def sample(**kwargs) -> str:
+ return SAMPLE_DATA
+
+
+# Bind a metric for each operation noted above.
+for op in OPERATIONS:
+
+ def bind(metric):
+ """Bind op to a new scope."""
+
+ # pylint: disable=unused-argument
+ def parse(data: str, **kwargs) -> float:
+ """Operation throughput in requests/sec."""
+ regex = r"\"" + metric + r"( .*)?\",\"(\d*.\d*)"
+ res = re.compile(regex).search(data)
+ if res:
+ return float(res.group(2))
+ return 0.0
+
+ parse.__name__ = metric
+ return parse
+
+ METRICS[op] = bind(op)
diff --git a/benchmarks/workloads/redisbenchmark/redisbenchmark_test.py b/benchmarks/workloads/redisbenchmark/redisbenchmark_test.py
new file mode 100644
index 000000000..419ced059
--- /dev/null
+++ b/benchmarks/workloads/redisbenchmark/redisbenchmark_test.py
@@ -0,0 +1,51 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Parser test."""
+
+import sys
+
+import pytest
+
+from benchmarks.workloads import redisbenchmark
+
+RESULTS = {
+ "PING_INLINE": 48661.80,
+ "PING_BULK": 50301.81,
+ "SET": 48923.68,
+ "GET": 49382.71,
+ "INCR": 49975.02,
+ "LPUSH": 49875.31,
+ "RPUSH": 50276.52,
+ "LPOP": 50327.12,
+ "RPOP": 50556.12,
+ "SADD": 49504.95,
+ "HSET": 49504.95,
+ "SPOP": 50025.02,
+ "LRANGE_100": 33955.86,
+ "LRANGE_300": 16550.81,
+ "LRANGE_500": 13653.74,
+ "LRANGE_600": 11219.57,
+ "MSET": 44682.75
+}
+
+
+def test_metrics():
+ """Test all metrics."""
+ for (metric, func) in redisbenchmark.METRICS.items():
+ res = func(redisbenchmark.sample())
+ assert float(res) == RESULTS[metric]
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main([__file__]))
diff --git a/benchmarks/workloads/ruby/BUILD b/benchmarks/workloads/ruby/BUILD
new file mode 100644
index 000000000..a3be4fe92
--- /dev/null
+++ b/benchmarks/workloads/ruby/BUILD
@@ -0,0 +1,28 @@
+load("//tools:defs.bzl", "pkg_tar")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+filegroup(
+ name = "files",
+ srcs = [
+ "Dockerfile",
+ "Gemfile",
+ "Gemfile.lock",
+ "config.ru",
+ "index.rb",
+ ],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ "Gemfile",
+ "Gemfile.lock",
+ "config.ru",
+ "index.rb",
+ ],
+)
diff --git a/benchmarks/workloads/ruby/Dockerfile b/benchmarks/workloads/ruby/Dockerfile
new file mode 100644
index 000000000..a9a7a7086
--- /dev/null
+++ b/benchmarks/workloads/ruby/Dockerfile
@@ -0,0 +1,28 @@
+# example based on https://github.com/errm/fib
+
+FROM ruby:2.5
+
+RUN apt-get update -qq && apt-get install -y build-essential libpq-dev nodejs libsodium-dev
+
+# Set an environment variable where the Rails app is installed to inside of Docker image
+ENV RAILS_ROOT /var/www/app_name
+RUN mkdir -p $RAILS_ROOT
+
+# Set working directory
+WORKDIR $RAILS_ROOT
+
+# Setting env up
+ENV RAILS_ENV='production'
+ENV RACK_ENV='production'
+
+# Adding gems
+COPY Gemfile Gemfile
+COPY Gemfile.lock Gemfile.lock
+RUN bundle install --jobs 20 --retry 5 --without development test
+
+# Adding project files
+COPY . .
+
+EXPOSE $PORT
+STOPSIGNAL SIGINT
+CMD ["bundle", "exec", "puma", "config.ru"]
diff --git a/benchmarks/workloads/ruby/Gemfile b/benchmarks/workloads/ruby/Gemfile
new file mode 100644
index 000000000..8f1bdad6e
--- /dev/null
+++ b/benchmarks/workloads/ruby/Gemfile
@@ -0,0 +1,12 @@
+source "https://rubygems.org"
+# load a bunch of dependencies to take up memory
+gem "sinatra"
+gem "puma"
+gem "redis"
+gem 'rake'
+gem 'squid', '~> 1.4'
+gem 'cassandra-driver'
+gem 'ruby-fann'
+gem 'rbnacl'
+gem 'bcrypt'
+gem "activemerchant" \ No newline at end of file
diff --git a/benchmarks/workloads/ruby/Gemfile.lock b/benchmarks/workloads/ruby/Gemfile.lock
new file mode 100644
index 000000000..ea9f0ea85
--- /dev/null
+++ b/benchmarks/workloads/ruby/Gemfile.lock
@@ -0,0 +1,71 @@
+GEM
+ remote: https://rubygems.org/
+ specs:
+ activemerchant (1.105.0)
+ activesupport (>= 4.2)
+ builder (>= 2.1.2, < 4.0.0)
+ i18n (>= 0.6.9)
+ nokogiri (~> 1.4)
+ activesupport (5.2.3)
+ concurrent-ruby (~> 1.0, >= 1.0.2)
+ i18n (>= 0.7, < 2)
+ minitest (~> 5.1)
+ tzinfo (~> 1.1)
+ bcrypt (3.1.13)
+ builder (3.2.4)
+ cassandra-driver (3.2.3)
+ ione (~> 1.2)
+ concurrent-ruby (1.1.5)
+ ffi (1.12.2)
+ i18n (1.6.0)
+ concurrent-ruby (~> 1.0)
+ ione (1.2.4)
+ mini_portile2 (2.4.0)
+ minitest (5.11.3)
+ mustermann (1.0.3)
+ nokogiri (1.10.8)
+ mini_portile2 (~> 2.4.0)
+ pdf-core (0.7.0)
+ prawn (2.2.2)
+ pdf-core (~> 0.7.0)
+ ttfunk (~> 1.5)
+ puma (3.12.4)
+ rack (2.2.2)
+ rack-protection (2.0.5)
+ rack
+ rake (12.3.3)
+ rbnacl (7.1.1)
+ ffi
+ redis (4.1.1)
+ ruby-fann (1.2.6)
+ sinatra (2.0.5)
+ mustermann (~> 1.0)
+ rack (~> 2.0)
+ rack-protection (= 2.0.5)
+ tilt (~> 2.0)
+ squid (1.4.1)
+ activesupport (>= 4.0)
+ prawn (~> 2.2)
+ thread_safe (0.3.6)
+ tilt (2.0.9)
+ ttfunk (1.5.1)
+ tzinfo (1.2.5)
+ thread_safe (~> 0.1)
+
+PLATFORMS
+ ruby
+
+DEPENDENCIES
+ activemerchant
+ bcrypt
+ cassandra-driver
+ puma
+ rake
+ rbnacl
+ redis
+ ruby-fann
+ sinatra
+ squid (~> 1.4)
+
+BUNDLED WITH
+ 1.17.1
diff --git a/benchmarks/workloads/ruby/config.ru b/benchmarks/workloads/ruby/config.ru
new file mode 100755
index 000000000..fbd5acc82
--- /dev/null
+++ b/benchmarks/workloads/ruby/config.ru
@@ -0,0 +1,2 @@
+require './index'
+run Sinatra::Application \ No newline at end of file
diff --git a/benchmarks/workloads/ruby/index.rb b/benchmarks/workloads/ruby/index.rb
new file mode 100755
index 000000000..5fa85af93
--- /dev/null
+++ b/benchmarks/workloads/ruby/index.rb
@@ -0,0 +1,14 @@
+require "sinatra"
+require "puma"
+require "redis"
+require "rake"
+require "squid"
+require "cassandra"
+require "ruby-fann"
+require "rbnacl"
+require "bcrypt"
+require "activemerchant"
+
+get "/" do
+ "Hello World!"
+end \ No newline at end of file
diff --git a/benchmarks/workloads/ruby_template/BUILD b/benchmarks/workloads/ruby_template/BUILD
new file mode 100644
index 000000000..72ed9403d
--- /dev/null
+++ b/benchmarks/workloads/ruby_template/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "pkg_tar")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ "Gemfile",
+ "Gemfile.lock",
+ "config.ru",
+ "index.erb",
+ "main.rb",
+ ],
+)
diff --git a/benchmarks/workloads/ruby_template/Dockerfile b/benchmarks/workloads/ruby_template/Dockerfile
new file mode 100755
index 000000000..a06d68bf4
--- /dev/null
+++ b/benchmarks/workloads/ruby_template/Dockerfile
@@ -0,0 +1,38 @@
+# example based on https://github.com/errm/fib
+
+FROM alpine:3.9 as build
+
+COPY Gemfile Gemfile.lock ./
+
+RUN apk add --no-cache ruby ruby-dev ruby-bundler ruby-json build-base bash \
+ && bundle install --frozen -j4 -r3 --no-cache --without development \
+ && apk del --no-cache ruby-bundler \
+ && rm -rf /usr/lib/ruby/gems/*/cache
+
+FROM alpine:3.9 as prod
+
+COPY --from=build /usr/lib/ruby/gems /usr/lib/ruby/gems
+RUN apk add --no-cache ruby ruby-json ruby-etc redis apache2-utils \
+ && ruby -e "Gem::Specification.map.each do |spec| \
+ Gem::Installer.for_spec( \
+ spec, \
+ wrappers: true, \
+ force: true, \
+ install_dir: spec.base_dir, \
+ build_args: spec.build_args, \
+ ).generate_bin \
+ end"
+
+WORKDIR /app
+COPY . /app/.
+
+ENV PORT=9292 \
+ WEB_CONCURRENCY=20 \
+ WEB_MAX_THREADS=20 \
+ RACK_ENV=production
+
+ENV host localhost
+EXPOSE $PORT
+USER nobody
+STOPSIGNAL SIGINT
+CMD ["sh", "-c", "/usr/bin/puma", "${host}"]
diff --git a/benchmarks/workloads/ruby_template/Gemfile b/benchmarks/workloads/ruby_template/Gemfile
new file mode 100755
index 000000000..ac521b32c
--- /dev/null
+++ b/benchmarks/workloads/ruby_template/Gemfile
@@ -0,0 +1,5 @@
+source "https://rubygems.org"
+
+gem "sinatra"
+gem "puma"
+gem "redis" \ No newline at end of file
diff --git a/benchmarks/workloads/ruby_template/Gemfile.lock b/benchmarks/workloads/ruby_template/Gemfile.lock
new file mode 100644
index 000000000..eeb3c7bbe
--- /dev/null
+++ b/benchmarks/workloads/ruby_template/Gemfile.lock
@@ -0,0 +1,26 @@
+GEM
+ remote: https://rubygems.org/
+ specs:
+ mustermann (1.0.3)
+ puma (3.12.6)
+ rack (2.0.6)
+ rack-protection (2.0.5)
+ rack
+ redis (4.1.0)
+ sinatra (2.0.5)
+ mustermann (~> 1.0)
+ rack (~> 2.0)
+ rack-protection (= 2.0.5)
+ tilt (~> 2.0)
+ tilt (2.0.9)
+
+PLATFORMS
+ ruby
+
+DEPENDENCIES
+ puma
+ redis
+ sinatra
+
+BUNDLED WITH
+ 1.17.1 \ No newline at end of file
diff --git a/benchmarks/workloads/ruby_template/config.ru b/benchmarks/workloads/ruby_template/config.ru
new file mode 100755
index 000000000..b2d135cc0
--- /dev/null
+++ b/benchmarks/workloads/ruby_template/config.ru
@@ -0,0 +1,2 @@
+require './main'
+run Sinatra::Application \ No newline at end of file
diff --git a/benchmarks/workloads/ruby_template/index.erb b/benchmarks/workloads/ruby_template/index.erb
new file mode 100755
index 000000000..7f7300e80
--- /dev/null
+++ b/benchmarks/workloads/ruby_template/index.erb
@@ -0,0 +1,8 @@
+<!DOCTYPE html>
+<html>
+<body>
+ <% text.each do |t| %>
+ <p><%= t %></p>
+ <% end %>
+</body>
+</html>
diff --git a/benchmarks/workloads/ruby_template/main.rb b/benchmarks/workloads/ruby_template/main.rb
new file mode 100755
index 000000000..35c239377
--- /dev/null
+++ b/benchmarks/workloads/ruby_template/main.rb
@@ -0,0 +1,27 @@
+require "sinatra"
+require "securerandom"
+require "redis"
+
+redis_host = ENV["host"]
+$redis = Redis.new(host: redis_host)
+
+def generateText
+ for i in 0..99
+ $redis.set(i, randomBody(1024))
+ end
+end
+
+def randomBody(length)
+ return SecureRandom.alphanumeric(length)
+end
+
+generateText
+template = ERB.new(File.read('./index.erb'))
+
+get "/" do
+ texts = Array.new
+ for i in 0..4
+ texts.push($redis.get(rand(0..99)))
+ end
+ template.result_with_hash(text: texts)
+end \ No newline at end of file
diff --git a/benchmarks/workloads/sleep/BUILD b/benchmarks/workloads/sleep/BUILD
new file mode 100644
index 000000000..a70873065
--- /dev/null
+++ b/benchmarks/workloads/sleep/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "pkg_tar")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ ],
+)
diff --git a/benchmarks/workloads/sleep/Dockerfile b/benchmarks/workloads/sleep/Dockerfile
new file mode 100644
index 000000000..24c72e07a
--- /dev/null
+++ b/benchmarks/workloads/sleep/Dockerfile
@@ -0,0 +1,3 @@
+FROM alpine:latest
+
+CMD ["sleep", "315360000"]
diff --git a/benchmarks/workloads/sysbench/BUILD b/benchmarks/workloads/sysbench/BUILD
new file mode 100644
index 000000000..ab2556064
--- /dev/null
+++ b/benchmarks/workloads/sysbench/BUILD
@@ -0,0 +1,28 @@
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+py_library(
+ name = "sysbench",
+ srcs = ["__init__.py"],
+)
+
+py_test(
+ name = "sysbench_test",
+ srcs = ["sysbench_test.py"],
+ python_version = "PY3",
+ deps = test_deps + [
+ ":sysbench",
+ ],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ ],
+)
diff --git a/benchmarks/workloads/sysbench/Dockerfile b/benchmarks/workloads/sysbench/Dockerfile
new file mode 100644
index 000000000..8225e0e14
--- /dev/null
+++ b/benchmarks/workloads/sysbench/Dockerfile
@@ -0,0 +1,16 @@
+FROM ubuntu:18.04
+
+RUN set -x \
+ && apt-get update \
+ && apt-get install -y \
+ sysbench \
+ && rm -rf /var/lib/apt/lists/*
+
+# Parameterize the tests.
+ENV test cpu
+ENV threads 1
+ENV options ""
+
+# run sysbench once as a warm-up and take the second result
+CMD ["sh", "-c", "sysbench --threads=8 --memory-total-size=5G memory run > /dev/null && \
+sysbench --threads=${threads} ${options} ${test} run"]
diff --git a/benchmarks/workloads/sysbench/__init__.py b/benchmarks/workloads/sysbench/__init__.py
new file mode 100644
index 000000000..de357b4db
--- /dev/null
+++ b/benchmarks/workloads/sysbench/__init__.py
@@ -0,0 +1,167 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Sysbench."""
+
+import re
+
+STD_REGEX = r"events per second:\s*(\d*.?\d*)\n"
+MEM_REGEX = r"Total\soperations:\s+\d*\s*\((\d*\.\d*)\sper\ssecond\)"
+ALT_REGEX = r"execution time \(avg/stddev\):\s*(\d*.?\d*)/(\d*.?\d*)"
+AVG_REGEX = r"avg:[^\n^\d]*(\d*\.?\d*)"
+
+SAMPLE_CPU_DATA = """
+sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3)
+
+Running the test with following options:
+Number of threads: 8
+Initializing random number generator from current time
+
+
+Prime numbers limit: 10000
+
+Initializing worker threads...
+
+Threads started!
+
+CPU speed:
+ events per second: 9093.38
+
+General statistics:
+ total time: 10.0007s
+ total number of events: 90949
+
+Latency (ms):
+ min: 0.64
+ avg: 0.88
+ max: 24.65
+ 95th percentile: 1.55
+ sum: 79936.91
+
+Threads fairness:
+ events (avg/stddev): 11368.6250/831.38
+ execution time (avg/stddev): 9.9921/0.01
+"""
+
+SAMPLE_MEMORY_DATA = """
+sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3)
+
+Running the test with following options:
+Number of threads: 8
+Initializing random number generator from current time
+
+
+Running memory speed test with the following options:
+ block size: 1KiB
+ total size: 102400MiB
+ operation: write
+ scope: global
+
+Initializing worker threads...
+
+Threads started!
+
+Total operations: 47999046 (9597428.64 per second)
+
+46874.07 MiB transferred (9372.49 MiB/sec)
+
+
+General statistics:
+ total time: 5.0001s
+ total number of events: 47999046
+
+Latency (ms):
+ min: 0.00
+ avg: 0.00
+ max: 0.21
+ 95th percentile: 0.00
+ sum: 33165.91
+
+Threads fairness:
+ events (avg/stddev): 5999880.7500/111242.52
+ execution time (avg/stddev): 4.1457/0.09
+"""
+
+SAMPLE_MUTEX_DATA = """
+sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3)
+
+Running the test with following options:
+Number of threads: 8
+Initializing random number generator from current time
+
+
+Initializing worker threads...
+
+Threads started!
+
+
+General statistics:
+ total time: 3.7869s
+ total number of events: 8
+
+Latency (ms):
+ min: 3688.56
+ avg: 3754.03
+ max: 3780.94
+ 95th percentile: 3773.42
+ sum: 30032.28
+
+Threads fairness:
+ events (avg/stddev): 1.0000/0.00
+ execution time (avg/stddev): 3.7540/0.03
+"""
+
+
+# pylint: disable=unused-argument
+def sample(test, **kwargs):
+ switch = {
+ "cpu": SAMPLE_CPU_DATA,
+ "memory": SAMPLE_MEMORY_DATA,
+ "mutex": SAMPLE_MUTEX_DATA,
+ "randwr": SAMPLE_CPU_DATA
+ }
+ return switch[test]
+
+
+# pylint: disable=unused-argument
+def cpu_events_per_second(data: str, **kwargs) -> float:
+ """Returns events per second."""
+ return float(re.compile(STD_REGEX).search(data).group(1))
+
+
+# pylint: disable=unused-argument
+def memory_ops_per_second(data: str, **kwargs) -> float:
+ """Returns memory operations per second."""
+ return float(re.compile(MEM_REGEX).search(data).group(1))
+
+
+# pylint: disable=unused-argument
+def mutex_time(data: str, count: int, locks: int, threads: int,
+ **kwargs) -> float:
+ """Returns normalized mutex time (lower is better)."""
+ value = float(re.compile(ALT_REGEX).search(data).group(1))
+ contention = float(threads) / float(locks)
+ scale = contention * float(count) / 100000000.0
+ return value / scale
+
+
+# pylint: disable=unused-argument
+def mutex_deviation(data: str, **kwargs) -> float:
+ """Returns deviation for threads."""
+ return float(re.compile(ALT_REGEX).search(data).group(2))
+
+
+# pylint: disable=unused-argument
+def mutex_latency(data: str, **kwargs) -> float:
+ """Returns average mutex latency."""
+ return float(re.compile(AVG_REGEX).search(data).group(1))
diff --git a/benchmarks/workloads/sysbench/sysbench_test.py b/benchmarks/workloads/sysbench/sysbench_test.py
new file mode 100644
index 000000000..3fb541fd2
--- /dev/null
+++ b/benchmarks/workloads/sysbench/sysbench_test.py
@@ -0,0 +1,34 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Parser test."""
+
+import sys
+
+import pytest
+
+from benchmarks.workloads import sysbench
+
+
+def test_sysbench_parser():
+ """Test the basic parser."""
+ assert sysbench.cpu_events_per_second(sysbench.sample("cpu")) == 9093.38
+ assert sysbench.memory_ops_per_second(sysbench.sample("memory")) == 9597428.64
+ assert sysbench.mutex_time(sysbench.sample("mutex"), 1, 1,
+ 100000000.0) == 3.754
+ assert sysbench.mutex_deviation(sysbench.sample("mutex")) == 0.03
+ assert sysbench.mutex_latency(sysbench.sample("mutex")) == 3754.03
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main([__file__]))
diff --git a/benchmarks/workloads/syscall/BUILD b/benchmarks/workloads/syscall/BUILD
new file mode 100644
index 000000000..f8c43bca1
--- /dev/null
+++ b/benchmarks/workloads/syscall/BUILD
@@ -0,0 +1,29 @@
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+py_library(
+ name = "syscall",
+ srcs = ["__init__.py"],
+)
+
+py_test(
+ name = "syscall_test",
+ srcs = ["syscall_test.py"],
+ python_version = "PY3",
+ deps = test_deps + [
+ ":syscall",
+ ],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ "syscall.c",
+ ],
+)
diff --git a/benchmarks/workloads/syscall/Dockerfile b/benchmarks/workloads/syscall/Dockerfile
new file mode 100644
index 000000000..a2088d953
--- /dev/null
+++ b/benchmarks/workloads/syscall/Dockerfile
@@ -0,0 +1,6 @@
+FROM gcc:latest
+COPY . /usr/src/syscall
+WORKDIR /usr/src/syscall
+RUN gcc -O2 -o syscall syscall.c
+ENV count 1000000
+CMD ["sh", "-c", "./syscall ${count}"]
diff --git a/benchmarks/workloads/syscall/__init__.py b/benchmarks/workloads/syscall/__init__.py
new file mode 100644
index 000000000..dc9028faa
--- /dev/null
+++ b/benchmarks/workloads/syscall/__init__.py
@@ -0,0 +1,29 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Simple syscall test."""
+
+import re
+
+SAMPLE_DATA = "Called getpid syscall 1000000 times: 1117 ms, 500 ns each."
+
+
+# pylint: disable=unused-argument
+def sample(**kwargs) -> str:
+ return SAMPLE_DATA
+
+
+# pylint: disable=unused-argument
+def syscall_time_ns(data: str, **kwargs) -> int:
+ """Returns average system call time."""
+ return float(re.compile(r"(\d+)\sns each.").search(data).group(1))
diff --git a/benchmarks/workloads/syscall/syscall.c b/benchmarks/workloads/syscall/syscall.c
new file mode 100644
index 000000000..ded030397
--- /dev/null
+++ b/benchmarks/workloads/syscall/syscall.c
@@ -0,0 +1,55 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+
+// http://www.apache.org/licenses/LICENSE-2.0
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#define _GNU_SOURCE
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <time.h>
+#include <unistd.h>
+
+// Short program that calls getpid() a number of times and outputs time
+// diference from the MONOTONIC clock.
+int main(int argc, char** argv) {
+ struct timespec start, stop;
+ long result;
+ char buf[80];
+
+ if (argc < 2) {
+ printf("Usage:./syscall NUM_TIMES_TO_CALL");
+ return 1;
+ }
+
+ if (clock_gettime(CLOCK_MONOTONIC, &start)) return 1;
+
+ long loops = atoi(argv[1]);
+ for (long i = 0; i < loops; i++) {
+ syscall(SYS_gettimeofday, 0, 0);
+ }
+
+ if (clock_gettime(CLOCK_MONOTONIC, &stop)) return 1;
+
+ if ((stop.tv_nsec - start.tv_nsec) < 0) {
+ result = (stop.tv_sec - start.tv_sec - 1) * 1000;
+ result += (stop.tv_nsec - start.tv_nsec + 1000000000) / (1000 * 1000);
+ } else {
+ result = (stop.tv_sec - start.tv_sec) * 1000;
+ result += (stop.tv_nsec - start.tv_nsec) / (1000 * 1000);
+ }
+
+ printf("Called getpid syscall %d times: %lu ms, %lu ns each.\n", loops,
+ result, result * 1000000 / loops);
+
+ return 0;
+}
diff --git a/benchmarks/workloads/syscall/syscall_test.py b/benchmarks/workloads/syscall/syscall_test.py
new file mode 100644
index 000000000..72f027de1
--- /dev/null
+++ b/benchmarks/workloads/syscall/syscall_test.py
@@ -0,0 +1,27 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+
+import pytest
+
+from benchmarks.workloads import syscall
+
+
+def test_syscall_time_ns():
+ assert syscall.syscall_time_ns(syscall.sample()) == 500
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main([__file__]))
diff --git a/benchmarks/workloads/tensorflow/BUILD b/benchmarks/workloads/tensorflow/BUILD
new file mode 100644
index 000000000..a7b7742f4
--- /dev/null
+++ b/benchmarks/workloads/tensorflow/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "pkg_tar")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+py_library(
+ name = "tensorflow",
+ srcs = ["__init__.py"],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ ],
+)
diff --git a/benchmarks/workloads/tensorflow/Dockerfile b/benchmarks/workloads/tensorflow/Dockerfile
new file mode 100644
index 000000000..b5763e8ae
--- /dev/null
+++ b/benchmarks/workloads/tensorflow/Dockerfile
@@ -0,0 +1,14 @@
+FROM tensorflow/tensorflow:1.13.2
+
+RUN apt-get update \
+ && apt-get install -y git
+RUN git clone --depth 1 https://github.com/aymericdamien/TensorFlow-Examples.git
+RUN python -m pip install -U pip setuptools
+RUN python -m pip install matplotlib
+
+WORKDIR /TensorFlow-Examples/examples
+
+ENV PYTHONPATH="$PYTHONPATH:/TensorFlow-Examples/examples"
+
+ENV workload "3_NeuralNetworks/convolutional_network.py"
+CMD python ${workload}
diff --git a/benchmarks/workloads/tensorflow/__init__.py b/benchmarks/workloads/tensorflow/__init__.py
new file mode 100644
index 000000000..b5ec213f8
--- /dev/null
+++ b/benchmarks/workloads/tensorflow/__init__.py
@@ -0,0 +1,20 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A Tensorflow example."""
+
+
+# pylint: disable=unused-argument
+def run_time(value, **kwargs):
+ """Returns the startup and runtime of the Tensorflow workload in seconds."""
+ return value
diff --git a/benchmarks/workloads/true/BUILD b/benchmarks/workloads/true/BUILD
new file mode 100644
index 000000000..eba23d325
--- /dev/null
+++ b/benchmarks/workloads/true/BUILD
@@ -0,0 +1,14 @@
+load("//tools:defs.bzl", "pkg_tar")
+
+package(
+ default_visibility = ["//benchmarks:__subpackages__"],
+ licenses = ["notice"],
+)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ ],
+ extension = "tar",
+)
diff --git a/benchmarks/workloads/true/Dockerfile b/benchmarks/workloads/true/Dockerfile
new file mode 100644
index 000000000..2e97c921e
--- /dev/null
+++ b/benchmarks/workloads/true/Dockerfile
@@ -0,0 +1,3 @@
+FROM alpine:latest
+
+CMD ["true"]
diff --git a/g3doc/BUILD b/g3doc/BUILD
new file mode 100644
index 000000000..c315d38be
--- /dev/null
+++ b/g3doc/BUILD
@@ -0,0 +1,44 @@
+load("//website:defs.bzl", "doc")
+
+package(
+ default_visibility = ["//website:__pkg__"],
+ licenses = ["notice"],
+)
+
+doc(
+ name = "index",
+ src = "README.md",
+ category = "Project",
+ data = glob([
+ "*.png",
+ "*.svg",
+ ]),
+ permalink = "/docs/",
+ weight = "0",
+)
+
+doc(
+ name = "roadmap",
+ src = "roadmap.md",
+ category = "Project",
+ permalink = "/roadmap/",
+ weight = "10",
+)
+
+doc(
+ name = "community",
+ src = "community.md",
+ category = "Project",
+ permalink = "/community/",
+ subcategory = "Community",
+ weight = "95",
+)
+
+doc(
+ name = "style",
+ src = "style.md",
+ category = "Project",
+ permalink = "/community/style_guide/",
+ subcategory = "Community",
+ weight = "10",
+)
diff --git a/g3doc/Layers.png b/g3doc/Layers.png
new file mode 100644
index 000000000..308c6c451
--- /dev/null
+++ b/g3doc/Layers.png
Binary files differ
diff --git a/g3doc/Layers.svg b/g3doc/Layers.svg
new file mode 100644
index 000000000..0a366f841
--- /dev/null
+++ b/g3doc/Layers.svg
@@ -0,0 +1 @@
+<svg version="1.1" viewBox="0.0 0.0 371.8346456692913 255.01574803149606" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg"><clipPath id="p.0"><path d="m0 0l371.83466 0l0 255.01575l-371.83466 0l0 -255.01575z" clip-rule="nonzero"/></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l371.83466 0l0 255.01575l-371.83466 0z" fill-rule="evenodd"/><path fill="#f4cccc" d="m36.454067 6.6430445l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path stroke="#cc4125" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m36.454067 6.6430445l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path fill="#000000" d="m78.206116 37.98824l5.125 -13.359373l1.90625 0l5.46875 13.359373l-2.015625 0l-1.546875 -4.046875l-5.59375 0l-1.46875 4.046875l-1.875 0zm3.859375 -5.484375l4.53125 0l-1.40625 -3.703123q-0.625 -1.6875 -0.9375 -2.765625q-0.265625 1.28125 -0.71875 2.546875l-1.46875 3.921873zm9.849823 9.1875l0 -13.374998l1.484375 0l0 1.25q0.53125 -0.734375 1.1875 -1.09375q0.671875 -0.375 1.625 -0.375q1.234375 0 2.171875 0.640625q0.953125 0.625 1.4375 1.796875q0.484375 1.15625 0.484375 2.546873q0 1.484375 -0.53125 2.671875q-0.53125 1.1875 -1.546875 1.828125q-1.015625 0.625 -2.140625 0.625q-0.8125 0 -1.46875 -0.34375q-0.65625 -0.34375 -1.0625 -0.875l0 4.703125l-1.640625 0zm1.484375 -8.484375q0 1.859375 0.75 2.765625q0.765625 0.890625 1.828125 0.890625q1.09375 0 1.875 -0.921875q0.78125 -0.9375 0.78125 -2.875q0 -1.8437481 -0.765625 -2.765623q-0.75 -0.921875 -1.8125 -0.921875q-1.046875 0 -1.859375 0.984375q-0.796875 0.96875 -0.796875 2.843748zm8.891342 8.484375l0 -13.374998l1.484375 0l0 1.25q0.53125 -0.734375 1.1875 -1.09375q0.671875 -0.375 1.625 -0.375q1.234375 0 2.171875 0.640625q0.953125 0.625 1.4375 1.796875q0.484375 1.15625 0.484375 2.546873q0 1.484375 -0.53125 2.671875q-0.53125 1.1875 -1.546875 1.828125q-1.015625 0.625 -2.140625 0.625q-0.8125 0 -1.46875 -0.34375q-0.65625 -0.34375 -1.0625 -0.875l0 4.703125l-1.640625 0zm1.484375 -8.484375q0 1.859375 0.75 2.765625q0.765625 0.890625 1.828125 0.890625q1.09375 0 1.875 -0.921875q0.78125 -0.9375 0.78125 -2.875q0 -1.8437481 -0.765625 -2.765623q-0.75 -0.921875 -1.8125 -0.921875q-1.046875 0 -1.859375 0.984375q-0.796875 0.96875 -0.796875 2.843748zm8.844467 4.78125l0 -13.359373l1.640625 0l0 13.359373l-1.640625 0zm4.191696 -11.468748l0 -1.890625l1.640625 0l0 1.890625l-1.640625 0zm0 11.468748l0 -9.671873l1.640625 0l0 9.671873l-1.640625 0zm10.457321 -3.546875l1.609375 0.21875q-0.265625 1.65625 -1.359375 2.609375q-1.078125 0.9375 -2.671875 0.9375q-1.984375 0 -3.1875 -1.296875q-1.203125 -1.296875 -1.203125 -3.71875q0 -1.5781231 0.515625 -2.749998q0.515625 -1.171875 1.578125 -1.75q1.0625 -0.59375 2.3125 -0.59375q1.578125 0 2.578125 0.796875q1.0 0.796875 1.28125 2.265625l-1.59375 0.234375q-0.234375 -0.96875 -0.8125 -1.453125q-0.578125 -0.5 -1.390625 -0.5q-1.234375 0 -2.015625 0.890625q-0.78125 0.890625 -0.78125 2.812498q0 1.953125 0.75 2.84375q0.75 0.875 1.953125 0.875q0.96875 0 1.609375 -0.59375q0.65625 -0.59375 0.828125 -1.828125zm9.328125 2.359375q-0.921875 0.765625 -1.765625 1.09375q-0.828125 0.3125 -1.796875 0.3125q-1.59375 0 -2.453125 -0.78125q-0.859375 -0.78125 -0.859375 -1.984375q0 -0.71875 0.328125 -1.296875q0.328125 -0.59375 0.84375 -0.9375q0.53125 -0.359375 1.1875 -0.546875q0.46875 -0.125 1.453125 -0.25q1.984375 -0.234375 2.921875 -0.5624981q0.015625 -0.34375 0.015625 -0.421875q0 -1.0 -0.46875 -1.421875q-0.625 -0.546875 -1.875 -0.546875q-1.15625 0 -1.703125 0.40625q-0.546875 0.40625 -0.8125 1.421875l-1.609375 -0.21875q0.21875 -1.015625 0.71875 -1.640625q0.5 -0.640625 1.453125 -0.984375q0.953125 -0.34375 2.1875 -0.34375q1.25 0 2.015625 0.296875q0.78125 0.28125 1.140625 0.734375q0.375 0.4375 0.515625 1.109375q0.078125 0.421875 0.078125 1.515625l0 2.187498q0 2.28125 0.109375 2.890625q0.109375 0.59375 0.40625 1.15625l-1.703125 0q-0.265625 -0.515625 -0.328125 -1.1875zm-0.140625 -3.671875q-0.890625 0.375 -2.671875 0.625q-1.015625 0.140625 -1.4375 0.328125q-0.421875 0.1875 -0.65625 0.53125q-0.21875 0.34375 -0.21875 0.78125q0 0.65625 0.5 1.09375q0.5 0.4375 1.453125 0.4375q0.9375 0 1.671875 -0.40625q0.75 -0.421875 1.09375 -1.140625q0.265625 -0.5625 0.265625 -1.640625l0 -0.609375zm7.781967 3.390625l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578123l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671873q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm1.6051788 -9.999998l0 -1.890625l1.640625 0l0 1.890625l-1.640625 0zm0 11.468748l0 -9.671873l1.640625 0l0 9.671873l-1.640625 0zm3.5354462 -4.84375q0 -2.687498 1.484375 -3.968748q1.25 -1.078125 3.046875 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609373q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.03125 0 -3.28125 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.8125 0.921875 2.046875 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.7968731 -0.8125 -2.718748q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.046875 0.921875q-0.796875 0.90625 -0.796875 2.765623zm9.297592 4.84375l0 -9.671873l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.5937481l0 5.953125l-1.640625 0l0 -5.890625q0 -0.9999981 -0.203125 -1.4843731q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515623l0 5.28125l-1.640625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m4.454068 73.068245l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m4.454068 73.068245l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m172.45407 73.068245l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m172.45407 73.068245l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m73.43044 56.702377l100.88189 0l0 32.06299l-100.88189 0z" fill-rule="evenodd"/><path fill="#000000" d="m87.06437 74.471375l1.203125 -0.109375q0.078125 0.71875 0.390625 1.1875q0.3125 0.453125 0.953125 0.734375q0.65625 0.28125 1.46875 0.28125q0.71875 0 1.265625 -0.21875q0.5625 -0.21875 0.828125 -0.578125q0.265625 -0.375 0.265625 -0.828125q0 -0.453125 -0.265625 -0.78125q-0.25 -0.328125 -0.84375 -0.5625q-0.390625 -0.15625 -1.703125 -0.46875q-1.3125 -0.3125 -1.84375 -0.59375q-0.671875 -0.359375 -1.015625 -0.890625q-0.328125 -0.53125 -0.328125 -1.1875q0 -0.71875 0.40625 -1.34375q0.40625 -0.625 1.1875 -0.953125q0.796875 -0.328125 1.765625 -0.328125q1.046875 0 1.859375 0.34375q0.8125 0.34375 1.25 1.015625q0.4375 0.65625 0.46875 1.484375l-1.203125 0.09375q-0.109375 -0.90625 -0.671875 -1.359375q-0.5625 -0.46875 -1.65625 -0.46875q-1.140625 0 -1.671875 0.421875q-0.515625 0.421875 -0.515625 1.015625q0 0.515625 0.359375 0.84375q0.375 0.328125 1.90625 0.6875q1.546875 0.34375 2.109375 0.59375q0.84375 0.390625 1.234375 0.984375q0.390625 0.578125 0.390625 1.359375q0 0.75 -0.4375 1.4375q-0.421875 0.671875 -1.25 1.046875q-0.8125 0.359375 -1.828125 0.359375q-1.296875 0 -2.171875 -0.375q-0.875 -0.375 -1.375 -1.125q-0.5 -0.765625 -0.53125 -1.71875zm9.12413 5.71875l-0.125 -1.09375q0.375 0.109375 0.65625 0.109375q0.390625 0 0.625 -0.140625q0.234375 -0.125 0.390625 -0.359375q0.109375 -0.171875 0.359375 -0.875q0.03125 -0.09375 0.109375 -0.28125l-2.625 -6.921875l1.265625 0l1.4375 4.0q0.28125 0.765625 0.5 1.59375q0.203125 -0.796875 0.46875 -1.578125l1.484375 -4.015625l1.171875 0l-2.625 7.015625q-0.421875 1.140625 -0.65625 1.578125q-0.3125 0.578125 -0.71875 0.84375q-0.40625 0.28125 -0.96875 0.28125q-0.328125 0 -0.75 -0.15625zm6.2421875 -4.71875l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625zm9.6953125 1.015625l0.171875 1.03125q-0.5 0.109375 -0.890625 0.109375q-0.640625 0 -1.0 -0.203125q-0.34375 -0.203125 -0.484375 -0.53125q-0.140625 -0.328125 -0.140625 -1.390625l0 -3.96875l-0.859375 0l0 -0.90625l0.859375 0l0 -1.71875l1.171875 -0.703125l0 2.421875l1.171875 0l0 0.90625l-1.171875 0l0 4.046875q0 0.5 0.046875 0.640625q0.0625 0.140625 0.203125 0.234375q0.140625 0.078125 0.40625 0.078125q0.203125 0 0.515625 -0.046875zm5.8748627 -1.171875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375zm6.5218506 4.125l0 -6.90625l1.046875 0l0 0.96875q0.328125 -0.515625 0.859375 -0.8125q0.546875 -0.3125 1.234375 -0.3125q0.78125 0 1.265625 0.3125q0.484375 0.3125 0.6875 0.890625q0.828125 -1.203125 2.140625 -1.203125q1.0312424 0 1.5781174 0.578125q0.5625 0.5625 0.5625 1.734375l0 4.75l-1.171875 0l0 -4.359375q0 -0.703125 -0.125 -1.0q-0.109375 -0.3125 -0.40625 -0.5q-0.296875 -0.1875 -0.7031174 -0.1875q-0.71875 0 -1.203125 0.484375q-0.484375 0.484375 -0.484375 1.546875l0 4.015625l-1.171875 0l0 -4.484375q0 -0.78125 -0.296875 -1.171875q-0.28125 -0.390625 -0.921875 -0.390625q-0.5 0 -0.921875 0.265625q-0.421875 0.25 -0.609375 0.75q-0.1875 0.5 -0.1875 1.453125l0 3.578125l-1.171875 0zm19.321053 -2.53125l1.15625 0.15625q-0.1875 1.1875 -0.96875 1.859375q-0.78125 0.671875 -1.921875 0.671875q-1.4062653 0 -2.2812653 -0.921875q-0.859375 -0.9375 -0.859375 -2.65625q0 -1.125 0.375 -1.96875q0.375 -0.84375 1.125 -1.25q0.765625 -0.421875 1.6562653 -0.421875q1.125 0 1.84375 0.578125q0.71875 0.5625 0.921875 1.609375l-1.140625 0.171875q-0.171875 -0.703125 -0.59375 -1.046875q-0.40625 -0.359375 -0.984375 -0.359375q-0.890625 0 -1.4531403 0.640625q-0.546875 0.640625 -0.546875 2.0q0 1.40625 0.53125 2.03125q0.546875 0.625 1.4062653 0.625q0.6875 0 1.140625 -0.421875q0.46875 -0.421875 0.59375 -1.296875zm6.6640625 1.671875q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.96109 0l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.507965 -2.0625l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625z" fill-rule="nonzero"/><path fill="#d9d2e9" d="m36.454067 87.40656l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path stroke="#8e7cc3" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m36.454067 87.40656l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path fill="#000000" d="m98.35295 119.54864l1.59375 0.234375q0.109375 0.75 0.5625 1.078125q0.609375 0.453125 1.671875 0.453125q1.140625 0 1.75 -0.453125q0.625 -0.453125 0.84375 -1.265625q0.125 -0.5 0.109375 -2.109375q-1.0625 1.265625 -2.671875 1.265625q-2.0 0 -3.09375 -1.4375q-1.09375 -1.4375 -1.09375 -3.453125q0 -1.390625 0.5 -2.5625q0.515625 -1.171875 1.453125 -1.796875q0.953125 -0.640625 2.25 -0.640625q1.703125 0 2.8125 1.375l0 -1.15625l1.515625 0l0 8.359375q0 2.265625 -0.46875 3.203125q-0.453125 0.9375 -1.453125 1.484375q-0.984375 0.546875 -2.453125 0.546875q-1.71875 0 -2.796875 -0.78125q-1.0625 -0.765625 -1.03125 -2.34375zm1.359375 -5.8125q0 1.90625 0.75 2.78125q0.765625 0.875 1.90625 0.875q1.125 0 1.890625 -0.859375q0.765625 -0.875 0.765625 -2.734375q0 -1.78125 -0.796875 -2.671875q-0.78125 -0.90625 -1.890625 -0.90625q-1.09375 0 -1.859375 0.890625q-0.765625 0.875 -0.765625 2.625zm13.344467 5.015625l-5.171875 -13.359375l1.921875 0l3.46875 9.703125q0.421875 1.171875 0.703125 2.1875q0.3125 -1.09375 0.71875 -2.1875l3.609375 -9.703125l1.796875 0l-5.234375 13.359375l-1.8125 0zm8.427948 -11.46875l0 -1.890625l1.640625 0l0 1.890625l-1.640625 0zm0 11.46875l0 -9.671875l1.640625 0l0 9.671875l-1.640625 0zm3.4885712 -2.890625l1.625 -0.25q0.125 0.96875 0.75 1.5q0.625 0.515625 1.75 0.515625q1.125 0 1.671875 -0.453125q0.546875 -0.46875 0.546875 -1.09375q0 -0.546875 -0.484375 -0.875q-0.328125 -0.21875 -1.671875 -0.546875q-1.8125 -0.46875 -2.515625 -0.796875q-0.6875 -0.328125 -1.046875 -0.90625q-0.359375 -0.59375 -0.359375 -1.3125q0 -0.640625 0.296875 -1.1875q0.296875 -0.5625 0.8125 -0.921875q0.375 -0.28125 1.03125 -0.46875q0.671875 -0.203125 1.421875 -0.203125q1.140625 0 2.0 0.328125q0.859375 0.328125 1.265625 0.890625q0.421875 0.5625 0.578125 1.5l-1.609375 0.21875q-0.109375 -0.75 -0.640625 -1.171875q-0.515625 -0.421875 -1.46875 -0.421875q-1.140625 0 -1.625 0.375q-0.46875 0.375 -0.46875 0.875q0 0.3125 0.1875 0.578125q0.203125 0.265625 0.640625 0.4375q0.234375 0.09375 1.4375 0.421875q1.75 0.453125 2.4375 0.75q0.6875 0.296875 1.078125 0.859375q0.390625 0.5625 0.390625 1.40625q0 0.828125 -0.484375 1.546875q-0.46875 0.71875 -1.375 1.125q-0.90625 0.390625 -2.046875 0.390625q-1.875 0 -2.875 -0.78125q-0.984375 -0.78125 -1.25 -2.328125zm9.375 -1.953125q0 -2.6875 1.484375 -3.96875q1.25 -1.078125 3.046875 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609375q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.03125 0 -3.28125 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.8125 0.921875 2.046875 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.796875 -0.8125 -2.71875q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.046875 0.921875q-0.796875 0.90625 -0.796875 2.765625zm9.281967 4.84375l0 -9.671875l1.46875 0l0 1.46875q0.5625 -1.03125 1.03125 -1.359375q0.484375 -0.328125 1.0625 -0.328125q0.828125 0 1.6875 0.53125l-0.5625 1.515625q-0.609375 -0.359375 -1.203125 -0.359375q-0.546875 0 -0.96875 0.328125q-0.421875 0.328125 -0.609375 0.890625q-0.28125 0.875 -0.28125 1.921875l0 5.0625l-1.625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m3.6351707 152.91733l48.850395 0" fill-rule="evenodd"/><path stroke="#ff0000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m3.6351707 152.91733l48.850395 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m195.25722 152.91733l47.338577 0" fill-rule="evenodd"/><path stroke="#ff0000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m195.25722 152.91733l47.338577 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m52.485565 136.88583l142.77165 0l0 32.06299l-142.77165 0z" fill-rule="evenodd"/><path fill="#000000" d="m65.21821 157.71732l0 -9.546875l1.265625 0l0 8.421875l4.703125 0l0 1.125l-5.96875 0zm7.3343506 -8.1875l0 -1.359375l1.171875 0l0 1.359375l-1.171875 0zm0 8.1875l0 -6.90625l1.171875 0l0 6.90625l-1.171875 0zm2.945465 0l0 -6.90625l1.046875 0l0 0.96875q0.328125 -0.515625 0.859375 -0.8125q0.546875 -0.3125 1.234375 -0.3125q0.78125 0 1.265625 0.3125q0.484375 0.3125 0.6875 0.890625q0.828125 -1.203125 2.140625 -1.203125q1.03125 0 1.578125 0.578125q0.5625 0.5625 0.5625 1.734375l0 4.75l-1.171875 0l0 -4.359375q0 -0.703125 -0.125 -1.0q-0.109375 -0.3125 -0.40625 -0.5q-0.296875 -0.1875 -0.703125 -0.1875q-0.71875 0 -1.203125 0.484375q-0.484375 0.484375 -0.484375 1.546875l0 4.015625l-1.171875 0l0 -4.484375q0 -0.78125 -0.296875 -1.171875q-0.28125 -0.390625 -0.921875 -0.390625q-0.5 0 -0.921875 0.265625q-0.421875 0.25 -0.609375 0.75q-0.1875 0.5 -0.1875 1.453125l0 3.578125l-1.171875 0zm11.118057 -8.1875l0 -1.359375l1.171875 0l0 1.359375l-1.171875 0zm0 8.1875l0 -6.90625l1.171875 0l0 6.90625l-1.171875 0zm5.507965 -1.046875l0.171875 1.03125q-0.5 0.109375 -0.890625 0.109375q-0.640625 0 -1.0 -0.203125q-0.34375 -0.203125 -0.484375 -0.53125q-0.140625 -0.328125 -0.140625 -1.390625l0 -3.96875l-0.859375 0l0 -0.90625l0.859375 0l0 -1.71875l1.171875 -0.703125l0 2.421875l1.171875 0l0 0.90625l-1.171875 0l0 4.046875q0 0.5 0.046875 0.640625q0.0625 0.140625 0.203125 0.234375q0.140625 0.078125 0.40625 0.078125q0.203125 0 0.515625 -0.046875zm5.8748627 -1.171875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375zm11.006226 4.125l0 -0.875q-0.65625 1.03125 -1.9375 1.03125q-0.8125 0 -1.515625 -0.453125q-0.6875 -0.453125 -1.078125 -1.265625q-0.375 -0.828125 -0.375 -1.890625q0 -1.03125 0.34375 -1.875q0.34375 -0.84375 1.03125 -1.28125q0.703125 -0.453125 1.546875 -0.453125q0.625 0 1.109375 0.265625q0.5 0.25 0.796875 0.671875l0 -3.421875l1.171875 0l0 9.546875l-1.09375 0zm-3.703125 -3.453125q0 1.328125 0.5625 1.984375q0.5625 0.65625 1.328125 0.65625q0.765625 0 1.296875 -0.625q0.53125 -0.625 0.53125 -1.90625q0 -1.421875 -0.546875 -2.078125q-0.546875 -0.671875 -1.34375 -0.671875q-0.78125 0 -1.3125 0.640625q-0.515625 0.625 -0.515625 2.0zm9.865463 1.390625l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625zm7.0859375 4.71875l-0.125 -1.09375q0.375 0.109375 0.65625 0.109375q0.390625 0 0.625 -0.140625q0.234375 -0.125 0.390625 -0.359375q0.109375 -0.171875 0.359375 -0.875q0.03125 -0.09375 0.109375 -0.28125l-2.625 -6.921875l1.265625 0l1.4375 4.0q0.28125 0.765625 0.5 1.59375q0.203125 -0.796875 0.46875 -1.578125l1.484375 -4.015625l1.171875 0l-2.625 7.015625q-0.421875 1.140625 -0.65625 1.578125q-0.3125 0.578125 -0.71875 0.84375q-0.40625 0.28125 -0.96875 0.28125q-0.328125 0 -0.75 -0.15625zm6.2421875 -4.71875l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8124924 0 1.2031174 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1874924 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8124924 0 1.4218674 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0624924 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.2499924 0.328125 1.7343674 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.4531174 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625zm9.695305 1.015625l0.171875 1.03125q-0.5 0.109375 -0.890625 0.109375q-0.640625 0 -1.0 -0.203125q-0.34375 -0.203125 -0.484375 -0.53125q-0.140625 -0.328125 -0.140625 -1.390625l0 -3.96875l-0.859375 0l0 -0.90625l0.859375 0l0 -1.71875l1.171875 -0.703125l0 2.421875l1.171875 0l0 0.90625l-1.171875 0l0 4.046875q0 0.5 0.046875 0.640625q0.0625 0.140625 0.203125 0.234375q0.140625 0.078125 0.40625 0.078125q0.203125 0 0.515625 -0.046875zm5.8748627 -1.171875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375zm6.5218506 4.125l0 -6.90625l1.046875 0l0 0.96875q0.328125 -0.515625 0.859375 -0.8125q0.546875 -0.3125 1.234375 -0.3125q0.78125 0 1.265625 0.3125q0.484375 0.3125 0.6875 0.890625q0.828125 -1.203125 2.140625 -1.203125q1.03125 0 1.578125 0.578125q0.5625 0.5625 0.5625 1.734375l0 4.75l-1.171875 0l0 -4.359375q0 -0.703125 -0.125 -1.0q-0.109375 -0.3125 -0.40625 -0.5q-0.296875 -0.1875 -0.703125 -0.1875q-0.71875 0 -1.203125 0.484375q-0.484375 0.484375 -0.484375 1.546875l0 4.015625l-1.171875 0l0 -4.484375q0 -0.78125 -0.296875 -1.171875q-0.28125 -0.390625 -0.921875 -0.390625q-0.5 0 -0.921875 0.265625q-0.421875 0.25 -0.609375 0.75q-0.1875 0.5 -0.1875 1.453125l0 3.578125l-1.171875 0zm19.321045 -2.53125l1.15625 0.15625q-0.1875 1.1875 -0.96875 1.859375q-0.78125 0.671875 -1.921875 0.671875q-1.40625 0 -2.28125 -0.921875q-0.859375 -0.9375 -0.859375 -2.65625q0 -1.125 0.375 -1.96875q0.375 -0.84375 1.125 -1.25q0.765625 -0.421875 1.65625 -0.421875q1.125 0 1.84375 0.578125q0.71875 0.5625 0.921875 1.609375l-1.140625 0.171875q-0.171875 -0.703125 -0.59375 -1.046875q-0.40625 -0.359375 -0.984375 -0.359375q-0.890625 0 -1.453125 0.640625q-0.546875 0.640625 -0.546875 2.0q0 1.40625 0.53125 2.03125q0.546875 0.625 1.40625 0.625q0.6875 0 1.140625 -0.421875q0.46875 -0.421875 0.59375 -1.296875zm6.6640625 1.671875q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.96109 0l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.507965 -2.0625l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625z" fill-rule="nonzero"/><path fill="#cfe2f3" d="m36.454067 167.00784l174.83464 0l0 48.850388l-174.83464 0z" fill-rule="evenodd"/><path stroke="#6d9eeb" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m36.454067 167.00784l174.83464 0l0 48.850388l-174.83464 0z" fill-rule="evenodd"/><path fill="#000000" d="m76.63558 198.35303l0 -13.359375l1.765625 0l0 5.484375l6.9375 0l0 -5.484375l1.765625 0l0 13.359375l-1.765625 0l0 -6.296875l-6.9375 0l0 6.296875l-1.765625 0zm12.597946 -4.84375q0 -2.6875 1.484375 -3.96875q1.25 -1.078125 3.046875 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609375q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.03125 0 -3.28125 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.8125 0.921875 2.046875 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.796875 -0.8125 -2.71875q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.046875 0.921875q-0.796875 0.90625 -0.796875 2.765625zm8.641342 1.953125l1.625 -0.25q0.125 0.96875 0.75 1.5q0.625 0.515625 1.75 0.515625q1.125 0 1.671875 -0.453125q0.546875 -0.46875 0.546875 -1.09375q0 -0.546875 -0.484375 -0.875q-0.328125 -0.21875 -1.671875 -0.546875q-1.8125 -0.46875 -2.515625 -0.796875q-0.6875 -0.328125 -1.046875 -0.90625q-0.359375 -0.59375 -0.359375 -1.3125q0 -0.640625 0.296875 -1.1875q0.296875 -0.5625 0.8125 -0.921875q0.375 -0.28125 1.03125 -0.46875q0.671875 -0.203125 1.421875 -0.203125q1.140625 0 2.0 0.328125q0.859375 0.328125 1.265625 0.890625q0.421875 0.5625 0.578125 1.5l-1.609375 0.21875q-0.109375 -0.75 -0.640625 -1.171875q-0.515625 -0.421875 -1.46875 -0.421875q-1.140625 0 -1.625 0.375q-0.46875 0.375 -0.46875 0.875q0 0.3125 0.1875 0.578125q0.203125 0.265625 0.640625 0.4375q0.234375 0.09375 1.4375 0.421875q1.75 0.453125 2.4375 0.75q0.6875 0.296875 1.078125 0.859375q0.390625 0.5625 0.390625 1.40625q0 0.828125 -0.484375 1.546875q-0.46875 0.71875 -1.375 1.125q-0.90625 0.390625 -2.046875 0.390625q-1.875 0 -2.875 -0.78125q-0.984375 -0.78125 -1.25 -2.328125zm13.5625 1.421875l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578125l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671875q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm6.9134827 1.46875l0 -13.359375l1.78125 0l0 6.625l6.6249924 -6.625l2.390625 0l-5.5937424 5.421875l5.8437424 7.9375l-2.328125 0l-4.7656174 -6.765625l-2.171875 2.140625l0 4.625l-1.78125 0zm18.943565 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.765625 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375 0 3.15625 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.21875 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.546875 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.390625 -2.65625l5.40625 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.03125 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.125717 5.765625l0 -9.671875l1.46875 0l0 1.46875q0.5625 -1.03125 1.03125 -1.359375q0.484375 -0.328125 1.0625 -0.328125q0.828125 0 1.6875 0.53125l-0.5625 1.515625q-0.609375 -0.359375 -1.203125 -0.359375q-0.546875 0 -0.96875 0.328125q-0.421875 0.328125 -0.609375 0.890625q-0.28125 0.875 -0.28125 1.921875l0 5.0625l-1.625 0zm6.228302 0l0 -9.671875l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.59375l0 5.953125l-1.640625 0l0 -5.890625q0 -1.0 -0.203125 -1.484375q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515625l0 5.28125l-1.640625 0zm17.000732 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.7656403 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375153 0 3.1562653 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.2187653 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.5468903 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.3906403 -2.65625l5.4062653 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.0312653 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.094467 5.765625l0 -13.359375l1.640625 0l0 13.359375l-1.640625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m4.454068 233.43303l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m4.454068 233.43303l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m172.45407 233.43303l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m172.45407 233.43303l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m73.43044 217.06717l100.88189 0l0 32.06299l-100.88189 0z" fill-rule="evenodd"/><path fill="#000000" d="m96.04542 237.89867l0 -9.546875l1.265625 0l0 3.921875l4.953125 0l0 -3.921875l1.265625 0l0 9.546875l-1.265625 0l0 -4.5l-4.953125 0l0 4.5l-1.265625 0zm13.953278 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm8.93837 0l0 -0.875q-0.65625 1.03125 -1.9375 1.03125q-0.8125 0 -1.515625 -0.453125q-0.6875 -0.453125 -1.078125 -1.265625q-0.375 -0.828125 -0.375 -1.890625q0 -1.03125 0.34375 -1.875q0.34375 -0.84375 1.03125 -1.28125q0.703125 -0.453125 1.546875 -0.453125q0.625 0 1.109375 0.265625q0.5 0.25 0.796875 0.671875l0 -3.421875l1.171875 0l0 9.546875l-1.09375 0zm-3.703125 -3.453125q0 1.328125 0.5625 1.984375q0.5625 0.65625 1.328125 0.65625q0.765625 0 1.296875 -0.625q0.53125 -0.625 0.53125 -1.90625q0 -1.421875 -0.546875 -2.078125q-0.546875 -0.671875 -1.34375 -0.671875q-0.78125 0 -1.3125 0.640625q-0.515625 0.625 -0.515625 2.0zm7.9124756 3.453125l-2.125 -6.90625l1.21875 0l1.09375 3.984375l0.421875 1.484375q0.015625 -0.109375 0.359375 -1.421875l1.0937424 -4.046875l1.203125 0l1.03125 4.0l0.34375 1.328125l0.40625 -1.34375l1.171875 -3.984375l1.140625 0l-2.15625 6.90625l-1.21875 0l-1.09375 -4.140625l-0.265625 -1.171875l-1.4062424 5.3125l-1.21875 0zm12.859535 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59376526 0.21875 -1.2812653 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.4218903 -0.171875 2.0937653 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.3437653 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.89064026 0 1.4375153 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.9218903 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.2031403 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm9.18837 -2.21875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375z" fill-rule="nonzero"/><defs><linearGradient id="p.1" gradientUnits="userSpaceOnUse" gradientTransform="matrix(4.53514884533539 0.0 0.0 4.53514884533539 0.0 0.0)" spreadMethod="pad" x1="8.21347768339151" y1="37.02644733653771" x2="8.213461293294644" y2="41.56159618184348"><stop offset="0.0" stop-color="#ff0000"/><stop offset="0.51" stop-color="#dab7a6"/><stop offset="0.99999994" stop-color="#dab7a6" stop-opacity="0.0"/><stop offset="1.0" stop-color="#ffffff" stop-opacity="0.0"/></linearGradient></defs><path fill="url(#p.1)" d="m37.249344 167.92108l173.29134 0l0 20.566925l-173.29134 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m272.4455 182.06865l129.5748 -74.83464l20.629913 35.74803l-129.5748 74.83464z" fill-rule="evenodd"/><path fill="#000000" d="m287.51392 188.73558l1.1823425 -0.82717896q0.51071167 0.6974335 1.1166077 0.9970703q0.5980835 0.28611755 1.4464111 0.1931305q0.84054565 -0.10652161 1.6794434 -0.5910187q0.75772095 -0.4376068 1.2010193 -0.9823456q0.44906616 -0.5660858 0.50097656 -1.1013031q0.057678223 -0.55656433 -0.20785522 -1.0166931q-0.27334595 -0.47366333 -0.7392273 -0.6557007q-0.47366333 -0.1955719 -1.2366333 -0.079711914q-0.478302 0.07775879 -2.032318 0.54222107q-1.5618286 0.45092773 -2.2805786 0.48712158q-0.9222717 0.027420044 -1.5864563 -0.31072998q-0.6719971 -0.35168457 -1.0703125 -1.0418701q-0.4295349 -0.74432373 -0.38497925 -1.6361694q0.05029297 -0.9131775 0.6668701 -1.7203827q0.63012695 -0.81500244 1.6313782 -1.3932648q1.1095276 -0.64079285 2.1592712 -0.7598877q1.0419617 -0.1326294 1.8867493 0.29968262q0.8583679 0.4244995 1.3987732 1.2671814l-1.2036743 0.82147217q-0.64712524 -0.87127686 -1.5022583 -1.0089111q-0.8629761 -0.15116882 -2.013092 0.51304626q-1.1906738 0.68766785 -1.4819641 1.4333191q-0.2913208 0.745636 0.06793213 1.3681641q0.30459595 0.52778625 0.8865051 0.6608429q0.5740967 0.119522095 2.3815613 -0.43717957q1.8210144 -0.5645294 2.5725403 -0.6376953q1.0924377 -0.107666016 1.857605 0.28042603q0.77090454 0.366745 1.2316895 1.1652069q0.4529724 0.78492737 0.39904785 1.7543335q-0.040405273 0.9616089 -0.6663208 1.8463745q-0.62594604 0.8847656 -1.6813354 1.4942932q-1.3530579 0.78144836 -2.486084 0.9125519q-1.1408386 0.11756897 -2.1214905 -0.3625946q-0.96713257 -0.48797607 -1.5721436 -1.4738007zm13.40155 -4.9431458l0.8006897 0.98106384q-0.45169067 0.4052124 -0.857605 0.6396332q-0.6629944 0.3829193 -1.1454773 0.39089966q-0.4902649 -0.0055389404 -0.8343506 -0.25790405q-0.3518982 -0.26589966 -0.9844971 -1.3620911l-2.382019 -4.1276093l-0.8930054 0.5157471l-0.5466919 -0.9473114l0.8930054 -0.5157623l-1.0308838 -1.786377l0.79599 -1.434082l1.4526367 2.5171661l1.2177734 -0.7033081l0.5466919 0.94732666l-1.2177734 0.70329285l2.4210815 4.195282q0.30456543 0.5278015 0.4446106 0.645401q0.15356445 0.109802246 0.35705566 0.11857605q0.19570923 -0.004760742 0.4663086 -0.16105652q0.20297241 -0.11721802 0.49645996 -0.35888672zm1.8165283 0.41241455l-4.147064 -7.1861115l1.0959778 -0.6329651l0.6247864 1.0826569q-0.0178833 -1.0001068 0.19332886 -1.4468842q0.21121216 -0.44676208 0.64419556 -0.6968231q0.6088562 -0.35165405 1.471283 -0.3264618l0.22875977 1.3654938q-0.59487915 7.4768066E-4 -1.0413818 0.25862122q-0.39239502 0.22662354 -0.5765381 0.6577606q-0.17843628 0.40979004 -0.06384277 0.92100525q0.17193604 0.7667999 0.61709595 1.5381927l2.1711426 3.7622223l-1.2177429 0.70329285zm2.0899658 -6.006668q-1.1480408 -1.9893646 -0.5930481 -3.5910034q0.47280884 -1.3376465 1.7988281 -2.1034698q1.4883423 -0.8595886 2.984253 -0.4243927q1.501648 0.41384888 2.4934998 2.132553q0.79660034 1.3803864 0.83795166 2.4210815q0.04135132 1.0407104 -0.49923706 1.948349q-0.5348511 0.88630676 -1.4819946 1.4333191q-1.5018921 0.8674011 -2.9899902 0.44573975q-1.4959106 -0.43519592 -2.5502625 -2.2621765zm1.2583313 -0.72673035q0.79663086 1.3803711 1.7902527 1.7267303q1.0072021 0.33854675 1.9137268 -0.18502808q0.9065552 -0.5235596 1.1036072 -1.5575867q0.21057129 -1.0418396 -0.60946655 -2.4628143q-0.76538086 -1.3262482 -1.7725525 -1.6647949q-0.99365234 -0.34635925 -1.9002075 0.17720032q-0.90652466 0.5235748 -1.117096 1.5654144q-0.2048645 1.0204926 0.59173584 2.400879zm8.984772 -0.38945007l-4.1470337 -7.1861115l1.0959778 -0.6329651l0.5857239 1.0149841q0.10531616 -1.6306152 1.6072083 -2.4980164q0.6494751 -0.37509155 1.3234558 -0.45761108q0.6739807 -0.08250427 1.163269 0.14013672q0.48928833 0.22264099 0.9021301 0.687912q0.26287842 0.29925537 0.74710083 1.1383057l2.553833 4.425354l-1.2177429 0.70329285l-2.522583 -4.371216q-0.42956543 -0.74432373 -0.7892456 -1.0237579q-0.3540039 -0.30078125 -0.8442688 -0.3063202q-0.47677612 -0.01335144 -0.9638672 0.2679596q-0.7847595 0.45324707 -1.0640869 1.2821655q-0.2735901 0.80758667 0.52301025 2.187973l2.264862 3.9246216l-1.2177429 0.70329285zm7.819275 -3.7220154l1.2922058 -0.511734q0.3878479 0.51579285 0.8666992 0.5640259q0.65527344 0.072631836 1.4400635 -0.38059998q0.8388977 -0.48449707 1.1036682 -1.0885162q0.26480103 -0.60401917 0.07571411 -1.306778q-0.1161499 -0.42010498 -0.81121826 -1.6245575q-0.25161743 1.408371 -1.4423218 2.096054q-1.4883423 0.85957336 -2.9171448 0.25932312q-1.428833 -0.60025024 -2.2957153 -2.1024323q-0.5935669 -1.0285187 -0.72805786 -2.1056366q-0.12097168 -1.0849152 0.30926514 -1.9649353q0.4437561 -0.8878174 1.3908997 -1.4348297q1.2718811 -0.7345581 2.690796 -0.182724l-0.4998474 -0.8661194l1.1230469 -0.6485901l3.5847168 6.2117157q0.9684448 1.6781158 1.0362854 2.5771942q0.067840576 0.8990936 -0.4420166 1.7348785q-0.5098877 0.8357849 -1.5923462 1.4609375q-1.2854004 0.7423706 -2.4195251 0.62150574q-1.1206055 -0.12869263 -1.7651672 -1.3081818zm-1.4765625 -4.9031525q0.81222534 1.4074402 1.7418518 1.7366486q0.94314575 0.32138062 1.7820435 -0.16311646q0.8388977 -0.4844818 1.0401001 -1.4487457q0.19342041 -0.97779846 -0.60317993 -2.3581848q-0.76538086 -1.3262482 -1.7298889 -1.6533508q-0.97229004 -0.3406372 -1.7976685 0.1360321q-0.8118286 0.46887207 -0.9974365 1.4602051q-0.1855774 0.991333 0.56417847 2.290512z" fill-rule="nonzero"/><path fill="#000000" d="m294.23132 199.68793l-0.80441284 -1.3939209l1.2177429 -0.70329285l0.80441284 1.3939056l-1.2177429 0.7033081zm4.920227 8.525894l-4.147064 -7.1861115l1.2177734 -0.7033081l4.1470337 7.1861115l-1.2177429 0.7033081zm1.3493347 -3.6482391l1.0948792 -0.88494873q0.51641846 0.6760864 1.2029724 0.80285645q0.6922302 0.10542297 1.5176086 -0.3712616q0.8388672 -0.4844818 1.0495605 -1.057251q0.20285034 -0.58628845 -0.062683105 -1.0464172q-0.23431396 -0.4059906 -0.7266846 -0.4464264q-0.35079956 -0.013916016 -1.4790955 0.3129425q-1.5347595 0.43530273 -2.2030334 0.49645996q-0.66256714 0.03982544 -1.183075 -0.23695374q-0.5069885 -0.28459167 -0.8115845 -0.8123779q-0.28115845 -0.48719788 -0.2989807 -1.018219q-0.017791748 -0.5310211 0.2048645 -1.0204926q0.15917969 -0.3806305 0.56817627 -0.797287q0.4147339 -0.43800354 0.9694824 -0.75839233q0.852417 -0.49230957 1.6289368 -0.61598206q0.7765198 -0.123687744 1.3162842 0.123931885q0.5455017 0.22625732 1.0733948 0.85964966l-1.0969849 0.85006714q-0.4013672 -0.5079651 -0.9733887 -0.59262085q-0.5720215 -0.0846405 -1.2756042 0.3217163q-0.8388977 0.4844818 -1.0402222 0.9796753q-0.19558716 0.47383118 0.015289307 0.8392334q0.14056396 0.24359131 0.39874268 0.34710693q0.2581787 0.103500366 0.64749146 0.05909729q0.22845459 -0.041732788 1.2620544 -0.31388855q1.4806519 -0.40403748 2.127594 -0.470932q0.6390991 -0.08041382 1.1653442 0.17501831q0.5397949 0.24760437 0.89904785 0.87013245q0.35144043 0.60899353 0.29852295 1.3613129q-0.047210693 0.73095703 -0.5519409 1.4194183q-0.49118042 0.68063354 -1.3300476 1.1651306q-1.407196 0.81269836 -2.4736633 0.6527405q-1.0664673 -0.15994263 -1.933258 -1.193039zm6.1190186 -5.4646606q-1.1480408 -1.9893799 -0.5930481 -3.5910187q0.47283936 -1.3376465 1.7988281 -2.1034698q1.4883423 -0.85957336 2.984253 -0.4243927q1.501648 0.41384888 2.4934998 2.132553q0.79663086 1.3803864 0.83795166 2.4210968q0.04135132 1.0406952 -0.49923706 1.948349q-0.5348511 0.8862915 -1.4819946 1.4333038q-1.5018921 0.8674011 -2.9899902 0.44573975q-1.4959106 -0.43519592 -2.5502625 -2.2621613zm1.2583313 -0.7267456q0.79663086 1.3803864 1.7902832 1.7267303q1.0071716 0.33854675 1.9136963 -0.18501282q0.9065552 -0.5235748 1.1036072 -1.5576019q0.21057129 -1.0418396 -0.60946655 -2.4628143q-0.76538086 -1.3262482 -1.7725525 -1.6647949q-0.99365234 -0.34635925 -1.9002075 0.17721558q-0.90652466 0.5235596 -1.117096 1.5653992q-0.2048645 1.0204926 0.59173584 2.400879zm8.957733 -0.3738098l-5.7246704 -9.919815l1.2177734 -0.70329285l5.72464 9.9198l-1.2177429 0.7033081zm7.2713623 -5.390396q-0.34069824 0.972641 -0.8225403 1.5756989q-0.48962402 0.5895386 -1.2067566 1.0036926q-1.1906738 0.6876831 -2.162445 0.47302246q-0.9717407 -0.21464539 -1.4872131 -1.1078339q-0.30456543 -0.52778625 -0.3109436 -1.1015167q-0.014190674 -0.58724976 0.21627808 -1.0631866q0.23620605 -0.49728394 0.64520264 -0.9139252q0.31063843 -0.3057251 0.980896 -0.8010864q1.3733215 -1.02771 1.9363098 -1.6776581q-0.14837646 -0.25712585 -0.18740845 -0.32478333q-0.42956543 -0.74432373 -0.93963623 -0.84669495q-0.7156677 -0.14602661 -1.6357422 0.38536072q-0.8524475 0.49230957 -1.0922546 1.0458069q-0.23410034 0.5321655 0.013824463 1.3994293l-1.2843933 0.52526855q-0.2749939 -0.85162354 -0.18301392 -1.5362854q0.10549927 -0.6924591 0.66851807 -1.3424072q0.5552063 -0.66348267 1.4752808 -1.1948547q0.92007446 -0.5313873 1.6133118 -0.6430664q0.7067566 -0.11949158 1.1648254 0.04902649q0.45803833 0.16850281 0.8552551 0.60672q0.24728394 0.27218628 0.71588135 1.0841675l0.9371643 1.6239929q0.9840393 1.7051697 1.3094177 2.1126862q0.33892822 0.39971924 0.81103516 0.68640137l-1.2718506 0.7345581q-0.40811157 -0.26953125 -0.7590027 -0.75253296zm-1.6645203 -2.6654663q-0.5067749 0.65356445 -1.7234497 1.6088562q-0.69522095 0.5458679 -0.9283142 0.8609314q-0.23312378 0.31506348 -0.25280762 0.6873169q-0.013977051 0.3508911 0.16564941 0.66215515q0.28115845 0.48719788 0.83392334 0.6010132q0.5662842 0.10598755 1.269867 -0.30036926q0.70358276 -0.40634155 1.072998 -1.0166473q0.37512207 -0.63165283 0.3197937 -1.3214569q-0.031341553 -0.5232086 -0.49990845 -1.3352051l-0.25775146 -0.44659424zm7.223419 -0.8156891l0.8006897 0.98106384q-0.45169067 0.4052124 -0.857605 0.63964844q-0.6629944 0.38290405 -1.1454468 0.39089966q-0.4902954 -0.0055389404 -0.8343811 -0.25790405q-0.3518982 -0.26591492 -0.9844971 -1.3620911l-2.382019 -4.1276093l-0.8930054 0.5157471l-0.5466919 -0.94732666l0.8930054 -0.5157471l-1.0308838 -1.786377l0.7960205 -1.434082l1.4526367 2.5171661l1.2177429 -0.7033081l0.5466919 0.94732666l-1.2177429 0.70329285l2.421051 4.195282q0.30459595 0.52778625 0.4446106 0.645401q0.15356445 0.10978699 0.35708618 0.11856079q0.19567871 -0.004760742 0.46627808 -0.16104126q0.20297241 -0.11721802 0.49645996 -0.35890198zm-3.0901794 -8.121277l-0.80441284 -1.3939209l1.2177429 -0.70329285l0.80444336 1.3939209l-1.2177734 0.70329285zm4.920227 8.525894l-4.1470337 -7.1861115l1.2177429 -0.70329285l4.147064 7.186096l-1.2177734 0.7033081zm0.54074097 -5.111908q-1.1480408 -1.9893799 -0.5930481 -3.5910187q0.47280884 -1.3376465 1.7988281 -2.1034698q1.4883423 -0.8595886 2.984253 -0.4243927q1.501648 0.41384888 2.4934998 2.132553q0.79660034 1.3803864 0.83795166 2.4210968q0.0413208 1.0406952 -0.49923706 1.948349q-0.5348511 0.8862915 -1.4819946 1.4333038q-1.5018921 0.8674011 -2.9899902 0.44573975q-1.4959412 -0.43519592 -2.5502625 -2.2621613zm1.2583313 -0.7267456q0.79663086 1.3803864 1.7902527 1.7267303q1.0072021 0.33854675 1.9137268 -0.18501282q0.9065552 -0.5235748 1.1036072 -1.5576019q0.21057129 -1.0418396 -0.60946655 -2.4628143q-0.76538086 -1.3262482 -1.7725525 -1.6647949q-0.99365234 -0.34635925 -1.9002075 0.17721558q-0.9065552 0.5235596 -1.117096 1.5653992q-0.2048645 1.0204926 0.59173584 2.400879zm8.984772 -0.38945007l-4.1470337 -7.1861115l1.0959473 -0.6329651l0.5857544 1.0149841q0.10531616 -1.6306152 1.6072083 -2.4980164q0.6494446 -0.37509155 1.3234558 -0.45761108q0.6739807 -0.08250427 1.163269 0.14013672q0.48928833 0.22264099 0.9021301 0.687912q0.26287842 0.29925537 0.74710083 1.1383057l2.553833 4.425354l-1.2177429 0.70329285l-2.5226135 -4.371216q-0.4295349 -0.74432373 -0.7892456 -1.0237579q-0.3539734 -0.30078125 -0.8442383 -0.3063202q-0.47677612 -0.01335144 -0.9638672 0.2679596q-0.7847595 0.45324707 -1.0640869 1.2821808q-0.2735901 0.8075714 0.52301025 2.1879578l2.264862 3.9246216l-1.2177429 0.70329285z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m282.76047 199.94267c-17.003845 0 -26.795105 -5.566925 -34.007706 -11.133865c-7.2126007 -5.566925 -11.846542 -11.13385 -23.6931 -11.13385" fill-rule="evenodd"/><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m282.76047 199.94267c-17.003876 0 -26.795105 -5.5669403 -34.007706 -11.133865c-3.6062927 -2.7834625 -6.567932 -5.566925 -10.10881 -7.6545258c-0.4426117 -0.2609558 -0.8942871 -0.5110321 -1.3573761 -0.74887085c-0.11578369 -0.0594635 -0.23228455 -0.1181488 -0.34950256 -0.17605591l-0.13806152 -0.066833496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="2.0" stroke-linecap="butt" d="m237.48381 176.93082l-9.563843 1.350235l8.194244 5.1131744z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m272.4455 118.06866l129.5748 -74.83465l20.629913 35.74803l-129.5748 74.83464z" fill-rule="evenodd"/><path fill="#000000" d="m290.03357 127.53869l-5.72464 -9.9198l1.3124695 -0.75800323l5.72464 9.919807l-1.3124695 0.7579956zm3.470581 -2.0044022l-4.1470337 -7.186104l1.0959778 -0.6329651l0.5857239 1.0149841q0.10531616 -1.6306229 1.6072083 -2.498024q0.6494751 -0.37509155 1.3234558 -0.45760345q0.6739807 -0.0825119 1.163269 0.14012146q0.48928833 0.22264099 0.9021301 0.6879196q0.26290894 0.29925537 0.74710083 1.1383133l2.553833 4.4253464l-1.2177429 0.70329285l-2.522583 -4.371208q-0.4295349 -0.74432373 -0.7892456 -1.0237579q-0.3540039 -0.30078888 -0.8442688 -0.30632782q-0.4767456 -0.01335144 -0.9638672 0.26796722q-0.7847595 0.4532318 -1.0640869 1.2821655q-0.2735901 0.80758667 0.52301025 2.187973l2.2648926 3.924614l-1.2177734 0.70329285zm12.360168 -7.1384735l-0.5232849 -0.906723q-0.059539795 1.4598389 -1.3855286 2.2256546q-0.8659668 0.5001221 -1.8564148 0.44062805q-0.9904785 -0.05949402 -1.8883972 -0.67765045q-0.88442993 -0.62597656 -1.5170288 -1.7221603q-0.6247864 -1.0826492 -0.77282715 -2.151947q-0.14230347 -1.0906448 0.30926514 -1.9649353q0.44378662 -0.887825 1.336792 -1.4035797q0.6494751 -0.37509155 1.3063049 -0.39356232q0.6703491 -0.026283264 1.2392883 0.2405777l-2.054016 -3.5592194l1.2177429 -0.7033005l5.7246704 9.919807l-1.1365662 0.6564102zm-5.9122925 -1.3669891q0.79660034 1.3803787 1.7767029 1.7345505q0.97232056 0.3406372 1.7570801 -0.112602234q0.7983093 -0.4610443 0.97817993 -1.4310303q0.18560791 -0.991333 -0.58758545 -2.3311157q-0.8512573 -1.4751129 -1.8178406 -1.8370972q-0.96658325 -0.36198425 -1.805481 0.12251282q-0.8118286 0.46886444 -0.97036743 1.4445648q-0.15853882 0.9757004 0.6693115 2.4102173zm12.53952 -5.5459747l1.3442383 -0.57787323q0.34274292 1.2816315 -0.11764526 2.3594894q-0.4468689 1.0700455 -1.6916809 1.788971q-1.5830688 0.9142914 -3.0654602 0.47128296q-1.4823914 -0.44300842 -2.4898376 -2.1887817q-1.0465393 -1.813446 -0.69085693 -3.3540955q0.35568237 -1.5406494 1.8440247 -2.4002304q1.4342346 -0.828331 2.9109192 -0.36397552q1.4823608 0.4430008 2.5054626 2.2158508q0.0625 0.10826111 0.18743896 0.32479095l-5.3580933 3.094513q0.75494385 1.1518478 1.7173462 1.4440689q0.96810913 0.27088165 1.861145 -0.24487305q0.6765137 -0.39072418 0.9470215 -1.0160904q0.2705078 -0.6253662 0.09597778 -1.5530472zm-5.131775 0.32941437l4.005066 -2.3130646q-0.60446167 -0.8598404 -1.2410278 -1.0876312q-0.98794556 -0.36769867 -1.90802 0.16368103q-0.8388977 0.48449707 -1.0926819 1.3889084q-0.2480774 0.88306427 0.23666382 1.8481064zm10.854523 3.1318436l-5.740265 -9.946869l1.1094971 -0.6407852l0.5388794 0.9337845q0.07220459 -0.78147125 0.4031067 -1.333458q0.3444214 -0.5597992 1.0480042 -0.9661484q0.92007446 -0.5313797 1.8970032 -0.46406555q0.9769592 0.06730652 1.8285828 0.7302551q0.8573303 0.64160156 1.4508972 1.6701202q0.6404114 1.1097183 0.74212646 2.2238083q0.11526489 1.1062698 -0.3691101 2.01754q-0.4708252 0.90345 -1.3097229 1.3879471q-0.6088867 0.35164642 -1.2443542 0.37583923q-0.62197876 0.01637268 -1.159668 -0.19635773l2.022766 3.5050888l-1.2177429 0.7033005zm-2.551239 -6.952957q0.80441284 1.3939209 1.7418518 1.7366486q0.95095825 0.3349228 1.7492676 -0.12612915q0.8118286 -0.46886444 0.9953308 -1.495079q0.18353271 -1.026207 -0.6443176 -2.4607239q-0.79663086 -1.3803787 -1.7554016 -1.728836q-0.96658325 -0.36198425 -1.7513428 0.09125519q-0.77124023 0.4454193 -0.958374 1.5278625q-0.1736145 1.0746231 0.62298584 2.4550018zm12.239746 -5.408928l1.3442688 -0.57788086q0.34274292 1.2816391 -0.11764526 2.359497q-0.4468689 1.0700455 -1.6916809 1.788971q-1.5830688 0.91428375 -3.0654602 0.47128296q-1.4823914 -0.44300842 -2.4898682 -2.1887894q-1.0465088 -1.8134384 -0.6908264 -3.3540878q0.35565186 -1.5406494 1.8440247 -2.400238q1.4342346 -0.828331 2.9108887 -0.36397552q1.4823914 0.44300842 2.5054932 2.2158508q0.062469482 0.10826874 0.18743896 0.32479858l-5.3580933 3.094513q0.75491333 1.1518402 1.7173462 1.4440689q0.96810913 0.27087402 1.861145 -0.24488068q0.6765137 -0.39071655 0.9470215 -1.0160828q0.2705078 -0.62537384 0.095947266 -1.5530472zm-5.1317444 0.32941437l4.0050354 -2.3130722q-0.60443115 -0.85983276 -1.2410278 -1.0876236q-0.98791504 -0.3677063 -1.9079895 0.1636734q-0.8388977 0.48449707 -1.0926819 1.3889084q-0.2480774 0.88306427 0.23666382 1.848114zm9.2612915 0.3710785l-4.1470337 -7.1861115l1.0959778 -0.6329651l0.5857239 1.0149841q0.10531616 -1.6306229 1.6072083 -2.498024q0.6494751 -0.37509155 1.3234558 -0.45760345q0.6739807 -0.0825119 1.163269 0.14012909q0.48928833 0.22263336 0.9021301 0.687912q0.26290894 0.29925537 0.74710083 1.1383133l2.553833 4.4253464l-1.2177429 0.70329285l-2.522583 -4.371208q-0.42956543 -0.74432373 -0.7892456 -1.0237579q-0.3540039 -0.30078888 -0.8442688 -0.30632782q-0.4767456 -0.01335144 -0.9638672 0.26796722q-0.7847595 0.4532318 -1.0640869 1.2821655q-0.2735901 0.80758667 0.52301025 2.187973l2.2648926 3.924614l-1.2177734 0.7033005zm12.360168 -7.138481l-0.5232849 -0.906723q-0.059539795 1.4598389 -1.3855286 2.2256546q-0.8659668 0.5001221 -1.8564148 0.44062805q-0.9904785 -0.05949402 -1.8883972 -0.67765045q-0.88442993 -0.62596893 -1.5170288 -1.7221603q-0.6247864 -1.0826492 -0.77282715 -2.151947q-0.14230347 -1.0906448 0.30926514 -1.9649353q0.44378662 -0.887825 1.336792 -1.4035797q0.6494751 -0.37509155 1.3063049 -0.39356232q0.6703491 -0.026283264 1.2392883 0.2405777l-2.054016 -3.5592194l1.2177429 -0.7033005l5.7246704 9.919807l-1.1365662 0.6564102zm-5.9122925 -1.3669891q0.79660034 1.3803787 1.7767029 1.7345505q0.97232056 0.3406372 1.7570801 -0.112602234q0.7983093 -0.4610443 0.97817993 -1.4310303q0.18560791 -0.991333 -0.58758545 -2.3311157q-0.8512573 -1.4751129 -1.8178406 -1.8370972q-0.96658325 -0.36198425 -1.805481 0.12251282q-0.8118286 0.46886444 -0.97036743 1.4445648q-0.15853882 0.9757004 0.6693115 2.4102173zm12.53952 -5.5459747l1.3442383 -0.57787323q0.34274292 1.2816391 -0.11764526 2.3594894q-0.4468689 1.0700455 -1.6916809 1.788971q-1.5830688 0.9142914 -3.0654602 0.47128296q-1.4823914 -0.4430008 -2.4898682 -2.1887817q-1.0465088 -1.813446 -0.6908264 -3.3540955q0.35568237 -1.5406494 1.8440247 -2.4002304q1.4342346 -0.828331 2.9109192 -0.36397552q1.4823608 0.4430008 2.5054626 2.2158508q0.0625 0.10826111 0.18743896 0.32479095l-5.3580933 3.094513q0.75494385 1.1518478 1.7173462 1.4440689q0.96810913 0.27088165 1.861145 -0.24487305q0.6765137 -0.39072418 0.9470215 -1.0160904q0.2705078 -0.6253662 0.09597778 -1.5530472zm-5.131775 0.32941437l4.005066 -2.3130646q-0.60446167 -0.8598404 -1.2410278 -1.0876312q-0.98794556 -0.36769867 -1.90802 0.16368103q-0.8388977 0.48449707 -1.0926819 1.3889084q-0.2480774 0.88306427 0.23666382 1.8481064zm9.261322 0.3710785l-4.147064 -7.186104l1.0959778 -0.6329727l0.5857544 1.0149918q0.105285645 -1.6306229 1.6071777 -2.498024q0.6494751 -0.37509155 1.3234558 -0.45761108q0.67401123 -0.0825119 1.163269 0.14012909q0.48928833 0.22264099 0.9021301 0.6879196q0.26290894 0.29925537 0.74710083 1.1383133l2.553833 4.4253387l-1.2177429 0.7033005l-2.522583 -4.371208q-0.4295349 -0.74432373 -0.7892456 -1.0237656q-0.3539734 -0.30078125 -0.8442688 -0.3063202q-0.4767456 -0.01335907 -0.96383667 0.2679596q-0.78479004 0.45323944 -1.0640869 1.2821732q-0.2736206 0.80758667 0.52301025 2.1879654l2.264862 3.924614l-1.2177429 0.7033005zm9.725006 -7.078125l0.8006897 0.98106384q-0.45169067 0.4052124 -0.857605 0.6396408q-0.6629944 0.38291168 -1.1454468 0.39089966q-0.4902954 -0.0055389404 -0.8343811 -0.25791168q-0.3518982 -0.26589966 -0.9844971 -1.3620834l-2.382019 -4.127617l-0.8930054 0.5157547l-0.5466919 -0.94731903l0.8930054 -0.5157547l-1.0308838 -1.786377l0.79599 -1.4340897l1.4526672 2.5171661l1.2177429 -0.7033005l0.5466919 0.94732666l-1.2177429 0.70329285l2.421051 4.195282q0.30459595 0.5277939 0.4446106 0.645401q0.15356445 0.10979462 0.35708618 0.11856842q0.19567871 -0.004760742 0.46627808 -0.16104889q0.20297241 -0.11721802 0.49645996 -0.35889435z" fill-rule="nonzero"/><path fill="#000000" d="m299.15155 144.21382l-5.72464 -9.919815l1.2177429 -0.70329285l3.2645264 5.6568604l1.1950684 -4.587631l1.5830688 -0.9142914l-1.2081604 4.252365l5.625824 2.7774506l-1.5018921 0.8674011l-4.4921265 -2.3134918l-0.38955688 1.3256378l1.6478882 2.8554993l-1.2177429 0.7033081zm10.503723 -9.151794l1.3442383 -0.57788086q0.34274292 1.2816315 -0.11764526 2.359497q-0.44683838 1.0700378 -1.6916504 1.788971q-1.5830688 0.9142761 -3.0654602 0.47128296q-1.4823914 -0.44300842 -2.4898682 -2.188797q-1.0465393 -1.8134308 -0.69085693 -3.3540802q0.35568237 -1.5406494 1.8440247 -2.400238q1.4342346 -0.8283386 2.9109192 -0.36398315q1.4823914 0.44300842 2.5054932 2.2158508q0.062469482 0.10827637 0.18743896 0.32479858l-5.3580933 3.094513q0.75491333 1.1518402 1.7173157 1.4440765q0.96813965 0.27087402 1.861145 -0.2448883q0.6765137 -0.39071655 0.947052 -1.0160828q0.2705078 -0.6253662 0.095947266 -1.5530396zm-5.1317444 0.32940674l4.0050354 -2.3130646q-0.60446167 -0.85983276 -1.2410278 -1.0876312q-0.98791504 -0.3677063 -1.90802 0.16368103q-0.8388672 0.48449707 -1.0926819 1.3889008q-0.2480774 0.8830719 0.23669434 1.848114zm9.247772 0.378891l-4.147064 -7.1861115l1.0959778 -0.6329651l0.6247864 1.0826569q-0.017913818 -1.0001068 0.19329834 -1.4468765q0.21124268 -0.4467697 0.64419556 -0.69683075q0.6088867 -0.35165405 1.4713135 -0.32646942l0.22875977 1.3655014q-0.59487915 7.4768066E-4 -1.0413818 0.25862122q-0.39239502 0.22662354 -0.5765686 0.6577606q-0.17840576 0.40979004 -0.063812256 0.92100525q0.17190552 0.7667999 0.61709595 1.5381927l2.1711426 3.7622223l-1.2177429 0.70329285zm4.6274414 -2.6725311l-4.147064 -7.186104l1.0959778 -0.6329727l0.5857544 1.0149841q0.105285645 -1.6306152 1.6071777 -2.4980164q0.6494751 -0.37509918 1.3234558 -0.45761108q0.67401123 -0.0825119 1.1632996 0.14012909q0.4892578 0.22264099 0.9020996 0.6879196q0.26290894 0.29925537 0.74710083 1.1383133l2.553833 4.4253464l-1.2177429 0.70329285l-2.522583 -4.371208q-0.4295349 -0.74432373 -0.7892456 -1.0237656q-0.3539734 -0.30078125 -0.8442688 -0.3063202q-0.4767456 -0.01335907 -0.96383667 0.2679596q-0.78479004 0.45323944 -1.0640869 1.2821732q-0.2736206 0.80758667 0.52301025 2.1879654l2.264862 3.9246216l-1.2177429 0.70329285zm11.281708 -9.60112l1.3442688 -0.57788086q0.34274292 1.2816391 -0.11764526 2.359497q-0.4468689 1.0700378 -1.6916809 1.788971q-1.5830688 0.91428375 -3.0654602 0.47127533q-1.4823914 -0.4430008 -2.4898682 -2.1887817q-1.0465088 -1.8134384 -0.6908264 -3.3540878q0.35565186 -1.5406494 1.8440247 -2.400238q1.4342346 -0.828331 2.9108887 -0.36397552q1.4823914 0.44300842 2.5054932 2.2158508q0.062469482 0.10826111 0.18743896 0.32479095l-5.3580933 3.094513q0.75491333 1.1518478 1.7173462 1.4440689q0.96810913 0.27088165 1.861145 -0.24487305q0.6765137 -0.39072418 0.9470215 -1.0160904q0.2705078 -0.6253662 0.095947266 -1.5530396zm-5.1317444 0.32941437l4.0050354 -2.3130722q-0.60443115 -0.85983276 -1.2410278 -1.0876236q-0.98791504 -0.3677063 -1.9079895 0.1636734q-0.8388977 0.48449707 -1.0926819 1.3889084q-0.2480774 0.88306427 0.23666382 1.848114zm9.234253 0.3867035l-5.7246704 -9.919807l1.2177734 -0.7033005l5.72464 9.919807l-1.2177429 0.7033005z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m282.76047 135.94267c-17.003845 0 -26.795105 -5.566925 -34.007706 -11.133858c-7.2126007 -5.5669327 -11.846542 -11.133858 -23.6931 -11.133858" fill-rule="evenodd"/><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m282.76047 135.94267c-17.003876 0 -26.795105 -5.5669403 -34.007706 -11.133865c-3.6062927 -2.7834625 -6.567932 -5.566925 -10.10881 -7.6545258c-0.4426117 -0.26094818 -0.8942871 -0.5110321 -1.3573761 -0.74887085c-0.11578369 -0.0594635 -0.23228455 -0.11816406 -0.34950256 -0.17607117l-0.13806152 -0.06682587" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="2.0" stroke-linecap="butt" d="m237.48381 112.93081l-9.563843 1.350235l8.194244 5.113182z" fill-rule="evenodd"/></g></svg> \ No newline at end of file
diff --git a/g3doc/Machine-Virtualization.png b/g3doc/Machine-Virtualization.png
new file mode 100644
index 000000000..1ba2ed6b2
--- /dev/null
+++ b/g3doc/Machine-Virtualization.png
Binary files differ
diff --git a/g3doc/Machine-Virtualization.svg b/g3doc/Machine-Virtualization.svg
new file mode 100644
index 000000000..5352da07b
--- /dev/null
+++ b/g3doc/Machine-Virtualization.svg
@@ -0,0 +1 @@
+<svg version="1.1" viewBox="0.0 0.0 387.7034120734908 336.4225721784777" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg"><clipPath id="p.0"><path d="m0 0l387.7034 0l0 336.42258l-387.7034 0l0 -336.42258z" clip-rule="nonzero"/></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l387.7034 0l0 336.42258l-387.7034 0z" fill-rule="evenodd"/><path fill="#f4cccc" d="m44.454067 14.643044l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path stroke="#cc4125" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m44.454067 14.643044l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path fill="#000000" d="m86.206116 45.98824l5.125 -13.359375l1.90625 0l5.46875 13.359375l-2.015625 0l-1.546875 -4.046875l-5.59375 0l-1.46875 4.046875l-1.875 0zm3.859375 -5.484375l4.53125 0l-1.40625 -3.703125q-0.625 -1.6875 -0.9375 -2.765625q-0.265625 1.28125 -0.71875 2.546875l-1.46875 3.921875zm9.849823 9.1875l0 -13.375l1.484375 0l0 1.25q0.53125 -0.734375 1.1875 -1.09375q0.671875 -0.375 1.625 -0.375q1.234375 0 2.171875 0.640625q0.953125 0.625 1.4375 1.796875q0.484375 1.15625 0.484375 2.546875q0 1.484375 -0.53125 2.671875q-0.53125 1.1875 -1.546875 1.828125q-1.015625 0.625 -2.140625 0.625q-0.8125 0 -1.46875 -0.34375q-0.65625 -0.34375 -1.0625 -0.875l0 4.703125l-1.640625 0zm1.484375 -8.484375q0 1.859375 0.75 2.765625q0.765625 0.890625 1.828125 0.890625q1.09375 0 1.875 -0.921875q0.78125 -0.9375 0.78125 -2.875q0 -1.84375 -0.765625 -2.765625q-0.75 -0.921875 -1.8125 -0.921875q-1.046875 0 -1.859375 0.984375q-0.796875 0.96875 -0.796875 2.84375zm8.891342 8.484375l0 -13.375l1.484375 0l0 1.25q0.53125 -0.734375 1.1875 -1.09375q0.671875 -0.375 1.625 -0.375q1.234375 0 2.171875 0.640625q0.953125 0.625 1.4375 1.796875q0.484375 1.15625 0.484375 2.546875q0 1.484375 -0.53125 2.671875q-0.53125 1.1875 -1.546875 1.828125q-1.015625 0.625 -2.140625 0.625q-0.8125 0 -1.46875 -0.34375q-0.65625 -0.34375 -1.0625 -0.875l0 4.703125l-1.640625 0zm1.484375 -8.484375q0 1.859375 0.75 2.765625q0.765625 0.890625 1.828125 0.890625q1.09375 0 1.875 -0.921875q0.78125 -0.9375 0.78125 -2.875q0 -1.84375 -0.765625 -2.765625q-0.75 -0.921875 -1.8125 -0.921875q-1.046875 0 -1.859375 0.984375q-0.796875 0.96875 -0.796875 2.84375zm8.844467 4.78125l0 -13.359375l1.640625 0l0 13.359375l-1.640625 0zm4.191696 -11.46875l0 -1.890625l1.640625 0l0 1.890625l-1.640625 0zm0 11.46875l0 -9.671875l1.640625 0l0 9.671875l-1.640625 0zm10.457321 -3.546875l1.609375 0.21875q-0.265625 1.65625 -1.359375 2.609375q-1.078125 0.9375 -2.671875 0.9375q-1.984375 0 -3.1875 -1.296875q-1.203125 -1.296875 -1.203125 -3.71875q0 -1.578125 0.515625 -2.75q0.515625 -1.171875 1.578125 -1.75q1.0625 -0.59375 2.3125 -0.59375q1.578125 0 2.578125 0.796875q1.0 0.796875 1.28125 2.265625l-1.59375 0.234375q-0.234375 -0.96875 -0.8125 -1.453125q-0.578125 -0.5 -1.390625 -0.5q-1.234375 0 -2.015625 0.890625q-0.78125 0.890625 -0.78125 2.8125q0 1.953125 0.75 2.84375q0.75 0.875 1.953125 0.875q0.96875 0 1.609375 -0.59375q0.65625 -0.59375 0.828125 -1.828125zm9.328125 2.359375q-0.921875 0.765625 -1.765625 1.09375q-0.828125 0.3125 -1.796875 0.3125q-1.59375 0 -2.453125 -0.78125q-0.859375 -0.78125 -0.859375 -1.984375q0 -0.71875 0.328125 -1.296875q0.328125 -0.59375 0.84375 -0.9375q0.53125 -0.359375 1.1875 -0.546875q0.46875 -0.125 1.453125 -0.25q1.984375 -0.234375 2.921875 -0.5625q0.015625 -0.34375 0.015625 -0.421875q0 -1.0 -0.46875 -1.421875q-0.625 -0.546875 -1.875 -0.546875q-1.15625 0 -1.703125 0.40625q-0.546875 0.40625 -0.8125 1.421875l-1.609375 -0.21875q0.21875 -1.015625 0.71875 -1.640625q0.5 -0.640625 1.453125 -0.984375q0.953125 -0.34375 2.1875 -0.34375q1.25 0 2.015625 0.296875q0.78125 0.28125 1.140625 0.734375q0.375 0.4375 0.515625 1.109375q0.078125 0.421875 0.078125 1.515625l0 2.1875q0 2.28125 0.109375 2.890625q0.109375 0.59375 0.40625 1.15625l-1.703125 0q-0.265625 -0.515625 -0.328125 -1.1875zm-0.140625 -3.671875q-0.890625 0.375 -2.671875 0.625q-1.015625 0.140625 -1.4375 0.328125q-0.421875 0.1875 -0.65625 0.53125q-0.21875 0.34375 -0.21875 0.78125q0 0.65625 0.5 1.09375q0.5 0.4375 1.453125 0.4375q0.9375 0 1.671875 -0.40625q0.75 -0.421875 1.09375 -1.140625q0.265625 -0.5625 0.265625 -1.640625l0 -0.609375zm7.781967 3.390625l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578125l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671875q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm1.6051788 -10.0l0 -1.890625l1.640625 0l0 1.890625l-1.640625 0zm0 11.46875l0 -9.671875l1.640625 0l0 9.671875l-1.640625 0zm3.5354462 -4.84375q0 -2.6875 1.484375 -3.96875q1.25 -1.078125 3.046875 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609375q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.03125 0 -3.28125 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.8125 0.921875 2.046875 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.796875 -0.8125 -2.71875q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.046875 0.921875q-0.796875 0.90625 -0.796875 2.765625zm9.297592 4.84375l0 -9.671875l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.59375l0 5.953125l-1.640625 0l0 -5.890625q0 -1.0 -0.203125 -1.484375q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515625l0 5.28125l-1.640625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m12.454068 81.068245l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m12.454068 81.068245l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m180.45407 81.068245l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m180.45407 81.068245l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m81.43044 64.70238l100.88189 0l0 32.06299l-100.88189 0z" fill-rule="evenodd"/><path fill="#000000" d="m95.06437 82.471375l1.203125 -0.109375q0.078125 0.71875 0.390625 1.1875q0.3125 0.453125 0.953125 0.734375q0.65625 0.28125 1.46875 0.28125q0.71875 0 1.265625 -0.21875q0.5625 -0.21875 0.828125 -0.578125q0.265625 -0.375 0.265625 -0.828125q0 -0.453125 -0.265625 -0.78125q-0.25 -0.328125 -0.84375 -0.5625q-0.390625 -0.15625 -1.703125 -0.46875q-1.3125 -0.3125 -1.84375 -0.59375q-0.671875 -0.359375 -1.015625 -0.890625q-0.328125 -0.53125 -0.328125 -1.1875q0 -0.71875 0.40625 -1.34375q0.40625 -0.625 1.1875 -0.953125q0.796875 -0.328125 1.765625 -0.328125q1.046875 0 1.859375 0.34375q0.8125 0.34375 1.25 1.015625q0.4375 0.65625 0.46875 1.484375l-1.203125 0.09375q-0.109375 -0.90625 -0.671875 -1.359375q-0.5625 -0.46875 -1.65625 -0.46875q-1.140625 0 -1.671875 0.421875q-0.515625 0.421875 -0.515625 1.015625q0 0.515625 0.359375 0.84375q0.375 0.328125 1.90625 0.6875q1.546875 0.34375 2.109375 0.59375q0.84375 0.390625 1.234375 0.984375q0.390625 0.578125 0.390625 1.359375q0 0.75 -0.4375 1.4375q-0.421875 0.671875 -1.25 1.046875q-0.8125 0.359375 -1.828125 0.359375q-1.296875 0 -2.171875 -0.375q-0.875 -0.375 -1.375 -1.125q-0.5 -0.765625 -0.53125 -1.71875zm9.12413 5.71875l-0.125 -1.09375q0.375 0.109375 0.65625 0.109375q0.390625 0 0.625 -0.140625q0.234375 -0.125 0.390625 -0.359375q0.109375 -0.171875 0.359375 -0.875q0.03125 -0.09375 0.109375 -0.28125l-2.625 -6.921875l1.265625 0l1.4375 4.0q0.28125 0.765625 0.5 1.59375q0.203125 -0.796875 0.46875 -1.578125l1.484375 -4.015625l1.171875 0l-2.625 7.015625q-0.421875 1.140625 -0.65625 1.578125q-0.3125 0.578125 -0.71875 0.84375q-0.40625 0.28125 -0.96875 0.28125q-0.328125 0 -0.75 -0.15625zm6.2421875 -4.71875l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625zm9.6953125 1.015625l0.171875 1.03125q-0.5 0.109375 -0.890625 0.109375q-0.640625 0 -1.0 -0.203125q-0.34375 -0.203125 -0.484375 -0.53125q-0.140625 -0.328125 -0.140625 -1.390625l0 -3.96875l-0.859375 0l0 -0.90625l0.859375 0l0 -1.71875l1.171875 -0.703125l0 2.421875l1.171875 0l0 0.90625l-1.171875 0l0 4.046875q0 0.5 0.046875 0.640625q0.0625 0.140625 0.203125 0.234375q0.140625 0.078125 0.40625 0.078125q0.203125 0 0.515625 -0.046875zm5.8748627 -1.171875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375zm6.521843 4.125l0 -6.90625l1.046875 0l0 0.96875q0.328125 -0.515625 0.859375 -0.8125q0.546875 -0.3125 1.234375 -0.3125q0.78125 0 1.265625 0.3125q0.484375 0.3125 0.6875 0.890625q0.828125 -1.203125 2.140625 -1.203125q1.03125 0 1.578125 0.578125q0.5625 0.5625 0.5625 1.734375l0 4.75l-1.171875 0l0 -4.359375q0 -0.703125 -0.125 -1.0q-0.109375 -0.3125 -0.40625 -0.5q-0.296875 -0.1875 -0.703125 -0.1875q-0.71875 0 -1.203125 0.484375q-0.484375 0.484375 -0.484375 1.546875l0 4.015625l-1.171875 0l0 -4.484375q0 -0.78125 -0.296875 -1.171875q-0.28125 -0.390625 -0.921875 -0.390625q-0.5 0 -0.921875 0.265625q-0.421875 0.25 -0.609375 0.75q-0.1875 0.5 -0.1875 1.453125l0 3.578125l-1.171875 0zm19.32106 -2.53125l1.15625 0.15625q-0.1875 1.1875 -0.96875 1.859375q-0.78125 0.671875 -1.921875 0.671875q-1.4062653 0 -2.2812653 -0.921875q-0.859375 -0.9375 -0.859375 -2.65625q0 -1.125 0.375 -1.96875q0.375 -0.84375 1.125 -1.25q0.765625 -0.421875 1.6562653 -0.421875q1.125 0 1.84375 0.578125q0.71875 0.5625 0.921875 1.609375l-1.140625 0.171875q-0.171875 -0.703125 -0.59375 -1.046875q-0.40625 -0.359375 -0.984375 -0.359375q-0.890625 0 -1.4531403 0.640625q-0.546875 0.640625 -0.546875 2.0q0 1.40625 0.53125 2.03125q0.546875 0.625 1.4062653 0.625q0.6875 0 1.140625 -0.421875q0.46875 -0.421875 0.59375 -1.296875zm6.6640625 1.671875q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.96109 0l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.507965 -2.0625l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625z" fill-rule="nonzero"/><path fill="#fff2cc" d="m44.454067 95.40656l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path stroke="#f1c232" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m44.454067 95.40656l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path fill="#000000" d="m85.11598 121.51739l0 -1.578125l5.65625 0l0 4.953125q-1.296875 1.046875 -2.6875 1.578125q-1.375 0.515625 -2.84375 0.515625q-1.96875 0 -3.578125 -0.84375q-1.609375 -0.84375 -2.421875 -2.4375q-0.8125 -1.59375 -0.8125 -3.5625q0 -1.953125 0.8125 -3.640625q0.8125 -1.6875 2.34375 -2.5q1.53125 -0.828125 3.515625 -0.828125q1.453125 0 2.625 0.46875q1.171875 0.46875 1.828125 1.3125q0.671875 0.828125 1.015625 2.171875l-1.59375 0.4375q-0.296875 -1.015625 -0.75 -1.59375q-0.4375 -0.59375 -1.265625 -0.9375q-0.828125 -0.34375 -1.84375 -0.34375q-1.203125 0 -2.09375 0.375q-0.890625 0.359375 -1.4375 0.96875q-0.53125 0.59375 -0.828125 1.3125q-0.515625 1.234375 -0.515625 2.6875q0 1.78125 0.609375 2.984375q0.625 1.203125 1.796875 1.796875q1.171875 0.578125 2.5 0.578125q1.140625 0 2.234375 -0.4375q1.09375 -0.453125 1.65625 -0.953125l0 -2.484375l-3.921875 0zm14.386429 5.234375l0 -1.421875q-1.125 1.640625 -3.0625 1.640625q-0.859375 0 -1.609375 -0.328125q-0.734375 -0.328125 -1.09375 -0.828125q-0.359375 -0.5 -0.5 -1.21875q-0.109375 -0.46875 -0.109375 -1.53125l0 -5.984375l1.640625 0l0 5.359375q0 1.28125 0.109375 1.734375q0.15625 0.640625 0.65625 1.015625q0.5 0.375 1.234375 0.375q0.734375 0 1.375 -0.375q0.65625 -0.390625 0.921875 -1.03125q0.265625 -0.65625 0.265625 -1.890625l0 -5.1875l1.640625 0l0 9.671875l-1.46875 0zm10.672592 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.765625 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375 0 3.15625 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.21875 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.546875 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.390625 -2.65625l5.40625 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.03125 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm8.485092 2.875l1.625 -0.25q0.125 0.96875 0.75 1.5q0.625 0.515625 1.75 0.515625q1.125 0 1.671875 -0.453125q0.546875 -0.46875 0.546875 -1.09375q0 -0.546875 -0.484375 -0.875q-0.328125 -0.21875 -1.671875 -0.546875q-1.8125 -0.46875 -2.515625 -0.796875q-0.6875 -0.328125 -1.046875 -0.90625q-0.359375 -0.59375 -0.359375 -1.3125q0 -0.640625 0.296875 -1.1875q0.296875 -0.5625 0.8125 -0.921875q0.375 -0.28125 1.03125 -0.46875q0.671875 -0.203125 1.421875 -0.203125q1.140625 0 2.0 0.328125q0.859375 0.328125 1.265625 0.890625q0.421875 0.5625 0.578125 1.5l-1.609375 0.21875q-0.109375 -0.75 -0.640625 -1.171875q-0.515625 -0.421875 -1.46875 -0.421875q-1.140625 0 -1.625 0.375q-0.46875 0.375 -0.46875 0.875q0 0.3125 0.1875 0.578125q0.203125 0.265625 0.640625 0.4375q0.234375 0.09375 1.4375 0.421875q1.75 0.453125 2.4375 0.75q0.6875 0.296875 1.078125 0.859375q0.390625 0.5625 0.390625 1.40625q0 0.828125 -0.484375 1.546875q-0.46875 0.71875 -1.375 1.125q-0.90625 0.390625 -2.046875 0.390625q-1.875 0 -2.875 -0.78125q-0.984375 -0.78125 -1.25 -2.328125zm13.5625 1.421875l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578125l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671875q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm6.9134827 1.46875l0 -13.359375l1.78125 0l0 6.625l6.625 -6.625l2.390625 0l-5.59375 5.421875l5.84375 7.9375l-2.328125 0l-4.765625 -6.765625l-2.171875 2.140625l0 4.625l-1.78125 0zm18.943573 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.765625 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375 0 3.15625 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.21875 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.546875 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.390625 -2.65625l5.40625 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.03125 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.125717 5.765625l0 -9.671875l1.46875 0l0 1.46875q0.5625 -1.03125 1.03125 -1.359375q0.484375 -0.328125 1.0625 -0.328125q0.828125 0 1.6875 0.53125l-0.5625 1.515625q-0.609375 -0.359375 -1.203125 -0.359375q-0.546875 0 -0.96875 0.328125q-0.421875 0.328125 -0.609375 0.890625q-0.28125 0.875 -0.28125 1.921875l0 5.0625l-1.625 0zm6.228302 0l0 -9.671875l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.59375l0 5.953125l-1.640625 0l0 -5.890625q0 -1.0 -0.203125 -1.484375q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515625l0 5.28125l-1.640625 0zm17.000717 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.765625 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375 0 3.15625 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.21875 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.546875 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.390625 -2.65625l5.40625 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.03125 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.094467 5.765625l0 -13.359375l1.640625 0l0 13.359375l-1.640625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m12.454068 161.83176l57.574802 0" fill-rule="evenodd"/><path stroke="#ff0000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m12.454068 161.83176l57.574802 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m193.71391 161.83176l60.7874 0" fill-rule="evenodd"/><path stroke="#ff0000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m193.71391 161.83176l60.7874 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m70.02887 145.80026l123.68504 0l0 32.06299l-123.68504 0z" fill-rule="evenodd"/><path fill="#000000" d="m87.09864 166.63176l-3.6875 -9.546875l1.359375 0l2.484375 6.9375q0.296875 0.828125 0.5 1.5625q0.21875 -0.78125 0.515625 -1.5625l2.578125 -6.9375l1.28125 0l-3.734375 9.546875l-1.296875 0zm6.0303802 -8.1875l0 -1.359375l1.171875 0l0 1.359375l-1.171875 0zm0 8.1875l0 -6.90625l1.171875 0l0 6.90625l-1.171875 0zm2.92984 0l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm7.0164948 -1.046875l0.171875 1.03125q-0.5 0.109375 -0.890625 0.109375q-0.640625 0 -1.0 -0.203125q-0.34375 -0.203125 -0.484375 -0.53125q-0.140625 -0.328125 -0.140625 -1.390625l0 -3.96875l-0.859375 0l0 -0.90625l0.859375 0l0 -1.71875l1.171875 -0.703125l0 2.421875l1.171875 0l0 0.90625l-1.171875 0l0 4.046875q0 0.5 0.046875 0.640625q0.0625 0.140625 0.203125 0.234375q0.140625 0.078125 0.40625 0.078125q0.203125 0 0.515625 -0.046875zm5.6717377 1.046875l0 -1.015625q-0.8125 1.171875 -2.1875 1.171875q-0.609375 0 -1.140625 -0.234375q-0.53125 -0.234375 -0.796875 -0.578125q-0.25 -0.359375 -0.359375 -0.875q-0.0625 -0.34375 -0.0625 -1.09375l0 -4.28125l1.171875 0l0 3.828125q0 0.921875 0.0625 1.234375q0.109375 0.46875 0.46875 0.734375q0.359375 0.25 0.890625 0.25q0.515625 0 0.984375 -0.265625q0.46875 -0.265625 0.65625 -0.734375q0.1875 -0.46875 0.1875 -1.34375l0 -3.703125l1.171875 0l0 6.90625l-1.046875 0zm7.3968506 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm6.6797028 0l0 -9.546875l1.171875 0l0 3.421875q0.828125 -0.9375 2.0781174 -0.9375q0.765625 0 1.328125 0.296875q0.5625 0.296875 0.8125 0.84375q0.25 0.53125 0.25 1.546875l0 4.375l-1.171875 0l0 -4.375q0 -0.890625 -0.390625 -1.28125q-0.375 -0.40625 -1.078125 -0.40625q-0.515625 0 -0.9843674 0.28125q-0.453125 0.265625 -0.65625 0.734375q-0.1875 0.453125 -0.1875 1.265625l0 3.78125l-1.171875 0zm11.928093 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm8.93837 0l0 -0.875q-0.65625 1.03125 -1.9375 1.03125q-0.8125 0 -1.515625 -0.453125q-0.6875 -0.453125 -1.078125 -1.265625q-0.375 -0.828125 -0.375 -1.890625q0 -1.03125 0.34375 -1.875q0.34375 -0.84375 1.03125 -1.28125q0.703125 -0.453125 1.546875 -0.453125q0.625 0 1.109375 0.265625q0.5 0.25 0.796875 0.671875l0 -3.421875l1.171875 0l0 9.546875l-1.09375 0zm-3.703125 -3.453125q0 1.328125 0.5625 1.984375q0.5625 0.65625 1.328125 0.65625q0.765625 0 1.296875 -0.625q0.53125 -0.625 0.53125 -1.90625q0 -1.421875 -0.546875 -2.078125q-0.546875 -0.671875 -1.34375 -0.671875q-0.78125 0 -1.3125 0.640625q-0.515625 0.625 -0.515625 2.0zm7.9124756 3.453125l-2.125 -6.90625l1.21875 0l1.09375 3.984375l0.421875 1.484375q0.015625 -0.109375 0.359375 -1.421875l1.09375 -4.046875l1.203125 0l1.03125 4.0l0.34375 1.328125l0.40625 -1.34375l1.171875 -3.984375l1.140625 0l-2.15625 6.90625l-1.21875 0l-1.09375 -4.140625l-0.265625 -1.171875l-1.40625 5.3125l-1.21875 0zm12.859528 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm9.18837 -2.21875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375z" fill-rule="nonzero"/><path fill="#d9ead3" d="m44.454067 175.40657l174.83464 0l0 48.850388l-174.83464 0z" fill-rule="evenodd"/><path stroke="#93c47d" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m44.454067 175.40657l174.83464 0l0 48.850388l-174.83464 0z" fill-rule="evenodd"/><path fill="#000000" d="m115.3588 206.75175l-5.171875 -13.359375l1.921875 0l3.46875 9.703125q0.421875 1.171875 0.703125 2.1875q0.3125 -1.09375 0.71875 -2.1875l3.609375 -9.703125l1.796875 0l-5.234375 13.359375l-1.8125 0zm8.584198 0l0 -13.359375l2.65625 0l3.1562424 9.453125q0.4375 1.328125 0.640625 1.984375q0.234375 -0.734375 0.703125 -2.140625l3.203125 -9.296875l2.375 0l0 13.359375l-1.703125 0l0 -11.171875l-3.875 11.171875l-1.59375 0l-3.8593674 -11.375l0 11.375l-1.703125 0zm15.540794 0l0 -13.359375l2.65625 0l3.15625 9.453125q0.4375 1.328125 0.640625 1.984375q0.234375 -0.734375 0.703125 -2.140625l3.203125 -9.296875l2.375 0l0 13.359375l-1.703125 0l0 -11.171875l-3.875 11.171875l-1.59375 0l-3.859375 -11.375l0 11.375l-1.703125 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m12.454068 239.1764l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m12.454068 239.1764l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m180.45407 239.1764l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m180.45407 239.1764l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m81.43044 222.81055l100.88189 0l0 32.06299l-100.88189 0z" fill-rule="evenodd"/><path fill="#000000" d="m95.06437 240.57954l1.203125 -0.109375q0.078125 0.71875 0.390625 1.1875q0.3125 0.453125 0.953125 0.734375q0.65625 0.28125 1.46875 0.28125q0.71875 0 1.265625 -0.21875q0.5625 -0.21875 0.828125 -0.578125q0.265625 -0.375 0.265625 -0.828125q0 -0.453125 -0.265625 -0.78125q-0.25 -0.328125 -0.84375 -0.5625q-0.390625 -0.15625 -1.703125 -0.46875q-1.3125 -0.3125 -1.84375 -0.59375q-0.671875 -0.359375 -1.015625 -0.890625q-0.328125 -0.53125 -0.328125 -1.1875q0 -0.71875 0.40625 -1.34375q0.40625 -0.625 1.1875 -0.953125q0.796875 -0.328125 1.765625 -0.328125q1.046875 0 1.859375 0.34375q0.8125 0.34375 1.25 1.015625q0.4375 0.65625 0.46875 1.484375l-1.203125 0.09375q-0.109375 -0.90625 -0.671875 -1.359375q-0.5625 -0.46875 -1.65625 -0.46875q-1.140625 0 -1.671875 0.421875q-0.515625 0.421875 -0.515625 1.015625q0 0.515625 0.359375 0.84375q0.375 0.328125 1.90625 0.6875q1.546875 0.34375 2.109375 0.59375q0.84375 0.390625 1.234375 0.984375q0.390625 0.578125 0.390625 1.359375q0 0.75 -0.4375 1.4375q-0.421875 0.671875 -1.25 1.046875q-0.8125 0.359375 -1.828125 0.359375q-1.296875 0 -2.171875 -0.375q-0.875 -0.375 -1.375 -1.125q-0.5 -0.765625 -0.53125 -1.71875zm9.12413 5.71875l-0.125 -1.09375q0.375 0.109375 0.65625 0.109375q0.390625 0 0.625 -0.140625q0.234375 -0.125 0.390625 -0.359375q0.109375 -0.171875 0.359375 -0.875q0.03125 -0.09375 0.109375 -0.28125l-2.625 -6.921875l1.265625 0l1.4375 4.0q0.28125 0.765625 0.5 1.59375q0.203125 -0.796875 0.46875 -1.578125l1.484375 -4.015625l1.171875 0l-2.625 7.015625q-0.421875 1.140625 -0.65625 1.578125q-0.3125 0.578125 -0.71875 0.84375q-0.40625 0.28125 -0.96875 0.28125q-0.328125 0 -0.75 -0.15625zm6.2421875 -4.71875l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625zm9.6953125 1.015625l0.171875 1.03125q-0.5 0.109375 -0.890625 0.109375q-0.640625 0 -1.0 -0.203125q-0.34375 -0.203125 -0.484375 -0.53125q-0.140625 -0.328125 -0.140625 -1.390625l0 -3.96875l-0.859375 0l0 -0.90625l0.859375 0l0 -1.71875l1.171875 -0.703125l0 2.421875l1.171875 0l0 0.90625l-1.171875 0l0 4.046875q0 0.5 0.046875 0.640625q0.0625 0.140625 0.203125 0.234375q0.140625 0.078125 0.40625 0.078125q0.203125 0 0.515625 -0.046875zm5.8748627 -1.171875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375zm6.521843 4.125l0 -6.90625l1.046875 0l0 0.96875q0.328125 -0.515625 0.859375 -0.8125q0.546875 -0.3125 1.234375 -0.3125q0.78125 0 1.265625 0.3125q0.484375 0.3125 0.6875 0.890625q0.828125 -1.203125 2.140625 -1.203125q1.03125 0 1.578125 0.578125q0.5625 0.5625 0.5625 1.734375l0 4.75l-1.171875 0l0 -4.359375q0 -0.703125 -0.125 -1.0q-0.109375 -0.3125 -0.40625 -0.5q-0.296875 -0.1875 -0.703125 -0.1875q-0.71875 0 -1.203125 0.484375q-0.484375 0.484375 -0.484375 1.546875l0 4.015625l-1.171875 0l0 -4.484375q0 -0.78125 -0.296875 -1.171875q-0.28125 -0.390625 -0.921875 -0.390625q-0.5 0 -0.921875 0.265625q-0.421875 0.25 -0.609375 0.75q-0.1875 0.5 -0.1875 1.453125l0 3.578125l-1.171875 0zm19.32106 -2.53125l1.15625 0.15625q-0.1875 1.1875 -0.96875 1.859375q-0.78125 0.671875 -1.921875 0.671875q-1.4062653 0 -2.2812653 -0.921875q-0.859375 -0.9375 -0.859375 -2.65625q0 -1.125 0.375 -1.96875q0.375 -0.84375 1.125 -1.25q0.765625 -0.421875 1.6562653 -0.421875q1.125 0 1.84375 0.578125q0.71875 0.5625 0.921875 1.609375l-1.140625 0.171875q-0.171875 -0.703125 -0.59375 -1.046875q-0.40625 -0.359375 -0.984375 -0.359375q-0.890625 0 -1.4531403 0.640625q-0.546875 0.640625 -0.546875 2.0q0 1.40625 0.53125 2.03125q0.546875 0.625 1.4062653 0.625q0.6875 0 1.140625 -0.421875q0.46875 -0.421875 0.59375 -1.296875zm6.6640625 1.671875q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.96109 0l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.507965 -2.0625l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625z" fill-rule="nonzero"/><path fill="#cfe2f3" d="m44.454067 252.7512l174.83464 0l0 48.850388l-174.83464 0z" fill-rule="evenodd"/><path stroke="#6d9eeb" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m44.454067 252.7512l174.83464 0l0 48.850388l-174.83464 0z" fill-rule="evenodd"/><path fill="#000000" d="m84.63558 284.0964l0 -13.359375l1.765625 0l0 5.484375l6.9375 0l0 -5.484375l1.765625 0l0 13.359375l-1.765625 0l0 -6.296875l-6.9375 0l0 6.296875l-1.765625 0zm12.597946 -4.84375q0 -2.6875 1.484375 -3.96875q1.25 -1.078125 3.046875 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609375q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.03125 0 -3.28125 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.8125 0.921875 2.046875 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.796875 -0.8125 -2.71875q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.046875 0.921875q-0.796875 0.90625 -0.796875 2.765625zm8.641342 1.953125l1.625 -0.25q0.125 0.96875 0.75 1.5q0.625 0.515625 1.75 0.515625q1.125 0 1.671875 -0.453125q0.546875 -0.46875 0.546875 -1.09375q0 -0.546875 -0.484375 -0.875q-0.328125 -0.21875 -1.671875 -0.546875q-1.8125 -0.46875 -2.515625 -0.796875q-0.6875 -0.328125 -1.046875 -0.90625q-0.359375 -0.59375 -0.359375 -1.3125q0 -0.640625 0.296875 -1.1875q0.296875 -0.5625 0.8125 -0.921875q0.375 -0.28125 1.03125 -0.46875q0.671875 -0.203125 1.421875 -0.203125q1.140625 0 2.0 0.328125q0.859375 0.328125 1.265625 0.890625q0.421875 0.5625 0.578125 1.5l-1.609375 0.21875q-0.109375 -0.75 -0.640625 -1.171875q-0.515625 -0.421875 -1.46875 -0.421875q-1.140625 0 -1.625 0.375q-0.46875 0.375 -0.46875 0.875q0 0.3125 0.1875 0.578125q0.203125 0.265625 0.640625 0.4375q0.234375 0.09375 1.4375 0.421875q1.75 0.453125 2.4375 0.75q0.6875 0.296875 1.078125 0.859375q0.390625 0.5625 0.390625 1.40625q0 0.828125 -0.484375 1.546875q-0.46875 0.71875 -1.375 1.125q-0.90625 0.390625 -2.046875 0.390625q-1.875 0 -2.875 -0.78125q-0.984375 -0.78125 -1.25 -2.328125zm13.5625 1.421875l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578125l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671875q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm6.913475 1.46875l0 -13.359375l1.78125 0l0 6.625l6.625 -6.625l2.390625 0l-5.59375 5.421875l5.84375 7.9375l-2.328125 0l-4.765625 -6.765625l-2.171875 2.140625l0 4.625l-1.78125 0zm18.943573 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.765625 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375 0 3.15625 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.21875 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.546875 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.390625 -2.65625l5.40625 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.03125 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.125717 5.765625l0 -9.671875l1.46875 0l0 1.46875q0.5625 -1.03125 1.03125 -1.359375q0.484375 -0.328125 1.0625 -0.328125q0.828125 0 1.6875 0.53125l-0.5625 1.515625q-0.609375 -0.359375 -1.203125 -0.359375q-0.546875 0 -0.96875 0.328125q-0.421875 0.328125 -0.609375 0.890625q-0.28125 0.875 -0.28125 1.921875l0 5.0625l-1.625 0zm6.228302 0l0 -9.671875l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.59375l0 5.953125l-1.640625 0l0 -5.890625q0 -1.0 -0.203125 -1.484375q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515625l0 5.28125l-1.640625 0zm17.000732 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.7656403 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375153 0 3.1562653 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.2187653 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.5468903 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.3906403 -2.65625l5.4062653 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.0312653 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.094467 5.765625l0 -13.359375l1.640625 0l0 13.359375l-1.640625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m12.454068 319.17642l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m12.454068 319.17642l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m180.45407 319.17642l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m180.45407 319.17642l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m81.43044 302.81055l100.88189 0l0 32.06299l-100.88189 0z" fill-rule="evenodd"/><path fill="#000000" d="m104.04542 323.64203l0 -9.546875l1.265625 0l0 3.921875l4.953125 0l0 -3.921875l1.265625 0l0 9.546875l-1.265625 0l0 -4.5l-4.953125 0l0 4.5l-1.265625 0zm13.953278 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm8.938362 0l0 -0.875q-0.65625 1.03125 -1.9374924 1.03125q-0.8125 0 -1.515625 -0.453125q-0.6875 -0.453125 -1.078125 -1.265625q-0.375 -0.828125 -0.375 -1.890625q0 -1.03125 0.34375 -1.875q0.34375 -0.84375 1.03125 -1.28125q0.703125 -0.453125 1.546875 -0.453125q0.6249924 0 1.1093674 0.265625q0.5 0.25 0.796875 0.671875l0 -3.421875l1.171875 0l0 9.546875l-1.09375 0zm-3.7031174 -3.453125q0 1.328125 0.5625 1.984375q0.5625 0.65625 1.3281174 0.65625q0.765625 0 1.296875 -0.625q0.53125 -0.625 0.53125 -1.90625q0 -1.421875 -0.546875 -2.078125q-0.546875 -0.671875 -1.3437424 -0.671875q-0.78125 0 -1.3125 0.640625q-0.515625 0.625 -0.515625 2.0zm7.912468 3.453125l-2.125 -6.90625l1.21875 0l1.09375 3.984375l0.421875 1.484375q0.015625 -0.109375 0.359375 -1.421875l1.09375 -4.046875l1.203125 0l1.03125 4.0l0.34375 1.328125l0.40625 -1.34375l1.171875 -3.984375l1.140625 0l-2.15625 6.90625l-1.21875 0l-1.09375 -4.140625l-0.265625 -1.171875l-1.40625 5.3125l-1.21875 0zm12.859543 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59376526 0.21875 -1.2812653 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.4218903 -0.171875 2.0937653 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.3437653 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.89064026 0 1.4375153 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.9218903 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.2031403 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm9.18837 -2.21875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375z" fill-rule="nonzero"/><defs><linearGradient id="p.1" gradientUnits="userSpaceOnUse" gradientTransform="matrix(4.545553100086654 0.0 0.0 4.545553100086654 0.0 0.0)" spreadMethod="pad" x1="9.954639806354566" y1="38.70166210013951" x2="9.95462288989064" y2="43.24721520019468"><stop offset="0.0" stop-color="#ff0000"/><stop offset="0.51" stop-color="#dab7a6"/><stop offset="0.99999994" stop-color="#dab7a6" stop-opacity="0.0"/><stop offset="1.0" stop-color="#ffffff" stop-opacity="0.0"/></linearGradient></defs><path fill="url(#p.1)" d="m45.249344 175.92108l173.29134 0l0 20.661423l-173.29134 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m280.4455 190.06865l129.5748 -74.83464l20.629913 35.74803l-129.5748 74.83464z" fill-rule="evenodd"/><path fill="#000000" d="m295.51392 196.73558l1.1823425 -0.82717896q0.51071167 0.6974335 1.1166077 0.9970703q0.5980835 0.28611755 1.4464111 0.1931305q0.84054565 -0.10652161 1.6794434 -0.5910187q0.75772095 -0.4376068 1.2010193 -0.9823456q0.44906616 -0.5660858 0.50097656 -1.1013031q0.057678223 -0.55656433 -0.20785522 -1.0166931q-0.27334595 -0.47366333 -0.7392273 -0.6557007q-0.47366333 -0.1955719 -1.2366333 -0.079711914q-0.478302 0.07775879 -2.032318 0.54222107q-1.5618286 0.45092773 -2.2805786 0.48712158q-0.9222717 0.027420044 -1.5864563 -0.31072998q-0.6719971 -0.35168457 -1.0703125 -1.0418701q-0.4295349 -0.74432373 -0.38497925 -1.6361694q0.05029297 -0.9131775 0.6668701 -1.7203827q0.63012695 -0.81500244 1.6313782 -1.3932648q1.1095276 -0.64079285 2.1592712 -0.7598877q1.0419617 -0.1326294 1.8867493 0.29968262q0.8583679 0.4244995 1.3987732 1.2671814l-1.2036743 0.82147217q-0.64712524 -0.87127686 -1.5022583 -1.0089111q-0.8629761 -0.15116882 -2.013092 0.51304626q-1.1906738 0.68766785 -1.4819641 1.4333191q-0.2913208 0.745636 0.06793213 1.3681641q0.30459595 0.52778625 0.8865051 0.6608429q0.5740967 0.119522095 2.3815613 -0.43717957q1.8210144 -0.5645294 2.5725403 -0.6376953q1.0924377 -0.107666016 1.857605 0.28042603q0.77090454 0.366745 1.2316895 1.1652069q0.4529724 0.78492737 0.39904785 1.7543335q-0.040405273 0.9616089 -0.6663208 1.8463745q-0.62594604 0.8847656 -1.6813354 1.4942932q-1.3530579 0.78144836 -2.486084 0.9125519q-1.1408386 0.11756897 -2.1214905 -0.3625946q-0.96713257 -0.48797607 -1.5721436 -1.4738007zm13.40155 -4.9431458l0.8006897 0.98106384q-0.45169067 0.4052124 -0.857605 0.6396332q-0.6629944 0.3829193 -1.1454773 0.39089966q-0.4902649 -0.0055389404 -0.8343506 -0.25790405q-0.3518982 -0.26589966 -0.9844971 -1.3620911l-2.382019 -4.1276093l-0.8930054 0.5157471l-0.5466919 -0.9473114l0.8930054 -0.5157623l-1.0308838 -1.786377l0.79599 -1.434082l1.4526367 2.5171661l1.2177734 -0.7033081l0.5466919 0.94732666l-1.2177734 0.70329285l2.4210815 4.195282q0.30456543 0.5278015 0.4446106 0.645401q0.15356445 0.109802246 0.35705566 0.11857605q0.19570923 -0.004760742 0.4663086 -0.16105652q0.20297241 -0.11721802 0.49645996 -0.35888672zm1.8165283 0.41241455l-4.147064 -7.1861115l1.0959778 -0.6329651l0.6247864 1.0826569q-0.0178833 -1.0001068 0.19332886 -1.4468842q0.21121216 -0.44676208 0.64419556 -0.6968231q0.6088562 -0.35165405 1.471283 -0.3264618l0.22875977 1.3654938q-0.59487915 7.4768066E-4 -1.0413818 0.25862122q-0.39239502 0.22662354 -0.5765381 0.6577606q-0.17843628 0.40979004 -0.06384277 0.92100525q0.17193604 0.7667999 0.61709595 1.5381927l2.1711426 3.7622223l-1.2177429 0.70329285zm2.0899658 -6.006668q-1.1480408 -1.9893646 -0.5930481 -3.5910034q0.47280884 -1.3376465 1.7988281 -2.1034698q1.4883423 -0.8595886 2.984253 -0.4243927q1.501648 0.41384888 2.4934998 2.132553q0.79660034 1.3803864 0.83795166 2.4210815q0.04135132 1.0407104 -0.49923706 1.948349q-0.5348511 0.88630676 -1.4819946 1.4333191q-1.5018921 0.8674011 -2.9899902 0.44573975q-1.4959106 -0.43519592 -2.5502625 -2.2621765zm1.2583313 -0.72673035q0.79663086 1.3803711 1.7902527 1.7267303q1.0072021 0.33854675 1.9137268 -0.18502808q0.9065552 -0.5235596 1.1036072 -1.5575867q0.21057129 -1.0418396 -0.60946655 -2.4628143q-0.76538086 -1.3262482 -1.7725525 -1.6647949q-0.99365234 -0.34635925 -1.9002075 0.17720032q-0.90652466 0.5235748 -1.117096 1.5654144q-0.2048645 1.0204926 0.59173584 2.400879zm8.984772 -0.38945007l-4.1470337 -7.1861115l1.0959778 -0.6329651l0.5857239 1.0149841q0.10531616 -1.6306152 1.6072083 -2.4980164q0.6494751 -0.37509155 1.3234558 -0.45761108q0.6739807 -0.08250427 1.163269 0.14013672q0.48928833 0.22264099 0.9021301 0.687912q0.26287842 0.29925537 0.74710083 1.1383057l2.553833 4.425354l-1.2177429 0.70329285l-2.522583 -4.371216q-0.42956543 -0.74432373 -0.7892456 -1.0237579q-0.3540039 -0.30078125 -0.8442688 -0.3063202q-0.47677612 -0.01335144 -0.9638672 0.2679596q-0.7847595 0.45324707 -1.0640869 1.2821655q-0.2735901 0.80758667 0.52301025 2.187973l2.264862 3.9246216l-1.2177429 0.70329285zm7.819275 -3.7220154l1.2922058 -0.511734q0.3878479 0.51579285 0.8666992 0.5640259q0.65527344 0.072631836 1.4400635 -0.38059998q0.8388977 -0.48449707 1.1036682 -1.0885162q0.26480103 -0.60401917 0.07571411 -1.306778q-0.1161499 -0.42010498 -0.81121826 -1.6245575q-0.25161743 1.408371 -1.4423218 2.096054q-1.4883423 0.85957336 -2.9171448 0.25932312q-1.428833 -0.60025024 -2.2957153 -2.1024323q-0.5935669 -1.0285187 -0.72805786 -2.1056366q-0.12097168 -1.0849152 0.30926514 -1.9649353q0.4437561 -0.8878174 1.3908997 -1.4348297q1.2718811 -0.7345581 2.690796 -0.182724l-0.4998474 -0.8661194l1.1230469 -0.6485901l3.5847168 6.2117157q0.9684448 1.6781158 1.0362854 2.5771942q0.067840576 0.8990936 -0.4420166 1.7348785q-0.5098877 0.8357849 -1.5923462 1.4609375q-1.2854004 0.7423706 -2.4195251 0.62150574q-1.1206055 -0.12869263 -1.7651672 -1.3081818zm-1.4765625 -4.9031525q0.81222534 1.4074402 1.7418518 1.7366486q0.94314575 0.32138062 1.7820435 -0.16311646q0.8388977 -0.4844818 1.0401001 -1.4487457q0.19342041 -0.97779846 -0.60317993 -2.3581848q-0.76538086 -1.3262482 -1.7298889 -1.6533508q-0.97229004 -0.3406372 -1.7976685 0.1360321q-0.8118286 0.46887207 -0.9974365 1.4602051q-0.1855774 0.991333 0.56417847 2.290512z" fill-rule="nonzero"/><path fill="#000000" d="m302.23132 207.68793l-0.80441284 -1.3939209l1.2177429 -0.70329285l0.80441284 1.3939056l-1.2177429 0.7033081zm4.920227 8.525894l-4.147064 -7.1861115l1.2177734 -0.7033081l4.1470337 7.1861115l-1.2177429 0.7033081zm1.3493347 -3.6482391l1.0948792 -0.88494873q0.51641846 0.6760864 1.2029724 0.80285645q0.6922302 0.10542297 1.5176086 -0.3712616q0.8388672 -0.4844818 1.0495605 -1.057251q0.20285034 -0.58628845 -0.062683105 -1.0464172q-0.23431396 -0.4059906 -0.7266846 -0.4464264q-0.35079956 -0.013916016 -1.4790955 0.3129425q-1.5347595 0.43530273 -2.2030334 0.49645996q-0.66256714 0.03982544 -1.183075 -0.23695374q-0.5069885 -0.28459167 -0.8115845 -0.8123779q-0.28115845 -0.48719788 -0.2989807 -1.018219q-0.017791748 -0.5310211 0.2048645 -1.0204926q0.15917969 -0.3806305 0.56817627 -0.797287q0.4147339 -0.43800354 0.9694824 -0.75839233q0.852417 -0.49230957 1.6289368 -0.61598206q0.7765198 -0.123687744 1.3162842 0.123931885q0.5455017 0.22625732 1.0733948 0.85964966l-1.0969849 0.85006714q-0.4013672 -0.5079651 -0.9733887 -0.59262085q-0.5720215 -0.0846405 -1.2756042 0.3217163q-0.8388977 0.4844818 -1.0402222 0.9796753q-0.19558716 0.47383118 0.015289307 0.8392334q0.14056396 0.24359131 0.39874268 0.34710693q0.2581787 0.103500366 0.64749146 0.05909729q0.22845459 -0.041732788 1.2620544 -0.31388855q1.4806519 -0.40403748 2.127594 -0.470932q0.6390991 -0.08041382 1.1653442 0.17501831q0.5397949 0.24760437 0.89904785 0.87013245q0.35144043 0.60899353 0.29852295 1.3613129q-0.047210693 0.73095703 -0.5519409 1.4194183q-0.49118042 0.68063354 -1.3300476 1.1651306q-1.407196 0.81269836 -2.4736633 0.6527405q-1.0664673 -0.15994263 -1.933258 -1.193039zm6.1190186 -5.4646606q-1.1480408 -1.9893799 -0.5930481 -3.5910187q0.47283936 -1.3376465 1.7988281 -2.1034698q1.4883423 -0.85957336 2.984253 -0.4243927q1.501648 0.41384888 2.4934998 2.132553q0.79663086 1.3803864 0.83795166 2.4210968q0.04135132 1.0406952 -0.49923706 1.948349q-0.5348511 0.8862915 -1.4819946 1.4333038q-1.5018921 0.8674011 -2.9899902 0.44573975q-1.4959106 -0.43519592 -2.5502625 -2.2621613zm1.2583313 -0.7267456q0.79663086 1.3803864 1.7902832 1.7267303q1.0071716 0.33854675 1.9136963 -0.18501282q0.9065552 -0.5235748 1.1036072 -1.5576019q0.21057129 -1.0418396 -0.60946655 -2.4628143q-0.76538086 -1.3262482 -1.7725525 -1.6647949q-0.99365234 -0.34635925 -1.9002075 0.17721558q-0.90652466 0.5235596 -1.117096 1.5653992q-0.2048645 1.0204926 0.59173584 2.400879zm8.957733 -0.3738098l-5.7246704 -9.919815l1.2177734 -0.70329285l5.72464 9.9198l-1.2177429 0.7033081zm7.2713623 -5.390396q-0.34069824 0.972641 -0.8225403 1.5756989q-0.48962402 0.5895386 -1.2067566 1.0036926q-1.1906738 0.6876831 -2.162445 0.47302246q-0.9717407 -0.21464539 -1.4872131 -1.1078339q-0.30456543 -0.52778625 -0.3109436 -1.1015167q-0.014190674 -0.58724976 0.21627808 -1.0631866q0.23620605 -0.49728394 0.64520264 -0.9139252q0.31063843 -0.3057251 0.980896 -0.8010864q1.3733215 -1.02771 1.9363098 -1.6776581q-0.14837646 -0.25712585 -0.18740845 -0.32478333q-0.42956543 -0.74432373 -0.93963623 -0.84669495q-0.7156677 -0.14602661 -1.6357422 0.38536072q-0.8524475 0.49230957 -1.0922546 1.0458069q-0.23410034 0.5321655 0.013824463 1.3994293l-1.2843933 0.52526855q-0.2749939 -0.85162354 -0.18301392 -1.5362854q0.10549927 -0.6924591 0.66851807 -1.3424072q0.5552063 -0.66348267 1.4752808 -1.1948547q0.92007446 -0.5313873 1.6133118 -0.6430664q0.7067566 -0.11949158 1.1648254 0.04902649q0.45803833 0.16850281 0.8552551 0.60672q0.24728394 0.27218628 0.71588135 1.0841675l0.9371643 1.6239929q0.9840393 1.7051697 1.3094177 2.1126862q0.33892822 0.39971924 0.81103516 0.68640137l-1.2718506 0.7345581q-0.40811157 -0.26953125 -0.7590027 -0.75253296zm-1.6645203 -2.6654663q-0.5067749 0.65356445 -1.7234497 1.6088562q-0.69522095 0.5458679 -0.9283142 0.8609314q-0.23312378 0.31506348 -0.25280762 0.6873169q-0.013977051 0.3508911 0.16564941 0.66215515q0.28115845 0.48719788 0.83392334 0.6010132q0.5662842 0.10598755 1.269867 -0.30036926q0.70358276 -0.40634155 1.072998 -1.0166473q0.37512207 -0.63165283 0.3197937 -1.3214569q-0.031341553 -0.5232086 -0.49990845 -1.3352051l-0.25775146 -0.44659424zm7.223419 -0.8156891l0.8006897 0.98106384q-0.45169067 0.4052124 -0.857605 0.63964844q-0.6629944 0.38290405 -1.1454468 0.39089966q-0.4902954 -0.0055389404 -0.8343811 -0.25790405q-0.3518982 -0.26591492 -0.9844971 -1.3620911l-2.382019 -4.1276093l-0.8930054 0.5157471l-0.5466919 -0.94732666l0.8930054 -0.5157471l-1.0308838 -1.786377l0.7960205 -1.434082l1.4526367 2.5171661l1.2177429 -0.7033081l0.5466919 0.94732666l-1.2177429 0.70329285l2.421051 4.195282q0.30459595 0.52778625 0.4446106 0.645401q0.15356445 0.10978699 0.35708618 0.11856079q0.19567871 -0.004760742 0.46627808 -0.16104126q0.20297241 -0.11721802 0.49645996 -0.35890198zm-3.0901794 -8.121277l-0.80441284 -1.3939209l1.2177429 -0.70329285l0.80444336 1.3939209l-1.2177734 0.70329285zm4.920227 8.525894l-4.1470337 -7.1861115l1.2177429 -0.70329285l4.147064 7.186096l-1.2177734 0.7033081zm0.54074097 -5.111908q-1.1480408 -1.9893799 -0.5930481 -3.5910187q0.47280884 -1.3376465 1.7988281 -2.1034698q1.4883423 -0.8595886 2.984253 -0.4243927q1.501648 0.41384888 2.4934998 2.132553q0.79660034 1.3803864 0.83795166 2.4210968q0.0413208 1.0406952 -0.49923706 1.948349q-0.5348511 0.8862915 -1.4819946 1.4333038q-1.5018921 0.8674011 -2.9899902 0.44573975q-1.4959412 -0.43519592 -2.5502625 -2.2621613zm1.2583313 -0.7267456q0.79663086 1.3803864 1.7902527 1.7267303q1.0072021 0.33854675 1.9137268 -0.18501282q0.9065552 -0.5235748 1.1036072 -1.5576019q0.21057129 -1.0418396 -0.60946655 -2.4628143q-0.76538086 -1.3262482 -1.7725525 -1.6647949q-0.99365234 -0.34635925 -1.9002075 0.17721558q-0.9065552 0.5235596 -1.117096 1.5653992q-0.2048645 1.0204926 0.59173584 2.400879zm8.984772 -0.38945007l-4.1470337 -7.1861115l1.0959473 -0.6329651l0.5857544 1.0149841q0.10531616 -1.6306152 1.6072083 -2.4980164q0.6494446 -0.37509155 1.3234558 -0.45761108q0.6739807 -0.08250427 1.163269 0.14013672q0.48928833 0.22264099 0.9021301 0.687912q0.26287842 0.29925537 0.74710083 1.1383057l2.553833 4.425354l-1.2177429 0.70329285l-2.5226135 -4.371216q-0.4295349 -0.74432373 -0.7892456 -1.0237579q-0.3539734 -0.30078125 -0.8442383 -0.3063202q-0.47677612 -0.01335144 -0.9638672 0.2679596q-0.7847595 0.45324707 -1.0640869 1.2821808q-0.2735901 0.8075714 0.52301025 2.1879578l2.264862 3.9246216l-1.2177429 0.70329285z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m290.76047 207.94267c-17.003845 0 -26.795105 -5.566925 -34.00769 -11.133865c-7.212616 -5.566925 -11.846558 -11.13385 -23.693115 -11.13385" fill-rule="evenodd"/><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m290.76047 207.94267c-17.003876 0 -26.795105 -5.5669403 -34.00769 -11.133865c-3.606308 -2.7834625 -6.5679474 -5.566925 -10.108826 -7.6545258c-0.4426117 -0.2609558 -0.8942871 -0.5110321 -1.3573761 -0.74887085c-0.11578369 -0.0594635 -0.23228455 -0.1181488 -0.34950256 -0.17605591l-0.13806152 -0.066833496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="2.0" stroke-linecap="butt" d="m245.48381 184.93082l-9.563843 1.350235l8.194244 5.1131744z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m280.4455 126.06866l129.5748 -74.83465l20.629913 35.74803l-129.5748 74.83464z" fill-rule="evenodd"/><path fill="#000000" d="m298.03357 135.5387l-5.72464 -9.919807l1.3124695 -0.75800323l5.72464 9.9198l-1.3124695 0.75801086zm3.470581 -2.0044098l-4.1470337 -7.186104l1.0959778 -0.6329651l0.5857239 1.0149841q0.10531616 -1.6306229 1.6072083 -2.498024q0.6494751 -0.37509155 1.3234558 -0.45760345q0.6739807 -0.0825119 1.163269 0.14012146q0.48928833 0.22264099 0.9021301 0.6879196q0.26290894 0.29925537 0.74710083 1.1383133l2.553833 4.4253464l-1.2177429 0.70329285l-2.522583 -4.371208q-0.4295349 -0.74432373 -0.7892456 -1.0237579q-0.3540039 -0.30078888 -0.8442688 -0.30632782q-0.4767456 -0.01335144 -0.9638672 0.26796722q-0.7847595 0.4532318 -1.0640869 1.2821655q-0.2735901 0.80758667 0.52301025 2.1879654l2.2648926 3.9246216l-1.2177734 0.70329285zm12.360168 -7.1384735l-0.5232849 -0.906723q-0.059539795 1.4598389 -1.3855286 2.2256546q-0.8659668 0.5001297 -1.8564148 0.44062042q-0.9904785 -0.05949402 -1.8883972 -0.6776428q-0.88442993 -0.62597656 -1.5170288 -1.7221603q-0.6247864 -1.0826492 -0.77282715 -2.151947q-0.14230347 -1.0906448 0.30926514 -1.9649353q0.44378662 -0.887825 1.336792 -1.4035797q0.6494751 -0.37509155 1.3063049 -0.39356232q0.6703491 -0.026283264 1.2392883 0.2405777l-2.054016 -3.5592194l1.2177429 -0.7033005l5.7246704 9.919807l-1.1365662 0.6564102zm-5.9122925 -1.3669891q0.79660034 1.3803787 1.7767029 1.7345505q0.97232056 0.3406372 1.7570801 -0.112602234q0.7983093 -0.4610443 0.97817993 -1.4310303q0.18560791 -0.991333 -0.58758545 -2.3311157q-0.8512573 -1.4751129 -1.8178406 -1.8370972q-0.96658325 -0.36198425 -1.805481 0.12251282q-0.8118286 0.46886444 -0.97036743 1.4445648q-0.15853882 0.9757004 0.6693115 2.4102173zm12.53952 -5.5459747l1.3442383 -0.57787323q0.34274292 1.2816315 -0.11764526 2.3594894q-0.4468689 1.0700455 -1.6916809 1.788971q-1.5830688 0.9142914 -3.0654602 0.47128296q-1.4823914 -0.44300842 -2.4898376 -2.1887817q-1.0465393 -1.813446 -0.69085693 -3.3540955q0.35568237 -1.5406494 1.8440247 -2.4002304q1.4342346 -0.828331 2.9109192 -0.36397552q1.4823608 0.4430008 2.5054626 2.2158508q0.0625 0.10826111 0.18743896 0.32479095l-5.3580933 3.094513q0.75494385 1.1518478 1.7173462 1.4440689q0.96810913 0.27088165 1.861145 -0.24487305q0.6765137 -0.39072418 0.9470215 -1.0160904q0.2705078 -0.6253662 0.09597778 -1.5530472zm-5.131775 0.32941437l4.005066 -2.3130646q-0.60446167 -0.8598404 -1.2410278 -1.0876312q-0.98794556 -0.36769867 -1.90802 0.16368103q-0.8388977 0.48449707 -1.0926819 1.3889084q-0.2480774 0.88306427 0.23666382 1.8481064zm10.854523 3.1318436l-5.740265 -9.946869l1.1094971 -0.6407852l0.5388794 0.9337845q0.07220459 -0.78147125 0.4031067 -1.333458q0.3444214 -0.5597992 1.0480042 -0.9661484q0.92007446 -0.5313797 1.8970032 -0.46406555q0.9769592 0.06730652 1.8285828 0.7302551q0.8573303 0.64160156 1.4508972 1.6701202q0.6404114 1.1097183 0.74212646 2.2238083q0.11526489 1.1062698 -0.3691101 2.01754q-0.4708252 0.90345 -1.3097229 1.3879471q-0.6088867 0.35164642 -1.2443542 0.37583923q-0.62197876 0.01637268 -1.159668 -0.19635773l2.022766 3.5050888l-1.2177429 0.7033005zm-2.551239 -6.952957q0.80441284 1.3939209 1.7418518 1.7366486q0.95095825 0.3349228 1.7492676 -0.12612915q0.8118286 -0.46886444 0.9953308 -1.495079q0.18353271 -1.026207 -0.6443176 -2.4607239q-0.79663086 -1.3803787 -1.7554016 -1.728836q-0.96658325 -0.36198425 -1.7513428 0.09125519q-0.77124023 0.4454193 -0.958374 1.5278625q-0.1736145 1.0746231 0.62298584 2.4550018zm12.239746 -5.408928l1.3442688 -0.57788086q0.34274292 1.2816391 -0.11764526 2.359497q-0.4468689 1.0700455 -1.6916809 1.788971q-1.5830688 0.91428375 -3.0654602 0.47128296q-1.4823914 -0.44300842 -2.4898682 -2.1887894q-1.0465088 -1.8134384 -0.6908264 -3.3540878q0.35565186 -1.5406494 1.8440247 -2.400238q1.4342346 -0.828331 2.9108887 -0.36397552q1.4823914 0.44300842 2.5054932 2.2158508q0.062469482 0.10826874 0.18743896 0.32479858l-5.3580933 3.094513q0.75491333 1.1518402 1.7173462 1.4440689q0.96810913 0.27087402 1.861145 -0.24488068q0.6765137 -0.39071655 0.9470215 -1.0160828q0.2705078 -0.62537384 0.095947266 -1.5530472zm-5.1317444 0.32941437l4.0050354 -2.3130722q-0.60443115 -0.85983276 -1.2410278 -1.0876236q-0.98791504 -0.3677063 -1.9079895 0.1636734q-0.8388977 0.48449707 -1.0926819 1.3889084q-0.2480774 0.88306427 0.23666382 1.848114zm9.2612915 0.3710785l-4.1470337 -7.1861115l1.0959778 -0.6329651l0.5857239 1.0149841q0.10531616 -1.6306229 1.6072083 -2.498024q0.6494751 -0.37509155 1.3234558 -0.45760345q0.6739807 -0.0825119 1.163269 0.14012909q0.48928833 0.22263336 0.9021301 0.687912q0.26290894 0.29925537 0.74710083 1.1383133l2.553833 4.4253464l-1.2177429 0.70329285l-2.522583 -4.371208q-0.42956543 -0.74432373 -0.7892456 -1.0237579q-0.3540039 -0.30078888 -0.8442688 -0.30632782q-0.4767456 -0.01335144 -0.9638672 0.26796722q-0.7847595 0.4532318 -1.0640869 1.2821655q-0.2735901 0.80758667 0.52301025 2.187973l2.2648926 3.924614l-1.2177734 0.7033005zm12.360168 -7.138481l-0.5232849 -0.906723q-0.059539795 1.4598389 -1.3855286 2.2256546q-0.8659668 0.5001221 -1.8564148 0.44062805q-0.9904785 -0.05949402 -1.8883972 -0.67765045q-0.88442993 -0.62596893 -1.5170288 -1.7221603q-0.6247864 -1.0826492 -0.77282715 -2.151947q-0.14230347 -1.0906448 0.30926514 -1.9649353q0.44378662 -0.887825 1.336792 -1.4035797q0.6494751 -0.37509155 1.3063049 -0.39356232q0.6703491 -0.026283264 1.2392883 0.2405777l-2.054016 -3.5592194l1.2177429 -0.7033005l5.7246704 9.919807l-1.1365662 0.6564102zm-5.9122925 -1.3669891q0.79660034 1.3803787 1.7767029 1.7345505q0.97232056 0.3406372 1.7570801 -0.112602234q0.7983093 -0.4610443 0.97817993 -1.4310303q0.18560791 -0.991333 -0.58758545 -2.3311157q-0.8512573 -1.4751129 -1.8178406 -1.8370972q-0.96658325 -0.36198425 -1.805481 0.12251282q-0.8118286 0.46886444 -0.97036743 1.4445648q-0.15853882 0.9757004 0.6693115 2.4102173zm12.53952 -5.5459747l1.3442383 -0.57787323q0.34274292 1.2816391 -0.11764526 2.3594894q-0.4468689 1.0700455 -1.6916809 1.788971q-1.5830688 0.9142914 -3.0654602 0.47128296q-1.4823914 -0.4430008 -2.4898682 -2.1887817q-1.0465088 -1.813446 -0.6908264 -3.3540955q0.35568237 -1.5406494 1.8440247 -2.4002304q1.4342346 -0.828331 2.9109192 -0.36397552q1.4823608 0.4430008 2.5054626 2.2158508q0.0625 0.10826111 0.18743896 0.32479095l-5.3580933 3.094513q0.75494385 1.1518478 1.7173462 1.4440689q0.96810913 0.27088165 1.861145 -0.24487305q0.6765137 -0.39072418 0.9470215 -1.0160904q0.2705078 -0.6253662 0.09597778 -1.5530472zm-5.131775 0.32941437l4.005066 -2.3130646q-0.60446167 -0.8598404 -1.2410278 -1.0876312q-0.98794556 -0.36769867 -1.90802 0.16368103q-0.8388977 0.48449707 -1.0926819 1.3889084q-0.2480774 0.88306427 0.23666382 1.8481064zm9.261322 0.3710785l-4.147064 -7.186104l1.0959778 -0.6329727l0.5857544 1.0149918q0.105285645 -1.6306229 1.6071777 -2.498024q0.6494751 -0.37509155 1.3234558 -0.45761108q0.67401123 -0.0825119 1.163269 0.14012909q0.48928833 0.22264099 0.9021301 0.6879196q0.26290894 0.29925537 0.74710083 1.1383133l2.553833 4.4253387l-1.2177429 0.7033005l-2.522583 -4.371208q-0.4295349 -0.74432373 -0.7892456 -1.0237656q-0.3539734 -0.30078125 -0.8442688 -0.3063202q-0.4767456 -0.01335907 -0.96383667 0.2679596q-0.78479004 0.45323944 -1.0640869 1.2821732q-0.2736206 0.80758667 0.52301025 2.1879654l2.264862 3.924614l-1.2177429 0.7033005zm9.725006 -7.078125l0.8006897 0.98106384q-0.45169067 0.4052124 -0.857605 0.6396408q-0.6629944 0.38291168 -1.1454468 0.39089966q-0.4902954 -0.0055389404 -0.8343811 -0.25791168q-0.3518982 -0.26589966 -0.9844971 -1.3620834l-2.382019 -4.127617l-0.8930054 0.5157547l-0.5466919 -0.94731903l0.8930054 -0.5157547l-1.0308838 -1.786377l0.79599 -1.4340897l1.4526672 2.5171661l1.2177429 -0.7033005l0.5466919 0.94732666l-1.2177429 0.70329285l2.421051 4.195282q0.30459595 0.5277939 0.4446106 0.645401q0.15356445 0.10979462 0.35708618 0.11856842q0.19567871 -0.004760742 0.46627808 -0.16104889q0.20297241 -0.11721802 0.49645996 -0.35889435z" fill-rule="nonzero"/><path fill="#000000" d="m307.15155 152.21382l-5.72464 -9.919815l1.2177429 -0.70329285l3.2645264 5.6568604l1.1950684 -4.587631l1.5830688 -0.9142914l-1.2081604 4.252365l5.625824 2.7774506l-1.5018921 0.8674011l-4.4921265 -2.3134918l-0.38955688 1.3256378l1.6478882 2.8554993l-1.2177429 0.7033081zm10.503723 -9.151794l1.3442383 -0.57788086q0.34274292 1.2816315 -0.11764526 2.359497q-0.44683838 1.0700378 -1.6916504 1.788971q-1.5830688 0.9142761 -3.0654602 0.47128296q-1.4823914 -0.44300842 -2.4898682 -2.188797q-1.0465393 -1.8134308 -0.69085693 -3.3540802q0.35568237 -1.5406494 1.8440247 -2.400238q1.4342346 -0.8283386 2.9109192 -0.36398315q1.4823914 0.44300842 2.5054932 2.2158508q0.062469482 0.10827637 0.18743896 0.32479858l-5.3580933 3.094513q0.75491333 1.1518402 1.7173157 1.4440765q0.96813965 0.27087402 1.861145 -0.2448883q0.6765137 -0.39071655 0.947052 -1.0160828q0.2705078 -0.6253662 0.095947266 -1.5530396zm-5.1317444 0.32940674l4.0050354 -2.3130646q-0.60446167 -0.85983276 -1.2410278 -1.0876312q-0.98791504 -0.3677063 -1.90802 0.16368103q-0.8388672 0.48449707 -1.0926819 1.3889008q-0.2480774 0.8830719 0.23669434 1.848114zm9.247772 0.378891l-4.147064 -7.1861115l1.0959778 -0.6329651l0.6247864 1.0826569q-0.017913818 -1.0001068 0.19329834 -1.4468842q0.21124268 -0.44676208 0.64419556 -0.6968231q0.6088867 -0.35165405 1.4713135 -0.3264618l0.22875977 1.3654938q-0.59487915 7.4768066E-4 -1.0413818 0.25862122q-0.39239502 0.22662354 -0.5765686 0.6577606q-0.17840576 0.40979004 -0.063812256 0.92100525q0.17190552 0.7667999 0.61709595 1.5381927l2.1711426 3.7622223l-1.2177429 0.70329285zm4.6274414 -2.6725311l-4.147064 -7.1861115l1.0959778 -0.6329651l0.5857544 1.0149841q0.105285645 -1.6306152 1.6071777 -2.4980164q0.6494751 -0.37509155 1.3234558 -0.45761108q0.67401123 -0.08250427 1.1632996 0.14012146q0.4892578 0.22264099 0.9020996 0.68792725q0.26290894 0.29925537 0.74710083 1.1383057l2.553833 4.425354l-1.2177429 0.70329285l-2.522583 -4.371216q-0.4295349 -0.74432373 -0.7892456 -1.0237579q-0.3539734 -0.30078125 -0.8442688 -0.3063202q-0.4767456 -0.01335144 -0.96383667 0.2679596q-0.78479004 0.4532318 -1.0640869 1.2821655q-0.2736206 0.80758667 0.52301025 2.187973l2.264862 3.9246216l-1.2177429 0.70329285zm11.281708 -9.60112l1.3442688 -0.57788086q0.34274292 1.2816315 -0.11764526 2.359497q-0.4468689 1.0700378 -1.6916809 1.788971q-1.5830688 0.9142914 -3.0654602 0.47128296q-1.4823914 -0.44300842 -2.4898682 -2.1887817q-1.0465088 -1.813446 -0.6908264 -3.3540955q0.35565186 -1.5406494 1.8440247 -2.400238q1.4342346 -0.828331 2.9108887 -0.36397552q1.4823914 0.44300842 2.5054932 2.2158432q0.062469482 0.10827637 0.18743896 0.32479858l-5.3580933 3.094513q0.75491333 1.1518555 1.7173462 1.4440765q0.96810913 0.27087402 1.861145 -0.24487305q0.6765137 -0.3907318 0.9470215 -1.016098q0.2705078 -0.6253662 0.095947266 -1.5530396zm-5.1317444 0.32940674l4.0050354 -2.3130646q-0.60443115 -0.85983276 -1.2410278 -1.0876312q-0.98791504 -0.3677063 -1.9079895 0.16368103q-0.8388977 0.48449707 -1.0926819 1.388916q-0.2480774 0.88305664 0.23666382 1.8480988zm9.234253 0.3867035l-5.7246704 -9.9198l1.2177734 -0.7033005l5.72464 9.919807l-1.2177429 0.70329285z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m290.76047 143.94267c-17.003845 0 -26.795105 -5.566925 -34.00769 -11.133865c-7.212616 -5.566925 -11.846558 -11.13385 -23.693115 -11.13385" fill-rule="evenodd"/><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m290.76047 143.94267c-17.003876 0 -26.795105 -5.5669403 -34.00769 -11.133865c-3.606308 -2.7834625 -6.5679474 -5.566925 -10.108826 -7.6545258c-0.4426117 -0.26094818 -0.8942871 -0.5110321 -1.3573761 -0.74887085c-0.11578369 -0.0594635 -0.23228455 -0.11816406 -0.34950256 -0.17607117l-0.13806152 -0.06682587" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="2.0" stroke-linecap="butt" d="m245.48381 120.93081l-9.563843 1.350235l8.194244 5.113182z" fill-rule="evenodd"/></g></svg> \ No newline at end of file
diff --git a/g3doc/README.md b/g3doc/README.md
new file mode 100644
index 000000000..7956fe739
--- /dev/null
+++ b/g3doc/README.md
@@ -0,0 +1,168 @@
+# What is gVisor?
+
+gVisor is an application kernel, written in Go, that implements a substantial
+portion of the [Linux system call interface][linux]. It provides an additional
+layer of isolation between running applications and the host operating system.
+
+gVisor includes an [Open Container Initiative (OCI)][oci] runtime called `runsc`
+that makes it easy to work with existing container tooling. The `runsc` runtime
+integrates with Docker and Kubernetes, making it simple to run sandboxed
+containers.
+
+gVisor can be used with Docker, Kubernetes, or directly using `runsc`. Use the
+links below to see detailed instructions for each of them:
+
+* [Docker](./user_guide/quick_start/docker.md): The quickest and easiest way
+ to get started.
+* [Kubernetes](./user_guide/quick_start/kubernetes.md): Isolate Pods in your
+ K8s cluster with gVisor.
+* [OCI Quick Start](./user_guide/quick_start/oci.md): Expert mode. Customize
+ gVisor for your environment.
+
+## What does gVisor do?
+
+gVisor provides a virtualized environment in order to sandbox containers. The
+system interfaces normally implemented by the host kernel are moved into a
+distinct, per-sandbox application kernel in order to minimize the risk of an
+container escape exploit. gVisor does not introduce large fixed overheads
+however, and still retains a process-like model with respect to resource
+utilization.
+
+## How is this different?
+
+Two other approaches are commonly taken to provide stronger isolation than
+native containers.
+
+**Machine-level virtualization**, such as [KVM][kvm] and [Xen][xen], exposes
+virtualized hardware to a guest kernel via a Virtual Machine Monitor (VMM). This
+virtualized hardware is generally enlightened (paravirtualized) and additional
+mechanisms can be used to improve the visibility between the guest and host
+(e.g. balloon drivers, paravirtualized spinlocks). Running containers in
+distinct virtual machines can provide great isolation, compatibility and
+performance (though nested virtualization may bring challenges in this area),
+but for containers it often requires additional proxies and agents, and may
+require a larger resource footprint and slower start-up times.
+
+![Machine-level virtualization](Machine-Virtualization.png "Machine-level virtualization")
+
+**Rule-based execution**, such as [seccomp][seccomp], [SELinux][selinux] and
+[AppArmor][apparmor], allows the specification of a fine-grained security policy
+for an application or container. These schemes typically rely on hooks
+implemented inside the host kernel to enforce the rules. If the surface can be
+made small enough, then this is an excellent way to sandbox applications and
+maintain native performance. However, in practice it can be extremely difficult
+(if not impossible) to reliably define a policy for arbitrary, previously
+unknown applications, making this approach challenging to apply universally.
+
+![Rule-based execution](Rule-Based-Execution.png "Rule-based execution")
+
+Rule-based execution is often combined with additional layers for
+defense-in-depth.
+
+**gVisor** provides a third isolation mechanism, distinct from those above.
+
+gVisor intercepts application system calls and acts as the guest kernel, without
+the need for translation through virtualized hardware. gVisor may be thought of
+as either a merged guest kernel and VMM, or as seccomp on steroids. This
+architecture allows it to provide a flexible resource footprint (i.e. one based
+on threads and memory mappings, not fixed guest physical resources) while also
+lowering the fixed costs of virtualization. However, this comes at the price of
+reduced application compatibility and higher per-system call overhead.
+
+![gVisor](Layers.png "gVisor")
+
+On top of this, gVisor employs rule-based execution to provide defense-in-depth
+(details below).
+
+gVisor's approach is similar to [User Mode Linux (UML)][uml], although UML
+virtualizes hardware internally and thus provides a fixed resource footprint.
+
+Each of the above approaches may excel in distinct scenarios. For example,
+machine-level virtualization will face challenges achieving high density, while
+gVisor may provide poor performance for system call heavy workloads.
+
+## Why Go?
+
+gVisor is written in [Go][golang] in order to avoid security pitfalls that can
+plague kernels. With Go, there are strong types, built-in bounds checks, no
+uninitialized variables, no use-after-free, no stack overflow, and a built-in
+race detector. However, the use of Go has its challenges, and the runtime often
+introduces performance overhead.
+
+## What are the different components?
+
+A gVisor sandbox consists of multiple processes. These processes collectively
+comprise an environment in which one or more containers can be run.
+
+Each sandbox has its own isolated instance of:
+
+* The **Sentry**, which is a kernel that runs the containers and intercepts
+ and responds to system calls made by the application.
+
+Each container running in the sandbox has its own isolated instance of:
+
+* A **Gofer** which provides file system access to the containers.
+
+![gVisor architecture diagram](Sentry-Gofer.png "gVisor architecture diagram")
+
+## What is runsc?
+
+The entrypoint to running a sandboxed container is the `runsc` executable.
+`runsc` implements the [Open Container Initiative (OCI)][oci] runtime
+specification, which is used by Docker and Kubernetes. This means that OCI
+compatible _filesystem bundles_ can be run by `runsc`. Filesystem bundles are
+comprised of a `config.json` file containing container configuration, and a root
+filesystem for the container. Please see the [OCI runtime spec][runtime-spec]
+for more information on filesystem bundles. `runsc` implements multiple commands
+that perform various functions such as starting, stopping, listing, and querying
+the status of containers.
+
+### Sentry
+
+<a name="sentry"></a> <!-- For deep linking. -->
+
+The Sentry is the largest component of gVisor. It can be thought of as a
+application kernel. The Sentry implements all the kernel functionality needed by
+the application, including: system calls, signal delivery, memory management and
+page faulting logic, the threading model, and more.
+
+When the application makes a system call, the
+[Platform](./architecture_guide/platforms.md) redirects the call to the Sentry,
+which will do the necessary work to service it. It is important to note that the
+Sentry does not pass system calls through to the host kernel. As a userspace
+application, the Sentry will make some host system calls to support its
+operation, but it does not allow the application to directly control the system
+calls it makes. For example, the Sentry is not able to open files directly; file
+system operations that extend beyond the sandbox (not internal `/proc` files,
+pipes, etc) are sent to the Gofer, described below.
+
+### Gofer
+
+<a name="gofer"></a> <!-- For deep linking. -->
+
+The Gofer is a standard host process which is started with each container and
+communicates with the Sentry via the [9P protocol][9p] over a socket or shared
+memory channel. The Sentry process is started in a restricted seccomp container
+without access to file system resources. The Gofer mediates all access to the
+these resources, providing an additional level of isolation.
+
+### Application
+
+The application is a normal Linux binary provided to gVisor in an OCI runtime
+bundle. gVisor aims to provide an environment equivalent to Linux v4.4, so
+applications should be able to run unmodified. However, gVisor does not
+presently implement every system call, `/proc` file, or `/sys` file so some
+incompatibilities may occur. See [Compatibility](./user_guide/compatibility.md)
+for more information.
+
+[9p]: https://en.wikipedia.org/wiki/9P_(protocol)
+[apparmor]: https://wiki.ubuntu.com/AppArmor
+[golang]: https://golang.org
+[kvm]: https://www.linux-kvm.org
+[linux]: https://en.wikipedia.org/wiki/Linux_kernel_interfaces
+[oci]: https://www.opencontainers.org
+[runtime-spec]: https://github.com/opencontainers/runtime-spec
+[seccomp]: https://www.kernel.org/doc/Documentation/prctl/seccomp_filter.txt
+[selinux]: https://selinuxproject.org
+[uml]: http://user-mode-linux.sourceforge.net/
+[xen]: https://www.xenproject.org
diff --git a/g3doc/Rule-Based-Execution.png b/g3doc/Rule-Based-Execution.png
new file mode 100644
index 000000000..b42654a90
--- /dev/null
+++ b/g3doc/Rule-Based-Execution.png
Binary files differ
diff --git a/g3doc/Rule-Based-Execution.svg b/g3doc/Rule-Based-Execution.svg
new file mode 100644
index 000000000..bd6717043
--- /dev/null
+++ b/g3doc/Rule-Based-Execution.svg
@@ -0,0 +1 @@
+<svg version="1.1" viewBox="0.0 0.0 355.03674540682414 172.5564304461942" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg"><clipPath id="p.0"><path d="m0 0l355.03674 0l0 172.55643l-355.03674 0l0 -172.55643z" clip-rule="nonzero"/></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l355.03674 0l0 172.55643l-355.03674 0z" fill-rule="evenodd"/><path fill="#f4cccc" d="m36.454067 6.6430445l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path stroke="#cc4125" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m36.454067 6.6430445l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path fill="#000000" d="m78.206116 37.98824l5.125 -13.359373l1.90625 0l5.46875 13.359373l-2.015625 0l-1.546875 -4.046875l-5.59375 0l-1.46875 4.046875l-1.875 0zm3.859375 -5.484375l4.53125 0l-1.40625 -3.703123q-0.625 -1.6875 -0.9375 -2.765625q-0.265625 1.28125 -0.71875 2.546875l-1.46875 3.921873zm9.849823 9.1875l0 -13.374998l1.484375 0l0 1.25q0.53125 -0.734375 1.1875 -1.09375q0.671875 -0.375 1.625 -0.375q1.234375 0 2.171875 0.640625q0.953125 0.625 1.4375 1.796875q0.484375 1.15625 0.484375 2.546873q0 1.484375 -0.53125 2.671875q-0.53125 1.1875 -1.546875 1.828125q-1.015625 0.625 -2.140625 0.625q-0.8125 0 -1.46875 -0.34375q-0.65625 -0.34375 -1.0625 -0.875l0 4.703125l-1.640625 0zm1.484375 -8.484375q0 1.859375 0.75 2.765625q0.765625 0.890625 1.828125 0.890625q1.09375 0 1.875 -0.921875q0.78125 -0.9375 0.78125 -2.875q0 -1.8437481 -0.765625 -2.765623q-0.75 -0.921875 -1.8125 -0.921875q-1.046875 0 -1.859375 0.984375q-0.796875 0.96875 -0.796875 2.843748zm8.891342 8.484375l0 -13.374998l1.484375 0l0 1.25q0.53125 -0.734375 1.1875 -1.09375q0.671875 -0.375 1.625 -0.375q1.234375 0 2.171875 0.640625q0.953125 0.625 1.4375 1.796875q0.484375 1.15625 0.484375 2.546873q0 1.484375 -0.53125 2.671875q-0.53125 1.1875 -1.546875 1.828125q-1.015625 0.625 -2.140625 0.625q-0.8125 0 -1.46875 -0.34375q-0.65625 -0.34375 -1.0625 -0.875l0 4.703125l-1.640625 0zm1.484375 -8.484375q0 1.859375 0.75 2.765625q0.765625 0.890625 1.828125 0.890625q1.09375 0 1.875 -0.921875q0.78125 -0.9375 0.78125 -2.875q0 -1.8437481 -0.765625 -2.765623q-0.75 -0.921875 -1.8125 -0.921875q-1.046875 0 -1.859375 0.984375q-0.796875 0.96875 -0.796875 2.843748zm8.844467 4.78125l0 -13.359373l1.640625 0l0 13.359373l-1.640625 0zm4.191696 -11.468748l0 -1.890625l1.640625 0l0 1.890625l-1.640625 0zm0 11.468748l0 -9.671873l1.640625 0l0 9.671873l-1.640625 0zm10.457321 -3.546875l1.609375 0.21875q-0.265625 1.65625 -1.359375 2.609375q-1.078125 0.9375 -2.671875 0.9375q-1.984375 0 -3.1875 -1.296875q-1.203125 -1.296875 -1.203125 -3.71875q0 -1.5781231 0.515625 -2.749998q0.515625 -1.171875 1.578125 -1.75q1.0625 -0.59375 2.3125 -0.59375q1.578125 0 2.578125 0.796875q1.0 0.796875 1.28125 2.265625l-1.59375 0.234375q-0.234375 -0.96875 -0.8125 -1.453125q-0.578125 -0.5 -1.390625 -0.5q-1.234375 0 -2.015625 0.890625q-0.78125 0.890625 -0.78125 2.812498q0 1.953125 0.75 2.84375q0.75 0.875 1.953125 0.875q0.96875 0 1.609375 -0.59375q0.65625 -0.59375 0.828125 -1.828125zm9.328125 2.359375q-0.921875 0.765625 -1.765625 1.09375q-0.828125 0.3125 -1.796875 0.3125q-1.59375 0 -2.453125 -0.78125q-0.859375 -0.78125 -0.859375 -1.984375q0 -0.71875 0.328125 -1.296875q0.328125 -0.59375 0.84375 -0.9375q0.53125 -0.359375 1.1875 -0.546875q0.46875 -0.125 1.453125 -0.25q1.984375 -0.234375 2.921875 -0.5624981q0.015625 -0.34375 0.015625 -0.421875q0 -1.0 -0.46875 -1.421875q-0.625 -0.546875 -1.875 -0.546875q-1.15625 0 -1.703125 0.40625q-0.546875 0.40625 -0.8125 1.421875l-1.609375 -0.21875q0.21875 -1.015625 0.71875 -1.640625q0.5 -0.640625 1.453125 -0.984375q0.953125 -0.34375 2.1875 -0.34375q1.25 0 2.015625 0.296875q0.78125 0.28125 1.140625 0.734375q0.375 0.4375 0.515625 1.109375q0.078125 0.421875 0.078125 1.515625l0 2.187498q0 2.28125 0.109375 2.890625q0.109375 0.59375 0.40625 1.15625l-1.703125 0q-0.265625 -0.515625 -0.328125 -1.1875zm-0.140625 -3.671875q-0.890625 0.375 -2.671875 0.625q-1.015625 0.140625 -1.4375 0.328125q-0.421875 0.1875 -0.65625 0.53125q-0.21875 0.34375 -0.21875 0.78125q0 0.65625 0.5 1.09375q0.5 0.4375 1.453125 0.4375q0.9375 0 1.671875 -0.40625q0.75 -0.421875 1.09375 -1.140625q0.265625 -0.5625 0.265625 -1.640625l0 -0.609375zm7.781967 3.390625l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578123l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671873q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm1.6051788 -9.999998l0 -1.890625l1.640625 0l0 1.890625l-1.640625 0zm0 11.468748l0 -9.671873l1.640625 0l0 9.671873l-1.640625 0zm3.5354462 -4.84375q0 -2.687498 1.484375 -3.968748q1.25 -1.078125 3.046875 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609373q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.03125 0 -3.28125 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.8125 0.921875 2.046875 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.7968731 -0.8125 -2.718748q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.046875 0.921875q-0.796875 0.90625 -0.796875 2.765623zm9.297592 4.84375l0 -9.671873l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.5937481l0 5.953125l-1.640625 0l0 -5.890625q0 -0.9999981 -0.203125 -1.4843731q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515623l0 5.28125l-1.640625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m3.6351707 71.39028l48.850395 0" fill-rule="evenodd"/><path stroke="#ff0000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m3.6351707 71.39028l48.850395 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m195.25722 71.39028l47.338577 0" fill-rule="evenodd"/><path stroke="#ff0000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m195.25722 71.39028l47.338577 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m52.485565 55.358784l142.77165 0l0 32.062992l-142.77165 0z" fill-rule="evenodd"/><path fill="#000000" d="m65.21821 76.19028l0 -9.546875l1.265625 0l0 8.421875l4.703125 0l0 1.125l-5.96875 0zm7.3343506 -8.1875l0 -1.359375l1.171875 0l0 1.359375l-1.171875 0zm0 8.1875l0 -6.90625l1.171875 0l0 6.90625l-1.171875 0zm2.945465 0l0 -6.90625l1.046875 0l0 0.96875q0.328125 -0.515625 0.859375 -0.8125q0.546875 -0.3125 1.234375 -0.3125q0.78125 0 1.265625 0.3125q0.484375 0.3125 0.6875 0.890625q0.828125 -1.203125 2.140625 -1.203125q1.03125 0 1.578125 0.578125q0.5625 0.5625 0.5625 1.734375l0 4.75l-1.171875 0l0 -4.359375q0 -0.703125 -0.125 -1.0q-0.109375 -0.3125 -0.40625 -0.5q-0.296875 -0.1875 -0.703125 -0.1875q-0.71875 0 -1.203125 0.484375q-0.484375 0.484375 -0.484375 1.546875l0 4.015625l-1.171875 0l0 -4.484375q0 -0.78125 -0.296875 -1.171875q-0.28125 -0.390625 -0.921875 -0.390625q-0.5 0 -0.921875 0.265625q-0.421875 0.25 -0.609375 0.75q-0.1875 0.5 -0.1875 1.453125l0 3.578125l-1.171875 0zm11.118057 -8.1875l0 -1.359375l1.171875 0l0 1.359375l-1.171875 0zm0 8.1875l0 -6.90625l1.171875 0l0 6.90625l-1.171875 0zm5.507965 -1.046875l0.171875 1.03125q-0.5 0.109375 -0.890625 0.109375q-0.640625 0 -1.0 -0.203125q-0.34375 -0.203125 -0.484375 -0.53125q-0.140625 -0.328125 -0.140625 -1.390625l0 -3.96875l-0.859375 0l0 -0.90625l0.859375 0l0 -1.71875l1.171875 -0.703125l0 2.421875l1.171875 0l0 0.90625l-1.171875 0l0 4.046875q0 0.5 0.046875 0.640625q0.0625 0.140625 0.203125 0.234375q0.140625 0.078125 0.40625 0.078125q0.203125 0 0.515625 -0.046875zm5.8748627 -1.171875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375zm11.006226 4.125l0 -0.875q-0.65625 1.03125 -1.9375 1.03125q-0.8125 0 -1.515625 -0.453125q-0.6875 -0.453125 -1.078125 -1.265625q-0.375 -0.828125 -0.375 -1.890625q0 -1.03125 0.34375 -1.875q0.34375 -0.84375 1.03125 -1.28125q0.703125 -0.453125 1.546875 -0.453125q0.625 0 1.109375 0.265625q0.5 0.25 0.796875 0.671875l0 -3.421875l1.171875 0l0 9.546875l-1.09375 0zm-3.703125 -3.453125q0 1.328125 0.5625 1.984375q0.5625 0.65625 1.328125 0.65625q0.765625 0 1.296875 -0.625q0.53125 -0.625 0.53125 -1.90625q0 -1.421875 -0.546875 -2.078125q-0.546875 -0.671875 -1.34375 -0.671875q-0.78125 0 -1.3125 0.640625q-0.515625 0.625 -0.515625 2.0zm9.865463 1.390625l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625zm7.0859375 4.71875l-0.125 -1.09375q0.375 0.109375 0.65625 0.109375q0.390625 0 0.625 -0.140625q0.234375 -0.125 0.390625 -0.359375q0.109375 -0.171875 0.359375 -0.875q0.03125 -0.09375 0.109375 -0.28125l-2.625 -6.921875l1.265625 0l1.4375 4.0q0.28125 0.765625 0.5 1.59375q0.203125 -0.796875 0.46875 -1.578125l1.484375 -4.015625l1.171875 0l-2.625 7.015625q-0.421875 1.140625 -0.65625 1.578125q-0.3125 0.578125 -0.71875 0.84375q-0.40625 0.28125 -0.96875 0.28125q-0.328125 0 -0.75 -0.15625zm6.2421875 -4.71875l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8124924 0 1.2031174 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1874924 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8124924 0 1.4218674 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0624924 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.2499924 0.328125 1.7343674 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.4531174 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625zm9.695305 1.015625l0.171875 1.03125q-0.5 0.109375 -0.890625 0.109375q-0.640625 0 -1.0 -0.203125q-0.34375 -0.203125 -0.484375 -0.53125q-0.140625 -0.328125 -0.140625 -1.390625l0 -3.96875l-0.859375 0l0 -0.90625l0.859375 0l0 -1.71875l1.171875 -0.703125l0 2.421875l1.171875 0l0 0.90625l-1.171875 0l0 4.046875q0 0.5 0.046875 0.640625q0.0625 0.140625 0.203125 0.234375q0.140625 0.078125 0.40625 0.078125q0.203125 0 0.515625 -0.046875zm5.8748627 -1.171875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375zm6.5218506 4.125l0 -6.90625l1.046875 0l0 0.96875q0.328125 -0.515625 0.859375 -0.8125q0.546875 -0.3125 1.234375 -0.3125q0.78125 0 1.265625 0.3125q0.484375 0.3125 0.6875 0.890625q0.828125 -1.203125 2.140625 -1.203125q1.03125 0 1.578125 0.578125q0.5625 0.5625 0.5625 1.734375l0 4.75l-1.171875 0l0 -4.359375q0 -0.703125 -0.125 -1.0q-0.109375 -0.3125 -0.40625 -0.5q-0.296875 -0.1875 -0.703125 -0.1875q-0.71875 0 -1.203125 0.484375q-0.484375 0.484375 -0.484375 1.546875l0 4.015625l-1.171875 0l0 -4.484375q0 -0.78125 -0.296875 -1.171875q-0.28125 -0.390625 -0.921875 -0.390625q-0.5 0 -0.921875 0.265625q-0.421875 0.25 -0.609375 0.75q-0.1875 0.5 -0.1875 1.453125l0 3.578125l-1.171875 0zm19.321045 -2.53125l1.15625 0.15625q-0.1875 1.1875 -0.96875 1.859375q-0.78125 0.671875 -1.921875 0.671875q-1.40625 0 -2.28125 -0.921875q-0.859375 -0.9375 -0.859375 -2.65625q0 -1.125 0.375 -1.96875q0.375 -0.84375 1.125 -1.25q0.765625 -0.421875 1.65625 -0.421875q1.125 0 1.84375 0.578125q0.71875 0.5625 0.921875 1.609375l-1.140625 0.171875q-0.171875 -0.703125 -0.59375 -1.046875q-0.40625 -0.359375 -0.984375 -0.359375q-0.890625 0 -1.453125 0.640625q-0.546875 0.640625 -0.546875 2.0q0 1.40625 0.53125 2.03125q0.546875 0.625 1.40625 0.625q0.6875 0 1.140625 -0.421875q0.46875 -0.421875 0.59375 -1.296875zm6.6640625 1.671875q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.96109 0l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.507965 -2.0625l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625z" fill-rule="nonzero"/><path fill="#cfe2f3" d="m36.454067 85.4808l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path stroke="#6d9eeb" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m36.454067 85.4808l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path fill="#000000" d="m76.63558 116.82599l0 -13.359375l1.765625 0l0 5.484375l6.9375 0l0 -5.484375l1.765625 0l0 13.359375l-1.765625 0l0 -6.296875l-6.9375 0l0 6.296875l-1.765625 0zm12.597946 -4.84375q0 -2.6875 1.484375 -3.96875q1.25 -1.078125 3.046875 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609375q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.03125 0 -3.28125 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.8125 0.921875 2.046875 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.796875 -0.8125 -2.71875q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.046875 0.921875q-0.796875 0.90625 -0.796875 2.765625zm8.641342 1.953125l1.625 -0.25q0.125 0.96875 0.75 1.5q0.625 0.515625 1.75 0.515625q1.125 0 1.671875 -0.453125q0.546875 -0.46875 0.546875 -1.09375q0 -0.546875 -0.484375 -0.875q-0.328125 -0.21875 -1.671875 -0.546875q-1.8125 -0.46875 -2.515625 -0.796875q-0.6875 -0.328125 -1.046875 -0.90625q-0.359375 -0.59375 -0.359375 -1.3125q0 -0.640625 0.296875 -1.1875q0.296875 -0.5625 0.8125 -0.921875q0.375 -0.28125 1.03125 -0.46875q0.671875 -0.203125 1.421875 -0.203125q1.140625 0 2.0 0.328125q0.859375 0.328125 1.265625 0.890625q0.421875 0.5625 0.578125 1.5l-1.609375 0.21875q-0.109375 -0.75 -0.640625 -1.171875q-0.515625 -0.421875 -1.46875 -0.421875q-1.140625 0 -1.625 0.375q-0.46875 0.375 -0.46875 0.875q0 0.3125 0.1875 0.578125q0.203125 0.265625 0.640625 0.4375q0.234375 0.09375 1.4375 0.421875q1.75 0.453125 2.4375 0.75q0.6875 0.296875 1.078125 0.859375q0.390625 0.5625 0.390625 1.40625q0 0.828125 -0.484375 1.546875q-0.46875 0.71875 -1.375 1.125q-0.90625 0.390625 -2.046875 0.390625q-1.875 0 -2.875 -0.78125q-0.984375 -0.78125 -1.25 -2.328125zm13.5625 1.421875l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578125l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671875q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm6.9134827 1.46875l0 -13.359375l1.78125 0l0 6.625l6.6249924 -6.625l2.390625 0l-5.5937424 5.421875l5.8437424 7.9375l-2.328125 0l-4.7656174 -6.765625l-2.171875 2.140625l0 4.625l-1.78125 0zm18.943565 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.765625 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375 0 3.15625 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.21875 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.546875 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.390625 -2.65625l5.40625 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.03125 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.125717 5.765625l0 -9.671875l1.46875 0l0 1.46875q0.5625 -1.03125 1.03125 -1.359375q0.484375 -0.328125 1.0625 -0.328125q0.828125 0 1.6875 0.53125l-0.5625 1.515625q-0.609375 -0.359375 -1.203125 -0.359375q-0.546875 0 -0.96875 0.328125q-0.421875 0.328125 -0.609375 0.890625q-0.28125 0.875 -0.28125 1.921875l0 5.0625l-1.625 0zm6.228302 0l0 -9.671875l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.59375l0 5.953125l-1.640625 0l0 -5.890625q0 -1.0 -0.203125 -1.484375q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515625l0 5.28125l-1.640625 0zm17.000732 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.7656403 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375153 0 3.1562653 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.2187653 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.5468903 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.3906403 -2.65625l5.4062653 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.0312653 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.094467 5.765625l0 -13.359375l1.640625 0l0 13.359375l-1.640625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m4.454068 151.90599l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m4.454068 151.90599l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m172.45407 151.90599l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m172.45407 151.90599l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m73.43044 135.54013l100.88189 0l0 32.06299l-100.88189 0z" fill-rule="evenodd"/><path fill="#000000" d="m96.04542 156.37163l0 -9.546875l1.265625 0l0 3.921875l4.953125 0l0 -3.921875l1.265625 0l0 9.546875l-1.265625 0l0 -4.5l-4.953125 0l0 4.5l-1.265625 0zm13.953278 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm8.93837 0l0 -0.875q-0.65625 1.03125 -1.9375 1.03125q-0.8125 0 -1.515625 -0.453125q-0.6875 -0.453125 -1.078125 -1.265625q-0.375 -0.828125 -0.375 -1.890625q0 -1.03125 0.34375 -1.875q0.34375 -0.84375 1.03125 -1.28125q0.703125 -0.453125 1.546875 -0.453125q0.625 0 1.109375 0.265625q0.5 0.25 0.796875 0.671875l0 -3.421875l1.171875 0l0 9.546875l-1.09375 0zm-3.703125 -3.453125q0 1.328125 0.5625 1.984375q0.5625 0.65625 1.328125 0.65625q0.765625 0 1.296875 -0.625q0.53125 -0.625 0.53125 -1.90625q0 -1.421875 -0.546875 -2.078125q-0.546875 -0.671875 -1.34375 -0.671875q-0.78125 0 -1.3125 0.640625q-0.515625 0.625 -0.515625 2.0zm7.9124756 3.453125l-2.125 -6.90625l1.21875 0l1.09375 3.984375l0.421875 1.484375q0.015625 -0.109375 0.359375 -1.421875l1.0937424 -4.046875l1.203125 0l1.03125 4.0l0.34375 1.328125l0.40625 -1.34375l1.171875 -3.984375l1.140625 0l-2.15625 6.90625l-1.21875 0l-1.09375 -4.140625l-0.265625 -1.171875l-1.4062424 5.3125l-1.21875 0zm12.859535 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59376526 0.21875 -1.2812653 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.4218903 -0.171875 2.0937653 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.3437653 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.89064026 0 1.4375153 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.9218903 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.2031403 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm9.18837 -2.21875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375z" fill-rule="nonzero"/><defs><linearGradient id="p.1" gradientUnits="userSpaceOnUse" gradientTransform="matrix(4.54555197232122 0.0 0.0 4.54555197232122 0.0 0.0)" spreadMethod="pad" x1="8.189483259998303" y1="18.80511284496466" x2="8.189466907412452" y2="23.35066481725647"><stop offset="0.0" stop-color="#ff0000"/><stop offset="0.51" stop-color="#dab7a6"/><stop offset="0.99999994" stop-color="#dab7a6" stop-opacity="0.0"/><stop offset="1.0" stop-color="#ffffff" stop-opacity="0.0"/></linearGradient></defs><path fill="url(#p.1)" d="m37.225723 85.48025l173.29134 0l0 20.661415l-173.29134 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m272.4455 100.54161l129.5748 -74.83464l20.629913 35.74803l-129.5748 74.83464z" fill-rule="evenodd"/><path fill="#000000" d="m287.51392 107.20854l1.1823425 -0.8271866q0.51071167 0.6974335 1.1166077 0.9970856q0.5980835 0.28610992 1.4464111 0.19311523q0.84054565 -0.10652161 1.6794434 -0.5910187q0.75772095 -0.4376068 1.2010193 -0.98233795q0.44906616 -0.5660858 0.50097656 -1.1013031q0.057678223 -0.55656433 -0.20785522 -1.0166931q-0.27334595 -0.47366333 -0.7392273 -0.6557007q-0.47366333 -0.1955719 -1.2366333 -0.079704285q-0.478302 0.07775116 -2.032318 0.54221344q-1.5618286 0.45092773 -2.2805786 0.48712158q-0.9222717 0.027420044 -1.5864563 -0.31072998q-0.6719971 -0.35168457 -1.0703125 -1.0418777q-0.4295349 -0.74432373 -0.38497925 -1.6361618q0.05029297 -0.9131851 0.6668701 -1.7203751q0.63012695 -0.8150101 1.6313782 -1.39328q1.1095276 -0.6407852 2.1592712 -0.75988007q1.0419617 -0.13263702 1.8867493 0.29968262q0.8583679 0.4244995 1.3987732 1.2671738l-1.2036743 0.8214798q-0.64712524 -0.87127686 -1.5022583 -1.0089188q-0.8629761 -0.15116882 -2.013092 0.5130539q-1.1906738 0.6876755 -1.4819641 1.4333115q-0.2913208 0.745636 0.06793213 1.3681641q0.30459595 0.5277939 0.8865051 0.6608505q0.5740967 0.119522095 2.3815613 -0.43717957q1.8210144 -0.5645218 2.5725403 -0.6376953q1.0924377 -0.107666016 1.857605 0.28042603q0.77090454 0.366745 1.2316895 1.1652069q0.4529724 0.78491974 0.39904785 1.7543335q-0.040405273 0.96160126 -0.6663208 1.8463745q-0.62594604 0.8847656 -1.6813354 1.4942932q-1.3530579 0.78144073 -2.486084 0.9125519q-1.1408386 0.11756897 -2.1214905 -0.3625946q-0.96713257 -0.48797607 -1.5721436 -1.4738007zm13.40155 -4.9431534l0.8006897 0.98106384q-0.45169067 0.4052124 -0.857605 0.63964844q-0.6629944 0.38290405 -1.1454773 0.39089966q-0.4902649 -0.00554657 -0.8343506 -0.25791168q-0.3518982 -0.2659073 -0.9844971 -1.3620911l-2.382019 -4.127617l-0.8930054 0.5157547l-0.5466919 -0.94731903l0.8930054 -0.5157547l-1.0308838 -1.786377l0.79599 -1.4340897l1.4526367 2.5171661l1.2177734 -0.70329285l0.5466919 0.94731903l-1.2177734 0.7033005l2.4210815 4.1952744q0.30456543 0.5277939 0.4446106 0.645401q0.15356445 0.10979462 0.35705566 0.11856842q0.19570923 -0.004760742 0.4663086 -0.16104889q0.20297241 -0.11721802 0.49645996 -0.35889435zm1.8165283 0.41242218l-4.147064 -7.186104l1.0959778 -0.6329727l0.6247864 1.0826492q-0.0178833 -1.0000992 0.19332886 -1.4468689q0.21121216 -0.44677734 0.64419556 -0.6968384q0.6088562 -0.35164642 1.471283 -0.3264618l0.22875977 1.3654938q-0.59487915 7.4768066E-4 -1.0413818 0.25862885q-0.39239502 0.2266159 -0.5765381 0.6577606q-0.17843628 0.40979004 -0.06384277 0.9209976q0.17193604 0.76680756 0.61709595 1.5382004l2.1711426 3.7622147l-1.2177429 0.7033005zm2.0899658 -6.0066605q-1.1480408 -1.9893799 -0.5930481 -3.5910187q0.47280884 -1.3376465 1.7988281 -2.1034622q1.4883423 -0.8595886 2.984253 -0.4243927q1.501648 0.41384125 2.4934998 2.132553q0.79660034 1.3803787 0.83795166 2.4210892q0.04135132 1.0407028 -0.49923706 1.948349q-0.5348511 0.8862915 -1.4819946 1.4333038q-1.5018921 0.8674011 -2.9899902 0.44573975q-1.4959106 -0.4351883 -2.5502625 -2.2621613zm1.2583313 -0.7267456q0.79663086 1.3803787 1.7902527 1.726738q1.0072021 0.33853912 1.9137268 -0.18502045q0.9065552 -0.5235672 1.1036072 -1.5575943q0.21057129 -1.0418396 -0.60946655 -2.462822q-0.76538086 -1.3262482 -1.7725525 -1.6647949q-0.99365234 -0.34635162 -1.9002075 0.17721558q-0.90652466 0.5235672 -1.117096 1.5654068q-0.2048645 1.0204926 0.59173584 2.4008713zm8.984772 -0.38944244l-4.1470337 -7.1861115l1.0959778 -0.6329651l0.5857239 1.0149841q0.10531616 -1.6306229 1.6072083 -2.498024q0.6494751 -0.37509155 1.3234558 -0.45760345q0.6739807 -0.0825119 1.163269 0.14012909q0.48928833 0.22264099 0.9021301 0.6879196q0.26287842 0.29925537 0.74710083 1.1383057l2.553833 4.4253464l-1.2177429 0.7033005l-2.522583 -4.371216q-0.42956543 -0.74432373 -0.7892456 -1.0237579q-0.3540039 -0.30078125 -0.8442688 -0.30632782q-0.47677612 -0.01335144 -0.9638672 0.26796722q-0.7847595 0.45323944 -1.0640869 1.2821732q-0.2735901 0.80757904 0.52301025 2.1879654l2.264862 3.924614l-1.2177429 0.7033005zm7.819275 -3.7220154l1.2922058 -0.511734q0.3878479 0.5157852 0.8666992 0.56401825q0.65527344 0.072639465 1.4400635 -0.38059235q0.8388977 -0.48449707 1.1036682 -1.0885162q0.26480103 -0.60401917 0.07571411 -1.3067856q-0.1161499 -0.42009735 -0.81121826 -1.6245499q-0.25161743 1.408371 -1.4423218 2.0960464q-1.4883423 0.859581 -2.9171448 0.25933075q-1.428833 -0.60025024 -2.2957153 -2.1024323q-0.5935669 -1.0285187 -0.72805786 -2.1056366q-0.12097168 -1.0849228 0.30926514 -1.9649353q0.4437561 -0.887825 1.3908997 -1.4348297q1.2718811 -0.7345581 2.690796 -0.18271637l-0.4998474 -0.866127l1.1230469 -0.6485977l3.5847168 6.2117233q0.9684448 1.6781082 1.0362854 2.5771942q0.067840576 0.899086 -0.4420166 1.7348709q-0.5098877 0.83579254 -1.5923462 1.4609451q-1.2854004 0.7423706 -2.4195251 0.6214981q-1.1206055 -0.12869263 -1.7651672 -1.3081741zm-1.4765625 -4.90316q0.81222534 1.4074478 1.7418518 1.7366486q0.94314575 0.32138824 1.7820435 -0.16310883q0.8388977 -0.48448944 1.0401001 -1.4487534q0.19342041 -0.97779846 -0.60317993 -2.3581848q-0.76538086 -1.3262482 -1.7298889 -1.6533508q-0.97229004 -0.3406372 -1.7976685 0.13603973q-0.8118286 0.46886444 -0.9974365 1.4601974q-0.1855774 0.991333 0.56417847 2.290512z" fill-rule="nonzero"/><path fill="#000000" d="m294.23132 118.16088l-0.80441284 -1.3939133l1.2177429 -0.7033005l0.80441284 1.3939209l-1.2177429 0.70329285zm4.920227 8.525894l-4.147064 -7.1861115l1.2177734 -0.70329285l4.1470337 7.186104l-1.2177429 0.7033005zm1.3493347 -3.6482391l1.0948792 -0.88494873q0.51641846 0.67609406 1.2029724 0.8028641q0.6922302 0.10542297 1.5176086 -0.37125397q0.8388672 -0.48449707 1.0495605 -1.0572586q0.20285034 -0.5862961 -0.062683105 -1.0464249q-0.23431396 -0.4059906 -0.7266846 -0.44641113q-0.35079956 -0.013923645 -1.4790955 0.31292725q-1.5347595 0.43530273 -2.2030334 0.4964676q-0.66256714 0.03981781 -1.183075 -0.23695374q-0.5069885 -0.28459167 -0.8115845 -0.8123779q-0.28115845 -0.48719788 -0.2989807 -1.0182266q-0.017791748 -0.5310211 0.2048645 -1.0204926q0.15917969 -0.3806305 0.56817627 -0.79727936q0.4147339 -0.43800354 0.9694824 -0.75839233q0.852417 -0.49230957 1.6289368 -0.6159897q0.7765198 -0.123680115 1.3162842 0.123931885q0.5455017 0.22625732 1.0733948 0.8596573l-1.0969849 0.85006714q-0.4013672 -0.5079727 -0.9733887 -0.59262085q-0.5720215 -0.0846405 -1.2756042 0.32170868q-0.8388977 0.48449707 -1.0402222 0.9796829q-0.19558716 0.47383118 0.015289307 0.8392334q0.14056396 0.24359131 0.39874268 0.3470993q0.2581787 0.103507996 0.64749146 0.05910492q0.22845459 -0.041732788 1.2620544 -0.31388855q1.4806519 -0.4040451 2.127594 -0.470932q0.6390991 -0.08041382 1.1653442 0.17501068q0.5397949 0.247612 0.89904785 0.87013245q0.35144043 0.60899353 0.29852295 1.3613129q-0.047210693 0.73096466 -0.5519409 1.4194183q-0.49118042 0.68063354 -1.3300476 1.1651306q-1.407196 0.81269836 -2.4736633 0.65275574q-1.0664673 -0.15995026 -1.933258 -1.1930542zm6.1190186 -5.4646606q-1.1480408 -1.9893723 -0.5930481 -3.591011q0.47283936 -1.3376541 1.7988281 -2.1034698q1.4883423 -0.859581 2.984253 -0.4243927q1.501648 0.41384888 2.4934998 2.1325607q0.79663086 1.3803787 0.83795166 2.4210815q0.04135132 1.0407028 -0.49923706 1.948349q-0.5348511 0.88629913 -1.4819946 1.4333115q-1.5018921 0.8674011 -2.9899902 0.44573975q-1.4959106 -0.43519592 -2.5502625 -2.262169zm1.2583313 -0.7267456q0.79663086 1.3803864 1.7902832 1.7267456q1.0071716 0.33853912 1.9136963 -0.18502808q0.9065552 -0.5235672 1.1036072 -1.5575943q0.21057129 -1.0418396 -0.60946655 -2.462822q-0.76538086 -1.3262482 -1.7725525 -1.6647873q-0.99365234 -0.34635925 -1.9002075 0.17720795q-0.90652466 0.5235672 -1.117096 1.5654068q-0.2048645 1.0204926 0.59173584 2.4008713zm8.957733 -0.3738098l-5.7246704 -9.919807l1.2177734 -0.7033005l5.72464 9.919807l-1.2177429 0.7033005zm7.2713623 -5.390396q-0.34069824 0.9726486 -0.8225403 1.5757141q-0.48962402 0.5895233 -1.2067566 1.003685q-1.1906738 0.6876755 -2.162445 0.47302246q-0.9717407 -0.21464539 -1.4872131 -1.1078339q-0.30456543 -0.5277939 -0.3109436 -1.1015167q-0.014190674 -0.58724976 0.21627808 -1.0631866q0.23620605 -0.4972763 0.64520264 -0.9139328q0.31063843 -0.30571747 0.980896 -0.8010864q1.3733215 -1.0277023 1.9363098 -1.6776505q-0.14837646 -0.25712585 -0.18740845 -0.32479095q-0.42956543 -0.74432373 -0.93963623 -0.84669495q-0.7156677 -0.14601898 -1.6357422 0.38536072q-0.8524475 0.49230957 -1.0922546 1.0458145q-0.23410034 0.5321655 0.013824463 1.3994217l-1.2843933 0.5252762q-0.2749939 -0.85163116 -0.18301392 -1.5362854q0.10549927 -0.6924591 0.66851807 -1.3424072q0.5552063 -0.66348267 1.4752808 -1.1948624q0.92007446 -0.5313797 1.6133118 -0.6430588q0.7067566 -0.11948395 1.1648254 0.04901886q0.45803833 0.16851044 0.8552551 0.60672q0.24728394 0.27218628 0.71588135 1.0841827l0.9371643 1.6239777q0.9840393 1.7051773 1.3094177 2.1127014q0.33892822 0.39970398 0.81103516 0.6863861l-1.2718506 0.7345581q-0.40811157 -0.26953125 -0.7590027 -0.75253296zm-1.6645203 -2.6654587q-0.5067749 0.65356445 -1.7234497 1.6088486q-0.69522095 0.5458679 -0.9283142 0.8609314q-0.23312378 0.31506348 -0.25280762 0.6873169q-0.013977051 0.35090637 0.16564941 0.6621628q0.28115845 0.48719788 0.83392334 0.60100555q0.5662842 0.10598755 1.269867 -0.30036163q0.70358276 -0.40634918 1.072998 -1.016655q0.37512207 -0.63165283 0.3197937 -1.3214569q-0.031341553 -0.5232086 -0.49990845 -1.3351974l-0.25775146 -0.44659424zm7.223419 -0.8156891l0.8006897 0.98106384q-0.45169067 0.40522003 -0.857605 0.63964844q-0.6629944 0.38290405 -1.1454468 0.39089966q-0.4902954 -0.0055389404 -0.8343811 -0.25791168q-0.3518982 -0.26589966 -0.9844971 -1.3620911l-2.382019 -4.1276093l-0.8930054 0.5157471l-0.5466919 -0.94731903l0.8930054 -0.5157471l-1.0308838 -1.786377l0.7960205 -1.4340897l1.4526367 2.5171661l1.2177429 -0.7033005l0.5466919 0.94731903l-1.2177429 0.7033005l2.421051 4.195282q0.30459595 0.5277939 0.4446106 0.645401q0.15356445 0.10978699 0.35708618 0.11856079q0.19567871 -0.004760742 0.46627808 -0.16104889q0.20297241 -0.11721039 0.49645996 -0.35889435zm-3.0901794 -8.121277l-0.80441284 -1.3939209l1.2177429 -0.70329285l0.80444336 1.3939133l-1.2177734 0.7033005zm4.920227 8.525887l-4.1470337 -7.186104l1.2177429 -0.7033005l4.147064 7.186104l-1.2177734 0.7033005zm0.54074097 -5.111908q-1.1480408 -1.9893799 -0.5930481 -3.591011q0.47280884 -1.3376541 1.7988281 -2.1034698q1.4883423 -0.8595886 2.984253 -0.4243927q1.501648 0.41384125 2.4934998 2.132553q0.79660034 1.3803864 0.83795166 2.4210892q0.0413208 1.0407028 -0.49923706 1.948349q-0.5348511 0.88629913 -1.4819946 1.4333038q-1.5018921 0.8674011 -2.9899902 0.44574738q-1.4959412 -0.43519592 -2.5502625 -2.262169zm1.2583313 -0.7267456q0.79663086 1.3803864 1.7902527 1.726738q1.0072021 0.33854675 1.9137268 -0.18502045q0.9065552 -0.5235672 1.1036072 -1.5575943q0.21057129 -1.0418396 -0.60946655 -2.462822q-0.76538086 -1.3262482 -1.7725525 -1.6647949q-0.99365234 -0.34635162 -1.9002075 0.17721558q-0.9065552 0.5235672 -1.117096 1.5654068q-0.2048645 1.0204926 0.59173584 2.4008713zm8.984772 -0.38944244l-4.1470337 -7.1861115l1.0959473 -0.6329651l0.5857544 1.0149841q0.10531616 -1.6306152 1.6072083 -2.4980164q0.6494446 -0.37509918 1.3234558 -0.45761108q0.6739807 -0.0825119 1.163269 0.14012909q0.48928833 0.22264099 0.9021301 0.6879196q0.26287842 0.29925537 0.74710083 1.1383133l2.553833 4.4253387l-1.2177429 0.7033005l-2.5226135 -4.371208q-0.4295349 -0.74432373 -0.7892456 -1.0237656q-0.3539734 -0.30078125 -0.8442383 -0.3063202q-0.47677612 -0.01335907 -0.9638672 0.2679596q-0.7847595 0.45323944 -1.0640869 1.2821732q-0.2735901 0.80757904 0.52301025 2.1879654l2.264862 3.924614l-1.2177429 0.7033005z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m282.76047 118.41563c-17.003845 0 -26.795105 -5.566925 -34.007706 -11.133858c-7.2126007 -5.566925 -11.846542 -11.133858 -23.6931 -11.133858" fill-rule="evenodd"/><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m282.76047 118.415634c-17.003876 0 -26.795105 -5.5669327 -34.007706 -11.133858c-3.6062927 -2.7834702 -6.567932 -5.5669327 -10.10881 -7.6545334c-0.4426117 -0.26094818 -0.8942871 -0.51101685 -1.3573761 -0.74887085c-0.11578369 -0.0594635 -0.23228455 -0.11816406 -0.34950256 -0.17607117l-0.13806152 -0.06682587" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="2.0" stroke-linecap="butt" d="m237.48381 95.40378l-9.563843 1.350235l8.194244 5.1131744z" fill-rule="evenodd"/></g></svg> \ No newline at end of file
diff --git a/g3doc/Sentry-Gofer.png b/g3doc/Sentry-Gofer.png
new file mode 100644
index 000000000..ca2c27ef7
--- /dev/null
+++ b/g3doc/Sentry-Gofer.png
Binary files differ
diff --git a/g3doc/Sentry-Gofer.svg b/g3doc/Sentry-Gofer.svg
new file mode 100644
index 000000000..5c10750d2
--- /dev/null
+++ b/g3doc/Sentry-Gofer.svg
@@ -0,0 +1 @@
+<svg version="1.1" viewBox="0.0 0.0 358.8556430446194 249.67191601049868" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg"><clipPath id="p.0"><path d="m0 0l358.85565 0l0 249.67192l-358.85565 0l0 -249.67192z" clip-rule="nonzero"/></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l358.85565 0l0 249.67192l-358.85565 0z" fill-rule="evenodd"/><path fill="#f4cccc" d="m36.454067 6.6430445l114.4252 0l0 48.850395l-114.4252 0z" fill-rule="evenodd"/><path stroke="#cc4125" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m36.454067 6.6430445l114.4252 0l0 48.850395l-114.4252 0z" fill-rule="evenodd"/><path fill="#000000" d="m48.00139 37.98824l5.125 -13.359373l1.90625 0l5.46875 13.359373l-2.015625 0l-1.546875 -4.046875l-5.59375 0l-1.46875 4.046875l-1.875 0zm3.859375 -5.484375l4.53125 0l-1.40625 -3.703123q-0.625 -1.6875 -0.9375 -2.765625q-0.265625 1.28125 -0.71875 2.546875l-1.46875 3.921873zm9.849823 9.1875l0 -13.374998l1.484375 0l0 1.25q0.53125 -0.734375 1.1875 -1.09375q0.671875 -0.375 1.625 -0.375q1.234375 0 2.171875 0.640625q0.953125 0.625 1.4375 1.796875q0.484375 1.15625 0.484375 2.546873q0 1.484375 -0.53125 2.671875q-0.53125 1.1875 -1.546875 1.828125q-1.015625 0.625 -2.140625 0.625q-0.8125 0 -1.46875 -0.34375q-0.65625 -0.34375 -1.0625 -0.875l0 4.703125l-1.640625 0zm1.484375 -8.484375q0 1.859375 0.75 2.765625q0.765625 0.890625 1.828125 0.890625q1.09375 0 1.875 -0.921875q0.78125 -0.9375 0.78125 -2.875q0 -1.8437481 -0.765625 -2.765623q-0.75 -0.921875 -1.8125 -0.921875q-1.046875 0 -1.859375 0.984375q-0.796875 0.96875 -0.796875 2.843748zm8.891342 8.484375l0 -13.374998l1.484375 0l0 1.25q0.53125 -0.734375 1.1875 -1.09375q0.671875 -0.375 1.625 -0.375q1.234375 0 2.171875 0.640625q0.953125 0.625 1.4375 1.796875q0.484375 1.15625 0.484375 2.546873q0 1.484375 -0.53125 2.671875q-0.53125 1.1875 -1.546875 1.828125q-1.015625 0.625 -2.140625 0.625q-0.8125 0 -1.46875 -0.34375q-0.65625 -0.34375 -1.0625 -0.875l0 4.703125l-1.640625 0zm1.484375 -8.484375q0 1.859375 0.75 2.765625q0.765625 0.890625 1.828125 0.890625q1.09375 0 1.875 -0.921875q0.78125 -0.9375 0.78125 -2.875q0 -1.8437481 -0.765625 -2.765623q-0.75 -0.921875 -1.8125 -0.921875q-1.046875 0 -1.859375 0.984375q-0.796875 0.96875 -0.796875 2.843748zm8.844467 4.78125l0 -13.359373l1.640625 0l0 13.359373l-1.640625 0zm4.191696 -11.468748l0 -1.890625l1.640625 0l0 1.890625l-1.640625 0zm0 11.468748l0 -9.671873l1.640625 0l0 9.671873l-1.640625 0zm10.457321 -3.546875l1.609375 0.21875q-0.265625 1.65625 -1.359375 2.609375q-1.078125 0.9375 -2.671875 0.9375q-1.984375 0 -3.1875 -1.296875q-1.203125 -1.296875 -1.203125 -3.71875q0 -1.5781231 0.515625 -2.749998q0.515625 -1.171875 1.578125 -1.75q1.0625 -0.59375 2.3125 -0.59375q1.578125 0 2.578125 0.796875q1.0 0.796875 1.28125 2.265625l-1.59375 0.234375q-0.234375 -0.96875 -0.8125 -1.453125q-0.578125 -0.5 -1.390625 -0.5q-1.234375 0 -2.015625 0.890625q-0.78125 0.890625 -0.78125 2.812498q0 1.953125 0.75 2.84375q0.75 0.875 1.953125 0.875q0.96875 0 1.609375 -0.59375q0.65625 -0.59375 0.828125 -1.828125zm9.328125 2.359375q-0.921875 0.765625 -1.765625 1.09375q-0.828125 0.3125 -1.796875 0.3125q-1.59375 0 -2.453125 -0.78125q-0.859375 -0.78125 -0.859375 -1.984375q0 -0.71875 0.328125 -1.296875q0.328125 -0.59375 0.84375 -0.9375q0.53125 -0.359375 1.1875 -0.546875q0.46875 -0.125 1.453125 -0.25q1.984375 -0.234375 2.921875 -0.5624981q0.015625 -0.34375 0.015625 -0.421875q0 -1.0 -0.46875 -1.421875q-0.625 -0.546875 -1.875 -0.546875q-1.15625 0 -1.703125 0.40625q-0.546875 0.40625 -0.8125 1.421875l-1.609375 -0.21875q0.21875 -1.015625 0.71875 -1.640625q0.5 -0.640625 1.453125 -0.984375q0.953125 -0.34375 2.1875 -0.34375q1.25 0 2.015625 0.296875q0.78125 0.28125 1.140625 0.734375q0.375 0.4375 0.515625 1.109375q0.078125 0.421875 0.078125 1.515625l0 2.187498q0 2.28125 0.109375 2.890625q0.109375 0.59375 0.40625 1.15625l-1.703125 0q-0.265625 -0.515625 -0.328125 -1.1875zm-0.140625 -3.671875q-0.890625 0.375 -2.671875 0.625q-1.015625 0.140625 -1.4375 0.328125q-0.421875 0.1875 -0.65625 0.53125q-0.21875 0.34375 -0.21875 0.78125q0 0.65625 0.5 1.09375q0.5 0.4375 1.453125 0.4375q0.9375 0 1.671875 -0.40625q0.75 -0.421875 1.09375 -1.140625q0.265625 -0.5625 0.265625 -1.640625l0 -0.609375zm7.781967 3.390625l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578123l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671873q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm1.6051788 -9.999998l0 -1.890625l1.640625 0l0 1.890625l-1.640625 0zm0 11.468748l0 -9.671873l1.640625 0l0 9.671873l-1.640625 0zm3.5354462 -4.84375q0 -2.687498 1.484375 -3.968748q1.25 -1.078125 3.046875 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609373q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.03125 0 -3.28125 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.8125 0.921875 2.046875 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.7968731 -0.8125 -2.718748q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.046875 0.921875q-0.796875 0.90625 -0.796875 2.765623zm9.297592 4.84375l0 -9.671873l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.5937481l0 5.953125l-1.640625 0l0 -5.890625q0 -0.9999981 -0.203125 -1.4843731q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515623l0 5.28125l-1.640625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m4.4540663 73.055115l40.47244 0.25196838" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m4.4540663 73.055115l40.47244 0.25196838" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m144.0 74.0l35.27559 -0.94488525" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m144.0 74.0l35.27559 -0.94488525" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m44.926506 57.27559l97.48032 0l0 32.062996l-97.48032 0z" fill-rule="evenodd"/><path fill="#000000" d="m56.859642 75.044586l1.203125 -0.109375q0.078125 0.71875 0.390625 1.1875q0.3125 0.453125 0.953125 0.734375q0.65625 0.28125 1.46875 0.28125q0.71875 0 1.265625 -0.21875q0.5625 -0.21875 0.828125 -0.578125q0.265625 -0.375 0.265625 -0.828125q0 -0.453125 -0.265625 -0.78125q-0.25 -0.328125 -0.84375 -0.5625q-0.390625 -0.15625 -1.703125 -0.46875q-1.3125 -0.3125 -1.84375 -0.59375q-0.671875 -0.359375 -1.015625 -0.890625q-0.328125 -0.53125 -0.328125 -1.1875q0 -0.71875 0.40625 -1.34375q0.40625 -0.625 1.1875 -0.953125q0.796875 -0.328125 1.765625 -0.328125q1.046875 0 1.859375 0.34375q0.8125 0.34375 1.25 1.015625q0.4375 0.65625 0.46875 1.484375l-1.203125 0.09375q-0.109375 -0.90625 -0.671875 -1.359375q-0.5625 -0.46875 -1.65625 -0.46875q-1.140625 0 -1.671875 0.421875q-0.515625 0.421875 -0.515625 1.015625q0 0.515625 0.359375 0.84375q0.375 0.328125 1.90625 0.6875q1.546875 0.34375 2.109375 0.59375q0.84375 0.390625 1.234375 0.984375q0.390625 0.578125 0.390625 1.359375q0 0.75 -0.4375 1.4375q-0.421875 0.671875 -1.25 1.046875q-0.8125 0.359375 -1.828125 0.359375q-1.296875 0 -2.171875 -0.375q-0.875 -0.375 -1.375 -1.125q-0.5 -0.765625 -0.53125 -1.71875zm9.12413 5.71875l-0.125 -1.09375q0.375 0.109375 0.65625 0.109375q0.390625 0 0.625 -0.140625q0.234375 -0.125 0.390625 -0.359375q0.109375 -0.171875 0.359375 -0.875q0.03125 -0.09375 0.109375 -0.28125l-2.625 -6.921875l1.265625 0l1.4375 4.0q0.28125 0.765625 0.5 1.59375q0.203125 -0.796875 0.46875 -1.578125l1.484375 -4.015625l1.171875 0l-2.625 7.015625q-0.421875 1.140625 -0.65625 1.578125q-0.3125 0.578125 -0.71875 0.84375q-0.40625 0.28125 -0.96875 0.28125q-0.328125 0 -0.75 -0.15625zm6.2421875 -4.71875l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625zm9.6953125 1.015625l0.171875 1.03125q-0.5 0.109375 -0.890625 0.109375q-0.640625 0 -1.0 -0.203125q-0.34375 -0.203125 -0.484375 -0.53125q-0.140625 -0.328125 -0.140625 -1.390625l0 -3.96875l-0.859375 0l0 -0.90625l0.859375 0l0 -1.71875l1.171875 -0.703125l0 2.421875l1.171875 0l0 0.90625l-1.171875 0l0 4.046875q0 0.5 0.046875 0.640625q0.0625 0.140625 0.203125 0.234375q0.140625 0.078125 0.40625 0.078125q0.203125 0 0.515625 -0.046875zm5.8748627 -1.171875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375zm6.5218506 4.125l0 -6.90625l1.046875 0l0 0.96875q0.328125 -0.515625 0.859375 -0.8125q0.546875 -0.3125 1.234375 -0.3125q0.78125 0 1.265625 0.3125q0.484375 0.3125 0.6875 0.890625q0.828125 -1.203125 2.140625 -1.203125q1.03125 0 1.578125 0.578125q0.5625 0.5625 0.5625 1.734375l0 4.75l-1.171875 0l0 -4.359375q0 -0.703125 -0.125 -1.0q-0.109375 -0.3125 -0.40625 -0.5q-0.296875 -0.1875 -0.703125 -0.1875q-0.71875 0 -1.203125 0.484375q-0.484375 0.484375 -0.484375 1.546875l0 4.015625l-1.171875 0l0 -4.484375q0 -0.78125 -0.296875 -1.171875q-0.28125 -0.390625 -0.921875 -0.390625q-0.5 0 -0.921875 0.265625q-0.421875 0.25 -0.609375 0.75q-0.1875 0.5 -0.1875 1.453125l0 3.578125l-1.171875 0zm19.321045 -2.53125l1.15625 0.15625q-0.1875 1.1875 -0.96875 1.859375q-0.78125 0.671875 -1.921875 0.671875q-1.40625 0 -2.28125 -0.921875q-0.859375 -0.9375 -0.859375 -2.65625q0 -1.125 0.375 -1.96875q0.375 -0.84375 1.125 -1.25q0.765625 -0.421875 1.65625 -0.421875q1.125 0 1.84375 0.578125q0.71875 0.5625 0.921875 1.609375l-1.140625 0.171875q-0.171875 -0.703125 -0.59375 -1.046875q-0.40625 -0.359375 -0.984375 -0.359375q-0.890625 0 -1.453125 0.640625q-0.546875 0.640625 -0.546875 2.0q0 1.40625 0.53125 2.03125q0.546875 0.625 1.40625 0.625q0.6875 0 1.140625 -0.421875q0.46875 -0.421875 0.59375 -1.296875zm6.6640625 1.671875q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.96109 0l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.507965 -2.0625l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125076 0 1.2031326 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875076 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125076 0 1.4218826 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625076 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.2500076 0.328125 1.7343826 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.4531326 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625z" fill-rule="nonzero"/><path fill="#d9d2e9" d="m36.454067 87.40682l114.4252 0l0 48.850395l-114.4252 0z" fill-rule="evenodd"/><path stroke="#8e7cc3" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m36.454067 87.40682l114.4252 0l0 48.850395l-114.4252 0z" fill-rule="evenodd"/><path fill="#000000" d="m67.55086 114.45515l1.65625 -0.140625q0.125 1.0 0.546875 1.640625q0.4375 0.640625 1.34375 1.046875q0.921875 0.390625 2.0625 0.390625q1.0 0 1.78125 -0.296875q0.78125 -0.296875 1.15625 -0.8125q0.375 -0.53125 0.375 -1.15625q0 -0.625 -0.375 -1.09375q-0.359375 -0.46875 -1.1875 -0.796875q-0.546875 -0.203125 -2.390625 -0.640625q-1.828125 -0.453125 -2.5625 -0.84375q-0.96875 -0.5 -1.4375 -1.234375q-0.46875 -0.75 -0.46875 -1.671875q0 -1.0 0.578125 -1.875q0.578125 -0.890625 1.671875 -1.34375q1.109375 -0.453125 2.453125 -0.453125q1.484375 0 2.609375 0.484375q1.140625 0.46875 1.75 1.40625q0.609375 0.921875 0.65625 2.09375l-1.6875 0.125q-0.140625 -1.265625 -0.9375 -1.90625q-0.78125 -0.65625 -2.3125 -0.65625q-1.609375 0 -2.34375 0.59375q-0.734375 0.59375 -0.734375 1.421875q0 0.71875 0.53125 1.171875q0.5 0.46875 2.65625 0.96875q2.15625 0.484375 2.953125 0.84375q1.171875 0.53125 1.71875 1.359375q0.5625 0.828125 0.5625 1.90625q0 1.0625 -0.609375 2.015625q-0.609375 0.9375 -1.75 1.46875q-1.140625 0.515625 -2.578125 0.515625q-1.8125 0 -3.046875 -0.53125q-1.21875 -0.53125 -1.921875 -1.59375q-0.6875 -1.0625 -0.71875 -2.40625zm19.459198 1.1875l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.765625 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375 0 3.15625 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.21875 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.546875 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.390625 -2.65625l5.40625 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.03125 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.141342 5.765625l0 -9.671875l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.59375l0 5.953125l-1.640625 0l0 -5.890625q0 -1.0 -0.203125 -1.484375q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515625l0 5.28125l-1.640625 0zm13.953842 -1.46875l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578125l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671875q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm1.5895538 1.46875l0 -9.671875l1.46875 0l0 1.46875q0.5625 -1.03125 1.03125 -1.359375q0.484375 -0.328125 1.0625 -0.328125q0.828125 0 1.6875 0.53125l-0.5625 1.515625q-0.609375 -0.359375 -1.203125 -0.359375q-0.546875 0 -0.96875 0.328125q-0.421875 0.328125 -0.609375 0.890625q-0.28125 0.875 -0.28125 1.921875l0 5.0625l-1.625 0zm6.150177 3.71875l-0.1875 -1.53125q0.546875 0.140625 0.9375 0.140625q0.546875 0 0.875 -0.1875q0.328125 -0.171875 0.546875 -0.5q0.15625 -0.25 0.5 -1.21875q0.046875 -0.140625 0.140625 -0.40625l-3.671875 -9.6875l1.765625 0l2.015625 5.59375q0.390625 1.078125 0.703125 2.25q0.28125 -1.125 0.671875 -2.203125l2.078125 -5.640625l1.640625 0l-3.6875 9.828125q-0.59375 1.609375 -0.921875 2.203125q-0.4375 0.8125 -1.0 1.1875q-0.5625 0.375 -1.34375 0.375q-0.484375 0 -1.0625 -0.203125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m6.8477693 152.91733l48.85039 0" fill-rule="evenodd"/><path stroke="#ff0000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m6.8477693 152.91733l48.85039 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m145.08398 152.91733l33.95276 0" fill-rule="evenodd"/><path stroke="#ff0000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m145.08398 152.91733l33.95276 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m55.698162 136.88583l89.38582 0l0 32.06299l-89.38582 0z" fill-rule="evenodd"/><path fill="#000000" d="m81.92428 150.43732l0 -8.59375l1.140625 0l0 7.578125l4.234375 0l0 1.015625l-5.375 0zm6.595703 -7.375l0 -1.21875l1.0625 0l0 1.21875l-1.0625 0zm0 7.375l0 -6.21875l1.0625 0l0 6.21875l-1.0625 0zm2.6660156 0l0 -6.21875l0.9375 0l0 0.875q0.296875 -0.46875 0.78125 -0.734375q0.484375 -0.28125 1.109375 -0.28125q0.6875 0 1.125 0.28125q0.453125 0.28125 0.625 0.796875q0.75 -1.078125 1.921875 -1.078125q0.9375 0 1.421875 0.515625q0.5 0.5 0.5 1.578125l0 4.265625l-1.046875 0l0 -3.921875q0 -0.625 -0.109375 -0.90625q-0.09375 -0.28125 -0.359375 -0.453125q-0.265625 -0.171875 -0.640625 -0.171875q-0.65625 0 -1.09375 0.4375q-0.421875 0.4375 -0.421875 1.40625l0 3.609375l-1.0625 0l0 -4.046875q0 -0.703125 -0.265625 -1.046875q-0.25 -0.359375 -0.828125 -0.359375q-0.453125 0 -0.828125 0.234375q-0.375 0.234375 -0.546875 0.6875q-0.171875 0.453125 -0.171875 1.296875l0 3.234375l-1.046875 0zm9.996094 -7.375l0 -1.21875l1.0625 0l0 1.21875l-1.0625 0zm0 7.375l0 -6.21875l1.0625 0l0 6.21875l-1.0625 0zm4.9628906 -0.9375l0.15625 0.921875q-0.453125 0.09375 -0.796875 0.09375q-0.578125 0 -0.890625 -0.171875q-0.3125 -0.1875 -0.453125 -0.484375q-0.125 -0.296875 -0.125 -1.25l0 -3.578125l-0.765625 0l0 -0.8125l0.765625 0l0 -1.546875l1.046875 -0.625l0 2.171875l1.0625 0l0 0.8125l-1.0625 0l0 3.640625q0 0.453125 0.046875 0.578125q0.0625 0.125 0.1875 0.203125q0.125 0.078125 0.359375 0.078125q0.1875 0 0.46875 -0.03125zm5.2871094 -1.0625l1.09375 0.125q-0.25 0.953125 -0.953125 1.484375q-0.703125 0.53125 -1.78125 0.53125q-1.359375 0 -2.171875 -0.84375q-0.796875 -0.84375 -0.796875 -2.359375q0 -1.5625 0.8125 -2.421875q0.8125 -0.875 2.09375 -0.875q1.25 0 2.03125 0.84375q0.796875 0.84375 0.796875 2.390625q0 0.09375 0 0.28125l-4.640625 0q0.0625 1.03125 0.578125 1.578125q0.515625 0.53125 1.296875 0.53125q0.578125 0 0.984375 -0.296875q0.421875 -0.3125 0.65625 -0.96875zm-3.453125 -1.703125l3.46875 0q-0.0625 -0.796875 -0.390625 -1.1875q-0.515625 -0.609375 -1.3125 -0.609375q-0.734375 0 -1.234375 0.484375q-0.484375 0.484375 -0.53125 1.3125zm9.908203 3.703125l0 -0.78125q-0.59375 0.921875 -1.734375 0.921875q-0.75 0 -1.375 -0.40625q-0.625 -0.421875 -0.96875 -1.15625q-0.34375 -0.734375 -0.34375 -1.6875q0 -0.921875 0.3125 -1.6875q0.3125 -0.765625 0.9375 -1.15625q0.625 -0.40625 1.390625 -0.40625q0.5625 0 1.0 0.234375q0.4375 0.234375 0.71875 0.609375l0 -3.078125l1.046875 0l0 8.59375l-0.984375 0zm-3.328125 -3.109375q0 1.203125 0.5 1.796875q0.5 0.578125 1.1875 0.578125q0.6875 0 1.171875 -0.5625q0.484375 -0.5625 0.484375 -1.71875q0 -1.28125 -0.5 -1.875q-0.484375 -0.59375 -1.203125 -0.59375q-0.703125 0 -1.171875 0.578125q-0.46875 0.5625 -0.46875 1.796875z" fill-rule="nonzero"/><path fill="#000000" d="m68.0942 162.57794l1.03125 -0.15625q0.09375 0.625 0.484375 0.953125q0.40625 0.328125 1.140625 0.328125q0.71875 0 1.0625 -0.28125q0.359375 -0.296875 0.359375 -0.703125q0 -0.359375 -0.3125 -0.5625q-0.21875 -0.140625 -1.078125 -0.359375q-1.15625 -0.296875 -1.609375 -0.5q-0.4375 -0.21875 -0.671875 -0.59375q-0.234375 -0.375 -0.234375 -0.84375q0 -0.40625 0.1875 -0.765625q0.1875 -0.359375 0.515625 -0.59375q0.25 -0.171875 0.671875 -0.296875q0.421875 -0.125 0.921875 -0.125q0.71875 0 1.265625 0.21875q0.5625 0.203125 0.828125 0.5625q0.265625 0.359375 0.359375 0.953125l-1.03125 0.140625q-0.0625 -0.46875 -0.40625 -0.734375q-0.328125 -0.28125 -0.953125 -0.28125q-0.71875 0 -1.03125 0.25q-0.3125 0.234375 -0.3125 0.5625q0 0.203125 0.125 0.359375q0.140625 0.171875 0.40625 0.28125q0.15625 0.0625 0.9375 0.265625q1.125 0.3125 1.5625 0.5q0.4375 0.1875 0.6875 0.546875q0.25 0.359375 0.25 0.90625q0 0.53125 -0.3125 1.0q-0.296875 0.453125 -0.875 0.71875q-0.578125 0.25 -1.3125 0.25q-1.21875 0 -1.859375 -0.5q-0.625 -0.515625 -0.796875 -1.5zm6.375 4.25l-0.125 -0.984375q0.34375 0.09375 0.609375 0.09375q0.34375 0 0.546875 -0.125q0.21875 -0.109375 0.359375 -0.3125q0.09375 -0.171875 0.328125 -0.796875q0.015625 -0.078125 0.09375 -0.25l-2.375 -6.234375l1.140625 0l1.296875 3.59375q0.25 0.6875 0.453125 1.453125q0.1875 -0.734375 0.4375 -1.421875l1.328125 -3.625l1.046875 0l-2.359375 6.328125q-0.390625 1.015625 -0.59375 1.40625q-0.28125 0.53125 -0.65625 0.765625q-0.359375 0.25 -0.859375 0.25q-0.296875 0 -0.671875 -0.140625zm5.625 -4.25l1.03125 -0.15625q0.09375 0.625 0.484375 0.953125q0.40625 0.328125 1.140625 0.328125q0.71875 0 1.0625 -0.28125q0.359375 -0.296875 0.359375 -0.703125q0 -0.359375 -0.3125 -0.5625q-0.21875 -0.140625 -1.078125 -0.359375q-1.15625 -0.296875 -1.609375 -0.5q-0.4375 -0.21875 -0.671875 -0.59375q-0.234375 -0.375 -0.234375 -0.84375q0 -0.40625 0.1875 -0.765625q0.1875 -0.359375 0.515625 -0.59375q0.25 -0.171875 0.671875 -0.296875q0.421875 -0.125 0.921875 -0.125q0.71875 0 1.265625 0.21875q0.5625 0.203125 0.828125 0.5625q0.265625 0.359375 0.359375 0.953125l-1.03125 0.140625q-0.0625 -0.46875 -0.40625 -0.734375q-0.328125 -0.28125 -0.953125 -0.28125q-0.71875 0 -1.03125 0.25q-0.3125 0.234375 -0.3125 0.5625q0 0.203125 0.125 0.359375q0.140625 0.171875 0.40625 0.28125q0.15625 0.0625 0.9375 0.265625q1.125 0.3125 1.5625 0.5q0.4375 0.1875 0.6875 0.546875q0.25 0.359375 0.25 0.90625q0 0.53125 -0.3125 1.0q-0.296875 0.453125 -0.875 0.71875q-0.578125 0.25 -1.3125 0.25q-1.21875 0 -1.859375 -0.5q-0.625 -0.515625 -0.796875 -1.5zm8.71875 0.921875l0.15625 0.921875q-0.453125 0.09375 -0.796875 0.09375q-0.578125 0 -0.890625 -0.171875q-0.3125 -0.1875 -0.453125 -0.484375q-0.125 -0.296875 -0.125 -1.25l0 -3.578125l-0.765625 0l0 -0.8125l0.765625 0l0 -1.546875l1.046875 -0.625l0 2.171875l1.0625 0l0 0.8125l-1.0625 0l0 3.640625q0 0.453125 0.046875 0.578125q0.0625 0.125 0.1875 0.203125q0.125 0.078125 0.359375 0.078125q0.1875 0 0.46875 -0.03125zm5.2871094 -1.0625l1.09375 0.125q-0.25 0.953125 -0.953125 1.484375q-0.703125 0.53125 -1.78125 0.53125q-1.359375 0 -2.171875 -0.84375q-0.796875 -0.84375 -0.796875 -2.359375q0 -1.5625 0.8125 -2.421875q0.8125 -0.875 2.09375 -0.875q1.25 0 2.03125 0.84375q0.796875 0.84375 0.796875 2.390625q0 0.09375 0 0.28125l-4.640625 0q0.0625 1.03125 0.578125 1.578125q0.515625 0.53125 1.296875 0.53125q0.578125 0 0.984375 -0.296875q0.421875 -0.3125 0.65625 -0.96875zm-3.453125 -1.703125l3.46875 0q-0.0625 -0.796875 -0.390625 -1.1875q-0.515625 -0.609375 -1.3125 -0.609375q-0.734375 0 -1.234375 0.484375q-0.484375 0.484375 -0.53125 1.3125zm5.876953 3.703125l0 -6.21875l0.9375 0l0 0.875q0.296875 -0.46875 0.78125 -0.734375q0.484375 -0.28125 1.109375 -0.28125q0.6875 0 1.125 0.28125q0.453125 0.28125 0.625 0.796875q0.75 -1.078125 1.921875 -1.078125q0.9375 0 1.421875 0.515625q0.5 0.5 0.5 1.578125l0 4.265625l-1.046875 0l0 -3.921875q0 -0.625 -0.109375 -0.90625q-0.09375 -0.28125 -0.359375 -0.453125q-0.265625 -0.171875 -0.640625 -0.171875q-0.65625 0 -1.09375 0.4375q-0.421875 0.4375 -0.421875 1.40625l0 3.609375l-1.0625 0l0 -4.046875q0 -0.703125 -0.265625 -1.046875q-0.25 -0.359375 -0.828125 -0.359375q-0.453125 0 -0.828125 0.234375q-0.375 0.234375 -0.546875 0.6875q-0.171875 0.453125 -0.171875 1.296875l0 3.234375l-1.046875 0zm17.392578 -2.28125l1.03125 0.140625q-0.171875 1.0625 -0.875 1.671875q-0.703125 0.609375 -1.71875 0.609375q-1.28125 0 -2.0625 -0.828125q-0.765625 -0.84375 -0.765625 -2.40625q0 -1.0 0.328125 -1.75q0.34375 -0.765625 1.015625 -1.140625q0.6875 -0.375 1.5 -0.375q1.0 0 1.640625 0.515625q0.65625 0.5 0.84375 1.453125l-1.03125 0.15625q-0.140625 -0.625 -0.515625 -0.9375q-0.375 -0.328125 -0.90625 -0.328125q-0.796875 0 -1.296875 0.578125q-0.5 0.5625 -0.5 1.796875q0 1.265625 0.484375 1.828125q0.484375 0.5625 1.25 0.5625q0.625 0 1.03125 -0.375q0.421875 -0.375 0.546875 -1.171875zm6.0000076 1.515625q-0.5937576 0.5 -1.1406326 0.703125q-0.53125 0.203125 -1.15625 0.203125q-1.03125 0 -1.578125 -0.5q-0.546875 -0.5 -0.546875 -1.28125q0 -0.453125 0.203125 -0.828125q0.203125 -0.390625 0.546875 -0.609375q0.34375 -0.234375 0.765625 -0.34375q0.296875 -0.09375 0.9375 -0.171875q1.265625 -0.140625 1.8750076 -0.359375q0 -0.21875 0 -0.265625q0 -0.65625 -0.29688263 -0.921875q-0.40625 -0.34375 -1.203125 -0.34375q-0.734375 0 -1.09375 0.265625q-0.359375 0.25 -0.53125 0.90625l-1.03125 -0.140625q0.140625 -0.65625 0.46875 -1.0625q0.328125 -0.40625 0.9375 -0.625q0.609375 -0.21875 1.40625 -0.21875q0.796875 0 1.2968826 0.1875q0.5 0.1875 0.734375 0.46875q0.234375 0.28125 0.328125 0.71875q0.046875 0.265625 0.046875 0.96875l0 1.40625q0 1.46875 0.0625 1.859375q0.078125 0.390625 0.28125 0.75l-1.109375 0q-0.15625 -0.328125 -0.203125 -0.765625zm-0.09375 -2.359375q-0.5781326 0.234375 -1.7187576 0.40625q-0.65625 0.09375 -0.921875 0.21875q-0.265625 0.109375 -0.421875 0.328125q-0.140625 0.21875 -0.140625 0.5q0 0.421875 0.3125 0.703125q0.328125 0.28125 0.9375 0.28125q0.609375 0 1.078125 -0.265625q0.484375 -0.265625 0.703125 -0.734375q0.17188263 -0.359375 0.17188263 -1.046875l0 -0.390625zm2.6738281 3.125l0 -8.59375l1.0625 0l0 8.59375l-1.0625 0zm2.6660156 0l0 -8.59375l1.0625 0l0 8.59375l-1.0625 0zm2.2753906 -1.859375l1.03125 -0.15625q0.09375 0.625 0.484375 0.953125q0.40625 0.328125 1.140625 0.328125q0.71875 0 1.0625 -0.28125q0.359375 -0.296875 0.359375 -0.703125q0 -0.359375 -0.3125 -0.5625q-0.21875 -0.140625 -1.078125 -0.359375q-1.15625 -0.296875 -1.609375 -0.5q-0.4375 -0.21875 -0.671875 -0.59375q-0.234375 -0.375 -0.234375 -0.84375q0 -0.40625 0.1875 -0.765625q0.1875 -0.359375 0.515625 -0.59375q0.25 -0.171875 0.671875 -0.296875q0.421875 -0.125 0.921875 -0.125q0.71875 0 1.265625 0.21875q0.5625 0.203125 0.828125 0.5625q0.265625 0.359375 0.359375 0.953125l-1.03125 0.140625q-0.0625 -0.46875 -0.40625 -0.734375q-0.328125 -0.28125 -0.953125 -0.28125q-0.71875 0 -1.03125 0.25q-0.3125 0.234375 -0.3125 0.5625q0 0.203125 0.125 0.359375q0.140625 0.171875 0.40625 0.28125q0.15625 0.0625 0.9375 0.265625q1.125 0.3125 1.5625 0.5q0.4375 0.1875 0.6875 0.546875q0.25 0.359375 0.25 0.90625q0 0.53125 -0.3125 1.0q-0.296875 0.453125 -0.875 0.71875q-0.578125 0.25 -1.3125 0.25q-1.21875 0 -1.859375 -0.5q-0.625 -0.515625 -0.796875 -1.5z" fill-rule="nonzero"/><path fill="#cfe2f3" d="m89.902885 171.04015l174.83463 0l0 48.850388l-174.83463 0z" fill-rule="evenodd"/><path stroke="#6d9eeb" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m89.902885 171.04015l174.83463 0l0 48.850388l-174.83463 0z" fill-rule="evenodd"/><path fill="#000000" d="m130.0844 202.38535l0 -13.359375l1.765625 0l0 5.484375l6.9375 0l0 -5.484375l1.765625 0l0 13.359375l-1.765625 0l0 -6.296875l-6.9375 0l0 6.296875l-1.765625 0zm12.597946 -4.84375q0 -2.6875 1.484375 -3.96875q1.25 -1.078125 3.046875 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609375q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.03125 0 -3.28125 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.8125 0.921875 2.046875 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.796875 -0.8125 -2.71875q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.046875 0.921875q-0.796875 0.90625 -0.796875 2.765625zm8.641342 1.953125l1.625 -0.25q0.125 0.96875 0.75 1.5q0.625 0.515625 1.75 0.515625q1.125 0 1.671875 -0.453125q0.546875 -0.46875 0.546875 -1.09375q0 -0.546875 -0.484375 -0.875q-0.328125 -0.21875 -1.671875 -0.546875q-1.8125 -0.46875 -2.515625 -0.796875q-0.6875 -0.328125 -1.046875 -0.90625q-0.359375 -0.59375 -0.359375 -1.3125q0 -0.640625 0.296875 -1.1875q0.296875 -0.5625 0.8125 -0.921875q0.375 -0.28125 1.03125 -0.46875q0.671875 -0.203125 1.421875 -0.203125q1.140625 0 2.0 0.328125q0.859375 0.328125 1.265625 0.890625q0.421875 0.5625 0.578125 1.5l-1.609375 0.21875q-0.109375 -0.75 -0.640625 -1.171875q-0.515625 -0.421875 -1.46875 -0.421875q-1.140625 0 -1.625 0.375q-0.46875 0.375 -0.46875 0.875q0 0.3125 0.1875 0.578125q0.203125 0.265625 0.640625 0.4375q0.234375 0.09375 1.4375 0.421875q1.75 0.453125 2.4375 0.75q0.6875 0.296875 1.078125 0.859375q0.390625 0.5625 0.390625 1.40625q0 0.828125 -0.484375 1.546875q-0.46875 0.71875 -1.375 1.125q-0.90625 0.390625 -2.046875 0.390625q-1.875 0 -2.875 -0.78125q-0.984375 -0.78125 -1.25 -2.328125zm13.5625 1.421875l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578125l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671875q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm6.9134827 1.46875l0 -13.359375l1.78125 0l0 6.625l6.625 -6.625l2.390625 0l-5.59375 5.421875l5.84375 7.9375l-2.328125 0l-4.765625 -6.765625l-2.171875 2.140625l0 4.625l-1.78125 0zm18.943573 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.765625 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375 0 3.15625 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.21875 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.546875 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.390625 -2.65625l5.40625 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.03125 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.125717 5.765625l0 -9.671875l1.46875 0l0 1.46875q0.5625 -1.03125 1.03125 -1.359375q0.484375 -0.328125 1.0625 -0.328125q0.828125 0 1.6875 0.53125l-0.5625 1.515625q-0.609375 -0.359375 -1.203125 -0.359375q-0.546875 0 -0.96875 0.328125q-0.421875 0.328125 -0.609375 0.890625q-0.28125 0.875 -0.28125 1.921875l0 5.0625l-1.625 0zm6.228302 0l0 -9.671875l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.59375l0 5.953125l-1.640625 0l0 -5.890625q0 -1.0 -0.203125 -1.484375q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515625l0 5.28125l-1.640625 0zm17.000732 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.7656403 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375153 0 3.1562653 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.2187653 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.5468903 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.3906403 -2.65625l5.4062653 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.0312653 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.094467 5.765625l0 -13.359375l1.640625 0l0 13.359375l-1.640625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m9.1482525 232.87665l117.79527 -0.34646606" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m9.1482525 232.87665l117.79527 -0.34646606" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m233.43172 232.53018l120.34645 0.34646606" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.43172 232.53018l120.34645 0.34646606" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m126.94353 216.49869l106.48819 0l0 32.06299l-106.48819 0z" fill-rule="evenodd"/><path fill="#000000" d="m152.36165 237.33018l0 -9.546875l1.265625 0l0 3.921875l4.953125 0l0 -3.921875l1.265625 0l0 9.546875l-1.265625 0l0 -4.5l-4.953125 0l0 4.5l-1.265625 0zm13.953278 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm8.93837 0l0 -0.875q-0.65625 1.03125 -1.9375 1.03125q-0.8125 0 -1.515625 -0.453125q-0.6875 -0.453125 -1.078125 -1.265625q-0.375 -0.828125 -0.375 -1.890625q0 -1.03125 0.34375 -1.875q0.34375 -0.84375 1.03125 -1.28125q0.703125 -0.453125 1.546875 -0.453125q0.625 0 1.109375 0.265625q0.5 0.25 0.796875 0.671875l0 -3.421875l1.171875 0l0 9.546875l-1.09375 0zm-3.703125 -3.453125q0 1.328125 0.5625 1.984375q0.5625 0.65625 1.328125 0.65625q0.765625 0 1.296875 -0.625q0.53125 -0.625 0.53125 -1.90625q0 -1.421875 -0.546875 -2.078125q-0.546875 -0.671875 -1.34375 -0.671875q-0.78125 0 -1.3125 0.640625q-0.515625 0.625 -0.515625 2.0zm7.9124756 3.453125l-2.125 -6.90625l1.21875 0l1.09375 3.984375l0.421875 1.484375q0.015625 -0.109375 0.359375 -1.421875l1.09375 -4.046875l1.203125 0l1.03125 4.0l0.34375 1.328125l0.40625 -1.34375l1.171875 -3.984375l1.140625 0l-2.15625 6.90625l-1.21875 0l-1.09375 -4.140625l-0.265625 -1.171875l-1.40625 5.3125l-1.21875 0zm12.859528 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm9.18837 -2.21875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375z" fill-rule="nonzero"/><defs><linearGradient id="p.1" gradientUnits="userSpaceOnUse" gradientTransform="matrix(4.500288203456436 0.0 0.0 4.500288203456436 0.0 0.0)" spreadMethod="pad" x1="20.153856515233464" y1="38.20913288577608" x2="20.153840567727556" y2="42.70942108920426"><stop offset="0.0" stop-color="#ff0000"/><stop offset="0.51" stop-color="#dab7a6"/><stop offset="0.99999994" stop-color="#dab7a6" stop-opacity="0.0"/><stop offset="1.0" stop-color="#ffffff" stop-opacity="0.0"/></linearGradient></defs><path fill="url(#p.1)" d="m90.698166 171.95273l173.29134 0l0 20.251968l-173.29134 0z" fill-rule="evenodd"/><path fill="#d9d2e9" d="m203.76447 87.804726l114.4252 0l0 48.850395l-114.4252 0z" fill-rule="evenodd"/><path stroke="#8e7cc3" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m203.76447 87.804726l114.4252 0l0 48.850395l-114.4252 0z" fill-rule="evenodd"/><path fill="#000000" d="m245.33514 113.91555l0 -1.578125l5.65625 0l0 4.953125q-1.296875 1.046875 -2.6875 1.578125q-1.375 0.515625 -2.84375 0.515625q-1.96875 0 -3.578125 -0.84375q-1.609375 -0.84375 -2.421875 -2.4375q-0.8125 -1.59375 -0.8125 -3.5625q0 -1.953125 0.8125 -3.640625q0.8125 -1.6875 2.34375 -2.5q1.53125 -0.828125 3.515625 -0.828125q1.453125 0 2.625 0.46875q1.171875 0.46875 1.828125 1.3125q0.671875 0.828125 1.015625 2.171875l-1.59375 0.4375q-0.296875 -1.015625 -0.75 -1.59375q-0.4375 -0.59375 -1.265625 -0.9375q-0.828125 -0.34375 -1.84375 -0.34375q-1.203125 0 -2.09375 0.375q-0.890625 0.359375 -1.4375 0.96875q-0.53125 0.59375 -0.828125 1.3125q-0.515625 1.234375 -0.515625 2.6875q0 1.78125 0.609375 2.984375q0.625 1.203125 1.796875 1.796875q1.171875 0.578125 2.5 0.578125q1.140625 0 2.234375 -0.4375q1.09375 -0.453125 1.65625 -0.953125l0 -2.484375l-3.921875 0zm7.448929 0.390625q0 -2.6875 1.484375 -3.96875q1.25 -1.078125 3.0468597 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609375q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.0312347 0 -3.2812347 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.81248474 0.921875 2.0468597 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.796875 -0.8125 -2.71875q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.0468597 0.921875q-0.796875 0.90625 -0.796875 2.765625zm9.688217 4.84375l0 -8.40625l-1.453125 0l0 -1.265625l1.453125 0l0 -1.03125q0 -0.96875 0.171875 -1.453125q0.234375 -0.640625 0.828125 -1.03125q0.59375 -0.390625 1.671875 -0.390625q0.6875 0 1.53125 0.15625l-0.25 1.4375q-0.5 -0.09375 -0.953125 -0.09375q-0.75 0 -1.0625 0.328125q-0.3125 0.3125 -0.3125 1.1875l0 0.890625l1.890625 0l0 1.265625l-1.890625 0l0 8.40625l-1.625 0zm11.417664 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.765625 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375 0 3.15625 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.21875 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.546875 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.390625 -2.65625l5.40625 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.03125 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.125732 5.765625l0 -9.671875l1.46875 0l0 1.46875q0.5625 -1.03125 1.03125 -1.359375q0.484375 -0.328125 1.0625 -0.328125q0.828125 0 1.6875 0.53125l-0.5625 1.515625q-0.609375 -0.359375 -1.203125 -0.359375q-0.546875 0 -0.96875 0.328125q-0.421875 0.328125 -0.609375 0.890625q-0.28125 0.875 -0.28125 1.921875l0 5.0625l-1.625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m179.05511 152.91733l37.984253 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m179.05511 152.91733l37.984253 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m306.4252 152.91733l47.338593 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m306.4252 152.91733l47.338593 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m217.03937 136.88583l89.38583 0l0 32.06299l-89.38583 0z" fill-rule="evenodd"/><path fill="#000000" d="m226.58624 154.67169l1.0625 -0.09375q0.078125 0.65625 0.359375 1.0625q0.28125 0.40625 0.859375 0.671875q0.59375 0.25 1.328125 0.25q0.640625 0 1.140625 -0.1875q0.5 -0.203125 0.734375 -0.53125q0.25 -0.34375 0.25 -0.734375q0 -0.40625 -0.234375 -0.703125q-0.234375 -0.3125 -0.765625 -0.515625q-0.359375 -0.140625 -1.546875 -0.421875q-1.171875 -0.28125 -1.640625 -0.53125q-0.625 -0.328125 -0.921875 -0.796875q-0.296875 -0.484375 -0.296875 -1.078125q0 -0.640625 0.359375 -1.203125q0.375 -0.578125 1.078125 -0.859375q0.71875 -0.296875 1.578125 -0.296875q0.953125 0 1.6875 0.3125q0.734375 0.296875 1.125 0.90625q0.390625 0.59375 0.421875 1.34375l-1.09375 0.078125q-0.09375 -0.8125 -0.609375 -1.21875q-0.5 -0.421875 -1.484375 -0.421875q-1.03125 0 -1.5 0.375q-0.46875 0.375 -0.46875 0.90625q0 0.46875 0.328125 0.765625q0.328125 0.296875 1.703125 0.609375q1.390625 0.3125 1.90625 0.546875q0.75 0.359375 1.109375 0.890625q0.359375 0.515625 0.359375 1.21875q0 0.6875 -0.390625 1.296875q-0.390625 0.59375 -1.125 0.9375q-0.734375 0.328125 -1.65625 0.328125q-1.171875 0 -1.96875 -0.328125q-0.78125 -0.34375 -1.234375 -1.03125q-0.4375 -0.6875 -0.453125 -1.546875zm8.207031 5.15625l-0.125 -0.984375q0.34375 0.09375 0.609375 0.09375q0.34375 0 0.546875 -0.125q0.21875 -0.109375 0.359375 -0.3125q0.09375 -0.171875 0.328125 -0.796875q0.015625 -0.078125 0.09375 -0.25l-2.375 -6.234375l1.140625 0l1.296875 3.59375q0.25 0.6875 0.453125 1.453125q0.1875 -0.734375 0.4375 -1.421875l1.328125 -3.625l1.046875 0l-2.359375 6.328125q-0.390625 1.015625 -0.59375 1.40625q-0.28125 0.53125 -0.65625 0.765625q-0.359375 0.25 -0.859375 0.25q-0.296875 0 -0.671875 -0.140625zm5.625 -4.25l1.03125 -0.15625q0.09375 0.625 0.484375 0.953125q0.40625 0.328125 1.140625 0.328125q0.71875 0 1.0625 -0.28125q0.359375 -0.296875 0.359375 -0.703125q0 -0.359375 -0.3125 -0.5625q-0.21875 -0.140625 -1.078125 -0.359375q-1.15625 -0.296875 -1.609375 -0.5q-0.4375 -0.21875 -0.671875 -0.59375q-0.234375 -0.375 -0.234375 -0.84375q0 -0.40625 0.1875 -0.765625q0.1875 -0.359375 0.515625 -0.59375q0.25 -0.171875 0.671875 -0.296875q0.421875 -0.125 0.921875 -0.125q0.71875 0 1.265625 0.21875q0.5625 0.203125 0.828125 0.5625q0.265625 0.359375 0.359375 0.953125l-1.03125 0.140625q-0.0625 -0.46875 -0.40625 -0.734375q-0.328125 -0.28125 -0.953125 -0.28125q-0.71875 0 -1.03125 0.25q-0.3125 0.234375 -0.3125 0.5625q0 0.203125 0.125 0.359375q0.140625 0.171875 0.40625 0.28125q0.15625 0.0625 0.9375 0.265625q1.125 0.3125 1.5625 0.5q0.4375 0.1875 0.6875 0.546875q0.25 0.359375 0.25 0.90625q0 0.53125 -0.3125 1.0q-0.296875 0.453125 -0.875 0.71875q-0.578125 0.25 -1.3125 0.25q-1.21875 0 -1.859375 -0.5q-0.625 -0.515625 -0.796875 -1.5zm8.71875 0.921875l0.15625 0.921875q-0.453125 0.09375 -0.796875 0.09375q-0.578125 0 -0.890625 -0.171875q-0.3125 -0.1875 -0.453125 -0.484375q-0.125 -0.296875 -0.125 -1.25l0 -3.578125l-0.765625 0l0 -0.8125l0.765625 0l0 -1.546875l1.046875 -0.625l0 2.171875l1.0625 0l0 0.8125l-1.0625 0l0 3.640625q0 0.453125 0.046875 0.578125q0.0625 0.125 0.1875 0.203125q0.125 0.078125 0.359375 0.078125q0.1875 0 0.46875 -0.03125zm5.2871094 -1.0625l1.09375 0.125q-0.25 0.953125 -0.953125 1.484375q-0.703125 0.53125 -1.78125 0.53125q-1.359375 0 -2.171875 -0.84375q-0.796875 -0.84375 -0.796875 -2.359375q0 -1.5625 0.8125 -2.421875q0.8125 -0.875 2.09375 -0.875q1.25 0 2.03125 0.84375q0.796875 0.84375 0.796875 2.390625q0 0.09375 0 0.28125l-4.640625 0q0.0625 1.03125 0.578125 1.578125q0.515625 0.53125 1.296875 0.53125q0.578125 0 0.984375 -0.296875q0.421875 -0.3125 0.65625 -0.96875zm-3.453125 -1.703125l3.46875 0q-0.0625 -0.796875 -0.390625 -1.1875q-0.515625 -0.609375 -1.3125 -0.609375q-0.734375 0 -1.234375 0.484375q-0.484375 0.484375 -0.53125 1.3125zm5.876953 3.703125l0 -6.21875l0.9375 0l0 0.875q0.296875 -0.46875 0.78125 -0.734375q0.484375 -0.28125 1.109375 -0.28125q0.6875 0 1.125 0.28125q0.453125 0.28125 0.625 0.796875q0.75 -1.078125 1.921875 -1.078125q0.9375 0 1.421875 0.515625q0.5 0.5 0.5 1.578125l0 4.265625l-1.046875 0l0 -3.921875q0 -0.625 -0.109375 -0.90625q-0.09375 -0.28125 -0.359375 -0.453125q-0.265625 -0.171875 -0.640625 -0.171875q-0.65625 0 -1.09375 0.4375q-0.421875 0.4375 -0.421875 1.40625l0 3.609375l-1.0625 0l0 -4.046875q0 -0.703125 -0.265625 -1.046875q-0.25 -0.359375 -0.828125 -0.359375q-0.453125 0 -0.828125 0.234375q-0.375 0.234375 -0.546875 0.6875q-0.171875 0.453125 -0.171875 1.296875l0 3.234375l-1.046875 0zm17.392578 -2.28125l1.03125 0.140625q-0.171875 1.0625 -0.875 1.671875q-0.703125 0.609375 -1.71875 0.609375q-1.28125 0 -2.0625 -0.828125q-0.765625 -0.84375 -0.765625 -2.40625q0 -1.0 0.328125 -1.75q0.34375 -0.765625 1.015625 -1.140625q0.6875 -0.375 1.5 -0.375q1.0 0 1.640625 0.515625q0.65625 0.5 0.84375 1.453125l-1.03125 0.15625q-0.140625 -0.625 -0.515625 -0.9375q-0.375 -0.328125 -0.90625 -0.328125q-0.796875 0 -1.296875 0.578125q-0.5 0.5625 -0.5 1.796875q0 1.265625 0.484375 1.828125q0.484375 0.5625 1.25 0.5625q0.625 0 1.03125 -0.375q0.421875 -0.375 0.546875 -1.171875zm6.0 1.515625q-0.59375 0.5 -1.140625 0.703125q-0.53125 0.203125 -1.15625 0.203125q-1.03125 0 -1.578125 -0.5q-0.546875 -0.5 -0.546875 -1.28125q0 -0.453125 0.203125 -0.828125q0.203125 -0.390625 0.546875 -0.609375q0.34375 -0.234375 0.765625 -0.34375q0.296875 -0.09375 0.9375 -0.171875q1.265625 -0.140625 1.875 -0.359375q0 -0.21875 0 -0.265625q0 -0.65625 -0.296875 -0.921875q-0.40625 -0.34375 -1.203125 -0.34375q-0.734375 0 -1.09375 0.265625q-0.359375 0.25 -0.53125 0.90625l-1.03125 -0.140625q0.140625 -0.65625 0.46875 -1.0625q0.328125 -0.40625 0.9375 -0.625q0.609375 -0.21875 1.40625 -0.21875q0.796875 0 1.296875 0.1875q0.5 0.1875 0.734375 0.46875q0.234375 0.28125 0.328125 0.71875q0.046875 0.265625 0.046875 0.96875l0 1.40625q0 1.46875 0.0625 1.859375q0.078125 0.390625 0.28125 0.75l-1.109375 0q-0.15625 -0.328125 -0.203125 -0.765625zm-0.09375 -2.359375q-0.578125 0.234375 -1.71875 0.40625q-0.65625 0.09375 -0.921875 0.21875q-0.265625 0.109375 -0.421875 0.328125q-0.140625 0.21875 -0.140625 0.5q0 0.421875 0.3125 0.703125q0.328125 0.28125 0.9375 0.28125q0.609375 0 1.078125 -0.265625q0.484375 -0.265625 0.703125 -0.734375q0.171875 -0.359375 0.171875 -1.046875l0 -0.390625zm2.6738281 3.125l0 -8.59375l1.0625 0l0 8.59375l-1.0625 0zm2.6660156 0l0 -8.59375l1.0625 0l0 8.59375l-1.0625 0zm2.2753906 -1.859375l1.03125 -0.15625q0.09375 0.625 0.484375 0.953125q0.40625 0.328125 1.140625 0.328125q0.71875 0 1.0625 -0.28125q0.359375 -0.296875 0.359375 -0.703125q0 -0.359375 -0.3125 -0.5625q-0.21875 -0.140625 -1.078125 -0.359375q-1.15625 -0.296875 -1.609375 -0.5q-0.4375 -0.21875 -0.671875 -0.59375q-0.234375 -0.375 -0.234375 -0.84375q0 -0.40625 0.1875 -0.765625q0.1875 -0.359375 0.515625 -0.59375q0.25 -0.171875 0.671875 -0.296875q0.421875 -0.125 0.921875 -0.125q0.71875 0 1.265625 0.21875q0.5625 0.203125 0.828125 0.5625q0.265625 0.359375 0.359375 0.953125l-1.03125 0.140625q-0.0625 -0.46875 -0.40625 -0.734375q-0.328125 -0.28125 -0.953125 -0.28125q-0.71875 0 -1.03125 0.25q-0.3125 0.234375 -0.3125 0.5625q0 0.203125 0.125 0.359375q0.140625 0.171875 0.40625 0.28125q0.15625 0.0625 0.9375 0.265625q1.125 0.3125 1.5625 0.5q0.4375 0.1875 0.6875 0.546875q0.25 0.359375 0.25 0.90625q0 0.53125 -0.3125 1.0q-0.296875 0.453125 -0.875 0.71875q-0.578125 0.25 -1.3125 0.25q-1.21875 0 -1.859375 -0.5q-0.625 -0.515625 -0.796875 -1.5z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m150.87927 111.83202l52.88188 0.40944672" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m154.30624 111.85855l46.027924 0.35638428" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m154.30624 111.85856l1.133255 -1.1158447l-3.0983734 1.1006241l3.0809631 1.1484756z" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m200.33417 112.214935l-1.133255 1.1158447l3.0983887 -1.1006317l-3.0809784 -1.148468z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m159.04854 85.32021l37.417328 0l0 32.06299l-37.417328 0z" fill-rule="evenodd"/><path fill="#000000" d="m168.78291 104.91708l1.125 -0.109375q0.140625 0.796875 0.546875 1.15625q0.40625 0.359375 1.03125 0.359375q0.53125 0 0.9375 -0.25q0.421875 -0.25 0.671875 -0.65625q0.265625 -0.421875 0.4375 -1.125q0.171875 -0.703125 0.171875 -1.421875q0 -0.078125 0 -0.234375q-0.359375 0.546875 -0.96875 0.90625q-0.59375 0.34375 -1.3125 0.34375q-1.1875 0 -2.015625 -0.859375q-0.8125 -0.859375 -0.8125 -2.265625q0 -1.453125 0.859375 -2.328125q0.859375 -0.890625 2.140625 -0.890625q0.9375 0 1.703125 0.5q0.78125 0.5 1.171875 1.4375q0.40625 0.921875 0.40625 2.671875q0 1.828125 -0.40625 2.921875q-0.390625 1.078125 -1.171875 1.640625q-0.78125 0.5625 -1.84375 0.5625q-1.109375 0 -1.828125 -0.609375q-0.703125 -0.625 -0.84375 -1.75zm4.796875 -4.21875q0 -1.0 -0.546875 -1.59375q-0.53125 -0.59375 -1.28125 -0.59375q-0.78125 0 -1.375 0.640625q-0.578125 0.640625 -0.578125 1.65625q0 0.90625 0.546875 1.484375q0.5625 0.5625 1.359375 0.5625q0.828125 0 1.34375 -0.5625q0.53125 -0.578125 0.53125 -1.59375zm2.9124756 6.421875l0 -9.546875l3.59375 0q0.953125 0 1.453125 0.09375q0.703125 0.125 1.171875 0.453125q0.484375 0.328125 0.765625 0.921875q0.296875 0.59375 0.296875 1.296875q0 1.21875 -0.78125 2.0625q-0.765625 0.84375 -2.796875 0.84375l-2.4375 0l0 3.875l-1.265625 0zm1.265625 -5.0l2.453125 0q1.234375 0 1.75 -0.453125q0.515625 -0.46875 0.515625 -1.28125q0 -0.609375 -0.3125 -1.03125q-0.296875 -0.421875 -0.796875 -0.5625q-0.3125 -0.09375 -1.171875 -0.09375l-2.4375 0l0 3.421875z" fill-rule="nonzero"/></g></svg> \ No newline at end of file
diff --git a/g3doc/architecture_guide/BUILD b/g3doc/architecture_guide/BUILD
new file mode 100644
index 000000000..404f627a4
--- /dev/null
+++ b/g3doc/architecture_guide/BUILD
@@ -0,0 +1,50 @@
+load("//website:defs.bzl", "doc")
+
+package(
+ default_visibility = ["//website:__pkg__"],
+ licenses = ["notice"],
+)
+
+doc(
+ name = "platforms",
+ src = "platforms.md",
+ category = "Architecture Guide",
+ data = [
+ "platforms.png",
+ "platforms.svg",
+ ],
+ permalink = "/docs/architecture_guide/platforms/",
+ weight = "40",
+)
+
+doc(
+ name = "resources",
+ src = "resources.md",
+ category = "Architecture Guide",
+ data = [
+ "resources.png",
+ "resources.svg",
+ ],
+ permalink = "/docs/architecture_guide/resources/",
+ weight = "30",
+)
+
+doc(
+ name = "security",
+ src = "security.md",
+ category = "Architecture Guide",
+ data = [
+ "security.png",
+ "security.svg",
+ ],
+ permalink = "/docs/architecture_guide/security/",
+ weight = "10",
+)
+
+doc(
+ name = "performance",
+ src = "performance.md",
+ category = "Architecture Guide",
+ permalink = "/docs/architecture_guide/performance/",
+ weight = "20",
+)
diff --git a/g3doc/architecture_guide/performance.md b/g3doc/architecture_guide/performance.md
new file mode 100644
index 000000000..39dbb0045
--- /dev/null
+++ b/g3doc/architecture_guide/performance.md
@@ -0,0 +1,277 @@
+# Performance Guide
+
+[TOC]
+
+gVisor is designed to provide a secure, virtualized environment while preserving
+key benefits of containerization, such as small fixed overheads and a dynamic
+resource footprint. For containerized infrastructure, this can provide a
+turn-key solution for sandboxing untrusted workloads: there are no changes to
+the fundamental resource model.
+
+gVisor imposes runtime costs over native containers. These costs come in two
+forms: additional cycles and memory usage, which may manifest as increased
+latency, reduced throughput or density, or not at all. In general, these costs
+come from two different sources.
+
+First, the existence of the [Sentry](../README.md#sentry) means that additional
+memory will be required, and application system calls must traverse additional
+layers of software. The design emphasizes
+[security](/docs/architecture_guide/security/) and therefore we chose to use a
+language for the Sentry that provides benefits in this domain but may not yet
+offer the raw performance of other choices. Costs imposed by these design
+choices are **structural costs**.
+
+Second, as gVisor is an independent implementation of the system call surface,
+many of the subsystems or specific calls are not as optimized as more mature
+implementations. A good example here is the network stack, which is continuing
+to evolve but does not support all the advanced recovery mechanisms offered by
+other stacks and is less CPU efficient. This is an **implementation cost** and
+is distinct from **structural costs**. Improvements here are ongoing and driven
+by the workloads that matter to gVisor users and contributors.
+
+This page provides a guide for understanding baseline performance, and calls out
+distint **structural costs** and **implementation costs**, highlighting where
+improvements are possible and not possible.
+
+While we include a variety of workloads here, it’s worth emphasizing that gVisor
+may not be an appropriate solution for every workload, for reasons other than
+performance. For example, a sandbox may provide minimal benefit for a trusted
+database, since _user data would already be inside the sandbox_ and there is no
+need for an attacker to break out in the first place.
+
+## Methodology
+
+All data below was generated using the [benchmark tools][benchmark-tools]
+repository, and the machines under test are uniform [Google Compute Engine][gce]
+Virtual Machines (VMs) with the following specifications:
+
+ Machine type: n1-standard-4 (broadwell)
+ Image: Debian GNU/Linux 9 (stretch) 4.19.0-0
+ BootDisk: 2048GB SSD persistent disk
+
+Through this document, `runsc` is used to indicate the runtime provided by
+gVisor. When relevant, we use the name `runsc-platform` to describe a specific
+[platform choice](/docs/architecture_guide/platforms/).
+
+**Except where specified, all tests below are conducted with the `ptrace`
+platform. The `ptrace` platform works everywhere and does not require hardware
+virtualization or kernel modifications but suffers from the highest structural
+costs by far. This platform is used to provide a clear understanding of the
+performance model, but in no way represents an ideal scenario. In the future,
+this guide will be extended to bare metal environments and include additional
+platforms.**
+
+## Memory access
+
+gVisor does not introduce any additional costs with respect to raw memory
+accesses. Page faults and other Operating System (OS) mechanisms are translated
+through the Sentry, but once mappings are installed and available to the
+application, there is no additional overhead.
+
+{% include graph.html id="sysbench-memory"
+url="/performance/sysbench-memory.csv" title="perf.py sysbench.memory
+--runtime=runc --runtime=runsc" %}
+
+The above figure demonstrates the memory transfer rate as measured by
+`sysbench`.
+
+## Memory usage
+
+The Sentry provides an additional layer of indirection, and it requires memory
+in order to store state associated with the application. This memory generally
+consists of a fixed component, plus an amount that varies with the usage of
+operating system resources (e.g. how many sockets or files are opened).
+
+For many use cases, fixed memory overheads are a primary concern. This may be
+because sandboxed containers handle a low volume of requests, and it is
+therefore important to achieve high densities for efficiency.
+
+{% include graph.html id="density" url="/performance/density.csv" title="perf.py
+density --runtime=runc --runtime=runsc" log="true" y_min="100000" %}
+
+The above figure demonstrates these costs based on three sample applications.
+This test is the result of running many instances of a container (50, or 5 in
+the case of redis) and calculating available memory on the host before and
+afterwards, and dividing the difference by the number of containers. This
+technique is used for measuring memory usage over the `usage_in_bytes` value of
+the container cgroup because we found that some container runtimes, other than
+`runc` and `runsc`, do not use an individual container cgroup.
+
+The first application is an instance of `sleep`: a trivial application that does
+nothing. The second application is a synthetic `node` application which imports
+a number of modules and listens for requests. The third application is a similar
+synthetic `ruby` application which does the same. Finally, we include an
+instance of `redis` storing approximately 1GB of data. In all cases, the sandbox
+itself is responsible for a small, mostly fixed amount of memory overhead.
+
+## CPU performance
+
+gVisor does not perform emulation or otherwise interfere with the raw execution
+of CPU instructions by the application. Therefore, there is no runtime cost
+imposed for CPU operations.
+
+{% include graph.html id="sysbench-cpu" url="/performance/sysbench-cpu.csv"
+title="perf.py sysbench.cpu --runtime=runc --runtime=runsc" %}
+
+The above figure demonstrates the `sysbench` measurement of CPU events per
+second. Events per second is based on a CPU-bound loop that calculates all prime
+numbers in a specified range. We note that `runsc` does not impose a performance
+penalty, as the code is executing natively in both cases.
+
+This has important consequences for classes of workloads that are often
+CPU-bound, such as data processing or machine learning. In these cases, `runsc`
+will similarly impose minimal runtime overhead.
+
+{% include graph.html id="tensorflow" url="/performance/tensorflow.csv"
+title="perf.py tensorflow --runtime=runc --runtime=runsc" %}
+
+For example, the above figure shows a sample TensorFlow workload, the
+[convolutional neural network example][cnn]. The time indicated includes the
+full start-up and run time for the workload, which trains a model.
+
+## System calls
+
+Some **structural costs** of gVisor are heavily influenced by the
+[platform choice](/docs/architecture_guide/platforms/), which implements system
+call interception. Today, gVisor supports a variety of platforms. These
+platforms present distinct performance, compatibility and security trade-offs.
+For example, the KVM platform has low overhead system call interception but runs
+poorly with nested virtualization.
+
+{% include graph.html id="syscall" url="/performance/syscall.csv" title="perf.py
+syscall --runtime=runc --runtime=runsc-ptrace --runtime=runsc-kvm" y_min="100"
+log="true" %}
+
+The above figure demonstrates the time required for a raw system call on various
+platforms. The test is implemented by a custom binary which performs a large
+number of system calls and calculates the average time required.
+
+This cost will principally impact applications that are system call bound, which
+tend to be high-performance data stores and static network services. In general,
+the impact of system call interception will be lower the more work an
+application does.
+
+{% include graph.html id="redis" url="/performance/redis.csv" title="perf.py
+redis --runtime=runc --runtime=runsc" %}
+
+For example, `redis` is an application that performs relatively little work in
+userspace: in general it reads from a connected socket, reads or modifies some
+data, and writes a result back to the socket. The above figure shows the results
+of running [comprehensive set of benchmarks][redis-benchmark]. We can see that
+small operations impose a large overhead, while larger operations, such as
+`LRANGE`, where more work is done in the application, have a smaller relative
+overhead.
+
+Some of these costs above are **structural costs**, and `redis` is likely to
+remain a challenging performance scenario. However, optimizing the
+[platform](/docs/architecture_guide/platforms/) will also have a dramatic
+impact.
+
+## Start-up time
+
+For many use cases, the ability to spin-up containers quickly and efficiently is
+important. A sandbox may be short-lived and perform minimal user work (e.g. a
+function invocation).
+
+{% include graph.html id="startup" url="/performance/startup.csv" title="perf.py
+startup --runtime=runc --runtime=runsc" %}
+
+The above figure indicates how total time required to start a container through
+[Docker][docker]. This benchmark uses three different applications. First, an
+alpine Linux-container that executes `true`. Second, a `node` application that
+loads a number of modules and binds an HTTP server. The time is measured by a
+successful request to the bound port. Finally, a `ruby` application that
+similarly loads a number of modules and binds an HTTP server.
+
+> Note: most of the time overhead above is associated Docker itself. This is
+> evident with the empty `runc` benchmark. To avoid these costs with `runsc`,
+> you may also consider using `runsc do` mode or invoking the
+> [OCI runtime](../user_guide/quick_start/oci.md) directly.
+
+## Network
+
+Networking is mostly bound by **implementation costs**, and gVisor's network
+stack is improving quickly.
+
+While typically not an important metric in practice for common sandbox use
+cases, nevertheless `iperf` is a common microbenchmark used to measure raw
+throughput.
+
+{% include graph.html id="iperf" url="/performance/iperf.csv" title="perf.py
+iperf --runtime=runc --runtime=runsc" %}
+
+The above figure shows the result of an `iperf` test between two instances. For
+the upload case, the specified runtime is used for the `iperf` client, and in
+the download case, the specified runtime is the server. A native runtime is
+always used for the other endpoint in the test.
+
+{% include graph.html id="applications" metric="requests_per_second"
+url="/performance/applications.csv" title="perf.py http.(node|ruby)
+--connections=25 --runtime=runc --runtime=runsc" %}
+
+The above figure shows the result of simple `node` and `ruby` web services that
+render a template upon receiving a request. Because these synthetic benchmarks
+do minimal work per request, must like the `redis` case, they suffer from high
+overheads. In practice, the more work an application does the smaller the impact
+of **structural costs** become.
+
+## File system
+
+Some aspects of file system performance are also reflective of **implementation
+costs**, and an area where gVisor's implementation is improving quickly.
+
+In terms of raw disk I/O, gVisor does not introduce significant fundamental
+overhead. For general file operations, gVisor introduces a small fixed overhead
+for data that transitions across the sandbox boundary. This manifests as
+**structural costs** in some cases, since these operations must be routed
+through the [Gofer](../README.md#gofer) as a result of our
+[Security Model](/docs/architecture_guide/security/), but in most cases are
+dominated by **implementation costs**, due to an internal
+[Virtual File System][vfs] (VFS) implementation that needs improvement.
+
+{% include graph.html id="fio-bw" url="/performance/fio.csv" title="perf.py fio
+--engine=sync --runtime=runc --runtime=runsc" log="true" %}
+
+The above figures demonstrate the results of `fio` for reads and writes to and
+from the disk. In this case, the disk quickly becomes the bottleneck and
+dominates other costs.
+
+{% include graph.html id="fio-tmpfs-bw" url="/performance/fio-tmpfs.csv"
+title="perf.py fio --engine=sync --runtime=runc --tmpfs=True --runtime=runsc"
+log="true" %}
+
+The above figure shows the raw I/O performance of using a `tmpfs` mount which is
+sandbox-internal in the case of `runsc`. Generally these operations are
+similarly bound to the cost of copying around data in-memory, and we don't see
+the cost of VFS operations.
+
+{% include graph.html id="httpd100k" metric="transfer_rate"
+url="/performance/httpd100k.csv" title="perf.py http.httpd --connections=1
+--connections=5 --connections=10 --connections=25 --runtime=runc
+--runtime=runsc" %}
+
+The high costs of VFS operations can manifest in benchmarks that execute many
+such operations in the hot path for serving requests, for example. The above
+figure shows the result of using gVisor to serve small pieces of static content
+with predictably poor results. This workload represents `apache` serving a
+single file sized 100k from the container image to a client running
+[ApacheBench][ab] with varying levels of concurrency. The high overhead comes
+principally from the VFS implementation that needs improvement, with several
+internal serialization points (since all requests are reading the same file).
+Note that some of some of network stack performance issues also impact this
+benchmark.
+
+{% include graph.html id="ffmpeg" url="/performance/ffmpeg.csv" title="perf.py
+media.ffmpeg --runtime=runc --runtime=runsc" %}
+
+For benchmarks that are bound by raw disk I/O and a mix of compute, file system
+operations are less of an issue. The above figure shows the total time required
+for an `ffmpeg` container to start, load and transcode a 27MB input video.
+
+[ab]: https://en.wikipedia.org/wiki/ApacheBench
+[benchmark-tools]: https://github.com/google/gvisor/tree/master/benchmarks
+[gce]: https://cloud.google.com/compute/
+[cnn]: https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/convolutional_network.py
+[docker]: https://docker.io
+[redis-benchmark]: https://redis.io/topics/benchmarks
+[vfs]: https://en.wikipedia.org/wiki/Virtual_file_system
diff --git a/g3doc/architecture_guide/platforms.md b/g3doc/architecture_guide/platforms.md
new file mode 100644
index 000000000..d112c9a28
--- /dev/null
+++ b/g3doc/architecture_guide/platforms.md
@@ -0,0 +1,61 @@
+# Platform Guide
+
+[TOC]
+
+gVisor requires a platform to implement interception of syscalls, basic context
+switching, and memory mapping functionality. Internally, gVisor uses an
+abstraction sensibly called [Platform][platform]. A simplified version of this
+interface looks like:
+
+```golang
+type Platform interface {
+ NewAddressSpace() (AddressSpace, error)
+ NewContext() Context
+}
+
+type Context interface {
+ Switch(as AddressSpace, ac arch.Context) (..., error)
+}
+
+type AddressSpace interface {
+ MapFile(addr usermem.Addr, f File, fr FileRange, at usermem.AccessType, ...) error
+ Unmap(addr usermem.Addr, length uint64)
+}
+```
+
+There are a number of different ways to implement this interface that come with
+various trade-offs, generally around performance and hardware requirements.
+
+## Implementations
+
+The choice of platform depends on the context in which `runsc` is executing. In
+general, virtualized platforms may be limited to platforms that do not require
+hardware virtualized support (since the hardware is already in use):
+
+![Platforms](platforms.png "Platform examples.")
+
+### ptrace
+
+The ptrace platform uses [PTRACE_SYSEMU][ptrace] to execute user code without
+allowing it to execute host system calls. This platform can run anywhere that
+`ptrace` works (even VMs without nested virtualization), which is ubiquitous.
+
+Unfortunately, the ptrace platform has high context switch overhead, so system
+call-heavy applications may pay a [performance penalty](./performance.md).
+
+### KVM
+
+The KVM platform uses the kernel's [KVM][kvm] functionality to allow the Sentry
+to act as both guest OS and VMM. The KVM platform can run on bare-metal or in a
+VM with nested virtualization enabled. While there is no virtualized hardware
+layer -- the sandbox retains a process model -- gVisor leverages virtualization
+extensions available on modern processors in order to improve isolation and
+performance of address space switches.
+
+## Changing Platforms
+
+See [Changing Platforms](../user_guide/platforms.md).
+
+[kvm]: https://www.kernel.org/doc/Documentation/virtual/kvm/api.txt
+[platform]: https://cs.opensource.google/gvisor/gvisor/+/release-20190304.1:pkg/sentry/platform/platform.go;l=33
+[ptrace]: http://man7.org/linux/man-pages/man2/ptrace.2.html
diff --git a/g3doc/architecture_guide/platforms.png b/g3doc/architecture_guide/platforms.png
new file mode 100644
index 000000000..005d56feb
--- /dev/null
+++ b/g3doc/architecture_guide/platforms.png
Binary files differ
diff --git a/g3doc/architecture_guide/platforms.svg b/g3doc/architecture_guide/platforms.svg
new file mode 100644
index 000000000..b0bac9ba7
--- /dev/null
+++ b/g3doc/architecture_guide/platforms.svg
@@ -0,0 +1,334 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ width="142.67763mm"
+ height="67.063133mm"
+ viewBox="0 0 142.67763 67.063134"
+ version="1.1"
+ id="svg8"
+ inkscape:export-filename="/home/ascannell/resources.png"
+ inkscape:export-xdpi="53.50127"
+ inkscape:export-ydpi="53.50127"
+ inkscape:version="0.92.4 (5da689c313, 2019-01-14)"
+ sodipodi:docname="platforms.svg">
+ <defs
+ id="defs2" />
+ <sodipodi:namedview
+ id="base"
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1.0"
+ inkscape:pageopacity="0.0"
+ inkscape:pageshadow="2"
+ inkscape:zoom="0.98994949"
+ inkscape:cx="86.443612"
+ inkscape:cy="102.88104"
+ inkscape:document-units="mm"
+ inkscape:current-layer="layer1"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:window-width="1920"
+ inkscape:window-height="1005"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="1" />
+ <metadata
+ id="metadata5">
+ <rdf:RDF>
+ <cc:Work
+ rdf:about="">
+ <dc:format>image/svg+xml</dc:format>
+ <dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" />
+ <dc:title></dc:title>
+ </cc:Work>
+ </rdf:RDF>
+ </metadata>
+ <g
+ inkscape:label="Layer 1"
+ inkscape:groupmode="layer"
+ id="layer1"
+ transform="translate(-36.081387,-98.953278)">
+ <rect
+ id="rect10"
+ width="33.408691"
+ height="33.408691"
+ x="36.081387"
+ y="120.06757"
+ style="fill:#44aa00;stroke-width:0.26458332" />
+ <rect
+ style="fill:#b3b3b3;stroke-width:0.23881446"
+ id="rect16"
+ width="142.45465"
+ height="10.423517"
+ x="36.08139"
+ y="155.5929" />
+ <rect
+ id="rect10-7"
+ width="30.52453"
+ height="18.976137"
+ x="37.416695"
+ y="121.65508"
+ style="fill:#ff8080;stroke-width:0.19060372" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.40292525px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08507314"
+ x="41.03727"
+ y="148.58765"
+ id="text65"><tspan
+ sodipodi:role="line"
+ id="tspan63"
+ x="41.03727"
+ y="148.58765"
+ style="stroke-width:0.08507314">gVisor</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.33113885px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08327847"
+ x="45.473087"
+ y="132.50232"
+ id="text123"><tspan
+ sodipodi:role="line"
+ id="tspan121"
+ x="45.473087"
+ y="132.50232"
+ style="stroke-width:0.08327847">workload</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:6.43922186px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.16098055"
+ x="97.768547"
+ y="163.15665"
+ id="text163"><tspan
+ sodipodi:role="line"
+ id="tspan161"
+ x="97.768547"
+ y="163.15665"
+ style="stroke-width:0.16098055">host</tspan></text>
+ <rect
+ style="fill:#e9afdd;stroke-width:0.39185274"
+ id="rect16-7"
+ width="72.9646"
+ height="54.79026"
+ x="105.79441"
+ y="98.953278" />
+ <rect
+ id="rect10-5"
+ width="33.408691"
+ height="33.408691"
+ x="108.24348"
+ y="100.53072"
+ style="fill:#44aa00;stroke-width:0.26458332" />
+ <rect
+ id="rect10-7-6"
+ width="30.52453"
+ height="20.045216"
+ x="109.57877"
+ y="102.11823"
+ style="fill:#ff8080;stroke-width:0.19589928" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.40292525px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08507314"
+ x="112.86765"
+ y="129.01863"
+ id="text65-2"><tspan
+ sodipodi:role="line"
+ id="tspan63-9"
+ x="112.86765"
+ y="129.01863"
+ style="stroke-width:0.08507314">gVisor</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.33113885px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08327847"
+ x="117.63519"
+ y="114.02371"
+ id="text123-1"><tspan
+ sodipodi:role="line"
+ id="tspan121-2"
+ x="117.63519"
+ y="114.02371"
+ style="stroke-width:0.08327847">workload</tspan></text>
+ <rect
+ id="rect10-7-7"
+ width="11.815663"
+ height="8.0126781"
+ x="54.538059"
+ y="143.27702"
+ style="fill:#aaccff;stroke-width:0.07705856" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:4.35074377px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.10876859"
+ x="55.931114"
+ y="148.90578"
+ id="text144"><tspan
+ sodipodi:role="line"
+ id="tspan142"
+ x="55.931114"
+ y="148.90578"
+ style="stroke-width:0.10876859">KVM</tspan></text>
+ <rect
+ id="rect10-6"
+ width="33.408691"
+ height="33.408691"
+ x="71.044685"
+ y="119.73112"
+ style="fill:#44aa00;stroke-width:0.26458332" />
+ <rect
+ id="rect10-7-0"
+ width="30.52453"
+ height="18.976137"
+ x="72.37999"
+ y="121.31865"
+ style="fill:#ff8080;stroke-width:0.19060372" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.40292525px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08507314"
+ x="76.000565"
+ y="148.25128"
+ id="text65-6"><tspan
+ sodipodi:role="line"
+ id="tspan63-2"
+ x="76.000565"
+ y="148.25128"
+ style="stroke-width:0.08507314">gVisor</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.33113885px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08327847"
+ x="80.436386"
+ y="132.16595"
+ id="text123-6"><tspan
+ sodipodi:role="line"
+ id="tspan121-1"
+ x="80.436386"
+ y="132.16595"
+ style="stroke-width:0.08327847">workload</tspan></text>
+ <rect
+ id="rect10-7-7-8"
+ width="11.815664"
+ height="8.0126781"
+ x="89.501358"
+ y="142.94067"
+ style="fill:#ffeeaa;stroke-width:0.07705856" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.39456654px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08486416"
+ x="89.92292"
+ y="147.89806"
+ id="text144-7"><tspan
+ sodipodi:role="line"
+ id="tspan142-9"
+ x="89.92292"
+ y="147.89806"
+ style="stroke-width:0.08486416">ptrace</tspan></text>
+ <rect
+ id="rect10-7-7-8-3"
+ width="11.815665"
+ height="8.0126781"
+ x="127.08897"
+ y="123.97878"
+ style="fill:#ffeeaa;stroke-width:0.07705856" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.39456654px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08486416"
+ x="127.51052"
+ y="128.9362"
+ id="text144-7-7"><tspan
+ sodipodi:role="line"
+ id="tspan142-9-5"
+ x="127.51052"
+ y="128.9362"
+ style="stroke-width:0.08486416">ptrace</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:5.45061255px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.13626531"
+ x="138.49318"
+ y="152.11841"
+ id="text229"><tspan
+ sodipodi:role="line"
+ id="tspan227"
+ x="138.49318"
+ y="152.11841"
+ style="stroke-width:0.13626531">VM</tspan></text>
+ <rect
+ style="fill:#b3b3b3;stroke-width:0.16518368"
+ id="rect16-9"
+ width="68.15374"
+ height="10.423517"
+ x="108.24348"
+ y="134.99774" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:6.17854786px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.15446369"
+ x="132.91473"
+ y="142.07658"
+ id="text248"><tspan
+ sodipodi:role="line"
+ id="tspan246"
+ x="132.91473"
+ y="142.07658"
+ style="stroke-width:0.15446369">guest</tspan></text>
+ <rect
+ id="rect10-5-2"
+ width="33.408691"
+ height="33.408691"
+ x="143.32402"
+ y="100.35877"
+ style="fill:#44aa00;stroke-width:0.26458332" />
+ <rect
+ id="rect10-7-6-2"
+ width="30.52453"
+ height="20.045216"
+ x="144.65933"
+ y="101.94627"
+ style="fill:#ff8080;stroke-width:0.19589929" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.40292525px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08507314"
+ x="147.94815"
+ y="128.84665"
+ id="text65-2-8"><tspan
+ sodipodi:role="line"
+ id="tspan63-9-9"
+ x="147.94815"
+ y="128.84665"
+ style="stroke-width:0.08507314">gVisor</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.33113885px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08327847"
+ x="152.71565"
+ y="113.85176"
+ id="text123-1-7"><tspan
+ sodipodi:role="line"
+ id="tspan121-2-3"
+ x="152.71565"
+ y="113.85176"
+ style="stroke-width:0.08327847">workload</tspan></text>
+ <rect
+ id="rect10-7-7-8-3-6"
+ width="11.815666"
+ height="8.0126781"
+ x="162.16933"
+ y="123.80682"
+ style="fill:#ffeeaa;stroke-width:0.07705856" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.39456654px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08486416"
+ x="162.59088"
+ y="128.76421"
+ id="text144-7-7-1"><tspan
+ sodipodi:role="line"
+ id="tspan142-9-5-2"
+ x="162.59088"
+ y="128.76421"
+ style="stroke-width:0.08486416">ptrace</tspan></text>
+ </g>
+</svg>
diff --git a/g3doc/architecture_guide/resources.md b/g3doc/architecture_guide/resources.md
new file mode 100644
index 000000000..1dec37bd1
--- /dev/null
+++ b/g3doc/architecture_guide/resources.md
@@ -0,0 +1,144 @@
+# Resource Model
+
+[TOC]
+
+The resource model for gVisor does not assume a fixed number of threads of
+execution (i.e. vCPUs) or amount of physical memory. Where possible, decisions
+about underlying physical resources are delegated to the host system, where
+optimizations can be made with global information. This delegation allows the
+sandbox to be highly dynamic in terms of resource usage: spanning a large number
+of cores and large amount of memory when busy, and yielding those resources back
+to the host when not.
+
+In order words, the shape of the sandbox should closely track the shape of the
+sandboxed process:
+
+![Resource model](resources.png "Workloads of different shapes.")
+
+## Processes
+
+Much like a Virtual Machine (VM), a gVisor sandbox appears as an opaque process
+on the system. Processes within the sandbox do not manifest as processes on the
+host system, and process-level interactions within the sandbox requires entering
+the sandbox (e.g. via a [Docker exec][exec]).
+
+## Networking
+
+The sandbox attaches a network endpoint to the system, but runs it's own network
+stack. All network resources, other than packets in flight on the host, exist
+only inside the sandbox, bound by relevant resource limits.
+
+You can interact with network endpoints exposed by the sandbox, just as you
+would any other container, but network introspection similarly requires entering
+the sandbox.
+
+## Files
+
+Files in the sandbox may be backed by different implementations. For host-native
+files (where a file descriptor is available), the Gofer may return a file
+descriptor to the Sentry via [SCM_RIGHTS][scmrights][^1].
+
+These files may be read from and written to through standard system calls, and
+also mapped into the associated application's address space. This allows the
+same host memory to be shared across multiple sandboxes, although this mechanism
+does not preclude the use of side-channels (see [Security Model](./security.md).
+
+Note that some file systems exist only within the context of the sandbox. For
+example, in many cases a `tmpfs` mount will be available at `/tmp` or
+`/dev/shm`, which allocates memory directly from the sandbox memory file (see
+below). Ultimately, these will be accounted against relevant limits in a similar
+way as the host native case.
+
+## Threads
+
+The Sentry models individual task threads with [goroutines][goroutine]. As a
+result, each task thread is a lightweight [green thread][greenthread], and may
+not correspond to an underlying host thread.
+
+However, application execution is modelled as a blocking system call with the
+Sentry. This means that additional host threads may be created, *depending on
+the number of active application threads*. In practice, a busy application will
+converge on the number of active threads, and the host will be able to make
+scheduling decisions about all application threads.
+
+## Time
+
+Time in the sandbox is provided by the Sentry, through its own [vDSO][vdso] and
+time-keeping implementation. This is distinct from the host time, and no state
+is shared with the host, although the time will be initialized with the host
+clock.
+
+The Sentry runs timers to note the passage of time, much like a kernel running
+on hardware (though the timers are software timers, in this case). These timers
+provide updates to the vDSO, the time returned through system calls, and the
+time recorded for usage or limit tracking (e.g. [RLIMIT_CPU][rlimit]).
+
+When all application threads are idle, the Sentry disables timers until an event
+occurs that wakes either the Sentry or an application thread, similar to a
+[tickless kernel][tickless]. This allows the Sentry to achieve near zero CPU
+usage for idle applications.
+
+## Memory
+
+The Sentry implements its own memory management, including demand-paging and a
+Sentry internal page cache for files that cannot be used natively. A single
+[memfd][memfd] backs all application memory.
+
+### Address spaces
+
+The creation of address spaces is platform-specific. For some platforms,
+additional "stub" processes may be created on the host in order to support
+additional address spaces. These stubs are subject to various limits applied at
+the sandbox level (e.g. PID limits).
+
+### Physical memory
+
+The host is able to manage physical memory using regular means (e.g. tracking
+working sets, reclaiming and swapping under pressure). The Sentry lazily
+populates host mappings for applications, and allow the host to demand-page
+those regions, which is critical for the functioning of those mechanisms.
+
+In order to avoid excessive overhead, the Sentry does not demand-page individual
+pages. Instead, it selects appropriate regions based on heuristics. There is a
+trade-off here: the Sentry is unable to trivially determine which pages are
+active and which are not. Even if pages were individually faulted, the host may
+select pages to be reclaimed or swapped without the Sentry's knowledge.
+
+Therefore, memory usage statistics within the sandbox (e.g. via `proc`) are
+approximations. The Sentry maintains an internal breakdown of memory usage, and
+can collect accurate information but only through a relatively expensive API
+call. In any case, it would likely be considered unwise to share precise
+information about how the host is managing memory with the sandbox.
+
+Finally, when an application marks a region of memory as no longer needed, for
+example via a call to [madvise][madvise], the Sentry *releases this memory back
+to the host*. There can be performance penalties for this, since it may be
+cheaper in many cases to retain the memory and use it to satisfy some other
+request. However, releasing it immediately to the host allows the host to more
+effectively multiplex resources and apply an efficient global policy.
+
+## Limits
+
+All Sentry threads and Sentry memory are subject to a container cgroup. However,
+application usage will not appear as anonymous memory usage, and will instead be
+accounted to the `memfd`. All anonymous memory will correspond to Sentry usage,
+and host memory charged to the container will work as standard.
+
+The cgroups can be monitored for standard signals: pressure indicators,
+threshold notifiers, etc. and can also be adjusted dynamically. Note that the
+Sentry itself may listen for pressure signals in its containing cgroup, in order
+to purge internal caches.
+
+[goroutine]: https://tour.golang.org/concurrency/1
+[greenthread]: https://en.wikipedia.org/wiki/Green_threads
+[scheduler]: https://morsmachine.dk/go-scheduler
+[vdso]: https://en.wikipedia.org/wiki/VDSO
+[rlimit]: http://man7.org/linux/man-pages/man2/getrlimit.2.html
+[tickless]: https://en.wikipedia.org/wiki/Tickless_kernel
+[memfd]: http://man7.org/linux/man-pages/man2/memfd_create.2.html
+[scmrights]: http://man7.org/linux/man-pages/man7/unix.7.html
+[madvise]: http://man7.org/linux/man-pages/man2/madvise.2.html
+[exec]: https://docs.docker.com/engine/reference/commandline/exec/
+[^1]: Unless host networking is enabled, the Sentry is not able to create or
+ open host file descriptors itself, it can only receive them in this way
+ from the Gofer.
diff --git a/g3doc/architecture_guide/resources.png b/g3doc/architecture_guide/resources.png
new file mode 100644
index 000000000..f715008ec
--- /dev/null
+++ b/g3doc/architecture_guide/resources.png
Binary files differ
diff --git a/g3doc/architecture_guide/resources.svg b/g3doc/architecture_guide/resources.svg
new file mode 100644
index 000000000..fd7805d90
--- /dev/null
+++ b/g3doc/architecture_guide/resources.svg
@@ -0,0 +1,208 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ width="108.24417mm"
+ height="47.513165mm"
+ viewBox="0 0 108.24417 47.513165"
+ version="1.1"
+ id="svg8"
+ inkscape:export-filename="/home/ascannell/resources.png"
+ inkscape:export-xdpi="53.50127"
+ inkscape:export-ydpi="53.50127"
+ inkscape:version="0.92.4 (5da689c313, 2019-01-14)"
+ sodipodi:docname="resources.svg">
+ <defs
+ id="defs2" />
+ <sodipodi:namedview
+ id="base"
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1.0"
+ inkscape:pageopacity="0.0"
+ inkscape:pageshadow="2"
+ inkscape:zoom="0.98994949"
+ inkscape:cx="16.897058"
+ inkscape:cy="41.261746"
+ inkscape:document-units="mm"
+ inkscape:current-layer="layer1"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:window-width="1920"
+ inkscape:window-height="1005"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="1" />
+ <metadata
+ id="metadata5">
+ <rdf:RDF>
+ <cc:Work
+ rdf:about="">
+ <dc:format>image/svg+xml</dc:format>
+ <dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" />
+ <dc:title></dc:title>
+ </cc:Work>
+ </rdf:RDF>
+ </metadata>
+ <g
+ inkscape:label="Layer 1"
+ inkscape:groupmode="layer"
+ id="layer1"
+ transform="translate(-36.081387,-118.50325)">
+ <rect
+ id="rect10"
+ width="33.408691"
+ height="33.408691"
+ x="36.081387"
+ y="120.06757"
+ style="fill:#44aa00;stroke-width:0.26458332" />
+ <circle
+ style="fill:#44aa00;stroke-width:0.21849461"
+ id="path12"
+ cx="87.958534"
+ cy="136.63828"
+ r="17.105247" />
+ <path
+ sodipodi:type="star"
+ style="fill:#44aa00;stroke-width:0.26458332"
+ id="path14"
+ sodipodi:sides="3"
+ sodipodi:cx="124.13387"
+ sodipodi:cy="141.81859"
+ sodipodi:r1="23.31534"
+ sodipodi:r2="11.65767"
+ sodipodi:arg1="0.52359878"
+ sodipodi:arg2="1.5707963"
+ inkscape:flatsided="false"
+ inkscape:rounded="0"
+ inkscape:randomized="0"
+ d="m 144.32555,153.47626 -20.19168,0 -20.19167,0 10.09583,-17.48651 10.09584,-17.4865 10.09584,17.4865 z"
+ inkscape:transform-center-x="1.8384776e-06"
+ inkscape:transform-center-y="-5.8288369" />
+ <rect
+ style="fill:#b3b3b3;stroke-width:0.20817307"
+ id="rect16"
+ width="108.24416"
+ height="10.423517"
+ x="36.08139"
+ y="155.5929" />
+ <path
+ sodipodi:type="star"
+ style="fill:#ff8080;stroke-width:0.20018946"
+ id="path14-3"
+ sodipodi:sides="3"
+ sodipodi:cx="124.13387"
+ sodipodi:cy="139.31911"
+ sodipodi:r1="17.640888"
+ sodipodi:r2="8.8204451"
+ sodipodi:arg1="0.52359878"
+ sodipodi:arg2="1.5707963"
+ inkscape:flatsided="false"
+ inkscape:rounded="0"
+ inkscape:randomized="0"
+ d="m 139.41133,148.13955 -15.27746,0 -15.27745,0 7.63872,-13.23067 7.63873,-13.23066 7.63873,13.23066 z"
+ inkscape:transform-center-x="3.9117172e-06"
+ inkscape:transform-center-y="-4.4102243" />
+ <circle
+ style="fill:#ff8080;stroke-width:0.18094084"
+ id="path12-6"
+ cx="87.93705"
+ cy="134.75125"
+ r="14.165282" />
+ <rect
+ id="rect10-7"
+ width="30.52453"
+ height="25.657875"
+ x="37.416695"
+ y="121.65508"
+ style="fill:#ff8080;stroke-width:0.22163473" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.40292525px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08507314"
+ x="47.387276"
+ y="151.7626"
+ id="text65"><tspan
+ sodipodi:role="line"
+ id="tspan63"
+ x="47.387276"
+ y="151.7626"
+ style="stroke-width:0.08507314">gVisor</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.40292525px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08507314"
+ x="82.156319"
+ y="151.71547"
+ id="text65-5"><tspan
+ sodipodi:role="line"
+ id="tspan63-3"
+ x="82.156319"
+ y="151.71547"
+ style="stroke-width:0.08507314">gVisor</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.40292525px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08507314"
+ x="118.66879"
+ y="151.71547"
+ id="text65-5-5"><tspan
+ sodipodi:role="line"
+ id="tspan63-3-6"
+ x="118.66879"
+ y="151.71547"
+ style="stroke-width:0.08507314">gVisor</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.33113885px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08327847"
+ x="45.473087"
+ y="136.20644"
+ id="text123"><tspan
+ sodipodi:role="line"
+ id="tspan121"
+ x="45.473087"
+ y="136.20644"
+ style="stroke-width:0.08327847">workload</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.33113885px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08327847"
+ x="80.153076"
+ y="136.00925"
+ id="text123-1"><tspan
+ sodipodi:role="line"
+ id="tspan121-2"
+ x="80.153076"
+ y="136.00925"
+ style="stroke-width:0.08327847">workload</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.33113885px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08327847"
+ x="116.50173"
+ y="138.68195"
+ id="text123-1-7"><tspan
+ sodipodi:role="line"
+ id="tspan121-2-0"
+ x="116.50173"
+ y="138.68195"
+ style="stroke-width:0.08327847">workload</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:6.43922186px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.16098055"
+ x="81.893562"
+ y="163.15665"
+ id="text163"><tspan
+ sodipodi:role="line"
+ id="tspan161"
+ x="81.893562"
+ y="163.15665"
+ style="stroke-width:0.16098055">host</tspan></text>
+ </g>
+</svg>
diff --git a/g3doc/architecture_guide/security.md b/g3doc/architecture_guide/security.md
new file mode 100644
index 000000000..b99b86332
--- /dev/null
+++ b/g3doc/architecture_guide/security.md
@@ -0,0 +1,255 @@
+# Security Model
+
+[TOC]
+
+gVisor was created in order to provide additional defense against the
+exploitation of kernel bugs by untrusted userspace code. In order to understand
+how gVisor achieves this goal, it is first necessary to understand the basic
+threat model.
+
+## Threats: The Anatomy of an Exploit
+
+An exploit takes advantage of a software or hardware bug in order to escalate
+privileges, gain access to privileged data, or disrupt services. All of the
+possible interactions that a malicious application can have with the rest of the
+system (attack vectors) define the attack surface. We categorize these attack
+vectors into several common classes.
+
+### System API
+
+An operating system or hypervisor exposes an abstract System API in the form of
+system calls and traps. This API may be documented and stable, as with Linux, or
+it may be abstracted behind a library, as with Windows (i.e. win32.dll or
+ntdll.dll). The System API includes all standard interfaces that application
+code uses to interact with the system. This includes high-level abstractions
+that are derived from low-level system calls, such as system files, sockets and
+namespaces.
+
+Although the System API is exposed to applications by design, bugs and race
+conditions within the kernel or hypervisor may occasionally be exploitable via
+the API. This is common in part due to the fact that most kernels and
+hypervisors are written in [C][clang], which is well-suited to interfacing with
+hardware but often prone to security issues. In order to exploit these issues, a
+typical attack might involve some combination of the following:
+
+1. Opening or creating some combination of files, sockets or other descriptors.
+1. Passing crafted, malicious arguments, structures or packets.
+1. Racing with multiple threads in order to hit specific code paths.
+
+For example, for the [Dirty Cow][dirtycow] privilege escalation bug, an
+application would open a specific file in `/proc` or use a specific `ptrace`
+system call, and use multiple threads in order to trigger a race condition when
+touching a fresh page of memory. The attacker then gains control over a page of
+memory belonging to the system. With additional privileges or access to
+privileged data in the kernel, an attacker will often be able to employ
+additional techniques to gain full access to the rest of the system.
+
+While bugs in the implementation of the System API are readily fixed, they are
+also the most common form of exploit. The exposure created by this class of
+exploit is what gVisor aims to minimize and control, described in detail below.
+
+### System ABI
+
+Hardware and software exploits occasionally exist in execution paths that are
+not part of an intended System API. In this case, exploits may be found as part
+of implicit actions the hardware or privileged system code takes in response to
+certain events, such as traps or interrupts. For example, the recent
+[POPSS][popss] flaw required only native code execution (no specific system call
+or file access). In that case, the Xen hypervisor was similarly vulnerable,
+highlighting that hypervisors are not immune to this vector.
+
+### Side Channels
+
+Hardware side channels may be exploitable by any code running on a system:
+native, sandboxed, or virtualized. However, many host-level mitigations against
+hardware side channels are still effective with a sandbox. For example, kernels
+built with retpoline protect against some speculative execution attacks
+(Spectre) and frame poisoning may protect against L1 terminal fault (L1TF)
+attacks. Hypervisors may introduce additional complications in this regard, as
+there is no mitigation against an application in a normally functioning Virtual
+Machine (VM) exploiting the L1TF vulnerability for another VM on the sibling
+hyperthread.
+
+### Other Vectors
+
+The above categories in no way represent an exhaustive list of exploits, as we
+focus only on running untrusted code from within the operating system or
+hypervisor. We do not consider other ways that a more generic adversary may
+interact with a system, such as inserting a portable storage device with a
+malicious filesystem image, using a combination of crafted keyboard or touch
+inputs, or saturating a network device with ill-formed packets.
+
+Furthermore, high-level systems may contain exploitable components. An attacker
+need not escalate privileges within a container if there’s an exploitable
+network-accessible service on the host or some other API path. *A sandbox is not
+a substitute for a secure architecture*.
+
+## Goals: Limiting Exposure
+
+![Threat model](security.png "Threat model.")
+
+gVisor’s primary design goal is to minimize the System API attack vector through
+multiple layers of defense, while still providing a process model. There are two
+primary security principles that inform this design. First, the application’s
+direct interactions with the host System API are intercepted by the Sentry,
+which implements the System API instead. Second, the System API accessible to
+the Sentry itself is minimized to a safer, restricted set. The first principle
+minimizes the possibility of direct exploitation of the host System API by
+applications, and the second principle minimizes indirect exploitability, which
+is the exploitation by an exploited or buggy Sentry (e.g. chaining an exploit).
+
+The first principle is similar to the security basis for a Virtual Machine (VM).
+With a VM, an application’s interactions with the host are replaced by
+interactions with a guest operating system and a set of virtualized hardware
+devices. These hardware devices are then implemented via the host System API by
+a Virtual Machine Monitor (VMM). The Sentry similarly prevents direct
+interactions by providing its own implementation of the System API that the
+application must interact with. Applications are not able to to directly craft
+specific arguments or flags for the host System API, or interact directly with
+host primitives.
+
+For both the Sentry and a VMM, it’s worth noting that while direct interactions
+are not possible, indirect interactions are still possible. For example, a read
+on a host-backed file in the Sentry may ultimately result in a host read system
+call (made by the Sentry, not by passing through arguments from the
+application), similar to how a read on a block device in a VM may result in the
+VMM issuing a corresponding host read system call from a backing file.
+
+An important distinction from a VM is that the Sentry implements a System API
+based directly on host System API primitives instead of relying on virtualized
+hardware and a guest operating system. This selects a distinct set of
+trade-offs, largely in the performance, efficiency and compatibility domains.
+Since transitions in and out of the sandbox are relatively expensive, a guest
+operating system will typically take ownership of resources. For example, in the
+above case, the guest operating system may read the block device data in a local
+page cache, to avoid subsequent reads. This may lead to better performance but
+lower efficiency, since memory may be wasted or duplicated. The Sentry opts
+instead to defer to the host for many operations during runtime, for improved
+efficiency but lower performance in some use cases.
+
+### What can a sandbox do?
+
+An application in a gVisor sandbox is permitted to do most things a standard
+container can do: for example, applications can read and write files mapped
+within the container, make network connections, etc. As described above,
+gVisor's primary goal is to limit exposure to bugs and exploits while still
+allowing most applications to run. Even so, gVisor will limit some operations
+that might be permitted with a standard container. Even with appropriate
+capabilities, a user in a gVisor sandbox will only be able to manipulate
+virtualized system resources (e.g. the system time, kernel settings or
+filesystem attributes) and not underlying host system resources.
+
+While the sandbox virtualizes many operations for the application, we limit the
+sandbox's own interactions with the host to the following high-level operations:
+
+1. Communicate with a Gofer process via a connected socket. The sandbox may
+ receive new file descriptors from the Gofer process, corresponding to opened
+ files. These files can then be read from and written to by the sandbox.
+1. Make a minimal set of host system calls. The calls do not include the
+ creation of new sockets (unless host networking mode is enabled) or opening
+ files. The calls include duplication and closing of file descriptors,
+ synchronization, timers and signal management.
+1. Read and write packets to a virtual ethernet device. This is not required if
+ host networking is enabled (or networking is disabled).
+
+### System ABI, Side Channels and Other Vectors
+
+gVisor relies on the host operating system and the platform for defense against
+hardware-based attacks. Given the nature of these vulnerabilities, there is
+little defense that gVisor can provide (there’s no guarantee that additional
+hardware measures, such as virtualization, memory encryption, etc. would
+actually decrease the attack surface). Note that this is true even when using
+hardware virtualization for acceleration, as the host kernel or hypervisor is
+ultimately responsible for defending against attacks from within malicious
+guests.
+
+gVisor similarly relies on the host resource mechanisms (cgroups) for defense
+against resource exhaustion and denial of service attacks. Network policy
+controls should be applied at the container level to ensure appropriate network
+policy enforcement. Note that the sandbox itself is not capable of altering or
+configuring these mechanisms, and the sandbox itself should make an attacker
+less likely to exploit or override these controls through other means.
+
+## Principles: Defense-in-Depth
+
+For gVisor development, there are several engineering principles that are
+employed in order to ensure that the system meets its design goals.
+
+1. No system call is passed through directly to the host. Every supported call
+ has an independent implementation in the Sentry, that is unlikely to suffer
+ from identical vulnerabilities that may appear in the host. This has the
+ consequence that all kernel features used by applications require an
+ implementation within the Sentry.
+1. Only common, universal functionality is implemented. Some filesystems,
+ network devices or modules may expose specialized functionality to user
+ space applications via mechanisms such as extended attributes, raw sockets
+ or ioctls. Since the Sentry is responsible for implementing the full system
+ call surface, we do not implement or pass through these specialized APIs.
+1. The host surface exposed to the Sentry is minimized. While the system call
+ surface is not trivial, it is explicitly enumerated and controlled. The
+ Sentry is not permitted to open new files, create new sockets or do many
+ other interesting things on the host.
+
+Additionally, we have practical restrictions that are imposed on the project to
+minimize the risk of Sentry exploitability. For example:
+
+1. Unsafe code is carefully controlled. All unsafe code is isolated in files
+ that end with "unsafe.go", in order to facilitate validation and auditing.
+ No file without the unsafe suffix may import the unsafe package.
+1. No CGo is allowed. The Sentry must be a pure Go binary.
+1. External imports are not generally allowed within the core packages. Only
+ limited external imports are used within the setup code. The code available
+ inside the Sentry is carefully controlled, to ensure that the above rules
+ are effective.
+
+Finally, we recognize that security is a process, and that vigilance is
+critical. Beyond our security disclosure process, the Sentry is fuzzed
+continuously to identify potential bugs and races proactively, and production
+crashes are recorded and triaged to similarly identify material issues.
+
+## FAQ
+
+### Is this more or less secure than a Virtual Machine?
+
+The security of a VM depends to a large extent on what is exposed from the host
+kernel and userspace support code. For example, device emulation code in the
+host kernel (e.g. APIC) or optimizations (e.g. vhost) can be more complex than a
+simple system call, and exploits carry the same risks. Similarly, the userspace
+support code is frequently unsandboxed, and exploits, while rare, may allow
+unfettered access to the system.
+
+Some platforms leverage the same virtualization hardware as VMs in order to
+provide better system call interception performance. However, gVisor does not
+implement any device emulation, and instead opts to use a sandboxed host System
+API directly. Both approaches significantly reduce the original attack surface.
+Ultimately, since gVisor is capable of using the same hardware mechanism, one
+should not assume that the mere use of virtualization hardware makes a system
+more or less secure, just as it would be a mistake to make the claim that the
+use of a unibody alone makes a car safe.
+
+### Does this stop hardware side channels?
+
+In general, gVisor does not provide protection against hardware side channels,
+although it may make exploits that rely on direct access to the host System API
+more difficult to use. To minimize exposure, you should follow relevant guidance
+from vendors and keep your host kernel and firmware up-to-date.
+
+### Is this just a ptrace sandbox?
+
+No: the term “ptrace sandbox” generally refers to software that uses the Linux
+ptrace facility to inspect and authorize system calls made by applications,
+enforcing a specific policy. These commonly suffer from two issues. First,
+vulnerable system calls may be authorized by the sandbox, as the application
+still has direct access to some System API. Second, it’s impossible to avoid
+time-of-check, time-of-use race conditions without disabling multi-threading.
+
+In gVisor, the platforms that use ptrace operate differently. The stubs that are
+traced are never allowed to continue execution into the host kernel and complete
+a call directly. Instead, all system calls are interpreted and handled by the
+Sentry itself, who reflects resulting register state back into the tracee before
+continuing execution in userspace. This is very similar to the mechanism used by
+User-Mode Linux (UML).
+
+[dirtycow]: https://en.wikipedia.org/wiki/Dirty_COW
+[clang]: https://en.wikipedia.org/wiki/C_(programming_language)
+[popss]: https://nvd.nist.gov/vuln/detail/CVE-2018-8897
diff --git a/g3doc/architecture_guide/security.png b/g3doc/architecture_guide/security.png
new file mode 100644
index 000000000..c29befbf6
--- /dev/null
+++ b/g3doc/architecture_guide/security.png
Binary files differ
diff --git a/g3doc/architecture_guide/security.svg b/g3doc/architecture_guide/security.svg
new file mode 100644
index 000000000..0575e2dec
--- /dev/null
+++ b/g3doc/architecture_guide/security.svg
@@ -0,0 +1,153 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ width="92.963379mm"
+ height="107.18885mm"
+ viewBox="0 0 92.963379 107.18885"
+ version="1.1"
+ id="svg8"
+ inkscape:version="0.92.4 (5da689c313, 2019-01-14)"
+ sodipodi:docname="defense.svg">
+ <defs
+ id="defs2" />
+ <sodipodi:namedview
+ id="base"
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1.0"
+ inkscape:pageopacity="0.0"
+ inkscape:pageshadow="2"
+ inkscape:zoom="0.98994949"
+ inkscape:cx="-242.99254"
+ inkscape:cy="136.90181"
+ inkscape:document-units="mm"
+ inkscape:current-layer="layer4"
+ showgrid="false"
+ inkscape:object-nodes="true"
+ inkscape:window-width="1920"
+ inkscape:window-height="1005"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="1"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0" />
+ <metadata
+ id="metadata5">
+ <rdf:RDF>
+ <cc:Work
+ rdf:about="">
+ <dc:format>image/svg+xml</dc:format>
+ <dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" />
+ <dc:title></dc:title>
+ </cc:Work>
+ </rdf:RDF>
+ </metadata>
+ <g
+ inkscape:groupmode="layer"
+ id="layer2"
+ inkscape:label="Layer 2"
+ transform="translate(-61.112559,-78.160466)">
+ <g
+ id="g4644"
+ style="fill:none;fill-opacity:0.34351148;stroke:#00a500;stroke-width:1;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:0.25572576"
+ transform="matrix(1,0,0,-1,2.138671,277.94235)">
+ <path
+ transform="scale(0.26458333)"
+ inkscape:connector-curvature="0"
+ style="opacity:1;fill:none;fill-opacity:0.34351148;stroke:#00a500;stroke-width:3.77952766;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:0.25572576"
+ d="M 398.57227,351.84766 224.7832,452.18359 398.57227,552.51953 572.35938,452.18359 Z"
+ id="path4638" />
+ <path
+ inkscape:connector-curvature="0"
+ style="opacity:1;fill:none;fill-opacity:0.34351148;stroke:#00a500;stroke-width:3.77952766;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:0.25572576"
+ d="M 572.35938,452.18359 398.57227,552.51953 V 753.19141 L 572.35938,652.85547 Z"
+ transform="scale(0.26458333)"
+ id="path4640" />
+ <path
+ id="path4642"
+ d="m 59.473888,119.64024 45.981172,26.54722 v 53.09443 L 59.473888,172.73467 Z"
+ style="opacity:1;fill:none;fill-opacity:0.34351148;stroke:#00a500;stroke-width:1;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:0.25572576"
+ inkscape:connector-curvature="0" />
+ </g>
+ </g>
+ <g
+ inkscape:groupmode="layer"
+ id="layer3"
+ inkscape:label="Layer 3"
+ transform="translate(-61.112559,-78.160466)">
+ <g
+ id="g4554"
+ transform="matrix(-0.39771468,0.69855937,-0.69855937,-0.39771468,366.58103,126.65261)">
+ <g
+ id="g4662"
+ transform="translate(59.46839,130.66062)">
+ <path
+ inkscape:connector-curvature="0"
+ id="path4548"
+ transform="scale(0.26458333)"
+ d="M 398.57227,351.84766 224.7832,452.18359 398.57227,552.51953 572.35938,452.18359 Z"
+ style="opacity:1;fill:#0066ff;fill-opacity:0.34509804;stroke:#00a5ff;stroke-width:4.70182848;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1" />
+ <path
+ inkscape:connector-curvature="0"
+ id="path4550"
+ transform="scale(0.26458333)"
+ d="M 572.35938,452.18359 398.57227,552.51953 V 753.19141 L 572.35938,652.85547 Z"
+ style="opacity:1;fill:#0044aa;fill-opacity:0.34509804;stroke:#00a5ff;stroke-width:4.29276943;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1" />
+ <path
+ inkscape:connector-curvature="0"
+ style="opacity:1;fill:#5599ff;fill-opacity:0.34509804;stroke:#00a5ff;stroke-width:1.24402535;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1"
+ d="m 59.473888,119.64024 45.981172,26.54722 v 53.09443 L 59.473888,172.73467 Z"
+ id="path4552" />
+ </g>
+ </g>
+ </g>
+ <g
+ inkscape:groupmode="layer"
+ id="layer4"
+ inkscape:label="Layer 4"
+ transform="translate(-61.112559,-78.160466)">
+ <path
+ style="fill:#e000ae;fill-opacity:1;stroke-width:0.12476727"
+ d="m 84.610811,107.36071 v 2.55773 2.55772 h 2.49535 2.49534 v -2.55772 -2.55773 h -2.49534 z m 40.674129,0 v 2.55773 2.55772 h 2.49535 2.49534 v -2.55772 -2.55773 h -2.49534 z m -35.558669,5.11545 v 2.55773 2.55773 h 2.49535 2.49534 v -2.55773 -2.55773 h -2.49534 z m 4.99069,5.11546 v 2.55773 2.55773 h -2.49534 -2.49535 v 2.49534 2.49535 h -2.55773 -2.55773 v 2.55773 2.55773 h -2.55773 -2.55773 v 10.16853 10.16853 h 2.55773 2.55773 v -7.67562 -7.67587 l 2.52654,0.0339 2.52654,0.0336 0.0327,5.08427 0.0327,5.08426 h 2.49388 2.49388 v 2.55919 2.5592 l 5.08427,-0.0327 5.084269,-0.0326 v -2.49534 -2.49535 l -5.084269,-0.0324 -5.08427,-0.0327 v -2.55626 -2.55651 h 12.726269 12.72626 v 2.55651 2.55626 l -5.05868,0.0327 -5.05893,0.0324 v 2.49535 2.49534 l 5.05893,0.0326 5.05868,0.0327 v -2.55919 -2.55919 h 2.49388 2.49413 l 0.0324,-5.08426 0.0327,-5.08427 2.52653,-0.0336 2.52654,-0.0339 v 7.67586 7.67563 h 2.55773 2.55773 v -10.16854 -10.16853 h -2.55773 -2.55773 v -2.55773 -2.55773 h -2.55773 -2.55773 v -2.49535 -2.49534 h -2.49535 -2.49534 v -2.55773 -2.55773 h -2.55773 -2.55773 v 2.55773 2.55773 h -7.6108 -7.610809 v -2.55773 -2.55773 h -2.55774 z m 25.452519,0 h 2.49535 2.49535 v -2.55773 -2.55773 h -2.49535 -2.49535 v 2.55773 z m -25.452519,10.10615 h 5.11546 5.115459 v 2.55773 2.55773 h -5.115459 -5.11546 v -2.55773 z m 15.221609,0 h 5.11546 5.11545 v 2.55773 2.55773 h -5.11545 -5.11546 v -2.55773 z"
+ id="path4732"
+ inkscape:connector-curvature="0" />
+ </g>
+ <g
+ inkscape:label="Layer 1"
+ inkscape:groupmode="layer"
+ id="layer1"
+ style="display:inline"
+ transform="translate(-61.112559,-78.160466)">
+ <g
+ transform="translate(-131.49557,42.495842)"
+ style="fill:#007200;fill-opacity:0.34351148;stroke:#00a500;stroke-width:1;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1"
+ id="g4628">
+ <path
+ id="path4529"
+ d="m 239.09034,36.164616 -45.98169,26.547215 45.98169,26.547217 45.98117,-26.547217 z"
+ style="opacity:1;fill:#4aba19;fill-opacity:0.34509804;stroke:#00a500;stroke-width:1;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1"
+ inkscape:connector-curvature="0" />
+ <path
+ id="path4531"
+ d="m 285.07151,62.711828 -45.98117,26.54722 v 53.094432 l 45.98117,-26.54722 z"
+ style="opacity:1;fill:#007900;fill-opacity:0.34351148;stroke:#00a500;stroke-width:1;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1"
+ inkscape:connector-curvature="0" />
+ <path
+ inkscape:connector-curvature="0"
+ style="opacity:1;fill:#003d00;fill-opacity:0.34509804;stroke:#00a500;stroke-width:1;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1"
+ d="m 193.10865,62.711831 45.98117,26.54722 v 53.094429 l -45.98117,-26.54722 z"
+ id="path4541" />
+ </g>
+ </g>
+</svg>
diff --git a/g3doc/community.md b/g3doc/community.md
new file mode 100644
index 000000000..76f4d87c3
--- /dev/null
+++ b/g3doc/community.md
@@ -0,0 +1,31 @@
+# Participation
+
+To contribute code, please read the [contributing guide](../CONTRIBUTING.md).
+
+Please note that the [Code of Conduct](../CODE_OF_CONDUCT.md) applies to
+community forums as well as technical participation.
+
+## Communication channels
+
+The project maintains two mailing lists:
+
+* [gvisor-users][gvisor-users] for accouncements and general discussion.
+* [gvisor-dev][gvisor-dev] for development and contribution.
+
+We also have a [chat room hosted on Gitter][gitter-chat].
+
+We'd love to hear from you!
+
+## Community meetings
+
+The community calendar shows upcoming public meetings and opportunities to
+collaborate or discuss the project. Meetings are planned and announced ahead of
+time via the [gvisor-users][gvisor-users] mailing list.
+
+These meetings are public: anyone can join.
+
+<iframe src="https://calendar.google.com/calendar/b/1/embed?showTitle=0&amp;height=600&amp;wkst=1&amp;bgcolor=%23FFFFFF&amp;src=bd6f4k210u3ukmlj9b8vl053fk%40group.calendar.google.com&amp;color=%23AB8B00&amp;ctz=America%2FLos_Angeles" style="border-width:0" width="600" height="400" frameborder="0" scrolling="no"></iframe>
+
+[gitter-chat]: https://gitter.im/gvisor/community
+[gvisor-dev]: https://groups.google.com/forum/#!forum/gvisor-dev
+[gvisor-users]: https://groups.google.com/forum/#!forum/gvisor-users
diff --git a/g3doc/logo.png b/g3doc/logo.png
new file mode 100644
index 000000000..bd1a1e4b7
--- /dev/null
+++ b/g3doc/logo.png
Binary files differ
diff --git a/g3doc/logo.txt b/g3doc/logo.txt
new file mode 100644
index 000000000..92f9cad5f
--- /dev/null
+++ b/g3doc/logo.txt
@@ -0,0 +1 @@
+The gVisor logo files are licensed under CC BY-SA 4.0 (Creative Commons Attribution-ShareAlike 4.0 International).
diff --git a/g3doc/roadmap.md b/g3doc/roadmap.md
new file mode 100644
index 000000000..06ea25a8b
--- /dev/null
+++ b/g3doc/roadmap.md
@@ -0,0 +1,49 @@
+# Roadmap
+
+gVisor [GitHub Issues][issues] serve as the source-of-truth for most work in
+flight. Specific performance and compatibility issues are generally tracked
+there. [GitHub Milestones][milestones] may be used to track larger features that
+span many issues. However, labels are also used to aggregate cross-cutting
+feature work.
+
+## Core Improvements
+
+Most gVisor work is focused on four areas.
+
+* [Performance][performance]: overall sandbox performance, including platform
+ performance, is a critical area for investment. This includes: network
+ performance (throughput and latency), file system performance (metadata and
+ data I/O), application switch and fault costs, etc. The goal of gVisor is to
+ provide sandboxing without a material performance or efficiency impact on
+ all but the most performance-sensitive applications.
+
+* [Compatibility][compatibility]: supporting a wide range of applications
+ requires supporting a large system API, including special system files (e.g.
+ proc, sys, dev, etc.). The goal of gVisor is to support the broad set of
+ applications that depend on a generic Linux API, rather than a specific
+ kernel version.
+
+* [Infrastructure & tooling][infrastructure]: the above goals require
+ aggressive testing and coverage, and well-established processes. This
+ includes adding appropriate system call coverage, end-to-end suites and
+ runtime tests.
+
+* [Integration][integration]: Container infrastructure is evolving rapidly and
+ becoming more complex, and gVisor must continuously implement relevant and
+ popular features to ensure that integration points remain robust and
+ feature-complete while preserving security guarantees.
+
+## Releases
+
+Releases are available on [GitHub][releases].
+
+As a convenience, binary packages are also published. Instructions for their use
+are available via the [Installation instructions](./user_guide/install.md).
+
+[issues]: https://github.com/google/gvisor/issues
+[milestones]: https://github.com/google/gvisor/milestones
+[releases]: https://github.com/google/gvisor/releases
+[performance]: https://github.com/google/gvisor/issues?q=is%3Aopen+is%3Aissue+label%3A%22area%3A+performance%22
+[integration]: https://github.com/google/gvisor/issues?q=is%3Aopen+is%3Aissue+label%3A%22area%3A+integration%22
+[compatibility]: https://github.com/google/gvisor/issues?q=is%3Aopen+is%3Aissue+label%3A%22area%3A+compatibility%22
+[infrastructure]: https://github.com/google/gvisor/issues?q=is%3Aopen+is%3Aissue+label%3A%22area%3A+tooling%22
diff --git a/g3doc/style.md b/g3doc/style.md
new file mode 100644
index 000000000..d10549fe9
--- /dev/null
+++ b/g3doc/style.md
@@ -0,0 +1,88 @@
+# Provisional style guide
+
+> These guidelines are new and may change. This note will be removed when
+> consensus is reached.
+
+Not all existing code will comply with this style guide, but new code should.
+Further, it is a goal to eventually update all existing code to be in
+compliance.
+
+## All code
+
+### Early exit
+
+All code, unless it substantially increases the line count or complexity, should
+use early exits from loops and functions where possible.
+
+## Go specific
+
+All Go code should comply with the [Go Code Review Comments][gostyle] and
+[Effective Go][effective_go] guides, as well as the additional guidelines
+described below.
+
+### Mutexes
+
+#### Naming
+
+Mutexes should be named mu or xxxMu. Mutexes as a general rule should not be
+exported. Instead, export methods which use the mutexes to avoid leaky
+abstractions.
+
+#### Location
+
+Mutexes should be sibling fields to the fields that they protect. Mutexes should
+not be declared as global variables, instead use a struct (anonymous ok, but
+naming conventions still apply).
+
+Mutexes should be ordered before the fields that they protect.
+
+#### Comments
+
+Mutexes should have a comment on their declaration explaining any ordering
+requirements (or pointing to where this information can be found), if
+applicable. There is no need for a comment explaining which fields are
+protected.
+
+Each field or variable protected by a mutex should state as such in a comment on
+the field or variable declaration.
+
+### Unused returns
+
+Unused returns should be explicitly ignored with underscores. If there is a
+function which is commonly used without using its return(s), a wrapper function
+should be declared which explicitly ignores the returns. That said, in many
+cases, it may make sense for the wrapper to check the returns.
+
+### Formatting verbs
+
+Built-in types should use their associated verbs (e.g. %d for integral types),
+but other types should use a %v variant, even if they implement fmt.Stringer.
+The built-in `error` type should use %w when formatted with `fmt.Errorf`, but
+only then.
+
+### Wrapping
+
+Comments should be wrapped at 80 columns with a 2 space tab size.
+
+Code does not need to be wrapped, but if wrapping would make it more readable,
+it should be wrapped with each subcomponent of the thing being wrapped on its
+own line. For example, if a struct is split between lines, each field should be
+on its own line.
+
+#### Example
+
+```go
+_ = exec.Cmd{
+ Path: "/foo/bar",
+ Args: []string{"-baz"},
+}
+```
+
+## C++ specific
+
+C++ code should conform to the [Google C++ Style Guide][cppstyle] and the
+guidelines described for tests.
+
+[cppstyle]: https://google.github.io/styleguide/cppguide.html
+[gostyle]: https://github.com/golang/go/wiki/CodeReviewComments
+[effective_go]: https://golang.org/doc/effective_go.html
diff --git a/g3doc/user_guide/BUILD b/g3doc/user_guide/BUILD
new file mode 100644
index 000000000..b69aee12c
--- /dev/null
+++ b/g3doc/user_guide/BUILD
@@ -0,0 +1,70 @@
+load("//website:defs.bzl", "doc")
+
+package(
+ default_visibility = ["//website:__pkg__"],
+ licenses = ["notice"],
+)
+
+doc(
+ name = "compatibility",
+ src = "compatibility.md",
+ category = "Compatibility",
+ permalink = "/docs/user_guide/compatibility/",
+ weight = "0",
+)
+
+doc(
+ name = "checkpoint_restore",
+ src = "checkpoint_restore.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/checkpoint_restore/",
+ weight = "60",
+)
+
+doc(
+ name = "debugging",
+ src = "debugging.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/debugging/",
+ weight = "70",
+)
+
+doc(
+ name = "FAQ",
+ src = "FAQ.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/faq/",
+ weight = "90",
+)
+
+doc(
+ name = "filesystem",
+ src = "filesystem.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/filesystem/",
+ weight = "40",
+)
+
+doc(
+ name = "networking",
+ src = "networking.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/networking/",
+ weight = "50",
+)
+
+doc(
+ name = "install",
+ src = "install.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/install/",
+ weight = "10",
+)
+
+doc(
+ name = "platforms",
+ src = "platforms.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/platforms/",
+ weight = "30",
+)
diff --git a/g3doc/user_guide/FAQ.md b/g3doc/user_guide/FAQ.md
new file mode 100644
index 000000000..89df65e99
--- /dev/null
+++ b/g3doc/user_guide/FAQ.md
@@ -0,0 +1,122 @@
+# FAQ
+
+[TOC]
+
+### What operating systems are supported? {#supported-os}
+
+Today, gVisor requires Linux.
+
+### What CPU architectures are supported? {#supported-cpus}
+
+gVisor currently supports [x86_64/AMD64](https://en.wikipedia.org/wiki/X86-64)
+compatible processors. Preliminary support is also available for
+[ARM64](https://en.wikipedia.org/wiki/ARM_architecture#AArch64).
+
+### Do I need to modify my Linux application to use gVisor? {#modify-app}
+
+No. gVisor is capable of running unmodified Linux binaries.
+
+### What binary formats does gVisor support? {#supported-binaries}
+
+gVisor supports Linux
+[ELF](https://en.wikipedia.org/wiki/Executable_and_Linkable_Format) binaries.
+
+Binaries run in gVisor should be built for the
+[AMD64](https://en.wikipedia.org/wiki/X86-64) or
+[AArch64](https://en.wikipedia.org/wiki/ARM_architecture#AArch64) CPU
+architectures.
+
+### Can I run Docker images using gVisor? {#docker-images}
+
+Yes. Please see the [Docker Quick Start][docker].
+
+### Can I run Kubernetes pods using gVisor? {#k8s-pods}
+
+Yes. Please see the [Kubernetes Quick Start][k8s].
+
+### What's the security model? {#security-model}
+
+See the [Security Model][security-model].
+
+## Troubleshooting
+
+### My container runs fine with `runc` but fails with `runsc` {#app-compatibility}
+
+If you’re having problems running a container with `runsc` it’s most likely due
+to a compatibility issue or a missing feature in gVisor. See
+[Debugging][debugging].
+
+### When I run my container, docker fails with: `open /run/containerd/.../<containerid>/log.json: no such file or directory` {#memfd-create}
+
+You are using an older version of Linux which doesn't support `memfd_create`.
+
+This is tracked in [bug #268](https://gvisor.dev/issue/268).
+
+### When I run my container, docker fails with: `flag provided but not defined: -console` {#old-docker}
+
+You're using an old version of Docker. See [Docker Quick Start][docker].
+
+### I can’t see a file copied with: `docker cp` {#fs-cache}
+
+For performance reasons, gVisor caches directory contents, and therefore it may
+not realize a new file was copied to a given directory. To invalidate the cache
+and force a refresh, create a file under the directory in question and list the
+contents again.
+
+As a workaround, shared root filesystem can be enabled. See
+[Filesystem][filesystem].
+
+This bug is tracked in [bug #4](https://gvisor.dev/issue/4).
+
+Note that `kubectl cp` works because it does the copy by exec'ing inside the
+sandbox, and thus gVisor's internal cache is made aware of the new files and
+directories.
+
+### I'm getting an error like: `panic: unable to attach: operation not permitted` or `fork/exec /proc/self/exe: invalid argument: unknown` {#runsc-perms}
+
+Make sure that permissions and the owner is correct on the `runsc` binary.
+
+```bash
+sudo chown root:root /usr/local/bin/runsc
+sudo chmod 0755 /usr/local/bin/runsc
+```
+
+### I'm getting an error like `mount submount "/etc/hostname": creating mount with source ".../hostname": input/output error: unknown.` {#memlock}
+
+There is a bug in Linux kernel versions 5.1 to 5.3.15, 5.4.2, and 5.5. Upgrade
+to a newer kernel or add the following to
+`/lib/systemd/system/containerd.service` as a workaround.
+
+```
+LimitMEMLOCK=infinity
+```
+
+And run `systemctl daemon-reload && systemctl restart containerd` to restart
+containerd.
+
+See [issue #1765](https://gvisor.dev/issue/1765) for more details.
+
+### My container cannot resolve another container's name when using Docker user defined bridge {#docker-bridge}
+
+This is normally indicated by errors like `bad address 'container-name'` when
+trying to communicate to another container in the same network.
+
+Docker user defined bridge uses an embedded DNS server bound to the loopback
+interface on address 127.0.0.10. This requires access to the host network in
+order to communicate to the DNS server. runsc network is isolated from the host
+and cannot access the DNS server on the host network without breaking the
+sandbox isolation. There are a few different workarounds you can try:
+
+* Use default bridge network with `--link` to connect containers. Default
+ bridge doesn't use embedded DNS.
+* Use [`--network=host`][host-net] option in runsc, however beware that it
+ will use the host network stack and is less secure.
+* Use IPs instead of container names.
+* Use [Kubernetes][k8s]. Container name lookup works fine in Kubernetes.
+
+[security-model]: /docs/architecture_guide/security/
+[host-net]: /docs/user_guide/networking/#network-passthrough
+[debugging]: /docs/user_guide/debugging/
+[filesystem]: /docs/user_guide/filesystem/
+[docker]: /docs/user_guide/quick_start/docker/
+[k8s]: /docs/user_guide/quick_start/kubernetes/
diff --git a/g3doc/user_guide/checkpoint_restore.md b/g3doc/user_guide/checkpoint_restore.md
new file mode 100644
index 000000000..0ab0911b0
--- /dev/null
+++ b/g3doc/user_guide/checkpoint_restore.md
@@ -0,0 +1,101 @@
+# Checkpoint/Restore
+
+[TOC]
+
+gVisor has the ability to checkpoint a process, save its current state in a
+state file, and restore into a new container using the state file.
+
+## How to use checkpoint/restore
+
+Checkpoint/restore functionality is currently available via raw `runsc`
+commands. To use the checkpoint command, first run a container.
+
+```bash
+runsc run <container id>
+```
+
+To checkpoint the container, the `--image-path` flag must be provided. This is
+the directory path within which the checkpoint state-file will be created. The
+file will be called `checkpoint.img` and necessary directories will be created
+if they do not yet exist.
+
+> Note: Two checkpoints cannot be saved to the same directory; every image-path
+> provided must be unique.
+
+```bash
+runsc checkpoint --image-path=<path> <container id>
+```
+
+There is also an optional `--leave-running` flag that allows the container to
+continue to run after the checkpoint has been made. (By default, containers stop
+their processes after committing a checkpoint.)
+
+> Note: All top-level runsc flags needed when calling run must be provided to
+> checkpoint if --leave-running is used.
+
+> Note: --leave-running functions by causing an immediate restore so the
+> container, although will maintain its given container id, may have a different
+> process id.
+
+```bash
+runsc checkpoint --image-path=<path> --leave-running <container id>
+```
+
+To restore, provide the image path to the `checkpoint.img` file created during
+the checkpoint. Because containers stop by default after checkpointing, restore
+needs to happen in a new container (restore is a command which parallels start).
+
+```bash
+runsc create <container id>
+
+runsc restore --image-path=<path> <container id>
+```
+
+## How to use checkpoint/restore in Docker:
+
+Currently checkpoint/restore through `runsc` is not entirely compatible with
+Docker, although there has been progress made from both gVisor and Docker to
+enable compatibility. Here, we document the ideal workflow.
+
+Run a container:
+
+```bash
+docker run [options] --runtime=runsc <image>`
+```
+
+Checkpoint a container:
+
+```bash
+docker checkpoint create <container> <checkpoint_name>`
+```
+
+Create a new container into which to restore:
+
+```bash
+docker create [options] --runtime=runsc <image>
+```
+
+Restore a container:
+
+```bash
+docker start --checkpoint --checkpoint-dir=<directory> <container>
+```
+
+### Issues Preventing Compatibility with Docker
+
+- **[Moby #37360][leave-running]:** Docker version 18.03.0-ce and earlier
+ hangs when checkpointing and does not create the checkpoint. To successfully
+ use this feature, install a custom version of docker-ce from the moby
+ repository. This issue is caused by an improper implementation of the
+ `--leave-running` flag. This issue is fixed in newer releases.
+- **Docker does not support restoration into new containers:** Docker
+ currently expects the container which created the checkpoint to be the same
+ container used to restore which is not possible in runsc. When Docker
+ supports container migration and therefore restoration into new containers,
+ this will be the flow.
+- **[Moby #37344][checkpoint-dir]:** Docker does not currently support the
+ `--checkpoint-dir` flag but this will be required when restoring from a
+ checkpoint made in another container.
+
+[leave-running]: https://github.com/moby/moby/pull/37360
+[checkpoint-dir]: https://github.com/moby/moby/issues/37344
diff --git a/g3doc/user_guide/compatibility.md b/g3doc/user_guide/compatibility.md
new file mode 100644
index 000000000..9d3e3680f
--- /dev/null
+++ b/g3doc/user_guide/compatibility.md
@@ -0,0 +1,93 @@
+# Applications
+
+[TOC]
+
+gVisor implements a large portion of the Linux surface and while we strive to
+make it broadly compatible, there are (and always will be) unimplemented
+features and bugs. The only real way to know if it will work is to try. If you
+find a container that doesn’t work and there is no known issue, please
+[file a bug][bug] indicating the full command you used to run the image. You can
+view open issues related to compatibility [here][issues].
+
+If you're able to provide the [debug logs](../debugging/), the problem likely to
+be fixed much faster.
+
+## What works?
+
+The following applications/images have been tested:
+
+* elasticsearch
+* golang
+* httpd
+* java8
+* jenkins
+* mariadb
+* memcached
+* mongo
+* mysql
+* nginx
+* node
+* php
+* postgres
+* prometheus
+* python
+* redis
+* registry
+* tomcat
+* wordpress
+
+## Utilities
+
+Most common utilities work. Note that:
+
+* Some tools, such as `tcpdump` and old versions of `ping`, require explicitly
+ enabling raw sockets via the unsafe `--net-raw` runsc flag.
+* Different Docker images can behave differently. For example, Alpine Linux
+ and Ubuntu have different `ip` binaries.
+
+ Specific tools include:
+
+<!-- mdformat off(don't wrap the table) -->
+
+| Tool | Status |
+|:--------:|:-----------------------------------------:|
+| apt-get | Working. |
+| bundle | Working. |
+| cat | Working. |
+| curl | Working. |
+| dd | Working. |
+| df | Working. |
+| dig | Working. |
+| drill | Working. |
+| env | Working. |
+| find | Working. |
+| gdb | Working. |
+| gosu | Working. |
+| grep | Working (unless stdin is a pipe and stdout is /dev/null). |
+| ifconfig | Works partially, like ip. Full support [in progress](https://gvisor.dev/issue/578). |
+| ip | Some subcommands work (e.g. addr, route). Full support [in progress](https://gvisor.dev/issue/578). |
+| less | Working. |
+| ls | Working. |
+| lsof | Working. |
+| mount | Works in readonly mode. gVisor doesn't currently support creating new mounts at runtime. |
+| nc | Working. |
+| nmap | Not working. |
+| netstat | [In progress](https://gvisor.dev/issue/2112). |
+| nslookup | Working. |
+| ping | Working. |
+| ps | Working. |
+| route | Working. |
+| ss | [In progress](https://gvisor.dev/issue/2114). |
+| sshd | Partially working. Job control [in progress](https://gvisor.dev/issue/154). |
+| strace | Working. |
+| tar | Working. |
+| tcpdump | [In progress](https://gvisor.dev/issue/173). |
+| top | Working. |
+| uptime | Working. |
+| vim | Working. |
+| wget | Working. |
+
+<!-- mdformat on -->
+
+[bug]: https://github.com/google/gvisor/issues/new?title=Compatibility%20Issue:
+[issues]: https://github.com/google/gvisor/issues?q=is%3Aissue+is%3Aopen+label%3A%22area%3A+compatibility%22
diff --git a/g3doc/user_guide/debugging.md b/g3doc/user_guide/debugging.md
new file mode 100644
index 000000000..54fdce34f
--- /dev/null
+++ b/g3doc/user_guide/debugging.md
@@ -0,0 +1,141 @@
+# Debugging
+
+[TOC]
+
+To enable debug and system call logging, add the `runtimeArgs` below to your
+[Docker](../quick_start/docker/) configuration (`/etc/docker/daemon.json`):
+
+```json
+{
+ "runtimes": {
+ "runsc": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--debug-log=/tmp/runsc/",
+ "--debug",
+ "--strace"
+ ]
+ }
+ }
+}
+```
+
+> Note: the last `/` in `--debug-log` is needed to interpret it as a directory.
+> Then each `runsc` command executed will create a separate log file. Otherwise,
+> log messages from all commands will be appended to the same file.
+
+You may also want to pass `--log-packets` to troubleshoot network problems. Then
+restart the Docker daemon:
+
+```bash
+sudo systemctl restart docker
+```
+
+Run your container again, and inspect the files under `/tmp/runsc`. The log file
+ending with `.boot` will contain the strace logs from your application, which
+can be useful for identifying missing or broken system calls in gVisor. If you
+are having problems starting the container, the log file ending with `.create`
+may have the reason for the failure.
+
+## Stack traces
+
+The command `runsc debug --stacks` collects stack traces while the sandbox is
+running which can be useful to troubleshoot issues or just to learn more about
+gVisor. It connects to the sandbox process, collects a stack dump, and writes it
+to the console. For example:
+
+```bash
+docker run --runtime=runsc --rm -d alpine sh -c "while true; do echo running; sleep 1; done"
+63254c6ab3a6989623fa1fb53616951eed31ac605a2637bb9ddba5d8d404b35b
+
+sudo runsc --root /var/run/docker/runtime-runsc/moby debug --stacks 63254c6ab3a6989623fa1fb53616951eed31ac605a2637bb9ddba5d8d404b35b
+```
+
+> Note: `--root` variable is provided by docker and is normally set to
+> `/var/run/docker/runtime-[runtime-name]/moby`. If in doubt, `--root` is logged
+> to `runsc` logs.
+
+## Debugger
+
+You can debug gVisor like any other Golang program. If you're running with
+Docker, you'll need to find the sandbox PID and attach the debugger as root.
+Here is an example:
+
+```bash
+# Get a runsc with debug symbols (download nightly or build with symbols).
+bazel build -c dbg //runsc:runsc
+
+# Start the container you want to debug.
+docker run --runtime=runsc --rm --name=test -d alpine sleep 1000
+
+# Find the sandbox PID.
+docker inspect test | grep Pid | head -n 1
+
+# Attach your favorite debugger.
+sudo dlv attach <PID>
+
+# Set a breakpoint and resume.
+break mm.MemoryManager.MMap
+continue
+```
+
+## Profiling
+
+`runsc` integrates with Go profiling tools and gives you easy commands to
+profile CPU and heap usage. First you need to enable `--profile` in the command
+line options before starting the container:
+
+```json
+{
+ "runtimes": {
+ "runsc-prof": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--profile"
+ ]
+ }
+ }
+}
+```
+
+> Note: Enabling profiling loosens the seccomp protection added to the sandbox,
+> and should not be run in production under normal circumstances.
+
+Then restart docker to refresh the runtime options. While the container is
+running, execute `runsc debug` to collect profile information and save to a
+file. Here are the options available:
+
+* **--profile-heap:** Generates heap profile to the speficied file.
+* **--profile-cpu:** Enables CPU profiler, waits for `--duration` seconds and
+ generates CPU profile to the speficied file.
+
+For example:
+
+```bash
+docker run --runtime=runsc-prof --rm -d alpine sh -c "while true; do echo running; sleep 1; done"
+63254c6ab3a6989623fa1fb53616951eed31ac605a2637bb9ddba5d8d404b35b
+
+sudo runsc --root /var/run/docker/runtime-runsc-prof/moby debug --profile-heap=/tmp/heap.prof 63254c6ab3a6989623fa1fb53616951eed31ac605a2637bb9ddba5d8d404b35b
+sudo runsc --root /var/run/docker/runtime-runsc-prof/moby debug --profile-cpu=/tmp/cpu.prof --duration=30s 63254c6ab3a6989623fa1fb53616951eed31ac605a2637bb9ddba5d8d404b35b
+```
+
+The resulting files can be opened using `go tool pprof` or [pprof][]. The
+examples below create image file (`.svg`) with the heap profile and writes the
+top functions using CPU to the console:
+
+```bash
+go tool pprof -svg /usr/local/bin/runsc /tmp/heap.prof
+go tool pprof -top /usr/local/bin/runsc /tmp/cpu.prof
+```
+
+[pprof]: https://github.com/google/pprof/blob/master/doc/README.md
+
+### Docker Proxy
+
+When forwarding a port to the container, Docker will likely route traffic
+through the [docker-proxy][]. This proxy may make profiling noisy, so it can be
+helpful to bypass it. Do so by sending traffic directly to the container IP and
+port. e.g., if the `docker0` IP is `192.168.9.1`, the container IP is likely a
+subsequent IP, such as `192.168.9.2`.
+
+[docker-proxy]: https://windsock.io/the-docker-proxy/
diff --git a/g3doc/user_guide/filesystem.md b/g3doc/user_guide/filesystem.md
new file mode 100644
index 000000000..cd00762dd
--- /dev/null
+++ b/g3doc/user_guide/filesystem.md
@@ -0,0 +1,60 @@
+# Filesystem
+
+[TOC]
+
+gVisor accesses the filesystem through a file proxy, called the Gofer. The gofer
+runs as a separate process, that is isolated from the sandbox. Gofer instances
+communicate with their respective sentry using the 9P protocol. For another
+explanation see [What is gVisor?](../README.md).
+
+## Sandbox overlay
+
+To isolate the host filesystem from the sandbox, you can set a writable tmpfs
+overlay on top of the entire filesystem. All modifications are made to the
+overlay, keeping the host filesystem unmodified.
+
+> Note: All created and modified files are stored in memory inside the sandbox.
+
+To use the tmpfs overlay, add the following `runtimeArgs` to your Docker
+configuration (`/etc/docker/daemon.json`) and restart the Docker daemon:
+
+```json
+{
+ "runtimes": {
+ "runsc": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--overlay"
+ ]
+ }
+ }
+}
+```
+
+## Shared root filesystem
+
+The root filesystem is where the image is extracted and is not generally
+modified from outside the sandbox. This allows for some optimizations, like
+skipping checks to determine if a directory has changed since the last time it
+was cached, thus missing updates that may have happened. If you need to `docker
+cp` files inside the root filesystem, you may want to enable shared mode. Just
+be aware that file system access will be slower due to the extra checks that are
+required.
+
+> Note: External mounts are always shared.
+
+To use set the root filesystem shared, add the following `runtimeArgs` to your
+Docker configuration (`/etc/docker/daemon.json`) and restart the Docker daemon:
+
+```json
+{
+ "runtimes": {
+ "runsc": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--file-access=shared"
+ ]
+ }
+ }
+}
+```
diff --git a/g3doc/user_guide/install.md b/g3doc/user_guide/install.md
new file mode 100644
index 000000000..9afdd264d
--- /dev/null
+++ b/g3doc/user_guide/install.md
@@ -0,0 +1,157 @@
+# Installation
+
+[TOC]
+
+> Note: gVisor supports only x86\_64 and requires Linux 4.14.77+
+> ([older Linux](./networking.md#gso)).
+
+## Versions
+
+The `runsc` binaries and repositories are available in multiple versions and
+release channels. You should pick the version you'd like to install. For
+experimentation, the nightly release is recommended. For production use, the
+latest release is recommended.
+
+After selecting an appropriate release channel from the options below, proceed
+to the preferred installation mechanism: manual or from an `apt` repository.
+
+### HEAD
+
+Binaries are available for every commit on the `master` branch, and are
+available at the following URL:
+
+`https://storage.googleapis.com/gvisor/releases/master/latest/runsc`
+
+Checksums for the release binary are at:
+
+`https://storage.googleapis.com/gvisor/releases/master/latest/runsc.sha512`
+
+For `apt` installation, use the `master` as the `${DIST}` below.
+
+### Nightly
+
+Nightly releases are built most nights from the master branch, and are available
+at the following URL:
+
+`https://storage.googleapis.com/gvisor/releases/nightly/latest/runsc`
+
+Checksums for the release binary are at:
+
+`https://storage.googleapis.com/gvisor/releases/nightly/latest/runsc.sha512`
+
+Specific nightly releases can be found at:
+
+`https://storage.googleapis.com/gvisor/releases/nightly/${yyyy-mm-dd}/runsc`
+
+Note that a release may not be available for every day.
+
+For `apt` installation, use the `nightly` as the `${DIST}` below.
+
+### Latest release
+
+The latest official release is available at the following URL:
+
+`https://storage.googleapis.com/gvisor/releases/release/latest`
+
+For `apt` installation, use the `release` as the `${DIST}` below.
+
+### Specific release
+
+A given release release is available at the following URL:
+
+`https://storage.googleapis.com/gvisor/releases/release/${yyyymmdd}`
+
+See the [releases][releases] page for information about specific releases.
+
+For `apt` installation of a specific release, which may include point updates,
+use the date of the release, e.g. `${yyyymmdd}`, as the `${DIST}` below.
+
+> Note: only newer releases may be available as `apt` repositories.
+
+### Point release
+
+A given point release is available at the following URL:
+
+`https://storage.googleapis.com/gvisor/releases/release/${yyyymmdd}.${rc}`
+
+Note that `apt` installation of a specific point release is not supported.
+
+## Install from an `apt` repository
+
+First, appropriate dependencies must be installed to allow `apt` to install
+packages via https:
+
+```bash
+sudo apt-get update && \
+sudo apt-get install -y \
+ apt-transport-https \
+ ca-certificates \
+ curl \
+ gnupg-agent \
+ software-properties-common
+```
+
+Next, the key used to sign archives should be added to your `apt` keychain:
+
+```bash
+curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add -
+```
+
+Based on the release type, you will need to substitute `${DIST}` below, using
+one of:
+
+* `master`: For HEAD.
+* `nightly`: For nightly releases.
+* `release`: For the latest release.
+* `${yyyymmdd}`: For a specific releases (see above).
+
+The repository for the release you wish to install should be added:
+
+```bash
+sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases ${DIST} main"
+```
+
+For example, to install the latest official release, you can use:
+
+```bash
+sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main"
+```
+
+Now the runsc package can be installed:
+
+```bash
+sudo apt-get update && sudo apt-get install -y runsc
+```
+
+If you have Docker installed, it will be automatically configured.
+
+## Install directly
+
+The binary URLs provided above can be used to install directly. For example, the
+latest nightly binary can be downloaded, validated, and placed in an appropriate
+location by running:
+
+```bash
+(
+ set -e
+ URL=https://storage.googleapis.com/gvisor/releases/nightly/latest
+ wget ${URL}/runsc
+ wget ${URL}/runsc.sha512
+ sha512sum -c runsc.sha512
+ rm -f runsc.sha512
+ sudo mv runsc /usr/local/bin
+ sudo chown root:root /usr/local/bin/runsc
+ sudo chmod 0755 /usr/local/bin/runsc
+)
+```
+
+**It is important to copy this binary to a location that is accessible to all
+users, and ensure it is executable by all users**, since `runsc` executes itself
+as user `nobody` to avoid unnecessary privileges. The `/usr/local/bin` directory
+is a good place to put the `runsc` binary.
+
+After installation, try out `runsc` by following the
+[Docker Quick Start](./quick_start/docker.md) or
+[OCI Quick Start](./quick_start/oci.md).
+
+[releases]: https://github.com/google/gvisor/releases
diff --git a/g3doc/user_guide/networking.md b/g3doc/user_guide/networking.md
new file mode 100644
index 000000000..4aa394c91
--- /dev/null
+++ b/g3doc/user_guide/networking.md
@@ -0,0 +1,85 @@
+# Networking
+
+[TOC]
+
+gVisor implements its own network stack called [netstack][netstack]. All aspects
+of the network stack are handled inside the Sentry — including TCP connection
+state, control messages, and packet assembly — keeping it isolated from the host
+network stack. Data link layer packets are written directly to the virtual
+device inside the network namespace setup by Docker or Kubernetes.
+
+The IP address and routes configured for the device are transferred inside the
+sandbox. The loopback device runs exclusively inside the sandbox and does not
+use the host. You can inspect them by running:
+
+```bash
+docker run --rm --runtime=runsc alpine ip addr
+```
+
+## Network passthrough
+
+For high-performance networking applications, you may choose to disable the user
+space network stack and instead use the host network stack, including the
+loopback. Note that this mode decreases the isolation to the host.
+
+Add the following `runtimeArgs` to your Docker configuration
+(`/etc/docker/daemon.json`) and restart the Docker daemon:
+
+```json
+{
+ "runtimes": {
+ "runsc": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--network=host"
+ ]
+ }
+ }
+}
+```
+
+## Disabling external networking
+
+To completely isolate the host and network from the sandbox, external networking
+can be disabled. The sandbox will still contain a loopback provided by netstack.
+
+Add the following `runtimeArgs` to your Docker configuration
+(`/etc/docker/daemon.json`) and restart the Docker daemon:
+
+```json
+{
+ "runtimes": {
+ "runsc": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--network=none"
+ ]
+ }
+ }
+}
+```
+
+### Disable GSO {#gso}
+
+If your Linux is older than 4.14.17, you can disable Generic Segmentation
+Offload (GSO) to run with a kernel that is newer than 3.17. Add the
+`--gso=false` flag to your Docker runtime configuration
+(`/etc/docker/daemon.json`) and restart the Docker daemon:
+
+> Note: Network performance, especially for large payloads, will be greatly
+> reduced.
+
+```json
+{
+ "runtimes": {
+ "runsc": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--gso=false"
+ ]
+ }
+ }
+}
+```
+
+[netstack]: https://github.com/google/netstack
diff --git a/g3doc/user_guide/platforms.md b/g3doc/user_guide/platforms.md
new file mode 100644
index 000000000..752025881
--- /dev/null
+++ b/g3doc/user_guide/platforms.md
@@ -0,0 +1,95 @@
+# Changing Platforms
+
+[TOC]
+
+This guide described how to change the
+[platform](../architecture_guide/platforms.md) used by `runsc`.
+
+## Prerequisites
+
+If you intend to run the KVM platform, you will also to have KVM installed on
+your system. If you are running a Debian based system like Debian or Ubuntu you
+can usually do this by ensuring the module is loaded, and permissions are
+appropriately set on the `/dev/kvm` device.
+
+If you have an Intel CPU:
+
+```bash
+sudo modprobe kvm-intel && sudo chmod a+rw /dev/kvm
+```
+
+If you have an AMD CPU:
+
+```bash
+sudo modprobe kvm-amd && sudo chmod a+rw /dev/kvm
+```
+
+If you are using a virtual machine you will need to make sure that nested
+virtualization is configured. Here are links to documents on how to set up
+nested virtualization in several popular environments:
+
+* Google Cloud: [Enabling Nested Virtualization for VM Instances][nested-gcp]
+* Microsoft Azure:
+ [How to enable nested virtualization in an Azure VM][nested-azure]
+* VirtualBox: [Nested Virtualization][nested-virtualbox]
+* KVM: [Nested Guests][nested-kvm]
+
+***Note: nested virtualization will have poor performance and is historically a
+cause of security issues (e.g.
+[CVE-2018-12904](https://nvd.nist.gov/vuln/detail/CVE-2018-12904)). It is not
+recommended for production.***
+
+## Configuring Docker
+
+The platform is selected by the `--platform` command line flag passed to
+`runsc`. By default, the ptrace platform is selected. For example, to select the
+KVM platform, modify your Docker configuration (`/etc/docker/daemon.json`) to
+pass the `--platform` argument:
+
+```json
+{
+ "runtimes": {
+ "runsc": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--platform=kvm"
+ ]
+ }
+ }
+}
+```
+
+You must restart the Docker daemon after making changes to this file, typically
+this is done via `systemd`:
+
+```bash
+sudo systemctl restart docker
+```
+
+Note that you may configure multiple runtimes using different platforms. For
+example, the following configuration has one configuration for ptrace and one
+for the KVM platform:
+
+```json
+{
+ "runtimes": {
+ "runsc-ptrace": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--platform=ptrace"
+ ]
+ },
+ "runsc-kvm": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--platform=kvm"
+ ]
+ }
+ }
+}
+```
+
+[nested-azure]: https://docs.microsoft.com/en-us/azure/virtual-machines/windows/nested-virtualization
+[nested-gcp]: https://cloud.google.com/compute/docs/instances/enable-nested-virtualization-vm-instances
+[nested-virtualbox]: https://www.virtualbox.org/manual/UserManual.html#nested-virt
+[nested-kvm]: https://www.linux-kvm.org/page/Nested_Guests
diff --git a/g3doc/user_guide/quick_start/BUILD b/g3doc/user_guide/quick_start/BUILD
new file mode 100644
index 000000000..63f17f9cb
--- /dev/null
+++ b/g3doc/user_guide/quick_start/BUILD
@@ -0,0 +1,33 @@
+load("//website:defs.bzl", "doc")
+
+package(
+ default_visibility = ["//website:__pkg__"],
+ licenses = ["notice"],
+)
+
+doc(
+ name = "docker",
+ src = "docker.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/quick_start/docker/",
+ subcategory = "Quick Start",
+ weight = "11",
+)
+
+doc(
+ name = "oci",
+ src = "oci.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/quick_start/oci/",
+ subcategory = "Quick Start",
+ weight = "12",
+)
+
+doc(
+ name = "kubernetes",
+ src = "kubernetes.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/quick_start/kubernetes/",
+ subcategory = "Quick Start",
+ weight = "13",
+)
diff --git a/g3doc/user_guide/quick_start/docker.md b/g3doc/user_guide/quick_start/docker.md
new file mode 100644
index 000000000..6ad594ecc
--- /dev/null
+++ b/g3doc/user_guide/quick_start/docker.md
@@ -0,0 +1,96 @@
+# Docker Quick Start
+
+> Note: This guide requires Docker version 17.09.0 or greater. Refer to the
+> [Docker documentation][docker] for how to install it.
+
+This guide will help you quickly get started running Docker containers using
+gVisor.
+
+First, follow the [Installation guide][install].
+
+If you use the `apt` repository or the `automated` install, then you can skip
+the next section and proceed straight to running a container.
+
+## Configuring Docker
+
+First you will need to configure Docker to use `runsc` by adding a runtime entry
+to your Docker configuration (e.g. `/etc/docker/daemon.json`). The easiest way
+to this is via the `runsc install` command. This will install a docker runtime
+named "runsc" by default.
+
+```bash
+sudo runsc install
+```
+
+You may also wish to install a runtime entry for debugging. The `runsc install`
+command can accept options that will be passed to the runtime when it is invoked
+by Docker.
+
+```bash
+sudo runsc install --runtime runsc-debug -- \
+ --debug \
+ --debug-log=/tmp/runsc-debug.log \
+ --strace \
+ --log-packets
+```
+
+You must restart the Docker daemon after installing the runtime. Typically this
+is done via `systemd`:
+
+```bash
+sudo systemctl restart docker
+```
+
+## Running a container
+
+Now run your container using the `runsc` runtime:
+
+```bash
+docker run --runtime=runsc --rm hello-world
+```
+
+You can also run a terminal to explore the container.
+
+```bash
+docker run --runtime=runsc --rm -it ubuntu /bin/bash
+```
+
+Many docker options are compatible with gVisor, try them out. Here is an
+example:
+
+```bash
+docker run --runtime=runsc --rm --link backend:database -v ~/bin:/tools:ro -p 8080:80 --cpus=0.5 -it busybox telnet towel.blinkenlights.nl
+```
+
+## Verify the runtime
+
+You can verify that you are running in gVisor using the `dmesg` command.
+
+```text
+$ docker run --runtime=runsc -it ubuntu dmesg
+[ 0.000000] Starting gVisor...
+[ 0.354495] Daemonizing children...
+[ 0.564053] Constructing home...
+[ 0.976710] Preparing for the zombie uprising...
+[ 1.299083] Creating process schedule...
+[ 1.479987] Committing treasure map to memory...
+[ 1.704109] Searching for socket adapter...
+[ 1.748935] Generating random numbers by fair dice roll...
+[ 2.059747] Digging up root...
+[ 2.259327] Checking naughty and nice process list...
+[ 2.610538] Rewriting operating system in Javascript...
+[ 2.613217] Ready!
+```
+
+Note that this is easily replicated by an attacker so applications should never
+use `dmesg` to verify the runtime in a security sensitive context.
+
+Next, look at the different options available for gVisor: [platform][platforms],
+[network][networking], [filesystem][filesystem].
+
+[docker]: https://docs.docker.com/install/
+[storage-driver]: https://docs.docker.com/engine/reference/commandline/dockerd/#daemon-storage-driver
+[install]: /docs/user_guide/install/
+[filesystem]: /docs/user_guide/filesystem/
+[networking]: /docs/user_guide/networking/
+[platforms]: /docs/user_guide/platforms/
diff --git a/g3doc/user_guide/quick_start/kubernetes.md b/g3doc/user_guide/quick_start/kubernetes.md
new file mode 100644
index 000000000..f875d8002
--- /dev/null
+++ b/g3doc/user_guide/quick_start/kubernetes.md
@@ -0,0 +1,36 @@
+# Kubernetes Quick Start
+
+gVisor can be used to run Kubernetes pods and has several integration points
+with Kubernetes.
+
+## Using Minikube
+
+gVisor can run sandboxed containers in a Kubernetes cluster with Minikube. After
+the gVisor addon is enabled, pods with `io.kubernetes.cri.untrusted-workload`
+set to true will execute with `runsc`. Follow [these instructions][minikube] to
+enable gVisor addon.
+
+## Using Containerd
+
+You can also setup Kubernetes nodes to run pods in gvisor using the
+[containerd][containerd] CRI runtime and the `gvisor-containerd-shim`. You can
+use either the `io.kubernetes.cri.untrusted-workload` annotation or
+[RuntimeClass][runtimeclass] to run Pods with `runsc`. You can find instructions
+[here][gvisor-containerd-shim].
+
+## Using GKE Sandbox
+
+[GKE Sandbox][gke-sandbox] is available in [Google Kubernetes Engine][gke]. You
+just need to deploy a node pool with gVisor enabled in your cluster, and it will
+run pods annotated with `runtimeClassName: gvisor` inside a gVisor sandbox for
+you. [Here][wordpress-quick] is a quick example showing how to deploy a
+WordPress site. You can view the full documentation [here][gke-sandbox-docs].
+
+[containerd]: https://containerd.io/
+[minikube]: https://github.com/kubernetes/minikube/blob/master/deploy/addons/gvisor/README.md
+[gke]: https://cloud.google.com/kubernetes-engine/
+[gke-sandbox]: https://cloud.google.com/kubernetes-engine/sandbox/
+[gke-sandbox-docs]: https://cloud.google.com/kubernetes-engine/docs/how-to/sandbox-pods
+[gvisor-containerd-shim]: https://github.com/google/gvisor-containerd-shim
+[runtimeclass]: https://kubernetes.io/docs/concepts/containers/runtime-class/
+[wordpress-quick]: /docs/tutorials/kubernetes/
diff --git a/g3doc/user_guide/quick_start/oci.md b/g3doc/user_guide/quick_start/oci.md
new file mode 100644
index 000000000..e7768946b
--- /dev/null
+++ b/g3doc/user_guide/quick_start/oci.md
@@ -0,0 +1,43 @@
+# OCI Quick Start
+
+This guide will quickly get you started running your first gVisor sandbox
+container using the runtime directly with the default platform.
+
+First, follow the [Installation guide][install].
+
+## Run an OCI compatible container
+
+Now we will create an [OCI][oci] container bundle to run our container. First we
+will create a root directory for our bundle.
+
+```bash
+mkdir bundle
+cd bundle
+```
+
+Create a root file system for the container. We will use the Docker
+`hello-world` image as the basis for our container.
+
+```bash
+mkdir rootfs
+docker export $(docker create hello-world) | tar -xf - -C rootfs
+```
+
+Next, create an specification file called `config.json` that contains our
+container specification. We tell the container to run the `/hello` program.
+
+```bash
+runsc spec -- /hello
+```
+
+Finally run the container.
+
+```bash
+sudo runsc run hello
+```
+
+Next try [using CNI to set up networking](../../../tutorials/cni/) or
+[running gVisor using Docker](../docker/).
+
+[oci]: https://opencontainers.org/
+[install]: /docs/user_guide/install
diff --git a/g3doc/user_guide/tutorials/BUILD b/g3doc/user_guide/tutorials/BUILD
new file mode 100644
index 000000000..caae98623
--- /dev/null
+++ b/g3doc/user_guide/tutorials/BUILD
@@ -0,0 +1,37 @@
+load("//website:defs.bzl", "doc")
+
+package(
+ default_visibility = ["//website:__pkg__"],
+ licenses = ["notice"],
+)
+
+doc(
+ name = "docker",
+ src = "docker.md",
+ category = "User Guide",
+ permalink = "/docs/tutorials/docker/",
+ subcategory = "Tutorials",
+ weight = "21",
+)
+
+doc(
+ name = "cni",
+ src = "cni.md",
+ category = "User Guide",
+ permalink = "/docs/tutorials/cni/",
+ subcategory = "Tutorials",
+ weight = "22",
+)
+
+doc(
+ name = "kubernetes",
+ src = "kubernetes.md",
+ category = "User Guide",
+ data = [
+ "add-node-pool.png",
+ "node-pool-button.png",
+ ],
+ permalink = "/docs/tutorials/kubernetes/",
+ subcategory = "Tutorials",
+ weight = "33",
+)
diff --git a/g3doc/user_guide/tutorials/add-node-pool.png b/g3doc/user_guide/tutorials/add-node-pool.png
new file mode 100644
index 000000000..e4560359b
--- /dev/null
+++ b/g3doc/user_guide/tutorials/add-node-pool.png
Binary files differ
diff --git a/g3doc/user_guide/tutorials/cni.md b/g3doc/user_guide/tutorials/cni.md
new file mode 100644
index 000000000..ce2fd09a8
--- /dev/null
+++ b/g3doc/user_guide/tutorials/cni.md
@@ -0,0 +1,174 @@
+# Using CNI
+
+This tutorial will show you how to set up networking for a gVisor sandbox using
+the
+[Container Networking Interface (CNI)](https://github.com/containernetworking/cni).
+
+## Install CNI Plugins
+
+First you will need to install the CNI plugins. CNI plugins are used to set up a
+network namespace that `runsc` can use with the sandbox.
+
+Start by creating the directories for CNI plugin binaries:
+
+```
+sudo mkdir -p /opt/cni/bin
+```
+
+Download the CNI plugins:
+
+```
+wget https://github.com/containernetworking/plugins/releases/download/v0.8.3/cni-plugins-linux-amd64-v0.8.3.tgz
+```
+
+Next, unpack the plugins into the CNI binary directory:
+
+```
+sudo tar -xvf cni-plugins-linux-amd64-v0.8.3.tgz -C /opt/cni/bin/
+```
+
+## Configure CNI Plugins
+
+This section will show you how to configure CNI plugins. This tutorial will use
+the "bridge" and "loopback" plugins which will create the necessary bridge and
+loopback devices in our network namespace. However, you should be able to use
+any CNI compatible plugin to set up networking for gVisor sandboxes.
+
+The bridge plugin configuration specifies the IP address subnet range for IP
+addresses that will be assigned to sandboxes as well as the network routing
+configuration. This tutorial will assign IP addresses from the `10.22.0.0/16`
+range and allow all outbound traffic, however you can modify this configuration
+to suit your use case.
+
+Create the bridge and loopback plugin configurations:
+
+```
+sudo mkdir -p /etc/cni/net.d
+
+sudo sh -c 'cat > /etc/cni/net.d/10-bridge.conf << EOF
+{
+ "cniVersion": "0.4.0",
+ "name": "mynet",
+ "type": "bridge",
+ "bridge": "cni0",
+ "isGateway": true,
+ "ipMasq": true,
+ "ipam": {
+ "type": "host-local",
+ "subnet": "10.22.0.0/16",
+ "routes": [
+ { "dst": "0.0.0.0/0" }
+ ]
+ }
+}
+EOF'
+
+sudo sh -c 'cat > /etc/cni/net.d/99-loopback.conf << EOF
+{
+ "cniVersion": "0.4.0",
+ "name": "lo",
+ "type": "loopback"
+}
+EOF'
+```
+
+## Create a Network Namespace
+
+For each gVisor sandbox you will create a network namespace and configure it
+using CNI. First, create a random network namespace name and then create the
+namespace.
+
+The network namespace path will then be `/var/run/netns/${CNI_CONTAINERID}`.
+
+```
+export CNI_PATH=/opt/cni/bin
+export CNI_CONTAINERID=$(printf '%x%x%x%x' $RANDOM $RANDOM $RANDOM $RANDOM)
+export CNI_COMMAND=ADD
+export CNI_NETNS=/var/run/netns/${CNI_CONTAINERID}
+
+sudo ip netns add ${CNI_CONTAINERID}
+```
+
+Next, run the bridge and loopback plugins to apply the configuration that was
+created earlier to the namespace. Each plugin outputs some JSON indicating the
+results of executing the plugin. For example, The bridge plugin's response
+includes the IP address assigned to the ethernet device created in the network
+namespace. Take note of the IP address for use later.
+
+```
+export CNI_IFNAME="eth0"
+sudo -E /opt/cni/bin/bridge < /etc/cni/net.d/10-bridge.conf
+export CNI_IFNAME="lo"
+sudo -E /opt/cni/bin/loopback < /etc/cni/net.d/99-loopback.conf
+```
+
+Get the IP address assigned to our sandbox:
+
+```
+POD_IP=$(sudo ip netns exec ${CNI_CONTAINERID} ip -4 addr show eth0 | grep -oP '(?<=inet\s)\d+(\.\d+){3}')
+```
+
+## Create the OCI Bundle
+
+Now that our network namespace is created and configured, we can create the OCI
+bundle for our container. As part of the bundle's `config.json` we will specify
+that the container use the network namespace that we created.
+
+The container will run a simple python webserver that we will be able to connect
+to via the IP address assigned to it via the bridge CNI plugin.
+
+Create the bundle and root filesystem directories:
+
+```
+sudo mkdir -p bundle
+cd bundle
+sudo mkdir rootfs
+sudo docker export $(docker create python) | sudo tar --same-owner -pxf - -C rootfs
+sudo mkdir -p rootfs/var/www/html
+sudo sh -c 'echo "Hello World!" > rootfs/var/www/html/index.html'
+```
+
+Next create the `config.json` specifying the network namespace.
+
+```
+sudo /usr/local/bin/runsc spec \
+ --cwd /var/www/html \
+ --netns /var/run/netns/${CNI_CONTAINERID} \
+ -- python -m http.server
+```
+
+## Run the Container
+
+Now we can run and connect to the webserver. Run the container in gVisor. Use
+the same ID used for the network namespace to be consistent:
+
+```
+sudo runsc run -detach ${CNI_CONTAINERID}
+```
+
+Connect to the server via the sandbox's IP address:
+
+```
+curl http://${POD_IP}:8000/
+```
+
+You should see the server returning `Hello World!`.
+
+## Cleanup
+
+After you are finished running the container, you can clean up the network
+namespace .
+
+```
+sudo runsc kill ${CNI_CONTAINERID}
+sudo runsc delete ${CNI_CONTAINERID}
+
+export CNI_COMMAND=DEL
+
+export CNI_IFNAME="lo"
+sudo -E /opt/cni/bin/loopback < /etc/cni/net.d/99-loopback.conf
+export CNI_IFNAME="eth0"
+sudo -E /opt/cni/bin/bridge < /etc/cni/net.d/10-bridge.conf
+
+sudo ip netns delete ${CNI_CONTAINERID}
+```
diff --git a/g3doc/user_guide/tutorials/docker.md b/g3doc/user_guide/tutorials/docker.md
new file mode 100644
index 000000000..705560038
--- /dev/null
+++ b/g3doc/user_guide/tutorials/docker.md
@@ -0,0 +1,68 @@
+# WordPress with Docker
+
+This page shows you how to deploy a sample [WordPress][wordpress] site using
+[Docker][docker].
+
+### Before you begin
+
+[Follow these instructions][docker-install] to install runsc with Docker. This
+document assumes that the runtime name chosen is `runsc`.
+
+### Running WordPress
+
+Now, let's deploy a WordPress site using Docker. WordPress site requires two
+containers: web server in the frontend, MySQL database in the backend.
+
+First, let's define a few environment variables that are shared between both
+containers:
+
+```bash
+export MYSQL_PASSWORD=${YOUR_SECRET_PASSWORD_HERE?}
+export MYSQL_DB=wordpress
+export MYSQL_USER=wordpress
+```
+
+Next, let's start the database container running MySQL and wait until the
+database is initialized:
+
+```bash
+docker run --runtime=runsc --name mysql -d \
+ -e MYSQL_RANDOM_ROOT_PASSWORD=1 \
+ -e MYSQL_PASSWORD="${MYSQL_PASSWORD}" \
+ -e MYSQL_DATABASE="${MYSQL_DB}" \
+ -e MYSQL_USER="${MYSQL_USER}" \
+ mysql:5.7
+
+# Wait until this message appears in the log.
+docker logs mysql |& grep 'port: 3306 MySQL Community Server (GPL)'
+```
+
+Once the database is running, you can start the WordPress frontend. We use the
+`--link` option to connect the frontend to the database, and expose the
+WordPress to port 8080 on the localhost.
+
+```bash
+docker run --runtime=runsc --name wordpress -d \
+ --link mysql:mysql \
+ -p 8080:80 \
+ -e WORDPRESS_DB_HOST=mysql \
+ -e WORDPRESS_DB_USER="${MYSQL_USER}" \
+ -e WORDPRESS_DB_PASSWORD="${MYSQL_PASSWORD}" \
+ -e WORDPRESS_DB_NAME="${MYSQL_DB}" \
+ -e WORDPRESS_TABLE_PREFIX=wp_ \
+ wordpress
+```
+
+Now, you can access the WordPress website pointing your favorite browser to
+<http://localhost:8080>.
+
+Congratulations! You have just deployed a WordPress site using Docker.
+
+### What's next
+
+[Learn how to deploy WordPress with Kubernetes][wordpress-k8s].
+
+[docker]: https://www.docker.com/
+[docker-install]: /docs/user_guide/quick_start/docker/
+[wordpress]: https://wordpress.com/
+[wordpress-k8s]: /docs/tutorials/kubernetes/
diff --git a/g3doc/user_guide/tutorials/kubernetes.md b/g3doc/user_guide/tutorials/kubernetes.md
new file mode 100644
index 000000000..d2a94b1b7
--- /dev/null
+++ b/g3doc/user_guide/tutorials/kubernetes.md
@@ -0,0 +1,134 @@
+# WordPress with Kubernetes
+
+This page shows you how to deploy a sample [WordPress][wordpress] site using
+[GKE Sandbox][gke-sandbox].
+
+### Before you begin
+
+Take the following steps to enable the Kubernetes Engine API:
+
+1. Visit the [Kubernetes Engine page][project-selector] in the Google Cloud
+ Platform Console.
+1. Create or select a project.
+
+### Creating a node pool with gVisor enabled
+
+Create a node pool inside your cluster with option `--sandbox type=gvisor` added
+to the command, like below:
+
+```bash
+gcloud beta container node-pools create sandbox-pool --cluster=${CLUSTER_NAME} --image-type=cos_containerd --sandbox type=gvisor
+```
+
+If you prefer to use the console, select your cluster and select the **ADD NODE
+POOL** button:
+
+![+ ADD NODE POOL](./node-pool-button.png)
+
+Then select the **Image type** with **Containerd** and select **Enable sandbox
+with gVisor** option. Select other options as you like:
+
+![+ NODE POOL](./add-node-pool.png)
+
+### Check that gVisor is enabled
+
+The gvisor RuntimeClass is instantiated during node creation. You can check for
+the existence of the gvisor RuntimeClass using the following command:
+
+```bash
+kubectl get runtimeclasses
+```
+
+### Wordpress deployment
+
+Now, let's deploy a WordPress site using GKE Sandbox. WordPress site requires
+two pods: web server in the frontend, MySQL database in the backend. Both
+applications use PersistentVolumes to store the site data data. In addition,
+they use secret store to share MySQL password between them.
+
+First, let's download the deployment configuration files to add the runtime
+class annotation to them:
+
+```bash
+curl -LO https://k8s.io/examples/application/wordpress/wordpress-deployment.yaml
+curl -LO https://k8s.io/examples/application/wordpress/mysql-deployment.yaml
+```
+
+Add a **spec.template.spec.runtimeClassName** set to **gvisor** to both files,
+as shown below:
+
+**wordpress-deployment.yaml:** ```yaml apiVersion: v1 kind: Service metadata:
+name: wordpress labels: app: wordpress spec: ports: - port: 80 selector: app:
+wordpress tier: frontend
+
+## type: LoadBalancer
+
+apiVersion: v1 kind: PersistentVolumeClaim metadata: name: wp-pv-claim labels:
+app: wordpress spec: accessModes: - ReadWriteOnce resources: requests:
+
+## storage: 20Gi
+
+apiVersion: apps/v1 kind: Deployment metadata: name: wordpress labels: app:
+wordpress spec: selector: matchLabels: app: wordpress tier: frontend strategy:
+type: Recreate template: metadata: labels: app: wordpress tier: frontend spec:
+runtimeClassName: gvisor # ADD THIS LINE containers: - image:
+wordpress:4.8-apache name: wordpress env: - name: WORDPRESS_DB_HOST value:
+wordpress-mysql - name: WORDPRESS_DB_PASSWORD valueFrom: secretKeyRef: name:
+mysql-pass key: password ports: - containerPort: 80 name: wordpress
+volumeMounts: - name: wordpress-persistent-storage mountPath: /var/www/html
+volumes: - name: wordpress-persistent-storage persistentVolumeClaim: claimName:
+wp-pv-claim ```
+
+**mysql-deployment.yaml:** ```yaml apiVersion: v1 kind: Service metadata: name:
+wordpress-mysql labels: app: wordpress spec: ports: - port: 3306 selector: app:
+wordpress tier: mysql
+
+## clusterIP: None
+
+apiVersion: v1 kind: PersistentVolumeClaim metadata: name: mysql-pv-claim
+labels: app: wordpress spec: accessModes: - ReadWriteOnce resources: requests:
+
+## storage: 20Gi
+
+apiVersion: apps/v1 kind: Deployment metadata: name: wordpress-mysql labels:
+app: wordpress spec: selector: matchLabels: app: wordpress tier: mysql strategy:
+type: Recreate template: metadata: labels: app: wordpress tier: mysql spec:
+runtimeClassName: gvisor # ADD THIS LINE containers: - image: mysql:5.6 name:
+mysql env: - name: MYSQL_ROOT_PASSWORD valueFrom: secretKeyRef: name: mysql-pass
+key: password ports: - containerPort: 3306 name: mysql volumeMounts: - name:
+mysql-persistent-storage mountPath: /var/lib/mysql volumes: - name:
+mysql-persistent-storage persistentVolumeClaim: claimName: mysql-pv-claim ```
+
+Note that apart from `runtimeClassName: gvisor`, nothing else about the
+Deployment has is changed.
+
+You are now ready to deploy the entire application. Just create a secret to
+store MySQL's password and *apply* both deployments:
+
+```bash
+kubectl create secret generic mysql-pass --from-literal=password=${YOUR_SECRET_PASSWORD_HERE?}
+kubectl apply -f mysql-deployment.yaml
+kubectl apply -f wordpress-deployment.yaml
+```
+
+Wait for the deployments to be ready and an external IP to be assigned to the
+Wordpress service:
+
+```bash
+watch kubectl get service wordpress
+```
+
+Now, copy the service `EXTERNAL-IP` from above to your favorite browser to view
+and configure your new WordPress site.
+
+Congratulations! You have just deployed a WordPress site using GKE Sandbox.
+
+### What's next
+
+To learn more about GKE Sandbox and how to run your deployment securely, take a
+look at the [documentation][gke-sandbox-docs].
+
+[gke-sandbox-docs]: https://cloud.google.com/kubernetes-engine/docs/how-to/sandbox-pods
+[gke-sandbox]: https://cloud.google.com/kubernetes-engine/sandbox/
+[project-selector]: https://console.cloud.google.com/projectselector/kubernetes
+[wordpress]: https://wordpress.com/
diff --git a/g3doc/user_guide/tutorials/node-pool-button.png b/g3doc/user_guide/tutorials/node-pool-button.png
new file mode 100644
index 000000000..bee0c11dc
--- /dev/null
+++ b/g3doc/user_guide/tutorials/node-pool-button.png
Binary files differ
diff --git a/go.mod b/go.mod
new file mode 100644
index 000000000..434fa713f
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,20 @@
+module gvisor.dev/gvisor
+
+go 1.14
+
+require (
+ github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422
+ github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079
+ github.com/golang/protobuf v1.3.1
+ github.com/google/btree v1.0.0
+ github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8
+ github.com/kr/pretty v0.2.0 // indirect
+ github.com/kr/pty v1.1.1
+ github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78
+ github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2
+ github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e
+ github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 // indirect
+ golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527
+ golang.org/x/time v0.0.0-20191024005414-555d28b269f0
+ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
+)
diff --git a/go.sum b/go.sum
new file mode 100644
index 000000000..c44a17c71
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,32 @@
+github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422 h1:+FKjzBIdfBHYDvxCv+djmDJdes/AoDtg8gpcxowBlF8=
+github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422/go.mod h1:b6Nc7NRH5C4aCISLry0tLnTjcuTEvoiqcWDdsU0sOGM=
+github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 h1:JFTFz3HZTGmgMz4E1TabNBNJljROSYgja1b4l50FNVs=
+github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
+github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
+github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
+github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
+github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
+github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8 h1:GZGUPQiZfYrd9uOqyqwbQcHPkz/EZJVkZB1MkaO9UBI=
+github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
+github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=
+github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
+github.com/kr/pty v1.1.1 h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw=
+github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
+github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
+github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78 h1:d9F+LNYwMyi3BDN4GzZdaSiq4otb8duVEWyZjeUtOQI=
+github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
+github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 h1:b6uOv7YOFK0TYG7HtkIgExQo+2RdLuwRft63jn2HWj8=
+github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww=
+github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e h1:/Tdc23Arz1OtdIsBY2utWepGRQ9fEAJlhkdoLzWMK8Q=
+github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
+github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 h1:J9gO8RJCAFlln1jsvRba/CWVUnMHwObklfxxjErl1uk=
+github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI=
+golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
+golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So=
+golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
+golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
+gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
diff --git a/images/BUILD b/images/BUILD
new file mode 100644
index 000000000..a50f388e9
--- /dev/null
+++ b/images/BUILD
@@ -0,0 +1,11 @@
+package(licenses = ["notice"])
+
+# The images filegroup is definitely not a hermetic target, and requires Make
+# to do anything meaningful with. However, this will be slurped up and used by
+# the tools/installer/images.sh installer, which will ensure that all required
+# images are available locally when running vm_tests.
+filegroup(
+ name = "images",
+ srcs = glob(["**"]),
+ visibility = ["//tools/installers:__pkg__"],
+)
diff --git a/images/Makefile b/images/Makefile
new file mode 100644
index 000000000..1485607bd
--- /dev/null
+++ b/images/Makefile
@@ -0,0 +1,93 @@
+#!/usr/bin/make -f
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# ARCH is the architecture used for the build. This may be overriden at the
+# command line in order to perform a cross-build (in a limited capacity).
+ARCH := $(shell uname -m)
+
+# Note that the image prefixes used here must match the image mangling in
+# runsc/testutil.MangleImage. Names are mangled in this way to ensure that all
+# tests are using locally-defined images (that are consistent and idempotent).
+REMOTE_IMAGE_PREFIX ?= gcr.io/gvisor-presubmit
+LOCAL_IMAGE_PREFIX ?= gvisor.dev/images
+ALL_IMAGES := $(subst /,_,$(subst ./,,$(shell find . -name Dockerfile -exec dirname {} \;)))
+ifneq ($(ARCH),$(shell uname -m))
+DOCKER_PLATFORM_ARGS := --platform=$(ARCH)
+else
+DOCKER_PLATFORM_ARGS :=
+endif
+
+list-all-images:
+ @for image in $(ALL_IMAGES); do echo $${image}; done
+.PHONY: list-build-images
+
+%-all-images:
+ @$(MAKE) $(patsubst %,$*-%,$(ALL_IMAGES))
+
+# tag is a function that returns the tag name, given an image.
+#
+# The tag constructed is used to memoize the image generated (see README.md).
+# This scheme is used to enable aggressive caching in a central repository, but
+# ensuring that images will always be sourced using the local files if there
+# are changes.
+path = $(subst _,/,$(1))
+tag = $(shell find $(call path,$(1)) -type f -print | sort | xargs -n 1 sha256sum | sha256sum - | cut -c 1-16)
+remote_image = $(REMOTE_IMAGE_PREFIX)/$(subst _,/,$(1))_$(ARCH):$(call tag,$(1))
+local_image = $(LOCAL_IMAGE_PREFIX)/$(subst _,/,$(1))
+
+# rebuild builds the image locally. Only the "remote" tag will be applied. Note
+# we need to explicitly repull the base layer in order to ensure that the
+# architecture is correct. Note that we use the term "rebuild" here to avoid
+# conflicting with the bazel "build" terminology, which is used elsewhere.
+rebuild-%: register-cross
+ FROM=$(shell grep FROM $(call path,$*)/Dockerfile | cut -d' ' -f2-) && \
+ docker pull $(DOCKER_PLATFORM_ARGS) $$FROM
+ T=$$(mktemp -d) && cp -a $(call path,$*)/* $$T && \
+ docker build $(DOCKER_PLATFORM_ARGS) -t $(call remote_image,$*) $$T && \
+ rm -rf $$T
+
+# pull will check the "remote" image and pull if necessary. If the remote image
+# must be pulled, then it will tag with the latest local target. Note that pull
+# may fail if the remote image is not available.
+pull-%:
+ docker pull $(DOCKER_PLATFORM_ARGS) $(call remote_image,$*)
+
+# load will either pull the "remote" or build it locally. This is the preferred
+# entrypoint, as it should never file. The local tag should always be set after
+# this returns (either by the pull or the build).
+load-%:
+ docker inspect $(call remote_image,$*) >/dev/null 2>&1 || $(MAKE) pull-$* || $(MAKE) rebuild-$*
+ docker tag $(call remote_image,$*) $(call local_image,$*)
+
+# push pushes the remote image, after either pulling (to validate that the tag
+# already exists) or building manually.
+push-%: load-%
+ docker push $(call remote_image,$*)
+
+# register-cross registers the necessary qemu binaries for cross-compilation.
+# This may be used by any target that may execute containers that are not the
+# native format.
+register-cross:
+ifneq ($(ARCH),$(shell uname -m))
+ifeq (,$(wildcard /proc/sys/fs/binfmt_misc/qemu-*))
+ docker run --rm --privileged multiarch/qemu-user-static --reset --persistent yes
+else
+ @true # Already registered.
+endif
+else
+ @true # No cross required.
+endif
+.PHONY: register-cross
diff --git a/images/README.md b/images/README.md
new file mode 100644
index 000000000..d2efb5db4
--- /dev/null
+++ b/images/README.md
@@ -0,0 +1,61 @@
+# Container Images
+
+This directory contains all images used by tests.
+
+Note that all these images must be pushed to the testing project hosted on
+[Google Container Registry][gcr]. This will happen automatically as part of
+continuous integration. This will speed up loading as images will not need to be
+built from scratch for each test run.
+
+Image tooling is accessible via `make`, specifically via `tools/images.mk`.
+
+## Why make?
+
+Make is used because it can bootstrap the `default` image, which contains
+`bazel` and all other parts of the toolchain.
+
+## Listing images
+
+To list all images, use `make list-all-images` from the top-level directory.
+
+## Loading and referencing images
+
+To build a specific image, use `make load-<image>` from the top-level directory.
+This will ensure that an image `gvisor.dev/images/<image>:latest` is available.
+
+Images should always be referred to via the `gvisor.dev/images` canonical path.
+This tag exists only locally, but serves to decouple tests from the underlying
+image infrastructure.
+
+The continuous integration system can either take fine-grained dependencies on
+single images via individual `load` targets, or pull all images via a single
+`load-all-images` invocation.
+
+## Adding new images
+
+To add a new image, create a new directory under `images` containing a
+Dockerfile and any other files that the image requires. You may choose to add to
+an existing subdirectory if applicable, or create a new one.
+
+All images will be tagged and memoized using a hash of the directory contents.
+As a result, every image should be made completely reproducible if possible.
+This means using fixed tags and fixed versions whenever feasible.
+
+Notes that images should also be made architecture-independent if possible. The
+build scripts will handling loading the appropriate architecture onto the
+machine and tagging it with the single canonical tag.
+
+Add a `load-<image>` dependency in the Makefile if the image is required for a
+particular set of tests. This target will pull the tag from the image repository
+if available.
+
+## Building and pushing images
+
+All images can be built manually by running `build-<image>` and pushed using
+`push-<image>`. Note that you can also use `build-all-images` and
+`push-all-images`. Note that pushing will require appropriate permissions in the
+project.
+
+The continuous integration system can either take fine-grained dependencies on
+individual `push` targets, or ensure all images are up-to-date with a single
+`push-all-images` invocation.
diff --git a/images/basic/alpine/Dockerfile b/images/basic/alpine/Dockerfile
new file mode 100644
index 000000000..12b26040a
--- /dev/null
+++ b/images/basic/alpine/Dockerfile
@@ -0,0 +1 @@
+FROM alpine:3.11.5
diff --git a/images/basic/busybox/Dockerfile b/images/basic/busybox/Dockerfile
new file mode 100644
index 000000000..79b3f683a
--- /dev/null
+++ b/images/basic/busybox/Dockerfile
@@ -0,0 +1 @@
+FROM busybox:1.31.1
diff --git a/images/basic/httpd/Dockerfile b/images/basic/httpd/Dockerfile
new file mode 100644
index 000000000..83bc0ed88
--- /dev/null
+++ b/images/basic/httpd/Dockerfile
@@ -0,0 +1 @@
+FROM httpd:2.4.43
diff --git a/images/basic/mysql/Dockerfile b/images/basic/mysql/Dockerfile
new file mode 100644
index 000000000..95da9c48d
--- /dev/null
+++ b/images/basic/mysql/Dockerfile
@@ -0,0 +1 @@
+FROM mysql:8.0.19
diff --git a/images/basic/nginx/Dockerfile b/images/basic/nginx/Dockerfile
new file mode 100644
index 000000000..af2e62526
--- /dev/null
+++ b/images/basic/nginx/Dockerfile
@@ -0,0 +1 @@
+FROM nginx:1.17.9
diff --git a/images/basic/python/Dockerfile b/images/basic/python/Dockerfile
new file mode 100644
index 000000000..acf07cca9
--- /dev/null
+++ b/images/basic/python/Dockerfile
@@ -0,0 +1,2 @@
+FROM python:3
+ENTRYPOINT ["python", "-m", "http.server", "8080"]
diff --git a/images/basic/resolv/Dockerfile b/images/basic/resolv/Dockerfile
new file mode 100644
index 000000000..13665bdaf
--- /dev/null
+++ b/images/basic/resolv/Dockerfile
@@ -0,0 +1 @@
+FROM k8s.gcr.io/busybox:latest
diff --git a/images/basic/ruby/Dockerfile b/images/basic/ruby/Dockerfile
new file mode 100644
index 000000000..d290418fb
--- /dev/null
+++ b/images/basic/ruby/Dockerfile
@@ -0,0 +1 @@
+FROM ruby:2.7.1
diff --git a/images/basic/tomcat/Dockerfile b/images/basic/tomcat/Dockerfile
new file mode 100644
index 000000000..c7db39a36
--- /dev/null
+++ b/images/basic/tomcat/Dockerfile
@@ -0,0 +1 @@
+FROM tomcat:8.0
diff --git a/images/basic/ubuntu/Dockerfile b/images/basic/ubuntu/Dockerfile
new file mode 100644
index 000000000..331b71343
--- /dev/null
+++ b/images/basic/ubuntu/Dockerfile
@@ -0,0 +1 @@
+FROM ubuntu:trusty
diff --git a/images/default/Dockerfile b/images/default/Dockerfile
new file mode 100644
index 000000000..397082b02
--- /dev/null
+++ b/images/default/Dockerfile
@@ -0,0 +1,16 @@
+FROM fedora:31
+# Install bazel.
+RUN dnf install -y dnf-plugins-core && dnf copr enable -y vbatts/bazel
+RUN dnf install -y git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static libstdc++-static patch
+RUN pip install pycparser
+RUN dnf install -y bazel3
+# Install gcloud.
+RUN curl https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-sdk-289.0.0-linux-x86_64.tar.gz | \
+ tar zxvf - google-cloud-sdk && \
+ google-cloud-sdk/install.sh && \
+ ln -s /google-cloud-sdk/bin/gcloud /usr/bin/gcloud
+# Install Docker client for the website build.
+RUN dnf config-manager --add-repo https://download.docker.com/linux/fedora/docker-ce.repo
+RUN dnf install -y docker-ce-cli
+WORKDIR /workspace
+ENTRYPOINT ["/usr/bin/bazel"]
diff --git a/images/hostoverlaytest/Dockerfile b/images/hostoverlaytest/Dockerfile
new file mode 100644
index 000000000..d83439e9c
--- /dev/null
+++ b/images/hostoverlaytest/Dockerfile
@@ -0,0 +1,7 @@
+FROM ubuntu:bionic
+
+WORKDIR /root
+COPY . .
+
+RUN apt-get update && apt-get install -y gcc
+RUN gcc -O2 -o test test.c
diff --git a/images/hostoverlaytest/test.c b/images/hostoverlaytest/test.c
new file mode 100644
index 000000000..088f90746
--- /dev/null
+++ b/images/hostoverlaytest/test.c
@@ -0,0 +1,88 @@
+#include <err.h>
+#include <fcntl.h>
+#include <stdio.h>
+#include <string.h>
+#include <sys/mman.h>
+#include <unistd.h>
+
+int main(int argc, char** argv) {
+ const char kTestFilePath[] = "testfile.txt";
+ const char kOldFileData[] = "old data\n";
+ const char kNewFileData[] = "new data\n";
+ const size_t kPageSize = sysconf(_SC_PAGE_SIZE);
+
+ // Open a file that already exists in a host overlayfs lower layer.
+ const int fd_rdonly = open(kTestFilePath, O_RDONLY);
+ if (fd_rdonly < 0) {
+ err(1, "open(%s, O_RDONLY)", kTestFilePath);
+ }
+
+ // Check that the file's initial contents are what we expect when read via
+ // syscall.
+ char oldbuf[sizeof(kOldFileData)] = {};
+ ssize_t n = pread(fd_rdonly, oldbuf, sizeof(oldbuf), 0);
+ if (n < 0) {
+ err(1, "initial pread");
+ }
+ if (n != strlen(kOldFileData)) {
+ errx(1, "short initial pread (%ld/%lu bytes)", n, strlen(kOldFileData));
+ }
+ if (strcmp(oldbuf, kOldFileData) != 0) {
+ errx(1, "initial pread returned wrong data: %s", oldbuf);
+ }
+
+ // Check that the file's initial contents are what we expect when read via
+ // memory mapping.
+ void* page = mmap(NULL, kPageSize, PROT_READ, MAP_SHARED, fd_rdonly, 0);
+ if (page == MAP_FAILED) {
+ err(1, "mmap");
+ }
+ if (strcmp(page, kOldFileData) != 0) {
+ errx(1, "mapping contains wrong initial data: %s", (const char*)page);
+ }
+
+ // Open the same file writably, causing host overlayfs to copy it up, and
+ // replace its contents.
+ const int fd_rdwr = open(kTestFilePath, O_RDWR);
+ if (fd_rdwr < 0) {
+ err(1, "open(%s, O_RDWR)", kTestFilePath);
+ }
+ n = write(fd_rdwr, kNewFileData, strlen(kNewFileData));
+ if (n < 0) {
+ err(1, "write");
+ }
+ if (n != strlen(kNewFileData)) {
+ errx(1, "short write (%ld/%lu bytes)", n, strlen(kNewFileData));
+ }
+ if (ftruncate(fd_rdwr, strlen(kNewFileData)) < 0) {
+ err(1, "truncate");
+ }
+
+ int failed = 0;
+
+ // Check that syscalls on the old FD return updated contents. (Before Linux
+ // 4.18, this requires that runsc use a post-copy-up FD to service the read.)
+ char newbuf[sizeof(kNewFileData)] = {};
+ n = pread(fd_rdonly, newbuf, sizeof(newbuf), 0);
+ if (n < 0) {
+ err(1, "final pread");
+ }
+ if (n != strlen(kNewFileData)) {
+ warnx("short final pread (%ld/%lu bytes)", n, strlen(kNewFileData));
+ failed = 1;
+ } else if (strcmp(newbuf, kNewFileData) != 0) {
+ warnx("final pread returned wrong data: %s", newbuf);
+ failed = 1;
+ }
+
+ // Check that the memory mapping of the old FD has been updated. (Linux
+ // overlayfs does not do this, so regardless of kernel version this requires
+ // that runsc replace existing memory mappings with mappings of a
+ // post-copy-up FD.)
+ if (strcmp(page, kNewFileData) != 0) {
+ warnx("mapping contains wrong final data: %s", (const char*)page);
+ failed = 1;
+ }
+
+ return failed;
+}
diff --git a/images/hostoverlaytest/testfile.txt b/images/hostoverlaytest/testfile.txt
new file mode 100644
index 000000000..e4188c841
--- /dev/null
+++ b/images/hostoverlaytest/testfile.txt
@@ -0,0 +1 @@
+old data
diff --git a/images/iptables/Dockerfile b/images/iptables/Dockerfile
new file mode 100644
index 000000000..efd91cb80
--- /dev/null
+++ b/images/iptables/Dockerfile
@@ -0,0 +1,2 @@
+FROM ubuntu
+RUN apt update && apt install -y iptables
diff --git a/images/jekyll/Dockerfile b/images/jekyll/Dockerfile
new file mode 100644
index 000000000..4860dd750
--- /dev/null
+++ b/images/jekyll/Dockerfile
@@ -0,0 +1,13 @@
+FROM jekyll/jekyll:4.0.0
+USER root
+RUN gem install \
+ html-proofer:3.10.2 \
+ nokogiri:1.10.1 \
+ jekyll-autoprefixer:1.0.2 \
+ jekyll-inline-svg:1.1.4 \
+ jekyll-paginate:1.1.0 \
+ kramdown-parser-gfm:1.1.0 \
+ jekyll-relative-links:0.6.1 \
+ jekyll-feed:0.13.0 \
+ jekyll-sitemap:1.4.0
+CMD ["/usr/gem/gems/jekyll-4.0.0/exe/jekyll", "build", "-t", "-s", "/input", "-d", "/output"]
diff --git a/images/packetdrill/Dockerfile b/images/packetdrill/Dockerfile
new file mode 100644
index 000000000..01296dbaf
--- /dev/null
+++ b/images/packetdrill/Dockerfile
@@ -0,0 +1,8 @@
+FROM ubuntu:bionic
+RUN apt-get update && apt-get install -y net-tools git iptables iputils-ping \
+ netcat tcpdump jq tar bison flex make
+RUN hash -r
+RUN git clone --depth 1 --branch packetdrill-v2.0 \
+ https://github.com/google/packetdrill.git
+RUN cd packetdrill/gtests/net/packetdrill && ./configure && make
+CMD /bin/bash
diff --git a/images/packetimpact/Dockerfile b/images/packetimpact/Dockerfile
new file mode 100644
index 000000000..87aa99ef2
--- /dev/null
+++ b/images/packetimpact/Dockerfile
@@ -0,0 +1,16 @@
+FROM ubuntu:bionic
+RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
+ # iptables to disable OS native packet processing.
+ iptables \
+ # nc to check that the posix_server is running.
+ netcat \
+ # tcpdump to log brief packet sniffing.
+ tcpdump \
+ # ip link show to display MAC addresses.
+ iproute2 \
+ # tshark to log verbose packet sniffing.
+ tshark \
+ # killall for cleanup.
+ psmisc
+RUN hash -r
+CMD /bin/bash
diff --git a/images/runtimes/go1.12/Dockerfile b/images/runtimes/go1.12/Dockerfile
new file mode 100644
index 000000000..cb2944062
--- /dev/null
+++ b/images/runtimes/go1.12/Dockerfile
@@ -0,0 +1,4 @@
+# Go is easy, since we already have everything we need to compile the proctor
+# binary and run the tests in the golang Docker image.
+FROM golang:1.12
+RUN ["go", "tool", "dist", "test", "-compile-only"]
diff --git a/images/runtimes/java11/Dockerfile b/images/runtimes/java11/Dockerfile
new file mode 100644
index 000000000..03bc8aaf1
--- /dev/null
+++ b/images/runtimes/java11/Dockerfile
@@ -0,0 +1,22 @@
+FROM ubuntu:bionic
+RUN apt-get update && apt-get install -y \
+ autoconf \
+ build-essential \
+ curl \
+ make \
+ openjdk-11-jdk \
+ unzip \
+ zip
+
+# Download the JDK test library.
+WORKDIR /root
+RUN set -ex \
+ && curl -fsSL --retry 10 -o /tmp/jdktests.tar.gz http://hg.openjdk.java.net/jdk/jdk11/archive/76072a077ee1.tar.gz/test \
+ && tar -xzf /tmp/jdktests.tar.gz \
+ && mv jdk11-76072a077ee1/test test \
+ && rm -f /tmp/jdktests.tar.gz
+
+# Install jtreg and add to PATH.
+RUN curl -o jtreg.tar.gz https://ci.adoptopenjdk.net/view/Dependencies/job/jtreg/lastSuccessfulBuild/artifact/jtreg-4.2.0-tip.tar.gz
+RUN tar -xzf jtreg.tar.gz
+ENV PATH="/root/jtreg/bin:$PATH"
diff --git a/images/runtimes/nodejs12.4.0/Dockerfile b/images/runtimes/nodejs12.4.0/Dockerfile
new file mode 100644
index 000000000..d17924b62
--- /dev/null
+++ b/images/runtimes/nodejs12.4.0/Dockerfile
@@ -0,0 +1,21 @@
+FROM ubuntu:bionic
+RUN apt-get update && apt-get install -y \
+ curl \
+ dumb-init \
+ g++ \
+ make \
+ python
+
+WORKDIR /root
+ARG VERSION=v12.4.0
+RUN curl -o node-${VERSION}.tar.gz https://nodejs.org/dist/${VERSION}/node-${VERSION}.tar.gz
+RUN tar -zxf node-${VERSION}.tar.gz
+
+WORKDIR /root/node-${VERSION}
+RUN ./configure
+RUN make
+RUN make test-build
+
+# Including dumb-init emulates the Linux "init" process, preventing the failure
+# of tests involving worker processes.
+ENTRYPOINT ["/usr/bin/dumb-init"]
diff --git a/images/runtimes/php7.3.6/Dockerfile b/images/runtimes/php7.3.6/Dockerfile
new file mode 100644
index 000000000..e5f67f79c
--- /dev/null
+++ b/images/runtimes/php7.3.6/Dockerfile
@@ -0,0 +1,19 @@
+FROM ubuntu:bionic
+RUN apt-get update && apt-get install -y \
+ autoconf \
+ automake \
+ bison \
+ build-essential \
+ curl \
+ libtool \
+ libxml2-dev \
+ re2c
+
+WORKDIR /root
+ARG VERSION=7.3.6
+RUN curl -o php-${VERSION}.tar.gz https://www.php.net/distributions/php-${VERSION}.tar.gz
+RUN tar -zxf php-${VERSION}.tar.gz
+
+WORKDIR /root/php-${VERSION}
+RUN ./configure
+RUN make
diff --git a/images/runtimes/python3.7.3/Dockerfile b/images/runtimes/python3.7.3/Dockerfile
new file mode 100644
index 000000000..4d1e1e221
--- /dev/null
+++ b/images/runtimes/python3.7.3/Dockerfile
@@ -0,0 +1,21 @@
+FROM ubuntu:bionic
+RUN apt-get update && apt-get install -y \
+ curl \
+ gcc \
+ libbz2-dev \
+ libffi-dev \
+ liblzma-dev \
+ libreadline-dev \
+ libssl-dev \
+ make \
+ zlib1g-dev
+
+# Use flags -LJO to follow the html redirect and download .tar.gz.
+WORKDIR /root
+ARG VERSION=3.7.3
+RUN curl -LJO https://github.com/python/cpython/archive/v${VERSION}.tar.gz
+RUN tar -zxf cpython-${VERSION}.tar.gz
+
+WORKDIR /root/cpython-${VERSION}
+RUN ./configure --with-pydebug
+RUN make -s -j2
diff --git a/images/tmpfile/Dockerfile b/images/tmpfile/Dockerfile
new file mode 100644
index 000000000..e3816c8cb
--- /dev/null
+++ b/images/tmpfile/Dockerfile
@@ -0,0 +1,4 @@
+# Create file under /tmp to ensure files inside '/tmp' are not overridden.
+FROM alpine:3.11.5
+RUN mkdir -p /tmp/foo \
+ && echo 123 > /tmp/foo/file.txt
diff --git a/pkg/abi/BUILD b/pkg/abi/BUILD
new file mode 100644
index 000000000..839f822eb
--- /dev/null
+++ b/pkg/abi/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "abi",
+ srcs = [
+ "abi.go",
+ "abi_linux.go",
+ "flag.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/abi/abi.go b/pkg/abi/abi.go
new file mode 100644
index 000000000..e6be93c3a
--- /dev/null
+++ b/pkg/abi/abi.go
@@ -0,0 +1,45 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package abi describes the interface between a kernel and userspace.
+package abi
+
+import (
+ "fmt"
+)
+
+// OS describes the target operating system for an ABI.
+//
+// Note that OS is architecture-independent. The details of the OS ABI will
+// vary between architectures.
+type OS int
+
+const (
+ // Linux is the Linux ABI.
+ Linux OS = iota
+)
+
+// String implements fmt.Stringer.
+func (o OS) String() string {
+ switch o {
+ case Linux:
+ return "linux"
+ default:
+ return fmt.Sprintf("OS(%d)", o)
+ }
+}
+
+// ABI is an interface that defines OS-specific interactions.
+type ABI interface {
+}
diff --git a/pkg/abi/abi_linux.go b/pkg/abi/abi_linux.go
new file mode 100644
index 000000000..3059479bd
--- /dev/null
+++ b/pkg/abi/abi_linux.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package abi
+
+// Host specifies the host ABI.
+const Host = Linux
diff --git a/pkg/abi/flag.go b/pkg/abi/flag.go
new file mode 100644
index 000000000..dcdd66d4e
--- /dev/null
+++ b/pkg/abi/flag.go
@@ -0,0 +1,85 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package abi
+
+import (
+ "fmt"
+ "math"
+ "strconv"
+ "strings"
+)
+
+// A FlagSet is a slice of bit-flags and their name.
+type FlagSet []struct {
+ Flag uint64
+ Name string
+}
+
+// Parse returns a pretty version of val, using the flag names for known flags.
+// Unknown flags remain numeric.
+func (s FlagSet) Parse(val uint64) string {
+ var flags []string
+
+ for _, f := range s {
+ if val&f.Flag == f.Flag {
+ flags = append(flags, f.Name)
+ val &^= f.Flag
+ }
+ }
+
+ if val != 0 {
+ flags = append(flags, "0x"+strconv.FormatUint(val, 16))
+ }
+
+ if len(flags) == 0 {
+ // Prefer 0 to an empty string.
+ return "0x0"
+ }
+
+ return strings.Join(flags, "|")
+}
+
+// ValueSet is a map of syscall values to their name. Parse will use the name
+// or the value if unknown.
+type ValueSet map[uint64]string
+
+// Parse returns the name of the value associated with `val`. Unknown values
+// are converted to hex.
+func (s ValueSet) Parse(val uint64) string {
+ if v, ok := s[val]; ok {
+ return v
+ }
+ return fmt.Sprintf("%#x", val)
+}
+
+// ParseDecimal returns the name of the value associated with `val`. Unknown
+// values are converted to decimal.
+func (s ValueSet) ParseDecimal(val uint64) string {
+ if v, ok := s[val]; ok {
+ return v
+ }
+ return fmt.Sprintf("%d", val)
+}
+
+// ParseName returns the flag value associated with 'name'. Returns false
+// if no value is found.
+func (s ValueSet) ParseName(name string) (uint64, bool) {
+ for k, v := range s {
+ if v == name {
+ return k, true
+ }
+ }
+ return math.MaxUint64, false
+}
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
new file mode 100644
index 000000000..2b789c4ec
--- /dev/null
+++ b/pkg/abi/linux/BUILD
@@ -0,0 +1,86 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+# Package linux contains the constants and types needed to interface with a
+# Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix
+# when the host OS may not be Linux.
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "linux",
+ srcs = [
+ "aio.go",
+ "arch_amd64.go",
+ "audit.go",
+ "bpf.go",
+ "capability.go",
+ "clone.go",
+ "dev.go",
+ "elf.go",
+ "epoll.go",
+ "epoll_amd64.go",
+ "epoll_arm64.go",
+ "errors.go",
+ "eventfd.go",
+ "exec.go",
+ "fadvise.go",
+ "fcntl.go",
+ "file.go",
+ "file_amd64.go",
+ "file_arm64.go",
+ "fs.go",
+ "futex.go",
+ "inotify.go",
+ "ioctl.go",
+ "ioctl_tun.go",
+ "ip.go",
+ "ipc.go",
+ "limits.go",
+ "linux.go",
+ "mm.go",
+ "netdevice.go",
+ "netfilter.go",
+ "netlink.go",
+ "netlink_route.go",
+ "poll.go",
+ "prctl.go",
+ "ptrace.go",
+ "ptrace_amd64.go",
+ "ptrace_arm64.go",
+ "rseq.go",
+ "rusage.go",
+ "sched.go",
+ "seccomp.go",
+ "sem.go",
+ "shm.go",
+ "signal.go",
+ "signalfd.go",
+ "socket.go",
+ "splice.go",
+ "tcp.go",
+ "time.go",
+ "timer.go",
+ "tty.go",
+ "uio.go",
+ "utsname.go",
+ "wait.go",
+ "xattr.go",
+ ],
+ marshal = True,
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/abi",
+ "//pkg/binary",
+ "//pkg/bits",
+ ],
+)
+
+go_test(
+ name = "linux_test",
+ size = "small",
+ srcs = ["netfilter_test.go"],
+ library = ":linux",
+ deps = [
+ "//pkg/binary",
+ ],
+)
diff --git a/pkg/abi/linux/aio.go b/pkg/abi/linux/aio.go
new file mode 100644
index 000000000..86ee3f8b5
--- /dev/null
+++ b/pkg/abi/linux/aio.go
@@ -0,0 +1,76 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import "encoding/binary"
+
+// AIORingSize is sizeof(struct aio_ring).
+const AIORingSize = 32
+
+// I/O commands.
+const (
+ IOCB_CMD_PREAD = 0
+ IOCB_CMD_PWRITE = 1
+ IOCB_CMD_FSYNC = 2
+ IOCB_CMD_FDSYNC = 3
+ // 4 was the experimental IOCB_CMD_PREADX.
+ IOCB_CMD_POLL = 5
+ IOCB_CMD_NOOP = 6
+ IOCB_CMD_PREADV = 7
+ IOCB_CMD_PWRITEV = 8
+)
+
+// I/O flags.
+const (
+ IOCB_FLAG_RESFD = 1
+ IOCB_FLAG_IOPRIO = 2
+)
+
+// IOCallback describes an I/O request.
+//
+// The priority field is currently ignored in the implementation below. Also
+// note that the IOCB_FLAG_RESFD feature is not supported.
+type IOCallback struct {
+ Data uint64
+ Key uint32
+ _ uint32
+
+ OpCode uint16
+ ReqPrio int16
+ FD int32
+
+ Buf uint64
+ Bytes uint64
+ Offset int64
+
+ Reserved2 uint64
+ Flags uint32
+
+ // eventfd to signal if IOCB_FLAG_RESFD is set in flags.
+ ResFD int32
+}
+
+// IOEvent describes an I/O result.
+//
+// +stateify savable
+type IOEvent struct {
+ Data uint64
+ Obj uint64
+ Result int64
+ Result2 int64
+}
+
+// IOEventSize is the size of an ioEvent encoded.
+var IOEventSize = binary.Size(IOEvent{})
diff --git a/pkg/abi/linux/arch_amd64.go b/pkg/abi/linux/arch_amd64.go
new file mode 100644
index 000000000..0be31e755
--- /dev/null
+++ b/pkg/abi/linux/arch_amd64.go
@@ -0,0 +1,23 @@
+// Copyright 2020 The gVisor Authors.
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+// Start and end addresses of the vsyscall page.
+const (
+ VSyscallStartAddr uint64 = 0xffffffffff600000
+ VSyscallEndAddr uint64 = 0xffffffffff601000
+)
diff --git a/pkg/abi/linux/audit.go b/pkg/abi/linux/audit.go
new file mode 100644
index 000000000..6cca69af9
--- /dev/null
+++ b/pkg/abi/linux/audit.go
@@ -0,0 +1,23 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Audit numbers identify different system call APIs, from <uapi/linux/audit.h>
+const (
+ // AUDIT_ARCH_X86_64 identifies AMD64.
+ AUDIT_ARCH_X86_64 = 0xc000003e
+ // AUDIT_ARCH_AARCH64 identifies ARM64.
+ AUDIT_ARCH_AARCH64 = 0xc00000b7
+)
diff --git a/pkg/abi/linux/bpf.go b/pkg/abi/linux/bpf.go
new file mode 100644
index 000000000..aa3d3ce70
--- /dev/null
+++ b/pkg/abi/linux/bpf.go
@@ -0,0 +1,34 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// BPFInstruction is a raw BPF virtual machine instruction.
+//
+// +stateify savable
+type BPFInstruction struct {
+ // OpCode is the operation to execute.
+ OpCode uint16
+
+ // JumpIfTrue is the number of instructions to skip if OpCode is a
+ // conditional instruction and the condition is true.
+ JumpIfTrue uint8
+
+ // JumpIfFalse is the number of instructions to skip if OpCode is a
+ // conditional instruction and the condition is false.
+ JumpIfFalse uint8
+
+ // K is a constant parameter. The meaning depends on the value of OpCode.
+ K uint32
+}
diff --git a/pkg/abi/linux/capability.go b/pkg/abi/linux/capability.go
new file mode 100644
index 000000000..965f74663
--- /dev/null
+++ b/pkg/abi/linux/capability.go
@@ -0,0 +1,190 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// A Capability represents the ability to perform a privileged operation.
+type Capability int
+
+// Capabilities defined by Linux. Taken from the kernel's
+// include/uapi/linux/capability.h. See capabilities(7) or that file for more
+// detailed capability descriptions.
+const (
+ CAP_CHOWN = Capability(0)
+ CAP_DAC_OVERRIDE = Capability(1)
+ CAP_DAC_READ_SEARCH = Capability(2)
+ CAP_FOWNER = Capability(3)
+ CAP_FSETID = Capability(4)
+ CAP_KILL = Capability(5)
+ CAP_SETGID = Capability(6)
+ CAP_SETUID = Capability(7)
+ CAP_SETPCAP = Capability(8)
+ CAP_LINUX_IMMUTABLE = Capability(9)
+ CAP_NET_BIND_SERVICE = Capability(10)
+ CAP_NET_BROADCAST = Capability(11)
+ CAP_NET_ADMIN = Capability(12)
+ CAP_NET_RAW = Capability(13)
+ CAP_IPC_LOCK = Capability(14)
+ CAP_IPC_OWNER = Capability(15)
+ CAP_SYS_MODULE = Capability(16)
+ CAP_SYS_RAWIO = Capability(17)
+ CAP_SYS_CHROOT = Capability(18)
+ CAP_SYS_PTRACE = Capability(19)
+ CAP_SYS_PACCT = Capability(20)
+ CAP_SYS_ADMIN = Capability(21)
+ CAP_SYS_BOOT = Capability(22)
+ CAP_SYS_NICE = Capability(23)
+ CAP_SYS_RESOURCE = Capability(24)
+ CAP_SYS_TIME = Capability(25)
+ CAP_SYS_TTY_CONFIG = Capability(26)
+ CAP_MKNOD = Capability(27)
+ CAP_LEASE = Capability(28)
+ CAP_AUDIT_WRITE = Capability(29)
+ CAP_AUDIT_CONTROL = Capability(30)
+ CAP_SETFCAP = Capability(31)
+ CAP_MAC_OVERRIDE = Capability(32)
+ CAP_MAC_ADMIN = Capability(33)
+ CAP_SYSLOG = Capability(34)
+ CAP_WAKE_ALARM = Capability(35)
+ CAP_BLOCK_SUSPEND = Capability(36)
+ CAP_AUDIT_READ = Capability(37)
+
+ // CAP_LAST_CAP is the highest-numbered capability.
+ // Seach for "CAP_LAST_CAP" to find other places that need to change.
+ CAP_LAST_CAP = CAP_AUDIT_READ
+)
+
+// Ok returns true if cp is a supported capability.
+func (cp Capability) Ok() bool {
+ return cp >= 0 && cp <= CAP_LAST_CAP
+}
+
+// String returns the capability name.
+func (cp Capability) String() string {
+ switch cp {
+ case CAP_CHOWN:
+ return "CAP_CHOWN"
+ case CAP_DAC_OVERRIDE:
+ return "CAP_DAC_OVERRIDE"
+ case CAP_DAC_READ_SEARCH:
+ return "CAP_DAC_READ_SEARCH"
+ case CAP_FOWNER:
+ return "CAP_FOWNER"
+ case CAP_FSETID:
+ return "CAP_FSETID"
+ case CAP_KILL:
+ return "CAP_KILL"
+ case CAP_SETGID:
+ return "CAP_SETGID"
+ case CAP_SETUID:
+ return "CAP_SETUID"
+ case CAP_SETPCAP:
+ return "CAP_SETPCAP"
+ case CAP_LINUX_IMMUTABLE:
+ return "CAP_LINUX_IMMUTABLE"
+ case CAP_NET_BIND_SERVICE:
+ return "CAP_NET_BIND_SERVICE"
+ case CAP_NET_BROADCAST:
+ return "CAP_NET_BROADCAST"
+ case CAP_NET_ADMIN:
+ return "CAP_NET_ADMIN"
+ case CAP_NET_RAW:
+ return "CAP_NET_RAW"
+ case CAP_IPC_LOCK:
+ return "CAP_IPC_LOCK"
+ case CAP_IPC_OWNER:
+ return "CAP_IPC_OWNER"
+ case CAP_SYS_MODULE:
+ return "CAP_SYS_MODULE"
+ case CAP_SYS_RAWIO:
+ return "CAP_SYS_RAWIO"
+ case CAP_SYS_CHROOT:
+ return "CAP_SYS_CHROOT"
+ case CAP_SYS_PTRACE:
+ return "CAP_SYS_PTRACE"
+ case CAP_SYS_PACCT:
+ return "CAP_SYS_PACCT"
+ case CAP_SYS_ADMIN:
+ return "CAP_SYS_ADMIN"
+ case CAP_SYS_BOOT:
+ return "CAP_SYS_BOOT"
+ case CAP_SYS_NICE:
+ return "CAP_SYS_NICE"
+ case CAP_SYS_RESOURCE:
+ return "CAP_SYS_RESOURCE"
+ case CAP_SYS_TIME:
+ return "CAP_SYS_TIME"
+ case CAP_SYS_TTY_CONFIG:
+ return "CAP_SYS_TTY_CONFIG"
+ case CAP_MKNOD:
+ return "CAP_MKNOD"
+ case CAP_LEASE:
+ return "CAP_LEASE"
+ case CAP_AUDIT_WRITE:
+ return "CAP_AUDIT_WRITE"
+ case CAP_AUDIT_CONTROL:
+ return "CAP_AUDIT_CONTROL"
+ case CAP_SETFCAP:
+ return "CAP_SETFCAP"
+ case CAP_MAC_OVERRIDE:
+ return "CAP_MAC_OVERRIDE"
+ case CAP_MAC_ADMIN:
+ return "CAP_MAC_ADMIN"
+ case CAP_SYSLOG:
+ return "CAP_SYSLOG"
+ case CAP_WAKE_ALARM:
+ return "CAP_WAKE_ALARM"
+ case CAP_BLOCK_SUSPEND:
+ return "CAP_BLOCK_SUSPEND"
+ case CAP_AUDIT_READ:
+ return "CAP_AUDIT_READ"
+ default:
+ return "UNKNOWN"
+ }
+}
+
+// Version numbers used by the capget/capset syscalls, defined in Linux's
+// include/uapi/linux/capability.h.
+const (
+ // LINUX_CAPABILITY_VERSION_1 causes the data pointer to be
+ // interpreted as a pointer to a single cap_user_data_t. Since capability
+ // sets are 64 bits and the "capability sets" in cap_user_data_t are 32
+ // bits only, this causes the upper 32 bits to be implicitly 0.
+ LINUX_CAPABILITY_VERSION_1 = 0x19980330
+
+ // LINUX_CAPABILITY_VERSION_2 and LINUX_CAPABILITY_VERSION_3 cause the
+ // data pointer to be interpreted as a pointer to an array of 2
+ // cap_user_data_t, using the second to store the 32 MSB of each capability
+ // set. Versions 2 and 3 are identical, but Linux printk's a warning on use
+ // of version 2 due to a userspace API defect.
+ LINUX_CAPABILITY_VERSION_2 = 0x20071026
+ LINUX_CAPABILITY_VERSION_3 = 0x20080522
+
+ // HighestCapabilityVersion is the highest supported
+ // LINUX_CAPABILITY_VERSION_* version.
+ HighestCapabilityVersion = LINUX_CAPABILITY_VERSION_3
+)
+
+// CapUserHeader is equivalent to Linux's cap_user_header_t.
+type CapUserHeader struct {
+ Version uint32
+ Pid int32
+}
+
+// CapUserData is equivalent to Linux's cap_user_data_t.
+type CapUserData struct {
+ Effective uint32
+ Permitted uint32
+ Inheritable uint32
+}
diff --git a/pkg/abi/linux/clone.go b/pkg/abi/linux/clone.go
new file mode 100644
index 000000000..c2cbfca5e
--- /dev/null
+++ b/pkg/abi/linux/clone.go
@@ -0,0 +1,41 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Clone constants per clone(2).
+const (
+ CLONE_VM = 0x100
+ CLONE_FS = 0x200
+ CLONE_FILES = 0x400
+ CLONE_SIGHAND = 0x800
+ CLONE_PARENT = 0x8000
+ CLONE_PTRACE = 0x2000
+ CLONE_VFORK = 0x4000
+ CLONE_THREAD = 0x10000
+ CLONE_NEWNS = 0x20000
+ CLONE_SYSVSEM = 0x40000
+ CLONE_SETTLS = 0x80000
+ CLONE_PARENT_SETTID = 0x100000
+ CLONE_CHILD_CLEARTID = 0x200000
+ CLONE_DETACHED = 0x400000
+ CLONE_UNTRACED = 0x800000
+ CLONE_CHILD_SETTID = 0x1000000
+ CLONE_NEWUTS = 0x4000000
+ CLONE_NEWIPC = 0x8000000
+ CLONE_NEWUSER = 0x10000000
+ CLONE_NEWPID = 0x20000000
+ CLONE_NEWNET = 0x40000000
+ CLONE_IO = 0x80000000
+)
diff --git a/pkg/abi/linux/dev.go b/pkg/abi/linux/dev.go
new file mode 100644
index 000000000..192e2093b
--- /dev/null
+++ b/pkg/abi/linux/dev.go
@@ -0,0 +1,66 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// MakeDeviceID encodes a major and minor device number into a single device ID.
+//
+// Format (see linux/kdev_t.h:new_encode_dev):
+//
+// Bits 7:0 - minor bits 7:0
+// Bits 19:8 - major bits 11:0
+// Bits 31:20 - minor bits 19:8
+func MakeDeviceID(major uint16, minor uint32) uint32 {
+ return (minor & 0xff) | ((uint32(major) & 0xfff) << 8) | ((minor >> 8) << 20)
+}
+
+// DecodeDeviceID decodes a device ID into major and minor device numbers.
+func DecodeDeviceID(rdev uint32) (uint16, uint32) {
+ major := uint16((rdev >> 8) & 0xfff)
+ minor := (rdev & 0xff) | ((rdev >> 20) << 8)
+ return major, minor
+}
+
+// Character device IDs.
+//
+// See Documentations/devices.txt and uapi/linux/major.h.
+const (
+ // UNNAMED_MAJOR is the major device number for "unnamed" devices, whose
+ // minor numbers are dynamically allocated by the kernel.
+ UNNAMED_MAJOR = 0
+
+ // MEM_MAJOR is the major device number for "memory" character devices.
+ MEM_MAJOR = 1
+
+ // TTYAUX_MAJOR is the major device number for alternate TTY devices.
+ TTYAUX_MAJOR = 5
+
+ // MISC_MAJOR is the major device number for non-serial mice, misc feature
+ // devices.
+ MISC_MAJOR = 10
+
+ // UNIX98_PTY_MASTER_MAJOR is the initial major device number for
+ // Unix98 PTY masters.
+ UNIX98_PTY_MASTER_MAJOR = 128
+
+ // UNIX98_PTY_SLAVE_MAJOR is the initial major device number for
+ // Unix98 PTY slaves.
+ UNIX98_PTY_SLAVE_MAJOR = 136
+)
+
+// Minor device numbers for TTYAUX_MAJOR.
+const (
+ // PTMX_MINOR is the minor device number for /dev/ptmx.
+ PTMX_MINOR = 2
+)
diff --git a/pkg/abi/linux/elf.go b/pkg/abi/linux/elf.go
new file mode 100644
index 000000000..7c9a02f20
--- /dev/null
+++ b/pkg/abi/linux/elf.go
@@ -0,0 +1,108 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Linux auxiliary vector entry types.
+const (
+ // AT_NULL is the end of the auxiliary vector.
+ AT_NULL = 0
+
+ // AT_IGNORE should be ignored.
+ AT_IGNORE = 1
+
+ // AT_EXECFD is the file descriptor of the program.
+ AT_EXECFD = 2
+
+ // AT_PHDR points to the program headers.
+ AT_PHDR = 3
+
+ // AT_PHENT is the size of a program header entry.
+ AT_PHENT = 4
+
+ // AT_PHNUM is the number of program headers.
+ AT_PHNUM = 5
+
+ // AT_PAGESZ is the system page size.
+ AT_PAGESZ = 6
+
+ // AT_BASE is the base address of the interpreter.
+ AT_BASE = 7
+
+ // AT_FLAGS are flags.
+ AT_FLAGS = 8
+
+ // AT_ENTRY is the program entry point.
+ AT_ENTRY = 9
+
+ // AT_NOTELF indicates that the program is not an ELF binary.
+ AT_NOTELF = 10
+
+ // AT_UID is the real UID.
+ AT_UID = 11
+
+ // AT_EUID is the effective UID.
+ AT_EUID = 12
+
+ // AT_GID is the real GID.
+ AT_GID = 13
+
+ // AT_EGID is the effective GID.
+ AT_EGID = 14
+
+ // AT_PLATFORM is a string identifying the CPU.
+ AT_PLATFORM = 15
+
+ // AT_HWCAP are arch-dependent CPU capabilities.
+ AT_HWCAP = 16
+
+ // AT_CLKTCK is the frequency used by times(2).
+ AT_CLKTCK = 17
+
+ // AT_SECURE indicate secure mode.
+ AT_SECURE = 23
+
+ // AT_BASE_PLATFORM is a string identifying the "real" platform. It may
+ // differ from AT_PLATFORM.
+ AT_BASE_PLATFORM = 24
+
+ // AT_RANDOM points to 16-bytes of random data.
+ AT_RANDOM = 25
+
+ // AT_HWCAP2 is an extension of AT_HWCAP.
+ AT_HWCAP2 = 26
+
+ // AT_EXECFN is the path used to execute the program.
+ AT_EXECFN = 31
+
+ // AT_SYSINFO_EHDR is the address of the VDSO.
+ AT_SYSINFO_EHDR = 33
+)
+
+// ELF ET_CORE and ptrace GETREGSET/SETREGSET register set types.
+//
+// See include/uapi/linux/elf.h.
+const (
+ // NT_PRSTATUS is for general purpose register.
+ NT_PRSTATUS = 0x1
+
+ // NT_PRFPREG is for float point register.
+ NT_PRFPREG = 0x2
+
+ // NT_X86_XSTATE is for x86 extended state using xsave.
+ NT_X86_XSTATE = 0x202
+
+ // NT_ARM_TLS is for ARM TLS register.
+ NT_ARM_TLS = 0x401
+)
diff --git a/pkg/abi/linux/epoll.go b/pkg/abi/linux/epoll.go
new file mode 100644
index 000000000..1121a1a92
--- /dev/null
+++ b/pkg/abi/linux/epoll.go
@@ -0,0 +1,62 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/binary"
+)
+
+// Event masks.
+const (
+ EPOLLIN = 0x1
+ EPOLLPRI = 0x2
+ EPOLLOUT = 0x4
+ EPOLLERR = 0x8
+ EPOLLHUP = 0x10
+ EPOLLRDNORM = 0x40
+ EPOLLRDBAND = 0x80
+ EPOLLWRNORM = 0x100
+ EPOLLWRBAND = 0x200
+ EPOLLMSG = 0x400
+ EPOLLRDHUP = 0x2000
+)
+
+// Per-file descriptor flags.
+const (
+ EPOLLEXCLUSIVE = 1 << 28
+ EPOLLWAKEUP = 1 << 29
+ EPOLLONESHOT = 1 << 30
+ EPOLLET = 1 << 31
+
+ // EP_PRIVATE_BITS is fs/eventpoll.c:EP_PRIVATE_BITS, the set of all bits
+ // in an epoll event mask that correspond to flags rather than I/O events.
+ EP_PRIVATE_BITS = EPOLLEXCLUSIVE | EPOLLWAKEUP | EPOLLONESHOT | EPOLLET
+)
+
+// Operation flags.
+const (
+ EPOLL_CLOEXEC = 0x80000
+ EPOLL_NONBLOCK = 0x800
+)
+
+// Control operations.
+const (
+ EPOLL_CTL_ADD = 0x1
+ EPOLL_CTL_DEL = 0x2
+ EPOLL_CTL_MOD = 0x3
+)
+
+// SizeOfEpollEvent is the size of EpollEvent struct.
+var SizeOfEpollEvent = int(binary.Size(EpollEvent{}))
diff --git a/pkg/abi/linux/epoll_amd64.go b/pkg/abi/linux/epoll_amd64.go
new file mode 100644
index 000000000..7e74b1143
--- /dev/null
+++ b/pkg/abi/linux/epoll_amd64.go
@@ -0,0 +1,29 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+// EpollEvent is equivalent to struct epoll_event from epoll(2).
+//
+// +marshal slice:EpollEventSlice
+type EpollEvent struct {
+ Events uint32
+ // Linux makes struct epoll_event::data a __u64. We represent it as
+ // [2]int32 because, on amd64, Linux also makes struct epoll_event
+ // __attribute__((packed)), such that there is no padding between Events
+ // and Data.
+ Data [2]int32
+}
diff --git a/pkg/abi/linux/epoll_arm64.go b/pkg/abi/linux/epoll_arm64.go
new file mode 100644
index 000000000..a35939cc9
--- /dev/null
+++ b/pkg/abi/linux/epoll_arm64.go
@@ -0,0 +1,28 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package linux
+
+// EpollEvent is equivalent to struct epoll_event from epoll(2).
+//
+// +marshal slice:EpollEventSlice
+type EpollEvent struct {
+ Events uint32
+ // Linux makes struct epoll_event a __u64, necessitating 4 bytes of padding
+ // here.
+ _ int32
+ Data [2]int32
+}
diff --git a/pkg/abi/linux/errors.go b/pkg/abi/linux/errors.go
new file mode 100644
index 000000000..93f85a864
--- /dev/null
+++ b/pkg/abi/linux/errors.go
@@ -0,0 +1,172 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Errno represents a Linux errno value.
+type Errno struct {
+ number int
+ name string
+}
+
+// Number returns the errno number.
+func (e *Errno) Number() int {
+ return e.number
+}
+
+// String implements fmt.Stringer.String.
+func (e *Errno) String() string {
+ return e.name
+}
+
+// Errno values from include/uapi/asm-generic/errno-base.h.
+var (
+ EPERM = &Errno{1, "operation not permitted"}
+ ENOENT = &Errno{2, "no such file or directory"}
+ ESRCH = &Errno{3, "no such process"}
+ EINTR = &Errno{4, "interrupted system call"}
+ EIO = &Errno{5, "I/O error"}
+ ENXIO = &Errno{6, "no such device or address"}
+ E2BIG = &Errno{7, "argument list too long"}
+ ENOEXEC = &Errno{8, "exec format error"}
+ EBADF = &Errno{9, "bad file number"}
+ ECHILD = &Errno{10, "no child processes"}
+ EAGAIN = &Errno{11, "try again"}
+ ENOMEM = &Errno{12, "out of memory"}
+ EACCES = &Errno{13, "permission denied"}
+ EFAULT = &Errno{14, "bad address"}
+ ENOTBLK = &Errno{15, "block device required"}
+ EBUSY = &Errno{16, "device or resource busy"}
+ EEXIST = &Errno{17, "file exists"}
+ EXDEV = &Errno{18, "cross-device link"}
+ ENODEV = &Errno{19, "no such device"}
+ ENOTDIR = &Errno{20, "not a directory"}
+ EISDIR = &Errno{21, "is a directory"}
+ EINVAL = &Errno{22, "invalid argument"}
+ ENFILE = &Errno{23, "file table overflow"}
+ EMFILE = &Errno{24, "too many open files"}
+ ENOTTY = &Errno{25, "not a typewriter"}
+ ETXTBSY = &Errno{26, "text file busy"}
+ EFBIG = &Errno{27, "file too large"}
+ ENOSPC = &Errno{28, "no space left on device"}
+ ESPIPE = &Errno{29, "illegal seek"}
+ EROFS = &Errno{30, "read-only file system"}
+ EMLINK = &Errno{31, "too many links"}
+ EPIPE = &Errno{32, "broken pipe"}
+ EDOM = &Errno{33, "math argument out of domain of func"}
+ ERANGE = &Errno{34, "math result not representable"}
+)
+
+// Errno values from include/uapi/asm-generic/errno.h.
+var (
+ EDEADLK = &Errno{35, "resource deadlock would occur"}
+ ENAMETOOLONG = &Errno{36, "file name too long"}
+ ENOLCK = &Errno{37, "no record locks available"}
+ ENOSYS = &Errno{38, "invalid system call number"}
+ ENOTEMPTY = &Errno{39, "directory not empty"}
+ ELOOP = &Errno{40, "too many symbolic links encountered"}
+ EWOULDBLOCK = &Errno{EAGAIN.number, "operation would block"}
+ ENOMSG = &Errno{42, "no message of desired type"}
+ EIDRM = &Errno{43, "identifier removed"}
+ ECHRNG = &Errno{44, "channel number out of range"}
+ EL2NSYNC = &Errno{45, "level 2 not synchronized"}
+ EL3HLT = &Errno{46, "level 3 halted"}
+ EL3RST = &Errno{47, "level 3 reset"}
+ ELNRNG = &Errno{48, "link number out of range"}
+ EUNATCH = &Errno{49, "protocol driver not attached"}
+ ENOCSI = &Errno{50, "no CSI structure available"}
+ EL2HLT = &Errno{51, "level 2 halted"}
+ EBADE = &Errno{52, "invalid exchange"}
+ EBADR = &Errno{53, "invalid request descriptor"}
+ EXFULL = &Errno{54, "exchange full"}
+ ENOANO = &Errno{55, "no anode"}
+ EBADRQC = &Errno{56, "invalid request code"}
+ EBADSLT = &Errno{57, "invalid slot"}
+ EDEADLOCK = EDEADLK
+ EBFONT = &Errno{59, "bad font file format"}
+ ENOSTR = &Errno{60, "device not a stream"}
+ ENODATA = &Errno{61, "no data available"}
+ ETIME = &Errno{62, "timer expired"}
+ ENOSR = &Errno{63, "out of streams resources"}
+ ENONET = &Errno{64, "machine is not on the network"}
+ ENOPKG = &Errno{65, "package not installed"}
+ EREMOTE = &Errno{66, "object is remote"}
+ ENOLINK = &Errno{67, "link has been severed"}
+ EADV = &Errno{68, "advertise error"}
+ ESRMNT = &Errno{69, "srmount error"}
+ ECOMM = &Errno{70, "communication error on send"}
+ EPROTO = &Errno{71, "protocol error"}
+ EMULTIHOP = &Errno{72, "multihop attempted"}
+ EDOTDOT = &Errno{73, "RFS specific error"}
+ EBADMSG = &Errno{74, "not a data message"}
+ EOVERFLOW = &Errno{75, "value too large for defined data type"}
+ ENOTUNIQ = &Errno{76, "name not unique on network"}
+ EBADFD = &Errno{77, "file descriptor in bad state"}
+ EREMCHG = &Errno{78, "remote address changed"}
+ ELIBACC = &Errno{79, "can not access a needed shared library"}
+ ELIBBAD = &Errno{80, "accessing a corrupted shared library"}
+ ELIBSCN = &Errno{81, ".lib section in a.out corrupted"}
+ ELIBMAX = &Errno{82, "attempting to link in too many shared libraries"}
+ ELIBEXEC = &Errno{83, "cannot exec a shared library directly"}
+ EILSEQ = &Errno{84, "illegal byte sequence"}
+ ERESTART = &Errno{85, "interrupted system call should be restarted"}
+ ESTRPIPE = &Errno{86, "streams pipe error"}
+ EUSERS = &Errno{87, "too many users"}
+ ENOTSOCK = &Errno{88, "socket operation on non-socket"}
+ EDESTADDRREQ = &Errno{89, "destination address required"}
+ EMSGSIZE = &Errno{90, "message too long"}
+ EPROTOTYPE = &Errno{91, "protocol wrong type for socket"}
+ ENOPROTOOPT = &Errno{92, "protocol not available"}
+ EPROTONOSUPPORT = &Errno{93, "protocol not supported"}
+ ESOCKTNOSUPPORT = &Errno{94, "socket type not supported"}
+ EOPNOTSUPP = &Errno{95, "operation not supported on transport endpoint"}
+ EPFNOSUPPORT = &Errno{96, "protocol family not supported"}
+ EAFNOSUPPORT = &Errno{97, "address family not supported by protocol"}
+ EADDRINUSE = &Errno{98, "address already in use"}
+ EADDRNOTAVAIL = &Errno{99, "cannot assign requested address"}
+ ENETDOWN = &Errno{100, "network is down"}
+ ENETUNREACH = &Errno{101, "network is unreachable"}
+ ENETRESET = &Errno{102, "network dropped connection because of reset"}
+ ECONNABORTED = &Errno{103, "software caused connection abort"}
+ ECONNRESET = &Errno{104, "connection reset by peer"}
+ ENOBUFS = &Errno{105, "no buffer space available"}
+ EISCONN = &Errno{106, "transport endpoint is already connected"}
+ ENOTCONN = &Errno{107, "transport endpoint is not connected"}
+ ESHUTDOWN = &Errno{108, "cannot send after transport endpoint shutdown"}
+ ETOOMANYREFS = &Errno{109, "too many references: cannot splice"}
+ ETIMEDOUT = &Errno{110, "connection timed out"}
+ ECONNREFUSED = &Errno{111, "connection refused"}
+ EHOSTDOWN = &Errno{112, "host is down"}
+ EHOSTUNREACH = &Errno{113, "no route to host"}
+ EALREADY = &Errno{114, "operation already in progress"}
+ EINPROGRESS = &Errno{115, "operation now in progress"}
+ ESTALE = &Errno{116, "stale file handle"}
+ EUCLEAN = &Errno{117, "structure needs cleaning"}
+ ENOTNAM = &Errno{118, "not a XENIX named type file"}
+ ENAVAIL = &Errno{119, "no XENIX semaphores available"}
+ EISNAM = &Errno{120, "is a named type file"}
+ EREMOTEIO = &Errno{121, "remote I/O error"}
+ EDQUOT = &Errno{122, "quota exceeded"}
+ ENOMEDIUM = &Errno{123, "no medium found"}
+ EMEDIUMTYPE = &Errno{124, "wrong medium type"}
+ ECANCELED = &Errno{125, "operation Canceled"}
+ ENOKEY = &Errno{126, "required key not available"}
+ EKEYEXPIRED = &Errno{127, "key has expired"}
+ EKEYREVOKED = &Errno{128, "key has been revoked"}
+ EKEYREJECTED = &Errno{129, "key was rejected by service"}
+ EOWNERDEAD = &Errno{130, "owner died"}
+ ENOTRECOVERABLE = &Errno{131, "state not recoverable"}
+ ERFKILL = &Errno{132, "operation not possible due to RF-kill"}
+ EHWPOISON = &Errno{133, "memory page has hardware error"}
+)
diff --git a/pkg/abi/linux/eventfd.go b/pkg/abi/linux/eventfd.go
new file mode 100644
index 000000000..9c479fc8f
--- /dev/null
+++ b/pkg/abi/linux/eventfd.go
@@ -0,0 +1,22 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Constants for eventfd2(2).
+const (
+ EFD_SEMAPHORE = 0x1
+ EFD_CLOEXEC = O_CLOEXEC
+ EFD_NONBLOCK = O_NONBLOCK
+)
diff --git a/pkg/abi/linux/exec.go b/pkg/abi/linux/exec.go
new file mode 100644
index 000000000..579d46c41
--- /dev/null
+++ b/pkg/abi/linux/exec.go
@@ -0,0 +1,18 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// TASK_COMM_LEN is the task command name length.
+const TASK_COMM_LEN = 16
diff --git a/pkg/abi/linux/fadvise.go b/pkg/abi/linux/fadvise.go
new file mode 100644
index 000000000..b06ff9964
--- /dev/null
+++ b/pkg/abi/linux/fadvise.go
@@ -0,0 +1,24 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+const (
+ POSIX_FADV_NORMAL = 0
+ POSIX_FADV_RANDOM = 1
+ POSIX_FADV_SEQUENTIAL = 2
+ POSIX_FADV_WILLNEED = 3
+ POSIX_FADV_DONTNEED = 4
+ POSIX_FADV_NOREUSE = 5
+)
diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go
new file mode 100644
index 000000000..9242e80a5
--- /dev/null
+++ b/pkg/abi/linux/fcntl.go
@@ -0,0 +1,69 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Commands from linux/fcntl.h.
+const (
+ F_DUPFD = 0
+ F_GETFD = 1
+ F_SETFD = 2
+ F_GETFL = 3
+ F_SETFL = 4
+ F_SETLK = 6
+ F_SETLKW = 7
+ F_SETOWN = 8
+ F_GETOWN = 9
+ F_SETOWN_EX = 15
+ F_GETOWN_EX = 16
+ F_DUPFD_CLOEXEC = 1024 + 6
+ F_SETPIPE_SZ = 1024 + 7
+ F_GETPIPE_SZ = 1024 + 8
+)
+
+// Commands for F_SETLK.
+const (
+ F_RDLCK = 0
+ F_WRLCK = 1
+ F_UNLCK = 2
+)
+
+// Flags for fcntl.
+const (
+ FD_CLOEXEC = 00000001
+)
+
+// Flock is the lock structure for F_SETLK.
+type Flock struct {
+ Type int16
+ Whence int16
+ _ [4]byte
+ Start int64
+ Len int64
+ Pid int32
+ _ [4]byte
+}
+
+// Owner types for F_SETOWN_EX and F_GETOWN_EX.
+const (
+ F_OWNER_TID = 0
+ F_OWNER_PID = 1
+ F_OWNER_PGRP = 2
+)
+
+// FOwnerEx is the owner structure for F_SETOWN_EX and F_GETOWN_EX.
+type FOwnerEx struct {
+ Type int32
+ PID int32
+}
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
new file mode 100644
index 000000000..e11ca2d62
--- /dev/null
+++ b/pkg/abi/linux/file.go
@@ -0,0 +1,384 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "fmt"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/binary"
+)
+
+// Constants for open(2).
+const (
+ O_ACCMODE = 000000003
+ O_RDONLY = 000000000
+ O_WRONLY = 000000001
+ O_RDWR = 000000002
+ O_CREAT = 000000100
+ O_EXCL = 000000200
+ O_NOCTTY = 000000400
+ O_TRUNC = 000001000
+ O_APPEND = 000002000
+ O_NONBLOCK = 000004000
+ O_DSYNC = 000010000
+ O_ASYNC = 000020000
+ O_NOATIME = 001000000
+ O_CLOEXEC = 002000000
+ O_SYNC = 004000000 // __O_SYNC in Linux
+ O_PATH = 010000000
+ O_TMPFILE = 020000000 // __O_TMPFILE in Linux
+)
+
+// Constants for fstatat(2).
+const (
+ AT_SYMLINK_NOFOLLOW = 0x100
+)
+
+// Constants for mount(2).
+const (
+ MS_RDONLY = 0x1
+ MS_NOSUID = 0x2
+ MS_NODEV = 0x4
+ MS_NOEXEC = 0x8
+ MS_SYNCHRONOUS = 0x10
+ MS_REMOUNT = 0x20
+ MS_MANDLOCK = 0x40
+ MS_DIRSYNC = 0x80
+ MS_NOATIME = 0x400
+ MS_NODIRATIME = 0x800
+ MS_BIND = 0x1000
+ MS_MOVE = 0x2000
+ MS_REC = 0x4000
+
+ MS_POSIXACL = 0x10000
+ MS_UNBINDABLE = 0x20000
+ MS_PRIVATE = 0x40000
+ MS_SLAVE = 0x80000
+ MS_SHARED = 0x100000
+ MS_RELATIME = 0x200000
+ MS_KERNMOUNT = 0x400000
+ MS_I_VERSION = 0x800000
+ MS_STRICTATIME = 0x1000000
+
+ MS_MGC_VAL = 0xC0ED0000
+ MS_MGC_MSK = 0xffff0000
+)
+
+// Constants for umount2(2).
+const (
+ MNT_FORCE = 0x1
+ MNT_DETACH = 0x2
+ MNT_EXPIRE = 0x4
+ UMOUNT_NOFOLLOW = 0x8
+)
+
+// Constants for unlinkat(2).
+const (
+ AT_REMOVEDIR = 0x200
+)
+
+// Constants for linkat(2) and fchownat(2).
+const (
+ AT_SYMLINK_FOLLOW = 0x400
+ AT_EMPTY_PATH = 0x1000
+)
+
+// Constants for all file-related ...at(2) syscalls.
+const (
+ AT_FDCWD = -100
+)
+
+// Special values for the ns field in utimensat(2).
+const (
+ UTIME_NOW = ((1 << 30) - 1)
+ UTIME_OMIT = ((1 << 30) - 2)
+)
+
+// MaxSymlinkTraversals is the maximum number of links that will be followed by
+// the kernel to resolve a symlink.
+const MaxSymlinkTraversals = 40
+
+// Constants for flock(2).
+const (
+ LOCK_SH = 1 // shared lock
+ LOCK_EX = 2 // exclusive lock
+ LOCK_NB = 4 // or'd with one of the above to prevent blocking
+ LOCK_UN = 8 // remove lock
+)
+
+// Values for mode_t.
+const (
+ S_IFMT = 0170000
+ S_IFSOCK = 0140000
+ S_IFLNK = 0120000
+ S_IFREG = 0100000
+ S_IFBLK = 060000
+ S_IFDIR = 040000
+ S_IFCHR = 020000
+ S_IFIFO = 010000
+
+ FileTypeMask = S_IFMT
+ ModeSocket = S_IFSOCK
+ ModeSymlink = S_IFLNK
+ ModeRegular = S_IFREG
+ ModeBlockDevice = S_IFBLK
+ ModeDirectory = S_IFDIR
+ ModeCharacterDevice = S_IFCHR
+ ModeNamedPipe = S_IFIFO
+
+ S_ISUID = 04000
+ S_ISGID = 02000
+ S_ISVTX = 01000
+
+ ModeSetUID = S_ISUID
+ ModeSetGID = S_ISGID
+ ModeSticky = S_ISVTX
+
+ ModeUserAll = 0700
+ ModeUserRead = 0400
+ ModeUserWrite = 0200
+ ModeUserExec = 0100
+ ModeGroupAll = 0070
+ ModeGroupRead = 0040
+ ModeGroupWrite = 0020
+ ModeGroupExec = 0010
+ ModeOtherAll = 0007
+ ModeOtherRead = 0004
+ ModeOtherWrite = 0002
+ ModeOtherExec = 0001
+ PermissionsMask = 0777
+)
+
+// Values for linux_dirent64.d_type.
+const (
+ DT_UNKNOWN = 0
+ DT_FIFO = 1
+ DT_CHR = 2
+ DT_DIR = 4
+ DT_BLK = 6
+ DT_REG = 8
+ DT_LNK = 10
+ DT_SOCK = 12
+ DT_WHT = 14
+)
+
+// DirentType are the friendly strings for linux_dirent64.d_type.
+var DirentType = abi.ValueSet{
+ DT_UNKNOWN: "DT_UNKNOWN",
+ DT_FIFO: "DT_FIFO",
+ DT_CHR: "DT_CHR",
+ DT_DIR: "DT_DIR",
+ DT_BLK: "DT_BLK",
+ DT_REG: "DT_REG",
+ DT_LNK: "DT_LNK",
+ DT_SOCK: "DT_SOCK",
+ DT_WHT: "DT_WHT",
+}
+
+// Values for preadv2/pwritev2.
+const (
+ // NOTE(b/120162627): gVisor does not implement the RWF_HIPRI feature, but
+ // the flag is accepted as a valid flag argument for preadv2/pwritev2 and
+ // silently ignored.
+ RWF_HIPRI = 0x00000001
+ RWF_DSYNC = 0x00000002
+ RWF_SYNC = 0x00000004
+ RWF_VALID = RWF_HIPRI | RWF_DSYNC | RWF_SYNC
+)
+
+// SizeOfStat is the size of a Stat struct.
+var SizeOfStat = binary.Size(Stat{})
+
+// Flags for statx.
+const (
+ AT_STATX_SYNC_TYPE = 0x6000
+ AT_STATX_SYNC_AS_STAT = 0x0000
+ AT_STATX_FORCE_SYNC = 0x2000
+ AT_STATX_DONT_SYNC = 0x4000
+)
+
+// Mask values for statx.
+const (
+ STATX_TYPE = 0x00000001
+ STATX_MODE = 0x00000002
+ STATX_NLINK = 0x00000004
+ STATX_UID = 0x00000008
+ STATX_GID = 0x00000010
+ STATX_ATIME = 0x00000020
+ STATX_MTIME = 0x00000040
+ STATX_CTIME = 0x00000080
+ STATX_INO = 0x00000100
+ STATX_SIZE = 0x00000200
+ STATX_BLOCKS = 0x00000400
+ STATX_BASIC_STATS = 0x000007ff
+ STATX_BTIME = 0x00000800
+ STATX_ALL = 0x00000fff
+ STATX__RESERVED = 0x80000000
+)
+
+// Bitmasks for Statx.Attributes and Statx.AttributesMask, from
+// include/uapi/linux/stat.h.
+const (
+ STATX_ATTR_COMPRESSED = 0x00000004
+ STATX_ATTR_IMMUTABLE = 0x00000010
+ STATX_ATTR_APPEND = 0x00000020
+ STATX_ATTR_NODUMP = 0x00000040
+ STATX_ATTR_ENCRYPTED = 0x00000800
+ STATX_ATTR_AUTOMOUNT = 0x00001000
+)
+
+// Statx represents struct statx.
+//
+// +marshal
+type Statx struct {
+ Mask uint32
+ Blksize uint32
+ Attributes uint64
+ Nlink uint32
+ UID uint32
+ GID uint32
+ Mode uint16
+ _ uint16
+ Ino uint64
+ Size uint64
+ Blocks uint64
+ AttributesMask uint64
+ Atime StatxTimestamp
+ Btime StatxTimestamp
+ Ctime StatxTimestamp
+ Mtime StatxTimestamp
+ RdevMajor uint32
+ RdevMinor uint32
+ DevMajor uint32
+ DevMinor uint32
+}
+
+// SizeOfStatx is the size of a Statx struct.
+var SizeOfStatx = binary.Size(Statx{})
+
+// FileMode represents a mode_t.
+type FileMode uint16
+
+// Permissions returns just the permission bits.
+func (m FileMode) Permissions() FileMode {
+ return m & PermissionsMask
+}
+
+// FileType returns just the file type bits.
+func (m FileMode) FileType() FileMode {
+ return m & FileTypeMask
+}
+
+// ExtraBits returns everything but the file type and permission bits.
+func (m FileMode) ExtraBits() FileMode {
+ return m &^ (PermissionsMask | FileTypeMask)
+}
+
+// IsDir returns true if file type represents a directory.
+func (m FileMode) IsDir() bool {
+ return m.FileType() == S_IFDIR
+}
+
+// String returns a string representation of m.
+func (m FileMode) String() string {
+ var s []string
+ if ft := m.FileType(); ft != 0 {
+ s = append(s, fileType.Parse(uint64(ft)))
+ }
+ if eb := m.ExtraBits(); eb != 0 {
+ s = append(s, modeExtraBits.Parse(uint64(eb)))
+ }
+ s = append(s, fmt.Sprintf("0o%o", m.Permissions()))
+ return strings.Join(s, "|")
+}
+
+// DirentType maps file types to dirent types appropriate for (struct
+// dirent)::d_type.
+func (m FileMode) DirentType() uint8 {
+ switch m.FileType() {
+ case ModeSocket:
+ return DT_SOCK
+ case ModeSymlink:
+ return DT_LNK
+ case ModeRegular:
+ return DT_REG
+ case ModeBlockDevice:
+ return DT_BLK
+ case ModeDirectory:
+ return DT_DIR
+ case ModeCharacterDevice:
+ return DT_CHR
+ case ModeNamedPipe:
+ return DT_FIFO
+ default:
+ return DT_UNKNOWN
+ }
+}
+
+var modeExtraBits = abi.FlagSet{
+ {
+ Flag: ModeSetUID,
+ Name: "S_ISUID",
+ },
+ {
+ Flag: ModeSetGID,
+ Name: "S_ISGID",
+ },
+ {
+ Flag: ModeSticky,
+ Name: "S_ISVTX",
+ },
+}
+
+var fileType = abi.ValueSet{
+ ModeSocket: "S_IFSOCK",
+ ModeSymlink: "S_IFLINK",
+ ModeRegular: "S_IFREG",
+ ModeBlockDevice: "S_IFBLK",
+ ModeDirectory: "S_IFDIR",
+ ModeCharacterDevice: "S_IFCHR",
+ ModeNamedPipe: "S_IFIFO",
+}
+
+// Constants for memfd_create(2). Source: include/uapi/linux/memfd.h
+const (
+ MFD_CLOEXEC = 0x0001
+ MFD_ALLOW_SEALING = 0x0002
+)
+
+// Constants related to file seals. Source: include/uapi/{asm-generic,linux}/fcntl.h
+const (
+ F_LINUX_SPECIFIC_BASE = 1024
+ F_ADD_SEALS = F_LINUX_SPECIFIC_BASE + 9
+ F_GET_SEALS = F_LINUX_SPECIFIC_BASE + 10
+
+ F_SEAL_SEAL = 0x0001 // Prevent further seals from being set.
+ F_SEAL_SHRINK = 0x0002 // Prevent file from shrinking.
+ F_SEAL_GROW = 0x0004 // Prevent file from growing.
+ F_SEAL_WRITE = 0x0008 // Prevent writes.
+)
+
+// Constants related to fallocate(2). Source: include/uapi/linux/falloc.h
+const (
+ FALLOC_FL_KEEP_SIZE = 0x01
+ FALLOC_FL_PUNCH_HOLE = 0x02
+ FALLOC_FL_NO_HIDE_STALE = 0x04
+ FALLOC_FL_COLLAPSE_RANGE = 0x08
+ FALLOC_FL_ZERO_RANGE = 0x10
+ FALLOC_FL_INSERT_RANGE = 0x20
+ FALLOC_FL_UNSHARE_RANGE = 0x40
+)
diff --git a/pkg/abi/linux/file_amd64.go b/pkg/abi/linux/file_amd64.go
new file mode 100644
index 000000000..6b72364ea
--- /dev/null
+++ b/pkg/abi/linux/file_amd64.go
@@ -0,0 +1,46 @@
+// Copyright 2018 The gVisor Authors.
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+// Constants for open(2).
+const (
+ O_DIRECT = 000040000
+ O_LARGEFILE = 000100000
+ O_DIRECTORY = 000200000
+ O_NOFOLLOW = 000400000
+)
+
+// Stat represents struct stat.
+//
+// +marshal
+type Stat struct {
+ Dev uint64
+ Ino uint64
+ Nlink uint64
+ Mode uint32
+ UID uint32
+ GID uint32
+ _ int32
+ Rdev uint64
+ Size int64
+ Blksize int64
+ Blocks int64
+ ATime Timespec
+ MTime Timespec
+ CTime Timespec
+ _ [3]int64
+}
diff --git a/pkg/abi/linux/file_arm64.go b/pkg/abi/linux/file_arm64.go
new file mode 100644
index 000000000..6492c9038
--- /dev/null
+++ b/pkg/abi/linux/file_arm64.go
@@ -0,0 +1,47 @@
+// Copyright 2019 The gVisor Authors.
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package linux
+
+// Constants for open(2).
+const (
+ O_DIRECTORY = 000040000
+ O_NOFOLLOW = 000100000
+ O_DIRECT = 000200000
+ O_LARGEFILE = 000400000
+)
+
+// Stat represents struct stat.
+//
+// +marshal
+type Stat struct {
+ Dev uint64
+ Ino uint64
+ Mode uint32
+ Nlink uint32
+ UID uint32
+ GID uint32
+ Rdev uint64
+ _ uint64
+ Size int64
+ Blksize int32
+ _ int32
+ Blocks int64
+ ATime Timespec
+ MTime Timespec
+ CTime Timespec
+ _ [2]int32
+}
diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go
new file mode 100644
index 000000000..158d2db5b
--- /dev/null
+++ b/pkg/abi/linux/fs.go
@@ -0,0 +1,103 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Filesystem types used in statfs(2).
+//
+// See linux/magic.h.
+const (
+ ANON_INODE_FS_MAGIC = 0x09041934
+ DEVPTS_SUPER_MAGIC = 0x00001cd1
+ EXT_SUPER_MAGIC = 0xef53
+ OVERLAYFS_SUPER_MAGIC = 0x794c7630
+ PIPEFS_MAGIC = 0x50495045
+ PROC_SUPER_MAGIC = 0x9fa0
+ RAMFS_MAGIC = 0x09041934
+ SOCKFS_MAGIC = 0x534F434B
+ SYSFS_MAGIC = 0x62656572
+ TMPFS_MAGIC = 0x01021994
+ V9FS_MAGIC = 0x01021997
+)
+
+// Filesystem path limits, from uapi/linux/limits.h.
+const (
+ NAME_MAX = 255
+ PATH_MAX = 4096
+)
+
+// Statfs is struct statfs, from uapi/asm-generic/statfs.h.
+//
+// +marshal
+type Statfs struct {
+ // Type is one of the filesystem magic values, defined above.
+ Type uint64
+
+ // BlockSize is the data block size.
+ BlockSize int64
+
+ // Blocks is the number of data blocks in use.
+ Blocks uint64
+
+ // BlocksFree is the number of free blocks.
+ BlocksFree uint64
+
+ // BlocksAvailable is the number of blocks free for use by
+ // unprivileged users.
+ BlocksAvailable uint64
+
+ // Files is the number of used file nodes on the filesystem.
+ Files uint64
+
+ // FileFress is the number of free file nodes on the filesystem.
+ FilesFree uint64
+
+ // FSID is the filesystem ID.
+ FSID [2]int32
+
+ // NameLength is the maximum file name length.
+ NameLength uint64
+
+ // FragmentSize is equivalent to BlockSize.
+ FragmentSize int64
+
+ // Flags is the set of filesystem mount flags.
+ Flags uint64
+
+ // Spare is unused.
+ Spare [4]uint64
+}
+
+// Whence argument to lseek(2), from include/uapi/linux/fs.h.
+const (
+ SEEK_SET = 0
+ SEEK_CUR = 1
+ SEEK_END = 2
+ SEEK_DATA = 3
+ SEEK_HOLE = 4
+)
+
+// Sync_file_range flags, from include/uapi/linux/fs.h
+const (
+ SYNC_FILE_RANGE_WAIT_BEFORE = 1
+ SYNC_FILE_RANGE_WRITE = 2
+ SYNC_FILE_RANGE_WAIT_AFTER = 4
+)
+
+// Flag argument to renameat2(2), from include/uapi/linux/fs.h.
+const (
+ RENAME_NOREPLACE = (1 << 0) // Don't overwrite target.
+ RENAME_EXCHANGE = (1 << 1) // Exchange src and dst.
+ RENAME_WHITEOUT = (1 << 2) // Whiteout src.
+)
diff --git a/pkg/abi/linux/futex.go b/pkg/abi/linux/futex.go
new file mode 100644
index 000000000..08bfde3b5
--- /dev/null
+++ b/pkg/abi/linux/futex.go
@@ -0,0 +1,62 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// From <linux/futex.h> and <sys/time.h>.
+// Flags are used in syscall futex(2).
+const (
+ FUTEX_WAIT = 0
+ FUTEX_WAKE = 1
+ FUTEX_FD = 2
+ FUTEX_REQUEUE = 3
+ FUTEX_CMP_REQUEUE = 4
+ FUTEX_WAKE_OP = 5
+ FUTEX_LOCK_PI = 6
+ FUTEX_UNLOCK_PI = 7
+ FUTEX_TRYLOCK_PI = 8
+ FUTEX_WAIT_BITSET = 9
+ FUTEX_WAKE_BITSET = 10
+ FUTEX_WAIT_REQUEUE_PI = 11
+ FUTEX_CMP_REQUEUE_PI = 12
+
+ FUTEX_PRIVATE_FLAG = 128
+ FUTEX_CLOCK_REALTIME = 256
+)
+
+// These are flags are from <linux/futex.h> and are used in FUTEX_WAKE_OP
+// to define the operations.
+const (
+ FUTEX_OP_SET = 0
+ FUTEX_OP_ADD = 1
+ FUTEX_OP_OR = 2
+ FUTEX_OP_ANDN = 3
+ FUTEX_OP_XOR = 4
+ FUTEX_OP_OPARG_SHIFT = 8
+ FUTEX_OP_CMP_EQ = 0
+ FUTEX_OP_CMP_NE = 1
+ FUTEX_OP_CMP_LT = 2
+ FUTEX_OP_CMP_LE = 3
+ FUTEX_OP_CMP_GT = 4
+ FUTEX_OP_CMP_GE = 5
+)
+
+// FUTEX_TID_MASK is the TID portion of a PI futex word.
+const FUTEX_TID_MASK = 0x3fffffff
+
+// Constants used for priority-inheritance futexes.
+const (
+ FUTEX_WAITERS = 0x80000000
+ FUTEX_OWNER_DIED = 0x40000000
+)
diff --git a/pkg/abi/linux/inotify.go b/pkg/abi/linux/inotify.go
new file mode 100644
index 000000000..2d08194ba
--- /dev/null
+++ b/pkg/abi/linux/inotify.go
@@ -0,0 +1,97 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Inotify events observable by userspace. These directly correspond to
+// filesystem operations and there may only be a single of them per inotify
+// event read from an inotify fd.
+const (
+ // IN_ACCESS indicates a file was accessed.
+ IN_ACCESS = 0x00000001
+ // IN_MODIFY indicates a file was modified.
+ IN_MODIFY = 0x00000002
+ // IN_ATTRIB indicates a watch target's metadata changed.
+ IN_ATTRIB = 0x00000004
+ // IN_CLOSE_WRITE indicates a writable file was closed.
+ IN_CLOSE_WRITE = 0x00000008
+ // IN_CLOSE_NOWRITE indicates a non-writable file was closed.
+ IN_CLOSE_NOWRITE = 0x00000010
+ // IN_OPEN indicates a file was opened.
+ IN_OPEN = 0x00000020
+ // IN_MOVED_FROM indicates a file was moved from X.
+ IN_MOVED_FROM = 0x00000040
+ // IN_MOVED_TO indicates a file was moved to Y.
+ IN_MOVED_TO = 0x00000080
+ // IN_CREATE indicates a file was created in a watched directory.
+ IN_CREATE = 0x00000100
+ // IN_DELETE indicates a file was deleted in a watched directory.
+ IN_DELETE = 0x00000200
+ // IN_DELETE_SELF indicates a watch target itself was deleted.
+ IN_DELETE_SELF = 0x00000400
+ // IN_MOVE_SELF indicates a watch target itself was moved.
+ IN_MOVE_SELF = 0x00000800
+ // IN_ALL_EVENTS is a mask for all observable userspace events.
+ IN_ALL_EVENTS = 0x00000fff
+)
+
+// Inotify control events. These may be present in their own events, or ORed
+// with other observable events.
+const (
+ // IN_UNMOUNT indicates the backing filesystem was unmounted.
+ IN_UNMOUNT = 0x00002000
+ // IN_Q_OVERFLOW indicates the event queued overflowed.
+ IN_Q_OVERFLOW = 0x00004000
+ // IN_IGNORED indicates a watch was removed, either implicitly or through
+ // inotify_rm_watch(2).
+ IN_IGNORED = 0x00008000
+ // IN_ISDIR indicates the subject of an event was a directory.
+ IN_ISDIR = 0x40000000
+)
+
+// Feature flags for inotify_add_watch(2).
+const (
+ // IN_ONLYDIR indicates that a path should be watched only if it's a
+ // directory.
+ IN_ONLYDIR = 0x01000000
+ // IN_DONT_FOLLOW indicates that the watch path shouldn't be resolved if
+ // it's a symlink.
+ IN_DONT_FOLLOW = 0x02000000
+ // IN_EXCL_UNLINK indicates events to this watch from unlinked objects
+ // should be filtered out.
+ IN_EXCL_UNLINK = 0x04000000
+ // IN_MASK_ADD indicates the provided mask should be ORed into any existing
+ // watch on the provided path.
+ IN_MASK_ADD = 0x20000000
+ // IN_ONESHOT indicates the watch should be removed after one event.
+ IN_ONESHOT = 0x80000000
+)
+
+// Feature flags for inotify_init1(2).
+const (
+ // IN_CLOEXEC is an alias for O_CLOEXEC. It indicates that the inotify
+ // fd should be closed on exec(2) and friends.
+ IN_CLOEXEC = 0x00080000
+ // IN_NONBLOCK is an alias for O_NONBLOCK. It indicates I/O syscall on the
+ // inotify fd should not block.
+ IN_NONBLOCK = 0x00000800
+)
+
+// ALL_INOTIFY_BITS contains all the bits for all possible inotify events. It's
+// defined in the Linux source at "include/linux/inotify.h".
+const ALL_INOTIFY_BITS = IN_ACCESS | IN_MODIFY | IN_ATTRIB | IN_CLOSE_WRITE |
+ IN_CLOSE_NOWRITE | IN_OPEN | IN_MOVED_FROM | IN_MOVED_TO | IN_CREATE |
+ IN_DELETE | IN_DELETE_SELF | IN_MOVE_SELF | IN_UNMOUNT | IN_Q_OVERFLOW |
+ IN_IGNORED | IN_ONLYDIR | IN_DONT_FOLLOW | IN_EXCL_UNLINK | IN_MASK_ADD |
+ IN_ISDIR | IN_ONESHOT
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
new file mode 100644
index 000000000..2062e6a4b
--- /dev/null
+++ b/pkg/abi/linux/ioctl.go
@@ -0,0 +1,100 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// ioctl(2) requests provided by asm-generic/ioctls.h
+//
+// These are ordered by request number (low byte).
+const (
+ TCGETS = 0x00005401
+ TCSETS = 0x00005402
+ TCSETSW = 0x00005403
+ TCSETSF = 0x00005404
+ TCSBRK = 0x00005409
+ TIOCEXCL = 0x0000540c
+ TIOCNXCL = 0x0000540d
+ TIOCSCTTY = 0x0000540e
+ TIOCGPGRP = 0x0000540f
+ TIOCSPGRP = 0x00005410
+ TIOCOUTQ = 0x00005411
+ TIOCSTI = 0x00005412
+ TIOCGWINSZ = 0x00005413
+ TIOCSWINSZ = 0x00005414
+ TIOCMGET = 0x00005415
+ TIOCMBIS = 0x00005416
+ TIOCMBIC = 0x00005417
+ TIOCMSET = 0x00005418
+ TIOCINQ = 0x0000541b
+ FIONREAD = TIOCINQ
+ FIONBIO = 0x00005421
+ TIOCSETD = 0x00005423
+ TIOCNOTTY = 0x00005422
+ TIOCGETD = 0x00005424
+ TCSBRKP = 0x00005425
+ TIOCSBRK = 0x00005427
+ TIOCCBRK = 0x00005428
+ TIOCGSID = 0x00005429
+ TIOCGPTN = 0x80045430
+ TIOCSPTLCK = 0x40045431
+ TIOCGDEV = 0x80045432
+ TIOCVHANGUP = 0x00005437
+ TCFLSH = 0x0000540b
+ TIOCCONS = 0x0000541d
+ TIOCSSERIAL = 0x0000541f
+ TIOCGEXCL = 0x80045440
+ TIOCGPTPEER = 0x80045441
+ TIOCGICOUNT = 0x0000545d
+ FIONCLEX = 0x00005450
+ FIOCLEX = 0x00005451
+ FIOASYNC = 0x00005452
+ FIOSETOWN = 0x00008901
+ SIOCSPGRP = 0x00008902
+ FIOGETOWN = 0x00008903
+ SIOCGPGRP = 0x00008904
+)
+
+// ioctl(2) requests provided by uapi/linux/sockios.h
+const (
+ SIOCGIFMEM = 0x891f
+ SIOCGIFPFLAGS = 0x8935
+ SIOCGMIIPHY = 0x8947
+ SIOCGMIIREG = 0x8948
+)
+
+// ioctl(2) directions. Used to calculate requests number.
+// Constants from asm-generic/ioctl.h.
+const (
+ _IOC_NONE = 0
+ _IOC_WRITE = 1
+ _IOC_READ = 2
+)
+
+// Constants from asm-generic/ioctl.h.
+const (
+ _IOC_NRBITS = 8
+ _IOC_TYPEBITS = 8
+ _IOC_SIZEBITS = 14
+ _IOC_DIRBITS = 2
+
+ _IOC_NRSHIFT = 0
+ _IOC_TYPESHIFT = _IOC_NRSHIFT + _IOC_NRBITS
+ _IOC_SIZESHIFT = _IOC_TYPESHIFT + _IOC_TYPEBITS
+ _IOC_DIRSHIFT = _IOC_SIZESHIFT + _IOC_SIZEBITS
+)
+
+// IOC outputs the result of _IOC macro in asm-generic/ioctl.h.
+func IOC(dir, typ, nr, size uint32) uint32 {
+ return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT
+}
diff --git a/pkg/abi/linux/ioctl_tun.go b/pkg/abi/linux/ioctl_tun.go
new file mode 100644
index 000000000..c59c9c136
--- /dev/null
+++ b/pkg/abi/linux/ioctl_tun.go
@@ -0,0 +1,29 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// ioctl(2) request numbers from linux/if_tun.h
+var (
+ TUNSETIFF = IOC(_IOC_WRITE, 'T', 202, 4)
+ TUNGETIFF = IOC(_IOC_READ, 'T', 210, 4)
+)
+
+// Flags from net/if_tun.h
+const (
+ IFF_TUN = 0x0001
+ IFF_TAP = 0x0002
+ IFF_NO_PI = 0x1000
+ IFF_NOFILTER = 0x1000
+)
diff --git a/pkg/abi/linux/ip.go b/pkg/abi/linux/ip.go
new file mode 100644
index 000000000..ef6d1093e
--- /dev/null
+++ b/pkg/abi/linux/ip.go
@@ -0,0 +1,161 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// IP protocols
+const (
+ IPPROTO_IP = 0
+ IPPROTO_ICMP = 1
+ IPPROTO_IGMP = 2
+ IPPROTO_IPIP = 4
+ IPPROTO_TCP = 6
+ IPPROTO_EGP = 8
+ IPPROTO_PUP = 12
+ IPPROTO_UDP = 17
+ IPPROTO_IDP = 22
+ IPPROTO_TP = 29
+ IPPROTO_DCCP = 33
+ IPPROTO_IPV6 = 41
+ IPPROTO_RSVP = 46
+ IPPROTO_GRE = 47
+ IPPROTO_ESP = 50
+ IPPROTO_AH = 51
+ IPPROTO_MTP = 92
+ IPPROTO_BEETPH = 94
+ IPPROTO_ENCAP = 98
+ IPPROTO_PIM = 103
+ IPPROTO_COMP = 108
+ IPPROTO_SCTP = 132
+ IPPROTO_UDPLITE = 136
+ IPPROTO_MPLS = 137
+ IPPROTO_RAW = 255
+)
+
+// Socket options from uapi/linux/in.h
+const (
+ IP_TOS = 1
+ IP_TTL = 2
+ IP_HDRINCL = 3
+ IP_OPTIONS = 4
+ IP_ROUTER_ALERT = 5
+ IP_RECVOPTS = 6
+ IP_RETOPTS = 7
+ IP_PKTINFO = 8
+ IP_PKTOPTIONS = 9
+ IP_MTU_DISCOVER = 10
+ IP_RECVERR = 11
+ IP_RECVTTL = 12
+ IP_RECVTOS = 13
+ IP_MTU = 14
+ IP_FREEBIND = 15
+ IP_IPSEC_POLICY = 16
+ IP_XFRM_POLICY = 17
+ IP_PASSSEC = 18
+ IP_TRANSPARENT = 19
+ IP_ORIGDSTADDR = 20
+ IP_RECVORIGDSTADDR = IP_ORIGDSTADDR
+ IP_MINTTL = 21
+ IP_NODEFRAG = 22
+ IP_CHECKSUM = 23
+ IP_BIND_ADDRESS_NO_PORT = 24
+ IP_RECVFRAGSIZE = 25
+ IP_MULTICAST_IF = 32
+ IP_MULTICAST_TTL = 33
+ IP_MULTICAST_LOOP = 34
+ IP_ADD_MEMBERSHIP = 35
+ IP_DROP_MEMBERSHIP = 36
+ IP_UNBLOCK_SOURCE = 37
+ IP_BLOCK_SOURCE = 38
+ IP_ADD_SOURCE_MEMBERSHIP = 39
+ IP_DROP_SOURCE_MEMBERSHIP = 40
+ IP_MSFILTER = 41
+ MCAST_JOIN_GROUP = 42
+ MCAST_BLOCK_SOURCE = 43
+ MCAST_UNBLOCK_SOURCE = 44
+ MCAST_LEAVE_GROUP = 45
+ MCAST_JOIN_SOURCE_GROUP = 46
+ MCAST_LEAVE_SOURCE_GROUP = 47
+ MCAST_MSFILTER = 48
+ IP_MULTICAST_ALL = 49
+ IP_UNICAST_IF = 50
+)
+
+// IP_MTU_DISCOVER values from uapi/linux/in.h
+const (
+ IP_PMTUDISC_DONT = 0
+ IP_PMTUDISC_WANT = 1
+ IP_PMTUDISC_DO = 2
+ IP_PMTUDISC_PROBE = 3
+ IP_PMTUDISC_INTERFACE = 4
+ IP_PMTUDISC_OMIT = 5
+)
+
+// Socket options from uapi/linux/in6.h
+const (
+ IPV6_ADDRFORM = 1
+ IPV6_2292PKTINFO = 2
+ IPV6_2292HOPOPTS = 3
+ IPV6_2292DSTOPTS = 4
+ IPV6_2292RTHDR = 5
+ IPV6_2292PKTOPTIONS = 6
+ IPV6_CHECKSUM = 7
+ IPV6_2292HOPLIMIT = 8
+ IPV6_NEXTHOP = 9
+ IPV6_FLOWINFO = 11
+ IPV6_UNICAST_HOPS = 16
+ IPV6_MULTICAST_IF = 17
+ IPV6_MULTICAST_HOPS = 18
+ IPV6_MULTICAST_LOOP = 19
+ IPV6_ADD_MEMBERSHIP = 20
+ IPV6_DROP_MEMBERSHIP = 21
+ IPV6_ROUTER_ALERT = 22
+ IPV6_MTU_DISCOVER = 23
+ IPV6_MTU = 24
+ IPV6_RECVERR = 25
+ IPV6_V6ONLY = 26
+ IPV6_JOIN_ANYCAST = 27
+ IPV6_LEAVE_ANYCAST = 28
+ IPV6_MULTICAST_ALL = 29
+ IPV6_FLOWLABEL_MGR = 32
+ IPV6_FLOWINFO_SEND = 33
+ IPV6_IPSEC_POLICY = 34
+ IPV6_XFRM_POLICY = 35
+ IPV6_HDRINCL = 36
+ IPV6_RECVPKTINFO = 49
+ IPV6_PKTINFO = 50
+ IPV6_RECVHOPLIMIT = 51
+ IPV6_HOPLIMIT = 52
+ IPV6_RECVHOPOPTS = 53
+ IPV6_HOPOPTS = 54
+ IPV6_RTHDRDSTOPTS = 55
+ IPV6_RECVRTHDR = 56
+ IPV6_RTHDR = 57
+ IPV6_RECVDSTOPTS = 58
+ IPV6_DSTOPTS = 59
+ IPV6_RECVPATHMTU = 60
+ IPV6_PATHMTU = 61
+ IPV6_DONTFRAG = 62
+ IPV6_RECVTCLASS = 66
+ IPV6_TCLASS = 67
+ IPV6_AUTOFLOWLABEL = 70
+ IPV6_ADDR_PREFERENCES = 72
+ IPV6_MINHOPCOUNT = 73
+ IPV6_ORIGDSTADDR = 74
+ IPV6_RECVORIGDSTADDR = IPV6_ORIGDSTADDR
+ IPV6_TRANSPARENT = 75
+ IPV6_UNICAST_IF = 76
+ IPV6_RECVFRAGSIZE = 77
+ IPV6_FREEBIND = 78
+)
diff --git a/pkg/abi/linux/ipc.go b/pkg/abi/linux/ipc.go
new file mode 100644
index 000000000..22acd2d43
--- /dev/null
+++ b/pkg/abi/linux/ipc.go
@@ -0,0 +1,53 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Control commands used with semctl, shmctl, and msgctl. Source:
+// include/uapi/linux/ipc.h.
+const (
+ IPC_RMID = 0
+ IPC_SET = 1
+ IPC_STAT = 2
+ IPC_INFO = 3
+)
+
+// resource get request flags. Source: include/uapi/linux/ipc.h
+const (
+ IPC_CREAT = 00001000
+ IPC_EXCL = 00002000
+ IPC_NOWAIT = 00004000
+)
+
+const IPC_PRIVATE = 0
+
+// In Linux, amd64 does not enable CONFIG_ARCH_WANT_IPC_PARSE_VERSION, so SysV
+// IPC unconditionally uses the "new" 64-bit structures that are needed for
+// features like 32-bit UIDs.
+
+// IPCPerm is equivalent to struct ipc64_perm.
+type IPCPerm struct {
+ Key uint32
+ UID uint32
+ GID uint32
+ CUID uint32
+ CGID uint32
+ Mode uint16
+ _ uint16
+ Seq uint16
+ _ uint16
+ _ uint32
+ unused1 uint64
+ unused2 uint64
+}
diff --git a/pkg/abi/linux/limits.go b/pkg/abi/linux/limits.go
new file mode 100644
index 000000000..c74dfcd53
--- /dev/null
+++ b/pkg/abi/linux/limits.go
@@ -0,0 +1,88 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Resources for getrlimit(2)/setrlimit(2)/prlimit(2).
+const (
+ RLIMIT_CPU = 0
+ RLIMIT_FSIZE = 1
+ RLIMIT_DATA = 2
+ RLIMIT_STACK = 3
+ RLIMIT_CORE = 4
+ RLIMIT_RSS = 5
+ RLIMIT_NPROC = 6
+ RLIMIT_NOFILE = 7
+ RLIMIT_MEMLOCK = 8
+ RLIMIT_AS = 9
+ RLIMIT_LOCKS = 10
+ RLIMIT_SIGPENDING = 11
+ RLIMIT_MSGQUEUE = 12
+ RLIMIT_NICE = 13
+ RLIMIT_RTPRIO = 14
+ RLIMIT_RTTIME = 15
+)
+
+// RLimit corresponds to Linux's struct rlimit.
+type RLimit struct {
+ // Cur specifies the soft limit.
+ Cur uint64
+ // Max specifies the hard limit.
+ Max uint64
+}
+
+const (
+ // RLimInfinity is RLIM_INFINITY on Linux.
+ RLimInfinity = ^uint64(0)
+
+ // DefaultStackSoftLimit is called _STK_LIM in Linux.
+ DefaultStackSoftLimit = 8 * 1024 * 1024
+
+ // DefaultNprocLimit is defined in kernel/fork.c:set_max_threads, and
+ // called MAX_THREADS / 2 in Linux.
+ DefaultNprocLimit = FUTEX_TID_MASK / 2
+
+ // DefaultNofileSoftLimit is called INR_OPEN_CUR in Linux.
+ DefaultNofileSoftLimit = 1024
+
+ // DefaultNofileHardLimit is called INR_OPEN_MAX in Linux.
+ DefaultNofileHardLimit = 4096
+
+ // DefaultMemlockLimit is called MLOCK_LIMIT in Linux.
+ DefaultMemlockLimit = 64 * 1024
+
+ // DefaultMsgqueueLimit is called MQ_BYTES_MAX in Linux.
+ DefaultMsgqueueLimit = 819200
+)
+
+// InitRLimits is a map of initial rlimits set by Linux in
+// include/asm-generic/resource.h.
+var InitRLimits = map[int]RLimit{
+ RLIMIT_CPU: {RLimInfinity, RLimInfinity},
+ RLIMIT_FSIZE: {RLimInfinity, RLimInfinity},
+ RLIMIT_DATA: {RLimInfinity, RLimInfinity},
+ RLIMIT_STACK: {DefaultStackSoftLimit, RLimInfinity},
+ RLIMIT_CORE: {0, RLimInfinity},
+ RLIMIT_RSS: {RLimInfinity, RLimInfinity},
+ RLIMIT_NPROC: {DefaultNprocLimit, DefaultNprocLimit},
+ RLIMIT_NOFILE: {DefaultNofileSoftLimit, DefaultNofileHardLimit},
+ RLIMIT_MEMLOCK: {DefaultMemlockLimit, DefaultMemlockLimit},
+ RLIMIT_AS: {RLimInfinity, RLimInfinity},
+ RLIMIT_LOCKS: {RLimInfinity, RLimInfinity},
+ RLIMIT_SIGPENDING: {0, 0},
+ RLIMIT_MSGQUEUE: {DefaultMsgqueueLimit, DefaultMsgqueueLimit},
+ RLIMIT_NICE: {0, 0},
+ RLIMIT_RTPRIO: {0, 0},
+ RLIMIT_RTTIME: {RLimInfinity, RLimInfinity},
+}
diff --git a/pkg/abi/linux/linux.go b/pkg/abi/linux/linux.go
new file mode 100644
index 000000000..281acdbde
--- /dev/null
+++ b/pkg/abi/linux/linux.go
@@ -0,0 +1,39 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package linux contains the constants and types needed to interface with a Linux kernel.
+package linux
+
+// NumSoftIRQ is the number of software IRQs, exposed via /proc/stat.
+//
+// Defined in linux/interrupt.h.
+const NumSoftIRQ = 10
+
+// Sysinfo is the structure provided by sysinfo on linux versions > 2.3.48.
+type Sysinfo struct {
+ Uptime int64
+ Loads [3]uint64
+ TotalRAM uint64
+ FreeRAM uint64
+ SharedRAM uint64
+ BufferRAM uint64
+ TotalSwap uint64
+ FreeSwap uint64
+ Procs uint16
+ _ [6]byte // Pad Procs to 64bits.
+ TotalHigh uint64
+ FreeHigh uint64
+ Unit uint32
+ /* The _f field in the glibc version of Sysinfo has size 0 on AMD64 */
+}
diff --git a/pkg/abi/linux/mm.go b/pkg/abi/linux/mm.go
new file mode 100644
index 000000000..07cc1895e
--- /dev/null
+++ b/pkg/abi/linux/mm.go
@@ -0,0 +1,130 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Protections for mmap(2).
+const (
+ PROT_NONE = 0
+ PROT_READ = 1 << 0
+ PROT_WRITE = 1 << 1
+ PROT_EXEC = 1 << 2
+ PROT_SEM = 1 << 3
+ PROT_GROWSDOWN = 1 << 24
+ PROT_GROWSUP = 1 << 25
+)
+
+// Flags for mmap(2).
+const (
+ MAP_SHARED = 1 << 0
+ MAP_PRIVATE = 1 << 1
+ MAP_FIXED = 1 << 4
+ MAP_ANONYMOUS = 1 << 5
+ MAP_32BIT = 1 << 6 // arch/x86/include/uapi/asm/mman.h
+ MAP_GROWSDOWN = 1 << 8
+ MAP_DENYWRITE = 1 << 11
+ MAP_EXECUTABLE = 1 << 12
+ MAP_LOCKED = 1 << 13
+ MAP_NORESERVE = 1 << 14
+ MAP_POPULATE = 1 << 15
+ MAP_NONBLOCK = 1 << 16
+ MAP_STACK = 1 << 17
+ MAP_HUGETLB = 1 << 18
+)
+
+// Flags for mremap(2).
+const (
+ MREMAP_MAYMOVE = 1 << 0
+ MREMAP_FIXED = 1 << 1
+)
+
+// Flags for mlock2(2).
+const (
+ MLOCK_ONFAULT = 0x01
+)
+
+// Flags for mlockall(2).
+const (
+ MCL_CURRENT = 1
+ MCL_FUTURE = 2
+ MCL_ONFAULT = 4
+)
+
+// Advice for madvise(2).
+const (
+ MADV_NORMAL = 0
+ MADV_RANDOM = 1
+ MADV_SEQUENTIAL = 2
+ MADV_WILLNEED = 3
+ MADV_DONTNEED = 4
+ MADV_REMOVE = 9
+ MADV_DONTFORK = 10
+ MADV_DOFORK = 11
+ MADV_MERGEABLE = 12
+ MADV_UNMERGEABLE = 13
+ MADV_HUGEPAGE = 14
+ MADV_NOHUGEPAGE = 15
+ MADV_DONTDUMP = 16
+ MADV_DODUMP = 17
+ MADV_HWPOISON = 100
+ MADV_SOFT_OFFLINE = 101
+ MADV_NOMAJFAULT = 200
+ MADV_DONTCHGME = 201
+)
+
+// Flags for msync(2).
+const (
+ MS_ASYNC = 1 << 0
+ MS_INVALIDATE = 1 << 1
+ MS_SYNC = 1 << 2
+)
+
+// NumaPolicy is the NUMA memory policy for a memory range. See numa(7).
+//
+// +marshal
+type NumaPolicy int32
+
+// Policies for get_mempolicy(2)/set_mempolicy(2).
+const (
+ MPOL_DEFAULT NumaPolicy = 0
+ MPOL_PREFERRED NumaPolicy = 1
+ MPOL_BIND NumaPolicy = 2
+ MPOL_INTERLEAVE NumaPolicy = 3
+ MPOL_LOCAL NumaPolicy = 4
+ MPOL_MAX NumaPolicy = 5
+)
+
+// Flags for get_mempolicy(2).
+const (
+ MPOL_F_NODE = 1 << 0
+ MPOL_F_ADDR = 1 << 1
+ MPOL_F_MEMS_ALLOWED = 1 << 2
+)
+
+// Flags for set_mempolicy(2).
+const (
+ MPOL_F_RELATIVE_NODES = 1 << 14
+ MPOL_F_STATIC_NODES = 1 << 15
+
+ MPOL_MODE_FLAGS = (MPOL_F_STATIC_NODES | MPOL_F_RELATIVE_NODES)
+)
+
+// Flags for mbind(2).
+const (
+ MPOL_MF_STRICT = 1 << 0
+ MPOL_MF_MOVE = 1 << 1
+ MPOL_MF_MOVE_ALL = 1 << 2
+
+ MPOL_MF_VALID = MPOL_MF_STRICT | MPOL_MF_MOVE | MPOL_MF_MOVE_ALL
+)
diff --git a/pkg/abi/linux/netdevice.go b/pkg/abi/linux/netdevice.go
new file mode 100644
index 000000000..7866352b4
--- /dev/null
+++ b/pkg/abi/linux/netdevice.go
@@ -0,0 +1,86 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import "gvisor.dev/gvisor/pkg/binary"
+
+const (
+ // IFNAMSIZ is the size of the name field for IFReq.
+ IFNAMSIZ = 16
+)
+
+// IFReq is an interface request.
+type IFReq struct {
+ // IFName is an encoded name, normally null-terminated. This should be
+ // accessed via the Name and SetName functions.
+ IFName [IFNAMSIZ]byte
+
+ // Data is the union of the following structures:
+ //
+ // struct sockaddr ifr_addr;
+ // struct sockaddr ifr_dstaddr;
+ // struct sockaddr ifr_broadaddr;
+ // struct sockaddr ifr_netmask;
+ // struct sockaddr ifr_hwaddr;
+ // short ifr_flags;
+ // int ifr_ifindex;
+ // int ifr_metric;
+ // int ifr_mtu;
+ // struct ifmap ifr_map;
+ // char ifr_slave[IFNAMSIZ];
+ // char ifr_newname[IFNAMSIZ];
+ // char *ifr_data;
+ Data [24]byte
+}
+
+// Name returns the name.
+func (ifr *IFReq) Name() string {
+ for c := 0; c < len(ifr.IFName); c++ {
+ if ifr.IFName[c] == 0 {
+ return string(ifr.IFName[:c])
+ }
+ }
+ return string(ifr.IFName[:])
+}
+
+// SetName sets the name.
+func (ifr *IFReq) SetName(name string) {
+ n := copy(ifr.IFName[:], []byte(name))
+ for i := n; i < len(ifr.IFName); i++ {
+ ifr.IFName[i] = 0
+ }
+}
+
+// SizeOfIFReq is the binary size of an IFReq struct (40 bytes).
+var SizeOfIFReq = binary.Size(IFReq{})
+
+// IFMap contains interface hardware parameters.
+type IFMap struct {
+ MemStart uint64
+ MemEnd uint64
+ BaseAddr int16
+ IRQ byte
+ DMA byte
+ Port byte
+ _ [3]byte // Pad to sizeof(struct ifmap).
+}
+
+// IFConf is used to return a list of interfaces and their addresses. See
+// netdevice(7) and struct ifconf for more detail on its use.
+type IFConf struct {
+ Len int32
+ _ [4]byte // Pad to sizeof(struct ifconf).
+ Ptr uint64
+}
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
new file mode 100644
index 000000000..46d8b0b42
--- /dev/null
+++ b/pkg/abi/linux/netfilter.go
@@ -0,0 +1,552 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// This file contains structures required to support netfilter, specifically
+// the iptables tool.
+
+// Hooks into the network stack. These correspond to values in
+// include/uapi/linux/netfilter.h.
+const (
+ NF_INET_PRE_ROUTING = 0
+ NF_INET_LOCAL_IN = 1
+ NF_INET_FORWARD = 2
+ NF_INET_LOCAL_OUT = 3
+ NF_INET_POST_ROUTING = 4
+ NF_INET_NUMHOOKS = 5
+)
+
+// Verdicts that can be returned by targets. These correspond to values in
+// include/uapi/linux/netfilter.h
+const (
+ NF_DROP = 0
+ NF_ACCEPT = 1
+ NF_STOLEN = 2
+ NF_QUEUE = 3
+ NF_REPEAT = 4
+ NF_STOP = 5
+ NF_MAX_VERDICT = NF_STOP
+ // NF_RETURN is defined in include/uapi/linux/netfilter/x_tables.h.
+ NF_RETURN = -NF_REPEAT - 1
+)
+
+// VerdictStrings maps int verdicts to the strings they represent. It is used
+// for debugging.
+var VerdictStrings = map[int32]string{
+ -NF_DROP - 1: "DROP",
+ -NF_ACCEPT - 1: "ACCEPT",
+ -NF_QUEUE - 1: "QUEUE",
+ NF_RETURN: "RETURN",
+}
+
+// Socket options. These correspond to values in
+// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+const (
+ IPT_BASE_CTL = 64
+ IPT_SO_SET_REPLACE = IPT_BASE_CTL
+ IPT_SO_SET_ADD_COUNTERS = IPT_BASE_CTL + 1
+ IPT_SO_SET_MAX = IPT_SO_SET_ADD_COUNTERS
+
+ IPT_SO_GET_INFO = IPT_BASE_CTL
+ IPT_SO_GET_ENTRIES = IPT_BASE_CTL + 1
+ IPT_SO_GET_REVISION_MATCH = IPT_BASE_CTL + 2
+ IPT_SO_GET_REVISION_TARGET = IPT_BASE_CTL + 3
+ IPT_SO_GET_MAX = IPT_SO_GET_REVISION_TARGET
+)
+
+// Name lengths. These correspond to values in
+// include/uapi/linux/netfilter/x_tables.h.
+const (
+ XT_FUNCTION_MAXNAMELEN = 30
+ XT_EXTENSION_MAXNAMELEN = 29
+ XT_TABLE_MAXNAMELEN = 32
+)
+
+// IPTEntry is an iptable rule. It corresponds to struct ipt_entry in
+// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+type IPTEntry struct {
+ // IP is used to filter packets based on the IP header.
+ IP IPTIP
+
+ // NFCache relates to kernel-internal caching and isn't used by
+ // userspace.
+ NFCache uint32
+
+ // TargetOffset is the byte offset from the beginning of this IPTEntry
+ // to the start of the entry's target.
+ TargetOffset uint16
+
+ // NextOffset is the byte offset from the beginning of this IPTEntry to
+ // the start of the next entry. It is thus also the size of the entry.
+ NextOffset uint16
+
+ // Comeback is a return pointer. It is not used by userspace.
+ Comeback uint32
+
+ // Counters holds the packet and byte counts for this rule.
+ Counters XTCounters
+
+ // Elems holds the data for all this rule's matches followed by the
+ // target. It is variable length -- users have to iterate over any
+ // matches and use TargetOffset and NextOffset to make sense of the
+ // data.
+ //
+ // Elems is omitted here because it would cause IPTEntry to be an extra
+ // byte larger (see http://www.catb.org/esr/structure-packing/).
+ //
+ // Elems [0]byte
+}
+
+// SizeOfIPTEntry is the size of an IPTEntry.
+const SizeOfIPTEntry = 112
+
+// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. This
+// struct marshaled via the binary package to write an IPTEntry to userspace.
+type KernelIPTEntry struct {
+ IPTEntry
+
+ // Elems holds the data for all this rule's matches followed by the
+ // target. It is variable length -- users have to iterate over any
+ // matches and use TargetOffset and NextOffset to make sense of the
+ // data.
+ Elems []byte
+}
+
+// IPTIP contains information for matching a packet's IP header.
+// It corresponds to struct ipt_ip in
+// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+type IPTIP struct {
+ // Src is the source IP address.
+ Src InetAddr
+
+ // Dst is the destination IP address.
+ Dst InetAddr
+
+ // SrcMask is the source IP mask.
+ SrcMask InetAddr
+
+ // DstMask is the destination IP mask.
+ DstMask InetAddr
+
+ // InputInterface is the input network interface.
+ InputInterface [IFNAMSIZ]byte
+
+ // OutputInterface is the output network interface.
+ OutputInterface [IFNAMSIZ]byte
+
+ // InputInterfaceMask is the input interface mask.
+ InputInterfaceMask [IFNAMSIZ]byte
+
+ // OuputInterfaceMask is the output interface mask.
+ OutputInterfaceMask [IFNAMSIZ]byte
+
+ // Protocol is the transport protocol.
+ Protocol uint16
+
+ // Flags define matching behavior for the IP header.
+ Flags uint8
+
+ // InverseFlags invert the meaning of fields in struct IPTIP. See the
+ // IPT_INV_* flags.
+ InverseFlags uint8
+}
+
+// Flags in IPTIP.InverseFlags. Corresponding constants are in
+// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+const (
+ // Invert the meaning of InputInterface.
+ IPT_INV_VIA_IN = 0x01
+ // Invert the meaning of OutputInterface.
+ IPT_INV_VIA_OUT = 0x02
+ // Unclear what this is, as no references to it exist in the kernel.
+ IPT_INV_TOS = 0x04
+ // Invert the meaning of Src.
+ IPT_INV_SRCIP = 0x08
+ // Invert the meaning of Dst.
+ IPT_INV_DSTIP = 0x10
+ // Invert the meaning of the IPT_F_FRAG flag.
+ IPT_INV_FRAG = 0x20
+ // Invert the meaning of the Protocol field.
+ IPT_INV_PROTO = 0x40
+ // Enable all flags.
+ IPT_INV_MASK = 0x7F
+)
+
+// SizeOfIPTIP is the size of an IPTIP.
+const SizeOfIPTIP = 84
+
+// XTCounters holds packet and byte counts for a rule. It corresponds to struct
+// xt_counters in include/uapi/linux/netfilter/x_tables.h.
+type XTCounters struct {
+ // Pcnt is the packet count.
+ Pcnt uint64
+
+ // Bcnt is the byte count.
+ Bcnt uint64
+}
+
+// SizeOfXTCounters is the size of an XTCounters.
+const SizeOfXTCounters = 16
+
+// XTEntryMatch holds a match for a rule. For example, a user using the
+// addrtype iptables match extension would put the data for that match into an
+// XTEntryMatch. iptables-extensions(8) has a list of possible matches.
+//
+// XTEntryMatch corresponds to struct xt_entry_match in
+// include/uapi/linux/netfilter/x_tables.h. That struct contains a union
+// exposing different data to the user and kernel, but this struct holds only
+// the user data.
+type XTEntryMatch struct {
+ MatchSize uint16
+ Name ExtensionName
+ Revision uint8
+ // Data is omitted here because it would cause XTEntryMatch to be an
+ // extra byte larger (see http://www.catb.org/esr/structure-packing/).
+ // Data [0]byte
+}
+
+// SizeOfXTEntryMatch is the size of an XTEntryMatch.
+const SizeOfXTEntryMatch = 32
+
+// KernelXTEntryMatch is identical to XTEntryMatch, but contains
+// variable-length Data field.
+type KernelXTEntryMatch struct {
+ XTEntryMatch
+ Data []byte
+}
+
+// XTEntryTarget holds a target for a rule. For example, it can specify that
+// packets matching the rule should DROP, ACCEPT, or use an extension target.
+// iptables-extension(8) has a list of possible targets.
+//
+// XTEntryTarget corresponds to struct xt_entry_target in
+// include/uapi/linux/netfilter/x_tables.h. That struct contains a union
+// exposing different data to the user and kernel, but this struct holds only
+// the user data.
+type XTEntryTarget struct {
+ TargetSize uint16
+ Name ExtensionName
+ Revision uint8
+ // Data is omitted here because it would cause XTEntryTarget to be an
+ // extra byte larger (see http://www.catb.org/esr/structure-packing/).
+ // Data [0]byte
+}
+
+// SizeOfXTEntryTarget is the size of an XTEntryTarget.
+const SizeOfXTEntryTarget = 32
+
+// XTStandardTarget is a built-in target, one of ACCEPT, DROP, JUMP, QUEUE,
+// RETURN, or jump. It corresponds to struct xt_standard_target in
+// include/uapi/linux/netfilter/x_tables.h.
+type XTStandardTarget struct {
+ Target XTEntryTarget
+ // A positive verdict indicates a jump, and is the offset from the
+ // start of the table to jump to. A negative value means one of the
+ // other built-in targets.
+ Verdict int32
+ _ [4]byte
+}
+
+// SizeOfXTStandardTarget is the size of an XTStandardTarget.
+const SizeOfXTStandardTarget = 40
+
+// XTErrorTarget triggers an error when reached. It is also used to mark the
+// beginning of user-defined chains by putting the name of the chain in
+// ErrorName. It corresponds to struct xt_error_target in
+// include/uapi/linux/netfilter/x_tables.h.
+type XTErrorTarget struct {
+ Target XTEntryTarget
+ Name ErrorName
+ _ [2]byte
+}
+
+// SizeOfXTErrorTarget is the size of an XTErrorTarget.
+const SizeOfXTErrorTarget = 64
+
+// Flag values for NfNATIPV4Range. The values indicate whether to map
+// protocol specific part(ports) or IPs. It corresponds to values in
+// include/uapi/linux/netfilter/nf_nat.h.
+const (
+ NF_NAT_RANGE_MAP_IPS = 1 << 0
+ NF_NAT_RANGE_PROTO_SPECIFIED = 1 << 1
+ NF_NAT_RANGE_PROTO_RANDOM = 1 << 2
+ NF_NAT_RANGE_PERSISTENT = 1 << 3
+ NF_NAT_RANGE_PROTO_RANDOM_FULLY = 1 << 4
+ NF_NAT_RANGE_PROTO_RANDOM_ALL = (NF_NAT_RANGE_PROTO_RANDOM | NF_NAT_RANGE_PROTO_RANDOM_FULLY)
+ NF_NAT_RANGE_MASK = (NF_NAT_RANGE_MAP_IPS |
+ NF_NAT_RANGE_PROTO_SPECIFIED | NF_NAT_RANGE_PROTO_RANDOM |
+ NF_NAT_RANGE_PERSISTENT | NF_NAT_RANGE_PROTO_RANDOM_FULLY)
+)
+
+// NfNATIPV4Range corresponds to struct nf_nat_ipv4_range
+// in include/uapi/linux/netfilter/nf_nat.h. The fields are in
+// network byte order.
+type NfNATIPV4Range struct {
+ Flags uint32
+ MinIP [4]byte
+ MaxIP [4]byte
+ MinPort uint16
+ MaxPort uint16
+}
+
+// NfNATIPV4MultiRangeCompat corresponds to struct
+// nf_nat_ipv4_multi_range_compat in include/uapi/linux/netfilter/nf_nat.h.
+type NfNATIPV4MultiRangeCompat struct {
+ RangeSize uint32
+ RangeIPV4 NfNATIPV4Range
+}
+
+// XTRedirectTarget triggers a redirect when reached.
+// Adding 4 bytes of padding to make the struct 8 byte aligned.
+type XTRedirectTarget struct {
+ Target XTEntryTarget
+ NfRange NfNATIPV4MultiRangeCompat
+ _ [4]byte
+}
+
+// SizeOfXTRedirectTarget is the size of an XTRedirectTarget.
+const SizeOfXTRedirectTarget = 56
+
+// IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds
+// to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h.
+type IPTGetinfo struct {
+ Name TableName
+ ValidHooks uint32
+ HookEntry [NF_INET_NUMHOOKS]uint32
+ Underflow [NF_INET_NUMHOOKS]uint32
+ NumEntries uint32
+ Size uint32
+}
+
+// SizeOfIPTGetinfo is the size of an IPTGetinfo.
+const SizeOfIPTGetinfo = 84
+
+// IPTGetEntries is the argument for the IPT_SO_GET_ENTRIES sockopt. It
+// corresponds to struct ipt_get_entries in
+// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+type IPTGetEntries struct {
+ Name TableName
+ Size uint32
+ _ [4]byte
+ // Entrytable is omitted here because it would cause IPTGetEntries to
+ // be an extra byte longer (see
+ // http://www.catb.org/esr/structure-packing/).
+ // Entrytable [0]IPTEntry
+}
+
+// SizeOfIPTGetEntries is the size of an IPTGetEntries.
+const SizeOfIPTGetEntries = 40
+
+// KernelIPTGetEntries is identical to IPTGetEntries, but includes the
+// Entrytable field. This struct marshaled via the binary package to write an
+// KernelIPTGetEntries to userspace.
+type KernelIPTGetEntries struct {
+ IPTGetEntries
+ Entrytable []KernelIPTEntry
+}
+
+// IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It
+// corresponds to struct ipt_replace in
+// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+type IPTReplace struct {
+ Name TableName
+ ValidHooks uint32
+ NumEntries uint32
+ Size uint32
+ HookEntry [NF_INET_NUMHOOKS]uint32
+ Underflow [NF_INET_NUMHOOKS]uint32
+ NumCounters uint32
+ Counters uint64 // This is really a *XTCounters.
+ // Entries is omitted here because it would cause IPTReplace to be an
+ // extra byte longer (see http://www.catb.org/esr/structure-packing/).
+ // Entries [0]IPTEntry
+}
+
+// KernelIPTReplace is identical to IPTReplace, but includes the Entries field.
+type KernelIPTReplace struct {
+ IPTReplace
+ Entries [0]IPTEntry
+}
+
+// SizeOfIPTReplace is the size of an IPTReplace.
+const SizeOfIPTReplace = 96
+
+// ExtensionName holds the name of a netfilter extension.
+type ExtensionName [XT_EXTENSION_MAXNAMELEN]byte
+
+// String implements fmt.Stringer.
+func (en ExtensionName) String() string {
+ return goString(en[:])
+}
+
+// TableName holds the name of a netfilter table.
+type TableName [XT_TABLE_MAXNAMELEN]byte
+
+// String implements fmt.Stringer.
+func (tn TableName) String() string {
+ return goString(tn[:])
+}
+
+// ErrorName holds the name of a netfilter error. These can also hold
+// user-defined chains.
+type ErrorName [XT_FUNCTION_MAXNAMELEN]byte
+
+// String implements fmt.Stringer.
+func (en ErrorName) String() string {
+ return goString(en[:])
+}
+
+func goString(cstring []byte) string {
+ for i, c := range cstring {
+ if c == 0 {
+ return string(cstring[:i])
+ }
+ }
+ return string(cstring)
+}
+
+// XTTCP holds data for matching TCP packets. It corresponds to struct xt_tcp
+// in include/uapi/linux/netfilter/xt_tcpudp.h.
+type XTTCP struct {
+ // SourcePortStart specifies the inclusive start of the range of source
+ // ports to which the matcher applies.
+ SourcePortStart uint16
+
+ // SourcePortEnd specifies the inclusive end of the range of source ports
+ // to which the matcher applies.
+ SourcePortEnd uint16
+
+ // DestinationPortStart specifies the start of the destination port
+ // range to which the matcher applies.
+ DestinationPortStart uint16
+
+ // DestinationPortEnd specifies the end of the destination port
+ // range to which the matcher applies.
+ DestinationPortEnd uint16
+
+ // Option specifies that a particular TCP option must be set.
+ Option uint8
+
+ // FlagMask masks TCP flags when comparing to the FlagCompare byte. It allows
+ // for specification of which flags are important to the matcher.
+ FlagMask uint8
+
+ // FlagCompare, in combination with FlagMask, is used to match only packets
+ // that have certain flags set.
+ FlagCompare uint8
+
+ // InverseFlags flips the meaning of certain fields. See the
+ // TX_TCP_INV_* flags.
+ InverseFlags uint8
+}
+
+// SizeOfXTTCP is the size of an XTTCP.
+const SizeOfXTTCP = 12
+
+// Flags in XTTCP.InverseFlags. Corresponding constants are in
+// include/uapi/linux/netfilter/xt_tcpudp.h.
+const (
+ // Invert the meaning of SourcePortStart/End.
+ XT_TCP_INV_SRCPT = 0x01
+ // Invert the meaning of DestinationPortStart/End.
+ XT_TCP_INV_DSTPT = 0x02
+ // Invert the meaning of FlagCompare.
+ XT_TCP_INV_FLAGS = 0x04
+ // Invert the meaning of Option.
+ XT_TCP_INV_OPTION = 0x08
+ // Enable all flags.
+ XT_TCP_INV_MASK = 0x0F
+)
+
+// XTUDP holds data for matching UDP packets. It corresponds to struct xt_udp
+// in include/uapi/linux/netfilter/xt_tcpudp.h.
+type XTUDP struct {
+ // SourcePortStart is the inclusive start of the range of source ports
+ // to which the matcher applies.
+ SourcePortStart uint16
+
+ // SourcePortEnd is the inclusive end of the range of source ports to
+ // which the matcher applies.
+ SourcePortEnd uint16
+
+ // DestinationPortStart is the inclusive start of the destination port
+ // range to which the matcher applies.
+ DestinationPortStart uint16
+
+ // DestinationPortEnd is the inclusive end of the destination port
+ // range to which the matcher applies.
+ DestinationPortEnd uint16
+
+ // InverseFlags flips the meaning of certain fields. See the
+ // TX_UDP_INV_* flags.
+ InverseFlags uint8
+
+ _ uint8
+}
+
+// SizeOfXTUDP is the size of an XTUDP.
+const SizeOfXTUDP = 10
+
+// Flags in XTUDP.InverseFlags. Corresponding constants are in
+// include/uapi/linux/netfilter/xt_tcpudp.h.
+const (
+ // Invert the meaning of SourcePortStart/End.
+ XT_UDP_INV_SRCPT = 0x01
+ // Invert the meaning of DestinationPortStart/End.
+ XT_UDP_INV_DSTPT = 0x02
+ // Enable all flags.
+ XT_UDP_INV_MASK = 0x03
+)
+
+// IPTOwnerInfo holds data for matching packets with owner. It corresponds
+// to struct ipt_owner_info in libxt_owner.c of iptables binary.
+type IPTOwnerInfo struct {
+ // UID is user id which created the packet.
+ UID uint32
+
+ // GID is group id which created the packet.
+ GID uint32
+
+ // PID is process id of the process which created the packet.
+ PID uint32
+
+ // SID is session id which created the packet.
+ SID uint32
+
+ // Comm is the command name which created the packet.
+ Comm [16]byte
+
+ // Match is used to match UID/GID of the socket. See the
+ // XT_OWNER_* flags below.
+ Match uint8
+
+ // Invert flips the meaning of Match field.
+ Invert uint8
+}
+
+// SizeOfIPTOwnerInfo is the size of an XTOwnerMatchInfo.
+const SizeOfIPTOwnerInfo = 34
+
+// Flags in IPTOwnerInfo.Match. Corresponding constants are in
+// include/uapi/linux/netfilter/xt_owner.h.
+const (
+ // Match the UID of the packet.
+ XT_OWNER_UID = 1 << 0
+ // Match the GID of the packet.
+ XT_OWNER_GID = 1 << 1
+ // Match if the socket exists for the packet. Forwarded
+ // packets do not have an associated socket.
+ XT_OWNER_SOCKET = 1 << 2
+)
diff --git a/pkg/abi/linux/netfilter_test.go b/pkg/abi/linux/netfilter_test.go
new file mode 100644
index 000000000..565dd550e
--- /dev/null
+++ b/pkg/abi/linux/netfilter_test.go
@@ -0,0 +1,46 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/binary"
+)
+
+func TestSizes(t *testing.T) {
+ testCases := []struct {
+ typ interface{}
+ defined uintptr
+ }{
+ {IPTEntry{}, SizeOfIPTEntry},
+ {IPTGetEntries{}, SizeOfIPTGetEntries},
+ {IPTGetinfo{}, SizeOfIPTGetinfo},
+ {IPTIP{}, SizeOfIPTIP},
+ {IPTOwnerInfo{}, SizeOfIPTOwnerInfo},
+ {IPTReplace{}, SizeOfIPTReplace},
+ {XTCounters{}, SizeOfXTCounters},
+ {XTEntryMatch{}, SizeOfXTEntryMatch},
+ {XTEntryTarget{}, SizeOfXTEntryTarget},
+ {XTErrorTarget{}, SizeOfXTErrorTarget},
+ {XTStandardTarget{}, SizeOfXTStandardTarget},
+ }
+
+ for _, tc := range testCases {
+ if calculated := binary.Size(tc.typ); calculated != tc.defined {
+ t.Errorf("%T has a defined size of %d and calculated size of %d", tc.typ, tc.defined, calculated)
+ }
+ }
+}
diff --git a/pkg/abi/linux/netlink.go b/pkg/abi/linux/netlink.go
new file mode 100644
index 000000000..0ba086c76
--- /dev/null
+++ b/pkg/abi/linux/netlink.go
@@ -0,0 +1,130 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Netlink protocols, from uapi/linux/netlink.h.
+const (
+ NETLINK_ROUTE = 0
+ NETLINK_UNUSED = 1
+ NETLINK_USERSOCK = 2
+ NETLINK_FIREWALL = 3
+ NETLINK_SOCK_DIAG = 4
+ NETLINK_NFLOG = 5
+ NETLINK_XFRM = 6
+ NETLINK_SELINUX = 7
+ NETLINK_ISCSI = 8
+ NETLINK_AUDIT = 9
+ NETLINK_FIB_LOOKUP = 10
+ NETLINK_CONNECTOR = 11
+ NETLINK_NETFILTER = 12
+ NETLINK_IP6_FW = 13
+ NETLINK_DNRTMSG = 14
+ NETLINK_KOBJECT_UEVENT = 15
+ NETLINK_GENERIC = 16
+ NETLINK_SCSITRANSPORT = 18
+ NETLINK_ECRYPTFS = 19
+ NETLINK_RDMA = 20
+ NETLINK_CRYPTO = 21
+)
+
+// SockAddrNetlink is struct sockaddr_nl, from uapi/linux/netlink.h.
+type SockAddrNetlink struct {
+ Family uint16
+ _ uint16
+ PortID uint32
+ Groups uint32
+}
+
+// SockAddrNetlinkSize is the size of SockAddrNetlink.
+const SockAddrNetlinkSize = 12
+
+// NetlinkMessageHeader is struct nlmsghdr, from uapi/linux/netlink.h.
+type NetlinkMessageHeader struct {
+ Length uint32
+ Type uint16
+ Flags uint16
+ Seq uint32
+ PortID uint32
+}
+
+// NetlinkMessageHeaderSize is the size of NetlinkMessageHeader.
+const NetlinkMessageHeaderSize = 16
+
+// Netlink message header flags, from uapi/linux/netlink.h.
+const (
+ NLM_F_REQUEST = 0x1
+ NLM_F_MULTI = 0x2
+ NLM_F_ACK = 0x4
+ NLM_F_ECHO = 0x8
+ NLM_F_DUMP_INTR = 0x10
+ NLM_F_ROOT = 0x100
+ NLM_F_MATCH = 0x200
+ NLM_F_ATOMIC = 0x400
+ NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
+ NLM_F_REPLACE = 0x100
+ NLM_F_EXCL = 0x200
+ NLM_F_CREATE = 0x400
+ NLM_F_APPEND = 0x800
+)
+
+// Standard netlink message types, from uapi/linux/netlink.h.
+const (
+ NLMSG_NOOP = 0x1
+ NLMSG_ERROR = 0x2
+ NLMSG_DONE = 0x3
+ NLMSG_OVERRUN = 0x4
+
+ // NLMSG_MIN_TYPE is the first value for protocol-level types.
+ NLMSG_MIN_TYPE = 0x10
+)
+
+// NLMSG_ALIGNTO is the alignment of netlink messages, from
+// uapi/linux/netlink.h.
+const NLMSG_ALIGNTO = 4
+
+// NetlinkAttrHeader is the header of a netlink attribute, followed by payload.
+//
+// This is struct nlattr, from uapi/linux/netlink.h.
+type NetlinkAttrHeader struct {
+ Length uint16
+ Type uint16
+}
+
+// NetlinkAttrHeaderSize is the size of NetlinkAttrHeader.
+const NetlinkAttrHeaderSize = 4
+
+// NLA_ALIGNTO is the alignment of netlink attributes, from
+// uapi/linux/netlink.h.
+const NLA_ALIGNTO = 4
+
+// Socket options, from uapi/linux/netlink.h.
+const (
+ NETLINK_ADD_MEMBERSHIP = 1
+ NETLINK_DROP_MEMBERSHIP = 2
+ NETLINK_PKTINFO = 3
+ NETLINK_BROADCAST_ERROR = 4
+ NETLINK_NO_ENOBUFS = 5
+ NETLINK_LISTEN_ALL_NSID = 8
+ NETLINK_LIST_MEMBERSHIPS = 9
+ NETLINK_CAP_ACK = 10
+ NETLINK_EXT_ACK = 11
+ NETLINK_DUMP_STRICT_CHK = 12
+)
+
+// NetlinkErrorMessage is struct nlmsgerr, from uapi/linux/netlink.h.
+type NetlinkErrorMessage struct {
+ Error int32
+ Header NetlinkMessageHeader
+}
diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go
new file mode 100644
index 000000000..40bec566c
--- /dev/null
+++ b/pkg/abi/linux/netlink_route.go
@@ -0,0 +1,346 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Netlink message types for NETLINK_ROUTE sockets, from uapi/linux/rtnetlink.h.
+const (
+ RTM_NEWLINK = 16
+ RTM_DELLINK = 17
+ RTM_GETLINK = 18
+ RTM_SETLINK = 19
+
+ RTM_NEWADDR = 20
+ RTM_DELADDR = 21
+ RTM_GETADDR = 22
+
+ RTM_NEWROUTE = 24
+ RTM_DELROUTE = 25
+ RTM_GETROUTE = 26
+
+ RTM_NEWNEIGH = 28
+ RTM_DELNEIGH = 29
+ RTM_GETNEIGH = 30
+
+ RTM_NEWRULE = 32
+ RTM_DELRULE = 33
+ RTM_GETRULE = 34
+
+ RTM_NEWQDISC = 36
+ RTM_DELQDISC = 37
+ RTM_GETQDISC = 38
+
+ RTM_NEWTCLASS = 40
+ RTM_DELTCLASS = 41
+ RTM_GETTCLASS = 42
+
+ RTM_NEWTFILTER = 44
+ RTM_DELTFILTER = 45
+ RTM_GETTFILTER = 46
+
+ RTM_NEWACTION = 48
+ RTM_DELACTION = 49
+ RTM_GETACTION = 50
+
+ RTM_NEWPREFIX = 52
+
+ RTM_GETMULTICAST = 58
+
+ RTM_GETANYCAST = 62
+
+ RTM_NEWNEIGHTBL = 64
+ RTM_GETNEIGHTBL = 66
+ RTM_SETNEIGHTBL = 67
+
+ RTM_NEWNDUSEROPT = 68
+
+ RTM_NEWADDRLABEL = 72
+ RTM_DELADDRLABEL = 73
+ RTM_GETADDRLABEL = 74
+
+ RTM_GETDCB = 78
+ RTM_SETDCB = 79
+
+ RTM_NEWNETCONF = 80
+ RTM_GETNETCONF = 82
+
+ RTM_NEWMDB = 84
+ RTM_DELMDB = 85
+ RTM_GETMDB = 86
+
+ RTM_NEWNSID = 88
+ RTM_DELNSID = 89
+ RTM_GETNSID = 90
+)
+
+// InterfaceInfoMessage is struct ifinfomsg, from uapi/linux/rtnetlink.h.
+type InterfaceInfoMessage struct {
+ Family uint8
+ _ uint8
+ Type uint16
+ Index int32
+ Flags uint32
+ Change uint32
+}
+
+// Interface flags, from uapi/linux/if.h.
+const (
+ IFF_UP = 1 << 0
+ IFF_BROADCAST = 1 << 1
+ IFF_DEBUG = 1 << 2
+ IFF_LOOPBACK = 1 << 3
+ IFF_POINTOPOINT = 1 << 4
+ IFF_NOTRAILERS = 1 << 5
+ IFF_RUNNING = 1 << 6
+ IFF_NOARP = 1 << 7
+ IFF_PROMISC = 1 << 8
+ IFF_ALLMULTI = 1 << 9
+ IFF_MASTER = 1 << 10
+ IFF_SLAVE = 1 << 11
+ IFF_MULTICAST = 1 << 12
+ IFF_PORTSEL = 1 << 13
+ IFF_AUTOMEDIA = 1 << 14
+ IFF_DYNAMIC = 1 << 15
+ IFF_LOWER_UP = 1 << 16
+ IFF_DORMANT = 1 << 17
+ IFF_ECHO = 1 << 18
+)
+
+// Interface link attributes, from uapi/linux/if_link.h.
+const (
+ IFLA_UNSPEC = 0
+ IFLA_ADDRESS = 1
+ IFLA_BROADCAST = 2
+ IFLA_IFNAME = 3
+ IFLA_MTU = 4
+ IFLA_LINK = 5
+ IFLA_QDISC = 6
+ IFLA_STATS = 7
+ IFLA_COST = 8
+ IFLA_PRIORITY = 9
+ IFLA_MASTER = 10
+ IFLA_WIRELESS = 11
+ IFLA_PROTINFO = 12
+ IFLA_TXQLEN = 13
+ IFLA_MAP = 14
+ IFLA_WEIGHT = 15
+ IFLA_OPERSTATE = 16
+ IFLA_LINKMODE = 17
+ IFLA_LINKINFO = 18
+ IFLA_NET_NS_PID = 19
+ IFLA_IFALIAS = 20
+ IFLA_NUM_VF = 21
+ IFLA_VFINFO_LIST = 22
+ IFLA_STATS64 = 23
+ IFLA_VF_PORTS = 24
+ IFLA_PORT_SELF = 25
+ IFLA_AF_SPEC = 26
+ IFLA_GROUP = 27
+ IFLA_NET_NS_FD = 28
+ IFLA_EXT_MASK = 29
+ IFLA_PROMISCUITY = 30
+ IFLA_NUM_TX_QUEUES = 31
+ IFLA_NUM_RX_QUEUES = 32
+ IFLA_CARRIER = 33
+ IFLA_PHYS_PORT_ID = 34
+ IFLA_CARRIER_CHANGES = 35
+ IFLA_PHYS_SWITCH_ID = 36
+ IFLA_LINK_NETNSID = 37
+ IFLA_PHYS_PORT_NAME = 38
+ IFLA_PROTO_DOWN = 39
+ IFLA_GSO_MAX_SEGS = 40
+ IFLA_GSO_MAX_SIZE = 41
+)
+
+// InterfaceAddrMessage is struct ifaddrmsg, from uapi/linux/if_addr.h.
+type InterfaceAddrMessage struct {
+ Family uint8
+ PrefixLen uint8
+ Flags uint8
+ Scope uint8
+ Index uint32
+}
+
+// Interface attributes, from uapi/linux/if_addr.h.
+const (
+ IFA_UNSPEC = 0
+ IFA_ADDRESS = 1
+ IFA_LOCAL = 2
+ IFA_LABEL = 3
+ IFA_BROADCAST = 4
+ IFA_ANYCAST = 5
+ IFA_CACHEINFO = 6
+ IFA_MULTICAST = 7
+ IFA_FLAGS = 8
+)
+
+// Device types, from uapi/linux/if_arp.h.
+const (
+ ARPHRD_LOOPBACK = 772
+)
+
+// RouteMessage is struct rtmsg, from uapi/linux/rtnetlink.h.
+type RouteMessage struct {
+ Family uint8
+ DstLen uint8
+ SrcLen uint8
+ TOS uint8
+
+ Table uint8
+ Protocol uint8
+ Scope uint8
+ Type uint8
+
+ Flags uint32
+}
+
+// SizeOfRouteMessage is the size of RouteMessage.
+const SizeOfRouteMessage = 12
+
+// Route types, from uapi/linux/rtnetlink.h.
+const (
+ // RTN_UNSPEC represents an unspecified route type.
+ RTN_UNSPEC = 0
+
+ // RTN_UNICAST represents a unicast route.
+ RTN_UNICAST = 1
+
+ // RTN_LOCAL represents a route that is accepted locally.
+ RTN_LOCAL = 2
+
+ // RTN_BROADCAST represents a broadcast route (Traffic is accepted locally
+ // as broadcast, and sent as broadcast).
+ RTN_BROADCAST = 3
+
+ // RTN_ANYCAST represents a anycast route (Traffic is accepted locally as
+ // broadcast but sent as unicast).
+ RTN_ANYCAST = 6
+
+ // RTN_MULTICAST represents a multicast route.
+ RTN_MULTICAST = 5
+
+ // RTN_BLACKHOLE represents a route where all traffic is dropped.
+ RTN_BLACKHOLE = 6
+
+ // RTN_UNREACHABLE represents a route where the destination is unreachable.
+ RTN_UNREACHABLE = 7
+
+ RTN_PROHIBIT = 8
+ RTN_THROW = 9
+ RTN_NAT = 10
+ RTN_XRESOLVE = 11
+)
+
+// Route protocols/origins, from uapi/linux/rtnetlink.h.
+const (
+ RTPROT_UNSPEC = 0
+ RTPROT_REDIRECT = 1
+ RTPROT_KERNEL = 2
+ RTPROT_BOOT = 3
+ RTPROT_STATIC = 4
+ RTPROT_GATED = 8
+ RTPROT_RA = 9
+ RTPROT_MRT = 10
+ RTPROT_ZEBRA = 11
+ RTPROT_BIRD = 12
+ RTPROT_DNROUTED = 13
+ RTPROT_XORP = 14
+ RTPROT_NTK = 15
+ RTPROT_DHCP = 16
+ RTPROT_MROUTED = 17
+ RTPROT_BABEL = 42
+ RTPROT_BGP = 186
+ RTPROT_ISIS = 187
+ RTPROT_OSPF = 188
+ RTPROT_RIP = 189
+ RTPROT_EIGRP = 192
+)
+
+// Route scopes, from uapi/linux/rtnetlink.h.
+const (
+ RT_SCOPE_UNIVERSE = 0
+ RT_SCOPE_SITE = 200
+ RT_SCOPE_LINK = 253
+ RT_SCOPE_HOST = 254
+ RT_SCOPE_NOWHERE = 255
+)
+
+// Route flags, from uapi/linux/rtnetlink.h.
+const (
+ RTM_F_NOTIFY = 0x100
+ RTM_F_CLONED = 0x200
+ RTM_F_EQUALIZE = 0x400
+ RTM_F_PREFIX = 0x800
+ RTM_F_LOOKUP_TABLE = 0x1000
+ RTM_F_FIB_MATCH = 0x2000
+)
+
+// Route tables, from uapi/linux/rtnetlink.h.
+const (
+ RT_TABLE_UNSPEC = 0
+ RT_TABLE_COMPAT = 252
+ RT_TABLE_DEFAULT = 253
+ RT_TABLE_MAIN = 254
+ RT_TABLE_LOCAL = 255
+)
+
+// Route attributes, from uapi/linux/rtnetlink.h.
+const (
+ RTA_UNSPEC = 0
+ RTA_DST = 1
+ RTA_SRC = 2
+ RTA_IIF = 3
+ RTA_OIF = 4
+ RTA_GATEWAY = 5
+ RTA_PRIORITY = 6
+ RTA_PREFSRC = 7
+ RTA_METRICS = 8
+ RTA_MULTIPATH = 9
+ RTA_PROTOINFO = 10
+ RTA_FLOW = 11
+ RTA_CACHEINFO = 12
+ RTA_SESSION = 13
+ RTA_MP_ALGO = 14
+ RTA_TABLE = 15
+ RTA_MARK = 16
+ RTA_MFC_STATS = 17
+ RTA_VIA = 18
+ RTA_NEWDST = 19
+ RTA_PREF = 20
+ RTA_ENCAP_TYPE = 21
+ RTA_ENCAP = 22
+ RTA_EXPIRES = 23
+ RTA_PAD = 24
+ RTA_UID = 25
+ RTA_TTL_PROPAGATE = 26
+ RTA_IP_PROTO = 27
+ RTA_SPORT = 28
+ RTA_DPORT = 29
+)
+
+// Route flags, from include/uapi/linux/route.h.
+const (
+ RTF_GATEWAY = 0x2
+ RTF_UP = 0x1
+)
+
+// RtAttr is the header of optional addition route information, as a netlink
+// attribute. From include/uapi/linux/rtnetlink.h.
+type RtAttr struct {
+ Len uint16
+ Type uint16
+}
+
+// SizeOfRtAttr is the size of RtAttr.
+const SizeOfRtAttr = 4
diff --git a/pkg/abi/linux/poll.go b/pkg/abi/linux/poll.go
new file mode 100644
index 000000000..c04d26e4c
--- /dev/null
+++ b/pkg/abi/linux/poll.go
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// PollFD is struct pollfd, used by poll(2)/ppoll(2), from uapi/asm-generic/poll.h.
+type PollFD struct {
+ FD int32
+ Events int16
+ REvents int16
+}
+
+// Poll event flags, used by poll(2)/ppoll(2) and/or
+// epoll_ctl(2)/epoll_wait(2), from uapi/asm-generic/poll.h.
+const (
+ POLLIN = 0x0001
+ POLLPRI = 0x0002
+ POLLOUT = 0x0004
+ POLLERR = 0x0008
+ POLLHUP = 0x0010
+ POLLNVAL = 0x0020
+ POLLRDNORM = 0x0040
+ POLLRDBAND = 0x0080
+ POLLWRNORM = 0x0100
+ POLLWRBAND = 0x0200
+ POLLMSG = 0x0400
+ POLLREMOVE = 0x1000
+ POLLRDHUP = 0x2000
+ POLLFREE = 0x4000
+ POLL_BUSY_LOOP = 0x8000
+)
diff --git a/pkg/abi/linux/prctl.go b/pkg/abi/linux/prctl.go
new file mode 100644
index 000000000..391cfaa1c
--- /dev/null
+++ b/pkg/abi/linux/prctl.go
@@ -0,0 +1,164 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// PR_* flags, from <linux/pcrtl.h> for prctl(2).
+const (
+ // PR_SET_PDEATHSIG sets the process' death signal.
+ PR_SET_PDEATHSIG = 1
+
+ // PR_GET_PDEATHSIG gets the process' death signal.
+ PR_GET_PDEATHSIG = 2
+
+ // PR_GET_DUMPABLE gets the process' dumpable flag.
+ PR_GET_DUMPABLE = 3
+
+ // PR_SET_DUMPABLE sets the process' dumpable flag.
+ PR_SET_DUMPABLE = 4
+
+ // PR_GET_KEEPCAPS gets the value of the keep capabilities flag.
+ PR_GET_KEEPCAPS = 7
+
+ // PR_SET_KEEPCAPS sets the value of the keep capabilities flag.
+ PR_SET_KEEPCAPS = 8
+
+ // PR_GET_TIMING gets the process' timing method.
+ PR_GET_TIMING = 13
+
+ // PR_SET_TIMING sets the process' timing method.
+ PR_SET_TIMING = 14
+
+ // PR_SET_NAME sets the process' name.
+ PR_SET_NAME = 15
+
+ // PR_GET_NAME gets the process' name.
+ PR_GET_NAME = 16
+
+ // PR_GET_SECCOMP gets a process' seccomp mode.
+ PR_GET_SECCOMP = 21
+
+ // PR_SET_SECCOMP sets a process' seccomp mode.
+ PR_SET_SECCOMP = 22
+
+ // PR_CAPBSET_READ gets the capability bounding set.
+ PR_CAPBSET_READ = 23
+
+ // PR_CAPBSET_DROP sets the capability bounding set.
+ PR_CAPBSET_DROP = 24
+
+ // PR_GET_TSC gets the value of the flag determining whether the
+ // timestamp counter can be read.
+ PR_GET_TSC = 25
+
+ // PR_SET_TSC sets the value of the flag determining whether the
+ // timestamp counter can be read.
+ PR_SET_TSC = 26
+
+ // PR_SET_TIMERSLACK sets the process' time slack.
+ PR_SET_TIMERSLACK = 29
+
+ // PR_GET_TIMERSLACK gets the process' time slack.
+ PR_GET_TIMERSLACK = 30
+
+ // PR_TASK_PERF_EVENTS_DISABLE disables all performance counters
+ // attached to the calling process.
+ PR_TASK_PERF_EVENTS_DISABLE = 31
+
+ // PR_TASK_PERF_EVENTS_ENABLE enables all performance counters attached
+ // to the calling process.
+ PR_TASK_PERF_EVENTS_ENABLE = 32
+
+ // PR_MCE_KILL sets the machine check memory corruption kill policy for
+ // the calling thread.
+ PR_MCE_KILL = 33
+
+ // PR_MCE_KILL_GET gets the machine check memory corruption kill policy
+ // for the calling thread.
+ PR_MCE_KILL_GET = 34
+
+ // PR_SET_MM modifies certain kernel memory map descriptor fields of
+ // the calling process. See prctl(2) for more information.
+ PR_SET_MM = 35
+
+ PR_SET_MM_START_CODE = 1
+ PR_SET_MM_END_CODE = 2
+ PR_SET_MM_START_DATA = 3
+ PR_SET_MM_END_DATA = 4
+ PR_SET_MM_START_STACK = 5
+ PR_SET_MM_START_BRK = 6
+ PR_SET_MM_BRK = 7
+ PR_SET_MM_ARG_START = 8
+ PR_SET_MM_ARG_END = 9
+ PR_SET_MM_ENV_START = 10
+ PR_SET_MM_ENV_END = 11
+ PR_SET_MM_AUXV = 12
+ // PR_SET_MM_EXE_FILE supersedes the /proc/pid/exe symbolic link with a
+ // new one pointing to a new executable file identified by the file
+ // descriptor provided in arg3 argument. See prctl(2) for more
+ // information.
+ PR_SET_MM_EXE_FILE = 13
+ PR_SET_MM_MAP = 14
+ PR_SET_MM_MAP_SIZE = 15
+
+ // PR_SET_CHILD_SUBREAPER sets the "child subreaper" attribute of the
+ // calling process.
+ PR_SET_CHILD_SUBREAPER = 36
+
+ // PR_GET_CHILD_SUBREAPER gets the "child subreaper" attribute of the
+ // calling process.
+ PR_GET_CHILD_SUBREAPER = 37
+
+ // PR_SET_NO_NEW_PRIVS sets the calling thread's no_new_privs bit.
+ PR_SET_NO_NEW_PRIVS = 38
+
+ // PR_GET_NO_NEW_PRIVS gets the calling thread's no_new_privs bit.
+ PR_GET_NO_NEW_PRIVS = 39
+
+ // PR_GET_TID_ADDRESS retrieves the clear_child_tid address.
+ PR_GET_TID_ADDRESS = 40
+
+ // PR_SET_THP_DISABLE sets the state of the "THP disable" flag for the
+ // calling thread.
+ PR_SET_THP_DISABLE = 41
+
+ // PR_GET_THP_DISABLE gets the state of the "THP disable" flag for the
+ // calling thread.
+ PR_GET_THP_DISABLE = 42
+
+ // PR_MPX_ENABLE_MANAGEMENT enables kernel management of Memory
+ // Protection eXtensions (MPX) bounds tables.
+ PR_MPX_ENABLE_MANAGEMENT = 43
+
+ // PR_MPX_DISABLE_MANAGEMENT disables kernel management of Memory
+ // Protection eXtensions (MPX) bounds tables.
+ PR_MPX_DISABLE_MANAGEMENT = 44
+)
+
+// From <asm/prctl.h>
+// Flags are used in syscall arch_prctl(2).
+const (
+ ARCH_SET_GS = 0x1001
+ ARCH_SET_FS = 0x1002
+ ARCH_GET_FS = 0x1003
+ ARCH_GET_GS = 0x1004
+ ARCH_SET_CPUID = 0x1012
+)
+
+// Flags for prctl(PR_SET_DUMPABLE), defined in include/linux/sched/coredump.h.
+const (
+ SUID_DUMP_DISABLE = 0
+ SUID_DUMP_USER = 1
+ SUID_DUMP_ROOT = 2
+)
diff --git a/pkg/abi/linux/ptrace.go b/pkg/abi/linux/ptrace.go
new file mode 100644
index 000000000..23e605ab2
--- /dev/null
+++ b/pkg/abi/linux/ptrace.go
@@ -0,0 +1,89 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// ptrace commands from include/uapi/linux/ptrace.h.
+const (
+ PTRACE_TRACEME = 0
+ PTRACE_PEEKTEXT = 1
+ PTRACE_PEEKDATA = 2
+ PTRACE_PEEKUSR = 3
+ PTRACE_POKETEXT = 4
+ PTRACE_POKEDATA = 5
+ PTRACE_POKEUSR = 6
+ PTRACE_CONT = 7
+ PTRACE_KILL = 8
+ PTRACE_SINGLESTEP = 9
+ PTRACE_ATTACH = 16
+ PTRACE_DETACH = 17
+ PTRACE_SYSCALL = 24
+ PTRACE_SETOPTIONS = 0x4200
+ PTRACE_GETEVENTMSG = 0x4201
+ PTRACE_GETSIGINFO = 0x4202
+ PTRACE_SETSIGINFO = 0x4203
+ PTRACE_GETREGSET = 0x4204
+ PTRACE_SETREGSET = 0x4205
+ PTRACE_SEIZE = 0x4206
+ PTRACE_INTERRUPT = 0x4207
+ PTRACE_LISTEN = 0x4208
+ PTRACE_PEEKSIGINFO = 0x4209
+ PTRACE_GETSIGMASK = 0x420a
+ PTRACE_SETSIGMASK = 0x420b
+ PTRACE_SECCOMP_GET_FILTER = 0x420c
+ PTRACE_SECCOMP_GET_METADATA = 0x420d
+)
+
+// ptrace commands from arch/x86/include/uapi/asm/ptrace-abi.h.
+const (
+ PTRACE_GETREGS = 12
+ PTRACE_SETREGS = 13
+ PTRACE_GETFPREGS = 14
+ PTRACE_SETFPREGS = 15
+ PTRACE_GETFPXREGS = 18
+ PTRACE_SETFPXREGS = 19
+ PTRACE_OLDSETOPTIONS = 21
+ PTRACE_GET_THREAD_AREA = 25
+ PTRACE_SET_THREAD_AREA = 26
+ PTRACE_ARCH_PRCTL = 30
+ PTRACE_SYSEMU = 31
+ PTRACE_SYSEMU_SINGLESTEP = 32
+ PTRACE_SINGLEBLOCK = 33
+)
+
+// ptrace event codes from include/uapi/linux/ptrace.h.
+const (
+ PTRACE_EVENT_FORK = 1
+ PTRACE_EVENT_VFORK = 2
+ PTRACE_EVENT_CLONE = 3
+ PTRACE_EVENT_EXEC = 4
+ PTRACE_EVENT_VFORK_DONE = 5
+ PTRACE_EVENT_EXIT = 6
+ PTRACE_EVENT_SECCOMP = 7
+ PTRACE_EVENT_STOP = 128
+)
+
+// PTRACE_SETOPTIONS options from include/uapi/linux/ptrace.h.
+const (
+ PTRACE_O_TRACESYSGOOD = 1
+ PTRACE_O_TRACEFORK = 1 << PTRACE_EVENT_FORK
+ PTRACE_O_TRACEVFORK = 1 << PTRACE_EVENT_VFORK
+ PTRACE_O_TRACECLONE = 1 << PTRACE_EVENT_CLONE
+ PTRACE_O_TRACEEXEC = 1 << PTRACE_EVENT_EXEC
+ PTRACE_O_TRACEVFORKDONE = 1 << PTRACE_EVENT_VFORK_DONE
+ PTRACE_O_TRACEEXIT = 1 << PTRACE_EVENT_EXIT
+ PTRACE_O_TRACESECCOMP = 1 << PTRACE_EVENT_SECCOMP
+ PTRACE_O_EXITKILL = 1 << 20
+ PTRACE_O_SUSPEND_SECCOMP = 1 << 21
+)
diff --git a/pkg/abi/linux/ptrace_amd64.go b/pkg/abi/linux/ptrace_amd64.go
new file mode 100644
index 000000000..ed3881e27
--- /dev/null
+++ b/pkg/abi/linux/ptrace_amd64.go
@@ -0,0 +1,52 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+// PtraceRegs is the set of CPU registers exposed by ptrace. Source:
+// syscall.PtraceRegs.
+//
+// +marshal
+// +stateify savable
+type PtraceRegs struct {
+ R15 uint64
+ R14 uint64
+ R13 uint64
+ R12 uint64
+ Rbp uint64
+ Rbx uint64
+ R11 uint64
+ R10 uint64
+ R9 uint64
+ R8 uint64
+ Rax uint64
+ Rcx uint64
+ Rdx uint64
+ Rsi uint64
+ Rdi uint64
+ Orig_rax uint64
+ Rip uint64
+ Cs uint64
+ Eflags uint64
+ Rsp uint64
+ Ss uint64
+ Fs_base uint64
+ Gs_base uint64
+ Ds uint64
+ Es uint64
+ Fs uint64
+ Gs uint64
+}
diff --git a/pkg/abi/linux/ptrace_arm64.go b/pkg/abi/linux/ptrace_arm64.go
new file mode 100644
index 000000000..6147738b3
--- /dev/null
+++ b/pkg/abi/linux/ptrace_arm64.go
@@ -0,0 +1,29 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package linux
+
+// PtraceRegs is the set of CPU registers exposed by ptrace. Source:
+// syscall.PtraceRegs.
+//
+// +marshal
+// +stateify savable
+type PtraceRegs struct {
+ Regs [31]uint64
+ Sp uint64
+ Pc uint64
+ Pstate uint64
+}
diff --git a/pkg/abi/linux/rseq.go b/pkg/abi/linux/rseq.go
new file mode 100644
index 000000000..76253ba30
--- /dev/null
+++ b/pkg/abi/linux/rseq.go
@@ -0,0 +1,130 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Flags passed to rseq(2).
+//
+// Defined in include/uapi/linux/rseq.h.
+const (
+ // RSEQ_FLAG_UNREGISTER unregisters the current thread.
+ RSEQ_FLAG_UNREGISTER = 1 << 0
+)
+
+// Critical section flags used in RSeqCriticalSection.Flags and RSeq.Flags.
+//
+// Defined in include/uapi/linux/rseq.h.
+const (
+ // RSEQ_CS_FLAG_NO_RESTART_ON_PREEMPT inhibits restart on preemption.
+ RSEQ_CS_FLAG_NO_RESTART_ON_PREEMPT = 1 << 0
+
+ // RSEQ_CS_FLAG_NO_RESTART_ON_SIGNAL inhibits restart on signal
+ // delivery.
+ RSEQ_CS_FLAG_NO_RESTART_ON_SIGNAL = 1 << 1
+
+ // RSEQ_CS_FLAG_NO_RESTART_ON_MIGRATE inhibits restart on CPU
+ // migration.
+ RSEQ_CS_FLAG_NO_RESTART_ON_MIGRATE = 1 << 2
+)
+
+// RSeqCriticalSection describes a restartable sequences critical section. It
+// is equivalent to struct rseq_cs, defined in include/uapi/linux/rseq.h.
+//
+// In userspace, this structure is always aligned to 32 bytes.
+//
+// +marshal
+type RSeqCriticalSection struct {
+ // Version is the version of this structure. Version 0 is defined here.
+ Version uint32
+
+ // Flags are the critical section flags, defined above.
+ Flags uint32
+
+ // Start is the start address of the critical section.
+ Start uint64
+
+ // PostCommitOffset is the offset from Start of the first instruction
+ // outside of the critical section.
+ PostCommitOffset uint64
+
+ // Abort is the abort address. It must be outside the critical section,
+ // and the 4 bytes prior must match the abort signature.
+ Abort uint64
+}
+
+const (
+ // SizeOfRSeqCriticalSection is the size of RSeqCriticalSection.
+ SizeOfRSeqCriticalSection = 32
+
+ // SizeOfRSeqSignature is the size of the signature immediately
+ // preceding RSeqCriticalSection.Abort.
+ SizeOfRSeqSignature = 4
+)
+
+// Special values for RSeq.CPUID, defined in include/uapi/linux/rseq.h.
+const (
+ // RSEQ_CPU_ID_UNINITIALIZED indicates that this thread has not
+ // performed rseq initialization.
+ RSEQ_CPU_ID_UNINITIALIZED = ^uint32(0) // -1
+
+ // RSEQ_CPU_ID_REGISTRATION_FAILED indicates that rseq initialization
+ // failed.
+ RSEQ_CPU_ID_REGISTRATION_FAILED = ^uint32(1) // -2
+)
+
+// RSeq is the thread-local restartable sequences config/status. It
+// is equivalent to struct rseq, defined in include/uapi/linux/rseq.h.
+//
+// In userspace, this structure is always aligned to 32 bytes.
+type RSeq struct {
+ // CPUIDStart contains the current CPU ID if rseq is initialized.
+ //
+ // This field should only be read by the thread which registered this
+ // structure, and must be read atomically.
+ CPUIDStart uint32
+
+ // CPUID contains the current CPU ID or one of the CPU ID special
+ // values defined above.
+ //
+ // This field should only be read by the thread which registered this
+ // structure, and must be read atomically.
+ CPUID uint32
+
+ // RSeqCriticalSection is a pointer to the current RSeqCriticalSection
+ // block, or NULL. It is reset to NULL by the kernel on restart or
+ // non-restarting preempt/signal.
+ //
+ // This field should only be written by the thread which registered
+ // this structure, and must be written atomically.
+ RSeqCriticalSection uint64
+
+ // Flags are the critical section flags that apply to all critical
+ // sections on this thread, defined above.
+ Flags uint32
+}
+
+const (
+ // SizeOfRSeq is the size of RSeq.
+ //
+ // Note that RSeq is naively 24 bytes. However, it has 32-byte
+ // alignment, which in C increases sizeof to 32. That is the size that
+ // the Linux kernel uses.
+ SizeOfRSeq = 32
+
+ // AlignOfRSeq is the standard alignment of RSeq.
+ AlignOfRSeq = 32
+
+ // OffsetOfRSeqCriticalSection is the offset of RSeqCriticalSection in RSeq.
+ OffsetOfRSeqCriticalSection = 8
+)
diff --git a/pkg/abi/linux/rusage.go b/pkg/abi/linux/rusage.go
new file mode 100644
index 000000000..d8302dc85
--- /dev/null
+++ b/pkg/abi/linux/rusage.go
@@ -0,0 +1,46 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Flags that may be used with wait4(2) and getrusage(2).
+const (
+ // wait4(2) uses this to aggregate RUSAGE_SELF and RUSAGE_CHILDREN.
+ RUSAGE_BOTH = -0x2
+
+ // getrusage(2) flags.
+ RUSAGE_CHILDREN = -0x1
+ RUSAGE_SELF = 0x0
+ RUSAGE_THREAD = 0x1
+)
+
+// Rusage represents the Linux struct rusage.
+type Rusage struct {
+ UTime Timeval
+ STime Timeval
+ MaxRSS int64
+ IXRSS int64
+ IDRSS int64
+ ISRSS int64
+ MinFlt int64
+ MajFlt int64
+ NSwap int64
+ InBlock int64
+ OuBlock int64
+ MsgSnd int64
+ MsgRcv int64
+ NSignals int64
+ NVCSw int64
+ NIvCSw int64
+}
diff --git a/pkg/abi/linux/sched.go b/pkg/abi/linux/sched.go
new file mode 100644
index 000000000..70e820823
--- /dev/null
+++ b/pkg/abi/linux/sched.go
@@ -0,0 +1,36 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Scheduling policies, exposed by sched_getscheduler(2)/sched_setscheduler(2).
+const (
+ SCHED_NORMAL = 0
+ SCHED_FIFO = 1
+ SCHED_RR = 2
+ SCHED_BATCH = 3
+ SCHED_IDLE = 5
+ SCHED_DEADLINE = 6
+ SCHED_MICROQ = 16
+
+ // SCHED_RESET_ON_FORK is a flag that indicates that the process is
+ // reverted back to SCHED_NORMAL on fork.
+ SCHED_RESET_ON_FORK = 0x40000000
+)
+
+const (
+ PRIO_PGRP = 0x1
+ PRIO_PROCESS = 0x0
+ PRIO_USER = 0x2
+)
diff --git a/pkg/abi/linux/seccomp.go b/pkg/abi/linux/seccomp.go
new file mode 100644
index 000000000..d0607e256
--- /dev/null
+++ b/pkg/abi/linux/seccomp.go
@@ -0,0 +1,72 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import "fmt"
+
+// Seccomp constants taken from <linux/seccomp.h>.
+const (
+ SECCOMP_MODE_NONE = 0
+ SECCOMP_MODE_FILTER = 2
+
+ SECCOMP_RET_ACTION_FULL = 0xffff0000
+ SECCOMP_RET_ACTION = 0x7fff0000
+ SECCOMP_RET_DATA = 0x0000ffff
+
+ SECCOMP_SET_MODE_FILTER = 1
+ SECCOMP_FILTER_FLAG_TSYNC = 1
+ SECCOMP_GET_ACTION_AVAIL = 2
+)
+
+type BPFAction uint32
+
+const (
+ SECCOMP_RET_KILL_PROCESS BPFAction = 0x80000000
+ SECCOMP_RET_KILL_THREAD = 0x00000000
+ SECCOMP_RET_TRAP = 0x00030000
+ SECCOMP_RET_ERRNO = 0x00050000
+ SECCOMP_RET_TRACE = 0x7ff00000
+ SECCOMP_RET_ALLOW = 0x7fff0000
+)
+
+func (a BPFAction) String() string {
+ switch a & SECCOMP_RET_ACTION_FULL {
+ case SECCOMP_RET_KILL_PROCESS:
+ return "kill process"
+ case SECCOMP_RET_KILL_THREAD:
+ return "kill thread"
+ case SECCOMP_RET_TRAP:
+ return fmt.Sprintf("trap (%d)", a.Data())
+ case SECCOMP_RET_ERRNO:
+ return fmt.Sprintf("errno (%d)", a.Data())
+ case SECCOMP_RET_TRACE:
+ return fmt.Sprintf("trace (%d)", a.Data())
+ case SECCOMP_RET_ALLOW:
+ return "allow"
+ }
+ return fmt.Sprintf("invalid action: %#x", a)
+}
+
+// Data returns the SECCOMP_RET_DATA portion of the action.
+func (a BPFAction) Data() uint16 {
+ return uint16(a & SECCOMP_RET_DATA)
+}
+
+// SockFprog is sock_fprog taken from <linux/filter.h>.
+type SockFprog struct {
+ Len uint16
+ pad [6]byte
+ Filter *BPFInstruction
+}
diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go
new file mode 100644
index 000000000..de422c519
--- /dev/null
+++ b/pkg/abi/linux/sem.go
@@ -0,0 +1,52 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// semctl Command Definitions. Source: include/uapi/linux/sem.h
+const (
+ GETPID = 11
+ GETVAL = 12
+ GETALL = 13
+ GETNCNT = 14
+ GETZCNT = 15
+ SETVAL = 16
+ SETALL = 17
+)
+
+// ipcs ctl cmds. Source: include/uapi/linux/sem.h
+const (
+ SEM_STAT = 18
+ SEM_INFO = 19
+ SEM_STAT_ANY = 20
+)
+
+const SEM_UNDO = 0x1000
+
+// SemidDS is equivalent to struct semid64_ds.
+type SemidDS struct {
+ SemPerm IPCPerm
+ SemOTime TimeT
+ SemCTime TimeT
+ SemNSems uint64
+ unused3 uint64
+ unused4 uint64
+}
+
+// Sembuf is equivalent to struct sembuf.
+type Sembuf struct {
+ SemNum uint16
+ SemOp int16
+ SemFlg int16
+}
diff --git a/pkg/abi/linux/shm.go b/pkg/abi/linux/shm.go
new file mode 100644
index 000000000..e45aadb10
--- /dev/null
+++ b/pkg/abi/linux/shm.go
@@ -0,0 +1,86 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import "math"
+
+// shmat(2) flags. Source: include/uapi/linux/shm.h
+const (
+ SHM_RDONLY = 010000 // Read-only access.
+ SHM_RND = 020000 // Round attach address to SHMLBA boundary.
+ SHM_REMAP = 040000 // Take-over region on attach.
+ SHM_EXEC = 0100000 // Execution access.
+)
+
+// IPCPerm.Mode upper byte flags. Source: include/linux/shm.h
+const (
+ SHM_DEST = 01000 // Segment will be destroyed on last detach.
+ SHM_LOCKED = 02000 // Segment will not be swapped.
+ SHM_HUGETLB = 04000 // Segment will use huge TLB pages.
+ SHM_NORESERVE = 010000 // Don't check for reservations.
+)
+
+// Additional Linux-only flags for shmctl(2). Source: include/uapi/linux/shm.h
+const (
+ SHM_LOCK = 11
+ SHM_UNLOCK = 12
+ SHM_STAT = 13
+ SHM_INFO = 14
+)
+
+// SHM defaults as specified by linux. Source: include/uapi/linux/shm.h
+const (
+ SHMMIN = 1
+ SHMMNI = 4096
+ SHMMAX = math.MaxUint64 - 1<<24
+ SHMALL = math.MaxUint64 - 1<<24
+ SHMSEG = 4096
+)
+
+// ShmidDS is equivalent to struct shmid64_ds. Source:
+// include/uapi/asm-generic/shmbuf.h
+type ShmidDS struct {
+ ShmPerm IPCPerm
+ ShmSegsz uint64
+ ShmAtime TimeT
+ ShmDtime TimeT
+ ShmCtime TimeT
+ ShmCpid int32
+ ShmLpid int32
+ ShmNattach uint64
+
+ Unused4 uint64
+ Unused5 uint64
+}
+
+// ShmParams is equivalent to struct shminfo. Source: include/uapi/linux/shm.h
+type ShmParams struct {
+ ShmMax uint64
+ ShmMin uint64
+ ShmMni uint64
+ ShmSeg uint64
+ ShmAll uint64
+}
+
+// ShmInfo is equivalent to struct shm_info. Source: include/uapi/linux/shm.h
+type ShmInfo struct {
+ UsedIDs int32 // Number of currently existing segments.
+ _ [4]byte
+ ShmTot uint64 // Total number of shared memory pages.
+ ShmRss uint64 // Number of resident shared memory pages.
+ ShmSwp uint64 // Number of swapped shared memory pages.
+ SwapAttempts uint64 // Unused since Linux 2.4.
+ SwapSuccesses uint64 // Unused since Linux 2.4.
+}
diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go
new file mode 100644
index 000000000..1c330e763
--- /dev/null
+++ b/pkg/abi/linux/signal.go
@@ -0,0 +1,234 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/bits"
+)
+
+const (
+ // SignalMaximum is the highest valid signal number.
+ SignalMaximum = 64
+
+ // FirstStdSignal is the lowest standard signal number.
+ FirstStdSignal = 1
+
+ // LastStdSignal is the highest standard signal number.
+ LastStdSignal = 31
+
+ // FirstRTSignal is the lowest real-time signal number.
+ //
+ // 32 (SIGCANCEL) and 33 (SIGSETXID) are used internally by glibc.
+ FirstRTSignal = 32
+
+ // LastRTSignal is the highest real-time signal number.
+ LastRTSignal = 64
+
+ // NumStdSignals is the number of standard signals.
+ NumStdSignals = LastStdSignal - FirstStdSignal + 1
+
+ // NumRTSignals is the number of realtime signals.
+ NumRTSignals = LastRTSignal - FirstRTSignal + 1
+)
+
+// Signal is a signal number.
+type Signal int
+
+// IsValid returns true if s is a valid standard or realtime signal. (0 is not
+// considered valid; interfaces special-casing signal number 0 should check for
+// 0 first before asserting validity.)
+func (s Signal) IsValid() bool {
+ return s > 0 && s <= SignalMaximum
+}
+
+// IsStandard returns true if s is a standard signal.
+//
+// Preconditions: s.IsValid().
+func (s Signal) IsStandard() bool {
+ return s <= LastStdSignal
+}
+
+// IsRealtime returns true if s is a realtime signal.
+//
+// Preconditions: s.IsValid().
+func (s Signal) IsRealtime() bool {
+ return s >= FirstRTSignal
+}
+
+// Index returns the index for signal s into arrays of both standard and
+// realtime signals (e.g. signal masks).
+//
+// Preconditions: s.IsValid().
+func (s Signal) Index() int {
+ return int(s - 1)
+}
+
+// Signals.
+const (
+ SIGABRT = Signal(6)
+ SIGALRM = Signal(14)
+ SIGBUS = Signal(7)
+ SIGCHLD = Signal(17)
+ SIGCLD = Signal(17)
+ SIGCONT = Signal(18)
+ SIGFPE = Signal(8)
+ SIGHUP = Signal(1)
+ SIGILL = Signal(4)
+ SIGINT = Signal(2)
+ SIGIO = Signal(29)
+ SIGIOT = Signal(6)
+ SIGKILL = Signal(9)
+ SIGPIPE = Signal(13)
+ SIGPOLL = Signal(29)
+ SIGPROF = Signal(27)
+ SIGPWR = Signal(30)
+ SIGQUIT = Signal(3)
+ SIGSEGV = Signal(11)
+ SIGSTKFLT = Signal(16)
+ SIGSTOP = Signal(19)
+ SIGSYS = Signal(31)
+ SIGTERM = Signal(15)
+ SIGTRAP = Signal(5)
+ SIGTSTP = Signal(20)
+ SIGTTIN = Signal(21)
+ SIGTTOU = Signal(22)
+ SIGUNUSED = Signal(31)
+ SIGURG = Signal(23)
+ SIGUSR1 = Signal(10)
+ SIGUSR2 = Signal(12)
+ SIGVTALRM = Signal(26)
+ SIGWINCH = Signal(28)
+ SIGXCPU = Signal(24)
+ SIGXFSZ = Signal(25)
+)
+
+// SignalSet is a signal mask with a bit corresponding to each signal.
+//
+// +marshal
+type SignalSet uint64
+
+// SignalSetSize is the size in bytes of a SignalSet.
+const SignalSetSize = 8
+
+// MakeSignalSet returns SignalSet with the bit corresponding to each of the
+// given signals set.
+func MakeSignalSet(sigs ...Signal) SignalSet {
+ indices := make([]int, len(sigs))
+ for i, sig := range sigs {
+ indices[i] = sig.Index()
+ }
+ return SignalSet(bits.Mask64(indices...))
+}
+
+// SignalSetOf returns a SignalSet with a single signal set.
+func SignalSetOf(sig Signal) SignalSet {
+ return SignalSet(bits.MaskOf64(sig.Index()))
+}
+
+// ForEachSignal invokes f for each signal set in the given mask.
+func ForEachSignal(mask SignalSet, f func(sig Signal)) {
+ bits.ForEachSetBit64(uint64(mask), func(i int) {
+ f(Signal(i + 1))
+ })
+}
+
+// 'how' values for rt_sigprocmask(2).
+const (
+ // SIG_BLOCK blocks the signals in the set.
+ SIG_BLOCK = 0
+
+ // SIG_UNBLOCK blocks the signals in the set.
+ SIG_UNBLOCK = 1
+
+ // SIG_SETMASK sets the signal mask to set.
+ SIG_SETMASK = 2
+)
+
+// Signal actions for rt_sigaction(2), from uapi/asm-generic/signal-defs.h.
+const (
+ // SIG_DFL performs the default action.
+ SIG_DFL = 0
+
+ // SIG_IGN ignores the signal.
+ SIG_IGN = 1
+)
+
+// Signal action flags for rt_sigaction(2), from uapi/asm-generic/signal.h
+const (
+ SA_NOCLDSTOP = 0x00000001
+ SA_NOCLDWAIT = 0x00000002
+ SA_SIGINFO = 0x00000004
+ SA_RESTORER = 0x04000000
+ SA_ONSTACK = 0x08000000
+ SA_RESTART = 0x10000000
+ SA_NODEFER = 0x40000000
+ SA_RESETHAND = 0x80000000
+ SA_NOMASK = SA_NODEFER
+ SA_ONESHOT = SA_RESETHAND
+)
+
+// Signal info types.
+const (
+ SI_MASK = 0xffff0000
+ SI_KILL = 0 << 16
+ SI_TIMER = 1 << 16
+ SI_POLL = 2 << 16
+ SI_FAULT = 3 << 16
+ SI_CHLD = 4 << 16
+ SI_RT = 5 << 16
+ SI_MESGQ = 6 << 16
+ SI_SYS = 7 << 16
+)
+
+// SIGPOLL si_codes.
+const (
+ // POLL_IN indicates that data input available.
+ POLL_IN = SI_POLL | 1
+
+ // POLL_OUT indicates that output buffers available.
+ POLL_OUT = SI_POLL | 2
+
+ // POLL_MSG indicates that an input message available.
+ POLL_MSG = SI_POLL | 3
+
+ // POLL_ERR indicates that there was an i/o error.
+ POLL_ERR = SI_POLL | 4
+
+ // POLL_PRI indicates that a high priority input available.
+ POLL_PRI = SI_POLL | 5
+
+ // POLL_HUP indicates that a device disconnected.
+ POLL_HUP = SI_POLL | 6
+)
+
+// Sigevent represents struct sigevent.
+type Sigevent struct {
+ Value uint64 // union sigval {int, void*}
+ Signo int32
+ Notify int32
+
+ // struct sigevent here contains 48-byte union _sigev_un. However, only
+ // member _tid is significant to the kernel.
+ Tid int32
+ UnRemainder [44]byte
+}
+
+// Possible values for Sigevent.Notify, aka struct sigevent::sigev_notify.
+const (
+ SIGEV_SIGNAL = 0
+ SIGEV_NONE = 1
+ SIGEV_THREAD = 2
+ SIGEV_THREAD_ID = 4
+)
diff --git a/pkg/abi/linux/signalfd.go b/pkg/abi/linux/signalfd.go
new file mode 100644
index 000000000..85fad9956
--- /dev/null
+++ b/pkg/abi/linux/signalfd.go
@@ -0,0 +1,45 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+const (
+ // SFD_NONBLOCK is a signalfd(2) flag.
+ SFD_NONBLOCK = 00004000
+
+ // SFD_CLOEXEC is a signalfd(2) flag.
+ SFD_CLOEXEC = 02000000
+)
+
+// SignalfdSiginfo is the siginfo encoding for signalfds.
+type SignalfdSiginfo struct {
+ Signo uint32
+ Errno int32
+ Code int32
+ PID uint32
+ UID uint32
+ FD int32
+ TID uint32
+ Band uint32
+ Overrun uint32
+ TrapNo uint32
+ Status int32
+ Int int32
+ Ptr uint64
+ UTime uint64
+ STime uint64
+ Addr uint64
+ AddrLSB uint16
+ _ [48]uint8
+}
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
new file mode 100644
index 000000000..4a14ef691
--- /dev/null
+++ b/pkg/abi/linux/socket.go
@@ -0,0 +1,456 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import "gvisor.dev/gvisor/pkg/binary"
+
+// Address families, from linux/socket.h.
+const (
+ AF_UNSPEC = 0
+ AF_UNIX = 1
+ AF_INET = 2
+ AF_AX25 = 3
+ AF_IPX = 4
+ AF_APPLETALK = 5
+ AF_NETROM = 6
+ AF_BRIDGE = 7
+ AF_ATMPVC = 8
+ AF_X25 = 9
+ AF_INET6 = 10
+ AF_ROSE = 11
+ AF_DECnet = 12
+ AF_NETBEUI = 13
+ AF_SECURITY = 14
+ AF_KEY = 15
+ AF_NETLINK = 16
+ AF_PACKET = 17
+ AF_ASH = 18
+ AF_ECONET = 19
+ AF_ATMSVC = 20
+ AF_RDS = 21
+ AF_SNA = 22
+ AF_IRDA = 23
+ AF_PPPOX = 24
+ AF_WANPIPE = 25
+ AF_LLC = 26
+ AF_IB = 27
+ AF_MPLS = 28
+ AF_CAN = 29
+ AF_TIPC = 30
+ AF_BLUETOOTH = 31
+ AF_IUCV = 32
+ AF_RXRPC = 33
+ AF_ISDN = 34
+ AF_PHONET = 35
+ AF_IEEE802154 = 36
+ AF_CAIF = 37
+ AF_ALG = 38
+ AF_NFC = 39
+ AF_VSOCK = 40
+)
+
+// sendmsg(2)/recvmsg(2) flags, from linux/socket.h.
+const (
+ MSG_OOB = 0x1
+ MSG_PEEK = 0x2
+ MSG_DONTROUTE = 0x4
+ MSG_TRYHARD = 0x4
+ MSG_CTRUNC = 0x8
+ MSG_PROBE = 0x10
+ MSG_TRUNC = 0x20
+ MSG_DONTWAIT = 0x40
+ MSG_EOR = 0x80
+ MSG_WAITALL = 0x100
+ MSG_FIN = 0x200
+ MSG_EOF = MSG_FIN
+ MSG_SYN = 0x400
+ MSG_CONFIRM = 0x800
+ MSG_RST = 0x1000
+ MSG_ERRQUEUE = 0x2000
+ MSG_NOSIGNAL = 0x4000
+ MSG_MORE = 0x8000
+ MSG_WAITFORONE = 0x10000
+ MSG_SENDPAGE_NOTLAST = 0x20000
+ MSG_REINJECT = 0x8000000
+ MSG_ZEROCOPY = 0x4000000
+ MSG_FASTOPEN = 0x20000000
+ MSG_CMSG_CLOEXEC = 0x40000000
+)
+
+// Set/get socket option levels, from socket.h.
+const (
+ SOL_IP = 0
+ SOL_SOCKET = 1
+ SOL_TCP = 6
+ SOL_UDP = 17
+ SOL_IPV6 = 41
+ SOL_ICMPV6 = 58
+ SOL_RAW = 255
+ SOL_PACKET = 263
+ SOL_NETLINK = 270
+)
+
+// A SockType is a type (as opposed to family) of sockets. These are enumerated
+// below as SOCK_* constants.
+type SockType int
+
+// Socket types, from linux/net.h.
+const (
+ SOCK_STREAM SockType = 1
+ SOCK_DGRAM = 2
+ SOCK_RAW = 3
+ SOCK_RDM = 4
+ SOCK_SEQPACKET = 5
+ SOCK_DCCP = 6
+ SOCK_PACKET = 10
+)
+
+// SOCK_TYPE_MASK covers all of the above socket types. The remaining bits are
+// flags. From linux/net.h.
+const SOCK_TYPE_MASK = 0xf
+
+// socket(2)/socketpair(2)/accept4(2) flags, from linux/net.h.
+const (
+ SOCK_CLOEXEC = O_CLOEXEC
+ SOCK_NONBLOCK = O_NONBLOCK
+)
+
+// shutdown(2) how commands, from <linux/net.h>.
+const (
+ SHUT_RD = 0
+ SHUT_WR = 1
+ SHUT_RDWR = 2
+)
+
+// Socket options from socket.h.
+const (
+ SO_DEBUG = 1
+ SO_REUSEADDR = 2
+ SO_TYPE = 3
+ SO_ERROR = 4
+ SO_DONTROUTE = 5
+ SO_BROADCAST = 6
+ SO_SNDBUF = 7
+ SO_RCVBUF = 8
+ SO_KEEPALIVE = 9
+ SO_OOBINLINE = 10
+ SO_NO_CHECK = 11
+ SO_PRIORITY = 12
+ SO_LINGER = 13
+ SO_BSDCOMPAT = 14
+ SO_REUSEPORT = 15
+ SO_PASSCRED = 16
+ SO_PEERCRED = 17
+ SO_RCVLOWAT = 18
+ SO_SNDLOWAT = 19
+ SO_RCVTIMEO = 20
+ SO_SNDTIMEO = 21
+ SO_BINDTODEVICE = 25
+ SO_ATTACH_FILTER = 26
+ SO_DETACH_FILTER = 27
+ SO_GET_FILTER = SO_ATTACH_FILTER
+ SO_PEERNAME = 28
+ SO_TIMESTAMP = 29
+ SO_ACCEPTCONN = 30
+ SO_PEERSEC = 31
+ SO_SNDBUFFORCE = 32
+ SO_RCVBUFFORCE = 33
+ SO_PASSSEC = 34
+ SO_TIMESTAMPNS = 35
+ SO_MARK = 36
+ SO_TIMESTAMPING = 37
+ SO_PROTOCOL = 38
+ SO_DOMAIN = 39
+ SO_RXQ_OVFL = 40
+ SO_WIFI_STATUS = 41
+ SO_PEEK_OFF = 42
+ SO_NOFCS = 43
+ SO_LOCK_FILTER = 44
+ SO_SELECT_ERR_QUEUE = 45
+ SO_BUSY_POLL = 46
+ SO_MAX_PACING_RATE = 47
+ SO_BPF_EXTENSIONS = 48
+ SO_INCOMING_CPU = 49
+ SO_ATTACH_BPF = 50
+ SO_ATTACH_REUSEPORT_CBPF = 51
+ SO_ATTACH_REUSEPORT_EBPF = 52
+ SO_CNX_ADVICE = 53
+ SO_MEMINFO = 55
+ SO_INCOMING_NAPI_ID = 56
+ SO_COOKIE = 57
+ SO_PEERGROUPS = 59
+ SO_ZEROCOPY = 60
+ SO_TXTIME = 61
+)
+
+// enum socket_state, from uapi/linux/net.h.
+const (
+ SS_FREE = 0 // Not allocated.
+ SS_UNCONNECTED = 1 // Unconnected to any socket.
+ SS_CONNECTING = 2 // In process of connecting.
+ SS_CONNECTED = 3 // Connected to socket.
+ SS_DISCONNECTING = 4 // In process of disconnecting.
+)
+
+// TCP protocol states, from include/net/tcp_states.h.
+const (
+ TCP_ESTABLISHED uint32 = iota + 1
+ TCP_SYN_SENT
+ TCP_SYN_RECV
+ TCP_FIN_WAIT1
+ TCP_FIN_WAIT2
+ TCP_TIME_WAIT
+ TCP_CLOSE
+ TCP_CLOSE_WAIT
+ TCP_LAST_ACK
+ TCP_LISTEN
+ TCP_CLOSING
+ TCP_NEW_SYN_RECV
+)
+
+// SockAddrMax is the maximum size of a struct sockaddr, from
+// uapi/linux/socket.h.
+const SockAddrMax = 128
+
+// InetAddr is struct in_addr, from uapi/linux/in.h.
+type InetAddr [4]byte
+
+// SockAddrInet is struct sockaddr_in, from uapi/linux/in.h.
+type SockAddrInet struct {
+ Family uint16
+ Port uint16
+ Addr InetAddr
+ Zero [8]uint8 // pad to sizeof(struct sockaddr).
+}
+
+// InetMulticastRequest is struct ip_mreq, from uapi/linux/in.h.
+type InetMulticastRequest struct {
+ MulticastAddr InetAddr
+ InterfaceAddr InetAddr
+}
+
+// InetMulticastRequestWithNIC is struct ip_mreqn, from uapi/linux/in.h.
+type InetMulticastRequestWithNIC struct {
+ InetMulticastRequest
+ InterfaceIndex int32
+}
+
+// SockAddrInet6 is struct sockaddr_in6, from uapi/linux/in6.h.
+type SockAddrInet6 struct {
+ Family uint16
+ Port uint16
+ Flowinfo uint32
+ Addr [16]byte
+ Scope_id uint32
+}
+
+// SockAddrLink is a struct sockaddr_ll, from uapi/linux/if_packet.h.
+type SockAddrLink struct {
+ Family uint16
+ Protocol uint16
+ InterfaceIndex int32
+ ARPHardwareType uint16
+ PacketType byte
+ HardwareAddrLen byte
+ HardwareAddr [8]byte
+}
+
+// UnixPathMax is the maximum length of the path in an AF_UNIX socket.
+//
+// From uapi/linux/un.h.
+const UnixPathMax = 108
+
+// SockAddrUnix is struct sockaddr_un, from uapi/linux/un.h.
+type SockAddrUnix struct {
+ Family uint16
+ Path [UnixPathMax]int8
+}
+
+// SockAddr represents a union of valid socket address types. This is logically
+// equivalent to struct sockaddr. SockAddr ensures that a well-defined set of
+// types can be used as socket addresses.
+type SockAddr interface {
+ // implementsSockAddr exists purely to allow a type to indicate that they
+ // implement this interface. This method is a no-op and shouldn't be called.
+ implementsSockAddr()
+}
+
+func (s *SockAddrInet) implementsSockAddr() {}
+func (s *SockAddrInet6) implementsSockAddr() {}
+func (s *SockAddrLink) implementsSockAddr() {}
+func (s *SockAddrUnix) implementsSockAddr() {}
+func (s *SockAddrNetlink) implementsSockAddr() {}
+
+// Linger is struct linger, from include/linux/socket.h.
+type Linger struct {
+ OnOff int32
+ Linger int32
+}
+
+// SizeOfLinger is the binary size of a Linger struct.
+const SizeOfLinger = 8
+
+// TCPInfo is a collection of TCP statistics.
+//
+// From uapi/linux/tcp.h. Newer versions of Linux continue to add new fields to
+// the end of this struct or within existing unusued space, so its size grows
+// over time. The current iteration is based on linux v4.17. New versions are
+// always backwards compatible.
+type TCPInfo struct {
+ State uint8
+ CaState uint8
+ Retransmits uint8
+ Probes uint8
+ Backoff uint8
+ Options uint8
+ // WindowScale is the combination of snd_wscale (first 4 bits) and rcv_wscale (second 4 bits)
+ WindowScale uint8
+ // DeliveryRateAppLimited is a boolean and only the first bit is meaningful.
+ DeliveryRateAppLimited uint8
+
+ RTO uint32
+ ATO uint32
+ SndMss uint32
+ RcvMss uint32
+
+ Unacked uint32
+ Sacked uint32
+ Lost uint32
+ Retrans uint32
+ Fackets uint32
+
+ // Times.
+ LastDataSent uint32
+ LastAckSent uint32
+ LastDataRecv uint32
+ LastAckRecv uint32
+
+ // Metrics.
+ PMTU uint32
+ RcvSsthresh uint32
+ RTT uint32
+ RTTVar uint32
+ SndSsthresh uint32
+ SndCwnd uint32
+ Advmss uint32
+ Reordering uint32
+
+ RcvRTT uint32
+ RcvSpace uint32
+
+ TotalRetrans uint32
+
+ PacingRate uint64
+ MaxPacingRate uint64
+ // BytesAcked is RFC4898 tcpEStatsAppHCThruOctetsAcked.
+ BytesAcked uint64
+ // BytesReceived is RFC4898 tcpEStatsAppHCThruOctetsReceived.
+ BytesReceived uint64
+ // SegsOut is RFC4898 tcpEStatsPerfSegsOut.
+ SegsOut uint32
+ // SegsIn is RFC4898 tcpEStatsPerfSegsIn.
+ SegsIn uint32
+
+ NotSentBytes uint32
+ MinRTT uint32
+ // DataSegsIn is RFC4898 tcpEStatsDataSegsIn.
+ DataSegsIn uint32
+ // DataSegsOut is RFC4898 tcpEStatsDataSegsOut.
+ DataSegsOut uint32
+
+ DeliveryRate uint64
+
+ // BusyTime is the time in microseconds busy sending data.
+ BusyTime uint64
+ // RwndLimited is the time in microseconds limited by receive window.
+ RwndLimited uint64
+ // SndBufLimited is the time in microseconds limited by send buffer.
+ SndBufLimited uint64
+}
+
+// SizeOfTCPInfo is the binary size of a TCPInfo struct.
+var SizeOfTCPInfo = int(binary.Size(TCPInfo{}))
+
+// Control message types, from linux/socket.h.
+const (
+ SCM_CREDENTIALS = 0x2
+ SCM_RIGHTS = 0x1
+)
+
+// A ControlMessageHeader is the header for a socket control message.
+//
+// ControlMessageHeader represents struct cmsghdr from linux/socket.h.
+type ControlMessageHeader struct {
+ Length uint64
+ Level int32
+ Type int32
+}
+
+// SizeOfControlMessageHeader is the binary size of a ControlMessageHeader
+// struct.
+var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{}))
+
+// A ControlMessageCredentials is an SCM_CREDENTIALS socket control message.
+//
+// ControlMessageCredentials represents struct ucred from linux/socket.h.
+type ControlMessageCredentials struct {
+ PID int32
+ UID uint32
+ GID uint32
+}
+
+// A ControlMessageIPPacketInfo is IP_PKTINFO socket control message.
+//
+// ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h.
+type ControlMessageIPPacketInfo struct {
+ NIC int32
+ LocalAddr InetAddr
+ DestinationAddr InetAddr
+}
+
+// SizeOfControlMessageCredentials is the binary size of a
+// ControlMessageCredentials struct.
+var SizeOfControlMessageCredentials = int(binary.Size(ControlMessageCredentials{}))
+
+// A ControlMessageRights is an SCM_RIGHTS socket control message.
+type ControlMessageRights []int32
+
+// SizeOfControlMessageRight is the size of a single element in
+// ControlMessageRights.
+const SizeOfControlMessageRight = 4
+
+// SizeOfControlMessageInq is the size of a TCP_INQ control message.
+const SizeOfControlMessageInq = 4
+
+// SizeOfControlMessageTOS is the size of an IP_TOS control message.
+const SizeOfControlMessageTOS = 1
+
+// SizeOfControlMessageTClass is the size of an IPV6_TCLASS control message.
+const SizeOfControlMessageTClass = 4
+
+// SizeOfControlMessageIPPacketInfo is the size of an IP_PKTINFO
+// control message.
+const SizeOfControlMessageIPPacketInfo = 12
+
+// SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call.
+// From net/scm.h.
+const SCM_MAX_FD = 253
+
+// SO_ACCEPTCON is defined as __SO_ACCEPTCON in
+// include/uapi/linux/net.h, which represents a listening socket
+// state. Note that this is distinct from SO_ACCEPTCONN, which is a
+// socket option for querying whether a socket is in a listening
+// state.
+const SO_ACCEPTCON = 1 << 16
diff --git a/pkg/abi/linux/splice.go b/pkg/abi/linux/splice.go
new file mode 100644
index 000000000..650eb87e8
--- /dev/null
+++ b/pkg/abi/linux/splice.go
@@ -0,0 +1,23 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Constants for splice(2), sendfile(2) and tee(2).
+const (
+ SPLICE_F_MOVE = 1 << iota
+ SPLICE_F_NONBLOCK
+ SPLICE_F_MORE
+ SPLICE_F_GIFT
+)
diff --git a/pkg/abi/linux/tcp.go b/pkg/abi/linux/tcp.go
new file mode 100644
index 000000000..2a8d4708b
--- /dev/null
+++ b/pkg/abi/linux/tcp.go
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Socket options from uapi/linux/tcp.h.
+const (
+ TCP_NODELAY = 1
+ TCP_MAXSEG = 2
+ TCP_CORK = 3
+ TCP_KEEPIDLE = 4
+ TCP_KEEPINTVL = 5
+ TCP_KEEPCNT = 6
+ TCP_SYNCNT = 7
+ TCP_LINGER2 = 8
+ TCP_DEFER_ACCEPT = 9
+ TCP_WINDOW_CLAMP = 10
+ TCP_INFO = 11
+ TCP_QUICKACK = 12
+ TCP_CONGESTION = 13
+ TCP_MD5SIG = 14
+ TCP_THIN_LINEAR_TIMEOUTS = 16
+ TCP_THIN_DUPACK = 17
+ TCP_USER_TIMEOUT = 18
+ TCP_REPAIR = 19
+ TCP_REPAIR_QUEUE = 20
+ TCP_QUEUE_SEQ = 21
+ TCP_REPAIR_OPTIONS = 22
+ TCP_FASTOPEN = 23
+ TCP_TIMESTAMP = 24
+ TCP_NOTSENT_LOWAT = 25
+ TCP_CC_INFO = 26
+ TCP_SAVE_SYN = 27
+ TCP_SAVED_SYN = 28
+ TCP_REPAIR_WINDOW = 29
+ TCP_FASTOPEN_CONNECT = 30
+ TCP_ULP = 31
+ TCP_MD5SIG_EXT = 32
+ TCP_FASTOPEN_KEY = 33
+ TCP_FASTOPEN_NO_COOKIE = 34
+ TCP_ZEROCOPY_RECEIVE = 35
+ TCP_INQ = 36
+)
+
+// Socket constants from include/net/tcp.h.
+const (
+ MAX_TCP_KEEPIDLE = 32767
+ MAX_TCP_KEEPINTVL = 32767
+ MAX_TCP_KEEPCNT = 127
+)
diff --git a/pkg/abi/linux/time.go b/pkg/abi/linux/time.go
new file mode 100644
index 000000000..e6860ed49
--- /dev/null
+++ b/pkg/abi/linux/time.go
@@ -0,0 +1,270 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "math"
+ "time"
+)
+
+const (
+ // ClockTick is the length of time represented by a single clock tick, as
+ // used by times(2) and /proc/[pid]/stat.
+ ClockTick = time.Second / CLOCKS_PER_SEC
+
+ // CLOCKS_PER_SEC is the number of ClockTicks per second.
+ //
+ // Linux defines this to be 100 on most architectures, irrespective of
+ // CONFIG_HZ. Userspace obtains the value through sysconf(_SC_CLK_TCK),
+ // which uses the AT_CLKTCK entry in the auxiliary vector if one is
+ // provided, and assumes 100 otherwise (glibc:
+ // sysdeps/posix/sysconf.c:__sysconf() =>
+ // sysdeps/unix/sysv/linux/getclktck.c, elf/dl-support.c:_dl_aux_init()).
+ //
+ // Not to be confused with POSIX CLOCKS_PER_SEC, as used by clock(3); "XSI
+ // requires that [POSIX] CLOCKS_PER_SEC equals 1000000 independent of the
+ // actual resolution" - clock(3).
+ CLOCKS_PER_SEC = 100
+)
+
+// CPU clock types for use with clock_gettime(2) et al.
+//
+// The 29 most significant bits of a 32 bit clock ID are either a PID or a FD.
+//
+// Bits 1 and 0 give the type: PROF=0, VIRT=1, SCHED=2, or FD=3.
+//
+// Bit 2 indicates whether a cpu clock refers to a thread or a process.
+const (
+ CPUCLOCK_PROF = 0
+ CPUCLOCK_VIRT = 1
+ CPUCLOCK_SCHED = 2
+ CPUCLOCK_MAX = 3
+ CLOCKFD = CPUCLOCK_MAX
+
+ CPUCLOCK_CLOCK_MASK = 3
+ CPUCLOCK_PERTHREAD_MASK = 4
+)
+
+// Clock identifiers for use with clock_gettime(2), clock_getres(2),
+// clock_nanosleep(2).
+const (
+ CLOCK_REALTIME = 0
+ CLOCK_MONOTONIC = 1
+ CLOCK_PROCESS_CPUTIME_ID = 2
+ CLOCK_THREAD_CPUTIME_ID = 3
+ CLOCK_MONOTONIC_RAW = 4
+ CLOCK_REALTIME_COARSE = 5
+ CLOCK_MONOTONIC_COARSE = 6
+ CLOCK_BOOTTIME = 7
+ CLOCK_REALTIME_ALARM = 8
+ CLOCK_BOOTTIME_ALARM = 9
+)
+
+// Flags for clock_nanosleep(2).
+const (
+ TIMER_ABSTIME = 1
+)
+
+// Flags for timerfd syscalls (timerfd_create(2), timerfd_settime(2)).
+const (
+ // TFD_CLOEXEC is a timerfd_create flag.
+ TFD_CLOEXEC = O_CLOEXEC
+
+ // TFD_NONBLOCK is a timerfd_create flag.
+ TFD_NONBLOCK = O_NONBLOCK
+
+ // TFD_TIMER_ABSTIME is a timerfd_settime flag.
+ TFD_TIMER_ABSTIME = 1
+)
+
+// The safe number of seconds you can represent by int64.
+const maxSecInDuration = math.MaxInt64 / int64(time.Second)
+
+// TimeT represents time_t in <time.h>. It represents time in seconds.
+type TimeT int64
+
+// NsecToTimeT translates nanoseconds to TimeT (seconds).
+func NsecToTimeT(nsec int64) TimeT {
+ return TimeT(nsec / 1e9)
+}
+
+// Timespec represents struct timespec in <time.h>.
+//
+// +marshal
+type Timespec struct {
+ Sec int64
+ Nsec int64
+}
+
+// Unix returns the second and nanosecond.
+func (ts Timespec) Unix() (sec int64, nsec int64) {
+ return int64(ts.Sec), int64(ts.Nsec)
+}
+
+// ToTime returns the Go time.Time representation.
+func (ts Timespec) ToTime() time.Time {
+ return time.Unix(ts.Sec, ts.Nsec)
+}
+
+// ToNsec returns the nanosecond representation.
+func (ts Timespec) ToNsec() int64 {
+ return int64(ts.Sec)*1e9 + int64(ts.Nsec)
+}
+
+// ToNsecCapped returns the safe nanosecond representation.
+func (ts Timespec) ToNsecCapped() int64 {
+ if ts.Sec > maxSecInDuration {
+ return math.MaxInt64
+ }
+ return ts.ToNsec()
+}
+
+// ToDuration returns the safe nanosecond representation as time.Duration.
+func (ts Timespec) ToDuration() time.Duration {
+ return time.Duration(ts.ToNsecCapped())
+}
+
+// Valid returns whether the timespec contains valid values.
+func (ts Timespec) Valid() bool {
+ return !(ts.Sec < 0 || ts.Nsec < 0 || ts.Nsec >= int64(time.Second))
+}
+
+// NsecToTimespec translates nanoseconds to Timespec.
+func NsecToTimespec(nsec int64) (ts Timespec) {
+ ts.Sec = nsec / 1e9
+ ts.Nsec = nsec % 1e9
+ return
+}
+
+// DurationToTimespec translates time.Duration to Timespec.
+func DurationToTimespec(dur time.Duration) Timespec {
+ return NsecToTimespec(dur.Nanoseconds())
+}
+
+// SizeOfTimeval is the size of a Timeval struct in bytes.
+const SizeOfTimeval = 16
+
+// Timeval represents struct timeval in <time.h>.
+//
+// +marshal
+type Timeval struct {
+ Sec int64
+ Usec int64
+}
+
+// ToNsecCapped returns the safe nanosecond representation.
+func (tv Timeval) ToNsecCapped() int64 {
+ if tv.Sec > maxSecInDuration {
+ return math.MaxInt64
+ }
+ return int64(tv.Sec)*1e9 + int64(tv.Usec)*1e3
+}
+
+// ToDuration returns the safe nanosecond representation as a time.Duration.
+func (tv Timeval) ToDuration() time.Duration {
+ return time.Duration(tv.ToNsecCapped())
+}
+
+// ToTime returns the Go time.Time representation.
+func (tv Timeval) ToTime() time.Time {
+ return time.Unix(tv.Sec, tv.Usec*1e3)
+}
+
+// NsecToTimeval translates nanosecond to Timeval.
+func NsecToTimeval(nsec int64) (tv Timeval) {
+ nsec += 999 // round up to microsecond
+ tv.Sec = nsec / 1e9
+ tv.Usec = nsec % 1e9 / 1e3
+ return
+}
+
+// DurationToTimeval translates time.Duration to Timeval.
+func DurationToTimeval(dur time.Duration) Timeval {
+ return NsecToTimeval(dur.Nanoseconds())
+}
+
+// Itimerspec represents struct itimerspec in <time.h>.
+type Itimerspec struct {
+ Interval Timespec
+ Value Timespec
+}
+
+// ItimerVal mimics the following struct in <sys/time.h>
+// struct itimerval {
+// struct timeval it_interval; /* next value */
+// struct timeval it_value; /* current value */
+// };
+type ItimerVal struct {
+ Interval Timeval
+ Value Timeval
+}
+
+// ClockT represents type clock_t.
+type ClockT int64
+
+// ClockTFromDuration converts time.Duration to clock_t.
+func ClockTFromDuration(d time.Duration) ClockT {
+ return ClockT(d / ClockTick)
+}
+
+// Tms represents struct tms, used by times(2).
+type Tms struct {
+ UTime ClockT
+ STime ClockT
+ CUTime ClockT
+ CSTime ClockT
+}
+
+// TimerID represents type timer_t, which identifies a POSIX per-process
+// interval timer.
+type TimerID int32
+
+// StatxTimestamp represents struct statx_timestamp.
+//
+// +marshal
+type StatxTimestamp struct {
+ Sec int64
+ Nsec uint32
+ _ int32
+}
+
+// ToNsec returns the nanosecond representation.
+func (sxts StatxTimestamp) ToNsec() int64 {
+ return int64(sxts.Sec)*1e9 + int64(sxts.Nsec)
+}
+
+// ToNsecCapped returns the safe nanosecond representation.
+func (sxts StatxTimestamp) ToNsecCapped() int64 {
+ if sxts.Sec > maxSecInDuration {
+ return math.MaxInt64
+ }
+ return sxts.ToNsec()
+}
+
+// NsecToStatxTimestamp translates nanoseconds to StatxTimestamp.
+func NsecToStatxTimestamp(nsec int64) (ts StatxTimestamp) {
+ return StatxTimestamp{
+ Sec: nsec / 1e9,
+ Nsec: uint32(nsec % 1e9),
+ }
+}
+
+// Utime represents struct utimbuf used by utimes(2).
+//
+// +marshal
+type Utime struct {
+ Actime int64
+ Modtime int64
+}
diff --git a/pkg/abi/linux/timer.go b/pkg/abi/linux/timer.go
new file mode 100644
index 000000000..e32d09e10
--- /dev/null
+++ b/pkg/abi/linux/timer.go
@@ -0,0 +1,23 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// itimer types for getitimer(2) and setitimer(2), from
+// include/uapi/linux/time.h.
+const (
+ ITIMER_REAL = 0
+ ITIMER_VIRTUAL = 1
+ ITIMER_PROF = 2
+)
diff --git a/pkg/abi/linux/tty.go b/pkg/abi/linux/tty.go
new file mode 100644
index 000000000..8ac02aee8
--- /dev/null
+++ b/pkg/abi/linux/tty.go
@@ -0,0 +1,344 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+const (
+ // NumControlCharacters is the number of control characters in Termios.
+ NumControlCharacters = 19
+ // disabledChar is used to indicate that a control character is
+ // disabled.
+ disabledChar = 0
+)
+
+// Winsize is struct winsize, defined in uapi/asm-generic/termios.h.
+type Winsize struct {
+ Row uint16
+ Col uint16
+ Xpixel uint16
+ Ypixel uint16
+}
+
+// Termios is struct termios, defined in uapi/asm-generic/termbits.h.
+type Termios struct {
+ InputFlags uint32
+ OutputFlags uint32
+ ControlFlags uint32
+ LocalFlags uint32
+ LineDiscipline uint8
+ ControlCharacters [NumControlCharacters]uint8
+}
+
+// KernelTermios is struct ktermios/struct termios2, defined in
+// uapi/asm-generic/termbits.h.
+//
+// +stateify savable
+type KernelTermios struct {
+ InputFlags uint32
+ OutputFlags uint32
+ ControlFlags uint32
+ LocalFlags uint32
+ LineDiscipline uint8
+ ControlCharacters [NumControlCharacters]uint8
+ InputSpeed uint32
+ OutputSpeed uint32
+}
+
+// IEnabled returns whether flag is enabled in termios input flags.
+func (t *KernelTermios) IEnabled(flag uint32) bool {
+ return t.InputFlags&flag == flag
+}
+
+// OEnabled returns whether flag is enabled in termios output flags.
+func (t *KernelTermios) OEnabled(flag uint32) bool {
+ return t.OutputFlags&flag == flag
+}
+
+// CEnabled returns whether flag is enabled in termios control flags.
+func (t *KernelTermios) CEnabled(flag uint32) bool {
+ return t.ControlFlags&flag == flag
+}
+
+// LEnabled returns whether flag is enabled in termios local flags.
+func (t *KernelTermios) LEnabled(flag uint32) bool {
+ return t.LocalFlags&flag == flag
+}
+
+// ToTermios copies fields that are shared with Termios into a new Termios
+// struct.
+func (t *KernelTermios) ToTermios() Termios {
+ return Termios{
+ InputFlags: t.InputFlags,
+ OutputFlags: t.OutputFlags,
+ ControlFlags: t.ControlFlags,
+ LocalFlags: t.LocalFlags,
+ LineDiscipline: t.LineDiscipline,
+ ControlCharacters: t.ControlCharacters,
+ }
+}
+
+// FromTermios copies fields that are shared with Termios into this
+// KernelTermios struct.
+func (t *KernelTermios) FromTermios(term Termios) {
+ t.InputFlags = term.InputFlags
+ t.OutputFlags = term.OutputFlags
+ t.ControlFlags = term.ControlFlags
+ t.LocalFlags = term.LocalFlags
+ t.LineDiscipline = term.LineDiscipline
+ t.ControlCharacters = term.ControlCharacters
+}
+
+// IsTerminating returns whether c is a line terminating character.
+func (t *KernelTermios) IsTerminating(cBytes []byte) bool {
+ // All terminating characters are 1 byte.
+ if len(cBytes) != 1 {
+ return false
+ }
+ c := cBytes[0]
+
+ // Is this the user-set EOF character?
+ if t.IsEOF(c) {
+ return true
+ }
+
+ switch c {
+ case disabledChar:
+ return false
+ case '\n', t.ControlCharacters[VEOL]:
+ return true
+ case t.ControlCharacters[VEOL2]:
+ return t.LEnabled(IEXTEN)
+ }
+ return false
+}
+
+// IsEOF returns whether c is the EOF character.
+func (t *KernelTermios) IsEOF(c byte) bool {
+ return c == t.ControlCharacters[VEOF] && t.ControlCharacters[VEOF] != disabledChar
+}
+
+// Input flags.
+const (
+ IGNBRK = 0000001
+ BRKINT = 0000002
+ IGNPAR = 0000004
+ PARMRK = 0000010
+ INPCK = 0000020
+ ISTRIP = 0000040
+ INLCR = 0000100
+ IGNCR = 0000200
+ ICRNL = 0000400
+ IUCLC = 0001000
+ IXON = 0002000
+ IXANY = 0004000
+ IXOFF = 0010000
+ IMAXBEL = 0020000
+ IUTF8 = 0040000
+)
+
+// Output flags.
+const (
+ OPOST = 0000001
+ OLCUC = 0000002
+ ONLCR = 0000004
+ OCRNL = 0000010
+ ONOCR = 0000020
+ ONLRET = 0000040
+ OFILL = 0000100
+ OFDEL = 0000200
+ NLDLY = 0000400
+ NL0 = 0000000
+ NL1 = 0000400
+ CRDLY = 0003000
+ CR0 = 0000000
+ CR1 = 0001000
+ CR2 = 0002000
+ CR3 = 0003000
+ TABDLY = 0014000
+ TAB0 = 0000000
+ TAB1 = 0004000
+ TAB2 = 0010000
+ TAB3 = 0014000
+ XTABS = 0014000
+ BSDLY = 0020000
+ BS0 = 0000000
+ BS1 = 0020000
+ VTDLY = 0040000
+ VT0 = 0000000
+ VT1 = 0040000
+ FFDLY = 0100000
+ FF0 = 0000000
+ FF1 = 0100000
+)
+
+// Control flags.
+const (
+ CBAUD = 0010017
+ B0 = 0000000
+ B50 = 0000001
+ B75 = 0000002
+ B110 = 0000003
+ B134 = 0000004
+ B150 = 0000005
+ B200 = 0000006
+ B300 = 0000007
+ B600 = 0000010
+ B1200 = 0000011
+ B1800 = 0000012
+ B2400 = 0000013
+ B4800 = 0000014
+ B9600 = 0000015
+ B19200 = 0000016
+ B38400 = 0000017
+ EXTA = B19200
+ EXTB = B38400
+ CSIZE = 0000060
+ CS5 = 0000000
+ CS6 = 0000020
+ CS7 = 0000040
+ CS8 = 0000060
+ CSTOPB = 0000100
+ CREAD = 0000200
+ PARENB = 0000400
+ PARODD = 0001000
+ HUPCL = 0002000
+ CLOCAL = 0004000
+ CBAUDEX = 0010000
+ BOTHER = 0010000
+ B57600 = 0010001
+ B115200 = 0010002
+ B230400 = 0010003
+ B460800 = 0010004
+ B500000 = 0010005
+ B576000 = 0010006
+ B921600 = 0010007
+ B1000000 = 0010010
+ B1152000 = 0010011
+ B1500000 = 0010012
+ B2000000 = 0010013
+ B2500000 = 0010014
+ B3000000 = 0010015
+ B3500000 = 0010016
+ B4000000 = 0010017
+ CIBAUD = 002003600000
+ CMSPAR = 010000000000
+ CRTSCTS = 020000000000
+
+ // IBSHIFT is the shift from CBAUD to CIBAUD.
+ IBSHIFT = 16
+)
+
+// Local flags.
+const (
+ ISIG = 0000001
+ ICANON = 0000002
+ XCASE = 0000004
+ ECHO = 0000010
+ ECHOE = 0000020
+ ECHOK = 0000040
+ ECHONL = 0000100
+ NOFLSH = 0000200
+ TOSTOP = 0000400
+ ECHOCTL = 0001000
+ ECHOPRT = 0002000
+ ECHOKE = 0004000
+ FLUSHO = 0010000
+ PENDIN = 0040000
+ IEXTEN = 0100000
+ EXTPROC = 0200000
+)
+
+// Control Character indices.
+const (
+ VINTR = 0
+ VQUIT = 1
+ VERASE = 2
+ VKILL = 3
+ VEOF = 4
+ VTIME = 5
+ VMIN = 6
+ VSWTC = 7
+ VSTART = 8
+ VSTOP = 9
+ VSUSP = 10
+ VEOL = 11
+ VREPRINT = 12
+ VDISCARD = 13
+ VWERASE = 14
+ VLNEXT = 15
+ VEOL2 = 16
+)
+
+// ControlCharacter returns the termios-style control character for the passed
+// character.
+//
+// e.g., for Ctrl-C, i.e., ^C, call ControlCharacter('C').
+//
+// Standard control characters are ASCII bytes 0 through 31.
+func ControlCharacter(c byte) uint8 {
+ // A is 1, B is 2, etc.
+ return uint8(c - 'A' + 1)
+}
+
+// DefaultControlCharacters is the default set of Termios control characters.
+var DefaultControlCharacters = [NumControlCharacters]uint8{
+ ControlCharacter('C'), // VINTR = ^C
+ ControlCharacter('\\'), // VQUIT = ^\
+ '\x7f', // VERASE = DEL
+ ControlCharacter('U'), // VKILL = ^U
+ ControlCharacter('D'), // VEOF = ^D
+ 0, // VTIME
+ 1, // VMIN
+ 0, // VSWTC
+ ControlCharacter('Q'), // VSTART = ^Q
+ ControlCharacter('S'), // VSTOP = ^S
+ ControlCharacter('Z'), // VSUSP = ^Z
+ 0, // VEOL
+ ControlCharacter('R'), // VREPRINT = ^R
+ ControlCharacter('O'), // VDISCARD = ^O
+ ControlCharacter('W'), // VWERASE = ^W
+ ControlCharacter('V'), // VLNEXT = ^V
+ 0, // VEOL2
+}
+
+// MasterTermios is the terminal configuration of the master end of a Unix98
+// pseudoterminal.
+var MasterTermios = KernelTermios{
+ ControlFlags: B38400 | CS8 | CREAD,
+ ControlCharacters: DefaultControlCharacters,
+ InputSpeed: 38400,
+ OutputSpeed: 38400,
+}
+
+// DefaultSlaveTermios is the default terminal configuration of the slave end
+// of a Unix98 pseudoterminal.
+var DefaultSlaveTermios = KernelTermios{
+ InputFlags: ICRNL | IXON,
+ OutputFlags: OPOST | ONLCR,
+ ControlFlags: B38400 | CS8 | CREAD,
+ LocalFlags: ISIG | ICANON | ECHO | ECHOE | ECHOK | ECHOCTL | ECHOKE | IEXTEN,
+ ControlCharacters: DefaultControlCharacters,
+ InputSpeed: 38400,
+ OutputSpeed: 38400,
+}
+
+// WindowSize corresponds to struct winsize defined in
+// include/uapi/asm-generic/termios.h.
+//
+// +stateify savable
+type WindowSize struct {
+ Rows uint16
+ Cols uint16
+ _ [4]byte // Padding for 2 unused shorts.
+}
diff --git a/pkg/abi/linux/uio.go b/pkg/abi/linux/uio.go
new file mode 100644
index 000000000..1fd1e9802
--- /dev/null
+++ b/pkg/abi/linux/uio.go
@@ -0,0 +1,18 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// UIO_MAXIOV is the maximum number of struct iovecs in a struct iovec array.
+const UIO_MAXIOV = 1024
diff --git a/pkg/abi/linux/utsname.go b/pkg/abi/linux/utsname.go
new file mode 100644
index 000000000..60f220a67
--- /dev/null
+++ b/pkg/abi/linux/utsname.go
@@ -0,0 +1,49 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "bytes"
+ "fmt"
+)
+
+const (
+ // UTSLen is the maximum length of strings contained in fields of
+ // UtsName.
+ UTSLen = 64
+)
+
+// UtsName represents struct utsname, the struct returned by uname(2).
+type UtsName struct {
+ Sysname [UTSLen + 1]byte
+ Nodename [UTSLen + 1]byte
+ Release [UTSLen + 1]byte
+ Version [UTSLen + 1]byte
+ Machine [UTSLen + 1]byte
+ Domainname [UTSLen + 1]byte
+}
+
+// utsNameString converts a UtsName entry to a string without NULs.
+func utsNameString(s [UTSLen + 1]byte) string {
+ // The NUL bytes will remain even in a cast to string. We must
+ // explicitly strip them.
+ return string(bytes.TrimRight(s[:], "\x00"))
+}
+
+func (u UtsName) String() string {
+ return fmt.Sprintf("{Sysname: %s, Nodename: %s, Release: %s, Version: %s, Machine: %s, Domainname: %s}",
+ utsNameString(u.Sysname), utsNameString(u.Nodename), utsNameString(u.Release),
+ utsNameString(u.Version), utsNameString(u.Machine), utsNameString(u.Domainname))
+}
diff --git a/pkg/abi/linux/wait.go b/pkg/abi/linux/wait.go
new file mode 100644
index 000000000..4bdc280d1
--- /dev/null
+++ b/pkg/abi/linux/wait.go
@@ -0,0 +1,36 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Options for waitpid(2), wait4(2), and/or waitid(2), from
+// include/uapi/linux/wait.h.
+const (
+ WNOHANG = 0x00000001
+ WUNTRACED = 0x00000002
+ WSTOPPED = WUNTRACED
+ WEXITED = 0x00000004
+ WCONTINUED = 0x00000008
+ WNOWAIT = 0x01000000
+ WNOTHREAD = 0x20000000
+ WALL = 0x40000000
+ WCLONE = 0x80000000
+)
+
+// ID types for waitid(2), from include/uapi/linux/wait.h.
+const (
+ P_ALL = 0x0
+ P_PID = 0x1
+ P_PGID = 0x2
+)
diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go
new file mode 100644
index 000000000..99180b208
--- /dev/null
+++ b/pkg/abi/linux/xattr.go
@@ -0,0 +1,28 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Constants for extended attributes.
+const (
+ XATTR_NAME_MAX = 255
+ XATTR_SIZE_MAX = 65536
+ XATTR_LIST_MAX = 65536
+
+ XATTR_CREATE = 1
+ XATTR_REPLACE = 2
+
+ XATTR_USER_PREFIX = "user."
+ XATTR_USER_PREFIX_LEN = len(XATTR_USER_PREFIX)
+)
diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD
new file mode 100644
index 000000000..ffc918846
--- /dev/null
+++ b/pkg/amutex/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "amutex",
+ srcs = ["amutex.go"],
+ visibility = ["//:sandbox"],
+ deps = ["//pkg/syserror"],
+)
+
+go_test(
+ name = "amutex_test",
+ size = "small",
+ srcs = ["amutex_test.go"],
+ library = ":amutex",
+ deps = ["//pkg/sync"],
+)
diff --git a/pkg/amutex/amutex.go b/pkg/amutex/amutex.go
new file mode 100644
index 000000000..a078a31db
--- /dev/null
+++ b/pkg/amutex/amutex.go
@@ -0,0 +1,137 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package amutex provides the implementation of an abortable mutex. It allows
+// the Lock() function to be canceled while it waits to acquire the mutex.
+package amutex
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Sleeper must be implemented by users of the abortable mutex to allow for
+// cancellation of waits.
+type Sleeper interface {
+ // SleepStart is called by the AbortableMutex.Lock() function when the
+ // mutex is contended and the goroutine is about to sleep.
+ //
+ // A channel can be returned that causes the sleep to be canceled if
+ // it's readable. If no cancellation is desired, nil can be returned.
+ SleepStart() <-chan struct{}
+
+ // SleepFinish is called by AbortableMutex.Lock() once a contended mutex
+ // is acquired or the wait is aborted.
+ SleepFinish(success bool)
+
+ // Interrupted returns true if the wait is aborted.
+ Interrupted() bool
+}
+
+// NoopSleeper is a stateless no-op implementation of Sleeper for anonymous
+// embedding in other types that do not support cancelation.
+type NoopSleeper struct{}
+
+// SleepStart implements Sleeper.SleepStart.
+func (NoopSleeper) SleepStart() <-chan struct{} {
+ return nil
+}
+
+// SleepFinish implements Sleeper.SleepFinish.
+func (NoopSleeper) SleepFinish(success bool) {}
+
+// Interrupted implements Sleeper.Interrupted.
+func (NoopSleeper) Interrupted() bool { return false }
+
+// Block blocks until either receiving from ch succeeds (in which case it
+// returns nil) or sleeper is interrupted (in which case it returns
+// syserror.ErrInterrupted).
+func Block(sleeper Sleeper, ch <-chan struct{}) error {
+ cancel := sleeper.SleepStart()
+ select {
+ case <-ch:
+ sleeper.SleepFinish(true)
+ return nil
+ case <-cancel:
+ sleeper.SleepFinish(false)
+ return syserror.ErrInterrupted
+ }
+}
+
+// AbortableMutex is an abortable mutex. It allows Lock() to be aborted while it
+// waits to acquire the mutex.
+type AbortableMutex struct {
+ v int32
+ ch chan struct{}
+}
+
+// Init initializes the abortable mutex.
+func (m *AbortableMutex) Init() {
+ m.v = 1
+ m.ch = make(chan struct{}, 1)
+}
+
+// Lock attempts to acquire the mutex, returning true on success. If something
+// is written to the "c" while Lock waits, the wait is aborted and false is
+// returned instead.
+func (m *AbortableMutex) Lock(s Sleeper) bool {
+ // Uncontended case.
+ if atomic.AddInt32(&m.v, -1) == 0 {
+ return true
+ }
+
+ var c <-chan struct{}
+ if s != nil {
+ c = s.SleepStart()
+ }
+
+ for {
+ // Try to acquire the mutex again, at the same time making sure
+ // that m.v is negative, which indicates to the owner of the
+ // lock that it is contended, which ill force it to try to wake
+ // someone up when it releases the mutex.
+ if v := atomic.LoadInt32(&m.v); v >= 0 && atomic.SwapInt32(&m.v, -1) == 1 {
+ if s != nil {
+ s.SleepFinish(true)
+ }
+ return true
+ }
+
+ // Wait for the owner to wake us up before trying again, or for
+ // the wait to be aborted by the provided channel.
+ select {
+ case <-m.ch:
+ case <-c:
+ // s must be non-nil, otherwise c would be nil and we'd
+ // never reach this path.
+ s.SleepFinish(false)
+ return false
+ }
+ }
+}
+
+// Unlock releases the mutex.
+func (m *AbortableMutex) Unlock() {
+ if atomic.SwapInt32(&m.v, 1) == 0 {
+ // There were no pending waiters.
+ return
+ }
+
+ // Wake some waiter up.
+ select {
+ case m.ch <- struct{}{}:
+ default:
+ }
+}
diff --git a/pkg/amutex/amutex_test.go b/pkg/amutex/amutex_test.go
new file mode 100644
index 000000000..8a3952f2a
--- /dev/null
+++ b/pkg/amutex/amutex_test.go
@@ -0,0 +1,98 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package amutex
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+type sleeper struct {
+ ch chan struct{}
+}
+
+func (s *sleeper) SleepStart() <-chan struct{} {
+ return s.ch
+}
+
+func (*sleeper) SleepFinish(bool) {
+}
+
+func (s *sleeper) Interrupted() bool {
+ return len(s.ch) != 0
+}
+
+func TestMutualExclusion(t *testing.T) {
+ var m AbortableMutex
+ m.Init()
+
+ // Test mutual exclusion by running "gr" goroutines concurrently, and
+ // have each one increment a counter "iters" times within the critical
+ // section established by the mutex.
+ //
+ // If at the end of the counter is not gr * iters, then we know that
+ // goroutines ran concurrently within the critical section.
+ //
+ // If one of the goroutines doesn't complete, it's likely a bug that
+ // causes it to wait forever.
+ const gr = 1000
+ const iters = 100000
+ v := 0
+ var wg sync.WaitGroup
+ for i := 0; i < gr; i++ {
+ wg.Add(1)
+ go func() {
+ for j := 0; j < iters; j++ {
+ m.Lock(nil)
+ v++
+ m.Unlock()
+ }
+ wg.Done()
+ }()
+ }
+
+ wg.Wait()
+
+ if v != gr*iters {
+ t.Fatalf("Bad count: got %v, want %v", v, gr*iters)
+ }
+}
+
+func TestAbortWait(t *testing.T) {
+ var s sleeper
+ var m AbortableMutex
+ m.Init()
+
+ // Lock the mutex.
+ m.Lock(&s)
+
+ // Lock again, but this time cancel after 500ms.
+ s.ch = make(chan struct{}, 1)
+ go func() {
+ time.Sleep(500 * time.Millisecond)
+ s.ch <- struct{}{}
+ }()
+ if v := m.Lock(&s); v {
+ t.Fatalf("Lock succeeded when it should have failed")
+ }
+
+ // Lock again, but cancel right away.
+ s.ch <- struct{}{}
+ if v := m.Lock(&s); v {
+ t.Fatalf("Lock succeeded when it should have failed")
+ }
+}
diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD
new file mode 100644
index 000000000..1a30f6967
--- /dev/null
+++ b/pkg/atomicbitops/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "atomicbitops",
+ srcs = [
+ "atomicbitops.go",
+ "atomicbitops_amd64.s",
+ "atomicbitops_arm64.s",
+ "atomicbitops_noasm.go",
+ ],
+ visibility = ["//:sandbox"],
+)
+
+go_test(
+ name = "atomicbitops_test",
+ size = "small",
+ srcs = ["atomicbitops_test.go"],
+ library = ":atomicbitops",
+ deps = ["//pkg/sync"],
+)
diff --git a/pkg/atomicbitops/atomicbitops.go b/pkg/atomicbitops/atomicbitops.go
new file mode 100644
index 000000000..1be081719
--- /dev/null
+++ b/pkg/atomicbitops/atomicbitops.go
@@ -0,0 +1,47 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64 arm64
+
+// Package atomicbitops provides extensions to the sync/atomic package.
+//
+// All read-modify-write operations implemented by this package have
+// acquire-release memory ordering (like sync/atomic).
+package atomicbitops
+
+// AndUint32 atomically applies bitwise AND operation to *addr with val.
+func AndUint32(addr *uint32, val uint32)
+
+// OrUint32 atomically applies bitwise OR operation to *addr with val.
+func OrUint32(addr *uint32, val uint32)
+
+// XorUint32 atomically applies bitwise XOR operation to *addr with val.
+func XorUint32(addr *uint32, val uint32)
+
+// CompareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns
+// the value previously stored at addr.
+func CompareAndSwapUint32(addr *uint32, old, new uint32) uint32
+
+// AndUint64 atomically applies bitwise AND operation to *addr with val.
+func AndUint64(addr *uint64, val uint64)
+
+// OrUint64 atomically applies bitwise OR operation to *addr with val.
+func OrUint64(addr *uint64, val uint64)
+
+// XorUint64 atomically applies bitwise XOR operation to *addr with val.
+func XorUint64(addr *uint64, val uint64)
+
+// CompareAndSwapUint64 is like sync/atomic.CompareAndSwapUint64, but returns
+// the value previously stored at addr.
+func CompareAndSwapUint64(addr *uint64, old, new uint64) uint64
diff --git a/pkg/atomicbitops/atomicbitops_amd64.s b/pkg/atomicbitops/atomicbitops_amd64.s
new file mode 100644
index 000000000..54c887ee5
--- /dev/null
+++ b/pkg/atomicbitops/atomicbitops_amd64.s
@@ -0,0 +1,77 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+#include "textflag.h"
+
+TEXT ·AndUint32(SB),$0-12
+ MOVQ addr+0(FP), BP
+ MOVL val+8(FP), AX
+ LOCK
+ ANDL AX, 0(BP)
+ RET
+
+TEXT ·OrUint32(SB),$0-12
+ MOVQ addr+0(FP), BP
+ MOVL val+8(FP), AX
+ LOCK
+ ORL AX, 0(BP)
+ RET
+
+TEXT ·XorUint32(SB),$0-12
+ MOVQ addr+0(FP), BP
+ MOVL val+8(FP), AX
+ LOCK
+ XORL AX, 0(BP)
+ RET
+
+TEXT ·CompareAndSwapUint32(SB),$0-20
+ MOVQ addr+0(FP), DI
+ MOVL old+8(FP), AX
+ MOVL new+12(FP), DX
+ LOCK
+ CMPXCHGL DX, 0(DI)
+ MOVL AX, ret+16(FP)
+ RET
+
+TEXT ·AndUint64(SB),$0-16
+ MOVQ addr+0(FP), BP
+ MOVQ val+8(FP), AX
+ LOCK
+ ANDQ AX, 0(BP)
+ RET
+
+TEXT ·OrUint64(SB),$0-16
+ MOVQ addr+0(FP), BP
+ MOVQ val+8(FP), AX
+ LOCK
+ ORQ AX, 0(BP)
+ RET
+
+TEXT ·XorUint64(SB),$0-16
+ MOVQ addr+0(FP), BP
+ MOVQ val+8(FP), AX
+ LOCK
+ XORQ AX, 0(BP)
+ RET
+
+TEXT ·CompareAndSwapUint64(SB),$0-32
+ MOVQ addr+0(FP), DI
+ MOVQ old+8(FP), AX
+ MOVQ new+16(FP), DX
+ LOCK
+ CMPXCHGQ DX, 0(DI)
+ MOVQ AX, ret+24(FP)
+ RET
diff --git a/pkg/atomicbitops/atomicbitops_arm64.s b/pkg/atomicbitops/atomicbitops_arm64.s
new file mode 100644
index 000000000..5c780851b
--- /dev/null
+++ b/pkg/atomicbitops/atomicbitops_arm64.s
@@ -0,0 +1,105 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+#include "textflag.h"
+
+TEXT ·AndUint32(SB),$0-12
+ MOVD ptr+0(FP), R0
+ MOVW val+8(FP), R1
+again:
+ LDAXRW (R0), R2
+ ANDW R1, R2
+ STLXRW R2, (R0), R3
+ CBNZ R3, again
+ RET
+
+TEXT ·OrUint32(SB),$0-12
+ MOVD ptr+0(FP), R0
+ MOVW val+8(FP), R1
+again:
+ LDAXRW (R0), R2
+ ORRW R1, R2
+ STLXRW R2, (R0), R3
+ CBNZ R3, again
+ RET
+
+TEXT ·XorUint32(SB),$0-12
+ MOVD ptr+0(FP), R0
+ MOVW val+8(FP), R1
+again:
+ LDAXRW (R0), R2
+ EORW R1, R2
+ STLXRW R2, (R0), R3
+ CBNZ R3, again
+ RET
+
+TEXT ·CompareAndSwapUint32(SB),$0-20
+ MOVD addr+0(FP), R0
+ MOVW old+8(FP), R1
+ MOVW new+12(FP), R2
+again:
+ LDAXRW (R0), R3
+ CMPW R1, R3
+ BNE done
+ STLXRW R2, (R0), R4
+ CBNZ R4, again
+done:
+ MOVW R3, prev+16(FP)
+ RET
+
+TEXT ·AndUint64(SB),$0-16
+ MOVD ptr+0(FP), R0
+ MOVD val+8(FP), R1
+again:
+ LDAXR (R0), R2
+ AND R1, R2
+ STLXR R2, (R0), R3
+ CBNZ R3, again
+ RET
+
+TEXT ·OrUint64(SB),$0-16
+ MOVD ptr+0(FP), R0
+ MOVD val+8(FP), R1
+again:
+ LDAXR (R0), R2
+ ORR R1, R2
+ STLXR R2, (R0), R3
+ CBNZ R3, again
+ RET
+
+TEXT ·XorUint64(SB),$0-16
+ MOVD ptr+0(FP), R0
+ MOVD val+8(FP), R1
+again:
+ LDAXR (R0), R2
+ EOR R1, R2
+ STLXR R2, (R0), R3
+ CBNZ R3, again
+ RET
+
+TEXT ·CompareAndSwapUint64(SB),$0-32
+ MOVD addr+0(FP), R0
+ MOVD old+8(FP), R1
+ MOVD new+16(FP), R2
+again:
+ LDAXR (R0), R3
+ CMP R1, R3
+ BNE done
+ STLXR R2, (R0), R4
+ CBNZ R4, again
+done:
+ MOVD R3, prev+24(FP)
+ RET
diff --git a/pkg/atomicbitops/atomicbitops_noasm.go b/pkg/atomicbitops/atomicbitops_noasm.go
new file mode 100644
index 000000000..3b2898256
--- /dev/null
+++ b/pkg/atomicbitops/atomicbitops_noasm.go
@@ -0,0 +1,105 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !amd64,!arm64
+
+package atomicbitops
+
+import (
+ "sync/atomic"
+)
+
+func AndUint32(addr *uint32, val uint32) {
+ for {
+ o := atomic.LoadUint32(addr)
+ n := o & val
+ if atomic.CompareAndSwapUint32(addr, o, n) {
+ break
+ }
+ }
+}
+
+func OrUint32(addr *uint32, val uint32) {
+ for {
+ o := atomic.LoadUint32(addr)
+ n := o | val
+ if atomic.CompareAndSwapUint32(addr, o, n) {
+ break
+ }
+ }
+}
+
+func XorUint32(addr *uint32, val uint32) {
+ for {
+ o := atomic.LoadUint32(addr)
+ n := o ^ val
+ if atomic.CompareAndSwapUint32(addr, o, n) {
+ break
+ }
+ }
+}
+
+func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
+ for {
+ prev = atomic.LoadUint32(addr)
+ if prev != old {
+ return
+ }
+ if atomic.CompareAndSwapUint32(addr, old, new) {
+ return
+ }
+ }
+}
+
+func AndUint64(addr *uint64, val uint64) {
+ for {
+ o := atomic.LoadUint64(addr)
+ n := o & val
+ if atomic.CompareAndSwapUint64(addr, o, n) {
+ break
+ }
+ }
+}
+
+func OrUint64(addr *uint64, val uint64) {
+ for {
+ o := atomic.LoadUint64(addr)
+ n := o | val
+ if atomic.CompareAndSwapUint64(addr, o, n) {
+ break
+ }
+ }
+}
+
+func XorUint64(addr *uint64, val uint64) {
+ for {
+ o := atomic.LoadUint64(addr)
+ n := o ^ val
+ if atomic.CompareAndSwapUint64(addr, o, n) {
+ break
+ }
+ }
+}
+
+func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) {
+ for {
+ prev = atomic.LoadUint64(addr)
+ if prev != old {
+ return
+ }
+ if atomic.CompareAndSwapUint64(addr, old, new) {
+ return
+ }
+ }
+}
diff --git a/pkg/atomicbitops/atomicbitops_test.go b/pkg/atomicbitops/atomicbitops_test.go
new file mode 100644
index 000000000..73af71bb4
--- /dev/null
+++ b/pkg/atomicbitops/atomicbitops_test.go
@@ -0,0 +1,198 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package atomicbitops
+
+import (
+ "runtime"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+const iterations = 100
+
+func detectRaces32(val, target uint32, fn func(*uint32, uint32)) bool {
+ runtime.GOMAXPROCS(100)
+ for n := 0; n < iterations; n++ {
+ x := val
+ var wg sync.WaitGroup
+ for i := uint32(0); i < 32; i++ {
+ wg.Add(1)
+ go func(a *uint32, i uint32) {
+ defer wg.Done()
+ fn(a, uint32(1<<i))
+ }(&x, i)
+ }
+ wg.Wait()
+ if x != target {
+ return true
+ }
+ }
+ return false
+}
+
+func detectRaces64(val, target uint64, fn func(*uint64, uint64)) bool {
+ runtime.GOMAXPROCS(100)
+ for n := 0; n < iterations; n++ {
+ x := val
+ var wg sync.WaitGroup
+ for i := uint64(0); i < 64; i++ {
+ wg.Add(1)
+ go func(a *uint64, i uint64) {
+ defer wg.Done()
+ fn(a, uint64(1<<i))
+ }(&x, i)
+ }
+ wg.Wait()
+ if x != target {
+ return true
+ }
+ }
+ return false
+}
+
+func TestOrUint32(t *testing.T) {
+ if detectRaces32(0x0, 0xffffffff, OrUint32) {
+ t.Error("Data race detected!")
+ }
+}
+
+func TestAndUint32(t *testing.T) {
+ if detectRaces32(0xf0f0f0f0, 0x00000000, AndUint32) {
+ t.Error("Data race detected!")
+ }
+}
+
+func TestXorUint32(t *testing.T) {
+ if detectRaces32(0xf0f0f0f0, 0x0f0f0f0f, XorUint32) {
+ t.Error("Data race detected!")
+ }
+}
+
+func TestOrUint64(t *testing.T) {
+ if detectRaces64(0x0, 0xffffffffffffffff, OrUint64) {
+ t.Error("Data race detected!")
+ }
+}
+
+func TestAndUint64(t *testing.T) {
+ if detectRaces64(0xf0f0f0f0f0f0f0f0, 0x0, AndUint64) {
+ t.Error("Data race detected!")
+ }
+}
+
+func TestXorUint64(t *testing.T) {
+ if detectRaces64(0xf0f0f0f0f0f0f0f0, 0x0f0f0f0f0f0f0f0f, XorUint64) {
+ t.Error("Data race detected!")
+ }
+}
+
+func TestCompareAndSwapUint32(t *testing.T) {
+ tests := []struct {
+ name string
+ prev uint32
+ old uint32
+ new uint32
+ next uint32
+ }{
+ {
+ name: "Successful compare-and-swap with prev == new",
+ prev: 10,
+ old: 10,
+ new: 10,
+ next: 10,
+ },
+ {
+ name: "Successful compare-and-swap with prev != new",
+ prev: 20,
+ old: 20,
+ new: 22,
+ next: 22,
+ },
+ {
+ name: "Failed compare-and-swap with prev == new",
+ prev: 31,
+ old: 30,
+ new: 31,
+ next: 31,
+ },
+ {
+ name: "Failed compare-and-swap with prev != new",
+ prev: 41,
+ old: 40,
+ new: 42,
+ next: 41,
+ },
+ }
+ for _, test := range tests {
+ val := test.prev
+ prev := CompareAndSwapUint32(&val, test.old, test.new)
+ if got, want := prev, test.prev; got != want {
+ t.Errorf("%s: incorrect returned previous value: got %d, expected %d", test.name, got, want)
+ }
+ if got, want := val, test.next; got != want {
+ t.Errorf("%s: incorrect value stored in val: got %d, expected %d", test.name, got, want)
+ }
+ }
+}
+
+func TestCompareAndSwapUint64(t *testing.T) {
+ tests := []struct {
+ name string
+ prev uint64
+ old uint64
+ new uint64
+ next uint64
+ }{
+ {
+ name: "Successful compare-and-swap with prev == new",
+ prev: 0x100000000,
+ old: 0x100000000,
+ new: 0x100000000,
+ next: 0x100000000,
+ },
+ {
+ name: "Successful compare-and-swap with prev != new",
+ prev: 0x200000000,
+ old: 0x200000000,
+ new: 0x200000002,
+ next: 0x200000002,
+ },
+ {
+ name: "Failed compare-and-swap with prev == new",
+ prev: 0x300000001,
+ old: 0x300000000,
+ new: 0x300000001,
+ next: 0x300000001,
+ },
+ {
+ name: "Failed compare-and-swap with prev != new",
+ prev: 0x400000001,
+ old: 0x400000000,
+ new: 0x400000002,
+ next: 0x400000001,
+ },
+ }
+ for _, test := range tests {
+ val := test.prev
+ prev := CompareAndSwapUint64(&val, test.old, test.new)
+ if got, want := prev, test.prev; got != want {
+ t.Errorf("%s: incorrect returned previous value: got %d, expected %d", test.name, got, want)
+ }
+ if got, want := val, test.next; got != want {
+ t.Errorf("%s: incorrect value stored in val: got %d, expected %d", test.name, got, want)
+ }
+ }
+}
diff --git a/pkg/binary/BUILD b/pkg/binary/BUILD
new file mode 100644
index 000000000..7ca2fda90
--- /dev/null
+++ b/pkg/binary/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "binary",
+ srcs = ["binary.go"],
+ visibility = ["//:sandbox"],
+)
+
+go_test(
+ name = "binary_test",
+ size = "small",
+ srcs = ["binary_test.go"],
+ library = ":binary",
+)
diff --git a/pkg/binary/binary.go b/pkg/binary/binary.go
new file mode 100644
index 000000000..25065aef9
--- /dev/null
+++ b/pkg/binary/binary.go
@@ -0,0 +1,266 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package binary translates between select fixed-sized types and a binary
+// representation.
+package binary
+
+import (
+ "encoding/binary"
+ "fmt"
+ "io"
+ "reflect"
+)
+
+// LittleEndian is the same as encoding/binary.LittleEndian.
+//
+// It is included here as a convenience.
+var LittleEndian = binary.LittleEndian
+
+// BigEndian is the same as encoding/binary.BigEndian.
+//
+// It is included here as a convenience.
+var BigEndian = binary.BigEndian
+
+// AppendUint16 appends the binary representation of a uint16 to buf.
+func AppendUint16(buf []byte, order binary.ByteOrder, num uint16) []byte {
+ buf = append(buf, make([]byte, 2)...)
+ order.PutUint16(buf[len(buf)-2:], num)
+ return buf
+}
+
+// AppendUint32 appends the binary representation of a uint32 to buf.
+func AppendUint32(buf []byte, order binary.ByteOrder, num uint32) []byte {
+ buf = append(buf, make([]byte, 4)...)
+ order.PutUint32(buf[len(buf)-4:], num)
+ return buf
+}
+
+// AppendUint64 appends the binary representation of a uint64 to buf.
+func AppendUint64(buf []byte, order binary.ByteOrder, num uint64) []byte {
+ buf = append(buf, make([]byte, 8)...)
+ order.PutUint64(buf[len(buf)-8:], num)
+ return buf
+}
+
+// Marshal appends a binary representation of data to buf.
+//
+// data must only contain fixed-length signed and unsigned ints, arrays,
+// slices, structs and compositions of said types. data may be a pointer,
+// but cannot contain pointers.
+func Marshal(buf []byte, order binary.ByteOrder, data interface{}) []byte {
+ return marshal(buf, order, reflect.Indirect(reflect.ValueOf(data)))
+}
+
+func marshal(buf []byte, order binary.ByteOrder, data reflect.Value) []byte {
+ switch data.Kind() {
+ case reflect.Int8:
+ buf = append(buf, byte(int8(data.Int())))
+ case reflect.Int16:
+ buf = AppendUint16(buf, order, uint16(int16(data.Int())))
+ case reflect.Int32:
+ buf = AppendUint32(buf, order, uint32(int32(data.Int())))
+ case reflect.Int64:
+ buf = AppendUint64(buf, order, uint64(data.Int()))
+
+ case reflect.Uint8:
+ buf = append(buf, byte(data.Uint()))
+ case reflect.Uint16:
+ buf = AppendUint16(buf, order, uint16(data.Uint()))
+ case reflect.Uint32:
+ buf = AppendUint32(buf, order, uint32(data.Uint()))
+ case reflect.Uint64:
+ buf = AppendUint64(buf, order, data.Uint())
+
+ case reflect.Array, reflect.Slice:
+ for i, l := 0, data.Len(); i < l; i++ {
+ buf = marshal(buf, order, data.Index(i))
+ }
+
+ case reflect.Struct:
+ for i, l := 0, data.NumField(); i < l; i++ {
+ buf = marshal(buf, order, data.Field(i))
+ }
+
+ default:
+ panic("invalid type: " + data.Type().String())
+ }
+ return buf
+}
+
+// Unmarshal unpacks buf into data.
+//
+// data must be a slice or a pointer and buf must have a length of exactly
+// Size(data). data must only contain fixed-length signed and unsigned ints,
+// arrays, slices, structs and compositions of said types.
+func Unmarshal(buf []byte, order binary.ByteOrder, data interface{}) {
+ value := reflect.ValueOf(data)
+ switch value.Kind() {
+ case reflect.Ptr:
+ value = value.Elem()
+ case reflect.Slice:
+ default:
+ panic("invalid type: " + value.Type().String())
+ }
+ buf = unmarshal(buf, order, value)
+ if len(buf) != 0 {
+ panic(fmt.Sprintf("buffer too long by %d bytes", len(buf)))
+ }
+}
+
+func unmarshal(buf []byte, order binary.ByteOrder, data reflect.Value) []byte {
+ switch data.Kind() {
+ case reflect.Int8:
+ data.SetInt(int64(int8(buf[0])))
+ buf = buf[1:]
+ case reflect.Int16:
+ data.SetInt(int64(int16(order.Uint16(buf))))
+ buf = buf[2:]
+ case reflect.Int32:
+ data.SetInt(int64(int32(order.Uint32(buf))))
+ buf = buf[4:]
+ case reflect.Int64:
+ data.SetInt(int64(order.Uint64(buf)))
+ buf = buf[8:]
+
+ case reflect.Uint8:
+ data.SetUint(uint64(buf[0]))
+ buf = buf[1:]
+ case reflect.Uint16:
+ data.SetUint(uint64(order.Uint16(buf)))
+ buf = buf[2:]
+ case reflect.Uint32:
+ data.SetUint(uint64(order.Uint32(buf)))
+ buf = buf[4:]
+ case reflect.Uint64:
+ data.SetUint(order.Uint64(buf))
+ buf = buf[8:]
+
+ case reflect.Array, reflect.Slice:
+ for i, l := 0, data.Len(); i < l; i++ {
+ buf = unmarshal(buf, order, data.Index(i))
+ }
+
+ case reflect.Struct:
+ for i, l := 0, data.NumField(); i < l; i++ {
+ if field := data.Field(i); field.CanSet() {
+ buf = unmarshal(buf, order, field)
+ } else {
+ buf = buf[sizeof(field):]
+ }
+ }
+
+ default:
+ panic("invalid type: " + data.Type().String())
+ }
+ return buf
+}
+
+// Size calculates the buffer sized needed by Marshal or Unmarshal.
+//
+// Size only support the types supported by Marshal.
+func Size(v interface{}) uintptr {
+ return sizeof(reflect.Indirect(reflect.ValueOf(v)))
+}
+
+func sizeof(data reflect.Value) uintptr {
+ switch data.Kind() {
+ case reflect.Int8, reflect.Uint8:
+ return 1
+ case reflect.Int16, reflect.Uint16:
+ return 2
+ case reflect.Int32, reflect.Uint32:
+ return 4
+ case reflect.Int64, reflect.Uint64:
+ return 8
+
+ case reflect.Array, reflect.Slice:
+ var size uintptr
+ for i, l := 0, data.Len(); i < l; i++ {
+ size += sizeof(data.Index(i))
+ }
+ return size
+
+ case reflect.Struct:
+ var size uintptr
+ for i, l := 0, data.NumField(); i < l; i++ {
+ size += sizeof(data.Field(i))
+ }
+ return size
+
+ default:
+ panic("invalid type: " + data.Type().String())
+ }
+}
+
+// ReadUint16 reads a uint16 from r.
+func ReadUint16(r io.Reader, order binary.ByteOrder) (uint16, error) {
+ buf := make([]byte, 2)
+ if _, err := io.ReadFull(r, buf); err != nil {
+ return 0, err
+ }
+ return order.Uint16(buf), nil
+}
+
+// ReadUint32 reads a uint32 from r.
+func ReadUint32(r io.Reader, order binary.ByteOrder) (uint32, error) {
+ buf := make([]byte, 4)
+ if _, err := io.ReadFull(r, buf); err != nil {
+ return 0, err
+ }
+ return order.Uint32(buf), nil
+}
+
+// ReadUint64 reads a uint64 from r.
+func ReadUint64(r io.Reader, order binary.ByteOrder) (uint64, error) {
+ buf := make([]byte, 8)
+ if _, err := io.ReadFull(r, buf); err != nil {
+ return 0, err
+ }
+ return order.Uint64(buf), nil
+}
+
+// WriteUint16 writes a uint16 to w.
+func WriteUint16(w io.Writer, order binary.ByteOrder, num uint16) error {
+ buf := make([]byte, 2)
+ order.PutUint16(buf, num)
+ _, err := w.Write(buf)
+ return err
+}
+
+// WriteUint32 writes a uint32 to w.
+func WriteUint32(w io.Writer, order binary.ByteOrder, num uint32) error {
+ buf := make([]byte, 4)
+ order.PutUint32(buf, num)
+ _, err := w.Write(buf)
+ return err
+}
+
+// WriteUint64 writes a uint64 to w.
+func WriteUint64(w io.Writer, order binary.ByteOrder, num uint64) error {
+ buf := make([]byte, 8)
+ order.PutUint64(buf, num)
+ _, err := w.Write(buf)
+ return err
+}
+
+// AlignUp rounds a length up to an alignment. align must be a power of 2.
+func AlignUp(length int, align uint) int {
+ return (length + int(align) - 1) & ^(int(align) - 1)
+}
+
+// AlignDown rounds a length down to an alignment. align must be a power of 2.
+func AlignDown(length int, align uint) int {
+ return length & ^(int(align) - 1)
+}
diff --git a/pkg/binary/binary_test.go b/pkg/binary/binary_test.go
new file mode 100644
index 000000000..4d609a438
--- /dev/null
+++ b/pkg/binary/binary_test.go
@@ -0,0 +1,266 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package binary
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "reflect"
+ "strings"
+ "testing"
+)
+
+func newInt32(i int32) *int32 {
+ return &i
+}
+
+func TestSize(t *testing.T) {
+ if got, want := Size(uint32(10)), uintptr(4); got != want {
+ t.Errorf("Got = %d, want = %d", got, want)
+ }
+}
+
+func TestPanic(t *testing.T) {
+ tests := []struct {
+ name string
+ f func([]byte, binary.ByteOrder, interface{})
+ data interface{}
+ want string
+ }{
+ {"Unmarshal int", Unmarshal, 5, "invalid type: int"},
+ {"Unmarshal []int", Unmarshal, []int{5}, "invalid type: int"},
+ {"Marshal int", func(_ []byte, bo binary.ByteOrder, d interface{}) { Marshal(nil, bo, d) }, 5, "invalid type: int"},
+ {"Marshal int[]", func(_ []byte, bo binary.ByteOrder, d interface{}) { Marshal(nil, bo, d) }, []int{5}, "invalid type: int"},
+ {"Unmarshal short buffer", Unmarshal, newInt32(5), "runtime error: index out of range"},
+ {"Unmarshal long buffer", func(_ []byte, bo binary.ByteOrder, d interface{}) { Unmarshal(make([]byte, 50), bo, d) }, newInt32(5), "buffer too long by 46 bytes"},
+ {"marshal int", func(_ []byte, bo binary.ByteOrder, d interface{}) { marshal(nil, bo, reflect.ValueOf(d)) }, 5, "invalid type: int"},
+ {"Size int", func(_ []byte, _ binary.ByteOrder, d interface{}) { Size(d) }, 5, "invalid type: int"},
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ defer func() {
+ r := recover()
+ if got := fmt.Sprint(r); !strings.HasPrefix(got, test.want) {
+ t.Errorf("Got recover() = %q, want prefix = %q", got, test.want)
+ }
+ }()
+
+ test.f(nil, LittleEndian, test.data)
+ })
+ }
+}
+
+type inner struct {
+ Field int32
+}
+
+type outer struct {
+ Int8 int8
+ Int16 int16
+ Int32 int32
+ Int64 int64
+ Uint8 uint8
+ Uint16 uint16
+ Uint32 uint32
+ Uint64 uint64
+
+ Slice []int32
+ Array [5]int32
+ Struct inner
+}
+
+func TestMarshalUnmarshal(t *testing.T) {
+ want := outer{
+ 1, 2, 3, 4, 5, 6, 7, 8,
+ []int32{9, 10, 11},
+ [5]int32{12, 13, 14, 15, 16},
+ inner{17},
+ }
+ buf := Marshal(nil, LittleEndian, want)
+ got := outer{Slice: []int32{0, 0, 0}}
+ Unmarshal(buf, LittleEndian, &got)
+ if !reflect.DeepEqual(&got, &want) {
+ t.Errorf("Got = %#v, want = %#v", got, want)
+ }
+}
+
+type outerBenchmark struct {
+ Int8 int8
+ Int16 int16
+ Int32 int32
+ Int64 int64
+ Uint8 uint8
+ Uint16 uint16
+ Uint32 uint32
+ Uint64 uint64
+
+ Array [5]int32
+ Struct inner
+}
+
+func BenchmarkMarshalUnmarshal(b *testing.B) {
+ b.ReportAllocs()
+
+ in := outerBenchmark{
+ 1, 2, 3, 4, 5, 6, 7, 8,
+ [5]int32{9, 10, 11, 12, 13},
+ inner{14},
+ }
+ buf := make([]byte, Size(&in))
+ out := outerBenchmark{}
+
+ for i := 0; i < b.N; i++ {
+ buf := Marshal(buf[:0], LittleEndian, &in)
+ Unmarshal(buf, LittleEndian, &out)
+ }
+}
+
+func BenchmarkReadWrite(b *testing.B) {
+ b.ReportAllocs()
+
+ in := outerBenchmark{
+ 1, 2, 3, 4, 5, 6, 7, 8,
+ [5]int32{9, 10, 11, 12, 13},
+ inner{14},
+ }
+ buf := bytes.NewBuffer(make([]byte, binary.Size(&in)))
+ out := outerBenchmark{}
+
+ for i := 0; i < b.N; i++ {
+ buf.Reset()
+ if err := binary.Write(buf, LittleEndian, &in); err != nil {
+ b.Error("Write:", err)
+ }
+ if err := binary.Read(buf, LittleEndian, &out); err != nil {
+ b.Error("Read:", err)
+ }
+ }
+}
+
+type outerPadding struct {
+ _ int8
+ _ int16
+ _ int32
+ _ int64
+ _ uint8
+ _ uint16
+ _ uint32
+ _ uint64
+
+ _ []int32
+ _ [5]int32
+ _ inner
+}
+
+func TestMarshalUnmarshalPadding(t *testing.T) {
+ var want outerPadding
+ buf := Marshal(nil, LittleEndian, want)
+ var got outerPadding
+ Unmarshal(buf, LittleEndian, &got)
+ if !reflect.DeepEqual(&got, &want) {
+ t.Errorf("Got = %#v, want = %#v", got, want)
+ }
+}
+
+// Numbers with bits in every byte that distinguishable in big and little endian.
+const (
+ want16 = 64<<8 | 128
+ want32 = 16<<24 | 32<<16 | want16
+ want64 = 1<<56 | 2<<48 | 4<<40 | 8<<32 | want32
+)
+
+func TestReadWriteUint16(t *testing.T) {
+ const want = uint16(want16)
+ var buf bytes.Buffer
+ if err := WriteUint16(&buf, LittleEndian, want); err != nil {
+ t.Error("WriteUint16:", err)
+ }
+ got, err := ReadUint16(&buf, LittleEndian)
+ if err != nil {
+ t.Error("ReadUint16:", err)
+ }
+ if got != want {
+ t.Errorf("got = %d, want = %d", got, want)
+ }
+}
+
+func TestReadWriteUint32(t *testing.T) {
+ const want = uint32(want32)
+ var buf bytes.Buffer
+ if err := WriteUint32(&buf, LittleEndian, want); err != nil {
+ t.Error("WriteUint32:", err)
+ }
+ got, err := ReadUint32(&buf, LittleEndian)
+ if err != nil {
+ t.Error("ReadUint32:", err)
+ }
+ if got != want {
+ t.Errorf("got = %d, want = %d", got, want)
+ }
+}
+
+func TestReadWriteUint64(t *testing.T) {
+ const want = uint64(want64)
+ var buf bytes.Buffer
+ if err := WriteUint64(&buf, LittleEndian, want); err != nil {
+ t.Error("WriteUint64:", err)
+ }
+ got, err := ReadUint64(&buf, LittleEndian)
+ if err != nil {
+ t.Error("ReadUint64:", err)
+ }
+ if got != want {
+ t.Errorf("got = %d, want = %d", got, want)
+ }
+}
+
+type readWriter struct {
+ err error
+}
+
+func (rw *readWriter) Write([]byte) (int, error) {
+ return 0, rw.err
+}
+
+func (rw *readWriter) Read([]byte) (int, error) {
+ return 0, rw.err
+}
+
+func TestReadWriteError(t *testing.T) {
+ tests := []struct {
+ name string
+ f func(rw io.ReadWriter) error
+ }{
+ {"WriteUint16", func(rw io.ReadWriter) error { return WriteUint16(rw, LittleEndian, 0) }},
+ {"ReadUint16", func(rw io.ReadWriter) error { _, err := ReadUint16(rw, LittleEndian); return err }},
+ {"WriteUint32", func(rw io.ReadWriter) error { return WriteUint32(rw, LittleEndian, 0) }},
+ {"ReadUint32", func(rw io.ReadWriter) error { _, err := ReadUint32(rw, LittleEndian); return err }},
+ {"WriteUint64", func(rw io.ReadWriter) error { return WriteUint64(rw, LittleEndian, 0) }},
+ {"ReadUint64", func(rw io.ReadWriter) error { _, err := ReadUint64(rw, LittleEndian); return err }},
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ want := errors.New("want")
+ if got := test.f(&readWriter{want}); got != want {
+ t.Errorf("got = %v, want = %v", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD
new file mode 100644
index 000000000..63f4670d7
--- /dev/null
+++ b/pkg/bits/BUILD
@@ -0,0 +1,55 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "bits",
+ srcs = [
+ "bits.go",
+ "bits32.go",
+ "bits64.go",
+ "uint64_arch.go",
+ "uint64_arch_amd64_asm.s",
+ "uint64_arch_arm64_asm.s",
+ "uint64_arch_generic.go",
+ ],
+ visibility = ["//:sandbox"],
+)
+
+go_template(
+ name = "bits_template",
+ srcs = ["bits_template.go"],
+ types = [
+ "T",
+ ],
+)
+
+go_template_instance(
+ name = "bits64",
+ out = "bits64.go",
+ package = "bits",
+ suffix = "64",
+ template = ":bits_template",
+ types = {
+ "T": "uint64",
+ },
+)
+
+go_template_instance(
+ name = "bits32",
+ out = "bits32.go",
+ package = "bits",
+ suffix = "32",
+ template = ":bits_template",
+ types = {
+ "T": "uint32",
+ },
+)
+
+go_test(
+ name = "bits_test",
+ size = "small",
+ srcs = ["uint64_test.go"],
+ library = ":bits",
+)
diff --git a/pkg/bits/bits.go b/pkg/bits/bits.go
new file mode 100644
index 000000000..a26433ad6
--- /dev/null
+++ b/pkg/bits/bits.go
@@ -0,0 +1,16 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package bits includes all bit related types and operations.
+package bits
diff --git a/pkg/bits/bits_template.go b/pkg/bits/bits_template.go
new file mode 100644
index 000000000..998645388
--- /dev/null
+++ b/pkg/bits/bits_template.go
@@ -0,0 +1,52 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package bits
+
+// Non-atomic bit operations on a template type T.
+
+// T is a required type parameter that must be an integral type.
+type T uint64
+
+// IsOn returns true if *all* bits set in 'bits' are set in 'mask'.
+func IsOn(mask, bits T) bool {
+ return mask&bits == bits
+}
+
+// IsAnyOn returns true if *any* bit set in 'bits' is set in 'mask'.
+func IsAnyOn(mask, bits T) bool {
+ return mask&bits != 0
+}
+
+// Mask returns a T with all of the given bits set.
+func Mask(is ...int) T {
+ ret := T(0)
+ for _, i := range is {
+ ret |= MaskOf(i)
+ }
+ return ret
+}
+
+// MaskOf is like Mask, but sets only a single bit (more efficiently).
+func MaskOf(i int) T {
+ return T(1) << T(i)
+}
+
+// IsPowerOfTwo returns true if v is power of 2.
+func IsPowerOfTwo(v T) bool {
+ if v == 0 {
+ return false
+ }
+ return v&(v-1) == 0
+}
diff --git a/pkg/bits/uint64_arch.go b/pkg/bits/uint64_arch.go
new file mode 100644
index 000000000..9f23eff77
--- /dev/null
+++ b/pkg/bits/uint64_arch.go
@@ -0,0 +1,36 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64 arm64
+
+package bits
+
+// TrailingZeros64 returns the number of bits before the least significant 1
+// bit in x; in other words, it returns the index of the least significant 1
+// bit in x. If x is 0, TrailingZeros64 returns 64.
+func TrailingZeros64(x uint64) int
+
+// MostSignificantOne64 returns the index of the most significant 1 bit in
+// x. If x is 0, MostSignificantOne64 returns 64.
+func MostSignificantOne64(x uint64) int
+
+// ForEachSetBit64 calls f once for each set bit in x, with argument i equal to
+// the set bit's index.
+func ForEachSetBit64(x uint64, f func(i int)) {
+ for x != 0 {
+ i := TrailingZeros64(x)
+ f(i)
+ x &^= MaskOf64(i)
+ }
+}
diff --git a/pkg/bits/uint64_arch_amd64_asm.s b/pkg/bits/uint64_arch_amd64_asm.s
new file mode 100644
index 000000000..8ff364181
--- /dev/null
+++ b/pkg/bits/uint64_arch_amd64_asm.s
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+TEXT ·TrailingZeros64(SB),$0-16
+ BSFQ x+0(FP), AX
+ JNZ end
+ MOVQ $64, AX
+end:
+ MOVQ AX, ret+8(FP)
+ RET
+
+TEXT ·MostSignificantOne64(SB),$0-16
+ BSRQ x+0(FP), AX
+ JNZ end
+ MOVQ $64, AX
+end:
+ MOVQ AX, ret+8(FP)
+ RET
diff --git a/pkg/bits/uint64_arch_arm64_asm.s b/pkg/bits/uint64_arch_arm64_asm.s
new file mode 100644
index 000000000..814ba562d
--- /dev/null
+++ b/pkg/bits/uint64_arch_arm64_asm.s
@@ -0,0 +1,33 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+TEXT ·TrailingZeros64(SB),$0-16
+ MOVD x+0(FP), R0
+ RBIT R0, R0
+ CLZ R0, R0 // return 64 if x == 0
+ MOVD R0, ret+8(FP)
+ RET
+
+TEXT ·MostSignificantOne64(SB),$0-16
+ MOVD x+0(FP), R0
+ CLZ R0, R0 // return 64 if x == 0
+ MOVD $63, R1
+ SUBS R0, R1, R0 // ret = 63 - CLZ
+ BPL end
+ MOVD $64, R0 // x == 0
+end:
+ MOVD R0, ret+8(FP)
+ RET
diff --git a/pkg/bits/uint64_arch_generic.go b/pkg/bits/uint64_arch_generic.go
new file mode 100644
index 000000000..9dd2098d1
--- /dev/null
+++ b/pkg/bits/uint64_arch_generic.go
@@ -0,0 +1,55 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !amd64,!arm64
+
+package bits
+
+// TrailingZeros64 returns the number of bits before the least significant 1
+// bit in x; in other words, it returns the index of the least significant 1
+// bit in x. If x is 0, TrailingZeros64 returns 64.
+func TrailingZeros64(x uint64) int {
+ if x == 0 {
+ return 64
+ }
+ i := 0
+ for ; x&1 == 0; i++ {
+ x >>= 1
+ }
+ return i
+}
+
+// MostSignificantOne64 returns the index of the most significant 1 bit in
+// x. If x is 0, MostSignificantOne64 returns 64.
+func MostSignificantOne64(x uint64) int {
+ if x == 0 {
+ return 64
+ }
+ i := 63
+ for ; x&(1<<63) == 0; i-- {
+ x <<= 1
+ }
+ return i
+}
+
+// ForEachSetBit64 calls f once for each set bit in x, with argument i equal to
+// the set bit's index.
+func ForEachSetBit64(x uint64, f func(i int)) {
+ for i := 0; x != 0; i++ {
+ if x&1 != 0 {
+ f(i)
+ }
+ x >>= 1
+ }
+}
diff --git a/pkg/bits/uint64_test.go b/pkg/bits/uint64_test.go
new file mode 100644
index 000000000..193d1ebcd
--- /dev/null
+++ b/pkg/bits/uint64_test.go
@@ -0,0 +1,134 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package bits
+
+import (
+ "reflect"
+ "testing"
+)
+
+func TestTrailingZeros64(t *testing.T) {
+ for i := 0; i <= 64; i++ {
+ n := uint64(1) << uint(i)
+ if got, want := TrailingZeros64(n), i; got != want {
+ t.Errorf("TrailingZeros64(%#x): got %d, wanted %d", n, got, want)
+ }
+ }
+
+ for i := 0; i < 64; i++ {
+ n := ^uint64(0) << uint(i)
+ if got, want := TrailingZeros64(n), i; got != want {
+ t.Errorf("TrailingZeros64(%#x): got %d, wanted %d", n, got, want)
+ }
+ }
+
+ for i := 0; i < 64; i++ {
+ n := ^uint64(0) >> uint(i)
+ if got, want := TrailingZeros64(n), 0; got != want {
+ t.Errorf("TrailingZeros64(%#x): got %d, wanted %d", n, got, want)
+ }
+ }
+}
+
+func TestMostSignificantOne64(t *testing.T) {
+ for i := 0; i <= 64; i++ {
+ n := uint64(1) << uint(i)
+ if got, want := MostSignificantOne64(n), i; got != want {
+ t.Errorf("MostSignificantOne64(%#x): got %d, wanted %d", n, got, want)
+ }
+ }
+
+ for i := 0; i < 64; i++ {
+ n := ^uint64(0) >> uint(i)
+ if got, want := MostSignificantOne64(n), 63-i; got != want {
+ t.Errorf("MostSignificantOne64(%#x): got %d, wanted %d", n, got, want)
+ }
+ }
+
+ for i := 0; i < 64; i++ {
+ n := ^uint64(0) << uint(i)
+ if got, want := MostSignificantOne64(n), 63; got != want {
+ t.Errorf("MostSignificantOne64(%#x): got %d, wanted %d", n, got, want)
+ }
+ }
+}
+
+func TestForEachSetBit64(t *testing.T) {
+ for _, want := range [][]int{
+ {},
+ {0},
+ {1},
+ {63},
+ {0, 1},
+ {1, 3, 5},
+ {0, 63},
+ } {
+ n := Mask64(want...)
+ // "Slice values are deeply equal when ... they are both nil or both
+ // non-nil ..."
+ got := make([]int, 0)
+ ForEachSetBit64(n, func(i int) {
+ got = append(got, i)
+ })
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("ForEachSetBit64(%#x): iterated bits %v, wanted %v", n, got, want)
+ }
+ }
+}
+
+func TestIsOn(t *testing.T) {
+ type spec struct {
+ mask uint64
+ bits uint64
+ any bool
+ all bool
+ }
+ for _, s := range []spec{
+ {Mask64(0), Mask64(0), true, true},
+ {Mask64(63), Mask64(63), true, true},
+ {Mask64(0), Mask64(1), false, false},
+ {Mask64(0), Mask64(0, 1), true, false},
+
+ {Mask64(1, 63), Mask64(1), true, true},
+ {Mask64(1, 63), Mask64(1, 63), true, true},
+ {Mask64(1, 63), Mask64(0, 1, 63), true, false},
+ {Mask64(1, 63), Mask64(0, 62), false, false},
+ } {
+ if ok := IsAnyOn64(s.mask, s.bits); ok != s.any {
+ t.Errorf("IsAnyOn(%#x, %#x) = %v, wanted: %v", s.mask, s.bits, ok, s.any)
+ }
+ if ok := IsOn64(s.mask, s.bits); ok != s.all {
+ t.Errorf("IsOn(%#x, %#x) = %v, wanted: %v", s.mask, s.bits, ok, s.all)
+ }
+ }
+}
+
+func TestIsPowerOfTwo(t *testing.T) {
+ for _, tc := range []struct {
+ v uint64
+ want bool
+ }{
+ {v: 0, want: false},
+ {v: 1, want: true},
+ {v: 2, want: true},
+ {v: 3, want: false},
+ {v: 4, want: true},
+ {v: 5, want: false},
+ } {
+ if got := IsPowerOfTwo64(tc.v); got != tc.want {
+ t.Errorf("IsPowerOfTwo(%d) = %t, want: %t", tc.v, got, tc.want)
+ }
+ }
+}
diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD
new file mode 100644
index 000000000..2a6977f85
--- /dev/null
+++ b/pkg/bpf/BUILD
@@ -0,0 +1,31 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "bpf",
+ srcs = [
+ "bpf.go",
+ "decoder.go",
+ "input_bytes.go",
+ "interpreter.go",
+ "program_builder.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = ["//pkg/abi/linux"],
+)
+
+go_test(
+ name = "bpf_test",
+ size = "small",
+ srcs = [
+ "decoder_test.go",
+ "interpreter_test.go",
+ "program_builder_test.go",
+ ],
+ library = ":bpf",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ ],
+)
diff --git a/pkg/bpf/bpf.go b/pkg/bpf/bpf.go
new file mode 100644
index 000000000..b8b8ad372
--- /dev/null
+++ b/pkg/bpf/bpf.go
@@ -0,0 +1,129 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package bpf provides tools for working with Berkeley Packet Filter (BPF)
+// programs. More information on BPF can be found at
+// https://www.freebsd.org/cgi/man.cgi?bpf(4)
+package bpf
+
+import "gvisor.dev/gvisor/pkg/abi/linux"
+
+const (
+ // MaxInstructions is the maximum number of instructions in a BPF program,
+ // and is equal to Linux's BPF_MAXINSNS.
+ MaxInstructions = 4096
+
+ // ScratchMemRegisters is the number of M registers in a BPF virtual machine,
+ // and is equal to Linux's BPF_MEMWORDS.
+ ScratchMemRegisters = 16
+)
+
+// Parts of a linux.BPFInstruction.OpCode. Compare to the Linux kernel's
+// include/uapi/linux/filter.h.
+//
+// In the comments below:
+//
+// - A, X, and M[] are BPF virtual machine registers.
+//
+// - K refers to the instruction field linux.BPFInstruction.K.
+//
+// - Bits are counted from the LSB position.
+const (
+ // Instruction class, stored in bits 0-2.
+ Ld = 0x00 // load into A
+ Ldx = 0x01 // load into X
+ St = 0x02 // store from A
+ Stx = 0x03 // store from X
+ Alu = 0x04 // arithmetic
+ Jmp = 0x05 // jump
+ Ret = 0x06 // return
+ Misc = 0x07
+ instructionClassMask = 0x07
+
+ // Size of a load, stored in bits 3-4.
+ W = 0x00 // 32 bits
+ H = 0x08 // 16 bits
+ B = 0x10 // 8 bits
+ loadSizeMask = 0x18
+
+ // Source operand for a load, stored in bits 5-7.
+ // Address mode numbers in the comments come from Linux's
+ // Documentation/networking/filter.txt.
+ Imm = 0x00 // immediate value K (mode 4)
+ Abs = 0x20 // data in input at byte offset K (mode 1)
+ Ind = 0x40 // data in input at byte offset X+K (mode 2)
+ Mem = 0x60 // M[K] (mode 3)
+ Len = 0x80 // length of the input in bytes ("BPF extension len")
+ Msh = 0xa0 // 4 * lower nibble of input at byte offset K (mode 5)
+ loadModeMask = 0xe0
+
+ // Source operands for arithmetic, jump, and return instructions.
+ // Arithmetic and jump instructions can use K or X as source operands.
+ // Return instructions can use K or A as source operands.
+ K = 0x00 // still mode 4
+ X = 0x08 // mode 0
+ A = 0x10 // mode 9
+ srcAluJmpMask = 0x08
+ srcRetMask = 0x18
+
+ // Arithmetic instructions, stored in bits 4-7.
+ Add = 0x00
+ Sub = 0x10 // A - src
+ Mul = 0x20
+ Div = 0x30 // A / src
+ Or = 0x40
+ And = 0x50
+ Lsh = 0x60 // A << src
+ Rsh = 0x70 // A >> src
+ Neg = 0x80 // -A (src ignored)
+ Mod = 0x90 // A % src
+ Xor = 0xa0
+ aluMask = 0xf0
+
+ // Jump instructions, stored in bits 4-7.
+ Ja = 0x00 // unconditional (uses K for jump offset)
+ Jeq = 0x10 // if A == src
+ Jgt = 0x20 // if A > src
+ Jge = 0x30 // if A >= src
+ Jset = 0x40 // if (A & src) != 0
+ jmpMask = 0xf0
+
+ // Miscellaneous instructions, stored in bits 3-7.
+ Tax = 0x00 // A = X
+ Txa = 0x80 // X = A
+ miscMask = 0xf8
+
+ // Masks for bits that should be zero.
+ unusedBitsMask = 0xff00 // all valid instructions use only bits 0-7
+ storeUnusedBitsMask = 0xf8 // stores only use instruction class
+ retUnusedBitsMask = 0xe0 // returns only use instruction class and source operand
+)
+
+// Stmt returns a linux.BPFInstruction representing a BPF non-jump instruction.
+func Stmt(code uint16, k uint32) linux.BPFInstruction {
+ return linux.BPFInstruction{
+ OpCode: code,
+ K: k,
+ }
+}
+
+// Jump returns a linux.BPFInstruction representing a BPF jump instruction.
+func Jump(code uint16, k uint32, jt, jf uint8) linux.BPFInstruction {
+ return linux.BPFInstruction{
+ OpCode: code,
+ JumpIfTrue: jt,
+ JumpIfFalse: jf,
+ K: k,
+ }
+}
diff --git a/pkg/bpf/decoder.go b/pkg/bpf/decoder.go
new file mode 100644
index 000000000..c8ee0c3b1
--- /dev/null
+++ b/pkg/bpf/decoder.go
@@ -0,0 +1,245 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package bpf
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// DecodeProgram translates an array of BPF instructions into text format.
+func DecodeProgram(program []linux.BPFInstruction) (string, error) {
+ var ret bytes.Buffer
+ for line, s := range program {
+ ret.WriteString(fmt.Sprintf("%v: ", line))
+ if err := decode(s, line, &ret); err != nil {
+ return "", err
+ }
+ ret.WriteString("\n")
+ }
+ return ret.String(), nil
+}
+
+// Decode translates BPF instruction into text format.
+func Decode(inst linux.BPFInstruction) (string, error) {
+ var ret bytes.Buffer
+ err := decode(inst, -1, &ret)
+ return ret.String(), err
+}
+
+func decode(inst linux.BPFInstruction, line int, w *bytes.Buffer) error {
+ var err error
+ switch inst.OpCode & instructionClassMask {
+ case Ld:
+ err = decodeLd(inst, w)
+ case Ldx:
+ err = decodeLdx(inst, w)
+ case St:
+ w.WriteString(fmt.Sprintf("M[%v] <- A", inst.K))
+ case Stx:
+ w.WriteString(fmt.Sprintf("M[%v] <- X", inst.K))
+ case Alu:
+ err = decodeAlu(inst, w)
+ case Jmp:
+ err = decodeJmp(inst, line, w)
+ case Ret:
+ err = decodeRet(inst, w)
+ case Misc:
+ err = decodeMisc(inst, w)
+ default:
+ return fmt.Errorf("invalid BPF instruction: %v", inst)
+ }
+ return err
+}
+
+// A <- P[k:4]
+func decodeLd(inst linux.BPFInstruction, w *bytes.Buffer) error {
+ w.WriteString("A <- ")
+
+ switch inst.OpCode & loadModeMask {
+ case Imm:
+ w.WriteString(fmt.Sprintf("%v", inst.K))
+ case Abs:
+ w.WriteString(fmt.Sprintf("P[%v:", inst.K))
+ if err := decodeLdSize(inst, w); err != nil {
+ return err
+ }
+ w.WriteString("]")
+ case Ind:
+ w.WriteString(fmt.Sprintf("P[X+%v:", inst.K))
+ if err := decodeLdSize(inst, w); err != nil {
+ return err
+ }
+ w.WriteString("]")
+ case Mem:
+ w.WriteString(fmt.Sprintf("M[%v]", inst.K))
+ case Len:
+ w.WriteString("len")
+ default:
+ return fmt.Errorf("invalid BPF LD instruction: %v", inst)
+ }
+ return nil
+}
+
+func decodeLdSize(inst linux.BPFInstruction, w *bytes.Buffer) error {
+ switch inst.OpCode & loadSizeMask {
+ case W:
+ w.WriteString("4")
+ case H:
+ w.WriteString("2")
+ case B:
+ w.WriteString("1")
+ default:
+ return fmt.Errorf("Invalid BPF LD size: %v", inst)
+ }
+ return nil
+}
+
+// X <- P[k:4]
+func decodeLdx(inst linux.BPFInstruction, w *bytes.Buffer) error {
+ w.WriteString("X <- ")
+
+ switch inst.OpCode & loadModeMask {
+ case Imm:
+ w.WriteString(fmt.Sprintf("%v", inst.K))
+ case Mem:
+ w.WriteString(fmt.Sprintf("M[%v]", inst.K))
+ case Len:
+ w.WriteString("len")
+ case Msh:
+ w.WriteString(fmt.Sprintf("4*(P[%v:1]&0xf)", inst.K))
+ default:
+ return fmt.Errorf("invalid BPF LDX instruction: %v", inst)
+ }
+ return nil
+}
+
+// A <- A + k
+func decodeAlu(inst linux.BPFInstruction, w *bytes.Buffer) error {
+ code := inst.OpCode & aluMask
+ if code == Neg {
+ w.WriteString("A <- -A")
+ return nil
+ }
+
+ w.WriteString("A <- A ")
+ switch code {
+ case Add:
+ w.WriteString("+ ")
+ case Sub:
+ w.WriteString("- ")
+ case Mul:
+ w.WriteString("* ")
+ case Div:
+ w.WriteString("/ ")
+ case Or:
+ w.WriteString("| ")
+ case And:
+ w.WriteString("& ")
+ case Lsh:
+ w.WriteString("<< ")
+ case Rsh:
+ w.WriteString(">> ")
+ case Mod:
+ w.WriteString("% ")
+ case Xor:
+ w.WriteString("^ ")
+ default:
+ return fmt.Errorf("invalid BPF ALU instruction: %v", inst)
+ }
+ return decodeSource(inst, w)
+}
+
+func decodeSource(inst linux.BPFInstruction, w *bytes.Buffer) error {
+ switch inst.OpCode & srcAluJmpMask {
+ case K:
+ w.WriteString(fmt.Sprintf("%v", inst.K))
+ case X:
+ w.WriteString("X")
+ default:
+ return fmt.Errorf("invalid BPF ALU/JMP source instruction: %v", inst)
+ }
+ return nil
+}
+
+// pc += (A > k) ? jt : jf
+func decodeJmp(inst linux.BPFInstruction, line int, w *bytes.Buffer) error {
+ code := inst.OpCode & jmpMask
+
+ w.WriteString("pc += ")
+ if code == Ja {
+ w.WriteString(printJmpTarget(inst.K, line))
+ } else {
+ w.WriteString("(A ")
+ switch code {
+ case Jeq:
+ w.WriteString("== ")
+ case Jgt:
+ w.WriteString("> ")
+ case Jge:
+ w.WriteString(">= ")
+ case Jset:
+ w.WriteString("& ")
+ default:
+ return fmt.Errorf("invalid BPF ALU instruction: %v", inst)
+ }
+ if err := decodeSource(inst, w); err != nil {
+ return err
+ }
+ w.WriteString(
+ fmt.Sprintf(") ? %s : %s",
+ printJmpTarget(uint32(inst.JumpIfTrue), line),
+ printJmpTarget(uint32(inst.JumpIfFalse), line)))
+ }
+ return nil
+}
+
+func printJmpTarget(target uint32, line int) string {
+ if line == -1 {
+ return fmt.Sprintf("%v", target)
+ }
+ return fmt.Sprintf("%v [%v]", target, int(target)+line+1)
+}
+
+// ret k
+func decodeRet(inst linux.BPFInstruction, w *bytes.Buffer) error {
+ w.WriteString("ret ")
+
+ code := inst.OpCode & srcRetMask
+ switch code {
+ case K:
+ w.WriteString(fmt.Sprintf("%v", inst.K))
+ case A:
+ w.WriteString("A")
+ default:
+ return fmt.Errorf("invalid BPF RET source instruction: %v", inst)
+ }
+ return nil
+}
+
+func decodeMisc(inst linux.BPFInstruction, w *bytes.Buffer) error {
+ code := inst.OpCode & miscMask
+ switch code {
+ case Tax:
+ w.WriteString("X <- A")
+ case Txa:
+ w.WriteString("A <- X")
+ default:
+ return fmt.Errorf("invalid BPF ALU/JMP source instruction: %v", inst)
+ }
+ return nil
+}
diff --git a/pkg/bpf/decoder_test.go b/pkg/bpf/decoder_test.go
new file mode 100644
index 000000000..6a023f0c0
--- /dev/null
+++ b/pkg/bpf/decoder_test.go
@@ -0,0 +1,146 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package bpf
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+func TestDecode(t *testing.T) {
+ for _, test := range []struct {
+ filter linux.BPFInstruction
+ expected string
+ fail bool
+ }{
+ {filter: Stmt(Ld+Imm, 10), expected: "A <- 10"},
+ {filter: Stmt(Ld+Abs+W, 10), expected: "A <- P[10:4]"},
+ {filter: Stmt(Ld+Ind+H, 10), expected: "A <- P[X+10:2]"},
+ {filter: Stmt(Ld+Ind+B, 10), expected: "A <- P[X+10:1]"},
+ {filter: Stmt(Ld+Mem, 10), expected: "A <- M[10]"},
+ {filter: Stmt(Ld+Len, 0), expected: "A <- len"},
+ {filter: Stmt(Ldx+Imm, 10), expected: "X <- 10"},
+ {filter: Stmt(Ldx+Mem, 10), expected: "X <- M[10]"},
+ {filter: Stmt(Ldx+Len, 0), expected: "X <- len"},
+ {filter: Stmt(Ldx+Msh, 10), expected: "X <- 4*(P[10:1]&0xf)"},
+ {filter: Stmt(St, 10), expected: "M[10] <- A"},
+ {filter: Stmt(Stx, 10), expected: "M[10] <- X"},
+ {filter: Stmt(Alu+Add+K, 10), expected: "A <- A + 10"},
+ {filter: Stmt(Alu+Sub+K, 10), expected: "A <- A - 10"},
+ {filter: Stmt(Alu+Mul+K, 10), expected: "A <- A * 10"},
+ {filter: Stmt(Alu+Div+K, 10), expected: "A <- A / 10"},
+ {filter: Stmt(Alu+Or+K, 10), expected: "A <- A | 10"},
+ {filter: Stmt(Alu+And+K, 10), expected: "A <- A & 10"},
+ {filter: Stmt(Alu+Lsh+K, 10), expected: "A <- A << 10"},
+ {filter: Stmt(Alu+Rsh+K, 10), expected: "A <- A >> 10"},
+ {filter: Stmt(Alu+Mod+K, 10), expected: "A <- A % 10"},
+ {filter: Stmt(Alu+Xor+K, 10), expected: "A <- A ^ 10"},
+ {filter: Stmt(Alu+Add+X, 0), expected: "A <- A + X"},
+ {filter: Stmt(Alu+Sub+X, 0), expected: "A <- A - X"},
+ {filter: Stmt(Alu+Mul+X, 0), expected: "A <- A * X"},
+ {filter: Stmt(Alu+Div+X, 0), expected: "A <- A / X"},
+ {filter: Stmt(Alu+Or+X, 0), expected: "A <- A | X"},
+ {filter: Stmt(Alu+And+X, 0), expected: "A <- A & X"},
+ {filter: Stmt(Alu+Lsh+X, 0), expected: "A <- A << X"},
+ {filter: Stmt(Alu+Rsh+X, 0), expected: "A <- A >> X"},
+ {filter: Stmt(Alu+Mod+X, 0), expected: "A <- A % X"},
+ {filter: Stmt(Alu+Xor+X, 0), expected: "A <- A ^ X"},
+ {filter: Stmt(Alu+Neg, 0), expected: "A <- -A"},
+ {filter: Stmt(Jmp+Ja, 10), expected: "pc += 10"},
+ {filter: Jump(Jmp+Jeq+K, 10, 2, 5), expected: "pc += (A == 10) ? 2 : 5"},
+ {filter: Jump(Jmp+Jgt+K, 10, 2, 5), expected: "pc += (A > 10) ? 2 : 5"},
+ {filter: Jump(Jmp+Jge+K, 10, 2, 5), expected: "pc += (A >= 10) ? 2 : 5"},
+ {filter: Jump(Jmp+Jset+K, 10, 2, 5), expected: "pc += (A & 10) ? 2 : 5"},
+ {filter: Jump(Jmp+Jeq+X, 0, 2, 5), expected: "pc += (A == X) ? 2 : 5"},
+ {filter: Jump(Jmp+Jgt+X, 0, 2, 5), expected: "pc += (A > X) ? 2 : 5"},
+ {filter: Jump(Jmp+Jge+X, 0, 2, 5), expected: "pc += (A >= X) ? 2 : 5"},
+ {filter: Jump(Jmp+Jset+X, 0, 2, 5), expected: "pc += (A & X) ? 2 : 5"},
+ {filter: Stmt(Ret+K, 10), expected: "ret 10"},
+ {filter: Stmt(Ret+A, 0), expected: "ret A"},
+ {filter: Stmt(Misc+Tax, 0), expected: "X <- A"},
+ {filter: Stmt(Misc+Txa, 0), expected: "A <- X"},
+ {filter: Stmt(Ld+Ind+Msh, 0), fail: true},
+ } {
+ got, err := Decode(test.filter)
+ if test.fail {
+ if err == nil {
+ t.Errorf("Decode(%v) failed, expected: 'error', got: %q", test.filter, got)
+ continue
+ }
+ } else {
+ if err != nil {
+ t.Errorf("Decode(%v) failed for test %q, error: %q", test.filter, test.expected, err)
+ continue
+ }
+ if got != test.expected {
+ t.Errorf("Decode(%v) failed, expected: %q, got: %q", test.filter, test.expected, got)
+ continue
+ }
+ }
+ }
+}
+
+func TestDecodeProgram(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ program []linux.BPFInstruction
+ expected string
+ fail bool
+ }{
+ {name: "basic with jump indexes",
+ program: []linux.BPFInstruction{
+ Stmt(Ld+Abs+W, 10),
+ Stmt(Ldx+Mem, 10),
+ Stmt(St, 10),
+ Stmt(Stx, 10),
+ Stmt(Alu+Add+K, 10),
+ Stmt(Jmp+Ja, 10),
+ Jump(Jmp+Jeq+K, 10, 2, 5),
+ Jump(Jmp+Jset+X, 0, 0, 5),
+ Stmt(Misc+Tax, 0),
+ },
+ expected: "0: A <- P[10:4]\n" +
+ "1: X <- M[10]\n" +
+ "2: M[10] <- A\n" +
+ "3: M[10] <- X\n" +
+ "4: A <- A + 10\n" +
+ "5: pc += 10 [16]\n" +
+ "6: pc += (A == 10) ? 2 [9] : 5 [12]\n" +
+ "7: pc += (A & X) ? 0 [8] : 5 [13]\n" +
+ "8: X <- A\n",
+ },
+ {name: "invalid instruction",
+ program: []linux.BPFInstruction{Stmt(Ld+Abs+W, 10), Stmt(Ld+Len+Mem, 0)},
+ fail: true},
+ } {
+ got, err := DecodeProgram(test.program)
+ if test.fail {
+ if err == nil {
+ t.Errorf("%s: Decode(...) failed, expected: 'error', got: %q", test.name, got)
+ continue
+ }
+ } else {
+ if err != nil {
+ t.Errorf("%s: Decode failed: %v", test.name, err)
+ continue
+ }
+ if got != test.expected {
+ t.Errorf("%s: Decode(...) failed, expected: %q, got: %q", test.name, test.expected, got)
+ continue
+ }
+ }
+ }
+}
diff --git a/pkg/bpf/input_bytes.go b/pkg/bpf/input_bytes.go
new file mode 100644
index 000000000..86b216cfc
--- /dev/null
+++ b/pkg/bpf/input_bytes.go
@@ -0,0 +1,58 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package bpf
+
+import (
+ "encoding/binary"
+)
+
+// InputBytes implements the Input interface by providing access to a byte
+// slice. Unaligned loads are supported.
+type InputBytes struct {
+ // Data is the data accessed through the Input interface.
+ Data []byte
+
+ // Order is the byte order the data is accessed with.
+ Order binary.ByteOrder
+}
+
+// Load32 implements Input.Load32.
+func (i InputBytes) Load32(off uint32) (uint32, bool) {
+ if uint64(off)+4 > uint64(len(i.Data)) {
+ return 0, false
+ }
+ return i.Order.Uint32(i.Data[int(off):]), true
+}
+
+// Load16 implements Input.Load16.
+func (i InputBytes) Load16(off uint32) (uint16, bool) {
+ if uint64(off)+2 > uint64(len(i.Data)) {
+ return 0, false
+ }
+ return i.Order.Uint16(i.Data[int(off):]), true
+}
+
+// Load8 implements Input.Load8.
+func (i InputBytes) Load8(off uint32) (uint8, bool) {
+ if uint64(off)+1 > uint64(len(i.Data)) {
+ return 0, false
+ }
+ return i.Data[int(off)], true
+}
+
+// Length implements Input.Length.
+func (i InputBytes) Length() uint32 {
+ return uint32(len(i.Data))
+}
diff --git a/pkg/bpf/interpreter.go b/pkg/bpf/interpreter.go
new file mode 100644
index 000000000..ed27abb9b
--- /dev/null
+++ b/pkg/bpf/interpreter.go
@@ -0,0 +1,412 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package bpf
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// Possible values for ProgramError.Code.
+const (
+ // DivisionByZero indicates that a program contains, or executed, a
+ // division or modulo by zero.
+ DivisionByZero = iota
+
+ // InvalidEndOfProgram indicates that the last instruction of a program is
+ // not a return.
+ InvalidEndOfProgram
+
+ // InvalidInstructionCount indicates that a program has zero instructions
+ // or more than MaxInstructions instructions.
+ InvalidInstructionCount
+
+ // InvalidJumpTarget indicates that a program contains a jump whose target
+ // is outside of the program's bounds.
+ InvalidJumpTarget
+
+ // InvalidLoad indicates that a program executed an invalid load of input
+ // data.
+ InvalidLoad
+
+ // InvalidOpcode indicates that a program contains an instruction with an
+ // invalid opcode.
+ InvalidOpcode
+
+ // InvalidRegister indicates that a program contains a load from, or store
+ // to, a non-existent M register (index >= ScratchMemRegisters).
+ InvalidRegister
+)
+
+// Error is an error encountered while compiling or executing a BPF program.
+type Error struct {
+ // Code indicates the kind of error that occurred.
+ Code int
+
+ // PC is the program counter (index into the list of instructions) at which
+ // the error occurred.
+ PC int
+}
+
+func (e Error) codeString() string {
+ switch e.Code {
+ case DivisionByZero:
+ return "division by zero"
+ case InvalidEndOfProgram:
+ return "last instruction must be a return"
+ case InvalidInstructionCount:
+ return "invalid number of instructions"
+ case InvalidJumpTarget:
+ return "jump target out of bounds"
+ case InvalidLoad:
+ return "load out of bounds or violates input alignment requirements"
+ case InvalidOpcode:
+ return "invalid instruction opcode"
+ case InvalidRegister:
+ return "invalid M register"
+ default:
+ return "unknown error"
+ }
+}
+
+// Error implements error.Error.
+func (e Error) Error() string {
+ return fmt.Sprintf("at l%d: %s", e.PC, e.codeString())
+}
+
+// Program is a BPF program that has been validated for consistency.
+//
+// +stateify savable
+type Program struct {
+ instructions []linux.BPFInstruction
+}
+
+// Length returns the number of instructions in the program.
+func (p Program) Length() int {
+ return len(p.instructions)
+}
+
+// Compile performs validation on a sequence of BPF instructions before
+// wrapping them in a Program.
+func Compile(insns []linux.BPFInstruction) (Program, error) {
+ if len(insns) == 0 || len(insns) > MaxInstructions {
+ return Program{}, Error{InvalidInstructionCount, len(insns)}
+ }
+
+ // The last instruction must be a return.
+ if last := insns[len(insns)-1]; last.OpCode != (Ret|K) && last.OpCode != (Ret|A) {
+ return Program{}, Error{InvalidEndOfProgram, len(insns) - 1}
+ }
+
+ // Validate each instruction. Note that we skip a validation Linux does:
+ // Linux additionally verifies that every load from an M register is
+ // preceded, in every path, by a store to the same M register, in order to
+ // avoid having to clear M between programs
+ // (net/core/filter.c:check_load_and_stores). We always start with a zeroed
+ // M array.
+ for pc, i := range insns {
+ if i.OpCode&unusedBitsMask != 0 {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ switch i.OpCode & instructionClassMask {
+ case Ld:
+ mode := i.OpCode & loadModeMask
+ switch i.OpCode & loadSizeMask {
+ case W:
+ if mode != Imm && mode != Abs && mode != Ind && mode != Mem && mode != Len {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ if mode == Mem && i.K >= ScratchMemRegisters {
+ return Program{}, Error{InvalidRegister, pc}
+ }
+ case H, B:
+ if mode != Abs && mode != Ind {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ default:
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ case Ldx:
+ mode := i.OpCode & loadModeMask
+ switch i.OpCode & loadSizeMask {
+ case W:
+ if mode != Imm && mode != Mem && mode != Len {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ if mode == Mem && i.K >= ScratchMemRegisters {
+ return Program{}, Error{InvalidRegister, pc}
+ }
+ case B:
+ if mode != Msh {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ default:
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ case St, Stx:
+ if i.OpCode&storeUnusedBitsMask != 0 {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ if i.K >= ScratchMemRegisters {
+ return Program{}, Error{InvalidRegister, pc}
+ }
+ case Alu:
+ switch i.OpCode & aluMask {
+ case Add, Sub, Mul, Or, And, Lsh, Rsh, Xor:
+ break
+ case Div, Mod:
+ if src := i.OpCode & srcAluJmpMask; src == K && i.K == 0 {
+ return Program{}, Error{DivisionByZero, pc}
+ }
+ case Neg:
+ // Negation doesn't take a source operand.
+ if i.OpCode&srcAluJmpMask != 0 {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ default:
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ case Jmp:
+ switch i.OpCode & jmpMask {
+ case Ja:
+ // Unconditional jump doesn't take a source operand.
+ if i.OpCode&srcAluJmpMask != 0 {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ // Do the comparison in 64 bits to avoid the possibility of
+ // overflow from a very large i.K.
+ if uint64(pc)+uint64(i.K)+1 >= uint64(len(insns)) {
+ return Program{}, Error{InvalidJumpTarget, pc}
+ }
+ case Jeq, Jgt, Jge, Jset:
+ // jt and jf are uint16s, so there's no threat of overflow.
+ if pc+int(i.JumpIfTrue)+1 >= len(insns) {
+ return Program{}, Error{InvalidJumpTarget, pc}
+ }
+ if pc+int(i.JumpIfFalse)+1 >= len(insns) {
+ return Program{}, Error{InvalidJumpTarget, pc}
+ }
+ default:
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ case Ret:
+ if i.OpCode&retUnusedBitsMask != 0 {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ if src := i.OpCode & srcRetMask; src != K && src != A {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ case Misc:
+ if misc := i.OpCode & miscMask; misc != Tax && misc != Txa {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ }
+ }
+
+ return Program{insns}, nil
+}
+
+// Input represents a source of input data for a BPF program. (BPF
+// documentation sometimes refers to the input data as the "packet" due to its
+// origins as a packet processing DSL.)
+//
+// For all of Input's Load methods:
+//
+// - The second (bool) return value is true if the load succeeded and false
+// otherwise.
+//
+// - Inputs should not assume that the loaded range falls within the input
+// data's length. Inputs should return false if the load falls outside of the
+// input data.
+//
+// - Inputs should not assume that the offset is correctly aligned. Inputs may
+// choose to service or reject loads to unaligned addresses.
+type Input interface {
+ // Load32 reads 32 bits from the input starting at the given byte offset.
+ Load32(off uint32) (uint32, bool)
+
+ // Load16 reads 16 bits from the input starting at the given byte offset.
+ Load16(off uint32) (uint16, bool)
+
+ // Load8 reads 8 bits from the input starting at the given byte offset.
+ Load8(off uint32) (uint8, bool)
+
+ // Length returns the length of the input in bytes.
+ Length() uint32
+}
+
+// machine represents the state of a BPF virtual machine.
+type machine struct {
+ A uint32
+ X uint32
+ M [ScratchMemRegisters]uint32
+}
+
+func conditionalJumpOffset(insn linux.BPFInstruction, cond bool) int {
+ if cond {
+ return int(insn.JumpIfTrue)
+ }
+ return int(insn.JumpIfFalse)
+}
+
+// Exec executes a BPF program over the given input and returns its return
+// value.
+func Exec(p Program, in Input) (uint32, error) {
+ var m machine
+ var pc int
+ for ; pc < len(p.instructions); pc++ {
+ i := p.instructions[pc]
+ switch i.OpCode {
+ case Ld | Imm | W:
+ m.A = i.K
+ case Ld | Abs | W:
+ val, ok := in.Load32(i.K)
+ if !ok {
+ return 0, Error{InvalidLoad, pc}
+ }
+ m.A = val
+ case Ld | Abs | H:
+ val, ok := in.Load16(i.K)
+ if !ok {
+ return 0, Error{InvalidLoad, pc}
+ }
+ m.A = uint32(val)
+ case Ld | Abs | B:
+ val, ok := in.Load8(i.K)
+ if !ok {
+ return 0, Error{InvalidLoad, pc}
+ }
+ m.A = uint32(val)
+ case Ld | Ind | W:
+ val, ok := in.Load32(m.X + i.K)
+ if !ok {
+ return 0, Error{InvalidLoad, pc}
+ }
+ m.A = val
+ case Ld | Ind | H:
+ val, ok := in.Load16(m.X + i.K)
+ if !ok {
+ return 0, Error{InvalidLoad, pc}
+ }
+ m.A = uint32(val)
+ case Ld | Ind | B:
+ val, ok := in.Load8(m.X + i.K)
+ if !ok {
+ return 0, Error{InvalidLoad, pc}
+ }
+ m.A = uint32(val)
+ case Ld | Mem | W:
+ m.A = m.M[int(i.K)]
+ case Ld | Len | W:
+ m.A = in.Length()
+ case Ldx | Imm | W:
+ m.X = i.K
+ case Ldx | Mem | W:
+ m.X = m.M[int(i.K)]
+ case Ldx | Len | W:
+ m.X = in.Length()
+ case Ldx | Msh | B:
+ val, ok := in.Load8(i.K)
+ if !ok {
+ return 0, Error{InvalidLoad, pc}
+ }
+ m.X = 4 * uint32(val&0xf)
+ case St:
+ m.M[int(i.K)] = m.A
+ case Stx:
+ m.M[int(i.K)] = m.X
+ case Alu | Add | K:
+ m.A += i.K
+ case Alu | Add | X:
+ m.A += m.X
+ case Alu | Sub | K:
+ m.A -= i.K
+ case Alu | Sub | X:
+ m.A -= m.X
+ case Alu | Mul | K:
+ m.A *= i.K
+ case Alu | Mul | X:
+ m.A *= m.X
+ case Alu | Div | K:
+ // K != 0 already checked by Compile.
+ m.A /= i.K
+ case Alu | Div | X:
+ if m.X == 0 {
+ return 0, Error{DivisionByZero, pc}
+ }
+ m.A /= m.X
+ case Alu | Or | K:
+ m.A |= i.K
+ case Alu | Or | X:
+ m.A |= m.X
+ case Alu | And | K:
+ m.A &= i.K
+ case Alu | And | X:
+ m.A &= m.X
+ case Alu | Lsh | K:
+ m.A <<= i.K
+ case Alu | Lsh | X:
+ m.A <<= m.X
+ case Alu | Rsh | K:
+ m.A >>= i.K
+ case Alu | Rsh | X:
+ m.A >>= m.X
+ case Alu | Neg:
+ m.A = uint32(-int32(m.A))
+ case Alu | Mod | K:
+ // K != 0 already checked by Compile.
+ m.A %= i.K
+ case Alu | Mod | X:
+ if m.X == 0 {
+ return 0, Error{DivisionByZero, pc}
+ }
+ m.A %= m.X
+ case Alu | Xor | K:
+ m.A ^= i.K
+ case Alu | Xor | X:
+ m.A ^= m.X
+ case Jmp | Ja:
+ pc += int(i.K)
+ case Jmp | Jeq | K:
+ pc += conditionalJumpOffset(i, m.A == i.K)
+ case Jmp | Jeq | X:
+ pc += conditionalJumpOffset(i, m.A == m.X)
+ case Jmp | Jgt | K:
+ pc += conditionalJumpOffset(i, m.A > i.K)
+ case Jmp | Jgt | X:
+ pc += conditionalJumpOffset(i, m.A > m.X)
+ case Jmp | Jge | K:
+ pc += conditionalJumpOffset(i, m.A >= i.K)
+ case Jmp | Jge | X:
+ pc += conditionalJumpOffset(i, m.A >= m.X)
+ case Jmp | Jset | K:
+ pc += conditionalJumpOffset(i, (m.A&i.K) != 0)
+ case Jmp | Jset | X:
+ pc += conditionalJumpOffset(i, (m.A&m.X) != 0)
+ case Ret | K:
+ return i.K, nil
+ case Ret | A:
+ return m.A, nil
+ case Misc | Tax:
+ m.A = m.X
+ case Misc | Txa:
+ m.X = m.A
+ default:
+ return 0, Error{InvalidOpcode, pc}
+ }
+ }
+ return 0, Error{InvalidEndOfProgram, pc}
+}
diff --git a/pkg/bpf/interpreter_test.go b/pkg/bpf/interpreter_test.go
new file mode 100644
index 000000000..c85d786b9
--- /dev/null
+++ b/pkg/bpf/interpreter_test.go
@@ -0,0 +1,797 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package bpf
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+)
+
+func TestCompilationErrors(t *testing.T) {
+ for _, test := range []struct {
+ // desc is the test's description.
+ desc string
+
+ // insns is the BPF instructions to be compiled.
+ insns []linux.BPFInstruction
+
+ // expectedErr is the expected compilation error.
+ expectedErr error
+ }{
+ {
+ desc: "Instructions must not be nil",
+ expectedErr: Error{InvalidInstructionCount, 0},
+ },
+ {
+ desc: "Instructions must not be empty",
+ insns: []linux.BPFInstruction{},
+ expectedErr: Error{InvalidInstructionCount, 0},
+ },
+ {
+ desc: "A program must end with a return",
+ insns: make([]linux.BPFInstruction, MaxInstructions),
+ expectedErr: Error{InvalidEndOfProgram, MaxInstructions - 1},
+ },
+ {
+ desc: "A program must have MaxInstructions or fewer instructions",
+ insns: append(make([]linux.BPFInstruction, MaxInstructions), Stmt(Ret|K, 0)),
+ expectedErr: Error{InvalidInstructionCount, MaxInstructions + 1},
+ },
+ {
+ desc: "A load from an invalid M register is a compilation error",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Mem|W, ScratchMemRegisters), // A = M[16]
+ Stmt(Ret|K, 0), // return 0
+ },
+ expectedErr: Error{InvalidRegister, 0},
+ },
+ {
+ desc: "A store to an invalid M register is a compilation error",
+ insns: []linux.BPFInstruction{
+ Stmt(St, ScratchMemRegisters), // M[16] = A
+ Stmt(Ret|K, 0), // return 0
+ },
+ expectedErr: Error{InvalidRegister, 0},
+ },
+ {
+ desc: "Division by literal zero is a compilation error",
+ insns: []linux.BPFInstruction{
+ Stmt(Alu|Div|K, 0), // A /= 0
+ Stmt(Ret|K, 0), // return 0
+ },
+ expectedErr: Error{DivisionByZero, 0},
+ },
+ {
+ desc: "An unconditional jump outside of the program is a compilation error",
+ insns: []linux.BPFInstruction{
+ Jump(Jmp|Ja, 1, 0, 0), // jmp nextpc+1
+ Stmt(Ret|K, 0), // return 0
+ },
+ expectedErr: Error{InvalidJumpTarget, 0},
+ },
+ {
+ desc: "A conditional jump outside of the program in the true case is a compilation error",
+ insns: []linux.BPFInstruction{
+ Jump(Jmp|Jeq|K, 0, 1, 0), // if (A == K) jmp nextpc+1
+ Stmt(Ret|K, 0), // return 0
+ },
+ expectedErr: Error{InvalidJumpTarget, 0},
+ },
+ {
+ desc: "A conditional jump outside of the program in the false case is a compilation error",
+ insns: []linux.BPFInstruction{
+ Jump(Jmp|Jeq|K, 0, 0, 1), // if (A != K) jmp nextpc+1
+ Stmt(Ret|K, 0), // return 0
+ },
+ expectedErr: Error{InvalidJumpTarget, 0},
+ },
+ } {
+ _, err := Compile(test.insns)
+ if err != test.expectedErr {
+ t.Errorf("%s: expected error %q, got error %q", test.desc, test.expectedErr, err)
+ }
+ }
+}
+
+func TestExecErrors(t *testing.T) {
+ for _, test := range []struct {
+ // desc is the test's description.
+ desc string
+
+ // insns is the BPF instructions to be executed.
+ insns []linux.BPFInstruction
+
+ // expectedErr is the expected execution error.
+ expectedErr error
+ }{
+ {
+ desc: "An out-of-bounds load of input data is an execution error",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Abs|B, 0), // A = input[0]
+ Stmt(Ret|K, 0), // return 0
+ },
+ expectedErr: Error{InvalidLoad, 0},
+ },
+ {
+ desc: "Division by zero at runtime is an execution error",
+ insns: []linux.BPFInstruction{
+ Stmt(Alu|Div|X, 0), // A /= X
+ Stmt(Ret|K, 0), // return 0
+ },
+ expectedErr: Error{DivisionByZero, 0},
+ },
+ {
+ desc: "Modulo zero at runtime is an execution error",
+ insns: []linux.BPFInstruction{
+ Stmt(Alu|Mod|X, 0), // A %= X
+ Stmt(Ret|K, 0), // return 0
+ },
+ expectedErr: Error{DivisionByZero, 0},
+ },
+ } {
+ p, err := Compile(test.insns)
+ if err != nil {
+ t.Errorf("%s: unexpected compilation error: %v", test.desc, err)
+ continue
+ }
+ ret, err := Exec(p, InputBytes{nil, binary.BigEndian})
+ if err != test.expectedErr {
+ t.Errorf("%s: expected execution error %q, got (%d, %v)", test.desc, test.expectedErr, ret, err)
+ }
+ }
+}
+
+func TestValidInstructions(t *testing.T) {
+ for _, test := range []struct {
+ // desc is the test's description.
+ desc string
+
+ // insns is the BPF instructions to be compiled.
+ insns []linux.BPFInstruction
+
+ // input is the input data. Note that input will be read as big-endian.
+ input []byte
+
+ // expectedRet is the expected return value of the BPF program.
+ expectedRet uint32
+ }{
+ {
+ desc: "Return of immediate",
+ insns: []linux.BPFInstruction{
+ Stmt(Ret|K, 42), // return 42
+ },
+ expectedRet: 42,
+ },
+ {
+ desc: "Load of immediate into A",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 42), // A = 42
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 42,
+ },
+ {
+ desc: "Load of immediate into X and copying of X into A",
+ insns: []linux.BPFInstruction{
+ Stmt(Ldx|Imm|W, 42), // X = 42
+ Stmt(Misc|Tax, 0), // A = X
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 42,
+ },
+ {
+ desc: "Copying of A into X and back",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 42), // A = 42
+ Stmt(Misc|Txa, 0), // X = A
+ Stmt(Ld|Imm|W, 0), // A = 0
+ Stmt(Misc|Tax, 0), // A = X
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 42,
+ },
+ {
+ desc: "Load of 32-bit input by absolute offset into A",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Abs|W, 1), // A = input[1..4]
+ Stmt(Ret|A, 0), // return A
+ },
+ input: []byte{0x00, 0x11, 0x22, 0x33, 0x44},
+ expectedRet: 0x11223344,
+ },
+ {
+ desc: "Load of 16-bit input by absolute offset into A",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Abs|H, 1), // A = input[1..2]
+ Stmt(Ret|A, 0), // return A
+ },
+ input: []byte{0x00, 0x11, 0x22},
+ expectedRet: 0x1122,
+ },
+ {
+ desc: "Load of 8-bit input by absolute offset into A",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Abs|B, 1), // A = input[1]
+ Stmt(Ret|A, 0), // return A
+ },
+ input: []byte{0x00, 0x11},
+ expectedRet: 0x11,
+ },
+ {
+ desc: "Load of 32-bit input by relative offset into A",
+ insns: []linux.BPFInstruction{
+ Stmt(Ldx|Imm|W, 1), // X = 1
+ Stmt(Ld|Ind|W, 1), // A = input[X+1..X+4]
+ Stmt(Ret|A, 0), // return A
+ },
+ input: []byte{0x00, 0x11, 0x22, 0x33, 0x44, 0x55},
+ expectedRet: 0x22334455,
+ },
+ {
+ desc: "Load of 16-bit input by relative offset into A",
+ insns: []linux.BPFInstruction{
+ Stmt(Ldx|Imm|W, 1), // X = 1
+ Stmt(Ld|Ind|H, 1), // A = input[X+1..X+2]
+ Stmt(Ret|A, 0), // return A
+ },
+ input: []byte{0x00, 0x11, 0x22, 0x33},
+ expectedRet: 0x2233,
+ },
+ {
+ desc: "Load of 8-bit input by relative offset into A",
+ insns: []linux.BPFInstruction{
+ Stmt(Ldx|Imm|W, 1), // X = 1
+ Stmt(Ld|Ind|B, 1), // A = input[X+1]
+ Stmt(Ret|A, 0), // return A
+ },
+ input: []byte{0x00, 0x11, 0x22},
+ expectedRet: 0x22,
+ },
+ {
+ desc: "Load/store between A and scratch memory",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 42), // A = 42
+ Stmt(St, 2), // M[2] = A
+ Stmt(Ld|Imm|W, 0), // A = 0
+ Stmt(Ld|Mem|W, 2), // A = M[2]
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 42,
+ },
+ {
+ desc: "Load/store between X and scratch memory",
+ insns: []linux.BPFInstruction{
+ Stmt(Ldx|Imm|W, 42), // X = 42
+ Stmt(Stx, 3), // M[3] = X
+ Stmt(Ldx|Imm|W, 0), // X = 0
+ Stmt(Ldx|Mem|W, 3), // X = M[3]
+ Stmt(Misc|Tax, 0), // A = X
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 42,
+ },
+ {
+ desc: "Load of input length into A",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Len|W, 0), // A = len(input)
+ Stmt(Ret|A, 0), // return A
+ },
+ input: []byte{1, 2, 3},
+ expectedRet: 3,
+ },
+ {
+ desc: "Load of input length into X",
+ insns: []linux.BPFInstruction{
+ Stmt(Ldx|Len|W, 0), // X = len(input)
+ Stmt(Misc|Tax, 0), // A = X
+ Stmt(Ret|A, 0), // return A
+ },
+ input: []byte{1, 2, 3},
+ expectedRet: 3,
+ },
+ {
+ desc: "Load of MSH (?) into X",
+ insns: []linux.BPFInstruction{
+ Stmt(Ldx|Msh|B, 0), // X = 4*(input[0]&0xf)
+ Stmt(Misc|Tax, 0), // A = X
+ Stmt(Ret|A, 0), // return A
+ },
+ input: []byte{0xf1},
+ expectedRet: 4,
+ },
+ {
+ desc: "Addition of immediate",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 10), // A = 10
+ Stmt(Alu|Add|K, 20), // A += 20
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 30,
+ },
+ {
+ desc: "Addition of X",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 10), // A = 10
+ Stmt(Ldx|Imm|W, 20), // X = 20
+ Stmt(Alu|Add|X, 0), // A += X
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 30,
+ },
+ {
+ desc: "Subtraction of immediate",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 30), // A = 30
+ Stmt(Alu|Sub|K, 20), // A -= 20
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 10,
+ },
+ {
+ desc: "Subtraction of X",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 30), // A = 30
+ Stmt(Ldx|Imm|W, 20), // X = 20
+ Stmt(Alu|Sub|X, 0), // A -= X
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 10,
+ },
+ {
+ desc: "Multiplication of immediate",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 2), // A = 2
+ Stmt(Alu|Mul|K, 3), // A *= 3
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 6,
+ },
+ {
+ desc: "Multiplication of X",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 2), // A = 2
+ Stmt(Ldx|Imm|W, 3), // X = 3
+ Stmt(Alu|Mul|X, 0), // A *= X
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 6,
+ },
+ {
+ desc: "Division by immediate",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 6), // A = 6
+ Stmt(Alu|Div|K, 3), // A /= 3
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 2,
+ },
+ {
+ desc: "Division by X",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 6), // A = 6
+ Stmt(Ldx|Imm|W, 3), // X = 3
+ Stmt(Alu|Div|X, 0), // A /= X
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 2,
+ },
+ {
+ desc: "Modulo immediate",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 17), // A = 17
+ Stmt(Alu|Mod|K, 7), // A %= 7
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 3,
+ },
+ {
+ desc: "Modulo X",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 17), // A = 17
+ Stmt(Ldx|Imm|W, 7), // X = 7
+ Stmt(Alu|Mod|X, 0), // A %= X
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 3,
+ },
+ {
+ desc: "Arithmetic negation",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 1), // A = 1
+ Stmt(Alu|Neg, 0), // A = -A
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 0xffffffff,
+ },
+ {
+ desc: "Bitwise OR with immediate",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55
+ Stmt(Alu|Or|K, 0xff0055aa), // A |= 0xff0055aa
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 0xff00ffff,
+ },
+ {
+ desc: "Bitwise OR with X",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55
+ Stmt(Ldx|Imm|W, 0xff0055aa), // X = 0xff0055aa
+ Stmt(Alu|Or|X, 0), // A |= X
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 0xff00ffff,
+ },
+ {
+ desc: "Bitwise AND with immediate",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55
+ Stmt(Alu|And|K, 0xff0055aa), // A &= 0xff0055aa
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 0xff000000,
+ },
+ {
+ desc: "Bitwise AND with X",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55
+ Stmt(Ldx|Imm|W, 0xff0055aa), // X = 0xff0055aa
+ Stmt(Alu|And|X, 0), // A &= X
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 0xff000000,
+ },
+ {
+ desc: "Bitwise XOR with immediate",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55
+ Stmt(Alu|Xor|K, 0xff0055aa), // A ^= 0xff0055aa
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 0x0000ffff,
+ },
+ {
+ desc: "Bitwise XOR with X",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55
+ Stmt(Ldx|Imm|W, 0xff0055aa), // X = 0xff0055aa
+ Stmt(Alu|Xor|X, 0), // A ^= X
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 0x0000ffff,
+ },
+ {
+ desc: "Left shift by immediate",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 1), // A = 1
+ Stmt(Alu|Lsh|K, 5), // A <<= 5
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 32,
+ },
+ {
+ desc: "Left shift by X",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 1), // A = 1
+ Stmt(Ldx|Imm|W, 5), // X = 5
+ Stmt(Alu|Lsh|X, 0), // A <<= X
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 32,
+ },
+ {
+ desc: "Right shift by immediate",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 0xffffffff), // A = 0xffffffff
+ Stmt(Alu|Rsh|K, 31), // A >>= 31
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 1,
+ },
+ {
+ desc: "Right shift by X",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 0xffffffff), // A = 0xffffffff
+ Stmt(Ldx|Imm|W, 31), // X = 31
+ Stmt(Alu|Rsh|X, 0), // A >>= X
+ Stmt(Ret|A, 0), // return A
+ },
+ expectedRet: 1,
+ },
+ {
+ desc: "Unconditional jump",
+ insns: []linux.BPFInstruction{
+ Jump(Jmp|Ja, 1, 0, 0), // jmp nextpc+1
+ Stmt(Ret|K, 0), // return 0
+ Stmt(Ret|K, 1), // return 1
+ },
+ expectedRet: 1,
+ },
+ {
+ desc: "Jump when A == immediate",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 42), // A = 42
+ Jump(Jmp|Jeq|K, 42, 1, 2), // if (A == 42) jmp nextpc+1 else jmp nextpc+2
+ Stmt(Ret|K, 0), // return 0
+ Stmt(Ret|K, 1), // return 1
+ Stmt(Ret|K, 2), // return 2
+ },
+ expectedRet: 1,
+ },
+ {
+ desc: "Jump when A != immediate",
+ insns: []linux.BPFInstruction{
+ Jump(Jmp|Jeq|K, 42, 1, 2), // if (A == 42) jmp nextpc+1 else jmp nextpc+2
+ Stmt(Ret|K, 0), // return 0
+ Stmt(Ret|K, 1), // return 1
+ Stmt(Ret|K, 2), // return 2
+ },
+ expectedRet: 2,
+ },
+ {
+ desc: "Jump when A == X",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 42), // A = 42
+ Stmt(Ldx|Imm|W, 42), // X = 42
+ Jump(Jmp|Jeq|X, 0, 1, 2), // if (A == X) jmp nextpc+1 else jmp nextpc+2
+ Stmt(Ret|K, 0), // return 0
+ Stmt(Ret|K, 1), // return 1
+ Stmt(Ret|K, 2), // return 2
+ },
+ expectedRet: 1,
+ },
+ {
+ desc: "Jump when A != X",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 42), // A = 42
+ Jump(Jmp|Jeq|X, 0, 1, 2), // if (A == X) jmp nextpc+1 else jmp nextpc+2
+ Stmt(Ret|K, 0), // return 0
+ Stmt(Ret|K, 1), // return 1
+ Stmt(Ret|K, 2), // return 2
+ },
+ expectedRet: 2,
+ },
+ {
+ desc: "Jump when A > immediate",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 10), // A = 10
+ Jump(Jmp|Jgt|K, 9, 1, 2), // if (A > 9) jmp nextpc+1 else jmp nextpc+2
+ Stmt(Ret|K, 0), // return 0
+ Stmt(Ret|K, 1), // return 1
+ Stmt(Ret|K, 2), // return 2
+ },
+ expectedRet: 1,
+ },
+ {
+ desc: "Jump when A <= immediate",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 10), // A = 10
+ Jump(Jmp|Jgt|K, 10, 1, 2), // if (A > 10) jmp nextpc+1 else jmp nextpc+2
+ Stmt(Ret|K, 0), // return 0
+ Stmt(Ret|K, 1), // return 1
+ Stmt(Ret|K, 2), // return 2
+ },
+ expectedRet: 2,
+ },
+ {
+ desc: "Jump when A > X",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 10), // A = 10
+ Stmt(Ldx|Imm|W, 9), // X = 9
+ Jump(Jmp|Jgt|X, 0, 1, 2), // if (A > X) jmp nextpc+1 else jmp nextpc+2
+ Stmt(Ret|K, 0), // return 0
+ Stmt(Ret|K, 1), // return 1
+ Stmt(Ret|K, 2), // return 2
+ },
+ expectedRet: 1,
+ },
+ {
+ desc: "Jump when A <= X",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 10), // A = 10
+ Stmt(Ldx|Imm|W, 10), // X = 10
+ Jump(Jmp|Jgt|X, 0, 1, 2), // if (A > X) jmp nextpc+1 else jmp nextpc+2
+ Stmt(Ret|K, 0), // return 0
+ Stmt(Ret|K, 1), // return 1
+ Stmt(Ret|K, 2), // return 2
+ },
+ expectedRet: 2,
+ },
+ {
+ desc: "Jump when A >= immediate",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 10), // A = 10
+ Jump(Jmp|Jge|K, 10, 1, 2), // if (A >= 10) jmp nextpc+1 else jmp nextpc+2
+ Stmt(Ret|K, 0), // return 0
+ Stmt(Ret|K, 1), // return 1
+ Stmt(Ret|K, 2), // return 2
+ },
+ expectedRet: 1,
+ },
+ {
+ desc: "Jump when A < immediate",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 10), // A = 10
+ Jump(Jmp|Jge|K, 11, 1, 2), // if (A >= 11) jmp nextpc+1 else jmp nextpc+2
+ Stmt(Ret|K, 0), // return 0
+ Stmt(Ret|K, 1), // return 1
+ Stmt(Ret|K, 2), // return 2
+ },
+ expectedRet: 2,
+ },
+ {
+ desc: "Jump when A >= X",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 10), // A = 10
+ Stmt(Ldx|Imm|W, 10), // X = 10
+ Jump(Jmp|Jge|X, 0, 1, 2), // if (A >= X) jmp nextpc+1 else jmp nextpc+2
+ Stmt(Ret|K, 0), // return 0
+ Stmt(Ret|K, 1), // return 1
+ Stmt(Ret|K, 2), // return 2
+ },
+ expectedRet: 1,
+ },
+ {
+ desc: "Jump when A < X",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 10), // A = 10
+ Stmt(Ldx|Imm|W, 11), // X = 11
+ Jump(Jmp|Jge|X, 0, 1, 2), // if (A >= X) jmp nextpc+1 else jmp nextpc+2
+ Stmt(Ret|K, 0), // return 0
+ Stmt(Ret|K, 1), // return 1
+ Stmt(Ret|K, 2), // return 2
+ },
+ expectedRet: 2,
+ },
+ {
+ desc: "Jump when A & immediate != 0",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 0xff), // A = 0xff
+ Jump(Jmp|Jset|K, 0x101, 1, 2), // if (A & 0x101) jmp nextpc+1 else jmp nextpc+2
+ Stmt(Ret|K, 0), // return 0
+ Stmt(Ret|K, 1), // return 1
+ Stmt(Ret|K, 2), // return 2
+ },
+ expectedRet: 1,
+ },
+ {
+ desc: "Jump when A & immediate == 0",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 0xfe), // A = 0xfe
+ Jump(Jmp|Jset|K, 0x101, 1, 2), // if (A & 0x101) jmp nextpc+1 else jmp nextpc+2
+ Stmt(Ret|K, 0), // return 0
+ Stmt(Ret|K, 1), // return 1
+ Stmt(Ret|K, 2), // return 2
+ },
+ expectedRet: 2,
+ },
+ {
+ desc: "Jump when A & X != 0",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 0xff), // A = 0xff
+ Stmt(Ldx|Imm|W, 0x101), // X = 0x101
+ Jump(Jmp|Jset|X, 0, 1, 2), // if (A & X) jmp nextpc+1 else jmp nextpc+2
+ Stmt(Ret|K, 0), // return 0
+ Stmt(Ret|K, 1), // return 1
+ Stmt(Ret|K, 2), // return 2
+ },
+ expectedRet: 1,
+ },
+ {
+ desc: "Jump when A & X == 0",
+ insns: []linux.BPFInstruction{
+ Stmt(Ld|Imm|W, 0xfe), // A = 0xfe
+ Stmt(Ldx|Imm|W, 0x101), // X = 0x101
+ Jump(Jmp|Jset|X, 0, 1, 2), // if (A & X) jmp nextpc+1 else jmp nextpc+2
+ Stmt(Ret|K, 0), // return 0
+ Stmt(Ret|K, 1), // return 1
+ Stmt(Ret|K, 2), // return 2
+ },
+ expectedRet: 2,
+ },
+ } {
+ p, err := Compile(test.insns)
+ if err != nil {
+ t.Errorf("%s: unexpected compilation error: %v", test.desc, err)
+ continue
+ }
+ ret, err := Exec(p, InputBytes{test.input, binary.BigEndian})
+ if err != nil {
+ t.Errorf("%s: expected return value of %d, got execution error: %v", test.desc, test.expectedRet, err)
+ continue
+ }
+ if ret != test.expectedRet {
+ t.Errorf("%s: expected return value of %d, got value %d", test.desc, test.expectedRet, ret)
+ }
+ }
+}
+
+func TestSimpleFilter(t *testing.T) {
+ // Seccomp filter example given in Linux's
+ // Documentation/networking/filter.txt, translated to bytecode using the
+ // Linux kernel tree's tools/net/bpf_asm.
+ filter := []linux.BPFInstruction{
+ {0x20, 0, 0, 0x00000004}, // ld [4] /* offsetof(struct seccomp_data, arch) */
+ {0x15, 0, 11, 0xc000003e}, // jne #0xc000003e, bad /* AUDIT_ARCH_X86_64 */
+ {0x20, 0, 0, 0000000000}, // ld [0] /* offsetof(struct seccomp_data, nr) */
+ {0x15, 10, 0, 0x0000000f}, // jeq #15, good /* __NR_rt_sigreturn */
+ {0x15, 9, 0, 0x000000e7}, // jeq #231, good /* __NR_exit_group */
+ {0x15, 8, 0, 0x0000003c}, // jeq #60, good /* __NR_exit */
+ {0x15, 7, 0, 0000000000}, // jeq #0, good /* __NR_read */
+ {0x15, 6, 0, 0x00000001}, // jeq #1, good /* __NR_write */
+ {0x15, 5, 0, 0x00000005}, // jeq #5, good /* __NR_fstat */
+ {0x15, 4, 0, 0x00000009}, // jeq #9, good /* __NR_mmap */
+ {0x15, 3, 0, 0x0000000e}, // jeq #14, good /* __NR_rt_sigprocmask */
+ {0x15, 2, 0, 0x0000000d}, // jeq #13, good /* __NR_rt_sigaction */
+ {0x15, 1, 0, 0x00000023}, // jeq #35, good /* __NR_nanosleep */
+ {0x06, 0, 0, 0000000000}, // bad: ret #0 /* SECCOMP_RET_KILL */
+ {0x06, 0, 0, 0x7fff0000}, // good: ret #0x7fff0000 /* SECCOMP_RET_ALLOW */
+ }
+ p, err := Compile(filter)
+ if err != nil {
+ t.Fatalf("Unexpected compilation error: %v", err)
+ }
+
+ for _, test := range []struct {
+ // desc is the test's description.
+ desc string
+
+ // seccompData is the input data.
+ seccompData
+
+ // expectedRet is the expected return value of the BPF program.
+ expectedRet uint32
+ }{
+ {
+ desc: "Invalid arch is rejected",
+ seccompData: seccompData{nr: 1 /* x86 exit */, arch: 0x40000003 /* AUDIT_ARCH_I386 */},
+ expectedRet: 0,
+ },
+ {
+ desc: "Disallowed syscall is rejected",
+ seccompData: seccompData{nr: 105 /* __NR_setuid */, arch: 0xc000003e},
+ expectedRet: 0,
+ },
+ {
+ desc: "Allowed syscall is indeed allowed",
+ seccompData: seccompData{nr: 231 /* __NR_exit_group */, arch: 0xc000003e},
+ expectedRet: 0x7fff0000,
+ },
+ } {
+ ret, err := Exec(p, test.seccompData.asInput())
+ if err != nil {
+ t.Errorf("%s: expected return value of %d, got execution error: %v", test.desc, test.expectedRet, err)
+ continue
+ }
+ if ret != test.expectedRet {
+ t.Errorf("%s: expected return value of %d, got value %d", test.desc, test.expectedRet, ret)
+ }
+ }
+}
+
+// seccompData is equivalent to struct seccomp_data.
+type seccompData struct {
+ nr uint32
+ arch uint32
+ instructionPointer uint64
+ args [6]uint64
+}
+
+// asInput converts a seccompData to a bpf.Input.
+func (d *seccompData) asInput() Input {
+ return InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian}
+}
diff --git a/pkg/bpf/program_builder.go b/pkg/bpf/program_builder.go
new file mode 100644
index 000000000..7992044d0
--- /dev/null
+++ b/pkg/bpf/program_builder.go
@@ -0,0 +1,191 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package bpf
+
+import (
+ "fmt"
+ "math"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+const (
+ labelTarget = math.MaxUint8
+ labelDirectTarget = math.MaxUint32
+)
+
+// ProgramBuilder assists with building a BPF program with jump
+// labels that are resolved to their proper offsets.
+type ProgramBuilder struct {
+ // Maps label names to label objects.
+ labels map[string]*label
+
+ // Array of BPF instructions that makes up the program.
+ instructions []linux.BPFInstruction
+}
+
+// NewProgramBuilder creates a new ProgramBuilder instance.
+func NewProgramBuilder() *ProgramBuilder {
+ return &ProgramBuilder{labels: map[string]*label{}}
+}
+
+// label contains information to resolve a label to an offset.
+type label struct {
+ // List of locations that reference the label in the program.
+ sources []source
+
+ // Program line when the label is located.
+ target int
+}
+
+type jmpType int
+
+const (
+ jDirect jmpType = iota
+ jTrue
+ jFalse
+)
+
+// source contains information about a single reference to a label.
+type source struct {
+ // Program line where the label reference is present.
+ line int
+
+ // True if label reference is in the 'jump if true' part of the jump.
+ // False if label reference is in the 'jump if false' part of the jump.
+ jt jmpType
+}
+
+// AddStmt adds a new statement to the program.
+func (b *ProgramBuilder) AddStmt(code uint16, k uint32) {
+ b.instructions = append(b.instructions, Stmt(code, k))
+}
+
+// AddJump adds a new jump to the program.
+func (b *ProgramBuilder) AddJump(code uint16, k uint32, jt, jf uint8) {
+ b.instructions = append(b.instructions, Jump(code, k, jt, jf))
+}
+
+// AddDirectJumpLabel adds a new jump to the program where is labelled.
+func (b *ProgramBuilder) AddDirectJumpLabel(labelName string) {
+ b.addLabelSource(labelName, jDirect)
+ b.AddJump(Jmp|Ja, labelDirectTarget, 0, 0)
+}
+
+// AddJumpTrueLabel adds a new jump to the program where 'jump if true' is a label.
+func (b *ProgramBuilder) AddJumpTrueLabel(code uint16, k uint32, jtLabel string, jf uint8) {
+ b.addLabelSource(jtLabel, jTrue)
+ b.AddJump(code, k, labelTarget, jf)
+}
+
+// AddJumpFalseLabel adds a new jump to the program where 'jump if false' is a label.
+func (b *ProgramBuilder) AddJumpFalseLabel(code uint16, k uint32, jt uint8, jfLabel string) {
+ b.addLabelSource(jfLabel, jFalse)
+ b.AddJump(code, k, jt, labelTarget)
+}
+
+// AddJumpLabels adds a new jump to the program where both jump targets are labels.
+func (b *ProgramBuilder) AddJumpLabels(code uint16, k uint32, jtLabel, jfLabel string) {
+ b.addLabelSource(jtLabel, jTrue)
+ b.addLabelSource(jfLabel, jFalse)
+ b.AddJump(code, k, labelTarget, labelTarget)
+}
+
+// AddLabel sets the given label name at the current location. The next instruction is executed
+// when the any code jumps to this label. More than one label can be added to the same location.
+func (b *ProgramBuilder) AddLabel(name string) error {
+ l, ok := b.labels[name]
+ if !ok {
+ // This is done to catch jump backwards cases, but it's not strictly wrong
+ // to have unused labels.
+ return fmt.Errorf("Adding a label that hasn't been used is not allowed: %v", name)
+ }
+ if l.target != -1 {
+ return fmt.Errorf("label %q target already set: %v", name, l.target)
+ }
+ l.target = len(b.instructions)
+ return nil
+}
+
+// Instructions returns an array of BPF instructions representing the program with all labels
+// resolved. Return error in case label resolution failed due to an invalid program.
+//
+// N.B. Partial results will be returned in the error case, which is useful for debugging.
+func (b *ProgramBuilder) Instructions() ([]linux.BPFInstruction, error) {
+ if err := b.resolveLabels(); err != nil {
+ return b.instructions, err
+ }
+ return b.instructions, nil
+}
+
+func (b *ProgramBuilder) addLabelSource(labelName string, t jmpType) {
+ l, ok := b.labels[labelName]
+ if !ok {
+ l = &label{sources: make([]source, 0), target: -1}
+ b.labels[labelName] = l
+ }
+ l.sources = append(l.sources, source{line: len(b.instructions), jt: t})
+}
+
+func (b *ProgramBuilder) resolveLabels() error {
+ for key, v := range b.labels {
+ if v.target == -1 {
+ return fmt.Errorf("label target not set: %v", key)
+ }
+ if v.target >= len(b.instructions) {
+ return fmt.Errorf("target is beyond end of ProgramBuilder")
+ }
+ for _, s := range v.sources {
+ // Finds jump instruction that references the label.
+ inst := b.instructions[s.line]
+ if s.line >= v.target {
+ return fmt.Errorf("cannot jump backwards")
+ }
+ // Calculates the jump offset from current line.
+ offset := v.target - s.line - 1
+ // Sets offset into jump instruction.
+ switch s.jt {
+ case jDirect:
+ if offset > labelDirectTarget {
+ return fmt.Errorf("jump offset to label '%v' is too large: %v, inst: %v, lineno: %v", key, offset, inst, s.line)
+ }
+ if inst.K != labelDirectTarget {
+ return fmt.Errorf("jump target is not a label")
+ }
+ inst.K = uint32(offset)
+ case jTrue:
+ if offset > labelTarget {
+ return fmt.Errorf("jump offset to label '%v' is too large: %v, inst: %v, lineno: %v", key, offset, inst, s.line)
+ }
+ if inst.JumpIfTrue != labelTarget {
+ return fmt.Errorf("jump target is not a label")
+ }
+ inst.JumpIfTrue = uint8(offset)
+ case jFalse:
+ if offset > labelTarget {
+ return fmt.Errorf("jump offset to label '%v' is too large: %v, inst: %v, lineno: %v", key, offset, inst, s.line)
+ }
+ if inst.JumpIfFalse != labelTarget {
+ return fmt.Errorf("jump target is not a label")
+ }
+ inst.JumpIfFalse = uint8(offset)
+ }
+
+ b.instructions[s.line] = inst
+ }
+ }
+ b.labels = map[string]*label{}
+ return nil
+}
diff --git a/pkg/bpf/program_builder_test.go b/pkg/bpf/program_builder_test.go
new file mode 100644
index 000000000..92ca5f4c3
--- /dev/null
+++ b/pkg/bpf/program_builder_test.go
@@ -0,0 +1,157 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package bpf
+
+import (
+ "fmt"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+func validate(p *ProgramBuilder, expected []linux.BPFInstruction) error {
+ instructions, err := p.Instructions()
+ if err != nil {
+ return fmt.Errorf("Instructions() failed: %v", err)
+ }
+ got, err := DecodeProgram(instructions)
+ if err != nil {
+ return fmt.Errorf("DecodeProgram('instructions') failed: %v", err)
+ }
+ expectedDecoded, err := DecodeProgram(expected)
+ if err != nil {
+ return fmt.Errorf("DecodeProgram('expected') failed: %v", err)
+ }
+ if got != expectedDecoded {
+ return fmt.Errorf("DecodeProgram() failed, expected: %q, got: %q", expectedDecoded, got)
+ }
+ return nil
+}
+
+func TestProgramBuilderSimple(t *testing.T) {
+ p := NewProgramBuilder()
+ p.AddStmt(Ld+Abs+W, 10)
+ p.AddJump(Jmp+Ja, 10, 0, 0)
+
+ expected := []linux.BPFInstruction{
+ Stmt(Ld+Abs+W, 10),
+ Jump(Jmp+Ja, 10, 0, 0),
+ }
+
+ if err := validate(p, expected); err != nil {
+ t.Errorf("Validate() failed: %v", err)
+ }
+}
+
+func TestProgramBuilderLabels(t *testing.T) {
+ p := NewProgramBuilder()
+ p.AddJumpTrueLabel(Jmp+Jeq+K, 11, "label_1", 0)
+ p.AddJumpFalseLabel(Jmp+Jeq+K, 12, 0, "label_2")
+ p.AddJumpLabels(Jmp+Jeq+K, 13, "label_3", "label_4")
+ if err := p.AddLabel("label_1"); err != nil {
+ t.Errorf("AddLabel(label_1) failed: %v", err)
+ }
+ p.AddStmt(Ld+Abs+W, 1)
+ if err := p.AddLabel("label_3"); err != nil {
+ t.Errorf("AddLabel(label_3) failed: %v", err)
+ }
+ p.AddJumpLabels(Jmp+Jeq+K, 14, "label_4", "label_5")
+ if err := p.AddLabel("label_2"); err != nil {
+ t.Errorf("AddLabel(label_2) failed: %v", err)
+ }
+ p.AddJumpLabels(Jmp+Jeq+K, 15, "label_4", "label_6")
+ if err := p.AddLabel("label_4"); err != nil {
+ t.Errorf("AddLabel(label_4) failed: %v", err)
+ }
+ p.AddStmt(Ld+Abs+W, 4)
+ if err := p.AddLabel("label_5"); err != nil {
+ t.Errorf("AddLabel(label_5) failed: %v", err)
+ }
+ if err := p.AddLabel("label_6"); err != nil {
+ t.Errorf("AddLabel(label_6) failed: %v", err)
+ }
+ p.AddStmt(Ld+Abs+W, 5)
+
+ expected := []linux.BPFInstruction{
+ Jump(Jmp+Jeq+K, 11, 2, 0),
+ Jump(Jmp+Jeq+K, 12, 0, 3),
+ Jump(Jmp+Jeq+K, 13, 1, 3),
+ Stmt(Ld+Abs+W, 1),
+ Jump(Jmp+Jeq+K, 14, 1, 2),
+ Jump(Jmp+Jeq+K, 15, 0, 1),
+ Stmt(Ld+Abs+W, 4),
+ Stmt(Ld+Abs+W, 5),
+ }
+ if err := validate(p, expected); err != nil {
+ t.Errorf("Validate() failed: %v", err)
+ }
+ // Calling validate()=>p.Instructions() again to make sure
+ // Instructions can be called multiple times without ruining
+ // the program.
+ if err := validate(p, expected); err != nil {
+ t.Errorf("Validate() failed: %v", err)
+ }
+}
+
+func TestProgramBuilderMissingErrorTarget(t *testing.T) {
+ p := NewProgramBuilder()
+ p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0)
+ if _, err := p.Instructions(); err == nil {
+ t.Errorf("Instructions() should have failed")
+ }
+}
+
+func TestProgramBuilderLabelWithNoInstruction(t *testing.T) {
+ p := NewProgramBuilder()
+ p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0)
+ if err := p.AddLabel("label_1"); err != nil {
+ t.Errorf("AddLabel(label_1) failed: %v", err)
+ }
+ if _, err := p.Instructions(); err == nil {
+ t.Errorf("Instructions() should have failed")
+ }
+}
+
+func TestProgramBuilderUnusedLabel(t *testing.T) {
+ p := NewProgramBuilder()
+ if err := p.AddLabel("unused"); err == nil {
+ t.Errorf("AddLabel(unused) should have failed")
+ }
+}
+
+func TestProgramBuilderLabelAddedTwice(t *testing.T) {
+ p := NewProgramBuilder()
+ p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0)
+ if err := p.AddLabel("label_1"); err != nil {
+ t.Errorf("AddLabel(label_1) failed: %v", err)
+ }
+ p.AddStmt(Ld+Abs+W, 0)
+ if err := p.AddLabel("label_1"); err == nil {
+ t.Errorf("AddLabel(label_1) failed: %v", err)
+ }
+}
+
+func TestProgramBuilderJumpBackwards(t *testing.T) {
+ p := NewProgramBuilder()
+ p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0)
+ if err := p.AddLabel("label_1"); err != nil {
+ t.Errorf("AddLabel(label_1) failed: %v", err)
+ }
+ p.AddStmt(Ld+Abs+W, 0)
+ p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0)
+ if _, err := p.Instructions(); err == nil {
+ t.Errorf("Instructions() should have failed")
+ }
+}
diff --git a/pkg/buffer/BUILD b/pkg/buffer/BUILD
new file mode 100644
index 000000000..dcd086298
--- /dev/null
+++ b/pkg/buffer/BUILD
@@ -0,0 +1,43 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "buffer_list",
+ out = "buffer_list.go",
+ package = "buffer",
+ prefix = "buffer",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*buffer",
+ "Linker": "*buffer",
+ },
+)
+
+go_library(
+ name = "buffer",
+ srcs = [
+ "buffer.go",
+ "buffer_list.go",
+ "safemem.go",
+ "view.go",
+ "view_unsafe.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/log",
+ "//pkg/safemem",
+ ],
+)
+
+go_test(
+ name = "buffer_test",
+ size = "small",
+ srcs = [
+ "safemem_test.go",
+ "view_test.go",
+ ],
+ library = ":buffer",
+ deps = ["//pkg/safemem"],
+)
diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go
new file mode 100644
index 000000000..c6d089fd9
--- /dev/null
+++ b/pkg/buffer/buffer.go
@@ -0,0 +1,94 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package buffer provides the implementation of a buffer view.
+//
+// A view is an flexible buffer, backed by a pool, supporting the safecopy
+// operations natively as well as the ability to grow via either prepend or
+// append, as well as shrink.
+package buffer
+
+import (
+ "sync"
+)
+
+const bufferSize = 8144 // See below.
+
+// buffer encapsulates a queueable byte buffer.
+//
+// Note that the total size is slightly less than two pages. This is done
+// intentionally to ensure that the buffer object aligns with runtime
+// internals. We have no hard size or alignment requirements. This two page
+// size will effectively minimize internal fragmentation, but still have a
+// large enough chunk to limit excessive segmentation.
+//
+// +stateify savable
+type buffer struct {
+ data [bufferSize]byte
+ read int
+ write int
+ bufferEntry
+}
+
+// reset resets internal data.
+//
+// This must be called before returning the buffer to the pool.
+func (b *buffer) Reset() {
+ b.read = 0
+ b.write = 0
+}
+
+// Full indicates the buffer is full.
+//
+// This indicates there is no capacity left to write.
+func (b *buffer) Full() bool {
+ return b.write == len(b.data)
+}
+
+// ReadSize returns the number of bytes available for reading.
+func (b *buffer) ReadSize() int {
+ return b.write - b.read
+}
+
+// ReadMove advances the read index by the given amount.
+func (b *buffer) ReadMove(n int) {
+ b.read += n
+}
+
+// ReadSlice returns the read slice for this buffer.
+func (b *buffer) ReadSlice() []byte {
+ return b.data[b.read:b.write]
+}
+
+// WriteSize returns the number of bytes available for writing.
+func (b *buffer) WriteSize() int {
+ return len(b.data) - b.write
+}
+
+// WriteMove advances the write index by the given amount.
+func (b *buffer) WriteMove(n int) {
+ b.write += n
+}
+
+// WriteSlice returns the write slice for this buffer.
+func (b *buffer) WriteSlice() []byte {
+ return b.data[b.write:]
+}
+
+// bufferPool is a pool for buffers.
+var bufferPool = sync.Pool{
+ New: func() interface{} {
+ return new(buffer)
+ },
+}
diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go
new file mode 100644
index 000000000..b789e56e9
--- /dev/null
+++ b/pkg/buffer/safemem.go
@@ -0,0 +1,133 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "gvisor.dev/gvisor/pkg/safemem"
+)
+
+// WriteBlock returns this buffer as a write Block.
+func (b *buffer) WriteBlock() safemem.Block {
+ return safemem.BlockFromSafeSlice(b.WriteSlice())
+}
+
+// ReadBlock returns this buffer as a read Block.
+func (b *buffer) ReadBlock() safemem.Block {
+ return safemem.BlockFromSafeSlice(b.ReadSlice())
+}
+
+// WriteFromSafememReader writes up to count bytes from r to v and advances the
+// write index by the number of bytes written. It calls r.ReadToBlocks() at
+// most once.
+func (v *View) WriteFromSafememReader(r safemem.Reader, count uint64) (uint64, error) {
+ if count == 0 {
+ return 0, nil
+ }
+
+ var (
+ dst safemem.BlockSeq
+ blocks []safemem.Block
+ )
+
+ // Need at least one buffer.
+ firstBuf := v.data.Back()
+ if firstBuf == nil {
+ firstBuf = bufferPool.Get().(*buffer)
+ v.data.PushBack(firstBuf)
+ }
+
+ // Does the last block have sufficient capacity alone?
+ if l := uint64(firstBuf.WriteSize()); l >= count {
+ dst = safemem.BlockSeqOf(firstBuf.WriteBlock().TakeFirst64(count))
+ } else {
+ // Append blocks until sufficient.
+ count -= l
+ blocks = append(blocks, firstBuf.WriteBlock())
+ for count > 0 {
+ emptyBuf := bufferPool.Get().(*buffer)
+ v.data.PushBack(emptyBuf)
+ block := emptyBuf.WriteBlock().TakeFirst64(count)
+ count -= uint64(block.Len())
+ blocks = append(blocks, block)
+ }
+ dst = safemem.BlockSeqFromSlice(blocks)
+ }
+
+ // Perform I/O.
+ n, err := r.ReadToBlocks(dst)
+ v.size += int64(n)
+
+ // Update all indices.
+ for left := n; left > 0; firstBuf = firstBuf.Next() {
+ if l := firstBuf.WriteSize(); left >= uint64(l) {
+ firstBuf.WriteMove(l) // Whole block.
+ left -= uint64(l)
+ } else {
+ firstBuf.WriteMove(int(left)) // Partial block.
+ left = 0
+ }
+ }
+
+ return n, err
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. It advances the
+// write index by the number of bytes written.
+func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ return v.WriteFromSafememReader(&safemem.BlockSeqReader{srcs}, srcs.NumBytes())
+}
+
+// ReadToSafememWriter reads up to count bytes from v to w. It does not advance
+// the read index. It calls w.WriteFromBlocks() at most once.
+func (v *View) ReadToSafememWriter(w safemem.Writer, count uint64) (uint64, error) {
+ if count == 0 {
+ return 0, nil
+ }
+
+ var (
+ src safemem.BlockSeq
+ blocks []safemem.Block
+ )
+
+ firstBuf := v.data.Front()
+ if firstBuf == nil {
+ return 0, nil // No EOF.
+ }
+
+ // Is all the data in a single block?
+ if l := uint64(firstBuf.ReadSize()); l >= count {
+ src = safemem.BlockSeqOf(firstBuf.ReadBlock().TakeFirst64(count))
+ } else {
+ // Build a list of all the buffers.
+ count -= l
+ blocks = append(blocks, firstBuf.ReadBlock())
+ for buf := firstBuf.Next(); buf != nil && count > 0; buf = buf.Next() {
+ block := buf.ReadBlock().TakeFirst64(count)
+ count -= uint64(block.Len())
+ blocks = append(blocks, block)
+ }
+ src = safemem.BlockSeqFromSlice(blocks)
+ }
+
+ // Perform I/O. As documented, we don't advance the read index.
+ return w.WriteFromBlocks(src)
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks. It does not advance the
+// read index by the number of bytes read, such that it's only safe to call if
+// the caller guarantees that ReadToBlocks will only be called once.
+func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ return v.ReadToSafememWriter(&safemem.BlockSeqWriter{dsts}, dsts.NumBytes())
+}
diff --git a/pkg/buffer/safemem_test.go b/pkg/buffer/safemem_test.go
new file mode 100644
index 000000000..47f357e0c
--- /dev/null
+++ b/pkg/buffer/safemem_test.go
@@ -0,0 +1,170 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "bytes"
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/safemem"
+)
+
+func TestSafemem(t *testing.T) {
+ testCases := []struct {
+ name string
+ input string
+ output string
+ readLen int
+ op func(*View)
+ }{
+ // Basic coverage.
+ {
+ name: "short",
+ input: "010",
+ output: "010",
+ },
+ {
+ name: "long",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "0" + strings.Repeat("1", bufferSize) + "0",
+ },
+ {
+ name: "short-read",
+ input: "0",
+ readLen: 100, // > size.
+ output: "0",
+ },
+ {
+ name: "zero-read",
+ input: "0",
+ output: "",
+ },
+ {
+ name: "read-empty",
+ input: "",
+ readLen: 1, // > size.
+ output: "",
+ },
+
+ // Ensure offsets work.
+ {
+ name: "offsets-short",
+ input: "012",
+ output: "2",
+ op: func(v *View) {
+ v.TrimFront(2)
+ },
+ },
+ {
+ name: "offsets-long0",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: strings.Repeat("1", bufferSize) + "0",
+ op: func(v *View) {
+ v.TrimFront(1)
+ },
+ },
+ {
+ name: "offsets-long1",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: strings.Repeat("1", bufferSize-1) + "0",
+ op: func(v *View) {
+ v.TrimFront(2)
+ },
+ },
+ {
+ name: "offsets-long2",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "10",
+ op: func(v *View) {
+ v.TrimFront(bufferSize)
+ },
+ },
+
+ // Ensure truncation works.
+ {
+ name: "truncate-short",
+ input: "012",
+ output: "01",
+ op: func(v *View) {
+ v.Truncate(2)
+ },
+ },
+ {
+ name: "truncate-long0",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "0" + strings.Repeat("1", bufferSize),
+ op: func(v *View) {
+ v.Truncate(bufferSize + 1)
+ },
+ },
+ {
+ name: "truncate-long1",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "0" + strings.Repeat("1", bufferSize-1),
+ op: func(v *View) {
+ v.Truncate(bufferSize)
+ },
+ },
+ {
+ name: "truncate-long2",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "01",
+ op: func(v *View) {
+ v.Truncate(2)
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ // Construct the new view.
+ var view View
+ bs := safemem.BlockSeqOf(safemem.BlockFromSafeSlice([]byte(tc.input)))
+ n, err := view.WriteFromBlocks(bs)
+ if err != nil {
+ t.Errorf("expected err nil, got %v", err)
+ }
+ if n != uint64(len(tc.input)) {
+ t.Errorf("expected %d bytes, got %d", len(tc.input), n)
+ }
+
+ // Run the operation.
+ if tc.op != nil {
+ tc.op(&view)
+ }
+
+ // Read and validate.
+ readLen := tc.readLen
+ if readLen == 0 {
+ readLen = len(tc.output) // Default.
+ }
+ out := make([]byte, readLen)
+ bs = safemem.BlockSeqOf(safemem.BlockFromSafeSlice(out))
+ n, err = view.ReadToBlocks(bs)
+ if err != nil {
+ t.Errorf("expected nil, got %v", err)
+ }
+ if n != uint64(len(tc.output)) {
+ t.Errorf("expected %d bytes, got %d", len(tc.output), n)
+ }
+
+ // Ensure the contents are correct.
+ if !bytes.Equal(out[:n], []byte(tc.output[:n])) {
+ t.Errorf("contents are wrong: expected %q, got %q", tc.output, string(out))
+ }
+ })
+ }
+}
diff --git a/pkg/buffer/view.go b/pkg/buffer/view.go
new file mode 100644
index 000000000..e6901eadb
--- /dev/null
+++ b/pkg/buffer/view.go
@@ -0,0 +1,390 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "fmt"
+ "io"
+)
+
+// View is a non-linear buffer.
+//
+// All methods are thread compatible.
+//
+// +stateify savable
+type View struct {
+ data bufferList
+ size int64
+}
+
+// TrimFront removes the first count bytes from the buffer.
+func (v *View) TrimFront(count int64) {
+ if count >= v.size {
+ v.advanceRead(v.size)
+ } else {
+ v.advanceRead(count)
+ }
+}
+
+// ReadAt implements io.ReaderAt.ReadAt.
+func (v *View) ReadAt(p []byte, offset int64) (int, error) {
+ var (
+ skipped int64
+ done int64
+ )
+ for buf := v.data.Front(); buf != nil && done < int64(len(p)); buf = buf.Next() {
+ needToSkip := int(offset - skipped)
+ if sz := buf.ReadSize(); sz <= needToSkip {
+ skipped += int64(sz)
+ continue
+ }
+
+ // Actually read data.
+ n := copy(p[done:], buf.ReadSlice()[needToSkip:])
+ skipped += int64(needToSkip)
+ done += int64(n)
+ }
+ if int(done) < len(p) || offset+done == v.size {
+ return int(done), io.EOF
+ }
+ return int(done), nil
+}
+
+// advanceRead advances the view's read index.
+//
+// Precondition: there must be sufficient bytes in the buffer.
+func (v *View) advanceRead(count int64) {
+ for buf := v.data.Front(); buf != nil && count > 0; {
+ sz := int64(buf.ReadSize())
+ if sz > count {
+ // There is still data for reading.
+ buf.ReadMove(int(count))
+ v.size -= count
+ count = 0
+ break
+ }
+
+ // Consume the whole buffer.
+ oldBuf := buf
+ buf = buf.Next() // Iterate.
+ v.data.Remove(oldBuf)
+ oldBuf.Reset()
+ bufferPool.Put(oldBuf)
+
+ // Update counts.
+ count -= sz
+ v.size -= sz
+ }
+ if count > 0 {
+ panic(fmt.Sprintf("advanceRead still has %d bytes remaining", count))
+ }
+}
+
+// Truncate truncates the view to the given bytes.
+//
+// This will not grow the view, only shrink it. If a length is passed that is
+// greater than the current size of the view, then nothing will happen.
+//
+// Precondition: length must be >= 0.
+func (v *View) Truncate(length int64) {
+ if length < 0 {
+ panic("negative length provided")
+ }
+ if length >= v.size {
+ return // Nothing to do.
+ }
+ for buf := v.data.Back(); buf != nil && v.size > length; buf = v.data.Back() {
+ sz := int64(buf.ReadSize())
+ if after := v.size - sz; after < length {
+ // Truncate the buffer locally.
+ left := (length - after)
+ buf.write = buf.read + int(left)
+ v.size = length
+ break
+ }
+
+ // Drop the buffer completely; see above.
+ v.data.Remove(buf)
+ buf.Reset()
+ bufferPool.Put(buf)
+ v.size -= sz
+ }
+}
+
+// Grow grows the given view to the number of bytes, which will be appended. If
+// zero is true, all these bytes will be zero. If zero is false, then this is
+// the caller's responsibility.
+//
+// Precondition: length must be >= 0.
+func (v *View) Grow(length int64, zero bool) {
+ if length < 0 {
+ panic("negative length provided")
+ }
+ for v.size < length {
+ buf := v.data.Back()
+
+ // Is there some space in the last buffer?
+ if buf == nil || buf.Full() {
+ buf = bufferPool.Get().(*buffer)
+ v.data.PushBack(buf)
+ }
+
+ // Write up to length bytes.
+ sz := buf.WriteSize()
+ if int64(sz) > length-v.size {
+ sz = int(length - v.size)
+ }
+
+ // Zero the written section; note that this pattern is
+ // specifically recognized and optimized by the compiler.
+ if zero {
+ for i := buf.write; i < buf.write+sz; i++ {
+ buf.data[i] = 0
+ }
+ }
+
+ // Advance the index.
+ buf.WriteMove(sz)
+ v.size += int64(sz)
+ }
+}
+
+// Prepend prepends the given data.
+func (v *View) Prepend(data []byte) {
+ // Is there any space in the first buffer?
+ if buf := v.data.Front(); buf != nil && buf.read > 0 {
+ // Fill up before the first write.
+ avail := buf.read
+ bStart := 0
+ dStart := len(data) - avail
+ if avail > len(data) {
+ bStart = avail - len(data)
+ dStart = 0
+ }
+ n := copy(buf.data[bStart:], data[dStart:])
+ data = data[:dStart]
+ v.size += int64(n)
+ buf.read -= n
+ }
+
+ for len(data) > 0 {
+ // Do we need an empty buffer?
+ buf := bufferPool.Get().(*buffer)
+ v.data.PushFront(buf)
+
+ // The buffer is empty; copy last chunk.
+ avail := len(buf.data)
+ bStart := 0
+ dStart := len(data) - avail
+ if avail > len(data) {
+ bStart = avail - len(data)
+ dStart = 0
+ }
+
+ // We have to put the data at the end of the current
+ // buffer in order to ensure that the next prepend will
+ // correctly fill up the beginning of this buffer.
+ n := copy(buf.data[bStart:], data[dStart:])
+ data = data[:dStart]
+ v.size += int64(n)
+ buf.read = len(buf.data) - n
+ buf.write = len(buf.data)
+ }
+}
+
+// Append appends the given data.
+func (v *View) Append(data []byte) {
+ for done := 0; done < len(data); {
+ buf := v.data.Back()
+
+ // Ensure there's a buffer with space.
+ if buf == nil || buf.Full() {
+ buf = bufferPool.Get().(*buffer)
+ v.data.PushBack(buf)
+ }
+
+ // Copy in to the given buffer.
+ n := copy(buf.WriteSlice(), data[done:])
+ done += n
+ buf.WriteMove(n)
+ v.size += int64(n)
+ }
+}
+
+// Flatten returns a flattened copy of this data.
+//
+// This method should not be used in any performance-sensitive paths. It may
+// allocate a fresh byte slice sufficiently large to contain all the data in
+// the buffer. This is principally for debugging.
+//
+// N.B. Tee data still belongs to this view, as if there is a single buffer
+// present, then it will be returned directly. This should be used for
+// temporary use only, and a reference to the given slice should not be held.
+func (v *View) Flatten() []byte {
+ if buf := v.data.Front(); buf == nil {
+ return nil // No data at all.
+ } else if buf.Next() == nil {
+ return buf.ReadSlice() // Only one buffer.
+ }
+ data := make([]byte, 0, v.size) // Need to flatten.
+ for buf := v.data.Front(); buf != nil; buf = buf.Next() {
+ // Copy to the allocated slice.
+ data = append(data, buf.ReadSlice()...)
+ }
+ return data
+}
+
+// Size indicates the total amount of data available in this view.
+func (v *View) Size() int64 {
+ return v.size
+}
+
+// Copy makes a strict copy of this view.
+func (v *View) Copy() (other View) {
+ for buf := v.data.Front(); buf != nil; buf = buf.Next() {
+ other.Append(buf.ReadSlice())
+ }
+ return
+}
+
+// Apply applies the given function across all valid data.
+func (v *View) Apply(fn func([]byte)) {
+ for buf := v.data.Front(); buf != nil; buf = buf.Next() {
+ fn(buf.ReadSlice())
+ }
+}
+
+// Merge merges the provided View with this one.
+//
+// The other view will be appended to v, and other will be empty after this
+// operation completes.
+func (v *View) Merge(other *View) {
+ // Copy over all buffers.
+ for buf := other.data.Front(); buf != nil; buf = other.data.Front() {
+ other.data.Remove(buf)
+ v.data.PushBack(buf)
+ }
+
+ // Adjust sizes.
+ v.size += other.size
+ other.size = 0
+}
+
+// WriteFromReader writes to the buffer from an io.Reader.
+//
+// A minimum read size equal to unsafe.Sizeof(unintptr) is enforced,
+// provided that count is greater than or equal to unsafe.Sizeof(uintptr).
+func (v *View) WriteFromReader(r io.Reader, count int64) (int64, error) {
+ var (
+ done int64
+ n int
+ err error
+ )
+ for done < count {
+ buf := v.data.Back()
+
+ // Ensure we have an empty buffer.
+ if buf == nil || buf.Full() {
+ buf = bufferPool.Get().(*buffer)
+ v.data.PushBack(buf)
+ }
+
+ // Is this less than the minimum batch?
+ if buf.WriteSize() < minBatch && (count-done) >= int64(minBatch) {
+ tmp := make([]byte, minBatch)
+ n, err = r.Read(tmp)
+ v.Append(tmp[:n])
+ done += int64(n)
+ if err != nil {
+ break
+ }
+ continue
+ }
+
+ // Limit the read, if necessary.
+ sz := buf.WriteSize()
+ if left := count - done; int64(sz) > left {
+ sz = int(left)
+ }
+
+ // Pass the relevant portion of the buffer.
+ n, err = r.Read(buf.WriteSlice()[:sz])
+ buf.WriteMove(n)
+ done += int64(n)
+ v.size += int64(n)
+ if err == io.EOF {
+ err = nil // Short write allowed.
+ break
+ } else if err != nil {
+ break
+ }
+ }
+ return done, err
+}
+
+// ReadToWriter reads from the buffer into an io.Writer.
+//
+// N.B. This does not consume the bytes read. TrimFront should
+// be called appropriately after this call in order to do so.
+//
+// A minimum write size equal to unsafe.Sizeof(unintptr) is enforced,
+// provided that count is greater than or equal to unsafe.Sizeof(uintptr).
+func (v *View) ReadToWriter(w io.Writer, count int64) (int64, error) {
+ var (
+ done int64
+ n int
+ err error
+ )
+ offset := 0 // Spill-over for batching.
+ for buf := v.data.Front(); buf != nil && done < count; buf = buf.Next() {
+ // Has this been consumed? Skip it.
+ sz := buf.ReadSize()
+ if sz <= offset {
+ offset -= sz
+ continue
+ }
+ sz -= offset
+
+ // Is this less than the minimum batch?
+ left := count - done
+ if sz < minBatch && left >= int64(minBatch) && (v.size-done) >= int64(minBatch) {
+ tmp := make([]byte, minBatch)
+ n, err = v.ReadAt(tmp, done)
+ w.Write(tmp[:n])
+ done += int64(n)
+ offset = n - sz // Reset below.
+ if err != nil {
+ break
+ }
+ continue
+ }
+
+ // Limit the write if necessary.
+ if int64(sz) >= left {
+ sz = int(left)
+ }
+
+ // Perform the actual write.
+ n, err = w.Write(buf.ReadSlice()[offset : offset+sz])
+ done += int64(n)
+ if err != nil {
+ break
+ }
+
+ // Reset spill-over.
+ offset = 0
+ }
+ return done, err
+}
diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go
new file mode 100644
index 000000000..3db1bc6ee
--- /dev/null
+++ b/pkg/buffer/view_test.go
@@ -0,0 +1,467 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "bytes"
+ "io"
+ "strings"
+ "testing"
+)
+
+func fillAppend(v *View, data []byte) {
+ v.Append(data)
+}
+
+func fillAppendEnd(v *View, data []byte) {
+ v.Grow(bufferSize-1, false)
+ v.Append(data)
+ v.TrimFront(bufferSize - 1)
+}
+
+func fillWriteFromReader(v *View, data []byte) {
+ b := bytes.NewBuffer(data)
+ v.WriteFromReader(b, int64(len(data)))
+}
+
+func fillWriteFromReaderEnd(v *View, data []byte) {
+ v.Grow(bufferSize-1, false)
+ b := bytes.NewBuffer(data)
+ v.WriteFromReader(b, int64(len(data)))
+ v.TrimFront(bufferSize - 1)
+}
+
+var fillFuncs = map[string]func(*View, []byte){
+ "append": fillAppend,
+ "appendEnd": fillAppendEnd,
+ "writeFromReader": fillWriteFromReader,
+ "writeFromReaderEnd": fillWriteFromReaderEnd,
+}
+
+func testReadAt(t *testing.T, v *View, offset int64, n int, wantStr string, wantErr error) {
+ t.Helper()
+ d := make([]byte, n)
+ n, err := v.ReadAt(d, offset)
+ if n != len(wantStr) {
+ t.Errorf("got %d, want %d", n, len(wantStr))
+ }
+ if err != wantErr {
+ t.Errorf("got err %v, want %v", err, wantErr)
+ }
+ if !bytes.Equal(d[:n], []byte(wantStr)) {
+ t.Errorf("got %q, want %q", string(d[:n]), wantStr)
+ }
+}
+
+func TestView(t *testing.T) {
+ testCases := []struct {
+ name string
+ input string
+ output string
+ op func(*testing.T, *View)
+ }{
+ // Preconditions.
+ {
+ name: "truncate-check",
+ input: "hello",
+ output: "hello", // Not touched.
+ op: func(t *testing.T, v *View) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Truncate(-1) did not panic")
+ }
+ }()
+ v.Truncate(-1)
+ },
+ },
+ {
+ name: "grow-check",
+ input: "hello",
+ output: "hello", // Not touched.
+ op: func(t *testing.T, v *View) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Grow(-1) did not panic")
+ }
+ }()
+ v.Grow(-1, false)
+ },
+ },
+ {
+ name: "advance-check",
+ input: "hello",
+ output: "", // Consumed.
+ op: func(t *testing.T, v *View) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("advanceRead(Size()+1) did not panic")
+ }
+ }()
+ v.advanceRead(v.Size() + 1)
+ },
+ },
+
+ // Prepend.
+ {
+ name: "prepend",
+ input: "world",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ v.Prepend([]byte("hello "))
+ },
+ },
+ {
+ name: "prepend-backfill-full",
+ input: "hello world",
+ output: "jello world",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(1)
+ v.Prepend([]byte("j"))
+ },
+ },
+ {
+ name: "prepend-backfill-under",
+ input: "hello world",
+ output: "hola world",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(5)
+ v.Prepend([]byte("hola"))
+ },
+ },
+ {
+ name: "prepend-backfill-over",
+ input: "hello world",
+ output: "smello world",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(1)
+ v.Prepend([]byte("sm"))
+ },
+ },
+ {
+ name: "prepend-fill",
+ input: strings.Repeat("1", bufferSize-1),
+ output: "0" + strings.Repeat("1", bufferSize-1),
+ op: func(t *testing.T, v *View) {
+ v.Prepend([]byte("0"))
+ },
+ },
+ {
+ name: "prepend-overflow",
+ input: strings.Repeat("1", bufferSize),
+ output: "0" + strings.Repeat("1", bufferSize),
+ op: func(t *testing.T, v *View) {
+ v.Prepend([]byte("0"))
+ },
+ },
+ {
+ name: "prepend-multiple-buffers",
+ input: strings.Repeat("1", bufferSize-1),
+ output: strings.Repeat("0", bufferSize*3) + strings.Repeat("1", bufferSize-1),
+ op: func(t *testing.T, v *View) {
+ v.Prepend([]byte(strings.Repeat("0", bufferSize*3)))
+ },
+ },
+
+ // Append and write.
+ {
+ name: "append",
+ input: "hello",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ v.Append([]byte(" world"))
+ },
+ },
+ {
+ name: "append-fill",
+ input: strings.Repeat("1", bufferSize-1),
+ output: strings.Repeat("1", bufferSize-1) + "0",
+ op: func(t *testing.T, v *View) {
+ v.Append([]byte("0"))
+ },
+ },
+ {
+ name: "append-overflow",
+ input: strings.Repeat("1", bufferSize),
+ output: strings.Repeat("1", bufferSize) + "0",
+ op: func(t *testing.T, v *View) {
+ v.Append([]byte("0"))
+ },
+ },
+ {
+ name: "append-multiple-buffers",
+ input: strings.Repeat("1", bufferSize-1),
+ output: strings.Repeat("1", bufferSize-1) + strings.Repeat("0", bufferSize*3),
+ op: func(t *testing.T, v *View) {
+ v.Append([]byte(strings.Repeat("0", bufferSize*3)))
+ },
+ },
+
+ // Truncate.
+ {
+ name: "truncate",
+ input: "hello world",
+ output: "hello",
+ op: func(t *testing.T, v *View) {
+ v.Truncate(5)
+ },
+ },
+ {
+ name: "truncate-noop",
+ input: "hello world",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ v.Truncate(v.Size() + 1)
+ },
+ },
+ {
+ name: "truncate-multiple-buffers",
+ input: strings.Repeat("1", bufferSize*2),
+ output: strings.Repeat("1", bufferSize*2-1),
+ op: func(t *testing.T, v *View) {
+ v.Truncate(bufferSize*2 - 1)
+ },
+ },
+ {
+ name: "truncate-multiple-buffers-to-one",
+ input: strings.Repeat("1", bufferSize*2),
+ output: "11111",
+ op: func(t *testing.T, v *View) {
+ v.Truncate(5)
+ },
+ },
+
+ // TrimFront.
+ {
+ name: "trim",
+ input: "hello world",
+ output: "world",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(6)
+ },
+ },
+ {
+ name: "trim-too-large",
+ input: "hello world",
+ output: "",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(v.Size() + 1)
+ },
+ },
+ {
+ name: "trim-multiple-buffers",
+ input: strings.Repeat("1", bufferSize*2),
+ output: strings.Repeat("1", bufferSize*2-1),
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(1)
+ },
+ },
+ {
+ name: "trim-multiple-buffers-to-one-buffer",
+ input: strings.Repeat("1", bufferSize*2),
+ output: "1",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(bufferSize*2 - 1)
+ },
+ },
+
+ // Grow.
+ {
+ name: "grow",
+ input: "hello world",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ v.Grow(1, true)
+ },
+ },
+ {
+ name: "grow-from-zero",
+ output: strings.Repeat("\x00", 1024),
+ op: func(t *testing.T, v *View) {
+ v.Grow(1024, true)
+ },
+ },
+ {
+ name: "grow-from-non-zero",
+ input: strings.Repeat("1", bufferSize),
+ output: strings.Repeat("1", bufferSize) + strings.Repeat("\x00", bufferSize),
+ op: func(t *testing.T, v *View) {
+ v.Grow(bufferSize*2, true)
+ },
+ },
+
+ // Copy.
+ {
+ name: "copy",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) {
+ other := v.Copy()
+ bs := other.Flatten()
+ want := []byte("hello")
+ if !bytes.Equal(bs, want) {
+ t.Errorf("expected %v, got %v", want, bs)
+ }
+ },
+ },
+ {
+ name: "copy-large",
+ input: strings.Repeat("1", bufferSize+1),
+ output: strings.Repeat("1", bufferSize+1),
+ op: func(t *testing.T, v *View) {
+ other := v.Copy()
+ bs := other.Flatten()
+ want := []byte(strings.Repeat("1", bufferSize+1))
+ if !bytes.Equal(bs, want) {
+ t.Errorf("expected %v, got %v", want, bs)
+ }
+ },
+ },
+
+ // Merge.
+ {
+ name: "merge",
+ input: "hello",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ var other View
+ other.Append([]byte(" world"))
+ v.Merge(&other)
+ if sz := other.Size(); sz != 0 {
+ t.Errorf("expected 0, got %d", sz)
+ }
+ },
+ },
+ {
+ name: "merge-large",
+ input: strings.Repeat("1", bufferSize+1),
+ output: strings.Repeat("1", bufferSize+1) + strings.Repeat("0", bufferSize+1),
+ op: func(t *testing.T, v *View) {
+ var other View
+ other.Append([]byte(strings.Repeat("0", bufferSize+1)))
+ v.Merge(&other)
+ if sz := other.Size(); sz != 0 {
+ t.Errorf("expected 0, got %d", sz)
+ }
+ },
+ },
+
+ // ReadAt.
+ {
+ name: "readat",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 6, "hello", io.EOF) },
+ },
+ {
+ name: "readat-long",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 8, "hello", io.EOF) },
+ },
+ {
+ name: "readat-short",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 3, "hel", nil) },
+ },
+ {
+ name: "readat-offset",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 3, "llo", io.EOF) },
+ },
+ {
+ name: "readat-long-offset",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 8, "llo", io.EOF) },
+ },
+ {
+ name: "readat-short-offset",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 2, "ll", nil) },
+ },
+ {
+ name: "readat-skip-all",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 1, "", io.EOF) },
+ },
+ {
+ name: "readat-second-buffer",
+ input: strings.Repeat("0", bufferSize+1) + "12",
+ output: strings.Repeat("0", bufferSize+1) + "12",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 1, "1", nil) },
+ },
+ {
+ name: "readat-second-buffer-end",
+ input: strings.Repeat("0", bufferSize+1) + "12",
+ output: strings.Repeat("0", bufferSize+1) + "12",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 2, "12", io.EOF) },
+ },
+ }
+
+ for _, tc := range testCases {
+ for fillName, fn := range fillFuncs {
+ t.Run(fillName+"/"+tc.name, func(t *testing.T) {
+ // Construct & fill the view.
+ var view View
+ fn(&view, []byte(tc.input))
+
+ // Run the operation.
+ if tc.op != nil {
+ tc.op(t, &view)
+ }
+
+ // Flatten and validate.
+ out := view.Flatten()
+ if !bytes.Equal([]byte(tc.output), out) {
+ t.Errorf("expected %q, got %q", tc.output, string(out))
+ }
+
+ // Ensure the size is correct.
+ if len(out) != int(view.Size()) {
+ t.Errorf("size is wrong: expected %d, got %d", len(out), view.Size())
+ }
+
+ // Calculate contents via apply.
+ var appliedOut []byte
+ view.Apply(func(b []byte) {
+ appliedOut = append(appliedOut, b...)
+ })
+ if len(appliedOut) != len(out) {
+ t.Errorf("expected %d, got %d", len(out), len(appliedOut))
+ }
+ if !bytes.Equal(appliedOut, out) {
+ t.Errorf("expected %v, got %v", out, appliedOut)
+ }
+
+ // Calculate contents via ReadToWriter.
+ var b bytes.Buffer
+ n, err := view.ReadToWriter(&b, int64(len(out)))
+ if n != int64(len(out)) {
+ t.Errorf("expected %d, got %d", len(out), n)
+ }
+ if err != nil {
+ t.Errorf("expected nil, got %v", err)
+ }
+ if !bytes.Equal(b.Bytes(), out) {
+ t.Errorf("expected %v, got %v", out, b.Bytes())
+ }
+ })
+ }
+ }
+}
diff --git a/pkg/buffer/view_unsafe.go b/pkg/buffer/view_unsafe.go
new file mode 100644
index 000000000..d1ef39b26
--- /dev/null
+++ b/pkg/buffer/view_unsafe.go
@@ -0,0 +1,25 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "unsafe"
+)
+
+// minBatch is the smallest Read or Write operation that the
+// WriteFromReader and ReadToWriter functions will use.
+//
+// This is defined as the size of a native pointer.
+const minBatch = int(unsafe.Sizeof(uintptr(0)))
diff --git a/pkg/cleanup/BUILD b/pkg/cleanup/BUILD
new file mode 100644
index 000000000..5c34b9872
--- /dev/null
+++ b/pkg/cleanup/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "cleanup",
+ srcs = ["cleanup.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ ],
+)
+
+go_test(
+ name = "cleanup_test",
+ srcs = ["cleanup_test.go"],
+ library = ":cleanup",
+)
diff --git a/pkg/cleanup/cleanup.go b/pkg/cleanup/cleanup.go
new file mode 100644
index 000000000..14a05f076
--- /dev/null
+++ b/pkg/cleanup/cleanup.go
@@ -0,0 +1,60 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package cleanup provides utilities to clean "stuff" on defers.
+package cleanup
+
+// Cleanup allows defers to be aborted when cleanup needs to happen
+// conditionally. Usage:
+// cu := cleanup.Make(func() { f.Close() })
+// defer cu.Clean() // failure before release is called will close the file.
+// ...
+// cu.Add(func() { f2.Close() }) // Adds another cleanup function
+// ...
+// cu.Release() // on success, aborts closing the file.
+// return f
+type Cleanup struct {
+ cleaners []func()
+}
+
+// Make creates a new Cleanup object.
+func Make(f func()) Cleanup {
+ return Cleanup{cleaners: []func(){f}}
+}
+
+// Add adds a new function to be called on Clean().
+func (c *Cleanup) Add(f func()) {
+ c.cleaners = append(c.cleaners, f)
+}
+
+// Clean calls all cleanup functions in reverse order.
+func (c *Cleanup) Clean() {
+ clean(c.cleaners)
+ c.cleaners = nil
+}
+
+// Release releases the cleanup from its duties, i.e. cleanup functions are not
+// called after this point. Returns a function that calls all registered
+// functions in case the caller has use for them.
+func (c *Cleanup) Release() func() {
+ old := c.cleaners
+ c.cleaners = nil
+ return func() { clean(old) }
+}
+
+func clean(cleaners []func()) {
+ for i := len(cleaners) - 1; i >= 0; i-- {
+ cleaners[i]()
+ }
+}
diff --git a/pkg/cleanup/cleanup_test.go b/pkg/cleanup/cleanup_test.go
new file mode 100644
index 000000000..ab3d9ed95
--- /dev/null
+++ b/pkg/cleanup/cleanup_test.go
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cleanup
+
+import "testing"
+
+func testCleanupHelper(clean, cleanAdd *bool, release bool) func() {
+ cu := Make(func() {
+ *clean = true
+ })
+ cu.Add(func() {
+ *cleanAdd = true
+ })
+ defer cu.Clean()
+ if release {
+ return cu.Release()
+ }
+ return nil
+}
+
+func TestCleanup(t *testing.T) {
+ clean := false
+ cleanAdd := false
+ testCleanupHelper(&clean, &cleanAdd, false)
+ if !clean {
+ t.Fatalf("cleanup function was not called.")
+ }
+ if !cleanAdd {
+ t.Fatalf("added cleanup function was not called.")
+ }
+}
+
+func TestRelease(t *testing.T) {
+ clean := false
+ cleanAdd := false
+ cleaner := testCleanupHelper(&clean, &cleanAdd, true)
+
+ // Check that clean was not called after release.
+ if clean {
+ t.Fatalf("cleanup function was called.")
+ }
+ if cleanAdd {
+ t.Fatalf("added cleanup function was called.")
+ }
+
+ // Call the cleaner function and check that both cleanup functions are called.
+ cleaner()
+ if !clean {
+ t.Fatalf("cleanup function was not called.")
+ }
+ if !cleanAdd {
+ t.Fatalf("added cleanup function was not called.")
+ }
+}
diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD
new file mode 100644
index 000000000..1f75319a7
--- /dev/null
+++ b/pkg/compressio/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "compressio",
+ srcs = ["compressio.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/binary",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "compressio_test",
+ size = "medium",
+ srcs = ["compressio_test.go"],
+ library = ":compressio",
+)
diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go
new file mode 100644
index 000000000..b094c5662
--- /dev/null
+++ b/pkg/compressio/compressio.go
@@ -0,0 +1,773 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package compressio provides parallel compression and decompression, as well
+// as optional SHA-256 hashing.
+//
+// The stream format is defined as follows.
+//
+// /------------------------------------------------------\
+// | chunk size (4-bytes) |
+// +------------------------------------------------------+
+// | (optional) hash (32-bytes) |
+// +------------------------------------------------------+
+// | compressed data size (4-bytes) |
+// +------------------------------------------------------+
+// | compressed data |
+// +------------------------------------------------------+
+// | (optional) hash (32-bytes) |
+// +------------------------------------------------------+
+// | compressed data size (4-bytes) |
+// +------------------------------------------------------+
+// | ...... |
+// \------------------------------------------------------/
+//
+// where each subsequent hash is calculated from the following items in order
+//
+// compressed data
+// compressed data size
+// previous hash
+//
+// so the stream integrity cannot be compromised by switching and mixing
+// compressed chunks.
+package compressio
+
+import (
+ "bytes"
+ "compress/flate"
+ "crypto/hmac"
+ "crypto/sha256"
+ "errors"
+ "hash"
+ "io"
+ "runtime"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+var bufPool = sync.Pool{
+ New: func() interface{} {
+ return bytes.NewBuffer(nil)
+ },
+}
+
+var chunkPool = sync.Pool{
+ New: func() interface{} {
+ return new(chunk)
+ },
+}
+
+// chunk is a unit of work.
+type chunk struct {
+ // compressed is compressed data.
+ //
+ // This will always be returned to the bufPool directly when work has
+ // finished (in schedule) and therefore must be allocated.
+ compressed *bytes.Buffer
+
+ // uncompressed is the uncompressed data.
+ //
+ // This is not returned to the bufPool automatically, since it may
+ // correspond to a inline slice (provided directly to Read or Write).
+ uncompressed *bytes.Buffer
+
+ // The current hash object. Only used in compress mode.
+ h hash.Hash
+
+ // The hash from previous chunks. Only used in uncompress mode.
+ lastSum []byte
+
+ // The expected hash after current chunk. Only used in uncompress mode.
+ sum []byte
+}
+
+// newChunk allocates a new chunk object (or pulls one from the pool). Buffers
+// will be allocated if nil is provided for compressed or uncompressed.
+func newChunk(lastSum []byte, sum []byte, compressed *bytes.Buffer, uncompressed *bytes.Buffer) *chunk {
+ c := chunkPool.Get().(*chunk)
+ c.lastSum = lastSum
+ c.sum = sum
+ if compressed != nil {
+ c.compressed = compressed
+ } else {
+ c.compressed = bufPool.Get().(*bytes.Buffer)
+ }
+ if uncompressed != nil {
+ c.uncompressed = uncompressed
+ } else {
+ c.uncompressed = bufPool.Get().(*bytes.Buffer)
+ }
+ return c
+}
+
+// result is the result of some work; it includes the original chunk.
+type result struct {
+ *chunk
+ err error
+}
+
+// worker is a compression/decompression worker.
+//
+// The associated worker goroutine reads in uncompressed buffers from input and
+// writes compressed buffers to its output. Alternatively, the worker reads
+// compressed buffers from input and writes uncompressed buffers to its output.
+//
+// The goroutine will exit when input is closed, and the goroutine will close
+// output.
+type worker struct {
+ hashPool *hashPool
+ input chan *chunk
+ output chan result
+}
+
+// work is the main work routine; see worker.
+func (w *worker) work(compress bool, level int) {
+ defer close(w.output)
+
+ var h hash.Hash
+
+ for c := range w.input {
+ if h == nil && w.hashPool != nil {
+ h = w.hashPool.getHash()
+ }
+ if compress {
+ mw := io.Writer(c.compressed)
+ if h != nil {
+ mw = io.MultiWriter(mw, h)
+ }
+
+ // Encode this slice.
+ fw, err := flate.NewWriter(mw, level)
+ if err != nil {
+ w.output <- result{c, err}
+ continue
+ }
+
+ // Encode the input.
+ if _, err := io.CopyN(fw, c.uncompressed, int64(c.uncompressed.Len())); err != nil {
+ w.output <- result{c, err}
+ continue
+ }
+ if err := fw.Close(); err != nil {
+ w.output <- result{c, err}
+ continue
+ }
+
+ // Write the hash, if enabled.
+ if h != nil {
+ binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len()))
+ c.h = h
+ h = nil
+ }
+ } else {
+ // Check the hash of the compressed contents.
+ if h != nil {
+ h.Write(c.compressed.Bytes())
+ binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len()))
+ io.CopyN(h, bytes.NewReader(c.lastSum), int64(len(c.lastSum)))
+
+ sum := h.Sum(nil)
+ h.Reset()
+ if !hmac.Equal(c.sum, sum) {
+ w.output <- result{c, ErrHashMismatch}
+ continue
+ }
+ }
+
+ // Decode this slice.
+ fr := flate.NewReader(c.compressed)
+
+ // Decode the input.
+ if _, err := io.Copy(c.uncompressed, fr); err != nil {
+ w.output <- result{c, err}
+ continue
+ }
+ }
+
+ // Send the output.
+ w.output <- result{c, nil}
+ }
+}
+
+type hashPool struct {
+ // mu protexts the hash list.
+ mu sync.Mutex
+
+ // key is the key used to create hash objects.
+ key []byte
+
+ // hashes is the hash object free list. Note that this cannot be
+ // globally shared across readers or writers, as it is key-specific.
+ hashes []hash.Hash
+}
+
+// getHash gets a hash object for the pool. It should only be called when the
+// pool key is non-nil.
+func (p *hashPool) getHash() hash.Hash {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if len(p.hashes) == 0 {
+ return hmac.New(sha256.New, p.key)
+ }
+
+ h := p.hashes[len(p.hashes)-1]
+ p.hashes = p.hashes[:len(p.hashes)-1]
+ return h
+}
+
+func (p *hashPool) putHash(h hash.Hash) {
+ h.Reset()
+
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ p.hashes = append(p.hashes, h)
+}
+
+// pool is common functionality for reader/writers.
+type pool struct {
+ // workers are the compression/decompression workers.
+ workers []worker
+
+ // chunkSize is the chunk size. This is the first four bytes in the
+ // stream and is shared across both the reader and writer.
+ chunkSize uint32
+
+ // mu protects below; it is generally the responsibility of users to
+ // acquire this mutex before calling any methods on the pool.
+ mu sync.Mutex
+
+ // nextInput is the next worker for input (scheduling).
+ nextInput int
+
+ // nextOutput is the next worker for output (result).
+ nextOutput int
+
+ // buf is the current active buffer; the exact semantics of this buffer
+ // depending on whether this is a reader or a writer.
+ buf *bytes.Buffer
+
+ // lasSum records the hash of the last chunk processed.
+ lastSum []byte
+
+ // hashPool is the hash object pool. It cannot be embedded into pool
+ // itself as worker refers to it and that would stop pool from being
+ // GCed.
+ hashPool *hashPool
+}
+
+// init initializes the worker pool.
+//
+// This should only be called once.
+func (p *pool) init(key []byte, workers int, compress bool, level int) {
+ if key != nil {
+ p.hashPool = &hashPool{key: key}
+ }
+ p.workers = make([]worker, workers)
+ for i := 0; i < len(p.workers); i++ {
+ p.workers[i] = worker{
+ hashPool: p.hashPool,
+ input: make(chan *chunk, 1),
+ output: make(chan result, 1),
+ }
+ go p.workers[i].work(compress, level) // S/R-SAFE: In save path only.
+ }
+ runtime.SetFinalizer(p, (*pool).stop)
+}
+
+// stop stops all workers.
+func (p *pool) stop() {
+ for i := 0; i < len(p.workers); i++ {
+ close(p.workers[i].input)
+ }
+ p.workers = nil
+ p.hashPool = nil
+}
+
+// handleResult calls the callback.
+func handleResult(r result, callback func(*chunk) error) error {
+ defer func() {
+ r.chunk.compressed.Reset()
+ bufPool.Put(r.chunk.compressed)
+ chunkPool.Put(r.chunk)
+ }()
+ if r.err != nil {
+ return r.err
+ }
+ return callback(r.chunk)
+}
+
+// schedule schedules the given buffers.
+//
+// If c is non-nil, then it will return as soon as the chunk is scheduled. If c
+// is nil, then it will return only when no more work is left to do.
+//
+// If no callback function is provided, then the output channel will be
+// ignored. You must be sure that the input is schedulable in this case.
+func (p *pool) schedule(c *chunk, callback func(*chunk) error) error {
+ for {
+ var (
+ inputChan chan *chunk
+ outputChan chan result
+ )
+ if c != nil && len(p.workers) != 0 {
+ inputChan = p.workers[(p.nextInput+1)%len(p.workers)].input
+ }
+ if callback != nil && p.nextOutput != p.nextInput && len(p.workers) != 0 {
+ outputChan = p.workers[(p.nextOutput+1)%len(p.workers)].output
+ }
+ if inputChan == nil && outputChan == nil {
+ return nil
+ }
+
+ select {
+ case inputChan <- c:
+ p.nextInput++
+ return nil
+ case r := <-outputChan:
+ p.nextOutput++
+ if err := handleResult(r, callback); err != nil {
+ return err
+ }
+ }
+ }
+}
+
+// Reader is a compressed reader.
+type Reader struct {
+ pool
+
+ // in is the source.
+ in io.Reader
+}
+
+var _ io.Reader = (*Reader)(nil)
+
+// NewReader returns a new compressed reader. If key is non-nil, the data stream
+// is assumed to contain expected hash values, which will be compared against
+// hash values computed from the compressed bytes. See package comments for
+// details.
+func NewReader(in io.Reader, key []byte) (*Reader, error) {
+ r := &Reader{
+ in: in,
+ }
+
+ // Use double buffering for read.
+ r.init(key, 2*runtime.GOMAXPROCS(0), false, 0)
+
+ var err error
+ if r.chunkSize, err = binary.ReadUint32(in, binary.BigEndian); err != nil {
+ return nil, err
+ }
+
+ if r.hashPool != nil {
+ h := r.hashPool.getHash()
+ binary.WriteUint32(h, binary.BigEndian, r.chunkSize)
+ r.lastSum = h.Sum(nil)
+ r.hashPool.putHash(h)
+ sum := make([]byte, len(r.lastSum))
+ if _, err := io.ReadFull(r.in, sum); err != nil {
+ return nil, err
+ }
+ if !hmac.Equal(r.lastSum, sum) {
+ return nil, ErrHashMismatch
+ }
+ }
+
+ return r, nil
+}
+
+// errNewBuffer is returned when a new buffer is completed.
+var errNewBuffer = errors.New("buffer ready")
+
+// ErrHashMismatch is returned if the hash does not match.
+var ErrHashMismatch = errors.New("hash mismatch")
+
+// ReadByte implements wire.Reader.ReadByte.
+func (r *Reader) ReadByte() (byte, error) {
+ var p [1]byte
+ n, err := r.Read(p[:])
+ if n != 1 {
+ return p[0], err
+ }
+ // Suppress EOF.
+ return p[0], nil
+}
+
+// Read implements io.Reader.Read.
+func (r *Reader) Read(p []byte) (int, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ // Total bytes completed; this is declared up front because it must be
+ // adjustable by the callback below.
+ done := 0
+
+ // Total bytes pending in the asynchronous workers for buffers. This is
+ // used to process the proper regions of the input as inline buffers.
+ var (
+ pendingPre = r.nextInput - r.nextOutput
+ pendingInline = 0
+ )
+
+ // Define our callback for completed work.
+ callback := func(c *chunk) error {
+ // Check for an inline buffer.
+ if pendingPre == 0 && pendingInline > 0 {
+ pendingInline--
+ done += c.uncompressed.Len()
+ return nil
+ }
+
+ // Copy the resulting buffer to our intermediate one, and
+ // return errNewBuffer to ensure that we aren't called a second
+ // time. This error code is handled specially below.
+ //
+ // c.buf will be freed and return to the pool when it is done.
+ if pendingPre > 0 {
+ pendingPre--
+ }
+ r.buf = c.uncompressed
+ return errNewBuffer
+ }
+
+ for done < len(p) {
+ // Do we have buffered data available?
+ if r.buf != nil {
+ n, err := r.buf.Read(p[done:])
+ done += n
+ if err == io.EOF {
+ // This is the uncompressed buffer, it can be
+ // returned to the pool at this point.
+ r.buf.Reset()
+ bufPool.Put(r.buf)
+ r.buf = nil
+ } else if err != nil {
+ // Should never happen.
+ defer r.stop()
+ return done, err
+ }
+ continue
+ }
+
+ // Read the length of the next chunk and reset the
+ // reader. The length is used to limit the reader.
+ //
+ // See writer.flush.
+ l, err := binary.ReadUint32(r.in, binary.BigEndian)
+ if err != nil {
+ // This is generally okay as long as there
+ // are still buffers outstanding. We actually
+ // just wait for completion of those buffers here
+ // and continue our loop.
+ if err := r.schedule(nil, callback); err == nil {
+ // We've actually finished all buffers; this is
+ // the normal EOF exit path.
+ defer r.stop()
+ return done, io.EOF
+ } else if err == errNewBuffer {
+ // A new buffer is now available.
+ continue
+ } else {
+ // Some other error occurred; we cannot
+ // process any further.
+ defer r.stop()
+ return done, err
+ }
+ }
+
+ // Read this chunk and schedule decompression.
+ compressed := bufPool.Get().(*bytes.Buffer)
+ if _, err := io.CopyN(compressed, r.in, int64(l)); err != nil {
+ // Some other error occurred; see above.
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return done, err
+ }
+
+ var sum []byte
+ if r.hashPool != nil {
+ sum = make([]byte, len(r.lastSum))
+ if _, err := io.ReadFull(r.in, sum); err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return done, err
+ }
+ }
+
+ // Are we doing inline decoding?
+ //
+ // Note that we need to check the length here against
+ // bytes.MinRead, since the bytes library will choose to grow
+ // the slice if the available capacity is not at least
+ // bytes.MinRead. This limits inline decoding to chunkSizes
+ // that are at least bytes.MinRead (which is not unreasonable).
+ var c *chunk
+ start := done + ((pendingPre + pendingInline) * int(r.chunkSize))
+ if len(p) >= start+int(r.chunkSize) && len(p) >= start+bytes.MinRead {
+ c = newChunk(r.lastSum, sum, compressed, bytes.NewBuffer(p[start:start]))
+ pendingInline++
+ } else {
+ c = newChunk(r.lastSum, sum, compressed, nil)
+ }
+ r.lastSum = sum
+ if err := r.schedule(c, callback); err == errNewBuffer {
+ // A new buffer was completed while we were reading.
+ // That's great, but we need to force schedule the
+ // current buffer so that it does not get lost.
+ //
+ // It is safe to pass nil as an output function here,
+ // because we know that we just freed up a slot above.
+ r.schedule(c, nil)
+ } else if err != nil {
+ // Some other error occurred; see above.
+ defer r.stop()
+ return done, err
+ }
+ }
+
+ // Make sure that everything has been decoded successfully, otherwise
+ // parts of p may not actually have completed.
+ for pendingInline > 0 {
+ if err := r.schedule(nil, func(c *chunk) error {
+ if err := callback(c); err != nil {
+ return err
+ }
+ // The nil case means that an inline buffer has
+ // completed. The callback will have already removed
+ // the inline buffer from the map, so we just return an
+ // error to check the top of the loop again.
+ return errNewBuffer
+ }); err != errNewBuffer {
+ // Some other error occurred; see above.
+ return done, err
+ }
+ }
+
+ // Need to return done here, since it may have been adjusted by the
+ // callback to compensation for partial reads on some inline buffer.
+ return done, nil
+}
+
+// Writer is a compressed writer.
+type Writer struct {
+ pool
+
+ // out is the underlying writer.
+ out io.Writer
+
+ // closed indicates whether the file has been closed.
+ closed bool
+}
+
+var _ io.Writer = (*Writer)(nil)
+
+// NewWriter returns a new compressed writer. If key is non-nil, hash values are
+// generated and written out for compressed bytes. See package comments for
+// details.
+//
+// The recommended chunkSize is on the order of 1M. Extra memory may be
+// buffered (in the form of read-ahead, or buffered writes), and is limited to
+// O(chunkSize * [1+GOMAXPROCS]).
+func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (*Writer, error) {
+ w := &Writer{
+ pool: pool{
+ chunkSize: chunkSize,
+ buf: bufPool.Get().(*bytes.Buffer),
+ },
+ out: out,
+ }
+ w.init(key, 1+runtime.GOMAXPROCS(0), true, level)
+
+ if err := binary.WriteUint32(w.out, binary.BigEndian, chunkSize); err != nil {
+ return nil, err
+ }
+
+ if w.hashPool != nil {
+ h := w.hashPool.getHash()
+ binary.WriteUint32(h, binary.BigEndian, chunkSize)
+ w.lastSum = h.Sum(nil)
+ w.hashPool.putHash(h)
+ if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil {
+ return nil, err
+ }
+ }
+
+ return w, nil
+}
+
+// flush writes a single buffer.
+func (w *Writer) flush(c *chunk) error {
+ // Prefix each chunk with a length; this allows the reader to safely
+ // limit reads while buffering.
+ l := uint32(c.compressed.Len())
+ if err := binary.WriteUint32(w.out, binary.BigEndian, l); err != nil {
+ return err
+ }
+
+ // Write out to the stream.
+ if _, err := io.CopyN(w.out, c.compressed, int64(c.compressed.Len())); err != nil {
+ return err
+ }
+
+ if w.hashPool != nil {
+ io.CopyN(c.h, bytes.NewReader(w.lastSum), int64(len(w.lastSum)))
+ sum := c.h.Sum(nil)
+ w.hashPool.putHash(c.h)
+ c.h = nil
+ if _, err := io.CopyN(w.out, bytes.NewReader(sum), int64(len(sum))); err != nil {
+ return err
+ }
+ w.lastSum = sum
+ }
+
+ return nil
+}
+
+// WriteByte implements wire.Writer.WriteByte.
+//
+// Note that this implementation is necessary on the object itself, as an
+// interface-based dispatch cannot tell whether the array backing the slice
+// escapes, therefore the all bytes written will generate an escape.
+func (w *Writer) WriteByte(b byte) error {
+ var p [1]byte
+ p[0] = b
+ n, err := w.Write(p[:])
+ if n != 1 {
+ return err
+ }
+ return nil
+}
+
+// Write implements io.Writer.Write.
+func (w *Writer) Write(p []byte) (int, error) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ // Did we close already?
+ if w.closed {
+ return 0, io.ErrUnexpectedEOF
+ }
+
+ // See above; we need to track in the same way.
+ var (
+ pendingPre = w.nextInput - w.nextOutput
+ pendingInline = 0
+ )
+ callback := func(c *chunk) error {
+ if pendingPre == 0 && pendingInline > 0 {
+ pendingInline--
+ return w.flush(c)
+ }
+ if pendingPre > 0 {
+ pendingPre--
+ }
+ err := w.flush(c)
+ c.uncompressed.Reset()
+ bufPool.Put(c.uncompressed)
+ return err
+ }
+
+ for done := 0; done < len(p); {
+ // Construct an inline buffer if we're doing an inline
+ // encoding; see above regarding the bytes.MinRead constraint.
+ if w.buf.Len() == 0 && len(p) >= done+int(w.chunkSize) && len(p) >= done+bytes.MinRead {
+ bufPool.Put(w.buf) // Return to the pool; never scheduled.
+ w.buf = bytes.NewBuffer(p[done : done+int(w.chunkSize)])
+ done += int(w.chunkSize)
+ pendingInline++
+ }
+
+ // Do we need to flush w.buf? Note that this case should be hit
+ // immediately following the inline case above.
+ left := int(w.chunkSize) - w.buf.Len()
+ if left == 0 {
+ if err := w.schedule(newChunk(nil, nil, nil, w.buf), callback); err != nil {
+ return done, err
+ }
+ // Reset the buffer, since this has now been scheduled
+ // for compression. Note that this may be trampled
+ // immediately by the bufPool.Put(w.buf) above if the
+ // next buffer happens to be inline, but that's okay.
+ w.buf = bufPool.Get().(*bytes.Buffer)
+ continue
+ }
+
+ // Read from p into w.buf.
+ toWrite := len(p) - done
+ if toWrite > left {
+ toWrite = left
+ }
+ n, err := w.buf.Write(p[done : done+toWrite])
+ done += n
+ if err != nil {
+ return done, err
+ }
+ }
+
+ // Make sure that everything has been flushed, we can't return until
+ // all the contents from p have been used.
+ for pendingInline > 0 {
+ if err := w.schedule(nil, func(c *chunk) error {
+ if err := callback(c); err != nil {
+ return err
+ }
+ // The flush was successful, return errNewBuffer here
+ // to break from the loop and check the condition
+ // again.
+ return errNewBuffer
+ }); err != errNewBuffer {
+ return len(p), err
+ }
+ }
+
+ return len(p), nil
+}
+
+// Close implements io.Closer.Close.
+func (w *Writer) Close() error {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ // Did we already close? After the call to Close, we always mark as
+ // closed, regardless of whether the flush is successful.
+ if w.closed {
+ return io.ErrUnexpectedEOF
+ }
+ w.closed = true
+ defer w.stop()
+
+ // Schedule any remaining partial buffer; we pass w.flush directly here
+ // because the final buffer is guaranteed to not be an inline buffer.
+ if w.buf.Len() > 0 {
+ if err := w.schedule(newChunk(nil, nil, nil, w.buf), w.flush); err != nil {
+ return err
+ }
+ }
+
+ // Flush all scheduled buffers; see above.
+ if err := w.schedule(nil, w.flush); err != nil {
+ return err
+ }
+
+ // Close the underlying writer (if necessary).
+ if closer, ok := w.out.(io.Closer); ok {
+ return closer.Close()
+ }
+ return nil
+}
diff --git a/pkg/compressio/compressio_test.go b/pkg/compressio/compressio_test.go
new file mode 100644
index 000000000..86dc47e44
--- /dev/null
+++ b/pkg/compressio/compressio_test.go
@@ -0,0 +1,290 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package compressio
+
+import (
+ "bytes"
+ "compress/flate"
+ "encoding/base64"
+ "fmt"
+ "io"
+ "math/rand"
+ "runtime"
+ "testing"
+ "time"
+)
+
+type harness interface {
+ Errorf(format string, v ...interface{})
+ Fatalf(format string, v ...interface{})
+ Logf(format string, v ...interface{})
+}
+
+func initTest(t harness, size int) []byte {
+ // Set number of processes to number of CPUs.
+ runtime.GOMAXPROCS(runtime.NumCPU())
+
+ // Construct synthetic data. We do this by encoding random data with
+ // base64. This gives a high level of entropy, but still quite a bit of
+ // structure, to give reasonable compression ratios (~75%).
+ var buf bytes.Buffer
+ bufW := base64.NewEncoder(base64.RawStdEncoding, &buf)
+ bufR := rand.New(rand.NewSource(0))
+ if _, err := io.CopyN(bufW, bufR, int64(size)); err != nil {
+ t.Fatalf("unable to seed random data: %v", err)
+ }
+ return buf.Bytes()
+}
+
+type testOpts struct {
+ Name string
+ Data []byte
+ NewWriter func(*bytes.Buffer) (io.Writer, error)
+ NewReader func(*bytes.Buffer) (io.Reader, error)
+ PreCompress func()
+ PostCompress func()
+ PreDecompress func()
+ PostDecompress func()
+ CompressIters int
+ DecompressIters int
+ CorruptData bool
+}
+
+func doTest(t harness, opts testOpts) {
+ // Compress.
+ var compressed bytes.Buffer
+ compressionStartTime := time.Now()
+ if opts.PreCompress != nil {
+ opts.PreCompress()
+ }
+ if opts.CompressIters <= 0 {
+ opts.CompressIters = 1
+ }
+ for i := 0; i < opts.CompressIters; i++ {
+ compressed.Reset()
+ w, err := opts.NewWriter(&compressed)
+ if err != nil {
+ t.Errorf("%s: NewWriter got err %v, expected nil", opts.Name, err)
+ }
+ if _, err := io.Copy(w, bytes.NewBuffer(opts.Data)); err != nil {
+ t.Errorf("%s: compress got err %v, expected nil", opts.Name, err)
+ return
+ }
+ closer, ok := w.(io.Closer)
+ if ok {
+ if err := closer.Close(); err != nil {
+ t.Errorf("%s: got err %v, expected nil", opts.Name, err)
+ return
+ }
+ }
+ }
+ if opts.PostCompress != nil {
+ opts.PostCompress()
+ }
+ compressionTime := time.Since(compressionStartTime)
+ compressionRatio := float32(compressed.Len()) / float32(len(opts.Data))
+
+ // Decompress.
+ var decompressed bytes.Buffer
+ decompressionStartTime := time.Now()
+ if opts.PreDecompress != nil {
+ opts.PreDecompress()
+ }
+ if opts.DecompressIters <= 0 {
+ opts.DecompressIters = 1
+ }
+ if opts.CorruptData {
+ b := compressed.Bytes()
+ b[rand.Intn(len(b))]++
+ }
+ for i := 0; i < opts.DecompressIters; i++ {
+ decompressed.Reset()
+ r, err := opts.NewReader(bytes.NewBuffer(compressed.Bytes()))
+ if err != nil {
+ if opts.CorruptData {
+ continue
+ }
+ t.Errorf("%s: NewReader got err %v, expected nil", opts.Name, err)
+ return
+ }
+ if _, err := io.Copy(&decompressed, r); (err != nil) != opts.CorruptData {
+ t.Errorf("%s: decompress got err %v unexpectly", opts.Name, err)
+ return
+ }
+ }
+ if opts.PostDecompress != nil {
+ opts.PostDecompress()
+ }
+ decompressionTime := time.Since(decompressionStartTime)
+
+ if opts.CorruptData {
+ return
+ }
+
+ // Verify.
+ if decompressed.Len() != len(opts.Data) {
+ t.Errorf("%s: got %d bytes, expected %d", opts.Name, decompressed.Len(), len(opts.Data))
+ }
+ if !bytes.Equal(opts.Data, decompressed.Bytes()) {
+ t.Errorf("%s: got mismatch, expected match", opts.Name)
+ if len(opts.Data) < 32 { // Don't flood the logs.
+ t.Errorf("got %v, expected %v", decompressed.Bytes(), opts.Data)
+ }
+ }
+
+ t.Logf("%s: compression time %v, ratio %2.2f, decompression time %v",
+ opts.Name, compressionTime, compressionRatio, decompressionTime)
+}
+
+var hashKey = []byte("01234567890123456789012345678901")
+
+func TestCompress(t *testing.T) {
+ rand.Seed(time.Now().Unix())
+
+ var (
+ data = initTest(t, 10*1024*1024)
+ data0 = data[:0]
+ data1 = data[:1]
+ data2 = data[:11]
+ data3 = data[:16]
+ data4 = data[:]
+ )
+
+ for _, data := range [][]byte{data0, data1, data2, data3, data4} {
+ for _, blockSize := range []uint32{1, 4, 1024, 4 * 1024, 16 * 1024} {
+ // Skip annoying tests; they just take too long.
+ if blockSize <= 16 && len(data) > 16 {
+ continue
+ }
+
+ for _, key := range [][]byte{nil, hashKey} {
+ for _, corruptData := range []bool{false, true} {
+ if key == nil && corruptData {
+ // No need to test corrupt data
+ // case when not doing hashing.
+ continue
+ }
+ // Do the compress test.
+ doTest(t, testOpts{
+ Name: fmt.Sprintf("len(data)=%d, blockSize=%d, key=%s, corruptData=%v", len(data), blockSize, string(key), corruptData),
+ Data: data,
+ NewWriter: func(b *bytes.Buffer) (io.Writer, error) {
+ return NewWriter(b, key, blockSize, flate.BestSpeed)
+ },
+ NewReader: func(b *bytes.Buffer) (io.Reader, error) {
+ return NewReader(b, key)
+ },
+ CorruptData: corruptData,
+ })
+ }
+ }
+ }
+
+ // Do the vanilla test.
+ doTest(t, testOpts{
+ Name: fmt.Sprintf("len(data)=%d, vanilla flate", len(data)),
+ Data: data,
+ NewWriter: func(b *bytes.Buffer) (io.Writer, error) {
+ return flate.NewWriter(b, flate.BestSpeed)
+ },
+ NewReader: func(b *bytes.Buffer) (io.Reader, error) {
+ return flate.NewReader(b), nil
+ },
+ })
+ }
+}
+
+const (
+ benchDataSize = 600 * 1024 * 1024
+)
+
+func benchmark(b *testing.B, compress bool, hash bool, blockSize uint32) {
+ b.StopTimer()
+ b.SetBytes(benchDataSize)
+ data := initTest(b, benchDataSize)
+ compIters := b.N
+ decompIters := b.N
+ if compress {
+ decompIters = 0
+ } else {
+ compIters = 0
+ }
+ key := hashKey
+ if !hash {
+ key = nil
+ }
+ doTest(b, testOpts{
+ Name: fmt.Sprintf("compress=%t, hash=%t, len(data)=%d, blockSize=%d", compress, hash, len(data), blockSize),
+ Data: data,
+ PreCompress: b.StartTimer,
+ PostCompress: b.StopTimer,
+ NewWriter: func(b *bytes.Buffer) (io.Writer, error) {
+ return NewWriter(b, key, blockSize, flate.BestSpeed)
+ },
+ NewReader: func(b *bytes.Buffer) (io.Reader, error) {
+ return NewReader(b, key)
+ },
+ CompressIters: compIters,
+ DecompressIters: decompIters,
+ })
+}
+
+func BenchmarkCompressNoHash64K(b *testing.B) {
+ benchmark(b, true, false, 64*1024)
+}
+
+func BenchmarkCompressHash64K(b *testing.B) {
+ benchmark(b, true, true, 64*1024)
+}
+
+func BenchmarkDecompressNoHash64K(b *testing.B) {
+ benchmark(b, false, false, 64*1024)
+}
+
+func BenchmarkDecompressHash64K(b *testing.B) {
+ benchmark(b, false, true, 64*1024)
+}
+
+func BenchmarkCompressNoHash1M(b *testing.B) {
+ benchmark(b, true, false, 1024*1024)
+}
+
+func BenchmarkCompressHash1M(b *testing.B) {
+ benchmark(b, true, true, 1024*1024)
+}
+
+func BenchmarkDecompressNoHash1M(b *testing.B) {
+ benchmark(b, false, false, 1024*1024)
+}
+
+func BenchmarkDecompressHash1M(b *testing.B) {
+ benchmark(b, false, true, 1024*1024)
+}
+
+func BenchmarkCompressNoHash16M(b *testing.B) {
+ benchmark(b, true, false, 16*1024*1024)
+}
+
+func BenchmarkCompressHash16M(b *testing.B) {
+ benchmark(b, true, true, 16*1024*1024)
+}
+
+func BenchmarkDecompressNoHash16M(b *testing.B) {
+ benchmark(b, false, false, 16*1024*1024)
+}
+
+func BenchmarkDecompressHash16M(b *testing.B) {
+ benchmark(b, false, true, 16*1024*1024)
+}
diff --git a/pkg/context/BUILD b/pkg/context/BUILD
new file mode 100644
index 000000000..239f31149
--- /dev/null
+++ b/pkg/context/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "context",
+ srcs = ["context.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/amutex",
+ "//pkg/log",
+ ],
+)
diff --git a/pkg/context/context.go b/pkg/context/context.go
new file mode 100644
index 000000000..5319b6d8d
--- /dev/null
+++ b/pkg/context/context.go
@@ -0,0 +1,137 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package context defines an internal context type.
+//
+// The given Context conforms to the standard Go context, but mandates
+// additional methods that are specific to the kernel internals. Note however,
+// that the Context described by this package carries additional constraints
+// regarding concurrent access and retaining beyond the scope of a call.
+//
+// See the Context type for complete details.
+package context
+
+import (
+ "context"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/amutex"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+type contextID int
+
+// Globally accessible values from a context. These keys are defined in the
+// context package to resolve dependency cycles by not requiring the caller to
+// import packages usually required to get these information.
+const (
+ // CtxThreadGroupID is the current thread group ID when a context represents
+ // a task context. The value is represented as an int32.
+ CtxThreadGroupID contextID = iota
+)
+
+// ThreadGroupIDFromContext returns the current thread group ID when ctx
+// represents a task context.
+func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) {
+ if tgid := ctx.Value(CtxThreadGroupID); tgid != nil {
+ return tgid.(int32), true
+ }
+ return 0, false
+}
+
+// A Context represents a thread of execution (hereafter "goroutine" to reflect
+// Go idiosyncrasy). It carries state associated with the goroutine across API
+// boundaries.
+//
+// While Context exists for essentially the same reasons as Go's standard
+// context.Context, the standard type represents the state of an operation
+// rather than that of a goroutine. This is a critical distinction:
+//
+// - Unlike context.Context, which "may be passed to functions running in
+// different goroutines", it is *not safe* to use the same Context in multiple
+// concurrent goroutines.
+//
+// - It is *not safe* to retain a Context passed to a function beyond the scope
+// of that function call.
+//
+// In both cases, values extracted from the Context should be used instead.
+type Context interface {
+ log.Logger
+ amutex.Sleeper
+ context.Context
+
+ // UninterruptibleSleepStart indicates the beginning of an uninterruptible
+ // sleep state (equivalent to Linux's TASK_UNINTERRUPTIBLE). If deactivate
+ // is true and the Context represents a Task, the Task's AddressSpace is
+ // deactivated.
+ UninterruptibleSleepStart(deactivate bool)
+
+ // UninterruptibleSleepFinish indicates the end of an uninterruptible sleep
+ // state that was begun by a previous call to UninterruptibleSleepStart. If
+ // activate is true and the Context represents a Task, the Task's
+ // AddressSpace is activated. Normally activate is the same value as the
+ // deactivate parameter passed to UninterruptibleSleepStart.
+ UninterruptibleSleepFinish(activate bool)
+}
+
+// NoopSleeper is a noop implementation of amutex.Sleeper and UninterruptibleSleep
+// methods for anonymous embedding in other types that do not implement sleeps.
+type NoopSleeper struct {
+ amutex.NoopSleeper
+}
+
+// UninterruptibleSleepStart does nothing.
+func (NoopSleeper) UninterruptibleSleepStart(bool) {}
+
+// UninterruptibleSleepFinish does nothing.
+func (NoopSleeper) UninterruptibleSleepFinish(bool) {}
+
+// Deadline returns zero values, meaning no deadline.
+func (NoopSleeper) Deadline() (time.Time, bool) {
+ return time.Time{}, false
+}
+
+// Done returns nil.
+func (NoopSleeper) Done() <-chan struct{} {
+ return nil
+}
+
+// Err returns nil.
+func (NoopSleeper) Err() error {
+ return nil
+}
+
+// logContext implements basic logging.
+type logContext struct {
+ log.Logger
+ NoopSleeper
+}
+
+// Value implements Context.Value.
+func (logContext) Value(key interface{}) interface{} {
+ return nil
+}
+
+// bgContext is the context returned by context.Background.
+var bgContext = &logContext{Logger: log.Log()}
+
+// Background returns an empty context using the default logger.
+// Generally, one should use the Task as their context when available, or avoid
+// having to use a context in places where a Task is unavailable.
+//
+// Using a Background context for tests is fine, as long as no values are
+// needed from the context in the tested code paths.
+func Background() Context {
+ return bgContext
+}
diff --git a/pkg/control/client/BUILD b/pkg/control/client/BUILD
new file mode 100644
index 000000000..1b9e10ee7
--- /dev/null
+++ b/pkg/control/client/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "client",
+ srcs = [
+ "client.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/unet",
+ "//pkg/urpc",
+ ],
+)
diff --git a/pkg/control/client/client.go b/pkg/control/client/client.go
new file mode 100644
index 000000000..41807cd45
--- /dev/null
+++ b/pkg/control/client/client.go
@@ -0,0 +1,33 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package client provides a basic control client interface.
+package client
+
+import (
+ "gvisor.dev/gvisor/pkg/unet"
+ "gvisor.dev/gvisor/pkg/urpc"
+)
+
+// ConnectTo attempts to connect to the sandbox with the given address.
+func ConnectTo(addr string) (*urpc.Client, error) {
+ // Connect to the server.
+ conn, err := unet.Connect(addr, false)
+ if err != nil {
+ return nil, err
+ }
+
+ // Wrap in our stream codec.
+ return urpc.NewClient(conn), nil
+}
diff --git a/pkg/control/server/BUILD b/pkg/control/server/BUILD
new file mode 100644
index 000000000..002d2ef44
--- /dev/null
+++ b/pkg/control/server/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "server",
+ srcs = ["server.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/log",
+ "//pkg/sync",
+ "//pkg/unet",
+ "//pkg/urpc",
+ ],
+)
diff --git a/pkg/control/server/server.go b/pkg/control/server/server.go
new file mode 100644
index 000000000..41abe1f2d
--- /dev/null
+++ b/pkg/control/server/server.go
@@ -0,0 +1,160 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+Package server provides a basic control server interface.
+
+Note that no objects are registered by default. Users must provide their own
+implementations of the control interface.
+*/
+package server
+
+import (
+ "os"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+ "gvisor.dev/gvisor/pkg/urpc"
+)
+
+// curUID is the unix user ID of the user that the control server is running as.
+var curUID = os.Getuid()
+
+// Server is a basic control server.
+type Server struct {
+ // socket is our bound socket.
+ socket *unet.ServerSocket
+
+ // server is our rpc server.
+ server *urpc.Server
+
+ // wg waits for the accept loop to terminate.
+ wg sync.WaitGroup
+}
+
+// New returns a new bound control server.
+func New(socket *unet.ServerSocket) *Server {
+ return &Server{
+ socket: socket,
+ server: urpc.NewServer(),
+ }
+}
+
+// FD returns the file descriptor that the server is running on.
+func (s *Server) FD() int {
+ return s.socket.FD()
+}
+
+// Wait waits for the main server goroutine to exit. This should be
+// called after a call to Serve.
+func (s *Server) Wait() {
+ s.wg.Wait()
+}
+
+// Stop stops the server. Note that this function should only be called once
+// and the server should not be used afterwards.
+func (s *Server) Stop() {
+ s.socket.Close()
+ s.wg.Wait()
+
+ // This will cause existing clients to be terminated safely.
+ s.server.Stop()
+}
+
+// StartServing starts listening for connect and spawns the main service
+// goroutine for handling incoming control requests. StartServing does not
+// block; to wait for the control server to exit, call Wait.
+func (s *Server) StartServing() error {
+ // Actually start listening.
+ if err := s.socket.Listen(); err != nil {
+ return err
+ }
+
+ s.wg.Add(1)
+ go func() { // S/R-SAFE: does not impact state directly.
+ s.serve()
+ s.wg.Done()
+ }()
+
+ return nil
+}
+
+// serve is the body of the main service goroutine. It handles incoming control
+// connections and dispatches requests to registered objects.
+func (s *Server) serve() {
+ for {
+ // Accept clients.
+ conn, err := s.socket.Accept()
+ if err != nil {
+ return
+ }
+
+ ucred, err := conn.GetPeerCred()
+ if err != nil {
+ log.Warningf("Control couldn't get credentials: %s", err.Error())
+ conn.Close()
+ continue
+ }
+
+ // Only allow this user and root.
+ if int(ucred.Uid) != curUID && ucred.Uid != 0 {
+ // Authentication failed.
+ log.Warningf("Control auth failure: other UID = %d, current UID = %d", ucred.Uid, curUID)
+ conn.Close()
+ continue
+ }
+
+ // Handle the connection non-blockingly.
+ s.server.StartHandling(conn)
+ }
+}
+
+// Register registers a specific control interface with the server.
+func (s *Server) Register(obj interface{}) {
+ s.server.Register(obj)
+}
+
+// CreateFromFD creates a new control bound to the given 'fd'. It has no
+// registered interfaces and will not start serving until StartServing is
+// called.
+func CreateFromFD(fd int) (*Server, error) {
+ socket, err := unet.NewServerSocket(fd)
+ if err != nil {
+ return nil, err
+ }
+ return New(socket), nil
+}
+
+// Create creates a new control server with an abstract unix socket
+// with the given address, which must must be unique and a valid
+// abstract socket name.
+func Create(addr string) (*Server, error) {
+ socket, err := unet.Bind(addr, false)
+ if err != nil {
+ return nil, err
+ }
+ return New(socket), nil
+}
+
+// CreateSocket creates a socket that can be used with control server,
+// but doesn't start control server. 'addr' must be a valid and unique
+// abstract socket name. Returns socket's FD, -1 in case of error.
+func CreateSocket(addr string) (int, error) {
+ socket, err := unet.Bind(addr, false)
+ if err != nil {
+ return -1, err
+ }
+ return socket.Release()
+}
diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD
new file mode 100644
index 000000000..d6cb1a549
--- /dev/null
+++ b/pkg/cpuid/BUILD
@@ -0,0 +1,35 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "cpuid",
+ srcs = [
+ "cpu_amd64.s",
+ "cpuid.go",
+ "cpuid_arm64.go",
+ "cpuid_x86.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = ["//pkg/log"],
+)
+
+go_test(
+ name = "cpuid_test",
+ size = "small",
+ srcs = [
+ "cpuid_arm64_test.go",
+ "cpuid_x86_test.go",
+ ],
+ library = ":cpuid",
+)
+
+go_test(
+ name = "cpuid_parse_test",
+ size = "small",
+ srcs = [
+ "cpuid_parse_x86_test.go",
+ ],
+ library = ":cpuid",
+ tags = ["manual"],
+)
diff --git a/pkg/cpuid/cpu_amd64.s b/pkg/cpuid/cpu_amd64.s
new file mode 100644
index 000000000..ac80d3c8a
--- /dev/null
+++ b/pkg/cpuid/cpu_amd64.s
@@ -0,0 +1,24 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// func HostID(rax, rcx uint32) (ret0, ret1, ret2, ret3 uint32)
+TEXT ·HostID(SB),$0-48
+ MOVL ax+0(FP), AX
+ MOVL cx+4(FP), CX
+ CPUID
+ MOVL AX, ret0+8(FP)
+ MOVL BX, ret1+12(FP)
+ MOVL CX, ret2+16(FP)
+ MOVL DX, ret3+20(FP)
+ RET
diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go
new file mode 100644
index 000000000..f7f9dbf86
--- /dev/null
+++ b/pkg/cpuid/cpuid.go
@@ -0,0 +1,38 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package cpuid provides basic functionality for creating and adjusting CPU
+// feature sets.
+//
+// To use FeatureSets, one should start with an existing FeatureSet (either a
+// known platform, or HostFeatureSet()) and then add, remove, and test for
+// features as desired.
+//
+// For example: on x86, test for hardware extended state saving, and if
+// we don't have it, don't expose AVX, which cannot be saved with fxsave.
+//
+// if !HostFeatureSet().HasFeature(X86FeatureXSAVE) {
+// exposedFeatures.Remove(X86FeatureAVX)
+// }
+package cpuid
+
+// Feature is a unique identifier for a particular cpu feature. We just use an
+// int as a feature number on x86 and arm64.
+//
+// On x86, features are numbered according to "blocks". Each block is 32 bits, and
+// feature bits from the same source (cpuid leaf/level) are in the same block.
+//
+// On arm64, features are numbered according to the ELF HWCAP definition.
+// arch/arm64/include/uapi/asm/hwcap.h
+type Feature int
diff --git a/pkg/cpuid/cpuid_arm64.go b/pkg/cpuid/cpuid_arm64.go
new file mode 100644
index 000000000..ac7bb6774
--- /dev/null
+++ b/pkg/cpuid/cpuid_arm64.go
@@ -0,0 +1,483 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package cpuid
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "io/ioutil"
+ "strconv"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// ARM64 doesn't have a 'cpuid' equivalent, which means it have no architected
+// discovery mechanism for hardware features available to userspace code at EL0.
+// The kernel exposes the presence of these features to userspace through a set
+// of flags(HWCAP/HWCAP2) bits, exposed in the auxilliary vector.
+// Ref Documentation/arm64/elf_hwcaps.rst for more info.
+//
+// Currently, only the HWCAP bits are supported.
+
+const (
+ // ARM64FeatureFP indicates support for single and double precision
+ // float point types.
+ ARM64FeatureFP Feature = iota
+
+ // ARM64FeatureASIMD indicates support for Advanced SIMD with single
+ // and double precision float point arithmetic.
+ ARM64FeatureASIMD
+
+ // ARM64FeatureEVTSTRM indicates support for the generic timer
+ // configured to generate events at a frequency of approximately
+ // 100KHz.
+ ARM64FeatureEVTSTRM
+
+ // ARM64FeatureAES indicates support for AES instructions
+ // (AESE/AESD/AESMC/AESIMC).
+ ARM64FeatureAES
+
+ // ARM64FeaturePMULL indicates support for AES instructions
+ // (PMULL/PMULL2).
+ ARM64FeaturePMULL
+
+ // ARM64FeatureSHA1 indicates support for SHA1 instructions
+ // (SHA1C/SHA1P/SHA1M etc).
+ ARM64FeatureSHA1
+
+ // ARM64FeatureSHA2 indicates support for SHA2 instructions
+ // (SHA256H/SHA256H2/SHA256SU0 etc).
+ ARM64FeatureSHA2
+
+ // ARM64FeatureCRC32 indicates support for CRC32 instructions
+ // (CRC32B/CRC32H/CRC32W etc).
+ ARM64FeatureCRC32
+
+ // ARM64FeatureATOMICS indicates support for atomic instructions
+ // (LDADD/LDCLR/LDEOR/LDSET etc).
+ ARM64FeatureATOMICS
+
+ // ARM64FeatureFPHP indicates support for half precision float point
+ // arithmetic.
+ ARM64FeatureFPHP
+
+ // ARM64FeatureASIMDHP indicates support for ASIMD with half precision
+ // float point arithmetic.
+ ARM64FeatureASIMDHP
+
+ // ARM64FeatureCPUID indicates support for EL0 access to certain ID
+ // registers is available.
+ ARM64FeatureCPUID
+
+ // ARM64FeatureASIMDRDM indicates support for SQRDMLAH and SQRDMLSH
+ // instructions.
+ ARM64FeatureASIMDRDM
+
+ // ARM64FeatureJSCVT indicates support for the FJCVTZS instruction.
+ ARM64FeatureJSCVT
+
+ // ARM64FeatureFCMA indicates support for the FCMLA and FCADD
+ // instructions.
+ ARM64FeatureFCMA
+
+ // ARM64FeatureLRCPC indicates support for the LDAPRB/LDAPRH/LDAPR
+ // instructions.
+ ARM64FeatureLRCPC
+
+ // ARM64FeatureDCPOP indicates support for DC instruction (DC CVAP).
+ ARM64FeatureDCPOP
+
+ // ARM64FeatureSHA3 indicates support for SHA3 instructions
+ // (EOR3/RAX1/XAR/BCAX).
+ ARM64FeatureSHA3
+
+ // ARM64FeatureSM3 indicates support for SM3 instructions
+ // (SM3SS1/SM3TT1A/SM3TT1B).
+ ARM64FeatureSM3
+
+ // ARM64FeatureSM4 indicates support for SM4 instructions
+ // (SM4E/SM4EKEY).
+ ARM64FeatureSM4
+
+ // ARM64FeatureASIMDDP indicates support for dot product instructions
+ // (UDOT/SDOT).
+ ARM64FeatureASIMDDP
+
+ // ARM64FeatureSHA512 indicates support for SHA2 instructions
+ // (SHA512H/SHA512H2/SHA512SU0).
+ ARM64FeatureSHA512
+
+ // ARM64FeatureSVE indicates support for Scalable Vector Extension.
+ ARM64FeatureSVE
+
+ // ARM64FeatureASIMDFHM indicates support for FMLAL and FMLSL
+ // instructions.
+ ARM64FeatureASIMDFHM
+)
+
+// ELF auxiliary vector tags
+const (
+ _AT_NULL = 0 // End of vector
+ _AT_HWCAP = 16 // hardware capability bit vector
+ _AT_HWCAP2 = 26 // hardware capability bit vector 2
+)
+
+// These should not be changed after they are initialized.
+var hwCap uint
+
+// To make emulation of /proc/cpuinfo easy, these names match the names of the
+// basic features in Linux defined in arch/arm64/kernel/cpuinfo.c.
+var arm64FeatureStrings = map[Feature]string{
+ ARM64FeatureFP: "fp",
+ ARM64FeatureASIMD: "asimd",
+ ARM64FeatureEVTSTRM: "evtstrm",
+ ARM64FeatureAES: "aes",
+ ARM64FeaturePMULL: "pmull",
+ ARM64FeatureSHA1: "sha1",
+ ARM64FeatureSHA2: "sha2",
+ ARM64FeatureCRC32: "crc32",
+ ARM64FeatureATOMICS: "atomics",
+ ARM64FeatureFPHP: "fphp",
+ ARM64FeatureASIMDHP: "asimdhp",
+ ARM64FeatureCPUID: "cpuid",
+ ARM64FeatureASIMDRDM: "asimdrdm",
+ ARM64FeatureJSCVT: "jscvt",
+ ARM64FeatureFCMA: "fcma",
+ ARM64FeatureLRCPC: "lrcpc",
+ ARM64FeatureDCPOP: "dcpop",
+ ARM64FeatureSHA3: "sha3",
+ ARM64FeatureSM3: "sm3",
+ ARM64FeatureSM4: "sm4",
+ ARM64FeatureASIMDDP: "asimddp",
+ ARM64FeatureSHA512: "sha512",
+ ARM64FeatureSVE: "sve",
+ ARM64FeatureASIMDFHM: "asimdfhm",
+}
+
+var (
+ cpuFreqMHz float64
+ cpuImplHex uint64
+ cpuArchDec uint64
+ cpuVarHex uint64
+ cpuPartHex uint64
+ cpuRevDec uint64
+)
+
+// arm64FeaturesFromString includes features from arm64FeatureStrings.
+var arm64FeaturesFromString = make(map[string]Feature)
+
+// FeatureFromString returns the Feature associated with the given feature
+// string plus a bool to indicate if it could find the feature.
+func FeatureFromString(s string) (Feature, bool) {
+ f, b := arm64FeaturesFromString[s]
+ return f, b
+}
+
+// String implements fmt.Stringer.
+func (f Feature) String() string {
+ if s := f.flagString(); s != "" {
+ return s
+ }
+
+ return fmt.Sprintf("<cpuflag %d>", f)
+}
+
+func (f Feature) flagString() string {
+ if s, ok := arm64FeatureStrings[f]; ok {
+ return s
+ }
+
+ return ""
+}
+
+// FeatureSet is a set of Features for a CPU.
+//
+// +stateify savable
+type FeatureSet struct {
+ // Set is the set of features that are enabled in this FeatureSet.
+ Set map[Feature]bool
+
+ // CPUImplementer is part of the processor signature.
+ CPUImplementer uint8
+
+ // CPUArchitecture is part of the processor signature.
+ CPUArchitecture uint8
+
+ // CPUVariant is part of the processor signature.
+ CPUVariant uint8
+
+ // CPUPartnum is part of the processor signature.
+ CPUPartnum uint16
+
+ // CPURevision is part of the processor signature.
+ CPURevision uint8
+}
+
+// CheckHostCompatible returns nil if fs is a subset of the host feature set.
+// Noop on arm64.
+func (fs *FeatureSet) CheckHostCompatible() error {
+ return nil
+}
+
+// ExtendedStateSize returns the number of bytes needed to save the "extended
+// state" for this processor and the boundary it must be aligned to. Extended
+// state includes floating point(NEON) registers, and other cpu state that's not
+// associated with the normal task context.
+func (fs *FeatureSet) ExtendedStateSize() (size, align uint) {
+ // ARMv8 provide 32x128bits NEON registers.
+ //
+ // Ref arch/arm64/include/uapi/asm/ptrace.h
+ // struct user_fpsimd_state {
+ // __uint128_t vregs[32];
+ // __u32 fpsr;
+ // __u32 fpcr;
+ // __u32 __reserved[2];
+ // };
+ return 528, 16
+}
+
+// HasFeature tests whether or not a feature is in the given feature set.
+func (fs *FeatureSet) HasFeature(feature Feature) bool {
+ return fs.Set[feature]
+}
+
+// UseXsave returns true if 'fs' supports the "xsave" instruction.
+//
+// Irrelevant on arm64.
+func (fs *FeatureSet) UseXsave() bool {
+ return false
+}
+
+// FlagsString prints out supported CPU "flags" field in /proc/cpuinfo.
+func (fs *FeatureSet) FlagsString() string {
+ var s []string
+ for f, _ := range arm64FeatureStrings {
+ if fs.Set[f] {
+ if fstr := f.flagString(); fstr != "" {
+ s = append(s, fstr)
+ }
+ }
+ }
+ return strings.Join(s, " ")
+}
+
+// WriteCPUInfoTo is to generate a section of one cpu in /proc/cpuinfo. This is
+// a minimal /proc/cpuinfo, and the bogomips field is simply made up.
+func (fs FeatureSet) WriteCPUInfoTo(cpu uint, b *bytes.Buffer) {
+ fmt.Fprintf(b, "processor\t: %d\n", cpu)
+ fmt.Fprintf(b, "BogoMIPS\t: %.02f\n", cpuFreqMHz) // It's bogus anyway.
+ fmt.Fprintf(b, "Features\t\t: %s\n", fs.FlagsString())
+ fmt.Fprintf(b, "CPU implementer\t: 0x%x\n", cpuImplHex)
+ fmt.Fprintf(b, "CPU architecture\t: %d\n", cpuArchDec)
+ fmt.Fprintf(b, "CPU variant\t: 0x%x\n", cpuVarHex)
+ fmt.Fprintf(b, "CPU part\t: 0x%x\n", cpuPartHex)
+ fmt.Fprintf(b, "CPU revision\t: %d\n", cpuRevDec)
+ fmt.Fprintln(b, "") // The /proc/cpuinfo file ends with an extra newline.
+}
+
+// HostFeatureSet uses hwCap to get host values and construct a feature set
+// that matches that of the host machine.
+func HostFeatureSet() *FeatureSet {
+ s := make(map[Feature]bool)
+
+ for f, _ := range arm64FeatureStrings {
+ if hwCap&(1<<f) != 0 {
+ s[f] = true
+ }
+ }
+
+ return &FeatureSet{
+ Set: s,
+ CPUImplementer: uint8(cpuImplHex),
+ CPUArchitecture: uint8(cpuArchDec),
+ CPUVariant: uint8(cpuVarHex),
+ CPUPartnum: uint16(cpuPartHex),
+ CPURevision: uint8(cpuRevDec),
+ }
+}
+
+// Reads bogomips from host /proc/cpuinfo. Must run before syscall filter
+// installation. This value is used to create the fake /proc/cpuinfo from a
+// FeatureSet.
+func initCPUInfo() {
+ cpuinfob, err := ioutil.ReadFile("/proc/cpuinfo")
+ if err != nil {
+ // Leave it as 0. The standalone VDSO bails out in the same way.
+ log.Warningf("Could not read /proc/cpuinfo: %v", err)
+ return
+ }
+ cpuinfo := string(cpuinfob)
+
+ // We get the value straight from host /proc/cpuinfo.
+ for _, line := range strings.Split(cpuinfo, "\n") {
+ switch {
+ case strings.Contains(line, "BogoMIPS"):
+ {
+ splitMHz := strings.Split(line, ":")
+ if len(splitMHz) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed BogoMIPS")
+ break
+ }
+
+ // If there was a problem, leave cpuFreqMHz as 0.
+ var err error
+ cpuFreqMHz, err = strconv.ParseFloat(strings.TrimSpace(splitMHz[1]), 64)
+ if err != nil {
+ log.Warningf("Could not parse BogoMIPS value %v: %v", splitMHz[1], err)
+ cpuFreqMHz = 0
+ }
+ }
+ case strings.Contains(line, "CPU implementer"):
+ {
+ splitImpl := strings.Split(line, ":")
+ if len(splitImpl) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed CPU implementer")
+ break
+ }
+
+ // If there was a problem, leave cpuImplHex as 0.
+ var err error
+ cpuImplHex, err = strconv.ParseUint(strings.TrimSpace(splitImpl[1]), 0, 64)
+ if err != nil {
+ log.Warningf("Could not parse CPU implementer value %v: %v", splitImpl[1], err)
+ cpuImplHex = 0
+ }
+ }
+ case strings.Contains(line, "CPU architecture"):
+ {
+ splitArch := strings.Split(line, ":")
+ if len(splitArch) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed CPU architecture")
+ break
+ }
+
+ // If there was a problem, leave cpuArchDec as 0.
+ var err error
+ cpuArchDec, err = strconv.ParseUint(strings.TrimSpace(splitArch[1]), 0, 64)
+ if err != nil {
+ log.Warningf("Could not parse CPU architecture value %v: %v", splitArch[1], err)
+ cpuArchDec = 0
+ }
+ }
+ case strings.Contains(line, "CPU variant"):
+ {
+ splitVar := strings.Split(line, ":")
+ if len(splitVar) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed CPU variant")
+ break
+ }
+
+ // If there was a problem, leave cpuVarHex as 0.
+ var err error
+ cpuVarHex, err = strconv.ParseUint(strings.TrimSpace(splitVar[1]), 0, 64)
+ if err != nil {
+ log.Warningf("Could not parse CPU variant value %v: %v", splitVar[1], err)
+ cpuVarHex = 0
+ }
+ }
+ case strings.Contains(line, "CPU part"):
+ {
+ splitPart := strings.Split(line, ":")
+ if len(splitPart) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed CPU part")
+ break
+ }
+
+ // If there was a problem, leave cpuPartHex as 0.
+ var err error
+ cpuPartHex, err = strconv.ParseUint(strings.TrimSpace(splitPart[1]), 0, 64)
+ if err != nil {
+ log.Warningf("Could not parse CPU part value %v: %v", splitPart[1], err)
+ cpuPartHex = 0
+ }
+ }
+ case strings.Contains(line, "CPU revision"):
+ {
+ splitRev := strings.Split(line, ":")
+ if len(splitRev) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed CPU revision")
+ break
+ }
+
+ // If there was a problem, leave cpuRevDec as 0.
+ var err error
+ cpuRevDec, err = strconv.ParseUint(strings.TrimSpace(splitRev[1]), 0, 64)
+ if err != nil {
+ log.Warningf("Could not parse CPU revision value %v: %v", splitRev[1], err)
+ cpuRevDec = 0
+ }
+ }
+ }
+ }
+}
+
+// The auxiliary vector of a process on the Linux system can be read
+// from /proc/self/auxv, and tags and values are stored as 8-bytes
+// decimal key-value pairs on the 64-bit system.
+//
+// $ od -t d8 /proc/self/auxv
+// 0000000 33 140734615224320
+// 0000020 16 3219913727
+// 0000040 6 4096
+// 0000060 17 100
+// 0000100 3 94665627353152
+// 0000120 4 56
+// 0000140 5 9
+// 0000160 7 140425502162944
+// 0000200 8 0
+// 0000220 9 94665627365760
+// 0000240 11 1000
+// 0000260 12 1000
+// 0000300 13 1000
+// 0000320 14 1000
+// 0000340 23 0
+// 0000360 25 140734614619513
+// 0000400 26 0
+// 0000420 31 140734614626284
+// 0000440 15 140734614619529
+// 0000460 0 0
+func initHwCap() {
+ auxv, err := ioutil.ReadFile("/proc/self/auxv")
+ if err != nil {
+ log.Warningf("Could not read /proc/self/auxv: %v", err)
+ return
+ }
+
+ l := len(auxv) / 16
+ for i := 0; i < l; i++ {
+ tag := binary.LittleEndian.Uint64(auxv[i*16:])
+ val := binary.LittleEndian.Uint64(auxv[(i*16 + 8):])
+ if tag == _AT_HWCAP {
+ hwCap = uint(val)
+ break
+ }
+ }
+}
+
+func initFeaturesFromString() {
+ for f, s := range arm64FeatureStrings {
+ arm64FeaturesFromString[s] = f
+ }
+}
+
+func init() {
+ initCPUInfo()
+ initHwCap()
+ initFeaturesFromString()
+}
diff --git a/pkg/cpuid/cpuid_arm64_test.go b/pkg/cpuid/cpuid_arm64_test.go
new file mode 100644
index 000000000..a34f67779
--- /dev/null
+++ b/pkg/cpuid/cpuid_arm64_test.go
@@ -0,0 +1,55 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package cpuid
+
+import (
+ "testing"
+)
+
+var justFP = &FeatureSet{
+ Set: map[Feature]bool{
+ ARM64FeatureFP: true,
+ }}
+
+func TestHostFeatureSet(t *testing.T) {
+ hostFeatures := HostFeatureSet()
+ if len(hostFeatures.Set) == 0 {
+ t.Errorf("Got invalid feature set %v from HostFeatureSet()", hostFeatures)
+ }
+}
+
+func TestHasFeature(t *testing.T) {
+ if !justFP.HasFeature(ARM64FeatureFP) {
+ t.Errorf("HasFeature failed, %v should contain %v", justFP, ARM64FeatureFP)
+ }
+
+ if justFP.HasFeature(ARM64FeatureSM3) {
+ t.Errorf("HasFeature failed, %v should not contain %v", justFP, ARM64FeatureSM3)
+ }
+}
+
+func TestFeatureFromString(t *testing.T) {
+ f, ok := FeatureFromString("asimd")
+ if f != ARM64FeatureASIMD || !ok {
+ t.Errorf("got %v want asimd", f)
+ }
+
+ f, ok = FeatureFromString("bad")
+ if ok {
+ t.Errorf("got %v want nothing", f)
+ }
+}
diff --git a/pkg/cpuid/cpuid_parse_x86_test.go b/pkg/cpuid/cpuid_parse_x86_test.go
new file mode 100644
index 000000000..c9bd40e1b
--- /dev/null
+++ b/pkg/cpuid/cpuid_parse_x86_test.go
@@ -0,0 +1,144 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build 386 amd64
+
+package cpuid
+
+import (
+ "fmt"
+ "io/ioutil"
+ "regexp"
+ "strconv"
+ "strings"
+ "syscall"
+ "testing"
+)
+
+func kernelVersion() (int, int, error) {
+ var u syscall.Utsname
+ if err := syscall.Uname(&u); err != nil {
+ return 0, 0, err
+ }
+
+ var r string
+ for _, b := range u.Release {
+ if b == 0 {
+ break
+ }
+ r += string(b)
+ }
+
+ s := strings.Split(r, ".")
+ if len(s) < 2 {
+ return 0, 0, fmt.Errorf("kernel release missing major and minor component: %s", r)
+ }
+
+ major, err := strconv.Atoi(s[0])
+ if err != nil {
+ return 0, 0, fmt.Errorf("error parsing major version %q in %q: %v", s[0], r, err)
+ }
+
+ minor, err := strconv.Atoi(s[1])
+ if err != nil {
+ return 0, 0, fmt.Errorf("error parsing minor version %q in %q: %v", s[1], r, err)
+ }
+
+ return major, minor, nil
+}
+
+// TestHostFeatureFlags tests that all features detected by HostFeatureSet are
+// on the host.
+//
+// It does *not* verify that all features reported by the host are detected by
+// HostFeatureSet.
+//
+// i.e., test that HostFeatureSet is a subset of the host features.
+func TestHostFeatureFlags(t *testing.T) {
+ cpuinfoBytes, _ := ioutil.ReadFile("/proc/cpuinfo")
+ cpuinfo := string(cpuinfoBytes)
+ t.Logf("Host cpu info:\n%s", cpuinfo)
+
+ major, minor, err := kernelVersion()
+ if err != nil {
+ t.Fatalf("Unable to parse kernel version: %v", err)
+ }
+
+ re := regexp.MustCompile(`(?m)^flags\s+: (.*)$`)
+ m := re.FindStringSubmatch(cpuinfo)
+ if len(m) != 2 {
+ t.Fatalf("Unable to extract flags from %q", cpuinfo)
+ }
+
+ cpuinfoFlags := make(map[string]struct{})
+ for _, f := range strings.Split(m[1], " ") {
+ cpuinfoFlags[f] = struct{}{}
+ }
+
+ fs := HostFeatureSet()
+
+ // All features have a string and appear in host cpuinfo.
+ for f := range fs.Set {
+ name := f.flagString(false)
+ if name == "" {
+ t.Errorf("Non-parsable feature: %v", f)
+ }
+
+ // Special cases not consistently visible. We don't mind if
+ // they are exposed in earlier versions.
+ switch {
+ // Block 0.
+ case f == X86FeatureSDBG && (major < 4 || major == 4 && minor < 3):
+ // SDBG only exposed in
+ // b1c599b8ff80ea79b9f8277a3f9f36a7b0cfedce (4.3).
+ continue
+ // Block 2.
+ case f == X86FeatureRDT && (major < 4 || major == 4 && minor < 10):
+ // RDT only exposed in
+ // 4ab1586488cb56ed8728e54c4157cc38646874d9 (4.10).
+ continue
+ // Block 3.
+ case f == X86FeatureAVX512VBMI && (major < 4 || major == 4 && minor < 10):
+ // AVX512VBMI only exposed in
+ // a8d9df5a509a232a959e4ef2e281f7ecd77810d6 (4.10).
+ continue
+ case f == X86FeatureUMIP && (major < 4 || major == 4 && minor < 15):
+ // UMIP only exposed in
+ // 3522c2a6a4f341058b8291326a945e2a2d2aaf55 (4.15).
+ continue
+ case f == X86FeaturePKU && (major < 4 || major == 4 && minor < 9):
+ // PKU only exposed in
+ // dfb4a70f20c5b3880da56ee4c9484bdb4e8f1e65 (4.9).
+ continue
+ // Block 4.
+ case f == X86FeatureXSAVES && (major < 4 || major == 4 && minor < 8):
+ // XSAVES only exposed in
+ // b8be15d588060a03569ac85dc4a0247460988f5b (4.8).
+ continue
+ // Block 5.
+ case f == X86FeaturePERFCTR_LLC && (major < 4 || major == 4 && minor < 14):
+ // PERFCTR_LLC renamed in
+ // 910448bbed066ab1082b510eef1ae61bb792d854 (4.14).
+ continue
+ }
+
+ hidden := f.flagString(true) == ""
+ _, ok := cpuinfoFlags[name]
+ if hidden && ok {
+ t.Errorf("Unexpectedly hidden flag: %v", f)
+ } else if !hidden && !ok {
+ t.Errorf("Non-native flag: %v", f)
+ }
+ }
+}
diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go
new file mode 100644
index 000000000..17a89c00d
--- /dev/null
+++ b/pkg/cpuid/cpuid_x86.go
@@ -0,0 +1,1111 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build 386 amd64
+
+package cpuid
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "strconv"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// Common references for CPUID leaves and bits:
+//
+// Intel:
+// * Intel SDM Volume 2, Chapter 3.2 "CPUID" (more up-to-date)
+// * Intel Application Note 485 (more detailed)
+//
+// AMD:
+// * AMD64 APM Volume 3, Appendix 3 "Obtaining Processor Information ..."
+
+// block is a collection of 32 Feature bits.
+type block int
+
+const blockSize = 32
+
+// Feature bits are numbered according to "blocks". Each block is 32 bits, and
+// feature bits from the same source (cpuid leaf/level) are in the same block.
+func featureID(b block, bit int) Feature {
+ return Feature(32*int(b) + bit)
+}
+
+// Block 0 constants are all of the "basic" feature bits returned by a cpuid in
+// ecx with eax=1.
+const (
+ X86FeatureSSE3 Feature = iota
+ X86FeaturePCLMULDQ
+ X86FeatureDTES64
+ X86FeatureMONITOR
+ X86FeatureDSCPL
+ X86FeatureVMX
+ X86FeatureSMX
+ X86FeatureEST
+ X86FeatureTM2
+ X86FeatureSSSE3 // Not a typo, "supplemental" SSE3.
+ X86FeatureCNXTID
+ X86FeatureSDBG
+ X86FeatureFMA
+ X86FeatureCX16
+ X86FeatureXTPR
+ X86FeaturePDCM
+ _ // ecx bit 16 is reserved.
+ X86FeaturePCID
+ X86FeatureDCA
+ X86FeatureSSE4_1
+ X86FeatureSSE4_2
+ X86FeatureX2APIC
+ X86FeatureMOVBE
+ X86FeaturePOPCNT
+ X86FeatureTSCD
+ X86FeatureAES
+ X86FeatureXSAVE
+ X86FeatureOSXSAVE
+ X86FeatureAVX
+ X86FeatureF16C
+ X86FeatureRDRAND
+ _ // ecx bit 31 is reserved.
+)
+
+// Block 1 constants are all of the "basic" feature bits returned by a cpuid in
+// edx with eax=1.
+const (
+ X86FeatureFPU Feature = 32 + iota
+ X86FeatureVME
+ X86FeatureDE
+ X86FeaturePSE
+ X86FeatureTSC
+ X86FeatureMSR
+ X86FeaturePAE
+ X86FeatureMCE
+ X86FeatureCX8
+ X86FeatureAPIC
+ _ // edx bit 10 is reserved.
+ X86FeatureSEP
+ X86FeatureMTRR
+ X86FeaturePGE
+ X86FeatureMCA
+ X86FeatureCMOV
+ X86FeaturePAT
+ X86FeaturePSE36
+ X86FeaturePSN
+ X86FeatureCLFSH
+ _ // edx bit 20 is reserved.
+ X86FeatureDS
+ X86FeatureACPI
+ X86FeatureMMX
+ X86FeatureFXSR
+ X86FeatureSSE
+ X86FeatureSSE2
+ X86FeatureSS
+ X86FeatureHTT
+ X86FeatureTM
+ X86FeatureIA64
+ X86FeaturePBE
+)
+
+// Block 2 bits are the "structured extended" features returned in ebx for
+// eax=7, ecx=0.
+const (
+ X86FeatureFSGSBase Feature = 2*32 + iota
+ X86FeatureTSC_ADJUST
+ _ // ebx bit 2 is reserved.
+ X86FeatureBMI1
+ X86FeatureHLE
+ X86FeatureAVX2
+ X86FeatureFDP_EXCPTN_ONLY
+ X86FeatureSMEP
+ X86FeatureBMI2
+ X86FeatureERMS
+ X86FeatureINVPCID
+ X86FeatureRTM
+ X86FeatureCQM
+ X86FeatureFPCSDS
+ X86FeatureMPX
+ X86FeatureRDT
+ X86FeatureAVX512F
+ X86FeatureAVX512DQ
+ X86FeatureRDSEED
+ X86FeatureADX
+ X86FeatureSMAP
+ X86FeatureAVX512IFMA
+ X86FeaturePCOMMIT
+ X86FeatureCLFLUSHOPT
+ X86FeatureCLWB
+ X86FeatureIPT // Intel processor trace.
+ X86FeatureAVX512PF
+ X86FeatureAVX512ER
+ X86FeatureAVX512CD
+ X86FeatureSHA
+ X86FeatureAVX512BW
+ X86FeatureAVX512VL
+)
+
+// Block 3 bits are the "extended" features returned in ecx for eax=7, ecx=0.
+const (
+ X86FeaturePREFETCHWT1 Feature = 3*32 + iota
+ X86FeatureAVX512VBMI
+ X86FeatureUMIP
+ X86FeaturePKU
+ X86FeatureOSPKE
+ X86FeatureWAITPKG
+ X86FeatureAVX512_VBMI2
+ _ // ecx bit 7 is reserved
+ X86FeatureGFNI
+ X86FeatureVAES
+ X86FeatureVPCLMULQDQ
+ X86FeatureAVX512_VNNI
+ X86FeatureAVX512_BITALG
+ X86FeatureTME
+ X86FeatureAVX512_VPOPCNTDQ
+ _ // ecx bit 15 is reserved
+ X86FeatureLA57
+ // ecx bits 17-21 are reserved
+ _
+ _
+ _
+ _
+ _
+ X86FeatureRDPID
+ // ecx bits 23-24 are reserved
+ _
+ _
+ X86FeatureCLDEMOTE
+ _ // ecx bit 26 is reserved
+ X86FeatureMOVDIRI
+ X86FeatureMOVDIR64B
+)
+
+// Block 4 constants are for xsave capabilities in CPUID.(EAX=0DH,ECX=01H):EAX.
+// The CPUID leaf is available only if 'X86FeatureXSAVE' is present.
+const (
+ X86FeatureXSAVEOPT Feature = 4*32 + iota
+ X86FeatureXSAVEC
+ X86FeatureXGETBV1
+ X86FeatureXSAVES
+ // EAX[31:4] are reserved.
+)
+
+// Block 5 constants are the extended feature bits in
+// CPUID.(EAX=0x80000001):ECX.
+const (
+ X86FeatureLAHF64 Feature = 5*32 + iota
+ X86FeatureCMP_LEGACY
+ X86FeatureSVM
+ X86FeatureEXTAPIC
+ X86FeatureCR8_LEGACY
+ X86FeatureLZCNT
+ X86FeatureSSE4A
+ X86FeatureMISALIGNSSE
+ X86FeaturePREFETCHW
+ X86FeatureOSVW
+ X86FeatureIBS
+ X86FeatureXOP
+ X86FeatureSKINIT
+ X86FeatureWDT
+ _ // ecx bit 14 is reserved.
+ X86FeatureLWP
+ X86FeatureFMA4
+ X86FeatureTCE
+ _ // ecx bit 18 is reserved.
+ _ // ecx bit 19 is reserved.
+ _ // ecx bit 20 is reserved.
+ X86FeatureTBM
+ X86FeatureTOPOLOGY
+ X86FeaturePERFCTR_CORE
+ X86FeaturePERFCTR_NB
+ _ // ecx bit 25 is reserved.
+ X86FeatureBPEXT
+ X86FeaturePERFCTR_TSC
+ X86FeaturePERFCTR_LLC
+ X86FeatureMWAITX
+ // TODO(b/152776797): Some CPUs set this but it is not documented anywhere.
+ X86FeatureBlock5Bit30
+ _ // ecx bit 31 is reserved.
+)
+
+// Block 6 constants are the extended feature bits in
+// CPUID.(EAX=0x80000001):EDX.
+//
+// These are sparse, and so the bit positions are assigned manually.
+const (
+ // On AMD, EDX[24:23] | EDX[17:12] | EDX[9:0] are duplicate features
+ // also defined in block 1 (in identical bit positions). Those features
+ // are not listed here.
+ block6DuplicateMask = 0x183f3ff
+
+ X86FeatureSYSCALL Feature = 6*32 + 11
+ X86FeatureNX Feature = 6*32 + 20
+ X86FeatureMMXEXT Feature = 6*32 + 22
+ X86FeatureFXSR_OPT Feature = 6*32 + 25
+ X86FeatureGBPAGES Feature = 6*32 + 26
+ X86FeatureRDTSCP Feature = 6*32 + 27
+ X86FeatureLM Feature = 6*32 + 29
+ X86Feature3DNOWEXT Feature = 6*32 + 30
+ X86Feature3DNOW Feature = 6*32 + 31
+)
+
+// linuxBlockOrder defines the order in which linux organizes the feature
+// blocks. Linux also tracks feature bits in 32-bit blocks, but in an order
+// which doesn't match well here, so for the /proc/cpuinfo generation we simply
+// re-map the blocks to Linux's ordering and then go through the bits in each
+// block.
+var linuxBlockOrder = []block{1, 6, 0, 5, 2, 4, 3}
+
+// To make emulation of /proc/cpuinfo easy, these names match the names of the
+// basic features in Linux defined in arch/x86/kernel/cpu/capflags.c.
+var x86FeatureStrings = map[Feature]string{
+ // Block 0.
+ X86FeatureSSE3: "pni",
+ X86FeaturePCLMULDQ: "pclmulqdq",
+ X86FeatureDTES64: "dtes64",
+ X86FeatureMONITOR: "monitor",
+ X86FeatureDSCPL: "ds_cpl",
+ X86FeatureVMX: "vmx",
+ X86FeatureSMX: "smx",
+ X86FeatureEST: "est",
+ X86FeatureTM2: "tm2",
+ X86FeatureSSSE3: "ssse3",
+ X86FeatureCNXTID: "cid",
+ X86FeatureSDBG: "sdbg",
+ X86FeatureFMA: "fma",
+ X86FeatureCX16: "cx16",
+ X86FeatureXTPR: "xtpr",
+ X86FeaturePDCM: "pdcm",
+ X86FeaturePCID: "pcid",
+ X86FeatureDCA: "dca",
+ X86FeatureSSE4_1: "sse4_1",
+ X86FeatureSSE4_2: "sse4_2",
+ X86FeatureX2APIC: "x2apic",
+ X86FeatureMOVBE: "movbe",
+ X86FeaturePOPCNT: "popcnt",
+ X86FeatureTSCD: "tsc_deadline_timer",
+ X86FeatureAES: "aes",
+ X86FeatureXSAVE: "xsave",
+ X86FeatureAVX: "avx",
+ X86FeatureF16C: "f16c",
+ X86FeatureRDRAND: "rdrand",
+
+ // Block 1.
+ X86FeatureFPU: "fpu",
+ X86FeatureVME: "vme",
+ X86FeatureDE: "de",
+ X86FeaturePSE: "pse",
+ X86FeatureTSC: "tsc",
+ X86FeatureMSR: "msr",
+ X86FeaturePAE: "pae",
+ X86FeatureMCE: "mce",
+ X86FeatureCX8: "cx8",
+ X86FeatureAPIC: "apic",
+ X86FeatureSEP: "sep",
+ X86FeatureMTRR: "mtrr",
+ X86FeaturePGE: "pge",
+ X86FeatureMCA: "mca",
+ X86FeatureCMOV: "cmov",
+ X86FeaturePAT: "pat",
+ X86FeaturePSE36: "pse36",
+ X86FeaturePSN: "pn",
+ X86FeatureCLFSH: "clflush",
+ X86FeatureDS: "dts",
+ X86FeatureACPI: "acpi",
+ X86FeatureMMX: "mmx",
+ X86FeatureFXSR: "fxsr",
+ X86FeatureSSE: "sse",
+ X86FeatureSSE2: "sse2",
+ X86FeatureSS: "ss",
+ X86FeatureHTT: "ht",
+ X86FeatureTM: "tm",
+ X86FeatureIA64: "ia64",
+ X86FeaturePBE: "pbe",
+
+ // Block 2.
+ X86FeatureFSGSBase: "fsgsbase",
+ X86FeatureTSC_ADJUST: "tsc_adjust",
+ X86FeatureBMI1: "bmi1",
+ X86FeatureHLE: "hle",
+ X86FeatureAVX2: "avx2",
+ X86FeatureSMEP: "smep",
+ X86FeatureBMI2: "bmi2",
+ X86FeatureERMS: "erms",
+ X86FeatureINVPCID: "invpcid",
+ X86FeatureRTM: "rtm",
+ X86FeatureCQM: "cqm",
+ X86FeatureMPX: "mpx",
+ X86FeatureRDT: "rdt_a",
+ X86FeatureAVX512F: "avx512f",
+ X86FeatureAVX512DQ: "avx512dq",
+ X86FeatureRDSEED: "rdseed",
+ X86FeatureADX: "adx",
+ X86FeatureSMAP: "smap",
+ X86FeatureCLWB: "clwb",
+ X86FeatureAVX512PF: "avx512pf",
+ X86FeatureAVX512ER: "avx512er",
+ X86FeatureAVX512CD: "avx512cd",
+ X86FeatureSHA: "sha_ni",
+ X86FeatureAVX512BW: "avx512bw",
+ X86FeatureAVX512VL: "avx512vl",
+
+ // Block 3.
+ X86FeatureAVX512VBMI: "avx512vbmi",
+ X86FeatureUMIP: "umip",
+ X86FeaturePKU: "pku",
+ X86FeatureOSPKE: "ospke",
+ X86FeatureWAITPKG: "waitpkg",
+ X86FeatureAVX512_VBMI2: "avx512_vbmi2",
+ X86FeatureGFNI: "gfni",
+ X86FeatureVAES: "vaes",
+ X86FeatureVPCLMULQDQ: "vpclmulqdq",
+ X86FeatureAVX512_VNNI: "avx512_vnni",
+ X86FeatureAVX512_BITALG: "avx512_bitalg",
+ X86FeatureTME: "tme",
+ X86FeatureAVX512_VPOPCNTDQ: "avx512_vpopcntdq",
+ X86FeatureLA57: "la57",
+ X86FeatureRDPID: "rdpid",
+ X86FeatureCLDEMOTE: "cldemote",
+ X86FeatureMOVDIRI: "movdiri",
+ X86FeatureMOVDIR64B: "movdir64b",
+
+ // Block 4.
+ X86FeatureXSAVEOPT: "xsaveopt",
+ X86FeatureXSAVEC: "xsavec",
+ X86FeatureXGETBV1: "xgetbv1",
+ X86FeatureXSAVES: "xsaves",
+
+ // Block 5.
+ X86FeatureLAHF64: "lahf_lm", // LAHF/SAHF in long mode
+ X86FeatureCMP_LEGACY: "cmp_legacy",
+ X86FeatureSVM: "svm",
+ X86FeatureEXTAPIC: "extapic",
+ X86FeatureCR8_LEGACY: "cr8_legacy",
+ X86FeatureLZCNT: "abm", // Advanced bit manipulation
+ X86FeatureSSE4A: "sse4a",
+ X86FeatureMISALIGNSSE: "misalignsse",
+ X86FeaturePREFETCHW: "3dnowprefetch",
+ X86FeatureOSVW: "osvw",
+ X86FeatureIBS: "ibs",
+ X86FeatureXOP: "xop",
+ X86FeatureSKINIT: "skinit",
+ X86FeatureWDT: "wdt",
+ X86FeatureLWP: "lwp",
+ X86FeatureFMA4: "fma4",
+ X86FeatureTCE: "tce",
+ X86FeatureTBM: "tbm",
+ X86FeatureTOPOLOGY: "topoext",
+ X86FeaturePERFCTR_CORE: "perfctr_core",
+ X86FeaturePERFCTR_NB: "perfctr_nb",
+ X86FeatureBPEXT: "bpext",
+ X86FeaturePERFCTR_TSC: "ptsc",
+ X86FeaturePERFCTR_LLC: "perfctr_llc",
+ X86FeatureMWAITX: "mwaitx",
+
+ // Block 6.
+ X86FeatureSYSCALL: "syscall",
+ X86FeatureNX: "nx",
+ X86FeatureMMXEXT: "mmxext",
+ X86FeatureFXSR_OPT: "fxsr_opt",
+ X86FeatureGBPAGES: "pdpe1gb",
+ X86FeatureRDTSCP: "rdtscp",
+ X86FeatureLM: "lm",
+ X86Feature3DNOWEXT: "3dnowext",
+ X86Feature3DNOW: "3dnow",
+}
+
+// These flags are parse only---they can be used for setting / unsetting the
+// flags, but will not get printed out in /proc/cpuinfo.
+var x86FeatureParseOnlyStrings = map[Feature]string{
+ // Block 0.
+ X86FeatureOSXSAVE: "osxsave",
+
+ // Block 2.
+ X86FeatureFDP_EXCPTN_ONLY: "fdp_excptn_only",
+ X86FeatureFPCSDS: "fpcsds",
+ X86FeatureIPT: "pt",
+ X86FeatureCLFLUSHOPT: "clfushopt",
+
+ // Block 3.
+ X86FeaturePREFETCHWT1: "prefetchwt1",
+
+ // Block 5.
+ X86FeatureBlock5Bit30: "block5_bit30",
+}
+
+// intelCacheDescriptors describe the caches and TLBs on the system. They are
+// returned in the registers for eax=2. Intel only.
+type intelCacheDescriptor uint8
+
+// Valid cache/TLB descriptors. All descriptors can be found in Intel SDM Vol.
+// 2, Ch. 3.2, "CPUID", Table 3-12 "Encoding of CPUID Leaf 2 Descriptors".
+const (
+ intelNullDescriptor intelCacheDescriptor = 0
+ intelNoTLBDescriptor intelCacheDescriptor = 0xfe
+ intelNoCacheDescriptor intelCacheDescriptor = 0xff
+
+ // Most descriptors omitted for brevity as they are currently unused.
+)
+
+// CacheType describes the type of a cache, as returned in eax[4:0] for eax=4.
+type CacheType uint8
+
+const (
+ // cacheNull indicates that there are no more entries.
+ cacheNull CacheType = iota
+
+ // CacheData is a data cache.
+ CacheData
+
+ // CacheInstruction is an instruction cache.
+ CacheInstruction
+
+ // CacheUnified is a unified instruction and data cache.
+ CacheUnified
+)
+
+// Cache describes the parameters of a single cache on the system.
+//
+// +stateify savable
+type Cache struct {
+ // Level is the hierarchical level of this cache (L1, L2, etc).
+ Level uint32
+
+ // Type is the type of cache.
+ Type CacheType
+
+ // FullyAssociative indicates that entries may be placed in any block.
+ FullyAssociative bool
+
+ // Partitions is the number of physical partitions in the cache.
+ Partitions uint32
+
+ // Ways is the number of ways of associativity in the cache.
+ Ways uint32
+
+ // Sets is the number of sets in the cache.
+ Sets uint32
+
+ // InvalidateHierarchical indicates that WBINVD/INVD from threads
+ // sharing this cache acts upon lower level caches for threads sharing
+ // this cache.
+ InvalidateHierarchical bool
+
+ // Inclusive indicates that this cache is inclusive of lower cache
+ // levels.
+ Inclusive bool
+
+ // DirectMapped indicates that this cache is directly mapped from
+ // address, rather than using a hash function.
+ DirectMapped bool
+}
+
+// Just a way to wrap cpuid function numbers.
+type cpuidFunction uint32
+
+// The constants below are the lower or "standard" cpuid functions, ordered as
+// defined by the hardware.
+const (
+ vendorID cpuidFunction = iota // Returns vendor ID and largest standard function.
+ featureInfo // Returns basic feature bits and processor signature.
+ intelCacheDescriptors // Returns list of cache descriptors. Intel only.
+ intelSerialNumber // Returns processor serial number (obsolete on new hardware). Intel only.
+ intelDeterministicCacheParams // Returns deterministic cache information. Intel only.
+ monitorMwaitParams // Returns information about monitor/mwait instructions.
+ powerParams // Returns information about power management and thermal sensors.
+ extendedFeatureInfo // Returns extended feature bits.
+ _ // Function 0x8 is reserved.
+ intelDCAParams // Returns direct cache access information. Intel only.
+ intelPMCInfo // Returns information about performance monitoring features. Intel only.
+ intelX2APICInfo // Returns core/logical processor topology. Intel only.
+ _ // Function 0xc is reserved.
+ xSaveInfo // Returns information about extended state management.
+)
+
+// The "extended" functions start at 0x80000000.
+const (
+ extendedFunctionInfo cpuidFunction = 0x80000000 + iota // Returns highest available extended function in eax.
+ extendedFeatures // Returns some extended feature bits in edx and ecx.
+)
+
+// These are the extended floating point state features. They are used to
+// enumerate floating point features in XCR0, XSTATE_BV, etc.
+const (
+ XSAVEFeatureX87 = 1 << 0
+ XSAVEFeatureSSE = 1 << 1
+ XSAVEFeatureAVX = 1 << 2
+ XSAVEFeatureBNDREGS = 1 << 3
+ XSAVEFeatureBNDCSR = 1 << 4
+ XSAVEFeatureAVX512op = 1 << 5
+ XSAVEFeatureAVX512zmm0 = 1 << 6
+ XSAVEFeatureAVX512zmm16 = 1 << 7
+ XSAVEFeaturePKRU = 1 << 9
+)
+
+var cpuFreqMHz float64
+
+// x86FeaturesFromString includes features from x86FeatureStrings and
+// x86FeatureParseOnlyStrings.
+var x86FeaturesFromString = make(map[string]Feature)
+
+// FeatureFromString returns the Feature associated with the given feature
+// string plus a bool to indicate if it could find the feature.
+func FeatureFromString(s string) (Feature, bool) {
+ f, b := x86FeaturesFromString[s]
+ return f, b
+}
+
+// String implements fmt.Stringer.
+func (f Feature) String() string {
+ if s := f.flagString(false); s != "" {
+ return s
+ }
+
+ block := int(f) / 32
+ bit := int(f) % 32
+ return fmt.Sprintf("<cpuflag %d; block %d bit %d>", f, block, bit)
+}
+
+func (f Feature) flagString(cpuinfoOnly bool) string {
+ if s, ok := x86FeatureStrings[f]; ok {
+ return s
+ }
+ if !cpuinfoOnly {
+ return x86FeatureParseOnlyStrings[f]
+ }
+ return ""
+}
+
+// FeatureSet is a set of Features for a CPU.
+//
+// +stateify savable
+type FeatureSet struct {
+ // Set is the set of features that are enabled in this FeatureSet.
+ Set map[Feature]bool
+
+ // VendorID is the 12-char string returned in ebx:edx:ecx for eax=0.
+ VendorID string
+
+ // ExtendedFamily is part of the processor signature.
+ ExtendedFamily uint8
+
+ // ExtendedModel is part of the processor signature.
+ ExtendedModel uint8
+
+ // ProcessorType is part of the processor signature.
+ ProcessorType uint8
+
+ // Family is part of the processor signature.
+ Family uint8
+
+ // Model is part of the processor signature.
+ Model uint8
+
+ // SteppingID is part of the processor signature.
+ SteppingID uint8
+
+ // Caches describes the caches on the CPU.
+ Caches []Cache
+
+ // CacheLine is the size of a cache line in bytes.
+ //
+ // All caches use the same line size. This is not enforced in the CPUID
+ // encoding, but is true on all known x86 processors.
+ CacheLine uint32
+}
+
+// FlagsString prints out supported CPU flags. If cpuinfoOnly is true, it is
+// equivalent to the "flags" field in /proc/cpuinfo.
+func (fs *FeatureSet) FlagsString(cpuinfoOnly bool) string {
+ var s []string
+ for _, b := range linuxBlockOrder {
+ for i := 0; i < blockSize; i++ {
+ if f := featureID(b, i); fs.Set[f] {
+ if fstr := f.flagString(cpuinfoOnly); fstr != "" {
+ s = append(s, fstr)
+ }
+ }
+ }
+ }
+ return strings.Join(s, " ")
+}
+
+// WriteCPUInfoTo is to generate a section of one cpu in /proc/cpuinfo. This is
+// a minimal /proc/cpuinfo, it is missing some fields like "microcode" that are
+// not always printed in Linux. The bogomips field is simply made up.
+func (fs FeatureSet) WriteCPUInfoTo(cpu uint, b *bytes.Buffer) {
+ fmt.Fprintf(b, "processor\t: %d\n", cpu)
+ fmt.Fprintf(b, "vendor_id\t: %s\n", fs.VendorID)
+ fmt.Fprintf(b, "cpu family\t: %d\n", ((fs.ExtendedFamily<<4)&0xff)|fs.Family)
+ fmt.Fprintf(b, "model\t\t: %d\n", ((fs.ExtendedModel<<4)&0xff)|fs.Model)
+ fmt.Fprintf(b, "model name\t: %s\n", "unknown") // Unknown for now.
+ fmt.Fprintf(b, "stepping\t: %s\n", "unknown") // Unknown for now.
+ fmt.Fprintf(b, "cpu MHz\t\t: %.3f\n", cpuFreqMHz)
+ fmt.Fprintln(b, "fpu\t\t: yes")
+ fmt.Fprintln(b, "fpu_exception\t: yes")
+ fmt.Fprintf(b, "cpuid level\t: %d\n", uint32(xSaveInfo)) // Same as ax in vendorID.
+ fmt.Fprintln(b, "wp\t\t: yes")
+ fmt.Fprintf(b, "flags\t\t: %s\n", fs.FlagsString(true))
+ fmt.Fprintf(b, "bogomips\t: %.02f\n", cpuFreqMHz) // It's bogus anyway.
+ fmt.Fprintf(b, "clflush size\t: %d\n", fs.CacheLine)
+ fmt.Fprintf(b, "cache_alignment\t: %d\n", fs.CacheLine)
+ fmt.Fprintf(b, "address sizes\t: %d bits physical, %d bits virtual\n", 46, 48)
+ fmt.Fprintln(b, "power management:") // This is always here, but can be blank.
+ fmt.Fprintln(b, "") // The /proc/cpuinfo file ends with an extra newline.
+}
+
+const (
+ amdVendorID = "AuthenticAMD"
+ intelVendorID = "GenuineIntel"
+)
+
+// AMD returns true if fs describes an AMD CPU.
+func (fs *FeatureSet) AMD() bool {
+ return fs.VendorID == amdVendorID
+}
+
+// Intel returns true if fs describes an Intel CPU.
+func (fs *FeatureSet) Intel() bool {
+ return fs.VendorID == intelVendorID
+}
+
+// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a
+// subset of the host feature set.
+type ErrIncompatible struct {
+ message string
+}
+
+// Error implements error.
+func (e ErrIncompatible) Error() string {
+ return e.message
+}
+
+// CheckHostCompatible returns nil if fs is a subset of the host feature set.
+func (fs *FeatureSet) CheckHostCompatible() error {
+ hfs := HostFeatureSet()
+
+ if diff := fs.Subtract(hfs); diff != nil {
+ return ErrIncompatible{fmt.Sprintf("CPU feature set %v incompatible with host feature set %v (missing: %v)", fs.FlagsString(false), hfs.FlagsString(false), diff)}
+ }
+
+ // The size of a cache line must match, as it is critical to correctly
+ // utilizing CLFLUSH. Other cache properties are allowed to change, as
+ // they are not important to correctness.
+ if fs.CacheLine != hfs.CacheLine {
+ return ErrIncompatible{fmt.Sprintf("CPU cache line size %d incompatible with host cache line size %d", fs.CacheLine, hfs.CacheLine)}
+ }
+
+ return nil
+}
+
+// Helper to convert 3 regs into 12-byte vendor ID.
+func vendorIDFromRegs(bx, cx, dx uint32) string {
+ bytes := make([]byte, 0, 12)
+ for i := uint(0); i < 4; i++ {
+ b := byte(bx >> (i * 8))
+ bytes = append(bytes, b)
+ }
+
+ for i := uint(0); i < 4; i++ {
+ b := byte(dx >> (i * 8))
+ bytes = append(bytes, b)
+ }
+
+ for i := uint(0); i < 4; i++ {
+ b := byte(cx >> (i * 8))
+ bytes = append(bytes, b)
+ }
+ return string(bytes)
+}
+
+var maxXsaveSize = func() uint32 {
+ // Leaf 0 of xsaveinfo function returns the size for currently
+ // enabled xsave features in ebx, the maximum size if all valid
+ // features are saved with xsave in ecx, and valid XCR0 bits in
+ // edx:eax.
+ //
+ // If xSaveInfo isn't supported, cpuid will not fault but will
+ // return bogus values.
+ _, _, maxXsaveSize, _ := HostID(uint32(xSaveInfo), 0)
+ return maxXsaveSize
+}()
+
+// ExtendedStateSize returns the number of bytes needed to save the "extended
+// state" for this processor and the boundary it must be aligned to. Extended
+// state includes floating point registers, and other cpu state that's not
+// associated with the normal task context.
+//
+// Note: We can save some space here with an optimization where we use a
+// smaller chunk of memory depending on features that are actually enabled.
+// Currently we just use the largest possible size for simplicity (which is
+// about 2.5K worst case, with avx512).
+func (fs *FeatureSet) ExtendedStateSize() (size, align uint) {
+ if fs.UseXsave() {
+ return uint(maxXsaveSize), 64
+ }
+
+ // If we don't support xsave, we fall back to fxsave, which requires
+ // 512 bytes aligned to 16 bytes.
+ return 512, 16
+}
+
+// ValidXCR0Mask returns the bits that may be set to 1 in control register
+// XCR0.
+func (fs *FeatureSet) ValidXCR0Mask() uint64 {
+ if !fs.UseXsave() {
+ return 0
+ }
+ eax, _, _, edx := HostID(uint32(xSaveInfo), 0)
+ return uint64(edx)<<32 | uint64(eax)
+}
+
+// vendorIDRegs returns the 3 register values used to construct the 12-byte
+// vendor ID string for eax=0.
+func (fs *FeatureSet) vendorIDRegs() (bx, dx, cx uint32) {
+ for i := uint(0); i < 4; i++ {
+ bx |= uint32(fs.VendorID[i]) << (i * 8)
+ }
+
+ for i := uint(0); i < 4; i++ {
+ dx |= uint32(fs.VendorID[i+4]) << (i * 8)
+ }
+
+ for i := uint(0); i < 4; i++ {
+ cx |= uint32(fs.VendorID[i+8]) << (i * 8)
+ }
+ return
+}
+
+// signature returns the signature dword that's returned in eax when eax=1.
+func (fs *FeatureSet) signature() uint32 {
+ var s uint32
+ s |= uint32(fs.SteppingID & 0xf)
+ s |= uint32(fs.Model&0xf) << 4
+ s |= uint32(fs.Family&0xf) << 8
+ s |= uint32(fs.ProcessorType&0x3) << 12
+ s |= uint32(fs.ExtendedModel&0xf) << 16
+ s |= uint32(fs.ExtendedFamily&0xff) << 20
+ return s
+}
+
+// Helper to deconstruct signature dword.
+func signatureSplit(v uint32) (ef, em, pt, f, m, sid uint8) {
+ sid = uint8(v & 0xf)
+ m = uint8(v>>4) & 0xf
+ f = uint8(v>>8) & 0xf
+ pt = uint8(v>>12) & 0x3
+ em = uint8(v>>16) & 0xf
+ ef = uint8(v >> 20)
+ return
+}
+
+// Helper to convert blockwise feature bit masks into a set of features. Masks
+// must be provided in order for each block, without skipping them. If a block
+// does not matter for this feature set, 0 is specified.
+func setFromBlockMasks(blocks ...uint32) map[Feature]bool {
+ s := make(map[Feature]bool)
+ for b, blockMask := range blocks {
+ for i := 0; i < blockSize; i++ {
+ if blockMask&1 != 0 {
+ s[featureID(block(b), i)] = true
+ }
+ blockMask >>= 1
+ }
+ }
+ return s
+}
+
+// blockMask returns the 32-bit mask associated with a block of features.
+func (fs *FeatureSet) blockMask(b block) uint32 {
+ var mask uint32
+ for i := 0; i < blockSize; i++ {
+ if fs.Set[featureID(b, i)] {
+ mask |= 1 << uint(i)
+ }
+ }
+ return mask
+}
+
+// Remove removes a Feature from a FeatureSet. It ignores features
+// that are not in the FeatureSet.
+func (fs *FeatureSet) Remove(feature Feature) {
+ delete(fs.Set, feature)
+}
+
+// Add adds a Feature to a FeatureSet. It ignores duplicate features.
+func (fs *FeatureSet) Add(feature Feature) {
+ fs.Set[feature] = true
+}
+
+// HasFeature tests whether or not a feature is in the given feature set.
+func (fs *FeatureSet) HasFeature(feature Feature) bool {
+ return fs.Set[feature]
+}
+
+// Subtract returns the features present in fs that are not present in other.
+// If all features in fs are present in other, Subtract returns nil.
+func (fs *FeatureSet) Subtract(other *FeatureSet) (diff map[Feature]bool) {
+ for f := range fs.Set {
+ if !other.Set[f] {
+ if diff == nil {
+ diff = make(map[Feature]bool)
+ }
+ diff[f] = true
+ }
+ }
+
+ return
+}
+
+// EmulateID emulates a cpuid instruction based on the feature set.
+func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) {
+ switch cpuidFunction(origAx) {
+ case vendorID:
+ ax = uint32(xSaveInfo) // 0xd (xSaveInfo) is the highest function we support.
+ bx, dx, cx = fs.vendorIDRegs()
+ case featureInfo:
+ // CLFLUSH line size is encoded in quadwords. Other fields in bx unsupported.
+ bx = (fs.CacheLine / 8) << 8
+ cx = fs.blockMask(block(0))
+ dx = fs.blockMask(block(1))
+ ax = fs.signature()
+ case intelCacheDescriptors:
+ if !fs.Intel() {
+ // Reserved on non-Intel.
+ return 0, 0, 0, 0
+ }
+
+ // "The least-significant byte in register EAX (register AL)
+ // will always return 01H. Software should ignore this value
+ // and not interpret it as an informational descriptor." - SDM
+ //
+ // We only support reporting cache parameters via
+ // intelDeterministicCacheParams; report as much here.
+ //
+ // We do not support exposing TLB information at all.
+ ax = 1 | (uint32(intelNoCacheDescriptor) << 8)
+ case intelDeterministicCacheParams:
+ if !fs.Intel() {
+ // Reserved on non-Intel.
+ return 0, 0, 0, 0
+ }
+
+ // cx is the index of the cache to describe.
+ if int(origCx) >= len(fs.Caches) {
+ return uint32(cacheNull), 0, 0, 0
+ }
+ c := fs.Caches[origCx]
+
+ ax = uint32(c.Type)
+ ax |= c.Level << 5
+ ax |= 1 << 8 // Always claim the cache is "self-initializing".
+ if c.FullyAssociative {
+ ax |= 1 << 9
+ }
+ // Processor topology not supported.
+
+ bx = fs.CacheLine - 1
+ bx |= (c.Partitions - 1) << 12
+ bx |= (c.Ways - 1) << 22
+
+ cx = c.Sets - 1
+
+ if !c.InvalidateHierarchical {
+ dx |= 1
+ }
+ if c.Inclusive {
+ dx |= 1 << 1
+ }
+ if !c.DirectMapped {
+ dx |= 1 << 2
+ }
+ case xSaveInfo:
+ if !fs.UseXsave() {
+ return 0, 0, 0, 0
+ }
+ return HostID(uint32(xSaveInfo), origCx)
+ case extendedFeatureInfo:
+ if origCx != 0 {
+ break // Only leaf 0 is supported.
+ }
+ bx = fs.blockMask(block(2))
+ cx = fs.blockMask(block(3))
+ case extendedFunctionInfo:
+ // We only support showing the extended features.
+ ax = uint32(extendedFeatures)
+ cx = 0
+ case extendedFeatures:
+ cx = fs.blockMask(block(5))
+ dx = fs.blockMask(block(6))
+ if fs.AMD() {
+ // AMD duplicates some block 1 features in block 6.
+ dx |= fs.blockMask(block(1)) & block6DuplicateMask
+ }
+ }
+
+ return
+}
+
+// UseXsave returns the choice of fp state saving instruction.
+func (fs *FeatureSet) UseXsave() bool {
+ return fs.HasFeature(X86FeatureXSAVE) && fs.HasFeature(X86FeatureOSXSAVE)
+}
+
+// UseXsaveopt returns true if 'fs' supports the "xsaveopt" instruction.
+func (fs *FeatureSet) UseXsaveopt() bool {
+ return fs.UseXsave() && fs.HasFeature(X86FeatureXSAVEOPT)
+}
+
+// HostID executes a native CPUID instruction.
+func HostID(axArg, cxArg uint32) (ax, bx, cx, dx uint32)
+
+// HostFeatureSet uses cpuid to get host values and construct a feature set
+// that matches that of the host machine. Note that there are several places
+// where there appear to be some unnecessary assignments between register names
+// (ax, bx, cx, or dx) and featureBlockN variables. This is to explicitly show
+// where the different feature blocks come from, to make the code easier to
+// inspect and read.
+func HostFeatureSet() *FeatureSet {
+ // eax=0 gets max supported feature and vendor ID.
+ _, bx, cx, dx := HostID(0, 0)
+ vendorID := vendorIDFromRegs(bx, cx, dx)
+
+ // eax=1 gets basic features in ecx:edx.
+ ax, bx, cx, dx := HostID(1, 0)
+ featureBlock0 := cx
+ featureBlock1 := dx
+ ef, em, pt, f, m, sid := signatureSplit(ax)
+ cacheLine := 8 * (bx >> 8) & 0xff
+
+ // eax=4, ecx=i gets details about cache index i. Only supported on Intel.
+ var caches []Cache
+ if vendorID == intelVendorID {
+ // ecx selects the cache index until a null type is returned.
+ for i := uint32(0); ; i++ {
+ ax, bx, cx, dx := HostID(4, i)
+ t := CacheType(ax & 0xf)
+ if t == cacheNull {
+ break
+ }
+
+ lineSize := (bx & 0xfff) + 1
+ if lineSize != cacheLine {
+ panic(fmt.Sprintf("Mismatched cache line size: %d vs %d", lineSize, cacheLine))
+ }
+
+ caches = append(caches, Cache{
+ Type: t,
+ Level: (ax >> 5) & 0x7,
+ FullyAssociative: ((ax >> 9) & 1) == 1,
+ Partitions: ((bx >> 12) & 0x3ff) + 1,
+ Ways: ((bx >> 22) & 0x3ff) + 1,
+ Sets: cx + 1,
+ InvalidateHierarchical: (dx & 1) == 0,
+ Inclusive: ((dx >> 1) & 1) == 1,
+ DirectMapped: ((dx >> 2) & 1) == 0,
+ })
+ }
+ }
+
+ // eax=7, ecx=0 gets extended features in ecx:ebx.
+ _, bx, cx, _ = HostID(7, 0)
+ featureBlock2 := bx
+ featureBlock3 := cx
+
+ // Leaf 0xd is supported only if CPUID.1:ECX.XSAVE[bit 26] is set.
+ var featureBlock4 uint32
+ if (featureBlock0 & (1 << 26)) != 0 {
+ featureBlock4, _, _, _ = HostID(uint32(xSaveInfo), 1)
+ }
+
+ // eax=0x80000000 gets supported extended levels. We use this to
+ // determine if there are any non-zero block 4 or block 6 bits to find.
+ var featureBlock5, featureBlock6 uint32
+ if ax, _, _, _ := HostID(uint32(extendedFunctionInfo), 0); ax >= uint32(extendedFeatures) {
+ // eax=0x80000001 gets AMD added feature bits.
+ _, _, cx, dx = HostID(uint32(extendedFeatures), 0)
+ featureBlock5 = cx
+ // Ignore features duplicated from block 1 on AMD. These bits
+ // are reserved on Intel.
+ featureBlock6 = dx &^ block6DuplicateMask
+ }
+
+ set := setFromBlockMasks(featureBlock0, featureBlock1, featureBlock2, featureBlock3, featureBlock4, featureBlock5, featureBlock6)
+ return &FeatureSet{
+ Set: set,
+ VendorID: vendorID,
+ ExtendedFamily: ef,
+ ExtendedModel: em,
+ ProcessorType: pt,
+ Family: f,
+ Model: m,
+ SteppingID: sid,
+ CacheLine: cacheLine,
+ Caches: caches,
+ }
+}
+
+// Reads max cpu frequency from host /proc/cpuinfo. Must run before syscall
+// filter installation. This value is used to create the fake /proc/cpuinfo
+// from a FeatureSet.
+func initCPUFreq() {
+ cpuinfob, err := ioutil.ReadFile("/proc/cpuinfo")
+ if err != nil {
+ // Leave it as 0... The standalone VDSO bails out in the same
+ // way.
+ log.Warningf("Could not read /proc/cpuinfo: %v", err)
+ return
+ }
+ cpuinfo := string(cpuinfob)
+
+ // We get the value straight from host /proc/cpuinfo. On machines with
+ // frequency scaling enabled, this will only get the current value
+ // which will likely be inaccurate. This is fine on machines with
+ // frequency scaling disabled.
+ for _, line := range strings.Split(cpuinfo, "\n") {
+ if strings.Contains(line, "cpu MHz") {
+ splitMHz := strings.Split(line, ":")
+ if len(splitMHz) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed cpu MHz line")
+ return
+ }
+
+ // If there was a problem, leave cpuFreqMHz as 0.
+ var err error
+ cpuFreqMHz, err = strconv.ParseFloat(strings.TrimSpace(splitMHz[1]), 64)
+ if err != nil {
+ log.Warningf("Could not parse cpu MHz value %v: %v", splitMHz[1], err)
+ cpuFreqMHz = 0
+ return
+ }
+ return
+ }
+ }
+ log.Warningf("Could not parse /proc/cpuinfo, it is empty or does not contain cpu MHz")
+}
+
+func initFeaturesFromString() {
+ for f, s := range x86FeatureStrings {
+ x86FeaturesFromString[s] = f
+ }
+ for f, s := range x86FeatureParseOnlyStrings {
+ x86FeaturesFromString[s] = f
+ }
+}
+
+func init() {
+ initCPUFreq()
+ initFeaturesFromString()
+}
diff --git a/pkg/cpuid/cpuid_x86_test.go b/pkg/cpuid/cpuid_x86_test.go
new file mode 100644
index 000000000..bacf345c8
--- /dev/null
+++ b/pkg/cpuid/cpuid_x86_test.go
@@ -0,0 +1,243 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build 386 amd64
+
+package cpuid
+
+import (
+ "testing"
+)
+
+// These are the default values of various FeatureSet fields.
+const (
+ defaultVendorID = "GenuineIntel"
+
+ // These processor signature defaults are derived from the values
+ // listed in Intel Application Note 485 for i7/Xeon processors.
+ defaultExtFamily uint8 = 0
+ defaultExtModel uint8 = 1
+ defaultType uint8 = 0
+ defaultFamily uint8 = 0x06
+ defaultModel uint8 = 0x0a
+ defaultSteppingID uint8 = 0
+)
+
+// newEmptyFeatureSet creates a new FeatureSet with a sensible default model and no features.
+func newEmptyFeatureSet() *FeatureSet {
+ return &FeatureSet{
+ Set: make(map[Feature]bool),
+ VendorID: defaultVendorID,
+ ExtendedFamily: defaultExtFamily,
+ ExtendedModel: defaultExtModel,
+ ProcessorType: defaultType,
+ Family: defaultFamily,
+ Model: defaultModel,
+ SteppingID: defaultSteppingID,
+ }
+}
+
+var justFPU = &FeatureSet{
+ Set: map[Feature]bool{
+ X86FeatureFPU: true,
+ }}
+
+var justFPUandPAE = &FeatureSet{
+ Set: map[Feature]bool{
+ X86FeatureFPU: true,
+ X86FeaturePAE: true,
+ }}
+
+func TestSubtract(t *testing.T) {
+ if diff := justFPU.Subtract(justFPUandPAE); diff != nil {
+ t.Errorf("Got %v is not subset of %v, want diff (%v) to be nil", justFPU, justFPUandPAE, diff)
+ }
+
+ if justFPUandPAE.Subtract(justFPU) == nil {
+ t.Errorf("Got %v is a subset of %v, want diff to be nil", justFPU, justFPUandPAE)
+ }
+}
+
+// TODO(b/73346484): Run this test on a very old platform, and make sure more
+// bits are enabled than just FPU and PAE. This test currently may not detect
+// if HostFeatureSet gives back junk bits.
+func TestHostFeatureSet(t *testing.T) {
+ hostFeatures := HostFeatureSet()
+ if justFPUandPAE.Subtract(hostFeatures) != nil {
+ t.Errorf("Got invalid feature set %v from HostFeatureSet()", hostFeatures)
+ }
+}
+
+func TestHasFeature(t *testing.T) {
+ if !justFPU.HasFeature(X86FeatureFPU) {
+ t.Errorf("HasFeature failed, %v should contain %v", justFPU, X86FeatureFPU)
+ }
+
+ if justFPU.HasFeature(X86FeatureAVX) {
+ t.Errorf("HasFeature failed, %v should not contain %v", justFPU, X86FeatureAVX)
+ }
+}
+
+// Note: these tests are aware of and abuse internal details of FeatureSets.
+// Users of FeatureSets should not depend on this.
+func TestAdd(t *testing.T) {
+ // Test a basic insertion into the FeatureSet.
+ testFeatures := newEmptyFeatureSet()
+ testFeatures.Add(X86FeatureCLFSH)
+ if len(testFeatures.Set) != 1 {
+ t.Errorf("Got length %v want 1", len(testFeatures.Set))
+ }
+
+ if !testFeatures.HasFeature(X86FeatureCLFSH) {
+ t.Errorf("Add failed, got %v want set with %v", testFeatures, X86FeatureCLFSH)
+ }
+
+ // Test that duplicates are ignored.
+ testFeatures.Add(X86FeatureCLFSH)
+ if len(testFeatures.Set) != 1 {
+ t.Errorf("Got length %v, want 1", len(testFeatures.Set))
+ }
+}
+
+func TestRemove(t *testing.T) {
+ // Try removing the last feature.
+ testFeatures := newEmptyFeatureSet()
+ testFeatures.Add(X86FeatureFPU)
+ testFeatures.Add(X86FeaturePAE)
+ testFeatures.Remove(X86FeaturePAE)
+ if !testFeatures.HasFeature(X86FeatureFPU) || len(testFeatures.Set) != 1 || testFeatures.HasFeature(X86FeaturePAE) {
+ t.Errorf("Remove failed, got %v want %v", testFeatures, justFPU)
+ }
+
+ // Try removing a feature not in the set.
+ testFeatures.Remove(X86FeatureRDRAND)
+ if !testFeatures.HasFeature(X86FeatureFPU) || len(testFeatures.Set) != 1 {
+ t.Errorf("Remove failed, got %v want %v", testFeatures, justFPU)
+ }
+}
+
+func TestFeatureFromString(t *testing.T) {
+ f, ok := FeatureFromString("avx")
+ if f != X86FeatureAVX || !ok {
+ t.Errorf("got %v want avx", f)
+ }
+
+ f, ok = FeatureFromString("bad")
+ if ok {
+ t.Errorf("got %v want nothing", f)
+ }
+}
+
+// This tests function 0 (eax=0), which returns the vendor ID and highest cpuid
+// function reported to be available.
+func TestEmulateIDVendorAndLength(t *testing.T) {
+ testFeatures := newEmptyFeatureSet()
+
+ ax, bx, cx, dx := testFeatures.EmulateID(0, 0)
+ wantEax := uint32(0xd) // Highest supported cpuid function.
+
+ // These magical constants are the characters of "GenuineIntel".
+ // See Intel AN485 for a reference on why they are laid out like this.
+ wantEbx := uint32(0x756e6547)
+ wantEcx := uint32(0x6c65746e)
+ wantEdx := uint32(0x49656e69)
+ if wantEax != ax {
+ t.Errorf("highest function failed, got %x want %x", ax, wantEax)
+ }
+
+ if wantEbx != bx || wantEcx != cx || wantEdx != dx {
+ t.Errorf("vendor string emulation failed, bx:cx:dx, got %x:%x:%x want %x:%x:%x", bx, cx, dx, wantEbx, wantEcx, wantEdx)
+ }
+}
+
+func TestEmulateIDBasicFeatures(t *testing.T) {
+ // Make a minimal test feature set.
+ testFeatures := newEmptyFeatureSet()
+ testFeatures.Add(X86FeatureCLFSH)
+ testFeatures.Add(X86FeatureAVX)
+ testFeatures.CacheLine = 64
+
+ ax, bx, cx, dx := testFeatures.EmulateID(1, 0)
+ ECXAVXBit := uint32(1 << uint(X86FeatureAVX))
+ EDXCLFlushBit := uint32(1 << uint(X86FeatureCLFSH-32)) // We adjust by 32 since it's in block 1.
+
+ if EDXCLFlushBit&dx == 0 || dx&^EDXCLFlushBit != 0 {
+ t.Errorf("EmulateID failed, got feature bits %x want %x", dx, testFeatures.blockMask(1))
+ }
+
+ if ECXAVXBit&cx == 0 || cx&^ECXAVXBit != 0 {
+ t.Errorf("EmulateID failed, got feature bits %x want %x", cx, testFeatures.blockMask(0))
+ }
+
+ // Default signature bits, based on values for i7/Xeon.
+ // See Intel AN485 for information on stepping/model bits.
+ defaultSignature := uint32(0x000106a0)
+ if defaultSignature != ax {
+ t.Errorf("EmulateID stepping emulation failed, got %x want %x", ax, defaultSignature)
+ }
+
+ clflushSizeInfo := uint32(8 << 8)
+ if clflushSizeInfo != bx {
+ t.Errorf("EmulateID bx emulation failed, got %x want %x", bx, clflushSizeInfo)
+ }
+}
+
+func TestEmulateIDExtendedFeatures(t *testing.T) {
+ // Make a minimal test feature set, one bit in each extended feature word.
+ testFeatures := newEmptyFeatureSet()
+ testFeatures.Add(X86FeatureSMEP)
+ testFeatures.Add(X86FeatureAVX512VBMI)
+
+ ax, bx, cx, dx := testFeatures.EmulateID(7, 0)
+ EBXSMEPBit := uint32(1 << uint(X86FeatureSMEP-2*32)) // Adjust by 2*32 since SMEP is a block 2 feature.
+ ECXAVXBit := uint32(1 << uint(X86FeatureAVX512VBMI-3*32)) // We adjust by 3*32 since it's a block 3 feature.
+
+ // Test that the desired bit is set and no other bits are set.
+ if EBXSMEPBit&bx == 0 || bx&^EBXSMEPBit != 0 {
+ t.Errorf("extended feature emulation failed, got feature bits %x want %x", bx, testFeatures.blockMask(2))
+ }
+
+ if ECXAVXBit&cx == 0 || cx&^ECXAVXBit != 0 {
+ t.Errorf("extended feature emulation failed, got feature bits %x want %x", cx, testFeatures.blockMask(3))
+ }
+
+ if ax != 0 || dx != 0 {
+ t.Errorf("extended feature emulation failed, ax:dx, got %x:%x want 0:0", ax, dx)
+ }
+
+ // Check that no subleaves other than 0 do anything.
+ ax, bx, cx, dx = testFeatures.EmulateID(7, 1)
+ if ax != 0 || bx != 0 || cx != 0 || dx != 0 {
+ t.Errorf("extended feature emulation failed, got %x:%x:%x:%x want 0:0", ax, bx, cx, dx)
+ }
+
+}
+
+// Checks that the expected extended features are available via cpuid functions
+// 0x80000000 and up.
+func TestEmulateIDExtended(t *testing.T) {
+ testFeatures := newEmptyFeatureSet()
+ testFeatures.Add(X86FeatureSYSCALL)
+ EDXSYSCALLBit := uint32(1 << uint(X86FeatureSYSCALL-6*32)) // Adjust by 6*32 since SYSCALL is a block 6 feature.
+
+ ax, bx, cx, dx := testFeatures.EmulateID(0x80000000, 0)
+ if ax != 0x80000001 || bx != 0 || cx != 0 || dx != 0 {
+ t.Errorf("EmulateID extended emulation failed, ax:bx:cx:dx, got %x:%x:%x:%x want 0x80000001:0:0:0", ax, bx, cx, dx)
+ }
+
+ _, _, _, dx = testFeatures.EmulateID(0x80000001, 0)
+ if EDXSYSCALLBit&dx == 0 || dx&^EDXSYSCALLBit != 0 {
+ t.Errorf("extended feature emulation failed, got feature bits %x want %x", dx, testFeatures.blockMask(6))
+ }
+}
diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD
new file mode 100644
index 000000000..bee28b68d
--- /dev/null
+++ b/pkg/eventchannel/BUILD
@@ -0,0 +1,37 @@
+load("//tools:defs.bzl", "go_library", "go_test", "proto_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "eventchannel",
+ srcs = [
+ "event.go",
+ "rate.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ ":eventchannel_go_proto",
+ "//pkg/log",
+ "//pkg/sync",
+ "//pkg/unet",
+ "@com_github_golang_protobuf//proto:go_default_library",
+ "@com_github_golang_protobuf//ptypes:go_default_library_gen",
+ "@org_golang_x_time//rate:go_default_library",
+ ],
+)
+
+proto_library(
+ name = "eventchannel",
+ srcs = ["event.proto"],
+ visibility = ["//:sandbox"],
+)
+
+go_test(
+ name = "eventchannel_test",
+ srcs = ["event_test.go"],
+ library = ":eventchannel",
+ deps = [
+ "//pkg/sync",
+ "@com_github_golang_protobuf//proto:go_default_library",
+ ],
+)
diff --git a/pkg/eventchannel/event.go b/pkg/eventchannel/event.go
new file mode 100644
index 000000000..9a29c58bd
--- /dev/null
+++ b/pkg/eventchannel/event.go
@@ -0,0 +1,201 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package eventchannel contains functionality for sending any protobuf message
+// on a socketpair.
+//
+// The wire format is a uvarint length followed by a binary protobuf.Any
+// message.
+package eventchannel
+
+import (
+ "encoding/binary"
+ "fmt"
+ "syscall"
+
+ "github.com/golang/protobuf/proto"
+ "github.com/golang/protobuf/ptypes"
+ pb "gvisor.dev/gvisor/pkg/eventchannel/eventchannel_go_proto"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// Emitter emits a proto message.
+type Emitter interface {
+ // Emit writes a single eventchannel message to an emitter. Emit should
+ // return hangup = true to indicate an emitter has "hung up" and no further
+ // messages should be directed to it.
+ Emit(msg proto.Message) (hangup bool, err error)
+
+ // Close closes this emitter. Emit cannot be used after Close is called.
+ Close() error
+}
+
+// DefaultEmitter is the default emitter. Calls to Emit and AddEmitter are sent
+// to this Emitter.
+var DefaultEmitter = &multiEmitter{}
+
+// Emit is a helper method that calls DefaultEmitter.Emit.
+func Emit(msg proto.Message) error {
+ _, err := DefaultEmitter.Emit(msg)
+ return err
+}
+
+// AddEmitter is a helper method that calls DefaultEmitter.AddEmitter.
+func AddEmitter(e Emitter) {
+ DefaultEmitter.AddEmitter(e)
+}
+
+// multiEmitter is an Emitter that forwards messages to multiple Emitters.
+type multiEmitter struct {
+ // mu protects emitters.
+ mu sync.Mutex
+ // emitters is initialized lazily in AddEmitter.
+ emitters map[Emitter]struct{}
+}
+
+// Emit emits a message using all added emitters.
+func (me *multiEmitter) Emit(msg proto.Message) (bool, error) {
+ me.mu.Lock()
+ defer me.mu.Unlock()
+
+ var err error
+ for e := range me.emitters {
+ hangup, eerr := e.Emit(msg)
+ if eerr != nil {
+ if err == nil {
+ err = fmt.Errorf("error emitting %v: on %v: %v", msg, e, eerr)
+ } else {
+ err = fmt.Errorf("%v; on %v: %v", err, e, eerr)
+ }
+
+ // Log as well, since most callers ignore the error.
+ log.Warningf("Error emitting %v on %v: %v", msg, e, eerr)
+ }
+ if hangup {
+ log.Infof("Hangup on eventchannel emitter %v.", e)
+ delete(me.emitters, e)
+ }
+ }
+
+ return false, err
+}
+
+// AddEmitter adds a new emitter.
+func (me *multiEmitter) AddEmitter(e Emitter) {
+ me.mu.Lock()
+ defer me.mu.Unlock()
+ if me.emitters == nil {
+ me.emitters = make(map[Emitter]struct{})
+ }
+ me.emitters[e] = struct{}{}
+}
+
+// Close closes all emitters. If any Close call errors, it returns the first
+// one encountered.
+func (me *multiEmitter) Close() error {
+ me.mu.Lock()
+ defer me.mu.Unlock()
+ var err error
+ for e := range me.emitters {
+ if eerr := e.Close(); err == nil && eerr != nil {
+ err = eerr
+ }
+ delete(me.emitters, e)
+ }
+ return err
+}
+
+func marshal(msg proto.Message) ([]byte, error) {
+ anypb, err := ptypes.MarshalAny(msg)
+ if err != nil {
+ return nil, err
+ }
+
+ // Wire format is uvarint message length followed by binary proto.
+ bufMsg, err := proto.Marshal(anypb)
+ if err != nil {
+ return nil, err
+ }
+ p := make([]byte, binary.MaxVarintLen64)
+ n := binary.PutUvarint(p, uint64(len(bufMsg)))
+ return append(p[:n], bufMsg...), nil
+}
+
+// socketEmitter emits proto messages on a socket.
+type socketEmitter struct {
+ socket *unet.Socket
+}
+
+// SocketEmitter creates a new event channel based on the given fd.
+//
+// SocketEmitter takes ownership of fd.
+func SocketEmitter(fd int) (Emitter, error) {
+ s, err := unet.NewSocket(fd)
+ if err != nil {
+ return nil, err
+ }
+
+ return &socketEmitter{
+ socket: s,
+ }, nil
+}
+
+// Emit implements Emitter.Emit.
+func (s *socketEmitter) Emit(msg proto.Message) (bool, error) {
+ p, err := marshal(msg)
+ if err != nil {
+ return false, err
+ }
+ for done := 0; done < len(p); {
+ n, err := s.socket.Write(p[done:])
+ if err != nil {
+ return (err == syscall.EPIPE), err
+ }
+ done += n
+ }
+ return false, nil
+}
+
+// Close implements Emitter.Emit.
+func (s *socketEmitter) Close() error {
+ return s.socket.Close()
+}
+
+// debugEmitter wraps an emitter to emit stringified event messages. This is
+// useful for debugging -- when the messages are intended for humans.
+type debugEmitter struct {
+ inner Emitter
+}
+
+// DebugEmitterFrom creates a new event channel emitter by wrapping an existing
+// raw emitter.
+func DebugEmitterFrom(inner Emitter) Emitter {
+ return &debugEmitter{
+ inner: inner,
+ }
+}
+
+func (d *debugEmitter) Emit(msg proto.Message) (bool, error) {
+ ev := &pb.DebugEvent{
+ Name: proto.MessageName(msg),
+ Text: proto.MarshalTextString(msg),
+ }
+ return d.inner.Emit(ev)
+}
+
+func (d *debugEmitter) Close() error {
+ return d.inner.Close()
+}
diff --git a/pkg/eventchannel/event.proto b/pkg/eventchannel/event.proto
new file mode 100644
index 000000000..34468f072
--- /dev/null
+++ b/pkg/eventchannel/event.proto
@@ -0,0 +1,27 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package gvisor;
+
+// A debug event encapsulates any other event protobuf in text format. This is
+// useful because clients reading events emitted this way do not need to link
+// the event protobufs to display them in a human-readable format.
+message DebugEvent {
+ // Name of the inner message.
+ string name = 1;
+ // Text representation of the inner message content.
+ string text = 2;
+}
diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go
new file mode 100644
index 000000000..43750360b
--- /dev/null
+++ b/pkg/eventchannel/event_test.go
@@ -0,0 +1,146 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package eventchannel
+
+import (
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/golang/protobuf/proto"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// testEmitter is an emitter that can be used in tests. It records all events
+// emitted, and whether it has been closed.
+type testEmitter struct {
+ // mu protects all fields below.
+ mu sync.Mutex
+
+ // events contains all emitted events.
+ events []proto.Message
+
+ // closed records whether Close() was called.
+ closed bool
+}
+
+// Emit implements Emitter.Emit.
+func (te *testEmitter) Emit(msg proto.Message) (bool, error) {
+ te.mu.Lock()
+ defer te.mu.Unlock()
+ te.events = append(te.events, msg)
+ return false, nil
+}
+
+// Close implements Emitter.Close.
+func (te *testEmitter) Close() error {
+ te.mu.Lock()
+ defer te.mu.Unlock()
+ if te.closed {
+ return fmt.Errorf("closed called twice")
+ }
+ te.closed = true
+ return nil
+}
+
+// testMessage implements proto.Message for testing.
+type testMessage struct {
+ proto.Message
+
+ // name is the name of the message, used by tests to compare messages.
+ name string
+}
+
+func TestMultiEmitter(t *testing.T) {
+ // Create three testEmitters, tied together in a multiEmitter.
+ me := &multiEmitter{}
+ var emitters []*testEmitter
+ for i := 0; i < 3; i++ {
+ te := &testEmitter{}
+ emitters = append(emitters, te)
+ me.AddEmitter(te)
+ }
+
+ // Emit three messages to multiEmitter.
+ names := []string{"foo", "bar", "baz"}
+ for _, name := range names {
+ m := testMessage{name: name}
+ if _, err := me.Emit(m); err != nil {
+ t.Fatalf("me.Emit(%v) failed: %v", m, err)
+ }
+ }
+
+ // All three emitters should have all three events.
+ for _, te := range emitters {
+ if got, want := len(te.events), len(names); got != want {
+ t.Fatalf("emitter got %d events, want %d", got, want)
+ }
+ for i, name := range names {
+ if got := te.events[i].(testMessage).name; got != name {
+ t.Errorf("emitter got message with name %q, want %q", got, name)
+ }
+ }
+ }
+
+ // Close multiEmitter.
+ if err := me.Close(); err != nil {
+ t.Fatalf("me.Close() failed: %v", err)
+ }
+
+ // All testEmitters should be closed.
+ for _, te := range emitters {
+ if !te.closed {
+ t.Errorf("te.closed got false, want true")
+ }
+ }
+}
+
+func TestRateLimitedEmitter(t *testing.T) {
+ // Create a RateLimittedEmitter that wraps a testEmitter.
+ te := &testEmitter{}
+ max := float64(5) // events per second
+ burst := 10 // events
+ rle := RateLimitedEmitterFrom(te, max, burst)
+
+ // Send 50 messages in one shot.
+ for i := 0; i < 50; i++ {
+ if _, err := rle.Emit(testMessage{}); err != nil {
+ t.Fatalf("rle.Emit failed: %v", err)
+ }
+ }
+
+ // We should have received only 10 messages.
+ if got, want := len(te.events), 10; got != want {
+ t.Errorf("got %d events, want %d", got, want)
+ }
+
+ // Sleep for a second and then send another 50.
+ time.Sleep(1 * time.Second)
+ for i := 0; i < 50; i++ {
+ if _, err := rle.Emit(testMessage{}); err != nil {
+ t.Fatalf("rle.Emit failed: %v", err)
+ }
+ }
+
+ // We should have at least 5 more message, plus maybe a few more if the
+ // test ran slowly.
+ got, wantAtLeast, wantAtMost := len(te.events), 15, 20
+ if got < wantAtLeast {
+ t.Errorf("got %d events, want at least %d", got, wantAtLeast)
+ }
+ if got > wantAtMost {
+ t.Errorf("got %d events, want at most %d", got, wantAtMost)
+ }
+}
diff --git a/pkg/eventchannel/rate.go b/pkg/eventchannel/rate.go
new file mode 100644
index 000000000..179226c92
--- /dev/null
+++ b/pkg/eventchannel/rate.go
@@ -0,0 +1,54 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package eventchannel
+
+import (
+ "github.com/golang/protobuf/proto"
+ "golang.org/x/time/rate"
+)
+
+// rateLimitedEmitter wraps an emitter and limits events to the given limits.
+// Events that would exceed the limit are discarded.
+type rateLimitedEmitter struct {
+ inner Emitter
+ limiter *rate.Limiter
+}
+
+// RateLimitedEmitterFrom creates a new event channel emitter that wraps the
+// existing emitter and enforces rate limits. The limits are imposed via a
+// token bucket, with `maxRate` events per second, with burst size of `burst`
+// events. See the golang.org/x/time/rate package and
+// https://en.wikipedia.org/wiki/Token_bucket for more information about token
+// buckets generally.
+func RateLimitedEmitterFrom(inner Emitter, maxRate float64, burst int) Emitter {
+ return &rateLimitedEmitter{
+ inner: inner,
+ limiter: rate.NewLimiter(rate.Limit(maxRate), burst),
+ }
+}
+
+// Emit implements EventEmitter.Emit.
+func (rle *rateLimitedEmitter) Emit(msg proto.Message) (bool, error) {
+ if !rle.limiter.Allow() {
+ // Drop event.
+ return false, nil
+ }
+ return rle.inner.Emit(msg)
+}
+
+// Close implements EventEmitter.Close.
+func (rle *rateLimitedEmitter) Close() error {
+ return rle.inner.Close()
+}
diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD
new file mode 100644
index 000000000..872361546
--- /dev/null
+++ b/pkg/fd/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "fd",
+ srcs = ["fd.go"],
+ visibility = ["//visibility:public"],
+)
+
+go_test(
+ name = "fd_test",
+ size = "small",
+ srcs = ["fd_test.go"],
+ library = ":fd",
+)
diff --git a/pkg/fd/fd.go b/pkg/fd/fd.go
new file mode 100644
index 000000000..83bcfe220
--- /dev/null
+++ b/pkg/fd/fd.go
@@ -0,0 +1,234 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fd provides types for working with file descriptors.
+package fd
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "runtime"
+ "sync/atomic"
+ "syscall"
+)
+
+// ReadWriter implements io.ReadWriter, io.ReaderAt, and io.WriterAt for fd. It
+// does not take ownership of fd.
+type ReadWriter struct {
+ // fd is accessed atomically so FD.Close/Release can swap it.
+ fd int64
+}
+
+var _ io.ReadWriter = (*ReadWriter)(nil)
+var _ io.ReaderAt = (*ReadWriter)(nil)
+var _ io.WriterAt = (*ReadWriter)(nil)
+
+// NewReadWriter creates a ReadWriter for fd.
+func NewReadWriter(fd int) *ReadWriter {
+ return &ReadWriter{int64(fd)}
+}
+
+func fixCount(n int, err error) (int, error) {
+ if n < 0 {
+ n = 0
+ }
+ return n, err
+}
+
+// Read implements io.Reader.
+func (r *ReadWriter) Read(b []byte) (int, error) {
+ c, err := fixCount(syscall.Read(int(atomic.LoadInt64(&r.fd)), b))
+ if c == 0 && len(b) > 0 && err == nil {
+ return 0, io.EOF
+ }
+ return c, err
+}
+
+// ReadAt implements io.ReaderAt.
+//
+// ReadAt always returns a non-nil error when c < len(b).
+func (r *ReadWriter) ReadAt(b []byte, off int64) (c int, err error) {
+ for len(b) > 0 {
+ var m int
+ m, err = fixCount(syscall.Pread(int(atomic.LoadInt64(&r.fd)), b, off))
+ if m == 0 && err == nil {
+ return c, io.EOF
+ }
+ if err != nil {
+ return c, err
+ }
+ c += m
+ b = b[m:]
+ off += int64(m)
+ }
+ return
+}
+
+// Write implements io.Writer.
+func (r *ReadWriter) Write(b []byte) (int, error) {
+ var err error
+ var n, remaining int
+ for remaining = len(b); remaining > 0; {
+ woff := len(b) - remaining
+ n, err = syscall.Write(int(atomic.LoadInt64(&r.fd)), b[woff:])
+
+ if n > 0 {
+ // syscall.Write wrote some bytes. This is the common case.
+ remaining -= n
+ } else {
+ if err == nil {
+ // syscall.Write did not write anything nor did it return an error.
+ //
+ // There is no way to guarantee that a subsequent syscall.Write will
+ // make forward progress so just panic.
+ panic(fmt.Sprintf("syscall.Write returned %d with no error", n))
+ }
+
+ if err != syscall.EINTR {
+ // If the write failed for anything other than a signal, bail out.
+ break
+ }
+ }
+ }
+
+ return len(b) - remaining, err
+}
+
+// WriteAt implements io.WriterAt.
+func (r *ReadWriter) WriteAt(b []byte, off int64) (c int, err error) {
+ for len(b) > 0 {
+ var m int
+ m, err = fixCount(syscall.Pwrite(int(atomic.LoadInt64(&r.fd)), b, off))
+ if err != nil {
+ break
+ }
+ c += m
+ b = b[m:]
+ off += int64(m)
+ }
+ return
+}
+
+// FD owns a host file descriptor.
+//
+// It is similar to os.File, with a few important distinctions:
+//
+// FD provies a Release() method which relinquishes ownership. Like os.File,
+// FD adds a finalizer to close the backing FD. However, the finalizer cannot
+// be removed from os.File, forever pinning the lifetime of an FD to its
+// os.File.
+//
+// FD supports both blocking and non-blocking operation. os.File only
+// supports blocking operation.
+type FD struct {
+ ReadWriter
+}
+
+// New creates a new FD.
+//
+// New takes ownership of fd.
+func New(fd int) *FD {
+ if fd < 0 {
+ return &FD{ReadWriter{-1}}
+ }
+ f := &FD{ReadWriter{int64(fd)}}
+ runtime.SetFinalizer(f, (*FD).Close)
+ return f
+}
+
+// NewFromFile creates a new FD from an os.File.
+//
+// NewFromFile does not transfer ownership of the file descriptor (it will be
+// duplicated, so both the os.File and FD will eventually need to be closed
+// and some (but not all) changes made to the FD will be applied to the
+// os.File as well).
+//
+// The returned FD is always blocking (Go 1.9+).
+func NewFromFile(file *os.File) (*FD, error) {
+ fd, err := syscall.Dup(int(file.Fd()))
+ // Technically, the runtime may call the finalizer on file as soon as
+ // Fd() returns.
+ runtime.KeepAlive(file)
+ if err != nil {
+ return &FD{ReadWriter{-1}}, err
+ }
+ return New(fd), nil
+}
+
+// Open is equivalent to open(2).
+func Open(path string, openmode int, perm uint32) (*FD, error) {
+ f, err := syscall.Open(path, openmode|syscall.O_LARGEFILE, perm)
+ if err != nil {
+ return nil, err
+ }
+ return New(f), nil
+}
+
+// OpenAt is equivalent to openat(2).
+func OpenAt(dir *FD, path string, flags int, mode uint32) (*FD, error) {
+ f, err := syscall.Openat(dir.FD(), path, flags, mode)
+ if err != nil {
+ return nil, err
+ }
+ return New(f), nil
+}
+
+// Close closes the file descriptor contained in the FD.
+//
+// Close is safe to call multiple times, but will return an error after the
+// first call.
+//
+// Concurrently calling Close and any other method is undefined.
+func (f *FD) Close() error {
+ runtime.SetFinalizer(f, nil)
+ return syscall.Close(int(atomic.SwapInt64(&f.fd, -1)))
+}
+
+// Release relinquishes ownership of the contained file descriptor.
+//
+// Concurrently calling Release and any other method is undefined.
+func (f *FD) Release() int {
+ runtime.SetFinalizer(f, nil)
+ return int(atomic.SwapInt64(&f.fd, -1))
+}
+
+// FD returns the file descriptor owned by FD. FD retains ownership.
+func (f *FD) FD() int {
+ return int(atomic.LoadInt64(&f.fd))
+}
+
+// File converts the FD to an os.File.
+//
+// FD does not transfer ownership of the file descriptor (it will be
+// duplicated, so both the FD and os.File will eventually need to be closed
+// and some (but not all) changes made to the os.File will be applied to the
+// FD as well).
+//
+// This operation is somewhat expensive, so care should be taken to minimize
+// its use.
+func (f *FD) File() (*os.File, error) {
+ fd, err := syscall.Dup(int(atomic.LoadInt64(&f.fd)))
+ if err != nil {
+ return nil, err
+ }
+ return os.NewFile(uintptr(fd), ""), nil
+}
+
+// ReleaseToFile returns an os.File that takes ownership of the FD.
+//
+// name is passed to os.NewFile.
+func (f *FD) ReleaseToFile(name string) *os.File {
+ return os.NewFile(uintptr(f.Release()), name)
+}
diff --git a/pkg/fd/fd_test.go b/pkg/fd/fd_test.go
new file mode 100644
index 000000000..5fb0ad47d
--- /dev/null
+++ b/pkg/fd/fd_test.go
@@ -0,0 +1,136 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fd
+
+import (
+ "math"
+ "os"
+ "syscall"
+ "testing"
+)
+
+func TestSetNegOne(t *testing.T) {
+ type entry struct {
+ name string
+ file *FD
+ fn func() error
+ }
+ var tests []entry
+
+ fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ t.Fatal("syscall.Socket:", err)
+ }
+ f1 := New(fd)
+ tests = append(tests, entry{
+ "Release",
+ f1,
+ func() error {
+ return syscall.Close(f1.Release())
+ },
+ })
+
+ fd, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ t.Fatal("syscall.Socket:", err)
+ }
+ f2 := New(fd)
+ tests = append(tests, entry{
+ "Close",
+ f2,
+ f2.Close,
+ })
+
+ for _, test := range tests {
+ if err := test.fn(); err != nil {
+ t.Errorf("%s: %v", test.name, err)
+ continue
+ }
+ if fd := test.file.FD(); fd != -1 {
+ t.Errorf("%s: got FD() = %d, want = -1", test.name, fd)
+ }
+ }
+}
+
+func TestStartsNegOne(t *testing.T) {
+ type entry struct {
+ name string
+ file *FD
+ }
+
+ tests := []entry{
+ {"-1", New(-1)},
+ {"-2", New(-2)},
+ {"MinInt32", New(math.MinInt32)},
+ {"MinInt64", New(math.MinInt64)},
+ }
+
+ for _, test := range tests {
+ if fd := test.file.FD(); fd != -1 {
+ t.Errorf("%s: got FD() = %d, want = -1", test.name, fd)
+ }
+ }
+}
+
+func TestFileDotFile(t *testing.T) {
+ fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ t.Fatal("syscall.Socket:", err)
+ }
+
+ f := New(fd)
+ of, err := f.File()
+ if err != nil {
+ t.Fatalf("File got err %v want nil", err)
+ }
+
+ if ofd, nfd := int(of.Fd()), f.FD(); ofd == nfd || ofd == -1 {
+ // Try not to double close the FD.
+ f.Release()
+
+ t.Fatalf("got %#v.File().Fd() = %d, want new FD", f, ofd)
+ }
+
+ f.Close()
+ of.Close()
+}
+
+func TestFileDotFileError(t *testing.T) {
+ f := &FD{ReadWriter{-2}}
+
+ if of, err := f.File(); err == nil {
+ t.Errorf("File %v got nil err want non-nil", of)
+ of.Close()
+ }
+}
+
+func TestNewFromFile(t *testing.T) {
+ f, err := NewFromFile(os.Stdin)
+ if err != nil {
+ t.Fatalf("NewFromFile got err %v want nil", err)
+ }
+ if nfd, ofd := f.FD(), int(os.Stdin.Fd()); nfd == -1 || nfd == ofd {
+ t.Errorf("got FD() = %d, want = new FD (old FD was %d)", nfd, ofd)
+ }
+ f.Close()
+}
+
+func TestNewFromFileError(t *testing.T) {
+ f, err := NewFromFile(nil)
+ if err == nil {
+ t.Errorf("NewFromFile got %v with nil err want non-nil", f)
+ f.Close()
+ }
+}
diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD
new file mode 100644
index 000000000..d9104ef02
--- /dev/null
+++ b/pkg/fdchannel/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+licenses(["notice"])
+
+go_library(
+ name = "fdchannel",
+ srcs = ["fdchannel_unsafe.go"],
+ visibility = ["//visibility:public"],
+)
+
+go_test(
+ name = "fdchannel_test",
+ size = "small",
+ srcs = ["fdchannel_test.go"],
+ library = ":fdchannel",
+ deps = ["//pkg/sync"],
+)
diff --git a/pkg/fdchannel/fdchannel_test.go b/pkg/fdchannel/fdchannel_test.go
new file mode 100644
index 000000000..7a8a63a59
--- /dev/null
+++ b/pkg/fdchannel/fdchannel_test.go
@@ -0,0 +1,132 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fdchannel
+
+import (
+ "io/ioutil"
+ "os"
+ "syscall"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func TestSendRecvFD(t *testing.T) {
+ sendFile, err := ioutil.TempFile("", "fdchannel_test_")
+ if err != nil {
+ t.Fatalf("failed to create temporary file: %v", err)
+ }
+ defer sendFile.Close()
+
+ chanFDs, err := NewConnectedSockets()
+ if err != nil {
+ t.Fatalf("failed to create fdchannel sockets: %v", err)
+ }
+ sendEP := NewEndpoint(chanFDs[0])
+ defer sendEP.Destroy()
+ recvEP := NewEndpoint(chanFDs[1])
+ defer recvEP.Destroy()
+
+ recvFD, err := recvEP.RecvFDNonblock()
+ if err != syscall.EAGAIN && err != syscall.EWOULDBLOCK {
+ t.Errorf("RecvFDNonblock before SendFD: got (%d, %v), wanted (<unspecified>, EAGAIN or EWOULDBLOCK", recvFD, err)
+ }
+
+ if err := sendEP.SendFD(int(sendFile.Fd())); err != nil {
+ t.Fatalf("SendFD failed: %v", err)
+ }
+ recvFD, err = recvEP.RecvFD()
+ if err != nil {
+ t.Fatalf("RecvFD failed: %v", err)
+ }
+ recvFile := os.NewFile(uintptr(recvFD), "received file")
+ defer recvFile.Close()
+
+ sendInfo, err := sendFile.Stat()
+ if err != nil {
+ t.Fatalf("failed to stat sent file: %v", err)
+ }
+ sendInfoSys := sendInfo.Sys()
+ sendStat, ok := sendInfoSys.(*syscall.Stat_t)
+ if !ok {
+ t.Fatalf("sent file's FileInfo is backed by unknown type %T", sendInfoSys)
+ }
+
+ recvInfo, err := recvFile.Stat()
+ if err != nil {
+ t.Fatalf("failed to stat received file: %v", err)
+ }
+ recvInfoSys := recvInfo.Sys()
+ recvStat, ok := recvInfoSys.(*syscall.Stat_t)
+ if !ok {
+ t.Fatalf("received file's FileInfo is backed by unknown type %T", recvInfoSys)
+ }
+
+ if sendStat.Dev != recvStat.Dev || sendStat.Ino != recvStat.Ino {
+ t.Errorf("sent file (dev=%d, ino=%d) does not match received file (dev=%d, ino=%d)", sendStat.Dev, sendStat.Ino, recvStat.Dev, recvStat.Ino)
+ }
+}
+
+func TestShutdownThenRecvFD(t *testing.T) {
+ sendFile, err := ioutil.TempFile("", "fdchannel_test_")
+ if err != nil {
+ t.Fatalf("failed to create temporary file: %v", err)
+ }
+ defer sendFile.Close()
+
+ chanFDs, err := NewConnectedSockets()
+ if err != nil {
+ t.Fatalf("failed to create fdchannel sockets: %v", err)
+ }
+ sendEP := NewEndpoint(chanFDs[0])
+ defer sendEP.Destroy()
+ recvEP := NewEndpoint(chanFDs[1])
+ defer recvEP.Destroy()
+
+ recvEP.Shutdown()
+ if _, err := recvEP.RecvFD(); err == nil {
+ t.Error("RecvFD succeeded unexpectedly")
+ }
+}
+
+func TestRecvFDThenShutdown(t *testing.T) {
+ sendFile, err := ioutil.TempFile("", "fdchannel_test_")
+ if err != nil {
+ t.Fatalf("failed to create temporary file: %v", err)
+ }
+ defer sendFile.Close()
+
+ chanFDs, err := NewConnectedSockets()
+ if err != nil {
+ t.Fatalf("failed to create fdchannel sockets: %v", err)
+ }
+ sendEP := NewEndpoint(chanFDs[0])
+ defer sendEP.Destroy()
+ recvEP := NewEndpoint(chanFDs[1])
+ defer recvEP.Destroy()
+
+ var receiverWG sync.WaitGroup
+ receiverWG.Add(1)
+ go func() {
+ defer receiverWG.Done()
+ if _, err := recvEP.RecvFD(); err == nil {
+ t.Error("RecvFD succeeded unexpectedly")
+ }
+ }()
+ defer receiverWG.Wait()
+ time.Sleep(time.Second) // to ensure recvEP.RecvFD() has blocked
+ recvEP.Shutdown()
+}
diff --git a/pkg/fdchannel/fdchannel_unsafe.go b/pkg/fdchannel/fdchannel_unsafe.go
new file mode 100644
index 000000000..367235be5
--- /dev/null
+++ b/pkg/fdchannel/fdchannel_unsafe.go
@@ -0,0 +1,146 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
+
+// Package fdchannel implements passing file descriptors between processes over
+// Unix domain sockets.
+package fdchannel
+
+import (
+ "fmt"
+ "reflect"
+ "sync/atomic"
+ "syscall"
+ "unsafe"
+)
+
+// int32 is the real type of a file descriptor.
+const sizeofInt32 = int(unsafe.Sizeof(int32(0)))
+
+// NewConnectedSockets returns a pair of file descriptors, owned by the caller,
+// representing connected sockets that may be passed to separate calls to
+// NewEndpoint to create connected Endpoints.
+func NewConnectedSockets() ([2]int, error) {
+ return syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET|syscall.SOCK_CLOEXEC, 0)
+}
+
+// Endpoint sends file descriptors to, and receives them from, another
+// connected Endpoint.
+//
+// Endpoint is not copyable or movable by value.
+type Endpoint struct {
+ sockfd int32 // accessed using atomic memory operations
+ msghdr syscall.Msghdr
+ cmsg *syscall.Cmsghdr // followed by sizeofInt32 bytes of data
+}
+
+// Init must be called on zero-value Endpoints before first use. sockfd must be
+// a blocking AF_UNIX SOCK_SEQPACKET socket.
+func (ep *Endpoint) Init(sockfd int) {
+ // "Datagram sockets in various domains (e.g., the UNIX and Internet
+ // domains) permit zero-length datagrams." - recv(2). Experimentally,
+ // sendmsg+recvmsg for a zero-length datagram is slightly faster than
+ // sendmsg+recvmsg for a single byte over a stream socket.
+ cmsgSlice := make([]byte, syscall.CmsgSpace(sizeofInt32))
+ cmsgReflect := (*reflect.SliceHeader)((unsafe.Pointer)(&cmsgSlice))
+ ep.sockfd = int32(sockfd)
+ ep.msghdr.Control = (*byte)((unsafe.Pointer)(cmsgReflect.Data))
+ ep.cmsg = (*syscall.Cmsghdr)((unsafe.Pointer)(cmsgReflect.Data))
+ // ep.msghdr.Controllen and ep.cmsg.* are mutated by recvmsg(2), so they're
+ // set before calling sendmsg/recvmsg.
+}
+
+// NewEndpoint is a convenience function that returns an initialized Endpoint
+// allocated on the heap.
+func NewEndpoint(sockfd int) *Endpoint {
+ ep := &Endpoint{}
+ ep.Init(sockfd)
+ return ep
+}
+
+// Destroy releases resources owned by ep. No other Endpoint methods may be
+// called after Destroy.
+func (ep *Endpoint) Destroy() {
+ // These need not use sync/atomic since there must not be any concurrent
+ // calls to Endpoint methods.
+ if ep.sockfd >= 0 {
+ syscall.Close(int(ep.sockfd))
+ ep.sockfd = -1
+ }
+}
+
+// Shutdown causes concurrent and future calls to ep.SendFD(), ep.RecvFD(), and
+// ep.RecvFDNonblock(), as well as the same calls in the connected Endpoint, to
+// unblock and return errors. It does not wait for concurrent calls to return.
+//
+// Shutdown is the only Endpoint method that may be called concurrently with
+// other methods.
+func (ep *Endpoint) Shutdown() {
+ if sockfd := int(atomic.SwapInt32(&ep.sockfd, -1)); sockfd >= 0 {
+ syscall.Shutdown(sockfd, syscall.SHUT_RDWR)
+ syscall.Close(sockfd)
+ }
+}
+
+// SendFD sends the open file description represented by the given file
+// descriptor to the connected Endpoint.
+func (ep *Endpoint) SendFD(fd int) error {
+ cmsgLen := syscall.CmsgLen(sizeofInt32)
+ ep.cmsg.Level = syscall.SOL_SOCKET
+ ep.cmsg.Type = syscall.SCM_RIGHTS
+ ep.cmsg.SetLen(cmsgLen)
+ *ep.cmsgData() = int32(fd)
+ ep.msghdr.SetControllen(cmsgLen)
+ _, _, e := syscall.Syscall(syscall.SYS_SENDMSG, uintptr(atomic.LoadInt32(&ep.sockfd)), uintptr((unsafe.Pointer)(&ep.msghdr)), 0)
+ if e != 0 {
+ return e
+ }
+ return nil
+}
+
+// RecvFD receives an open file description from the connected Endpoint and
+// returns a file descriptor representing it, owned by the caller.
+func (ep *Endpoint) RecvFD() (int, error) {
+ return ep.recvFD(0)
+}
+
+// RecvFDNonblock receives an open file description from the connected Endpoint
+// and returns a file descriptor representing it, owned by the caller. If there
+// are no pending receivable open file descriptions, RecvFDNonblock returns
+// (<unspecified>, EAGAIN or EWOULDBLOCK).
+func (ep *Endpoint) RecvFDNonblock() (int, error) {
+ return ep.recvFD(syscall.MSG_DONTWAIT)
+}
+
+func (ep *Endpoint) recvFD(flags uintptr) (int, error) {
+ cmsgLen := syscall.CmsgLen(sizeofInt32)
+ ep.msghdr.SetControllen(cmsgLen)
+ _, _, e := syscall.Syscall(syscall.SYS_RECVMSG, uintptr(atomic.LoadInt32(&ep.sockfd)), uintptr((unsafe.Pointer)(&ep.msghdr)), flags|syscall.MSG_TRUNC)
+ if e != 0 {
+ return -1, e
+ }
+ if int(ep.msghdr.Controllen) != cmsgLen {
+ return -1, fmt.Errorf("received control message has incorrect length: got %d, wanted %d", ep.msghdr.Controllen, cmsgLen)
+ }
+ if ep.cmsg.Level != syscall.SOL_SOCKET || ep.cmsg.Type != syscall.SCM_RIGHTS {
+ return -1, fmt.Errorf("received control message has incorrect (level, type): got (%v, %v), wanted (%v, %v)", ep.cmsg.Level, ep.cmsg.Type, syscall.SOL_SOCKET, syscall.SCM_RIGHTS)
+ }
+ return int(*ep.cmsgData()), nil
+}
+
+func (ep *Endpoint) cmsgData() *int32 {
+ // syscall.CmsgLen(0) == syscall.cmsgAlignOf(syscall.SizeofCmsghdr)
+ return (*int32)((unsafe.Pointer)(uintptr((unsafe.Pointer)(ep.cmsg)) + uintptr(syscall.CmsgLen(0))))
+}
diff --git a/pkg/fdnotifier/BUILD b/pkg/fdnotifier/BUILD
new file mode 100644
index 000000000..235dcc490
--- /dev/null
+++ b/pkg/fdnotifier/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "fdnotifier",
+ srcs = [
+ "fdnotifier.go",
+ "poll_unsafe.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/sync",
+ "//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/fdnotifier/fdnotifier.go b/pkg/fdnotifier/fdnotifier.go
new file mode 100644
index 000000000..a6b63c982
--- /dev/null
+++ b/pkg/fdnotifier/fdnotifier.go
@@ -0,0 +1,203 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+// Package fdnotifier contains an adapter that translates IO events (e.g., a
+// file became readable/writable) from native FDs to the notifications in the
+// waiter package. It uses epoll in edge-triggered mode to receive notifications
+// for registered FDs.
+package fdnotifier
+
+import (
+ "fmt"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+type fdInfo struct {
+ queue *waiter.Queue
+ waiting bool
+}
+
+// notifier holds all the state necessary to issue notifications when IO events
+// occur in the observed FDs.
+type notifier struct {
+ // epFD is the epoll file descriptor used to register for io
+ // notifications.
+ epFD int
+
+ // mu protects fdMap.
+ mu sync.Mutex
+
+ // fdMap maps file descriptors to their notification queues and waiting
+ // status.
+ fdMap map[int32]*fdInfo
+}
+
+// newNotifier creates a new notifier object.
+func newNotifier() (*notifier, error) {
+ epfd, err := syscall.EpollCreate1(0)
+ if err != nil {
+ return nil, err
+ }
+
+ w := &notifier{
+ epFD: epfd,
+ fdMap: make(map[int32]*fdInfo),
+ }
+
+ go w.waitAndNotify() // S/R-SAFE: no waiter exists during save / load.
+
+ return w, nil
+}
+
+// waitFD waits on mask for fd. The fdMap mutex must be hold.
+func (n *notifier) waitFD(fd int32, fi *fdInfo, mask waiter.EventMask) error {
+ if !fi.waiting && mask == 0 {
+ return nil
+ }
+
+ e := syscall.EpollEvent{
+ Events: mask.ToLinux() | unix.EPOLLET,
+ Fd: fd,
+ }
+
+ switch {
+ case !fi.waiting && mask != 0:
+ if err := syscall.EpollCtl(n.epFD, syscall.EPOLL_CTL_ADD, int(fd), &e); err != nil {
+ return err
+ }
+ fi.waiting = true
+ case fi.waiting && mask == 0:
+ syscall.EpollCtl(n.epFD, syscall.EPOLL_CTL_DEL, int(fd), nil)
+ fi.waiting = false
+ case fi.waiting && mask != 0:
+ if err := syscall.EpollCtl(n.epFD, syscall.EPOLL_CTL_MOD, int(fd), &e); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// addFD adds an FD to the list of FDs observed by n.
+func (n *notifier) addFD(fd int32, queue *waiter.Queue) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ // Panic if we're already notifying on this FD.
+ if _, ok := n.fdMap[fd]; ok {
+ panic(fmt.Sprintf("File descriptor %v added twice", fd))
+ }
+
+ // We have nothing to wait for at the moment. Just add it to the map.
+ n.fdMap[fd] = &fdInfo{queue: queue}
+}
+
+// updateFD updates the set of events the fd needs to be notified on.
+func (n *notifier) updateFD(fd int32) error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if fi, ok := n.fdMap[fd]; ok {
+ return n.waitFD(fd, fi, fi.queue.Events())
+ }
+
+ return nil
+}
+
+// RemoveFD removes an FD from the list of FDs observed by n.
+func (n *notifier) removeFD(fd int32) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ // Remove from map, then from epoll object.
+ n.waitFD(fd, n.fdMap[fd], 0)
+ delete(n.fdMap, fd)
+}
+
+// hasFD returns true if the fd is in the list of observed FDs.
+func (n *notifier) hasFD(fd int32) bool {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ _, ok := n.fdMap[fd]
+ return ok
+}
+
+// waitAndNotify run is its own goroutine and loops waiting for io event
+// notifications from the epoll object. Once notifications arrive, they are
+// dispatched to the registered queue.
+func (n *notifier) waitAndNotify() error {
+ e := make([]syscall.EpollEvent, 100)
+ for {
+ v, err := epollWait(n.epFD, e, -1)
+ if err == syscall.EINTR {
+ continue
+ }
+
+ if err != nil {
+ return err
+ }
+
+ n.mu.Lock()
+ for i := 0; i < v; i++ {
+ if fi, ok := n.fdMap[e[i].Fd]; ok {
+ fi.queue.Notify(waiter.EventMaskFromLinux(e[i].Events))
+ }
+ }
+ n.mu.Unlock()
+ }
+}
+
+var shared struct {
+ notifier *notifier
+ once sync.Once
+ initErr error
+}
+
+// AddFD adds an FD to the list of observed FDs.
+func AddFD(fd int32, queue *waiter.Queue) error {
+ shared.once.Do(func() {
+ shared.notifier, shared.initErr = newNotifier()
+ })
+
+ if shared.initErr != nil {
+ return shared.initErr
+ }
+
+ shared.notifier.addFD(fd, queue)
+ return nil
+}
+
+// UpdateFD updates the set of events the fd needs to be notified on.
+func UpdateFD(fd int32) error {
+ return shared.notifier.updateFD(fd)
+}
+
+// RemoveFD removes an FD from the list of observed FDs.
+func RemoveFD(fd int32) {
+ shared.notifier.removeFD(fd)
+}
+
+// HasFD returns true if the FD is in the list of observed FDs.
+//
+// This should only be used by tests to assert that FDs are correctly registered.
+func HasFD(fd int32) bool {
+ return shared.notifier.hasFD(fd)
+}
diff --git a/pkg/fdnotifier/poll_unsafe.go b/pkg/fdnotifier/poll_unsafe.go
new file mode 100644
index 000000000..4225b04dd
--- /dev/null
+++ b/pkg/fdnotifier/poll_unsafe.go
@@ -0,0 +1,82 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package fdnotifier
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// NonBlockingPoll polls the given FD in non-blocking fashion. It is used just
+// to query the FD's current state.
+func NonBlockingPoll(fd int32, mask waiter.EventMask) waiter.EventMask {
+ e := struct {
+ fd int32
+ events int16
+ revents int16
+ }{
+ fd: fd,
+ events: int16(mask.ToLinux()),
+ }
+
+ ts := syscall.Timespec{
+ Sec: 0,
+ Nsec: 0,
+ }
+
+ for {
+ n, _, err := syscall.RawSyscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(&e)), 1,
+ uintptr(unsafe.Pointer(&ts)), 0, 0, 0)
+ // Interrupted by signal, try again.
+ if err == syscall.EINTR {
+ continue
+ }
+ // If an error occur we'll conservatively say the FD is ready for
+ // whatever is being checked.
+ if err != 0 {
+ return mask
+ }
+
+ // If no FDs were returned, it wasn't ready for anything.
+ if n == 0 {
+ return 0
+ }
+
+ // Otherwise we got the ready events in the revents field.
+ return waiter.EventMaskFromLinux(uint32(e.revents))
+ }
+}
+
+// epollWait performs a blocking wait on epfd.
+//
+// Preconditions:
+// * len(events) > 0
+func epollWait(epfd int, events []syscall.EpollEvent, msec int) (int, error) {
+ if len(events) == 0 {
+ panic("Empty events passed to EpollWait")
+ }
+
+ // We actually use epoll_pwait with NULL sigmask instead of epoll_wait
+ // since that is what the Go >= 1.11 runtime prefers.
+ r, _, e := syscall.Syscall6(syscall.SYS_EPOLL_PWAIT, uintptr(epfd), uintptr(unsafe.Pointer(&events[0])), uintptr(len(events)), uintptr(msec), 0, 0)
+ if e != 0 {
+ return 0, e
+ }
+ return int(r), nil
+}
diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD
new file mode 100644
index 000000000..aa8e4e1f3
--- /dev/null
+++ b/pkg/flipcall/BUILD
@@ -0,0 +1,34 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+licenses(["notice"])
+
+go_library(
+ name = "flipcall",
+ srcs = [
+ "ctrl_futex.go",
+ "flipcall.go",
+ "flipcall_unsafe.go",
+ "futex_linux.go",
+ "io.go",
+ "packet_window_allocator.go",
+ "packet_window_mmap.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/log",
+ "//pkg/memutil",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "flipcall_test",
+ size = "small",
+ srcs = [
+ "flipcall_example_test.go",
+ "flipcall_test.go",
+ ],
+ library = ":flipcall",
+ deps = ["//pkg/sync"],
+)
diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go
new file mode 100644
index 000000000..e7c3a3a0b
--- /dev/null
+++ b/pkg/flipcall/ctrl_futex.go
@@ -0,0 +1,176 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flipcall
+
+import (
+ "encoding/json"
+ "fmt"
+ "math"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+type endpointControlImpl struct {
+ state int32
+}
+
+// Bits in endpointControlImpl.state.
+const (
+ epsBlocked = 1 << iota
+ epsShutdown
+)
+
+func (ep *Endpoint) ctrlInit(opts ...EndpointOption) error {
+ if len(opts) != 0 {
+ return fmt.Errorf("unknown EndpointOption: %T", opts[0])
+ }
+ return nil
+}
+
+type ctrlHandshakeRequest struct{}
+
+type ctrlHandshakeResponse struct{}
+
+func (ep *Endpoint) ctrlConnect() error {
+ if err := ep.enterFutexWait(); err != nil {
+ return err
+ }
+ _, err := ep.futexConnect(&ctrlHandshakeRequest{})
+ ep.exitFutexWait()
+ return err
+}
+
+func (ep *Endpoint) ctrlWaitFirst() error {
+ if err := ep.enterFutexWait(); err != nil {
+ return err
+ }
+ defer ep.exitFutexWait()
+
+ // Wait for the handshake request.
+ if err := ep.futexSwitchFromPeer(); err != nil {
+ return err
+ }
+
+ // Read the handshake request.
+ reqLen := atomic.LoadUint32(ep.dataLen())
+ if reqLen > ep.dataCap {
+ return fmt.Errorf("invalid handshake request length %d (maximum %d)", reqLen, ep.dataCap)
+ }
+ var req ctrlHandshakeRequest
+ if err := json.NewDecoder(ep.NewReader(reqLen)).Decode(&req); err != nil {
+ return fmt.Errorf("error reading handshake request: %v", err)
+ }
+
+ // Write the handshake response.
+ w := ep.NewWriter()
+ if err := json.NewEncoder(w).Encode(ctrlHandshakeResponse{}); err != nil {
+ return fmt.Errorf("error writing handshake response: %v", err)
+ }
+ *ep.dataLen() = w.Len()
+
+ // Return control to the client.
+ raceBecomeInactive()
+ if err := ep.futexSwitchToPeer(); err != nil {
+ return err
+ }
+
+ // Wait for the first non-handshake message.
+ return ep.futexSwitchFromPeer()
+}
+
+func (ep *Endpoint) ctrlRoundTrip() error {
+ if err := ep.futexSwitchToPeer(); err != nil {
+ return err
+ }
+ if err := ep.enterFutexWait(); err != nil {
+ return err
+ }
+ err := ep.futexSwitchFromPeer()
+ ep.exitFutexWait()
+ return err
+}
+
+func (ep *Endpoint) ctrlWakeLast() error {
+ return ep.futexSwitchToPeer()
+}
+
+func (ep *Endpoint) enterFutexWait() error {
+ switch eps := atomic.AddInt32(&ep.ctrl.state, epsBlocked); eps {
+ case epsBlocked:
+ return nil
+ case epsBlocked | epsShutdown:
+ atomic.AddInt32(&ep.ctrl.state, -epsBlocked)
+ return ShutdownError{}
+ default:
+ // Most likely due to ep.enterFutexWait() being called concurrently
+ // from multiple goroutines.
+ panic(fmt.Sprintf("invalid flipcall.Endpoint.ctrl.state before flipcall.Endpoint.enterFutexWait(): %v", eps-epsBlocked))
+ }
+}
+
+func (ep *Endpoint) exitFutexWait() {
+ switch eps := atomic.AddInt32(&ep.ctrl.state, -epsBlocked); eps {
+ case 0:
+ return
+ case epsShutdown:
+ // ep.ctrlShutdown() was called while we were blocked, so we are
+ // repsonsible for indicating connection shutdown.
+ ep.shutdownConn()
+ default:
+ panic(fmt.Sprintf("invalid flipcall.Endpoint.ctrl.state after flipcall.Endpoint.exitFutexWait(): %v", eps+epsBlocked))
+ }
+}
+
+func (ep *Endpoint) ctrlShutdown() {
+ // Set epsShutdown to ensure that future calls to ep.enterFutexWait() fail.
+ if atomic.AddInt32(&ep.ctrl.state, epsShutdown)&epsBlocked != 0 {
+ // Wake the blocked thread. This must loop because it's possible that
+ // FUTEX_WAKE occurs after the waiter sets epsBlocked, but before it
+ // blocks in FUTEX_WAIT.
+ for {
+ // Wake MaxInt32 threads to prevent a broken or malicious peer from
+ // swallowing our wakeup by FUTEX_WAITing from multiple threads.
+ if err := ep.futexWakeConnState(math.MaxInt32); err != nil {
+ log.Warningf("failed to FUTEX_WAKE Endpoints: %v", err)
+ break
+ }
+ yieldThread()
+ if atomic.LoadInt32(&ep.ctrl.state)&epsBlocked == 0 {
+ break
+ }
+ }
+ } else {
+ // There is no blocked thread, so we are responsible for indicating
+ // connection shutdown.
+ ep.shutdownConn()
+ }
+}
+
+func (ep *Endpoint) shutdownConn() {
+ switch cs := atomic.SwapUint32(ep.connState(), csShutdown); cs {
+ case ep.activeState:
+ if err := ep.futexWakeConnState(1); err != nil {
+ log.Warningf("failed to FUTEX_WAKE peer Endpoint for shutdown: %v", err)
+ }
+ case ep.inactiveState:
+ // The peer is currently active and will detect shutdown when it tries
+ // to update the connection state.
+ case csShutdown:
+ // The peer also called Endpoint.Shutdown().
+ default:
+ log.Warningf("unexpected connection state before Endpoint.shutdownConn(): %v", cs)
+ }
+}
diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go
new file mode 100644
index 000000000..ec742c091
--- /dev/null
+++ b/pkg/flipcall/flipcall.go
@@ -0,0 +1,257 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package flipcall implements a protocol providing Fast Local Interprocess
+// Procedure Calls between mutually-distrusting processes.
+package flipcall
+
+import (
+ "fmt"
+ "math"
+ "sync/atomic"
+ "syscall"
+)
+
+// An Endpoint provides the ability to synchronously transfer data and control
+// to a connected peer Endpoint, which may be in another process.
+//
+// Since the Endpoint control transfer model is synchronous, at any given time
+// one Endpoint "has control" (designated the active Endpoint), and the other
+// is "waiting for control" (designated the inactive Endpoint). Users of the
+// flipcall package designate one Endpoint as the client, which is initially
+// active, and the other as the server, which is initially inactive. See
+// flipcall_example_test.go for usage.
+type Endpoint struct {
+ // packet is a pointer to the beginning of the packet window. (Since this
+ // is a raw OS memory mapping and not a Go object, it does not need to be
+ // represented as an unsafe.Pointer.) packet is immutable.
+ packet uintptr
+
+ // dataCap is the size of the datagram part of the packet window in bytes.
+ // dataCap is immutable.
+ dataCap uint32
+
+ // activeState is csClientActive if this is a client Endpoint and
+ // csServerActive if this is a server Endpoint.
+ activeState uint32
+
+ // inactiveState is csServerActive if this is a client Endpoint and
+ // csClientActive if this is a server Endpoint.
+ inactiveState uint32
+
+ // shutdown is non-zero if Endpoint.Shutdown() has been called, or if the
+ // Endpoint has acknowledged shutdown initiated by the peer. shutdown is
+ // accessed using atomic memory operations.
+ shutdown uint32
+
+ ctrl endpointControlImpl
+}
+
+// EndpointSide indicates which side of a connection an Endpoint belongs to.
+type EndpointSide int
+
+const (
+ // ClientSide indicates that an Endpoint is a client (initially-active;
+ // first method call should be Connect).
+ ClientSide EndpointSide = iota
+
+ // ServerSide indicates that an Endpoint is a server (initially-inactive;
+ // first method call should be RecvFirst.)
+ ServerSide
+)
+
+// Init must be called on zero-value Endpoints before first use. If it
+// succeeds, ep.Destroy() must be called once the Endpoint is no longer in use.
+//
+// pwd represents the packet window used to exchange data with the peer
+// Endpoint. FD may differ between Endpoints if they are in different
+// processes, but must represent the same file. The packet window must
+// initially be filled with zero bytes.
+func (ep *Endpoint) Init(side EndpointSide, pwd PacketWindowDescriptor, opts ...EndpointOption) error {
+ switch side {
+ case ClientSide:
+ ep.activeState = csClientActive
+ ep.inactiveState = csServerActive
+ case ServerSide:
+ ep.activeState = csServerActive
+ ep.inactiveState = csClientActive
+ default:
+ return fmt.Errorf("invalid EndpointSide: %v", side)
+ }
+ if pwd.Length < pageSize {
+ return fmt.Errorf("packet window size (%d) less than minimum (%d)", pwd.Length, pageSize)
+ }
+ if pwd.Length > math.MaxUint32 {
+ return fmt.Errorf("packet window size (%d) exceeds maximum (%d)", pwd.Length, math.MaxUint32)
+ }
+ m, e := packetWindowMmap(pwd)
+ if e != 0 {
+ return fmt.Errorf("failed to mmap packet window: %v", e)
+ }
+ ep.packet = m
+ ep.dataCap = uint32(pwd.Length) - uint32(PacketHeaderBytes)
+ if err := ep.ctrlInit(opts...); err != nil {
+ ep.unmapPacket()
+ return err
+ }
+ return nil
+}
+
+// NewEndpoint is a convenience function that returns an initialized Endpoint
+// allocated on the heap.
+func NewEndpoint(side EndpointSide, pwd PacketWindowDescriptor, opts ...EndpointOption) (*Endpoint, error) {
+ var ep Endpoint
+ if err := ep.Init(side, pwd, opts...); err != nil {
+ return nil, err
+ }
+ return &ep, nil
+}
+
+// An EndpointOption configures an Endpoint.
+type EndpointOption interface {
+ isEndpointOption()
+}
+
+// Destroy releases resources owned by ep. No other Endpoint methods may be
+// called after Destroy.
+func (ep *Endpoint) Destroy() {
+ ep.unmapPacket()
+}
+
+func (ep *Endpoint) unmapPacket() {
+ syscall.RawSyscall(syscall.SYS_MUNMAP, ep.packet, uintptr(ep.dataCap)+PacketHeaderBytes, 0)
+ ep.packet = 0
+}
+
+// Shutdown causes concurrent and future calls to ep.Connect(), ep.SendRecv(),
+// ep.RecvFirst(), and ep.SendLast(), as well as the same calls in the peer
+// Endpoint, to unblock and return ShutdownErrors. It does not wait for
+// concurrent calls to return. Successive calls to Shutdown have no effect.
+//
+// Shutdown is the only Endpoint method that may be called concurrently with
+// other methods on the same Endpoint.
+func (ep *Endpoint) Shutdown() {
+ if atomic.SwapUint32(&ep.shutdown, 1) != 0 {
+ // ep.Shutdown() has previously been called.
+ return
+ }
+ ep.ctrlShutdown()
+}
+
+// isShutdownLocally returns true if ep.Shutdown() has been called.
+func (ep *Endpoint) isShutdownLocally() bool {
+ return atomic.LoadUint32(&ep.shutdown) != 0
+}
+
+// ShutdownError is returned by most Endpoint methods after Endpoint.Shutdown()
+// has been called.
+type ShutdownError struct{}
+
+// Error implements error.Error.
+func (ShutdownError) Error() string {
+ return "flipcall connection shutdown"
+}
+
+// DataCap returns the maximum datagram size supported by ep. Equivalently,
+// DataCap returns len(ep.Data()).
+func (ep *Endpoint) DataCap() uint32 {
+ return ep.dataCap
+}
+
+// Connection state.
+const (
+ // The client is, by definition, initially active, so this must be 0.
+ csClientActive = 0
+ csServerActive = 1
+ csShutdown = 2
+)
+
+// Connect blocks until the peer Endpoint has called Endpoint.RecvFirst().
+//
+// Preconditions: ep is a client Endpoint. ep.Connect(), ep.RecvFirst(),
+// ep.SendRecv(), and ep.SendLast() have never been called.
+func (ep *Endpoint) Connect() error {
+ err := ep.ctrlConnect()
+ if err == nil {
+ raceBecomeActive()
+ }
+ return err
+}
+
+// RecvFirst blocks until the peer Endpoint calls Endpoint.SendRecv(), then
+// returns the datagram length specified by that call.
+//
+// Preconditions: ep is a server Endpoint. ep.SendRecv(), ep.RecvFirst(), and
+// ep.SendLast() have never been called.
+func (ep *Endpoint) RecvFirst() (uint32, error) {
+ if err := ep.ctrlWaitFirst(); err != nil {
+ return 0, err
+ }
+ raceBecomeActive()
+ recvDataLen := atomic.LoadUint32(ep.dataLen())
+ if recvDataLen > ep.dataCap {
+ return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap)
+ }
+ return recvDataLen, nil
+}
+
+// SendRecv transfers control to the peer Endpoint, causing its call to
+// Endpoint.SendRecv() or Endpoint.RecvFirst() to return with the given
+// datagram length, then blocks until the peer Endpoint calls
+// Endpoint.SendRecv() or Endpoint.SendLast().
+//
+// Preconditions: dataLen <= ep.DataCap(). No previous call to ep.SendRecv() or
+// ep.RecvFirst() has returned an error. ep.SendLast() has never been called.
+// If ep is a client Endpoint, ep.Connect() has previously been called and
+// returned nil.
+func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) {
+ if dataLen > ep.dataCap {
+ panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap))
+ }
+ // This store can safely be non-atomic: Under correct operation we should
+ // be the only thread writing ep.dataLen(), and ep.ctrlRoundTrip() will
+ // synchronize with the receiver. We will not read from ep.dataLen() until
+ // after ep.ctrlRoundTrip(), so if the peer is mutating it concurrently then
+ // they can only shoot themselves in the foot.
+ *ep.dataLen() = dataLen
+ raceBecomeInactive()
+ if err := ep.ctrlRoundTrip(); err != nil {
+ return 0, err
+ }
+ raceBecomeActive()
+ recvDataLen := atomic.LoadUint32(ep.dataLen())
+ if recvDataLen > ep.dataCap {
+ return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap)
+ }
+ return recvDataLen, nil
+}
+
+// SendLast causes the peer Endpoint's call to Endpoint.SendRecv() or
+// Endpoint.RecvFirst() to return with the given datagram length.
+//
+// Preconditions: dataLen <= ep.DataCap(). No previous call to ep.SendRecv() or
+// ep.RecvFirst() has returned an error. ep.SendLast() has never been called.
+// If ep is a client Endpoint, ep.Connect() has previously been called and
+// returned nil.
+func (ep *Endpoint) SendLast(dataLen uint32) error {
+ if dataLen > ep.dataCap {
+ panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap))
+ }
+ *ep.dataLen() = dataLen
+ raceBecomeInactive()
+ if err := ep.ctrlWakeLast(); err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/pkg/flipcall/flipcall_example_test.go b/pkg/flipcall/flipcall_example_test.go
new file mode 100644
index 000000000..2e28a149a
--- /dev/null
+++ b/pkg/flipcall/flipcall_example_test.go
@@ -0,0 +1,113 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flipcall
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func Example() {
+ const (
+ reqPrefix = "request "
+ respPrefix = "response "
+ count = 3
+ maxMessageLen = len(respPrefix) + 1 // 1 digit
+ )
+
+ pwa, err := NewPacketWindowAllocator()
+ if err != nil {
+ panic(err)
+ }
+ defer pwa.Destroy()
+ pwd, err := pwa.Allocate(PacketWindowLengthForDataCap(uint32(maxMessageLen)))
+ if err != nil {
+ panic(err)
+ }
+ var clientEP Endpoint
+ if err := clientEP.Init(ClientSide, pwd); err != nil {
+ panic(err)
+ }
+ defer clientEP.Destroy()
+ var serverEP Endpoint
+ if err := serverEP.Init(ServerSide, pwd); err != nil {
+ panic(err)
+ }
+ defer serverEP.Destroy()
+
+ var serverRun sync.WaitGroup
+ serverRun.Add(1)
+ go func() {
+ defer serverRun.Done()
+ i := 0
+ var buf bytes.Buffer
+ // wait for first request
+ n, err := serverEP.RecvFirst()
+ if err != nil {
+ return
+ }
+ for {
+ // read request
+ buf.Reset()
+ buf.Write(serverEP.Data()[:n])
+ fmt.Println(buf.String())
+ // write response
+ buf.Reset()
+ fmt.Fprintf(&buf, "%s%d", respPrefix, i)
+ copy(serverEP.Data(), buf.Bytes())
+ // send response and wait for next request
+ n, err = serverEP.SendRecv(uint32(buf.Len()))
+ if err != nil {
+ return
+ }
+ i++
+ }
+ }()
+ defer func() {
+ serverEP.Shutdown()
+ serverRun.Wait()
+ }()
+
+ // establish connection as client
+ if err := clientEP.Connect(); err != nil {
+ panic(err)
+ }
+ var buf bytes.Buffer
+ for i := 0; i < count; i++ {
+ // write request
+ buf.Reset()
+ fmt.Fprintf(&buf, "%s%d", reqPrefix, i)
+ copy(clientEP.Data(), buf.Bytes())
+ // send request and wait for response
+ n, err := clientEP.SendRecv(uint32(buf.Len()))
+ if err != nil {
+ panic(err)
+ }
+ // read response
+ buf.Reset()
+ buf.Write(clientEP.Data()[:n])
+ fmt.Println(buf.String())
+ }
+
+ // Output:
+ // request 0
+ // response 0
+ // request 1
+ // response 1
+ // request 2
+ // response 2
+}
diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go
new file mode 100644
index 000000000..33fd55a44
--- /dev/null
+++ b/pkg/flipcall/flipcall_test.go
@@ -0,0 +1,405 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flipcall
+
+import (
+ "runtime"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+var testPacketWindowSize = pageSize
+
+type testConnection struct {
+ pwa PacketWindowAllocator
+ clientEP Endpoint
+ serverEP Endpoint
+}
+
+func newTestConnectionWithOptions(tb testing.TB, clientOpts, serverOpts []EndpointOption) *testConnection {
+ c := &testConnection{}
+ if err := c.pwa.Init(); err != nil {
+ tb.Fatalf("failed to create PacketWindowAllocator: %v", err)
+ }
+ pwd, err := c.pwa.Allocate(testPacketWindowSize)
+ if err != nil {
+ c.pwa.Destroy()
+ tb.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err)
+ }
+ if err := c.clientEP.Init(ClientSide, pwd, clientOpts...); err != nil {
+ c.pwa.Destroy()
+ tb.Fatalf("failed to create client Endpoint: %v", err)
+ }
+ if err := c.serverEP.Init(ServerSide, pwd, serverOpts...); err != nil {
+ c.pwa.Destroy()
+ c.clientEP.Destroy()
+ tb.Fatalf("failed to create server Endpoint: %v", err)
+ }
+ return c
+}
+
+func newTestConnection(tb testing.TB) *testConnection {
+ return newTestConnectionWithOptions(tb, nil, nil)
+}
+
+func (c *testConnection) destroy() {
+ c.pwa.Destroy()
+ c.clientEP.Destroy()
+ c.serverEP.Destroy()
+}
+
+func testSendRecv(t *testing.T, c *testConnection) {
+ // This shared variable is used to confirm that synchronization between
+ // flipcall endpoints is visible to the Go race detector.
+ state := 0
+ var serverRun sync.WaitGroup
+ serverRun.Add(1)
+ go func() {
+ defer serverRun.Done()
+ t.Logf("server Endpoint waiting for packet 1")
+ if _, err := c.serverEP.RecvFirst(); err != nil {
+ t.Errorf("server Endpoint.RecvFirst() failed: %v", err)
+ return
+ }
+ state++
+ if state != 2 {
+ t.Errorf("shared state counter: got %d, wanted 2", state)
+ }
+ t.Logf("server Endpoint got packet 1, sending packet 2 and waiting for packet 3")
+ if _, err := c.serverEP.SendRecv(0); err != nil {
+ t.Errorf("server Endpoint.SendRecv() failed: %v", err)
+ return
+ }
+ state++
+ if state != 4 {
+ t.Errorf("shared state counter: got %d, wanted 4", state)
+ }
+ t.Logf("server Endpoint got packet 3")
+ }()
+ defer func() {
+ // Ensure that the server goroutine is cleaned up before
+ // c.serverEP.Destroy(), even if the test fails.
+ c.serverEP.Shutdown()
+ serverRun.Wait()
+ }()
+
+ t.Logf("client Endpoint establishing connection")
+ if err := c.clientEP.Connect(); err != nil {
+ t.Fatalf("client Endpoint.Connect() failed: %v", err)
+ }
+ state++
+ if state != 1 {
+ t.Errorf("shared state counter: got %d, wanted 1", state)
+ }
+ t.Logf("client Endpoint sending packet 1 and waiting for packet 2")
+ if _, err := c.clientEP.SendRecv(0); err != nil {
+ t.Fatalf("client Endpoint.SendRecv() failed: %v", err)
+ }
+ state++
+ if state != 3 {
+ t.Errorf("shared state counter: got %d, wanted 3", state)
+ }
+ t.Logf("client Endpoint got packet 2, sending packet 3")
+ if err := c.clientEP.SendLast(0); err != nil {
+ t.Fatalf("client Endpoint.SendLast() failed: %v", err)
+ }
+ t.Logf("waiting for server goroutine to complete")
+ serverRun.Wait()
+}
+
+func TestSendRecv(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testSendRecv(t, c)
+}
+
+func testShutdownBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
+ if remoteShutdown {
+ c.serverEP.Shutdown()
+ } else {
+ c.clientEP.Shutdown()
+ }
+ if err := c.clientEP.Connect(); err == nil {
+ t.Errorf("client Endpoint.Connect() succeeded unexpectedly")
+ }
+}
+
+func TestShutdownBeforeConnectLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownBeforeConnect(t, c, false)
+}
+
+func TestShutdownBeforeConnectRemote(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownBeforeConnect(t, c, true)
+}
+
+func testShutdownDuringConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
+ var clientRun sync.WaitGroup
+ clientRun.Add(1)
+ go func() {
+ defer clientRun.Done()
+ if err := c.clientEP.Connect(); err == nil {
+ t.Errorf("client Endpoint.Connect() succeeded unexpectedly")
+ }
+ }()
+ time.Sleep(time.Second) // to allow c.clientEP.Connect() to block
+ if remoteShutdown {
+ c.serverEP.Shutdown()
+ } else {
+ c.clientEP.Shutdown()
+ }
+ clientRun.Wait()
+}
+
+func TestShutdownDuringConnectLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringConnect(t, c, false)
+}
+
+func TestShutdownDuringConnectRemote(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringConnect(t, c, true)
+}
+
+func testShutdownBeforeRecvFirst(t *testing.T, c *testConnection, remoteShutdown bool) {
+ if remoteShutdown {
+ c.clientEP.Shutdown()
+ } else {
+ c.serverEP.Shutdown()
+ }
+ if _, err := c.serverEP.RecvFirst(); err == nil {
+ t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
+ }
+}
+
+func TestShutdownBeforeRecvFirstLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownBeforeRecvFirst(t, c, false)
+}
+
+func TestShutdownBeforeRecvFirstRemote(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownBeforeRecvFirst(t, c, true)
+}
+
+func testShutdownDuringRecvFirstBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
+ var serverRun sync.WaitGroup
+ serverRun.Add(1)
+ go func() {
+ defer serverRun.Done()
+ if _, err := c.serverEP.RecvFirst(); err == nil {
+ t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
+ }
+ }()
+ time.Sleep(time.Second) // to allow c.serverEP.RecvFirst() to block
+ if remoteShutdown {
+ c.clientEP.Shutdown()
+ } else {
+ c.serverEP.Shutdown()
+ }
+ serverRun.Wait()
+}
+
+func TestShutdownDuringRecvFirstBeforeConnectLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringRecvFirstBeforeConnect(t, c, false)
+}
+
+func TestShutdownDuringRecvFirstBeforeConnectRemote(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringRecvFirstBeforeConnect(t, c, true)
+}
+
+func testShutdownDuringRecvFirstAfterConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
+ var serverRun sync.WaitGroup
+ serverRun.Add(1)
+ go func() {
+ defer serverRun.Done()
+ if _, err := c.serverEP.RecvFirst(); err == nil {
+ t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
+ }
+ }()
+ defer func() {
+ // Ensure that the server goroutine is cleaned up before
+ // c.serverEP.Destroy(), even if the test fails.
+ c.serverEP.Shutdown()
+ serverRun.Wait()
+ }()
+ if err := c.clientEP.Connect(); err != nil {
+ t.Fatalf("client Endpoint.Connect() failed: %v", err)
+ }
+ if remoteShutdown {
+ c.clientEP.Shutdown()
+ } else {
+ c.serverEP.Shutdown()
+ }
+ serverRun.Wait()
+}
+
+func TestShutdownDuringRecvFirstAfterConnectLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringRecvFirstAfterConnect(t, c, false)
+}
+
+func TestShutdownDuringRecvFirstAfterConnectRemote(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringRecvFirstAfterConnect(t, c, true)
+}
+
+func testShutdownDuringClientSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) {
+ var serverRun sync.WaitGroup
+ serverRun.Add(1)
+ go func() {
+ defer serverRun.Done()
+ if _, err := c.serverEP.RecvFirst(); err != nil {
+ t.Errorf("server Endpoint.RecvFirst() failed: %v", err)
+ }
+ // At this point, the client must be blocked in c.clientEP.SendRecv().
+ if remoteShutdown {
+ c.serverEP.Shutdown()
+ } else {
+ c.clientEP.Shutdown()
+ }
+ }()
+ defer func() {
+ // Ensure that the server goroutine is cleaned up before
+ // c.serverEP.Destroy(), even if the test fails.
+ c.serverEP.Shutdown()
+ serverRun.Wait()
+ }()
+ if err := c.clientEP.Connect(); err != nil {
+ t.Fatalf("client Endpoint.Connect() failed: %v", err)
+ }
+ if _, err := c.clientEP.SendRecv(0); err == nil {
+ t.Errorf("client Endpoint.SendRecv() succeeded unexpectedly")
+ }
+}
+
+func TestShutdownDuringClientSendRecvLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringClientSendRecv(t, c, false)
+}
+
+func TestShutdownDuringClientSendRecvRemote(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringClientSendRecv(t, c, true)
+}
+
+func testShutdownDuringServerSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) {
+ var serverRun sync.WaitGroup
+ serverRun.Add(1)
+ go func() {
+ defer serverRun.Done()
+ if _, err := c.serverEP.RecvFirst(); err != nil {
+ t.Errorf("server Endpoint.RecvFirst() failed: %v", err)
+ return
+ }
+ if _, err := c.serverEP.SendRecv(0); err == nil {
+ t.Errorf("server Endpoint.SendRecv() succeeded unexpectedly")
+ }
+ }()
+ defer func() {
+ // Ensure that the server goroutine is cleaned up before
+ // c.serverEP.Destroy(), even if the test fails.
+ c.serverEP.Shutdown()
+ serverRun.Wait()
+ }()
+ if err := c.clientEP.Connect(); err != nil {
+ t.Fatalf("client Endpoint.Connect() failed: %v", err)
+ }
+ if _, err := c.clientEP.SendRecv(0); err != nil {
+ t.Fatalf("client Endpoint.SendRecv() failed: %v", err)
+ }
+ time.Sleep(time.Second) // to allow serverEP.SendRecv() to block
+ if remoteShutdown {
+ c.clientEP.Shutdown()
+ } else {
+ c.serverEP.Shutdown()
+ }
+ serverRun.Wait()
+}
+
+func TestShutdownDuringServerSendRecvLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringServerSendRecv(t, c, false)
+}
+
+func TestShutdownDuringServerSendRecvRemote(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringServerSendRecv(t, c, true)
+}
+
+func benchmarkSendRecv(b *testing.B, c *testConnection) {
+ var serverRun sync.WaitGroup
+ serverRun.Add(1)
+ go func() {
+ defer serverRun.Done()
+ if b.N == 0 {
+ return
+ }
+ if _, err := c.serverEP.RecvFirst(); err != nil {
+ b.Errorf("server Endpoint.RecvFirst() failed: %v", err)
+ return
+ }
+ for i := 1; i < b.N; i++ {
+ if _, err := c.serverEP.SendRecv(0); err != nil {
+ b.Errorf("server Endpoint.SendRecv() failed: %v", err)
+ return
+ }
+ }
+ if err := c.serverEP.SendLast(0); err != nil {
+ b.Errorf("server Endpoint.SendLast() failed: %v", err)
+ }
+ }()
+ defer func() {
+ c.serverEP.Shutdown()
+ serverRun.Wait()
+ }()
+
+ if err := c.clientEP.Connect(); err != nil {
+ b.Fatalf("client Endpoint.Connect() failed: %v", err)
+ }
+ runtime.GC()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ if _, err := c.clientEP.SendRecv(0); err != nil {
+ b.Fatalf("client Endpoint.SendRecv() failed: %v", err)
+ }
+ }
+ b.StopTimer()
+}
+
+func BenchmarkSendRecv(b *testing.B) {
+ c := newTestConnection(b)
+ defer c.destroy()
+ benchmarkSendRecv(b, c)
+}
diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go
new file mode 100644
index 000000000..ac974b232
--- /dev/null
+++ b/pkg/flipcall/flipcall_unsafe.go
@@ -0,0 +1,87 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flipcall
+
+import (
+ "reflect"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Packets consist of a 16-byte header followed by an arbitrarily-sized
+// datagram. The header consists of:
+//
+// - A 4-byte native-endian connection state.
+//
+// - A 4-byte native-endian datagram length in bytes.
+//
+// - 8 reserved bytes.
+const (
+ // PacketHeaderBytes is the size of a flipcall packet header in bytes. The
+ // maximum datagram size supported by a flipcall connection is equal to the
+ // length of the packet window minus PacketHeaderBytes.
+ //
+ // PacketHeaderBytes is exported to support its use in constant
+ // expressions. Non-constant expressions may prefer to use
+ // PacketWindowLengthForDataCap().
+ PacketHeaderBytes = 16
+)
+
+func (ep *Endpoint) connState() *uint32 {
+ return (*uint32)((unsafe.Pointer)(ep.packet))
+}
+
+func (ep *Endpoint) dataLen() *uint32 {
+ return (*uint32)((unsafe.Pointer)(ep.packet + 4))
+}
+
+// Data returns the datagram part of ep's packet window as a byte slice.
+//
+// Note that the packet window is shared with the potentially-untrusted peer
+// Endpoint, which may concurrently mutate the contents of the packet window.
+// Thus:
+//
+// - Readers must not assume that two reads of the same byte in Data() will
+// return the same result. In other words, readers should read any given byte
+// in Data() at most once.
+//
+// - Writers must not assume that they will read back the same data that they
+// have written. In other words, writers should avoid reading from Data() at
+// all.
+func (ep *Endpoint) Data() []byte {
+ var bs []byte
+ bsReflect := (*reflect.SliceHeader)((unsafe.Pointer)(&bs))
+ bsReflect.Data = ep.packet + PacketHeaderBytes
+ bsReflect.Len = int(ep.dataCap)
+ bsReflect.Cap = int(ep.dataCap)
+ return bs
+}
+
+// ioSync is a dummy variable used to indicate synchronization to the Go race
+// detector. Compare syscall.ioSync.
+var ioSync int64
+
+func raceBecomeActive() {
+ if sync.RaceEnabled {
+ sync.RaceAcquire((unsafe.Pointer)(&ioSync))
+ }
+}
+
+func raceBecomeInactive() {
+ if sync.RaceEnabled {
+ sync.RaceReleaseMerge((unsafe.Pointer)(&ioSync))
+ }
+}
diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go
new file mode 100644
index 000000000..168c1ccff
--- /dev/null
+++ b/pkg/flipcall/futex_linux.go
@@ -0,0 +1,118 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package flipcall
+
+import (
+ "encoding/json"
+ "fmt"
+ "runtime"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+func (ep *Endpoint) futexConnect(req *ctrlHandshakeRequest) (ctrlHandshakeResponse, error) {
+ var resp ctrlHandshakeResponse
+
+ // Write the handshake request.
+ w := ep.NewWriter()
+ if err := json.NewEncoder(w).Encode(req); err != nil {
+ return resp, fmt.Errorf("error writing handshake request: %v", err)
+ }
+ *ep.dataLen() = w.Len()
+
+ // Exchange control with the server.
+ if err := ep.futexSwitchToPeer(); err != nil {
+ return resp, err
+ }
+ if err := ep.futexSwitchFromPeer(); err != nil {
+ return resp, err
+ }
+
+ // Read the handshake response.
+ respLen := atomic.LoadUint32(ep.dataLen())
+ if respLen > ep.dataCap {
+ return resp, fmt.Errorf("invalid handshake response length %d (maximum %d)", respLen, ep.dataCap)
+ }
+ if err := json.NewDecoder(ep.NewReader(respLen)).Decode(&resp); err != nil {
+ return resp, fmt.Errorf("error reading handshake response: %v", err)
+ }
+
+ return resp, nil
+}
+
+func (ep *Endpoint) futexSwitchToPeer() error {
+ // Update connection state to indicate that the peer should be active.
+ if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) {
+ switch cs := atomic.LoadUint32(ep.connState()); cs {
+ case csShutdown:
+ return ShutdownError{}
+ default:
+ return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs)
+ }
+ }
+
+ // Wake the peer's Endpoint.futexSwitchFromPeer().
+ if err := ep.futexWakeConnState(1); err != nil {
+ return fmt.Errorf("failed to FUTEX_WAKE peer Endpoint: %v", err)
+ }
+ return nil
+}
+
+func (ep *Endpoint) futexSwitchFromPeer() error {
+ for {
+ switch cs := atomic.LoadUint32(ep.connState()); cs {
+ case ep.activeState:
+ return nil
+ case ep.inactiveState:
+ if ep.isShutdownLocally() {
+ return ShutdownError{}
+ }
+ if err := ep.futexWaitConnState(ep.inactiveState); err != nil {
+ return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err)
+ }
+ continue
+ case csShutdown:
+ return ShutdownError{}
+ default:
+ return fmt.Errorf("unexpected connection state before FUTEX_WAIT: %v", cs)
+ }
+ }
+}
+
+func (ep *Endpoint) futexWakeConnState(numThreads int32) error {
+ if _, _, e := syscall.RawSyscall(syscall.SYS_FUTEX, ep.packet, linux.FUTEX_WAKE, uintptr(numThreads)); e != 0 {
+ return e
+ }
+ return nil
+}
+
+func (ep *Endpoint) futexWaitConnState(curState uint32) error {
+ _, _, e := syscall.Syscall6(syscall.SYS_FUTEX, ep.packet, linux.FUTEX_WAIT, uintptr(curState), 0, 0, 0)
+ if e != 0 && e != syscall.EAGAIN && e != syscall.EINTR {
+ return e
+ }
+ return nil
+}
+
+func yieldThread() {
+ syscall.Syscall(syscall.SYS_SCHED_YIELD, 0, 0, 0)
+ // The thread we're trying to yield to may be waiting for a Go runtime P.
+ // runtime.Gosched() will hand off ours if necessary.
+ runtime.Gosched()
+}
diff --git a/pkg/flipcall/io.go b/pkg/flipcall/io.go
new file mode 100644
index 000000000..85e40b932
--- /dev/null
+++ b/pkg/flipcall/io.go
@@ -0,0 +1,113 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flipcall
+
+import (
+ "fmt"
+ "io"
+)
+
+// DatagramReader implements io.Reader by reading a datagram from an Endpoint's
+// packet window. Its use is optional; users that can use Endpoint.Data() more
+// efficiently are advised to do so.
+type DatagramReader struct {
+ ep *Endpoint
+ off uint32
+ end uint32
+}
+
+// Init must be called on zero-value DatagramReaders before first use.
+//
+// Preconditions: dataLen is 0, or was returned by a previous call to
+// ep.RecvFirst() or ep.SendRecv().
+func (r *DatagramReader) Init(ep *Endpoint, dataLen uint32) {
+ r.ep = ep
+ r.Reset(dataLen)
+}
+
+// Reset causes r to begin reading a new datagram of the given length from the
+// associated Endpoint.
+//
+// Preconditions: dataLen is 0, or was returned by a previous call to the
+// associated Endpoint's RecvFirst() or SendRecv() methods.
+func (r *DatagramReader) Reset(dataLen uint32) {
+ if dataLen > r.ep.dataCap {
+ panic(fmt.Sprintf("invalid dataLen (%d) > ep.dataCap (%d)", dataLen, r.ep.dataCap))
+ }
+ r.off = 0
+ r.end = dataLen
+}
+
+// NewReader is a convenience function that returns an initialized
+// DatagramReader allocated on the heap.
+//
+// Preconditions: dataLen was returned by a previous call to ep.RecvFirst() or
+// ep.SendRecv().
+func (ep *Endpoint) NewReader(dataLen uint32) *DatagramReader {
+ r := &DatagramReader{}
+ r.Init(ep, dataLen)
+ return r
+}
+
+// Read implements io.Reader.Read.
+func (r *DatagramReader) Read(dst []byte) (int, error) {
+ n := copy(dst, r.ep.Data()[r.off:r.end])
+ r.off += uint32(n)
+ if r.off == r.end {
+ return n, io.EOF
+ }
+ return n, nil
+}
+
+// DatagramWriter implements io.Writer by writing a datagram to an Endpoint's
+// packet window. Its use is optional; users that can use Endpoint.Data() more
+// efficiently are advised to do so.
+type DatagramWriter struct {
+ ep *Endpoint
+ off uint32
+}
+
+// Init must be called on zero-value DatagramWriters before first use.
+func (w *DatagramWriter) Init(ep *Endpoint) {
+ w.ep = ep
+}
+
+// Reset causes w to begin writing a new datagram to the associated Endpoint.
+func (w *DatagramWriter) Reset() {
+ w.off = 0
+}
+
+// NewWriter is a convenience function that returns an initialized
+// DatagramWriter allocated on the heap.
+func (ep *Endpoint) NewWriter() *DatagramWriter {
+ w := &DatagramWriter{}
+ w.Init(ep)
+ return w
+}
+
+// Write implements io.Writer.Write.
+func (w *DatagramWriter) Write(src []byte) (int, error) {
+ n := copy(w.ep.Data()[w.off:w.ep.dataCap], src)
+ w.off += uint32(n)
+ if n != len(src) {
+ return n, fmt.Errorf("datagram would exceed maximum size of %d bytes", w.ep.dataCap)
+ }
+ return n, nil
+}
+
+// Len returns the length of the written datagram.
+func (w *DatagramWriter) Len() uint32 {
+ return w.off
+}
diff --git a/pkg/flipcall/packet_window_allocator.go b/pkg/flipcall/packet_window_allocator.go
new file mode 100644
index 000000000..af9cc3d21
--- /dev/null
+++ b/pkg/flipcall/packet_window_allocator.go
@@ -0,0 +1,166 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flipcall
+
+import (
+ "fmt"
+ "math/bits"
+ "os"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/memutil"
+)
+
+var (
+ pageSize = os.Getpagesize()
+ pageMask = pageSize - 1
+)
+
+func init() {
+ if bits.OnesCount(uint(pageSize)) != 1 {
+ // This is depended on by roundUpToPage().
+ panic(fmt.Sprintf("system page size (%d) is not a power of 2", pageSize))
+ }
+ if uintptr(pageSize) < PacketHeaderBytes {
+ // This is required since Endpoint.Init() imposes a minimum packet
+ // window size of 1 page.
+ panic(fmt.Sprintf("system page size (%d) is less than packet header size (%d)", pageSize, PacketHeaderBytes))
+ }
+}
+
+// PacketWindowDescriptor represents a packet window, a range of pages in a
+// shared memory file that is used to exchange packets between partner
+// Endpoints.
+type PacketWindowDescriptor struct {
+ // FD is the file descriptor representing the shared memory file.
+ FD int
+
+ // Offset is the offset into the shared memory file at which the packet
+ // window begins.
+ Offset int64
+
+ // Length is the size of the packet window in bytes.
+ Length int
+}
+
+// PacketWindowLengthForDataCap returns the minimum packet window size required
+// to accommodate datagrams of the given size in bytes.
+func PacketWindowLengthForDataCap(dataCap uint32) int {
+ return roundUpToPage(int(dataCap) + int(PacketHeaderBytes))
+}
+
+func roundUpToPage(x int) int {
+ return (x + pageMask) &^ pageMask
+}
+
+// A PacketWindowAllocator owns a shared memory file, and allocates packet
+// windows from it.
+type PacketWindowAllocator struct {
+ fd int
+ nextAlloc int64
+ fileSize int64
+}
+
+// Init must be called on zero-value PacketWindowAllocators before first use.
+// If it succeeds, Destroy() must be called once the PacketWindowAllocator is
+// no longer in use.
+func (pwa *PacketWindowAllocator) Init() error {
+ fd, err := memutil.CreateMemFD("flipcall_packet_windows", linux.MFD_CLOEXEC|linux.MFD_ALLOW_SEALING)
+ if err != nil {
+ return fmt.Errorf("failed to create memfd: %v", err)
+ }
+ // Apply F_SEAL_SHRINK to prevent either party from causing SIGBUS in the
+ // other by truncating the file, and F_SEAL_SEAL to prevent either party
+ // from applying F_SEAL_GROW or F_SEAL_WRITE.
+ if _, _, e := syscall.RawSyscall(syscall.SYS_FCNTL, uintptr(fd), linux.F_ADD_SEALS, linux.F_SEAL_SHRINK|linux.F_SEAL_SEAL); e != 0 {
+ syscall.Close(fd)
+ return fmt.Errorf("failed to apply memfd seals: %v", e)
+ }
+ pwa.fd = fd
+ return nil
+}
+
+// NewPacketWindowAllocator is a convenience function that returns an
+// initialized PacketWindowAllocator allocated on the heap.
+func NewPacketWindowAllocator() (*PacketWindowAllocator, error) {
+ var pwa PacketWindowAllocator
+ if err := pwa.Init(); err != nil {
+ return nil, err
+ }
+ return &pwa, nil
+}
+
+// Destroy releases resources owned by pwa. This invalidates file descriptors
+// previously returned by pwa.FD() and pwd.Allocate().
+func (pwa *PacketWindowAllocator) Destroy() {
+ syscall.Close(pwa.fd)
+}
+
+// FD represents the file descriptor of the shared memory file backing pwa.
+func (pwa *PacketWindowAllocator) FD() int {
+ return pwa.fd
+}
+
+// Allocate allocates a new packet window of at least the given size and
+// returns a PacketWindowDescriptor representing it.
+//
+// Preconditions: size > 0.
+func (pwa *PacketWindowAllocator) Allocate(size int) (PacketWindowDescriptor, error) {
+ if size <= 0 {
+ return PacketWindowDescriptor{}, fmt.Errorf("invalid size: %d", size)
+ }
+ // Page-align size to ensure that pwa.nextAlloc remains page-aligned.
+ size = roundUpToPage(size)
+ if size <= 0 {
+ return PacketWindowDescriptor{}, fmt.Errorf("size %d overflows after rounding up to page size", size)
+ }
+ end := pwa.nextAlloc + int64(size) // overflow checked by ensureFileSize
+ if err := pwa.ensureFileSize(end); err != nil {
+ return PacketWindowDescriptor{}, err
+ }
+ start := pwa.nextAlloc
+ pwa.nextAlloc = end
+ return PacketWindowDescriptor{
+ FD: pwa.FD(),
+ Offset: start,
+ Length: size,
+ }, nil
+}
+
+func (pwa *PacketWindowAllocator) ensureFileSize(min int64) error {
+ if min <= 0 {
+ return fmt.Errorf("file size would overflow")
+ }
+ if pwa.fileSize >= min {
+ return nil
+ }
+ newSize := 2 * pwa.fileSize
+ if newSize == 0 {
+ newSize = int64(pageSize)
+ }
+ for newSize < min {
+ newNewSize := newSize * 2
+ if newNewSize <= 0 {
+ return fmt.Errorf("file size would overflow")
+ }
+ newSize = newNewSize
+ }
+ if err := syscall.Ftruncate(pwa.FD(), newSize); err != nil {
+ return fmt.Errorf("ftruncate failed: %v", err)
+ }
+ pwa.fileSize = newSize
+ return nil
+}
diff --git a/pkg/flipcall/packet_window_mmap.go b/pkg/flipcall/packet_window_mmap.go
new file mode 100644
index 000000000..869183b11
--- /dev/null
+++ b/pkg/flipcall/packet_window_mmap.go
@@ -0,0 +1,25 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flipcall
+
+import (
+ "syscall"
+)
+
+// Return a memory mapping of the pwd in memory that can be shared outside the sandbox.
+func packetWindowMmap(pwd PacketWindowDescriptor) (uintptr, syscall.Errno) {
+ m, _, err := syscall.RawSyscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset))
+ return m, err
+}
diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD
new file mode 100644
index 000000000..67dd1e225
--- /dev/null
+++ b/pkg/fspath/BUILD
@@ -0,0 +1,26 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"])
+
+go_library(
+ name = "fspath",
+ srcs = [
+ "builder.go",
+ "fspath.go",
+ ],
+ deps = [
+ "//pkg/gohacks",
+ ],
+)
+
+go_test(
+ name = "fspath_test",
+ size = "small",
+ srcs = [
+ "builder_test.go",
+ "fspath_test.go",
+ ],
+ library = ":fspath",
+)
diff --git a/pkg/fspath/builder.go b/pkg/fspath/builder.go
new file mode 100644
index 000000000..6318d3874
--- /dev/null
+++ b/pkg/fspath/builder.go
@@ -0,0 +1,112 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fspath
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/gohacks"
+)
+
+// Builder is similar to strings.Builder, but is used to produce pathnames
+// given path components in reverse order (from leaf to root). This is useful
+// in the common case where a filesystem is represented by a tree of named
+// nodes, and the path to a given node must be produced by walking upward from
+// that node to a given root.
+type Builder struct {
+ buf []byte
+ start int
+ needSep bool
+}
+
+// Reset resets the Builder to be empty.
+func (b *Builder) Reset() {
+ b.start = len(b.buf)
+ b.needSep = false
+}
+
+// Len returns the number of accumulated bytes.
+func (b *Builder) Len() int {
+ return len(b.buf) - b.start
+}
+
+func (b *Builder) needToGrow(n int) bool {
+ return b.start < n
+}
+
+func (b *Builder) grow(n int) {
+ newLen := b.Len() + n
+ var newCap int
+ if len(b.buf) == 0 {
+ newCap = 64 // arbitrary
+ } else {
+ newCap = 2 * len(b.buf)
+ }
+ for newCap < newLen {
+ newCap *= 2
+ if newCap == 0 {
+ panic(fmt.Sprintf("required length (%d) causes buffer size to overflow", newLen))
+ }
+ }
+ newBuf := make([]byte, newCap)
+ copy(newBuf[newCap-b.Len():], b.buf[b.start:])
+ b.start += newCap - len(b.buf)
+ b.buf = newBuf
+}
+
+// PrependComponent prepends the given path component to b's buffer. A path
+// separator is automatically inserted if appropriate.
+func (b *Builder) PrependComponent(pc string) {
+ if b.needSep {
+ b.PrependByte('/')
+ }
+ b.PrependString(pc)
+ b.needSep = true
+}
+
+// PrependString prepends the given string to b's buffer.
+func (b *Builder) PrependString(str string) {
+ if b.needToGrow(len(str)) {
+ b.grow(len(str))
+ }
+ b.start -= len(str)
+ copy(b.buf[b.start:], str)
+}
+
+// PrependByte prepends the given byte to b's buffer.
+func (b *Builder) PrependByte(c byte) {
+ if b.needToGrow(1) {
+ b.grow(1)
+ }
+ b.start--
+ b.buf[b.start] = c
+}
+
+// AppendString appends the given string to b's buffer.
+func (b *Builder) AppendString(str string) {
+ if b.needToGrow(len(str)) {
+ b.grow(len(str))
+ }
+ oldStart := b.start
+ b.start -= len(str)
+ copy(b.buf[b.start:], b.buf[oldStart:])
+ copy(b.buf[len(b.buf)-len(str):], str)
+}
+
+// String returns the accumulated string. No other methods should be called
+// after String.
+func (b *Builder) String() string {
+ return gohacks.StringFromImmutableBytes(b.buf[b.start:])
+}
diff --git a/pkg/fspath/builder_test.go b/pkg/fspath/builder_test.go
new file mode 100644
index 000000000..22f890273
--- /dev/null
+++ b/pkg/fspath/builder_test.go
@@ -0,0 +1,58 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fspath
+
+import (
+ "testing"
+)
+
+func TestBuilder(t *testing.T) {
+ type testCase struct {
+ pcs []string // path components in reverse order
+ after string
+ want string
+ }
+ tests := []testCase{
+ {
+ // Empty case.
+ },
+ {
+ pcs: []string{"foo"},
+ want: "foo",
+ },
+ {
+ pcs: []string{"foo", "bar", "baz"},
+ want: "baz/bar/foo",
+ },
+ {
+ pcs: []string{"foo", "bar"},
+ after: " (deleted)",
+ want: "bar/foo (deleted)",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.want, func(t *testing.T) {
+ var b Builder
+ for _, pc := range test.pcs {
+ b.PrependComponent(pc)
+ }
+ b.AppendString(test.after)
+ if got := b.String(); got != test.want {
+ t.Errorf("got %q, wanted %q", got, test.want)
+ }
+ })
+ }
+}
diff --git a/pkg/fspath/fspath.go b/pkg/fspath/fspath.go
new file mode 100644
index 000000000..4c983d5fd
--- /dev/null
+++ b/pkg/fspath/fspath.go
@@ -0,0 +1,187 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fspath provides efficient tools for working with file paths in
+// Linux-compatible filesystem implementations.
+package fspath
+
+import (
+ "strings"
+)
+
+const pathSep = '/'
+
+// Parse parses a pathname as described by path_resolution(7), except that
+// empty pathnames will be parsed successfully to a Path for which
+// Path.Absolute == Path.Dir == Path.HasComponents() == false. (This is
+// necessary to support AT_EMPTY_PATH.)
+func Parse(pathname string) Path {
+ if len(pathname) == 0 {
+ return Path{}
+ }
+ // Skip leading path separators.
+ i := 0
+ for pathname[i] == pathSep {
+ i++
+ if i == len(pathname) {
+ // pathname consists entirely of path separators.
+ return Path{
+ Absolute: true,
+ Dir: true,
+ }
+ }
+ }
+ // Skip trailing path separators. This is required by Iterator.Next. This
+ // loop is guaranteed to terminate with j >= 0 because otherwise the
+ // pathname would consist entirely of path separators, so we would have
+ // returned above.
+ j := len(pathname) - 1
+ for pathname[j] == pathSep {
+ j--
+ }
+ // Find the end of the first path component.
+ firstEnd := i + 1
+ for firstEnd != len(pathname) && pathname[firstEnd] != pathSep {
+ firstEnd++
+ }
+ return Path{
+ Begin: Iterator{
+ partialPathname: pathname[i : j+1],
+ end: firstEnd - i,
+ },
+ Absolute: i != 0,
+ Dir: j != len(pathname)-1,
+ }
+}
+
+// Path contains the information contained in a pathname string.
+//
+// Path is copyable by value. The zero value for Path is equivalent to
+// fspath.Parse(""), i.e. the empty path.
+type Path struct {
+ // Begin is an iterator to the first path component in the relative part of
+ // the path.
+ //
+ // Path doesn't store information about path components after the first
+ // since this would require allocation.
+ Begin Iterator
+
+ // If true, the path is absolute, such that lookup should begin at the
+ // filesystem root. If false, the path is relative, such that where lookup
+ // begins is unspecified.
+ Absolute bool
+
+ // If true, the pathname contains trailing path separators, so the last
+ // path component must exist and resolve to a directory.
+ Dir bool
+}
+
+// String returns a pathname string equivalent to p. Note that the returned
+// string is not necessarily equal to the string p was parsed from; in
+// particular, redundant path separators will not be present.
+func (p Path) String() string {
+ var b strings.Builder
+ if p.Absolute {
+ b.WriteByte(pathSep)
+ }
+ sep := false
+ for pit := p.Begin; pit.Ok(); pit = pit.Next() {
+ if sep {
+ b.WriteByte(pathSep)
+ }
+ b.WriteString(pit.String())
+ sep = true
+ }
+ // Don't return "//" for Parse("/").
+ if p.Dir && p.Begin.Ok() {
+ b.WriteByte(pathSep)
+ }
+ return b.String()
+}
+
+// HasComponents returns true if p contains a non-zero number of path
+// components.
+func (p Path) HasComponents() bool {
+ return p.Begin.Ok()
+}
+
+// An Iterator represents either a path component in a Path or a terminal
+// iterator indicating that the end of the path has been reached.
+//
+// Iterator is immutable and copyable by value. The zero value of Iterator is
+// valid, and represents a terminal iterator.
+type Iterator struct {
+ // partialPathname is a substring of the original pathname beginning at the
+ // start of the represented path component and ending immediately after the
+ // end of the last path component in the pathname. If partialPathname is
+ // empty, the PathnameIterator is terminal.
+ //
+ // See TestParseIteratorPartialPathnames in fspath_test.go for a worked
+ // example.
+ partialPathname string
+
+ // end is the offset into partialPathname of the first byte after the end
+ // of the represented path component.
+ end int
+}
+
+// Ok returns true if it is not terminal.
+func (it Iterator) Ok() bool {
+ return len(it.partialPathname) != 0
+}
+
+// String returns the path component represented by it.
+//
+// Preconditions: it.Ok().
+func (it Iterator) String() string {
+ return it.partialPathname[:it.end]
+}
+
+// Next returns an iterator to the path component after it. If it is the last
+// component in the path, Next returns a terminal iterator.
+//
+// Preconditions: it.Ok().
+func (it Iterator) Next() Iterator {
+ if it.end == len(it.partialPathname) {
+ // End of the path.
+ return Iterator{}
+ }
+ // Skip path separators. Since Parse trims trailing path separators, if we
+ // aren't at the end of the path, there is definitely another path
+ // component.
+ i := it.end + 1
+ for {
+ if it.partialPathname[i] != pathSep {
+ break
+ }
+ i++
+ }
+ nextPartialPathname := it.partialPathname[i:]
+ // Find the end of this path component.
+ nextEnd := 1
+ for nextEnd < len(nextPartialPathname) && nextPartialPathname[nextEnd] != pathSep {
+ nextEnd++
+ }
+ return Iterator{
+ partialPathname: nextPartialPathname,
+ end: nextEnd,
+ }
+}
+
+// NextOk is equivalent to it.Next().Ok(), but is faster.
+//
+// Preconditions: it.Ok().
+func (it Iterator) NextOk() bool {
+ return it.end != len(it.partialPathname)
+}
diff --git a/pkg/fspath/fspath_test.go b/pkg/fspath/fspath_test.go
new file mode 100644
index 000000000..d5e9a549a
--- /dev/null
+++ b/pkg/fspath/fspath_test.go
@@ -0,0 +1,134 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fspath
+
+import (
+ "reflect"
+ "strings"
+ "testing"
+)
+
+func TestParseIteratorPartialPathnames(t *testing.T) {
+ path := Parse("/foo//bar///baz////")
+ // Parse strips leading slashes, and records their presence as
+ // Path.Absolute.
+ if !path.Absolute {
+ t.Errorf("Path.Absolute: got false, wanted true")
+ }
+ // Parse strips trailing slashes, and records their presence as Path.Dir.
+ if !path.Dir {
+ t.Errorf("Path.Dir: got false, wanted true")
+ }
+ // The first Iterator.partialPathname is the input pathname, with leading
+ // and trailing slashes stripped.
+ it := path.Begin
+ if want := "foo//bar///baz"; it.partialPathname != want {
+ t.Errorf("first Iterator.partialPathname: got %q, wanted %q", it.partialPathname, want)
+ }
+ // Successive Iterator.partialPathnames remove the leading path component
+ // and following slashes, until we run out of path components and get a
+ // terminal Iterator.
+ it = it.Next()
+ if want := "bar///baz"; it.partialPathname != want {
+ t.Errorf("second Iterator.partialPathname: got %q, wanted %q", it.partialPathname, want)
+ }
+ it = it.Next()
+ if want := "baz"; it.partialPathname != want {
+ t.Errorf("third Iterator.partialPathname: got %q, wanted %q", it.partialPathname, want)
+ }
+ it = it.Next()
+ if want := ""; it.partialPathname != want {
+ t.Errorf("fourth Iterator.partialPathname: got %q, wanted %q", it.partialPathname, want)
+ }
+ if it.Ok() {
+ t.Errorf("fourth Iterator.Ok(): got true, wanted false")
+ }
+}
+
+func TestParse(t *testing.T) {
+ type testCase struct {
+ pathname string
+ relpath []string
+ abs bool
+ dir bool
+ }
+ tests := []testCase{
+ {
+ pathname: "",
+ relpath: []string{},
+ abs: false,
+ dir: false,
+ },
+ {
+ pathname: "/",
+ relpath: []string{},
+ abs: true,
+ dir: true,
+ },
+ {
+ pathname: "//",
+ relpath: []string{},
+ abs: true,
+ dir: true,
+ },
+ }
+ for _, sep := range []string{"/", "//"} {
+ for _, abs := range []bool{false, true} {
+ for _, dir := range []bool{false, true} {
+ for _, pcs := range [][]string{
+ // single path component
+ {"foo"},
+ // multiple path components, including non-UTF-8
+ {".", "foo", "..", "\xe6", "bar"},
+ } {
+ prefix := ""
+ if abs {
+ prefix = sep
+ }
+ suffix := ""
+ if dir {
+ suffix = sep
+ }
+ tests = append(tests, testCase{
+ pathname: prefix + strings.Join(pcs, sep) + suffix,
+ relpath: pcs,
+ abs: abs,
+ dir: dir,
+ })
+ }
+ }
+ }
+ }
+
+ for _, test := range tests {
+ t.Run(test.pathname, func(t *testing.T) {
+ p := Parse(test.pathname)
+ t.Logf("pathname %q => path %q", test.pathname, p)
+ if p.Absolute != test.abs {
+ t.Errorf("path absoluteness: got %v, wanted %v", p.Absolute, test.abs)
+ }
+ if p.Dir != test.dir {
+ t.Errorf("path must resolve to a directory: got %v, wanted %v", p.Dir, test.dir)
+ }
+ pcs := []string{}
+ for pit := p.Begin; pit.Ok(); pit = pit.Next() {
+ pcs = append(pcs, pit.String())
+ }
+ if !reflect.DeepEqual(pcs, test.relpath) {
+ t.Errorf("relative path: got %v, wanted %v", pcs, test.relpath)
+ }
+ })
+ }
+}
diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD
new file mode 100644
index 000000000..dd3141143
--- /dev/null
+++ b/pkg/gate/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "gate",
+ srcs = [
+ "gate.go",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+go_test(
+ name = "gate_test",
+ srcs = [
+ "gate_test.go",
+ ],
+ deps = [
+ ":gate",
+ "//pkg/sync",
+ ],
+)
diff --git a/pkg/gate/gate.go b/pkg/gate/gate.go
new file mode 100644
index 000000000..bda6aae09
--- /dev/null
+++ b/pkg/gate/gate.go
@@ -0,0 +1,134 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gate provides a usage Gate synchronization primitive.
+package gate
+
+import (
+ "sync/atomic"
+)
+
+const (
+ // gateClosed is the bit set in the gate's user count to indicate that
+ // it has been closed. It is the MSB of the 32-bit field; the other 31
+ // bits carry the actual count.
+ gateClosed = 0x80000000
+)
+
+// Gate is a synchronization primitive that allows concurrent goroutines to
+// "enter" it as long as it hasn't been closed yet. Once it's been closed,
+// goroutines cannot enter it anymore, but are allowed to leave, and the closer
+// will be informed when all goroutines have left.
+//
+// Many goroutines are allowed to enter the gate concurrently, but only one is
+// allowed to close it.
+//
+// This is similar to a r/w critical section, except that goroutines "entering"
+// never block: they either enter immediately or fail to enter. The closer will
+// block waiting for all goroutines currently inside the gate to leave.
+//
+// This function is implemented efficiently. On x86, only one interlocked
+// operation is performed on enter, and one on leave.
+//
+// This is useful, for example, in cases when a goroutine is trying to clean up
+// an object for which multiple goroutines have pointers. In such a case, users
+// would be required to enter and leave the gates, and the cleaner would wait
+// until all users are gone (and no new ones are allowed) before proceeding.
+//
+// Users:
+//
+// if !g.Enter() {
+// // Gate is closed, we can't use the object.
+// return
+// }
+//
+// // Do something with object.
+// [...]
+//
+// g.Leave()
+//
+// Closer:
+//
+// // Prevent new users from using the object, and wait for the existing
+// // ones to complete.
+// g.Close()
+//
+// // Clean up the object.
+// [...]
+//
+type Gate struct {
+ userCount uint32
+ done chan struct{}
+}
+
+// Enter tries to enter the gate. It will succeed if it hasn't been closed yet,
+// in which case the caller must eventually call Leave().
+//
+// This function is thread-safe.
+func (g *Gate) Enter() bool {
+ if g == nil {
+ return false
+ }
+
+ for {
+ v := atomic.LoadUint32(&g.userCount)
+ if v&gateClosed != 0 {
+ return false
+ }
+
+ if atomic.CompareAndSwapUint32(&g.userCount, v, v+1) {
+ return true
+ }
+ }
+}
+
+// Leave leaves the gate. This must only be called after a successful call to
+// Enter(). If the gate has been closed and this is the last one inside the
+// gate, it will notify the closer that the gate is done.
+//
+// This function is thread-safe.
+func (g *Gate) Leave() {
+ for {
+ v := atomic.LoadUint32(&g.userCount)
+ if v&^gateClosed == 0 {
+ panic("leaving a gate with zero usage count")
+ }
+
+ if atomic.CompareAndSwapUint32(&g.userCount, v, v-1) {
+ if v == gateClosed+1 {
+ close(g.done)
+ }
+ return
+ }
+ }
+}
+
+// Close closes the gate for entering, and waits until all goroutines [that are
+// currently inside the gate] leave before returning.
+//
+// Only one goroutine can call this function.
+func (g *Gate) Close() {
+ for {
+ v := atomic.LoadUint32(&g.userCount)
+ if v&^gateClosed != 0 && g.done == nil {
+ g.done = make(chan struct{})
+ }
+ if atomic.CompareAndSwapUint32(&g.userCount, v, v|gateClosed) {
+ if v&^gateClosed != 0 {
+ <-g.done
+ }
+ return
+ }
+ }
+}
diff --git a/pkg/gate/gate_test.go b/pkg/gate/gate_test.go
new file mode 100644
index 000000000..316015e06
--- /dev/null
+++ b/pkg/gate/gate_test.go
@@ -0,0 +1,192 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gate_test
+
+import (
+ "runtime"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/gate"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func TestBasicEnter(t *testing.T) {
+ var g gate.Gate
+
+ if !g.Enter() {
+ t.Fatalf("Failed to enter when it should be allowed")
+ }
+
+ g.Leave()
+
+ g.Close()
+
+ if g.Enter() {
+ t.Fatalf("Allowed to enter when it should fail")
+ }
+}
+
+func enterFunc(t *testing.T, g *gate.Gate, enter, leave, reenter chan struct{}, done1, done2, done3 *sync.WaitGroup) {
+ // Wait until instructed to enter.
+ <-enter
+ if !g.Enter() {
+ t.Errorf("Failed to enter when it should be allowed")
+ }
+
+ done1.Done()
+
+ // Wait until instructed to leave.
+ <-leave
+ g.Leave()
+
+ done2.Done()
+
+ // Wait until instructed to reenter.
+ <-reenter
+ if g.Enter() {
+ t.Errorf("Allowed to enter when it should fail")
+ }
+ done3.Done()
+}
+
+func TestConcurrentEnter(t *testing.T) {
+ var g gate.Gate
+ var done1, done2, done3 sync.WaitGroup
+
+ // Create 1000 worker goroutines.
+ enter := make(chan struct{})
+ leave := make(chan struct{})
+ reenter := make(chan struct{})
+ done1.Add(1000)
+ done2.Add(1000)
+ done3.Add(1000)
+ for i := 0; i < 1000; i++ {
+ go enterFunc(t, &g, enter, leave, reenter, &done1, &done2, &done3)
+ }
+
+ // Tell them all to enter, then leave.
+ close(enter)
+ done1.Wait()
+
+ close(leave)
+ done2.Wait()
+
+ // Close the gate, then have the workers try to enter again.
+ g.Close()
+ close(reenter)
+ done3.Wait()
+}
+
+func closeFunc(g *gate.Gate, done chan struct{}) {
+ g.Close()
+ close(done)
+}
+
+func TestCloseWaits(t *testing.T) {
+ var g gate.Gate
+
+ // Enter 10 times.
+ for i := 0; i < 10; i++ {
+ if !g.Enter() {
+ t.Fatalf("Failed to enter when it should be allowed")
+ }
+ }
+
+ // Launch closer. Check that it doesn't complete.
+ done := make(chan struct{})
+ go closeFunc(&g, done)
+
+ for i := 0; i < 10; i++ {
+ select {
+ case <-done:
+ t.Fatalf("Close function completed too soon")
+ case <-time.After(100 * time.Millisecond):
+ }
+
+ g.Leave()
+ }
+
+ // Now the closer must complete.
+ <-done
+}
+
+func TestMultipleSerialCloses(t *testing.T) {
+ var g gate.Gate
+
+ // Enter 10 times.
+ for i := 0; i < 10; i++ {
+ if !g.Enter() {
+ t.Fatalf("Failed to enter when it should be allowed")
+ }
+ }
+
+ // Launch closer. Check that it doesn't complete.
+ done := make(chan struct{})
+ go closeFunc(&g, done)
+
+ for i := 0; i < 10; i++ {
+ select {
+ case <-done:
+ t.Fatalf("Close function completed too soon")
+ case <-time.After(100 * time.Millisecond):
+ }
+
+ g.Leave()
+ }
+
+ // Now the closer must complete.
+ <-done
+
+ // Close again should not block.
+ done = make(chan struct{})
+ go closeFunc(&g, done)
+
+ select {
+ case <-done:
+ case <-time.After(2 * time.Second):
+ t.Fatalf("Second Close is blocking")
+ }
+}
+
+func worker(g *gate.Gate, done *sync.WaitGroup) {
+ for {
+ if !g.Enter() {
+ break
+ }
+ // Golang before v1.14 doesn't preempt busyloops.
+ runtime.Gosched()
+ g.Leave()
+ }
+ done.Done()
+}
+
+func TestConcurrentAll(t *testing.T) {
+ var g gate.Gate
+ var done sync.WaitGroup
+
+ // Launch 1000 goroutines to concurrently enter/leave.
+ done.Add(1000)
+ for i := 0; i < 1000; i++ {
+ go worker(&g, &done)
+ }
+
+ // Wait for the goroutines to do some work, then close the gate.
+ time.Sleep(2 * time.Second)
+ g.Close()
+
+ // Wait for all of them to complete.
+ done.Wait()
+}
diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD
new file mode 100644
index 000000000..35683fe98
--- /dev/null
+++ b/pkg/gohacks/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "gohacks",
+ srcs = [
+ "gohacks_unsafe.go",
+ ],
+ stateify = False,
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/gohacks/gohacks_unsafe.go b/pkg/gohacks/gohacks_unsafe.go
new file mode 100644
index 000000000..aad675172
--- /dev/null
+++ b/pkg/gohacks/gohacks_unsafe.go
@@ -0,0 +1,57 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gohacks contains utilities for subverting the Go compiler.
+package gohacks
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// Noescape hides a pointer from escape analysis. Noescape is the identity
+// function but escape analysis doesn't think the output depends on the input.
+// Noescape is inlined and currently compiles down to zero instructions.
+// USE CAREFULLY!
+//
+// (Noescape is copy/pasted from Go's runtime/stubs.go:noescape().)
+//
+//go:nosplit
+func Noescape(p unsafe.Pointer) unsafe.Pointer {
+ x := uintptr(p)
+ return unsafe.Pointer(x ^ 0)
+}
+
+// ImmutableBytesFromString is equivalent to []byte(s), except that it uses the
+// same memory backing s instead of making a heap-allocated copy. This is only
+// valid if the returned slice is never mutated.
+func ImmutableBytesFromString(s string) []byte {
+ shdr := (*reflect.StringHeader)(unsafe.Pointer(&s))
+ var bs []byte
+ bshdr := (*reflect.SliceHeader)(unsafe.Pointer(&bs))
+ bshdr.Data = shdr.Data
+ bshdr.Len = shdr.Len
+ bshdr.Cap = shdr.Len
+ return bs
+}
+
+// StringFromImmutableBytes is equivalent to string(bs), except that it uses
+// the same memory backing bs instead of making a heap-allocated copy. This is
+// only valid if bs is never mutated after StringFromImmutableBytes returns.
+func StringFromImmutableBytes(bs []byte) string {
+ // This is cheaper than messing with reflect.StringHeader and
+ // reflect.SliceHeader, which as of this writing produces many dead stores
+ // of zeroes. Compare strings.Builder.String().
+ return *(*string)(unsafe.Pointer(&bs))
+}
diff --git a/pkg/goid/BUILD b/pkg/goid/BUILD
new file mode 100644
index 000000000..7a82631c5
--- /dev/null
+++ b/pkg/goid/BUILD
@@ -0,0 +1,25 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "goid",
+ srcs = [
+ "goid.go",
+ "goid_amd64.s",
+ "goid_arm64.s",
+ "goid_race.go",
+ "goid_unsafe.go",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+go_test(
+ name = "goid_test",
+ size = "small",
+ srcs = [
+ "empty_test.go",
+ "goid_test.go",
+ ],
+ library = ":goid",
+)
diff --git a/pkg/goid/empty_test.go b/pkg/goid/empty_test.go
new file mode 100644
index 000000000..c0a4b17ab
--- /dev/null
+++ b/pkg/goid/empty_test.go
@@ -0,0 +1,22 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !race
+
+package goid
+
+import "testing"
+
+// TestNothing exists to make the build system happy.
+func TestNothing(t *testing.T) {}
diff --git a/pkg/goid/goid.go b/pkg/goid/goid.go
new file mode 100644
index 000000000..39df30031
--- /dev/null
+++ b/pkg/goid/goid.go
@@ -0,0 +1,24 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !race
+
+// Package goid provides access to the ID of the current goroutine in
+// race/gotsan builds.
+package goid
+
+// Get returns the ID of the current goroutine.
+func Get() int64 {
+ panic("unimplemented for non-race builds")
+}
diff --git a/pkg/goid/goid_amd64.s b/pkg/goid/goid_amd64.s
new file mode 100644
index 000000000..d9f5cd2a3
--- /dev/null
+++ b/pkg/goid/goid_amd64.s
@@ -0,0 +1,21 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// func getg() *g
+TEXT ·getg(SB),NOSPLIT,$0-8
+ MOVQ (TLS), R14
+ MOVQ R14, ret+0(FP)
+ RET
diff --git a/pkg/goid/goid_arm64.s b/pkg/goid/goid_arm64.s
new file mode 100644
index 000000000..a7465b75d
--- /dev/null
+++ b/pkg/goid/goid_arm64.s
@@ -0,0 +1,21 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// func getg() *g
+TEXT ·getg(SB),NOSPLIT,$0-8
+ MOVD g, R0 // g
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/goid/goid_race.go b/pkg/goid/goid_race.go
new file mode 100644
index 000000000..1766beaee
--- /dev/null
+++ b/pkg/goid/goid_race.go
@@ -0,0 +1,25 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Only available in race/gotsan builds.
+// +build race
+
+// Package goid provides access to the ID of the current goroutine in
+// race/gotsan builds.
+package goid
+
+// Get returns the ID of the current goroutine.
+func Get() int64 {
+ return goid()
+}
diff --git a/pkg/goid/goid_test.go b/pkg/goid/goid_test.go
new file mode 100644
index 000000000..31970ce79
--- /dev/null
+++ b/pkg/goid/goid_test.go
@@ -0,0 +1,74 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build race
+
+package goid
+
+import (
+ "runtime"
+ "sync"
+ "testing"
+)
+
+func TestInitialGoID(t *testing.T) {
+ const max = 10000
+ if id := goid(); id < 0 || id > max {
+ t.Errorf("got goid = %d, want 0 < goid <= %d", id, max)
+ }
+}
+
+// TestGoIDSquence verifies that goid returns values which could plausibly be
+// goroutine IDs. If this test breaks or becomes flaky, the structs in
+// goid_unsafe.go may need to be updated.
+func TestGoIDSquence(t *testing.T) {
+ // Goroutine IDs are cached by each P.
+ runtime.GOMAXPROCS(1)
+
+ // Fill any holes in lower range.
+ for i := 0; i < 50; i++ {
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ wg.Done()
+
+ // Leak the goroutine to prevent the ID from being
+ // reused.
+ select {}
+ }()
+ wg.Wait()
+ }
+
+ id := goid()
+ for i := 0; i < 100; i++ {
+ var (
+ newID int64
+ wg sync.WaitGroup
+ )
+ wg.Add(1)
+ go func() {
+ newID = goid()
+ wg.Done()
+
+ // Leak the goroutine to prevent the ID from being
+ // reused.
+ select {}
+ }()
+ wg.Wait()
+ if max := id + 100; newID <= id || newID > max {
+ t.Errorf("unexpected goroutine ID pattern, got goid = %d, want %d < goid <= %d (previous = %d)", newID, id, max, id)
+ }
+ id = newID
+ }
+}
diff --git a/pkg/goid/goid_unsafe.go b/pkg/goid/goid_unsafe.go
new file mode 100644
index 000000000..ded8004dd
--- /dev/null
+++ b/pkg/goid/goid_unsafe.go
@@ -0,0 +1,64 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package goid
+
+// Structs from Go runtime. These may change in the future and require
+// updating. These structs are currently the same on both AMD64 and ARM64,
+// but may diverge in the future.
+
+type stack struct {
+ lo uintptr
+ hi uintptr
+}
+
+type gobuf struct {
+ sp uintptr
+ pc uintptr
+ g uintptr
+ ctxt uintptr
+ ret uint64
+ lr uintptr
+ bp uintptr
+}
+
+type g struct {
+ stack stack
+ stackguard0 uintptr
+ stackguard1 uintptr
+
+ _panic uintptr
+ _defer uintptr
+ m uintptr
+ sched gobuf
+ syscallsp uintptr
+ syscallpc uintptr
+ stktopsp uintptr
+ param uintptr
+ atomicstatus uint32
+ stackLock uint32
+ goid int64
+
+ // More fields...
+ //
+ // We only use goid and the fields before it are only listed to
+ // calculate the correct offset.
+}
+
+func getg() *g
+
+// goid returns the ID of the current goroutine.
+func goid() int64 {
+ return getg().goid
+}
diff --git a/pkg/ilist/BUILD b/pkg/ilist/BUILD
new file mode 100644
index 000000000..3f6eb07df
--- /dev/null
+++ b/pkg/ilist/BUILD
@@ -0,0 +1,56 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "ilist",
+ srcs = [
+ "interface_list.go",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+go_template_instance(
+ name = "interface_list",
+ out = "interface_list.go",
+ package = "ilist",
+ template = ":generic_list",
+ types = {},
+)
+
+# This list is used for benchmarking.
+go_template_instance(
+ name = "test_list",
+ out = "test_list.go",
+ package = "ilist",
+ prefix = "direct",
+ template = ":generic_list",
+ types = {
+ "Element": "*direct",
+ "Linker": "*direct",
+ },
+)
+
+go_test(
+ name = "list_test",
+ size = "small",
+ srcs = [
+ "list_test.go",
+ "test_list.go",
+ ],
+ library = ":ilist",
+)
+
+go_template(
+ name = "generic_list",
+ srcs = [
+ "list.go",
+ ],
+ opt_types = [
+ "Element",
+ "ElementMapper",
+ "Linker",
+ ],
+ visibility = ["//visibility:public"],
+)
diff --git a/pkg/ilist/list.go b/pkg/ilist/list.go
new file mode 100644
index 000000000..f4a4c33d3
--- /dev/null
+++ b/pkg/ilist/list.go
@@ -0,0 +1,227 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ilist provides the implementation of intrusive linked lists.
+package ilist
+
+// Linker is the interface that objects must implement if they want to be added
+// to and/or removed from List objects.
+//
+// N.B. When substituted in a template instantiation, Linker doesn't need to
+// be an interface, and in most cases won't be.
+type Linker interface {
+ Next() Element
+ Prev() Element
+ SetNext(Element)
+ SetPrev(Element)
+}
+
+// Element the item that is used at the API level.
+//
+// N.B. Like Linker, this is unlikely to be an interface in most cases.
+type Element interface {
+ Linker
+}
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type ElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (ElementMapper) linkerFor(elem Element) Linker { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type List struct {
+ head Element
+ tail Element
+}
+
+// Reset resets list l to the empty state.
+func (l *List) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *List) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *List) Front() Element {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *List) Back() Element {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+func (l *List) Len() (count int) {
+ for e := l.Front(); e != nil; e = (ElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *List) PushFront(e Element) {
+ linker := ElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ ElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *List) PushBack(e Element) {
+ linker := ElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ ElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *List) PushBackList(m *List) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ ElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ ElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *List) InsertAfter(b, e Element) {
+ bLinker := ElementMapper{}.linkerFor(b)
+ eLinker := ElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ ElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *List) InsertBefore(a, e Element) {
+ aLinker := ElementMapper{}.linkerFor(a)
+ eLinker := ElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ ElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *List) Remove(e Element) {
+ linker := ElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ ElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ ElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type Entry struct {
+ next Element
+ prev Element
+}
+
+// Next returns the entry that follows e in the list.
+func (e *Entry) Next() Element {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *Entry) Prev() Element {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *Entry) SetNext(elem Element) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *Entry) SetPrev(elem Element) {
+ e.prev = elem
+}
diff --git a/pkg/ilist/list_test.go b/pkg/ilist/list_test.go
new file mode 100644
index 000000000..3f9abfb56
--- /dev/null
+++ b/pkg/ilist/list_test.go
@@ -0,0 +1,240 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ilist
+
+import (
+ "testing"
+)
+
+type testEntry struct {
+ Entry
+ value int
+}
+
+type direct struct {
+ directEntry
+ value int
+}
+
+func verifyEquality(t *testing.T, entries []testEntry, l *List) {
+ t.Helper()
+
+ i := 0
+ for it := l.Front(); it != nil; it = it.Next() {
+ e := it.(*testEntry)
+ if e != &entries[i] {
+ t.Errorf("Wrong entry at index %d", i)
+ return
+ }
+ i++
+ }
+
+ if i != len(entries) {
+ t.Errorf("Wrong number of entries; want = %d, got = %d", len(entries), i)
+ return
+ }
+
+ i = 0
+ for it := l.Back(); it != nil; it = it.Prev() {
+ e := it.(*testEntry)
+ if e != &entries[len(entries)-1-i] {
+ t.Errorf("Wrong entry at index %d", i)
+ return
+ }
+ i++
+ }
+
+ if i != len(entries) {
+ t.Errorf("Wrong number of entries; want = %d, got = %d", len(entries), i)
+ return
+ }
+}
+
+func TestZeroEmpty(t *testing.T) {
+ var l List
+ if l.Front() != nil {
+ t.Error("Front is non-nil")
+ }
+ if l.Back() != nil {
+ t.Error("Back is non-nil")
+ }
+}
+
+func TestPushBack(t *testing.T) {
+ var l List
+
+ // Test single entry insertion.
+ var entry testEntry
+ l.PushBack(&entry)
+
+ e := l.Front().(*testEntry)
+ if e != &entry {
+ t.Error("Wrong entry returned")
+ }
+
+ // Test inserting 100 entries.
+ l.Reset()
+ var entries [100]testEntry
+ for i := range entries {
+ l.PushBack(&entries[i])
+ }
+
+ verifyEquality(t, entries[:], &l)
+}
+
+func TestPushFront(t *testing.T) {
+ var l List
+
+ // Test single entry insertion.
+ var entry testEntry
+ l.PushFront(&entry)
+
+ e := l.Front().(*testEntry)
+ if e != &entry {
+ t.Error("Wrong entry returned")
+ }
+
+ // Test inserting 100 entries.
+ l.Reset()
+ var entries [100]testEntry
+ for i := range entries {
+ l.PushFront(&entries[len(entries)-1-i])
+ }
+
+ verifyEquality(t, entries[:], &l)
+}
+
+func TestRemove(t *testing.T) {
+ // Remove entry from single-element list.
+ var l List
+ var entry testEntry
+ l.PushBack(&entry)
+ l.Remove(&entry)
+ if l.Front() != nil {
+ t.Error("List is empty")
+ }
+
+ var entries [100]testEntry
+
+ // Remove single element from lists of lengths 2 to 101.
+ for n := 1; n <= len(entries); n++ {
+ for extra := 0; extra <= n; extra++ {
+ l.Reset()
+ for i := 0; i < n; i++ {
+ if extra == i {
+ l.PushBack(&entry)
+ }
+ l.PushBack(&entries[i])
+ }
+ if extra == n {
+ l.PushBack(&entry)
+ }
+
+ l.Remove(&entry)
+ verifyEquality(t, entries[:n], &l)
+ }
+ }
+}
+
+func TestReset(t *testing.T) {
+ var l List
+
+ // Resetting list of one element.
+ l.PushBack(&testEntry{})
+ if l.Front() == nil {
+ t.Error("List is empty")
+ }
+
+ l.Reset()
+ if l.Front() != nil {
+ t.Error("List is not empty")
+ }
+
+ // Resetting list of 10 elements.
+ for i := 0; i < 10; i++ {
+ l.PushBack(&testEntry{})
+ }
+
+ if l.Front() == nil {
+ t.Error("List is empty")
+ }
+
+ l.Reset()
+ if l.Front() != nil {
+ t.Error("List is not empty")
+ }
+
+ // Resetting empty list.
+ l.Reset()
+ if l.Front() != nil {
+ t.Error("List is not empty")
+ }
+}
+
+func BenchmarkIterateForward(b *testing.B) {
+ var l List
+ for i := 0; i < 1000000; i++ {
+ l.PushBack(&testEntry{value: i})
+ }
+
+ for i := b.N; i > 0; i-- {
+ tmp := 0
+ for e := l.Front(); e != nil; e = e.Next() {
+ tmp += e.(*testEntry).value
+ }
+ }
+}
+
+func BenchmarkIterateBackward(b *testing.B) {
+ var l List
+ for i := 0; i < 1000000; i++ {
+ l.PushBack(&testEntry{value: i})
+ }
+
+ for i := b.N; i > 0; i-- {
+ tmp := 0
+ for e := l.Back(); e != nil; e = e.Prev() {
+ tmp += e.(*testEntry).value
+ }
+ }
+}
+
+func BenchmarkDirectIterateForward(b *testing.B) {
+ var l directList
+ for i := 0; i < 1000000; i++ {
+ l.PushBack(&direct{value: i})
+ }
+
+ for i := b.N; i > 0; i-- {
+ tmp := 0
+ for e := l.Front(); e != nil; e = e.Next() {
+ tmp += e.value
+ }
+ }
+}
+
+func BenchmarkDirectIterateBackward(b *testing.B) {
+ var l directList
+ for i := 0; i < 1000000; i++ {
+ l.PushBack(&direct{value: i})
+ }
+
+ for i := b.N; i > 0; i-- {
+ tmp := 0
+ for e := l.Back(); e != nil; e = e.Prev() {
+ tmp += e.value
+ }
+ }
+}
diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD
new file mode 100644
index 000000000..f84d03700
--- /dev/null
+++ b/pkg/linewriter/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "linewriter",
+ srcs = ["linewriter.go"],
+ marshal = False,
+ stateify = False,
+ visibility = ["//visibility:public"],
+ deps = ["//pkg/sync"],
+)
+
+go_test(
+ name = "linewriter_test",
+ srcs = ["linewriter_test.go"],
+ library = ":linewriter",
+)
diff --git a/pkg/linewriter/linewriter.go b/pkg/linewriter/linewriter.go
new file mode 100644
index 000000000..a1b1285d4
--- /dev/null
+++ b/pkg/linewriter/linewriter.go
@@ -0,0 +1,79 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package linewriter provides an io.Writer which calls an emitter on each line.
+package linewriter
+
+import (
+ "bytes"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Writer is an io.Writer which buffers input, flushing
+// individual lines through an emitter function.
+type Writer struct {
+ // the mutex locks buf.
+ sync.Mutex
+
+ // buf holds the data we haven't emitted yet.
+ buf bytes.Buffer
+
+ // emit is used to flush individual lines.
+ emit func(p []byte)
+}
+
+// NewWriter creates a Writer which emits using emitter.
+// The emitter must not retain p. It may change after emitter returns.
+func NewWriter(emitter func(p []byte)) *Writer {
+ return &Writer{emit: emitter}
+}
+
+// Write implements io.Writer.Write.
+// It calls emit on each line of input, not including the newline.
+// Write may be called concurrently.
+func (w *Writer) Write(p []byte) (int, error) {
+ w.Lock()
+ defer w.Unlock()
+
+ total := 0
+ for len(p) > 0 {
+ emit := true
+ i := bytes.IndexByte(p, '\n')
+ if i < 0 {
+ // No newline, we will buffer everything.
+ i = len(p)
+ emit = false
+ }
+
+ n, err := w.buf.Write(p[:i])
+ if err != nil {
+ return total, err
+ }
+ total += n
+
+ p = p[i:]
+
+ if emit {
+ // Skip the newline, but still count it.
+ p = p[1:]
+ total++
+
+ w.emit(w.buf.Bytes())
+ w.buf.Reset()
+ }
+ }
+
+ return total, nil
+}
diff --git a/pkg/linewriter/linewriter_test.go b/pkg/linewriter/linewriter_test.go
new file mode 100644
index 000000000..96dc7e6e0
--- /dev/null
+++ b/pkg/linewriter/linewriter_test.go
@@ -0,0 +1,81 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linewriter
+
+import (
+ "bytes"
+ "testing"
+)
+
+func TestWriter(t *testing.T) {
+ testCases := []struct {
+ input []string
+ want []string
+ }{
+ {
+ input: []string{"1\n", "2\n"},
+ want: []string{"1", "2"},
+ },
+ {
+ input: []string{"1\n", "\n", "2\n"},
+ want: []string{"1", "", "2"},
+ },
+ {
+ input: []string{"1\n2\n", "3\n"},
+ want: []string{"1", "2", "3"},
+ },
+ {
+ input: []string{"1", "2\n"},
+ want: []string{"12"},
+ },
+ {
+ // Data with no newline yet is omitted.
+ input: []string{"1\n", "2\n", "3"},
+ want: []string{"1", "2"},
+ },
+ }
+
+ for _, c := range testCases {
+ var lines [][]byte
+
+ w := NewWriter(func(p []byte) {
+ // We must not retain p, so we must make a copy.
+ b := make([]byte, len(p))
+ copy(b, p)
+
+ lines = append(lines, b)
+ })
+
+ for _, in := range c.input {
+ n, err := w.Write([]byte(in))
+ if err != nil {
+ t.Errorf("Write(%q) err got %v want nil (case %+v)", in, err, c)
+ }
+ if n != len(in) {
+ t.Errorf("Write(%q) b got %d want %d (case %+v)", in, n, len(in), c)
+ }
+ }
+
+ if len(lines) != len(c.want) {
+ t.Errorf("len(lines) got %d want %d (case %+v)", len(lines), len(c.want), c)
+ }
+
+ for i := range lines {
+ if !bytes.Equal(lines[i], []byte(c.want[i])) {
+ t.Errorf("item %d got %q want %q (case %+v)", i, lines[i], c.want[i], c)
+ }
+ }
+ }
+}
diff --git a/pkg/log/BUILD b/pkg/log/BUILD
new file mode 100644
index 000000000..3ed6aba5c
--- /dev/null
+++ b/pkg/log/BUILD
@@ -0,0 +1,32 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "log",
+ srcs = [
+ "glog.go",
+ "json.go",
+ "json_k8s.go",
+ "log.go",
+ ],
+ marshal = False,
+ stateify = False,
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ "//pkg/linewriter",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "log_test",
+ size = "small",
+ srcs = [
+ "json_test.go",
+ "log_test.go",
+ ],
+ library = ":log",
+)
diff --git a/pkg/log/glog.go b/pkg/log/glog.go
new file mode 100644
index 000000000..f57c4427b
--- /dev/null
+++ b/pkg/log/glog.go
@@ -0,0 +1,85 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package log
+
+import (
+ "fmt"
+ "os"
+ "runtime"
+ "strings"
+ "time"
+)
+
+// GoogleEmitter is a wrapper that emits logs in a format compatible with
+// package github.com/golang/glog.
+type GoogleEmitter struct {
+ *Writer
+}
+
+// pid is used for the threadid component of the header.
+var pid = os.Getpid()
+
+// Emit emits the message, google-style.
+//
+// Log lines have this form:
+// Lmmdd hh:mm:ss.uuuuuu threadid file:line] msg...
+//
+// where the fields are defined as follows:
+// L A single character, representing the log level (eg 'I' for INFO)
+// mm The month (zero padded; ie May is '05')
+// dd The day (zero padded)
+// hh:mm:ss.uuuuuu Time in hours, minutes and fractional seconds
+// threadid The space-padded thread ID as returned by GetTID()
+// file The file name
+// line The line number
+// msg The user-supplied message
+//
+func (g GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format string, args ...interface{}) {
+ // Log level.
+ prefix := byte('?')
+ switch level {
+ case Debug:
+ prefix = byte('D')
+ case Info:
+ prefix = byte('I')
+ case Warning:
+ prefix = byte('W')
+ }
+
+ // Timestamp.
+ _, month, day := timestamp.Date()
+ hour, minute, second := timestamp.Clock()
+ microsecond := int(timestamp.Nanosecond() / 1000)
+
+ // 0 = this frame.
+ _, file, line, ok := runtime.Caller(depth + 1)
+ if ok {
+ // Trim any directory path from the file.
+ slash := strings.LastIndexByte(file, byte('/'))
+ if slash >= 0 {
+ file = file[slash+1:]
+ }
+ } else {
+ // We don't have a filename.
+ file = "???"
+ line = 0
+ }
+
+ // Generate the message.
+ message := fmt.Sprintf(format, args...)
+
+ // Emit the formatted result.
+ fmt.Fprintf(g.Writer, "%c%02d%02d %02d:%02d:%02d.%06d % 7d %s:%d] %s\n", prefix, int(month), day, hour, minute, second, microsecond, pid, file, line, message)
+}
diff --git a/pkg/log/json.go b/pkg/log/json.go
new file mode 100644
index 000000000..bdf9d691e
--- /dev/null
+++ b/pkg/log/json.go
@@ -0,0 +1,76 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package log
+
+import (
+ "encoding/json"
+ "fmt"
+ "time"
+)
+
+type jsonLog struct {
+ Msg string `json:"msg"`
+ Level Level `json:"level"`
+ Time time.Time `json:"time"`
+}
+
+// MarshalJSON implements json.Marshaler.MarashalJSON.
+func (lv Level) MarshalJSON() ([]byte, error) {
+ switch lv {
+ case Warning:
+ return []byte(`"warning"`), nil
+ case Info:
+ return []byte(`"info"`), nil
+ case Debug:
+ return []byte(`"debug"`), nil
+ default:
+ return nil, fmt.Errorf("unknown level %v", lv)
+ }
+}
+
+// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON. It can unmarshal
+// from both string names and integers.
+func (lv *Level) UnmarshalJSON(b []byte) error {
+ switch s := string(b); s {
+ case "0", `"warning"`:
+ *lv = Warning
+ case "1", `"info"`:
+ *lv = Info
+ case "2", `"debug"`:
+ *lv = Debug
+ default:
+ return fmt.Errorf("unknown level %q", s)
+ }
+ return nil
+}
+
+// JSONEmitter logs messages in json format.
+type JSONEmitter struct {
+ *Writer
+}
+
+// Emit implements Emitter.Emit.
+func (e JSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
+ j := jsonLog{
+ Msg: fmt.Sprintf(format, v...),
+ Level: level,
+ Time: timestamp,
+ }
+ b, err := json.Marshal(j)
+ if err != nil {
+ panic(err)
+ }
+ e.Writer.Write(b)
+}
diff --git a/pkg/log/json_k8s.go b/pkg/log/json_k8s.go
new file mode 100644
index 000000000..5883e95e1
--- /dev/null
+++ b/pkg/log/json_k8s.go
@@ -0,0 +1,47 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package log
+
+import (
+ "encoding/json"
+ "fmt"
+ "time"
+)
+
+type k8sJSONLog struct {
+ Log string `json:"log"`
+ Level Level `json:"level"`
+ Time time.Time `json:"time"`
+}
+
+// K8sJSONEmitter logs messages in json format that is compatible with
+// Kubernetes fluent configuration.
+type K8sJSONEmitter struct {
+ *Writer
+}
+
+// Emit implements Emitter.Emit.
+func (e K8sJSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
+ j := k8sJSONLog{
+ Log: fmt.Sprintf(format, v...),
+ Level: level,
+ Time: timestamp,
+ }
+ b, err := json.Marshal(j)
+ if err != nil {
+ panic(err)
+ }
+ e.Writer.Write(b)
+}
diff --git a/pkg/log/json_test.go b/pkg/log/json_test.go
new file mode 100644
index 000000000..f25224fe1
--- /dev/null
+++ b/pkg/log/json_test.go
@@ -0,0 +1,64 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package log
+
+import (
+ "encoding/json"
+ "testing"
+)
+
+// Tests that Level can marshal/unmarshal properly.
+func TestLevelMarshal(t *testing.T) {
+ lvs := []Level{Warning, Info, Debug}
+ for _, lv := range lvs {
+ bs, err := lv.MarshalJSON()
+ if err != nil {
+ t.Errorf("error marshaling %v: %v", lv, err)
+ }
+ var lv2 Level
+ if err := lv2.UnmarshalJSON(bs); err != nil {
+ t.Errorf("error unmarshaling %v: %v", bs, err)
+ }
+ if lv != lv2 {
+ t.Errorf("marshal/unmarshal level got %v wanted %v", lv2, lv)
+ }
+ }
+}
+
+// Test that integers can be properly unmarshaled.
+func TestUnmarshalFromInt(t *testing.T) {
+ tcs := []struct {
+ i int
+ want Level
+ }{
+ {0, Warning},
+ {1, Info},
+ {2, Debug},
+ }
+
+ for _, tc := range tcs {
+ j, err := json.Marshal(tc.i)
+ if err != nil {
+ t.Errorf("error marshaling %v: %v", tc.i, err)
+ }
+ var lv Level
+ if err := lv.UnmarshalJSON(j); err != nil {
+ t.Errorf("error unmarshaling %v: %v", j, err)
+ }
+ if lv != tc.want {
+ t.Errorf("marshal/unmarshal %v got %v want %v", tc.i, lv, tc.want)
+ }
+ }
+}
diff --git a/pkg/log/log.go b/pkg/log/log.go
new file mode 100644
index 000000000..37e0605ad
--- /dev/null
+++ b/pkg/log/log.go
@@ -0,0 +1,378 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package log implements a library for logging.
+//
+// This is separate from the standard logging package because logging may be a
+// high-impact activity, and therefore we wanted to provide as much flexibility
+// as possible in the underlying implementation.
+//
+// Note that logging should still be considered high-impact, and should not be
+// done in the hot path. If necessary, logging statements should be protected
+// with guards regarding the logging level. For example,
+//
+// if log.IsLogging(log.Debug) {
+// log.Debugf(...)
+// }
+//
+// This is because the log.Debugf(...) statement alone will generate a
+// significant amount of garbage and churn in many cases, even if no log
+// message is ultimately emitted.
+package log
+
+import (
+ "fmt"
+ "io"
+ stdlog "log"
+ "os"
+ "runtime"
+ "sync/atomic"
+ "syscall"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/linewriter"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Level is the log level.
+type Level uint32
+
+// The following levels are fixed, and can never be changed. Since some control
+// RPCs allow for changing the level as an integer, it is only possible to add
+// additional levels, and the existing one cannot be removed.
+const (
+ // Warning indicates that output should always be emitted.
+ Warning Level = iota
+
+ // Info indicates that output should normally be emitted.
+ Info
+
+ // Debug indicates that output should not normally be emitted.
+ Debug
+)
+
+func (l Level) String() string {
+ switch l {
+ case Warning:
+ return "Warning"
+ case Info:
+ return "Info"
+ case Debug:
+ return "Debug"
+ default:
+ return fmt.Sprintf("Invalid level: %d", l)
+ }
+}
+
+// Emitter is the final destination for logs.
+type Emitter interface {
+ // Emit emits the given log statement. This allows for control over the
+ // timestamp used for logging.
+ Emit(depth int, level Level, timestamp time.Time, format string, v ...interface{})
+}
+
+// Writer writes the output to the given writer.
+type Writer struct {
+ // Next is where output is written.
+ Next io.Writer
+
+ // mu protects fields below.
+ mu sync.Mutex
+
+ // errors counts failures to write log messages so it can be reported
+ // when writer start to work again. Needs to be accessed using atomics
+ // to make race detector happy because it's read outside the mutex.
+ errors int32
+}
+
+// Write writes out the given bytes, handling non-blocking sockets.
+func (l *Writer) Write(data []byte) (int, error) {
+ n := 0
+
+ for n < len(data) {
+ w, err := l.Next.Write(data[n:])
+ n += w
+
+ // Is it a non-blocking socket?
+ if pathErr, ok := err.(*os.PathError); ok && pathErr.Err == syscall.EAGAIN {
+ runtime.Gosched()
+ continue
+ }
+
+ // Some other error?
+ if err != nil {
+ l.mu.Lock()
+ atomic.AddInt32(&l.errors, 1)
+ l.mu.Unlock()
+ return n, err
+ }
+ }
+
+ // Do we need to end with a '\n'?
+ if len(data) == 0 || data[len(data)-1] != '\n' {
+ l.Write([]byte{'\n'})
+ }
+
+ // Dirty read in case there were errors (rare).
+ if atomic.LoadInt32(&l.errors) > 0 {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ // Recheck condition under lock.
+ if e := atomic.LoadInt32(&l.errors); e > 0 {
+ msg := fmt.Sprintf("\n*** Dropped %d log messages ***\n", e)
+ if _, err := l.Next.Write([]byte(msg)); err == nil {
+ atomic.StoreInt32(&l.errors, 0)
+ }
+ }
+ }
+
+ return n, nil
+}
+
+// Emit emits the message.
+func (l *Writer) Emit(_ int, _ Level, _ time.Time, format string, args ...interface{}) {
+ fmt.Fprintf(l, format, args...)
+}
+
+// MultiEmitter is an emitter that emits to multiple Emitters.
+type MultiEmitter []Emitter
+
+// Emit emits to all emitters.
+func (m *MultiEmitter) Emit(depth int, level Level, timestamp time.Time, format string, v ...interface{}) {
+ for _, e := range *m {
+ e.Emit(1+depth, level, timestamp, format, v...)
+ }
+}
+
+// TestLogger is implemented by testing.T and testing.B.
+type TestLogger interface {
+ Logf(format string, v ...interface{})
+}
+
+// TestEmitter may be used for wrapping tests.
+type TestEmitter struct {
+ TestLogger
+}
+
+// Emit emits to the TestLogger.
+func (t *TestEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
+ t.Logf(format, v...)
+}
+
+// Logger is a high-level logging interface. It is in fact, not used within the
+// log package. Rather it is provided for others to provide contextual loggers
+// that may append some addition information to log statement. BasicLogger
+// satisfies this interface, and may be passed around as a Logger.
+type Logger interface {
+ // Debugf logs a debug statement.
+ Debugf(format string, v ...interface{})
+
+ // Infof logs at an info level.
+ Infof(format string, v ...interface{})
+
+ // Warningf logs at a warning level.
+ Warningf(format string, v ...interface{})
+
+ // IsLogging returns true iff this level is being logged. This may be
+ // used to short-circuit expensive operations for debugging calls.
+ IsLogging(level Level) bool
+}
+
+// BasicLogger is the default implementation of Logger.
+type BasicLogger struct {
+ Level
+ Emitter
+}
+
+// Debugf implements logger.Debugf.
+func (l *BasicLogger) Debugf(format string, v ...interface{}) {
+ l.DebugfAtDepth(1, format, v...)
+}
+
+// Infof implements logger.Infof.
+func (l *BasicLogger) Infof(format string, v ...interface{}) {
+ l.InfofAtDepth(1, format, v...)
+}
+
+// Warningf implements logger.Warningf.
+func (l *BasicLogger) Warningf(format string, v ...interface{}) {
+ l.WarningfAtDepth(1, format, v...)
+}
+
+// DebugfAtDepth logs at a specific depth.
+func (l *BasicLogger) DebugfAtDepth(depth int, format string, v ...interface{}) {
+ if l.IsLogging(Debug) {
+ l.Emit(1+depth, Debug, time.Now(), format, v...)
+ }
+}
+
+// InfofAtDepth logs at a specific depth.
+func (l *BasicLogger) InfofAtDepth(depth int, format string, v ...interface{}) {
+ if l.IsLogging(Info) {
+ l.Emit(1+depth, Info, time.Now(), format, v...)
+ }
+}
+
+// WarningfAtDepth logs at a specific depth.
+func (l *BasicLogger) WarningfAtDepth(depth int, format string, v ...interface{}) {
+ if l.IsLogging(Warning) {
+ l.Emit(1+depth, Warning, time.Now(), format, v...)
+ }
+}
+
+// IsLogging implements logger.IsLogging.
+func (l *BasicLogger) IsLogging(level Level) bool {
+ return atomic.LoadUint32((*uint32)(&l.Level)) >= uint32(level)
+}
+
+// SetLevel sets the logging level.
+func (l *BasicLogger) SetLevel(level Level) {
+ atomic.StoreUint32((*uint32)(&l.Level), uint32(level))
+}
+
+// logMu protects Log below. We use atomic operations to read the value, but
+// updates require logMu to ensure consistency.
+var logMu sync.Mutex
+
+// log is the default logger.
+var log atomic.Value
+
+// Log retrieves the global logger.
+func Log() *BasicLogger {
+ return log.Load().(*BasicLogger)
+}
+
+// SetTarget sets the log target.
+//
+// This is not thread safe and shouldn't be called concurrently with any
+// logging calls.
+func SetTarget(target Emitter) {
+ logMu.Lock()
+ defer logMu.Unlock()
+ oldLog := Log()
+ log.Store(&BasicLogger{Level: oldLog.Level, Emitter: target})
+}
+
+// SetLevel sets the log level.
+func SetLevel(newLevel Level) {
+ Log().SetLevel(newLevel)
+}
+
+// Debugf logs to the global logger.
+func Debugf(format string, v ...interface{}) {
+ Log().DebugfAtDepth(1, format, v...)
+}
+
+// Infof logs to the global logger.
+func Infof(format string, v ...interface{}) {
+ Log().InfofAtDepth(1, format, v...)
+}
+
+// Warningf logs to the global logger.
+func Warningf(format string, v ...interface{}) {
+ Log().WarningfAtDepth(1, format, v...)
+}
+
+// DebugfAtDepth logs to the global logger.
+func DebugfAtDepth(depth int, format string, v ...interface{}) {
+ Log().DebugfAtDepth(1+depth, format, v...)
+}
+
+// InfofAtDepth logs to the global logger.
+func InfofAtDepth(depth int, format string, v ...interface{}) {
+ Log().InfofAtDepth(1+depth, format, v...)
+}
+
+// WarningfAtDepth logs to the global logger.
+func WarningfAtDepth(depth int, format string, v ...interface{}) {
+ Log().WarningfAtDepth(1+depth, format, v...)
+}
+
+// defaultStackSize is the default buffer size to allocate for stack traces.
+const defaultStackSize = 1 << 16 // 64KB
+
+// maxStackSize is the maximum buffer size to allocate for stack traces.
+const maxStackSize = 1 << 26 // 64MB
+
+// Stacks returns goroutine stacks, like panic.
+func Stacks(all bool) []byte {
+ var trace []byte
+ for s := defaultStackSize; s <= maxStackSize; s *= 4 {
+ trace = make([]byte, s)
+ nbytes := runtime.Stack(trace, all)
+ if nbytes == s {
+ continue
+ }
+ return trace[:nbytes]
+ }
+ trace = append(trace, []byte("\n\n...<too large, truncated>")...)
+ return trace
+}
+
+// Traceback logs the given message and dumps a stacktrace of the current
+// goroutine.
+//
+// This will be print a traceback, tb, as Warningf(format+":\n%s", v..., tb).
+func Traceback(format string, v ...interface{}) {
+ v = append(v, Stacks(false))
+ Warningf(format+":\n%s", v...)
+}
+
+// TracebackAll logs the given message and dumps a stacktrace of all goroutines.
+//
+// This will be print a traceback, tb, as Warningf(format+":\n%s", v..., tb).
+func TracebackAll(format string, v ...interface{}) {
+ v = append(v, Stacks(true))
+ Warningf(format+":\n%s", v...)
+}
+
+// IsLogging returns whether the global logger is logging.
+func IsLogging(level Level) bool {
+ return Log().IsLogging(level)
+}
+
+// CopyStandardLogTo redirects the stdlib log package global output to the global
+// logger for the specified level.
+func CopyStandardLogTo(l Level) error {
+ var f func(string, ...interface{})
+
+ switch l {
+ case Debug:
+ f = Debugf
+ case Info:
+ f = Infof
+ case Warning:
+ f = Warningf
+ default:
+ return fmt.Errorf("Unknown log level %v", l)
+ }
+
+ stdlog.SetOutput(linewriter.NewWriter(func(p []byte) {
+ // We must not retain p, but log formatting is not required to
+ // be synchronous (though the in-package implementations are),
+ // so we must make a copy.
+ b := make([]byte, len(p))
+ copy(b, p)
+
+ f("%s", b)
+ }))
+
+ return nil
+}
+
+func init() {
+ // Store the initial value for the log.
+ log.Store(&BasicLogger{Level: Info, Emitter: GoogleEmitter{&Writer{Next: os.Stderr}}})
+}
diff --git a/pkg/log/log_test.go b/pkg/log/log_test.go
new file mode 100644
index 000000000..9ff18559b
--- /dev/null
+++ b/pkg/log/log_test.go
@@ -0,0 +1,105 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package log
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+)
+
+type testWriter struct {
+ lines []string
+ fail bool
+ limit int
+}
+
+func (w *testWriter) Write(bytes []byte) (int, error) {
+ if w.fail {
+ return 0, fmt.Errorf("simulated failure")
+ }
+ if w.limit > 0 && len(w.lines) >= w.limit {
+ return len(bytes), nil
+ }
+ w.lines = append(w.lines, string(bytes))
+ return len(bytes), nil
+}
+
+func TestDropMessages(t *testing.T) {
+ tw := &testWriter{}
+ w := Writer{Next: tw}
+ if _, err := w.Write([]byte("line 1\n")); err != nil {
+ t.Fatalf("Write failed, err: %v", err)
+ }
+
+ tw.fail = true
+ if _, err := w.Write([]byte("error\n")); err == nil {
+ t.Fatalf("Write should have failed")
+ }
+ if _, err := w.Write([]byte("error\n")); err == nil {
+ t.Fatalf("Write should have failed")
+ }
+
+ fmt.Printf("writer: %#v\n", &w)
+
+ tw.fail = false
+ if _, err := w.Write([]byte("line 2\n")); err != nil {
+ t.Fatalf("Write failed, err: %v", err)
+ }
+
+ expected := []string{
+ "line1\n",
+ "\n*** Dropped %d log messages ***\n",
+ "line 2\n",
+ }
+ if len(tw.lines) != len(expected) {
+ t.Fatalf("Writer should have logged %d lines, got: %v, expected: %v", len(expected), tw.lines, expected)
+ }
+ for i, l := range tw.lines {
+ if l == expected[i] {
+ t.Fatalf("line %d doesn't match, got: %v, expected: %v", i, l, expected[i])
+ }
+ }
+}
+
+func TestCaller(t *testing.T) {
+ tw := &testWriter{}
+ e := GoogleEmitter{Writer: &Writer{Next: tw}}
+ bl := &BasicLogger{
+ Emitter: e,
+ Level: Debug,
+ }
+ bl.Debugf("testing...\n") // Just for file + line.
+ if len(tw.lines) != 1 {
+ t.Errorf("expected 1 line, got %d", len(tw.lines))
+ }
+ if !strings.Contains(tw.lines[0], "log_test.go") {
+ t.Errorf("expected log_test.go, got %q", tw.lines[0])
+ }
+}
+
+func BenchmarkGoogleLogging(b *testing.B) {
+ tw := &testWriter{
+ limit: 1, // Only record one message.
+ }
+ e := GoogleEmitter{Writer: &Writer{Next: tw}}
+ bl := &BasicLogger{
+ Emitter: e,
+ Level: Debug,
+ }
+ for i := 0; i < b.N; i++ {
+ bl.Debugf("hello %d, %d, %d", 1, 2, 3)
+ }
+}
diff --git a/pkg/memutil/BUILD b/pkg/memutil/BUILD
new file mode 100644
index 000000000..9d07d98b4
--- /dev/null
+++ b/pkg/memutil/BUILD
@@ -0,0 +1,10 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "memutil",
+ srcs = ["memutil_unsafe.go"],
+ visibility = ["//visibility:public"],
+ deps = ["@org_golang_x_sys//unix:go_default_library"],
+)
diff --git a/pkg/memutil/memutil_unsafe.go b/pkg/memutil/memutil_unsafe.go
new file mode 100644
index 000000000..979d942a9
--- /dev/null
+++ b/pkg/memutil/memutil_unsafe.go
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+// Package memutil provides a wrapper for the memfd_create() system call.
+package memutil
+
+import (
+ "fmt"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+// CreateMemFD creates a memfd file and returns the fd.
+func CreateMemFD(name string, flags int) (int, error) {
+ p, err := syscall.BytePtrFromString(name)
+ if err != nil {
+ return -1, err
+ }
+ fd, _, e := syscall.Syscall(unix.SYS_MEMFD_CREATE, uintptr(unsafe.Pointer(p)), uintptr(flags), 0)
+ if e != 0 {
+ if e == syscall.ENOSYS {
+ return -1, fmt.Errorf("memfd_create(2) is not implemented. Check that you have Linux 3.17 or higher")
+ }
+ return -1, e
+ }
+ return int(fd), nil
+}
diff --git a/pkg/merkletree/BUILD b/pkg/merkletree/BUILD
new file mode 100644
index 000000000..5b0e4143a
--- /dev/null
+++ b/pkg/merkletree/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "merkletree",
+ srcs = ["merkletree.go"],
+ deps = ["//pkg/usermem"],
+)
+
+go_test(
+ name = "merkletree_test",
+ srcs = ["merkletree_test.go"],
+ library = ":merkletree",
+ deps = ["//pkg/usermem"],
+)
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
new file mode 100644
index 000000000..906f67943
--- /dev/null
+++ b/pkg/merkletree/merkletree.go
@@ -0,0 +1,135 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package merkletree implements Merkle tree generating and verification.
+package merkletree
+
+import (
+ "crypto/sha256"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ // sha256DigestSize specifies the digest size of a SHA256 hash.
+ sha256DigestSize = 32
+)
+
+// Size defines the scale of a Merkle tree.
+type Size struct {
+ // blockSize is the size of a data block to be hashed.
+ blockSize int64
+ // digestSize is the size of a generated hash.
+ digestSize int64
+ // hashesPerBlock is the number of hashes in a block. For example, if
+ // blockSize is 4096 bytes, and digestSize is 32 bytes, there will be 128
+ // hashesPerBlock. Therefore 128 hashes in a lower level will be put into a
+ // block and generate a single hash in an upper level.
+ hashesPerBlock int64
+ // levelStart is the start block index of each level. The number of levels in
+ // the tree is the length of the slice. The leafs (level 0) are hashes of
+ // blocks in the input data. The levels above are hashes of lower level
+ // hashes. The highest level is the root hash.
+ levelStart []int64
+}
+
+// MakeSize initializes and returns a new Size object describing the structure
+// of a tree. dataSize specifies the number of the file system size in bytes.
+func MakeSize(dataSize int64) Size {
+ size := Size{
+ blockSize: usermem.PageSize,
+ // TODO(b/156980949): Allow config other hash methods (SHA384/SHA512).
+ digestSize: sha256DigestSize,
+ hashesPerBlock: usermem.PageSize / sha256DigestSize,
+ }
+ numBlocks := (dataSize + size.blockSize - 1) / size.blockSize
+ level := int64(0)
+ offset := int64(0)
+
+ // Calcuate the number of levels in the Merkle tree and the beginning offset
+ // of each level. Level 0 is the level directly above the data blocks, while
+ // level NumLevels - 1 is the root.
+ for numBlocks > 1 {
+ size.levelStart = append(size.levelStart, offset)
+ // Round numBlocks up to fill up a block.
+ numBlocks += (size.hashesPerBlock - numBlocks%size.hashesPerBlock) % size.hashesPerBlock
+ offset += numBlocks / size.hashesPerBlock
+ numBlocks = numBlocks / size.hashesPerBlock
+ level++
+ }
+ size.levelStart = append(size.levelStart, offset)
+ return size
+}
+
+// Generate constructs a Merkle tree for the contents of data. The output is
+// written to treeWriter. The treeReader should be able to read the tree after
+// it has been written. That is, treeWriter and treeReader should point to the
+// same underlying data but have separate cursors.
+func Generate(data io.Reader, dataSize int64, treeReader io.Reader, treeWriter io.Writer) ([]byte, error) {
+ size := MakeSize(dataSize)
+
+ numBlocks := (dataSize + size.blockSize - 1) / size.blockSize
+
+ var root []byte
+ for level := 0; level < len(size.levelStart); level++ {
+ for i := int64(0); i < numBlocks; i++ {
+ buf := make([]byte, size.blockSize)
+ var (
+ n int
+ err error
+ )
+ if level == 0 {
+ // Read data block from the target file since level 0 is directly above
+ // the raw data block.
+ n, err = data.Read(buf)
+ } else {
+ // Read data block from the tree file since levels higher than 0 are
+ // hashing the lower level hashes.
+ n, err = treeReader.Read(buf)
+ }
+
+ // err is populated as long as the bytes read is smaller than the buffer
+ // size. This could be the case if we are reading the last block, and
+ // break in that case. If this is the last block, the end of the block
+ // will be zero-padded.
+ if n == 0 && err == io.EOF {
+ break
+ } else if err != nil && err != io.EOF {
+ return nil, err
+ }
+ // Hash the bytes in buf.
+ digest := sha256.Sum256(buf)
+
+ if level == len(size.levelStart)-1 {
+ root = digest[:]
+ }
+
+ // Write the generated hash to the end of the tree file.
+ if _, err = treeWriter.Write(digest[:]); err != nil {
+ return nil, err
+ }
+ }
+ // If the genereated digests do not round up to a block, zero-padding the
+ // remaining of the last block. But no need to do so for root.
+ if level != len(size.levelStart)-1 && numBlocks%size.hashesPerBlock != 0 {
+ zeroBuf := make([]byte, size.blockSize-(numBlocks%size.hashesPerBlock)*size.digestSize)
+ if _, err := treeWriter.Write(zeroBuf[:]); err != nil {
+ return nil, err
+ }
+ }
+ numBlocks = (numBlocks + size.hashesPerBlock - 1) / size.hashesPerBlock
+ }
+ return root, nil
+}
diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go
new file mode 100644
index 000000000..7344db0b6
--- /dev/null
+++ b/pkg/merkletree/merkletree_test.go
@@ -0,0 +1,122 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package merkletree
+
+import (
+ "bytes"
+ "fmt"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func TestSize(t *testing.T) {
+ testCases := []struct {
+ dataSize int64
+ expectedLevelStart []int64
+ }{
+ {
+ dataSize: 100,
+ expectedLevelStart: []int64{0},
+ },
+ {
+ dataSize: 1000000,
+ expectedLevelStart: []int64{0, 2, 3},
+ },
+ {
+ dataSize: 4096 * int64(usermem.PageSize),
+ expectedLevelStart: []int64{0, 32, 33},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) {
+ s := MakeSize(tc.dataSize)
+ if s.blockSize != int64(usermem.PageSize) {
+ t.Errorf("got blockSize %d, want %d", s.blockSize, usermem.PageSize)
+ }
+ if s.digestSize != sha256DigestSize {
+ t.Errorf("got digestSize %d, want %d", s.digestSize, sha256DigestSize)
+ }
+ if len(s.levelStart) != len(tc.expectedLevelStart) {
+ t.Errorf("got levels %d, want %d", len(s.levelStart), len(tc.expectedLevelStart))
+ }
+ for i := 0; i < len(s.levelStart) && i < len(tc.expectedLevelStart); i++ {
+ if s.levelStart[i] != tc.expectedLevelStart[i] {
+ t.Errorf("got levelStart[%d] %d, want %d", i, s.levelStart[i], tc.expectedLevelStart[i])
+ }
+ }
+ })
+ }
+}
+
+func TestGenerate(t *testing.T) {
+ // The input data has size dataSize. It starts with the data in startWith,
+ // and all other bytes are zeroes.
+ testCases := []struct {
+ dataSize int
+ startWith []byte
+ expectedRoot []byte
+ }{
+ {
+ dataSize: usermem.PageSize,
+ startWith: nil,
+ expectedRoot: []byte{173, 127, 172, 178, 88, 111, 198, 233, 102, 192, 4, 215, 209, 209, 107, 2, 79, 88, 5, 255, 124, 180, 124, 122, 133, 218, 189, 139, 72, 137, 44, 167},
+ },
+ {
+ dataSize: 128*usermem.PageSize + 1,
+ startWith: nil,
+ expectedRoot: []byte{62, 93, 40, 92, 161, 241, 30, 223, 202, 99, 39, 2, 132, 113, 240, 139, 117, 99, 79, 243, 54, 18, 100, 184, 141, 121, 238, 46, 149, 202, 203, 132},
+ },
+ {
+ dataSize: 1,
+ startWith: []byte{'a'},
+ expectedRoot: []byte{52, 75, 204, 142, 172, 129, 37, 14, 145, 137, 103, 203, 11, 162, 209, 205, 30, 169, 213, 72, 20, 28, 243, 24, 242, 2, 92, 43, 169, 59, 110, 210},
+ },
+ {
+ dataSize: 1,
+ startWith: []byte{'1'},
+ expectedRoot: []byte{74, 35, 103, 179, 176, 149, 254, 112, 42, 65, 104, 66, 119, 56, 133, 124, 228, 15, 65, 161, 150, 0, 117, 174, 242, 34, 115, 115, 218, 37, 3, 105},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) {
+ var (
+ data bytes.Buffer
+ tree bytes.Buffer
+ )
+
+ startSize := len(tc.startWith)
+ _, err := data.Write(tc.startWith)
+ if err != nil {
+ t.Fatalf("Failed to write to data: %v", err)
+ }
+ _, err = data.Write(make([]byte, tc.dataSize-startSize))
+ if err != nil {
+ t.Fatalf("Failed to write to data: %v", err)
+ }
+
+ root, err := Generate(&data, int64(tc.dataSize), &tree, &tree)
+ if err != nil {
+ t.Fatalf("Generate failed: %v", err)
+ }
+
+ if !bytes.Equal(root, tc.expectedRoot) {
+ t.Errorf("Unexpected root")
+ }
+ })
+ }
+}
diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD
new file mode 100644
index 000000000..58305009d
--- /dev/null
+++ b/pkg/metric/BUILD
@@ -0,0 +1,32 @@
+load("//tools:defs.bzl", "go_library", "go_test", "proto_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "metric",
+ srcs = ["metric.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ ":metric_go_proto",
+ "//pkg/eventchannel",
+ "//pkg/log",
+ "//pkg/sync",
+ ],
+)
+
+proto_library(
+ name = "metric",
+ srcs = ["metric.proto"],
+ visibility = ["//:sandbox"],
+)
+
+go_test(
+ name = "metric_test",
+ srcs = ["metric_test.go"],
+ library = ":metric",
+ deps = [
+ ":metric_go_proto",
+ "//pkg/eventchannel",
+ "@com_github_golang_protobuf//proto:go_default_library",
+ ],
+)
diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go
new file mode 100644
index 000000000..64aa365ce
--- /dev/null
+++ b/pkg/metric/metric.go
@@ -0,0 +1,250 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package metric provides primitives for collecting metrics.
+package metric
+
+import (
+ "errors"
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/eventchannel"
+ "gvisor.dev/gvisor/pkg/log"
+ pb "gvisor.dev/gvisor/pkg/metric/metric_go_proto"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+var (
+ // ErrNameInUse indicates that another metric is already defined for
+ // the given name.
+ ErrNameInUse = errors.New("metric name already in use")
+
+ // ErrInitializationDone indicates that the caller tried to create a
+ // new metric after initialization.
+ ErrInitializationDone = errors.New("metric cannot be created after initialization is complete")
+)
+
+// Uint64Metric encapsulates a uint64 that represents some kind of metric to be
+// monitored.
+//
+// Metrics are not saved across save/restore and thus reset to zero on restore.
+//
+// TODO(b/67298427): Support metric fields.
+type Uint64Metric struct {
+ // value is the actual value of the metric. It must be accessed atomically.
+ value uint64
+}
+
+var (
+ // initialized indicates that all metrics are registered. allMetrics is
+ // immutable once initialized is true.
+ initialized bool
+
+ // allMetrics are the registered metrics.
+ allMetrics = makeMetricSet()
+)
+
+// Initialize sends a metric registration event over the event channel.
+//
+// Precondition:
+// * All metrics are registered.
+// * Initialize/Disable has not been called.
+func Initialize() {
+ if initialized {
+ panic("Initialize/Disable called more than once")
+ }
+ initialized = true
+
+ m := pb.MetricRegistration{}
+ for _, v := range allMetrics.m {
+ m.Metrics = append(m.Metrics, v.metadata)
+ }
+ eventchannel.Emit(&m)
+}
+
+// Disable sends an empty metric registration event over the event channel,
+// disabling metric collection.
+//
+// Precondition:
+// * All metrics are registered.
+// * Initialize/Disable has not been called.
+func Disable() {
+ if initialized {
+ panic("Initialize/Disable called more than once")
+ }
+ initialized = true
+
+ m := pb.MetricRegistration{}
+ if err := eventchannel.Emit(&m); err != nil {
+ panic("unable to emit metric disable event: " + err.Error())
+ }
+}
+
+type customUint64Metric struct {
+ // metadata describes the metric. It is immutable.
+ metadata *pb.MetricMetadata
+
+ // value returns the current value of the metric.
+ value func() uint64
+}
+
+// RegisterCustomUint64Metric registers a metric with the given name.
+//
+// Register must only be called at init and will return and error if called
+// after Initialized.
+//
+// Preconditions:
+// * name must be globally unique.
+// * Initialize/Disable have not been called.
+func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.MetricMetadata_Units, description string, value func() uint64) error {
+ if initialized {
+ return ErrInitializationDone
+ }
+
+ if _, ok := allMetrics.m[name]; ok {
+ return ErrNameInUse
+ }
+
+ allMetrics.m[name] = customUint64Metric{
+ metadata: &pb.MetricMetadata{
+ Name: name,
+ Description: description,
+ Cumulative: cumulative,
+ Sync: sync,
+ Type: pb.MetricMetadata_TYPE_UINT64,
+ Units: units,
+ },
+ value: value,
+ }
+ return nil
+}
+
+// MustRegisterCustomUint64Metric calls RegisterCustomUint64Metric and panics
+// if it returns an error.
+func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, description string, value func() uint64) {
+ if err := RegisterCustomUint64Metric(name, cumulative, sync, pb.MetricMetadata_UNITS_NONE, description, value); err != nil {
+ panic(fmt.Sprintf("Unable to register metric %q: %v", name, err))
+ }
+}
+
+// NewUint64Metric creates and registers a new cumulative metric with the given name.
+//
+// Metrics must be statically defined (i.e., at init).
+func NewUint64Metric(name string, sync bool, units pb.MetricMetadata_Units, description string) (*Uint64Metric, error) {
+ var m Uint64Metric
+ return &m, RegisterCustomUint64Metric(name, true /* cumulative */, sync, units, description, m.Value)
+}
+
+// MustCreateNewUint64Metric calls NewUint64Metric and panics if it returns an error.
+func MustCreateNewUint64Metric(name string, sync bool, description string) *Uint64Metric {
+ m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NONE, description)
+ if err != nil {
+ panic(fmt.Sprintf("Unable to create metric %q: %v", name, err))
+ }
+ return m
+}
+
+// MustCreateNewUint64NanosecondsMetric calls NewUint64Metric and panics if it returns an error.
+func MustCreateNewUint64NanosecondsMetric(name string, sync bool, description string) *Uint64Metric {
+ m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NANOSECONDS, description)
+ if err != nil {
+ panic(fmt.Sprintf("Unable to create metric %q: %v", name, err))
+ }
+ return m
+}
+
+// Value returns the current value of the metric.
+func (m *Uint64Metric) Value() uint64 {
+ return atomic.LoadUint64(&m.value)
+}
+
+// Increment increments the metric by 1.
+func (m *Uint64Metric) Increment() {
+ atomic.AddUint64(&m.value, 1)
+}
+
+// IncrementBy increments the metric by v.
+func (m *Uint64Metric) IncrementBy(v uint64) {
+ atomic.AddUint64(&m.value, v)
+}
+
+// metricSet holds named metrics.
+type metricSet struct {
+ m map[string]customUint64Metric
+}
+
+// makeMetricSet returns a new metricSet.
+func makeMetricSet() metricSet {
+ return metricSet{
+ m: make(map[string]customUint64Metric),
+ }
+}
+
+// Values returns a snapshot of all values in m.
+func (m *metricSet) Values() metricValues {
+ vals := make(metricValues)
+ for k, v := range m.m {
+ vals[k] = v.value()
+ }
+ return vals
+}
+
+// metricValues contains a copy of the values of all metrics.
+type metricValues map[string]uint64
+
+var (
+ // emitMu protects metricsAtLastEmit and ensures that all emitted
+ // metrics are strongly ordered (older metrics are never emitted after
+ // newer metrics).
+ emitMu sync.Mutex
+
+ // metricsAtLastEmit contains the state of the metrics at the last emit event.
+ metricsAtLastEmit metricValues
+)
+
+// EmitMetricUpdate emits a MetricUpdate over the event channel.
+//
+// Only metrics that have changed since the last call are emitted.
+//
+// EmitMetricUpdate is thread-safe.
+//
+// Preconditions:
+// * Initialize has been called.
+func EmitMetricUpdate() {
+ emitMu.Lock()
+ defer emitMu.Unlock()
+
+ snapshot := allMetrics.Values()
+
+ m := pb.MetricUpdate{}
+ for k, v := range snapshot {
+ // On the first call metricsAtLastEmit will be empty. Include
+ // all metrics then.
+ if prev, ok := metricsAtLastEmit[k]; !ok || prev != v {
+ m.Metrics = append(m.Metrics, &pb.MetricValue{
+ Name: k,
+ Value: &pb.MetricValue_Uint64Value{v},
+ })
+ }
+ }
+
+ metricsAtLastEmit = snapshot
+ if len(m.Metrics) == 0 {
+ return
+ }
+
+ log.Debugf("Emitting metrics: %v", &m)
+ eventchannel.Emit(&m)
+}
diff --git a/pkg/metric/metric.proto b/pkg/metric/metric.proto
new file mode 100644
index 000000000..3cc89047d
--- /dev/null
+++ b/pkg/metric/metric.proto
@@ -0,0 +1,76 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package gvisor;
+
+// MetricMetadata contains all of the metadata describing a single metric.
+message MetricMetadata {
+ // name is the unique name of the metric, usually in a "directory" format
+ // (e.g., /foo/count).
+ string name = 1;
+
+ // description is a human-readable description of the metric.
+ string description = 2;
+
+ // cumulative indicates that this metric is never decremented.
+ bool cumulative = 3;
+
+ // sync indicates that values from the final metric event should be
+ // synchronized to the backing monitoring system at exit.
+ //
+ // If sync is false, values are only sent to the monitoring system
+ // periodically. There is no guarantee that values will ever be received by
+ // the monitoring system.
+ bool sync = 4;
+
+ enum Type { TYPE_UINT64 = 0; }
+
+ // type is the type of the metric value.
+ Type type = 5;
+
+ enum Units {
+ UNITS_NONE = 0;
+ UNITS_NANOSECONDS = 1;
+ }
+
+ // units is the units of the metric value.
+ Units units = 6;
+}
+
+// MetricRegistration contains the metadata for all metrics that will be in
+// future MetricUpdates.
+message MetricRegistration {
+ repeated MetricMetadata metrics = 1;
+}
+
+// MetricValue the value of a metric at a single point in time.
+message MetricValue {
+ // name is the unique name of the metric, as in MetricMetadata.
+ string name = 1;
+
+ // value is the value of the metric at a single point in time. The field set
+ // depends on the type of the metric.
+ oneof value {
+ uint64 uint64_value = 2;
+ }
+}
+
+// MetricUpdate contains new values for multiple distinct metrics.
+//
+// Metrics whose values have not changed are not included.
+message MetricUpdate {
+ repeated MetricValue metrics = 1;
+}
diff --git a/pkg/metric/metric_test.go b/pkg/metric/metric_test.go
new file mode 100644
index 000000000..c425ea532
--- /dev/null
+++ b/pkg/metric/metric_test.go
@@ -0,0 +1,258 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package metric
+
+import (
+ "testing"
+
+ "github.com/golang/protobuf/proto"
+ "gvisor.dev/gvisor/pkg/eventchannel"
+ pb "gvisor.dev/gvisor/pkg/metric/metric_go_proto"
+)
+
+// sliceEmitter implements eventchannel.Emitter by appending all messages to a
+// slice.
+type sliceEmitter []proto.Message
+
+// Emit implements eventchannel.Emitter.Emit.
+func (s *sliceEmitter) Emit(msg proto.Message) (bool, error) {
+ *s = append(*s, msg)
+ return false, nil
+}
+
+// Emit implements eventchannel.Emitter.Close.
+func (s *sliceEmitter) Close() error {
+ return nil
+}
+
+// Reset clears all events in s.
+func (s *sliceEmitter) Reset() {
+ *s = nil
+}
+
+// emitter is the eventchannel.Emitter used for all tests. Package eventchannel
+// doesn't allow removing Emitters, so we must use one global emitter for all
+// test cases.
+var emitter sliceEmitter
+
+func init() {
+ eventchannel.AddEmitter(&emitter)
+}
+
+// reset clears all global state in the metric package.
+func reset() {
+ initialized = false
+ allMetrics = makeMetricSet()
+ emitter.Reset()
+}
+
+const (
+ fooDescription = "Foo!"
+ barDescription = "Bar Baz"
+)
+
+func TestInitialize(t *testing.T) {
+ defer reset()
+
+ _, err := NewUint64Metric("/foo", false, pb.MetricMetadata_UNITS_NONE, fooDescription)
+ if err != nil {
+ t.Fatalf("NewUint64Metric got err %v want nil", err)
+ }
+
+ _, err = NewUint64Metric("/bar", true, pb.MetricMetadata_UNITS_NANOSECONDS, barDescription)
+ if err != nil {
+ t.Fatalf("NewUint64Metric got err %v want nil", err)
+ }
+
+ Initialize()
+
+ if len(emitter) != 1 {
+ t.Fatalf("Initialize emitted %d events want 1", len(emitter))
+ }
+
+ mr, ok := emitter[0].(*pb.MetricRegistration)
+ if !ok {
+ t.Fatalf("emitter %v got %T want pb.MetricRegistration", emitter[0], emitter[0])
+ }
+
+ if len(mr.Metrics) != 2 {
+ t.Errorf("MetricRegistration got %d metrics want 2", len(mr.Metrics))
+ }
+
+ foundFoo := false
+ foundBar := false
+ for _, m := range mr.Metrics {
+ if m.Type != pb.MetricMetadata_TYPE_UINT64 {
+ t.Errorf("Metadata %+v Type got %v want %v", m, m.Type, pb.MetricMetadata_TYPE_UINT64)
+ }
+ if !m.Cumulative {
+ t.Errorf("Metadata %+v Cumulative got false want true", m)
+ }
+
+ switch m.Name {
+ case "/foo":
+ foundFoo = true
+ if m.Description != fooDescription {
+ t.Errorf("/foo %+v Description got %q want %q", m, m.Description, fooDescription)
+ }
+ if m.Sync {
+ t.Errorf("/foo %+v Sync got true want false", m)
+ }
+ if m.Units != pb.MetricMetadata_UNITS_NONE {
+ t.Errorf("/foo %+v Units got %v want %v", m, m.Units, pb.MetricMetadata_UNITS_NONE)
+ }
+ case "/bar":
+ foundBar = true
+ if m.Description != barDescription {
+ t.Errorf("/bar %+v Description got %q want %q", m, m.Description, barDescription)
+ }
+ if !m.Sync {
+ t.Errorf("/bar %+v Sync got true want false", m)
+ }
+ if m.Units != pb.MetricMetadata_UNITS_NANOSECONDS {
+ t.Errorf("/bar %+v Units got %v want %v", m, m.Units, pb.MetricMetadata_UNITS_NANOSECONDS)
+ }
+ }
+ }
+
+ if !foundFoo {
+ t.Errorf("/foo not found: %+v", emitter)
+ }
+ if !foundBar {
+ t.Errorf("/bar not found: %+v", emitter)
+ }
+}
+
+func TestDisable(t *testing.T) {
+ defer reset()
+
+ _, err := NewUint64Metric("/foo", false, pb.MetricMetadata_UNITS_NONE, fooDescription)
+ if err != nil {
+ t.Fatalf("NewUint64Metric got err %v want nil", err)
+ }
+
+ _, err = NewUint64Metric("/bar", true, pb.MetricMetadata_UNITS_NONE, barDescription)
+ if err != nil {
+ t.Fatalf("NewUint64Metric got err %v want nil", err)
+ }
+
+ Disable()
+
+ if len(emitter) != 1 {
+ t.Fatalf("Initialize emitted %d events want 1", len(emitter))
+ }
+
+ mr, ok := emitter[0].(*pb.MetricRegistration)
+ if !ok {
+ t.Fatalf("emitter %v got %T want pb.MetricRegistration", emitter[0], emitter[0])
+ }
+
+ if len(mr.Metrics) != 0 {
+ t.Errorf("MetricRegistration got %d metrics want 0", len(mr.Metrics))
+ }
+}
+
+func TestEmitMetricUpdate(t *testing.T) {
+ defer reset()
+
+ foo, err := NewUint64Metric("/foo", false, pb.MetricMetadata_UNITS_NONE, fooDescription)
+ if err != nil {
+ t.Fatalf("NewUint64Metric got err %v want nil", err)
+ }
+
+ _, err = NewUint64Metric("/bar", true, pb.MetricMetadata_UNITS_NONE, barDescription)
+ if err != nil {
+ t.Fatalf("NewUint64Metric got err %v want nil", err)
+ }
+
+ Initialize()
+
+ // Don't care about the registration metrics.
+ emitter.Reset()
+ EmitMetricUpdate()
+
+ if len(emitter) != 1 {
+ t.Fatalf("EmitMetricUpdate emitted %d events want 1", len(emitter))
+ }
+
+ update, ok := emitter[0].(*pb.MetricUpdate)
+ if !ok {
+ t.Fatalf("emitter %v got %T want pb.MetricUpdate", emitter[0], emitter[0])
+ }
+
+ if len(update.Metrics) != 2 {
+ t.Errorf("MetricUpdate got %d metrics want 2", len(update.Metrics))
+ }
+
+ // Both are included for their initial values.
+ foundFoo := false
+ foundBar := false
+ for _, m := range update.Metrics {
+ switch m.Name {
+ case "/foo":
+ foundFoo = true
+ case "/bar":
+ foundBar = true
+ }
+ uv, ok := m.Value.(*pb.MetricValue_Uint64Value)
+ if !ok {
+ t.Errorf("%+v: value %v got %T want pb.MetricValue_Uint64Value", m, m.Value, m.Value)
+ continue
+ }
+ if uv.Uint64Value != 0 {
+ t.Errorf("%v: Value got %v want 0", m, uv.Uint64Value)
+ }
+ }
+
+ if !foundFoo {
+ t.Errorf("/foo not found: %+v", emitter)
+ }
+ if !foundBar {
+ t.Errorf("/bar not found: %+v", emitter)
+ }
+
+ // Increment foo. Only it is included in the next update.
+ foo.Increment()
+
+ emitter.Reset()
+ EmitMetricUpdate()
+
+ if len(emitter) != 1 {
+ t.Fatalf("EmitMetricUpdate emitted %d events want 1", len(emitter))
+ }
+
+ update, ok = emitter[0].(*pb.MetricUpdate)
+ if !ok {
+ t.Fatalf("emitter %v got %T want pb.MetricUpdate", emitter[0], emitter[0])
+ }
+
+ if len(update.Metrics) != 1 {
+ t.Errorf("MetricUpdate got %d metrics want 1", len(update.Metrics))
+ }
+
+ m := update.Metrics[0]
+
+ if m.Name != "/foo" {
+ t.Errorf("Metric %+v name got %q want '/foo'", m, m.Name)
+ }
+
+ uv, ok := m.Value.(*pb.MetricValue_Uint64Value)
+ if !ok {
+ t.Errorf("%+v: value %v got %T want pb.MetricValue_Uint64Value", m, m.Value, m.Value)
+ }
+ if uv.Uint64Value != 1 {
+ t.Errorf("%v: Value got %v want 1", m, uv.Uint64Value)
+ }
+}
diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD
new file mode 100644
index 000000000..8904afad9
--- /dev/null
+++ b/pkg/p9/BUILD
@@ -0,0 +1,52 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "p9",
+ srcs = [
+ "buffer.go",
+ "client.go",
+ "client_file.go",
+ "file.go",
+ "handlers.go",
+ "messages.go",
+ "p9.go",
+ "path_tree.go",
+ "server.go",
+ "transport.go",
+ "transport_flipcall.go",
+ "version.go",
+ ],
+ deps = [
+ "//pkg/fd",
+ "//pkg/fdchannel",
+ "//pkg/flipcall",
+ "//pkg/log",
+ "//pkg/pool",
+ "//pkg/sync",
+ "//pkg/unet",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "p9_test",
+ size = "small",
+ srcs = [
+ "buffer_test.go",
+ "client_test.go",
+ "messages_test.go",
+ "p9_test.go",
+ "transport_test.go",
+ "version_test.go",
+ ],
+ library = ":p9",
+ deps = [
+ "//pkg/fd",
+ "//pkg/unet",
+ ],
+)
diff --git a/pkg/p9/buffer.go b/pkg/p9/buffer.go
new file mode 100644
index 000000000..6a4951821
--- /dev/null
+++ b/pkg/p9/buffer.go
@@ -0,0 +1,263 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "encoding/binary"
+)
+
+// encoder is used for messages and 9P primitives.
+type encoder interface {
+ // decode decodes from the given buffer. decode may be called more than once
+ // to reuse the instance. It must clear any previous state.
+ //
+ // This may not fail, exhaustion will be recorded in the buffer.
+ decode(b *buffer)
+
+ // encode encodes to the given buffer.
+ //
+ // This may not fail.
+ encode(b *buffer)
+}
+
+// order is the byte order used for encoding.
+var order = binary.LittleEndian
+
+// buffer is a slice that is consumed.
+//
+// This is passed to the encoder methods.
+type buffer struct {
+ // data is the underlying data. This may grow during encode.
+ data []byte
+
+ // overflow indicates whether an overflow has occurred.
+ overflow bool
+}
+
+// append appends n bytes to the buffer and returns a slice pointing to the
+// newly appended bytes.
+func (b *buffer) append(n int) []byte {
+ b.data = append(b.data, make([]byte, n)...)
+ return b.data[len(b.data)-n:]
+}
+
+// consume consumes n bytes from the buffer.
+func (b *buffer) consume(n int) ([]byte, bool) {
+ if !b.has(n) {
+ b.markOverrun()
+ return nil, false
+ }
+ rval := b.data[:n]
+ b.data = b.data[n:]
+ return rval, true
+}
+
+// has returns true if n bytes are available.
+func (b *buffer) has(n int) bool {
+ return len(b.data) >= n
+}
+
+// markOverrun immediately marks this buffer as overrun.
+//
+// This is used by ReadString, since some invalid data implies the rest of the
+// buffer is no longer valid either.
+func (b *buffer) markOverrun() {
+ b.overflow = true
+}
+
+// isOverrun returns true if this buffer has run past the end.
+func (b *buffer) isOverrun() bool {
+ return b.overflow
+}
+
+// Read8 reads a byte from the buffer.
+func (b *buffer) Read8() uint8 {
+ v, ok := b.consume(1)
+ if !ok {
+ return 0
+ }
+ return uint8(v[0])
+}
+
+// Read16 reads a 16-bit value from the buffer.
+func (b *buffer) Read16() uint16 {
+ v, ok := b.consume(2)
+ if !ok {
+ return 0
+ }
+ return order.Uint16(v)
+}
+
+// Read32 reads a 32-bit value from the buffer.
+func (b *buffer) Read32() uint32 {
+ v, ok := b.consume(4)
+ if !ok {
+ return 0
+ }
+ return order.Uint32(v)
+}
+
+// Read64 reads a 64-bit value from the buffer.
+func (b *buffer) Read64() uint64 {
+ v, ok := b.consume(8)
+ if !ok {
+ return 0
+ }
+ return order.Uint64(v)
+}
+
+// ReadQIDType reads a QIDType value.
+func (b *buffer) ReadQIDType() QIDType {
+ return QIDType(b.Read8())
+}
+
+// ReadTag reads a Tag value.
+func (b *buffer) ReadTag() Tag {
+ return Tag(b.Read16())
+}
+
+// ReadFID reads a FID value.
+func (b *buffer) ReadFID() FID {
+ return FID(b.Read32())
+}
+
+// ReadUID reads a UID value.
+func (b *buffer) ReadUID() UID {
+ return UID(b.Read32())
+}
+
+// ReadGID reads a GID value.
+func (b *buffer) ReadGID() GID {
+ return GID(b.Read32())
+}
+
+// ReadPermissions reads a file mode value and applies the mask for permissions.
+func (b *buffer) ReadPermissions() FileMode {
+ return b.ReadFileMode() & permissionsMask
+}
+
+// ReadFileMode reads a file mode value.
+func (b *buffer) ReadFileMode() FileMode {
+ return FileMode(b.Read32())
+}
+
+// ReadOpenFlags reads an OpenFlags.
+func (b *buffer) ReadOpenFlags() OpenFlags {
+ return OpenFlags(b.Read32())
+}
+
+// ReadConnectFlags reads a ConnectFlags.
+func (b *buffer) ReadConnectFlags() ConnectFlags {
+ return ConnectFlags(b.Read32())
+}
+
+// ReadMsgType writes a MsgType.
+func (b *buffer) ReadMsgType() MsgType {
+ return MsgType(b.Read8())
+}
+
+// ReadString deserializes a string.
+func (b *buffer) ReadString() string {
+ l := b.Read16()
+ if !b.has(int(l)) {
+ // Mark the buffer as corrupted.
+ b.markOverrun()
+ return ""
+ }
+
+ bs := make([]byte, l)
+ for i := 0; i < int(l); i++ {
+ bs[i] = byte(b.Read8())
+ }
+ return string(bs)
+}
+
+// Write8 writes a byte to the buffer.
+func (b *buffer) Write8(v uint8) {
+ b.append(1)[0] = byte(v)
+}
+
+// Write16 writes a 16-bit value to the buffer.
+func (b *buffer) Write16(v uint16) {
+ order.PutUint16(b.append(2), v)
+}
+
+// Write32 writes a 32-bit value to the buffer.
+func (b *buffer) Write32(v uint32) {
+ order.PutUint32(b.append(4), v)
+}
+
+// Write64 writes a 64-bit value to the buffer.
+func (b *buffer) Write64(v uint64) {
+ order.PutUint64(b.append(8), v)
+}
+
+// WriteQIDType writes a QIDType value.
+func (b *buffer) WriteQIDType(qidType QIDType) {
+ b.Write8(uint8(qidType))
+}
+
+// WriteTag writes a Tag value.
+func (b *buffer) WriteTag(tag Tag) {
+ b.Write16(uint16(tag))
+}
+
+// WriteFID writes a FID value.
+func (b *buffer) WriteFID(fid FID) {
+ b.Write32(uint32(fid))
+}
+
+// WriteUID writes a UID value.
+func (b *buffer) WriteUID(uid UID) {
+ b.Write32(uint32(uid))
+}
+
+// WriteGID writes a GID value.
+func (b *buffer) WriteGID(gid GID) {
+ b.Write32(uint32(gid))
+}
+
+// WritePermissions applies a permissions mask and writes the FileMode.
+func (b *buffer) WritePermissions(perm FileMode) {
+ b.WriteFileMode(perm & permissionsMask)
+}
+
+// WriteFileMode writes a FileMode.
+func (b *buffer) WriteFileMode(mode FileMode) {
+ b.Write32(uint32(mode))
+}
+
+// WriteOpenFlags writes an OpenFlags.
+func (b *buffer) WriteOpenFlags(flags OpenFlags) {
+ b.Write32(uint32(flags))
+}
+
+// WriteConnectFlags writes a ConnectFlags.
+func (b *buffer) WriteConnectFlags(flags ConnectFlags) {
+ b.Write32(uint32(flags))
+}
+
+// WriteMsgType writes a MsgType.
+func (b *buffer) WriteMsgType(t MsgType) {
+ b.Write8(uint8(t))
+}
+
+// WriteString serializes the given string.
+func (b *buffer) WriteString(s string) {
+ b.Write16(uint16(len(s)))
+ for i := 0; i < len(s); i++ {
+ b.Write8(byte(s[i]))
+ }
+}
diff --git a/pkg/p9/buffer_test.go b/pkg/p9/buffer_test.go
new file mode 100644
index 000000000..a9c75f86b
--- /dev/null
+++ b/pkg/p9/buffer_test.go
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "testing"
+)
+
+func TestBufferOverrun(t *testing.T) {
+ buf := &buffer{
+ // This header indicates that a large string should follow, but
+ // it is only two bytes. Reading a string should cause an
+ // overrun.
+ data: []byte{0x0, 0x16},
+ }
+ if s := buf.ReadString(); s != "" {
+ t.Errorf("overrun read got %s, want empty", s)
+ }
+}
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
new file mode 100644
index 000000000..71e944c30
--- /dev/null
+++ b/pkg/p9/client.go
@@ -0,0 +1,575 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "errors"
+ "fmt"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/pool"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// ErrOutOfTags indicates no tags are available.
+var ErrOutOfTags = errors.New("out of tags -- messages lost?")
+
+// ErrOutOfFIDs indicates no more FIDs are available.
+var ErrOutOfFIDs = errors.New("out of FIDs -- messages lost?")
+
+// ErrUnexpectedTag indicates a response with an unexpected tag was received.
+var ErrUnexpectedTag = errors.New("unexpected tag in response")
+
+// ErrVersionsExhausted indicates that all versions to negotiate have been exhausted.
+var ErrVersionsExhausted = errors.New("exhausted all versions to negotiate")
+
+// ErrBadVersionString indicates that the version string is malformed or unsupported.
+var ErrBadVersionString = errors.New("bad version string")
+
+// ErrBadResponse indicates the response didn't match the request.
+type ErrBadResponse struct {
+ Got MsgType
+ Want MsgType
+}
+
+// Error returns a highly descriptive error.
+func (e *ErrBadResponse) Error() string {
+ return fmt.Sprintf("unexpected message type: got %v, want %v", e.Got, e.Want)
+}
+
+// response is the asynchronous return from recv.
+//
+// This is used in the pending map below.
+type response struct {
+ r message
+ done chan error
+}
+
+var responsePool = sync.Pool{
+ New: func() interface{} {
+ return &response{
+ done: make(chan error, 1),
+ }
+ },
+}
+
+// Client is at least a 9P2000.L client.
+type Client struct {
+ // socket is the connected socket.
+ socket *unet.Socket
+
+ // tagPool is the collection of available tags.
+ tagPool pool.Pool
+
+ // fidPool is the collection of available fids.
+ fidPool pool.Pool
+
+ // messageSize is the maximum total size of a message.
+ messageSize uint32
+
+ // payloadSize is the maximum payload size of a read or write.
+ //
+ // For large reads and writes this means that the read or write is
+ // broken up into buffer-size/payloadSize requests.
+ payloadSize uint32
+
+ // version is the agreed upon version X of 9P2000.L.Google.X.
+ // version 0 implies 9P2000.L.
+ version uint32
+
+ // closedWg is marked as done when the Client.watch() goroutine, which is
+ // responsible for closing channels and the socket fd, returns.
+ closedWg sync.WaitGroup
+
+ // sendRecv is the transport function.
+ //
+ // This is determined dynamically based on whether or not the server
+ // supports flipcall channels (preferred as it is faster and more
+ // efficient, and does not require tags).
+ sendRecv func(message, message) error
+
+ // -- below corresponds to sendRecvChannel --
+
+ // channelsMu protects channels.
+ channelsMu sync.Mutex
+
+ // channelsWg counts the number of channels for which channel.active ==
+ // true.
+ channelsWg sync.WaitGroup
+
+ // channels is the set of all initialized channels.
+ channels []*channel
+
+ // availableChannels is a FIFO of inactive channels.
+ availableChannels []*channel
+
+ // -- below corresponds to sendRecvLegacy --
+
+ // pending is the set of pending messages.
+ pending map[Tag]*response
+ pendingMu sync.Mutex
+
+ // sendMu is the lock for sending a request.
+ sendMu sync.Mutex
+
+ // recvr is essentially a mutex for calling recv.
+ //
+ // Whoever writes to this channel is permitted to call recv. When
+ // finished calling recv, this channel should be emptied.
+ recvr chan bool
+}
+
+// NewClient creates a new client. It performs a Tversion exchange with
+// the server to assert that messageSize is ok to use.
+//
+// If NewClient succeeds, ownership of socket is transferred to the new Client.
+func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) {
+ // Need at least one byte of payload.
+ if messageSize <= msgRegistry.largestFixedSize {
+ return nil, &ErrMessageTooLarge{
+ size: messageSize,
+ msize: msgRegistry.largestFixedSize,
+ }
+ }
+
+ // Compute a payload size and round to 512 (normal block size)
+ // if it's larger than a single block.
+ payloadSize := messageSize - msgRegistry.largestFixedSize
+ if payloadSize > 512 && payloadSize%512 != 0 {
+ payloadSize -= (payloadSize % 512)
+ }
+ c := &Client{
+ socket: socket,
+ tagPool: pool.Pool{Start: 1, Limit: uint64(NoTag)},
+ fidPool: pool.Pool{Start: 1, Limit: uint64(NoFID)},
+ pending: make(map[Tag]*response),
+ recvr: make(chan bool, 1),
+ messageSize: messageSize,
+ payloadSize: payloadSize,
+ }
+ // Agree upon a version.
+ requested, ok := parseVersion(version)
+ if !ok {
+ return nil, ErrBadVersionString
+ }
+ for {
+ // Always exchange the version using the legacy version of the
+ // protocol. If the protocol supports flipcall, then we switch
+ // our sendRecv function to use that functionality. Otherwise,
+ // we stick to sendRecvLegacy.
+ rversion := Rversion{}
+ _, err := c.sendRecvLegacy(&Tversion{
+ Version: versionString(requested),
+ MSize: messageSize,
+ }, &rversion)
+
+ // The server told us to try again with a lower version.
+ if err == syscall.EAGAIN {
+ if requested == lowestSupportedVersion {
+ return nil, ErrVersionsExhausted
+ }
+ requested--
+ continue
+ }
+
+ // We requested an impossible version or our other parameters were bogus.
+ if err != nil {
+ return nil, err
+ }
+
+ // Parse the version.
+ version, ok := parseVersion(rversion.Version)
+ if !ok {
+ // The server gave us a bad version. We return a generically worrisome error.
+ log.Warningf("server returned bad version string %q", rversion.Version)
+ return nil, ErrBadVersionString
+ }
+ c.version = version
+ break
+ }
+
+ // Can we switch to use the more advanced channels and create
+ // independent channels for communication? Prefer it if possible.
+ if versionSupportsFlipcall(c.version) {
+ // Attempt to initialize IPC-based communication.
+ for i := 0; i < channelsPerClient; i++ {
+ if err := c.openChannel(i); err != nil {
+ log.Warningf("error opening flipcall channel: %v", err)
+ break // Stop.
+ }
+ }
+ if len(c.channels) >= 1 {
+ // At least one channel created.
+ c.sendRecv = c.sendRecvChannel
+ } else {
+ // Channel setup failed; fallback.
+ c.sendRecv = c.sendRecvLegacySyscallErr
+ }
+ } else {
+ // No channels available: use the legacy mechanism.
+ c.sendRecv = c.sendRecvLegacySyscallErr
+ }
+
+ // Ensure that the socket and channels are closed when the socket is shut
+ // down.
+ c.closedWg.Add(1)
+ go c.watch(socket) // S/R-SAFE: not relevant.
+
+ return c, nil
+}
+
+// watch watches the given socket and releases resources on hangup events.
+//
+// This is intended to be called as a goroutine.
+func (c *Client) watch(socket *unet.Socket) {
+ defer c.closedWg.Done()
+
+ events := []unix.PollFd{
+ unix.PollFd{
+ Fd: int32(socket.FD()),
+ Events: unix.POLLHUP | unix.POLLRDHUP,
+ },
+ }
+
+ // Wait for a shutdown event.
+ for {
+ n, err := unix.Ppoll(events, nil, nil)
+ if err == syscall.EINTR || err == syscall.EAGAIN {
+ continue
+ }
+ if err != nil {
+ log.Warningf("p9.Client.watch(): %v", err)
+ break
+ }
+ if n != 1 {
+ log.Warningf("p9.Client.watch(): got %d events, wanted 1", n)
+ }
+ break
+ }
+
+ // Set availableChannels to nil so that future calls to c.sendRecvChannel()
+ // don't attempt to activate a channel, and concurrent calls to
+ // c.sendRecvChannel() don't mark released channels as available.
+ c.channelsMu.Lock()
+ c.availableChannels = nil
+
+ // Shut down all active channels.
+ for _, ch := range c.channels {
+ if ch.active {
+ log.Debugf("shutting down active channel@%p...", ch)
+ ch.Shutdown()
+ }
+ }
+ c.channelsMu.Unlock()
+
+ // Wait for active channels to become inactive.
+ c.channelsWg.Wait()
+
+ // Close all channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.Close()
+ }
+ c.channelsMu.Unlock()
+
+ // Close the main socket.
+ c.socket.Close()
+}
+
+// openChannel attempts to open a client channel.
+//
+// Note that this function returns naked errors which should not be propagated
+// directly to a caller. It is expected that the errors will be logged and a
+// fallback path will be used instead.
+func (c *Client) openChannel(id int) error {
+ var (
+ rchannel0 Rchannel
+ rchannel1 Rchannel
+ res = new(channel)
+ )
+
+ // Open the data channel.
+ if _, err := c.sendRecvLegacy(&Tchannel{
+ ID: uint32(id),
+ Control: 0,
+ }, &rchannel0); err != nil {
+ return fmt.Errorf("error handling Tchannel message: %v", err)
+ }
+ if rchannel0.FilePayload() == nil {
+ return fmt.Errorf("missing file descriptor on primary channel")
+ }
+
+ // We don't need to hold this.
+ defer rchannel0.FilePayload().Close()
+
+ // Open the channel for file descriptors.
+ if _, err := c.sendRecvLegacy(&Tchannel{
+ ID: uint32(id),
+ Control: 1,
+ }, &rchannel1); err != nil {
+ return err
+ }
+ if rchannel1.FilePayload() == nil {
+ return fmt.Errorf("missing file descriptor on file descriptor channel")
+ }
+
+ // Construct the endpoints.
+ res.desc = flipcall.PacketWindowDescriptor{
+ FD: rchannel0.FilePayload().FD(),
+ Offset: int64(rchannel0.Offset),
+ Length: int(rchannel0.Length),
+ }
+ if err := res.data.Init(flipcall.ClientSide, res.desc); err != nil {
+ rchannel1.FilePayload().Close()
+ return err
+ }
+
+ // The fds channel owns the control payload, and it will be closed when
+ // the channel object is closed.
+ res.fds.Init(rchannel1.FilePayload().Release())
+
+ // Save the channel.
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+ c.channels = append(c.channels, res)
+ c.availableChannels = append(c.availableChannels, res)
+ return nil
+}
+
+// handleOne handles a single incoming message.
+//
+// This should only be called with the token from recvr. Note that the received
+// tag will automatically be cleared from pending.
+func (c *Client) handleOne() {
+ tag, r, err := recv(c.socket, c.messageSize, func(tag Tag, t MsgType) (message, error) {
+ c.pendingMu.Lock()
+ resp := c.pending[tag]
+ c.pendingMu.Unlock()
+
+ // Not expecting this message?
+ if resp == nil {
+ log.Warningf("client received unexpected tag %v, ignoring", tag)
+ return nil, ErrUnexpectedTag
+ }
+
+ // Is it an error? We specifically allow this to
+ // go through, and then we deserialize below.
+ if t == MsgRlerror {
+ return &Rlerror{}, nil
+ }
+
+ // Does it match expectations?
+ if t != resp.r.Type() {
+ return nil, &ErrBadResponse{Got: t, Want: resp.r.Type()}
+ }
+
+ // Return the response.
+ return resp.r, nil
+ })
+
+ if err != nil {
+ // No tag was extracted (probably a socket error).
+ //
+ // Likely catastrophic. Notify all waiters and clear pending.
+ c.pendingMu.Lock()
+ for _, resp := range c.pending {
+ resp.done <- err
+ }
+ c.pending = make(map[Tag]*response)
+ c.pendingMu.Unlock()
+ } else {
+ // Process the tag.
+ //
+ // We know that is is contained in the map because our lookup function
+ // above must have succeeded (found the tag) to return nil err.
+ c.pendingMu.Lock()
+ resp := c.pending[tag]
+ delete(c.pending, tag)
+ c.pendingMu.Unlock()
+ resp.r = r
+ resp.done <- err
+ }
+}
+
+// waitAndRecv co-ordinates with other receivers to handle responses.
+func (c *Client) waitAndRecv(done chan error) error {
+ for {
+ select {
+ case err := <-done:
+ return err
+ case c.recvr <- true:
+ select {
+ case err := <-done:
+ // It's possible that we got the token, despite
+ // done also being available. Check for that.
+ <-c.recvr
+ return err
+ default:
+ // Handle receiving one tag.
+ c.handleOne()
+
+ // Return the token.
+ <-c.recvr
+ }
+ }
+ }
+}
+
+// sendRecvLegacySyscallErr is a wrapper for sendRecvLegacy that converts all
+// non-syscall errors to EIO.
+func (c *Client) sendRecvLegacySyscallErr(t message, r message) error {
+ received, err := c.sendRecvLegacy(t, r)
+ if !received {
+ log.Warningf("p9.Client.sendRecvChannel: %v", err)
+ return syscall.EIO
+ }
+ return err
+}
+
+// sendRecvLegacy performs a roundtrip message exchange.
+//
+// sendRecvLegacy returns true if a message was received. This allows us to
+// differentiate between failed receives and successful receives where the
+// response was an error message.
+//
+// This is called by internal functions.
+func (c *Client) sendRecvLegacy(t message, r message) (bool, error) {
+ tag, ok := c.tagPool.Get()
+ if !ok {
+ return false, ErrOutOfTags
+ }
+ defer c.tagPool.Put(tag)
+
+ // Indicate we're expecting a response.
+ //
+ // Note that the tag will be cleared from pending
+ // automatically (see handleOne for details).
+ resp := responsePool.Get().(*response)
+ defer responsePool.Put(resp)
+ resp.r = r
+ c.pendingMu.Lock()
+ c.pending[Tag(tag)] = resp
+ c.pendingMu.Unlock()
+
+ // Send the request over the wire.
+ c.sendMu.Lock()
+ err := send(c.socket, Tag(tag), t)
+ c.sendMu.Unlock()
+ if err != nil {
+ return false, err
+ }
+
+ // Co-ordinate with other receivers.
+ if err := c.waitAndRecv(resp.done); err != nil {
+ return false, err
+ }
+
+ // Is it an error message?
+ //
+ // For convenience, we transform these directly
+ // into errors. Handlers need not handle this case.
+ if rlerr, ok := resp.r.(*Rlerror); ok {
+ return true, syscall.Errno(rlerr.Error)
+ }
+
+ // At this point, we know it matches.
+ //
+ // Per recv call above, we will only allow a type
+ // match (and give our r) or an instance of Rlerror.
+ return true, nil
+}
+
+// sendRecvChannel uses channels to send a message.
+func (c *Client) sendRecvChannel(t message, r message) error {
+ // Acquire an available channel.
+ c.channelsMu.Lock()
+ if len(c.availableChannels) == 0 {
+ c.channelsMu.Unlock()
+ return c.sendRecvLegacySyscallErr(t, r)
+ }
+ idx := len(c.availableChannels) - 1
+ ch := c.availableChannels[idx]
+ c.availableChannels = c.availableChannels[:idx]
+ ch.active = true
+ c.channelsWg.Add(1)
+ c.channelsMu.Unlock()
+
+ // Ensure that it's connected.
+ if !ch.connected {
+ ch.connected = true
+ if err := ch.data.Connect(); err != nil {
+ // The channel is unusable, so don't return it to
+ // c.availableChannels. However, we still have to mark it as
+ // inactive so c.watch() doesn't wait for it.
+ c.channelsMu.Lock()
+ ch.active = false
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+ // Map all transport errors to EIO, but ensure that the real error
+ // is logged.
+ log.Warningf("p9.Client.sendRecvChannel: flipcall.Endpoint.Connect: %v", err)
+ return syscall.EIO
+ }
+ }
+
+ // Send the request and receive the server's response.
+ rsz, err := ch.send(t)
+ if err != nil {
+ // See above.
+ c.channelsMu.Lock()
+ ch.active = false
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+ log.Warningf("p9.Client.sendRecvChannel: p9.channel.send: %v", err)
+ return syscall.EIO
+ }
+
+ // Parse the server's response.
+ resp, retErr := ch.recv(r, rsz)
+ if resp == nil {
+ log.Warningf("p9.Client.sendRecvChannel: p9.channel.recv: %v", retErr)
+ retErr = syscall.EIO
+ }
+
+ // Release the channel.
+ c.channelsMu.Lock()
+ ch.active = false
+ // If c.availableChannels is nil, c.watch() has fired and we should not
+ // mark this channel as available.
+ if c.availableChannels != nil {
+ c.availableChannels = append(c.availableChannels, ch)
+ }
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+
+ return retErr
+}
+
+// Version returns the negotiated 9P2000.L.Google version number.
+func (c *Client) Version() uint32 {
+ return c.version
+}
+
+// Close closes the underlying socket and channels.
+func (c *Client) Close() {
+ // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already
+ // been called (by c.watch()).
+ c.socket.Shutdown()
+ c.closedWg.Wait()
+}
diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go
new file mode 100644
index 000000000..2ee07b664
--- /dev/null
+++ b/pkg/p9/client_file.go
@@ -0,0 +1,686 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "fmt"
+ "io"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// Attach attaches to a server.
+//
+// Note that authentication is not currently supported.
+func (c *Client) Attach(name string) (File, error) {
+ fid, ok := c.fidPool.Get()
+ if !ok {
+ return nil, ErrOutOfFIDs
+ }
+
+ rattach := Rattach{}
+ if err := c.sendRecv(&Tattach{FID: FID(fid), Auth: Tauth{AttachName: name, AuthenticationFID: NoFID, UID: NoUID}}, &rattach); err != nil {
+ c.fidPool.Put(fid)
+ return nil, err
+ }
+
+ return c.newFile(FID(fid)), nil
+}
+
+// newFile returns a new client file.
+func (c *Client) newFile(fid FID) *clientFile {
+ return &clientFile{
+ client: c,
+ fid: fid,
+ }
+}
+
+// clientFile is provided to clients.
+//
+// This proxies all of the interfaces found in file.go.
+type clientFile struct {
+ // client is the originating client.
+ client *Client
+
+ // fid is the FID for this file.
+ fid FID
+
+ // closed indicates whether this file has been closed.
+ closed uint32
+}
+
+// Walk implements File.Walk.
+func (c *clientFile) Walk(names []string) ([]QID, File, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, nil, syscall.EBADF
+ }
+
+ fid, ok := c.client.fidPool.Get()
+ if !ok {
+ return nil, nil, ErrOutOfFIDs
+ }
+
+ rwalk := Rwalk{}
+ if err := c.client.sendRecv(&Twalk{FID: c.fid, NewFID: FID(fid), Names: names}, &rwalk); err != nil {
+ c.client.fidPool.Put(fid)
+ return nil, nil, err
+ }
+
+ // Return a new client file.
+ return rwalk.QIDs, c.client.newFile(FID(fid)), nil
+}
+
+// WalkGetAttr implements File.WalkGetAttr.
+func (c *clientFile) WalkGetAttr(components []string) ([]QID, File, AttrMask, Attr, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, nil, AttrMask{}, Attr{}, syscall.EBADF
+ }
+
+ if !versionSupportsTwalkgetattr(c.client.version) {
+ qids, file, err := c.Walk(components)
+ if err != nil {
+ return nil, nil, AttrMask{}, Attr{}, err
+ }
+ _, valid, attr, err := file.GetAttr(AttrMaskAll())
+ if err != nil {
+ file.Close()
+ return nil, nil, AttrMask{}, Attr{}, err
+ }
+ return qids, file, valid, attr, nil
+ }
+
+ fid, ok := c.client.fidPool.Get()
+ if !ok {
+ return nil, nil, AttrMask{}, Attr{}, ErrOutOfFIDs
+ }
+
+ rwalkgetattr := Rwalkgetattr{}
+ if err := c.client.sendRecv(&Twalkgetattr{FID: c.fid, NewFID: FID(fid), Names: components}, &rwalkgetattr); err != nil {
+ c.client.fidPool.Put(fid)
+ return nil, nil, AttrMask{}, Attr{}, err
+ }
+
+ // Return a new client file.
+ return rwalkgetattr.QIDs, c.client.newFile(FID(fid)), rwalkgetattr.Valid, rwalkgetattr.Attr, nil
+}
+
+// StatFS implements File.StatFS.
+func (c *clientFile) StatFS() (FSStat, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return FSStat{}, syscall.EBADF
+ }
+
+ rstatfs := Rstatfs{}
+ if err := c.client.sendRecv(&Tstatfs{FID: c.fid}, &rstatfs); err != nil {
+ return FSStat{}, err
+ }
+
+ return rstatfs.FSStat, nil
+}
+
+// FSync implements File.FSync.
+func (c *clientFile) FSync() error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ return c.client.sendRecv(&Tfsync{FID: c.fid}, &Rfsync{})
+}
+
+// GetAttr implements File.GetAttr.
+func (c *clientFile) GetAttr(req AttrMask) (QID, AttrMask, Attr, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return QID{}, AttrMask{}, Attr{}, syscall.EBADF
+ }
+
+ rgetattr := Rgetattr{}
+ if err := c.client.sendRecv(&Tgetattr{FID: c.fid, AttrMask: req}, &rgetattr); err != nil {
+ return QID{}, AttrMask{}, Attr{}, err
+ }
+
+ return rgetattr.QID, rgetattr.Valid, rgetattr.Attr, nil
+}
+
+// SetAttr implements File.SetAttr.
+func (c *clientFile) SetAttr(valid SetAttrMask, attr SetAttr) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ return c.client.sendRecv(&Tsetattr{FID: c.fid, Valid: valid, SetAttr: attr}, &Rsetattr{})
+}
+
+// GetXattr implements File.GetXattr.
+func (c *clientFile) GetXattr(name string, size uint64) (string, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return "", syscall.EBADF
+ }
+ if !versionSupportsGetSetXattr(c.client.version) {
+ return "", syscall.EOPNOTSUPP
+ }
+
+ rgetxattr := Rgetxattr{}
+ if err := c.client.sendRecv(&Tgetxattr{FID: c.fid, Name: name, Size: size}, &rgetxattr); err != nil {
+ return "", err
+ }
+
+ return rgetxattr.Value, nil
+}
+
+// SetXattr implements File.SetXattr.
+func (c *clientFile) SetXattr(name, value string, flags uint32) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+ if !versionSupportsGetSetXattr(c.client.version) {
+ return syscall.EOPNOTSUPP
+ }
+
+ return c.client.sendRecv(&Tsetxattr{FID: c.fid, Name: name, Value: value, Flags: flags}, &Rsetxattr{})
+}
+
+// ListXattr implements File.ListXattr.
+func (c *clientFile) ListXattr(size uint64) (map[string]struct{}, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, syscall.EBADF
+ }
+ if !versionSupportsListRemoveXattr(c.client.version) {
+ return nil, syscall.EOPNOTSUPP
+ }
+
+ rlistxattr := Rlistxattr{}
+ if err := c.client.sendRecv(&Tlistxattr{FID: c.fid, Size: size}, &rlistxattr); err != nil {
+ return nil, err
+ }
+
+ xattrs := make(map[string]struct{}, len(rlistxattr.Xattrs))
+ for _, x := range rlistxattr.Xattrs {
+ xattrs[x] = struct{}{}
+ }
+ return xattrs, nil
+}
+
+// RemoveXattr implements File.RemoveXattr.
+func (c *clientFile) RemoveXattr(name string) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+ if !versionSupportsListRemoveXattr(c.client.version) {
+ return syscall.EOPNOTSUPP
+ }
+
+ return c.client.sendRecv(&Tremovexattr{FID: c.fid, Name: name}, &Rremovexattr{})
+}
+
+// Allocate implements File.Allocate.
+func (c *clientFile) Allocate(mode AllocateMode, offset, length uint64) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+ if !versionSupportsTallocate(c.client.version) {
+ return syscall.EOPNOTSUPP
+ }
+
+ return c.client.sendRecv(&Tallocate{FID: c.fid, Mode: mode, Offset: offset, Length: length}, &Rallocate{})
+}
+
+// Remove implements File.Remove.
+//
+// N.B. This method is no longer part of the file interface and should be
+// considered deprecated.
+func (c *clientFile) Remove() error {
+ // Avoid double close.
+ if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
+ return syscall.EBADF
+ }
+
+ // Send the remove message.
+ if err := c.client.sendRecv(&Tremove{FID: c.fid}, &Rremove{}); err != nil {
+ return err
+ }
+
+ // "It is correct to consider remove to be a clunk with the side effect
+ // of removing the file if permissions allow."
+ // https://swtch.com/plan9port/man/man9/remove.html
+
+ // Return the FID to the pool.
+ c.client.fidPool.Put(uint64(c.fid))
+ return nil
+}
+
+// Close implements File.Close.
+func (c *clientFile) Close() error {
+ // Avoid double close.
+ if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
+ return syscall.EBADF
+ }
+
+ // Send the close message.
+ if err := c.client.sendRecv(&Tclunk{FID: c.fid}, &Rclunk{}); err != nil {
+ // If an error occurred, we toss away the FID. This isn't ideal,
+ // but I'm not sure what else makes sense in this context.
+ log.Warningf("Tclunk failed, losing FID %v: %v", c.fid, err)
+ return err
+ }
+
+ // Return the FID to the pool.
+ c.client.fidPool.Put(uint64(c.fid))
+ return nil
+}
+
+// Open implements File.Open.
+func (c *clientFile) Open(flags OpenFlags) (*fd.FD, QID, uint32, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, QID{}, 0, syscall.EBADF
+ }
+
+ rlopen := Rlopen{}
+ if err := c.client.sendRecv(&Tlopen{FID: c.fid, Flags: flags}, &rlopen); err != nil {
+ return nil, QID{}, 0, err
+ }
+
+ return rlopen.File, rlopen.QID, rlopen.IoUnit, nil
+}
+
+// Connect implements File.Connect.
+func (c *clientFile) Connect(flags ConnectFlags) (*fd.FD, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, syscall.EBADF
+ }
+
+ if !VersionSupportsConnect(c.client.version) {
+ return nil, syscall.ECONNREFUSED
+ }
+
+ rlconnect := Rlconnect{}
+ if err := c.client.sendRecv(&Tlconnect{FID: c.fid, Flags: flags}, &rlconnect); err != nil {
+ return nil, err
+ }
+
+ return rlconnect.File, nil
+}
+
+// chunk applies fn to p in chunkSize-sized chunks until fn returns a partial result, p is
+// exhausted, or an error is encountered (which may be io.EOF).
+func chunk(chunkSize uint32, fn func([]byte, uint64) (int, error), p []byte, offset uint64) (int, error) {
+ // Some p9.Clients depend on executing fn on zero-byte buffers. Handle this
+ // as a special case (normally it is fine to short-circuit and return (0, nil)).
+ if len(p) == 0 {
+ return fn(p, offset)
+ }
+
+ // total is the cumulative bytes processed.
+ var total int
+ for {
+ var n int
+ var err error
+
+ // We're done, don't bother trying to do anything more.
+ if total == len(p) {
+ return total, nil
+ }
+
+ // Apply fn to a chunkSize-sized (or less) chunk of p.
+ if len(p) < total+int(chunkSize) {
+ n, err = fn(p[total:], offset)
+ } else {
+ n, err = fn(p[total:total+int(chunkSize)], offset)
+ }
+ total += n
+ offset += uint64(n)
+
+ // Return whatever we have processed if we encounter an error. This error
+ // could be io.EOF.
+ if err != nil {
+ return total, err
+ }
+
+ // Did we get a partial result? If so, return it immediately.
+ if n < int(chunkSize) {
+ return total, nil
+ }
+
+ // If we received more bytes than we ever requested, this is a problem.
+ if total > len(p) {
+ panic(fmt.Sprintf("bytes completed (%d)) > requested (%d)", total, len(p)))
+ }
+ }
+}
+
+// ReadAt proxies File.ReadAt.
+func (c *clientFile) ReadAt(p []byte, offset uint64) (int, error) {
+ return chunk(c.client.payloadSize, c.readAt, p, offset)
+}
+
+func (c *clientFile) readAt(p []byte, offset uint64) (int, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return 0, syscall.EBADF
+ }
+
+ rread := Rread{Data: p}
+ if err := c.client.sendRecv(&Tread{FID: c.fid, Offset: offset, Count: uint32(len(p))}, &rread); err != nil {
+ return 0, err
+ }
+
+ // The message may have been truncated, or for some reason a new buffer
+ // allocated. This isn't the common path, but we make sure that if the
+ // payload has changed we copy it. See transport.go for more information.
+ if len(p) > 0 && len(rread.Data) > 0 && &rread.Data[0] != &p[0] {
+ copy(p, rread.Data)
+ }
+
+ // io.EOF is not an error that a p9 server can return. Use POSIX semantics to
+ // return io.EOF manually: zero bytes were returned and a non-zero buffer was used.
+ if len(rread.Data) == 0 && len(p) > 0 {
+ return 0, io.EOF
+ }
+
+ return len(rread.Data), nil
+}
+
+// WriteAt proxies File.WriteAt.
+func (c *clientFile) WriteAt(p []byte, offset uint64) (int, error) {
+ return chunk(c.client.payloadSize, c.writeAt, p, offset)
+}
+
+func (c *clientFile) writeAt(p []byte, offset uint64) (int, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return 0, syscall.EBADF
+ }
+
+ rwrite := Rwrite{}
+ if err := c.client.sendRecv(&Twrite{FID: c.fid, Offset: offset, Data: p}, &rwrite); err != nil {
+ return 0, err
+ }
+
+ return int(rwrite.Count), nil
+}
+
+// ReadWriterFile wraps a File and implements io.ReadWriter, io.ReaderAt, and io.WriterAt.
+type ReadWriterFile struct {
+ File File
+ Offset uint64
+}
+
+// Read implements part of the io.ReadWriter interface.
+func (r *ReadWriterFile) Read(p []byte) (int, error) {
+ n, err := r.File.ReadAt(p, r.Offset)
+ r.Offset += uint64(n)
+ if err != nil {
+ return n, err
+ }
+ if n == 0 && len(p) > 0 {
+ return n, io.EOF
+ }
+ return n, nil
+}
+
+// ReadAt implements the io.ReaderAt interface.
+func (r *ReadWriterFile) ReadAt(p []byte, offset int64) (int, error) {
+ n, err := r.File.ReadAt(p, uint64(offset))
+ if err != nil {
+ return 0, err
+ }
+ if n == 0 && len(p) > 0 {
+ return n, io.EOF
+ }
+ return n, nil
+}
+
+// Write implements part of the io.ReadWriter interface.
+func (r *ReadWriterFile) Write(p []byte) (int, error) {
+ n, err := r.File.WriteAt(p, r.Offset)
+ r.Offset += uint64(n)
+ if err != nil {
+ return n, err
+ }
+ if n < len(p) {
+ return n, io.ErrShortWrite
+ }
+ return n, nil
+}
+
+// WriteAt implements the io.WriteAt interface.
+func (r *ReadWriterFile) WriteAt(p []byte, offset int64) (int, error) {
+ n, err := r.File.WriteAt(p, uint64(offset))
+ if err != nil {
+ return n, err
+ }
+ if n < len(p) {
+ return n, io.ErrShortWrite
+ }
+ return n, nil
+}
+
+// Rename implements File.Rename.
+func (c *clientFile) Rename(dir File, name string) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ clientDir, ok := dir.(*clientFile)
+ if !ok {
+ return syscall.EBADF
+ }
+
+ return c.client.sendRecv(&Trename{FID: c.fid, Directory: clientDir.fid, Name: name}, &Rrename{})
+}
+
+// Create implements File.Create.
+func (c *clientFile) Create(name string, openFlags OpenFlags, permissions FileMode, uid UID, gid GID) (*fd.FD, File, QID, uint32, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, nil, QID{}, 0, syscall.EBADF
+ }
+
+ msg := Tlcreate{
+ FID: c.fid,
+ Name: name,
+ OpenFlags: openFlags,
+ Permissions: permissions,
+ GID: NoGID,
+ }
+
+ if versionSupportsTucreation(c.client.version) {
+ msg.GID = gid
+ rucreate := Rucreate{}
+ if err := c.client.sendRecv(&Tucreate{Tlcreate: msg, UID: uid}, &rucreate); err != nil {
+ return nil, nil, QID{}, 0, err
+ }
+ return rucreate.File, c, rucreate.QID, rucreate.IoUnit, nil
+ }
+
+ rlcreate := Rlcreate{}
+ if err := c.client.sendRecv(&msg, &rlcreate); err != nil {
+ return nil, nil, QID{}, 0, err
+ }
+
+ return rlcreate.File, c, rlcreate.QID, rlcreate.IoUnit, nil
+}
+
+// Mkdir implements File.Mkdir.
+func (c *clientFile) Mkdir(name string, permissions FileMode, uid UID, gid GID) (QID, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return QID{}, syscall.EBADF
+ }
+
+ msg := Tmkdir{
+ Directory: c.fid,
+ Name: name,
+ Permissions: permissions,
+ GID: NoGID,
+ }
+
+ if versionSupportsTucreation(c.client.version) {
+ msg.GID = gid
+ rumkdir := Rumkdir{}
+ if err := c.client.sendRecv(&Tumkdir{Tmkdir: msg, UID: uid}, &rumkdir); err != nil {
+ return QID{}, err
+ }
+ return rumkdir.QID, nil
+ }
+
+ rmkdir := Rmkdir{}
+ if err := c.client.sendRecv(&msg, &rmkdir); err != nil {
+ return QID{}, err
+ }
+
+ return rmkdir.QID, nil
+}
+
+// Symlink implements File.Symlink.
+func (c *clientFile) Symlink(oldname string, newname string, uid UID, gid GID) (QID, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return QID{}, syscall.EBADF
+ }
+
+ msg := Tsymlink{
+ Directory: c.fid,
+ Name: newname,
+ Target: oldname,
+ GID: NoGID,
+ }
+
+ if versionSupportsTucreation(c.client.version) {
+ msg.GID = gid
+ rusymlink := Rusymlink{}
+ if err := c.client.sendRecv(&Tusymlink{Tsymlink: msg, UID: uid}, &rusymlink); err != nil {
+ return QID{}, err
+ }
+ return rusymlink.QID, nil
+ }
+
+ rsymlink := Rsymlink{}
+ if err := c.client.sendRecv(&msg, &rsymlink); err != nil {
+ return QID{}, err
+ }
+
+ return rsymlink.QID, nil
+}
+
+// Link implements File.Link.
+func (c *clientFile) Link(target File, newname string) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ targetFile, ok := target.(*clientFile)
+ if !ok {
+ return syscall.EBADF
+ }
+
+ return c.client.sendRecv(&Tlink{Directory: c.fid, Name: newname, Target: targetFile.fid}, &Rlink{})
+}
+
+// Mknod implements File.Mknod.
+func (c *clientFile) Mknod(name string, mode FileMode, major uint32, minor uint32, uid UID, gid GID) (QID, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return QID{}, syscall.EBADF
+ }
+
+ msg := Tmknod{
+ Directory: c.fid,
+ Name: name,
+ Mode: mode,
+ Major: major,
+ Minor: minor,
+ GID: NoGID,
+ }
+
+ if versionSupportsTucreation(c.client.version) {
+ msg.GID = gid
+ rumknod := Rumknod{}
+ if err := c.client.sendRecv(&Tumknod{Tmknod: msg, UID: uid}, &rumknod); err != nil {
+ return QID{}, err
+ }
+ return rumknod.QID, nil
+ }
+
+ rmknod := Rmknod{}
+ if err := c.client.sendRecv(&msg, &rmknod); err != nil {
+ return QID{}, err
+ }
+
+ return rmknod.QID, nil
+}
+
+// RenameAt implements File.RenameAt.
+func (c *clientFile) RenameAt(oldname string, newdir File, newname string) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ clientNewDir, ok := newdir.(*clientFile)
+ if !ok {
+ return syscall.EBADF
+ }
+
+ return c.client.sendRecv(&Trenameat{OldDirectory: c.fid, OldName: oldname, NewDirectory: clientNewDir.fid, NewName: newname}, &Rrenameat{})
+}
+
+// UnlinkAt implements File.UnlinkAt.
+func (c *clientFile) UnlinkAt(name string, flags uint32) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ return c.client.sendRecv(&Tunlinkat{Directory: c.fid, Name: name, Flags: flags}, &Runlinkat{})
+}
+
+// Readdir implements File.Readdir.
+func (c *clientFile) Readdir(offset uint64, count uint32) ([]Dirent, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, syscall.EBADF
+ }
+
+ rreaddir := Rreaddir{}
+ if err := c.client.sendRecv(&Treaddir{Directory: c.fid, Offset: offset, Count: count}, &rreaddir); err != nil {
+ return nil, err
+ }
+
+ return rreaddir.Entries, nil
+}
+
+// Readlink implements File.Readlink.
+func (c *clientFile) Readlink() (string, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return "", syscall.EBADF
+ }
+
+ rreadlink := Rreadlink{}
+ if err := c.client.sendRecv(&Treadlink{FID: c.fid}, &rreadlink); err != nil {
+ return "", err
+ }
+
+ return rreadlink.Target, nil
+}
+
+// Flush implements File.Flush.
+func (c *clientFile) Flush() error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ if !VersionSupportsTflushf(c.client.version) {
+ return nil
+ }
+
+ return c.client.sendRecv(&Tflushf{FID: c.fid}, &Rflushf{})
+}
+
+// Renamed implements File.Renamed.
+func (c *clientFile) Renamed(newDir File, newName string) {}
diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go
new file mode 100644
index 000000000..c757583e0
--- /dev/null
+++ b/pkg/p9/client_test.go
@@ -0,0 +1,109 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "syscall"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// TestVersion tests the version negotiation.
+func TestVersion(t *testing.T) {
+ // First, create a new server and connection.
+ serverSocket, clientSocket, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer clientSocket.Close()
+
+ // Create a new server and client.
+ s := NewServer(nil)
+ go s.Handle(serverSocket)
+
+ // NewClient does a Tversion exchange, so this is our test for success.
+ c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString())
+ if err != nil {
+ t.Fatalf("got %v, expected nil", err)
+ }
+
+ // Check a bogus version string.
+ if err := c.sendRecv(&Tversion{Version: "notokay", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL {
+ t.Errorf("got %v expected %v", err, syscall.EINVAL)
+ }
+
+ // Check a bogus version number.
+ if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL {
+ t.Errorf("got %v expected %v", err, syscall.EINVAL)
+ }
+
+ // Check a too high version number.
+ if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EAGAIN {
+ t.Errorf("got %v expected %v", err, syscall.EAGAIN)
+ }
+
+ // Check an invalid MSize.
+ if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion), MSize: 0}, &Rversion{}); err != syscall.EINVAL {
+ t.Errorf("got %v expected %v", err, syscall.EINVAL)
+ }
+}
+
+func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) error) {
+ // See above.
+ serverSocket, clientSocket, err := unet.SocketPair(false)
+ if err != nil {
+ b.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer clientSocket.Close()
+
+ // See above.
+ s := NewServer(nil)
+ go s.Handle(serverSocket)
+
+ // See above.
+ c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString())
+ if err != nil {
+ b.Fatalf("got %v, expected nil", err)
+ }
+
+ // Initialize messages.
+ sendRecv := fn(c)
+ tversion := &Tversion{
+ Version: versionString(highestSupportedVersion),
+ MSize: DefaultMessageSize,
+ }
+ rversion := new(Rversion)
+
+ // Run in a loop.
+ for i := 0; i < b.N; i++ {
+ if err := sendRecv(tversion, rversion); err != nil {
+ b.Fatalf("got unexpected err: %v", err)
+ }
+ }
+}
+
+func BenchmarkSendRecvLegacy(b *testing.B) {
+ benchmarkSendRecv(b, func(c *Client) func(message, message) error {
+ return func(t message, r message) error {
+ _, err := c.sendRecvLegacy(t, r)
+ return err
+ }
+ })
+}
+
+func BenchmarkSendRecvChannel(b *testing.B) {
+ benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvChannel })
+}
diff --git a/pkg/p9/file.go b/pkg/p9/file.go
new file mode 100644
index 000000000..cab35896f
--- /dev/null
+++ b/pkg/p9/file.go
@@ -0,0 +1,288 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fd"
+)
+
+// Attacher is provided by the server.
+type Attacher interface {
+ // Attach returns a new File.
+ //
+ // The client-side attach will be translate to a series of walks from
+ // the file returned by this Attach call.
+ Attach() (File, error)
+}
+
+// File is a set of operations corresponding to a single node.
+//
+// Note that on the server side, the server logic places constraints on
+// concurrent operations to make things easier. This may reduce the need for
+// complex, error-prone locking and logic in the backend. These are documented
+// for each method.
+//
+// There are three different types of guarantees provided:
+//
+// none: There is no concurrency guarantee. The method may be invoked
+// concurrently with any other method on any other file.
+//
+// read: The method is guaranteed to be exclusive of any write or global
+// operation that is mutating the state of the directory tree starting at this
+// node. For example, this means creating new files, symlinks, directories or
+// renaming a directory entry (or renaming in to this target), but the method
+// may be called concurrently with other read methods.
+//
+// write: The method is guaranteed to be exclusive of any read, write or global
+// operation that is mutating the state of the directory tree starting at this
+// node, as described in read above. There may however, be other write
+// operations executing concurrently on other components in the directory tree.
+//
+// global: The method is guaranteed to be exclusive of any read, write or
+// global operation.
+type File interface {
+ // Walk walks to the path components given in names.
+ //
+ // Walk returns QIDs in the same order that the names were passed in.
+ //
+ // An empty list of arguments should return a copy of the current file.
+ //
+ // On the server, Walk has a read concurrency guarantee.
+ Walk(names []string) ([]QID, File, error)
+
+ // WalkGetAttr walks to the next file and returns its maximal set of
+ // attributes.
+ //
+ // Server-side p9.Files may return syscall.ENOSYS to indicate that Walk
+ // and GetAttr should be used separately to satisfy this request.
+ //
+ // On the server, WalkGetAttr has a read concurrency guarantee.
+ WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error)
+
+ // StatFS returns information about the file system associated with
+ // this file.
+ //
+ // On the server, StatFS has no concurrency guarantee.
+ StatFS() (FSStat, error)
+
+ // GetAttr returns attributes of this node.
+ //
+ // On the server, GetAttr has a read concurrency guarantee.
+ GetAttr(req AttrMask) (QID, AttrMask, Attr, error)
+
+ // SetAttr sets attributes on this node.
+ //
+ // On the server, SetAttr has a write concurrency guarantee.
+ SetAttr(valid SetAttrMask, attr SetAttr) error
+
+ // GetXattr returns extended attributes of this node.
+ //
+ // Size indicates the size of the buffer that has been allocated to hold the
+ // attribute value. If the value is larger than size, implementations may
+ // return ERANGE to indicate that the buffer is too small, but they are also
+ // free to ignore the hint entirely (i.e. the value returned may be larger
+ // than size). All size checking is done independently at the syscall layer.
+ //
+ // On the server, GetXattr has a read concurrency guarantee.
+ GetXattr(name string, size uint64) (string, error)
+
+ // SetXattr sets extended attributes on this node.
+ //
+ // On the server, SetXattr has a write concurrency guarantee.
+ SetXattr(name, value string, flags uint32) error
+
+ // ListXattr lists the names of the extended attributes on this node.
+ //
+ // Size indicates the size of the buffer that has been allocated to hold the
+ // attribute list. If the list would be larger than size, implementations may
+ // return ERANGE to indicate that the buffer is too small, but they are also
+ // free to ignore the hint entirely (i.e. the value returned may be larger
+ // than size). All size checking is done independently at the syscall layer.
+ //
+ // On the server, ListXattr has a read concurrency guarantee.
+ ListXattr(size uint64) (map[string]struct{}, error)
+
+ // RemoveXattr removes extended attributes on this node.
+ //
+ // On the server, RemoveXattr has a write concurrency guarantee.
+ RemoveXattr(name string) error
+
+ // Allocate allows the caller to directly manipulate the allocated disk space
+ // for the file. See fallocate(2) for more details.
+ Allocate(mode AllocateMode, offset, length uint64) error
+
+ // Close is called when all references are dropped on the server side,
+ // and Close should be called by the client to drop all references.
+ //
+ // For server-side implementations of Close, the error is ignored.
+ //
+ // Close must be called even when Open has not been called.
+ //
+ // On the server, Close has no concurrency guarantee.
+ Close() error
+
+ // Open must be called prior to using Read, Write or Readdir. Once Open
+ // is called, some operations, such as Walk, will no longer work.
+ //
+ // On the client, Open should be called only once. The fd return is
+ // optional, and may be nil.
+ //
+ // On the server, Open has a read concurrency guarantee. If an *fd.FD
+ // is provided, ownership now belongs to the caller. Open is guaranteed
+ // to be called only once.
+ //
+ // N.B. The server must resolve any lazy paths when open is called.
+ // After this point, read and write may be called on files with no
+ // deletion check, so resolving in the data path is not viable.
+ Open(flags OpenFlags) (*fd.FD, QID, uint32, error)
+
+ // Read reads from this file. Open must be called first.
+ //
+ // This may return io.EOF in addition to syscall.Errno values.
+ //
+ // On the server, ReadAt has a read concurrency guarantee. See Open for
+ // additional requirements regarding lazy path resolution.
+ ReadAt(p []byte, offset uint64) (int, error)
+
+ // Write writes to this file. Open must be called first.
+ //
+ // This may return io.EOF in addition to syscall.Errno values.
+ //
+ // On the server, WriteAt has a read concurrency guarantee. See Open
+ // for additional requirements regarding lazy path resolution.
+ WriteAt(p []byte, offset uint64) (int, error)
+
+ // FSync syncs this node. Open must be called first.
+ //
+ // On the server, FSync has a read concurrency guarantee.
+ FSync() error
+
+ // Create creates a new regular file and opens it according to the
+ // flags given. This file is already Open.
+ //
+ // N.B. On the client, the returned file is a reference to the current
+ // file, which now represents the created file. This is not the case on
+ // the server. These semantics are very subtle and can easily lead to
+ // bugs, but are a consequence of the 9P create operation.
+ //
+ // See p9.File.Open for a description of *fd.FD.
+ //
+ // On the server, Create has a write concurrency guarantee.
+ Create(name string, flags OpenFlags, permissions FileMode, uid UID, gid GID) (*fd.FD, File, QID, uint32, error)
+
+ // Mkdir creates a subdirectory.
+ //
+ // On the server, Mkdir has a write concurrency guarantee.
+ Mkdir(name string, permissions FileMode, uid UID, gid GID) (QID, error)
+
+ // Symlink makes a new symbolic link.
+ //
+ // On the server, Symlink has a write concurrency guarantee.
+ Symlink(oldName string, newName string, uid UID, gid GID) (QID, error)
+
+ // Link makes a new hard link.
+ //
+ // On the server, Link has a write concurrency guarantee.
+ Link(target File, newName string) error
+
+ // Mknod makes a new device node.
+ //
+ // On the server, Mknod has a write concurrency guarantee.
+ Mknod(name string, mode FileMode, major uint32, minor uint32, uid UID, gid GID) (QID, error)
+
+ // Rename renames the file.
+ //
+ // Rename will never be called on the server, and RenameAt will always
+ // be used instead.
+ Rename(newDir File, newName string) error
+
+ // RenameAt renames a given file to a new name in a potentially new
+ // directory.
+ //
+ // oldName must be a name relative to this file, which must be a
+ // directory. newName is a name relative to newDir.
+ //
+ // On the server, RenameAt has a global concurrency guarantee.
+ RenameAt(oldName string, newDir File, newName string) error
+
+ // UnlinkAt the given named file.
+ //
+ // name must be a file relative to this directory.
+ //
+ // Flags are implementation-specific (e.g. O_DIRECTORY), but are
+ // generally Linux unlinkat(2) flags.
+ //
+ // On the server, UnlinkAt has a write concurrency guarantee.
+ UnlinkAt(name string, flags uint32) error
+
+ // Readdir reads directory entries.
+ //
+ // This may return io.EOF in addition to syscall.Errno values.
+ //
+ // On the server, Readdir has a read concurrency guarantee.
+ Readdir(offset uint64, count uint32) ([]Dirent, error)
+
+ // Readlink reads the link target.
+ //
+ // On the server, Readlink has a read concurrency guarantee.
+ Readlink() (string, error)
+
+ // Flush is called prior to Close.
+ //
+ // Whereas Close drops all references to the file, Flush cleans up the
+ // file state. Behavior is implementation-specific.
+ //
+ // Flush is not related to flush(9p). Flush is an extension to 9P2000.L,
+ // see version.go.
+ //
+ // On the server, Flush has a read concurrency guarantee.
+ Flush() error
+
+ // Connect establishes a new host-socket backed connection with a
+ // socket. A File does not need to be opened before it can be connected
+ // and it can be connected to multiple times resulting in a unique
+ // *fd.FD each time. In addition, the lifetime of the *fd.FD is
+ // independent from the lifetime of the p9.File and must be managed by
+ // the caller.
+ //
+ // The returned FD must be non-blocking.
+ //
+ // Flags indicates the requested type of socket.
+ //
+ // On the server, Connect has a read concurrency guarantee.
+ Connect(flags ConnectFlags) (*fd.FD, error)
+
+ // Renamed is called when this node is renamed.
+ //
+ // This may not fail. The file will hold a reference to its parent
+ // within the p9 package, and is therefore safe to use for the lifetime
+ // of this File (until Close is called).
+ //
+ // This method should not be called by clients, who should use the
+ // relevant Rename methods. (Although the method will be a no-op.)
+ //
+ // On the server, Renamed has a global concurrency guarantee.
+ Renamed(newDir File, newName string)
+}
+
+// DefaultWalkGetAttr implements File.WalkGetAttr to return ENOSYS for server-side Files.
+type DefaultWalkGetAttr struct{}
+
+// WalkGetAttr implements File.WalkGetAttr.
+func (DefaultWalkGetAttr) WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error) {
+ return nil, nil, AttrMask{}, Attr{}, syscall.ENOSYS
+}
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
new file mode 100644
index 000000000..1db5797dd
--- /dev/null
+++ b/pkg/p9/handlers.go
@@ -0,0 +1,1393 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "path"
+ "strings"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// ExtractErrno extracts a syscall.Errno from a error, best effort.
+func ExtractErrno(err error) syscall.Errno {
+ switch err {
+ case os.ErrNotExist:
+ return syscall.ENOENT
+ case os.ErrExist:
+ return syscall.EEXIST
+ case os.ErrPermission:
+ return syscall.EACCES
+ case os.ErrInvalid:
+ return syscall.EINVAL
+ }
+
+ // Attempt to unwrap.
+ switch e := err.(type) {
+ case syscall.Errno:
+ return e
+ case *os.PathError:
+ return ExtractErrno(e.Err)
+ case *os.SyscallError:
+ return ExtractErrno(e.Err)
+ case *os.LinkError:
+ return ExtractErrno(e.Err)
+ }
+
+ // Default case.
+ log.Warningf("unknown error: %v", err)
+ return syscall.EIO
+}
+
+// newErr returns a new error message from an error.
+func newErr(err error) *Rlerror {
+ return &Rlerror{Error: uint32(ExtractErrno(err))}
+}
+
+// handler is implemented for server-handled messages.
+//
+// See server.go for call information.
+type handler interface {
+ // Handle handles the given message.
+ //
+ // This may modify the server state. The handle function must return a
+ // message which will be sent back to the client. It may be useful to
+ // use newErr to automatically extract an error message.
+ handle(cs *connState) message
+}
+
+// handle implements handler.handle.
+func (t *Tversion) handle(cs *connState) message {
+ if t.MSize == 0 {
+ return newErr(syscall.EINVAL)
+ }
+ if t.MSize > maximumLength {
+ return newErr(syscall.EINVAL)
+ }
+ atomic.StoreUint32(&cs.messageSize, t.MSize)
+ requested, ok := parseVersion(t.Version)
+ if !ok {
+ return newErr(syscall.EINVAL)
+ }
+ // The server cannot support newer versions that it doesn't know about. In this
+ // case we return EAGAIN to tell the client to try again with a lower version.
+ if requested > highestSupportedVersion {
+ return newErr(syscall.EAGAIN)
+ }
+ // From Tversion(9P): "The server may respond with the client’s version
+ // string, or a version string identifying an earlier defined protocol version".
+ atomic.StoreUint32(&cs.version, requested)
+ return &Rversion{
+ MSize: t.MSize,
+ Version: t.Version,
+ }
+}
+
+// handle implements handler.handle.
+func (t *Tflush) handle(cs *connState) message {
+ cs.WaitTag(t.OldTag)
+ return &Rflush{}
+}
+
+// checkSafeName validates the name and returns nil or returns an error.
+func checkSafeName(name string) error {
+ if name != "" && !strings.Contains(name, "/") && name != "." && name != ".." {
+ return nil
+ }
+ return syscall.EINVAL
+}
+
+// handle implements handler.handle.
+func (t *Tclunk) handle(cs *connState) message {
+ if !cs.DeleteFID(t.FID) {
+ return newErr(syscall.EBADF)
+ }
+ return &Rclunk{}
+}
+
+// handle implements handler.handle.
+func (t *Tremove) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // Frustratingly, because we can't be guaranteed that a rename is not
+ // occurring simultaneously with this removal, we need to acquire the
+ // global rename lock for this kind of remove operation to ensure that
+ // ref.parent does not change out from underneath us.
+ //
+ // This is why Tremove is a bad idea, and clients should generally use
+ // Tunlinkat. All p9 clients will use Tunlinkat.
+ err := ref.safelyGlobal(func() error {
+ // Is this a root? Can't remove that.
+ if ref.isRoot() {
+ return syscall.EINVAL
+ }
+
+ // N.B. this remove operation is permitted, even if the file is open.
+ // See also rename below for reasoning.
+
+ // Is this file already deleted?
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+
+ // Retrieve the file's proper name.
+ name := ref.parent.pathNode.nameFor(ref)
+
+ // Attempt the removal.
+ if err := ref.parent.file.UnlinkAt(name, 0); err != nil {
+ return err
+ }
+
+ // Mark all relevant fids as deleted. We don't need to lock any
+ // individual nodes because we already hold the global lock.
+ ref.parent.markChildDeleted(name)
+ return nil
+ })
+
+ // "The remove request asks the file server both to remove the file
+ // represented by fid and to clunk the fid, even if the remove fails."
+ //
+ // "It is correct to consider remove to be a clunk with the side effect
+ // of removing the file if permissions allow."
+ // https://swtch.com/plan9port/man/man9/remove.html
+ if !cs.DeleteFID(t.FID) {
+ return newErr(syscall.EBADF)
+ }
+ if err != nil {
+ return newErr(err)
+ }
+
+ return &Rremove{}
+}
+
+// handle implements handler.handle.
+//
+// We don't support authentication, so this just returns ENOSYS.
+func (t *Tauth) handle(cs *connState) message {
+ return newErr(syscall.ENOSYS)
+}
+
+// handle implements handler.handle.
+func (t *Tattach) handle(cs *connState) message {
+ // Ensure no authentication FID is provided.
+ if t.Auth.AuthenticationFID != NoFID {
+ return newErr(syscall.EINVAL)
+ }
+
+ // Must provide an absolute path.
+ if path.IsAbs(t.Auth.AttachName) {
+ // Trim off the leading / if the path is absolute. We always
+ // treat attach paths as absolute and call attach with the root
+ // argument on the server file for clarity.
+ t.Auth.AttachName = t.Auth.AttachName[1:]
+ }
+
+ // Do the attach on the root.
+ sf, err := cs.server.attacher.Attach()
+ if err != nil {
+ return newErr(err)
+ }
+ qid, valid, attr, err := sf.GetAttr(AttrMaskAll())
+ if err != nil {
+ sf.Close() // Drop file.
+ return newErr(err)
+ }
+ if !valid.Mode {
+ sf.Close() // Drop file.
+ return newErr(syscall.EINVAL)
+ }
+
+ // Build a transient reference.
+ root := &fidRef{
+ server: cs.server,
+ parent: nil,
+ file: sf,
+ refs: 1,
+ mode: attr.Mode.FileType(),
+ pathNode: cs.server.pathTree,
+ }
+ defer root.DecRef()
+
+ // Attach the root?
+ if len(t.Auth.AttachName) == 0 {
+ cs.InsertFID(t.FID, root)
+ return &Rattach{QID: qid}
+ }
+
+ // We want the same traversal checks to apply on attach, so always
+ // attach at the root and use the regular walk paths.
+ names := strings.Split(t.Auth.AttachName, "/")
+ _, newRef, _, _, err := doWalk(cs, root, names, false)
+ if err != nil {
+ return newErr(err)
+ }
+ defer newRef.DecRef()
+
+ // Insert the FID.
+ cs.InsertFID(t.FID, newRef)
+ return &Rattach{QID: qid}
+}
+
+// CanOpen returns whether this file open can be opened, read and written to.
+//
+// This includes everything except symlinks and sockets.
+func CanOpen(mode FileMode) bool {
+ return mode.IsRegular() || mode.IsDir() || mode.IsNamedPipe() || mode.IsBlockDevice() || mode.IsCharacterDevice()
+}
+
+// handle implements handler.handle.
+func (t *Tlopen) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ ref.openedMu.Lock()
+ defer ref.openedMu.Unlock()
+
+ // Has it been opened already?
+ if ref.opened || !CanOpen(ref.mode) {
+ return newErr(syscall.EINVAL)
+ }
+
+ if ref.mode.IsDir() {
+ // Directory must be opened ReadOnly.
+ if t.Flags&OpenFlagsModeMask != ReadOnly {
+ return newErr(syscall.EISDIR)
+ }
+ // Directory not truncatable.
+ if t.Flags&OpenTruncate != 0 {
+ return newErr(syscall.EISDIR)
+ }
+ }
+
+ var (
+ qid QID
+ ioUnit uint32
+ osFile *fd.FD
+ )
+ if err := ref.safelyRead(func() (err error) {
+ // Has it been deleted already?
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+
+ osFile, qid, ioUnit, err = ref.file.Open(t.Flags)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+
+ // Mark file as opened and set open mode.
+ ref.opened = true
+ ref.openFlags = t.Flags
+
+ rlopen := &Rlopen{QID: qid, IoUnit: ioUnit}
+ rlopen.SetFilePayload(osFile)
+ return rlopen
+}
+
+func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) {
+ if err := checkSafeName(t.Name); err != nil {
+ return nil, err
+ }
+
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer ref.DecRef()
+
+ var (
+ osFile *fd.FD
+ nsf File
+ qid QID
+ ioUnit uint32
+ newRef *fidRef
+ )
+ if err := ref.safelyWrite(func() (err error) {
+ // Don't allow creation from non-directories or deleted directories.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Do the create.
+ osFile, nsf, qid, ioUnit, err = ref.file.Create(t.Name, t.OpenFlags, t.Permissions, uid, t.GID)
+ if err != nil {
+ return err
+ }
+
+ newRef = &fidRef{
+ server: cs.server,
+ parent: ref,
+ file: nsf,
+ opened: true,
+ openFlags: t.OpenFlags,
+ mode: ModeRegular,
+ pathNode: ref.pathNode.pathNodeFor(t.Name),
+ }
+ ref.pathNode.addChild(newRef, t.Name)
+ ref.IncRef() // Acquire parent reference.
+ return nil
+ }); err != nil {
+ return nil, err
+ }
+
+ // Replace the FID reference.
+ cs.InsertFID(t.FID, newRef)
+
+ rlcreate := &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit}}
+ rlcreate.SetFilePayload(osFile)
+ return rlcreate, nil
+}
+
+// handle implements handler.handle.
+func (t *Tlcreate) handle(cs *connState) message {
+ rlcreate, err := t.do(cs, NoUID)
+ if err != nil {
+ return newErr(err)
+ }
+ return rlcreate
+}
+
+// handle implements handler.handle.
+func (t *Tsymlink) handle(cs *connState) message {
+ rsymlink, err := t.do(cs, NoUID)
+ if err != nil {
+ return newErr(err)
+ }
+ return rsymlink
+}
+
+func (t *Tsymlink) do(cs *connState, uid UID) (*Rsymlink, error) {
+ if err := checkSafeName(t.Name); err != nil {
+ return nil, err
+ }
+
+ ref, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer ref.DecRef()
+
+ var qid QID
+ if err := ref.safelyWrite(func() (err error) {
+ // Don't allow symlinks from non-directories or deleted directories.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Do the symlink.
+ qid, err = ref.file.Symlink(t.Target, t.Name, uid, t.GID)
+ return err
+ }); err != nil {
+ return nil, err
+ }
+
+ return &Rsymlink{QID: qid}, nil
+}
+
+// handle implements handler.handle.
+func (t *Tlink) handle(cs *connState) message {
+ if err := checkSafeName(t.Name); err != nil {
+ return newErr(err)
+ }
+
+ ref, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ refTarget, ok := cs.LookupFID(t.Target)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer refTarget.DecRef()
+
+ if err := ref.safelyWrite(func() (err error) {
+ // Don't allow create links from non-directories or deleted directories.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Do the link.
+ return ref.file.Link(refTarget.file, t.Name)
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rlink{}
+}
+
+// handle implements handler.handle.
+func (t *Trenameat) handle(cs *connState) message {
+ if err := checkSafeName(t.OldName); err != nil {
+ return newErr(err)
+ }
+ if err := checkSafeName(t.NewName); err != nil {
+ return newErr(err)
+ }
+
+ ref, ok := cs.LookupFID(t.OldDirectory)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ refTarget, ok := cs.LookupFID(t.NewDirectory)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer refTarget.DecRef()
+
+ // Perform the rename holding the global lock.
+ if err := ref.safelyGlobal(func() (err error) {
+ // Don't allow renaming across deleted directories.
+ if ref.isDeleted() || !ref.mode.IsDir() || refTarget.isDeleted() || !refTarget.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Is this the same file? If yes, short-circuit and return success.
+ if ref.pathNode == refTarget.pathNode && t.OldName == t.NewName {
+ return nil
+ }
+
+ // Attempt the actual rename.
+ if err := ref.file.RenameAt(t.OldName, refTarget.file, t.NewName); err != nil {
+ return err
+ }
+
+ // Update the path tree.
+ ref.renameChildTo(t.OldName, refTarget, t.NewName)
+ return nil
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rrenameat{}
+}
+
+// handle implements handler.handle.
+func (t *Tunlinkat) handle(cs *connState) message {
+ if err := checkSafeName(t.Name); err != nil {
+ return newErr(err)
+ }
+
+ ref, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyWrite(func() (err error) {
+ // Don't allow deletion from non-directories or deleted directories.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Before we do the unlink itself, we need to ensure that there
+ // are no operations in flight on associated path node. The
+ // child's path node lock must be held to ensure that the
+ // unlinkat marking the child deleted below is atomic with
+ // respect to any other read or write operations.
+ //
+ // This is one case where we have a lock ordering issue, but
+ // since we always acquire deeper in the hierarchy, we know
+ // that we are free of lock cycles.
+ childPathNode := ref.pathNode.pathNodeFor(t.Name)
+ childPathNode.opMu.Lock()
+ defer childPathNode.opMu.Unlock()
+
+ // Do the unlink.
+ err = ref.file.UnlinkAt(t.Name, t.Flags)
+ if err != nil {
+ return err
+ }
+
+ // Mark the path as deleted.
+ ref.markChildDeleted(t.Name)
+ return nil
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Runlinkat{}
+}
+
+// handle implements handler.handle.
+func (t *Trename) handle(cs *connState) message {
+ if err := checkSafeName(t.Name); err != nil {
+ return newErr(err)
+ }
+
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ refTarget, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer refTarget.DecRef()
+
+ if err := ref.safelyGlobal(func() (err error) {
+ // Don't allow a root rename.
+ if ref.isRoot() {
+ return syscall.EINVAL
+ }
+
+ // Don't allow renaming deleting entries, or target non-directories.
+ if ref.isDeleted() || refTarget.isDeleted() || !refTarget.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // If the parent is deleted, but we not, something is seriously wrong.
+ // It's fail to die at this point with an assertion failure.
+ if ref.parent.isDeleted() {
+ panic(fmt.Sprintf("parent %+v deleted, child %+v is not", ref.parent, ref))
+ }
+
+ // N.B. The rename operation is allowed to proceed on open files. It
+ // does impact the state of its parent, but this is merely a sanity
+ // check in any case, and the operation is safe. There may be other
+ // files corresponding to the same path that are renamed anyways.
+
+ // Check for the exact same file and short-circuit.
+ oldName := ref.parent.pathNode.nameFor(ref)
+ if ref.parent.pathNode == refTarget.pathNode && oldName == t.Name {
+ return nil
+ }
+
+ // Call the rename method on the parent.
+ if err := ref.parent.file.RenameAt(oldName, refTarget.file, t.Name); err != nil {
+ return err
+ }
+
+ // Update the path tree.
+ ref.parent.renameChildTo(oldName, refTarget, t.Name)
+ return nil
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rrename{}
+}
+
+// handle implements handler.handle.
+func (t *Treadlink) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ var target string
+ if err := ref.safelyRead(func() (err error) {
+ // Don't allow readlink on deleted files. There is no need to
+ // check if this file is opened because symlinks cannot be
+ // opened.
+ if ref.isDeleted() || !ref.mode.IsSymlink() {
+ return syscall.EINVAL
+ }
+
+ // Do the read.
+ target, err = ref.file.Readlink()
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rreadlink{target}
+}
+
+// handle implements handler.handle.
+func (t *Tread) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // Constrain the size of the read buffer.
+ if int(t.Count) > int(maximumLength) {
+ return newErr(syscall.ENOBUFS)
+ }
+
+ var (
+ data = make([]byte, t.Count)
+ n int
+ )
+ if err := ref.safelyRead(func() (err error) {
+ // Has it been opened already?
+ openFlags, opened := ref.OpenFlags()
+ if !opened {
+ return syscall.EINVAL
+ }
+
+ // Can it be read? Check permissions.
+ if openFlags&OpenFlagsModeMask == WriteOnly {
+ return syscall.EPERM
+ }
+
+ n, err = ref.file.ReadAt(data, t.Offset)
+ return err
+ }); err != nil && err != io.EOF {
+ return newErr(err)
+ }
+
+ return &Rread{Data: data[:n]}
+}
+
+// handle implements handler.handle.
+func (t *Twrite) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ var n int
+ if err := ref.safelyRead(func() (err error) {
+ // Has it been opened already?
+ openFlags, opened := ref.OpenFlags()
+ if !opened {
+ return syscall.EINVAL
+ }
+
+ // Can it be written? Check permissions.
+ if openFlags&OpenFlagsModeMask == ReadOnly {
+ return syscall.EPERM
+ }
+
+ n, err = ref.file.WriteAt(t.Data, t.Offset)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rwrite{Count: uint32(n)}
+}
+
+// handle implements handler.handle.
+func (t *Tmknod) handle(cs *connState) message {
+ rmknod, err := t.do(cs, NoUID)
+ if err != nil {
+ return newErr(err)
+ }
+ return rmknod
+}
+
+func (t *Tmknod) do(cs *connState, uid UID) (*Rmknod, error) {
+ if err := checkSafeName(t.Name); err != nil {
+ return nil, err
+ }
+
+ ref, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer ref.DecRef()
+
+ var qid QID
+ if err := ref.safelyWrite(func() (err error) {
+ // Don't allow mknod on deleted files.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Do the mknod.
+ qid, err = ref.file.Mknod(t.Name, t.Mode, t.Major, t.Minor, uid, t.GID)
+ return err
+ }); err != nil {
+ return nil, err
+ }
+
+ return &Rmknod{QID: qid}, nil
+}
+
+// handle implements handler.handle.
+func (t *Tmkdir) handle(cs *connState) message {
+ rmkdir, err := t.do(cs, NoUID)
+ if err != nil {
+ return newErr(err)
+ }
+ return rmkdir
+}
+
+func (t *Tmkdir) do(cs *connState, uid UID) (*Rmkdir, error) {
+ if err := checkSafeName(t.Name); err != nil {
+ return nil, err
+ }
+
+ ref, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer ref.DecRef()
+
+ var qid QID
+ if err := ref.safelyWrite(func() (err error) {
+ // Don't allow mkdir on deleted files.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Do the mkdir.
+ qid, err = ref.file.Mkdir(t.Name, t.Permissions, uid, t.GID)
+ return err
+ }); err != nil {
+ return nil, err
+ }
+
+ return &Rmkdir{QID: qid}, nil
+}
+
+// handle implements handler.handle.
+func (t *Tgetattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // We allow getattr on deleted files. Depending on the backing
+ // implementation, it's possible that races exist that might allow
+ // fetching attributes of other files. But we need to generally allow
+ // refreshing attributes and this is a minor leak, if at all.
+
+ var (
+ qid QID
+ valid AttrMask
+ attr Attr
+ )
+ if err := ref.safelyRead(func() (err error) {
+ qid, valid, attr, err = ref.file.GetAttr(t.AttrMask)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rgetattr{QID: qid, Valid: valid, Attr: attr}
+}
+
+// handle implements handler.handle.
+func (t *Tsetattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyWrite(func() error {
+ // We don't allow setattr on files that have been deleted.
+ // This might be technically incorrect, as it's possible that
+ // there were multiple links and you can still change the
+ // corresponding inode information.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+
+ // Set the attributes.
+ return ref.file.SetAttr(t.Valid, t.SetAttr)
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rsetattr{}
+}
+
+// handle implements handler.handle.
+func (t *Tallocate) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyWrite(func() error {
+ // Has it been opened already?
+ openFlags, opened := ref.OpenFlags()
+ if !opened {
+ return syscall.EINVAL
+ }
+
+ // Can it be written? Check permissions.
+ if openFlags&OpenFlagsModeMask == ReadOnly {
+ return syscall.EBADF
+ }
+
+ // We don't allow allocate on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+
+ return ref.file.Allocate(t.Mode, t.Offset, t.Length)
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rallocate{}
+}
+
+// handle implements handler.handle.
+func (t *Txattrwalk) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // We don't support extended attributes.
+ return newErr(syscall.ENODATA)
+}
+
+// handle implements handler.handle.
+func (t *Txattrcreate) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // We don't support extended attributes.
+ return newErr(syscall.ENOSYS)
+}
+
+// handle implements handler.handle.
+func (t *Tgetxattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ var val string
+ if err := ref.safelyRead(func() (err error) {
+ // Don't allow getxattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ val, err = ref.file.GetXattr(t.Name, t.Size)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+ return &Rgetxattr{Value: val}
+}
+
+// handle implements handler.handle.
+func (t *Tsetxattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyWrite(func() error {
+ // Don't allow setxattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ return ref.file.SetXattr(t.Name, t.Value, t.Flags)
+ }); err != nil {
+ return newErr(err)
+ }
+ return &Rsetxattr{}
+}
+
+// handle implements handler.handle.
+func (t *Tlistxattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ var xattrs map[string]struct{}
+ if err := ref.safelyRead(func() (err error) {
+ // Don't allow listxattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ xattrs, err = ref.file.ListXattr(t.Size)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+
+ xattrList := make([]string, 0, len(xattrs))
+ for x := range xattrs {
+ xattrList = append(xattrList, x)
+ }
+ return &Rlistxattr{Xattrs: xattrList}
+}
+
+// handle implements handler.handle.
+func (t *Tremovexattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyWrite(func() error {
+ // Don't allow removexattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ return ref.file.RemoveXattr(t.Name)
+ }); err != nil {
+ return newErr(err)
+ }
+ return &Rremovexattr{}
+}
+
+// handle implements handler.handle.
+func (t *Treaddir) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ var entries []Dirent
+ if err := ref.safelyRead(func() (err error) {
+ // Don't allow reading deleted directories.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Has it been opened already?
+ if _, opened := ref.OpenFlags(); !opened {
+ return syscall.EINVAL
+ }
+
+ // Read the entries.
+ entries, err = ref.file.Readdir(t.Offset, t.Count)
+ if err != nil && err != io.EOF {
+ return err
+ }
+ return nil
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rreaddir{Count: t.Count, Entries: entries}
+}
+
+// handle implements handler.handle.
+func (t *Tfsync) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyRead(func() (err error) {
+ // Has it been opened already?
+ if _, opened := ref.OpenFlags(); !opened {
+ return syscall.EINVAL
+ }
+
+ // Perform the sync.
+ return ref.file.FSync()
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rfsync{}
+}
+
+// handle implements handler.handle.
+func (t *Tstatfs) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ st, err := ref.file.StatFS()
+ if err != nil {
+ return newErr(err)
+ }
+
+ return &Rstatfs{st}
+}
+
+// handle implements handler.handle.
+func (t *Tflushf) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyRead(ref.file.Flush); err != nil {
+ return newErr(err)
+ }
+
+ return &Rflushf{}
+}
+
+// walkOne walks zero or one path elements.
+//
+// The slice passed as qids is append and returned.
+func walkOne(qids []QID, from File, names []string, getattr bool) ([]QID, File, AttrMask, Attr, error) {
+ if len(names) > 1 {
+ // We require exactly zero or one elements.
+ return nil, nil, AttrMask{}, Attr{}, syscall.EINVAL
+ }
+ var (
+ localQIDs []QID
+ sf File
+ valid AttrMask
+ attr Attr
+ err error
+ )
+ switch {
+ case getattr:
+ localQIDs, sf, valid, attr, err = from.WalkGetAttr(names)
+ // Can't put fallthrough in the if because Go.
+ if err != syscall.ENOSYS {
+ break
+ }
+ fallthrough
+ default:
+ localQIDs, sf, err = from.Walk(names)
+ if err != nil {
+ // No way to walk this element.
+ break
+ }
+ if getattr {
+ _, valid, attr, err = sf.GetAttr(AttrMaskAll())
+ if err != nil {
+ // Don't leak the file.
+ sf.Close()
+ }
+ }
+ }
+ if err != nil {
+ // Error walking, don't return anything.
+ return nil, nil, AttrMask{}, Attr{}, err
+ }
+ if len(localQIDs) != 1 {
+ // Expected a single QID.
+ sf.Close()
+ return nil, nil, AttrMask{}, Attr{}, syscall.EINVAL
+ }
+ return append(qids, localQIDs...), sf, valid, attr, nil
+}
+
+// doWalk walks from a given fidRef.
+//
+// This enforces that all intermediate nodes are walkable (directories). The
+// fidRef returned (newRef) has a reference associated with it that is now
+// owned by the caller and must be handled appropriately.
+func doWalk(cs *connState, ref *fidRef, names []string, getattr bool) (qids []QID, newRef *fidRef, valid AttrMask, attr Attr, err error) {
+ // Check the names.
+ for _, name := range names {
+ err = checkSafeName(name)
+ if err != nil {
+ return
+ }
+ }
+
+ // Has it been opened already?
+ if _, opened := ref.OpenFlags(); opened {
+ err = syscall.EBUSY
+ return
+ }
+
+ // Is this an empty list? Handle specially. We don't actually need to
+ // validate anything since this is always permitted.
+ if len(names) == 0 {
+ var sf File // Temporary.
+ if err := ref.maybeParent().safelyRead(func() (err error) {
+ // Clone the single element.
+ qids, sf, valid, attr, err = walkOne(nil, ref.file, nil, getattr)
+ if err != nil {
+ return err
+ }
+
+ newRef = &fidRef{
+ server: cs.server,
+ parent: ref.parent,
+ file: sf,
+ mode: ref.mode,
+ pathNode: ref.pathNode,
+
+ // For the clone case, the cloned fid must
+ // preserve the deleted property of the
+ // original FID.
+ deleted: ref.deleted,
+ }
+ if !ref.isRoot() {
+ if !newRef.isDeleted() {
+ // Add only if a non-root node; the same node.
+ ref.parent.pathNode.addChild(newRef, ref.parent.pathNode.nameFor(ref))
+ }
+ ref.parent.IncRef() // Acquire parent reference.
+ }
+ // doWalk returns a reference.
+ newRef.IncRef()
+ return nil
+ }); err != nil {
+ return nil, nil, AttrMask{}, Attr{}, err
+ }
+ // Do not return the new QID.
+ return nil, newRef, valid, attr, nil
+ }
+
+ // Do the walk, one element at a time.
+ walkRef := ref
+ walkRef.IncRef()
+ for i := 0; i < len(names); i++ {
+ // We won't allow beyond past symlinks; stop here if this isn't
+ // a proper directory and we have additional paths to walk.
+ if !walkRef.mode.IsDir() {
+ walkRef.DecRef() // Drop walk reference; no lock required.
+ return nil, nil, AttrMask{}, Attr{}, syscall.EINVAL
+ }
+
+ var sf File // Temporary.
+ if err := walkRef.safelyRead(func() (err error) {
+ // Pass getattr = true to walkOne since we need the file type for
+ // newRef.
+ qids, sf, valid, attr, err = walkOne(qids, walkRef.file, names[i:i+1], true)
+ if err != nil {
+ return err
+ }
+
+ // Note that we don't need to acquire a lock on any of
+ // these individual instances. That's because they are
+ // not actually addressable via a FID. They are
+ // anonymous. They exist in the tree for tracking
+ // purposes.
+ newRef := &fidRef{
+ server: cs.server,
+ parent: walkRef,
+ file: sf,
+ mode: attr.Mode.FileType(),
+ pathNode: walkRef.pathNode.pathNodeFor(names[i]),
+ }
+ walkRef.pathNode.addChild(newRef, names[i])
+ // We allow our walk reference to become the new parent
+ // reference here and so we don't IncRef. Instead, just
+ // set walkRef to the newRef above and acquire a new
+ // walk reference.
+ walkRef = newRef
+ walkRef.IncRef()
+ return nil
+ }); err != nil {
+ walkRef.DecRef() // Drop the old walkRef.
+ return nil, nil, AttrMask{}, Attr{}, err
+ }
+ }
+
+ // Success.
+ return qids, walkRef, valid, attr, nil
+}
+
+// handle implements handler.handle.
+func (t *Twalk) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // Do the walk.
+ qids, newRef, _, _, err := doWalk(cs, ref, t.Names, false)
+ if err != nil {
+ return newErr(err)
+ }
+ defer newRef.DecRef()
+
+ // Install the new FID.
+ cs.InsertFID(t.NewFID, newRef)
+ return &Rwalk{QIDs: qids}
+}
+
+// handle implements handler.handle.
+func (t *Twalkgetattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // Do the walk.
+ qids, newRef, valid, attr, err := doWalk(cs, ref, t.Names, true)
+ if err != nil {
+ return newErr(err)
+ }
+ defer newRef.DecRef()
+
+ // Install the new FID.
+ cs.InsertFID(t.NewFID, newRef)
+ return &Rwalkgetattr{QIDs: qids, Valid: valid, Attr: attr}
+}
+
+// handle implements handler.handle.
+func (t *Tucreate) handle(cs *connState) message {
+ rlcreate, err := t.Tlcreate.do(cs, t.UID)
+ if err != nil {
+ return newErr(err)
+ }
+ return &Rucreate{*rlcreate}
+}
+
+// handle implements handler.handle.
+func (t *Tumkdir) handle(cs *connState) message {
+ rmkdir, err := t.Tmkdir.do(cs, t.UID)
+ if err != nil {
+ return newErr(err)
+ }
+ return &Rumkdir{*rmkdir}
+}
+
+// handle implements handler.handle.
+func (t *Tusymlink) handle(cs *connState) message {
+ rsymlink, err := t.Tsymlink.do(cs, t.UID)
+ if err != nil {
+ return newErr(err)
+ }
+ return &Rusymlink{*rsymlink}
+}
+
+// handle implements handler.handle.
+func (t *Tumknod) handle(cs *connState) message {
+ rmknod, err := t.Tmknod.do(cs, t.UID)
+ if err != nil {
+ return newErr(err)
+ }
+ return &Rumknod{*rmknod}
+}
+
+// handle implements handler.handle.
+func (t *Tlconnect) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ var osFile *fd.FD
+ if err := ref.safelyRead(func() (err error) {
+ // Don't allow connecting to deleted files.
+ if ref.isDeleted() || !ref.mode.IsSocket() {
+ return syscall.EINVAL
+ }
+
+ // Do the connect.
+ osFile, err = ref.file.Connect(t.Flags)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+
+ rlconnect := &Rlconnect{}
+ rlconnect.SetFilePayload(osFile)
+ return rlconnect
+}
+
+// handle implements handler.handle.
+func (t *Tchannel) handle(cs *connState) message {
+ // Ensure that channels are enabled.
+ if err := cs.initializeChannels(); err != nil {
+ return newErr(err)
+ }
+
+ ch := cs.lookupChannel(t.ID)
+ if ch == nil {
+ return newErr(syscall.ENOSYS)
+ }
+
+ // Return the payload. Note that we need to duplicate the file
+ // descriptor for the channel allocator, because sending is a
+ // destructive operation between sendRecvLegacy (and now the newer
+ // channel send operations). Same goes for the client FD.
+ rchannel := &Rchannel{
+ Offset: uint64(ch.desc.Offset),
+ Length: uint64(ch.desc.Length),
+ }
+ switch t.Control {
+ case 0:
+ // Open the main data channel.
+ mfd, err := syscall.Dup(int(cs.channelAlloc.FD()))
+ if err != nil {
+ return newErr(err)
+ }
+ rchannel.SetFilePayload(fd.New(mfd))
+ case 1:
+ cfd, err := syscall.Dup(ch.client.FD())
+ if err != nil {
+ return newErr(err)
+ }
+ rchannel.SetFilePayload(fd.New(cfd))
+ default:
+ return newErr(syscall.EINVAL)
+ }
+ return rchannel
+}
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
new file mode 100644
index 000000000..57b89ad7d
--- /dev/null
+++ b/pkg/p9/messages.go
@@ -0,0 +1,2662 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "fmt"
+ "math"
+
+ "gvisor.dev/gvisor/pkg/fd"
+)
+
+// ErrInvalidMsgType is returned when an unsupported message type is found.
+type ErrInvalidMsgType struct {
+ MsgType
+}
+
+// Error returns a useful string.
+func (e *ErrInvalidMsgType) Error() string {
+ return fmt.Sprintf("invalid message type: %d", e.MsgType)
+}
+
+// message is a generic 9P message.
+type message interface {
+ encoder
+ fmt.Stringer
+
+ // Type returns the message type number.
+ Type() MsgType
+}
+
+// payloader is a special message which may include an inline payload.
+type payloader interface {
+ // FixedSize returns the size of the fixed portion of this message.
+ FixedSize() uint32
+
+ // Payload returns the payload for sending.
+ Payload() []byte
+
+ // SetPayload returns the decoded message.
+ //
+ // This is going to be total message size - FixedSize. But this should
+ // be validated during decode, which will be called after SetPayload.
+ SetPayload([]byte)
+}
+
+// filer is a message capable of passing a file.
+type filer interface {
+ // FilePayload returns the file payload.
+ FilePayload() *fd.FD
+
+ // SetFilePayload sets the file payload.
+ SetFilePayload(*fd.FD)
+}
+
+// filePayload embeds a File object.
+type filePayload struct {
+ File *fd.FD
+}
+
+// FilePayload returns the file payload.
+func (f *filePayload) FilePayload() *fd.FD {
+ return f.File
+}
+
+// SetFilePayload sets the received file.
+func (f *filePayload) SetFilePayload(file *fd.FD) {
+ f.File = file
+}
+
+// Tversion is a version request.
+type Tversion struct {
+ // MSize is the message size to use.
+ MSize uint32
+
+ // Version is the version string.
+ //
+ // For this implementation, this must be 9P2000.L.
+ Version string
+}
+
+// decode implements encoder.decode.
+func (t *Tversion) decode(b *buffer) {
+ t.MSize = b.Read32()
+ t.Version = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (t *Tversion) encode(b *buffer) {
+ b.Write32(t.MSize)
+ b.WriteString(t.Version)
+}
+
+// Type implements message.Type.
+func (*Tversion) Type() MsgType {
+ return MsgTversion
+}
+
+// String implements fmt.Stringer.
+func (t *Tversion) String() string {
+ return fmt.Sprintf("Tversion{MSize: %d, Version: %s}", t.MSize, t.Version)
+}
+
+// Rversion is a version response.
+type Rversion struct {
+ // MSize is the negotiated size.
+ MSize uint32
+
+ // Version is the negotiated version.
+ Version string
+}
+
+// decode implements encoder.decode.
+func (r *Rversion) decode(b *buffer) {
+ r.MSize = b.Read32()
+ r.Version = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (r *Rversion) encode(b *buffer) {
+ b.Write32(r.MSize)
+ b.WriteString(r.Version)
+}
+
+// Type implements message.Type.
+func (*Rversion) Type() MsgType {
+ return MsgRversion
+}
+
+// String implements fmt.Stringer.
+func (r *Rversion) String() string {
+ return fmt.Sprintf("Rversion{MSize: %d, Version: %s}", r.MSize, r.Version)
+}
+
+// Tflush is a flush request.
+type Tflush struct {
+ // OldTag is the tag to wait on.
+ OldTag Tag
+}
+
+// decode implements encoder.decode.
+func (t *Tflush) decode(b *buffer) {
+ t.OldTag = b.ReadTag()
+}
+
+// encode implements encoder.encode.
+func (t *Tflush) encode(b *buffer) {
+ b.WriteTag(t.OldTag)
+}
+
+// Type implements message.Type.
+func (*Tflush) Type() MsgType {
+ return MsgTflush
+}
+
+// String implements fmt.Stringer.
+func (t *Tflush) String() string {
+ return fmt.Sprintf("Tflush{OldTag: %d}", t.OldTag)
+}
+
+// Rflush is a flush response.
+type Rflush struct {
+}
+
+// decode implements encoder.decode.
+func (*Rflush) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rflush) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rflush) Type() MsgType {
+ return MsgRflush
+}
+
+// String implements fmt.Stringer.
+func (r *Rflush) String() string {
+ return "RFlush{}"
+}
+
+// Twalk is a walk request.
+type Twalk struct {
+ // FID is the FID to be walked.
+ FID FID
+
+ // NewFID is the resulting FID.
+ NewFID FID
+
+ // Names are the set of names to be walked.
+ Names []string
+}
+
+// decode implements encoder.decode.
+func (t *Twalk) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.NewFID = b.ReadFID()
+ n := b.Read16()
+ t.Names = t.Names[:0]
+ for i := 0; i < int(n); i++ {
+ t.Names = append(t.Names, b.ReadString())
+ }
+}
+
+// encode implements encoder.encode.
+func (t *Twalk) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteFID(t.NewFID)
+ b.Write16(uint16(len(t.Names)))
+ for _, name := range t.Names {
+ b.WriteString(name)
+ }
+}
+
+// Type implements message.Type.
+func (*Twalk) Type() MsgType {
+ return MsgTwalk
+}
+
+// String implements fmt.Stringer.
+func (t *Twalk) String() string {
+ return fmt.Sprintf("Twalk{FID: %d, NewFID: %d, Names: %v}", t.FID, t.NewFID, t.Names)
+}
+
+// Rwalk is a walk response.
+type Rwalk struct {
+ // QIDs are the set of QIDs returned.
+ QIDs []QID
+}
+
+// decode implements encoder.decode.
+func (r *Rwalk) decode(b *buffer) {
+ n := b.Read16()
+ r.QIDs = r.QIDs[:0]
+ for i := 0; i < int(n); i++ {
+ var q QID
+ q.decode(b)
+ r.QIDs = append(r.QIDs, q)
+ }
+}
+
+// encode implements encoder.encode.
+func (r *Rwalk) encode(b *buffer) {
+ b.Write16(uint16(len(r.QIDs)))
+ for _, q := range r.QIDs {
+ q.encode(b)
+ }
+}
+
+// Type implements message.Type.
+func (*Rwalk) Type() MsgType {
+ return MsgRwalk
+}
+
+// String implements fmt.Stringer.
+func (r *Rwalk) String() string {
+ return fmt.Sprintf("Rwalk{QIDs: %v}", r.QIDs)
+}
+
+// Tclunk is a close request.
+type Tclunk struct {
+ // FID is the FID to be closed.
+ FID FID
+}
+
+// decode implements encoder.decode.
+func (t *Tclunk) decode(b *buffer) {
+ t.FID = b.ReadFID()
+}
+
+// encode implements encoder.encode.
+func (t *Tclunk) encode(b *buffer) {
+ b.WriteFID(t.FID)
+}
+
+// Type implements message.Type.
+func (*Tclunk) Type() MsgType {
+ return MsgTclunk
+}
+
+// String implements fmt.Stringer.
+func (t *Tclunk) String() string {
+ return fmt.Sprintf("Tclunk{FID: %d}", t.FID)
+}
+
+// Rclunk is a close response.
+type Rclunk struct {
+}
+
+// decode implements encoder.decode.
+func (*Rclunk) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rclunk) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rclunk) Type() MsgType {
+ return MsgRclunk
+}
+
+// String implements fmt.Stringer.
+func (r *Rclunk) String() string {
+ return "Rclunk{}"
+}
+
+// Tremove is a remove request.
+//
+// This will eventually be replaced by Tunlinkat.
+type Tremove struct {
+ // FID is the FID to be removed.
+ FID FID
+}
+
+// decode implements encoder.decode.
+func (t *Tremove) decode(b *buffer) {
+ t.FID = b.ReadFID()
+}
+
+// encode implements encoder.encode.
+func (t *Tremove) encode(b *buffer) {
+ b.WriteFID(t.FID)
+}
+
+// Type implements message.Type.
+func (*Tremove) Type() MsgType {
+ return MsgTremove
+}
+
+// String implements fmt.Stringer.
+func (t *Tremove) String() string {
+ return fmt.Sprintf("Tremove{FID: %d}", t.FID)
+}
+
+// Rremove is a remove response.
+type Rremove struct {
+}
+
+// decode implements encoder.decode.
+func (*Rremove) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rremove) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rremove) Type() MsgType {
+ return MsgRremove
+}
+
+// String implements fmt.Stringer.
+func (r *Rremove) String() string {
+ return "Rremove{}"
+}
+
+// Rlerror is an error response.
+//
+// Note that this replaces the error code used in 9p.
+type Rlerror struct {
+ Error uint32
+}
+
+// decode implements encoder.decode.
+func (r *Rlerror) decode(b *buffer) {
+ r.Error = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (r *Rlerror) encode(b *buffer) {
+ b.Write32(r.Error)
+}
+
+// Type implements message.Type.
+func (*Rlerror) Type() MsgType {
+ return MsgRlerror
+}
+
+// String implements fmt.Stringer.
+func (r *Rlerror) String() string {
+ return fmt.Sprintf("Rlerror{Error: %d}", r.Error)
+}
+
+// Tauth is an authentication request.
+type Tauth struct {
+ // AuthenticationFID is the FID to attach the authentication result.
+ AuthenticationFID FID
+
+ // UserName is the user to attach.
+ UserName string
+
+ // AttachName is the attach name.
+ AttachName string
+
+ // UserID is the numeric identifier for UserName.
+ UID UID
+}
+
+// decode implements encoder.decode.
+func (t *Tauth) decode(b *buffer) {
+ t.AuthenticationFID = b.ReadFID()
+ t.UserName = b.ReadString()
+ t.AttachName = b.ReadString()
+ t.UID = b.ReadUID()
+}
+
+// encode implements encoder.encode.
+func (t *Tauth) encode(b *buffer) {
+ b.WriteFID(t.AuthenticationFID)
+ b.WriteString(t.UserName)
+ b.WriteString(t.AttachName)
+ b.WriteUID(t.UID)
+}
+
+// Type implements message.Type.
+func (*Tauth) Type() MsgType {
+ return MsgTauth
+}
+
+// String implements fmt.Stringer.
+func (t *Tauth) String() string {
+ return fmt.Sprintf("Tauth{AuthFID: %d, UserName: %s, AttachName: %s, UID: %d", t.AuthenticationFID, t.UserName, t.AttachName, t.UID)
+}
+
+// Rauth is an authentication response.
+//
+// encode and decode are inherited directly from QID.
+type Rauth struct {
+ QID
+}
+
+// Type implements message.Type.
+func (*Rauth) Type() MsgType {
+ return MsgRauth
+}
+
+// String implements fmt.Stringer.
+func (r *Rauth) String() string {
+ return fmt.Sprintf("Rauth{QID: %s}", r.QID)
+}
+
+// Tattach is an attach request.
+type Tattach struct {
+ // FID is the FID to be attached.
+ FID FID
+
+ // Auth is the embedded authentication request.
+ //
+ // See client.Attach for information regarding authentication.
+ Auth Tauth
+}
+
+// decode implements encoder.decode.
+func (t *Tattach) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Auth.decode(b)
+}
+
+// encode implements encoder.encode.
+func (t *Tattach) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ t.Auth.encode(b)
+}
+
+// Type implements message.Type.
+func (*Tattach) Type() MsgType {
+ return MsgTattach
+}
+
+// String implements fmt.Stringer.
+func (t *Tattach) String() string {
+ return fmt.Sprintf("Tattach{FID: %d, AuthFID: %d, UserName: %s, AttachName: %s, UID: %d}", t.FID, t.Auth.AuthenticationFID, t.Auth.UserName, t.Auth.AttachName, t.Auth.UID)
+}
+
+// Rattach is an attach response.
+type Rattach struct {
+ QID
+}
+
+// Type implements message.Type.
+func (*Rattach) Type() MsgType {
+ return MsgRattach
+}
+
+// String implements fmt.Stringer.
+func (r *Rattach) String() string {
+ return fmt.Sprintf("Rattach{QID: %s}", r.QID)
+}
+
+// Tlopen is an open request.
+type Tlopen struct {
+ // FID is the FID to be opened.
+ FID FID
+
+ // Flags are the open flags.
+ Flags OpenFlags
+}
+
+// decode implements encoder.decode.
+func (t *Tlopen) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Flags = b.ReadOpenFlags()
+}
+
+// encode implements encoder.encode.
+func (t *Tlopen) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteOpenFlags(t.Flags)
+}
+
+// Type implements message.Type.
+func (*Tlopen) Type() MsgType {
+ return MsgTlopen
+}
+
+// String implements fmt.Stringer.
+func (t *Tlopen) String() string {
+ return fmt.Sprintf("Tlopen{FID: %d, Flags: %v}", t.FID, t.Flags)
+}
+
+// Rlopen is a open response.
+type Rlopen struct {
+ // QID is the file's QID.
+ QID QID
+
+ // IoUnit is the recommended I/O unit.
+ IoUnit uint32
+
+ filePayload
+}
+
+// decode implements encoder.decode.
+func (r *Rlopen) decode(b *buffer) {
+ r.QID.decode(b)
+ r.IoUnit = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (r *Rlopen) encode(b *buffer) {
+ r.QID.encode(b)
+ b.Write32(r.IoUnit)
+}
+
+// Type implements message.Type.
+func (*Rlopen) Type() MsgType {
+ return MsgRlopen
+}
+
+// String implements fmt.Stringer.
+func (r *Rlopen) String() string {
+ return fmt.Sprintf("Rlopen{QID: %s, IoUnit: %d, File: %v}", r.QID, r.IoUnit, r.File)
+}
+
+// Tlcreate is a create request.
+type Tlcreate struct {
+ // FID is the parent FID.
+ //
+ // This becomes the new file.
+ FID FID
+
+ // Name is the file name to create.
+ Name string
+
+ // Mode is the open mode (O_RDWR, etc.).
+ //
+ // Note that flags like O_TRUNC are ignored, as is O_EXCL. All
+ // create operations are exclusive.
+ OpenFlags OpenFlags
+
+ // Permissions is the set of permission bits.
+ Permissions FileMode
+
+ // GID is the group ID to use for creating the file.
+ GID GID
+}
+
+// decode implements encoder.decode.
+func (t *Tlcreate) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+ t.OpenFlags = b.ReadOpenFlags()
+ t.Permissions = b.ReadPermissions()
+ t.GID = b.ReadGID()
+}
+
+// encode implements encoder.encode.
+func (t *Tlcreate) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+ b.WriteOpenFlags(t.OpenFlags)
+ b.WritePermissions(t.Permissions)
+ b.WriteGID(t.GID)
+}
+
+// Type implements message.Type.
+func (*Tlcreate) Type() MsgType {
+ return MsgTlcreate
+}
+
+// String implements fmt.Stringer.
+func (t *Tlcreate) String() string {
+ return fmt.Sprintf("Tlcreate{FID: %d, Name: %s, OpenFlags: %s, Permissions: 0o%o, GID: %d}", t.FID, t.Name, t.OpenFlags, t.Permissions, t.GID)
+}
+
+// Rlcreate is a create response.
+//
+// The encode, decode, etc. methods are inherited from Rlopen.
+type Rlcreate struct {
+ Rlopen
+}
+
+// Type implements message.Type.
+func (*Rlcreate) Type() MsgType {
+ return MsgRlcreate
+}
+
+// String implements fmt.Stringer.
+func (r *Rlcreate) String() string {
+ return fmt.Sprintf("Rlcreate{QID: %s, IoUnit: %d, File: %v}", r.QID, r.IoUnit, r.File)
+}
+
+// Tsymlink is a symlink request.
+type Tsymlink struct {
+ // Directory is the directory FID.
+ Directory FID
+
+ // Name is the new in the directory.
+ Name string
+
+ // Target is the symlink target.
+ Target string
+
+ // GID is the owning group.
+ GID GID
+}
+
+// decode implements encoder.decode.
+func (t *Tsymlink) decode(b *buffer) {
+ t.Directory = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Target = b.ReadString()
+ t.GID = b.ReadGID()
+}
+
+// encode implements encoder.encode.
+func (t *Tsymlink) encode(b *buffer) {
+ b.WriteFID(t.Directory)
+ b.WriteString(t.Name)
+ b.WriteString(t.Target)
+ b.WriteGID(t.GID)
+}
+
+// Type implements message.Type.
+func (*Tsymlink) Type() MsgType {
+ return MsgTsymlink
+}
+
+// String implements fmt.Stringer.
+func (t *Tsymlink) String() string {
+ return fmt.Sprintf("Tsymlink{DirectoryFID: %d, Name: %s, Target: %s, GID: %d}", t.Directory, t.Name, t.Target, t.GID)
+}
+
+// Rsymlink is a symlink response.
+type Rsymlink struct {
+ // QID is the new symlink's QID.
+ QID QID
+}
+
+// decode implements encoder.decode.
+func (r *Rsymlink) decode(b *buffer) {
+ r.QID.decode(b)
+}
+
+// encode implements encoder.encode.
+func (r *Rsymlink) encode(b *buffer) {
+ r.QID.encode(b)
+}
+
+// Type implements message.Type.
+func (*Rsymlink) Type() MsgType {
+ return MsgRsymlink
+}
+
+// String implements fmt.Stringer.
+func (r *Rsymlink) String() string {
+ return fmt.Sprintf("Rsymlink{QID: %s}", r.QID)
+}
+
+// Tlink is a link request.
+type Tlink struct {
+ // Directory is the directory to contain the link.
+ Directory FID
+
+ // FID is the target.
+ Target FID
+
+ // Name is the new source name.
+ Name string
+}
+
+// decode implements encoder.decode.
+func (t *Tlink) decode(b *buffer) {
+ t.Directory = b.ReadFID()
+ t.Target = b.ReadFID()
+ t.Name = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (t *Tlink) encode(b *buffer) {
+ b.WriteFID(t.Directory)
+ b.WriteFID(t.Target)
+ b.WriteString(t.Name)
+}
+
+// Type implements message.Type.
+func (*Tlink) Type() MsgType {
+ return MsgTlink
+}
+
+// String implements fmt.Stringer.
+func (t *Tlink) String() string {
+ return fmt.Sprintf("Tlink{DirectoryFID: %d, TargetFID: %d, Name: %s}", t.Directory, t.Target, t.Name)
+}
+
+// Rlink is a link response.
+type Rlink struct {
+}
+
+// Type implements message.Type.
+func (*Rlink) Type() MsgType {
+ return MsgRlink
+}
+
+// decode implements encoder.decode.
+func (*Rlink) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rlink) encode(*buffer) {
+}
+
+// String implements fmt.Stringer.
+func (r *Rlink) String() string {
+ return "Rlink{}"
+}
+
+// Trenameat is a rename request.
+type Trenameat struct {
+ // OldDirectory is the source directory.
+ OldDirectory FID
+
+ // OldName is the source file name.
+ OldName string
+
+ // NewDirectory is the target directory.
+ NewDirectory FID
+
+ // NewName is the new file name.
+ NewName string
+}
+
+// decode implements encoder.decode.
+func (t *Trenameat) decode(b *buffer) {
+ t.OldDirectory = b.ReadFID()
+ t.OldName = b.ReadString()
+ t.NewDirectory = b.ReadFID()
+ t.NewName = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (t *Trenameat) encode(b *buffer) {
+ b.WriteFID(t.OldDirectory)
+ b.WriteString(t.OldName)
+ b.WriteFID(t.NewDirectory)
+ b.WriteString(t.NewName)
+}
+
+// Type implements message.Type.
+func (*Trenameat) Type() MsgType {
+ return MsgTrenameat
+}
+
+// String implements fmt.Stringer.
+func (t *Trenameat) String() string {
+ return fmt.Sprintf("TrenameAt{OldDirectoryFID: %d, OldName: %s, NewDirectoryFID: %d, NewName: %s}", t.OldDirectory, t.OldName, t.NewDirectory, t.NewName)
+}
+
+// Rrenameat is a rename response.
+type Rrenameat struct {
+}
+
+// decode implements encoder.decode.
+func (*Rrenameat) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rrenameat) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rrenameat) Type() MsgType {
+ return MsgRrenameat
+}
+
+// String implements fmt.Stringer.
+func (r *Rrenameat) String() string {
+ return "Rrenameat{}"
+}
+
+// Tunlinkat is an unlink request.
+type Tunlinkat struct {
+ // Directory is the originating directory.
+ Directory FID
+
+ // Name is the name of the entry to unlink.
+ Name string
+
+ // Flags are extra flags (e.g. O_DIRECTORY). These are not interpreted by p9.
+ Flags uint32
+}
+
+// decode implements encoder.decode.
+func (t *Tunlinkat) decode(b *buffer) {
+ t.Directory = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Flags = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (t *Tunlinkat) encode(b *buffer) {
+ b.WriteFID(t.Directory)
+ b.WriteString(t.Name)
+ b.Write32(t.Flags)
+}
+
+// Type implements message.Type.
+func (*Tunlinkat) Type() MsgType {
+ return MsgTunlinkat
+}
+
+// String implements fmt.Stringer.
+func (t *Tunlinkat) String() string {
+ return fmt.Sprintf("Tunlinkat{DirectoryFID: %d, Name: %s, Flags: 0x%X}", t.Directory, t.Name, t.Flags)
+}
+
+// Runlinkat is an unlink response.
+type Runlinkat struct {
+}
+
+// decode implements encoder.decode.
+func (*Runlinkat) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Runlinkat) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Runlinkat) Type() MsgType {
+ return MsgRunlinkat
+}
+
+// String implements fmt.Stringer.
+func (r *Runlinkat) String() string {
+ return "Runlinkat{}"
+}
+
+// Trename is a rename request.
+//
+// Note that this generally isn't used anymore, and ideally all rename calls
+// should Trenameat below.
+type Trename struct {
+ // FID is the FID to rename.
+ FID FID
+
+ // Directory is the target directory.
+ Directory FID
+
+ // Name is the new file name.
+ Name string
+}
+
+// decode implements encoder.decode.
+func (t *Trename) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Directory = b.ReadFID()
+ t.Name = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (t *Trename) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteFID(t.Directory)
+ b.WriteString(t.Name)
+}
+
+// Type implements message.Type.
+func (*Trename) Type() MsgType {
+ return MsgTrename
+}
+
+// String implements fmt.Stringer.
+func (t *Trename) String() string {
+ return fmt.Sprintf("Trename{FID: %d, DirectoryFID: %d, Name: %s}", t.FID, t.Directory, t.Name)
+}
+
+// Rrename is a rename response.
+type Rrename struct {
+}
+
+// decode implements encoder.decode.
+func (*Rrename) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rrename) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rrename) Type() MsgType {
+ return MsgRrename
+}
+
+// String implements fmt.Stringer.
+func (r *Rrename) String() string {
+ return "Rrename{}"
+}
+
+// Treadlink is a readlink request.
+type Treadlink struct {
+ // FID is the symlink.
+ FID FID
+}
+
+// decode implements encoder.decode.
+func (t *Treadlink) decode(b *buffer) {
+ t.FID = b.ReadFID()
+}
+
+// encode implements encoder.encode.
+func (t *Treadlink) encode(b *buffer) {
+ b.WriteFID(t.FID)
+}
+
+// Type implements message.Type.
+func (*Treadlink) Type() MsgType {
+ return MsgTreadlink
+}
+
+// String implements fmt.Stringer.
+func (t *Treadlink) String() string {
+ return fmt.Sprintf("Treadlink{FID: %d}", t.FID)
+}
+
+// Rreadlink is a readlink response.
+type Rreadlink struct {
+ // Target is the symlink target.
+ Target string
+}
+
+// decode implements encoder.decode.
+func (r *Rreadlink) decode(b *buffer) {
+ r.Target = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (r *Rreadlink) encode(b *buffer) {
+ b.WriteString(r.Target)
+}
+
+// Type implements message.Type.
+func (*Rreadlink) Type() MsgType {
+ return MsgRreadlink
+}
+
+// String implements fmt.Stringer.
+func (r *Rreadlink) String() string {
+ return fmt.Sprintf("Rreadlink{Target: %s}", r.Target)
+}
+
+// Tread is a read request.
+type Tread struct {
+ // FID is the FID to read.
+ FID FID
+
+ // Offset indicates the file offset.
+ Offset uint64
+
+ // Count indicates the number of bytes to read.
+ Count uint32
+}
+
+// decode implements encoder.decode.
+func (t *Tread) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Offset = b.Read64()
+ t.Count = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (t *Tread) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.Write64(t.Offset)
+ b.Write32(t.Count)
+}
+
+// Type implements message.Type.
+func (*Tread) Type() MsgType {
+ return MsgTread
+}
+
+// String implements fmt.Stringer.
+func (t *Tread) String() string {
+ return fmt.Sprintf("Tread{FID: %d, Offset: %d, Count: %d}", t.FID, t.Offset, t.Count)
+}
+
+// Rread is the response for a Tread.
+type Rread struct {
+ // Data is the resulting data.
+ Data []byte
+}
+
+// decode implements encoder.decode.
+//
+// Data is automatically decoded via Payload.
+func (r *Rread) decode(b *buffer) {
+ count := b.Read32()
+ if count != uint32(len(r.Data)) {
+ b.markOverrun()
+ }
+}
+
+// encode implements encoder.encode.
+//
+// Data is automatically encoded via Payload.
+func (r *Rread) encode(b *buffer) {
+ b.Write32(uint32(len(r.Data)))
+}
+
+// Type implements message.Type.
+func (*Rread) Type() MsgType {
+ return MsgRread
+}
+
+// FixedSize implements payloader.FixedSize.
+func (*Rread) FixedSize() uint32 {
+ return 4
+}
+
+// Payload implements payloader.Payload.
+func (r *Rread) Payload() []byte {
+ return r.Data
+}
+
+// SetPayload implements payloader.SetPayload.
+func (r *Rread) SetPayload(p []byte) {
+ r.Data = p
+}
+
+// String implements fmt.Stringer.
+func (r *Rread) String() string {
+ return fmt.Sprintf("Rread{len(Data): %d}", len(r.Data))
+}
+
+// Twrite is a write request.
+type Twrite struct {
+ // FID is the FID to read.
+ FID FID
+
+ // Offset indicates the file offset.
+ Offset uint64
+
+ // Data is the data to be written.
+ Data []byte
+}
+
+// decode implements encoder.decode.
+func (t *Twrite) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Offset = b.Read64()
+ count := b.Read32()
+ if count != uint32(len(t.Data)) {
+ b.markOverrun()
+ }
+}
+
+// encode implements encoder.encode.
+//
+// This uses the buffer payload to avoid a copy.
+func (t *Twrite) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.Write64(t.Offset)
+ b.Write32(uint32(len(t.Data)))
+}
+
+// Type implements message.Type.
+func (*Twrite) Type() MsgType {
+ return MsgTwrite
+}
+
+// FixedSize implements payloader.FixedSize.
+func (*Twrite) FixedSize() uint32 {
+ return 16
+}
+
+// Payload implements payloader.Payload.
+func (t *Twrite) Payload() []byte {
+ return t.Data
+}
+
+// SetPayload implements payloader.SetPayload.
+func (t *Twrite) SetPayload(p []byte) {
+ t.Data = p
+}
+
+// String implements fmt.Stringer.
+func (t *Twrite) String() string {
+ return fmt.Sprintf("Twrite{FID: %v, Offset %d, len(Data): %d}", t.FID, t.Offset, len(t.Data))
+}
+
+// Rwrite is the response for a Twrite.
+type Rwrite struct {
+ // Count indicates the number of bytes successfully written.
+ Count uint32
+}
+
+// decode implements encoder.decode.
+func (r *Rwrite) decode(b *buffer) {
+ r.Count = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (r *Rwrite) encode(b *buffer) {
+ b.Write32(r.Count)
+}
+
+// Type implements message.Type.
+func (*Rwrite) Type() MsgType {
+ return MsgRwrite
+}
+
+// String implements fmt.Stringer.
+func (r *Rwrite) String() string {
+ return fmt.Sprintf("Rwrite{Count: %d}", r.Count)
+}
+
+// Tmknod is a mknod request.
+type Tmknod struct {
+ // Directory is the parent directory.
+ Directory FID
+
+ // Name is the device name.
+ Name string
+
+ // Mode is the device mode and permissions.
+ Mode FileMode
+
+ // Major is the device major number.
+ Major uint32
+
+ // Minor is the device minor number.
+ Minor uint32
+
+ // GID is the device GID.
+ GID GID
+}
+
+// decode implements encoder.decode.
+func (t *Tmknod) decode(b *buffer) {
+ t.Directory = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Mode = b.ReadFileMode()
+ t.Major = b.Read32()
+ t.Minor = b.Read32()
+ t.GID = b.ReadGID()
+}
+
+// encode implements encoder.encode.
+func (t *Tmknod) encode(b *buffer) {
+ b.WriteFID(t.Directory)
+ b.WriteString(t.Name)
+ b.WriteFileMode(t.Mode)
+ b.Write32(t.Major)
+ b.Write32(t.Minor)
+ b.WriteGID(t.GID)
+}
+
+// Type implements message.Type.
+func (*Tmknod) Type() MsgType {
+ return MsgTmknod
+}
+
+// String implements fmt.Stringer.
+func (t *Tmknod) String() string {
+ return fmt.Sprintf("Tmknod{DirectoryFID: %d, Name: %s, Mode: 0o%o, Major: %d, Minor: %d, GID: %d}", t.Directory, t.Name, t.Mode, t.Major, t.Minor, t.GID)
+}
+
+// Rmknod is a mknod response.
+type Rmknod struct {
+ // QID is the resulting QID.
+ QID QID
+}
+
+// decode implements encoder.decode.
+func (r *Rmknod) decode(b *buffer) {
+ r.QID.decode(b)
+}
+
+// encode implements encoder.encode.
+func (r *Rmknod) encode(b *buffer) {
+ r.QID.encode(b)
+}
+
+// Type implements message.Type.
+func (*Rmknod) Type() MsgType {
+ return MsgRmknod
+}
+
+// String implements fmt.Stringer.
+func (r *Rmknod) String() string {
+ return fmt.Sprintf("Rmknod{QID: %s}", r.QID)
+}
+
+// Tmkdir is a mkdir request.
+type Tmkdir struct {
+ // Directory is the parent directory.
+ Directory FID
+
+ // Name is the new directory name.
+ Name string
+
+ // Permissions is the set of permission bits.
+ Permissions FileMode
+
+ // GID is the owning group.
+ GID GID
+}
+
+// decode implements encoder.decode.
+func (t *Tmkdir) decode(b *buffer) {
+ t.Directory = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Permissions = b.ReadPermissions()
+ t.GID = b.ReadGID()
+}
+
+// encode implements encoder.encode.
+func (t *Tmkdir) encode(b *buffer) {
+ b.WriteFID(t.Directory)
+ b.WriteString(t.Name)
+ b.WritePermissions(t.Permissions)
+ b.WriteGID(t.GID)
+}
+
+// Type implements message.Type.
+func (*Tmkdir) Type() MsgType {
+ return MsgTmkdir
+}
+
+// String implements fmt.Stringer.
+func (t *Tmkdir) String() string {
+ return fmt.Sprintf("Tmkdir{DirectoryFID: %d, Name: %s, Permissions: 0o%o, GID: %d}", t.Directory, t.Name, t.Permissions, t.GID)
+}
+
+// Rmkdir is a mkdir response.
+type Rmkdir struct {
+ // QID is the resulting QID.
+ QID QID
+}
+
+// decode implements encoder.decode.
+func (r *Rmkdir) decode(b *buffer) {
+ r.QID.decode(b)
+}
+
+// encode implements encoder.encode.
+func (r *Rmkdir) encode(b *buffer) {
+ r.QID.encode(b)
+}
+
+// Type implements message.Type.
+func (*Rmkdir) Type() MsgType {
+ return MsgRmkdir
+}
+
+// String implements fmt.Stringer.
+func (r *Rmkdir) String() string {
+ return fmt.Sprintf("Rmkdir{QID: %s}", r.QID)
+}
+
+// Tgetattr is a getattr request.
+type Tgetattr struct {
+ // FID is the FID to get attributes for.
+ FID FID
+
+ // AttrMask is the set of attributes to get.
+ AttrMask AttrMask
+}
+
+// decode implements encoder.decode.
+func (t *Tgetattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.AttrMask.decode(b)
+}
+
+// encode implements encoder.encode.
+func (t *Tgetattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ t.AttrMask.encode(b)
+}
+
+// Type implements message.Type.
+func (*Tgetattr) Type() MsgType {
+ return MsgTgetattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tgetattr) String() string {
+ return fmt.Sprintf("Tgetattr{FID: %d, AttrMask: %s}", t.FID, t.AttrMask)
+}
+
+// Rgetattr is a getattr response.
+type Rgetattr struct {
+ // Valid indicates which fields are valid.
+ Valid AttrMask
+
+ // QID is the QID for this file.
+ QID
+
+ // Attr is the set of attributes.
+ Attr Attr
+}
+
+// decode implements encoder.decode.
+func (r *Rgetattr) decode(b *buffer) {
+ r.Valid.decode(b)
+ r.QID.decode(b)
+ r.Attr.decode(b)
+}
+
+// encode implements encoder.encode.
+func (r *Rgetattr) encode(b *buffer) {
+ r.Valid.encode(b)
+ r.QID.encode(b)
+ r.Attr.encode(b)
+}
+
+// Type implements message.Type.
+func (*Rgetattr) Type() MsgType {
+ return MsgRgetattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rgetattr) String() string {
+ return fmt.Sprintf("Rgetattr{Valid: %v, QID: %s, Attr: %s}", r.Valid, r.QID, r.Attr)
+}
+
+// Tsetattr is a setattr request.
+type Tsetattr struct {
+ // FID is the FID to change.
+ FID FID
+
+ // Valid is the set of bits which will be used.
+ Valid SetAttrMask
+
+ // SetAttr is the set request.
+ SetAttr SetAttr
+}
+
+// decode implements encoder.decode.
+func (t *Tsetattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Valid.decode(b)
+ t.SetAttr.decode(b)
+}
+
+// encode implements encoder.encode.
+func (t *Tsetattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ t.Valid.encode(b)
+ t.SetAttr.encode(b)
+}
+
+// Type implements message.Type.
+func (*Tsetattr) Type() MsgType {
+ return MsgTsetattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tsetattr) String() string {
+ return fmt.Sprintf("Tsetattr{FID: %d, Valid: %v, SetAttr: %s}", t.FID, t.Valid, t.SetAttr)
+}
+
+// Rsetattr is a setattr response.
+type Rsetattr struct {
+}
+
+// decode implements encoder.decode.
+func (*Rsetattr) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rsetattr) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rsetattr) Type() MsgType {
+ return MsgRsetattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rsetattr) String() string {
+ return "Rsetattr{}"
+}
+
+// Tallocate is an allocate request. This is an extension to 9P protocol, not
+// present in the 9P2000.L standard.
+type Tallocate struct {
+ FID FID
+ Mode AllocateMode
+ Offset uint64
+ Length uint64
+}
+
+// decode implements encoder.decode.
+func (t *Tallocate) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Mode.decode(b)
+ t.Offset = b.Read64()
+ t.Length = b.Read64()
+}
+
+// encode implements encoder.encode.
+func (t *Tallocate) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ t.Mode.encode(b)
+ b.Write64(t.Offset)
+ b.Write64(t.Length)
+}
+
+// Type implements message.Type.
+func (*Tallocate) Type() MsgType {
+ return MsgTallocate
+}
+
+// String implements fmt.Stringer.
+func (t *Tallocate) String() string {
+ return fmt.Sprintf("Tallocate{FID: %d, Offset: %d, Length: %d}", t.FID, t.Offset, t.Length)
+}
+
+// Rallocate is an allocate response.
+type Rallocate struct {
+}
+
+// decode implements encoder.decode.
+func (*Rallocate) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rallocate) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rallocate) Type() MsgType {
+ return MsgRallocate
+}
+
+// String implements fmt.Stringer.
+func (r *Rallocate) String() string {
+ return "Rallocate{}"
+}
+
+// Tlistxattr is a listxattr request.
+type Tlistxattr struct {
+ // FID refers to the file on which to list xattrs.
+ FID FID
+
+ // Size is the buffer size for the xattr list.
+ Size uint64
+}
+
+// decode implements encoder.decode.
+func (t *Tlistxattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Size = b.Read64()
+}
+
+// encode implements encoder.encode.
+func (t *Tlistxattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.Write64(t.Size)
+}
+
+// Type implements message.Type.
+func (*Tlistxattr) Type() MsgType {
+ return MsgTlistxattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tlistxattr) String() string {
+ return fmt.Sprintf("Tlistxattr{FID: %d, Size: %d}", t.FID, t.Size)
+}
+
+// Rlistxattr is a listxattr response.
+type Rlistxattr struct {
+ // Xattrs is a list of extended attribute names.
+ Xattrs []string
+}
+
+// decode implements encoder.decode.
+func (r *Rlistxattr) decode(b *buffer) {
+ n := b.Read16()
+ r.Xattrs = r.Xattrs[:0]
+ for i := 0; i < int(n); i++ {
+ r.Xattrs = append(r.Xattrs, b.ReadString())
+ }
+}
+
+// encode implements encoder.encode.
+func (r *Rlistxattr) encode(b *buffer) {
+ b.Write16(uint16(len(r.Xattrs)))
+ for _, x := range r.Xattrs {
+ b.WriteString(x)
+ }
+}
+
+// Type implements message.Type.
+func (*Rlistxattr) Type() MsgType {
+ return MsgRlistxattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rlistxattr) String() string {
+ return fmt.Sprintf("Rlistxattr{Xattrs: %v}", r.Xattrs)
+}
+
+// Txattrwalk walks extended attributes.
+type Txattrwalk struct {
+ // FID is the FID to check for attributes.
+ FID FID
+
+ // NewFID is the new FID associated with the attributes.
+ NewFID FID
+
+ // Name is the attribute name.
+ Name string
+}
+
+// decode implements encoder.decode.
+func (t *Txattrwalk) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.NewFID = b.ReadFID()
+ t.Name = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (t *Txattrwalk) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteFID(t.NewFID)
+ b.WriteString(t.Name)
+}
+
+// Type implements message.Type.
+func (*Txattrwalk) Type() MsgType {
+ return MsgTxattrwalk
+}
+
+// String implements fmt.Stringer.
+func (t *Txattrwalk) String() string {
+ return fmt.Sprintf("Txattrwalk{FID: %d, NewFID: %d, Name: %s}", t.FID, t.NewFID, t.Name)
+}
+
+// Rxattrwalk is a xattrwalk response.
+type Rxattrwalk struct {
+ // Size is the size of the extended attribute.
+ Size uint64
+}
+
+// decode implements encoder.decode.
+func (r *Rxattrwalk) decode(b *buffer) {
+ r.Size = b.Read64()
+}
+
+// encode implements encoder.encode.
+func (r *Rxattrwalk) encode(b *buffer) {
+ b.Write64(r.Size)
+}
+
+// Type implements message.Type.
+func (*Rxattrwalk) Type() MsgType {
+ return MsgRxattrwalk
+}
+
+// String implements fmt.Stringer.
+func (r *Rxattrwalk) String() string {
+ return fmt.Sprintf("Rxattrwalk{Size: %d}", r.Size)
+}
+
+// Txattrcreate prepare to set extended attributes.
+type Txattrcreate struct {
+ // FID is input/output parameter, it identifies the file on which
+ // extended attributes will be set but after successful Rxattrcreate
+ // it is used to write the extended attribute value.
+ FID FID
+
+ // Name is the attribute name.
+ Name string
+
+ // Size of the attribute value. When the FID is clunked it has to match
+ // the number of bytes written to the FID.
+ AttrSize uint64
+
+ // Linux setxattr(2) flags.
+ Flags uint32
+}
+
+// decode implements encoder.decode.
+func (t *Txattrcreate) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+ t.AttrSize = b.Read64()
+ t.Flags = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (t *Txattrcreate) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+ b.Write64(t.AttrSize)
+ b.Write32(t.Flags)
+}
+
+// Type implements message.Type.
+func (*Txattrcreate) Type() MsgType {
+ return MsgTxattrcreate
+}
+
+// String implements fmt.Stringer.
+func (t *Txattrcreate) String() string {
+ return fmt.Sprintf("Txattrcreate{FID: %d, Name: %s, AttrSize: %d, Flags: %d}", t.FID, t.Name, t.AttrSize, t.Flags)
+}
+
+// Rxattrcreate is a xattrcreate response.
+type Rxattrcreate struct {
+}
+
+// decode implements encoder.decode.
+func (r *Rxattrcreate) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (r *Rxattrcreate) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rxattrcreate) Type() MsgType {
+ return MsgRxattrcreate
+}
+
+// String implements fmt.Stringer.
+func (r *Rxattrcreate) String() string {
+ return "Rxattrcreate{}"
+}
+
+// Tgetxattr is a getxattr request.
+type Tgetxattr struct {
+ // FID refers to the file for which to get xattrs.
+ FID FID
+
+ // Name is the xattr to get.
+ Name string
+
+ // Size is the buffer size for the xattr to get.
+ Size uint64
+}
+
+// decode implements encoder.decode.
+func (t *Tgetxattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Size = b.Read64()
+}
+
+// encode implements encoder.encode.
+func (t *Tgetxattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+ b.Write64(t.Size)
+}
+
+// Type implements message.Type.
+func (*Tgetxattr) Type() MsgType {
+ return MsgTgetxattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tgetxattr) String() string {
+ return fmt.Sprintf("Tgetxattr{FID: %d, Name: %s, Size: %d}", t.FID, t.Name, t.Size)
+}
+
+// Rgetxattr is a getxattr response.
+type Rgetxattr struct {
+ // Value is the extended attribute value.
+ Value string
+}
+
+// decode implements encoder.decode.
+func (r *Rgetxattr) decode(b *buffer) {
+ r.Value = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (r *Rgetxattr) encode(b *buffer) {
+ b.WriteString(r.Value)
+}
+
+// Type implements message.Type.
+func (*Rgetxattr) Type() MsgType {
+ return MsgRgetxattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rgetxattr) String() string {
+ return fmt.Sprintf("Rgetxattr{Value: %s}", r.Value)
+}
+
+// Tsetxattr sets extended attributes.
+type Tsetxattr struct {
+ // FID refers to the file on which to set xattrs.
+ FID FID
+
+ // Name is the attribute name.
+ Name string
+
+ // Value is the attribute value.
+ Value string
+
+ // Linux setxattr(2) flags.
+ Flags uint32
+}
+
+// decode implements encoder.decode.
+func (t *Tsetxattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Value = b.ReadString()
+ t.Flags = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (t *Tsetxattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+ b.WriteString(t.Value)
+ b.Write32(t.Flags)
+}
+
+// Type implements message.Type.
+func (*Tsetxattr) Type() MsgType {
+ return MsgTsetxattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tsetxattr) String() string {
+ return fmt.Sprintf("Tsetxattr{FID: %d, Name: %s, Value: %s, Flags: %d}", t.FID, t.Name, t.Value, t.Flags)
+}
+
+// Rsetxattr is a setxattr response.
+type Rsetxattr struct {
+}
+
+// decode implements encoder.decode.
+func (r *Rsetxattr) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (r *Rsetxattr) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rsetxattr) Type() MsgType {
+ return MsgRsetxattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rsetxattr) String() string {
+ return "Rsetxattr{}"
+}
+
+// Tremovexattr is a removexattr request.
+type Tremovexattr struct {
+ // FID refers to the file on which to set xattrs.
+ FID FID
+
+ // Name is the attribute name.
+ Name string
+}
+
+// decode implements encoder.decode.
+func (t *Tremovexattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (t *Tremovexattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+}
+
+// Type implements message.Type.
+func (*Tremovexattr) Type() MsgType {
+ return MsgTremovexattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tremovexattr) String() string {
+ return fmt.Sprintf("Tremovexattr{FID: %d, Name: %s}", t.FID, t.Name)
+}
+
+// Rremovexattr is a removexattr response.
+type Rremovexattr struct {
+}
+
+// decode implements encoder.decode.
+func (r *Rremovexattr) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (r *Rremovexattr) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rremovexattr) Type() MsgType {
+ return MsgRremovexattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rremovexattr) String() string {
+ return "Rremovexattr{}"
+}
+
+// Treaddir is a readdir request.
+type Treaddir struct {
+ // Directory is the directory FID to read.
+ Directory FID
+
+ // Offset is the offset to read at.
+ Offset uint64
+
+ // Count is the number of bytes to read.
+ Count uint32
+}
+
+// decode implements encoder.decode.
+func (t *Treaddir) decode(b *buffer) {
+ t.Directory = b.ReadFID()
+ t.Offset = b.Read64()
+ t.Count = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (t *Treaddir) encode(b *buffer) {
+ b.WriteFID(t.Directory)
+ b.Write64(t.Offset)
+ b.Write32(t.Count)
+}
+
+// Type implements message.Type.
+func (*Treaddir) Type() MsgType {
+ return MsgTreaddir
+}
+
+// String implements fmt.Stringer.
+func (t *Treaddir) String() string {
+ return fmt.Sprintf("Treaddir{DirectoryFID: %d, Offset: %d, Count: %d}", t.Directory, t.Offset, t.Count)
+}
+
+// Rreaddir is a readdir response.
+type Rreaddir struct {
+ // Count is the byte limit.
+ //
+ // This should always be set from the Treaddir request.
+ Count uint32
+
+ // Entries are the resulting entries.
+ //
+ // This may be constructed in decode.
+ Entries []Dirent
+
+ // payload is the encoded payload.
+ //
+ // This is constructed by encode.
+ payload []byte
+}
+
+// decode implements encoder.decode.
+func (r *Rreaddir) decode(b *buffer) {
+ r.Count = b.Read32()
+ entriesBuf := buffer{data: r.payload}
+ r.Entries = r.Entries[:0]
+ for {
+ var d Dirent
+ d.decode(&entriesBuf)
+ if entriesBuf.isOverrun() {
+ // Couldn't decode a complete entry.
+ break
+ }
+ r.Entries = append(r.Entries, d)
+ }
+}
+
+// encode implements encoder.encode.
+func (r *Rreaddir) encode(b *buffer) {
+ entriesBuf := buffer{}
+ payloadSize := 0
+ for _, d := range r.Entries {
+ d.encode(&entriesBuf)
+ if len(entriesBuf.data) > int(r.Count) {
+ break
+ }
+ payloadSize = len(entriesBuf.data)
+ }
+ r.Count = uint32(payloadSize)
+ r.payload = entriesBuf.data[:payloadSize]
+ b.Write32(r.Count)
+}
+
+// Type implements message.Type.
+func (*Rreaddir) Type() MsgType {
+ return MsgRreaddir
+}
+
+// FixedSize implements payloader.FixedSize.
+func (*Rreaddir) FixedSize() uint32 {
+ return 4
+}
+
+// Payload implements payloader.Payload.
+func (r *Rreaddir) Payload() []byte {
+ return r.payload
+}
+
+// SetPayload implements payloader.SetPayload.
+func (r *Rreaddir) SetPayload(p []byte) {
+ r.payload = p
+}
+
+// String implements fmt.Stringer.
+func (r *Rreaddir) String() string {
+ return fmt.Sprintf("Rreaddir{Count: %d, Entries: %s}", r.Count, r.Entries)
+}
+
+// Tfsync is an fsync request.
+type Tfsync struct {
+ // FID is the fid to sync.
+ FID FID
+}
+
+// decode implements encoder.decode.
+func (t *Tfsync) decode(b *buffer) {
+ t.FID = b.ReadFID()
+}
+
+// encode implements encoder.encode.
+func (t *Tfsync) encode(b *buffer) {
+ b.WriteFID(t.FID)
+}
+
+// Type implements message.Type.
+func (*Tfsync) Type() MsgType {
+ return MsgTfsync
+}
+
+// String implements fmt.Stringer.
+func (t *Tfsync) String() string {
+ return fmt.Sprintf("Tfsync{FID: %d}", t.FID)
+}
+
+// Rfsync is an fsync response.
+type Rfsync struct {
+}
+
+// decode implements encoder.decode.
+func (*Rfsync) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rfsync) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rfsync) Type() MsgType {
+ return MsgRfsync
+}
+
+// String implements fmt.Stringer.
+func (r *Rfsync) String() string {
+ return "Rfsync{}"
+}
+
+// Tstatfs is a stat request.
+type Tstatfs struct {
+ // FID is the root.
+ FID FID
+}
+
+// decode implements encoder.decode.
+func (t *Tstatfs) decode(b *buffer) {
+ t.FID = b.ReadFID()
+}
+
+// encode implements encoder.encode.
+func (t *Tstatfs) encode(b *buffer) {
+ b.WriteFID(t.FID)
+}
+
+// Type implements message.Type.
+func (*Tstatfs) Type() MsgType {
+ return MsgTstatfs
+}
+
+// String implements fmt.Stringer.
+func (t *Tstatfs) String() string {
+ return fmt.Sprintf("Tstatfs{FID: %d}", t.FID)
+}
+
+// Rstatfs is the response for a Tstatfs.
+type Rstatfs struct {
+ // FSStat is the stat result.
+ FSStat FSStat
+}
+
+// decode implements encoder.decode.
+func (r *Rstatfs) decode(b *buffer) {
+ r.FSStat.decode(b)
+}
+
+// encode implements encoder.encode.
+func (r *Rstatfs) encode(b *buffer) {
+ r.FSStat.encode(b)
+}
+
+// Type implements message.Type.
+func (*Rstatfs) Type() MsgType {
+ return MsgRstatfs
+}
+
+// String implements fmt.Stringer.
+func (r *Rstatfs) String() string {
+ return fmt.Sprintf("Rstatfs{FSStat: %v}", r.FSStat)
+}
+
+// Tflushf is a flush file request, not to be confused with Tflush.
+type Tflushf struct {
+ // FID is the FID to be flushed.
+ FID FID
+}
+
+// decode implements encoder.decode.
+func (t *Tflushf) decode(b *buffer) {
+ t.FID = b.ReadFID()
+}
+
+// encode implements encoder.encode.
+func (t *Tflushf) encode(b *buffer) {
+ b.WriteFID(t.FID)
+}
+
+// Type implements message.Type.
+func (*Tflushf) Type() MsgType {
+ return MsgTflushf
+}
+
+// String implements fmt.Stringer.
+func (t *Tflushf) String() string {
+ return fmt.Sprintf("Tflushf{FID: %d}", t.FID)
+}
+
+// Rflushf is a flush file response.
+type Rflushf struct {
+}
+
+// decode implements encoder.decode.
+func (*Rflushf) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rflushf) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rflushf) Type() MsgType {
+ return MsgRflushf
+}
+
+// String implements fmt.Stringer.
+func (*Rflushf) String() string {
+ return "Rflushf{}"
+}
+
+// Twalkgetattr is a walk request.
+type Twalkgetattr struct {
+ // FID is the FID to be walked.
+ FID FID
+
+ // NewFID is the resulting FID.
+ NewFID FID
+
+ // Names are the set of names to be walked.
+ Names []string
+}
+
+// decode implements encoder.decode.
+func (t *Twalkgetattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.NewFID = b.ReadFID()
+ n := b.Read16()
+ t.Names = t.Names[:0]
+ for i := 0; i < int(n); i++ {
+ t.Names = append(t.Names, b.ReadString())
+ }
+}
+
+// encode implements encoder.encode.
+func (t *Twalkgetattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteFID(t.NewFID)
+ b.Write16(uint16(len(t.Names)))
+ for _, name := range t.Names {
+ b.WriteString(name)
+ }
+}
+
+// Type implements message.Type.
+func (*Twalkgetattr) Type() MsgType {
+ return MsgTwalkgetattr
+}
+
+// String implements fmt.Stringer.
+func (t *Twalkgetattr) String() string {
+ return fmt.Sprintf("Twalkgetattr{FID: %d, NewFID: %d, Names: %v}", t.FID, t.NewFID, t.Names)
+}
+
+// Rwalkgetattr is a walk response.
+type Rwalkgetattr struct {
+ // Valid indicates which fields are valid in the Attr below.
+ Valid AttrMask
+
+ // Attr is the set of attributes for the last QID (the file walked to).
+ Attr Attr
+
+ // QIDs are the set of QIDs returned.
+ QIDs []QID
+}
+
+// decode implements encoder.decode.
+func (r *Rwalkgetattr) decode(b *buffer) {
+ r.Valid.decode(b)
+ r.Attr.decode(b)
+ n := b.Read16()
+ r.QIDs = r.QIDs[:0]
+ for i := 0; i < int(n); i++ {
+ var q QID
+ q.decode(b)
+ r.QIDs = append(r.QIDs, q)
+ }
+}
+
+// encode implements encoder.encode.
+func (r *Rwalkgetattr) encode(b *buffer) {
+ r.Valid.encode(b)
+ r.Attr.encode(b)
+ b.Write16(uint16(len(r.QIDs)))
+ for _, q := range r.QIDs {
+ q.encode(b)
+ }
+}
+
+// Type implements message.Type.
+func (*Rwalkgetattr) Type() MsgType {
+ return MsgRwalkgetattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rwalkgetattr) String() string {
+ return fmt.Sprintf("Rwalkgetattr{Valid: %s, Attr: %s, QIDs: %v}", r.Valid, r.Attr, r.QIDs)
+}
+
+// Tucreate is a Tlcreate message that includes a UID.
+type Tucreate struct {
+ Tlcreate
+
+ // UID is the UID to use as the effective UID in creation messages.
+ UID UID
+}
+
+// decode implements encoder.decode.
+func (t *Tucreate) decode(b *buffer) {
+ t.Tlcreate.decode(b)
+ t.UID = b.ReadUID()
+}
+
+// encode implements encoder.encode.
+func (t *Tucreate) encode(b *buffer) {
+ t.Tlcreate.encode(b)
+ b.WriteUID(t.UID)
+}
+
+// Type implements message.Type.
+func (t *Tucreate) Type() MsgType {
+ return MsgTucreate
+}
+
+// String implements fmt.Stringer.
+func (t *Tucreate) String() string {
+ return fmt.Sprintf("Tucreate{Tlcreate: %v, UID: %d}", &t.Tlcreate, t.UID)
+}
+
+// Rucreate is a file creation response.
+type Rucreate struct {
+ Rlcreate
+}
+
+// Type implements message.Type.
+func (*Rucreate) Type() MsgType {
+ return MsgRucreate
+}
+
+// String implements fmt.Stringer.
+func (r *Rucreate) String() string {
+ return fmt.Sprintf("Rucreate{%v}", &r.Rlcreate)
+}
+
+// Tumkdir is a Tmkdir message that includes a UID.
+type Tumkdir struct {
+ Tmkdir
+
+ // UID is the UID to use as the effective UID in creation messages.
+ UID UID
+}
+
+// decode implements encoder.decode.
+func (t *Tumkdir) decode(b *buffer) {
+ t.Tmkdir.decode(b)
+ t.UID = b.ReadUID()
+}
+
+// encode implements encoder.encode.
+func (t *Tumkdir) encode(b *buffer) {
+ t.Tmkdir.encode(b)
+ b.WriteUID(t.UID)
+}
+
+// Type implements message.Type.
+func (t *Tumkdir) Type() MsgType {
+ return MsgTumkdir
+}
+
+// String implements fmt.Stringer.
+func (t *Tumkdir) String() string {
+ return fmt.Sprintf("Tumkdir{Tmkdir: %v, UID: %d}", &t.Tmkdir, t.UID)
+}
+
+// Rumkdir is a umkdir response.
+type Rumkdir struct {
+ Rmkdir
+}
+
+// Type implements message.Type.
+func (*Rumkdir) Type() MsgType {
+ return MsgRumkdir
+}
+
+// String implements fmt.Stringer.
+func (r *Rumkdir) String() string {
+ return fmt.Sprintf("Rumkdir{%v}", &r.Rmkdir)
+}
+
+// Tumknod is a Tmknod message that includes a UID.
+type Tumknod struct {
+ Tmknod
+
+ // UID is the UID to use as the effective UID in creation messages.
+ UID UID
+}
+
+// decode implements encoder.decode.
+func (t *Tumknod) decode(b *buffer) {
+ t.Tmknod.decode(b)
+ t.UID = b.ReadUID()
+}
+
+// encode implements encoder.encode.
+func (t *Tumknod) encode(b *buffer) {
+ t.Tmknod.encode(b)
+ b.WriteUID(t.UID)
+}
+
+// Type implements message.Type.
+func (t *Tumknod) Type() MsgType {
+ return MsgTumknod
+}
+
+// String implements fmt.Stringer.
+func (t *Tumknod) String() string {
+ return fmt.Sprintf("Tumknod{Tmknod: %v, UID: %d}", &t.Tmknod, t.UID)
+}
+
+// Rumknod is a umknod response.
+type Rumknod struct {
+ Rmknod
+}
+
+// Type implements message.Type.
+func (*Rumknod) Type() MsgType {
+ return MsgRumknod
+}
+
+// String implements fmt.Stringer.
+func (r *Rumknod) String() string {
+ return fmt.Sprintf("Rumknod{%v}", &r.Rmknod)
+}
+
+// Tusymlink is a Tsymlink message that includes a UID.
+type Tusymlink struct {
+ Tsymlink
+
+ // UID is the UID to use as the effective UID in creation messages.
+ UID UID
+}
+
+// decode implements encoder.decode.
+func (t *Tusymlink) decode(b *buffer) {
+ t.Tsymlink.decode(b)
+ t.UID = b.ReadUID()
+}
+
+// encode implements encoder.encode.
+func (t *Tusymlink) encode(b *buffer) {
+ t.Tsymlink.encode(b)
+ b.WriteUID(t.UID)
+}
+
+// Type implements message.Type.
+func (t *Tusymlink) Type() MsgType {
+ return MsgTusymlink
+}
+
+// String implements fmt.Stringer.
+func (t *Tusymlink) String() string {
+ return fmt.Sprintf("Tusymlink{Tsymlink: %v, UID: %d}", &t.Tsymlink, t.UID)
+}
+
+// Rusymlink is a usymlink response.
+type Rusymlink struct {
+ Rsymlink
+}
+
+// Type implements message.Type.
+func (*Rusymlink) Type() MsgType {
+ return MsgRusymlink
+}
+
+// String implements fmt.Stringer.
+func (r *Rusymlink) String() string {
+ return fmt.Sprintf("Rusymlink{%v}", &r.Rsymlink)
+}
+
+// Tlconnect is a connect request.
+type Tlconnect struct {
+ // FID is the FID to be connected.
+ FID FID
+
+ // Flags are the connect flags.
+ Flags ConnectFlags
+}
+
+// decode implements encoder.decode.
+func (t *Tlconnect) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Flags = b.ReadConnectFlags()
+}
+
+// encode implements encoder.encode.
+func (t *Tlconnect) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteConnectFlags(t.Flags)
+}
+
+// Type implements message.Type.
+func (*Tlconnect) Type() MsgType {
+ return MsgTlconnect
+}
+
+// String implements fmt.Stringer.
+func (t *Tlconnect) String() string {
+ return fmt.Sprintf("Tlconnect{FID: %d, Flags: %v}", t.FID, t.Flags)
+}
+
+// Rlconnect is a connect response.
+type Rlconnect struct {
+ filePayload
+}
+
+// decode implements encoder.decode.
+func (r *Rlconnect) decode(*buffer) {}
+
+// encode implements encoder.encode.
+func (r *Rlconnect) encode(*buffer) {}
+
+// Type implements message.Type.
+func (*Rlconnect) Type() MsgType {
+ return MsgRlconnect
+}
+
+// String implements fmt.Stringer.
+func (r *Rlconnect) String() string {
+ return fmt.Sprintf("Rlconnect{File: %v}", r.File)
+}
+
+// Tchannel creates a new channel.
+type Tchannel struct {
+ // ID is the channel ID.
+ ID uint32
+
+ // Control is 0 if the Rchannel response should provide the flipcall
+ // component of the channel, and 1 if the Rchannel response should
+ // provide the fdchannel component of the channel.
+ Control uint32
+}
+
+// decode implements encoder.decode.
+func (t *Tchannel) decode(b *buffer) {
+ t.ID = b.Read32()
+ t.Control = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (t *Tchannel) encode(b *buffer) {
+ b.Write32(t.ID)
+ b.Write32(t.Control)
+}
+
+// Type implements message.Type.
+func (*Tchannel) Type() MsgType {
+ return MsgTchannel
+}
+
+// String implements fmt.Stringer.
+func (t *Tchannel) String() string {
+ return fmt.Sprintf("Tchannel{ID: %d, Control: %d}", t.ID, t.Control)
+}
+
+// Rchannel is the channel response.
+type Rchannel struct {
+ Offset uint64
+ Length uint64
+ filePayload
+}
+
+// decode implements encoder.decode.
+func (r *Rchannel) decode(b *buffer) {
+ r.Offset = b.Read64()
+ r.Length = b.Read64()
+}
+
+// encode implements encoder.encode.
+func (r *Rchannel) encode(b *buffer) {
+ b.Write64(r.Offset)
+ b.Write64(r.Length)
+}
+
+// Type implements message.Type.
+func (*Rchannel) Type() MsgType {
+ return MsgRchannel
+}
+
+// String implements fmt.Stringer.
+func (r *Rchannel) String() string {
+ return fmt.Sprintf("Rchannel{Offset: %d, Length: %d}", r.Offset, r.Length)
+}
+
+const maxCacheSize = 3
+
+// msgFactory is used to reduce allocations by caching messages for reuse.
+type msgFactory struct {
+ create func() message
+ cache chan message
+}
+
+// msgRegistry indexes all message factories by type.
+var msgRegistry registry
+
+type registry struct {
+ factories [math.MaxUint8]msgFactory
+
+ // largestFixedSize is computed so that given some message size M, you can
+ // compute the maximum payload size (e.g. for Twrite, Rread) with
+ // M-largestFixedSize. You could do this individual on a per-message basis,
+ // but it's easier to compute a single maximum safe payload.
+ largestFixedSize uint32
+}
+
+// get returns a new message by type.
+//
+// An error is returned in the case of an unknown message.
+//
+// This takes, and ignores, a message tag so that it may be used directly as a
+// lookupTagAndType function for recv (by design).
+func (r *registry) get(_ Tag, t MsgType) (message, error) {
+ entry := &r.factories[t]
+ if entry.create == nil {
+ return nil, &ErrInvalidMsgType{t}
+ }
+
+ select {
+ case msg := <-entry.cache:
+ return msg, nil
+ default:
+ return entry.create(), nil
+ }
+}
+
+func (r *registry) put(msg message) {
+ if p, ok := msg.(payloader); ok {
+ p.SetPayload(nil)
+ }
+ if f, ok := msg.(filer); ok {
+ f.SetFilePayload(nil)
+ }
+
+ entry := &r.factories[msg.Type()]
+ select {
+ case entry.cache <- msg:
+ default:
+ }
+}
+
+// register registers the given message type.
+//
+// This may cause panic on failure and should only be used from init.
+func (r *registry) register(t MsgType, fn func() message) {
+ if int(t) >= len(r.factories) {
+ panic(fmt.Sprintf("message type %d is too large. It must be smaller than %d", t, len(r.factories)))
+ }
+ if r.factories[t].create != nil {
+ panic(fmt.Sprintf("duplicate message type %d: first is %T, second is %T", t, r.factories[t].create(), fn()))
+ }
+ r.factories[t] = msgFactory{
+ create: fn,
+ cache: make(chan message, maxCacheSize),
+ }
+
+ if size := calculateSize(fn()); size > r.largestFixedSize {
+ r.largestFixedSize = size
+ }
+}
+
+func calculateSize(m message) uint32 {
+ if p, ok := m.(payloader); ok {
+ return p.FixedSize()
+ }
+ var dataBuf buffer
+ m.encode(&dataBuf)
+ return uint32(len(dataBuf.data))
+}
+
+func init() {
+ msgRegistry.register(MsgRlerror, func() message { return &Rlerror{} })
+ msgRegistry.register(MsgTstatfs, func() message { return &Tstatfs{} })
+ msgRegistry.register(MsgRstatfs, func() message { return &Rstatfs{} })
+ msgRegistry.register(MsgTlopen, func() message { return &Tlopen{} })
+ msgRegistry.register(MsgRlopen, func() message { return &Rlopen{} })
+ msgRegistry.register(MsgTlcreate, func() message { return &Tlcreate{} })
+ msgRegistry.register(MsgRlcreate, func() message { return &Rlcreate{} })
+ msgRegistry.register(MsgTsymlink, func() message { return &Tsymlink{} })
+ msgRegistry.register(MsgRsymlink, func() message { return &Rsymlink{} })
+ msgRegistry.register(MsgTmknod, func() message { return &Tmknod{} })
+ msgRegistry.register(MsgRmknod, func() message { return &Rmknod{} })
+ msgRegistry.register(MsgTrename, func() message { return &Trename{} })
+ msgRegistry.register(MsgRrename, func() message { return &Rrename{} })
+ msgRegistry.register(MsgTreadlink, func() message { return &Treadlink{} })
+ msgRegistry.register(MsgRreadlink, func() message { return &Rreadlink{} })
+ msgRegistry.register(MsgTgetattr, func() message { return &Tgetattr{} })
+ msgRegistry.register(MsgRgetattr, func() message { return &Rgetattr{} })
+ msgRegistry.register(MsgTsetattr, func() message { return &Tsetattr{} })
+ msgRegistry.register(MsgRsetattr, func() message { return &Rsetattr{} })
+ msgRegistry.register(MsgTlistxattr, func() message { return &Tlistxattr{} })
+ msgRegistry.register(MsgRlistxattr, func() message { return &Rlistxattr{} })
+ msgRegistry.register(MsgTxattrwalk, func() message { return &Txattrwalk{} })
+ msgRegistry.register(MsgRxattrwalk, func() message { return &Rxattrwalk{} })
+ msgRegistry.register(MsgTxattrcreate, func() message { return &Txattrcreate{} })
+ msgRegistry.register(MsgRxattrcreate, func() message { return &Rxattrcreate{} })
+ msgRegistry.register(MsgTgetxattr, func() message { return &Tgetxattr{} })
+ msgRegistry.register(MsgRgetxattr, func() message { return &Rgetxattr{} })
+ msgRegistry.register(MsgTsetxattr, func() message { return &Tsetxattr{} })
+ msgRegistry.register(MsgRsetxattr, func() message { return &Rsetxattr{} })
+ msgRegistry.register(MsgTremovexattr, func() message { return &Tremovexattr{} })
+ msgRegistry.register(MsgRremovexattr, func() message { return &Rremovexattr{} })
+ msgRegistry.register(MsgTreaddir, func() message { return &Treaddir{} })
+ msgRegistry.register(MsgRreaddir, func() message { return &Rreaddir{} })
+ msgRegistry.register(MsgTfsync, func() message { return &Tfsync{} })
+ msgRegistry.register(MsgRfsync, func() message { return &Rfsync{} })
+ msgRegistry.register(MsgTlink, func() message { return &Tlink{} })
+ msgRegistry.register(MsgRlink, func() message { return &Rlink{} })
+ msgRegistry.register(MsgTmkdir, func() message { return &Tmkdir{} })
+ msgRegistry.register(MsgRmkdir, func() message { return &Rmkdir{} })
+ msgRegistry.register(MsgTrenameat, func() message { return &Trenameat{} })
+ msgRegistry.register(MsgRrenameat, func() message { return &Rrenameat{} })
+ msgRegistry.register(MsgTunlinkat, func() message { return &Tunlinkat{} })
+ msgRegistry.register(MsgRunlinkat, func() message { return &Runlinkat{} })
+ msgRegistry.register(MsgTversion, func() message { return &Tversion{} })
+ msgRegistry.register(MsgRversion, func() message { return &Rversion{} })
+ msgRegistry.register(MsgTauth, func() message { return &Tauth{} })
+ msgRegistry.register(MsgRauth, func() message { return &Rauth{} })
+ msgRegistry.register(MsgTattach, func() message { return &Tattach{} })
+ msgRegistry.register(MsgRattach, func() message { return &Rattach{} })
+ msgRegistry.register(MsgTflush, func() message { return &Tflush{} })
+ msgRegistry.register(MsgRflush, func() message { return &Rflush{} })
+ msgRegistry.register(MsgTwalk, func() message { return &Twalk{} })
+ msgRegistry.register(MsgRwalk, func() message { return &Rwalk{} })
+ msgRegistry.register(MsgTread, func() message { return &Tread{} })
+ msgRegistry.register(MsgRread, func() message { return &Rread{} })
+ msgRegistry.register(MsgTwrite, func() message { return &Twrite{} })
+ msgRegistry.register(MsgRwrite, func() message { return &Rwrite{} })
+ msgRegistry.register(MsgTclunk, func() message { return &Tclunk{} })
+ msgRegistry.register(MsgRclunk, func() message { return &Rclunk{} })
+ msgRegistry.register(MsgTremove, func() message { return &Tremove{} })
+ msgRegistry.register(MsgRremove, func() message { return &Rremove{} })
+ msgRegistry.register(MsgTflushf, func() message { return &Tflushf{} })
+ msgRegistry.register(MsgRflushf, func() message { return &Rflushf{} })
+ msgRegistry.register(MsgTwalkgetattr, func() message { return &Twalkgetattr{} })
+ msgRegistry.register(MsgRwalkgetattr, func() message { return &Rwalkgetattr{} })
+ msgRegistry.register(MsgTucreate, func() message { return &Tucreate{} })
+ msgRegistry.register(MsgRucreate, func() message { return &Rucreate{} })
+ msgRegistry.register(MsgTumkdir, func() message { return &Tumkdir{} })
+ msgRegistry.register(MsgRumkdir, func() message { return &Rumkdir{} })
+ msgRegistry.register(MsgTumknod, func() message { return &Tumknod{} })
+ msgRegistry.register(MsgRumknod, func() message { return &Rumknod{} })
+ msgRegistry.register(MsgTusymlink, func() message { return &Tusymlink{} })
+ msgRegistry.register(MsgRusymlink, func() message { return &Rusymlink{} })
+ msgRegistry.register(MsgTlconnect, func() message { return &Tlconnect{} })
+ msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} })
+ msgRegistry.register(MsgTallocate, func() message { return &Tallocate{} })
+ msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} })
+ msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} })
+ msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} })
+}
diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go
new file mode 100644
index 000000000..7facc9f5e
--- /dev/null
+++ b/pkg/p9/messages_test.go
@@ -0,0 +1,483 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "fmt"
+ "reflect"
+ "testing"
+)
+
+func TestEncodeDecode(t *testing.T) {
+ objs := []encoder{
+ &QID{
+ Type: 1,
+ Version: 2,
+ Path: 3,
+ },
+ &FSStat{
+ Type: 1,
+ BlockSize: 2,
+ Blocks: 3,
+ BlocksFree: 4,
+ BlocksAvailable: 5,
+ Files: 6,
+ FilesFree: 7,
+ FSID: 8,
+ NameLength: 9,
+ },
+ &AttrMask{
+ Mode: true,
+ NLink: true,
+ UID: true,
+ GID: true,
+ RDev: true,
+ ATime: true,
+ MTime: true,
+ CTime: true,
+ INo: true,
+ Size: true,
+ Blocks: true,
+ BTime: true,
+ Gen: true,
+ DataVersion: true,
+ },
+ &Attr{
+ Mode: Exec,
+ UID: 2,
+ GID: 3,
+ NLink: 4,
+ RDev: 5,
+ Size: 6,
+ BlockSize: 7,
+ Blocks: 8,
+ ATimeSeconds: 9,
+ ATimeNanoSeconds: 10,
+ MTimeSeconds: 11,
+ MTimeNanoSeconds: 12,
+ CTimeSeconds: 13,
+ CTimeNanoSeconds: 14,
+ BTimeSeconds: 15,
+ BTimeNanoSeconds: 16,
+ Gen: 17,
+ DataVersion: 18,
+ },
+ &SetAttrMask{
+ Permissions: true,
+ UID: true,
+ GID: true,
+ Size: true,
+ ATime: true,
+ MTime: true,
+ CTime: true,
+ ATimeNotSystemTime: true,
+ MTimeNotSystemTime: true,
+ },
+ &SetAttr{
+ Permissions: 1,
+ UID: 2,
+ GID: 3,
+ Size: 4,
+ ATimeSeconds: 5,
+ ATimeNanoSeconds: 6,
+ MTimeSeconds: 7,
+ MTimeNanoSeconds: 8,
+ },
+ &Dirent{
+ QID: QID{Type: 1},
+ Offset: 2,
+ Type: 3,
+ Name: "a",
+ },
+ &Rlerror{
+ Error: 1,
+ },
+ &Tstatfs{
+ FID: 1,
+ },
+ &Rstatfs{
+ FSStat: FSStat{Type: 1},
+ },
+ &Tlopen{
+ FID: 1,
+ Flags: WriteOnly,
+ },
+ &Rlopen{
+ QID: QID{Type: 1},
+ IoUnit: 2,
+ },
+ &Tlconnect{
+ FID: 1,
+ },
+ &Rlconnect{},
+ &Tlcreate{
+ FID: 1,
+ Name: "a",
+ OpenFlags: 2,
+ Permissions: 3,
+ GID: 4,
+ },
+ &Rlcreate{
+ Rlopen{QID: QID{Type: 1}},
+ },
+ &Tsymlink{
+ Directory: 1,
+ Name: "a",
+ Target: "b",
+ GID: 2,
+ },
+ &Rsymlink{
+ QID: QID{Type: 1},
+ },
+ &Tmknod{
+ Directory: 1,
+ Name: "a",
+ Mode: 2,
+ Major: 3,
+ Minor: 4,
+ GID: 5,
+ },
+ &Rmknod{
+ QID: QID{Type: 1},
+ },
+ &Trename{
+ FID: 1,
+ Directory: 2,
+ Name: "a",
+ },
+ &Rrename{},
+ &Treadlink{
+ FID: 1,
+ },
+ &Rreadlink{
+ Target: "a",
+ },
+ &Tgetattr{
+ FID: 1,
+ AttrMask: AttrMask{Mode: true},
+ },
+ &Rgetattr{
+ Valid: AttrMask{Mode: true},
+ QID: QID{Type: 1},
+ Attr: Attr{Mode: Write},
+ },
+ &Tsetattr{
+ FID: 1,
+ Valid: SetAttrMask{Permissions: true},
+ SetAttr: SetAttr{Permissions: Write},
+ },
+ &Rsetattr{},
+ &Txattrwalk{
+ FID: 1,
+ NewFID: 2,
+ Name: "a",
+ },
+ &Rxattrwalk{
+ Size: 1,
+ },
+ &Txattrcreate{
+ FID: 1,
+ Name: "a",
+ AttrSize: 2,
+ Flags: 3,
+ },
+ &Rxattrcreate{},
+ &Tgetxattr{
+ FID: 1,
+ Name: "abc",
+ Size: 2,
+ },
+ &Rgetxattr{
+ Value: "xyz",
+ },
+ &Tsetxattr{
+ FID: 1,
+ Name: "abc",
+ Value: "xyz",
+ Flags: 2,
+ },
+ &Rsetxattr{},
+ &Treaddir{
+ Directory: 1,
+ Offset: 2,
+ Count: 3,
+ },
+ &Rreaddir{
+ // Count must be sufficient to encode a dirent.
+ Count: 0x1a,
+ Entries: []Dirent{{QID: QID{Type: 2}}},
+ },
+ &Tfsync{
+ FID: 1,
+ },
+ &Rfsync{},
+ &Tlink{
+ Directory: 1,
+ Target: 2,
+ Name: "a",
+ },
+ &Rlink{},
+ &Tmkdir{
+ Directory: 1,
+ Name: "a",
+ Permissions: 2,
+ GID: 3,
+ },
+ &Rmkdir{
+ QID: QID{Type: 1},
+ },
+ &Trenameat{
+ OldDirectory: 1,
+ OldName: "a",
+ NewDirectory: 2,
+ NewName: "b",
+ },
+ &Rrenameat{},
+ &Tunlinkat{
+ Directory: 1,
+ Name: "a",
+ Flags: 2,
+ },
+ &Runlinkat{},
+ &Tversion{
+ MSize: 1,
+ Version: "a",
+ },
+ &Rversion{
+ MSize: 1,
+ Version: "a",
+ },
+ &Tauth{
+ AuthenticationFID: 1,
+ UserName: "a",
+ AttachName: "b",
+ UID: 2,
+ },
+ &Rauth{
+ QID: QID{Type: 1},
+ },
+ &Tattach{
+ FID: 1,
+ Auth: Tauth{AuthenticationFID: 2},
+ },
+ &Rattach{
+ QID: QID{Type: 1},
+ },
+ &Tflush{
+ OldTag: 1,
+ },
+ &Rflush{},
+ &Twalk{
+ FID: 1,
+ NewFID: 2,
+ Names: []string{"a"},
+ },
+ &Rwalk{
+ QIDs: []QID{{Type: 1}},
+ },
+ &Tread{
+ FID: 1,
+ Offset: 2,
+ Count: 3,
+ },
+ &Rread{
+ Data: []byte{'a'},
+ },
+ &Twrite{
+ FID: 1,
+ Offset: 2,
+ Data: []byte{'a'},
+ },
+ &Rwrite{
+ Count: 1,
+ },
+ &Tclunk{
+ FID: 1,
+ },
+ &Rclunk{},
+ &Tremove{
+ FID: 1,
+ },
+ &Rremove{},
+ &Tflushf{
+ FID: 1,
+ },
+ &Rflushf{},
+ &Twalkgetattr{
+ FID: 1,
+ NewFID: 2,
+ Names: []string{"a"},
+ },
+ &Rwalkgetattr{
+ QIDs: []QID{{Type: 1}},
+ Valid: AttrMask{Mode: true},
+ Attr: Attr{Mode: Write},
+ },
+ &Tucreate{
+ Tlcreate: Tlcreate{
+ FID: 1,
+ Name: "a",
+ OpenFlags: 2,
+ Permissions: 3,
+ GID: 4,
+ },
+ UID: 5,
+ },
+ &Rucreate{
+ Rlcreate{Rlopen{QID: QID{Type: 1}}},
+ },
+ &Tumkdir{
+ Tmkdir: Tmkdir{
+ Directory: 1,
+ Name: "a",
+ Permissions: 2,
+ GID: 3,
+ },
+ UID: 4,
+ },
+ &Rumkdir{
+ Rmkdir{QID: QID{Type: 1}},
+ },
+ &Tusymlink{
+ Tsymlink: Tsymlink{
+ Directory: 1,
+ Name: "a",
+ Target: "b",
+ GID: 2,
+ },
+ UID: 3,
+ },
+ &Rusymlink{
+ Rsymlink{QID: QID{Type: 1}},
+ },
+ &Tumknod{
+ Tmknod: Tmknod{
+ Directory: 1,
+ Name: "a",
+ Mode: 2,
+ Major: 3,
+ Minor: 4,
+ GID: 5,
+ },
+ UID: 6,
+ },
+ &Rumknod{
+ Rmknod{QID: QID{Type: 1}},
+ },
+ }
+
+ for _, enc := range objs {
+ // Encode the original.
+ data := make([]byte, initialBufferLength)
+ buf := buffer{data: data[:0]}
+ enc.encode(&buf)
+
+ // Create a new object, same as the first.
+ enc2 := reflect.New(reflect.ValueOf(enc).Elem().Type()).Interface().(encoder)
+ buf2 := buffer{data: buf.data}
+
+ // To be fair, we need to add any payloads (directly).
+ if pl, ok := enc.(payloader); ok {
+ enc2.(payloader).SetPayload(pl.Payload())
+ }
+
+ // And any file payloads (directly).
+ if fl, ok := enc.(filer); ok {
+ enc2.(filer).SetFilePayload(fl.FilePayload())
+ }
+
+ // Mark sure it was okay.
+ enc2.decode(&buf2)
+ if buf2.isOverrun() {
+ t.Errorf("object %#v->%#v got overrun on decode", enc, enc2)
+ continue
+ }
+
+ // Check that they are equal.
+ if !reflect.DeepEqual(enc, enc2) {
+ t.Errorf("object %#v and %#v differ", enc, enc2)
+ continue
+ }
+ }
+}
+
+func TestMessageStrings(t *testing.T) {
+ for typ := range msgRegistry.factories {
+ entry := &msgRegistry.factories[typ]
+ if entry.create != nil {
+ name := fmt.Sprintf("%+v", typ)
+ t.Run(name, func(t *testing.T) {
+ defer func() { // Ensure no panic.
+ if r := recover(); r != nil {
+ t.Errorf("printing %s failed: %v", name, r)
+ }
+ }()
+ m := entry.create()
+ _ = fmt.Sprintf("%v", m)
+ err := ErrInvalidMsgType{MsgType(typ)}
+ _ = err.Error()
+ })
+ }
+ }
+}
+
+func TestRegisterDuplicate(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ // We expect a panic.
+ t.FailNow()
+ }
+ }()
+
+ // Register a duplicate.
+ msgRegistry.register(MsgRlerror, func() message { return &Rlerror{} })
+}
+
+func TestMsgCache(t *testing.T) {
+ // Cache starts empty.
+ if got, want := len(msgRegistry.factories[MsgRlerror].cache), 0; got != want {
+ t.Errorf("Wrong cache size, got: %d, want: %d", got, want)
+ }
+
+ // Message can be created with an empty cache.
+ msg, err := msgRegistry.get(0, MsgRlerror)
+ if err != nil {
+ t.Errorf("msgRegistry.get(): %v", err)
+ }
+ if got, want := len(msgRegistry.factories[MsgRlerror].cache), 0; got != want {
+ t.Errorf("Wrong cache size, got: %d, want: %d", got, want)
+ }
+
+ // Check that message is added to the cache when returned.
+ msgRegistry.put(msg)
+ if got, want := len(msgRegistry.factories[MsgRlerror].cache), 1; got != want {
+ t.Errorf("Wrong cache size, got: %d, want: %d", got, want)
+ }
+
+ // Check that returned message is reused.
+ if got, err := msgRegistry.get(0, MsgRlerror); err != nil {
+ t.Errorf("msgRegistry.get(): %v", err)
+ } else if msg != got {
+ t.Errorf("Message not reused, got: %d, want: %d", got, msg)
+ }
+
+ // Check that cache doesn't grow beyond max size.
+ for i := 0; i < maxCacheSize+1; i++ {
+ msgRegistry.put(&Rlerror{})
+ }
+ if got, want := len(msgRegistry.factories[MsgRlerror].cache), maxCacheSize; got != want {
+ t.Errorf("Wrong cache size, got: %d, want: %d", got, want)
+ }
+}
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
new file mode 100644
index 000000000..122c457d2
--- /dev/null
+++ b/pkg/p9/p9.go
@@ -0,0 +1,1171 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package p9 is a 9P2000.L implementation.
+package p9
+
+import (
+ "fmt"
+ "math"
+ "os"
+ "strings"
+ "sync/atomic"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+// OpenFlags is the mode passed to Open and Create operations.
+//
+// These correspond to bits sent over the wire.
+type OpenFlags uint32
+
+const (
+ // ReadOnly is a Tlopen and Tlcreate flag indicating read-only mode.
+ ReadOnly OpenFlags = 0
+
+ // WriteOnly is a Tlopen and Tlcreate flag indicating write-only mode.
+ WriteOnly OpenFlags = 1
+
+ // ReadWrite is a Tlopen flag indicates read-write mode.
+ ReadWrite OpenFlags = 2
+
+ // OpenFlagsModeMask is a mask of valid OpenFlags mode bits.
+ OpenFlagsModeMask OpenFlags = 3
+
+ // OpenTruncate is a Tlopen flag indicating that the opened file should be
+ // truncated.
+ OpenTruncate OpenFlags = 01000
+)
+
+// ConnectFlags is the mode passed to Connect operations.
+//
+// These correspond to bits sent over the wire.
+type ConnectFlags uint32
+
+const (
+ // StreamSocket is a Tlconnect flag indicating SOCK_STREAM mode.
+ StreamSocket ConnectFlags = 0
+
+ // DgramSocket is a Tlconnect flag indicating SOCK_DGRAM mode.
+ DgramSocket ConnectFlags = 1
+
+ // SeqpacketSocket is a Tlconnect flag indicating SOCK_SEQPACKET mode.
+ SeqpacketSocket ConnectFlags = 2
+
+ // AnonymousSocket is a Tlconnect flag indicating that the mode does not
+ // matter and that the requester will accept any socket type.
+ AnonymousSocket ConnectFlags = 3
+)
+
+// OSFlags converts a p9.OpenFlags to an int compatible with open(2).
+func (o OpenFlags) OSFlags() int {
+ // "flags contains Linux open(2) flags bits" - 9P2000.L
+ return int(o)
+}
+
+// String implements fmt.Stringer.
+func (o OpenFlags) String() string {
+ var buf strings.Builder
+ switch mode := o & OpenFlagsModeMask; mode {
+ case ReadOnly:
+ buf.WriteString("ReadOnly")
+ case WriteOnly:
+ buf.WriteString("WriteOnly")
+ case ReadWrite:
+ buf.WriteString("ReadWrite")
+ default:
+ fmt.Fprintf(&buf, "%#o", mode)
+ }
+ otherFlags := o &^ OpenFlagsModeMask
+ if otherFlags&OpenTruncate != 0 {
+ buf.WriteString("|OpenTruncate")
+ otherFlags &^= OpenTruncate
+ }
+ if otherFlags != 0 {
+ fmt.Fprintf(&buf, "|%#o", otherFlags)
+ }
+ return buf.String()
+}
+
+// Tag is a message tag.
+type Tag uint16
+
+// FID is a file identifier.
+type FID uint64
+
+// FileMode are flags corresponding to file modes.
+//
+// These correspond to bits sent over the wire.
+// These also correspond to mode_t bits.
+type FileMode uint32
+
+const (
+ // FileModeMask is a mask of all the file mode bits of FileMode.
+ FileModeMask FileMode = 0170000
+
+ // ModeSocket is an (unused) mode bit for a socket.
+ ModeSocket FileMode = 0140000
+
+ // ModeSymlink is a mode bit for a symlink.
+ ModeSymlink FileMode = 0120000
+
+ // ModeRegular is a mode bit for regular files.
+ ModeRegular FileMode = 0100000
+
+ // ModeBlockDevice is a mode bit for block devices.
+ ModeBlockDevice FileMode = 060000
+
+ // ModeDirectory is a mode bit for directories.
+ ModeDirectory FileMode = 040000
+
+ // ModeCharacterDevice is a mode bit for a character device.
+ ModeCharacterDevice FileMode = 020000
+
+ // ModeNamedPipe is a mode bit for a named pipe.
+ ModeNamedPipe FileMode = 010000
+
+ // Read is a mode bit indicating read permission.
+ Read FileMode = 04
+
+ // Write is a mode bit indicating write permission.
+ Write FileMode = 02
+
+ // Exec is a mode bit indicating exec permission.
+ Exec FileMode = 01
+
+ // AllPermissions is a mask with rwx bits set for user, group and others.
+ AllPermissions FileMode = 0777
+
+ // Sticky is a mode bit indicating sticky directories.
+ Sticky FileMode = 01000
+
+ // permissionsMask is the mask to apply to FileModes for permissions. It
+ // includes rwx bits for user, group and others, and sticky bit.
+ permissionsMask FileMode = 01777
+)
+
+// QIDType is the most significant byte of the FileMode word, to be used as the
+// Type field of p9.QID.
+func (m FileMode) QIDType() QIDType {
+ switch {
+ case m.IsDir():
+ return TypeDir
+ case m.IsSocket(), m.IsNamedPipe(), m.IsCharacterDevice():
+ // Best approximation.
+ return TypeAppendOnly
+ case m.IsSymlink():
+ return TypeSymlink
+ default:
+ return TypeRegular
+ }
+}
+
+// FileType returns the file mode without the permission bits.
+func (m FileMode) FileType() FileMode {
+ return m & FileModeMask
+}
+
+// Permissions returns just the permission bits of the mode.
+func (m FileMode) Permissions() FileMode {
+ return m & permissionsMask
+}
+
+// Writable returns the mode with write bits added.
+func (m FileMode) Writable() FileMode {
+ return m | 0222
+}
+
+// IsReadable returns true if m represents a file that can be read.
+func (m FileMode) IsReadable() bool {
+ return m&0444 != 0
+}
+
+// IsWritable returns true if m represents a file that can be written to.
+func (m FileMode) IsWritable() bool {
+ return m&0222 != 0
+}
+
+// IsExecutable returns true if m represents a file that can be executed.
+func (m FileMode) IsExecutable() bool {
+ return m&0111 != 0
+}
+
+// IsRegular returns true if m is a regular file.
+func (m FileMode) IsRegular() bool {
+ return m&FileModeMask == ModeRegular
+}
+
+// IsDir returns true if m represents a directory.
+func (m FileMode) IsDir() bool {
+ return m&FileModeMask == ModeDirectory
+}
+
+// IsNamedPipe returns true if m represents a named pipe.
+func (m FileMode) IsNamedPipe() bool {
+ return m&FileModeMask == ModeNamedPipe
+}
+
+// IsCharacterDevice returns true if m represents a character device.
+func (m FileMode) IsCharacterDevice() bool {
+ return m&FileModeMask == ModeCharacterDevice
+}
+
+// IsBlockDevice returns true if m represents a character device.
+func (m FileMode) IsBlockDevice() bool {
+ return m&FileModeMask == ModeBlockDevice
+}
+
+// IsSocket returns true if m represents a socket.
+func (m FileMode) IsSocket() bool {
+ return m&FileModeMask == ModeSocket
+}
+
+// IsSymlink returns true if m represents a symlink.
+func (m FileMode) IsSymlink() bool {
+ return m&FileModeMask == ModeSymlink
+}
+
+// ModeFromOS returns a FileMode from an os.FileMode.
+func ModeFromOS(mode os.FileMode) FileMode {
+ m := FileMode(mode.Perm())
+ switch {
+ case mode.IsDir():
+ m |= ModeDirectory
+ case mode&os.ModeSymlink != 0:
+ m |= ModeSymlink
+ case mode&os.ModeSocket != 0:
+ m |= ModeSocket
+ case mode&os.ModeNamedPipe != 0:
+ m |= ModeNamedPipe
+ case mode&os.ModeCharDevice != 0:
+ m |= ModeCharacterDevice
+ case mode&os.ModeDevice != 0:
+ m |= ModeBlockDevice
+ default:
+ m |= ModeRegular
+ }
+ return m
+}
+
+// OSMode converts a p9.FileMode to an os.FileMode.
+func (m FileMode) OSMode() os.FileMode {
+ var osMode os.FileMode
+ osMode |= os.FileMode(m.Permissions())
+ switch {
+ case m.IsDir():
+ osMode |= os.ModeDir
+ case m.IsSymlink():
+ osMode |= os.ModeSymlink
+ case m.IsSocket():
+ osMode |= os.ModeSocket
+ case m.IsNamedPipe():
+ osMode |= os.ModeNamedPipe
+ case m.IsCharacterDevice():
+ osMode |= os.ModeCharDevice | os.ModeDevice
+ case m.IsBlockDevice():
+ osMode |= os.ModeDevice
+ }
+ return osMode
+}
+
+// UID represents a user ID.
+type UID uint32
+
+// Ok returns true if uid is not NoUID.
+func (uid UID) Ok() bool {
+ return uid != NoUID
+}
+
+// GID represents a group ID.
+type GID uint32
+
+// Ok returns true if gid is not NoGID.
+func (gid GID) Ok() bool {
+ return gid != NoGID
+}
+
+const (
+ // NoTag is a sentinel used to indicate no valid tag.
+ NoTag Tag = math.MaxUint16
+
+ // NoFID is a sentinel used to indicate no valid FID.
+ NoFID FID = math.MaxUint32
+
+ // NoUID is a sentinel used to indicate no valid UID.
+ NoUID UID = math.MaxUint32
+
+ // NoGID is a sentinel used to indicate no valid GID.
+ NoGID GID = math.MaxUint32
+)
+
+// MsgType is a type identifier.
+type MsgType uint8
+
+// MsgType declarations.
+const (
+ MsgTlerror MsgType = 6
+ MsgRlerror = 7
+ MsgTstatfs = 8
+ MsgRstatfs = 9
+ MsgTlopen = 12
+ MsgRlopen = 13
+ MsgTlcreate = 14
+ MsgRlcreate = 15
+ MsgTsymlink = 16
+ MsgRsymlink = 17
+ MsgTmknod = 18
+ MsgRmknod = 19
+ MsgTrename = 20
+ MsgRrename = 21
+ MsgTreadlink = 22
+ MsgRreadlink = 23
+ MsgTgetattr = 24
+ MsgRgetattr = 25
+ MsgTsetattr = 26
+ MsgRsetattr = 27
+ MsgTlistxattr = 28
+ MsgRlistxattr = 29
+ MsgTxattrwalk = 30
+ MsgRxattrwalk = 31
+ MsgTxattrcreate = 32
+ MsgRxattrcreate = 33
+ MsgTgetxattr = 34
+ MsgRgetxattr = 35
+ MsgTsetxattr = 36
+ MsgRsetxattr = 37
+ MsgTremovexattr = 38
+ MsgRremovexattr = 39
+ MsgTreaddir = 40
+ MsgRreaddir = 41
+ MsgTfsync = 50
+ MsgRfsync = 51
+ MsgTlink = 70
+ MsgRlink = 71
+ MsgTmkdir = 72
+ MsgRmkdir = 73
+ MsgTrenameat = 74
+ MsgRrenameat = 75
+ MsgTunlinkat = 76
+ MsgRunlinkat = 77
+ MsgTversion = 100
+ MsgRversion = 101
+ MsgTauth = 102
+ MsgRauth = 103
+ MsgTattach = 104
+ MsgRattach = 105
+ MsgTflush = 108
+ MsgRflush = 109
+ MsgTwalk = 110
+ MsgRwalk = 111
+ MsgTread = 116
+ MsgRread = 117
+ MsgTwrite = 118
+ MsgRwrite = 119
+ MsgTclunk = 120
+ MsgRclunk = 121
+ MsgTremove = 122
+ MsgRremove = 123
+ MsgTflushf = 124
+ MsgRflushf = 125
+ MsgTwalkgetattr = 126
+ MsgRwalkgetattr = 127
+ MsgTucreate = 128
+ MsgRucreate = 129
+ MsgTumkdir = 130
+ MsgRumkdir = 131
+ MsgTumknod = 132
+ MsgRumknod = 133
+ MsgTusymlink = 134
+ MsgRusymlink = 135
+ MsgTlconnect = 136
+ MsgRlconnect = 137
+ MsgTallocate = 138
+ MsgRallocate = 139
+ MsgTchannel = 250
+ MsgRchannel = 251
+)
+
+// QIDType represents the file type for QIDs.
+//
+// QIDType corresponds to the high 8 bits of a Plan 9 file mode.
+type QIDType uint8
+
+const (
+ // TypeDir represents a directory type.
+ TypeDir QIDType = 0x80
+
+ // TypeAppendOnly represents an append only file.
+ TypeAppendOnly QIDType = 0x40
+
+ // TypeExclusive represents an exclusive-use file.
+ TypeExclusive QIDType = 0x20
+
+ // TypeMount represents a mounted channel.
+ TypeMount QIDType = 0x10
+
+ // TypeAuth represents an authentication file.
+ TypeAuth QIDType = 0x08
+
+ // TypeTemporary represents a temporary file.
+ TypeTemporary QIDType = 0x04
+
+ // TypeSymlink represents a symlink.
+ TypeSymlink QIDType = 0x02
+
+ // TypeLink represents a hard link.
+ TypeLink QIDType = 0x01
+
+ // TypeRegular represents a regular file.
+ TypeRegular QIDType = 0x00
+)
+
+// QID is a unique file identifier.
+//
+// This may be embedded in other requests and responses.
+type QID struct {
+ // Type is the highest order byte of the file mode.
+ Type QIDType
+
+ // Version is an arbitrary server version number.
+ Version uint32
+
+ // Path is a unique server identifier for this path (e.g. inode).
+ Path uint64
+}
+
+// String implements fmt.Stringer.
+func (q QID) String() string {
+ return fmt.Sprintf("QID{Type: %d, Version: %d, Path: %d}", q.Type, q.Version, q.Path)
+}
+
+// decode implements encoder.decode.
+func (q *QID) decode(b *buffer) {
+ q.Type = b.ReadQIDType()
+ q.Version = b.Read32()
+ q.Path = b.Read64()
+}
+
+// encode implements encoder.encode.
+func (q *QID) encode(b *buffer) {
+ b.WriteQIDType(q.Type)
+ b.Write32(q.Version)
+ b.Write64(q.Path)
+}
+
+// QIDGenerator is a simple generator for QIDs that atomically increments Path
+// values.
+type QIDGenerator struct {
+ // uids is an ever increasing value that can be atomically incremented
+ // to provide unique Path values for QIDs.
+ uids uint64
+}
+
+// Get returns a new 9P unique ID with a unique Path given a QID type.
+//
+// While the 9P spec allows Version to be incremented every time the file is
+// modified, we currently do not use the Version member for anything. Hence,
+// it is set to 0.
+func (q *QIDGenerator) Get(t QIDType) QID {
+ return QID{
+ Type: t,
+ Version: 0,
+ Path: atomic.AddUint64(&q.uids, 1),
+ }
+}
+
+// FSStat is used by statfs.
+type FSStat struct {
+ // Type is the filesystem type.
+ Type uint32
+
+ // BlockSize is the blocksize.
+ BlockSize uint32
+
+ // Blocks is the number of blocks.
+ Blocks uint64
+
+ // BlocksFree is the number of free blocks.
+ BlocksFree uint64
+
+ // BlocksAvailable is the number of blocks *available*.
+ BlocksAvailable uint64
+
+ // Files is the number of files available.
+ Files uint64
+
+ // FilesFree is the number of free file nodes.
+ FilesFree uint64
+
+ // FSID is the filesystem ID.
+ FSID uint64
+
+ // NameLength is the maximum name length.
+ NameLength uint32
+}
+
+// decode implements encoder.decode.
+func (f *FSStat) decode(b *buffer) {
+ f.Type = b.Read32()
+ f.BlockSize = b.Read32()
+ f.Blocks = b.Read64()
+ f.BlocksFree = b.Read64()
+ f.BlocksAvailable = b.Read64()
+ f.Files = b.Read64()
+ f.FilesFree = b.Read64()
+ f.FSID = b.Read64()
+ f.NameLength = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (f *FSStat) encode(b *buffer) {
+ b.Write32(f.Type)
+ b.Write32(f.BlockSize)
+ b.Write64(f.Blocks)
+ b.Write64(f.BlocksFree)
+ b.Write64(f.BlocksAvailable)
+ b.Write64(f.Files)
+ b.Write64(f.FilesFree)
+ b.Write64(f.FSID)
+ b.Write32(f.NameLength)
+}
+
+// AttrMask is a mask of attributes for getattr.
+type AttrMask struct {
+ Mode bool
+ NLink bool
+ UID bool
+ GID bool
+ RDev bool
+ ATime bool
+ MTime bool
+ CTime bool
+ INo bool
+ Size bool
+ Blocks bool
+ BTime bool
+ Gen bool
+ DataVersion bool
+}
+
+// Contains returns true if a contains all of the attributes masked as b.
+func (a AttrMask) Contains(b AttrMask) bool {
+ if b.Mode && !a.Mode {
+ return false
+ }
+ if b.NLink && !a.NLink {
+ return false
+ }
+ if b.UID && !a.UID {
+ return false
+ }
+ if b.GID && !a.GID {
+ return false
+ }
+ if b.RDev && !a.RDev {
+ return false
+ }
+ if b.ATime && !a.ATime {
+ return false
+ }
+ if b.MTime && !a.MTime {
+ return false
+ }
+ if b.CTime && !a.CTime {
+ return false
+ }
+ if b.INo && !a.INo {
+ return false
+ }
+ if b.Size && !a.Size {
+ return false
+ }
+ if b.Blocks && !a.Blocks {
+ return false
+ }
+ if b.BTime && !a.BTime {
+ return false
+ }
+ if b.Gen && !a.Gen {
+ return false
+ }
+ if b.DataVersion && !a.DataVersion {
+ return false
+ }
+ return true
+}
+
+// Empty returns true if no fields are masked.
+func (a AttrMask) Empty() bool {
+ return !a.Mode && !a.NLink && !a.UID && !a.GID && !a.RDev && !a.ATime && !a.MTime && !a.CTime && !a.INo && !a.Size && !a.Blocks && !a.BTime && !a.Gen && !a.DataVersion
+}
+
+// AttrMaskAll returns an AttrMask with all fields masked.
+func AttrMaskAll() AttrMask {
+ return AttrMask{
+ Mode: true,
+ NLink: true,
+ UID: true,
+ GID: true,
+ RDev: true,
+ ATime: true,
+ MTime: true,
+ CTime: true,
+ INo: true,
+ Size: true,
+ Blocks: true,
+ BTime: true,
+ Gen: true,
+ DataVersion: true,
+ }
+}
+
+// String implements fmt.Stringer.
+func (a AttrMask) String() string {
+ var masks []string
+ if a.Mode {
+ masks = append(masks, "Mode")
+ }
+ if a.NLink {
+ masks = append(masks, "NLink")
+ }
+ if a.UID {
+ masks = append(masks, "UID")
+ }
+ if a.GID {
+ masks = append(masks, "GID")
+ }
+ if a.RDev {
+ masks = append(masks, "RDev")
+ }
+ if a.ATime {
+ masks = append(masks, "ATime")
+ }
+ if a.MTime {
+ masks = append(masks, "MTime")
+ }
+ if a.CTime {
+ masks = append(masks, "CTime")
+ }
+ if a.INo {
+ masks = append(masks, "INo")
+ }
+ if a.Size {
+ masks = append(masks, "Size")
+ }
+ if a.Blocks {
+ masks = append(masks, "Blocks")
+ }
+ if a.BTime {
+ masks = append(masks, "BTime")
+ }
+ if a.Gen {
+ masks = append(masks, "Gen")
+ }
+ if a.DataVersion {
+ masks = append(masks, "DataVersion")
+ }
+ return fmt.Sprintf("AttrMask{with: %s}", strings.Join(masks, " "))
+}
+
+// decode implements encoder.decode.
+func (a *AttrMask) decode(b *buffer) {
+ mask := b.Read64()
+ a.Mode = mask&0x00000001 != 0
+ a.NLink = mask&0x00000002 != 0
+ a.UID = mask&0x00000004 != 0
+ a.GID = mask&0x00000008 != 0
+ a.RDev = mask&0x00000010 != 0
+ a.ATime = mask&0x00000020 != 0
+ a.MTime = mask&0x00000040 != 0
+ a.CTime = mask&0x00000080 != 0
+ a.INo = mask&0x00000100 != 0
+ a.Size = mask&0x00000200 != 0
+ a.Blocks = mask&0x00000400 != 0
+ a.BTime = mask&0x00000800 != 0
+ a.Gen = mask&0x00001000 != 0
+ a.DataVersion = mask&0x00002000 != 0
+}
+
+// encode implements encoder.encode.
+func (a *AttrMask) encode(b *buffer) {
+ var mask uint64
+ if a.Mode {
+ mask |= 0x00000001
+ }
+ if a.NLink {
+ mask |= 0x00000002
+ }
+ if a.UID {
+ mask |= 0x00000004
+ }
+ if a.GID {
+ mask |= 0x00000008
+ }
+ if a.RDev {
+ mask |= 0x00000010
+ }
+ if a.ATime {
+ mask |= 0x00000020
+ }
+ if a.MTime {
+ mask |= 0x00000040
+ }
+ if a.CTime {
+ mask |= 0x00000080
+ }
+ if a.INo {
+ mask |= 0x00000100
+ }
+ if a.Size {
+ mask |= 0x00000200
+ }
+ if a.Blocks {
+ mask |= 0x00000400
+ }
+ if a.BTime {
+ mask |= 0x00000800
+ }
+ if a.Gen {
+ mask |= 0x00001000
+ }
+ if a.DataVersion {
+ mask |= 0x00002000
+ }
+ b.Write64(mask)
+}
+
+// Attr is a set of attributes for getattr.
+type Attr struct {
+ Mode FileMode
+ UID UID
+ GID GID
+ NLink uint64
+ RDev uint64
+ Size uint64
+ BlockSize uint64
+ Blocks uint64
+ ATimeSeconds uint64
+ ATimeNanoSeconds uint64
+ MTimeSeconds uint64
+ MTimeNanoSeconds uint64
+ CTimeSeconds uint64
+ CTimeNanoSeconds uint64
+ BTimeSeconds uint64
+ BTimeNanoSeconds uint64
+ Gen uint64
+ DataVersion uint64
+}
+
+// String implements fmt.Stringer.
+func (a Attr) String() string {
+ return fmt.Sprintf("Attr{Mode: 0o%o, UID: %d, GID: %d, NLink: %d, RDev: %d, Size: %d, BlockSize: %d, Blocks: %d, ATime: {Sec: %d, NanoSec: %d}, MTime: {Sec: %d, NanoSec: %d}, CTime: {Sec: %d, NanoSec: %d}, BTime: {Sec: %d, NanoSec: %d}, Gen: %d, DataVersion: %d}",
+ a.Mode, a.UID, a.GID, a.NLink, a.RDev, a.Size, a.BlockSize, a.Blocks, a.ATimeSeconds, a.ATimeNanoSeconds, a.MTimeSeconds, a.MTimeNanoSeconds, a.CTimeSeconds, a.CTimeNanoSeconds, a.BTimeSeconds, a.BTimeNanoSeconds, a.Gen, a.DataVersion)
+}
+
+// encode implements encoder.encode.
+func (a *Attr) encode(b *buffer) {
+ b.WriteFileMode(a.Mode)
+ b.WriteUID(a.UID)
+ b.WriteGID(a.GID)
+ b.Write64(a.NLink)
+ b.Write64(a.RDev)
+ b.Write64(a.Size)
+ b.Write64(a.BlockSize)
+ b.Write64(a.Blocks)
+ b.Write64(a.ATimeSeconds)
+ b.Write64(a.ATimeNanoSeconds)
+ b.Write64(a.MTimeSeconds)
+ b.Write64(a.MTimeNanoSeconds)
+ b.Write64(a.CTimeSeconds)
+ b.Write64(a.CTimeNanoSeconds)
+ b.Write64(a.BTimeSeconds)
+ b.Write64(a.BTimeNanoSeconds)
+ b.Write64(a.Gen)
+ b.Write64(a.DataVersion)
+}
+
+// decode implements encoder.decode.
+func (a *Attr) decode(b *buffer) {
+ a.Mode = b.ReadFileMode()
+ a.UID = b.ReadUID()
+ a.GID = b.ReadGID()
+ a.NLink = b.Read64()
+ a.RDev = b.Read64()
+ a.Size = b.Read64()
+ a.BlockSize = b.Read64()
+ a.Blocks = b.Read64()
+ a.ATimeSeconds = b.Read64()
+ a.ATimeNanoSeconds = b.Read64()
+ a.MTimeSeconds = b.Read64()
+ a.MTimeNanoSeconds = b.Read64()
+ a.CTimeSeconds = b.Read64()
+ a.CTimeNanoSeconds = b.Read64()
+ a.BTimeSeconds = b.Read64()
+ a.BTimeNanoSeconds = b.Read64()
+ a.Gen = b.Read64()
+ a.DataVersion = b.Read64()
+}
+
+// StatToAttr converts a Linux syscall stat structure to an Attr.
+func StatToAttr(s *syscall.Stat_t, req AttrMask) (Attr, AttrMask) {
+ attr := Attr{
+ UID: NoUID,
+ GID: NoGID,
+ }
+ if req.Mode {
+ // p9.FileMode corresponds to Linux mode_t.
+ attr.Mode = FileMode(s.Mode)
+ }
+ if req.NLink {
+ attr.NLink = uint64(s.Nlink)
+ }
+ if req.UID {
+ attr.UID = UID(s.Uid)
+ }
+ if req.GID {
+ attr.GID = GID(s.Gid)
+ }
+ if req.RDev {
+ attr.RDev = s.Dev
+ }
+ if req.ATime {
+ attr.ATimeSeconds = uint64(s.Atim.Sec)
+ attr.ATimeNanoSeconds = uint64(s.Atim.Nsec)
+ }
+ if req.MTime {
+ attr.MTimeSeconds = uint64(s.Mtim.Sec)
+ attr.MTimeNanoSeconds = uint64(s.Mtim.Nsec)
+ }
+ if req.CTime {
+ attr.CTimeSeconds = uint64(s.Ctim.Sec)
+ attr.CTimeNanoSeconds = uint64(s.Ctim.Nsec)
+ }
+ if req.Size {
+ attr.Size = uint64(s.Size)
+ }
+ if req.Blocks {
+ attr.BlockSize = uint64(s.Blksize)
+ attr.Blocks = uint64(s.Blocks)
+ }
+
+ // Use the req field because we already have it.
+ req.BTime = false
+ req.Gen = false
+ req.DataVersion = false
+
+ return attr, req
+}
+
+// SetAttrMask specifies a valid mask for setattr.
+type SetAttrMask struct {
+ Permissions bool
+ UID bool
+ GID bool
+ Size bool
+ ATime bool
+ MTime bool
+ CTime bool
+ ATimeNotSystemTime bool
+ MTimeNotSystemTime bool
+}
+
+// IsSubsetOf returns whether s is a subset of m.
+func (s SetAttrMask) IsSubsetOf(m SetAttrMask) bool {
+ sb := s.bitmask()
+ sm := m.bitmask()
+ return sm|sb == sm
+}
+
+// String implements fmt.Stringer.
+func (s SetAttrMask) String() string {
+ var masks []string
+ if s.Permissions {
+ masks = append(masks, "Permissions")
+ }
+ if s.UID {
+ masks = append(masks, "UID")
+ }
+ if s.GID {
+ masks = append(masks, "GID")
+ }
+ if s.Size {
+ masks = append(masks, "Size")
+ }
+ if s.ATime {
+ masks = append(masks, "ATime")
+ }
+ if s.MTime {
+ masks = append(masks, "MTime")
+ }
+ if s.CTime {
+ masks = append(masks, "CTime")
+ }
+ if s.ATimeNotSystemTime {
+ masks = append(masks, "ATimeNotSystemTime")
+ }
+ if s.MTimeNotSystemTime {
+ masks = append(masks, "MTimeNotSystemTime")
+ }
+ return fmt.Sprintf("SetAttrMask{with: %s}", strings.Join(masks, " "))
+}
+
+// Empty returns true if no fields are masked.
+func (s SetAttrMask) Empty() bool {
+ return !s.Permissions && !s.UID && !s.GID && !s.Size && !s.ATime && !s.MTime && !s.CTime && !s.ATimeNotSystemTime && !s.MTimeNotSystemTime
+}
+
+// decode implements encoder.decode.
+func (s *SetAttrMask) decode(b *buffer) {
+ mask := b.Read32()
+ s.Permissions = mask&0x00000001 != 0
+ s.UID = mask&0x00000002 != 0
+ s.GID = mask&0x00000004 != 0
+ s.Size = mask&0x00000008 != 0
+ s.ATime = mask&0x00000010 != 0
+ s.MTime = mask&0x00000020 != 0
+ s.CTime = mask&0x00000040 != 0
+ s.ATimeNotSystemTime = mask&0x00000080 != 0
+ s.MTimeNotSystemTime = mask&0x00000100 != 0
+}
+
+func (s SetAttrMask) bitmask() uint32 {
+ var mask uint32
+ if s.Permissions {
+ mask |= 0x00000001
+ }
+ if s.UID {
+ mask |= 0x00000002
+ }
+ if s.GID {
+ mask |= 0x00000004
+ }
+ if s.Size {
+ mask |= 0x00000008
+ }
+ if s.ATime {
+ mask |= 0x00000010
+ }
+ if s.MTime {
+ mask |= 0x00000020
+ }
+ if s.CTime {
+ mask |= 0x00000040
+ }
+ if s.ATimeNotSystemTime {
+ mask |= 0x00000080
+ }
+ if s.MTimeNotSystemTime {
+ mask |= 0x00000100
+ }
+ return mask
+}
+
+// encode implements encoder.encode.
+func (s *SetAttrMask) encode(b *buffer) {
+ b.Write32(s.bitmask())
+}
+
+// SetAttr specifies a set of attributes for a setattr.
+type SetAttr struct {
+ Permissions FileMode
+ UID UID
+ GID GID
+ Size uint64
+ ATimeSeconds uint64
+ ATimeNanoSeconds uint64
+ MTimeSeconds uint64
+ MTimeNanoSeconds uint64
+}
+
+// String implements fmt.Stringer.
+func (s SetAttr) String() string {
+ return fmt.Sprintf("SetAttr{Permissions: 0o%o, UID: %d, GID: %d, Size: %d, ATime: {Sec: %d, NanoSec: %d}, MTime: {Sec: %d, NanoSec: %d}}", s.Permissions, s.UID, s.GID, s.Size, s.ATimeSeconds, s.ATimeNanoSeconds, s.MTimeSeconds, s.MTimeNanoSeconds)
+}
+
+// decode implements encoder.decode.
+func (s *SetAttr) decode(b *buffer) {
+ s.Permissions = b.ReadPermissions()
+ s.UID = b.ReadUID()
+ s.GID = b.ReadGID()
+ s.Size = b.Read64()
+ s.ATimeSeconds = b.Read64()
+ s.ATimeNanoSeconds = b.Read64()
+ s.MTimeSeconds = b.Read64()
+ s.MTimeNanoSeconds = b.Read64()
+}
+
+// encode implements encoder.encode.
+func (s *SetAttr) encode(b *buffer) {
+ b.WritePermissions(s.Permissions)
+ b.WriteUID(s.UID)
+ b.WriteGID(s.GID)
+ b.Write64(s.Size)
+ b.Write64(s.ATimeSeconds)
+ b.Write64(s.ATimeNanoSeconds)
+ b.Write64(s.MTimeSeconds)
+ b.Write64(s.MTimeNanoSeconds)
+}
+
+// Apply applies this to the given Attr.
+func (a *Attr) Apply(mask SetAttrMask, attr SetAttr) {
+ if mask.Permissions {
+ a.Mode = a.Mode&^permissionsMask | (attr.Permissions & permissionsMask)
+ }
+ if mask.UID {
+ a.UID = attr.UID
+ }
+ if mask.GID {
+ a.GID = attr.GID
+ }
+ if mask.Size {
+ a.Size = attr.Size
+ }
+ if mask.ATime {
+ a.ATimeSeconds = attr.ATimeSeconds
+ a.ATimeNanoSeconds = attr.ATimeNanoSeconds
+ }
+ if mask.MTime {
+ a.MTimeSeconds = attr.MTimeSeconds
+ a.MTimeNanoSeconds = attr.MTimeNanoSeconds
+ }
+}
+
+// Dirent is used for readdir.
+type Dirent struct {
+ // QID is the entry QID.
+ QID QID
+
+ // Offset is the offset in the directory.
+ //
+ // This will be communicated back the original caller.
+ Offset uint64
+
+ // Type is the 9P type.
+ Type QIDType
+
+ // Name is the name of the entry (i.e. basename).
+ Name string
+}
+
+// String implements fmt.Stringer.
+func (d Dirent) String() string {
+ return fmt.Sprintf("Dirent{QID: %d, Offset: %d, Type: 0x%X, Name: %s}", d.QID, d.Offset, d.Type, d.Name)
+}
+
+// decode implements encoder.decode.
+func (d *Dirent) decode(b *buffer) {
+ d.QID.decode(b)
+ d.Offset = b.Read64()
+ d.Type = b.ReadQIDType()
+ d.Name = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (d *Dirent) encode(b *buffer) {
+ d.QID.encode(b)
+ b.Write64(d.Offset)
+ b.WriteQIDType(d.Type)
+ b.WriteString(d.Name)
+}
+
+// AllocateMode are possible modes to p9.File.Allocate().
+type AllocateMode struct {
+ KeepSize bool
+ PunchHole bool
+ NoHideStale bool
+ CollapseRange bool
+ ZeroRange bool
+ InsertRange bool
+ Unshare bool
+}
+
+// ToAllocateMode returns an AllocateMode from a fallocate(2) mode.
+func ToAllocateMode(mode uint64) AllocateMode {
+ return AllocateMode{
+ KeepSize: mode&unix.FALLOC_FL_KEEP_SIZE != 0,
+ PunchHole: mode&unix.FALLOC_FL_PUNCH_HOLE != 0,
+ NoHideStale: mode&unix.FALLOC_FL_NO_HIDE_STALE != 0,
+ CollapseRange: mode&unix.FALLOC_FL_COLLAPSE_RANGE != 0,
+ ZeroRange: mode&unix.FALLOC_FL_ZERO_RANGE != 0,
+ InsertRange: mode&unix.FALLOC_FL_INSERT_RANGE != 0,
+ Unshare: mode&unix.FALLOC_FL_UNSHARE_RANGE != 0,
+ }
+}
+
+// ToLinux converts to a value compatible with fallocate(2)'s mode.
+func (a *AllocateMode) ToLinux() uint32 {
+ rv := uint32(0)
+ if a.KeepSize {
+ rv |= unix.FALLOC_FL_KEEP_SIZE
+ }
+ if a.PunchHole {
+ rv |= unix.FALLOC_FL_PUNCH_HOLE
+ }
+ if a.NoHideStale {
+ rv |= unix.FALLOC_FL_NO_HIDE_STALE
+ }
+ if a.CollapseRange {
+ rv |= unix.FALLOC_FL_COLLAPSE_RANGE
+ }
+ if a.ZeroRange {
+ rv |= unix.FALLOC_FL_ZERO_RANGE
+ }
+ if a.InsertRange {
+ rv |= unix.FALLOC_FL_INSERT_RANGE
+ }
+ if a.Unshare {
+ rv |= unix.FALLOC_FL_UNSHARE_RANGE
+ }
+ return rv
+}
+
+// decode implements encoder.decode.
+func (a *AllocateMode) decode(b *buffer) {
+ mask := b.Read32()
+ a.KeepSize = mask&0x01 != 0
+ a.PunchHole = mask&0x02 != 0
+ a.NoHideStale = mask&0x04 != 0
+ a.CollapseRange = mask&0x08 != 0
+ a.ZeroRange = mask&0x10 != 0
+ a.InsertRange = mask&0x20 != 0
+ a.Unshare = mask&0x40 != 0
+}
+
+// encode implements encoder.encode.
+func (a *AllocateMode) encode(b *buffer) {
+ mask := uint32(0)
+ if a.KeepSize {
+ mask |= 0x01
+ }
+ if a.PunchHole {
+ mask |= 0x02
+ }
+ if a.NoHideStale {
+ mask |= 0x04
+ }
+ if a.CollapseRange {
+ mask |= 0x08
+ }
+ if a.ZeroRange {
+ mask |= 0x10
+ }
+ if a.InsertRange {
+ mask |= 0x20
+ }
+ if a.Unshare {
+ mask |= 0x40
+ }
+ b.Write32(mask)
+}
diff --git a/pkg/p9/p9_test.go b/pkg/p9/p9_test.go
new file mode 100644
index 000000000..8dda6cc64
--- /dev/null
+++ b/pkg/p9/p9_test.go
@@ -0,0 +1,188 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "os"
+ "testing"
+)
+
+func TestFileModeHelpers(t *testing.T) {
+ fns := map[FileMode]struct {
+ // name identifies the file mode.
+ name string
+
+ // function is the function that should return true given the
+ // right FileMode.
+ function func(m FileMode) bool
+ }{
+ ModeRegular: {
+ name: "regular",
+ function: FileMode.IsRegular,
+ },
+ ModeDirectory: {
+ name: "directory",
+ function: FileMode.IsDir,
+ },
+ ModeNamedPipe: {
+ name: "named pipe",
+ function: FileMode.IsNamedPipe,
+ },
+ ModeCharacterDevice: {
+ name: "character device",
+ function: FileMode.IsCharacterDevice,
+ },
+ ModeBlockDevice: {
+ name: "block device",
+ function: FileMode.IsBlockDevice,
+ },
+ ModeSymlink: {
+ name: "symlink",
+ function: FileMode.IsSymlink,
+ },
+ ModeSocket: {
+ name: "socket",
+ function: FileMode.IsSocket,
+ },
+ }
+ for mode, info := range fns {
+ // Make sure the mode doesn't identify as anything but itself.
+ for testMode, testfns := range fns {
+ if mode != testMode && testfns.function(mode) {
+ t.Errorf("Mode %s returned true when asked if it was mode %s", info.name, testfns.name)
+ }
+ }
+
+ // Make sure mode identifies as itself.
+ if !info.function(mode) {
+ t.Errorf("Mode %s returned false when asked if it was itself", info.name)
+ }
+ }
+}
+
+func TestFileModeToQID(t *testing.T) {
+ for _, test := range []struct {
+ // name identifies the test.
+ name string
+
+ // mode is the FileMode we start out with.
+ mode FileMode
+
+ // want is the corresponding QIDType we expect.
+ want QIDType
+ }{
+ {
+ name: "Directories are of type directory",
+ mode: ModeDirectory,
+ want: TypeDir,
+ },
+ {
+ name: "Sockets are append-only files",
+ mode: ModeSocket,
+ want: TypeAppendOnly,
+ },
+ {
+ name: "Named pipes are append-only files",
+ mode: ModeNamedPipe,
+ want: TypeAppendOnly,
+ },
+ {
+ name: "Character devices are append-only files",
+ mode: ModeCharacterDevice,
+ want: TypeAppendOnly,
+ },
+ {
+ name: "Symlinks are of type symlink",
+ mode: ModeSymlink,
+ want: TypeSymlink,
+ },
+ {
+ name: "Regular files are of type regular",
+ mode: ModeRegular,
+ want: TypeRegular,
+ },
+ {
+ name: "Block devices are regular files",
+ mode: ModeBlockDevice,
+ want: TypeRegular,
+ },
+ } {
+ if qidType := test.mode.QIDType(); qidType != test.want {
+ t.Errorf("ModeToQID test %s failed: got %o, wanted %o", test.name, qidType, test.want)
+ }
+ }
+}
+
+func TestP9ModeConverters(t *testing.T) {
+ for _, m := range []FileMode{
+ ModeRegular,
+ ModeDirectory,
+ ModeCharacterDevice,
+ ModeBlockDevice,
+ ModeSocket,
+ ModeSymlink,
+ ModeNamedPipe,
+ } {
+ if mb := ModeFromOS(m.OSMode()); mb != m {
+ t.Errorf("Converting %o to OS.FileMode gives %o and is converted back as %o", m, m.OSMode(), mb)
+ }
+ }
+}
+
+func TestOSModeConverters(t *testing.T) {
+ // Modes that can be converted back and forth.
+ for _, m := range []os.FileMode{
+ 0, // Regular file.
+ os.ModeDir,
+ os.ModeCharDevice | os.ModeDevice,
+ os.ModeDevice,
+ os.ModeSocket,
+ os.ModeSymlink,
+ os.ModeNamedPipe,
+ } {
+ if mb := ModeFromOS(m).OSMode(); mb != m {
+ t.Errorf("Converting %o to p9.FileMode gives %o and is converted back as %o", m, ModeFromOS(m), mb)
+ }
+ }
+
+ // Modes that will be converted to a regular file since p9 cannot
+ // express these.
+ for _, m := range []os.FileMode{
+ os.ModeAppend,
+ os.ModeExclusive,
+ os.ModeTemporary,
+ } {
+ if p9Mode := ModeFromOS(m); p9Mode != ModeRegular {
+ t.Errorf("Converting %o to p9.FileMode should have given ModeRegular, but yielded %o", m, p9Mode)
+ }
+ }
+}
+
+func TestAttrMaskContains(t *testing.T) {
+ req := AttrMask{Mode: true, Size: true}
+ have := AttrMask{}
+ if have.Contains(req) {
+ t.Fatalf("AttrMask %v should not be a superset of %v", have, req)
+ }
+ have.Mode = true
+ if have.Contains(req) {
+ t.Fatalf("AttrMask %v should not be a superset of %v", have, req)
+ }
+ have.Size = true
+ have.MTime = true
+ if !have.Contains(req) {
+ t.Fatalf("AttrMask %v should be a superset of %v", have, req)
+ }
+}
diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD
new file mode 100644
index 000000000..7ca67cb19
--- /dev/null
+++ b/pkg/p9/p9test/BUILD
@@ -0,0 +1,88 @@
+load("//tools:defs.bzl", "go_binary", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+alias(
+ name = "mockgen",
+ actual = "@com_github_golang_mock//mockgen:mockgen",
+)
+
+MOCK_SRC_PACKAGE = "gvisor.dev/gvisor/pkg/p9"
+
+# mockgen_reflect is a source file that contains mock generation code that
+# imports the p9 package and generates a specification via reflection. The
+# usual generation path must be split into two distinct parts because the full
+# source tree is not available to all build targets. Only declared depencies
+# are available (and even then, not the Go source files).
+genrule(
+ name = "mockgen_reflect",
+ testonly = 1,
+ outs = ["mockgen_reflect.go"],
+ cmd = (
+ "$(location :mockgen) " +
+ "-package p9test " +
+ "-prog_only " + MOCK_SRC_PACKAGE + " " +
+ "Attacher,File > $@"
+ ),
+ tools = [":mockgen"],
+)
+
+# mockgen_exec is the binary that includes the above reflection generator.
+# Running this binary will emit an encoded version of the p9 Attacher and File
+# structures. This is consumed by the mocks genrule, below.
+go_binary(
+ name = "mockgen_exec",
+ testonly = 1,
+ srcs = ["mockgen_reflect.go"],
+ deps = [
+ "//pkg/p9",
+ "@com_github_golang_mock//mockgen/model:go_default_library",
+ ],
+)
+
+# mocks consumes the encoded output above, and generates the full source for a
+# set of mocks. These are included directly in the p9test library.
+genrule(
+ name = "mocks",
+ testonly = 1,
+ outs = ["mocks.go"],
+ cmd = (
+ "$(location :mockgen) " +
+ "-package p9test " +
+ "-exec_only $(location :mockgen_exec) " + MOCK_SRC_PACKAGE + " File > $@"
+ ),
+ tools = [
+ ":mockgen",
+ ":mockgen_exec",
+ ],
+)
+
+go_library(
+ name = "p9test",
+ srcs = [
+ "mocks.go",
+ "p9test.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/fd",
+ "//pkg/log",
+ "//pkg/p9",
+ "//pkg/sync",
+ "//pkg/unet",
+ "@com_github_golang_mock//gomock:go_default_library",
+ ],
+)
+
+go_test(
+ name = "client_test",
+ size = "medium",
+ srcs = ["client_test.go"],
+ library = ":p9test",
+ deps = [
+ "//pkg/fd",
+ "//pkg/p9",
+ "//pkg/sync",
+ "@com_github_golang_mock//gomock:go_default_library",
+ ],
+)
diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go
new file mode 100644
index 000000000..6e7bb3db2
--- /dev/null
+++ b/pkg/p9/p9test/client_test.go
@@ -0,0 +1,2242 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9test
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "math/rand"
+ "os"
+ "reflect"
+ "strings"
+ "syscall"
+ "testing"
+ "time"
+
+ "github.com/golang/mock/gomock"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func TestPanic(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ // Create a new root.
+ d := h.NewDirectory(nil)(nil)
+ defer d.Close() // Needed manually.
+ h.Attacher.EXPECT().Attach().Return(d, nil).Do(func() {
+ // Panic here, and ensure that we get back EFAULT.
+ panic("handler")
+ })
+
+ // Attach to the client.
+ if _, err := c.Attach("/"); err != syscall.EFAULT {
+ t.Fatalf("got attach err %v, want EFAULT", err)
+ }
+}
+
+func TestAttachNoLeak(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ // Create a new root.
+ d := h.NewDirectory(nil)(nil)
+ h.Attacher.EXPECT().Attach().Return(d, nil).Times(1)
+
+ // Attach to the client.
+ f, err := c.Attach("/")
+ if err != nil {
+ t.Fatalf("got attach err %v, want nil", err)
+ }
+
+ // Don't close the file. This should be closed automatically when the
+ // client disconnects. The mock asserts that everything is closed
+ // exactly once. This statement just removes the unused variable error.
+ _ = f
+}
+
+func TestBadAttach(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ // Return an error on attach.
+ h.Attacher.EXPECT().Attach().Return(nil, syscall.EINVAL).Times(1)
+
+ // Attach to the client.
+ if _, err := c.Attach("/"); err != syscall.EINVAL {
+ t.Fatalf("got attach err %v, want syscall.EINVAL", err)
+ }
+}
+
+func TestWalkAttach(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ // Create a new root.
+ d := h.NewDirectory(map[string]Generator{
+ "a": h.NewDirectory(map[string]Generator{
+ "b": h.NewFile(),
+ }),
+ })(nil)
+ h.Attacher.EXPECT().Attach().Return(d, nil).Times(1)
+
+ // Attach to the client as a non-root, and ensure that the walk above
+ // occurs as expected. We should get back b, and all references should
+ // be dropped when the file is closed.
+ f, err := c.Attach("/a/b")
+ if err != nil {
+ t.Fatalf("got attach err %v, want nil", err)
+ }
+ defer f.Close()
+
+ // Check that's a regular file.
+ if _, _, attr, err := f.GetAttr(p9.AttrMaskAll()); err != nil {
+ t.Errorf("got err %v, want nil", err)
+ } else if !attr.Mode.IsRegular() {
+ t.Errorf("got mode %v, want regular file", err)
+ }
+}
+
+// newTypeMap returns a new type map dictionary.
+func newTypeMap(h *Harness) map[string]Generator {
+ return map[string]Generator{
+ "directory": h.NewDirectory(map[string]Generator{}),
+ "file": h.NewFile(),
+ "symlink": h.NewSymlink(),
+ "block-device": h.NewBlockDevice(),
+ "character-device": h.NewCharacterDevice(),
+ "named-pipe": h.NewNamedPipe(),
+ "socket": h.NewSocket(),
+ }
+}
+
+// newRoot returns a new root filesystem.
+//
+// This is set up in a deterministic way for testing most operations.
+//
+// The represented file system looks like:
+// - file
+// - symlink
+// - directory
+// ...
+// + one
+// - file
+// - symlink
+// - directory
+// ...
+// + two
+// - file
+// - symlink
+// - directory
+// ...
+// + three
+// - file
+// - symlink
+// - directory
+// ...
+func newRoot(h *Harness, c *p9.Client) (*Mock, p9.File) {
+ root := newTypeMap(h)
+ one := newTypeMap(h)
+ two := newTypeMap(h)
+ three := newTypeMap(h)
+ one["two"] = h.NewDirectory(two) // Will be nested in one.
+ root["one"] = h.NewDirectory(one) // Top level.
+ root["three"] = h.NewDirectory(three) // Alternate top-level.
+
+ // Create a new root.
+ rootBackend := h.NewDirectory(root)(nil)
+ h.Attacher.EXPECT().Attach().Return(rootBackend, nil)
+
+ // Attach to the client.
+ r, err := c.Attach("/")
+ if err != nil {
+ h.t.Fatalf("got attach err %v, want nil", err)
+ }
+
+ return rootBackend, r
+}
+
+func allInvalidNames(from string) []string {
+ return []string{
+ from + "/other",
+ from + "/..",
+ from + "/.",
+ from + "/",
+ "other/" + from,
+ "/" + from,
+ "./" + from,
+ "../" + from,
+ ".",
+ "..",
+ "/",
+ "",
+ }
+}
+
+func TestWalkInvalid(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ // Run relevant tests.
+ for name := range newTypeMap(h) {
+ // These are all the various ways that one might attempt to
+ // construct compound paths. They should all be rejected, as
+ // any compound that contains a / is not allowed, as well as
+ // the singular paths of '.' and '..'.
+ if _, _, err := root.Walk([]string{".", name}); err != syscall.EINVAL {
+ t.Errorf("Walk through . %s wanted EINVAL, got %v", name, err)
+ }
+ if _, _, err := root.Walk([]string{"..", name}); err != syscall.EINVAL {
+ t.Errorf("Walk through . %s wanted EINVAL, got %v", name, err)
+ }
+ if _, _, err := root.Walk([]string{name, "."}); err != syscall.EINVAL {
+ t.Errorf("Walk through %s . wanted EINVAL, got %v", name, err)
+ }
+ if _, _, err := root.Walk([]string{name, ".."}); err != syscall.EINVAL {
+ t.Errorf("Walk through %s .. wanted EINVAL, got %v", name, err)
+ }
+ for _, invalidName := range allInvalidNames(name) {
+ if _, _, err := root.Walk([]string{invalidName}); err != syscall.EINVAL {
+ t.Errorf("Walk through %s wanted EINVAL, got %v", invalidName, err)
+ }
+ }
+ wantErr := syscall.EINVAL
+ if name == "directory" {
+ // We can attempt a walk through a directory. However,
+ // we should never see a file named "other", so we
+ // expect this to return ENOENT.
+ wantErr = syscall.ENOENT
+ }
+ if _, _, err := root.Walk([]string{name, "other"}); err != wantErr {
+ t.Errorf("Walk through %s/other wanted %v, got %v", name, wantErr, err)
+ }
+
+ // Do a successful walk.
+ _, f, err := root.Walk([]string{name})
+ if err != nil {
+ t.Errorf("Walk to %s wanted nil, got %v", name, err)
+ }
+ defer f.Close()
+ local := h.Pop(f)
+
+ // Check that the file matches.
+ _, localMask, localAttr, localErr := local.GetAttr(p9.AttrMaskAll())
+ if _, mask, attr, err := f.GetAttr(p9.AttrMaskAll()); mask != localMask || attr != localAttr || err != localErr {
+ t.Errorf("GetAttr got (%v, %v, %v), wanted (%v, %v, %v)",
+ mask, attr, err, localMask, localAttr, localErr)
+ }
+
+ // Ensure we can't walk backwards.
+ if _, _, err := f.Walk([]string{"."}); err != syscall.EINVAL {
+ t.Errorf("Walk through %s/. wanted EINVAL, got %v", name, err)
+ }
+ if _, _, err := f.Walk([]string{".."}); err != syscall.EINVAL {
+ t.Errorf("Walk through %s/.. wanted EINVAL, got %v", name, err)
+ }
+ }
+}
+
+// fileGenerator is a function to generate files via walk or create.
+//
+// Examples are:
+// - walkHelper
+// - walkAndOpenHelper
+// - createHelper
+type fileGenerator func(*Harness, string, p9.File) (*Mock, *Mock, p9.File)
+
+// walkHelper walks to the given file.
+//
+// The backends of the parent and walked file are returned, as well as the
+// walked client file.
+func walkHelper(h *Harness, name string, dir p9.File) (parentBackend *Mock, walkedBackend *Mock, walked p9.File) {
+ _, parent, err := dir.Walk(nil)
+ if err != nil {
+ h.t.Fatalf("Walk(nil) got err %v, want nil", err)
+ }
+ defer parent.Close()
+ parentBackend = h.Pop(parent)
+
+ _, walked, err = parent.Walk([]string{name})
+ if err != nil {
+ h.t.Fatalf("Walk(%s) got err %v, want nil", name, err)
+ }
+ walkedBackend = h.Pop(walked)
+
+ return parentBackend, walkedBackend, walked
+}
+
+// walkAndOpenHelper additionally opens the walked file, if possible.
+func walkAndOpenHelper(h *Harness, name string, dir p9.File) (*Mock, *Mock, p9.File) {
+ parentBackend, walkedBackend, walked := walkHelper(h, name, dir)
+ if p9.CanOpen(walkedBackend.Attr.Mode) {
+ // Open for all file types that we can. We stick to a read-only
+ // open here because directories may not be opened otherwise.
+ walkedBackend.EXPECT().Open(p9.ReadOnly).Times(1)
+ if _, _, _, err := walked.Open(p9.ReadOnly); err != nil {
+ h.t.Errorf("got open err %v, want nil", err)
+ }
+ } else {
+ // ... or assert an error for others.
+ if _, _, _, err := walked.Open(p9.ReadOnly); err != syscall.EINVAL {
+ h.t.Errorf("got open err %v, want EINVAL", err)
+ }
+ }
+ return parentBackend, walkedBackend, walked
+}
+
+// createHelper creates the given file and returns the parent directory,
+// created file and client file, which must be closed when done.
+func createHelper(h *Harness, name string, dir p9.File) (*Mock, *Mock, p9.File) {
+ // Clone the directory first, since Create replaces the existing file.
+ // We change the type after calling create.
+ _, dirThenFile, err := dir.Walk(nil)
+ if err != nil {
+ h.t.Fatalf("got walk err %v, want nil", err)
+ }
+
+ // Create a new server-side file. On the server-side, the a new file is
+ // returned from a create call. The client will reuse the same file,
+ // but we still expect the normal chain of closes. This complicates
+ // things a bit because the "parent" will always chain to the cloned
+ // dir above.
+ dirBackend := h.Pop(dirThenFile) // New backend directory.
+ newFile := h.NewFile()(dirBackend) // New file with backend parent.
+ dirBackend.EXPECT().Create(name, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, newFile, newFile.QID, uint32(0), nil)
+
+ // Create via the client.
+ _, dirThenFile, _, _, err = dirThenFile.Create(name, p9.ReadOnly, 0, 0, 0)
+ if err != nil {
+ h.t.Fatalf("got create err %v, want nil", err)
+ }
+
+ // Ensure subsequent walks succeed.
+ dirBackend.AddChild(name, h.NewFile())
+ return dirBackend, newFile, dirThenFile
+}
+
+// deprecatedRemover allows us to access the deprecated Remove operation within
+// the p9.File client object.
+type deprecatedRemover interface {
+ Remove() error
+}
+
+// checkDeleted asserts that relevant methods fail for an unlinked file.
+//
+// This function will close the file at the end.
+func checkDeleted(h *Harness, file p9.File) {
+ defer file.Close() // See doc.
+
+ if _, _, _, err := file.Open(p9.ReadOnly); err != syscall.EINVAL {
+ h.t.Errorf("open while deleted, got %v, want EINVAL", err)
+ }
+ if _, _, _, _, err := file.Create("created", p9.ReadOnly, 0, 0, 0); err != syscall.EINVAL {
+ h.t.Errorf("create while deleted, got %v, want EINVAL", err)
+ }
+ if _, err := file.Symlink("old", "new", 0, 0); err != syscall.EINVAL {
+ h.t.Errorf("symlink while deleted, got %v, want EINVAL", err)
+ }
+ // N.B. This link is technically invalid, but if a call to link is
+ // actually made in the backend then the mock will panic.
+ if err := file.Link(file, "new"); err != syscall.EINVAL {
+ h.t.Errorf("link while deleted, got %v, want EINVAL", err)
+ }
+ if err := file.RenameAt("src", file, "dst"); err != syscall.EINVAL {
+ h.t.Errorf("renameAt while deleted, got %v, want EINVAL", err)
+ }
+ if err := file.UnlinkAt("file", 0); err != syscall.EINVAL {
+ h.t.Errorf("unlinkAt while deleted, got %v, want EINVAL", err)
+ }
+ if err := file.Rename(file, "dst"); err != syscall.EINVAL {
+ h.t.Errorf("rename while deleted, got %v, want EINVAL", err)
+ }
+ if _, err := file.Readlink(); err != syscall.EINVAL {
+ h.t.Errorf("readlink while deleted, got %v, want EINVAL", err)
+ }
+ if _, err := file.Mkdir("dir", p9.ModeDirectory, 0, 0); err != syscall.EINVAL {
+ h.t.Errorf("mkdir while deleted, got %v, want EINVAL", err)
+ }
+ if _, err := file.Mknod("dir", p9.ModeDirectory, 0, 0, 0, 0); err != syscall.EINVAL {
+ h.t.Errorf("mknod while deleted, got %v, want EINVAL", err)
+ }
+ if _, err := file.Readdir(0, 1); err != syscall.EINVAL {
+ h.t.Errorf("readdir while deleted, got %v, want EINVAL", err)
+ }
+ if _, err := file.Connect(p9.ConnectFlags(0)); err != syscall.EINVAL {
+ h.t.Errorf("connect while deleted, got %v, want EINVAL", err)
+ }
+
+ // The remove method is technically deprecated, but we want to ensure
+ // that it still checks for deleted appropriately. We must first clone
+ // the file because remove is equivalent to close.
+ _, newFile, err := file.Walk(nil)
+ if err == syscall.EBUSY {
+ // We can't walk from here because this reference is open
+ // already. Okay, we will also have unopened cases through
+ // TestUnlink, just skip the remove operation for now.
+ return
+ } else if err != nil {
+ h.t.Fatalf("clone failed, got %v, want nil", err)
+ }
+ if err := newFile.(deprecatedRemover).Remove(); err != syscall.EINVAL {
+ h.t.Errorf("remove while deleted, got %v, want EINVAL", err)
+ }
+}
+
+// deleter is a function to remove a file.
+type deleter func(parent p9.File, name string) error
+
+// unlinkAt is a deleter.
+func unlinkAt(parent p9.File, name string) error {
+ // Call unlink. Note that a filesystem may normally impose additional
+ // constaints on unlinkat success, such as ensuring that a directory is
+ // empty, requiring AT_REMOVEDIR in flags to remove a directory, etc.
+ // None of that is required internally (entire trees can be marked
+ // deleted when this operation succeeds), so the mock will succeed.
+ return parent.UnlinkAt(name, 0)
+}
+
+// remove is a deleter.
+func remove(parent p9.File, name string) error {
+ // See notes above re: remove.
+ _, newFile, err := parent.Walk([]string{name})
+ if err != nil {
+ // Should not be expected.
+ return err
+ }
+
+ // Do the actual remove.
+ if err := newFile.(deprecatedRemover).Remove(); err != nil {
+ return err
+ }
+
+ // Ensure that the remove closed the file.
+ if err := newFile.(deprecatedRemover).Remove(); err != syscall.EBADF {
+ return syscall.EBADF // Propagate this code.
+ }
+
+ return nil
+}
+
+// unlinkHelper unlinks the noted path, and ensures that all relevant
+// operations on that path, acquired from multiple paths, start failing.
+func unlinkHelper(h *Harness, root p9.File, targetNames []string, targetGen fileGenerator, deleteFn deleter) {
+ // name is the file to be unlinked.
+ name := targetNames[len(targetNames)-1]
+
+ // Walk to the directory containing the target.
+ _, parent, err := root.Walk(targetNames[:len(targetNames)-1])
+ if err != nil {
+ h.t.Fatalf("got walk err %v, want nil", err)
+ }
+ defer parent.Close()
+ parentBackend := h.Pop(parent)
+
+ // Walk to or generate the target file.
+ _, _, target := targetGen(h, name, parent)
+ defer checkDeleted(h, target)
+
+ // Walk to a second reference.
+ _, second, err := parent.Walk([]string{name})
+ if err != nil {
+ h.t.Fatalf("got walk err %v, want nil", err)
+ }
+ defer checkDeleted(h, second)
+
+ // Walk to a third reference, from the start.
+ _, third, err := root.Walk(targetNames)
+ if err != nil {
+ h.t.Fatalf("got walk err %v, want nil", err)
+ }
+ defer checkDeleted(h, third)
+
+ // This will be translated in the backend to an unlinkat.
+ parentBackend.EXPECT().UnlinkAt(name, uint32(0)).Return(nil)
+
+ // Actually perform the deletion.
+ if err := deleteFn(parent, name); err != nil {
+ h.t.Fatalf("got delete err %v, want nil", err)
+ }
+}
+
+func unlinkTest(t *testing.T, targetNames []string, targetGen fileGenerator) {
+ t.Run(fmt.Sprintf("unlinkAt(%s)", strings.Join(targetNames, "/")), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ unlinkHelper(h, root, targetNames, targetGen, unlinkAt)
+ })
+ t.Run(fmt.Sprintf("remove(%s)", strings.Join(targetNames, "/")), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ unlinkHelper(h, root, targetNames, targetGen, remove)
+ })
+}
+
+func TestUnlink(t *testing.T) {
+ // Unlink all files.
+ for name := range newTypeMap(nil) {
+ unlinkTest(t, []string{name}, walkHelper)
+ unlinkTest(t, []string{name}, walkAndOpenHelper)
+ unlinkTest(t, []string{"one", name}, walkHelper)
+ unlinkTest(t, []string{"one", name}, walkAndOpenHelper)
+ unlinkTest(t, []string{"one", "two", name}, walkHelper)
+ unlinkTest(t, []string{"one", "two", name}, walkAndOpenHelper)
+ }
+
+ // Unlink a directory.
+ unlinkTest(t, []string{"one"}, walkHelper)
+ unlinkTest(t, []string{"one"}, walkAndOpenHelper)
+ unlinkTest(t, []string{"one", "two"}, walkHelper)
+ unlinkTest(t, []string{"one", "two"}, walkAndOpenHelper)
+
+ // Unlink created files.
+ unlinkTest(t, []string{"created"}, createHelper)
+ unlinkTest(t, []string{"one", "created"}, createHelper)
+ unlinkTest(t, []string{"one", "two", "created"}, createHelper)
+}
+
+func TestUnlinkAtInvalid(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ for name := range newTypeMap(nil) {
+ for _, invalidName := range allInvalidNames(name) {
+ if err := root.UnlinkAt(invalidName, 0); err != syscall.EINVAL {
+ t.Errorf("got %v for name %q, want EINVAL", err, invalidName)
+ }
+ }
+ }
+}
+
+// expectRenamed asserts an ordered sequence of rename calls, based on all the
+// elements in elements being the source, and the first element therein
+// changing to dstName, parented at dstParent.
+func expectRenamed(file *Mock, elements []string, dstParent *Mock, dstName string) *gomock.Call {
+ if len(elements) > 0 {
+ // Recurse to the parent, if necessary.
+ call := expectRenamed(file.parent, elements[:len(elements)-1], dstParent, dstName)
+
+ // Recursive case: this element is unchanged, but should have
+ // it's hook called after the parent.
+ return file.EXPECT().Renamed(file.parent, elements[len(elements)-1]).Do(func(p p9.File, _ string) {
+ file.parent = p.(*Mock)
+ }).After(call)
+ }
+
+ // Base case: this is the changed element.
+ return file.EXPECT().Renamed(dstParent, dstName).Do(func(p p9.File, name string) {
+ file.parent = p.(*Mock)
+ })
+}
+
+// renamer is a rename function.
+type renamer func(h *Harness, srcParent, dstParent p9.File, origName, newName string, selfRename bool) error
+
+// renameAt is a renamer.
+func renameAt(_ *Harness, srcParent, dstParent p9.File, srcName, dstName string, selfRename bool) error {
+ return srcParent.RenameAt(srcName, dstParent, dstName)
+}
+
+// rename is a renamer.
+func rename(h *Harness, srcParent, dstParent p9.File, srcName, dstName string, selfRename bool) error {
+ _, f, err := srcParent.Walk([]string{srcName})
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ if !selfRename {
+ backend := h.Pop(f)
+ backend.EXPECT().Renamed(gomock.Any(), dstName).Do(func(p p9.File, name string) {
+ backend.parent = p.(*Mock) // Required for close ordering.
+ })
+ }
+ return f.Rename(dstParent, dstName)
+}
+
+// renameHelper executes a rename, and asserts that all relevant elements
+// receive expected notifications. If overwriting a file, this includes
+// ensuring that the target has been appropriately marked as unlinked.
+func renameHelper(h *Harness, root p9.File, srcNames []string, dstNames []string, target fileGenerator, renameFn renamer) {
+ // Walk to the directory containing the target.
+ srcQID, targetParent, err := root.Walk(srcNames[:len(srcNames)-1])
+ if err != nil {
+ h.t.Fatalf("got walk err %v, want nil", err)
+ }
+ defer targetParent.Close()
+ targetParentBackend := h.Pop(targetParent)
+
+ // Walk to or generate the target file.
+ _, targetBackend, src := target(h, srcNames[len(srcNames)-1], targetParent)
+ defer src.Close()
+
+ // Walk to a second reference.
+ _, second, err := targetParent.Walk([]string{srcNames[len(srcNames)-1]})
+ if err != nil {
+ h.t.Fatalf("got walk err %v, want nil", err)
+ }
+ defer second.Close()
+ secondBackend := h.Pop(second)
+
+ // Walk to a third reference, from the start.
+ _, third, err := root.Walk(srcNames)
+ if err != nil {
+ h.t.Fatalf("got walk err %v, want nil", err)
+ }
+ defer third.Close()
+ thirdBackend := h.Pop(third)
+
+ // Find the common suffix to identify the rename parent.
+ var (
+ renameDestPath []string
+ renameSrcPath []string
+ selfRename bool
+ )
+ for i := 1; i <= len(srcNames) && i <= len(dstNames); i++ {
+ if srcNames[len(srcNames)-i] != dstNames[len(dstNames)-i] {
+ // Take the full prefix of dstNames up until this
+ // point, including the first mismatched name. The
+ // first mismatch must be the renamed entry.
+ renameDestPath = dstNames[:len(dstNames)-i+1]
+ renameSrcPath = srcNames[:len(srcNames)-i+1]
+
+ // Does the renameDestPath fully contain the
+ // renameSrcPath here? If yes, then this is a mismatch.
+ // We can't rename the src to some subpath of itself.
+ if len(renameDestPath) > len(renameSrcPath) &&
+ reflect.DeepEqual(renameDestPath[:len(renameSrcPath)], renameSrcPath) {
+ renameDestPath = nil
+ renameSrcPath = nil
+ continue
+ }
+ break
+ }
+ }
+ if len(renameSrcPath) == 0 || len(renameDestPath) == 0 {
+ // This must be a rename to self, or a tricky look-alike. This
+ // happens iff we fail to find a suitable divergence in the two
+ // paths. It's a true self move if the path length is the same.
+ renameDestPath = dstNames
+ renameSrcPath = srcNames
+ selfRename = len(srcNames) == len(dstNames)
+ }
+
+ // Walk to the source parent.
+ _, srcParent, err := root.Walk(renameSrcPath[:len(renameSrcPath)-1])
+ if err != nil {
+ h.t.Fatalf("got walk err %v, want nil", err)
+ }
+ defer srcParent.Close()
+ srcParentBackend := h.Pop(srcParent)
+
+ // Walk to the destination parent.
+ _, dstParent, err := root.Walk(renameDestPath[:len(renameDestPath)-1])
+ if err != nil {
+ h.t.Fatalf("got walk err %v, want nil", err)
+ }
+ defer dstParent.Close()
+ dstParentBackend := h.Pop(dstParent)
+
+ // expectedErr is the result of the rename operation.
+ var expectedErr error
+
+ // Walk to the target file, if one exists.
+ dstQID, dst, err := root.Walk(renameDestPath)
+ if err == nil {
+ if !selfRename && srcQID[0].Type == dstQID[0].Type {
+ // If there is a destination file, and is it of the
+ // same type as the source file, then we expect the
+ // rename to succeed. We expect the destination file to
+ // be deleted, so we run a deletion test on it in this
+ // case.
+ defer checkDeleted(h, dst)
+ } else {
+ if !selfRename {
+ // If the type is different than the
+ // destination, then we expect the rename to
+ // fail. We expect ensure that this is
+ // returned.
+ expectedErr = syscall.EINVAL
+ } else {
+ // This is the file being renamed to itself.
+ // This is technically allowed and a no-op, but
+ // all the triggers will fire.
+ }
+ dst.Close()
+ }
+ }
+ dstName := renameDestPath[len(renameDestPath)-1] // Renamed element.
+ srcName := renameSrcPath[len(renameSrcPath)-1] // Renamed element.
+ if expectedErr == nil && !selfRename {
+ // Expect all to be renamed appropriately. Note that if this is
+ // a final file being renamed, then we expect the file to be
+ // called with the new parent. If not, then we expect the
+ // rename hook to be called, but the parent will remain
+ // unchanged.
+ elements := srcNames[len(renameSrcPath):]
+ expectRenamed(targetBackend, elements, dstParentBackend, dstName)
+ expectRenamed(secondBackend, elements, dstParentBackend, dstName)
+ expectRenamed(thirdBackend, elements, dstParentBackend, dstName)
+
+ // The target parent has also been opened, and may be moved
+ // directly or indirectly.
+ if len(elements) > 1 {
+ expectRenamed(targetParentBackend, elements[:len(elements)-1], dstParentBackend, dstName)
+ }
+ }
+
+ // Expect the rename if it's not the same file. Note that like unlink,
+ // renames are always translated to the at variant in the backend.
+ if !selfRename {
+ srcParentBackend.EXPECT().RenameAt(srcName, dstParentBackend, dstName).Return(expectedErr)
+ }
+
+ // Perform the actual rename; everything has been lined up.
+ if err := renameFn(h, srcParent, dstParent, srcName, dstName, selfRename); err != expectedErr {
+ h.t.Fatalf("got rename err %v, want %v", err, expectedErr)
+ }
+}
+
+func renameTest(t *testing.T, srcNames []string, dstNames []string, target fileGenerator) {
+ t.Run(fmt.Sprintf("renameAt(%s->%s)", strings.Join(srcNames, "/"), strings.Join(dstNames, "/")), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ renameHelper(h, root, srcNames, dstNames, target, renameAt)
+ })
+ t.Run(fmt.Sprintf("rename(%s->%s)", strings.Join(srcNames, "/"), strings.Join(dstNames, "/")), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ renameHelper(h, root, srcNames, dstNames, target, rename)
+ })
+}
+
+func TestRename(t *testing.T) {
+ // In-directory rename, simple case.
+ for name := range newTypeMap(nil) {
+ // Within the root.
+ renameTest(t, []string{name}, []string{"renamed"}, walkHelper)
+ renameTest(t, []string{name}, []string{"renamed"}, walkAndOpenHelper)
+
+ // Within a subdirectory.
+ renameTest(t, []string{"one", name}, []string{"one", "renamed"}, walkHelper)
+ renameTest(t, []string{"one", name}, []string{"one", "renamed"}, walkAndOpenHelper)
+ }
+
+ // ... with created files.
+ renameTest(t, []string{"created"}, []string{"renamed"}, createHelper)
+ renameTest(t, []string{"one", "created"}, []string{"one", "renamed"}, createHelper)
+
+ // Across directories.
+ for name := range newTypeMap(nil) {
+ // Down one level.
+ renameTest(t, []string{"one", name}, []string{"one", "two", "renamed"}, walkHelper)
+ renameTest(t, []string{"one", name}, []string{"one", "two", "renamed"}, walkAndOpenHelper)
+
+ // Up one level.
+ renameTest(t, []string{"one", "two", name}, []string{"one", "renamed"}, walkHelper)
+ renameTest(t, []string{"one", "two", name}, []string{"one", "renamed"}, walkAndOpenHelper)
+
+ // Across at the same level.
+ renameTest(t, []string{"one", name}, []string{"three", "renamed"}, walkHelper)
+ renameTest(t, []string{"one", name}, []string{"three", "renamed"}, walkAndOpenHelper)
+ }
+
+ // ... with created files.
+ renameTest(t, []string{"one", "created"}, []string{"one", "two", "renamed"}, createHelper)
+ renameTest(t, []string{"one", "two", "created"}, []string{"one", "renamed"}, createHelper)
+ renameTest(t, []string{"one", "created"}, []string{"three", "renamed"}, createHelper)
+
+ // Renaming parents.
+ for name := range newTypeMap(nil) {
+ // Rename a parent.
+ renameTest(t, []string{"one", name}, []string{"renamed", name}, walkHelper)
+ renameTest(t, []string{"one", name}, []string{"renamed", name}, walkAndOpenHelper)
+
+ // Rename a super parent.
+ renameTest(t, []string{"one", "two", name}, []string{"renamed", name}, walkHelper)
+ renameTest(t, []string{"one", "two", name}, []string{"renamed", name}, walkAndOpenHelper)
+ }
+
+ // ... with created files.
+ renameTest(t, []string{"one", "created"}, []string{"renamed", "created"}, createHelper)
+ renameTest(t, []string{"one", "two", "created"}, []string{"renamed", "created"}, createHelper)
+
+ // Over existing files, including itself.
+ for name := range newTypeMap(nil) {
+ for other := range newTypeMap(nil) {
+ // Overwrite the noted file (may be itself).
+ renameTest(t, []string{"one", name}, []string{"one", other}, walkHelper)
+ renameTest(t, []string{"one", name}, []string{"one", other}, walkAndOpenHelper)
+
+ // Overwrite other files in another directory.
+ renameTest(t, []string{"one", name}, []string{"one", "two", other}, walkHelper)
+ renameTest(t, []string{"one", name}, []string{"one", "two", other}, walkAndOpenHelper)
+ }
+
+ // Overwrite by moving the parent.
+ renameTest(t, []string{"three", name}, []string{"one", name}, walkHelper)
+ renameTest(t, []string{"three", name}, []string{"one", name}, walkAndOpenHelper)
+
+ // Create over the types.
+ renameTest(t, []string{"one", "created"}, []string{"one", name}, createHelper)
+ renameTest(t, []string{"one", "created"}, []string{"one", "two", name}, createHelper)
+ renameTest(t, []string{"three", "created"}, []string{"one", name}, createHelper)
+ }
+}
+
+func TestRenameInvalid(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ for name := range newTypeMap(nil) {
+ for _, invalidName := range allInvalidNames(name) {
+ if err := root.Rename(root, invalidName); err != syscall.EINVAL {
+ t.Errorf("got %v for name %q, want EINVAL", err, invalidName)
+ }
+ }
+ }
+}
+
+func TestRenameAtInvalid(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ for name := range newTypeMap(nil) {
+ for _, invalidName := range allInvalidNames(name) {
+ if err := root.RenameAt(invalidName, root, "okay"); err != syscall.EINVAL {
+ t.Errorf("got %v for name %q, want EINVAL", err, invalidName)
+ }
+ if err := root.RenameAt("okay", root, invalidName); err != syscall.EINVAL {
+ t.Errorf("got %v for name %q, want EINVAL", err, invalidName)
+ }
+ }
+ }
+}
+
+// TestRenameSecondOrder tests that indirect rename targets continue to receive
+// Renamed calls after a rename of its renamed parent. i.e.,
+//
+// 1. Create /one/file
+// 2. Create /directory
+// 3. Rename /one -> /directory/one
+// 4. Rename /directory -> /three/foo
+// 5. file from (1) should still receive Renamed.
+//
+// This is a regression test for b/135219260.
+func TestRenameSecondOrder(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ rootBackend, root := newRoot(h, c)
+ defer root.Close()
+
+ // Walk to /one.
+ _, oneBackend, oneFile := walkHelper(h, "one", root)
+ defer oneFile.Close()
+
+ // Walk to and generate /one/file.
+ //
+ // walkHelper re-walks to oneFile, so we need the second backend,
+ // which will also receive Renamed calls.
+ oneSecondBackend, fileBackend, fileFile := walkHelper(h, "file", oneFile)
+ defer fileFile.Close()
+
+ // Walk to and generate /directory.
+ _, directoryBackend, directoryFile := walkHelper(h, "directory", root)
+ defer directoryFile.Close()
+
+ // Rename /one to /directory/one.
+ rootBackend.EXPECT().RenameAt("one", directoryBackend, "one").Return(nil)
+ expectRenamed(oneBackend, []string{}, directoryBackend, "one")
+ expectRenamed(oneSecondBackend, []string{}, directoryBackend, "one")
+ expectRenamed(fileBackend, []string{}, oneBackend, "file")
+ if err := renameAt(h, root, directoryFile, "one", "one", false); err != nil {
+ h.t.Fatalf("got rename err %v, want nil", err)
+ }
+
+ // Walk to /three.
+ _, threeBackend, threeFile := walkHelper(h, "three", root)
+ defer threeFile.Close()
+
+ // Rename /directory to /three/foo.
+ rootBackend.EXPECT().RenameAt("directory", threeBackend, "foo").Return(nil)
+ expectRenamed(directoryBackend, []string{}, threeBackend, "foo")
+ expectRenamed(oneBackend, []string{}, directoryBackend, "one")
+ expectRenamed(oneSecondBackend, []string{}, directoryBackend, "one")
+ expectRenamed(fileBackend, []string{}, oneBackend, "file")
+ if err := renameAt(h, root, threeFile, "directory", "foo", false); err != nil {
+ h.t.Fatalf("got rename err %v, want nil", err)
+ }
+}
+
+func TestReadlink(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ t.Run(name, func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ // Walk to the file normally.
+ _, f, err := root.Walk([]string{name})
+ if err != nil {
+ t.Fatalf("walk failed: got %v, wanted nil", err)
+ }
+ defer f.Close()
+ backend := h.Pop(f)
+
+ const symlinkTarget = "symlink-target"
+
+ if backend.Attr.Mode.IsSymlink() {
+ // This should only go through on symlinks.
+ backend.EXPECT().Readlink().Return(symlinkTarget, nil)
+ }
+
+ // Attempt a Readlink operation.
+ target, err := f.Readlink()
+ if err != nil && err != syscall.EINVAL {
+ t.Errorf("readlink got %v, wanted EINVAL", err)
+ } else if err == nil && target != symlinkTarget {
+ t.Errorf("readlink got %v, wanted %v", target, symlinkTarget)
+ }
+ })
+ }
+}
+
+// fdTest is a wrapper around operations that may send file descriptors. This
+// asserts that the file descriptors are working as intended.
+func fdTest(t *testing.T, sendFn func(*fd.FD) *fd.FD) {
+ // Create a pipe that we can read from.
+ r, w, err := os.Pipe()
+ if err != nil {
+ t.Fatalf("unable to create pipe: %v", err)
+ }
+ defer r.Close()
+ defer w.Close()
+
+ // Attempt to send the write end.
+ wFD, err := fd.NewFromFile(w)
+ if err != nil {
+ t.Fatalf("unable to convert file: %v", err)
+ }
+ defer wFD.Close() // This is a copy.
+
+ // Send wFD and receive newFD.
+ newFD := sendFn(wFD)
+ defer newFD.Close()
+
+ // Attempt to write.
+ const message = "hello"
+ if _, err := newFD.Write([]byte(message)); err != nil {
+ t.Fatalf("write got %v, wanted nil", err)
+ }
+
+ // Should see the message on our end.
+ buffer := []byte(message)
+ if _, err := io.ReadFull(r, buffer); err != nil {
+ t.Fatalf("read got %v, wanted nil", err)
+ }
+ if string(buffer) != message {
+ t.Errorf("got message %v, wanted %v", string(buffer), message)
+ }
+}
+
+func TestConnect(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ t.Run(name, func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ // Walk to the file normally.
+ _, backend, f := walkHelper(h, name, root)
+ defer f.Close()
+
+ // Catch all the non-socket cases.
+ if !backend.Attr.Mode.IsSocket() {
+ // This has been set up to fail if Connect is called.
+ if _, err := f.Connect(p9.ConnectFlags(0)); err != syscall.EINVAL {
+ t.Errorf("connect got %v, wanted EINVAL", err)
+ }
+ return
+ }
+
+ // Ensure the fd exchange works.
+ fdTest(t, func(send *fd.FD) *fd.FD {
+ backend.EXPECT().Connect(p9.ConnectFlags(0)).Return(send, nil)
+ recv, err := backend.Connect(p9.ConnectFlags(0))
+ if err != nil {
+ t.Fatalf("connect got %v, wanted nil", err)
+ }
+ return recv
+ })
+ })
+ }
+}
+
+func TestReaddir(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ t.Run(name, func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ // Walk to the file normally.
+ _, backend, f := walkHelper(h, name, root)
+ defer f.Close()
+
+ // Catch all the non-directory cases.
+ if !backend.Attr.Mode.IsDir() {
+ // This has also been set up to fail if Readdir is called.
+ if _, err := f.Readdir(0, 1); err != syscall.EINVAL {
+ t.Errorf("readdir got %v, wanted EINVAL", err)
+ }
+ return
+ }
+
+ // Ensure that readdir works for directories.
+ if _, err := f.Readdir(0, 1); err != syscall.EINVAL {
+ t.Errorf("readdir got %v, wanted EINVAL", err)
+ }
+ if _, _, _, err := f.Open(p9.ReadWrite); err != syscall.EISDIR {
+ t.Errorf("readdir got %v, wanted EISDIR", err)
+ }
+ if _, _, _, err := f.Open(p9.WriteOnly); err != syscall.EISDIR {
+ t.Errorf("readdir got %v, wanted EISDIR", err)
+ }
+ backend.EXPECT().Open(p9.ReadOnly).Times(1)
+ if _, _, _, err := f.Open(p9.ReadOnly); err != nil {
+ t.Errorf("readdir got %v, wanted nil", err)
+ }
+ backend.EXPECT().Readdir(uint64(0), uint32(1)).Times(1)
+ if _, err := f.Readdir(0, 1); err != nil {
+ t.Errorf("readdir got %v, wanted nil", err)
+ }
+ })
+ }
+}
+
+func TestOpen(t *testing.T) {
+ type openTest struct {
+ name string
+ flags p9.OpenFlags
+ err error
+ match func(p9.FileMode) bool
+ }
+
+ cases := []openTest{
+ {
+ name: "not-openable-read-only",
+ flags: p9.ReadOnly,
+ err: syscall.EINVAL,
+ match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) },
+ },
+ {
+ name: "not-openable-write-only",
+ flags: p9.WriteOnly,
+ err: syscall.EINVAL,
+ match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) },
+ },
+ {
+ name: "not-openable-read-write",
+ flags: p9.ReadWrite,
+ err: syscall.EINVAL,
+ match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) },
+ },
+ {
+ name: "directory-read-only",
+ flags: p9.ReadOnly,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ },
+ {
+ name: "directory-read-write",
+ flags: p9.ReadWrite,
+ err: syscall.EISDIR,
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ },
+ {
+ name: "directory-write-only",
+ flags: p9.WriteOnly,
+ err: syscall.EISDIR,
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ },
+ {
+ name: "read-only",
+ flags: p9.ReadOnly,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) },
+ },
+ {
+ name: "write-only",
+ flags: p9.WriteOnly,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
+ },
+ {
+ name: "read-write",
+ flags: p9.ReadWrite,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
+ },
+ {
+ name: "directory-read-only-truncate",
+ flags: p9.ReadOnly | p9.OpenTruncate,
+ err: syscall.EISDIR,
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ },
+ {
+ name: "read-only-truncate",
+ flags: p9.ReadOnly | p9.OpenTruncate,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
+ },
+ {
+ name: "write-only-truncate",
+ flags: p9.WriteOnly | p9.OpenTruncate,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
+ },
+ {
+ name: "read-write-truncate",
+ flags: p9.ReadWrite | p9.OpenTruncate,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
+ },
+ }
+
+ // Open(flags OpenFlags) (*fd.FD, QID, uint32, error)
+ // - only works on Regular, NamedPipe, BLockDevice, CharacterDevice
+ // - returning a file works as expected
+ for name := range newTypeMap(nil) {
+ for _, tc := range cases {
+ t.Run(fmt.Sprintf("%s-%s", tc.name, name), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ // Walk to the file normally.
+ _, backend, f := walkHelper(h, name, root)
+ defer f.Close()
+
+ // Does this match the case?
+ if !tc.match(backend.Attr.Mode) {
+ t.SkipNow()
+ }
+
+ // Ensure open-required operations fail.
+ if _, err := f.ReadAt([]byte("hello"), 0); err != syscall.EINVAL {
+ t.Errorf("readAt got %v, wanted EINVAL", err)
+ }
+ if _, err := f.WriteAt(make([]byte, 6), 0); err != syscall.EINVAL {
+ t.Errorf("writeAt got %v, wanted EINVAL", err)
+ }
+ if err := f.FSync(); err != syscall.EINVAL {
+ t.Errorf("fsync got %v, wanted EINVAL", err)
+ }
+ if _, err := f.Readdir(0, 1); err != syscall.EINVAL {
+ t.Errorf("readdir got %v, wanted EINVAL", err)
+ }
+
+ // Attempt the given open.
+ if tc.err != nil {
+ // We expect an error, just test and return.
+ if _, _, _, err := f.Open(tc.flags); err != tc.err {
+ t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err)
+ }
+ return
+ }
+
+ // Run an FD test, since we expect success.
+ fdTest(t, func(send *fd.FD) *fd.FD {
+ backend.EXPECT().Open(tc.flags).Return(send, p9.QID{}, uint32(0), nil).Times(1)
+ recv, _, _, err := f.Open(tc.flags)
+ if err != tc.err {
+ t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err)
+ }
+ return recv
+ })
+
+ // If the open was successful, attempt another one.
+ if _, _, _, err := f.Open(tc.flags); err != syscall.EINVAL {
+ t.Errorf("second open with flags %v got %v, want EINVAL", tc.flags, err)
+ }
+
+ // Ensure that all illegal operations fail.
+ if _, _, err := f.Walk(nil); err != syscall.EINVAL && err != syscall.EBUSY {
+ t.Errorf("walk got %v, wanted EINVAL or EBUSY", err)
+ }
+ if _, _, _, _, err := f.WalkGetAttr(nil); err != syscall.EINVAL && err != syscall.EBUSY {
+ t.Errorf("walkgetattr got %v, wanted EINVAL or EBUSY", err)
+ }
+ })
+ }
+ }
+}
+
+func TestClose(t *testing.T) {
+ type closeTest struct {
+ name string
+ closeFn func(backend *Mock, f p9.File)
+ }
+
+ cases := []closeTest{
+ {
+ name: "close",
+ closeFn: func(_ *Mock, f p9.File) {
+ f.Close()
+ },
+ },
+ {
+ name: "remove",
+ closeFn: func(backend *Mock, f p9.File) {
+ // Allow the rename call in the parent, automatically translated.
+ backend.parent.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Times(1)
+ f.(deprecatedRemover).Remove()
+ },
+ },
+ }
+
+ for name := range newTypeMap(nil) {
+ for _, tc := range cases {
+ t.Run(fmt.Sprintf("%s(%s)", tc.name, name), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ // Walk to the file normally.
+ _, backend, f := walkHelper(h, name, root)
+
+ // Close via the prescribed method.
+ tc.closeFn(backend, f)
+
+ // Everything should fail with EBADF.
+ if _, _, err := f.Walk(nil); err != syscall.EBADF {
+ t.Errorf("walk got %v, wanted EBADF", err)
+ }
+ if _, err := f.StatFS(); err != syscall.EBADF {
+ t.Errorf("statfs got %v, wanted EBADF", err)
+ }
+ if _, _, _, err := f.GetAttr(p9.AttrMaskAll()); err != syscall.EBADF {
+ t.Errorf("getattr got %v, wanted EBADF", err)
+ }
+ if err := f.SetAttr(p9.SetAttrMask{}, p9.SetAttr{}); err != syscall.EBADF {
+ t.Errorf("setattrk got %v, wanted EBADF", err)
+ }
+ if err := f.Rename(root, "new-name"); err != syscall.EBADF {
+ t.Errorf("rename got %v, wanted EBADF", err)
+ }
+ if err := f.Close(); err != syscall.EBADF {
+ t.Errorf("close got %v, wanted EBADF", err)
+ }
+ if _, _, _, err := f.Open(p9.ReadOnly); err != syscall.EBADF {
+ t.Errorf("open got %v, wanted EBADF", err)
+ }
+ if _, err := f.ReadAt([]byte("hello"), 0); err != syscall.EBADF {
+ t.Errorf("readAt got %v, wanted EBADF", err)
+ }
+ if _, err := f.WriteAt(make([]byte, 6), 0); err != syscall.EBADF {
+ t.Errorf("writeAt got %v, wanted EBADF", err)
+ }
+ if err := f.FSync(); err != syscall.EBADF {
+ t.Errorf("fsync got %v, wanted EBADF", err)
+ }
+ if _, _, _, _, err := f.Create("new-file", p9.ReadWrite, 0, 0, 0); err != syscall.EBADF {
+ t.Errorf("create got %v, wanted EBADF", err)
+ }
+ if _, err := f.Mkdir("new-directory", 0, 0, 0); err != syscall.EBADF {
+ t.Errorf("mkdir got %v, wanted EBADF", err)
+ }
+ if _, err := f.Symlink("old-name", "new-name", 0, 0); err != syscall.EBADF {
+ t.Errorf("symlink got %v, wanted EBADF", err)
+ }
+ if err := f.Link(root, "new-name"); err != syscall.EBADF {
+ t.Errorf("link got %v, wanted EBADF", err)
+ }
+ if _, err := f.Mknod("new-block-device", 0, 0, 0, 0, 0); err != syscall.EBADF {
+ t.Errorf("mknod got %v, wanted EBADF", err)
+ }
+ if err := f.RenameAt("old-name", root, "new-name"); err != syscall.EBADF {
+ t.Errorf("renameAt got %v, wanted EBADF", err)
+ }
+ if err := f.UnlinkAt("name", 0); err != syscall.EBADF {
+ t.Errorf("unlinkAt got %v, wanted EBADF", err)
+ }
+ if _, err := f.Readdir(0, 1); err != syscall.EBADF {
+ t.Errorf("readdir got %v, wanted EBADF", err)
+ }
+ if _, err := f.Readlink(); err != syscall.EBADF {
+ t.Errorf("readlink got %v, wanted EBADF", err)
+ }
+ if err := f.Flush(); err != syscall.EBADF {
+ t.Errorf("flush got %v, wanted EBADF", err)
+ }
+ if _, _, _, _, err := f.WalkGetAttr(nil); err != syscall.EBADF {
+ t.Errorf("walkgetattr got %v, wanted EBADF", err)
+ }
+ if _, err := f.Connect(p9.ConnectFlags(0)); err != syscall.EBADF {
+ t.Errorf("connect got %v, wanted EBADF", err)
+ }
+ })
+ }
+ }
+}
+
+// onlyWorksOnOpenThings is a helper test method for operations that should
+// only work on files that have been explicitly opened.
+func onlyWorksOnOpenThings(h *Harness, t *testing.T, name string, root p9.File, mode p9.OpenFlags, expectedErr error, fn func(backend *Mock, f p9.File, shouldSucceed bool) error) {
+ // Walk to the file normally.
+ _, backend, f := walkHelper(h, name, root)
+ defer f.Close()
+
+ // Does it work before opening?
+ if err := fn(backend, f, false); err != syscall.EINVAL {
+ t.Errorf("operation got %v, wanted EINVAL", err)
+ }
+
+ // Is this openable?
+ if !p9.CanOpen(backend.Attr.Mode) {
+ return // Nothing to do.
+ }
+
+ // If this is a directory, we can't handle writing.
+ if backend.Attr.Mode.IsDir() && (mode == p9.ReadWrite || mode == p9.WriteOnly) {
+ return // Skip.
+ }
+
+ // Open the file.
+ backend.EXPECT().Open(mode)
+ if _, _, _, err := f.Open(mode); err != nil {
+ t.Fatalf("open got %v, wanted nil", err)
+ }
+
+ // Attempt the operation.
+ if err := fn(backend, f, expectedErr == nil); err != expectedErr {
+ t.Fatalf("operation got %v, wanted %v", err, expectedErr)
+ }
+}
+
+func TestRead(t *testing.T) {
+ type readTest struct {
+ name string
+ mode p9.OpenFlags
+ err error
+ }
+
+ cases := []readTest{
+ {
+ name: "read-only",
+ mode: p9.ReadOnly,
+ err: nil,
+ },
+ {
+ name: "read-write",
+ mode: p9.ReadWrite,
+ err: nil,
+ },
+ {
+ name: "write-only",
+ mode: p9.WriteOnly,
+ err: syscall.EPERM,
+ },
+ }
+
+ for name := range newTypeMap(nil) {
+ for _, tc := range cases {
+ t.Run(fmt.Sprintf("%s-%s", tc.name, name), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ const message = "hello"
+
+ onlyWorksOnOpenThings(h, t, name, root, tc.mode, tc.err, func(backend *Mock, f p9.File, shouldSucceed bool) error {
+ if !shouldSucceed {
+ _, err := f.ReadAt([]byte(message), 0)
+ return err
+ }
+
+ // Prepare for the call to readAt in the backend.
+ backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ copy(p, message)
+ }).Return(len(message), nil)
+
+ // Make the client call.
+ p := make([]byte, 2*len(message)) // Double size.
+ n, err := f.ReadAt(p, 0)
+
+ // Sanity check result.
+ if err != nil {
+ return err
+ }
+ if n != len(message) {
+ t.Fatalf("message length incorrect, got %d, want %d", n, len(message))
+ }
+ if !bytes.Equal(p[:n], []byte(message)) {
+ t.Fatalf("message incorrect, got %v, want %v", p, []byte(message))
+ }
+ return nil // Success.
+ })
+ })
+ }
+ }
+}
+
+func TestWrite(t *testing.T) {
+ type writeTest struct {
+ name string
+ mode p9.OpenFlags
+ err error
+ }
+
+ cases := []writeTest{
+ {
+ name: "read-only",
+ mode: p9.ReadOnly,
+ err: syscall.EPERM,
+ },
+ {
+ name: "read-write",
+ mode: p9.ReadWrite,
+ err: nil,
+ },
+ {
+ name: "write-only",
+ mode: p9.WriteOnly,
+ err: nil,
+ },
+ }
+
+ for name := range newTypeMap(nil) {
+ for _, tc := range cases {
+ t.Run(fmt.Sprintf("%s-%s", tc.name, name), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ const message = "hello"
+
+ onlyWorksOnOpenThings(h, t, name, root, tc.mode, tc.err, func(backend *Mock, f p9.File, shouldSucceed bool) error {
+ if !shouldSucceed {
+ _, err := f.WriteAt([]byte(message), 0)
+ return err
+ }
+
+ // Prepare for the call to readAt in the backend.
+ var output []byte // Saved by Do below.
+ backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ output = p
+ }).Return(len(message), nil)
+
+ // Make the client call.
+ n, err := f.WriteAt([]byte(message), 0)
+
+ // Sanity check result.
+ if err != nil {
+ return err
+ }
+ if n != len(message) {
+ t.Fatalf("message length incorrect, got %d, want %d", n, len(message))
+ }
+ if !bytes.Equal(output, []byte(message)) {
+ t.Fatalf("message incorrect, got %v, want %v", output, []byte(message))
+ }
+ return nil // Success.
+ })
+ })
+ }
+ }
+}
+
+func TestFSync(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ for _, mode := range []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite} {
+ t.Run(fmt.Sprintf("%s-%s", mode, name), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ onlyWorksOnOpenThings(h, t, name, root, mode, nil, func(backend *Mock, f p9.File, shouldSucceed bool) error {
+ if shouldSucceed {
+ backend.EXPECT().FSync().Times(1)
+ }
+ return f.FSync()
+ })
+ })
+ }
+ }
+}
+
+func TestFlush(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ t.Run(name, func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ _, backend, f := walkHelper(h, name, root)
+ defer f.Close()
+
+ backend.EXPECT().Flush()
+ f.Flush()
+ })
+ }
+}
+
+// onlyWorksOnDirectories is a helper test method for operations that should
+// only work on unopened directories, such as create, mkdir and symlink.
+func onlyWorksOnDirectories(h *Harness, t *testing.T, name string, root p9.File, fn func(backend *Mock, f p9.File, shouldSucceed bool) error) {
+ // Walk to the file normally.
+ _, backend, f := walkHelper(h, name, root)
+ defer f.Close()
+
+ // Only directories support mknod.
+ if !backend.Attr.Mode.IsDir() {
+ if err := fn(backend, f, false); err != syscall.EINVAL {
+ t.Errorf("operation got %v, wanted EINVAL", err)
+ }
+ return // Nothing else to do.
+ }
+
+ // Should succeed.
+ if err := fn(backend, f, true); err != nil {
+ t.Fatalf("operation got %v, wanted nil", err)
+ }
+
+ // Open the directory.
+ backend.EXPECT().Open(p9.ReadOnly).Times(1)
+ if _, _, _, err := f.Open(p9.ReadOnly); err != nil {
+ t.Fatalf("open got %v, wanted nil", err)
+ }
+
+ // Should not work again.
+ if err := fn(backend, f, false); err != syscall.EINVAL {
+ t.Fatalf("operation got %v, wanted EINVAL", err)
+ }
+}
+
+func TestCreate(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ t.Run(name, func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error {
+ if !shouldSucceed {
+ _, _, _, _, err := f.Create("new-file", p9.ReadWrite, 0, 1, 2)
+ return err
+ }
+
+ // If the create is going to succeed, then we
+ // need to create a new backend file, and we
+ // clone to ensure that we don't close the
+ // original.
+ _, newF, err := f.Walk(nil)
+ if err != nil {
+ t.Fatalf("clone got %v, wanted nil", err)
+ }
+ defer newF.Close()
+ newBackend := h.Pop(newF)
+
+ // Run a regular FD test to validate that path.
+ fdTest(t, func(send *fd.FD) *fd.FD {
+ // Return the send FD on success.
+ newFile := h.NewFile()(backend) // New file with the parent backend.
+ newBackend.EXPECT().Create("new-file", p9.ReadWrite, p9.FileMode(0), p9.UID(1), p9.GID(2)).Return(send, newFile, p9.QID{}, uint32(0), nil)
+
+ // Receive the fd back.
+ recv, _, _, _, err := newF.Create("new-file", p9.ReadWrite, 0, 1, 2)
+ if err != nil {
+ t.Fatalf("create got %v, wanted nil", err)
+ }
+ return recv
+ })
+
+ // The above will fail via normal test flow, so
+ // we can assume that it passed.
+ return nil
+ })
+ })
+ }
+}
+
+func TestCreateInvalid(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ for name := range newTypeMap(nil) {
+ for _, invalidName := range allInvalidNames(name) {
+ if _, _, _, _, err := root.Create(invalidName, p9.ReadWrite, 0, 0, 0); err != syscall.EINVAL {
+ t.Errorf("got %v for name %q, want EINVAL", err, invalidName)
+ }
+ }
+ }
+}
+
+func TestMkdir(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ t.Run(name, func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error {
+ if shouldSucceed {
+ backend.EXPECT().Mkdir("new-directory", p9.FileMode(0), p9.UID(1), p9.GID(2))
+ }
+ _, err := f.Mkdir("new-directory", 0, 1, 2)
+ return err
+ })
+ })
+ }
+}
+
+func TestMkdirInvalid(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ for name := range newTypeMap(nil) {
+ for _, invalidName := range allInvalidNames(name) {
+ if _, err := root.Mkdir(invalidName, 0, 0, 0); err != syscall.EINVAL {
+ t.Errorf("got %v for name %q, want EINVAL", err, invalidName)
+ }
+ }
+ }
+}
+
+func TestSymlink(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ t.Run(name, func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error {
+ if shouldSucceed {
+ backend.EXPECT().Symlink("old-name", "new-name", p9.UID(1), p9.GID(2))
+ }
+ _, err := f.Symlink("old-name", "new-name", 1, 2)
+ return err
+ })
+ })
+ }
+}
+
+func TestSyminkInvalid(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ for name := range newTypeMap(nil) {
+ for _, invalidName := range allInvalidNames(name) {
+ // We need only test for invalid names in the new name,
+ // the target can be an arbitrary string and we don't
+ // need to sanity check it.
+ if _, err := root.Symlink("old-name", invalidName, 0, 0); err != syscall.EINVAL {
+ t.Errorf("got %v for name %q, want EINVAL", err, invalidName)
+ }
+ }
+ }
+}
+
+func TestLink(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ t.Run(name, func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error {
+ if shouldSucceed {
+ backend.EXPECT().Link(gomock.Any(), "new-link")
+ }
+ return f.Link(f, "new-link")
+ })
+ })
+ }
+}
+
+func TestLinkInvalid(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ for name := range newTypeMap(nil) {
+ for _, invalidName := range allInvalidNames(name) {
+ if err := root.Link(root, invalidName); err != syscall.EINVAL {
+ t.Errorf("got %v for name %q, want EINVAL", err, invalidName)
+ }
+ }
+ }
+}
+
+func TestMknod(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ t.Run(name, func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error {
+ if shouldSucceed {
+ backend.EXPECT().Mknod("new-block-device", p9.FileMode(0), uint32(1), uint32(2), p9.UID(3), p9.GID(4)).Times(1)
+ }
+ _, err := f.Mknod("new-block-device", 0, 1, 2, 3, 4)
+ return err
+ })
+ })
+ }
+}
+
+// concurrentFn is a specification of a concurrent operation. This is used to
+// drive the concurrency tests below.
+type concurrentFn struct {
+ name string
+ match func(p9.FileMode) bool
+ op func(h *Harness, backend *Mock, f p9.File, callback func())
+}
+
+func concurrentTest(t *testing.T, name string, fn1, fn2 concurrentFn, sameDir, expectedOkay bool) {
+ var (
+ names1 []string
+ names2 []string
+ )
+ if sameDir {
+ // Use the same file one directory up.
+ names1, names2 = []string{"one", name}, []string{"one", name}
+ } else {
+ // For different directories, just use siblings.
+ names1, names2 = []string{"one", name}, []string{"three", name}
+ }
+
+ t.Run(fmt.Sprintf("%s(%v)+%s(%v)", fn1.name, names1, fn2.name, names2), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ // Walk to both files as given.
+ _, f1, err := root.Walk(names1)
+ if err != nil {
+ t.Fatalf("error walking, got %v, want nil", err)
+ }
+ defer f1.Close()
+ b1 := h.Pop(f1)
+ _, f2, err := root.Walk(names2)
+ if err != nil {
+ t.Fatalf("error walking, got %v, want nil", err)
+ }
+ defer f2.Close()
+ b2 := h.Pop(f2)
+
+ // Are these a good match for the current test case?
+ if !fn1.match(b1.Attr.Mode) {
+ t.SkipNow()
+ }
+ if !fn2.match(b2.Attr.Mode) {
+ t.SkipNow()
+ }
+
+ // Construct our "concurrency creator".
+ in1 := make(chan struct{}, 1)
+ in2 := make(chan struct{}, 1)
+ var top sync.WaitGroup
+ var fns sync.WaitGroup
+ defer top.Wait()
+ top.Add(2) // Accounting for below.
+ defer fns.Done()
+ fns.Add(1) // See line above; released before top.Wait.
+ go func() {
+ defer top.Done()
+ fn1.op(h, b1, f1, func() {
+ in1 <- struct{}{}
+ fns.Wait()
+ })
+ }()
+ go func() {
+ defer top.Done()
+ fn2.op(h, b2, f2, func() {
+ in2 <- struct{}{}
+ fns.Wait()
+ })
+ }()
+
+ // Compute a reasonable timeout. If we expect the operation to hang,
+ // give it 10 milliseconds before we assert that it's fine. After all,
+ // there will be a lot of these tests. If we don't expect it to hang,
+ // give it a full minute, since the machine could be slow.
+ timeout := 10 * time.Millisecond
+ if expectedOkay {
+ timeout = 1 * time.Minute
+ }
+
+ // Read the first channel.
+ var second chan struct{}
+ select {
+ case <-in1:
+ second = in2
+ case <-in2:
+ second = in1
+ }
+
+ // Catch concurrency.
+ select {
+ case <-second:
+ // We finished successful. Is this good? Depends on the
+ // expected result.
+ if !expectedOkay {
+ t.Errorf("%q and %q proceeded concurrently!", fn1.name, fn2.name)
+ }
+ case <-time.After(timeout):
+ // Great, things did not proceed concurrently. Is that what we
+ // expected?
+ if expectedOkay {
+ t.Errorf("%q and %q hung concurrently!", fn1.name, fn2.name)
+ }
+ }
+ })
+}
+
+func randomFileName() string {
+ return fmt.Sprintf("%x", rand.Int63())
+}
+
+func TestConcurrency(t *testing.T) {
+ readExclusive := []concurrentFn{
+ {
+ // N.B. We can't explicitly check WalkGetAttr behavior,
+ // but we rely on the fact that the internal code paths
+ // are the same.
+ name: "walk",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ // See the documentation of WalkCallback.
+ // Because walk is actually implemented by the
+ // mock, we need a special place for this
+ // callback.
+ //
+ // Note that a clone actually locks the parent
+ // node. So we walk from this node to test
+ // concurrent operations appropriately.
+ backend.WalkCallback = func() error {
+ callback()
+ return nil
+ }
+ f.Walk([]string{randomFileName()}) // Won't exist.
+ },
+ },
+ {
+ name: "fsync",
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Open(gomock.Any())
+ backend.EXPECT().FSync().Do(func() {
+ callback()
+ })
+ f.Open(p9.ReadOnly) // Required.
+ f.FSync()
+ },
+ },
+ {
+ name: "readdir",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Open(gomock.Any())
+ backend.EXPECT().Readdir(gomock.Any(), gomock.Any()).Do(func(uint64, uint32) {
+ callback()
+ })
+ f.Open(p9.ReadOnly) // Required.
+ f.Readdir(0, 1)
+ },
+ },
+ {
+ name: "readlink",
+ match: func(mode p9.FileMode) bool { return mode.IsSymlink() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Readlink().Do(func() {
+ callback()
+ })
+ f.Readlink()
+ },
+ },
+ {
+ name: "connect",
+ match: func(mode p9.FileMode) bool { return mode.IsSocket() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Connect(gomock.Any()).Do(func(p9.ConnectFlags) {
+ callback()
+ })
+ f.Connect(0)
+ },
+ },
+ {
+ name: "open",
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Open(gomock.Any()).Do(func(p9.OpenFlags) {
+ callback()
+ })
+ f.Open(p9.ReadOnly)
+ },
+ },
+ {
+ name: "flush",
+ match: func(mode p9.FileMode) bool { return true },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Flush().Do(func() {
+ callback()
+ })
+ f.Flush()
+ },
+ },
+ }
+ writeExclusive := []concurrentFn{
+ {
+ // N.B. We can't really check getattr. But this is an
+ // extremely low-risk function, it seems likely that
+ // this check is paranoid anyways.
+ name: "setattr",
+ match: func(mode p9.FileMode) bool { return true },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().SetAttr(gomock.Any(), gomock.Any()).Do(func(p9.SetAttrMask, p9.SetAttr) {
+ callback()
+ })
+ f.SetAttr(p9.SetAttrMask{}, p9.SetAttr{})
+ },
+ },
+ {
+ name: "unlinkAt",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Do(func(string, uint32) {
+ callback()
+ })
+ f.UnlinkAt(randomFileName(), 0)
+ },
+ },
+ {
+ name: "mknod",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Mknod(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.FileMode, uint32, uint32, p9.UID, p9.GID) {
+ callback()
+ })
+ f.Mknod(randomFileName(), 0, 0, 0, 0, 0)
+ },
+ },
+ {
+ name: "link",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Link(gomock.Any(), gomock.Any()).Do(func(p9.File, string) {
+ callback()
+ })
+ f.Link(f, randomFileName())
+ },
+ },
+ {
+ name: "symlink",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Symlink(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, string, p9.UID, p9.GID) {
+ callback()
+ })
+ f.Symlink(randomFileName(), randomFileName(), 0, 0)
+ },
+ },
+ {
+ name: "mkdir",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Mkdir(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.FileMode, p9.UID, p9.GID) {
+ callback()
+ })
+ f.Mkdir(randomFileName(), 0, 0, 0)
+ },
+ },
+ {
+ name: "create",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ // Return an error for the creation operation, as this is the simplest.
+ backend.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil, p9.QID{}, uint32(0), syscall.EINVAL).Do(func(string, p9.OpenFlags, p9.FileMode, p9.UID, p9.GID) {
+ callback()
+ })
+ f.Create(randomFileName(), p9.ReadOnly, 0, 0, 0)
+ },
+ },
+ }
+ globalExclusive := []concurrentFn{
+ {
+ name: "remove",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ // Remove operates on a locked parent. So we
+ // add a child, walk to it and call remove.
+ // Note that because this operation can operate
+ // concurrently with itself, we need to
+ // generate a random file name.
+ randomFile := randomFileName()
+ backend.AddChild(randomFile, h.NewFile())
+ defer backend.RemoveChild(randomFile)
+ _, file, err := f.Walk([]string{randomFile})
+ if err != nil {
+ h.t.Fatalf("walk got %v, want nil", err)
+ }
+
+ // Remove is automatically translated to the parent.
+ backend.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Do(func(string, uint32) {
+ callback()
+ })
+
+ // Remove is also a close.
+ file.(deprecatedRemover).Remove()
+ },
+ },
+ {
+ name: "rename",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ // Similarly to remove, because we need to
+ // operate on a child, we allow a walk.
+ randomFile := randomFileName()
+ backend.AddChild(randomFile, h.NewFile())
+ defer backend.RemoveChild(randomFile)
+ _, file, err := f.Walk([]string{randomFile})
+ if err != nil {
+ h.t.Fatalf("walk got %v, want nil", err)
+ }
+ defer file.Close()
+ fileBackend := h.Pop(file)
+
+ // Rename is automatically translated to the parent.
+ backend.EXPECT().RenameAt(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.File, string) {
+ callback()
+ })
+
+ // Attempt the rename.
+ fileBackend.EXPECT().Renamed(gomock.Any(), gomock.Any())
+ file.Rename(f, randomFileName())
+ },
+ },
+ {
+ name: "renameAt",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().RenameAt(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.File, string) {
+ callback()
+ })
+
+ // Attempt the rename. There are no active fids
+ // with this name, so we don't need to expect
+ // Renamed hooks on anything.
+ f.RenameAt(randomFileName(), f, randomFileName())
+ },
+ },
+ }
+
+ for _, fn1 := range readExclusive {
+ for _, fn2 := range readExclusive {
+ for name := range newTypeMap(nil) {
+ // Everything should be able to proceed in parallel.
+ concurrentTest(t, name, fn1, fn2, true, true)
+ concurrentTest(t, name, fn1, fn2, false, true)
+ }
+ }
+ }
+
+ for _, fn1 := range append(readExclusive, writeExclusive...) {
+ for _, fn2 := range writeExclusive {
+ for name := range newTypeMap(nil) {
+ // Only cross-directory functions should proceed in parallel.
+ concurrentTest(t, name, fn1, fn2, true, false)
+ concurrentTest(t, name, fn1, fn2, false, true)
+ }
+ }
+ }
+
+ for _, fn1 := range append(append(readExclusive, writeExclusive...), globalExclusive...) {
+ for _, fn2 := range globalExclusive {
+ for name := range newTypeMap(nil) {
+ // Nothing should be able to run in parallel.
+ concurrentTest(t, name, fn1, fn2, true, false)
+ concurrentTest(t, name, fn1, fn2, false, false)
+ }
+ }
+ }
+}
+
+func TestReadWriteConcurrent(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ const (
+ instances = 10
+ iterations = 10000
+ dataSize = 1024
+ )
+ var (
+ dataSets [instances][dataSize]byte
+ backends [instances]*Mock
+ files [instances]p9.File
+ )
+
+ // Walk to the file normally.
+ for i := 0; i < instances; i++ {
+ _, backends[i], files[i] = walkHelper(h, "file", root)
+ defer files[i].Close()
+ }
+
+ // Open the files.
+ for i := 0; i < instances; i++ {
+ backends[i].EXPECT().Open(p9.ReadWrite)
+ if _, _, _, err := files[i].Open(p9.ReadWrite); err != nil {
+ t.Fatalf("open got %v, wanted nil", err)
+ }
+ }
+
+ // Initialize random data for each instance.
+ for i := 0; i < instances; i++ {
+ if _, err := rand.Read(dataSets[i][:]); err != nil {
+ t.Fatalf("error initializing dataSet#%d, got %v", i, err)
+ }
+ }
+
+ // Define our random read/write mechanism.
+ randRead := func(h *Harness, backend *Mock, f p9.File, data, test []byte) {
+ // Prepare the backend.
+ backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if n := copy(p, data); n != len(data) {
+ // Note that we have to assert the result here, as the Return statement
+ // below cannot be dynamic: it will be bound before this call is made.
+ h.t.Errorf("wanted length %d, got %d", len(data), n)
+ }
+ }).Return(len(data), nil)
+
+ // Execute the read.
+ if n, err := f.ReadAt(test, 0); n != len(test) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(test), n, err)
+ return // No sense doing check below.
+ }
+ if !bytes.Equal(test, data) {
+ t.Errorf("data integrity failed during read") // Not as expected.
+ }
+ }
+ randWrite := func(h *Harness, backend *Mock, f p9.File, data []byte) {
+ // Prepare the backend.
+ backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if !bytes.Equal(p, data) {
+ h.t.Errorf("data integrity failed during write") // Not as expected.
+ }
+ }).Return(len(data), nil)
+
+ // Execute the write.
+ if n, err := f.WriteAt(data, 0); n != len(data) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(data), n, err)
+ }
+ }
+ randReadWrite := func(n int, h *Harness, backend *Mock, f p9.File, data []byte) {
+ test := make([]byte, len(data))
+ for i := 0; i < n; i++ {
+ if rand.Intn(2) == 0 {
+ randRead(h, backend, f, data, test)
+ } else {
+ randWrite(h, backend, f, data)
+ }
+ }
+ }
+
+ // Start reading and writing.
+ var wg sync.WaitGroup
+ for i := 0; i < instances; i++ {
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+ randReadWrite(iterations, h, backends[i], files[i], dataSets[i][:])
+ }(i)
+ }
+ wg.Wait()
+}
diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go
new file mode 100644
index 000000000..dd8b01b6d
--- /dev/null
+++ b/pkg/p9/p9test/p9test.go
@@ -0,0 +1,329 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package p9test provides standard mocks for p9.
+package p9test
+
+import (
+ "fmt"
+ "sync/atomic"
+ "syscall"
+ "testing"
+
+ "github.com/golang/mock/gomock"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// Harness is an attacher mock.
+type Harness struct {
+ t *testing.T
+ mockCtrl *gomock.Controller
+ Attacher *MockAttacher
+ wg sync.WaitGroup
+ clientSocket *unet.Socket
+ mu sync.Mutex
+ created []*Mock
+}
+
+// globalPath is a QID.Path Generator.
+var globalPath uint64
+
+// MakePath returns a globally unique path.
+func MakePath() uint64 {
+ return atomic.AddUint64(&globalPath, 1)
+}
+
+// Generator is a function that generates a new file.
+type Generator func(parent *Mock) *Mock
+
+// Mock is a common mock element.
+type Mock struct {
+ p9.DefaultWalkGetAttr
+ *MockFile
+ parent *Mock
+ closed bool
+ harness *Harness
+ QID p9.QID
+ Attr p9.Attr
+ children map[string]Generator
+
+ // WalkCallback is a special function that will be called from within
+ // the walk context. This is needed for the concurrent tests within
+ // this package.
+ WalkCallback func() error
+}
+
+// globalMu protects the children maps in all mocks. Note that this is not a
+// particularly elegant solution, but because the test has walks from the root
+// through to final nodes, we must share maps below, and it's easiest to simply
+// protect against concurrent access globally.
+var globalMu sync.RWMutex
+
+// AddChild adds a new child to the Mock.
+func (m *Mock) AddChild(name string, generator Generator) {
+ globalMu.Lock()
+ defer globalMu.Unlock()
+ m.children[name] = generator
+}
+
+// RemoveChild removes the child with the given name.
+func (m *Mock) RemoveChild(name string) {
+ globalMu.Lock()
+ defer globalMu.Unlock()
+ delete(m.children, name)
+}
+
+// Matches implements gomock.Matcher.Matches.
+func (m *Mock) Matches(x interface{}) bool {
+ if om, ok := x.(*Mock); ok {
+ return m.QID.Path == om.QID.Path
+ }
+ return false
+}
+
+// String implements gomock.Matcher.String.
+func (m *Mock) String() string {
+ return fmt.Sprintf("Mock{Mode: 0x%x, QID.Path: %d}", m.Attr.Mode, m.QID.Path)
+}
+
+// GetAttr returns the current attributes.
+func (m *Mock) GetAttr(mask p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) {
+ return m.QID, p9.AttrMaskAll(), m.Attr, nil
+}
+
+// Walk supports clone and walking in directories.
+func (m *Mock) Walk(names []string) ([]p9.QID, p9.File, error) {
+ if m.WalkCallback != nil {
+ if err := m.WalkCallback(); err != nil {
+ return nil, nil, err
+ }
+ }
+ if len(names) == 0 {
+ // Clone the file appropriately.
+ nm := m.harness.NewMock(m.parent, m.QID.Path, m.Attr)
+ nm.children = m.children // Inherit children.
+ return []p9.QID{nm.QID}, nm, nil
+ } else if len(names) != 1 {
+ m.harness.t.Fail() // Should not happen.
+ return nil, nil, syscall.EINVAL
+ }
+
+ if m.Attr.Mode.IsDir() {
+ globalMu.RLock()
+ defer globalMu.RUnlock()
+ if fn, ok := m.children[names[0]]; ok {
+ // Generate the child.
+ nm := fn(m)
+ return []p9.QID{nm.QID}, nm, nil
+ }
+ // No child found.
+ return nil, nil, syscall.ENOENT
+ }
+
+ // Call the underlying mock.
+ return m.MockFile.Walk(names)
+}
+
+// WalkGetAttr calls the default implementation; this is a client-side optimization.
+func (m *Mock) WalkGetAttr(names []string) ([]p9.QID, p9.File, p9.AttrMask, p9.Attr, error) {
+ return m.DefaultWalkGetAttr.WalkGetAttr(names)
+}
+
+// Pop pops off the most recently created Mock and assert that this mock
+// represents the same file passed in. If nil is passed in, no check is
+// performed.
+//
+// Precondition: there must be at least one Mock or this will panic.
+func (h *Harness) Pop(clientFile p9.File) *Mock {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+
+ if clientFile == nil {
+ // If no clientFile is provided, then we always return the last
+ // created file. The caller can safely use this as long as
+ // there is no concurrency.
+ m := h.created[len(h.created)-1]
+ h.created = h.created[:len(h.created)-1]
+ return m
+ }
+
+ qid, _, _, err := clientFile.GetAttr(p9.AttrMaskAll())
+ if err != nil {
+ // We do not expect this to happen.
+ panic(fmt.Sprintf("err during Pop: %v", err))
+ }
+
+ // Find the relevant file in our created list. We must scan the last
+ // from back to front to ensure that we favor the most recently
+ // generated file.
+ for i := len(h.created) - 1; i >= 0; i-- {
+ m := h.created[i]
+ if qid.Path == m.QID.Path {
+ // Copy and truncate.
+ copy(h.created[i:], h.created[i+1:])
+ h.created = h.created[:len(h.created)-1]
+ return m
+ }
+ }
+
+ // Unable to find relevant file.
+ panic(fmt.Sprintf("unable to locate file with QID %+v", qid.Path))
+}
+
+// NewMock returns a new base file.
+func (h *Harness) NewMock(parent *Mock, path uint64, attr p9.Attr) *Mock {
+ m := &Mock{
+ MockFile: NewMockFile(h.mockCtrl),
+ parent: parent,
+ harness: h,
+ QID: p9.QID{
+ Type: p9.QIDType((attr.Mode & p9.FileModeMask) >> 12),
+ Path: path,
+ },
+ Attr: attr,
+ }
+
+ // Always ensure Close is after the parent's close. Note that this
+ // can't be done via a straight-forward After call, because the parent
+ // might change after initial creation. We ensure that this is true at
+ // close time.
+ m.EXPECT().Close().Return(nil).Times(1).Do(func() {
+ if m.parent != nil && m.parent.closed {
+ h.t.FailNow()
+ }
+ // Note that this should not be racy, as this operation should
+ // be protected by the Times(1) above first.
+ m.closed = true
+ })
+
+ // Remember what was created.
+ h.mu.Lock()
+ defer h.mu.Unlock()
+ h.created = append(h.created, m)
+
+ return m
+}
+
+// NewFile returns a new file mock.
+//
+// Note that ReadAt and WriteAt must be mocked separately.
+func (h *Harness) NewFile() Generator {
+ return func(parent *Mock) *Mock {
+ return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeRegular})
+ }
+}
+
+// NewDirectory returns a new mock directory.
+//
+// Note that Mkdir, Link, Mknod, RenameAt, UnlinkAt and Readdir must be mocked
+// separately. Walk is provided and children may be manipulated via AddChild
+// and RemoveChild. After calling Walk remotely, one can use Pop to find the
+// corresponding backend mock on the server side.
+func (h *Harness) NewDirectory(contents map[string]Generator) Generator {
+ return func(parent *Mock) *Mock {
+ m := h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeDirectory})
+ m.children = contents // Save contents.
+ return m
+ }
+}
+
+// NewSymlink returns a new mock directory.
+//
+// Note that Readlink must be mocked separately.
+func (h *Harness) NewSymlink() Generator {
+ return func(parent *Mock) *Mock {
+ return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeSymlink})
+ }
+}
+
+// NewBlockDevice returns a new mock block device.
+func (h *Harness) NewBlockDevice() Generator {
+ return func(parent *Mock) *Mock {
+ return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeBlockDevice})
+ }
+}
+
+// NewCharacterDevice returns a new mock character device.
+func (h *Harness) NewCharacterDevice() Generator {
+ return func(parent *Mock) *Mock {
+ return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeCharacterDevice})
+ }
+}
+
+// NewNamedPipe returns a new mock named pipe.
+func (h *Harness) NewNamedPipe() Generator {
+ return func(parent *Mock) *Mock {
+ return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeNamedPipe})
+ }
+}
+
+// NewSocket returns a new mock socket.
+func (h *Harness) NewSocket() Generator {
+ return func(parent *Mock) *Mock {
+ return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeSocket})
+ }
+}
+
+// Finish completes all checks and shuts down the server.
+func (h *Harness) Finish() {
+ h.clientSocket.Shutdown()
+ h.wg.Wait()
+ h.mockCtrl.Finish()
+}
+
+// NewHarness creates and returns a new test server.
+//
+// It should always be used as:
+//
+// h, c := NewHarness(t)
+// defer h.Finish()
+//
+func NewHarness(t *testing.T) (*Harness, *p9.Client) {
+ // Create the mock.
+ mockCtrl := gomock.NewController(t)
+ h := &Harness{
+ t: t,
+ mockCtrl: mockCtrl,
+ Attacher: NewMockAttacher(mockCtrl),
+ }
+
+ // Make socket pair.
+ serverSocket, clientSocket, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v wanted nil", err)
+ }
+
+ // Start the server, synchronized on exit.
+ server := p9.NewServer(h.Attacher)
+ h.wg.Add(1)
+ go func() {
+ defer h.wg.Done()
+ server.Handle(serverSocket)
+ }()
+
+ // Create the client.
+ client, err := p9.NewClient(clientSocket, p9.DefaultMessageSize, p9.HighestVersionString())
+ if err != nil {
+ serverSocket.Close()
+ clientSocket.Close()
+ t.Fatalf("new client got %v, expected nil", err)
+ return nil, nil // Never hit.
+ }
+
+ // Capture the client socket.
+ h.clientSocket = clientSocket
+ return h, client
+}
diff --git a/pkg/p9/path_tree.go b/pkg/p9/path_tree.go
new file mode 100644
index 000000000..72ef53313
--- /dev/null
+++ b/pkg/p9/path_tree.go
@@ -0,0 +1,222 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// pathNode is a single node in a path traversal.
+//
+// These are shared by all fidRefs that point to the same path.
+//
+// Lock ordering:
+// opMu
+// childMu
+//
+// Two different pathNodes may only be locked if Server.renameMu is held for
+// write, in which case they can be acquired in any order.
+type pathNode struct {
+ // opMu synchronizes high-level, sematic operations, such as the
+ // simultaneous creation and deletion of a file.
+ //
+ // opMu does not directly protect any fields in pathNode.
+ opMu sync.RWMutex
+
+ // childMu protects the fields below.
+ childMu sync.RWMutex
+
+ // childNodes maps child path component names to their pathNode.
+ childNodes map[string]*pathNode
+
+ // childRefs maps child path component names to all of the their
+ // references.
+ childRefs map[string]map[*fidRef]struct{}
+
+ // childRefNames maps child references back to their path component
+ // name.
+ childRefNames map[*fidRef]string
+}
+
+func newPathNode() *pathNode {
+ return &pathNode{
+ childNodes: make(map[string]*pathNode),
+ childRefs: make(map[string]map[*fidRef]struct{}),
+ childRefNames: make(map[*fidRef]string),
+ }
+}
+
+// forEachChildRef calls fn for each child reference.
+func (p *pathNode) forEachChildRef(fn func(ref *fidRef, name string)) {
+ p.childMu.RLock()
+ defer p.childMu.RUnlock()
+
+ for name, m := range p.childRefs {
+ for ref := range m {
+ fn(ref, name)
+ }
+ }
+}
+
+// forEachChildNode calls fn for each child pathNode.
+func (p *pathNode) forEachChildNode(fn func(pn *pathNode)) {
+ p.childMu.RLock()
+ defer p.childMu.RUnlock()
+
+ for _, pn := range p.childNodes {
+ fn(pn)
+ }
+}
+
+// pathNodeFor returns the path node for the given name, or a new one.
+func (p *pathNode) pathNodeFor(name string) *pathNode {
+ p.childMu.RLock()
+ // Fast path, node already exists.
+ if pn, ok := p.childNodes[name]; ok {
+ p.childMu.RUnlock()
+ return pn
+ }
+ p.childMu.RUnlock()
+
+ // Slow path, create a new pathNode for shared use.
+ p.childMu.Lock()
+
+ // Re-check after re-lock.
+ if pn, ok := p.childNodes[name]; ok {
+ p.childMu.Unlock()
+ return pn
+ }
+
+ pn := newPathNode()
+ p.childNodes[name] = pn
+ p.childMu.Unlock()
+ return pn
+}
+
+// nameFor returns the name for the given fidRef.
+//
+// Precondition: addChild is called for ref before nameFor.
+func (p *pathNode) nameFor(ref *fidRef) string {
+ p.childMu.RLock()
+ n, ok := p.childRefNames[ref]
+ p.childMu.RUnlock()
+
+ if !ok {
+ // This should not happen, don't proceed.
+ panic(fmt.Sprintf("expected name for %+v, none found", ref))
+ }
+
+ return n
+}
+
+// addChildLocked adds a child reference to p.
+//
+// Precondition: As addChild, plus childMu is locked for write.
+func (p *pathNode) addChildLocked(ref *fidRef, name string) {
+ if n, ok := p.childRefNames[ref]; ok {
+ // This should not happen, don't proceed.
+ panic(fmt.Sprintf("unexpected fidRef %+v with path %q, wanted %q", ref, n, name))
+ }
+
+ p.childRefNames[ref] = name
+
+ m, ok := p.childRefs[name]
+ if !ok {
+ m = make(map[*fidRef]struct{})
+ p.childRefs[name] = m
+ }
+
+ m[ref] = struct{}{}
+}
+
+// addChild adds a child reference to p.
+//
+// Precondition: ref may only be added once at a time.
+func (p *pathNode) addChild(ref *fidRef, name string) {
+ p.childMu.Lock()
+ p.addChildLocked(ref, name)
+ p.childMu.Unlock()
+}
+
+// removeChild removes the given child.
+//
+// This applies only to an individual fidRef, which is not required to exist.
+func (p *pathNode) removeChild(ref *fidRef) {
+ p.childMu.Lock()
+
+ // This ref may not exist anymore. This can occur, e.g., in unlink,
+ // where a removeWithName removes the ref, and then a DecRef on the ref
+ // attempts to remove again.
+ if name, ok := p.childRefNames[ref]; ok {
+ m, ok := p.childRefs[name]
+ if !ok {
+ // This should not happen, don't proceed.
+ p.childMu.Unlock()
+ panic(fmt.Sprintf("name %s missing from childfidRefs", name))
+ }
+
+ delete(m, ref)
+ if len(m) == 0 {
+ delete(p.childRefs, name)
+ }
+ }
+
+ delete(p.childRefNames, ref)
+
+ p.childMu.Unlock()
+}
+
+// addPathNodeFor adds an existing pathNode as the node for name.
+//
+// Preconditions: newName does not exist.
+func (p *pathNode) addPathNodeFor(name string, pn *pathNode) {
+ p.childMu.Lock()
+
+ if opn, ok := p.childNodes[name]; ok {
+ p.childMu.Unlock()
+ panic(fmt.Sprintf("unexpected pathNode %+v with path %q", opn, name))
+ }
+
+ p.childNodes[name] = pn
+ p.childMu.Unlock()
+}
+
+// removeWithName removes all references with the given name.
+//
+// The provided function is executed after reference removal. The only method
+// it may (transitively) call on this pathNode is addChildLocked.
+//
+// If a child pathNode for name exists, it is removed from this pathNode and
+// returned by this function. Any operations on the removed tree must use this
+// value.
+func (p *pathNode) removeWithName(name string, fn func(ref *fidRef)) *pathNode {
+ p.childMu.Lock()
+ defer p.childMu.Unlock()
+
+ if m, ok := p.childRefs[name]; ok {
+ for ref := range m {
+ delete(m, ref)
+ delete(p.childRefNames, ref)
+ fn(ref)
+ }
+ }
+
+ // Return the original path node, if it exists.
+ origPathNode := p.childNodes[name]
+ delete(p.childNodes, name)
+ return origPathNode
+}
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
new file mode 100644
index 000000000..60cf94fa1
--- /dev/null
+++ b/pkg/p9/server.go
@@ -0,0 +1,694 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "io"
+ "runtime/debug"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/fdchannel"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// Server is a 9p2000.L server.
+type Server struct {
+ // attacher provides the attach function.
+ attacher Attacher
+
+ // pathTree is the full set of paths opened on this server.
+ //
+ // These may be across different connections, but rename operations
+ // must be serialized globally for safely. There is a single pathTree
+ // for the entire server, and not per connection.
+ pathTree *pathNode
+
+ // renameMu is a global lock protecting rename operations. With this
+ // lock, we can be certain that any given rename operation can safely
+ // acquire two path nodes in any order, as all other concurrent
+ // operations acquire at most a single node.
+ renameMu sync.RWMutex
+}
+
+// NewServer returns a new server.
+func NewServer(attacher Attacher) *Server {
+ return &Server{
+ attacher: attacher,
+ pathTree: newPathNode(),
+ }
+}
+
+// connState is the state for a single connection.
+type connState struct {
+ // server is the backing server.
+ server *Server
+
+ // sendMu is the send lock.
+ sendMu sync.Mutex
+
+ // conn is the connection.
+ conn *unet.Socket
+
+ // fids is the set of active FIDs.
+ //
+ // This is used to find FIDs for files.
+ fidMu sync.Mutex
+ fids map[FID]*fidRef
+
+ // tags is the set of active tags.
+ //
+ // The given channel is closed when the
+ // tag is finished with processing.
+ tagMu sync.Mutex
+ tags map[Tag]chan struct{}
+
+ // messageSize is the maximum message size. The server does not
+ // do automatic splitting of messages.
+ messageSize uint32
+
+ // version is the agreed upon version X of 9P2000.L.Google.X.
+ // version 0 implies 9P2000.L.
+ version uint32
+
+ // -- below relates to the legacy handler --
+
+ // recvOkay indicates that a receive may start.
+ recvOkay chan bool
+
+ // recvDone is signalled when a message is received.
+ recvDone chan error
+
+ // sendDone is signalled when a send is finished.
+ sendDone chan error
+
+ // -- below relates to the flipcall handler --
+
+ // channelMu protects below.
+ channelMu sync.Mutex
+
+ // channelWg represents active workers.
+ channelWg sync.WaitGroup
+
+ // channelAlloc allocates channel memory.
+ channelAlloc *flipcall.PacketWindowAllocator
+
+ // channels are the set of initialized channels.
+ channels []*channel
+}
+
+// fidRef wraps a node and tracks references.
+type fidRef struct {
+ // server is the associated server.
+ server *Server
+
+ // file is the associated File.
+ file File
+
+ // refs is an active refence count.
+ //
+ // The node above will be closed only when refs reaches zero.
+ refs int64
+
+ // openedMu protects opened and openFlags.
+ openedMu sync.Mutex
+
+ // opened indicates whether this has been opened already.
+ //
+ // This is updated in handlers.go.
+ opened bool
+
+ // mode is the fidRef's mode from the walk. Only the type bits are
+ // valid, the permissions may change. This is used to sanity check
+ // operations on this element, and prevent walks across
+ // non-directories.
+ mode FileMode
+
+ // openFlags is the mode used in the open.
+ //
+ // This is updated in handlers.go.
+ openFlags OpenFlags
+
+ // pathNode is the current pathNode for this FID.
+ pathNode *pathNode
+
+ // parent is the parent fidRef. We hold on to a parent reference to
+ // ensure that hooks, such as Renamed, can be executed safely by the
+ // server code.
+ //
+ // Note that parent cannot be changed without holding both the global
+ // rename lock and a writable lock on the associated pathNode for this
+ // fidRef. Holding either of these locks is sufficient to examine
+ // parent safely.
+ //
+ // The parent will be nil for root fidRefs, and non-nil otherwise. The
+ // method maybeParent can be used to return a cyclical reference, and
+ // isRoot should be used to check for root over looking at parent
+ // directly.
+ parent *fidRef
+
+ // deleted indicates that the backing file has been deleted. We stop
+ // many operations at the API level if they are incompatible with a
+ // file that has already been unlinked.
+ deleted uint32
+}
+
+// OpenFlags returns the flags the file was opened with and true iff the fid was opened previously.
+func (f *fidRef) OpenFlags() (OpenFlags, bool) {
+ f.openedMu.Lock()
+ defer f.openedMu.Unlock()
+ return f.openFlags, f.opened
+}
+
+// IncRef increases the references on a fid.
+func (f *fidRef) IncRef() {
+ atomic.AddInt64(&f.refs, 1)
+}
+
+// DecRef should be called when you're finished with a fid.
+func (f *fidRef) DecRef() {
+ if atomic.AddInt64(&f.refs, -1) == 0 {
+ f.file.Close()
+
+ // Drop the parent reference.
+ //
+ // Since this fidRef is guaranteed to be non-discoverable when
+ // the references reach zero, we don't need to worry about
+ // clearing the parent.
+ if f.parent != nil {
+ // If we've been previously deleted, this removing this
+ // ref is a no-op. That's expected.
+ f.parent.pathNode.removeChild(f)
+ f.parent.DecRef()
+ }
+ }
+}
+
+// isDeleted returns true if this fidRef has been deleted.
+func (f *fidRef) isDeleted() bool {
+ return atomic.LoadUint32(&f.deleted) != 0
+}
+
+// isRoot indicates whether this is a root fid.
+func (f *fidRef) isRoot() bool {
+ return f.parent == nil
+}
+
+// maybeParent returns a cyclic reference for roots, and the parent otherwise.
+func (f *fidRef) maybeParent() *fidRef {
+ if f.parent != nil {
+ return f.parent
+ }
+ return f // Root has itself.
+}
+
+// notifyDelete marks all fidRefs as deleted.
+//
+// Precondition: this must be called via safelyWrite or safelyGlobal.
+func notifyDelete(pn *pathNode) {
+ // Call on all local references.
+ pn.forEachChildRef(func(ref *fidRef, _ string) {
+ atomic.StoreUint32(&ref.deleted, 1)
+ })
+
+ // Call on all subtrees.
+ pn.forEachChildNode(func(pn *pathNode) {
+ notifyDelete(pn)
+ })
+}
+
+// markChildDeleted marks all children below the given name as deleted.
+//
+// Precondition: this must be called via safelyWrite or safelyGlobal.
+func (f *fidRef) markChildDeleted(name string) {
+ origPathNode := f.pathNode.removeWithName(name, func(ref *fidRef) {
+ atomic.StoreUint32(&ref.deleted, 1)
+ })
+
+ if origPathNode != nil {
+ // Mark all children as deleted.
+ notifyDelete(origPathNode)
+ }
+}
+
+// notifyNameChange calls the relevant Renamed method on all nodes in the path,
+// recursively. Note that this applies only for subtrees, as these
+// notifications do not apply to the actual file whose name has changed.
+//
+// Precondition: this must be called via safelyGlobal.
+func notifyNameChange(pn *pathNode) {
+ // Call on all local references.
+ pn.forEachChildRef(func(ref *fidRef, name string) {
+ ref.file.Renamed(ref.parent.file, name)
+ })
+
+ // Call on all subtrees.
+ pn.forEachChildNode(func(pn *pathNode) {
+ notifyNameChange(pn)
+ })
+}
+
+// renameChildTo renames the given child to the target.
+//
+// Precondition: this must be called via safelyGlobal.
+func (f *fidRef) renameChildTo(oldName string, target *fidRef, newName string) {
+ target.markChildDeleted(newName)
+ origPathNode := f.pathNode.removeWithName(oldName, func(ref *fidRef) {
+ // N.B. DecRef can take f.pathNode's parent's childMu. This is
+ // allowed because renameMu is held for write via safelyGlobal.
+ ref.parent.DecRef() // Drop original reference.
+ ref.parent = target // Change parent.
+ ref.parent.IncRef() // Acquire new one.
+ if f.pathNode == target.pathNode {
+ target.pathNode.addChildLocked(ref, newName)
+ } else {
+ target.pathNode.addChild(ref, newName)
+ }
+ ref.file.Renamed(target.file, newName)
+ })
+
+ if origPathNode != nil {
+ // Replace the previous (now deleted) path node.
+ target.pathNode.addPathNodeFor(newName, origPathNode)
+ // Call Renamed on all children.
+ notifyNameChange(origPathNode)
+ }
+}
+
+// safelyRead executes the given operation with the local path node locked.
+// This implies that paths will not change during the operation.
+func (f *fidRef) safelyRead(fn func() error) (err error) {
+ f.server.renameMu.RLock()
+ defer f.server.renameMu.RUnlock()
+ f.pathNode.opMu.RLock()
+ defer f.pathNode.opMu.RUnlock()
+ return fn()
+}
+
+// safelyWrite executes the given operation with the local path node locked in
+// a writable fashion. This implies some paths may change.
+func (f *fidRef) safelyWrite(fn func() error) (err error) {
+ f.server.renameMu.RLock()
+ defer f.server.renameMu.RUnlock()
+ f.pathNode.opMu.Lock()
+ defer f.pathNode.opMu.Unlock()
+ return fn()
+}
+
+// safelyGlobal executes the given operation with the global path lock held.
+func (f *fidRef) safelyGlobal(fn func() error) (err error) {
+ f.server.renameMu.Lock()
+ defer f.server.renameMu.Unlock()
+ return fn()
+}
+
+// LookupFID finds the given FID.
+//
+// You should call fid.DecRef when you are finished using the fid.
+func (cs *connState) LookupFID(fid FID) (*fidRef, bool) {
+ cs.fidMu.Lock()
+ defer cs.fidMu.Unlock()
+ fidRef, ok := cs.fids[fid]
+ if ok {
+ fidRef.IncRef()
+ return fidRef, true
+ }
+ return nil, false
+}
+
+// InsertFID installs the given FID.
+//
+// This fid starts with a reference count of one. If a FID exists in
+// the slot already it is closed, per the specification.
+func (cs *connState) InsertFID(fid FID, newRef *fidRef) {
+ cs.fidMu.Lock()
+ defer cs.fidMu.Unlock()
+ origRef, ok := cs.fids[fid]
+ if ok {
+ defer origRef.DecRef()
+ }
+ newRef.IncRef()
+ cs.fids[fid] = newRef
+}
+
+// DeleteFID removes the given FID.
+//
+// This simply removes it from the map and drops a reference.
+func (cs *connState) DeleteFID(fid FID) bool {
+ cs.fidMu.Lock()
+ defer cs.fidMu.Unlock()
+ fidRef, ok := cs.fids[fid]
+ if !ok {
+ return false
+ }
+ delete(cs.fids, fid)
+ fidRef.DecRef()
+ return true
+}
+
+// StartTag starts handling the tag.
+//
+// False is returned if this tag is already active.
+func (cs *connState) StartTag(t Tag) bool {
+ cs.tagMu.Lock()
+ defer cs.tagMu.Unlock()
+ _, ok := cs.tags[t]
+ if ok {
+ return false
+ }
+ cs.tags[t] = make(chan struct{})
+ return true
+}
+
+// ClearTag finishes handling a tag.
+func (cs *connState) ClearTag(t Tag) {
+ cs.tagMu.Lock()
+ defer cs.tagMu.Unlock()
+ ch, ok := cs.tags[t]
+ if !ok {
+ // Should never happen.
+ panic("unused tag cleared")
+ }
+ delete(cs.tags, t)
+
+ // Notify.
+ close(ch)
+}
+
+// WaitTag waits for a tag to finish.
+func (cs *connState) WaitTag(t Tag) {
+ cs.tagMu.Lock()
+ ch, ok := cs.tags[t]
+ cs.tagMu.Unlock()
+ if !ok {
+ return
+ }
+
+ // Wait for close.
+ <-ch
+}
+
+// initializeChannels initializes all channels.
+//
+// This is a no-op if channels are already initialized.
+func (cs *connState) initializeChannels() (err error) {
+ cs.channelMu.Lock()
+ defer cs.channelMu.Unlock()
+
+ // Initialize our channel allocator.
+ if cs.channelAlloc == nil {
+ alloc, err := flipcall.NewPacketWindowAllocator()
+ if err != nil {
+ return err
+ }
+ cs.channelAlloc = alloc
+ }
+
+ // Create all the channels.
+ for len(cs.channels) < channelsPerClient {
+ res := &channel{
+ done: make(chan struct{}),
+ }
+
+ res.desc, err = cs.channelAlloc.Allocate(channelSize)
+ if err != nil {
+ return err
+ }
+ if err := res.data.Init(flipcall.ServerSide, res.desc); err != nil {
+ return err
+ }
+
+ socks, err := fdchannel.NewConnectedSockets()
+ if err != nil {
+ res.data.Destroy() // Cleanup.
+ return err
+ }
+ res.fds.Init(socks[0])
+ res.client = fd.New(socks[1])
+
+ cs.channels = append(cs.channels, res)
+
+ // Start servicing the channel.
+ //
+ // When we call stop, we will close all the channels and these
+ // routines should finish. We need the wait group to ensure
+ // that active handlers are actually finished before cleanup.
+ cs.channelWg.Add(1)
+ go func() { // S/R-SAFE: Server side.
+ defer cs.channelWg.Done()
+ if err := res.service(cs); err != nil {
+ // Don't log flipcall.ShutdownErrors, which we expect to be
+ // returned during server shutdown.
+ if _, ok := err.(flipcall.ShutdownError); !ok {
+ log.Warningf("p9.channel.service: %v", err)
+ }
+ }
+ }()
+ }
+
+ return nil
+}
+
+// lookupChannel looks up the channel with given id.
+//
+// The function returns nil if no such channel is available.
+func (cs *connState) lookupChannel(id uint32) *channel {
+ cs.channelMu.Lock()
+ defer cs.channelMu.Unlock()
+ if id >= uint32(len(cs.channels)) {
+ return nil
+ }
+ return cs.channels[id]
+}
+
+// handle handles a single message.
+func (cs *connState) handle(m message) (r message) {
+ defer func() {
+ if r == nil {
+ // Don't allow a panic to propagate.
+ err := recover()
+
+ // Include a useful log message.
+ log.Warningf("panic in handler: %v\n%s", err, debug.Stack())
+
+ // Wrap in an EFAULT error; we don't really have a
+ // better way to describe this kind of error. It will
+ // usually manifest as a result of the test framework.
+ r = newErr(syscall.EFAULT)
+ }
+ }()
+ if handler, ok := m.(handler); ok {
+ // Call the message handler.
+ r = handler.handle(cs)
+ } else {
+ // Produce an ENOSYS error.
+ r = newErr(syscall.ENOSYS)
+ }
+ return
+}
+
+// handleRequest handles a single request.
+//
+// The recvDone channel is signaled when recv is done (with a error if
+// necessary). The sendDone channel is signaled with the result of the send.
+func (cs *connState) handleRequest() {
+ messageSize := atomic.LoadUint32(&cs.messageSize)
+ if messageSize == 0 {
+ // Default or not yet negotiated.
+ messageSize = maximumLength
+ }
+
+ // Receive a message.
+ tag, m, err := recv(cs.conn, messageSize, msgRegistry.get)
+ if errSocket, ok := err.(ErrSocket); ok {
+ // Connection problem; stop serving.
+ cs.recvDone <- errSocket.error
+ return
+ }
+
+ // Signal receive is done.
+ cs.recvDone <- nil
+
+ // Deal with other errors.
+ if err != nil && err != io.EOF {
+ // If it's not a connection error, but some other protocol error,
+ // we can send a response immediately.
+ cs.sendMu.Lock()
+ err := send(cs.conn, tag, newErr(err))
+ cs.sendMu.Unlock()
+ cs.sendDone <- err
+ return
+ }
+
+ // Try to start the tag.
+ if !cs.StartTag(tag) {
+ // Nothing we can do at this point; client is bogus.
+ log.Debugf("no valid tag [%05d]", tag)
+ cs.sendDone <- ErrNoValidMessage
+ return
+ }
+
+ // Handle the message.
+ r := cs.handle(m)
+
+ // Clear the tag before sending. That's because as soon as this hits
+ // the wire, the client can legally send the same tag.
+ cs.ClearTag(tag)
+
+ // Send back the result.
+ cs.sendMu.Lock()
+ err = send(cs.conn, tag, r)
+ cs.sendMu.Unlock()
+ cs.sendDone <- err
+
+ // Return the message to the cache.
+ msgRegistry.put(m)
+}
+
+func (cs *connState) handleRequests() {
+ for range cs.recvOkay {
+ cs.handleRequest()
+ }
+}
+
+func (cs *connState) stop() {
+ // Close all channels.
+ close(cs.recvOkay)
+ close(cs.recvDone)
+ close(cs.sendDone)
+
+ // Free the channels.
+ cs.channelMu.Lock()
+ for _, ch := range cs.channels {
+ ch.Shutdown()
+ }
+ cs.channelWg.Wait()
+ for _, ch := range cs.channels {
+ ch.Close()
+ }
+ cs.channels = nil // Clear.
+ cs.channelMu.Unlock()
+
+ // Free the channel memory.
+ if cs.channelAlloc != nil {
+ cs.channelAlloc.Destroy()
+ }
+
+ // Close all remaining fids.
+ for fid, fidRef := range cs.fids {
+ delete(cs.fids, fid)
+
+ // Drop final reference in the FID table. Note this should
+ // always close the file, since we've ensured that there are no
+ // handlers running via the wait for Pending => 0 below.
+ fidRef.DecRef()
+ }
+
+ // Ensure the connection is closed.
+ cs.conn.Close()
+}
+
+// service services requests concurrently.
+func (cs *connState) service() error {
+ // Pending is the number of handlers that have finished receiving but
+ // not finished processing requests. These must be waiting on properly
+ // below. See the next comment for an explanation of the loop.
+ pending := 0
+
+ // Start the first request handler.
+ go cs.handleRequests() // S/R-SAFE: Irrelevant.
+ cs.recvOkay <- true
+
+ // We loop and make sure there's always one goroutine waiting for a new
+ // request. We process all the data for a single request in one
+ // goroutine however, to ensure the best turnaround time possible.
+ for {
+ select {
+ case err := <-cs.recvDone:
+ if err != nil {
+ // Wait for pending handlers.
+ for i := 0; i < pending; i++ {
+ <-cs.sendDone
+ }
+ return nil
+ }
+
+ // This handler is now pending.
+ pending++
+
+ // Kick the next receiver, or start a new handler
+ // if no receiver is currently waiting.
+ select {
+ case cs.recvOkay <- true:
+ default:
+ go cs.handleRequests() // S/R-SAFE: Irrelevant.
+ cs.recvOkay <- true
+ }
+
+ case <-cs.sendDone:
+ // This handler is finished.
+ pending--
+
+ // Error sending a response? Nothing can be done.
+ //
+ // We don't terminate on a send error though, since
+ // we still have a pending receive. The error would
+ // have been logged above, we just ignore it here.
+ }
+ }
+}
+
+// Handle handles a single connection.
+func (s *Server) Handle(conn *unet.Socket) error {
+ cs := &connState{
+ server: s,
+ conn: conn,
+ fids: make(map[FID]*fidRef),
+ tags: make(map[Tag]chan struct{}),
+ recvOkay: make(chan bool),
+ recvDone: make(chan error, 10),
+ sendDone: make(chan error, 10),
+ }
+ defer cs.stop()
+ return cs.service()
+}
+
+// Serve handles requests from the bound socket.
+//
+// The passed serverSocket _must_ be created in packet mode.
+func (s *Server) Serve(serverSocket *unet.ServerSocket) error {
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ for {
+ conn, err := serverSocket.Accept()
+ if err != nil {
+ // Something went wrong.
+ //
+ // Socket closed?
+ return err
+ }
+
+ wg.Add(1)
+ go func(conn *unet.Socket) { // S/R-SAFE: Irrelevant.
+ s.Handle(conn)
+ wg.Done()
+ }(conn)
+ }
+}
diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go
new file mode 100644
index 000000000..7cec0e86d
--- /dev/null
+++ b/pkg/p9/transport.go
@@ -0,0 +1,345 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// ErrSocket is returned in cases of a socket issue.
+//
+// This may be treated differently than other errors.
+type ErrSocket struct {
+ // error is the socket error.
+ error
+}
+
+// ErrMessageTooLarge indicates the size was larger than reasonable.
+type ErrMessageTooLarge struct {
+ size uint32
+ msize uint32
+}
+
+// Error returns a sensible error.
+func (e *ErrMessageTooLarge) Error() string {
+ return fmt.Sprintf("message too large for fixed buffer: size is %d, limit is %d", e.size, e.msize)
+}
+
+// ErrNoValidMessage indicates no valid message could be decoded.
+var ErrNoValidMessage = errors.New("buffer contained no valid message")
+
+const (
+ // headerLength is the number of bytes required for a header.
+ headerLength uint32 = 7
+
+ // maximumLength is the largest possible message.
+ maximumLength uint32 = 1 << 20
+
+ // DefaultMessageSize is a sensible default.
+ DefaultMessageSize uint32 = 64 << 10
+
+ // initialBufferLength is the initial data buffer we allocate.
+ initialBufferLength uint32 = 64
+)
+
+var dataPool = sync.Pool{
+ New: func() interface{} {
+ // These buffers are used for decoding without a payload.
+ return make([]byte, initialBufferLength)
+ },
+}
+
+// send sends the given message over the socket.
+func send(s *unet.Socket, tag Tag, m message) error {
+ data := dataPool.Get().([]byte)
+ dataBuf := buffer{data: data[:0]}
+
+ if log.IsLogging(log.Debug) {
+ log.Debugf("send [FD %d] [Tag %06d] %s", s.FD(), tag, m.String())
+ }
+
+ // Encode the message. The buffer will grow automatically.
+ m.encode(&dataBuf)
+
+ // Get our vectors to send.
+ var hdr [headerLength]byte
+ vecs := make([][]byte, 0, 3)
+ vecs = append(vecs, hdr[:])
+ if len(dataBuf.data) > 0 {
+ vecs = append(vecs, dataBuf.data)
+ }
+ totalLength := headerLength + uint32(len(dataBuf.data))
+
+ // Is there a payload?
+ if payloader, ok := m.(payloader); ok {
+ p := payloader.Payload()
+ if len(p) > 0 {
+ vecs = append(vecs, p)
+ totalLength += uint32(len(p))
+ }
+ }
+
+ // Construct the header.
+ headerBuf := buffer{data: hdr[:0]}
+ headerBuf.Write32(totalLength)
+ headerBuf.WriteMsgType(m.Type())
+ headerBuf.WriteTag(tag)
+
+ // Pack any files if necessary.
+ w := s.Writer(true)
+ if filer, ok := m.(filer); ok {
+ if f := filer.FilePayload(); f != nil {
+ defer f.Close()
+ // Pack the file into the message.
+ w.PackFDs(f.FD())
+ }
+ }
+
+ for n := 0; n < int(totalLength); {
+ cur, err := w.WriteVec(vecs)
+ if err != nil {
+ return ErrSocket{err}
+ }
+ n += cur
+
+ // Consume iovecs.
+ for consumed := 0; consumed < cur; {
+ if len(vecs[0]) <= cur-consumed {
+ consumed += len(vecs[0])
+ vecs = vecs[1:]
+ } else {
+ vecs[0] = vecs[0][cur-consumed:]
+ break
+ }
+ }
+
+ if n > 0 && n < int(totalLength) {
+ // Don't resend any control message.
+ w.UnpackFDs()
+ }
+ }
+
+ // All set.
+ dataPool.Put(dataBuf.data)
+ return nil
+}
+
+// lookupTagAndType looks up an existing message or creates a new one.
+//
+// This is called by recv after decoding the header. Any error returned will be
+// propagating back to the caller. You may use messageByType directly as a
+// lookupTagAndType function (by design).
+type lookupTagAndType func(tag Tag, t MsgType) (message, error)
+
+// recv decodes a message from the socket.
+//
+// This is done in two parts, and is thus not safe for multiple callers.
+//
+// On a socket error, the special error type ErrSocket is returned.
+//
+// The tag value NoTag will always be returned if err is non-nil.
+func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message, error) {
+ // Read a header.
+ //
+ // Since the send above is atomic, we must always receive control
+ // messages along with the header. This means we need to be careful
+ // about closing FDs during errors to prevent leaks.
+ var hdr [headerLength]byte
+ r := s.Reader(true)
+ r.EnableFDs(1)
+
+ n, err := r.ReadVec([][]byte{hdr[:]})
+ if err != nil && (n == 0 || err != io.EOF) {
+ r.CloseFDs()
+ return NoTag, nil, ErrSocket{err}
+ }
+
+ fds, err := r.ExtractFDs()
+ if err != nil {
+ return NoTag, nil, ErrSocket{err}
+ }
+ defer func() {
+ // Close anything left open. The case where
+ // fds are caught and used is handled below,
+ // and the fds variable will be set to nil.
+ for _, fd := range fds {
+ syscall.Close(fd)
+ }
+ }()
+ r.EnableFDs(0)
+
+ // Continuing reading for a short header.
+ for n < int(headerLength) {
+ cur, err := r.ReadVec([][]byte{hdr[n:]})
+ if err != nil && (cur == 0 || err != io.EOF) {
+ return NoTag, nil, ErrSocket{err}
+ }
+ n += cur
+ }
+
+ // Decode the header.
+ headerBuf := buffer{data: hdr[:]}
+ size := headerBuf.Read32()
+ t := headerBuf.ReadMsgType()
+ tag := headerBuf.ReadTag()
+ if size < headerLength {
+ // The message is too small.
+ //
+ // See above: it's probably screwed.
+ return NoTag, nil, ErrSocket{ErrNoValidMessage}
+ }
+ if size > maximumLength || size > msize {
+ // The message is too big.
+ return NoTag, nil, ErrSocket{&ErrMessageTooLarge{size, msize}}
+ }
+ remaining := size - headerLength
+
+ // Find our message to decode.
+ m, err := lookup(tag, t)
+ if err != nil {
+ // Throw away the contents of this message.
+ if remaining > 0 {
+ io.Copy(ioutil.Discard, &io.LimitedReader{R: s, N: int64(remaining)})
+ }
+ return tag, nil, err
+ }
+
+ // Not yet initialized.
+ var dataBuf buffer
+
+ // Read the rest of the payload.
+ //
+ // This requires some special care to ensure that the vectors all line
+ // up the way they should. We do this to minimize copying data around.
+ var vecs [][]byte
+ if payloader, ok := m.(payloader); ok {
+ fixedSize := payloader.FixedSize()
+
+ // Do we need more than there is?
+ if fixedSize > remaining {
+ // This is not a valid message.
+ if remaining > 0 {
+ io.Copy(ioutil.Discard, &io.LimitedReader{R: s, N: int64(remaining)})
+ }
+ return NoTag, nil, ErrNoValidMessage
+ }
+
+ if fixedSize != 0 {
+ // Pull a data buffer from the pool.
+ data := dataPool.Get().([]byte)
+ if int(fixedSize) > len(data) {
+ // Create a larger data buffer, ensuring
+ // sufficient capicity for the message.
+ data = make([]byte, fixedSize)
+ defer dataPool.Put(data)
+ dataBuf = buffer{data: data}
+ vecs = append(vecs, data)
+ } else {
+ // Limit the data buffer, and make sure it
+ // gets filled before the payload buffer.
+ defer dataPool.Put(data)
+ dataBuf = buffer{data: data[:fixedSize]}
+ vecs = append(vecs, data[:fixedSize])
+ }
+ }
+
+ // Include the payload.
+ p := payloader.Payload()
+ if p == nil || len(p) != int(remaining-fixedSize) {
+ p = make([]byte, remaining-fixedSize)
+ payloader.SetPayload(p)
+ }
+ if len(p) > 0 {
+ vecs = append(vecs, p)
+ }
+ } else if remaining != 0 {
+ // Pull a data buffer from the pool.
+ data := dataPool.Get().([]byte)
+ if int(remaining) > len(data) {
+ // Create a larger data buffer.
+ data = make([]byte, remaining)
+ defer dataPool.Put(data)
+ dataBuf = buffer{data: data}
+ vecs = append(vecs, data)
+ } else {
+ // Limit the data buffer.
+ defer dataPool.Put(data)
+ dataBuf = buffer{data: data[:remaining]}
+ vecs = append(vecs, data[:remaining])
+ }
+ }
+
+ if len(vecs) > 0 {
+ // Read the rest of the message.
+ //
+ // No need to handle a control message.
+ r := s.Reader(true)
+ for n := 0; n < int(remaining); {
+ cur, err := r.ReadVec(vecs)
+ if err != nil && (cur == 0 || err != io.EOF) {
+ return NoTag, nil, ErrSocket{err}
+ }
+ n += cur
+
+ // Consume iovecs.
+ for consumed := 0; consumed < cur; {
+ if len(vecs[0]) <= cur-consumed {
+ consumed += len(vecs[0])
+ vecs = vecs[1:]
+ } else {
+ vecs[0] = vecs[0][cur-consumed:]
+ break
+ }
+ }
+ }
+ }
+
+ // Decode the message data.
+ m.decode(&dataBuf)
+ if dataBuf.isOverrun() {
+ // No need to drain the socket.
+ return NoTag, nil, ErrNoValidMessage
+ }
+
+ // Save the file, if any came out.
+ if filer, ok := m.(filer); ok && len(fds) > 0 {
+ // Set the file object.
+ filer.SetFilePayload(fd.New(fds[0]))
+
+ // Close the rest. We support only one.
+ for i := 1; i < len(fds); i++ {
+ syscall.Close(fds[i])
+ }
+
+ // Don't close in the defer.
+ fds = nil
+ }
+
+ if log.IsLogging(log.Debug) {
+ log.Debugf("recv [FD %d] [Tag %06d] %s", s.FD(), tag, m.String())
+ }
+
+ // All set.
+ return tag, m, nil
+}
diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go
new file mode 100644
index 000000000..38038abdf
--- /dev/null
+++ b/pkg/p9/transport_flipcall.go
@@ -0,0 +1,243 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "runtime"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/fdchannel"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// channelsPerClient is the number of channels to create per client.
+//
+// While the client and server will generally agree on this number, in reality
+// it's completely up to the server. We simply define a minimum of 2, and a
+// maximum of 4, and select the number of available processes as a tie-breaker.
+// Note that we don't want the number of channels to be too large, because each
+// will account for channelSize memory used, which can be large.
+var channelsPerClient = func() int {
+ n := runtime.NumCPU()
+ if n < 2 {
+ return 2
+ }
+ if n > 4 {
+ return 4
+ }
+ return n
+}()
+
+// channelSize is the channel size to create.
+//
+// We simply ensure that this is larger than the largest possible message size,
+// plus the flipcall packet header, plus the two bytes we write below.
+const channelSize = int(2 + flipcall.PacketHeaderBytes + 2 + maximumLength)
+
+// channel is a fast IPC channel.
+//
+// The same object is used by both the server and client implementations. In
+// general, the client will use only the send and recv methods.
+type channel struct {
+ desc flipcall.PacketWindowDescriptor
+ data flipcall.Endpoint
+ fds fdchannel.Endpoint
+ buf buffer
+
+ // -- client only --
+ connected bool
+ active bool
+
+ // -- server only --
+ client *fd.FD
+ done chan struct{}
+}
+
+// reset resets the channel buffer.
+func (ch *channel) reset(sz uint32) {
+ ch.buf.data = ch.data.Data()[:sz]
+}
+
+// service services the channel.
+func (ch *channel) service(cs *connState) error {
+ rsz, err := ch.data.RecvFirst()
+ if err != nil {
+ return err
+ }
+ for rsz > 0 {
+ m, err := ch.recv(nil, rsz)
+ if err != nil {
+ return err
+ }
+ r := cs.handle(m)
+ msgRegistry.put(m)
+ rsz, err = ch.send(r)
+ if err != nil {
+ return err
+ }
+ }
+ return nil // Done.
+}
+
+// Shutdown shuts down the channel.
+//
+// This must be called before Close.
+func (ch *channel) Shutdown() {
+ ch.data.Shutdown()
+}
+
+// Close closes the channel.
+//
+// This must only be called once, and cannot return an error. Note that
+// synchronization for this method is provided at a high-level, depending on
+// whether it is the client or server. This cannot be called while there are
+// active callers in either service or sendRecv.
+//
+// Precondition: the channel should be shutdown.
+func (ch *channel) Close() error {
+ // Close all backing transports.
+ ch.fds.Destroy()
+ ch.data.Destroy()
+ if ch.client != nil {
+ ch.client.Close()
+ }
+ return nil
+}
+
+// send sends the given message.
+//
+// The return value is the size of the received response. Not that in the
+// server case, this is the size of the next request.
+func (ch *channel) send(m message) (uint32, error) {
+ if log.IsLogging(log.Debug) {
+ log.Debugf("send [channel @%p] %s", ch, m.String())
+ }
+
+ // Send any file payload.
+ sentFD := false
+ if filer, ok := m.(filer); ok {
+ if f := filer.FilePayload(); f != nil {
+ if err := ch.fds.SendFD(f.FD()); err != nil {
+ return 0, err
+ }
+ f.Close() // Per sendRecvLegacy.
+ sentFD = true // To mark below.
+ }
+ }
+
+ // Encode the message.
+ //
+ // Note that IPC itself encodes the length of messages, so we don't
+ // need to encode a standard 9P header. We write only the message type.
+ ch.reset(0)
+
+ ch.buf.WriteMsgType(m.Type())
+ if sentFD {
+ ch.buf.Write8(1) // Incoming FD.
+ } else {
+ ch.buf.Write8(0) // No incoming FD.
+ }
+ m.encode(&ch.buf)
+ ssz := uint32(len(ch.buf.data)) // Updated below.
+
+ // Is there a payload?
+ if payloader, ok := m.(payloader); ok {
+ p := payloader.Payload()
+ copy(ch.data.Data()[ssz:], p)
+ ssz += uint32(len(p))
+ }
+
+ // Perform the one-shot communication.
+ return ch.data.SendRecv(ssz)
+}
+
+// recv decodes a message that exists on the channel.
+//
+// If the passed r is non-nil, then the type must match or an error will be
+// generated. If the passed r is nil, then a new message will be created and
+// returned.
+func (ch *channel) recv(r message, rsz uint32) (message, error) {
+ // Decode the response from the inline buffer.
+ ch.reset(rsz)
+ t := ch.buf.ReadMsgType()
+ hasFD := ch.buf.Read8() != 0
+ if t == MsgRlerror {
+ // Change the message type. We check for this special case
+ // after decoding below, and transform into an error.
+ r = &Rlerror{}
+ } else if r == nil {
+ nr, err := msgRegistry.get(0, t)
+ if err != nil {
+ return nil, err
+ }
+ r = nr // New message.
+ } else if t != r.Type() {
+ // Not an error and not the expected response; propagate.
+ return nil, &ErrBadResponse{Got: t, Want: r.Type()}
+ }
+
+ // Is there a payload? Copy from the latter portion.
+ if payloader, ok := r.(payloader); ok {
+ fs := payloader.FixedSize()
+ p := payloader.Payload()
+ payloadData := ch.buf.data[fs:]
+ if len(p) < len(payloadData) {
+ p = make([]byte, len(payloadData))
+ copy(p, payloadData)
+ payloader.SetPayload(p)
+ } else if n := copy(p, payloadData); n < len(p) {
+ payloader.SetPayload(p[:n])
+ }
+ ch.buf.data = ch.buf.data[:fs]
+ }
+
+ r.decode(&ch.buf)
+ if ch.buf.isOverrun() {
+ // Nothing valid was available.
+ log.Debugf("recv [got %d bytes, needed more]", rsz)
+ return nil, ErrNoValidMessage
+ }
+
+ // Read any FD result.
+ if hasFD {
+ if rfd, err := ch.fds.RecvFDNonblock(); err == nil {
+ f := fd.New(rfd)
+ if filer, ok := r.(filer); ok {
+ // Set the payload.
+ filer.SetFilePayload(f)
+ } else {
+ // Don't want the FD.
+ f.Close()
+ }
+ } else {
+ // The header bit was set but nothing came in.
+ log.Warningf("expected FD, got err: %v", err)
+ }
+ }
+
+ // Log a message.
+ if log.IsLogging(log.Debug) {
+ log.Debugf("recv [channel @%p] %s", ch, r.String())
+ }
+
+ // Convert errors appropriately; see above.
+ if rlerr, ok := r.(*Rlerror); ok {
+ return r, syscall.Errno(rlerr.Error)
+ }
+
+ return r, nil
+}
diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go
new file mode 100644
index 000000000..3668fcad7
--- /dev/null
+++ b/pkg/p9/transport_test.go
@@ -0,0 +1,231 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "io/ioutil"
+ "os"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+const (
+ MsgTypeBadEncode = iota + 252
+ MsgTypeBadDecode
+ MsgTypeUnregistered
+)
+
+func TestSendRecv(t *testing.T) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer server.Close()
+ defer client.Close()
+
+ if err := send(client, Tag(1), &Tlopen{}); err != nil {
+ t.Fatalf("send got err %v expected nil", err)
+ }
+
+ tag, m, err := recv(server, maximumLength, msgRegistry.get)
+ if err != nil {
+ t.Fatalf("recv got err %v expected nil", err)
+ }
+ if tag != Tag(1) {
+ t.Fatalf("got tag %v expected 1", tag)
+ }
+ if _, ok := m.(*Tlopen); !ok {
+ t.Fatalf("got message %v expected *Tlopen", m)
+ }
+}
+
+// badDecode overruns on decode.
+type badDecode struct{}
+
+func (*badDecode) decode(b *buffer) { b.markOverrun() }
+func (*badDecode) encode(b *buffer) {}
+func (*badDecode) Type() MsgType { return MsgTypeBadDecode }
+func (*badDecode) String() string { return "badDecode{}" }
+
+func TestRecvOverrun(t *testing.T) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer server.Close()
+ defer client.Close()
+
+ if err := send(client, Tag(1), &badDecode{}); err != nil {
+ t.Fatalf("send got err %v expected nil", err)
+ }
+
+ if _, _, err := recv(server, maximumLength, msgRegistry.get); err == nil {
+ t.Fatalf("recv got err %v expected ErrSocket{ErrNoValidMessage}", err)
+ }
+}
+
+// unregistered is not registered on decode.
+type unregistered struct{}
+
+func (*unregistered) decode(b *buffer) {}
+func (*unregistered) encode(b *buffer) {}
+func (*unregistered) Type() MsgType { return MsgTypeUnregistered }
+func (*unregistered) String() string { return "unregistered{}" }
+
+func TestRecvInvalidType(t *testing.T) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer server.Close()
+ defer client.Close()
+
+ if err := send(client, Tag(1), &unregistered{}); err != nil {
+ t.Fatalf("send got err %v expected nil", err)
+ }
+
+ _, _, err = recv(server, maximumLength, msgRegistry.get)
+ if _, ok := err.(*ErrInvalidMsgType); !ok {
+ t.Fatalf("recv got err %v expected ErrInvalidMsgType", err)
+ }
+}
+
+func TestSendRecvWithFile(t *testing.T) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer server.Close()
+ defer client.Close()
+
+ // Create a tempfile.
+ osf, err := ioutil.TempFile("", "p9")
+ if err != nil {
+ t.Fatalf("tempfile got err %v expected nil", err)
+ }
+ os.Remove(osf.Name())
+ f, err := fd.NewFromFile(osf)
+ osf.Close()
+ if err != nil {
+ t.Fatalf("unable to create file: %v", err)
+ }
+
+ rlopen := &Rlopen{}
+ rlopen.SetFilePayload(f)
+ if err := send(client, Tag(1), rlopen); err != nil {
+ t.Fatalf("send got err %v expected nil", err)
+ }
+
+ // Enable withFile.
+ tag, m, err := recv(server, maximumLength, msgRegistry.get)
+ if err != nil {
+ t.Fatalf("recv got err %v expected nil", err)
+ }
+ if tag != Tag(1) {
+ t.Fatalf("got tag %v expected 1", tag)
+ }
+ rlopen, ok := m.(*Rlopen)
+ if !ok {
+ t.Fatalf("got m %v expected *Rlopen", m)
+ }
+ if rlopen.File == nil {
+ t.Fatalf("got nil file expected non-nil")
+ }
+}
+
+func TestRecvClosed(t *testing.T) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer server.Close()
+ client.Close()
+
+ _, _, err = recv(server, maximumLength, msgRegistry.get)
+ if err == nil {
+ t.Fatalf("got err nil expected non-nil")
+ }
+ if _, ok := err.(ErrSocket); !ok {
+ t.Fatalf("got err %v expected ErrSocket", err)
+ }
+}
+
+func TestSendClosed(t *testing.T) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ server.Close()
+ defer client.Close()
+
+ err = send(client, Tag(1), &Tlopen{})
+ if err == nil {
+ t.Fatalf("send got err nil expected non-nil")
+ }
+ if _, ok := err.(ErrSocket); !ok {
+ t.Fatalf("got err %v expected ErrSocket", err)
+ }
+}
+
+func BenchmarkSendRecv(b *testing.B) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ b.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer server.Close()
+ defer client.Close()
+
+ // Exchange Rflush messages since these contain no data and therefore incur
+ // no additional marshaling overhead.
+ go func() {
+ for i := 0; i < b.N; i++ {
+ tag, m, err := recv(server, maximumLength, msgRegistry.get)
+ if err != nil {
+ b.Fatalf("recv got err %v expected nil", err)
+ }
+ if tag != Tag(1) {
+ b.Fatalf("got tag %v expected 1", tag)
+ }
+ if _, ok := m.(*Rflush); !ok {
+ b.Fatalf("got message %T expected *Rflush", m)
+ }
+ if err := send(server, Tag(2), &Rflush{}); err != nil {
+ b.Fatalf("send got err %v expected nil", err)
+ }
+ }
+ }()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ if err := send(client, Tag(1), &Rflush{}); err != nil {
+ b.Fatalf("send got err %v expected nil", err)
+ }
+ tag, m, err := recv(client, maximumLength, msgRegistry.get)
+ if err != nil {
+ b.Fatalf("recv got err %v expected nil", err)
+ }
+ if tag != Tag(2) {
+ b.Fatalf("got tag %v expected 2", tag)
+ }
+ if _, ok := m.(*Rflush); !ok {
+ b.Fatalf("got message %v expected *Rflush", m)
+ }
+ }
+}
+
+func init() {
+ msgRegistry.register(MsgTypeBadDecode, func() message { return &badDecode{} })
+}
diff --git a/pkg/p9/version.go b/pkg/p9/version.go
new file mode 100644
index 000000000..09cde9f5a
--- /dev/null
+++ b/pkg/p9/version.go
@@ -0,0 +1,175 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+)
+
+const (
+ // highestSupportedVersion is the highest supported version X in a
+ // version string of the format 9P2000.L.Google.X.
+ //
+ // Clients are expected to start requesting this version number and
+ // to continuously decrement it until a Tversion request succeeds.
+ highestSupportedVersion uint32 = 11
+
+ // lowestSupportedVersion is the lowest supported version X in a
+ // version string of the format 9P2000.L.Google.X.
+ //
+ // Clients are free to send a Tversion request at a version below this
+ // value but are expected to encounter an Rlerror in response.
+ lowestSupportedVersion uint32 = 0
+
+ // baseVersion is the base version of 9P that this package must always
+ // support. It is equivalent to 9P2000.L.Google.0.
+ baseVersion = "9P2000.L"
+)
+
+// HighestVersionString returns the highest possible version string that a client
+// may request or a server may support.
+func HighestVersionString() string {
+ return versionString(highestSupportedVersion)
+}
+
+// parseVersion parses a Tversion version string into a numeric version number
+// if the version string is supported by p9. Otherwise returns (0, false).
+//
+// From Tversion(9P): "Version strings are defined such that, if the client string
+// contains one or more period characters, the initial substring up to but not
+// including any single period in the version string defines a version of the protocol."
+//
+// p9 intentionally diverges from this and always requires that the version string
+// start with 9P2000.L to express that it is always compatible with 9P2000.L. The
+// only supported versions extensions are of the format 9p2000.L.Google.X where X
+// is an ever increasing version counter.
+//
+// Version 9P2000.L.Google.0 implies 9P2000.L.
+//
+// New versions must always be a strict superset of 9P2000.L. A version increase must
+// define a predicate representing the feature extension introduced by that version. The
+// predicate must be commented and should take the format:
+//
+// // VersionSupportsX returns true if version v supports X and must be checked when ...
+// func VersionSupportsX(v int32) bool {
+// ...
+// )
+func parseVersion(str string) (uint32, bool) {
+ // Special case the base version which lacks the ".Google.X" suffix. This
+ // version always means version 0.
+ if str == baseVersion {
+ return 0, true
+ }
+ substr := strings.Split(str, ".")
+ if len(substr) != 4 {
+ return 0, false
+ }
+ if substr[0] != "9P2000" || substr[1] != "L" || substr[2] != "Google" || len(substr[3]) == 0 {
+ return 0, false
+ }
+ version, err := strconv.ParseUint(substr[3], 10, 32)
+ if err != nil {
+ return 0, false
+ }
+ return uint32(version), true
+}
+
+// versionString formats a p9 version number into a Tversion version string.
+func versionString(version uint32) string {
+ // Special case the base version so that clients expecting this string
+ // instead of the 9P2000.L.Google.0 equivalent get it. This is important
+ // for backwards compatibility with legacy servers that check for exactly
+ // the baseVersion and allow nothing else.
+ if version == 0 {
+ return baseVersion
+ }
+ return fmt.Sprintf("9P2000.L.Google.%d", version)
+}
+
+// VersionSupportsTflushf returns true if version v supports the Tflushf message.
+// This predicate must be checked by clients before attempting to make a Tflushf
+// request. If this predicate returns false, then clients may safely no-op.
+func VersionSupportsTflushf(v uint32) bool {
+ return v >= 1
+}
+
+// versionSupportsTwalkgetattr returns true if version v supports the
+// Twalkgetattr message. This predicate must be checked by clients before
+// attempting to make a Twalkgetattr request.
+func versionSupportsTwalkgetattr(v uint32) bool {
+ return v >= 2
+}
+
+// versionSupportsTucreation returns true if version v supports the Tucreation
+// messages (Tucreate, Tusymlink, Tumkdir, Tumknod). This predicate must be
+// checked by clients before attempting to make a Tucreation request.
+// If Tucreation messages are not supported, their non-UID supporting
+// counterparts (Tlcreate, Tsymlink, Tmkdir, Tmknod) should be used.
+func versionSupportsTucreation(v uint32) bool {
+ return v >= 3
+}
+
+// VersionSupportsConnect returns true if version v supports the Tlconnect
+// message. This predicate must be checked by clients
+// before attempting to make a Tlconnect request. If Tlconnect messages are not
+// supported, Tlopen should be used.
+func VersionSupportsConnect(v uint32) bool {
+ return v >= 4
+}
+
+// VersionSupportsAnonymous returns true if version v supports Tlconnect
+// with the AnonymousSocket mode. This predicate must be checked by clients
+// before attempting to use the AnonymousSocket Tlconnect mode.
+func VersionSupportsAnonymous(v uint32) bool {
+ return v >= 5
+}
+
+// VersionSupportsMultiUser returns true if version v supports multi-user fake
+// directory permissions and ID values.
+func VersionSupportsMultiUser(v uint32) bool {
+ return v >= 6
+}
+
+// versionSupportsTallocate returns true if version v supports Allocate().
+func versionSupportsTallocate(v uint32) bool {
+ return v >= 7
+}
+
+// versionSupportsFlipcall returns true if version v supports IPC channels from
+// the flipcall package. Note that these must be negotiated, but this version
+// string indicates that such a facility exists.
+func versionSupportsFlipcall(v uint32) bool {
+ return v >= 8
+}
+
+// VersionSupportsOpenTruncateFlag returns true if version v supports
+// passing the OpenTruncate flag to Tlopen.
+func VersionSupportsOpenTruncateFlag(v uint32) bool {
+ return v >= 9
+}
+
+// versionSupportsGetSetXattr returns true if version v supports
+// the Tgetxattr and Tsetxattr messages.
+func versionSupportsGetSetXattr(v uint32) bool {
+ return v >= 10
+}
+
+// versionSupportsListRemoveXattr returns true if version v supports
+// the Tlistxattr and Tremovexattr messages.
+func versionSupportsListRemoveXattr(v uint32) bool {
+ return v >= 11
+}
diff --git a/pkg/p9/version_test.go b/pkg/p9/version_test.go
new file mode 100644
index 000000000..291e8580e
--- /dev/null
+++ b/pkg/p9/version_test.go
@@ -0,0 +1,145 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "testing"
+)
+
+func TestVersionNumberEquivalent(t *testing.T) {
+ for i := uint32(0); i < 1024; i++ {
+ str := versionString(i)
+ version, ok := parseVersion(str)
+ if !ok {
+ t.Errorf("#%d: parseVersion(%q) failed, want success", i, str)
+ continue
+ }
+ if i != version {
+ t.Errorf("#%d: got version %d, want %d", i, i, version)
+ }
+ }
+}
+
+func TestVersionStringEquivalent(t *testing.T) {
+ // There is one case where the version is not equivalent on purpose,
+ // that is 9P2000.L.Google.0. It is not equivalent because versionString
+ // must always return the more generic 9P2000.L for legacy servers that
+ // check for it. See net/9p/client.c.
+ str := "9P2000.L.Google.0"
+ version, ok := parseVersion(str)
+ if !ok {
+ t.Errorf("parseVersion(%q) failed, want success", str)
+ }
+ if got := versionString(version); got != "9P2000.L" {
+ t.Errorf("versionString(%d) got %q, want %q", version, got, "9P2000.L")
+ }
+
+ for _, test := range []struct {
+ versionString string
+ }{
+ {
+ versionString: "9P2000.L",
+ },
+ {
+ versionString: "9P2000.L.Google.1",
+ },
+ {
+ versionString: "9P2000.L.Google.347823894",
+ },
+ } {
+ version, ok := parseVersion(test.versionString)
+ if !ok {
+ t.Errorf("parseVersion(%q) failed, want success", test.versionString)
+ continue
+ }
+ if got := versionString(version); got != test.versionString {
+ t.Errorf("versionString(%d) got %q, want %q", version, got, test.versionString)
+ }
+ }
+}
+
+func TestParseVersion(t *testing.T) {
+ for _, test := range []struct {
+ versionString string
+ expectSuccess bool
+ expectedVersion uint32
+ }{
+ {
+ versionString: "9P",
+ expectSuccess: false,
+ },
+ {
+ versionString: "9P.L",
+ expectSuccess: false,
+ },
+ {
+ versionString: "9P200.L",
+ expectSuccess: false,
+ },
+ {
+ versionString: "9P2000",
+ expectSuccess: false,
+ },
+ {
+ versionString: "9P2000.L.Google.-1",
+ expectSuccess: false,
+ },
+ {
+ versionString: "9P2000.L.Google.",
+ expectSuccess: false,
+ },
+ {
+ versionString: "9P2000.L.Google.3546343826724305832",
+ expectSuccess: false,
+ },
+ {
+ versionString: "9P2001.L",
+ expectSuccess: false,
+ },
+ {
+ versionString: "9P2000.L",
+ expectSuccess: true,
+ expectedVersion: 0,
+ },
+ {
+ versionString: "9P2000.L.Google.0",
+ expectSuccess: true,
+ expectedVersion: 0,
+ },
+ {
+ versionString: "9P2000.L.Google.1",
+ expectSuccess: true,
+ expectedVersion: 1,
+ },
+ } {
+ version, ok := parseVersion(test.versionString)
+ if ok != test.expectSuccess {
+ t.Errorf("parseVersion(%q) got (_, %v), want (_, %v)", test.versionString, ok, test.expectSuccess)
+ continue
+ }
+ if !test.expectSuccess {
+ continue
+ }
+ if version != test.expectedVersion {
+ t.Errorf("parseVersion(%q) got (%d, _), want (%d, _)", test.versionString, version, test.expectedVersion)
+ }
+ }
+}
+
+func BenchmarkParseVersion(b *testing.B) {
+ for n := 0; n < b.N; n++ {
+ parseVersion("9P2000.L.Google.1")
+ }
+}
diff --git a/pkg/pool/BUILD b/pkg/pool/BUILD
new file mode 100644
index 000000000..7b1c6b75b
--- /dev/null
+++ b/pkg/pool/BUILD
@@ -0,0 +1,25 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "pool",
+ srcs = [
+ "pool.go",
+ ],
+ deps = [
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "pool_test",
+ size = "small",
+ srcs = [
+ "pool_test.go",
+ ],
+ library = ":pool",
+)
diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go
new file mode 100644
index 000000000..a1b2e0cfe
--- /dev/null
+++ b/pkg/pool/pool.go
@@ -0,0 +1,66 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pool
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Pool is a simple allocator.
+type Pool struct {
+ mu sync.Mutex
+
+ // cache is the set of returned values.
+ cache []uint64
+
+ // Start is the starting value (if needed).
+ Start uint64
+
+ // max is the current maximum issued.
+ max uint64
+
+ // Limit is the upper limit.
+ Limit uint64
+}
+
+// Get gets a value from the pool.
+func (p *Pool) Get() (uint64, bool) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ // Anything cached?
+ if len(p.cache) > 0 {
+ v := p.cache[len(p.cache)-1]
+ p.cache = p.cache[:len(p.cache)-1]
+ return v, true
+ }
+
+ // Over the limit?
+ if p.Start == p.Limit {
+ return 0, false
+ }
+
+ // Generate a new value.
+ v := p.Start
+ p.Start++
+ return v, true
+}
+
+// Put returns a value to the pool.
+func (p *Pool) Put(v uint64) {
+ p.mu.Lock()
+ p.cache = append(p.cache, v)
+ p.mu.Unlock()
+}
diff --git a/pkg/pool/pool_test.go b/pkg/pool/pool_test.go
new file mode 100644
index 000000000..d928439c1
--- /dev/null
+++ b/pkg/pool/pool_test.go
@@ -0,0 +1,64 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pool
+
+import (
+ "testing"
+)
+
+func TestPoolUnique(t *testing.T) {
+ p := Pool{Start: 1, Limit: 3}
+ got := make(map[uint64]bool)
+
+ for {
+ n, ok := p.Get()
+ if !ok {
+ break
+ }
+
+ // Check unique.
+ if _, ok := got[n]; ok {
+ t.Errorf("pool spit out %v multiple times", n)
+ }
+
+ // Record.
+ got[n] = true
+ }
+}
+
+func TestExausted(t *testing.T) {
+ p := Pool{Start: 1, Limit: 500}
+ for i := 0; i < 499; i++ {
+ _, ok := p.Get()
+ if !ok {
+ t.Fatalf("pool exhausted before 499 items")
+ }
+ }
+
+ _, ok := p.Get()
+ if ok {
+ t.Errorf("pool not exhausted when it should be")
+ }
+}
+
+func TestPoolRecycle(t *testing.T) {
+ p := Pool{Start: 1, Limit: 500}
+ n1, _ := p.Get()
+ p.Put(n1)
+ n2, _ := p.Get()
+ if n1 != n2 {
+ t.Errorf("pool not recycling items")
+ }
+}
diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD
new file mode 100644
index 000000000..aa3e3ac0b
--- /dev/null
+++ b/pkg/procid/BUILD
@@ -0,0 +1,34 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "procid",
+ srcs = [
+ "procid.go",
+ "procid_amd64.s",
+ "procid_arm64.s",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+go_test(
+ name = "procid_test",
+ size = "small",
+ srcs = [
+ "procid_test.go",
+ ],
+ library = ":procid",
+ deps = ["//pkg/sync"],
+)
+
+go_test(
+ name = "procid_net_test",
+ size = "small",
+ srcs = [
+ "procid_net_test.go",
+ "procid_test.go",
+ ],
+ library = ":procid",
+ deps = ["//pkg/sync"],
+)
diff --git a/pkg/procid/procid.go b/pkg/procid/procid.go
new file mode 100644
index 000000000..78b92422c
--- /dev/null
+++ b/pkg/procid/procid.go
@@ -0,0 +1,21 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package procid provides a way to get the current system thread identifier.
+package procid
+
+// Current returns the current system thread identifier.
+//
+// Precondition: This should only be called with the runtime OS thread locked.
+func Current() uint64
diff --git a/pkg/procid/procid_amd64.s b/pkg/procid/procid_amd64.s
new file mode 100644
index 000000000..7c622e5d7
--- /dev/null
+++ b/pkg/procid/procid_amd64.s
@@ -0,0 +1,30 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+// +build go1.8
+// +build !go1.16
+
+#include "textflag.h"
+
+TEXT ·Current(SB),NOSPLIT,$0-8
+ // The offset specified here is the m_procid offset for Go1.8+.
+ // Changes to this offset should be caught by the tests, and major
+ // version changes require an explicit tag change above.
+ MOVQ TLS, AX
+ MOVQ 0(AX)(TLS*1), AX
+ MOVQ 48(AX), AX // g_m (may change in future versions)
+ MOVQ 72(AX), AX // m_procid (may change in future versions)
+ MOVQ AX, ret+0(FP)
+ RET
diff --git a/pkg/procid/procid_arm64.s b/pkg/procid/procid_arm64.s
new file mode 100644
index 000000000..48ebb5fd1
--- /dev/null
+++ b/pkg/procid/procid_arm64.s
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+// +build go1.8
+// +build !go1.16
+
+#include "textflag.h"
+
+TEXT ·Current(SB),NOSPLIT,$0-8
+ // The offset specified here is the m_procid offset for Go1.8+.
+ // Changes to this offset should be caught by the tests, and major
+ // version changes require an explicit tag change above.
+ MOVD g, R0 // g
+ MOVD 48(R0), R0 // g_m (may change in future versions)
+ MOVD 72(R0), R0 // m_procid (may change in future versions)
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/procid/procid_net_test.go b/pkg/procid/procid_net_test.go
new file mode 100644
index 000000000..b628e2285
--- /dev/null
+++ b/pkg/procid/procid_net_test.go
@@ -0,0 +1,21 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package procid
+
+// This file is just to force the inclusion of the "net" package, which will
+// make the test binary a cgo one.
+import (
+ _ "net"
+)
diff --git a/pkg/procid/procid_test.go b/pkg/procid/procid_test.go
new file mode 100644
index 000000000..9ec08c3d6
--- /dev/null
+++ b/pkg/procid/procid_test.go
@@ -0,0 +1,86 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package procid
+
+import (
+ "os"
+ "runtime"
+ "syscall"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// runOnMain is used to send functions to run on the main (initial) thread.
+var runOnMain = make(chan func(), 10)
+
+func checkProcid(t *testing.T, start *sync.WaitGroup, done *sync.WaitGroup) {
+ defer done.Done()
+
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+
+ start.Done()
+ start.Wait()
+
+ procID := Current()
+ tid := syscall.Gettid()
+
+ if procID != uint64(tid) {
+ t.Logf("Bad procid: expected %v, got %v", tid, procID)
+ t.Fail()
+ }
+}
+
+func TestProcidInitialized(t *testing.T) {
+ var start sync.WaitGroup
+ var done sync.WaitGroup
+
+ count := 100
+ start.Add(count + 1)
+ done.Add(count + 1)
+
+ // Run the check on the main thread.
+ //
+ // When cgo is not included, the only case when procid isn't initialized
+ // is in the main (initial) thread, so we have to test this case
+ // specifically.
+ runOnMain <- func() {
+ checkProcid(t, &start, &done)
+ }
+
+ // Run the check on a number of different threads.
+ for i := 0; i < count; i++ {
+ go checkProcid(t, &start, &done)
+ }
+
+ done.Wait()
+}
+
+func TestMain(m *testing.M) {
+ // Make sure we remain at the main (initial) thread.
+ runtime.LockOSThread()
+
+ // Start running tests in a different goroutine.
+ go func() {
+ os.Exit(m.Run())
+ }()
+
+ // Execute any functions that have been sent for execution in the main
+ // thread.
+ for f := range runOnMain {
+ f()
+ }
+}
diff --git a/pkg/rand/BUILD b/pkg/rand/BUILD
new file mode 100644
index 000000000..80b8ceb02
--- /dev/null
+++ b/pkg/rand/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "rand",
+ srcs = [
+ "rand.go",
+ "rand_linux.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/sync",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/rand/rand.go b/pkg/rand/rand.go
new file mode 100644
index 000000000..a2714784d
--- /dev/null
+++ b/pkg/rand/rand.go
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !linux
+
+// Package rand implements a cryptographically secure pseudorandom number
+// generator.
+package rand
+
+import "crypto/rand"
+
+// Reader is the default reader.
+var Reader = rand.Reader
+
+// Read implements io.Reader.Read.
+func Read(b []byte) (int, error) {
+ return rand.Read(b)
+}
diff --git a/pkg/rand/rand_linux.go b/pkg/rand/rand_linux.go
new file mode 100644
index 000000000..fa6a21026
--- /dev/null
+++ b/pkg/rand/rand_linux.go
@@ -0,0 +1,77 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package rand implements a cryptographically secure pseudorandom number
+// generator.
+package rand
+
+import (
+ "bufio"
+ "crypto/rand"
+ "io"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// reader implements an io.Reader that returns pseudorandom bytes.
+type reader struct {
+ once sync.Once
+ useGetrandom bool
+}
+
+// Read implements io.Reader.Read.
+func (r *reader) Read(p []byte) (int, error) {
+ r.once.Do(func() {
+ _, err := unix.Getrandom(p, 0)
+ if err != unix.ENOSYS {
+ r.useGetrandom = true
+ }
+ })
+
+ if r.useGetrandom {
+ return unix.Getrandom(p, 0)
+ }
+ return rand.Read(p)
+}
+
+// bufferedReader implements a threadsafe buffered io.Reader.
+type bufferedReader struct {
+ mu sync.Mutex
+ r *bufio.Reader
+}
+
+// Read implements io.Reader.Read.
+func (b *bufferedReader) Read(p []byte) (int, error) {
+ b.mu.Lock()
+ n, err := b.r.Read(p)
+ b.mu.Unlock()
+ return n, err
+}
+
+// Reader is the default reader.
+var Reader io.Reader = &bufferedReader{r: bufio.NewReader(&reader{})}
+
+// Read reads from the default reader.
+func Read(b []byte) (int, error) {
+ return io.ReadFull(Reader, b)
+}
+
+// Init can be called to make sure /dev/urandom is pre-opened on kernels that
+// do not support getrandom(2).
+func Init() error {
+ p := make([]byte, 1)
+ _, err := Read(p)
+ return err
+}
diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD
new file mode 100644
index 000000000..74affc887
--- /dev/null
+++ b/pkg/refs/BUILD
@@ -0,0 +1,38 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "weak_ref_list",
+ out = "weak_ref_list.go",
+ package = "refs",
+ prefix = "weakRef",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*WeakRef",
+ "Linker": "*WeakRef",
+ },
+)
+
+go_library(
+ name = "refs",
+ srcs = [
+ "refcounter.go",
+ "refcounter_state.go",
+ "weak_ref_list.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/log",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "refs_test",
+ size = "small",
+ srcs = ["refcounter_test.go"],
+ library = ":refs",
+ deps = ["//pkg/sync"],
+)
diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go
new file mode 100644
index 000000000..c45ba8200
--- /dev/null
+++ b/pkg/refs/refcounter.go
@@ -0,0 +1,469 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package refs defines an interface for reference counted objects. It
+// also provides a drop-in implementation called AtomicRefCount.
+package refs
+
+import (
+ "bytes"
+ "fmt"
+ "reflect"
+ "runtime"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// RefCounter is the interface to be implemented by objects that are reference
+// counted.
+type RefCounter interface {
+ // IncRef increments the reference counter on the object.
+ IncRef()
+
+ // DecRef decrements the reference counter on the object.
+ //
+ // Note that AtomicRefCounter.DecRef() does not support destructors.
+ // If a type has a destructor, it must implement its own DecRef()
+ // method and call AtomicRefCounter.DecRefWithDestructor(destructor).
+ DecRef()
+
+ // TryIncRef attempts to increase the reference counter on the object,
+ // but may fail if all references have already been dropped. This
+ // should be used only in special circumstances, such as WeakRefs.
+ TryIncRef() bool
+
+ // addWeakRef adds the given weak reference. Note that you should have a
+ // reference to the object when calling this method.
+ addWeakRef(*WeakRef)
+
+ // dropWeakRef drops the given weak reference. Note that you should have
+ // a reference to the object when calling this method.
+ dropWeakRef(*WeakRef)
+}
+
+// A WeakRefUser is notified when the last non-weak reference is dropped.
+type WeakRefUser interface {
+ // WeakRefGone is called when the last non-weak reference is dropped.
+ WeakRefGone()
+}
+
+// WeakRef is a weak reference.
+//
+// +stateify savable
+type WeakRef struct {
+ weakRefEntry `state:"nosave"`
+
+ // obj is an atomic value that points to the refCounter.
+ obj atomic.Value `state:".(savedReference)"`
+
+ // user is notified when the weak ref is zapped by the object getting
+ // destroyed.
+ user WeakRefUser
+}
+
+// weakRefPool is a pool of weak references to avoid allocations on the hot path.
+var weakRefPool = sync.Pool{
+ New: func() interface{} {
+ return &WeakRef{}
+ },
+}
+
+// NewWeakRef acquires a weak reference for the given object.
+//
+// An optional user will be notified when the last non-weak reference is
+// dropped.
+//
+// Note that you must hold a reference to the object prior to getting a weak
+// reference. (But you may drop the non-weak reference after that.)
+func NewWeakRef(rc RefCounter, u WeakRefUser) *WeakRef {
+ w := weakRefPool.Get().(*WeakRef)
+ w.init(rc, u)
+ return w
+}
+
+// get attempts to get a normal reference to the underlying object, and returns
+// the object. If this weak reference has already been zapped (the object has
+// been destroyed) then false is returned. If the object still exists, then
+// true is returned.
+func (w *WeakRef) get() (RefCounter, bool) {
+ rc := w.obj.Load().(RefCounter)
+ if v := reflect.ValueOf(rc); v == reflect.Zero(v.Type()) {
+ // This pointer has already been zapped by zap() below. We do
+ // this to ensure that the GC can collect the underlying
+ // RefCounter objects and they don't hog resources.
+ return nil, false
+ }
+ if !rc.TryIncRef() {
+ return nil, true
+ }
+ return rc, true
+}
+
+// Get attempts to get a normal reference to the underlying object, and returns
+// the object. If this fails (the object no longer exists), then nil will be
+// returned instead.
+func (w *WeakRef) Get() RefCounter {
+ rc, _ := w.get()
+ return rc
+}
+
+// Drop drops this weak reference. You should always call drop when you are
+// finished with the weak reference. You may not use this object after calling
+// drop.
+func (w *WeakRef) Drop() {
+ rc, ok := w.get()
+ if !ok {
+ // We've been zapped already. When the refcounter has called
+ // zap, we're guaranteed it's not holding references.
+ weakRefPool.Put(w)
+ return
+ }
+ if rc == nil {
+ // The object is in the process of being destroyed. We can't
+ // remove this from the object's list, nor can we return this
+ // object to the pool. It'll just be garbage collected. This is
+ // a rare edge case, so it's not a big deal.
+ return
+ }
+
+ // At this point, we have a reference on the object. So destruction
+ // of the object (and zapping this weak reference) can't race here.
+ rc.dropWeakRef(w)
+
+ // And now aren't on the object's list of weak references. So it won't
+ // zap us if this causes the reference count to drop to zero.
+ rc.DecRef()
+
+ // Return to the pool.
+ weakRefPool.Put(w)
+}
+
+// init initializes this weak reference.
+func (w *WeakRef) init(rc RefCounter, u WeakRefUser) {
+ // Reset the contents of the weak reference.
+ // This is important because we are reseting the atomic value type.
+ // Otherwise, we could panic here if obj is different than what it was
+ // the last time this was used.
+ *w = WeakRef{}
+ w.user = u
+ w.obj.Store(rc)
+
+ // In the load path, we may already have a nil value. So we need to
+ // check whether or not that is the case before calling addWeakRef.
+ if v := reflect.ValueOf(rc); v != reflect.Zero(v.Type()) {
+ rc.addWeakRef(w)
+ }
+}
+
+// zap zaps this weak reference.
+func (w *WeakRef) zap() {
+ // We need to be careful about types here.
+ // So reflect is involved. But it's not that bad.
+ rc := w.obj.Load()
+ typ := reflect.TypeOf(rc)
+ w.obj.Store(reflect.Zero(typ).Interface())
+}
+
+// AtomicRefCount keeps a reference count using atomic operations and calls the
+// destructor when the count reaches zero.
+//
+// N.B. To allow the zero-object to be initialized, the count is offset by
+// 1, that is, when refCount is n, there are really n+1 references.
+//
+// +stateify savable
+type AtomicRefCount struct {
+ // refCount is composed of two fields:
+ //
+ // [32-bit speculative references]:[32-bit real references]
+ //
+ // Speculative references are used for TryIncRef, to avoid a
+ // CompareAndSwap loop. See IncRef, DecRef and TryIncRef for details of
+ // how these fields are used.
+ refCount int64
+
+ // name is the name of the type which owns this ref count.
+ //
+ // name is immutable after EnableLeakCheck is called.
+ name string
+
+ // stack optionally records the caller of EnableLeakCheck.
+ //
+ // stack is immutable after EnableLeakCheck is called.
+ stack []uintptr
+
+ // mu protects the list below.
+ mu sync.Mutex `state:"nosave"`
+
+ // weakRefs is our collection of weak references.
+ weakRefs weakRefList `state:"nosave"`
+}
+
+// LeakMode configures the leak checker.
+type LeakMode uint32
+
+const (
+ // UninitializedLeakChecking indicates that the leak checker has not yet been initialized.
+ UninitializedLeakChecking LeakMode = iota
+
+ // NoLeakChecking indicates that no effort should be made to check for
+ // leaks.
+ NoLeakChecking
+
+ // LeaksLogWarning indicates that a warning should be logged when leaks
+ // are found.
+ LeaksLogWarning
+
+ // LeaksLogTraces indicates that a trace collected during allocation
+ // should be logged when leaks are found.
+ LeaksLogTraces
+)
+
+// leakMode stores the current mode for the reference leak checker.
+//
+// Values must be one of the LeakMode values.
+//
+// leakMode must be accessed atomically.
+var leakMode uint32
+
+// SetLeakMode configures the reference leak checker.
+func SetLeakMode(mode LeakMode) {
+ atomic.StoreUint32(&leakMode, uint32(mode))
+}
+
+const maxStackFrames = 40
+
+type fileLine struct {
+ file string
+ line int
+}
+
+// A stackKey is a representation of a stack frame for use as a map key.
+//
+// The fileLine type is used as PC values seem to vary across collections, even
+// for the same call stack.
+type stackKey [maxStackFrames]fileLine
+
+var stackCache = struct {
+ sync.Mutex
+ entries map[stackKey][]uintptr
+}{entries: map[stackKey][]uintptr{}}
+
+func makeStackKey(pcs []uintptr) stackKey {
+ frames := runtime.CallersFrames(pcs)
+ var key stackKey
+ keySlice := key[:0]
+ for {
+ frame, more := frames.Next()
+ keySlice = append(keySlice, fileLine{frame.File, frame.Line})
+
+ if !more || len(keySlice) == len(key) {
+ break
+ }
+ }
+ return key
+}
+
+func recordStack() []uintptr {
+ pcs := make([]uintptr, maxStackFrames)
+ n := runtime.Callers(1, pcs)
+ if n == 0 {
+ // No pcs available. Stop now.
+ //
+ // This can happen if the first argument to runtime.Callers
+ // is large.
+ return nil
+ }
+ pcs = pcs[:n]
+ key := makeStackKey(pcs)
+ stackCache.Lock()
+ v, ok := stackCache.entries[key]
+ if !ok {
+ // Reallocate to prevent pcs from escaping.
+ v = append([]uintptr(nil), pcs...)
+ stackCache.entries[key] = v
+ }
+ stackCache.Unlock()
+ return v
+}
+
+func formatStack(pcs []uintptr) string {
+ frames := runtime.CallersFrames(pcs)
+ var trace bytes.Buffer
+ for {
+ frame, more := frames.Next()
+ fmt.Fprintf(&trace, "%s:%d: %s\n", frame.File, frame.Line, frame.Function)
+
+ if !more {
+ break
+ }
+ }
+ return trace.String()
+}
+
+func (r *AtomicRefCount) finalize() {
+ var note string
+ switch LeakMode(atomic.LoadUint32(&leakMode)) {
+ case NoLeakChecking:
+ return
+ case UninitializedLeakChecking:
+ note = "(Leak checker uninitialized): "
+ }
+ if n := r.ReadRefs(); n != 0 {
+ msg := fmt.Sprintf("%sAtomicRefCount %p owned by %q garbage collected with ref count of %d (want 0)", note, r, r.name, n)
+ if len(r.stack) != 0 {
+ msg += ":\nCaller:\n" + formatStack(r.stack)
+ } else {
+ msg += " (enable trace logging to debug)"
+ }
+ log.Warningf(msg)
+ }
+}
+
+// EnableLeakCheck checks for reference leaks when the AtomicRefCount gets
+// garbage collected.
+//
+// This function adds a finalizer to the AtomicRefCount, so the AtomicRefCount
+// must be at the beginning of its parent.
+//
+// name is a friendly name that will be listed as the owner of the
+// AtomicRefCount in logs. It should be the name of the parent type, including
+// package.
+func (r *AtomicRefCount) EnableLeakCheck(name string) {
+ if name == "" {
+ panic("invalid name")
+ }
+ switch LeakMode(atomic.LoadUint32(&leakMode)) {
+ case NoLeakChecking:
+ return
+ case LeaksLogTraces:
+ r.stack = recordStack()
+ }
+ r.name = name
+ runtime.SetFinalizer(r, (*AtomicRefCount).finalize)
+}
+
+// ReadRefs returns the current number of references. The returned count is
+// inherently racy and is unsafe to use without external synchronization.
+func (r *AtomicRefCount) ReadRefs() int64 {
+ // Account for the internal -1 offset on refcounts.
+ return atomic.LoadInt64(&r.refCount) + 1
+}
+
+// IncRef increments this object's reference count. While the count is kept
+// greater than zero, the destructor doesn't get called.
+//
+// The sanity check here is limited to real references, since if they have
+// dropped beneath zero then the object should have been destroyed.
+//
+//go:nosplit
+func (r *AtomicRefCount) IncRef() {
+ if v := atomic.AddInt64(&r.refCount, 1); v <= 0 {
+ panic("Incrementing non-positive ref count")
+ }
+}
+
+// TryIncRef attempts to increment the reference count, *unless the count has
+// already reached zero*. If false is returned, then the object has already
+// been destroyed, and the weak reference is no longer valid. If true if
+// returned then a valid reference is now held on the object.
+//
+// To do this safely without a loop, a speculative reference is first acquired
+// on the object. This allows multiple concurrent TryIncRef calls to
+// distinguish other TryIncRef calls from genuine references held.
+//
+//go:nosplit
+func (r *AtomicRefCount) TryIncRef() bool {
+ const speculativeRef = 1 << 32
+ v := atomic.AddInt64(&r.refCount, speculativeRef)
+ if int32(v) < 0 {
+ // This object has already been freed.
+ atomic.AddInt64(&r.refCount, -speculativeRef)
+ return false
+ }
+
+ // Turn into a real reference.
+ atomic.AddInt64(&r.refCount, -speculativeRef+1)
+ return true
+}
+
+// addWeakRef adds the given weak reference.
+func (r *AtomicRefCount) addWeakRef(w *WeakRef) {
+ r.mu.Lock()
+ r.weakRefs.PushBack(w)
+ r.mu.Unlock()
+}
+
+// dropWeakRef drops the given weak reference.
+func (r *AtomicRefCount) dropWeakRef(w *WeakRef) {
+ r.mu.Lock()
+ r.weakRefs.Remove(w)
+ r.mu.Unlock()
+}
+
+// DecRefWithDestructor decrements the object's reference count. If the
+// resulting count is negative and the destructor is not nil, then the
+// destructor will be called.
+//
+// Note that speculative references are counted here. Since they were added
+// prior to real references reaching zero, they will successfully convert to
+// real references. In other words, we see speculative references only in the
+// following case:
+//
+// A: TryIncRef [speculative increase => sees non-negative references]
+// B: DecRef [real decrease]
+// A: TryIncRef [transform speculative to real]
+//
+//go:nosplit
+func (r *AtomicRefCount) DecRefWithDestructor(destroy func()) {
+ switch v := atomic.AddInt64(&r.refCount, -1); {
+ case v < -1:
+ panic("Decrementing non-positive ref count")
+
+ case v == -1:
+ // Zap weak references. Note that at this point, all weak
+ // references are already invalid. That is, TryIncRef() will
+ // return false due to the reference count check.
+ r.mu.Lock()
+ for !r.weakRefs.Empty() {
+ w := r.weakRefs.Front()
+ // Capture the callback because w cannot be touched
+ // after it's zapped -- the owner is free it reuse it
+ // after that.
+ user := w.user
+ r.weakRefs.Remove(w)
+ w.zap()
+
+ if user != nil {
+ r.mu.Unlock()
+ user.WeakRefGone()
+ r.mu.Lock()
+ }
+ }
+ r.mu.Unlock()
+
+ // Call the destructor.
+ if destroy != nil {
+ destroy()
+ }
+ }
+}
+
+// DecRef decrements this object's reference count.
+//
+//go:nosplit
+func (r *AtomicRefCount) DecRef() {
+ r.DecRefWithDestructor(nil)
+}
diff --git a/pkg/refs/refcounter_state.go b/pkg/refs/refcounter_state.go
new file mode 100644
index 000000000..7c99fd2b5
--- /dev/null
+++ b/pkg/refs/refcounter_state.go
@@ -0,0 +1,35 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package refs
+
+// +stateify savable
+type savedReference struct {
+ obj interface{}
+}
+
+func (w *WeakRef) saveObj() savedReference {
+ // We load the object directly, because it is typed. This will be
+ // serialized and loaded as a typed value.
+ return savedReference{w.obj.Load()}
+}
+
+func (w *WeakRef) loadObj(v savedReference) {
+ // See note above. This will be serialized and loaded typed. So we're okay
+ // as long as refs aren't changing during save and load (which they should
+ // not be).
+ //
+ // w.user is loaded before loadObj is called.
+ w.init(v.obj.(RefCounter), w.user)
+}
diff --git a/pkg/refs/refcounter_test.go b/pkg/refs/refcounter_test.go
new file mode 100644
index 000000000..1ab4a4440
--- /dev/null
+++ b/pkg/refs/refcounter_test.go
@@ -0,0 +1,173 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package refs
+
+import (
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+type testCounter struct {
+ AtomicRefCount
+
+ // mu protects the boolean below.
+ mu sync.Mutex
+
+ // destroyed indicates whether this was destroyed.
+ destroyed bool
+}
+
+func (t *testCounter) DecRef() {
+ t.AtomicRefCount.DecRefWithDestructor(t.destroy)
+}
+
+func (t *testCounter) destroy() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.destroyed = true
+}
+
+func (t *testCounter) IsDestroyed() bool {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.destroyed
+}
+
+func newTestCounter() *testCounter {
+ return &testCounter{destroyed: false}
+}
+
+func TestOneRef(t *testing.T) {
+ tc := newTestCounter()
+ tc.DecRef()
+
+ if !tc.IsDestroyed() {
+ t.Errorf("object should have been destroyed")
+ }
+}
+
+func TestTwoRefs(t *testing.T) {
+ tc := newTestCounter()
+ tc.IncRef()
+ tc.DecRef()
+ tc.DecRef()
+
+ if !tc.IsDestroyed() {
+ t.Errorf("object should have been destroyed")
+ }
+}
+
+func TestMultiRefs(t *testing.T) {
+ tc := newTestCounter()
+ tc.IncRef()
+ tc.DecRef()
+
+ tc.IncRef()
+ tc.DecRef()
+
+ tc.DecRef()
+
+ if !tc.IsDestroyed() {
+ t.Errorf("object should have been destroyed")
+ }
+}
+
+func TestWeakRef(t *testing.T) {
+ tc := newTestCounter()
+ w := NewWeakRef(tc, nil)
+
+ // Try resolving.
+ if x := w.Get(); x == nil {
+ t.Errorf("weak reference didn't resolve: expected %v, got nil", tc)
+ } else {
+ x.DecRef()
+ }
+
+ // Try resolving again.
+ if x := w.Get(); x == nil {
+ t.Errorf("weak reference didn't resolve: expected %v, got nil", tc)
+ } else {
+ x.DecRef()
+ }
+
+ // Shouldn't be destroyed yet. (Can't continue if this fails.)
+ if tc.IsDestroyed() {
+ t.Fatalf("original object destroyed earlier than expected")
+ }
+
+ // Drop the original reference.
+ tc.DecRef()
+
+ // Assert destroyed.
+ if !tc.IsDestroyed() {
+ t.Errorf("original object not destroyed as expected")
+ }
+
+ // Shouldn't be anything.
+ if x := w.Get(); x != nil {
+ t.Errorf("weak reference resolved: expected nil, got %v", x)
+ }
+}
+
+func TestWeakRefDrop(t *testing.T) {
+ tc := newTestCounter()
+ w := NewWeakRef(tc, nil)
+ w.Drop()
+
+ // Just assert the list is empty.
+ if !tc.weakRefs.Empty() {
+ t.Errorf("weak reference not dropped")
+ }
+
+ // Drop the original reference.
+ tc.DecRef()
+}
+
+type testWeakRefUser struct {
+ weakRefGone func()
+}
+
+func (u *testWeakRefUser) WeakRefGone() {
+ u.weakRefGone()
+}
+
+func TestCallback(t *testing.T) {
+ called := false
+ tc := newTestCounter()
+ var w *WeakRef
+ w = NewWeakRef(tc, &testWeakRefUser{func() {
+ called = true
+
+ // Check that the weak ref has been zapped.
+ rc := w.obj.Load().(RefCounter)
+ if v := reflect.ValueOf(rc); v != reflect.Zero(v.Type()) {
+ t.Fatalf("Callback called with non-nil ptr")
+ }
+
+ // Check that we're not holding the mutex by acquiring and
+ // releasing it.
+ tc.mu.Lock()
+ tc.mu.Unlock()
+ }})
+
+ // Drop the original reference, this must trigger the callback.
+ tc.DecRef()
+
+ if !called {
+ t.Fatalf("Callback not called")
+ }
+}
diff --git a/pkg/safecopy/BUILD b/pkg/safecopy/BUILD
new file mode 100644
index 000000000..426ef30c9
--- /dev/null
+++ b/pkg/safecopy/BUILD
@@ -0,0 +1,29 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "safecopy",
+ srcs = [
+ "atomic_amd64.s",
+ "atomic_arm64.s",
+ "memclr_amd64.s",
+ "memclr_arm64.s",
+ "memcpy_amd64.s",
+ "memcpy_arm64.s",
+ "safecopy.go",
+ "safecopy_unsafe.go",
+ "sighandler_amd64.s",
+ "sighandler_arm64.s",
+ ],
+ visibility = ["//:sandbox"],
+ deps = ["//pkg/syserror"],
+)
+
+go_test(
+ name = "safecopy_test",
+ srcs = [
+ "safecopy_test.go",
+ ],
+ library = ":safecopy",
+)
diff --git a/pkg/safecopy/LICENSE b/pkg/safecopy/LICENSE
new file mode 100644
index 000000000..6a66aea5e
--- /dev/null
+++ b/pkg/safecopy/LICENSE
@@ -0,0 +1,27 @@
+Copyright (c) 2009 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/pkg/safecopy/atomic_amd64.s b/pkg/safecopy/atomic_amd64.s
new file mode 100644
index 000000000..a0cd78f33
--- /dev/null
+++ b/pkg/safecopy/atomic_amd64.s
@@ -0,0 +1,136 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// handleSwapUint32Fault returns the value stored in DI. Control is transferred
+// to it when swapUint32 below receives SIGSEGV or SIGBUS, with the signal
+// number stored in DI.
+//
+// It must have the same frame configuration as swapUint32 so that it can undo
+// any potential call frame set up by the assembler.
+TEXT handleSwapUint32Fault(SB), NOSPLIT, $0-24
+ MOVL DI, sig+20(FP)
+ RET
+
+// swapUint32 atomically stores new into *addr and returns (the previous *addr
+// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the
+// value of old is unspecified, and sig is the number of the signal that was
+// received.
+//
+// Preconditions: addr must be aligned to a 4-byte boundary.
+//
+//func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32)
+TEXT ·swapUint32(SB), NOSPLIT, $0-24
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleSwapUint32Fault will store a different value in this address.
+ MOVL $0, sig+20(FP)
+
+ MOVQ addr+0(FP), DI
+ MOVL new+8(FP), AX
+ XCHGL AX, 0(DI)
+ MOVL AX, old+16(FP)
+ RET
+
+// handleSwapUint64Fault returns the value stored in DI. Control is transferred
+// to it when swapUint64 below receives SIGSEGV or SIGBUS, with the signal
+// number stored in DI.
+//
+// It must have the same frame configuration as swapUint64 so that it can undo
+// any potential call frame set up by the assembler.
+TEXT handleSwapUint64Fault(SB), NOSPLIT, $0-28
+ MOVL DI, sig+24(FP)
+ RET
+
+// swapUint64 atomically stores new into *addr and returns (the previous *addr
+// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the
+// value of old is unspecified, and sig is the number of the signal that was
+// received.
+//
+// Preconditions: addr must be aligned to a 8-byte boundary.
+//
+//func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32)
+TEXT ·swapUint64(SB), NOSPLIT, $0-28
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleSwapUint64Fault will store a different value in this address.
+ MOVL $0, sig+24(FP)
+
+ MOVQ addr+0(FP), DI
+ MOVQ new+8(FP), AX
+ XCHGQ AX, 0(DI)
+ MOVQ AX, old+16(FP)
+ RET
+
+// handleCompareAndSwapUint32Fault returns the value stored in DI. Control is
+// transferred to it when swapUint64 below receives SIGSEGV or SIGBUS, with the
+// signal number stored in DI.
+//
+// It must have the same frame configuration as compareAndSwapUint32 so that it
+// can undo any potential call frame set up by the assembler.
+TEXT handleCompareAndSwapUint32Fault(SB), NOSPLIT, $0-24
+ MOVL DI, sig+20(FP)
+ RET
+
+// compareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns
+// (the value previously stored at addr, 0). If a SIGSEGV or SIGBUS signal is
+// received during the operation, the value of prev is unspecified, and sig is
+// the number of the signal that was received.
+//
+// Preconditions: addr must be aligned to a 4-byte boundary.
+//
+//func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32)
+TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24
+ // Store 0 as the returned signal number. If we run to completion, this is
+ // the value the caller will see; if a signal is received,
+ // handleCompareAndSwapUint32Fault will store a different value in this
+ // address.
+ MOVL $0, sig+20(FP)
+
+ MOVQ addr+0(FP), DI
+ MOVL old+8(FP), AX
+ MOVL new+12(FP), DX
+ LOCK
+ CMPXCHGL DX, 0(DI)
+ MOVL AX, prev+16(FP)
+ RET
+
+// handleLoadUint32Fault returns the value stored in DI. Control is transferred
+// to it when LoadUint32 below receives SIGSEGV or SIGBUS, with the signal
+// number stored in DI.
+//
+// It must have the same frame configuration as loadUint32 so that it can undo
+// any potential call frame set up by the assembler.
+TEXT handleLoadUint32Fault(SB), NOSPLIT, $0-16
+ MOVL DI, sig+12(FP)
+ RET
+
+// loadUint32 atomically loads *addr and returns it. If a SIGSEGV or SIGBUS
+// signal is received, the value returned is unspecified, and sig is the number
+// of the signal that was received.
+//
+// Preconditions: addr must be aligned to a 4-byte boundary.
+//
+//func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32)
+TEXT ·loadUint32(SB), NOSPLIT, $0-16
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleLoadUint32Fault will store a different value in this address.
+ MOVL $0, sig+12(FP)
+
+ MOVQ addr+0(FP), AX
+ MOVL (AX), BX
+ MOVL BX, val+8(FP)
+ RET
diff --git a/pkg/safecopy/atomic_arm64.s b/pkg/safecopy/atomic_arm64.s
new file mode 100644
index 000000000..d58ed71f7
--- /dev/null
+++ b/pkg/safecopy/atomic_arm64.s
@@ -0,0 +1,126 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "textflag.h"
+
+// handleSwapUint32Fault returns the value stored in R1. Control is transferred
+// to it when swapUint32 below receives SIGSEGV or SIGBUS, with the signal
+// number stored in R1.
+//
+// It must have the same frame configuration as swapUint32 so that it can undo
+// any potential call frame set up by the assembler.
+TEXT handleSwapUint32Fault(SB), NOSPLIT, $0-24
+ MOVW R1, sig+20(FP)
+ RET
+
+// See the corresponding doc in safecopy_unsafe.go
+//
+// The code is derived from Go source runtime/internal/atomic.Xchg.
+//
+//func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32)
+TEXT ·swapUint32(SB), NOSPLIT, $0-24
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleSwapUint32Fault will store a different value in this address.
+ MOVW $0, sig+20(FP)
+again:
+ MOVD addr+0(FP), R0
+ MOVW new+8(FP), R1
+ LDAXRW (R0), R2
+ STLXRW R1, (R0), R3
+ CBNZ R3, again
+ MOVW R2, old+16(FP)
+ RET
+
+// handleSwapUint64Fault returns the value stored in R1. Control is transferred
+// to it when swapUint64 below receives SIGSEGV or SIGBUS, with the signal
+// number stored in R1.
+//
+// It must have the same frame configuration as swapUint64 so that it can undo
+// any potential call frame set up by the assembler.
+TEXT handleSwapUint64Fault(SB), NOSPLIT, $0-28
+ MOVW R1, sig+24(FP)
+ RET
+
+// See the corresponding doc in safecopy_unsafe.go
+//
+// The code is derived from Go source runtime/internal/atomic.Xchg64.
+//
+//func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32)
+TEXT ·swapUint64(SB), NOSPLIT, $0-28
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleSwapUint64Fault will store a different value in this address.
+ MOVW $0, sig+24(FP)
+again:
+ MOVD addr+0(FP), R0
+ MOVD new+8(FP), R1
+ LDAXR (R0), R2
+ STLXR R1, (R0), R3
+ CBNZ R3, again
+ MOVD R2, old+16(FP)
+ RET
+
+// handleCompareAndSwapUint32Fault returns the value stored in R1. Control is
+// transferred to it when compareAndSwapUint32 below receives SIGSEGV or SIGBUS,
+// with the signal number stored in R1.
+//
+// It must have the same frame configuration as compareAndSwapUint32 so that it
+// can undo any potential call frame set up by the assembler.
+TEXT handleCompareAndSwapUint32Fault(SB), NOSPLIT, $0-24
+ MOVW R1, sig+20(FP)
+ RET
+
+// See the corresponding doc in safecopy_unsafe.go
+//
+// The code is derived from Go source runtime/internal/atomic.Cas.
+//
+//func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32)
+TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24
+ // Store 0 as the returned signal number. If we run to completion, this is
+ // the value the caller will see; if a signal is received,
+ // handleCompareAndSwapUint32Fault will store a different value in this
+ // address.
+ MOVW $0, sig+20(FP)
+
+ MOVD addr+0(FP), R0
+ MOVW old+8(FP), R1
+ MOVW new+12(FP), R2
+again:
+ LDAXRW (R0), R3
+ CMPW R1, R3
+ BNE done
+ STLXRW R2, (R0), R4
+ CBNZ R4, again
+done:
+ MOVW R3, prev+16(FP)
+ RET
+
+// handleLoadUint32Fault returns the value stored in DI. Control is transferred
+// to it when LoadUint32 below receives SIGSEGV or SIGBUS, with the signal
+// number stored in DI.
+//
+// It must have the same frame configuration as loadUint32 so that it can undo
+// any potential call frame set up by the assembler.
+TEXT handleLoadUint32Fault(SB), NOSPLIT, $0-16
+ MOVW R1, sig+12(FP)
+ RET
+
+// loadUint32 atomically loads *addr and returns it. If a SIGSEGV or SIGBUS
+// signal is received, the value returned is unspecified, and sig is the number
+// of the signal that was received.
+//
+// Preconditions: addr must be aligned to a 4-byte boundary.
+//
+//func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32)
+TEXT ·loadUint32(SB), NOSPLIT, $0-16
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleLoadUint32Fault will store a different value in this address.
+ MOVW $0, sig+12(FP)
+
+ MOVD addr+0(FP), R0
+ LDARW (R0), R1
+ MOVW R1, val+8(FP)
+ RET
diff --git a/pkg/safecopy/memclr_amd64.s b/pkg/safecopy/memclr_amd64.s
new file mode 100644
index 000000000..64cf32f05
--- /dev/null
+++ b/pkg/safecopy/memclr_amd64.s
@@ -0,0 +1,147 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "textflag.h"
+
+// handleMemclrFault returns (the value stored in AX, the value stored in DI).
+// Control is transferred to it when memclr below receives SIGSEGV or SIGBUS,
+// with the faulting address stored in AX and the signal number stored in DI.
+//
+// It must have the same frame configuration as memclr so that it can undo any
+// potential call frame set up by the assembler.
+TEXT handleMemclrFault(SB), NOSPLIT, $0-28
+ MOVQ AX, addr+16(FP)
+ MOVL DI, sig+24(FP)
+ RET
+
+// memclr sets the n bytes following ptr to zeroes. If a SIGSEGV or SIGBUS
+// signal is received during the write, it returns the address that caused the
+// fault and the number of the signal that was received. Otherwise, it returns
+// an unspecified address and a signal number of 0.
+//
+// Data is written in order, such that if a fault happens at address p, it is
+// safe to assume that all data before p-maxRegisterSize has already been
+// successfully written.
+//
+// The code is derived from runtime.memclrNoHeapPointers.
+//
+// func memclr(ptr unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32)
+TEXT ·memclr(SB), NOSPLIT, $0-28
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleMemclrFault will store a different value in this address.
+ MOVL $0, sig+24(FP)
+
+ MOVQ ptr+0(FP), DI
+ MOVQ n+8(FP), BX
+ XORQ AX, AX
+
+ // MOVOU seems always faster than REP STOSQ.
+tail:
+ TESTQ BX, BX
+ JEQ _0
+ CMPQ BX, $2
+ JBE _1or2
+ CMPQ BX, $4
+ JBE _3or4
+ CMPQ BX, $8
+ JB _5through7
+ JE _8
+ CMPQ BX, $16
+ JBE _9through16
+ PXOR X0, X0
+ CMPQ BX, $32
+ JBE _17through32
+ CMPQ BX, $64
+ JBE _33through64
+ CMPQ BX, $128
+ JBE _65through128
+ CMPQ BX, $256
+ JBE _129through256
+ // TODO: use branch table and BSR to make this just a single dispatch
+ // TODO: for really big clears, use MOVNTDQ, even without AVX2.
+
+loop:
+ MOVOU X0, 0(DI)
+ MOVOU X0, 16(DI)
+ MOVOU X0, 32(DI)
+ MOVOU X0, 48(DI)
+ MOVOU X0, 64(DI)
+ MOVOU X0, 80(DI)
+ MOVOU X0, 96(DI)
+ MOVOU X0, 112(DI)
+ MOVOU X0, 128(DI)
+ MOVOU X0, 144(DI)
+ MOVOU X0, 160(DI)
+ MOVOU X0, 176(DI)
+ MOVOU X0, 192(DI)
+ MOVOU X0, 208(DI)
+ MOVOU X0, 224(DI)
+ MOVOU X0, 240(DI)
+ SUBQ $256, BX
+ ADDQ $256, DI
+ CMPQ BX, $256
+ JAE loop
+ JMP tail
+
+_1or2:
+ MOVB AX, (DI)
+ MOVB AX, -1(DI)(BX*1)
+ RET
+_0:
+ RET
+_3or4:
+ MOVW AX, (DI)
+ MOVW AX, -2(DI)(BX*1)
+ RET
+_5through7:
+ MOVL AX, (DI)
+ MOVL AX, -4(DI)(BX*1)
+ RET
+_8:
+ // We need a separate case for 8 to make sure we clear pointers atomically.
+ MOVQ AX, (DI)
+ RET
+_9through16:
+ MOVQ AX, (DI)
+ MOVQ AX, -8(DI)(BX*1)
+ RET
+_17through32:
+ MOVOU X0, (DI)
+ MOVOU X0, -16(DI)(BX*1)
+ RET
+_33through64:
+ MOVOU X0, (DI)
+ MOVOU X0, 16(DI)
+ MOVOU X0, -32(DI)(BX*1)
+ MOVOU X0, -16(DI)(BX*1)
+ RET
+_65through128:
+ MOVOU X0, (DI)
+ MOVOU X0, 16(DI)
+ MOVOU X0, 32(DI)
+ MOVOU X0, 48(DI)
+ MOVOU X0, -64(DI)(BX*1)
+ MOVOU X0, -48(DI)(BX*1)
+ MOVOU X0, -32(DI)(BX*1)
+ MOVOU X0, -16(DI)(BX*1)
+ RET
+_129through256:
+ MOVOU X0, (DI)
+ MOVOU X0, 16(DI)
+ MOVOU X0, 32(DI)
+ MOVOU X0, 48(DI)
+ MOVOU X0, 64(DI)
+ MOVOU X0, 80(DI)
+ MOVOU X0, 96(DI)
+ MOVOU X0, 112(DI)
+ MOVOU X0, -128(DI)(BX*1)
+ MOVOU X0, -112(DI)(BX*1)
+ MOVOU X0, -96(DI)(BX*1)
+ MOVOU X0, -80(DI)(BX*1)
+ MOVOU X0, -64(DI)(BX*1)
+ MOVOU X0, -48(DI)(BX*1)
+ MOVOU X0, -32(DI)(BX*1)
+ MOVOU X0, -16(DI)(BX*1)
+ RET
diff --git a/pkg/safecopy/memclr_arm64.s b/pkg/safecopy/memclr_arm64.s
new file mode 100644
index 000000000..7361b9067
--- /dev/null
+++ b/pkg/safecopy/memclr_arm64.s
@@ -0,0 +1,74 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "textflag.h"
+
+// handleMemclrFault returns (the value stored in R0, the value stored in R1).
+// Control is transferred to it when memclr below receives SIGSEGV or SIGBUS,
+// with the faulting address stored in R0 and the signal number stored in R1.
+//
+// It must have the same frame configuration as memclr so that it can undo any
+// potential call frame set up by the assembler.
+TEXT handleMemclrFault(SB), NOSPLIT, $0-28
+ MOVD R0, addr+16(FP)
+ MOVW R1, sig+24(FP)
+ RET
+
+// See the corresponding doc in safecopy_unsafe.go
+//
+// The code is derived from runtime.memclrNoHeapPointers.
+//
+// func memclr(ptr unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32)
+TEXT ·memclr(SB), NOSPLIT, $0-28
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleMemclrFault will store a different value in this address.
+ MOVW $0, sig+24(FP)
+ MOVD ptr+0(FP), R0
+ MOVD n+8(FP), R1
+
+ // If size is less than 16 bytes, use tail_zero to zero what remains
+ CMP $16, R1
+ BLT tail_zero
+ // Get buffer offset into 16 byte aligned address for better performance
+ ANDS $15, R0, ZR
+ BNE unaligned_to_16
+aligned_to_16:
+ LSR $4, R1, R2
+zero_by_16:
+ STP.P (ZR, ZR), 16(R0) // Store pair with post index.
+ SUBS $1, R2, R2
+ BNE zero_by_16
+ ANDS $15, R1, R1
+ BEQ end
+
+ // Zero buffer with size=R1 < 16
+tail_zero:
+ TBZ $3, R1, tail_zero_4
+ MOVD.P ZR, 8(R0)
+tail_zero_4:
+ TBZ $2, R1, tail_zero_2
+ MOVW.P ZR, 4(R0)
+tail_zero_2:
+ TBZ $1, R1, tail_zero_1
+ MOVH.P ZR, 2(R0)
+tail_zero_1:
+ TBZ $0, R1, end
+ MOVB ZR, (R0)
+end:
+ RET
+
+unaligned_to_16:
+ MOVD R0, R2
+head_loop:
+ MOVBU.P ZR, 1(R0)
+ ANDS $15, R0, ZR
+ BNE head_loop
+ // Adjust length for what remains
+ SUB R2, R0, R3
+ SUB R3, R1
+ // If size is less than 16 bytes, use tail_zero to zero what remains
+ CMP $16, R1
+ BLT tail_zero
+ B aligned_to_16
diff --git a/pkg/safecopy/memcpy_amd64.s b/pkg/safecopy/memcpy_amd64.s
new file mode 100644
index 000000000..00b46c18f
--- /dev/null
+++ b/pkg/safecopy/memcpy_amd64.s
@@ -0,0 +1,219 @@
+// Copyright © 1994-1999 Lucent Technologies Inc. All rights reserved.
+// Revisions Copyright © 2000-2007 Vita Nuova Holdings Limited (www.vitanuova.com). All rights reserved.
+// Portions Copyright 2009 The Go Authors. All rights reserved.
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in
+// all copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+// THE SOFTWARE.
+
+#include "textflag.h"
+
+// handleMemcpyFault returns (the value stored in AX, the value stored in DI).
+// Control is transferred to it when memcpy below receives SIGSEGV or SIGBUS,
+// with the faulting address stored in AX and the signal number stored in DI.
+//
+// It must have the same frame configuration as memcpy so that it can undo any
+// potential call frame set up by the assembler.
+TEXT handleMemcpyFault(SB), NOSPLIT, $0-36
+ MOVQ AX, addr+24(FP)
+ MOVL DI, sig+32(FP)
+ RET
+
+// memcpy copies data from src to dst. If a SIGSEGV or SIGBUS signal is received
+// during the copy, it returns the address that caused the fault and the number
+// of the signal that was received. Otherwise, it returns an unspecified address
+// and a signal number of 0.
+//
+// Data is copied in order, such that if a fault happens at address p, it is
+// safe to assume that all data before p-maxRegisterSize has already been
+// successfully copied.
+//
+// The code is derived from the forward copying part of runtime.memmove.
+//
+// func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32)
+TEXT ·memcpy(SB), NOSPLIT, $0-36
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleMemcpyFault will store a different value in this address.
+ MOVL $0, sig+32(FP)
+
+ MOVQ to+0(FP), DI
+ MOVQ from+8(FP), SI
+ MOVQ n+16(FP), BX
+
+tail:
+ // BSR+branch table make almost all memmove/memclr benchmarks worse. Not
+ // worth doing.
+ TESTQ BX, BX
+ JEQ move_0
+ CMPQ BX, $2
+ JBE move_1or2
+ CMPQ BX, $4
+ JBE move_3or4
+ CMPQ BX, $8
+ JB move_5through7
+ JE move_8
+ CMPQ BX, $16
+ JBE move_9through16
+ CMPQ BX, $32
+ JBE move_17through32
+ CMPQ BX, $64
+ JBE move_33through64
+ CMPQ BX, $128
+ JBE move_65through128
+ CMPQ BX, $256
+ JBE move_129through256
+
+move_257plus:
+ SUBQ $256, BX
+ MOVOU (SI), X0
+ MOVOU X0, (DI)
+ MOVOU 16(SI), X1
+ MOVOU X1, 16(DI)
+ MOVOU 32(SI), X2
+ MOVOU X2, 32(DI)
+ MOVOU 48(SI), X3
+ MOVOU X3, 48(DI)
+ MOVOU 64(SI), X4
+ MOVOU X4, 64(DI)
+ MOVOU 80(SI), X5
+ MOVOU X5, 80(DI)
+ MOVOU 96(SI), X6
+ MOVOU X6, 96(DI)
+ MOVOU 112(SI), X7
+ MOVOU X7, 112(DI)
+ MOVOU 128(SI), X8
+ MOVOU X8, 128(DI)
+ MOVOU 144(SI), X9
+ MOVOU X9, 144(DI)
+ MOVOU 160(SI), X10
+ MOVOU X10, 160(DI)
+ MOVOU 176(SI), X11
+ MOVOU X11, 176(DI)
+ MOVOU 192(SI), X12
+ MOVOU X12, 192(DI)
+ MOVOU 208(SI), X13
+ MOVOU X13, 208(DI)
+ MOVOU 224(SI), X14
+ MOVOU X14, 224(DI)
+ MOVOU 240(SI), X15
+ MOVOU X15, 240(DI)
+ CMPQ BX, $256
+ LEAQ 256(SI), SI
+ LEAQ 256(DI), DI
+ JGE move_257plus
+ JMP tail
+
+move_1or2:
+ MOVB (SI), AX
+ MOVB AX, (DI)
+ MOVB -1(SI)(BX*1), CX
+ MOVB CX, -1(DI)(BX*1)
+ RET
+move_0:
+ RET
+move_3or4:
+ MOVW (SI), AX
+ MOVW AX, (DI)
+ MOVW -2(SI)(BX*1), CX
+ MOVW CX, -2(DI)(BX*1)
+ RET
+move_5through7:
+ MOVL (SI), AX
+ MOVL AX, (DI)
+ MOVL -4(SI)(BX*1), CX
+ MOVL CX, -4(DI)(BX*1)
+ RET
+move_8:
+ // We need a separate case for 8 to make sure we write pointers atomically.
+ MOVQ (SI), AX
+ MOVQ AX, (DI)
+ RET
+move_9through16:
+ MOVQ (SI), AX
+ MOVQ AX, (DI)
+ MOVQ -8(SI)(BX*1), CX
+ MOVQ CX, -8(DI)(BX*1)
+ RET
+move_17through32:
+ MOVOU (SI), X0
+ MOVOU X0, (DI)
+ MOVOU -16(SI)(BX*1), X1
+ MOVOU X1, -16(DI)(BX*1)
+ RET
+move_33through64:
+ MOVOU (SI), X0
+ MOVOU X0, (DI)
+ MOVOU 16(SI), X1
+ MOVOU X1, 16(DI)
+ MOVOU -32(SI)(BX*1), X2
+ MOVOU X2, -32(DI)(BX*1)
+ MOVOU -16(SI)(BX*1), X3
+ MOVOU X3, -16(DI)(BX*1)
+ RET
+move_65through128:
+ MOVOU (SI), X0
+ MOVOU X0, (DI)
+ MOVOU 16(SI), X1
+ MOVOU X1, 16(DI)
+ MOVOU 32(SI), X2
+ MOVOU X2, 32(DI)
+ MOVOU 48(SI), X3
+ MOVOU X3, 48(DI)
+ MOVOU -64(SI)(BX*1), X4
+ MOVOU X4, -64(DI)(BX*1)
+ MOVOU -48(SI)(BX*1), X5
+ MOVOU X5, -48(DI)(BX*1)
+ MOVOU -32(SI)(BX*1), X6
+ MOVOU X6, -32(DI)(BX*1)
+ MOVOU -16(SI)(BX*1), X7
+ MOVOU X7, -16(DI)(BX*1)
+ RET
+move_129through256:
+ MOVOU (SI), X0
+ MOVOU X0, (DI)
+ MOVOU 16(SI), X1
+ MOVOU X1, 16(DI)
+ MOVOU 32(SI), X2
+ MOVOU X2, 32(DI)
+ MOVOU 48(SI), X3
+ MOVOU X3, 48(DI)
+ MOVOU 64(SI), X4
+ MOVOU X4, 64(DI)
+ MOVOU 80(SI), X5
+ MOVOU X5, 80(DI)
+ MOVOU 96(SI), X6
+ MOVOU X6, 96(DI)
+ MOVOU 112(SI), X7
+ MOVOU X7, 112(DI)
+ MOVOU -128(SI)(BX*1), X8
+ MOVOU X8, -128(DI)(BX*1)
+ MOVOU -112(SI)(BX*1), X9
+ MOVOU X9, -112(DI)(BX*1)
+ MOVOU -96(SI)(BX*1), X10
+ MOVOU X10, -96(DI)(BX*1)
+ MOVOU -80(SI)(BX*1), X11
+ MOVOU X11, -80(DI)(BX*1)
+ MOVOU -64(SI)(BX*1), X12
+ MOVOU X12, -64(DI)(BX*1)
+ MOVOU -48(SI)(BX*1), X13
+ MOVOU X13, -48(DI)(BX*1)
+ MOVOU -32(SI)(BX*1), X14
+ MOVOU X14, -32(DI)(BX*1)
+ MOVOU -16(SI)(BX*1), X15
+ MOVOU X15, -16(DI)(BX*1)
+ RET
diff --git a/pkg/safecopy/memcpy_arm64.s b/pkg/safecopy/memcpy_arm64.s
new file mode 100644
index 000000000..e7e541565
--- /dev/null
+++ b/pkg/safecopy/memcpy_arm64.s
@@ -0,0 +1,78 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "textflag.h"
+
+// handleMemcpyFault returns (the value stored in R0, the value stored in R1).
+// Control is transferred to it when memcpy below receives SIGSEGV or SIGBUS,
+// with the faulting address stored in R0 and the signal number stored in R1.
+//
+// It must have the same frame configuration as memcpy so that it can undo any
+// potential call frame set up by the assembler.
+TEXT handleMemcpyFault(SB), NOSPLIT, $0-36
+ MOVD R0, addr+24(FP)
+ MOVW R1, sig+32(FP)
+ RET
+
+// memcpy copies data from src to dst. If a SIGSEGV or SIGBUS signal is received
+// during the copy, it returns the address that caused the fault and the number
+// of the signal that was received. Otherwise, it returns an unspecified address
+// and a signal number of 0.
+//
+// Data is copied in order, such that if a fault happens at address p, it is
+// safe to assume that all data before p-maxRegisterSize has already been
+// successfully copied.
+//
+// The code is derived from the Go source runtime.memmove.
+//
+// func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32)
+TEXT ·memcpy(SB), NOSPLIT, $-8-36
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleMemcpyFault will store a different value in this address.
+ MOVW $0, sig+32(FP)
+
+ MOVD to+0(FP), R3
+ MOVD from+8(FP), R4
+ MOVD n+16(FP), R5
+ CMP $0, R5
+ BNE check
+ RET
+
+check:
+ AND $~7, R5, R7 // R7 is N&~7.
+ SUB R7, R5, R6 // R6 is N&7.
+
+ // Copying forward proceeds by copying R7/8 words then copying R6 bytes.
+ // R3 and R4 are advanced as we copy.
+
+ // (There may be implementations of armv8 where copying by bytes until
+ // at least one of source or dest is word aligned is a worthwhile
+ // optimization, but the on the one tested so far (xgene) it did not
+ // make a significance difference.)
+
+ CMP $0, R7 // Do we need to do any word-by-word copying?
+ BEQ noforwardlarge
+ ADD R3, R7, R9 // R9 points just past where we copy by word.
+
+forwardlargeloop:
+ MOVD.P 8(R4), R8 // R8 is just a scratch register.
+ MOVD.P R8, 8(R3)
+ CMP R3, R9
+ BNE forwardlargeloop
+
+noforwardlarge:
+ CMP $0, R6 // Do we need to do any byte-by-byte copying?
+ BNE forwardtail
+ RET
+
+forwardtail:
+ ADD R3, R6, R9 // R9 points just past the destination memory.
+
+forwardtailloop:
+ MOVBU.P 1(R4), R8
+ MOVBU.P R8, 1(R3)
+ CMP R3, R9
+ BNE forwardtailloop
+ RET
diff --git a/pkg/safecopy/safecopy.go b/pkg/safecopy/safecopy.go
new file mode 100644
index 000000000..2fb7e5809
--- /dev/null
+++ b/pkg/safecopy/safecopy.go
@@ -0,0 +1,144 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package safecopy provides an efficient implementation of functions to access
+// memory that may result in SIGSEGV or SIGBUS being sent to the accessor.
+package safecopy
+
+import (
+ "fmt"
+ "reflect"
+ "runtime"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// SegvError is returned when a safecopy function receives SIGSEGV.
+type SegvError struct {
+ // Addr is the address at which the SIGSEGV occurred.
+ Addr uintptr
+}
+
+// Error implements error.Error.
+func (e SegvError) Error() string {
+ return fmt.Sprintf("SIGSEGV at %#x", e.Addr)
+}
+
+// BusError is returned when a safecopy function receives SIGBUS.
+type BusError struct {
+ // Addr is the address at which the SIGBUS occurred.
+ Addr uintptr
+}
+
+// Error implements error.Error.
+func (e BusError) Error() string {
+ return fmt.Sprintf("SIGBUS at %#x", e.Addr)
+}
+
+// AlignmentError is returned when a safecopy function is passed an address
+// that does not meet alignment requirements.
+type AlignmentError struct {
+ // Addr is the invalid address.
+ Addr uintptr
+
+ // Alignment is the required alignment.
+ Alignment uintptr
+}
+
+// Error implements error.Error.
+func (e AlignmentError) Error() string {
+ return fmt.Sprintf("address %#x is not aligned to a %d-byte boundary", e.Addr, e.Alignment)
+}
+
+var (
+ // The begin and end addresses below are for the functions that are
+ // checked by the signal handler.
+ memcpyBegin uintptr
+ memcpyEnd uintptr
+ memclrBegin uintptr
+ memclrEnd uintptr
+ swapUint32Begin uintptr
+ swapUint32End uintptr
+ swapUint64Begin uintptr
+ swapUint64End uintptr
+ compareAndSwapUint32Begin uintptr
+ compareAndSwapUint32End uintptr
+ loadUint32Begin uintptr
+ loadUint32End uintptr
+
+ // savedSigSegVHandler is a pointer to the SIGSEGV handler that was
+ // configured before we replaced it with our own. We still call into it
+ // when we get a SIGSEGV that is not interesting to us.
+ savedSigSegVHandler uintptr
+
+ // same a above, but for SIGBUS signals.
+ savedSigBusHandler uintptr
+)
+
+// signalHandler is our replacement signal handler for SIGSEGV and SIGBUS
+// signals.
+func signalHandler()
+
+// FindEndAddress returns the end address (one byte beyond the last) of the
+// function that contains the specified address (begin).
+func FindEndAddress(begin uintptr) uintptr {
+ f := runtime.FuncForPC(begin)
+ if f != nil {
+ for p := begin; ; p++ {
+ g := runtime.FuncForPC(p)
+ if f != g {
+ return p
+ }
+ }
+ }
+ return begin
+}
+
+// initializeAddresses initializes the addresses used by the signal handler.
+func initializeAddresses() {
+ // The following functions are written in assembly language, so they won't
+ // be inlined by the existing compiler/linker. Tests will fail if this
+ // assumption is violated.
+ memcpyBegin = reflect.ValueOf(memcpy).Pointer()
+ memcpyEnd = FindEndAddress(memcpyBegin)
+ memclrBegin = reflect.ValueOf(memclr).Pointer()
+ memclrEnd = FindEndAddress(memclrBegin)
+ swapUint32Begin = reflect.ValueOf(swapUint32).Pointer()
+ swapUint32End = FindEndAddress(swapUint32Begin)
+ swapUint64Begin = reflect.ValueOf(swapUint64).Pointer()
+ swapUint64End = FindEndAddress(swapUint64Begin)
+ compareAndSwapUint32Begin = reflect.ValueOf(compareAndSwapUint32).Pointer()
+ compareAndSwapUint32End = FindEndAddress(compareAndSwapUint32Begin)
+ loadUint32Begin = reflect.ValueOf(loadUint32).Pointer()
+ loadUint32End = FindEndAddress(loadUint32Begin)
+}
+
+func init() {
+ initializeAddresses()
+ if err := ReplaceSignalHandler(syscall.SIGSEGV, reflect.ValueOf(signalHandler).Pointer(), &savedSigSegVHandler); err != nil {
+ panic(fmt.Sprintf("Unable to set handler for SIGSEGV: %v", err))
+ }
+ if err := ReplaceSignalHandler(syscall.SIGBUS, reflect.ValueOf(signalHandler).Pointer(), &savedSigBusHandler); err != nil {
+ panic(fmt.Sprintf("Unable to set handler for SIGBUS: %v", err))
+ }
+ syserror.AddErrorUnwrapper(func(e error) (syscall.Errno, bool) {
+ switch e.(type) {
+ case SegvError, BusError, AlignmentError:
+ return syscall.EFAULT, true
+ default:
+ return 0, false
+ }
+ })
+}
diff --git a/pkg/safecopy/safecopy_test.go b/pkg/safecopy/safecopy_test.go
new file mode 100644
index 000000000..7f7f69d61
--- /dev/null
+++ b/pkg/safecopy/safecopy_test.go
@@ -0,0 +1,629 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package safecopy
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "math/rand"
+ "os"
+ "runtime/debug"
+ "syscall"
+ "testing"
+ "unsafe"
+)
+
+// Size of a page in bytes. Cloned from usermem.PageSize to avoid a circular
+// dependency.
+const pageSize = 4096
+
+func initRandom(b []byte) {
+ for i := range b {
+ b[i] = byte(rand.Intn(256))
+ }
+}
+
+func randBuf(size int) []byte {
+ b := make([]byte, size)
+ initRandom(b)
+ return b
+}
+
+func TestCopyInSuccess(t *testing.T) {
+ // Test that CopyIn does not return an error when all pages are accessible.
+ const bufLen = 8192
+ a := randBuf(bufLen)
+ b := make([]byte, bufLen)
+
+ n, err := CopyIn(b, unsafe.Pointer(&a[0]))
+ if n != bufLen {
+ t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen)
+ }
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if !bytes.Equal(a, b) {
+ t.Errorf("Buffers are not equal when they should be: %v %v", a, b)
+ }
+}
+
+func TestCopyOutSuccess(t *testing.T) {
+ // Test that CopyOut does not return an error when all pages are
+ // accessible.
+ const bufLen = 8192
+ a := randBuf(bufLen)
+ b := make([]byte, bufLen)
+
+ n, err := CopyOut(unsafe.Pointer(&b[0]), a)
+ if n != bufLen {
+ t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen)
+ }
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if !bytes.Equal(a, b) {
+ t.Errorf("Buffers are not equal when they should be: %v %v", a, b)
+ }
+}
+
+func TestCopySuccess(t *testing.T) {
+ // Test that Copy does not return an error when all pages are accessible.
+ const bufLen = 8192
+ a := randBuf(bufLen)
+ b := make([]byte, bufLen)
+
+ n, err := Copy(unsafe.Pointer(&b[0]), unsafe.Pointer(&a[0]), bufLen)
+ if n != bufLen {
+ t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen)
+ }
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if !bytes.Equal(a, b) {
+ t.Errorf("Buffers are not equal when they should be: %v %v", a, b)
+ }
+}
+
+func TestZeroOutSuccess(t *testing.T) {
+ // Test that ZeroOut does not return an error when all pages are
+ // accessible.
+ const bufLen = 8192
+ a := make([]byte, bufLen)
+ b := randBuf(bufLen)
+
+ n, err := ZeroOut(unsafe.Pointer(&b[0]), bufLen)
+ if n != bufLen {
+ t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen)
+ }
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if !bytes.Equal(a, b) {
+ t.Errorf("Buffers are not equal when they should be: %v %v", a, b)
+ }
+}
+
+func TestSwapUint32Success(t *testing.T) {
+ // Test that SwapUint32 does not return an error when the page is
+ // accessible.
+ before := uint32(rand.Int31())
+ after := uint32(rand.Int31())
+ val := before
+
+ old, err := SwapUint32(unsafe.Pointer(&val), after)
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if old != before {
+ t.Errorf("Unexpected old value: got %v, want %v", old, before)
+ }
+ if val != after {
+ t.Errorf("Unexpected new value: got %v, want %v", val, after)
+ }
+}
+
+func TestSwapUint32AlignmentError(t *testing.T) {
+ // Test that SwapUint32 returns an AlignmentError when passed an unaligned
+ // address.
+ data := make([]byte, 8) // 2 * sizeof(uint32).
+ alignedIndex := uintptr(0)
+ if offset := uintptr(unsafe.Pointer(&data[0])) % 4; offset != 0 {
+ alignedIndex = 4 - offset
+ }
+ ptr := unsafe.Pointer(&data[alignedIndex+1])
+ want := AlignmentError{Addr: uintptr(ptr), Alignment: 4}
+ if _, err := SwapUint32(ptr, 1); err != want {
+ t.Errorf("Unexpected error: got %v, want %v", err, want)
+ }
+}
+
+func TestSwapUint64Success(t *testing.T) {
+ // Test that SwapUint64 does not return an error when the page is
+ // accessible.
+ before := uint64(rand.Int63())
+ after := uint64(rand.Int63())
+ // "The first word in ... an allocated struct or slice can be relied upon
+ // to be 64-bit aligned." - sync/atomic docs
+ data := new(struct{ val uint64 })
+ data.val = before
+
+ old, err := SwapUint64(unsafe.Pointer(&data.val), after)
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if old != before {
+ t.Errorf("Unexpected old value: got %v, want %v", old, before)
+ }
+ if data.val != after {
+ t.Errorf("Unexpected new value: got %v, want %v", data.val, after)
+ }
+}
+
+func TestSwapUint64AlignmentError(t *testing.T) {
+ // Test that SwapUint64 returns an AlignmentError when passed an unaligned
+ // address.
+ data := make([]byte, 16) // 2 * sizeof(uint64).
+ alignedIndex := uintptr(0)
+ if offset := uintptr(unsafe.Pointer(&data[0])) % 8; offset != 0 {
+ alignedIndex = 8 - offset
+ }
+ ptr := unsafe.Pointer(&data[alignedIndex+1])
+ want := AlignmentError{Addr: uintptr(ptr), Alignment: 8}
+ if _, err := SwapUint64(ptr, 1); err != want {
+ t.Errorf("Unexpected error: got %v, want %v", err, want)
+ }
+}
+
+func TestCompareAndSwapUint32Success(t *testing.T) {
+ // Test that CompareAndSwapUint32 does not return an error when the page is
+ // accessible.
+ before := uint32(rand.Int31())
+ after := uint32(rand.Int31())
+ val := before
+
+ old, err := CompareAndSwapUint32(unsafe.Pointer(&val), before, after)
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if old != before {
+ t.Errorf("Unexpected old value: got %v, want %v", old, before)
+ }
+ if val != after {
+ t.Errorf("Unexpected new value: got %v, want %v", val, after)
+ }
+}
+
+func TestCompareAndSwapUint32AlignmentError(t *testing.T) {
+ // Test that CompareAndSwapUint32 returns an AlignmentError when passed an
+ // unaligned address.
+ data := make([]byte, 8) // 2 * sizeof(uint32).
+ alignedIndex := uintptr(0)
+ if offset := uintptr(unsafe.Pointer(&data[0])) % 4; offset != 0 {
+ alignedIndex = 4 - offset
+ }
+ ptr := unsafe.Pointer(&data[alignedIndex+1])
+ want := AlignmentError{Addr: uintptr(ptr), Alignment: 4}
+ if _, err := CompareAndSwapUint32(ptr, 0, 1); err != want {
+ t.Errorf("Unexpected error: got %v, want %v", err, want)
+ }
+}
+
+// withSegvErrorTestMapping calls fn with a two-page mapping. The first page
+// contains random data, and the second page generates SIGSEGV when accessed.
+func withSegvErrorTestMapping(t *testing.T, fn func(m []byte)) {
+ mapping, err := syscall.Mmap(-1, 0, 2*pageSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_ANONYMOUS|syscall.MAP_PRIVATE)
+ if err != nil {
+ t.Fatalf("Mmap failed: %v", err)
+ }
+ defer syscall.Munmap(mapping)
+ if err := syscall.Mprotect(mapping[pageSize:], syscall.PROT_NONE); err != nil {
+ t.Fatalf("Mprotect failed: %v", err)
+ }
+ initRandom(mapping[:pageSize])
+
+ fn(mapping)
+}
+
+// withBusErrorTestMapping calls fn with a two-page mapping. The first page
+// contains random data, and the second page generates SIGBUS when accessed.
+func withBusErrorTestMapping(t *testing.T, fn func(m []byte)) {
+ f, err := ioutil.TempFile("", "sigbus_test")
+ if err != nil {
+ t.Fatalf("TempFile failed: %v", err)
+ }
+ defer f.Close()
+ if err := f.Truncate(pageSize); err != nil {
+ t.Fatalf("Truncate failed: %v", err)
+ }
+ mapping, err := syscall.Mmap(int(f.Fd()), 0, 2*pageSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED)
+ if err != nil {
+ t.Fatalf("Mmap failed: %v", err)
+ }
+ defer syscall.Munmap(mapping)
+ initRandom(mapping[:pageSize])
+
+ fn(mapping)
+}
+
+func TestCopyInSegvError(t *testing.T) {
+ // Test that CopyIn returns a SegvError when reaching a page that signals
+ // SIGSEGV.
+ for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
+ t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
+ withSegvErrorTestMapping(t, func(mapping []byte) {
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
+ dst := randBuf(pageSize)
+ n, err := CopyIn(dst, src)
+ if n != bytesBeforeFault {
+ t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault)
+ }
+ if want := (SegvError{secondPage}); err != want {
+ t.Errorf("Unexpected error: got %v, want %v", err, want)
+ }
+ if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) {
+ t.Errorf("Buffers are not equal when they should be: %v %v", got, want)
+ }
+ })
+ })
+ }
+}
+
+func TestCopyInBusError(t *testing.T) {
+ // Test that CopyIn returns a BusError when reaching a page that signals
+ // SIGBUS.
+ for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
+ t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) {
+ withBusErrorTestMapping(t, func(mapping []byte) {
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
+ dst := randBuf(pageSize)
+ n, err := CopyIn(dst, src)
+ if n != bytesBeforeFault {
+ t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault)
+ }
+ if want := (BusError{secondPage}); err != want {
+ t.Errorf("Unexpected error: got %v, want %v", err, want)
+ }
+ if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) {
+ t.Errorf("Buffers are not equal when they should be: %v %v", got, want)
+ }
+ })
+ })
+ }
+}
+
+func TestCopyOutSegvError(t *testing.T) {
+ // Test that CopyOut returns a SegvError when reaching a page that signals
+ // SIGSEGV.
+ for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
+ t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
+ withSegvErrorTestMapping(t, func(mapping []byte) {
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
+ src := randBuf(pageSize)
+ n, err := CopyOut(dst, src)
+ if n != bytesBeforeFault {
+ t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault)
+ }
+ if want := (SegvError{secondPage}); err != want {
+ t.Errorf("Unexpected error: got %v, want %v", err, want)
+ }
+ if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) {
+ t.Errorf("Buffers are not equal when they should be: %v %v", got, want)
+ }
+ })
+ })
+ }
+}
+
+func TestCopyOutBusError(t *testing.T) {
+ // Test that CopyOut returns a BusError when reaching a page that signals
+ // SIGBUS.
+ for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
+ t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
+ withBusErrorTestMapping(t, func(mapping []byte) {
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
+ src := randBuf(pageSize)
+ n, err := CopyOut(dst, src)
+ if n != bytesBeforeFault {
+ t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault)
+ }
+ if want := (BusError{secondPage}); err != want {
+ t.Errorf("Unexpected error: got %v, want %v", err, want)
+ }
+ if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) {
+ t.Errorf("Buffers are not equal when they should be: %v %v", got, want)
+ }
+ })
+ })
+ }
+}
+
+func TestCopySourceSegvError(t *testing.T) {
+ // Test that Copy returns a SegvError when copying from a page that signals
+ // SIGSEGV.
+ for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
+ t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
+ withSegvErrorTestMapping(t, func(mapping []byte) {
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
+ dst := randBuf(pageSize)
+ n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize)
+ if n != uintptr(bytesBeforeFault) {
+ t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault)
+ }
+ if want := (SegvError{secondPage}); err != want {
+ t.Errorf("Unexpected error: got %v, want %v", err, want)
+ }
+ if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) {
+ t.Errorf("Buffers are not equal when they should be: %v %v", got, want)
+ }
+ })
+ })
+ }
+}
+
+func TestCopySourceBusError(t *testing.T) {
+ // Test that Copy returns a BusError when copying from a page that signals
+ // SIGBUS.
+ for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
+ t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) {
+ withBusErrorTestMapping(t, func(mapping []byte) {
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
+ dst := randBuf(pageSize)
+ n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize)
+ if n != uintptr(bytesBeforeFault) {
+ t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault)
+ }
+ if want := (BusError{secondPage}); err != want {
+ t.Errorf("Unexpected error: got %v, want %v", err, want)
+ }
+ if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) {
+ t.Errorf("Buffers are not equal when they should be: %v %v", got, want)
+ }
+ })
+ })
+ }
+}
+
+func TestCopyDestinationSegvError(t *testing.T) {
+ // Test that Copy returns a SegvError when copying to a page that signals
+ // SIGSEGV.
+ for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
+ t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
+ withSegvErrorTestMapping(t, func(mapping []byte) {
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
+ src := randBuf(pageSize)
+ n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize)
+ if n != uintptr(bytesBeforeFault) {
+ t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault)
+ }
+ if want := (SegvError{secondPage}); err != want {
+ t.Errorf("Unexpected error: got %v, want %v", err, want)
+ }
+ if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) {
+ t.Errorf("Buffers are not equal when they should be: %v %v", got, want)
+ }
+ })
+ })
+ }
+}
+
+func TestCopyDestinationBusError(t *testing.T) {
+ // Test that Copy returns a BusError when copying to a page that signals
+ // SIGBUS.
+ for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
+ t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) {
+ withBusErrorTestMapping(t, func(mapping []byte) {
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
+ src := randBuf(pageSize)
+ n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize)
+ if n != uintptr(bytesBeforeFault) {
+ t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault)
+ }
+ if want := (BusError{secondPage}); err != want {
+ t.Errorf("Unexpected error: got %v, want %v", err, want)
+ }
+ if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) {
+ t.Errorf("Buffers are not equal when they should be: %v %v", got, want)
+ }
+ })
+ })
+ }
+}
+
+func TestZeroOutSegvError(t *testing.T) {
+ // Test that ZeroOut returns a SegvError when reaching a page that signals
+ // SIGSEGV.
+ for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
+ t.Run(fmt.Sprintf("starting write %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
+ withSegvErrorTestMapping(t, func(mapping []byte) {
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
+ n, err := ZeroOut(dst, pageSize)
+ if n != uintptr(bytesBeforeFault) {
+ t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault)
+ }
+ if want := (SegvError{secondPage}); err != want {
+ t.Errorf("Unexpected error: got %v, want %v", err, want)
+ }
+ if got, want := mapping[pageSize-bytesBeforeFault:pageSize], make([]byte, bytesBeforeFault); !bytes.Equal(got, want) {
+ t.Errorf("Non-zero bytes in written part of mapping: %v", got)
+ }
+ })
+ })
+ }
+}
+
+func TestZeroOutBusError(t *testing.T) {
+ // Test that ZeroOut returns a BusError when reaching a page that signals
+ // SIGBUS.
+ for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
+ t.Run(fmt.Sprintf("starting write %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) {
+ withBusErrorTestMapping(t, func(mapping []byte) {
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
+ n, err := ZeroOut(dst, pageSize)
+ if n != uintptr(bytesBeforeFault) {
+ t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault)
+ }
+ if want := (BusError{secondPage}); err != want {
+ t.Errorf("Unexpected error: got %v, want %v", err, want)
+ }
+ if got, want := mapping[pageSize-bytesBeforeFault:pageSize], make([]byte, bytesBeforeFault); !bytes.Equal(got, want) {
+ t.Errorf("Non-zero bytes in written part of mapping: %v", got)
+ }
+ })
+ })
+ }
+}
+
+func TestSwapUint32SegvError(t *testing.T) {
+ // Test that SwapUint32 returns a SegvError when reaching a page that
+ // signals SIGSEGV.
+ withSegvErrorTestMapping(t, func(mapping []byte) {
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ _, err := SwapUint32(unsafe.Pointer(secondPage), 1)
+ if want := (SegvError{secondPage}); err != want {
+ t.Errorf("Unexpected error: got %v, want %v", err, want)
+ }
+ })
+}
+
+func TestSwapUint32BusError(t *testing.T) {
+ // Test that SwapUint32 returns a BusError when reaching a page that
+ // signals SIGBUS.
+ withBusErrorTestMapping(t, func(mapping []byte) {
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ _, err := SwapUint32(unsafe.Pointer(secondPage), 1)
+ if want := (BusError{secondPage}); err != want {
+ t.Errorf("Unexpected error: got %v, want %v", err, want)
+ }
+ })
+}
+
+func TestSwapUint64SegvError(t *testing.T) {
+ // Test that SwapUint64 returns a SegvError when reaching a page that
+ // signals SIGSEGV.
+ withSegvErrorTestMapping(t, func(mapping []byte) {
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ _, err := SwapUint64(unsafe.Pointer(secondPage), 1)
+ if want := (SegvError{secondPage}); err != want {
+ t.Errorf("Unexpected error: got %v, want %v", err, want)
+ }
+ })
+}
+
+func TestSwapUint64BusError(t *testing.T) {
+ // Test that SwapUint64 returns a BusError when reaching a page that
+ // signals SIGBUS.
+ withBusErrorTestMapping(t, func(mapping []byte) {
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ _, err := SwapUint64(unsafe.Pointer(secondPage), 1)
+ if want := (BusError{secondPage}); err != want {
+ t.Errorf("Unexpected error: got %v, want %v", err, want)
+ }
+ })
+}
+
+func TestCompareAndSwapUint32SegvError(t *testing.T) {
+ // Test that CompareAndSwapUint32 returns a SegvError when reaching a page
+ // that signals SIGSEGV.
+ withSegvErrorTestMapping(t, func(mapping []byte) {
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ _, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1)
+ if want := (SegvError{secondPage}); err != want {
+ t.Errorf("Unexpected error: got %v, want %v", err, want)
+ }
+ })
+}
+
+func TestCompareAndSwapUint32BusError(t *testing.T) {
+ // Test that CompareAndSwapUint32 returns a BusError when reaching a page
+ // that signals SIGBUS.
+ withBusErrorTestMapping(t, func(mapping []byte) {
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ _, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1)
+ if want := (BusError{secondPage}); err != want {
+ t.Errorf("Unexpected error: got %v, want %v", err, want)
+ }
+ })
+}
+
+func testCopy(dst, src []byte) (panicked bool) {
+ defer func() {
+ if r := recover(); r != nil {
+ panicked = true
+ }
+ }()
+ debug.SetPanicOnFault(true)
+ copy(dst, src)
+ return
+}
+
+func TestSegVOnMemmove(t *testing.T) {
+ // Test that SIGSEGVs received by runtime.memmove when *not* doing
+ // CopyIn or CopyOut work gets propagated to the runtime.
+ const bufLen = pageSize
+ a, err := syscall.Mmap(-1, 0, bufLen, syscall.PROT_NONE, syscall.MAP_ANON|syscall.MAP_PRIVATE)
+ if err != nil {
+ t.Fatalf("Mmap failed: %v", err)
+
+ }
+ defer syscall.Munmap(a)
+ b := randBuf(bufLen)
+
+ if !testCopy(b, a) {
+ t.Fatalf("testCopy didn't panic when it should have")
+ }
+
+ if !testCopy(a, b) {
+ t.Fatalf("testCopy didn't panic when it should have")
+ }
+}
+
+func TestSigbusOnMemmove(t *testing.T) {
+ // Test that SIGBUS received by runtime.memmove when *not* doing
+ // CopyIn or CopyOut work gets propagated to the runtime.
+ const bufLen = pageSize
+ f, err := ioutil.TempFile("", "sigbus_test")
+ if err != nil {
+ t.Fatalf("TempFile failed: %v", err)
+ }
+ os.Remove(f.Name())
+ defer f.Close()
+
+ a, err := syscall.Mmap(int(f.Fd()), 0, bufLen, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED)
+ if err != nil {
+ t.Fatalf("Mmap failed: %v", err)
+
+ }
+ defer syscall.Munmap(a)
+ b := randBuf(bufLen)
+
+ if !testCopy(b, a) {
+ t.Fatalf("testCopy didn't panic when it should have")
+ }
+
+ if !testCopy(a, b) {
+ t.Fatalf("testCopy didn't panic when it should have")
+ }
+}
diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go
new file mode 100644
index 000000000..41dd567f3
--- /dev/null
+++ b/pkg/safecopy/safecopy_unsafe.go
@@ -0,0 +1,361 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package safecopy
+
+import (
+ "fmt"
+ "runtime"
+ "syscall"
+ "unsafe"
+)
+
+// maxRegisterSize is the maximum register size used in memcpy and memclr. It
+// is used to decide by how much to rewind the copy (for memcpy) or zeroing
+// (for memclr) before proceeding.
+const maxRegisterSize = 16
+
+// memcpy copies data from src to dst. If a SIGSEGV or SIGBUS signal is received
+// during the copy, it returns the address that caused the fault and the number
+// of the signal that was received. Otherwise, it returns an unspecified address
+// and a signal number of 0.
+//
+// Data is copied in order, such that if a fault happens at address p, it is
+// safe to assume that all data before p-maxRegisterSize has already been
+// successfully copied.
+//
+//go:noescape
+func memcpy(dst, src uintptr, n uintptr) (fault uintptr, sig int32)
+
+// memclr sets the n bytes following ptr to zeroes. If a SIGSEGV or SIGBUS
+// signal is received during the write, it returns the address that caused the
+// fault and the number of the signal that was received. Otherwise, it returns
+// an unspecified address and a signal number of 0.
+//
+// Data is written in order, such that if a fault happens at address p, it is
+// safe to assume that all data before p-maxRegisterSize has already been
+// successfully written.
+//
+//go:noescape
+func memclr(ptr uintptr, n uintptr) (fault uintptr, sig int32)
+
+// swapUint32 atomically stores new into *ptr and returns (the previous *ptr
+// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the
+// value of old is unspecified, and sig is the number of the signal that was
+// received.
+//
+// Preconditions: ptr must be aligned to a 4-byte boundary.
+//
+//go:noescape
+func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32)
+
+// swapUint64 atomically stores new into *ptr and returns (the previous *ptr
+// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the
+// value of old is unspecified, and sig is the number of the signal that was
+// received.
+//
+// Preconditions: ptr must be aligned to a 8-byte boundary.
+//
+//go:noescape
+func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32)
+
+// compareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns
+// (the value previously stored at ptr, 0). If a SIGSEGV or SIGBUS signal is
+// received during the operation, the value of prev is unspecified, and sig is
+// the number of the signal that was received.
+//
+// Preconditions: ptr must be aligned to a 4-byte boundary.
+//
+//go:noescape
+func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32)
+
+// LoadUint32 is like sync/atomic.LoadUint32, but operates with user memory. It
+// may fail with SIGSEGV or SIGBUS if it is received while reading from ptr.
+//
+// Preconditions: ptr must be aligned to a 4-byte boundary.
+//
+//go:noescape
+func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32)
+
+// CopyIn copies len(dst) bytes from src to dst. It returns the number of bytes
+// copied and an error if SIGSEGV or SIGBUS is received while reading from src.
+func CopyIn(dst []byte, src unsafe.Pointer) (int, error) {
+ n, err := copyIn(dst, uintptr(src))
+ runtime.KeepAlive(src)
+ return n, err
+}
+
+// copyIn is the underlying definition for CopyIn.
+func copyIn(dst []byte, src uintptr) (int, error) {
+ toCopy := uintptr(len(dst))
+ if len(dst) == 0 {
+ return 0, nil
+ }
+
+ fault, sig := memcpy(uintptr(unsafe.Pointer(&dst[0])), src, toCopy)
+ if sig == 0 {
+ return len(dst), nil
+ }
+
+ if fault < src || fault >= src+toCopy {
+ panic(fmt.Sprintf("CopyIn raised signal %d at %#x, which is outside source [%#x, %#x)", sig, fault, src, src+toCopy))
+ }
+
+ // memcpy might have ended the copy up to maxRegisterSize bytes before
+ // fault, if an instruction caused a memory access that straddled two
+ // pages, and the second one faulted. Try to copy up to the fault.
+ var done int
+ if fault-src > maxRegisterSize {
+ done = int(fault - src - maxRegisterSize)
+ }
+ n, err := copyIn(dst[done:int(fault-src)], src+uintptr(done))
+ done += n
+ if err != nil {
+ return done, err
+ }
+ return done, errorFromFaultSignal(fault, sig)
+}
+
+// CopyOut copies len(src) bytes from src to dst. If returns the number of
+// bytes done and an error if SIGSEGV or SIGBUS is received while writing to
+// dst.
+func CopyOut(dst unsafe.Pointer, src []byte) (int, error) {
+ n, err := copyOut(uintptr(dst), src)
+ runtime.KeepAlive(dst)
+ return n, err
+}
+
+// copyOut is the underlying definition for CopyOut.
+func copyOut(dst uintptr, src []byte) (int, error) {
+ toCopy := uintptr(len(src))
+ if toCopy == 0 {
+ return 0, nil
+ }
+
+ fault, sig := memcpy(dst, uintptr(unsafe.Pointer(&src[0])), toCopy)
+ if sig == 0 {
+ return len(src), nil
+ }
+
+ if fault < dst || fault >= dst+toCopy {
+ panic(fmt.Sprintf("CopyOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, fault, dst, dst+toCopy))
+ }
+
+ // memcpy might have ended the copy up to maxRegisterSize bytes before
+ // fault, if an instruction caused a memory access that straddled two
+ // pages, and the second one faulted. Try to copy up to the fault.
+ var done int
+ if fault-dst > maxRegisterSize {
+ done = int(fault - dst - maxRegisterSize)
+ }
+ n, err := copyOut(dst+uintptr(done), src[done:int(fault-dst)])
+ done += n
+ if err != nil {
+ return done, err
+ }
+ return done, errorFromFaultSignal(fault, sig)
+}
+
+// Copy copies toCopy bytes from src to dst. It returns the number of bytes
+// copied and an error if SIGSEGV or SIGBUS is received while reading from src
+// or writing to dst.
+//
+// Data is copied in order; if [src, src+toCopy) and [dst, dst+toCopy) overlap,
+// the resulting contents of dst are unspecified.
+func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) {
+ n, err := copyN(uintptr(dst), uintptr(src), toCopy)
+ runtime.KeepAlive(dst)
+ runtime.KeepAlive(src)
+ return n, err
+}
+
+// copyN is the underlying definition for Copy.
+func copyN(dst, src uintptr, toCopy uintptr) (uintptr, error) {
+ if toCopy == 0 {
+ return 0, nil
+ }
+
+ fault, sig := memcpy(dst, src, toCopy)
+ if sig == 0 {
+ return toCopy, nil
+ }
+
+ // Did the fault occur while reading from src or writing to dst?
+ faultAfterSrc := ^uintptr(0)
+ if fault >= src {
+ faultAfterSrc = fault - src
+ }
+ faultAfterDst := ^uintptr(0)
+ if fault >= dst {
+ faultAfterDst = fault - dst
+ }
+ if faultAfterSrc >= toCopy && faultAfterDst >= toCopy {
+ panic(fmt.Sprintf("Copy raised signal %d at %#x, which is outside source [%#x, %#x) and destination [%#x, %#x)", sig, fault, src, src+toCopy, dst, dst+toCopy))
+ }
+ faultedAfter := faultAfterSrc
+ if faultedAfter > faultAfterDst {
+ faultedAfter = faultAfterDst
+ }
+
+ // memcpy might have ended the copy up to maxRegisterSize bytes before
+ // fault, if an instruction caused a memory access that straddled two
+ // pages, and the second one faulted. Try to copy up to the fault.
+ var done uintptr
+ if faultedAfter > maxRegisterSize {
+ done = faultedAfter - maxRegisterSize
+ }
+ n, err := copyN(dst+done, src+done, faultedAfter-done)
+ done += n
+ if err != nil {
+ return done, err
+ }
+ return done, errorFromFaultSignal(fault, sig)
+}
+
+// ZeroOut writes toZero zero bytes to dst. It returns the number of bytes
+// written and an error if SIGSEGV or SIGBUS is received while writing to dst.
+func ZeroOut(dst unsafe.Pointer, toZero uintptr) (uintptr, error) {
+ n, err := zeroOut(uintptr(dst), toZero)
+ runtime.KeepAlive(dst)
+ return n, err
+}
+
+// zeroOut is the underlying definition for ZeroOut.
+func zeroOut(dst uintptr, toZero uintptr) (uintptr, error) {
+ if toZero == 0 {
+ return 0, nil
+ }
+
+ fault, sig := memclr(dst, toZero)
+ if sig == 0 {
+ return toZero, nil
+ }
+
+ if fault < dst || fault >= dst+toZero {
+ panic(fmt.Sprintf("ZeroOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, fault, dst, dst+toZero))
+ }
+
+ // memclr might have ended the write up to maxRegisterSize bytes before
+ // fault, if an instruction caused a memory access that straddled two
+ // pages, and the second one faulted. Try to write up to the fault.
+ var done uintptr
+ if fault-dst > maxRegisterSize {
+ done = fault - dst - maxRegisterSize
+ }
+ n, err := zeroOut(dst+done, fault-dst-done)
+ done += n
+ if err != nil {
+ return done, err
+ }
+ return done, errorFromFaultSignal(fault, sig)
+}
+
+// SwapUint32 is equivalent to sync/atomic.SwapUint32, except that it returns
+// an error if SIGSEGV or SIGBUS is received while accessing ptr, or if ptr is
+// not aligned to a 4-byte boundary.
+func SwapUint32(ptr unsafe.Pointer, new uint32) (uint32, error) {
+ if addr := uintptr(ptr); addr&3 != 0 {
+ return 0, AlignmentError{addr, 4}
+ }
+ old, sig := swapUint32(ptr, new)
+ return old, errorFromFaultSignal(uintptr(ptr), sig)
+}
+
+// SwapUint64 is equivalent to sync/atomic.SwapUint64, except that it returns
+// an error if SIGSEGV or SIGBUS is received while accessing ptr, or if ptr is
+// not aligned to an 8-byte boundary.
+func SwapUint64(ptr unsafe.Pointer, new uint64) (uint64, error) {
+ if addr := uintptr(ptr); addr&7 != 0 {
+ return 0, AlignmentError{addr, 8}
+ }
+ old, sig := swapUint64(ptr, new)
+ return old, errorFromFaultSignal(uintptr(ptr), sig)
+}
+
+// CompareAndSwapUint32 is equivalent to atomicbitops.CompareAndSwapUint32,
+// except that it returns an error if SIGSEGV or SIGBUS is received while
+// accessing ptr, or if ptr is not aligned to a 4-byte boundary.
+func CompareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (uint32, error) {
+ if addr := uintptr(ptr); addr&3 != 0 {
+ return 0, AlignmentError{addr, 4}
+ }
+ prev, sig := compareAndSwapUint32(ptr, old, new)
+ return prev, errorFromFaultSignal(uintptr(ptr), sig)
+}
+
+// LoadUint32 is like sync/atomic.LoadUint32, but operates with user memory. It
+// may fail with SIGSEGV or SIGBUS if it is received while reading from ptr.
+//
+// Preconditions: ptr must be aligned to a 4-byte boundary.
+func LoadUint32(ptr unsafe.Pointer) (uint32, error) {
+ if addr := uintptr(ptr); addr&3 != 0 {
+ return 0, AlignmentError{addr, 4}
+ }
+ val, sig := loadUint32(ptr)
+ return val, errorFromFaultSignal(uintptr(ptr), sig)
+}
+
+func errorFromFaultSignal(addr uintptr, sig int32) error {
+ switch sig {
+ case 0:
+ return nil
+ case int32(syscall.SIGSEGV):
+ return SegvError{addr}
+ case int32(syscall.SIGBUS):
+ return BusError{addr}
+ default:
+ panic(fmt.Sprintf("safecopy got unexpected signal %d at address %#x", sig, addr))
+ }
+}
+
+// ReplaceSignalHandler replaces the existing signal handler for the provided
+// signal with the one that handles faults in safecopy-protected functions.
+//
+// It stores the value of the previously set handler in previous.
+//
+// This function will be called on initialization in order to install safecopy
+// handlers for appropriate signals. These handlers will call the previous
+// handler however, and if this is function is being used externally then the
+// same courtesy is expected.
+func ReplaceSignalHandler(sig syscall.Signal, handler uintptr, previous *uintptr) error {
+ var sa struct {
+ handler uintptr
+ flags uint64
+ restorer uintptr
+ mask uint64
+ }
+ const maskLen = 8
+
+ // Get the existing signal handler information, and save the current
+ // handler. Once we replace it, we will use this pointer to fall back to
+ // it when we receive other signals.
+ if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), 0, uintptr(unsafe.Pointer(&sa)), maskLen, 0, 0); e != 0 {
+ return e
+ }
+
+ // Fail if there isn't a previous handler.
+ if sa.handler == 0 {
+ return fmt.Errorf("previous handler for signal %x isn't set", sig)
+ }
+
+ *previous = sa.handler
+
+ // Install our own handler.
+ sa.handler = handler
+ if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 {
+ return e
+ }
+
+ return nil
+}
diff --git a/pkg/safecopy/sighandler_amd64.s b/pkg/safecopy/sighandler_amd64.s
new file mode 100644
index 000000000..475ae48e9
--- /dev/null
+++ b/pkg/safecopy/sighandler_amd64.s
@@ -0,0 +1,133 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// The signals handled by sigHandler.
+#define SIGBUS 7
+#define SIGSEGV 11
+
+// Offsets to the registers in context->uc_mcontext.gregs[].
+#define REG_RDI 0x68
+#define REG_RAX 0x90
+#define REG_IP 0xa8
+
+// Offset to the si_addr field of siginfo.
+#define SI_CODE 0x08
+#define SI_ADDR 0x10
+
+// signalHandler is the signal handler for SIGSEGV and SIGBUS signals. It must
+// not be set up as a handler to any other signals.
+//
+// If the instruction causing the signal is within a safecopy-protected
+// function, the signal is handled such that execution resumes in the
+// appropriate fault handling stub with AX containing the faulting address and
+// DI containing the signal number. Otherwise control is transferred to the
+// previously configured signal handler (savedSigSegvHandler or
+// savedSigBusHandler).
+//
+// This function cannot be written in go because it runs whenever a signal is
+// received by the thread (preempting whatever was running), which includes when
+// garbage collector has stopped or isn't expecting any interactions (like
+// barriers).
+//
+// The arguments are the following:
+// DI - The signal number.
+// SI - Pointer to siginfo_t structure.
+// DX - Pointer to ucontext structure.
+TEXT ·signalHandler(SB),NOSPLIT,$0
+ // Check if the signal is from the kernel.
+ MOVQ $0x0, CX
+ CMPL CX, SI_CODE(SI)
+ JGE original_handler
+
+ // Check if RIP is within the area we care about.
+ MOVQ REG_IP(DX), CX
+ CMPQ CX, ·memcpyBegin(SB)
+ JB not_memcpy
+ CMPQ CX, ·memcpyEnd(SB)
+ JAE not_memcpy
+
+ // Modify the context such that execution will resume in the fault
+ // handler.
+ LEAQ handleMemcpyFault(SB), CX
+ JMP handle_fault
+
+not_memcpy:
+ CMPQ CX, ·memclrBegin(SB)
+ JB not_memclr
+ CMPQ CX, ·memclrEnd(SB)
+ JAE not_memclr
+
+ LEAQ handleMemclrFault(SB), CX
+ JMP handle_fault
+
+not_memclr:
+ CMPQ CX, ·swapUint32Begin(SB)
+ JB not_swapuint32
+ CMPQ CX, ·swapUint32End(SB)
+ JAE not_swapuint32
+
+ LEAQ handleSwapUint32Fault(SB), CX
+ JMP handle_fault
+
+not_swapuint32:
+ CMPQ CX, ·swapUint64Begin(SB)
+ JB not_swapuint64
+ CMPQ CX, ·swapUint64End(SB)
+ JAE not_swapuint64
+
+ LEAQ handleSwapUint64Fault(SB), CX
+ JMP handle_fault
+
+not_swapuint64:
+ CMPQ CX, ·compareAndSwapUint32Begin(SB)
+ JB not_casuint32
+ CMPQ CX, ·compareAndSwapUint32End(SB)
+ JAE not_casuint32
+
+ LEAQ handleCompareAndSwapUint32Fault(SB), CX
+ JMP handle_fault
+
+not_casuint32:
+ CMPQ CX, ·loadUint32Begin(SB)
+ JB not_loaduint32
+ CMPQ CX, ·loadUint32End(SB)
+ JAE not_loaduint32
+
+ LEAQ handleLoadUint32Fault(SB), CX
+ JMP handle_fault
+
+not_loaduint32:
+original_handler:
+ // Jump to the previous signal handler, which is likely the golang one.
+ XORQ CX, CX
+ MOVQ ·savedSigBusHandler(SB), AX
+ CMPL DI, $SIGSEGV
+ CMOVQEQ ·savedSigSegVHandler(SB), AX
+ JMP AX
+
+handle_fault:
+ // Entered with the address of the fault handler in RCX; store it in
+ // RIP.
+ MOVQ CX, REG_IP(DX)
+
+ // Store the faulting address in RAX.
+ MOVQ SI_ADDR(SI), CX
+ MOVQ CX, REG_RAX(DX)
+
+ // Store the signal number in EDI.
+ MOVL DI, REG_RDI(DX)
+
+ RET
diff --git a/pkg/safecopy/sighandler_arm64.s b/pkg/safecopy/sighandler_arm64.s
new file mode 100644
index 000000000..53e4ac2c1
--- /dev/null
+++ b/pkg/safecopy/sighandler_arm64.s
@@ -0,0 +1,143 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// The signals handled by sigHandler.
+#define SIGBUS 7
+#define SIGSEGV 11
+
+// Offsets to the registers in context->uc_mcontext.gregs[].
+#define REG_R0 0xB8
+#define REG_R1 0xC0
+#define REG_PC 0x1B8
+
+// Offset to the si_addr field of siginfo.
+#define SI_CODE 0x08
+#define SI_ADDR 0x10
+
+// signalHandler is the signal handler for SIGSEGV and SIGBUS signals. It must
+// not be set up as a handler to any other signals.
+//
+// If the instruction causing the signal is within a safecopy-protected
+// function, the signal is handled such that execution resumes in the
+// appropriate fault handling stub with R0 containing the faulting address and
+// R1 containing the signal number. Otherwise control is transferred to the
+// previously configured signal handler (savedSigSegvHandler or
+// savedSigBusHandler).
+//
+// This function cannot be written in go because it runs whenever a signal is
+// received by the thread (preempting whatever was running), which includes when
+// garbage collector has stopped or isn't expecting any interactions (like
+// barriers).
+//
+// The arguments are the following:
+// R0 - The signal number.
+// R1 - Pointer to siginfo_t structure.
+// R2 - Pointer to ucontext structure.
+TEXT ·signalHandler(SB),NOSPLIT,$0
+ // Check if the signal is from the kernel, si_code > 0 means a kernel signal.
+ MOVD SI_CODE(R1), R7
+ CMPW $0x0, R7
+ BLE original_handler
+
+ // Check if PC is within the area we care about.
+ MOVD REG_PC(R2), R7
+ MOVD ·memcpyBegin(SB), R8
+ CMP R8, R7
+ BLO not_memcpy
+ MOVD ·memcpyEnd(SB), R8
+ CMP R8, R7
+ BHS not_memcpy
+
+ // Modify the context such that execution will resume in the fault handler.
+ MOVD $handleMemcpyFault(SB), R7
+ B handle_fault
+
+not_memcpy:
+ MOVD ·memclrBegin(SB), R8
+ CMP R8, R7
+ BLO not_memclr
+ MOVD ·memclrEnd(SB), R8
+ CMP R8, R7
+ BHS not_memclr
+
+ MOVD $handleMemclrFault(SB), R7
+ B handle_fault
+
+not_memclr:
+ MOVD ·swapUint32Begin(SB), R8
+ CMP R8, R7
+ BLO not_swapuint32
+ MOVD ·swapUint32End(SB), R8
+ CMP R8, R7
+ BHS not_swapuint32
+
+ MOVD $handleSwapUint32Fault(SB), R7
+ B handle_fault
+
+not_swapuint32:
+ MOVD ·swapUint64Begin(SB), R8
+ CMP R8, R7
+ BLO not_swapuint64
+ MOVD ·swapUint64End(SB), R8
+ CMP R8, R7
+ BHS not_swapuint64
+
+ MOVD $handleSwapUint64Fault(SB), R7
+ B handle_fault
+
+not_swapuint64:
+ MOVD ·compareAndSwapUint32Begin(SB), R8
+ CMP R8, R7
+ BLO not_casuint32
+ MOVD ·compareAndSwapUint32End(SB), R8
+ CMP R8, R7
+ BHS not_casuint32
+
+ MOVD $handleCompareAndSwapUint32Fault(SB), R7
+ B handle_fault
+
+not_casuint32:
+ MOVD ·loadUint32Begin(SB), R8
+ CMP R8, R7
+ BLO not_loaduint32
+ MOVD ·loadUint32End(SB), R8
+ CMP R8, R7
+ BHS not_loaduint32
+
+ MOVD $handleLoadUint32Fault(SB), R7
+ B handle_fault
+
+not_loaduint32:
+original_handler:
+ // Jump to the previous signal handler, which is likely the golang one.
+ MOVD ·savedSigBusHandler(SB), R7
+ MOVD ·savedSigSegVHandler(SB), R8
+ CMPW $SIGSEGV, R0
+ CSEL EQ, R8, R7, R7
+ B (R7)
+
+handle_fault:
+ // Entered with the address of the fault handler in R7; store it in PC.
+ MOVD R7, REG_PC(R2)
+
+ // Store the faulting address in R0.
+ MOVD SI_ADDR(R1), R7
+ MOVD R7, REG_R0(R2)
+
+ // Store the signal number in R1.
+ MOVW R0, REG_R1(R2)
+
+ RET
diff --git a/pkg/safemem/BUILD b/pkg/safemem/BUILD
new file mode 100644
index 000000000..ce30382ab
--- /dev/null
+++ b/pkg/safemem/BUILD
@@ -0,0 +1,27 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "safemem",
+ srcs = [
+ "block_unsafe.go",
+ "io.go",
+ "safemem.go",
+ "seq_unsafe.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/safecopy",
+ ],
+)
+
+go_test(
+ name = "safemem_test",
+ size = "small",
+ srcs = [
+ "io_test.go",
+ "seq_test.go",
+ ],
+ library = ":safemem",
+)
diff --git a/pkg/safemem/block_unsafe.go b/pkg/safemem/block_unsafe.go
new file mode 100644
index 000000000..e7fd30743
--- /dev/null
+++ b/pkg/safemem/block_unsafe.go
@@ -0,0 +1,279 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package safemem
+
+import (
+ "fmt"
+ "reflect"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/safecopy"
+)
+
+// A Block is a range of contiguous bytes, similar to []byte but with the
+// following differences:
+//
+// - The memory represented by a Block may require the use of safecopy to
+// access.
+//
+// - Block does not carry a capacity and cannot be expanded.
+//
+// Blocks are immutable and may be copied by value. The zero value of Block
+// represents an empty range, analogous to a nil []byte.
+type Block struct {
+ // [start, start+length) is the represented memory.
+ //
+ // start is an unsafe.Pointer to ensure that Block prevents the represented
+ // memory from being garbage-collected.
+ start unsafe.Pointer
+ length int
+
+ // needSafecopy is true if accessing the represented memory requires the
+ // use of safecopy.
+ needSafecopy bool
+}
+
+// BlockFromSafeSlice returns a Block equivalent to slice, which is safe to
+// access without safecopy.
+func BlockFromSafeSlice(slice []byte) Block {
+ return blockFromSlice(slice, false)
+}
+
+// BlockFromUnsafeSlice returns a Block equivalent to bs, which is not safe to
+// access without safecopy.
+func BlockFromUnsafeSlice(slice []byte) Block {
+ return blockFromSlice(slice, true)
+}
+
+func blockFromSlice(slice []byte, needSafecopy bool) Block {
+ if len(slice) == 0 {
+ return Block{}
+ }
+ return Block{
+ start: unsafe.Pointer(&slice[0]),
+ length: len(slice),
+ needSafecopy: needSafecopy,
+ }
+}
+
+// BlockFromSafePointer returns a Block equivalent to [ptr, ptr+len), which is
+// safe to access without safecopy.
+//
+// Preconditions: ptr+len does not overflow.
+func BlockFromSafePointer(ptr unsafe.Pointer, len int) Block {
+ return blockFromPointer(ptr, len, false)
+}
+
+// BlockFromUnsafePointer returns a Block equivalent to [ptr, ptr+len), which
+// is not safe to access without safecopy.
+//
+// Preconditions: ptr+len does not overflow.
+func BlockFromUnsafePointer(ptr unsafe.Pointer, len int) Block {
+ return blockFromPointer(ptr, len, true)
+}
+
+func blockFromPointer(ptr unsafe.Pointer, len int, needSafecopy bool) Block {
+ if uptr := uintptr(ptr); uptr+uintptr(len) < uptr {
+ panic(fmt.Sprintf("ptr %#x + len %#x overflows", ptr, len))
+ }
+ return Block{
+ start: ptr,
+ length: len,
+ needSafecopy: needSafecopy,
+ }
+}
+
+// DropFirst returns a Block equivalent to b, but with the first n bytes
+// omitted. It is analogous to the [n:] operation on a slice, except that if n
+// > b.Len(), DropFirst returns an empty Block instead of panicking.
+//
+// Preconditions: n >= 0.
+func (b Block) DropFirst(n int) Block {
+ if n < 0 {
+ panic(fmt.Sprintf("invalid n: %d", n))
+ }
+ return b.DropFirst64(uint64(n))
+}
+
+// DropFirst64 is equivalent to DropFirst but takes a uint64.
+func (b Block) DropFirst64(n uint64) Block {
+ if n >= uint64(b.length) {
+ return Block{}
+ }
+ return Block{
+ start: unsafe.Pointer(uintptr(b.start) + uintptr(n)),
+ length: b.length - int(n),
+ needSafecopy: b.needSafecopy,
+ }
+}
+
+// TakeFirst returns a Block equivalent to the first n bytes of b. It is
+// analogous to the [:n] operation on a slice, except that if n > b.Len(),
+// TakeFirst returns a copy of b instead of panicking.
+//
+// Preconditions: n >= 0.
+func (b Block) TakeFirst(n int) Block {
+ if n < 0 {
+ panic(fmt.Sprintf("invalid n: %d", n))
+ }
+ return b.TakeFirst64(uint64(n))
+}
+
+// TakeFirst64 is equivalent to TakeFirst but takes a uint64.
+func (b Block) TakeFirst64(n uint64) Block {
+ if n == 0 {
+ return Block{}
+ }
+ if n >= uint64(b.length) {
+ return b
+ }
+ return Block{
+ start: b.start,
+ length: int(n),
+ needSafecopy: b.needSafecopy,
+ }
+}
+
+// ToSlice returns a []byte equivalent to b.
+func (b Block) ToSlice() []byte {
+ var bs []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&bs))
+ hdr.Data = uintptr(b.start)
+ hdr.Len = b.length
+ hdr.Cap = b.length
+ return bs
+}
+
+// Addr returns b's start address as a uintptr. It returns uintptr instead of
+// unsafe.Pointer so that code using safemem cannot obtain unsafe.Pointers
+// without importing the unsafe package explicitly.
+//
+// Note that a uintptr is not recognized as a pointer by the garbage collector,
+// such that if there are no uses of b after a call to b.Addr() and the address
+// is to Go-managed memory, the returned uintptr does not prevent garbage
+// collection of the pointee.
+func (b Block) Addr() uintptr {
+ return uintptr(b.start)
+}
+
+// Len returns b's length in bytes.
+func (b Block) Len() int {
+ return b.length
+}
+
+// NeedSafecopy returns true if accessing b.ToSlice() requires the use of safecopy.
+func (b Block) NeedSafecopy() bool {
+ return b.needSafecopy
+}
+
+// String implements fmt.Stringer.String.
+func (b Block) String() string {
+ if uintptr(b.start) == 0 && b.length == 0 {
+ return "<nil>"
+ }
+ var suffix string
+ if b.needSafecopy {
+ suffix = "*"
+ }
+ return fmt.Sprintf("[%#x-%#x)%s", uintptr(b.start), uintptr(b.start)+uintptr(b.length), suffix)
+}
+
+// Copy copies src.Len() or dst.Len() bytes, whichever is less, from src
+// to dst and returns the number of bytes copied.
+//
+// If src and dst overlap, the data stored in dst is unspecified.
+func Copy(dst, src Block) (int, error) {
+ if !dst.needSafecopy && !src.needSafecopy {
+ return copy(dst.ToSlice(), src.ToSlice()), nil
+ }
+
+ n := dst.length
+ if n > src.length {
+ n = src.length
+ }
+ if n == 0 {
+ return 0, nil
+ }
+
+ switch {
+ case dst.needSafecopy && !src.needSafecopy:
+ return safecopy.CopyOut(dst.start, src.TakeFirst(n).ToSlice())
+ case !dst.needSafecopy && src.needSafecopy:
+ return safecopy.CopyIn(dst.TakeFirst(n).ToSlice(), src.start)
+ case dst.needSafecopy && src.needSafecopy:
+ n64, err := safecopy.Copy(dst.start, src.start, uintptr(n))
+ return int(n64), err
+ default:
+ panic("unreachable")
+ }
+}
+
+// Zero sets all bytes in dst to 0 and returns the number of bytes zeroed.
+func Zero(dst Block) (int, error) {
+ if !dst.needSafecopy {
+ bs := dst.ToSlice()
+ for i := range bs {
+ bs[i] = 0
+ }
+ return len(bs), nil
+ }
+
+ n64, err := safecopy.ZeroOut(dst.start, uintptr(dst.length))
+ return int(n64), err
+}
+
+// Safecopy atomics are no slower than non-safecopy atomics, so use the former
+// even when !b.needSafecopy to get consistent alignment checking.
+
+// SwapUint32 invokes safecopy.SwapUint32 on the first 4 bytes of b.
+//
+// Preconditions: b.Len() >= 4.
+func SwapUint32(b Block, new uint32) (uint32, error) {
+ if b.length < 4 {
+ panic(fmt.Sprintf("insufficient length: %d", b.length))
+ }
+ return safecopy.SwapUint32(b.start, new)
+}
+
+// SwapUint64 invokes safecopy.SwapUint64 on the first 8 bytes of b.
+//
+// Preconditions: b.Len() >= 8.
+func SwapUint64(b Block, new uint64) (uint64, error) {
+ if b.length < 8 {
+ panic(fmt.Sprintf("insufficient length: %d", b.length))
+ }
+ return safecopy.SwapUint64(b.start, new)
+}
+
+// CompareAndSwapUint32 invokes safecopy.CompareAndSwapUint32 on the first 4
+// bytes of b.
+//
+// Preconditions: b.Len() >= 4.
+func CompareAndSwapUint32(b Block, old, new uint32) (uint32, error) {
+ if b.length < 4 {
+ panic(fmt.Sprintf("insufficient length: %d", b.length))
+ }
+ return safecopy.CompareAndSwapUint32(b.start, old, new)
+}
+
+// LoadUint32 invokes safecopy.LoadUint32 on the first 4 bytes of b.
+//
+// Preconditions: b.Len() >= 4.
+func LoadUint32(b Block) (uint32, error) {
+ if b.length < 4 {
+ panic(fmt.Sprintf("insufficient length: %d", b.length))
+ }
+ return safecopy.LoadUint32(b.start)
+}
diff --git a/pkg/safemem/io.go b/pkg/safemem/io.go
new file mode 100644
index 000000000..f039a5c34
--- /dev/null
+++ b/pkg/safemem/io.go
@@ -0,0 +1,392 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package safemem
+
+import (
+ "errors"
+ "io"
+ "math"
+)
+
+// ErrEndOfBlockSeq is returned by BlockSeqWriter when attempting to write
+// beyond the end of the BlockSeq.
+var ErrEndOfBlockSeq = errors.New("write beyond end of BlockSeq")
+
+// Reader represents a streaming byte source like io.Reader.
+type Reader interface {
+ // ReadToBlocks reads up to dsts.NumBytes() bytes into dsts and returns the
+ // number of bytes read. It may return a partial read without an error
+ // (i.e. (n, nil) where 0 < n < dsts.NumBytes()). It should not return a
+ // full read with an error (i.e. (dsts.NumBytes(), err) where err != nil);
+ // note that this differs from io.Reader.Read (in particular, io.EOF should
+ // not be returned if ReadToBlocks successfully reads dsts.NumBytes()
+ // bytes.)
+ ReadToBlocks(dsts BlockSeq) (uint64, error)
+}
+
+// Writer represents a streaming byte sink like io.Writer.
+type Writer interface {
+ // WriteFromBlocks writes up to srcs.NumBytes() bytes from srcs and returns
+ // the number of bytes written. It may return a partial write without an
+ // error (i.e. (n, nil) where 0 < n < srcs.NumBytes()). It should not
+ // return a full write with an error (i.e. srcs.NumBytes(), err) where err
+ // != nil).
+ WriteFromBlocks(srcs BlockSeq) (uint64, error)
+}
+
+// ReadFullToBlocks repeatedly invokes r.ReadToBlocks until dsts.NumBytes()
+// bytes have been read or ReadToBlocks returns an error.
+func ReadFullToBlocks(r Reader, dsts BlockSeq) (uint64, error) {
+ var done uint64
+ for !dsts.IsEmpty() {
+ n, err := r.ReadToBlocks(dsts)
+ done += n
+ if err != nil {
+ return done, err
+ }
+ dsts = dsts.DropFirst64(n)
+ }
+ return done, nil
+}
+
+// WriteFullFromBlocks repeatedly invokes w.WriteFromBlocks until
+// srcs.NumBytes() bytes have been written or WriteFromBlocks returns an error.
+func WriteFullFromBlocks(w Writer, srcs BlockSeq) (uint64, error) {
+ var done uint64
+ for !srcs.IsEmpty() {
+ n, err := w.WriteFromBlocks(srcs)
+ done += n
+ if err != nil {
+ return done, err
+ }
+ srcs = srcs.DropFirst64(n)
+ }
+ return done, nil
+}
+
+// BlockSeqReader implements Reader by reading from a BlockSeq.
+type BlockSeqReader struct {
+ Blocks BlockSeq
+}
+
+// ReadToBlocks implements Reader.ReadToBlocks.
+func (r *BlockSeqReader) ReadToBlocks(dsts BlockSeq) (uint64, error) {
+ n, err := CopySeq(dsts, r.Blocks)
+ r.Blocks = r.Blocks.DropFirst64(n)
+ if err != nil {
+ return n, err
+ }
+ if n < dsts.NumBytes() {
+ return n, io.EOF
+ }
+ return n, nil
+}
+
+// BlockSeqWriter implements Writer by writing to a BlockSeq.
+type BlockSeqWriter struct {
+ Blocks BlockSeq
+}
+
+// WriteFromBlocks implements Writer.WriteFromBlocks.
+func (w *BlockSeqWriter) WriteFromBlocks(srcs BlockSeq) (uint64, error) {
+ n, err := CopySeq(w.Blocks, srcs)
+ w.Blocks = w.Blocks.DropFirst64(n)
+ if err != nil {
+ return n, err
+ }
+ if n < srcs.NumBytes() {
+ return n, ErrEndOfBlockSeq
+ }
+ return n, nil
+}
+
+// ReaderFunc implements Reader for a function with the semantics of
+// Reader.ReadToBlocks.
+type ReaderFunc func(dsts BlockSeq) (uint64, error)
+
+// ReadToBlocks implements Reader.ReadToBlocks.
+func (f ReaderFunc) ReadToBlocks(dsts BlockSeq) (uint64, error) {
+ return f(dsts)
+}
+
+// WriterFunc implements Writer for a function with the semantics of
+// Writer.WriteFromBlocks.
+type WriterFunc func(srcs BlockSeq) (uint64, error)
+
+// WriteFromBlocks implements Writer.WriteFromBlocks.
+func (f WriterFunc) WriteFromBlocks(srcs BlockSeq) (uint64, error) {
+ return f(srcs)
+}
+
+// ToIOReader implements io.Reader for a (safemem.)Reader.
+//
+// ToIOReader will return a successful partial read iff Reader.ReadToBlocks does
+// so.
+type ToIOReader struct {
+ Reader Reader
+}
+
+// Read implements io.Reader.Read.
+func (r ToIOReader) Read(dst []byte) (int, error) {
+ n, err := r.Reader.ReadToBlocks(BlockSeqOf(BlockFromSafeSlice(dst)))
+ return int(n), err
+}
+
+// ToIOWriter implements io.Writer for a (safemem.)Writer.
+type ToIOWriter struct {
+ Writer Writer
+}
+
+// Write implements io.Writer.Write.
+func (w ToIOWriter) Write(src []byte) (int, error) {
+ // io.Writer does not permit partial writes.
+ n, err := WriteFullFromBlocks(w.Writer, BlockSeqOf(BlockFromSafeSlice(src)))
+ return int(n), err
+}
+
+// FromIOReader implements Reader for an io.Reader by repeatedly invoking
+// io.Reader.Read until it returns an error or partial read. This is not
+// thread-safe.
+//
+// FromIOReader will return a successful partial read iff Reader.Read does so.
+type FromIOReader struct {
+ Reader io.Reader
+}
+
+// ReadToBlocks implements Reader.ReadToBlocks.
+func (r FromIOReader) ReadToBlocks(dsts BlockSeq) (uint64, error) {
+ var buf []byte
+ var done uint64
+ for !dsts.IsEmpty() {
+ dst := dsts.Head()
+ var n int
+ var err error
+ n, buf, err = r.readToBlock(dst, buf)
+ done += uint64(n)
+ if n != dst.Len() {
+ return done, err
+ }
+ dsts = dsts.Tail()
+ if err != nil {
+ if dsts.IsEmpty() && err == io.EOF {
+ return done, nil
+ }
+ return done, err
+ }
+ }
+ return done, nil
+}
+
+func (r FromIOReader) readToBlock(dst Block, buf []byte) (int, []byte, error) {
+ // io.Reader isn't safecopy-aware, so we have to buffer Blocks that require
+ // safecopy.
+ if !dst.NeedSafecopy() {
+ n, err := r.Reader.Read(dst.ToSlice())
+ return n, buf, err
+ }
+ if len(buf) < dst.Len() {
+ buf = make([]byte, dst.Len())
+ }
+ rn, rerr := r.Reader.Read(buf[:dst.Len()])
+ wbn, wberr := Copy(dst, BlockFromSafeSlice(buf[:rn]))
+ if wberr != nil {
+ return wbn, buf, wberr
+ }
+ return wbn, buf, rerr
+}
+
+// FromIOReaderAt implements Reader for an io.ReaderAt. Does not repeatedly
+// invoke io.ReaderAt.ReadAt because ReadAt is more strict than Read. A partial
+// read indicates an error. This is not thread-safe.
+type FromIOReaderAt struct {
+ ReaderAt io.ReaderAt
+ Offset int64
+}
+
+// ReadToBlocks implements Reader.ReadToBlocks.
+func (r FromIOReaderAt) ReadToBlocks(dsts BlockSeq) (uint64, error) {
+ var buf []byte
+ var done uint64
+ for !dsts.IsEmpty() {
+ dst := dsts.Head()
+ var n int
+ var err error
+ n, buf, err = r.readToBlock(dst, buf)
+ done += uint64(n)
+ if n != dst.Len() {
+ return done, err
+ }
+ dsts = dsts.Tail()
+ if err != nil {
+ if dsts.IsEmpty() && err == io.EOF {
+ return done, nil
+ }
+ return done, err
+ }
+ }
+ return done, nil
+}
+
+func (r FromIOReaderAt) readToBlock(dst Block, buf []byte) (int, []byte, error) {
+ // io.Reader isn't safecopy-aware, so we have to buffer Blocks that require
+ // safecopy.
+ if !dst.NeedSafecopy() {
+ n, err := r.ReaderAt.ReadAt(dst.ToSlice(), r.Offset)
+ r.Offset += int64(n)
+ return n, buf, err
+ }
+ if len(buf) < dst.Len() {
+ buf = make([]byte, dst.Len())
+ }
+ rn, rerr := r.ReaderAt.ReadAt(buf[:dst.Len()], r.Offset)
+ r.Offset += int64(rn)
+ wbn, wberr := Copy(dst, BlockFromSafeSlice(buf[:rn]))
+ if wberr != nil {
+ return wbn, buf, wberr
+ }
+ return wbn, buf, rerr
+}
+
+// FromIOWriter implements Writer for an io.Writer by repeatedly invoking
+// io.Writer.Write until it returns an error or partial write.
+//
+// FromIOWriter will tolerate implementations of io.Writer.Write that return
+// partial writes with a nil error in contravention of io.Writer's
+// requirements, since Writer is permitted to do so. FromIOWriter will return a
+// successful partial write iff Writer.Write does so.
+type FromIOWriter struct {
+ Writer io.Writer
+}
+
+// WriteFromBlocks implements Writer.WriteFromBlocks.
+func (w FromIOWriter) WriteFromBlocks(srcs BlockSeq) (uint64, error) {
+ var buf []byte
+ var done uint64
+ for !srcs.IsEmpty() {
+ src := srcs.Head()
+ var n int
+ var err error
+ n, buf, err = w.writeFromBlock(src, buf)
+ done += uint64(n)
+ if n != src.Len() || err != nil {
+ return done, err
+ }
+ srcs = srcs.Tail()
+ }
+ return done, nil
+}
+
+func (w FromIOWriter) writeFromBlock(src Block, buf []byte) (int, []byte, error) {
+ // io.Writer isn't safecopy-aware, so we have to buffer Blocks that require
+ // safecopy.
+ if !src.NeedSafecopy() {
+ n, err := w.Writer.Write(src.ToSlice())
+ return n, buf, err
+ }
+ if len(buf) < src.Len() {
+ buf = make([]byte, src.Len())
+ }
+ bufn, buferr := Copy(BlockFromSafeSlice(buf[:src.Len()]), src)
+ wn, werr := w.Writer.Write(buf[:bufn])
+ if werr != nil {
+ return wn, buf, werr
+ }
+ return wn, buf, buferr
+}
+
+// FromVecReaderFunc implements Reader for a function that reads data into a
+// [][]byte and returns the number of bytes read as an int64.
+type FromVecReaderFunc struct {
+ ReadVec func(dsts [][]byte) (int64, error)
+}
+
+// ReadToBlocks implements Reader.ReadToBlocks.
+//
+// ReadToBlocks calls r.ReadVec at most once.
+func (r FromVecReaderFunc) ReadToBlocks(dsts BlockSeq) (uint64, error) {
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+ // Ensure that we don't pass a [][]byte with a total length > MaxInt64.
+ dsts = dsts.TakeFirst64(uint64(math.MaxInt64))
+ dstSlices := make([][]byte, 0, dsts.NumBlocks())
+ // Buffer Blocks that require safecopy.
+ for tmp := dsts; !tmp.IsEmpty(); tmp = tmp.Tail() {
+ dst := tmp.Head()
+ if dst.NeedSafecopy() {
+ dstSlices = append(dstSlices, make([]byte, dst.Len()))
+ } else {
+ dstSlices = append(dstSlices, dst.ToSlice())
+ }
+ }
+ rn, rerr := r.ReadVec(dstSlices)
+ dsts = dsts.TakeFirst64(uint64(rn))
+ var done uint64
+ var i int
+ for !dsts.IsEmpty() {
+ dst := dsts.Head()
+ if dst.NeedSafecopy() {
+ n, err := Copy(dst, BlockFromSafeSlice(dstSlices[i]))
+ done += uint64(n)
+ if err != nil {
+ return done, err
+ }
+ } else {
+ done += uint64(dst.Len())
+ }
+ dsts = dsts.Tail()
+ i++
+ }
+ return done, rerr
+}
+
+// FromVecWriterFunc implements Writer for a function that writes data from a
+// [][]byte and returns the number of bytes written.
+type FromVecWriterFunc struct {
+ WriteVec func(srcs [][]byte) (int64, error)
+}
+
+// WriteFromBlocks implements Writer.WriteFromBlocks.
+//
+// WriteFromBlocks calls w.WriteVec at most once.
+func (w FromVecWriterFunc) WriteFromBlocks(srcs BlockSeq) (uint64, error) {
+ if srcs.IsEmpty() {
+ return 0, nil
+ }
+ // Ensure that we don't pass a [][]byte with a total length > MaxInt64.
+ srcs = srcs.TakeFirst64(uint64(math.MaxInt64))
+ srcSlices := make([][]byte, 0, srcs.NumBlocks())
+ // Buffer Blocks that require safecopy.
+ var buferr error
+ for tmp := srcs; !tmp.IsEmpty(); tmp = tmp.Tail() {
+ src := tmp.Head()
+ if src.NeedSafecopy() {
+ slice := make([]byte, src.Len())
+ n, err := Copy(BlockFromSafeSlice(slice), src)
+ srcSlices = append(srcSlices, slice[:n])
+ if err != nil {
+ buferr = err
+ break
+ }
+ } else {
+ srcSlices = append(srcSlices, src.ToSlice())
+ }
+ }
+ n, err := w.WriteVec(srcSlices)
+ if err != nil {
+ return uint64(n), err
+ }
+ return uint64(n), buferr
+}
diff --git a/pkg/safemem/io_test.go b/pkg/safemem/io_test.go
new file mode 100644
index 000000000..629741bee
--- /dev/null
+++ b/pkg/safemem/io_test.go
@@ -0,0 +1,199 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package safemem
+
+import (
+ "bytes"
+ "io"
+ "testing"
+)
+
+func makeBlocks(slices ...[]byte) []Block {
+ blocks := make([]Block, 0, len(slices))
+ for _, s := range slices {
+ blocks = append(blocks, BlockFromSafeSlice(s))
+ }
+ return blocks
+}
+
+func TestFromIOReaderFullRead(t *testing.T) {
+ r := FromIOReader{bytes.NewBufferString("foobar")}
+ dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
+ n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
+ if wantN := uint64(6); n != wantN || err != nil {
+ t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ for i, want := range [][]byte{[]byte("foo"), []byte("bar")} {
+ if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
+ t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
+ }
+ }
+}
+
+type eofHidingReader struct {
+ Reader io.Reader
+}
+
+func (r eofHidingReader) Read(dst []byte) (int, error) {
+ n, err := r.Reader.Read(dst)
+ if err == io.EOF {
+ return n, nil
+ }
+ return n, err
+}
+
+func TestFromIOReaderPartialRead(t *testing.T) {
+ r := FromIOReader{eofHidingReader{bytes.NewBufferString("foob")}}
+ dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
+ n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
+ // FromIOReader should stop after the eofHidingReader returns (1, nil)
+ // for a 3-byte read.
+ if wantN := uint64(4); n != wantN || err != nil {
+ t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ for i, want := range [][]byte{[]byte("foo"), []byte("b\x00\x00")} {
+ if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
+ t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
+ }
+ }
+}
+
+type singleByteReader struct {
+ Reader io.Reader
+}
+
+func (r singleByteReader) Read(dst []byte) (int, error) {
+ if len(dst) == 0 {
+ return r.Reader.Read(dst)
+ }
+ return r.Reader.Read(dst[:1])
+}
+
+func TestSingleByteReader(t *testing.T) {
+ r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}}
+ dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
+ n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
+ // FromIOReader should stop after the singleByteReader returns (1, nil)
+ // for a 3-byte read.
+ if wantN := uint64(1); n != wantN || err != nil {
+ t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ for i, want := range [][]byte{[]byte("f\x00\x00"), []byte("\x00\x00\x00")} {
+ if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
+ t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
+ }
+ }
+}
+
+func TestReadFullToBlocks(t *testing.T) {
+ r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}}
+ dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
+ n, err := ReadFullToBlocks(r, BlockSeqFromSlice(dsts))
+ // ReadFullToBlocks should call into FromIOReader => singleByteReader
+ // repeatedly until dsts is exhausted.
+ if wantN := uint64(6); n != wantN || err != nil {
+ t.Errorf("ReadFullToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ for i, want := range [][]byte{[]byte("foo"), []byte("bar")} {
+ if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
+ t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
+ }
+ }
+}
+
+func TestFromIOWriterFullWrite(t *testing.T) {
+ srcs := makeBlocks([]byte("foo"), []byte("bar"))
+ var dst bytes.Buffer
+ w := FromIOWriter{&dst}
+ n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
+ if wantN := uint64(6); n != wantN || err != nil {
+ t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) {
+ t.Errorf("dst: got %q, wanted %q", got, want)
+ }
+}
+
+type limitedWriter struct {
+ Writer io.Writer
+ Done int
+ Limit int
+}
+
+func (w *limitedWriter) Write(src []byte) (int, error) {
+ count := len(src)
+ if count > (w.Limit - w.Done) {
+ count = w.Limit - w.Done
+ }
+ n, err := w.Writer.Write(src[:count])
+ w.Done += n
+ return n, err
+}
+
+func TestFromIOWriterPartialWrite(t *testing.T) {
+ srcs := makeBlocks([]byte("foo"), []byte("bar"))
+ var dst bytes.Buffer
+ w := FromIOWriter{&limitedWriter{&dst, 0, 4}}
+ n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
+ // FromIOWriter should stop after the limitedWriter returns (1, nil) for a
+ // 3-byte write.
+ if wantN := uint64(4); n != wantN || err != nil {
+ t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) {
+ t.Errorf("dst: got %q, wanted %q", got, want)
+ }
+}
+
+type singleByteWriter struct {
+ Writer io.Writer
+}
+
+func (w singleByteWriter) Write(src []byte) (int, error) {
+ if len(src) == 0 {
+ return w.Writer.Write(src)
+ }
+ return w.Writer.Write(src[:1])
+}
+
+func TestSingleByteWriter(t *testing.T) {
+ srcs := makeBlocks([]byte("foo"), []byte("bar"))
+ var dst bytes.Buffer
+ w := FromIOWriter{singleByteWriter{&dst}}
+ n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
+ // FromIOWriter should stop after the singleByteWriter returns (1, nil)
+ // for a 3-byte write.
+ if wantN := uint64(1); n != wantN || err != nil {
+ t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if got, want := dst.Bytes(), []byte("f"); !bytes.Equal(got, want) {
+ t.Errorf("dst: got %q, wanted %q", got, want)
+ }
+}
+
+func TestWriteFullToBlocks(t *testing.T) {
+ srcs := makeBlocks([]byte("foo"), []byte("bar"))
+ var dst bytes.Buffer
+ w := FromIOWriter{singleByteWriter{&dst}}
+ n, err := WriteFullFromBlocks(w, BlockSeqFromSlice(srcs))
+ // WriteFullToBlocks should call into FromIOWriter => singleByteWriter
+ // repeatedly until srcs is exhausted.
+ if wantN := uint64(6); n != wantN || err != nil {
+ t.Errorf("WriteFullFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) {
+ t.Errorf("dst: got %q, wanted %q", got, want)
+ }
+}
diff --git a/pkg/safemem/safemem.go b/pkg/safemem/safemem.go
new file mode 100644
index 000000000..3e70d33a2
--- /dev/null
+++ b/pkg/safemem/safemem.go
@@ -0,0 +1,16 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package safemem provides the Block and BlockSeq types.
+package safemem
diff --git a/pkg/safemem/seq_test.go b/pkg/safemem/seq_test.go
new file mode 100644
index 000000000..de34005e9
--- /dev/null
+++ b/pkg/safemem/seq_test.go
@@ -0,0 +1,217 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package safemem
+
+import (
+ "bytes"
+ "reflect"
+ "testing"
+)
+
+func TestBlockSeqOfEmptyBlock(t *testing.T) {
+ bs := BlockSeqOf(Block{})
+ if !bs.IsEmpty() {
+ t.Errorf("BlockSeqOf(Block{}).IsEmpty(): got false, wanted true; BlockSeq is %v", bs)
+ }
+}
+
+func TestBlockSeqOfNonemptyBlock(t *testing.T) {
+ b := BlockFromSafeSlice(make([]byte, 1))
+ bs := BlockSeqOf(b)
+ if bs.IsEmpty() {
+ t.Fatalf("BlockSeqOf(non-empty Block).IsEmpty(): got true, wanted false; BlockSeq is %v", bs)
+ }
+ if head := bs.Head(); head != b {
+ t.Fatalf("BlockSeqOf(non-empty Block).Head(): got %v, wanted %v", head, b)
+ }
+ if tail := bs.Tail(); !tail.IsEmpty() {
+ t.Fatalf("BlockSeqOf(non-empty Block).Tail().IsEmpty(): got false, wanted true: tail is %v", tail)
+ }
+}
+
+type blockSeqTest struct {
+ desc string
+
+ pieces []string
+ haveOffset bool
+ offset uint64
+ haveLimit bool
+ limit uint64
+
+ want string
+}
+
+func (t blockSeqTest) NonEmptyByteSlices() [][]byte {
+ // t is a value, so we can mutate it freely.
+ slices := make([][]byte, 0, len(t.pieces))
+ for _, str := range t.pieces {
+ if t.haveOffset {
+ strOff := t.offset
+ if strOff > uint64(len(str)) {
+ strOff = uint64(len(str))
+ }
+ str = str[strOff:]
+ t.offset -= strOff
+ }
+ if t.haveLimit {
+ strLim := t.limit
+ if strLim > uint64(len(str)) {
+ strLim = uint64(len(str))
+ }
+ str = str[:strLim]
+ t.limit -= strLim
+ }
+ if len(str) != 0 {
+ slices = append(slices, []byte(str))
+ }
+ }
+ return slices
+}
+
+func (t blockSeqTest) BlockSeq() BlockSeq {
+ blocks := make([]Block, 0, len(t.pieces))
+ for _, str := range t.pieces {
+ blocks = append(blocks, BlockFromSafeSlice([]byte(str)))
+ }
+ bs := BlockSeqFromSlice(blocks)
+ if t.haveOffset {
+ bs = bs.DropFirst64(t.offset)
+ }
+ if t.haveLimit {
+ bs = bs.TakeFirst64(t.limit)
+ }
+ return bs
+}
+
+var blockSeqTests = []blockSeqTest{
+ {
+ desc: "Empty sequence",
+ },
+ {
+ desc: "Sequence of length 1",
+ pieces: []string{"foobar"},
+ want: "foobar",
+ },
+ {
+ desc: "Sequence of length 2",
+ pieces: []string{"foo", "bar"},
+ want: "foobar",
+ },
+ {
+ desc: "Empty Blocks",
+ pieces: []string{"", "foo", "", "", "bar", ""},
+ want: "foobar",
+ },
+ {
+ desc: "Sequence with non-zero offset",
+ pieces: []string{"foo", "bar"},
+ haveOffset: true,
+ offset: 2,
+ want: "obar",
+ },
+ {
+ desc: "Sequence with non-maximal limit",
+ pieces: []string{"foo", "bar"},
+ haveLimit: true,
+ limit: 5,
+ want: "fooba",
+ },
+ {
+ desc: "Sequence with offset and limit",
+ pieces: []string{"foo", "bar"},
+ haveOffset: true,
+ offset: 2,
+ haveLimit: true,
+ limit: 3,
+ want: "oba",
+ },
+}
+
+func TestBlockSeqNumBytes(t *testing.T) {
+ for _, test := range blockSeqTests {
+ t.Run(test.desc, func(t *testing.T) {
+ if got, want := test.BlockSeq().NumBytes(), uint64(len(test.want)); got != want {
+ t.Errorf("NumBytes: got %d, wanted %d", got, want)
+ }
+ })
+ }
+}
+
+func TestBlockSeqIterBlocks(t *testing.T) {
+ // Tests BlockSeq iteration using Head/Tail.
+ for _, test := range blockSeqTests {
+ t.Run(test.desc, func(t *testing.T) {
+ srcs := test.BlockSeq()
+ // "Note that a non-nil empty slice and a nil slice ... are not
+ // deeply equal." - reflect
+ slices := make([][]byte, 0, 0)
+ for !srcs.IsEmpty() {
+ src := srcs.Head()
+ slices = append(slices, src.ToSlice())
+ nextSrcs := srcs.Tail()
+ if got, want := nextSrcs.NumBytes(), srcs.NumBytes()-uint64(src.Len()); got != want {
+ t.Fatalf("%v.Tail(): got %v (%d bytes), wanted %d bytes", srcs, nextSrcs, got, want)
+ }
+ srcs = nextSrcs
+ }
+ if wantSlices := test.NonEmptyByteSlices(); !reflect.DeepEqual(slices, wantSlices) {
+ t.Errorf("Accumulated slices: got %v, wanted %v", slices, wantSlices)
+ }
+ })
+ }
+}
+
+func TestBlockSeqIterBytes(t *testing.T) {
+ // Tests BlockSeq iteration using Head/DropFirst.
+ for _, test := range blockSeqTests {
+ t.Run(test.desc, func(t *testing.T) {
+ srcs := test.BlockSeq()
+ var dst bytes.Buffer
+ for !srcs.IsEmpty() {
+ src := srcs.Head()
+ var b [1]byte
+ n, err := Copy(BlockFromSafeSlice(b[:]), src)
+ if n != 1 || err != nil {
+ t.Fatalf("Copy: got (%v, %v), wanted (1, nil)", n, err)
+ }
+ dst.WriteByte(b[0])
+ nextSrcs := srcs.DropFirst(1)
+ if got, want := nextSrcs.NumBytes(), srcs.NumBytes()-1; got != want {
+ t.Fatalf("%v.DropFirst(1): got %v (%d bytes), wanted %d bytes", srcs, nextSrcs, got, want)
+ }
+ srcs = nextSrcs
+ }
+ if got := string(dst.Bytes()); got != test.want {
+ t.Errorf("Copied string: got %q, wanted %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestBlockSeqDropBeyondLimit(t *testing.T) {
+ blocks := []Block{BlockFromSafeSlice([]byte("123")), BlockFromSafeSlice([]byte("4"))}
+ bs := BlockSeqFromSlice(blocks)
+ if got, want := bs.NumBytes(), uint64(4); got != want {
+ t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want)
+ }
+ bs = bs.TakeFirst(1)
+ if got, want := bs.NumBytes(), uint64(1); got != want {
+ t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want)
+ }
+ bs = bs.DropFirst(2)
+ if got, want := bs.NumBytes(), uint64(0); got != want {
+ t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want)
+ }
+}
diff --git a/pkg/safemem/seq_unsafe.go b/pkg/safemem/seq_unsafe.go
new file mode 100644
index 000000000..f5f0574f8
--- /dev/null
+++ b/pkg/safemem/seq_unsafe.go
@@ -0,0 +1,319 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package safemem
+
+import (
+ "bytes"
+ "fmt"
+ "reflect"
+ "syscall"
+ "unsafe"
+)
+
+// A BlockSeq represents a sequence of Blocks, each of which has non-zero
+// length.
+//
+// BlockSeqs are immutable and may be copied by value. The zero value of
+// BlockSeq represents an empty sequence.
+type BlockSeq struct {
+ // If length is 0, then the BlockSeq is empty. Invariants: data == 0;
+ // offset == 0; limit == 0.
+ //
+ // If length is -1, then the BlockSeq represents the single Block{data,
+ // limit, false}. Invariants: offset == 0; limit > 0; limit does not
+ // overflow the range of an int.
+ //
+ // If length is -2, then the BlockSeq represents the single Block{data,
+ // limit, true}. Invariants: offset == 0; limit > 0; limit does not
+ // overflow the range of an int.
+ //
+ // Otherwise, length >= 2, and the BlockSeq represents the `length` Blocks
+ // in the array of Blocks starting at address `data`, starting at `offset`
+ // bytes into the first Block and limited to the following `limit` bytes.
+ // Invariants: data != 0; offset < len(data[0]); limit > 0; offset+limit <=
+ // the combined length of all Blocks in the array; the first Block in the
+ // array has non-zero length.
+ //
+ // length is never 1; sequences consisting of a single Block are always
+ // stored inline (with length < 0).
+ data unsafe.Pointer
+ length int
+ offset int
+ limit uint64
+}
+
+// BlockSeqOf returns a BlockSeq representing the single Block b.
+func BlockSeqOf(b Block) BlockSeq {
+ if b.length == 0 {
+ return BlockSeq{}
+ }
+ bs := BlockSeq{
+ data: b.start,
+ length: -1,
+ limit: uint64(b.length),
+ }
+ if b.needSafecopy {
+ bs.length = -2
+ }
+ return bs
+}
+
+// BlockSeqFromSlice returns a BlockSeq representing all Blocks in slice.
+// If slice contains Blocks with zero length, BlockSeq will skip them during
+// iteration.
+//
+// Whether the returned BlockSeq shares memory with slice is unspecified;
+// clients should avoid mutating slices passed to BlockSeqFromSlice.
+//
+// Preconditions: The combined length of all Blocks in slice <= math.MaxUint64.
+func BlockSeqFromSlice(slice []Block) BlockSeq {
+ slice = skipEmpty(slice)
+ var limit uint64
+ for _, b := range slice {
+ sum := limit + uint64(b.Len())
+ if sum < limit {
+ panic("BlockSeq length overflows uint64")
+ }
+ limit = sum
+ }
+ return blockSeqFromSliceLimited(slice, limit)
+}
+
+// Preconditions: The combined length of all Blocks in slice <= limit. If
+// len(slice) != 0, the first Block in slice has non-zero length, and limit >
+// 0.
+func blockSeqFromSliceLimited(slice []Block, limit uint64) BlockSeq {
+ switch len(slice) {
+ case 0:
+ return BlockSeq{}
+ case 1:
+ return BlockSeqOf(slice[0].TakeFirst64(limit))
+ default:
+ return BlockSeq{
+ data: unsafe.Pointer(&slice[0]),
+ length: len(slice),
+ limit: limit,
+ }
+ }
+}
+
+func skipEmpty(slice []Block) []Block {
+ for i, b := range slice {
+ if b.Len() != 0 {
+ return slice[i:]
+ }
+ }
+ return nil
+}
+
+// IsEmpty returns true if bs contains no Blocks.
+//
+// Invariants: bs.IsEmpty() == (bs.NumBlocks() == 0) == (bs.NumBytes() == 0).
+// (Of these, prefer to use bs.IsEmpty().)
+func (bs BlockSeq) IsEmpty() bool {
+ return bs.length == 0
+}
+
+// NumBlocks returns the number of Blocks in bs.
+func (bs BlockSeq) NumBlocks() int {
+ // In general, we have to count: if bs represents a windowed slice then the
+ // slice may contain Blocks with zero length, and bs.length may be larger
+ // than the actual number of Blocks due to bs.limit.
+ var n int
+ for !bs.IsEmpty() {
+ n++
+ bs = bs.Tail()
+ }
+ return n
+}
+
+// NumBytes returns the sum of Block.Len() for all Blocks in bs.
+func (bs BlockSeq) NumBytes() uint64 {
+ return bs.limit
+}
+
+// Head returns the first Block in bs.
+//
+// Preconditions: !bs.IsEmpty().
+func (bs BlockSeq) Head() Block {
+ if bs.length == 0 {
+ panic("empty BlockSeq")
+ }
+ if bs.length < 0 {
+ return bs.internalBlock()
+ }
+ return (*Block)(bs.data).DropFirst(bs.offset).TakeFirst64(bs.limit)
+}
+
+// Preconditions: bs.length < 0.
+func (bs BlockSeq) internalBlock() Block {
+ return Block{
+ start: bs.data,
+ length: int(bs.limit),
+ needSafecopy: bs.length == -2,
+ }
+}
+
+// Tail returns a BlockSeq consisting of all Blocks in bs after the first.
+//
+// Preconditions: !bs.IsEmpty().
+func (bs BlockSeq) Tail() BlockSeq {
+ if bs.length == 0 {
+ panic("empty BlockSeq")
+ }
+ if bs.length < 0 {
+ return BlockSeq{}
+ }
+ head := (*Block)(bs.data).DropFirst(bs.offset)
+ headLen := uint64(head.Len())
+ if headLen >= bs.limit {
+ // The head Block exhausts the limit, so the tail is empty.
+ return BlockSeq{}
+ }
+ var extSlice []Block
+ extSliceHdr := (*reflect.SliceHeader)(unsafe.Pointer(&extSlice))
+ extSliceHdr.Data = uintptr(bs.data)
+ extSliceHdr.Len = bs.length
+ extSliceHdr.Cap = bs.length
+ tailSlice := skipEmpty(extSlice[1:])
+ tailLimit := bs.limit - headLen
+ return blockSeqFromSliceLimited(tailSlice, tailLimit)
+}
+
+// DropFirst returns a BlockSeq equivalent to bs, but with the first n bytes
+// omitted. If n > bs.NumBytes(), DropFirst returns an empty BlockSeq.
+//
+// Preconditions: n >= 0.
+func (bs BlockSeq) DropFirst(n int) BlockSeq {
+ if n < 0 {
+ panic(fmt.Sprintf("invalid n: %d", n))
+ }
+ return bs.DropFirst64(uint64(n))
+}
+
+// DropFirst64 is equivalent to DropFirst but takes an uint64.
+func (bs BlockSeq) DropFirst64(n uint64) BlockSeq {
+ if n >= bs.limit {
+ return BlockSeq{}
+ }
+ for {
+ // Calling bs.Head() here is surprisingly expensive, so inline getting
+ // the head's length.
+ var headLen uint64
+ if bs.length < 0 {
+ headLen = bs.limit
+ } else {
+ headLen = uint64((*Block)(bs.data).Len() - bs.offset)
+ }
+ if n < headLen {
+ // Dropping ends partway through the head Block.
+ if bs.length < 0 {
+ return BlockSeqOf(bs.internalBlock().DropFirst64(n))
+ }
+ bs.offset += int(n)
+ bs.limit -= n
+ return bs
+ }
+ n -= headLen
+ bs = bs.Tail()
+ }
+}
+
+// TakeFirst returns a BlockSeq equivalent to the first n bytes of bs. If n >
+// bs.NumBytes(), TakeFirst returns a BlockSeq equivalent to bs.
+//
+// Preconditions: n >= 0.
+func (bs BlockSeq) TakeFirst(n int) BlockSeq {
+ if n < 0 {
+ panic(fmt.Sprintf("invalid n: %d", n))
+ }
+ return bs.TakeFirst64(uint64(n))
+}
+
+// TakeFirst64 is equivalent to TakeFirst but takes a uint64.
+func (bs BlockSeq) TakeFirst64(n uint64) BlockSeq {
+ if n == 0 {
+ return BlockSeq{}
+ }
+ if bs.limit > n {
+ bs.limit = n
+ }
+ return bs
+}
+
+// String implements fmt.Stringer.String.
+func (bs BlockSeq) String() string {
+ var buf bytes.Buffer
+ buf.WriteByte('[')
+ var sep string
+ for !bs.IsEmpty() {
+ buf.WriteString(sep)
+ sep = " "
+ buf.WriteString(bs.Head().String())
+ bs = bs.Tail()
+ }
+ buf.WriteByte(']')
+ return buf.String()
+}
+
+// CopySeq copies srcs.NumBytes() or dsts.NumBytes() bytes, whichever is less,
+// from srcs to dsts and returns the number of bytes copied.
+//
+// If srcs and dsts overlap, the data stored in dsts is unspecified.
+func CopySeq(dsts, srcs BlockSeq) (uint64, error) {
+ var done uint64
+ for !dsts.IsEmpty() && !srcs.IsEmpty() {
+ dst := dsts.Head()
+ src := srcs.Head()
+ n, err := Copy(dst, src)
+ done += uint64(n)
+ if err != nil {
+ return done, err
+ }
+ dsts = dsts.DropFirst(n)
+ srcs = srcs.DropFirst(n)
+ }
+ return done, nil
+}
+
+// ZeroSeq sets all bytes in dsts to 0 and returns the number of bytes zeroed.
+func ZeroSeq(dsts BlockSeq) (uint64, error) {
+ var done uint64
+ for !dsts.IsEmpty() {
+ n, err := Zero(dsts.Head())
+ done += uint64(n)
+ if err != nil {
+ return done, err
+ }
+ dsts = dsts.DropFirst(n)
+ }
+ return done, nil
+}
+
+// IovecsFromBlockSeq returns a []syscall.Iovec representing seq.
+func IovecsFromBlockSeq(bs BlockSeq) []syscall.Iovec {
+ iovs := make([]syscall.Iovec, 0, bs.NumBlocks())
+ for ; !bs.IsEmpty(); bs = bs.Tail() {
+ b := bs.Head()
+ iovs = append(iovs, syscall.Iovec{
+ Base: &b.ToSlice()[0],
+ Len: uint64(b.Len()),
+ })
+ // We don't need to care about b.NeedSafecopy(), because the host
+ // kernel will handle such address ranges just fine (by returning
+ // EFAULT).
+ }
+ return iovs
+}
diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD
new file mode 100644
index 000000000..c5fca2ba3
--- /dev/null
+++ b/pkg/seccomp/BUILD
@@ -0,0 +1,50 @@
+load("//tools:defs.bzl", "go_binary", "go_embed_data", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "victim",
+ testonly = 1,
+ srcs = ["seccomp_test_victim.go"],
+ deps = [":seccomp"],
+)
+
+go_embed_data(
+ name = "victim_data",
+ testonly = 1,
+ src = "victim",
+ package = "seccomp",
+ var = "victimData",
+)
+
+go_library(
+ name = "seccomp",
+ srcs = [
+ "seccomp.go",
+ "seccomp_amd64.go",
+ "seccomp_arm64.go",
+ "seccomp_rules.go",
+ "seccomp_unsafe.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/bpf",
+ "//pkg/log",
+ ],
+)
+
+go_test(
+ name = "seccomp_test",
+ size = "small",
+ srcs = [
+ "seccomp_test.go",
+ ":victim_data",
+ ],
+ library = ":seccomp",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/bpf",
+ ],
+)
diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go
new file mode 100644
index 000000000..55fd6967e
--- /dev/null
+++ b/pkg/seccomp/seccomp.go
@@ -0,0 +1,404 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package seccomp provides basic seccomp filters for x86_64 (little endian).
+package seccomp
+
+import (
+ "fmt"
+ "reflect"
+ "sort"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+const (
+ // skipOneInst is the offset to take for skipping one instruction.
+ skipOneInst = 1
+
+ // defaultLabel is the label for the default action.
+ defaultLabel = "default_action"
+)
+
+// Install generates BPF code based on the set of syscalls provided. It only
+// allows syscalls that conform to the specification. Syscalls that violate the
+// specification will trigger RET_KILL_PROCESS, except for the cases below.
+//
+// RET_TRAP is used in violations, instead of RET_KILL_PROCESS, in the
+// following cases:
+// 1. Kernel doesn't support RET_KILL_PROCESS: RET_KILL_THREAD only kills the
+// offending thread and often keeps the sentry hanging.
+// 2. Debug: RET_TRAP generates a panic followed by a stack trace which is
+// much easier to debug then RET_KILL_PROCESS which can't be caught.
+//
+// Be aware that RET_TRAP sends SIGSYS to the process and it may be ignored,
+// making it possible for the process to continue running after a violation.
+// However, it will leave a SECCOMP audit event trail behind. In any case, the
+// syscall is still blocked from executing.
+func Install(rules SyscallRules) error {
+ defaultAction, err := defaultAction()
+ if err != nil {
+ return err
+ }
+
+ // Uncomment to get stack trace when there is a violation.
+ // defaultAction = linux.BPFAction(linux.SECCOMP_RET_TRAP)
+
+ log.Infof("Installing seccomp filters for %d syscalls (action=%v)", len(rules), defaultAction)
+
+ instrs, err := BuildProgram([]RuleSet{
+ RuleSet{
+ Rules: rules,
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ }, defaultAction)
+ if log.IsLogging(log.Debug) {
+ programStr, errDecode := bpf.DecodeProgram(instrs)
+ if errDecode != nil {
+ programStr = fmt.Sprintf("Error: %v\n%s", errDecode, programStr)
+ }
+ log.Debugf("Seccomp program dump:\n%s", programStr)
+ }
+ if err != nil {
+ return err
+ }
+
+ // Perform the actual installation.
+ if errno := SetFilter(instrs); errno != 0 {
+ return fmt.Errorf("Failed to set filter: %v", errno)
+ }
+
+ log.Infof("Seccomp filters installed.")
+ return nil
+}
+
+func defaultAction() (linux.BPFAction, error) {
+ available, err := isKillProcessAvailable()
+ if err != nil {
+ return 0, err
+ }
+ if available {
+ return linux.SECCOMP_RET_KILL_PROCESS, nil
+ }
+ return linux.SECCOMP_RET_TRAP, nil
+}
+
+// RuleSet is a set of rules and associated action.
+type RuleSet struct {
+ Rules SyscallRules
+ Action linux.BPFAction
+
+ // Vsyscall indicates that a check is made for a function being called
+ // from kernel mappings. This is where the vsyscall page is located
+ // (and typically) emulated, so this RuleSet will not match any
+ // functions not dispatched from the vsyscall page.
+ Vsyscall bool
+}
+
+// SyscallName gives names to system calls. It is used purely for debugging purposes.
+//
+// An alternate namer can be provided to the package at initialization time.
+var SyscallName = func(sysno uintptr) string {
+ return fmt.Sprintf("syscall_%d", sysno)
+}
+
+// BuildProgram builds a BPF program from the given map of actions to matching
+// SyscallRules. The single generated program covers all provided RuleSets.
+func BuildProgram(rules []RuleSet, defaultAction linux.BPFAction) ([]linux.BPFInstruction, error) {
+ program := bpf.NewProgramBuilder()
+
+ // Be paranoid and check that syscall is done in the expected architecture.
+ //
+ // A = seccomp_data.arch
+ // if (A != AUDIT_ARCH) goto defaultAction.
+ program.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArch)
+ // defaultLabel is at the bottom of the program. The size of program
+ // may exceeds 255 lines, which is the limit of a condition jump.
+ program.AddJump(bpf.Jmp|bpf.Jeq|bpf.K, LINUX_AUDIT_ARCH, skipOneInst, 0)
+ program.AddDirectJumpLabel(defaultLabel)
+ if err := buildIndex(rules, program); err != nil {
+ return nil, err
+ }
+
+ // Exhausted: return defaultAction.
+ if err := program.AddLabel(defaultLabel); err != nil {
+ return nil, err
+ }
+ program.AddStmt(bpf.Ret|bpf.K, uint32(defaultAction))
+
+ return program.Instructions()
+}
+
+// buildIndex builds a BST to quickly search through all syscalls.
+func buildIndex(rules []RuleSet, program *bpf.ProgramBuilder) error {
+ // Build a list of all application system calls, across all given rule
+ // sets. We have a simple BST, but may dispatch individual matchers
+ // with different actions. The matchers are evaluated linearly.
+ requiredSyscalls := make(map[uintptr]struct{})
+ for _, rs := range rules {
+ for sysno := range rs.Rules {
+ requiredSyscalls[sysno] = struct{}{}
+ }
+ }
+ syscalls := make([]uintptr, 0, len(requiredSyscalls))
+ for sysno, _ := range requiredSyscalls {
+ syscalls = append(syscalls, sysno)
+ }
+ sort.Slice(syscalls, func(i, j int) bool { return syscalls[i] < syscalls[j] })
+ for _, sysno := range syscalls {
+ for _, rs := range rules {
+ // Print only if there is a corresponding set of rules.
+ if _, ok := rs.Rules[sysno]; ok {
+ log.Debugf("syscall filter %v: %s => 0x%x", SyscallName(sysno), rs.Rules[sysno], rs.Action)
+ }
+ }
+ }
+
+ root := createBST(syscalls)
+ root.root = true
+
+ // Load syscall number into A and run through BST.
+ //
+ // A = seccomp_data.nr
+ program.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetNR)
+ return root.traverse(buildBSTProgram, rules, program)
+}
+
+// createBST converts sorted syscall slice into a balanced BST.
+// Panics if syscalls is empty.
+func createBST(syscalls []uintptr) *node {
+ i := len(syscalls) / 2
+ parent := node{value: syscalls[i]}
+ if i > 0 {
+ parent.left = createBST(syscalls[:i])
+ }
+ if i+1 < len(syscalls) {
+ parent.right = createBST(syscalls[i+1:])
+ }
+ return &parent
+}
+
+func vsyscallViolationLabel(ruleSetIdx int, sysno uintptr) string {
+ return fmt.Sprintf("vsyscallViolation_%v_%v", ruleSetIdx, sysno)
+}
+
+func ruleViolationLabel(ruleSetIdx int, sysno uintptr, idx int) string {
+ return fmt.Sprintf("ruleViolation_%v_%v_%v", ruleSetIdx, sysno, idx)
+}
+
+func ruleLabel(ruleSetIdx int, sysno uintptr, idx int, name string) string {
+ return fmt.Sprintf("rule_%v_%v_%v_%v", ruleSetIdx, sysno, idx, name)
+}
+
+func checkArgsLabel(sysno uintptr) string {
+ return fmt.Sprintf("checkArgs_%v", sysno)
+}
+
+// addSyscallArgsCheck adds argument checks for a single system call. It does
+// not insert a jump to the default action at the end and it is the
+// responsibility of the caller to insert an appropriate jump after calling
+// this function.
+func addSyscallArgsCheck(p *bpf.ProgramBuilder, rules []Rule, action linux.BPFAction, ruleSetIdx int, sysno uintptr) error {
+ for ruleidx, rule := range rules {
+ labelled := false
+ for i, arg := range rule {
+ if arg != nil {
+ switch a := arg.(type) {
+ case AllowAny:
+ case AllowValue:
+ dataOffsetLow := seccompDataOffsetArgLow(i)
+ dataOffsetHigh := seccompDataOffsetArgHigh(i)
+ if i == RuleIP {
+ dataOffsetLow = seccompDataOffsetIPLow
+ dataOffsetHigh = seccompDataOffsetIPHigh
+ }
+ high, low := uint32(a>>32), uint32(a)
+ // assert arg_low == low
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
+ // assert arg_high == high
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
+ labelled = true
+ case GreaterThan:
+ dataOffsetLow := seccompDataOffsetArgLow(i)
+ dataOffsetHigh := seccompDataOffsetArgHigh(i)
+ if i == RuleIP {
+ dataOffsetLow = seccompDataOffsetIPLow
+ dataOffsetHigh = seccompDataOffsetIPHigh
+ }
+ labelGood := fmt.Sprintf("gt%v", i)
+ high, low := uint32(a>>32), uint32(a)
+ // assert arg_high < high
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
+ // arg_high > high
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+ // arg_low < low
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jgt|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
+ p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+ labelled = true
+ default:
+ return fmt.Errorf("unknown syscall rule type: %v", reflect.TypeOf(a))
+ }
+ }
+ }
+
+ // Matched, emit the given action.
+ p.AddStmt(bpf.Ret|bpf.K, uint32(action))
+
+ // Label the end of the rule if necessary. This is added for
+ // the jumps above when the argument check fails.
+ if labelled {
+ if err := p.AddLabel(ruleViolationLabel(ruleSetIdx, sysno, ruleidx)); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+// buildBSTProgram converts a binary tree started in 'root' into BPF code. The outline of the code
+// is as follows:
+//
+// // SYS_PIPE(22), root
+// (A == 22) ? goto argument check : continue
+// (A > 22) ? goto index_35 : goto index_9
+//
+// index_9: // SYS_MMAP(9), leaf
+// A == 9) ? goto argument check : defaultLabel
+//
+// index_35: // SYS_NANOSLEEP(35), single child
+// (A == 35) ? goto argument check : continue
+// (A > 35) ? goto index_50 : goto defaultLabel
+//
+// index_50: // SYS_LISTEN(50), leaf
+// (A == 50) ? goto argument check : goto defaultLabel
+//
+func buildBSTProgram(n *node, rules []RuleSet, program *bpf.ProgramBuilder) error {
+ // Root node is never referenced by label, skip it.
+ if !n.root {
+ if err := program.AddLabel(n.label()); err != nil {
+ return err
+ }
+ }
+
+ sysno := n.value
+ program.AddJumpTrueLabel(bpf.Jmp|bpf.Jeq|bpf.K, uint32(sysno), checkArgsLabel(sysno), 0)
+ if n.left == nil && n.right == nil {
+ // Leaf nodes don't require extra check.
+ program.AddDirectJumpLabel(defaultLabel)
+ } else {
+ // Non-leaf node. Check which turn to take otherwise. Using direct jumps
+ // in case that the offset may exceed the limit of a conditional jump (255)
+ program.AddJump(bpf.Jmp|bpf.Jgt|bpf.K, uint32(sysno), 0, skipOneInst)
+ program.AddDirectJumpLabel(n.right.label())
+ program.AddDirectJumpLabel(n.left.label())
+ }
+
+ if err := program.AddLabel(checkArgsLabel(sysno)); err != nil {
+ return err
+ }
+
+ emitted := false
+ for ruleSetIdx, rs := range rules {
+ if _, ok := rs.Rules[sysno]; ok {
+ // If there are no rules, then this will always match.
+ // Remember we've done this so that we can emit a
+ // sensible error. We can't catch all overlaps, but we
+ // can catch this one at least.
+ if emitted {
+ return fmt.Errorf("unreachable action for %v: 0x%x (rule set %d)", SyscallName(sysno), rs.Action, ruleSetIdx)
+ }
+
+ // Emit a vsyscall check if this rule requires a
+ // Vsyscall match. This rule ensures that the top bit
+ // is set in the instruction pointer, which is where
+ // the vsyscall page will be mapped.
+ if rs.Vsyscall {
+ program.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetIPHigh)
+ program.AddJumpFalseLabel(bpf.Jmp|bpf.Jset|bpf.K, 0x80000000, 0, vsyscallViolationLabel(ruleSetIdx, sysno))
+ }
+
+ // Emit matchers.
+ if len(rs.Rules[sysno]) == 0 {
+ // This is a blanket action.
+ program.AddStmt(bpf.Ret|bpf.K, uint32(rs.Action))
+ emitted = true
+ } else {
+ // Add an argument check for these particular
+ // arguments. This will continue execution and
+ // check the next rule set. We need to ensure
+ // that at the very end, we insert a direct
+ // jump label for the unmatched case.
+ if err := addSyscallArgsCheck(program, rs.Rules[sysno], rs.Action, ruleSetIdx, sysno); err != nil {
+ return err
+ }
+ }
+
+ // If there was a Vsyscall check for this rule, then we
+ // need to add an appropriate label for the jump above.
+ if rs.Vsyscall {
+ if err := program.AddLabel(vsyscallViolationLabel(ruleSetIdx, sysno)); err != nil {
+ return err
+ }
+ }
+ }
+ }
+
+ // Not matched? We only need to insert a jump to the default label if
+ // not default action has been emitted for this call.
+ if !emitted {
+ program.AddDirectJumpLabel(defaultLabel)
+ }
+
+ return nil
+}
+
+// node represents a tree node.
+type node struct {
+ value uintptr
+ left *node
+ right *node
+ root bool
+}
+
+// label returns the label corresponding to this node.
+//
+// If n is nil, then the defaultLabel is returned.
+func (n *node) label() string {
+ if n == nil {
+ return defaultLabel
+ }
+ return fmt.Sprintf("index_%v", n.value)
+}
+
+type traverseFunc func(*node, []RuleSet, *bpf.ProgramBuilder) error
+
+func (n *node) traverse(fn traverseFunc, rules []RuleSet, p *bpf.ProgramBuilder) error {
+ if n == nil {
+ return nil
+ }
+ if err := fn(n, rules, p); err != nil {
+ return err
+ }
+ if err := n.left.traverse(fn, rules, p); err != nil {
+ return err
+ }
+ return n.right.traverse(fn, rules, p)
+}
diff --git a/pkg/seccomp/seccomp_amd64.go b/pkg/seccomp/seccomp_amd64.go
new file mode 100644
index 000000000..00bf332c1
--- /dev/null
+++ b/pkg/seccomp/seccomp_amd64.go
@@ -0,0 +1,26 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package seccomp
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+const (
+ LINUX_AUDIT_ARCH = linux.AUDIT_ARCH_X86_64
+ SYS_SECCOMP = 317
+)
diff --git a/pkg/seccomp/seccomp_arm64.go b/pkg/seccomp/seccomp_arm64.go
new file mode 100644
index 000000000..b62133f21
--- /dev/null
+++ b/pkg/seccomp/seccomp_arm64.go
@@ -0,0 +1,26 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package seccomp
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+const (
+ LINUX_AUDIT_ARCH = linux.AUDIT_ARCH_AARCH64
+ SYS_SECCOMP = 277
+)
diff --git a/pkg/seccomp/seccomp_rules.go b/pkg/seccomp/seccomp_rules.go
new file mode 100644
index 000000000..a52dc1b4e
--- /dev/null
+++ b/pkg/seccomp/seccomp_rules.go
@@ -0,0 +1,139 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccomp
+
+import "fmt"
+
+// The offsets are based on the following struct in include/linux/seccomp.h.
+// struct seccomp_data {
+// int nr;
+// __u32 arch;
+// __u64 instruction_pointer;
+// __u64 args[6];
+// };
+const (
+ seccompDataOffsetNR = 0
+ seccompDataOffsetArch = 4
+ seccompDataOffsetIPLow = 8
+ seccompDataOffsetIPHigh = 12
+ seccompDataOffsetArgs = 16
+)
+
+func seccompDataOffsetArgLow(i int) uint32 {
+ return uint32(seccompDataOffsetArgs + i*8)
+}
+
+func seccompDataOffsetArgHigh(i int) uint32 {
+ return seccompDataOffsetArgLow(i) + 4
+}
+
+// AllowAny is marker to indicate any value will be accepted.
+type AllowAny struct{}
+
+func (a AllowAny) String() (s string) {
+ return "*"
+}
+
+// AllowValue specifies a value that needs to be strictly matched.
+type AllowValue uintptr
+
+// GreaterThan specifies a value that needs to be strictly smaller.
+type GreaterThan uintptr
+
+func (a AllowValue) String() (s string) {
+ return fmt.Sprintf("%#x ", uintptr(a))
+}
+
+// Rule stores the allowed syscall arguments.
+//
+// For example:
+// rule := Rule {
+// AllowValue(linux.ARCH_GET_FS | linux.ARCH_SET_FS), // arg0
+// }
+type Rule [7]interface{} // 6 arguments + RIP
+
+// RuleIP indicates what rules in the Rule array have to be applied to
+// instruction pointer.
+const RuleIP = 6
+
+func (r Rule) String() (s string) {
+ if len(r) == 0 {
+ return
+ }
+ s += "( "
+ for _, arg := range r {
+ if arg != nil {
+ s += fmt.Sprintf("%v ", arg)
+ }
+ }
+ s += ")"
+ return
+}
+
+// SyscallRules stores a map of OR'ed argument rules indexed by the syscall number.
+// If the 'Rules' is empty, we treat it as any argument is allowed.
+//
+// For example:
+// rules := SyscallRules{
+// syscall.SYS_FUTEX: []Rule{
+// {
+// AllowAny{},
+// AllowValue(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG),
+// }, // OR
+// {
+// AllowAny{},
+// AllowValue(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG),
+// },
+// },
+// syscall.SYS_GETPID: []Rule{},
+// }
+type SyscallRules map[uintptr][]Rule
+
+// NewSyscallRules returns a new SyscallRules.
+func NewSyscallRules() SyscallRules {
+ return make(map[uintptr][]Rule)
+}
+
+// AddRule adds the given rule. It will create a new entry for a new syscall, otherwise
+// it will append to the existing rules.
+func (sr SyscallRules) AddRule(sysno uintptr, r Rule) {
+ if cur, ok := sr[sysno]; ok {
+ // An empty rules means allow all. Honor it when more rules are added.
+ if len(cur) == 0 {
+ sr[sysno] = append(sr[sysno], Rule{})
+ }
+ sr[sysno] = append(sr[sysno], r)
+ } else {
+ sr[sysno] = []Rule{r}
+ }
+}
+
+// Merge merges the given SyscallRules.
+func (sr SyscallRules) Merge(rules SyscallRules) {
+ for sysno, rs := range rules {
+ if cur, ok := sr[sysno]; ok {
+ // An empty rules means allow all. Honor it when more rules are added.
+ if len(cur) == 0 {
+ sr[sysno] = append(sr[sysno], Rule{})
+ }
+ if len(rs) == 0 {
+ rs = []Rule{{}}
+ }
+ sr[sysno] = append(sr[sysno], rs...)
+ } else {
+ sr[sysno] = rs
+ }
+ }
+}
diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go
new file mode 100644
index 000000000..88766f33b
--- /dev/null
+++ b/pkg/seccomp/seccomp_test.go
@@ -0,0 +1,580 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccomp
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "math"
+ "math/rand"
+ "os"
+ "os/exec"
+ "strings"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/bpf"
+)
+
+type seccompData struct {
+ nr uint32
+ arch uint32
+ instructionPointer uint64
+ args [6]uint64
+}
+
+// newVictim makes a victim binary.
+func newVictim() (string, error) {
+ f, err := ioutil.TempFile("", "victim")
+ if err != nil {
+ return "", err
+ }
+ defer f.Close()
+ path := f.Name()
+ if _, err := io.Copy(f, bytes.NewBuffer(victimData)); err != nil {
+ os.Remove(path)
+ return "", err
+ }
+ if err := os.Chmod(path, 0755); err != nil {
+ os.Remove(path)
+ return "", err
+ }
+ return path, nil
+}
+
+// asInput converts a seccompData to a bpf.Input.
+func (d *seccompData) asInput() bpf.Input {
+ return bpf.InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian}
+}
+
+func TestBasic(t *testing.T) {
+ type spec struct {
+ // desc is the test's description.
+ desc string
+
+ // data is the input data.
+ data seccompData
+
+ // want is the expected return value of the BPF program.
+ want linux.BPFAction
+ }
+
+ for _, test := range []struct {
+ ruleSets []RuleSet
+ defaultAction linux.BPFAction
+ specs []spec
+ }{
+ {
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{1: {}},
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ specs: []spec{
+ {
+ desc: "Single syscall allowed",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "Single syscall disallowed",
+ data: seccompData{nr: 2, arch: linux.AUDIT_ARCH_X86_64},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ AllowValue(0x1),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ Rules: SyscallRules{
+ 1: {},
+ 2: {},
+ },
+ Action: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_KILL_THREAD,
+ specs: []spec{
+ {
+ desc: "Multiple rulesets allowed (1a)",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x1}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "Multiple rulesets allowed (1b)",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "Multiple rulesets allowed (2)",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "Multiple rulesets allowed (2)",
+ data: seccompData{nr: 0, arch: linux.AUDIT_ARCH_X86_64},
+ want: linux.SECCOMP_RET_KILL_THREAD,
+ },
+ },
+ },
+ {
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: {},
+ 3: {},
+ 5: {},
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ specs: []spec{
+ {
+ desc: "Multiple syscalls allowed (1)",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "Multiple syscalls allowed (3)",
+ data: seccompData{nr: 3, arch: linux.AUDIT_ARCH_X86_64},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "Multiple syscalls allowed (5)",
+ data: seccompData{nr: 5, arch: linux.AUDIT_ARCH_X86_64},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "Multiple syscalls disallowed (0)",
+ data: seccompData{nr: 0, arch: linux.AUDIT_ARCH_X86_64},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "Multiple syscalls disallowed (2)",
+ data: seccompData{nr: 2, arch: linux.AUDIT_ARCH_X86_64},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "Multiple syscalls disallowed (4)",
+ data: seccompData{nr: 4, arch: linux.AUDIT_ARCH_X86_64},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "Multiple syscalls disallowed (6)",
+ data: seccompData{nr: 6, arch: linux.AUDIT_ARCH_X86_64},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "Multiple syscalls disallowed (100)",
+ data: seccompData{nr: 100, arch: linux.AUDIT_ARCH_X86_64},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: {},
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ specs: []spec{
+ {
+ desc: "Wrong architecture",
+ data: seccompData{nr: 1, arch: 123},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: {},
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ specs: []spec{
+ {
+ desc: "Syscall disallowed, action trap",
+ data: seccompData{nr: 2, arch: linux.AUDIT_ARCH_X86_64},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ AllowAny{},
+ AllowValue(0xf),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ specs: []spec{
+ {
+ desc: "Syscall argument allowed",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xf}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "Syscall argument disallowed",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xe}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ AllowValue(0xf),
+ },
+ {
+ AllowValue(0xe),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ specs: []spec{
+ {
+ desc: "Syscall argument allowed, two rules",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "Syscall argument allowed, two rules",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xe}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ },
+ {
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ AllowValue(0),
+ AllowValue(math.MaxUint64 - 1),
+ AllowValue(math.MaxUint32),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ specs: []spec{
+ {
+ desc: "64bit syscall argument allowed",
+ data: seccompData{
+ nr: 1,
+ arch: linux.AUDIT_ARCH_X86_64,
+ args: [6]uint64{0, math.MaxUint64 - 1, math.MaxUint32},
+ },
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "64bit syscall argument disallowed",
+ data: seccompData{
+ nr: 1,
+ arch: linux.AUDIT_ARCH_X86_64,
+ args: [6]uint64{0, math.MaxUint64, math.MaxUint32},
+ },
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "64bit syscall argument disallowed",
+ data: seccompData{
+ nr: 1,
+ arch: linux.AUDIT_ARCH_X86_64,
+ args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1},
+ },
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ GreaterThan(0xf),
+ GreaterThan(0xabcd000d),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ specs: []spec{
+ {
+ desc: "GreaterThan: Syscall argument allowed",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xffffffff}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "GreaterThan: Syscall argument disallowed (equal)",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xffffffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "Syscall argument disallowed (smaller)",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x0, 0xffffffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "GreaterThan2: Syscall argument allowed",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xfbcd000d}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "GreaterThan2: Syscall argument disallowed (equal)",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xabcd000d}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "GreaterThan2: Syscall argument disallowed (smaller)",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xa000ffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ RuleIP: AllowValue(0x7aabbccdd),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ specs: []spec{
+ {
+ desc: "IP: Syscall instruction pointer allowed",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{}, instructionPointer: 0x7aabbccdd},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "IP: Syscall instruction pointer disallowed",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{}, instructionPointer: 0x711223344},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ } {
+ instrs, err := BuildProgram(test.ruleSets, test.defaultAction)
+ if err != nil {
+ t.Errorf("%s: buildProgram() got error: %v", test.specs[0].desc, err)
+ continue
+ }
+ p, err := bpf.Compile(instrs)
+ if err != nil {
+ t.Errorf("%s: bpf.Compile() got error: %v", test.specs[0].desc, err)
+ continue
+ }
+ for _, spec := range test.specs {
+ got, err := bpf.Exec(p, spec.data.asInput())
+ if err != nil {
+ t.Errorf("%s: bpf.Exec() got error: %v", spec.desc, err)
+ continue
+ }
+ if got != uint32(spec.want) {
+ t.Errorf("%s: bpd.Exec() = %d, want: %d", spec.desc, got, spec.want)
+ }
+ }
+ }
+}
+
+// TestRandom tests that randomly generated rules are encoded correctly.
+func TestRandom(t *testing.T) {
+ rand.Seed(time.Now().UnixNano())
+ size := rand.Intn(50) + 1
+ syscallRules := make(map[uintptr][]Rule)
+ for len(syscallRules) < size {
+ n := uintptr(rand.Intn(200))
+ if _, ok := syscallRules[n]; !ok {
+ syscallRules[n] = []Rule{}
+ }
+ }
+
+ t.Logf("Testing filters: %v", syscallRules)
+ instrs, err := BuildProgram([]RuleSet{
+ RuleSet{
+ Rules: syscallRules,
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ }, linux.SECCOMP_RET_TRAP)
+ if err != nil {
+ t.Fatalf("buildProgram() got error: %v", err)
+ }
+ p, err := bpf.Compile(instrs)
+ if err != nil {
+ t.Fatalf("bpf.Compile() got error: %v", err)
+ }
+ for i := uint32(0); i < 200; i++ {
+ data := seccompData{nr: i, arch: linux.AUDIT_ARCH_X86_64}
+ got, err := bpf.Exec(p, data.asInput())
+ if err != nil {
+ t.Errorf("bpf.Exec() got error: %v, for syscall %d", err, i)
+ continue
+ }
+ want := linux.SECCOMP_RET_TRAP
+ if _, ok := syscallRules[uintptr(i)]; ok {
+ want = linux.SECCOMP_RET_ALLOW
+ }
+ if got != uint32(want) {
+ t.Errorf("bpf.Exec() = %d, want: %d, for syscall %d", got, want, i)
+ }
+ }
+}
+
+// TestReadDeal checks that a process dies when it trips over the filter and
+// that it doesn't die when the filter is not triggered.
+func TestRealDeal(t *testing.T) {
+ for _, test := range []struct {
+ die bool
+ want string
+ }{
+ {die: true, want: "bad system call"},
+ {die: false, want: "Syscall was allowed!!!"},
+ } {
+ victim, err := newVictim()
+ if err != nil {
+ t.Fatalf("unable to get victim: %v", err)
+ }
+ defer os.Remove(victim)
+ dieFlag := fmt.Sprintf("-die=%v", test.die)
+ cmd := exec.Command(victim, dieFlag)
+
+ out, err := cmd.CombinedOutput()
+ if test.die {
+ if err == nil {
+ t.Errorf("victim was not killed as expected, output: %s", out)
+ continue
+ }
+ // Depending on kernel version, either RET_TRAP or RET_KILL_PROCESS is
+ // used. RET_TRAP dumps reason for exit in output, while RET_KILL_PROCESS
+ // returns SIGSYS as exit status.
+ if !strings.Contains(string(out), test.want) &&
+ !strings.Contains(err.Error(), test.want) {
+ t.Errorf("Victim error is wrong, got: %v, err: %v, want: %v", string(out), err, test.want)
+ continue
+ }
+ } else {
+ if err != nil {
+ t.Errorf("victim failed to execute, err: %v", err)
+ continue
+ }
+ if !strings.Contains(string(out), test.want) {
+ t.Errorf("Victim output is wrong, got: %v, want: %v", string(out), test.want)
+ continue
+ }
+ }
+ }
+}
+
+// TestMerge ensures that empty rules are not erased when rules are merged.
+func TestMerge(t *testing.T) {
+ for _, tst := range []struct {
+ name string
+ main []Rule
+ merge []Rule
+ want []Rule
+ }{
+ {
+ name: "empty both",
+ main: nil,
+ merge: nil,
+ want: []Rule{{}, {}},
+ },
+ {
+ name: "empty main",
+ main: nil,
+ merge: []Rule{{}},
+ want: []Rule{{}, {}},
+ },
+ {
+ name: "empty merge",
+ main: []Rule{{}},
+ merge: nil,
+ want: []Rule{{}, {}},
+ },
+ } {
+ t.Run(tst.name, func(t *testing.T) {
+ mainRules := SyscallRules{1: tst.main}
+ mergeRules := SyscallRules{1: tst.merge}
+ mainRules.Merge(mergeRules)
+ if got, want := len(mainRules[1]), len(tst.want); got != want {
+ t.Errorf("wrong length, got: %d, want: %d", got, want)
+ }
+ for i, r := range mainRules[1] {
+ if r != tst.want[i] {
+ t.Errorf("result, got: %v, want: %v", r, tst.want[i])
+ }
+ }
+ })
+ }
+}
+
+// TestAddRule ensures that empty rules are not erased when rules are added.
+func TestAddRule(t *testing.T) {
+ rules := SyscallRules{1: {}}
+ rules.AddRule(1, Rule{})
+ if got, want := len(rules[1]), 2; got != want {
+ t.Errorf("len(rules[1]), got: %d, want: %d", got, want)
+ }
+}
diff --git a/pkg/seccomp/seccomp_test_victim.go b/pkg/seccomp/seccomp_test_victim.go
new file mode 100644
index 000000000..da6b9eaaf
--- /dev/null
+++ b/pkg/seccomp/seccomp_test_victim.go
@@ -0,0 +1,117 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Test binary used to test that seccomp filters are properly constructed and
+// indeed kill the process on violation.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "os"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+func main() {
+ dieFlag := flag.Bool("die", false, "trips over the filter if true")
+ flag.Parse()
+
+ syscalls := seccomp.SyscallRules{
+ syscall.SYS_ACCEPT: {},
+ syscall.SYS_ARCH_PRCTL: {},
+ syscall.SYS_BIND: {},
+ syscall.SYS_BRK: {},
+ syscall.SYS_CLOCK_GETTIME: {},
+ syscall.SYS_CLONE: {},
+ syscall.SYS_CLOSE: {},
+ syscall.SYS_DUP: {},
+ syscall.SYS_DUP3: {},
+ syscall.SYS_EPOLL_CREATE1: {},
+ syscall.SYS_EPOLL_CTL: {},
+ syscall.SYS_EPOLL_WAIT: {},
+ syscall.SYS_EPOLL_PWAIT: {},
+ syscall.SYS_EXIT: {},
+ syscall.SYS_EXIT_GROUP: {},
+ syscall.SYS_FALLOCATE: {},
+ syscall.SYS_FCHMOD: {},
+ syscall.SYS_FCNTL: {},
+ syscall.SYS_FSTAT: {},
+ syscall.SYS_FSYNC: {},
+ syscall.SYS_FTRUNCATE: {},
+ syscall.SYS_FUTEX: {},
+ syscall.SYS_GETDENTS64: {},
+ syscall.SYS_GETPEERNAME: {},
+ syscall.SYS_GETPID: {},
+ syscall.SYS_GETSOCKNAME: {},
+ syscall.SYS_GETSOCKOPT: {},
+ syscall.SYS_GETTID: {},
+ syscall.SYS_GETTIMEOFDAY: {},
+ syscall.SYS_LISTEN: {},
+ syscall.SYS_LSEEK: {},
+ syscall.SYS_MADVISE: {},
+ syscall.SYS_MINCORE: {},
+ syscall.SYS_MMAP: {},
+ syscall.SYS_MPROTECT: {},
+ syscall.SYS_MUNLOCK: {},
+ syscall.SYS_MUNMAP: {},
+ syscall.SYS_NANOSLEEP: {},
+ syscall.SYS_NEWFSTATAT: {},
+ syscall.SYS_OPEN: {},
+ syscall.SYS_PPOLL: {},
+ syscall.SYS_PREAD64: {},
+ syscall.SYS_PSELECT6: {},
+ syscall.SYS_PWRITE64: {},
+ syscall.SYS_READ: {},
+ syscall.SYS_READLINKAT: {},
+ syscall.SYS_READV: {},
+ syscall.SYS_RECVMSG: {},
+ syscall.SYS_RENAMEAT: {},
+ syscall.SYS_RESTART_SYSCALL: {},
+ syscall.SYS_RT_SIGACTION: {},
+ syscall.SYS_RT_SIGPROCMASK: {},
+ syscall.SYS_RT_SIGRETURN: {},
+ syscall.SYS_SCHED_YIELD: {},
+ syscall.SYS_SENDMSG: {},
+ syscall.SYS_SETITIMER: {},
+ syscall.SYS_SET_ROBUST_LIST: {},
+ syscall.SYS_SETSOCKOPT: {},
+ syscall.SYS_SHUTDOWN: {},
+ syscall.SYS_SIGALTSTACK: {},
+ syscall.SYS_SOCKET: {},
+ syscall.SYS_SYNC_FILE_RANGE: {},
+ syscall.SYS_TGKILL: {},
+ syscall.SYS_UTIMENSAT: {},
+ syscall.SYS_WRITE: {},
+ syscall.SYS_WRITEV: {},
+ }
+ die := *dieFlag
+ if !die {
+ syscalls[syscall.SYS_OPENAT] = []seccomp.Rule{
+ {
+ seccomp.AllowValue(10),
+ },
+ }
+ }
+
+ if err := seccomp.Install(syscalls); err != nil {
+ fmt.Printf("Failed to install seccomp: %v", err)
+ os.Exit(1)
+ }
+ fmt.Printf("Filters installed\n")
+
+ syscall.RawSyscall(syscall.SYS_OPENAT, 10, 0, 0)
+ fmt.Printf("Syscall was allowed!!!\n")
+}
diff --git a/pkg/seccomp/seccomp_unsafe.go b/pkg/seccomp/seccomp_unsafe.go
new file mode 100644
index 000000000..f7e986589
--- /dev/null
+++ b/pkg/seccomp/seccomp_unsafe.go
@@ -0,0 +1,63 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccomp
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// SetFilter installs the given BPF program.
+//
+// This is safe to call from an afterFork context.
+//
+//go:nosplit
+func SetFilter(instrs []linux.BPFInstruction) syscall.Errno {
+ // PR_SET_NO_NEW_PRIVS is required in order to enable seccomp. See seccomp(2) for details.
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PRCTL, linux.PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0, 0); errno != 0 {
+ return errno
+ }
+
+ sockProg := linux.SockFprog{
+ Len: uint16(len(instrs)),
+ Filter: (*linux.BPFInstruction)(unsafe.Pointer(&instrs[0])),
+ }
+ return seccomp(linux.SECCOMP_SET_MODE_FILTER, linux.SECCOMP_FILTER_FLAG_TSYNC, unsafe.Pointer(&sockProg))
+}
+
+func isKillProcessAvailable() (bool, error) {
+ action := uint32(linux.SECCOMP_RET_KILL_PROCESS)
+ if errno := seccomp(linux.SECCOMP_GET_ACTION_AVAIL, 0, unsafe.Pointer(&action)); errno != 0 {
+ // EINVAL: SECCOMP_GET_ACTION_AVAIL not in this kernel yet.
+ // EOPNOTSUPP: SECCOMP_RET_KILL_PROCESS not supported.
+ if errno == syscall.EINVAL || errno == syscall.EOPNOTSUPP {
+ return false, nil
+ }
+ return false, errno
+ }
+ return true, nil
+}
+
+// seccomp calls seccomp(2). This is safe to call from an afterFork context.
+//
+//go:nosplit
+func seccomp(op, flags uint32, ptr unsafe.Pointer) syscall.Errno {
+ if _, _, errno := syscall.RawSyscall(SYS_SECCOMP, uintptr(op), uintptr(flags), uintptr(ptr)); errno != 0 {
+ return errno
+ }
+ return 0
+}
diff --git a/pkg/secio/BUILD b/pkg/secio/BUILD
new file mode 100644
index 000000000..60f63c7a6
--- /dev/null
+++ b/pkg/secio/BUILD
@@ -0,0 +1,19 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "secio",
+ srcs = [
+ "full_reader.go",
+ "secio.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+)
+
+go_test(
+ name = "secio_test",
+ size = "small",
+ srcs = ["secio_test.go"],
+ library = ":secio",
+)
diff --git a/pkg/secio/full_reader.go b/pkg/secio/full_reader.go
new file mode 100644
index 000000000..aed2564bd
--- /dev/null
+++ b/pkg/secio/full_reader.go
@@ -0,0 +1,34 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package secio
+
+import (
+ "io"
+)
+
+// FullReader adapts an io.Reader to never return partial reads with a nil
+// error.
+type FullReader struct {
+ Reader io.Reader
+}
+
+// Read implements io.Reader.Read.
+func (r FullReader) Read(dst []byte) (int, error) {
+ n, err := io.ReadFull(r.Reader, dst)
+ if err == io.ErrUnexpectedEOF {
+ return n, io.EOF
+ }
+ return n, err
+}
diff --git a/pkg/secio/secio.go b/pkg/secio/secio.go
new file mode 100644
index 000000000..b43226035
--- /dev/null
+++ b/pkg/secio/secio.go
@@ -0,0 +1,105 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package secio provides support for sectioned I/O.
+package secio
+
+import (
+ "errors"
+ "io"
+)
+
+// ErrReachedLimit is returned when SectionReader.Read or SectionWriter.Write
+// reaches its limit.
+var ErrReachedLimit = errors.New("reached limit")
+
+// SectionReader implements io.Reader on a section of an underlying io.ReaderAt.
+// It is similar to io.SectionReader, but:
+//
+// - Reading beyond the limit returns ErrReachedLimit, not io.EOF.
+//
+// - Limit overflow is handled correctly.
+type SectionReader struct {
+ r io.ReaderAt
+ off int64
+ limit int64
+}
+
+// Read implements io.Reader.Read.
+func (r *SectionReader) Read(dst []byte) (int, error) {
+ if r.limit >= 0 {
+ if max := r.limit - r.off; max < int64(len(dst)) {
+ dst = dst[:max]
+ }
+ }
+ n, err := r.r.ReadAt(dst, r.off)
+ r.off += int64(n)
+ if err == nil && r.off == r.limit {
+ err = ErrReachedLimit
+ }
+ return n, err
+}
+
+// NewOffsetReader returns an io.Reader that reads from r starting at offset
+// off.
+func NewOffsetReader(r io.ReaderAt, off int64) *SectionReader {
+ return &SectionReader{r, off, -1}
+}
+
+// NewSectionReader returns an io.Reader that reads from r starting at offset
+// off and stops with ErrReachedLimit after n bytes.
+func NewSectionReader(r io.ReaderAt, off int64, n int64) *SectionReader {
+ // If off + n overflows, it will be < 0 such that no limit applies, but
+ // this is the correct behavior as long as r prohibits reading at offsets
+ // beyond MaxInt64.
+ return &SectionReader{r, off, off + n}
+}
+
+// SectionWriter implements io.Writer on a section of an underlying
+// io.WriterAt. Writing beyond the limit returns ErrReachedLimit.
+type SectionWriter struct {
+ w io.WriterAt
+ off int64
+ limit int64
+}
+
+// Write implements io.Writer.Write.
+func (w *SectionWriter) Write(src []byte) (int, error) {
+ if w.limit >= 0 {
+ if max := w.limit - w.off; max < int64(len(src)) {
+ src = src[:max]
+ }
+ }
+ n, err := w.w.WriteAt(src, w.off)
+ w.off += int64(n)
+ if err == nil && w.off == w.limit {
+ err = ErrReachedLimit
+ }
+ return n, err
+}
+
+// NewOffsetWriter returns an io.Writer that writes to w starting at offset
+// off.
+func NewOffsetWriter(w io.WriterAt, off int64) *SectionWriter {
+ return &SectionWriter{w, off, -1}
+}
+
+// NewSectionWriter returns an io.Writer that writes to w starting at offset
+// off and stops with ErrReachedLimit after n bytes.
+func NewSectionWriter(w io.WriterAt, off int64, n int64) *SectionWriter {
+ // If off + n overflows, it will be < 0 such that no limit applies, but
+ // this is the correct behavior as long as w prohibits writing at offsets
+ // beyond MaxInt64.
+ return &SectionWriter{w, off, off + n}
+}
diff --git a/pkg/secio/secio_test.go b/pkg/secio/secio_test.go
new file mode 100644
index 000000000..d1d905187
--- /dev/null
+++ b/pkg/secio/secio_test.go
@@ -0,0 +1,126 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package secio
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "io/ioutil"
+ "math"
+ "testing"
+)
+
+var errEndOfBuffer = errors.New("write beyond end of buffer")
+
+// buffer resembles bytes.Buffer, but implements io.ReaderAt and io.WriterAt.
+// Reads beyond the end of the buffer return io.EOF. Writes beyond the end of
+// the buffer return errEndOfBuffer.
+type buffer struct {
+ Bytes []byte
+}
+
+// ReadAt implements io.ReaderAt.ReadAt.
+func (b *buffer) ReadAt(dst []byte, off int64) (int, error) {
+ if off >= int64(len(b.Bytes)) {
+ return 0, io.EOF
+ }
+ n := copy(dst, b.Bytes[off:])
+ if n < len(dst) {
+ return n, io.EOF
+ }
+ return n, nil
+}
+
+// WriteAt implements io.WriterAt.WriteAt.
+func (b *buffer) WriteAt(src []byte, off int64) (int, error) {
+ if off >= int64(len(b.Bytes)) {
+ return 0, errEndOfBuffer
+ }
+ n := copy(b.Bytes[off:], src)
+ if n < len(src) {
+ return n, errEndOfBuffer
+ }
+ return n, nil
+}
+
+func newBufferString(s string) *buffer {
+ return &buffer{[]byte(s)}
+}
+
+func TestOffsetReader(t *testing.T) {
+ buf := newBufferString("foobar")
+ r := NewOffsetReader(buf, 3)
+ dst, err := ioutil.ReadAll(r)
+ if want := []byte("bar"); !bytes.Equal(dst, want) || err != nil {
+ t.Errorf("ReadAll: got (%q, %v), wanted (%q, nil)", dst, err, want)
+ }
+}
+
+func TestSectionReader(t *testing.T) {
+ buf := newBufferString("foobarbaz")
+ r := NewSectionReader(buf, 3, 3)
+ dst, err := ioutil.ReadAll(r)
+ if want, wantErr := []byte("bar"), ErrReachedLimit; !bytes.Equal(dst, want) || err != wantErr {
+ t.Errorf("ReadAll: got (%q, %v), wanted (%q, %v)", dst, err, want, wantErr)
+ }
+}
+
+func TestSectionReaderLimitOverflow(t *testing.T) {
+ // SectionReader behaves like OffsetReader when limit overflows int64.
+ buf := newBufferString("foobar")
+ r := NewSectionReader(buf, 3, math.MaxInt64)
+ dst, err := ioutil.ReadAll(r)
+ if want := []byte("bar"); !bytes.Equal(dst, want) || err != nil {
+ t.Errorf("ReadAll: got (%q, %v), wanted (%q, nil)", dst, err, want)
+ }
+}
+
+func TestOffsetWriter(t *testing.T) {
+ buf := newBufferString("ABCDEF")
+ w := NewOffsetWriter(buf, 3)
+ n, err := w.Write([]byte("foobar"))
+ if wantN, wantErr := 3, errEndOfBuffer; n != wantN || err != wantErr {
+ t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
+ }
+ if got, want := buf.Bytes, []byte("ABCfoo"); !bytes.Equal(got, want) {
+ t.Errorf("buf.Bytes: got %q, wanted %q", got, want)
+ }
+}
+
+func TestSectionWriter(t *testing.T) {
+ buf := newBufferString("ABCDEFGHI")
+ w := NewSectionWriter(buf, 3, 3)
+ n, err := w.Write([]byte("foobar"))
+ if wantN, wantErr := 3, ErrReachedLimit; n != wantN || err != wantErr {
+ t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
+ }
+ if got, want := buf.Bytes, []byte("ABCfooGHI"); !bytes.Equal(got, want) {
+ t.Errorf("buf.Bytes: got %q, wanted %q", got, want)
+ }
+}
+
+func TestSectionWriterLimitOverflow(t *testing.T) {
+ // SectionWriter behaves like OffsetWriter when limit overflows int64.
+ buf := newBufferString("ABCDEF")
+ w := NewSectionWriter(buf, 3, math.MaxInt64)
+ n, err := w.Write([]byte("foobar"))
+ if wantN, wantErr := 3, errEndOfBuffer; n != wantN || err != wantErr {
+ t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
+ }
+ if got, want := buf.Bytes, []byte("ABCfoo"); !bytes.Equal(got, want) {
+ t.Errorf("buf.Bytes: got %q, wanted %q", got, want)
+ }
+}
diff --git a/pkg/segment/BUILD b/pkg/segment/BUILD
new file mode 100644
index 000000000..f57ccc170
--- /dev/null
+++ b/pkg/segment/BUILD
@@ -0,0 +1,33 @@
+load("//tools/go_generics:defs.bzl", "go_template")
+
+package(
+ default_visibility = ["//:sandbox"],
+ licenses = ["notice"],
+)
+
+go_template(
+ name = "generic_range",
+ srcs = ["range.go"],
+ types = [
+ "T",
+ ],
+)
+
+go_template(
+ name = "generic_set",
+ srcs = [
+ "set.go",
+ "set_state.go",
+ ],
+ opt_consts = [
+ "minDegree",
+ # trackGaps must either be 0 or 1.
+ "trackGaps",
+ ],
+ types = [
+ "Key",
+ "Range",
+ "Value",
+ "Functions",
+ ],
+)
diff --git a/pkg/segment/range.go b/pkg/segment/range.go
new file mode 100644
index 000000000..4d4aeffef
--- /dev/null
+++ b/pkg/segment/range.go
@@ -0,0 +1,79 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package segment
+
+// T is a required type parameter that must be an integral type.
+type T uint64
+
+// A Range represents a contiguous range of T.
+//
+// +stateify savable
+type Range struct {
+ // Start is the inclusive start of the range.
+ Start T
+
+ // End is the exclusive end of the range.
+ End T
+}
+
+// WellFormed returns true if r.Start <= r.End. All other methods on a Range
+// require that the Range is well-formed.
+func (r Range) WellFormed() bool {
+ return r.Start <= r.End
+}
+
+// Length returns the length of the range.
+func (r Range) Length() T {
+ return r.End - r.Start
+}
+
+// Contains returns true if r contains x.
+func (r Range) Contains(x T) bool {
+ return r.Start <= x && x < r.End
+}
+
+// Overlaps returns true if r and r2 overlap.
+func (r Range) Overlaps(r2 Range) bool {
+ return r.Start < r2.End && r2.Start < r.End
+}
+
+// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is
+// contained within r.
+func (r Range) IsSupersetOf(r2 Range) bool {
+ return r.Start <= r2.Start && r.End >= r2.End
+}
+
+// Intersect returns a range consisting of the intersection between r and r2.
+// If r and r2 do not overlap, Intersect returns a range with unspecified
+// bounds, but for which Length() == 0.
+func (r Range) Intersect(r2 Range) Range {
+ if r.Start < r2.Start {
+ r.Start = r2.Start
+ }
+ if r.End > r2.End {
+ r.End = r2.End
+ }
+ if r.End < r.Start {
+ r.End = r.Start
+ }
+ return r
+}
+
+// CanSplitAt returns true if it is legal to split a segment spanning the range
+// r at x; that is, splitting at x would produce two ranges, both of which have
+// non-zero length.
+func (r Range) CanSplitAt(x T) bool {
+ return r.Contains(x) && r.Start < x
+}
diff --git a/pkg/segment/set.go b/pkg/segment/set.go
new file mode 100644
index 000000000..1a17ad9cb
--- /dev/null
+++ b/pkg/segment/set.go
@@ -0,0 +1,1754 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package segment provides tools for working with collections of segments. A
+// segment is a key-value mapping, where the key is a non-empty contiguous
+// range of values of type Key, and the value is a single value of type Value.
+//
+// Clients using this package must use the go_template_instance rule in
+// tools/go_generics/defs.bzl to create an instantiation of this
+// template package, providing types to use in place of Key, Range, Value, and
+// Functions. See pkg/segment/test/BUILD for a usage example.
+package segment
+
+import (
+ "bytes"
+ "fmt"
+)
+
+// Key is a required type parameter that must be an integral type.
+type Key uint64
+
+// Range is a required type parameter equivalent to Range<Key>.
+type Range interface{}
+
+// Value is a required type parameter.
+type Value interface{}
+
+// trackGaps is an optional parameter.
+//
+// If trackGaps is 1, the Set will track maximum gap size recursively,
+// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this
+// case, Key must be an unsigned integer.
+//
+// trackGaps must be 0 or 1.
+const trackGaps = 0
+
+var _ = uint8(trackGaps << 7) // Will fail if not zero or one.
+
+// dynamicGap is a type that disappears if trackGaps is 0.
+type dynamicGap [trackGaps]Key
+
+// Get returns the value of the gap.
+//
+// Precondition: trackGaps must be non-zero.
+func (d *dynamicGap) Get() Key {
+ return d[:][0]
+}
+
+// Set sets the value of the gap.
+//
+// Precondition: trackGaps must be non-zero.
+func (d *dynamicGap) Set(v Key) {
+ d[:][0] = v
+}
+
+// Functions is a required type parameter that must be a struct implementing
+// the methods defined by Functions.
+type Functions interface {
+ // MinKey returns the minimum allowed key.
+ MinKey() Key
+
+ // MaxKey returns the maximum allowed key + 1.
+ MaxKey() Key
+
+ // ClearValue deinitializes the given value. (For example, if Value is a
+ // pointer or interface type, ClearValue should set it to nil.)
+ ClearValue(*Value)
+
+ // Merge attempts to merge the values corresponding to two consecutive
+ // segments. If successful, Merge returns (merged value, true). Otherwise,
+ // it returns (unspecified, false).
+ //
+ // Preconditions: r1.End == r2.Start.
+ //
+ // Postconditions: If merging succeeds, val1 and val2 are invalidated.
+ Merge(r1 Range, val1 Value, r2 Range, val2 Value) (Value, bool)
+
+ // Split splits a segment's value at a key within its range, such that the
+ // first returned value corresponds to the range [r.Start, split) and the
+ // second returned value corresponds to the range [split, r.End).
+ //
+ // Preconditions: r.Start < split < r.End.
+ //
+ // Postconditions: The original value val is invalidated.
+ Split(r Range, val Value, split Key) (Value, Value)
+}
+
+const (
+ // minDegree is the minimum degree of an internal node in a Set B-tree.
+ //
+ // - Any non-root node has at least minDegree-1 segments.
+ //
+ // - Any non-root internal (non-leaf) node has at least minDegree children.
+ //
+ // - The root node may have fewer than minDegree-1 segments, but it may
+ // only have 0 segments if the tree is empty.
+ //
+ // Our implementation requires minDegree >= 3. Higher values of minDegree
+ // usually improve performance, but increase memory usage for small sets.
+ minDegree = 3
+
+ maxDegree = 2 * minDegree
+)
+
+// A Set is a mapping of segments with non-overlapping Range keys. The zero
+// value for a Set is an empty set. Set values are not safely movable nor
+// copyable. Set is thread-compatible.
+//
+// +stateify savable
+type Set struct {
+ root node `state:".(*SegmentDataSlices)"`
+}
+
+// IsEmpty returns true if the set contains no segments.
+func (s *Set) IsEmpty() bool {
+ return s.root.nrSegments == 0
+}
+
+// IsEmptyRange returns true iff no segments in the set overlap the given
+// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be
+// more efficient.
+func (s *Set) IsEmptyRange(r Range) bool {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return true
+ }
+ _, gap := s.Find(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ return r.End <= gap.End()
+}
+
+// Span returns the total size of all segments in the set.
+func (s *Set) Span() Key {
+ var sz Key
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sz += seg.Range().Length()
+ }
+ return sz
+}
+
+// SpanRange returns the total size of the intersection of segments in the set
+// with the given range.
+func (s *Set) SpanRange(r Range) Key {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return 0
+ }
+ var sz Key
+ for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() {
+ sz += seg.Range().Intersect(r).Length()
+ }
+ return sz
+}
+
+// FirstSegment returns the first segment in the set. If the set is empty,
+// FirstSegment returns a terminal iterator.
+func (s *Set) FirstSegment() Iterator {
+ if s.root.nrSegments == 0 {
+ return Iterator{}
+ }
+ return s.root.firstSegment()
+}
+
+// LastSegment returns the last segment in the set. If the set is empty,
+// LastSegment returns a terminal iterator.
+func (s *Set) LastSegment() Iterator {
+ if s.root.nrSegments == 0 {
+ return Iterator{}
+ }
+ return s.root.lastSegment()
+}
+
+// FirstGap returns the first gap in the set.
+func (s *Set) FirstGap() GapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return GapIterator{n, 0}
+}
+
+// LastGap returns the last gap in the set.
+func (s *Set) LastGap() GapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return GapIterator{n, n.nrSegments}
+}
+
+// Find returns the segment or gap whose range contains the given key. If a
+// segment is found, the returned Iterator is non-terminal and the
+// returned GapIterator is terminal. Otherwise, the returned Iterator is
+// terminal and the returned GapIterator is non-terminal.
+func (s *Set) Find(key Key) (Iterator, GapIterator) {
+ n := &s.root
+ for {
+ // Binary search invariant: the correct value of i lies within [lower,
+ // upper].
+ lower := 0
+ upper := n.nrSegments
+ for lower < upper {
+ i := lower + (upper-lower)/2
+ if r := n.keys[i]; key < r.End {
+ if key >= r.Start {
+ return Iterator{n, i}, GapIterator{}
+ }
+ upper = i
+ } else {
+ lower = i + 1
+ }
+ }
+ i := lower
+ if !n.hasChildren {
+ return Iterator{}, GapIterator{n, i}
+ }
+ n = n.children[i]
+ }
+}
+
+// FindSegment returns the segment whose range contains the given key. If no
+// such segment exists, FindSegment returns a terminal iterator.
+func (s *Set) FindSegment(key Key) Iterator {
+ seg, _ := s.Find(key)
+ return seg
+}
+
+// LowerBoundSegment returns the segment with the lowest range that contains a
+// key greater than or equal to min. If no such segment exists,
+// LowerBoundSegment returns a terminal iterator.
+func (s *Set) LowerBoundSegment(min Key) Iterator {
+ seg, gap := s.Find(min)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.NextSegment()
+}
+
+// UpperBoundSegment returns the segment with the highest range that contains a
+// key less than or equal to max. If no such segment exists, UpperBoundSegment
+// returns a terminal iterator.
+func (s *Set) UpperBoundSegment(max Key) Iterator {
+ seg, gap := s.Find(max)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.PrevSegment()
+}
+
+// FindGap returns the gap containing the given key. If no such gap exists
+// (i.e. the set contains a segment containing that key), FindGap returns a
+// terminal iterator.
+func (s *Set) FindGap(key Key) GapIterator {
+ _, gap := s.Find(key)
+ return gap
+}
+
+// LowerBoundGap returns the gap with the lowest range that is greater than or
+// equal to min.
+func (s *Set) LowerBoundGap(min Key) GapIterator {
+ seg, gap := s.Find(min)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.NextGap()
+}
+
+// UpperBoundGap returns the gap with the highest range that is less than or
+// equal to max.
+func (s *Set) UpperBoundGap(max Key) GapIterator {
+ seg, gap := s.Find(max)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.PrevGap()
+}
+
+// Add inserts the given segment into the set and returns true. If the new
+// segment can be merged with adjacent segments, Add will do so. If the new
+// segment would overlap an existing segment, Add returns false. If Add
+// succeeds, all existing iterators are invalidated.
+func (s *Set) Add(r Range, val Value) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.Insert(gap, r, val)
+ return true
+}
+
+// AddWithoutMerging inserts the given segment into the set and returns true.
+// If it would overlap an existing segment, AddWithoutMerging does nothing and
+// returns false. If AddWithoutMerging succeeds, all existing iterators are
+// invalidated.
+func (s *Set) AddWithoutMerging(r Range, val Value) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.InsertWithoutMergingUnchecked(gap, r, val)
+ return true
+}
+
+// Insert inserts the given segment into the given gap. If the new segment can
+// be merged with adjacent segments, Insert will do so. Insert returns an
+// iterator to the segment containing the inserted value (which may have been
+// merged with other values). All existing iterators (including gap, but not
+// including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid, Insert panics.
+//
+// Insert is semantically equivalent to a InsertWithoutMerging followed by a
+// Merge, but may be more efficient. Note that there is no unchecked variant of
+// Insert since Insert must retrieve and inspect gap's predecessor and
+// successor segments regardless.
+func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ prev, next := gap.PrevSegment(), gap.NextSegment()
+ if prev.Ok() && prev.End() > r.Start {
+ panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range()))
+ }
+ if next.Ok() && next.Start() < r.End {
+ panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range()))
+ }
+ if prev.Ok() && prev.End() == r.Start {
+ if mval, ok := (Functions{}).Merge(prev.Range(), prev.Value(), r, val); ok {
+ shrinkMaxGap := trackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get()
+ prev.SetEndUnchecked(r.End)
+ prev.SetValue(mval)
+ if shrinkMaxGap {
+ gap.node.updateMaxGapLeaf()
+ }
+ if next.Ok() && next.Start() == r.End {
+ val = mval
+ if mval, ok := (Functions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok {
+ prev.SetEndUnchecked(next.End())
+ prev.SetValue(mval)
+ return s.Remove(next).PrevSegment()
+ }
+ }
+ return prev
+ }
+ }
+ if next.Ok() && next.Start() == r.End {
+ if mval, ok := (Functions{}).Merge(r, val, next.Range(), next.Value()); ok {
+ shrinkMaxGap := trackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get()
+ next.SetStartUnchecked(r.Start)
+ next.SetValue(mval)
+ if shrinkMaxGap {
+ gap.node.updateMaxGapLeaf()
+ }
+ return next
+ }
+ }
+ // InsertWithoutMergingUnchecked will maintain maxGap if necessary.
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMerging inserts the given segment into the given gap and
+// returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid,
+// InsertWithoutMerging panics.
+func (s *Set) InsertWithoutMerging(gap GapIterator, r Range, val Value) Iterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if gr := gap.Range(); !gr.IsSupersetOf(r) {
+ panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr))
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMergingUnchecked inserts the given segment into the given gap
+// and returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// Preconditions: r.Start >= gap.Start(); r.End <= gap.End().
+func (s *Set) InsertWithoutMergingUnchecked(gap GapIterator, r Range, val Value) Iterator {
+ gap = gap.node.rebalanceBeforeInsert(gap)
+ splitMaxGap := trackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get())
+ copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments])
+ copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments])
+ gap.node.keys[gap.index] = r
+ gap.node.values[gap.index] = val
+ gap.node.nrSegments++
+ if splitMaxGap {
+ gap.node.updateMaxGapLeaf()
+ }
+ return Iterator{gap.node, gap.index}
+}
+
+// Remove removes the given segment and returns an iterator to the vacated gap.
+// All existing iterators (including seg, but not including the returned
+// iterator) are invalidated.
+func (s *Set) Remove(seg Iterator) GapIterator {
+ // We only want to remove directly from a leaf node.
+ if seg.node.hasChildren {
+ // Since seg.node has children, the removed segment must have a
+ // predecessor (at the end of the rightmost leaf of its left child
+ // subtree). Move the contents of that predecessor into the removed
+ // segment's position, and remove that predecessor instead. (We choose
+ // to steal the predecessor rather than the successor because removing
+ // from the end of a leaf node doesn't involve any copying unless
+ // merging is required.)
+ victim := seg.PrevSegment()
+ // This must be unchecked since until victim is removed, seg and victim
+ // overlap.
+ seg.SetRangeUnchecked(victim.Range())
+ seg.SetValue(victim.Value())
+ // Need to update the nextAdjacentNode's maxGap because the gap in between
+ // must have been modified by updating seg.Range() to victim.Range().
+ // seg.NextSegment() must exist since the last segment can't be in a
+ // non-leaf node.
+ nextAdjacentNode := seg.NextSegment().node
+ if trackGaps != 0 {
+ nextAdjacentNode.updateMaxGapLeaf()
+ }
+ return s.Remove(victim).NextGap()
+ }
+ copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments])
+ copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments])
+ Functions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1])
+ seg.node.nrSegments--
+ if trackGaps != 0 {
+ seg.node.updateMaxGapLeaf()
+ }
+ return seg.node.rebalanceAfterRemove(GapIterator{seg.node, seg.index})
+}
+
+// RemoveAll removes all segments from the set. All existing iterators are
+// invalidated.
+func (s *Set) RemoveAll() {
+ s.root = node{}
+}
+
+// RemoveRange removes all segments in the given range. An iterator to the
+// newly formed gap is returned, and all existing iterators are invalidated.
+func (s *Set) RemoveRange(r Range) GapIterator {
+ seg, gap := s.Find(r.Start)
+ if seg.Ok() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ return gap
+}
+
+// Merge attempts to merge two neighboring segments. If successful, Merge
+// returns an iterator to the merged segment, and all existing iterators are
+// invalidated. Otherwise, Merge returns a terminal iterator.
+//
+// If first is not the predecessor of second, Merge panics.
+func (s *Set) Merge(first, second Iterator) Iterator {
+ if first.NextSegment() != second {
+ panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range()))
+ }
+ return s.MergeUnchecked(first, second)
+}
+
+// MergeUnchecked attempts to merge two neighboring segments. If successful,
+// MergeUnchecked returns an iterator to the merged segment, and all existing
+// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal
+// iterator.
+//
+// Precondition: first is the predecessor of second: first.NextSegment() ==
+// second, first == second.PrevSegment().
+func (s *Set) MergeUnchecked(first, second Iterator) Iterator {
+ if first.End() == second.Start() {
+ if mval, ok := (Functions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok {
+ // N.B. This must be unchecked because until s.Remove(second), first
+ // overlaps second.
+ first.SetEndUnchecked(second.End())
+ first.SetValue(mval)
+ // Remove will handle the maxGap update if necessary.
+ return s.Remove(second).PrevSegment()
+ }
+ }
+ return Iterator{}
+}
+
+// MergeAll attempts to merge all adjacent segments in the set. All existing
+// iterators are invalidated.
+func (s *Set) MergeAll() {
+ seg := s.FirstSegment()
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeRange attempts to merge all adjacent segments that contain a key in the
+// specific range. All existing iterators are invalidated.
+func (s *Set) MergeRange(r Range) {
+ seg := s.LowerBoundSegment(r.Start)
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() && next.Range().Start < r.End {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeAdjacent attempts to merge the segment containing r.Start with its
+// predecessor, and the segment containing r.End-1 with its successor.
+func (s *Set) MergeAdjacent(r Range) {
+ first := s.FindSegment(r.Start)
+ if first.Ok() {
+ if prev := first.PrevSegment(); prev.Ok() {
+ s.Merge(prev, first)
+ }
+ }
+ last := s.FindSegment(r.End - 1)
+ if last.Ok() {
+ if next := last.NextSegment(); next.Ok() {
+ s.Merge(last, next)
+ }
+ }
+}
+
+// Split splits the given segment at the given key and returns iterators to the
+// two resulting segments. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+//
+// If the segment cannot be split at split (because split is at the start or
+// end of the segment's range, so splitting would produce a segment with zero
+// length, or because split falls outside the segment's range altogether),
+// Split panics.
+func (s *Set) Split(seg Iterator, split Key) (Iterator, Iterator) {
+ if !seg.Range().CanSplitAt(split) {
+ panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split))
+ }
+ return s.SplitUnchecked(seg, split)
+}
+
+// SplitUnchecked splits the given segment at the given key and returns
+// iterators to the two resulting segments. All existing iterators (including
+// seg, but not including the returned iterators) are invalidated.
+//
+// Preconditions: seg.Start() < key < seg.End().
+func (s *Set) SplitUnchecked(seg Iterator, split Key) (Iterator, Iterator) {
+ val1, val2 := (Functions{}).Split(seg.Range(), seg.Value(), split)
+ end2 := seg.End()
+ seg.SetEndUnchecked(split)
+ seg.SetValue(val1)
+ seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), Range{split, end2}, val2)
+ // seg may now be invalid due to the Insert.
+ return seg2.PrevSegment(), seg2
+}
+
+// SplitAt splits the segment straddling split, if one exists. SplitAt returns
+// true if a segment was split and false otherwise. If SplitAt splits a
+// segment, all existing iterators are invalidated.
+func (s *Set) SplitAt(split Key) bool {
+ if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) {
+ s.SplitUnchecked(seg, split)
+ return true
+ }
+ return false
+}
+
+// Isolate ensures that the given segment's range does not escape r by
+// splitting at r.Start and r.End if necessary, and returns an updated iterator
+// to the bounded segment. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+func (s *Set) Isolate(seg Iterator, r Range) Iterator {
+ if seg.Range().CanSplitAt(r.Start) {
+ _, seg = s.SplitUnchecked(seg, r.Start)
+ }
+ if seg.Range().CanSplitAt(r.End) {
+ seg, _ = s.SplitUnchecked(seg, r.End)
+ }
+ return seg
+}
+
+// ApplyContiguous applies a function to a contiguous range of segments,
+// splitting if necessary. The function is applied until the first gap is
+// encountered, at which point the gap is returned. If the function is applied
+// across the entire range, a terminal gap is returned. All existing iterators
+// are invalidated.
+//
+// N.B. The Iterator must not be invalidated by the function.
+func (s *Set) ApplyContiguous(r Range, fn func(seg Iterator)) GapIterator {
+ seg, gap := s.Find(r.Start)
+ if !seg.Ok() {
+ return gap
+ }
+ for {
+ seg = s.Isolate(seg, r)
+ fn(seg)
+ if seg.End() >= r.End {
+ return GapIterator{}
+ }
+ gap = seg.NextGap()
+ if !gap.IsEmpty() {
+ return gap
+ }
+ seg = gap.NextSegment()
+ if !seg.Ok() {
+ // This implies that the last segment extended all the
+ // way to the maximum value, since the gap was empty.
+ return GapIterator{}
+ }
+ }
+}
+
+// +stateify savable
+type node struct {
+ // An internal binary tree node looks like:
+ //
+ // K
+ // / \
+ // Cl Cr
+ //
+ // where all keys in the subtree rooted by Cl (the left subtree) are less
+ // than K (the key of the parent node), and all keys in the subtree rooted
+ // by Cr (the right subtree) are greater than K.
+ //
+ // An internal B-tree node's indexes work out to look like:
+ //
+ // K0 K1 K2 ... Kn-1
+ // / \/ \/ \ ... / \
+ // C0 C1 C2 C3 ... Cn-1 Cn
+ //
+ // where n is nrSegments.
+ nrSegments int
+
+ // parent is a pointer to this node's parent. If this node is root, parent
+ // is nil.
+ parent *node
+
+ // parentIndex is the index of this node in parent.children.
+ parentIndex int
+
+ // Flag for internal nodes that is technically redundant with "children[0]
+ // != nil", but is stored in the first cache line. "hasChildren" rather
+ // than "isLeaf" because false must be the correct value for an empty root.
+ hasChildren bool
+
+ // The longest gap within this node. If the node is a leaf, it's simply the
+ // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys
+ // including the 0th and nrSegments-th gap possibly shared with its upper-level
+ // nodes; if it's a non-leaf node, it's the max of all children's maxGap.
+ maxGap dynamicGap
+
+ // Nodes store keys and values in separate arrays to maximize locality in
+ // the common case (scanning keys for lookup).
+ keys [maxDegree - 1]Range
+ values [maxDegree - 1]Value
+ children [maxDegree]*node
+}
+
+// firstSegment returns the first segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *node) firstSegment() Iterator {
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return Iterator{n, 0}
+}
+
+// lastSegment returns the last segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *node) lastSegment() Iterator {
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return Iterator{n, n.nrSegments - 1}
+}
+
+func (n *node) prevSibling() *node {
+ if n.parent == nil || n.parentIndex == 0 {
+ return nil
+ }
+ return n.parent.children[n.parentIndex-1]
+}
+
+func (n *node) nextSibling() *node {
+ if n.parent == nil || n.parentIndex == n.parent.nrSegments {
+ return nil
+ }
+ return n.parent.children[n.parentIndex+1]
+}
+
+// rebalanceBeforeInsert splits n and its ancestors if they are full, as
+// required for insertion, and returns an updated iterator to the position
+// represented by gap.
+func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator {
+ if n.nrSegments < maxDegree-1 {
+ return gap
+ }
+ if n.parent != nil {
+ gap = n.parent.rebalanceBeforeInsert(gap)
+ }
+ if n.parent == nil {
+ // n is root. Move all segments before and after n's median segment
+ // into new child nodes adjacent to the median segment, which is now
+ // the only segment in root.
+ left := &node{
+ nrSegments: minDegree - 1,
+ parent: n,
+ parentIndex: 0,
+ hasChildren: n.hasChildren,
+ }
+ right := &node{
+ nrSegments: minDegree - 1,
+ parent: n,
+ parentIndex: 1,
+ hasChildren: n.hasChildren,
+ }
+ copy(left.keys[:minDegree-1], n.keys[:minDegree-1])
+ copy(left.values[:minDegree-1], n.values[:minDegree-1])
+ copy(right.keys[:minDegree-1], n.keys[minDegree:])
+ copy(right.values[:minDegree-1], n.values[minDegree:])
+ n.keys[0], n.values[0] = n.keys[minDegree-1], n.values[minDegree-1]
+ zeroValueSlice(n.values[1:])
+ if n.hasChildren {
+ copy(left.children[:minDegree], n.children[:minDegree])
+ copy(right.children[:minDegree], n.children[minDegree:])
+ zeroNodeSlice(n.children[2:])
+ for i := 0; i < minDegree; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ right.children[i].parent = right
+ right.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = 1
+ n.hasChildren = true
+ n.children[0] = left
+ n.children[1] = right
+ // In this case, n's maxGap won't violated as it's still the root,
+ // but the left and right children should be updated locally as they
+ // are newly split from n.
+ if trackGaps != 0 {
+ left.updateMaxGapLocal()
+ right.updateMaxGapLocal()
+ }
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < minDegree {
+ return GapIterator{left, gap.index}
+ }
+ return GapIterator{right, gap.index - minDegree}
+ }
+ // n is non-root. Move n's median segment into its parent node (which can't
+ // be full because we've already invoked n.parent.rebalanceBeforeInsert)
+ // and move all segments after n's median into a new sibling node (the
+ // median segment's right child subtree).
+ copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments])
+ copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments])
+ n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[minDegree-1], n.values[minDegree-1]
+ copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1])
+ for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ {
+ n.parent.children[i].parentIndex = i
+ }
+ sibling := &node{
+ nrSegments: minDegree - 1,
+ parent: n.parent,
+ parentIndex: n.parentIndex + 1,
+ hasChildren: n.hasChildren,
+ }
+ n.parent.children[n.parentIndex+1] = sibling
+ n.parent.nrSegments++
+ copy(sibling.keys[:minDegree-1], n.keys[minDegree:])
+ copy(sibling.values[:minDegree-1], n.values[minDegree:])
+ zeroValueSlice(n.values[minDegree-1:])
+ if n.hasChildren {
+ copy(sibling.children[:minDegree], n.children[minDegree:])
+ zeroNodeSlice(n.children[minDegree:])
+ for i := 0; i < minDegree; i++ {
+ sibling.children[i].parent = sibling
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = minDegree - 1
+ // MaxGap of n's parent is not violated because the segments within is not changed.
+ // n and its sibling's maxGap need to be updated locally as they are two new nodes split from old n.
+ if trackGaps != 0 {
+ n.updateMaxGapLocal()
+ sibling.updateMaxGapLocal()
+ }
+ // gap.node can't be n.parent because gaps are always in leaf nodes.
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < minDegree {
+ return gap
+ }
+ return GapIterator{sibling, gap.index - minDegree}
+}
+
+// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient
+// (contain fewer segments than required by B-tree invariants), as required for
+// removal, and returns an updated iterator to the position represented by gap.
+//
+// Precondition: n is the only node in the tree that may currently violate a
+// B-tree invariant.
+func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator {
+ for {
+ if n.nrSegments >= minDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+ // Root is allowed to be deficient.
+ return gap
+ }
+ // There's one other thing we can do before resorting to unsplitting.
+ // If either sibling node has at least minDegree segments, rotate that
+ // sibling's closest segment through the segment in the parent that
+ // separates us. That is, given:
+ //
+ // ... D ...
+ // / \
+ // ... B C] [E ...
+ //
+ // where the node containing E is deficient, end up with:
+ //
+ // ... C ...
+ // / \
+ // ... B] [D E ...
+ //
+ // As in Set.Remove, prefer rotating from the end of the sibling to the
+ // left: by precondition, n.node has fewer segments (to memcpy) than
+ // the sibling does.
+ if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= minDegree {
+ copy(n.keys[1:], n.keys[:n.nrSegments])
+ copy(n.values[1:], n.values[:n.nrSegments])
+ n.keys[0] = n.parent.keys[n.parentIndex-1]
+ n.values[0] = n.parent.values[n.parentIndex-1]
+ n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1]
+ n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1]
+ Functions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ copy(n.children[1:], n.children[:n.nrSegments+1])
+ n.children[0] = sibling.children[sibling.nrSegments]
+ sibling.children[sibling.nrSegments] = nil
+ n.children[0].parent = n
+ n.children[0].parentIndex = 0
+ for i := 1; i < n.nrSegments+2; i++ {
+ n.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ // n's parent's maxGap does not need to be updated as its content is unmodified.
+ // n and its sibling must be updated with (new) maxGap because of the shift of keys.
+ if trackGaps != 0 {
+ n.updateMaxGapLocal()
+ sibling.updateMaxGapLocal()
+ }
+ if gap.node == sibling && gap.index == sibling.nrSegments {
+ return GapIterator{n, 0}
+ }
+ if gap.node == n {
+ return GapIterator{n, gap.index + 1}
+ }
+ return gap
+ }
+ if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= minDegree {
+ n.keys[n.nrSegments] = n.parent.keys[n.parentIndex]
+ n.values[n.nrSegments] = n.parent.values[n.parentIndex]
+ n.parent.keys[n.parentIndex] = sibling.keys[0]
+ n.parent.values[n.parentIndex] = sibling.values[0]
+ copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:])
+ copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:])
+ Functions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ n.children[n.nrSegments+1] = sibling.children[0]
+ copy(sibling.children[:sibling.nrSegments], sibling.children[1:])
+ sibling.children[sibling.nrSegments] = nil
+ n.children[n.nrSegments+1].parent = n
+ n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1
+ for i := 0; i < sibling.nrSegments; i++ {
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ // n's parent's maxGap does not need to be updated as its content is unmodified.
+ // n and its sibling must be updated with (new) maxGap because of the shift of keys.
+ if trackGaps != 0 {
+ n.updateMaxGapLocal()
+ sibling.updateMaxGapLocal()
+ }
+ if gap.node == sibling {
+ if gap.index == 0 {
+ return GapIterator{n, n.nrSegments}
+ }
+ return GapIterator{sibling, gap.index - 1}
+ }
+ return gap
+ }
+ // Otherwise, we must unsplit.
+ p := n.parent
+ if p.nrSegments == 1 {
+ // Merge all segments in both n and its sibling back into n.parent.
+ // This is the reverse of the root splitting case in
+ // node.rebalanceBeforeInsert. (Because we require minDegree >= 3,
+ // only root can have 1 segment in this path, so this reduces the
+ // height of the tree by 1, without violating the constraint that
+ // all leaf nodes remain at the same depth.)
+ left, right := p.children[0], p.children[1]
+ p.nrSegments = left.nrSegments + right.nrSegments + 1
+ p.hasChildren = left.hasChildren
+ p.keys[left.nrSegments] = p.keys[0]
+ p.values[left.nrSegments] = p.values[0]
+ copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments])
+ copy(p.values[:left.nrSegments], left.values[:left.nrSegments])
+ copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1])
+ copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := 0; i < p.nrSegments+1; i++ {
+ p.children[i].parent = p
+ p.children[i].parentIndex = i
+ }
+ } else {
+ p.children[0] = nil
+ p.children[1] = nil
+ }
+ // No need to update maxGap of p as its content is not changed.
+ if gap.node == left {
+ return GapIterator{p, gap.index}
+ }
+ if gap.node == right {
+ return GapIterator{p, gap.index + left.nrSegments + 1}
+ }
+ return gap
+ }
+ // Merge n and either sibling, along with the segment separating the
+ // two, into whichever of the two nodes comes first. This is the
+ // reverse of the non-root splitting case in
+ // node.rebalanceBeforeInsert.
+ var left, right *node
+ if n.parentIndex > 0 {
+ left = n.prevSibling()
+ right = n
+ } else {
+ left = n
+ right = n.nextSibling()
+ }
+ // Fix up gap first since we need the old left.nrSegments, which
+ // merging will change.
+ if gap.node == right {
+ gap = GapIterator{left, gap.index + left.nrSegments + 1}
+ }
+ left.keys[left.nrSegments] = p.keys[left.parentIndex]
+ left.values[left.nrSegments] = p.values[left.parentIndex]
+ copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ }
+ }
+ left.nrSegments += right.nrSegments + 1
+ copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments])
+ copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments])
+ Functions{}.ClearValue(&p.values[p.nrSegments-1])
+ copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1])
+ for i := 0; i < p.nrSegments; i++ {
+ p.children[i].parentIndex = i
+ }
+ p.children[p.nrSegments] = nil
+ p.nrSegments--
+ // Update maxGap of left locally, no need to change p and right because
+ // p's contents is not changed and right is already invalid.
+ if trackGaps != 0 {
+ left.updateMaxGapLocal()
+ }
+ // This process robs p of one segment, so recurse into rebalancing p.
+ n = p
+ }
+}
+
+// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no
+// necessary update.
+//
+// Preconditions: n must be a leaf node, trackGaps must be 1.
+func (n *node) updateMaxGapLeaf() {
+ if n.hasChildren {
+ panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n))
+ }
+ max := n.calculateMaxGapLeaf()
+ if max == n.maxGap.Get() {
+ // If new max equals the old maxGap, no update is needed.
+ return
+ }
+ oldMax := n.maxGap.Get()
+ n.maxGap.Set(max)
+ if max > oldMax {
+ // Grow ancestor maxGaps.
+ for p := n.parent; p != nil; p = p.parent {
+ if p.maxGap.Get() >= max {
+ // p and its ancestors already contain an equal or larger gap.
+ break
+ }
+ // Only if new maxGap is larger than parent's
+ // old maxGap, propagate this update to parent.
+ p.maxGap.Set(max)
+ }
+ return
+ }
+ // Shrink ancestor maxGaps.
+ for p := n.parent; p != nil; p = p.parent {
+ if p.maxGap.Get() > oldMax {
+ // p and its ancestors still contain a larger gap.
+ break
+ }
+ // If new max is smaller than the old maxGap, and this gap used
+ // to be the maxGap of its parent, iterate parent's children
+ // and calculate parent's new maxGap.(It's probable that parent
+ // has two children with the old maxGap, but we need to check it anyway.)
+ parentNewMax := p.calculateMaxGapInternal()
+ if p.maxGap.Get() == parentNewMax {
+ // p and its ancestors still contain a gap of at least equal size.
+ break
+ }
+ // If p's new maxGap differs from the old one, propagate this update.
+ p.maxGap.Set(parentNewMax)
+ }
+}
+
+// updateMaxGapLocal updates maxGap of the calling node solely with no
+// propagation to ancestor nodes.
+//
+// Precondition: trackGaps must be 1.
+func (n *node) updateMaxGapLocal() {
+ if !n.hasChildren {
+ // Leaf node iterates its gaps.
+ n.maxGap.Set(n.calculateMaxGapLeaf())
+ } else {
+ // Non-leaf node iterates its children.
+ n.maxGap.Set(n.calculateMaxGapInternal())
+ }
+}
+
+// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the
+// max.
+//
+// Preconditions: n must be a leaf node.
+func (n *node) calculateMaxGapLeaf() Key {
+ max := GapIterator{n, 0}.Range().Length()
+ for i := 1; i <= n.nrSegments; i++ {
+ if current := (GapIterator{n, i}).Range().Length(); current > max {
+ max = current
+ }
+ }
+ return max
+}
+
+// calculateMaxGapInternal iterates children's maxGap within an internal node n
+// and calculate the max.
+//
+// Preconditions: n must be a non-leaf node.
+func (n *node) calculateMaxGapInternal() Key {
+ max := n.children[0].maxGap.Get()
+ for i := 1; i <= n.nrSegments; i++ {
+ if current := n.children[i].maxGap.Get(); current > max {
+ max = current
+ }
+ }
+ return max
+}
+
+// searchFirstLargeEnoughGap returns the first gap having at least minSize length
+// in the subtree rooted by n. If not found, return a terminal gap iterator.
+func (n *node) searchFirstLargeEnoughGap(minSize Key) GapIterator {
+ if n.maxGap.Get() < minSize {
+ return GapIterator{}
+ }
+ if n.hasChildren {
+ for i := 0; i <= n.nrSegments; i++ {
+ if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ }
+ } else {
+ for i := 0; i <= n.nrSegments; i++ {
+ currentGap := GapIterator{n, i}
+ if currentGap.Range().Length() >= minSize {
+ return currentGap
+ }
+ }
+ }
+ panic(fmt.Sprintf("invalid maxGap in %v", n))
+}
+
+// searchLastLargeEnoughGap returns the last gap having at least minSize length
+// in the subtree rooted by n. If not found, return a terminal gap iterator.
+func (n *node) searchLastLargeEnoughGap(minSize Key) GapIterator {
+ if n.maxGap.Get() < minSize {
+ return GapIterator{}
+ }
+ if n.hasChildren {
+ for i := n.nrSegments; i >= 0; i-- {
+ if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ }
+ } else {
+ for i := n.nrSegments; i >= 0; i-- {
+ currentGap := GapIterator{n, i}
+ if currentGap.Range().Length() >= minSize {
+ return currentGap
+ }
+ }
+ }
+ panic(fmt.Sprintf("invalid maxGap in %v", n))
+}
+
+// A Iterator is conceptually one of:
+//
+// - A pointer to a segment in a set; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Iterators are copyable values and are meaningfully equality-comparable. The
+// zero value of Iterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type Iterator struct {
+ // node is the node containing the iterated segment. If the iterator is
+ // terminal, node is nil.
+ node *node
+
+ // index is the index of the segment in node.keys/values.
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (seg Iterator) Ok() bool {
+ return seg.node != nil
+}
+
+// Range returns the iterated segment's range key.
+func (seg Iterator) Range() Range {
+ return seg.node.keys[seg.index]
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (seg Iterator) Start() Key {
+ return seg.node.keys[seg.index].Start
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (seg Iterator) End() Key {
+ return seg.node.keys[seg.index].End
+}
+
+// SetRangeUnchecked mutates the iterated segment's range key. This operation
+// does not invalidate any iterators.
+//
+// Preconditions:
+//
+// - r.Length() > 0.
+//
+// - The new range must not overlap an existing one: If seg.NextSegment().Ok(),
+// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then
+// r.start >= seg.PrevSegment().End().
+func (seg Iterator) SetRangeUnchecked(r Range) {
+ seg.node.keys[seg.index] = r
+}
+
+// SetRange mutates the iterated segment's range key. If the new range would
+// cause the iterated segment to overlap another segment, or if the new range
+// is invalid, SetRange panics. This operation does not invalidate any
+// iterators.
+func (seg Iterator) SetRange(r Range) {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && r.End > next.Start() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range()))
+ }
+ seg.SetRangeUnchecked(r)
+}
+
+// SetStartUnchecked mutates the iterated segment's start. This operation does
+// not invalidate any iterators.
+//
+// Preconditions: The new start must be valid: start < seg.End(); if
+// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End().
+func (seg Iterator) SetStartUnchecked(start Key) {
+ seg.node.keys[seg.index].Start = start
+}
+
+// SetStart mutates the iterated segment's start. If the new start value would
+// cause the iterated segment to overlap another segment, or would result in an
+// invalid range, SetStart panics. This operation does not invalidate any
+// iterators.
+func (seg Iterator) SetStart(start Key) {
+ if start >= seg.End() {
+ panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range()))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() {
+ panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range()))
+ }
+ seg.SetStartUnchecked(start)
+}
+
+// SetEndUnchecked mutates the iterated segment's end. This operation does not
+// invalidate any iterators.
+//
+// Preconditions: The new end must be valid: end > seg.Start(); if
+// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start().
+func (seg Iterator) SetEndUnchecked(end Key) {
+ seg.node.keys[seg.index].End = end
+}
+
+// SetEnd mutates the iterated segment's end. If the new end value would cause
+// the iterated segment to overlap another segment, or would result in an
+// invalid range, SetEnd panics. This operation does not invalidate any
+// iterators.
+func (seg Iterator) SetEnd(end Key) {
+ if end <= seg.Start() {
+ panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && end > next.Start() {
+ panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range()))
+ }
+ seg.SetEndUnchecked(end)
+}
+
+// Value returns a copy of the iterated segment's value.
+func (seg Iterator) Value() Value {
+ return seg.node.values[seg.index]
+}
+
+// ValuePtr returns a pointer to the iterated segment's value. The pointer is
+// invalidated if the iterator is invalidated. This operation does not
+// invalidate any iterators.
+func (seg Iterator) ValuePtr() *Value {
+ return &seg.node.values[seg.index]
+}
+
+// SetValue mutates the iterated segment's value. This operation does not
+// invalidate any iterators.
+func (seg Iterator) SetValue(val Value) {
+ seg.node.values[seg.index] = val
+}
+
+// PrevSegment returns the iterated segment's predecessor. If there is no
+// preceding segment, PrevSegment returns a terminal iterator.
+func (seg Iterator) PrevSegment() Iterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index].lastSegment()
+ }
+ if seg.index > 0 {
+ return Iterator{seg.node, seg.index - 1}
+ }
+ if seg.node.parent == nil {
+ return Iterator{}
+ }
+ return segmentBeforePosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// NextSegment returns the iterated segment's successor. If there is no
+// succeeding segment, NextSegment returns a terminal iterator.
+func (seg Iterator) NextSegment() Iterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment()
+ }
+ if seg.index < seg.node.nrSegments-1 {
+ return Iterator{seg.node, seg.index + 1}
+ }
+ if seg.node.parent == nil {
+ return Iterator{}
+ }
+ return segmentAfterPosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// PrevGap returns the gap immediately before the iterated segment.
+func (seg Iterator) PrevGap() GapIterator {
+ if seg.node.hasChildren {
+ // Note that this isn't recursive because the last segment in a subtree
+ // must be in a leaf node.
+ return seg.node.children[seg.index].lastSegment().NextGap()
+ }
+ return GapIterator{seg.node, seg.index}
+}
+
+// NextGap returns the gap immediately after the iterated segment.
+func (seg Iterator) NextGap() GapIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment().PrevGap()
+ }
+ return GapIterator{seg.node, seg.index + 1}
+}
+
+// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent,
+// or the gap before the iterated segment otherwise. If seg.Start() ==
+// Functions.MinKey(), PrevNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be
+// non-terminal.
+func (seg Iterator) PrevNonEmpty() (Iterator, GapIterator) {
+ gap := seg.PrevGap()
+ if gap.Range().Length() != 0 {
+ return Iterator{}, gap
+ }
+ return gap.PrevSegment(), GapIterator{}
+}
+
+// NextNonEmpty returns the iterated segment's successor if it is adjacent, or
+// the gap after the iterated segment otherwise. If seg.End() ==
+// Functions.MaxKey(), NextNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by NextNonEmpty will be
+// non-terminal.
+func (seg Iterator) NextNonEmpty() (Iterator, GapIterator) {
+ gap := seg.NextGap()
+ if gap.Range().Length() != 0 {
+ return Iterator{}, gap
+ }
+ return gap.NextSegment(), GapIterator{}
+}
+
+// A GapIterator is conceptually one of:
+//
+// - A pointer to a position between two segments, before the first segment, or
+// after the last segment in a set, called a *gap*; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Note that the gap between two adjacent segments exists (iterators to it are
+// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true
+// for such gaps. An empty set contains a single gap, spanning the entire range
+// of the set's keys.
+//
+// GapIterators are copyable values and are meaningfully equality-comparable.
+// The zero value of GapIterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type GapIterator struct {
+ // The representation of a GapIterator is identical to that of an Iterator,
+ // except that index corresponds to positions between segments in the same
+ // way as for node.children (see comment for node.nrSegments).
+ node *node
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (gap GapIterator) Ok() bool {
+ return gap.node != nil
+}
+
+// Range returns the range spanned by the iterated gap.
+func (gap GapIterator) Range() Range {
+ return Range{gap.Start(), gap.End()}
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (gap GapIterator) Start() Key {
+ if ps := gap.PrevSegment(); ps.Ok() {
+ return ps.End()
+ }
+ return Functions{}.MinKey()
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (gap GapIterator) End() Key {
+ if ns := gap.NextSegment(); ns.Ok() {
+ return ns.Start()
+ }
+ return Functions{}.MaxKey()
+}
+
+// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is
+// between two adjacent segments.)
+func (gap GapIterator) IsEmpty() bool {
+ return gap.Range().Length() == 0
+}
+
+// PrevSegment returns the segment immediately before the iterated gap. If no
+// such segment exists, PrevSegment returns a terminal iterator.
+func (gap GapIterator) PrevSegment() Iterator {
+ return segmentBeforePosition(gap.node, gap.index)
+}
+
+// NextSegment returns the segment immediately after the iterated gap. If no
+// such segment exists, NextSegment returns a terminal iterator.
+func (gap GapIterator) NextSegment() Iterator {
+ return segmentAfterPosition(gap.node, gap.index)
+}
+
+// PrevGap returns the iterated gap's predecessor. If no such gap exists,
+// PrevGap returns a terminal iterator.
+func (gap GapIterator) PrevGap() GapIterator {
+ seg := gap.PrevSegment()
+ if !seg.Ok() {
+ return GapIterator{}
+ }
+ return seg.PrevGap()
+}
+
+// NextGap returns the iterated gap's successor. If no such gap exists, NextGap
+// returns a terminal iterator.
+func (gap GapIterator) NextGap() GapIterator {
+ seg := gap.NextSegment()
+ if !seg.Ok() {
+ return GapIterator{}
+ }
+ return seg.NextGap()
+}
+
+// NextLargeEnoughGap returns the iterated gap's first next gap with larger
+// length than minSize. If not found, return a terminal gap iterator (does NOT
+// include this gap itself).
+//
+// Precondition: trackGaps must be 1.
+func (gap GapIterator) NextLargeEnoughGap(minSize Key) GapIterator {
+ if trackGaps != 1 {
+ panic("set is not tracking gaps")
+ }
+ if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments {
+ // If gap is the trailing gap of an non-leaf node,
+ // translate it to the equivalent gap on leaf level.
+ gap.node = gap.NextSegment().node
+ gap.index = 0
+ return gap.nextLargeEnoughGapHelper(minSize)
+ }
+ return gap.nextLargeEnoughGapHelper(minSize)
+}
+
+// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap
+// to do the real recursions.
+//
+// Preconditions: gap is NOT the trailing gap of a non-leaf node.
+func (gap GapIterator) nextLargeEnoughGapHelper(minSize Key) GapIterator {
+ // Crawl up the tree if no large enough gap in current node or the
+ // current gap is the trailing one on leaf level.
+ for gap.node != nil &&
+ (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) {
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+ // If no large enough gap throughout the whole set, return a terminal
+ // gap iterator.
+ if gap.node == nil {
+ return GapIterator{}
+ }
+ // Iterate subsequent gaps.
+ gap.index++
+ for gap.index <= gap.node.nrSegments {
+ if gap.node.hasChildren {
+ if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ } else {
+ if gap.Range().Length() >= minSize {
+ return gap
+ }
+ }
+ gap.index++
+ }
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ if gap.node != nil && gap.index == gap.node.nrSegments {
+ // If gap is the trailing gap of a non-leaf node, crawl up to
+ // parent again and do recursion.
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+ return gap.nextLargeEnoughGapHelper(minSize)
+}
+
+// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or
+// equal length than minSize. If not found, return a terminal gap iterator
+// (does NOT include this gap itself).
+//
+// Precondition: trackGaps must be 1.
+func (gap GapIterator) PrevLargeEnoughGap(minSize Key) GapIterator {
+ if trackGaps != 1 {
+ panic("set is not tracking gaps")
+ }
+ if gap.node != nil && gap.node.hasChildren && gap.index == 0 {
+ // If gap is the first gap of an non-leaf node,
+ // translate it to the equivalent gap on leaf level.
+ gap.node = gap.PrevSegment().node
+ gap.index = gap.node.nrSegments
+ return gap.prevLargeEnoughGapHelper(minSize)
+ }
+ return gap.prevLargeEnoughGapHelper(minSize)
+}
+
+// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap
+// to do the real recursions.
+//
+// Preconditions: gap is NOT the first gap of a non-leaf node.
+func (gap GapIterator) prevLargeEnoughGapHelper(minSize Key) GapIterator {
+ // Crawl up the tree if no large enough gap in current node or the
+ // current gap is the first one on leaf level.
+ for gap.node != nil &&
+ (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) {
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+ // If no large enough gap throughout the whole set, return a terminal
+ // gap iterator.
+ if gap.node == nil {
+ return GapIterator{}
+ }
+ // Iterate previous gaps.
+ gap.index--
+ for gap.index >= 0 {
+ if gap.node.hasChildren {
+ if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ } else {
+ if gap.Range().Length() >= minSize {
+ return gap
+ }
+ }
+ gap.index--
+ }
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ if gap.node != nil && gap.index == 0 {
+ // If gap is the first gap of a non-leaf node, crawl up to
+ // parent again and do recursion.
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+ return gap.prevLargeEnoughGapHelper(minSize)
+}
+
+// segmentBeforePosition returns the predecessor segment of the position given
+// by n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentBeforePosition returns a terminal iterator.
+func segmentBeforePosition(n *node, i int) Iterator {
+ for i == 0 {
+ if n.parent == nil {
+ return Iterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return Iterator{n, i - 1}
+}
+
+// segmentAfterPosition returns the successor segment of the position given by
+// n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentAfterPosition returns a terminal iterator.
+func segmentAfterPosition(n *node, i int) Iterator {
+ for i == n.nrSegments {
+ if n.parent == nil {
+ return Iterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return Iterator{n, i}
+}
+
+func zeroValueSlice(slice []Value) {
+ // TODO(jamieliu): check if Go is actually smart enough to optimize a
+ // ClearValue that assigns nil to a memset here.
+ for i := range slice {
+ Functions{}.ClearValue(&slice[i])
+ }
+}
+
+func zeroNodeSlice(slice []*node) {
+ for i := range slice {
+ slice[i] = nil
+ }
+}
+
+// String stringifies a Set for debugging.
+func (s *Set) String() string {
+ return s.root.String()
+}
+
+// String stringifies a node (and all of its children) for debugging.
+func (n *node) String() string {
+ var buf bytes.Buffer
+ n.writeDebugString(&buf, "")
+ return buf.String()
+}
+
+func (n *node) writeDebugString(buf *bytes.Buffer, prefix string) {
+ if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) {
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren))
+ }
+ for i := 0; i < n.nrSegments; i++ {
+ if child := n.children[i]; child != nil {
+ cprefix := fmt.Sprintf("%s- % 3d ", prefix, i)
+ if child.parent != n || child.parentIndex != i {
+ buf.WriteString(cprefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i))
+ }
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i))
+ }
+ buf.WriteString(prefix)
+ if n.hasChildren {
+ if trackGaps != 0 {
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get()))
+ } else {
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ } else {
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ }
+ if child := n.children[n.nrSegments]; child != nil {
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments))
+ }
+}
+
+// SegmentDataSlices represents segments from a set as slices of start, end, and
+// values. SegmentDataSlices is primarily used as an intermediate representation
+// for save/restore and the layout here is optimized for that.
+//
+// +stateify savable
+type SegmentDataSlices struct {
+ Start []Key
+ End []Key
+ Values []Value
+}
+
+// ExportSortedSlice returns a copy of all segments in the given set, in ascending
+// key order.
+func (s *Set) ExportSortedSlices() *SegmentDataSlices {
+ var sds SegmentDataSlices
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sds.Start = append(sds.Start, seg.Start())
+ sds.End = append(sds.End, seg.End())
+ sds.Values = append(sds.Values, seg.Value())
+ }
+ sds.Start = sds.Start[:len(sds.Start):len(sds.Start)]
+ sds.End = sds.End[:len(sds.End):len(sds.End)]
+ sds.Values = sds.Values[:len(sds.Values):len(sds.Values)]
+ return &sds
+}
+
+// ImportSortedSlice initializes the given set from the given slice.
+//
+// Preconditions: s must be empty. sds must represent a valid set (the segments
+// in sds must have valid lengths that do not overlap). The segments in sds
+// must be sorted in ascending key order.
+func (s *Set) ImportSortedSlices(sds *SegmentDataSlices) error {
+ if !s.IsEmpty() {
+ return fmt.Errorf("cannot import into non-empty set %v", s)
+ }
+ gap := s.FirstGap()
+ for i := range sds.Start {
+ r := Range{sds.Start[i], sds.End[i]}
+ if !gap.Range().IsSupersetOf(r) {
+ return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i])
+ }
+ gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap()
+ }
+ return nil
+}
+
+// segmentTestCheck returns an error if s is incorrectly sorted, does not
+// contain exactly expectedSegments segments, or contains a segment which
+// fails the passed check.
+//
+// This should be used only for testing, and has been added to this package for
+// templating convenience.
+func (s *Set) segmentTestCheck(expectedSegments int, segFunc func(int, Range, Value) error) error {
+ havePrev := false
+ prev := Key(0)
+ nrSegments := 0
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ next := seg.Start()
+ if havePrev && prev >= next {
+ return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments)
+ }
+ if segFunc != nil {
+ if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil {
+ return err
+ }
+ }
+ prev = next
+ havePrev = true
+ nrSegments++
+ }
+ if nrSegments != expectedSegments {
+ return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments)
+ }
+ return nil
+}
+
+// countSegments counts the number of segments in the set.
+//
+// Similar to Check, this should only be used for testing.
+func (s *Set) countSegments() (segments int) {
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ segments++
+ }
+ return segments
+}
diff --git a/pkg/segment/set_state.go b/pkg/segment/set_state.go
new file mode 100644
index 000000000..76de92591
--- /dev/null
+++ b/pkg/segment/set_state.go
@@ -0,0 +1,25 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package segment
+
+func (s *Set) saveRoot() *SegmentDataSlices {
+ return s.ExportSortedSlices()
+}
+
+func (s *Set) loadRoot(sds *SegmentDataSlices) {
+ if err := s.ImportSortedSlices(sds); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD
new file mode 100644
index 000000000..131bf09b9
--- /dev/null
+++ b/pkg/segment/test/BUILD
@@ -0,0 +1,68 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(
+ default_visibility = ["//visibility:private"],
+ licenses = ["notice"],
+)
+
+go_template_instance(
+ name = "int_range",
+ out = "int_range.go",
+ package = "segment",
+ template = "//pkg/segment:generic_range",
+ types = {
+ "T": "int",
+ },
+)
+
+go_template_instance(
+ name = "int_set",
+ out = "int_set.go",
+ package = "segment",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "int",
+ "Range": "Range",
+ "Value": "int",
+ "Functions": "setFunctions",
+ },
+)
+
+go_template_instance(
+ name = "gap_set",
+ out = "gap_set.go",
+ consts = {
+ "trackGaps": "1",
+ },
+ package = "segment",
+ prefix = "gap",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "int",
+ "Range": "Range",
+ "Value": "int",
+ "Functions": "gapSetFunctions",
+ },
+)
+
+go_library(
+ name = "segment",
+ testonly = 1,
+ srcs = [
+ "gap_set.go",
+ "int_range.go",
+ "int_set.go",
+ "set_functions.go",
+ ],
+ deps = [
+ "//pkg/state",
+ ],
+)
+
+go_test(
+ name = "segment_test",
+ size = "small",
+ srcs = ["segment_test.go"],
+ library = ":segment",
+)
diff --git a/pkg/segment/test/segment_test.go b/pkg/segment/test/segment_test.go
new file mode 100644
index 000000000..85fa19096
--- /dev/null
+++ b/pkg/segment/test/segment_test.go
@@ -0,0 +1,865 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package segment
+
+import (
+ "fmt"
+ "math/rand"
+ "reflect"
+ "testing"
+)
+
+const (
+ // testSize is the baseline number of elements inserted into sets under
+ // test, and is chosen to be large enough to ensure interesting amounts of
+ // tree rebalancing.
+ //
+ // Note that because checkSet is called between each insertion/removal in
+ // some tests that use it, tests may be quadratic in testSize.
+ testSize = 8000
+
+ // valueOffset is the difference between the value and start of test
+ // segments.
+ valueOffset = 100000
+
+ // intervalLength is the interval used by random gap tests.
+ intervalLength = 10
+)
+
+func shuffle(xs []int) {
+ rand.Shuffle(len(xs), func(i, j int) { xs[i], xs[j] = xs[j], xs[i] })
+}
+
+func randIntervalPermutation(size int) []int {
+ p := make([]int, size)
+ for i := range p {
+ p[i] = intervalLength * i
+ }
+ shuffle(p)
+ return p
+}
+
+// validate can be passed to Check.
+func validate(nr int, r Range, v int) error {
+ if got, want := v, r.Start+valueOffset; got != want {
+ return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nr, r.Start, got, want)
+ }
+ return nil
+}
+
+// checkSetMaxGap returns an error if maxGap inside all nodes of s is not well
+// maintained.
+func checkSetMaxGap(s *gapSet) error {
+ n := s.root
+ return checkNodeMaxGap(&n)
+}
+
+// checkNodeMaxGap returns an error if maxGap inside the subtree rooted by n is
+// not well maintained.
+func checkNodeMaxGap(n *gapnode) error {
+ var max int
+ if !n.hasChildren {
+ max = n.calculateMaxGapLeaf()
+ } else {
+ for i := 0; i <= n.nrSegments; i++ {
+ child := n.children[i]
+ if err := checkNodeMaxGap(child); err != nil {
+ return err
+ }
+ if temp := child.maxGap.Get(); i == 0 || temp > max {
+ max = temp
+ }
+ }
+ }
+ if max != n.maxGap.Get() {
+ return fmt.Errorf("maxGap wrong in node\n%vexpected: %d got: %d", n, max, n.maxGap)
+ }
+ return nil
+}
+
+func TestAddRandom(t *testing.T) {
+ var s Set
+ order := rand.Perm(testSize)
+ var nrInsertions int
+ for i, j := range order {
+ if !s.AddWithoutMerging(Range{j, j + 1}, j+valueOffset) {
+ t.Errorf("Iteration %d: failed to insert segment with key %d", i, j)
+ break
+ }
+ nrInsertions++
+ if err := s.segmentTestCheck(nrInsertions, validate); err != nil {
+ t.Errorf("Iteration %d: %v", i, err)
+ break
+ }
+ }
+ if got, want := s.countSegments(), nrInsertions; got != want {
+ t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
+ }
+ if t.Failed() {
+ t.Logf("Insertion order: %v", order[:nrInsertions])
+ t.Logf("Set contents:\n%v", &s)
+ }
+}
+
+func TestRemoveRandom(t *testing.T) {
+ var s Set
+ for i := 0; i < testSize; i++ {
+ if !s.AddWithoutMerging(Range{i, i + 1}, i+valueOffset) {
+ t.Fatalf("Failed to insert segment %d", i)
+ }
+ }
+ order := rand.Perm(testSize)
+ var nrRemovals int
+ for i, j := range order {
+ seg := s.FindSegment(j)
+ if !seg.Ok() {
+ t.Errorf("Iteration %d: failed to find segment with key %d", i, j)
+ break
+ }
+ s.Remove(seg)
+ nrRemovals++
+ if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil {
+ t.Errorf("Iteration %d: %v", i, err)
+ break
+ }
+ }
+ if got, want := s.countSegments(), testSize-nrRemovals; got != want {
+ t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
+ }
+ if t.Failed() {
+ t.Logf("Removal order: %v", order[:nrRemovals])
+ t.Logf("Set contents:\n%v", &s)
+ t.FailNow()
+ }
+}
+
+func TestMaxGapAddRandom(t *testing.T) {
+ var s gapSet
+ order := rand.Perm(testSize)
+ var nrInsertions int
+ for i, j := range order {
+ if !s.AddWithoutMerging(Range{j, j + 1}, j+valueOffset) {
+ t.Errorf("Iteration %d: failed to insert segment with key %d", i, j)
+ break
+ }
+ nrInsertions++
+ if err := s.segmentTestCheck(nrInsertions, validate); err != nil {
+ t.Errorf("Iteration %d: %v", i, err)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When inserting %d: %v", j, err)
+ break
+ }
+ }
+ if got, want := s.countSegments(), nrInsertions; got != want {
+ t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
+ }
+ if t.Failed() {
+ t.Logf("Insertion order: %v", order[:nrInsertions])
+ t.Logf("Set contents:\n%v", &s)
+ }
+}
+
+func TestMaxGapAddRandomWithRandomInterval(t *testing.T) {
+ var s gapSet
+ order := randIntervalPermutation(testSize)
+ var nrInsertions int
+ for i, j := range order {
+ if !s.AddWithoutMerging(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) {
+ t.Errorf("Iteration %d: failed to insert segment with key %d", i, j)
+ break
+ }
+ nrInsertions++
+ if err := s.segmentTestCheck(nrInsertions, validate); err != nil {
+ t.Errorf("Iteration %d: %v", i, err)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When inserting %d: %v", j, err)
+ break
+ }
+ }
+ if got, want := s.countSegments(), nrInsertions; got != want {
+ t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
+ }
+ if t.Failed() {
+ t.Logf("Insertion order: %v", order[:nrInsertions])
+ t.Logf("Set contents:\n%v", &s)
+ }
+}
+
+func TestMaxGapAddRandomWithMerge(t *testing.T) {
+ var s gapSet
+ order := randIntervalPermutation(testSize)
+ nrInsertions := 1
+ for i, j := range order {
+ if !s.Add(Range{j, j + intervalLength}, j+valueOffset) {
+ t.Errorf("Iteration %d: failed to insert segment with key %d", i, j)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When inserting %d: %v", j, err)
+ break
+ }
+ }
+ if got, want := s.countSegments(), nrInsertions; got != want {
+ t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
+ }
+ if t.Failed() {
+ t.Logf("Insertion order: %v", order)
+ t.Logf("Set contents:\n%v", &s)
+ }
+}
+
+func TestMaxGapRemoveRandom(t *testing.T) {
+ var s gapSet
+ for i := 0; i < testSize; i++ {
+ if !s.AddWithoutMerging(Range{i, i + 1}, i+valueOffset) {
+ t.Fatalf("Failed to insert segment %d", i)
+ }
+ }
+ order := rand.Perm(testSize)
+ var nrRemovals int
+ for i, j := range order {
+ seg := s.FindSegment(j)
+ if !seg.Ok() {
+ t.Errorf("Iteration %d: failed to find segment with key %d", i, j)
+ break
+ }
+ temprange := seg.Range()
+ s.Remove(seg)
+ nrRemovals++
+ if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil {
+ t.Errorf("Iteration %d: %v", i, err)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When removing %v: %v", temprange, err)
+ break
+ }
+ }
+ if got, want := s.countSegments(), testSize-nrRemovals; got != want {
+ t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
+ }
+ if t.Failed() {
+ t.Logf("Removal order: %v", order[:nrRemovals])
+ t.Logf("Set contents:\n%v", &s)
+ t.FailNow()
+ }
+}
+
+func TestMaxGapRemoveHalfRandom(t *testing.T) {
+ var s gapSet
+ for i := 0; i < testSize; i++ {
+ if !s.AddWithoutMerging(Range{intervalLength * i, intervalLength*i + rand.Intn(intervalLength-1) + 1}, intervalLength*i+valueOffset) {
+ t.Fatalf("Failed to insert segment %d", i)
+ }
+ }
+ order := randIntervalPermutation(testSize)
+ order = order[:testSize/2]
+ var nrRemovals int
+ for i, j := range order {
+ seg := s.FindSegment(j)
+ if !seg.Ok() {
+ t.Errorf("Iteration %d: failed to find segment with key %d", i, j)
+ break
+ }
+ temprange := seg.Range()
+ s.Remove(seg)
+ nrRemovals++
+ if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil {
+ t.Errorf("Iteration %d: %v", i, err)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When removing %v: %v", temprange, err)
+ break
+ }
+ }
+ if got, want := s.countSegments(), testSize-nrRemovals; got != want {
+ t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
+ }
+ if t.Failed() {
+ t.Logf("Removal order: %v", order[:nrRemovals])
+ t.Logf("Set contents:\n%v", &s)
+ t.FailNow()
+ }
+}
+
+func TestMaxGapAddRandomRemoveRandomHalfWithMerge(t *testing.T) {
+ var s gapSet
+ order := randIntervalPermutation(testSize * 2)
+ order = order[:testSize]
+ for i, j := range order {
+ if !s.Add(Range{j, j + intervalLength}, j+valueOffset) {
+ t.Errorf("Iteration %d: failed to insert segment with key %d", i, j)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When inserting %d: %v", j, err)
+ break
+ }
+ }
+ shuffle(order)
+ var nrRemovals int
+ for _, j := range order {
+ seg := s.FindSegment(j)
+ if !seg.Ok() {
+ continue
+ }
+ temprange := seg.Range()
+ s.Remove(seg)
+ nrRemovals++
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When removing %v: %v", temprange, err)
+ break
+ }
+ }
+ if t.Failed() {
+ t.Logf("Removal order: %v", order[:nrRemovals])
+ t.Logf("Set contents:\n%v", &s)
+ t.FailNow()
+ }
+}
+
+func TestNextLargeEnoughGap(t *testing.T) {
+ var s gapSet
+ order := randIntervalPermutation(testSize * 2)
+ order = order[:testSize]
+ for i, j := range order {
+ if !s.Add(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) {
+ t.Errorf("Iteration %d: failed to insert segment with key %d", i, j)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When inserting %d: %v", j, err)
+ break
+ }
+ }
+ shuffle(order)
+ order = order[:testSize/2]
+ for _, j := range order {
+ seg := s.FindSegment(j)
+ if !seg.Ok() {
+ continue
+ }
+ temprange := seg.Range()
+ s.Remove(seg)
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When removing %v: %v", temprange, err)
+ break
+ }
+ }
+ minSize := 7
+ var gapArr1 []int
+ for gap := s.LowerBoundGap(0).NextLargeEnoughGap(minSize); gap.Ok(); gap = gap.NextLargeEnoughGap(minSize) {
+ if gap.Range().Length() < minSize {
+ t.Errorf("NextLargeEnoughGap wrong, gap %v has length %d, wanted %d", gap.Range(), gap.Range().Length(), minSize)
+ } else {
+ gapArr1 = append(gapArr1, gap.Range().Start)
+ }
+ }
+ var gapArr2 []int
+ for gap := s.LowerBoundGap(0).NextGap(); gap.Ok(); gap = gap.NextGap() {
+ if gap.Range().Length() >= minSize {
+ gapArr2 = append(gapArr2, gap.Range().Start)
+ }
+ }
+
+ if !reflect.DeepEqual(gapArr2, gapArr1) {
+ t.Errorf("Search result not correct, got: %v, wanted: %v", gapArr1, gapArr2)
+ }
+ if t.Failed() {
+ t.Logf("Set contents:\n%v", &s)
+ t.FailNow()
+ }
+}
+
+func TestPrevLargeEnoughGap(t *testing.T) {
+ var s gapSet
+ order := randIntervalPermutation(testSize * 2)
+ order = order[:testSize]
+ for i, j := range order {
+ if !s.Add(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) {
+ t.Errorf("Iteration %d: failed to insert segment with key %d", i, j)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When inserting %d: %v", j, err)
+ break
+ }
+ }
+ end := s.LastSegment().End()
+ shuffle(order)
+ order = order[:testSize/2]
+ for _, j := range order {
+ seg := s.FindSegment(j)
+ if !seg.Ok() {
+ continue
+ }
+ temprange := seg.Range()
+ s.Remove(seg)
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When removing %v: %v", temprange, err)
+ break
+ }
+ }
+ minSize := 7
+ var gapArr1 []int
+ for gap := s.UpperBoundGap(end + intervalLength).PrevLargeEnoughGap(minSize); gap.Ok(); gap = gap.PrevLargeEnoughGap(minSize) {
+ if gap.Range().Length() < minSize {
+ t.Errorf("PrevLargeEnoughGap wrong, gap length %d, wanted %d", gap.Range().Length(), minSize)
+ } else {
+ gapArr1 = append(gapArr1, gap.Range().Start)
+ }
+ }
+ var gapArr2 []int
+ for gap := s.UpperBoundGap(end + intervalLength).PrevGap(); gap.Ok(); gap = gap.PrevGap() {
+ if gap.Range().Length() >= minSize {
+ gapArr2 = append(gapArr2, gap.Range().Start)
+ }
+ }
+ if !reflect.DeepEqual(gapArr2, gapArr1) {
+ t.Errorf("Search result not correct, got: %v, wanted: %v", gapArr1, gapArr2)
+ }
+ if t.Failed() {
+ t.Logf("Set contents:\n%v", &s)
+ t.FailNow()
+ }
+}
+
+func TestAddSequentialAdjacent(t *testing.T) {
+ var s Set
+ var nrInsertions int
+ for i := 0; i < testSize; i++ {
+ if !s.AddWithoutMerging(Range{i, i + 1}, i+valueOffset) {
+ t.Fatalf("Failed to insert segment %d", i)
+ }
+ nrInsertions++
+ if err := s.segmentTestCheck(nrInsertions, validate); err != nil {
+ t.Errorf("Iteration %d: %v", i, err)
+ break
+ }
+ }
+ if got, want := s.countSegments(), nrInsertions; got != want {
+ t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
+ }
+ if t.Failed() {
+ t.Logf("Set contents:\n%v", &s)
+ }
+
+ first := s.FirstSegment()
+ gotSeg, gotGap := first.PrevNonEmpty()
+ if wantGap := s.FirstGap(); gotSeg.Ok() || gotGap != wantGap {
+ t.Errorf("FirstSegment().PrevNonEmpty(): got (%v, %v), wanted (<terminal iterator>, %v)", gotSeg, gotGap, wantGap)
+ }
+ gotSeg, gotGap = first.NextNonEmpty()
+ if wantSeg := first.NextSegment(); gotSeg != wantSeg || gotGap.Ok() {
+ t.Errorf("FirstSegment().NextNonEmpty(): got (%v, %v), wanted (%v, <terminal iterator>)", gotSeg, gotGap, wantSeg)
+ }
+
+ last := s.LastSegment()
+ gotSeg, gotGap = last.PrevNonEmpty()
+ if wantSeg := last.PrevSegment(); gotSeg != wantSeg || gotGap.Ok() {
+ t.Errorf("LastSegment().PrevNonEmpty(): got (%v, %v), wanted (%v, <terminal iterator>)", gotSeg, gotGap, wantSeg)
+ }
+ gotSeg, gotGap = last.NextNonEmpty()
+ if wantGap := s.LastGap(); gotSeg.Ok() || gotGap != wantGap {
+ t.Errorf("LastSegment().NextNonEmpty(): got (%v, %v), wanted (<terminal iterator>, %v)", gotSeg, gotGap, wantGap)
+ }
+
+ for seg := first.NextSegment(); seg != last; seg = seg.NextSegment() {
+ gotSeg, gotGap = seg.PrevNonEmpty()
+ if wantSeg := seg.PrevSegment(); gotSeg != wantSeg || gotGap.Ok() {
+ t.Errorf("%v.PrevNonEmpty(): got (%v, %v), wanted (%v, <terminal iterator>)", seg, gotSeg, gotGap, wantSeg)
+ }
+ gotSeg, gotGap = seg.NextNonEmpty()
+ if wantSeg := seg.NextSegment(); gotSeg != wantSeg || gotGap.Ok() {
+ t.Errorf("%v.NextNonEmpty(): got (%v, %v), wanted (%v, <terminal iterator>)", seg, gotSeg, gotGap, wantSeg)
+ }
+ }
+}
+
+func TestAddSequentialNonAdjacent(t *testing.T) {
+ var s Set
+ var nrInsertions int
+ for i := 0; i < testSize; i++ {
+ // The range here differs from TestAddSequentialAdjacent so that
+ // consecutive segments are not adjacent.
+ if !s.AddWithoutMerging(Range{2 * i, 2*i + 1}, 2*i+valueOffset) {
+ t.Fatalf("Failed to insert segment %d", i)
+ }
+ nrInsertions++
+ if err := s.segmentTestCheck(nrInsertions, validate); err != nil {
+ t.Errorf("Iteration %d: %v", i, err)
+ break
+ }
+ }
+ if got, want := s.countSegments(), nrInsertions; got != want {
+ t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
+ }
+ if t.Failed() {
+ t.Logf("Set contents:\n%v", &s)
+ }
+
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ gotSeg, gotGap := seg.PrevNonEmpty()
+ if wantGap := seg.PrevGap(); gotSeg.Ok() || gotGap != wantGap {
+ t.Errorf("%v.PrevNonEmpty(): got (%v, %v), wanted (<terminal iterator>, %v)", seg, gotSeg, gotGap, wantGap)
+ }
+ gotSeg, gotGap = seg.NextNonEmpty()
+ if wantGap := seg.NextGap(); gotSeg.Ok() || gotGap != wantGap {
+ t.Errorf("%v.NextNonEmpty(): got (%v, %v), wanted (<terminal iterator>, %v)", seg, gotSeg, gotGap, wantGap)
+ }
+ }
+}
+
+func TestMergeSplit(t *testing.T) {
+ tests := []struct {
+ name string
+ initial []Range
+ split bool
+ splitAddr int
+ final []Range
+ }{
+ {
+ name: "Add merges after existing segment",
+ initial: []Range{{1000, 1100}, {1100, 1200}},
+ final: []Range{{1000, 1200}},
+ },
+ {
+ name: "Add merges before existing segment",
+ initial: []Range{{1100, 1200}, {1000, 1100}},
+ final: []Range{{1000, 1200}},
+ },
+ {
+ name: "Add merges between existing segments",
+ initial: []Range{{1000, 1100}, {1200, 1300}, {1100, 1200}},
+ final: []Range{{1000, 1300}},
+ },
+ {
+ name: "SplitAt does nothing at a free address",
+ initial: []Range{{100, 200}},
+ split: true,
+ splitAddr: 300,
+ final: []Range{{100, 200}},
+ },
+ {
+ name: "SplitAt does nothing at the beginning of a segment",
+ initial: []Range{{100, 200}},
+ split: true,
+ splitAddr: 100,
+ final: []Range{{100, 200}},
+ },
+ {
+ name: "SplitAt does nothing at the end of a segment",
+ initial: []Range{{100, 200}},
+ split: true,
+ splitAddr: 200,
+ final: []Range{{100, 200}},
+ },
+ {
+ name: "SplitAt splits in the middle of a segment",
+ initial: []Range{{100, 200}},
+ split: true,
+ splitAddr: 150,
+ final: []Range{{100, 150}, {150, 200}},
+ },
+ }
+Tests:
+ for _, test := range tests {
+ var s Set
+ for _, r := range test.initial {
+ if !s.Add(r, 0) {
+ t.Errorf("%s: Add(%v) failed; set contents:\n%v", test.name, r, &s)
+ continue Tests
+ }
+ }
+ if test.split {
+ s.SplitAt(test.splitAddr)
+ }
+ var i int
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ if i > len(test.final) {
+ t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, s.countSegments(), len(test.final), &s)
+ continue Tests
+ }
+ if got, want := seg.Range(), test.final[i]; got != want {
+ t.Errorf("%s: Segment %d mismatch: got %v, wanted %v; set contents:\n%v", test.name, i, got, want, &s)
+ continue Tests
+ }
+ i++
+ }
+ if i < len(test.final) {
+ t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, i, len(test.final), &s)
+ }
+ }
+}
+
+func TestIsolate(t *testing.T) {
+ tests := []struct {
+ name string
+ initial Range
+ bounds Range
+ final []Range
+ }{
+ {
+ name: "Isolate does not split a segment that falls inside bounds",
+ initial: Range{100, 200},
+ bounds: Range{100, 200},
+ final: []Range{{100, 200}},
+ },
+ {
+ name: "Isolate splits at beginning of segment",
+ initial: Range{50, 200},
+ bounds: Range{100, 200},
+ final: []Range{{50, 100}, {100, 200}},
+ },
+ {
+ name: "Isolate splits at end of segment",
+ initial: Range{100, 250},
+ bounds: Range{100, 200},
+ final: []Range{{100, 200}, {200, 250}},
+ },
+ {
+ name: "Isolate splits at beginning and end of segment",
+ initial: Range{50, 250},
+ bounds: Range{100, 200},
+ final: []Range{{50, 100}, {100, 200}, {200, 250}},
+ },
+ }
+Tests:
+ for _, test := range tests {
+ var s Set
+ seg := s.Insert(s.FirstGap(), test.initial, 0)
+ seg = s.Isolate(seg, test.bounds)
+ if !test.bounds.IsSupersetOf(seg.Range()) {
+ t.Errorf("%s: Isolated segment %v lies outside bounds %v; set contents:\n%v", test.name, seg.Range(), test.bounds, &s)
+ }
+ var i int
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ if i > len(test.final) {
+ t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, s.countSegments(), len(test.final), &s)
+ continue Tests
+ }
+ if got, want := seg.Range(), test.final[i]; got != want {
+ t.Errorf("%s: Segment %d mismatch: got %v, wanted %v; set contents:\n%v", test.name, i, got, want, &s)
+ continue Tests
+ }
+ i++
+ }
+ if i < len(test.final) {
+ t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, i, len(test.final), &s)
+ }
+ }
+}
+
+func benchmarkAddSequential(b *testing.B, size int) {
+ for n := 0; n < b.N; n++ {
+ var s Set
+ for i := 0; i < size; i++ {
+ if !s.AddWithoutMerging(Range{i, i + 1}, i) {
+ b.Fatalf("Failed to insert segment %d", i)
+ }
+ }
+ }
+}
+
+func benchmarkAddRandom(b *testing.B, size int) {
+ order := rand.Perm(size)
+
+ b.ResetTimer()
+ for n := 0; n < b.N; n++ {
+ var s Set
+ for _, i := range order {
+ if !s.AddWithoutMerging(Range{i, i + 1}, i) {
+ b.Fatalf("Failed to insert segment %d", i)
+ }
+ }
+ }
+}
+
+func benchmarkFindSequential(b *testing.B, size int) {
+ var s Set
+ for i := 0; i < size; i++ {
+ if !s.AddWithoutMerging(Range{i, i + 1}, i) {
+ b.Fatalf("Failed to insert segment %d", i)
+ }
+ }
+
+ b.ResetTimer()
+ for n := 0; n < b.N; n++ {
+ for i := 0; i < size; i++ {
+ if seg := s.FindSegment(i); !seg.Ok() {
+ b.Fatalf("Failed to find segment %d", i)
+ }
+ }
+ }
+}
+
+func benchmarkFindRandom(b *testing.B, size int) {
+ var s Set
+ for i := 0; i < size; i++ {
+ if !s.AddWithoutMerging(Range{i, i + 1}, i) {
+ b.Fatalf("Failed to insert segment %d", i)
+ }
+ }
+ order := rand.Perm(size)
+
+ b.ResetTimer()
+ for n := 0; n < b.N; n++ {
+ for _, i := range order {
+ if si := s.FindSegment(i); !si.Ok() {
+ b.Fatalf("Failed to find segment %d", i)
+ }
+ }
+ }
+}
+
+func benchmarkIteration(b *testing.B, size int) {
+ var s Set
+ for i := 0; i < size; i++ {
+ if !s.AddWithoutMerging(Range{i, i + 1}, i) {
+ b.Fatalf("Failed to insert segment %d", i)
+ }
+ }
+
+ b.ResetTimer()
+ var count uint64
+ for n := 0; n < b.N; n++ {
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ count++
+ }
+ }
+ if got, want := count, uint64(size)*uint64(b.N); got != want {
+ b.Fatalf("Iterated wrong number of segments: got %d, wanted %d", got, want)
+ }
+}
+
+func benchmarkAddFindRemoveSequential(b *testing.B, size int) {
+ for n := 0; n < b.N; n++ {
+ var s Set
+ for i := 0; i < size; i++ {
+ if !s.AddWithoutMerging(Range{i, i + 1}, i) {
+ b.Fatalf("Failed to insert segment %d", i)
+ }
+ }
+ for i := 0; i < size; i++ {
+ seg := s.FindSegment(i)
+ if !seg.Ok() {
+ b.Fatalf("Failed to find segment %d", i)
+ }
+ s.Remove(seg)
+ }
+ if !s.IsEmpty() {
+ b.Fatalf("Set not empty after all removals:\n%v", &s)
+ }
+ }
+}
+
+func benchmarkAddFindRemoveRandom(b *testing.B, size int) {
+ order := rand.Perm(size)
+
+ b.ResetTimer()
+ for n := 0; n < b.N; n++ {
+ var s Set
+ for _, i := range order {
+ if !s.AddWithoutMerging(Range{i, i + 1}, i) {
+ b.Fatalf("Failed to insert segment %d", i)
+ }
+ }
+ for _, i := range order {
+ seg := s.FindSegment(i)
+ if !seg.Ok() {
+ b.Fatalf("Failed to find segment %d", i)
+ }
+ s.Remove(seg)
+ }
+ if !s.IsEmpty() {
+ b.Fatalf("Set not empty after all removals:\n%v", &s)
+ }
+ }
+}
+
+// Although we don't generally expect our segment sets to get this big, they're
+// useful for emulating the effect of cache pressure.
+var testSizes = []struct {
+ desc string
+ size int
+}{
+ {"64", 1 << 6},
+ {"256", 1 << 8},
+ {"1K", 1 << 10},
+ {"4K", 1 << 12},
+ {"16K", 1 << 14},
+ {"64K", 1 << 16},
+}
+
+func BenchmarkAddSequential(b *testing.B) {
+ for _, test := range testSizes {
+ b.Run(test.desc, func(b *testing.B) {
+ benchmarkAddSequential(b, test.size)
+ })
+ }
+}
+
+func BenchmarkAddRandom(b *testing.B) {
+ for _, test := range testSizes {
+ b.Run(test.desc, func(b *testing.B) {
+ benchmarkAddRandom(b, test.size)
+ })
+ }
+}
+
+func BenchmarkFindSequential(b *testing.B) {
+ for _, test := range testSizes {
+ b.Run(test.desc, func(b *testing.B) {
+ benchmarkFindSequential(b, test.size)
+ })
+ }
+}
+
+func BenchmarkFindRandom(b *testing.B) {
+ for _, test := range testSizes {
+ b.Run(test.desc, func(b *testing.B) {
+ benchmarkFindRandom(b, test.size)
+ })
+ }
+}
+
+func BenchmarkIteration(b *testing.B) {
+ for _, test := range testSizes {
+ b.Run(test.desc, func(b *testing.B) {
+ benchmarkIteration(b, test.size)
+ })
+ }
+}
+
+func BenchmarkAddFindRemoveSequential(b *testing.B) {
+ for _, test := range testSizes {
+ b.Run(test.desc, func(b *testing.B) {
+ benchmarkAddFindRemoveSequential(b, test.size)
+ })
+ }
+}
+
+func BenchmarkAddFindRemoveRandom(b *testing.B) {
+ for _, test := range testSizes {
+ b.Run(test.desc, func(b *testing.B) {
+ benchmarkAddFindRemoveRandom(b, test.size)
+ })
+ }
+}
diff --git a/pkg/segment/test/set_functions.go b/pkg/segment/test/set_functions.go
new file mode 100644
index 000000000..7cd895cc7
--- /dev/null
+++ b/pkg/segment/test/set_functions.go
@@ -0,0 +1,54 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package segment
+
+type setFunctions struct{}
+
+// MinKey returns the minimum key for the set.
+func (s setFunctions) MinKey() int {
+ return -s.MaxKey() - 1
+}
+
+// MaxKey returns the maximum key for the set.
+func (setFunctions) MaxKey() int {
+ return int(^uint(0) >> 1)
+}
+
+func (setFunctions) ClearValue(*int) {}
+
+func (setFunctions) Merge(_ Range, val1 int, _ Range, _ int) (int, bool) {
+ return val1, true
+}
+
+func (setFunctions) Split(_ Range, val int, _ int) (int, int) {
+ return val, val
+}
+
+type gapSetFunctions struct {
+ setFunctions
+}
+
+// MinKey is adjusted to make sure no add overflow would happen in test cases.
+// e.g. A gap with range {MinInt32, 2} would cause overflow in Range().Length().
+//
+// Normally Keys should be unsigned to avoid these issues.
+func (s gapSetFunctions) MinKey() int {
+ return s.setFunctions.MinKey() / 2
+}
+
+// MaxKey returns the maximum key for the set.
+func (s gapSetFunctions) MaxKey() int {
+ return s.setFunctions.MaxKey() / 2
+}
diff --git a/pkg/sentry/BUILD b/pkg/sentry/BUILD
new file mode 100644
index 000000000..e759dc36f
--- /dev/null
+++ b/pkg/sentry/BUILD
@@ -0,0 +1,14 @@
+package(licenses = ["notice"])
+
+# The "internal" package_group should be used as much as possible by packages
+# that should remain Sentry-internal (i.e. not be exposed directly to command
+# line tooling or APIs).
+package_group(
+ name = "internal",
+ packages = [
+ "//pkg/sentry/...",
+ "//runsc/...",
+ # Code generated by go_marshal relies on go_marshal libraries.
+ "//tools/go_marshal/...",
+ ],
+)
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
new file mode 100644
index 000000000..901e0f320
--- /dev/null
+++ b/pkg/sentry/arch/BUILD
@@ -0,0 +1,48 @@
+load("//tools:defs.bzl", "go_library", "proto_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "arch",
+ srcs = [
+ "aligned.go",
+ "arch.go",
+ "arch_aarch64.go",
+ "arch_amd64.go",
+ "arch_amd64.s",
+ "arch_arm64.go",
+ "arch_state_x86.go",
+ "arch_x86.go",
+ "arch_x86_impl.go",
+ "auxv.go",
+ "signal.go",
+ "signal_act.go",
+ "signal_amd64.go",
+ "signal_arm64.go",
+ "signal_info.go",
+ "signal_stack.go",
+ "stack.go",
+ "syscalls_amd64.go",
+ "syscalls_arm64.go",
+ ],
+ marshal = True,
+ visibility = ["//:sandbox"],
+ deps = [
+ ":registers_go_proto",
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/cpuid",
+ "//pkg/log",
+ "//pkg/sentry/limits",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//tools/go_marshal/marshal",
+ ],
+)
+
+proto_library(
+ name = "registers",
+ srcs = ["registers.proto"],
+ visibility = ["//visibility:public"],
+)
diff --git a/pkg/sentry/arch/aligned.go b/pkg/sentry/arch/aligned.go
new file mode 100644
index 000000000..df01a903d
--- /dev/null
+++ b/pkg/sentry/arch/aligned.go
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import (
+ "reflect"
+)
+
+// alignedBytes returns a slice of size bytes, aligned in memory to the given
+// alignment. This is used because we require certain structures to be aligned
+// in a specific way (for example, the X86 floating point data).
+func alignedBytes(size, alignment uint) []byte {
+ data := make([]byte, size+alignment-1)
+ offset := uint(reflect.ValueOf(data).Index(0).Addr().Pointer() % uintptr(alignment))
+ if offset == 0 {
+ return data[:size:size]
+ }
+ return data[alignment-offset:][:size:size]
+}
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go
new file mode 100644
index 000000000..a903d031c
--- /dev/null
+++ b/pkg/sentry/arch/arch.go
@@ -0,0 +1,366 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package arch provides abstractions around architecture-dependent details,
+// such as syscall calling conventions, native types, etc.
+package arch
+
+import (
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Arch describes an architecture.
+type Arch int
+
+const (
+ // AMD64 is the x86-64 architecture.
+ AMD64 Arch = iota
+ // ARM64 is the aarch64 architecture.
+ ARM64
+)
+
+// String implements fmt.Stringer.
+func (a Arch) String() string {
+ switch a {
+ case AMD64:
+ return "amd64"
+ case ARM64:
+ return "arm64"
+ default:
+ return fmt.Sprintf("Arch(%d)", a)
+ }
+}
+
+// FloatingPointData is a generic type, and will always be passed as a pointer.
+// We rely on the individual arch implementations to meet all the necessary
+// requirements. For example, on x86 the region must be 16-byte aligned and 512
+// bytes in size.
+type FloatingPointData byte
+
+// Context provides architecture-dependent information for a specific thread.
+//
+// NOTE(b/34169503): Currently we use uintptr here to refer to a generic native
+// register value. While this will work for the foreseeable future, it isn't
+// strictly correct. We may want to create some abstraction that makes this
+// more clear or enables us to store values of arbitrary widths. This is
+// particularly true for RegisterMap().
+type Context interface {
+ // Arch returns the architecture for this Context.
+ Arch() Arch
+
+ // Native converts a generic type to a native value.
+ //
+ // Because the architecture is not specified here, we may be dealing
+ // with return values of varying sizes (for example ARCH_GETFS). This
+ // is a simple utility function to convert to the native size in these
+ // cases, and then we can CopyOut.
+ Native(val uintptr) interface{}
+
+ // Value converts a native type back to a generic value.
+ // Once a value has been converted to native via the above call -- it
+ // can be converted back here.
+ Value(val interface{}) uintptr
+
+ // Width returns the number of bytes for a native value.
+ Width() uint
+
+ // Fork creates a clone of the context.
+ Fork() Context
+
+ // SyscallNo returns the syscall number.
+ SyscallNo() uintptr
+
+ // SyscallSaveOrig save orignal register value.
+ SyscallSaveOrig()
+
+ // SyscallArgs returns the syscall arguments in an array.
+ SyscallArgs() SyscallArguments
+
+ // Return returns the return value for a system call.
+ Return() uintptr
+
+ // SetReturn sets the return value for a system call.
+ SetReturn(value uintptr)
+
+ // RestartSyscall reverses over the current syscall instruction, such that
+ // when the application resumes execution the syscall will be re-attempted.
+ RestartSyscall()
+
+ // RestartSyscallWithRestartBlock reverses over the current syscall
+ // instraction and overwrites the current syscall number with that of
+ // restart_syscall(2). This causes the application to restart the current
+ // syscall with a custom function when execution resumes.
+ RestartSyscallWithRestartBlock()
+
+ // IP returns the current instruction pointer.
+ IP() uintptr
+
+ // SetIP sets the current instruction pointer.
+ SetIP(value uintptr)
+
+ // Stack returns the current stack pointer.
+ Stack() uintptr
+
+ // SetStack sets the current stack pointer.
+ SetStack(value uintptr)
+
+ // TLS returns the current TLS pointer.
+ TLS() uintptr
+
+ // SetTLS sets the current TLS pointer. Returns false if value is invalid.
+ SetTLS(value uintptr) bool
+
+ // SetOldRSeqInterruptedIP sets the register that contains the old IP
+ // when an "old rseq" restartable sequence is interrupted.
+ SetOldRSeqInterruptedIP(value uintptr)
+
+ // StateData returns a pointer to underlying architecture state.
+ StateData() *State
+
+ // RegisterMap returns a map of all registers.
+ RegisterMap() (map[string]uintptr, error)
+
+ // NewSignalAct returns a new object that is equivalent to struct sigaction
+ // in the guest architecture.
+ NewSignalAct() NativeSignalAct
+
+ // NewSignalStack returns a new object that is equivalent to stack_t in the
+ // guest architecture.
+ NewSignalStack() NativeSignalStack
+
+ // SignalSetup modifies the context in preparation for handling the
+ // given signal.
+ //
+ // st is the stack where the signal handler frame should be
+ // constructed.
+ //
+ // act is the SignalAct that specifies how this signal is being
+ // handled.
+ //
+ // info is the SignalInfo of the signal being delivered.
+ //
+ // alt is the alternate signal stack (even if the alternate signal
+ // stack is not going to be used).
+ //
+ // sigset is the signal mask before entering the signal handler.
+ SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error
+
+ // SignalRestore restores context after returning from a signal
+ // handler.
+ //
+ // st is the current thread stack.
+ //
+ // rt is true if SignalRestore is being entered from rt_sigreturn and
+ // false if SignalRestore is being entered from sigreturn.
+ // SignalRestore returns the thread's new signal mask.
+ SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error)
+
+ // CPUIDEmulate emulates a CPUID instruction according to current register state.
+ CPUIDEmulate(l log.Logger)
+
+ // SingleStep returns true if single stepping is enabled.
+ SingleStep() bool
+
+ // SetSingleStep enables single stepping.
+ SetSingleStep()
+
+ // ClearSingleStep disables single stepping.
+ ClearSingleStep()
+
+ // FloatingPointData will be passed to underlying save routines.
+ FloatingPointData() *FloatingPointData
+
+ // NewMmapLayout returns a layout for a new MM, where MinAddr for the
+ // returned layout must be no lower than min, and MaxAddr for the returned
+ // layout must be no higher than max. Repeated calls to NewMmapLayout may
+ // return different layouts.
+ NewMmapLayout(min, max usermem.Addr, limits *limits.LimitSet) (MmapLayout, error)
+
+ // PIELoadAddress returns a preferred load address for a
+ // position-independent executable within l.
+ PIELoadAddress(l MmapLayout) usermem.Addr
+
+ // FeatureSet returns the FeatureSet in use in this context.
+ FeatureSet() *cpuid.FeatureSet
+
+ // Hack around our package dependences being too broken to support the
+ // equivalent of arch_ptrace():
+
+ // PtracePeekUser implements ptrace(PTRACE_PEEKUSR).
+ PtracePeekUser(addr uintptr) (interface{}, error)
+
+ // PtracePokeUser implements ptrace(PTRACE_POKEUSR).
+ PtracePokeUser(addr, data uintptr) error
+
+ // PtraceGetRegs implements ptrace(PTRACE_GETREGS) by writing the
+ // general-purpose registers represented by this Context to dst and
+ // returning the number of bytes written.
+ PtraceGetRegs(dst io.Writer) (int, error)
+
+ // PtraceSetRegs implements ptrace(PTRACE_SETREGS) by reading
+ // general-purpose registers from src into this Context and returning the
+ // number of bytes read.
+ PtraceSetRegs(src io.Reader) (int, error)
+
+ // PtraceGetFPRegs implements ptrace(PTRACE_GETFPREGS) by writing the
+ // floating-point registers represented by this Context to addr in dst and
+ // returning the number of bytes written.
+ PtraceGetFPRegs(dst io.Writer) (int, error)
+
+ // PtraceSetFPRegs implements ptrace(PTRACE_SETFPREGS) by reading
+ // floating-point registers from src into this Context and returning the
+ // number of bytes read.
+ PtraceSetFPRegs(src io.Reader) (int, error)
+
+ // PtraceGetRegSet implements ptrace(PTRACE_GETREGSET) by writing the
+ // register set given by architecture-defined value regset from this
+ // Context to dst and returning the number of bytes written, which must be
+ // less than or equal to maxlen.
+ PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error)
+
+ // PtraceSetRegSet implements ptrace(PTRACE_SETREGSET) by reading the
+ // register set given by architecture-defined value regset from src and
+ // returning the number of bytes read, which must be less than or equal to
+ // maxlen.
+ PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error)
+
+ // FullRestore returns 'true' if all CPU registers must be restored
+ // when switching to the untrusted application. Typically a task enters
+ // and leaves the kernel via a system call. Platform.Switch() may
+ // optimize for this by not saving/restoring all registers if allowed
+ // by the ABI. For e.g. the amd64 ABI specifies that syscall clobbers
+ // %rcx and %r11. If FullRestore returns true then these optimizations
+ // must be disabled and all registers restored.
+ FullRestore() bool
+}
+
+// MmapDirection is a search direction for mmaps.
+type MmapDirection int
+
+const (
+ // MmapBottomUp instructs mmap to prefer lower addresses.
+ MmapBottomUp MmapDirection = iota
+
+ // MmapTopDown instructs mmap to prefer higher addresses.
+ MmapTopDown
+)
+
+// MmapLayout defines the layout of the user address space for a particular
+// MemoryManager.
+//
+// Note that "highest address" below is always exclusive.
+//
+// +stateify savable
+type MmapLayout struct {
+ // MinAddr is the lowest mappable address.
+ MinAddr usermem.Addr
+
+ // MaxAddr is the highest mappable address.
+ MaxAddr usermem.Addr
+
+ // BottomUpBase is the lowest address that may be returned for a
+ // MmapBottomUp mmap.
+ BottomUpBase usermem.Addr
+
+ // TopDownBase is the highest address that may be returned for a
+ // MmapTopDown mmap.
+ TopDownBase usermem.Addr
+
+ // DefaultDirection is the direction for most non-fixed mmaps in this
+ // layout.
+ DefaultDirection MmapDirection
+
+ // MaxStackRand is the maximum randomization to apply to stack
+ // allocations to maintain a proper gap between the stack and
+ // TopDownBase.
+ MaxStackRand uint64
+}
+
+// Valid returns true if this layout is valid.
+func (m *MmapLayout) Valid() bool {
+ if m.MinAddr > m.MaxAddr {
+ return false
+ }
+ if m.BottomUpBase < m.MinAddr {
+ return false
+ }
+ if m.BottomUpBase > m.MaxAddr {
+ return false
+ }
+ if m.TopDownBase < m.MinAddr {
+ return false
+ }
+ if m.TopDownBase > m.MaxAddr {
+ return false
+ }
+ return true
+}
+
+// SyscallArgument is an argument supplied to a syscall implementation. The
+// methods used to access the arguments are named after the ***C type name*** and
+// they convert to the closest Go type available. For example, Int() refers to a
+// 32-bit signed integer argument represented in Go as an int32.
+//
+// Using the accessor methods guarantees that the conversion between types is
+// correct, taking into account size and signedness (i.e., zero-extension vs
+// signed-extension).
+type SyscallArgument struct {
+ // Prefer to use accessor methods instead of 'Value' directly.
+ Value uintptr
+}
+
+// SyscallArguments represents the set of arguments passed to a syscall.
+type SyscallArguments [6]SyscallArgument
+
+// Pointer returns the usermem.Addr representation of a pointer argument.
+func (a SyscallArgument) Pointer() usermem.Addr {
+ return usermem.Addr(a.Value)
+}
+
+// Int returns the int32 representation of a 32-bit signed integer argument.
+func (a SyscallArgument) Int() int32 {
+ return int32(a.Value)
+}
+
+// Uint returns the uint32 representation of a 32-bit unsigned integer argument.
+func (a SyscallArgument) Uint() uint32 {
+ return uint32(a.Value)
+}
+
+// Int64 returns the int64 representation of a 64-bit signed integer argument.
+func (a SyscallArgument) Int64() int64 {
+ return int64(a.Value)
+}
+
+// Uint64 returns the uint64 representation of a 64-bit unsigned integer argument.
+func (a SyscallArgument) Uint64() uint64 {
+ return uint64(a.Value)
+}
+
+// SizeT returns the uint representation of a size_t argument.
+func (a SyscallArgument) SizeT() uint {
+ return uint(a.Value)
+}
+
+// ModeT returns the int representation of a mode_t argument.
+func (a SyscallArgument) ModeT() uint {
+ return uint(uint16(a.Value))
+}
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
new file mode 100644
index 000000000..daba8b172
--- /dev/null
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -0,0 +1,321 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package arch
+
+import (
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/log"
+ rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Registers represents the CPU registers for this architecture.
+type Registers = linux.PtraceRegs
+
+const (
+ // SyscallWidth is the width of insturctions.
+ SyscallWidth = 4
+
+ // fpsimdMagic is the magic number which is used in fpsimd_context.
+ fpsimdMagic = 0x46508001
+
+ // fpsimdContextSize is the size of fpsimd_context.
+ fpsimdContextSize = 0x210
+)
+
+// ARMTrapFlag is the mask for the trap flag.
+const ARMTrapFlag = uint64(1) << 21
+
+// aarch64FPState is aarch64 floating point state.
+type aarch64FPState []byte
+
+// initAarch64FPState sets up initial state.
+//
+// Related code in Linux kernel: fpsimd_flush_thread().
+// FPCR = FPCR_RM_RN (0x0 << 22).
+//
+// Currently, aarch64FPState is only a space of 0x210 length for fpstate.
+// The fp head is useless in sentry/ptrace/kvm.
+//
+func initAarch64FPState(data aarch64FPState) {
+}
+
+func newAarch64FPStateSlice() []byte {
+ return alignedBytes(4096, 16)[:fpsimdContextSize]
+}
+
+// newAarch64FPState returns an initialized floating point state.
+//
+// The returned state is large enough to store all floating point state
+// supported by host, even if the app won't use much of it due to a restricted
+// FeatureSet.
+func newAarch64FPState() aarch64FPState {
+ f := aarch64FPState(newAarch64FPStateSlice())
+ initAarch64FPState(f)
+ return f
+}
+
+// fork creates and returns an identical copy of the aarch64 floating point state.
+func (f aarch64FPState) fork() aarch64FPState {
+ n := aarch64FPState(newAarch64FPStateSlice())
+ copy(n, f)
+ return n
+}
+
+// FloatingPointData returns the raw data pointer.
+func (f aarch64FPState) FloatingPointData() *FloatingPointData {
+ return (*FloatingPointData)(&f[0])
+}
+
+// NewFloatingPointData returns a new floating point data blob.
+//
+// This is primarily for use in tests.
+func NewFloatingPointData() *FloatingPointData {
+ return (*FloatingPointData)(&(newAarch64FPState()[0]))
+}
+
+// State contains the common architecture bits for aarch64 (the build tag of this
+// file ensures it's only built on aarch64).
+type State struct {
+ // The system registers.
+ Regs Registers
+
+ // Our floating point state.
+ aarch64FPState `state:"wait"`
+
+ // TLS pointer
+ TPValue uint64
+
+ // FeatureSet is a pointer to the currently active feature set.
+ FeatureSet *cpuid.FeatureSet
+
+ // OrigR0 stores the value of register R0.
+ OrigR0 uint64
+}
+
+// Proto returns a protobuf representation of the system registers in State.
+func (s State) Proto() *rpb.Registers {
+ regs := &rpb.ARM64Registers{
+ R0: s.Regs.Regs[0],
+ R1: s.Regs.Regs[1],
+ R2: s.Regs.Regs[2],
+ R3: s.Regs.Regs[3],
+ R4: s.Regs.Regs[4],
+ R5: s.Regs.Regs[5],
+ R6: s.Regs.Regs[6],
+ R7: s.Regs.Regs[7],
+ R8: s.Regs.Regs[8],
+ R9: s.Regs.Regs[9],
+ R10: s.Regs.Regs[10],
+ R11: s.Regs.Regs[11],
+ R12: s.Regs.Regs[12],
+ R13: s.Regs.Regs[13],
+ R14: s.Regs.Regs[14],
+ R15: s.Regs.Regs[15],
+ R16: s.Regs.Regs[16],
+ R17: s.Regs.Regs[17],
+ R18: s.Regs.Regs[18],
+ R19: s.Regs.Regs[19],
+ R20: s.Regs.Regs[20],
+ R21: s.Regs.Regs[21],
+ R22: s.Regs.Regs[22],
+ R23: s.Regs.Regs[23],
+ R24: s.Regs.Regs[24],
+ R25: s.Regs.Regs[25],
+ R26: s.Regs.Regs[26],
+ R27: s.Regs.Regs[27],
+ R28: s.Regs.Regs[28],
+ R29: s.Regs.Regs[29],
+ R30: s.Regs.Regs[30],
+ Sp: s.Regs.Sp,
+ Pc: s.Regs.Pc,
+ Pstate: s.Regs.Pstate,
+ }
+ return &rpb.Registers{Arch: &rpb.Registers_Arm64{Arm64: regs}}
+}
+
+// Fork creates and returns an identical copy of the state.
+func (s *State) Fork() State {
+ return State{
+ Regs: s.Regs,
+ aarch64FPState: s.aarch64FPState.fork(),
+ TPValue: s.TPValue,
+ FeatureSet: s.FeatureSet,
+ OrigR0: s.OrigR0,
+ }
+}
+
+// StateData implements Context.StateData.
+func (s *State) StateData() *State {
+ return s
+}
+
+// CPUIDEmulate emulates a cpuid instruction.
+func (s *State) CPUIDEmulate(l log.Logger) {
+ // TODO(gvisor.dev/issue/1255): cpuid is not supported.
+}
+
+// SingleStep implements Context.SingleStep.
+func (s *State) SingleStep() bool {
+ return false
+}
+
+// SetSingleStep enables single stepping.
+func (s *State) SetSingleStep() {
+ // Set the trap flag.
+ // TODO(gvisor.dev/issue/1239): ptrace single-step is not supported.
+}
+
+// ClearSingleStep enables single stepping.
+func (s *State) ClearSingleStep() {
+ // Clear the trap flag.
+ // TODO(gvisor.dev/issue/1239): ptrace single-step is not supported.
+}
+
+// RegisterMap returns a map of all registers.
+func (s *State) RegisterMap() (map[string]uintptr, error) {
+ return map[string]uintptr{
+ "R0": uintptr(s.Regs.Regs[0]),
+ "R1": uintptr(s.Regs.Regs[1]),
+ "R2": uintptr(s.Regs.Regs[2]),
+ "R3": uintptr(s.Regs.Regs[3]),
+ "R4": uintptr(s.Regs.Regs[4]),
+ "R5": uintptr(s.Regs.Regs[5]),
+ "R6": uintptr(s.Regs.Regs[6]),
+ "R7": uintptr(s.Regs.Regs[7]),
+ "R8": uintptr(s.Regs.Regs[8]),
+ "R9": uintptr(s.Regs.Regs[9]),
+ "R10": uintptr(s.Regs.Regs[10]),
+ "R11": uintptr(s.Regs.Regs[11]),
+ "R12": uintptr(s.Regs.Regs[12]),
+ "R13": uintptr(s.Regs.Regs[13]),
+ "R14": uintptr(s.Regs.Regs[14]),
+ "R15": uintptr(s.Regs.Regs[15]),
+ "R16": uintptr(s.Regs.Regs[16]),
+ "R17": uintptr(s.Regs.Regs[17]),
+ "R18": uintptr(s.Regs.Regs[18]),
+ "R19": uintptr(s.Regs.Regs[19]),
+ "R20": uintptr(s.Regs.Regs[20]),
+ "R21": uintptr(s.Regs.Regs[21]),
+ "R22": uintptr(s.Regs.Regs[22]),
+ "R23": uintptr(s.Regs.Regs[23]),
+ "R24": uintptr(s.Regs.Regs[24]),
+ "R25": uintptr(s.Regs.Regs[25]),
+ "R26": uintptr(s.Regs.Regs[26]),
+ "R27": uintptr(s.Regs.Regs[27]),
+ "R28": uintptr(s.Regs.Regs[28]),
+ "R29": uintptr(s.Regs.Regs[29]),
+ "R30": uintptr(s.Regs.Regs[30]),
+ "Sp": uintptr(s.Regs.Sp),
+ "Pc": uintptr(s.Regs.Pc),
+ "Pstate": uintptr(s.Regs.Pstate),
+ }, nil
+}
+
+// PtraceGetRegs implements Context.PtraceGetRegs.
+func (s *State) PtraceGetRegs(dst io.Writer) (int, error) {
+ regs := s.ptraceGetRegs()
+ n, err := regs.WriteTo(dst)
+ return int(n), err
+}
+
+func (s *State) ptraceGetRegs() Registers {
+ return s.Regs
+}
+
+var registersSize = (*Registers)(nil).SizeBytes()
+
+// PtraceSetRegs implements Context.PtraceSetRegs.
+func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
+ var regs Registers
+ buf := make([]byte, registersSize)
+ if _, err := io.ReadFull(src, buf); err != nil {
+ return 0, err
+ }
+ regs.UnmarshalUnsafe(buf)
+ s.Regs = regs
+ return registersSize, nil
+}
+
+// PtraceGetFPRegs implements Context.PtraceGetFPRegs.
+func (s *State) PtraceGetFPRegs(dst io.Writer) (int, error) {
+ // TODO(gvisor.dev/issue/1238): floating-point is not supported.
+ return 0, nil
+}
+
+// PtraceSetFPRegs implements Context.PtraceSetFPRegs.
+func (s *State) PtraceSetFPRegs(src io.Reader) (int, error) {
+ // TODO(gvisor.dev/issue/1238): floating-point is not supported.
+ return 0, nil
+}
+
+// Register sets defined in include/uapi/linux/elf.h.
+const (
+ _NT_PRSTATUS = 1
+ _NT_PRFPREG = 2
+ _NT_ARM_TLS = 0x401
+)
+
+// PtraceGetRegSet implements Context.PtraceGetRegSet.
+func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) {
+ switch regset {
+ case _NT_PRSTATUS:
+ if maxlen < registersSize {
+ return 0, syserror.EFAULT
+ }
+ return s.PtraceGetRegs(dst)
+ default:
+ return 0, syserror.EINVAL
+ }
+}
+
+// PtraceSetRegSet implements Context.PtraceSetRegSet.
+func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) {
+ switch regset {
+ case _NT_PRSTATUS:
+ if maxlen < registersSize {
+ return 0, syserror.EFAULT
+ }
+ return s.PtraceSetRegs(src)
+ default:
+ return 0, syserror.EINVAL
+ }
+}
+
+// FullRestore indicates whether a full restore is required.
+func (s *State) FullRestore() bool {
+ return false
+}
+
+// New returns a new architecture context.
+func New(arch Arch, fs *cpuid.FeatureSet) Context {
+ switch arch {
+ case ARM64:
+ return &context64{
+ State{
+ aarch64FPState: newAarch64FPState(),
+ FeatureSet: fs,
+ },
+ []aarch64FPState(nil),
+ }
+ }
+ panic(fmt.Sprintf("unknown architecture %v", arch))
+}
diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go
new file mode 100644
index 000000000..3b3a0a272
--- /dev/null
+++ b/pkg/sentry/arch/arch_amd64.go
@@ -0,0 +1,328 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package arch
+
+import (
+ "bytes"
+ "fmt"
+ "math/rand"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Host specifies the host architecture.
+const Host = AMD64
+
+// These constants come directly from Linux.
+const (
+ // maxAddr64 is the maximum userspace address. It is TASK_SIZE in Linux
+ // for a 64-bit process.
+ maxAddr64 usermem.Addr = (1 << 47) - usermem.PageSize
+
+ // maxStackRand64 is the maximum randomization to apply to the stack.
+ // It is defined by arch/x86/mm/mmap.c:stack_maxrandom_size in Linux.
+ maxStackRand64 = 16 << 30 // 16 GB
+
+ // maxMmapRand64 is the maximum randomization to apply to the mmap
+ // layout. It is defined by arch/x86/mm/mmap.c:arch_mmap_rnd in Linux.
+ maxMmapRand64 = (1 << 28) * usermem.PageSize
+
+ // minGap64 is the minimum gap to leave at the top of the address space
+ // for the stack. It is defined by arch/x86/mm/mmap.c:MIN_GAP in Linux.
+ minGap64 = (128 << 20) + maxStackRand64
+
+ // preferredPIELoadAddr is the standard Linux position-independent
+ // executable base load address. It is ELF_ET_DYN_BASE in Linux.
+ //
+ // The Platform {Min,Max}UserAddress() may preclude loading at this
+ // address. See other preferredFoo comments below.
+ preferredPIELoadAddr usermem.Addr = maxAddr64 / 3 * 2
+)
+
+// These constants are selected as heuristics to help make the Platform's
+// potentially limited address space conform as closely to Linux as possible.
+const (
+ // Select a preferred minimum TopDownBase address.
+ //
+ // Some applications (TSAN and other *SANs) are very particular about
+ // the way the Linux mmap allocator layouts out the address space.
+ //
+ // TSAN in particular expects top down allocations to be made in the
+ // range [0x7e8000000000, 0x800000000000).
+ //
+ // The minimum TopDownBase on Linux would be:
+ // 0x800000000000 - minGap64 - maxMmapRand64 = 0x7efbf8000000.
+ //
+ // (minGap64 because TSAN uses a small RLIMIT_STACK.)
+ //
+ // 0x7e8000000000 is selected arbitrarily by TSAN to leave room for
+ // allocations below TopDownBase.
+ //
+ // N.B. ASAN and MSAN are more forgiving; ASAN allows allocations all
+ // the way down to 0x10007fff8000, and MSAN down to 0x700000000000.
+ //
+ // Of course, there is no hard minimum to allocation; an allocator can
+ // search all the way from TopDownBase to Min. However, TSAN declared
+ // their range "good enough".
+ //
+ // We would like to pick a TopDownBase such that it is unlikely that an
+ // allocator will select an address below TSAN's minimum. We achieve
+ // this by trying to leave a sizable gap below TopDownBase.
+ //
+ // This is all "preferred" because the layout min/max address may not
+ // allow us to select such a TopDownBase, in which case we have to fall
+ // back to a layout that TSAN may not be happy with.
+ preferredTopDownAllocMin usermem.Addr = 0x7e8000000000
+ preferredAllocationGap = 128 << 30 // 128 GB
+ preferredTopDownBaseMin = preferredTopDownAllocMin + preferredAllocationGap
+
+ // minMmapRand64 is the smallest we are willing to make the
+ // randomization to stay above preferredTopDownBaseMin.
+ minMmapRand64 = (1 << 26) * usermem.PageSize
+)
+
+// context64 represents an AMD64 context.
+//
+// +stateify savable
+type context64 struct {
+ State
+ sigFPState []x86FPState // fpstate to be restored on sigreturn.
+}
+
+// Arch implements Context.Arch.
+func (c *context64) Arch() Arch {
+ return AMD64
+}
+
+func (c *context64) copySigFPState() []x86FPState {
+ var sigfps []x86FPState
+ for _, s := range c.sigFPState {
+ sigfps = append(sigfps, s.fork())
+ }
+ return sigfps
+}
+
+// Fork returns an exact copy of this context.
+func (c *context64) Fork() Context {
+ return &context64{
+ State: c.State.Fork(),
+ sigFPState: c.copySigFPState(),
+ }
+}
+
+// Return returns the current syscall return value.
+func (c *context64) Return() uintptr {
+ return uintptr(c.Regs.Rax)
+}
+
+// SetReturn sets the syscall return value.
+func (c *context64) SetReturn(value uintptr) {
+ c.Regs.Rax = uint64(value)
+}
+
+// IP returns the current instruction pointer.
+func (c *context64) IP() uintptr {
+ return uintptr(c.Regs.Rip)
+}
+
+// SetIP sets the current instruction pointer.
+func (c *context64) SetIP(value uintptr) {
+ c.Regs.Rip = uint64(value)
+}
+
+// Stack returns the current stack pointer.
+func (c *context64) Stack() uintptr {
+ return uintptr(c.Regs.Rsp)
+}
+
+// SetStack sets the current stack pointer.
+func (c *context64) SetStack(value uintptr) {
+ c.Regs.Rsp = uint64(value)
+}
+
+// TLS returns the current TLS pointer.
+func (c *context64) TLS() uintptr {
+ return uintptr(c.Regs.Fs_base)
+}
+
+// SetTLS sets the current TLS pointer. Returns false if value is invalid.
+func (c *context64) SetTLS(value uintptr) bool {
+ if !isValidSegmentBase(uint64(value)) {
+ return false
+ }
+
+ c.Regs.Fs = 0
+ c.Regs.Fs_base = uint64(value)
+ return true
+}
+
+// SetOldRSeqInterruptedIP implements Context.SetOldRSeqInterruptedIP.
+func (c *context64) SetOldRSeqInterruptedIP(value uintptr) {
+ c.Regs.R10 = uint64(value)
+}
+
+// Native returns the native type for the given val.
+func (c *context64) Native(val uintptr) interface{} {
+ v := uint64(val)
+ return &v
+}
+
+// Value returns the generic val for the given native type.
+func (c *context64) Value(val interface{}) uintptr {
+ return uintptr(*val.(*uint64))
+}
+
+// Width returns the byte width of this architecture.
+func (c *context64) Width() uint {
+ return 8
+}
+
+// FeatureSet returns the FeatureSet in use.
+func (c *context64) FeatureSet() *cpuid.FeatureSet {
+ return c.State.FeatureSet
+}
+
+// mmapRand returns a random adjustment for randomizing an mmap layout.
+func mmapRand(max uint64) usermem.Addr {
+ return usermem.Addr(rand.Int63n(int64(max))).RoundDown()
+}
+
+// NewMmapLayout implements Context.NewMmapLayout consistently with Linux.
+func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (MmapLayout, error) {
+ min, ok := min.RoundUp()
+ if !ok {
+ return MmapLayout{}, syscall.EINVAL
+ }
+ if max > maxAddr64 {
+ max = maxAddr64
+ }
+ max = max.RoundDown()
+
+ if min > max {
+ return MmapLayout{}, syscall.EINVAL
+ }
+
+ stackSize := r.Get(limits.Stack)
+
+ // MAX_GAP in Linux.
+ maxGap := (max / 6) * 5
+ gap := usermem.Addr(stackSize.Cur)
+ if gap < minGap64 {
+ gap = minGap64
+ }
+ if gap > maxGap {
+ gap = maxGap
+ }
+ defaultDir := MmapTopDown
+ if stackSize.Cur == limits.Infinity {
+ defaultDir = MmapBottomUp
+ }
+
+ topDownMin := max - gap - maxMmapRand64
+ maxRand := usermem.Addr(maxMmapRand64)
+ if topDownMin < preferredTopDownBaseMin {
+ // Try to keep TopDownBase above preferredTopDownBaseMin by
+ // shrinking maxRand.
+ maxAdjust := maxRand - minMmapRand64
+ needAdjust := preferredTopDownBaseMin - topDownMin
+ if needAdjust <= maxAdjust {
+ maxRand -= needAdjust
+ }
+ }
+
+ rnd := mmapRand(uint64(maxRand))
+ l := MmapLayout{
+ MinAddr: min,
+ MaxAddr: max,
+ // TASK_UNMAPPED_BASE in Linux.
+ BottomUpBase: (max/3 + rnd).RoundDown(),
+ TopDownBase: (max - gap - rnd).RoundDown(),
+ DefaultDirection: defaultDir,
+ // We may have reduced the maximum randomization to keep
+ // TopDownBase above preferredTopDownBaseMin while maintaining
+ // our stack gap. Stack allocations must use that max
+ // randomization to avoiding eating into the gap.
+ MaxStackRand: uint64(maxRand),
+ }
+
+ // Final sanity check on the layout.
+ if !l.Valid() {
+ panic(fmt.Sprintf("Invalid MmapLayout: %+v", l))
+ }
+
+ return l, nil
+}
+
+// PIELoadAddress implements Context.PIELoadAddress.
+func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr {
+ base := preferredPIELoadAddr
+ max, ok := base.AddLength(maxMmapRand64)
+ if !ok {
+ panic(fmt.Sprintf("preferredPIELoadAddr %#x too large", base))
+ }
+
+ if max > l.MaxAddr {
+ // preferredPIELoadAddr won't fit; fall back to the standard
+ // Linux behavior of 2/3 of TopDownBase. TSAN won't like this.
+ //
+ // Don't bother trying to shrink the randomization for now.
+ base = l.TopDownBase / 3 * 2
+ }
+
+ return base + mmapRand(maxMmapRand64)
+}
+
+// userStructSize is the size in bytes of Linux's struct user on amd64.
+const userStructSize = 928
+
+// PtracePeekUser implements Context.PtracePeekUser.
+func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) {
+ if addr&7 != 0 || addr >= userStructSize {
+ return nil, syscall.EIO
+ }
+ // PTRACE_PEEKUSER and PTRACE_POKEUSER are only effective on regs and
+ // u_debugreg, returning 0 or silently no-oping for other fields
+ // respectively.
+ if addr < uintptr(registersSize) {
+ regs := c.ptraceGetRegs()
+ buf := make([]byte, regs.SizeBytes())
+ regs.MarshalUnsafe(buf)
+ return c.Native(uintptr(usermem.ByteOrder.Uint64(buf[addr:]))), nil
+ }
+ // Note: x86 debug registers are missing.
+ return c.Native(0), nil
+}
+
+// PtracePokeUser implements Context.PtracePokeUser.
+func (c *context64) PtracePokeUser(addr, data uintptr) error {
+ if addr&7 != 0 || addr >= userStructSize {
+ return syscall.EIO
+ }
+ if addr < uintptr(registersSize) {
+ regs := c.ptraceGetRegs()
+ buf := make([]byte, regs.SizeBytes())
+ regs.MarshalUnsafe(buf)
+ usermem.ByteOrder.PutUint64(buf[addr:], uint64(data))
+ _, err := c.PtraceSetRegs(bytes.NewBuffer(buf))
+ return err
+ }
+ // Note: x86 debug registers are missing.
+ return nil
+}
diff --git a/pkg/sentry/arch/arch_amd64.s b/pkg/sentry/arch/arch_amd64.s
new file mode 100644
index 000000000..6c10336e7
--- /dev/null
+++ b/pkg/sentry/arch/arch_amd64.s
@@ -0,0 +1,136 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// MXCSR_DEFAULT is the reset value of MXCSR (Intel SDM Vol. 2, Ch. 3.2
+// "LDMXCSR")
+#define MXCSR_DEFAULT 0x1f80
+
+// MXCSR_OFFSET is the offset in bytes of the MXCSR field from the start of the
+// FXSAVE/XSAVE area. (Intel SDM Vol. 1, Table 10-2 "Format of an FXSAVE Area")
+#define MXCSR_OFFSET 24
+
+// initX86FPState initializes floating point state.
+//
+// func initX86FPState(data *FloatingPointData, useXsave bool)
+//
+// We need to clear out and initialize an empty fp state area since the sentry,
+// or any previous loader, may have left sensitive information in the floating
+// point registers.
+//
+// Preconditions: data is zeroed.
+TEXT ·initX86FPState(SB), $24-16
+ // Save MXCSR (callee-save)
+ STMXCSR mxcsr-8(SP)
+
+ // Save x87 CW (callee-save)
+ FSTCW cw-16(SP)
+
+ MOVQ fpState+0(FP), DI
+
+ // Do we use xsave?
+ MOVBQZX useXsave+8(FP), AX
+ TESTQ AX, AX
+ JZ no_xsave
+
+ // Use XRSTOR to clear all FP state to an initial state.
+ //
+ // The fpState XSAVE area is zeroed on function entry, meaning
+ // XSTATE_BV is zero.
+ //
+ // "If RFBM[i] = 1 and bit i is clear in the XSTATE_BV field in the
+ // XSAVE header, XRSTOR initializes state component i."
+ //
+ // Initialization is defined in SDM Vol 1, Chapter 13.3. It puts all
+ // the registers in a reasonable initial state, except MXCSR:
+ //
+ // "The MXCSR register is part of state component 1, SSE state (see
+ // Section 13.5.2). However, the standard form of XRSTOR loads the
+ // MXCSR register from memory whenever the RFBM[1] (SSE) or RFBM[2]
+ // (AVX) is set, regardless of the values of XSTATE_BV[1] and
+ // XSTATE_BV[2]."
+
+ // Set MXCSR to the default value.
+ MOVL $MXCSR_DEFAULT, MXCSR_OFFSET(DI)
+
+ // Initialize registers with XRSTOR.
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x2f // XRSTOR64 0(DI)
+
+ // Now that all the state has been reset, write it back out to the
+ // XSAVE area.
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27 // XSAVE64 0(DI)
+
+ JMP out
+
+no_xsave:
+ // Clear out existing X values.
+ PXOR X0, X0
+ MOVO X0, X1
+ MOVO X0, X2
+ MOVO X0, X3
+ MOVO X0, X4
+ MOVO X0, X5
+ MOVO X0, X6
+ MOVO X0, X7
+ MOVO X0, X8
+ MOVO X0, X9
+ MOVO X0, X10
+ MOVO X0, X11
+ MOVO X0, X12
+ MOVO X0, X13
+ MOVO X0, X14
+ MOVO X0, X15
+
+ // Zero out %rax and store into MMX registers. MMX registers are
+ // an alias of 8x64 bits of the 8x80 bits used for the original
+ // x87 registers. Storing zero into them will reset the FPU registers
+ // to bits [63:0] = 0, [79:64] = 1. But the contents aren't too
+ // important, just the fact that we have reset them to a known value.
+ XORQ AX, AX
+ MOVQ AX, M0
+ MOVQ AX, M1
+ MOVQ AX, M2
+ MOVQ AX, M3
+ MOVQ AX, M4
+ MOVQ AX, M5
+ MOVQ AX, M6
+ MOVQ AX, M7
+
+ // The Go assembler doesn't support FNINIT, so we use BYTE.
+ // This will:
+ // - Reset FPU control word to 0x037f
+ // - Clear FPU status word
+ // - Reset FPU tag word to 0xffff
+ // - Clear FPU data pointer
+ // - Clear FPU instruction pointer
+ BYTE $0xDB; BYTE $0xE3; // FNINIT
+
+ // Reset MXCSR.
+ MOVL $MXCSR_DEFAULT, tmpmxcsr-24(SP)
+ LDMXCSR tmpmxcsr-24(SP)
+
+ // Save the floating point state with fxsave.
+ FXSAVE64 0(DI)
+
+out:
+ // Restore MXCSR.
+ LDMXCSR mxcsr-8(SP)
+
+ // Restore x87 CW.
+ FLDCW cw-16(SP)
+
+ RET
diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go
new file mode 100644
index 000000000..ada7ac7b8
--- /dev/null
+++ b/pkg/sentry/arch/arch_arm64.go
@@ -0,0 +1,284 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package arch
+
+import (
+ "fmt"
+ "math/rand"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Host specifies the host architecture.
+const Host = ARM64
+
+// These constants come directly from Linux.
+const (
+ // maxAddr64 is the maximum userspace address. It is TASK_SIZE in Linux
+ // for a 64-bit process.
+ maxAddr64 usermem.Addr = (1 << 48)
+
+ // maxStackRand64 is the maximum randomization to apply to the stack.
+ // It is defined by arch/arm64/mm/mmap.c:(STACK_RND_MASK << PAGE_SHIFT) in Linux.
+ maxStackRand64 = 0x3ffff << 12 // 16 GB
+
+ // maxMmapRand64 is the maximum randomization to apply to the mmap
+ // layout. It is defined by arch/arm64/mm/mmap.c:arch_mmap_rnd in Linux.
+ maxMmapRand64 = (1 << 33) * usermem.PageSize
+
+ // minGap64 is the minimum gap to leave at the top of the address space
+ // for the stack. It is defined by arch/arm64/mm/mmap.c:MIN_GAP in Linux.
+ minGap64 = (128 << 20) + maxStackRand64
+
+ // preferredPIELoadAddr is the standard Linux position-independent
+ // executable base load address. It is ELF_ET_DYN_BASE in Linux.
+ //
+ // The Platform {Min,Max}UserAddress() may preclude loading at this
+ // address. See other preferredFoo comments below.
+ preferredPIELoadAddr usermem.Addr = maxAddr64 / 6 * 5
+)
+
+var (
+ // CPUIDInstruction doesn't exist on ARM64.
+ CPUIDInstruction = []byte{}
+)
+
+// These constants are selected as heuristics to help make the Platform's
+// potentially limited address space conform as closely to Linux as possible.
+const (
+ preferredTopDownAllocMin usermem.Addr = 0x7e8000000000
+ preferredAllocationGap = 128 << 30 // 128 GB
+ preferredTopDownBaseMin = preferredTopDownAllocMin + preferredAllocationGap
+
+ // minMmapRand64 is the smallest we are willing to make the
+ // randomization to stay above preferredTopDownBaseMin.
+ minMmapRand64 = (1 << 18) * usermem.PageSize
+)
+
+// context64 represents an ARM64 context.
+type context64 struct {
+ State
+ sigFPState []aarch64FPState // fpstate to be restored on sigreturn.
+}
+
+// Arch implements Context.Arch.
+func (c *context64) Arch() Arch {
+ return ARM64
+}
+
+func (c *context64) copySigFPState() []aarch64FPState {
+ var sigfps []aarch64FPState
+ for _, s := range c.sigFPState {
+ sigfps = append(sigfps, s.fork())
+ }
+ return sigfps
+}
+
+// Fork returns an exact copy of this context.
+func (c *context64) Fork() Context {
+ return &context64{
+ State: c.State.Fork(),
+ sigFPState: c.copySigFPState(),
+ }
+}
+
+// General purpose registers usage on Arm64:
+// R0...R7: parameter/result registers.
+// R8: indirect result location register.
+// R9...R15: temporary rgisters.
+// R16: the first intra-procedure-call scratch register.
+// R17: the second intra-procedure-call scratch register.
+// R18: the platform register.
+// R19...R28: callee-saved registers.
+// R29: the frame pointer.
+// R30: the link register.
+
+// Return returns the current syscall return value.
+func (c *context64) Return() uintptr {
+ return uintptr(c.Regs.Regs[0])
+}
+
+// SetReturn sets the syscall return value.
+func (c *context64) SetReturn(value uintptr) {
+ c.Regs.Regs[0] = uint64(value)
+}
+
+// IP returns the current instruction pointer.
+func (c *context64) IP() uintptr {
+ return uintptr(c.Regs.Pc)
+}
+
+// SetIP sets the current instruction pointer.
+func (c *context64) SetIP(value uintptr) {
+ c.Regs.Pc = uint64(value)
+}
+
+// Stack returns the current stack pointer.
+func (c *context64) Stack() uintptr {
+ return uintptr(c.Regs.Sp)
+}
+
+// SetStack sets the current stack pointer.
+func (c *context64) SetStack(value uintptr) {
+ c.Regs.Sp = uint64(value)
+}
+
+// TLS returns the current TLS pointer.
+func (c *context64) TLS() uintptr {
+ return uintptr(c.TPValue)
+}
+
+// SetTLS sets the current TLS pointer. Returns false if value is invalid.
+func (c *context64) SetTLS(value uintptr) bool {
+ if value >= uintptr(maxAddr64) {
+ return false
+ }
+
+ c.TPValue = uint64(value)
+ return true
+}
+
+// SetOldRSeqInterruptedIP implements Context.SetOldRSeqInterruptedIP.
+func (c *context64) SetOldRSeqInterruptedIP(value uintptr) {
+ c.Regs.Regs[3] = uint64(value)
+}
+
+// Native returns the native type for the given val.
+func (c *context64) Native(val uintptr) interface{} {
+ v := uint64(val)
+ return &v
+}
+
+// Value returns the generic val for the given native type.
+func (c *context64) Value(val interface{}) uintptr {
+ return uintptr(*val.(*uint64))
+}
+
+// Width returns the byte width of this architecture.
+func (c *context64) Width() uint {
+ return 8
+}
+
+// FeatureSet returns the FeatureSet in use.
+func (c *context64) FeatureSet() *cpuid.FeatureSet {
+ return c.State.FeatureSet
+}
+
+// mmapRand returns a random adjustment for randomizing an mmap layout.
+func mmapRand(max uint64) usermem.Addr {
+ return usermem.Addr(rand.Int63n(int64(max))).RoundDown()
+}
+
+// NewMmapLayout implements Context.NewMmapLayout consistently with Linux.
+func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (MmapLayout, error) {
+ min, ok := min.RoundUp()
+ if !ok {
+ return MmapLayout{}, syscall.EINVAL
+ }
+ if max > maxAddr64 {
+ max = maxAddr64
+ }
+ max = max.RoundDown()
+
+ if min > max {
+ return MmapLayout{}, syscall.EINVAL
+ }
+
+ stackSize := r.Get(limits.Stack)
+
+ // MAX_GAP in Linux.
+ maxGap := (max / 6) * 5
+ gap := usermem.Addr(stackSize.Cur)
+ if gap < minGap64 {
+ gap = minGap64
+ }
+ if gap > maxGap {
+ gap = maxGap
+ }
+ defaultDir := MmapTopDown
+ if stackSize.Cur == limits.Infinity {
+ defaultDir = MmapBottomUp
+ }
+
+ topDownMin := max - gap - maxMmapRand64
+ maxRand := usermem.Addr(maxMmapRand64)
+ if topDownMin < preferredTopDownBaseMin {
+ // Try to keep TopDownBase above preferredTopDownBaseMin by
+ // shrinking maxRand.
+ maxAdjust := maxRand - minMmapRand64
+ needAdjust := preferredTopDownBaseMin - topDownMin
+ if needAdjust <= maxAdjust {
+ maxRand -= needAdjust
+ }
+ }
+
+ rnd := mmapRand(uint64(maxRand))
+ l := MmapLayout{
+ MinAddr: min,
+ MaxAddr: max,
+ // TASK_UNMAPPED_BASE in Linux.
+ BottomUpBase: (max/3 + rnd).RoundDown(),
+ TopDownBase: (max - gap - rnd).RoundDown(),
+ DefaultDirection: defaultDir,
+ // We may have reduced the maximum randomization to keep
+ // TopDownBase above preferredTopDownBaseMin while maintaining
+ // our stack gap. Stack allocations must use that max
+ // randomization to avoiding eating into the gap.
+ MaxStackRand: uint64(maxRand),
+ }
+
+ // Final sanity check on the layout.
+ if !l.Valid() {
+ panic(fmt.Sprintf("Invalid MmapLayout: %+v", l))
+ }
+
+ return l, nil
+}
+
+// PIELoadAddress implements Context.PIELoadAddress.
+func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr {
+ base := preferredPIELoadAddr
+ max, ok := base.AddLength(maxMmapRand64)
+ if !ok {
+ panic(fmt.Sprintf("preferredPIELoadAddr %#x too large", base))
+ }
+
+ if max > l.MaxAddr {
+ // preferredPIELoadAddr won't fit; fall back to the standard
+ // Linux behavior of 2/3 of TopDownBase. TSAN won't like this.
+ //
+ // Don't bother trying to shrink the randomization for now.
+ base = l.TopDownBase / 3 * 2
+ }
+
+ return base + mmapRand(maxMmapRand64)
+}
+
+// PtracePeekUser implements Context.PtracePeekUser.
+func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) {
+ // TODO(gvisor.dev/issue/1239): Full ptrace supporting for Arm64.
+ return c.Native(0), nil
+}
+
+// PtracePokeUser implements Context.PtracePokeUser.
+func (c *context64) PtracePokeUser(addr, data uintptr) error {
+ // TODO(gvisor.dev/issue/1239): Full ptrace supporting for Arm64.
+ return nil
+}
diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go
new file mode 100644
index 000000000..19ce99d25
--- /dev/null
+++ b/pkg/sentry/arch/arch_state_x86.go
@@ -0,0 +1,91 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64 386
+
+package arch
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// ErrFloatingPoint indicates a failed restore due to unusable floating point
+// state.
+type ErrFloatingPoint struct {
+ // supported is the supported floating point state.
+ supported uint64
+
+ // saved is the saved floating point state.
+ saved uint64
+}
+
+// Error returns a sensible description of the restore error.
+func (e ErrFloatingPoint) Error() string {
+ return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved)
+}
+
+// XSTATE_BV does not exist if FXSAVE is used, but FXSAVE implicitly saves x87
+// and SSE state, so this is the equivalent XSTATE_BV value.
+const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE
+
+// afterLoadFPState is invoked by afterLoad.
+func (s *State) afterLoadFPState() {
+ old := s.x86FPState
+
+ // Recreate the slice. This is done to ensure that it is aligned
+ // appropriately in memory, and large enough to accommodate any new
+ // state that may be saved by the new CPU. Even if extraneous new state
+ // is saved, the state we care about is guaranteed to be a subset of
+ // new state. Later optimizations can use less space when using a
+ // smaller state component bitmap. Intel SDM Volume 1 Chapter 13 has
+ // more info.
+ s.x86FPState = newX86FPState()
+
+ // x86FPState always contains all the FP state supported by the host.
+ // We may have come from a newer machine that supports additional state
+ // which we cannot restore.
+ //
+ // The x86 FP state areas are backwards compatible, so we can simply
+ // truncate the additional floating point state.
+ //
+ // Applications should not depend on the truncated state because it
+ // should relate only to features that were not exposed in the app
+ // FeatureSet. However, because we do not *prevent* them from using
+ // this state, we must verify here that there is no in-use state
+ // (according to XSTATE_BV) which we do not support.
+ if len(s.x86FPState) < len(old) {
+ // What do we support?
+ supportedBV := fxsaveBV
+ if fs := cpuid.HostFeatureSet(); fs.UseXsave() {
+ supportedBV = fs.ValidXCR0Mask()
+ }
+
+ // What was in use?
+ savedBV := fxsaveBV
+ if len(old) >= xstateBVOffset+8 {
+ savedBV = usermem.ByteOrder.Uint64(old[xstateBVOffset:])
+ }
+
+ // Supported features must be a superset of saved features.
+ if savedBV&^supportedBV != 0 {
+ panic(ErrFloatingPoint{supported: supportedBV, saved: savedBV})
+ }
+ }
+
+ // Copy to the new, aligned location.
+ copy(s.x86FPState, old)
+}
diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go
new file mode 100644
index 000000000..dc458b37f
--- /dev/null
+++ b/pkg/sentry/arch/arch_x86.go
@@ -0,0 +1,615 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64 386
+
+package arch
+
+import (
+ "fmt"
+ "io"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/log"
+ rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Registers represents the CPU registers for this architecture.
+type Registers = linux.PtraceRegs
+
+// System-related constants for x86.
+const (
+ // SyscallWidth is the width of syscall, sysenter, and int 80 insturctions.
+ SyscallWidth = 2
+)
+
+// EFLAGS register bits.
+const (
+ // eflagsCF is the mask for the carry flag.
+ eflagsCF = uint64(1) << 0
+ // eflagsPF is the mask for the parity flag.
+ eflagsPF = uint64(1) << 2
+ // eflagsAF is the mask for the auxiliary carry flag.
+ eflagsAF = uint64(1) << 4
+ // eflagsZF is the mask for the zero flag.
+ eflagsZF = uint64(1) << 6
+ // eflagsSF is the mask for the sign flag.
+ eflagsSF = uint64(1) << 7
+ // eflagsTF is the mask for the trap flag.
+ eflagsTF = uint64(1) << 8
+ // eflagsIF is the mask for the interrupt flag.
+ eflagsIF = uint64(1) << 9
+ // eflagsDF is the mask for the direction flag.
+ eflagsDF = uint64(1) << 10
+ // eflagsOF is the mask for the overflow flag.
+ eflagsOF = uint64(1) << 11
+ // eflagsIOPL is the mask for the I/O privilege level.
+ eflagsIOPL = uint64(3) << 12
+ // eflagsNT is the mask for the nested task bit.
+ eflagsNT = uint64(1) << 14
+ // eflagsRF is the mask for the resume flag.
+ eflagsRF = uint64(1) << 16
+ // eflagsVM is the mask for the virtual mode bit.
+ eflagsVM = uint64(1) << 17
+ // eflagsAC is the mask for the alignment check / access control bit.
+ eflagsAC = uint64(1) << 18
+ // eflagsVIF is the mask for the virtual interrupt flag.
+ eflagsVIF = uint64(1) << 19
+ // eflagsVIP is the mask for the virtual interrupt pending bit.
+ eflagsVIP = uint64(1) << 20
+ // eflagsID is the mask for the CPUID detection bit.
+ eflagsID = uint64(1) << 21
+
+ // eflagsPtraceMutable is the mask for the set of EFLAGS that may be
+ // changed by ptrace(PTRACE_SETREGS). eflagsPtraceMutable is analogous to
+ // Linux's FLAG_MASK.
+ eflagsPtraceMutable = eflagsCF | eflagsPF | eflagsAF | eflagsZF | eflagsSF | eflagsTF | eflagsDF | eflagsOF | eflagsRF | eflagsAC | eflagsNT
+
+ // eflagsRestorable is the mask for the set of EFLAGS that may be changed by
+ // SignalReturn. eflagsRestorable is analogous to Linux's FIX_EFLAGS.
+ eflagsRestorable = eflagsAC | eflagsOF | eflagsDF | eflagsTF | eflagsSF | eflagsZF | eflagsAF | eflagsPF | eflagsCF | eflagsRF
+)
+
+// Segment selectors. See arch/x86/include/asm/segment.h.
+const (
+ userCS = 0x33 // guest ring 3 code selector
+ user32CS = 0x23 // guest ring 3 32 bit code selector
+ userDS = 0x2b // guest ring 3 data selector
+
+ _FS_TLS_SEL = 0x63 // Linux FS thread-local storage selector
+ _GS_TLS_SEL = 0x6b // Linux GS thread-local storage selector
+)
+
+var (
+ // TrapInstruction is the x86 trap instruction.
+ TrapInstruction = [1]byte{0xcc}
+
+ // CPUIDInstruction is the x86 CPUID instruction.
+ CPUIDInstruction = [2]byte{0xf, 0xa2}
+
+ // X86TrapFlag is an exported const for use by other packages.
+ X86TrapFlag uint64 = (1 << 8)
+)
+
+// x86FPState is x86 floating point state.
+type x86FPState []byte
+
+// initX86FPState (defined in asm files) sets up initial state.
+func initX86FPState(data *FloatingPointData, useXsave bool)
+
+func newX86FPStateSlice() []byte {
+ size, align := cpuid.HostFeatureSet().ExtendedStateSize()
+ capacity := size
+ // Always use at least 4096 bytes.
+ //
+ // For the KVM platform, this state is a fixed 4096 bytes, so make sure
+ // that the underlying array is at _least_ that size otherwise we will
+ // corrupt random memory. This is not a pleasant thing to debug.
+ if capacity < 4096 {
+ capacity = 4096
+ }
+ return alignedBytes(capacity, align)[:size]
+}
+
+// newX86FPState returns an initialized floating point state.
+//
+// The returned state is large enough to store all floating point state
+// supported by host, even if the app won't use much of it due to a restricted
+// FeatureSet. Since they may still be able to see state not advertised by
+// CPUID we must ensure it does not contain any sentry state.
+func newX86FPState() x86FPState {
+ f := x86FPState(newX86FPStateSlice())
+ initX86FPState(f.FloatingPointData(), cpuid.HostFeatureSet().UseXsave())
+ return f
+}
+
+// fork creates and returns an identical copy of the x86 floating point state.
+func (f x86FPState) fork() x86FPState {
+ n := x86FPState(newX86FPStateSlice())
+ copy(n, f)
+ return n
+}
+
+// FloatingPointData returns the raw data pointer.
+func (f x86FPState) FloatingPointData() *FloatingPointData {
+ return (*FloatingPointData)(&f[0])
+}
+
+// NewFloatingPointData returns a new floating point data blob.
+//
+// This is primarily for use in tests.
+func NewFloatingPointData() *FloatingPointData {
+ return (*FloatingPointData)(&(newX86FPState()[0]))
+}
+
+// Proto returns a protobuf representation of the system registers in State.
+func (s State) Proto() *rpb.Registers {
+ regs := &rpb.AMD64Registers{
+ Rax: s.Regs.Rax,
+ Rbx: s.Regs.Rbx,
+ Rcx: s.Regs.Rcx,
+ Rdx: s.Regs.Rdx,
+ Rsi: s.Regs.Rsi,
+ Rdi: s.Regs.Rdi,
+ Rsp: s.Regs.Rsp,
+ Rbp: s.Regs.Rbp,
+ R8: s.Regs.R8,
+ R9: s.Regs.R9,
+ R10: s.Regs.R10,
+ R11: s.Regs.R11,
+ R12: s.Regs.R12,
+ R13: s.Regs.R13,
+ R14: s.Regs.R14,
+ R15: s.Regs.R15,
+ Rip: s.Regs.Rip,
+ Rflags: s.Regs.Eflags,
+ OrigRax: s.Regs.Orig_rax,
+ Cs: s.Regs.Cs,
+ Ds: s.Regs.Ds,
+ Es: s.Regs.Es,
+ Fs: s.Regs.Fs,
+ Gs: s.Regs.Gs,
+ Ss: s.Regs.Ss,
+ FsBase: s.Regs.Fs_base,
+ GsBase: s.Regs.Gs_base,
+ }
+ return &rpb.Registers{Arch: &rpb.Registers_Amd64{Amd64: regs}}
+}
+
+// Fork creates and returns an identical copy of the state.
+func (s *State) Fork() State {
+ return State{
+ Regs: s.Regs,
+ x86FPState: s.x86FPState.fork(),
+ FeatureSet: s.FeatureSet,
+ }
+}
+
+// StateData implements Context.StateData.
+func (s *State) StateData() *State {
+ return s
+}
+
+// CPUIDEmulate emulates a cpuid instruction.
+func (s *State) CPUIDEmulate(l log.Logger) {
+ argax := uint32(s.Regs.Rax)
+ argcx := uint32(s.Regs.Rcx)
+ ax, bx, cx, dx := s.FeatureSet.EmulateID(argax, argcx)
+ s.Regs.Rax = uint64(ax)
+ s.Regs.Rbx = uint64(bx)
+ s.Regs.Rcx = uint64(cx)
+ s.Regs.Rdx = uint64(dx)
+ l.Debugf("CPUID(%x,%x): %x %x %x %x", argax, argcx, ax, bx, cx, dx)
+}
+
+// SingleStep implements Context.SingleStep.
+func (s *State) SingleStep() bool {
+ return s.Regs.Eflags&X86TrapFlag != 0
+}
+
+// SetSingleStep enables single stepping.
+func (s *State) SetSingleStep() {
+ // Set the trap flag.
+ s.Regs.Eflags |= X86TrapFlag
+}
+
+// ClearSingleStep enables single stepping.
+func (s *State) ClearSingleStep() {
+ // Clear the trap flag.
+ s.Regs.Eflags &= ^X86TrapFlag
+}
+
+// RegisterMap returns a map of all registers.
+func (s *State) RegisterMap() (map[string]uintptr, error) {
+ return map[string]uintptr{
+ "R15": uintptr(s.Regs.R15),
+ "R14": uintptr(s.Regs.R14),
+ "R13": uintptr(s.Regs.R13),
+ "R12": uintptr(s.Regs.R12),
+ "Rbp": uintptr(s.Regs.Rbp),
+ "Rbx": uintptr(s.Regs.Rbx),
+ "R11": uintptr(s.Regs.R11),
+ "R10": uintptr(s.Regs.R10),
+ "R9": uintptr(s.Regs.R9),
+ "R8": uintptr(s.Regs.R8),
+ "Rax": uintptr(s.Regs.Rax),
+ "Rcx": uintptr(s.Regs.Rcx),
+ "Rdx": uintptr(s.Regs.Rdx),
+ "Rsi": uintptr(s.Regs.Rsi),
+ "Rdi": uintptr(s.Regs.Rdi),
+ "Orig_rax": uintptr(s.Regs.Orig_rax),
+ "Rip": uintptr(s.Regs.Rip),
+ "Cs": uintptr(s.Regs.Cs),
+ "Eflags": uintptr(s.Regs.Eflags),
+ "Rsp": uintptr(s.Regs.Rsp),
+ "Ss": uintptr(s.Regs.Ss),
+ "Fs_base": uintptr(s.Regs.Fs_base),
+ "Gs_base": uintptr(s.Regs.Gs_base),
+ "Ds": uintptr(s.Regs.Ds),
+ "Es": uintptr(s.Regs.Es),
+ "Fs": uintptr(s.Regs.Fs),
+ "Gs": uintptr(s.Regs.Gs),
+ }, nil
+}
+
+// PtraceGetRegs implements Context.PtraceGetRegs.
+func (s *State) PtraceGetRegs(dst io.Writer) (int, error) {
+ regs := s.ptraceGetRegs()
+ n, err := regs.WriteTo(dst)
+ return int(n), err
+}
+
+func (s *State) ptraceGetRegs() Registers {
+ regs := s.Regs
+ // These may not be initialized.
+ if regs.Cs == 0 || regs.Ss == 0 || regs.Eflags == 0 {
+ regs.Eflags = eflagsIF
+ regs.Cs = userCS
+ regs.Ss = userDS
+ }
+ // As an optimization, Linux <4.7 implements 32-bit fs_base/gs_base
+ // addresses using reserved descriptors in the GDT instead of the MSRs,
+ // with selector values FS_TLS_SEL and GS_TLS_SEL respectively. These
+ // values are actually visible in struct user_regs_struct::fs/gs;
+ // arch/x86/kernel/ptrace.c:getreg() doesn't attempt to sanitize struct
+ // thread_struct::fsindex/gsindex.
+ //
+ // We always use fs == gs == 0 when fs_base/gs_base is in use, for
+ // simplicity.
+ //
+ // Luckily, Linux <4.7 silently ignores setting fs/gs to 0 via
+ // arch/x86/kernel/ptrace.c:set_segment_reg() when fs_base/gs_base is a
+ // 32-bit value and fsindex/gsindex indicates that this optimization is
+ // in use, as well as the reverse case of setting fs/gs to
+ // FS/GS_TLS_SEL when fs_base/gs_base is a 64-bit value. (We do the
+ // same in PtraceSetRegs.)
+ //
+ // TODO(gvisor.dev/issue/168): Remove this fixup since newer Linux
+ // doesn't have this behavior anymore.
+ if regs.Fs == 0 && regs.Fs_base <= 0xffffffff {
+ regs.Fs = _FS_TLS_SEL
+ }
+ if regs.Gs == 0 && regs.Gs_base <= 0xffffffff {
+ regs.Gs = _GS_TLS_SEL
+ }
+ return regs
+}
+
+var registersSize = (*Registers)(nil).SizeBytes()
+
+// PtraceSetRegs implements Context.PtraceSetRegs.
+func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
+ var regs Registers
+ buf := make([]byte, registersSize)
+ if _, err := io.ReadFull(src, buf); err != nil {
+ return 0, err
+ }
+ regs.UnmarshalUnsafe(buf)
+ // Truncate segment registers to 16 bits.
+ regs.Cs = uint64(uint16(regs.Cs))
+ regs.Ds = uint64(uint16(regs.Ds))
+ regs.Es = uint64(uint16(regs.Es))
+ regs.Fs = uint64(uint16(regs.Fs))
+ regs.Gs = uint64(uint16(regs.Gs))
+ regs.Ss = uint64(uint16(regs.Ss))
+ // In Linux this validation is via arch/x86/kernel/ptrace.c:putreg().
+ if !isUserSegmentSelector(regs.Cs) {
+ return 0, syscall.EIO
+ }
+ if regs.Ds != 0 && !isUserSegmentSelector(regs.Ds) {
+ return 0, syscall.EIO
+ }
+ if regs.Es != 0 && !isUserSegmentSelector(regs.Es) {
+ return 0, syscall.EIO
+ }
+ if regs.Fs != 0 && !isUserSegmentSelector(regs.Fs) {
+ return 0, syscall.EIO
+ }
+ if regs.Gs != 0 && !isUserSegmentSelector(regs.Gs) {
+ return 0, syscall.EIO
+ }
+ if !isUserSegmentSelector(regs.Ss) {
+ return 0, syscall.EIO
+ }
+ if !isValidSegmentBase(regs.Fs_base) {
+ return 0, syscall.EIO
+ }
+ if !isValidSegmentBase(regs.Gs_base) {
+ return 0, syscall.EIO
+ }
+ // CS and SS are validated, but changes to them are otherwise silently
+ // ignored on amd64.
+ regs.Cs = s.Regs.Cs
+ regs.Ss = s.Regs.Ss
+ // fs_base/gs_base changes reset fs/gs via do_arch_prctl() on Linux.
+ if regs.Fs_base != s.Regs.Fs_base {
+ regs.Fs = 0
+ }
+ if regs.Gs_base != s.Regs.Gs_base {
+ regs.Gs = 0
+ }
+ // Ignore "stale" TLS segment selectors for FS and GS. See comment in
+ // ptraceGetRegs.
+ if regs.Fs == _FS_TLS_SEL && regs.Fs_base != 0 {
+ regs.Fs = 0
+ }
+ if regs.Gs == _GS_TLS_SEL && regs.Gs_base != 0 {
+ regs.Gs = 0
+ }
+ regs.Eflags = (s.Regs.Eflags &^ eflagsPtraceMutable) | (regs.Eflags & eflagsPtraceMutable)
+ s.Regs = regs
+ return registersSize, nil
+}
+
+// isUserSegmentSelector returns true if the given segment selector specifies a
+// privilege level of 3 (USER_RPL).
+func isUserSegmentSelector(reg uint64) bool {
+ return reg&3 == 3
+}
+
+// isValidSegmentBase returns true if the given segment base specifies a
+// canonical user address.
+func isValidSegmentBase(reg uint64) bool {
+ return reg < uint64(maxAddr64)
+}
+
+// ptraceFPRegsSize is the size in bytes of Linux's user_i387_struct, the type
+// manipulated by PTRACE_GETFPREGS and PTRACE_SETFPREGS on x86. Equivalently,
+// ptraceFPRegsSize is the size in bytes of the x86 FXSAVE area.
+const ptraceFPRegsSize = 512
+
+// PtraceGetFPRegs implements Context.PtraceGetFPRegs.
+func (s *State) PtraceGetFPRegs(dst io.Writer) (int, error) {
+ return dst.Write(s.x86FPState[:ptraceFPRegsSize])
+}
+
+// PtraceSetFPRegs implements Context.PtraceSetFPRegs.
+func (s *State) PtraceSetFPRegs(src io.Reader) (int, error) {
+ var f [ptraceFPRegsSize]byte
+ n, err := io.ReadFull(src, f[:])
+ if err != nil {
+ return 0, err
+ }
+ // Force reserved bits in MXCSR to 0. This is consistent with Linux.
+ sanitizeMXCSR(x86FPState(f[:]))
+ // N.B. this only copies the beginning of the FP state, which
+ // corresponds to the FXSAVE area.
+ copy(s.x86FPState, f[:])
+ return n, nil
+}
+
+const (
+ // mxcsrOffset is the offset in bytes of the MXCSR field from the start of
+ // the FXSAVE area. (Intel SDM Vol. 1, Table 10-2 "Format of an FXSAVE
+ // Area")
+ mxcsrOffset = 24
+
+ // mxcsrMaskOffset is the offset in bytes of the MXCSR_MASK field from the
+ // start of the FXSAVE area.
+ mxcsrMaskOffset = 28
+)
+
+var (
+ mxcsrMask uint32
+ initMXCSRMask sync.Once
+)
+
+// sanitizeMXCSR coerces reserved bits in the MXCSR field of f to 0. ("FXRSTOR
+// generates a general-protection fault (#GP) in response to an attempt to set
+// any of the reserved bits of the MXCSR register." - Intel SDM Vol. 1, Section
+// 10.5.1.2 "SSE State")
+func sanitizeMXCSR(f x86FPState) {
+ mxcsr := usermem.ByteOrder.Uint32(f[mxcsrOffset:])
+ initMXCSRMask.Do(func() {
+ temp := x86FPState(alignedBytes(uint(ptraceFPRegsSize), 16))
+ initX86FPState(temp.FloatingPointData(), false /* useXsave */)
+ mxcsrMask = usermem.ByteOrder.Uint32(temp[mxcsrMaskOffset:])
+ if mxcsrMask == 0 {
+ // "If the value of the MXCSR_MASK field is 00000000H, then the
+ // MXCSR_MASK value is the default value of 0000FFBFH." - Intel SDM
+ // Vol. 1, Section 11.6.6 "Guidelines for Writing to the MXCSR
+ // Register"
+ mxcsrMask = 0xffbf
+ }
+ })
+ mxcsr &= mxcsrMask
+ usermem.ByteOrder.PutUint32(f[mxcsrOffset:], mxcsr)
+}
+
+const (
+ // minXstateBytes is the minimum size in bytes of an x86 XSAVE area, equal
+ // to the size of the XSAVE legacy area (512 bytes) plus the size of the
+ // XSAVE header (64 bytes). Equivalently, minXstateBytes is GDB's
+ // X86_XSTATE_SSE_SIZE.
+ minXstateBytes = 512 + 64
+
+ // userXstateXCR0Offset is the offset in bytes of the USER_XSTATE_XCR0_WORD
+ // field in Linux's struct user_xstateregs, which is the type manipulated
+ // by ptrace(PTRACE_GET/SETREGSET, NT_X86_XSTATE). Equivalently,
+ // userXstateXCR0Offset is GDB's I386_LINUX_XSAVE_XCR0_OFFSET.
+ userXstateXCR0Offset = 464
+
+ // xstateBVOffset is the offset in bytes of the XSTATE_BV field in an x86
+ // XSAVE area.
+ xstateBVOffset = 512
+
+ // xsaveHeaderZeroedOffset and xsaveHeaderZeroedBytes indicate parts of the
+ // XSAVE header that we coerce to zero: "Bytes 15:8 of the XSAVE header is
+ // a state-component bitmap called XCOMP_BV. ... Bytes 63:16 of the XSAVE
+ // header are reserved." - Intel SDM Vol. 1, Section 13.4.2 "XSAVE Header".
+ // Linux ignores XCOMP_BV, but it's able to recover from XRSTOR #GP
+ // exceptions resulting from invalid values; we aren't. Linux also never
+ // uses the compacted format when doing XSAVE and doesn't even define the
+ // compaction extensions to XSAVE as a CPU feature, so for simplicity we
+ // assume no one is using them.
+ xsaveHeaderZeroedOffset = 512 + 8
+ xsaveHeaderZeroedBytes = 64 - 8
+)
+
+func (s *State) ptraceGetXstateRegs(dst io.Writer, maxlen int) (int, error) {
+ // N.B. s.x86FPState may contain more state than the application
+ // expects. We only copy the subset that would be in their XSAVE area.
+ ess, _ := s.FeatureSet.ExtendedStateSize()
+ f := make([]byte, ess)
+ copy(f, s.x86FPState)
+ // "The XSAVE feature set does not use bytes 511:416; bytes 463:416 are
+ // reserved." - Intel SDM Vol 1., Section 13.4.1 "Legacy Region of an XSAVE
+ // Area". Linux uses the first 8 bytes of this area to store the OS XSTATE
+ // mask. GDB relies on this: see
+ // gdb/x86-linux-nat.c:x86_linux_read_description().
+ usermem.ByteOrder.PutUint64(f[userXstateXCR0Offset:], s.FeatureSet.ValidXCR0Mask())
+ if len(f) > maxlen {
+ f = f[:maxlen]
+ }
+ return dst.Write(f)
+}
+
+func (s *State) ptraceSetXstateRegs(src io.Reader, maxlen int) (int, error) {
+ // Allow users to pass an xstate register set smaller than ours (they can
+ // mask bits out of XSTATE_BV), as long as it's at least minXstateBytes.
+ // Also allow users to pass a register set larger than ours; anything after
+ // their ExtendedStateSize will be ignored. (I think Linux technically
+ // permits setting a register set smaller than minXstateBytes, but it has
+ // the same silent truncation behavior in kernel/ptrace.c:ptrace_regset().)
+ if maxlen < minXstateBytes {
+ return 0, syscall.EFAULT
+ }
+ ess, _ := s.FeatureSet.ExtendedStateSize()
+ if maxlen > int(ess) {
+ maxlen = int(ess)
+ }
+ f := make([]byte, maxlen)
+ if _, err := io.ReadFull(src, f); err != nil {
+ return 0, err
+ }
+ // Force reserved bits in MXCSR to 0. This is consistent with Linux.
+ sanitizeMXCSR(x86FPState(f))
+ // Users can't enable *more* XCR0 bits than what we, and the CPU, support.
+ xstateBV := usermem.ByteOrder.Uint64(f[xstateBVOffset:])
+ xstateBV &= s.FeatureSet.ValidXCR0Mask()
+ usermem.ByteOrder.PutUint64(f[xstateBVOffset:], xstateBV)
+ // Force XCOMP_BV and reserved bytes in the XSAVE header to 0.
+ reserved := f[xsaveHeaderZeroedOffset : xsaveHeaderZeroedOffset+xsaveHeaderZeroedBytes]
+ for i := range reserved {
+ reserved[i] = 0
+ }
+ return copy(s.x86FPState, f), nil
+}
+
+// Register sets defined in include/uapi/linux/elf.h.
+const (
+ _NT_PRSTATUS = 1
+ _NT_PRFPREG = 2
+ _NT_X86_XSTATE = 0x202
+)
+
+// PtraceGetRegSet implements Context.PtraceGetRegSet.
+func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) {
+ switch regset {
+ case _NT_PRSTATUS:
+ if maxlen < registersSize {
+ return 0, syserror.EFAULT
+ }
+ return s.PtraceGetRegs(dst)
+ case _NT_PRFPREG:
+ if maxlen < ptraceFPRegsSize {
+ return 0, syserror.EFAULT
+ }
+ return s.PtraceGetFPRegs(dst)
+ case _NT_X86_XSTATE:
+ return s.ptraceGetXstateRegs(dst, maxlen)
+ default:
+ return 0, syserror.EINVAL
+ }
+}
+
+// PtraceSetRegSet implements Context.PtraceSetRegSet.
+func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) {
+ switch regset {
+ case _NT_PRSTATUS:
+ if maxlen < registersSize {
+ return 0, syserror.EFAULT
+ }
+ return s.PtraceSetRegs(src)
+ case _NT_PRFPREG:
+ if maxlen < ptraceFPRegsSize {
+ return 0, syserror.EFAULT
+ }
+ return s.PtraceSetFPRegs(src)
+ case _NT_X86_XSTATE:
+ return s.ptraceSetXstateRegs(src, maxlen)
+ default:
+ return 0, syserror.EINVAL
+ }
+}
+
+// FullRestore indicates whether a full restore is required.
+func (s *State) FullRestore() bool {
+ // A fast system call return is possible only if
+ //
+ // * RCX matches the instruction pointer.
+ // * R11 matches our flags value.
+ // * Usermode does not expect to set either the resume flag or the
+ // virtual mode flags (unlikely.)
+ // * CS and SS are set to the standard selectors.
+ //
+ // That is, SYSRET results in the correct final state.
+ fastRestore := s.Regs.Rcx == s.Regs.Rip &&
+ s.Regs.Eflags == s.Regs.R11 &&
+ (s.Regs.Eflags&eflagsRF == 0) &&
+ (s.Regs.Eflags&eflagsVM == 0) &&
+ s.Regs.Cs == userCS &&
+ s.Regs.Ss == userDS
+ return !fastRestore
+}
+
+// New returns a new architecture context.
+func New(arch Arch, fs *cpuid.FeatureSet) Context {
+ switch arch {
+ case AMD64:
+ return &context64{
+ State{
+ x86FPState: newX86FPState(),
+ FeatureSet: fs,
+ },
+ []x86FPState(nil),
+ }
+ }
+ panic(fmt.Sprintf("unknown architecture %v", arch))
+}
diff --git a/pkg/sentry/arch/arch_x86_impl.go b/pkg/sentry/arch/arch_x86_impl.go
new file mode 100644
index 000000000..0c73fcbfb
--- /dev/null
+++ b/pkg/sentry/arch/arch_x86_impl.go
@@ -0,0 +1,41 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64 386
+
+package arch
+
+import (
+ "gvisor.dev/gvisor/pkg/cpuid"
+)
+
+// State contains the common architecture bits for X86 (the build tag of this
+// file ensures it's only built on x86).
+//
+// +stateify savable
+type State struct {
+ // The system registers.
+ Regs Registers
+
+ // Our floating point state.
+ x86FPState `state:"wait"`
+
+ // FeatureSet is a pointer to the currently active feature set.
+ FeatureSet *cpuid.FeatureSet
+}
+
+// afterLoad is invoked by stateify.
+func (s *State) afterLoad() {
+ s.afterLoadFPState()
+}
diff --git a/pkg/sentry/arch/auxv.go b/pkg/sentry/arch/auxv.go
new file mode 100644
index 000000000..2b4c8f3fc
--- /dev/null
+++ b/pkg/sentry/arch/auxv.go
@@ -0,0 +1,30 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import (
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// An AuxEntry represents an entry in an ELF auxiliary vector.
+//
+// +stateify savable
+type AuxEntry struct {
+ Key uint64
+ Value usermem.Addr
+}
+
+// An Auxv represents an ELF auxiliary vector.
+type Auxv []AuxEntry
diff --git a/pkg/sentry/arch/registers.proto b/pkg/sentry/arch/registers.proto
new file mode 100644
index 000000000..60c027aab
--- /dev/null
+++ b/pkg/sentry/arch/registers.proto
@@ -0,0 +1,92 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package gvisor;
+
+message AMD64Registers {
+ uint64 rax = 1;
+ uint64 rbx = 2;
+ uint64 rcx = 3;
+ uint64 rdx = 4;
+ uint64 rsi = 5;
+ uint64 rdi = 6;
+ uint64 rsp = 7;
+ uint64 rbp = 8;
+
+ uint64 r8 = 9;
+ uint64 r9 = 10;
+ uint64 r10 = 11;
+ uint64 r11 = 12;
+ uint64 r12 = 13;
+ uint64 r13 = 14;
+ uint64 r14 = 15;
+ uint64 r15 = 16;
+
+ uint64 rip = 17;
+ uint64 rflags = 18;
+ uint64 orig_rax = 19;
+ uint64 cs = 20;
+ uint64 ds = 21;
+ uint64 es = 22;
+ uint64 fs = 23;
+ uint64 gs = 24;
+ uint64 ss = 25;
+ uint64 fs_base = 26;
+ uint64 gs_base = 27;
+}
+
+message ARM64Registers {
+ uint64 r0 = 1;
+ uint64 r1 = 2;
+ uint64 r2 = 3;
+ uint64 r3 = 4;
+ uint64 r4 = 5;
+ uint64 r5 = 6;
+ uint64 r6 = 7;
+ uint64 r7 = 8;
+ uint64 r8 = 9;
+ uint64 r9 = 10;
+ uint64 r10 = 11;
+ uint64 r11 = 12;
+ uint64 r12 = 13;
+ uint64 r13 = 14;
+ uint64 r14 = 15;
+ uint64 r15 = 16;
+ uint64 r16 = 17;
+ uint64 r17 = 18;
+ uint64 r18 = 19;
+ uint64 r19 = 20;
+ uint64 r20 = 21;
+ uint64 r21 = 22;
+ uint64 r22 = 23;
+ uint64 r23 = 24;
+ uint64 r24 = 25;
+ uint64 r25 = 26;
+ uint64 r26 = 27;
+ uint64 r27 = 28;
+ uint64 r28 = 29;
+ uint64 r29 = 30;
+ uint64 r30 = 31;
+ uint64 sp = 32;
+ uint64 pc = 33;
+ uint64 pstate = 34;
+}
+message Registers {
+ oneof arch {
+ AMD64Registers amd64 = 1;
+ ARM64Registers arm64 = 2;
+ }
+}
diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go
new file mode 100644
index 000000000..c9fb55d00
--- /dev/null
+++ b/pkg/sentry/arch/signal.go
@@ -0,0 +1,253 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// SignalAct represents the action that should be taken when a signal is
+// delivered, and is equivalent to struct sigaction.
+//
+// +marshal
+// +stateify savable
+type SignalAct struct {
+ Handler uint64
+ Flags uint64
+ Restorer uint64 // Only used on amd64.
+ Mask linux.SignalSet
+}
+
+// SerializeFrom implements NativeSignalAct.SerializeFrom.
+func (s *SignalAct) SerializeFrom(other *SignalAct) {
+ *s = *other
+}
+
+// DeserializeTo implements NativeSignalAct.DeserializeTo.
+func (s *SignalAct) DeserializeTo(other *SignalAct) {
+ *other = *s
+}
+
+// SignalStack represents information about a user stack, and is equivalent to
+// stack_t.
+//
+// +marshal
+// +stateify savable
+type SignalStack struct {
+ Addr uint64
+ Flags uint32
+ _ uint32
+ Size uint64
+}
+
+// SerializeFrom implements NativeSignalStack.SerializeFrom.
+func (s *SignalStack) SerializeFrom(other *SignalStack) {
+ *s = *other
+}
+
+// DeserializeTo implements NativeSignalStack.DeserializeTo.
+func (s *SignalStack) DeserializeTo(other *SignalStack) {
+ *other = *s
+}
+
+// SignalInfo represents information about a signal being delivered, and is
+// equivalent to struct siginfo in linux kernel(linux/include/uapi/asm-generic/siginfo.h).
+//
+// +marshal
+// +stateify savable
+type SignalInfo struct {
+ Signo int32 // Signal number
+ Errno int32 // Errno value
+ Code int32 // Signal code
+ _ uint32
+
+ // struct siginfo::_sifields is a union. In SignalInfo, fields in the union
+ // are accessed through methods.
+ //
+ // For reference, here is the definition of _sifields: (_sigfault._trapno,
+ // which does not exist on x86, omitted for clarity)
+ //
+ // union {
+ // int _pad[SI_PAD_SIZE];
+ //
+ // /* kill() */
+ // struct {
+ // __kernel_pid_t _pid; /* sender's pid */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // } _kill;
+ //
+ // /* POSIX.1b timers */
+ // struct {
+ // __kernel_timer_t _tid; /* timer id */
+ // int _overrun; /* overrun count */
+ // char _pad[sizeof( __ARCH_SI_UID_T) - sizeof(int)];
+ // sigval_t _sigval; /* same as below */
+ // int _sys_private; /* not to be passed to user */
+ // } _timer;
+ //
+ // /* POSIX.1b signals */
+ // struct {
+ // __kernel_pid_t _pid; /* sender's pid */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // sigval_t _sigval;
+ // } _rt;
+ //
+ // /* SIGCHLD */
+ // struct {
+ // __kernel_pid_t _pid; /* which child */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // int _status; /* exit code */
+ // __ARCH_SI_CLOCK_T _utime;
+ // __ARCH_SI_CLOCK_T _stime;
+ // } _sigchld;
+ //
+ // /* SIGILL, SIGFPE, SIGSEGV, SIGBUS */
+ // struct {
+ // void *_addr; /* faulting insn/memory ref. */
+ // short _addr_lsb; /* LSB of the reported address */
+ // } _sigfault;
+ //
+ // /* SIGPOLL */
+ // struct {
+ // __ARCH_SI_BAND_T _band; /* POLL_IN, POLL_OUT, POLL_MSG */
+ // int _fd;
+ // } _sigpoll;
+ //
+ // /* SIGSYS */
+ // struct {
+ // void *_call_addr; /* calling user insn */
+ // int _syscall; /* triggering system call number */
+ // unsigned int _arch; /* AUDIT_ARCH_* of syscall */
+ // } _sigsys;
+ // } _sifields;
+ //
+ // _sifields is padded so that the size of siginfo is SI_MAX_SIZE = 128
+ // bytes.
+ Fields [128 - 16]byte
+}
+
+// FixSignalCodeForUser fixes up si_code.
+//
+// The si_code we get from Linux may contain the kernel-specific code in the
+// top 16 bits if it's positive (e.g., from ptrace). Linux's
+// copy_siginfo_to_user does
+// err |= __put_user((short)from->si_code, &to->si_code);
+// to mask out those bits and we need to do the same.
+func (s *SignalInfo) FixSignalCodeForUser() {
+ if s.Code > 0 {
+ s.Code &= 0x0000ffff
+ }
+}
+
+// Pid returns the si_pid field.
+func (s *SignalInfo) Pid() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[0:4]))
+}
+
+// SetPid mutates the si_pid field.
+func (s *SignalInfo) SetPid(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
+}
+
+// Uid returns the si_uid field.
+func (s *SignalInfo) Uid() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
+}
+
+// SetUid mutates the si_uid field.
+func (s *SignalInfo) SetUid(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
+}
+
+// Sigval returns the sigval field, which is aliased to both si_int and si_ptr.
+func (s *SignalInfo) Sigval() uint64 {
+ return usermem.ByteOrder.Uint64(s.Fields[8:16])
+}
+
+// SetSigval mutates the sigval field.
+func (s *SignalInfo) SetSigval(val uint64) {
+ usermem.ByteOrder.PutUint64(s.Fields[8:16], val)
+}
+
+// TimerID returns the si_timerid field.
+func (s *SignalInfo) TimerID() linux.TimerID {
+ return linux.TimerID(usermem.ByteOrder.Uint32(s.Fields[0:4]))
+}
+
+// SetTimerID sets the si_timerid field.
+func (s *SignalInfo) SetTimerID(val linux.TimerID) {
+ usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
+}
+
+// Overrun returns the si_overrun field.
+func (s *SignalInfo) Overrun() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
+}
+
+// SetOverrun sets the si_overrun field.
+func (s *SignalInfo) SetOverrun(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
+}
+
+// Addr returns the si_addr field.
+func (s *SignalInfo) Addr() uint64 {
+ return usermem.ByteOrder.Uint64(s.Fields[0:8])
+}
+
+// SetAddr sets the si_addr field.
+func (s *SignalInfo) SetAddr(val uint64) {
+ usermem.ByteOrder.PutUint64(s.Fields[0:8], val)
+}
+
+// Status returns the si_status field.
+func (s *SignalInfo) Status() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[8:12]))
+}
+
+// SetStatus mutates the si_status field.
+func (s *SignalInfo) SetStatus(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
+}
+
+// CallAddr returns the si_call_addr field.
+func (s *SignalInfo) CallAddr() uint64 {
+ return usermem.ByteOrder.Uint64(s.Fields[0:8])
+}
+
+// SetCallAddr mutates the si_call_addr field.
+func (s *SignalInfo) SetCallAddr(val uint64) {
+ usermem.ByteOrder.PutUint64(s.Fields[0:8], val)
+}
+
+// Syscall returns the si_syscall field.
+func (s *SignalInfo) Syscall() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[8:12]))
+}
+
+// SetSyscall mutates the si_syscall field.
+func (s *SignalInfo) SetSyscall(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
+}
+
+// Arch returns the si_arch field.
+func (s *SignalInfo) Arch() uint32 {
+ return usermem.ByteOrder.Uint32(s.Fields[12:16])
+}
+
+// SetArch mutates the si_arch field.
+func (s *SignalInfo) SetArch(val uint32) {
+ usermem.ByteOrder.PutUint32(s.Fields[12:16], val)
+}
diff --git a/pkg/sentry/arch/signal_act.go b/pkg/sentry/arch/signal_act.go
new file mode 100644
index 000000000..32173aa20
--- /dev/null
+++ b/pkg/sentry/arch/signal_act.go
@@ -0,0 +1,83 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import "gvisor.dev/gvisor/tools/go_marshal/marshal"
+
+// Special values for SignalAct.Handler.
+const (
+ // SignalActDefault is SIG_DFL and specifies that the default behavior for
+ // a signal should be taken.
+ SignalActDefault = 0
+
+ // SignalActIgnore is SIG_IGN and specifies that a signal should be
+ // ignored.
+ SignalActIgnore = 1
+)
+
+// Available signal flags.
+const (
+ SignalFlagNoCldStop = 0x00000001
+ SignalFlagNoCldWait = 0x00000002
+ SignalFlagSigInfo = 0x00000004
+ SignalFlagRestorer = 0x04000000
+ SignalFlagOnStack = 0x08000000
+ SignalFlagRestart = 0x10000000
+ SignalFlagInterrupt = 0x20000000
+ SignalFlagNoDefer = 0x40000000
+ SignalFlagResetHandler = 0x80000000
+)
+
+// IsSigInfo returns true iff this handle expects siginfo.
+func (s SignalAct) IsSigInfo() bool {
+ return s.Flags&SignalFlagSigInfo != 0
+}
+
+// IsNoDefer returns true iff this SignalAct has the NoDefer flag set.
+func (s SignalAct) IsNoDefer() bool {
+ return s.Flags&SignalFlagNoDefer != 0
+}
+
+// IsRestart returns true iff this SignalAct has the Restart flag set.
+func (s SignalAct) IsRestart() bool {
+ return s.Flags&SignalFlagRestart != 0
+}
+
+// IsResetHandler returns true iff this SignalAct has the ResetHandler flag set.
+func (s SignalAct) IsResetHandler() bool {
+ return s.Flags&SignalFlagResetHandler != 0
+}
+
+// IsOnStack returns true iff this SignalAct has the OnStack flag set.
+func (s SignalAct) IsOnStack() bool {
+ return s.Flags&SignalFlagOnStack != 0
+}
+
+// HasRestorer returns true iff this SignalAct has the Restorer flag set.
+func (s SignalAct) HasRestorer() bool {
+ return s.Flags&SignalFlagRestorer != 0
+}
+
+// NativeSignalAct is a type that is equivalent to struct sigaction in the
+// guest architecture.
+type NativeSignalAct interface {
+ marshal.Marshallable
+
+ // SerializeFrom copies the data in the host SignalAct s into this object.
+ SerializeFrom(s *SignalAct)
+
+ // DeserializeTo copies the data in this object into the host SignalAct s.
+ DeserializeTo(s *SignalAct)
+}
diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go
new file mode 100644
index 000000000..6fb756f0e
--- /dev/null
+++ b/pkg/sentry/arch/signal_amd64.go
@@ -0,0 +1,291 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package arch
+
+import (
+ "encoding/binary"
+ "math"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// SignalContext64 is equivalent to struct sigcontext, the type passed as the
+// second argument to signal handlers set by signal(2).
+type SignalContext64 struct {
+ R8 uint64
+ R9 uint64
+ R10 uint64
+ R11 uint64
+ R12 uint64
+ R13 uint64
+ R14 uint64
+ R15 uint64
+ Rdi uint64
+ Rsi uint64
+ Rbp uint64
+ Rbx uint64
+ Rdx uint64
+ Rax uint64
+ Rcx uint64
+ Rsp uint64
+ Rip uint64
+ Eflags uint64
+ Cs uint16
+ Gs uint16 // always 0 on amd64.
+ Fs uint16 // always 0 on amd64.
+ Ss uint16 // only restored if _UC_STRICT_RESTORE_SS (unsupported).
+ Err uint64
+ Trapno uint64
+ Oldmask linux.SignalSet
+ Cr2 uint64
+ // Pointer to a struct _fpstate. See b/33003106#comment8.
+ Fpstate uint64
+ Reserved [8]uint64
+}
+
+// Flags for UContext64.Flags.
+const (
+ _UC_FP_XSTATE = 1
+ _UC_SIGCONTEXT_SS = 2
+ _UC_STRICT_RESTORE_SS = 4
+)
+
+// UContext64 is equivalent to ucontext_t on 64-bit x86.
+type UContext64 struct {
+ Flags uint64
+ Link uint64
+ Stack SignalStack
+ MContext SignalContext64
+ Sigset linux.SignalSet
+}
+
+// NewSignalAct implements Context.NewSignalAct.
+func (c *context64) NewSignalAct() NativeSignalAct {
+ return &SignalAct{}
+}
+
+// NewSignalStack implements Context.NewSignalStack.
+func (c *context64) NewSignalStack() NativeSignalStack {
+ return &SignalStack{}
+}
+
+// From Linux 'arch/x86/include/uapi/asm/sigcontext.h' the following is the
+// size of the magic cookie at the end of the xsave frame.
+//
+// NOTE(b/33003106#comment11): Currently we don't actually populate the fpstate
+// on the signal stack.
+const _FP_XSTATE_MAGIC2_SIZE = 4
+
+func (c *context64) fpuFrameSize() (size int, useXsave bool) {
+ size = len(c.x86FPState)
+ if size > 512 {
+ // Make room for the magic cookie at the end of the xsave frame.
+ size += _FP_XSTATE_MAGIC2_SIZE
+ useXsave = true
+ }
+ return size, useXsave
+}
+
+// SignalSetup implements Context.SignalSetup. (Compare to Linux's
+// arch/x86/kernel/signal.c:__setup_rt_frame().)
+func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error {
+ sp := st.Bottom
+
+ // "The 128-byte area beyond the location pointed to by %rsp is considered
+ // to be reserved and shall not be modified by signal or interrupt
+ // handlers. ... leaf functions may use this area for their entire stack
+ // frame, rather than adjusting the stack pointer in the prologue and
+ // epilogue." - AMD64 ABI
+ //
+ // (But this doesn't apply if we're starting at the top of the signal
+ // stack, in which case there is no following stack frame.)
+ if !(alt.IsEnabled() && sp == alt.Top()) {
+ sp -= 128
+ }
+
+ // Allocate space for floating point state on the stack.
+ //
+ // This isn't strictly necessary because we don't actually populate
+ // the fpstate. However we do store the floating point state of the
+ // interrupted thread inside the sentry. Simply accounting for this
+ // space on the user stack naturally caps the amount of memory the
+ // sentry will allocate for this purpose.
+ fpSize, _ := c.fpuFrameSize()
+ sp = (sp - usermem.Addr(fpSize)) & ^usermem.Addr(63)
+
+ // Construct the UContext64 now since we need its size.
+ uc := &UContext64{
+ // No _UC_FP_XSTATE: see Fpstate above.
+ // No _UC_STRICT_RESTORE_SS: we don't allow SS changes.
+ Flags: _UC_SIGCONTEXT_SS,
+ Stack: *alt,
+ MContext: SignalContext64{
+ R8: c.Regs.R8,
+ R9: c.Regs.R9,
+ R10: c.Regs.R10,
+ R11: c.Regs.R11,
+ R12: c.Regs.R12,
+ R13: c.Regs.R13,
+ R14: c.Regs.R14,
+ R15: c.Regs.R15,
+ Rdi: c.Regs.Rdi,
+ Rsi: c.Regs.Rsi,
+ Rbp: c.Regs.Rbp,
+ Rbx: c.Regs.Rbx,
+ Rdx: c.Regs.Rdx,
+ Rax: c.Regs.Rax,
+ Rcx: c.Regs.Rcx,
+ Rsp: c.Regs.Rsp,
+ Rip: c.Regs.Rip,
+ Eflags: c.Regs.Eflags,
+ Cs: uint16(c.Regs.Cs),
+ Ss: uint16(c.Regs.Ss),
+ Oldmask: sigset,
+ },
+ Sigset: sigset,
+ }
+
+ // TODO(gvisor.dev/issue/159): Set SignalContext64.Err, Trapno, and Cr2
+ // based on the fault that caused the signal. For now, leave Err and
+ // Trapno unset and assume CR2 == info.Addr() for SIGSEGVs and
+ // SIGBUSes.
+ if linux.Signal(info.Signo) == linux.SIGSEGV || linux.Signal(info.Signo) == linux.SIGBUS {
+ uc.MContext.Cr2 = info.Addr()
+ }
+
+ // "... the value (%rsp+8) is always a multiple of 16 (...) when
+ // control is transferred to the function entry point." - AMD64 ABI
+ ucSize := binary.Size(uc)
+ if ucSize < 0 {
+ // This can only happen if we've screwed up the definition of
+ // UContext64.
+ panic("can't get size of UContext64")
+ }
+ // st.Arch.Width() is for the restorer address. sizeof(siginfo) == 128.
+ frameSize := int(st.Arch.Width()) + ucSize + 128
+ frameBottom := (sp-usermem.Addr(frameSize)) & ^usermem.Addr(15) - 8
+ sp = frameBottom + usermem.Addr(frameSize)
+ st.Bottom = sp
+
+ // Prior to proceeding, figure out if the frame will exhaust the range
+ // for the signal stack. This is not allowed, and should immediately
+ // force signal delivery (reverting to the default handler).
+ if act.IsOnStack() && alt.IsEnabled() && !alt.Contains(frameBottom) {
+ return syscall.EFAULT
+ }
+
+ // Adjust the code.
+ info.FixSignalCodeForUser()
+
+ // Set up the stack frame.
+ infoAddr, err := st.Push(info)
+ if err != nil {
+ return err
+ }
+ ucAddr, err := st.Push(uc)
+ if err != nil {
+ return err
+ }
+ if act.HasRestorer() {
+ // Push the restorer return address.
+ // Note that this doesn't need to be popped.
+ if _, err := st.Push(usermem.Addr(act.Restorer)); err != nil {
+ return err
+ }
+ } else {
+ // amd64 requires a restorer.
+ return syscall.EFAULT
+ }
+
+ // Set up registers.
+ c.Regs.Rip = act.Handler
+ c.Regs.Rsp = uint64(st.Bottom)
+ c.Regs.Rdi = uint64(info.Signo)
+ c.Regs.Rsi = uint64(infoAddr)
+ c.Regs.Rdx = uint64(ucAddr)
+ c.Regs.Rax = 0
+ c.Regs.Ds = userDS
+ c.Regs.Es = userDS
+ c.Regs.Cs = userCS
+ c.Regs.Ss = userDS
+
+ // Save the thread's floating point state.
+ c.sigFPState = append(c.sigFPState, c.x86FPState)
+
+ // Signal handler gets a clean floating point state.
+ c.x86FPState = newX86FPState()
+
+ return nil
+}
+
+// SignalRestore implements Context.SignalRestore. (Compare to Linux's
+// arch/x86/kernel/signal.c:sys_rt_sigreturn().)
+func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) {
+ // Copy out the stack frame.
+ var uc UContext64
+ if _, err := st.Pop(&uc); err != nil {
+ return 0, SignalStack{}, err
+ }
+ var info SignalInfo
+ if _, err := st.Pop(&info); err != nil {
+ return 0, SignalStack{}, err
+ }
+
+ // Restore registers.
+ c.Regs.R8 = uc.MContext.R8
+ c.Regs.R9 = uc.MContext.R9
+ c.Regs.R10 = uc.MContext.R10
+ c.Regs.R11 = uc.MContext.R11
+ c.Regs.R12 = uc.MContext.R12
+ c.Regs.R13 = uc.MContext.R13
+ c.Regs.R14 = uc.MContext.R14
+ c.Regs.R15 = uc.MContext.R15
+ c.Regs.Rdi = uc.MContext.Rdi
+ c.Regs.Rsi = uc.MContext.Rsi
+ c.Regs.Rbp = uc.MContext.Rbp
+ c.Regs.Rbx = uc.MContext.Rbx
+ c.Regs.Rdx = uc.MContext.Rdx
+ c.Regs.Rax = uc.MContext.Rax
+ c.Regs.Rcx = uc.MContext.Rcx
+ c.Regs.Rsp = uc.MContext.Rsp
+ c.Regs.Rip = uc.MContext.Rip
+ c.Regs.Eflags = (c.Regs.Eflags & ^eflagsRestorable) | (uc.MContext.Eflags & eflagsRestorable)
+ c.Regs.Cs = uint64(uc.MContext.Cs) | 3
+ // N.B. _UC_STRICT_RESTORE_SS not supported.
+ c.Regs.Orig_rax = math.MaxUint64
+
+ // Restore floating point state.
+ l := len(c.sigFPState)
+ if l > 0 {
+ c.x86FPState = c.sigFPState[l-1]
+ // NOTE(cl/133042258): State save requires that any slice
+ // elements from '[len:cap]' to be zero value.
+ c.sigFPState[l-1] = nil
+ c.sigFPState = c.sigFPState[0 : l-1]
+ } else {
+ // This might happen if sigreturn(2) calls are unbalanced with
+ // respect to signal handler entries. This is not expected so
+ // don't bother to do anything fancy with the floating point
+ // state.
+ log.Infof("sigreturn unable to restore application fpstate")
+ }
+
+ return uc.Sigset, uc.Stack, nil
+}
diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go
new file mode 100644
index 000000000..642c79dda
--- /dev/null
+++ b/pkg/sentry/arch/signal_arm64.go
@@ -0,0 +1,181 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import (
+ "encoding/binary"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// SignalContext64 is equivalent to struct sigcontext, the type passed as the
+// second argument to signal handlers set by signal(2).
+type SignalContext64 struct {
+ FaultAddr uint64
+ Regs [31]uint64
+ Sp uint64
+ Pc uint64
+ Pstate uint64
+ _pad [8]byte // __attribute__((__aligned__(16)))
+ Fpsimd64 FpsimdContext // size = 528
+ Reserved [3568]uint8
+}
+
+type aarch64Ctx struct {
+ Magic uint32
+ Size uint32
+}
+
+// FpsimdContext is equivalent to struct fpsimd_context on arm64
+// (arch/arm64/include/uapi/asm/sigcontext.h).
+type FpsimdContext struct {
+ Head aarch64Ctx
+ Fpsr uint32
+ Fpcr uint32
+ Vregs [64]uint64 // actually [32]uint128
+}
+
+// UContext64 is equivalent to ucontext on arm64(arch/arm64/include/uapi/asm/ucontext.h).
+type UContext64 struct {
+ Flags uint64
+ Link uint64
+ Stack SignalStack
+ Sigset linux.SignalSet
+ // glibc uses a 1024-bit sigset_t
+ _pad [(1024 - 64) / 8]byte
+ // sigcontext must be aligned to 16-byte
+ _pad2 [8]byte
+ // last for future expansion
+ MContext SignalContext64
+}
+
+// NewSignalAct implements Context.NewSignalAct.
+func (c *context64) NewSignalAct() NativeSignalAct {
+ return &SignalAct{}
+}
+
+// NewSignalStack implements Context.NewSignalStack.
+func (c *context64) NewSignalStack() NativeSignalStack {
+ return &SignalStack{}
+}
+
+// SignalSetup implements Context.SignalSetup.
+func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error {
+ sp := st.Bottom
+
+ if !(alt.IsEnabled() && sp == alt.Top()) {
+ sp -= 128
+ }
+
+ // Construct the UContext64 now since we need its size.
+ uc := &UContext64{
+ Flags: 0,
+ Stack: *alt,
+ MContext: SignalContext64{
+ Regs: c.Regs.Regs,
+ Sp: c.Regs.Sp,
+ Pc: c.Regs.Pc,
+ Pstate: c.Regs.Pstate,
+ },
+ Sigset: sigset,
+ }
+
+ ucSize := binary.Size(uc)
+ if ucSize < 0 {
+ panic("can't get size of UContext64")
+ }
+
+ // frameSize = ucSize + sizeof(siginfo).
+ // sizeof(siginfo) == 128.
+ // R30 stores the restorer address.
+ frameSize := ucSize + 128
+ frameBottom := (sp - usermem.Addr(frameSize)) & ^usermem.Addr(15)
+ sp = frameBottom + usermem.Addr(frameSize)
+ st.Bottom = sp
+
+ // Prior to proceeding, figure out if the frame will exhaust the range
+ // for the signal stack. This is not allowed, and should immediately
+ // force signal delivery (reverting to the default handler).
+ if act.IsOnStack() && alt.IsEnabled() && !alt.Contains(frameBottom) {
+ return syscall.EFAULT
+ }
+
+ // Adjust the code.
+ info.FixSignalCodeForUser()
+
+ // Set up the stack frame.
+ infoAddr, err := st.Push(info)
+ if err != nil {
+ return err
+ }
+ ucAddr, err := st.Push(uc)
+ if err != nil {
+ return err
+ }
+
+ // Set up registers.
+ c.Regs.Sp = uint64(st.Bottom)
+ c.Regs.Pc = act.Handler
+ c.Regs.Regs[0] = uint64(info.Signo)
+ c.Regs.Regs[1] = uint64(infoAddr)
+ c.Regs.Regs[2] = uint64(ucAddr)
+ c.Regs.Regs[30] = uint64(act.Restorer)
+
+ // Save the thread's floating point state.
+ c.sigFPState = append(c.sigFPState, c.aarch64FPState)
+ // Signal handler gets a clean floating point state.
+ c.aarch64FPState = newAarch64FPState()
+ return nil
+}
+
+// SignalRestore implements Context.SignalRestore.
+func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) {
+ // Copy out the stack frame.
+ var uc UContext64
+ if _, err := st.Pop(&uc); err != nil {
+ return 0, SignalStack{}, err
+ }
+ var info SignalInfo
+ if _, err := st.Pop(&info); err != nil {
+ return 0, SignalStack{}, err
+ }
+
+ // Restore registers.
+ c.Regs.Regs = uc.MContext.Regs
+ c.Regs.Pc = uc.MContext.Pc
+ c.Regs.Sp = uc.MContext.Sp
+ c.Regs.Pstate = uc.MContext.Pstate
+
+ // Restore floating point state.
+ l := len(c.sigFPState)
+ if l > 0 {
+ c.aarch64FPState = c.sigFPState[l-1]
+ // NOTE(cl/133042258): State save requires that any slice
+ // elements from '[len:cap]' to be zero value.
+ c.sigFPState[l-1] = nil
+ c.sigFPState = c.sigFPState[0 : l-1]
+ } else {
+ // This might happen if sigreturn(2) calls are unbalanced with
+ // respect to signal handler entries. This is not expected so
+ // don't bother to do anything fancy with the floating point
+ // state.
+ log.Warningf("sigreturn unable to restore application fpstate")
+ }
+
+ return uc.Sigset, uc.Stack, nil
+}
diff --git a/pkg/sentry/arch/signal_info.go b/pkg/sentry/arch/signal_info.go
new file mode 100644
index 000000000..f93ee8b46
--- /dev/null
+++ b/pkg/sentry/arch/signal_info.go
@@ -0,0 +1,66 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+// Possible values for SignalInfo.Code. These values originate from the Linux
+// kernel's include/uapi/asm-generic/siginfo.h.
+const (
+ // SignalInfoUser (properly SI_USER) indicates that a signal was sent from
+ // a kill() or raise() syscall.
+ SignalInfoUser = 0
+
+ // SignalInfoKernel (properly SI_KERNEL) indicates that the signal was sent
+ // by the kernel.
+ SignalInfoKernel = 0x80
+
+ // SignalInfoTimer (properly SI_TIMER) indicates that the signal was sent
+ // by an expired timer.
+ SignalInfoTimer = -2
+
+ // SignalInfoTkill (properly SI_TKILL) indicates that the signal was sent
+ // from a tkill() or tgkill() syscall.
+ SignalInfoTkill = -6
+
+ // CLD_* codes are only meaningful for SIGCHLD.
+
+ // CLD_EXITED indicates that a task exited.
+ CLD_EXITED = 1
+
+ // CLD_KILLED indicates that a task was killed by a signal.
+ CLD_KILLED = 2
+
+ // CLD_DUMPED indicates that a task was killed by a signal and then dumped
+ // core.
+ CLD_DUMPED = 3
+
+ // CLD_TRAPPED indicates that a task was stopped by ptrace.
+ CLD_TRAPPED = 4
+
+ // CLD_STOPPED indicates that a thread group completed a group stop.
+ CLD_STOPPED = 5
+
+ // CLD_CONTINUED indicates that a group-stopped thread group was continued.
+ CLD_CONTINUED = 6
+
+ // SYS_* codes are only meaningful for SIGSYS.
+
+ // SYS_SECCOMP indicates that a signal originates from seccomp.
+ SYS_SECCOMP = 1
+
+ // TRAP_* codes are only meaningful for SIGTRAP.
+
+ // TRAP_BRKPT indicates a breakpoint trap.
+ TRAP_BRKPT = 1
+)
diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go
new file mode 100644
index 000000000..0fa738a1d
--- /dev/null
+++ b/pkg/sentry/arch/signal_stack.go
@@ -0,0 +1,68 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build 386 amd64 arm64
+
+package arch
+
+import (
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+)
+
+const (
+ // SignalStackFlagOnStack is possible set on return from getaltstack,
+ // in order to indicate that the thread is currently on the alt stack.
+ SignalStackFlagOnStack = 1
+
+ // SignalStackFlagDisable is a flag to indicate the stack is disabled.
+ SignalStackFlagDisable = 2
+)
+
+// IsEnabled returns true iff this signal stack is marked as enabled.
+func (s SignalStack) IsEnabled() bool {
+ return s.Flags&SignalStackFlagDisable == 0
+}
+
+// Top returns the stack's top address.
+func (s SignalStack) Top() usermem.Addr {
+ return usermem.Addr(s.Addr + s.Size)
+}
+
+// SetOnStack marks this signal stack as in use.
+//
+// Note that there is no corresponding ClearOnStack, and that this should only
+// be called on copies that are serialized to userspace.
+func (s *SignalStack) SetOnStack() {
+ s.Flags |= SignalStackFlagOnStack
+}
+
+// Contains checks if the stack pointer is within this stack.
+func (s *SignalStack) Contains(sp usermem.Addr) bool {
+ return usermem.Addr(s.Addr) < sp && sp <= usermem.Addr(s.Addr+s.Size)
+}
+
+// NativeSignalStack is a type that is equivalent to stack_t in the guest
+// architecture.
+type NativeSignalStack interface {
+ marshal.Marshallable
+
+ // SerializeFrom copies the data in the host SignalStack s into this
+ // object.
+ SerializeFrom(s *SignalStack)
+
+ // DeserializeTo copies the data in this object into the host SignalStack
+ // s.
+ DeserializeTo(s *SignalStack)
+}
diff --git a/pkg/sentry/arch/stack.go b/pkg/sentry/arch/stack.go
new file mode 100644
index 000000000..1108fa0bd
--- /dev/null
+++ b/pkg/sentry/arch/stack.go
@@ -0,0 +1,249 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import (
+ "encoding/binary"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Stack is a simple wrapper around a usermem.IO and an address.
+type Stack struct {
+ // Our arch info.
+ // We use this for automatic Native conversion of usermem.Addrs during
+ // Push() and Pop().
+ Arch Context
+
+ // The interface used to actually copy user memory.
+ IO usermem.IO
+
+ // Our current stack bottom.
+ Bottom usermem.Addr
+}
+
+// Push pushes the given values on to the stack.
+//
+// (This method supports Addrs and treats them as native types.)
+func (s *Stack) Push(vals ...interface{}) (usermem.Addr, error) {
+ for _, v := range vals {
+
+ // We convert some types to well-known serializable quanities.
+ var norm interface{}
+
+ // For array types, we will automatically add an appropriate
+ // terminal value. This is done simply to make the interface
+ // easier to use.
+ var term interface{}
+
+ switch v.(type) {
+ case string:
+ norm = []byte(v.(string))
+ term = byte(0)
+ case []int8, []uint8:
+ norm = v
+ term = byte(0)
+ case []int16, []uint16:
+ norm = v
+ term = uint16(0)
+ case []int32, []uint32:
+ norm = v
+ term = uint32(0)
+ case []int64, []uint64:
+ norm = v
+ term = uint64(0)
+ case []usermem.Addr:
+ // Special case: simply push recursively.
+ _, err := s.Push(s.Arch.Native(uintptr(0)))
+ if err != nil {
+ return 0, err
+ }
+ varr := v.([]usermem.Addr)
+ for i := len(varr) - 1; i >= 0; i-- {
+ _, err := s.Push(varr[i])
+ if err != nil {
+ return 0, err
+ }
+ }
+ continue
+ case usermem.Addr:
+ norm = s.Arch.Native(uintptr(v.(usermem.Addr)))
+ default:
+ norm = v
+ }
+
+ if term != nil {
+ _, err := s.Push(term)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ c := binary.Size(norm)
+ if c < 0 {
+ return 0, fmt.Errorf("bad binary.Size for %T", v)
+ }
+ n, err := usermem.CopyObjectOut(context.Background(), s.IO, s.Bottom-usermem.Addr(c), norm, usermem.IOOpts{})
+ if err != nil || c != n {
+ return 0, err
+ }
+
+ s.Bottom -= usermem.Addr(n)
+ }
+
+ return s.Bottom, nil
+}
+
+// Pop pops the given values off the stack.
+//
+// (This method supports Addrs and treats them as native types.)
+func (s *Stack) Pop(vals ...interface{}) (usermem.Addr, error) {
+ for _, v := range vals {
+
+ vaddr, isVaddr := v.(*usermem.Addr)
+
+ var n int
+ var err error
+ if isVaddr {
+ value := s.Arch.Native(uintptr(0))
+ n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, value, usermem.IOOpts{})
+ *vaddr = usermem.Addr(s.Arch.Value(value))
+ } else {
+ n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, v, usermem.IOOpts{})
+ }
+ if err != nil {
+ return 0, err
+ }
+
+ s.Bottom += usermem.Addr(n)
+ }
+
+ return s.Bottom, nil
+}
+
+// Align aligns the stack to the given offset.
+func (s *Stack) Align(offset int) {
+ if s.Bottom%usermem.Addr(offset) != 0 {
+ s.Bottom -= (s.Bottom % usermem.Addr(offset))
+ }
+}
+
+// StackLayout describes the location of the arguments and environment on the
+// stack.
+type StackLayout struct {
+ // ArgvStart is the beginning of the argument vector.
+ ArgvStart usermem.Addr
+
+ // ArgvEnd is the end of the argument vector.
+ ArgvEnd usermem.Addr
+
+ // EnvvStart is the beginning of the environment vector.
+ EnvvStart usermem.Addr
+
+ // EnvvEnd is the end of the environment vector.
+ EnvvEnd usermem.Addr
+}
+
+// Load pushes the given args, env and aux vector to the stack using the
+// well-known format for a new executable. It returns the start and end
+// of the argument and environment vectors.
+func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error) {
+ l := StackLayout{}
+
+ // Make sure we start with a 16-byte alignment.
+ s.Align(16)
+
+ // Push the environment vector so the end of the argument vector is adjacent to
+ // the beginning of the environment vector.
+ // While the System V abi for x86_64 does not specify an ordering to the
+ // Information Block (the block holding the arg, env, and aux vectors),
+ // support features like setproctitle(3) naturally expect these segments
+ // to be in this order. See: https://www.uclibc.org/docs/psABI-x86_64.pdf
+ // page 29.
+ l.EnvvEnd = s.Bottom
+ envAddrs := make([]usermem.Addr, len(env))
+ for i := len(env) - 1; i >= 0; i-- {
+ addr, err := s.Push(env[i])
+ if err != nil {
+ return StackLayout{}, err
+ }
+ envAddrs[i] = addr
+ }
+ l.EnvvStart = s.Bottom
+
+ // Push our strings.
+ l.ArgvEnd = s.Bottom
+ argAddrs := make([]usermem.Addr, len(args))
+ for i := len(args) - 1; i >= 0; i-- {
+ addr, err := s.Push(args[i])
+ if err != nil {
+ return StackLayout{}, err
+ }
+ argAddrs[i] = addr
+ }
+ l.ArgvStart = s.Bottom
+
+ // We need to align the arguments appropriately.
+ //
+ // We must finish on a 16-byte alignment, but we'll play it
+ // conservatively and finish at 32-bytes. It would be nice to be able
+ // to call Align here, but unfortunately we need to align the stack
+ // with all the variable sized arrays pushed. So we just need to do
+ // some calculations.
+ argvSize := s.Arch.Width() * uint(len(args)+1)
+ envvSize := s.Arch.Width() * uint(len(env)+1)
+ auxvSize := s.Arch.Width() * 2 * uint(len(aux)+1)
+ total := usermem.Addr(argvSize) + usermem.Addr(envvSize) + usermem.Addr(auxvSize) + usermem.Addr(s.Arch.Width())
+ expectedBottom := s.Bottom - total
+ if expectedBottom%32 != 0 {
+ s.Bottom -= expectedBottom % 32
+ }
+
+ // Push our auxvec.
+ // NOTE: We need an extra zero here per spec.
+ // The Push function will automatically terminate
+ // strings and arrays with a single null value.
+ auxv := make([]usermem.Addr, 0, len(aux))
+ for _, a := range aux {
+ auxv = append(auxv, usermem.Addr(a.Key), a.Value)
+ }
+ auxv = append(auxv, usermem.Addr(0))
+ _, err := s.Push(auxv)
+ if err != nil {
+ return StackLayout{}, err
+ }
+
+ // Push environment.
+ _, err = s.Push(envAddrs)
+ if err != nil {
+ return StackLayout{}, err
+ }
+
+ // Push args.
+ _, err = s.Push(argAddrs)
+ if err != nil {
+ return StackLayout{}, err
+ }
+
+ // Push arg count.
+ _, err = s.Push(usermem.Addr(len(args)))
+ if err != nil {
+ return StackLayout{}, err
+ }
+
+ return l, nil
+}
diff --git a/pkg/sentry/arch/syscalls_amd64.go b/pkg/sentry/arch/syscalls_amd64.go
new file mode 100644
index 000000000..3859f41ee
--- /dev/null
+++ b/pkg/sentry/arch/syscalls_amd64.go
@@ -0,0 +1,59 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package arch
+
+const restartSyscallNr = uintptr(219)
+
+// SyscallSaveOrig save the value of the register which is clobbered in
+// syscall handler(doSyscall()).
+//
+// Noop on x86.
+func (c *context64) SyscallSaveOrig() {
+}
+
+// SyscallNo returns the syscall number according to the 64-bit convention.
+func (c *context64) SyscallNo() uintptr {
+ return uintptr(c.Regs.Orig_rax)
+}
+
+// SyscallArgs provides syscall arguments according to the 64-bit convention.
+//
+// Due to the way addresses are mapped for the sentry this binary *must* be
+// built in 64-bit mode. So we can just assume the syscall numbers that come
+// back match the expected host system call numbers.
+func (c *context64) SyscallArgs() SyscallArguments {
+ return SyscallArguments{
+ SyscallArgument{Value: uintptr(c.Regs.Rdi)},
+ SyscallArgument{Value: uintptr(c.Regs.Rsi)},
+ SyscallArgument{Value: uintptr(c.Regs.Rdx)},
+ SyscallArgument{Value: uintptr(c.Regs.R10)},
+ SyscallArgument{Value: uintptr(c.Regs.R8)},
+ SyscallArgument{Value: uintptr(c.Regs.R9)},
+ }
+}
+
+// RestartSyscall implements Context.RestartSyscall.
+func (c *context64) RestartSyscall() {
+ c.Regs.Rip -= SyscallWidth
+ c.Regs.Rax = c.Regs.Orig_rax
+}
+
+// RestartSyscallWithRestartBlock implements Context.RestartSyscallWithRestartBlock.
+func (c *context64) RestartSyscallWithRestartBlock() {
+ c.Regs.Rip -= SyscallWidth
+ c.Regs.Rax = uint64(restartSyscallNr)
+}
diff --git a/pkg/sentry/arch/syscalls_arm64.go b/pkg/sentry/arch/syscalls_arm64.go
new file mode 100644
index 000000000..95dfd1e90
--- /dev/null
+++ b/pkg/sentry/arch/syscalls_arm64.go
@@ -0,0 +1,81 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package arch
+
+const restartSyscallNr = uintptr(128)
+
+// SyscallSaveOrig save the value of the register R0 which is clobbered in
+// syscall handler(doSyscall()).
+//
+// In linux, at the entry of the syscall handler(el0_svc_common()), value of R0
+// is saved to the pt_regs.orig_x0 in kernel code. But currently, the orig_x0
+// was not accessible to the userspace application, so we have to do the same
+// operation in the sentry code to save the R0 value into the App context.
+func (c *context64) SyscallSaveOrig() {
+ c.OrigR0 = c.Regs.Regs[0]
+}
+
+// SyscallNo returns the syscall number according to the 64-bit convention.
+func (c *context64) SyscallNo() uintptr {
+ return uintptr(c.Regs.Regs[8])
+}
+
+// SyscallArgs provides syscall arguments according to the 64-bit convention.
+//
+// Due to the way addresses are mapped for the sentry this binary *must* be
+// built in 64-bit mode. So we can just assume the syscall numbers that come
+// back match the expected host system call numbers.
+// General purpose registers usage on Arm64:
+// R0...R7: parameter/result registers.
+// R8: indirect result location register.
+// R9...R15: temporary registers.
+// R16: the first intra-procedure-call scratch register.
+// R17: the second intra-procedure-call scratch register.
+// R18: the platform register.
+// R19...R28: callee-saved registers.
+// R29: the frame pointer.
+// R30: the link register.
+func (c *context64) SyscallArgs() SyscallArguments {
+ return SyscallArguments{
+ SyscallArgument{Value: uintptr(c.OrigR0)},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[1])},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[2])},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[3])},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[4])},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[5])},
+ }
+}
+
+// RestartSyscall implements Context.RestartSyscall.
+// Prepare for system call restart, OrigR0 will be restored to R0.
+// Please see the linux code as reference:
+// arch/arm64/kernel/signal.c:do_signal()
+func (c *context64) RestartSyscall() {
+ c.Regs.Pc -= SyscallWidth
+ // R0 will be backed up into OrigR0 when entering doSyscall().
+ // Please see the linux code as reference:
+ // arch/arm64/kernel/syscall.c:el0_svc_common().
+ // Here we restore it back.
+ c.Regs.Regs[0] = uint64(c.OrigR0)
+}
+
+// RestartSyscallWithRestartBlock implements Context.RestartSyscallWithRestartBlock.
+func (c *context64) RestartSyscallWithRestartBlock() {
+ c.Regs.Pc -= SyscallWidth
+ c.Regs.Regs[0] = uint64(c.OrigR0)
+ c.Regs.Regs[8] = uint64(restartSyscallNr)
+}
diff --git a/pkg/sentry/contexttest/BUILD b/pkg/sentry/contexttest/BUILD
new file mode 100644
index 000000000..6f4c86684
--- /dev/null
+++ b/pkg/sentry/contexttest/BUILD
@@ -0,0 +1,21 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "contexttest",
+ testonly = 1,
+ srcs = ["contexttest.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/context",
+ "//pkg/memutil",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/platform/ptrace",
+ "//pkg/sentry/uniqueid",
+ ],
+)
diff --git a/pkg/sentry/contexttest/contexttest.go b/pkg/sentry/contexttest/contexttest.go
new file mode 100644
index 000000000..8e5658c7a
--- /dev/null
+++ b/pkg/sentry/contexttest/contexttest.go
@@ -0,0 +1,188 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package contexttest builds a test context.Context.
+package contexttest
+
+import (
+ "os"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/memutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ptrace"
+ "gvisor.dev/gvisor/pkg/sentry/uniqueid"
+)
+
+// Context returns a Context that may be used in tests. Uses ptrace as the
+// platform.Platform.
+//
+// Note that some filesystems may require a minimal kernel for testing, which
+// this test context does not provide. For such tests, see kernel/contexttest.
+func Context(tb testing.TB) context.Context {
+ const memfileName = "contexttest-memory"
+ memfd, err := memutil.CreateMemFD(memfileName, 0)
+ if err != nil {
+ tb.Fatalf("error creating application memory file: %v", err)
+ }
+ memfile := os.NewFile(uintptr(memfd), memfileName)
+ mf, err := pgalloc.NewMemoryFile(memfile, pgalloc.MemoryFileOpts{})
+ if err != nil {
+ memfile.Close()
+ tb.Fatalf("error creating pgalloc.MemoryFile: %v", err)
+ }
+ p, err := ptrace.New()
+ if err != nil {
+ tb.Fatal(err)
+ }
+ // Test usage of context.Background is fine.
+ return &TestContext{
+ Context: context.Background(),
+ l: limits.NewLimitSet(),
+ mf: mf,
+ platform: p,
+ creds: auth.NewAnonymousCredentials(),
+ otherValues: make(map[interface{}]interface{}),
+ }
+}
+
+// TestContext represents a context with minimal functionality suitable for
+// running tests.
+type TestContext struct {
+ context.Context
+ l *limits.LimitSet
+ mf *pgalloc.MemoryFile
+ platform platform.Platform
+ creds *auth.Credentials
+ otherValues map[interface{}]interface{}
+}
+
+// globalUniqueID tracks incremental unique identifiers for tests.
+var globalUniqueID uint64
+
+// globalUniqueIDProvider implements unix.UniqueIDProvider.
+type globalUniqueIDProvider struct{}
+
+// UniqueID implements unix.UniqueIDProvider.UniqueID.
+func (*globalUniqueIDProvider) UniqueID() uint64 {
+ return atomic.AddUint64(&globalUniqueID, 1)
+}
+
+// lastInotifyCookie is a monotonically increasing counter for generating unique
+// inotify cookies. Must be accessed using atomic ops.
+var lastInotifyCookie uint32
+
+// hostClock implements ktime.Clock.
+type hostClock struct {
+ ktime.WallRateClock
+ ktime.NoClockEvents
+}
+
+// Now implements ktime.Clock.Now.
+func (*hostClock) Now() ktime.Time {
+ return ktime.FromNanoseconds(time.Now().UnixNano())
+}
+
+// RegisterValue registers additional values with this test context. Useful for
+// providing values from external packages that contexttest can't depend on.
+func (t *TestContext) RegisterValue(key, value interface{}) {
+ t.otherValues[key] = value
+}
+
+// Value implements context.Context.
+func (t *TestContext) Value(key interface{}) interface{} {
+ switch key {
+ case auth.CtxCredentials:
+ return t.creds
+ case limits.CtxLimits:
+ return t.l
+ case pgalloc.CtxMemoryFile:
+ return t.mf
+ case pgalloc.CtxMemoryFileProvider:
+ return t
+ case platform.CtxPlatform:
+ return t.platform
+ case uniqueid.CtxGlobalUniqueID:
+ return (*globalUniqueIDProvider).UniqueID(nil)
+ case uniqueid.CtxGlobalUniqueIDProvider:
+ return &globalUniqueIDProvider{}
+ case uniqueid.CtxInotifyCookie:
+ return atomic.AddUint32(&lastInotifyCookie, 1)
+ case ktime.CtxRealtimeClock:
+ return &hostClock{}
+ default:
+ if val, ok := t.otherValues[key]; ok {
+ return val
+ }
+ return t.Context.Value(key)
+ }
+}
+
+// MemoryFile implements pgalloc.MemoryFileProvider.MemoryFile.
+func (t *TestContext) MemoryFile() *pgalloc.MemoryFile {
+ return t.mf
+}
+
+// RootContext returns a Context that may be used in tests that need root
+// credentials. Uses ptrace as the platform.Platform.
+func RootContext(tb testing.TB) context.Context {
+ return WithCreds(Context(tb), auth.NewRootCredentials(auth.NewRootUserNamespace()))
+}
+
+// WithCreds returns a copy of ctx carrying creds.
+func WithCreds(ctx context.Context, creds *auth.Credentials) context.Context {
+ return &authContext{ctx, creds}
+}
+
+type authContext struct {
+ context.Context
+ creds *auth.Credentials
+}
+
+// Value implements context.Context.
+func (ac *authContext) Value(key interface{}) interface{} {
+ switch key {
+ case auth.CtxCredentials:
+ return ac.creds
+ default:
+ return ac.Context.Value(key)
+ }
+}
+
+// WithLimitSet returns a copy of ctx carrying l.
+func WithLimitSet(ctx context.Context, l *limits.LimitSet) context.Context {
+ return limitContext{ctx, l}
+}
+
+type limitContext struct {
+ context.Context
+ l *limits.LimitSet
+}
+
+// Value implements context.Context.
+func (lc limitContext) Value(key interface{}) interface{} {
+ switch key {
+ case limits.CtxLimits:
+ return lc.l
+ default:
+ return lc.Context.Value(key)
+ }
+}
diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD
new file mode 100644
index 000000000..2c5d14be5
--- /dev/null
+++ b/pkg/sentry/control/BUILD
@@ -0,0 +1,52 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "control",
+ srcs = [
+ "control.go",
+ "logging.go",
+ "pprof.go",
+ "proc.go",
+ "state.go",
+ ],
+ visibility = [
+ "//:sandbox",
+ ],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/fd",
+ "//pkg/log",
+ "//pkg/sentry/fdimport",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/host",
+ "//pkg/sentry/fs/user",
+ "//pkg/sentry/fsimpl/host",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/state",
+ "//pkg/sentry/strace",
+ "//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
+ "//pkg/sentry/watchdog",
+ "//pkg/sync",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/urpc",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "control_test",
+ size = "small",
+ srcs = ["proc_test.go"],
+ library = ":control",
+ deps = [
+ "//pkg/log",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/usage",
+ ],
+)
diff --git a/pkg/sentry/control/control.go b/pkg/sentry/control/control.go
new file mode 100644
index 000000000..6060b9b4f
--- /dev/null
+++ b/pkg/sentry/control/control.go
@@ -0,0 +1,17 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package control contains types that expose control server methods, and can
+// be used to configure and interact with a running sandbox process.
+package control
diff --git a/pkg/sentry/control/logging.go b/pkg/sentry/control/logging.go
new file mode 100644
index 000000000..8a500a515
--- /dev/null
+++ b/pkg/sentry/control/logging.go
@@ -0,0 +1,136 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package control
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/strace"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+)
+
+// LoggingArgs are the arguments to use for changing the logging
+// level and strace list.
+type LoggingArgs struct {
+ // SetLevel is a flag used to indicate that we should update
+ // the logging level. We should be able to change the strace
+ // list without affecting the logging level and vice versa.
+ SetLevel bool
+
+ // Level is the log level that will be set if SetLevel is true.
+ Level log.Level
+
+ // SetLogPackets indicates that we should update the log packets flag.
+ SetLogPackets bool
+
+ // LogPackets is the actual value to set for LogPackets.
+ // SetLogPackets must be enabled to indicate that we're changing
+ // the value.
+ LogPackets bool
+
+ // SetStrace is a flag used to indicate that strace related
+ // arguments were passed in.
+ SetStrace bool
+
+ // EnableStrace is a flag from the CLI that specifies whether to
+ // enable strace at all. If this flag is false then a completely
+ // pristine copy of the syscall table will be swapped in. This
+ // approach is used to remain consistent with an empty strace
+ // whitelist meaning trace all system calls.
+ EnableStrace bool
+
+ // Strace is the whitelist of syscalls to trace to log. If this
+ // and StraceEventWhitelist are empty trace all system calls.
+ StraceWhitelist []string
+
+ // SetEventStrace is a flag used to indicate that event strace
+ // related arguments were passed in.
+ SetEventStrace bool
+
+ // StraceEventWhitelist is the whitelist of syscalls to trace
+ // to event log.
+ StraceEventWhitelist []string
+}
+
+// Logging provides functions related to logging.
+type Logging struct{}
+
+// Change will change the log level and strace arguments. Although
+// this functions signature requires an error it never actually
+// returns an error. It's required by the URPC interface.
+// Additionally, it may look odd that this is the only method
+// attached to an empty struct but this is also part of how
+// URPC dispatches.
+func (l *Logging) Change(args *LoggingArgs, code *int) error {
+ if args.SetLevel {
+ // Logging uses an atomic for the level so this is thread safe.
+ log.SetLevel(args.Level)
+ }
+
+ if args.SetLogPackets {
+ if args.LogPackets {
+ atomic.StoreUint32(&sniffer.LogPackets, 1)
+ } else {
+ atomic.StoreUint32(&sniffer.LogPackets, 0)
+ }
+ log.Infof("LogPackets set to: %v", atomic.LoadUint32(&sniffer.LogPackets))
+ }
+
+ if args.SetStrace {
+ if err := l.configureStrace(args); err != nil {
+ return fmt.Errorf("error configuring strace: %v", err)
+ }
+ }
+
+ if args.SetEventStrace {
+ if err := l.configureEventStrace(args); err != nil {
+ return fmt.Errorf("error configuring event strace: %v", err)
+ }
+ }
+
+ return nil
+}
+
+func (l *Logging) configureStrace(args *LoggingArgs) error {
+ if args.EnableStrace {
+ // Install the whitelist specified.
+ if len(args.StraceWhitelist) > 0 {
+ if err := strace.Enable(args.StraceWhitelist, strace.SinkTypeLog); err != nil {
+ return err
+ }
+ } else {
+ // For convenience, if strace is enabled but whitelist
+ // is empty, enable everything to log.
+ strace.EnableAll(strace.SinkTypeLog)
+ }
+ } else {
+ // Uninstall all strace functions.
+ strace.Disable(strace.SinkTypeLog)
+ }
+ return nil
+}
+
+func (l *Logging) configureEventStrace(args *LoggingArgs) error {
+ if len(args.StraceEventWhitelist) > 0 {
+ if err := strace.Enable(args.StraceEventWhitelist, strace.SinkTypeEvent); err != nil {
+ return err
+ }
+ } else {
+ strace.Disable(strace.SinkTypeEvent)
+ }
+ return nil
+}
diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go
new file mode 100644
index 000000000..663e51989
--- /dev/null
+++ b/pkg/sentry/control/pprof.go
@@ -0,0 +1,209 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package control
+
+import (
+ "errors"
+ "runtime"
+ "runtime/pprof"
+ "runtime/trace"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/urpc"
+)
+
+var errNoOutput = errors.New("no output writer provided")
+
+// ProfileOpts contains options for the StartCPUProfile/Goroutine RPC call.
+type ProfileOpts struct {
+ // File is the filesystem path for the profile.
+ File string `json:"path"`
+
+ // FilePayload is the destination for the profiling output.
+ urpc.FilePayload
+}
+
+// Profile includes profile-related RPC stubs. It provides a way to
+// control the built-in pprof facility in sentry via sentryctl.
+//
+// The following options to sentryctl are added:
+//
+// - collect CPU profile on-demand.
+// sentryctl -pid <pid> pprof-cpu-start
+// sentryctl -pid <pid> pprof-cpu-stop
+//
+// - dump out the stack trace of current go routines.
+// sentryctl -pid <pid> pprof-goroutine
+type Profile struct {
+ // mu protects the fields below.
+ mu sync.Mutex
+
+ // cpuFile is the current CPU profile output file.
+ cpuFile *fd.FD
+
+ // traceFile is the current execution trace output file.
+ traceFile *fd.FD
+
+ // Kernel is the kernel under profile.
+ Kernel *kernel.Kernel
+}
+
+// StartCPUProfile is an RPC stub which starts recording the CPU profile in a
+// file.
+func (p *Profile) StartCPUProfile(o *ProfileOpts, _ *struct{}) error {
+ if len(o.FilePayload.Files) < 1 {
+ return errNoOutput
+ }
+
+ output, err := fd.NewFromFile(o.FilePayload.Files[0])
+ if err != nil {
+ return err
+ }
+
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ // Returns an error if profiling is already started.
+ if err := pprof.StartCPUProfile(output); err != nil {
+ output.Close()
+ return err
+ }
+
+ p.cpuFile = output
+ return nil
+}
+
+// StopCPUProfile is an RPC stub which stops the CPU profiling and flush out the
+// profile data. It takes no argument.
+func (p *Profile) StopCPUProfile(_, _ *struct{}) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if p.cpuFile == nil {
+ return errors.New("CPU profiling not started")
+ }
+
+ pprof.StopCPUProfile()
+ p.cpuFile.Close()
+ p.cpuFile = nil
+ return nil
+}
+
+// HeapProfile generates a heap profile for the sentry.
+func (p *Profile) HeapProfile(o *ProfileOpts, _ *struct{}) error {
+ if len(o.FilePayload.Files) < 1 {
+ return errNoOutput
+ }
+ output := o.FilePayload.Files[0]
+ defer output.Close()
+ runtime.GC() // Get up-to-date statistics.
+ if err := pprof.WriteHeapProfile(output); err != nil {
+ return err
+ }
+ return nil
+}
+
+// GoroutineProfile is an RPC stub which dumps out the stack trace for all
+// running goroutines.
+func (p *Profile) GoroutineProfile(o *ProfileOpts, _ *struct{}) error {
+ if len(o.FilePayload.Files) < 1 {
+ return errNoOutput
+ }
+ output := o.FilePayload.Files[0]
+ defer output.Close()
+ if err := pprof.Lookup("goroutine").WriteTo(output, 2); err != nil {
+ return err
+ }
+ return nil
+}
+
+// BlockProfile is an RPC stub which dumps out the stack trace that led to
+// blocking on synchronization primitives.
+func (p *Profile) BlockProfile(o *ProfileOpts, _ *struct{}) error {
+ if len(o.FilePayload.Files) < 1 {
+ return errNoOutput
+ }
+ output := o.FilePayload.Files[0]
+ defer output.Close()
+ if err := pprof.Lookup("block").WriteTo(output, 0); err != nil {
+ return err
+ }
+ return nil
+}
+
+// MutexProfile is an RPC stub which dumps out the stack trace of holders of
+// contended mutexes.
+func (p *Profile) MutexProfile(o *ProfileOpts, _ *struct{}) error {
+ if len(o.FilePayload.Files) < 1 {
+ return errNoOutput
+ }
+ output := o.FilePayload.Files[0]
+ defer output.Close()
+ if err := pprof.Lookup("mutex").WriteTo(output, 0); err != nil {
+ return err
+ }
+ return nil
+}
+
+// StartTrace is an RPC stub which starts collection of an execution trace.
+func (p *Profile) StartTrace(o *ProfileOpts, _ *struct{}) error {
+ if len(o.FilePayload.Files) < 1 {
+ return errNoOutput
+ }
+
+ output, err := fd.NewFromFile(o.FilePayload.Files[0])
+ if err != nil {
+ return err
+ }
+
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ // Returns an error if profiling is already started.
+ if err := trace.Start(output); err != nil {
+ output.Close()
+ return err
+ }
+
+ // Ensure all trace contexts are registered.
+ p.Kernel.RebuildTraceContexts()
+
+ p.traceFile = output
+ return nil
+}
+
+// StopTrace is an RPC stub which stops collection of an ongoing execution
+// trace and flushes the trace data. It takes no argument.
+func (p *Profile) StopTrace(_, _ *struct{}) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if p.traceFile == nil {
+ return errors.New("Execution tracing not started")
+ }
+
+ // Similarly to the case above, if tasks have not ended traces, we will
+ // lose information. Thus we need to rebuild the tasks in order to have
+ // complete information. This will not lose information if multiple
+ // traces are overlapping.
+ p.Kernel.RebuildTraceContexts()
+
+ trace.Stop()
+ p.traceFile.Close()
+ p.traceFile = nil
+ return nil
+}
diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go
new file mode 100644
index 000000000..1bae7cfaf
--- /dev/null
+++ b/pkg/sentry/control/proc.go
@@ -0,0 +1,416 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package control
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "sort"
+ "strings"
+ "text/tabwriter"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fdimport"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/host"
+ "gvisor.dev/gvisor/pkg/sentry/fs/user"
+ hostvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/urpc"
+)
+
+// Proc includes task-related functions.
+//
+// At the moment, this is limited to exec support.
+type Proc struct {
+ Kernel *kernel.Kernel
+}
+
+// ExecArgs is the set of arguments to exec.
+type ExecArgs struct {
+ // Filename is the filename to load.
+ //
+ // If this is provided as "", then the file will be guessed via Argv[0].
+ Filename string `json:"filename"`
+
+ // Argv is a list of arguments.
+ Argv []string `json:"argv"`
+
+ // Envv is a list of environment variables.
+ Envv []string `json:"envv"`
+
+ // MountNamespace is the mount namespace to execute the new process in.
+ // A reference on MountNamespace must be held for the lifetime of the
+ // ExecArgs. If MountNamespace is nil, it will default to the init
+ // process's MountNamespace.
+ MountNamespace *fs.MountNamespace
+
+ // MountNamespaceVFS2 is the mount namespace to execute the new process in.
+ // A reference on MountNamespace must be held for the lifetime of the
+ // ExecArgs. If MountNamespace is nil, it will default to the init
+ // process's MountNamespace.
+ MountNamespaceVFS2 *vfs.MountNamespace
+
+ // WorkingDirectory defines the working directory for the new process.
+ WorkingDirectory string `json:"wd"`
+
+ // KUID is the UID to run with in the root user namespace. Defaults to
+ // root if not set explicitly.
+ KUID auth.KUID
+
+ // KGID is the GID to run with in the root user namespace. Defaults to
+ // the root group if not set explicitly.
+ KGID auth.KGID
+
+ // ExtraKGIDs is the list of additional groups to which the user belongs.
+ ExtraKGIDs []auth.KGID
+
+ // Capabilities is the list of capabilities to give to the process.
+ Capabilities *auth.TaskCapabilities
+
+ // StdioIsPty indicates that FDs 0, 1, and 2 are connected to a host pty FD.
+ StdioIsPty bool
+
+ // FilePayload determines the files to give to the new process.
+ urpc.FilePayload
+
+ // ContainerID is the container for the process being executed.
+ ContainerID string
+
+ // PIDNamespace is the pid namespace for the process being executed.
+ PIDNamespace *kernel.PIDNamespace
+}
+
+// String prints the arguments as a string.
+func (args ExecArgs) String() string {
+ if len(args.Argv) == 0 {
+ return args.Filename
+ }
+ a := make([]string, len(args.Argv))
+ copy(a, args.Argv)
+ if args.Filename != "" {
+ a[0] = args.Filename
+ }
+ return strings.Join(a, " ")
+}
+
+// Exec runs a new task.
+func (proc *Proc) Exec(args *ExecArgs, waitStatus *uint32) error {
+ newTG, _, _, _, err := proc.execAsync(args)
+ if err != nil {
+ return err
+ }
+
+ // Wait for completion.
+ newTG.WaitExited()
+ *waitStatus = newTG.ExitStatus().Status()
+ return nil
+}
+
+// ExecAsync runs a new task, but doesn't wait for it to finish. It is defined
+// as a function rather than a method to avoid exposing execAsync as an RPC.
+func ExecAsync(proc *Proc, args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadID, *host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) {
+ return proc.execAsync(args)
+}
+
+// execAsync runs a new task, but doesn't wait for it to finish. It returns the
+// newly created thread group and its PID. If the stdio FDs are TTYs, then a
+// TTYFileOperations that wraps the TTY is also returned.
+func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadID, *host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) {
+ // Import file descriptors.
+ fdTable := proc.Kernel.NewFDTable()
+ defer fdTable.DecRef()
+
+ creds := auth.NewUserCredentials(
+ args.KUID,
+ args.KGID,
+ args.ExtraKGIDs,
+ args.Capabilities,
+ proc.Kernel.RootUserNamespace())
+
+ initArgs := kernel.CreateProcessArgs{
+ Filename: args.Filename,
+ Argv: args.Argv,
+ Envv: args.Envv,
+ WorkingDirectory: args.WorkingDirectory,
+ MountNamespace: args.MountNamespace,
+ MountNamespaceVFS2: args.MountNamespaceVFS2,
+ Credentials: creds,
+ FDTable: fdTable,
+ Umask: 0022,
+ Limits: limits.NewLimitSet(),
+ MaxSymlinkTraversals: linux.MaxSymlinkTraversals,
+ UTSNamespace: proc.Kernel.RootUTSNamespace(),
+ IPCNamespace: proc.Kernel.RootIPCNamespace(),
+ AbstractSocketNamespace: proc.Kernel.RootAbstractSocketNamespace(),
+ ContainerID: args.ContainerID,
+ PIDNamespace: args.PIDNamespace,
+ }
+ if initArgs.MountNamespace != nil {
+ // initArgs must hold a reference on MountNamespace, which will
+ // be donated to the new process in CreateProcess.
+ initArgs.MountNamespace.IncRef()
+ }
+ if initArgs.MountNamespaceVFS2 != nil {
+ // initArgs must hold a reference on MountNamespaceVFS2, which will
+ // be donated to the new process in CreateProcess.
+ initArgs.MountNamespaceVFS2.IncRef()
+ }
+ ctx := initArgs.NewContext(proc.Kernel)
+
+ if kernel.VFS2Enabled {
+ // Get the full path to the filename from the PATH env variable.
+ if initArgs.MountNamespaceVFS2 == nil {
+ // Set initArgs so that 'ctx' returns the namespace.
+ //
+ // MountNamespaceVFS2 adds a reference to the namespace, which is
+ // transferred to the new process.
+ initArgs.MountNamespaceVFS2 = proc.Kernel.GlobalInit().Leader().MountNamespaceVFS2()
+ }
+ } else {
+ if initArgs.MountNamespace == nil {
+ // Set initArgs so that 'ctx' returns the namespace.
+ initArgs.MountNamespace = proc.Kernel.GlobalInit().Leader().MountNamespace()
+
+ // initArgs must hold a reference on MountNamespace, which will
+ // be donated to the new process in CreateProcess.
+ initArgs.MountNamespace.IncRef()
+ }
+ }
+ resolved, err := user.ResolveExecutablePath(ctx, &initArgs)
+ if err != nil {
+ return nil, 0, nil, nil, err
+ }
+ initArgs.Filename = resolved
+
+ fds := make([]int, len(args.FilePayload.Files))
+ for i, file := range args.FilePayload.Files {
+ if kernel.VFS2Enabled {
+ // Need to dup to remove ownership from os.File.
+ dup, err := unix.Dup(int(file.Fd()))
+ if err != nil {
+ return nil, 0, nil, nil, fmt.Errorf("duplicating payload files: %w", err)
+ }
+ fds[i] = dup
+ } else {
+ // VFS1 dups the file on import.
+ fds[i] = int(file.Fd())
+ }
+ }
+ ttyFile, ttyFileVFS2, err := fdimport.Import(ctx, fdTable, args.StdioIsPty, fds)
+ if err != nil {
+ if kernel.VFS2Enabled {
+ for _, fd := range fds {
+ unix.Close(fd)
+ }
+ }
+ return nil, 0, nil, nil, err
+ }
+
+ tg, tid, err := proc.Kernel.CreateProcess(initArgs)
+ if err != nil {
+ return nil, 0, nil, nil, err
+ }
+
+ // Set the foreground process group on the TTY before starting the process.
+ switch {
+ case ttyFile != nil:
+ ttyFile.InitForegroundProcessGroup(tg.ProcessGroup())
+ case ttyFileVFS2 != nil:
+ ttyFileVFS2.InitForegroundProcessGroup(tg.ProcessGroup())
+ }
+
+ // Start the newly created process.
+ proc.Kernel.StartProcess(tg)
+
+ return tg, tid, ttyFile, ttyFileVFS2, nil
+}
+
+// PsArgs is the set of arguments to ps.
+type PsArgs struct {
+ // JSON will force calls to Ps to return the result as a JSON payload.
+ JSON bool
+}
+
+// Ps provides a process listing for the running kernel.
+func (proc *Proc) Ps(args *PsArgs, out *string) error {
+ var p []*Process
+ if e := Processes(proc.Kernel, "", &p); e != nil {
+ return e
+ }
+ if !args.JSON {
+ *out = ProcessListToTable(p)
+ } else {
+ s, e := ProcessListToJSON(p)
+ if e != nil {
+ return e
+ }
+ *out = s
+ }
+ return nil
+}
+
+// Process contains information about a single process in a Sandbox.
+type Process struct {
+ UID auth.KUID `json:"uid"`
+ PID kernel.ThreadID `json:"pid"`
+ // Parent PID
+ PPID kernel.ThreadID `json:"ppid"`
+ Threads []kernel.ThreadID `json:"threads"`
+ // Processor utilization
+ C int32 `json:"c"`
+ // TTY name of the process. Will be of the form "pts/N" if there is a
+ // TTY, or "?" if there is not.
+ TTY string `json:"tty"`
+ // Start time
+ STime string `json:"stime"`
+ // CPU time
+ Time string `json:"time"`
+ // Executable shortname (e.g. "sh" for /bin/sh)
+ Cmd string `json:"cmd"`
+}
+
+// ProcessListToTable prints a table with the following format:
+// UID PID PPID C TTY STIME TIME CMD
+// 0 1 0 0 pty/4 14:04 505262ns tail
+func ProcessListToTable(pl []*Process) string {
+ var buf bytes.Buffer
+ tw := tabwriter.NewWriter(&buf, 10, 1, 3, ' ', 0)
+ fmt.Fprint(tw, "UID\tPID\tPPID\tC\tTTY\tSTIME\tTIME\tCMD")
+ for _, d := range pl {
+ fmt.Fprintf(tw, "\n%d\t%d\t%d\t%d\t%s\t%s\t%s\t%s",
+ d.UID,
+ d.PID,
+ d.PPID,
+ d.C,
+ d.TTY,
+ d.STime,
+ d.Time,
+ d.Cmd)
+ }
+ tw.Flush()
+ return buf.String()
+}
+
+// ProcessListToJSON will return the JSON representation of ps.
+func ProcessListToJSON(pl []*Process) (string, error) {
+ b, err := json.MarshalIndent(pl, "", " ")
+ if err != nil {
+ return "", fmt.Errorf("couldn't marshal process list %v: %v", pl, err)
+ }
+ return string(b), nil
+}
+
+// PrintPIDsJSON prints a JSON object containing only the PIDs in pl. This
+// behavior is the same as runc's.
+func PrintPIDsJSON(pl []*Process) (string, error) {
+ pids := make([]kernel.ThreadID, 0, len(pl))
+ for _, d := range pl {
+ pids = append(pids, d.PID)
+ }
+ b, err := json.Marshal(pids)
+ if err != nil {
+ return "", fmt.Errorf("couldn't marshal PIDs %v: %v", pids, err)
+ }
+ return string(b), nil
+}
+
+// Processes retrieves information about processes running in the sandbox with
+// the given container id. All processes are returned if 'containerID' is empty.
+func Processes(k *kernel.Kernel, containerID string, out *[]*Process) error {
+ ts := k.TaskSet()
+ now := k.RealtimeClock().Now()
+ for _, tg := range ts.Root.ThreadGroups() {
+ pidns := tg.PIDNamespace()
+ pid := pidns.IDOfThreadGroup(tg)
+
+ // If tg has already been reaped ignore it.
+ if pid == 0 {
+ continue
+ }
+ if containerID != "" && containerID != tg.Leader().ContainerID() {
+ continue
+ }
+
+ ppid := kernel.ThreadID(0)
+ if p := tg.Leader().Parent(); p != nil {
+ ppid = pidns.IDOfThreadGroup(p.ThreadGroup())
+ }
+ threads := tg.MemberIDs(pidns)
+ *out = append(*out, &Process{
+ UID: tg.Leader().Credentials().EffectiveKUID,
+ PID: pid,
+ PPID: ppid,
+ Threads: threads,
+ STime: formatStartTime(now, tg.Leader().StartTime()),
+ C: percentCPU(tg.CPUStats(), tg.Leader().StartTime(), now),
+ Time: tg.CPUStats().SysTime.String(),
+ Cmd: tg.Leader().Name(),
+ TTY: ttyName(tg.TTY()),
+ })
+ }
+ sort.Slice(*out, func(i, j int) bool { return (*out)[i].PID < (*out)[j].PID })
+ return nil
+}
+
+// formatStartTime formats startTime depending on the current time:
+// - If startTime was today, HH:MM is used.
+// - If startTime was not today but was this year, MonDD is used (e.g. Jan02)
+// - If startTime was not this year, the year is used.
+func formatStartTime(now, startTime ktime.Time) string {
+ nowS, nowNs := now.Unix()
+ n := time.Unix(nowS, nowNs)
+ startTimeS, startTimeNs := startTime.Unix()
+ st := time.Unix(startTimeS, startTimeNs)
+ format := "15:04"
+ if st.YearDay() != n.YearDay() {
+ format = "Jan02"
+ }
+ if st.Year() != n.Year() {
+ format = "2006"
+ }
+ return st.Format(format)
+}
+
+func percentCPU(stats usage.CPUStats, startTime, now ktime.Time) int32 {
+ // Note: In procps, there is an option to include child CPU stats. As
+ // it is disabled by default, we do not include them.
+ total := stats.UserTime + stats.SysTime
+ lifetime := now.Sub(startTime)
+ if lifetime <= 0 {
+ return 0
+ }
+ percentCPU := total * 100 / lifetime
+ // Cap at 99% since procps does the same.
+ if percentCPU > 99 {
+ percentCPU = 99
+ }
+ return int32(percentCPU)
+}
+
+func ttyName(tty *kernel.TTY) string {
+ if tty == nil {
+ return "?"
+ }
+ return fmt.Sprintf("pts/%d", tty.Index)
+}
diff --git a/pkg/sentry/control/proc_test.go b/pkg/sentry/control/proc_test.go
new file mode 100644
index 000000000..0a88459b2
--- /dev/null
+++ b/pkg/sentry/control/proc_test.go
@@ -0,0 +1,166 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package control
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/log"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+)
+
+func init() {
+ log.SetLevel(log.Debug)
+}
+
+// Tests that ProcessData.Table() prints with the correct format.
+func TestProcessListTable(t *testing.T) {
+ testCases := []struct {
+ pl []*Process
+ expected string
+ }{
+ {
+ pl: []*Process{},
+ expected: "UID PID PPID C TTY STIME TIME CMD",
+ },
+ {
+ pl: []*Process{
+ {
+ UID: 0,
+ PID: 0,
+ PPID: 0,
+ C: 0,
+ TTY: "?",
+ STime: "0",
+ Time: "0",
+ Cmd: "zero",
+ },
+ {
+ UID: 1,
+ PID: 1,
+ PPID: 1,
+ C: 1,
+ TTY: "pts/4",
+ STime: "1",
+ Time: "1",
+ Cmd: "one",
+ },
+ },
+ expected: `UID PID PPID C TTY STIME TIME CMD
+0 0 0 0 ? 0 0 zero
+1 1 1 1 pts/4 1 1 one`,
+ },
+ }
+
+ for _, tc := range testCases {
+ output := ProcessListToTable(tc.pl)
+
+ if tc.expected != output {
+ t.Errorf("PrintTable(%v): got:\n%s\nwant:\n%s", tc.pl, output, tc.expected)
+ }
+ }
+}
+
+func TestProcessListJSON(t *testing.T) {
+ testCases := []struct {
+ pl []*Process
+ expected string
+ }{
+ {
+ pl: []*Process{},
+ expected: "[]",
+ },
+ {
+ pl: []*Process{
+ {
+ UID: 0,
+ PID: 0,
+ PPID: 0,
+ C: 0,
+ STime: "0",
+ Time: "0",
+ Cmd: "zero",
+ },
+ {
+ UID: 1,
+ PID: 1,
+ PPID: 1,
+ C: 1,
+ STime: "1",
+ Time: "1",
+ Cmd: "one",
+ },
+ },
+ expected: "[0,1]",
+ },
+ }
+
+ for _, tc := range testCases {
+ output, err := PrintPIDsJSON(tc.pl)
+ if err != nil {
+ t.Errorf("failed to generate JSON: %v", err)
+ }
+
+ if tc.expected != output {
+ t.Errorf("PrintJSON(%v): got:\n%s\nwant:\n%s", tc.pl, output, tc.expected)
+ }
+ }
+}
+
+func TestPercentCPU(t *testing.T) {
+ testCases := []struct {
+ stats usage.CPUStats
+ startTime ktime.Time
+ now ktime.Time
+ expected int32
+ }{
+ {
+ // Verify that 100% use is capped at 99.
+ stats: usage.CPUStats{UserTime: 1e9, SysTime: 1e9},
+ startTime: ktime.FromNanoseconds(7e9),
+ now: ktime.FromNanoseconds(9e9),
+ expected: 99,
+ },
+ {
+ // Verify that if usage > lifetime, we get at most 99%
+ // usage.
+ stats: usage.CPUStats{UserTime: 2e9, SysTime: 2e9},
+ startTime: ktime.FromNanoseconds(7e9),
+ now: ktime.FromNanoseconds(9e9),
+ expected: 99,
+ },
+ {
+ // Verify that 50% usage is reported correctly.
+ stats: usage.CPUStats{UserTime: 1e9, SysTime: 1e9},
+ startTime: ktime.FromNanoseconds(12e9),
+ now: ktime.FromNanoseconds(16e9),
+ expected: 50,
+ },
+ {
+ // Verify that 0% usage is reported correctly.
+ stats: usage.CPUStats{UserTime: 0, SysTime: 0},
+ startTime: ktime.FromNanoseconds(12e9),
+ now: ktime.FromNanoseconds(14e9),
+ expected: 0,
+ },
+ }
+
+ for _, tc := range testCases {
+ if pcpu := percentCPU(tc.stats, tc.startTime, tc.now); pcpu != tc.expected {
+ t.Errorf("percentCPU(%v, %v, %v): got %d, want %d", tc.stats, tc.startTime, tc.now, pcpu, tc.expected)
+ }
+ }
+}
diff --git a/pkg/sentry/control/state.go b/pkg/sentry/control/state.go
new file mode 100644
index 000000000..41feeffe3
--- /dev/null
+++ b/pkg/sentry/control/state.go
@@ -0,0 +1,73 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package control
+
+import (
+ "errors"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/state"
+ "gvisor.dev/gvisor/pkg/sentry/watchdog"
+ "gvisor.dev/gvisor/pkg/urpc"
+)
+
+// ErrInvalidFiles is returned when the urpc call to Save does not include an
+// appropriate file payload (e.g. there is no output file!).
+var ErrInvalidFiles = errors.New("exactly one file must be provided")
+
+// State includes state-related functions.
+type State struct {
+ Kernel *kernel.Kernel
+ Watchdog *watchdog.Watchdog
+}
+
+// SaveOpts contains options for the Save RPC call.
+type SaveOpts struct {
+ // Key is used for state integrity check.
+ Key []byte `json:"key"`
+
+ // Metadata is the set of metadata to prepend to the state file.
+ Metadata map[string]string `json:"metadata"`
+
+ // FilePayload contains the destination for the state.
+ urpc.FilePayload
+}
+
+// Save saves the running system.
+func (s *State) Save(o *SaveOpts, _ *struct{}) error {
+ // Create an output stream.
+ if len(o.FilePayload.Files) != 1 {
+ return ErrInvalidFiles
+ }
+ defer o.FilePayload.Files[0].Close()
+
+ // Save to the first provided stream.
+ saveOpts := state.SaveOpts{
+ Destination: o.FilePayload.Files[0],
+ Key: o.Key,
+ Metadata: o.Metadata,
+ Callback: func(err error) {
+ if err == nil {
+ log.Infof("Save succeeded: exiting...")
+ } else {
+ log.Warningf("Save failed: exiting...")
+ s.Kernel.SetSaveError(err)
+ }
+ s.Kernel.Kill(kernel.ExitStatus{})
+ },
+ }
+ return saveOpts.Save(s.Kernel, s.Watchdog)
+}
diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD
new file mode 100644
index 000000000..e403cbd8b
--- /dev/null
+++ b/pkg/sentry/device/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "device",
+ srcs = ["device.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "device_test",
+ size = "small",
+ srcs = ["device_test.go"],
+ library = ":device",
+)
diff --git a/pkg/sentry/device/device.go b/pkg/sentry/device/device.go
new file mode 100644
index 000000000..f45b2bd2b
--- /dev/null
+++ b/pkg/sentry/device/device.go
@@ -0,0 +1,269 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package device defines reserved virtual kernel devices and structures
+// for managing them.
+package device
+
+import (
+ "bytes"
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Registry tracks all simple devices and related state on the system for
+// save/restore.
+//
+// The set of devices across save/restore must remain consistent. That is, no
+// devices may be created or removed on restore relative to the saved
+// system. Practically, this means do not create new devices specifically as
+// part of restore.
+//
+// +stateify savable
+type Registry struct {
+ // lastAnonDeviceMinor is the last minor device number used for an anonymous
+ // device. Must be accessed atomically.
+ lastAnonDeviceMinor uint64
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ devices map[ID]*Device
+}
+
+// SimpleDevices is the system-wide simple device registry. This is
+// saved/restored by kernel.Kernel, but defined here to allow access without
+// depending on the kernel package. See kernel.Kernel.deviceRegistry.
+var SimpleDevices = newRegistry()
+
+func newRegistry() *Registry {
+ return &Registry{
+ devices: make(map[ID]*Device),
+ }
+}
+
+// newAnonID assigns a major and minor number to an anonymous device ID.
+func (r *Registry) newAnonID() ID {
+ return ID{
+ // Anon devices always have a major number of 0.
+ Major: 0,
+ // Use the next minor number.
+ Minor: atomic.AddUint64(&r.lastAnonDeviceMinor, 1),
+ }
+}
+
+// newAnonDevice allocates a new anonymous device with a unique minor device
+// number, and registers it with r.
+func (r *Registry) newAnonDevice() *Device {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ d := &Device{
+ ID: r.newAnonID(),
+ }
+ r.devices[d.ID] = d
+ return d
+}
+
+// LoadFrom initializes the internal state of all devices in r from other. The
+// set of devices in both registries must match. Devices may not be created or
+// destroyed across save/restore.
+func (r *Registry) LoadFrom(other *Registry) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ other.mu.Lock()
+ defer other.mu.Unlock()
+ if len(r.devices) != len(other.devices) {
+ panic(fmt.Sprintf("Devices were added or removed when restoring the registry:\nnew:\n%+v\nold:\n%+v", r.devices, other.devices))
+ }
+ for id, otherD := range other.devices {
+ ourD, ok := r.devices[id]
+ if !ok {
+ panic(fmt.Sprintf("Device %+v could not be restored as it wasn't defined in the new registry", otherD))
+ }
+ ourD.loadFrom(otherD)
+ }
+ atomic.StoreUint64(&r.lastAnonDeviceMinor, atomic.LoadUint64(&other.lastAnonDeviceMinor))
+}
+
+// ID identifies a device.
+//
+// +stateify savable
+type ID struct {
+ Major uint64
+ Minor uint64
+}
+
+// DeviceID formats a major and minor device number into a standard device number.
+func (i *ID) DeviceID() uint64 {
+ return uint64(linux.MakeDeviceID(uint16(i.Major), uint32(i.Minor)))
+}
+
+// NewAnonDevice creates a new anonymous device. Packages that require an anonymous
+// device should initialize the device in a global variable in a file called device.go:
+//
+// var myDevice = device.NewAnonDevice()
+func NewAnonDevice() *Device {
+ return SimpleDevices.newAnonDevice()
+}
+
+// NewAnonMultiDevice creates a new multi-keyed anonymous device. Packages that require
+// a multi-key anonymous device should initialize the device in a global variable in a
+// file called device.go:
+//
+// var myDevice = device.NewAnonMultiDevice()
+func NewAnonMultiDevice() *MultiDevice {
+ return &MultiDevice{
+ ID: SimpleDevices.newAnonID(),
+ }
+}
+
+// Device is a simple virtual kernel device.
+//
+// +stateify savable
+type Device struct {
+ ID
+
+ // last is the last generated inode.
+ last uint64
+}
+
+// loadFrom initializes d from other. The IDs of both devices must match.
+func (d *Device) loadFrom(other *Device) {
+ if d.ID != other.ID {
+ panic(fmt.Sprintf("Attempting to initialize a device %+v from %+v, but device IDs don't match", d, other))
+ }
+ atomic.StoreUint64(&d.last, atomic.LoadUint64(&other.last))
+}
+
+// NextIno generates a new inode number
+func (d *Device) NextIno() uint64 {
+ return atomic.AddUint64(&d.last, 1)
+}
+
+// MultiDeviceKey provides a hashable key for a MultiDevice. The key consists
+// of a raw device and inode for a resource, which must consistently identify
+// the unique resource. It may optionally include a secondary device if
+// appropriate.
+//
+// Note that using the path is not enough, because filesystems may rename a file
+// to a different backing resource, at which point the path points to a different
+// entity. Using only the inode is also not enough because the inode is assumed
+// to be unique only within the device on which the resource exists.
+type MultiDeviceKey struct {
+ Device uint64
+ SecondaryDevice string
+ Inode uint64
+}
+
+// String stringifies the key.
+func (m MultiDeviceKey) String() string {
+ return fmt.Sprintf("key{device: %d, sdevice: %s, inode: %d}", m.Device, m.SecondaryDevice, m.Inode)
+}
+
+// MultiDevice allows for remapping resources that come from a variety of raw
+// devices into a single device. The device ID should be one of the static
+// Device IDs above and cannot be reused.
+type MultiDevice struct {
+ ID
+
+ mu sync.Mutex
+ last uint64
+ cache map[MultiDeviceKey]uint64
+ rcache map[uint64]MultiDeviceKey
+}
+
+// String stringifies MultiDevice.
+func (m *MultiDevice) String() string {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ buf := bytes.NewBuffer(nil)
+ buf.WriteString("cache{")
+ for k, v := range m.cache {
+ buf.WriteString(fmt.Sprintf("%s -> %d, ", k, v))
+ }
+ buf.WriteString("}")
+ return buf.String()
+}
+
+// Map maps a raw device and inode into the inode space of MultiDevice,
+// returning a virtualized inode. Raw devices and inodes can be reused;
+// in this case, the same virtual inode will be returned.
+func (m *MultiDevice) Map(key MultiDeviceKey) uint64 {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if m.cache == nil {
+ m.cache = make(map[MultiDeviceKey]uint64)
+ m.rcache = make(map[uint64]MultiDeviceKey)
+ }
+
+ id, ok := m.cache[key]
+ if ok {
+ return id
+ }
+ // Step over reserved entries that may have been loaded.
+ idx := m.last + 1
+ for {
+ if _, ok := m.rcache[idx]; !ok {
+ break
+ }
+ idx++
+ }
+ // We found a non-reserved entry, use it.
+ m.last = idx
+ m.cache[key] = m.last
+ m.rcache[m.last] = key
+ return m.last
+}
+
+// Load loads a raw device and inode into MultiDevice inode mappings
+// with value as the virtual inode.
+//
+// By design, inodes start from 1 and continue until max uint64. This means
+// that the zero value, which is often the uninitialized value, can be rejected
+// as invalid.
+func (m *MultiDevice) Load(key MultiDeviceKey, value uint64) bool {
+ // Reject the uninitialized value; see comment above.
+ if value == 0 {
+ return false
+ }
+
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if m.cache == nil {
+ m.cache = make(map[MultiDeviceKey]uint64)
+ m.rcache = make(map[uint64]MultiDeviceKey)
+ }
+
+ if val, exists := m.cache[key]; exists && val != value {
+ return false
+ }
+ if k, exists := m.rcache[value]; exists && k != key {
+ // Should never happen.
+ panic("MultiDevice's caches are inconsistent")
+ }
+
+ // Cache value at key.
+ m.cache[key] = value
+
+ // Prevent value from being used by new inode mappings.
+ m.rcache[value] = key
+
+ return true
+}
diff --git a/pkg/sentry/device/device_test.go b/pkg/sentry/device/device_test.go
new file mode 100644
index 000000000..e3f51ce4f
--- /dev/null
+++ b/pkg/sentry/device/device_test.go
@@ -0,0 +1,59 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package device
+
+import (
+ "testing"
+)
+
+func TestMultiDevice(t *testing.T) {
+ device := &MultiDevice{}
+
+ // Check that Load fails to install virtual inodes that are
+ // uninitialized.
+ if device.Load(MultiDeviceKey{}, 0) {
+ t.Fatalf("got load of invalid virtual inode 0, want unsuccessful")
+ }
+
+ inode := device.Map(MultiDeviceKey{})
+
+ // Assert that the same raw device and inode map to
+ // a consistent virtual inode.
+ if i := device.Map(MultiDeviceKey{}); i != inode {
+ t.Fatalf("got inode %d, want %d in %s", i, inode, device)
+ }
+
+ // Assert that a new inode or new device does not conflict.
+ if i := device.Map(MultiDeviceKey{Device: 0, Inode: 1}); i == inode {
+ t.Fatalf("got reused inode %d, want new distinct inode in %s", i, device)
+ }
+ last := device.Map(MultiDeviceKey{Device: 1, Inode: 0})
+ if last == inode {
+ t.Fatalf("got reused inode %d, want new distinct inode in %s", last, device)
+ }
+
+ // Virtual is the virtual inode we want to load.
+ virtual := last + 1
+
+ // Assert that we can load a virtual inode at a new place.
+ if !device.Load(MultiDeviceKey{Device: 0, Inode: 2}, virtual) {
+ t.Fatalf("got load of virtual inode %d failed, want success in %s", virtual, device)
+ }
+
+ // Assert that the next inode skips over the loaded one.
+ if i := device.Map(MultiDeviceKey{Device: 0, Inode: 3}); i != virtual+1 {
+ t.Fatalf("got inode %d, want %d in %s", i, virtual+1, device)
+ }
+}
diff --git a/pkg/sentry/devices/memdev/BUILD b/pkg/sentry/devices/memdev/BUILD
new file mode 100644
index 000000000..abe58f818
--- /dev/null
+++ b/pkg/sentry/devices/memdev/BUILD
@@ -0,0 +1,28 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "memdev",
+ srcs = [
+ "full.go",
+ "memdev.go",
+ "null.go",
+ "random.go",
+ "zero.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/rand",
+ "//pkg/safemem",
+ "//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/mm",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/devices/memdev/full.go b/pkg/sentry/devices/memdev/full.go
new file mode 100644
index 000000000..af66fe4dc
--- /dev/null
+++ b/pkg/sentry/devices/memdev/full.go
@@ -0,0 +1,76 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memdev
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const fullDevMinor = 7
+
+// fullDevice implements vfs.Device for /dev/full.
+type fullDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (fullDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &fullFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// fullFD implements vfs.FileDescriptionImpl for /dev/full.
+type fullFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *fullFD) Release() {
+ // noop
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *fullFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return dst.ZeroOut(ctx, dst.NumBytes())
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *fullFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return dst.ZeroOut(ctx, dst.NumBytes())
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *fullFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ENOSPC
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *fullFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ENOSPC
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *fullFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return 0, nil
+}
diff --git a/pkg/sentry/devices/memdev/memdev.go b/pkg/sentry/devices/memdev/memdev.go
new file mode 100644
index 000000000..5759900c4
--- /dev/null
+++ b/pkg/sentry/devices/memdev/memdev.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package memdev implements "mem" character devices, as implemented in Linux
+// by drivers/char/mem.c and drivers/char/random.c.
+package memdev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// Register registers all devices implemented by this package in vfsObj.
+func Register(vfsObj *vfs.VirtualFilesystem) error {
+ for minor, dev := range map[uint32]vfs.Device{
+ nullDevMinor: nullDevice{},
+ zeroDevMinor: zeroDevice{},
+ fullDevMinor: fullDevice{},
+ randomDevMinor: randomDevice{},
+ urandomDevMinor: randomDevice{},
+ } {
+ if err := vfsObj.RegisterDevice(vfs.CharDevice, linux.MEM_MAJOR, minor, dev, &vfs.RegisterDeviceOptions{
+ GroupName: "mem",
+ }); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// CreateDevtmpfsFiles creates device special files in dev representing all
+// devices implemented by this package.
+func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error {
+ for minor, name := range map[uint32]string{
+ nullDevMinor: "null",
+ zeroDevMinor: "zero",
+ fullDevMinor: "full",
+ randomDevMinor: "random",
+ urandomDevMinor: "urandom",
+ } {
+ if err := dev.CreateDeviceFile(ctx, name, vfs.CharDevice, linux.MEM_MAJOR, minor, 0666 /* mode */); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/devices/memdev/null.go b/pkg/sentry/devices/memdev/null.go
new file mode 100644
index 000000000..92d3d71be
--- /dev/null
+++ b/pkg/sentry/devices/memdev/null.go
@@ -0,0 +1,77 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memdev
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const nullDevMinor = 3
+
+// nullDevice implements vfs.Device for /dev/null.
+type nullDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (nullDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &nullFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// nullFD implements vfs.FileDescriptionImpl for /dev/null.
+type nullFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *nullFD) Release() {
+ // noop
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *nullFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, io.EOF
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *nullFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return 0, io.EOF
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *nullFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return src.NumBytes(), nil
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *nullFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return src.NumBytes(), nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *nullFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return 0, nil
+}
diff --git a/pkg/sentry/devices/memdev/random.go b/pkg/sentry/devices/memdev/random.go
new file mode 100644
index 000000000..6b81da5ef
--- /dev/null
+++ b/pkg/sentry/devices/memdev/random.go
@@ -0,0 +1,93 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memdev
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ randomDevMinor = 8
+ urandomDevMinor = 9
+)
+
+// randomDevice implements vfs.Device for /dev/random and /dev/urandom.
+type randomDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (randomDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &randomFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// randomFD implements vfs.FileDescriptionImpl for /dev/random.
+type randomFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+
+ // off is the "file offset". off is accessed using atomic memory
+ // operations.
+ off int64
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *randomFD) Release() {
+ // noop
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *randomFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return dst.CopyOutFrom(ctx, safemem.FromIOReader{rand.Reader})
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *randomFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ n, err := dst.CopyOutFrom(ctx, safemem.FromIOReader{rand.Reader})
+ atomic.AddInt64(&fd.off, n)
+ return n, err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *randomFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ // In Linux, this mixes the written bytes into the entropy pool; we just
+ // throw them away.
+ return src.NumBytes(), nil
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *randomFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ atomic.AddInt64(&fd.off, src.NumBytes())
+ return src.NumBytes(), nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *randomFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ // Linux: drivers/char/random.c:random_fops.llseek == urandom_fops.llseek
+ // == noop_llseek
+ return atomic.LoadInt64(&fd.off), nil
+}
diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go
new file mode 100644
index 000000000..c6f15054d
--- /dev/null
+++ b/pkg/sentry/devices/memdev/zero.go
@@ -0,0 +1,89 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memdev
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const zeroDevMinor = 5
+
+// zeroDevice implements vfs.Device for /dev/zero.
+type zeroDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (zeroDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &zeroFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// zeroFD implements vfs.FileDescriptionImpl for /dev/zero.
+type zeroFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *zeroFD) Release() {
+ // noop
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *zeroFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return dst.ZeroOut(ctx, dst.NumBytes())
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *zeroFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return dst.ZeroOut(ctx, dst.NumBytes())
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *zeroFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return src.NumBytes(), nil
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *zeroFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return src.NumBytes(), nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *zeroFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return 0, nil
+}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *zeroFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ m, err := mm.NewSharedAnonMappable(opts.Length, pgalloc.MemoryFileProviderFromContext(ctx))
+ if err != nil {
+ return err
+ }
+ opts.MappingIdentity = m
+ opts.Mappable = m
+ return nil
+}
diff --git a/pkg/sentry/devices/ttydev/BUILD b/pkg/sentry/devices/ttydev/BUILD
new file mode 100644
index 000000000..12e49b58a
--- /dev/null
+++ b/pkg/sentry/devices/ttydev/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "ttydev",
+ srcs = ["ttydev.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/vfs",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/devices/ttydev/ttydev.go b/pkg/sentry/devices/ttydev/ttydev.go
new file mode 100644
index 000000000..fbb7fd92c
--- /dev/null
+++ b/pkg/sentry/devices/ttydev/ttydev.go
@@ -0,0 +1,91 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ttydev implements devices for /dev/tty and (eventually)
+// /dev/console.
+//
+// TODO(b/159623826): Support /dev/console.
+package ttydev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ // See drivers/tty/tty_io.c:tty_init().
+ ttyDevMinor = 0
+ consoleDevMinor = 1
+)
+
+// ttyDevice implements vfs.Device for /dev/tty.
+type ttyDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (ttyDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &ttyFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// ttyFD implements vfs.FileDescriptionImpl for /dev/tty.
+type ttyFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *ttyFD) Release() {}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *ttyFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, nil
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *ttyFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return 0, nil
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *ttyFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return src.NumBytes(), nil
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *ttyFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return src.NumBytes(), nil
+}
+
+// Register registers all devices implemented by this package in vfsObj.
+func Register(vfsObj *vfs.VirtualFilesystem) error {
+ return vfsObj.RegisterDevice(vfs.CharDevice, linux.TTYAUX_MAJOR, ttyDevMinor, ttyDevice{}, &vfs.RegisterDeviceOptions{
+ GroupName: "tty",
+ })
+}
+
+// CreateDevtmpfsFiles creates device special files in dev representing all
+// devices implemented by this package.
+func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error {
+ return dev.CreateDeviceFile(ctx, "tty", vfs.CharDevice, linux.TTYAUX_MAJOR, ttyDevMinor, 0666 /* mode */)
+}
diff --git a/pkg/sentry/devices/tundev/BUILD b/pkg/sentry/devices/tundev/BUILD
new file mode 100644
index 000000000..71c59287c
--- /dev/null
+++ b/pkg/sentry/devices/tundev/BUILD
@@ -0,0 +1,23 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "tundev",
+ srcs = ["tundev.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/socket/netstack",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/tcpip/link/tun",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go
new file mode 100644
index 000000000..dfbd069af
--- /dev/null
+++ b/pkg/sentry/devices/tundev/tundev.go
@@ -0,0 +1,178 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tundev implements the /dev/net/tun device.
+package tundev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ netTunDevMajor = 10
+ netTunDevMinor = 200
+)
+
+// tunDevice implements vfs.Device for /dev/net/tun.
+type tunDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (tunDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &tunFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// tunFD implements vfs.FileDescriptionImpl for /dev/net/tun.
+type tunFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+
+ device tun.Device
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
+func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ request := args[1].Uint()
+ data := args[2].Pointer()
+
+ switch request {
+ case linux.TUNSETIFF:
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ panic("Ioctl should be called from a task context")
+ }
+ if !t.HasCapability(linux.CAP_NET_ADMIN) {
+ return 0, syserror.EPERM
+ }
+ stack, ok := t.NetworkContext().(*netstack.Stack)
+ if !ok {
+ return 0, syserror.EINVAL
+ }
+
+ var req linux.IFReq
+ if _, err := usermem.CopyObjectIn(ctx, uio, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ flags := usermem.ByteOrder.Uint16(req.Data[:])
+ return 0, fd.device.SetIff(stack.Stack, req.Name(), flags)
+
+ case linux.TUNGETIFF:
+ var req linux.IFReq
+
+ copy(req.IFName[:], fd.device.Name())
+
+ // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when
+ // there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c.
+ flags := fd.device.Flags() | linux.IFF_NOFILTER
+ usermem.ByteOrder.PutUint16(req.Data[:], flags)
+
+ _, err := usermem.CopyObjectOut(ctx, uio, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *tunFD) Release() {
+ fd.device.Release()
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *tunFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return fd.Read(ctx, dst, opts)
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *tunFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ data, err := fd.device.Read()
+ if err != nil {
+ return 0, err
+ }
+ n, err := dst.CopyOut(ctx, data)
+ if n > 0 && n < len(data) {
+ // Not an error for partial copying. Packet truncated.
+ err = nil
+ }
+ return int64(n), err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *tunFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return fd.Write(ctx, src, opts)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *tunFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ data := make([]byte, src.NumBytes())
+ if _, err := src.CopyIn(ctx, data); err != nil {
+ return 0, err
+ }
+ return fd.device.Write(data)
+}
+
+// Readiness implements watier.Waitable.Readiness.
+func (fd *tunFD) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fd.device.Readiness(mask)
+}
+
+// EventRegister implements watier.Waitable.EventRegister.
+func (fd *tunFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fd.device.EventRegister(e, mask)
+}
+
+// EventUnregister implements watier.Waitable.EventUnregister.
+func (fd *tunFD) EventUnregister(e *waiter.Entry) {
+ fd.device.EventUnregister(e)
+}
+
+// isNetTunSupported returns whether /dev/net/tun device is supported for s.
+func isNetTunSupported(s inet.Stack) bool {
+ _, ok := s.(*netstack.Stack)
+ return ok
+}
+
+// Register registers all devices implemented by this package in vfsObj.
+func Register(vfsObj *vfs.VirtualFilesystem) error {
+ return vfsObj.RegisterDevice(vfs.CharDevice, netTunDevMajor, netTunDevMinor, tunDevice{}, &vfs.RegisterDeviceOptions{})
+}
+
+// CreateDevtmpfsFiles creates device special files in dev representing all
+// devices implemented by this package.
+func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error {
+ return dev.CreateDeviceFile(ctx, "net/tun", vfs.CharDevice, netTunDevMajor, netTunDevMinor, 0666 /* mode */)
+}
diff --git a/pkg/sentry/fdimport/BUILD b/pkg/sentry/fdimport/BUILD
new file mode 100644
index 000000000..5e41ceb4e
--- /dev/null
+++ b/pkg/sentry/fdimport/BUILD
@@ -0,0 +1,19 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "fdimport",
+ srcs = [
+ "fdimport.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/host",
+ "//pkg/sentry/fsimpl/host",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/vfs",
+ ],
+)
diff --git a/pkg/sentry/fdimport/fdimport.go b/pkg/sentry/fdimport/fdimport.go
new file mode 100644
index 000000000..a4199f9e9
--- /dev/null
+++ b/pkg/sentry/fdimport/fdimport.go
@@ -0,0 +1,129 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fdimport
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/host"
+ hostvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// Import imports a slice of FDs into the given FDTable. If console is true,
+// sets up TTY for the first 3 FDs in the slice representing stdin, stdout,
+// stderr. Upon success, Import takes ownership of all FDs.
+func Import(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []int) (*host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) {
+ if kernel.VFS2Enabled {
+ ttyFile, err := importVFS2(ctx, fdTable, console, fds)
+ return nil, ttyFile, err
+ }
+ ttyFile, err := importFS(ctx, fdTable, console, fds)
+ return ttyFile, nil, err
+}
+
+func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []int) (*host.TTYFileOperations, error) {
+ var ttyFile *fs.File
+ for appFD, hostFD := range fds {
+ var appFile *fs.File
+
+ if console && appFD < 3 {
+ // Import the file as a host TTY file.
+ if ttyFile == nil {
+ var err error
+ appFile, err = host.ImportFile(ctx, hostFD, true /* isTTY */)
+ if err != nil {
+ return nil, err
+ }
+ defer appFile.DecRef()
+
+ // Remember this in the TTY file, as we will
+ // use it for the other stdio FDs.
+ ttyFile = appFile
+ } else {
+ // Re-use the existing TTY file, as all three
+ // stdio FDs must point to the same fs.File in
+ // order to share TTY state, specifically the
+ // foreground process group id.
+ appFile = ttyFile
+ }
+ } else {
+ // Import the file as a regular host file.
+ var err error
+ appFile, err = host.ImportFile(ctx, hostFD, false /* isTTY */)
+ if err != nil {
+ return nil, err
+ }
+ defer appFile.DecRef()
+ }
+
+ // Add the file to the FD map.
+ if err := fdTable.NewFDAt(ctx, int32(appFD), appFile, kernel.FDFlags{}); err != nil {
+ return nil, err
+ }
+ }
+
+ if ttyFile == nil {
+ return nil, nil
+ }
+ return ttyFile.FileOperations.(*host.TTYFileOperations), nil
+}
+
+func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdioFDs []int) (*hostvfs2.TTYFileDescription, error) {
+ k := kernel.KernelFromContext(ctx)
+
+ var ttyFile *vfs.FileDescription
+ for appFD, hostFD := range stdioFDs {
+ var appFile *vfs.FileDescription
+
+ if console && appFD < 3 {
+ // Import the file as a host TTY file.
+ if ttyFile == nil {
+ var err error
+ appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD, true /* isTTY */)
+ if err != nil {
+ return nil, err
+ }
+ defer appFile.DecRef()
+
+ // Remember this in the TTY file, as we will use it for the other stdio
+ // FDs.
+ ttyFile = appFile
+ } else {
+ // Re-use the existing TTY file, as all three stdio FDs must point to
+ // the same fs.File in order to share TTY state, specifically the
+ // foreground process group id.
+ appFile = ttyFile
+ }
+ } else {
+ var err error
+ appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD, false /* isTTY */)
+ if err != nil {
+ return nil, err
+ }
+ defer appFile.DecRef()
+ }
+
+ if err := fdTable.NewFDAtVFS2(ctx, int32(appFD), appFile, kernel.FDFlags{}); err != nil {
+ return nil, err
+ }
+ }
+
+ if ttyFile == nil {
+ return nil, nil
+ }
+ return ttyFile.Impl().(*hostvfs2.TTYFileDescription), nil
+}
diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD
new file mode 100644
index 000000000..ea85ab33c
--- /dev/null
+++ b/pkg/sentry/fs/BUILD
@@ -0,0 +1,135 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "fs",
+ srcs = [
+ "attr.go",
+ "context.go",
+ "copy_up.go",
+ "dentry.go",
+ "dirent.go",
+ "dirent_cache.go",
+ "dirent_cache_limiter.go",
+ "dirent_list.go",
+ "dirent_state.go",
+ "event_list.go",
+ "file.go",
+ "file_operations.go",
+ "file_overlay.go",
+ "file_state.go",
+ "filesystems.go",
+ "flags.go",
+ "fs.go",
+ "inode.go",
+ "inode_inotify.go",
+ "inode_operations.go",
+ "inode_overlay.go",
+ "inotify.go",
+ "inotify_event.go",
+ "inotify_watch.go",
+ "mock.go",
+ "mount.go",
+ "mount_overlay.go",
+ "mounts.go",
+ "offset.go",
+ "overlay.go",
+ "path.go",
+ "restore.go",
+ "save.go",
+ "seek.go",
+ "splice.go",
+ "sync.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/amutex",
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/metric",
+ "//pkg/p9",
+ "//pkg/refs",
+ "//pkg/secio",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/uniqueid",
+ "//pkg/sentry/usage",
+ "//pkg/state",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_template_instance(
+ name = "dirent_list",
+ out = "dirent_list.go",
+ package = "fs",
+ prefix = "dirent",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Linker": "*Dirent",
+ "Element": "*Dirent",
+ },
+)
+
+go_template_instance(
+ name = "event_list",
+ out = "event_list.go",
+ package = "fs",
+ prefix = "event",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Linker": "*Event",
+ "Element": "*Event",
+ },
+)
+
+go_test(
+ name = "fs_x_test",
+ size = "small",
+ srcs = [
+ "copy_up_test.go",
+ "file_overlay_test.go",
+ "inode_overlay_test.go",
+ "mounts_test.go",
+ ],
+ deps = [
+ ":fs",
+ "//pkg/context",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/ramfs",
+ "//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/kernel/contexttest",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
+
+go_test(
+ name = "fs_test",
+ size = "small",
+ srcs = [
+ "dirent_cache_test.go",
+ "dirent_refs_test.go",
+ "mount_test.go",
+ "path_test.go",
+ ],
+ library = ":fs",
+ deps = [
+ "//pkg/context",
+ "//pkg/sentry/contexttest",
+ ],
+)
diff --git a/pkg/sentry/fs/README.md b/pkg/sentry/fs/README.md
new file mode 100644
index 000000000..db4a1b730
--- /dev/null
+++ b/pkg/sentry/fs/README.md
@@ -0,0 +1,229 @@
+This package provides an implementation of the Linux virtual filesystem.
+
+[TOC]
+
+## Overview
+
+- An `fs.Dirent` caches an `fs.Inode` in memory at a path in the VFS, giving
+ the `fs.Inode` a relative position with respect to other `fs.Inode`s.
+
+- If an `fs.Dirent` is referenced by two file descriptors, then those file
+ descriptors are coherent with each other: they depend on the same
+ `fs.Inode`.
+
+- A mount point is an `fs.Dirent` for which `fs.Dirent.mounted` is true. It
+ exposes the root of a mounted filesystem.
+
+- The `fs.Inode` produced by a registered filesystem on mount(2) owns an
+ `fs.MountedFilesystem` from which other `fs.Inode`s will be looked up. For a
+ remote filesystem, the `fs.MountedFilesystem` owns the connection to that
+ remote filesystem.
+
+- In general:
+
+```
+fs.Inode <------------------------------
+| |
+| |
+produced by |
+exactly one |
+| responsible for the
+| virtual identity of
+v |
+fs.MountedFilesystem -------------------
+```
+
+Glossary:
+
+- VFS: virtual filesystem.
+
+- inode: a virtual file object holding a cached view of a file on a backing
+ filesystem (includes metadata and page caches).
+
+- superblock: the virtual state of a mounted filesystem (e.g. the virtual
+ inode number set).
+
+- mount namespace: a view of the mounts under a root (during path traversal,
+ the VFS makes visible/follows the mount point that is in the current task's
+ mount namespace).
+
+## Save and restore
+
+An application's hard dependencies on filesystem state can be broken down into
+two categories:
+
+- The state necessary to execute a traversal on or view the *virtual*
+ filesystem hierarchy, regardless of what files an application has open.
+
+- The state necessary to represent open files.
+
+The first is always necessary to save and restore. An application may never have
+any open file descriptors, but across save and restore it should see a coherent
+view of any mount namespace. NOTE(b/63601033): Currently only one "initial"
+mount namespace is supported.
+
+The second is so that system calls across save and restore are coherent with
+each other (e.g. so that unintended re-reads or overwrites do not occur).
+
+Specifically this state is:
+
+- An `fs.MountManager` containing mount points.
+
+- A `kernel.FDTable` containing pointers to open files.
+
+Anything else managed by the VFS that can be easily loaded into memory from a
+filesystem is synced back to those filesystems and is not saved. Examples are
+pages in page caches used for optimizations (i.e. readahead and writeback), and
+directory entries used to accelerate path lookups.
+
+### Mount points
+
+Saving and restoring a mount point means saving and restoring:
+
+- The root of the mounted filesystem.
+
+- Mount flags, which control how the VFS interacts with the mounted
+ filesystem.
+
+- Any relevant metadata about the mounted filesystem.
+
+- All `fs.Inode`s referenced by the application that reside under the mount
+ point.
+
+`fs.MountedFilesystem` is metadata about a filesystem that is mounted. It is
+referenced by every `fs.Inode` loaded into memory under the mount point
+including the `fs.Inode` of the mount point itself. The `fs.MountedFilesystem`
+maps file objects on the filesystem to a virtualized `fs.Inode` number and vice
+versa.
+
+To restore all `fs.Inode`s under a given mount point, each `fs.Inode` leverages
+its dependency on an `fs.MountedFilesystem`. Since the `fs.MountedFilesystem`
+knows how an `fs.Inode` maps to a file object on a backing filesystem, this
+mapping can be trivially consulted by each `fs.Inode` when the `fs.Inode` is
+restored.
+
+In detail, a mount point is saved in two steps:
+
+- First, after the kernel is paused but before state.Save, we walk all mount
+ namespaces and install a mapping from `fs.Inode` numbers to file paths
+ relative to the root of the mounted filesystem in each
+ `fs.MountedFilesystem`. This is subsequently called the set of `fs.Inode`
+ mappings.
+
+- Second, during state.Save, each `fs.MountedFilesystem` decides whether to
+ save the set of `fs.Inode` mappings. In-memory filesystems, like tmpfs, have
+ no need to save a set of `fs.Inode` mappings, since the `fs.Inode`s can be
+ entirely encoded in state file. Each `fs.MountedFilesystem` also optionally
+ saves the device name from when the filesystem was originally mounted. Each
+ `fs.Inode` saves its virtual identifier and a reference to a
+ `fs.MountedFilesystem`.
+
+A mount point is restored in two steps:
+
+- First, before state.Load, all mount configurations are stored in a global
+ `fs.RestoreEnvironment`. This tells us what mount points the user wants to
+ restore and how to re-establish pointers to backing filesystems.
+
+- Second, during state.Load, each `fs.MountedFilesystem` optionally searches
+ for a mount in the `fs.RestoreEnvironment` that matches its saved device
+ name. The `fs.MountedFilesystem` then reestablishes a pointer to the root of
+ the mounted filesystem. For example, the mount specification provides the
+ network connection for a mounted remote filesystem client to communicate
+ with its remote file server. The `fs.MountedFilesystem` also trivially loads
+ its set of `fs.Inode` mappings. When an `fs.Inode` is encountered, the
+ `fs.Inode` loads its virtual identifier and its reference a
+ `fs.MountedFilesystem`. It uses the `fs.MountedFilesystem` to obtain the
+ root of the mounted filesystem and the `fs.Inode` mappings to obtain the
+ relative file path to its data. With these, the `fs.Inode` re-establishes a
+ pointer to its file object.
+
+A mount point can trivially restore its `fs.Inode`s in parallel since
+`fs.Inode`s have a restore dependency on their `fs.MountedFilesystem` and not on
+each other.
+
+### Open files
+
+An `fs.File` references the following filesystem objects:
+
+```go
+fs.File -> fs.Dirent -> fs.Inode -> fs.MountedFilesystem
+```
+
+The `fs.Inode` is restored using its `fs.MountedFilesystem`. The
+[Mount points](#mount-points) section above describes how this happens in
+detail. The `fs.Dirent` restores its pointer to an `fs.Inode`, pointers to
+parent and children `fs.Dirents`, and the basename of the file.
+
+Otherwise an `fs.File` restores flags, an offset, and a unique identifier (only
+used internally).
+
+It may use the `fs.Inode`, which it indirectly holds a reference on through the
+`fs.Dirent`, to reestablish an open file handle on the backing filesystem (e.g.
+to continue reading and writing).
+
+## Overlay
+
+The overlay implementation in the fs package takes Linux overlayfs as a frame of
+reference but corrects for several POSIX consistency errors.
+
+In Linux overlayfs, the `struct inode` used for reading and writing to the same
+file may be different. This is because the `struct inode` is dissociated with
+the process of copying up the file from the upper to the lower directory. Since
+flock(2) and fcntl(2) locks, inotify(7) watches, page caches, and a file's
+identity are all stored directly or indirectly off the `struct inode`, these
+properties of the `struct inode` may be stale after the first modification. This
+can lead to file locking bugs, missed inotify events, and inconsistent data in
+shared memory mappings of files, to name a few problems.
+
+The fs package maintains a single `fs.Inode` to represent a directory entry in
+an overlay and defines operations on this `fs.Inode` which synchronize with the
+copy up process. This achieves several things:
+
++ File locks, inotify watches, and the identity of the file need not be copied
+ at all.
+
++ Memory mappings of files coordinate with the copy up process so that if a
+ file in the lower directory is memory mapped, all references to it are
+ invalidated, forcing the application to re-fault on memory mappings of the
+ file under the upper directory.
+
+The `fs.Inode` holds metadata about files in the upper and/or lower directories
+via an `fs.overlayEntry`. The `fs.overlayEntry` implements the `fs.Mappable`
+interface. It multiplexes between upper and lower directory memory mappings and
+stores a copy of memory references so they can be transferred to the upper
+directory `fs.Mappable` when the file is copied up.
+
+The lower filesystem in an overlay may contain another (nested) overlay, but the
+upper filesystem may not contain another overlay. In other words, nested
+overlays form a tree structure that only allows branching in the lower
+filesystem.
+
+Caching decisions in the overlay are delegated to the upper filesystem, meaning
+that the Keep and Revalidate methods on the overlay return the same values as
+the upper filesystem. A small wrinkle is that the lower filesystem is not
+allowed to return `true` from Revalidate, as the overlay can not reload inodes
+from the lower filesystem. A lower filesystem that does return `true` from
+Revalidate will trigger a panic.
+
+The `fs.Inode` also holds a reference to a `fs.MountedFilesystem` that
+normalizes across the mounted filesystem state of the upper and lower
+directories.
+
+When a file is copied from the lower to the upper directory, attempts to
+interact with the file block until the copy completes. All copying synchronizes
+with rename(2).
+
+## Future Work
+
+### Overlay
+
+When a file is copied from a lower directory to an upper directory, several
+locks are taken: the global renamuMu and the copyMu of the `fs.Inode` being
+copied. This blocks operations on the file, including fault handling of memory
+mappings. Performance could be improved by copying files into a temporary
+directory that resides on the same filesystem as the upper directory and doing
+an atomic rename, holding locks only during the rename operation.
+
+Additionally files are copied up synchronously. For large files, this causes a
+noticeable latency. Performance could be improved by pipelining copies at
+non-overlapping file offsets.
diff --git a/pkg/sentry/fs/anon/BUILD b/pkg/sentry/fs/anon/BUILD
new file mode 100644
index 000000000..aedcecfa1
--- /dev/null
+++ b/pkg/sentry/fs/anon/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "anon",
+ srcs = [
+ "anon.go",
+ "device.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fs/anon/anon.go b/pkg/sentry/fs/anon/anon.go
new file mode 100644
index 000000000..5c421f5fb
--- /dev/null
+++ b/pkg/sentry/fs/anon/anon.go
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package anon implements an anonymous inode, useful for implementing
+// inodes for pseudo filesystems.
+package anon
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// NewInode constructs an anonymous Inode that is not associated
+// with any real filesystem. Some types depend on completely pseudo
+// "anon" inodes (eventfds, epollfds, etc).
+func NewInode(ctx context.Context) *fs.Inode {
+ iops := &fsutil.SimpleFileInode{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.RootOwner, fs.FilePermissions{
+ User: fs.PermMask{Read: true, Write: true},
+ }, linux.ANON_INODE_FS_MAGIC),
+ }
+ return fs.NewInode(ctx, iops, fs.NewPseudoMountSource(ctx), fs.StableAttr{
+ Type: fs.Anonymous,
+ DeviceID: PseudoDevice.DeviceID(),
+ InodeID: PseudoDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ })
+}
diff --git a/pkg/sentry/fs/anon/device.go b/pkg/sentry/fs/anon/device.go
new file mode 100644
index 000000000..d9ac14956
--- /dev/null
+++ b/pkg/sentry/fs/anon/device.go
@@ -0,0 +1,22 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package anon
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/device"
+)
+
+// PseudoDevice is the device on which all anonymous inodes reside.
+var PseudoDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/fs/attr.go b/pkg/sentry/fs/attr.go
new file mode 100644
index 000000000..f60bd423d
--- /dev/null
+++ b/pkg/sentry/fs/attr.go
@@ -0,0 +1,493 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "os"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+)
+
+// InodeType enumerates types of Inodes.
+type InodeType int
+
+const (
+ // RegularFile is a regular file.
+ RegularFile InodeType = iota
+
+ // SpecialFile is a file that doesn't support SeekEnd. It is used for
+ // things like proc files.
+ SpecialFile
+
+ // Directory is a directory.
+ Directory
+
+ // SpecialDirectory is a directory that *does* support SeekEnd. It's
+ // the opposite of the SpecialFile scenario above. It similarly
+ // supports proc files.
+ SpecialDirectory
+
+ // Symlink is a symbolic link.
+ Symlink
+
+ // Pipe is a pipe (named or regular).
+ Pipe
+
+ // Socket is a socket.
+ Socket
+
+ // CharacterDevice is a character device.
+ CharacterDevice
+
+ // BlockDevice is a block device.
+ BlockDevice
+
+ // Anonymous is an anonymous type when none of the above apply.
+ // Epoll fds and event-driven fds fit this category.
+ Anonymous
+)
+
+// String returns a human-readable representation of the InodeType.
+func (n InodeType) String() string {
+ switch n {
+ case RegularFile, SpecialFile:
+ return "file"
+ case Directory, SpecialDirectory:
+ return "directory"
+ case Symlink:
+ return "symlink"
+ case Pipe:
+ return "pipe"
+ case Socket:
+ return "socket"
+ case CharacterDevice:
+ return "character-device"
+ case BlockDevice:
+ return "block-device"
+ case Anonymous:
+ return "anonymous"
+ default:
+ return "unknown"
+ }
+}
+
+// LinuxType returns the linux file type for this inode type.
+func (n InodeType) LinuxType() uint32 {
+ switch n {
+ case RegularFile, SpecialFile:
+ return linux.ModeRegular
+ case Directory, SpecialDirectory:
+ return linux.ModeDirectory
+ case Symlink:
+ return linux.ModeSymlink
+ case Pipe:
+ return linux.ModeNamedPipe
+ case CharacterDevice:
+ return linux.ModeCharacterDevice
+ case BlockDevice:
+ return linux.ModeBlockDevice
+ case Socket:
+ return linux.ModeSocket
+ default:
+ return 0
+ }
+}
+
+// ToDirentType converts an InodeType to a linux dirent type field.
+func ToDirentType(nodeType InodeType) uint8 {
+ switch nodeType {
+ case RegularFile, SpecialFile:
+ return linux.DT_REG
+ case Symlink:
+ return linux.DT_LNK
+ case Directory, SpecialDirectory:
+ return linux.DT_DIR
+ case Pipe:
+ return linux.DT_FIFO
+ case CharacterDevice:
+ return linux.DT_CHR
+ case BlockDevice:
+ return linux.DT_BLK
+ case Socket:
+ return linux.DT_SOCK
+ default:
+ return linux.DT_UNKNOWN
+ }
+}
+
+// ToInodeType coverts a linux file type to InodeType.
+func ToInodeType(linuxFileType linux.FileMode) InodeType {
+ switch linuxFileType {
+ case linux.ModeRegular:
+ return RegularFile
+ case linux.ModeDirectory:
+ return Directory
+ case linux.ModeSymlink:
+ return Symlink
+ case linux.ModeNamedPipe:
+ return Pipe
+ case linux.ModeCharacterDevice:
+ return CharacterDevice
+ case linux.ModeBlockDevice:
+ return BlockDevice
+ case linux.ModeSocket:
+ return Socket
+ default:
+ panic(fmt.Sprintf("unknown file mode: %d", linuxFileType))
+ }
+}
+
+// StableAttr contains Inode attributes that will be stable throughout the
+// lifetime of the Inode.
+//
+// +stateify savable
+type StableAttr struct {
+ // Type is the InodeType of a InodeOperations.
+ Type InodeType
+
+ // DeviceID is the device on which a InodeOperations resides.
+ DeviceID uint64
+
+ // InodeID uniquely identifies InodeOperations on its device.
+ InodeID uint64
+
+ // BlockSize is the block size of data backing this InodeOperations.
+ BlockSize int64
+
+ // DeviceFileMajor is the major device number of this Node, if it is a
+ // device file.
+ DeviceFileMajor uint16
+
+ // DeviceFileMinor is the minor device number of this Node, if it is a
+ // device file.
+ DeviceFileMinor uint32
+}
+
+// IsRegular returns true if StableAttr.Type matches a regular file.
+func IsRegular(s StableAttr) bool {
+ return s.Type == RegularFile
+}
+
+// IsFile returns true if StableAttr.Type matches any type of file.
+func IsFile(s StableAttr) bool {
+ return s.Type == RegularFile || s.Type == SpecialFile
+}
+
+// IsDir returns true if StableAttr.Type matches any type of directory.
+func IsDir(s StableAttr) bool {
+ return s.Type == Directory || s.Type == SpecialDirectory
+}
+
+// IsSymlink returns true if StableAttr.Type matches a symlink.
+func IsSymlink(s StableAttr) bool {
+ return s.Type == Symlink
+}
+
+// IsPipe returns true if StableAttr.Type matches any type of pipe.
+func IsPipe(s StableAttr) bool {
+ return s.Type == Pipe
+}
+
+// IsAnonymous returns true if StableAttr.Type matches any type of anonymous.
+func IsAnonymous(s StableAttr) bool {
+ return s.Type == Anonymous
+}
+
+// IsSocket returns true if StableAttr.Type matches any type of socket.
+func IsSocket(s StableAttr) bool {
+ return s.Type == Socket
+}
+
+// IsCharDevice returns true if StableAttr.Type matches a character device.
+func IsCharDevice(s StableAttr) bool {
+ return s.Type == CharacterDevice
+}
+
+// UnstableAttr contains Inode attributes that may change over the lifetime
+// of the Inode.
+//
+// +stateify savable
+type UnstableAttr struct {
+ // Size is the file size in bytes.
+ Size int64
+
+ // Usage is the actual data usage in bytes.
+ Usage int64
+
+ // Perms is the protection (read/write/execute for user/group/other).
+ Perms FilePermissions
+
+ // Owner describes the ownership of this file.
+ Owner FileOwner
+
+ // AccessTime is the time of last access
+ AccessTime ktime.Time
+
+ // ModificationTime is the time of last modification.
+ ModificationTime ktime.Time
+
+ // StatusChangeTime is the time of last attribute modification.
+ StatusChangeTime ktime.Time
+
+ // Links is the number of hard links.
+ Links uint64
+}
+
+// SetOwner sets the owner and group if they are valid.
+//
+// This method is NOT thread-safe. Callers must prevent concurrent calls.
+func (ua *UnstableAttr) SetOwner(ctx context.Context, owner FileOwner) {
+ if owner.UID.Ok() {
+ ua.Owner.UID = owner.UID
+ }
+ if owner.GID.Ok() {
+ ua.Owner.GID = owner.GID
+ }
+ ua.StatusChangeTime = ktime.NowFromContext(ctx)
+}
+
+// SetPermissions sets the permissions.
+//
+// This method is NOT thread-safe. Callers must prevent concurrent calls.
+func (ua *UnstableAttr) SetPermissions(ctx context.Context, p FilePermissions) {
+ ua.Perms = p
+ ua.StatusChangeTime = ktime.NowFromContext(ctx)
+}
+
+// SetTimestamps sets the timestamps according to the TimeSpec.
+//
+// This method is NOT thread-safe. Callers must prevent concurrent calls.
+func (ua *UnstableAttr) SetTimestamps(ctx context.Context, ts TimeSpec) {
+ if ts.ATimeOmit && ts.MTimeOmit {
+ return
+ }
+
+ now := ktime.NowFromContext(ctx)
+ if !ts.ATimeOmit {
+ if ts.ATimeSetSystemTime {
+ ua.AccessTime = now
+ } else {
+ ua.AccessTime = ts.ATime
+ }
+ }
+ if !ts.MTimeOmit {
+ if ts.MTimeSetSystemTime {
+ ua.ModificationTime = now
+ } else {
+ ua.ModificationTime = ts.MTime
+ }
+ }
+ ua.StatusChangeTime = now
+}
+
+// WithCurrentTime returns u with AccessTime == ModificationTime == current time.
+func WithCurrentTime(ctx context.Context, u UnstableAttr) UnstableAttr {
+ t := ktime.NowFromContext(ctx)
+ u.AccessTime = t
+ u.ModificationTime = t
+ u.StatusChangeTime = t
+ return u
+}
+
+// AttrMask contains fields to mask StableAttr and UnstableAttr.
+//
+// +stateify savable
+type AttrMask struct {
+ Type bool
+ DeviceID bool
+ InodeID bool
+ BlockSize bool
+ Size bool
+ Usage bool
+ Perms bool
+ UID bool
+ GID bool
+ AccessTime bool
+ ModificationTime bool
+ StatusChangeTime bool
+ Links bool
+}
+
+// Empty returns true if all fields in AttrMask are false.
+func (a AttrMask) Empty() bool {
+ return a == AttrMask{}
+}
+
+// PermMask are file access permissions.
+//
+// +stateify savable
+type PermMask struct {
+ // Read indicates reading is permitted.
+ Read bool
+
+ // Write indicates writing is permitted.
+ Write bool
+
+ // Execute indicates execution is permitted.
+ Execute bool
+}
+
+// OnlyRead returns true when only the read bit is set.
+func (p PermMask) OnlyRead() bool {
+ return p.Read && !p.Write && !p.Execute
+}
+
+// String implements the fmt.Stringer interface for PermMask.
+func (p PermMask) String() string {
+ return fmt.Sprintf("PermMask{Read: %v, Write: %v, Execute: %v}", p.Read, p.Write, p.Execute)
+}
+
+// Mode returns the system mode (syscall.S_IXOTH, etc.) for these permissions
+// in the "other" bits.
+func (p PermMask) Mode() (mode os.FileMode) {
+ if p.Read {
+ mode |= syscall.S_IROTH
+ }
+ if p.Write {
+ mode |= syscall.S_IWOTH
+ }
+ if p.Execute {
+ mode |= syscall.S_IXOTH
+ }
+ return
+}
+
+// SupersetOf returns true iff the permissions in p are a superset of the
+// permissions in other.
+func (p PermMask) SupersetOf(other PermMask) bool {
+ if !p.Read && other.Read {
+ return false
+ }
+ if !p.Write && other.Write {
+ return false
+ }
+ if !p.Execute && other.Execute {
+ return false
+ }
+ return true
+}
+
+// FilePermissions represents the permissions of a file, with
+// Read/Write/Execute bits for user, group, and other.
+//
+// +stateify savable
+type FilePermissions struct {
+ User PermMask
+ Group PermMask
+ Other PermMask
+
+ // Sticky, if set on directories, restricts renaming and deletion of
+ // files in those directories to the directory owner, file owner, or
+ // CAP_FOWNER. The sticky bit is ignored when set on other files.
+ Sticky bool
+
+ // SetUID executables can call UID-setting syscalls without CAP_SETUID.
+ SetUID bool
+
+ // SetGID executables can call GID-setting syscalls without CAP_SETGID.
+ SetGID bool
+}
+
+// PermsFromMode takes the Other permissions (last 3 bits) of a FileMode and
+// returns a set of PermMask.
+func PermsFromMode(mode linux.FileMode) (perms PermMask) {
+ perms.Read = mode&linux.ModeOtherRead != 0
+ perms.Write = mode&linux.ModeOtherWrite != 0
+ perms.Execute = mode&linux.ModeOtherExec != 0
+ return
+}
+
+// FilePermsFromP9 converts a p9.FileMode to a FilePermissions struct.
+func FilePermsFromP9(mode p9.FileMode) FilePermissions {
+ return FilePermsFromMode(linux.FileMode(mode))
+}
+
+// FilePermsFromMode converts a system file mode to a FilePermissions struct.
+func FilePermsFromMode(mode linux.FileMode) (fp FilePermissions) {
+ perm := mode.Permissions()
+ fp.Other = PermsFromMode(perm)
+ fp.Group = PermsFromMode(perm >> 3)
+ fp.User = PermsFromMode(perm >> 6)
+ fp.Sticky = mode&linux.ModeSticky == linux.ModeSticky
+ fp.SetUID = mode&linux.ModeSetUID == linux.ModeSetUID
+ fp.SetGID = mode&linux.ModeSetGID == linux.ModeSetGID
+ return
+}
+
+// LinuxMode returns the linux mode_t representation of these permissions.
+func (f FilePermissions) LinuxMode() linux.FileMode {
+ m := linux.FileMode(f.User.Mode()<<6 | f.Group.Mode()<<3 | f.Other.Mode())
+ if f.SetUID {
+ m |= linux.ModeSetUID
+ }
+ if f.SetGID {
+ m |= linux.ModeSetGID
+ }
+ if f.Sticky {
+ m |= linux.ModeSticky
+ }
+ return m
+}
+
+// OSMode returns the Go runtime's OS independent os.FileMode representation of
+// these permissions.
+func (f FilePermissions) OSMode() os.FileMode {
+ m := os.FileMode(f.User.Mode()<<6 | f.Group.Mode()<<3 | f.Other.Mode())
+ if f.SetUID {
+ m |= os.ModeSetuid
+ }
+ if f.SetGID {
+ m |= os.ModeSetgid
+ }
+ if f.Sticky {
+ m |= os.ModeSticky
+ }
+ return m
+}
+
+// AnyExecute returns true if any of U/G/O have the execute bit set.
+func (f FilePermissions) AnyExecute() bool {
+ return f.User.Execute || f.Group.Execute || f.Other.Execute
+}
+
+// AnyWrite returns true if any of U/G/O have the write bit set.
+func (f FilePermissions) AnyWrite() bool {
+ return f.User.Write || f.Group.Write || f.Other.Write
+}
+
+// AnyRead returns true if any of U/G/O have the read bit set.
+func (f FilePermissions) AnyRead() bool {
+ return f.User.Read || f.Group.Read || f.Other.Read
+}
+
+// FileOwner represents ownership of a file.
+//
+// +stateify savable
+type FileOwner struct {
+ UID auth.KUID
+ GID auth.KGID
+}
+
+// RootOwner corresponds to KUID/KGID 0/0.
+var RootOwner = FileOwner{
+ UID: auth.RootKUID,
+ GID: auth.RootKGID,
+}
diff --git a/pkg/sentry/fs/context.go b/pkg/sentry/fs/context.go
new file mode 100644
index 000000000..0fbd60056
--- /dev/null
+++ b/pkg/sentry/fs/context.go
@@ -0,0 +1,138 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// contextID is the fs package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxRoot is a Context.Value key for a Dirent.
+ CtxRoot contextID = iota
+
+ // CtxDirentCacheLimiter is a Context.Value key for DirentCacheLimiter.
+ CtxDirentCacheLimiter
+)
+
+// ContextCanAccessFile determines whether `file` can be accessed in the requested way
+// (for reading, writing, or execution) using the caller's credentials and user
+// namespace, as does Linux's fs/namei.c:generic_permission.
+func ContextCanAccessFile(ctx context.Context, inode *Inode, reqPerms PermMask) bool {
+ creds := auth.CredentialsFromContext(ctx)
+ uattr, err := inode.UnstableAttr(ctx)
+ if err != nil {
+ return false
+ }
+
+ p := uattr.Perms.Other
+ // Are we owner or in group?
+ if uattr.Owner.UID == creds.EffectiveKUID {
+ p = uattr.Perms.User
+ } else if creds.InGroup(uattr.Owner.GID) {
+ p = uattr.Perms.Group
+ }
+
+ // Do not allow programs to be executed if MS_NOEXEC is set.
+ if IsFile(inode.StableAttr) && reqPerms.Execute && inode.MountSource.Flags.NoExec {
+ return false
+ }
+
+ // Are permissions satisfied without capability checks?
+ if p.SupersetOf(reqPerms) {
+ return true
+ }
+
+ if IsDir(inode.StableAttr) {
+ // CAP_DAC_OVERRIDE can override any perms on directories.
+ if inode.CheckCapability(ctx, linux.CAP_DAC_OVERRIDE) {
+ return true
+ }
+
+ // CAP_DAC_READ_SEARCH can normally only override Read perms,
+ // but for directories it can also override execution.
+ if !reqPerms.Write && inode.CheckCapability(ctx, linux.CAP_DAC_READ_SEARCH) {
+ return true
+ }
+ }
+
+ // CAP_DAC_OVERRIDE can always override Read/Write.
+ // Can override executable only when at least one execute bit is set.
+ if !reqPerms.Execute || uattr.Perms.AnyExecute() {
+ if inode.CheckCapability(ctx, linux.CAP_DAC_OVERRIDE) {
+ return true
+ }
+ }
+
+ // Read perms can be overridden by CAP_DAC_READ_SEARCH.
+ if reqPerms.OnlyRead() && inode.CheckCapability(ctx, linux.CAP_DAC_READ_SEARCH) {
+ return true
+ }
+ return false
+}
+
+// FileOwnerFromContext returns a FileOwner using the effective user and group
+// IDs used by ctx.
+func FileOwnerFromContext(ctx context.Context) FileOwner {
+ creds := auth.CredentialsFromContext(ctx)
+ return FileOwner{creds.EffectiveKUID, creds.EffectiveKGID}
+}
+
+// RootFromContext returns the root of the virtual filesystem observed by ctx,
+// or nil if ctx is not associated with a virtual filesystem. If
+// RootFromContext returns a non-nil fs.Dirent, a reference is taken on it.
+func RootFromContext(ctx context.Context) *Dirent {
+ if v := ctx.Value(CtxRoot); v != nil {
+ return v.(*Dirent)
+ }
+ return nil
+}
+
+// DirentCacheLimiterFromContext returns the DirentCacheLimiter used by ctx, or
+// nil if ctx does not have a dirent cache limiter.
+func DirentCacheLimiterFromContext(ctx context.Context) *DirentCacheLimiter {
+ if v := ctx.Value(CtxDirentCacheLimiter); v != nil {
+ return v.(*DirentCacheLimiter)
+ }
+ return nil
+}
+
+type rootContext struct {
+ context.Context
+ root *Dirent
+}
+
+// WithRoot returns a copy of ctx with the given root.
+func WithRoot(ctx context.Context, root *Dirent) context.Context {
+ return &rootContext{
+ Context: ctx,
+ root: root,
+ }
+}
+
+// Value implements Context.Value.
+func (rc rootContext) Value(key interface{}) interface{} {
+ switch key {
+ case CtxRoot:
+ rc.root.IncRef()
+ return rc.root
+ default:
+ return rc.Context.Value(key)
+ }
+}
diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go
new file mode 100644
index 000000000..ab1424c95
--- /dev/null
+++ b/pkg/sentry/fs/copy_up.go
@@ -0,0 +1,436 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// copyUp copies a file in an overlay from a lower filesystem to an
+// upper filesytem so that the file can be modified in the upper
+// filesystem. Copying a file involves several steps:
+//
+// - All parent directories of the file are created in the upper
+// filesystem if they don't exist there. For instance:
+//
+// upper /dir0
+// lower /dir0/dir1/file
+//
+// copyUp of /dir0/dir1/file creates /dir0/dir1 in order to create
+// /dir0/dir1/file.
+//
+// - The file content is copied from the lower file to the upper
+// file. For symlinks this is the symlink target. For directories,
+// upper directory entries are merged with lower directory entries
+// so there is no need to copy any entries.
+//
+// - A subset of file attributes of the lower file are set on the
+// upper file. These are the file owner, the file timestamps,
+// and all non-overlay extended attributes. copyUp will fail if
+// the upper filesystem does not support the setting of these
+// attributes.
+//
+// The file's permissions are set when the file is created and its
+// size will be brought up to date when its contents are copied.
+// Notably no attempt is made to bring link count up to date because
+// hard links are currently not preserved across overlay filesystems.
+//
+// - Memory mappings of the lower file are invalidated and memory
+// references are transferred to the upper file. From this point on,
+// memory mappings of the file will be backed by content in the upper
+// filesystem.
+//
+// Synchronization:
+//
+// copyUp synchronizes with rename(2) using renameMu to ensure that
+// parentage does not change while a file is being copied. In the context
+// of rename(2), copyUpLockedForRename should be used to avoid deadlock on
+// renameMu.
+//
+// The following operations synchronize with copyUp using copyMu:
+//
+// - InodeOperations, i.e. to ensure that looking up a directory takes
+// into account new upper filesystem directories created by copy up,
+// which subsequently can be modified.
+//
+// - FileOperations, i.e. to ensure that reading from a file does not
+// continue using a stale, lower filesystem handle when the file is
+// written to.
+//
+// Lock ordering: Dirent.mu -> Inode.overlay.copyMu -> Inode.mu.
+//
+// Caveats:
+//
+// If any step in copying up a file fails, copyUp cleans the upper
+// filesystem of any partially up-to-date file. If this cleanup fails,
+// the overlay may be in an unacceptable, inconsistent state, so copyUp
+// panics. If copyUp fails because any step (above) fails, a generic
+// error is returned.
+//
+// copyUp currently makes no attempt to optimize copying up file content.
+// For large files, this means that copyUp blocks until the entire file
+// is copied synchronously.
+func copyUp(ctx context.Context, d *Dirent) error {
+ renameMu.RLock()
+ defer renameMu.RUnlock()
+ return copyUpLockedForRename(ctx, d)
+}
+
+// copyUpLockedForRename is the same as copyUp except that it does not lock
+// renameMu.
+//
+// It copies each component of d that does not yet exist in the upper
+// filesystem. If d already exists in the upper filesystem, it is a no-op.
+//
+// Any error returned indicates a failure to copy all of d. This may
+// leave the upper filesystem filled with any number of parent directories
+// but the upper filesystem will never be in an inconsistent state.
+//
+// Preconditions:
+// - d.Inode.overlay is non-nil.
+func copyUpLockedForRename(ctx context.Context, d *Dirent) error {
+ for {
+ // Did we race with another copy up or does there
+ // already exist something in the upper filesystem
+ // for d?
+ d.Inode.overlay.copyMu.RLock()
+ if d.Inode.overlay.upper != nil {
+ d.Inode.overlay.copyMu.RUnlock()
+ // Done, d is in the upper filesystem.
+ return nil
+ }
+ d.Inode.overlay.copyMu.RUnlock()
+
+ // Find the next component to copy up. We will work our way
+ // down to the last component of d and finally copy it.
+ next := findNextCopyUp(ctx, d)
+
+ // Attempt to copy.
+ if err := doCopyUp(ctx, next); err != nil {
+ return err
+ }
+ }
+}
+
+// findNextCopyUp finds the next component of d from root that does not
+// yet exist in the upper filesystem. The parent of this component is
+// also returned, which is the root of the overlay in the worst case.
+func findNextCopyUp(ctx context.Context, d *Dirent) *Dirent {
+ next := d
+ for parent := next.parent; ; /* checked in-loop */ /* updated in-loop */ {
+ // Does this parent have a non-nil upper Inode?
+ parent.Inode.overlay.copyMu.RLock()
+ if parent.Inode.overlay.upper != nil {
+ parent.Inode.overlay.copyMu.RUnlock()
+ // Note that since we found an upper, it is stable.
+ return next
+ }
+ parent.Inode.overlay.copyMu.RUnlock()
+
+ // Continue searching for a parent with a non-nil
+ // upper Inode.
+ next = parent
+ parent = next.parent
+ }
+}
+
+func doCopyUp(ctx context.Context, d *Dirent) error {
+ // Fail fast on Inode types we won't be able to copy up anyways. These
+ // Inodes may block in GetFile while holding copyMu for reading. If we
+ // then try to take copyMu for writing here, we'd deadlock.
+ t := d.Inode.overlay.lower.StableAttr.Type
+ if t != RegularFile && t != Directory && t != Symlink {
+ return syserror.EINVAL
+ }
+
+ // Wait to get exclusive access to the upper Inode.
+ d.Inode.overlay.copyMu.Lock()
+ defer d.Inode.overlay.copyMu.Unlock()
+ if d.Inode.overlay.upper != nil {
+ // We raced with another doCopyUp, no problem.
+ return nil
+ }
+
+ // Perform the copy.
+ return copyUpLocked(ctx, d.parent, d)
+}
+
+// copyUpLocked creates a copy of next in the upper filesystem of parent.
+//
+// copyUpLocked must be called with d.Inode.overlay.copyMu locked.
+//
+// Returns a generic error on failure.
+//
+// Preconditions:
+// - parent.Inode.overlay.upper must be non-nil.
+// - next.Inode.overlay.copyMu must be locked writable.
+// - next.Inode.overlay.lower must be non-nil.
+// - next.Inode.overlay.lower.StableAttr.Type must be RegularFile, Directory,
+// or Symlink.
+// - upper filesystem must support setting file ownership and timestamps.
+func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
+ // Extract the attributes of the file we wish to copy.
+ attrs, err := next.Inode.overlay.lower.UnstableAttr(ctx)
+ if err != nil {
+ log.Warningf("copy up failed to get lower attributes: %v", err)
+ return syserror.EIO
+ }
+
+ var childUpperInode *Inode
+ parentUpper := parent.Inode.overlay.upper
+ root := RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+
+ // Create the file in the upper filesystem and get an Inode for it.
+ switch next.Inode.StableAttr.Type {
+ case RegularFile:
+ childFile, err := parentUpper.Create(ctx, root, next.name, FileFlags{Read: true, Write: true}, attrs.Perms)
+ if err != nil {
+ log.Warningf("copy up failed to create file: %v", err)
+ return syserror.EIO
+ }
+ defer childFile.DecRef()
+ childUpperInode = childFile.Dirent.Inode
+
+ case Directory:
+ if err := parentUpper.CreateDirectory(ctx, root, next.name, attrs.Perms); err != nil {
+ log.Warningf("copy up failed to create directory: %v", err)
+ return syserror.EIO
+ }
+ childUpper, err := parentUpper.Lookup(ctx, next.name)
+ if err != nil {
+ werr := fmt.Errorf("copy up failed to lookup directory: %v", err)
+ cleanupUpper(ctx, parentUpper, next.name, werr)
+ return syserror.EIO
+ }
+ defer childUpper.DecRef()
+ childUpperInode = childUpper.Inode
+
+ case Symlink:
+ childLower := next.Inode.overlay.lower
+ link, err := childLower.Readlink(ctx)
+ if err != nil {
+ log.Warningf("copy up failed to read symlink value: %v", err)
+ return syserror.EIO
+ }
+ if err := parentUpper.CreateLink(ctx, root, link, next.name); err != nil {
+ log.Warningf("copy up failed to create symlink: %v", err)
+ return syserror.EIO
+ }
+ childUpper, err := parentUpper.Lookup(ctx, next.name)
+ if err != nil {
+ werr := fmt.Errorf("copy up failed to lookup symlink: %v", err)
+ cleanupUpper(ctx, parentUpper, next.name, werr)
+ return syserror.EIO
+ }
+ defer childUpper.DecRef()
+ childUpperInode = childUpper.Inode
+
+ default:
+ panic(fmt.Sprintf("copy up of invalid type %v on %+v", next.Inode.StableAttr.Type, next))
+ }
+
+ // Bring file attributes up to date. This does not include size, which will be
+ // brought up to date with copyContentsLocked.
+ if err := copyAttributesLocked(ctx, childUpperInode, next.Inode.overlay.lower); err != nil {
+ werr := fmt.Errorf("copy up failed to copy up attributes: %v", err)
+ cleanupUpper(ctx, parentUpper, next.name, werr)
+ return syserror.EIO
+ }
+
+ // Copy the entire file.
+ if err := copyContentsLocked(ctx, childUpperInode, next.Inode.overlay.lower, attrs.Size); err != nil {
+ werr := fmt.Errorf("copy up failed to copy up contents: %v", err)
+ cleanupUpper(ctx, parentUpper, next.name, werr)
+ return syserror.EIO
+ }
+
+ lowerMappable := next.Inode.overlay.lower.Mappable()
+ upperMappable := childUpperInode.Mappable()
+ if lowerMappable != nil && upperMappable == nil {
+ werr := fmt.Errorf("copy up failed: cannot ensure memory mapping coherence")
+ cleanupUpper(ctx, parentUpper, next.name, werr)
+ return syserror.EIO
+ }
+
+ // Propagate memory mappings to the upper Inode.
+ next.Inode.overlay.mapsMu.Lock()
+ defer next.Inode.overlay.mapsMu.Unlock()
+ if upperMappable != nil {
+ // Remember which mappings we added so we can remove them on failure.
+ allAdded := make(map[memmap.MappableRange]memmap.MappingsOfRange)
+ for seg := next.Inode.overlay.mappings.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ added := make(memmap.MappingsOfRange)
+ for m := range seg.Value() {
+ if err := upperMappable.AddMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable); err != nil {
+ for m := range added {
+ upperMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable)
+ }
+ for mr, mappings := range allAdded {
+ for m := range mappings {
+ upperMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, mr.Start, m.Writable)
+ }
+ }
+ return err
+ }
+ added[m] = struct{}{}
+ }
+ allAdded[seg.Range()] = added
+ }
+ }
+
+ // Take a reference on the upper Inode (transferred to
+ // next.Inode.overlay.upper) and make new translations use it.
+ next.Inode.overlay.dataMu.Lock()
+ childUpperInode.IncRef()
+ next.Inode.overlay.upper = childUpperInode
+ next.Inode.overlay.dataMu.Unlock()
+
+ // Invalidate existing translations through the lower Inode.
+ next.Inode.overlay.mappings.InvalidateAll(memmap.InvalidateOpts{})
+
+ // Remove existing memory mappings from the lower Inode.
+ if lowerMappable != nil {
+ for seg := next.Inode.overlay.mappings.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ for m := range seg.Value() {
+ lowerMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable)
+ }
+ }
+ }
+
+ return nil
+}
+
+// cleanupUpper is called when copy-up fails. It logs the copy-up error and
+// attempts to remove name from parent. If that fails, then it panics.
+func cleanupUpper(ctx context.Context, parent *Inode, name string, copyUpErr error) {
+ log.Warningf(copyUpErr.Error())
+ if err := parent.InodeOperations.Remove(ctx, parent, name); err != nil {
+ // Unfortunately we don't have much choice. We shouldn't
+ // willingly give the caller access to a nonsense filesystem.
+ panic(fmt.Sprintf("overlay filesystem is in an inconsistent state: copyUp got error: %v; then cleanup failed to remove %q from upper filesystem: %v.", copyUpErr, name, err))
+ }
+}
+
+// copyUpBuffers is a buffer pool for copying file content. The buffer
+// size is the same used by io.Copy.
+var copyUpBuffers = sync.Pool{New: func() interface{} { return make([]byte, 8*usermem.PageSize) }}
+
+// copyContentsLocked copies the contents of lower to upper. It panics if
+// less than size bytes can be copied.
+func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size int64) error {
+ // We don't support copying up for anything other than regular files.
+ if lower.StableAttr.Type != RegularFile {
+ return nil
+ }
+
+ // Get a handle to the upper filesystem, which we will write to.
+ upperFile, err := overlayFile(ctx, upper, FileFlags{Write: true})
+ if err != nil {
+ return err
+ }
+ defer upperFile.DecRef()
+
+ // Get a handle to the lower filesystem, which we will read from.
+ lowerFile, err := overlayFile(ctx, lower, FileFlags{Read: true})
+ if err != nil {
+ return err
+ }
+ defer lowerFile.DecRef()
+
+ // Use a buffer pool to minimize allocations.
+ buf := copyUpBuffers.Get().([]byte)
+ defer copyUpBuffers.Put(buf)
+
+ // Transfer the contents.
+ //
+ // One might be able to optimize this by doing parallel reads, parallel writes and reads, larger
+ // buffers, etc. But we really don't know anything about the underlying implementation, so these
+ // optimizations could be self-defeating. So we leave this as simple as possible.
+ var offset int64
+ for {
+ nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(buf), offset)
+ if err != nil && err != io.EOF {
+ return err
+ }
+ if nr == 0 {
+ if offset != size {
+ // Same as in cleanupUpper, we cannot live
+ // with ourselves if we do anything less.
+ panic(fmt.Sprintf("filesystem is in an inconsistent state: wrote only %d bytes of %d sized file", offset, size))
+ }
+ return nil
+ }
+ nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence(buf[:nr]), offset)
+ if err != nil {
+ return err
+ }
+ offset += nw
+ }
+}
+
+// copyAttributesLocked copies a subset of lower's attributes to upper,
+// specifically owner, timestamps (except of status change time), and
+// extended attributes. Notably no attempt is made to copy link count.
+// Size and permissions are set on upper when the file content is copied
+// and when the file is created respectively.
+func copyAttributesLocked(ctx context.Context, upper *Inode, lower *Inode) error {
+ // Extract attributes from the lower filesystem.
+ lowerAttr, err := lower.UnstableAttr(ctx)
+ if err != nil {
+ return err
+ }
+ lowerXattr, err := lower.ListXattr(ctx, linux.XATTR_SIZE_MAX)
+ if err != nil && err != syserror.EOPNOTSUPP {
+ return err
+ }
+
+ // Set the attributes on the upper filesystem.
+ if err := upper.InodeOperations.SetOwner(ctx, upper, lowerAttr.Owner); err != nil {
+ return err
+ }
+ if err := upper.InodeOperations.SetTimestamps(ctx, upper, TimeSpec{
+ ATime: lowerAttr.AccessTime,
+ MTime: lowerAttr.ModificationTime,
+ }); err != nil {
+ return err
+ }
+ for name := range lowerXattr {
+ // Don't copy-up attributes that configure an overlay in the
+ // lower.
+ if isXattrOverlay(name) {
+ continue
+ }
+ value, err := lower.GetXattr(ctx, name, linux.XATTR_SIZE_MAX)
+ if err != nil {
+ return err
+ }
+ if err := upper.InodeOperations.SetXattr(ctx, upper, name, value, 0 /* flags */); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go
new file mode 100644
index 000000000..91792d9fe
--- /dev/null
+++ b/pkg/sentry/fs/copy_up_test.go
@@ -0,0 +1,183 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs_test
+
+import (
+ "bytes"
+ "crypto/rand"
+ "fmt"
+ "io"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ _ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ // origFileSize is the original file size. This many bytes should be
+ // copied up before the test file is modified.
+ origFileSize = 4096
+
+ // truncatedFileSize is the size to truncate all test files.
+ truncateFileSize = 10
+)
+
+// TestConcurrentCopyUp is a copy up stress test for an overlay.
+//
+// It creates a 64-level deep directory tree in the lower filesystem and
+// populates the last subdirectory with 64 files containing random content:
+//
+// /lower
+// /sudir0/.../subdir63/
+// /file0
+// ...
+// /file63
+//
+// The files are truncated concurrently by 4 goroutines per file.
+// These goroutines contend with copying up all parent 64 subdirectories
+// as well as the final file content.
+//
+// At the end of the test, we assert that the files respect the new truncated
+// size and contain the content we expect.
+func TestConcurrentCopyUp(t *testing.T) {
+ ctx := contexttest.Context(t)
+ files := makeOverlayTestFiles(t)
+
+ var wg sync.WaitGroup
+ for _, file := range files {
+ for i := 0; i < 4; i++ {
+ wg.Add(1)
+ go func(o *overlayTestFile) {
+ if err := o.File.Dirent.Inode.Truncate(ctx, o.File.Dirent, truncateFileSize); err != nil {
+ t.Fatalf("failed to copy up: %v", err)
+ }
+ wg.Done()
+ }(file)
+ }
+ }
+ wg.Wait()
+
+ for _, file := range files {
+ got := make([]byte, origFileSize)
+ n, err := file.File.Readv(ctx, usermem.BytesIOSequence(got))
+ if int(n) != truncateFileSize {
+ t.Fatalf("read %d bytes from file, want %d", n, truncateFileSize)
+ }
+ if err != nil && err != io.EOF {
+ t.Fatalf("read got error %v, want nil", err)
+ }
+ if !bytes.Equal(got[:n], file.content[:truncateFileSize]) {
+ t.Fatalf("file content is %v, want %v", got[:n], file.content[:truncateFileSize])
+ }
+ }
+}
+
+type overlayTestFile struct {
+ File *fs.File
+ name string
+ content []byte
+}
+
+func makeOverlayTestFiles(t *testing.T) []*overlayTestFile {
+ ctx := contexttest.Context(t)
+
+ // Create a lower tmpfs mount.
+ fsys, _ := fs.FindFilesystem("tmpfs")
+ lower, err := fsys.Mount(contexttest.Context(t), "", fs.MountSourceFlags{}, "", nil)
+ if err != nil {
+ t.Fatalf("failed to mount tmpfs: %v", err)
+ }
+ lowerRoot := fs.NewDirent(ctx, lower, "")
+
+ // Make a deep set of subdirectories that everyone shares.
+ next := lowerRoot
+ for i := 0; i < 64; i++ {
+ name := fmt.Sprintf("subdir%d", i)
+ err := next.CreateDirectory(ctx, lowerRoot, name, fs.FilePermsFromMode(0777))
+ if err != nil {
+ t.Fatalf("failed to create dir %q: %v", name, err)
+ }
+ next, err = next.Walk(ctx, lowerRoot, name)
+ if err != nil {
+ t.Fatalf("failed to walk to %q: %v", name, err)
+ }
+ }
+
+ // Make a bunch of files in the last directory.
+ var files []*overlayTestFile
+ for i := 0; i < 64; i++ {
+ name := fmt.Sprintf("file%d", i)
+ f, err := next.Create(ctx, next, name, fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666))
+ if err != nil {
+ t.Fatalf("failed to create file %q: %v", name, err)
+ }
+ defer f.DecRef()
+
+ relname, _ := f.Dirent.FullName(lowerRoot)
+
+ o := &overlayTestFile{
+ name: relname,
+ content: make([]byte, origFileSize),
+ }
+
+ if _, err := rand.Read(o.content); err != nil {
+ t.Fatalf("failed to read from /dev/urandom: %v", err)
+ }
+
+ if _, err := f.Writev(ctx, usermem.BytesIOSequence(o.content)); err != nil {
+ t.Fatalf("failed to write content to file %q: %v", name, err)
+ }
+
+ files = append(files, o)
+ }
+
+ // Create an empty upper tmpfs mount which we will copy up into.
+ upper, err := fsys.Mount(ctx, "", fs.MountSourceFlags{}, "", nil)
+ if err != nil {
+ t.Fatalf("failed to mount tmpfs: %v", err)
+ }
+
+ // Construct an overlay root.
+ overlay, err := fs.NewOverlayRoot(ctx, upper, lower, fs.MountSourceFlags{})
+ if err != nil {
+ t.Fatalf("failed to construct overlay root: %v", err)
+ }
+
+ // Create a MountNamespace to traverse the file system.
+ mns, err := fs.NewMountNamespace(ctx, overlay)
+ if err != nil {
+ t.Fatalf("failed to construct mount manager: %v", err)
+ }
+
+ // Walk to all of the files in the overlay, open them readable.
+ for _, f := range files {
+ maxTraversals := uint(0)
+ d, err := mns.FindInode(ctx, mns.Root(), mns.Root(), f.name, &maxTraversals)
+ if err != nil {
+ t.Fatalf("failed to find %q: %v", f.name, err)
+ }
+ defer d.DecRef()
+
+ f.File, err = d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true})
+ if err != nil {
+ t.Fatalf("failed to open file %q readable: %v", f.name, err)
+ }
+ }
+
+ return files
+}
diff --git a/pkg/sentry/fs/dentry.go b/pkg/sentry/fs/dentry.go
new file mode 100644
index 000000000..6b2699f15
--- /dev/null
+++ b/pkg/sentry/fs/dentry.go
@@ -0,0 +1,234 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "sort"
+
+ "gvisor.dev/gvisor/pkg/sentry/device"
+)
+
+// DentAttr is the metadata of a directory entry. It is a subset of StableAttr.
+//
+// +stateify savable
+type DentAttr struct {
+ // Type is the InodeType of an Inode.
+ Type InodeType
+
+ // InodeID uniquely identifies an Inode on a device.
+ InodeID uint64
+}
+
+// GenericDentAttr returns a generic DentAttr where:
+//
+// Type == nt
+// InodeID == the inode id of a new inode on device.
+func GenericDentAttr(nt InodeType, device *device.Device) DentAttr {
+ return DentAttr{
+ Type: nt,
+ InodeID: device.NextIno(),
+ }
+}
+
+// DentrySerializer serializes a directory entry.
+type DentrySerializer interface {
+ // CopyOut serializes a directory entry based on its name and attributes.
+ CopyOut(name string, attributes DentAttr) error
+
+ // Written returns the number of bytes written.
+ Written() int
+}
+
+// CollectEntriesSerializer copies DentAttrs to Entries. The order in
+// which entries are encountered is preserved in Order.
+type CollectEntriesSerializer struct {
+ Entries map[string]DentAttr
+ Order []string
+}
+
+// CopyOut implements DentrySerializer.CopyOut.
+func (c *CollectEntriesSerializer) CopyOut(name string, attr DentAttr) error {
+ if c.Entries == nil {
+ c.Entries = make(map[string]DentAttr)
+ }
+ c.Entries[name] = attr
+ c.Order = append(c.Order, name)
+ return nil
+}
+
+// Written implements DentrySerializer.Written.
+func (c *CollectEntriesSerializer) Written() int {
+ return len(c.Entries)
+}
+
+// DirCtx is used in FileOperations.IterateDir to emit directory entries. It is
+// not thread-safe.
+type DirCtx struct {
+ // Serializer is used to serialize the node attributes.
+ Serializer DentrySerializer
+
+ // attrs are DentAttrs
+ attrs map[string]DentAttr
+
+ // DirCursor is the directory cursor.
+ DirCursor *string
+}
+
+// DirEmit is called for each directory entry.
+func (c *DirCtx) DirEmit(name string, attr DentAttr) error {
+ if c.Serializer != nil {
+ if err := c.Serializer.CopyOut(name, attr); err != nil {
+ return err
+ }
+ }
+ if c.attrs == nil {
+ c.attrs = make(map[string]DentAttr)
+ }
+ c.attrs[name] = attr
+ return nil
+}
+
+// DentAttrs returns a map of DentAttrs corresponding to the emitted directory
+// entries.
+func (c *DirCtx) DentAttrs() map[string]DentAttr {
+ if c.attrs == nil {
+ c.attrs = make(map[string]DentAttr)
+ }
+ return c.attrs
+}
+
+// GenericReaddir serializes DentAttrs based on a SortedDentryMap that must
+// contain _all_ up-to-date DentAttrs under a directory. If ctx.DirCursor is
+// not nil, it is updated to the name of the last DentAttr that was
+// successfully serialized.
+//
+// Returns the number of entries serialized.
+func GenericReaddir(ctx *DirCtx, s *SortedDentryMap) (int, error) {
+ // Retrieve the next directory entries.
+ var names []string
+ var entries map[string]DentAttr
+ if ctx.DirCursor != nil {
+ names, entries = s.GetNext(*ctx.DirCursor)
+ } else {
+ names, entries = s.GetAll()
+ }
+
+ // Try to serialize each entry.
+ var serialized int
+ for _, name := range names {
+ // Skip "" per POSIX. Skip "." and ".." which will be added by Dirent.Readdir.
+ if name == "" || name == "." || name == ".." {
+ continue
+ }
+
+ // Emit the directory entry.
+ if err := ctx.DirEmit(name, entries[name]); err != nil {
+ // Return potentially a partial serialized count.
+ return serialized, err
+ }
+
+ // We successfully serialized this entry.
+ serialized++
+
+ // Update the cursor with the name of the entry last serialized.
+ if ctx.DirCursor != nil {
+ *ctx.DirCursor = name
+ }
+ }
+
+ // Everything was serialized.
+ return serialized, nil
+}
+
+// SortedDentryMap is a sorted map of names and fs.DentAttr entries.
+//
+// +stateify savable
+type SortedDentryMap struct {
+ // names is always kept in sorted-order.
+ names []string
+
+ // entries maps names to fs.DentAttrs.
+ entries map[string]DentAttr
+}
+
+// NewSortedDentryMap maintains entries in name sorted order.
+func NewSortedDentryMap(entries map[string]DentAttr) *SortedDentryMap {
+ s := &SortedDentryMap{
+ names: make([]string, 0, len(entries)),
+ entries: entries,
+ }
+ // Don't allow s.entries to be nil, because nil maps arn't Saveable.
+ if s.entries == nil {
+ s.entries = make(map[string]DentAttr)
+ }
+
+ // Collect names from entries and sort them.
+ for name := range s.entries {
+ s.names = append(s.names, name)
+ }
+ sort.Strings(s.names)
+ return s
+}
+
+// GetAll returns all names and entries in s. Callers should not modify the
+// returned values.
+func (s *SortedDentryMap) GetAll() ([]string, map[string]DentAttr) {
+ return s.names, s.entries
+}
+
+// GetNext returns names after cursor in s and all entries.
+func (s *SortedDentryMap) GetNext(cursor string) ([]string, map[string]DentAttr) {
+ i := sort.SearchStrings(s.names, cursor)
+ if i == len(s.names) {
+ return nil, s.entries
+ }
+
+ // Return everything strictly after the cursor.
+ if s.names[i] == cursor {
+ i++
+ }
+ return s.names[i:], s.entries
+}
+
+// Add adds an entry with the given name to the map, preserving sort order. If
+// name already exists in the map, its entry will be overwritten.
+func (s *SortedDentryMap) Add(name string, entry DentAttr) {
+ if _, ok := s.entries[name]; !ok {
+ // Map does not yet contain an entry with this name. We must
+ // insert it in s.names at the appropriate spot.
+ i := sort.SearchStrings(s.names, name)
+ s.names = append(s.names, "")
+ copy(s.names[i+1:], s.names[i:])
+ s.names[i] = name
+ }
+ s.entries[name] = entry
+}
+
+// Remove removes an entry with the given name from the map, preserving sort order.
+func (s *SortedDentryMap) Remove(name string) {
+ if _, ok := s.entries[name]; !ok {
+ return
+ }
+ i := sort.SearchStrings(s.names, name)
+ copy(s.names[i:], s.names[i+1:])
+ s.names = s.names[:len(s.names)-1]
+ delete(s.entries, name)
+}
+
+// Contains reports whether the map contains an entry with the given name.
+func (s *SortedDentryMap) Contains(name string) bool {
+ _, ok := s.entries[name]
+ return ok
+}
diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD
new file mode 100644
index 000000000..9379a4d7b
--- /dev/null
+++ b/pkg/sentry/fs/dev/BUILD
@@ -0,0 +1,40 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "dev",
+ srcs = [
+ "dev.go",
+ "device.go",
+ "fs.go",
+ "full.go",
+ "net_tun.go",
+ "null.go",
+ "random.go",
+ "tty.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/rand",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/ramfs",
+ "//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/mm",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/socket/netstack",
+ "//pkg/syserror",
+ "//pkg/tcpip/link/tun",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go
new file mode 100644
index 000000000..acbd401a0
--- /dev/null
+++ b/pkg/sentry/fs/dev/dev.go
@@ -0,0 +1,151 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package dev provides a filesystem with simple devices.
+package dev
+
+import (
+ "math"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Memory device numbers are from Linux's drivers/char/mem.c
+const (
+ // Mem device major.
+ memDevMajor uint16 = 1
+
+ // Mem device minors.
+ nullDevMinor uint32 = 3
+ zeroDevMinor uint32 = 5
+ fullDevMinor uint32 = 7
+ randomDevMinor uint32 = 8
+ urandomDevMinor uint32 = 9
+)
+
+// TTY major device number comes from include/uapi/linux/major.h.
+const (
+ ttyDevMinor = 0
+ ttyDevMajor = 5
+)
+
+func newCharacterDevice(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSource, major uint16, minor uint32) *fs.Inode {
+ return fs.NewInode(ctx, iops, msrc, fs.StableAttr{
+ DeviceID: devDevice.DeviceID(),
+ InodeID: devDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.CharacterDevice,
+ DeviceFileMajor: major,
+ DeviceFileMinor: minor,
+ })
+}
+
+func newMemDevice(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSource, minor uint32) *fs.Inode {
+ return fs.NewInode(ctx, iops, msrc, fs.StableAttr{
+ DeviceID: devDevice.DeviceID(),
+ InodeID: devDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.CharacterDevice,
+ DeviceFileMajor: memDevMajor,
+ DeviceFileMinor: minor,
+ })
+}
+
+func newDirectory(ctx context.Context, contents map[string]*fs.Inode, msrc *fs.MountSource) *fs.Inode {
+ iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return fs.NewInode(ctx, iops, msrc, fs.StableAttr{
+ DeviceID: devDevice.DeviceID(),
+ InodeID: devDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.Directory,
+ })
+}
+
+func newSymlink(ctx context.Context, target string, msrc *fs.MountSource) *fs.Inode {
+ iops := ramfs.NewSymlink(ctx, fs.RootOwner, target)
+ return fs.NewInode(ctx, iops, msrc, fs.StableAttr{
+ DeviceID: devDevice.DeviceID(),
+ InodeID: devDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.Symlink,
+ })
+}
+
+// New returns the root node of a device filesystem.
+func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ contents := map[string]*fs.Inode{
+ "fd": newSymlink(ctx, "/proc/self/fd", msrc),
+ "stdin": newSymlink(ctx, "/proc/self/fd/0", msrc),
+ "stdout": newSymlink(ctx, "/proc/self/fd/1", msrc),
+ "stderr": newSymlink(ctx, "/proc/self/fd/2", msrc),
+
+ "null": newMemDevice(ctx, newNullDevice(ctx, fs.RootOwner, 0666), msrc, nullDevMinor),
+ "zero": newMemDevice(ctx, newZeroDevice(ctx, fs.RootOwner, 0666), msrc, zeroDevMinor),
+ "full": newMemDevice(ctx, newFullDevice(ctx, fs.RootOwner, 0666), msrc, fullDevMinor),
+
+ // This is not as good as /dev/random in linux because go
+ // runtime uses sys_random and /dev/urandom internally.
+ // According to 'man 4 random', this will be sufficient unless
+ // application uses this to generate long-lived GPG/SSL/SSH
+ // keys.
+ "random": newMemDevice(ctx, newRandomDevice(ctx, fs.RootOwner, 0444), msrc, randomDevMinor),
+ "urandom": newMemDevice(ctx, newRandomDevice(ctx, fs.RootOwner, 0444), msrc, urandomDevMinor),
+
+ "shm": tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc),
+
+ // A devpts is typically mounted at /dev/pts to provide
+ // pseudoterminal support. Place an empty directory there for
+ // the devpts to be mounted over.
+ "pts": newDirectory(ctx, nil, msrc),
+ // Similarly, applications expect a ptmx device at /dev/ptmx
+ // connected to the terminals provided by /dev/pts/. Rather
+ // than creating a device directly (which requires a hairy
+ // lookup on open to determine if a devpts exists), just create
+ // a symlink to the ptmx provided by devpts. (The Linux devpts
+ // documentation recommends this).
+ //
+ // If no devpts is mounted, this will simply be a dangling
+ // symlink, which is fine.
+ "ptmx": newSymlink(ctx, "pts/ptmx", msrc),
+
+ "tty": newCharacterDevice(ctx, newTTYDevice(ctx, fs.RootOwner, 0666), msrc, ttyDevMajor, ttyDevMinor),
+ }
+
+ if isNetTunSupported(inet.StackFromContext(ctx)) {
+ contents["net"] = newDirectory(ctx, map[string]*fs.Inode{
+ "tun": newCharacterDevice(ctx, newNetTunDevice(ctx, fs.RootOwner, 0666), msrc, netTunDevMajor, netTunDevMinor),
+ }, msrc)
+ }
+
+ iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return fs.NewInode(ctx, iops, msrc, fs.StableAttr{
+ DeviceID: devDevice.DeviceID(),
+ InodeID: devDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.Directory,
+ })
+}
+
+// readZeros implements fs.FileOperations.Read with infinite null bytes.
+type readZeros struct{}
+
+// Read implements fs.FileOperations.Read.
+func (*readZeros) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ return dst.ZeroOut(ctx, math.MaxInt64)
+}
diff --git a/pkg/sentry/fs/dev/device.go b/pkg/sentry/fs/dev/device.go
new file mode 100644
index 000000000..a0493474e
--- /dev/null
+++ b/pkg/sentry/fs/dev/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev
+
+import "gvisor.dev/gvisor/pkg/sentry/device"
+
+// devDevice is the pseudo-filesystem device.
+var devDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/fs/dev/fs.go b/pkg/sentry/fs/dev/fs.go
new file mode 100644
index 000000000..5e518fb63
--- /dev/null
+++ b/pkg/sentry/fs/dev/fs.go
@@ -0,0 +1,64 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+)
+
+// filesystem is a devtmpfs.
+//
+// +stateify savable
+type filesystem struct{}
+
+var _ fs.Filesystem = (*filesystem)(nil)
+
+func init() {
+ fs.RegisterFilesystem(&filesystem{})
+}
+
+// FilesystemName is the name under which the filesystem is registered.
+// Name matches drivers/base/devtmpfs.c:dev_fs_type.name.
+const FilesystemName = "devtmpfs"
+
+// Name is the name of the file system.
+func (*filesystem) Name() string {
+ return FilesystemName
+}
+
+// AllowUserMount allows users to mount(2) this file system.
+func (*filesystem) AllowUserMount() bool {
+ return true
+}
+
+// AllowUserList allows this filesystem to be listed in /proc/filesystems.
+func (*filesystem) AllowUserList() bool {
+ return true
+}
+
+// Flags returns that there is nothing special about this file system.
+//
+// In Linux, devtmpfs does the same thing.
+func (*filesystem) Flags() fs.FilesystemFlags {
+ return 0
+}
+
+// Mount returns a devtmpfs root that can be positioned in the vfs.
+func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSourceFlags, data string, _ interface{}) (*fs.Inode, error) {
+ // devtmpfs backed by ramfs ignores bad options. See fs/ramfs/inode.c:ramfs_parse_options.
+ // -> we should consider parsing the mode and backing devtmpfs by this.
+ return New(ctx, fs.NewNonCachingMountSource(ctx, f, flags)), nil
+}
diff --git a/pkg/sentry/fs/dev/full.go b/pkg/sentry/fs/dev/full.go
new file mode 100644
index 000000000..deb9c6ad8
--- /dev/null
+++ b/pkg/sentry/fs/dev/full.go
@@ -0,0 +1,81 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// fullDevice is used to implement /dev/full.
+//
+// +stateify savable
+type fullDevice struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopAllocate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+}
+
+var _ fs.InodeOperations = (*fullDevice)(nil)
+
+func newFullDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *fullDevice {
+ f := &fullDevice{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC),
+ }
+ return f
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (f *fullDevice) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ flags.Pread = true
+ return fs.NewFile(ctx, dirent, flags, &fullFileOperations{}), nil
+}
+
+// +stateify savable
+type fullFileOperations struct {
+ waiter.AlwaysReady `state:"nosave"`
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ readZeros `state:"nosave"`
+}
+
+var _ fs.FileOperations = (*fullFileOperations)(nil)
+
+// Write implements FileOperations.Write.
+func (*fullFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.ENOSPC
+}
diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go
new file mode 100644
index 000000000..dc7ad075a
--- /dev/null
+++ b/pkg/sentry/fs/dev/net_tun.go
@@ -0,0 +1,177 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ netTunDevMajor = 10
+ netTunDevMinor = 200
+)
+
+// +stateify savable
+type netTunInodeOperations struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopAllocate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+}
+
+var _ fs.InodeOperations = (*netTunInodeOperations)(nil)
+
+func newNetTunDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *netTunInodeOperations {
+ return &netTunInodeOperations{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC),
+ }
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (iops *netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, d, flags, &netTunFileOperations{}), nil
+}
+
+// +stateify savable
+type netTunFileOperations struct {
+ fsutil.FileNoSeek `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ device tun.Device
+}
+
+var _ fs.FileOperations = (*netTunFileOperations)(nil)
+
+// Release implements fs.FileOperations.Release.
+func (fops *netTunFileOperations) Release() {
+ fops.device.Release()
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ request := args[1].Uint()
+ data := args[2].Pointer()
+
+ switch request {
+ case linux.TUNSETIFF:
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ panic("Ioctl should be called from a task context")
+ }
+ if !t.HasCapability(linux.CAP_NET_ADMIN) {
+ return 0, syserror.EPERM
+ }
+ stack, ok := t.NetworkContext().(*netstack.Stack)
+ if !ok {
+ return 0, syserror.EINVAL
+ }
+
+ var req linux.IFReq
+ if _, err := usermem.CopyObjectIn(ctx, io, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ flags := usermem.ByteOrder.Uint16(req.Data[:])
+ return 0, fops.device.SetIff(stack.Stack, req.Name(), flags)
+
+ case linux.TUNGETIFF:
+ var req linux.IFReq
+
+ copy(req.IFName[:], fops.device.Name())
+
+ // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when
+ // there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c.
+ flags := fops.device.Flags() | linux.IFF_NOFILTER
+ usermem.ByteOrder.PutUint16(req.Data[:], flags)
+
+ _, err := usermem.CopyObjectOut(ctx, io, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+// Write implements fs.FileOperations.Write.
+func (fops *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ data := make([]byte, src.NumBytes())
+ if _, err := src.CopyIn(ctx, data); err != nil {
+ return 0, err
+ }
+ return fops.device.Write(data)
+}
+
+// Read implements fs.FileOperations.Read.
+func (fops *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ data, err := fops.device.Read()
+ if err != nil {
+ return 0, err
+ }
+ n, err := dst.CopyOut(ctx, data)
+ if n > 0 && n < len(data) {
+ // Not an error for partial copying. Packet truncated.
+ err = nil
+ }
+ return int64(n), err
+}
+
+// Readiness implements watier.Waitable.Readiness.
+func (fops *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fops.device.Readiness(mask)
+}
+
+// EventRegister implements watier.Waitable.EventRegister.
+func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fops.device.EventRegister(e, mask)
+}
+
+// EventUnregister implements watier.Waitable.EventUnregister.
+func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) {
+ fops.device.EventUnregister(e)
+}
+
+// isNetTunSupported returns whether /dev/net/tun device is supported for s.
+func isNetTunSupported(s inet.Stack) bool {
+ _, ok := s.(*netstack.Stack)
+ return ok
+}
diff --git a/pkg/sentry/fs/dev/null.go b/pkg/sentry/fs/dev/null.go
new file mode 100644
index 000000000..aec33d0d9
--- /dev/null
+++ b/pkg/sentry/fs/dev/null.go
@@ -0,0 +1,131 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type nullDevice struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopAllocate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+}
+
+var _ fs.InodeOperations = (*nullDevice)(nil)
+
+func newNullDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *nullDevice {
+ n := &nullDevice{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC),
+ }
+ return n
+}
+
+// GetFile implements fs.FileOperations.GetFile.
+func (n *nullDevice) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ flags.Pread = true
+ flags.Pwrite = true
+
+ return fs.NewFile(ctx, dirent, flags, &nullFileOperations{}), nil
+}
+
+// +stateify savable
+type nullFileOperations struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRead `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNoopWrite `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+}
+
+var _ fs.FileOperations = (*nullFileOperations)(nil)
+
+// +stateify savable
+type zeroDevice struct {
+ nullDevice
+}
+
+var _ fs.InodeOperations = (*zeroDevice)(nil)
+
+func newZeroDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *zeroDevice {
+ zd := &zeroDevice{
+ nullDevice: nullDevice{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC),
+ },
+ }
+ return zd
+}
+
+// GetFile implements fs.FileOperations.GetFile.
+func (zd *zeroDevice) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ flags.Pread = true
+ flags.Pwrite = true
+ flags.NonSeekable = true
+
+ return fs.NewFile(ctx, dirent, flags, &zeroFileOperations{}), nil
+}
+
+// +stateify savable
+type zeroFileOperations struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNoopWrite `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+ readZeros `state:"nosave"`
+}
+
+var _ fs.FileOperations = (*zeroFileOperations)(nil)
+
+// ConfigureMMap implements fs.FileOperations.ConfigureMMap.
+func (*zeroFileOperations) ConfigureMMap(ctx context.Context, file *fs.File, opts *memmap.MMapOpts) error {
+ m, err := mm.NewSharedAnonMappable(opts.Length, pgalloc.MemoryFileProviderFromContext(ctx))
+ if err != nil {
+ return err
+ }
+ opts.MappingIdentity = m
+ opts.Mappable = m
+ return nil
+}
diff --git a/pkg/sentry/fs/dev/random.go b/pkg/sentry/fs/dev/random.go
new file mode 100644
index 000000000..2a9bbeb18
--- /dev/null
+++ b/pkg/sentry/fs/dev/random.go
@@ -0,0 +1,79 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type randomDevice struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopAllocate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+}
+
+var _ fs.InodeOperations = (*randomDevice)(nil)
+
+func newRandomDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *randomDevice {
+ r := &randomDevice{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC),
+ }
+ return r
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (*randomDevice) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &randomFileOperations{}), nil
+}
+
+// +stateify savable
+type randomFileOperations struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNoopWrite `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+}
+
+var _ fs.FileOperations = (*randomFileOperations)(nil)
+
+// Read implements fs.FileOperations.Read.
+func (*randomFileOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ return dst.CopyOutFrom(ctx, safemem.FromIOReader{rand.Reader})
+}
diff --git a/pkg/sentry/fs/dev/tty.go b/pkg/sentry/fs/dev/tty.go
new file mode 100644
index 000000000..760ca563d
--- /dev/null
+++ b/pkg/sentry/fs/dev/tty.go
@@ -0,0 +1,67 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type ttyInodeOperations struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopAllocate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotOpenable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+}
+
+var _ fs.InodeOperations = (*ttyInodeOperations)(nil)
+
+func newTTYDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *ttyInodeOperations {
+ return &ttyInodeOperations{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC),
+ }
+}
+
+// +stateify savable
+type ttyFileOperations struct {
+ fsutil.FileNoSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNoopWrite `state:"nosave"`
+ fsutil.FileNoopRead `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+}
+
+var _ fs.FileOperations = (*ttyFileOperations)(nil)
diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go
new file mode 100644
index 000000000..65be12175
--- /dev/null
+++ b/pkg/sentry/fs/dirent.go
@@ -0,0 +1,1558 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "path"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+type globalDirentMap struct {
+ mu sync.Mutex
+ dirents map[*Dirent]struct{}
+}
+
+func (g *globalDirentMap) add(d *Dirent) {
+ g.mu.Lock()
+ g.dirents[d] = struct{}{}
+ g.mu.Unlock()
+}
+
+func (g *globalDirentMap) remove(d *Dirent) {
+ g.mu.Lock()
+ delete(g.dirents, d)
+ g.mu.Unlock()
+}
+
+// allDirents keeps track of all Dirents that need to be considered in
+// Save/Restore for inode mappings.
+//
+// Because inodes do not hold paths, but inodes for external file systems map
+// to an external path, every user-visible Dirent is stored in this map and
+// iterated through upon save to keep inode ID -> restore path mappings.
+var allDirents = globalDirentMap{
+ dirents: map[*Dirent]struct{}{},
+}
+
+// renameMu protects the parent of *all* Dirents. (See explanation in
+// lockForRename.)
+//
+// See fs.go for lock ordering.
+var renameMu sync.RWMutex
+
+// Dirent holds an Inode in memory.
+//
+// A Dirent may be negative or positive:
+//
+// A negative Dirent contains a nil Inode and indicates that a path does not exist. This
+// is a convention taken from the Linux dcache, see fs/dcache.c. A negative Dirent remains
+// cached until a create operation replaces it with a positive Dirent. A negative Dirent
+// always has one reference owned by its parent and takes _no_ reference on its parent. This
+// ensures that its parent can be unhashed regardless of negative children.
+//
+// A positive Dirent contains a non-nil Inode. It remains cached for as long as there remain
+// references to it. A positive Dirent always takes a reference on its parent.
+//
+// A Dirent may be a root Dirent (parent is nil) or be parented (non-nil parent).
+//
+// Dirents currently do not attempt to free entries that lack application references under
+// memory pressure.
+//
+// +stateify savable
+type Dirent struct {
+ // AtomicRefCount is our reference count.
+ refs.AtomicRefCount
+
+ // userVisible indicates whether the Dirent is visible to the user or
+ // not. Only user-visible Dirents should save inode mappings in
+ // save/restore, as only they hold the real path to the underlying
+ // inode.
+ //
+ // See newDirent and Dirent.afterLoad.
+ userVisible bool
+
+ // Inode is the underlying file object.
+ //
+ // Inode is exported currently to assist in implementing overlay Inodes (where a
+ // Inode.InodeOperations.Lookup may need to merge the Inode contained in a positive Dirent with
+ // another Inode). This is normally done before the Dirent is parented (there are
+ // no external references to it).
+ //
+ // Other objects in the VFS may take a reference to this Inode but only while holding
+ // a reference to this Dirent.
+ Inode *Inode
+
+ // name is the name (i.e. basename) of this entry.
+ //
+ // N.B. name is protected by parent.mu, not this node's mu!
+ name string
+
+ // parent is the parent directory.
+ //
+ // We hold a hard reference to the parent.
+ //
+ // parent is protected by renameMu.
+ parent *Dirent
+
+ // deleted may be set atomically when removed.
+ deleted int32
+
+ // mounted is true if Dirent is a mount point, similar to include/linux/dcache.h:DCACHE_MOUNTED.
+ mounted bool
+
+ // direntEntry identifies this Dirent as an element in a DirentCache. DirentCaches
+ // and their contents are not saved.
+ direntEntry `state:"nosave"`
+
+ // dirMu is a read-write mutex that protects caching decisions made by directory operations.
+ // Lock ordering: dirMu must be taken before mu (see below). Details:
+ //
+ // dirMu does not participate in Rename; instead mu and renameMu are used, see lockForRename.
+ //
+ // Creation and Removal operations must be synchronized with Walk to prevent stale negative
+ // caching. Note that this requirement is not specific to a _Dirent_ doing negative caching.
+ // The following race exists at any level of the VFS:
+ //
+ // For an object D that represents a directory, containing a cache of non-existent paths,
+ // protected by D.cacheMu:
+ //
+ // T1: T2:
+ // D.lookup(name)
+ // --> ENOENT
+ // D.create(name)
+ // --> success
+ // D.cacheMu.Lock
+ // delete(D.cache, name)
+ // D.cacheMu.Unlock
+ // D.cacheMu.Lock
+ // D.cache[name] = true
+ // D.cacheMu.Unlock
+ //
+ // D.lookup(name)
+ // D.cacheMu.Lock
+ // if D.cache[name] {
+ // --> ENOENT (wrong)
+ // }
+ // D.cacheMu.Lock
+ //
+ // Correct:
+ //
+ // T1: T2:
+ // D.cacheMu.Lock
+ // D.lookup(name)
+ // --> ENOENT
+ // D.cache[name] = true
+ // D.cacheMu.Unlock
+ // D.cacheMu.Lock
+ // D.create(name)
+ // --> success
+ // delete(D.cache, name)
+ // D.cacheMu.Unlock
+ //
+ // D.cacheMu.Lock
+ // D.lookup(name)
+ // --> EXISTS (right)
+ // D.cacheMu.Unlock
+ //
+ // Note that the above "correct" solution causes too much lock contention: all lookups are
+ // synchronized with each other. This is a problem because lookups are involved in any VFS
+ // path operation.
+ //
+ // A Dirent diverges from the single D.cacheMu and instead uses two locks: dirMu to protect
+ // concurrent creation/removal/lookup caching, and mu to protect the Dirent's children map
+ // in general.
+ //
+ // This allows for concurrent Walks to be executed in order to pipeline lookups. For instance
+ // for a hot directory /a/b, threads T1, T2, T3 will only block on each other update the
+ // children map of /a/b when their individual lookups complete.
+ //
+ // T1: T2: T3:
+ // stat(/a/b/c) stat(/a/b/d) stat(/a/b/e)
+ dirMu sync.RWMutex `state:"nosave"`
+
+ // mu protects the below fields. Lock ordering: mu must be taken after dirMu.
+ mu sync.Mutex `state:"nosave"`
+
+ // children are cached via weak references.
+ children map[string]*refs.WeakRef `state:".(map[string]*Dirent)"`
+}
+
+// NewDirent returns a new root Dirent, taking the caller's reference on inode. The caller
+// holds the only reference to the Dirent. Parents may call hashChild to parent this Dirent.
+func NewDirent(ctx context.Context, inode *Inode, name string) *Dirent {
+ d := newDirent(inode, name)
+ allDirents.add(d)
+ d.userVisible = true
+ return d
+}
+
+// NewTransientDirent creates a transient Dirent that shouldn't actually be
+// visible to users.
+//
+// An Inode is required.
+func NewTransientDirent(inode *Inode) *Dirent {
+ if inode == nil {
+ panic("an inode is required")
+ }
+ return newDirent(inode, "transient")
+}
+
+func newDirent(inode *Inode, name string) *Dirent {
+ // The Dirent needs to maintain one reference to MountSource.
+ if inode != nil {
+ inode.MountSource.IncDirentRefs()
+ }
+ d := Dirent{
+ Inode: inode,
+ name: name,
+ children: make(map[string]*refs.WeakRef),
+ }
+ d.EnableLeakCheck("fs.Dirent")
+ return &d
+}
+
+// NewNegativeDirent returns a new root negative Dirent. Otherwise same as NewDirent.
+func NewNegativeDirent(name string) *Dirent {
+ return newDirent(nil, name)
+}
+
+// IsRoot returns true if d is a root Dirent.
+func (d *Dirent) IsRoot() bool {
+ return d.parent == nil
+}
+
+// IsNegative returns true if d represents a path that does not exist.
+func (d *Dirent) IsNegative() bool {
+ return d.Inode == nil
+}
+
+// hashChild will hash child into the children list of its new parent d.
+//
+// Returns (*WeakRef, true) if hashing child caused a Dirent to be unhashed. The caller must
+// validate the returned unhashed weak reference. Common cases:
+//
+// * Remove: hashing a negative Dirent unhashes a positive Dirent (unimplemented).
+// * Create: hashing a positive Dirent unhashes a negative Dirent.
+// * Lookup: hashing any Dirent should not unhash any other Dirent.
+//
+// Preconditions:
+// * d.mu must be held.
+// * child must be a root Dirent.
+func (d *Dirent) hashChild(child *Dirent) (*refs.WeakRef, bool) {
+ if !child.IsRoot() {
+ panic("hashChild must be a root Dirent")
+ }
+
+ // Assign parentage.
+ child.parent = d
+
+ // Avoid letting negative Dirents take a reference on their parent; these Dirents
+ // don't have a role outside of the Dirent cache and should not keep their parent
+ // indefinitely pinned.
+ if !child.IsNegative() {
+ // Positive dirents must take a reference on their parent.
+ d.IncRef()
+ }
+
+ return d.hashChildParentSet(child)
+}
+
+// hashChildParentSet will rehash child into the children list of its parent d.
+//
+// Assumes that child.parent = d already.
+func (d *Dirent) hashChildParentSet(child *Dirent) (*refs.WeakRef, bool) {
+ if child.parent != d {
+ panic("hashChildParentSet assumes the child already belongs to the parent")
+ }
+
+ // Save any replaced child so our caller can validate it.
+ old, ok := d.children[child.name]
+
+ // Hash the child.
+ d.children[child.name] = refs.NewWeakRef(child, nil)
+
+ // Return any replaced child.
+ return old, ok
+}
+
+// SyncAll iterates through mount points under d and writes back their buffered
+// modifications to filesystems.
+func (d *Dirent) SyncAll(ctx context.Context) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // For negative Dirents there is nothing to sync. By definition these are
+ // leaves (there is nothing left to traverse).
+ if d.IsNegative() {
+ return
+ }
+
+ // There is nothing to sync for a read-only filesystem.
+ if !d.Inode.MountSource.Flags.ReadOnly {
+ // NOTE(b/34856369): This should be a mount traversal, not a Dirent
+ // traversal, because some Inodes that need to be synced may no longer
+ // be reachable by name (after sys_unlink).
+ //
+ // Write out metadata, dirty page cached pages, and sync disk/remote
+ // caches.
+ d.Inode.WriteOut(ctx)
+ }
+
+ // Continue iterating through other mounted filesystems.
+ for _, w := range d.children {
+ if child := w.Get(); child != nil {
+ child.(*Dirent).SyncAll(ctx)
+ child.DecRef()
+ }
+ }
+}
+
+// BaseName returns the base name of the dirent.
+func (d *Dirent) BaseName() string {
+ p := d.parent
+ if p == nil {
+ return d.name
+ }
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return d.name
+}
+
+// FullName returns the fully-qualified name and a boolean value representing
+// whether this Dirent was a descendant of root.
+// If the root argument is nil it is assumed to be the root of the Dirent tree.
+func (d *Dirent) FullName(root *Dirent) (string, bool) {
+ renameMu.RLock()
+ defer renameMu.RUnlock()
+ return d.fullName(root)
+}
+
+// fullName returns the fully-qualified name and a boolean value representing
+// if the root node was reachable from this Dirent.
+func (d *Dirent) fullName(root *Dirent) (string, bool) {
+ if d == root {
+ return "/", true
+ }
+
+ if d.IsRoot() {
+ if root != nil {
+ // We reached the top of the Dirent tree but did not encounter
+ // the given root. Return false for reachable so the caller
+ // can handle this situation accordingly.
+ return d.name, false
+ }
+ return d.name, true
+ }
+
+ // Traverse up to parent.
+ d.parent.mu.Lock()
+ name := d.name
+ d.parent.mu.Unlock()
+ parentName, reachable := d.parent.fullName(root)
+ s := path.Join(parentName, name)
+ if atomic.LoadInt32(&d.deleted) != 0 {
+ return s + " (deleted)", reachable
+ }
+ return s, reachable
+}
+
+// MountRoot finds and returns the mount-root for a given dirent.
+func (d *Dirent) MountRoot() *Dirent {
+ renameMu.RLock()
+ defer renameMu.RUnlock()
+
+ mountRoot := d
+ for !mountRoot.mounted && mountRoot.parent != nil {
+ mountRoot = mountRoot.parent
+ }
+ mountRoot.IncRef()
+ return mountRoot
+}
+
+// descendantOf returns true if the receiver dirent is equal to, or a
+// descendant of, the argument dirent.
+//
+// d.mu must be held.
+func (d *Dirent) descendantOf(p *Dirent) bool {
+ if d == p {
+ return true
+ }
+ if d.IsRoot() {
+ return false
+ }
+ return d.parent.descendantOf(p)
+}
+
+// walk walks to path name starting at the dirent, and will not traverse above
+// root Dirent.
+//
+// If walkMayUnlock is true then walk can unlock d.mu to execute a slow
+// Inode.Lookup, otherwise walk will keep d.mu locked.
+//
+// Preconditions:
+// - renameMu must be held for reading.
+// - d.mu must be held.
+// - name must must not contain "/"s.
+func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnlock bool) (*Dirent, error) {
+ if !IsDir(d.Inode.StableAttr) {
+ return nil, syscall.ENOTDIR
+ }
+
+ if name == "" || name == "." {
+ d.IncRef()
+ return d, nil
+ } else if name == ".." {
+ // Respect the chroot. Note that in Linux there is no check to enforce
+ // that d is a descendant of root.
+ if d == root {
+ d.IncRef()
+ return d, nil
+ }
+ // Are we already at the root? Then ".." is ".".
+ if d.IsRoot() {
+ d.IncRef()
+ return d, nil
+ }
+ d.parent.IncRef()
+ return d.parent, nil
+ }
+
+ if w, ok := d.children[name]; ok {
+ // Try to resolve the weak reference to a hard reference.
+ if child := w.Get(); child != nil {
+ cd := child.(*Dirent)
+
+ // Is this a negative Dirent?
+ if cd.IsNegative() {
+ // Don't leak a reference; this doesn't matter as much for negative Dirents,
+ // which don't hold a hard reference on their parent (their parent holds a
+ // hard reference on them, and they contain virtually no state). But this is
+ // good house-keeping.
+ child.DecRef()
+ return nil, syscall.ENOENT
+ }
+
+ // Do we need to revalidate this child?
+ //
+ // We never allow the file system to revalidate mounts, that could cause them
+ // to unexpectedly drop out before umount.
+ if cd.mounted || !cd.Inode.MountSource.Revalidate(ctx, name, d.Inode, cd.Inode) {
+ // Good to go. This is the fast-path.
+ return cd, nil
+ }
+
+ // If we're revalidating a child, we must ensure all inotify watches release
+ // their pins on the child. Inotify doesn't properly support filesystems that
+ // revalidate dirents (since watches are lost on revalidation), but if we fail
+ // to unpin the watches child will never be GCed.
+ cd.Inode.Watches.Unpin(cd)
+
+ // This child needs to be revalidated, fallthrough to unhash it. Make sure
+ // to not leak a reference from Get().
+ //
+ // Note that previous lookups may still have a reference to this stale child;
+ // this can't be helped, but we can ensure that *new* lookups are up-to-date.
+ child.DecRef()
+ }
+
+ // Either our weak reference expired or we need to revalidate it. Unhash child first, we're
+ // about to replace it.
+ delete(d.children, name)
+ w.Drop()
+ }
+
+ // Slow path: load the InodeOperations into memory. Since this is a hot path and the lookup may be
+ // expensive, if possible release the lock and re-acquire it.
+ if walkMayUnlock {
+ d.mu.Unlock()
+ }
+ c, err := d.Inode.Lookup(ctx, name)
+ if walkMayUnlock {
+ d.mu.Lock()
+ }
+ // No dice.
+ if err != nil {
+ return nil, err
+ }
+
+ // Sanity check c, its name must be consistent.
+ if c.name != name {
+ panic(fmt.Sprintf("lookup from %q to %q returned unexpected name %q", d.name, name, c.name))
+ }
+
+ // Now that we have the lock again, check if we raced.
+ if w, ok := d.children[name]; ok {
+ // Someone else looked up or created a child at name before us.
+ if child := w.Get(); child != nil {
+ cd := child.(*Dirent)
+
+ // There are active references to the existing child, prefer it to the one we
+ // retrieved from Lookup. Likely the Lookup happened very close to the insertion
+ // of child, so considering one stale over the other is fairly arbitrary.
+ c.DecRef()
+
+ // The child that was installed could be negative.
+ if cd.IsNegative() {
+ // If so, don't leak a reference and short circuit.
+ child.DecRef()
+ return nil, syscall.ENOENT
+ }
+
+ // We make the judgement call that if c raced with cd they are close enough to have
+ // the same staleness, so we don't attempt to revalidate cd. In Linux revalidations
+ // can continue indefinitely (see fs/namei.c, retry_estale); we try to avoid this.
+ return cd, nil
+ }
+
+ // Weak reference expired. We went through a full cycle of create/destroy in the time
+ // we did the Inode.Lookup. Fully drop the weak reference and fallback to using the child
+ // we looked up.
+ delete(d.children, name)
+ w.Drop()
+ }
+
+ // Give the looked up child a parent. We cannot kick out entries, since we just checked above
+ // that there is nothing at name in d's children list.
+ if _, kicked := d.hashChild(c); kicked {
+ // Yell loudly.
+ panic(fmt.Sprintf("hashed child %q over existing child", c.name))
+ }
+
+ // Is this a negative Dirent?
+ if c.IsNegative() {
+ // Don't drop a reference on the negative Dirent, it was just installed and this is the
+ // only reference we'll ever get. d owns the reference.
+ return nil, syscall.ENOENT
+ }
+
+ // Return the positive Dirent.
+ return c, nil
+}
+
+// Walk walks to a new dirent, and will not walk higher than the given root
+// Dirent, which must not be nil.
+func (d *Dirent) Walk(ctx context.Context, root *Dirent, name string) (*Dirent, error) {
+ if root == nil {
+ panic("Dirent.Walk: root must not be nil")
+ }
+
+ // We could use lockDirectory here, but this is a hot path and we want
+ // to avoid defer.
+ renameMu.RLock()
+ d.dirMu.RLock()
+ d.mu.Lock()
+
+ child, err := d.walk(ctx, root, name, true /* may unlock */)
+
+ d.mu.Unlock()
+ d.dirMu.RUnlock()
+ renameMu.RUnlock()
+
+ return child, err
+}
+
+// exists returns true if name exists in relation to d.
+//
+// Preconditions:
+// - renameMu must be held for reading.
+// - d.mu must be held.
+// - name must must not contain "/"s.
+func (d *Dirent) exists(ctx context.Context, root *Dirent, name string) bool {
+ child, err := d.walk(ctx, root, name, false /* may unlock */)
+ if err != nil {
+ // Child may not exist.
+ return false
+ }
+ // Child exists.
+ child.DecRef()
+ return true
+}
+
+// lockDirectory should be called for any operation that changes this `d`s
+// children (creating or removing them).
+func (d *Dirent) lockDirectory() func() {
+ renameMu.RLock()
+ d.dirMu.Lock()
+ d.mu.Lock()
+ return func() {
+ d.mu.Unlock()
+ d.dirMu.Unlock()
+ renameMu.RUnlock()
+ }
+}
+
+// Create creates a new regular file in this directory.
+func (d *Dirent) Create(ctx context.Context, root *Dirent, name string, flags FileFlags, perms FilePermissions) (*File, error) {
+ unlock := d.lockDirectory()
+ defer unlock()
+
+ // Does something already exist?
+ if d.exists(ctx, root, name) {
+ return nil, syscall.EEXIST
+ }
+
+ // Try the create. We need to trust the file system to return EEXIST (or something
+ // that will translate to EEXIST) if name already exists.
+ file, err := d.Inode.Create(ctx, d, name, flags, perms)
+ if err != nil {
+ return nil, err
+ }
+ child := file.Dirent
+
+ d.finishCreate(child, name)
+
+ // Return the reference and the new file. When the last reference to
+ // the file is dropped, file.Dirent may no longer be cached.
+ return file, nil
+}
+
+// finishCreate validates the created file, adds it as a child of this dirent,
+// and notifies any watchers.
+func (d *Dirent) finishCreate(child *Dirent, name string) {
+ // Sanity check c, its name must be consistent.
+ if child.name != name {
+ panic(fmt.Sprintf("create from %q to %q returned unexpected name %q", d.name, name, child.name))
+ }
+
+ // File systems cannot return a negative Dirent on Create, that makes no sense.
+ if child.IsNegative() {
+ panic(fmt.Sprintf("create from %q to %q returned negative Dirent", d.name, name))
+ }
+
+ // Hash the child into its parent. We can only kick out a Dirent if it is negative
+ // (we are replacing something that does not exist with something that now does).
+ if w, kicked := d.hashChild(child); kicked {
+ if old := w.Get(); old != nil {
+ if !old.(*Dirent).IsNegative() {
+ panic(fmt.Sprintf("hashed child %q over a positive child", child.name))
+ }
+ // Don't leak a reference.
+ old.DecRef()
+
+ // Drop d's reference.
+ old.DecRef()
+ }
+
+ // Finally drop the useless weak reference on the floor.
+ w.Drop()
+ }
+
+ d.Inode.Watches.Notify(name, linux.IN_CREATE, 0)
+
+ // Allow the file system to take extra references on c.
+ child.maybeExtendReference()
+}
+
+// genericCreate executes create if name does not exist. Removes a negative Dirent at name if
+// create succeeds.
+func (d *Dirent) genericCreate(ctx context.Context, root *Dirent, name string, create func() error) error {
+ unlock := d.lockDirectory()
+ defer unlock()
+
+ // Does something already exist?
+ if d.exists(ctx, root, name) {
+ return syscall.EEXIST
+ }
+
+ // Remove any negative Dirent. We've already asserted above with d.exists
+ // that the only thing remaining here can be a negative Dirent.
+ if w, ok := d.children[name]; ok {
+ // Same as Create.
+ if old := w.Get(); old != nil {
+ if !old.(*Dirent).IsNegative() {
+ panic(fmt.Sprintf("hashed over a positive child %q", old.(*Dirent).name))
+ }
+ // Don't leak a reference.
+ old.DecRef()
+
+ // Drop d's reference.
+ old.DecRef()
+ }
+
+ // Unhash the negative Dirent, name needs to exist now.
+ delete(d.children, name)
+
+ // Finally drop the useless weak reference on the floor.
+ w.Drop()
+ }
+
+ // Execute the create operation.
+ return create()
+}
+
+// CreateLink creates a new link in this directory.
+func (d *Dirent) CreateLink(ctx context.Context, root *Dirent, oldname, newname string) error {
+ return d.genericCreate(ctx, root, newname, func() error {
+ if err := d.Inode.CreateLink(ctx, d, oldname, newname); err != nil {
+ return err
+ }
+ d.Inode.Watches.Notify(newname, linux.IN_CREATE, 0)
+ return nil
+ })
+}
+
+// CreateHardLink creates a new hard link in this directory.
+func (d *Dirent) CreateHardLink(ctx context.Context, root *Dirent, target *Dirent, name string) error {
+ // Make sure that target does not span filesystems.
+ if d.Inode.MountSource != target.Inode.MountSource {
+ return syscall.EXDEV
+ }
+
+ // Directories are never linkable. See fs/namei.c:vfs_link.
+ if IsDir(target.Inode.StableAttr) {
+ return syscall.EPERM
+ }
+
+ return d.genericCreate(ctx, root, name, func() error {
+ if err := d.Inode.CreateHardLink(ctx, d, target, name); err != nil {
+ return err
+ }
+ target.Inode.Watches.Notify("", linux.IN_ATTRIB, 0) // Link count change.
+ d.Inode.Watches.Notify(name, linux.IN_CREATE, 0)
+ return nil
+ })
+}
+
+// CreateDirectory creates a new directory under this dirent.
+func (d *Dirent) CreateDirectory(ctx context.Context, root *Dirent, name string, perms FilePermissions) error {
+ return d.genericCreate(ctx, root, name, func() error {
+ if err := d.Inode.CreateDirectory(ctx, d, name, perms); err != nil {
+ return err
+ }
+ d.Inode.Watches.Notify(name, linux.IN_ISDIR|linux.IN_CREATE, 0)
+ return nil
+ })
+}
+
+// Bind satisfies the InodeOperations interface; otherwise same as GetFile.
+func (d *Dirent) Bind(ctx context.Context, root *Dirent, name string, data transport.BoundEndpoint, perms FilePermissions) (*Dirent, error) {
+ var childDir *Dirent
+ err := d.genericCreate(ctx, root, name, func() error {
+ var e error
+ childDir, e = d.Inode.Bind(ctx, d, name, data, perms)
+ if e != nil {
+ return e
+ }
+ d.finishCreate(childDir, name)
+ return nil
+ })
+ if err == syscall.EEXIST {
+ return nil, syscall.EADDRINUSE
+ }
+ if err != nil {
+ return nil, err
+ }
+ return childDir, err
+}
+
+// CreateFifo creates a new named pipe under this dirent.
+func (d *Dirent) CreateFifo(ctx context.Context, root *Dirent, name string, perms FilePermissions) error {
+ return d.genericCreate(ctx, root, name, func() error {
+ if err := d.Inode.CreateFifo(ctx, d, name, perms); err != nil {
+ return err
+ }
+ d.Inode.Watches.Notify(name, linux.IN_CREATE, 0)
+ return nil
+ })
+}
+
+// GetDotAttrs returns the DentAttrs corresponding to "." and ".." directories.
+func (d *Dirent) GetDotAttrs(root *Dirent) (DentAttr, DentAttr) {
+ // Get '.'.
+ sattr := d.Inode.StableAttr
+ dot := DentAttr{
+ Type: sattr.Type,
+ InodeID: sattr.InodeID,
+ }
+
+ // Hold d.mu while we call d.descendantOf.
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // Get '..'.
+ if !d.IsRoot() && d.descendantOf(root) {
+ // Dirent is a descendant of the root. Get its parent's attrs.
+ psattr := d.parent.Inode.StableAttr
+ dotdot := DentAttr{
+ Type: psattr.Type,
+ InodeID: psattr.InodeID,
+ }
+ return dot, dotdot
+ }
+ // Dirent is either root or not a descendant of the root. ".." is the
+ // same as ".".
+ return dot, dot
+}
+
+// DirIterator is an open directory containing directory entries that can be read.
+type DirIterator interface {
+ // IterateDir emits directory entries by calling dirCtx.EmitDir, beginning
+ // with the entry at offset and returning the next directory offset.
+ //
+ // Entries for "." and ".." must *not* be included.
+ //
+ // If the offset returned is the same as the argument offset, then
+ // nothing has been serialized. This is equivalent to reaching EOF.
+ // In this case serializer.Written() should return 0.
+ //
+ // The order of entries to emit must be consistent between Readdir
+ // calls, and must start with the given offset.
+ //
+ // The caller must ensure that this operation is permitted.
+ IterateDir(ctx context.Context, d *Dirent, dirCtx *DirCtx, offset int) (int, error)
+}
+
+// DirentReaddir serializes the directory entries of d including "." and "..".
+//
+// Arguments:
+//
+// * d: the Dirent of the directory being read; required to provide "." and "..".
+// * it: the directory iterator; which represents an open directory handle.
+// * root: fs root; if d is equal to the root, then '..' will refer to d.
+// * ctx: context provided to file systems in order to select and serialize entries.
+// * offset: the current directory offset.
+//
+// Returns the offset of the *next* element which was not serialized.
+func DirentReaddir(ctx context.Context, d *Dirent, it DirIterator, root *Dirent, dirCtx *DirCtx, offset int64) (int64, error) {
+ offset, err := direntReaddir(ctx, d, it, root, dirCtx, offset)
+ // Serializing any directory entries at all means success.
+ if dirCtx.Serializer.Written() > 0 {
+ return offset, nil
+ }
+ return offset, err
+}
+
+func direntReaddir(ctx context.Context, d *Dirent, it DirIterator, root *Dirent, dirCtx *DirCtx, offset int64) (int64, error) {
+ if root == nil {
+ panic("Dirent.Readdir: root must not be nil")
+ }
+ if dirCtx.Serializer == nil {
+ panic("Dirent.Readdir: serializer must not be nil")
+ }
+
+ // Check that this is actually a directory before emitting anything.
+ // Once we have written entries for "." and "..", future errors from
+ // IterateDir will be hidden.
+ if !IsDir(d.Inode.StableAttr) {
+ return 0, syserror.ENOTDIR
+ }
+
+ // This is a special case for lseek(fd, 0, SEEK_END).
+ // See SeekWithDirCursor for more details.
+ if offset == FileMaxOffset {
+ return offset, nil
+ }
+
+ // Collect attrs for "." and "..".
+ dot, dotdot := d.GetDotAttrs(root)
+
+ // Emit "." and ".." if the offset is low enough.
+ if offset == 0 {
+ // Serialize ".".
+ if err := dirCtx.DirEmit(".", dot); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+ if offset == 1 {
+ // Serialize "..".
+ if err := dirCtx.DirEmit("..", dotdot); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+
+ // it.IterateDir should be passed an offset that does not include the
+ // initial dot elements. We will add them back later.
+ offset -= 2
+ newOffset, err := it.IterateDir(ctx, d, dirCtx, int(offset))
+ if int64(newOffset) < offset {
+ panic(fmt.Sprintf("node.Readdir returned offset %v less than input offset %v", newOffset, offset))
+ }
+ // Add the initial nodes back to the offset count.
+ newOffset += 2
+ return int64(newOffset), err
+}
+
+// flush flushes all weak references recursively, and removes any cached
+// references to children.
+//
+// Preconditions: d.mu must be held.
+func (d *Dirent) flush() {
+ expired := make(map[string]*refs.WeakRef)
+ for n, w := range d.children {
+ // Call flush recursively on each child before removing our
+ // reference on it, and removing the cache's reference.
+ if child := w.Get(); child != nil {
+ cd := child.(*Dirent)
+
+ if !cd.IsNegative() {
+ // Flush the child.
+ cd.mu.Lock()
+ cd.flush()
+ cd.mu.Unlock()
+
+ // Allow the file system to drop extra references on child.
+ cd.dropExtendedReference()
+ }
+
+ // Don't leak a reference.
+ child.DecRef()
+ }
+ // Check if the child dirent is closed, and mark it as expired if it is.
+ // We must call w.Get() again here, since the child could have been closed
+ // by the calls to flush() and cache.Remove() in the above if-block.
+ if child := w.Get(); child != nil {
+ child.DecRef()
+ } else {
+ expired[n] = w
+ }
+ }
+
+ // Remove expired entries.
+ for n, w := range expired {
+ delete(d.children, n)
+ w.Drop()
+ }
+}
+
+// isMountPoint returns true if the dirent is a mount point or the root.
+func (d *Dirent) isMountPoint() bool {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.isMountPointLocked()
+}
+
+func (d *Dirent) isMountPointLocked() bool {
+ return d.mounted || d.parent == nil
+}
+
+// mount mounts a new dirent with the given inode over d.
+//
+// Precondition: must be called with mm.withMountLocked held on `d`.
+func (d *Dirent) mount(ctx context.Context, inode *Inode) (newChild *Dirent, err error) {
+ // Did we race with deletion?
+ if atomic.LoadInt32(&d.deleted) != 0 {
+ return nil, syserror.ENOENT
+ }
+
+ // Refuse to mount a symlink.
+ //
+ // See Linux equivalent in fs/namespace.c:do_add_mount.
+ if IsSymlink(inode.StableAttr) {
+ return nil, syserror.EINVAL
+ }
+
+ // Dirent that'll replace d.
+ //
+ // Note that NewDirent returns with one reference taken; the reference
+ // is donated to the caller as the mount reference.
+ replacement := NewDirent(ctx, inode, d.name)
+ replacement.mounted = true
+
+ weakRef, ok := d.parent.hashChild(replacement)
+ if !ok {
+ panic("mount must mount over an existing dirent")
+ }
+ weakRef.Drop()
+
+ // Note that even though `d` is now hidden, it still holds a reference
+ // to its parent.
+ return replacement, nil
+}
+
+// unmount unmounts `d` and replaces it with the last Dirent that was in its
+// place, supplied by the MountNamespace as `replacement`.
+//
+// Precondition: must be called with mm.withMountLocked held on `d`.
+func (d *Dirent) unmount(ctx context.Context, replacement *Dirent) error {
+ // Did we race with deletion?
+ if atomic.LoadInt32(&d.deleted) != 0 {
+ return syserror.ENOENT
+ }
+
+ // Remount our former child in its place.
+ //
+ // As replacement used to be our child, it must already have the right
+ // parent.
+ weakRef, ok := d.parent.hashChildParentSet(replacement)
+ if !ok {
+ panic("mount must mount over an existing dirent")
+ }
+ weakRef.Drop()
+
+ // d is not reachable anymore, and hence not mounted anymore.
+ d.mounted = false
+
+ // Drop mount reference.
+ d.DecRef()
+ return nil
+}
+
+// Remove removes the given file or symlink. The root dirent is used to
+// resolve name, and must not be nil.
+func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath bool) error {
+ // Check the root.
+ if root == nil {
+ panic("Dirent.Remove: root must not be nil")
+ }
+
+ unlock := d.lockDirectory()
+ defer unlock()
+
+ // Try to walk to the node.
+ child, err := d.walk(ctx, root, name, false /* may unlock */)
+ if err != nil {
+ // Child does not exist.
+ return err
+ }
+ defer child.DecRef()
+
+ // Remove cannot remove directories.
+ if IsDir(child.Inode.StableAttr) {
+ return syscall.EISDIR
+ } else if dirPath {
+ return syscall.ENOTDIR
+ }
+
+ // Remove cannot remove a mount point.
+ if child.isMountPoint() {
+ return syscall.EBUSY
+ }
+
+ // Try to remove name on the file system.
+ if err := d.Inode.Remove(ctx, d, child); err != nil {
+ return err
+ }
+
+ // Link count changed, this only applies to non-directory nodes.
+ child.Inode.Watches.Notify("", linux.IN_ATTRIB, 0)
+
+ // Mark name as deleted and remove from children.
+ atomic.StoreInt32(&child.deleted, 1)
+ if w, ok := d.children[name]; ok {
+ delete(d.children, name)
+ w.Drop()
+ }
+
+ // Allow the file system to drop extra references on child.
+ child.dropExtendedReference()
+
+ // Finally, let inotify know the child is being unlinked. Drop any extra
+ // refs from inotify to this child dirent. This doesn't necessarily mean the
+ // watches on the underlying inode will be destroyed, since the underlying
+ // inode may have other links. If this was the last link, the events for the
+ // watch removal will be queued by the inode destructor.
+ child.Inode.Watches.MarkUnlinked()
+ child.Inode.Watches.Unpin(child)
+ d.Inode.Watches.Notify(name, linux.IN_DELETE, 0)
+
+ return nil
+}
+
+// RemoveDirectory removes the given directory. The root dirent is used to
+// resolve name, and must not be nil.
+func (d *Dirent) RemoveDirectory(ctx context.Context, root *Dirent, name string) error {
+ // Check the root.
+ if root == nil {
+ panic("Dirent.Remove: root must not be nil")
+ }
+
+ unlock := d.lockDirectory()
+ defer unlock()
+
+ // Check for dots.
+ if name == "." {
+ // Rejected as the last component by rmdir(2).
+ return syscall.EINVAL
+ }
+ if name == ".." {
+ // If d was found, then its parent is not empty.
+ return syscall.ENOTEMPTY
+ }
+
+ // Try to walk to the node.
+ child, err := d.walk(ctx, root, name, false /* may unlock */)
+ if err != nil {
+ // Child does not exist.
+ return err
+ }
+ defer child.DecRef()
+
+ // RemoveDirectory can only remove directories.
+ if !IsDir(child.Inode.StableAttr) {
+ return syscall.ENOTDIR
+ }
+
+ // Remove cannot remove a mount point.
+ if child.isMountPoint() {
+ return syscall.EBUSY
+ }
+
+ // Try to remove name on the file system.
+ if err := d.Inode.Remove(ctx, d, child); err != nil {
+ return err
+ }
+
+ // Mark name as deleted and remove from children.
+ atomic.StoreInt32(&child.deleted, 1)
+ if w, ok := d.children[name]; ok {
+ delete(d.children, name)
+ w.Drop()
+ }
+
+ // Allow the file system to drop extra references on child.
+ child.dropExtendedReference()
+
+ // Finally, let inotify know the child is being unlinked. Drop any extra
+ // refs from inotify to this child dirent.
+ child.Inode.Watches.MarkUnlinked()
+ child.Inode.Watches.Unpin(child)
+ d.Inode.Watches.Notify(name, linux.IN_ISDIR|linux.IN_DELETE, 0)
+
+ return nil
+}
+
+// destroy closes this node and all children.
+func (d *Dirent) destroy() {
+ if d.IsNegative() {
+ // Nothing to tear-down and no parent references to drop, since a negative
+ // Dirent does not take a references on its parent, has no Inode and no children.
+ return
+ }
+
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // Drop all weak references.
+ for _, w := range d.children {
+ if c := w.Get(); c != nil {
+ if c.(*Dirent).IsNegative() {
+ // The parent holds both weak and strong refs in the case of
+ // negative dirents.
+ c.DecRef()
+ }
+ // Drop the reference we just acquired in WeakRef.Get.
+ c.DecRef()
+ }
+ w.Drop()
+ }
+ d.children = nil
+
+ allDirents.remove(d)
+
+ // Drop our reference to the Inode.
+ d.Inode.DecRef()
+
+ // Allow the Dirent to be GC'ed after this point, since the Inode may still
+ // be referenced after the Dirent is destroyed (for instance by filesystem
+ // internal caches or hard links).
+ d.Inode = nil
+
+ // Drop the reference we have on our parent if we took one. renameMu doesn't need to be
+ // held because d can't be reparented without any references to it left.
+ if d.parent != nil {
+ d.parent.DecRef()
+ }
+}
+
+// IncRef increases the Dirent's refcount as well as its mount's refcount.
+//
+// IncRef implements RefCounter.IncRef.
+func (d *Dirent) IncRef() {
+ if d.Inode != nil {
+ d.Inode.MountSource.IncDirentRefs()
+ }
+ d.AtomicRefCount.IncRef()
+}
+
+// TryIncRef implements RefCounter.TryIncRef.
+func (d *Dirent) TryIncRef() bool {
+ ok := d.AtomicRefCount.TryIncRef()
+ if ok && d.Inode != nil {
+ d.Inode.MountSource.IncDirentRefs()
+ }
+ return ok
+}
+
+// DecRef decreases the Dirent's refcount and drops its reference on its mount.
+//
+// DecRef implements RefCounter.DecRef with destructor d.destroy.
+func (d *Dirent) DecRef() {
+ if d.Inode != nil {
+ // Keep mount around, since DecRef may destroy d.Inode.
+ msrc := d.Inode.MountSource
+ d.DecRefWithDestructor(d.destroy)
+ msrc.DecDirentRefs()
+ } else {
+ d.DecRefWithDestructor(d.destroy)
+ }
+}
+
+// InotifyEvent notifies all watches on the inode for this dirent and its parent
+// of potential events. The events may not actually propagate up to the user,
+// depending on the event masks. InotifyEvent automatically provides the name of
+// the current dirent as the subject of the event as required, and adds the
+// IN_ISDIR flag for dirents that refer to directories.
+func (d *Dirent) InotifyEvent(events, cookie uint32) {
+ // N.B. We don't defer the unlocks because InotifyEvent is in the hot
+ // path of all IO operations, and the defers cost too much for small IO
+ // operations.
+ renameMu.RLock()
+
+ if IsDir(d.Inode.StableAttr) {
+ events |= linux.IN_ISDIR
+ }
+
+ // The ordering below is important, Linux always notifies the parent first.
+ if d.parent != nil {
+ // name is immediately stale w.r.t. renames (renameMu doesn't
+ // protect against renames in the same directory). Holding
+ // d.parent.mu around Notify() wouldn't matter since Notify
+ // doesn't provide a synchronous mechanism for reading the name
+ // anyway.
+ d.parent.mu.Lock()
+ name := d.name
+ d.parent.mu.Unlock()
+ d.parent.Inode.Watches.Notify(name, events, cookie)
+ }
+ d.Inode.Watches.Notify("", events, cookie)
+
+ renameMu.RUnlock()
+}
+
+// maybeExtendReference caches a reference on this Dirent if
+// MountSourceOperations.Keep returns true.
+func (d *Dirent) maybeExtendReference() {
+ if msrc := d.Inode.MountSource; msrc.Keep(d) {
+ msrc.fscache.Add(d)
+ }
+}
+
+// dropExtendedReference drops any cached reference held by the
+// MountSource on the dirent.
+func (d *Dirent) dropExtendedReference() {
+ d.Inode.MountSource.fscache.Remove(d)
+}
+
+// lockForRename takes locks on oldParent and newParent as required by Rename
+// and returns a function that will unlock the locks taken. The returned
+// function must be called even if a non-nil error is returned.
+func lockForRename(oldParent *Dirent, oldName string, newParent *Dirent, newName string) (func(), error) {
+ renameMu.Lock()
+ if oldParent == newParent {
+ oldParent.mu.Lock()
+ return func() {
+ oldParent.mu.Unlock()
+ renameMu.Unlock()
+ }, nil
+ }
+
+ // Renaming between directories is a bit subtle:
+ //
+ // - A concurrent cross-directory Rename may try to lock in the opposite
+ // order; take renameMu to prevent this from happening.
+ //
+ // - If either directory is an ancestor of the other, then a concurrent
+ // Remove may lock the descendant (in DecRef -> closeAll) while holding a
+ // lock on the ancestor; to avoid this, ensure we take locks in the same
+ // ancestor-to-descendant order. (Holding renameMu prevents this
+ // relationship from changing.)
+
+ // First check if newParent is a descendant of oldParent.
+ child := newParent
+ for p := newParent.parent; p != nil; p = p.parent {
+ if p == oldParent {
+ oldParent.mu.Lock()
+ newParent.mu.Lock()
+ var err error
+ if child.name == oldName {
+ // newParent is not just a descendant of oldParent, but
+ // more specifically of oldParent/oldName. That is, we're
+ // trying to rename something into a subdirectory of
+ // itself.
+ err = syscall.EINVAL
+ }
+ return func() {
+ newParent.mu.Unlock()
+ oldParent.mu.Unlock()
+ renameMu.Unlock()
+ }, err
+ }
+ child = p
+ }
+
+ // Otherwise, either oldParent is a descendant of newParent or the two
+ // have no relationship; in either case we can do this:
+ newParent.mu.Lock()
+ oldParent.mu.Lock()
+ return func() {
+ oldParent.mu.Unlock()
+ newParent.mu.Unlock()
+ renameMu.Unlock()
+ }, nil
+}
+
+func (d *Dirent) checkSticky(ctx context.Context, victim *Dirent) error {
+ uattr, err := d.Inode.UnstableAttr(ctx)
+ if err != nil {
+ return syserror.EPERM
+ }
+ if !uattr.Perms.Sticky {
+ return nil
+ }
+
+ creds := auth.CredentialsFromContext(ctx)
+ if uattr.Owner.UID == creds.EffectiveKUID {
+ return nil
+ }
+
+ vuattr, err := victim.Inode.UnstableAttr(ctx)
+ if err != nil {
+ return syserror.EPERM
+ }
+ if vuattr.Owner.UID == creds.EffectiveKUID {
+ return nil
+ }
+ if victim.Inode.CheckCapability(ctx, linux.CAP_FOWNER) {
+ return nil
+ }
+ return syserror.EPERM
+}
+
+// MayDelete determines whether `name`, a child of `d`, can be deleted or
+// renamed by `ctx`.
+//
+// Compare Linux kernel fs/namei.c:may_delete.
+func (d *Dirent) MayDelete(ctx context.Context, root *Dirent, name string) error {
+ if err := d.Inode.CheckPermission(ctx, PermMask{Write: true, Execute: true}); err != nil {
+ return err
+ }
+
+ unlock := d.lockDirectory()
+ defer unlock()
+
+ victim, err := d.walk(ctx, root, name, true /* may unlock */)
+ if err != nil {
+ return err
+ }
+ defer victim.DecRef()
+
+ return d.mayDelete(ctx, victim)
+}
+
+// mayDelete determines whether `victim`, a child of `dir`, can be deleted or
+// renamed by `ctx`.
+//
+// Preconditions: `dir` is writable and executable by `ctx`.
+func (d *Dirent) mayDelete(ctx context.Context, victim *Dirent) error {
+ if err := d.checkSticky(ctx, victim); err != nil {
+ return err
+ }
+
+ if victim.IsRoot() {
+ return syserror.EBUSY
+ }
+
+ return nil
+}
+
+// Rename atomically converts the child of oldParent named oldName to a
+// child of newParent named newName.
+func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string, newParent *Dirent, newName string) error {
+ if root == nil {
+ panic("Rename: root must not be nil")
+ }
+ if oldParent == newParent && oldName == newName {
+ return nil
+ }
+
+ // Acquire global renameMu lock, and mu locks on oldParent/newParent.
+ unlock, err := lockForRename(oldParent, oldName, newParent, newName)
+ defer unlock()
+ if err != nil {
+ return err
+ }
+
+ // Do we have general permission to remove from oldParent and
+ // create/replace in newParent?
+ if err := oldParent.Inode.CheckPermission(ctx, PermMask{Write: true, Execute: true}); err != nil {
+ return err
+ }
+ if err := newParent.Inode.CheckPermission(ctx, PermMask{Write: true, Execute: true}); err != nil {
+ return err
+ }
+
+ // renamed is the dirent that will be renamed to something else.
+ renamed, err := oldParent.walk(ctx, root, oldName, false /* may unlock */)
+ if err != nil {
+ return err
+ }
+ defer renamed.DecRef()
+
+ // Check that the renamed dirent is deletable.
+ if err := oldParent.mayDelete(ctx, renamed); err != nil {
+ return err
+ }
+
+ // Check that the renamed dirent is not a mount point.
+ if renamed.isMountPointLocked() {
+ return syscall.EBUSY
+ }
+
+ // Source should not be an ancestor of the target.
+ if newParent.descendantOf(renamed) {
+ return syscall.EINVAL
+ }
+
+ // Per rename(2): "... EACCES: ... or oldpath is a directory and does not
+ // allow write permission (needed to update the .. entry)."
+ if IsDir(renamed.Inode.StableAttr) {
+ if err := renamed.Inode.CheckPermission(ctx, PermMask{Write: true}); err != nil {
+ return err
+ }
+ }
+
+ // replaced is the dirent that is being overwritten by rename.
+ replaced, err := newParent.walk(ctx, root, newName, false /* may unlock */)
+ if err != nil {
+ if err != syserror.ENOENT {
+ return err
+ }
+
+ // newName doesn't exist; simply create it below.
+ replaced = nil
+ } else {
+ // Check constraints on the dirent being replaced.
+
+ // NOTE(b/111808347): We don't want to keep replaced alive
+ // across the Rename, so must call DecRef manually (no defer).
+
+ // Check that we can delete replaced.
+ if err := newParent.mayDelete(ctx, replaced); err != nil {
+ replaced.DecRef()
+ return err
+ }
+
+ // Target should not be an ancestor of source.
+ if oldParent.descendantOf(replaced) {
+ replaced.DecRef()
+
+ // Note that Linux returns EINVAL if the source is an
+ // ancestor of target, but ENOTEMPTY if the target is
+ // an ancestor of source (unless RENAME_EXCHANGE flag
+ // is present). See fs/namei.c:renameat2.
+ return syscall.ENOTEMPTY
+ }
+
+ // Check that replaced is not a mount point.
+ if replaced.isMountPointLocked() {
+ replaced.DecRef()
+ return syscall.EBUSY
+ }
+
+ // Require that a directory is replaced by a directory.
+ oldIsDir := IsDir(renamed.Inode.StableAttr)
+ newIsDir := IsDir(replaced.Inode.StableAttr)
+ if !newIsDir && oldIsDir {
+ replaced.DecRef()
+ return syscall.ENOTDIR
+ }
+ if !oldIsDir && newIsDir {
+ replaced.DecRef()
+ return syscall.EISDIR
+ }
+
+ // Allow the file system to drop extra references on replaced.
+ replaced.dropExtendedReference()
+
+ // NOTE(b/31798319,b/31867149,b/31867671): Keeping a dirent
+ // open across renames is currently broken for multiple
+ // reasons, so we flush all references on the replaced node and
+ // its children.
+ replaced.Inode.Watches.Unpin(replaced)
+ replaced.mu.Lock()
+ replaced.flush()
+ replaced.mu.Unlock()
+
+ // Done with replaced.
+ replaced.DecRef()
+ }
+
+ if err := renamed.Inode.Rename(ctx, oldParent, renamed, newParent, newName, replaced != nil); err != nil {
+ return err
+ }
+
+ renamed.name = newName
+ renamed.parent = newParent
+ if oldParent != newParent {
+ // Reparent the reference held by renamed.parent. oldParent.DecRef
+ // can't destroy oldParent (and try to retake its lock) because
+ // Rename's caller must be holding a reference.
+ newParent.IncRef()
+ oldParent.DecRef()
+ }
+ if w, ok := newParent.children[newName]; ok {
+ w.Drop()
+ delete(newParent.children, newName)
+ }
+ if w, ok := oldParent.children[oldName]; ok {
+ w.Drop()
+ delete(oldParent.children, oldName)
+ }
+
+ // Add a weak reference from the new parent. This ensures that the child
+ // can still be found from the new parent if a prior hard reference is
+ // held on renamed.
+ //
+ // This is required for file lock correctness because file locks are per-Dirent
+ // and without maintaining the a cached child (via a weak reference) for renamed,
+ // multiple Dirents can correspond to the same resource (by virtue of the renamed
+ // Dirent being unreachable by its parent and it being looked up).
+ newParent.children[newName] = refs.NewWeakRef(renamed, nil)
+
+ // Queue inotify events for the rename.
+ var ev uint32
+ if IsDir(renamed.Inode.StableAttr) {
+ ev |= linux.IN_ISDIR
+ }
+
+ cookie := uniqueid.InotifyCookie(ctx)
+ oldParent.Inode.Watches.Notify(oldName, ev|linux.IN_MOVED_FROM, cookie)
+ newParent.Inode.Watches.Notify(newName, ev|linux.IN_MOVED_TO, cookie)
+ // Somewhat surprisingly, self move events do not have a cookie.
+ renamed.Inode.Watches.Notify("", linux.IN_MOVE_SELF, 0)
+
+ // Allow the file system to drop extra references on renamed.
+ renamed.dropExtendedReference()
+
+ // Same as replaced.flush above.
+ renamed.mu.Lock()
+ renamed.flush()
+ renamed.mu.Unlock()
+
+ return nil
+}
diff --git a/pkg/sentry/fs/dirent_cache.go b/pkg/sentry/fs/dirent_cache.go
new file mode 100644
index 000000000..33de32c69
--- /dev/null
+++ b/pkg/sentry/fs/dirent_cache.go
@@ -0,0 +1,174 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// DirentCache is an LRU cache of Dirents. The Dirent's refCount is
+// incremented when it is added to the cache, and decremented when it is
+// removed.
+//
+// A nil DirentCache corresponds to a cache with size 0. All methods can be
+// called, but nothing is actually cached.
+//
+// +stateify savable
+type DirentCache struct {
+ // Maximum size of the cache. This must be saved manually, to handle the case
+ // when cache is nil.
+ maxSize uint64
+
+ // limit restricts the number of entries in the cache amoung multiple caches.
+ // It may be nil if there are no global limit for this cache.
+ limit *DirentCacheLimiter
+
+ // mu protects currentSize and direntList.
+ mu sync.Mutex `state:"nosave"`
+
+ // currentSize is the number of elements in the cache. It must be zero (i.e.
+ // the cache must be empty) on Save.
+ currentSize uint64 `state:"zerovalue"`
+
+ // list is a direntList, an ilist of Dirents. New Dirents are added
+ // to the front of the list. Old Dirents are removed from the back of
+ // the list. It must be zerovalue (i.e. the cache must be empty) on Save.
+ list direntList `state:"zerovalue"`
+}
+
+// NewDirentCache returns a new DirentCache with the given maxSize.
+func NewDirentCache(maxSize uint64) *DirentCache {
+ return &DirentCache{
+ maxSize: maxSize,
+ }
+}
+
+// Add adds the element to the cache and increments the refCount. If the
+// argument is already in the cache, it is moved to the front. An element is
+// removed from the back if the cache is over capacity.
+func (c *DirentCache) Add(d *Dirent) {
+ if c == nil || c.maxSize == 0 {
+ return
+ }
+
+ c.mu.Lock()
+ if c.contains(d) {
+ // d is already in cache. Bump it to the front.
+ // currentSize and refCount are unaffected.
+ c.list.Remove(d)
+ c.list.PushFront(d)
+ c.mu.Unlock()
+ return
+ }
+
+ // First check against the global limit.
+ for c.limit != nil && !c.limit.tryInc() {
+ if c.currentSize == 0 {
+ // If the global limit is reached, but there is nothing more to drop from
+ // this cache, there is not much else to do.
+ c.mu.Unlock()
+ return
+ }
+ c.remove(c.list.Back())
+ }
+
+ // d is not in cache. Add it and take a reference.
+ c.list.PushFront(d)
+ d.IncRef()
+ c.currentSize++
+
+ c.maybeShrink()
+
+ c.mu.Unlock()
+}
+
+func (c *DirentCache) remove(d *Dirent) {
+ if !c.contains(d) {
+ panic(fmt.Sprintf("trying to remove %v, which is not in the dirent cache", d))
+ }
+ c.list.Remove(d)
+ d.DecRef()
+ c.currentSize--
+ if c.limit != nil {
+ c.limit.dec()
+ }
+}
+
+// Remove removes the element from the cache and decrements its refCount. It
+// also sets the previous and next elements to nil, which allows us to
+// determine if a given element is in the cache.
+func (c *DirentCache) Remove(d *Dirent) {
+ if c == nil || c.maxSize == 0 {
+ return
+ }
+ c.mu.Lock()
+ if !c.contains(d) {
+ c.mu.Unlock()
+ return
+ }
+ c.remove(d)
+ c.mu.Unlock()
+}
+
+// Size returns the number of elements in the cache.
+func (c *DirentCache) Size() uint64 {
+ if c == nil {
+ return 0
+ }
+ c.mu.Lock()
+ size := c.currentSize
+ c.mu.Unlock()
+ return size
+}
+
+func (c *DirentCache) contains(d *Dirent) bool {
+ // If d has a Prev or Next element, then it is in the cache.
+ if d.Prev() != nil || d.Next() != nil {
+ return true
+ }
+ // Otherwise, d is in the cache if it is the only element (and thus the
+ // first element).
+ return c.list.Front() == d
+}
+
+// Invalidate removes all Dirents from the cache, calling DecRef on each.
+func (c *DirentCache) Invalidate() {
+ if c == nil {
+ return
+ }
+ c.mu.Lock()
+ for c.list.Front() != nil {
+ c.remove(c.list.Front())
+ }
+ c.mu.Unlock()
+}
+
+// setMaxSize sets cache max size. If current size is larger than max size, the
+// cache shrinks to accommodate the new max.
+func (c *DirentCache) setMaxSize(max uint64) {
+ c.mu.Lock()
+ c.maxSize = max
+ c.maybeShrink()
+ c.mu.Unlock()
+}
+
+// shrink removes the oldest element until the list is under the size limit.
+func (c *DirentCache) maybeShrink() {
+ for c.maxSize > 0 && c.currentSize > c.maxSize {
+ c.remove(c.list.Back())
+ }
+}
diff --git a/pkg/sentry/fs/dirent_cache_limiter.go b/pkg/sentry/fs/dirent_cache_limiter.go
new file mode 100644
index 000000000..525ee25f9
--- /dev/null
+++ b/pkg/sentry/fs/dirent_cache_limiter.go
@@ -0,0 +1,56 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// DirentCacheLimiter acts as a global limit for all dirent caches in the
+// process.
+//
+// +stateify savable
+type DirentCacheLimiter struct {
+ mu sync.Mutex `state:"nosave"`
+ max uint64
+ count uint64 `state:"zerovalue"`
+}
+
+// NewDirentCacheLimiter creates a new DirentCacheLimiter.
+func NewDirentCacheLimiter(max uint64) *DirentCacheLimiter {
+ return &DirentCacheLimiter{max: max}
+}
+
+func (d *DirentCacheLimiter) tryInc() bool {
+ d.mu.Lock()
+ if d.count >= d.max {
+ d.mu.Unlock()
+ return false
+ }
+ d.count++
+ d.mu.Unlock()
+ return true
+}
+
+func (d *DirentCacheLimiter) dec() {
+ d.mu.Lock()
+ if d.count == 0 {
+ panic(fmt.Sprintf("underflowing DirentCacheLimiter count: %+v", d))
+ }
+ d.count--
+ d.mu.Unlock()
+}
diff --git a/pkg/sentry/fs/dirent_cache_test.go b/pkg/sentry/fs/dirent_cache_test.go
new file mode 100644
index 000000000..395c879f5
--- /dev/null
+++ b/pkg/sentry/fs/dirent_cache_test.go
@@ -0,0 +1,247 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "testing"
+)
+
+func TestDirentCache(t *testing.T) {
+ const maxSize = 5
+
+ c := NewDirentCache(maxSize)
+
+ // Size starts at 0.
+ if got, want := c.Size(), uint64(0); got != want {
+ t.Errorf("c.Size() got %v, want %v", got, want)
+ }
+
+ // Create a Dirent d.
+ d := NewNegativeDirent("")
+
+ // c does not contain d.
+ if got, want := c.contains(d), false; got != want {
+ t.Errorf("c.contains(d) got %v want %v", got, want)
+ }
+
+ // Add d to the cache.
+ c.Add(d)
+
+ // Size is now 1.
+ if got, want := c.Size(), uint64(1); got != want {
+ t.Errorf("c.Size() got %v, want %v", got, want)
+ }
+
+ // c contains d.
+ if got, want := c.contains(d), true; got != want {
+ t.Errorf("c.contains(d) got %v want %v", got, want)
+ }
+
+ // Add maxSize-1 more elements. d should be oldest element.
+ for i := 0; i < maxSize-1; i++ {
+ c.Add(NewNegativeDirent(""))
+ }
+
+ // Size is maxSize.
+ if got, want := c.Size(), uint64(maxSize); got != want {
+ t.Errorf("c.Size() got %v, want %v", got, want)
+ }
+
+ // c contains d.
+ if got, want := c.contains(d), true; got != want {
+ t.Errorf("c.contains(d) got %v want %v", got, want)
+ }
+
+ // "Bump" d to the front by re-adding it.
+ c.Add(d)
+
+ // Size is maxSize.
+ if got, want := c.Size(), uint64(maxSize); got != want {
+ t.Errorf("c.Size() got %v, want %v", got, want)
+ }
+
+ // c contains d.
+ if got, want := c.contains(d), true; got != want {
+ t.Errorf("c.contains(d) got %v want %v", got, want)
+ }
+
+ // Add maxSize-1 more elements. d should again be oldest element.
+ for i := 0; i < maxSize-1; i++ {
+ c.Add(NewNegativeDirent(""))
+ }
+
+ // Size is maxSize.
+ if got, want := c.Size(), uint64(maxSize); got != want {
+ t.Errorf("c.Size() got %v, want %v", got, want)
+ }
+
+ // c contains d.
+ if got, want := c.contains(d), true; got != want {
+ t.Errorf("c.contains(d) got %v want %v", got, want)
+ }
+
+ // Add one more element, which will bump d from the cache.
+ c.Add(NewNegativeDirent(""))
+
+ // Size is maxSize.
+ if got, want := c.Size(), uint64(maxSize); got != want {
+ t.Errorf("c.Size() got %v, want %v", got, want)
+ }
+
+ // c does not contain d.
+ if got, want := c.contains(d), false; got != want {
+ t.Errorf("c.contains(d) got %v want %v", got, want)
+ }
+
+ // Invalidating causes size to be 0 and list to be empty.
+ c.Invalidate()
+ if got, want := c.Size(), uint64(0); got != want {
+ t.Errorf("c.Size() got %v, want %v", got, want)
+ }
+ if got, want := c.list.Empty(), true; got != want {
+ t.Errorf("c.list.Empty() got %v, want %v", got, want)
+ }
+
+ // Fill cache with maxSize dirents.
+ for i := 0; i < maxSize; i++ {
+ c.Add(NewNegativeDirent(""))
+ }
+}
+
+func TestDirentCacheLimiter(t *testing.T) {
+ const (
+ globalMaxSize = 5
+ maxSize = 3
+ )
+
+ limit := NewDirentCacheLimiter(globalMaxSize)
+ c1 := NewDirentCache(maxSize)
+ c1.limit = limit
+ c2 := NewDirentCache(maxSize)
+ c2.limit = limit
+
+ // Create a Dirent d.
+ d := NewNegativeDirent("")
+
+ // Add d to the cache.
+ c1.Add(d)
+ if got, want := c1.Size(), uint64(1); got != want {
+ t.Errorf("c1.Size() got %v, want %v", got, want)
+ }
+
+ // Add maxSize-1 more elements. d should be oldest element.
+ for i := 0; i < maxSize-1; i++ {
+ c1.Add(NewNegativeDirent(""))
+ }
+ if got, want := c1.Size(), uint64(maxSize); got != want {
+ t.Errorf("c1.Size() got %v, want %v", got, want)
+ }
+
+ // Check that d is still there.
+ if got, want := c1.contains(d), true; got != want {
+ t.Errorf("c1.contains(d) got %v want %v", got, want)
+ }
+
+ // Fill up the other cache, it will start dropping old entries from the cache
+ // when the global limit is reached.
+ for i := 0; i < maxSize; i++ {
+ c2.Add(NewNegativeDirent(""))
+ }
+
+ // Check is what's remaining from global max.
+ if got, want := c2.Size(), globalMaxSize-maxSize; int(got) != want {
+ t.Errorf("c2.Size() got %v, want %v", got, want)
+ }
+
+ // Check that d was not dropped.
+ if got, want := c1.contains(d), true; got != want {
+ t.Errorf("c1.contains(d) got %v want %v", got, want)
+ }
+
+ // Add an entry that will eventually be dropped. Check is done later...
+ drop := NewNegativeDirent("")
+ c1.Add(drop)
+
+ // Check that d is bumped to front even when global limit is reached.
+ c1.Add(d)
+ if got, want := c1.contains(d), true; got != want {
+ t.Errorf("c1.contains(d) got %v want %v", got, want)
+ }
+
+ // Add 2 more element and check that:
+ // - d is still in the list: to verify that d was bumped
+ // - d2/d3 are in the list: older entries are dropped when global limit is
+ // reached.
+ // - drop is not in the list: indeed older elements are dropped.
+ d2 := NewNegativeDirent("")
+ c1.Add(d2)
+ d3 := NewNegativeDirent("")
+ c1.Add(d3)
+ if got, want := c1.contains(d), true; got != want {
+ t.Errorf("c1.contains(d) got %v want %v", got, want)
+ }
+ if got, want := c1.contains(d2), true; got != want {
+ t.Errorf("c1.contains(d2) got %v want %v", got, want)
+ }
+ if got, want := c1.contains(d3), true; got != want {
+ t.Errorf("c1.contains(d3) got %v want %v", got, want)
+ }
+ if got, want := c1.contains(drop), false; got != want {
+ t.Errorf("c1.contains(drop) got %v want %v", got, want)
+ }
+
+ // Drop all entries from one cache. The other will be allowed to grow.
+ c1.Invalidate()
+ c2.Add(NewNegativeDirent(""))
+ if got, want := c2.Size(), uint64(maxSize); got != want {
+ t.Errorf("c2.Size() got %v, want %v", got, want)
+ }
+}
+
+// TestNilDirentCache tests that a nil cache supports all cache operations, but
+// treats them as noop.
+func TestNilDirentCache(t *testing.T) {
+ // Create a nil cache.
+ var c *DirentCache
+
+ // Size is zero.
+ if got, want := c.Size(), uint64(0); got != want {
+ t.Errorf("c.Size() got %v, want %v", got, want)
+ }
+
+ // Call Add.
+ c.Add(NewNegativeDirent(""))
+
+ // Size is zero.
+ if got, want := c.Size(), uint64(0); got != want {
+ t.Errorf("c.Size() got %v, want %v", got, want)
+ }
+
+ // Call Remove.
+ c.Remove(NewNegativeDirent(""))
+
+ // Size is zero.
+ if got, want := c.Size(), uint64(0); got != want {
+ t.Errorf("c.Size() got %v, want %v", got, want)
+ }
+
+ // Call Invalidate.
+ c.Invalidate()
+
+ // Size is zero.
+ if got, want := c.Size(), uint64(0); got != want {
+ t.Errorf("c.Size() got %v, want %v", got, want)
+ }
+}
diff --git a/pkg/sentry/fs/dirent_refs_test.go b/pkg/sentry/fs/dirent_refs_test.go
new file mode 100644
index 000000000..98d69c6f2
--- /dev/null
+++ b/pkg/sentry/fs/dirent_refs_test.go
@@ -0,0 +1,418 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "syscall"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+)
+
+func newMockDirInode(ctx context.Context, cache *DirentCache) *Inode {
+ return NewMockInode(ctx, NewMockMountSource(cache), StableAttr{Type: Directory})
+}
+
+func TestWalkPositive(t *testing.T) {
+ // refs == 0 -> one reference.
+ // refs == -1 -> has been destroyed.
+
+ ctx := contexttest.Context(t)
+ root := NewDirent(ctx, newMockDirInode(ctx, nil), "root")
+
+ if got := root.ReadRefs(); got != 1 {
+ t.Fatalf("root has a ref count of %d, want %d", got, 1)
+ }
+
+ name := "d"
+ d, err := root.walk(ctx, root, name, false)
+ if err != nil {
+ t.Fatalf("root.walk(root, %q) got %v, want nil", name, err)
+ }
+
+ if got := root.ReadRefs(); got != 2 {
+ t.Fatalf("root has a ref count of %d, want %d", got, 2)
+ }
+
+ if got := d.ReadRefs(); got != 1 {
+ t.Fatalf("child name = %q has a ref count of %d, want %d", d.name, got, 1)
+ }
+
+ d.DecRef()
+
+ if got := root.ReadRefs(); got != 1 {
+ t.Fatalf("root has a ref count of %d, want %d", got, 1)
+ }
+
+ if got := d.ReadRefs(); got != 0 {
+ t.Fatalf("child name = %q has a ref count of %d, want %d", d.name, got, 0)
+ }
+
+ root.flush()
+
+ if got := len(root.children); got != 0 {
+ t.Fatalf("root has %d children, want %d", got, 0)
+ }
+}
+
+func TestWalkNegative(t *testing.T) {
+ // refs == 0 -> one reference.
+ // refs == -1 -> has been destroyed.
+
+ ctx := contexttest.Context(t)
+ root := NewDirent(ctx, NewEmptyDir(ctx, nil), "root")
+ mn := root.Inode.InodeOperations.(*mockInodeOperationsLookupNegative)
+
+ if got := root.ReadRefs(); got != 1 {
+ t.Fatalf("root has a ref count of %d, want %d", got, 1)
+ }
+
+ name := "d"
+ for i := 0; i < 100; i++ {
+ _, err := root.walk(ctx, root, name, false)
+ if err != syscall.ENOENT {
+ t.Fatalf("root.walk(root, %q) got %v, want %v", name, err, syscall.ENOENT)
+ }
+ }
+
+ if got := root.ReadRefs(); got != 1 {
+ t.Fatalf("root has a ref count of %d, want %d", got, 1)
+ }
+
+ if got := len(root.children); got != 1 {
+ t.Fatalf("root has %d children, want %d", got, 1)
+ }
+
+ w, ok := root.children[name]
+ if !ok {
+ t.Fatalf("root wants child at %q", name)
+ }
+
+ child := w.Get()
+ if child == nil {
+ t.Fatalf("root wants to resolve weak reference")
+ }
+
+ if !child.(*Dirent).IsNegative() {
+ t.Fatalf("root found positive child at %q, want negative", name)
+ }
+
+ if got := child.(*Dirent).ReadRefs(); got != 2 {
+ t.Fatalf("child has a ref count of %d, want %d", got, 2)
+ }
+
+ child.DecRef()
+
+ if got := child.(*Dirent).ReadRefs(); got != 1 {
+ t.Fatalf("child has a ref count of %d, want %d", got, 1)
+ }
+
+ if got := len(root.children); got != 1 {
+ t.Fatalf("root has %d children, want %d", got, 1)
+ }
+
+ root.DecRef()
+
+ if got := root.ReadRefs(); got != 0 {
+ t.Fatalf("root has a ref count of %d, want %d", got, 0)
+ }
+
+ AsyncBarrier()
+
+ if got := mn.releaseCalled; got != true {
+ t.Fatalf("root.Close was called %v, want true", got)
+ }
+}
+
+type mockInodeOperationsLookupNegative struct {
+ *MockInodeOperations
+ releaseCalled bool
+}
+
+func NewEmptyDir(ctx context.Context, cache *DirentCache) *Inode {
+ m := NewMockMountSource(cache)
+ return NewInode(ctx, &mockInodeOperationsLookupNegative{
+ MockInodeOperations: NewMockInodeOperations(ctx),
+ }, m, StableAttr{Type: Directory})
+}
+
+func (m *mockInodeOperationsLookupNegative) Lookup(ctx context.Context, dir *Inode, p string) (*Dirent, error) {
+ return NewNegativeDirent(p), nil
+}
+
+func (m *mockInodeOperationsLookupNegative) Release(context.Context) {
+ m.releaseCalled = true
+}
+
+func TestHashNegativeToPositive(t *testing.T) {
+ // refs == 0 -> one reference.
+ // refs == -1 -> has been destroyed.
+
+ ctx := contexttest.Context(t)
+ root := NewDirent(ctx, NewEmptyDir(ctx, nil), "root")
+
+ name := "d"
+ _, err := root.walk(ctx, root, name, false)
+ if err != syscall.ENOENT {
+ t.Fatalf("root.walk(root, %q) got %v, want %v", name, err, syscall.ENOENT)
+ }
+
+ if got := root.exists(ctx, root, name); got != false {
+ t.Fatalf("got %q exists, want does not exist", name)
+ }
+
+ f, err := root.Create(ctx, root, name, FileFlags{}, FilePermissions{})
+ if err != nil {
+ t.Fatalf("root.Create(%q, _), got error %v, want nil", name, err)
+ }
+ d := f.Dirent
+
+ if d.IsNegative() {
+ t.Fatalf("got negative Dirent, want positive")
+ }
+
+ if got := d.ReadRefs(); got != 1 {
+ t.Fatalf("child %q has a ref count of %d, want %d", name, got, 1)
+ }
+
+ if got := root.ReadRefs(); got != 2 {
+ t.Fatalf("root has a ref count of %d, want %d", got, 2)
+ }
+
+ if got := len(root.children); got != 1 {
+ t.Fatalf("got %d children, want %d", got, 1)
+ }
+
+ w, ok := root.children[name]
+ if !ok {
+ t.Fatalf("failed to find weak reference to %q", name)
+ }
+
+ child := w.Get()
+ if child == nil {
+ t.Fatalf("want to resolve weak reference")
+ }
+
+ if child.(*Dirent) != d {
+ t.Fatalf("got foreign child")
+ }
+}
+
+func TestRevalidate(t *testing.T) {
+ // refs == 0 -> one reference.
+ // refs == -1 -> has been destroyed.
+
+ for _, test := range []struct {
+ // desc is the test's description.
+ desc string
+
+ // Whether to make negative Dirents.
+ makeNegative bool
+ }{
+ {
+ desc: "Revalidate negative Dirent",
+ makeNegative: true,
+ },
+ {
+ desc: "Revalidate positive Dirent",
+ makeNegative: false,
+ },
+ } {
+ t.Run(test.desc, func(t *testing.T) {
+ ctx := contexttest.Context(t)
+ root := NewDirent(ctx, NewMockInodeRevalidate(ctx, test.makeNegative), "root")
+
+ name := "d"
+ d1, err := root.walk(ctx, root, name, false)
+ if !test.makeNegative && err != nil {
+ t.Fatalf("root.walk(root, %q) got %v, want nil", name, err)
+ }
+ d2, err := root.walk(ctx, root, name, false)
+ if !test.makeNegative && err != nil {
+ t.Fatalf("root.walk(root, %q) got %v, want nil", name, err)
+ }
+ if !test.makeNegative && d1 == d2 {
+ t.Fatalf("revalidating walk got same *Dirent, want different")
+ }
+ if got := len(root.children); got != 1 {
+ t.Errorf("revalidating walk got %d children, want %d", got, 1)
+ }
+ })
+ }
+}
+
+type MockInodeOperationsRevalidate struct {
+ *MockInodeOperations
+ makeNegative bool
+}
+
+func NewMockInodeRevalidate(ctx context.Context, makeNegative bool) *Inode {
+ mn := NewMockInodeOperations(ctx)
+ m := NewMockMountSource(nil)
+ m.MountSourceOperations.(*MockMountSourceOps).revalidate = true
+ return NewInode(ctx, &MockInodeOperationsRevalidate{MockInodeOperations: mn, makeNegative: makeNegative}, m, StableAttr{Type: Directory})
+}
+
+func (m *MockInodeOperationsRevalidate) Lookup(ctx context.Context, dir *Inode, p string) (*Dirent, error) {
+ if !m.makeNegative {
+ return m.MockInodeOperations.Lookup(ctx, dir, p)
+ }
+ return NewNegativeDirent(p), nil
+}
+
+func TestCreateExtraRefs(t *testing.T) {
+ // refs == 0 -> one reference.
+ // refs == -1 -> has been destroyed.
+
+ ctx := contexttest.Context(t)
+ for _, test := range []struct {
+ // desc is the test's description.
+ desc string
+
+ // root is the Dirent to create from.
+ root *Dirent
+
+ // expected references on walked Dirent.
+ refs int64
+ }{
+ {
+ desc: "Create caching",
+ root: NewDirent(ctx, NewEmptyDir(ctx, NewDirentCache(1)), "root"),
+ refs: 2,
+ },
+ {
+ desc: "Create not caching",
+ root: NewDirent(ctx, NewEmptyDir(ctx, nil), "root"),
+ refs: 1,
+ },
+ } {
+ t.Run(test.desc, func(t *testing.T) {
+ name := "d"
+ f, err := test.root.Create(ctx, test.root, name, FileFlags{}, FilePermissions{})
+ if err != nil {
+ t.Fatalf("root.Create(root, %q) failed: %v", name, err)
+ }
+ d := f.Dirent
+
+ if got := d.ReadRefs(); got != test.refs {
+ t.Errorf("dirent has a ref count of %d, want %d", got, test.refs)
+ }
+ })
+ }
+}
+
+func TestRemoveExtraRefs(t *testing.T) {
+ // refs == 0 -> one reference.
+ // refs == -1 -> has been destroyed.
+
+ ctx := contexttest.Context(t)
+ for _, test := range []struct {
+ // desc is the test's description.
+ desc string
+
+ // root is the Dirent to make and remove from.
+ root *Dirent
+ }{
+ {
+ desc: "Remove caching",
+ root: NewDirent(ctx, NewEmptyDir(ctx, NewDirentCache(1)), "root"),
+ },
+ {
+ desc: "Remove not caching",
+ root: NewDirent(ctx, NewEmptyDir(ctx, nil), "root"),
+ },
+ } {
+ t.Run(test.desc, func(t *testing.T) {
+ name := "d"
+ f, err := test.root.Create(ctx, test.root, name, FileFlags{}, FilePermissions{})
+ if err != nil {
+ t.Fatalf("root.Create(%q, _) failed: %v", name, err)
+ }
+ d := f.Dirent
+
+ if err := test.root.Remove(contexttest.Context(t), test.root, name, false /* dirPath */); err != nil {
+ t.Fatalf("root.Remove(root, %q) failed: %v", name, err)
+ }
+
+ if got := d.ReadRefs(); got != 1 {
+ t.Fatalf("dirent has a ref count of %d, want %d", got, 1)
+ }
+
+ d.DecRef()
+
+ test.root.flush()
+
+ if got := len(test.root.children); got != 0 {
+ t.Errorf("root has %d children, want %d", got, 0)
+ }
+ })
+ }
+}
+
+func TestRenameExtraRefs(t *testing.T) {
+ // refs == 0 -> one reference.
+ // refs == -1 -> has been destroyed.
+
+ for _, test := range []struct {
+ // desc is the test's description.
+ desc string
+
+ // cache of extra Dirent references, may be nil.
+ cache *DirentCache
+ }{
+ {
+ desc: "Rename no caching",
+ cache: nil,
+ },
+ {
+ desc: "Rename caching",
+ cache: NewDirentCache(5),
+ },
+ } {
+ t.Run(test.desc, func(t *testing.T) {
+ ctx := contexttest.Context(t)
+
+ dirAttr := StableAttr{Type: Directory}
+
+ oldParent := NewDirent(ctx, NewMockInode(ctx, NewMockMountSource(test.cache), dirAttr), "old_parent")
+ newParent := NewDirent(ctx, NewMockInode(ctx, NewMockMountSource(test.cache), dirAttr), "new_parent")
+
+ renamed, err := oldParent.Walk(ctx, oldParent, "old_child")
+ if err != nil {
+ t.Fatalf("Walk(oldParent, %q) got error %v, want nil", "old_child", err)
+ }
+ replaced, err := newParent.Walk(ctx, oldParent, "new_child")
+ if err != nil {
+ t.Fatalf("Walk(newParent, %q) got error %v, want nil", "new_child", err)
+ }
+
+ if err := Rename(contexttest.RootContext(t), oldParent /*root */, oldParent, "old_child", newParent, "new_child"); err != nil {
+ t.Fatalf("Rename got error %v, want nil", err)
+ }
+
+ oldParent.flush()
+ newParent.flush()
+
+ // Expect to have only active references.
+ if got := renamed.ReadRefs(); got != 1 {
+ t.Errorf("renamed has ref count %d, want only active references %d", got, 1)
+ }
+ if got := replaced.ReadRefs(); got != 1 {
+ t.Errorf("replaced has ref count %d, want only active references %d", got, 1)
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/fs/dirent_state.go b/pkg/sentry/fs/dirent_state.go
new file mode 100644
index 000000000..f623d6c0e
--- /dev/null
+++ b/pkg/sentry/fs/dirent_state.go
@@ -0,0 +1,77 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/refs"
+)
+
+// beforeSave is invoked by stateify.
+func (d *Dirent) beforeSave() {
+ // Refuse to save if the file is on a non-virtual file system and has
+ // already been deleted (but still has open fds, which is why the Dirent
+ // is still accessible). We know the the restore re-opening of the file
+ // will always fail. This condition will last until all the open fds and
+ // this Dirent are closed and released.
+ //
+ // Such "dangling" open files on virtual file systems (e.g., tmpfs) is
+ // OK to save as their restore does not require re-opening the files.
+ //
+ // Note that this is rejection rather than failure---it would be
+ // perfectly OK to save---we are simply disallowing it here to prevent
+ // generating non-restorable state dumps. As the program continues its
+ // execution, it may become allowed to save again.
+ if !d.Inode.IsVirtual() && atomic.LoadInt32(&d.deleted) != 0 {
+ n, _ := d.FullName(nil /* root */)
+ panic(ErrSaveRejection{fmt.Errorf("deleted file %q still has open fds", n)})
+ }
+}
+
+// saveChildren is invoked by stateify.
+func (d *Dirent) saveChildren() map[string]*Dirent {
+ c := make(map[string]*Dirent)
+ for name, w := range d.children {
+ if rc := w.Get(); rc != nil {
+ // Drop the reference count obtain in w.Get()
+ rc.DecRef()
+
+ cd := rc.(*Dirent)
+ if cd.IsNegative() {
+ // Don't bother saving negative Dirents.
+ continue
+ }
+ c[name] = cd
+ }
+ }
+ return c
+}
+
+// loadChildren is invoked by stateify.
+func (d *Dirent) loadChildren(children map[string]*Dirent) {
+ d.children = make(map[string]*refs.WeakRef)
+ for name, c := range children {
+ d.children[name] = refs.NewWeakRef(c, nil)
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (d *Dirent) afterLoad() {
+ if d.userVisible {
+ allDirents.add(d)
+ }
+}
diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD
new file mode 100644
index 000000000..1d09e983c
--- /dev/null
+++ b/pkg/sentry/fs/fdpipe/BUILD
@@ -0,0 +1,48 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "fdpipe",
+ srcs = [
+ "pipe.go",
+ "pipe_opener.go",
+ "pipe_state.go",
+ ],
+ imports = ["gvisor.dev/gvisor/pkg/sentry/fs"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/context",
+ "//pkg/fd",
+ "//pkg/fdnotifier",
+ "//pkg/log",
+ "//pkg/safemem",
+ "//pkg/secio",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "fdpipe_test",
+ size = "small",
+ srcs = [
+ "pipe_opener_test.go",
+ "pipe_test.go",
+ ],
+ library = ":fdpipe",
+ deps = [
+ "//pkg/context",
+ "//pkg/fd",
+ "//pkg/fdnotifier",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "@com_github_google_uuid//:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go
new file mode 100644
index 000000000..9fce177ad
--- /dev/null
+++ b/pkg/sentry/fs/fdpipe/pipe.go
@@ -0,0 +1,168 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fdpipe implements common namedpipe opening and accessing logic.
+package fdpipe
+
+import (
+ "os"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/secio"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// pipeOperations are the fs.FileOperations of a host pipe.
+//
+// +stateify savable
+type pipeOperations struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.Queue `state:"nosave"`
+
+ // flags are the flags used to open the pipe.
+ flags fs.FileFlags `state:".(fs.FileFlags)"`
+
+ // opener is how the pipe was opened.
+ opener NonBlockingOpener `state:"wait"`
+
+ // file represents the host pipe.
+ file *fd.FD `state:"nosave"`
+
+ // mu protects readAheadBuffer access below.
+ mu sync.Mutex `state:"nosave"`
+
+ // readAheadBuffer contains read bytes that have not yet been read
+ // by the application but need to be buffered for save-restore for correct
+ // opening semantics. The readAheadBuffer will only be non-empty when the
+ // is first opened and will be drained by subsequent reads on the pipe.
+ readAheadBuffer []byte
+}
+
+// newPipeOperations returns an implementation of fs.FileOperations for a pipe.
+func newPipeOperations(ctx context.Context, opener NonBlockingOpener, flags fs.FileFlags, file *fd.FD, readAheadBuffer []byte) (*pipeOperations, error) {
+ pipeOps := &pipeOperations{
+ flags: flags,
+ opener: opener,
+ file: file,
+ readAheadBuffer: readAheadBuffer,
+ }
+ if err := pipeOps.init(); err != nil {
+ return nil, err
+ }
+ return pipeOps, nil
+}
+
+// init initializes p.file.
+func (p *pipeOperations) init() error {
+ var s syscall.Stat_t
+ if err := syscall.Fstat(p.file.FD(), &s); err != nil {
+ log.Warningf("pipe: cannot stat fd %d: %v", p.file.FD(), err)
+ return syscall.EINVAL
+ }
+ if (s.Mode & syscall.S_IFMT) != syscall.S_IFIFO {
+ log.Warningf("pipe: cannot load fd %d as pipe, file type: %o", p.file.FD(), s.Mode)
+ return syscall.EINVAL
+ }
+ if err := syscall.SetNonblock(p.file.FD(), true); err != nil {
+ return err
+ }
+ return fdnotifier.AddFD(int32(p.file.FD()), &p.Queue)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (p *pipeOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ p.Queue.EventRegister(e, mask)
+ fdnotifier.UpdateFD(int32(p.file.FD()))
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (p *pipeOperations) EventUnregister(e *waiter.Entry) {
+ p.Queue.EventUnregister(e)
+ fdnotifier.UpdateFD(int32(p.file.FD()))
+}
+
+// Readiness returns a mask of ready events for stream.
+func (p *pipeOperations) Readiness(mask waiter.EventMask) (eventMask waiter.EventMask) {
+ return fdnotifier.NonBlockingPoll(int32(p.file.FD()), mask)
+}
+
+// Release implements fs.FileOperations.Release.
+func (p *pipeOperations) Release() {
+ fdnotifier.RemoveFD(int32(p.file.FD()))
+ p.file.Close()
+ p.file = nil
+}
+
+// Read implements fs.FileOperations.Read.
+func (p *pipeOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ // Drain the read ahead buffer, if it contains anything first.
+ var bufN int
+ var bufErr error
+ p.mu.Lock()
+ if len(p.readAheadBuffer) > 0 {
+ bufN, bufErr = dst.CopyOut(ctx, p.readAheadBuffer)
+ p.readAheadBuffer = p.readAheadBuffer[bufN:]
+ dst = dst.DropFirst(bufN)
+ }
+ p.mu.Unlock()
+ if dst.NumBytes() == 0 || bufErr != nil {
+ return int64(bufN), bufErr
+ }
+
+ // Pipes expect full reads.
+ n, err := dst.CopyOutFrom(ctx, safemem.FromIOReader{secio.FullReader{p.file}})
+ total := int64(bufN) + n
+ if err != nil && isBlockError(err) {
+ return total, syserror.ErrWouldBlock
+ }
+ return total, err
+}
+
+// Write implements fs.FileOperations.Write.
+func (p *pipeOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ n, err := src.CopyInTo(ctx, safemem.FromIOWriter{p.file})
+ if err != nil && isBlockError(err) {
+ return n, syserror.ErrWouldBlock
+ }
+ return n, err
+}
+
+// isBlockError unwraps os errors and checks if they are caused by EAGAIN or
+// EWOULDBLOCK. This is so they can be transformed into syserror.ErrWouldBlock.
+func isBlockError(err error) bool {
+ if err == syserror.EAGAIN || err == syserror.EWOULDBLOCK {
+ return true
+ }
+ if pe, ok := err.(*os.PathError); ok {
+ return isBlockError(pe.Err)
+ }
+ return false
+}
diff --git a/pkg/sentry/fs/fdpipe/pipe_opener.go b/pkg/sentry/fs/fdpipe/pipe_opener.go
new file mode 100644
index 000000000..0c3595998
--- /dev/null
+++ b/pkg/sentry/fs/fdpipe/pipe_opener.go
@@ -0,0 +1,193 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fdpipe
+
+import (
+ "io"
+ "os"
+ "syscall"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// NonBlockingOpener is a generic host file opener used to retry opening host
+// pipes if necessary.
+type NonBlockingOpener interface {
+ // NonBlockingOpen tries to open a host pipe in a non-blocking way,
+ // and otherwise returns an error. Implementations should be idempotent.
+ NonBlockingOpen(context.Context, fs.PermMask) (*fd.FD, error)
+}
+
+// Open blocks until a host pipe can be opened or the action was cancelled.
+// On success, returns fs.FileOperations wrapping the opened host pipe.
+func Open(ctx context.Context, opener NonBlockingOpener, flags fs.FileFlags) (fs.FileOperations, error) {
+ p := &pipeOpenState{}
+ canceled := false
+ for {
+ if file, err := p.TryOpen(ctx, opener, flags); err != syserror.ErrWouldBlock {
+ return file, err
+ }
+
+ // Honor the cancellation request if open still blocks.
+ if canceled {
+ // If we were canceled but we have a handle to a host
+ // file, we need to close it.
+ if p.hostFile != nil {
+ p.hostFile.Close()
+ }
+ return nil, syserror.ErrInterrupted
+ }
+
+ cancel := ctx.SleepStart()
+ select {
+ case <-cancel:
+ // The cancellation request received here really says
+ // "cancel from now on (or ASAP)". Any environmental
+ // changes happened before receiving it, that might have
+ // caused open to not block anymore, should still be
+ // respected. So we cannot just return here. We have to
+ // give open another try below first.
+ canceled = true
+ ctx.SleepFinish(false)
+ case <-time.After(100 * time.Millisecond):
+ // If we would block, then delay retrying for a bit, since there
+ // is no way to know when the pipe would be ready to be
+ // re-opened. This is identical to sending an event notification
+ // to stop blocking in Task.Block, given that this routine will
+ // stop retrying if a cancelation is received.
+ ctx.SleepFinish(true)
+ }
+ }
+}
+
+// pipeOpenState holds state needed to open a blocking named pipe read only, for instance the
+// file that has been opened but doesn't yet have a corresponding writer.
+type pipeOpenState struct {
+ // hostFile is the read only named pipe which lacks a corresponding writer.
+ hostFile *fd.FD
+}
+
+// unwrapError is needed to match against ENXIO primarily.
+func unwrapError(err error) error {
+ if pe, ok := err.(*os.PathError); ok {
+ return pe.Err
+ }
+ return err
+}
+
+// TryOpen uses a NonBlockingOpener to try to open a host pipe, respecting the fs.FileFlags.
+func (p *pipeOpenState) TryOpen(ctx context.Context, opener NonBlockingOpener, flags fs.FileFlags) (*pipeOperations, error) {
+ switch {
+ // Reject invalid configurations so they don't accidentally succeed below.
+ case !flags.Read && !flags.Write:
+ return nil, syscall.EINVAL
+
+ // Handle opening RDWR or with O_NONBLOCK: will never block, so try only once.
+ case (flags.Read && flags.Write) || flags.NonBlocking:
+ f, err := opener.NonBlockingOpen(ctx, fs.PermMask{Read: flags.Read, Write: flags.Write})
+ if err != nil {
+ return nil, err
+ }
+ return newPipeOperations(ctx, opener, flags, f, nil)
+
+ // Handle opening O_WRONLY blocking: convert ENXIO to syserror.ErrWouldBlock.
+ // See TryOpenWriteOnly for more details.
+ case flags.Write:
+ return p.TryOpenWriteOnly(ctx, opener)
+
+ default:
+ // Handle opening O_RDONLY blocking: convert EOF from read to syserror.ErrWouldBlock.
+ // See TryOpenReadOnly for more details.
+ return p.TryOpenReadOnly(ctx, opener)
+ }
+}
+
+// TryOpenReadOnly tries to open a host pipe read only but only returns a fs.File when
+// there is a coordinating writer. Call TryOpenReadOnly repeatedly on the same pipeOpenState
+// until syserror.ErrWouldBlock is no longer returned.
+//
+// How it works:
+//
+// Opening a pipe read only will return no error, but each non zero Read will return EOF
+// until a writer becomes available, then EWOULDBLOCK. This is the only state change
+// available to us. We keep a read ahead buffer in case we read bytes instead of getting
+// EWOULDBLOCK, to be read from on the first read request to this fs.File.
+func (p *pipeOpenState) TryOpenReadOnly(ctx context.Context, opener NonBlockingOpener) (*pipeOperations, error) {
+ // Waiting for a blocking read only open involves reading from the host pipe until
+ // bytes or other writers are available, so instead of retrying opening the pipe,
+ // it's necessary to retry reading from the pipe. To do this we need to keep around
+ // the read only pipe we opened, until success or an irrecoverable read error (at
+ // which point it must be closed).
+ if p.hostFile == nil {
+ var err error
+ p.hostFile, err = opener.NonBlockingOpen(ctx, fs.PermMask{Read: true})
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // Try to read from the pipe to see if writers are around.
+ tryReadBuffer := make([]byte, 1)
+ n, rerr := p.hostFile.Read(tryReadBuffer)
+
+ // No bytes were read.
+ if n == 0 {
+ // EOF means that we're not ready yet.
+ if rerr == nil || rerr == io.EOF {
+ return nil, syserror.ErrWouldBlock
+ }
+ // Any error that is not EWOULDBLOCK also means we're not
+ // ready yet, and probably never will be ready. In this
+ // case we need to close the host pipe we opened.
+ if unwrapError(rerr) != syscall.EWOULDBLOCK {
+ p.hostFile.Close()
+ return nil, rerr
+ }
+ }
+
+ // If any bytes were read, no matter the corresponding error, we need
+ // to keep them around so they can be read by the application.
+ var readAheadBuffer []byte
+ if n > 0 {
+ readAheadBuffer = tryReadBuffer
+ }
+
+ // Successfully opened read only blocking pipe with either bytes available
+ // to read and/or a writer available.
+ return newPipeOperations(ctx, opener, fs.FileFlags{Read: true}, p.hostFile, readAheadBuffer)
+}
+
+// TryOpenWriteOnly tries to open a host pipe write only but only returns a fs.File when
+// there is a coordinating reader. Call TryOpenWriteOnly repeatedly on the same pipeOpenState
+// until syserror.ErrWouldBlock is no longer returned.
+//
+// How it works:
+//
+// Opening a pipe write only will return ENXIO until readers are available. Converts the ENXIO
+// to an syserror.ErrWouldBlock, to tell callers to retry.
+func (*pipeOpenState) TryOpenWriteOnly(ctx context.Context, opener NonBlockingOpener) (*pipeOperations, error) {
+ hostFile, err := opener.NonBlockingOpen(ctx, fs.PermMask{Write: true})
+ if unwrapError(err) == syscall.ENXIO {
+ return nil, syserror.ErrWouldBlock
+ }
+ if err != nil {
+ return nil, err
+ }
+ return newPipeOperations(ctx, opener, fs.FileFlags{Write: true}, hostFile, nil)
+}
diff --git a/pkg/sentry/fs/fdpipe/pipe_opener_test.go b/pkg/sentry/fs/fdpipe/pipe_opener_test.go
new file mode 100644
index 000000000..e556da48a
--- /dev/null
+++ b/pkg/sentry/fs/fdpipe/pipe_opener_test.go
@@ -0,0 +1,523 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fdpipe
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "os"
+ "path"
+ "syscall"
+ "testing"
+ "time"
+
+ "github.com/google/uuid"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type hostOpener struct {
+ name string
+}
+
+func (h *hostOpener) NonBlockingOpen(_ context.Context, p fs.PermMask) (*fd.FD, error) {
+ var flags int
+ switch {
+ case p.Read && p.Write:
+ flags = syscall.O_RDWR
+ case p.Write:
+ flags = syscall.O_WRONLY
+ case p.Read:
+ flags = syscall.O_RDONLY
+ default:
+ return nil, syscall.EINVAL
+ }
+ f, err := syscall.Open(h.name, flags|syscall.O_NONBLOCK, 0666)
+ if err != nil {
+ return nil, err
+ }
+ return fd.New(f), nil
+}
+
+func pipename() string {
+ return fmt.Sprintf(path.Join(os.TempDir(), "test-named-pipe-%s"), uuid.New())
+}
+
+func mkpipe(name string) error {
+ return syscall.Mknod(name, syscall.S_IFIFO|0666, 0)
+}
+
+func TestTryOpen(t *testing.T) {
+ for _, test := range []struct {
+ // desc is the test's description.
+ desc string
+
+ // makePipe is true if the test case should create the pipe.
+ makePipe bool
+
+ // flags are the fs.FileFlags used to open the pipe.
+ flags fs.FileFlags
+
+ // expectFile is true if a fs.File is expected.
+ expectFile bool
+
+ // err is the expected error
+ err error
+ }{
+ {
+ desc: "FileFlags lacking Read and Write are invalid",
+ makePipe: false,
+ flags: fs.FileFlags{}, /* bogus */
+ expectFile: false,
+ err: syscall.EINVAL,
+ },
+ {
+ desc: "NonBlocking Read only error returns immediately",
+ makePipe: false, /* causes the error */
+ flags: fs.FileFlags{Read: true, NonBlocking: true},
+ expectFile: false,
+ err: syscall.ENOENT,
+ },
+ {
+ desc: "NonBlocking Read only success returns immediately",
+ makePipe: true,
+ flags: fs.FileFlags{Read: true, NonBlocking: true},
+ expectFile: true,
+ err: nil,
+ },
+ {
+ desc: "NonBlocking Write only error returns immediately",
+ makePipe: false, /* causes the error */
+ flags: fs.FileFlags{Write: true, NonBlocking: true},
+ expectFile: false,
+ err: syscall.ENOENT,
+ },
+ {
+ desc: "NonBlocking Write only no reader error returns immediately",
+ makePipe: true,
+ flags: fs.FileFlags{Write: true, NonBlocking: true},
+ expectFile: false,
+ err: syscall.ENXIO,
+ },
+ {
+ desc: "ReadWrite error returns immediately",
+ makePipe: false, /* causes the error */
+ flags: fs.FileFlags{Read: true, Write: true},
+ expectFile: false,
+ err: syscall.ENOENT,
+ },
+ {
+ desc: "ReadWrite returns immediately",
+ makePipe: true,
+ flags: fs.FileFlags{Read: true, Write: true},
+ expectFile: true,
+ err: nil,
+ },
+ {
+ desc: "Blocking Write only returns open error",
+ makePipe: false, /* causes the error */
+ flags: fs.FileFlags{Write: true},
+ expectFile: false,
+ err: syscall.ENOENT, /* from bogus perms */
+ },
+ {
+ desc: "Blocking Read only returns open error",
+ makePipe: false, /* causes the error */
+ flags: fs.FileFlags{Read: true},
+ expectFile: false,
+ err: syscall.ENOENT,
+ },
+ {
+ desc: "Blocking Write only returns with syserror.ErrWouldBlock",
+ makePipe: true,
+ flags: fs.FileFlags{Write: true},
+ expectFile: false,
+ err: syserror.ErrWouldBlock,
+ },
+ {
+ desc: "Blocking Read only returns with syserror.ErrWouldBlock",
+ makePipe: true,
+ flags: fs.FileFlags{Read: true},
+ expectFile: false,
+ err: syserror.ErrWouldBlock,
+ },
+ } {
+ name := pipename()
+ if test.makePipe {
+ // Create the pipe. We do this per-test case to keep tests independent.
+ if err := mkpipe(name); err != nil {
+ t.Errorf("%s: failed to make host pipe: %v", test.desc, err)
+ continue
+ }
+ defer syscall.Unlink(name)
+ }
+
+ // Use a host opener to keep things simple.
+ opener := &hostOpener{name: name}
+
+ pipeOpenState := &pipeOpenState{}
+ ctx := contexttest.Context(t)
+ pipeOps, err := pipeOpenState.TryOpen(ctx, opener, test.flags)
+ if unwrapError(err) != test.err {
+ t.Errorf("%s: got error %v, want %v", test.desc, err, test.err)
+ if pipeOps != nil {
+ // Cleanup the state of the pipe, and remove the fd from the
+ // fdnotifier. Sadly this needed to maintain the correctness
+ // of other tests because the fdnotifier is global.
+ pipeOps.Release()
+ }
+ continue
+ }
+ if (pipeOps != nil) != test.expectFile {
+ t.Errorf("%s: got non-nil file %v, want %v", test.desc, pipeOps != nil, test.expectFile)
+ }
+ if pipeOps != nil {
+ // Same as above.
+ pipeOps.Release()
+ }
+ }
+}
+
+func TestPipeOpenUnblocksEventually(t *testing.T) {
+ for _, test := range []struct {
+ // desc is the test's description.
+ desc string
+
+ // partnerIsReader is true if the goroutine opening the same pipe as the test case
+ // should open the pipe read only. Otherwise write only. This also means that the
+ // test case will open the pipe in the opposite way.
+ partnerIsReader bool
+
+ // partnerIsBlocking is true if the goroutine opening the same pipe as the test case
+ // should do so without the O_NONBLOCK flag, otherwise opens the pipe with O_NONBLOCK
+ // until ENXIO is not returned.
+ partnerIsBlocking bool
+ }{
+ {
+ desc: "Blocking Read with blocking writer partner opens eventually",
+ partnerIsReader: false,
+ partnerIsBlocking: true,
+ },
+ {
+ desc: "Blocking Write with blocking reader partner opens eventually",
+ partnerIsReader: true,
+ partnerIsBlocking: true,
+ },
+ {
+ desc: "Blocking Read with non-blocking writer partner opens eventually",
+ partnerIsReader: false,
+ partnerIsBlocking: false,
+ },
+ {
+ desc: "Blocking Write with non-blocking reader partner opens eventually",
+ partnerIsReader: true,
+ partnerIsBlocking: false,
+ },
+ } {
+ // Create the pipe. We do this per-test case to keep tests independent.
+ name := pipename()
+ if err := mkpipe(name); err != nil {
+ t.Errorf("%s: failed to make host pipe: %v", test.desc, err)
+ continue
+ }
+ defer syscall.Unlink(name)
+
+ // Spawn the partner.
+ type fderr struct {
+ fd int
+ err error
+ }
+ errch := make(chan fderr, 1)
+ go func() {
+ var flags int
+ if test.partnerIsReader {
+ flags = syscall.O_RDONLY
+ } else {
+ flags = syscall.O_WRONLY
+ }
+ if test.partnerIsBlocking {
+ fd, err := syscall.Open(name, flags, 0666)
+ errch <- fderr{fd: fd, err: err}
+ } else {
+ var fd int
+ err := error(syscall.ENXIO)
+ for err == syscall.ENXIO {
+ fd, err = syscall.Open(name, flags|syscall.O_NONBLOCK, 0666)
+ time.Sleep(1 * time.Second)
+ }
+ errch <- fderr{fd: fd, err: err}
+ }
+ }()
+
+ // Setup file flags for either a read only or write only open.
+ flags := fs.FileFlags{
+ Read: !test.partnerIsReader,
+ Write: test.partnerIsReader,
+ }
+
+ // Open the pipe in a blocking way, which should succeed eventually.
+ opener := &hostOpener{name: name}
+ ctx := contexttest.Context(t)
+ pipeOps, err := Open(ctx, opener, flags)
+ if pipeOps != nil {
+ // Same as TestTryOpen.
+ pipeOps.Release()
+ }
+
+ // Check that the partner opened the file successfully.
+ e := <-errch
+ if e.err != nil {
+ t.Errorf("%s: partner got error %v, wanted nil", test.desc, e.err)
+ continue
+ }
+ // If so, then close the partner fd to avoid leaking an fd.
+ syscall.Close(e.fd)
+
+ // Check that our blocking open was successful.
+ if err != nil {
+ t.Errorf("%s: blocking open got error %v, wanted nil", test.desc, err)
+ continue
+ }
+ if pipeOps == nil {
+ t.Errorf("%s: blocking open got nil file, wanted non-nil", test.desc)
+ continue
+ }
+ }
+}
+
+func TestCopiedReadAheadBuffer(t *testing.T) {
+ // Create the pipe.
+ name := pipename()
+ if err := mkpipe(name); err != nil {
+ t.Fatalf("failed to make host pipe: %v", err)
+ }
+ defer syscall.Unlink(name)
+
+ // We're taking advantage of the fact that pipes opened read only always return
+ // success, but internally they are not deemed "opened" until we're sure that
+ // another writer comes along. This means we can open the same pipe write only
+ // with no problems + write to it, given that opener.Open already tried to open
+ // the pipe RDONLY and succeeded, which we know happened if TryOpen returns
+ // syserror.ErrwouldBlock.
+ //
+ // This simulates the open(RDONLY) <-> open(WRONLY)+write race we care about, but
+ // does not cause our test to be racy (which would be terrible).
+ opener := &hostOpener{name: name}
+ pipeOpenState := &pipeOpenState{}
+ ctx := contexttest.Context(t)
+ pipeOps, err := pipeOpenState.TryOpen(ctx, opener, fs.FileFlags{Read: true})
+ if pipeOps != nil {
+ pipeOps.Release()
+ t.Fatalf("open(%s, %o) got file, want nil", name, syscall.O_RDONLY)
+ }
+ if err != syserror.ErrWouldBlock {
+ t.Fatalf("open(%s, %o) got error %v, want %v", name, syscall.O_RDONLY, err, syserror.ErrWouldBlock)
+ }
+
+ // Then open the same pipe write only and write some bytes to it. The next
+ // time we try to open the pipe read only again via the pipeOpenState, we should
+ // succeed and buffer some of the bytes written.
+ fd, err := syscall.Open(name, syscall.O_WRONLY, 0666)
+ if err != nil {
+ t.Fatalf("open(%s, %o) got error %v, want nil", name, syscall.O_WRONLY, err)
+ }
+ defer syscall.Close(fd)
+
+ data := []byte("hello")
+ if n, err := syscall.Write(fd, data); n != len(data) || err != nil {
+ t.Fatalf("write(%v) got (%d, %v), want (%d, nil)", data, n, err, len(data))
+ }
+
+ // Try the read again, knowing that it should succeed this time.
+ pipeOps, err = pipeOpenState.TryOpen(ctx, opener, fs.FileFlags{Read: true})
+ if pipeOps == nil {
+ t.Fatalf("open(%s, %o) got nil file, want not nil", name, syscall.O_RDONLY)
+ }
+ defer pipeOps.Release()
+
+ if err != nil {
+ t.Fatalf("open(%s, %o) got error %v, want nil", name, syscall.O_RDONLY, err)
+ }
+
+ inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{
+ Type: fs.Pipe,
+ })
+ file := fs.NewFile(ctx, fs.NewDirent(ctx, inode, "pipe"), fs.FileFlags{Read: true}, pipeOps)
+
+ // Check that the file we opened points to a pipe with a non-empty read ahead buffer.
+ bufsize := len(pipeOps.readAheadBuffer)
+ if bufsize != 1 {
+ t.Fatalf("read ahead buffer got %d bytes, want %d", bufsize, 1)
+ }
+
+ // Now for the final test, try to read everything in, expecting to get back all of
+ // the bytes that were written at once. Note that in the wild there is no atomic
+ // read size so expecting to get all bytes from a single writer when there are
+ // multiple readers is a bad expectation.
+ buf := make([]byte, len(data))
+ ioseq := usermem.BytesIOSequence(buf)
+ n, err := pipeOps.Read(ctx, file, ioseq, 0)
+ if err != nil {
+ t.Fatalf("read request got error %v, want nil", err)
+ }
+ if n != int64(len(data)) {
+ t.Fatalf("read request got %d bytes, want %d", n, len(data))
+ }
+ if !bytes.Equal(buf, data) {
+ t.Errorf("read request got bytes [%v], want [%v]", buf, data)
+ }
+}
+
+func TestPipeHangup(t *testing.T) {
+ for _, test := range []struct {
+ // desc is the test's description.
+ desc string
+
+ // flags control how we open our end of the pipe and must be read
+ // only or write only. They also dicate how a coordinating partner
+ // fd is opened, which is their inverse (read only -> write only, etc).
+ flags fs.FileFlags
+
+ // hangupSelf if true causes the test case to close our end of the pipe
+ // and causes hangup errors to be asserted on our coordinating partner's
+ // fd. If hangupSelf is false, then our partner's fd is closed and the
+ // hangup errors are expected on our end of the pipe.
+ hangupSelf bool
+ }{
+ {
+ desc: "Read only gets hangup error",
+ flags: fs.FileFlags{Read: true},
+ },
+ {
+ desc: "Write only gets hangup error",
+ flags: fs.FileFlags{Write: true},
+ },
+ {
+ desc: "Read only generates hangup error",
+ flags: fs.FileFlags{Read: true},
+ hangupSelf: true,
+ },
+ {
+ desc: "Write only generates hangup error",
+ flags: fs.FileFlags{Write: true},
+ hangupSelf: true,
+ },
+ } {
+ if test.flags.Read == test.flags.Write {
+ t.Errorf("%s: test requires a single reader or writer", test.desc)
+ continue
+ }
+
+ // Create the pipe. We do this per-test case to keep tests independent.
+ name := pipename()
+ if err := mkpipe(name); err != nil {
+ t.Errorf("%s: failed to make host pipe: %v", test.desc, err)
+ continue
+ }
+ defer syscall.Unlink(name)
+
+ // Fire off a partner routine which tries to open the same pipe blocking,
+ // which will synchronize with us. The channel allows us to get back the
+ // fd once we expect this partner routine to succeed, so we can manifest
+ // hangup events more directly.
+ fdchan := make(chan int, 1)
+ go func() {
+ // Be explicit about the flags to protect the test from
+ // misconfiguration.
+ var flags int
+ if test.flags.Read {
+ flags = syscall.O_WRONLY
+ } else {
+ flags = syscall.O_RDONLY
+ }
+ fd, err := syscall.Open(name, flags, 0666)
+ if err != nil {
+ t.Logf("Open(%q, %o, 0666) partner failed: %v", name, flags, err)
+ }
+ fdchan <- fd
+ }()
+
+ // Open our end in a blocking way to ensure that we coordinate.
+ opener := &hostOpener{name: name}
+ ctx := contexttest.Context(t)
+ pipeOps, err := Open(ctx, opener, test.flags)
+ if err != nil {
+ t.Errorf("%s: Open got error %v, want nil", test.desc, err)
+ continue
+ }
+ // Don't defer file.DecRef here because that causes the hangup we're
+ // trying to test for.
+
+ // Expect the partner routine to have coordinated with us and get back
+ // its open fd.
+ f := <-fdchan
+ if f < 0 {
+ t.Errorf("%s: partner routine got fd %d, want > 0", test.desc, f)
+ pipeOps.Release()
+ continue
+ }
+
+ if test.hangupSelf {
+ // Hangup self and assert that our partner got the expected hangup
+ // error.
+ pipeOps.Release()
+
+ if test.flags.Read {
+ // Partner is writer.
+ assertWriterHungup(t, test.desc, fd.NewReadWriter(f))
+ } else {
+ // Partner is reader.
+ assertReaderHungup(t, test.desc, fd.NewReadWriter(f))
+ }
+ } else {
+ // Hangup our partner and expect us to get the hangup error.
+ syscall.Close(f)
+ defer pipeOps.Release()
+
+ if test.flags.Read {
+ assertReaderHungup(t, test.desc, pipeOps.(*pipeOperations).file)
+ } else {
+ assertWriterHungup(t, test.desc, pipeOps.(*pipeOperations).file)
+ }
+ }
+ }
+}
+
+func assertReaderHungup(t *testing.T, desc string, reader io.Reader) bool {
+ // Drain the pipe completely, it might have crap in it, but expect EOF eventually.
+ var err error
+ for err == nil {
+ _, err = reader.Read(make([]byte, 10))
+ }
+ if err != io.EOF {
+ t.Errorf("%s: read from self after hangup got error %v, want %v", desc, err, io.EOF)
+ return false
+ }
+ return true
+}
+
+func assertWriterHungup(t *testing.T, desc string, writer io.Writer) bool {
+ if _, err := writer.Write([]byte("hello")); unwrapError(err) != syscall.EPIPE {
+ t.Errorf("%s: write to self after hangup got error %v, want %v", desc, err, syscall.EPIPE)
+ return false
+ }
+ return true
+}
diff --git a/pkg/sentry/fs/fdpipe/pipe_state.go b/pkg/sentry/fs/fdpipe/pipe_state.go
new file mode 100644
index 000000000..af8230a7d
--- /dev/null
+++ b/pkg/sentry/fs/fdpipe/pipe_state.go
@@ -0,0 +1,89 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fdpipe
+
+import (
+ "fmt"
+ "io/ioutil"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// beforeSave is invoked by stateify.
+func (p *pipeOperations) beforeSave() {
+ if p.flags.Read {
+ data, err := ioutil.ReadAll(p.file)
+ if err != nil && !isBlockError(err) {
+ panic(fmt.Sprintf("failed to read from pipe: %v", err))
+ }
+ p.readAheadBuffer = append(p.readAheadBuffer, data...)
+ } else if p.flags.Write {
+ file, err := p.opener.NonBlockingOpen(context.Background(), fs.PermMask{Write: true})
+ if err != nil {
+ panic(fs.ErrSaveRejection{fmt.Errorf("write-only pipe end cannot be re-opened as %v: %v", p, err)})
+ }
+ file.Close()
+ }
+}
+
+// saveFlags is invoked by stateify.
+func (p *pipeOperations) saveFlags() fs.FileFlags {
+ return p.flags
+}
+
+// readPipeOperationsLoading is used to ensure that write-only pipe fds are
+// opened after read/write and read-only pipe fds, to avoid ENXIO when
+// multiple pipe fds refer to different ends of the same pipe.
+var readPipeOperationsLoading sync.WaitGroup
+
+// loadFlags is invoked by stateify.
+func (p *pipeOperations) loadFlags(flags fs.FileFlags) {
+ // This is a hack to ensure that readPipeOperationsLoading includes all
+ // readable pipe fds before any asynchronous calls to
+ // readPipeOperationsLoading.Wait().
+ if flags.Read {
+ readPipeOperationsLoading.Add(1)
+ }
+ p.flags = flags
+}
+
+// afterLoad is invoked by stateify.
+func (p *pipeOperations) afterLoad() {
+ load := func() error {
+ if !p.flags.Read {
+ readPipeOperationsLoading.Wait()
+ } else {
+ defer readPipeOperationsLoading.Done()
+ }
+ var err error
+ p.file, err = p.opener.NonBlockingOpen(context.Background(), fs.PermMask{
+ Read: p.flags.Read,
+ Write: p.flags.Write,
+ })
+ if err != nil {
+ return fmt.Errorf("unable to open pipe %v: %v", p, err)
+ }
+ if err := p.init(); err != nil {
+ return fmt.Errorf("unable to initialize pipe %v: %v", p, err)
+ }
+ return nil
+ }
+
+ // Do background opening of pipe ends. Note for write-only pipe ends we
+ // have to do it asynchronously to avoid blocking the restore.
+ fs.Async(fs.CatchError(load))
+}
diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go
new file mode 100644
index 000000000..a0082ecca
--- /dev/null
+++ b/pkg/sentry/fs/fdpipe/pipe_test.go
@@ -0,0 +1,505 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fdpipe
+
+import (
+ "bytes"
+ "io"
+ "os"
+ "syscall"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func singlePipeFD() (int, error) {
+ fds := make([]int, 2)
+ if err := syscall.Pipe(fds); err != nil {
+ return -1, err
+ }
+ syscall.Close(fds[1])
+ return fds[0], nil
+}
+
+func singleDirFD() (int, error) {
+ return syscall.Open(os.TempDir(), syscall.O_RDONLY, 0666)
+}
+
+func mockPipeDirent(t *testing.T) *fs.Dirent {
+ ctx := contexttest.Context(t)
+ node := fs.NewMockInodeOperations(ctx)
+ node.UAttr = fs.UnstableAttr{
+ Perms: fs.FilePermissions{
+ User: fs.PermMask{Read: true, Write: true},
+ },
+ }
+ inode := fs.NewInode(ctx, node, fs.NewMockMountSource(nil), fs.StableAttr{
+ Type: fs.Pipe,
+ BlockSize: usermem.PageSize,
+ })
+ return fs.NewDirent(ctx, inode, "")
+}
+
+func TestNewPipe(t *testing.T) {
+ for _, test := range []struct {
+ // desc is the test's description.
+ desc string
+
+ // getfd generates the fd to pass to newPipeOperations.
+ getfd func() (int, error)
+
+ // flags are the fs.FileFlags passed to newPipeOperations.
+ flags fs.FileFlags
+
+ // readAheadBuffer is the buffer passed to newPipeOperations.
+ readAheadBuffer []byte
+
+ // err is the expected error.
+ err error
+ }{
+ {
+ desc: "Cannot make new pipe from bad fd",
+ getfd: func() (int, error) { return -1, nil },
+ err: syscall.EINVAL,
+ },
+ {
+ desc: "Cannot make new pipe from non-pipe fd",
+ getfd: singleDirFD,
+ err: syscall.EINVAL,
+ },
+ {
+ desc: "Can make new pipe from pipe fd",
+ getfd: singlePipeFD,
+ flags: fs.FileFlags{Read: true},
+ readAheadBuffer: []byte("hello"),
+ },
+ } {
+ gfd, err := test.getfd()
+ if err != nil {
+ t.Errorf("%s: getfd got (%d, %v), want (fd, nil)", test.desc, gfd, err)
+ continue
+ }
+ f := fd.New(gfd)
+
+ p, err := newPipeOperations(contexttest.Context(t), nil, test.flags, f, test.readAheadBuffer)
+ if p != nil {
+ // This is necessary to remove the fd from the global fd notifier.
+ defer p.Release()
+ } else {
+ // If there is no p to DecRef on, because newPipeOperations failed, then the
+ // file still needs to be closed.
+ defer f.Close()
+ }
+
+ if err != test.err {
+ t.Errorf("%s: got error %v, want %v", test.desc, err, test.err)
+ continue
+ }
+ // Check the state of the pipe given that it was successfully opened.
+ if err == nil {
+ if p == nil {
+ t.Errorf("%s: got nil pipe and nil error, want (pipe, nil)", test.desc)
+ continue
+ }
+ if flags := p.flags; test.flags != flags {
+ t.Errorf("%s: got file flags %v, want %v", test.desc, flags, test.flags)
+ continue
+ }
+ if len(test.readAheadBuffer) != len(p.readAheadBuffer) {
+ t.Errorf("%s: got read ahead buffer length %d, want %d", test.desc, len(p.readAheadBuffer), len(test.readAheadBuffer))
+ continue
+ }
+ fileFlags, _, errno := syscall.Syscall(syscall.SYS_FCNTL, uintptr(p.file.FD()), syscall.F_GETFL, 0)
+ if errno != 0 {
+ t.Errorf("%s: failed to get file flags for fd %d, got %v, want 0", test.desc, p.file.FD(), errno)
+ continue
+ }
+ if fileFlags&syscall.O_NONBLOCK == 0 {
+ t.Errorf("%s: pipe is blocking, expected non-blocking", test.desc)
+ continue
+ }
+ if !fdnotifier.HasFD(int32(f.FD())) {
+ t.Errorf("%s: pipe fd %d is not registered for events", test.desc, f.FD())
+ }
+ }
+ }
+}
+
+func TestPipeDestruction(t *testing.T) {
+ fds := make([]int, 2)
+ if err := syscall.Pipe(fds); err != nil {
+ t.Fatalf("failed to create pipes: got %v, want nil", err)
+ }
+ f := fd.New(fds[0])
+
+ // We don't care about the other end, just use the read end.
+ syscall.Close(fds[1])
+
+ // Test the read end, but it doesn't really matter which.
+ p, err := newPipeOperations(contexttest.Context(t), nil, fs.FileFlags{Read: true}, f, nil)
+ if err != nil {
+ f.Close()
+ t.Fatalf("newPipeOperations got error %v, want nil", err)
+ }
+ // Drop our only reference, which should trigger the destructor.
+ p.Release()
+
+ if fdnotifier.HasFD(int32(fds[0])) {
+ t.Fatalf("after DecRef fdnotifier has fd %d, want no longer registered", fds[0])
+ }
+ if p.file != nil {
+ t.Errorf("after DecRef got file, want nil")
+ }
+}
+
+type Seek struct{}
+
+type ReadDir struct{}
+
+type Writev struct {
+ Src usermem.IOSequence
+}
+
+type Readv struct {
+ Dst usermem.IOSequence
+}
+
+type Fsync struct{}
+
+func TestPipeRequest(t *testing.T) {
+ for _, test := range []struct {
+ // desc is the test's description.
+ desc string
+
+ // request to execute.
+ context interface{}
+
+ // flags determines whether to use the read or write end
+ // of the pipe, for this test it can only be Read or Write.
+ flags fs.FileFlags
+
+ // keepOpenPartner if false closes the other end of the pipe,
+ // otherwise this is delayed until the end of the test.
+ keepOpenPartner bool
+
+ // expected error
+ err error
+ }{
+ {
+ desc: "ReadDir on pipe returns ENOTDIR",
+ context: &ReadDir{},
+ err: syscall.ENOTDIR,
+ },
+ {
+ desc: "Fsync on pipe returns EINVAL",
+ context: &Fsync{},
+ err: syscall.EINVAL,
+ },
+ {
+ desc: "Seek on pipe returns ESPIPE",
+ context: &Seek{},
+ err: syscall.ESPIPE,
+ },
+ {
+ desc: "Readv on pipe from empty buffer returns nil",
+ context: &Readv{Dst: usermem.BytesIOSequence(nil)},
+ flags: fs.FileFlags{Read: true},
+ },
+ {
+ desc: "Readv on pipe from non-empty buffer and closed partner returns EOF",
+ context: &Readv{Dst: usermem.BytesIOSequence(make([]byte, 10))},
+ flags: fs.FileFlags{Read: true},
+ err: io.EOF,
+ },
+ {
+ desc: "Readv on pipe from non-empty buffer and open partner returns EWOULDBLOCK",
+ context: &Readv{Dst: usermem.BytesIOSequence(make([]byte, 10))},
+ flags: fs.FileFlags{Read: true},
+ keepOpenPartner: true,
+ err: syserror.ErrWouldBlock,
+ },
+ {
+ desc: "Writev on pipe from empty buffer returns nil",
+ context: &Writev{Src: usermem.BytesIOSequence(nil)},
+ flags: fs.FileFlags{Write: true},
+ },
+ {
+ desc: "Writev on pipe from non-empty buffer and closed partner returns EPIPE",
+ context: &Writev{Src: usermem.BytesIOSequence([]byte("hello"))},
+ flags: fs.FileFlags{Write: true},
+ err: syscall.EPIPE,
+ },
+ {
+ desc: "Writev on pipe from non-empty buffer and open partner succeeds",
+ context: &Writev{Src: usermem.BytesIOSequence([]byte("hello"))},
+ flags: fs.FileFlags{Write: true},
+ keepOpenPartner: true,
+ },
+ } {
+ if test.flags.Read && test.flags.Write {
+ panic("both read and write not supported for this test")
+ }
+
+ fds := make([]int, 2)
+ if err := syscall.Pipe(fds); err != nil {
+ t.Errorf("%s: failed to create pipes: got %v, want nil", test.desc, err)
+ continue
+ }
+
+ // Configure the fd and partner fd based on the file flags.
+ testFd, partnerFd := fds[0], fds[1]
+ if test.flags.Write {
+ testFd, partnerFd = fds[1], fds[0]
+ }
+
+ // Configure closing the fds.
+ if test.keepOpenPartner {
+ defer syscall.Close(partnerFd)
+ } else {
+ syscall.Close(partnerFd)
+ }
+
+ // Create the pipe.
+ ctx := contexttest.Context(t)
+ p, err := newPipeOperations(ctx, nil, test.flags, fd.New(testFd), nil)
+ if err != nil {
+ t.Fatalf("%s: newPipeOperations got error %v, want nil", test.desc, err)
+ }
+ defer p.Release()
+
+ inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe})
+ file := fs.NewFile(ctx, fs.NewDirent(ctx, inode, "pipe"), fs.FileFlags{Read: true}, p)
+
+ // Issue request via the appropriate function.
+ switch c := test.context.(type) {
+ case *Seek:
+ _, err = p.Seek(ctx, file, 0, 0)
+ case *ReadDir:
+ _, err = p.Readdir(ctx, file, nil)
+ case *Readv:
+ _, err = p.Read(ctx, file, c.Dst, 0)
+ case *Writev:
+ _, err = p.Write(ctx, file, c.Src, 0)
+ case *Fsync:
+ err = p.Fsync(ctx, file, 0, fs.FileMaxOffset, fs.SyncAll)
+ default:
+ t.Errorf("%s: unknown request type %T", test.desc, test.context)
+ }
+
+ if unwrapError(err) != test.err {
+ t.Errorf("%s: got error %v, want %v", test.desc, err, test.err)
+ }
+ }
+}
+
+func TestPipeReadAheadBuffer(t *testing.T) {
+ fds := make([]int, 2)
+ if err := syscall.Pipe(fds); err != nil {
+ t.Fatalf("failed to create pipes: got %v, want nil", err)
+ }
+ rfile := fd.New(fds[0])
+
+ // Eventually close the write end, which is not wrapped in a pipe object.
+ defer syscall.Close(fds[1])
+
+ // Write some bytes to this end.
+ data := []byte("world")
+ if n, err := syscall.Write(fds[1], data); n != len(data) || err != nil {
+ rfile.Close()
+ t.Fatalf("write to pipe got (%d, %v), want (%d, nil)", n, err, len(data))
+ }
+ // Close the write end immediately, we don't care about it.
+
+ buffered := []byte("hello ")
+ ctx := contexttest.Context(t)
+ p, err := newPipeOperations(ctx, nil, fs.FileFlags{Read: true}, rfile, buffered)
+ if err != nil {
+ rfile.Close()
+ t.Fatalf("newPipeOperations got error %v, want nil", err)
+ }
+ defer p.Release()
+
+ inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{
+ Type: fs.Pipe,
+ })
+ file := fs.NewFile(ctx, fs.NewDirent(ctx, inode, "pipe"), fs.FileFlags{Read: true}, p)
+
+ // In total we expect to read data + buffered.
+ total := append(buffered, data...)
+
+ buf := make([]byte, len(total))
+ iov := usermem.BytesIOSequence(buf)
+ n, err := p.Read(contexttest.Context(t), file, iov, 0)
+ if err != nil {
+ t.Fatalf("read request got error %v, want nil", err)
+ }
+ if n != int64(len(total)) {
+ t.Fatalf("read request got %d bytes, want %d", n, len(total))
+ }
+ if !bytes.Equal(buf, total) {
+ t.Errorf("read request got bytes [%v], want [%v]", buf, total)
+ }
+}
+
+// This is very important for pipes in general because they can return
+// EWOULDBLOCK and for those that block they must continue until they have read
+// all of the data (and report it as such).
+func TestPipeReadsAccumulate(t *testing.T) {
+ fds := make([]int, 2)
+ if err := syscall.Pipe(fds); err != nil {
+ t.Fatalf("failed to create pipes: got %v, want nil", err)
+ }
+ rfile := fd.New(fds[0])
+
+ // Eventually close the write end, it doesn't depend on a pipe object.
+ defer syscall.Close(fds[1])
+
+ // Get a new read only pipe reference.
+ ctx := contexttest.Context(t)
+ p, err := newPipeOperations(ctx, nil, fs.FileFlags{Read: true}, rfile, nil)
+ if err != nil {
+ rfile.Close()
+ t.Fatalf("newPipeOperations got error %v, want nil", err)
+ }
+ // Don't forget to remove the fd from the fd notifier. Otherwise other tests will
+ // likely be borked, because it's global :(
+ defer p.Release()
+
+ inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{
+ Type: fs.Pipe,
+ })
+ file := fs.NewFile(ctx, fs.NewDirent(ctx, inode, "pipe"), fs.FileFlags{Read: true}, p)
+
+ // Write some some bytes to the pipe.
+ data := []byte("some message")
+ if n, err := syscall.Write(fds[1], data); n != len(data) || err != nil {
+ t.Fatalf("write to pipe got (%d, %v), want (%d, nil)", n, err, len(data))
+ }
+
+ // Construct a segment vec that is a bit more than we have written so we
+ // trigger an EWOULDBLOCK.
+ wantBytes := len(data) + 1
+ readBuffer := make([]byte, wantBytes)
+ iov := usermem.BytesIOSequence(readBuffer)
+ n, err := p.Read(ctx, file, iov, 0)
+ total := n
+ iov = iov.DropFirst64(n)
+ if err != syserror.ErrWouldBlock {
+ t.Fatalf("Readv got error %v, want %v", err, syserror.ErrWouldBlock)
+ }
+
+ // Write a few more bytes to allow us to read more/accumulate.
+ extra := []byte("extra")
+ if n, err := syscall.Write(fds[1], extra); n != len(extra) || err != nil {
+ t.Fatalf("write to pipe got (%d, %v), want (%d, nil)", n, err, len(extra))
+ }
+
+ // This time, using the same request, we should not block.
+ n, err = p.Read(ctx, file, iov, 0)
+ total += n
+ if err != nil {
+ t.Fatalf("Readv got error %v, want nil", err)
+ }
+
+ // Assert that the result we got back is cumulative.
+ if total != int64(wantBytes) {
+ t.Fatalf("Readv sequence got %d bytes, want %d", total, wantBytes)
+ }
+
+ if want := append(data, extra[0]); !bytes.Equal(readBuffer, want) {
+ t.Errorf("Readv sequence got %v, want %v", readBuffer, want)
+ }
+}
+
+// Same as TestReadsAccumulate.
+func TestPipeWritesAccumulate(t *testing.T) {
+ fds := make([]int, 2)
+ if err := syscall.Pipe(fds); err != nil {
+ t.Fatalf("failed to create pipes: got %v, want nil", err)
+ }
+ wfile := fd.New(fds[1])
+
+ // Eventually close the read end, it doesn't depend on a pipe object.
+ defer syscall.Close(fds[0])
+
+ // Get a new write only pipe reference.
+ ctx := contexttest.Context(t)
+ p, err := newPipeOperations(ctx, nil, fs.FileFlags{Write: true}, wfile, nil)
+ if err != nil {
+ wfile.Close()
+ t.Fatalf("newPipeOperations got error %v, want nil", err)
+ }
+ // Don't forget to remove the fd from the fd notifier. Otherwise other tests
+ // will likely be borked, because it's global :(
+ defer p.Release()
+
+ inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{
+ Type: fs.Pipe,
+ })
+ file := fs.NewFile(ctx, fs.NewDirent(ctx, inode, "pipe"), fs.FileFlags{Read: true}, p)
+
+ pipeSize, _, errno := syscall.Syscall(syscall.SYS_FCNTL, uintptr(wfile.FD()), syscall.F_GETPIPE_SZ, 0)
+ if errno != 0 {
+ t.Fatalf("fcntl(F_GETPIPE_SZ) failed: %v", errno)
+ }
+ t.Logf("Pipe buffer size: %d", pipeSize)
+
+ // Construct a segment vec that is larger than the pipe size to trigger an
+ // EWOULDBLOCK.
+ wantBytes := int(pipeSize) * 2
+ writeBuffer := make([]byte, wantBytes)
+ for i := 0; i < wantBytes; i++ {
+ writeBuffer[i] = 'a'
+ }
+ iov := usermem.BytesIOSequence(writeBuffer)
+ n, err := p.Write(ctx, file, iov, 0)
+ if err != syserror.ErrWouldBlock {
+ t.Fatalf("Writev got error %v, want %v", err, syserror.ErrWouldBlock)
+ }
+ if n != int64(pipeSize) {
+ t.Fatalf("Writev partial write, got: %v, want %v", n, pipeSize)
+ }
+ total := n
+ iov = iov.DropFirst64(n)
+
+ // Read the entire pipe buf size to make space for the second half.
+ readBuffer := make([]byte, n)
+ if n, err := syscall.Read(fds[0], readBuffer); n != len(readBuffer) || err != nil {
+ t.Fatalf("write to pipe got (%d, %v), want (%d, nil)", n, err, len(readBuffer))
+ }
+ if !bytes.Equal(readBuffer, writeBuffer[:len(readBuffer)]) {
+ t.Fatalf("wrong data read from pipe, got: %v, want: %v", readBuffer, writeBuffer)
+ }
+
+ // This time we should not block.
+ n, err = p.Write(ctx, file, iov, 0)
+ if err != nil {
+ t.Fatalf("Writev got error %v, want nil", err)
+ }
+ if n != int64(pipeSize) {
+ t.Fatalf("Writev partial write, got: %v, want %v", n, pipeSize)
+ }
+ total += n
+
+ // Assert that the result we got back is cumulative.
+ if total != int64(wantBytes) {
+ t.Fatalf("Writev sequence got %d bytes, want %d", total, wantBytes)
+ }
+}
diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go
new file mode 100644
index 000000000..ca41520b4
--- /dev/null
+++ b/pkg/sentry/fs/file.go
@@ -0,0 +1,593 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "math"
+ "sync/atomic"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/amutex"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/metric"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+var (
+ // RecordWaitTime controls writing metrics for filesystem reads.
+ // Enabling this comes at a small CPU cost due to performing two
+ // monotonic clock reads per read call.
+ //
+ // Note that this is only performed in the direct read path, and may
+ // not be consistently applied for other forms of reads, such as
+ // splice.
+ RecordWaitTime = false
+
+ reads = metric.MustCreateNewUint64Metric("/fs/reads", false /* sync */, "Number of file reads.")
+ readWait = metric.MustCreateNewUint64NanosecondsMetric("/fs/read_wait", false /* sync */, "Time waiting on file reads, in nanoseconds.")
+)
+
+// IncrementWait increments the given wait time metric, if enabled.
+func IncrementWait(m *metric.Uint64Metric, start time.Time) {
+ if !RecordWaitTime {
+ return
+ }
+ m.IncrementBy(uint64(time.Since(start)))
+}
+
+// FileMaxOffset is the maximum possible file offset.
+const FileMaxOffset = math.MaxInt64
+
+// File is an open file handle. It is thread-safe.
+//
+// File provides stronger synchronization guarantees than Linux. Linux
+// synchronizes lseek(2), read(2), and write(2) with respect to the file
+// offset for regular files and only for those interfaces. See
+// fs/read_write.c:fdget_pos, fs.read_write.c:fdput_pos and FMODE_ATOMIC_POS.
+//
+// In contrast, File synchronizes any operation that could take a long time
+// under a single abortable mutex which also synchronizes lseek(2), read(2),
+// and write(2).
+//
+// FIXME(b/38451980): Split synchronization from cancellation.
+//
+// +stateify savable
+type File struct {
+ refs.AtomicRefCount
+
+ // UniqueID is the globally unique identifier of the File.
+ UniqueID uint64
+
+ // Dirent is the Dirent backing this File. This encodes the name
+ // of the File via Dirent.FullName() as well as its identity via the
+ // Dirent's Inode. The Dirent is non-nil.
+ //
+ // A File holds a reference to this Dirent. Using the returned Dirent is
+ // only safe as long as a reference on the File is held. The association
+ // between a File and a Dirent is immutable.
+ //
+ // Files that are not parented in a filesystem return a root Dirent
+ // that holds a reference to their Inode.
+ //
+ // The name of the Dirent may reflect parentage if the Dirent is not a
+ // root Dirent or the identity of the File on a pseudo filesystem (pipefs,
+ // sockfs, etc).
+ //
+ // Multiple Files may hold a reference to the same Dirent. This is the
+ // common case for Files that are parented and maintain consistency with
+ // other files via the Dirent cache.
+ Dirent *Dirent
+
+ // flagsMu protects flags and async below.
+ flagsMu sync.Mutex `state:"nosave"`
+
+ // flags are the File's flags. Setting or getting flags is fully atomic
+ // and is not protected by mu (below).
+ flags FileFlags
+
+ // async handles O_ASYNC notifications.
+ async FileAsync
+
+ // saving indicates that this file is in the process of being saved.
+ saving bool `state:"nosave"`
+
+ // mu is dual-purpose: first, to make read(2) and write(2) thread-safe
+ // in conformity with POSIX, and second, to cancel operations before they
+ // begin in response to interruptions (i.e. signals).
+ mu amutex.AbortableMutex `state:"nosave"`
+
+ // FileOperations implements file system specific behavior for this File.
+ FileOperations FileOperations `state:"wait"`
+
+ // offset is the File's offset. Updating offset is protected by mu but
+ // can be read atomically via File.Offset() outside of mu.
+ offset int64
+}
+
+// NewFile returns a File. It takes a reference on the Dirent and owns the
+// lifetime of the FileOperations. Files that do not support reading and
+// writing at an arbitrary offset should set flags.Pread and flags.Pwrite
+// to false respectively.
+func NewFile(ctx context.Context, dirent *Dirent, flags FileFlags, fops FileOperations) *File {
+ dirent.IncRef()
+ f := File{
+ UniqueID: uniqueid.GlobalFromContext(ctx),
+ Dirent: dirent,
+ FileOperations: fops,
+ flags: flags,
+ }
+ f.mu.Init()
+ f.EnableLeakCheck("fs.File")
+ return &f
+}
+
+// DecRef destroys the File when it is no longer referenced.
+func (f *File) DecRef() {
+ f.DecRefWithDestructor(func() {
+ // Drop BSD style locks.
+ lockRng := lock.LockRange{Start: 0, End: lock.LockEOF}
+ f.Dirent.Inode.LockCtx.BSD.UnlockRegion(f, lockRng)
+
+ // Release resources held by the FileOperations.
+ f.FileOperations.Release()
+
+ // Release a reference on the Dirent.
+ f.Dirent.DecRef()
+
+ // Only unregister if we are currently registered. There is nothing
+ // to register if f.async is nil (this happens when async mode is
+ // enabled without setting an owner). Also, we unregister during
+ // save.
+ f.flagsMu.Lock()
+ if !f.saving && f.flags.Async && f.async != nil {
+ f.async.Unregister(f)
+ }
+ f.async = nil
+ f.flagsMu.Unlock()
+ })
+}
+
+// Flags atomically loads the File's flags.
+func (f *File) Flags() FileFlags {
+ f.flagsMu.Lock()
+ flags := f.flags
+ f.flagsMu.Unlock()
+ return flags
+}
+
+// SetFlags atomically changes the File's flags to the values contained
+// in newFlags. See SettableFileFlags for values that can be set.
+func (f *File) SetFlags(newFlags SettableFileFlags) {
+ f.flagsMu.Lock()
+ f.flags.Direct = newFlags.Direct
+ f.flags.NonBlocking = newFlags.NonBlocking
+ f.flags.Append = newFlags.Append
+ if f.async != nil {
+ if newFlags.Async && !f.flags.Async {
+ f.async.Register(f)
+ }
+ if !newFlags.Async && f.flags.Async {
+ f.async.Unregister(f)
+ }
+ }
+ f.flags.Async = newFlags.Async
+ f.flagsMu.Unlock()
+}
+
+// Offset atomically loads the File's offset.
+func (f *File) Offset() int64 {
+ return atomic.LoadInt64(&f.offset)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (f *File) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return f.FileOperations.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (f *File) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ f.FileOperations.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (f *File) EventUnregister(e *waiter.Entry) {
+ f.FileOperations.EventUnregister(e)
+}
+
+// Seek calls f.FileOperations.Seek with f as the File, updating the file
+// offset to the value returned by f.FileOperations.Seek if the operation
+// is successful.
+//
+// Returns syserror.ErrInterrupted if seeking was interrupted.
+func (f *File) Seek(ctx context.Context, whence SeekWhence, offset int64) (int64, error) {
+ if !f.mu.Lock(ctx) {
+ return 0, syserror.ErrInterrupted
+ }
+ defer f.mu.Unlock()
+
+ newOffset, err := f.FileOperations.Seek(ctx, f, whence, offset)
+ if err == nil {
+ atomic.StoreInt64(&f.offset, newOffset)
+ }
+ return newOffset, err
+}
+
+// Readdir reads the directory entries of this File and writes them out
+// to the DentrySerializer until entries can no longer be written. If even
+// a single directory entry is written then Readdir returns a nil error
+// and the directory offset is advanced.
+//
+// Readdir unconditionally updates the access time on the File's Inode,
+// see fs/readdir.c:iterate_dir.
+//
+// Returns syserror.ErrInterrupted if reading was interrupted.
+func (f *File) Readdir(ctx context.Context, serializer DentrySerializer) error {
+ if !f.mu.Lock(ctx) {
+ return syserror.ErrInterrupted
+ }
+ defer f.mu.Unlock()
+
+ offset, err := f.FileOperations.Readdir(ctx, f, serializer)
+ atomic.StoreInt64(&f.offset, offset)
+ return err
+}
+
+// Readv calls f.FileOperations.Read with f as the File, advancing the file
+// offset if f.FileOperations.Read returns bytes read > 0.
+//
+// Returns syserror.ErrInterrupted if reading was interrupted.
+func (f *File) Readv(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+ var start time.Time
+ if RecordWaitTime {
+ start = time.Now()
+ }
+ if !f.mu.Lock(ctx) {
+ IncrementWait(readWait, start)
+ return 0, syserror.ErrInterrupted
+ }
+
+ reads.Increment()
+ n, err := f.FileOperations.Read(ctx, f, dst, f.offset)
+ if n > 0 && !f.flags.NonSeekable {
+ atomic.AddInt64(&f.offset, n)
+ }
+ f.mu.Unlock()
+ IncrementWait(readWait, start)
+ return n, err
+}
+
+// Preadv calls f.FileOperations.Read with f as the File. It does not
+// advance the file offset. If !f.Flags().Pread, Preadv should not be
+// called.
+//
+// Otherwise same as Readv.
+func (f *File) Preadv(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
+ var start time.Time
+ if RecordWaitTime {
+ start = time.Now()
+ }
+ if !f.mu.Lock(ctx) {
+ IncrementWait(readWait, start)
+ return 0, syserror.ErrInterrupted
+ }
+
+ reads.Increment()
+ n, err := f.FileOperations.Read(ctx, f, dst, offset)
+ f.mu.Unlock()
+ IncrementWait(readWait, start)
+ return n, err
+}
+
+// Writev calls f.FileOperations.Write with f as the File, advancing the
+// file offset if f.FileOperations.Write returns bytes written > 0.
+//
+// Writev positions the write offset at EOF if f.Flags().Append. This is
+// unavoidably racy for network file systems. Writev also truncates src
+// to avoid overrunning the current file size limit if necessary.
+//
+// Returns syserror.ErrInterrupted if writing was interrupted.
+func (f *File) Writev(ctx context.Context, src usermem.IOSequence) (int64, error) {
+ if !f.mu.Lock(ctx) {
+ return 0, syserror.ErrInterrupted
+ }
+ unlockAppendMu := f.Dirent.Inode.lockAppendMu(f.Flags().Append)
+ // Handle append mode.
+ if f.Flags().Append {
+ if err := f.offsetForAppend(ctx, &f.offset); err != nil {
+ unlockAppendMu()
+ f.mu.Unlock()
+ return 0, err
+ }
+ }
+
+ // Enforce file limits.
+ limit, ok := f.checkLimit(ctx, f.offset)
+ switch {
+ case ok && limit == 0:
+ unlockAppendMu()
+ f.mu.Unlock()
+ return 0, syserror.ErrExceedsFileSizeLimit
+ case ok:
+ src = src.TakeFirst64(limit)
+ }
+
+ // We must hold the lock during the write.
+ n, err := f.FileOperations.Write(ctx, f, src, f.offset)
+ if n >= 0 && !f.flags.NonSeekable {
+ atomic.StoreInt64(&f.offset, f.offset+n)
+ }
+ unlockAppendMu()
+ f.mu.Unlock()
+ return n, err
+}
+
+// Pwritev calls f.FileOperations.Write with f as the File. It does not
+// advance the file offset. If !f.Flags().Pwritev, Pwritev should not be
+// called.
+//
+// Otherwise same as Writev.
+func (f *File) Pwritev(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ // "POSIX requires that opening a file with the O_APPEND flag should
+ // have no effect on the location at which pwrite() writes data.
+ // However, on Linux, if a file is opened with O_APPEND, pwrite()
+ // appends data to the end of the file, regardless of the value of
+ // offset."
+ unlockAppendMu := f.Dirent.Inode.lockAppendMu(f.Flags().Append)
+ defer unlockAppendMu()
+ if f.Flags().Append {
+ if err := f.offsetForAppend(ctx, &offset); err != nil {
+ return 0, err
+ }
+ }
+
+ // Enforce file limits.
+ limit, ok := f.checkLimit(ctx, offset)
+ switch {
+ case ok && limit == 0:
+ return 0, syserror.ErrExceedsFileSizeLimit
+ case ok:
+ src = src.TakeFirst64(limit)
+ }
+
+ return f.FileOperations.Write(ctx, f, src, offset)
+}
+
+// offsetForAppend atomically sets the given offset to the end of the file.
+//
+// Precondition: the file.Dirent.Inode.appendMu mutex should be held for
+// writing.
+func (f *File) offsetForAppend(ctx context.Context, offset *int64) error {
+ uattr, err := f.Dirent.Inode.UnstableAttr(ctx)
+ if err != nil {
+ // This is an odd error, we treat it as evidence that
+ // something is terribly wrong with the filesystem.
+ return syserror.EIO
+ }
+
+ // Update the offset.
+ atomic.StoreInt64(offset, uattr.Size)
+
+ return nil
+}
+
+// checkLimit checks the offset that the write will be performed at. The
+// returned boolean indicates that the write must be limited. The returned
+// integer indicates the new maximum write length.
+func (f *File) checkLimit(ctx context.Context, offset int64) (int64, bool) {
+ if IsRegular(f.Dirent.Inode.StableAttr) {
+ // Enforce size limits.
+ fileSizeLimit := limits.FromContext(ctx).Get(limits.FileSize).Cur
+ if fileSizeLimit <= math.MaxInt64 {
+ if offset >= int64(fileSizeLimit) {
+ return 0, true
+ }
+ return int64(fileSizeLimit) - offset, true
+ }
+ }
+
+ return 0, false
+}
+
+// Fsync calls f.FileOperations.Fsync with f as the File.
+//
+// Returns syserror.ErrInterrupted if syncing was interrupted.
+func (f *File) Fsync(ctx context.Context, start int64, end int64, syncType SyncType) error {
+ if !f.mu.Lock(ctx) {
+ return syserror.ErrInterrupted
+ }
+ defer f.mu.Unlock()
+
+ return f.FileOperations.Fsync(ctx, f, start, end, syncType)
+}
+
+// Flush calls f.FileOperations.Flush with f as the File.
+//
+// Returns syserror.ErrInterrupted if syncing was interrupted.
+func (f *File) Flush(ctx context.Context) error {
+ if !f.mu.Lock(ctx) {
+ return syserror.ErrInterrupted
+ }
+ defer f.mu.Unlock()
+
+ return f.FileOperations.Flush(ctx, f)
+}
+
+// ConfigureMMap calls f.FileOperations.ConfigureMMap with f as the File.
+//
+// Returns syserror.ErrInterrupted if interrupted.
+func (f *File) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ if !f.mu.Lock(ctx) {
+ return syserror.ErrInterrupted
+ }
+ defer f.mu.Unlock()
+
+ return f.FileOperations.ConfigureMMap(ctx, f, opts)
+}
+
+// UnstableAttr calls f.FileOperations.UnstableAttr with f as the File.
+//
+// Returns syserror.ErrInterrupted if interrupted.
+func (f *File) UnstableAttr(ctx context.Context) (UnstableAttr, error) {
+ if !f.mu.Lock(ctx) {
+ return UnstableAttr{}, syserror.ErrInterrupted
+ }
+ defer f.mu.Unlock()
+
+ return f.FileOperations.UnstableAttr(ctx, f)
+}
+
+// MappedName implements memmap.MappingIdentity.MappedName.
+func (f *File) MappedName(ctx context.Context) string {
+ root := RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+ name, _ := f.Dirent.FullName(root)
+ return name
+}
+
+// DeviceID implements memmap.MappingIdentity.DeviceID.
+func (f *File) DeviceID() uint64 {
+ return f.Dirent.Inode.StableAttr.DeviceID
+}
+
+// InodeID implements memmap.MappingIdentity.InodeID.
+func (f *File) InodeID() uint64 {
+ return f.Dirent.Inode.StableAttr.InodeID
+}
+
+// Msync implements memmap.MappingIdentity.Msync.
+func (f *File) Msync(ctx context.Context, mr memmap.MappableRange) error {
+ return f.Fsync(ctx, int64(mr.Start), int64(mr.End-1), SyncData)
+}
+
+// A FileAsync sends signals to its owner when w is ready for IO.
+type FileAsync interface {
+ Register(w waiter.Waitable)
+ Unregister(w waiter.Waitable)
+}
+
+// Async gets the stored FileAsync or creates a new one with the supplied
+// function. If the supplied function is nil, no FileAsync is created and the
+// current value is returned.
+func (f *File) Async(newAsync func() FileAsync) FileAsync {
+ f.flagsMu.Lock()
+ defer f.flagsMu.Unlock()
+ if f.async == nil && newAsync != nil {
+ f.async = newAsync()
+ if f.flags.Async {
+ f.async.Register(f)
+ }
+ }
+ return f.async
+}
+
+// lockedReader implements io.Reader and io.ReaderAt.
+//
+// Note this reads the underlying file using the file operations directly. It
+// is the responsibility of the caller to ensure that locks are appropriately
+// held and offsets updated if required. This should be used only by internal
+// functions that perform these operations and checks at other times.
+type lockedReader struct {
+ // Ctx is the context for the file reader.
+ Ctx context.Context
+
+ // File is the file to read from.
+ File *File
+
+ // Offset is the offset to start at.
+ //
+ // This applies only to Read, not ReadAt.
+ Offset int64
+}
+
+// Read implements io.Reader.Read.
+func (r *lockedReader) Read(buf []byte) (int, error) {
+ if r.Ctx.Interrupted() {
+ return 0, syserror.ErrInterrupted
+ }
+ n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.Offset)
+ r.Offset += n
+ return int(n), err
+}
+
+// ReadAt implements io.Reader.ReadAt.
+func (r *lockedReader) ReadAt(buf []byte, offset int64) (int, error) {
+ if r.Ctx.Interrupted() {
+ return 0, syserror.ErrInterrupted
+ }
+ n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), offset)
+ return int(n), err
+}
+
+// lockedWriter implements io.Writer and io.WriterAt.
+//
+// The same constraints as lockedReader apply; see above.
+type lockedWriter struct {
+ // Ctx is the context for the file writer.
+ Ctx context.Context
+
+ // File is the file to write to.
+ File *File
+
+ // Offset is the offset to start at.
+ //
+ // This applies only to Write, not WriteAt.
+ Offset int64
+}
+
+// Write implements io.Writer.Write.
+func (w *lockedWriter) Write(buf []byte) (int, error) {
+ if w.Ctx.Interrupted() {
+ return 0, syserror.ErrInterrupted
+ }
+ n, err := w.WriteAt(buf, w.Offset)
+ w.Offset += int64(n)
+ return int(n), err
+}
+
+// WriteAt implements io.Writer.WriteAt.
+func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) {
+ var (
+ written int
+ err error
+ )
+ // The io.Writer contract requires that Write writes all available
+ // bytes and does not return short writes. This causes errors with
+ // io.Copy, since our own Write interface does not have this same
+ // contract. Enforce that here.
+ for written < len(buf) {
+ if w.Ctx.Interrupted() {
+ return written, syserror.ErrInterrupted
+ }
+ var n int64
+ n, err = w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf[written:]), offset+int64(written))
+ if n > 0 {
+ written += int(n)
+ }
+ if err != nil {
+ break
+ }
+ }
+ return written, err
+}
diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go
new file mode 100644
index 000000000..beba0f771
--- /dev/null
+++ b/pkg/sentry/fs/file_operations.go
@@ -0,0 +1,175 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SpliceOpts define how a splice works.
+type SpliceOpts struct {
+ // Length is the length of the splice operation.
+ Length int64
+
+ // SrcOffset indicates whether the existing source file offset should
+ // be used. If this is true, then the Start value below is used.
+ //
+ // When passed to FileOperations object, this should always be true as
+ // the offset will be provided by a layer above, unless the object in
+ // question is a pipe or socket. This value can be relied upon for such
+ // an indicator.
+ SrcOffset bool
+
+ // SrcStart is the start of the source file. This is used only if
+ // SrcOffset is false.
+ SrcStart int64
+
+ // Dup indicates that the contents should not be consumed from the
+ // source (e.g. in the case of a socket or a pipe), but duplicated.
+ Dup bool
+
+ // DstOffset indicates that the destination file offset should be used.
+ //
+ // See SrcOffset for additional information.
+ DstOffset bool
+
+ // DstStart is the start of the destination file. This is used only if
+ // DstOffset is false.
+ DstStart int64
+}
+
+// FileOperations are operations on a File that diverge per file system.
+//
+// Operations that take a *File may use only the following interfaces:
+//
+// - File.UniqueID: Operations may only read this value.
+// - File.Dirent: Operations must not take or drop a reference.
+// - File.Offset(): This value is guaranteed to not change for the
+// duration of the operation.
+// - File.Flags(): This value may change during the operation.
+type FileOperations interface {
+ // Release release resources held by FileOperations.
+ Release()
+
+ // Waitable defines how this File can be waited on for read and
+ // write readiness.
+ waiter.Waitable
+
+ // Seek seeks to offset based on SeekWhence. Returns the new
+ // offset or no change in the offset and an error.
+ Seek(ctx context.Context, file *File, whence SeekWhence, offset int64) (int64, error)
+
+ // Readdir reads the directory entries of file and serializes them
+ // using serializer.
+ //
+ // Returns the new directory offset or no change in the offset and
+ // an error. The offset returned must not be less than file.Offset().
+ //
+ // Serialization of directory entries must not happen asynchronously.
+ Readdir(ctx context.Context, file *File, serializer DentrySerializer) (int64, error)
+
+ // Read reads from file into dst at offset and returns the number
+ // of bytes read which must be greater than or equal to 0. File
+ // systems that do not support reading at an offset, (i.e. pipefs,
+ // sockfs) may ignore the offset. These file systems are expected
+ // to construct Files with !FileFlags.Pread.
+ //
+ // Read may return a nil error and only partially fill dst (at or
+ // before EOF). If the file represents a symlink, Read reads the target
+ // value of the symlink.
+ //
+ // Read does not check permissions nor flags.
+ //
+ // Read must not be called if !FileFlags.Read.
+ Read(ctx context.Context, file *File, dst usermem.IOSequence, offset int64) (int64, error)
+
+ // WriteTo is a variant of read that takes another file as a
+ // destination. For a splice (copy or move from one file to another),
+ // first a WriteTo on the source is attempted, followed by a ReadFrom
+ // on the destination, following by a buffered copy with standard Read
+ // and Write operations.
+ //
+ // If dup is set, the data should be duplicated into the destination
+ // and retained.
+ //
+ // The same preconditions as Read apply.
+ WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (int64, error)
+
+ // Write writes src to file at offset and returns the number of bytes
+ // written which must be greater than or equal to 0. Like Read, file
+ // systems that do not support writing at an offset (i.e. pipefs, sockfs)
+ // may ignore the offset. These file systems are expected to construct
+ // Files with !FileFlags.Pwrite.
+ //
+ // If only part of src could be written, Write must return an error
+ // indicating why (e.g. syserror.ErrWouldBlock).
+ //
+ // Write does not check permissions nor flags.
+ //
+ // Write must not be called if !FileFlags.Write.
+ Write(ctx context.Context, file *File, src usermem.IOSequence, offset int64) (int64, error)
+
+ // ReadFrom is a variant of write that takes a another file as a
+ // source. See WriteTo for details regarding how this is called.
+ //
+ // The same preconditions as Write apply; FileFlags.Write must be set.
+ ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (int64, error)
+
+ // Fsync writes buffered modifications of file and/or flushes in-flight
+ // operations to backing storage based on syncType. The range to sync is
+ // [start, end]. The end is inclusive so that the last byte of a maximally
+ // sized file can be synced.
+ Fsync(ctx context.Context, file *File, start, end int64, syncType SyncType) error
+
+ // Flush this file's buffers/state (on close(2)).
+ Flush(ctx context.Context, file *File) error
+
+ // ConfigureMMap mutates opts to implement mmap(2) for the file. Most
+ // implementations can either embed fsutil.FileNoMMap (if they don't support
+ // memory mapping) or call fsutil.GenericConfigureMMap with the appropriate
+ // memmap.Mappable.
+ ConfigureMMap(ctx context.Context, file *File, opts *memmap.MMapOpts) error
+
+ // UnstableAttr returns the "unstable" attributes of the inode represented
+ // by the file. Most implementations can embed
+ // fsutil.FileUseInodeUnstableAttr, which delegates to
+ // InodeOperations.UnstableAttr.
+ UnstableAttr(ctx context.Context, file *File) (UnstableAttr, error)
+
+ // Ioctl implements the ioctl(2) linux syscall.
+ //
+ // io provides access to the virtual memory space to which pointers in args
+ // refer.
+ //
+ // Preconditions: The AddressSpace (if any) that io refers to is activated.
+ Ioctl(ctx context.Context, file *File, io usermem.IO, args arch.SyscallArguments) (uintptr, error)
+}
+
+// FifoSizer is an interface for setting and getting the size of a pipe.
+type FifoSizer interface {
+ // FifoSize returns the pipe capacity in bytes.
+ FifoSize(ctx context.Context, file *File) (int64, error)
+
+ // SetFifoSize sets the new pipe capacity in bytes.
+ //
+ // The new size is returned (which may be capped).
+ SetFifoSize(size int64) (int64, error)
+}
diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go
new file mode 100644
index 000000000..dcc1df38f
--- /dev/null
+++ b/pkg/sentry/fs/file_overlay.go
@@ -0,0 +1,556 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// overlayFile gets a handle to a file from the upper or lower filesystem
+// in an overlay. The caller is responsible for calling File.DecRef on
+// the returned file.
+func overlayFile(ctx context.Context, inode *Inode, flags FileFlags) (*File, error) {
+ // Do a song and dance to eventually get to:
+ //
+ // File -> single reference
+ // Dirent -> single reference
+ // Inode -> multiple references
+ //
+ // So that File.DecRef() -> File.destroy -> Dirent.DecRef -> Dirent.destroy,
+ // and both the transitory File and Dirent can be GC'ed but the Inode
+ // remains.
+
+ // Take another reference on the Inode.
+ inode.IncRef()
+
+ // Start with a single reference on the Dirent. It inherits the reference
+ // we just took on the Inode above.
+ dirent := NewTransientDirent(inode)
+
+ // Get a File. This will take another reference on the Dirent.
+ f, err := inode.GetFile(ctx, dirent, flags)
+
+ // Drop the extra reference on the Dirent. Now there's only one reference
+ // on the dirent, either owned by f (if non-nil), or the Dirent is about
+ // to be destroyed (if GetFile failed).
+ dirent.DecRef()
+
+ return f, err
+}
+
+// overlayFileOperations implements FileOperations for a file in an overlay.
+//
+// +stateify savable
+type overlayFileOperations struct {
+ // upperMu protects upper below. In contrast lower is stable.
+ upperMu sync.Mutex `state:"nosave"`
+
+ // We can't share Files in upper and lower filesystems between all Files
+ // in an overlay because some file systems expect to get distinct handles
+ // that are not consistent with each other on open(2).
+ //
+ // So we lazily acquire an upper File when the overlayEntry acquires an
+ // upper Inode (it might have one from the start). This synchronizes with
+ // copy up.
+ //
+ // If upper is non-nil and this is not a directory, then lower is ignored.
+ //
+ // For directories, upper and lower are ignored because it is always
+ // necessary to acquire new directory handles so that the directory cursors
+ // of the upper and lower Files are not exhausted.
+ upper *File
+ lower *File
+
+ // dirCursor is a directory cursor for a directory in an overlay. It is
+ // protected by File.mu of the owning file, which is held during
+ // Readdir and Seek calls.
+ dirCursor string
+}
+
+// Release implements FileOperations.Release.
+func (f *overlayFileOperations) Release() {
+ if f.upper != nil {
+ f.upper.DecRef()
+ }
+ if f.lower != nil {
+ f.lower.DecRef()
+ }
+}
+
+// EventRegister implements FileOperations.EventRegister.
+func (f *overlayFileOperations) EventRegister(we *waiter.Entry, mask waiter.EventMask) {
+ f.upperMu.Lock()
+ defer f.upperMu.Unlock()
+ if f.upper != nil {
+ f.upper.EventRegister(we, mask)
+ return
+ }
+ f.lower.EventRegister(we, mask)
+}
+
+// EventUnregister implements FileOperations.Unregister.
+func (f *overlayFileOperations) EventUnregister(we *waiter.Entry) {
+ f.upperMu.Lock()
+ defer f.upperMu.Unlock()
+ if f.upper != nil {
+ f.upper.EventUnregister(we)
+ return
+ }
+ f.lower.EventUnregister(we)
+}
+
+// Readiness implements FileOperations.Readiness.
+func (f *overlayFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ f.upperMu.Lock()
+ defer f.upperMu.Unlock()
+ if f.upper != nil {
+ return f.upper.Readiness(mask)
+ }
+ return f.lower.Readiness(mask)
+}
+
+// Seek implements FileOperations.Seek.
+func (f *overlayFileOperations) Seek(ctx context.Context, file *File, whence SeekWhence, offset int64) (int64, error) {
+ f.upperMu.Lock()
+ defer f.upperMu.Unlock()
+
+ var seekDir bool
+ var n int64
+ if f.upper != nil {
+ var err error
+ if n, err = f.upper.FileOperations.Seek(ctx, file, whence, offset); err != nil {
+ return n, err
+ }
+ seekDir = IsDir(f.upper.Dirent.Inode.StableAttr)
+ } else {
+ var err error
+ if n, err = f.lower.FileOperations.Seek(ctx, file, whence, offset); err != nil {
+ return n, err
+ }
+ seekDir = IsDir(f.lower.Dirent.Inode.StableAttr)
+ }
+
+ // If this was a seek on a directory, we must update the cursor.
+ if seekDir && whence == SeekSet && offset == 0 {
+ // Currently only seeking to 0 on a directory is supported.
+ // FIXME(b/33075855): Lift directory seeking limitations.
+ f.dirCursor = ""
+ }
+ return n, nil
+}
+
+// Readdir implements FileOperations.Readdir.
+func (f *overlayFileOperations) Readdir(ctx context.Context, file *File, serializer DentrySerializer) (int64, error) {
+ root := RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+
+ dirCtx := &DirCtx{
+ Serializer: serializer,
+ DirCursor: &f.dirCursor,
+ }
+ return DirentReaddir(ctx, file.Dirent, f, root, dirCtx, file.Offset())
+}
+
+// IterateDir implements DirIterator.IterateDir.
+func (f *overlayFileOperations) IterateDir(ctx context.Context, d *Dirent, dirCtx *DirCtx, offset int) (int, error) {
+ o := d.Inode.overlay
+
+ if !d.Inode.MountSource.CacheReaddir() {
+ // Can't use the dirCache. Simply read the entries.
+ entries, err := readdirEntries(ctx, o)
+ if err != nil {
+ return offset, err
+ }
+ n, err := GenericReaddir(dirCtx, entries)
+ return offset + n, err
+ }
+
+ // Otherwise, use or create cached entries.
+
+ o.dirCacheMu.RLock()
+ if o.dirCache != nil {
+ n, err := GenericReaddir(dirCtx, o.dirCache)
+ o.dirCacheMu.RUnlock()
+ return offset + n, err
+ }
+ o.dirCacheMu.RUnlock()
+
+ // readdirEntries holds o.copyUpMu to ensure that copy-up does not
+ // occur while calculating the readdir results.
+ //
+ // However, it is possible for a copy-up to occur after the call to
+ // readdirEntries, but before setting o.dirCache. This is OK, since
+ // copy-up does not change the children in a way that would affect the
+ // children returned in dirCache. Copy-up only moves files/directories
+ // between layers in the overlay.
+ //
+ // We must hold dirCacheMu around both readdirEntries and setting
+ // o.dirCache to synchronize with dirCache invalidations done by
+ // Create, Remove, Rename.
+ o.dirCacheMu.Lock()
+
+ // We expect dirCache to be nil (we just checked above), but there is a
+ // chance that a racing call managed to just set it, in which case we
+ // can use that new value.
+ if o.dirCache == nil {
+ dirCache, err := readdirEntries(ctx, o)
+ if err != nil {
+ o.dirCacheMu.Unlock()
+ return offset, err
+ }
+ o.dirCache = dirCache
+ }
+
+ o.dirCacheMu.DowngradeLock()
+ n, err := GenericReaddir(dirCtx, o.dirCache)
+ o.dirCacheMu.RUnlock()
+
+ return offset + n, err
+}
+
+// onTop performs the given operation on the top-most available layer.
+func (f *overlayFileOperations) onTop(ctx context.Context, file *File, fn func(*File, FileOperations) error) error {
+ file.Dirent.Inode.overlay.copyMu.RLock()
+ defer file.Dirent.Inode.overlay.copyMu.RUnlock()
+
+ // Only lower layer is available.
+ if file.Dirent.Inode.overlay.upper == nil {
+ return fn(f.lower, f.lower.FileOperations)
+ }
+
+ f.upperMu.Lock()
+ if f.upper == nil {
+ upper, err := overlayFile(ctx, file.Dirent.Inode.overlay.upper, file.Flags())
+ if err != nil {
+ // Something very wrong; return a generic filesystem
+ // error to avoid propagating internals.
+ f.upperMu.Unlock()
+ return syserror.EIO
+ }
+
+ // Save upper file.
+ f.upper = upper
+ }
+ f.upperMu.Unlock()
+
+ return fn(f.upper, f.upper.FileOperations)
+}
+
+// Read implements FileOperations.Read.
+func (f *overlayFileOperations) Read(ctx context.Context, file *File, dst usermem.IOSequence, offset int64) (n int64, err error) {
+ err = f.onTop(ctx, file, func(file *File, ops FileOperations) error {
+ n, err = ops.Read(ctx, file, dst, offset)
+ return err // Will overwrite itself.
+ })
+ return
+}
+
+// WriteTo implements FileOperations.WriteTo.
+func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (n int64, err error) {
+ err = f.onTop(ctx, file, func(file *File, ops FileOperations) error {
+ n, err = ops.WriteTo(ctx, file, dst, count, dup)
+ return err // Will overwrite itself.
+ })
+ return
+}
+
+// Write implements FileOperations.Write.
+func (f *overlayFileOperations) Write(ctx context.Context, file *File, src usermem.IOSequence, offset int64) (int64, error) {
+ // f.upper must be non-nil. See inode_overlay.go:overlayGetFile, where the
+ // file is copied up and opened in the upper filesystem if FileFlags.Write.
+ // Write cannot be called if !FileFlags.Write, see FileOperations.Write.
+ return f.upper.FileOperations.Write(ctx, f.upper, src, offset)
+}
+
+// ReadFrom implements FileOperations.ReadFrom.
+func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (n int64, err error) {
+ // See above; f.upper must be non-nil.
+ return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, count)
+}
+
+// Fsync implements FileOperations.Fsync.
+func (f *overlayFileOperations) Fsync(ctx context.Context, file *File, start, end int64, syncType SyncType) (err error) {
+ f.upperMu.Lock()
+ if f.upper != nil {
+ err = f.upper.FileOperations.Fsync(ctx, f.upper, start, end, syncType)
+ }
+ f.upperMu.Unlock()
+ if err == nil && f.lower != nil {
+ // N.B. Fsync on the lower filesystem can cause writes of file
+ // attributes (i.e. access time) despite the fact that we must
+ // treat the lower filesystem as read-only.
+ //
+ // This matches the semantics of fsync(2) in Linux overlayfs.
+ err = f.lower.FileOperations.Fsync(ctx, f.lower, start, end, syncType)
+ }
+ return err
+}
+
+// Flush implements FileOperations.Flush.
+func (f *overlayFileOperations) Flush(ctx context.Context, file *File) (err error) {
+ // Flush whatever handles we have.
+ f.upperMu.Lock()
+ if f.upper != nil {
+ err = f.upper.FileOperations.Flush(ctx, f.upper)
+ }
+ f.upperMu.Unlock()
+ if err == nil && f.lower != nil {
+ err = f.lower.FileOperations.Flush(ctx, f.lower)
+ }
+ return err
+}
+
+// ConfigureMMap implements FileOperations.ConfigureMMap.
+func (*overlayFileOperations) ConfigureMMap(ctx context.Context, file *File, opts *memmap.MMapOpts) error {
+ o := file.Dirent.Inode.overlay
+
+ o.copyMu.RLock()
+ defer o.copyMu.RUnlock()
+
+ // If there is no lower inode, the overlay will never need to do a
+ // copy-up, and thus will never need to invalidate any mappings. We can
+ // call ConfigureMMap directly on the upper file.
+ if o.lower == nil {
+ f := file.FileOperations.(*overlayFileOperations)
+ if err := f.upper.ConfigureMMap(ctx, opts); err != nil {
+ return err
+ }
+
+ // ConfigureMMap will set the MappableIdentity to the upper
+ // file and take a reference on it, but we must also hold a
+ // reference to the overlay file during the lifetime of the
+ // Mappable. If we do not do this, the overlay file can be
+ // Released before the upper file is Released, and we will be
+ // unable to traverse to the upper file during Save, thus
+ // preventing us from saving a proper inode mapping for the
+ // file.
+ file.IncRef()
+ id := overlayMappingIdentity{
+ id: opts.MappingIdentity,
+ overlayFile: file,
+ }
+ id.EnableLeakCheck("fs.overlayMappingIdentity")
+
+ // Swap out the old MappingIdentity for the wrapped one.
+ opts.MappingIdentity = &id
+ return nil
+ }
+
+ if !o.isMappableLocked() {
+ return syserror.ENODEV
+ }
+
+ // FIXME(jamieliu): This is a copy/paste of fsutil.GenericConfigureMMap,
+ // which we can't use because the overlay implementation is in package fs,
+ // so depending on fs/fsutil would create a circular dependency. Move
+ // overlay to fs/overlay.
+ opts.Mappable = o
+ opts.MappingIdentity = file
+ file.IncRef()
+ return nil
+}
+
+// UnstableAttr implements fs.FileOperations.UnstableAttr.
+func (f *overlayFileOperations) UnstableAttr(ctx context.Context, file *File) (UnstableAttr, error) {
+ // Hot path. Avoid defers.
+ f.upperMu.Lock()
+ if f.upper != nil {
+ attr, err := f.upper.UnstableAttr(ctx)
+ f.upperMu.Unlock()
+ return attr, err
+ }
+ f.upperMu.Unlock()
+
+ // It's possible that copy-up has occurred, but we haven't opened a upper
+ // file yet. If this is the case, just use the upper inode's UnstableAttr
+ // rather than opening a file.
+ o := file.Dirent.Inode.overlay
+ o.copyMu.RLock()
+ if o.upper != nil {
+ attr, err := o.upper.UnstableAttr(ctx)
+ o.copyMu.RUnlock()
+ return attr, err
+ }
+ o.copyMu.RUnlock()
+
+ return f.lower.UnstableAttr(ctx)
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (f *overlayFileOperations) Ioctl(ctx context.Context, overlayFile *File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ f.upperMu.Lock()
+ defer f.upperMu.Unlock()
+
+ if f.upper == nil {
+ // It's possible that ioctl changes the file. Since we don't know all
+ // possible ioctls, only allow them to propagate to the upper. Triggering a
+ // copy up on any ioctl would be too drastic. In the future, it can have a
+ // list of ioctls that are safe to send to lower and a list that triggers a
+ // copy up.
+ return 0, syserror.ENOTTY
+ }
+ return f.upper.FileOperations.Ioctl(ctx, f.upper, io, args)
+}
+
+// FifoSize implements FifoSizer.FifoSize.
+func (f *overlayFileOperations) FifoSize(ctx context.Context, overlayFile *File) (rv int64, err error) {
+ err = f.onTop(ctx, overlayFile, func(file *File, ops FileOperations) error {
+ sz, ok := ops.(FifoSizer)
+ if !ok {
+ return syserror.EINVAL
+ }
+ rv, err = sz.FifoSize(ctx, file)
+ return err
+ })
+ return
+}
+
+// SetFifoSize implements FifoSizer.SetFifoSize.
+func (f *overlayFileOperations) SetFifoSize(size int64) (rv int64, err error) {
+ f.upperMu.Lock()
+ defer f.upperMu.Unlock()
+
+ if f.upper == nil {
+ // Named pipes cannot be copied up and changes to the lower are prohibited.
+ return 0, syserror.EINVAL
+ }
+ sz, ok := f.upper.FileOperations.(FifoSizer)
+ if !ok {
+ return 0, syserror.EINVAL
+ }
+ return sz.SetFifoSize(size)
+}
+
+// readdirEntries returns a sorted map of directory entries from the
+// upper and/or lower filesystem.
+func readdirEntries(ctx context.Context, o *overlayEntry) (*SortedDentryMap, error) {
+ o.copyMu.RLock()
+ defer o.copyMu.RUnlock()
+
+ // Assert that there is at least one upper or lower entry.
+ if o.upper == nil && o.lower == nil {
+ panic("invalid overlayEntry, needs at least one Inode")
+ }
+ entries := make(map[string]DentAttr)
+
+ // Try the upper filesystem first.
+ if o.upper != nil {
+ var err error
+ entries, err = readdirOne(ctx, NewTransientDirent(o.upper))
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // Try the lower filesystem next.
+ if o.lower != nil {
+ lowerEntries, err := readdirOne(ctx, NewTransientDirent(o.lower))
+ if err != nil {
+ return nil, err
+ }
+ for name, entry := range lowerEntries {
+ // Skip this name if it is a negative entry in the
+ // upper or there exists a whiteout for it.
+ if o.upper != nil {
+ if overlayHasWhiteout(ctx, o.upper, name) {
+ continue
+ }
+ }
+ // Prefer the entries from the upper filesystem
+ // when names overlap.
+ if _, ok := entries[name]; !ok {
+ entries[name] = entry
+ }
+ }
+ }
+
+ // Sort and return the entries.
+ return NewSortedDentryMap(entries), nil
+}
+
+// readdirOne reads all of the directory entries from d.
+func readdirOne(ctx context.Context, d *Dirent) (map[string]DentAttr, error) {
+ dir, err := d.Inode.GetFile(ctx, d, FileFlags{Read: true})
+ if err != nil {
+ return nil, err
+ }
+ defer dir.DecRef()
+
+ // Use a stub serializer to read the entries into memory.
+ stubSerializer := &CollectEntriesSerializer{}
+ if err := dir.Readdir(ctx, stubSerializer); err != nil {
+ return nil, err
+ }
+ // The "." and ".." entries are from the overlay Inode's Dirent, not the stub.
+ delete(stubSerializer.Entries, ".")
+ delete(stubSerializer.Entries, "..")
+ return stubSerializer.Entries, nil
+}
+
+// overlayMappingIdentity wraps a MappingIdentity, and also holds a reference
+// on a file during its lifetime.
+//
+// +stateify savable
+type overlayMappingIdentity struct {
+ refs.AtomicRefCount
+ id memmap.MappingIdentity
+ overlayFile *File
+}
+
+// DecRef implements AtomicRefCount.DecRef.
+func (omi *overlayMappingIdentity) DecRef() {
+ omi.AtomicRefCount.DecRefWithDestructor(func() {
+ omi.overlayFile.DecRef()
+ omi.id.DecRef()
+ })
+}
+
+// DeviceID implements MappingIdentity.DeviceID using the device id from the
+// overlayFile.
+func (omi *overlayMappingIdentity) DeviceID() uint64 {
+ return omi.overlayFile.Dirent.Inode.StableAttr.DeviceID
+}
+
+// DeviceID implements MappingIdentity.InodeID using the inode id from the
+// overlayFile.
+func (omi *overlayMappingIdentity) InodeID() uint64 {
+ return omi.overlayFile.Dirent.Inode.StableAttr.InodeID
+}
+
+// MappedName implements MappingIdentity.MappedName.
+func (omi *overlayMappingIdentity) MappedName(ctx context.Context) string {
+ root := RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+ name, _ := omi.overlayFile.Dirent.FullName(root)
+ return name
+}
+
+// Msync implements MappingIdentity.Msync.
+func (omi *overlayMappingIdentity) Msync(ctx context.Context, mr memmap.MappableRange) error {
+ return omi.id.Msync(ctx, mr)
+}
diff --git a/pkg/sentry/fs/file_overlay_test.go b/pkg/sentry/fs/file_overlay_test.go
new file mode 100644
index 000000000..1971cc680
--- /dev/null
+++ b/pkg/sentry/fs/file_overlay_test.go
@@ -0,0 +1,192 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs_test
+
+import (
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
+)
+
+func TestReaddir(t *testing.T) {
+ ctx := contexttest.Context(t)
+ ctx = &rootContext{
+ Context: ctx,
+ root: fs.NewDirent(ctx, newTestRamfsDir(ctx, nil, nil), "root"),
+ }
+ for _, test := range []struct {
+ // Test description.
+ desc string
+
+ // Lookup parameters.
+ dir *fs.Inode
+
+ // Want from lookup.
+ err error
+ names []string
+ }{
+ {
+ desc: "no upper, lower has entries",
+ dir: fs.NewTestOverlayDir(ctx,
+ nil, /* upper */
+ newTestRamfsDir(ctx, []dirContent{
+ {name: "a"},
+ {name: "b"},
+ }, nil), /* lower */
+ false /* revalidate */),
+ names: []string{".", "..", "a", "b"},
+ },
+ {
+ desc: "upper has entries, no lower",
+ dir: fs.NewTestOverlayDir(ctx,
+ newTestRamfsDir(ctx, []dirContent{
+ {name: "a"},
+ {name: "b"},
+ }, nil), /* upper */
+ nil, /* lower */
+ false /* revalidate */),
+ names: []string{".", "..", "a", "b"},
+ },
+ {
+ desc: "upper and lower, entries combine",
+ dir: fs.NewTestOverlayDir(ctx,
+ newTestRamfsDir(ctx, []dirContent{
+ {name: "a"},
+ }, nil), /* upper */
+ newTestRamfsDir(ctx, []dirContent{
+ {name: "b"},
+ }, nil), /* lower */
+ false /* revalidate */),
+ names: []string{".", "..", "a", "b"},
+ },
+ {
+ desc: "upper and lower, entries combine, none are masked",
+ dir: fs.NewTestOverlayDir(ctx,
+ newTestRamfsDir(ctx, []dirContent{
+ {name: "a"},
+ }, []string{"b"}), /* upper */
+ newTestRamfsDir(ctx, []dirContent{
+ {name: "c"},
+ }, nil), /* lower */
+ false /* revalidate */),
+ names: []string{".", "..", "a", "c"},
+ },
+ {
+ desc: "upper and lower, entries combine, upper masks some of lower",
+ dir: fs.NewTestOverlayDir(ctx,
+ newTestRamfsDir(ctx, []dirContent{
+ {name: "a"},
+ }, []string{"b"}), /* upper */
+ newTestRamfsDir(ctx, []dirContent{
+ {name: "b"}, /* will be masked */
+ {name: "c"},
+ }, nil), /* lower */
+ false /* revalidate */),
+ names: []string{".", "..", "a", "c"},
+ },
+ } {
+ t.Run(test.desc, func(t *testing.T) {
+ openDir, err := test.dir.GetFile(ctx, fs.NewDirent(ctx, test.dir, "stub"), fs.FileFlags{Read: true})
+ if err != nil {
+ t.Fatalf("GetFile got error %v, want nil", err)
+ }
+ stubSerializer := &fs.CollectEntriesSerializer{}
+ err = openDir.Readdir(ctx, stubSerializer)
+ if err != test.err {
+ t.Fatalf("Readdir got error %v, want nil", err)
+ }
+ if err != nil {
+ return
+ }
+ if !reflect.DeepEqual(stubSerializer.Order, test.names) {
+ t.Errorf("Readdir got names %v, want %v", stubSerializer.Order, test.names)
+ }
+ })
+ }
+}
+
+func TestReaddirRevalidation(t *testing.T) {
+ ctx := contexttest.Context(t)
+ ctx = &rootContext{
+ Context: ctx,
+ root: fs.NewDirent(ctx, newTestRamfsDir(ctx, nil, nil), "root"),
+ }
+
+ // Create an overlay with two directories, each with one file.
+ upper := newTestRamfsDir(ctx, []dirContent{{name: "a"}}, nil)
+ lower := newTestRamfsDir(ctx, []dirContent{{name: "b"}}, nil)
+ overlay := fs.NewTestOverlayDir(ctx, upper, lower, true /* revalidate */)
+
+ // Get a handle to the dirent in the upper filesystem so that we can
+ // modify it without going through the dirent.
+ upperDir := upper.InodeOperations.(*dir).InodeOperations.(*ramfs.Dir)
+
+ // Check that overlay returns the files from both upper and lower.
+ openDir, err := overlay.GetFile(ctx, fs.NewDirent(ctx, overlay, "stub"), fs.FileFlags{Read: true})
+ if err != nil {
+ t.Fatalf("GetFile got error %v, want nil", err)
+ }
+ ser := &fs.CollectEntriesSerializer{}
+ if err := openDir.Readdir(ctx, ser); err != nil {
+ t.Fatalf("Readdir got error %v, want nil", err)
+ }
+ got, want := ser.Order, []string{".", "..", "a", "b"}
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Readdir got names %v, want %v", got, want)
+ }
+
+ // Remove "a" from the upper and add "c".
+ if err := upperDir.Remove(ctx, upper, "a"); err != nil {
+ t.Fatalf("error removing child: %v", err)
+ }
+ upperDir.AddChild(ctx, "c", fs.NewInode(ctx, fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermissions{}, 0),
+ upper.MountSource, fs.StableAttr{Type: fs.RegularFile}))
+
+ // Seek to beginning of the directory and do the readdir again.
+ if _, err := openDir.Seek(ctx, fs.SeekSet, 0); err != nil {
+ t.Fatalf("error seeking to beginning of dir: %v", err)
+ }
+ ser = &fs.CollectEntriesSerializer{}
+ if err := openDir.Readdir(ctx, ser); err != nil {
+ t.Fatalf("Readdir got error %v, want nil", err)
+ }
+
+ // Readdir should return the updated children.
+ got, want = ser.Order, []string{".", "..", "b", "c"}
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Readdir got names %v, want %v", got, want)
+ }
+}
+
+type rootContext struct {
+ context.Context
+ root *fs.Dirent
+}
+
+// Value implements context.Context.
+func (r *rootContext) Value(key interface{}) interface{} {
+ switch key {
+ case fs.CtxRoot:
+ r.root.IncRef()
+ return r.root
+ default:
+ return r.Context.Value(key)
+ }
+}
diff --git a/pkg/sentry/fs/file_state.go b/pkg/sentry/fs/file_state.go
new file mode 100644
index 000000000..523182d59
--- /dev/null
+++ b/pkg/sentry/fs/file_state.go
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+// beforeSave is invoked by stateify.
+func (f *File) beforeSave() {
+ f.saving = true
+ if f.flags.Async && f.async != nil {
+ f.async.Unregister(f)
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (f *File) afterLoad() {
+ f.mu.Init()
+ if f.flags.Async && f.async != nil {
+ f.async.Register(f)
+ }
+}
diff --git a/pkg/sentry/fs/filesystems.go b/pkg/sentry/fs/filesystems.go
new file mode 100644
index 000000000..d41f30bbb
--- /dev/null
+++ b/pkg/sentry/fs/filesystems.go
@@ -0,0 +1,160 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "sort"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// FilesystemFlags matches include/linux/fs.h:file_system_type.fs_flags.
+type FilesystemFlags int
+
+const (
+ // FilesystemRequiresDev indicates that the file system requires a device name
+ // on mount. It is used to construct the output of /proc/filesystems.
+ FilesystemRequiresDev FilesystemFlags = 1
+
+ // Currently other flags are not used, but can be pulled in from
+ // include/linux/fs.h:file_system_type as needed.
+)
+
+// Filesystem is a mountable file system.
+type Filesystem interface {
+ // Name is the unique identifier of the file system. It corresponds to the
+ // filesystemtype argument of sys_mount and will appear in the output of
+ // /proc/filesystems.
+ Name() string
+
+ // Flags indicate common properties of the file system.
+ Flags() FilesystemFlags
+
+ // Mount generates a mountable Inode backed by device and configured
+ // using file system independent flags and file system dependent
+ // data options.
+ //
+ // Mount may return arbitrary errors. They do not need syserr translations.
+ Mount(ctx context.Context, device string, flags MountSourceFlags, data string, dataObj interface{}) (*Inode, error)
+
+ // AllowUserMount determines whether mount(2) is allowed to mount a
+ // file system of this type.
+ AllowUserMount() bool
+
+ // AllowUserList determines whether this filesystem is listed in
+ // /proc/filesystems
+ AllowUserList() bool
+}
+
+// filesystems is the global set of registered file systems. It does not need
+// to be saved. Packages registering and unregistering file systems must do so
+// before calling save/restore methods.
+var filesystems = struct {
+ // mu protects registered below.
+ mu sync.Mutex
+
+ // registered is a set of registered Filesystems.
+ registered map[string]Filesystem
+}{
+ registered: make(map[string]Filesystem),
+}
+
+// RegisterFilesystem registers a new file system that is visible to mount and
+// the /proc/filesystems list. Packages implementing Filesystem should call
+// RegisterFilesystem in init().
+func RegisterFilesystem(f Filesystem) {
+ filesystems.mu.Lock()
+ defer filesystems.mu.Unlock()
+
+ if _, ok := filesystems.registered[f.Name()]; ok {
+ panic(fmt.Sprintf("filesystem already registered at %q", f.Name()))
+ }
+ filesystems.registered[f.Name()] = f
+}
+
+// FindFilesystem returns a Filesystem registered at name or (nil, false) if name
+// is not a file system type that can be found in /proc/filesystems.
+func FindFilesystem(name string) (Filesystem, bool) {
+ filesystems.mu.Lock()
+ defer filesystems.mu.Unlock()
+
+ f, ok := filesystems.registered[name]
+ return f, ok
+}
+
+// GetFilesystems returns the set of registered filesystems in a consistent order.
+func GetFilesystems() []Filesystem {
+ filesystems.mu.Lock()
+ defer filesystems.mu.Unlock()
+
+ var ss []Filesystem
+ for _, s := range filesystems.registered {
+ ss = append(ss, s)
+ }
+ sort.Slice(ss, func(i, j int) bool { return ss[i].Name() < ss[j].Name() })
+ return ss
+}
+
+// MountSourceFlags represents all mount option flags as a struct.
+//
+// +stateify savable
+type MountSourceFlags struct {
+ // ReadOnly corresponds to mount(2)'s "MS_RDONLY" and indicates that
+ // the filesystem should be mounted read-only.
+ ReadOnly bool
+
+ // NoAtime corresponds to mount(2)'s "MS_NOATIME" and indicates that
+ // the filesystem should not update access time in-place.
+ NoAtime bool
+
+ // ForcePageCache causes all filesystem I/O operations to use the page
+ // cache, even when the platform supports direct mapped I/O. This
+ // doesn't correspond to any Linux mount options.
+ ForcePageCache bool
+
+ // NoExec corresponds to mount(2)'s "MS_NOEXEC" and indicates that
+ // binaries from this file system can't be executed.
+ NoExec bool
+}
+
+// GenericMountSourceOptions splits a string containing comma separated tokens of the
+// format 'key=value' or 'key' into a map of keys and values. For example:
+//
+// data = "key0=value0,key1,key2=value2" -> map{'key0':'value0','key1':'','key2':'value2'}
+//
+// If data contains duplicate keys, then the last token wins.
+func GenericMountSourceOptions(data string) map[string]string {
+ options := make(map[string]string)
+ if len(data) == 0 {
+ // Don't return a nil map, callers might not be expecting that.
+ return options
+ }
+
+ // Parse options and skip empty ones.
+ for _, opt := range strings.Split(data, ",") {
+ if len(opt) > 0 {
+ res := strings.SplitN(opt, "=", 2)
+ if len(res) == 2 {
+ options[res[0]] = res[1]
+ } else {
+ options[opt] = ""
+ }
+ }
+ }
+ return options
+}
diff --git a/pkg/sentry/fs/filetest/BUILD b/pkg/sentry/fs/filetest/BUILD
new file mode 100644
index 000000000..a8000e010
--- /dev/null
+++ b/pkg/sentry/fs/filetest/BUILD
@@ -0,0 +1,19 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "filetest",
+ testonly = 1,
+ srcs = ["filetest.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/context",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/anon",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/fs/filetest/filetest.go b/pkg/sentry/fs/filetest/filetest.go
new file mode 100644
index 000000000..8049538f2
--- /dev/null
+++ b/pkg/sentry/fs/filetest/filetest.go
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package filetest provides a test implementation of an fs.File.
+package filetest
+
+import (
+ "fmt"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/anon"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// TestFileOperations is an implementation of the File interface. It provides all
+// required methods.
+type TestFileOperations struct {
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+}
+
+// NewTestFile creates and initializes a new test file.
+func NewTestFile(tb testing.TB) *fs.File {
+ ctx := contexttest.Context(tb)
+ dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "test")
+ return fs.NewFile(ctx, dirent, fs.FileFlags{}, &TestFileOperations{})
+}
+
+// Read just fails the request.
+func (*TestFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, fmt.Errorf("Readv not implemented")
+}
+
+// Write just fails the request.
+func (*TestFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, fmt.Errorf("Writev not implemented")
+}
diff --git a/pkg/sentry/fs/flags.go b/pkg/sentry/fs/flags.go
new file mode 100644
index 000000000..4338ae1fa
--- /dev/null
+++ b/pkg/sentry/fs/flags.go
@@ -0,0 +1,138 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// FileFlags encodes file flags.
+//
+// +stateify savable
+type FileFlags struct {
+ // Direct indicates that I/O should be done directly.
+ Direct bool
+
+ // NonBlocking indicates that I/O should not block.
+ NonBlocking bool
+
+ // DSync indicates that each write will flush data and metadata required to
+ // read the file's contents.
+ DSync bool
+
+ // Sync indicates that each write will flush data and all file metadata.
+ Sync bool
+
+ // Append indicates this file is append only.
+ Append bool
+
+ // Read indicates this file is readable.
+ Read bool
+
+ // Write indicates this file is writeable.
+ Write bool
+
+ // Pread indicates this file is readable at an arbitrary offset.
+ Pread bool
+
+ // Pwrite indicates this file is writable at an arbitrary offset.
+ Pwrite bool
+
+ // Directory indicates that this file must be a directory.
+ Directory bool
+
+ // Async indicates that this file sends signals on IO events.
+ Async bool
+
+ // LargeFile indicates that this file should be opened even if it has
+ // size greater than linux's off_t. When running in 64-bit mode,
+ // Linux sets this flag for all files. Since gVisor is only compatible
+ // with 64-bit Linux, it also sets this flag for all files.
+ LargeFile bool
+
+ // NonSeekable indicates that file.offset isn't used.
+ NonSeekable bool
+
+ // Truncate indicates that the file should be truncated before opened.
+ // This is only applicable if the file is regular.
+ Truncate bool
+}
+
+// SettableFileFlags is a subset of FileFlags above that can be changed
+// via fcntl(2) using the F_SETFL command.
+type SettableFileFlags struct {
+ // Direct indicates that I/O should be done directly.
+ Direct bool
+
+ // NonBlocking indicates that I/O should not block.
+ NonBlocking bool
+
+ // Append indicates this file is append only.
+ Append bool
+
+ // Async indicates that this file sends signals on IO events.
+ Async bool
+}
+
+// Settable returns the subset of f that are settable.
+func (f FileFlags) Settable() SettableFileFlags {
+ return SettableFileFlags{
+ Direct: f.Direct,
+ NonBlocking: f.NonBlocking,
+ Append: f.Append,
+ Async: f.Async,
+ }
+}
+
+// ToLinux converts a FileFlags object to a Linux representation.
+func (f FileFlags) ToLinux() (mask uint) {
+ if f.Direct {
+ mask |= linux.O_DIRECT
+ }
+ if f.NonBlocking {
+ mask |= linux.O_NONBLOCK
+ }
+ if f.DSync {
+ mask |= linux.O_DSYNC
+ }
+ if f.Sync {
+ mask |= linux.O_SYNC
+ }
+ if f.Append {
+ mask |= linux.O_APPEND
+ }
+ if f.Directory {
+ mask |= linux.O_DIRECTORY
+ }
+ if f.Async {
+ mask |= linux.O_ASYNC
+ }
+ if f.LargeFile {
+ mask |= linux.O_LARGEFILE
+ }
+ if f.Truncate {
+ mask |= linux.O_TRUNC
+ }
+
+ switch {
+ case f.Read && f.Write:
+ mask |= linux.O_RDWR
+ case f.Write:
+ mask |= linux.O_WRONLY
+ case f.Read:
+ mask |= linux.O_RDONLY
+ }
+ return
+}
diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go
new file mode 100644
index 000000000..d2dbff268
--- /dev/null
+++ b/pkg/sentry/fs/fs.go
@@ -0,0 +1,161 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fs implements a virtual filesystem layer.
+//
+// Specific filesystem implementations must implement the InodeOperations
+// interface (inode.go).
+//
+// The MountNamespace (mounts.go) is used to create a collection of mounts in
+// a filesystem rooted at a given Inode.
+//
+// MountSources (mount.go) form a tree, with each mount holding pointers to its
+// parent and children.
+//
+// Dirents (dirents.go) wrap Inodes in a caching layer.
+//
+// When multiple locks are to be held at the same time, they should be acquired
+// in the following order.
+//
+// Either:
+// File.mu
+// Locks in FileOperations implementations
+// goto Dirent-Locks
+//
+// Or:
+// MountNamespace.mu
+// goto Dirent-Locks
+//
+// Dirent-Locks:
+// renameMu
+// Dirent.dirMu
+// Dirent.mu
+// DirentCache.mu
+// Inode.Watches.mu (see `Inotify` for other lock ordering)
+// MountSource.mu
+// Inode.appendMu
+// Locks in InodeOperations implementations or overlayEntry
+//
+// If multiple Dirent or MountSource locks must be taken, locks in the parent must be
+// taken before locks in their children.
+//
+// If locks must be taken on multiple unrelated Dirents, renameMu must be taken
+// first. See lockForRename.
+package fs
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+var (
+ // workMu is used to synchronize pending asynchronous work. Async work
+ // runs with the lock held for reading. AsyncBarrier will take the lock
+ // for writing, thus ensuring that all Async work completes before
+ // AsyncBarrier returns.
+ workMu sync.RWMutex
+
+ // asyncError is used to store up to one asynchronous execution error.
+ asyncError = make(chan error, 1)
+)
+
+// AsyncBarrier waits for all outstanding asynchronous work to complete.
+func AsyncBarrier() {
+ workMu.Lock()
+ workMu.Unlock()
+}
+
+// Async executes a function asynchronously.
+//
+// Async must not be called recursively.
+func Async(f func()) {
+ workMu.RLock()
+ go func() { // S/R-SAFE: AsyncBarrier must be called.
+ defer workMu.RUnlock() // Ensure RUnlock in case of panic.
+ f()
+ }()
+}
+
+// AsyncWithContext is just like Async, except that it calls the asynchronous
+// function with the given context as argument. This function exists to avoid
+// needing to allocate an extra function on the heap in a hot path.
+func AsyncWithContext(ctx context.Context, f func(context.Context)) {
+ workMu.RLock()
+ go func() { // S/R-SAFE: AsyncBarrier must be called.
+ defer workMu.RUnlock() // Ensure RUnlock in case of panic.
+ f(ctx)
+ }()
+}
+
+// AsyncErrorBarrier waits for all outstanding asynchronous work to complete, or
+// the first async error to arrive. Other unfinished async executions will
+// continue in the background. Other past and future async errors are ignored.
+func AsyncErrorBarrier() error {
+ wait := make(chan struct{}, 1)
+ go func() { // S/R-SAFE: Does not touch persistent state.
+ AsyncBarrier()
+ wait <- struct{}{}
+ }()
+ select {
+ case <-wait:
+ select {
+ case err := <-asyncError:
+ return err
+ default:
+ return nil
+ }
+ case err := <-asyncError:
+ return err
+ }
+}
+
+// CatchError tries to capture the potential async error returned by the
+// function. At most one async error will be captured globally so excessive
+// errors will be dropped.
+func CatchError(f func() error) func() {
+ return func() {
+ if err := f(); err != nil {
+ select {
+ case asyncError <- err:
+ default:
+ log.Warningf("excessive async error dropped: %v", err)
+ }
+ }
+ }
+}
+
+// ErrSaveRejection indicates a failed save due to unsupported file system state
+// such as dangling open fd, etc.
+type ErrSaveRejection struct {
+ // Err is the wrapped error.
+ Err error
+}
+
+// Error returns a sensible description of the save rejection error.
+func (e ErrSaveRejection) Error() string {
+ return "save rejected due to unsupported file system state: " + e.Err.Error()
+}
+
+// ErrCorruption indicates a failed restore due to external file system state in
+// corruption.
+type ErrCorruption struct {
+ // Err is the wrapped error.
+ Err error
+}
+
+// Error returns a sensible description of the restore error.
+func (e ErrCorruption) Error() string {
+ return "restore failed due to external file system state in corruption: " + e.Err.Error()
+}
diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD
new file mode 100644
index 000000000..789369220
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/BUILD
@@ -0,0 +1,118 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "dirty_set_impl",
+ out = "dirty_set_impl.go",
+ imports = {
+ "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
+ "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ },
+ package = "fsutil",
+ prefix = "Dirty",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "uint64",
+ "Range": "memmap.MappableRange",
+ "Value": "DirtyInfo",
+ "Functions": "dirtySetFunctions",
+ },
+)
+
+go_template_instance(
+ name = "frame_ref_set_impl",
+ out = "frame_ref_set_impl.go",
+ imports = {
+ "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ },
+ package = "fsutil",
+ prefix = "FrameRef",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "uint64",
+ "Range": "platform.FileRange",
+ "Value": "uint64",
+ "Functions": "FrameRefSetFunctions",
+ },
+)
+
+go_template_instance(
+ name = "file_range_set_impl",
+ out = "file_range_set_impl.go",
+ imports = {
+ "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
+ "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ },
+ package = "fsutil",
+ prefix = "FileRange",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "uint64",
+ "Range": "memmap.MappableRange",
+ "Value": "uint64",
+ "Functions": "FileRangeSetFunctions",
+ },
+)
+
+go_library(
+ name = "fsutil",
+ srcs = [
+ "dirty_set.go",
+ "dirty_set_impl.go",
+ "file.go",
+ "file_range_set.go",
+ "file_range_set_impl.go",
+ "frame_ref_set.go",
+ "frame_ref_set_impl.go",
+ "fsutil.go",
+ "host_file_mapper.go",
+ "host_file_mapper_state.go",
+ "host_file_mapper_unsafe.go",
+ "host_mappable.go",
+ "inode.go",
+ "inode_cached.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/usage",
+ "//pkg/state",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "fsutil_test",
+ size = "small",
+ srcs = [
+ "dirty_set_test.go",
+ "inode_cached_test.go",
+ ],
+ library = ":fsutil",
+ deps = [
+ "//pkg/context",
+ "//pkg/safemem",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/memmap",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fs/fsutil/README.md b/pkg/sentry/fs/fsutil/README.md
new file mode 100644
index 000000000..8be367334
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/README.md
@@ -0,0 +1,207 @@
+This package provides utilities for implementing virtual filesystem objects.
+
+[TOC]
+
+## Page cache
+
+`CachingInodeOperations` implements a page cache for files that cannot use the
+host page cache. Normally these are files that store their data in a remote
+filesystem. This also applies to files that are accessed on a platform that does
+not support directly memory mapping host file descriptors (e.g. the ptrace
+platform).
+
+An `CachingInodeOperations` buffers regions of a single file into memory. It is
+owned by an `fs.Inode`, the in-memory representation of a file (all open file
+descriptors are backed by an `fs.Inode`). The `fs.Inode` provides operations for
+reading memory into an `CachingInodeOperations`, to represent the contents of
+the file in-memory, and for writing memory out, to relieve memory pressure on
+the kernel and to synchronize in-memory changes to filesystems.
+
+An `CachingInodeOperations` enables readable and/or writable memory access to
+file content. Files can be mapped shared or private, see mmap(2). When a file is
+mapped shared, changes to the file via write(2) and truncate(2) are reflected in
+the shared memory region. Conversely, when the shared memory region is modified,
+changes to the file are visible via read(2). Multiple shared mappings of the
+same file are coherent with each other. This is consistent with Linux.
+
+When a file is mapped private, updates to the mapped memory are not visible to
+other memory mappings. Updates to the mapped memory are also not reflected in
+the file content as seen by read(2). If the file is changed after a private
+mapping is created, for instance by write(2), the change to the file may or may
+not be reflected in the private mapping. This is consistent with Linux.
+
+An `CachingInodeOperations` keeps track of ranges of memory that were modified
+(or "dirtied"). When the file is explicitly synced via fsync(2), only the dirty
+ranges are written out to the filesystem. Any error returned indicates a failure
+to write all dirty memory of an `CachingInodeOperations` to the filesystem. In
+this case the filesystem may be in an inconsistent state. The same operation can
+be performed on the shared memory itself using msync(2). If neither fsync(2) nor
+msync(2) is performed, then the dirty memory is written out in accordance with
+the `CachingInodeOperations` eviction strategy (see below) and there is no
+guarantee that memory will be written out successfully in full.
+
+### Memory allocation and eviction
+
+An `CachingInodeOperations` implements the following allocation and eviction
+strategy:
+
+- Memory is allocated and brought up to date with the contents of a file when
+ a region of mapped memory is accessed (or "faulted on").
+
+- Dirty memory is written out to filesystems when an fsync(2) or msync(2)
+ operation is performed on a memory mapped file, for all memory mapped files
+ when saved, and/or when there are no longer any memory mappings of a range
+ of a file, see munmap(2). As the latter implies, in the absence of a panic
+ or SIGKILL, dirty memory is written out for all memory mapped files when an
+ application exits.
+
+- Memory is freed when there are no longer any memory mappings of a range of a
+ file (e.g. when an application exits). This behavior is consistent with
+ Linux for shared memory that has been locked via mlock(2).
+
+Notably, memory is not allocated for read(2) or write(2) operations. This means
+that reads and writes to the file are only accelerated by an
+`CachingInodeOperations` if the file being read or written has been memory
+mapped *and* if the shared memory has been accessed at the region being read or
+written. This diverges from Linux which buffers memory into a page cache on
+read(2) proactively (i.e. readahead) and delays writing it out to filesystems on
+write(2) (i.e. writeback). The absence of these optimizations is not visible to
+applications beyond less than optimal performance when repeatedly reading and/or
+writing to same region of a file. See [Future Work](#future-work) for plans to
+implement these optimizations.
+
+Additionally, memory held by `CachingInodeOperationss` is currently unbounded in
+size. An `CachingInodeOperations` does not write out dirty memory and free it
+under system memory pressure. This can cause pathological memory usage.
+
+When memory is written back, an `CachingInodeOperations` may write regions of
+shared memory that were never modified. This is due to the strategy of
+minimizing page faults (see below) and handling only a subset of memory write
+faults. In the absence of an application or sentry crash, it is guaranteed that
+if a region of shared memory was written to, it is written back to a filesystem.
+
+### Life of a shared memory mapping
+
+A file is memory mapped via mmap(2). For example, if `A` is an address, an
+application may execute:
+
+```
+mmap(A, 0x1000, PROT_READ|PROT_WRITE, MAP_SHARED, fd, 0);
+```
+
+This creates a shared mapping of fd that reflects 4k of the contents of fd
+starting at offset 0, accessible at address `A`. This in turn creates a virtual
+memory area region ("vma") which indicates that [`A`, `A`+0x1000) is now a valid
+address range for this application to access.
+
+At this point, memory has not been allocated in the file's
+`CachingInodeOperations`. It is also the case that the address range [`A`,
+`A`+0x1000) has not been mapped on the host on behalf of the application. If the
+application then tries to modify 8 bytes of the shared memory:
+
+```
+char buffer[] = "aaaaaaaa";
+memcpy(A, buffer, 8);
+```
+
+The host then sends a `SIGSEGV` to the sentry because the address range [`A`,
+`A`+8) is not mapped on the host. The `SIGSEGV` indicates that the memory was
+accessed writable. The sentry looks up the vma associated with [`A`, `A`+8),
+finds the file that was mapped and its `CachingInodeOperations`. It then calls
+`CachingInodeOperations.Translate` which allocates memory to back [`A`, `A`+8).
+It may choose to allocate more memory (i.e. do "readahead") to minimize
+subsequent faults.
+
+Memory that is allocated comes from a host tmpfs file (see
+`pgalloc.MemoryFile`). The host tmpfs file memory is brought up to date with the
+contents of the mapped file on its filesystem. The region of the host tmpfs file
+that reflects the mapped file is then mapped into the host address space of the
+application so that subsequent memory accesses do not repeatedly generate a
+`SIGSEGV`.
+
+The range that was allocated, including any extra memory allocation to minimize
+faults, is marked dirty due to the write fault. This overcounts dirty memory if
+the extra memory allocated is never modified.
+
+To make the scenario more interesting, imagine that this application spawns
+another process and maps the same file in the exact same way:
+
+```
+mmap(A, 0x1000, PROT_READ|PROT_WRITE, MAP_SHARED, fd, 0);
+```
+
+Imagine that this process then tries to modify the file again but with only 4
+bytes:
+
+```
+char buffer[] = "bbbb";
+memcpy(A, buffer, 4);
+```
+
+Since the first process has already mapped and accessed the same region of the
+file writable, `CachingInodeOperations.Translate` is called but returns the
+memory that has already been allocated rather than allocating new memory. The
+address range [`A`, `A`+0x1000) reflects the same cached view of the file as the
+first process sees. For example, reading 8 bytes from the file from either
+process via read(2) starting at offset 0 returns a consistent "bbbbaaaa".
+
+When this process no longer needs the shared memory, it may do:
+
+```
+munmap(A, 0x1000);
+```
+
+At this point, the modified memory cached by the `CachingInodeOperations` is not
+written back to the file because it is still in use by the first process that
+mapped it. When the first process also does:
+
+```
+munmap(A, 0x1000);
+```
+
+Then the last memory mapping of the file at the range [0, 0x1000) is gone. The
+file's `CachingInodeOperations` then starts writing back memory marked dirty to
+the file on its filesystem. Once writing completes, regardless of whether it was
+successful, the `CachingInodeOperations` frees the memory cached at the range
+[0, 0x1000).
+
+Subsequent read(2) or write(2) operations on the file go directly to the
+filesystem since there no longer exists memory for it in its
+`CachingInodeOperations`.
+
+## Future Work
+
+### Page cache
+
+The sentry does not yet implement the readahead and writeback optimizations for
+read(2) and write(2) respectively. To do so, on read(2) and/or write(2) the
+sentry must ensure that memory is allocated in a page cache to read or write
+into. However, the sentry cannot boundlessly allocate memory. If it did, the
+host would eventually OOM-kill the sentry+application process. This means that
+the sentry must implement a page cache memory allocation strategy that is
+bounded by a global user or container imposed limit. When this limit is
+approached, the sentry must decide from which page cache memory should be freed
+so that it can allocate more memory. If it makes a poor decision, the sentry may
+end up freeing and re-allocating memory to back regions of files that are
+frequently used, nullifying the optimization (and in some cases causing worse
+performance due to the overhead of memory allocation and general management).
+This is a form of "cache thrashing".
+
+In Linux, much research has been done to select and implement a lightweight but
+optimal page cache eviction algorithm. Linux makes use of hardware page bits to
+keep track of whether memory has been accessed. The sentry does not have direct
+access to hardware. Implementing a similarly lightweight and optimal page cache
+eviction algorithm will need to either introduce a kernel interface to obtain
+these page bits or find a suitable alternative proxy for access events.
+
+In Linux, readahead happens by default but is not always ideal. For instance,
+for files that are not read sequentially, it would be more ideal to simply read
+from only those regions of the file rather than to optimistically cache some
+number of bytes ahead of the read (up to 2MB in Linux) if the bytes cached won't
+be accessed. Linux implements the fadvise64(2) system call for applications to
+specify that a range of a file will not be accessed sequentially. The advice bit
+FADV_RANDOM turns off the readahead optimization for the given range in the
+given file. However fadvise64 is rarely used by applications so Linux implements
+a readahead backoff strategy if reads are not sequential. To ensure that
+application performance is not degraded, the sentry must implement a similar
+backoff strategy.
diff --git a/pkg/sentry/fs/fsutil/dirty_set.go b/pkg/sentry/fs/fsutil/dirty_set.go
new file mode 100644
index 000000000..c6cd45087
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/dirty_set.go
@@ -0,0 +1,237 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "math"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// DirtySet maps offsets into a memmap.Mappable to DirtyInfo. It is used to
+// implement Mappables that cache data from another source.
+//
+// type DirtySet <generated by go_generics>
+
+// DirtyInfo is the value type of DirtySet, and represents information about a
+// Mappable offset that is dirty (the cached data for that offset is newer than
+// its source).
+//
+// +stateify savable
+type DirtyInfo struct {
+ // Keep is true if the represented offset is concurrently writable, such
+ // that writing the data for that offset back to the source does not
+ // guarantee that the offset is clean (since it may be concurrently
+ // rewritten after the writeback).
+ Keep bool
+}
+
+// dirtySetFunctions implements segment.Functions for DirtySet.
+type dirtySetFunctions struct{}
+
+// MinKey implements segment.Functions.MinKey.
+func (dirtySetFunctions) MinKey() uint64 {
+ return 0
+}
+
+// MaxKey implements segment.Functions.MaxKey.
+func (dirtySetFunctions) MaxKey() uint64 {
+ return math.MaxUint64
+}
+
+// ClearValue implements segment.Functions.ClearValue.
+func (dirtySetFunctions) ClearValue(val *DirtyInfo) {
+}
+
+// Merge implements segment.Functions.Merge.
+func (dirtySetFunctions) Merge(_ memmap.MappableRange, val1 DirtyInfo, _ memmap.MappableRange, val2 DirtyInfo) (DirtyInfo, bool) {
+ if val1 != val2 {
+ return DirtyInfo{}, false
+ }
+ return val1, true
+}
+
+// Split implements segment.Functions.Split.
+func (dirtySetFunctions) Split(_ memmap.MappableRange, val DirtyInfo, _ uint64) (DirtyInfo, DirtyInfo) {
+ return val, val
+}
+
+// MarkClean marks all offsets in mr as not dirty, except for those to which
+// KeepDirty has been applied.
+func (ds *DirtySet) MarkClean(mr memmap.MappableRange) {
+ seg := ds.LowerBoundSegment(mr.Start)
+ for seg.Ok() && seg.Start() < mr.End {
+ if seg.Value().Keep {
+ seg = seg.NextSegment()
+ continue
+ }
+ seg = ds.Isolate(seg, mr)
+ seg = ds.Remove(seg).NextSegment()
+ }
+}
+
+// KeepClean marks all offsets in mr as not dirty, even those that were
+// previously kept dirty by KeepDirty.
+func (ds *DirtySet) KeepClean(mr memmap.MappableRange) {
+ ds.RemoveRange(mr)
+}
+
+// MarkDirty marks all offsets in mr as dirty.
+func (ds *DirtySet) MarkDirty(mr memmap.MappableRange) {
+ ds.setDirty(mr, false)
+}
+
+// KeepDirty marks all offsets in mr as dirty and prevents them from being
+// marked as clean by MarkClean.
+func (ds *DirtySet) KeepDirty(mr memmap.MappableRange) {
+ ds.setDirty(mr, true)
+}
+
+func (ds *DirtySet) setDirty(mr memmap.MappableRange, keep bool) {
+ var changedAny bool
+ defer func() {
+ if changedAny {
+ // Merge segments split by Isolate to reduce cost of iteration.
+ ds.MergeRange(mr)
+ }
+ }()
+ seg, gap := ds.Find(mr.Start)
+ for {
+ switch {
+ case seg.Ok() && seg.Start() < mr.End:
+ if keep && !seg.Value().Keep {
+ changedAny = true
+ seg = ds.Isolate(seg, mr)
+ seg.ValuePtr().Keep = true
+ }
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok() && gap.Start() < mr.End:
+ changedAny = true
+ seg = ds.Insert(gap, gap.Range().Intersect(mr), DirtyInfo{keep})
+ seg, gap = seg.NextNonEmpty()
+
+ default:
+ return
+ }
+ }
+}
+
+// AllowClean allows MarkClean to mark offsets in mr as not dirty, ending the
+// effect of a previous call to KeepDirty. (It does not itself mark those
+// offsets as not dirty.)
+func (ds *DirtySet) AllowClean(mr memmap.MappableRange) {
+ var changedAny bool
+ defer func() {
+ if changedAny {
+ // Merge segments split by Isolate to reduce cost of iteration.
+ ds.MergeRange(mr)
+ }
+ }()
+ for seg := ds.LowerBoundSegment(mr.Start); seg.Ok() && seg.Start() < mr.End; seg = seg.NextSegment() {
+ if seg.Value().Keep {
+ changedAny = true
+ seg = ds.Isolate(seg, mr)
+ seg.ValuePtr().Keep = false
+ }
+ }
+}
+
+// SyncDirty passes pages in the range mr that are stored in cache and
+// identified as dirty to writeAt, updating dirty to reflect successful writes.
+// If writeAt returns a successful partial write, SyncDirty will call it
+// repeatedly until all bytes have been written. max is the true size of the
+// cached object; offsets beyond max will not be passed to writeAt, even if
+// they are marked dirty.
+func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
+ var changedDirty bool
+ defer func() {
+ if changedDirty {
+ // Merge segments split by Isolate to reduce cost of iteration.
+ dirty.MergeRange(mr)
+ }
+ }()
+ dseg := dirty.LowerBoundSegment(mr.Start)
+ for dseg.Ok() && dseg.Start() < mr.End {
+ var dr memmap.MappableRange
+ if dseg.Value().Keep {
+ dr = dseg.Range().Intersect(mr)
+ } else {
+ changedDirty = true
+ dseg = dirty.Isolate(dseg, mr)
+ dr = dseg.Range()
+ }
+ if err := syncDirtyRange(ctx, dr, cache, max, mem, writeAt); err != nil {
+ return err
+ }
+ if dseg.Value().Keep {
+ dseg = dseg.NextSegment()
+ } else {
+ dseg = dirty.Remove(dseg).NextSegment()
+ }
+ }
+ return nil
+}
+
+// SyncDirtyAll passes all pages stored in cache identified as dirty to
+// writeAt, updating dirty to reflect successful writes. If writeAt returns a
+// successful partial write, SyncDirtyAll will call it repeatedly until all
+// bytes have been written. max is the true size of the cached object; offsets
+// beyond max will not be passed to writeAt, even if they are marked dirty.
+func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
+ dseg := dirty.FirstSegment()
+ for dseg.Ok() {
+ if err := syncDirtyRange(ctx, dseg.Range(), cache, max, mem, writeAt); err != nil {
+ return err
+ }
+ if dseg.Value().Keep {
+ dseg = dseg.NextSegment()
+ } else {
+ dseg = dirty.Remove(dseg).NextSegment()
+ }
+ }
+ return nil
+}
+
+// Preconditions: mr must be page-aligned.
+func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
+ for cseg := cache.LowerBoundSegment(mr.Start); cseg.Ok() && cseg.Start() < mr.End; cseg = cseg.NextSegment() {
+ wbr := cseg.Range().Intersect(mr)
+ if max < wbr.Start {
+ break
+ }
+ ims, err := mem.MapInternal(cseg.FileRangeOf(wbr), usermem.Read)
+ if err != nil {
+ return err
+ }
+ if max < wbr.End {
+ ims = ims.TakeFirst64(max - wbr.Start)
+ }
+ offset := wbr.Start
+ for !ims.IsEmpty() {
+ n, err := writeAt(ctx, ims, offset)
+ if err != nil {
+ return err
+ }
+ offset += n
+ ims = ims.DropFirst64(n)
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/fs/fsutil/dirty_set_test.go b/pkg/sentry/fs/fsutil/dirty_set_test.go
new file mode 100644
index 000000000..e3579c23c
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/dirty_set_test.go
@@ -0,0 +1,38 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func TestDirtySet(t *testing.T) {
+ var set DirtySet
+ set.MarkDirty(memmap.MappableRange{0, 2 * usermem.PageSize})
+ set.KeepDirty(memmap.MappableRange{usermem.PageSize, 2 * usermem.PageSize})
+ set.MarkClean(memmap.MappableRange{0, 2 * usermem.PageSize})
+ want := &DirtySegmentDataSlices{
+ Start: []uint64{usermem.PageSize},
+ End: []uint64{2 * usermem.PageSize},
+ Values: []DirtyInfo{{Keep: true}},
+ }
+ if got := set.ExportSortedSlices(); !reflect.DeepEqual(got, want) {
+ t.Errorf("set:\n\tgot %v,\n\twant %v", got, want)
+ }
+}
diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go
new file mode 100644
index 000000000..08695391c
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/file.go
@@ -0,0 +1,396 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// FileNoopRelease implements fs.FileOperations.Release for files that have no
+// resources to release.
+type FileNoopRelease struct{}
+
+// Release is a no-op.
+func (FileNoopRelease) Release() {}
+
+// SeekWithDirCursor is used to implement fs.FileOperations.Seek. If dirCursor
+// is not nil and the seek was on a directory, the cursor will be updated.
+//
+// Currently only seeking to 0 on a directory is supported.
+//
+// FIXME(b/33075855): Lift directory seeking limitations.
+func SeekWithDirCursor(ctx context.Context, file *fs.File, whence fs.SeekWhence, offset int64, dirCursor *string) (int64, error) {
+ inode := file.Dirent.Inode
+ current := file.Offset()
+
+ // Does the Inode represents a non-seekable type?
+ if fs.IsPipe(inode.StableAttr) || fs.IsSocket(inode.StableAttr) {
+ return current, syserror.ESPIPE
+ }
+
+ // Does the Inode represent a character device?
+ if fs.IsCharDevice(inode.StableAttr) {
+ // Ignore seek requests.
+ //
+ // FIXME(b/34716638): This preserves existing
+ // behavior but is not universally correct.
+ return 0, nil
+ }
+
+ // Otherwise compute the new offset.
+ switch whence {
+ case fs.SeekSet:
+ switch inode.StableAttr.Type {
+ case fs.RegularFile, fs.SpecialFile, fs.BlockDevice:
+ if offset < 0 {
+ return current, syserror.EINVAL
+ }
+ return offset, nil
+ case fs.Directory, fs.SpecialDirectory:
+ if offset != 0 {
+ return current, syserror.EINVAL
+ }
+ // SEEK_SET to 0 moves the directory "cursor" to the beginning.
+ if dirCursor != nil {
+ *dirCursor = ""
+ }
+ return 0, nil
+ default:
+ return current, syserror.EINVAL
+ }
+ case fs.SeekCurrent:
+ switch inode.StableAttr.Type {
+ case fs.RegularFile, fs.SpecialFile, fs.BlockDevice:
+ if current+offset < 0 {
+ return current, syserror.EINVAL
+ }
+ return current + offset, nil
+ case fs.Directory, fs.SpecialDirectory:
+ if offset != 0 {
+ return current, syserror.EINVAL
+ }
+ return current, nil
+ default:
+ return current, syserror.EINVAL
+ }
+ case fs.SeekEnd:
+ switch inode.StableAttr.Type {
+ case fs.RegularFile, fs.BlockDevice:
+ // Allow the file to determine the end.
+ uattr, err := inode.UnstableAttr(ctx)
+ if err != nil {
+ return current, err
+ }
+ sz := uattr.Size
+ if sz+offset < 0 {
+ return current, syserror.EINVAL
+ }
+ return sz + offset, nil
+ // FIXME(b/34778850): This is not universally correct.
+ // Remove SpecialDirectory.
+ case fs.SpecialDirectory:
+ if offset != 0 {
+ return current, syserror.EINVAL
+ }
+ // SEEK_END to 0 moves the directory "cursor" to the end.
+ //
+ // FIXME(b/35442290): The ensures that after the seek,
+ // reading on the directory will get EOF. But it is not
+ // correct in general because the directory can grow in
+ // size; attempting to read those new entries will be
+ // futile (EOF will always be the result).
+ return fs.FileMaxOffset, nil
+ default:
+ return current, syserror.EINVAL
+ }
+ }
+
+ // Not a valid seek request.
+ return current, syserror.EINVAL
+}
+
+// FileGenericSeek implements fs.FileOperations.Seek for files that use a
+// generic seek implementation.
+type FileGenericSeek struct{}
+
+// Seek implements fs.FileOperations.Seek.
+func (FileGenericSeek) Seek(ctx context.Context, file *fs.File, whence fs.SeekWhence, offset int64) (int64, error) {
+ return SeekWithDirCursor(ctx, file, whence, offset, nil)
+}
+
+// FileZeroSeek implements fs.FileOperations.Seek for files that maintain a
+// constant zero-value offset and require a no-op Seek.
+type FileZeroSeek struct{}
+
+// Seek implements fs.FileOperations.Seek.
+func (FileZeroSeek) Seek(context.Context, *fs.File, fs.SeekWhence, int64) (int64, error) {
+ return 0, nil
+}
+
+// FileNoSeek implements fs.FileOperations.Seek to return EINVAL.
+type FileNoSeek struct{}
+
+// Seek implements fs.FileOperations.Seek.
+func (FileNoSeek) Seek(context.Context, *fs.File, fs.SeekWhence, int64) (int64, error) {
+ return 0, syserror.EINVAL
+}
+
+// FilePipeSeek implements fs.FileOperations.Seek and can be used for files
+// that behave like pipes (seeking is not supported).
+type FilePipeSeek struct{}
+
+// Seek implements fs.FileOperations.Seek.
+func (FilePipeSeek) Seek(context.Context, *fs.File, fs.SeekWhence, int64) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// FileNotDirReaddir implements fs.FileOperations.Readdir for non-directories.
+type FileNotDirReaddir struct{}
+
+// Readdir implements fs.FileOperations.FileNotDirReaddir.
+func (FileNotDirReaddir) Readdir(context.Context, *fs.File, fs.DentrySerializer) (int64, error) {
+ return 0, syserror.ENOTDIR
+}
+
+// FileNoFsync implements fs.FileOperations.Fsync for files that don't support
+// syncing.
+type FileNoFsync struct{}
+
+// Fsync implements fs.FileOperations.Fsync.
+func (FileNoFsync) Fsync(context.Context, *fs.File, int64, int64, fs.SyncType) error {
+ return syserror.EINVAL
+}
+
+// FileNoopFsync implements fs.FileOperations.Fsync for files that don't need
+// to synced.
+type FileNoopFsync struct{}
+
+// Fsync implements fs.FileOperations.Fsync.
+func (FileNoopFsync) Fsync(context.Context, *fs.File, int64, int64, fs.SyncType) error {
+ return nil
+}
+
+// FileNoopFlush implements fs.FileOperations.Flush as a no-op.
+type FileNoopFlush struct{}
+
+// Flush implements fs.FileOperations.Flush.
+func (FileNoopFlush) Flush(context.Context, *fs.File) error {
+ return nil
+}
+
+// FileNoMMap implements fs.FileOperations.Mappable for files that cannot
+// be memory mapped.
+type FileNoMMap struct{}
+
+// ConfigureMMap implements fs.FileOperations.ConfigureMMap.
+func (FileNoMMap) ConfigureMMap(context.Context, *fs.File, *memmap.MMapOpts) error {
+ return syserror.ENODEV
+}
+
+// GenericConfigureMMap implements fs.FileOperations.ConfigureMMap for most
+// filesystems that support memory mapping.
+func GenericConfigureMMap(file *fs.File, m memmap.Mappable, opts *memmap.MMapOpts) error {
+ opts.Mappable = m
+ opts.MappingIdentity = file
+ file.IncRef()
+ return nil
+}
+
+// FileNoIoctl implements fs.FileOperations.Ioctl for files that don't
+// implement the ioctl syscall.
+type FileNoIoctl struct{}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (FileNoIoctl) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArguments) (uintptr, error) {
+ return 0, syserror.ENOTTY
+}
+
+// FileNoSplice implements fs.FileOperations.ReadFrom and
+// fs.FileOperations.WriteTo for files that don't support splice.
+type FileNoSplice struct{}
+
+// WriteTo implements fs.FileOperations.WriteTo.
+func (FileNoSplice) WriteTo(context.Context, *fs.File, io.Writer, int64, bool) (int64, error) {
+ return 0, syserror.ENOSYS
+}
+
+// ReadFrom implements fs.FileOperations.ReadFrom.
+func (FileNoSplice) ReadFrom(context.Context, *fs.File, io.Reader, int64) (int64, error) {
+ return 0, syserror.ENOSYS
+}
+
+// DirFileOperations implements most of fs.FileOperations for directories,
+// except for Readdir and UnstableAttr which the embedding type must implement.
+type DirFileOperations struct {
+ waiter.AlwaysReady
+ FileGenericSeek
+ FileNoIoctl
+ FileNoMMap
+ FileNoopFlush
+ FileNoopFsync
+ FileNoopRelease
+ FileNoSplice
+}
+
+// Read implements fs.FileOperations.Read
+func (*DirFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.EISDIR
+}
+
+// Write implements fs.FileOperations.Write.
+func (*DirFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.EISDIR
+}
+
+// StaticDirFileOperations implements fs.FileOperations for directories with
+// static children.
+//
+// +stateify savable
+type StaticDirFileOperations struct {
+ DirFileOperations `state:"nosave"`
+ FileUseInodeUnstableAttr `state:"nosave"`
+
+ // dentryMap is a SortedDentryMap used to implement Readdir.
+ dentryMap *fs.SortedDentryMap
+
+ // dirCursor contains the name of the last directory entry that was
+ // serialized.
+ dirCursor string
+}
+
+// NewStaticDirFileOperations returns a new StaticDirFileOperations that will
+// iterate the given denty map.
+func NewStaticDirFileOperations(dentries *fs.SortedDentryMap) *StaticDirFileOperations {
+ return &StaticDirFileOperations{
+ dentryMap: dentries,
+ }
+}
+
+// IterateDir implements DirIterator.IterateDir.
+func (sdfo *StaticDirFileOperations) IterateDir(ctx context.Context, d *fs.Dirent, dirCtx *fs.DirCtx, offset int) (int, error) {
+ n, err := fs.GenericReaddir(dirCtx, sdfo.dentryMap)
+ return offset + n, err
+}
+
+// Readdir implements fs.FileOperations.Readdir.
+func (sdfo *StaticDirFileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
+ root := fs.RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+ dirCtx := &fs.DirCtx{
+ Serializer: serializer,
+ DirCursor: &sdfo.dirCursor,
+ }
+ return fs.DirentReaddir(ctx, file.Dirent, sdfo, root, dirCtx, file.Offset())
+}
+
+// NoReadWriteFile is a file that does not support reading or writing.
+//
+// +stateify savable
+type NoReadWriteFile struct {
+ waiter.AlwaysReady `state:"nosave"`
+ FileGenericSeek `state:"nosave"`
+ FileNoIoctl `state:"nosave"`
+ FileNoMMap `state:"nosave"`
+ FileNoopFsync `state:"nosave"`
+ FileNoopFlush `state:"nosave"`
+ FileNoopRelease `state:"nosave"`
+ FileNoRead `state:"nosave"`
+ FileNoWrite `state:"nosave"`
+ FileNotDirReaddir `state:"nosave"`
+ FileUseInodeUnstableAttr `state:"nosave"`
+ FileNoSplice `state:"nosave"`
+}
+
+var _ fs.FileOperations = (*NoReadWriteFile)(nil)
+
+// FileStaticContentReader is a helper to implement fs.FileOperations.Read with
+// static content.
+//
+// +stateify savable
+type FileStaticContentReader struct {
+ // content is immutable.
+ content []byte
+}
+
+// NewFileStaticContentReader initializes a FileStaticContentReader with the
+// given content.
+func NewFileStaticContentReader(b []byte) FileStaticContentReader {
+ return FileStaticContentReader{
+ content: b,
+ }
+}
+
+// Read implements fs.FileOperations.Read.
+func (scr *FileStaticContentReader) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ if offset >= int64(len(scr.content)) {
+ return 0, nil
+ }
+ n, err := dst.CopyOut(ctx, scr.content[offset:])
+ return int64(n), err
+}
+
+// FileNoopWrite implements fs.FileOperations.Write as a noop.
+type FileNoopWrite struct{}
+
+// Write implements fs.FileOperations.Write.
+func (FileNoopWrite) Write(_ context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ return src.NumBytes(), nil
+}
+
+// FileNoRead implements fs.FileOperations.Read to return EINVAL.
+type FileNoRead struct{}
+
+// Read implements fs.FileOperations.Read.
+func (FileNoRead) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.EINVAL
+}
+
+// FileNoWrite implements fs.FileOperations.Write to return EINVAL.
+type FileNoWrite struct{}
+
+// Write implements fs.FileOperations.Write.
+func (FileNoWrite) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.EINVAL
+}
+
+// FileNoopRead implement fs.FileOperations.Read as a noop.
+type FileNoopRead struct{}
+
+// Read implements fs.FileOperations.Read.
+func (FileNoopRead) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, nil
+}
+
+// FileUseInodeUnstableAttr implements fs.FileOperations.UnstableAttr by calling
+// InodeOperations.UnstableAttr.
+type FileUseInodeUnstableAttr struct{}
+
+// UnstableAttr implements fs.FileOperations.UnstableAttr.
+func (FileUseInodeUnstableAttr) UnstableAttr(ctx context.Context, file *fs.File) (fs.UnstableAttr, error) {
+ return file.Dirent.Inode.UnstableAttr(ctx)
+}
diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go
new file mode 100644
index 000000000..5643cdac9
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/file_range_set.go
@@ -0,0 +1,209 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "fmt"
+ "io"
+ "math"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// FileRangeSet maps offsets into a memmap.Mappable to offsets into a
+// platform.File. It is used to implement Mappables that store data in
+// sparsely-allocated memory.
+//
+// type FileRangeSet <generated by go_generics>
+
+// FileRangeSetFunctions implements segment.Functions for FileRangeSet.
+type FileRangeSetFunctions struct{}
+
+// MinKey implements segment.Functions.MinKey.
+func (FileRangeSetFunctions) MinKey() uint64 {
+ return 0
+}
+
+// MaxKey implements segment.Functions.MaxKey.
+func (FileRangeSetFunctions) MaxKey() uint64 {
+ return math.MaxUint64
+}
+
+// ClearValue implements segment.Functions.ClearValue.
+func (FileRangeSetFunctions) ClearValue(_ *uint64) {
+}
+
+// Merge implements segment.Functions.Merge.
+func (FileRangeSetFunctions) Merge(mr1 memmap.MappableRange, frstart1 uint64, _ memmap.MappableRange, frstart2 uint64) (uint64, bool) {
+ if frstart1+mr1.Length() != frstart2 {
+ return 0, false
+ }
+ return frstart1, true
+}
+
+// Split implements segment.Functions.Split.
+func (FileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, split uint64) (uint64, uint64) {
+ return frstart, frstart + (split - mr.Start)
+}
+
+// FileRange returns the FileRange mapped by seg.
+func (seg FileRangeIterator) FileRange() platform.FileRange {
+ return seg.FileRangeOf(seg.Range())
+}
+
+// FileRangeOf returns the FileRange mapped by mr.
+//
+// Preconditions: seg.Range().IsSupersetOf(mr). mr.Length() != 0.
+func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) platform.FileRange {
+ frstart := seg.Value() + (mr.Start - seg.Start())
+ return platform.FileRange{frstart, frstart + mr.Length()}
+}
+
+// Fill attempts to ensure that all memmap.Mappable offsets in required are
+// mapped to a platform.File offset, by allocating from mf with the given
+// memory usage kind and invoking readAt to store data into memory. (If readAt
+// returns a successful partial read, Fill will call it repeatedly until all
+// bytes have been read.) EOF is handled consistently with the requirements of
+// mmap(2): bytes after EOF on the same page are zeroed; pages after EOF are
+// invalid.
+//
+// Fill may read offsets outside of required, but will never read offsets
+// outside of optional. It returns a non-nil error if any error occurs, even
+// if the error only affects offsets in optional, but not in required.
+//
+// Preconditions: required.Length() > 0. optional.IsSupersetOf(required).
+// required and optional must be page-aligned.
+func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.MappableRange, mf *pgalloc.MemoryFile, kind usage.MemoryKind, readAt func(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error)) error {
+ gap := frs.LowerBoundGap(required.Start)
+ for gap.Ok() && gap.Start() < required.End {
+ if gap.Range().Length() == 0 {
+ gap = gap.NextGap()
+ continue
+ }
+ gr := gap.Range().Intersect(optional)
+
+ // Read data into the gap.
+ fr, err := mf.AllocateAndFill(gr.Length(), kind, safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
+ var done uint64
+ for !dsts.IsEmpty() {
+ n, err := readAt(ctx, dsts, gr.Start+done)
+ done += n
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ if err == io.EOF {
+ // MemoryFile.AllocateAndFill truncates down to a page
+ // boundary, but FileRangeSet.Fill is supposed to
+ // zero-fill to the end of the page in this case.
+ donepgaddr, ok := usermem.Addr(done).RoundUp()
+ if donepg := uint64(donepgaddr); ok && donepg != done {
+ dsts.DropFirst64(donepg - done)
+ done = donepg
+ if dsts.IsEmpty() {
+ return done, nil
+ }
+ }
+ }
+ return done, err
+ }
+ }
+ return done, nil
+ }))
+
+ // Store anything we managed to read into the cache.
+ if done := fr.Length(); done != 0 {
+ gr.End = gr.Start + done
+ gap = frs.Insert(gap, gr, fr.Start).NextGap()
+ }
+
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// Drop removes segments for memmap.Mappable offsets in mr, freeing the
+// corresponding platform.FileRanges.
+//
+// Preconditions: mr must be page-aligned.
+func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) {
+ seg := frs.LowerBoundSegment(mr.Start)
+ for seg.Ok() && seg.Start() < mr.End {
+ seg = frs.Isolate(seg, mr)
+ mf.DecRef(seg.FileRange())
+ seg = frs.Remove(seg).NextSegment()
+ }
+}
+
+// DropAll removes all segments in mr, freeing the corresponding
+// platform.FileRanges.
+func (frs *FileRangeSet) DropAll(mf *pgalloc.MemoryFile) {
+ for seg := frs.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ mf.DecRef(seg.FileRange())
+ }
+ frs.RemoveAll()
+}
+
+// Truncate updates frs to reflect Mappable truncation to the given length:
+// bytes after the new EOF on the same page are zeroed, and pages after the new
+// EOF are freed.
+func (frs *FileRangeSet) Truncate(end uint64, mf *pgalloc.MemoryFile) {
+ pgendaddr, ok := usermem.Addr(end).RoundUp()
+ if ok {
+ pgend := uint64(pgendaddr)
+
+ // Free truncated pages.
+ frs.SplitAt(pgend)
+ seg := frs.LowerBoundSegment(pgend)
+ for seg.Ok() {
+ mf.DecRef(seg.FileRange())
+ seg = frs.Remove(seg).NextSegment()
+ }
+
+ if end == pgend {
+ return
+ }
+ }
+
+ // Here we know end < end.RoundUp(). If the new EOF lands in the
+ // middle of a page that we have, zero out its contents beyond the new
+ // length.
+ seg := frs.FindSegment(end)
+ if seg.Ok() {
+ fr := seg.FileRange()
+ fr.Start += end - seg.Start()
+ ims, err := mf.MapInternal(fr, usermem.Write)
+ if err != nil {
+ // There's no good recourse from here. This means
+ // that we can't keep cached memory consistent with
+ // the new end of file. The caller may have already
+ // updated the file size on their backing file system.
+ //
+ // We don't want to risk blindly continuing onward,
+ // so in the extremely rare cases this does happen,
+ // we abandon ship.
+ panic(fmt.Sprintf("Failed to map %v: %v", fr, err))
+ }
+ if _, err := safemem.ZeroSeq(ims); err != nil {
+ panic(fmt.Sprintf("Zeroing %v failed: %v", fr, err))
+ }
+ }
+}
diff --git a/pkg/sentry/fs/fsutil/frame_ref_set.go b/pkg/sentry/fs/fsutil/frame_ref_set.go
new file mode 100644
index 000000000..dd6f5aba6
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/frame_ref_set.go
@@ -0,0 +1,91 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "math"
+
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+)
+
+// FrameRefSetFunctions implements segment.Functions for FrameRefSet.
+type FrameRefSetFunctions struct{}
+
+// MinKey implements segment.Functions.MinKey.
+func (FrameRefSetFunctions) MinKey() uint64 {
+ return 0
+}
+
+// MaxKey implements segment.Functions.MaxKey.
+func (FrameRefSetFunctions) MaxKey() uint64 {
+ return math.MaxUint64
+}
+
+// ClearValue implements segment.Functions.ClearValue.
+func (FrameRefSetFunctions) ClearValue(val *uint64) {
+}
+
+// Merge implements segment.Functions.Merge.
+func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform.FileRange, val2 uint64) (uint64, bool) {
+ if val1 != val2 {
+ return 0, false
+ }
+ return val1, true
+}
+
+// Split implements segment.Functions.Split.
+func (FrameRefSetFunctions) Split(_ platform.FileRange, val uint64, _ uint64) (uint64, uint64) {
+ return val, val
+}
+
+// IncRefAndAccount adds a reference on the range fr. All newly inserted segments
+// are accounted as host page cache memory mappings.
+func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) {
+ seg, gap := refs.Find(fr.Start)
+ for {
+ switch {
+ case seg.Ok() && seg.Start() < fr.End:
+ seg = refs.Isolate(seg, fr)
+ seg.SetValue(seg.Value() + 1)
+ seg, gap = seg.NextNonEmpty()
+ case gap.Ok() && gap.Start() < fr.End:
+ newRange := gap.Range().Intersect(fr)
+ usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped)
+ seg, gap = refs.InsertWithoutMerging(gap, newRange, 1).NextNonEmpty()
+ default:
+ refs.MergeAdjacent(fr)
+ return
+ }
+ }
+}
+
+// DecRefAndAccount removes a reference on the range fr and untracks segments
+// that are removed from memory accounting.
+func (refs *FrameRefSet) DecRefAndAccount(fr platform.FileRange) {
+ seg := refs.FindSegment(fr.Start)
+
+ for seg.Ok() && seg.Start() < fr.End {
+ seg = refs.Isolate(seg, fr)
+ if old := seg.Value(); old == 1 {
+ usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped)
+ seg = refs.Remove(seg).NextSegment()
+ } else {
+ seg.SetValue(old - 1)
+ seg = seg.NextSegment()
+ }
+ }
+ refs.MergeAdjacent(fr)
+}
diff --git a/pkg/sentry/fs/fsutil/fsutil.go b/pkg/sentry/fs/fsutil/fsutil.go
new file mode 100644
index 000000000..c9587b1d9
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/fsutil.go
@@ -0,0 +1,24 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fsutil provides utilities for implementing fs.InodeOperations
+// and fs.FileOperations:
+//
+// - For embeddable utilities, see inode.go and file.go.
+//
+// - For fs.Inodes that require a page cache to be memory mapped, see
+// inode_cache.go.
+//
+// - For anon fs.Inodes, see anon.go.
+package fsutil
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go
new file mode 100644
index 000000000..e82afd112
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/host_file_mapper.go
@@ -0,0 +1,242 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// HostFileMapper caches mappings of an arbitrary host file descriptor. It is
+// used by implementations of memmap.Mappable that represent a host file
+// descriptor.
+//
+// +stateify savable
+type HostFileMapper struct {
+ // HostFile conceptually breaks the file into pieces called chunks, of
+ // size and alignment chunkSize, and caches mappings of the file on a chunk
+ // granularity.
+
+ refsMu sync.Mutex `state:"nosave"`
+
+ // refs maps chunk start offsets to the sum of reference counts for all
+ // pages in that chunk. refs is protected by refsMu.
+ refs map[uint64]int32
+
+ mapsMu sync.Mutex `state:"nosave"`
+
+ // mappings maps chunk start offsets to mappings of those chunks,
+ // obtained by calling syscall.Mmap. mappings is protected by
+ // mapsMu.
+ mappings map[uint64]mapping `state:"nosave"`
+}
+
+const (
+ chunkShift = usermem.HugePageShift
+ chunkSize = 1 << chunkShift
+ chunkMask = chunkSize - 1
+)
+
+func pagesInChunk(mr memmap.MappableRange, chunkStart uint64) int32 {
+ return int32(mr.Intersect(memmap.MappableRange{chunkStart, chunkStart + chunkSize}).Length() / usermem.PageSize)
+}
+
+type mapping struct {
+ addr uintptr
+ writable bool
+}
+
+// Init must be called on zero-value HostFileMappers before first use.
+func (f *HostFileMapper) Init() {
+ f.refs = make(map[uint64]int32)
+ f.mappings = make(map[uint64]mapping)
+}
+
+// NewHostFileMapper returns an initialized HostFileMapper allocated on the
+// heap with no references or cached mappings.
+func NewHostFileMapper() *HostFileMapper {
+ f := &HostFileMapper{}
+ f.Init()
+ return f
+}
+
+// IncRefOn increments the reference count on all offsets in mr.
+//
+// Preconditions: mr.Length() != 0. mr.Start and mr.End must be page-aligned.
+func (f *HostFileMapper) IncRefOn(mr memmap.MappableRange) {
+ f.refsMu.Lock()
+ defer f.refsMu.Unlock()
+ for chunkStart := mr.Start &^ chunkMask; chunkStart < mr.End; chunkStart += chunkSize {
+ refs := f.refs[chunkStart]
+ pgs := pagesInChunk(mr, chunkStart)
+ if refs+pgs < refs {
+ // Would overflow.
+ panic(fmt.Sprintf("HostFileMapper.IncRefOn(%v): adding %d page references to chunk %#x, which has %d page references", mr, pgs, chunkStart, refs))
+ }
+ f.refs[chunkStart] = refs + pgs
+ }
+}
+
+// DecRefOn decrements the reference count on all offsets in mr.
+//
+// Preconditions: mr.Length() != 0. mr.Start and mr.End must be page-aligned.
+func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) {
+ f.refsMu.Lock()
+ defer f.refsMu.Unlock()
+ for chunkStart := mr.Start &^ chunkMask; chunkStart < mr.End; chunkStart += chunkSize {
+ refs := f.refs[chunkStart]
+ pgs := pagesInChunk(mr, chunkStart)
+ switch {
+ case refs > pgs:
+ f.refs[chunkStart] = refs - pgs
+ case refs == pgs:
+ f.mapsMu.Lock()
+ delete(f.refs, chunkStart)
+ if m, ok := f.mappings[chunkStart]; ok {
+ f.unmapAndRemoveLocked(chunkStart, m)
+ }
+ f.mapsMu.Unlock()
+ case refs < pgs:
+ panic(fmt.Sprintf("HostFileMapper.DecRefOn(%v): removing %d page references from chunk %#x, which has %d page references", mr, pgs, chunkStart, refs))
+ }
+ }
+}
+
+// MapInternal returns a mapping of offsets in fr from fd. The returned
+// safemem.BlockSeq is valid as long as at least one reference is held on all
+// offsets in fr or until the next call to UnmapAll.
+//
+// Preconditions: The caller must hold a reference on all offsets in fr.
+func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool) (safemem.BlockSeq, error) {
+ chunks := ((fr.End + chunkMask) >> chunkShift) - (fr.Start >> chunkShift)
+ f.mapsMu.Lock()
+ defer f.mapsMu.Unlock()
+ if chunks == 1 {
+ // Avoid an unnecessary slice allocation.
+ var seq safemem.BlockSeq
+ err := f.forEachMappingBlockLocked(fr, fd, write, func(b safemem.Block) {
+ seq = safemem.BlockSeqOf(b)
+ })
+ return seq, err
+ }
+ blocks := make([]safemem.Block, 0, chunks)
+ err := f.forEachMappingBlockLocked(fr, fd, write, func(b safemem.Block) {
+ blocks = append(blocks, b)
+ })
+ return safemem.BlockSeqFromSlice(blocks), err
+}
+
+// Preconditions: f.mapsMu must be locked.
+func (f *HostFileMapper) forEachMappingBlockLocked(fr platform.FileRange, fd int, write bool, fn func(safemem.Block)) error {
+ prot := syscall.PROT_READ
+ if write {
+ prot |= syscall.PROT_WRITE
+ }
+ for chunkStart := fr.Start &^ chunkMask; chunkStart < fr.End; chunkStart += chunkSize {
+ m, ok := f.mappings[chunkStart]
+ if !ok {
+ addr, _, errno := syscall.Syscall6(
+ syscall.SYS_MMAP,
+ 0,
+ chunkSize,
+ uintptr(prot),
+ syscall.MAP_SHARED,
+ uintptr(fd),
+ uintptr(chunkStart))
+ if errno != 0 {
+ return errno
+ }
+ m = mapping{addr, write}
+ f.mappings[chunkStart] = m
+ } else if write && !m.writable {
+ addr, _, errno := syscall.Syscall6(
+ syscall.SYS_MMAP,
+ m.addr,
+ chunkSize,
+ uintptr(prot),
+ syscall.MAP_SHARED|syscall.MAP_FIXED,
+ uintptr(fd),
+ uintptr(chunkStart))
+ if errno != 0 {
+ return errno
+ }
+ m = mapping{addr, write}
+ f.mappings[chunkStart] = m
+ }
+ var startOff uint64
+ if chunkStart < fr.Start {
+ startOff = fr.Start - chunkStart
+ }
+ endOff := uint64(chunkSize)
+ if chunkStart+chunkSize > fr.End {
+ endOff = fr.End - chunkStart
+ }
+ fn(f.unsafeBlockFromChunkMapping(m.addr).TakeFirst64(endOff).DropFirst64(startOff))
+ }
+ return nil
+}
+
+// UnmapAll unmaps all cached mappings. Callers are responsible for
+// synchronization with mappings returned by previous calls to MapInternal.
+func (f *HostFileMapper) UnmapAll() {
+ f.mapsMu.Lock()
+ defer f.mapsMu.Unlock()
+ for chunkStart, m := range f.mappings {
+ f.unmapAndRemoveLocked(chunkStart, m)
+ }
+}
+
+// Preconditions: f.mapsMu must be locked. f.mappings[chunkStart] == m.
+func (f *HostFileMapper) unmapAndRemoveLocked(chunkStart uint64, m mapping) {
+ if _, _, errno := syscall.Syscall(syscall.SYS_MUNMAP, m.addr, chunkSize, 0); errno != 0 {
+ // This leaks address space and is unexpected, but is otherwise
+ // harmless, so complain but don't panic.
+ log.Warningf("HostFileMapper: failed to unmap mapping %#x for chunk %#x: %v", m.addr, chunkStart, errno)
+ }
+ delete(f.mappings, chunkStart)
+}
+
+// RegenerateMappings must be called when the file description mapped by f
+// changes, to replace existing mappings of the previous file description.
+func (f *HostFileMapper) RegenerateMappings(fd int) error {
+ f.mapsMu.Lock()
+ defer f.mapsMu.Unlock()
+
+ for chunkStart, m := range f.mappings {
+ prot := syscall.PROT_READ
+ if m.writable {
+ prot |= syscall.PROT_WRITE
+ }
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_MMAP,
+ m.addr,
+ chunkSize,
+ uintptr(prot),
+ syscall.MAP_SHARED|syscall.MAP_FIXED,
+ uintptr(fd),
+ uintptr(chunkStart))
+ if errno != 0 {
+ return errno
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper_state.go b/pkg/sentry/fs/fsutil/host_file_mapper_state.go
new file mode 100644
index 000000000..576d2a3df
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/host_file_mapper_state.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+// afterLoad is invoked by stateify.
+func (f *HostFileMapper) afterLoad() {
+ f.mappings = make(map[uint64]mapping)
+}
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go b/pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go
new file mode 100644
index 000000000..2d4778d64
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go
@@ -0,0 +1,27 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/safemem"
+)
+
+func (*HostFileMapper) unsafeBlockFromChunkMapping(addr uintptr) safemem.Block {
+ // We don't control the host file's length, so touching its mappings may
+ // raise SIGBUS. Thus accesses to it must use safecopy.
+ return safemem.BlockFromUnsafePointer((unsafe.Pointer)(addr), chunkSize)
+}
diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go
new file mode 100644
index 000000000..78fec553e
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/host_mappable.go
@@ -0,0 +1,214 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "math"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// HostMappable implements memmap.Mappable and platform.File over a
+// CachedFileObject.
+//
+// Lock order (compare the lock order model in mm/mm.go):
+// truncateMu ("fs locks")
+// mu ("memmap.Mappable locks not taken by Translate")
+// ("platform.File locks")
+// backingFile ("CachedFileObject locks")
+//
+// +stateify savable
+type HostMappable struct {
+ hostFileMapper *HostFileMapper
+
+ backingFile CachedFileObject
+
+ mu sync.Mutex `state:"nosave"`
+
+ // mappings tracks mappings of the cached file object into
+ // memmap.MappingSpaces so it can invalidated upon save. Protected by mu.
+ mappings memmap.MappingSet
+
+ // truncateMu protects writes and truncations. See Truncate() for details.
+ truncateMu sync.RWMutex `state:"nosave"`
+}
+
+// NewHostMappable creates a new mappable that maps directly to host FD.
+func NewHostMappable(backingFile CachedFileObject) *HostMappable {
+ return &HostMappable{
+ hostFileMapper: NewHostFileMapper(),
+ backingFile: backingFile,
+ }
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (h *HostMappable) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ // Hot path. Avoid defers.
+ h.mu.Lock()
+ mapped := h.mappings.AddMapping(ms, ar, offset, writable)
+ for _, r := range mapped {
+ h.hostFileMapper.IncRefOn(r)
+ }
+ h.mu.Unlock()
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (h *HostMappable) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ // Hot path. Avoid defers.
+ h.mu.Lock()
+ unmapped := h.mappings.RemoveMapping(ms, ar, offset, writable)
+ for _, r := range unmapped {
+ h.hostFileMapper.DecRefOn(r)
+ }
+ h.mu.Unlock()
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (h *HostMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ return h.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (h *HostMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ return []memmap.Translation{
+ {
+ Source: optional,
+ File: h,
+ Offset: optional.Start,
+ Perms: usermem.AnyAccess,
+ },
+ }, nil
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (h *HostMappable) InvalidateUnsavable(_ context.Context) error {
+ h.mu.Lock()
+ h.mappings.InvalidateAll(memmap.InvalidateOpts{})
+ h.mu.Unlock()
+ return nil
+}
+
+// NotifyChangeFD must be called after the file description represented by
+// CachedFileObject.FD() changes.
+func (h *HostMappable) NotifyChangeFD() error {
+ // Update existing sentry mappings to refer to the new file description.
+ if err := h.hostFileMapper.RegenerateMappings(h.backingFile.FD()); err != nil {
+ return err
+ }
+
+ // Shoot down existing application mappings of the old file description;
+ // they will be remapped with the new file description on demand.
+ h.mu.Lock()
+ defer h.mu.Unlock()
+
+ h.mappings.InvalidateAll(memmap.InvalidateOpts{})
+ return nil
+}
+
+// MapInternal implements platform.File.MapInternal.
+func (h *HostMappable) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+ return h.hostFileMapper.MapInternal(fr, h.backingFile.FD(), at.Write)
+}
+
+// FD implements platform.File.FD.
+func (h *HostMappable) FD() int {
+ return h.backingFile.FD()
+}
+
+// IncRef implements platform.File.IncRef.
+func (h *HostMappable) IncRef(fr platform.FileRange) {
+ mr := memmap.MappableRange{Start: fr.Start, End: fr.End}
+ h.hostFileMapper.IncRefOn(mr)
+}
+
+// DecRef implements platform.File.DecRef.
+func (h *HostMappable) DecRef(fr platform.FileRange) {
+ mr := memmap.MappableRange{Start: fr.Start, End: fr.End}
+ h.hostFileMapper.DecRefOn(mr)
+}
+
+// Truncate truncates the file, invalidating any mapping that may have been
+// removed after the size change.
+//
+// Truncation and writes are synchronized to prevent races where writes make the
+// file grow between truncation and invalidation below:
+// T1: Calls SetMaskedAttributes and stalls
+// T2: Appends to file causing it to grow
+// T2: Writes to mapped pages and COW happens
+// T1: Continues and wronly invalidates the page mapped in step above.
+func (h *HostMappable) Truncate(ctx context.Context, newSize int64) error {
+ h.truncateMu.Lock()
+ defer h.truncateMu.Unlock()
+
+ mask := fs.AttrMask{Size: true}
+ attr := fs.UnstableAttr{Size: newSize}
+ if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr, false); err != nil {
+ return err
+ }
+
+ // Invalidate COW mappings that may exist beyond the new size in case the file
+ // is being shrunk. Other mappings don't need to be invalidated because
+ // translate will just return identical mappings after invalidation anyway,
+ // and SIGBUS will be raised and handled when the mappings are touched.
+ //
+ // Compare Linux's mm/truncate.c:truncate_setsize() =>
+ // truncate_pagecache() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ h.mu.Lock()
+ defer h.mu.Unlock()
+ mr := memmap.MappableRange{
+ Start: fs.OffsetPageEnd(newSize),
+ End: fs.OffsetPageEnd(math.MaxInt64),
+ }
+ h.mappings.Invalidate(mr, memmap.InvalidateOpts{InvalidatePrivate: true})
+
+ return nil
+}
+
+// Allocate reserves space in the backing file.
+func (h *HostMappable) Allocate(ctx context.Context, offset int64, length int64) error {
+ h.truncateMu.RLock()
+ err := h.backingFile.Allocate(ctx, offset, length)
+ h.truncateMu.RUnlock()
+ return err
+}
+
+// Write writes to the file backing this mappable.
+func (h *HostMappable) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ h.truncateMu.RLock()
+ n, err := src.CopyInTo(ctx, &writer{ctx: ctx, hostMappable: h, off: offset})
+ h.truncateMu.RUnlock()
+ return n, err
+}
+
+type writer struct {
+ ctx context.Context
+ hostMappable *HostMappable
+ off int64
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (w *writer) WriteFromBlocks(src safemem.BlockSeq) (uint64, error) {
+ n, err := w.hostMappable.backingFile.WriteFromBlocksAt(w.ctx, src, uint64(w.off))
+ w.off += int64(n)
+ return n, err
+}
diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go
new file mode 100644
index 000000000..1922ff08c
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/inode.go
@@ -0,0 +1,531 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SimpleFileInode is a simple implementation of InodeOperations.
+//
+// +stateify savable
+type SimpleFileInode struct {
+ InodeGenericChecker `state:"nosave"`
+ InodeNoExtendedAttributes `state:"nosave"`
+ InodeNoopRelease `state:"nosave"`
+ InodeNoopWriteOut `state:"nosave"`
+ InodeNotAllocatable `state:"nosave"`
+ InodeNotDirectory `state:"nosave"`
+ InodeNotMappable `state:"nosave"`
+ InodeNotOpenable `state:"nosave"`
+ InodeNotSocket `state:"nosave"`
+ InodeNotSymlink `state:"nosave"`
+ InodeNotTruncatable `state:"nosave"`
+ InodeNotVirtual `state:"nosave"`
+
+ InodeSimpleAttributes
+}
+
+// NewSimpleFileInode returns a new SimpleFileInode.
+func NewSimpleFileInode(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, typ uint64) *SimpleFileInode {
+ return &SimpleFileInode{
+ InodeSimpleAttributes: NewInodeSimpleAttributes(ctx, owner, perms, typ),
+ }
+}
+
+// NoReadWriteFileInode is an implementation of InodeOperations that supports
+// opening files that are not readable or writeable.
+//
+// +stateify savable
+type NoReadWriteFileInode struct {
+ InodeGenericChecker `state:"nosave"`
+ InodeNoExtendedAttributes `state:"nosave"`
+ InodeNoopRelease `state:"nosave"`
+ InodeNoopWriteOut `state:"nosave"`
+ InodeNotAllocatable `state:"nosave"`
+ InodeNotDirectory `state:"nosave"`
+ InodeNotMappable `state:"nosave"`
+ InodeNotSocket `state:"nosave"`
+ InodeNotSymlink `state:"nosave"`
+ InodeNotTruncatable `state:"nosave"`
+ InodeNotVirtual `state:"nosave"`
+
+ InodeSimpleAttributes
+}
+
+// NewNoReadWriteFileInode returns a new NoReadWriteFileInode.
+func NewNoReadWriteFileInode(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, typ uint64) *NoReadWriteFileInode {
+ return &NoReadWriteFileInode{
+ InodeSimpleAttributes: NewInodeSimpleAttributes(ctx, owner, perms, typ),
+ }
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (*NoReadWriteFileInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &NoReadWriteFile{}), nil
+}
+
+// InodeSimpleAttributes implements methods for updating in-memory unstable
+// attributes.
+//
+// +stateify savable
+type InodeSimpleAttributes struct {
+ // fsType is the immutable filesystem type that will be returned by
+ // StatFS.
+ fsType uint64
+
+ // mu protects unstable.
+ mu sync.RWMutex `state:"nosave"`
+ unstable fs.UnstableAttr
+}
+
+// NewInodeSimpleAttributes returns a new InodeSimpleAttributes with the given
+// owner and permissions, and all timestamps set to the current time.
+func NewInodeSimpleAttributes(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, typ uint64) InodeSimpleAttributes {
+ return NewInodeSimpleAttributesWithUnstable(fs.WithCurrentTime(ctx, fs.UnstableAttr{
+ Owner: owner,
+ Perms: perms,
+ }), typ)
+}
+
+// NewInodeSimpleAttributesWithUnstable returns a new InodeSimpleAttributes
+// with the given unstable attributes.
+func NewInodeSimpleAttributesWithUnstable(uattr fs.UnstableAttr, typ uint64) InodeSimpleAttributes {
+ return InodeSimpleAttributes{
+ fsType: typ,
+ unstable: uattr,
+ }
+}
+
+// UnstableAttr implements fs.InodeOperations.UnstableAttr.
+func (i *InodeSimpleAttributes) UnstableAttr(ctx context.Context, _ *fs.Inode) (fs.UnstableAttr, error) {
+ i.mu.RLock()
+ u := i.unstable
+ i.mu.RUnlock()
+ return u, nil
+}
+
+// SetPermissions implements fs.InodeOperations.SetPermissions.
+func (i *InodeSimpleAttributes) SetPermissions(ctx context.Context, _ *fs.Inode, p fs.FilePermissions) bool {
+ i.mu.Lock()
+ i.unstable.SetPermissions(ctx, p)
+ i.mu.Unlock()
+ return true
+}
+
+// SetOwner implements fs.InodeOperations.SetOwner.
+func (i *InodeSimpleAttributes) SetOwner(ctx context.Context, _ *fs.Inode, owner fs.FileOwner) error {
+ i.mu.Lock()
+ i.unstable.SetOwner(ctx, owner)
+ i.mu.Unlock()
+ return nil
+}
+
+// SetTimestamps implements fs.InodeOperations.SetTimestamps.
+func (i *InodeSimpleAttributes) SetTimestamps(ctx context.Context, _ *fs.Inode, ts fs.TimeSpec) error {
+ i.mu.Lock()
+ i.unstable.SetTimestamps(ctx, ts)
+ i.mu.Unlock()
+ return nil
+}
+
+// AddLink implements fs.InodeOperations.AddLink.
+func (i *InodeSimpleAttributes) AddLink() {
+ i.mu.Lock()
+ i.unstable.Links++
+ i.mu.Unlock()
+}
+
+// DropLink implements fs.InodeOperations.DropLink.
+func (i *InodeSimpleAttributes) DropLink() {
+ i.mu.Lock()
+ i.unstable.Links--
+ i.mu.Unlock()
+}
+
+// StatFS implements fs.InodeOperations.StatFS.
+func (i *InodeSimpleAttributes) StatFS(context.Context) (fs.Info, error) {
+ if i.fsType == 0 {
+ return fs.Info{}, syserror.ENOSYS
+ }
+ return fs.Info{Type: i.fsType}, nil
+}
+
+// NotifyAccess updates the access time.
+func (i *InodeSimpleAttributes) NotifyAccess(ctx context.Context) {
+ i.mu.Lock()
+ i.unstable.AccessTime = ktime.NowFromContext(ctx)
+ i.mu.Unlock()
+}
+
+// NotifyModification updates the modification time.
+func (i *InodeSimpleAttributes) NotifyModification(ctx context.Context) {
+ i.mu.Lock()
+ i.unstable.ModificationTime = ktime.NowFromContext(ctx)
+ i.mu.Unlock()
+}
+
+// NotifyStatusChange updates the status change time.
+func (i *InodeSimpleAttributes) NotifyStatusChange(ctx context.Context) {
+ i.mu.Lock()
+ i.unstable.StatusChangeTime = ktime.NowFromContext(ctx)
+ i.mu.Unlock()
+}
+
+// NotifyModificationAndStatusChange updates the modification and status change
+// times.
+func (i *InodeSimpleAttributes) NotifyModificationAndStatusChange(ctx context.Context) {
+ i.mu.Lock()
+ now := ktime.NowFromContext(ctx)
+ i.unstable.ModificationTime = now
+ i.unstable.StatusChangeTime = now
+ i.mu.Unlock()
+}
+
+// InodeSimpleExtendedAttributes implements
+// fs.InodeOperations.{Get,Set,List}Xattr.
+//
+// +stateify savable
+type InodeSimpleExtendedAttributes struct {
+ // mu protects xattrs.
+ mu sync.RWMutex `state:"nosave"`
+ xattrs map[string]string
+}
+
+// GetXattr implements fs.InodeOperations.GetXattr.
+func (i *InodeSimpleExtendedAttributes) GetXattr(_ context.Context, _ *fs.Inode, name string, _ uint64) (string, error) {
+ i.mu.RLock()
+ value, ok := i.xattrs[name]
+ i.mu.RUnlock()
+ if !ok {
+ return "", syserror.ENOATTR
+ }
+ return value, nil
+}
+
+// SetXattr implements fs.InodeOperations.SetXattr.
+func (i *InodeSimpleExtendedAttributes) SetXattr(_ context.Context, _ *fs.Inode, name, value string, flags uint32) error {
+ i.mu.Lock()
+ defer i.mu.Unlock()
+ if i.xattrs == nil {
+ if flags&linux.XATTR_REPLACE != 0 {
+ return syserror.ENODATA
+ }
+ i.xattrs = make(map[string]string)
+ }
+
+ _, ok := i.xattrs[name]
+ if ok && flags&linux.XATTR_CREATE != 0 {
+ return syserror.EEXIST
+ }
+ if !ok && flags&linux.XATTR_REPLACE != 0 {
+ return syserror.ENODATA
+ }
+
+ i.xattrs[name] = value
+ return nil
+}
+
+// ListXattr implements fs.InodeOperations.ListXattr.
+func (i *InodeSimpleExtendedAttributes) ListXattr(context.Context, *fs.Inode, uint64) (map[string]struct{}, error) {
+ i.mu.RLock()
+ names := make(map[string]struct{}, len(i.xattrs))
+ for name := range i.xattrs {
+ names[name] = struct{}{}
+ }
+ i.mu.RUnlock()
+ return names, nil
+}
+
+// RemoveXattr implements fs.InodeOperations.RemoveXattr.
+func (i *InodeSimpleExtendedAttributes) RemoveXattr(_ context.Context, _ *fs.Inode, name string) error {
+ i.mu.Lock()
+ defer i.mu.Unlock()
+ if _, ok := i.xattrs[name]; ok {
+ delete(i.xattrs, name)
+ return nil
+ }
+ return syserror.ENOATTR
+}
+
+// staticFile is a file with static contents. It is returned by
+// InodeStaticFileGetter.GetFile.
+//
+// +stateify savable
+type staticFile struct {
+ FileGenericSeek `state:"nosave"`
+ FileNoIoctl `state:"nosave"`
+ FileNoMMap `state:"nosave"`
+ FileNoSplice `state:"nosave"`
+ FileNoopFsync `state:"nosave"`
+ FileNoopFlush `state:"nosave"`
+ FileNoopRelease `state:"nosave"`
+ FileNoopWrite `state:"nosave"`
+ FileNotDirReaddir `state:"nosave"`
+ FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ FileStaticContentReader
+}
+
+// InodeNoStatFS implement StatFS by retuning ENOSYS.
+type InodeNoStatFS struct{}
+
+// StatFS implements fs.InodeOperations.StatFS.
+func (InodeNoStatFS) StatFS(context.Context) (fs.Info, error) {
+ return fs.Info{}, syserror.ENOSYS
+}
+
+// InodeStaticFileGetter implements GetFile for a file with static contents.
+//
+// +stateify savable
+type InodeStaticFileGetter struct {
+ Contents []byte
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (i *InodeStaticFileGetter) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &staticFile{
+ FileStaticContentReader: NewFileStaticContentReader(i.Contents),
+ }), nil
+}
+
+// InodeNotMappable returns a nil memmap.Mappable.
+type InodeNotMappable struct{}
+
+// Mappable implements fs.InodeOperations.Mappable.
+func (InodeNotMappable) Mappable(*fs.Inode) memmap.Mappable {
+ return nil
+}
+
+// InodeNoopWriteOut is a no-op implementation of fs.InodeOperations.WriteOut.
+type InodeNoopWriteOut struct{}
+
+// WriteOut is a no-op.
+func (InodeNoopWriteOut) WriteOut(context.Context, *fs.Inode) error {
+ return nil
+}
+
+// InodeNotDirectory can be used by Inodes that are not directories.
+type InodeNotDirectory struct{}
+
+// Lookup implements fs.InodeOperations.Lookup.
+func (InodeNotDirectory) Lookup(context.Context, *fs.Inode, string) (*fs.Dirent, error) {
+ return nil, syserror.ENOTDIR
+}
+
+// Create implements fs.InodeOperations.Create.
+func (InodeNotDirectory) Create(context.Context, *fs.Inode, string, fs.FileFlags, fs.FilePermissions) (*fs.File, error) {
+ return nil, syserror.ENOTDIR
+}
+
+// CreateLink implements fs.InodeOperations.CreateLink.
+func (InodeNotDirectory) CreateLink(context.Context, *fs.Inode, string, string) error {
+ return syserror.ENOTDIR
+}
+
+// CreateHardLink implements fs.InodeOperations.CreateHardLink.
+func (InodeNotDirectory) CreateHardLink(context.Context, *fs.Inode, *fs.Inode, string) error {
+ return syserror.ENOTDIR
+}
+
+// CreateDirectory implements fs.InodeOperations.CreateDirectory.
+func (InodeNotDirectory) CreateDirectory(context.Context, *fs.Inode, string, fs.FilePermissions) error {
+ return syserror.ENOTDIR
+}
+
+// Bind implements fs.InodeOperations.Bind.
+func (InodeNotDirectory) Bind(context.Context, *fs.Inode, string, transport.BoundEndpoint, fs.FilePermissions) (*fs.Dirent, error) {
+ return nil, syserror.ENOTDIR
+}
+
+// CreateFifo implements fs.InodeOperations.CreateFifo.
+func (InodeNotDirectory) CreateFifo(context.Context, *fs.Inode, string, fs.FilePermissions) error {
+ return syserror.ENOTDIR
+}
+
+// Remove implements fs.InodeOperations.Remove.
+func (InodeNotDirectory) Remove(context.Context, *fs.Inode, string) error {
+ return syserror.ENOTDIR
+}
+
+// RemoveDirectory implements fs.InodeOperations.RemoveDirectory.
+func (InodeNotDirectory) RemoveDirectory(context.Context, *fs.Inode, string) error {
+ return syserror.ENOTDIR
+}
+
+// Rename implements fs.FileOperations.Rename.
+func (InodeNotDirectory) Rename(context.Context, *fs.Inode, *fs.Inode, string, *fs.Inode, string, bool) error {
+ return syserror.EINVAL
+}
+
+// InodeNotSocket can be used by Inodes that are not sockets.
+type InodeNotSocket struct{}
+
+// BoundEndpoint implements fs.InodeOperations.BoundEndpoint.
+func (InodeNotSocket) BoundEndpoint(*fs.Inode, string) transport.BoundEndpoint {
+ return nil
+}
+
+// InodeNotTruncatable can be used by Inodes that cannot be truncated.
+type InodeNotTruncatable struct{}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (InodeNotTruncatable) Truncate(context.Context, *fs.Inode, int64) error {
+ return syserror.EINVAL
+}
+
+// InodeIsDirTruncate implements fs.InodeOperations.Truncate for directories.
+type InodeIsDirTruncate struct{}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (InodeIsDirTruncate) Truncate(context.Context, *fs.Inode, int64) error {
+ return syserror.EISDIR
+}
+
+// InodeNoopTruncate implements fs.InodeOperations.Truncate as a noop.
+type InodeNoopTruncate struct{}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (InodeNoopTruncate) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+// InodeNotRenameable can be used by Inodes that cannot be truncated.
+type InodeNotRenameable struct{}
+
+// Rename implements fs.InodeOperations.Rename.
+func (InodeNotRenameable) Rename(context.Context, *fs.Inode, *fs.Inode, string, *fs.Inode, string, bool) error {
+ return syserror.EINVAL
+}
+
+// InodeNotOpenable can be used by Inodes that cannot be opened.
+type InodeNotOpenable struct{}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (InodeNotOpenable) GetFile(context.Context, *fs.Dirent, fs.FileFlags) (*fs.File, error) {
+ return nil, syserror.EIO
+}
+
+// InodeNotVirtual can be used by Inodes that are not virtual.
+type InodeNotVirtual struct{}
+
+// IsVirtual implements fs.InodeOperations.IsVirtual.
+func (InodeNotVirtual) IsVirtual() bool {
+ return false
+}
+
+// InodeVirtual can be used by Inodes that are virtual.
+type InodeVirtual struct{}
+
+// IsVirtual implements fs.InodeOperations.IsVirtual.
+func (InodeVirtual) IsVirtual() bool {
+ return true
+}
+
+// InodeNotSymlink can be used by Inodes that are not symlinks.
+type InodeNotSymlink struct{}
+
+// Readlink implements fs.InodeOperations.Readlink.
+func (InodeNotSymlink) Readlink(context.Context, *fs.Inode) (string, error) {
+ return "", syserror.ENOLINK
+}
+
+// Getlink implements fs.InodeOperations.Getlink.
+func (InodeNotSymlink) Getlink(context.Context, *fs.Inode) (*fs.Dirent, error) {
+ return nil, syserror.ENOLINK
+}
+
+// InodeNoExtendedAttributes can be used by Inodes that do not support
+// extended attributes.
+type InodeNoExtendedAttributes struct{}
+
+// GetXattr implements fs.InodeOperations.GetXattr.
+func (InodeNoExtendedAttributes) GetXattr(context.Context, *fs.Inode, string, uint64) (string, error) {
+ return "", syserror.EOPNOTSUPP
+}
+
+// SetXattr implements fs.InodeOperations.SetXattr.
+func (InodeNoExtendedAttributes) SetXattr(context.Context, *fs.Inode, string, string, uint32) error {
+ return syserror.EOPNOTSUPP
+}
+
+// ListXattr implements fs.InodeOperations.ListXattr.
+func (InodeNoExtendedAttributes) ListXattr(context.Context, *fs.Inode, uint64) (map[string]struct{}, error) {
+ return nil, syserror.EOPNOTSUPP
+}
+
+// RemoveXattr implements fs.InodeOperations.RemoveXattr.
+func (InodeNoExtendedAttributes) RemoveXattr(context.Context, *fs.Inode, string) error {
+ return syserror.EOPNOTSUPP
+}
+
+// InodeNoopRelease implements fs.InodeOperations.Release as a noop.
+type InodeNoopRelease struct{}
+
+// Release implements fs.InodeOperations.Release.
+func (InodeNoopRelease) Release(context.Context) {}
+
+// InodeGenericChecker implements fs.InodeOperations.Check with a generic
+// implementation.
+type InodeGenericChecker struct{}
+
+// Check implements fs.InodeOperations.Check.
+func (InodeGenericChecker) Check(ctx context.Context, inode *fs.Inode, p fs.PermMask) bool {
+ return fs.ContextCanAccessFile(ctx, inode, p)
+}
+
+// InodeDenyWriteChecker implements fs.InodeOperations.Check which denies all
+// write operations.
+type InodeDenyWriteChecker struct{}
+
+// Check implements fs.InodeOperations.Check.
+func (InodeDenyWriteChecker) Check(ctx context.Context, inode *fs.Inode, p fs.PermMask) bool {
+ if p.Write {
+ return false
+ }
+ return fs.ContextCanAccessFile(ctx, inode, p)
+}
+
+//InodeNotAllocatable can be used by Inodes that do not support Allocate().
+type InodeNotAllocatable struct{}
+
+func (InodeNotAllocatable) Allocate(_ context.Context, _ *fs.Inode, _, _ int64) error {
+ return syserror.EOPNOTSUPP
+}
+
+// InodeNoopAllocate implements fs.InodeOperations.Allocate as a noop.
+type InodeNoopAllocate struct{}
+
+// Allocate implements fs.InodeOperations.Allocate.
+func (InodeNoopAllocate) Allocate(_ context.Context, _ *fs.Inode, _, _ int64) error {
+ return nil
+}
+
+// InodeIsDirAllocate implements fs.InodeOperations.Allocate for directories.
+type InodeIsDirAllocate struct{}
+
+// Allocate implements fs.InodeOperations.Allocate.
+func (InodeIsDirAllocate) Allocate(_ context.Context, _ *fs.Inode, _, _ int64) error {
+ return syserror.EISDIR
+}
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
new file mode 100644
index 000000000..800c8b4e1
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -0,0 +1,1061 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Lock order (compare the lock order model in mm/mm.go):
+//
+// CachingInodeOperations.attrMu ("fs locks")
+// CachingInodeOperations.mapsMu ("memmap.Mappable locks not taken by Translate")
+// CachingInodeOperations.dataMu ("memmap.Mappable locks taken by Translate")
+// CachedFileObject locks
+
+// CachingInodeOperations caches the metadata and content of a CachedFileObject.
+// It implements a subset of InodeOperations. As a utility it can be used to
+// implement the full set of InodeOperations. Generally it should not be
+// embedded to avoid unexpected inherited behavior.
+//
+// CachingInodeOperations implements Mappable for the CachedFileObject:
+//
+// - If CachedFileObject.FD returns a value >= 0 then the file descriptor
+// will be memory mapped on the host.
+//
+// - Otherwise, the contents of CachedFileObject are buffered into memory
+// managed by the CachingInodeOperations.
+//
+// Implementations of FileOperations for a CachedFileObject must read and
+// write through CachingInodeOperations using Read and Write respectively.
+//
+// Implementations of InodeOperations.WriteOut must call Sync to write out
+// in-memory modifications of data and metadata to the CachedFileObject.
+//
+// +stateify savable
+type CachingInodeOperations struct {
+ // backingFile is a handle to a cached file object.
+ backingFile CachedFileObject
+
+ // mfp is used to allocate memory that caches backingFile's contents.
+ mfp pgalloc.MemoryFileProvider
+
+ // opts contains options. opts is immutable.
+ opts CachingInodeOperationsOptions
+
+ attrMu sync.Mutex `state:"nosave"`
+
+ // attr is unstable cached metadata.
+ //
+ // attr is protected by attrMu. attr.Size is protected by both attrMu and
+ // dataMu; reading it requires locking either mutex, while mutating it
+ // requires locking both.
+ attr fs.UnstableAttr
+
+ // dirtyAttr is metadata that was updated in-place but hasn't yet
+ // been successfully written out.
+ //
+ // dirtyAttr is protected by attrMu.
+ dirtyAttr fs.AttrMask
+
+ mapsMu sync.Mutex `state:"nosave"`
+
+ // mappings tracks mappings of the cached file object into
+ // memmap.MappingSpaces.
+ //
+ // mappings is protected by mapsMu.
+ mappings memmap.MappingSet
+
+ dataMu sync.RWMutex `state:"nosave"`
+
+ // cache maps offsets into the cached file to offsets into
+ // mfp.MemoryFile() that store the file's data.
+ //
+ // cache is protected by dataMu.
+ cache FileRangeSet
+
+ // dirty tracks dirty segments in cache.
+ //
+ // dirty is protected by dataMu.
+ dirty DirtySet
+
+ // hostFileMapper caches internal mappings of backingFile.FD().
+ hostFileMapper *HostFileMapper
+
+ // refs tracks active references to data in the cache.
+ //
+ // refs is protected by dataMu.
+ refs FrameRefSet
+}
+
+// CachingInodeOperationsOptions configures a CachingInodeOperations.
+//
+// +stateify savable
+type CachingInodeOperationsOptions struct {
+ // If ForcePageCache is true, use the sentry page cache even if a host file
+ // descriptor is available.
+ ForcePageCache bool
+
+ // If LimitHostFDTranslation is true, apply maxFillRange() constraints to
+ // host file descriptor mappings returned by
+ // CachingInodeOperations.Translate().
+ LimitHostFDTranslation bool
+}
+
+// CachedFileObject is a file that may require caching.
+type CachedFileObject interface {
+ // ReadToBlocksAt reads up to dsts.NumBytes() bytes from the file to dsts,
+ // starting at offset, and returns the number of bytes read. ReadToBlocksAt
+ // may return a partial read without an error.
+ ReadToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error)
+
+ // WriteFromBlocksAt writes up to srcs.NumBytes() bytes from srcs to the
+ // file, starting at offset, and returns the number of bytes written.
+ // WriteFromBlocksAt may return a partial write without an error.
+ WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)
+
+ // SetMaskedAttributes sets the attributes in attr that are true in
+ // mask on the backing file. If the mask contains only ATime or MTime
+ // and the CachedFileObject has an FD to the file, then this operation
+ // is a noop unless forceSetTimestamps is true. This avoids an extra
+ // RPC to the gofer in the open-read/write-close case, when the
+ // timestamps on the file will be updated by the host kernel for us.
+ //
+ // SetMaskedAttributes may be called at any point, regardless of whether
+ // the file was opened.
+ SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, forceSetTimestamps bool) error
+
+ // Allocate allows the caller to reserve disk space for the inode.
+ // It's equivalent to fallocate(2) with 'mode=0'.
+ Allocate(ctx context.Context, offset int64, length int64) error
+
+ // Sync instructs the remote filesystem to sync the file to stable storage.
+ Sync(ctx context.Context) error
+
+ // FD returns a host file descriptor. If it is possible for
+ // CachingInodeOperations.AddMapping to have ever been called with writable
+ // = true, the FD must have been opened O_RDWR; otherwise, it may have been
+ // opened O_RDONLY or O_RDWR. (mmap unconditionally requires that mapped
+ // files are readable.) If no host file descriptor is available, FD returns
+ // a negative number.
+ //
+ // For any given CachedFileObject, if FD() ever succeeds (returns a
+ // non-negative number), it must always succeed.
+ //
+ // FD is called iff the file has been memory mapped. This implies that
+ // the file was opened (see fs.InodeOperations.GetFile).
+ FD() int
+}
+
+// NewCachingInodeOperations returns a new CachingInodeOperations backed by
+// a CachedFileObject and its initial unstable attributes.
+func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject, uattr fs.UnstableAttr, opts CachingInodeOperationsOptions) *CachingInodeOperations {
+ mfp := pgalloc.MemoryFileProviderFromContext(ctx)
+ if mfp == nil {
+ panic(fmt.Sprintf("context.Context %T lacks non-nil value for key %T", ctx, pgalloc.CtxMemoryFileProvider))
+ }
+ return &CachingInodeOperations{
+ backingFile: backingFile,
+ mfp: mfp,
+ opts: opts,
+ attr: uattr,
+ hostFileMapper: NewHostFileMapper(),
+ }
+}
+
+// Release implements fs.InodeOperations.Release.
+func (c *CachingInodeOperations) Release() {
+ c.mapsMu.Lock()
+ defer c.mapsMu.Unlock()
+ c.dataMu.Lock()
+ defer c.dataMu.Unlock()
+
+ // Something has gone terribly wrong if we're releasing an inode that is
+ // still memory-mapped.
+ if !c.mappings.IsEmpty() {
+ panic(fmt.Sprintf("Releasing CachingInodeOperations with mappings:\n%s", &c.mappings))
+ }
+
+ // Drop any cached pages that are still awaiting MemoryFile eviction. (This
+ // means that MemoryFile no longer needs to evict them.)
+ mf := c.mfp.MemoryFile()
+ mf.MarkAllUnevictable(c)
+ if err := SyncDirtyAll(context.Background(), &c.cache, &c.dirty, uint64(c.attr.Size), mf, c.backingFile.WriteFromBlocksAt); err != nil {
+ panic(fmt.Sprintf("Failed to writeback cached data: %v", err))
+ }
+ c.cache.DropAll(mf)
+ c.dirty.RemoveAll()
+}
+
+// UnstableAttr implements fs.InodeOperations.UnstableAttr.
+func (c *CachingInodeOperations) UnstableAttr(ctx context.Context, inode *fs.Inode) (fs.UnstableAttr, error) {
+ c.attrMu.Lock()
+ attr := c.attr
+ c.attrMu.Unlock()
+ return attr, nil
+}
+
+// SetPermissions implements fs.InodeOperations.SetPermissions.
+func (c *CachingInodeOperations) SetPermissions(ctx context.Context, inode *fs.Inode, perms fs.FilePermissions) bool {
+ c.attrMu.Lock()
+ defer c.attrMu.Unlock()
+
+ now := ktime.NowFromContext(ctx)
+ masked := fs.AttrMask{Perms: true}
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Perms: perms}, false); err != nil {
+ return false
+ }
+ c.attr.Perms = perms
+ c.touchStatusChangeTimeLocked(now)
+ return true
+}
+
+// SetOwner implements fs.InodeOperations.SetOwner.
+func (c *CachingInodeOperations) SetOwner(ctx context.Context, inode *fs.Inode, owner fs.FileOwner) error {
+ if !owner.UID.Ok() && !owner.GID.Ok() {
+ return nil
+ }
+
+ c.attrMu.Lock()
+ defer c.attrMu.Unlock()
+
+ now := ktime.NowFromContext(ctx)
+ masked := fs.AttrMask{
+ UID: owner.UID.Ok(),
+ GID: owner.GID.Ok(),
+ }
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Owner: owner}, false); err != nil {
+ return err
+ }
+ if owner.UID.Ok() {
+ c.attr.Owner.UID = owner.UID
+ }
+ if owner.GID.Ok() {
+ c.attr.Owner.GID = owner.GID
+ }
+ c.touchStatusChangeTimeLocked(now)
+ return nil
+}
+
+// SetTimestamps implements fs.InodeOperations.SetTimestamps.
+func (c *CachingInodeOperations) SetTimestamps(ctx context.Context, inode *fs.Inode, ts fs.TimeSpec) error {
+ if ts.ATimeOmit && ts.MTimeOmit {
+ return nil
+ }
+
+ c.attrMu.Lock()
+ defer c.attrMu.Unlock()
+
+ // Replace requests to use the "system time" with the current time to
+ // ensure that cached timestamps remain consistent with the remote
+ // filesystem.
+ now := ktime.NowFromContext(ctx)
+ if ts.ATimeSetSystemTime {
+ ts.ATime = now
+ }
+ if ts.MTimeSetSystemTime {
+ ts.MTime = now
+ }
+ masked := fs.AttrMask{
+ AccessTime: !ts.ATimeOmit,
+ ModificationTime: !ts.MTimeOmit,
+ }
+ // Call SetMaskedAttributes with forceSetTimestamps = true to make sure
+ // the timestamp is updated.
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{AccessTime: ts.ATime, ModificationTime: ts.MTime}, true); err != nil {
+ return err
+ }
+ if !ts.ATimeOmit {
+ c.attr.AccessTime = ts.ATime
+ }
+ if !ts.MTimeOmit {
+ c.attr.ModificationTime = ts.MTime
+ }
+ c.touchStatusChangeTimeLocked(now)
+ return nil
+}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (c *CachingInodeOperations) Truncate(ctx context.Context, inode *fs.Inode, size int64) error {
+ c.attrMu.Lock()
+ defer c.attrMu.Unlock()
+
+ // c.attr.Size is protected by both c.attrMu and c.dataMu.
+ c.dataMu.Lock()
+ now := ktime.NowFromContext(ctx)
+ masked := fs.AttrMask{Size: true}
+ attr := fs.UnstableAttr{Size: size}
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr, false); err != nil {
+ c.dataMu.Unlock()
+ return err
+ }
+ oldSize := c.attr.Size
+ c.attr.Size = size
+ c.touchModificationAndStatusChangeTimeLocked(now)
+
+ // We drop c.dataMu here so that we can lock c.mapsMu and invalidate
+ // mappings below. This allows concurrent calls to Read/Translate/etc.
+ // These functions synchronize with an in-progress Truncate by refusing to
+ // use cache contents beyond the new c.attr.Size. (We are still holding
+ // c.attrMu, so we can't race with Truncate/Write.)
+ c.dataMu.Unlock()
+
+ // Nothing left to do unless shrinking the file.
+ if size >= oldSize {
+ return nil
+ }
+
+ oldpgend := fs.OffsetPageEnd(oldSize)
+ newpgend := fs.OffsetPageEnd(size)
+
+ // Invalidate past translations of truncated pages.
+ if newpgend != oldpgend {
+ c.mapsMu.Lock()
+ c.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
+ // Compare Linux's mm/truncate.c:truncate_setsize() =>
+ // truncate_pagecache() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ InvalidatePrivate: true,
+ })
+ c.mapsMu.Unlock()
+ }
+
+ // We are now guaranteed that there are no translations of truncated pages,
+ // and can remove them from the cache. Since truncated pages have been
+ // removed from the backing file, they should be dropped without being
+ // written back.
+ c.dataMu.Lock()
+ defer c.dataMu.Unlock()
+ c.cache.Truncate(uint64(size), c.mfp.MemoryFile())
+ c.dirty.KeepClean(memmap.MappableRange{uint64(size), oldpgend})
+
+ return nil
+}
+
+// Allocate implements fs.InodeOperations.Allocate.
+func (c *CachingInodeOperations) Allocate(ctx context.Context, offset, length int64) error {
+ newSize := offset + length
+
+ // c.attr.Size is protected by both c.attrMu and c.dataMu.
+ c.attrMu.Lock()
+ defer c.attrMu.Unlock()
+ c.dataMu.Lock()
+ defer c.dataMu.Unlock()
+
+ if newSize <= c.attr.Size {
+ return nil
+ }
+
+ now := ktime.NowFromContext(ctx)
+ if err := c.backingFile.Allocate(ctx, offset, length); err != nil {
+ return err
+ }
+
+ c.attr.Size = newSize
+ c.touchModificationAndStatusChangeTimeLocked(now)
+ return nil
+}
+
+// WriteOut implements fs.InodeOperations.WriteOut.
+func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error {
+ c.attrMu.Lock()
+
+ // Write dirty pages back.
+ c.dataMu.Lock()
+ err := SyncDirtyAll(ctx, &c.cache, &c.dirty, uint64(c.attr.Size), c.mfp.MemoryFile(), c.backingFile.WriteFromBlocksAt)
+ c.dataMu.Unlock()
+ if err != nil {
+ c.attrMu.Unlock()
+ return err
+ }
+
+ // SyncDirtyAll above would have grown if needed. On shrinks, the backing
+ // file is called directly, so size is never needs to be updated.
+ c.dirtyAttr.Size = false
+
+ // Write out cached attributes.
+ if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr, false); err != nil {
+ c.attrMu.Unlock()
+ return err
+ }
+ c.dirtyAttr = fs.AttrMask{}
+
+ c.attrMu.Unlock()
+
+ // Fsync the remote file.
+ return c.backingFile.Sync(ctx)
+}
+
+// IncLinks increases the link count and updates cached modification time.
+func (c *CachingInodeOperations) IncLinks(ctx context.Context) {
+ c.attrMu.Lock()
+ c.attr.Links++
+ c.touchModificationAndStatusChangeTimeLocked(ktime.NowFromContext(ctx))
+ c.attrMu.Unlock()
+}
+
+// DecLinks decreases the link count and updates cached modification time.
+func (c *CachingInodeOperations) DecLinks(ctx context.Context) {
+ c.attrMu.Lock()
+ c.attr.Links--
+ c.touchModificationAndStatusChangeTimeLocked(ktime.NowFromContext(ctx))
+ c.attrMu.Unlock()
+}
+
+// TouchAccessTime updates the cached access time in-place to the
+// current time. It does not update status change time in-place. See
+// mm/filemap.c:do_generic_file_read -> include/linux/h:file_accessed.
+func (c *CachingInodeOperations) TouchAccessTime(ctx context.Context, inode *fs.Inode) {
+ if inode.MountSource.Flags.NoAtime {
+ return
+ }
+
+ c.attrMu.Lock()
+ c.touchAccessTimeLocked(ktime.NowFromContext(ctx))
+ c.attrMu.Unlock()
+}
+
+// touchAccesstimeLocked updates the cached access time in-place to the current
+// time.
+//
+// Preconditions: c.attrMu is locked for writing.
+func (c *CachingInodeOperations) touchAccessTimeLocked(now time.Time) {
+ c.attr.AccessTime = now
+ c.dirtyAttr.AccessTime = true
+}
+
+// TouchModificationAndStatusChangeTime updates the cached modification and
+// status change times in-place to the current time.
+func (c *CachingInodeOperations) TouchModificationAndStatusChangeTime(ctx context.Context) {
+ c.attrMu.Lock()
+ c.touchModificationAndStatusChangeTimeLocked(ktime.NowFromContext(ctx))
+ c.attrMu.Unlock()
+}
+
+// touchModificationAndStatusChangeTimeLocked updates the cached modification
+// and status change times in-place to the current time.
+//
+// Preconditions: c.attrMu is locked for writing.
+func (c *CachingInodeOperations) touchModificationAndStatusChangeTimeLocked(now time.Time) {
+ c.attr.ModificationTime = now
+ c.dirtyAttr.ModificationTime = true
+ c.attr.StatusChangeTime = now
+ c.dirtyAttr.StatusChangeTime = true
+}
+
+// TouchStatusChangeTime updates the cached status change time in-place to the
+// current time.
+func (c *CachingInodeOperations) TouchStatusChangeTime(ctx context.Context) {
+ c.attrMu.Lock()
+ c.touchStatusChangeTimeLocked(ktime.NowFromContext(ctx))
+ c.attrMu.Unlock()
+}
+
+// touchStatusChangeTimeLocked updates the cached status change time
+// in-place to the current time.
+//
+// Preconditions: c.attrMu is locked for writing.
+func (c *CachingInodeOperations) touchStatusChangeTimeLocked(now time.Time) {
+ c.attr.StatusChangeTime = now
+ c.dirtyAttr.StatusChangeTime = true
+}
+
+// UpdateUnstable updates the cached unstable attributes. Only non-dirty
+// attributes are updated.
+func (c *CachingInodeOperations) UpdateUnstable(attr fs.UnstableAttr) {
+ // All attributes are protected by attrMu.
+ c.attrMu.Lock()
+
+ if !c.dirtyAttr.Usage {
+ c.attr.Usage = attr.Usage
+ }
+ if !c.dirtyAttr.Perms {
+ c.attr.Perms = attr.Perms
+ }
+ if !c.dirtyAttr.UID {
+ c.attr.Owner.UID = attr.Owner.UID
+ }
+ if !c.dirtyAttr.GID {
+ c.attr.Owner.GID = attr.Owner.GID
+ }
+ if !c.dirtyAttr.AccessTime {
+ c.attr.AccessTime = attr.AccessTime
+ }
+ if !c.dirtyAttr.ModificationTime {
+ c.attr.ModificationTime = attr.ModificationTime
+ }
+ if !c.dirtyAttr.StatusChangeTime {
+ c.attr.StatusChangeTime = attr.StatusChangeTime
+ }
+ if !c.dirtyAttr.Links {
+ c.attr.Links = attr.Links
+ }
+
+ // Size requires holding attrMu and dataMu.
+ c.dataMu.Lock()
+ if !c.dirtyAttr.Size {
+ c.attr.Size = attr.Size
+ }
+ c.dataMu.Unlock()
+
+ c.attrMu.Unlock()
+}
+
+// Read reads from frames and otherwise directly from the backing file
+// into dst starting at offset until dst is full, EOF is reached, or an
+// error is encountered.
+//
+// Read may partially fill dst and return a nil error.
+func (c *CachingInodeOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Have we reached EOF? We check for this again in
+ // inodeReadWriter.ReadToBlocks to avoid holding c.attrMu (which would
+ // serialize reads) or c.dataMu (which would violate lock ordering), but
+ // check here first (before calling into MM) since reading at EOF is
+ // common: getting a return value of 0 from a read syscall is the only way
+ // to detect EOF.
+ //
+ // TODO(jamieliu): Separate out c.attr.Size and use atomics instead of
+ // c.dataMu.
+ c.dataMu.RLock()
+ size := c.attr.Size
+ c.dataMu.RUnlock()
+ if offset >= size {
+ return 0, io.EOF
+ }
+
+ n, err := dst.CopyOutFrom(ctx, &inodeReadWriter{ctx, c, offset})
+ // Compare Linux's mm/filemap.c:do_generic_file_read() => file_accessed().
+ c.TouchAccessTime(ctx, file.Dirent.Inode)
+ return n, err
+}
+
+// Write writes to frames and otherwise directly to the backing file
+// from src starting at offset and until src is empty or an error is
+// encountered.
+//
+// If Write partially fills src, a non-nil error is returned.
+func (c *CachingInodeOperations) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ // Hot path. Avoid defers.
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ c.attrMu.Lock()
+ // Compare Linux's mm/filemap.c:__generic_file_write_iter() => file_update_time().
+ c.touchModificationAndStatusChangeTimeLocked(ktime.NowFromContext(ctx))
+ n, err := src.CopyInTo(ctx, &inodeReadWriter{ctx, c, offset})
+ c.attrMu.Unlock()
+ return n, err
+}
+
+type inodeReadWriter struct {
+ ctx context.Context
+ c *CachingInodeOperations
+ offset int64
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ mem := rw.c.mfp.MemoryFile()
+ fillCache := !rw.c.useHostPageCache() && mem.ShouldCacheEvictable()
+
+ // Hot path. Avoid defers.
+ var unlock func()
+ if fillCache {
+ rw.c.dataMu.Lock()
+ unlock = rw.c.dataMu.Unlock
+ } else {
+ rw.c.dataMu.RLock()
+ unlock = rw.c.dataMu.RUnlock
+ }
+
+ // Compute the range to read.
+ if rw.offset >= rw.c.attr.Size {
+ unlock()
+ return 0, io.EOF
+ }
+ end := fs.ReadEndOffset(rw.offset, int64(dsts.NumBytes()), rw.c.attr.Size)
+ if end == rw.offset { // dsts.NumBytes() == 0?
+ unlock()
+ return 0, nil
+ }
+
+ var done uint64
+ seg, gap := rw.c.cache.Find(uint64(rw.offset))
+ for rw.offset < end {
+ mr := memmap.MappableRange{uint64(rw.offset), uint64(end)}
+ switch {
+ case seg.Ok():
+ // Get internal mappings from the cache.
+ ims, err := mem.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
+ if err != nil {
+ unlock()
+ return done, err
+ }
+
+ // Copy from internal mappings.
+ n, err := safemem.CopySeq(dsts, ims)
+ done += n
+ rw.offset += int64(n)
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ unlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ gapMR := gap.Range().Intersect(mr)
+ if fillCache {
+ // Read into the cache, then re-enter the loop to read from the
+ // cache.
+ reqMR := memmap.MappableRange{
+ Start: uint64(usermem.Addr(gapMR.Start).RoundDown()),
+ End: fs.OffsetPageEnd(int64(gapMR.End)),
+ }
+ optMR := gap.Range()
+ err := rw.c.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), mem, usage.PageCache, rw.c.backingFile.ReadToBlocksAt)
+ mem.MarkEvictable(rw.c, pgalloc.EvictableRange{optMR.Start, optMR.End})
+ seg, gap = rw.c.cache.Find(uint64(rw.offset))
+ if !seg.Ok() {
+ unlock()
+ return done, err
+ }
+ // err might have occurred in part of gap.Range() outside
+ // gapMR. Forget about it for now; if the error matters and
+ // persists, we'll run into it again in a later iteration of
+ // this loop.
+ } else {
+ // Read directly from the backing file.
+ dst := dsts.TakeFirst64(gapMR.Length())
+ n, err := rw.c.backingFile.ReadToBlocksAt(rw.ctx, dst, gapMR.Start)
+ done += n
+ rw.offset += int64(n)
+ dsts = dsts.DropFirst64(n)
+ // Partial reads are fine. But we must stop reading.
+ if n != dst.NumBytes() || err != nil {
+ unlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = gap.NextSegment(), FileRangeGapIterator{}
+ }
+
+ default:
+ break
+ }
+ }
+ unlock()
+ return done, nil
+}
+
+// maybeGrowFile grows the file's size if data has been written past the old
+// size.
+//
+// Preconditions: rw.c.attrMu and rw.c.dataMu bust be locked.
+func (rw *inodeReadWriter) maybeGrowFile() {
+ // If the write ends beyond the file's previous size, it causes the
+ // file to grow.
+ if rw.offset > rw.c.attr.Size {
+ rw.c.attr.Size = rw.offset
+ rw.c.dirtyAttr.Size = true
+ }
+ if rw.offset > rw.c.attr.Usage {
+ // This is incorrect if CachingInodeOperations is caching a sparse
+ // file. (In Linux, keeping inode::i_blocks up to date is the
+ // filesystem's responsibility.)
+ rw.c.attr.Usage = rw.offset
+ rw.c.dirtyAttr.Usage = true
+ }
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+//
+// Preconditions: rw.c.attrMu must be locked.
+func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ // Hot path. Avoid defers.
+ rw.c.dataMu.Lock()
+
+ // Compute the range to write.
+ end := fs.WriteEndOffset(rw.offset, int64(srcs.NumBytes()))
+ if end == rw.offset { // srcs.NumBytes() == 0?
+ rw.c.dataMu.Unlock()
+ return 0, nil
+ }
+
+ mf := rw.c.mfp.MemoryFile()
+ var done uint64
+ seg, gap := rw.c.cache.Find(uint64(rw.offset))
+ for rw.offset < end {
+ mr := memmap.MappableRange{uint64(rw.offset), uint64(end)}
+ switch {
+ case seg.Ok() && seg.Start() < mr.End:
+ // Get internal mappings from the cache.
+ segMR := seg.Range().Intersect(mr)
+ ims, err := mf.MapInternal(seg.FileRangeOf(segMR), usermem.Write)
+ if err != nil {
+ rw.maybeGrowFile()
+ rw.c.dataMu.Unlock()
+ return done, err
+ }
+
+ // Copy to internal mappings.
+ n, err := safemem.CopySeq(ims, srcs)
+ done += n
+ rw.offset += int64(n)
+ srcs = srcs.DropFirst64(n)
+ rw.c.dirty.MarkDirty(segMR)
+ if err != nil {
+ rw.maybeGrowFile()
+ rw.c.dataMu.Unlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok() && gap.Start() < mr.End:
+ // Write directly to the backing file. At present, we never fill
+ // the cache when writing, since doing so can convert small writes
+ // into inefficient read-modify-write cycles, and we have no
+ // mechanism for detecting or avoiding this.
+ gapmr := gap.Range().Intersect(mr)
+ src := srcs.TakeFirst64(gapmr.Length())
+ n, err := rw.c.backingFile.WriteFromBlocksAt(rw.ctx, src, gapmr.Start)
+ done += n
+ rw.offset += int64(n)
+ srcs = srcs.DropFirst64(n)
+ // Partial writes are fine. But we must stop writing.
+ if n != src.NumBytes() || err != nil {
+ rw.maybeGrowFile()
+ rw.c.dataMu.Unlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = gap.NextSegment(), FileRangeGapIterator{}
+
+ default:
+ break
+ }
+ }
+ rw.maybeGrowFile()
+ rw.c.dataMu.Unlock()
+ return done, nil
+}
+
+// useHostPageCache returns true if c uses c.backingFile.FD() for all file I/O
+// and memory mappings, and false if c.cache may contain data cached from
+// c.backingFile.
+func (c *CachingInodeOperations) useHostPageCache() bool {
+ return !c.opts.ForcePageCache && c.backingFile.FD() >= 0
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (c *CachingInodeOperations) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ // Hot path. Avoid defers.
+ c.mapsMu.Lock()
+ mapped := c.mappings.AddMapping(ms, ar, offset, writable)
+ // Do this unconditionally since whether we have c.backingFile.FD() >= 0
+ // can change across save/restore.
+ for _, r := range mapped {
+ c.hostFileMapper.IncRefOn(r)
+ }
+ if !c.useHostPageCache() {
+ // c.Evict() will refuse to evict memory-mapped pages, so tell the
+ // MemoryFile to not bother trying.
+ mf := c.mfp.MemoryFile()
+ for _, r := range mapped {
+ mf.MarkUnevictable(c, pgalloc.EvictableRange{r.Start, r.End})
+ }
+ }
+ c.mapsMu.Unlock()
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (c *CachingInodeOperations) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ // Hot path. Avoid defers.
+ c.mapsMu.Lock()
+ unmapped := c.mappings.RemoveMapping(ms, ar, offset, writable)
+ for _, r := range unmapped {
+ c.hostFileMapper.DecRefOn(r)
+ }
+ if c.useHostPageCache() {
+ c.mapsMu.Unlock()
+ return
+ }
+
+ // Pages that are no longer referenced by any application memory mappings
+ // are now considered unused; allow MemoryFile to evict them when
+ // necessary.
+ mf := c.mfp.MemoryFile()
+ c.dataMu.Lock()
+ for _, r := range unmapped {
+ // Since these pages are no longer mapped, they are no longer
+ // concurrently dirtyable by a writable memory mapping.
+ c.dirty.AllowClean(r)
+ mf.MarkEvictable(c, pgalloc.EvictableRange{r.Start, r.End})
+ }
+ c.dataMu.Unlock()
+ c.mapsMu.Unlock()
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (c *CachingInodeOperations) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ return c.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (c *CachingInodeOperations) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ // Hot path. Avoid defer.
+ if c.useHostPageCache() {
+ mr := optional
+ if c.opts.LimitHostFDTranslation {
+ mr = maxFillRange(required, optional)
+ }
+ return []memmap.Translation{
+ {
+ Source: mr,
+ File: c,
+ Offset: mr.Start,
+ Perms: usermem.AnyAccess,
+ },
+ }, nil
+ }
+
+ c.dataMu.Lock()
+
+ // Constrain translations to c.attr.Size (rounded up) to prevent
+ // translation to pages that may be concurrently truncated.
+ pgend := fs.OffsetPageEnd(c.attr.Size)
+ var beyondEOF bool
+ if required.End > pgend {
+ if required.Start >= pgend {
+ c.dataMu.Unlock()
+ return nil, &memmap.BusError{io.EOF}
+ }
+ beyondEOF = true
+ required.End = pgend
+ }
+ if optional.End > pgend {
+ optional.End = pgend
+ }
+
+ mf := c.mfp.MemoryFile()
+ cerr := c.cache.Fill(ctx, required, maxFillRange(required, optional), mf, usage.PageCache, c.backingFile.ReadToBlocksAt)
+
+ var ts []memmap.Translation
+ var translatedEnd uint64
+ for seg := c.cache.FindSegment(required.Start); seg.Ok() && seg.Start() < required.End; seg, _ = seg.NextNonEmpty() {
+ segMR := seg.Range().Intersect(optional)
+ // TODO(jamieliu): Make Translations writable even if writability is
+ // not required if already kept-dirty by another writable translation.
+ perms := usermem.AccessType{
+ Read: true,
+ Execute: true,
+ }
+ if at.Write {
+ // From this point forward, this memory can be dirtied through the
+ // mapping at any time.
+ c.dirty.KeepDirty(segMR)
+ perms.Write = true
+ }
+ ts = append(ts, memmap.Translation{
+ Source: segMR,
+ File: mf,
+ Offset: seg.FileRangeOf(segMR).Start,
+ Perms: perms,
+ })
+ translatedEnd = segMR.End
+ }
+
+ c.dataMu.Unlock()
+
+ // Don't return the error returned by c.cache.Fill if it occurred outside
+ // of required.
+ if translatedEnd < required.End && cerr != nil {
+ return ts, &memmap.BusError{cerr}
+ }
+ if beyondEOF {
+ return ts, &memmap.BusError{io.EOF}
+ }
+ return ts, nil
+}
+
+func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange {
+ const maxReadahead = 64 << 10 // 64 KB, chosen arbitrarily
+ if required.Length() >= maxReadahead {
+ return required
+ }
+ if optional.Length() <= maxReadahead {
+ return optional
+ }
+ optional.Start = required.Start
+ if optional.Length() <= maxReadahead {
+ return optional
+ }
+ optional.End = optional.Start + maxReadahead
+ return optional
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (c *CachingInodeOperations) InvalidateUnsavable(ctx context.Context) error {
+ // Whether we have a host fd (and consequently what platform.File is
+ // mapped) can change across save/restore, so invalidate all translations
+ // unconditionally.
+ c.mapsMu.Lock()
+ defer c.mapsMu.Unlock()
+ c.mappings.InvalidateAll(memmap.InvalidateOpts{})
+
+ // Sync the cache's contents so that if we have a host fd after restore,
+ // the remote file's contents are coherent.
+ mf := c.mfp.MemoryFile()
+ c.dataMu.Lock()
+ defer c.dataMu.Unlock()
+ if err := SyncDirtyAll(ctx, &c.cache, &c.dirty, uint64(c.attr.Size), mf, c.backingFile.WriteFromBlocksAt); err != nil {
+ return err
+ }
+
+ // Discard the cache so that it's not stored in saved state. This is safe
+ // because per InvalidateUnsavable invariants, no new translations can have
+ // been returned after we invalidated all existing translations above.
+ c.cache.DropAll(mf)
+ c.dirty.RemoveAll()
+
+ return nil
+}
+
+// NotifyChangeFD must be called after the file description represented by
+// CachedFileObject.FD() changes.
+func (c *CachingInodeOperations) NotifyChangeFD() error {
+ // Update existing sentry mappings to refer to the new file description.
+ if err := c.hostFileMapper.RegenerateMappings(c.backingFile.FD()); err != nil {
+ return err
+ }
+
+ // Shoot down existing application mappings of the old file description;
+ // they will be remapped with the new file description on demand.
+ c.mapsMu.Lock()
+ defer c.mapsMu.Unlock()
+
+ c.mappings.InvalidateAll(memmap.InvalidateOpts{})
+ return nil
+}
+
+// Evict implements pgalloc.EvictableMemoryUser.Evict.
+func (c *CachingInodeOperations) Evict(ctx context.Context, er pgalloc.EvictableRange) {
+ c.mapsMu.Lock()
+ defer c.mapsMu.Unlock()
+ c.dataMu.Lock()
+ defer c.dataMu.Unlock()
+
+ mr := memmap.MappableRange{er.Start, er.End}
+ mf := c.mfp.MemoryFile()
+ // Only allow pages that are no longer memory-mapped to be evicted.
+ for mgap := c.mappings.LowerBoundGap(mr.Start); mgap.Ok() && mgap.Start() < mr.End; mgap = mgap.NextGap() {
+ mgapMR := mgap.Range().Intersect(mr)
+ if mgapMR.Length() == 0 {
+ continue
+ }
+ if err := SyncDirty(ctx, mgapMR, &c.cache, &c.dirty, uint64(c.attr.Size), mf, c.backingFile.WriteFromBlocksAt); err != nil {
+ log.Warningf("Failed to writeback cached data %v: %v", mgapMR, err)
+ }
+ c.cache.Drop(mgapMR, mf)
+ c.dirty.KeepClean(mgapMR)
+ }
+}
+
+// IncRef implements platform.File.IncRef. This is used when we directly map an
+// underlying host fd and CachingInodeOperations is used as the platform.File
+// during translation.
+func (c *CachingInodeOperations) IncRef(fr platform.FileRange) {
+ // Hot path. Avoid defers.
+ c.dataMu.Lock()
+ seg, gap := c.refs.Find(fr.Start)
+ for {
+ switch {
+ case seg.Ok() && seg.Start() < fr.End:
+ seg = c.refs.Isolate(seg, fr)
+ seg.SetValue(seg.Value() + 1)
+ seg, gap = seg.NextNonEmpty()
+ case gap.Ok() && gap.Start() < fr.End:
+ newRange := gap.Range().Intersect(fr)
+ usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped)
+ seg, gap = c.refs.InsertWithoutMerging(gap, newRange, 1).NextNonEmpty()
+ default:
+ c.refs.MergeAdjacent(fr)
+ c.dataMu.Unlock()
+ return
+ }
+ }
+}
+
+// DecRef implements platform.File.DecRef. This is used when we directly map an
+// underlying host fd and CachingInodeOperations is used as the platform.File
+// during translation.
+func (c *CachingInodeOperations) DecRef(fr platform.FileRange) {
+ // Hot path. Avoid defers.
+ c.dataMu.Lock()
+ seg := c.refs.FindSegment(fr.Start)
+
+ for seg.Ok() && seg.Start() < fr.End {
+ seg = c.refs.Isolate(seg, fr)
+ if old := seg.Value(); old == 1 {
+ usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped)
+ seg = c.refs.Remove(seg).NextSegment()
+ } else {
+ seg.SetValue(old - 1)
+ seg = seg.NextSegment()
+ }
+ }
+ c.refs.MergeAdjacent(fr)
+ c.dataMu.Unlock()
+}
+
+// MapInternal implements platform.File.MapInternal. This is used when we
+// directly map an underlying host fd and CachingInodeOperations is used as the
+// platform.File during translation.
+func (c *CachingInodeOperations) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+ return c.hostFileMapper.MapInternal(fr, c.backingFile.FD(), at.Write)
+}
+
+// FD implements platform.File.FD. This is used when we directly map an
+// underlying host fd and CachingInodeOperations is used as the platform.File
+// during translation.
+func (c *CachingInodeOperations) FD() int {
+ return c.backingFile.FD()
+}
diff --git a/pkg/sentry/fs/fsutil/inode_cached_test.go b/pkg/sentry/fs/fsutil/inode_cached_test.go
new file mode 100644
index 000000000..1547584c5
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/inode_cached_test.go
@@ -0,0 +1,389 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "bytes"
+ "io"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type noopBackingFile struct{}
+
+func (noopBackingFile) ReadToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) {
+ return dsts.NumBytes(), nil
+}
+
+func (noopBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) {
+ return srcs.NumBytes(), nil
+}
+
+func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error {
+ return nil
+}
+
+func (noopBackingFile) Sync(context.Context) error {
+ return nil
+}
+
+func (noopBackingFile) FD() int {
+ return -1
+}
+
+func (noopBackingFile) Allocate(ctx context.Context, offset int64, length int64) error {
+ return nil
+}
+
+func TestSetPermissions(t *testing.T) {
+ ctx := contexttest.Context(t)
+
+ uattr := fs.WithCurrentTime(ctx, fs.UnstableAttr{
+ Perms: fs.FilePermsFromMode(0444),
+ })
+ iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{})
+ defer iops.Release()
+
+ perms := fs.FilePermsFromMode(0777)
+ if !iops.SetPermissions(ctx, nil, perms) {
+ t.Fatalf("SetPermissions failed, want success")
+ }
+
+ // Did permissions change?
+ if iops.attr.Perms != perms {
+ t.Fatalf("got perms +%v, want +%v", iops.attr.Perms, perms)
+ }
+
+ // Did status change time change?
+ if !iops.dirtyAttr.StatusChangeTime {
+ t.Fatalf("got status change time not dirty, want dirty")
+ }
+ if iops.attr.StatusChangeTime.Equal(uattr.StatusChangeTime) {
+ t.Fatalf("got status change time unchanged")
+ }
+}
+
+func TestSetTimestamps(t *testing.T) {
+ ctx := contexttest.Context(t)
+ for _, test := range []struct {
+ desc string
+ ts fs.TimeSpec
+ wantChanged fs.AttrMask
+ }{
+ {
+ desc: "noop",
+ ts: fs.TimeSpec{
+ ATimeOmit: true,
+ MTimeOmit: true,
+ },
+ wantChanged: fs.AttrMask{},
+ },
+ {
+ desc: "access time only",
+ ts: fs.TimeSpec{
+ ATime: ktime.NowFromContext(ctx),
+ MTimeOmit: true,
+ },
+ wantChanged: fs.AttrMask{
+ AccessTime: true,
+ },
+ },
+ {
+ desc: "modification time only",
+ ts: fs.TimeSpec{
+ ATimeOmit: true,
+ MTime: ktime.NowFromContext(ctx),
+ },
+ wantChanged: fs.AttrMask{
+ ModificationTime: true,
+ },
+ },
+ {
+ desc: "access and modification time",
+ ts: fs.TimeSpec{
+ ATime: ktime.NowFromContext(ctx),
+ MTime: ktime.NowFromContext(ctx),
+ },
+ wantChanged: fs.AttrMask{
+ AccessTime: true,
+ ModificationTime: true,
+ },
+ },
+ {
+ desc: "system time access and modification time",
+ ts: fs.TimeSpec{
+ ATimeSetSystemTime: true,
+ MTimeSetSystemTime: true,
+ },
+ wantChanged: fs.AttrMask{
+ AccessTime: true,
+ ModificationTime: true,
+ },
+ },
+ } {
+ t.Run(test.desc, func(t *testing.T) {
+ ctx := contexttest.Context(t)
+
+ epoch := ktime.ZeroTime
+ uattr := fs.UnstableAttr{
+ AccessTime: epoch,
+ ModificationTime: epoch,
+ StatusChangeTime: epoch,
+ }
+ iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{})
+ defer iops.Release()
+
+ if err := iops.SetTimestamps(ctx, nil, test.ts); err != nil {
+ t.Fatalf("SetTimestamps got error %v, want nil", err)
+ }
+ if test.wantChanged.AccessTime {
+ if !iops.attr.AccessTime.After(uattr.AccessTime) {
+ t.Fatalf("diritied access time did not advance, want %v > %v", iops.attr.AccessTime, uattr.AccessTime)
+ }
+ if !iops.dirtyAttr.StatusChangeTime {
+ t.Fatalf("dirty access time requires dirty status change time")
+ }
+ if !iops.attr.StatusChangeTime.After(uattr.StatusChangeTime) {
+ t.Fatalf("dirtied status change time did not advance")
+ }
+ }
+ if test.wantChanged.ModificationTime {
+ if !iops.attr.ModificationTime.After(uattr.ModificationTime) {
+ t.Fatalf("diritied modification time did not advance")
+ }
+ if !iops.dirtyAttr.StatusChangeTime {
+ t.Fatalf("dirty modification time requires dirty status change time")
+ }
+ if !iops.attr.StatusChangeTime.After(uattr.StatusChangeTime) {
+ t.Fatalf("dirtied status change time did not advance")
+ }
+ }
+ })
+ }
+}
+
+func TestTruncate(t *testing.T) {
+ ctx := contexttest.Context(t)
+
+ uattr := fs.UnstableAttr{
+ Size: 0,
+ }
+ iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{})
+ defer iops.Release()
+
+ if err := iops.Truncate(ctx, nil, uattr.Size); err != nil {
+ t.Fatalf("Truncate got error %v, want nil", err)
+ }
+ var size int64 = 4096
+ if err := iops.Truncate(ctx, nil, size); err != nil {
+ t.Fatalf("Truncate got error %v, want nil", err)
+ }
+ if iops.attr.Size != size {
+ t.Fatalf("Truncate got %d, want %d", iops.attr.Size, size)
+ }
+ if !iops.dirtyAttr.ModificationTime || !iops.dirtyAttr.StatusChangeTime {
+ t.Fatalf("Truncate did not dirty modification and status change time")
+ }
+ if !iops.attr.ModificationTime.After(uattr.ModificationTime) {
+ t.Fatalf("dirtied modification time did not change")
+ }
+ if !iops.attr.StatusChangeTime.After(uattr.StatusChangeTime) {
+ t.Fatalf("dirtied status change time did not change")
+ }
+}
+
+type sliceBackingFile struct {
+ data []byte
+}
+
+func newSliceBackingFile(data []byte) *sliceBackingFile {
+ return &sliceBackingFile{data}
+}
+
+func (f *sliceBackingFile) ReadToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) {
+ r := safemem.BlockSeqReader{safemem.BlockSeqOf(safemem.BlockFromSafeSlice(f.data)).DropFirst64(offset)}
+ return r.ReadToBlocks(dsts)
+}
+
+func (f *sliceBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) {
+ w := safemem.BlockSeqWriter{safemem.BlockSeqOf(safemem.BlockFromSafeSlice(f.data)).DropFirst64(offset)}
+ return w.WriteFromBlocks(srcs)
+}
+
+func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error {
+ return nil
+}
+
+func (*sliceBackingFile) Sync(context.Context) error {
+ return nil
+}
+
+func (*sliceBackingFile) FD() int {
+ return -1
+}
+
+func (f *sliceBackingFile) Allocate(ctx context.Context, offset int64, length int64) error {
+ return syserror.EOPNOTSUPP
+}
+
+type noopMappingSpace struct{}
+
+// Invalidate implements memmap.MappingSpace.Invalidate.
+func (noopMappingSpace) Invalidate(ar usermem.AddrRange, opts memmap.InvalidateOpts) {
+}
+
+func anonInode(ctx context.Context) *fs.Inode {
+ return fs.NewInode(ctx, &SimpleFileInode{
+ InodeSimpleAttributes: NewInodeSimpleAttributes(ctx, fs.FileOwnerFromContext(ctx), fs.FilePermissions{
+ User: fs.PermMask{Read: true, Write: true},
+ }, 0),
+ }, fs.NewPseudoMountSource(ctx), fs.StableAttr{
+ Type: fs.Anonymous,
+ BlockSize: usermem.PageSize,
+ })
+}
+
+func pagesOf(bs ...byte) []byte {
+ buf := make([]byte, 0, len(bs)*usermem.PageSize)
+ for _, b := range bs {
+ buf = append(buf, bytes.Repeat([]byte{b}, usermem.PageSize)...)
+ }
+ return buf
+}
+
+func TestRead(t *testing.T) {
+ ctx := contexttest.Context(t)
+
+ // Construct a 3-page file.
+ buf := pagesOf('a', 'b', 'c')
+ file := fs.NewFile(ctx, fs.NewDirent(ctx, anonInode(ctx), "anon"), fs.FileFlags{}, nil)
+ uattr := fs.UnstableAttr{
+ Size: int64(len(buf)),
+ }
+ iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, CachingInodeOperationsOptions{})
+ defer iops.Release()
+
+ // Expect the cache to be initially empty.
+ if cached := iops.cache.Span(); cached != 0 {
+ t.Errorf("Span got %d, want 0", cached)
+ }
+
+ // Create a memory mapping of the second page (as CachingInodeOperations
+ // expects to only cache mapped pages), then call Translate to force it to
+ // be cached.
+ var ms noopMappingSpace
+ ar := usermem.AddrRange{usermem.PageSize, 2 * usermem.PageSize}
+ if err := iops.AddMapping(ctx, ms, ar, usermem.PageSize, true); err != nil {
+ t.Fatalf("AddMapping got %v, want nil", err)
+ }
+ mr := memmap.MappableRange{usermem.PageSize, 2 * usermem.PageSize}
+ if _, err := iops.Translate(ctx, mr, mr, usermem.Read); err != nil {
+ t.Fatalf("Translate got %v, want nil", err)
+ }
+ if cached := iops.cache.Span(); cached != usermem.PageSize {
+ t.Errorf("SpanRange got %d, want %d", cached, usermem.PageSize)
+ }
+
+ // Try to read 4 pages. The first and third pages should be read directly
+ // from the "file", the second page should be read from the cache, and only
+ // 3 pages (the size of the file) should be readable.
+ rbuf := make([]byte, 4*usermem.PageSize)
+ dst := usermem.BytesIOSequence(rbuf)
+ n, err := iops.Read(ctx, file, dst, 0)
+ if n != 3*usermem.PageSize || (err != nil && err != io.EOF) {
+ t.Fatalf("Read got (%d, %v), want (%d, nil or EOF)", n, err, 3*usermem.PageSize)
+ }
+ rbuf = rbuf[:3*usermem.PageSize]
+
+ // Did we get the bytes we expect?
+ if !bytes.Equal(rbuf, buf) {
+ t.Errorf("Read back bytes %v, want %v", rbuf, buf)
+ }
+
+ // Delete the memory mapping before iops.Release(). The cached page will
+ // either be evicted by ctx's pgalloc.MemoryFile, or dropped by
+ // iops.Release().
+ iops.RemoveMapping(ctx, ms, ar, usermem.PageSize, true)
+}
+
+func TestWrite(t *testing.T) {
+ ctx := contexttest.Context(t)
+
+ // Construct a 4-page file.
+ buf := pagesOf('a', 'b', 'c', 'd')
+ orig := append([]byte(nil), buf...)
+ inode := anonInode(ctx)
+ uattr := fs.UnstableAttr{
+ Size: int64(len(buf)),
+ }
+ iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, CachingInodeOperationsOptions{})
+ defer iops.Release()
+
+ // Expect the cache to be initially empty.
+ if cached := iops.cache.Span(); cached != 0 {
+ t.Errorf("Span got %d, want 0", cached)
+ }
+
+ // Create a memory mapping of the second and third pages (as
+ // CachingInodeOperations expects to only cache mapped pages), then call
+ // Translate to force them to be cached.
+ var ms noopMappingSpace
+ ar := usermem.AddrRange{usermem.PageSize, 3 * usermem.PageSize}
+ if err := iops.AddMapping(ctx, ms, ar, usermem.PageSize, true); err != nil {
+ t.Fatalf("AddMapping got %v, want nil", err)
+ }
+ defer iops.RemoveMapping(ctx, ms, ar, usermem.PageSize, true)
+ mr := memmap.MappableRange{usermem.PageSize, 3 * usermem.PageSize}
+ if _, err := iops.Translate(ctx, mr, mr, usermem.Read); err != nil {
+ t.Fatalf("Translate got %v, want nil", err)
+ }
+ if cached := iops.cache.Span(); cached != 2*usermem.PageSize {
+ t.Errorf("SpanRange got %d, want %d", cached, 2*usermem.PageSize)
+ }
+
+ // Write to the first 2 pages.
+ wbuf := pagesOf('e', 'f')
+ src := usermem.BytesIOSequence(wbuf)
+ n, err := iops.Write(ctx, src, 0)
+ if n != 2*usermem.PageSize || err != nil {
+ t.Fatalf("Write got (%d, %v), want (%d, nil)", n, err, 2*usermem.PageSize)
+ }
+
+ // The first page should have been written directly, since it was not cached.
+ want := append([]byte(nil), orig...)
+ copy(want, pagesOf('e'))
+ if !bytes.Equal(buf, want) {
+ t.Errorf("File contents are %v, want %v", buf, want)
+ }
+
+ // Sync back to the "backing file".
+ if err := iops.WriteOut(ctx, inode); err != nil {
+ t.Errorf("Sync got %v, want nil", err)
+ }
+
+ // Now the second page should have been written as well.
+ copy(want[usermem.PageSize:], pagesOf('f'))
+ if !bytes.Equal(buf, want) {
+ t.Errorf("File contents are %v, want %v", buf, want)
+ }
+}
diff --git a/pkg/sentry/fs/g3doc/.gitignore b/pkg/sentry/fs/g3doc/.gitignore
new file mode 100644
index 000000000..2d19fc766
--- /dev/null
+++ b/pkg/sentry/fs/g3doc/.gitignore
@@ -0,0 +1 @@
+*.html
diff --git a/pkg/sentry/fs/g3doc/fuse.md b/pkg/sentry/fs/g3doc/fuse.md
new file mode 100644
index 000000000..2ca84dd74
--- /dev/null
+++ b/pkg/sentry/fs/g3doc/fuse.md
@@ -0,0 +1,263 @@
+# Foreword
+
+This document describes an on-going project to support FUSE filesystems within
+the sentry. This is intended to become the final documentation for this
+subsystem, and is therefore written in the past tense. However FUSE support is
+currently incomplete and the document will be updated as things progress.
+
+# FUSE: Filesystem in Userspace
+
+The sentry supports dispatching filesystem operations to a FUSE server, allowing
+FUSE filesystem to be used with a sandbox.
+
+## Overview
+
+FUSE has two main components:
+
+1. A client kernel driver (canonically `fuse.ko` in Linux), which forwards
+ filesystem operations (usually initiated by syscalls) to the server.
+
+2. A server, which is a userspace daemon that implements the actual filesystem.
+
+The sentry implements the client component, which allows a server daemon running
+within the sandbox to implement a filesystem within the sandbox.
+
+A FUSE filesystem is initialized with `mount(2)`, typically with the help of a
+utility like `fusermount(1)`. Various mount options exist for establishing
+ownership and access permissions on the filesystem, but the most important mount
+option is a file descriptor used to establish communication between the client
+and server.
+
+The FUSE device FD is obtained by opening `/dev/fuse`. During regular operation,
+the client and server use the FUSE protocol described in `fuse(4)` to service
+filesystem operations. See the "Protocol" section below for more information
+about this protocol. The core of the sentry support for FUSE is the client-side
+implementation of this protocol.
+
+## FUSE in the Sentry
+
+The sentry's FUSE client targets VFS2 and has the following components:
+
+- An implementation of `/dev/fuse`.
+
+- A VFS2 filesystem for mapping syscalls to FUSE ops. Since we're targeting
+ VFS2, one point of contention may be the lack of inodes in VFS2. We can
+ tentatively implement a kernfs-based filesystem to bridge the gap in APIs.
+ The kernfs base functionality can serve the role of the Linux inode cache
+ and, the filesystem can map VFS2 syscalls to kernfs inode operations; see
+ the `kernfs.Inode` interface.
+
+The FUSE protocol lends itself well to marshaling with `go_marshal`. The various
+request and response packets can be defined in the ABI package and converted to
+and from the wire format using `go_marshal`.
+
+### Design Goals
+
+- While filesystem performance is always important, the sentry's FUSE support
+ is primarily concerned with compatibility, with performance as a secondary
+ concern.
+
+- Avoiding deadlocks from a hung server daemon.
+
+- Consider the potential for denial of service from a malicious server daemon.
+ Protecting itself from userspace is already a design goal for the sentry,
+ but needs additional consideration for FUSE. Normally, an operating system
+ doesn't rely on userspace to make progress with filesystem operations. Since
+ this changes with FUSE, it opens up the possibility of creating a chain of
+ dependencies controlled by userspace, which could affect an entire sandbox.
+ For example: a FUSE op can block a syscall, which could be holding a
+ subsystem lock, which can then block another task goroutine.
+
+### Milestones
+
+Below are some broad goals to aim for while implementing FUSE in the sentry.
+Many FUSE ops can be grouped into broad categories of functionality, and most
+ops can be implemented in parallel.
+
+#### Minimal client that can mount a trivial FUSE filesystem.
+
+- Implement `/dev/fuse` - a character device used to establish an FD for
+ communication between the sentry and the server daemon.
+
+- Implement basic FUSE ops like `FUSE_INIT`, `FUSE_DESTROY`.
+
+#### Read-only mount with basic file operations
+
+- Implement the majority of file, directory and file descriptor FUSE ops. For
+ this milestone, we can skip uncommon or complex operations like mmap, mknod,
+ file locking, poll, and extended attributes. We can stub these out along
+ with any ops that modify the filesystem. The exact list of required ops are
+ to be determined, but the goal is to mount a real filesystem as read-only,
+ and be able to read contents from the filesystem in the sentry.
+
+#### Full read-write support
+
+- Implement the remaining FUSE ops and decide if we can omit rarely used
+ operations like ioctl.
+
+# Appendix
+
+## FUSE Protocol
+
+The FUSE protocol is a request-response protocol. All requests are initiated by
+the client. The wire-format for the protocol is raw C structs serialized to
+memory.
+
+All FUSE requests begin with the following request header:
+
+```c
+struct fuse_in_header {
+ uint32_t len; // Length of the request, including this header.
+ uint32_t opcode; // Requested operation.
+ uint64_t unique; // A unique identifier for this request.
+ uint64_t nodeid; // ID of the filesystem object being operated on.
+ uint32_t uid; // UID of the requesting process.
+ uint32_t gid; // GID of the requesting process.
+ uint32_t pid; // PID of the requesting process.
+ uint32_t padding;
+};
+```
+
+The request is then followed by a payload specific to the `opcode`.
+
+All responses begin with this response header:
+
+```c
+struct fuse_out_header {
+ uint32_t len; // Length of the response, including this header.
+ int32_t error; // Status of the request, 0 if success.
+ uint64_t unique; // The unique identifier from the corresponding request.
+};
+```
+
+The response payload also depends on the request `opcode`. If `error != 0`, the
+response payload must be empty.
+
+### Operations
+
+The following is a list of all FUSE operations used in `fuse_in_header.opcode`
+as of Linux v4.4, and a brief description of their purpose. These are defined in
+`uapi/linux/fuse.h`. Many of these have a corresponding request and response
+payload struct; `fuse(4)` has details for some of these. We also note how these
+operations map to the sentry virtual filesystem.
+
+#### FUSE meta-operations
+
+These operations are specific to FUSE and don't have a corresponding action in a
+generic filesystem.
+
+- `FUSE_INIT`: This operation initializes a new FUSE filesystem, and is the
+ first message sent by the client after mount. This is used for version and
+ feature negotiation. This is related to `mount(2)`.
+- `FUSE_DESTROY`: Teardown a FUSE filesystem, related to `unmount(2)`.
+- `FUSE_INTERRUPT`: Interrupts an in-flight operation, specified by the
+ `fuse_in_header.unique` value provided in the corresponding request header.
+ The client can send at most one of these per request, and will enter an
+ uninterruptible wait for a reply. The server is expected to reply promptly.
+- `FUSE_FORGET`: A hint to the server that server should evict the indicate
+ node from any caches. This is wired up to `(struct
+ super_operations).evict_inode` in Linux, which is in turned hooked as the
+ inode cache shrinker which is typically triggered by system memory pressure.
+- `FUSE_BATCH_FORGET`: Batch version of `FUSE_FORGET`.
+
+#### Filesystem Syscalls
+
+These FUSE ops map directly to an equivalent filesystem syscall, or family of
+syscalls. The relevant syscalls have a similar name to the operation, unless
+otherwise noted.
+
+Node creation:
+
+- `FUSE_MKNOD`
+- `FUSE_MKDIR`
+- `FUSE_CREATE`: This is equivalent to `open(2)` and `creat(2)`, which
+ atomically creates and opens a node.
+
+Node attributes and extended attributes:
+
+- `FUSE_GETATTR`
+- `FUSE_SETATTR`
+- `FUSE_SETXATTR`
+- `FUSE_GETXATTR`
+- `FUSE_LISTXATTR`
+- `FUSE_REMOVEXATTR`
+
+Node link manipulation:
+
+- `FUSE_READLINK`
+- `FUSE_LINK`
+- `FUSE_SYMLINK`
+- `FUSE_UNLINK`
+
+Directory operations:
+
+- `FUSE_RMDIR`
+- `FUSE_RENAME`
+- `FUSE_RENAME2`
+- `FUSE_OPENDIR`: `open(2)` for directories.
+- `FUSE_RELEASEDIR`: `close(2)` for directories.
+- `FUSE_READDIR`
+- `FUSE_READDIRPLUS`
+- `FUSE_FSYNCDIR`: `fsync(2)` for directories.
+- `FUSE_LOOKUP`: Establishes a unique identifier for a FS node. This is
+ reminiscent of `VirtualFilesystem.GetDentryAt` in that it resolves a path
+ component to a node. However the returned identifier is opaque to the
+ client. The server must remember this mapping, as this is how the client
+ will reference the node in the future.
+
+File operations:
+
+- `FUSE_OPEN`: `open(2)` for files.
+- `FUSE_RELEASE`: `close(2)` for files.
+- `FUSE_FSYNC`
+- `FUSE_FALLOCATE`
+- `FUSE_SETUPMAPPING`: Creates a memory map on a file for `mmap(2)`.
+- `FUSE_REMOVEMAPPING`: Removes a memory map for `munmap(2)`.
+
+File locking:
+
+- `FUSE_GETLK`
+- `FUSE_SETLK`
+- `FUSE_SETLKW`
+- `FUSE_COPY_FILE_RANGE`
+
+File descriptor operations:
+
+- `FUSE_IOCTL`
+- `FUSE_POLL`
+- `FUSE_LSEEK`
+
+Filesystem operations:
+
+- `FUSE_STATFS`
+
+#### Permissions
+
+- `FUSE_ACCESS` is used to check if a node is accessible, as part of many
+ syscall implementations. Maps to `vfs.FilesystemImpl.AccessAt` in the
+ sentry.
+
+#### I/O Operations
+
+These ops are used to read and write file pages. They're used to implement both
+I/O syscalls like `read(2)`, `write(2)` and `mmap(2)`.
+
+- `FUSE_READ`
+- `FUSE_WRITE`
+
+#### Miscellaneous
+
+- `FUSE_FLUSH`: Used by the client to indicate when a file descriptor is
+ closed. Distinct from `FUSE_FSYNC`, which corresponds to an `fsync(2)`
+ syscall from the user. Maps to `vfs.FileDescriptorImpl.Release` in the
+ sentry.
+- `FUSE_BMAP`: Old address space API for block defrag. Probably not needed.
+- `FUSE_NOTIFY_REPLY`: [TODO: what does this do?]
+
+# References
+
+- [fuse(4) Linux manual page](https://www.man7.org/linux/man-pages/man4/fuse.4.html)
+- [Linux kernel FUSE documentation](https://www.kernel.org/doc/html/latest/filesystems/fuse.html)
+- [The reference implementation of the Linux FUSE (Filesystem in Userspace)
+ interface](https://github.com/libfuse/libfuse)
+- [The kernel interface of FUSE](https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/include/uapi/linux/fuse.h)
diff --git a/pkg/sentry/fs/g3doc/inotify.md b/pkg/sentry/fs/g3doc/inotify.md
new file mode 100644
index 000000000..85063d4e6
--- /dev/null
+++ b/pkg/sentry/fs/g3doc/inotify.md
@@ -0,0 +1,122 @@
+# Inotify
+
+Inotify implements the like-named filesystem event notification system for the
+sentry, see `inotify(7)`.
+
+## Architecture
+
+For the most part, the sentry implementation of inotify mirrors the Linux
+architecture. Inotify instances (i.e. the fd returned by inotify_init(2)) are
+backed by a pseudo-filesystem. Events are generated from various places in the
+sentry, including the [syscall layer][syscall_dir], the [vfs layer][dirent] and
+the [process fd table][fd_table]. Watches are stored in inodes and generated
+events are queued to the inotify instance owning the watches for delivery to the
+user.
+
+## Objects
+
+Here is a brief description of the existing and new objects involved in the
+sentry inotify mechanism, and how they interact:
+
+### [`fs.Inotify`][inotify]
+
+- An inotify instances, created by inotify_init(2)/inotify_init1(2).
+- The inotify fd has a `fs.Dirent`, supports filesystem syscalls to read
+ events.
+- Has multiple `fs.Watch`es, with at most one watch per target inode, per
+ inotify instance.
+- Has an instance `id` which is globally unique. This is *not* the fd number
+ for this instance, since the fd can be duped. This `id` is not externally
+ visible.
+
+### [`fs.Watch`][watch]
+
+- An inotify watch, created/deleted by
+ inotify_add_watch(2)/inotify_rm_watch(2).
+- Owned by an `fs.Inotify` instance, each watch keeps a pointer to the
+ `owner`.
+- Associated with a single `fs.Inode`, which is the watch `target`. While the
+ watch is active, it indirectly pins `target` to memory. See the "Reference
+ Model" section for a detailed explanation.
+- Filesystem operations on `target` generate `fs.Event`s.
+
+### [`fs.Event`][event]
+
+- A simple struct encapsulating all the fields for an inotify event.
+- Generated by `fs.Watch`es and forwarded to the watches' `owner`s.
+- Serialized to the user during read(2) syscalls on the associated
+ `fs.Inotify`'s fd.
+
+### [`fs.Dirent`][dirent]
+
+- Many inotify events are generated inside dirent methods. Events are
+ generated in the dirent methods rather than `fs.Inode` methods because some
+ events carry the name of the subject node, and node names are generally
+ unavailable in an `fs.Inode`.
+- Dirents do not directly contain state for any watches. Instead, they forward
+ notifications to the underlying `fs.Inode`.
+
+### [`fs.Inode`][inode]
+
+- Interacts with inotify through `fs.Watch`es.
+- Inodes contain a map of all active `fs.Watch`es on them.
+- An `fs.Inotify` instance can have at most one `fs.Watch` per inode.
+ `fs.Watch`es on an inode are indexed by their `owner`'s `id`.
+- All inotify logic is encapsulated in the [`Watches`][inode_watches] struct
+ in an inode. Logically, `Watches` is the set of inotify watches on the
+ inode.
+
+## Reference Model
+
+The sentry inotify implementation has a complex reference model. An inotify
+watch observes a single inode. For efficient lookup, the state for a watch is
+stored directly on the target inode. This state needs to be persistent for the
+lifetime of watch. Unlike usual filesystem metadata, the watch state has no
+"on-disk" representation, so they cannot be reconstructed by the filesystem if
+the inode is flushed from memory. This effectively means we need to keep any
+inodes with actives watches pinned to memory.
+
+We can't just hold an extra ref on the inode to pin it to memory because some
+filesystems (such as gofer-based filesystems) don't have persistent inodes. In
+such a filesystem, if we just pin the inode, nothing prevents the enclosing
+dirent from being GCed. Once the dirent is GCed, the pinned inode is
+unreachable -- these filesystems generate a new inode by re-reading the node
+state on the next walk. Incidentally, hardlinks also don't work on these
+filesystems for this reason.
+
+To prevent the above scenario, when a new watch is added on an inode, we *pin*
+the dirent we used to reach the inode. Note that due to hardlinks, this dirent
+may not be the only dirent pointing to the inode. Attempting to set an inotify
+watch via multiple hardlinks to the same file results in the same watch being
+returned for both links. However, for each new dirent we use to reach the same
+inode, we add a new pin. We need a new pin for each new dirent used to reach the
+inode because we have no guarantees about the deletion order of the different
+links to the inode.
+
+## Lock Ordering
+
+There are 4 locks related to the inotify implementation:
+
+- `Inotify.mu`: the inotify instance lock.
+- `Inotify.evMu`: the inotify event queue lock.
+- `Watch.mu`: the watch lock, used to protect pins.
+- `fs.Watches.mu`: the inode watch set mu, used to protect the collection of
+ watches on the inode.
+
+The correct lock ordering for inotify code is:
+
+`Inotify.mu` -> `fs.Watches.mu` -> `Watch.mu` -> `Inotify.evMu`.
+
+We need a distinct lock for the event queue because by the time a goroutine
+attempts to queue a new event, it is already holding `fs.Watches.mu`. If we used
+`Inotify.mu` to also protect the event queue, this would violate the above lock
+ordering.
+
+[dirent]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/dirent.go
+[event]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inotify_event.go
+[fd_table]: https://github.com/google/gvisor/blob/master/pkg/sentry/kernel/fd_table.go
+[inode]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inode.go
+[inode_watches]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inode_inotify.go
+[inotify]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inotify.go
+[syscall_dir]: https://github.com/google/gvisor/blob/master/pkg/sentry/syscalls/linux/
+[watch]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inotify_watch.go
diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD
new file mode 100644
index 000000000..fea135eea
--- /dev/null
+++ b/pkg/sentry/fs/gofer/BUILD
@@ -0,0 +1,67 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "gofer",
+ srcs = [
+ "attr.go",
+ "cache_policy.go",
+ "context_file.go",
+ "device.go",
+ "fifo.go",
+ "file.go",
+ "file_state.go",
+ "fs.go",
+ "handles.go",
+ "inode.go",
+ "inode_state.go",
+ "path.go",
+ "session.go",
+ "session_state.go",
+ "socket.go",
+ "util.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fd",
+ "//pkg/log",
+ "//pkg/metric",
+ "//pkg/p9",
+ "//pkg/refs",
+ "//pkg/safemem",
+ "//pkg/secio",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fdpipe",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/host",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/pipe",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sync",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/unet",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "gofer_test",
+ size = "small",
+ srcs = ["gofer_test.go"],
+ library = ":gofer",
+ deps = [
+ "//pkg/context",
+ "//pkg/p9",
+ "//pkg/p9/p9test",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fs",
+ ],
+)
diff --git a/pkg/sentry/fs/gofer/attr.go b/pkg/sentry/fs/gofer/attr.go
new file mode 100644
index 000000000..d481baf77
--- /dev/null
+++ b/pkg/sentry/fs/gofer/attr.go
@@ -0,0 +1,172 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// getattr returns the 9p attributes of the p9.File. On success, Mode, Size, and RDev
+// are guaranteed to be masked as valid.
+func getattr(ctx context.Context, file contextFile) (p9.QID, p9.AttrMask, p9.Attr, error) {
+ // Retrieve attributes over the wire.
+ qid, valid, attr, err := file.getAttr(ctx, p9.AttrMaskAll())
+ if err != nil {
+ return qid, valid, attr, err
+ }
+
+ // Require mode, size, and raw device id.
+ if !valid.Mode || !valid.Size || !valid.RDev {
+ return qid, valid, attr, syscall.EIO
+ }
+
+ return qid, valid, attr, nil
+}
+
+func unstable(ctx context.Context, valid p9.AttrMask, pattr p9.Attr, mounter fs.FileOwner, client *p9.Client) fs.UnstableAttr {
+ return fs.UnstableAttr{
+ Size: int64(pattr.Size),
+ Usage: int64(pattr.Size),
+ Perms: perms(valid, pattr, client),
+ Owner: owner(mounter, valid, pattr),
+ AccessTime: atime(ctx, valid, pattr),
+ ModificationTime: mtime(ctx, valid, pattr),
+ StatusChangeTime: ctime(ctx, valid, pattr),
+ Links: links(valid, pattr),
+ }
+}
+
+func perms(valid p9.AttrMask, pattr p9.Attr, client *p9.Client) fs.FilePermissions {
+ if pattr.Mode.IsDir() && !p9.VersionSupportsMultiUser(client.Version()) {
+ // If user and group permissions bits are not supplied, use
+ // "other" bits to supplement them.
+ //
+ // Older Gofer's fake directories only have "other" permission,
+ // but will often be accessed via user or group permissions.
+ if pattr.Mode&0770 == 0 {
+ other := pattr.Mode & 07
+ pattr.Mode = pattr.Mode | other<<3 | other<<6
+ }
+ }
+ return fs.FilePermsFromP9(pattr.Mode)
+}
+
+func owner(mounter fs.FileOwner, valid p9.AttrMask, pattr p9.Attr) fs.FileOwner {
+ // Unless the file returned its UID and GID, it belongs to the mounting
+ // task's EUID/EGID.
+ owner := mounter
+ if valid.UID {
+ if pattr.UID.Ok() {
+ owner.UID = auth.KUID(pattr.UID)
+ } else {
+ owner.UID = auth.KUID(auth.OverflowUID)
+ }
+ }
+ if valid.GID {
+ if pattr.GID.Ok() {
+ owner.GID = auth.KGID(pattr.GID)
+ } else {
+ owner.GID = auth.KGID(auth.OverflowGID)
+ }
+ }
+ return owner
+}
+
+// bsize returns a block size from 9p attributes.
+func bsize(pattr p9.Attr) int64 {
+ if pattr.BlockSize > 0 {
+ return int64(pattr.BlockSize)
+ }
+ // Some files, particularly those that are not on a local file system,
+ // may have no clue of their block size. Better not to report something
+ // misleading or buggy and have a safe default.
+ return usermem.PageSize
+}
+
+// ntype returns an fs.InodeType from 9p attributes.
+func ntype(pattr p9.Attr) fs.InodeType {
+ switch {
+ case pattr.Mode.IsNamedPipe():
+ return fs.Pipe
+ case pattr.Mode.IsDir():
+ return fs.Directory
+ case pattr.Mode.IsSymlink():
+ return fs.Symlink
+ case pattr.Mode.IsCharacterDevice():
+ return fs.CharacterDevice
+ case pattr.Mode.IsBlockDevice():
+ return fs.BlockDevice
+ case pattr.Mode.IsSocket():
+ return fs.Socket
+ case pattr.Mode.IsRegular():
+ fallthrough
+ default:
+ return fs.RegularFile
+ }
+}
+
+// ctime returns a change time from 9p attributes.
+func ctime(ctx context.Context, valid p9.AttrMask, pattr p9.Attr) ktime.Time {
+ if valid.CTime {
+ return ktime.FromUnix(int64(pattr.CTimeSeconds), int64(pattr.CTimeNanoSeconds))
+ }
+ // Approximate ctime with mtime if ctime isn't available.
+ return mtime(ctx, valid, pattr)
+}
+
+// atime returns an access time from 9p attributes.
+func atime(ctx context.Context, valid p9.AttrMask, pattr p9.Attr) ktime.Time {
+ if valid.ATime {
+ return ktime.FromUnix(int64(pattr.ATimeSeconds), int64(pattr.ATimeNanoSeconds))
+ }
+ return ktime.NowFromContext(ctx)
+}
+
+// mtime returns a modification time from 9p attributes.
+func mtime(ctx context.Context, valid p9.AttrMask, pattr p9.Attr) ktime.Time {
+ if valid.MTime {
+ return ktime.FromUnix(int64(pattr.MTimeSeconds), int64(pattr.MTimeNanoSeconds))
+ }
+ return ktime.NowFromContext(ctx)
+}
+
+// links returns a hard link count from 9p attributes.
+func links(valid p9.AttrMask, pattr p9.Attr) uint64 {
+ // For gofer file systems that support link count (such as a local file gofer),
+ // we return the link count reported by the underlying file system.
+ if valid.NLink {
+ return pattr.NLink
+ }
+
+ // This node is likely backed by a file system that doesn't support links.
+ //
+ // We could readdir() and count children directories to provide an accurate
+ // link count. However this may be expensive since the gofer may be backed by remote
+ // storage. Instead, simply return 2 links for directories and 1 for everything else
+ // since no one relies on an accurate link count for gofer-based file systems.
+ switch ntype(pattr) {
+ case fs.Directory:
+ return 2
+ default:
+ return 1
+ }
+}
diff --git a/pkg/sentry/fs/gofer/cache_policy.go b/pkg/sentry/fs/gofer/cache_policy.go
new file mode 100644
index 000000000..07a564e92
--- /dev/null
+++ b/pkg/sentry/fs/gofer/cache_policy.go
@@ -0,0 +1,186 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+)
+
+// cachePolicy is a 9p cache policy. It has methods that determine what to
+// cache (if anything) for a given inode.
+type cachePolicy int
+
+const (
+ // Cache nothing.
+ cacheNone cachePolicy = iota
+
+ // Use virtual file system cache for everything.
+ cacheAll
+
+ // Use virtual file system cache for everything, but send writes to the
+ // fs agent immediately.
+ cacheAllWritethrough
+
+ // Use the (host) page cache for reads/writes, but don't cache anything
+ // else. This allows the sandbox filesystem to stay in sync with any
+ // changes to the remote filesystem.
+ //
+ // This policy should *only* be used with remote filesystems that
+ // donate their host FDs to the sandbox and thus use the host page
+ // cache, otherwise the dirent state will be inconsistent.
+ cacheRemoteRevalidating
+)
+
+// String returns the string name of the cache policy.
+func (cp cachePolicy) String() string {
+ switch cp {
+ case cacheNone:
+ return "cacheNone"
+ case cacheAll:
+ return "cacheAll"
+ case cacheAllWritethrough:
+ return "cacheAllWritethrough"
+ case cacheRemoteRevalidating:
+ return "cacheRemoteRevalidating"
+ default:
+ return "unknown"
+ }
+}
+
+func parseCachePolicy(policy string) (cachePolicy, error) {
+ switch policy {
+ case "fscache":
+ return cacheAll, nil
+ case "none":
+ return cacheNone, nil
+ case "fscache_writethrough":
+ return cacheAllWritethrough, nil
+ case "remote_revalidating":
+ return cacheRemoteRevalidating, nil
+ }
+ return cacheNone, fmt.Errorf("unsupported cache mode: %s", policy)
+}
+
+// cacheUAtters determines whether unstable attributes should be cached for the
+// given inode.
+func (cp cachePolicy) cacheUAttrs(inode *fs.Inode) bool {
+ if !fs.IsFile(inode.StableAttr) && !fs.IsDir(inode.StableAttr) {
+ return false
+ }
+ return cp == cacheAll || cp == cacheAllWritethrough
+}
+
+// cacheReaddir determines whether readdir results should be cached.
+func (cp cachePolicy) cacheReaddir() bool {
+ return cp == cacheAll || cp == cacheAllWritethrough
+}
+
+// useCachingInodeOps determines whether the page cache should be used for the
+// given inode. If the remote filesystem donates host FDs to the sentry, then
+// the host kernel's page cache will be used, otherwise we will use a
+// sentry-internal page cache.
+func (cp cachePolicy) useCachingInodeOps(inode *fs.Inode) bool {
+ // Do cached IO for regular files only. Some "character devices" expect
+ // no caching.
+ if !fs.IsFile(inode.StableAttr) {
+ return false
+ }
+ return cp == cacheAll || cp == cacheAllWritethrough
+}
+
+// writeThough indicates whether writes to the file should be synced to the
+// gofer immediately.
+func (cp cachePolicy) writeThrough(inode *fs.Inode) bool {
+ return cp == cacheNone || cp == cacheAllWritethrough
+}
+
+// revalidate revalidates the child Inode if the cache policy allows it.
+//
+// Depending on the cache policy, revalidate will walk from the parent to the
+// child inode, and if any unstable attributes have changed, will update the
+// cached attributes on the child inode. If the walk fails, or the returned
+// inode id is different from the one being revalidated, then the entire Dirent
+// must be reloaded.
+func (cp cachePolicy) revalidate(ctx context.Context, name string, parent, child *fs.Inode) bool {
+ if cp == cacheAll || cp == cacheAllWritethrough {
+ return false
+ }
+
+ if cp == cacheNone {
+ return true
+ }
+
+ childIops, ok := child.InodeOperations.(*inodeOperations)
+ if !ok {
+ if _, ok := child.InodeOperations.(*fifo); ok {
+ return false
+ }
+ panic(fmt.Sprintf("revalidating inode operations of unknown type %T", child.InodeOperations))
+ }
+ parentIops, ok := parent.InodeOperations.(*inodeOperations)
+ if !ok {
+ panic(fmt.Sprintf("revalidating inode operations with parent of unknown type %T", parent.InodeOperations))
+ }
+
+ // Walk from parent to child again.
+ //
+ // TODO(b/112031682): If we have a directory FD in the parent
+ // inodeOperations, then we can use fstatat(2) to get the inode
+ // attributes instead of making this RPC.
+ qids, f, mask, attr, err := parentIops.fileState.file.walkGetAttr(ctx, []string{name})
+ if err != nil {
+ // Can't look up the name. Trigger reload.
+ return true
+ }
+ f.close(ctx)
+
+ // If the Path has changed, then we are not looking at the file file.
+ // We must reload.
+ if qids[0].Path != childIops.fileState.key.Inode {
+ return true
+ }
+
+ // If we are not caching unstable attrs, then there is nothing to
+ // update on this inode.
+ if !cp.cacheUAttrs(child) {
+ return false
+ }
+
+ // Update the inode's cached unstable attrs.
+ s := childIops.session()
+ childIops.cachingInodeOps.UpdateUnstable(unstable(ctx, mask, attr, s.mounter, s.client))
+
+ return false
+}
+
+// keep indicates that dirents should be kept pinned in the dirent tree even if
+// there are no application references on the file.
+func (cp cachePolicy) keep(d *fs.Dirent) bool {
+ if cp == cacheNone {
+ return false
+ }
+ sattr := d.Inode.StableAttr
+ // NOTE(b/31979197): Only cache files, directories, and symlinks.
+ return fs.IsFile(sattr) || fs.IsDir(sattr) || fs.IsSymlink(sattr)
+}
+
+// cacheNegativeDirents indicates that negative dirents should be held in the
+// dirent tree.
+func (cp cachePolicy) cacheNegativeDirents() bool {
+ return cp == cacheAll || cp == cacheAllWritethrough
+}
diff --git a/pkg/sentry/fs/gofer/context_file.go b/pkg/sentry/fs/gofer/context_file.go
new file mode 100644
index 000000000..125907d70
--- /dev/null
+++ b/pkg/sentry/fs/gofer/context_file.go
@@ -0,0 +1,218 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/p9"
+)
+
+// contextFile is a wrapper around p9.File that notifies the context that
+// it's about to sleep before calling the Gofer over P9.
+type contextFile struct {
+ file p9.File
+}
+
+func (c *contextFile) walk(ctx context.Context, names []string) ([]p9.QID, contextFile, error) {
+ ctx.UninterruptibleSleepStart(false)
+
+ q, f, err := c.file.Walk(names)
+ if err != nil {
+ ctx.UninterruptibleSleepFinish(false)
+ return nil, contextFile{}, err
+ }
+ ctx.UninterruptibleSleepFinish(false)
+ return q, contextFile{file: f}, nil
+}
+
+func (c *contextFile) statFS(ctx context.Context) (p9.FSStat, error) {
+ ctx.UninterruptibleSleepStart(false)
+ s, err := c.file.StatFS()
+ ctx.UninterruptibleSleepFinish(false)
+ return s, err
+}
+
+func (c *contextFile) getAttr(ctx context.Context, req p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) {
+ ctx.UninterruptibleSleepStart(false)
+ q, m, a, err := c.file.GetAttr(req)
+ ctx.UninterruptibleSleepFinish(false)
+ return q, m, a, err
+}
+
+func (c *contextFile) setAttr(ctx context.Context, valid p9.SetAttrMask, attr p9.SetAttr) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.SetAttr(valid, attr)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (c *contextFile) getXattr(ctx context.Context, name string, size uint64) (string, error) {
+ ctx.UninterruptibleSleepStart(false)
+ val, err := c.file.GetXattr(name, size)
+ ctx.UninterruptibleSleepFinish(false)
+ return val, err
+}
+
+func (c *contextFile) setXattr(ctx context.Context, name, value string, flags uint32) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.SetXattr(name, value, flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (c *contextFile) listXattr(ctx context.Context, size uint64) (map[string]struct{}, error) {
+ ctx.UninterruptibleSleepStart(false)
+ xattrs, err := c.file.ListXattr(size)
+ ctx.UninterruptibleSleepFinish(false)
+ return xattrs, err
+}
+
+func (c *contextFile) removeXattr(ctx context.Context, name string) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.RemoveXattr(name)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (c *contextFile) allocate(ctx context.Context, mode p9.AllocateMode, offset, length uint64) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.Allocate(mode, offset, length)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (c *contextFile) rename(ctx context.Context, directory contextFile, name string) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.Rename(directory.file, name)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (c *contextFile) close(ctx context.Context) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.Close()
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (c *contextFile) open(ctx context.Context, mode p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) {
+ ctx.UninterruptibleSleepStart(false)
+ f, q, u, err := c.file.Open(mode)
+ ctx.UninterruptibleSleepFinish(false)
+ return f, q, u, err
+}
+
+func (c *contextFile) readAt(ctx context.Context, p []byte, offset uint64) (int, error) {
+ ctx.UninterruptibleSleepStart(false)
+ n, err := c.file.ReadAt(p, offset)
+ ctx.UninterruptibleSleepFinish(false)
+ return n, err
+}
+
+func (c *contextFile) writeAt(ctx context.Context, p []byte, offset uint64) (int, error) {
+ ctx.UninterruptibleSleepStart(false)
+ n, err := c.file.WriteAt(p, offset)
+ ctx.UninterruptibleSleepFinish(false)
+ return n, err
+}
+
+func (c *contextFile) fsync(ctx context.Context) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.FSync()
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (c *contextFile) create(ctx context.Context, name string, flags p9.OpenFlags, permissions p9.FileMode, uid p9.UID, gid p9.GID) (*fd.FD, error) {
+ ctx.UninterruptibleSleepStart(false)
+ fd, _, _, _, err := c.file.Create(name, flags, permissions, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return fd, err
+}
+
+func (c *contextFile) mkdir(ctx context.Context, name string, permissions p9.FileMode, uid p9.UID, gid p9.GID) (p9.QID, error) {
+ ctx.UninterruptibleSleepStart(false)
+ q, err := c.file.Mkdir(name, permissions, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return q, err
+}
+
+func (c *contextFile) symlink(ctx context.Context, oldName string, newName string, uid p9.UID, gid p9.GID) (p9.QID, error) {
+ ctx.UninterruptibleSleepStart(false)
+ q, err := c.file.Symlink(oldName, newName, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return q, err
+}
+
+func (c *contextFile) link(ctx context.Context, target *contextFile, newName string) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.Link(target.file, newName)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (c *contextFile) mknod(ctx context.Context, name string, permissions p9.FileMode, major uint32, minor uint32, uid p9.UID, gid p9.GID) (p9.QID, error) {
+ ctx.UninterruptibleSleepStart(false)
+ q, err := c.file.Mknod(name, permissions, major, minor, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return q, err
+}
+
+func (c *contextFile) unlinkAt(ctx context.Context, name string, flags uint32) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.UnlinkAt(name, flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (c *contextFile) readdir(ctx context.Context, offset uint64, count uint32) ([]p9.Dirent, error) {
+ ctx.UninterruptibleSleepStart(false)
+ d, err := c.file.Readdir(offset, count)
+ ctx.UninterruptibleSleepFinish(false)
+ return d, err
+}
+
+func (c *contextFile) readlink(ctx context.Context) (string, error) {
+ ctx.UninterruptibleSleepStart(false)
+ s, err := c.file.Readlink()
+ ctx.UninterruptibleSleepFinish(false)
+ return s, err
+}
+
+func (c *contextFile) flush(ctx context.Context) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.Flush()
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (c *contextFile) walkGetAttr(ctx context.Context, names []string) ([]p9.QID, contextFile, p9.AttrMask, p9.Attr, error) {
+ ctx.UninterruptibleSleepStart(false)
+ q, f, m, a, err := c.file.WalkGetAttr(names)
+ if err != nil {
+ ctx.UninterruptibleSleepFinish(false)
+ return nil, contextFile{}, p9.AttrMask{}, p9.Attr{}, err
+ }
+ ctx.UninterruptibleSleepFinish(false)
+ return q, contextFile{file: f}, m, a, nil
+}
+
+func (c *contextFile) connect(ctx context.Context, flags p9.ConnectFlags) (*fd.FD, error) {
+ ctx.UninterruptibleSleepStart(false)
+ f, err := c.file.Connect(flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return f, err
+}
diff --git a/pkg/sentry/fs/gofer/device.go b/pkg/sentry/fs/gofer/device.go
new file mode 100644
index 000000000..cbd3c5da2
--- /dev/null
+++ b/pkg/sentry/fs/gofer/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import "gvisor.dev/gvisor/pkg/sentry/device"
+
+// goferDevice is the gofer virtual device.
+var goferDevice = device.NewAnonMultiDevice()
diff --git a/pkg/sentry/fs/gofer/fifo.go b/pkg/sentry/fs/gofer/fifo.go
new file mode 100644
index 000000000..456557058
--- /dev/null
+++ b/pkg/sentry/fs/gofer/fifo.go
@@ -0,0 +1,40 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+)
+
+// +stateify savable
+type fifo struct {
+ fs.InodeOperations
+ fileIops *inodeOperations
+}
+
+var _ fs.InodeOperations = (*fifo)(nil)
+
+// Rename implements fs.InodeOperations. It forwards the call to the underlying
+// file inode to handle the file rename. Note that file key remains the same
+// after the rename to keep the endpoint mapping.
+func (i *fifo) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ return i.fileIops.Rename(ctx, inode, oldParent, oldName, newParent, newName, replacement)
+}
+
+// StatFS implements fs.InodeOperations.
+func (i *fifo) StatFS(ctx context.Context) (fs.Info, error) {
+ return i.fileIops.StatFS(ctx)
+}
diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go
new file mode 100644
index 000000000..b2fcab127
--- /dev/null
+++ b/pkg/sentry/fs/gofer/file.go
@@ -0,0 +1,369 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+ "syscall"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/metric"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/device"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+var (
+ opensWX = metric.MustCreateNewUint64Metric("/gofer/opened_write_execute_file", true /* sync */, "Number of times a writable+executable file was opened from a gofer.")
+ opens9P = metric.MustCreateNewUint64Metric("/gofer/opens_9p", false /* sync */, "Number of times a 9P file was opened from a gofer.")
+ opensHost = metric.MustCreateNewUint64Metric("/gofer/opens_host", false /* sync */, "Number of times a host file was opened from a gofer.")
+ reads9P = metric.MustCreateNewUint64Metric("/gofer/reads_9p", false /* sync */, "Number of 9P file reads from a gofer.")
+ readWait9P = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_9p", false /* sync */, "Time waiting on 9P file reads from a gofer, in nanoseconds.")
+ readsHost = metric.MustCreateNewUint64Metric("/gofer/reads_host", false /* sync */, "Number of host file reads from a gofer.")
+ readWaitHost = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_host", false /* sync */, "Time waiting on host file reads from a gofer, in nanoseconds.")
+)
+
+// fileOperations implements fs.FileOperations for a remote file system.
+//
+// +stateify savable
+type fileOperations struct {
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosplice"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ // inodeOperations is the inodeOperations backing the file. It is protected
+ // by a reference held by File.Dirent.Inode which is stable until
+ // FileOperations.Release is called.
+ inodeOperations *inodeOperations `state:"wait"`
+
+ // dirCursor is the directory cursor.
+ dirCursor string
+
+ // handles are the opened remote file system handles, which may
+ // be shared with other files.
+ handles *handles `state:"nosave"`
+
+ // flags are the flags used to open handles.
+ flags fs.FileFlags `state:"wait"`
+}
+
+// fileOperations implements fs.FileOperations.
+var _ fs.FileOperations = (*fileOperations)(nil)
+
+// NewFile returns a file. NewFile is not appropriate with host pipes and sockets.
+//
+// The `name` argument is only used to log a warning if we are returning a
+// writeable+executable file. (A metric counter is incremented in this case as
+// well.) Note that we cannot call d.BaseName() directly in this function,
+// because that would lead to a lock order violation, since this is called in
+// d.Create which holds d.mu, while d.BaseName() takes d.parent.mu, and the two
+// locks must be taken in the opposite order.
+func NewFile(ctx context.Context, dirent *fs.Dirent, name string, flags fs.FileFlags, i *inodeOperations, handles *handles) *fs.File {
+ // Remote file systems enforce readability/writability at an offset,
+ // see fs/9p/vfs_inode.c:v9fs_vfs_atomic_open -> fs/open.c:finish_open.
+ flags.Pread = true
+ flags.Pwrite = true
+
+ if fs.IsFile(dirent.Inode.StableAttr) {
+ // If cache policy is "remote revalidating", then we must
+ // ensure that we have a host FD. Otherwise, the
+ // sentry-internal page cache will be used, and we can end up
+ // in an inconsistent state if the remote file changes.
+ cp := dirent.Inode.InodeOperations.(*inodeOperations).session().cachePolicy
+ if cp == cacheRemoteRevalidating && handles.Host == nil {
+ panic(fmt.Sprintf("remote-revalidating cache policy requires gofer to donate host FD, but file %q did not have host FD", name))
+ }
+ }
+
+ f := &fileOperations{
+ inodeOperations: i,
+ handles: handles,
+ flags: flags,
+ }
+ if flags.Write {
+ if err := dirent.Inode.CheckPermission(ctx, fs.PermMask{Execute: true}); err == nil {
+ opensWX.Increment()
+ log.Warningf("Opened a writable executable: %q", name)
+ }
+ }
+ if handles.Host != nil {
+ opensHost.Increment()
+ } else {
+ opens9P.Increment()
+ }
+ return fs.NewFile(ctx, dirent, flags, f)
+}
+
+// Release implements fs.FileOpeations.Release.
+func (f *fileOperations) Release() {
+ f.handles.DecRef()
+}
+
+// Readdir implements fs.FileOperations.Readdir.
+func (f *fileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
+ root := fs.RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+
+ dirCtx := &fs.DirCtx{
+ Serializer: serializer,
+ DirCursor: &f.dirCursor,
+ }
+ n, err := fs.DirentReaddir(ctx, file.Dirent, f, root, dirCtx, file.Offset())
+ if f.inodeOperations.session().cachePolicy.cacheUAttrs(file.Dirent.Inode) {
+ f.inodeOperations.cachingInodeOps.TouchAccessTime(ctx, file.Dirent.Inode)
+ }
+ return n, err
+}
+
+// IterateDir implements fs.DirIterator.IterateDir.
+func (f *fileOperations) IterateDir(ctx context.Context, d *fs.Dirent, dirCtx *fs.DirCtx, offset int) (int, error) {
+ f.inodeOperations.readdirMu.Lock()
+ defer f.inodeOperations.readdirMu.Unlock()
+
+ // Fetch directory entries if needed.
+ if !f.inodeOperations.session().cachePolicy.cacheReaddir() || f.inodeOperations.readdirCache == nil {
+ entries, err := f.readdirAll(ctx)
+ if err != nil {
+ return offset, err
+ }
+
+ // Cache the readdir result.
+ f.inodeOperations.readdirCache = fs.NewSortedDentryMap(entries)
+ }
+
+ // Serialize the entries.
+ n, err := fs.GenericReaddir(dirCtx, f.inodeOperations.readdirCache)
+ return offset + n, err
+}
+
+// readdirAll fetches fs.DentAttrs for f, using the attributes of g.
+func (f *fileOperations) readdirAll(ctx context.Context) (map[string]fs.DentAttr, error) {
+ entries := make(map[string]fs.DentAttr)
+ var readOffset uint64
+ for {
+ // We choose some arbitrary high number of directory entries (64k) and call
+ // Readdir until we've exhausted them all.
+ dirents, err := f.handles.File.readdir(ctx, readOffset, 64*1024)
+ if err != nil {
+ return nil, err
+ }
+ if len(dirents) == 0 {
+ // We're done, we reached EOF.
+ break
+ }
+
+ // The last dirent contains the offset into the next set of dirents. The gofer
+ // returns the offset as an index into directories, not as a byte offset, because
+ // converting a byte offset to an index into directories entries is a huge pain.
+ // But everything is fine if we're consistent.
+ readOffset = dirents[len(dirents)-1].Offset
+
+ for _, dirent := range dirents {
+ if dirent.Name == "." || dirent.Name == ".." {
+ // These must not be included in Readdir results.
+ continue
+ }
+
+ // Find a best approximation of the type.
+ var nt fs.InodeType
+ switch dirent.Type {
+ case p9.TypeDir:
+ nt = fs.Directory
+ case p9.TypeSymlink:
+ nt = fs.Symlink
+ default:
+ nt = fs.RegularFile
+ }
+
+ // Install the DentAttr.
+ entries[dirent.Name] = fs.DentAttr{
+ Type: nt,
+ // Construct the key to find the virtual inode.
+ // Directory entries reside on the same Device
+ // and SecondaryDevice as their parent.
+ InodeID: goferDevice.Map(device.MultiDeviceKey{
+ Device: f.inodeOperations.fileState.key.Device,
+ SecondaryDevice: f.inodeOperations.fileState.key.SecondaryDevice,
+ Inode: dirent.QID.Path,
+ }),
+ }
+ }
+ }
+
+ return entries, nil
+}
+
+// maybeSync will call FSync on the file if either the cache policy or file
+// flags require it.
+func (f *fileOperations) maybeSync(ctx context.Context, file *fs.File, offset, n int64) error {
+ if n == 0 {
+ // Nothing to sync.
+ return nil
+ }
+
+ if f.inodeOperations.session().cachePolicy.writeThrough(file.Dirent.Inode) {
+ // Call WriteOut directly, as some "writethrough" filesystems
+ // do not support sync.
+ return f.inodeOperations.cachingInodeOps.WriteOut(ctx, file.Dirent.Inode)
+ }
+
+ flags := file.Flags()
+ var syncType fs.SyncType
+ switch {
+ case flags.Direct || flags.Sync:
+ syncType = fs.SyncAll
+ case flags.DSync:
+ syncType = fs.SyncData
+ default:
+ // No need to sync.
+ return nil
+ }
+
+ return f.Fsync(ctx, file, offset, offset+n, syncType)
+}
+
+// Write implements fs.FileOperations.Write.
+func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ if fs.IsDir(file.Dirent.Inode.StableAttr) {
+ // Not all remote file systems enforce this so this client does.
+ return 0, syserror.EISDIR
+ }
+
+ var (
+ n int64
+ err error
+ )
+ // The write is handled in different ways depending on the cache policy
+ // and availability of a host-mappable FD.
+ if f.inodeOperations.session().cachePolicy.useCachingInodeOps(file.Dirent.Inode) {
+ n, err = f.inodeOperations.cachingInodeOps.Write(ctx, src, offset)
+ } else if f.inodeOperations.fileState.hostMappable != nil {
+ n, err = f.inodeOperations.fileState.hostMappable.Write(ctx, src, offset)
+ } else {
+ n, err = src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset))
+ }
+
+ // We may need to sync the written bytes.
+ if syncErr := f.maybeSync(ctx, file, offset, n); syncErr != nil {
+ // Sync failed. Report 0 bytes written, since none of them are
+ // guaranteed to have been synced.
+ return 0, syncErr
+ }
+
+ return n, err
+}
+
+// incrementReadCounters increments the read counters for the read starting at the given time. We
+// use this function rather than using a defer in Read() to avoid the performance hit of defer.
+func (f *fileOperations) incrementReadCounters(start time.Time) {
+ if f.handles.Host != nil {
+ readsHost.Increment()
+ fs.IncrementWait(readWaitHost, start)
+ } else {
+ reads9P.Increment()
+ fs.IncrementWait(readWait9P, start)
+ }
+}
+
+// Read implements fs.FileOperations.Read.
+func (f *fileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ var start time.Time
+ if fs.RecordWaitTime {
+ start = time.Now()
+ }
+ if fs.IsDir(file.Dirent.Inode.StableAttr) {
+ // Not all remote file systems enforce this so this client does.
+ f.incrementReadCounters(start)
+ return 0, syserror.EISDIR
+ }
+
+ if f.inodeOperations.session().cachePolicy.useCachingInodeOps(file.Dirent.Inode) {
+ n, err := f.inodeOperations.cachingInodeOps.Read(ctx, file, dst, offset)
+ f.incrementReadCounters(start)
+ return n, err
+ }
+ n, err := dst.CopyOutFrom(ctx, f.handles.readWriterAt(ctx, offset))
+ f.incrementReadCounters(start)
+ return n, err
+}
+
+// Fsync implements fs.FileOperations.Fsync.
+func (f *fileOperations) Fsync(ctx context.Context, file *fs.File, start, end int64, syncType fs.SyncType) error {
+ switch syncType {
+ case fs.SyncAll, fs.SyncData:
+ if err := file.Dirent.Inode.WriteOut(ctx); err != nil {
+ return err
+ }
+ fallthrough
+ case fs.SyncBackingStorage:
+ // Sync remote caches.
+ if f.handles.Host != nil {
+ // Sync the host fd directly.
+ return syscall.Fsync(f.handles.Host.FD())
+ }
+ // Otherwise sync on the p9.File handle.
+ return f.handles.File.fsync(ctx)
+ }
+ panic("invalid sync type")
+}
+
+// Flush implements fs.FileOperations.Flush.
+func (f *fileOperations) Flush(ctx context.Context, file *fs.File) error {
+ // If this file is not opened writable then there is nothing to flush.
+ // We do this because some p9 server implementations of Flush are
+ // over-zealous.
+ //
+ // FIXME(edahlgren): weaken these implementations and remove this check.
+ if !file.Flags().Write {
+ return nil
+ }
+ // Execute the flush.
+ return f.handles.File.flush(ctx)
+}
+
+// ConfigureMMap implements fs.FileOperations.ConfigureMMap.
+func (f *fileOperations) ConfigureMMap(ctx context.Context, file *fs.File, opts *memmap.MMapOpts) error {
+ return f.inodeOperations.configureMMap(file, opts)
+}
+
+// UnstableAttr implements fs.FileOperations.UnstableAttr.
+func (f *fileOperations) UnstableAttr(ctx context.Context, file *fs.File) (fs.UnstableAttr, error) {
+ s := f.inodeOperations.session()
+ if s.cachePolicy.cacheUAttrs(file.Dirent.Inode) {
+ return f.inodeOperations.cachingInodeOps.UnstableAttr(ctx, file.Dirent.Inode)
+ }
+ // Use f.handles.File, which represents 9P fids that have been opened,
+ // instead of inodeFileState.file, which represents 9P fids that have not.
+ // This may be significantly more efficient in some implementations.
+ _, valid, pattr, err := getattr(ctx, f.handles.File)
+ if err != nil {
+ return fs.UnstableAttr{}, err
+ }
+ return unstable(ctx, valid, pattr, s.mounter, s.client), nil
+}
+
+// Seek implements fs.FileOperations.Seek.
+func (f *fileOperations) Seek(ctx context.Context, file *fs.File, whence fs.SeekWhence, offset int64) (int64, error) {
+ return fsutil.SeekWithDirCursor(ctx, file, whence, offset, &f.dirCursor)
+}
diff --git a/pkg/sentry/fs/gofer/file_state.go b/pkg/sentry/fs/gofer/file_state.go
new file mode 100644
index 000000000..edd6576aa
--- /dev/null
+++ b/pkg/sentry/fs/gofer/file_state.go
@@ -0,0 +1,44 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+)
+
+// afterLoad is invoked by stateify.
+func (f *fileOperations) afterLoad() {
+ load := func() error {
+ f.inodeOperations.fileState.waitForLoad()
+
+ // Manually load the open handles.
+ var err error
+
+ // The file may have been opened with Truncate, but we don't
+ // want to re-open it with Truncate or we will lose data.
+ flags := f.flags
+ flags.Truncate = false
+
+ f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), flags, f.inodeOperations.cachingInodeOps)
+ if err != nil {
+ return fmt.Errorf("failed to re-open handle: %v", err)
+ }
+ return nil
+ }
+ fs.Async(fs.CatchError(load))
+}
diff --git a/pkg/sentry/fs/gofer/fs.go b/pkg/sentry/fs/gofer/fs.go
new file mode 100644
index 000000000..8ae2d78d7
--- /dev/null
+++ b/pkg/sentry/fs/gofer/fs.go
@@ -0,0 +1,267 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gofer implements a remote 9p filesystem.
+package gofer
+
+import (
+ "errors"
+ "fmt"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+)
+
+// The following are options defined by the Linux 9p client that we support,
+// see Documentation/filesystems/9p.txt.
+const (
+ // The transport method.
+ transportKey = "trans"
+
+ // The file tree to access when the file server
+ // is exporting several file systems. Stands for "attach name".
+ anameKey = "aname"
+
+ // The caching policy.
+ cacheKey = "cache"
+
+ // The file descriptor for reading with trans=fd.
+ readFDKey = "rfdno"
+
+ // The file descriptor for writing with trans=fd.
+ writeFDKey = "wfdno"
+
+ // The number of bytes to use for a 9p packet payload.
+ msizeKey = "msize"
+
+ // The 9p protocol version.
+ versionKey = "version"
+
+ // If set to true allows the creation of unix domain sockets inside the
+ // sandbox using files backed by the gofer. If set to false, unix sockets
+ // cannot be bound to gofer files without an overlay on top.
+ privateUnixSocketKey = "privateunixsocket"
+
+ // If present, sets CachingInodeOperationsOptions.LimitHostFDTranslation to
+ // true.
+ limitHostFDTranslationKey = "limit_host_fd_translation"
+
+ // overlayfsStaleRead if present closes cached readonly file after the first
+ // write. This is done to workaround a limitation of Linux overlayfs.
+ overlayfsStaleRead = "overlayfs_stale_read"
+)
+
+// defaultAname is the default attach name.
+const defaultAname = "/"
+
+// defaultMSize is the message size used for chunking large read and write requests.
+// This has been tested to give good enough performance up to 64M.
+const defaultMSize = 1024 * 1024 // 1M
+
+// defaultVersion is the default 9p protocol version. Will negotiate downwards with
+// file server if needed.
+var defaultVersion = p9.HighestVersionString()
+
+// Number of names of non-children to cache, preventing unneeded walks. 64 is
+// plenty for nodejs, which seems to stat about 4 children on every require().
+const nonChildrenCacheSize = 64
+
+var (
+ // ErrNoTransport is returned when there is no 'trans' option.
+ ErrNoTransport = errors.New("missing required option: 'trans='")
+
+ // ErrFileNoReadFD is returned when there is no 'rfdno' option.
+ ErrFileNoReadFD = errors.New("missing required option: 'rfdno='")
+
+ // ErrFileNoWriteFD is returned when there is no 'wfdno' option.
+ ErrFileNoWriteFD = errors.New("missing required option: 'wfdno='")
+)
+
+// filesystem is a 9p client.
+//
+// +stateify savable
+type filesystem struct{}
+
+var _ fs.Filesystem = (*filesystem)(nil)
+
+func init() {
+ fs.RegisterFilesystem(&filesystem{})
+}
+
+// FilesystemName is the name under which the filesystem is registered.
+// The name matches fs/9p/vfs_super.c:v9fs_fs_type.name.
+const FilesystemName = "9p"
+
+// Name is the name of the filesystem.
+func (*filesystem) Name() string {
+ return FilesystemName
+}
+
+// AllowUserMount prohibits users from using mount(2) with this file system.
+func (*filesystem) AllowUserMount() bool {
+ return false
+}
+
+// AllowUserList allows this filesystem to be listed in /proc/filesystems.
+func (*filesystem) AllowUserList() bool {
+ return true
+}
+
+// Flags returns that there is nothing special about this file system.
+//
+// The 9p Linux client returns FS_RENAME_DOES_D_MOVE, see fs/9p/vfs_super.c.
+func (*filesystem) Flags() fs.FilesystemFlags {
+ return 0
+}
+
+// Mount returns an attached 9p client that can be positioned in the vfs.
+func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSourceFlags, data string, _ interface{}) (*fs.Inode, error) {
+ // Parse and validate the mount options.
+ o, err := options(data)
+ if err != nil {
+ return nil, err
+ }
+
+ // Construct the 9p root to mount. We intentionally diverge from Linux in that
+ // the first Tversion and Tattach requests are done lazily.
+ return Root(ctx, device, f, flags, o)
+}
+
+// opts are parsed 9p mount options.
+type opts struct {
+ fd int
+ aname string
+ policy cachePolicy
+ msize uint32
+ version string
+ privateunixsocket bool
+ limitHostFDTranslation bool
+ overlayfsStaleRead bool
+}
+
+// options parses mount(2) data into structured options.
+func options(data string) (opts, error) {
+ var o opts
+
+ // Parse generic comma-separated key=value options, this file system expects them.
+ options := fs.GenericMountSourceOptions(data)
+
+ // Check for the required 'trans=fd' option.
+ trans, ok := options[transportKey]
+ if !ok {
+ return o, ErrNoTransport
+ }
+ if trans != "fd" {
+ return o, fmt.Errorf("unsupported transport: 'trans=%s'", trans)
+ }
+ delete(options, transportKey)
+
+ // Check for the required 'rfdno=' option.
+ srfd, ok := options[readFDKey]
+ if !ok {
+ return o, ErrFileNoReadFD
+ }
+ delete(options, readFDKey)
+
+ // Check for the required 'wfdno=' option.
+ swfd, ok := options[writeFDKey]
+ if !ok {
+ return o, ErrFileNoWriteFD
+ }
+ delete(options, writeFDKey)
+
+ // Parse the read fd.
+ rfd, err := strconv.Atoi(srfd)
+ if err != nil {
+ return o, fmt.Errorf("invalid fd for 'rfdno=%s': %v", srfd, err)
+ }
+
+ // Parse the write fd.
+ wfd, err := strconv.Atoi(swfd)
+ if err != nil {
+ return o, fmt.Errorf("invalid fd for 'wfdno=%s': %v", swfd, err)
+ }
+
+ // Require that the read and write fd are the same.
+ if rfd != wfd {
+ return o, fmt.Errorf("fd in 'rfdno=%d' and 'wfdno=%d' must match", rfd, wfd)
+ }
+ o.fd = rfd
+
+ // Parse the attach name.
+ o.aname = defaultAname
+ if an, ok := options[anameKey]; ok {
+ o.aname = an
+ delete(options, anameKey)
+ }
+
+ // Parse the cache policy. Reject unsupported policies.
+ o.policy = cacheAll
+ if policy, ok := options[cacheKey]; ok {
+ cp, err := parseCachePolicy(policy)
+ if err != nil {
+ return o, err
+ }
+ o.policy = cp
+ delete(options, cacheKey)
+ }
+
+ // Parse the message size. Reject malformed options.
+ o.msize = uint32(defaultMSize)
+ if m, ok := options[msizeKey]; ok {
+ i, err := strconv.ParseUint(m, 10, 32)
+ if err != nil {
+ return o, fmt.Errorf("invalid message size for 'msize=%s': %v", m, err)
+ }
+ o.msize = uint32(i)
+ delete(options, msizeKey)
+ }
+
+ // Parse the protocol version.
+ o.version = defaultVersion
+ if v, ok := options[versionKey]; ok {
+ o.version = v
+ delete(options, versionKey)
+ }
+
+ // Parse the unix socket policy. Reject non-booleans.
+ if v, ok := options[privateUnixSocketKey]; ok {
+ b, err := strconv.ParseBool(v)
+ if err != nil {
+ return o, fmt.Errorf("invalid boolean value for '%s=%s': %v", privateUnixSocketKey, v, err)
+ }
+ o.privateunixsocket = b
+ delete(options, privateUnixSocketKey)
+ }
+
+ if _, ok := options[limitHostFDTranslationKey]; ok {
+ o.limitHostFDTranslation = true
+ delete(options, limitHostFDTranslationKey)
+ }
+
+ if _, ok := options[overlayfsStaleRead]; ok {
+ o.overlayfsStaleRead = true
+ delete(options, overlayfsStaleRead)
+ }
+
+ // Fail to attach if the caller wanted us to do something that we
+ // don't support.
+ if len(options) > 0 {
+ return o, fmt.Errorf("unsupported mount options: %v", options)
+ }
+
+ return o, nil
+}
diff --git a/pkg/sentry/fs/gofer/gofer_test.go b/pkg/sentry/fs/gofer/gofer_test.go
new file mode 100644
index 000000000..2df2fe889
--- /dev/null
+++ b/pkg/sentry/fs/gofer/gofer_test.go
@@ -0,0 +1,310 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+ "syscall"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/p9/p9test"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+)
+
+// rootTest runs a test with a p9 mock and an fs.InodeOperations created from
+// the attached root directory. The root file will be closed and client
+// disconnected, but additional files must be closed manually.
+func rootTest(t *testing.T, name string, cp cachePolicy, fn func(context.Context, *p9test.Harness, *p9test.Mock, *fs.Inode)) {
+ t.Run(name, func(t *testing.T) {
+ h, c := p9test.NewHarness(t)
+ defer h.Finish()
+
+ // Create a new root. Note that we pass an empty, but non-nil
+ // map here. This allows tests to extend the root children
+ // dynamically.
+ root := h.NewDirectory(map[string]p9test.Generator{})(nil)
+
+ // Return this as the root.
+ h.Attacher.EXPECT().Attach().Return(root, nil).Times(1)
+
+ // ... and open via the client.
+ rootFile, err := c.Attach("/")
+ if err != nil {
+ t.Fatalf("unable to attach: %v", err)
+ }
+ defer rootFile.Close()
+
+ // Wrap an a session.
+ s := &session{
+ mounter: fs.RootOwner,
+ cachePolicy: cp,
+ client: c,
+ }
+
+ // ... and an INode, with only the mode being explicitly valid for now.
+ ctx := contexttest.Context(t)
+ sattr, rootInodeOperations := newInodeOperations(ctx, s, contextFile{
+ file: rootFile,
+ }, root.QID, p9.AttrMaskAll(), root.Attr)
+ m := fs.NewMountSource(ctx, s, &filesystem{}, fs.MountSourceFlags{})
+ rootInode := fs.NewInode(ctx, rootInodeOperations, m, sattr)
+
+ // Ensure that the cache is fully invalidated, so that any
+ // close actions actually take place before the full harness is
+ // torn down.
+ defer func() {
+ m.FlushDirentRefs()
+
+ // Wait for all resources to be released, otherwise the
+ // operations may fail after we close the rootFile.
+ fs.AsyncBarrier()
+ }()
+
+ // Execute the test.
+ fn(ctx, h, root, rootInode)
+ })
+}
+
+func TestLookup(t *testing.T) {
+ type lookupTest struct {
+ // Name of the test.
+ name string
+
+ // Expected return value.
+ want error
+ }
+
+ tests := []lookupTest{
+ {
+ name: "mock Walk passes (function succeeds)",
+ want: nil,
+ },
+ {
+ name: "mock Walk fails (function fails)",
+ want: syscall.ENOENT,
+ },
+ }
+
+ const file = "file" // The walked target file.
+
+ for _, test := range tests {
+ rootTest(t, test.name, cacheNone, func(ctx context.Context, h *p9test.Harness, rootFile *p9test.Mock, rootInode *fs.Inode) {
+ // Setup the appropriate result.
+ rootFile.WalkCallback = func() error {
+ return test.want
+ }
+ if test.want == nil {
+ // Set the contents of the root. We expect a
+ // normal file generator for ppp above. This is
+ // overriden by setting WalkErr in the mock.
+ rootFile.AddChild(file, h.NewFile())
+ }
+
+ // Call function.
+ dirent, err := rootInode.Lookup(ctx, file)
+
+ // Unwrap the InodeOperations.
+ var newInodeOperations fs.InodeOperations
+ if dirent != nil {
+ if dirent.IsNegative() {
+ err = syscall.ENOENT
+ } else {
+ newInodeOperations = dirent.Inode.InodeOperations
+ }
+ }
+
+ // Check return values.
+ if err != test.want {
+ t.Errorf("Lookup got err %v, want %v", err, test.want)
+ }
+ if err == nil && newInodeOperations == nil {
+ t.Errorf("Lookup got non-nil err and non-nil node, wanted at least one non-nil")
+ }
+ })
+ }
+}
+
+func TestRevalidation(t *testing.T) {
+ type revalidationTest struct {
+ cachePolicy cachePolicy
+
+ // Whether dirent should be reloaded before any modifications.
+ preModificationWantReload bool
+
+ // Whether dirent should be reloaded after updating an unstable
+ // attribute on the remote fs.
+ postModificationWantReload bool
+
+ // Whether dirent unstable attributes should be updated after
+ // updating an attribute on the remote fs.
+ postModificationWantUpdatedAttrs bool
+
+ // Whether dirent should be reloaded after the remote has
+ // removed the file.
+ postRemovalWantReload bool
+ }
+
+ tests := []revalidationTest{
+ {
+ // Policy cacheNone causes Revalidate to always return
+ // true.
+ cachePolicy: cacheNone,
+ preModificationWantReload: true,
+ postModificationWantReload: true,
+ postModificationWantUpdatedAttrs: true,
+ postRemovalWantReload: true,
+ },
+ {
+ // Policy cacheAll causes Revalidate to always return
+ // false.
+ cachePolicy: cacheAll,
+ preModificationWantReload: false,
+ postModificationWantReload: false,
+ postModificationWantUpdatedAttrs: false,
+ postRemovalWantReload: false,
+ },
+ {
+ // Policy cacheAllWritethrough causes Revalidate to
+ // always return false.
+ cachePolicy: cacheAllWritethrough,
+ preModificationWantReload: false,
+ postModificationWantReload: false,
+ postModificationWantUpdatedAttrs: false,
+ postRemovalWantReload: false,
+ },
+ {
+ // Policy cacheRemoteRevalidating causes Revalidate to
+ // return update cached unstable attrs, and returns
+ // true only when the remote inode itself has been
+ // removed or replaced.
+ cachePolicy: cacheRemoteRevalidating,
+ preModificationWantReload: false,
+ postModificationWantReload: false,
+ postModificationWantUpdatedAttrs: true,
+ postRemovalWantReload: true,
+ },
+ }
+
+ const file = "file" // The file walked below.
+
+ for _, test := range tests {
+ name := fmt.Sprintf("cachepolicy=%s", test.cachePolicy)
+ rootTest(t, name, test.cachePolicy, func(ctx context.Context, h *p9test.Harness, rootFile *p9test.Mock, rootInode *fs.Inode) {
+ // Wrap in a dirent object.
+ rootDir := fs.NewDirent(ctx, rootInode, "root")
+
+ // Create a mock file a child of the root. We save when
+ // this is generated, so that when the time changed, we
+ // can update the original entry.
+ var origMocks []*p9test.Mock
+ rootFile.AddChild(file, func(parent *p9test.Mock) *p9test.Mock {
+ // Regular a regular file that has a consistent
+ // path number. This might be used by
+ // validation so we don't change it.
+ m := h.NewMock(parent, 0, p9.Attr{
+ Mode: p9.ModeRegular,
+ })
+ origMocks = append(origMocks, m)
+ return m
+ })
+
+ // Do the walk.
+ dirent, err := rootDir.Walk(ctx, rootDir, file)
+ if err != nil {
+ t.Fatalf("Lookup failed: %v", err)
+ }
+
+ // We must release the dirent, of the test will fail
+ // with a reference leak. This is tracked by p9test.
+ defer dirent.DecRef()
+
+ // Walk again. Depending on the cache policy, we may
+ // get a new dirent.
+ newDirent, err := rootDir.Walk(ctx, rootDir, file)
+ if err != nil {
+ t.Fatalf("Lookup failed: %v", err)
+ }
+ if test.preModificationWantReload && dirent == newDirent {
+ t.Errorf("Lookup with cachePolicy=%s got old dirent %+v, wanted a new dirent", test.cachePolicy, dirent)
+ }
+ if !test.preModificationWantReload && dirent != newDirent {
+ t.Errorf("Lookup with cachePolicy=%s got new dirent %+v, wanted old dirent %+v", test.cachePolicy, newDirent, dirent)
+ }
+ newDirent.DecRef() // See above.
+
+ // Modify the underlying mocked file's modification
+ // time for the next walk that occurs.
+ nowSeconds := time.Now().Unix()
+ rootFile.AddChild(file, func(parent *p9test.Mock) *p9test.Mock {
+ // Ensure that the path is the same as above,
+ // but we change only the modification time of
+ // the file.
+ return h.NewMock(parent, 0, p9.Attr{
+ Mode: p9.ModeRegular,
+ MTimeSeconds: uint64(nowSeconds),
+ })
+ })
+
+ // We also modify the original time, so that GetAttr
+ // behaves as expected for the caching case.
+ for _, m := range origMocks {
+ m.Attr.MTimeSeconds = uint64(nowSeconds)
+ }
+
+ // Walk again. Depending on the cache policy, we may
+ // get a new dirent.
+ newDirent, err = rootDir.Walk(ctx, rootDir, file)
+ if err != nil {
+ t.Fatalf("Lookup failed: %v", err)
+ }
+ if test.postModificationWantReload && dirent == newDirent {
+ t.Errorf("Lookup with cachePolicy=%s got old dirent, wanted a new dirent", test.cachePolicy)
+ }
+ if !test.postModificationWantReload && dirent != newDirent {
+ t.Errorf("Lookup with cachePolicy=%s got new dirent, wanted old dirent", test.cachePolicy)
+ }
+ uattrs, err := newDirent.Inode.UnstableAttr(ctx)
+ if err != nil {
+ t.Fatalf("Error getting unstable attrs: %v", err)
+ }
+ gotModTimeSeconds := uattrs.ModificationTime.Seconds()
+ if test.postModificationWantUpdatedAttrs && gotModTimeSeconds != nowSeconds {
+ t.Fatalf("Lookup with cachePolicy=%s got new modification time %v, wanted %v", test.cachePolicy, gotModTimeSeconds, nowSeconds)
+ }
+ newDirent.DecRef() // See above.
+
+ // Remove the file from the remote fs, subsequent walks
+ // should now fail to find anything.
+ rootFile.RemoveChild(file)
+
+ // Walk again. Depending on the cache policy, we may
+ // get ENOENT.
+ newDirent, err = rootDir.Walk(ctx, rootDir, file)
+ if test.postRemovalWantReload && err == nil {
+ t.Errorf("Lookup with cachePolicy=%s got nil error, wanted ENOENT", test.cachePolicy)
+ }
+ if !test.postRemovalWantReload && (err != nil || dirent != newDirent) {
+ t.Errorf("Lookup with cachePolicy=%s got new dirent and error %v, wanted old dirent and nil error", test.cachePolicy, err)
+ }
+ if err == nil {
+ newDirent.DecRef() // See above.
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/fs/gofer/handles.go b/pkg/sentry/fs/gofer/handles.go
new file mode 100644
index 000000000..fc14249be
--- /dev/null
+++ b/pkg/sentry/fs/gofer/handles.go
@@ -0,0 +1,140 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/secio"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+)
+
+// handles are the open handles of a gofer file. They are reference counted to
+// support open handle sharing between files for read only filesystems.
+//
+// If Host != nil then it will be used exclusively over File.
+type handles struct {
+ refs.AtomicRefCount
+
+ // File is a p9.File handle. Must not be nil.
+ File contextFile
+
+ // Host is an *fd.FD handle. May be nil.
+ Host *fd.FD
+
+ // isHostBorrowed tells whether 'Host' is owned or borrowed. If owned, it's
+ // closed on destruction, otherwise it's released.
+ isHostBorrowed bool
+}
+
+// DecRef drops a reference on handles.
+func (h *handles) DecRef() {
+ h.DecRefWithDestructor(func() {
+ if h.Host != nil {
+ if h.isHostBorrowed {
+ h.Host.Release()
+ } else {
+ if err := h.Host.Close(); err != nil {
+ log.Warningf("error closing host file: %v", err)
+ }
+ }
+ }
+ if err := h.File.close(context.Background()); err != nil {
+ log.Warningf("error closing p9 file: %v", err)
+ }
+ })
+}
+
+func newHandles(ctx context.Context, client *p9.Client, file contextFile, flags fs.FileFlags) (*handles, error) {
+ _, newFile, err := file.walk(ctx, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ var p9flags p9.OpenFlags
+ switch {
+ case flags.Read && flags.Write:
+ p9flags = p9.ReadWrite
+ case flags.Read && !flags.Write:
+ p9flags = p9.ReadOnly
+ case !flags.Read && flags.Write:
+ p9flags = p9.WriteOnly
+ default:
+ panic("impossible fs.FileFlags")
+ }
+ if flags.Truncate && p9.VersionSupportsOpenTruncateFlag(client.Version()) {
+ p9flags |= p9.OpenTruncate
+ }
+
+ hostFile, _, _, err := newFile.open(ctx, p9flags)
+ if err != nil {
+ newFile.close(ctx)
+ return nil, err
+ }
+ h := handles{
+ File: newFile,
+ Host: hostFile,
+ }
+ h.EnableLeakCheck("gofer.handles")
+ return &h, nil
+}
+
+type handleReadWriter struct {
+ ctx context.Context
+ h *handles
+ off int64
+}
+
+func (h *handles) readWriterAt(ctx context.Context, offset int64) *handleReadWriter {
+ return &handleReadWriter{ctx, h, offset}
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (rw *handleReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ var r io.Reader
+ if rw.h.Host != nil {
+ r = secio.NewOffsetReader(rw.h.Host, rw.off)
+ } else {
+ r = &p9.ReadWriterFile{File: rw.h.File.file, Offset: uint64(rw.off)}
+ }
+
+ rw.ctx.UninterruptibleSleepStart(false)
+ defer rw.ctx.UninterruptibleSleepFinish(false)
+ n, err := safemem.FromIOReader{r}.ReadToBlocks(dsts)
+ rw.off += int64(n)
+ return n, err
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (rw *handleReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ var w io.Writer
+ if rw.h.Host != nil {
+ w = secio.NewOffsetWriter(rw.h.Host, rw.off)
+ } else {
+ w = &p9.ReadWriterFile{File: rw.h.File.file, Offset: uint64(rw.off)}
+ }
+
+ rw.ctx.UninterruptibleSleepStart(false)
+ defer rw.ctx.UninterruptibleSleepFinish(false)
+ n, err := safemem.FromIOWriter{w}.WriteFromBlocks(srcs)
+ rw.off += int64(n)
+ return n, err
+}
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
new file mode 100644
index 000000000..51d7368a1
--- /dev/null
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -0,0 +1,719 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "errors"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/device"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fdpipe"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fs/host"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// inodeOperations implements fs.InodeOperations.
+//
+// +stateify savable
+type inodeOperations struct {
+ fsutil.InodeNotVirtual `state:"nosave"`
+
+ // fileState implements fs.CachedFileObject. It exists
+ // to break a circular load dependency between inodeOperations
+ // and cachingInodeOps (below).
+ fileState *inodeFileState `state:"wait"`
+
+ // cachingInodeOps implement memmap.Mappable for inodeOperations.
+ cachingInodeOps *fsutil.CachingInodeOperations
+
+ // readdirMu protects readdirCache and concurrent Readdirs.
+ readdirMu sync.Mutex `state:"nosave"`
+
+ // readdirCache is a cache of readdir results in the form of
+ // a fs.SortedDentryMap.
+ //
+ // Starts out as nil, and is initialized under readdirMu lazily;
+ // invalidating the cache means setting it to nil.
+ readdirCache *fs.SortedDentryMap `state:"nosave"`
+}
+
+// inodeFileState implements fs.CachedFileObject and otherwise fully
+// encapsulates state that needs to be manually loaded on restore for
+// this file object.
+//
+// This unfortunate structure exists because fs.CachingInodeOperations
+// defines afterLoad and therefore cannot be lazily loaded (to break a
+// circular load dependency between it and inodeOperations). Even with
+// lazy loading, this approach defines the dependencies between objects
+// and the expected load behavior more concretely.
+//
+// +stateify savable
+type inodeFileState struct {
+ // s is common file system state for Gofers.
+ s *session `state:"wait"`
+
+ // MultiDeviceKey consists of:
+ //
+ // * Device: file system device from a specific gofer.
+ // * SecondaryDevice: unique identifier of the attach point.
+ // * Inode: the inode of this resource, unique per Device.=
+ //
+ // These fields combined enable consistent hashing of virtual inodes
+ // on goferDevice.
+ key device.MultiDeviceKey `state:"nosave"`
+
+ // file is the p9 file that contains a single unopened fid.
+ file contextFile `state:"nosave"`
+
+ // sattr caches the stable attributes.
+ sattr fs.StableAttr `state:"wait"`
+
+ // handlesMu protects the below fields.
+ handlesMu sync.RWMutex `state:"nosave"`
+
+ // If readHandles is non-nil, it holds handles that are either read-only or
+ // read/write. If writeHandles is non-nil, it holds write-only handles if
+ // writeHandlesRW is false, and read/write handles if writeHandlesRW is
+ // true.
+ //
+ // Once readHandles becomes non-nil, it can't be changed until
+ // inodeFileState.Release()*, because of a defect in the
+ // fsutil.CachedFileObject interface: there's no way for the caller of
+ // fsutil.CachedFileObject.FD() to keep the returned FD open, so if we
+ // racily replace readHandles after inodeFileState.FD() has returned
+ // readHandles.Host.FD(), fsutil.CachingInodeOperations may use a closed
+ // FD. writeHandles can be changed if writeHandlesRW is false, since
+ // inodeFileState.FD() can't return a write-only FD, but can't be changed
+ // if writeHandlesRW is true for the same reason.
+ //
+ // * There is one notable exception in recreateReadHandles(), where it dup's
+ // the FD and invalidates the page cache.
+ readHandles *handles `state:"nosave"`
+ writeHandles *handles `state:"nosave"`
+ writeHandlesRW bool `state:"nosave"`
+
+ // loading is acquired when the inodeFileState begins an asynchronous
+ // load. It releases when the load is complete. Callers that require all
+ // state to be available should call waitForLoad() to ensure that.
+ loading sync.Mutex `state:".(struct{})"`
+
+ // savedUAttr is only allocated during S/R. It points to the save-time
+ // unstable attributes and is used to validate restore-time ones.
+ //
+ // Note that these unstable attributes are only used to detect cross-S/R
+ // external file system metadata changes. They may differ from the
+ // cached unstable attributes in cachingInodeOps, as that might differ
+ // from the external file system attributes if there had been WriteOut
+ // failures. S/R is transparent to Sentry and the latter will continue
+ // using its cached values after restore.
+ savedUAttr *fs.UnstableAttr
+
+ // hostMappable is created when using 'cacheRemoteRevalidating' to map pages
+ // directly from host.
+ hostMappable *fsutil.HostMappable
+}
+
+// Release releases file handles.
+func (i *inodeFileState) Release(ctx context.Context) {
+ i.file.close(ctx)
+ if i.readHandles != nil {
+ i.readHandles.DecRef()
+ }
+ if i.writeHandles != nil {
+ i.writeHandles.DecRef()
+ }
+}
+
+func (i *inodeFileState) canShareHandles() bool {
+ // Only share handles for regular files, since for other file types,
+ // distinct handles may have special semantics even if they represent the
+ // same file. Disable handle sharing for cache policy cacheNone, since this
+ // is legacy behavior.
+ return fs.IsFile(i.sattr) && i.s.cachePolicy != cacheNone
+}
+
+// Preconditions: i.handlesMu must be locked for writing.
+func (i *inodeFileState) setSharedHandlesLocked(flags fs.FileFlags, h *handles) {
+ if flags.Read && i.readHandles == nil {
+ h.IncRef()
+ i.readHandles = h
+ }
+ if flags.Write {
+ if i.writeHandles == nil {
+ h.IncRef()
+ i.writeHandles = h
+ i.writeHandlesRW = flags.Read
+ } else if !i.writeHandlesRW && flags.Read {
+ // Upgrade i.writeHandles.
+ i.writeHandles.DecRef()
+ h.IncRef()
+ i.writeHandles = h
+ i.writeHandlesRW = flags.Read
+ }
+ }
+}
+
+// getHandles returns a set of handles for a new file using i opened with the
+// given flags.
+func (i *inodeFileState) getHandles(ctx context.Context, flags fs.FileFlags, cache *fsutil.CachingInodeOperations) (*handles, error) {
+ if !i.canShareHandles() {
+ return newHandles(ctx, i.s.client, i.file, flags)
+ }
+
+ i.handlesMu.Lock()
+ h, invalidate, err := i.getHandlesLocked(ctx, flags)
+ i.handlesMu.Unlock()
+
+ if invalidate {
+ cache.NotifyChangeFD()
+ if i.hostMappable != nil {
+ i.hostMappable.NotifyChangeFD()
+ }
+ }
+
+ return h, err
+}
+
+// getHandlesLocked returns a pointer to cached handles and a boolean indicating
+// whether previously open read handle was recreated. Host mappings must be
+// invalidated if so.
+func (i *inodeFileState) getHandlesLocked(ctx context.Context, flags fs.FileFlags) (*handles, bool, error) {
+ // Check if we are able to use cached handles.
+ if flags.Truncate && p9.VersionSupportsOpenTruncateFlag(i.s.client.Version()) {
+ // If we are truncating (and the gofer supports it), then we
+ // always need a new handle. Don't return one from the cache.
+ } else if flags.Write {
+ if i.writeHandles != nil && (i.writeHandlesRW || !flags.Read) {
+ // File is opened for writing, and we have cached write
+ // handles that we can use.
+ i.writeHandles.IncRef()
+ return i.writeHandles, false, nil
+ }
+ } else if i.readHandles != nil {
+ // File is opened for reading and we have cached handles.
+ i.readHandles.IncRef()
+ return i.readHandles, false, nil
+ }
+
+ // Get new handles and cache them for future sharing.
+ h, err := newHandles(ctx, i.s.client, i.file, flags)
+ if err != nil {
+ return nil, false, err
+ }
+
+ // Read handles invalidation is needed if:
+ // - Mount option 'overlayfs_stale_read' is set
+ // - Read handle is open: nothing to invalidate otherwise
+ // - Write handle is not open: file was not open for write and is being open
+ // for write now (will trigger copy up in overlayfs).
+ invalidate := false
+ if i.s.overlayfsStaleRead && i.readHandles != nil && i.writeHandles == nil && flags.Write {
+ if err := i.recreateReadHandles(ctx, h, flags); err != nil {
+ return nil, false, err
+ }
+ invalidate = true
+ }
+ i.setSharedHandlesLocked(flags, h)
+ return h, invalidate, nil
+}
+
+func (i *inodeFileState) recreateReadHandles(ctx context.Context, writer *handles, flags fs.FileFlags) error {
+ h := writer
+ if !flags.Read {
+ // Writer can't be used for read, must create a new handle.
+ var err error
+ h, err = newHandles(ctx, i.s.client, i.file, fs.FileFlags{Read: true})
+ if err != nil {
+ return err
+ }
+ defer h.DecRef()
+ }
+
+ if i.readHandles.Host == nil {
+ // If current readHandles doesn't have a host FD, it can simply be replaced.
+ i.readHandles.DecRef()
+
+ h.IncRef()
+ i.readHandles = h
+ return nil
+ }
+
+ if h.Host == nil {
+ // Current read handle has a host FD and can't be replaced with one that
+ // doesn't, because it breaks fsutil.CachedFileObject.FD() contract.
+ log.Warningf("Read handle can't be invalidated, reads may return stale data")
+ return nil
+ }
+
+ // Due to a defect in the fsutil.CachedFileObject interface,
+ // readHandles.Host.FD() may be used outside locks, making it impossible to
+ // reliably close it. To workaround it, we dup the new FD into the old one, so
+ // operations on the old will see the new data. Then, make the new handle take
+ // ownereship of the old FD and mark the old readHandle to not close the FD
+ // when done.
+ if err := syscall.Dup3(h.Host.FD(), i.readHandles.Host.FD(), syscall.O_CLOEXEC); err != nil {
+ return err
+ }
+
+ h.Host.Close()
+ h.Host = fd.New(i.readHandles.Host.FD())
+ i.readHandles.isHostBorrowed = true
+ i.readHandles.DecRef()
+
+ h.IncRef()
+ i.readHandles = h
+ return nil
+}
+
+// ReadToBlocksAt implements fsutil.CachedFileObject.ReadToBlocksAt.
+func (i *inodeFileState) ReadToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) {
+ i.handlesMu.RLock()
+ n, err := i.readHandles.readWriterAt(ctx, int64(offset)).ReadToBlocks(dsts)
+ i.handlesMu.RUnlock()
+ return n, err
+}
+
+// WriteFromBlocksAt implements fsutil.CachedFileObject.WriteFromBlocksAt.
+func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) {
+ i.handlesMu.RLock()
+ n, err := i.writeHandles.readWriterAt(ctx, int64(offset)).WriteFromBlocks(srcs)
+ i.handlesMu.RUnlock()
+ return n, err
+}
+
+// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
+func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, forceSetTimestamps bool) error {
+ if i.skipSetAttr(mask, forceSetTimestamps) {
+ return nil
+ }
+ as, ans := attr.AccessTime.Unix()
+ ms, mns := attr.ModificationTime.Unix()
+ // An update of status change time is implied by mask.AccessTime
+ // or mask.ModificationTime. Updating status change time to a
+ // time earlier than the system time is not possible.
+ return i.file.setAttr(
+ ctx,
+ p9.SetAttrMask{
+ Permissions: mask.Perms,
+ Size: mask.Size,
+ UID: mask.UID,
+ GID: mask.GID,
+ ATime: mask.AccessTime,
+ ATimeNotSystemTime: true,
+ MTime: mask.ModificationTime,
+ MTimeNotSystemTime: true,
+ }, p9.SetAttr{
+ Permissions: p9.FileMode(attr.Perms.LinuxMode()),
+ UID: p9.UID(attr.Owner.UID),
+ GID: p9.GID(attr.Owner.GID),
+ Size: uint64(attr.Size),
+ ATimeSeconds: uint64(as),
+ ATimeNanoSeconds: uint64(ans),
+ MTimeSeconds: uint64(ms),
+ MTimeNanoSeconds: uint64(mns),
+ })
+}
+
+// skipSetAttr checks if attribute change can be skipped. It can be skipped
+// when:
+// - Mask is empty
+// - Mask contains only attributes that cannot be set in the gofer
+// - forceSetTimestamps is false and mask contains only atime and/or mtime
+// and host FD exists
+//
+// Updates to atime and mtime can be skipped because cached value will be
+// "close enough" to host value, given that operation went directly to host FD.
+// Skipping atime updates is particularly important to reduce the number of
+// operations sent to the Gofer for readonly files.
+func (i *inodeFileState) skipSetAttr(mask fs.AttrMask, forceSetTimestamps bool) bool {
+ // First remove attributes that cannot be updated.
+ cpy := mask
+ cpy.Type = false
+ cpy.DeviceID = false
+ cpy.InodeID = false
+ cpy.BlockSize = false
+ cpy.Usage = false
+ cpy.Links = false
+ if cpy.Empty() {
+ return true
+ }
+
+ // Then check if more than just atime and mtime is being set.
+ cpy.AccessTime = false
+ cpy.ModificationTime = false
+ if !cpy.Empty() {
+ return false
+ }
+
+ // If forceSetTimestamps was passed, then we cannot skip.
+ if forceSetTimestamps {
+ return false
+ }
+
+ // Skip if we have a host FD.
+ i.handlesMu.RLock()
+ defer i.handlesMu.RUnlock()
+ return (i.readHandles != nil && i.readHandles.Host != nil) ||
+ (i.writeHandles != nil && i.writeHandles.Host != nil)
+}
+
+// Sync implements fsutil.CachedFileObject.Sync.
+func (i *inodeFileState) Sync(ctx context.Context) error {
+ i.handlesMu.RLock()
+ defer i.handlesMu.RUnlock()
+ if i.writeHandles == nil {
+ return nil
+ }
+ return i.writeHandles.File.fsync(ctx)
+}
+
+// FD implements fsutil.CachedFileObject.FD.
+func (i *inodeFileState) FD() int {
+ i.handlesMu.RLock()
+ defer i.handlesMu.RUnlock()
+ if i.writeHandlesRW && i.writeHandles != nil && i.writeHandles.Host != nil {
+ return int(i.writeHandles.Host.FD())
+ }
+ if i.readHandles != nil && i.readHandles.Host != nil {
+ return int(i.readHandles.Host.FD())
+ }
+ return -1
+}
+
+// waitForLoad makes sure any restore-issued loading is done.
+func (i *inodeFileState) waitForLoad() {
+ // This is not a no-op. The loading mutex is hold upon restore until
+ // all loading actions are done.
+ i.loading.Lock()
+ i.loading.Unlock()
+}
+
+func (i *inodeFileState) unstableAttr(ctx context.Context) (fs.UnstableAttr, error) {
+ _, valid, pattr, err := getattr(ctx, i.file)
+ if err != nil {
+ return fs.UnstableAttr{}, err
+ }
+ return unstable(ctx, valid, pattr, i.s.mounter, i.s.client), nil
+}
+
+func (i *inodeFileState) Allocate(ctx context.Context, offset, length int64) error {
+ i.handlesMu.RLock()
+ defer i.handlesMu.RUnlock()
+
+ // No options are supported for now.
+ mode := p9.AllocateMode{}
+ return i.writeHandles.File.allocate(ctx, mode, uint64(offset), uint64(length))
+}
+
+// session extracts the gofer's session from the MountSource.
+func (i *inodeOperations) session() *session {
+ return i.fileState.s
+}
+
+// Release implements fs.InodeOperations.Release.
+func (i *inodeOperations) Release(ctx context.Context) {
+ i.cachingInodeOps.Release()
+
+ // Releasing the fileState may make RPCs to the gofer. There is
+ // no need to wait for those to return, so we can do this
+ // asynchronously.
+ //
+ // We use AsyncWithContext to avoid needing to allocate an extra
+ // anonymous function on the heap.
+ fs.AsyncWithContext(ctx, i.fileState.Release)
+}
+
+// Mappable implements fs.InodeOperations.Mappable.
+func (i *inodeOperations) Mappable(inode *fs.Inode) memmap.Mappable {
+ if i.session().cachePolicy.useCachingInodeOps(inode) {
+ return i.cachingInodeOps
+ }
+ // This check is necessary because it's returning an interface type.
+ if i.fileState.hostMappable != nil {
+ return i.fileState.hostMappable
+ }
+ return nil
+}
+
+// UnstableAttr implements fs.InodeOperations.UnstableAttr.
+func (i *inodeOperations) UnstableAttr(ctx context.Context, inode *fs.Inode) (fs.UnstableAttr, error) {
+ if i.session().cachePolicy.cacheUAttrs(inode) {
+ return i.cachingInodeOps.UnstableAttr(ctx, inode)
+ }
+ return i.fileState.unstableAttr(ctx)
+}
+
+// Check implements fs.InodeOperations.Check.
+func (i *inodeOperations) Check(ctx context.Context, inode *fs.Inode, p fs.PermMask) bool {
+ return fs.ContextCanAccessFile(ctx, inode, p)
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ switch d.Inode.StableAttr.Type {
+ case fs.Socket:
+ return i.getFileSocket(ctx, d, flags)
+ case fs.Pipe:
+ return i.getFilePipe(ctx, d, flags)
+ default:
+ return i.getFileDefault(ctx, d, flags)
+ }
+}
+
+func (i *inodeOperations) getFileSocket(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ f, err := i.fileState.file.connect(ctx, p9.AnonymousSocket)
+ if err != nil {
+ return nil, syscall.EIO
+ }
+ fsf, err := host.NewSocketWithDirent(ctx, d, f, flags)
+ if err != nil {
+ f.Close()
+ return nil, err
+ }
+ return fsf, nil
+}
+
+func (i *inodeOperations) getFilePipe(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ // Try to open as a host pipe; if that doesn't work, handle it normally.
+ pipeOps, err := fdpipe.Open(ctx, i, flags)
+ if err == errNotHostFile {
+ return i.getFileDefault(ctx, d, flags)
+ }
+ if err != nil {
+ return nil, err
+ }
+ return fs.NewFile(ctx, d, flags, pipeOps), nil
+}
+
+// errNotHostFile indicates that the file is not a host file.
+var errNotHostFile = errors.New("not a host file")
+
+// NonBlockingOpen implements fdpipe.NonBlockingOpener for opening host named pipes.
+func (i *inodeOperations) NonBlockingOpen(ctx context.Context, p fs.PermMask) (*fd.FD, error) {
+ i.fileState.waitForLoad()
+
+ // Get a cloned fid which we will open.
+ _, newFile, err := i.fileState.file.walk(ctx, nil)
+ if err != nil {
+ log.Warningf("Open Walk failed: %v", err)
+ return nil, err
+ }
+ defer newFile.close(ctx)
+
+ flags, err := openFlagsFromPerms(p)
+ if err != nil {
+ log.Warningf("Open flags %s parsing failed: %v", p, err)
+ return nil, err
+ }
+ hostFile, _, _, err := newFile.open(ctx, flags)
+ // If the host file returned is nil and the error is nil,
+ // then this was never a host file to begin with, and should
+ // be treated like a remote file.
+ if hostFile == nil && err == nil {
+ return nil, errNotHostFile
+ }
+ return hostFile, err
+}
+
+func (i *inodeOperations) getFileDefault(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ h, err := i.fileState.getHandles(ctx, flags, i.cachingInodeOps)
+ if err != nil {
+ return nil, err
+ }
+ return NewFile(ctx, d, d.BaseName(), flags, i, h), nil
+}
+
+// SetPermissions implements fs.InodeOperations.SetPermissions.
+func (i *inodeOperations) SetPermissions(ctx context.Context, inode *fs.Inode, p fs.FilePermissions) bool {
+ if i.session().cachePolicy.cacheUAttrs(inode) {
+ return i.cachingInodeOps.SetPermissions(ctx, inode, p)
+ }
+
+ mask := p9.SetAttrMask{Permissions: true}
+ pattr := p9.SetAttr{Permissions: p9.FileMode(p.LinuxMode())}
+ // Execute the chmod.
+ return i.fileState.file.setAttr(ctx, mask, pattr) == nil
+}
+
+// SetOwner implements fs.InodeOperations.SetOwner.
+func (i *inodeOperations) SetOwner(ctx context.Context, inode *fs.Inode, owner fs.FileOwner) error {
+ // Save the roundtrip.
+ if !owner.UID.Ok() && !owner.GID.Ok() {
+ return nil
+ }
+
+ if i.session().cachePolicy.cacheUAttrs(inode) {
+ return i.cachingInodeOps.SetOwner(ctx, inode, owner)
+ }
+
+ var mask p9.SetAttrMask
+ var attr p9.SetAttr
+ if owner.UID.Ok() {
+ mask.UID = true
+ attr.UID = p9.UID(owner.UID)
+ }
+ if owner.GID.Ok() {
+ mask.GID = true
+ attr.GID = p9.GID(owner.GID)
+ }
+ return i.fileState.file.setAttr(ctx, mask, attr)
+}
+
+// SetTimestamps implements fs.InodeOperations.SetTimestamps.
+func (i *inodeOperations) SetTimestamps(ctx context.Context, inode *fs.Inode, ts fs.TimeSpec) error {
+ if i.session().cachePolicy.cacheUAttrs(inode) {
+ return i.cachingInodeOps.SetTimestamps(ctx, inode, ts)
+ }
+
+ return utimes(ctx, i.fileState.file, ts)
+}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (i *inodeOperations) Truncate(ctx context.Context, inode *fs.Inode, length int64) error {
+ // This can only be called for files anyway.
+ if i.session().cachePolicy.useCachingInodeOps(inode) {
+ return i.cachingInodeOps.Truncate(ctx, inode, length)
+ }
+ if i.session().cachePolicy == cacheRemoteRevalidating {
+ return i.fileState.hostMappable.Truncate(ctx, length)
+ }
+
+ return i.fileState.file.setAttr(ctx, p9.SetAttrMask{Size: true}, p9.SetAttr{Size: uint64(length)})
+}
+
+// GetXattr implements fs.InodeOperations.GetXattr.
+func (i *inodeOperations) GetXattr(ctx context.Context, _ *fs.Inode, name string, size uint64) (string, error) {
+ return i.fileState.file.getXattr(ctx, name, size)
+}
+
+// SetXattr implements fs.InodeOperations.SetXattr.
+func (i *inodeOperations) SetXattr(ctx context.Context, _ *fs.Inode, name string, value string, flags uint32) error {
+ return i.fileState.file.setXattr(ctx, name, value, flags)
+}
+
+// ListXattr implements fs.InodeOperations.ListXattr.
+func (i *inodeOperations) ListXattr(ctx context.Context, _ *fs.Inode, size uint64) (map[string]struct{}, error) {
+ return i.fileState.file.listXattr(ctx, size)
+}
+
+// RemoveXattr implements fs.InodeOperations.RemoveXattr.
+func (i *inodeOperations) RemoveXattr(ctx context.Context, _ *fs.Inode, name string) error {
+ return i.fileState.file.removeXattr(ctx, name)
+}
+
+// Allocate implements fs.InodeOperations.Allocate.
+func (i *inodeOperations) Allocate(ctx context.Context, inode *fs.Inode, offset, length int64) error {
+ // This can only be called for files anyway.
+ if i.session().cachePolicy.useCachingInodeOps(inode) {
+ return i.cachingInodeOps.Allocate(ctx, offset, length)
+ }
+ if i.session().cachePolicy == cacheRemoteRevalidating {
+ return i.fileState.hostMappable.Allocate(ctx, offset, length)
+ }
+
+ // No options are supported for now.
+ mode := p9.AllocateMode{}
+ return i.fileState.file.allocate(ctx, mode, uint64(offset), uint64(length))
+}
+
+// WriteOut implements fs.InodeOperations.WriteOut.
+func (i *inodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error {
+ if inode.MountSource.Flags.ReadOnly || !i.session().cachePolicy.cacheUAttrs(inode) {
+ return nil
+ }
+
+ return i.cachingInodeOps.WriteOut(ctx, inode)
+}
+
+// Readlink implements fs.InodeOperations.Readlink.
+func (i *inodeOperations) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
+ if !fs.IsSymlink(inode.StableAttr) {
+ return "", syscall.ENOLINK
+ }
+ return i.fileState.file.readlink(ctx)
+}
+
+// Getlink implementfs fs.InodeOperations.Getlink.
+func (i *inodeOperations) Getlink(context.Context, *fs.Inode) (*fs.Dirent, error) {
+ if !fs.IsSymlink(i.fileState.sattr) {
+ return nil, syserror.ENOLINK
+ }
+ return nil, fs.ErrResolveViaReadlink
+}
+
+// StatFS makes a StatFS request.
+func (i *inodeOperations) StatFS(ctx context.Context) (fs.Info, error) {
+ fsstat, err := i.fileState.file.statFS(ctx)
+ if err != nil {
+ return fs.Info{}, err
+ }
+
+ info := fs.Info{
+ // This is primarily for distinguishing a gofer file system in
+ // tests. Testing is important, so instead of defining
+ // something completely random, use a standard value.
+ Type: linux.V9FS_MAGIC,
+ TotalBlocks: fsstat.Blocks,
+ FreeBlocks: fsstat.BlocksFree,
+ TotalFiles: fsstat.Files,
+ FreeFiles: fsstat.FilesFree,
+ }
+
+ // If blocks available is non-zero, prefer that.
+ if fsstat.BlocksAvailable != 0 {
+ info.FreeBlocks = fsstat.BlocksAvailable
+ }
+
+ return info, nil
+}
+
+func (i *inodeOperations) configureMMap(file *fs.File, opts *memmap.MMapOpts) error {
+ if i.session().cachePolicy.useCachingInodeOps(file.Dirent.Inode) {
+ return fsutil.GenericConfigureMMap(file, i.cachingInodeOps, opts)
+ }
+ if i.fileState.hostMappable != nil {
+ return fsutil.GenericConfigureMMap(file, i.fileState.hostMappable, opts)
+ }
+ return syserror.ENODEV
+}
+
+func init() {
+ syserror.AddErrorUnwrapper(func(err error) (syscall.Errno, bool) {
+ if _, ok := err.(p9.ErrSocket); ok {
+ // Treat as an I/O error.
+ return syscall.EIO, true
+ }
+ return 0, false
+ })
+}
+
+// AddLink implements InodeOperations.AddLink, but is currently a noop.
+func (*inodeOperations) AddLink() {}
+
+// DropLink implements InodeOperations.DropLink, but is currently a noop.
+func (*inodeOperations) DropLink() {}
+
+// NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange.
+func (i *inodeOperations) NotifyStatusChange(ctx context.Context) {}
diff --git a/pkg/sentry/fs/gofer/inode_state.go b/pkg/sentry/fs/gofer/inode_state.go
new file mode 100644
index 000000000..a3402e343
--- /dev/null
+++ b/pkg/sentry/fs/gofer/inode_state.go
@@ -0,0 +1,171 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "errors"
+ "fmt"
+ "path/filepath"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/device"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+)
+
+// Some fs implementations may not support atime, ctime, or mtime in getattr.
+// The unstable() logic would try to use clock time for them. However, we do not
+// want to use such time during S/R as that would cause restore timestamp
+// checking failure. Hence a dummy stable-time clock is needed.
+//
+// Note that application-visible UnstableAttrs either come from CachingInodeOps
+// (in which case they are saved), or they are requested from the gofer on each
+// stat (for non-caching), so the dummy time only affects the modification
+// timestamp check.
+type dummyClock struct {
+ time.Clock
+}
+
+// Now returns a stable dummy time.
+func (d *dummyClock) Now() time.Time {
+ return time.Time{}
+}
+
+type dummyClockContext struct {
+ context.Context
+}
+
+// Value implements context.Context
+func (d *dummyClockContext) Value(key interface{}) interface{} {
+ switch key {
+ case time.CtxRealtimeClock:
+ return &dummyClock{}
+ default:
+ return d.Context.Value(key)
+ }
+}
+
+// beforeSave is invoked by stateify.
+func (i *inodeFileState) beforeSave() {
+ if _, ok := i.s.inodeMappings[i.sattr.InodeID]; !ok {
+ panic(fmt.Sprintf("failed to find path for inode number %d. Device %s contains %s", i.sattr.InodeID, i.s.connID, fs.InodeMappings(i.s.inodeMappings)))
+ }
+ if i.sattr.Type == fs.RegularFile {
+ uattr, err := i.unstableAttr(&dummyClockContext{context.Background()})
+ if err != nil {
+ panic(fs.ErrSaveRejection{fmt.Errorf("failed to get unstable atttribute of %s: %v", i.s.inodeMappings[i.sattr.InodeID], err)})
+ }
+ i.savedUAttr = &uattr
+ }
+}
+
+// saveLoading is invoked by stateify.
+func (i *inodeFileState) saveLoading() struct{} {
+ return struct{}{}
+}
+
+// splitAbsolutePath splits the path on slashes ignoring the leading slash.
+func splitAbsolutePath(path string) []string {
+ if len(path) == 0 {
+ panic("There is no path!")
+ }
+ if path != filepath.Clean(path) {
+ panic(fmt.Sprintf("path %q is not clean", path))
+ }
+ // This case is to return {} rather than {""}
+ if path == "/" {
+ return []string{}
+ }
+ if path[0] != '/' {
+ panic(fmt.Sprintf("path %q is not absolute", path))
+ }
+
+ s := strings.Split(path, "/")
+
+ // Since p is absolute, the first component of s
+ // is an empty string. We must remove that.
+ return s[1:]
+}
+
+// loadLoading is invoked by stateify.
+func (i *inodeFileState) loadLoading(_ struct{}) {
+ i.loading.Lock()
+}
+
+// afterLoad is invoked by stateify.
+func (i *inodeFileState) afterLoad() {
+ load := func() (err error) {
+ // See comment on i.loading().
+ defer func() {
+ if err == nil {
+ i.loading.Unlock()
+ }
+ }()
+
+ // Manually restore the p9.File.
+ name, ok := i.s.inodeMappings[i.sattr.InodeID]
+ if !ok {
+ // This should be impossible, see assertion in
+ // beforeSave.
+ return fmt.Errorf("failed to find path for inode number %d. Device %s contains %s", i.sattr.InodeID, i.s.connID, fs.InodeMappings(i.s.inodeMappings))
+ }
+ ctx := &dummyClockContext{context.Background()}
+
+ _, i.file, err = i.s.attach.walk(ctx, splitAbsolutePath(name))
+ if err != nil {
+ return fs.ErrCorruption{fmt.Errorf("failed to walk to %q: %v", name, err)}
+ }
+
+ // Remap the saved inode number into the gofer device using the
+ // actual device and actual inode that exists in our new
+ // environment.
+ qid, mask, attrs, err := i.file.getAttr(ctx, p9.AttrMaskAll())
+ if err != nil {
+ return fs.ErrCorruption{fmt.Errorf("failed to get file attributes of %s: %v", name, err)}
+ }
+ if !mask.RDev {
+ return fs.ErrCorruption{fmt.Errorf("file %s lacks device", name)}
+ }
+ i.key = device.MultiDeviceKey{
+ Device: attrs.RDev,
+ SecondaryDevice: i.s.connID,
+ Inode: qid.Path,
+ }
+ if !goferDevice.Load(i.key, i.sattr.InodeID) {
+ return fs.ErrCorruption{fmt.Errorf("gofer device %s -> %d conflict in gofer device mappings: %s", i.key, i.sattr.InodeID, goferDevice)}
+ }
+
+ if i.sattr.Type == fs.RegularFile {
+ env, ok := fs.CurrentRestoreEnvironment()
+ if !ok {
+ return errors.New("missing restore environment")
+ }
+ uattr := unstable(ctx, mask, attrs, i.s.mounter, i.s.client)
+ if env.ValidateFileSize && uattr.Size != i.savedUAttr.Size {
+ return fs.ErrCorruption{fmt.Errorf("file size has changed for %s: previously %d, now %d", i.s.inodeMappings[i.sattr.InodeID], i.savedUAttr.Size, uattr.Size)}
+ }
+ if env.ValidateFileTimestamp && uattr.ModificationTime != i.savedUAttr.ModificationTime {
+ return fs.ErrCorruption{fmt.Errorf("file modification time has changed for %s: previously %v, now %v", i.s.inodeMappings[i.sattr.InodeID], i.savedUAttr.ModificationTime, uattr.ModificationTime)}
+ }
+ i.savedUAttr = nil
+ }
+
+ return nil
+ }
+
+ fs.Async(fs.CatchError(load))
+}
diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go
new file mode 100644
index 000000000..cf9800100
--- /dev/null
+++ b/pkg/sentry/fs/gofer/path.go
@@ -0,0 +1,495 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/device"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// maxFilenameLen is the maximum length of a filename. This is dictated by 9P's
+// encoding of strings, which uses 2 bytes for the length prefix.
+const maxFilenameLen = (1 << 16) - 1
+
+func changeType(mode p9.FileMode, newType p9.FileMode) p9.FileMode {
+ if newType&^p9.FileModeMask != 0 {
+ panic(fmt.Sprintf("newType contained more bits than just file mode: %x", newType))
+ }
+ clear := mode &^ p9.FileModeMask
+ return clear | newType
+}
+
+// Lookup loads an Inode at name into a Dirent based on the session's cache
+// policy.
+func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string) (*fs.Dirent, error) {
+ if len(name) > maxFilenameLen {
+ return nil, syserror.ENAMETOOLONG
+ }
+
+ cp := i.session().cachePolicy
+ if cp.cacheReaddir() {
+ // Check to see if we have readdirCache that indicates the
+ // child does not exist. Avoid holding readdirMu longer than
+ // we need to.
+ i.readdirMu.Lock()
+ if i.readdirCache != nil && !i.readdirCache.Contains(name) {
+ // No such child.
+ i.readdirMu.Unlock()
+ if cp.cacheNegativeDirents() {
+ return fs.NewNegativeDirent(name), nil
+ }
+ return nil, syserror.ENOENT
+ }
+ i.readdirMu.Unlock()
+ }
+
+ // Get a p9.File for name.
+ qids, newFile, mask, p9attr, err := i.fileState.file.walkGetAttr(ctx, []string{name})
+ if err != nil {
+ if err == syserror.ENOENT {
+ if cp.cacheNegativeDirents() {
+ // Return a negative Dirent. It will stay cached until something
+ // is created over it.
+ return fs.NewNegativeDirent(name), nil
+ }
+ return nil, syserror.ENOENT
+ }
+ return nil, err
+ }
+
+ if i.session().overrides != nil {
+ // Check if file belongs to a internal named pipe. Note that it doesn't need
+ // to check for sockets because it's done in newInodeOperations below.
+ deviceKey := device.MultiDeviceKey{
+ Device: p9attr.RDev,
+ SecondaryDevice: i.session().connID,
+ Inode: qids[0].Path,
+ }
+ unlock := i.session().overrides.lock()
+ if pipeInode := i.session().overrides.getPipe(deviceKey); pipeInode != nil {
+ unlock()
+ pipeInode.IncRef()
+ return fs.NewDirent(ctx, pipeInode, name), nil
+ }
+ unlock()
+ }
+
+ // Construct the Inode operations.
+ sattr, node := newInodeOperations(ctx, i.fileState.s, newFile, qids[0], mask, p9attr)
+
+ // Construct a positive Dirent.
+ return fs.NewDirent(ctx, fs.NewInode(ctx, node, dir.MountSource, sattr), name), nil
+}
+
+// Creates a new Inode at name and returns its File based on the session's cache policy.
+//
+// Ownership is currently ignored.
+func (i *inodeOperations) Create(ctx context.Context, dir *fs.Inode, name string, flags fs.FileFlags, perm fs.FilePermissions) (*fs.File, error) {
+ if len(name) > maxFilenameLen {
+ return nil, syserror.ENAMETOOLONG
+ }
+
+ // Create replaces the directory fid with the newly created/opened
+ // file, so clone this directory so it doesn't change out from under
+ // this node.
+ _, newFile, err := i.fileState.file.walk(ctx, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ // Map the FileFlags to p9 OpenFlags.
+ var openFlags p9.OpenFlags
+ switch {
+ case flags.Read && flags.Write:
+ openFlags = p9.ReadWrite
+ case flags.Read:
+ openFlags = p9.ReadOnly
+ case flags.Write:
+ openFlags = p9.WriteOnly
+ default:
+ panic(fmt.Sprintf("Create called with unknown or unset open flags: %v", flags))
+ }
+
+ owner := fs.FileOwnerFromContext(ctx)
+ hostFile, err := newFile.create(ctx, name, openFlags, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID))
+ if err != nil {
+ // Could not create the file.
+ newFile.close(ctx)
+ return nil, err
+ }
+
+ i.touchModificationAndStatusChangeTime(ctx, dir)
+
+ // Get an unopened p9.File for the file we created so that it can be cloned
+ // and re-opened multiple times after creation, while also getting its
+ // attributes. Both are required for inodeOperations.
+ qids, unopened, mask, p9attr, err := i.fileState.file.walkGetAttr(ctx, []string{name})
+ if err != nil {
+ newFile.close(ctx)
+ if hostFile != nil {
+ hostFile.Close()
+ }
+ return nil, err
+ }
+ if len(qids) != 1 {
+ log.Warningf("WalkGetAttr(%s) succeeded, but returned %d QIDs (%v), wanted 1", name, len(qids), qids)
+ newFile.close(ctx)
+ if hostFile != nil {
+ hostFile.Close()
+ }
+ unopened.close(ctx)
+ return nil, syserror.EIO
+ }
+ qid := qids[0]
+
+ // Construct the InodeOperations.
+ sattr, iops := newInodeOperations(ctx, i.fileState.s, unopened, qid, mask, p9attr)
+
+ // Construct the positive Dirent.
+ d := fs.NewDirent(ctx, fs.NewInode(ctx, iops, dir.MountSource, sattr), name)
+ defer d.DecRef()
+
+ // Construct the new file, caching the handles if allowed.
+ h := handles{
+ File: newFile,
+ Host: hostFile,
+ }
+ h.EnableLeakCheck("gofer.handles")
+ if iops.fileState.canShareHandles() {
+ iops.fileState.handlesMu.Lock()
+ iops.fileState.setSharedHandlesLocked(flags, &h)
+ iops.fileState.handlesMu.Unlock()
+ }
+ return NewFile(ctx, d, name, flags, iops, &h), nil
+}
+
+// CreateLink uses Create to create a symlink between oldname and newname.
+func (i *inodeOperations) CreateLink(ctx context.Context, dir *fs.Inode, oldname string, newname string) error {
+ if len(newname) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
+ }
+
+ owner := fs.FileOwnerFromContext(ctx)
+ if _, err := i.fileState.file.symlink(ctx, oldname, newname, p9.UID(owner.UID), p9.GID(owner.GID)); err != nil {
+ return err
+ }
+ i.touchModificationAndStatusChangeTime(ctx, dir)
+ return nil
+}
+
+// CreateHardLink implements InodeOperations.CreateHardLink.
+func (i *inodeOperations) CreateHardLink(ctx context.Context, inode *fs.Inode, target *fs.Inode, newName string) error {
+ if len(newName) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
+ }
+
+ targetOpts, ok := target.InodeOperations.(*inodeOperations)
+ if !ok {
+ return syserror.EXDEV
+ }
+
+ if err := i.fileState.file.link(ctx, &targetOpts.fileState.file, newName); err != nil {
+ return err
+ }
+ if i.session().cachePolicy.cacheUAttrs(inode) {
+ // Increase link count.
+ targetOpts.cachingInodeOps.IncLinks(ctx)
+ }
+ i.touchModificationAndStatusChangeTime(ctx, inode)
+ return nil
+}
+
+// CreateDirectory uses Create to create a directory named s under inodeOperations.
+func (i *inodeOperations) CreateDirectory(ctx context.Context, dir *fs.Inode, s string, perm fs.FilePermissions) error {
+ if len(s) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
+ }
+
+ owner := fs.FileOwnerFromContext(ctx)
+ if _, err := i.fileState.file.mkdir(ctx, s, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID)); err != nil {
+ return err
+ }
+ if i.session().cachePolicy.cacheUAttrs(dir) {
+ // Increase link count.
+ //
+ // N.B. This will update the modification time.
+ i.cachingInodeOps.IncLinks(ctx)
+ }
+ if i.session().cachePolicy.cacheReaddir() {
+ // Invalidate readdir cache.
+ i.markDirectoryDirty()
+ }
+ return nil
+}
+
+// Bind implements InodeOperations.Bind.
+func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, ep transport.BoundEndpoint, perm fs.FilePermissions) (*fs.Dirent, error) {
+ if len(name) > maxFilenameLen {
+ return nil, syserror.ENAMETOOLONG
+ }
+
+ if i.session().overrides == nil {
+ return nil, syserror.EOPNOTSUPP
+ }
+
+ // Stabilize the override map while creation is in progress.
+ unlock := i.session().overrides.lock()
+ defer unlock()
+
+ sattr, iops, err := i.createEndpointFile(ctx, dir, name, perm, p9.ModeSocket)
+ if err != nil {
+ return nil, err
+ }
+
+ // Construct the positive Dirent.
+ childDir := fs.NewDirent(ctx, fs.NewInode(ctx, iops, dir.MountSource, sattr), name)
+ i.session().overrides.addBoundEndpoint(iops.fileState.key, childDir, ep)
+ return childDir, nil
+}
+
+// CreateFifo implements fs.InodeOperations.CreateFifo.
+func (i *inodeOperations) CreateFifo(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions) error {
+ if len(name) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
+ }
+
+ owner := fs.FileOwnerFromContext(ctx)
+ mode := p9.FileMode(perm.LinuxMode()) | p9.ModeNamedPipe
+
+ // N.B. FIFOs use major/minor numbers 0.
+ if _, err := i.fileState.file.mknod(ctx, name, mode, 0, 0, p9.UID(owner.UID), p9.GID(owner.GID)); err != nil {
+ if i.session().overrides == nil || err != syserror.EPERM {
+ return err
+ }
+ // If gofer doesn't support mknod, check if we can create an internal fifo.
+ return i.createInternalFifo(ctx, dir, name, owner, perm)
+ }
+
+ i.touchModificationAndStatusChangeTime(ctx, dir)
+ return nil
+}
+
+func (i *inodeOperations) createInternalFifo(ctx context.Context, dir *fs.Inode, name string, owner fs.FileOwner, perm fs.FilePermissions) error {
+ if i.session().overrides == nil {
+ return syserror.EPERM
+ }
+
+ // Stabilize the override map while creation is in progress.
+ unlock := i.session().overrides.lock()
+ defer unlock()
+
+ sattr, fileOps, err := i.createEndpointFile(ctx, dir, name, perm, p9.ModeNamedPipe)
+ if err != nil {
+ return err
+ }
+
+ // First create a pipe.
+ p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)
+
+ // Wrap the fileOps with our Fifo.
+ iops := &fifo{
+ InodeOperations: pipe.NewInodeOperations(ctx, perm, p),
+ fileIops: fileOps,
+ }
+ inode := fs.NewInode(ctx, iops, dir.MountSource, sattr)
+
+ // Construct the positive Dirent.
+ childDir := fs.NewDirent(ctx, fs.NewInode(ctx, iops, dir.MountSource, sattr), name)
+ i.session().overrides.addPipe(fileOps.fileState.key, childDir, inode)
+ return nil
+}
+
+// Caller must hold Session.endpoint lock.
+func (i *inodeOperations) createEndpointFile(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions, fileType p9.FileMode) (fs.StableAttr, *inodeOperations, error) {
+ _, dirClone, err := i.fileState.file.walk(ctx, nil)
+ if err != nil {
+ return fs.StableAttr{}, nil, err
+ }
+ // We're not going to use dirClone after return.
+ defer dirClone.close(ctx)
+
+ // Create a regular file in the gofer and then mark it as a socket by
+ // adding this inode key in the 'overrides' map.
+ owner := fs.FileOwnerFromContext(ctx)
+ hostFile, err := dirClone.create(ctx, name, p9.ReadWrite, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID))
+ if err != nil {
+ return fs.StableAttr{}, nil, err
+ }
+ // We're not going to use this file.
+ hostFile.Close()
+
+ i.touchModificationAndStatusChangeTime(ctx, dir)
+
+ // Get the attributes of the file to create inode key.
+ qid, mask, attr, err := getattr(ctx, dirClone)
+ if err != nil {
+ return fs.StableAttr{}, nil, err
+ }
+
+ // Get an unopened p9.File for the file we created so that it can be
+ // cloned and re-opened multiple times after creation.
+ _, unopened, err := i.fileState.file.walk(ctx, []string{name})
+ if err != nil {
+ return fs.StableAttr{}, nil, err
+ }
+
+ // Construct new inode with file type overridden.
+ attr.Mode = changeType(attr.Mode, fileType)
+ sattr, iops := newInodeOperations(ctx, i.fileState.s, unopened, qid, mask, attr)
+ return sattr, iops, nil
+}
+
+// Remove implements InodeOperations.Remove.
+func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string) error {
+ if len(name) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
+ }
+
+ var key *device.MultiDeviceKey
+ if i.session().overrides != nil {
+ // Find out if file being deleted is a socket or pipe that needs to be
+ // removed from endpoint map.
+ if d, err := i.Lookup(ctx, dir, name); err == nil {
+ defer d.DecRef()
+
+ if fs.IsSocket(d.Inode.StableAttr) || fs.IsPipe(d.Inode.StableAttr) {
+ switch iops := d.Inode.InodeOperations.(type) {
+ case *inodeOperations:
+ key = &iops.fileState.key
+ case *fifo:
+ key = &iops.fileIops.fileState.key
+ }
+
+ // Stabilize the override map while deletion is in progress.
+ unlock := i.session().overrides.lock()
+ defer unlock()
+ }
+ }
+ }
+
+ if err := i.fileState.file.unlinkAt(ctx, name, 0); err != nil {
+ return err
+ }
+ if key != nil {
+ i.session().overrides.remove(*key)
+ }
+ i.touchModificationAndStatusChangeTime(ctx, dir)
+
+ return nil
+}
+
+// Remove implements InodeOperations.RemoveDirectory.
+func (i *inodeOperations) RemoveDirectory(ctx context.Context, dir *fs.Inode, name string) error {
+ if len(name) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
+ }
+
+ // 0x200 = AT_REMOVEDIR.
+ if err := i.fileState.file.unlinkAt(ctx, name, 0x200); err != nil {
+ return err
+ }
+ if i.session().cachePolicy.cacheUAttrs(dir) {
+ // Decrease link count and updates atime.
+ i.cachingInodeOps.DecLinks(ctx)
+ }
+ if i.session().cachePolicy.cacheReaddir() {
+ // Invalidate readdir cache.
+ i.markDirectoryDirty()
+ }
+ return nil
+}
+
+// Rename renames this node.
+func (i *inodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ if len(newName) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
+ }
+
+ // Don't allow renames across different mounts.
+ if newParent.MountSource != oldParent.MountSource {
+ return syserror.EXDEV
+ }
+
+ // Unwrap the new parent to a *inodeOperations.
+ newParentInodeOperations := newParent.InodeOperations.(*inodeOperations)
+
+ // Unwrap the old parent to a *inodeOperations.
+ oldParentInodeOperations := oldParent.InodeOperations.(*inodeOperations)
+
+ // Do the rename.
+ if err := i.fileState.file.rename(ctx, newParentInodeOperations.fileState.file, newName); err != nil {
+ return err
+ }
+
+ // Is the renamed entity a directory? Fix link counts.
+ if fs.IsDir(i.fileState.sattr) {
+ // Update cached state.
+ if i.session().cachePolicy.cacheUAttrs(oldParent) {
+ oldParentInodeOperations.cachingInodeOps.DecLinks(ctx)
+ }
+ if i.session().cachePolicy.cacheUAttrs(newParent) {
+ // Only IncLinks if there is a new addition to
+ // newParent. If this is replacement, then the total
+ // count remains the same.
+ if !replacement {
+ newParentInodeOperations.cachingInodeOps.IncLinks(ctx)
+ }
+ }
+ }
+ if i.session().cachePolicy.cacheReaddir() {
+ // Mark old directory dirty.
+ oldParentInodeOperations.markDirectoryDirty()
+ if oldParent != newParent {
+ // Mark new directory dirty.
+ newParentInodeOperations.markDirectoryDirty()
+ }
+ }
+
+ // Rename always updates ctime.
+ if i.session().cachePolicy.cacheUAttrs(inode) {
+ i.cachingInodeOps.TouchStatusChangeTime(ctx)
+ }
+ return nil
+}
+
+func (i *inodeOperations) touchModificationAndStatusChangeTime(ctx context.Context, inode *fs.Inode) {
+ if i.session().cachePolicy.cacheUAttrs(inode) {
+ i.cachingInodeOps.TouchModificationAndStatusChangeTime(ctx)
+ }
+ if i.session().cachePolicy.cacheReaddir() {
+ // Invalidate readdir cache.
+ i.markDirectoryDirty()
+ }
+}
+
+// markDirectoryDirty marks any cached data dirty for this directory. This is necessary in order
+// to ensure that this node does not retain stale state throughout its lifetime across multiple
+// open directory handles.
+//
+// Currently this means invalidating any readdir caches.
+func (i *inodeOperations) markDirectoryDirty() {
+ i.readdirMu.Lock()
+ defer i.readdirMu.Unlock()
+ i.readdirCache = nil
+}
diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go
new file mode 100644
index 000000000..b5efc86f2
--- /dev/null
+++ b/pkg/sentry/fs/gofer/session.go
@@ -0,0 +1,426 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/device"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// DefaultDirentCacheSize is the default dirent cache size for 9P mounts. It can
+// be adjusted independently from the other dirent caches.
+var DefaultDirentCacheSize uint64 = fs.DefaultDirentCacheSize
+
+// +stateify savable
+type overrideInfo struct {
+ dirent *fs.Dirent
+
+ // endpoint is set when dirent points to a socket. inode must not be set.
+ endpoint transport.BoundEndpoint
+
+ // inode is set when dirent points to a pipe. endpoint must not be set.
+ inode *fs.Inode
+}
+
+func (l *overrideInfo) inodeType() fs.InodeType {
+ switch {
+ case l.endpoint != nil:
+ return fs.Socket
+ case l.inode != nil:
+ return fs.Pipe
+ }
+ panic("endpoint or node must be set")
+}
+
+// +stateify savable
+type overrideMaps struct {
+ // mu protexts the keyMap, and the pathMap below.
+ mu sync.RWMutex `state:"nosave"`
+
+ // keyMap links MultiDeviceKeys (containing inode IDs) to their sockets/pipes.
+ // It is not stored during save because the inode ID may change upon restore.
+ keyMap map[device.MultiDeviceKey]*overrideInfo `state:"nosave"`
+
+ // pathMap links the sockets/pipes to their paths.
+ // It is filled before saving from the direntMap and is stored upon save.
+ // Upon restore, this map is used to re-populate the keyMap.
+ pathMap map[*overrideInfo]string
+}
+
+// addBoundEndpoint adds the bound endpoint to the map.
+// A reference is taken on the dirent argument.
+//
+// Precondition: maps must have been locked with 'lock'.
+func (e *overrideMaps) addBoundEndpoint(key device.MultiDeviceKey, d *fs.Dirent, ep transport.BoundEndpoint) {
+ d.IncRef()
+ e.keyMap[key] = &overrideInfo{dirent: d, endpoint: ep}
+}
+
+// addPipe adds the pipe inode to the map.
+// A reference is taken on the dirent argument.
+//
+// Precondition: maps must have been locked with 'lock'.
+func (e *overrideMaps) addPipe(key device.MultiDeviceKey, d *fs.Dirent, inode *fs.Inode) {
+ d.IncRef()
+ e.keyMap[key] = &overrideInfo{dirent: d, inode: inode}
+}
+
+// remove deletes the key from the maps.
+//
+// Precondition: maps must have been locked with 'lock'.
+func (e *overrideMaps) remove(key device.MultiDeviceKey) {
+ endpoint := e.keyMap[key]
+ delete(e.keyMap, key)
+ endpoint.dirent.DecRef()
+}
+
+// lock blocks other addition and removal operations from happening while
+// the backing file is being created or deleted. Returns a function that unlocks
+// the endpoint map.
+func (e *overrideMaps) lock() func() {
+ e.mu.Lock()
+ return func() { e.mu.Unlock() }
+}
+
+// getBoundEndpoint returns the bound endpoint mapped to the given key.
+//
+// Precondition: maps must have been locked.
+func (e *overrideMaps) getBoundEndpoint(key device.MultiDeviceKey) transport.BoundEndpoint {
+ if v := e.keyMap[key]; v != nil {
+ return v.endpoint
+ }
+ return nil
+}
+
+// getPipe returns the pipe inode mapped to the given key.
+//
+// Precondition: maps must have been locked.
+func (e *overrideMaps) getPipe(key device.MultiDeviceKey) *fs.Inode {
+ if v := e.keyMap[key]; v != nil {
+ return v.inode
+ }
+ return nil
+}
+
+// getType returns the inode type if there is a corresponding endpoint for the
+// given key. Returns false otherwise.
+func (e *overrideMaps) getType(key device.MultiDeviceKey) (fs.InodeType, bool) {
+ e.mu.Lock()
+ v := e.keyMap[key]
+ e.mu.Unlock()
+
+ if v != nil {
+ return v.inodeType(), true
+ }
+ return 0, false
+}
+
+// session holds state for each 9p session established during sys_mount.
+//
+// +stateify savable
+type session struct {
+ refs.AtomicRefCount
+
+ // msize is the value of the msize mount option, see fs/gofer/fs.go.
+ msize uint32 `state:"wait"`
+
+ // version is the value of the version mount option, see fs/gofer/fs.go.
+ version string `state:"wait"`
+
+ // cachePolicy is the cache policy.
+ cachePolicy cachePolicy `state:"wait"`
+
+ // aname is the value of the aname mount option, see fs/gofer/fs.go.
+ aname string `state:"wait"`
+
+ // The client associated with this session. This will be initialized lazily.
+ client *p9.Client `state:"nosave"`
+
+ // The p9.File pointing to attachName via the client. This will be initialized
+ // lazily.
+ attach contextFile `state:"nosave"`
+
+ // Flags provided to the mount.
+ superBlockFlags fs.MountSourceFlags `state:"wait"`
+
+ // limitHostFDTranslation is the value used for
+ // CachingInodeOperationsOptions.LimitHostFDTranslation for all
+ // CachingInodeOperations created by the session.
+ limitHostFDTranslation bool
+
+ // overlayfsStaleRead when set causes the readonly handle to be invalidated
+ // after file is open for write.
+ overlayfsStaleRead bool
+
+ // connID is a unique identifier for the session connection.
+ connID string `state:"wait"`
+
+ // inodeMappings contains mappings of fs.Inodes associated with this session
+ // to paths relative to the attach point, where inodeMappings is keyed by
+ // Inode.StableAttr.InodeID.
+ inodeMappings map[uint64]string `state:"wait"`
+
+ // mounter is the EUID/EGID that mounted this file system.
+ mounter fs.FileOwner `state:"wait"`
+
+ // overrides is used to map inodes that represent socket/pipes files to their
+ // corresponding endpoint/iops. These files are created as regular files in
+ // the gofer and their presence in this map indicate that they should indeed
+ // be socket/pipe files. This allows unix domain sockets and named pipes to
+ // be used with paths that belong to a gofer.
+ //
+ // There are a few possible races with someone stat'ing the file and another
+ // deleting it concurrently, where the file will not be reported as socket
+ // file.
+ overrides *overrideMaps `state:"wait"`
+}
+
+// Destroy tears down the session.
+func (s *session) Destroy() {
+ s.client.Close()
+}
+
+// Revalidate implements MountSourceOperations.Revalidate.
+func (s *session) Revalidate(ctx context.Context, name string, parent, child *fs.Inode) bool {
+ return s.cachePolicy.revalidate(ctx, name, parent, child)
+}
+
+// Keep implements MountSourceOperations.Keep.
+func (s *session) Keep(d *fs.Dirent) bool {
+ return s.cachePolicy.keep(d)
+}
+
+// CacheReaddir implements MountSourceOperations.CacheReaddir.
+func (s *session) CacheReaddir() bool {
+ return s.cachePolicy.cacheReaddir()
+}
+
+// ResetInodeMappings implements fs.MountSourceOperations.ResetInodeMappings.
+func (s *session) ResetInodeMappings() {
+ s.inodeMappings = make(map[uint64]string)
+}
+
+// SaveInodeMapping implements fs.MountSourceOperations.SaveInodeMapping.
+func (s *session) SaveInodeMapping(inode *fs.Inode, path string) {
+ // This is very unintuitive. We *CANNOT* trust the inode's StableAttrs,
+ // because overlay copyUp may have changed them out from under us.
+ // So much for "immutable".
+ switch iops := inode.InodeOperations.(type) {
+ case *inodeOperations:
+ s.inodeMappings[iops.fileState.sattr.InodeID] = path
+ case *fifo:
+ s.inodeMappings[iops.fileIops.fileState.sattr.InodeID] = path
+ default:
+ panic(fmt.Sprintf("Invalid type: %T", iops))
+ }
+}
+
+// newInodeOperations creates a new 9p fs.InodeOperations backed by a p9.File
+// and attributes (p9.QID, p9.AttrMask, p9.Attr).
+//
+// Endpoints lock must not be held if socket == false.
+func newInodeOperations(ctx context.Context, s *session, file contextFile, qid p9.QID, valid p9.AttrMask, attr p9.Attr) (fs.StableAttr, *inodeOperations) {
+ deviceKey := device.MultiDeviceKey{
+ Device: attr.RDev,
+ SecondaryDevice: s.connID,
+ Inode: qid.Path,
+ }
+
+ sattr := fs.StableAttr{
+ Type: ntype(attr),
+ DeviceID: goferDevice.DeviceID(),
+ InodeID: goferDevice.Map(deviceKey),
+ BlockSize: bsize(attr),
+ }
+
+ if s.overrides != nil && sattr.Type == fs.RegularFile {
+ // If overrides are allowed on this filesystem, check if this file is
+ // supposed to be of a different type, e.g. socket.
+ if t, ok := s.overrides.getType(deviceKey); ok {
+ sattr.Type = t
+ }
+ }
+
+ fileState := &inodeFileState{
+ s: s,
+ file: file,
+ sattr: sattr,
+ key: deviceKey,
+ }
+ if s.cachePolicy == cacheRemoteRevalidating && fs.IsFile(sattr) {
+ fileState.hostMappable = fsutil.NewHostMappable(fileState)
+ }
+
+ uattr := unstable(ctx, valid, attr, s.mounter, s.client)
+ return sattr, &inodeOperations{
+ fileState: fileState,
+ cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, fsutil.CachingInodeOperationsOptions{
+ ForcePageCache: s.superBlockFlags.ForcePageCache,
+ LimitHostFDTranslation: s.limitHostFDTranslation,
+ }),
+ }
+}
+
+// Root returns the root of a 9p mount. This mount is bound to a 9p server
+// based on conn. Otherwise configuration parameters are:
+//
+// * dev: connection id
+// * filesystem: the filesystem backing the mount
+// * superBlockFlags: the mount flags describing general mount options
+// * opts: parsed 9p mount options
+func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockFlags fs.MountSourceFlags, o opts) (*fs.Inode, error) {
+ // The mounting EUID/EGID will be cached by this file system. This will
+ // be used to assign ownership to files that the Gofer owns.
+ mounter := fs.FileOwnerFromContext(ctx)
+
+ conn, err := unet.NewSocket(o.fd)
+ if err != nil {
+ return nil, err
+ }
+
+ // Construct the session.
+ s := session{
+ connID: dev,
+ msize: o.msize,
+ version: o.version,
+ cachePolicy: o.policy,
+ aname: o.aname,
+ superBlockFlags: superBlockFlags,
+ limitHostFDTranslation: o.limitHostFDTranslation,
+ overlayfsStaleRead: o.overlayfsStaleRead,
+ mounter: mounter,
+ }
+ s.EnableLeakCheck("gofer.session")
+
+ if o.privateunixsocket {
+ s.overrides = newOverrideMaps()
+ }
+
+ // Construct the MountSource with the session and superBlockFlags.
+ m := fs.NewMountSource(ctx, &s, filesystem, superBlockFlags)
+
+ // Given that gofer files can consume host FDs, restrict the number
+ // of files that can be held by the cache.
+ m.SetDirentCacheMaxSize(DefaultDirentCacheSize)
+ m.SetDirentCacheLimiter(fs.DirentCacheLimiterFromContext(ctx))
+
+ // Send the Tversion request.
+ s.client, err = p9.NewClient(conn, s.msize, s.version)
+ if err != nil {
+ // Drop our reference on the session, it needs to be torn down.
+ s.DecRef()
+ return nil, err
+ }
+
+ // Notify that we're about to call the Gofer and block.
+ ctx.UninterruptibleSleepStart(false)
+ // Send the Tattach request.
+ s.attach.file, err = s.client.Attach(s.aname)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ // Same as above.
+ s.DecRef()
+ return nil, err
+ }
+
+ qid, valid, attr, err := s.attach.getAttr(ctx, p9.AttrMaskAll())
+ if err != nil {
+ s.attach.close(ctx)
+ // Same as above, but after we execute the Close request.
+ s.DecRef()
+ return nil, err
+ }
+
+ sattr, iops := newInodeOperations(ctx, &s, s.attach, qid, valid, attr)
+ return fs.NewInode(ctx, iops, m, sattr), nil
+}
+
+// newOverrideMaps creates a new overrideMaps.
+func newOverrideMaps() *overrideMaps {
+ return &overrideMaps{
+ keyMap: make(map[device.MultiDeviceKey]*overrideInfo),
+ pathMap: make(map[*overrideInfo]string),
+ }
+}
+
+// fillKeyMap populates key and dirent maps upon restore from saved pathmap.
+func (s *session) fillKeyMap(ctx context.Context) error {
+ unlock := s.overrides.lock()
+ defer unlock()
+
+ for ep, dirPath := range s.overrides.pathMap {
+ _, file, err := s.attach.walk(ctx, splitAbsolutePath(dirPath))
+ if err != nil {
+ return fmt.Errorf("error filling endpointmaps, failed to walk to %q: %v", dirPath, err)
+ }
+
+ qid, _, attr, err := file.getAttr(ctx, p9.AttrMaskAll())
+ if err != nil {
+ return fmt.Errorf("failed to get file attributes of %s: %v", dirPath, err)
+ }
+
+ key := device.MultiDeviceKey{
+ Device: attr.RDev,
+ SecondaryDevice: s.connID,
+ Inode: qid.Path,
+ }
+
+ s.overrides.keyMap[key] = ep
+ }
+ return nil
+}
+
+// fillPathMap populates paths for overrides from dirents in direntMap
+// before save.
+func (s *session) fillPathMap() error {
+ unlock := s.overrides.lock()
+ defer unlock()
+
+ for _, endpoint := range s.overrides.keyMap {
+ mountRoot := endpoint.dirent.MountRoot()
+ defer mountRoot.DecRef()
+ dirPath, _ := endpoint.dirent.FullName(mountRoot)
+ if dirPath == "" {
+ return fmt.Errorf("error getting path from dirent")
+ }
+ s.overrides.pathMap[endpoint] = dirPath
+ }
+ return nil
+}
+
+// restoreEndpointMaps recreates and fills the key and dirent maps.
+func (s *session) restoreEndpointMaps(ctx context.Context) error {
+ // When restoring, only need to create the keyMap because the dirent and path
+ // maps got stored through the save.
+ s.overrides.keyMap = make(map[device.MultiDeviceKey]*overrideInfo)
+ if err := s.fillKeyMap(ctx); err != nil {
+ return fmt.Errorf("failed to insert sockets into endpoint map: %v", err)
+ }
+
+ // Re-create pathMap because it can no longer be trusted as socket paths can
+ // change while process continues to run. Empty pathMap will be re-filled upon
+ // next save.
+ s.overrides.pathMap = make(map[*overrideInfo]string)
+ return nil
+}
diff --git a/pkg/sentry/fs/gofer/session_state.go b/pkg/sentry/fs/gofer/session_state.go
new file mode 100644
index 000000000..2d398b753
--- /dev/null
+++ b/pkg/sentry/fs/gofer/session_state.go
@@ -0,0 +1,113 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// beforeSave is invoked by stateify.
+func (s *session) beforeSave() {
+ if s.overrides != nil {
+ if err := s.fillPathMap(); err != nil {
+ panic("failed to save paths to override map before saving" + err.Error())
+ }
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (s *session) afterLoad() {
+ // The restore environment contains the 9p connection of this mount.
+ fsys := filesystem{}
+ env, ok := fs.CurrentRestoreEnvironment()
+ if !ok {
+ panic("failed to find restore environment")
+ }
+ mounts, ok := env.MountSources[fsys.Name()]
+ if !ok {
+ panic("failed to find mounts for filesystem type " + fsys.Name())
+ }
+ var args fs.MountArgs
+ var found bool
+ for _, mount := range mounts {
+ if mount.Dev == s.connID {
+ args = mount
+ found = true
+ }
+ }
+ if !found {
+ panic(fmt.Sprintf("no connection for connection id %q", s.connID))
+ }
+
+ // Validate the mount flags and options.
+ opts, err := options(args.DataString)
+ if err != nil {
+ panic("failed to parse mount options: " + err.Error())
+ }
+ if opts.msize != s.msize {
+ panic(fmt.Sprintf("new message size %v, want %v", opts.msize, s.msize))
+ }
+ if opts.version != s.version {
+ panic(fmt.Sprintf("new version %v, want %v", opts.version, s.version))
+ }
+ if opts.policy != s.cachePolicy {
+ panic(fmt.Sprintf("new cache policy %v, want %v", opts.policy, s.cachePolicy))
+ }
+ if opts.aname != s.aname {
+ panic(fmt.Sprintf("new attach name %v, want %v", opts.aname, s.aname))
+ }
+
+ // Check if overrideMaps exist when uds sockets are enabled (only pathmaps
+ // will actually have been saved).
+ if opts.privateunixsocket != (s.overrides != nil) {
+ panic(fmt.Sprintf("new privateunixsocket option %v, want %v", opts.privateunixsocket, s.overrides != nil))
+ }
+ if args.Flags != s.superBlockFlags {
+ panic(fmt.Sprintf("new mount flags %v, want %v", args.Flags, s.superBlockFlags))
+ }
+
+ // Manually restore the connection.
+ conn, err := unet.NewSocket(opts.fd)
+ if err != nil {
+ panic(fmt.Sprintf("failed to create Socket for FD %d: %v", opts.fd, err))
+ }
+
+ // Manually restore the client.
+ s.client, err = p9.NewClient(conn, s.msize, s.version)
+ if err != nil {
+ panic(fmt.Sprintf("failed to connect client to server: %v", err))
+ }
+
+ // Manually restore the attach point.
+ s.attach.file, err = s.client.Attach(s.aname)
+ if err != nil {
+ panic(fmt.Sprintf("failed to attach to aname: %v", err))
+ }
+
+ // If private unix sockets are enabled, create and fill the session's endpoint
+ // maps.
+ if opts.privateunixsocket {
+ ctx := &dummyClockContext{context.Background()}
+
+ if err = s.restoreEndpointMaps(ctx); err != nil {
+ panic("failed to restore endpoint maps: " + err.Error())
+ }
+ }
+}
diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go
new file mode 100644
index 000000000..40f2c1cad
--- /dev/null
+++ b/pkg/sentry/fs/gofer/socket.go
@@ -0,0 +1,152 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/host"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// BoundEndpoint returns a gofer-backed transport.BoundEndpoint.
+func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) transport.BoundEndpoint {
+ if !fs.IsSocket(i.fileState.sattr) {
+ return nil
+ }
+
+ if i.session().overrides != nil {
+ unlock := i.session().overrides.lock()
+ defer unlock()
+ ep := i.session().overrides.getBoundEndpoint(i.fileState.key)
+ if ep != nil {
+ return ep
+ }
+
+ // Not found in overrides map, it may be a gofer backed unix socket...
+ }
+
+ inode.IncRef()
+ return &endpoint{inode, i.fileState.file.file, path}
+}
+
+// LINT.IfChange
+
+// endpoint is a Gofer-backed transport.BoundEndpoint.
+//
+// An endpoint's lifetime is the time between when InodeOperations.BoundEndpoint()
+// is called and either BoundEndpoint.BidirectionalConnect or
+// BoundEndpoint.UnidirectionalConnect is called.
+type endpoint struct {
+ // inode is the filesystem inode which produced this endpoint.
+ inode *fs.Inode
+
+ // file is the p9 file that contains a single unopened fid.
+ file p9.File
+
+ // path is the sentry path where this endpoint is bound.
+ path string
+}
+
+func sockTypeToP9(t linux.SockType) (p9.ConnectFlags, bool) {
+ switch t {
+ case linux.SOCK_STREAM:
+ return p9.StreamSocket, true
+ case linux.SOCK_SEQPACKET:
+ return p9.SeqpacketSocket, true
+ case linux.SOCK_DGRAM:
+ return p9.DgramSocket, true
+ }
+ return 0, false
+}
+
+// BidirectionalConnect implements ConnectableEndpoint.BidirectionalConnect.
+func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *syserr.Error {
+ cf, ok := sockTypeToP9(ce.Type())
+ if !ok {
+ return syserr.ErrConnectionRefused
+ }
+
+ // No lock ordering required as only the ConnectingEndpoint has a mutex.
+ ce.Lock()
+
+ // Check connecting state.
+ if ce.Connected() {
+ ce.Unlock()
+ return syserr.ErrAlreadyConnected
+ }
+ if ce.Listening() {
+ ce.Unlock()
+ return syserr.ErrInvalidEndpointState
+ }
+
+ hostFile, err := e.file.Connect(cf)
+ if err != nil {
+ ce.Unlock()
+ return syserr.ErrConnectionRefused
+ }
+
+ c, serr := host.NewConnectedEndpoint(ctx, hostFile, ce.WaiterQueue(), e.path)
+ if serr != nil {
+ ce.Unlock()
+ log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.file, cf, serr)
+ return serr
+ }
+
+ returnConnect(c, c)
+ ce.Unlock()
+ c.Init()
+
+ return nil
+}
+
+// UnidirectionalConnect implements
+// transport.BoundEndpoint.UnidirectionalConnect.
+func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.ConnectedEndpoint, *syserr.Error) {
+ hostFile, err := e.file.Connect(p9.DgramSocket)
+ if err != nil {
+ return nil, syserr.ErrConnectionRefused
+ }
+
+ c, serr := host.NewConnectedEndpoint(ctx, hostFile, &waiter.Queue{}, e.path)
+ if serr != nil {
+ log.Warningf("Gofer returned invalid host socket for UnidirectionalConnect; file %+v: %v", e.file, serr)
+ return nil, serr
+ }
+ c.Init()
+
+ // We don't need the receiver.
+ c.CloseRecv()
+ c.Release()
+
+ return c, nil
+}
+
+// Release implements transport.BoundEndpoint.Release.
+func (e *endpoint) Release() {
+ e.inode.DecRef()
+}
+
+// Passcred implements transport.BoundEndpoint.Passcred.
+func (e *endpoint) Passcred() bool {
+ return false
+}
+
+// LINT.ThenChange(../../fsimpl/gofer/socket.go)
diff --git a/pkg/sentry/fs/gofer/util.go b/pkg/sentry/fs/gofer/util.go
new file mode 100644
index 000000000..47a6c69bf
--- /dev/null
+++ b/pkg/sentry/fs/gofer/util.go
@@ -0,0 +1,72 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+)
+
+func utimes(ctx context.Context, file contextFile, ts fs.TimeSpec) error {
+ if ts.ATimeOmit && ts.MTimeOmit {
+ return nil
+ }
+
+ // Replace requests to use the "system time" with the current time to
+ // ensure that timestamps remain consistent with the remote
+ // filesystem.
+ now := ktime.NowFromContext(ctx)
+ if ts.ATimeSetSystemTime {
+ ts.ATime = now
+ }
+ if ts.MTimeSetSystemTime {
+ ts.MTime = now
+ }
+ mask := p9.SetAttrMask{
+ ATime: !ts.ATimeOmit,
+ ATimeNotSystemTime: true,
+ MTime: !ts.MTimeOmit,
+ MTimeNotSystemTime: true,
+ }
+ as, ans := ts.ATime.Unix()
+ ms, mns := ts.MTime.Unix()
+ attr := p9.SetAttr{
+ ATimeSeconds: uint64(as),
+ ATimeNanoSeconds: uint64(ans),
+ MTimeSeconds: uint64(ms),
+ MTimeNanoSeconds: uint64(mns),
+ }
+ // 9p2000.L SetAttr: "If a time bit is set without the corresponding SET bit,
+ // the current system time on the server is used instead of the value sent
+ // in the request."
+ return file.setAttr(ctx, mask, attr)
+}
+
+func openFlagsFromPerms(p fs.PermMask) (p9.OpenFlags, error) {
+ switch {
+ case p.Read && p.Write:
+ return p9.ReadWrite, nil
+ case p.Write:
+ return p9.WriteOnly, nil
+ case p.Read:
+ return p9.ReadOnly, nil
+ default:
+ return 0, syscall.EINVAL
+ }
+}
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
new file mode 100644
index 000000000..aabce6cc9
--- /dev/null
+++ b/pkg/sentry/fs/host/BUILD
@@ -0,0 +1,82 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "host",
+ srcs = [
+ "control.go",
+ "descriptor.go",
+ "descriptor_state.go",
+ "device.go",
+ "file.go",
+ "host.go",
+ "inode.go",
+ "inode_state.go",
+ "ioctl_unsafe.go",
+ "socket.go",
+ "socket_iovec.go",
+ "socket_state.go",
+ "socket_unsafe.go",
+ "tty.go",
+ "util.go",
+ "util_amd64_unsafe.go",
+ "util_arm64_unsafe.go",
+ "util_unsafe.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fd",
+ "//pkg/fdnotifier",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/safemem",
+ "//pkg/secio",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/socket/control",
+ "//pkg/sentry/socket/unix",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/unimpl",
+ "//pkg/sentry/uniqueid",
+ "//pkg/sync",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/unet",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "host_test",
+ size = "small",
+ srcs = [
+ "descriptor_test.go",
+ "inode_test.go",
+ "socket_test.go",
+ "wait_test.go",
+ ],
+ library = ":host",
+ deps = [
+ "//pkg/fd",
+ "//pkg/fdnotifier",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/syserr",
+ "//pkg/tcpip",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/fs/host/control.go b/pkg/sentry/fs/host/control.go
new file mode 100644
index 000000000..39299b7e4
--- /dev/null
+++ b/pkg/sentry/fs/host/control.go
@@ -0,0 +1,97 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+)
+
+// LINT.IfChange
+
+type scmRights struct {
+ fds []int
+}
+
+func newSCMRights(fds []int) control.SCMRights {
+ return &scmRights{fds}
+}
+
+// Files implements control.SCMRights.Files.
+func (c *scmRights) Files(ctx context.Context, max int) (control.RightsFiles, bool) {
+ n := max
+ var trunc bool
+ if l := len(c.fds); n > l {
+ n = l
+ } else if n < l {
+ trunc = true
+ }
+
+ rf := control.RightsFiles(fdsToFiles(ctx, c.fds[:n]))
+
+ // Only consume converted FDs (fdsToFiles may convert fewer than n FDs).
+ c.fds = c.fds[len(rf):]
+ return rf, trunc
+}
+
+// Clone implements transport.RightsControlMessage.Clone.
+func (c *scmRights) Clone() transport.RightsControlMessage {
+ // Host rights never need to be cloned.
+ return nil
+}
+
+// Release implements transport.RightsControlMessage.Release.
+func (c *scmRights) Release() {
+ for _, fd := range c.fds {
+ syscall.Close(fd)
+ }
+ c.fds = nil
+}
+
+// If an error is encountered, only files created before the error will be
+// returned. This is what Linux does.
+func fdsToFiles(ctx context.Context, fds []int) []*fs.File {
+ files := make([]*fs.File, 0, len(fds))
+ for _, fd := range fds {
+ // Get flags. We do it here because they may be modified
+ // by subsequent functions.
+ fileFlags, _, errno := syscall.Syscall(syscall.SYS_FCNTL, uintptr(fd), syscall.F_GETFL, 0)
+ if errno != 0 {
+ ctx.Warningf("Error retrieving host FD flags: %v", error(errno))
+ break
+ }
+
+ // Create the file backed by hostFD.
+ file, err := NewFile(ctx, fd)
+ if err != nil {
+ ctx.Warningf("Error creating file from host FD: %v", err)
+ break
+ }
+
+ // Set known flags.
+ file.SetFlags(fs.SettableFileFlags{
+ NonBlocking: fileFlags&syscall.O_NONBLOCK != 0,
+ })
+
+ files = append(files, file)
+ }
+ return files
+}
+
+// LINT.ThenChange(../../fsimpl/host/control.go)
diff --git a/pkg/sentry/fs/host/descriptor.go b/pkg/sentry/fs/host/descriptor.go
new file mode 100644
index 000000000..cfdce6a74
--- /dev/null
+++ b/pkg/sentry/fs/host/descriptor.go
@@ -0,0 +1,99 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// descriptor wraps a host fd.
+//
+// +stateify savable
+type descriptor struct {
+ // If origFD >= 0, it is the host fd that this file was originally created
+ // from, which must be available at time of restore. The FD can be closed
+ // after descriptor is created.
+ origFD int
+
+ // wouldBlock is true if value (below) points to a file that can
+ // return EWOULDBLOCK for operations that would block.
+ wouldBlock bool
+
+ // value is the wrapped host fd. It is never saved or restored
+ // directly.
+ value int `state:"nosave"`
+}
+
+// newDescriptor returns a wrapped host file descriptor. On success,
+// the descriptor is registered for event notifications with queue.
+func newDescriptor(fd int, saveable bool, wouldBlock bool, queue *waiter.Queue) (*descriptor, error) {
+ ownedFD := fd
+ origFD := -1
+ if saveable {
+ var err error
+ ownedFD, err = syscall.Dup(fd)
+ if err != nil {
+ return nil, err
+ }
+ origFD = fd
+ }
+ if wouldBlock {
+ if err := syscall.SetNonblock(ownedFD, true); err != nil {
+ return nil, err
+ }
+ if err := fdnotifier.AddFD(int32(ownedFD), queue); err != nil {
+ return nil, err
+ }
+ }
+ return &descriptor{
+ origFD: origFD,
+ wouldBlock: wouldBlock,
+ value: ownedFD,
+ }, nil
+}
+
+// initAfterLoad initializes the value of the descriptor after Load.
+func (d *descriptor) initAfterLoad(id uint64, queue *waiter.Queue) error {
+ var err error
+ d.value, err = syscall.Dup(d.origFD)
+ if err != nil {
+ return fmt.Errorf("failed to dup restored fd %d: %v", d.origFD, err)
+ }
+ if d.wouldBlock {
+ if err := syscall.SetNonblock(d.value, true); err != nil {
+ return err
+ }
+ if err := fdnotifier.AddFD(int32(d.value), queue); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// Release releases all resources held by descriptor.
+func (d *descriptor) Release() {
+ if d.wouldBlock {
+ fdnotifier.RemoveFD(int32(d.value))
+ }
+ if err := syscall.Close(d.value); err != nil {
+ log.Warningf("error closing fd %d: %v", d.value, err)
+ }
+ d.value = -1
+}
diff --git a/pkg/sentry/fs/host/descriptor_state.go b/pkg/sentry/fs/host/descriptor_state.go
new file mode 100644
index 000000000..e880582ab
--- /dev/null
+++ b/pkg/sentry/fs/host/descriptor_state.go
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+// beforeSave is invoked by stateify.
+func (d *descriptor) beforeSave() {
+ if d.origFD < 0 {
+ panic("donated file descriptor cannot be saved")
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (d *descriptor) afterLoad() {
+ // value must be manually restored by the descriptor's parent using
+ // initAfterLoad.
+ d.value = -1
+}
diff --git a/pkg/sentry/fs/host/descriptor_test.go b/pkg/sentry/fs/host/descriptor_test.go
new file mode 100644
index 000000000..d8e4605b6
--- /dev/null
+++ b/pkg/sentry/fs/host/descriptor_test.go
@@ -0,0 +1,78 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "io/ioutil"
+ "path/filepath"
+ "syscall"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestDescriptorRelease(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ saveable bool
+ wouldBlock bool
+ }{
+ {name: "all false"},
+ {name: "saveable", saveable: true},
+ {name: "wouldBlock", wouldBlock: true},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir("", "descriptor_test")
+ if err != nil {
+ t.Fatal("ioutil.TempDir() failed:", err)
+ }
+
+ fd, err := syscall.Open(filepath.Join(dir, "file"), syscall.O_RDWR|syscall.O_CREAT, 0666)
+ if err != nil {
+ t.Fatal("failed to open temp file:", err)
+ }
+
+ // FD ownership is transferred to the descritor.
+ queue := &waiter.Queue{}
+ d, err := newDescriptor(fd, tc.saveable, tc.wouldBlock, queue)
+ if err != nil {
+ syscall.Close(fd)
+ t.Fatalf("newDescriptor(%d, %t, %t, queue) failed, err: %v", fd, tc.saveable, tc.wouldBlock, err)
+ }
+ if tc.saveable {
+ if d.origFD < 0 {
+ t.Errorf("saveable descriptor must preserve origFD, desc: %+v", d)
+ }
+ }
+ if tc.wouldBlock {
+ if !fdnotifier.HasFD(int32(d.value)) {
+ t.Errorf("FD not registered with notifier, desc: %+v", d)
+ }
+ }
+
+ oldVal := d.value
+ d.Release()
+ if d.value != -1 {
+ t.Errorf("d.value want: -1, got: %d", d.value)
+ }
+ if tc.wouldBlock {
+ if fdnotifier.HasFD(int32(oldVal)) {
+ t.Errorf("FD not unregistered with notifier, desc: %+v", d)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/fs/host/device.go b/pkg/sentry/fs/host/device.go
new file mode 100644
index 000000000..484f0b58b
--- /dev/null
+++ b/pkg/sentry/fs/host/device.go
@@ -0,0 +1,25 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/device"
+)
+
+// hostFileDevice is the host file virtual device.
+var hostFileDevice = device.NewAnonMultiDevice()
+
+// hostPipeDevice is the host pipe virtual device.
+var hostPipeDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/fs/host/file.go b/pkg/sentry/fs/host/file.go
new file mode 100644
index 000000000..3e48b8b2c
--- /dev/null
+++ b/pkg/sentry/fs/host/file.go
@@ -0,0 +1,286 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/secio"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// fileOperations implements fs.FileOperations for a host file descriptor.
+//
+// +stateify savable
+type fileOperations struct {
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosplice"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ // iops are the Inode operations for this file.
+ iops *inodeOperations `state:"wait"`
+
+ // a scratch buffer for reading directory entries.
+ dirinfo *dirInfo `state:"nosave"`
+
+ // dirCursor is the directory cursor.
+ dirCursor string
+}
+
+// fileOperations implements fs.FileOperations.
+var _ fs.FileOperations = (*fileOperations)(nil)
+
+// NewFile creates a new File backed by the provided host file descriptor. If
+// NewFile succeeds, ownership of the FD is transferred to the returned File.
+//
+// The returned File cannot be saved, since there is no guarantee that the same
+// FD will exist or represent the same file at time of restore. If such a
+// guarantee does exist, use ImportFile instead.
+func NewFile(ctx context.Context, fd int) (*fs.File, error) {
+ return newFileFromDonatedFD(ctx, fd, false, false)
+}
+
+// ImportFile creates a new File backed by the provided host file descriptor.
+// Unlike NewFile, the file descriptor used by the File is duped from FD to
+// ensure that later changes to FD are not reflected by the fs.File.
+//
+// If the returned file is saved, it will be restored by re-importing the FD
+// originally passed to ImportFile. It is the restorer's responsibility to
+// ensure that the FD represents the same file.
+func ImportFile(ctx context.Context, fd int, isTTY bool) (*fs.File, error) {
+ return newFileFromDonatedFD(ctx, fd, true, isTTY)
+}
+
+// newFileFromDonatedFD returns an fs.File from a donated FD. If the FD is
+// saveable, then saveable is true.
+func newFileFromDonatedFD(ctx context.Context, donated int, saveable, isTTY bool) (*fs.File, error) {
+ var s syscall.Stat_t
+ if err := syscall.Fstat(donated, &s); err != nil {
+ return nil, err
+ }
+ flags, err := fileFlagsFromDonatedFD(donated)
+ if err != nil {
+ return nil, err
+ }
+ switch s.Mode & syscall.S_IFMT {
+ case syscall.S_IFSOCK:
+ if isTTY {
+ return nil, fmt.Errorf("cannot import host socket as TTY")
+ }
+
+ s, err := newSocket(ctx, donated, saveable)
+ if err != nil {
+ return nil, err
+ }
+ s.SetFlags(fs.SettableFileFlags{
+ NonBlocking: flags.NonBlocking,
+ })
+ return s, nil
+ default:
+ msrc := fs.NewNonCachingMountSource(ctx, &filesystem{}, fs.MountSourceFlags{})
+ inode, err := newInode(ctx, msrc, donated, saveable)
+ if err != nil {
+ return nil, err
+ }
+ iops := inode.InodeOperations.(*inodeOperations)
+
+ name := fmt.Sprintf("host:[%d]", inode.StableAttr.InodeID)
+ dirent := fs.NewDirent(ctx, inode, name)
+ defer dirent.DecRef()
+
+ if isTTY {
+ return newTTYFile(ctx, dirent, flags, iops), nil
+ }
+
+ return newFile(ctx, dirent, flags, iops), nil
+ }
+}
+
+func fileFlagsFromDonatedFD(donated int) (fs.FileFlags, error) {
+ flags, _, errno := syscall.Syscall(syscall.SYS_FCNTL, uintptr(donated), syscall.F_GETFL, 0)
+ if errno != 0 {
+ log.Warningf("Failed to get file flags for donated FD %d (errno=%d)", donated, errno)
+ return fs.FileFlags{}, syscall.EIO
+ }
+ accmode := flags & syscall.O_ACCMODE
+ return fs.FileFlags{
+ Direct: flags&syscall.O_DIRECT != 0,
+ NonBlocking: flags&syscall.O_NONBLOCK != 0,
+ Sync: flags&syscall.O_SYNC != 0,
+ Append: flags&syscall.O_APPEND != 0,
+ Read: accmode == syscall.O_RDONLY || accmode == syscall.O_RDWR,
+ Write: accmode == syscall.O_WRONLY || accmode == syscall.O_RDWR,
+ }, nil
+}
+
+// newFile returns a new fs.File.
+func newFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags, iops *inodeOperations) *fs.File {
+ if !iops.ReturnsWouldBlock() {
+ // Allow reading/writing at an arbitrary offset for files
+ // that support it.
+ flags.Pread = true
+ flags.Pwrite = true
+ }
+ return fs.NewFile(ctx, dirent, flags, &fileOperations{iops: iops})
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (f *fileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ f.iops.fileState.queue.EventRegister(e, mask)
+ fdnotifier.UpdateFD(int32(f.iops.fileState.FD()))
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (f *fileOperations) EventUnregister(e *waiter.Entry) {
+ f.iops.fileState.queue.EventUnregister(e)
+ fdnotifier.UpdateFD(int32(f.iops.fileState.FD()))
+}
+
+// Readiness uses the poll() syscall to check the status of the underlying FD.
+func (f *fileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fdnotifier.NonBlockingPoll(int32(f.iops.fileState.FD()), mask)
+}
+
+// Readdir implements fs.FileOperations.Readdir.
+func (f *fileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
+ root := fs.RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+ dirCtx := &fs.DirCtx{
+ Serializer: serializer,
+ DirCursor: &f.dirCursor,
+ }
+ return fs.DirentReaddir(ctx, file.Dirent, f, root, dirCtx, file.Offset())
+}
+
+// IterateDir implements fs.DirIterator.IterateDir.
+func (f *fileOperations) IterateDir(ctx context.Context, d *fs.Dirent, dirCtx *fs.DirCtx, offset int) (int, error) {
+ if f.dirinfo == nil {
+ f.dirinfo = new(dirInfo)
+ f.dirinfo.buf = make([]byte, usermem.PageSize)
+ }
+ entries, err := f.iops.readdirAll(f.dirinfo)
+ if err != nil {
+ return offset, err
+ }
+ count, err := fs.GenericReaddir(dirCtx, fs.NewSortedDentryMap(entries))
+ return offset + count, err
+}
+
+// Write implements fs.FileOperations.Write.
+func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ // Would this file block?
+ if f.iops.ReturnsWouldBlock() {
+ // These files can't be memory mapped, assert this. This also
+ // means that writes do not need to synchronize with memory
+ // mappings nor metadata cached by this file's fs.Inode.
+ if canMap(file.Dirent.Inode) {
+ panic("files that can return EWOULDBLOCK cannot be memory mapped")
+ }
+ // Ignore the offset, these files don't support writing at
+ // an arbitrary offset.
+ writer := fd.NewReadWriter(f.iops.fileState.FD())
+ n, err := src.CopyInTo(ctx, safemem.FromIOWriter{writer})
+ if isBlockError(err) {
+ err = syserror.ErrWouldBlock
+ }
+ return n, err
+ }
+ if !file.Dirent.Inode.MountSource.Flags.ForcePageCache {
+ writer := secio.NewOffsetWriter(fd.NewReadWriter(f.iops.fileState.FD()), offset)
+ return src.CopyInTo(ctx, safemem.FromIOWriter{writer})
+ }
+ return f.iops.cachingInodeOps.Write(ctx, src, offset)
+}
+
+// Read implements fs.FileOperations.Read.
+func (f *fileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ // Would this file block?
+ if f.iops.ReturnsWouldBlock() {
+ // These files can't be memory mapped, assert this. This also
+ // means that reads do not need to synchronize with memory
+ // mappings nor metadata cached by this file's fs.Inode.
+ if canMap(file.Dirent.Inode) {
+ panic("files that can return EWOULDBLOCK cannot be memory mapped")
+ }
+ // Ignore the offset, these files don't support reading at
+ // an arbitrary offset.
+ reader := fd.NewReadWriter(f.iops.fileState.FD())
+ n, err := dst.CopyOutFrom(ctx, safemem.FromIOReader{reader})
+ if isBlockError(err) {
+ // If we got any data at all, return it as a "completed" partial read
+ // rather than retrying until complete.
+ if n != 0 {
+ err = nil
+ } else {
+ err = syserror.ErrWouldBlock
+ }
+ }
+ return n, err
+ }
+ if !file.Dirent.Inode.MountSource.Flags.ForcePageCache {
+ reader := secio.NewOffsetReader(fd.NewReadWriter(f.iops.fileState.FD()), offset)
+ return dst.CopyOutFrom(ctx, safemem.FromIOReader{reader})
+ }
+ return f.iops.cachingInodeOps.Read(ctx, file, dst, offset)
+}
+
+// Fsync implements fs.FileOperations.Fsync.
+func (f *fileOperations) Fsync(ctx context.Context, file *fs.File, start int64, end int64, syncType fs.SyncType) error {
+ switch syncType {
+ case fs.SyncAll, fs.SyncData:
+ if err := file.Dirent.Inode.WriteOut(ctx); err != nil {
+ return err
+ }
+ fallthrough
+ case fs.SyncBackingStorage:
+ return syscall.Fsync(f.iops.fileState.FD())
+ }
+ panic("invalid sync type")
+}
+
+// Flush implements fs.FileOperations.Flush.
+func (f *fileOperations) Flush(context.Context, *fs.File) error {
+ // This is a no-op because flushing the resource backing this
+ // file would mean closing it. We can't do that because other
+ // open files may depend on the backing host FD.
+ return nil
+}
+
+// ConfigureMMap implements fs.FileOperations.ConfigureMMap.
+func (f *fileOperations) ConfigureMMap(ctx context.Context, file *fs.File, opts *memmap.MMapOpts) error {
+ if !canMap(file.Dirent.Inode) {
+ return syserror.ENODEV
+ }
+ return fsutil.GenericConfigureMMap(file, f.iops.cachingInodeOps, opts)
+}
+
+// Seek implements fs.FileOperations.Seek.
+func (f *fileOperations) Seek(ctx context.Context, file *fs.File, whence fs.SeekWhence, offset int64) (int64, error) {
+ return fsutil.SeekWithDirCursor(ctx, file, whence, offset, &f.dirCursor)
+}
diff --git a/pkg/sentry/fs/host/host.go b/pkg/sentry/fs/host/host.go
new file mode 100644
index 000000000..081ba1dd8
--- /dev/null
+++ b/pkg/sentry/fs/host/host.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package host supports file descriptors imported directly.
+package host
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// filesystem is a host filesystem.
+//
+// +stateify savable
+type filesystem struct{}
+
+func init() {
+ fs.RegisterFilesystem(&filesystem{})
+}
+
+// FilesystemName is the name under which the filesystem is registered.
+const FilesystemName = "host"
+
+// Name is the name of the filesystem.
+func (*filesystem) Name() string {
+ return FilesystemName
+}
+
+// Mount returns an error. Mounting hostfs is not allowed.
+func (*filesystem) Mount(ctx context.Context, device string, flags fs.MountSourceFlags, data string, dataObj interface{}) (*fs.Inode, error) {
+ return nil, syserror.EPERM
+}
+
+// AllowUserMount prohibits users from using mount(2) with this file system.
+func (*filesystem) AllowUserMount() bool {
+ return false
+}
+
+// AllowUserList prohibits this filesystem to be listed in /proc/filesystems.
+func (*filesystem) AllowUserList() bool {
+ return false
+}
+
+// Flags returns that there is nothing special about this file system.
+func (*filesystem) Flags() fs.FilesystemFlags {
+ return 0
+}
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
new file mode 100644
index 000000000..fbfba1b58
--- /dev/null
+++ b/pkg/sentry/fs/host/inode.go
@@ -0,0 +1,416 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/secio"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// inodeOperations implements fs.InodeOperations for an fs.Inodes backed
+// by a host file descriptor.
+//
+// +stateify savable
+type inodeOperations struct {
+ fsutil.InodeNotVirtual `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+
+ // fileState implements fs.CachedFileObject. It exists
+ // to break a circular load dependency between inodeOperations
+ // and cachingInodeOps (below).
+ fileState *inodeFileState `state:"wait"`
+
+ // cachedInodeOps implements memmap.Mappable.
+ cachingInodeOps *fsutil.CachingInodeOperations
+
+ // readdirMu protects the file offset on the host FD. This is needed
+ // for readdir because getdents must use the kernel offset, so
+ // concurrent readdirs must be exclusive.
+ //
+ // All read/write functions pass the offset directly to the kernel and
+ // thus don't need a lock.
+ readdirMu sync.Mutex `state:"nosave"`
+}
+
+// inodeFileState implements fs.CachedFileObject and otherwise fully
+// encapsulates state that needs to be manually loaded on restore for
+// this file object.
+//
+// This unfortunate structure exists because fs.CachingInodeOperations
+// defines afterLoad and therefore cannot be lazily loaded (to break a
+// circular load dependency between it and inodeOperations). Even with
+// lazy loading, this approach defines the dependencies between objects
+// and the expected load behavior more concretely.
+//
+// +stateify savable
+type inodeFileState struct {
+ // descriptor is the backing host FD.
+ descriptor *descriptor `state:"wait"`
+
+ // Event queue for blocking operations.
+ queue waiter.Queue `state:"zerovalue"`
+
+ // sattr is used to restore the inodeOperations.
+ sattr fs.StableAttr `state:"wait"`
+
+ // savedUAttr is only allocated during S/R. It points to the save-time
+ // unstable attributes and is used to validate restore-time ones.
+ //
+ // Note that these unstable attributes are only used to detect cross-S/R
+ // external file system metadata changes. They may differ from the
+ // cached unstable attributes in cachingInodeOps, as that might differ
+ // from the external file system attributes if there had been WriteOut
+ // failures. S/R is transparent to Sentry and the latter will continue
+ // using its cached values after restore.
+ savedUAttr *fs.UnstableAttr
+}
+
+// ReadToBlocksAt implements fsutil.CachedFileObject.ReadToBlocksAt.
+func (i *inodeFileState) ReadToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) {
+ // TODO(jamieliu): Using safemem.FromIOReader here is wasteful for two
+ // reasons:
+ //
+ // - Using preadv instead of iterated preads saves on host system calls.
+ //
+ // - Host system calls can handle destination memory that would fault in
+ // gr3 (i.e. they can accept safemem.Blocks with NeedSafecopy() == true),
+ // so the buffering performed by FromIOReader is unnecessary.
+ //
+ // This also applies to the write path below.
+ return safemem.FromIOReader{secio.NewOffsetReader(fd.NewReadWriter(i.FD()), int64(offset))}.ReadToBlocks(dsts)
+}
+
+// WriteFromBlocksAt implements fsutil.CachedFileObject.WriteFromBlocksAt.
+func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) {
+ return safemem.FromIOWriter{secio.NewOffsetWriter(fd.NewReadWriter(i.FD()), int64(offset))}.WriteFromBlocks(srcs)
+}
+
+// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
+func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, _ bool) error {
+ if mask.Empty() {
+ return nil
+ }
+ if mask.UID || mask.GID {
+ return syserror.EPERM
+ }
+ if mask.Perms {
+ if err := syscall.Fchmod(i.FD(), uint32(attr.Perms.LinuxMode())); err != nil {
+ return err
+ }
+ }
+ if mask.Size {
+ if err := syscall.Ftruncate(i.FD(), attr.Size); err != nil {
+ return err
+ }
+ }
+ if mask.AccessTime || mask.ModificationTime {
+ ts := fs.TimeSpec{
+ ATime: attr.AccessTime,
+ ATimeOmit: !mask.AccessTime,
+ MTime: attr.ModificationTime,
+ MTimeOmit: !mask.ModificationTime,
+ }
+ if err := setTimestamps(i.FD(), ts); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// Sync implements fsutil.CachedFileObject.Sync.
+func (i *inodeFileState) Sync(ctx context.Context) error {
+ return syscall.Fsync(i.FD())
+}
+
+// FD implements fsutil.CachedFileObject.FD.
+func (i *inodeFileState) FD() int {
+ return i.descriptor.value
+}
+
+func (i *inodeFileState) unstableAttr(ctx context.Context) (fs.UnstableAttr, error) {
+ var s syscall.Stat_t
+ if err := syscall.Fstat(i.FD(), &s); err != nil {
+ return fs.UnstableAttr{}, err
+ }
+ return unstableAttr(&s), nil
+}
+
+// Allocate implements fsutil.CachedFileObject.Allocate.
+func (i *inodeFileState) Allocate(_ context.Context, offset, length int64) error {
+ return syscall.Fallocate(i.FD(), 0, offset, length)
+}
+
+// inodeOperations implements fs.InodeOperations.
+var _ fs.InodeOperations = (*inodeOperations)(nil)
+
+// newInode returns a new fs.Inode backed by the host FD.
+func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool) (*fs.Inode, error) {
+ // Retrieve metadata.
+ var s syscall.Stat_t
+ err := syscall.Fstat(fd, &s)
+ if err != nil {
+ return nil, err
+ }
+
+ fileState := &inodeFileState{
+ sattr: stableAttr(&s),
+ }
+
+ // Initialize the wrapped host file descriptor.
+ fileState.descriptor, err = newDescriptor(fd, saveable, wouldBlock(&s), &fileState.queue)
+ if err != nil {
+ return nil, err
+ }
+
+ // Build the fs.InodeOperations.
+ uattr := unstableAttr(&s)
+ iops := &inodeOperations{
+ fileState: fileState,
+ cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, fsutil.CachingInodeOperationsOptions{
+ ForcePageCache: msrc.Flags.ForcePageCache,
+ }),
+ }
+
+ // Return the fs.Inode.
+ return fs.NewInode(ctx, iops, msrc, fileState.sattr), nil
+}
+
+// Mappable implements fs.InodeOperations.Mappable.
+func (i *inodeOperations) Mappable(inode *fs.Inode) memmap.Mappable {
+ if !canMap(inode) {
+ return nil
+ }
+ return i.cachingInodeOps
+}
+
+// ReturnsWouldBlock returns true if this host FD can return EWOULDBLOCK for
+// operations that would block.
+func (i *inodeOperations) ReturnsWouldBlock() bool {
+ return i.fileState.descriptor.wouldBlock
+}
+
+// Release implements fs.InodeOperations.Release.
+func (i *inodeOperations) Release(context.Context) {
+ i.fileState.descriptor.Release()
+ i.cachingInodeOps.Release()
+}
+
+// Lookup implements fs.InodeOperations.Lookup.
+func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string) (*fs.Dirent, error) {
+ return nil, syserror.ENOENT
+}
+
+// Create implements fs.InodeOperations.Create.
+func (i *inodeOperations) Create(ctx context.Context, dir *fs.Inode, name string, flags fs.FileFlags, perm fs.FilePermissions) (*fs.File, error) {
+ return nil, syserror.EPERM
+
+}
+
+// CreateDirectory implements fs.InodeOperations.CreateDirectory.
+func (i *inodeOperations) CreateDirectory(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions) error {
+ return syserror.EPERM
+}
+
+// CreateLink implements fs.InodeOperations.CreateLink.
+func (i *inodeOperations) CreateLink(ctx context.Context, dir *fs.Inode, oldname string, newname string) error {
+ return syserror.EPERM
+}
+
+// CreateHardLink implements fs.InodeOperations.CreateHardLink.
+func (*inodeOperations) CreateHardLink(context.Context, *fs.Inode, *fs.Inode, string) error {
+ return syserror.EPERM
+}
+
+// CreateFifo implements fs.InodeOperations.CreateFifo.
+func (*inodeOperations) CreateFifo(context.Context, *fs.Inode, string, fs.FilePermissions) error {
+ return syserror.EPERM
+}
+
+// Remove implements fs.InodeOperations.Remove.
+func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string) error {
+ return syserror.EPERM
+}
+
+// RemoveDirectory implements fs.InodeOperations.RemoveDirectory.
+func (i *inodeOperations) RemoveDirectory(ctx context.Context, dir *fs.Inode, name string) error {
+ return syserror.EPERM
+}
+
+// Rename implements fs.InodeOperations.Rename.
+func (i *inodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ return syserror.EPERM
+}
+
+// Bind implements fs.InodeOperations.Bind.
+func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, data transport.BoundEndpoint, perm fs.FilePermissions) (*fs.Dirent, error) {
+ return nil, syserror.EOPNOTSUPP
+}
+
+// BoundEndpoint implements fs.InodeOperations.BoundEndpoint.
+func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) transport.BoundEndpoint {
+ return nil
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return newFile(ctx, d, flags, i), nil
+}
+
+// canMap returns true if this fs.Inode can be memory mapped.
+func canMap(inode *fs.Inode) bool {
+ // FIXME(b/38213152): Some obscure character devices can be mapped.
+ return fs.IsFile(inode.StableAttr)
+}
+
+// UnstableAttr implements fs.InodeOperations.UnstableAttr.
+func (i *inodeOperations) UnstableAttr(ctx context.Context, inode *fs.Inode) (fs.UnstableAttr, error) {
+ // When the kernel supports mapping host FDs, we do so to take
+ // advantage of the host page cache. We forego updating fs.Inodes
+ // because the host manages consistency of its own inode structures.
+ //
+ // For fs.Inodes that can never be mapped we take advantage of
+ // synchronizing metadata updates through host caches.
+ //
+ // So can we use host kernel metadata caches?
+ if !inode.MountSource.Flags.ForcePageCache || !canMap(inode) {
+ // Then just obtain the attributes.
+ return i.fileState.unstableAttr(ctx)
+ }
+ // No, we're maintaining consistency of metadata ourselves.
+ return i.cachingInodeOps.UnstableAttr(ctx, inode)
+}
+
+// Check implements fs.InodeOperations.Check.
+func (i *inodeOperations) Check(ctx context.Context, inode *fs.Inode, p fs.PermMask) bool {
+ return fs.ContextCanAccessFile(ctx, inode, p)
+}
+
+// SetOwner implements fs.InodeOperations.SetOwner.
+func (i *inodeOperations) SetOwner(context.Context, *fs.Inode, fs.FileOwner) error {
+ return syserror.EPERM
+}
+
+// SetPermissions implements fs.InodeOperations.SetPermissions.
+func (i *inodeOperations) SetPermissions(ctx context.Context, inode *fs.Inode, f fs.FilePermissions) bool {
+ // Can we use host kernel metadata caches?
+ if !inode.MountSource.Flags.ForcePageCache || !canMap(inode) {
+ // Then just change the timestamps on the FD, the host
+ // will synchronize the metadata update with any host
+ // inode and page cache.
+ return syscall.Fchmod(i.fileState.FD(), uint32(f.LinuxMode())) == nil
+ }
+ // Otherwise update our cached metadata.
+ return i.cachingInodeOps.SetPermissions(ctx, inode, f)
+}
+
+// SetTimestamps implements fs.InodeOperations.SetTimestamps.
+func (i *inodeOperations) SetTimestamps(ctx context.Context, inode *fs.Inode, ts fs.TimeSpec) error {
+ // Can we use host kernel metadata caches?
+ if !inode.MountSource.Flags.ForcePageCache || !canMap(inode) {
+ // Then just change the timestamps on the FD, the host
+ // will synchronize the metadata update with any host
+ // inode and page cache.
+ return setTimestamps(i.fileState.FD(), ts)
+ }
+ // Otherwise update our cached metadata.
+ return i.cachingInodeOps.SetTimestamps(ctx, inode, ts)
+}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (i *inodeOperations) Truncate(ctx context.Context, inode *fs.Inode, size int64) error {
+ // Is the file not memory-mappable?
+ if !canMap(inode) {
+ // Then just change the file size on the FD, the host
+ // will synchronize the metadata update with any host
+ // inode and page cache.
+ return syscall.Ftruncate(i.fileState.FD(), size)
+ }
+ // Otherwise we need to go through cachingInodeOps, even if the host page
+ // cache is in use, to invalidate private copies of truncated pages.
+ return i.cachingInodeOps.Truncate(ctx, inode, size)
+}
+
+// Allocate implements fs.InodeOperations.Allocate.
+func (i *inodeOperations) Allocate(ctx context.Context, inode *fs.Inode, offset, length int64) error {
+ // Is the file not memory-mappable?
+ if !canMap(inode) {
+ // Then just send the call to the FD, the host will synchronize the metadata
+ // update with any host inode and page cache.
+ return i.fileState.Allocate(ctx, offset, length)
+ }
+ // Otherwise we need to go through cachingInodeOps, even if the host page
+ // cache is in use, to invalidate private copies of truncated pages.
+ return i.cachingInodeOps.Allocate(ctx, offset, length)
+}
+
+// WriteOut implements fs.InodeOperations.WriteOut.
+func (i *inodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error {
+ if inode.MountSource.Flags.ReadOnly {
+ return nil
+ }
+ // Have we been using host kernel metadata caches?
+ if !inode.MountSource.Flags.ForcePageCache || !canMap(inode) {
+ // Then the metadata is already up to date on the host.
+ return nil
+ }
+ // Otherwise we need to write out cached pages and attributes
+ // that are dirty.
+ return i.cachingInodeOps.WriteOut(ctx, inode)
+}
+
+// Readlink implements fs.InodeOperations.Readlink.
+func (i *inodeOperations) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
+ return readLink(i.fileState.FD())
+}
+
+// Getlink implements fs.InodeOperations.Getlink.
+func (i *inodeOperations) Getlink(context.Context, *fs.Inode) (*fs.Dirent, error) {
+ if !fs.IsSymlink(i.fileState.sattr) {
+ return nil, syserror.ENOLINK
+ }
+ return nil, fs.ErrResolveViaReadlink
+}
+
+// StatFS implements fs.InodeOperations.StatFS.
+func (i *inodeOperations) StatFS(context.Context) (fs.Info, error) {
+ return fs.Info{}, syserror.ENOSYS
+}
+
+// AddLink implements fs.InodeOperations.AddLink.
+func (i *inodeOperations) AddLink() {}
+
+// DropLink implements fs.InodeOperations.DropLink.
+func (i *inodeOperations) DropLink() {}
+
+// NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange.
+func (i *inodeOperations) NotifyStatusChange(ctx context.Context) {}
+
+// readdirAll returns all of the directory entries in i.
+func (i *inodeOperations) readdirAll(d *dirInfo) (map[string]fs.DentAttr, error) {
+ // We only support non-directory file descriptors that have been
+ // imported, so just claim that this isn't a directory, even if it is.
+ return nil, syscall.ENOTDIR
+}
diff --git a/pkg/sentry/fs/host/inode_state.go b/pkg/sentry/fs/host/inode_state.go
new file mode 100644
index 000000000..1adbd4562
--- /dev/null
+++ b/pkg/sentry/fs/host/inode_state.go
@@ -0,0 +1,49 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/sentry/device"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+)
+
+// afterLoad is invoked by stateify.
+func (i *inodeFileState) afterLoad() {
+ // Initialize the descriptor value.
+ if err := i.descriptor.initAfterLoad(i.sattr.InodeID, &i.queue); err != nil {
+ panic(fmt.Sprintf("failed to load value of descriptor: %v", err))
+ }
+
+ // Remap the inode number.
+ var s syscall.Stat_t
+ if err := syscall.Fstat(i.FD(), &s); err != nil {
+ panic(fs.ErrCorruption{fmt.Errorf("failed to get metadata for fd %d: %v", i.FD(), err)})
+ }
+ key := device.MultiDeviceKey{
+ Device: s.Dev,
+ Inode: s.Ino,
+ }
+ if !hostFileDevice.Load(key, i.sattr.InodeID) {
+ // This means there was a conflict at s.Dev and s.Ino with
+ // another inode mapping: two files that were unique on the
+ // saved filesystem are no longer unique on this filesystem.
+ // Since this violates the contract that filesystems cannot
+ // change across save and restore, error out.
+ panic(fs.ErrCorruption{fmt.Errorf("host %s conflict in host device mappings: %s", key, hostFileDevice)})
+ }
+}
diff --git a/pkg/sentry/fs/host/inode_test.go b/pkg/sentry/fs/host/inode_test.go
new file mode 100644
index 000000000..c507f57eb
--- /dev/null
+++ b/pkg/sentry/fs/host/inode_test.go
@@ -0,0 +1,45 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+)
+
+// TestCloseFD verifies fds will be closed.
+func TestCloseFD(t *testing.T) {
+ var p [2]int
+ if err := syscall.Pipe(p[0:]); err != nil {
+ t.Fatalf("Failed to create pipe %v", err)
+ }
+ defer syscall.Close(p[0])
+ defer syscall.Close(p[1])
+
+ // Use the write-end because we will detect if it's closed on the read end.
+ ctx := contexttest.Context(t)
+ file, err := NewFile(ctx, p[1])
+ if err != nil {
+ t.Fatalf("Failed to create File: %v", err)
+ }
+ file.DecRef()
+
+ s := make([]byte, 10)
+ if c, err := syscall.Read(p[0], s); c != 0 || err != nil {
+ t.Errorf("want 0, nil (EOF) from read end, got %v, %v", c, err)
+ }
+}
diff --git a/pkg/sentry/fs/host/ioctl_unsafe.go b/pkg/sentry/fs/host/ioctl_unsafe.go
new file mode 100644
index 000000000..150ac8e19
--- /dev/null
+++ b/pkg/sentry/fs/host/ioctl_unsafe.go
@@ -0,0 +1,60 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// LINT.IfChange
+
+func ioctlGetTermios(fd int) (*linux.Termios, error) {
+ var t linux.Termios
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TCGETS, uintptr(unsafe.Pointer(&t)))
+ if errno != 0 {
+ return nil, errno
+ }
+ return &t, nil
+}
+
+func ioctlSetTermios(fd int, req uint64, t *linux.Termios) error {
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(unsafe.Pointer(t)))
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+func ioctlGetWinsize(fd int) (*linux.Winsize, error) {
+ var w linux.Winsize
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TIOCGWINSZ, uintptr(unsafe.Pointer(&w)))
+ if errno != 0 {
+ return nil, errno
+ }
+ return &w, nil
+}
+
+func ioctlSetWinsize(fd int, w *linux.Winsize) error {
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TIOCSWINSZ, uintptr(unsafe.Pointer(w)))
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+// LINT.ThenChange(../../fsimpl/host/ioctl_unsafe.go)
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
new file mode 100644
index 000000000..cfb089e43
--- /dev/null
+++ b/pkg/sentry/fs/host/socket.go
@@ -0,0 +1,384 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/unet"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+// ConnectedEndpoint is a host FD backed implementation of
+// transport.ConnectedEndpoint and transport.Receiver.
+//
+// +stateify savable
+type ConnectedEndpoint struct {
+ // ref keeps track of references to a connectedEndpoint.
+ ref refs.AtomicRefCount
+
+ queue *waiter.Queue
+ path string
+
+ // If srfd >= 0, it is the host FD that file was imported from.
+ srfd int `state:"wait"`
+
+ // stype is the type of Unix socket.
+ stype linux.SockType
+
+ // sndbuf is the size of the send buffer.
+ //
+ // N.B. When this is smaller than the host size, we present it via
+ // GetSockOpt and message splitting/rejection in SendMsg, but do not
+ // prevent lots of small messages from filling the real send buffer
+ // size on the host.
+ sndbuf int64 `state:"nosave"`
+
+ // mu protects the fields below.
+ mu sync.RWMutex `state:"nosave"`
+
+ // file is an *fd.FD containing the FD backing this endpoint. It must be
+ // set to nil if it has been closed.
+ file *fd.FD `state:"nosave"`
+}
+
+// init performs initialization required for creating new ConnectedEndpoints and
+// for restoring them.
+func (c *ConnectedEndpoint) init() *syserr.Error {
+ family, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_DOMAIN)
+ if err != nil {
+ return syserr.FromError(err)
+ }
+
+ if family != syscall.AF_UNIX {
+ // We only allow Unix sockets.
+ return syserr.ErrInvalidEndpointState
+ }
+
+ stype, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_TYPE)
+ if err != nil {
+ return syserr.FromError(err)
+ }
+
+ if err := syscall.SetNonblock(c.file.FD(), true); err != nil {
+ return syserr.FromError(err)
+ }
+
+ sndbuf, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_SNDBUF)
+ if err != nil {
+ return syserr.FromError(err)
+ }
+
+ c.stype = linux.SockType(stype)
+ c.sndbuf = int64(sndbuf)
+
+ return nil
+}
+
+// NewConnectedEndpoint creates a new ConnectedEndpoint backed by a host FD
+// that will pretend to be bound at a given sentry path.
+//
+// The caller is responsible for calling Init(). Additionaly, Release needs to
+// be called twice because ConnectedEndpoint is both a transport.Receiver and
+// transport.ConnectedEndpoint.
+func NewConnectedEndpoint(ctx context.Context, file *fd.FD, queue *waiter.Queue, path string) (*ConnectedEndpoint, *syserr.Error) {
+ e := ConnectedEndpoint{
+ path: path,
+ queue: queue,
+ file: file,
+ srfd: -1,
+ }
+
+ if err := e.init(); err != nil {
+ return nil, err
+ }
+
+ // AtomicRefCounters start off with a single reference. We need two.
+ e.ref.IncRef()
+
+ e.ref.EnableLeakCheck("host.ConnectedEndpoint")
+
+ return &e, nil
+}
+
+// Init will do initialization required without holding other locks.
+func (c *ConnectedEndpoint) Init() {
+ if err := fdnotifier.AddFD(int32(c.file.FD()), c.queue); err != nil {
+ panic(err)
+ }
+}
+
+// NewSocketWithDirent allocates a new unix socket with host endpoint.
+//
+// This is currently only used by unsaveable Gofer nodes.
+//
+// NewSocketWithDirent takes ownership of f on success.
+func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.FileFlags) (*fs.File, error) {
+ f2 := fd.New(f.FD())
+ var q waiter.Queue
+ e, err := NewConnectedEndpoint(ctx, f2, &q, "" /* path */)
+ if err != nil {
+ f2.Release()
+ return nil, err.ToError()
+ }
+
+ // Take ownship of the FD.
+ f.Release()
+
+ e.Init()
+
+ ep := transport.NewExternal(ctx, e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
+
+ return unixsocket.NewWithDirent(ctx, d, ep, e.stype, flags), nil
+}
+
+// newSocket allocates a new unix socket with host endpoint.
+func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) {
+ ownedfd := orgfd
+ srfd := -1
+ if saveable {
+ var err error
+ ownedfd, err = syscall.Dup(orgfd)
+ if err != nil {
+ return nil, err
+ }
+ srfd = orgfd
+ }
+ f := fd.New(ownedfd)
+ var q waiter.Queue
+ e, err := NewConnectedEndpoint(ctx, f, &q, "" /* path */)
+ if err != nil {
+ if saveable {
+ f.Close()
+ } else {
+ f.Release()
+ }
+ return nil, err.ToError()
+ }
+
+ e.srfd = srfd
+ e.Init()
+
+ ep := transport.NewExternal(ctx, e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
+
+ return unixsocket.New(ctx, ep, e.stype), nil
+}
+
+// Send implements transport.ConnectedEndpoint.Send.
+func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ if !controlMessages.Empty() {
+ return 0, false, syserr.ErrInvalidEndpointState
+ }
+
+ // Since stream sockets don't preserve message boundaries, we can write
+ // only as much of the message as fits in the send buffer.
+ truncate := c.stype == linux.SOCK_STREAM
+
+ n, totalLen, err := fdWriteVec(c.file.FD(), data, c.sndbuf, truncate)
+ if n < totalLen && err == nil {
+ // The host only returns a short write if it would otherwise
+ // block (and only for stream sockets).
+ err = syserror.EAGAIN
+ }
+ if n > 0 && err != syserror.EAGAIN {
+ // The caller may need to block to send more data, but
+ // otherwise there isn't anything that can be done about an
+ // error with a partial write.
+ err = nil
+ }
+
+ // There is no need for the callee to call SendNotify because fdWriteVec
+ // uses the host's sendmsg(2) and the host kernel's queue.
+ return n, false, syserr.FromError(err)
+}
+
+// SendNotify implements transport.ConnectedEndpoint.SendNotify.
+func (c *ConnectedEndpoint) SendNotify() {}
+
+// CloseSend implements transport.ConnectedEndpoint.CloseSend.
+func (c *ConnectedEndpoint) CloseSend() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if err := syscall.Shutdown(c.file.FD(), syscall.SHUT_WR); err != nil {
+ // A well-formed UDS shutdown can't fail. See
+ // net/unix/af_unix.c:unix_shutdown.
+ panic(fmt.Sprintf("failed write shutdown on host socket %+v: %v", c, err))
+ }
+}
+
+// CloseNotify implements transport.ConnectedEndpoint.CloseNotify.
+func (c *ConnectedEndpoint) CloseNotify() {}
+
+// Writable implements transport.ConnectedEndpoint.Writable.
+func (c *ConnectedEndpoint) Writable() bool {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventOut)&waiter.EventOut != 0
+}
+
+// Passcred implements transport.ConnectedEndpoint.Passcred.
+func (c *ConnectedEndpoint) Passcred() bool {
+ // We don't support credential passing for host sockets.
+ return false
+}
+
+// GetLocalAddress implements transport.ConnectedEndpoint.GetLocalAddress.
+func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{Addr: tcpip.Address(c.path)}, nil
+}
+
+// EventUpdate implements transport.ConnectedEndpoint.EventUpdate.
+func (c *ConnectedEndpoint) EventUpdate() {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ if c.file.FD() != -1 {
+ fdnotifier.UpdateFD(int32(c.file.FD()))
+ }
+}
+
+// Recv implements transport.Receiver.Recv.
+func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ var cm unet.ControlMessage
+ if numRights > 0 {
+ cm.EnableFDs(int(numRights))
+ }
+
+ // N.B. Unix sockets don't have a receive buffer, the send buffer
+ // serves both purposes.
+ rl, ml, cl, cTrunc, err := fdReadVec(c.file.FD(), data, []byte(cm), peek, c.sndbuf)
+ if rl > 0 && err != nil {
+ // We got some data, so all we need to do on error is return
+ // the data that we got. Short reads are fine, no need to
+ // block.
+ err = nil
+ }
+ if err != nil {
+ return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err)
+ }
+
+ // There is no need for the callee to call RecvNotify because fdReadVec uses
+ // the host's recvmsg(2) and the host kernel's queue.
+
+ // Trim the control data if we received less than the full amount.
+ if cl < uint64(len(cm)) {
+ cm = cm[:cl]
+ }
+
+ // Avoid extra allocations in the case where there isn't any control data.
+ if len(cm) == 0 {
+ return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil
+ }
+
+ fds, err := cm.ExtractFDs()
+ if err != nil {
+ return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err)
+ }
+
+ if len(fds) == 0 {
+ return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil
+ }
+ return rl, ml, control.New(nil, nil, newSCMRights(fds)), cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil
+}
+
+// close releases all resources related to the endpoint.
+func (c *ConnectedEndpoint) close() {
+ fdnotifier.RemoveFD(int32(c.file.FD()))
+ c.file.Close()
+ c.file = nil
+}
+
+// RecvNotify implements transport.Receiver.RecvNotify.
+func (c *ConnectedEndpoint) RecvNotify() {}
+
+// CloseRecv implements transport.Receiver.CloseRecv.
+func (c *ConnectedEndpoint) CloseRecv() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if err := syscall.Shutdown(c.file.FD(), syscall.SHUT_RD); err != nil {
+ // A well-formed UDS shutdown can't fail. See
+ // net/unix/af_unix.c:unix_shutdown.
+ panic(fmt.Sprintf("failed read shutdown on host socket %+v: %v", c, err))
+ }
+}
+
+// Readable implements transport.Receiver.Readable.
+func (c *ConnectedEndpoint) Readable() bool {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventIn)&waiter.EventIn != 0
+}
+
+// SendQueuedSize implements transport.Receiver.SendQueuedSize.
+func (c *ConnectedEndpoint) SendQueuedSize() int64 {
+ // TODO(gvisor.dev/issue/273): SendQueuedSize isn't supported for host
+ // sockets because we don't allow the sentry to call ioctl(2).
+ return -1
+}
+
+// RecvQueuedSize implements transport.Receiver.RecvQueuedSize.
+func (c *ConnectedEndpoint) RecvQueuedSize() int64 {
+ // TODO(gvisor.dev/issue/273): RecvQueuedSize isn't supported for host
+ // sockets because we don't allow the sentry to call ioctl(2).
+ return -1
+}
+
+// SendMaxQueueSize implements transport.Receiver.SendMaxQueueSize.
+func (c *ConnectedEndpoint) SendMaxQueueSize() int64 {
+ return int64(c.sndbuf)
+}
+
+// RecvMaxQueueSize implements transport.Receiver.RecvMaxQueueSize.
+func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
+ // N.B. Unix sockets don't use the receive buffer. We'll claim it is
+ // the same size as the send buffer.
+ return int64(c.sndbuf)
+}
+
+// Release implements transport.ConnectedEndpoint.Release and transport.Receiver.Release.
+func (c *ConnectedEndpoint) Release() {
+ c.ref.DecRefWithDestructor(c.close)
+}
+
+// CloseUnread implements transport.ConnectedEndpoint.CloseUnread.
+func (c *ConnectedEndpoint) CloseUnread() {}
+
+// LINT.ThenChange(../../fsimpl/host/socket.go)
diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go
new file mode 100644
index 000000000..5c18dbd5e
--- /dev/null
+++ b/pkg/sentry/fs/host/socket_iovec.go
@@ -0,0 +1,117 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// LINT.IfChange
+
+// maxIovs is the maximum number of iovecs to pass to the host.
+var maxIovs = linux.UIO_MAXIOV
+
+// copyToMulti copies as many bytes from src to dst as possible.
+func copyToMulti(dst [][]byte, src []byte) {
+ for _, d := range dst {
+ done := copy(d, src)
+ src = src[done:]
+ if len(src) == 0 {
+ break
+ }
+ }
+}
+
+// copyFromMulti copies as many bytes from src to dst as possible.
+func copyFromMulti(dst []byte, src [][]byte) {
+ for _, s := range src {
+ done := copy(dst, s)
+ dst = dst[done:]
+ if len(dst) == 0 {
+ break
+ }
+ }
+}
+
+// buildIovec builds an iovec slice from the given []byte slice.
+//
+// If truncate, truncate bufs > maxlen. Otherwise, immediately return an error.
+//
+// If length < the total length of bufs, err indicates why, even when returning
+// a truncated iovec.
+//
+// If intermediate != nil, iovecs references intermediate rather than bufs and
+// the caller must copy to/from bufs as necessary.
+func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovecs []syscall.Iovec, intermediate []byte, err error) {
+ var iovsRequired int
+ for _, b := range bufs {
+ length += int64(len(b))
+ if len(b) > 0 {
+ iovsRequired++
+ }
+ }
+
+ stopLen := length
+ if length > maxlen {
+ if truncate {
+ stopLen = maxlen
+ err = syserror.EAGAIN
+ } else {
+ return 0, nil, nil, syserror.EMSGSIZE
+ }
+ }
+
+ if iovsRequired > maxIovs {
+ // The kernel will reject our call if we pass this many iovs.
+ // Use a single intermediate buffer instead.
+ b := make([]byte, stopLen)
+
+ return stopLen, []syscall.Iovec{{
+ Base: &b[0],
+ Len: uint64(stopLen),
+ }}, b, err
+ }
+
+ var total int64
+ iovecs = make([]syscall.Iovec, 0, iovsRequired)
+ for i := range bufs {
+ l := len(bufs[i])
+ if l == 0 {
+ continue
+ }
+
+ stop := int64(l)
+ if total+stop > stopLen {
+ stop = stopLen - total
+ }
+
+ iovecs = append(iovecs, syscall.Iovec{
+ Base: &bufs[i][0],
+ Len: uint64(stop),
+ })
+
+ total += stop
+ if total >= stopLen {
+ break
+ }
+ }
+
+ return total, iovecs, nil, err
+}
+
+// LINT.ThenChange(../../fsimpl/host/socket_iovec.go)
diff --git a/pkg/sentry/fs/host/socket_state.go b/pkg/sentry/fs/host/socket_state.go
new file mode 100644
index 000000000..498018f0a
--- /dev/null
+++ b/pkg/sentry/fs/host/socket_state.go
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fd"
+)
+
+// beforeSave is invoked by stateify.
+func (c *ConnectedEndpoint) beforeSave() {
+ if c.srfd < 0 {
+ panic("only host file descriptors provided at sentry startup can be saved")
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (c *ConnectedEndpoint) afterLoad() {
+ f, err := syscall.Dup(c.srfd)
+ if err != nil {
+ panic(fmt.Sprintf("failed to dup restored FD %d: %v", c.srfd, err))
+ }
+ c.file = fd.New(f)
+ if err := c.init(); err != nil {
+ panic(fmt.Sprintf("Could not restore host socket FD %d: %v", c.srfd, err))
+ }
+ c.Init()
+}
diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go
new file mode 100644
index 000000000..affdbcacb
--- /dev/null
+++ b/pkg/sentry/fs/host/socket_test.go
@@ -0,0 +1,246 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "reflect"
+ "syscall"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+var (
+ // Make sure that ConnectedEndpoint implements transport.ConnectedEndpoint.
+ _ = transport.ConnectedEndpoint(new(ConnectedEndpoint))
+
+ // Make sure that ConnectedEndpoint implements transport.Receiver.
+ _ = transport.Receiver(new(ConnectedEndpoint))
+)
+
+func getFl(fd int) (uint32, error) {
+ fl, _, err := syscall.RawSyscall(syscall.SYS_FCNTL, uintptr(fd), syscall.F_GETFL, 0)
+ if err == 0 {
+ return uint32(fl), nil
+ }
+ return 0, err
+}
+
+func TestSocketIsBlocking(t *testing.T) {
+ // Using socketpair here because it's already connected.
+ pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ t.Fatalf("host socket creation failed: %v", err)
+ }
+
+ fl, err := getFl(pair[0])
+ if err != nil {
+ t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[0], err)
+ }
+ if fl&syscall.O_NONBLOCK == syscall.O_NONBLOCK {
+ t.Fatalf("Expected socket %v to be blocking", pair[0])
+ }
+ if fl, err = getFl(pair[1]); err != nil {
+ t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[1], err)
+ }
+ if fl&syscall.O_NONBLOCK == syscall.O_NONBLOCK {
+ t.Fatalf("Expected socket %v to be blocking", pair[1])
+ }
+ sock, err := newSocket(contexttest.Context(t), pair[0], false)
+ if err != nil {
+ t.Fatalf("newSocket(%v) failed => %v", pair[0], err)
+ }
+ defer sock.DecRef()
+ // Test that the socket now is non-blocking.
+ if fl, err = getFl(pair[0]); err != nil {
+ t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[0], err)
+ }
+ if fl&syscall.O_NONBLOCK != syscall.O_NONBLOCK {
+ t.Errorf("Expected socket %v to have become non-blocking", pair[0])
+ }
+ if fl, err = getFl(pair[1]); err != nil {
+ t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[1], err)
+ }
+ if fl&syscall.O_NONBLOCK == syscall.O_NONBLOCK {
+ t.Errorf("Did not expect socket %v to become non-blocking", pair[1])
+ }
+}
+
+func TestSocketWritev(t *testing.T) {
+ // Using socketpair here because it's already connected.
+ pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ t.Fatalf("host socket creation failed: %v", err)
+ }
+ socket, err := newSocket(contexttest.Context(t), pair[0], false)
+ if err != nil {
+ t.Fatalf("newSocket(%v) => %v", pair[0], err)
+ }
+ defer socket.DecRef()
+ buf := []byte("hello world\n")
+ n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(buf))
+ if err != nil {
+ t.Fatalf("socket writev failed: %v", err)
+ }
+
+ if n != int64(len(buf)) {
+ t.Fatalf("socket writev wrote incorrect bytes: %d", n)
+ }
+}
+
+func TestSocketWritevLen0(t *testing.T) {
+ // Using socketpair here because it's already connected.
+ pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ t.Fatalf("host socket creation failed: %v", err)
+ }
+ socket, err := newSocket(contexttest.Context(t), pair[0], false)
+ if err != nil {
+ t.Fatalf("newSocket(%v) => %v", pair[0], err)
+ }
+ defer socket.DecRef()
+ n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(nil))
+ if err != nil {
+ t.Fatalf("socket writev failed: %v", err)
+ }
+
+ if n != 0 {
+ t.Fatalf("socket writev wrote incorrect bytes: %d", n)
+ }
+}
+
+func TestSocketSendMsgLen0(t *testing.T) {
+ // Using socketpair here because it's already connected.
+ pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ t.Fatalf("host socket creation failed: %v", err)
+ }
+ sfile, err := newSocket(contexttest.Context(t), pair[0], false)
+ if err != nil {
+ t.Fatalf("newSocket(%v) => %v", pair[0], err)
+ }
+ defer sfile.DecRef()
+
+ s := sfile.FileOperations.(socket.Socket)
+ n, terr := s.SendMsg(nil, usermem.BytesIOSequence(nil), []byte{}, 0, false, ktime.Time{}, socket.ControlMessages{})
+ if n != 0 {
+ t.Fatalf("socket sendmsg() failed: %v wrote: %d", terr, n)
+ }
+
+ if terr != nil {
+ t.Fatalf("socket sendmsg() failed: %v", terr)
+ }
+}
+
+func TestListen(t *testing.T) {
+ pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ t.Fatalf("syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) => %v", err)
+ }
+ sfile1, err := newSocket(contexttest.Context(t), pair[0], false)
+ if err != nil {
+ t.Fatalf("newSocket(%v) => %v", pair[0], err)
+ }
+ defer sfile1.DecRef()
+ socket1 := sfile1.FileOperations.(socket.Socket)
+
+ sfile2, err := newSocket(contexttest.Context(t), pair[1], false)
+ if err != nil {
+ t.Fatalf("newSocket(%v) => %v", pair[1], err)
+ }
+ defer sfile2.DecRef()
+ socket2 := sfile2.FileOperations.(socket.Socket)
+
+ // Socketpairs can not be listened to.
+ if err := socket1.Listen(nil, 64); err != syserr.ErrInvalidEndpointState {
+ t.Fatalf("socket1.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err)
+ }
+ if err := socket2.Listen(nil, 64); err != syserr.ErrInvalidEndpointState {
+ t.Fatalf("socket2.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err)
+ }
+
+ // Create a Unix socket, do not bind it.
+ sock, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ t.Fatalf("syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) => %v", err)
+ }
+ sfile3, err := newSocket(contexttest.Context(t), sock, false)
+ if err != nil {
+ t.Fatalf("newSocket(%v) => %v", sock, err)
+ }
+ defer sfile3.DecRef()
+ socket3 := sfile3.FileOperations.(socket.Socket)
+
+ // This socket is not bound so we can't listen on it.
+ if err := socket3.Listen(nil, 64); err != syserr.ErrInvalidEndpointState {
+ t.Fatalf("socket3.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err)
+ }
+}
+
+func TestPasscred(t *testing.T) {
+ e := &ConnectedEndpoint{}
+ if got, want := e.Passcred(), false; got != want {
+ t.Errorf("Got %#v.Passcred() = %t, want = %t", e, got, want)
+ }
+}
+
+func TestGetLocalAddress(t *testing.T) {
+ e := &ConnectedEndpoint{path: "foo"}
+ want := tcpip.FullAddress{Addr: tcpip.Address("foo")}
+ if got, err := e.GetLocalAddress(); err != nil || got != want {
+ t.Errorf("Got %#v.GetLocalAddress() = %#v, %v, want = %#v, %v", e, got, err, want, nil)
+ }
+}
+
+func TestQueuedSize(t *testing.T) {
+ e := &ConnectedEndpoint{}
+ tests := []struct {
+ name string
+ f func() int64
+ }{
+ {"SendQueuedSize", e.SendQueuedSize},
+ {"RecvQueuedSize", e.RecvQueuedSize},
+ }
+
+ for _, test := range tests {
+ if got, want := test.f(), int64(-1); got != want {
+ t.Errorf("Got %#v.%s() = %d, want = %d", e, test.name, got, want)
+ }
+ }
+}
+
+func TestRelease(t *testing.T) {
+ f, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ t.Fatal("Creating socket:", err)
+ }
+ c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)}
+ want := &ConnectedEndpoint{queue: c.queue}
+ want.ref.DecRef()
+ fdnotifier.AddFD(int32(c.file.FD()), nil)
+ c.Release()
+ if !reflect.DeepEqual(c, want) {
+ t.Errorf("got = %#v, want = %#v", c, want)
+ }
+}
diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go
new file mode 100644
index 000000000..5d4f312cf
--- /dev/null
+++ b/pkg/sentry/fs/host/socket_unsafe.go
@@ -0,0 +1,105 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+// LINT.IfChange
+
+// fdReadVec receives from fd to bufs.
+//
+// If the total length of bufs is > maxlen, fdReadVec will do a partial read
+// and err will indicate why the message was truncated.
+func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int64) (readLen int64, msgLen int64, controlLen uint64, controlTrunc bool, err error) {
+ flags := uintptr(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC)
+ if peek {
+ flags |= syscall.MSG_PEEK
+ }
+
+ // Always truncate the receive buffer. All socket types will truncate
+ // received messages.
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, true)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, 0, 0, false, err
+ }
+
+ var msg syscall.Msghdr
+ if len(control) != 0 {
+ msg.Control = &control[0]
+ msg.Controllen = uint64(len(control))
+ }
+
+ if len(iovecs) != 0 {
+ msg.Iov = &iovecs[0]
+ msg.Iovlen = uint64(len(iovecs))
+ }
+
+ rawN, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags)
+ if e != 0 {
+ // N.B. prioritize the syscall error over the buildIovec error.
+ return 0, 0, 0, false, e
+ }
+ n := int64(rawN)
+
+ // Copy data back to bufs.
+ if intermediate != nil {
+ copyToMulti(bufs, intermediate)
+ }
+
+ controlTrunc = msg.Flags&syscall.MSG_CTRUNC == syscall.MSG_CTRUNC
+
+ if n > length {
+ return length, n, msg.Controllen, controlTrunc, err
+ }
+
+ return n, n, msg.Controllen, controlTrunc, err
+}
+
+// fdWriteVec sends from bufs to fd.
+//
+// If the total length of bufs is > maxlen && truncate, fdWriteVec will do a
+// partial write and err will indicate why the message was truncated.
+func fdWriteVec(fd int, bufs [][]byte, maxlen int64, truncate bool) (int64, int64, error) {
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, truncate)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, length, err
+ }
+
+ // Copy data to intermediate buf.
+ if intermediate != nil {
+ copyFromMulti(intermediate, bufs)
+ }
+
+ var msg syscall.Msghdr
+ if len(iovecs) > 0 {
+ msg.Iov = &iovecs[0]
+ msg.Iovlen = uint64(len(iovecs))
+ }
+
+ n, _, e := syscall.RawSyscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_NOSIGNAL)
+ if e != 0 {
+ // N.B. prioritize the syscall error over the buildIovec error.
+ return 0, length, e
+ }
+
+ return int64(n), length, err
+}
+
+// LINT.ThenChange(../../fsimpl/host/socket_unsafe.go)
diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go
new file mode 100644
index 000000000..82a02fcb2
--- /dev/null
+++ b/pkg/sentry/fs/host/tty.go
@@ -0,0 +1,364 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/unimpl"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// LINT.IfChange
+
+// TTYFileOperations implements fs.FileOperations for a host file descriptor
+// that wraps a TTY FD.
+//
+// +stateify savable
+type TTYFileOperations struct {
+ fileOperations
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // session is the session attached to this TTYFileOperations.
+ session *kernel.Session
+
+ // fgProcessGroup is the foreground process group that is currently
+ // connected to this TTY.
+ fgProcessGroup *kernel.ProcessGroup
+
+ // termios contains the terminal attributes for this TTY.
+ termios linux.KernelTermios
+}
+
+// newTTYFile returns a new fs.File that wraps a TTY FD.
+func newTTYFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags, iops *inodeOperations) *fs.File {
+ return fs.NewFile(ctx, dirent, flags, &TTYFileOperations{
+ fileOperations: fileOperations{iops: iops},
+ termios: linux.DefaultSlaveTermios,
+ })
+}
+
+// InitForegroundProcessGroup sets the foreground process group and session for
+// the TTY. This should only be called once, after the foreground process group
+// has been created, but before it has started running.
+func (t *TTYFileOperations) InitForegroundProcessGroup(pg *kernel.ProcessGroup) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if t.fgProcessGroup != nil {
+ panic("foreground process group is already set")
+ }
+ t.fgProcessGroup = pg
+ t.session = pg.Session()
+}
+
+// ForegroundProcessGroup returns the foreground process for the TTY.
+func (t *TTYFileOperations) ForegroundProcessGroup() *kernel.ProcessGroup {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.fgProcessGroup
+}
+
+// Read implements fs.FileOperations.Read.
+//
+// Reading from a TTY is only allowed for foreground process groups. Background
+// process groups will either get EIO or a SIGTTIN.
+//
+// See drivers/tty/n_tty.c:n_tty_read()=>job_control().
+func (t *TTYFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Are we allowed to do the read?
+ // drivers/tty/n_tty.c:n_tty_read()=>job_control()=>tty_check_change().
+ if err := t.checkChange(ctx, linux.SIGTTIN); err != nil {
+ return 0, err
+ }
+
+ // Do the read.
+ return t.fileOperations.Read(ctx, file, dst, offset)
+}
+
+// Write implements fs.FileOperations.Write.
+func (t *TTYFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Check whether TOSTOP is enabled. This corresponds to the check in
+ // drivers/tty/n_tty.c:n_tty_write().
+ if t.termios.LEnabled(linux.TOSTOP) {
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
+ }
+ return t.fileOperations.Write(ctx, file, src, offset)
+}
+
+// Release implements fs.FileOperations.Release.
+func (t *TTYFileOperations) Release() {
+ t.mu.Lock()
+ t.fgProcessGroup = nil
+ t.mu.Unlock()
+
+ t.fileOperations.Release()
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ // Ignore arg[0]. This is the real FD:
+ fd := t.fileOperations.iops.fileState.FD()
+ ioctl := args[1].Uint64()
+ switch ioctl {
+ case linux.TCGETS:
+ termios, err := ioctlGetTermios(fd)
+ if err != nil {
+ return 0, err
+ }
+ _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), termios, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TCSETS, linux.TCSETSW, linux.TCSETSF:
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
+
+ var termios linux.Termios
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &termios, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ err := ioctlSetTermios(fd, ioctl, &termios)
+ if err == nil {
+ t.termios.FromTermios(termios)
+ }
+ return 0, err
+
+ case linux.TIOCGPGRP:
+ // Args: pid_t *argp
+ // When successful, equivalent to *argp = tcgetpgrp(fd).
+ // Get the process group ID of the foreground process group on
+ // this terminal.
+
+ pidns := kernel.PIDNamespaceFromContext(ctx)
+ if pidns == nil {
+ return 0, syserror.ENOTTY
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Map the ProcessGroup into a ProcessGroupID in the task's PID
+ // namespace.
+ pgID := pidns.IDOfProcessGroup(t.fgProcessGroup)
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TIOCSPGRP:
+ // Args: const pid_t *argp
+ // Equivalent to tcsetpgrp(fd, *argp).
+ // Set the foreground process group ID of this terminal.
+
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ return 0, syserror.ENOTTY
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Check that we are allowed to set the process group.
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ // drivers/tty/tty_io.c:tiocspgrp() converts -EIO from
+ // tty_check_change() to -ENOTTY.
+ if err == syserror.EIO {
+ return 0, syserror.ENOTTY
+ }
+ return 0, err
+ }
+
+ // Check that calling task's process group is in the TTY
+ // session.
+ if task.ThreadGroup().Session() != t.session {
+ return 0, syserror.ENOTTY
+ }
+
+ var pgID kernel.ProcessGroupID
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+
+ // pgID must be non-negative.
+ if pgID < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Process group with pgID must exist in this PID namespace.
+ pidns := task.PIDNamespace()
+ pg := pidns.ProcessGroupWithID(pgID)
+ if pg == nil {
+ return 0, syserror.ESRCH
+ }
+
+ // Check that new process group is in the TTY session.
+ if pg.Session() != t.session {
+ return 0, syserror.EPERM
+ }
+
+ t.fgProcessGroup = pg
+ return 0, nil
+
+ case linux.TIOCGWINSZ:
+ // Args: struct winsize *argp
+ // Get window size.
+ winsize, err := ioctlGetWinsize(fd)
+ if err != nil {
+ return 0, err
+ }
+ _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), winsize, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TIOCSWINSZ:
+ // Args: const struct winsize *argp
+ // Set window size.
+
+ // Unlike setting the termios, any process group (even
+ // background ones) can set the winsize.
+
+ var winsize linux.Winsize
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &winsize, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ err := ioctlSetWinsize(fd, &winsize)
+ return 0, err
+
+ // Unimplemented commands.
+ case linux.TIOCSETD,
+ linux.TIOCSBRK,
+ linux.TIOCCBRK,
+ linux.TCSBRK,
+ linux.TCSBRKP,
+ linux.TIOCSTI,
+ linux.TIOCCONS,
+ linux.FIONBIO,
+ linux.TIOCEXCL,
+ linux.TIOCNXCL,
+ linux.TIOCGEXCL,
+ linux.TIOCNOTTY,
+ linux.TIOCSCTTY,
+ linux.TIOCGSID,
+ linux.TIOCGETD,
+ linux.TIOCVHANGUP,
+ linux.TIOCGDEV,
+ linux.TIOCMGET,
+ linux.TIOCMSET,
+ linux.TIOCMBIC,
+ linux.TIOCMBIS,
+ linux.TIOCGICOUNT,
+ linux.TCFLSH,
+ linux.TIOCSSERIAL,
+ linux.TIOCGPTPEER:
+
+ unimpl.EmitUnimplementedEvent(ctx)
+ fallthrough
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+// checkChange checks that the process group is allowed to read, write, or
+// change the state of the TTY.
+//
+// This corresponds to Linux drivers/tty/tty_io.c:tty_check_change(). The logic
+// is a bit convoluted, but documented inline.
+//
+// Preconditions: t.mu must be held.
+func (t *TTYFileOperations) checkChange(ctx context.Context, sig linux.Signal) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ // No task? Linux does not have an analog for this case, but
+ // tty_check_change only blocks specific cases and is
+ // surprisingly permissive. Allowing the change seems
+ // appropriate.
+ return nil
+ }
+
+ tg := task.ThreadGroup()
+ pg := tg.ProcessGroup()
+
+ // If the session for the task is different than the session for the
+ // controlling TTY, then the change is allowed. Seems like a bad idea,
+ // but that's exactly what linux does.
+ if tg.Session() != t.fgProcessGroup.Session() {
+ return nil
+ }
+
+ // If we are the foreground process group, then the change is allowed.
+ if pg == t.fgProcessGroup {
+ return nil
+ }
+
+ // We are not the foreground process group.
+
+ // Is the provided signal blocked or ignored?
+ if (task.SignalMask()&linux.SignalSetOf(sig) != 0) || tg.SignalHandlers().IsIgnored(sig) {
+ // If the signal is SIGTTIN, then we are attempting to read
+ // from the TTY. Don't send the signal and return EIO.
+ if sig == linux.SIGTTIN {
+ return syserror.EIO
+ }
+
+ // Otherwise, we are writing or changing terminal state. This is allowed.
+ return nil
+ }
+
+ // If the process group is an orphan, return EIO.
+ if pg.IsOrphan() {
+ return syserror.EIO
+ }
+
+ // Otherwise, send the signal to the process group and return ERESTARTSYS.
+ //
+ // Note that Linux also unconditionally sets TIF_SIGPENDING on current,
+ // but this isn't necessary in gVisor because the rationale given in
+ // 040b6362d58f "tty: fix leakage of -ERESTARTSYS to userland" doesn't
+ // apply: the sentry will handle -ERESTARTSYS in
+ // kernel.runApp.execute() even if the kernel.Task isn't interrupted.
+ //
+ // Linux ignores the result of kill_pgrp().
+ _ = pg.SendSignal(kernel.SignalInfoPriv(sig))
+ return kernel.ERESTARTSYS
+}
+
+// LINT.ThenChange(../../fsimpl/host/tty.go)
diff --git a/pkg/sentry/fs/host/util.go b/pkg/sentry/fs/host/util.go
new file mode 100644
index 000000000..1b0356930
--- /dev/null
+++ b/pkg/sentry/fs/host/util.go
@@ -0,0 +1,129 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "os"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/device"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+func nodeType(s *syscall.Stat_t) fs.InodeType {
+ switch x := (s.Mode & syscall.S_IFMT); x {
+ case syscall.S_IFLNK:
+ return fs.Symlink
+ case syscall.S_IFIFO:
+ return fs.Pipe
+ case syscall.S_IFCHR:
+ return fs.CharacterDevice
+ case syscall.S_IFBLK:
+ return fs.BlockDevice
+ case syscall.S_IFSOCK:
+ return fs.Socket
+ case syscall.S_IFDIR:
+ return fs.Directory
+ case syscall.S_IFREG:
+ return fs.RegularFile
+ default:
+ // This shouldn't happen, but just in case...
+ log.Warningf("unknown host file type %d: assuming regular", x)
+ return fs.RegularFile
+ }
+}
+
+func wouldBlock(s *syscall.Stat_t) bool {
+ typ := nodeType(s)
+ return typ == fs.Pipe || typ == fs.Socket || typ == fs.CharacterDevice
+}
+
+func stableAttr(s *syscall.Stat_t) fs.StableAttr {
+ return fs.StableAttr{
+ Type: nodeType(s),
+ DeviceID: hostFileDevice.DeviceID(),
+ InodeID: hostFileDevice.Map(device.MultiDeviceKey{
+ Device: s.Dev,
+ Inode: s.Ino,
+ }),
+ BlockSize: int64(s.Blksize),
+ }
+}
+
+func owner(s *syscall.Stat_t) fs.FileOwner {
+ return fs.FileOwner{
+ UID: auth.KUID(s.Uid),
+ GID: auth.KGID(s.Gid),
+ }
+}
+
+func unstableAttr(s *syscall.Stat_t) fs.UnstableAttr {
+ return fs.UnstableAttr{
+ Size: s.Size,
+ Usage: s.Blocks * 512,
+ Perms: fs.FilePermsFromMode(linux.FileMode(s.Mode)),
+ Owner: owner(s),
+ AccessTime: ktime.FromUnix(s.Atim.Sec, s.Atim.Nsec),
+ ModificationTime: ktime.FromUnix(s.Mtim.Sec, s.Mtim.Nsec),
+ StatusChangeTime: ktime.FromUnix(s.Ctim.Sec, s.Ctim.Nsec),
+ Links: uint64(s.Nlink),
+ }
+}
+
+type dirInfo struct {
+ buf []byte // buffer for directory I/O.
+ nbuf int // length of buf; return value from ReadDirent.
+ bufp int // location of next record in buf.
+}
+
+// LINT.IfChange
+
+// isBlockError unwraps os errors and checks if they are caused by EAGAIN or
+// EWOULDBLOCK. This is so they can be transformed into syserror.ErrWouldBlock.
+func isBlockError(err error) bool {
+ if err == syserror.EAGAIN || err == syserror.EWOULDBLOCK {
+ return true
+ }
+ if pe, ok := err.(*os.PathError); ok {
+ return isBlockError(pe.Err)
+ }
+ return false
+}
+
+// LINT.ThenChange(../../fsimpl/host/util.go)
+
+func hostEffectiveKIDs() (uint32, []uint32, error) {
+ gids, err := os.Getgroups()
+ if err != nil {
+ return 0, nil, err
+ }
+ egids := make([]uint32, len(gids))
+ for i, gid := range gids {
+ egids[i] = uint32(gid)
+ }
+ return uint32(os.Geteuid()), append(egids, uint32(os.Getegid())), nil
+}
+
+var hostUID uint32
+var hostGIDs []uint32
+
+func init() {
+ hostUID, hostGIDs, _ = hostEffectiveKIDs()
+}
diff --git a/pkg/sentry/fs/host/util_amd64_unsafe.go b/pkg/sentry/fs/host/util_amd64_unsafe.go
new file mode 100644
index 000000000..66da6e9f5
--- /dev/null
+++ b/pkg/sentry/fs/host/util_amd64_unsafe.go
@@ -0,0 +1,41 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+func fstatat(fd int, name string, flags int) (syscall.Stat_t, error) {
+ var stat syscall.Stat_t
+ namePtr, err := syscall.BytePtrFromString(name)
+ if err != nil {
+ return stat, err
+ }
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_NEWFSTATAT,
+ uintptr(fd),
+ uintptr(unsafe.Pointer(namePtr)),
+ uintptr(unsafe.Pointer(&stat)),
+ uintptr(flags),
+ 0, 0)
+ if errno != 0 {
+ return stat, errno
+ }
+ return stat, nil
+}
diff --git a/pkg/sentry/fs/host/util_arm64_unsafe.go b/pkg/sentry/fs/host/util_arm64_unsafe.go
new file mode 100644
index 000000000..e8cb94aeb
--- /dev/null
+++ b/pkg/sentry/fs/host/util_arm64_unsafe.go
@@ -0,0 +1,41 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+func fstatat(fd int, name string, flags int) (syscall.Stat_t, error) {
+ var stat syscall.Stat_t
+ namePtr, err := syscall.BytePtrFromString(name)
+ if err != nil {
+ return stat, err
+ }
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_FSTATAT,
+ uintptr(fd),
+ uintptr(unsafe.Pointer(namePtr)),
+ uintptr(unsafe.Pointer(&stat)),
+ uintptr(flags),
+ 0, 0)
+ if errno != 0 {
+ return stat, errno
+ }
+ return stat, nil
+}
diff --git a/pkg/sentry/fs/host/util_unsafe.go b/pkg/sentry/fs/host/util_unsafe.go
new file mode 100644
index 000000000..23bd35d64
--- /dev/null
+++ b/pkg/sentry/fs/host/util_unsafe.go
@@ -0,0 +1,77 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+)
+
+// NulByte is a single NUL byte. It is passed to readlinkat as an empty string.
+var NulByte byte = '\x00'
+
+func readLink(fd int) (string, error) {
+ // Buffer sizing copied from os.Readlink.
+ for l := 128; ; l *= 2 {
+ b := make([]byte, l)
+ n, _, errno := syscall.Syscall6(
+ syscall.SYS_READLINKAT,
+ uintptr(fd),
+ uintptr(unsafe.Pointer(&NulByte)), // ""
+ uintptr(unsafe.Pointer(&b[0])),
+ uintptr(l),
+ 0, 0)
+ if errno != 0 {
+ return "", errno
+ }
+ if n < uintptr(l) {
+ return string(b[:n]), nil
+ }
+ }
+}
+
+func timespecFromTimestamp(t ktime.Time, omit, setSysTime bool) syscall.Timespec {
+ if omit {
+ return syscall.Timespec{0, linux.UTIME_OMIT}
+ }
+ if setSysTime {
+ return syscall.Timespec{0, linux.UTIME_NOW}
+ }
+ return syscall.NsecToTimespec(t.Nanoseconds())
+}
+
+func setTimestamps(fd int, ts fs.TimeSpec) error {
+ if ts.ATimeOmit && ts.MTimeOmit {
+ return nil
+ }
+ var sts [2]syscall.Timespec
+ sts[0] = timespecFromTimestamp(ts.ATime, ts.ATimeOmit, ts.ATimeSetSystemTime)
+ sts[1] = timespecFromTimestamp(ts.MTime, ts.MTimeOmit, ts.MTimeSetSystemTime)
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_UTIMENSAT,
+ uintptr(fd),
+ 0, /* path */
+ uintptr(unsafe.Pointer(&sts)),
+ 0, /* flags */
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
diff --git a/pkg/sentry/fs/host/wait_test.go b/pkg/sentry/fs/host/wait_test.go
new file mode 100644
index 000000000..ce397a5e3
--- /dev/null
+++ b/pkg/sentry/fs/host/wait_test.go
@@ -0,0 +1,69 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestWait(t *testing.T) {
+ var fds [2]int
+ err := syscall.Pipe(fds[:])
+ if err != nil {
+ t.Fatalf("Unable to create pipe: %v", err)
+ }
+
+ defer syscall.Close(fds[1])
+
+ ctx := contexttest.Context(t)
+ file, err := NewFile(ctx, fds[0])
+ if err != nil {
+ syscall.Close(fds[0])
+ t.Fatalf("NewFile failed: %v", err)
+ }
+
+ defer file.DecRef()
+
+ r := file.Readiness(waiter.EventIn)
+ if r != 0 {
+ t.Fatalf("File is ready for read when it shouldn't be.")
+ }
+
+ e, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&e, waiter.EventIn)
+ defer file.EventUnregister(&e)
+
+ // Check that there are no notifications yet.
+ if len(ch) != 0 {
+ t.Fatalf("Channel is non-empty")
+ }
+
+ // Write to the pipe, so it should be writable now.
+ syscall.Write(fds[1], []byte{1})
+
+ // Check that we get a notification. We need to yield the current thread
+ // so that the fdnotifier can deliver notifications, so we use a
+ // 1-second timeout instead of just checking the length of the channel.
+ select {
+ case <-ch:
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Channel not notified")
+ }
+}
diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go
new file mode 100644
index 000000000..a34fbc946
--- /dev/null
+++ b/pkg/sentry/fs/inode.go
@@ -0,0 +1,477 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/metric"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+var opens = metric.MustCreateNewUint64Metric("/fs/opens", false /* sync */, "Number of file opens.")
+
+// Inode is a file system object that can be simultaneously referenced by different
+// components of the VFS (Dirent, fs.File, etc).
+//
+// +stateify savable
+type Inode struct {
+ // AtomicRefCount is our reference count.
+ refs.AtomicRefCount
+
+ // InodeOperations is the file system specific behavior of the Inode.
+ InodeOperations InodeOperations
+
+ // StableAttr are stable cached attributes of the Inode.
+ StableAttr StableAttr
+
+ // LockCtx is the file lock context. It manages its own sychronization and tracks
+ // regions of the Inode that have locks held.
+ LockCtx LockCtx
+
+ // Watches is the set of inotify watches for this inode.
+ Watches *Watches
+
+ // MountSource is the mount source this Inode is a part of.
+ MountSource *MountSource
+
+ // overlay is the overlay entry for this Inode.
+ overlay *overlayEntry
+
+ // appendMu is used to synchronize write operations into files which
+ // have been opened with O_APPEND. Operations which change a file size
+ // have to take this lock for read. Write operations to files with
+ // O_APPEND have to take this lock for write.
+ appendMu sync.RWMutex `state:"nosave"`
+}
+
+// LockCtx is an Inode's lock context and contains different personalities of locks; both
+// Posix and BSD style locks are supported.
+//
+// Note that in Linux fcntl(2) and flock(2) locks are _not_ cooperative, because race and
+// deadlock conditions make merging them prohibitive. We do the same and keep them oblivious
+// to each other but provide a "context" as a convenient container.
+//
+// +stateify savable
+type LockCtx struct {
+ // Posix is a set of POSIX-style regional advisory locks, see fcntl(2).
+ Posix lock.Locks
+
+ // BSD is a set of BSD-style advisory file wide locks, see flock(2).
+ BSD lock.Locks
+}
+
+// NewInode constructs an Inode from InodeOperations, a MountSource, and stable attributes.
+//
+// NewInode takes a reference on msrc.
+func NewInode(ctx context.Context, iops InodeOperations, msrc *MountSource, sattr StableAttr) *Inode {
+ msrc.IncRef()
+ i := Inode{
+ InodeOperations: iops,
+ StableAttr: sattr,
+ Watches: newWatches(),
+ MountSource: msrc,
+ }
+ i.EnableLeakCheck("fs.Inode")
+ return &i
+}
+
+// DecRef drops a reference on the Inode.
+func (i *Inode) DecRef() {
+ i.DecRefWithDestructor(i.destroy)
+}
+
+// destroy releases the Inode and releases the msrc reference taken.
+func (i *Inode) destroy() {
+ ctx := context.Background()
+ if err := i.WriteOut(ctx); err != nil {
+ // FIXME(b/65209558): Mark as warning again once noatime is
+ // properly supported.
+ log.Debugf("Inode %+v, failed to sync all metadata: %v", i.StableAttr, err)
+ }
+
+ // If this inode is being destroyed because it was unlinked, queue a
+ // deletion event. This may not be the case for inodes being revalidated.
+ if i.Watches.unlinked {
+ i.Watches.Notify("", linux.IN_DELETE_SELF, 0)
+ }
+
+ // Remove references from the watch owners to the watches on this inode,
+ // since the watches are about to be GCed. Note that we don't need to worry
+ // about the watch pins since if there were any active pins, this inode
+ // wouldn't be in the destructor.
+ i.Watches.targetDestroyed()
+
+ if i.overlay != nil {
+ i.overlay.release()
+ } else {
+ i.InodeOperations.Release(ctx)
+ }
+
+ i.MountSource.DecRef()
+}
+
+// Mappable calls i.InodeOperations.Mappable.
+func (i *Inode) Mappable() memmap.Mappable {
+ if i.overlay != nil {
+ // In an overlay, Mappable is always implemented by
+ // the overlayEntry metadata to synchronize memory
+ // access of files with copy up. But first check if
+ // the Inodes involved would be mappable in the first
+ // place.
+ i.overlay.copyMu.RLock()
+ ok := i.overlay.isMappableLocked()
+ i.overlay.copyMu.RUnlock()
+ if !ok {
+ return nil
+ }
+ return i.overlay
+ }
+ return i.InodeOperations.Mappable(i)
+}
+
+// WriteOut calls i.InodeOperations.WriteOut with i as the Inode.
+func (i *Inode) WriteOut(ctx context.Context) error {
+ if i.overlay != nil {
+ return overlayWriteOut(ctx, i.overlay)
+ }
+ return i.InodeOperations.WriteOut(ctx, i)
+}
+
+// Lookup calls i.InodeOperations.Lookup with i as the directory.
+func (i *Inode) Lookup(ctx context.Context, name string) (*Dirent, error) {
+ if i.overlay != nil {
+ d, _, err := overlayLookup(ctx, i.overlay, i, name)
+ return d, err
+ }
+ return i.InodeOperations.Lookup(ctx, i, name)
+}
+
+// Create calls i.InodeOperations.Create with i as the directory.
+func (i *Inode) Create(ctx context.Context, d *Dirent, name string, flags FileFlags, perm FilePermissions) (*File, error) {
+ if i.overlay != nil {
+ return overlayCreate(ctx, i.overlay, d, name, flags, perm)
+ }
+ return i.InodeOperations.Create(ctx, i, name, flags, perm)
+}
+
+// CreateDirectory calls i.InodeOperations.CreateDirectory with i as the directory.
+func (i *Inode) CreateDirectory(ctx context.Context, d *Dirent, name string, perm FilePermissions) error {
+ if i.overlay != nil {
+ return overlayCreateDirectory(ctx, i.overlay, d, name, perm)
+ }
+ return i.InodeOperations.CreateDirectory(ctx, i, name, perm)
+}
+
+// CreateLink calls i.InodeOperations.CreateLink with i as the directory.
+func (i *Inode) CreateLink(ctx context.Context, d *Dirent, oldname string, newname string) error {
+ if i.overlay != nil {
+ return overlayCreateLink(ctx, i.overlay, d, oldname, newname)
+ }
+ return i.InodeOperations.CreateLink(ctx, i, oldname, newname)
+}
+
+// CreateHardLink calls i.InodeOperations.CreateHardLink with i as the directory.
+func (i *Inode) CreateHardLink(ctx context.Context, d *Dirent, target *Dirent, name string) error {
+ if i.overlay != nil {
+ return overlayCreateHardLink(ctx, i.overlay, d, target, name)
+ }
+ return i.InodeOperations.CreateHardLink(ctx, i, target.Inode, name)
+}
+
+// CreateFifo calls i.InodeOperations.CreateFifo with i as the directory.
+func (i *Inode) CreateFifo(ctx context.Context, d *Dirent, name string, perm FilePermissions) error {
+ if i.overlay != nil {
+ return overlayCreateFifo(ctx, i.overlay, d, name, perm)
+ }
+ return i.InodeOperations.CreateFifo(ctx, i, name, perm)
+}
+
+// Remove calls i.InodeOperations.Remove/RemoveDirectory with i as the directory.
+func (i *Inode) Remove(ctx context.Context, d *Dirent, remove *Dirent) error {
+ if i.overlay != nil {
+ return overlayRemove(ctx, i.overlay, d, remove)
+ }
+ switch remove.Inode.StableAttr.Type {
+ case Directory, SpecialDirectory:
+ return i.InodeOperations.RemoveDirectory(ctx, i, remove.name)
+ default:
+ return i.InodeOperations.Remove(ctx, i, remove.name)
+ }
+}
+
+// Rename calls i.InodeOperations.Rename with the given arguments.
+func (i *Inode) Rename(ctx context.Context, oldParent *Dirent, renamed *Dirent, newParent *Dirent, newName string, replacement bool) error {
+ if i.overlay != nil {
+ return overlayRename(ctx, i.overlay, oldParent, renamed, newParent, newName, replacement)
+ }
+ return i.InodeOperations.Rename(ctx, renamed.Inode, oldParent.Inode, renamed.name, newParent.Inode, newName, replacement)
+}
+
+// Bind calls i.InodeOperations.Bind with i as the directory.
+func (i *Inode) Bind(ctx context.Context, parent *Dirent, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) {
+ if i.overlay != nil {
+ return overlayBind(ctx, i.overlay, parent, name, data, perm)
+ }
+ return i.InodeOperations.Bind(ctx, i, name, data, perm)
+}
+
+// BoundEndpoint calls i.InodeOperations.BoundEndpoint with i as the Inode.
+func (i *Inode) BoundEndpoint(path string) transport.BoundEndpoint {
+ if i.overlay != nil {
+ return overlayBoundEndpoint(i.overlay, path)
+ }
+ return i.InodeOperations.BoundEndpoint(i, path)
+}
+
+// GetFile calls i.InodeOperations.GetFile with the given arguments.
+func (i *Inode) GetFile(ctx context.Context, d *Dirent, flags FileFlags) (*File, error) {
+ if i.overlay != nil {
+ return overlayGetFile(ctx, i.overlay, d, flags)
+ }
+ opens.Increment()
+ return i.InodeOperations.GetFile(ctx, d, flags)
+}
+
+// UnstableAttr calls i.InodeOperations.UnstableAttr with i as the Inode.
+func (i *Inode) UnstableAttr(ctx context.Context) (UnstableAttr, error) {
+ if i.overlay != nil {
+ return overlayUnstableAttr(ctx, i.overlay)
+ }
+ return i.InodeOperations.UnstableAttr(ctx, i)
+}
+
+// GetXattr calls i.InodeOperations.GetXattr with i as the Inode.
+func (i *Inode) GetXattr(ctx context.Context, name string, size uint64) (string, error) {
+ if i.overlay != nil {
+ return overlayGetXattr(ctx, i.overlay, name, size)
+ }
+ return i.InodeOperations.GetXattr(ctx, i, name, size)
+}
+
+// SetXattr calls i.InodeOperations.SetXattr with i as the Inode.
+func (i *Inode) SetXattr(ctx context.Context, d *Dirent, name, value string, flags uint32) error {
+ if i.overlay != nil {
+ return overlaySetxattr(ctx, i.overlay, d, name, value, flags)
+ }
+ return i.InodeOperations.SetXattr(ctx, i, name, value, flags)
+}
+
+// ListXattr calls i.InodeOperations.ListXattr with i as the Inode.
+func (i *Inode) ListXattr(ctx context.Context, size uint64) (map[string]struct{}, error) {
+ if i.overlay != nil {
+ return overlayListXattr(ctx, i.overlay, size)
+ }
+ return i.InodeOperations.ListXattr(ctx, i, size)
+}
+
+// RemoveXattr calls i.InodeOperations.RemoveXattr with i as the Inode.
+func (i *Inode) RemoveXattr(ctx context.Context, d *Dirent, name string) error {
+ if i.overlay != nil {
+ return overlayRemoveXattr(ctx, i.overlay, d, name)
+ }
+ return i.InodeOperations.RemoveXattr(ctx, i, name)
+}
+
+// CheckPermission will check if the caller may access this file in the
+// requested way for reading, writing, or executing.
+//
+// CheckPermission is like Linux's fs/namei.c:inode_permission. It
+// - checks file system mount flags,
+// - and utilizes InodeOperations.Check to check capabilities and modes.
+func (i *Inode) CheckPermission(ctx context.Context, p PermMask) error {
+ // First check the outer-most mounted filesystem.
+ if p.Write && i.MountSource.Flags.ReadOnly {
+ return syserror.EROFS
+ }
+
+ if i.overlay != nil {
+ // CheckPermission requires some special handling for
+ // an overlay.
+ //
+ // Writes will always be redirected to an upper filesystem,
+ // so ignore all lower layers being read-only.
+ //
+ // But still honor the upper-most filesystem's mount flags;
+ // we should not attempt to modify the writable layer if it
+ // is mounted read-only.
+ if p.Write && overlayUpperMountSource(i.MountSource).Flags.ReadOnly {
+ return syserror.EROFS
+ }
+ }
+
+ return i.check(ctx, p)
+}
+
+func (i *Inode) check(ctx context.Context, p PermMask) error {
+ if i.overlay != nil {
+ return overlayCheck(ctx, i.overlay, p)
+ }
+ if !i.InodeOperations.Check(ctx, i, p) {
+ return syserror.EACCES
+ }
+ return nil
+}
+
+// SetPermissions calls i.InodeOperations.SetPermissions with i as the Inode.
+func (i *Inode) SetPermissions(ctx context.Context, d *Dirent, f FilePermissions) bool {
+ if i.overlay != nil {
+ return overlaySetPermissions(ctx, i.overlay, d, f)
+ }
+ return i.InodeOperations.SetPermissions(ctx, i, f)
+}
+
+// SetOwner calls i.InodeOperations.SetOwner with i as the Inode.
+func (i *Inode) SetOwner(ctx context.Context, d *Dirent, o FileOwner) error {
+ if i.overlay != nil {
+ return overlaySetOwner(ctx, i.overlay, d, o)
+ }
+ return i.InodeOperations.SetOwner(ctx, i, o)
+}
+
+// SetTimestamps calls i.InodeOperations.SetTimestamps with i as the Inode.
+func (i *Inode) SetTimestamps(ctx context.Context, d *Dirent, ts TimeSpec) error {
+ if i.overlay != nil {
+ return overlaySetTimestamps(ctx, i.overlay, d, ts)
+ }
+ return i.InodeOperations.SetTimestamps(ctx, i, ts)
+}
+
+// Truncate calls i.InodeOperations.Truncate with i as the Inode.
+func (i *Inode) Truncate(ctx context.Context, d *Dirent, size int64) error {
+ if IsDir(i.StableAttr) {
+ return syserror.EISDIR
+ }
+
+ if i.overlay != nil {
+ return overlayTruncate(ctx, i.overlay, d, size)
+ }
+ i.appendMu.RLock()
+ defer i.appendMu.RUnlock()
+ return i.InodeOperations.Truncate(ctx, i, size)
+}
+
+func (i *Inode) Allocate(ctx context.Context, d *Dirent, offset int64, length int64) error {
+ if i.overlay != nil {
+ return overlayAllocate(ctx, i.overlay, d, offset, length)
+ }
+ return i.InodeOperations.Allocate(ctx, i, offset, length)
+}
+
+// Readlink calls i.InodeOperations.Readlnk with i as the Inode.
+func (i *Inode) Readlink(ctx context.Context) (string, error) {
+ if i.overlay != nil {
+ return overlayReadlink(ctx, i.overlay)
+ }
+ return i.InodeOperations.Readlink(ctx, i)
+}
+
+// Getlink calls i.InodeOperations.Getlink.
+func (i *Inode) Getlink(ctx context.Context) (*Dirent, error) {
+ if i.overlay != nil {
+ return overlayGetlink(ctx, i.overlay)
+ }
+ return i.InodeOperations.Getlink(ctx, i)
+}
+
+// AddLink calls i.InodeOperations.AddLink.
+func (i *Inode) AddLink() {
+ if i.overlay != nil {
+ // This interface is only used by ramfs to update metadata of
+ // children. These filesystems should _never_ have overlay
+ // Inodes cached as children. So explicitly disallow this
+ // scenario and avoid plumbing Dirents through to do copy up.
+ panic("overlay Inodes cached in ramfs directories are not supported")
+ }
+ i.InodeOperations.AddLink()
+}
+
+// DropLink calls i.InodeOperations.DropLink.
+func (i *Inode) DropLink() {
+ if i.overlay != nil {
+ // Same as AddLink.
+ panic("overlay Inodes cached in ramfs directories are not supported")
+ }
+ i.InodeOperations.DropLink()
+}
+
+// IsVirtual calls i.InodeOperations.IsVirtual.
+func (i *Inode) IsVirtual() bool {
+ if i.overlay != nil {
+ // An overlay configuration does not support virtual files.
+ return false
+ }
+ return i.InodeOperations.IsVirtual()
+}
+
+// StatFS calls i.InodeOperations.StatFS.
+func (i *Inode) StatFS(ctx context.Context) (Info, error) {
+ if i.overlay != nil {
+ return overlayStatFS(ctx, i.overlay)
+ }
+ return i.InodeOperations.StatFS(ctx)
+}
+
+// CheckOwnership checks whether `ctx` owns this Inode or may act as its owner.
+// Compare Linux's fs/inode.c:inode_owner_or_capable().
+func (i *Inode) CheckOwnership(ctx context.Context) bool {
+ uattr, err := i.UnstableAttr(ctx)
+ if err != nil {
+ return false
+ }
+ creds := auth.CredentialsFromContext(ctx)
+ if uattr.Owner.UID == creds.EffectiveKUID {
+ return true
+ }
+ if creds.HasCapability(linux.CAP_FOWNER) && creds.UserNamespace.MapFromKUID(uattr.Owner.UID).Ok() {
+ return true
+ }
+ return false
+}
+
+// CheckCapability checks whether `ctx` has capability `cp` with respect to
+// operations on this Inode.
+//
+// Compare Linux's kernel/capability.c:capable_wrt_inode_uidgid().
+func (i *Inode) CheckCapability(ctx context.Context, cp linux.Capability) bool {
+ uattr, err := i.UnstableAttr(ctx)
+ if err != nil {
+ return false
+ }
+ creds := auth.CredentialsFromContext(ctx)
+ if !creds.UserNamespace.MapFromKUID(uattr.Owner.UID).Ok() {
+ return false
+ }
+ if !creds.UserNamespace.MapFromKGID(uattr.Owner.GID).Ok() {
+ return false
+ }
+ return creds.HasCapability(cp)
+}
+
+func (i *Inode) lockAppendMu(appendMode bool) func() {
+ if appendMode {
+ i.appendMu.Lock()
+ return i.appendMu.Unlock
+ }
+ i.appendMu.RLock()
+ return i.appendMu.RUnlock
+}
diff --git a/pkg/sentry/fs/inode_inotify.go b/pkg/sentry/fs/inode_inotify.go
new file mode 100644
index 000000000..efd3c962b
--- /dev/null
+++ b/pkg/sentry/fs/inode_inotify.go
@@ -0,0 +1,170 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Watches is the collection of inotify watches on an inode.
+//
+// +stateify savable
+type Watches struct {
+ // mu protects the fields below.
+ mu sync.RWMutex `state:"nosave"`
+
+ // ws is the map of active watches in this collection, keyed by the inotify
+ // instance id of the owner.
+ ws map[uint64]*Watch
+
+ // unlinked indicates whether the target inode was ever unlinked. This is a
+ // hack to figure out if we should queue a IN_DELETE_SELF event when this
+ // watches collection is being destroyed, since otherwise we have no way of
+ // knowing if the target inode is going down due to a deletion or
+ // revalidation.
+ unlinked bool
+}
+
+func newWatches() *Watches {
+ return &Watches{}
+}
+
+// MarkUnlinked indicates the target for this set of watches to be unlinked.
+// This has implications for the IN_EXCL_UNLINK flag.
+func (w *Watches) MarkUnlinked() {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ w.unlinked = true
+}
+
+// Lookup returns a matching watch with the given id. Returns nil if no such
+// watch exists. Note that the result returned by this method only remains valid
+// if the inotify instance owning the watch is locked, preventing modification
+// of the returned watch and preventing the replacement of the watch by another
+// one from the same instance (since there may be at most one watch per
+// instance, per target).
+func (w *Watches) Lookup(id uint64) *Watch {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ return w.ws[id]
+}
+
+// Add adds watch into this set of watches. The watch being added must be unique
+// - its ID() should not collide with any existing watches.
+func (w *Watches) Add(watch *Watch) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ // Sanity check, the new watch shouldn't collide with an existing
+ // watch. Silently replacing an existing watch would result in a ref leak on
+ // this inode. We could handle this collision by calling Unpin() on the
+ // existing watch, but then we end up leaking watch descriptor ids at the
+ // inotify level.
+ if _, exists := w.ws[watch.ID()]; exists {
+ panic(fmt.Sprintf("Watch collision with ID %+v", watch.ID()))
+ }
+ if w.ws == nil {
+ w.ws = make(map[uint64]*Watch)
+ }
+ w.ws[watch.ID()] = watch
+}
+
+// Remove removes a watch with the given id from this set of watches. The caller
+// is responsible for generating any watch removal event, as appropriate. The
+// provided id must match an existing watch in this collection.
+func (w *Watches) Remove(id uint64) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ if w.ws == nil {
+ // This watch set is being destroyed. The thread executing the
+ // destructor is already in the process of deleting all our watches. We
+ // got here with no refs on the inode because we raced with the
+ // destructor notifying all the watch owners of the inode's destruction.
+ // See the comment in Watches.TargetDestroyed for why this race exists.
+ return
+ }
+
+ watch, ok := w.ws[id]
+ if !ok {
+ // While there's technically no problem with silently ignoring a missing
+ // watch, this is almost certainly a bug.
+ panic(fmt.Sprintf("Attempt to remove a watch, but no watch found with provided id %+v.", id))
+ }
+ delete(w.ws, watch.ID())
+}
+
+// Notify queues a new event with all watches in this set.
+func (w *Watches) Notify(name string, events, cookie uint32) {
+ // N.B. We don't defer the unlocks because Notify is in the hot path of
+ // all IO operations, and the defer costs too much for small IO
+ // operations.
+ w.mu.RLock()
+ for _, watch := range w.ws {
+ if name != "" && w.unlinked && !watch.NotifyParentAfterUnlink() {
+ // IN_EXCL_UNLINK - By default, when watching events on the children
+ // of a directory, events are generated for children even after they
+ // have been unlinked from the directory. This can result in large
+ // numbers of uninteresting events for some applications (e.g., if
+ // watching /tmp, in which many applications create temporary files
+ // whose names are immediately unlinked). Specifying IN_EXCL_UNLINK
+ // changes the default behavior, so that events are not generated
+ // for children after they have been unlinked from the watched
+ // directory. -- inotify(7)
+ //
+ // We know we're dealing with events for a parent when the name
+ // isn't empty.
+ continue
+ }
+ watch.Notify(name, events, cookie)
+ }
+ w.mu.RUnlock()
+}
+
+// Unpin unpins dirent from all watches in this set.
+func (w *Watches) Unpin(d *Dirent) {
+ w.mu.RLock()
+ defer w.mu.RUnlock()
+ for _, watch := range w.ws {
+ watch.Unpin(d)
+ }
+}
+
+// targetDestroyed is called by the inode destructor to notify the watch owners
+// of the impending destruction of the watch target.
+func (w *Watches) targetDestroyed() {
+ var ws map[uint64]*Watch
+
+ // We can't hold w.mu while calling watch.TargetDestroyed to preserve lock
+ // ordering w.r.t to the owner inotify instances. Instead, atomically move
+ // the watches map into a local variable so we can iterate over it safely.
+ //
+ // Because of this however, it is possible for the watches' owners to reach
+ // this inode while the inode has no refs. This is still safe because the
+ // owners can only reach the inode until this function finishes calling
+ // watch.TargetDestroyed() below and the inode is guaranteed to exist in the
+ // meanwhile. But we still have to be very careful not to rely on inode
+ // state that may have been already destroyed.
+ w.mu.Lock()
+ ws = w.ws
+ w.ws = nil
+ w.mu.Unlock()
+
+ for _, watch := range ws {
+ watch.TargetDestroyed()
+ }
+}
diff --git a/pkg/sentry/fs/inode_operations.go b/pkg/sentry/fs/inode_operations.go
new file mode 100644
index 000000000..2bbfb72ef
--- /dev/null
+++ b/pkg/sentry/fs/inode_operations.go
@@ -0,0 +1,326 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "errors"
+
+ "gvisor.dev/gvisor/pkg/context"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+)
+
+var (
+ // ErrResolveViaReadlink is a special error value returned by
+ // InodeOperations.Getlink() to indicate that a link should be
+ // resolved automatically by walking to the path returned by
+ // InodeOperations.Readlink().
+ ErrResolveViaReadlink = errors.New("link should be resolved via Readlink()")
+)
+
+// TimeSpec contains access and modification timestamps. If either ATimeOmit or
+// MTimeOmit is true, then the corresponding timestamp should not be updated.
+// If either ATimeSetSystemTime or MTimeSetSystemTime are set then the
+// corresponding timestamp should be ignored and the time will be set to the
+// current system time.
+type TimeSpec struct {
+ ATime ktime.Time
+ ATimeOmit bool
+ ATimeSetSystemTime bool
+ MTime ktime.Time
+ MTimeOmit bool
+ MTimeSetSystemTime bool
+}
+
+// InodeOperations are operations on an Inode that diverge per file system.
+//
+// Objects that implement InodeOperations may cache file system "private"
+// data that is useful for implementing these methods. In contrast, Inode
+// contains state that is common to all Inodes; this state may be optionally
+// used by InodeOperations. An object that implements InodeOperations may
+// not take a reference on an Inode.
+type InodeOperations interface {
+ // Release releases all private file system data held by this object.
+ // Once Release is called, this object is dead (no other methods will
+ // ever be called).
+ Release(context.Context)
+
+ // Lookup loads an Inode at name under dir into a Dirent. The name
+ // is a valid component path: it contains no "/"s nor is the empty
+ // string.
+ //
+ // Lookup may return one of:
+ //
+ // * A nil Dirent and a non-nil error. If the reason that Lookup failed
+ // was because the name does not exist under Inode, then must return
+ // syserror.ENOENT.
+ //
+ // * If name does not exist under dir and the file system wishes this
+ // fact to be cached, a non-nil Dirent containing a nil Inode and a
+ // nil error. This is a negative Dirent and must have exactly one
+ // reference (at-construction reference).
+ //
+ // * If name does exist under this dir, a non-nil Dirent containing a
+ // non-nil Inode, and a nil error. File systems that take extra
+ // references on this Dirent should implement DirentOperations.
+ Lookup(ctx context.Context, dir *Inode, name string) (*Dirent, error)
+
+ // Create creates an Inode at name under dir and returns a new File
+ // whose Dirent backs the new Inode. Implementations must ensure that
+ // name does not already exist. Create may return one of:
+ //
+ // * A nil File and a non-nil error.
+ //
+ // * A non-nil File and a nil error. File.Dirent will be a new Dirent,
+ // with a single reference held by File. File systems that take extra
+ // references on this Dirent should implement DirentOperations.
+ //
+ // The caller must ensure that this operation is permitted.
+ Create(ctx context.Context, dir *Inode, name string, flags FileFlags, perm FilePermissions) (*File, error)
+
+ // CreateDirectory creates a new directory under this dir.
+ // CreateDirectory should otherwise do the same as Create.
+ //
+ // The caller must ensure that this operation is permitted.
+ CreateDirectory(ctx context.Context, dir *Inode, name string, perm FilePermissions) error
+
+ // CreateLink creates a symbolic link under dir between newname
+ // and oldname. CreateLink should otherwise do the same as Create.
+ //
+ // The caller must ensure that this operation is permitted.
+ CreateLink(ctx context.Context, dir *Inode, oldname string, newname string) error
+
+ // CreateHardLink creates a hard link under dir between the target
+ // Inode and name.
+ //
+ // The caller must ensure this operation is permitted.
+ CreateHardLink(ctx context.Context, dir *Inode, target *Inode, name string) error
+
+ // CreateFifo creates a new named pipe under dir at name.
+ //
+ // The caller must ensure that this operation is permitted.
+ CreateFifo(ctx context.Context, dir *Inode, name string, perm FilePermissions) error
+
+ // Remove removes the given named non-directory under dir.
+ //
+ // The caller must ensure that this operation is permitted.
+ Remove(ctx context.Context, dir *Inode, name string) error
+
+ // RemoveDirectory removes the given named directory under dir.
+ //
+ // The caller must ensure that this operation is permitted.
+ //
+ // RemoveDirectory should check that the directory to be
+ // removed is empty.
+ RemoveDirectory(ctx context.Context, dir *Inode, name string) error
+
+ // Rename atomically renames oldName under oldParent to newName under
+ // newParent where oldParent and newParent are directories. inode is
+ // the Inode of this InodeOperations.
+ //
+ // If replacement is true, then newName already exists and this call
+ // will replace it with oldName.
+ //
+ // Implementations are responsible for rejecting renames that replace
+ // non-empty directories.
+ Rename(ctx context.Context, inode *Inode, oldParent *Inode, oldName string, newParent *Inode, newName string, replacement bool) error
+
+ // Bind binds a new socket under dir at the given name.
+ //
+ // The caller must ensure that this operation is permitted.
+ Bind(ctx context.Context, dir *Inode, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error)
+
+ // BoundEndpoint returns the socket endpoint at path stored in
+ // or generated by an Inode.
+ //
+ // The path is only relevant for generated endpoint because stored
+ // endpoints already know their path. It is ok for the endpoint to
+ // hold onto their path because the only way to change a bind
+ // address is to rebind the socket.
+ //
+ // This is valid iff the type of the Inode is a Socket, which
+ // generally implies that this Inode was created via CreateSocket.
+ //
+ // If there is no socket endpoint available, nil will be returned.
+ BoundEndpoint(inode *Inode, path string) transport.BoundEndpoint
+
+ // GetFile returns a new open File backed by a Dirent and FileFlags.
+ //
+ // Special Inode types may block using ctx.Sleeper. RegularFiles,
+ // Directories, and Symlinks must not block (see doCopyUp).
+ //
+ // The returned File will uniquely back an application fd.
+ GetFile(ctx context.Context, d *Dirent, flags FileFlags) (*File, error)
+
+ // UnstableAttr returns the most up-to-date "unstable" attributes of
+ // an Inode, where "unstable" means that they change in response to
+ // file system events.
+ UnstableAttr(ctx context.Context, inode *Inode) (UnstableAttr, error)
+
+ // GetXattr retrieves the value of extended attribute specified by name.
+ // Inodes that do not support extended attributes return EOPNOTSUPP. Inodes
+ // that support extended attributes but don't have a value at name return
+ // ENODATA.
+ //
+ // If this is called through the getxattr(2) syscall, size indicates the
+ // size of the buffer that the application has allocated to hold the
+ // attribute value. If the value is larger than size, implementations may
+ // return ERANGE to indicate that the buffer is too small, but they are also
+ // free to ignore the hint entirely (i.e. the value returned may be larger
+ // than size). All size checking is done independently at the syscall layer.
+ GetXattr(ctx context.Context, inode *Inode, name string, size uint64) (string, error)
+
+ // SetXattr sets the value of extended attribute specified by name. Inodes
+ // that do not support extended attributes return EOPNOTSUPP.
+ SetXattr(ctx context.Context, inode *Inode, name, value string, flags uint32) error
+
+ // ListXattr returns the set of all extended attributes names that
+ // have values. Inodes that do not support extended attributes return
+ // EOPNOTSUPP.
+ //
+ // If this is called through the listxattr(2) syscall, size indicates the
+ // size of the buffer that the application has allocated to hold the
+ // attribute list. If the list would be larger than size, implementations may
+ // return ERANGE to indicate that the buffer is too small, but they are also
+ // free to ignore the hint entirely. All size checking is done independently
+ // at the syscall layer.
+ ListXattr(ctx context.Context, inode *Inode, size uint64) (map[string]struct{}, error)
+
+ // RemoveXattr removes an extended attribute specified by name. Inodes that
+ // do not support extended attributes return EOPNOTSUPP.
+ RemoveXattr(ctx context.Context, inode *Inode, name string) error
+
+ // Check determines whether an Inode can be accessed with the
+ // requested permission mask using the context (which gives access
+ // to Credentials and UserNamespace).
+ Check(ctx context.Context, inode *Inode, p PermMask) bool
+
+ // SetPermissions sets new permissions for an Inode. Returns false
+ // if it was not possible to set the new permissions.
+ //
+ // The caller must ensure that this operation is permitted.
+ SetPermissions(ctx context.Context, inode *Inode, f FilePermissions) bool
+
+ // SetOwner sets the ownership for this file.
+ //
+ // If either UID or GID are set to auth.NoID, its value will not be
+ // changed.
+ //
+ // The caller must ensure that this operation is permitted.
+ SetOwner(ctx context.Context, inode *Inode, owner FileOwner) error
+
+ // SetTimestamps sets the access and modification timestamps of an
+ // Inode according to the access and modification times in the TimeSpec.
+ //
+ // If either ATimeOmit or MTimeOmit is set, then the corresponding
+ // timestamp is not updated.
+ //
+ // If either ATimeSetSystemTime or MTimeSetSystemTime is true, that
+ // timestamp is set to the current time instead.
+ //
+ // The caller must ensure that this operation is permitted.
+ SetTimestamps(ctx context.Context, inode *Inode, ts TimeSpec) error
+
+ // Truncate changes the size of an Inode. Truncate should not check
+ // permissions internally, as it is used for both sys_truncate and
+ // sys_ftruncate.
+ //
+ // Implementations need not check that length >= 0.
+ Truncate(ctx context.Context, inode *Inode, size int64) error
+
+ // Allocate allows the caller to reserve disk space for the inode.
+ // It's equivalent to fallocate(2) with 'mode=0'.
+ Allocate(ctx context.Context, inode *Inode, offset int64, length int64) error
+
+ // WriteOut writes cached Inode state to a backing filesystem in a
+ // synchronous manner.
+ //
+ // File systems that do not cache metadata or data via an Inode
+ // implement WriteOut as a no-op. File systems that are entirely in
+ // memory also implement WriteOut as a no-op. Otherwise file systems
+ // call Inode.Sync to write back page cached data and cached metadata
+ // followed by syncing writeback handles.
+ //
+ // It derives from include/linux/fs.h:super_operations->write_inode.
+ WriteOut(ctx context.Context, inode *Inode) error
+
+ // Readlink reads the symlink path of an Inode.
+ //
+ // Readlink is permitted to return a different path depending on ctx,
+ // the request originator.
+ //
+ // The caller must ensure that this operation is permitted.
+ //
+ // Readlink should check that Inode is a symlink and its content is
+ // at least readable.
+ Readlink(ctx context.Context, inode *Inode) (string, error)
+
+ // Getlink resolves a symlink to a target *Dirent.
+ //
+ // Filesystems that can resolve the link by walking to the path returned
+ // by Readlink should return (nil, ErrResolveViaReadlink), which
+ // triggers link resolution via Realink and Lookup.
+ //
+ // Some links cannot be followed by Lookup. In this case, Getlink can
+ // return the Dirent of the link target. The caller holds a reference
+ // to the Dirent. Filesystems that return a non-nil *Dirent from Getlink
+ // cannot participate in an overlay because it is impossible for the
+ // overlay to ascertain whether or not the *Dirent should contain an
+ // overlayEntry.
+ //
+ // Any error returned from Getlink other than ErrResolveViaReadlink
+ // indicates the caller's inability to traverse this Inode as a link
+ // (e.g. syserror.ENOLINK indicates that the Inode is not a link,
+ // syscall.EPERM indicates that traversing the link is not allowed, etc).
+ Getlink(context.Context, *Inode) (*Dirent, error)
+
+ // Mappable returns a memmap.Mappable that provides memory mappings of the
+ // Inode's data. Mappable may return nil if this is not supported. The
+ // returned Mappable must remain valid until InodeOperations.Release is
+ // called.
+ Mappable(*Inode) memmap.Mappable
+
+ // The below methods require cleanup.
+
+ // AddLink increments the hard link count of an Inode.
+ //
+ // Remove in favor of Inode.IncLink.
+ AddLink()
+
+ // DropLink decrements the hard link count of an Inode.
+ //
+ // Remove in favor of Inode.DecLink.
+ DropLink()
+
+ // NotifyStatusChange sets the status change time to the current time.
+ //
+ // Remove in favor of updating the Inode's cached status change time.
+ NotifyStatusChange(ctx context.Context)
+
+ // IsVirtual indicates whether or not this corresponds to a virtual
+ // resource.
+ //
+ // If IsVirtual returns true, then caching will be disabled for this
+ // node, and fs.Dirent.Freeze() will not stop operations on the node.
+ //
+ // Remove in favor of freezing specific mounts.
+ IsVirtual() bool
+
+ // StatFS returns a filesystem Info implementation or an error. If
+ // the filesystem does not support this operation (maybe in the future
+ // it will), then ENOSYS should be returned.
+ StatFS(context.Context) (Info, error)
+}
diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go
new file mode 100644
index 000000000..537c8d257
--- /dev/null
+++ b/pkg/sentry/fs/inode_overlay.go
@@ -0,0 +1,737 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+func overlayHasWhiteout(ctx context.Context, parent *Inode, name string) bool {
+ s, err := parent.GetXattr(ctx, XattrOverlayWhiteout(name), 1)
+ return err == nil && s == "y"
+}
+
+func overlayCreateWhiteout(ctx context.Context, parent *Inode, name string) error {
+ return parent.InodeOperations.SetXattr(ctx, parent, XattrOverlayWhiteout(name), "y", 0 /* flags */)
+}
+
+func overlayWriteOut(ctx context.Context, o *overlayEntry) error {
+ // Hot path. Avoid defers.
+ var err error
+ o.copyMu.RLock()
+ if o.upper != nil {
+ err = o.upper.InodeOperations.WriteOut(ctx, o.upper)
+ }
+ o.copyMu.RUnlock()
+ return err
+}
+
+// overlayLookup performs a lookup in parent.
+//
+// If name exists, it returns true if the Dirent is in the upper, false if the
+// Dirent is in the lower.
+func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name string) (*Dirent, bool, error) {
+ // Hot path. Avoid defers.
+ parent.copyMu.RLock()
+
+ // Assert that there is at least one upper or lower entry.
+ if parent.upper == nil && parent.lower == nil {
+ parent.copyMu.RUnlock()
+ panic("invalid overlayEntry, needs at least one Inode")
+ }
+
+ var upperInode *Inode
+ var lowerInode *Inode
+
+ // We must remember whether the upper fs returned a negative dirent,
+ // because it is only safe to return one if the upper did.
+ var negativeUpperChild bool
+
+ // Does the parent directory exist in the upper file system?
+ if parent.upper != nil {
+ // First check if a file object exists in the upper file system.
+ // A file could have been created over a whiteout, so we need to
+ // check if something exists in the upper file system first.
+ child, err := parent.upper.Lookup(ctx, name)
+ if err != nil && err != syserror.ENOENT {
+ // We encountered an error that an overlay cannot handle,
+ // we must propagate it to the caller.
+ parent.copyMu.RUnlock()
+ return nil, false, err
+ }
+ if child != nil {
+ if child.IsNegative() {
+ negativeUpperChild = true
+ } else {
+ upperInode = child.Inode
+ upperInode.IncRef()
+ }
+ child.DecRef()
+ }
+
+ // Are we done?
+ if overlayHasWhiteout(ctx, parent.upper, name) {
+ if upperInode == nil {
+ parent.copyMu.RUnlock()
+ if negativeUpperChild {
+ // If the upper fs returnd a negative
+ // Dirent, then the upper is OK with
+ // that negative Dirent being cached in
+ // the Dirent tree, so we can return
+ // one from the overlay.
+ return NewNegativeDirent(name), false, nil
+ }
+ // Upper fs is not OK with a negative Dirent
+ // being cached in the Dirent tree, so don't
+ // return one.
+ return nil, false, syserror.ENOENT
+ }
+ entry, err := newOverlayEntry(ctx, upperInode, nil, false)
+ if err != nil {
+ // Don't leak resources.
+ upperInode.DecRef()
+ parent.copyMu.RUnlock()
+ return nil, false, err
+ }
+ d, err := NewDirent(ctx, newOverlayInode(ctx, entry, inode.MountSource), name), nil
+ parent.copyMu.RUnlock()
+ return d, true, err
+ }
+ }
+
+ // Check the lower file system. We do this unconditionally (even for
+ // non-directories) because we may need to use stable attributes from
+ // the lower filesystem (e.g. device number, inode number) that were
+ // visible before a copy up.
+ if parent.lower != nil {
+ // Check the lower file system.
+ child, err := parent.lower.Lookup(ctx, name)
+ // Same song and dance as above.
+ if err != nil && err != syserror.ENOENT {
+ // Don't leak resources.
+ if upperInode != nil {
+ upperInode.DecRef()
+ }
+ parent.copyMu.RUnlock()
+ return nil, false, err
+ }
+ if child != nil {
+ if !child.IsNegative() {
+ if upperInode == nil {
+ // If nothing was in the upper, use what we found in the lower.
+ lowerInode = child.Inode
+ lowerInode.IncRef()
+ } else {
+ // If we have something from the upper, we can only use it if the types
+ // match.
+ // NOTE(b/112312863): Allow SpecialDirectories and Directories to merge.
+ // This is needed to allow submounts in /proc and /sys.
+ if upperInode.StableAttr.Type == child.Inode.StableAttr.Type ||
+ (IsDir(upperInode.StableAttr) && IsDir(child.Inode.StableAttr)) {
+ lowerInode = child.Inode
+ lowerInode.IncRef()
+ }
+ }
+ }
+ child.DecRef()
+ }
+ }
+
+ // Was all of this for naught?
+ if upperInode == nil && lowerInode == nil {
+ parent.copyMu.RUnlock()
+ // We can only return a negative dirent if the upper returned
+ // one as well. See comments above regarding negativeUpperChild
+ // for more info.
+ if negativeUpperChild {
+ return NewNegativeDirent(name), false, nil
+ }
+ return nil, false, syserror.ENOENT
+ }
+
+ // Did we find a lower Inode? Remember this because we may decide we don't
+ // actually need the lower Inode (see below).
+ lowerExists := lowerInode != nil
+
+ // If we found something in the upper filesystem and the lower filesystem,
+ // use the stable attributes from the lower filesystem. If we don't do this,
+ // then it may appear that the file was magically recreated across copy up.
+ if upperInode != nil && lowerInode != nil {
+ // Steal attributes.
+ upperInode.StableAttr = lowerInode.StableAttr
+
+ // For non-directories, the lower filesystem resource is strictly
+ // unnecessary because we don't need to copy-up and we will always
+ // operate (e.g. read/write) on the upper Inode.
+ if !IsDir(upperInode.StableAttr) {
+ lowerInode.DecRef()
+ lowerInode = nil
+ }
+ }
+
+ // Phew, finally done.
+ entry, err := newOverlayEntry(ctx, upperInode, lowerInode, lowerExists)
+ if err != nil {
+ // Well, not quite, we failed at the last moment, how depressing.
+ // Be sure not to leak resources.
+ if upperInode != nil {
+ upperInode.DecRef()
+ }
+ if lowerInode != nil {
+ lowerInode.DecRef()
+ }
+ parent.copyMu.RUnlock()
+ return nil, false, err
+ }
+ d, err := NewDirent(ctx, newOverlayInode(ctx, entry, inode.MountSource), name), nil
+ parent.copyMu.RUnlock()
+ return d, upperInode != nil, err
+}
+
+func overlayCreate(ctx context.Context, o *overlayEntry, parent *Dirent, name string, flags FileFlags, perm FilePermissions) (*File, error) {
+ // Sanity check.
+ if parent.Inode.overlay == nil {
+ panic(fmt.Sprintf("overlayCreate called with non-overlay parent inode (parent InodeOperations type is %T)", parent.Inode.InodeOperations))
+ }
+
+ // Dirent.Create takes renameMu if the Inode is an overlay Inode.
+ if err := copyUpLockedForRename(ctx, parent); err != nil {
+ return nil, err
+ }
+
+ upperFile, err := o.upper.InodeOperations.Create(ctx, o.upper, name, flags, perm)
+ if err != nil {
+ return nil, err
+ }
+
+ // We've added to the directory so we must drop the cache.
+ o.markDirectoryDirty()
+
+ // Take another reference on the upper file's inode, which will be
+ // owned by the overlay entry.
+ upperFile.Dirent.Inode.IncRef()
+ entry, err := newOverlayEntry(ctx, upperFile.Dirent.Inode, nil, false)
+ if err != nil {
+ werr := fmt.Errorf("newOverlayEntry failed: %v", err)
+ cleanupUpper(ctx, o.upper, name, werr)
+ return nil, err
+ }
+
+ // NOTE(b/71766861): Replace the Dirent with a transient Dirent, since
+ // we are about to create the real Dirent: an overlay Dirent.
+ //
+ // This ensures the *fs.File returned from overlayCreate is in the same
+ // state as the *fs.File returned by overlayGetFile, where the upper
+ // file has a transient Dirent.
+ //
+ // This is necessary for Save/Restore, as otherwise the upper Dirent
+ // (which has no path as it is unparented and never reachable by the
+ // user) will clobber the real path for the underlying Inode.
+ upperFile.Dirent.Inode.IncRef()
+ upperDirent := NewTransientDirent(upperFile.Dirent.Inode)
+ upperFile.Dirent.DecRef()
+ upperFile.Dirent = upperDirent
+
+ // Create the overlay inode and dirent. We need this to construct the
+ // overlay file.
+ overlayInode := newOverlayInode(ctx, entry, parent.Inode.MountSource)
+ // d will own the inode reference.
+ overlayDirent := NewDirent(ctx, overlayInode, name)
+ // The overlay file created below with NewFile will take a reference on
+ // the overlayDirent, and it should be the only thing holding a
+ // reference at the time of creation, so we must drop this reference.
+ defer overlayDirent.DecRef()
+
+ // Create a new overlay file that wraps the upper file.
+ flags.Pread = upperFile.Flags().Pread
+ flags.Pwrite = upperFile.Flags().Pwrite
+ overlayFile := NewFile(ctx, overlayDirent, flags, &overlayFileOperations{upper: upperFile})
+
+ return overlayFile, nil
+}
+
+func overlayCreateDirectory(ctx context.Context, o *overlayEntry, parent *Dirent, name string, perm FilePermissions) error {
+ // Dirent.CreateDirectory takes renameMu if the Inode is an overlay
+ // Inode.
+ if err := copyUpLockedForRename(ctx, parent); err != nil {
+ return err
+ }
+ if err := o.upper.InodeOperations.CreateDirectory(ctx, o.upper, name, perm); err != nil {
+ return err
+ }
+ // We've added to the directory so we must drop the cache.
+ o.markDirectoryDirty()
+ return nil
+}
+
+func overlayCreateLink(ctx context.Context, o *overlayEntry, parent *Dirent, oldname string, newname string) error {
+ // Dirent.CreateLink takes renameMu if the Inode is an overlay Inode.
+ if err := copyUpLockedForRename(ctx, parent); err != nil {
+ return err
+ }
+ if err := o.upper.InodeOperations.CreateLink(ctx, o.upper, oldname, newname); err != nil {
+ return err
+ }
+ // We've added to the directory so we must drop the cache.
+ o.markDirectoryDirty()
+ return nil
+}
+
+func overlayCreateHardLink(ctx context.Context, o *overlayEntry, parent *Dirent, target *Dirent, name string) error {
+ // Dirent.CreateHardLink takes renameMu if the Inode is an overlay
+ // Inode.
+ if err := copyUpLockedForRename(ctx, parent); err != nil {
+ return err
+ }
+ if err := copyUpLockedForRename(ctx, target); err != nil {
+ return err
+ }
+ if err := o.upper.InodeOperations.CreateHardLink(ctx, o.upper, target.Inode.overlay.upper, name); err != nil {
+ return err
+ }
+ // We've added to the directory so we must drop the cache.
+ o.markDirectoryDirty()
+ return nil
+}
+
+func overlayCreateFifo(ctx context.Context, o *overlayEntry, parent *Dirent, name string, perm FilePermissions) error {
+ // Dirent.CreateFifo takes renameMu if the Inode is an overlay Inode.
+ if err := copyUpLockedForRename(ctx, parent); err != nil {
+ return err
+ }
+ if err := o.upper.InodeOperations.CreateFifo(ctx, o.upper, name, perm); err != nil {
+ return err
+ }
+ // We've added to the directory so we must drop the cache.
+ o.markDirectoryDirty()
+ return nil
+}
+
+func overlayRemove(ctx context.Context, o *overlayEntry, parent *Dirent, child *Dirent) error {
+ // Dirent.Remove and Dirent.RemoveDirectory take renameMu if the Inode
+ // is an overlay Inode.
+ if err := copyUpLockedForRename(ctx, parent); err != nil {
+ return err
+ }
+ child.Inode.overlay.copyMu.RLock()
+ defer child.Inode.overlay.copyMu.RUnlock()
+ if child.Inode.overlay.upper != nil {
+ if child.Inode.StableAttr.Type == Directory {
+ if err := o.upper.InodeOperations.RemoveDirectory(ctx, o.upper, child.name); err != nil {
+ return err
+ }
+ } else {
+ if err := o.upper.InodeOperations.Remove(ctx, o.upper, child.name); err != nil {
+ return err
+ }
+ }
+ }
+ if child.Inode.overlay.lowerExists {
+ if err := overlayCreateWhiteout(ctx, o.upper, child.name); err != nil {
+ return err
+ }
+ }
+ // We've removed from the directory so we must drop the cache.
+ o.markDirectoryDirty()
+ return nil
+}
+
+func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, renamed *Dirent, newParent *Dirent, newName string, replacement bool) error {
+ // To be able to copy these up below, they have to be part of an
+ // overlay file system.
+ //
+ // Maybe some day we can allow the more complicated case of
+ // non-overlay X overlay renames, but that's not necessary right now.
+ if renamed.Inode.overlay == nil || newParent.Inode.overlay == nil || oldParent.Inode.overlay == nil {
+ return syserror.EXDEV
+ }
+
+ if replacement {
+ // Check here if the file to be replaced exists and is a
+ // non-empty directory. If we copy up first, we may end up
+ // copying the directory but none of its children, so the
+ // directory will appear empty in the upper fs, which will then
+ // allow the rename to proceed when it should return ENOTEMPTY.
+ //
+ // NOTE(b/111808347): Ideally, we'd just pass in the replaced
+ // Dirent from Rename, but we must drop the reference on
+ // replaced before we make the rename call, so Rename can't
+ // pass the Dirent to the Inode without significantly
+ // complicating the API. Thus we look it up again here.
+ //
+ // For the same reason we can't use defer here.
+ replaced, inUpper, err := overlayLookup(ctx, newParent.Inode.overlay, newParent.Inode, newName)
+ // If err == ENOENT or a negative Dirent is returned, then
+ // newName has been removed out from under us. That's fine;
+ // filesystems where that can happen must handle stale
+ // 'replaced'.
+ if err != nil && err != syserror.ENOENT {
+ return err
+ }
+ if err == nil {
+ if !inUpper {
+ // newName doesn't exist in
+ // newParent.Inode.overlay.upper, thus from
+ // that Inode's perspective this won't be a
+ // replacing rename.
+ replacement = false
+ }
+
+ if !replaced.IsNegative() && IsDir(replaced.Inode.StableAttr) {
+ children, err := readdirOne(ctx, replaced)
+ if err != nil {
+ replaced.DecRef()
+ return err
+ }
+
+ // readdirOne ensures that "." and ".." are not
+ // included among the returned children, so we don't
+ // need to bother checking for them.
+ if len(children) > 0 {
+ replaced.DecRef()
+ return syserror.ENOTEMPTY
+ }
+ }
+
+ replaced.DecRef()
+ }
+ }
+
+ if err := copyUpLockedForRename(ctx, renamed); err != nil {
+ return err
+ }
+ if err := copyUpLockedForRename(ctx, newParent); err != nil {
+ return err
+ }
+ oldName := renamed.name
+ if err := o.upper.InodeOperations.Rename(ctx, renamed.Inode.overlay.upper, oldParent.Inode.overlay.upper, oldName, newParent.Inode.overlay.upper, newName, replacement); err != nil {
+ return err
+ }
+ if renamed.Inode.overlay.lowerExists {
+ if err := overlayCreateWhiteout(ctx, oldParent.Inode.overlay.upper, oldName); err != nil {
+ return err
+ }
+ }
+ // We've changed the directory so we must drop the cache.
+ oldParent.Inode.overlay.markDirectoryDirty()
+ return nil
+}
+
+func overlayBind(ctx context.Context, o *overlayEntry, parent *Dirent, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) {
+ if err := copyUpLockedForRename(ctx, parent); err != nil {
+ return nil, err
+ }
+
+ o.copyMu.RLock()
+ defer o.copyMu.RUnlock()
+
+ d, err := o.upper.InodeOperations.Bind(ctx, o.upper, name, data, perm)
+ if err != nil {
+ return nil, err
+ }
+
+ // We've added to the directory so we must drop the cache.
+ o.markDirectoryDirty()
+
+ // Grab the inode and drop the dirent, we don't need it.
+ inode := d.Inode
+ inode.IncRef()
+ d.DecRef()
+
+ // Create a new overlay entry and dirent for the socket.
+ entry, err := newOverlayEntry(ctx, inode, nil, false)
+ if err != nil {
+ inode.DecRef()
+ return nil, err
+ }
+ // Use the parent's MountSource, since that corresponds to the overlay,
+ // and not the upper filesystem.
+ return NewDirent(ctx, newOverlayInode(ctx, entry, parent.Inode.MountSource), name), nil
+}
+
+func overlayBoundEndpoint(o *overlayEntry, path string) transport.BoundEndpoint {
+ o.copyMu.RLock()
+ defer o.copyMu.RUnlock()
+
+ if o.upper != nil {
+ return o.upper.InodeOperations.BoundEndpoint(o.upper, path)
+ }
+
+ return o.lower.BoundEndpoint(path)
+}
+
+func overlayGetFile(ctx context.Context, o *overlayEntry, d *Dirent, flags FileFlags) (*File, error) {
+ // Hot path. Avoid defers.
+ if flags.Write {
+ if err := copyUp(ctx, d); err != nil {
+ return nil, err
+ }
+ }
+
+ o.copyMu.RLock()
+
+ if o.upper != nil {
+ upper, err := overlayFile(ctx, o.upper, flags)
+ if err != nil {
+ o.copyMu.RUnlock()
+ return nil, err
+ }
+ flags.Pread = upper.Flags().Pread
+ flags.Pwrite = upper.Flags().Pwrite
+ f, err := NewFile(ctx, d, flags, &overlayFileOperations{upper: upper}), nil
+ o.copyMu.RUnlock()
+ return f, err
+ }
+
+ lower, err := overlayFile(ctx, o.lower, flags)
+ if err != nil {
+ o.copyMu.RUnlock()
+ return nil, err
+ }
+ flags.Pread = lower.Flags().Pread
+ flags.Pwrite = lower.Flags().Pwrite
+ o.copyMu.RUnlock()
+ return NewFile(ctx, d, flags, &overlayFileOperations{lower: lower}), nil
+}
+
+func overlayUnstableAttr(ctx context.Context, o *overlayEntry) (UnstableAttr, error) {
+ // Hot path. Avoid defers.
+ var (
+ attr UnstableAttr
+ err error
+ )
+ o.copyMu.RLock()
+ if o.upper != nil {
+ attr, err = o.upper.UnstableAttr(ctx)
+ } else {
+ attr, err = o.lower.UnstableAttr(ctx)
+ }
+ o.copyMu.RUnlock()
+ return attr, err
+}
+
+func overlayGetXattr(ctx context.Context, o *overlayEntry, name string, size uint64) (string, error) {
+ // Hot path. This is how the overlay checks for whiteout files.
+ // Avoid defers.
+ var (
+ s string
+ err error
+ )
+
+ // Don't forward the value of the extended attribute if it would
+ // unexpectedly change the behavior of a wrapping overlay layer.
+ if strings.HasPrefix(XattrOverlayPrefix, name) {
+ return "", syserror.ENODATA
+ }
+
+ o.copyMu.RLock()
+ if o.upper != nil {
+ s, err = o.upper.GetXattr(ctx, name, size)
+ } else {
+ s, err = o.lower.GetXattr(ctx, name, size)
+ }
+ o.copyMu.RUnlock()
+ return s, err
+}
+
+func overlaySetxattr(ctx context.Context, o *overlayEntry, d *Dirent, name, value string, flags uint32) error {
+ // Don't allow changes to overlay xattrs through a setxattr syscall.
+ if strings.HasPrefix(XattrOverlayPrefix, name) {
+ return syserror.EPERM
+ }
+
+ if err := copyUp(ctx, d); err != nil {
+ return err
+ }
+ return o.upper.SetXattr(ctx, d, name, value, flags)
+}
+
+func overlayListXattr(ctx context.Context, o *overlayEntry, size uint64) (map[string]struct{}, error) {
+ o.copyMu.RLock()
+ defer o.copyMu.RUnlock()
+ var names map[string]struct{}
+ var err error
+ if o.upper != nil {
+ names, err = o.upper.ListXattr(ctx, size)
+ } else {
+ names, err = o.lower.ListXattr(ctx, size)
+ }
+ for name := range names {
+ // Same as overlayGetXattr, we shouldn't forward along
+ // overlay attributes.
+ if strings.HasPrefix(XattrOverlayPrefix, name) {
+ delete(names, name)
+ }
+ }
+ return names, err
+}
+
+func overlayRemoveXattr(ctx context.Context, o *overlayEntry, d *Dirent, name string) error {
+ // Don't allow changes to overlay xattrs through a removexattr syscall.
+ if strings.HasPrefix(XattrOverlayPrefix, name) {
+ return syserror.EPERM
+ }
+
+ if err := copyUp(ctx, d); err != nil {
+ return err
+ }
+ return o.upper.RemoveXattr(ctx, d, name)
+}
+
+func overlayCheck(ctx context.Context, o *overlayEntry, p PermMask) error {
+ o.copyMu.RLock()
+ // Hot path. Avoid defers.
+ var err error
+ if o.upper != nil {
+ err = o.upper.check(ctx, p)
+ } else {
+ err = o.lower.check(ctx, p)
+ }
+ o.copyMu.RUnlock()
+ return err
+}
+
+func overlaySetPermissions(ctx context.Context, o *overlayEntry, d *Dirent, f FilePermissions) bool {
+ if err := copyUp(ctx, d); err != nil {
+ return false
+ }
+ return o.upper.InodeOperations.SetPermissions(ctx, o.upper, f)
+}
+
+func overlaySetOwner(ctx context.Context, o *overlayEntry, d *Dirent, owner FileOwner) error {
+ if err := copyUp(ctx, d); err != nil {
+ return err
+ }
+ return o.upper.InodeOperations.SetOwner(ctx, o.upper, owner)
+}
+
+func overlaySetTimestamps(ctx context.Context, o *overlayEntry, d *Dirent, ts TimeSpec) error {
+ if err := copyUp(ctx, d); err != nil {
+ return err
+ }
+ return o.upper.InodeOperations.SetTimestamps(ctx, o.upper, ts)
+}
+
+func overlayTruncate(ctx context.Context, o *overlayEntry, d *Dirent, size int64) error {
+ if err := copyUp(ctx, d); err != nil {
+ return err
+ }
+ return o.upper.InodeOperations.Truncate(ctx, o.upper, size)
+}
+
+func overlayAllocate(ctx context.Context, o *overlayEntry, d *Dirent, offset, length int64) error {
+ if err := copyUp(ctx, d); err != nil {
+ return err
+ }
+ return o.upper.InodeOperations.Allocate(ctx, o.upper, offset, length)
+}
+
+func overlayReadlink(ctx context.Context, o *overlayEntry) (string, error) {
+ o.copyMu.RLock()
+ defer o.copyMu.RUnlock()
+ if o.upper != nil {
+ return o.upper.Readlink(ctx)
+ }
+ return o.lower.Readlink(ctx)
+}
+
+func overlayGetlink(ctx context.Context, o *overlayEntry) (*Dirent, error) {
+ var dirent *Dirent
+ var err error
+
+ o.copyMu.RLock()
+ defer o.copyMu.RUnlock()
+
+ if o.upper != nil {
+ dirent, err = o.upper.Getlink(ctx)
+ } else {
+ dirent, err = o.lower.Getlink(ctx)
+ }
+ if dirent != nil {
+ // This dirent is likely bogus (its Inode likely doesn't contain
+ // the right overlayEntry). So we're forced to drop it on the
+ // ground and claim that jumping around the filesystem like this
+ // is not supported.
+ name, _ := dirent.FullName(nil)
+ dirent.DecRef()
+
+ // Claim that the path is not accessible.
+ err = syserror.EACCES
+ log.Warningf("Getlink not supported in overlay for %q", name)
+ }
+ return nil, err
+}
+
+func overlayStatFS(ctx context.Context, o *overlayEntry) (Info, error) {
+ o.copyMu.RLock()
+ defer o.copyMu.RUnlock()
+
+ var i Info
+ var err error
+ if o.upper != nil {
+ i, err = o.upper.StatFS(ctx)
+ } else {
+ i, err = o.lower.StatFS(ctx)
+ }
+ if err != nil {
+ return Info{}, err
+ }
+
+ i.Type = linux.OVERLAYFS_SUPER_MAGIC
+
+ return i, nil
+}
+
+// NewTestOverlayDir returns an overlay Inode for tests.
+//
+// If `revalidate` is true, then the upper filesystem will require
+// revalidation.
+func NewTestOverlayDir(ctx context.Context, upper, lower *Inode, revalidate bool) *Inode {
+ fs := &overlayFilesystem{}
+ var upperMsrc *MountSource
+ if revalidate {
+ upperMsrc = NewRevalidatingMountSource(ctx, fs, MountSourceFlags{})
+ } else {
+ upperMsrc = NewNonCachingMountSource(ctx, fs, MountSourceFlags{})
+ }
+ msrc := NewMountSource(ctx, &overlayMountSourceOperations{
+ upper: upperMsrc,
+ lower: NewNonCachingMountSource(ctx, fs, MountSourceFlags{}),
+ }, fs, MountSourceFlags{})
+ overlay := &overlayEntry{
+ upper: upper,
+ lower: lower,
+ }
+ return newOverlayInode(ctx, overlay, msrc)
+}
+
+// TestHasUpperFS returns true if i is an overlay Inode and it has a pointer
+// to an Inode on an upper filesystem.
+func (i *Inode) TestHasUpperFS() bool {
+ return i.overlay != nil && i.overlay.upper != nil
+}
+
+// TestHasLowerFS returns true if i is an overlay Inode and it has a pointer
+// to an Inode on a lower filesystem.
+func (i *Inode) TestHasLowerFS() bool {
+ return i.overlay != nil && i.overlay.lower != nil
+}
diff --git a/pkg/sentry/fs/inode_overlay_test.go b/pkg/sentry/fs/inode_overlay_test.go
new file mode 100644
index 000000000..389c219d6
--- /dev/null
+++ b/pkg/sentry/fs/inode_overlay_test.go
@@ -0,0 +1,470 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+func TestLookup(t *testing.T) {
+ ctx := contexttest.Context(t)
+ for _, test := range []struct {
+ // Test description.
+ desc string
+
+ // Lookup parameters.
+ dir *fs.Inode
+ name string
+
+ // Want from lookup.
+ found bool
+ hasUpper bool
+ hasLower bool
+ }{
+ {
+ desc: "no upper, lower has name",
+ dir: fs.NewTestOverlayDir(ctx,
+ nil, /* upper */
+ newTestRamfsDir(ctx, []dirContent{
+ {
+ name: "a",
+ dir: false,
+ },
+ }, nil), /* lower */
+ false /* revalidate */),
+ name: "a",
+ found: true,
+ hasUpper: false,
+ hasLower: true,
+ },
+ {
+ desc: "no lower, upper has name",
+ dir: fs.NewTestOverlayDir(ctx,
+ newTestRamfsDir(ctx, []dirContent{
+ {
+ name: "a",
+ dir: false,
+ },
+ }, nil), /* upper */
+ nil, /* lower */
+ false /* revalidate */),
+ name: "a",
+ found: true,
+ hasUpper: true,
+ hasLower: false,
+ },
+ {
+ desc: "upper and lower, only lower has name",
+ dir: fs.NewTestOverlayDir(ctx,
+ newTestRamfsDir(ctx, []dirContent{
+ {
+ name: "b",
+ dir: false,
+ },
+ }, nil), /* upper */
+ newTestRamfsDir(ctx, []dirContent{
+ {
+ name: "a",
+ dir: false,
+ },
+ }, nil), /* lower */
+ false /* revalidate */),
+ name: "a",
+ found: true,
+ hasUpper: false,
+ hasLower: true,
+ },
+ {
+ desc: "upper and lower, only upper has name",
+ dir: fs.NewTestOverlayDir(ctx,
+ newTestRamfsDir(ctx, []dirContent{
+ {
+ name: "a",
+ dir: false,
+ },
+ }, nil), /* upper */
+ newTestRamfsDir(ctx, []dirContent{
+ {
+ name: "b",
+ dir: false,
+ },
+ }, nil), /* lower */
+ false /* revalidate */),
+ name: "a",
+ found: true,
+ hasUpper: true,
+ hasLower: false,
+ },
+ {
+ desc: "upper and lower, both have file",
+ dir: fs.NewTestOverlayDir(ctx,
+ newTestRamfsDir(ctx, []dirContent{
+ {
+ name: "a",
+ dir: false,
+ },
+ }, nil), /* upper */
+ newTestRamfsDir(ctx, []dirContent{
+ {
+ name: "a",
+ dir: false,
+ },
+ }, nil), /* lower */
+ false /* revalidate */),
+ name: "a",
+ found: true,
+ hasUpper: true,
+ hasLower: false,
+ },
+ {
+ desc: "upper and lower, both have directory",
+ dir: fs.NewTestOverlayDir(ctx,
+ newTestRamfsDir(ctx, []dirContent{
+ {
+ name: "a",
+ dir: true,
+ },
+ }, nil), /* upper */
+ newTestRamfsDir(ctx, []dirContent{
+ {
+ name: "a",
+ dir: true,
+ },
+ }, nil), /* lower */
+ false /* revalidate */),
+ name: "a",
+ found: true,
+ hasUpper: true,
+ hasLower: true,
+ },
+ {
+ desc: "upper and lower, upper negative masks lower file",
+ dir: fs.NewTestOverlayDir(ctx,
+ newTestRamfsDir(ctx, nil, []string{"a"}), /* upper */
+ newTestRamfsDir(ctx, []dirContent{
+ {
+ name: "a",
+ dir: false,
+ },
+ }, nil), /* lower */
+ false /* revalidate */),
+ name: "a",
+ found: false,
+ hasUpper: false,
+ hasLower: false,
+ },
+ {
+ desc: "upper and lower, upper negative does not mask lower file",
+ dir: fs.NewTestOverlayDir(ctx,
+ newTestRamfsDir(ctx, nil, []string{"b"}), /* upper */
+ newTestRamfsDir(ctx, []dirContent{
+ {
+ name: "a",
+ dir: false,
+ },
+ }, nil), /* lower */
+ false /* revalidate */),
+ name: "a",
+ found: true,
+ hasUpper: false,
+ hasLower: true,
+ },
+ } {
+ t.Run(test.desc, func(t *testing.T) {
+ dirent, err := test.dir.Lookup(ctx, test.name)
+ if test.found && (err == syserror.ENOENT || dirent.IsNegative()) {
+ t.Fatalf("lookup %q expected to find positive dirent, got dirent %v err %v", test.name, dirent, err)
+ }
+ if !test.found {
+ if err != syserror.ENOENT && !dirent.IsNegative() {
+ t.Errorf("lookup %q expected to return ENOENT or negative dirent, got dirent %v err %v", test.name, dirent, err)
+ }
+ // Nothing more to check.
+ return
+ }
+ if hasUpper := dirent.Inode.TestHasUpperFS(); hasUpper != test.hasUpper {
+ t.Fatalf("lookup got upper filesystem %v, want %v", hasUpper, test.hasUpper)
+ }
+ if hasLower := dirent.Inode.TestHasLowerFS(); hasLower != test.hasLower {
+ t.Errorf("lookup got lower filesystem %v, want %v", hasLower, test.hasLower)
+ }
+ })
+ }
+}
+
+func TestLookupRevalidation(t *testing.T) {
+ // File name used in the tests.
+ fileName := "foofile"
+ ctx := contexttest.Context(t)
+ for _, tc := range []struct {
+ // Test description.
+ desc string
+
+ // Upper and lower fs for the overlay.
+ upper *fs.Inode
+ lower *fs.Inode
+
+ // Whether the upper requires revalidation.
+ revalidate bool
+
+ // Whether we should get the same dirent on second lookup.
+ wantSame bool
+ }{
+ {
+ desc: "file from upper with no revalidation",
+ upper: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil),
+ lower: newTestRamfsDir(ctx, nil, nil),
+ revalidate: false,
+ wantSame: true,
+ },
+ {
+ desc: "file from upper with revalidation",
+ upper: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil),
+ lower: newTestRamfsDir(ctx, nil, nil),
+ revalidate: true,
+ wantSame: false,
+ },
+ {
+ desc: "file from lower with no revalidation",
+ upper: newTestRamfsDir(ctx, nil, nil),
+ lower: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil),
+ revalidate: false,
+ wantSame: true,
+ },
+ {
+ desc: "file from lower with revalidation",
+ upper: newTestRamfsDir(ctx, nil, nil),
+ lower: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil),
+ revalidate: true,
+ // The file does not exist in the upper, so we do not
+ // need to revalidate it.
+ wantSame: true,
+ },
+ {
+ desc: "file from upper and lower with no revalidation",
+ upper: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil),
+ lower: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil),
+ revalidate: false,
+ wantSame: true,
+ },
+ {
+ desc: "file from upper and lower with revalidation",
+ upper: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil),
+ lower: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil),
+ revalidate: true,
+ wantSame: false,
+ },
+ } {
+ t.Run(tc.desc, func(t *testing.T) {
+ root := fs.NewDirent(ctx, newTestRamfsDir(ctx, nil, nil), "root")
+ ctx = &rootContext{
+ Context: ctx,
+ root: root,
+ }
+ overlay := fs.NewDirent(ctx, fs.NewTestOverlayDir(ctx, tc.upper, tc.lower, tc.revalidate), "overlay")
+ // Lookup the file twice through the overlay.
+ first, err := overlay.Walk(ctx, root, fileName)
+ if err != nil {
+ t.Fatalf("overlay.Walk(%q) failed: %v", fileName, err)
+ }
+ second, err := overlay.Walk(ctx, root, fileName)
+ if err != nil {
+ t.Fatalf("overlay.Walk(%q) failed: %v", fileName, err)
+ }
+
+ if tc.wantSame && first != second {
+ t.Errorf("dirent lookup got different dirents, wanted same\nfirst=%+v\nsecond=%+v", first, second)
+ } else if !tc.wantSame && first == second {
+ t.Errorf("dirent lookup got the same dirent, wanted different: %+v", first)
+ }
+ })
+ }
+}
+
+func TestCacheFlush(t *testing.T) {
+ ctx := contexttest.Context(t)
+
+ // Upper and lower each have a file.
+ upperFileName := "file-from-upper"
+ lowerFileName := "file-from-lower"
+ upper := newTestRamfsDir(ctx, []dirContent{{name: upperFileName}}, nil)
+ lower := newTestRamfsDir(ctx, []dirContent{{name: lowerFileName}}, nil)
+
+ overlay := fs.NewTestOverlayDir(ctx, upper, lower, true /* revalidate */)
+
+ mns, err := fs.NewMountNamespace(ctx, overlay)
+ if err != nil {
+ t.Fatalf("NewMountNamespace failed: %v", err)
+ }
+ root := mns.Root()
+ defer root.DecRef()
+
+ ctx = &rootContext{
+ Context: ctx,
+ root: root,
+ }
+
+ for _, fileName := range []string{upperFileName, lowerFileName} {
+ // Walk to the file.
+ maxTraversals := uint(0)
+ dirent, err := mns.FindInode(ctx, root, nil, fileName, &maxTraversals)
+ if err != nil {
+ t.Fatalf("FindInode(%q) failed: %v", fileName, err)
+ }
+
+ // Get a file from the dirent.
+ file, err := dirent.Inode.GetFile(ctx, dirent, fs.FileFlags{Read: true})
+ if err != nil {
+ t.Fatalf("GetFile() failed: %v", err)
+ }
+
+ // The dirent should have 3 refs, one from us, one from the
+ // file, and one from the dirent cache.
+ // dirent cache.
+ if got, want := dirent.ReadRefs(), 3; int(got) != want {
+ t.Errorf("dirent.ReadRefs() got %d want %d", got, want)
+ }
+
+ // Drop the file reference.
+ file.DecRef()
+
+ // Dirent should have 2 refs left.
+ if got, want := dirent.ReadRefs(), 2; int(got) != want {
+ t.Errorf("dirent.ReadRefs() got %d want %d", got, want)
+ }
+
+ // Flush the dirent cache.
+ mns.FlushMountSourceRefs()
+
+ // Dirent should have 1 ref left from the dirent cache.
+ if got, want := dirent.ReadRefs(), 1; int(got) != want {
+ t.Errorf("dirent.ReadRefs() got %d want %d", got, want)
+ }
+
+ // Drop our ref.
+ dirent.DecRef()
+
+ // We should be back to zero refs.
+ if got, want := dirent.ReadRefs(), 0; int(got) != want {
+ t.Errorf("dirent.ReadRefs() got %d want %d", got, want)
+ }
+ }
+
+}
+
+type dir struct {
+ fs.InodeOperations
+
+ // List of negative child names.
+ negative []string
+
+ // ReaddirCalled records whether Readdir was called on a file
+ // corresponding to this inode.
+ ReaddirCalled bool
+}
+
+// GetXattr implements InodeOperations.GetXattr.
+func (d *dir) GetXattr(_ context.Context, _ *fs.Inode, name string, _ uint64) (string, error) {
+ for _, n := range d.negative {
+ if name == fs.XattrOverlayWhiteout(n) {
+ return "y", nil
+ }
+ }
+ return "", syserror.ENOATTR
+}
+
+// GetFile implements InodeOperations.GetFile.
+func (d *dir) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ file, err := d.InodeOperations.GetFile(ctx, dirent, flags)
+ if err != nil {
+ return nil, err
+ }
+ defer file.DecRef()
+ // Wrap the file's FileOperations in a dirFile.
+ fops := &dirFile{
+ FileOperations: file.FileOperations,
+ inode: d,
+ }
+ return fs.NewFile(ctx, dirent, flags, fops), nil
+}
+
+type dirContent struct {
+ name string
+ dir bool
+}
+
+type dirFile struct {
+ fs.FileOperations
+ inode *dir
+}
+
+type inode struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotAllocatable `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeNotTruncatable `state:"nosave"`
+ fsutil.InodeNotVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+ fsutil.InodeStaticFileGetter
+}
+
+// Readdir implements fs.FileOperations.Readdir. It sets the ReaddirCalled
+// field on the inode.
+func (f *dirFile) Readdir(ctx context.Context, file *fs.File, ser fs.DentrySerializer) (int64, error) {
+ f.inode.ReaddirCalled = true
+ return f.FileOperations.Readdir(ctx, file, ser)
+}
+
+func newTestRamfsInode(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ inode := fs.NewInode(ctx, &inode{
+ InodeStaticFileGetter: fsutil.InodeStaticFileGetter{
+ Contents: []byte("foobar"),
+ },
+ }, msrc, fs.StableAttr{Type: fs.RegularFile})
+ return inode
+}
+
+func newTestRamfsDir(ctx context.Context, contains []dirContent, negative []string) *fs.Inode {
+ msrc := fs.NewPseudoMountSource(ctx)
+ contents := make(map[string]*fs.Inode)
+ for _, c := range contains {
+ if c.dir {
+ contents[c.name] = newTestRamfsDir(ctx, nil, nil)
+ } else {
+ contents[c.name] = newTestRamfsInode(ctx, msrc)
+ }
+ }
+ dops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermissions{
+ User: fs.PermMask{Read: true, Execute: true},
+ })
+ return fs.NewInode(ctx, &dir{
+ InodeOperations: dops,
+ negative: negative,
+ }, msrc, fs.StableAttr{Type: fs.Directory})
+}
diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go
new file mode 100644
index 000000000..e3a715c1f
--- /dev/null
+++ b/pkg/sentry/fs/inotify.go
@@ -0,0 +1,352 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "io"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// Inotify represents an inotify instance created by inotify_init(2) or
+// inotify_init1(2). Inotify implements the FileOperations interface.
+//
+// Lock ordering:
+// Inotify.mu -> Inode.Watches.mu -> Watch.mu -> Inotify.evMu
+//
+// +stateify savable
+type Inotify struct {
+ // Unique identifier for this inotify instance. We don't just reuse the
+ // inotify fd because fds can be duped. These should not be exposed to the
+ // user, since we may aggressively reuse an id on S/R.
+ id uint64
+
+ waiter.Queue `state:"nosave"`
+
+ // evMu *only* protects the events list. We need a separate lock because
+ // while queuing events, a watch needs to lock the event queue, and using mu
+ // for that would violate lock ordering since at that point the calling
+ // goroutine already holds Watch.target.Watches.mu.
+ evMu sync.Mutex `state:"nosave"`
+
+ // A list of pending events for this inotify instance. Protected by evMu.
+ events eventList
+
+ // A scratch buffer, use to serialize inotify events. Use allocate this
+ // ahead of time and reuse performance. Protected by evMu.
+ scratch []byte
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // The next watch descriptor number to use for this inotify instance. Note
+ // that Linux starts numbering watch descriptors from 1.
+ nextWatch int32
+
+ // Map from watch descriptors to watch objects.
+ watches map[int32]*Watch
+}
+
+// NewInotify constructs a new Inotify instance.
+func NewInotify(ctx context.Context) *Inotify {
+ return &Inotify{
+ id: uniqueid.GlobalFromContext(ctx),
+ scratch: make([]byte, inotifyEventBaseSize),
+ nextWatch: 1, // Linux starts numbering watch descriptors from 1.
+ watches: make(map[int32]*Watch),
+ }
+}
+
+// Release implements FileOperations.Release. Release removes all watches and
+// frees all resources for an inotify instance.
+func (i *Inotify) Release() {
+ // We need to hold i.mu to avoid a race with concurrent calls to
+ // Inotify.targetDestroyed from Watches. There's no risk of Watches
+ // accessing this Inotify after the destructor ends, because we remove all
+ // references to it below.
+ i.mu.Lock()
+ defer i.mu.Unlock()
+ for _, w := range i.watches {
+ // Remove references to the watch from the watch target. We don't need
+ // to worry about the references from the owner instance, since we're in
+ // the owner's destructor.
+ w.target.Watches.Remove(w.ID())
+ // Don't leak any references to the target, held by pins in the watch.
+ w.destroy()
+ }
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+//
+// Readiness indicates whether there are pending events for an inotify instance.
+func (i *Inotify) Readiness(mask waiter.EventMask) waiter.EventMask {
+ ready := waiter.EventMask(0)
+
+ i.evMu.Lock()
+ defer i.evMu.Unlock()
+
+ if !i.events.Empty() {
+ ready |= waiter.EventIn
+ }
+
+ return mask & ready
+}
+
+// Seek implements FileOperations.Seek.
+func (*Inotify) Seek(context.Context, *File, SeekWhence, int64) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Readdir implements FileOperatons.Readdir.
+func (*Inotify) Readdir(context.Context, *File, DentrySerializer) (int64, error) {
+ return 0, syserror.ENOTDIR
+}
+
+// Write implements FileOperations.Write.
+func (*Inotify) Write(context.Context, *File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// Read implements FileOperations.Read.
+func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ int64) (int64, error) {
+ if dst.NumBytes() < inotifyEventBaseSize {
+ return 0, syserror.EINVAL
+ }
+
+ i.evMu.Lock()
+ defer i.evMu.Unlock()
+
+ if i.events.Empty() {
+ // Nothing to read yet, tell caller to block.
+ return 0, syserror.ErrWouldBlock
+ }
+
+ var writeLen int64
+ for it := i.events.Front(); it != nil; {
+ event := it
+ it = it.Next()
+
+ // Does the buffer have enough remaining space to hold the event we're
+ // about to write out?
+ if dst.NumBytes() < int64(event.sizeOf()) {
+ if writeLen > 0 {
+ // Buffer wasn't big enough for all pending events, but we did
+ // write some events out.
+ return writeLen, nil
+ }
+ return 0, syserror.EINVAL
+ }
+
+ // Linux always dequeues an available event as long as there's enough
+ // buffer space to copy it out, even if the copy below fails. Emulate
+ // this behaviour.
+ i.events.Remove(event)
+
+ // Buffer has enough space, copy event to the read buffer.
+ n, err := event.CopyTo(ctx, i.scratch, dst)
+ if err != nil {
+ return 0, err
+ }
+
+ writeLen += n
+ dst = dst.DropFirst64(n)
+ }
+ return writeLen, nil
+}
+
+// WriteTo implements FileOperations.WriteTo.
+func (*Inotify) WriteTo(context.Context, *File, io.Writer, int64, bool) (int64, error) {
+ return 0, syserror.ENOSYS
+}
+
+// Fsync implements FileOperations.Fsync.
+func (*Inotify) Fsync(context.Context, *File, int64, int64, SyncType) error {
+ return syserror.EINVAL
+}
+
+// ReadFrom implements FileOperations.ReadFrom.
+func (*Inotify) ReadFrom(context.Context, *File, io.Reader, int64) (int64, error) {
+ return 0, syserror.ENOSYS
+}
+
+// Flush implements FileOperations.Flush.
+func (*Inotify) Flush(context.Context, *File) error {
+ return nil
+}
+
+// ConfigureMMap implements FileOperations.ConfigureMMap.
+func (*Inotify) ConfigureMMap(context.Context, *File, *memmap.MMapOpts) error {
+ return syserror.ENODEV
+}
+
+// UnstableAttr implements FileOperations.UnstableAttr.
+func (i *Inotify) UnstableAttr(ctx context.Context, file *File) (UnstableAttr, error) {
+ return file.Dirent.Inode.UnstableAttr(ctx)
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (i *Inotify) Ioctl(ctx context.Context, _ *File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch args[1].Int() {
+ case linux.FIONREAD:
+ i.evMu.Lock()
+ defer i.evMu.Unlock()
+ var n uint32
+ for e := i.events.Front(); e != nil; e = e.Next() {
+ n += uint32(e.sizeOf())
+ }
+ var buf [4]byte
+ usermem.ByteOrder.PutUint32(buf[:], n)
+ _, err := io.CopyOut(ctx, args[2].Pointer(), buf[:], usermem.IOOpts{})
+ return 0, err
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+func (i *Inotify) queueEvent(ev *Event) {
+ i.evMu.Lock()
+
+ // Check if we should coalesce the event we're about to queue with the last
+ // one currently in the queue. Events are coalesced if they are identical.
+ if last := i.events.Back(); last != nil {
+ if ev.equals(last) {
+ // "Coalesce" the two events by simply not queuing the new one. We
+ // don't need to raise a waiter.EventIn notification because no new
+ // data is available for reading.
+ i.evMu.Unlock()
+ return
+ }
+ }
+
+ i.events.PushBack(ev)
+
+ // Release mutex before notifying waiters because we don't control what they
+ // can do.
+ i.evMu.Unlock()
+
+ i.Queue.Notify(waiter.EventIn)
+}
+
+// newWatchLocked creates and adds a new watch to target.
+func (i *Inotify) newWatchLocked(target *Dirent, mask uint32) *Watch {
+ wd := i.nextWatch
+ i.nextWatch++
+
+ watch := &Watch{
+ owner: i,
+ wd: wd,
+ mask: mask,
+ target: target.Inode,
+ pins: make(map[*Dirent]bool),
+ }
+
+ i.watches[wd] = watch
+
+ // Grab an extra reference to target to prevent it from being evicted from
+ // memory. This ref is dropped during either watch removal, target
+ // destruction, or inotify instance destruction. See callers of Watch.Unpin.
+ watch.Pin(target)
+ target.Inode.Watches.Add(watch)
+
+ return watch
+}
+
+// targetDestroyed is called by w to notify i that w's target is gone. This
+// automatically generates a watch removal event.
+func (i *Inotify) targetDestroyed(w *Watch) {
+ i.mu.Lock()
+ _, found := i.watches[w.wd]
+ delete(i.watches, w.wd)
+ i.mu.Unlock()
+
+ if found {
+ i.queueEvent(newEvent(w.wd, "", linux.IN_IGNORED, 0))
+ }
+}
+
+// AddWatch constructs a new inotify watch and adds it to the target dirent. It
+// returns the watch descriptor returned by inotify_add_watch(2).
+func (i *Inotify) AddWatch(target *Dirent, mask uint32) int32 {
+ // Note: Locking this inotify instance protects the result returned by
+ // Lookup() below. With the lock held, we know for sure the lookup result
+ // won't become stale because it's impossible for *this* instance to
+ // add/remove watches on target.
+ i.mu.Lock()
+ defer i.mu.Unlock()
+
+ // Does the target already have a watch from this inotify instance?
+ if existing := target.Inode.Watches.Lookup(i.id); existing != nil {
+ // This may be a watch on a different dirent pointing to the
+ // same inode. Obtain an extra reference if necessary.
+ existing.Pin(target)
+
+ newmask := mask
+ if mergeMask := mask&linux.IN_MASK_ADD != 0; mergeMask {
+ // "Add (OR) events to watch mask for this pathname if it already
+ // exists (instead of replacing mask)." -- inotify(7)
+ newmask |= atomic.LoadUint32(&existing.mask)
+ }
+ atomic.StoreUint32(&existing.mask, newmask)
+ return existing.wd
+ }
+
+ // No existing watch, create a new watch.
+ watch := i.newWatchLocked(target, mask)
+ return watch.wd
+}
+
+// RmWatch implements watcher.Watchable.RmWatch.
+//
+// RmWatch looks up an inotify watch for the given 'wd' and configures the
+// target dirent to stop sending events to this inotify instance.
+func (i *Inotify) RmWatch(wd int32) error {
+ i.mu.Lock()
+
+ // Find the watch we were asked to removed.
+ watch, ok := i.watches[wd]
+ if !ok {
+ i.mu.Unlock()
+ return syserror.EINVAL
+ }
+
+ // Remove the watch from this instance.
+ delete(i.watches, wd)
+
+ // Remove the watch from the watch target.
+ watch.target.Watches.Remove(watch.ID())
+
+ // The watch is now isolated and we can safely drop the instance lock. We
+ // need to do so because watch.destroy() acquires Watch.mu, which cannot be
+ // acquired with Inotify.mu held.
+ i.mu.Unlock()
+
+ // Generate the event for the removal.
+ i.queueEvent(newEvent(watch.wd, "", linux.IN_IGNORED, 0))
+
+ // Remove all pins.
+ watch.destroy()
+
+ return nil
+}
diff --git a/pkg/sentry/fs/inotify_event.go b/pkg/sentry/fs/inotify_event.go
new file mode 100644
index 000000000..686e1b1cd
--- /dev/null
+++ b/pkg/sentry/fs/inotify_event.go
@@ -0,0 +1,139 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// inotifyEventBaseSize is the base size of linux's struct inotify_event. This
+// must be a power 2 for rounding below.
+const inotifyEventBaseSize = 16
+
+// Event represents a struct inotify_event from linux.
+//
+// +stateify savable
+type Event struct {
+ eventEntry
+
+ wd int32
+ mask uint32
+ cookie uint32
+
+ // len is computed based on the name field is set automatically by
+ // Event.setName. It should be 0 when no name is set; otherwise it is the
+ // length of the name slice.
+ len uint32
+
+ // The name field has special padding requirements and should only be set by
+ // calling Event.setName.
+ name []byte
+}
+
+func newEvent(wd int32, name string, events, cookie uint32) *Event {
+ e := &Event{
+ wd: wd,
+ mask: events,
+ cookie: cookie,
+ }
+ if name != "" {
+ e.setName(name)
+ }
+ return e
+}
+
+// paddedBytes converts a go string to a null-terminated c-string, padded with
+// null bytes to a total size of 'l'. 'l' must be large enough for all the bytes
+// in the 's' plus at least one null byte.
+func paddedBytes(s string, l uint32) []byte {
+ if l < uint32(len(s)+1) {
+ panic("Converting string to byte array results in truncation, this can lead to buffer-overflow due to the missing null-byte!")
+ }
+ b := make([]byte, l)
+ copy(b, s)
+
+ // b was zero-value initialized during make(), so the rest of the slice is
+ // already filled with null bytes.
+
+ return b
+}
+
+// setName sets the optional name for this event.
+func (e *Event) setName(name string) {
+ // We need to pad the name such that the entire event length ends up a
+ // multiple of inotifyEventBaseSize.
+ unpaddedLen := len(name) + 1
+ // Round up to nearest multiple of inotifyEventBaseSize.
+ e.len = uint32((unpaddedLen + inotifyEventBaseSize - 1) & ^(inotifyEventBaseSize - 1))
+ // Make sure we haven't overflowed and wrapped around when rounding.
+ if unpaddedLen > int(e.len) {
+ panic("Overflow when rounding inotify event size, the 'name' field was too big.")
+ }
+ e.name = paddedBytes(name, e.len)
+}
+
+func (e *Event) sizeOf() int {
+ s := inotifyEventBaseSize + int(e.len)
+ if s < inotifyEventBaseSize {
+ panic("overflow")
+ }
+ return s
+}
+
+// CopyTo serializes this event to dst. buf is used as a scratch buffer to
+// construct the output. We use a buffer allocated ahead of time for
+// performance. buf must be at least inotifyEventBaseSize bytes.
+func (e *Event) CopyTo(ctx context.Context, buf []byte, dst usermem.IOSequence) (int64, error) {
+ usermem.ByteOrder.PutUint32(buf[0:], uint32(e.wd))
+ usermem.ByteOrder.PutUint32(buf[4:], e.mask)
+ usermem.ByteOrder.PutUint32(buf[8:], e.cookie)
+ usermem.ByteOrder.PutUint32(buf[12:], e.len)
+
+ writeLen := 0
+
+ n, err := dst.CopyOut(ctx, buf)
+ if err != nil {
+ return 0, err
+ }
+ writeLen += n
+ dst = dst.DropFirst(n)
+
+ if e.len > 0 {
+ n, err = dst.CopyOut(ctx, e.name)
+ if err != nil {
+ return 0, err
+ }
+ writeLen += n
+ }
+
+ // Santiy check.
+ if writeLen != e.sizeOf() {
+ panic(fmt.Sprintf("Serialized unexpected amount of data for an event, expected %v, wrote %v.", e.sizeOf(), writeLen))
+ }
+
+ return int64(writeLen), nil
+}
+
+func (e *Event) equals(other *Event) bool {
+ return e.wd == other.wd &&
+ e.mask == other.mask &&
+ e.cookie == other.cookie &&
+ e.len == other.len &&
+ bytes.Equal(e.name, other.name)
+}
diff --git a/pkg/sentry/fs/inotify_watch.go b/pkg/sentry/fs/inotify_watch.go
new file mode 100644
index 000000000..900cba3ca
--- /dev/null
+++ b/pkg/sentry/fs/inotify_watch.go
@@ -0,0 +1,135 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Watch represent a particular inotify watch created by inotify_add_watch.
+//
+// While a watch is active, it ensures the target inode is pinned in memory by
+// holding an extra ref on each dirent known (by inotify) to point to the
+// inode. These are known as pins. For a full discussion, see
+// fs/g3doc/inotify.md.
+//
+// +stateify savable
+type Watch struct {
+ // Inotify instance which owns this watch.
+ owner *Inotify
+
+ // Descriptor for this watch. This is unique across an inotify instance.
+ wd int32
+
+ // The inode being watched. Note that we don't directly hold a reference on
+ // this inode. Instead we hold a reference on the dirent(s) containing the
+ // inode, which we record in pins.
+ target *Inode
+
+ // unpinned indicates whether we have a hard reference on target. This field
+ // may only be modified through atomic ops.
+ unpinned uint32
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // Events being monitored via this watch. Must be accessed atomically,
+ // writes are protected by mu.
+ mask uint32
+
+ // pins is the set of dirents this watch is currently pinning in memory by
+ // holding a reference to them. See Pin()/Unpin().
+ pins map[*Dirent]bool
+}
+
+// ID returns the id of the inotify instance that owns this watch.
+func (w *Watch) ID() uint64 {
+ return w.owner.id
+}
+
+// NotifyParentAfterUnlink indicates whether the parent of the watched object
+// should continue to be be notified of events after the target has been
+// unlinked.
+func (w *Watch) NotifyParentAfterUnlink() bool {
+ return atomic.LoadUint32(&w.mask)&linux.IN_EXCL_UNLINK == 0
+}
+
+// isRenameEvent returns true if eventMask describes a rename event.
+func isRenameEvent(eventMask uint32) bool {
+ return eventMask&(linux.IN_MOVED_FROM|linux.IN_MOVED_TO|linux.IN_MOVE_SELF) != 0
+}
+
+// Notify queues a new event on this watch.
+func (w *Watch) Notify(name string, events uint32, cookie uint32) {
+ mask := atomic.LoadUint32(&w.mask)
+ if mask&events == 0 {
+ // We weren't watching for this event.
+ return
+ }
+
+ // Event mask should include bits matched from the watch plus all control
+ // event bits.
+ unmaskableBits := ^uint32(0) &^ linux.IN_ALL_EVENTS
+ effectiveMask := unmaskableBits | mask
+ matchedEvents := effectiveMask & events
+ w.owner.queueEvent(newEvent(w.wd, name, matchedEvents, cookie))
+}
+
+// Pin acquires a new ref on dirent, which pins the dirent in memory while
+// the watch is active. Calling Pin for a second time on the same dirent for
+// the same watch is a no-op.
+func (w *Watch) Pin(d *Dirent) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ if !w.pins[d] {
+ w.pins[d] = true
+ d.IncRef()
+ }
+}
+
+// Unpin drops any extra refs held on dirent due to a previous Pin
+// call. Calling Unpin multiple times for the same dirent, or on a dirent
+// without a corresponding Pin call is a no-op.
+func (w *Watch) Unpin(d *Dirent) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ if w.pins[d] {
+ delete(w.pins, d)
+ d.DecRef()
+ }
+}
+
+// TargetDestroyed notifies the owner of the watch that the watch target is
+// gone. The owner should release its own references to the watcher upon
+// receiving this notification.
+func (w *Watch) TargetDestroyed() {
+ w.owner.targetDestroyed(w)
+}
+
+// destroy prepares the watch for destruction. It unpins all dirents pinned by
+// this watch. Destroy does not cause any new events to be generated. The caller
+// is responsible for ensuring there are no outstanding references to this
+// watch.
+func (w *Watch) destroy() {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ for d := range w.pins {
+ d.DecRef()
+ }
+ w.pins = nil
+}
diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD
new file mode 100644
index 000000000..ae3331737
--- /dev/null
+++ b/pkg/sentry/fs/lock/BUILD
@@ -0,0 +1,58 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "lock_range",
+ out = "lock_range.go",
+ package = "lock",
+ prefix = "Lock",
+ template = "//pkg/segment:generic_range",
+ types = {
+ "T": "uint64",
+ },
+)
+
+go_template_instance(
+ name = "lock_set",
+ out = "lock_set.go",
+ consts = {
+ "minDegree": "3",
+ },
+ package = "lock",
+ prefix = "Lock",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "uint64",
+ "Range": "LockRange",
+ "Value": "Lock",
+ "Functions": "lockSetFunctions",
+ },
+)
+
+go_library(
+ name = "lock",
+ srcs = [
+ "lock.go",
+ "lock_range.go",
+ "lock_set.go",
+ "lock_set_functions.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/log",
+ "//pkg/sync",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "lock_test",
+ size = "small",
+ srcs = [
+ "lock_range_test.go",
+ "lock_test.go",
+ ],
+ library = ":lock",
+)
diff --git a/pkg/sentry/fs/lock/lock.go b/pkg/sentry/fs/lock/lock.go
new file mode 100644
index 000000000..8a5d9c7eb
--- /dev/null
+++ b/pkg/sentry/fs/lock/lock.go
@@ -0,0 +1,453 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package lock is the API for POSIX-style advisory regional file locks and
+// BSD-style full file locks.
+//
+// Callers needing to enforce these types of locks, like sys_fcntl, can call
+// LockRegion and UnlockRegion on a thread-safe set of Locks. Locks are
+// specific to a unique file (unique device/inode pair) and for this reason
+// should not be shared between files.
+//
+// A Lock has a set of holders identified by UniqueID. Normally this is the
+// pid of the thread attempting to acquire the lock.
+//
+// Since these are advisory locks, they do not need to be integrated into
+// Reads/Writes and for this reason there is no way to *check* if a lock is
+// held. One can only attempt to take a lock or unlock an existing lock.
+//
+// A Lock in a set of Locks is typed: it is either a read lock with any number
+// of readers and no writer, or a write lock with no readers.
+//
+// As expected from POSIX, any attempt to acquire a write lock on a file region
+// when there already exits a write lock held by a different uid will fail. Any
+// attempt to acquire a write lock on a file region when there is more than one
+// reader will fail. Any attempt to acquire a read lock on a file region when
+// there is already a writer will fail.
+//
+// In special cases, a read lock may be upgraded to a write lock and a write lock
+// can be downgraded to a read lock. This can only happen if:
+//
+// * read lock upgrade to write lock: There can be only one reader and the reader
+// must be the same as the requested write lock holder.
+//
+// * write lock downgrade to read lock: The writer must be the same as the requested
+// read lock holder.
+//
+// UnlockRegion always succeeds. If LockRegion fails the caller should normally
+// interpret this as "try again later".
+package lock
+
+import (
+ "fmt"
+ "math"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LockType is a type of regional file lock.
+type LockType int
+
+// UniqueID is a unique identifier of the holder of a regional file lock.
+type UniqueID interface{}
+
+const (
+ // ReadLock describes a POSIX regional file lock to be taken
+ // read only. There may be multiple of these locks on a single
+ // file region as long as there is no writer lock on the same
+ // region.
+ ReadLock LockType = iota
+
+ // WriteLock describes a POSIX regional file lock to be taken
+ // write only. There may be only a single holder of this lock
+ // and no read locks.
+ WriteLock
+)
+
+// LockEOF is the maximal possible end of a regional file lock.
+//
+// A BSD-style full file lock can be represented as a regional file lock from
+// offset 0 to LockEOF.
+const LockEOF = math.MaxUint64
+
+// Lock is a regional file lock. It consists of either a single writer
+// or a set of readers.
+//
+// A Lock may be upgraded from a read lock to a write lock only if there
+// is a single reader and that reader has the same uid as the write lock.
+//
+// A Lock may be downgraded from a write lock to a read lock only if
+// the write lock's uid is the same as the read lock.
+//
+// +stateify savable
+type Lock struct {
+ // Readers are the set of read lock holders identified by UniqueID.
+ // If len(Readers) > 0 then HasWriter must be false.
+ Readers map[UniqueID]bool
+
+ // Writer holds the writer unique ID. It's nil if there are no writers.
+ Writer UniqueID
+}
+
+// Locks is a thread-safe wrapper around a LockSet.
+//
+// +stateify savable
+type Locks struct {
+ // mu protects locks below.
+ mu sync.Mutex `state:"nosave"`
+
+ // locks is the set of region locks currently held on an Inode.
+ locks LockSet
+
+ // blockedQueue is the queue of waiters that are waiting on a lock.
+ blockedQueue waiter.Queue `state:"zerovalue"`
+}
+
+// Blocker is the interface used for blocking locks. Passing a nil Blocker
+// will be treated as non-blocking.
+type Blocker interface {
+ Block(C <-chan struct{}) error
+}
+
+const (
+ // EventMaskAll is the mask we will always use for locks, by using the
+ // same mask all the time we can wake up everyone anytime the lock
+ // changes state.
+ EventMaskAll waiter.EventMask = 0xFFFF
+)
+
+// LockRegion attempts to acquire a typed lock for the uid on a region
+// of a file. Returns true if successful in locking the region. If false
+// is returned, the caller should normally interpret this as "try again later" if
+// acquiring the lock in a non-blocking mode or "interrupted" if in a blocking mode.
+// Blocker is the interface used to provide blocking behavior, passing a nil Blocker
+// will result in non-blocking behavior.
+func (l *Locks) LockRegion(uid UniqueID, t LockType, r LockRange, block Blocker) bool {
+ for {
+ l.mu.Lock()
+
+ // Blocking locks must run in a loop because we'll be woken up whenever an unlock event
+ // happens for this lock. We will then attempt to take the lock again and if it fails
+ // continue blocking.
+ res := l.locks.lock(uid, t, r)
+ if !res && block != nil {
+ e, ch := waiter.NewChannelEntry(nil)
+ l.blockedQueue.EventRegister(&e, EventMaskAll)
+ l.mu.Unlock()
+ if err := block.Block(ch); err != nil {
+ // We were interrupted, the caller can translate this to EINTR if applicable.
+ l.blockedQueue.EventUnregister(&e)
+ return false
+ }
+ l.blockedQueue.EventUnregister(&e)
+ continue // Try again now that someone has unlocked.
+ }
+
+ l.mu.Unlock()
+ return res
+ }
+}
+
+// UnlockRegion attempts to release a lock for the uid on a region of a file.
+// This operation is always successful, even if there did not exist a lock on
+// the requested region held by uid in the first place.
+func (l *Locks) UnlockRegion(uid UniqueID, r LockRange) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ l.locks.unlock(uid, r)
+
+ // Now that we've released the lock, we need to wake up any waiters.
+ l.blockedQueue.Notify(EventMaskAll)
+}
+
+// makeLock returns a new typed Lock that has either uid as its only reader
+// or uid as its only writer.
+func makeLock(uid UniqueID, t LockType) Lock {
+ value := Lock{Readers: make(map[UniqueID]bool)}
+ switch t {
+ case ReadLock:
+ value.Readers[uid] = true
+ case WriteLock:
+ value.Writer = uid
+ default:
+ panic(fmt.Sprintf("makeLock: invalid lock type %d", t))
+ }
+ return value
+}
+
+// isHeld returns true if uid is a holder of Lock.
+func (l Lock) isHeld(uid UniqueID) bool {
+ return l.Writer == uid || l.Readers[uid]
+}
+
+// lock sets uid as a holder of a typed lock on Lock.
+//
+// Preconditions: canLock is true for the range containing this Lock.
+func (l *Lock) lock(uid UniqueID, t LockType) {
+ switch t {
+ case ReadLock:
+ // If we are already a reader, then this is a no-op.
+ if l.Readers[uid] {
+ return
+ }
+ // We cannot downgrade a write lock to a read lock unless the
+ // uid is the same.
+ if l.Writer != nil {
+ if l.Writer != uid {
+ panic(fmt.Sprintf("lock: cannot downgrade write lock to read lock for uid %d, writer is %d", uid, l.Writer))
+ }
+ // Ensure that there is only one reader if upgrading.
+ l.Readers = make(map[UniqueID]bool)
+ // Ensure that there is no longer a writer.
+ l.Writer = nil
+ }
+ l.Readers[uid] = true
+ return
+ case WriteLock:
+ // If we are already the writer, then this is a no-op.
+ if l.Writer == uid {
+ return
+ }
+ // We can only upgrade a read lock to a write lock if there
+ // is only one reader and that reader has the same uid as
+ // the write lock.
+ if readers := len(l.Readers); readers > 0 {
+ if readers != 1 {
+ panic(fmt.Sprintf("lock: cannot upgrade read lock to write lock for uid %d, too many readers %v", uid, l.Readers))
+ }
+ if !l.Readers[uid] {
+ panic(fmt.Sprintf("lock: cannot upgrade read lock to write lock for uid %d, conflicting reader %v", uid, l.Readers))
+ }
+ }
+ // Ensure that there is only a writer.
+ l.Readers = make(map[UniqueID]bool)
+ l.Writer = uid
+ default:
+ panic(fmt.Sprintf("lock: invalid lock type %d", t))
+ }
+}
+
+// lockable returns true if check returns true for every Lock in LockRange.
+// Further, check should return true if Lock meets the callers requirements
+// for locking Lock.
+func (l LockSet) lockable(r LockRange, check func(value Lock) bool) bool {
+ // Get our starting point.
+ seg := l.LowerBoundSegment(r.Start)
+ for seg.Ok() && seg.Start() < r.End {
+ // Note that we don't care about overruning the end of the
+ // last segment because if everything checks out we'll just
+ // split the last segment.
+ if !check(seg.Value()) {
+ return false
+ }
+ // Jump to the next segment, ignoring gaps, for the same
+ // reason we ignored the first gap.
+ seg = seg.NextSegment()
+ }
+ // No conflict, we can get a lock for uid over the entire range.
+ return true
+}
+
+// canLock returns true if uid will be able to take a Lock of type t on the
+// entire range specified by LockRange.
+func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool {
+ switch t {
+ case ReadLock:
+ return l.lockable(r, func(value Lock) bool {
+ // If there is no writer, there's no problem adding another reader.
+ if value.Writer == nil {
+ return true
+ }
+ // If there is a writer, then it must be the same uid
+ // in order to downgrade the lock to a read lock.
+ return value.Writer == uid
+ })
+ case WriteLock:
+ return l.lockable(r, func(value Lock) bool {
+ // If there are only readers.
+ if value.Writer == nil {
+ // Then this uid can only take a write lock if this is a private
+ // upgrade, meaning that the only reader is uid.
+ return len(value.Readers) == 1 && value.Readers[uid]
+ }
+ // If the uid is already a writer on this region, then
+ // adding a write lock would be a no-op.
+ return value.Writer == uid
+ })
+ default:
+ panic(fmt.Sprintf("canLock: invalid lock type %d", t))
+ }
+}
+
+// lock returns true if uid took a lock of type t on the entire range of
+// LockRange.
+//
+// Preconditions: r.Start <= r.End (will panic otherwise).
+func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool {
+ if r.Start > r.End {
+ panic(fmt.Sprintf("lock: r.Start %d > r.End %d", r.Start, r.End))
+ }
+
+ // Don't attempt to insert anything with a range of 0 and treat this
+ // as a successful no-op.
+ if r.Length() == 0 {
+ return true
+ }
+
+ // Do a first-pass check. We *could* hold onto the segments we
+ // checked if canLock would return true, but traversing the segment
+ // set should be fast and this keeps things simple.
+ if !l.canLock(uid, t, r) {
+ return false
+ }
+ // Get our starting point.
+ seg, gap := l.Find(r.Start)
+ if gap.Ok() {
+ // Fill in the gap and get the next segment to modify.
+ seg = l.Insert(gap, gap.Range().Intersect(r), makeLock(uid, t)).NextSegment()
+ } else if seg.Start() < r.Start {
+ // Get our first segment to modify.
+ _, seg = l.Split(seg, r.Start)
+ }
+ for seg.Ok() && seg.Start() < r.End {
+ // Split the last one if necessary.
+ if seg.End() > r.End {
+ seg, _ = l.SplitUnchecked(seg, r.End)
+ }
+
+ // Set the lock on the segment. This is guaranteed to
+ // always be safe, given canLock above.
+ value := seg.ValuePtr()
+ value.lock(uid, t)
+
+ // Fill subsequent gaps.
+ gap = seg.NextGap()
+ if gr := gap.Range().Intersect(r); gr.Length() > 0 {
+ seg = l.Insert(gap, gr, makeLock(uid, t)).NextSegment()
+ } else {
+ seg = gap.NextSegment()
+ }
+ }
+ return true
+}
+
+// unlock is always successful. If uid has no locks held for the range LockRange,
+// unlock is a no-op.
+//
+// Preconditions: same as lock.
+func (l *LockSet) unlock(uid UniqueID, r LockRange) {
+ if r.Start > r.End {
+ panic(fmt.Sprintf("unlock: r.Start %d > r.End %d", r.Start, r.End))
+ }
+
+ // Same as setlock.
+ if r.Length() == 0 {
+ return
+ }
+
+ // Get our starting point.
+ seg := l.LowerBoundSegment(r.Start)
+ for seg.Ok() && seg.Start() < r.End {
+ // If this segment doesn't have a lock from uid then
+ // there is no need to fragment the set with Isolate (below).
+ // In this case just move on to the next segment.
+ if !seg.Value().isHeld(uid) {
+ seg = seg.NextSegment()
+ continue
+ }
+
+ // Ensure that if we need to unlock a sub-segment that
+ // we don't unlock/remove that entire segment.
+ seg = l.Isolate(seg, r)
+
+ value := seg.Value()
+ var remove bool
+ if value.Writer == uid {
+ // If we are unlocking a writer, then since there can
+ // only ever be one writer and no readers, then this
+ // lock should always be removed from the set.
+ remove = true
+ } else if value.Readers[uid] {
+ // If uid is the last reader, then just remove the entire
+ // segment.
+ if len(value.Readers) == 1 {
+ remove = true
+ } else {
+ // Otherwise we need to remove this reader without
+ // affecting any other segment's readers. To do
+ // this, we need to make a copy of the Readers map
+ // and not add this uid.
+ newValue := Lock{Readers: make(map[UniqueID]bool)}
+ for k, v := range value.Readers {
+ if k != uid {
+ newValue.Readers[k] = v
+ }
+ }
+ seg.SetValue(newValue)
+ }
+ }
+ if remove {
+ seg = l.Remove(seg).NextSegment()
+ } else {
+ seg = seg.NextSegment()
+ }
+ }
+}
+
+// ComputeRange takes a positive file offset and computes the start of a LockRange
+// using start (relative to offset) and the end of the LockRange using length. The
+// values of start and length may be negative but the resulting LockRange must
+// preserve that LockRange.Start < LockRange.End and LockRange.Start > 0.
+func ComputeRange(start, length, offset int64) (LockRange, error) {
+ offset += start
+ // fcntl(2): "l_start can be a negative number provided the offset
+ // does not lie before the start of the file"
+ if offset < 0 {
+ return LockRange{}, syscall.EINVAL
+ }
+
+ // fcntl(2): Specifying 0 for l_len has the special meaning: lock all
+ // bytes starting at the location specified by l_whence and l_start
+ // through to the end of file, no matter how large the file grows.
+ end := uint64(LockEOF)
+ if length > 0 {
+ // fcntl(2): If l_len is positive, then the range to be locked
+ // covers bytes l_start up to and including l_start+l_len-1.
+ //
+ // Since LockRange.End is exclusive we need not -1 from length..
+ end = uint64(offset + length)
+ } else if length < 0 {
+ // fcntl(2): If l_len is negative, the interval described by
+ // lock covers bytes l_start+l_len up to and including l_start-1.
+ //
+ // Since LockRange.End is exclusive we need not -1 from offset.
+ signedEnd := offset
+ // Add to offset using a negative length (subtract).
+ offset += length
+ if offset < 0 {
+ return LockRange{}, syscall.EINVAL
+ }
+ if signedEnd < offset {
+ return LockRange{}, syscall.EOVERFLOW
+ }
+ // At this point signedEnd cannot be negative,
+ // since we asserted that offset is not negative
+ // and it is not less than offset.
+ end = uint64(signedEnd)
+ }
+ // Offset is guaranteed to be positive at this point.
+ return LockRange{Start: uint64(offset), End: end}, nil
+}
diff --git a/pkg/sentry/fs/lock/lock_range_test.go b/pkg/sentry/fs/lock/lock_range_test.go
new file mode 100644
index 000000000..6221199d1
--- /dev/null
+++ b/pkg/sentry/fs/lock/lock_range_test.go
@@ -0,0 +1,136 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lock
+
+import (
+ "syscall"
+ "testing"
+)
+
+func TestComputeRange(t *testing.T) {
+ tests := []struct {
+ // Description of test.
+ name string
+
+ // Requested start of the lock range.
+ start int64
+
+ // Requested length of the lock range,
+ // can be negative :(
+ length int64
+
+ // Pre-computed file offset based on whence.
+ // Will be added to start.
+ offset int64
+
+ // Expected error.
+ err error
+
+ // If error is nil, the expected LockRange.
+ LockRange
+ }{
+ {
+ name: "offset, start, and length all zero",
+ LockRange: LockRange{Start: 0, End: LockEOF},
+ },
+ {
+ name: "zero offset, zero start, positive length",
+ start: 0,
+ length: 4096,
+ offset: 0,
+ LockRange: LockRange{Start: 0, End: 4096},
+ },
+ {
+ name: "zero offset, negative start",
+ start: -4096,
+ offset: 0,
+ err: syscall.EINVAL,
+ },
+ {
+ name: "large offset, negative start, positive length",
+ start: -2048,
+ length: 2048,
+ offset: 4096,
+ LockRange: LockRange{Start: 2048, End: 4096},
+ },
+ {
+ name: "large offset, negative start, zero length",
+ start: -2048,
+ length: 0,
+ offset: 4096,
+ LockRange: LockRange{Start: 2048, End: LockEOF},
+ },
+ {
+ name: "zero offset, zero start, negative length",
+ start: 0,
+ length: -4096,
+ offset: 0,
+ err: syscall.EINVAL,
+ },
+ {
+ name: "large offset, zero start, negative length",
+ start: 0,
+ length: -4096,
+ offset: 4096,
+ LockRange: LockRange{Start: 0, End: 4096},
+ },
+ {
+ name: "offset, start, and length equal, length is negative",
+ start: 1024,
+ length: -1024,
+ offset: 1024,
+ LockRange: LockRange{Start: 1024, End: 2048},
+ },
+ {
+ name: "offset, start, and length equal, start is negative",
+ start: -1024,
+ length: 1024,
+ offset: 1024,
+ LockRange: LockRange{Start: 0, End: 1024},
+ },
+ {
+ name: "offset, start, and length equal, offset is negative",
+ start: 1024,
+ length: 1024,
+ offset: -1024,
+ LockRange: LockRange{Start: 0, End: 1024},
+ },
+ {
+ name: "offset, start, and length equal, all negative",
+ start: -1024,
+ length: -1024,
+ offset: -1024,
+ err: syscall.EINVAL,
+ },
+ {
+ name: "offset, start, and length equal, all positive",
+ start: 1024,
+ length: 1024,
+ offset: 1024,
+ LockRange: LockRange{Start: 2048, End: 3072},
+ },
+ }
+
+ for _, test := range tests {
+ rng, err := ComputeRange(test.start, test.length, test.offset)
+ if err != test.err {
+ t.Errorf("%s: lockRange(%d, %d, %d) got error %v, want %v", test.name, test.start, test.length, test.offset, err, test.err)
+ continue
+ }
+ if err == nil && rng != test.LockRange {
+ t.Errorf("%s: lockRange(%d, %d, %d) got LockRange %v, want %v", test.name, test.start, test.length, test.offset, rng, test.LockRange)
+ }
+ }
+}
diff --git a/pkg/sentry/fs/lock/lock_set_functions.go b/pkg/sentry/fs/lock/lock_set_functions.go
new file mode 100644
index 000000000..50a16e662
--- /dev/null
+++ b/pkg/sentry/fs/lock/lock_set_functions.go
@@ -0,0 +1,63 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lock
+
+import (
+ "math"
+)
+
+// LockSet maps a set of Locks into a file. The key is the file offset.
+
+type lockSetFunctions struct{}
+
+func (lockSetFunctions) MinKey() uint64 {
+ return 0
+}
+
+func (lockSetFunctions) MaxKey() uint64 {
+ return math.MaxUint64
+}
+
+func (lockSetFunctions) ClearValue(l *Lock) {
+ *l = Lock{}
+}
+
+func (lockSetFunctions) Merge(r1 LockRange, val1 Lock, r2 LockRange, val2 Lock) (Lock, bool) {
+ // Merge only if the Readers/Writers are identical.
+ if len(val1.Readers) != len(val2.Readers) {
+ return Lock{}, false
+ }
+ for k := range val1.Readers {
+ if !val2.Readers[k] {
+ return Lock{}, false
+ }
+ }
+ if val1.Writer != val2.Writer {
+ return Lock{}, false
+ }
+ return val1, true
+}
+
+func (lockSetFunctions) Split(r LockRange, val Lock, split uint64) (Lock, Lock) {
+ // Copy the segment so that split segments don't contain map references
+ // to other segments.
+ val0 := Lock{Readers: make(map[UniqueID]bool)}
+ for k, v := range val.Readers {
+ val0.Readers[k] = v
+ }
+ val0.Writer = val.Writer
+
+ return val, val0
+}
diff --git a/pkg/sentry/fs/lock/lock_test.go b/pkg/sentry/fs/lock/lock_test.go
new file mode 100644
index 000000000..fad90984b
--- /dev/null
+++ b/pkg/sentry/fs/lock/lock_test.go
@@ -0,0 +1,1060 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lock
+
+import (
+ "reflect"
+ "testing"
+)
+
+type entry struct {
+ Lock
+ LockRange
+}
+
+func equals(e0, e1 []entry) bool {
+ if len(e0) != len(e1) {
+ return false
+ }
+ for i := range e0 {
+ for k := range e0[i].Lock.Readers {
+ if !e1[i].Lock.Readers[k] {
+ return false
+ }
+ }
+ for k := range e1[i].Lock.Readers {
+ if !e0[i].Lock.Readers[k] {
+ return false
+ }
+ }
+ if !reflect.DeepEqual(e0[i].LockRange, e1[i].LockRange) {
+ return false
+ }
+ if e0[i].Lock.Writer != e1[i].Lock.Writer {
+ return false
+ }
+ }
+ return true
+}
+
+// fill a LockSet with consecutive region locks. Will panic if
+// LockRanges are not consecutive.
+func fill(entries []entry) LockSet {
+ l := LockSet{}
+ for _, e := range entries {
+ gap := l.FindGap(e.LockRange.Start)
+ if !gap.Ok() {
+ panic("cannot insert into existing segment")
+ }
+ l.Insert(gap, e.LockRange, e.Lock)
+ }
+ return l
+}
+
+func TestCanLockEmpty(t *testing.T) {
+ l := LockSet{}
+
+ // Expect to be able to take any locks given that the set is empty.
+ eof := l.FirstGap().End()
+ r := LockRange{0, eof}
+ if !l.canLock(1, ReadLock, r) {
+ t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 1)
+ }
+ if !l.canLock(2, ReadLock, r) {
+ t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 2)
+ }
+ if !l.canLock(1, WriteLock, r) {
+ t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 1)
+ }
+ if !l.canLock(2, WriteLock, r) {
+ t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 2)
+ }
+}
+
+func TestCanLock(t *testing.T) {
+ // + -------------- + ---------- + -------------- + --------- +
+ // | Readers 1 & 2 | Readers 1 | Readers 1 & 3 | Writer 1 |
+ // + ------------- + ---------- + -------------- + --------- +
+ // 0 1024 2048 3072 4096
+ l := fill([]entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}},
+ LockRange: LockRange{0, 1024},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{1: true}},
+ LockRange: LockRange{1024, 2048},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{1: true, 3: true}},
+ LockRange: LockRange{2048, 3072},
+ },
+ {
+ Lock: Lock{Writer: 1},
+ LockRange: LockRange{3072, 4096},
+ },
+ })
+
+ // Now that we have a mildly interesting layout, try some checks on different
+ // ranges, uids, and lock types.
+ //
+ // Expect to be able to extend the read lock, despite the writer lock, because
+ // the writer has the same uid as the requested read lock.
+ r := LockRange{0, 8192}
+ if !l.canLock(1, ReadLock, r) {
+ t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 1)
+ }
+ // Expect to *not* be able to extend the read lock since there is an overlapping
+ // writer region locked by someone other than the uid.
+ if l.canLock(2, ReadLock, r) {
+ t.Fatalf("canLock type %d for range %v and uid %d got true, want false", ReadLock, r, 2)
+ }
+ // Expect to be able to extend the read lock if there are only other readers in
+ // the way.
+ r = LockRange{64, 3072}
+ if !l.canLock(2, ReadLock, r) {
+ t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 2)
+ }
+ // Expect to be able to set a read lock beyond the range of any existing locks.
+ r = LockRange{4096, 10240}
+ if !l.canLock(2, ReadLock, r) {
+ t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 2)
+ }
+
+ // Expect to not be able to take a write lock with other readers in the way.
+ r = LockRange{0, 8192}
+ if l.canLock(1, WriteLock, r) {
+ t.Fatalf("canLock type %d for range %v and uid %d got true, want false", WriteLock, r, 1)
+ }
+ // Expect to be able to extend the write lock for the same uid.
+ r = LockRange{3072, 8192}
+ if !l.canLock(1, WriteLock, r) {
+ t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 1)
+ }
+ // Expect to not be able to overlap a write lock for two different uids.
+ if l.canLock(2, WriteLock, r) {
+ t.Fatalf("canLock type %d for range %v and uid %d got true, want false", WriteLock, r, 2)
+ }
+ // Expect to be able to set a write lock that is beyond the range of any
+ // existing locks.
+ r = LockRange{8192, 10240}
+ if !l.canLock(2, WriteLock, r) {
+ t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 2)
+ }
+ // Expect to be able to upgrade a read lock (any portion of it).
+ r = LockRange{1024, 2048}
+ if !l.canLock(1, WriteLock, r) {
+ t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 1)
+ }
+ r = LockRange{1080, 2000}
+ if !l.canLock(1, WriteLock, r) {
+ t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 1)
+ }
+}
+
+func TestSetLock(t *testing.T) {
+ tests := []struct {
+ // description of test.
+ name string
+
+ // LockSet entries to pre-fill.
+ before []entry
+
+ // Description of region to lock:
+ //
+ // start is the file offset of the lock.
+ start uint64
+ // end is the end file offset of the lock.
+ end uint64
+ // uid of lock attempter.
+ uid UniqueID
+ // lock type requested.
+ lockType LockType
+
+ // success is true if taking the above
+ // lock should succeed.
+ success bool
+
+ // Expected layout of the set after locking
+ // if success is true.
+ after []entry
+ }{
+ {
+ name: "set zero length ReadLock on empty set",
+ start: 0,
+ end: 0,
+ uid: 0,
+ lockType: ReadLock,
+ success: true,
+ },
+ {
+ name: "set zero length WriteLock on empty set",
+ start: 0,
+ end: 0,
+ uid: 0,
+ lockType: WriteLock,
+ success: true,
+ },
+ {
+ name: "set ReadLock on empty set",
+ start: 0,
+ end: LockEOF,
+ uid: 0,
+ lockType: ReadLock,
+ success: true,
+ // + ----------------------------------------- +
+ // | Readers 0 |
+ // + ----------------------------------------- +
+ // 0 max uint64
+ after: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{0, LockEOF},
+ },
+ },
+ },
+ {
+ name: "set WriteLock on empty set",
+ start: 0,
+ end: LockEOF,
+ uid: 0,
+ lockType: WriteLock,
+ success: true,
+ // + ----------------------------------------- +
+ // | Writer 0 |
+ // + ----------------------------------------- +
+ // 0 max uint64
+ after: []entry{
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{0, LockEOF},
+ },
+ },
+ },
+ {
+ name: "set ReadLock on WriteLock same uid",
+ // + ----------------------------------------- +
+ // | Writer 0 |
+ // + ----------------------------------------- +
+ // 0 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{0, LockEOF},
+ },
+ },
+ start: 0,
+ end: 4096,
+ uid: 0,
+ lockType: ReadLock,
+ success: true,
+ // + ----------- + --------------------------- +
+ // | Readers 0 | Writer 0 |
+ // + ----------- + --------------------------- +
+ // 0 4096 max uint64
+ after: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{0, 4096},
+ },
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{4096, LockEOF},
+ },
+ },
+ },
+ {
+ name: "set WriteLock on ReadLock same uid",
+ // + ----------------------------------------- +
+ // | Readers 0 |
+ // + ----------------------------------------- +
+ // 0 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{0, LockEOF},
+ },
+ },
+ start: 0,
+ end: 4096,
+ uid: 0,
+ lockType: WriteLock,
+ success: true,
+ // + ----------- + --------------------------- +
+ // | Writer 0 | Readers 0 |
+ // + ----------- + --------------------------- +
+ // 0 4096 max uint64
+ after: []entry{
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{0, 4096},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{4096, LockEOF},
+ },
+ },
+ },
+ {
+ name: "set ReadLock on WriteLock different uid",
+ // + ----------------------------------------- +
+ // | Writer 0 |
+ // + ----------------------------------------- +
+ // 0 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{0, LockEOF},
+ },
+ },
+ start: 0,
+ end: 4096,
+ uid: 1,
+ lockType: ReadLock,
+ success: false,
+ },
+ {
+ name: "set WriteLock on ReadLock different uid",
+ // + ----------------------------------------- +
+ // | Readers 0 |
+ // + ----------------------------------------- +
+ // 0 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{0, LockEOF},
+ },
+ },
+ start: 0,
+ end: 4096,
+ uid: 1,
+ lockType: WriteLock,
+ success: false,
+ },
+ {
+ name: "split ReadLock for overlapping lock at start 0",
+ // + ----------------------------------------- +
+ // | Readers 0 |
+ // + ----------------------------------------- +
+ // 0 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{0, LockEOF},
+ },
+ },
+ start: 0,
+ end: 4096,
+ uid: 1,
+ lockType: ReadLock,
+ success: true,
+ // + -------------- + --------------------------- +
+ // | Readers 0 & 1 | Readers 0 |
+ // + -------------- + --------------------------- +
+ // 0 4096 max uint64
+ after: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ LockRange: LockRange{0, 4096},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{4096, LockEOF},
+ },
+ },
+ },
+ {
+ name: "split ReadLock for overlapping lock at non-zero start",
+ // + ----------------------------------------- +
+ // | Readers 0 |
+ // + ----------------------------------------- +
+ // 0 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{0, LockEOF},
+ },
+ },
+ start: 4096,
+ end: 8192,
+ uid: 1,
+ lockType: ReadLock,
+ success: true,
+ // + ---------- + -------------- + ----------- +
+ // | Readers 0 | Readers 0 & 1 | Readers 0 |
+ // + ---------- + -------------- + ----------- +
+ // 0 4096 8192 max uint64
+ after: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{0, 4096},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ LockRange: LockRange{4096, 8192},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{8192, LockEOF},
+ },
+ },
+ },
+ {
+ name: "fill front gap with ReadLock",
+ // + --------- + ---------------------------- +
+ // | gap | Readers 0 |
+ // + --------- + ---------------------------- +
+ // 0 1024 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{1024, LockEOF},
+ },
+ },
+ start: 0,
+ end: 8192,
+ uid: 0,
+ lockType: ReadLock,
+ success: true,
+ // + ----------------------------------------- +
+ // | Readers 0 |
+ // + ----------------------------------------- +
+ // 0 max uint64
+ after: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{0, LockEOF},
+ },
+ },
+ },
+ {
+ name: "fill end gap with ReadLock",
+ // + ---------------------------- +
+ // | Readers 0 |
+ // + ---------------------------- +
+ // 0 4096
+ before: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{0, 4096},
+ },
+ },
+ start: 1024,
+ end: LockEOF,
+ uid: 0,
+ lockType: ReadLock,
+ success: true,
+ // Note that this is not merged after lock does a Split. This is
+ // fine because the two locks will still *behave* as one. In other
+ // words we can fragment any lock all we want and semantically it
+ // makes no difference.
+ //
+ // + ----------- + --------------------------- +
+ // | Readers 0 | Readers 0 |
+ // + ----------- + --------------------------- +
+ // 0 max uint64
+ after: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{0, 1024},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{1024, LockEOF},
+ },
+ },
+ },
+ {
+ name: "fill gap with ReadLock and split",
+ // + --------- + ---------------------------- +
+ // | gap | Readers 0 |
+ // + --------- + ---------------------------- +
+ // 0 1024 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{1024, LockEOF},
+ },
+ },
+ start: 0,
+ end: 4096,
+ uid: 1,
+ lockType: ReadLock,
+ success: true,
+ // + --------- + ------------- + ------------- +
+ // | Reader 1 | Readers 0 & 1 | Reader 0 |
+ // + ----------+ ------------- + ------------- +
+ // 0 1024 4096 max uint64
+ after: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{1: true}},
+ LockRange: LockRange{0, 1024},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ LockRange: LockRange{1024, 4096},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{4096, LockEOF},
+ },
+ },
+ },
+ {
+ name: "upgrade ReadLock to WriteLock for single uid fill gap",
+ // + ------------- + --------- + --- + ------------- +
+ // | Readers 0 & 1 | Readers 0 | gap | Readers 0 & 2 |
+ // + ------------- + --------- + --- + ------------- +
+ // 0 1024 2048 4096 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ LockRange: LockRange{0, 1024},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{1024, 2048},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}},
+ LockRange: LockRange{4096, LockEOF},
+ },
+ },
+ start: 1024,
+ end: 4096,
+ uid: 0,
+ lockType: WriteLock,
+ success: true,
+ // + ------------- + -------- + ------------- +
+ // | Readers 0 & 1 | Writer 0 | Readers 0 & 2 |
+ // + ------------- + -------- + ------------- +
+ // 0 1024 4096 max uint64
+ after: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ LockRange: LockRange{0, 1024},
+ },
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{1024, 4096},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}},
+ LockRange: LockRange{4096, LockEOF},
+ },
+ },
+ },
+ {
+ name: "upgrade ReadLock to WriteLock for single uid keep gap",
+ // + ------------- + --------- + --- + ------------- +
+ // | Readers 0 & 1 | Readers 0 | gap | Readers 0 & 2 |
+ // + ------------- + --------- + --- + ------------- +
+ // 0 1024 2048 4096 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ LockRange: LockRange{0, 1024},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{1024, 2048},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}},
+ LockRange: LockRange{4096, LockEOF},
+ },
+ },
+ start: 1024,
+ end: 3072,
+ uid: 0,
+ lockType: WriteLock,
+ success: true,
+ // + ------------- + -------- + --- + ------------- +
+ // | Readers 0 & 1 | Writer 0 | gap | Readers 0 & 2 |
+ // + ------------- + -------- + --- + ------------- +
+ // 0 1024 3072 4096 max uint64
+ after: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ LockRange: LockRange{0, 1024},
+ },
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{1024, 3072},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}},
+ LockRange: LockRange{4096, LockEOF},
+ },
+ },
+ },
+ {
+ name: "fail to upgrade ReadLock to WriteLock with conflicting Reader",
+ // + ------------- + --------- +
+ // | Readers 0 & 1 | Readers 0 |
+ // + ------------- + --------- +
+ // 0 1024 2048
+ before: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ LockRange: LockRange{0, 1024},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{1024, 2048},
+ },
+ },
+ start: 0,
+ end: 2048,
+ uid: 0,
+ lockType: WriteLock,
+ success: false,
+ },
+ {
+ name: "take WriteLock on whole file if all uids are the same",
+ // + ------------- + --------- + --------- + ---------- +
+ // | Writer 0 | Readers 0 | Readers 0 | Readers 0 |
+ // + ------------- + --------- + --------- + ---------- +
+ // 0 1024 2048 4096 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{0, 1024},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{1024, 2048},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{2048, 4096},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{4096, LockEOF},
+ },
+ },
+ start: 0,
+ end: LockEOF,
+ uid: 0,
+ lockType: WriteLock,
+ success: true,
+ // We do not manually merge locks. Semantically a fragmented lock
+ // held by the same uid will behave as one lock so it makes no difference.
+ //
+ // + ------------- + ---------------------------- +
+ // | Writer 0 | Writer 0 |
+ // + ------------- + ---------------------------- +
+ // 0 1024 max uint64
+ after: []entry{
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{0, 1024},
+ },
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{1024, LockEOF},
+ },
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ l := fill(test.before)
+
+ r := LockRange{Start: test.start, End: test.end}
+ success := l.lock(test.uid, test.lockType, r)
+ var got []entry
+ for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ got = append(got, entry{
+ Lock: seg.Value(),
+ LockRange: seg.Range(),
+ })
+ }
+
+ if success != test.success {
+ t.Errorf("setlock(%v, %+v, %d, %d) got success %v, want %v", test.before, r, test.uid, test.lockType, success, test.success)
+ return
+ }
+
+ if success {
+ if !equals(got, test.after) {
+ t.Errorf("got set %+v, want %+v", got, test.after)
+ }
+ }
+ })
+ }
+}
+
+func TestUnlock(t *testing.T) {
+ tests := []struct {
+ // description of test.
+ name string
+
+ // LockSet entries to pre-fill.
+ before []entry
+
+ // Description of region to unlock:
+ //
+ // start is the file start of the lock.
+ start uint64
+ // end is the end file start of the lock.
+ end uint64
+ // uid of lock holder.
+ uid UniqueID
+
+ // Expected layout of the set after unlocking.
+ after []entry
+ }{
+ {
+ name: "unlock zero length on empty set",
+ start: 0,
+ end: 0,
+ uid: 0,
+ },
+ {
+ name: "unlock on empty set (no-op)",
+ start: 0,
+ end: LockEOF,
+ uid: 0,
+ },
+ {
+ name: "unlock uid not locked (no-op)",
+ // + --------------------------- +
+ // | Readers 1 & 2 |
+ // + --------------------------- +
+ // 0 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}},
+ LockRange: LockRange{0, LockEOF},
+ },
+ },
+ start: 1024,
+ end: 4096,
+ uid: 0,
+ // + --------------------------- +
+ // | Readers 1 & 2 |
+ // + --------------------------- +
+ // 0 max uint64
+ after: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}},
+ LockRange: LockRange{0, LockEOF},
+ },
+ },
+ },
+ {
+ name: "unlock ReadLock over entire file",
+ // + ----------------------------------------- +
+ // | Readers 0 |
+ // + ----------------------------------------- +
+ // 0 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{0, LockEOF},
+ },
+ },
+ start: 0,
+ end: LockEOF,
+ uid: 0,
+ },
+ {
+ name: "unlock WriteLock over entire file",
+ // + ----------------------------------------- +
+ // | Writer 0 |
+ // + ----------------------------------------- +
+ // 0 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{0, LockEOF},
+ },
+ },
+ start: 0,
+ end: LockEOF,
+ uid: 0,
+ },
+ {
+ name: "unlock partial ReadLock (start)",
+ // + ----------------------------------------- +
+ // | Readers 0 |
+ // + ----------------------------------------- +
+ // 0 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{0, LockEOF},
+ },
+ },
+ start: 0,
+ end: 4096,
+ uid: 0,
+ // + ------ + --------------------------- +
+ // | gap | Readers 0 |
+ // +------- + --------------------------- +
+ // 0 4096 max uint64
+ after: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{4096, LockEOF},
+ },
+ },
+ },
+ {
+ name: "unlock partial WriteLock (start)",
+ // + ----------------------------------------- +
+ // | Writer 0 |
+ // + ----------------------------------------- +
+ // 0 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{0, LockEOF},
+ },
+ },
+ start: 0,
+ end: 4096,
+ uid: 0,
+ // + ------ + --------------------------- +
+ // | gap | Writer 0 |
+ // +------- + --------------------------- +
+ // 0 4096 max uint64
+ after: []entry{
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{4096, LockEOF},
+ },
+ },
+ },
+ {
+ name: "unlock partial ReadLock (end)",
+ // + ----------------------------------------- +
+ // | Readers 0 |
+ // + ----------------------------------------- +
+ // 0 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{0, LockEOF},
+ },
+ },
+ start: 4096,
+ end: LockEOF,
+ uid: 0,
+ // + --------------------------- +
+ // | Readers 0 |
+ // +---------------------------- +
+ // 0 4096
+ after: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ LockRange: LockRange{0, 4096},
+ },
+ },
+ },
+ {
+ name: "unlock partial WriteLock (end)",
+ // + ----------------------------------------- +
+ // | Writer 0 |
+ // + ----------------------------------------- +
+ // 0 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{0, LockEOF},
+ },
+ },
+ start: 4096,
+ end: LockEOF,
+ uid: 0,
+ // + --------------------------- +
+ // | Writer 0 |
+ // +---------------------------- +
+ // 0 4096
+ after: []entry{
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{0, 4096},
+ },
+ },
+ },
+ {
+ name: "unlock for single uid",
+ // + ------------- + --------- + ------------------- +
+ // | Readers 0 & 1 | Writer 0 | Readers 0 & 1 & 2 |
+ // + ------------- + --------- + ------------------- +
+ // 0 1024 4096 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ LockRange: LockRange{0, 1024},
+ },
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{1024, 4096},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}},
+ LockRange: LockRange{4096, LockEOF},
+ },
+ },
+ start: 0,
+ end: LockEOF,
+ uid: 0,
+ // + --------- + --- + --------------- +
+ // | Readers 1 | gap | Readers 1 & 2 |
+ // + --------- + --- + --------------- +
+ // 0 1024 4096 max uint64
+ after: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{1: true}},
+ LockRange: LockRange{0, 1024},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}},
+ LockRange: LockRange{4096, LockEOF},
+ },
+ },
+ },
+ {
+ name: "unlock subsection locked",
+ // + ------------------------------- +
+ // | Readers 0 & 1 & 2 |
+ // + ------------------------------- +
+ // 0 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}},
+ LockRange: LockRange{0, LockEOF},
+ },
+ },
+ start: 1024,
+ end: 4096,
+ uid: 0,
+ // + ----------------- + ------------- + ----------------- +
+ // | Readers 0 & 1 & 2 | Readers 1 & 2 | Readers 0 & 1 & 2 |
+ // + ----------------- + ------------- + ----------------- +
+ // 0 1024 4096 max uint64
+ after: []entry{
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}},
+ LockRange: LockRange{0, 1024},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}},
+ LockRange: LockRange{1024, 4096},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}},
+ LockRange: LockRange{4096, LockEOF},
+ },
+ },
+ },
+ {
+ name: "unlock mid-gap to increase gap",
+ // + --------- + ----- + ------------------- +
+ // | Writer 0 | gap | Readers 0 & 1 |
+ // + --------- + ----- + ------------------- +
+ // 0 1024 4096 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{0, 1024},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ LockRange: LockRange{4096, LockEOF},
+ },
+ },
+ start: 8,
+ end: 2048,
+ uid: 0,
+ // + --------- + ----- + ------------------- +
+ // | Writer 0 | gap | Readers 0 & 1 |
+ // + --------- + ----- + ------------------- +
+ // 0 8 4096 max uint64
+ after: []entry{
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{0, 8},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ LockRange: LockRange{4096, LockEOF},
+ },
+ },
+ },
+ {
+ name: "unlock split region on uid mid-gap",
+ // + --------- + ----- + ------------------- +
+ // | Writer 0 | gap | Readers 0 & 1 |
+ // + --------- + ----- + ------------------- +
+ // 0 1024 4096 max uint64
+ before: []entry{
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{0, 1024},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ LockRange: LockRange{4096, LockEOF},
+ },
+ },
+ start: 2048,
+ end: 8192,
+ uid: 0,
+ // + --------- + ----- + --------- + ------------- +
+ // | Writer 0 | gap | Readers 1 | Readers 0 & 1 |
+ // + --------- + ----- + --------- + ------------- +
+ // 0 1024 4096 8192 max uint64
+ after: []entry{
+ {
+ Lock: Lock{Writer: 0},
+ LockRange: LockRange{0, 1024},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{1: true}},
+ LockRange: LockRange{4096, 8192},
+ },
+ {
+ Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ LockRange: LockRange{8192, LockEOF},
+ },
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ l := fill(test.before)
+
+ r := LockRange{Start: test.start, End: test.end}
+ l.unlock(test.uid, r)
+ var got []entry
+ for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ got = append(got, entry{
+ Lock: seg.Value(),
+ LockRange: seg.Range(),
+ })
+ }
+ if !equals(got, test.after) {
+ t.Errorf("got set %+v, want %+v", got, test.after)
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/fs/mock.go b/pkg/sentry/fs/mock.go
new file mode 100644
index 000000000..1d6ea5736
--- /dev/null
+++ b/pkg/sentry/fs/mock.go
@@ -0,0 +1,176 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// MockInodeOperations implements InodeOperations for testing Inodes.
+type MockInodeOperations struct {
+ InodeOperations
+
+ UAttr UnstableAttr
+
+ createCalled bool
+ createDirectoryCalled bool
+ createLinkCalled bool
+ renameCalled bool
+ walkCalled bool
+}
+
+// NewMockInode returns a mock *Inode using MockInodeOperations.
+func NewMockInode(ctx context.Context, msrc *MountSource, sattr StableAttr) *Inode {
+ return NewInode(ctx, NewMockInodeOperations(ctx), msrc, sattr)
+}
+
+// NewMockInodeOperations returns a *MockInodeOperations.
+func NewMockInodeOperations(ctx context.Context) *MockInodeOperations {
+ return &MockInodeOperations{
+ UAttr: WithCurrentTime(ctx, UnstableAttr{
+ Perms: FilePermsFromMode(0777),
+ }),
+ }
+}
+
+// MockMountSourceOps implements fs.MountSourceOperations.
+type MockMountSourceOps struct {
+ MountSourceOperations
+ keep bool
+ revalidate bool
+}
+
+// NewMockMountSource returns a new *MountSource using MockMountSourceOps.
+func NewMockMountSource(cache *DirentCache) *MountSource {
+ var keep bool
+ if cache != nil {
+ keep = cache.maxSize > 0
+ }
+ return &MountSource{
+ MountSourceOperations: &MockMountSourceOps{keep: keep},
+ fscache: cache,
+ }
+}
+
+// Revalidate implements fs.MountSourceOperations.Revalidate.
+func (n *MockMountSourceOps) Revalidate(context.Context, string, *Inode, *Inode) bool {
+ return n.revalidate
+}
+
+// Keep implements fs.MountSourceOperations.Keep.
+func (n *MockMountSourceOps) Keep(dirent *Dirent) bool {
+ return n.keep
+}
+
+// CacheReaddir implements fs.MountSourceOperations.CacheReaddir.
+func (n *MockMountSourceOps) CacheReaddir() bool {
+ // Common case: cache readdir results if there is a dirent cache.
+ return n.keep
+}
+
+// WriteOut implements fs.InodeOperations.WriteOut.
+func (n *MockInodeOperations) WriteOut(context.Context, *Inode) error {
+ return nil
+}
+
+// UnstableAttr implements fs.InodeOperations.UnstableAttr.
+func (n *MockInodeOperations) UnstableAttr(context.Context, *Inode) (UnstableAttr, error) {
+ return n.UAttr, nil
+}
+
+// IsVirtual implements fs.InodeOperations.IsVirtual.
+func (n *MockInodeOperations) IsVirtual() bool {
+ return false
+}
+
+// Lookup implements fs.InodeOperations.Lookup.
+func (n *MockInodeOperations) Lookup(ctx context.Context, dir *Inode, p string) (*Dirent, error) {
+ n.walkCalled = true
+ return NewDirent(ctx, NewInode(ctx, &MockInodeOperations{}, dir.MountSource, StableAttr{}), p), nil
+}
+
+// SetPermissions implements fs.InodeOperations.SetPermissions.
+func (n *MockInodeOperations) SetPermissions(context.Context, *Inode, FilePermissions) bool {
+ return false
+}
+
+// SetOwner implements fs.InodeOperations.SetOwner.
+func (*MockInodeOperations) SetOwner(context.Context, *Inode, FileOwner) error {
+ return syserror.EINVAL
+}
+
+// SetTimestamps implements fs.InodeOperations.SetTimestamps.
+func (n *MockInodeOperations) SetTimestamps(context.Context, *Inode, TimeSpec) error {
+ return nil
+}
+
+// Create implements fs.InodeOperations.Create.
+func (n *MockInodeOperations) Create(ctx context.Context, dir *Inode, p string, flags FileFlags, perms FilePermissions) (*File, error) {
+ n.createCalled = true
+ d := NewDirent(ctx, NewInode(ctx, &MockInodeOperations{}, dir.MountSource, StableAttr{}), p)
+ return &File{Dirent: d}, nil
+}
+
+// CreateLink implements fs.InodeOperations.CreateLink.
+func (n *MockInodeOperations) CreateLink(_ context.Context, dir *Inode, oldname string, newname string) error {
+ n.createLinkCalled = true
+ return nil
+}
+
+// CreateDirectory implements fs.InodeOperations.CreateDirectory.
+func (n *MockInodeOperations) CreateDirectory(context.Context, *Inode, string, FilePermissions) error {
+ n.createDirectoryCalled = true
+ return nil
+}
+
+// Rename implements fs.InodeOperations.Rename.
+func (n *MockInodeOperations) Rename(ctx context.Context, inode *Inode, oldParent *Inode, oldName string, newParent *Inode, newName string, replacement bool) error {
+ n.renameCalled = true
+ return nil
+}
+
+// Check implements fs.InodeOperations.Check.
+func (n *MockInodeOperations) Check(ctx context.Context, inode *Inode, p PermMask) bool {
+ return ContextCanAccessFile(ctx, inode, p)
+}
+
+// Release implements fs.InodeOperations.Release.
+func (n *MockInodeOperations) Release(context.Context) {}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (n *MockInodeOperations) Truncate(ctx context.Context, inode *Inode, size int64) error {
+ return nil
+}
+
+// Allocate implements fs.InodeOperations.Allocate.
+func (n *MockInodeOperations) Allocate(ctx context.Context, inode *Inode, offset, length int64) error {
+ return nil
+}
+
+// Remove implements fs.InodeOperations.Remove.
+func (n *MockInodeOperations) Remove(context.Context, *Inode, string) error {
+ return nil
+}
+
+// RemoveDirectory implements fs.InodeOperations.RemoveDirectory.
+func (n *MockInodeOperations) RemoveDirectory(context.Context, *Inode, string) error {
+ return nil
+}
+
+// Getlink implements fs.InodeOperations.Getlink.
+func (n *MockInodeOperations) Getlink(context.Context, *Inode) (*Dirent, error) {
+ return nil, syserror.ENOLINK
+}
diff --git a/pkg/sentry/fs/mount.go b/pkg/sentry/fs/mount.go
new file mode 100644
index 000000000..37bae6810
--- /dev/null
+++ b/pkg/sentry/fs/mount.go
@@ -0,0 +1,285 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "bytes"
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refs"
+)
+
+// DirentOperations provide file systems greater control over how long a Dirent
+// stays pinned in core. Implementations must not take Dirent.mu.
+type DirentOperations interface {
+ // Revalidate is called during lookup each time we encounter a Dirent
+ // in the cache. Implementations may update stale properties of the
+ // child Inode. If Revalidate returns true, then the entire Inode will
+ // be reloaded.
+ //
+ // Revalidate will never be called on a Inode that is mounted.
+ Revalidate(ctx context.Context, name string, parent, child *Inode) bool
+
+ // Keep returns true if the Dirent should be kept in memory for as long
+ // as possible beyond any active references.
+ Keep(dirent *Dirent) bool
+
+ // CacheReaddir returns true if directory entries returned by
+ // FileOperations.Readdir may be cached for future use.
+ //
+ // Postconditions: This method must always return the same value.
+ CacheReaddir() bool
+}
+
+// MountSourceOperations contains filesystem specific operations.
+type MountSourceOperations interface {
+ // DirentOperations provide optional extra management of Dirents.
+ DirentOperations
+
+ // Destroy destroys the MountSource.
+ Destroy()
+
+ // Below are MountSourceOperations that do not conform to Linux.
+
+ // ResetInodeMappings clears all mappings of Inodes before SaveInodeMapping
+ // is called.
+ ResetInodeMappings()
+
+ // SaveInodeMappings is called during saving to store, for each reachable
+ // Inode in the mounted filesystem, a mapping of Inode.StableAttr.InodeID
+ // to the Inode's path relative to its mount point. If an Inode is
+ // reachable at more than one path due to hard links, it is unspecified
+ // which path is mapped. Filesystems that do not use this information to
+ // restore inodes can make SaveInodeMappings a no-op.
+ SaveInodeMapping(inode *Inode, path string)
+}
+
+// InodeMappings defines a fmt.Stringer MountSource Inode mappings.
+type InodeMappings map[uint64]string
+
+// String implements fmt.Stringer.String.
+func (i InodeMappings) String() string {
+ var mappingsBuf bytes.Buffer
+ mappingsBuf.WriteString("\n")
+ for ino, name := range i {
+ mappingsBuf.WriteString(fmt.Sprintf("\t%q\t\tinode number %d\n", name, ino))
+ }
+ return mappingsBuf.String()
+}
+
+// MountSource represents a source of file objects.
+//
+// MountSource corresponds to struct super_block in Linux.
+//
+// A mount source may represent a physical device (or a partition of a physical
+// device) or a virtual source of files such as procfs for a specific PID
+// namespace. There should be only one mount source per logical device. E.g.
+// there should be only procfs mount source for a given PID namespace.
+//
+// A mount source represents files as inodes. Every inode belongs to exactly
+// one mount source. Each file object may only be represented using one inode
+// object in a sentry instance.
+//
+// TODO(b/63601033): Move Flags out of MountSource to Mount.
+//
+// +stateify savable
+type MountSource struct {
+ refs.AtomicRefCount
+
+ // MountSourceOperations defines filesystem specific behavior.
+ MountSourceOperations
+
+ // FilesystemType is the type of the filesystem backing this mount.
+ FilesystemType string
+
+ // Flags are the flags that this filesystem was mounted with.
+ Flags MountSourceFlags
+
+ // fscache keeps Dirents pinned beyond application references to them.
+ // It must be flushed before kernel.SaveTo.
+ fscache *DirentCache
+
+ // direntRefs is the sum of references on all Dirents in this MountSource.
+ //
+ // direntRefs is increased when a Dirent in MountSource is IncRef'd, and
+ // decreased when a Dirent in MountSource is DecRef'd.
+ //
+ // To cleanly unmount a MountSource, one must check that no direntRefs are
+ // held anymore. To check, one must hold root.parent.dirMu of the
+ // MountSource's root Dirent before reading direntRefs to prevent further
+ // walks to Dirents in this MountSource.
+ //
+ // direntRefs must be atomically changed.
+ direntRefs uint64
+}
+
+// DefaultDirentCacheSize is the number of Dirents that the VFS can hold an
+// extra reference on.
+const DefaultDirentCacheSize uint64 = 1000
+
+// NewMountSource returns a new MountSource. Filesystem may be nil if there is no
+// filesystem backing the mount.
+func NewMountSource(ctx context.Context, mops MountSourceOperations, filesystem Filesystem, flags MountSourceFlags) *MountSource {
+ fsType := "none"
+ if filesystem != nil {
+ fsType = filesystem.Name()
+ }
+ msrc := MountSource{
+ MountSourceOperations: mops,
+ Flags: flags,
+ FilesystemType: fsType,
+ fscache: NewDirentCache(DefaultDirentCacheSize),
+ }
+ msrc.EnableLeakCheck("fs.MountSource")
+ return &msrc
+}
+
+// DirentRefs returns the current mount direntRefs.
+func (msrc *MountSource) DirentRefs() uint64 {
+ return atomic.LoadUint64(&msrc.direntRefs)
+}
+
+// IncDirentRefs increases direntRefs.
+func (msrc *MountSource) IncDirentRefs() {
+ atomic.AddUint64(&msrc.direntRefs, 1)
+}
+
+// DecDirentRefs decrements direntRefs.
+func (msrc *MountSource) DecDirentRefs() {
+ if atomic.AddUint64(&msrc.direntRefs, ^uint64(0)) == ^uint64(0) {
+ panic("Decremented zero mount reference direntRefs")
+ }
+}
+
+func (msrc *MountSource) destroy() {
+ if c := msrc.DirentRefs(); c != 0 {
+ panic(fmt.Sprintf("MountSource with non-zero direntRefs is being destroyed: %d", c))
+ }
+ msrc.MountSourceOperations.Destroy()
+}
+
+// DecRef drops a reference on the MountSource.
+func (msrc *MountSource) DecRef() {
+ msrc.DecRefWithDestructor(msrc.destroy)
+}
+
+// FlushDirentRefs drops all references held by the MountSource on Dirents.
+func (msrc *MountSource) FlushDirentRefs() {
+ msrc.fscache.Invalidate()
+}
+
+// SetDirentCacheMaxSize sets the max size to the dirent cache associated with
+// this mount source.
+func (msrc *MountSource) SetDirentCacheMaxSize(max uint64) {
+ msrc.fscache.setMaxSize(max)
+}
+
+// SetDirentCacheLimiter sets the limiter objcet to the dirent cache associated
+// with this mount source.
+func (msrc *MountSource) SetDirentCacheLimiter(l *DirentCacheLimiter) {
+ msrc.fscache.limit = l
+}
+
+// NewCachingMountSource returns a generic mount that will cache dirents
+// aggressively.
+func NewCachingMountSource(ctx context.Context, filesystem Filesystem, flags MountSourceFlags) *MountSource {
+ return NewMountSource(ctx, &SimpleMountSourceOperations{
+ keep: true,
+ revalidate: false,
+ cacheReaddir: true,
+ }, filesystem, flags)
+}
+
+// NewNonCachingMountSource returns a generic mount that will never cache dirents.
+func NewNonCachingMountSource(ctx context.Context, filesystem Filesystem, flags MountSourceFlags) *MountSource {
+ return NewMountSource(ctx, &SimpleMountSourceOperations{
+ keep: false,
+ revalidate: false,
+ cacheReaddir: false,
+ }, filesystem, flags)
+}
+
+// NewRevalidatingMountSource returns a generic mount that will cache dirents,
+// but will revalidate them on each lookup and always perform uncached readdir.
+func NewRevalidatingMountSource(ctx context.Context, filesystem Filesystem, flags MountSourceFlags) *MountSource {
+ return NewMountSource(ctx, &SimpleMountSourceOperations{
+ keep: true,
+ revalidate: true,
+ cacheReaddir: false,
+ }, filesystem, flags)
+}
+
+// NewPseudoMountSource returns a "pseudo" mount source that is not backed by
+// an actual filesystem. It is always non-caching.
+func NewPseudoMountSource(ctx context.Context) *MountSource {
+ return NewMountSource(ctx, &SimpleMountSourceOperations{
+ keep: false,
+ revalidate: false,
+ cacheReaddir: false,
+ }, nil, MountSourceFlags{})
+}
+
+// SimpleMountSourceOperations implements MountSourceOperations.
+//
+// +stateify savable
+type SimpleMountSourceOperations struct {
+ keep bool
+ revalidate bool
+ cacheReaddir bool
+}
+
+// Revalidate implements MountSourceOperations.Revalidate.
+func (smo *SimpleMountSourceOperations) Revalidate(context.Context, string, *Inode, *Inode) bool {
+ return smo.revalidate
+}
+
+// Keep implements MountSourceOperations.Keep.
+func (smo *SimpleMountSourceOperations) Keep(*Dirent) bool {
+ return smo.keep
+}
+
+// CacheReaddir implements MountSourceOperations.CacheReaddir.
+func (smo *SimpleMountSourceOperations) CacheReaddir() bool {
+ return smo.cacheReaddir
+}
+
+// ResetInodeMappings implements MountSourceOperations.ResetInodeMappings.
+func (*SimpleMountSourceOperations) ResetInodeMappings() {}
+
+// SaveInodeMapping implements MountSourceOperations.SaveInodeMapping.
+func (*SimpleMountSourceOperations) SaveInodeMapping(*Inode, string) {}
+
+// Destroy implements MountSourceOperations.Destroy.
+func (*SimpleMountSourceOperations) Destroy() {}
+
+// Info defines attributes of a filesystem.
+type Info struct {
+ // Type is the filesystem type magic value.
+ Type uint64
+
+ // TotalBlocks is the total data blocks in the filesystem.
+ TotalBlocks uint64
+
+ // FreeBlocks is the number of free blocks available.
+ FreeBlocks uint64
+
+ // TotalFiles is the total file nodes in the filesystem.
+ TotalFiles uint64
+
+ // FreeFiles is the number of free file nodes.
+ FreeFiles uint64
+}
diff --git a/pkg/sentry/fs/mount_overlay.go b/pkg/sentry/fs/mount_overlay.go
new file mode 100644
index 000000000..78e35b1e6
--- /dev/null
+++ b/pkg/sentry/fs/mount_overlay.go
@@ -0,0 +1,151 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+// overlayMountSourceOperations implements MountSourceOperations for an overlay
+// mount point. The upper filesystem determines the caching behavior of the
+// overlay.
+//
+// +stateify savable
+type overlayMountSourceOperations struct {
+ upper *MountSource
+ lower *MountSource
+}
+
+func newOverlayMountSource(ctx context.Context, upper, lower *MountSource, flags MountSourceFlags) *MountSource {
+ upper.IncRef()
+ lower.IncRef()
+ msrc := NewMountSource(ctx, &overlayMountSourceOperations{
+ upper: upper,
+ lower: lower,
+ }, &overlayFilesystem{}, flags)
+
+ // Use the minimum number to keep resource usage under limits.
+ size := lower.fscache.maxSize
+ if size > upper.fscache.maxSize {
+ size = upper.fscache.maxSize
+ }
+ msrc.fscache.setMaxSize(size)
+
+ return msrc
+}
+
+// Revalidate implements MountSourceOperations.Revalidate for an overlay by
+// delegating to the upper filesystem's Revalidate method. We cannot reload
+// files from the lower filesystem, so we panic if the lower filesystem's
+// Revalidate method returns true.
+func (o *overlayMountSourceOperations) Revalidate(ctx context.Context, name string, parent, child *Inode) bool {
+ if child.overlay == nil {
+ panic("overlay cannot revalidate inode that is not an overlay")
+ }
+
+ // Revalidate is never called on a mount point, so parent and child
+ // must be from the same mount, and thus must both be overlay inodes.
+ if parent.overlay == nil {
+ panic("trying to revalidate an overlay inode but the parent is not an overlay")
+ }
+
+ // We can't revalidate from the lower filesystem.
+ if child.overlay.lower != nil && o.lower.Revalidate(ctx, name, parent.overlay.lower, child.overlay.lower) {
+ panic("an overlay cannot revalidate file objects from the lower fs")
+ }
+
+ var revalidate bool
+ child.overlay.copyMu.RLock()
+ if child.overlay.upper != nil {
+ // Does the upper require revalidation?
+ revalidate = o.upper.Revalidate(ctx, name, parent.overlay.upper, child.overlay.upper)
+ } else {
+ // Nothing to revalidate.
+ revalidate = false
+ }
+ child.overlay.copyMu.RUnlock()
+ return revalidate
+}
+
+// Keep implements MountSourceOperations by delegating to the upper
+// filesystem's Keep method.
+func (o *overlayMountSourceOperations) Keep(dirent *Dirent) bool {
+ return o.upper.Keep(dirent)
+}
+
+// CacheReaddir implements MountSourceOperations.CacheReaddir for an overlay by
+// performing the logical AND of the upper and lower filesystems' CacheReaddir
+// methods.
+//
+// N.B. This is fs-global instead of inode-specific because it must always
+// return the same value. If it was inode-specific, we couldn't guarantee that
+// property across copy up.
+func (o *overlayMountSourceOperations) CacheReaddir() bool {
+ return o.lower.CacheReaddir() && o.upper.CacheReaddir()
+}
+
+// ResetInodeMappings propagates the call to both upper and lower MountSource.
+func (o *overlayMountSourceOperations) ResetInodeMappings() {
+ o.upper.ResetInodeMappings()
+ o.lower.ResetInodeMappings()
+}
+
+// SaveInodeMapping propagates the call to both upper and lower MountSource.
+func (o *overlayMountSourceOperations) SaveInodeMapping(inode *Inode, path string) {
+ inode.overlay.copyMu.RLock()
+ defer inode.overlay.copyMu.RUnlock()
+ if inode.overlay.upper != nil {
+ o.upper.SaveInodeMapping(inode.overlay.upper, path)
+ }
+ if inode.overlay.lower != nil {
+ o.lower.SaveInodeMapping(inode.overlay.lower, path)
+ }
+}
+
+// Destroy drops references on the upper and lower MountSource.
+func (o *overlayMountSourceOperations) Destroy() {
+ o.upper.DecRef()
+ o.lower.DecRef()
+}
+
+// type overlayFilesystem is the filesystem for overlay mounts.
+//
+// +stateify savable
+type overlayFilesystem struct{}
+
+// Name implements Filesystem.Name.
+func (ofs *overlayFilesystem) Name() string {
+ return "overlayfs"
+}
+
+// Flags implements Filesystem.Flags.
+func (ofs *overlayFilesystem) Flags() FilesystemFlags {
+ return 0
+}
+
+// AllowUserMount implements Filesystem.AllowUserMount.
+func (ofs *overlayFilesystem) AllowUserMount() bool {
+ return false
+}
+
+// AllowUserList implements Filesystem.AllowUserList.
+func (*overlayFilesystem) AllowUserList() bool {
+ return true
+}
+
+// Mount implements Filesystem.Mount.
+func (ofs *overlayFilesystem) Mount(ctx context.Context, device string, flags MountSourceFlags, data string, _ interface{}) (*Inode, error) {
+ panic("overlayFilesystem.Mount should not be called!")
+}
diff --git a/pkg/sentry/fs/mount_test.go b/pkg/sentry/fs/mount_test.go
new file mode 100644
index 000000000..a3d10770b
--- /dev/null
+++ b/pkg/sentry/fs/mount_test.go
@@ -0,0 +1,272 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+)
+
+// cacheReallyContains iterates through the dirent cache to determine whether
+// it contains the given dirent.
+func cacheReallyContains(cache *DirentCache, d *Dirent) bool {
+ for i := cache.list.Front(); i != nil; i = i.Next() {
+ if i == d {
+ return true
+ }
+ }
+ return false
+}
+
+func mountPathsAre(root *Dirent, got []*Mount, want ...string) error {
+ gotPaths := make(map[string]struct{}, len(got))
+ gotStr := make([]string, len(got))
+ for i, g := range got {
+ if groot := g.Root(); groot != nil {
+ name, _ := groot.FullName(root)
+ groot.DecRef()
+ gotStr[i] = name
+ gotPaths[name] = struct{}{}
+ }
+ }
+ if len(got) != len(want) {
+ return fmt.Errorf("mount paths are different, got: %q, want: %q", gotStr, want)
+ }
+ for _, w := range want {
+ if _, ok := gotPaths[w]; !ok {
+ return fmt.Errorf("no mount with path %q found", w)
+ }
+ }
+ return nil
+}
+
+// TestMountSourceOnlyCachedOnce tests that a Dirent that is mounted over only ends
+// up in a single Dirent Cache. NOTE(b/63848693): Having a dirent in multiple
+// caches causes major consistency issues.
+func TestMountSourceOnlyCachedOnce(t *testing.T) {
+ ctx := contexttest.Context(t)
+
+ rootCache := NewDirentCache(100)
+ rootInode := NewMockInode(ctx, NewMockMountSource(rootCache), StableAttr{
+ Type: Directory,
+ })
+ mm, err := NewMountNamespace(ctx, rootInode)
+ if err != nil {
+ t.Fatalf("NewMountNamespace failed: %v", err)
+ }
+ rootDirent := mm.Root()
+ defer rootDirent.DecRef()
+
+ // Get a child of the root which we will mount over. Note that the
+ // MockInodeOperations causes Walk to always succeed.
+ child, err := rootDirent.Walk(ctx, rootDirent, "child")
+ if err != nil {
+ t.Fatalf("failed to walk to child dirent: %v", err)
+ }
+ child.maybeExtendReference() // Cache.
+
+ // Ensure that the root cache contains the child.
+ if !cacheReallyContains(rootCache, child) {
+ t.Errorf("wanted rootCache to contain child dirent, but it did not")
+ }
+
+ // Create a new cache and inode, and mount it over child.
+ submountCache := NewDirentCache(100)
+ submountInode := NewMockInode(ctx, NewMockMountSource(submountCache), StableAttr{
+ Type: Directory,
+ })
+ if err := mm.Mount(ctx, child, submountInode); err != nil {
+ t.Fatalf("failed to mount over child: %v", err)
+ }
+
+ // Walk to the child again.
+ child2, err := rootDirent.Walk(ctx, rootDirent, "child")
+ if err != nil {
+ t.Fatalf("failed to walk to child dirent: %v", err)
+ }
+
+ // Should have a different Dirent than before.
+ if child == child2 {
+ t.Fatalf("expected %v not equal to %v, but they are the same", child, child2)
+ }
+
+ // Neither of the caches should no contain the child.
+ if cacheReallyContains(rootCache, child) {
+ t.Errorf("wanted rootCache not to contain child dirent, but it did")
+ }
+ if cacheReallyContains(submountCache, child) {
+ t.Errorf("wanted submountCache not to contain child dirent, but it did")
+ }
+}
+
+func TestAllMountsUnder(t *testing.T) {
+ ctx := contexttest.Context(t)
+
+ rootCache := NewDirentCache(100)
+ rootInode := NewMockInode(ctx, NewMockMountSource(rootCache), StableAttr{
+ Type: Directory,
+ })
+ mm, err := NewMountNamespace(ctx, rootInode)
+ if err != nil {
+ t.Fatalf("NewMountNamespace failed: %v", err)
+ }
+ rootDirent := mm.Root()
+ defer rootDirent.DecRef()
+
+ // Add mounts at the following paths:
+ paths := []string{
+ "/foo",
+ "/foo/bar",
+ "/foo/bar/baz",
+ "/foo/qux",
+ "/waldo",
+ }
+
+ var maxTraversals uint
+ for _, p := range paths {
+ maxTraversals = 0
+ d, err := mm.FindLink(ctx, rootDirent, nil, p, &maxTraversals)
+ if err != nil {
+ t.Fatalf("could not find path %q in mount manager: %v", p, err)
+ }
+
+ submountInode := NewMockInode(ctx, NewMockMountSource(nil), StableAttr{
+ Type: Directory,
+ })
+ if err := mm.Mount(ctx, d, submountInode); err != nil {
+ t.Fatalf("could not mount at %q: %v", p, err)
+ }
+ d.DecRef()
+ }
+
+ // mm root should contain all submounts (and does not include the root mount).
+ rootMnt := mm.FindMount(rootDirent)
+ submounts := mm.AllMountsUnder(rootMnt)
+ allPaths := append(paths, "/")
+ if err := mountPathsAre(rootDirent, submounts, allPaths...); err != nil {
+ t.Error(err)
+ }
+
+ // Each mount should have a unique ID.
+ foundIDs := make(map[uint64]struct{})
+ for _, m := range submounts {
+ if _, ok := foundIDs[m.ID]; ok {
+ t.Errorf("got multiple mounts with id %d", m.ID)
+ }
+ foundIDs[m.ID] = struct{}{}
+ }
+
+ // Root mount should have no parent.
+ if p := rootMnt.ParentID; p != invalidMountID {
+ t.Errorf("root.Parent got %v wanted nil", p)
+ }
+
+ // Check that "foo" mount has 3 children.
+ maxTraversals = 0
+ d, err := mm.FindLink(ctx, rootDirent, nil, "/foo", &maxTraversals)
+ if err != nil {
+ t.Fatalf("could not find path %q in mount manager: %v", "/foo", err)
+ }
+ defer d.DecRef()
+ submounts = mm.AllMountsUnder(mm.FindMount(d))
+ if err := mountPathsAre(rootDirent, submounts, "/foo", "/foo/bar", "/foo/qux", "/foo/bar/baz"); err != nil {
+ t.Error(err)
+ }
+
+ // "waldo" mount should have no children.
+ maxTraversals = 0
+ waldo, err := mm.FindLink(ctx, rootDirent, nil, "/waldo", &maxTraversals)
+ if err != nil {
+ t.Fatalf("could not find path %q in mount manager: %v", "/waldo", err)
+ }
+ defer waldo.DecRef()
+ submounts = mm.AllMountsUnder(mm.FindMount(waldo))
+ if err := mountPathsAre(rootDirent, submounts, "/waldo"); err != nil {
+ t.Error(err)
+ }
+}
+
+func TestUnmount(t *testing.T) {
+ ctx := contexttest.Context(t)
+
+ rootCache := NewDirentCache(100)
+ rootInode := NewMockInode(ctx, NewMockMountSource(rootCache), StableAttr{
+ Type: Directory,
+ })
+ mm, err := NewMountNamespace(ctx, rootInode)
+ if err != nil {
+ t.Fatalf("NewMountNamespace failed: %v", err)
+ }
+ rootDirent := mm.Root()
+ defer rootDirent.DecRef()
+
+ // Add mounts at the following paths:
+ paths := []string{
+ "/foo",
+ "/foo/bar",
+ "/foo/bar/goo",
+ "/foo/bar/goo/abc",
+ "/foo/abc",
+ "/foo/def",
+ "/waldo",
+ "/wally",
+ }
+
+ var maxTraversals uint
+ for _, p := range paths {
+ maxTraversals = 0
+ d, err := mm.FindLink(ctx, rootDirent, nil, p, &maxTraversals)
+ if err != nil {
+ t.Fatalf("could not find path %q in mount manager: %v", p, err)
+ }
+
+ submountInode := NewMockInode(ctx, NewMockMountSource(nil), StableAttr{
+ Type: Directory,
+ })
+ if err := mm.Mount(ctx, d, submountInode); err != nil {
+ t.Fatalf("could not mount at %q: %v", p, err)
+ }
+ d.DecRef()
+ }
+
+ allPaths := make([]string, len(paths)+1)
+ allPaths[0] = "/"
+ copy(allPaths[1:], paths)
+
+ rootMnt := mm.FindMount(rootDirent)
+ for i := len(paths) - 1; i >= 0; i-- {
+ maxTraversals = 0
+ p := paths[i]
+ d, err := mm.FindLink(ctx, rootDirent, nil, p, &maxTraversals)
+ if err != nil {
+ t.Fatalf("could not find path %q in mount manager: %v", p, err)
+ }
+
+ if err := mm.Unmount(ctx, d, false); err != nil {
+ t.Fatalf("could not unmount at %q: %v", p, err)
+ }
+ d.DecRef()
+
+ // Remove the path that has been unmounted and the check that the remaining
+ // mounts are still there.
+ allPaths = allPaths[:len(allPaths)-1]
+ submounts := mm.AllMountsUnder(rootMnt)
+ if err := mountPathsAre(rootDirent, submounts, allPaths...); err != nil {
+ t.Error(err)
+ }
+ }
+}
diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go
new file mode 100644
index 000000000..3f2bd0e87
--- /dev/null
+++ b/pkg/sentry/fs/mounts.go
@@ -0,0 +1,623 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "math"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// DefaultTraversalLimit provides a sensible default traversal limit that may
+// be passed to FindInode and FindLink. You may want to provide other options in
+// individual syscall implementations, but for internal functions this will be
+// sane.
+const DefaultTraversalLimit = 10
+
+const invalidMountID = math.MaxUint64
+
+// Mount represents a mount in the file system. It holds the root dirent for the
+// mount. It also points back to the dirent or mount where it was mounted over,
+// so that it can be restored when unmounted. The chained mount can be either:
+// - Mount: when it's mounted on top of another mount point.
+// - Dirent: when it's mounted on top of a dirent. In this case the mount is
+// called an "undo" mount and only 'root' is set. All other fields are
+// either invalid or nil.
+//
+// +stateify savable
+type Mount struct {
+ // ID is a unique id for this mount. It may be invalidMountID if this is
+ // used to cache a dirent that was mounted over.
+ ID uint64
+
+ // ParentID is the parent's mount unique id. It may be invalidMountID if this
+ // is the root mount or if this is used to cache a dirent that was mounted
+ // over.
+ ParentID uint64
+
+ // root is the root Dirent of this mount. A reference on this Dirent must be
+ // held through the lifetime of the Mount which contains it.
+ root *Dirent
+
+ // previous is the existing dirent or mount that this object was mounted over.
+ // It's nil for the root mount and for the last entry in the chain (always an
+ // "undo" mount).
+ previous *Mount
+}
+
+// newMount creates a new mount, taking a reference on 'root'. Caller must
+// release the reference when it's done with the mount.
+func newMount(id, pid uint64, root *Dirent) *Mount {
+ root.IncRef()
+ return &Mount{
+ ID: id,
+ ParentID: pid,
+ root: root,
+ }
+}
+
+// newRootMount creates a new root mount (no parent), taking a reference on
+// 'root'. Caller must release the reference when it's done with the mount.
+func newRootMount(id uint64, root *Dirent) *Mount {
+ root.IncRef()
+ return &Mount{
+ ID: id,
+ ParentID: invalidMountID,
+ root: root,
+ }
+}
+
+// newUndoMount creates a new undo mount, taking a reference on 'd'. Caller must
+// release the reference when it's done with the mount.
+func newUndoMount(d *Dirent) *Mount {
+ d.IncRef()
+ return &Mount{
+ ID: invalidMountID,
+ ParentID: invalidMountID,
+ root: d,
+ }
+}
+
+// Root returns the root dirent of this mount.
+//
+// This may return nil if the mount has already been free. Callers must handle this
+// case appropriately. If non-nil, callers must call DecRef on the returned *Dirent.
+func (m *Mount) Root() *Dirent {
+ if !m.root.TryIncRef() {
+ return nil
+ }
+ return m.root
+}
+
+// IsRoot returns true if the mount has no parent.
+func (m *Mount) IsRoot() bool {
+ return !m.IsUndo() && m.ParentID == invalidMountID
+}
+
+// IsUndo returns true if 'm' is an undo mount that should be used to restore
+// the original dirent during unmount only and it's not a valid mount.
+func (m *Mount) IsUndo() bool {
+ if m.ID == invalidMountID {
+ if m.ParentID != invalidMountID {
+ panic(fmt.Sprintf("Undo mount with valid parentID: %+v", m))
+ }
+ return true
+ }
+ return false
+}
+
+// MountNamespace defines a VFS root. It contains collection of Mounts that are
+// mounted inside the Dirent tree rooted at the Root Dirent. It provides
+// methods for traversing the Dirent, and for mounting/unmounting in the tree.
+//
+// Note that this does not correspond to a "mount namespace" in the Linux. It
+// is more like a unique VFS instance.
+//
+// It's possible for different processes to have different MountNamespaces. In
+// this case, the file systems exposed to the processes are completely
+// distinct.
+//
+// +stateify savable
+type MountNamespace struct {
+ refs.AtomicRefCount
+
+ // userns is the user namespace associated with this mount namespace.
+ //
+ // All privileged operations on this mount namespace must have
+ // appropriate capabilities in this userns.
+ //
+ // userns is immutable.
+ userns *auth.UserNamespace
+
+ // root is the root directory.
+ root *Dirent
+
+ // mu protects mounts and mountID counter.
+ mu sync.Mutex `state:"nosave"`
+
+ // mounts is a map of mounted Dirent -> Mount object. There are three
+ // possible cases:
+ // - Dirent is mounted over a mount point: the stored Mount object will be
+ // the Mount for that mount point.
+ // - Dirent is mounted over a regular (non-mount point) Dirent: the stored
+ // Mount object will be an "undo" mount containing the mounted-over
+ // Dirent.
+ // - Dirent is the root mount: the stored Mount object will be a root mount
+ // containing the Dirent itself.
+ mounts map[*Dirent]*Mount
+
+ // mountID is the next mount id to assign.
+ mountID uint64
+}
+
+// NewMountNamespace returns a new MountNamespace, with the provided node at the
+// root, and the given cache size. A root must always be provided.
+func NewMountNamespace(ctx context.Context, root *Inode) (*MountNamespace, error) {
+ // Set the root dirent and id on the root mount. The reference returned from
+ // NewDirent will be donated to the MountNamespace constructed below.
+ d := NewDirent(ctx, root, "/")
+
+ mnts := map[*Dirent]*Mount{
+ d: newRootMount(1, d),
+ }
+
+ creds := auth.CredentialsFromContext(ctx)
+ mns := MountNamespace{
+ userns: creds.UserNamespace,
+ root: d,
+ mounts: mnts,
+ mountID: 2,
+ }
+ mns.EnableLeakCheck("fs.MountNamespace")
+ return &mns, nil
+}
+
+// UserNamespace returns the user namespace associated with this mount manager.
+func (mns *MountNamespace) UserNamespace() *auth.UserNamespace {
+ return mns.userns
+}
+
+// Root returns the MountNamespace's root Dirent and increments its reference
+// count. The caller must call DecRef when finished.
+func (mns *MountNamespace) Root() *Dirent {
+ mns.root.IncRef()
+ return mns.root
+}
+
+// FlushMountSourceRefs flushes extra references held by MountSources for all active mount points;
+// see fs/mount.go:MountSource.FlushDirentRefs.
+func (mns *MountNamespace) FlushMountSourceRefs() {
+ mns.mu.Lock()
+ defer mns.mu.Unlock()
+ mns.flushMountSourceRefsLocked()
+}
+
+func (mns *MountNamespace) flushMountSourceRefsLocked() {
+ // Flush mounts' MountSource references.
+ for _, mp := range mns.mounts {
+ for ; mp != nil; mp = mp.previous {
+ mp.root.Inode.MountSource.FlushDirentRefs()
+ }
+ }
+
+ if mns.root == nil {
+ // No root? This MountSource must have already been destroyed.
+ // This can happen when a Save is triggered while a process is
+ // exiting. There is nothing to flush.
+ return
+ }
+
+ // Flush root's MountSource references.
+ mns.root.Inode.MountSource.FlushDirentRefs()
+}
+
+// destroy drops root and mounts dirent references and closes any original nodes.
+//
+// After destroy is called, the MountNamespace may continue to be referenced (for
+// example via /proc/mounts), but should free all resources and shouldn't have
+// Find* methods called.
+func (mns *MountNamespace) destroy() {
+ mns.mu.Lock()
+ defer mns.mu.Unlock()
+
+ // Flush all mounts' MountSource references to Dirents. This allows for mount
+ // points to be torn down since there should be no remaining references after
+ // this and DecRef below.
+ mns.flushMountSourceRefsLocked()
+
+ // Teardown mounts.
+ for _, mp := range mns.mounts {
+ // Drop the mount reference on all mounted dirents.
+ for ; mp != nil; mp = mp.previous {
+ mp.root.DecRef()
+ }
+ }
+ mns.mounts = nil
+
+ // Drop reference on the root.
+ mns.root.DecRef()
+
+ // Ensure that root cannot be accessed via this MountNamespace any
+ // more.
+ mns.root = nil
+
+ // Wait for asynchronous work (queued by dropping Dirent references
+ // above) to complete before destroying this MountNamespace.
+ AsyncBarrier()
+}
+
+// DecRef implements RefCounter.DecRef with destructor mns.destroy.
+func (mns *MountNamespace) DecRef() {
+ mns.DecRefWithDestructor(mns.destroy)
+}
+
+// withMountLocked prevents further walks to `node`, because `node` is about to
+// be a mount point.
+func (mns *MountNamespace) withMountLocked(node *Dirent, fn func() error) error {
+ mns.mu.Lock()
+ defer mns.mu.Unlock()
+
+ renameMu.Lock()
+ defer renameMu.Unlock()
+
+ // Linux allows mounting over the root (?). It comes with a strange set
+ // of semantics. We'll just not do this for now.
+ if node.parent == nil {
+ return syserror.EBUSY
+ }
+
+ // For both mount and unmount, we take this lock so we can swap out the
+ // appropriate child in parent.children.
+ //
+ // For unmount, this also ensures that if `node` is a mount point, the
+ // underlying mount's MountSource.direntRefs cannot increase by preventing
+ // walks to node.
+ node.parent.dirMu.Lock()
+ defer node.parent.dirMu.Unlock()
+
+ node.parent.mu.Lock()
+ defer node.parent.mu.Unlock()
+
+ // We need not take node.dirMu since we have parent.dirMu.
+
+ // We need to take node.mu, so that we can check for deletion.
+ node.mu.Lock()
+ defer node.mu.Unlock()
+
+ return fn()
+}
+
+// Mount mounts a `inode` over the subtree at `node`.
+func (mns *MountNamespace) Mount(ctx context.Context, mountPoint *Dirent, inode *Inode) error {
+ return mns.withMountLocked(mountPoint, func() error {
+ replacement, err := mountPoint.mount(ctx, inode)
+ if err != nil {
+ return err
+ }
+ defer replacement.DecRef()
+
+ // Set the mount's root dirent and id.
+ parentMnt := mns.findMountLocked(mountPoint)
+ childMnt := newMount(mns.mountID, parentMnt.ID, replacement)
+ mns.mountID++
+
+ // Drop mountPoint from its dirent cache.
+ mountPoint.dropExtendedReference()
+
+ // If mountPoint is already a mount, push mountPoint on the stack so it can
+ // be recovered on unmount.
+ if prev := mns.mounts[mountPoint]; prev != nil {
+ childMnt.previous = prev
+ mns.mounts[replacement] = childMnt
+ delete(mns.mounts, mountPoint)
+ return nil
+ }
+
+ // Was not already mounted, just add another mount point.
+ childMnt.previous = newUndoMount(mountPoint)
+ mns.mounts[replacement] = childMnt
+ return nil
+ })
+}
+
+// Unmount ensures no references to the MountSource remain and removes `node` from
+// this subtree. The subtree formerly mounted in `node`'s place will be
+// restored. node's MountSource will be destroyed as soon as the last reference to
+// `node` is dropped, as no references to Dirents within will remain.
+//
+// If detachOnly is set, Unmount merely removes `node` from the subtree, but
+// allows existing references to the MountSource remain. E.g. if an open file still
+// refers to Dirents in MountSource, the Unmount will succeed anyway and MountSource will
+// be destroyed at a later time when all references to Dirents within are
+// dropped.
+//
+// The caller must hold a reference to node from walking to it.
+func (mns *MountNamespace) Unmount(ctx context.Context, node *Dirent, detachOnly bool) error {
+ // This takes locks to prevent further walks to Dirents in this mount
+ // under the assumption that `node` is the root of the mount.
+ return mns.withMountLocked(node, func() error {
+ orig, ok := mns.mounts[node]
+ if !ok {
+ // node is not a mount point.
+ return syserror.EINVAL
+ }
+
+ if orig.previous == nil {
+ panic("cannot unmount initial dirent")
+ }
+
+ m := node.Inode.MountSource
+ if !detachOnly {
+ // Flush all references on the mounted node.
+ m.FlushDirentRefs()
+
+ // At this point, exactly two references must be held
+ // to mount: one mount reference on node, and one due
+ // to walking to node.
+ //
+ // We must also be guaranteed that no more references
+ // can be taken on mount. This is why withMountLocked
+ // must be held at this point to prevent any walks to
+ // and from node.
+ if refs := m.DirentRefs(); refs < 2 {
+ panic(fmt.Sprintf("have %d refs on unmount, expect 2 or more", refs))
+ } else if refs != 2 {
+ return syserror.EBUSY
+ }
+ }
+
+ prev := orig.previous
+ if err := node.unmount(ctx, prev.root); err != nil {
+ return err
+ }
+
+ if prev.previous == nil {
+ if !prev.IsUndo() {
+ panic(fmt.Sprintf("Last mount in the chain must be a undo mount: %+v", prev))
+ }
+ // Drop mount reference taken at the end of MountNamespace.Mount.
+ prev.root.DecRef()
+ } else {
+ mns.mounts[prev.root] = prev
+ }
+ delete(mns.mounts, node)
+
+ return nil
+ })
+}
+
+// FindMount returns the mount that 'd' belongs to. It walks the dirent back
+// until a mount is found. It may return nil if no mount was found.
+func (mns *MountNamespace) FindMount(d *Dirent) *Mount {
+ mns.mu.Lock()
+ defer mns.mu.Unlock()
+ renameMu.Lock()
+ defer renameMu.Unlock()
+
+ return mns.findMountLocked(d)
+}
+
+func (mns *MountNamespace) findMountLocked(d *Dirent) *Mount {
+ for {
+ if mnt := mns.mounts[d]; mnt != nil {
+ return mnt
+ }
+ if d.parent == nil {
+ return nil
+ }
+ d = d.parent
+ }
+}
+
+// AllMountsUnder returns a slice of all mounts under the parent, including
+// itself.
+func (mns *MountNamespace) AllMountsUnder(parent *Mount) []*Mount {
+ mns.mu.Lock()
+ defer mns.mu.Unlock()
+
+ var rv []*Mount
+ for _, mp := range mns.mounts {
+ if !mp.IsUndo() && mp.root.descendantOf(parent.root) {
+ rv = append(rv, mp)
+ }
+ }
+ return rv
+}
+
+// FindLink returns an Dirent from a given node, which may be a symlink.
+//
+// The root argument is treated as the root directory, and FindLink will not
+// return anything above that. The wd dirent provides the starting directory,
+// and may be nil which indicates the root should be used. You must call DecRef
+// on the resulting Dirent when you are no longer using the object.
+//
+// If wd is nil, then the root will be used as the working directory. If the
+// path is absolute, this has no functional impact.
+//
+// Precondition: root must be non-nil.
+// Precondition: the path must be non-empty.
+func (mns *MountNamespace) FindLink(ctx context.Context, root, wd *Dirent, path string, remainingTraversals *uint) (*Dirent, error) {
+ if root == nil {
+ panic("MountNamespace.FindLink: root must not be nil")
+ }
+ if len(path) == 0 {
+ panic("MountNamespace.FindLink: path is empty")
+ }
+
+ // Split the path.
+ first, remainder := SplitFirst(path)
+
+ // Where does this walk originate?
+ current := wd
+ if current == nil {
+ current = root
+ }
+ for first == "/" {
+ // Special case: it's possible that we have nothing to walk at
+ // all. This is necessary since we're resplitting the path.
+ if remainder == "" {
+ root.IncRef()
+ return root, nil
+ }
+
+ // Start at the root and advance the path component so that the
+ // walk below can proceed. Note at this point, it handles the
+ // no-op walk case perfectly fine.
+ current = root
+ first, remainder = SplitFirst(remainder)
+ }
+
+ current.IncRef() // Transferred during walk.
+
+ for {
+ // Check that the file is a directory and that we have
+ // permissions to walk.
+ //
+ // Note that we elide this check for the root directory as an
+ // optimization; a non-executable root may still be walked. A
+ // non-directory root is hopeless.
+ if current != root {
+ if !IsDir(current.Inode.StableAttr) {
+ current.DecRef() // Drop reference from above.
+ return nil, syserror.ENOTDIR
+ }
+ if err := current.Inode.CheckPermission(ctx, PermMask{Execute: true}); err != nil {
+ current.DecRef() // Drop reference from above.
+ return nil, err
+ }
+ }
+
+ // Move to the next level.
+ next, err := current.Walk(ctx, root, first)
+ if err != nil {
+ // Allow failed walks to cache the dirent, because no
+ // children will acquire a reference at the end.
+ current.maybeExtendReference()
+ current.DecRef()
+ return nil, err
+ }
+
+ // Drop old reference.
+ current.DecRef()
+
+ if remainder != "" {
+ // Ensure it's resolved, unless it's the last level.
+ //
+ // See resolve for reference semantics; on err next
+ // will have one dropped.
+ current, err = mns.resolve(ctx, root, next, remainingTraversals)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ // Allow the file system to take an extra reference on the
+ // found child. This will hold a reference on the containing
+ // directory, so the whole tree will be implicitly cached.
+ next.maybeExtendReference()
+ return next, nil
+ }
+
+ // Move to the next element.
+ first, remainder = SplitFirst(remainder)
+ }
+}
+
+// FindInode is identical to FindLink except the return value is resolved.
+//
+//go:nosplit
+func (mns *MountNamespace) FindInode(ctx context.Context, root, wd *Dirent, path string, remainingTraversals *uint) (*Dirent, error) {
+ d, err := mns.FindLink(ctx, root, wd, path, remainingTraversals)
+ if err != nil {
+ return nil, err
+ }
+
+ // See resolve for reference semantics; on err d will have the
+ // reference dropped.
+ return mns.resolve(ctx, root, d, remainingTraversals)
+}
+
+// resolve resolves the given link.
+//
+// If successful, a reference is dropped on node and one is acquired on the
+// caller's behalf for the returned dirent.
+//
+// If not successful, a reference is _also_ dropped on the node and an error
+// returned. This is for convenience in using resolve directly as a return
+// value.
+func (mns *MountNamespace) resolve(ctx context.Context, root, node *Dirent, remainingTraversals *uint) (*Dirent, error) {
+ // Resolve the path.
+ target, err := node.Inode.Getlink(ctx)
+
+ switch err {
+ case nil:
+ // Make sure we didn't exhaust the traversal budget.
+ if *remainingTraversals == 0 {
+ target.DecRef()
+ return nil, syscall.ELOOP
+ }
+
+ node.DecRef() // Drop the original reference.
+ return target, nil
+
+ case syscall.ENOLINK:
+ // Not a symlink.
+ return node, nil
+
+ case ErrResolveViaReadlink:
+ defer node.DecRef() // See above.
+
+ // First, check if we should traverse.
+ if *remainingTraversals == 0 {
+ return nil, syscall.ELOOP
+ }
+
+ // Read the target path.
+ targetPath, err := node.Inode.Readlink(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ // Find the node; we resolve relative to the current symlink's parent.
+ renameMu.RLock()
+ parent := node.parent
+ renameMu.RUnlock()
+ *remainingTraversals--
+ d, err := mns.FindInode(ctx, root, parent, targetPath, remainingTraversals)
+ if err != nil {
+ return nil, err
+ }
+
+ return d, err
+
+ default:
+ node.DecRef() // Drop for err; see above.
+
+ // Propagate the error.
+ return nil, err
+ }
+}
+
+// SyncAll calls Dirent.SyncAll on the root.
+func (mns *MountNamespace) SyncAll(ctx context.Context) {
+ mns.mu.Lock()
+ defer mns.mu.Unlock()
+ mns.root.SyncAll(ctx)
+}
diff --git a/pkg/sentry/fs/mounts_test.go b/pkg/sentry/fs/mounts_test.go
new file mode 100644
index 000000000..a69b41468
--- /dev/null
+++ b/pkg/sentry/fs/mounts_test.go
@@ -0,0 +1,105 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
+)
+
+// Creates a new MountNamespace with filesystem:
+// / (root dir)
+// |-foo (dir)
+// |-bar (file)
+func createMountNamespace(ctx context.Context) (*fs.MountNamespace, error) {
+ perms := fs.FilePermsFromMode(0777)
+ m := fs.NewPseudoMountSource(ctx)
+
+ barFile := fsutil.NewSimpleFileInode(ctx, fs.RootOwner, perms, 0)
+ fooDir := ramfs.NewDir(ctx, map[string]*fs.Inode{
+ "bar": fs.NewInode(ctx, barFile, m, fs.StableAttr{Type: fs.RegularFile}),
+ }, fs.RootOwner, perms)
+ rootDir := ramfs.NewDir(ctx, map[string]*fs.Inode{
+ "foo": fs.NewInode(ctx, fooDir, m, fs.StableAttr{Type: fs.Directory}),
+ }, fs.RootOwner, perms)
+
+ return fs.NewMountNamespace(ctx, fs.NewInode(ctx, rootDir, m, fs.StableAttr{Type: fs.Directory}))
+}
+
+func TestFindLink(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mm, err := createMountNamespace(ctx)
+ if err != nil {
+ t.Fatalf("createMountNamespace failed: %v", err)
+ }
+
+ root := mm.Root()
+ defer root.DecRef()
+ foo, err := root.Walk(ctx, root, "foo")
+ if err != nil {
+ t.Fatalf("Error walking to foo: %v", err)
+ }
+
+ // Positive cases.
+ for _, tc := range []struct {
+ findPath string
+ wd *fs.Dirent
+ wantPath string
+ }{
+ {".", root, "/"},
+ {".", foo, "/foo"},
+ {"..", foo, "/"},
+ {"../../..", foo, "/"},
+ {"///foo", foo, "/foo"},
+ {"/foo", foo, "/foo"},
+ {"/foo/bar", foo, "/foo/bar"},
+ {"/foo/.///./bar", foo, "/foo/bar"},
+ {"/foo///bar", foo, "/foo/bar"},
+ {"/foo/../foo/bar", foo, "/foo/bar"},
+ {"foo/bar", root, "/foo/bar"},
+ {"foo////bar", root, "/foo/bar"},
+ {"bar", foo, "/foo/bar"},
+ } {
+ wdPath, _ := tc.wd.FullName(root)
+ maxTraversals := uint(0)
+ if d, err := mm.FindLink(ctx, root, tc.wd, tc.findPath, &maxTraversals); err != nil {
+ t.Errorf("FindLink(%q, wd=%q) failed: %v", tc.findPath, wdPath, err)
+ } else if got, _ := d.FullName(root); got != tc.wantPath {
+ t.Errorf("FindLink(%q, wd=%q) got dirent %q, want %q", tc.findPath, wdPath, got, tc.wantPath)
+ }
+ }
+
+ // Negative cases.
+ for _, tc := range []struct {
+ findPath string
+ wd *fs.Dirent
+ }{
+ {"bar", root},
+ {"/bar", root},
+ {"/foo/../../bar", root},
+ {"foo", foo},
+ } {
+ wdPath, _ := tc.wd.FullName(root)
+ maxTraversals := uint(0)
+ if _, err := mm.FindLink(ctx, root, tc.wd, tc.findPath, &maxTraversals); err == nil {
+ t.Errorf("FindLink(%q, wd=%q) did not return error", tc.findPath, wdPath)
+ }
+ }
+}
diff --git a/pkg/sentry/fs/offset.go b/pkg/sentry/fs/offset.go
new file mode 100644
index 000000000..53b5df175
--- /dev/null
+++ b/pkg/sentry/fs/offset.go
@@ -0,0 +1,65 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "math"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// OffsetPageEnd returns the file offset rounded up to the nearest
+// page boundary. OffsetPageEnd panics if rounding up causes overflow,
+// which shouldn't be possible given that offset is an int64.
+func OffsetPageEnd(offset int64) uint64 {
+ end, ok := usermem.Addr(offset).RoundUp()
+ if !ok {
+ panic("impossible overflow")
+ }
+ return uint64(end)
+}
+
+// ReadEndOffset returns an exclusive end offset for a read operation
+// so that the read does not overflow an int64 nor size.
+//
+// Parameters:
+// - offset: the starting offset of the read.
+// - length: the number of bytes to read.
+// - size: the size of the file.
+//
+// Postconditions: The returned offset is >= offset.
+func ReadEndOffset(offset int64, length int64, size int64) int64 {
+ if offset >= size {
+ return offset
+ }
+ end := offset + length
+ // Don't overflow.
+ if end < offset || end > size {
+ end = size
+ }
+ return end
+}
+
+// WriteEndOffset returns an exclusive end offset for a write operation
+// so that the write does not overflow an int64.
+//
+// Parameters:
+// - offset: the starting offset of the write.
+// - length: the number of bytes to write.
+//
+// Postconditions: The returned offset is >= offset.
+func WriteEndOffset(offset int64, length int64) int64 {
+ return ReadEndOffset(offset, length, math.MaxInt64)
+}
diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go
new file mode 100644
index 000000000..a8ae7d81d
--- /dev/null
+++ b/pkg/sentry/fs/overlay.go
@@ -0,0 +1,320 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// The virtual filesystem implements an overlay configuration. For a high-level
+// description, see README.md.
+//
+// Note on whiteouts:
+//
+// This implementation does not use the "Docker-style" whiteouts (symlinks with
+// ".wh." prefix). Instead upper filesystem directories support a set of extended
+// attributes to encode whiteouts: "trusted.overlay.whiteout.<filename>". This
+// gives flexibility to persist whiteouts independently of the filesystem layout
+// while additionally preventing name conflicts with files prefixed with ".wh.".
+//
+// Known deficiencies:
+//
+// - The device number of two files under the same overlay mount point may be
+// different. This can happen if a file is found in the lower filesystem (takes
+// the lower filesystem device) and another file is created in the upper
+// filesystem (takes the upper filesystem device). This may appear odd but
+// should not break applications.
+//
+// - Registered events on files (i.e. for notification of read/write readiness)
+// are not copied across copy up. This is fine in the common case of files that
+// do not block. For files that do block, like pipes and sockets, copy up is not
+// supported.
+//
+// - Hardlinks in a lower filesystem are broken by copy up. For this reason, no
+// attempt is made to preserve link count across copy up.
+//
+// - The maximum length of an extended attribute name is the same as the maximum
+// length of a file path in Linux (XATTR_NAME_MAX == NAME_MAX). This means that
+// whiteout attributes, if set directly on the host, are limited additionally by
+// the extra whiteout prefix length (file paths must be strictly shorter than
+// NAME_MAX). This is not a problem for in-memory filesystems which don't enforce
+// XATTR_NAME_MAX.
+
+const (
+ // XattrOverlayPrefix is the prefix for extended attributes that affect
+ // the behavior of an overlay.
+ XattrOverlayPrefix = "trusted.overlay."
+
+ // XattrOverlayWhiteoutPrefix is the prefix for extended attributes
+ // that indicate that a whiteout exists.
+ XattrOverlayWhiteoutPrefix = XattrOverlayPrefix + "whiteout."
+)
+
+// XattrOverlayWhiteout returns an extended attribute that indicates a
+// whiteout exists for name. It is supported by directories that wish to
+// mask the existence of name.
+func XattrOverlayWhiteout(name string) string {
+ return XattrOverlayWhiteoutPrefix + name
+}
+
+// isXattrOverlay returns whether the given extended attribute configures the
+// overlay.
+func isXattrOverlay(name string) bool {
+ return strings.HasPrefix(name, XattrOverlayPrefix)
+}
+
+// NewOverlayRoot produces the root of an overlay.
+//
+// Preconditions:
+//
+// - upper and lower must be non-nil.
+// - upper must not be an overlay.
+// - lower should not expose character devices, pipes, or sockets, because
+// copying up these types of files is not supported.
+// - lower must not require that file objects be revalidated.
+// - lower must not have dynamic file/directory content.
+func NewOverlayRoot(ctx context.Context, upper *Inode, lower *Inode, flags MountSourceFlags) (*Inode, error) {
+ if !IsDir(upper.StableAttr) {
+ return nil, fmt.Errorf("upper Inode is a %v, not a directory", upper.StableAttr.Type)
+ }
+ if !IsDir(lower.StableAttr) {
+ return nil, fmt.Errorf("lower Inode is a %v, not a directory", lower.StableAttr.Type)
+ }
+ if upper.overlay != nil {
+ return nil, fmt.Errorf("cannot nest overlay in upper file of another overlay")
+ }
+
+ msrc := newOverlayMountSource(ctx, upper.MountSource, lower.MountSource, flags)
+ overlay, err := newOverlayEntry(ctx, upper, lower, true)
+ if err != nil {
+ msrc.DecRef()
+ return nil, err
+ }
+
+ return newOverlayInode(ctx, overlay, msrc), nil
+}
+
+// NewOverlayRootFile produces the root of an overlay that points to a file.
+//
+// Preconditions:
+//
+// - lower must be non-nil.
+// - lower should not expose character devices, pipes, or sockets, because
+// copying up these types of files is not supported. Neither it can be a dir.
+// - lower must not require that file objects be revalidated.
+// - lower must not have dynamic file/directory content.
+func NewOverlayRootFile(ctx context.Context, upperMS *MountSource, lower *Inode, flags MountSourceFlags) (*Inode, error) {
+ if !IsRegular(lower.StableAttr) {
+ return nil, fmt.Errorf("lower Inode is not a regular file")
+ }
+ msrc := newOverlayMountSource(ctx, upperMS, lower.MountSource, flags)
+ overlay, err := newOverlayEntry(ctx, nil, lower, true)
+ if err != nil {
+ msrc.DecRef()
+ return nil, err
+ }
+ return newOverlayInode(ctx, overlay, msrc), nil
+}
+
+// newOverlayInode creates a new Inode for an overlay.
+func newOverlayInode(ctx context.Context, o *overlayEntry, msrc *MountSource) *Inode {
+ var inode *Inode
+ if o.upper != nil {
+ inode = NewInode(ctx, nil, msrc, o.upper.StableAttr)
+ } else {
+ inode = NewInode(ctx, nil, msrc, o.lower.StableAttr)
+ }
+ inode.overlay = o
+ return inode
+}
+
+// overlayEntry is the overlay metadata of an Inode. It implements Mappable.
+//
+// +stateify savable
+type overlayEntry struct {
+ // lowerExists is true if an Inode exists for this file in the lower
+ // filesystem. If lowerExists is true, then the overlay must create
+ // a whiteout entry when renaming and removing this entry to mask the
+ // lower Inode.
+ //
+ // Note that this is distinct from actually holding onto a non-nil
+ // lower Inode (below). The overlay does not need to keep a lower Inode
+ // around unless it needs to operate on it, but it always needs to know
+ // whether the lower Inode exists to correctly execute a rename or
+ // remove operation.
+ lowerExists bool
+
+ // lower is an Inode from a lower filesystem. Modifications are
+ // never made on this Inode.
+ lower *Inode
+
+ // copyMu serializes copy-up for operations above
+ // mm.MemoryManager.mappingMu in the lock order.
+ copyMu sync.RWMutex `state:"nosave"`
+
+ // mapsMu serializes copy-up for operations between
+ // mm.MemoryManager.mappingMu and mm.MemoryManager.activeMu in the lock
+ // order.
+ mapsMu sync.Mutex `state:"nosave"`
+
+ // mappings tracks memory mappings of this Mappable so they can be removed
+ // from the lower filesystem Mappable and added to the upper filesystem
+ // Mappable when copy up occurs. It is strictly unnecessary after copy-up.
+ //
+ // mappings is protected by mapsMu.
+ mappings memmap.MappingSet
+
+ // dataMu serializes copy-up for operations below mm.MemoryManager.activeMu
+ // in the lock order.
+ dataMu sync.RWMutex `state:"nosave"`
+
+ // upper is an Inode from an upper filesystem. It is non-nil if
+ // the file exists in the upper filesystem. It becomes non-nil
+ // when the Inode that owns this overlayEntry is modified.
+ //
+ // upper is protected by all of copyMu, mapsMu, and dataMu. Holding any of
+ // these locks is sufficient to read upper; holding all three for writing
+ // is required to mutate it.
+ upper *Inode
+
+ // dirCacheMu protects dirCache.
+ dirCacheMu sync.RWMutex `state:"nosave"`
+
+ // dirCache is cache of DentAttrs from upper and lower Inodes.
+ dirCache *SortedDentryMap
+}
+
+// newOverlayEntry returns a new overlayEntry.
+func newOverlayEntry(ctx context.Context, upper *Inode, lower *Inode, lowerExists bool) (*overlayEntry, error) {
+ if upper == nil && lower == nil {
+ panic("invalid overlayEntry, needs at least one Inode")
+ }
+ if upper != nil && upper.overlay != nil {
+ panic("nested writable layers are not supported")
+ }
+ // Check for supported lower filesystem types.
+ if lower != nil {
+ switch lower.StableAttr.Type {
+ case RegularFile, Directory, Symlink, Socket:
+ default:
+ // We don't support copying up from character devices,
+ // named pipes, or anything weird (like proc files).
+ log.Warningf("%s not supported in lower filesytem", lower.StableAttr.Type)
+ return nil, syserror.EINVAL
+ }
+ }
+ return &overlayEntry{
+ lowerExists: lowerExists,
+ lower: lower,
+ upper: upper,
+ }, nil
+}
+
+func (o *overlayEntry) release() {
+ // We drop a reference on upper and lower file system Inodes
+ // rather than releasing them, because in-memory filesystems
+ // may hold an extra reference to these Inodes so that they
+ // stay in memory.
+ if o.upper != nil {
+ o.upper.DecRef()
+ }
+ if o.lower != nil {
+ o.lower.DecRef()
+ }
+}
+
+// overlayUpperMountSource gives the upper mount of an overlay mount.
+//
+// The caller may not use this MountSource past the lifetime of overlayMountSource and may
+// not call DecRef on it.
+func overlayUpperMountSource(overlayMountSource *MountSource) *MountSource {
+ return overlayMountSource.MountSourceOperations.(*overlayMountSourceOperations).upper
+}
+
+// Preconditions: At least one of o.copyMu, o.mapsMu, or o.dataMu must be locked.
+func (o *overlayEntry) inodeLocked() *Inode {
+ if o.upper != nil {
+ return o.upper
+ }
+ return o.lower
+}
+
+// Preconditions: At least one of o.copyMu, o.mapsMu, or o.dataMu must be locked.
+func (o *overlayEntry) isMappableLocked() bool {
+ return o.inodeLocked().Mappable() != nil
+}
+
+// markDirectoryDirty marks any cached data dirty for this directory. This is
+// necessary in order to ensure that this node does not retain stale state
+// throughout its lifetime across multiple open directory handles.
+//
+// Currently this means invalidating any readdir caches.
+func (o *overlayEntry) markDirectoryDirty() {
+ o.dirCacheMu.Lock()
+ o.dirCache = nil
+ o.dirCacheMu.Unlock()
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (o *overlayEntry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ o.mapsMu.Lock()
+ defer o.mapsMu.Unlock()
+ if err := o.inodeLocked().Mappable().AddMapping(ctx, ms, ar, offset, writable); err != nil {
+ return err
+ }
+ o.mappings.AddMapping(ms, ar, offset, writable)
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (o *overlayEntry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ o.mapsMu.Lock()
+ defer o.mapsMu.Unlock()
+ o.inodeLocked().Mappable().RemoveMapping(ctx, ms, ar, offset, writable)
+ o.mappings.RemoveMapping(ms, ar, offset, writable)
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (o *overlayEntry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ o.mapsMu.Lock()
+ defer o.mapsMu.Unlock()
+ if err := o.inodeLocked().Mappable().CopyMapping(ctx, ms, srcAR, dstAR, offset, writable); err != nil {
+ return err
+ }
+ o.mappings.AddMapping(ms, dstAR, offset, writable)
+ return nil
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (o *overlayEntry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ o.dataMu.RLock()
+ defer o.dataMu.RUnlock()
+ return o.inodeLocked().Mappable().Translate(ctx, required, optional, at)
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (o *overlayEntry) InvalidateUnsavable(ctx context.Context) error {
+ o.mapsMu.Lock()
+ defer o.mapsMu.Unlock()
+ return o.inodeLocked().Mappable().InvalidateUnsavable(ctx)
+}
diff --git a/pkg/sentry/fs/path.go b/pkg/sentry/fs/path.go
new file mode 100644
index 000000000..e4dc02dbb
--- /dev/null
+++ b/pkg/sentry/fs/path.go
@@ -0,0 +1,119 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "path/filepath"
+ "strings"
+)
+
+// TrimTrailingSlashes trims any trailing slashes.
+//
+// The returned boolean indicates whether any changes were made.
+//
+//go:nosplit
+func TrimTrailingSlashes(dir string) (trimmed string, changed bool) {
+ // Trim the trailing slash, except for root.
+ for len(dir) > 1 && dir[len(dir)-1] == '/' {
+ dir = dir[:len(dir)-1]
+ changed = true
+ }
+ return dir, changed
+}
+
+// SplitLast splits the given path into a directory and a file.
+//
+// The "absoluteness" of the path is preserved, but dir is always stripped of
+// trailing slashes.
+//
+//go:nosplit
+func SplitLast(path string) (dir, file string) {
+ path, _ = TrimTrailingSlashes(path)
+ if path == "" {
+ return ".", "."
+ } else if path == "/" {
+ return "/", "."
+ }
+
+ var slash int // Last location of slash in path.
+ for slash = len(path) - 1; slash >= 0 && path[slash] != '/'; slash-- {
+ }
+ switch {
+ case slash < 0:
+ return ".", path
+ case slash == 0:
+ // Directory of the form "/foo", or just "/". We need to
+ // preserve the first slash here, since it indicates an
+ // absolute path.
+ return "/", path[1:]
+ default:
+ // Drop the trailing slash.
+ dir, _ = TrimTrailingSlashes(path[:slash])
+ return dir, path[slash+1:]
+ }
+}
+
+// SplitFirst splits the given path into a first directory and the remainder.
+//
+// If remainder is empty, then the path is a single element.
+//
+//go:nosplit
+func SplitFirst(path string) (current, remainder string) {
+ path, _ = TrimTrailingSlashes(path)
+ if path == "" {
+ return ".", ""
+ }
+
+ var slash int // First location of slash in path.
+ for slash = 0; slash < len(path) && path[slash] != '/'; slash++ {
+ }
+ switch {
+ case slash >= len(path):
+ return path, ""
+ case slash == 0:
+ // See above.
+ return "/", path[1:]
+ default:
+ current = path[:slash]
+ remainder = path[slash+1:]
+ // Strip redundant slashes.
+ for len(remainder) > 0 && remainder[0] == '/' {
+ remainder = remainder[1:]
+ }
+ return current, remainder
+ }
+}
+
+// IsSubpath checks whether the first path is a (strict) descendent of the
+// second. If it is a subpath, then true is returned along with a clean
+// relative path from the second path to the first. Otherwise false is
+// returned.
+func IsSubpath(subpath, path string) (string, bool) {
+ cleanPath := filepath.Clean(path)
+ cleanSubpath := filepath.Clean(subpath)
+
+ // Add a trailing slash to the path if it does not already have one.
+ if len(cleanPath) == 0 || cleanPath[len(cleanPath)-1] != '/' {
+ cleanPath += "/"
+ }
+ if cleanPath == cleanSubpath {
+ // Paths are equal, thus not a strict subpath.
+ return "", false
+ }
+ if strings.HasPrefix(cleanSubpath, cleanPath) {
+ return strings.TrimPrefix(cleanSubpath, cleanPath), true
+ }
+ return "", false
+}
diff --git a/pkg/sentry/fs/path_test.go b/pkg/sentry/fs/path_test.go
new file mode 100644
index 000000000..e6f57ebba
--- /dev/null
+++ b/pkg/sentry/fs/path_test.go
@@ -0,0 +1,289 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "testing"
+)
+
+// TestSplitLast tests variants of path splitting.
+func TestSplitLast(t *testing.T) {
+ cases := []struct {
+ path string
+ dir string
+ file string
+ }{
+ {path: "/", dir: "/", file: "."},
+ {path: "/.", dir: "/", file: "."},
+ {path: "/./", dir: "/", file: "."},
+ {path: "/./.", dir: "/.", file: "."},
+ {path: "/././", dir: "/.", file: "."},
+ {path: "/./..", dir: "/.", file: ".."},
+ {path: "/./../", dir: "/.", file: ".."},
+ {path: "/..", dir: "/", file: ".."},
+ {path: "/../", dir: "/", file: ".."},
+ {path: "/../.", dir: "/..", file: "."},
+ {path: "/.././", dir: "/..", file: "."},
+ {path: "/../..", dir: "/..", file: ".."},
+ {path: "/../../", dir: "/..", file: ".."},
+
+ {path: "", dir: ".", file: "."},
+ {path: ".", dir: ".", file: "."},
+ {path: "./", dir: ".", file: "."},
+ {path: "./.", dir: ".", file: "."},
+ {path: "././", dir: ".", file: "."},
+ {path: "./..", dir: ".", file: ".."},
+ {path: "./../", dir: ".", file: ".."},
+ {path: "..", dir: ".", file: ".."},
+ {path: "../", dir: ".", file: ".."},
+ {path: "../.", dir: "..", file: "."},
+ {path: ".././", dir: "..", file: "."},
+ {path: "../..", dir: "..", file: ".."},
+ {path: "../../", dir: "..", file: ".."},
+
+ {path: "/foo", dir: "/", file: "foo"},
+ {path: "/foo/", dir: "/", file: "foo"},
+ {path: "/foo/.", dir: "/foo", file: "."},
+ {path: "/foo/./", dir: "/foo", file: "."},
+ {path: "/foo/./.", dir: "/foo/.", file: "."},
+ {path: "/foo/./..", dir: "/foo/.", file: ".."},
+ {path: "/foo/..", dir: "/foo", file: ".."},
+ {path: "/foo/../", dir: "/foo", file: ".."},
+ {path: "/foo/../.", dir: "/foo/..", file: "."},
+ {path: "/foo/../..", dir: "/foo/..", file: ".."},
+
+ {path: "/foo/bar", dir: "/foo", file: "bar"},
+ {path: "/foo/bar/", dir: "/foo", file: "bar"},
+ {path: "/foo/bar/.", dir: "/foo/bar", file: "."},
+ {path: "/foo/bar/./", dir: "/foo/bar", file: "."},
+ {path: "/foo/bar/./.", dir: "/foo/bar/.", file: "."},
+ {path: "/foo/bar/./..", dir: "/foo/bar/.", file: ".."},
+ {path: "/foo/bar/..", dir: "/foo/bar", file: ".."},
+ {path: "/foo/bar/../", dir: "/foo/bar", file: ".."},
+ {path: "/foo/bar/../.", dir: "/foo/bar/..", file: "."},
+ {path: "/foo/bar/../..", dir: "/foo/bar/..", file: ".."},
+
+ {path: "foo", dir: ".", file: "foo"},
+ {path: "foo", dir: ".", file: "foo"},
+ {path: "foo/", dir: ".", file: "foo"},
+ {path: "foo/.", dir: "foo", file: "."},
+ {path: "foo/./", dir: "foo", file: "."},
+ {path: "foo/./.", dir: "foo/.", file: "."},
+ {path: "foo/./..", dir: "foo/.", file: ".."},
+ {path: "foo/..", dir: "foo", file: ".."},
+ {path: "foo/../", dir: "foo", file: ".."},
+ {path: "foo/../.", dir: "foo/..", file: "."},
+ {path: "foo/../..", dir: "foo/..", file: ".."},
+ {path: "foo/", dir: ".", file: "foo"},
+ {path: "foo/.", dir: "foo", file: "."},
+
+ {path: "foo/bar", dir: "foo", file: "bar"},
+ {path: "foo/bar/", dir: "foo", file: "bar"},
+ {path: "foo/bar/.", dir: "foo/bar", file: "."},
+ {path: "foo/bar/./", dir: "foo/bar", file: "."},
+ {path: "foo/bar/./.", dir: "foo/bar/.", file: "."},
+ {path: "foo/bar/./..", dir: "foo/bar/.", file: ".."},
+ {path: "foo/bar/..", dir: "foo/bar", file: ".."},
+ {path: "foo/bar/../", dir: "foo/bar", file: ".."},
+ {path: "foo/bar/../.", dir: "foo/bar/..", file: "."},
+ {path: "foo/bar/../..", dir: "foo/bar/..", file: ".."},
+ {path: "foo/bar/", dir: "foo", file: "bar"},
+ {path: "foo/bar/.", dir: "foo/bar", file: "."},
+ }
+
+ for _, c := range cases {
+ dir, file := SplitLast(c.path)
+ if dir != c.dir || file != c.file {
+ t.Errorf("SplitLast(%q) got (%q, %q), expected (%q, %q)", c.path, dir, file, c.dir, c.file)
+ }
+ }
+}
+
+// TestSplitFirst tests variants of path splitting.
+func TestSplitFirst(t *testing.T) {
+ cases := []struct {
+ path string
+ first string
+ remainder string
+ }{
+ {path: "/", first: "/", remainder: ""},
+ {path: "/.", first: "/", remainder: "."},
+ {path: "///.", first: "/", remainder: "//."},
+ {path: "/.///", first: "/", remainder: "."},
+ {path: "/./.", first: "/", remainder: "./."},
+ {path: "/././", first: "/", remainder: "./."},
+ {path: "/./..", first: "/", remainder: "./.."},
+ {path: "/./../", first: "/", remainder: "./.."},
+ {path: "/..", first: "/", remainder: ".."},
+ {path: "/../", first: "/", remainder: ".."},
+ {path: "/../.", first: "/", remainder: "../."},
+ {path: "/.././", first: "/", remainder: "../."},
+ {path: "/../..", first: "/", remainder: "../.."},
+ {path: "/../../", first: "/", remainder: "../.."},
+
+ {path: "", first: ".", remainder: ""},
+ {path: ".", first: ".", remainder: ""},
+ {path: "./", first: ".", remainder: ""},
+ {path: ".///", first: ".", remainder: ""},
+ {path: "./.", first: ".", remainder: "."},
+ {path: "././", first: ".", remainder: "."},
+ {path: "./..", first: ".", remainder: ".."},
+ {path: "./../", first: ".", remainder: ".."},
+ {path: "..", first: "..", remainder: ""},
+ {path: "../", first: "..", remainder: ""},
+ {path: "../.", first: "..", remainder: "."},
+ {path: ".././", first: "..", remainder: "."},
+ {path: "../..", first: "..", remainder: ".."},
+ {path: "../../", first: "..", remainder: ".."},
+
+ {path: "/foo", first: "/", remainder: "foo"},
+ {path: "/foo/", first: "/", remainder: "foo"},
+ {path: "/foo///", first: "/", remainder: "foo"},
+ {path: "/foo/.", first: "/", remainder: "foo/."},
+ {path: "/foo/./", first: "/", remainder: "foo/."},
+ {path: "/foo/./.", first: "/", remainder: "foo/./."},
+ {path: "/foo/./..", first: "/", remainder: "foo/./.."},
+ {path: "/foo/..", first: "/", remainder: "foo/.."},
+ {path: "/foo/../", first: "/", remainder: "foo/.."},
+ {path: "/foo/../.", first: "/", remainder: "foo/../."},
+ {path: "/foo/../..", first: "/", remainder: "foo/../.."},
+
+ {path: "/foo/bar", first: "/", remainder: "foo/bar"},
+ {path: "///foo/bar", first: "/", remainder: "//foo/bar"},
+ {path: "/foo///bar", first: "/", remainder: "foo///bar"},
+ {path: "/foo/bar/.", first: "/", remainder: "foo/bar/."},
+ {path: "/foo/bar/./", first: "/", remainder: "foo/bar/."},
+ {path: "/foo/bar/./.", first: "/", remainder: "foo/bar/./."},
+ {path: "/foo/bar/./..", first: "/", remainder: "foo/bar/./.."},
+ {path: "/foo/bar/..", first: "/", remainder: "foo/bar/.."},
+ {path: "/foo/bar/../", first: "/", remainder: "foo/bar/.."},
+ {path: "/foo/bar/../.", first: "/", remainder: "foo/bar/../."},
+ {path: "/foo/bar/../..", first: "/", remainder: "foo/bar/../.."},
+
+ {path: "foo", first: "foo", remainder: ""},
+ {path: "foo", first: "foo", remainder: ""},
+ {path: "foo/", first: "foo", remainder: ""},
+ {path: "foo///", first: "foo", remainder: ""},
+ {path: "foo/.", first: "foo", remainder: "."},
+ {path: "foo/./", first: "foo", remainder: "."},
+ {path: "foo/./.", first: "foo", remainder: "./."},
+ {path: "foo/./..", first: "foo", remainder: "./.."},
+ {path: "foo/..", first: "foo", remainder: ".."},
+ {path: "foo/../", first: "foo", remainder: ".."},
+ {path: "foo/../.", first: "foo", remainder: "../."},
+ {path: "foo/../..", first: "foo", remainder: "../.."},
+ {path: "foo/", first: "foo", remainder: ""},
+ {path: "foo/.", first: "foo", remainder: "."},
+
+ {path: "foo/bar", first: "foo", remainder: "bar"},
+ {path: "foo///bar", first: "foo", remainder: "bar"},
+ {path: "foo/bar/", first: "foo", remainder: "bar"},
+ {path: "foo/bar/.", first: "foo", remainder: "bar/."},
+ {path: "foo/bar/./", first: "foo", remainder: "bar/."},
+ {path: "foo/bar/./.", first: "foo", remainder: "bar/./."},
+ {path: "foo/bar/./..", first: "foo", remainder: "bar/./.."},
+ {path: "foo/bar/..", first: "foo", remainder: "bar/.."},
+ {path: "foo/bar/../", first: "foo", remainder: "bar/.."},
+ {path: "foo/bar/../.", first: "foo", remainder: "bar/../."},
+ {path: "foo/bar/../..", first: "foo", remainder: "bar/../.."},
+ {path: "foo/bar/", first: "foo", remainder: "bar"},
+ {path: "foo/bar/.", first: "foo", remainder: "bar/."},
+ }
+
+ for _, c := range cases {
+ first, remainder := SplitFirst(c.path)
+ if first != c.first || remainder != c.remainder {
+ t.Errorf("SplitFirst(%q) got (%q, %q), expected (%q, %q)", c.path, first, remainder, c.first, c.remainder)
+ }
+ }
+}
+
+// TestIsSubpath tests the IsSubpath method.
+func TestIsSubpath(t *testing.T) {
+ tcs := []struct {
+ // Two absolute paths.
+ pathA string
+ pathB string
+
+ // Whether pathA is a subpath of pathB.
+ wantIsSubpath bool
+
+ // Relative path from pathA to pathB. Only checked if
+ // wantIsSubpath is true.
+ wantRelpath string
+ }{
+ {
+ pathA: "/foo/bar/baz",
+ pathB: "/foo",
+ wantIsSubpath: true,
+ wantRelpath: "bar/baz",
+ },
+ {
+ pathA: "/foo",
+ pathB: "/foo/bar/baz",
+ wantIsSubpath: false,
+ },
+ {
+ pathA: "/foo",
+ pathB: "/foo",
+ wantIsSubpath: false,
+ },
+ {
+ pathA: "/foobar",
+ pathB: "/foo",
+ wantIsSubpath: false,
+ },
+ {
+ pathA: "/foo",
+ pathB: "/foobar",
+ wantIsSubpath: false,
+ },
+ {
+ pathA: "/foo",
+ pathB: "/foobar",
+ wantIsSubpath: false,
+ },
+ {
+ pathA: "/",
+ pathB: "/foo",
+ wantIsSubpath: false,
+ },
+ {
+ pathA: "/foo",
+ pathB: "/",
+ wantIsSubpath: true,
+ wantRelpath: "foo",
+ },
+ {
+ pathA: "/foo/bar/../bar",
+ pathB: "/foo",
+ wantIsSubpath: true,
+ wantRelpath: "bar",
+ },
+ {
+ pathA: "/foo/bar",
+ pathB: "/foo/../foo",
+ wantIsSubpath: true,
+ wantRelpath: "bar",
+ },
+ }
+
+ for _, tc := range tcs {
+ gotRelpath, gotIsSubpath := IsSubpath(tc.pathA, tc.pathB)
+ if gotRelpath != tc.wantRelpath || gotIsSubpath != tc.wantIsSubpath {
+ t.Errorf("IsSubpath(%q, %q) got %q %t, want %q %t", tc.pathA, tc.pathB, gotRelpath, gotIsSubpath, tc.wantRelpath, tc.wantIsSubpath)
+ }
+ }
+}
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD
new file mode 100644
index 000000000..77c2c5c0e
--- /dev/null
+++ b/pkg/sentry/fs/proc/BUILD
@@ -0,0 +1,72 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "proc",
+ srcs = [
+ "cgroup.go",
+ "cpuinfo.go",
+ "exec_args.go",
+ "fds.go",
+ "filesystems.go",
+ "fs.go",
+ "inode.go",
+ "loadavg.go",
+ "meminfo.go",
+ "mounts.go",
+ "net.go",
+ "proc.go",
+ "stat.go",
+ "sys.go",
+ "sys_net.go",
+ "sys_net_state.go",
+ "task.go",
+ "uid_gid_map.go",
+ "uptime.go",
+ "version.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/proc/device",
+ "//pkg/sentry/fs/proc/seqfile",
+ "//pkg/sentry/fs/ramfs",
+ "//pkg/sentry/fsbridge",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/mm",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/unix",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/usage",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/tcpip/header",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "proc_test",
+ size = "small",
+ srcs = [
+ "net_test.go",
+ "sys_net_test.go",
+ ],
+ library = ":proc",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/inet",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fs/proc/README.md b/pkg/sentry/fs/proc/README.md
new file mode 100644
index 000000000..6667a0916
--- /dev/null
+++ b/pkg/sentry/fs/proc/README.md
@@ -0,0 +1,336 @@
+This document tracks what is implemented in procfs. Refer to
+Documentation/filesystems/proc.txt in the Linux project for information about
+procfs generally.
+
+**NOTE**: This document is not guaranteed to be up to date. If you find an
+inconsistency, please file a bug.
+
+[TOC]
+
+## Kernel data
+
+The following files are implemented:
+
+<!-- mdformat off(don't wrap the table) -->
+
+| File /proc/ | Content |
+| :------------------------ | :---------------------------------------------------- |
+| [cpuinfo](#cpuinfo) | Info about the CPU |
+| [filesystems](#filesystems) | Supported filesystems |
+| [loadavg](#loadavg) | Load average of last 1, 5 & 15 minutes |
+| [meminfo](#meminfo) | Overall memory info |
+| [stat](#stat) | Overall kernel statistics |
+| [sys](#sys) | Change parameters within the kernel |
+| [uptime](#uptime) | Wall clock since boot, combined idle time of all cpus |
+| [version](#version) | Kernel version |
+
+<!-- mdformat on -->
+
+### cpuinfo
+
+```bash
+$ cat /proc/cpuinfo
+processor : 0
+vendor_id : GenuineIntel
+cpu family : 6
+model : 45
+model name : unknown
+stepping : unknown
+cpu MHz : 1234.588
+fpu : yes
+fpu_exception : yes
+cpuid level : 13
+wp : yes
+flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic popcnt tsc_deadline_timer aes xsave avx xsaveopt
+bogomips : 1234.59
+clflush size : 64
+cache_alignment : 64
+address sizes : 46 bits physical, 48 bits virtual
+power management:
+
+...
+```
+
+Notable divergences:
+
+Field name | Notes
+:--------------- | :---------------------------------------
+model name | Always unknown
+stepping | Always unknown
+fpu | Always yes
+fpu_exception | Always yes
+wp | Always yes
+bogomips | Bogus value (matches cpu MHz)
+clflush size | Always 64
+cache_alignment | Always 64
+address sizes | Always 46 bits physical, 48 bits virtual
+power management | Always blank
+
+Otherwise fields are derived from the sentry configuration.
+
+### filesystems
+
+```bash
+$ cat /proc/filesystems
+nodev 9p
+nodev devpts
+nodev devtmpfs
+nodev proc
+nodev sysfs
+nodev tmpfs
+```
+
+### loadavg
+
+```bash
+$ cat /proc/loadavg
+0.00 0.00 0.00 0/0 0
+```
+
+Column | Notes
+:------------------------------------ | :----------
+CPU.IO utilization in last 1 minute | Always zero
+CPU.IO utilization in last 5 minutes | Always zero
+CPU.IO utilization in last 10 minutes | Always zero
+Num currently running processes | Always zero
+Total num processes | Always zero
+
+TODO(b/62345059): Populate the columns with accurate statistics.
+
+### meminfo
+
+```bash
+$ cat /proc/meminfo
+MemTotal: 2097152 kB
+MemFree: 2083540 kB
+MemAvailable: 2083540 kB
+Buffers: 0 kB
+Cached: 4428 kB
+SwapCache: 0 kB
+Active: 10812 kB
+Inactive: 2216 kB
+Active(anon): 8600 kB
+Inactive(anon): 0 kB
+Active(file): 2212 kB
+Inactive(file): 2216 kB
+Unevictable: 0 kB
+Mlocked: 0 kB
+SwapTotal: 0 kB
+SwapFree: 0 kB
+Dirty: 0 kB
+Writeback: 0 kB
+AnonPages: 8600 kB
+Mapped: 4428 kB
+Shmem: 0 kB
+
+```
+
+Notable divergences:
+
+Field name | Notes
+:---------------- | :-----------------------------------------------------
+Buffers | Always zero, no block devices
+SwapCache | Always zero, no swap
+Inactive(anon) | Always zero, see SwapCache
+Unevictable | Always zero TODO(b/31823263)
+Mlocked | Always zero TODO(b/31823263)
+SwapTotal | Always zero, no swap
+SwapFree | Always zero, no swap
+Dirty | Always zero TODO(b/31823263)
+Writeback | Always zero TODO(b/31823263)
+MemAvailable | Uses the same value as MemFree since there is no swap.
+Slab | Missing
+SReclaimable | Missing
+SUnreclaim | Missing
+KernelStack | Missing
+PageTables | Missing
+NFS_Unstable | Missing
+Bounce | Missing
+WritebackTmp | Missing
+CommitLimit | Missing
+Committed_AS | Missing
+VmallocTotal | Missing
+VmallocUsed | Missing
+VmallocChunk | Missing
+HardwareCorrupted | Missing
+AnonHugePages | Missing
+ShmemHugePages | Missing
+ShmemPmdMapped | Missing
+HugePages_Total | Missing
+HugePages_Free | Missing
+HugePages_Rsvd | Missing
+HugePages_Surp | Missing
+Hugepagesize | Missing
+DirectMap4k | Missing
+DirectMap2M | Missing
+DirectMap1G | Missing
+
+### stat
+
+```bash
+$ cat /proc/stat
+cpu 0 0 0 0 0 0 0 0 0 0
+cpu0 0 0 0 0 0 0 0 0 0 0
+cpu1 0 0 0 0 0 0 0 0 0 0
+cpu2 0 0 0 0 0 0 0 0 0 0
+cpu3 0 0 0 0 0 0 0 0 0 0
+cpu4 0 0 0 0 0 0 0 0 0 0
+cpu5 0 0 0 0 0 0 0 0 0 0
+cpu6 0 0 0 0 0 0 0 0 0 0
+cpu7 0 0 0 0 0 0 0 0 0 0
+intr 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+ctxt 0
+btime 1504040968
+processes 0
+procs_running 0
+procs_blokkcked 0
+softirq 0 0 0 0 0 0 0 0 0 0 0
+```
+
+All fields except for `btime` are always zero.
+
+TODO(b/37226836): Populate with accurate fields.
+
+### sys
+
+```bash
+$ ls /proc/sys
+kernel vm
+```
+
+Directory | Notes
+:-------- | :----------------------------
+abi | Missing
+debug | Missing
+dev | Missing
+fs | Missing
+kernel | Contains hostname (only)
+net | Missing
+user | Missing
+vm | Contains mmap_min_addr (only)
+
+### uptime
+
+```bash
+$ cat /proc/uptime
+3204.62 0.00
+```
+
+Column | Notes
+:------------------------------- | :----------------------------
+Total num seconds system running | Time since procfs was mounted
+Number of seconds idle | Always zero
+
+### version
+
+```bash
+$ cat /proc/version
+Linux version 4.4 #1 SMP Sun Jan 10 15:06:54 PST 2016
+```
+
+## Process-specific data
+
+The following files are implemented:
+
+File /proc/PID | Content
+:---------------------- | :---------------------------------------------------
+[auxv](#auxv) | Copy of auxiliary vector for the process
+[cmdline](#cmdline) | Command line arguments
+[comm](#comm) | Command name associated with the process
+[environ](#environ) | Process environment
+[exe](#exe) | Symlink to the process's executable
+[fd](#fd) | Directory containing links to open file descriptors
+[fdinfo](#fdinfo) | Information associated with open file descriptors
+[gid_map](#gid_map) | Mappings for group IDs inside the user namespace
+[io](#io) | IO statistics
+[maps](#maps) | Memory mappings (anon, executables, library files)
+[mounts](#mounts) | Mounted filesystems
+[mountinfo](#mountinfo) | Information about mounts
+[ns](#ns) | Directory containing info about supported namespaces
+[stat](#stat) | Process statistics
+[statm](#statm) | Process memory statistics
+[status](#status) | Process status in human readable format
+[task](#task) | Directory containing info about running threads
+[uid_map](#uid_map) | Mappings for user IDs inside the user namespace
+
+### auxv
+
+TODO
+
+### cmdline
+
+TODO
+
+### comm
+
+TODO
+
+### environment
+
+TODO
+
+### exe
+
+TODO
+
+### fd
+
+TODO
+
+### fdinfo
+
+TODO
+
+### gid_map
+
+TODO
+
+### io
+
+Only has data for rchar, wchar, syscr, and syscw.
+
+TODO: add more detail.
+
+### maps
+
+TODO
+
+### mounts
+
+TODO
+
+### mountinfo
+
+TODO
+
+### ns
+
+TODO
+
+### stat
+
+Only has data for pid, comm, state, ppid, utime, stime, cutime, cstime,
+num_threads, and exit_signal.
+
+TODO: add more detail.
+
+### statm
+
+Only has data for vss and rss.
+
+TODO: add more detail.
+
+### status
+
+Contains data for Name, State, Tgid, Pid, Ppid, TracerPid, FDSize, VmSize,
+VmRSS, Threads, CapInh, CapPrm, CapEff, CapBnd, Seccomp.
+
+TODO: add more detail.
+
+### task
+
+TODO
+
+### uid_map
+
+TODO
diff --git a/pkg/sentry/fs/proc/cgroup.go b/pkg/sentry/fs/proc/cgroup.go
new file mode 100644
index 000000000..7c1d9e7e9
--- /dev/null
+++ b/pkg/sentry/fs/proc/cgroup.go
@@ -0,0 +1,45 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+)
+
+// LINT.IfChange
+
+func newCGroupInode(ctx context.Context, msrc *fs.MountSource, cgroupControllers map[string]string) *fs.Inode {
+ // From man 7 cgroups: "For each cgroup hierarchy of which the process
+ // is a member, there is one entry containing three colon-separated
+ // fields: hierarchy-ID:controller-list:cgroup-path"
+
+ // The hierarchy ids must be positive integers (for cgroup v1), but the
+ // exact number does not matter, so long as they are unique. We can
+ // just use a counter, but since linux sorts this file in descending
+ // order, we must count down to perserve this behavior.
+ i := len(cgroupControllers)
+ var data string
+ for name, dir := range cgroupControllers {
+ data += fmt.Sprintf("%d:%s:%s\n", i, name, dir)
+ i--
+ }
+
+ return newStaticProcInode(ctx, msrc, []byte(data))
+}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/cpuinfo.go b/pkg/sentry/fs/proc/cpuinfo.go
new file mode 100644
index 000000000..c96533401
--- /dev/null
+++ b/pkg/sentry/fs/proc/cpuinfo.go
@@ -0,0 +1,41 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// LINT.IfChange
+
+func newCPUInfo(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ k := kernel.KernelFromContext(ctx)
+ features := k.FeatureSet()
+ if features == nil {
+ // Kernel is always initialized with a FeatureSet.
+ panic("cpuinfo read with nil FeatureSet")
+ }
+ var buf bytes.Buffer
+ for i, max := uint(0), k.ApplicationCores(); i < max; i++ {
+ features.WriteCPUInfoTo(i, &buf)
+ }
+ return newStaticProcInode(ctx, msrc, buf.Bytes())
+}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/device/BUILD b/pkg/sentry/fs/proc/device/BUILD
new file mode 100644
index 000000000..52c9aa93d
--- /dev/null
+++ b/pkg/sentry/fs/proc/device/BUILD
@@ -0,0 +1,10 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "device",
+ srcs = ["device.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = ["//pkg/sentry/device"],
+)
diff --git a/pkg/sentry/fs/proc/device/device.go b/pkg/sentry/fs/proc/device/device.go
new file mode 100644
index 000000000..bbe66e796
--- /dev/null
+++ b/pkg/sentry/fs/proc/device/device.go
@@ -0,0 +1,23 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package device contains the proc device to avoid dependency loops.
+package device
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/device"
+)
+
+// ProcDevice is the kernel proc device.
+var ProcDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/fs/proc/exec_args.go b/pkg/sentry/fs/proc/exec_args.go
new file mode 100644
index 000000000..8fe626e1c
--- /dev/null
+++ b/pkg/sentry/fs/proc/exec_args.go
@@ -0,0 +1,207 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+// execArgType enumerates the types of exec arguments that are exposed through
+// proc.
+type execArgType int
+
+const (
+ cmdlineExecArg execArgType = iota
+ environExecArg
+)
+
+// execArgInode is a inode containing the exec args (either cmdline or environ)
+// for a given task.
+//
+// +stateify savable
+type execArgInode struct {
+ fsutil.SimpleFileInode
+
+ // arg is the type of exec argument this file contains.
+ arg execArgType
+
+ // t is the Task to read the exec arg line from.
+ t *kernel.Task
+}
+
+var _ fs.InodeOperations = (*execArgInode)(nil)
+
+// newExecArgFile creates a file containing the exec args of the given type.
+func newExecArgInode(t *kernel.Task, msrc *fs.MountSource, arg execArgType) *fs.Inode {
+ if arg != cmdlineExecArg && arg != environExecArg {
+ panic(fmt.Sprintf("unknown exec arg type %v", arg))
+ }
+ f := &execArgInode{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ arg: arg,
+ t: t,
+ }
+ return newProcInode(t, f, msrc, fs.SpecialFile, t)
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (i *execArgInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &execArgFile{
+ arg: i.arg,
+ t: i.t,
+ }), nil
+}
+
+// +stateify savable
+type execArgFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopWrite `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ // arg is the type of exec argument this file contains.
+ arg execArgType
+
+ // t is the Task to read the exec arg line from.
+ t *kernel.Task
+}
+
+var _ fs.FileOperations = (*execArgFile)(nil)
+
+// Read reads the exec arg from the process's address space..
+func (f *execArgFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ m, err := getTaskMM(f.t)
+ if err != nil {
+ return 0, err
+ }
+ defer m.DecUsers(ctx)
+
+ // Figure out the bounds of the exec arg we are trying to read.
+ var execArgStart, execArgEnd usermem.Addr
+ switch f.arg {
+ case cmdlineExecArg:
+ execArgStart, execArgEnd = m.ArgvStart(), m.ArgvEnd()
+ case environExecArg:
+ execArgStart, execArgEnd = m.EnvvStart(), m.EnvvEnd()
+ default:
+ panic(fmt.Sprintf("unknown exec arg type %v", f.arg))
+ }
+ if execArgStart == 0 || execArgEnd == 0 {
+ // Don't attempt to read before the start/end are set up.
+ return 0, io.EOF
+ }
+
+ start, ok := execArgStart.AddLength(uint64(offset))
+ if !ok {
+ return 0, io.EOF
+ }
+ if start >= execArgEnd {
+ return 0, io.EOF
+ }
+
+ length := int(execArgEnd - start)
+ if dstlen := dst.NumBytes(); int64(length) > dstlen {
+ length = int(dstlen)
+ }
+
+ buf := make([]byte, length)
+ // N.B. Technically this should be usermem.IOOpts.IgnorePermissions = true
+ // until Linux 4.9 (272ddc8b3735 "proc: don't use FOLL_FORCE for reading
+ // cmdline and environment").
+ copyN, err := m.CopyIn(ctx, start, buf, usermem.IOOpts{})
+ if copyN == 0 {
+ // Nothing to copy.
+ return 0, err
+ }
+ buf = buf[:copyN]
+
+ // On Linux, if the NUL byte at the end of the argument vector has been
+ // overwritten, it continues reading the environment vector as part of
+ // the argument vector.
+
+ if f.arg == cmdlineExecArg && buf[copyN-1] != 0 {
+ // Linux will limit the return up to and including the first null character in argv
+
+ copyN = bytes.IndexByte(buf, 0)
+ if copyN == -1 {
+ copyN = len(buf)
+ }
+ // If we found a NUL character in argv, return upto and including that character.
+ if copyN < len(buf) {
+ buf = buf[:copyN]
+ } else { // Otherwise return into envp.
+ lengthEnvv := int(m.EnvvEnd() - m.EnvvStart())
+
+ // Upstream limits the returned amount to one page of slop.
+ // https://elixir.bootlin.com/linux/v4.20/source/fs/proc/base.c#L208
+ // we'll return one page total between argv and envp because of the
+ // above page restrictions.
+ if lengthEnvv > usermem.PageSize-len(buf) {
+ lengthEnvv = usermem.PageSize - len(buf)
+ }
+ // Make a new buffer to fit the whole thing
+ tmp := make([]byte, length+lengthEnvv)
+ copyNE, err := m.CopyIn(ctx, m.EnvvStart(), tmp[copyN:], usermem.IOOpts{})
+ if err != nil {
+ return 0, err
+ }
+
+ // Linux will return envp up to and including the first NUL character, so find it.
+ for i, c := range tmp[copyN:] {
+ if c == 0 {
+ copyNE = i
+ break
+ }
+ }
+
+ copy(tmp, buf)
+ buf = tmp[:copyN+copyNE]
+
+ }
+
+ }
+
+ n, dstErr := dst.CopyOut(ctx, buf)
+ if dstErr != nil {
+ return int64(n), dstErr
+ }
+ return int64(n), err
+}
+
+// LINT.ThenChange(../../fsimpl/proc/task.go)
diff --git a/pkg/sentry/fs/proc/fds.go b/pkg/sentry/fs/proc/fds.go
new file mode 100644
index 000000000..35972e23c
--- /dev/null
+++ b/pkg/sentry/fs/proc/fds.go
@@ -0,0 +1,283 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+ "sort"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
+ "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// LINT.IfChange
+
+// walkDescriptors finds the descriptor (file-flag pair) for the fd identified
+// by p, and calls the toInodeOperations callback with that descriptor. This is a helper
+// method for implementing fs.InodeOperations.Lookup.
+func walkDescriptors(t *kernel.Task, p string, toInode func(*fs.File, kernel.FDFlags) *fs.Inode) (*fs.Inode, error) {
+ n, err := strconv.ParseUint(p, 10, 64)
+ if err != nil {
+ // Not found.
+ return nil, syserror.ENOENT
+ }
+
+ var file *fs.File
+ var fdFlags kernel.FDFlags
+ t.WithMuLocked(func(t *kernel.Task) {
+ if fdTable := t.FDTable(); fdTable != nil {
+ file, fdFlags = fdTable.Get(int32(n))
+ }
+ })
+ if file == nil {
+ return nil, syserror.ENOENT
+ }
+ return toInode(file, fdFlags), nil
+}
+
+// readDescriptors reads fds in the task starting at offset, and calls the
+// toDentAttr callback for each to get a DentAttr, which it then emits. This is
+// a helper for implementing fs.InodeOperations.Readdir.
+func readDescriptors(t *kernel.Task, c *fs.DirCtx, offset int64, toDentAttr func(int) fs.DentAttr) (int64, error) {
+ var fds []int32
+ t.WithMuLocked(func(t *kernel.Task) {
+ if fdTable := t.FDTable(); fdTable != nil {
+ fds = fdTable.GetFDs()
+ }
+ })
+
+ // Find the appropriate starting point.
+ idx := sort.Search(len(fds), func(i int) bool { return fds[i] >= int32(offset) })
+ if idx == len(fds) {
+ return offset, nil
+ }
+ fds = fds[idx:]
+
+ // Serialize all FDs.
+ for _, fd := range fds {
+ name := strconv.FormatUint(uint64(fd), 10)
+ if err := c.DirEmit(name, toDentAttr(int(fd))); err != nil {
+ // Returned offset is the next fd to serialize.
+ return int64(fd), err
+ }
+ }
+ // We serialized them all. Next offset should be higher than last
+ // serialized fd.
+ return int64(fds[len(fds)-1] + 1), nil
+}
+
+// fd implements fs.InodeOperations for a file in /proc/TID/fd/.
+type fd struct {
+ ramfs.Symlink
+ file *fs.File
+}
+
+var _ fs.InodeOperations = (*fd)(nil)
+
+// newFd returns a new fd based on an existing file.
+//
+// This inherits one reference to the file.
+func newFd(t *kernel.Task, f *fs.File, msrc *fs.MountSource) *fs.Inode {
+ fd := &fd{
+ // RootOwner overridden by taskOwnedInodeOps.UnstableAttrs().
+ Symlink: *ramfs.NewSymlink(t, fs.RootOwner, ""),
+ file: f,
+ }
+ return newProcInode(t, fd, msrc, fs.Symlink, t)
+}
+
+// GetFile returns the fs.File backing this fd. The dirent and flags
+// arguments are ignored.
+func (f *fd) GetFile(context.Context, *fs.Dirent, fs.FileFlags) (*fs.File, error) {
+ // Take a reference on the fs.File.
+ f.file.IncRef()
+ return f.file, nil
+}
+
+// Readlink returns the current target.
+func (f *fd) Readlink(ctx context.Context, _ *fs.Inode) (string, error) {
+ root := fs.RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+ n, _ := f.file.Dirent.FullName(root)
+ return n, nil
+}
+
+// Getlink implements fs.InodeOperations.Getlink.
+func (f *fd) Getlink(context.Context, *fs.Inode) (*fs.Dirent, error) {
+ f.file.Dirent.IncRef()
+ return f.file.Dirent, nil
+}
+
+// Truncate is ignored.
+func (f *fd) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+func (f *fd) Release(ctx context.Context) {
+ f.Symlink.Release(ctx)
+ f.file.DecRef()
+}
+
+// Close releases the reference on the file.
+func (f *fd) Close() error {
+ f.file.DecRef()
+ return nil
+}
+
+// fdDir is an InodeOperations for /proc/TID/fd.
+//
+// +stateify savable
+type fdDir struct {
+ ramfs.Dir
+
+ // We hold a reference on the task's FDTable but only keep an indirect
+ // task pointer to avoid Dirent loading circularity caused by the
+ // table's back pointers into the dirent tree.
+ t *kernel.Task
+}
+
+var _ fs.InodeOperations = (*fdDir)(nil)
+
+// newFdDir creates a new fdDir.
+func newFdDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ f := &fdDir{
+ Dir: *ramfs.NewDir(t, nil, fs.RootOwner, fs.FilePermissions{User: fs.PermMask{Read: true, Execute: true}}),
+ t: t,
+ }
+ return newProcInode(t, f, msrc, fs.SpecialDirectory, t)
+}
+
+// Check implements InodeOperations.Check.
+//
+// This is to match Linux, which uses a special permission handler to guarantee
+// that a process can still access /proc/self/fd after it has executed
+// setuid. See fs/proc/fd.c:proc_fd_permission.
+func (f *fdDir) Check(ctx context.Context, inode *fs.Inode, req fs.PermMask) bool {
+ if fs.ContextCanAccessFile(ctx, inode, req) {
+ return true
+ }
+ if t := kernel.TaskFromContext(ctx); t != nil {
+ // Allow access if the task trying to access it is in the
+ // thread group corresponding to this directory.
+ if f.t.ThreadGroup() == t.ThreadGroup() {
+ return true
+ }
+ }
+ return false
+}
+
+// Lookup loads an Inode in /proc/TID/fd into a Dirent.
+func (f *fdDir) Lookup(ctx context.Context, dir *fs.Inode, p string) (*fs.Dirent, error) {
+ n, err := walkDescriptors(f.t, p, func(file *fs.File, _ kernel.FDFlags) *fs.Inode {
+ return newFd(f.t, file, dir.MountSource)
+ })
+ if err != nil {
+ return nil, err
+ }
+ return fs.NewDirent(ctx, n, p), nil
+}
+
+// GetFile implements fs.FileOperations.GetFile.
+func (f *fdDir) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ fops := &fdDirFile{
+ isInfoFile: false,
+ t: f.t,
+ }
+ return fs.NewFile(ctx, dirent, flags, fops), nil
+}
+
+// +stateify savable
+type fdDirFile struct {
+ fsutil.DirFileOperations `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ isInfoFile bool
+
+ t *kernel.Task
+}
+
+var _ fs.FileOperations = (*fdDirFile)(nil)
+
+// Readdir implements fs.FileOperations.Readdir.
+func (f *fdDirFile) Readdir(ctx context.Context, file *fs.File, ser fs.DentrySerializer) (int64, error) {
+ dirCtx := &fs.DirCtx{
+ Serializer: ser,
+ }
+ typ := fs.RegularFile
+ if f.isInfoFile {
+ typ = fs.Symlink
+ }
+ return readDescriptors(f.t, dirCtx, file.Offset(), func(fd int) fs.DentAttr {
+ return fs.GenericDentAttr(typ, device.ProcDevice)
+ })
+}
+
+// fdInfoDir implements /proc/TID/fdinfo. It embeds an fdDir, but overrides
+// Lookup and Readdir.
+//
+// +stateify savable
+type fdInfoDir struct {
+ ramfs.Dir
+
+ t *kernel.Task
+}
+
+// newFdInfoDir creates a new fdInfoDir.
+func newFdInfoDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ fdid := &fdInfoDir{
+ Dir: *ramfs.NewDir(t, nil, fs.RootOwner, fs.FilePermsFromMode(0500)),
+ t: t,
+ }
+ return newProcInode(t, fdid, msrc, fs.SpecialDirectory, t)
+}
+
+// Lookup loads an fd in /proc/TID/fdinfo into a Dirent.
+func (fdid *fdInfoDir) Lookup(ctx context.Context, dir *fs.Inode, p string) (*fs.Dirent, error) {
+ inode, err := walkDescriptors(fdid.t, p, func(file *fs.File, fdFlags kernel.FDFlags) *fs.Inode {
+ // TODO(b/121266871): Using a static inode here means that the
+ // data can be out-of-date if, for instance, the flags on the
+ // FD change before we read this file. We should switch to
+ // generating the data on Read(). Also, we should include pos,
+ // locks, and other data. For now we only have flags.
+ // See https://www.kernel.org/doc/Documentation/filesystems/proc.txt
+ flags := file.Flags().ToLinux() | fdFlags.ToLinuxFileFlags()
+ file.DecRef()
+ contents := []byte(fmt.Sprintf("flags:\t0%o\n", flags))
+ return newStaticProcInode(ctx, dir.MountSource, contents)
+ })
+ if err != nil {
+ return nil, err
+ }
+ return fs.NewDirent(ctx, inode, p), nil
+}
+
+// GetFile implements fs.FileOperations.GetFile.
+func (fdid *fdInfoDir) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ fops := &fdDirFile{
+ isInfoFile: true,
+ t: fdid.t,
+ }
+ return fs.NewFile(ctx, dirent, flags, fops), nil
+}
+
+// LINT.ThenChange(../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/proc/filesystems.go b/pkg/sentry/fs/proc/filesystems.go
new file mode 100644
index 000000000..0a58ac34c
--- /dev/null
+++ b/pkg/sentry/fs/proc/filesystems.go
@@ -0,0 +1,65 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
+)
+
+// LINT.IfChange
+
+// filesystemsData backs /proc/filesystems.
+//
+// +stateify savable
+type filesystemsData struct{}
+
+// NeedsUpdate returns true on the first generation. The set of registered file
+// systems doesn't change so there's no need to generate SeqData more than once.
+func (*filesystemsData) NeedsUpdate(generation int64) bool {
+ return generation == 0
+}
+
+// ReadSeqFileData returns data for the SeqFile reader.
+// SeqData, the current generation and where in the file the handle corresponds to.
+func (*filesystemsData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ // We don't ever expect to see a non-nil SeqHandle.
+ if h != nil {
+ return nil, 0
+ }
+
+ // Generate the file contents.
+ var buf bytes.Buffer
+ for _, sys := range fs.GetFilesystems() {
+ if !sys.AllowUserList() {
+ continue
+ }
+ nodev := "nodev"
+ if sys.Flags()&fs.FilesystemRequiresDev != 0 {
+ nodev = ""
+ }
+ // Matches the format of fs/filesystems.c:filesystems_proc_show.
+ fmt.Fprintf(&buf, "%s\t%s\n", nodev, sys.Name())
+ }
+
+ // Return the SeqData and advance the generation counter.
+ return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*filesystemsData)(nil)}}, 1
+}
+
+// LINT.ThenChange(../../fsimpl/proc/filesystem.go)
diff --git a/pkg/sentry/fs/proc/fs.go b/pkg/sentry/fs/proc/fs.go
new file mode 100644
index 000000000..daf1ba781
--- /dev/null
+++ b/pkg/sentry/fs/proc/fs.go
@@ -0,0 +1,85 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+)
+
+// LINT.IfChange
+
+// filesystem is a procfs.
+//
+// +stateify savable
+type filesystem struct{}
+
+func init() {
+ fs.RegisterFilesystem(&filesystem{})
+}
+
+// FilesystemName is the name under which the filesystem is registered.
+// Name matches fs/proc/root.c:proc_fs_type.name.
+const FilesystemName = "proc"
+
+// Name is the name of the file system.
+func (*filesystem) Name() string {
+ return FilesystemName
+}
+
+// AllowUserMount allows users to mount(2) this file system.
+func (*filesystem) AllowUserMount() bool {
+ return true
+}
+
+// AllowUserList allows this filesystem to be listed in /proc/filesystems.
+func (*filesystem) AllowUserList() bool {
+ return true
+}
+
+// Flags returns that there is nothing special about this file system.
+//
+// In Linux, proc returns FS_USERNS_VISIBLE | FS_USERNS_MOUNT, see fs/proc/root.c.
+func (*filesystem) Flags() fs.FilesystemFlags {
+ return 0
+}
+
+// Mount returns the root of a procfs that can be positioned in the vfs.
+func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSourceFlags, data string, cgroupsInt interface{}) (*fs.Inode, error) {
+ // device is always ignored.
+
+ // Parse generic comma-separated key=value options, this file system expects them.
+ options := fs.GenericMountSourceOptions(data)
+
+ // Proc options parsing checks for either a gid= or hidepid= and barfs on
+ // anything else, see fs/proc/root.c:proc_parse_options. Since we don't know
+ // what to do with gid= or hidepid=, we blow up if we get any options.
+ if len(options) > 0 {
+ return nil, fmt.Errorf("unsupported mount options: %v", options)
+ }
+
+ var cgroups map[string]string
+ if cgroupsInt != nil {
+ cgroups = cgroupsInt.(map[string]string)
+ }
+
+ // Construct the procfs root. Since procfs files are all virtual, we
+ // never want them cached.
+ return New(ctx, fs.NewNonCachingMountSource(ctx, f, flags), cgroups)
+}
+
+// LINT.ThenChange(../../fsimpl/proc/filesystem.go)
diff --git a/pkg/sentry/fs/proc/inode.go b/pkg/sentry/fs/proc/inode.go
new file mode 100644
index 000000000..d2859a4c2
--- /dev/null
+++ b/pkg/sentry/fs/proc/inode.go
@@ -0,0 +1,137 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// LINT.IfChange
+
+// taskOwnedInodeOps wraps an fs.InodeOperations and overrides the UnstableAttr
+// method to return either the task or root as the owner, depending on the
+// task's dumpability.
+//
+// +stateify savable
+type taskOwnedInodeOps struct {
+ fs.InodeOperations
+
+ // t is the task that owns this file.
+ t *kernel.Task
+}
+
+// UnstableAttr implement fs.InodeOperations.UnstableAttr.
+func (i *taskOwnedInodeOps) UnstableAttr(ctx context.Context, inode *fs.Inode) (fs.UnstableAttr, error) {
+ uattr, err := i.InodeOperations.UnstableAttr(ctx, inode)
+ if err != nil {
+ return fs.UnstableAttr{}, err
+ }
+
+ // By default, set the task owner as the file owner.
+ creds := i.t.Credentials()
+ uattr.Owner = fs.FileOwner{creds.EffectiveKUID, creds.EffectiveKGID}
+
+ // Linux doesn't apply dumpability adjustments to world
+ // readable/executable directories so that applications can stat
+ // /proc/PID to determine the effective UID of a process. See
+ // fs/proc/base.c:task_dump_owner.
+ if fs.IsDir(inode.StableAttr) && uattr.Perms == fs.FilePermsFromMode(0555) {
+ return uattr, nil
+ }
+
+ // If the task is not dumpable, then root (in the namespace preferred)
+ // owns the file.
+ var m *mm.MemoryManager
+ i.t.WithMuLocked(func(t *kernel.Task) {
+ m = t.MemoryManager()
+ })
+
+ if m == nil {
+ uattr.Owner.UID = auth.RootKUID
+ uattr.Owner.GID = auth.RootKGID
+ } else if m.Dumpability() != mm.UserDumpable {
+ if kuid := creds.UserNamespace.MapToKUID(auth.RootUID); kuid.Ok() {
+ uattr.Owner.UID = kuid
+ } else {
+ uattr.Owner.UID = auth.RootKUID
+ }
+ if kgid := creds.UserNamespace.MapToKGID(auth.RootGID); kgid.Ok() {
+ uattr.Owner.GID = kgid
+ } else {
+ uattr.Owner.GID = auth.RootKGID
+ }
+ }
+
+ return uattr, nil
+}
+
+// staticFileInodeOps is an InodeOperations implementation that can be used to
+// return file contents which are constant. This file is not writable and will
+// always have mode 0444.
+//
+// +stateify savable
+type staticFileInodeOps struct {
+ fsutil.InodeDenyWriteChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopAllocate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+ fsutil.InodeStaticFileGetter
+}
+
+var _ fs.InodeOperations = (*staticFileInodeOps)(nil)
+
+// newStaticFileInode returns a procfs InodeOperations with static contents.
+func newStaticProcInode(ctx context.Context, msrc *fs.MountSource, contents []byte) *fs.Inode {
+ iops := &staticFileInodeOps{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ InodeStaticFileGetter: fsutil.InodeStaticFileGetter{
+ Contents: contents,
+ },
+ }
+ return newProcInode(ctx, iops, msrc, fs.SpecialFile, nil)
+}
+
+// newProcInode creates a new inode from the given inode operations.
+func newProcInode(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSource, typ fs.InodeType, t *kernel.Task) *fs.Inode {
+ sattr := fs.StableAttr{
+ DeviceID: device.ProcDevice.DeviceID(),
+ InodeID: device.ProcDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: typ,
+ }
+ if t != nil {
+ iops = &taskOwnedInodeOps{iops, t}
+ }
+ return fs.NewInode(ctx, iops, msrc, sattr)
+}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks.go)
diff --git a/pkg/sentry/fs/proc/loadavg.go b/pkg/sentry/fs/proc/loadavg.go
new file mode 100644
index 000000000..139d49c34
--- /dev/null
+++ b/pkg/sentry/fs/proc/loadavg.go
@@ -0,0 +1,59 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
+)
+
+// LINT.IfChange
+
+// loadavgData backs /proc/loadavg.
+//
+// +stateify savable
+type loadavgData struct{}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*loadavgData) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (d *loadavgData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ var buf bytes.Buffer
+
+ // TODO(b/62345059): Include real data in fields.
+ // Column 1-3: CPU and IO utilization of the last 1, 5, and 10 minute periods.
+ // Column 4-5: currently running processes and the total number of processes.
+ // Column 6: the last process ID used.
+ fmt.Fprintf(&buf, "%.2f %.2f %.2f %d/%d %d\n", 0.00, 0.00, 0.00, 0, 0, 0)
+
+ return []seqfile.SeqData{
+ {
+ Buf: buf.Bytes(),
+ Handle: (*loadavgData)(nil),
+ },
+ }, 0
+}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/meminfo.go b/pkg/sentry/fs/proc/meminfo.go
new file mode 100644
index 000000000..91617267d
--- /dev/null
+++ b/pkg/sentry/fs/proc/meminfo.go
@@ -0,0 +1,93 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// LINT.IfChange
+
+// meminfoData backs /proc/meminfo.
+//
+// +stateify savable
+type meminfoData struct {
+ // k is the owning Kernel.
+ k *kernel.Kernel
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*meminfoData) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (d *meminfoData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ mf := d.k.MemoryFile()
+ mf.UpdateUsage()
+ snapshot, totalUsage := usage.MemoryAccounting.Copy()
+ totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage)
+ anon := snapshot.Anonymous + snapshot.Tmpfs
+ file := snapshot.PageCache + snapshot.Mapped
+ // We don't actually have active/inactive LRUs, so just make up numbers.
+ activeFile := (file / 2) &^ (usermem.PageSize - 1)
+ inactiveFile := file - activeFile
+
+ var buf bytes.Buffer
+ fmt.Fprintf(&buf, "MemTotal: %8d kB\n", totalSize/1024)
+ memFree := totalSize - totalUsage
+ if memFree > totalSize {
+ // Underflow.
+ memFree = 0
+ }
+ // We use MemFree as MemAvailable because we don't swap.
+ // TODO(rahat): When reclaim is implemented the value of MemAvailable
+ // should change.
+ fmt.Fprintf(&buf, "MemFree: %8d kB\n", memFree/1024)
+ fmt.Fprintf(&buf, "MemAvailable: %8d kB\n", memFree/1024)
+ fmt.Fprintf(&buf, "Buffers: 0 kB\n") // memory usage by block devices
+ fmt.Fprintf(&buf, "Cached: %8d kB\n", (file+snapshot.Tmpfs)/1024)
+ // Emulate a system with no swap, which disables inactivation of anon pages.
+ fmt.Fprintf(&buf, "SwapCache: 0 kB\n")
+ fmt.Fprintf(&buf, "Active: %8d kB\n", (anon+activeFile)/1024)
+ fmt.Fprintf(&buf, "Inactive: %8d kB\n", inactiveFile/1024)
+ fmt.Fprintf(&buf, "Active(anon): %8d kB\n", anon/1024)
+ fmt.Fprintf(&buf, "Inactive(anon): 0 kB\n")
+ fmt.Fprintf(&buf, "Active(file): %8d kB\n", activeFile/1024)
+ fmt.Fprintf(&buf, "Inactive(file): %8d kB\n", inactiveFile/1024)
+ fmt.Fprintf(&buf, "Unevictable: 0 kB\n") // TODO(b/31823263)
+ fmt.Fprintf(&buf, "Mlocked: 0 kB\n") // TODO(b/31823263)
+ fmt.Fprintf(&buf, "SwapTotal: 0 kB\n")
+ fmt.Fprintf(&buf, "SwapFree: 0 kB\n")
+ fmt.Fprintf(&buf, "Dirty: 0 kB\n")
+ fmt.Fprintf(&buf, "Writeback: 0 kB\n")
+ fmt.Fprintf(&buf, "AnonPages: %8d kB\n", anon/1024)
+ fmt.Fprintf(&buf, "Mapped: %8d kB\n", file/1024) // doesn't count mapped tmpfs, which we don't know
+ fmt.Fprintf(&buf, "Shmem: %8d kB\n", snapshot.Tmpfs/1024)
+ return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*meminfoData)(nil)}}, 0
+}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go
new file mode 100644
index 000000000..1fc9c703c
--- /dev/null
+++ b/pkg/sentry/fs/proc/mounts.go
@@ -0,0 +1,232 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// LINT.IfChange
+
+// forEachMountSource runs f for the process root mount and each mount that is a
+// descendant of the root.
+func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) {
+ var fsctx *kernel.FSContext
+ t.WithMuLocked(func(t *kernel.Task) {
+ fsctx = t.FSContext()
+ })
+ if fsctx == nil {
+ // The task has been destroyed. Nothing to show here.
+ return
+ }
+
+ // All mount points must be relative to the rootDir, and mounts outside
+ // will be excluded.
+ rootDir := fsctx.RootDirectory()
+ if rootDir == nil {
+ // The task has been destroyed. Nothing to show here.
+ return
+ }
+ defer rootDir.DecRef()
+
+ mnt := t.MountNamespace().FindMount(rootDir)
+ if mnt == nil {
+ // Has it just been unmounted?
+ return
+ }
+ ms := t.MountNamespace().AllMountsUnder(mnt)
+ sort.Slice(ms, func(i, j int) bool {
+ return ms[i].ID < ms[j].ID
+ })
+ for _, m := range ms {
+ mroot := m.Root()
+ if mroot == nil {
+ continue // No longer valid.
+ }
+ mountPath, desc := mroot.FullName(rootDir)
+ mroot.DecRef()
+ if !desc {
+ // MountSources that are not descendants of the chroot jail are ignored.
+ continue
+ }
+ fn(mountPath, m)
+ }
+}
+
+// mountInfoFile is used to implement /proc/[pid]/mountinfo.
+//
+// +stateify savable
+type mountInfoFile struct {
+ t *kernel.Task
+}
+
+// NeedsUpdate implements SeqSource.NeedsUpdate.
+func (mif *mountInfoFile) NeedsUpdate(_ int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements SeqSource.ReadSeqFileData.
+func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if handle != nil {
+ return nil, 0
+ }
+
+ var buf bytes.Buffer
+ forEachMount(mif.t, func(mountPath string, m *fs.Mount) {
+ mroot := m.Root()
+ if mroot == nil {
+ return // No longer valid.
+ }
+ defer mroot.DecRef()
+
+ // Format:
+ // 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
+ // (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
+
+ // (1) MountSource ID.
+ fmt.Fprintf(&buf, "%d ", m.ID)
+
+ // (2) Parent ID (or this ID if there is no parent).
+ pID := m.ID
+ if !m.IsRoot() && !m.IsUndo() {
+ pID = m.ParentID
+ }
+ fmt.Fprintf(&buf, "%d ", pID)
+
+ // (3) Major:Minor device ID. We don't have a superblock, so we
+ // just use the root inode device number.
+ sa := mroot.Inode.StableAttr
+ fmt.Fprintf(&buf, "%d:%d ", sa.DeviceFileMajor, sa.DeviceFileMinor)
+
+ // (4) Root: the pathname of the directory in the filesystem
+ // which forms the root of this mount.
+ //
+ // NOTE(b/78135857): This will always be "/" until we implement
+ // bind mounts.
+ fmt.Fprintf(&buf, "/ ")
+
+ // (5) Mount point (relative to process root).
+ fmt.Fprintf(&buf, "%s ", mountPath)
+
+ // (6) Mount options.
+ flags := mroot.Inode.MountSource.Flags
+ opts := "rw"
+ if flags.ReadOnly {
+ opts = "ro"
+ }
+ if flags.NoAtime {
+ opts += ",noatime"
+ }
+ if flags.NoExec {
+ opts += ",noexec"
+ }
+ fmt.Fprintf(&buf, "%s ", opts)
+
+ // (7) Optional fields: zero or more fields of the form "tag[:value]".
+ // (8) Separator: the end of the optional fields is marked by a single hyphen.
+ fmt.Fprintf(&buf, "- ")
+
+ // (9) Filesystem type.
+ fmt.Fprintf(&buf, "%s ", mroot.Inode.MountSource.FilesystemType)
+
+ // (10) Mount source: filesystem-specific information or "none".
+ fmt.Fprintf(&buf, "none ")
+
+ // (11) Superblock options, and final newline.
+ fmt.Fprintf(&buf, "%s\n", superBlockOpts(mountPath, mroot.Inode.MountSource))
+ })
+
+ return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*mountInfoFile)(nil)}}, 0
+}
+
+func superBlockOpts(mountPath string, msrc *fs.MountSource) string {
+ // gVisor doesn't (yet) have a concept of super block options, so we
+ // use the ro/rw bit from the mount flag.
+ opts := "rw"
+ if msrc.Flags.ReadOnly {
+ opts = "ro"
+ }
+
+ // NOTE(b/147673608): If the mount is a cgroup, we also need to include
+ // the cgroup name in the options. For now we just read that from the
+ // path.
+ //
+ // TODO(gvisor.dev/issue/190): Once gVisor has full cgroup support, we
+ // should get this value from the cgroup itself, and not rely on the
+ // path.
+ if msrc.FilesystemType == "cgroup" {
+ splitPath := strings.Split(mountPath, "/")
+ cgroupType := splitPath[len(splitPath)-1]
+ opts += "," + cgroupType
+ }
+ return opts
+}
+
+// mountsFile is used to implement /proc/[pid]/mounts.
+//
+// +stateify savable
+type mountsFile struct {
+ t *kernel.Task
+}
+
+// NeedsUpdate implements SeqSource.NeedsUpdate.
+func (mf *mountsFile) NeedsUpdate(_ int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements SeqSource.ReadSeqFileData.
+func (mf *mountsFile) ReadSeqFileData(ctx context.Context, handle seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if handle != nil {
+ return nil, 0
+ }
+
+ var buf bytes.Buffer
+ forEachMount(mf.t, func(mountPath string, m *fs.Mount) {
+ // Format:
+ // <special device or remote filesystem> <mount point> <filesystem type> <mount options> <needs dump> <fsck order>
+ //
+ // We use the filesystem name as the first field, since there
+ // is no real block device we can point to, and we also should
+ // not expose anything about the remote filesystem.
+ //
+ // Only ro/rw option is supported for now.
+ //
+ // The "needs dump"and fsck flags are always 0, which is allowed.
+ root := m.Root()
+ if root == nil {
+ return // No longer valid.
+ }
+ defer root.DecRef()
+
+ flags := root.Inode.MountSource.Flags
+ opts := "rw"
+ if flags.ReadOnly {
+ opts = "ro"
+ }
+ fmt.Fprintf(&buf, "%s %s %s %s %d %d\n", "none", mountPath, root.Inode.MountSource.FilesystemType, opts, 0, 0)
+ })
+
+ return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*mountsFile)(nil)}}, 0
+}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
new file mode 100644
index 000000000..bd18177d4
--- /dev/null
+++ b/pkg/sentry/fs/proc/net.go
@@ -0,0 +1,841 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "reflect"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
+ "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// LINT.IfChange
+
+// newNetDir creates a new proc net entry.
+func newNetDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ k := t.Kernel()
+
+ var contents map[string]*fs.Inode
+ if s := t.NetworkNamespace().Stack(); s != nil {
+ // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task
+ // network namespace.
+ contents = map[string]*fs.Inode{
+ "dev": seqfile.NewSeqFileInode(t, &netDev{s: s}, msrc),
+ "snmp": seqfile.NewSeqFileInode(t, &netSnmp{s: s}, msrc),
+
+ // The following files are simple stubs until they are
+ // implemented in netstack, if the file contains a
+ // header the stub is just the header otherwise it is
+ // an empty file.
+ "arp": newStaticProcInode(t, msrc, []byte("IP address HW type Flags HW address Mask Device\n")),
+
+ "netlink": newStaticProcInode(t, msrc, []byte("sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n")),
+ "netstat": newStaticProcInode(t, msrc, []byte("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed EmbryonicRsts PruneCalled RcvPruned OfoPruned OutOfWindowIcmps LockDroppedIcmps ArpFilter TW TWRecycled TWKilled PAWSPassive PAWSActive PAWSEstab DelayedACKs DelayedACKLocked DelayedACKLost ListenOverflows ListenDrops TCPPrequeued TCPDirectCopyFromBacklog TCPDirectCopyFromPrequeue TCPPrequeueDropped TCPHPHits TCPHPHitsToUser TCPPureAcks TCPHPAcks TCPRenoRecovery TCPSackRecovery TCPSACKReneging TCPFACKReorder TCPSACKReorder TCPRenoReorder TCPTSReorder TCPFullUndo TCPPartialUndo TCPDSACKUndo TCPLossUndo TCPLostRetransmit TCPRenoFailures TCPSackFailures TCPLossFailures TCPFastRetrans TCPForwardRetrans TCPSlowStartRetrans TCPTimeouts TCPLossProbes TCPLossProbeRecovery TCPRenoRecoveryFail TCPSackRecoveryFail TCPSchedulerFailed TCPRcvCollapsed TCPDSACKOldSent TCPDSACKOfoSent TCPDSACKRecv TCPDSACKOfoRecv TCPAbortOnData TCPAbortOnClose TCPAbortOnMemory TCPAbortOnTimeout TCPAbortOnLinger TCPAbortFailed TCPMemoryPressures TCPSACKDiscard TCPDSACKIgnoredOld TCPDSACKIgnoredNoUndo TCPSpuriousRTOs TCPMD5NotFound TCPMD5Unexpected TCPMD5Failure TCPSackShifted TCPSackMerged TCPSackShiftFallback TCPBacklogDrop TCPMinTTLDrop TCPDeferAcceptDrop IPReversePathFilter TCPTimeWaitOverflow TCPReqQFullDoCookies TCPReqQFullDrop TCPRetransFail TCPRcvCoalesce TCPOFOQueue TCPOFODrop TCPOFOMerge TCPChallengeACK TCPSYNChallenge TCPFastOpenActive TCPFastOpenActiveFail TCPFastOpenPassive TCPFastOpenPassiveFail TCPFastOpenListenOverflow TCPFastOpenCookieReqd TCPSpuriousRtxHostQueues BusyPollRxPackets TCPAutoCorking TCPFromZeroWindowAdv TCPToZeroWindowAdv TCPWantZeroWindowAdv TCPSynRetrans TCPOrigDataSent TCPHystartTrainDetect TCPHystartTrainCwnd TCPHystartDelayDetect TCPHystartDelayCwnd TCPACKSkippedSynRecv TCPACKSkippedPAWS TCPACKSkippedSeq TCPACKSkippedFinWait2 TCPACKSkippedTimeWait TCPACKSkippedChallenge TCPWinProbe TCPKeepAlive TCPMTUPFail TCPMTUPSuccess\n")),
+ "packet": newStaticProcInode(t, msrc, []byte("sk RefCnt Type Proto Iface R Rmem User Inode\n")),
+ "protocols": newStaticProcInode(t, msrc, []byte("protocol size sockets memory press maxhdr slab module cl co di ac io in de sh ss gs se re sp bi br ha uh gp em\n")),
+ // Linux sets psched values to: nsec per usec, psched
+ // tick in ns, 1000000, high res timer ticks per sec
+ // (ClockGetres returns 1ns resolution).
+ "psched": newStaticProcInode(t, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))),
+ "ptype": newStaticProcInode(t, msrc, []byte("Type Device Function\n")),
+ "route": seqfile.NewSeqFileInode(t, &netRoute{s: s}, msrc),
+ "tcp": seqfile.NewSeqFileInode(t, &netTCP{k: k}, msrc),
+ "udp": seqfile.NewSeqFileInode(t, &netUDP{k: k}, msrc),
+ "unix": seqfile.NewSeqFileInode(t, &netUnix{k: k}, msrc),
+ }
+
+ if s.SupportsIPv6() {
+ contents["if_inet6"] = seqfile.NewSeqFileInode(t, &ifinet6{s: s}, msrc)
+ contents["ipv6_route"] = newStaticProcInode(t, msrc, []byte(""))
+ contents["tcp6"] = seqfile.NewSeqFileInode(t, &netTCP6{k: k}, msrc)
+ contents["udp6"] = newStaticProcInode(t, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n"))
+ }
+ }
+ d := ramfs.NewDir(t, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(t, d, msrc, fs.SpecialDirectory, t)
+}
+
+// ifinet6 implements seqfile.SeqSource for /proc/net/if_inet6.
+//
+// +stateify savable
+type ifinet6 struct {
+ s inet.Stack
+}
+
+func (n *ifinet6) contents() []string {
+ var lines []string
+ nics := n.s.Interfaces()
+ for id, naddrs := range n.s.InterfaceAddrs() {
+ nic, ok := nics[id]
+ if !ok {
+ // NIC was added after NICNames was called. We'll just
+ // ignore it.
+ continue
+ }
+
+ for _, a := range naddrs {
+ // IPv6 only.
+ if a.Family != linux.AF_INET6 {
+ continue
+ }
+
+ // Fields:
+ // IPv6 address displayed in 32 hexadecimal chars without colons
+ // Netlink device number (interface index) in hexadecimal (use nic id)
+ // Prefix length in hexadecimal
+ // Scope value (use 0)
+ // Interface flags
+ // Device name
+ lines = append(lines, fmt.Sprintf("%032x %02x %02x %02x %02x %8s\n", a.Addr, id, a.PrefixLen, 0, a.Flags, nic.Name))
+ }
+ }
+ return lines
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*ifinet6) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (n *ifinet6) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ var data []seqfile.SeqData
+ for _, l := range n.contents() {
+ data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*ifinet6)(nil)})
+ }
+
+ return data, 0
+}
+
+// netDev implements seqfile.SeqSource for /proc/net/dev.
+//
+// +stateify savable
+type netDev struct {
+ s inet.Stack
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (n *netDev) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. See Linux's
+// net/core/net-procfs.c:dev_seq_show.
+func (n *netDev) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ interfaces := n.s.Interfaces()
+ contents := make([]string, 2, 2+len(interfaces))
+ // Add the table header. From net/core/net-procfs.c:dev_seq_show.
+ contents[0] = "Inter-| Receive | Transmit\n"
+ contents[1] = " face |bytes packets errs drop fifo frame compressed multicast|bytes packets errs drop fifo colls carrier compressed\n"
+
+ for _, i := range interfaces {
+ // Implements the same format as
+ // net/core/net-procfs.c:dev_seq_printf_stats.
+ var stats inet.StatDev
+ if err := n.s.Statistics(&stats, i.Name); err != nil {
+ log.Warningf("Failed to retrieve interface statistics for %v: %v", i.Name, err)
+ continue
+ }
+ l := fmt.Sprintf(
+ "%6s: %7d %7d %4d %4d %4d %5d %10d %9d %8d %7d %4d %4d %4d %5d %7d %10d\n",
+ i.Name,
+ // Received
+ stats[0], // bytes
+ stats[1], // packets
+ stats[2], // errors
+ stats[3], // dropped
+ stats[4], // fifo
+ stats[5], // frame
+ stats[6], // compressed
+ stats[7], // multicast
+ // Transmitted
+ stats[8], // bytes
+ stats[9], // packets
+ stats[10], // errors
+ stats[11], // dropped
+ stats[12], // fifo
+ stats[13], // frame
+ stats[14], // compressed
+ stats[15]) // multicast
+ contents = append(contents, l)
+ }
+
+ var data []seqfile.SeqData
+ for _, l := range contents {
+ data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*netDev)(nil)})
+ }
+
+ return data, 0
+}
+
+// netSnmp implements seqfile.SeqSource for /proc/net/snmp.
+//
+// +stateify savable
+type netSnmp struct {
+ s inet.Stack
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (n *netSnmp) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+type snmpLine struct {
+ prefix string
+ header string
+}
+
+var snmp = []snmpLine{
+ {
+ prefix: "Ip",
+ header: "Forwarding DefaultTTL InReceives InHdrErrors InAddrErrors ForwDatagrams InUnknownProtos InDiscards InDelivers OutRequests OutDiscards OutNoRoutes ReasmTimeout ReasmReqds ReasmOKs ReasmFails FragOKs FragFails FragCreates",
+ },
+ {
+ prefix: "Icmp",
+ header: "InMsgs InErrors InCsumErrors InDestUnreachs InTimeExcds InParmProbs InSrcQuenchs InRedirects InEchos InEchoReps InTimestamps InTimestampReps InAddrMasks InAddrMaskReps OutMsgs OutErrors OutDestUnreachs OutTimeExcds OutParmProbs OutSrcQuenchs OutRedirects OutEchos OutEchoReps OutTimestamps OutTimestampReps OutAddrMasks OutAddrMaskReps",
+ },
+ {
+ prefix: "IcmpMsg",
+ },
+ {
+ prefix: "Tcp",
+ header: "RtoAlgorithm RtoMin RtoMax MaxConn ActiveOpens PassiveOpens AttemptFails EstabResets CurrEstab InSegs OutSegs RetransSegs InErrs OutRsts InCsumErrors",
+ },
+ {
+ prefix: "Udp",
+ header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti",
+ },
+ {
+ prefix: "UdpLite",
+ header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti",
+ },
+}
+
+func toSlice(a interface{}) []uint64 {
+ v := reflect.Indirect(reflect.ValueOf(a))
+ return v.Slice(0, v.Len()).Interface().([]uint64)
+}
+
+func sprintSlice(s []uint64) string {
+ if len(s) == 0 {
+ return ""
+ }
+ r := fmt.Sprint(s)
+ return r[1 : len(r)-1] // Remove "[]" introduced by fmt of slice.
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. See Linux's
+// net/core/net-procfs.c:dev_seq_show.
+func (n *netSnmp) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ contents := make([]string, 0, len(snmp)*2)
+ types := []interface{}{
+ &inet.StatSNMPIP{},
+ &inet.StatSNMPICMP{},
+ nil, // TODO(gvisor.dev/issue/628): Support IcmpMsg stats.
+ &inet.StatSNMPTCP{},
+ &inet.StatSNMPUDP{},
+ &inet.StatSNMPUDPLite{},
+ }
+ for i, stat := range types {
+ line := snmp[i]
+ if stat == nil {
+ contents = append(
+ contents,
+ fmt.Sprintf("%s:\n", line.prefix),
+ fmt.Sprintf("%s:\n", line.prefix),
+ )
+ continue
+ }
+ if err := n.s.Statistics(stat, line.prefix); err != nil {
+ if err == syserror.EOPNOTSUPP {
+ log.Infof("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err)
+ } else {
+ log.Warningf("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err)
+ }
+ }
+ var values string
+ if line.prefix == "Tcp" {
+ tcp := stat.(*inet.StatSNMPTCP)
+ // "Tcp" needs special processing because MaxConn is signed. RFC 2012.
+ values = fmt.Sprintf("%s %d %s", sprintSlice(tcp[:3]), int64(tcp[3]), sprintSlice(tcp[4:]))
+ } else {
+ values = sprintSlice(toSlice(stat))
+ }
+ contents = append(
+ contents,
+ fmt.Sprintf("%s: %s\n", line.prefix, line.header),
+ fmt.Sprintf("%s: %s\n", line.prefix, values),
+ )
+ }
+
+ data := make([]seqfile.SeqData, 0, len(snmp)*2)
+ for _, l := range contents {
+ data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*netSnmp)(nil)})
+ }
+
+ return data, 0
+}
+
+// netRoute implements seqfile.SeqSource for /proc/net/route.
+//
+// +stateify savable
+type netRoute struct {
+ s inet.Stack
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (n *netRoute) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+// See Linux's net/ipv4/fib_trie.c:fib_route_seq_show.
+func (n *netRoute) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ interfaces := n.s.Interfaces()
+ contents := []string{"Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT"}
+ for _, rt := range n.s.RouteTable() {
+ // /proc/net/route only includes ipv4 routes.
+ if rt.Family != linux.AF_INET {
+ continue
+ }
+
+ // /proc/net/route does not include broadcast or multicast routes.
+ if rt.Type == linux.RTN_BROADCAST || rt.Type == linux.RTN_MULTICAST {
+ continue
+ }
+
+ iface, ok := interfaces[rt.OutputInterface]
+ if !ok || iface.Name == "lo" {
+ continue
+ }
+
+ var (
+ gw uint32
+ prefix uint32
+ flags = linux.RTF_UP
+ )
+ if len(rt.GatewayAddr) == header.IPv4AddressSize {
+ flags |= linux.RTF_GATEWAY
+ gw = usermem.ByteOrder.Uint32(rt.GatewayAddr)
+ }
+ if len(rt.DstAddr) == header.IPv4AddressSize {
+ prefix = usermem.ByteOrder.Uint32(rt.DstAddr)
+ }
+ l := fmt.Sprintf(
+ "%s\t%08X\t%08X\t%04X\t%d\t%d\t%d\t%08X\t%d\t%d\t%d",
+ iface.Name,
+ prefix,
+ gw,
+ flags,
+ 0, // RefCnt.
+ 0, // Use.
+ 0, // Metric.
+ (uint32(1)<<rt.DstLen)-1,
+ 0, // MTU.
+ 0, // Window.
+ 0, // RTT.
+ )
+ contents = append(contents, l)
+ }
+
+ var data []seqfile.SeqData
+ for _, l := range contents {
+ l = fmt.Sprintf("%-127s\n", l)
+ data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*netRoute)(nil)})
+ }
+
+ return data, 0
+}
+
+// netUnix implements seqfile.SeqSource for /proc/net/unix.
+//
+// +stateify savable
+type netUnix struct {
+ k *kernel.Kernel
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*netUnix) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return []seqfile.SeqData{}, 0
+ }
+
+ var buf bytes.Buffer
+ for _, se := range n.k.ListSockets() {
+ s := se.Sock.Get()
+ if s == nil {
+ log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID)
+ continue
+ }
+ sfile := s.(*fs.File)
+ if family, _, _ := sfile.FileOperations.(socket.Socket).Type(); family != linux.AF_UNIX {
+ s.DecRef()
+ // Not a unix socket.
+ continue
+ }
+ sops := sfile.FileOperations.(*unix.SocketOperations)
+
+ addr, err := sops.Endpoint().GetLocalAddress()
+ if err != nil {
+ log.Warningf("Failed to retrieve socket name from %+v: %v", sfile, err)
+ addr.Addr = "<unknown>"
+ }
+
+ sockFlags := 0
+ if ce, ok := sops.Endpoint().(transport.ConnectingEndpoint); ok {
+ if ce.Listening() {
+ // For unix domain sockets, linux reports a single flag
+ // value if the socket is listening, of __SO_ACCEPTCON.
+ sockFlags = linux.SO_ACCEPTCON
+ }
+ }
+
+ // In the socket entry below, the value for the 'Num' field requires
+ // some consideration. Linux prints the address to the struct
+ // unix_sock representing a socket in the kernel, but may redact the
+ // value for unprivileged users depending on the kptr_restrict
+ // sysctl.
+ //
+ // One use for this field is to allow a privileged user to
+ // introspect into the kernel memory to determine information about
+ // a socket not available through procfs, such as the socket's peer.
+ //
+ // On gvisor, returning a pointer to our internal structures would
+ // be pointless, as it wouldn't match the memory layout for struct
+ // unix_sock, making introspection difficult. We could populate a
+ // struct unix_sock with the appropriate data, but even that
+ // requires consideration for which kernel version to emulate, as
+ // the definition of this struct changes over time.
+ //
+ // For now, we always redact this pointer.
+ fmt.Fprintf(&buf, "%#016p: %08X %08X %08X %04X %02X %5d",
+ (*unix.SocketOperations)(nil), // Num, pointer to kernel socket struct.
+ sfile.ReadRefs()-1, // RefCount, don't count our own ref.
+ 0, // Protocol, always 0 for UDS.
+ sockFlags, // Flags.
+ sops.Endpoint().Type(), // Type.
+ sops.State(), // State.
+ sfile.InodeID(), // Inode.
+ )
+
+ // Path
+ if len(addr.Addr) != 0 {
+ if addr.Addr[0] == 0 {
+ // Abstract path.
+ fmt.Fprintf(&buf, " @%s", string(addr.Addr[1:]))
+ } else {
+ fmt.Fprintf(&buf, " %s", string(addr.Addr))
+ }
+ }
+ fmt.Fprintf(&buf, "\n")
+
+ s.DecRef()
+ }
+
+ data := []seqfile.SeqData{
+ {
+ Buf: []byte("Num RefCount Protocol Flags Type St Inode Path\n"),
+ Handle: n,
+ },
+ {
+ Buf: buf.Bytes(),
+ Handle: n,
+ },
+ }
+ return data, 0
+}
+
+func networkToHost16(n uint16) uint16 {
+ // n is in network byte order, so is big-endian. The most-significant byte
+ // should be stored in the lower address.
+ //
+ // We manually inline binary.BigEndian.Uint16() because Go does not support
+ // non-primitive consts, so binary.BigEndian is a (mutable) var, so calls to
+ // binary.BigEndian.Uint16() require a read of binary.BigEndian and an
+ // interface method call, defeating inlining.
+ buf := [2]byte{byte(n >> 8 & 0xff), byte(n & 0xff)}
+ return usermem.ByteOrder.Uint16(buf[:])
+}
+
+func writeInetAddr(w io.Writer, family int, i linux.SockAddr) {
+ switch family {
+ case linux.AF_INET:
+ var a linux.SockAddrInet
+ if i != nil {
+ a = *i.(*linux.SockAddrInet)
+ }
+
+ // linux.SockAddrInet.Port is stored in the network byte order and is
+ // printed like a number in host byte order. Note that all numbers in host
+ // byte order are printed with the most-significant byte first when
+ // formatted with %X. See get_tcp4_sock() and udp4_format_sock() in Linux.
+ port := networkToHost16(a.Port)
+
+ // linux.SockAddrInet.Addr is stored as a byte slice in big-endian order
+ // (i.e. most-significant byte in index 0). Linux represents this as a
+ // __be32 which is a typedef for an unsigned int, and is printed with
+ // %X. This means that for a little-endian machine, Linux prints the
+ // least-significant byte of the address first. To emulate this, we first
+ // invert the byte order for the address using usermem.ByteOrder.Uint32,
+ // which makes it have the equivalent encoding to a __be32 on a little
+ // endian machine. Note that this operation is a no-op on a big endian
+ // machine. Then similar to Linux, we format it with %X, which will print
+ // the most-significant byte of the __be32 address first, which is now
+ // actually the least-significant byte of the original address in
+ // linux.SockAddrInet.Addr on little endian machines, due to the conversion.
+ addr := usermem.ByteOrder.Uint32(a.Addr[:])
+
+ fmt.Fprintf(w, "%08X:%04X ", addr, port)
+ case linux.AF_INET6:
+ var a linux.SockAddrInet6
+ if i != nil {
+ a = *i.(*linux.SockAddrInet6)
+ }
+
+ port := networkToHost16(a.Port)
+ addr0 := usermem.ByteOrder.Uint32(a.Addr[0:4])
+ addr1 := usermem.ByteOrder.Uint32(a.Addr[4:8])
+ addr2 := usermem.ByteOrder.Uint32(a.Addr[8:12])
+ addr3 := usermem.ByteOrder.Uint32(a.Addr[12:16])
+ fmt.Fprintf(w, "%08X%08X%08X%08X:%04X ", addr0, addr1, addr2, addr3, port)
+ }
+}
+
+func commonReadSeqFileDataTCP(ctx context.Context, n seqfile.SeqHandle, k *kernel.Kernel, h seqfile.SeqHandle, fa int, header []byte) ([]seqfile.SeqData, int64) {
+ // t may be nil here if our caller is not part of a task goroutine. This can
+ // happen for example if we're here for "sentryctl cat". When t is nil,
+ // degrade gracefully and retrieve what we can.
+ t := kernel.TaskFromContext(ctx)
+
+ if h != nil {
+ return nil, 0
+ }
+
+ var buf bytes.Buffer
+ for _, se := range k.ListSockets() {
+ s := se.Sock.Get()
+ if s == nil {
+ log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID)
+ continue
+ }
+ sfile := s.(*fs.File)
+ sops, ok := sfile.FileOperations.(socket.Socket)
+ if !ok {
+ panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile))
+ }
+ if family, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) {
+ s.DecRef()
+ // Not tcp4 sockets.
+ continue
+ }
+
+ // Linux's documentation for the fields below can be found at
+ // https://www.kernel.org/doc/Documentation/networking/proc_net_tcp.txt.
+ // For Linux's implementation, see net/ipv4/tcp_ipv4.c:get_tcp4_sock().
+ // Note that the header doesn't contain labels for all the fields.
+
+ // Field: sl; entry number.
+ fmt.Fprintf(&buf, "%4d: ", se.ID)
+
+ // Field: local_adddress.
+ var localAddr linux.SockAddr
+ if t != nil {
+ if local, _, err := sops.GetSockName(t); err == nil {
+ localAddr = local
+ }
+ }
+ writeInetAddr(&buf, fa, localAddr)
+
+ // Field: rem_address.
+ var remoteAddr linux.SockAddr
+ if t != nil {
+ if remote, _, err := sops.GetPeerName(t); err == nil {
+ remoteAddr = remote
+ }
+ }
+ writeInetAddr(&buf, fa, remoteAddr)
+
+ // Field: state; socket state.
+ fmt.Fprintf(&buf, "%02X ", sops.State())
+
+ // Field: tx_queue, rx_queue; number of packets in the transmit and
+ // receive queue. Unimplemented.
+ fmt.Fprintf(&buf, "%08X:%08X ", 0, 0)
+
+ // Field: tr, tm->when; timer active state and number of jiffies
+ // until timer expires. Unimplemented.
+ fmt.Fprintf(&buf, "%02X:%08X ", 0, 0)
+
+ // Field: retrnsmt; number of unrecovered RTO timeouts.
+ // Unimplemented.
+ fmt.Fprintf(&buf, "%08X ", 0)
+
+ // Field: uid.
+ uattr, err := sfile.Dirent.Inode.UnstableAttr(ctx)
+ if err != nil {
+ log.Warningf("Failed to retrieve unstable attr for socket file: %v", err)
+ fmt.Fprintf(&buf, "%5d ", 0)
+ } else {
+ creds := auth.CredentialsFromContext(ctx)
+ fmt.Fprintf(&buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow()))
+ }
+
+ // Field: timeout; number of unanswered 0-window probes.
+ // Unimplemented.
+ fmt.Fprintf(&buf, "%8d ", 0)
+
+ // Field: inode.
+ fmt.Fprintf(&buf, "%8d ", sfile.InodeID())
+
+ // Field: refcount. Don't count the ref we obtain while deferencing
+ // the weakref to this socket.
+ fmt.Fprintf(&buf, "%d ", sfile.ReadRefs()-1)
+
+ // Field: Socket struct address. Redacted due to the same reason as
+ // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData.
+ fmt.Fprintf(&buf, "%#016p ", (*socket.Socket)(nil))
+
+ // Field: retransmit timeout. Unimplemented.
+ fmt.Fprintf(&buf, "%d ", 0)
+
+ // Field: predicted tick of soft clock (delayed ACK control data).
+ // Unimplemented.
+ fmt.Fprintf(&buf, "%d ", 0)
+
+ // Field: (ack.quick<<1)|ack.pingpong, Unimplemented.
+ fmt.Fprintf(&buf, "%d ", 0)
+
+ // Field: sending congestion window, Unimplemented.
+ fmt.Fprintf(&buf, "%d ", 0)
+
+ // Field: Slow start size threshold, -1 if threshold >= 0xFFFF.
+ // Unimplemented, report as large threshold.
+ fmt.Fprintf(&buf, "%d", -1)
+
+ fmt.Fprintf(&buf, "\n")
+
+ s.DecRef()
+ }
+
+ data := []seqfile.SeqData{
+ {
+ Buf: header,
+ Handle: n,
+ },
+ {
+ Buf: buf.Bytes(),
+ Handle: n,
+ },
+ }
+ return data, 0
+}
+
+// netTCP implements seqfile.SeqSource for /proc/net/tcp.
+//
+// +stateify savable
+type netTCP struct {
+ k *kernel.Kernel
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*netTCP) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ header := []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n")
+ return commonReadSeqFileDataTCP(ctx, n, n.k, h, linux.AF_INET, header)
+}
+
+// netTCP6 implements seqfile.SeqSource for /proc/net/tcp6.
+//
+// +stateify savable
+type netTCP6 struct {
+ k *kernel.Kernel
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*netTCP6) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (n *netTCP6) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ header := []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n")
+ return commonReadSeqFileDataTCP(ctx, n, n.k, h, linux.AF_INET6, header)
+}
+
+// netUDP implements seqfile.SeqSource for /proc/net/udp.
+//
+// +stateify savable
+type netUDP struct {
+ k *kernel.Kernel
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*netUDP) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ // t may be nil here if our caller is not part of a task goroutine. This can
+ // happen for example if we're here for "sentryctl cat". When t is nil,
+ // degrade gracefully and retrieve what we can.
+ t := kernel.TaskFromContext(ctx)
+
+ if h != nil {
+ return nil, 0
+ }
+
+ var buf bytes.Buffer
+ for _, se := range n.k.ListSockets() {
+ s := se.Sock.Get()
+ if s == nil {
+ log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID)
+ continue
+ }
+ sfile := s.(*fs.File)
+ sops, ok := sfile.FileOperations.(socket.Socket)
+ if !ok {
+ panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile))
+ }
+ if family, stype, _ := sops.Type(); family != linux.AF_INET || stype != linux.SOCK_DGRAM {
+ s.DecRef()
+ // Not udp4 socket.
+ continue
+ }
+
+ // For Linux's implementation, see net/ipv4/udp.c:udp4_format_sock().
+
+ // Field: sl; entry number.
+ fmt.Fprintf(&buf, "%5d: ", se.ID)
+
+ // Field: local_adddress.
+ var localAddr linux.SockAddrInet
+ if t != nil {
+ if local, _, err := sops.GetSockName(t); err == nil {
+ localAddr = *local.(*linux.SockAddrInet)
+ }
+ }
+ writeInetAddr(&buf, linux.AF_INET, &localAddr)
+
+ // Field: rem_address.
+ var remoteAddr linux.SockAddrInet
+ if t != nil {
+ if remote, _, err := sops.GetPeerName(t); err == nil {
+ remoteAddr = *remote.(*linux.SockAddrInet)
+ }
+ }
+ writeInetAddr(&buf, linux.AF_INET, &remoteAddr)
+
+ // Field: state; socket state.
+ fmt.Fprintf(&buf, "%02X ", sops.State())
+
+ // Field: tx_queue, rx_queue; number of packets in the transmit and
+ // receive queue. Unimplemented.
+ fmt.Fprintf(&buf, "%08X:%08X ", 0, 0)
+
+ // Field: tr, tm->when. Always 0 for UDP.
+ fmt.Fprintf(&buf, "%02X:%08X ", 0, 0)
+
+ // Field: retrnsmt. Always 0 for UDP.
+ fmt.Fprintf(&buf, "%08X ", 0)
+
+ // Field: uid.
+ uattr, err := sfile.Dirent.Inode.UnstableAttr(ctx)
+ if err != nil {
+ log.Warningf("Failed to retrieve unstable attr for socket file: %v", err)
+ fmt.Fprintf(&buf, "%5d ", 0)
+ } else {
+ creds := auth.CredentialsFromContext(ctx)
+ fmt.Fprintf(&buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow()))
+ }
+
+ // Field: timeout. Always 0 for UDP.
+ fmt.Fprintf(&buf, "%8d ", 0)
+
+ // Field: inode.
+ fmt.Fprintf(&buf, "%8d ", sfile.InodeID())
+
+ // Field: ref; reference count on the socket inode. Don't count the ref
+ // we obtain while deferencing the weakref to this socket.
+ fmt.Fprintf(&buf, "%d ", sfile.ReadRefs()-1)
+
+ // Field: Socket struct address. Redacted due to the same reason as
+ // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData.
+ fmt.Fprintf(&buf, "%#016p ", (*socket.Socket)(nil))
+
+ // Field: drops; number of dropped packets. Unimplemented.
+ fmt.Fprintf(&buf, "%d", 0)
+
+ fmt.Fprintf(&buf, "\n")
+
+ s.DecRef()
+ }
+
+ data := []seqfile.SeqData{
+ {
+ Buf: []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode ref pointer drops \n"),
+ Handle: n,
+ },
+ {
+ Buf: buf.Bytes(),
+ Handle: n,
+ },
+ }
+ return data, 0
+}
+
+// LINT.ThenChange(../../fsimpl/proc/task_net.go)
diff --git a/pkg/sentry/fs/proc/net_test.go b/pkg/sentry/fs/proc/net_test.go
new file mode 100644
index 000000000..f18681405
--- /dev/null
+++ b/pkg/sentry/fs/proc/net_test.go
@@ -0,0 +1,74 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+)
+
+func newIPv6TestStack() *inet.TestStack {
+ s := inet.NewTestStack()
+ s.SupportsIPv6Flag = true
+ return s
+}
+
+func TestIfinet6NoAddresses(t *testing.T) {
+ n := &ifinet6{s: newIPv6TestStack()}
+ if got := n.contents(); got != nil {
+ t.Errorf("Got n.contents() = %v, want = %v", got, nil)
+ }
+}
+
+func TestIfinet6(t *testing.T) {
+ s := newIPv6TestStack()
+ s.InterfacesMap[1] = inet.Interface{Name: "eth0"}
+ s.InterfaceAddrsMap[1] = []inet.InterfaceAddr{
+ {
+ Family: linux.AF_INET6,
+ PrefixLen: 128,
+ Addr: []byte("\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f"),
+ },
+ }
+ s.InterfacesMap[2] = inet.Interface{Name: "eth1"}
+ s.InterfaceAddrsMap[2] = []inet.InterfaceAddr{
+ {
+ Family: linux.AF_INET6,
+ PrefixLen: 128,
+ Addr: []byte("\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"),
+ },
+ }
+ want := map[string]struct{}{
+ "000102030405060708090a0b0c0d0e0f 01 80 00 00 eth0\n": {},
+ "101112131415161718191a1b1c1d1e1f 02 80 00 00 eth1\n": {},
+ }
+
+ n := &ifinet6{s: s}
+ contents := n.contents()
+ if len(contents) != len(want) {
+ t.Errorf("Got len(n.contents()) = %d, want = %d", len(contents), len(want))
+ }
+ got := map[string]struct{}{}
+ for _, l := range contents {
+ got[l] = struct{}{}
+ }
+
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Got n.contents() = %v, want = %v", got, want)
+ }
+}
diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go
new file mode 100644
index 000000000..c659224a7
--- /dev/null
+++ b/pkg/sentry/fs/proc/proc.go
@@ -0,0 +1,248 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package proc implements a partial in-memory file system for profs.
+package proc
+
+import (
+ "fmt"
+ "sort"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
+ "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
+ "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// LINT.IfChange
+
+// proc is a root proc node.
+//
+// +stateify savable
+type proc struct {
+ ramfs.Dir
+
+ // k is the Kernel containing this proc node.
+ k *kernel.Kernel
+
+ // pidns is the PID namespace of the task that mounted the proc filesystem
+ // that this node represents.
+ pidns *kernel.PIDNamespace
+
+ // cgroupControllers is a map of controller name to directory in the
+ // cgroup hierarchy. These controllers are immutable and will be listed
+ // in /proc/pid/cgroup if not nil.
+ cgroupControllers map[string]string
+}
+
+// New returns the root node of a partial simple procfs.
+func New(ctx context.Context, msrc *fs.MountSource, cgroupControllers map[string]string) (*fs.Inode, error) {
+ k := kernel.KernelFromContext(ctx)
+ if k == nil {
+ return nil, fmt.Errorf("procfs requires a kernel")
+ }
+ pidns := kernel.PIDNamespaceFromContext(ctx)
+ if pidns == nil {
+ return nil, fmt.Errorf("procfs requires a PID namespace")
+ }
+
+ // Note that these are just the static members. There are dynamic
+ // members populated in Readdir and Lookup below.
+ contents := map[string]*fs.Inode{
+ "cpuinfo": newCPUInfo(ctx, msrc),
+ "filesystems": seqfile.NewSeqFileInode(ctx, &filesystemsData{}, msrc),
+ "loadavg": seqfile.NewSeqFileInode(ctx, &loadavgData{}, msrc),
+ "meminfo": seqfile.NewSeqFileInode(ctx, &meminfoData{k}, msrc),
+ "mounts": newProcInode(ctx, ramfs.NewSymlink(ctx, fs.RootOwner, "self/mounts"), msrc, fs.Symlink, nil),
+ "net": newProcInode(ctx, ramfs.NewSymlink(ctx, fs.RootOwner, "self/net"), msrc, fs.Symlink, nil),
+ "self": newSelf(ctx, pidns, msrc),
+ "stat": seqfile.NewSeqFileInode(ctx, &statData{k}, msrc),
+ "thread-self": newThreadSelf(ctx, pidns, msrc),
+ "uptime": newUptime(ctx, msrc),
+ "version": seqfile.NewSeqFileInode(ctx, &versionData{k}, msrc),
+ }
+
+ // Construct the proc InodeOperations.
+ p := &proc{
+ Dir: *ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)),
+ k: k,
+ pidns: pidns,
+ cgroupControllers: cgroupControllers,
+ }
+
+ // Add more contents that need proc to be initialized.
+ p.AddChild(ctx, "sys", p.newSysDir(ctx, msrc))
+
+ return newProcInode(ctx, p, msrc, fs.SpecialDirectory, nil), nil
+}
+
+// self is a magical link.
+//
+// +stateify savable
+type self struct {
+ ramfs.Symlink
+
+ pidns *kernel.PIDNamespace
+}
+
+// newSelf returns a new "self" node.
+func newSelf(ctx context.Context, pidns *kernel.PIDNamespace, msrc *fs.MountSource) *fs.Inode {
+ s := &self{
+ Symlink: *ramfs.NewSymlink(ctx, fs.RootOwner, ""),
+ pidns: pidns,
+ }
+ return newProcInode(ctx, s, msrc, fs.Symlink, nil)
+}
+
+// newThreadSelf returns a new "threadSelf" node.
+func newThreadSelf(ctx context.Context, pidns *kernel.PIDNamespace, msrc *fs.MountSource) *fs.Inode {
+ s := &threadSelf{
+ Symlink: *ramfs.NewSymlink(ctx, fs.RootOwner, ""),
+ pidns: pidns,
+ }
+ return newProcInode(ctx, s, msrc, fs.Symlink, nil)
+}
+
+// Readlink implements fs.InodeOperations.Readlink.
+func (s *self) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
+ if t := kernel.TaskFromContext(ctx); t != nil {
+ tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup())
+ if tgid == 0 {
+ return "", syserror.ENOENT
+ }
+ return strconv.FormatUint(uint64(tgid), 10), nil
+ }
+
+ // Who is reading this link?
+ return "", syserror.EINVAL
+}
+
+// threadSelf is more magical than "self" link.
+//
+// +stateify savable
+type threadSelf struct {
+ ramfs.Symlink
+
+ pidns *kernel.PIDNamespace
+}
+
+// Readlink implements fs.InodeOperations.Readlink.
+func (s *threadSelf) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
+ if t := kernel.TaskFromContext(ctx); t != nil {
+ tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup())
+ tid := s.pidns.IDOfTask(t)
+ if tid == 0 || tgid == 0 {
+ return "", syserror.ENOENT
+ }
+ return fmt.Sprintf("%d/task/%d", tgid, tid), nil
+ }
+
+ // Who is reading this link?
+ return "", syserror.EINVAL
+}
+
+// Lookup loads an Inode at name into a Dirent.
+func (p *proc) Lookup(ctx context.Context, dir *fs.Inode, name string) (*fs.Dirent, error) {
+ dirent, walkErr := p.Dir.Lookup(ctx, dir, name)
+ if walkErr == nil {
+ return dirent, nil
+ }
+
+ // Try to lookup a corresponding task.
+ tid, err := strconv.ParseUint(name, 10, 64)
+ if err != nil {
+ // Ignore the parse error and return the original.
+ return nil, walkErr
+ }
+
+ // Grab the other task.
+ otherTask := p.pidns.TaskWithID(kernel.ThreadID(tid))
+ if otherTask == nil {
+ // Per above.
+ return nil, walkErr
+ }
+
+ // Wrap it in a taskDir.
+ td := p.newTaskDir(otherTask, dir.MountSource, true)
+ return fs.NewDirent(ctx, td, name), nil
+}
+
+// GetFile implements fs.InodeOperations.
+func (p *proc) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &rootProcFile{iops: p}), nil
+}
+
+// rootProcFile implements fs.FileOperations for the proc directory.
+//
+// +stateify savable
+type rootProcFile struct {
+ fsutil.DirFileOperations `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ iops *proc
+}
+
+var _ fs.FileOperations = (*rootProcFile)(nil)
+
+// Readdir implements fs.FileOperations.Readdir.
+func (rpf *rootProcFile) Readdir(ctx context.Context, file *fs.File, ser fs.DentrySerializer) (int64, error) {
+ offset := file.Offset()
+ dirCtx := &fs.DirCtx{
+ Serializer: ser,
+ }
+
+ // Get normal directory contents from ramfs dir.
+ names, m := rpf.iops.Dir.Children()
+
+ // Add dot and dotdot.
+ root := fs.RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+ dot, dotdot := file.Dirent.GetDotAttrs(root)
+ names = append(names, ".", "..")
+ m["."] = dot
+ m[".."] = dotdot
+
+ // Collect tasks.
+ // Per linux we only include it in directory listings if it's the leader.
+ // But for whatever crazy reason, you can still walk to the given node.
+ for _, tg := range rpf.iops.pidns.ThreadGroups() {
+ if leader := tg.Leader(); leader != nil {
+ name := strconv.FormatUint(uint64(rpf.iops.pidns.IDOfThreadGroup(tg)), 10)
+ m[name] = fs.GenericDentAttr(fs.SpecialDirectory, device.ProcDevice)
+ names = append(names, name)
+ }
+ }
+
+ if offset >= int64(len(m)) {
+ return offset, nil
+ }
+ sort.Strings(names)
+ names = names[offset:]
+ for _, name := range names {
+ if err := dirCtx.DirEmit(name, m[name]); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+ return offset, nil
+}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks.go)
diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD
new file mode 100644
index 000000000..21338d912
--- /dev/null
+++ b/pkg/sentry/fs/proc/seqfile/BUILD
@@ -0,0 +1,35 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "seqfile",
+ srcs = ["seqfile.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/proc/device",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "seqfile_test",
+ size = "small",
+ srcs = ["seqfile_test.go"],
+ library = ":seqfile",
+ deps = [
+ "//pkg/context",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/ramfs",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fs/proc/seqfile/seqfile.go b/pkg/sentry/fs/proc/seqfile/seqfile.go
new file mode 100644
index 000000000..6121f0e95
--- /dev/null
+++ b/pkg/sentry/fs/proc/seqfile/seqfile.go
@@ -0,0 +1,283 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package seqfile provides dynamic ordered files.
+package seqfile
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SeqHandle is a helper handle to seek in the file.
+type SeqHandle interface{}
+
+// SeqData holds the data for one unit in the file.
+//
+// +stateify savable
+type SeqData struct {
+ // The data to be returned to the user.
+ Buf []byte
+
+ // A seek handle used to find the next valid unit in ReadSeqFiledata.
+ Handle SeqHandle
+}
+
+// SeqSource is a data source for a SeqFile file.
+type SeqSource interface {
+ // NeedsUpdate returns true if the consumer of SeqData should call
+ // ReadSeqFileData again. Generation is the generation returned by
+ // ReadSeqFile or 0.
+ NeedsUpdate(generation int64) bool
+
+ // Returns a slice of SeqData ordered by unit and the current
+ // generation. The first entry in the slice is greater than the handle.
+ // If handle is nil then all known records are returned. Generation
+ // must always be greater than 0.
+ ReadSeqFileData(ctx context.Context, handle SeqHandle) ([]SeqData, int64)
+}
+
+// SeqGenerationCounter is a counter to keep track if the SeqSource should be
+// updated. SeqGenerationCounter is not thread-safe and should be protected
+// with a mutex.
+type SeqGenerationCounter struct {
+ // The generation that the SeqData is at.
+ generation int64
+}
+
+// SetGeneration sets the generation to the new value, be careful to not set it
+// to a value less than current.
+func (s *SeqGenerationCounter) SetGeneration(generation int64) {
+ s.generation = generation
+}
+
+// Update increments the current generation.
+func (s *SeqGenerationCounter) Update() {
+ s.generation++
+}
+
+// Generation returns the current generation counter.
+func (s *SeqGenerationCounter) Generation() int64 {
+ return s.generation
+}
+
+// IsCurrent returns whether the given generation is current or not.
+func (s *SeqGenerationCounter) IsCurrent(generation int64) bool {
+ return s.Generation() == generation
+}
+
+// SeqFile is used to provide dynamic files that can be ordered by record.
+//
+// +stateify savable
+type SeqFile struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotAllocatable `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeNotTruncatable `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleExtendedAttributes
+ fsutil.InodeSimpleAttributes
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ SeqSource
+
+ source []SeqData
+ generation int64
+ lastRead int64
+}
+
+var _ fs.InodeOperations = (*SeqFile)(nil)
+
+// NewSeqFile returns a seqfile suitable for use by external consumers.
+func NewSeqFile(ctx context.Context, source SeqSource) *SeqFile {
+ return &SeqFile{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ SeqSource: source,
+ }
+}
+
+// NewSeqFileInode returns an Inode with SeqFile InodeOperations.
+func NewSeqFileInode(ctx context.Context, source SeqSource, msrc *fs.MountSource) *fs.Inode {
+ iops := NewSeqFile(ctx, source)
+ sattr := fs.StableAttr{
+ DeviceID: device.ProcDevice.DeviceID(),
+ InodeID: device.ProcDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.SpecialFile,
+ }
+ return fs.NewInode(ctx, iops, msrc, sattr)
+}
+
+// UnstableAttr returns unstable attributes of the SeqFile.
+func (s *SeqFile) UnstableAttr(ctx context.Context, inode *fs.Inode) (fs.UnstableAttr, error) {
+ uattr, err := s.InodeSimpleAttributes.UnstableAttr(ctx, inode)
+ if err != nil {
+ return fs.UnstableAttr{}, err
+ }
+ uattr.ModificationTime = ktime.NowFromContext(ctx)
+ return uattr, nil
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (s *SeqFile) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &seqFileOperations{seqFile: s}), nil
+}
+
+// findIndexAndOffset finds the unit that corresponds to a certain offset.
+// Returns the unit and the offset within the unit. If there are not enough
+// units len(data) and leftover offset is returned.
+func findIndexAndOffset(data []SeqData, offset int64) (int, int64) {
+ for i, buf := range data {
+ l := int64(len(buf.Buf))
+ if offset < l {
+ return i, offset
+ }
+ offset -= l
+ }
+ return len(data), offset
+}
+
+// updateSourceLocked requires that s.mu is held.
+func (s *SeqFile) updateSourceLocked(ctx context.Context, record int) {
+ var h SeqHandle
+ if record == 0 {
+ h = nil
+ } else {
+ h = s.source[record-1].Handle
+ }
+ // Save what we have previously read.
+ s.source = s.source[:record]
+ var newSource []SeqData
+ newSource, s.generation = s.SeqSource.ReadSeqFileData(ctx, h)
+ s.source = append(s.source, newSource...)
+}
+
+// seqFileOperations implements fs.FileOperations.
+//
+// +stateify savable
+type seqFileOperations struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ seqFile *SeqFile
+}
+
+var _ fs.FileOperations = (*seqFileOperations)(nil)
+
+// Write implements fs.FileOperations.Write.
+func (*seqFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.EACCES
+}
+
+// Read implements fs.FileOperations.Read.
+func (sfo *seqFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ sfo.seqFile.mu.Lock()
+ defer sfo.seqFile.mu.Unlock()
+
+ sfo.seqFile.NotifyAccess(ctx)
+ defer func() { sfo.seqFile.lastRead = offset }()
+
+ updated := false
+
+ // Try to find where we should start reading this file.
+ i, recordOffset := findIndexAndOffset(sfo.seqFile.source, offset)
+ if i == len(sfo.seqFile.source) {
+ // Ok, we're at EOF. Let's first check to see if there might be
+ // more data available to us. If there is more data, add it to
+ // the end and try reading again.
+ if !sfo.seqFile.SeqSource.NeedsUpdate(sfo.seqFile.generation) {
+ return 0, io.EOF
+ }
+ oldLen := len(sfo.seqFile.source)
+ sfo.seqFile.updateSourceLocked(ctx, len(sfo.seqFile.source))
+ updated = true
+ // We know that we had consumed everything up until this point
+ // so we search in the new slice instead of starting over.
+ i, recordOffset = findIndexAndOffset(sfo.seqFile.source[oldLen:], recordOffset)
+ i += oldLen
+ // i is at most the length of the slice which is
+ // len(sfo.seqFile.source) - oldLen. So at most i will be equal to
+ // len(sfo.seqFile.source).
+ if i == len(sfo.seqFile.source) {
+ return 0, io.EOF
+ }
+ }
+
+ var done int64
+ // We're reading parts of a record, finish reading the current object
+ // before continuing on to the next. We don't refresh our data source
+ // before this record is completed.
+ if recordOffset != 0 {
+ n, err := dst.CopyOut(ctx, sfo.seqFile.source[i].Buf[recordOffset:])
+ done += int64(n)
+ dst = dst.DropFirst(n)
+ if dst.NumBytes() == 0 || err != nil {
+ return done, err
+ }
+ i++
+ }
+
+ // Next/New unit, update the source file if necessary. Make an extra
+ // check to see if we've seeked backwards and if so always update our
+ // data source.
+ if !updated && (sfo.seqFile.SeqSource.NeedsUpdate(sfo.seqFile.generation) || sfo.seqFile.lastRead > offset) {
+ sfo.seqFile.updateSourceLocked(ctx, i)
+ // recordOffset is 0 here and we won't update records behind the
+ // current one so recordOffset is still 0 even though source
+ // just got updated. Just read the next record.
+ }
+
+ // Finish by reading all the available data.
+ for _, buf := range sfo.seqFile.source[i:] {
+ n, err := dst.CopyOut(ctx, buf.Buf)
+ done += int64(n)
+ dst = dst.DropFirst(n)
+ if dst.NumBytes() == 0 || err != nil {
+ return done, err
+ }
+ }
+
+ // If the file shrank (entries not yet read were removed above)
+ // while we tried to read we can end up with nothing read.
+ if done == 0 && dst.NumBytes() != 0 {
+ return 0, io.EOF
+ }
+ return done, nil
+}
diff --git a/pkg/sentry/fs/proc/seqfile/seqfile_test.go b/pkg/sentry/fs/proc/seqfile/seqfile_test.go
new file mode 100644
index 000000000..98e394569
--- /dev/null
+++ b/pkg/sentry/fs/proc/seqfile/seqfile_test.go
@@ -0,0 +1,279 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seqfile
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type seqTest struct {
+ actual []SeqData
+ update bool
+}
+
+func (s *seqTest) Init() {
+ var sq []SeqData
+ // Create some SeqData.
+ for i := 0; i < 10; i++ {
+ var b []byte
+ for j := 0; j < 10; j++ {
+ b = append(b, byte(i))
+ }
+ sq = append(sq, SeqData{
+ Buf: b,
+ Handle: &testHandle{i: i},
+ })
+ }
+ s.actual = sq
+}
+
+// NeedsUpdate reports whether we need to update the data we've previously read.
+func (s *seqTest) NeedsUpdate(int64) bool {
+ return s.update
+}
+
+// ReadSeqFiledata returns a slice of SeqData which contains elements
+// greater than the handle.
+func (s *seqTest) ReadSeqFileData(ctx context.Context, handle SeqHandle) ([]SeqData, int64) {
+ if handle == nil {
+ return s.actual, 0
+ }
+ h := *handle.(*testHandle)
+ var ret []SeqData
+ for _, b := range s.actual {
+ // We want the next one.
+ h2 := *b.Handle.(*testHandle)
+ if h2.i > h.i {
+ ret = append(ret, b)
+ }
+ }
+ return ret, 0
+}
+
+// Flatten a slice of slices into one slice.
+func flatten(buf ...[]byte) []byte {
+ var flat []byte
+ for _, b := range buf {
+ flat = append(flat, b...)
+ }
+ return flat
+}
+
+type testHandle struct {
+ i int
+}
+
+type testTable struct {
+ offset int64
+ readBufferSize int
+ expectedData []byte
+ expectedError error
+}
+
+func runTableTests(ctx context.Context, table []testTable, dirent *fs.Dirent) error {
+ for _, tt := range table {
+ file, err := dirent.Inode.InodeOperations.GetFile(ctx, dirent, fs.FileFlags{Read: true})
+ if err != nil {
+ return fmt.Errorf("GetFile returned error: %v", err)
+ }
+
+ data := make([]byte, tt.readBufferSize)
+ resultLen, err := file.Preadv(ctx, usermem.BytesIOSequence(data), tt.offset)
+ if err != tt.expectedError {
+ return fmt.Errorf("t.Preadv(len: %v, offset: %v) (error) => %v expected %v", tt.readBufferSize, tt.offset, err, tt.expectedError)
+ }
+ expectedLen := int64(len(tt.expectedData))
+ if resultLen != expectedLen {
+ // We make this just an error so we wall through and print the data below.
+ return fmt.Errorf("t.Preadv(len: %v, offset: %v) (size) => %v expected %v", tt.readBufferSize, tt.offset, resultLen, expectedLen)
+ }
+ if !bytes.Equal(data[:expectedLen], tt.expectedData) {
+ return fmt.Errorf("t.Preadv(len: %v, offset: %v) (data) => %v expected %v", tt.readBufferSize, tt.offset, data[:expectedLen], tt.expectedData)
+ }
+ }
+ return nil
+}
+
+func TestSeqFile(t *testing.T) {
+ testSource := &seqTest{}
+ testSource.Init()
+
+ // Create a file that can be R/W.
+ ctx := contexttest.Context(t)
+ m := fs.NewPseudoMountSource(ctx)
+ contents := map[string]*fs.Inode{
+ "foo": NewSeqFileInode(ctx, testSource, m),
+ }
+ root := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0777))
+
+ // How about opening it?
+ inode := fs.NewInode(ctx, root, m, fs.StableAttr{Type: fs.Directory})
+ dirent2, err := root.Lookup(ctx, inode, "foo")
+ if err != nil {
+ t.Fatalf("failed to walk to foo for n2: %v", err)
+ }
+ n2 := dirent2.Inode.InodeOperations
+ file2, err := n2.GetFile(ctx, dirent2, fs.FileFlags{Read: true, Write: true})
+ if err != nil {
+ t.Fatalf("GetFile returned error: %v", err)
+ }
+
+ // Writing?
+ if _, err := file2.Writev(ctx, usermem.BytesIOSequence([]byte("test"))); err == nil {
+ t.Fatalf("managed to write to n2: %v", err)
+ }
+
+ // How about reading?
+ dirent3, err := root.Lookup(ctx, inode, "foo")
+ if err != nil {
+ t.Fatalf("failed to walk to foo: %v", err)
+ }
+ n3 := dirent3.Inode.InodeOperations
+ if n2 != n3 {
+ t.Error("got n2 != n3, want same")
+ }
+
+ testSource.update = true
+
+ table := []testTable{
+ // Read past the end.
+ {100, 4, []byte{}, io.EOF},
+ {110, 4, []byte{}, io.EOF},
+ {200, 4, []byte{}, io.EOF},
+ // Read a truncated first line.
+ {0, 4, testSource.actual[0].Buf[:4], nil},
+ // Read the whole first line.
+ {0, 10, testSource.actual[0].Buf, nil},
+ // Read the whole first line + 5 bytes of second line.
+ {0, 15, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf[:5]), nil},
+ // First 4 bytes of the second line.
+ {10, 4, testSource.actual[1].Buf[:4], nil},
+ // Read the two first lines.
+ {0, 20, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf), nil},
+ // Read three lines.
+ {0, 30, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf, testSource.actual[2].Buf), nil},
+ // Read everything, but use a bigger buffer than necessary.
+ {0, 150, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf, testSource.actual[2].Buf, testSource.actual[3].Buf, testSource.actual[4].Buf, testSource.actual[5].Buf, testSource.actual[6].Buf, testSource.actual[7].Buf, testSource.actual[8].Buf, testSource.actual[9].Buf), nil},
+ // Read the last 3 bytes.
+ {97, 10, testSource.actual[9].Buf[7:], nil},
+ }
+ if err := runTableTests(ctx, table, dirent2); err != nil {
+ t.Errorf("runTableTest failed with testSource.update = %v : %v", testSource.update, err)
+ }
+
+ // Disable updates and do it again.
+ testSource.update = false
+ if err := runTableTests(ctx, table, dirent2); err != nil {
+ t.Errorf("runTableTest failed with testSource.update = %v: %v", testSource.update, err)
+ }
+}
+
+// Test that we behave correctly when the file is updated.
+func TestSeqFileFileUpdated(t *testing.T) {
+ testSource := &seqTest{}
+ testSource.Init()
+ testSource.update = true
+
+ // Create a file that can be R/W.
+ ctx := contexttest.Context(t)
+ m := fs.NewPseudoMountSource(ctx)
+ contents := map[string]*fs.Inode{
+ "foo": NewSeqFileInode(ctx, testSource, m),
+ }
+ root := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0777))
+
+ // How about opening it?
+ inode := fs.NewInode(ctx, root, m, fs.StableAttr{Type: fs.Directory})
+ dirent2, err := root.Lookup(ctx, inode, "foo")
+ if err != nil {
+ t.Fatalf("failed to walk to foo for dirent2: %v", err)
+ }
+
+ table := []testTable{
+ {0, 16, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf[:6]), nil},
+ }
+ if err := runTableTests(ctx, table, dirent2); err != nil {
+ t.Errorf("runTableTest failed: %v", err)
+ }
+ // Delete the first entry.
+ cut := testSource.actual[0].Buf
+ testSource.actual = testSource.actual[1:]
+
+ table = []testTable{
+ // Try reading buffer 0 with an offset. This will not delete the old data.
+ {1, 5, cut[1:6], nil},
+ // Reset our file by reading at offset 0.
+ {0, 10, testSource.actual[0].Buf, nil},
+ {16, 14, flatten(testSource.actual[1].Buf[6:], testSource.actual[2].Buf), nil},
+ // Read the same data a second time.
+ {16, 14, flatten(testSource.actual[1].Buf[6:], testSource.actual[2].Buf), nil},
+ // Read the following two lines.
+ {30, 20, flatten(testSource.actual[3].Buf, testSource.actual[4].Buf), nil},
+ }
+ if err := runTableTests(ctx, table, dirent2); err != nil {
+ t.Errorf("runTableTest failed after removing first entry: %v", err)
+ }
+
+ // Add a new duplicate line in the middle (6666...)
+ after := testSource.actual[5:]
+ testSource.actual = testSource.actual[:4]
+ // Note the list must be sorted.
+ testSource.actual = append(testSource.actual, after[0])
+ testSource.actual = append(testSource.actual, after...)
+
+ table = []testTable{
+ {50, 20, flatten(testSource.actual[4].Buf, testSource.actual[5].Buf), nil},
+ }
+ if err := runTableTests(ctx, table, dirent2); err != nil {
+ t.Errorf("runTableTest failed after adding middle entry: %v", err)
+ }
+ // This will be used in a later test.
+ oldTestData := testSource.actual
+
+ // Delete everything.
+ testSource.actual = testSource.actual[:0]
+ table = []testTable{
+ {20, 20, []byte{}, io.EOF},
+ }
+ if err := runTableTests(ctx, table, dirent2); err != nil {
+ t.Errorf("runTableTest failed after removing all entries: %v", err)
+ }
+ // Restore some of the data.
+ testSource.actual = oldTestData[:1]
+ table = []testTable{
+ {6, 20, testSource.actual[0].Buf[6:], nil},
+ }
+ if err := runTableTests(ctx, table, dirent2); err != nil {
+ t.Errorf("runTableTest failed after adding first entry back: %v", err)
+ }
+
+ // Re-extend the data
+ testSource.actual = oldTestData
+ table = []testTable{
+ {30, 20, flatten(testSource.actual[3].Buf, testSource.actual[4].Buf), nil},
+ }
+ if err := runTableTests(ctx, table, dirent2); err != nil {
+ t.Errorf("runTableTest failed after extending testSource: %v", err)
+ }
+}
diff --git a/pkg/sentry/fs/proc/stat.go b/pkg/sentry/fs/proc/stat.go
new file mode 100644
index 000000000..d4fbd76ac
--- /dev/null
+++ b/pkg/sentry/fs/proc/stat.go
@@ -0,0 +1,146 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// LINT.IfChange
+
+// statData backs /proc/stat.
+//
+// +stateify savable
+type statData struct {
+ // k is the owning Kernel.
+ k *kernel.Kernel
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*statData) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// cpuStats contains the breakdown of CPU time for /proc/stat.
+type cpuStats struct {
+ // user is time spent in userspace tasks with non-positive niceness.
+ user uint64
+
+ // nice is time spent in userspace tasks with positive niceness.
+ nice uint64
+
+ // system is time spent in non-interrupt kernel context.
+ system uint64
+
+ // idle is time spent idle.
+ idle uint64
+
+ // ioWait is time spent waiting for IO.
+ ioWait uint64
+
+ // irq is time spent in interrupt context.
+ irq uint64
+
+ // softirq is time spent in software interrupt context.
+ softirq uint64
+
+ // steal is involuntary wait time.
+ steal uint64
+
+ // guest is time spent in guests with non-positive niceness.
+ guest uint64
+
+ // guestNice is time spent in guests with positive niceness.
+ guestNice uint64
+}
+
+// String implements fmt.Stringer.
+func (c cpuStats) String() string {
+ return fmt.Sprintf("%d %d %d %d %d %d %d %d %d %d", c.user, c.nice, c.system, c.idle, c.ioWait, c.irq, c.softirq, c.steal, c.guest, c.guestNice)
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (s *statData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ var buf bytes.Buffer
+
+ // TODO(b/37226836): We currently export only zero CPU stats. We could
+ // at least provide some aggregate stats.
+ var cpu cpuStats
+ fmt.Fprintf(&buf, "cpu %s\n", cpu)
+
+ for c, max := uint(0), s.k.ApplicationCores(); c < max; c++ {
+ fmt.Fprintf(&buf, "cpu%d %s\n", c, cpu)
+ }
+
+ // The total number of interrupts is dependent on the CPUs and PCI
+ // devices on the system. See arch_probe_nr_irqs.
+ //
+ // Since we don't report real interrupt stats, just choose an arbitrary
+ // value from a representative VM.
+ const numInterrupts = 256
+
+ // The Kernel doesn't handle real interrupts, so report all zeroes.
+ // TODO(b/37226836): We could count page faults as #PF.
+ fmt.Fprintf(&buf, "intr 0") // total
+ for i := 0; i < numInterrupts; i++ {
+ fmt.Fprintf(&buf, " 0")
+ }
+ fmt.Fprintf(&buf, "\n")
+
+ // Total number of context switches.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(&buf, "ctxt 0\n")
+
+ // CLOCK_REALTIME timestamp from boot, in seconds.
+ fmt.Fprintf(&buf, "btime %d\n", s.k.Timekeeper().BootTime().Seconds())
+
+ // Total number of clones.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(&buf, "processes 0\n")
+
+ // Number of runnable tasks.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(&buf, "procs_running 0\n")
+
+ // Number of tasks waiting on IO.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(&buf, "procs_blocked 0\n")
+
+ // Number of each softirq handled.
+ fmt.Fprintf(&buf, "softirq 0") // total
+ for i := 0; i < linux.NumSoftIRQ; i++ {
+ fmt.Fprintf(&buf, " 0")
+ }
+ fmt.Fprintf(&buf, "\n")
+
+ return []seqfile.SeqData{
+ {
+ Buf: buf.Bytes(),
+ Handle: (*statData)(nil),
+ },
+ }, 0
+}
+
+// LINT.ThenChange(../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go
new file mode 100644
index 000000000..f8aad2dbd
--- /dev/null
+++ b/pkg/sentry/fs/proc/sys.go
@@ -0,0 +1,159 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+ "io"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
+ "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+// mmapMinAddrData backs /proc/sys/vm/mmap_min_addr.
+//
+// +stateify savable
+type mmapMinAddrData struct {
+ k *kernel.Kernel
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*mmapMinAddrData) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (d *mmapMinAddrData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+ return []seqfile.SeqData{
+ {
+ Buf: []byte(fmt.Sprintf("%d\n", d.k.Platform.MinUserAddress())),
+ Handle: (*mmapMinAddrData)(nil),
+ },
+ }, 0
+}
+
+// +stateify savable
+type overcommitMemory struct{}
+
+func (*overcommitMemory) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.
+func (*overcommitMemory) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+ return []seqfile.SeqData{
+ {
+ Buf: []byte("0\n"),
+ Handle: (*overcommitMemory)(nil),
+ },
+ }, 0
+}
+
+func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ h := hostname{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ }
+
+ children := map[string]*fs.Inode{
+ "hostname": newProcInode(ctx, &h, msrc, fs.SpecialFile, nil),
+ "shmall": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMALL, 10))),
+ "shmmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMAX, 10))),
+ "shmmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMNI, 10))),
+ }
+
+ d := ramfs.NewDir(ctx, children, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
+}
+
+func (p *proc) newVMDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ children := map[string]*fs.Inode{
+ "mmap_min_addr": seqfile.NewSeqFileInode(ctx, &mmapMinAddrData{p.k}, msrc),
+ "overcommit_memory": seqfile.NewSeqFileInode(ctx, &overcommitMemory{}, msrc),
+ }
+ d := ramfs.NewDir(ctx, children, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
+}
+
+func (p *proc) newSysDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ children := map[string]*fs.Inode{
+ "kernel": p.newKernelDir(ctx, msrc),
+ "net": p.newSysNetDir(ctx, msrc),
+ "vm": p.newVMDir(ctx, msrc),
+ }
+
+ d := ramfs.NewDir(ctx, children, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
+}
+
+// hostname is the inode for a file containing the system hostname.
+//
+// +stateify savable
+type hostname struct {
+ fsutil.SimpleFileInode
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (h *hostname) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, d, flags, &hostnameFile{}), nil
+}
+
+var _ fs.InodeOperations = (*hostname)(nil)
+
+// +stateify savable
+type hostnameFile struct {
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSeek `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+}
+
+// Read implements fs.FileOperations.Read.
+func (hf *hostnameFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ utsns := kernel.UTSNamespaceFromContext(ctx)
+ contents := []byte(utsns.HostName() + "\n")
+ if offset >= int64(len(contents)) {
+ return 0, io.EOF
+ }
+ n, err := dst.CopyOut(ctx, contents[offset:])
+ return int64(n), err
+
+}
+
+var _ fs.FileOperations = (*hostnameFile)(nil)
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_sys.go)
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
new file mode 100644
index 000000000..702fdd392
--- /dev/null
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -0,0 +1,372 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
+ "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+type tcpMemDir int
+
+const (
+ tcpRMem tcpMemDir = iota
+ tcpWMem
+)
+
+// tcpMemInode is used to read/write the size of netstack tcp buffers.
+//
+// TODO(b/121381035): If we have multiple proc mounts, concurrent writes can
+// leave netstack and the proc files in an inconsistent state. Since we set the
+// buffer size from these proc files on restore, we may also race and end up in
+// an inconsistent state on restore.
+//
+// +stateify savable
+type tcpMemInode struct {
+ fsutil.SimpleFileInode
+ dir tcpMemDir
+ s inet.Stack `state:"wait"`
+
+ // size stores the tcp buffer size during save, and sets the buffer
+ // size in netstack in restore. We must save/restore this here, since
+ // netstack itself is stateless.
+ size inet.TCPBufferSize
+
+ // mu protects against concurrent reads/writes to files based on this
+ // inode.
+ mu sync.Mutex `state:"nosave"`
+}
+
+var _ fs.InodeOperations = (*tcpMemInode)(nil)
+
+func newTCPMemInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack, dir tcpMemDir) *fs.Inode {
+ tm := &tcpMemInode{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC),
+ s: s,
+ dir: dir,
+ }
+ sattr := fs.StableAttr{
+ DeviceID: device.ProcDevice.DeviceID(),
+ InodeID: device.ProcDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.SpecialFile,
+ }
+ return fs.NewInode(ctx, tm, msrc, sattr)
+}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (*tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (m *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ flags.Pread = true
+ return fs.NewFile(ctx, dirent, flags, &tcpMemFile{tcpMemInode: m}), nil
+}
+
+// +stateify savable
+type tcpMemFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ tcpMemInode *tcpMemInode
+}
+
+var _ fs.FileOperations = (*tcpMemFile)(nil)
+
+// Read implements fs.FileOperations.Read.
+func (f *tcpMemFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset != 0 {
+ return 0, io.EOF
+ }
+ f.tcpMemInode.mu.Lock()
+ defer f.tcpMemInode.mu.Unlock()
+
+ size, err := readSize(f.tcpMemInode.dir, f.tcpMemInode.s)
+ if err != nil {
+ return 0, err
+ }
+ s := fmt.Sprintf("%d\t%d\t%d\n", size.Min, size.Default, size.Max)
+ n, err := dst.CopyOut(ctx, []byte(s))
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+func (f *tcpMemFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+ f.tcpMemInode.mu.Lock()
+ defer f.tcpMemInode.mu.Unlock()
+
+ src = src.TakeFirst(usermem.PageSize - 1)
+ size, err := readSize(f.tcpMemInode.dir, f.tcpMemInode.s)
+ if err != nil {
+ return 0, err
+ }
+ buf := []int32{int32(size.Min), int32(size.Default), int32(size.Max)}
+ n, cperr := usermem.CopyInt32StringsInVec(ctx, src.IO, src.Addrs, buf, src.Opts)
+ newSize := inet.TCPBufferSize{
+ Min: int(buf[0]),
+ Default: int(buf[1]),
+ Max: int(buf[2]),
+ }
+ if err := writeSize(f.tcpMemInode.dir, f.tcpMemInode.s, newSize); err != nil {
+ return n, err
+ }
+ return n, cperr
+}
+
+func readSize(dirType tcpMemDir, s inet.Stack) (inet.TCPBufferSize, error) {
+ switch dirType {
+ case tcpRMem:
+ return s.TCPReceiveBufferSize()
+ case tcpWMem:
+ return s.TCPSendBufferSize()
+ default:
+ panic(fmt.Sprintf("unknown tcpMemFile type: %v", dirType))
+ }
+}
+
+func writeSize(dirType tcpMemDir, s inet.Stack, size inet.TCPBufferSize) error {
+ switch dirType {
+ case tcpRMem:
+ return s.SetTCPReceiveBufferSize(size)
+ case tcpWMem:
+ return s.SetTCPSendBufferSize(size)
+ default:
+ panic(fmt.Sprintf("unknown tcpMemFile type: %v", dirType))
+ }
+}
+
+// +stateify savable
+type tcpSack struct {
+ fsutil.SimpleFileInode
+
+ stack inet.Stack `state:"wait"`
+ enabled *bool
+}
+
+func newTCPSackInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode {
+ ts := &tcpSack{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC),
+ stack: s,
+ }
+ sattr := fs.StableAttr{
+ DeviceID: device.ProcDevice.DeviceID(),
+ InodeID: device.ProcDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.SpecialFile,
+ }
+ return fs.NewInode(ctx, ts, msrc, sattr)
+}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (*tcpSack) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (s *tcpSack) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ flags.Pread = true
+ flags.Pwrite = true
+ return fs.NewFile(ctx, dirent, flags, &tcpSackFile{
+ tcpSack: s,
+ stack: s.stack,
+ }), nil
+}
+
+// +stateify savable
+type tcpSackFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ tcpSack *tcpSack
+
+ stack inet.Stack `state:"wait"`
+}
+
+// Read implements fs.FileOperations.Read.
+func (f *tcpSackFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset != 0 {
+ return 0, io.EOF
+ }
+
+ if f.tcpSack.enabled == nil {
+ sack, err := f.stack.TCPSACKEnabled()
+ if err != nil {
+ return 0, err
+ }
+ f.tcpSack.enabled = &sack
+ }
+
+ val := "0\n"
+ if *f.tcpSack.enabled {
+ // Technically, this is not quite compatible with Linux. Linux
+ // stores these as an integer, so if you write "2" into
+ // tcp_sack, you should get 2 back. Tough luck.
+ val = "1\n"
+ }
+ n, err := dst.CopyOut(ctx, []byte(val))
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+func (f *tcpSackFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ var v int32
+ n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
+ if err != nil {
+ return n, err
+ }
+ if f.tcpSack.enabled == nil {
+ f.tcpSack.enabled = new(bool)
+ }
+ *f.tcpSack.enabled = v != 0
+ return n, f.tcpSack.stack.SetTCPSACKEnabled(*f.tcpSack.enabled)
+}
+
+func (p *proc) newSysNetCore(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode {
+ // The following files are simple stubs until they are implemented in
+ // netstack, most of these files are configuration related. We use the
+ // value closest to the actual netstack behavior or any empty file,
+ // all of these files will have mode 0444 (read-only for all users).
+ contents := map[string]*fs.Inode{
+ "default_qdisc": newStaticProcInode(ctx, msrc, []byte("pfifo_fast")),
+ "message_burst": newStaticProcInode(ctx, msrc, []byte("10")),
+ "message_cost": newStaticProcInode(ctx, msrc, []byte("5")),
+ "optmem_max": newStaticProcInode(ctx, msrc, []byte("0")),
+ "rmem_default": newStaticProcInode(ctx, msrc, []byte("212992")),
+ "rmem_max": newStaticProcInode(ctx, msrc, []byte("212992")),
+ "somaxconn": newStaticProcInode(ctx, msrc, []byte("128")),
+ "wmem_default": newStaticProcInode(ctx, msrc, []byte("212992")),
+ "wmem_max": newStaticProcInode(ctx, msrc, []byte("212992")),
+ }
+
+ d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
+}
+
+func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode {
+ contents := map[string]*fs.Inode{
+ // Add tcp_sack.
+ "tcp_sack": newTCPSackInode(ctx, msrc, s),
+
+ // The following files are simple stubs until they are
+ // implemented in netstack, most of these files are
+ // configuration related. We use the value closest to the
+ // actual netstack behavior or any empty file, all of these
+ // files will have mode 0444 (read-only for all users).
+ "ip_local_port_range": newStaticProcInode(ctx, msrc, []byte("16000 65535")),
+ "ip_local_reserved_ports": newStaticProcInode(ctx, msrc, []byte("")),
+ "ipfrag_time": newStaticProcInode(ctx, msrc, []byte("30")),
+ "ip_nonlocal_bind": newStaticProcInode(ctx, msrc, []byte("0")),
+ "ip_no_pmtu_disc": newStaticProcInode(ctx, msrc, []byte("1")),
+
+ // tcp_allowed_congestion_control tell the user what they are
+ // able to do as an unprivledged process so we leave it empty.
+ "tcp_allowed_congestion_control": newStaticProcInode(ctx, msrc, []byte("")),
+ "tcp_available_congestion_control": newStaticProcInode(ctx, msrc, []byte("reno")),
+ "tcp_congestion_control": newStaticProcInode(ctx, msrc, []byte("reno")),
+
+ // Many of the following stub files are features netstack
+ // doesn't support. The unsupported features return "0" to
+ // indicate they are disabled.
+ "tcp_base_mss": newStaticProcInode(ctx, msrc, []byte("1280")),
+ "tcp_dsack": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_early_retrans": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_fack": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_fastopen": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_fastopen_key": newStaticProcInode(ctx, msrc, []byte("")),
+ "tcp_invalid_ratelimit": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_keepalive_intvl": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_keepalive_probes": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_keepalive_time": newStaticProcInode(ctx, msrc, []byte("7200")),
+ "tcp_mtu_probing": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_no_metrics_save": newStaticProcInode(ctx, msrc, []byte("1")),
+ "tcp_probe_interval": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_probe_threshold": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_retries1": newStaticProcInode(ctx, msrc, []byte("3")),
+ "tcp_retries2": newStaticProcInode(ctx, msrc, []byte("15")),
+ "tcp_rfc1337": newStaticProcInode(ctx, msrc, []byte("1")),
+ "tcp_slow_start_after_idle": newStaticProcInode(ctx, msrc, []byte("1")),
+ "tcp_synack_retries": newStaticProcInode(ctx, msrc, []byte("5")),
+ "tcp_syn_retries": newStaticProcInode(ctx, msrc, []byte("3")),
+ "tcp_timestamps": newStaticProcInode(ctx, msrc, []byte("1")),
+ }
+
+ // Add tcp_rmem.
+ if _, err := s.TCPReceiveBufferSize(); err == nil {
+ contents["tcp_rmem"] = newTCPMemInode(ctx, msrc, s, tcpRMem)
+ }
+
+ // Add tcp_wmem.
+ if _, err := s.TCPSendBufferSize(); err == nil {
+ contents["tcp_wmem"] = newTCPMemInode(ctx, msrc, s, tcpWMem)
+ }
+
+ d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
+}
+
+func (p *proc) newSysNetDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ var contents map[string]*fs.Inode
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process.
+ if s := p.k.RootNetworkNamespace().Stack(); s != nil {
+ contents = map[string]*fs.Inode{
+ "ipv4": p.newSysNetIPv4Dir(ctx, msrc, s),
+ "core": p.newSysNetCore(ctx, msrc, s),
+ }
+ }
+ d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
+}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_sys.go)
diff --git a/pkg/sentry/fs/proc/sys_net_state.go b/pkg/sentry/fs/proc/sys_net_state.go
new file mode 100644
index 000000000..6eba709c6
--- /dev/null
+++ b/pkg/sentry/fs/proc/sys_net_state.go
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import "fmt"
+
+// beforeSave is invoked by stateify.
+func (t *tcpMemInode) beforeSave() {
+ size, err := readSize(t.dir, t.s)
+ if err != nil {
+ panic(fmt.Sprintf("failed to read TCP send / receive buffer sizes: %v", err))
+ }
+ t.size = size
+}
+
+// afterLoad is invoked by stateify.
+func (t *tcpMemInode) afterLoad() {
+ if err := writeSize(t.dir, t.s, t.size); err != nil {
+ panic(fmt.Sprintf("failed to write previous TCP send / receive buffer sizes [%v]: %v", t.size, err))
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (s *tcpSack) afterLoad() {
+ if s.enabled != nil {
+ if err := s.stack.SetTCPSACKEnabled(*s.enabled); err != nil {
+ panic(fmt.Sprintf("failed to set previous TCP sack configuration [%v]: %v", *s.enabled, err))
+ }
+ }
+}
diff --git a/pkg/sentry/fs/proc/sys_net_test.go b/pkg/sentry/fs/proc/sys_net_test.go
new file mode 100644
index 000000000..355e83d47
--- /dev/null
+++ b/pkg/sentry/fs/proc/sys_net_test.go
@@ -0,0 +1,125 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func TestQuerySendBufferSize(t *testing.T) {
+ ctx := context.Background()
+ s := inet.NewTestStack()
+ s.TCPSendBufSize = inet.TCPBufferSize{100, 200, 300}
+ tmi := &tcpMemInode{s: s, dir: tcpWMem}
+ tmf := &tcpMemFile{tcpMemInode: tmi}
+
+ buf := make([]byte, 100)
+ dst := usermem.BytesIOSequence(buf)
+ n, err := tmf.Read(ctx, nil, dst, 0)
+ if err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+
+ if got, want := string(buf[:n]), "100\t200\t300\n"; got != want {
+ t.Fatalf("Bad string: got %v, want %v", got, want)
+ }
+}
+
+func TestQueryRecvBufferSize(t *testing.T) {
+ ctx := context.Background()
+ s := inet.NewTestStack()
+ s.TCPRecvBufSize = inet.TCPBufferSize{100, 200, 300}
+ tmi := &tcpMemInode{s: s, dir: tcpRMem}
+ tmf := &tcpMemFile{tcpMemInode: tmi}
+
+ buf := make([]byte, 100)
+ dst := usermem.BytesIOSequence(buf)
+ n, err := tmf.Read(ctx, nil, dst, 0)
+ if err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+
+ if got, want := string(buf[:n]), "100\t200\t300\n"; got != want {
+ t.Fatalf("Bad string: got %v, want %v", got, want)
+ }
+}
+
+var cases = []struct {
+ str string
+ initial inet.TCPBufferSize
+ final inet.TCPBufferSize
+}{
+ {
+ str: "",
+ initial: inet.TCPBufferSize{1, 2, 3},
+ final: inet.TCPBufferSize{1, 2, 3},
+ },
+ {
+ str: "100\n",
+ initial: inet.TCPBufferSize{1, 100, 200},
+ final: inet.TCPBufferSize{100, 100, 200},
+ },
+ {
+ str: "100 200 300\n",
+ initial: inet.TCPBufferSize{1, 2, 3},
+ final: inet.TCPBufferSize{100, 200, 300},
+ },
+}
+
+func TestConfigureSendBufferSize(t *testing.T) {
+ ctx := context.Background()
+ s := inet.NewTestStack()
+ for _, c := range cases {
+ s.TCPSendBufSize = c.initial
+ tmi := &tcpMemInode{s: s, dir: tcpWMem}
+ tmf := &tcpMemFile{tcpMemInode: tmi}
+
+ // Write the values.
+ src := usermem.BytesIOSequence([]byte(c.str))
+ if n, err := tmf.Write(ctx, nil, src, 0); n != int64(len(c.str)) || err != nil {
+ t.Errorf("Write, case = %q: got (%d, %v), wanted (%d, nil)", c.str, n, err, len(c.str))
+ }
+
+ // Read the values from the stack and check them.
+ if s.TCPSendBufSize != c.final {
+ t.Errorf("TCPSendBufferSize, case = %q: got %v, wanted %v", c.str, s.TCPSendBufSize, c.final)
+ }
+ }
+}
+
+func TestConfigureRecvBufferSize(t *testing.T) {
+ ctx := context.Background()
+ s := inet.NewTestStack()
+ for _, c := range cases {
+ s.TCPRecvBufSize = c.initial
+ tmi := &tcpMemInode{s: s, dir: tcpRMem}
+ tmf := &tcpMemFile{tcpMemInode: tmi}
+
+ // Write the values.
+ src := usermem.BytesIOSequence([]byte(c.str))
+ if n, err := tmf.Write(ctx, nil, src, 0); n != int64(len(c.str)) || err != nil {
+ t.Errorf("Write, case = %q: got (%d, %v), wanted (%d, nil)", c.str, n, err, len(c.str))
+ }
+
+ // Read the values from the stack and check them.
+ if s.TCPRecvBufSize != c.final {
+ t.Errorf("TCPRecvBufferSize, case = %q: got %v, wanted %v", c.str, s.TCPRecvBufSize, c.final)
+ }
+ }
+}
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
new file mode 100644
index 000000000..4bbe90198
--- /dev/null
+++ b/pkg/sentry/fs/proc/task.go
@@ -0,0 +1,914 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "sort"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
+ "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
+ "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+// getTaskMM returns t's MemoryManager. If getTaskMM succeeds, the MemoryManager's
+// users count is incremented, and must be decremented by the caller when it is
+// no longer in use.
+func getTaskMM(t *kernel.Task) (*mm.MemoryManager, error) {
+ if t.ExitState() == kernel.TaskExitDead {
+ return nil, syserror.ESRCH
+ }
+ var m *mm.MemoryManager
+ t.WithMuLocked(func(t *kernel.Task) {
+ m = t.MemoryManager()
+ })
+ if m == nil || !m.IncUsers() {
+ return nil, io.EOF
+ }
+ return m, nil
+}
+
+func checkTaskState(t *kernel.Task) error {
+ switch t.ExitState() {
+ case kernel.TaskExitZombie:
+ return syserror.EACCES
+ case kernel.TaskExitDead:
+ return syserror.ESRCH
+ }
+ return nil
+}
+
+// taskDir represents a task-level directory.
+//
+// +stateify savable
+type taskDir struct {
+ ramfs.Dir
+
+ t *kernel.Task
+}
+
+var _ fs.InodeOperations = (*taskDir)(nil)
+
+// newTaskDir creates a new proc task entry.
+func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode {
+ contents := map[string]*fs.Inode{
+ "auxv": newAuxvec(t, msrc),
+ "cmdline": newExecArgInode(t, msrc, cmdlineExecArg),
+ "comm": newComm(t, msrc),
+ "environ": newExecArgInode(t, msrc, environExecArg),
+ "exe": newExe(t, msrc),
+ "fd": newFdDir(t, msrc),
+ "fdinfo": newFdInfoDir(t, msrc),
+ "gid_map": newGIDMap(t, msrc),
+ "io": newIO(t, msrc, isThreadGroup),
+ "maps": newMaps(t, msrc),
+ "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc),
+ "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc),
+ "net": newNetDir(t, msrc),
+ "ns": newNamespaceDir(t, msrc),
+ "oom_score": newOOMScore(t, msrc),
+ "oom_score_adj": newOOMScoreAdj(t, msrc),
+ "smaps": newSmaps(t, msrc),
+ "stat": newTaskStat(t, msrc, isThreadGroup, p.pidns),
+ "statm": newStatm(t, msrc),
+ "status": newStatus(t, msrc, p.pidns),
+ "uid_map": newUIDMap(t, msrc),
+ }
+ if isThreadGroup {
+ contents["task"] = p.newSubtasks(t, msrc)
+ }
+ if len(p.cgroupControllers) > 0 {
+ contents["cgroup"] = newCGroupInode(t, msrc, p.cgroupControllers)
+ }
+
+ // N.B. taskOwnedInodeOps enforces dumpability-based ownership.
+ d := &taskDir{
+ Dir: *ramfs.NewDir(t, contents, fs.RootOwner, fs.FilePermsFromMode(0555)),
+ t: t,
+ }
+ return newProcInode(t, d, msrc, fs.SpecialDirectory, t)
+}
+
+// subtasks represents a /proc/TID/task directory.
+//
+// +stateify savable
+type subtasks struct {
+ ramfs.Dir
+
+ t *kernel.Task
+ p *proc
+}
+
+var _ fs.InodeOperations = (*subtasks)(nil)
+
+func (p *proc) newSubtasks(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ s := &subtasks{
+ Dir: *ramfs.NewDir(t, nil, fs.RootOwner, fs.FilePermsFromMode(0555)),
+ t: t,
+ p: p,
+ }
+ return newProcInode(t, s, msrc, fs.SpecialDirectory, t)
+}
+
+// UnstableAttr returns unstable attributes of the subtasks.
+func (s *subtasks) UnstableAttr(ctx context.Context, inode *fs.Inode) (fs.UnstableAttr, error) {
+ uattr, err := s.Dir.UnstableAttr(ctx, inode)
+ if err != nil {
+ return fs.UnstableAttr{}, err
+ }
+ // We can't rely on ramfs' implementation because the task directories are
+ // generated dynamically.
+ uattr.Links = uint64(2 + s.t.ThreadGroup().Count())
+ return uattr, nil
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (s *subtasks) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &subtasksFile{t: s.t, pidns: s.p.pidns}), nil
+}
+
+// +stateify savable
+type subtasksFile struct {
+ fsutil.DirFileOperations `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ t *kernel.Task
+ pidns *kernel.PIDNamespace
+}
+
+// Readdir implements fs.FileOperations.Readdir.
+func (f *subtasksFile) Readdir(ctx context.Context, file *fs.File, ser fs.DentrySerializer) (int64, error) {
+ dirCtx := fs.DirCtx{
+ Serializer: ser,
+ }
+
+ // Note that unlike most Readdir implementations, the offset here is
+ // not an index into the subtasks, but rather the TID of the next
+ // subtask to emit.
+ offset := file.Offset()
+
+ tasks := f.t.ThreadGroup().MemberIDs(f.pidns)
+ if len(tasks) == 0 {
+ return offset, syserror.ENOENT
+ }
+
+ if offset == 0 {
+ // Serialize "." and "..".
+ root := fs.RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+ dot, dotdot := file.Dirent.GetDotAttrs(root)
+ if err := dirCtx.DirEmit(".", dot); err != nil {
+ return offset, err
+ }
+ if err := dirCtx.DirEmit("..", dotdot); err != nil {
+ return offset, err
+ }
+ }
+
+ // Serialize tasks.
+ taskInts := make([]int, 0, len(tasks))
+ for _, tid := range tasks {
+ taskInts = append(taskInts, int(tid))
+ }
+
+ sort.Sort(sort.IntSlice(taskInts))
+ // Find the task to start at.
+ idx := sort.SearchInts(taskInts, int(offset))
+ if idx == len(taskInts) {
+ return offset, nil
+ }
+ taskInts = taskInts[idx:]
+
+ var tid int
+ for _, tid = range taskInts {
+ name := strconv.FormatUint(uint64(tid), 10)
+ attr := fs.GenericDentAttr(fs.SpecialDirectory, device.ProcDevice)
+ if err := dirCtx.DirEmit(name, attr); err != nil {
+ // Returned offset is next tid to serialize.
+ return int64(tid), err
+ }
+ }
+ // We serialized them all. Next offset should be higher than last
+ // serialized tid.
+ return int64(tid) + 1, nil
+}
+
+var _ fs.FileOperations = (*subtasksFile)(nil)
+
+// Lookup loads an Inode in a task's subtask directory into a Dirent.
+func (s *subtasks) Lookup(ctx context.Context, dir *fs.Inode, p string) (*fs.Dirent, error) {
+ tid, err := strconv.ParseUint(p, 10, 32)
+ if err != nil {
+ return nil, syserror.ENOENT
+ }
+
+ task := s.p.pidns.TaskWithID(kernel.ThreadID(tid))
+ if task == nil {
+ return nil, syserror.ENOENT
+ }
+ if task.ThreadGroup() != s.t.ThreadGroup() {
+ return nil, syserror.ENOENT
+ }
+
+ td := s.p.newTaskDir(task, dir.MountSource, false)
+ return fs.NewDirent(ctx, td, p), nil
+}
+
+// exe is an fs.InodeOperations symlink for the /proc/PID/exe file.
+//
+// +stateify savable
+type exe struct {
+ ramfs.Symlink
+
+ t *kernel.Task
+}
+
+func newExe(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ exeSymlink := &exe{
+ Symlink: *ramfs.NewSymlink(t, fs.RootOwner, ""),
+ t: t,
+ }
+ return newProcInode(t, exeSymlink, msrc, fs.Symlink, t)
+}
+
+func (e *exe) executable() (file fsbridge.File, err error) {
+ if err := checkTaskState(e.t); err != nil {
+ return nil, err
+ }
+ e.t.WithMuLocked(func(t *kernel.Task) {
+ mm := t.MemoryManager()
+ if mm == nil {
+ err = syserror.EACCES
+ return
+ }
+
+ // The MemoryManager may be destroyed, in which case
+ // MemoryManager.destroy will simply set the executable to nil
+ // (with locks held).
+ file = mm.Executable()
+ if file == nil {
+ err = syserror.ESRCH
+ }
+ })
+ return
+}
+
+// Readlink implements fs.InodeOperations.
+func (e *exe) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
+ if !kernel.ContextCanTrace(ctx, e.t, false) {
+ return "", syserror.EACCES
+ }
+
+ // Pull out the executable for /proc/TID/exe.
+ exec, err := e.executable()
+ if err != nil {
+ return "", err
+ }
+ defer exec.DecRef()
+
+ return exec.PathnameWithDeleted(ctx), nil
+}
+
+// namespaceSymlink represents a symlink in the namespacefs, such as the files
+// in /proc/<pid>/ns.
+//
+// +stateify savable
+type namespaceSymlink struct {
+ ramfs.Symlink
+
+ t *kernel.Task
+}
+
+func newNamespaceSymlink(t *kernel.Task, msrc *fs.MountSource, name string) *fs.Inode {
+ // TODO(rahat): Namespace symlinks should contain the namespace name and the
+ // inode number for the namespace instance, so for example user:[123456]. We
+ // currently fake the inode number by sticking the symlink inode in its
+ // place.
+ target := fmt.Sprintf("%s:[%d]", name, device.ProcDevice.NextIno())
+ n := &namespaceSymlink{
+ Symlink: *ramfs.NewSymlink(t, fs.RootOwner, target),
+ t: t,
+ }
+ return newProcInode(t, n, msrc, fs.Symlink, t)
+}
+
+// Readlink reads the symlink value.
+func (n *namespaceSymlink) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
+ if err := checkTaskState(n.t); err != nil {
+ return "", err
+ }
+ return n.Symlink.Readlink(ctx, inode)
+}
+
+// Getlink implements fs.InodeOperations.Getlink.
+func (n *namespaceSymlink) Getlink(ctx context.Context, inode *fs.Inode) (*fs.Dirent, error) {
+ if !kernel.ContextCanTrace(ctx, n.t, false) {
+ return nil, syserror.EACCES
+ }
+ if err := checkTaskState(n.t); err != nil {
+ return nil, err
+ }
+
+ // Create a new regular file to fake the namespace file.
+ iops := fsutil.NewNoReadWriteFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0777), linux.PROC_SUPER_MAGIC)
+ return fs.NewDirent(ctx, newProcInode(ctx, iops, inode.MountSource, fs.RegularFile, nil), n.Symlink.Target), nil
+}
+
+func newNamespaceDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ contents := map[string]*fs.Inode{
+ "net": newNamespaceSymlink(t, msrc, "net"),
+ "pid": newNamespaceSymlink(t, msrc, "pid"),
+ "user": newNamespaceSymlink(t, msrc, "user"),
+ }
+ d := ramfs.NewDir(t, contents, fs.RootOwner, fs.FilePermsFromMode(0511))
+ return newProcInode(t, d, msrc, fs.SpecialDirectory, t)
+}
+
+// mapsData implements seqfile.SeqSource for /proc/[pid]/maps.
+//
+// +stateify savable
+type mapsData struct {
+ t *kernel.Task
+}
+
+func newMaps(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ return newProcInode(t, seqfile.NewSeqFile(t, &mapsData{t}), msrc, fs.SpecialFile, t)
+}
+
+func (md *mapsData) mm() *mm.MemoryManager {
+ var tmm *mm.MemoryManager
+ md.t.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ // No additional reference is taken on mm here. This is safe
+ // because MemoryManager.destroy is required to leave the
+ // MemoryManager in a state where it's still usable as a SeqSource.
+ tmm = mm
+ }
+ })
+ return tmm
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (md *mapsData) NeedsUpdate(generation int64) bool {
+ if mm := md.mm(); mm != nil {
+ return mm.NeedsUpdate(generation)
+ }
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (md *mapsData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if mm := md.mm(); mm != nil {
+ return mm.ReadMapsSeqFileData(ctx, h)
+ }
+ return []seqfile.SeqData{}, 0
+}
+
+// smapsData implements seqfile.SeqSource for /proc/[pid]/smaps.
+//
+// +stateify savable
+type smapsData struct {
+ t *kernel.Task
+}
+
+func newSmaps(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ return newProcInode(t, seqfile.NewSeqFile(t, &smapsData{t}), msrc, fs.SpecialFile, t)
+}
+
+func (sd *smapsData) mm() *mm.MemoryManager {
+ var tmm *mm.MemoryManager
+ sd.t.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ // No additional reference is taken on mm here. This is safe
+ // because MemoryManager.destroy is required to leave the
+ // MemoryManager in a state where it's still usable as a SeqSource.
+ tmm = mm
+ }
+ })
+ return tmm
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (sd *smapsData) NeedsUpdate(generation int64) bool {
+ if mm := sd.mm(); mm != nil {
+ return mm.NeedsUpdate(generation)
+ }
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (sd *smapsData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if mm := sd.mm(); mm != nil {
+ return mm.ReadSmapsSeqFileData(ctx, h)
+ }
+ return []seqfile.SeqData{}, 0
+}
+
+// +stateify savable
+type taskStatData struct {
+ t *kernel.Task
+
+ // If tgstats is true, accumulate fault stats (not implemented) and CPU
+ // time across all tasks in t's thread group.
+ tgstats bool
+
+ // pidns is the PID namespace associated with the proc filesystem that
+ // includes the file using this statData.
+ pidns *kernel.PIDNamespace
+}
+
+func newTaskStat(t *kernel.Task, msrc *fs.MountSource, showSubtasks bool, pidns *kernel.PIDNamespace) *fs.Inode {
+ return newProcInode(t, seqfile.NewSeqFile(t, &taskStatData{t, showSubtasks /* tgstats */, pidns}), msrc, fs.SpecialFile, t)
+}
+
+// NeedsUpdate returns whether the generation is old or not.
+func (s *taskStatData) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData returns data for the SeqFile reader.
+// SeqData, the current generation and where in the file the handle corresponds to.
+func (s *taskStatData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ var buf bytes.Buffer
+
+ fmt.Fprintf(&buf, "%d ", s.pidns.IDOfTask(s.t))
+ fmt.Fprintf(&buf, "(%s) ", s.t.Name())
+ fmt.Fprintf(&buf, "%c ", s.t.StateStatus()[0])
+ ppid := kernel.ThreadID(0)
+ if parent := s.t.Parent(); parent != nil {
+ ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
+ }
+ fmt.Fprintf(&buf, "%d ", ppid)
+ fmt.Fprintf(&buf, "%d ", s.pidns.IDOfProcessGroup(s.t.ThreadGroup().ProcessGroup()))
+ fmt.Fprintf(&buf, "%d ", s.pidns.IDOfSession(s.t.ThreadGroup().Session()))
+ fmt.Fprintf(&buf, "0 0 " /* tty_nr tpgid */)
+ fmt.Fprintf(&buf, "0 " /* flags */)
+ fmt.Fprintf(&buf, "0 0 0 0 " /* minflt cminflt majflt cmajflt */)
+ var cputime usage.CPUStats
+ if s.tgstats {
+ cputime = s.t.ThreadGroup().CPUStats()
+ } else {
+ cputime = s.t.CPUStats()
+ }
+ fmt.Fprintf(&buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
+ cputime = s.t.ThreadGroup().JoinedChildCPUStats()
+ fmt.Fprintf(&buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
+ fmt.Fprintf(&buf, "%d %d ", s.t.Priority(), s.t.Niceness())
+ fmt.Fprintf(&buf, "%d ", s.t.ThreadGroup().Count())
+
+ // itrealvalue. Since kernel 2.6.17, this field is no longer
+ // maintained, and is hard coded as 0.
+ fmt.Fprintf(&buf, "0 ")
+
+ // Start time is relative to boot time, expressed in clock ticks.
+ fmt.Fprintf(&buf, "%d ", linux.ClockTFromDuration(s.t.StartTime().Sub(s.t.Kernel().Timekeeper().BootTime())))
+
+ var vss, rss uint64
+ s.t.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
+ })
+ fmt.Fprintf(&buf, "%d %d ", vss, rss/usermem.PageSize)
+
+ // rsslim.
+ fmt.Fprintf(&buf, "%d ", s.t.ThreadGroup().Limits().Get(limits.Rss).Cur)
+
+ fmt.Fprintf(&buf, "0 0 0 0 0 " /* startcode endcode startstack kstkesp kstkeip */)
+ fmt.Fprintf(&buf, "0 0 0 0 0 " /* signal blocked sigignore sigcatch wchan */)
+ fmt.Fprintf(&buf, "0 0 " /* nswap cnswap */)
+ terminationSignal := linux.Signal(0)
+ if s.t == s.t.ThreadGroup().Leader() {
+ terminationSignal = s.t.ThreadGroup().TerminationSignal()
+ }
+ fmt.Fprintf(&buf, "%d ", terminationSignal)
+ fmt.Fprintf(&buf, "0 0 0 " /* processor rt_priority policy */)
+ fmt.Fprintf(&buf, "0 0 0 " /* delayacct_blkio_ticks guest_time cguest_time */)
+ fmt.Fprintf(&buf, "0 0 0 0 0 0 0 " /* start_data end_data start_brk arg_start arg_end env_start env_end */)
+ fmt.Fprintf(&buf, "0\n" /* exit_code */)
+
+ return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*taskStatData)(nil)}}, 0
+}
+
+// statmData implements seqfile.SeqSource for /proc/[pid]/statm.
+//
+// +stateify savable
+type statmData struct {
+ t *kernel.Task
+}
+
+func newStatm(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ return newProcInode(t, seqfile.NewSeqFile(t, &statmData{t}), msrc, fs.SpecialFile, t)
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (s *statmData) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (s *statmData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ var vss, rss uint64
+ s.t.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
+ })
+
+ var buf bytes.Buffer
+ fmt.Fprintf(&buf, "%d %d 0 0 0 0 0\n", vss/usermem.PageSize, rss/usermem.PageSize)
+
+ return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*statmData)(nil)}}, 0
+}
+
+// statusData implements seqfile.SeqSource for /proc/[pid]/status.
+//
+// +stateify savable
+type statusData struct {
+ t *kernel.Task
+ pidns *kernel.PIDNamespace
+}
+
+func newStatus(t *kernel.Task, msrc *fs.MountSource, pidns *kernel.PIDNamespace) *fs.Inode {
+ return newProcInode(t, seqfile.NewSeqFile(t, &statusData{t, pidns}), msrc, fs.SpecialFile, t)
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (s *statusData) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (s *statusData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ var buf bytes.Buffer
+ fmt.Fprintf(&buf, "Name:\t%s\n", s.t.Name())
+ fmt.Fprintf(&buf, "State:\t%s\n", s.t.StateStatus())
+ fmt.Fprintf(&buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.t.ThreadGroup()))
+ fmt.Fprintf(&buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.t))
+ ppid := kernel.ThreadID(0)
+ if parent := s.t.Parent(); parent != nil {
+ ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
+ }
+ fmt.Fprintf(&buf, "PPid:\t%d\n", ppid)
+ tpid := kernel.ThreadID(0)
+ if tracer := s.t.Tracer(); tracer != nil {
+ tpid = s.pidns.IDOfTask(tracer)
+ }
+ fmt.Fprintf(&buf, "TracerPid:\t%d\n", tpid)
+ var fds int
+ var vss, rss, data uint64
+ s.t.WithMuLocked(func(t *kernel.Task) {
+ if fdTable := t.FDTable(); fdTable != nil {
+ fds = fdTable.Size()
+ }
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ data = mm.VirtualDataSize()
+ }
+ })
+ fmt.Fprintf(&buf, "FDSize:\t%d\n", fds)
+ fmt.Fprintf(&buf, "VmSize:\t%d kB\n", vss>>10)
+ fmt.Fprintf(&buf, "VmRSS:\t%d kB\n", rss>>10)
+ fmt.Fprintf(&buf, "VmData:\t%d kB\n", data>>10)
+ fmt.Fprintf(&buf, "Threads:\t%d\n", s.t.ThreadGroup().Count())
+ creds := s.t.Credentials()
+ fmt.Fprintf(&buf, "CapInh:\t%016x\n", creds.InheritableCaps)
+ fmt.Fprintf(&buf, "CapPrm:\t%016x\n", creds.PermittedCaps)
+ fmt.Fprintf(&buf, "CapEff:\t%016x\n", creds.EffectiveCaps)
+ fmt.Fprintf(&buf, "CapBnd:\t%016x\n", creds.BoundingCaps)
+ fmt.Fprintf(&buf, "Seccomp:\t%d\n", s.t.SeccompMode())
+ // We unconditionally report a single NUMA node. See
+ // pkg/sentry/syscalls/linux/sys_mempolicy.go.
+ fmt.Fprintf(&buf, "Mems_allowed:\t1\n")
+ fmt.Fprintf(&buf, "Mems_allowed_list:\t0\n")
+ return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*statusData)(nil)}}, 0
+}
+
+// ioUsage is the /proc/<pid>/io and /proc/<pid>/task/<tid>/io data provider.
+type ioUsage interface {
+ // IOUsage returns the io usage data.
+ IOUsage() *usage.IO
+}
+
+// +stateify savable
+type ioData struct {
+ ioUsage
+}
+
+func newIO(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode {
+ if isThreadGroup {
+ return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t)
+ }
+ return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t}), msrc, fs.SpecialFile, t)
+}
+
+// NeedsUpdate returns whether the generation is old or not.
+func (i *ioData) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData returns data for the SeqFile reader.
+// SeqData, the current generation and where in the file the handle corresponds to.
+func (i *ioData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ io := usage.IO{}
+ io.Accumulate(i.IOUsage())
+
+ var buf bytes.Buffer
+ fmt.Fprintf(&buf, "rchar: %d\n", io.CharsRead)
+ fmt.Fprintf(&buf, "wchar: %d\n", io.CharsWritten)
+ fmt.Fprintf(&buf, "syscr: %d\n", io.ReadSyscalls)
+ fmt.Fprintf(&buf, "syscw: %d\n", io.WriteSyscalls)
+ fmt.Fprintf(&buf, "read_bytes: %d\n", io.BytesRead)
+ fmt.Fprintf(&buf, "write_bytes: %d\n", io.BytesWritten)
+ fmt.Fprintf(&buf, "cancelled_write_bytes: %d\n", io.BytesWriteCancelled)
+
+ return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*ioData)(nil)}}, 0
+}
+
+// comm is a file containing the command name for a task.
+//
+// On Linux, /proc/[pid]/comm is writable, and writing to the comm file changes
+// the thread name. We don't implement this yet as there are no known users of
+// this feature.
+//
+// +stateify savable
+type comm struct {
+ fsutil.SimpleFileInode
+
+ t *kernel.Task
+}
+
+// newComm returns a new comm file.
+func newComm(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ c := &comm{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ t: t,
+ }
+ return newProcInode(t, c, msrc, fs.SpecialFile, t)
+}
+
+// Check implements fs.InodeOperations.Check.
+func (c *comm) Check(ctx context.Context, inode *fs.Inode, p fs.PermMask) bool {
+ // This file can always be read or written by members of the same
+ // thread group. See fs/proc/base.c:proc_tid_comm_permission.
+ //
+ // N.B. This check is currently a no-op as we don't yet support writing
+ // and this file is world-readable anyways.
+ t := kernel.TaskFromContext(ctx)
+ if t != nil && t.ThreadGroup() == c.t.ThreadGroup() && !p.Execute {
+ return true
+ }
+
+ return fs.ContextCanAccessFile(ctx, inode, p)
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (c *comm) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &commFile{t: c.t}), nil
+}
+
+// +stateify savable
+type commFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ t *kernel.Task
+}
+
+var _ fs.FileOperations = (*commFile)(nil)
+
+// Read implements fs.FileOperations.Read.
+func (f *commFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ buf := []byte(f.t.Name() + "\n")
+ if offset >= int64(len(buf)) {
+ return 0, io.EOF
+ }
+
+ n, err := dst.CopyOut(ctx, buf[offset:])
+ return int64(n), err
+}
+
+// auxvec is a file containing the auxiliary vector for a task.
+//
+// +stateify savable
+type auxvec struct {
+ fsutil.SimpleFileInode
+
+ t *kernel.Task
+}
+
+// newAuxvec returns a new auxvec file.
+func newAuxvec(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ a := &auxvec{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ t: t,
+ }
+ return newProcInode(t, a, msrc, fs.SpecialFile, t)
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (a *auxvec) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &auxvecFile{t: a.t}), nil
+}
+
+// +stateify savable
+type auxvecFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ t *kernel.Task
+}
+
+// Read implements fs.FileOperations.Read.
+func (f *auxvecFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ m, err := getTaskMM(f.t)
+ if err != nil {
+ return 0, err
+ }
+ defer m.DecUsers(ctx)
+ auxv := m.Auxv()
+
+ // Space for buffer with AT_NULL (0) terminator at the end.
+ size := (len(auxv) + 1) * 16
+ if offset >= int64(size) {
+ return 0, io.EOF
+ }
+
+ buf := make([]byte, size)
+ for i, e := range auxv {
+ usermem.ByteOrder.PutUint64(buf[16*i:], e.Key)
+ usermem.ByteOrder.PutUint64(buf[16*i+8:], uint64(e.Value))
+ }
+
+ n, err := dst.CopyOut(ctx, buf[offset:])
+ return int64(n), err
+}
+
+// newOOMScore returns a oom_score file. It is a stub that always returns 0.
+// TODO(gvisor.dev/issue/1967)
+func newOOMScore(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ return newStaticProcInode(t, msrc, []byte("0\n"))
+}
+
+// oomScoreAdj is a file containing the oom_score adjustment for a task.
+//
+// +stateify savable
+type oomScoreAdj struct {
+ fsutil.SimpleFileInode
+
+ t *kernel.Task
+}
+
+// +stateify savable
+type oomScoreAdjFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ t *kernel.Task
+}
+
+// newOOMScoreAdj returns a oom_score_adj file.
+func newOOMScoreAdj(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ i := &oomScoreAdj{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC),
+ t: t,
+ }
+ return newProcInode(t, i, msrc, fs.SpecialFile, t)
+}
+
+// Truncate implements fs.InodeOperations.Truncate. Truncate is called when
+// O_TRUNC is specified for any kind of existing Dirent but is not called via
+// (f)truncate for proc files.
+func (*oomScoreAdj) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (o *oomScoreAdj) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &oomScoreAdjFile{t: o.t}), nil
+}
+
+// Read implements fs.FileOperations.Read.
+func (f *oomScoreAdjFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if f.t.ExitState() == kernel.TaskExitDead {
+ return 0, syserror.ESRCH
+ }
+ var buf bytes.Buffer
+ fmt.Fprintf(&buf, "%d\n", f.t.OOMScoreAdj())
+ if offset >= int64(buf.Len()) {
+ return 0, io.EOF
+ }
+ n, err := dst.CopyOut(ctx, buf.Bytes()[offset:])
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+func (f *oomScoreAdjFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Limit input size so as not to impact performance if input size is large.
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ var v int32
+ n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
+ if err != nil {
+ return 0, err
+ }
+
+ if f.t.ExitState() == kernel.TaskExitDead {
+ return 0, syserror.ESRCH
+ }
+ if err := f.t.SetOOMScoreAdj(v); err != nil {
+ return 0, err
+ }
+
+ return n, nil
+}
+
+// LINT.ThenChange(../../fsimpl/proc/task.go|../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/proc/uid_gid_map.go b/pkg/sentry/fs/proc/uid_gid_map.go
new file mode 100644
index 000000000..8d9517b95
--- /dev/null
+++ b/pkg/sentry/fs/proc/uid_gid_map.go
@@ -0,0 +1,183 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+// idMapInodeOperations implements fs.InodeOperations for
+// /proc/[pid]/{uid,gid}_map.
+//
+// +stateify savable
+type idMapInodeOperations struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotAllocatable `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeNotTruncatable `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+ fsutil.InodeSimpleExtendedAttributes
+
+ t *kernel.Task
+ gids bool
+}
+
+var _ fs.InodeOperations = (*idMapInodeOperations)(nil)
+
+// newUIDMap returns a new uid_map file.
+func newUIDMap(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ return newIDMap(t, msrc, false /* gids */)
+}
+
+// newGIDMap returns a new gid_map file.
+func newGIDMap(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ return newIDMap(t, msrc, true /* gids */)
+}
+
+func newIDMap(t *kernel.Task, msrc *fs.MountSource, gids bool) *fs.Inode {
+ return newProcInode(t, &idMapInodeOperations{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(t, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC),
+ t: t,
+ gids: gids,
+ }, msrc, fs.SpecialFile, t)
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (imio *idMapInodeOperations) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &idMapFileOperations{
+ iops: imio,
+ }), nil
+}
+
+// +stateify savable
+type idMapFileOperations struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ iops *idMapInodeOperations
+}
+
+var _ fs.FileOperations = (*idMapFileOperations)(nil)
+
+// "There is an (arbitrary) limit on the number of lines in the file. As at
+// Linux 3.18, the limit is five lines." - user_namespaces(7)
+const maxIDMapLines = 5
+
+// Read implements fs.FileOperations.Read.
+func (imfo *idMapFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ var entries []auth.IDMapEntry
+ if imfo.iops.gids {
+ entries = imfo.iops.t.UserNamespace().GIDMap()
+ } else {
+ entries = imfo.iops.t.UserNamespace().UIDMap()
+ }
+ var buf bytes.Buffer
+ for _, e := range entries {
+ fmt.Fprintf(&buf, "%10d %10d %10d\n", e.FirstID, e.FirstParentID, e.Length)
+ }
+ if offset >= int64(buf.Len()) {
+ return 0, io.EOF
+ }
+ n, err := dst.CopyOut(ctx, buf.Bytes()[offset:])
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+func (imfo *idMapFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ // "In addition, the number of bytes written to the file must be less than
+ // the system page size, and the write must be performed at the start of
+ // the file ..." - user_namespaces(7)
+ srclen := src.NumBytes()
+ if srclen >= usermem.PageSize || offset != 0 {
+ return 0, syserror.EINVAL
+ }
+ b := make([]byte, srclen)
+ if _, err := src.CopyIn(ctx, b); err != nil {
+ return 0, err
+ }
+
+ // Truncate from the first NULL byte.
+ var nul int64
+ nul = int64(bytes.IndexByte(b, 0))
+ if nul == -1 {
+ nul = srclen
+ }
+ b = b[:nul]
+ // Remove the last \n.
+ if nul >= 1 && b[nul-1] == '\n' {
+ b = b[:nul-1]
+ }
+ lines := bytes.SplitN(b, []byte("\n"), maxIDMapLines+1)
+ if len(lines) > maxIDMapLines {
+ return 0, syserror.EINVAL
+ }
+
+ entries := make([]auth.IDMapEntry, len(lines))
+ for i, l := range lines {
+ var e auth.IDMapEntry
+ _, err := fmt.Sscan(string(l), &e.FirstID, &e.FirstParentID, &e.Length)
+ if err != nil {
+ return 0, syserror.EINVAL
+ }
+ entries[i] = e
+ }
+ var err error
+ if imfo.iops.gids {
+ err = imfo.iops.t.UserNamespace().SetGIDMap(ctx, entries)
+ } else {
+ err = imfo.iops.t.UserNamespace().SetUIDMap(ctx, entries)
+ }
+ if err != nil {
+ return 0, err
+ }
+
+ // On success, Linux's kernel/user_namespace.c:map_write() always returns
+ // count, even if fewer bytes were used.
+ return int64(srclen), nil
+}
+
+// LINT.ThenChange(../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/proc/uptime.go b/pkg/sentry/fs/proc/uptime.go
new file mode 100644
index 000000000..c0f6fb802
--- /dev/null
+++ b/pkg/sentry/fs/proc/uptime.go
@@ -0,0 +1,91 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+// uptime is a file containing the system uptime.
+//
+// +stateify savable
+type uptime struct {
+ fsutil.SimpleFileInode
+
+ // The "start time" of the sandbox.
+ startTime ktime.Time
+}
+
+// newUptime returns a new uptime file.
+func newUptime(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ u := &uptime{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ startTime: ktime.NowFromContext(ctx),
+ }
+ return newProcInode(ctx, u, msrc, fs.SpecialFile, nil)
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (u *uptime) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &uptimeFile{startTime: u.startTime}), nil
+}
+
+// +stateify savable
+type uptimeFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ startTime ktime.Time
+}
+
+// Read implements fs.FileOperations.Read.
+func (f *uptimeFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ now := ktime.NowFromContext(ctx)
+ // Pretend that we've spent zero time sleeping (second number).
+ s := []byte(fmt.Sprintf("%.2f 0.00\n", now.Sub(f.startTime).Seconds()))
+ if offset >= int64(len(s)) {
+ return 0, io.EOF
+ }
+
+ n, err := dst.CopyOut(ctx, s[offset:])
+ return int64(n), err
+}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/version.go b/pkg/sentry/fs/proc/version.go
new file mode 100644
index 000000000..35e258ff6
--- /dev/null
+++ b/pkg/sentry/fs/proc/version.go
@@ -0,0 +1,82 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// LINT.IfChange
+
+// versionData backs /proc/version.
+//
+// +stateify savable
+type versionData struct {
+ // k is the owning Kernel.
+ k *kernel.Kernel
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*versionData) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (v *versionData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ init := v.k.GlobalInit()
+ if init == nil {
+ // Attempted to read before the init Task is created. This can
+ // only occur during startup, which should never need to read
+ // this file.
+ panic("Attempted to read version before initial Task is available")
+ }
+
+ // /proc/version takes the form:
+ //
+ // "SYSNAME version RELEASE (COMPILE_USER@COMPILE_HOST)
+ // (COMPILER_VERSION) VERSION"
+ //
+ // where:
+ // - SYSNAME, RELEASE, and VERSION are the same as returned by
+ // sys_utsname
+ // - COMPILE_USER is the user that build the kernel
+ // - COMPILE_HOST is the hostname of the machine on which the kernel
+ // was built
+ // - COMPILER_VERSION is the version reported by the building compiler
+ //
+ // Since we don't really want to expose build information to
+ // applications, those fields are omitted.
+ //
+ // FIXME(mpratt): Using Version from the init task SyscallTable
+ // disregards the different version a task may have (e.g., in a uts
+ // namespace).
+ ver := init.Leader().SyscallTable().Version
+ return []seqfile.SeqData{
+ {
+ Buf: []byte(fmt.Sprintf("%s version %s %s\n", ver.Sysname, ver.Release, ver.Version)),
+ Handle: (*versionData)(nil),
+ },
+ }, 0
+}
+
+// LINT.ThenChange(../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD
new file mode 100644
index 000000000..8ca823fb3
--- /dev/null
+++ b/pkg/sentry/fs/ramfs/BUILD
@@ -0,0 +1,37 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "ramfs",
+ srcs = [
+ "dir.go",
+ "socket.go",
+ "symlink.go",
+ "tree.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/anon",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "ramfs_test",
+ size = "small",
+ srcs = ["tree_test.go"],
+ library = ":ramfs",
+ deps = [
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fs",
+ ],
+)
diff --git a/pkg/sentry/fs/ramfs/dir.go b/pkg/sentry/fs/ramfs/dir.go
new file mode 100644
index 000000000..bfa304552
--- /dev/null
+++ b/pkg/sentry/fs/ramfs/dir.go
@@ -0,0 +1,548 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ramfs provides the fundamentals for a simple in-memory filesystem.
+package ramfs
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// CreateOps represents operations to create different file types.
+type CreateOps struct {
+ // NewDir creates a new directory.
+ NewDir func(ctx context.Context, dir *fs.Inode, perms fs.FilePermissions) (*fs.Inode, error)
+
+ // NewFile creates a new file.
+ NewFile func(ctx context.Context, dir *fs.Inode, perms fs.FilePermissions) (*fs.Inode, error)
+
+ // NewSymlink creates a new symlink with permissions 0777.
+ NewSymlink func(ctx context.Context, dir *fs.Inode, target string) (*fs.Inode, error)
+
+ // NewBoundEndpoint creates a new socket.
+ NewBoundEndpoint func(ctx context.Context, dir *fs.Inode, ep transport.BoundEndpoint, perms fs.FilePermissions) (*fs.Inode, error)
+
+ // NewFifo creates a new fifo.
+ NewFifo func(ctx context.Context, dir *fs.Inode, perm fs.FilePermissions) (*fs.Inode, error)
+}
+
+// Dir represents a single directory in the filesystem.
+//
+// +stateify savable
+type Dir struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeIsDirAllocate `state:"nosave"`
+ fsutil.InodeIsDirTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+ fsutil.InodeSimpleExtendedAttributes
+
+ // CreateOps may be provided.
+ //
+ // These may only be modified during initialization (while the application
+ // is not running). No sychronization is performed when accessing these
+ // operations during syscalls.
+ *CreateOps `state:"nosave"`
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // children are inodes that are in this directory. A reference is held
+ // on each inode while it is in the map.
+ children map[string]*fs.Inode
+
+ // dentryMap is a sortedDentryMap containing entries for all children.
+ // Its entries are kept up-to-date with d.children.
+ dentryMap *fs.SortedDentryMap
+}
+
+var _ fs.InodeOperations = (*Dir)(nil)
+
+// NewDir returns a new Dir with the given contents and attributes. A reference
+// on each fs.Inode in the `contents` map will be donated to this Dir.
+func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwner, perms fs.FilePermissions) *Dir {
+ d := &Dir{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, perms, linux.RAMFS_MAGIC),
+ }
+
+ if contents == nil {
+ contents = make(map[string]*fs.Inode)
+ }
+ d.children = contents
+
+ // Build the entries map ourselves, rather than calling addChildLocked,
+ // because it will be faster.
+ entries := make(map[string]fs.DentAttr, len(contents))
+ for name, inode := range contents {
+ entries[name] = fs.DentAttr{
+ Type: inode.StableAttr.Type,
+ InodeID: inode.StableAttr.InodeID,
+ }
+ }
+ d.dentryMap = fs.NewSortedDentryMap(entries)
+
+ // Directories have an extra link, corresponding to '.'.
+ d.AddLink()
+
+ return d
+}
+
+// addChildLocked add the child inode, inheriting its reference.
+func (d *Dir) addChildLocked(ctx context.Context, name string, inode *fs.Inode) {
+ d.children[name] = inode
+ d.dentryMap.Add(name, fs.DentAttr{
+ Type: inode.StableAttr.Type,
+ InodeID: inode.StableAttr.InodeID,
+ })
+
+ // If the child is a directory, increment this dir's link count,
+ // corresponding to '..' from the subdirectory.
+ if fs.IsDir(inode.StableAttr) {
+ d.AddLink()
+ // ctime updated below.
+ }
+
+ // Given we're now adding this inode to the directory we must also
+ // increase its link count. Similarly we decrement it in removeChildLocked.
+ //
+ // Changing link count updates ctime.
+ inode.AddLink()
+ inode.InodeOperations.NotifyStatusChange(ctx)
+
+ // We've change the directory. This always updates our mtime and ctime.
+ d.NotifyModificationAndStatusChange(ctx)
+}
+
+// AddChild adds a child to this dir, inheriting its reference.
+func (d *Dir) AddChild(ctx context.Context, name string, inode *fs.Inode) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ d.addChildLocked(ctx, name, inode)
+}
+
+// FindChild returns (child, true) if the directory contains name.
+func (d *Dir) FindChild(name string) (*fs.Inode, bool) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ child, ok := d.children[name]
+ return child, ok
+}
+
+// Children returns the names and DentAttrs of all children. It can be used to
+// implement Readdir for types that embed ramfs.Dir.
+func (d *Dir) Children() ([]string, map[string]fs.DentAttr) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // Return a copy to prevent callers from modifying our children.
+ names, entries := d.dentryMap.GetAll()
+ namesCopy := make([]string, len(names))
+ copy(namesCopy, names)
+
+ entriesCopy := make(map[string]fs.DentAttr)
+ for k, v := range entries {
+ entriesCopy[k] = v
+ }
+
+ return namesCopy, entriesCopy
+}
+
+// removeChildLocked attempts to remove an entry from this directory. It
+// returns the removed fs.Inode along with its reference, which callers are
+// responsible for decrementing.
+func (d *Dir) removeChildLocked(ctx context.Context, name string) (*fs.Inode, error) {
+ inode, ok := d.children[name]
+ if !ok {
+ return nil, syserror.EACCES
+ }
+
+ delete(d.children, name)
+ d.dentryMap.Remove(name)
+ d.NotifyModification(ctx)
+
+ // If the child was a subdirectory, then we must decrement this dir's
+ // link count which was the child's ".." directory entry.
+ if fs.IsDir(inode.StableAttr) {
+ d.DropLink()
+ // ctime changed below.
+ }
+
+ // Given we're now removing this inode to the directory we must also
+ // decrease its link count. Similarly it is increased in addChildLocked.
+ //
+ // Changing link count updates ctime.
+ inode.DropLink()
+ inode.InodeOperations.NotifyStatusChange(ctx)
+
+ // We've change the directory. This always updates our mtime and ctime.
+ d.NotifyModificationAndStatusChange(ctx)
+
+ return inode, nil
+}
+
+// Remove removes the named non-directory.
+func (d *Dir) Remove(ctx context.Context, _ *fs.Inode, name string) error {
+ if len(name) > linux.NAME_MAX {
+ return syserror.ENAMETOOLONG
+ }
+
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ inode, err := d.removeChildLocked(ctx, name)
+ if err != nil {
+ return err
+ }
+
+ // Remove our reference on the inode.
+ inode.DecRef()
+ return nil
+}
+
+// RemoveDirectory removes the named directory.
+func (d *Dir) RemoveDirectory(ctx context.Context, _ *fs.Inode, name string) error {
+ if len(name) > linux.NAME_MAX {
+ return syserror.ENAMETOOLONG
+ }
+
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // Get the child and make sure it is not empty.
+ childInode, err := d.walkLocked(ctx, name)
+ if err != nil {
+ return err
+ }
+ if ok, err := hasChildren(ctx, childInode); err != nil {
+ return err
+ } else if ok {
+ return syserror.ENOTEMPTY
+ }
+
+ // Child was empty. Proceed with removal.
+ inode, err := d.removeChildLocked(ctx, name)
+ if err != nil {
+ return err
+ }
+
+ // Remove our reference on the inode.
+ inode.DecRef()
+
+ return nil
+}
+
+// Lookup loads an inode at p into a Dirent. It returns the fs.Dirent along
+// with a reference.
+func (d *Dir) Lookup(ctx context.Context, _ *fs.Inode, p string) (*fs.Dirent, error) {
+ if len(p) > linux.NAME_MAX {
+ return nil, syserror.ENAMETOOLONG
+ }
+
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ inode, err := d.walkLocked(ctx, p)
+ if err != nil {
+ return nil, err
+ }
+
+ // Take a reference on the inode before returning it. This reference
+ // is owned by the dirent we are about to create.
+ inode.IncRef()
+ return fs.NewDirent(ctx, inode, p), nil
+}
+
+// walkLocked must be called with d.mu held.
+func (d *Dir) walkLocked(ctx context.Context, p string) (*fs.Inode, error) {
+ // Lookup a child node.
+ if inode, ok := d.children[p]; ok {
+ return inode, nil
+ }
+
+ // fs.InodeOperations.Lookup returns syserror.ENOENT if p
+ // does not exist.
+ return nil, syserror.ENOENT
+}
+
+// createInodeOperationsCommon creates a new child node at this dir by calling
+// makeInodeOperations. It is the common logic for creating a new child.
+func (d *Dir) createInodeOperationsCommon(ctx context.Context, name string, makeInodeOperations func() (*fs.Inode, error)) (*fs.Inode, error) {
+ if len(name) > linux.NAME_MAX {
+ return nil, syserror.ENAMETOOLONG
+ }
+
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ inode, err := makeInodeOperations()
+ if err != nil {
+ return nil, err
+ }
+
+ d.addChildLocked(ctx, name, inode)
+
+ return inode, nil
+}
+
+// Create creates a new Inode with the given name and returns its File.
+func (d *Dir) Create(ctx context.Context, dir *fs.Inode, name string, flags fs.FileFlags, perms fs.FilePermissions) (*fs.File, error) {
+ if d.CreateOps == nil || d.CreateOps.NewFile == nil {
+ return nil, syserror.EACCES
+ }
+
+ inode, err := d.createInodeOperationsCommon(ctx, name, func() (*fs.Inode, error) {
+ return d.NewFile(ctx, dir, perms)
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // Take an extra ref on inode, which will be owned by the dirent.
+ inode.IncRef()
+
+ // Create the Dirent and corresponding file.
+ created := fs.NewDirent(ctx, inode, name)
+ defer created.DecRef()
+ return created.Inode.GetFile(ctx, created, flags)
+}
+
+// CreateLink returns a new link.
+func (d *Dir) CreateLink(ctx context.Context, dir *fs.Inode, oldname, newname string) error {
+ if d.CreateOps == nil || d.CreateOps.NewSymlink == nil {
+ return syserror.EACCES
+ }
+ _, err := d.createInodeOperationsCommon(ctx, newname, func() (*fs.Inode, error) {
+ return d.NewSymlink(ctx, dir, oldname)
+ })
+ return err
+}
+
+// CreateHardLink creates a new hard link.
+func (d *Dir) CreateHardLink(ctx context.Context, dir *fs.Inode, target *fs.Inode, name string) error {
+ if len(name) > linux.NAME_MAX {
+ return syserror.ENAMETOOLONG
+ }
+
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // Take an extra reference on the inode and add it to our children.
+ target.IncRef()
+
+ // The link count will be incremented in addChildLocked.
+ d.addChildLocked(ctx, name, target)
+
+ return nil
+}
+
+// CreateDirectory returns a new subdirectory.
+func (d *Dir) CreateDirectory(ctx context.Context, dir *fs.Inode, name string, perms fs.FilePermissions) error {
+ if d.CreateOps == nil || d.CreateOps.NewDir == nil {
+ return syserror.EACCES
+ }
+ _, err := d.createInodeOperationsCommon(ctx, name, func() (*fs.Inode, error) {
+ return d.NewDir(ctx, dir, perms)
+ })
+ return err
+}
+
+// Bind implements fs.InodeOperations.Bind.
+func (d *Dir) Bind(ctx context.Context, dir *fs.Inode, name string, ep transport.BoundEndpoint, perms fs.FilePermissions) (*fs.Dirent, error) {
+ if d.CreateOps == nil || d.CreateOps.NewBoundEndpoint == nil {
+ return nil, syserror.EACCES
+ }
+ inode, err := d.createInodeOperationsCommon(ctx, name, func() (*fs.Inode, error) {
+ return d.NewBoundEndpoint(ctx, dir, ep, perms)
+ })
+ if err == syscall.EEXIST {
+ return nil, syscall.EADDRINUSE
+ }
+ if err != nil {
+ return nil, err
+ }
+ // Take another ref on inode which will be donated to the new dirent.
+ inode.IncRef()
+ return fs.NewDirent(ctx, inode, name), nil
+}
+
+// CreateFifo implements fs.InodeOperations.CreateFifo.
+func (d *Dir) CreateFifo(ctx context.Context, dir *fs.Inode, name string, perms fs.FilePermissions) error {
+ if d.CreateOps == nil || d.CreateOps.NewFifo == nil {
+ return syserror.EACCES
+ }
+ _, err := d.createInodeOperationsCommon(ctx, name, func() (*fs.Inode, error) {
+ return d.NewFifo(ctx, dir, perms)
+ })
+ return err
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (d *Dir) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ flags.Pread = true
+ return fs.NewFile(ctx, dirent, flags, &dirFileOperations{dir: d}), nil
+}
+
+// Rename implements fs.InodeOperations.Rename.
+func (*Dir) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ return Rename(ctx, oldParent.InodeOperations, oldName, newParent.InodeOperations, newName, replacement)
+}
+
+// Release implements fs.InodeOperation.Release.
+func (d *Dir) Release(_ context.Context) {
+ // Drop references on all children.
+ d.mu.Lock()
+ for _, i := range d.children {
+ i.DecRef()
+ }
+ d.mu.Unlock()
+}
+
+// dirFileOperations implements fs.FileOperations for a ramfs directory.
+//
+// +stateify savable
+type dirFileOperations struct {
+ fsutil.DirFileOperations `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ // dirCursor contains the name of the last directory entry that was
+ // serialized.
+ dirCursor string
+
+ // dir is the ramfs dir that this file corresponds to.
+ dir *Dir
+}
+
+var _ fs.FileOperations = (*dirFileOperations)(nil)
+
+// Seek implements fs.FileOperations.Seek.
+func (dfo *dirFileOperations) Seek(ctx context.Context, file *fs.File, whence fs.SeekWhence, offset int64) (int64, error) {
+ return fsutil.SeekWithDirCursor(ctx, file, whence, offset, &dfo.dirCursor)
+}
+
+// IterateDir implements DirIterator.IterateDir.
+func (dfo *dirFileOperations) IterateDir(ctx context.Context, d *fs.Dirent, dirCtx *fs.DirCtx, offset int) (int, error) {
+ dfo.dir.mu.Lock()
+ defer dfo.dir.mu.Unlock()
+
+ n, err := fs.GenericReaddir(dirCtx, dfo.dir.dentryMap)
+ return offset + n, err
+}
+
+// Readdir implements FileOperations.Readdir.
+func (dfo *dirFileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
+ root := fs.RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+ dirCtx := &fs.DirCtx{
+ Serializer: serializer,
+ DirCursor: &dfo.dirCursor,
+ }
+ dfo.dir.InodeSimpleAttributes.NotifyAccess(ctx)
+ return fs.DirentReaddir(ctx, file.Dirent, dfo, root, dirCtx, file.Offset())
+}
+
+// hasChildren is a helper method that determines whether an arbitrary inode
+// (not necessarily ramfs) has any children.
+func hasChildren(ctx context.Context, inode *fs.Inode) (bool, error) {
+ // Take an extra ref on inode which will be given to the dirent and
+ // dropped when that dirent is destroyed.
+ inode.IncRef()
+ d := fs.NewTransientDirent(inode)
+ defer d.DecRef()
+
+ file, err := inode.GetFile(ctx, d, fs.FileFlags{Read: true})
+ if err != nil {
+ return false, err
+ }
+ defer file.DecRef()
+
+ ser := &fs.CollectEntriesSerializer{}
+ if err := file.Readdir(ctx, ser); err != nil {
+ return false, err
+ }
+ // We will always write "." and "..", so ignore those two.
+ if ser.Written() > 2 {
+ return true, nil
+ }
+ return false, nil
+}
+
+// Rename renames from a *ramfs.Dir to another *ramfs.Dir.
+func Rename(ctx context.Context, oldParent fs.InodeOperations, oldName string, newParent fs.InodeOperations, newName string, replacement bool) error {
+ op, ok := oldParent.(*Dir)
+ if !ok {
+ return syserror.EXDEV
+ }
+ np, ok := newParent.(*Dir)
+ if !ok {
+ return syserror.EXDEV
+ }
+ if len(newName) > linux.NAME_MAX {
+ return syserror.ENAMETOOLONG
+ }
+
+ np.mu.Lock()
+ defer np.mu.Unlock()
+
+ // Is this is an overwriting rename?
+ if replacement {
+ replaced, ok := np.children[newName]
+ if !ok {
+ panic(fmt.Sprintf("Dirent claims rename is replacement, but %q is missing from %+v", newName, np))
+ }
+
+ // Non-empty directories cannot be replaced.
+ if fs.IsDir(replaced.StableAttr) {
+ if ok, err := hasChildren(ctx, replaced); err != nil {
+ return err
+ } else if ok {
+ return syserror.ENOTEMPTY
+ }
+ }
+
+ // Remove the replaced child and drop our reference on it.
+ inode, err := np.removeChildLocked(ctx, newName)
+ if err != nil {
+ return err
+ }
+ inode.DecRef()
+ }
+
+ // Be careful, we may have already grabbed this mutex above.
+ if op != np {
+ op.mu.Lock()
+ defer op.mu.Unlock()
+ }
+
+ // Do the swap.
+ n := op.children[oldName]
+ op.removeChildLocked(ctx, oldName)
+ np.addChildLocked(ctx, newName, n)
+
+ return nil
+}
diff --git a/pkg/sentry/fs/ramfs/socket.go b/pkg/sentry/fs/ramfs/socket.go
new file mode 100644
index 000000000..29ff004f2
--- /dev/null
+++ b/pkg/sentry/fs/ramfs/socket.go
@@ -0,0 +1,85 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ramfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// Socket represents a socket.
+//
+// +stateify savable
+type Socket struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotAllocatable `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeNotTruncatable `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+ fsutil.InodeSimpleExtendedAttributes
+
+ // ep is the bound endpoint.
+ ep transport.BoundEndpoint
+}
+
+var _ fs.InodeOperations = (*Socket)(nil)
+
+// NewSocket returns a new Socket.
+func NewSocket(ctx context.Context, ep transport.BoundEndpoint, owner fs.FileOwner, perms fs.FilePermissions) *Socket {
+ return &Socket{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, perms, linux.SOCKFS_MAGIC),
+ ep: ep,
+ }
+}
+
+// BoundEndpoint returns the socket data.
+func (s *Socket) BoundEndpoint(*fs.Inode, string) transport.BoundEndpoint {
+ // ramfs only supports stored sentry internal sockets. Only gofer sockets
+ // care about the path argument.
+ return s.ep
+}
+
+// GetFile implements fs.FileOperations.GetFile.
+func (s *Socket) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &socketFileOperations{}), nil
+}
+
+// +stateify savable
+type socketFileOperations struct {
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoRead `state:"nosave"`
+ fsutil.FileNoSeek `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+}
+
+var _ fs.FileOperations = (*socketFileOperations)(nil)
diff --git a/pkg/sentry/fs/ramfs/symlink.go b/pkg/sentry/fs/ramfs/symlink.go
new file mode 100644
index 000000000..d988349aa
--- /dev/null
+++ b/pkg/sentry/fs/ramfs/symlink.go
@@ -0,0 +1,106 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ramfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// Symlink represents a symlink.
+//
+// +stateify savable
+type Symlink struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotAllocatable `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotTruncatable `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+ fsutil.InodeSimpleExtendedAttributes
+
+ // Target is the symlink target.
+ Target string
+}
+
+var _ fs.InodeOperations = (*Symlink)(nil)
+
+// NewSymlink returns a new Symlink.
+func NewSymlink(ctx context.Context, owner fs.FileOwner, target string) *Symlink {
+ // A symlink is assumed to always have permissions 0777.
+ return &Symlink{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(0777), linux.RAMFS_MAGIC),
+ Target: target,
+ }
+}
+
+// UnstableAttr returns all attributes of this ramfs symlink.
+func (s *Symlink) UnstableAttr(ctx context.Context, inode *fs.Inode) (fs.UnstableAttr, error) {
+ uattr, err := s.InodeSimpleAttributes.UnstableAttr(ctx, inode)
+ if err != nil {
+ return fs.UnstableAttr{}, err
+ }
+ uattr.Size = int64(len(s.Target))
+ uattr.Usage = uattr.Size
+ return uattr, nil
+}
+
+// SetPermissions on a symlink is always rejected.
+func (s *Symlink) SetPermissions(context.Context, *fs.Inode, fs.FilePermissions) bool {
+ return false
+}
+
+// Readlink reads the symlink value.
+func (s *Symlink) Readlink(ctx context.Context, _ *fs.Inode) (string, error) {
+ s.NotifyAccess(ctx)
+ return s.Target, nil
+}
+
+// Getlink returns ErrResolveViaReadlink, falling back to walking to the result
+// of Readlink().
+func (*Symlink) Getlink(context.Context, *fs.Inode) (*fs.Dirent, error) {
+ return nil, fs.ErrResolveViaReadlink
+}
+
+// GetFile implements fs.FileOperations.GetFile.
+func (s *Symlink) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &symlinkFileOperations{}), nil
+}
+
+// +stateify savable
+type symlinkFileOperations struct {
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoRead `state:"nosave"`
+ fsutil.FileNoSeek `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+}
+
+var _ fs.FileOperations = (*symlinkFileOperations)(nil)
diff --git a/pkg/sentry/fs/ramfs/tree.go b/pkg/sentry/fs/ramfs/tree.go
new file mode 100644
index 000000000..dfc9d3453
--- /dev/null
+++ b/pkg/sentry/fs/ramfs/tree.go
@@ -0,0 +1,77 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ramfs
+
+import (
+ "fmt"
+ "path"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/anon"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// MakeDirectoryTree constructs a ramfs tree of all directories containing
+// subdirs. Each element of subdir must be a clean path, and cannot be empty or
+// "/".
+//
+// All directories in the created tree will have full (read-write-execute)
+// permissions, but note that file creation inside the directories is not
+// actually supported because ramfs.Dir.CreateOpts == nil. However, these
+// directory trees are normally "underlayed" under another filesystem (possibly
+// the root), and file creation inside these directories in the overlay will be
+// possible if the upper is writeable.
+func MakeDirectoryTree(ctx context.Context, msrc *fs.MountSource, subdirs []string) (*fs.Inode, error) {
+ root := emptyDir(ctx, msrc)
+ for _, subdir := range subdirs {
+ if path.Clean(subdir) != subdir {
+ return nil, fmt.Errorf("cannot add subdir at an unclean path: %q", subdir)
+ }
+ if subdir == "" || subdir == "/" {
+ return nil, fmt.Errorf("cannot add subdir at %q", subdir)
+ }
+ makeSubdir(ctx, msrc, root.InodeOperations.(*Dir), subdir)
+ }
+ return root, nil
+}
+
+// makeSubdir installs into root each component of subdir. The final component is
+// a *ramfs.Dir.
+func makeSubdir(ctx context.Context, msrc *fs.MountSource, root *Dir, subdir string) {
+ for _, c := range strings.Split(subdir, "/") {
+ if len(c) == 0 {
+ continue
+ }
+ child, ok := root.FindChild(c)
+ if !ok {
+ child = emptyDir(ctx, msrc)
+ root.AddChild(ctx, c, child)
+ }
+ root = child.InodeOperations.(*Dir)
+ }
+}
+
+// emptyDir returns an empty *ramfs.Dir with all permissions granted.
+func emptyDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ dir := NewDir(ctx, make(map[string]*fs.Inode), fs.RootOwner, fs.FilePermsFromMode(0777))
+ return fs.NewInode(ctx, dir, msrc, fs.StableAttr{
+ DeviceID: anon.PseudoDevice.DeviceID(),
+ InodeID: anon.PseudoDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.Directory,
+ })
+}
diff --git a/pkg/sentry/fs/ramfs/tree_test.go b/pkg/sentry/fs/ramfs/tree_test.go
new file mode 100644
index 000000000..a6ed8b2c5
--- /dev/null
+++ b/pkg/sentry/fs/ramfs/tree_test.go
@@ -0,0 +1,80 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ramfs
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+)
+
+func TestMakeDirectoryTree(t *testing.T) {
+
+ for _, test := range []struct {
+ name string
+ subdirs []string
+ }{
+ {
+ name: "abs paths",
+ subdirs: []string{
+ "/tmp",
+ "/tmp/a/b",
+ "/tmp/a/c/d",
+ "/tmp/c",
+ "/proc",
+ "/dev/a/b",
+ "/tmp",
+ },
+ },
+ {
+ name: "rel paths",
+ subdirs: []string{
+ "tmp",
+ "tmp/a/b",
+ "tmp/a/c/d",
+ "tmp/c",
+ "proc",
+ "dev/a/b",
+ "tmp",
+ },
+ },
+ } {
+ ctx := contexttest.Context(t)
+ mount := fs.NewPseudoMountSource(ctx)
+ tree, err := MakeDirectoryTree(ctx, mount, test.subdirs)
+ if err != nil {
+ t.Errorf("%s: failed to make ramfs tree, got error %v, want nil", test.name, err)
+ continue
+ }
+
+ // Expect to be able to find each of the paths.
+ mm, err := fs.NewMountNamespace(ctx, tree)
+ if err != nil {
+ t.Errorf("%s: failed to create mount manager: %v", test.name, err)
+ continue
+ }
+ root := mm.Root()
+ defer mm.DecRef()
+
+ for _, p := range test.subdirs {
+ maxTraversals := uint(0)
+ if _, err := mm.FindInode(ctx, root, nil, p, &maxTraversals); err != nil {
+ t.Errorf("%s: failed to find node %s: %v", test.name, p, err)
+ break
+ }
+ }
+ }
+}
diff --git a/pkg/sentry/fs/restore.go b/pkg/sentry/fs/restore.go
new file mode 100644
index 000000000..64c6a6ae9
--- /dev/null
+++ b/pkg/sentry/fs/restore.go
@@ -0,0 +1,78 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// RestoreEnvironment is the restore environment for file systems. It consists
+// of things that change across save and restore and therefore cannot be saved
+// in the object graph.
+type RestoreEnvironment struct {
+ // MountSources maps Filesystem.Name() to mount arguments.
+ MountSources map[string][]MountArgs
+
+ // ValidateFileSize indicates file size should not change across S/R.
+ ValidateFileSize bool
+
+ // ValidateFileTimestamp indicates file modification timestamp should
+ // not change across S/R.
+ ValidateFileTimestamp bool
+}
+
+// MountArgs holds arguments to Mount.
+type MountArgs struct {
+ // Dev corresponds to the devname argumnent of Mount.
+ Dev string
+
+ // Flags corresponds to the flags argument of Mount.
+ Flags MountSourceFlags
+
+ // DataString corresponds to the data argument of Mount.
+ DataString string
+
+ // DataObj corresponds to the data interface argument of Mount.
+ DataObj interface{}
+}
+
+// restoreEnv holds the fs package global RestoreEnvironment.
+var restoreEnv = struct {
+ mu sync.Mutex
+ env RestoreEnvironment
+ set bool
+}{}
+
+// SetRestoreEnvironment sets the RestoreEnvironment. Must be called before
+// state.Load and only once.
+func SetRestoreEnvironment(r RestoreEnvironment) {
+ restoreEnv.mu.Lock()
+ defer restoreEnv.mu.Unlock()
+ if restoreEnv.set {
+ panic("RestoreEnvironment may only be set once")
+ }
+ restoreEnv.env = r
+ restoreEnv.set = true
+}
+
+// CurrentRestoreEnvironment returns the current, read-only RestoreEnvironment.
+// If no RestoreEnvironment was ever set, returns (_, false).
+func CurrentRestoreEnvironment() (RestoreEnvironment, bool) {
+ restoreEnv.mu.Lock()
+ defer restoreEnv.mu.Unlock()
+ e := restoreEnv.env
+ set := restoreEnv.set
+ return e, set
+}
diff --git a/pkg/sentry/fs/save.go b/pkg/sentry/fs/save.go
new file mode 100644
index 000000000..fe5c76b44
--- /dev/null
+++ b/pkg/sentry/fs/save.go
@@ -0,0 +1,77 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// SaveInodeMappings saves a mapping of path -> inode ID for every
+// user-reachable Dirent.
+//
+// The entire kernel must be frozen to call this, and filesystem state must not
+// change between SaveInodeMappings and state.Save, otherwise the saved state
+// of any MountSource may be incoherent.
+func SaveInodeMappings() {
+ mountsSeen := make(map[*MountSource]struct{})
+ for dirent := range allDirents.dirents {
+ if _, ok := mountsSeen[dirent.Inode.MountSource]; !ok {
+ dirent.Inode.MountSource.ResetInodeMappings()
+ mountsSeen[dirent.Inode.MountSource] = struct{}{}
+ }
+ }
+
+ for dirent := range allDirents.dirents {
+ if dirent.Inode != nil {
+ // We cannot trust the root provided in the mount due
+ // to the overlay. We can trust the overlay to delegate
+ // SaveInodeMappings to the right underlying
+ // filesystems, though.
+ root := dirent
+ for !root.mounted && root.parent != nil {
+ root = root.parent
+ }
+
+ // Add the mapping.
+ n, reachable := dirent.FullName(root)
+ if !reachable {
+ // Something has gone seriously wrong if we can't reach our root.
+ panic(fmt.Sprintf("Unreachable root on dirent file %s", n))
+ }
+ dirent.Inode.MountSource.SaveInodeMapping(dirent.Inode, n)
+ }
+ }
+}
+
+// SaveFileFsyncError converts an fs.File.Fsync error to an error that
+// indicates that the fs.File was not synced sufficiently to be saved.
+func SaveFileFsyncError(err error) error {
+ switch err {
+ case nil:
+ // We succeeded, everything is great.
+ return nil
+ case syscall.EBADF, syscall.EINVAL, syscall.EROFS, syscall.ENOSYS, syscall.EPERM:
+ // These errors mean that the underlying node might not be syncable,
+ // which we expect to be reported as such even from the gofer.
+ log.Infof("failed to sync during save: %v", err)
+ return nil
+ default:
+ // We failed in some way that indicates potential data loss.
+ return fmt.Errorf("failed to sync: %v, data loss may occur", err)
+ }
+}
diff --git a/pkg/sentry/fs/seek.go b/pkg/sentry/fs/seek.go
new file mode 100644
index 000000000..0f43918ad
--- /dev/null
+++ b/pkg/sentry/fs/seek.go
@@ -0,0 +1,43 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+// SeekWhence determines seek direction.
+type SeekWhence int
+
+const (
+ // SeekSet sets the absolute offset.
+ SeekSet SeekWhence = iota
+
+ // SeekCurrent sets relative to the current position.
+ SeekCurrent
+
+ // SeekEnd sets relative to the end of the file.
+ SeekEnd
+)
+
+// String returns a human readable string for whence.
+func (s SeekWhence) String() string {
+ switch s {
+ case SeekSet:
+ return "Set"
+ case SeekCurrent:
+ return "Current"
+ case SeekEnd:
+ return "End"
+ default:
+ return "Unknown"
+ }
+}
diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go
new file mode 100644
index 000000000..33da82868
--- /dev/null
+++ b/pkg/sentry/fs/splice.go
@@ -0,0 +1,181 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "io"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Splice moves data to this file, directly from another.
+//
+// Offsets are updated only if DstOffset and SrcOffset are set.
+func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, error) {
+ // Verify basic file flag permissions.
+ if !dst.Flags().Write || !src.Flags().Read {
+ return 0, syserror.EBADF
+ }
+
+ // Check whether or not the objects being sliced are stream-oriented
+ // (i.e. pipes or sockets). For all stream-oriented files and files
+ // where a specific offiset is not request, we acquire the file mutex.
+ // This has two important side effects. First, it provides the standard
+ // protection against concurrent writes that would mutate the offset.
+ // Second, it prevents Splice deadlocks. Only internal anonymous files
+ // implement the ReadFrom and WriteTo methods directly, and since such
+ // anonymous files are referred to by a unique fs.File object, we know
+ // that the file mutex takes strict precedence over internal locks.
+ // Since we enforce lock ordering here, we can't deadlock by using
+ // using a file in two different splice operations simultaneously.
+ srcPipe := !IsRegular(src.Dirent.Inode.StableAttr)
+ dstPipe := !IsRegular(dst.Dirent.Inode.StableAttr)
+ dstAppend := !dstPipe && dst.Flags().Append
+ srcLock := srcPipe || !opts.SrcOffset
+ dstLock := dstPipe || !opts.DstOffset || dstAppend
+
+ switch {
+ case srcLock && dstLock:
+ switch {
+ case dst.UniqueID < src.UniqueID:
+ // Acquire dst first.
+ if !dst.mu.Lock(ctx) {
+ return 0, syserror.ErrInterrupted
+ }
+ if !src.mu.Lock(ctx) {
+ dst.mu.Unlock()
+ return 0, syserror.ErrInterrupted
+ }
+ case dst.UniqueID > src.UniqueID:
+ // Acquire src first.
+ if !src.mu.Lock(ctx) {
+ return 0, syserror.ErrInterrupted
+ }
+ if !dst.mu.Lock(ctx) {
+ src.mu.Unlock()
+ return 0, syserror.ErrInterrupted
+ }
+ case dst.UniqueID == src.UniqueID:
+ // Acquire only one lock; it's the same file. This is a
+ // bit of a edge case, but presumably it's possible.
+ if !dst.mu.Lock(ctx) {
+ return 0, syserror.ErrInterrupted
+ }
+ srcLock = false // Only need one unlock.
+ }
+ // Use both offsets (locked).
+ opts.DstStart = dst.offset
+ opts.SrcStart = src.offset
+ case dstLock:
+ // Acquire only dst.
+ if !dst.mu.Lock(ctx) {
+ return 0, syserror.ErrInterrupted
+ }
+ opts.DstStart = dst.offset // Safe: locked.
+ case srcLock:
+ // Acquire only src.
+ if !src.mu.Lock(ctx) {
+ return 0, syserror.ErrInterrupted
+ }
+ opts.SrcStart = src.offset // Safe: locked.
+ }
+
+ var err error
+ if dstAppend {
+ unlock := dst.Dirent.Inode.lockAppendMu(dst.Flags().Append)
+ defer unlock()
+
+ // Figure out the appropriate offset to use.
+ err = dst.offsetForAppend(ctx, &opts.DstStart)
+ }
+ if err == nil && !dstPipe {
+ // Enforce file limits.
+ limit, ok := dst.checkLimit(ctx, opts.DstStart)
+ switch {
+ case ok && limit == 0:
+ err = syserror.ErrExceedsFileSizeLimit
+ case ok && limit < opts.Length:
+ opts.Length = limit // Cap the write.
+ }
+ }
+ if err != nil {
+ if dstLock {
+ dst.mu.Unlock()
+ }
+ if srcLock {
+ src.mu.Unlock()
+ }
+ return 0, err
+ }
+
+ // Construct readers and writers for the splice. This is used to
+ // provide a safer locking path for the WriteTo/ReadFrom operations
+ // (since they will otherwise go through public interface methods which
+ // conflict with locking done above), and simplifies the fallback path.
+ w := &lockedWriter{
+ Ctx: ctx,
+ File: dst,
+ Offset: opts.DstStart,
+ }
+ r := &lockedReader{
+ Ctx: ctx,
+ File: src,
+ Offset: opts.SrcStart,
+ }
+
+ // Attempt to do a WriteTo; this is likely the most efficient.
+ n, err := src.FileOperations.WriteTo(ctx, src, w, opts.Length, opts.Dup)
+ if n == 0 && err == syserror.ENOSYS && !opts.Dup {
+ // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also be
+ // more efficient than a copy if buffers are cached or readily
+ // available. (It's unlikely that they can actually be donated).
+ n, err = dst.FileOperations.ReadFrom(ctx, dst, r, opts.Length)
+ }
+
+ // Support one last fallback option, but only if at least one of
+ // the source and destination are regular files. This is because
+ // if we block at some point, we could lose data. If the source is
+ // not a pipe then reading is not destructive; if the destination
+ // is a regular file, then it is guaranteed not to block writing.
+ if n == 0 && err == syserror.ENOSYS && !opts.Dup && (!dstPipe || !srcPipe) {
+ // Fallback to an in-kernel copy.
+ n, err = io.Copy(w, &io.LimitedReader{
+ R: r,
+ N: opts.Length,
+ })
+ }
+
+ // Update offsets, if required.
+ if n > 0 {
+ if !dstPipe && !opts.DstOffset {
+ atomic.StoreInt64(&dst.offset, dst.offset+n)
+ }
+ if !srcPipe && !opts.SrcOffset {
+ atomic.StoreInt64(&src.offset, src.offset+n)
+ }
+ }
+
+ // Drop locks.
+ if dstLock {
+ dst.mu.Unlock()
+ }
+ if srcLock {
+ src.mu.Unlock()
+ }
+
+ return n, err
+}
diff --git a/pkg/sentry/fs/sync.go b/pkg/sentry/fs/sync.go
new file mode 100644
index 000000000..1fff8059c
--- /dev/null
+++ b/pkg/sentry/fs/sync.go
@@ -0,0 +1,43 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+// SyncType enumerates ways in which a File can be synced.
+type SyncType int
+
+const (
+ // SyncAll indicates that modified in-memory metadata and data should
+ // be written to backing storage. SyncAll implies SyncBackingStorage.
+ SyncAll SyncType = iota
+
+ // SyncData indicates that along with modified in-memory data, only
+ // metadata needed to access that data needs to be written.
+ //
+ // For example, changes to access time or modification time do not
+ // need to be written because they are not necessary for a data read
+ // to be handled correctly, unlike the file size.
+ //
+ // The aim of SyncData is to reduce disk activity for applications
+ // that do not require all metadata to be synchronized with the disk,
+ // see fdatasync(2). File systems that implement SyncData as SyncAll
+ // do not support this optimization.
+ //
+ // SyncData implies SyncBackingStorage.
+ SyncData
+
+ // SyncBackingStorage indicates that in-flight write operations to
+ // backing storage should be flushed.
+ SyncBackingStorage
+)
diff --git a/pkg/sentry/fs/sys/BUILD b/pkg/sentry/fs/sys/BUILD
new file mode 100644
index 000000000..f2e8b9932
--- /dev/null
+++ b/pkg/sentry/fs/sys/BUILD
@@ -0,0 +1,24 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "sys",
+ srcs = [
+ "device.go",
+ "devices.go",
+ "fs.go",
+ "sys.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/ramfs",
+ "//pkg/sentry/kernel",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fs/sys/device.go b/pkg/sentry/fs/sys/device.go
new file mode 100644
index 000000000..4e79dbb71
--- /dev/null
+++ b/pkg/sentry/fs/sys/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sys
+
+import "gvisor.dev/gvisor/pkg/sentry/device"
+
+// sysfsDevice is the sysfs virtual device.
+var sysfsDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/fs/sys/devices.go b/pkg/sentry/fs/sys/devices.go
new file mode 100644
index 000000000..b67065956
--- /dev/null
+++ b/pkg/sentry/fs/sys/devices.go
@@ -0,0 +1,91 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sys
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// +stateify savable
+type cpunum struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotAllocatable `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeNotTruncatable `state:"nosave"`
+ fsutil.InodeNotVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+ fsutil.InodeStaticFileGetter
+}
+
+var _ fs.InodeOperations = (*cpunum)(nil)
+
+func newPossible(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ var maxCore uint
+ k := kernel.KernelFromContext(ctx)
+ if k != nil {
+ maxCore = k.ApplicationCores() - 1
+ }
+ contents := []byte(fmt.Sprintf("0-%d\n", maxCore))
+
+ c := &cpunum{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.SYSFS_MAGIC),
+ InodeStaticFileGetter: fsutil.InodeStaticFileGetter{
+ Contents: contents,
+ },
+ }
+ return newFile(ctx, c, msrc)
+}
+
+func newCPU(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ m := map[string]*fs.Inode{
+ "online": newPossible(ctx, msrc),
+ "possible": newPossible(ctx, msrc),
+ "present": newPossible(ctx, msrc),
+ }
+
+ // Add directories for each of the cpus.
+ if k := kernel.KernelFromContext(ctx); k != nil {
+ for i := 0; uint(i) < k.ApplicationCores(); i++ {
+ m[fmt.Sprintf("cpu%d", i)] = newDir(ctx, msrc, nil)
+ }
+ }
+
+ return newDir(ctx, msrc, m)
+}
+
+func newSystemDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ return newDir(ctx, msrc, map[string]*fs.Inode{
+ "cpu": newCPU(ctx, msrc),
+ })
+}
+
+func newDevicesDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ return newDir(ctx, msrc, map[string]*fs.Inode{
+ "system": newSystemDir(ctx, msrc),
+ })
+}
diff --git a/pkg/sentry/fs/sys/fs.go b/pkg/sentry/fs/sys/fs.go
new file mode 100644
index 000000000..fd03a4e38
--- /dev/null
+++ b/pkg/sentry/fs/sys/fs.go
@@ -0,0 +1,65 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sys
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+)
+
+// filesystem is a sysfs.
+//
+// +stateify savable
+type filesystem struct{}
+
+var _ fs.Filesystem = (*filesystem)(nil)
+
+func init() {
+ fs.RegisterFilesystem(&filesystem{})
+}
+
+// FilesystemName is the name under which the filesystem is registered.
+// Name matches fs/sysfs/mount.c:sysfs_fs_type.name.
+const FilesystemName = "sysfs"
+
+// Name is the name of the file system.
+func (*filesystem) Name() string {
+ return FilesystemName
+}
+
+// AllowUserMount allows users to mount(2) this file system.
+func (*filesystem) AllowUserMount() bool {
+ return true
+}
+
+// AllowUserList allows this filesystem to be listed in /proc/filesystems.
+func (*filesystem) AllowUserList() bool {
+ return true
+}
+
+// Flags returns that there is nothing special about this file system.
+//
+// In Linux, sysfs returns FS_USERNS_VISIBLE | FS_USERNS_MOUNT, see fs/sysfs/mount.c.
+func (*filesystem) Flags() fs.FilesystemFlags {
+ return 0
+}
+
+// Mount returns a sysfs root which can be positioned in the vfs.
+func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSourceFlags, data string, _ interface{}) (*fs.Inode, error) {
+ // device is always ignored.
+ // sysfs ignores data, see fs/sysfs/mount.c:sysfs_mount.
+
+ return New(ctx, fs.NewNonCachingMountSource(ctx, f, flags)), nil
+}
diff --git a/pkg/sentry/fs/sys/sys.go b/pkg/sentry/fs/sys/sys.go
new file mode 100644
index 000000000..0891645e4
--- /dev/null
+++ b/pkg/sentry/fs/sys/sys.go
@@ -0,0 +1,64 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package sys implements a sysfs filesystem.
+package sys
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func newFile(ctx context.Context, node fs.InodeOperations, msrc *fs.MountSource) *fs.Inode {
+ sattr := fs.StableAttr{
+ DeviceID: sysfsDevice.DeviceID(),
+ InodeID: sysfsDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.SpecialFile,
+ }
+ return fs.NewInode(ctx, node, msrc, sattr)
+}
+
+func newDir(ctx context.Context, msrc *fs.MountSource, contents map[string]*fs.Inode) *fs.Inode {
+ d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return fs.NewInode(ctx, d, msrc, fs.StableAttr{
+ DeviceID: sysfsDevice.DeviceID(),
+ InodeID: sysfsDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.SpecialDirectory,
+ })
+}
+
+// New returns the root node of a partial simple sysfs.
+func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ return newDir(ctx, msrc, map[string]*fs.Inode{
+ // Add a basic set of top-level directories. In Linux, these
+ // are dynamically added depending on the KConfig. Here we just
+ // add the most common ones.
+ "block": newDir(ctx, msrc, nil),
+ "bus": newDir(ctx, msrc, nil),
+ "class": newDir(ctx, msrc, map[string]*fs.Inode{
+ "power_supply": newDir(ctx, msrc, nil),
+ }),
+ "dev": newDir(ctx, msrc, nil),
+ "devices": newDevicesDir(ctx, msrc),
+ "firmware": newDir(ctx, msrc, nil),
+ "fs": newDir(ctx, msrc, nil),
+ "kernel": newDir(ctx, msrc, nil),
+ "module": newDir(ctx, msrc, nil),
+ "power": newDir(ctx, msrc, nil),
+ })
+}
diff --git a/pkg/sentry/fs/timerfd/BUILD b/pkg/sentry/fs/timerfd/BUILD
new file mode 100644
index 000000000..d16cdb4df
--- /dev/null
+++ b/pkg/sentry/fs/timerfd/BUILD
@@ -0,0 +1,19 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "timerfd",
+ srcs = ["timerfd.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/anon",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel/time",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go
new file mode 100644
index 000000000..88c344089
--- /dev/null
+++ b/pkg/sentry/fs/timerfd/timerfd.go
@@ -0,0 +1,151 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package timerfd implements the semantics of Linux timerfd objects as
+// described by timerfd_create(2).
+package timerfd
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/anon"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// TimerOperations implements fs.FileOperations for timerfds.
+//
+// +stateify savable
+type TimerOperations struct {
+ fsutil.FileZeroSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ events waiter.Queue `state:"zerovalue"`
+ timer *ktime.Timer
+
+ // val is the number of timer expirations since the last successful call to
+ // Readv, Preadv, or SetTime. val is accessed using atomic memory
+ // operations.
+ val uint64
+}
+
+// NewFile returns a timerfd File that receives time from c.
+func NewFile(ctx context.Context, c ktime.Clock) *fs.File {
+ dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[timerfd]")
+ // Release the initial dirent reference after NewFile takes a reference.
+ defer dirent.DecRef()
+ tops := &TimerOperations{}
+ tops.timer = ktime.NewTimer(c, tops)
+ // Timerfds reject writes, but the Write flag must be set in order to
+ // ensure that our Writev/Pwritev methods actually get called to return
+ // the correct errors.
+ return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, tops)
+}
+
+// Release implements fs.FileOperations.Release.
+func (t *TimerOperations) Release() {
+ t.timer.Destroy()
+}
+
+// PauseTimer pauses the associated Timer.
+func (t *TimerOperations) PauseTimer() {
+ t.timer.Pause()
+}
+
+// ResumeTimer resumes the associated Timer.
+func (t *TimerOperations) ResumeTimer() {
+ t.timer.Resume()
+}
+
+// Clock returns the associated Timer's Clock.
+func (t *TimerOperations) Clock() ktime.Clock {
+ return t.timer.Clock()
+}
+
+// GetTime returns the associated Timer's setting and the time at which it was
+// observed.
+func (t *TimerOperations) GetTime() (ktime.Time, ktime.Setting) {
+ return t.timer.Get()
+}
+
+// SetTime atomically changes the associated Timer's setting, resets the number
+// of expirations to 0, and returns the previous setting and the time at which
+// it was observed.
+func (t *TimerOperations) SetTime(s ktime.Setting) (ktime.Time, ktime.Setting) {
+ return t.timer.SwapAnd(s, func() { atomic.StoreUint64(&t.val, 0) })
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (t *TimerOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ var ready waiter.EventMask
+ if atomic.LoadUint64(&t.val) != 0 {
+ ready |= waiter.EventIn
+ }
+ return ready
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (t *TimerOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ t.events.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (t *TimerOperations) EventUnregister(e *waiter.Entry) {
+ t.events.EventUnregister(e)
+}
+
+// Read implements fs.FileOperations.Read.
+func (t *TimerOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ const sizeofUint64 = 8
+ if dst.NumBytes() < sizeofUint64 {
+ return 0, syserror.EINVAL
+ }
+ if val := atomic.SwapUint64(&t.val, 0); val != 0 {
+ var buf [sizeofUint64]byte
+ usermem.ByteOrder.PutUint64(buf[:], val)
+ if _, err := dst.CopyOut(ctx, buf[:]); err != nil {
+ // Linux does not undo consuming the number of expirations even if
+ // writing to userspace fails.
+ return 0, err
+ }
+ return sizeofUint64, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+// Write implements fs.FileOperations.Write.
+func (t *TimerOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.EINVAL
+}
+
+// Notify implements ktime.TimerListener.Notify.
+func (t *TimerOperations) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
+ atomic.AddUint64(&t.val, exp)
+ t.events.Notify(waiter.EventIn)
+ return ktime.Setting{}, false
+}
+
+// Destroy implements ktime.TimerListener.Destroy.
+func (t *TimerOperations) Destroy() {}
diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD
new file mode 100644
index 000000000..aa7199014
--- /dev/null
+++ b/pkg/sentry/fs/tmpfs/BUILD
@@ -0,0 +1,50 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "tmpfs",
+ srcs = [
+ "device.go",
+ "file_regular.go",
+ "fs.go",
+ "inode_file.go",
+ "tmpfs.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/metric",
+ "//pkg/safemem",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/ramfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/pipe",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/usage",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "tmpfs_test",
+ size = "small",
+ srcs = ["file_test.go"],
+ library = ":tmpfs",
+ deps = [
+ "//pkg/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/kernel/contexttest",
+ "//pkg/sentry/usage",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fs/tmpfs/device.go b/pkg/sentry/fs/tmpfs/device.go
new file mode 100644
index 000000000..ae7c55ee1
--- /dev/null
+++ b/pkg/sentry/fs/tmpfs/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import "gvisor.dev/gvisor/pkg/sentry/device"
+
+// tmpfsDevice is the kernel tmpfs device.
+var tmpfsDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/fs/tmpfs/file_regular.go b/pkg/sentry/fs/tmpfs/file_regular.go
new file mode 100644
index 000000000..614f8f8a1
--- /dev/null
+++ b/pkg/sentry/fs/tmpfs/file_regular.go
@@ -0,0 +1,60 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// regularFileOperations implements fs.FileOperations for a regular
+// tmpfs file.
+//
+// +stateify savable
+type regularFileOperations struct {
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ // iops is the InodeOperations of a regular tmpfs file. It is
+ // guaranteed to be the same as file.Dirent.Inode.InodeOperations,
+ // see operations that take fs.File below.
+ iops *fileInodeOperations
+}
+
+// Read implements fs.FileOperations.Read.
+func (r *regularFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ return r.iops.read(ctx, file, dst, offset)
+}
+
+// Write implements fs.FileOperations.Write.
+func (r *regularFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ return r.iops.write(ctx, src, offset)
+}
+
+// ConfigureMMap implements fs.FileOperations.ConfigureMMap.
+func (r *regularFileOperations) ConfigureMMap(ctx context.Context, file *fs.File, opts *memmap.MMapOpts) error {
+ return fsutil.GenericConfigureMMap(file, r.iops, opts)
+}
diff --git a/pkg/sentry/fs/tmpfs/file_test.go b/pkg/sentry/fs/tmpfs/file_test.go
new file mode 100644
index 000000000..aaba35502
--- /dev/null
+++ b/pkg/sentry/fs/tmpfs/file_test.go
@@ -0,0 +1,72 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "bytes"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func newFileInode(ctx context.Context) *fs.Inode {
+ m := fs.NewCachingMountSource(ctx, &Filesystem{}, fs.MountSourceFlags{})
+ iops := NewInMemoryFile(ctx, usage.Tmpfs, fs.WithCurrentTime(ctx, fs.UnstableAttr{}))
+ return fs.NewInode(ctx, iops, m, fs.StableAttr{
+ DeviceID: tmpfsDevice.DeviceID(),
+ InodeID: tmpfsDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.RegularFile,
+ })
+}
+
+func newFile(ctx context.Context) *fs.File {
+ inode := newFileInode(ctx)
+ f, _ := inode.GetFile(ctx, fs.NewDirent(ctx, inode, "stub"), fs.FileFlags{Read: true, Write: true})
+ return f
+}
+
+// Allocate once, write twice.
+func TestGrow(t *testing.T) {
+ ctx := contexttest.Context(t)
+ f := newFile(ctx)
+ defer f.DecRef()
+
+ abuf := bytes.Repeat([]byte{'a'}, 68)
+ n, err := f.Pwritev(ctx, usermem.BytesIOSequence(abuf), 0)
+ if n != int64(len(abuf)) || err != nil {
+ t.Fatalf("Pwritev got (%d, %v) want (%d, nil)", n, err, len(abuf))
+ }
+
+ bbuf := bytes.Repeat([]byte{'b'}, 856)
+ n, err = f.Pwritev(ctx, usermem.BytesIOSequence(bbuf), 68)
+ if n != int64(len(bbuf)) || err != nil {
+ t.Fatalf("Pwritev got (%d, %v) want (%d, nil)", n, err, len(bbuf))
+ }
+
+ rbuf := make([]byte, len(abuf)+len(bbuf))
+ n, err = f.Preadv(ctx, usermem.BytesIOSequence(rbuf), 0)
+ if n != int64(len(rbuf)) || err != nil {
+ t.Fatalf("Preadv got (%d, %v) want (%d, nil)", n, err, len(rbuf))
+ }
+
+ if want := append(abuf, bbuf...); !bytes.Equal(rbuf, want) {
+ t.Fatalf("Read %v, want %v", rbuf, want)
+ }
+}
diff --git a/pkg/sentry/fs/tmpfs/fs.go b/pkg/sentry/fs/tmpfs/fs.go
new file mode 100644
index 000000000..bc117ca6a
--- /dev/null
+++ b/pkg/sentry/fs/tmpfs/fs.go
@@ -0,0 +1,155 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "fmt"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+const (
+ // Set initial permissions for the root directory.
+ modeKey = "mode"
+
+ // UID for the root directory.
+ rootUIDKey = "uid"
+
+ // GID for the root directory.
+ rootGIDKey = "gid"
+
+ // cacheKey sets the caching policy for the mount.
+ cacheKey = "cache"
+
+ // cacheAll uses the virtual file system cache for everything (default).
+ cacheAll = "cache"
+
+ // cacheRevalidate allows dirents to be cached, but revalidates them on each
+ // lookup.
+ cacheRevalidate = "revalidate"
+
+ // Permissions that exceed modeMask will be rejected.
+ modeMask = 01777
+
+ // Default permissions are read/write/execute.
+ defaultMode = 0777
+)
+
+// Filesystem is a tmpfs.
+//
+// +stateify savable
+type Filesystem struct{}
+
+var _ fs.Filesystem = (*Filesystem)(nil)
+
+func init() {
+ fs.RegisterFilesystem(&Filesystem{})
+}
+
+// FilesystemName is the name under which the filesystem is registered.
+// Name matches mm/shmem.c:shmem_fs_type.name.
+const FilesystemName = "tmpfs"
+
+// Name is the name of the file system.
+func (*Filesystem) Name() string {
+ return FilesystemName
+}
+
+// AllowUserMount allows users to mount(2) this file system.
+func (*Filesystem) AllowUserMount() bool {
+ return true
+}
+
+// AllowUserList allows this filesystem to be listed in /proc/filesystems.
+func (*Filesystem) AllowUserList() bool {
+ return true
+}
+
+// Flags returns that there is nothing special about this file system.
+//
+// In Linux, tmpfs returns FS_USERNS_MOUNT, see mm/shmem.c.
+func (*Filesystem) Flags() fs.FilesystemFlags {
+ return 0
+}
+
+// Mount returns a tmpfs root that can be positioned in the vfs.
+func (f *Filesystem) Mount(ctx context.Context, device string, flags fs.MountSourceFlags, data string, _ interface{}) (*fs.Inode, error) {
+ // device is always ignored.
+
+ // Parse generic comma-separated key=value options, this file system expects them.
+ options := fs.GenericMountSourceOptions(data)
+
+ // Parse the root directory permissions.
+ perms := fs.FilePermsFromMode(defaultMode)
+ if m, ok := options[modeKey]; ok {
+ i, err := strconv.ParseUint(m, 8, 32)
+ if err != nil {
+ return nil, fmt.Errorf("mode value not parsable 'mode=%s': %v", m, err)
+ }
+ if i&^modeMask != 0 {
+ return nil, fmt.Errorf("invalid mode %q: must be less than %o", m, modeMask)
+ }
+ perms = fs.FilePermsFromMode(linux.FileMode(i))
+ delete(options, modeKey)
+ }
+
+ creds := auth.CredentialsFromContext(ctx)
+ owner := fs.FileOwnerFromContext(ctx)
+ if uidstr, ok := options[rootUIDKey]; ok {
+ uid, err := strconv.ParseInt(uidstr, 10, 32)
+ if err != nil {
+ return nil, fmt.Errorf("uid value not parsable 'uid=%d': %v", uid, err)
+ }
+ owner.UID = creds.UserNamespace.MapToKUID(auth.UID(uid))
+ delete(options, rootUIDKey)
+ }
+
+ if gidstr, ok := options[rootGIDKey]; ok {
+ gid, err := strconv.ParseInt(gidstr, 10, 32)
+ if err != nil {
+ return nil, fmt.Errorf("gid value not parsable 'gid=%d': %v", gid, err)
+ }
+ owner.GID = creds.UserNamespace.MapToKGID(auth.GID(gid))
+ delete(options, rootGIDKey)
+ }
+
+ // Construct a mount which will follow the cache options provided.
+ //
+ // TODO(gvisor.dev/issue/179): There should be no reason to disable
+ // caching once bind mounts are properly supported.
+ var msrc *fs.MountSource
+ switch options[cacheKey] {
+ case "", cacheAll:
+ msrc = fs.NewCachingMountSource(ctx, f, flags)
+ case cacheRevalidate:
+ msrc = fs.NewRevalidatingMountSource(ctx, f, flags)
+ default:
+ return nil, fmt.Errorf("invalid cache policy option %q", options[cacheKey])
+ }
+ delete(options, cacheKey)
+
+ // Fail if the caller passed us more options than we can parse. They may be
+ // expecting us to set something we can't set.
+ if len(options) > 0 {
+ return nil, fmt.Errorf("unsupported mount options: %v", options)
+ }
+
+ // Construct the tmpfs root.
+ return NewDir(ctx, nil, owner, perms, msrc), nil
+}
diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go
new file mode 100644
index 000000000..1dc75291d
--- /dev/null
+++ b/pkg/sentry/fs/tmpfs/inode_file.go
@@ -0,0 +1,687 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "fmt"
+ "io"
+ "math"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/metric"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+var (
+ opensRO = metric.MustCreateNewUint64Metric("/in_memory_file/opens_ro", false /* sync */, "Number of times an in-memory file was opened in read-only mode.")
+ opensW = metric.MustCreateNewUint64Metric("/in_memory_file/opens_w", false /* sync */, "Number of times an in-memory file was opened in write mode.")
+ reads = metric.MustCreateNewUint64Metric("/in_memory_file/reads", false /* sync */, "Number of in-memory file reads.")
+ readWait = metric.MustCreateNewUint64NanosecondsMetric("/in_memory_file/read_wait", false /* sync */, "Time waiting on in-memory file reads, in nanoseconds.")
+)
+
+// fileInodeOperations implements fs.InodeOperations for a regular tmpfs file.
+// These files are backed by pages allocated from a platform.Memory, and may be
+// directly mapped.
+//
+// Lock order: attrMu -> mapsMu -> dataMu.
+//
+// +stateify savable
+type fileInodeOperations struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+
+ fsutil.InodeSimpleExtendedAttributes
+
+ // kernel is used to allocate memory that stores the file's contents.
+ kernel *kernel.Kernel
+
+ // memUsage is the default memory usage that will be reported by this file.
+ memUsage usage.MemoryKind
+
+ attrMu sync.Mutex `state:"nosave"`
+
+ // attr contains the unstable metadata for the file.
+ //
+ // attr is protected by attrMu. attr.Size is protected by both attrMu
+ // and dataMu; reading it requires locking either mutex, while mutating
+ // it requires locking both.
+ attr fs.UnstableAttr
+
+ mapsMu sync.Mutex `state:"nosave"`
+
+ // mappings tracks mappings of the file into memmap.MappingSpaces.
+ //
+ // mappings is protected by mapsMu.
+ mappings memmap.MappingSet
+
+ // writableMappingPages tracks how many pages of virtual memory are mapped
+ // as potentially writable from this file. If a page has multiple mappings,
+ // each mapping is counted separately.
+ //
+ // This counter is susceptible to overflow as we can potentially count
+ // mappings from many VMAs. We count pages rather than bytes to slightly
+ // mitigate this.
+ //
+ // Protected by mapsMu.
+ writableMappingPages uint64
+
+ dataMu sync.RWMutex `state:"nosave"`
+
+ // data maps offsets into the file to offsets into platform.Memory() that
+ // store the file's data.
+ //
+ // data is protected by dataMu.
+ data fsutil.FileRangeSet
+
+ // seals represents file seals on this inode.
+ //
+ // Protected by dataMu.
+ seals uint32
+}
+
+var _ fs.InodeOperations = (*fileInodeOperations)(nil)
+
+// NewInMemoryFile returns a new file backed by Kernel.MemoryFile().
+func NewInMemoryFile(ctx context.Context, usage usage.MemoryKind, uattr fs.UnstableAttr) fs.InodeOperations {
+ return &fileInodeOperations{
+ attr: uattr,
+ kernel: kernel.KernelFromContext(ctx),
+ memUsage: usage,
+ seals: linux.F_SEAL_SEAL,
+ }
+}
+
+// NewMemfdInode creates a new inode backing a memfd. Memory used by the memfd
+// is backed by platform memory.
+func NewMemfdInode(ctx context.Context, allowSeals bool) *fs.Inode {
+ // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd inodes are set up with
+ // S_IRWXUGO.
+ perms := fs.PermMask{Read: true, Write: true, Execute: true}
+ iops := NewInMemoryFile(ctx, usage.Tmpfs, fs.UnstableAttr{
+ Owner: fs.FileOwnerFromContext(ctx),
+ Perms: fs.FilePermissions{User: perms, Group: perms, Other: perms}}).(*fileInodeOperations)
+ if allowSeals {
+ iops.seals = 0
+ }
+ return fs.NewInode(ctx, iops, fs.NewNonCachingMountSource(ctx, nil, fs.MountSourceFlags{}), fs.StableAttr{
+ Type: fs.RegularFile,
+ DeviceID: tmpfsDevice.DeviceID(),
+ InodeID: tmpfsDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ })
+}
+
+// Release implements fs.InodeOperations.Release.
+func (f *fileInodeOperations) Release(context.Context) {
+ f.dataMu.Lock()
+ defer f.dataMu.Unlock()
+ f.data.DropAll(f.kernel.MemoryFile())
+}
+
+// Mappable implements fs.InodeOperations.Mappable.
+func (f *fileInodeOperations) Mappable(*fs.Inode) memmap.Mappable {
+ return f
+}
+
+// Rename implements fs.InodeOperations.Rename.
+func (*fileInodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ return rename(ctx, oldParent, oldName, newParent, newName, replacement)
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (f *fileInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ if flags.Write {
+ opensW.Increment()
+ } else if flags.Read {
+ opensRO.Increment()
+ }
+ flags.Pread = true
+ flags.Pwrite = true
+ return fs.NewFile(ctx, d, flags, &regularFileOperations{iops: f}), nil
+}
+
+// UnstableAttr returns unstable attributes of this tmpfs file.
+func (f *fileInodeOperations) UnstableAttr(ctx context.Context, inode *fs.Inode) (fs.UnstableAttr, error) {
+ f.attrMu.Lock()
+ f.dataMu.RLock()
+ attr := f.attr
+ attr.Usage = int64(f.data.Span())
+ f.dataMu.RUnlock()
+ f.attrMu.Unlock()
+ return attr, nil
+}
+
+// Check implements fs.InodeOperations.Check.
+func (f *fileInodeOperations) Check(ctx context.Context, inode *fs.Inode, p fs.PermMask) bool {
+ return fs.ContextCanAccessFile(ctx, inode, p)
+}
+
+// SetPermissions implements fs.InodeOperations.SetPermissions.
+func (f *fileInodeOperations) SetPermissions(ctx context.Context, _ *fs.Inode, p fs.FilePermissions) bool {
+ f.attrMu.Lock()
+ f.attr.SetPermissions(ctx, p)
+ f.attrMu.Unlock()
+ return true
+}
+
+// SetTimestamps implements fs.InodeOperations.SetTimestamps.
+func (f *fileInodeOperations) SetTimestamps(ctx context.Context, _ *fs.Inode, ts fs.TimeSpec) error {
+ f.attrMu.Lock()
+ f.attr.SetTimestamps(ctx, ts)
+ f.attrMu.Unlock()
+ return nil
+}
+
+// SetOwner implements fs.InodeOperations.SetOwner.
+func (f *fileInodeOperations) SetOwner(ctx context.Context, _ *fs.Inode, owner fs.FileOwner) error {
+ f.attrMu.Lock()
+ f.attr.SetOwner(ctx, owner)
+ f.attrMu.Unlock()
+ return nil
+}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (f *fileInodeOperations) Truncate(ctx context.Context, _ *fs.Inode, size int64) error {
+ f.attrMu.Lock()
+ defer f.attrMu.Unlock()
+
+ f.dataMu.Lock()
+ oldSize := f.attr.Size
+
+ // Check if current seals allow truncation.
+ switch {
+ case size > oldSize && f.seals&linux.F_SEAL_GROW != 0: // Grow sealed
+ fallthrough
+ case oldSize > size && f.seals&linux.F_SEAL_SHRINK != 0: // Shrink sealed
+ f.dataMu.Unlock()
+ return syserror.EPERM
+ }
+
+ if oldSize != size {
+ f.attr.Size = size
+ // Update mtime and ctime.
+ now := ktime.NowFromContext(ctx)
+ f.attr.ModificationTime = now
+ f.attr.StatusChangeTime = now
+ }
+ f.dataMu.Unlock()
+
+ // Nothing left to do unless shrinking the file.
+ if oldSize <= size {
+ return nil
+ }
+
+ oldpgend := fs.OffsetPageEnd(oldSize)
+ newpgend := fs.OffsetPageEnd(size)
+
+ // Invalidate past translations of truncated pages.
+ if newpgend != oldpgend {
+ f.mapsMu.Lock()
+ f.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
+ // Compare Linux's mm/shmem.c:shmem_setattr() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ InvalidatePrivate: true,
+ })
+ f.mapsMu.Unlock()
+ }
+
+ // We are now guaranteed that there are no translations of truncated pages,
+ // and can remove them.
+ f.dataMu.Lock()
+ defer f.dataMu.Unlock()
+ f.data.Truncate(uint64(size), f.kernel.MemoryFile())
+
+ return nil
+}
+
+// Allocate implements fs.InodeOperations.Allocate.
+func (f *fileInodeOperations) Allocate(ctx context.Context, _ *fs.Inode, offset, length int64) error {
+ newSize := offset + length
+
+ f.attrMu.Lock()
+ defer f.attrMu.Unlock()
+ f.dataMu.Lock()
+ defer f.dataMu.Unlock()
+
+ if newSize <= f.attr.Size {
+ return nil
+ }
+
+ // Check if current seals allow growth.
+ if f.seals&linux.F_SEAL_GROW != 0 {
+ return syserror.EPERM
+ }
+
+ f.attr.Size = newSize
+
+ now := ktime.NowFromContext(ctx)
+ f.attr.ModificationTime = now
+ f.attr.StatusChangeTime = now
+
+ return nil
+}
+
+// AddLink implements fs.InodeOperations.AddLink.
+func (f *fileInodeOperations) AddLink() {
+ f.attrMu.Lock()
+ f.attr.Links++
+ f.attrMu.Unlock()
+}
+
+// DropLink implements fs.InodeOperations.DropLink.
+func (f *fileInodeOperations) DropLink() {
+ f.attrMu.Lock()
+ f.attr.Links--
+ f.attrMu.Unlock()
+}
+
+// NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange.
+func (f *fileInodeOperations) NotifyStatusChange(ctx context.Context) {
+ f.attrMu.Lock()
+ f.attr.StatusChangeTime = ktime.NowFromContext(ctx)
+ f.attrMu.Unlock()
+}
+
+// IsVirtual implements fs.InodeOperations.IsVirtual.
+func (*fileInodeOperations) IsVirtual() bool {
+ return true
+}
+
+// StatFS implements fs.InodeOperations.StatFS.
+func (*fileInodeOperations) StatFS(context.Context) (fs.Info, error) {
+ return fsInfo, nil
+}
+
+func (f *fileInodeOperations) read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ var start time.Time
+ if fs.RecordWaitTime {
+ start = time.Now()
+ }
+ reads.Increment()
+ // Zero length reads for tmpfs are no-ops.
+ if dst.NumBytes() == 0 {
+ fs.IncrementWait(readWait, start)
+ return 0, nil
+ }
+
+ // Have we reached EOF? We check for this again in
+ // fileReadWriter.ReadToBlocks to avoid holding f.attrMu (which would
+ // serialize reads) or f.dataMu (which would violate lock ordering), but
+ // check here first (before calling into MM) since reading at EOF is
+ // common: getting a return value of 0 from a read syscall is the only way
+ // to detect EOF.
+ //
+ // TODO(jamieliu): Separate out f.attr.Size and use atomics instead of
+ // f.dataMu.
+ f.dataMu.RLock()
+ size := f.attr.Size
+ f.dataMu.RUnlock()
+ if offset >= size {
+ fs.IncrementWait(readWait, start)
+ return 0, io.EOF
+ }
+
+ n, err := dst.CopyOutFrom(ctx, &fileReadWriter{f, offset})
+ if !file.Dirent.Inode.MountSource.Flags.NoAtime {
+ // Compare Linux's mm/filemap.c:do_generic_file_read() => file_accessed().
+ f.attrMu.Lock()
+ f.attr.AccessTime = ktime.NowFromContext(ctx)
+ f.attrMu.Unlock()
+ }
+ fs.IncrementWait(readWait, start)
+ return n, err
+}
+
+func (f *fileInodeOperations) write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ // Zero length writes for tmpfs are no-ops.
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ f.attrMu.Lock()
+ defer f.attrMu.Unlock()
+ // Compare Linux's mm/filemap.c:__generic_file_write_iter() => file_update_time().
+ now := ktime.NowFromContext(ctx)
+ f.attr.ModificationTime = now
+ f.attr.StatusChangeTime = now
+ return src.CopyInTo(ctx, &fileReadWriter{f, offset})
+}
+
+type fileReadWriter struct {
+ f *fileInodeOperations
+ offset int64
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (rw *fileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ rw.f.dataMu.RLock()
+ defer rw.f.dataMu.RUnlock()
+
+ // Compute the range to read.
+ if rw.offset >= rw.f.attr.Size {
+ return 0, io.EOF
+ }
+ end := fs.ReadEndOffset(rw.offset, int64(dsts.NumBytes()), rw.f.attr.Size)
+ if end == rw.offset { // dsts.NumBytes() == 0?
+ return 0, nil
+ }
+
+ mf := rw.f.kernel.MemoryFile()
+ var done uint64
+ seg, gap := rw.f.data.Find(uint64(rw.offset))
+ for rw.offset < end {
+ mr := memmap.MappableRange{uint64(rw.offset), uint64(end)}
+ switch {
+ case seg.Ok():
+ // Get internal mappings.
+ ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
+ if err != nil {
+ return done, err
+ }
+
+ // Copy from internal mappings.
+ n, err := safemem.CopySeq(dsts, ims)
+ done += n
+ rw.offset += int64(n)
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ // Tmpfs holes are zero-filled.
+ gapmr := gap.Range().Intersect(mr)
+ dst := dsts.TakeFirst64(gapmr.Length())
+ n, err := safemem.ZeroSeq(dst)
+ done += n
+ rw.offset += int64(n)
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
+
+ default:
+ break
+ }
+ }
+ return done, nil
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ rw.f.dataMu.Lock()
+ defer rw.f.dataMu.Unlock()
+
+ // Compute the range to write.
+ if srcs.NumBytes() == 0 {
+ // Nothing to do.
+ return 0, nil
+ }
+ end := fs.WriteEndOffset(rw.offset, int64(srcs.NumBytes()))
+ if end == math.MaxInt64 {
+ // Overflow.
+ return 0, syserror.EINVAL
+ }
+
+ // Check if seals prevent either file growth or all writes.
+ switch {
+ case rw.f.seals&linux.F_SEAL_WRITE != 0: // Write sealed
+ return 0, syserror.EPERM
+ case end > rw.f.attr.Size && rw.f.seals&linux.F_SEAL_GROW != 0: // Grow sealed
+ // When growth is sealed, Linux effectively allows writes which would
+ // normally grow the file to partially succeed up to the current EOF,
+ // rounded down to the page boundary before the EOF.
+ //
+ // This happens because writes (and thus the growth check) for tmpfs
+ // files proceed page-by-page on Linux, and the final write to the page
+ // containing EOF fails, resulting in a partial write up to the start of
+ // that page.
+ //
+ // To emulate this behaviour, artifically truncate the write to the
+ // start of the page containing the current EOF.
+ //
+ // See Linux, mm/filemap.c:generic_perform_write() and
+ // mm/shmem.c:shmem_write_begin().
+ if pgstart := int64(usermem.Addr(rw.f.attr.Size).RoundDown()); end > pgstart {
+ end = pgstart
+ }
+ if end <= rw.offset {
+ // Truncation would result in no data being written.
+ return 0, syserror.EPERM
+ }
+ }
+
+ defer func() {
+ // If the write ends beyond the file's previous size, it causes the
+ // file to grow.
+ if rw.offset > rw.f.attr.Size {
+ rw.f.attr.Size = rw.offset
+ }
+ }()
+
+ mf := rw.f.kernel.MemoryFile()
+ // Page-aligned mr for when we need to allocate memory. RoundUp can't
+ // overflow since end is an int64.
+ pgstartaddr := usermem.Addr(rw.offset).RoundDown()
+ pgendaddr, _ := usermem.Addr(end).RoundUp()
+ pgMR := memmap.MappableRange{uint64(pgstartaddr), uint64(pgendaddr)}
+
+ var done uint64
+ seg, gap := rw.f.data.Find(uint64(rw.offset))
+ for rw.offset < end {
+ mr := memmap.MappableRange{uint64(rw.offset), uint64(end)}
+ switch {
+ case seg.Ok():
+ // Get internal mappings.
+ ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Write)
+ if err != nil {
+ return done, err
+ }
+
+ // Copy to internal mappings.
+ n, err := safemem.CopySeq(ims, srcs)
+ done += n
+ rw.offset += int64(n)
+ srcs = srcs.DropFirst64(n)
+ if err != nil {
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ // Allocate memory for the write.
+ gapMR := gap.Range().Intersect(pgMR)
+ fr, err := mf.Allocate(gapMR.Length(), rw.f.memUsage)
+ if err != nil {
+ return done, err
+ }
+
+ // Write to that memory as usual.
+ seg, gap = rw.f.data.Insert(gap, gapMR, fr.Start), fsutil.FileRangeGapIterator{}
+
+ default:
+ break
+ }
+ }
+ return done, nil
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (f *fileInodeOperations) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ f.mapsMu.Lock()
+ defer f.mapsMu.Unlock()
+
+ f.dataMu.RLock()
+ defer f.dataMu.RUnlock()
+
+ // Reject writable mapping if F_SEAL_WRITE is set.
+ if f.seals&linux.F_SEAL_WRITE != 0 && writable {
+ return syserror.EPERM
+ }
+
+ f.mappings.AddMapping(ms, ar, offset, writable)
+ if writable {
+ pagesBefore := f.writableMappingPages
+
+ // ar is guaranteed to be page aligned per memmap.Mappable.
+ f.writableMappingPages += uint64(ar.Length() / usermem.PageSize)
+
+ if f.writableMappingPages < pagesBefore {
+ panic(fmt.Sprintf("Overflow while mapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, f.writableMappingPages))
+ }
+ }
+
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (f *fileInodeOperations) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ f.mapsMu.Lock()
+ defer f.mapsMu.Unlock()
+
+ f.mappings.RemoveMapping(ms, ar, offset, writable)
+
+ if writable {
+ pagesBefore := f.writableMappingPages
+
+ // ar is guaranteed to be page aligned per memmap.Mappable.
+ f.writableMappingPages -= uint64(ar.Length() / usermem.PageSize)
+
+ if f.writableMappingPages > pagesBefore {
+ panic(fmt.Sprintf("Underflow while unmapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, f.writableMappingPages))
+ }
+ }
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (f *fileInodeOperations) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ return f.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (f *fileInodeOperations) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ f.dataMu.Lock()
+ defer f.dataMu.Unlock()
+
+ // Constrain translations to f.attr.Size (rounded up) to prevent
+ // translation to pages that may be concurrently truncated.
+ pgend := fs.OffsetPageEnd(f.attr.Size)
+ var beyondEOF bool
+ if required.End > pgend {
+ if required.Start >= pgend {
+ return nil, &memmap.BusError{io.EOF}
+ }
+ beyondEOF = true
+ required.End = pgend
+ }
+ if optional.End > pgend {
+ optional.End = pgend
+ }
+
+ mf := f.kernel.MemoryFile()
+ cerr := f.data.Fill(ctx, required, optional, mf, f.memUsage, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) {
+ // Newly-allocated pages are zeroed, so we don't need to do anything.
+ return dsts.NumBytes(), nil
+ })
+
+ var ts []memmap.Translation
+ var translatedEnd uint64
+ for seg := f.data.FindSegment(required.Start); seg.Ok() && seg.Start() < required.End; seg, _ = seg.NextNonEmpty() {
+ segMR := seg.Range().Intersect(optional)
+ ts = append(ts, memmap.Translation{
+ Source: segMR,
+ File: mf,
+ Offset: seg.FileRangeOf(segMR).Start,
+ Perms: usermem.AnyAccess,
+ })
+ translatedEnd = segMR.End
+ }
+
+ // Don't return the error returned by f.data.Fill if it occurred outside of
+ // required.
+ if translatedEnd < required.End && cerr != nil {
+ return ts, &memmap.BusError{cerr}
+ }
+ if beyondEOF {
+ return ts, &memmap.BusError{io.EOF}
+ }
+ return ts, nil
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (f *fileInodeOperations) InvalidateUnsavable(ctx context.Context) error {
+ return nil
+}
+
+// GetSeals returns the current set of seals on a memfd inode.
+func GetSeals(inode *fs.Inode) (uint32, error) {
+ if f, ok := inode.InodeOperations.(*fileInodeOperations); ok {
+ f.dataMu.RLock()
+ defer f.dataMu.RUnlock()
+ return f.seals, nil
+ }
+ // Not a memfd inode.
+ return 0, syserror.EINVAL
+}
+
+// AddSeals adds new file seals to a memfd inode.
+func AddSeals(inode *fs.Inode, val uint32) error {
+ if f, ok := inode.InodeOperations.(*fileInodeOperations); ok {
+ f.mapsMu.Lock()
+ defer f.mapsMu.Unlock()
+ f.dataMu.Lock()
+ defer f.dataMu.Unlock()
+
+ if f.seals&linux.F_SEAL_SEAL != 0 {
+ // Seal applied which prevents addition of any new seals.
+ return syserror.EPERM
+ }
+
+ // F_SEAL_WRITE can only be added if there are no active writable maps.
+ if f.seals&linux.F_SEAL_WRITE == 0 && val&linux.F_SEAL_WRITE != 0 {
+ if f.writableMappingPages > 0 {
+ return syserror.EBUSY
+ }
+ }
+
+ // Seals can only be added, never removed.
+ f.seals |= val
+ return nil
+ }
+ // Not a memfd inode.
+ return syserror.EINVAL
+}
diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go
new file mode 100644
index 000000000..b095312fe
--- /dev/null
+++ b/pkg/sentry/fs/tmpfs/tmpfs.go
@@ -0,0 +1,356 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tmpfs is a filesystem implementation backed by memory.
+package tmpfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+var fsInfo = fs.Info{
+ Type: linux.TMPFS_MAGIC,
+
+ // TODO(b/29637826): allow configuring a tmpfs size and enforce it.
+ TotalBlocks: 0,
+ FreeBlocks: 0,
+}
+
+// rename implements fs.InodeOperations.Rename for tmpfs nodes.
+func rename(ctx context.Context, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ // Don't allow renames across different mounts.
+ if newParent.MountSource != oldParent.MountSource {
+ return syserror.EXDEV
+ }
+
+ op := oldParent.InodeOperations.(*Dir)
+ np := newParent.InodeOperations.(*Dir)
+ return ramfs.Rename(ctx, op.ramfsDir, oldName, np.ramfsDir, newName, replacement)
+}
+
+// Dir is a directory.
+//
+// +stateify savable
+type Dir struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeIsDirTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ // Ideally this would be embedded, so that we "inherit" all of the
+ // InodeOperations implemented by ramfs.Dir for free.
+ //
+ // However, ramfs.dirFileOperations stores a pointer to a ramfs.Dir,
+ // and our save/restore package does not allow saving a pointer to an
+ // embedded field elsewhere.
+ //
+ // Thus, we must make the ramfs.Dir is a field, and we delegate all the
+ // InodeOperation methods to it.
+ ramfsDir *ramfs.Dir
+
+ // kernel is used to allocate memory as storage for tmpfs Files.
+ kernel *kernel.Kernel
+}
+
+var _ fs.InodeOperations = (*Dir)(nil)
+
+// NewDir returns a new directory.
+func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode {
+ d := &Dir{
+ ramfsDir: ramfs.NewDir(ctx, contents, owner, perms),
+ kernel: kernel.KernelFromContext(ctx),
+ }
+
+ // Manually set the CreateOps.
+ d.ramfsDir.CreateOps = d.newCreateOps()
+
+ return fs.NewInode(ctx, d, msrc, fs.StableAttr{
+ DeviceID: tmpfsDevice.DeviceID(),
+ InodeID: tmpfsDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.Directory,
+ })
+}
+
+// afterLoad is invoked by stateify.
+func (d *Dir) afterLoad() {
+ // Per NewDir, manually set the CreateOps.
+ d.ramfsDir.CreateOps = d.newCreateOps()
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (d *Dir) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return d.ramfsDir.GetFile(ctx, dirent, flags)
+}
+
+// AddLink implements fs.InodeOperations.AddLink.
+func (d *Dir) AddLink() {
+ d.ramfsDir.AddLink()
+}
+
+// DropLink implements fs.InodeOperations.DropLink.
+func (d *Dir) DropLink() {
+ d.ramfsDir.DropLink()
+}
+
+// Bind implements fs.InodeOperations.Bind.
+func (d *Dir) Bind(ctx context.Context, dir *fs.Inode, name string, ep transport.BoundEndpoint, perms fs.FilePermissions) (*fs.Dirent, error) {
+ return d.ramfsDir.Bind(ctx, dir, name, ep, perms)
+}
+
+// Create implements fs.InodeOperations.Create.
+func (d *Dir) Create(ctx context.Context, dir *fs.Inode, name string, flags fs.FileFlags, perms fs.FilePermissions) (*fs.File, error) {
+ return d.ramfsDir.Create(ctx, dir, name, flags, perms)
+}
+
+// CreateLink implements fs.InodeOperations.CreateLink.
+func (d *Dir) CreateLink(ctx context.Context, dir *fs.Inode, oldname, newname string) error {
+ return d.ramfsDir.CreateLink(ctx, dir, oldname, newname)
+}
+
+// CreateHardLink implements fs.InodeOperations.CreateHardLink.
+func (d *Dir) CreateHardLink(ctx context.Context, dir *fs.Inode, target *fs.Inode, name string) error {
+ return d.ramfsDir.CreateHardLink(ctx, dir, target, name)
+}
+
+// CreateDirectory implements fs.InodeOperations.CreateDirectory.
+func (d *Dir) CreateDirectory(ctx context.Context, dir *fs.Inode, name string, perms fs.FilePermissions) error {
+ return d.ramfsDir.CreateDirectory(ctx, dir, name, perms)
+}
+
+// CreateFifo implements fs.InodeOperations.CreateFifo.
+func (d *Dir) CreateFifo(ctx context.Context, dir *fs.Inode, name string, perms fs.FilePermissions) error {
+ return d.ramfsDir.CreateFifo(ctx, dir, name, perms)
+}
+
+// GetXattr implements fs.InodeOperations.GetXattr.
+func (d *Dir) GetXattr(ctx context.Context, i *fs.Inode, name string, size uint64) (string, error) {
+ return d.ramfsDir.GetXattr(ctx, i, name, size)
+}
+
+// SetXattr implements fs.InodeOperations.SetXattr.
+func (d *Dir) SetXattr(ctx context.Context, i *fs.Inode, name, value string, flags uint32) error {
+ return d.ramfsDir.SetXattr(ctx, i, name, value, flags)
+}
+
+// ListXattr implements fs.InodeOperations.ListXattr.
+func (d *Dir) ListXattr(ctx context.Context, i *fs.Inode, size uint64) (map[string]struct{}, error) {
+ return d.ramfsDir.ListXattr(ctx, i, size)
+}
+
+// RemoveXattr implements fs.InodeOperations.RemoveXattr.
+func (d *Dir) RemoveXattr(ctx context.Context, i *fs.Inode, name string) error {
+ return d.ramfsDir.RemoveXattr(ctx, i, name)
+}
+
+// Lookup implements fs.InodeOperations.Lookup.
+func (d *Dir) Lookup(ctx context.Context, i *fs.Inode, p string) (*fs.Dirent, error) {
+ return d.ramfsDir.Lookup(ctx, i, p)
+}
+
+// NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange.
+func (d *Dir) NotifyStatusChange(ctx context.Context) {
+ d.ramfsDir.NotifyStatusChange(ctx)
+}
+
+// Remove implements fs.InodeOperations.Remove.
+func (d *Dir) Remove(ctx context.Context, i *fs.Inode, name string) error {
+ return d.ramfsDir.Remove(ctx, i, name)
+}
+
+// RemoveDirectory implements fs.InodeOperations.RemoveDirectory.
+func (d *Dir) RemoveDirectory(ctx context.Context, i *fs.Inode, name string) error {
+ return d.ramfsDir.RemoveDirectory(ctx, i, name)
+}
+
+// UnstableAttr implements fs.InodeOperations.UnstableAttr.
+func (d *Dir) UnstableAttr(ctx context.Context, i *fs.Inode) (fs.UnstableAttr, error) {
+ return d.ramfsDir.UnstableAttr(ctx, i)
+}
+
+// SetPermissions implements fs.InodeOperations.SetPermissions.
+func (d *Dir) SetPermissions(ctx context.Context, i *fs.Inode, p fs.FilePermissions) bool {
+ return d.ramfsDir.SetPermissions(ctx, i, p)
+}
+
+// SetOwner implements fs.InodeOperations.SetOwner.
+func (d *Dir) SetOwner(ctx context.Context, i *fs.Inode, owner fs.FileOwner) error {
+ return d.ramfsDir.SetOwner(ctx, i, owner)
+}
+
+// SetTimestamps implements fs.InodeOperations.SetTimestamps.
+func (d *Dir) SetTimestamps(ctx context.Context, i *fs.Inode, ts fs.TimeSpec) error {
+ return d.ramfsDir.SetTimestamps(ctx, i, ts)
+}
+
+// newCreateOps builds the custom CreateOps for this Dir.
+func (d *Dir) newCreateOps() *ramfs.CreateOps {
+ return &ramfs.CreateOps{
+ NewDir: func(ctx context.Context, dir *fs.Inode, perms fs.FilePermissions) (*fs.Inode, error) {
+ return NewDir(ctx, nil, fs.FileOwnerFromContext(ctx), perms, dir.MountSource), nil
+ },
+ NewFile: func(ctx context.Context, dir *fs.Inode, perms fs.FilePermissions) (*fs.Inode, error) {
+ uattr := fs.WithCurrentTime(ctx, fs.UnstableAttr{
+ Owner: fs.FileOwnerFromContext(ctx),
+ Perms: perms,
+ // Always start unlinked.
+ Links: 0,
+ })
+ iops := NewInMemoryFile(ctx, usage.Tmpfs, uattr)
+ return fs.NewInode(ctx, iops, dir.MountSource, fs.StableAttr{
+ DeviceID: tmpfsDevice.DeviceID(),
+ InodeID: tmpfsDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.RegularFile,
+ }), nil
+ },
+ NewSymlink: func(ctx context.Context, dir *fs.Inode, target string) (*fs.Inode, error) {
+ return NewSymlink(ctx, target, fs.FileOwnerFromContext(ctx), dir.MountSource), nil
+ },
+ NewBoundEndpoint: func(ctx context.Context, dir *fs.Inode, socket transport.BoundEndpoint, perms fs.FilePermissions) (*fs.Inode, error) {
+ return NewSocket(ctx, socket, fs.FileOwnerFromContext(ctx), perms, dir.MountSource), nil
+ },
+ NewFifo: func(ctx context.Context, dir *fs.Inode, perms fs.FilePermissions) (*fs.Inode, error) {
+ return NewFifo(ctx, fs.FileOwnerFromContext(ctx), perms, dir.MountSource), nil
+ },
+ }
+}
+
+// Rename implements fs.InodeOperations.Rename.
+func (d *Dir) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ return rename(ctx, oldParent, oldName, newParent, newName, replacement)
+}
+
+// StatFS implements fs.InodeOperations.StatFS.
+func (*Dir) StatFS(context.Context) (fs.Info, error) {
+ return fsInfo, nil
+}
+
+// Allocate implements fs.InodeOperations.Allocate.
+func (d *Dir) Allocate(ctx context.Context, node *fs.Inode, offset, length int64) error {
+ return d.ramfsDir.Allocate(ctx, node, offset, length)
+}
+
+// Release implements fs.InodeOperations.Release.
+func (d *Dir) Release(ctx context.Context) {
+ d.ramfsDir.Release(ctx)
+}
+
+// Symlink is a symlink.
+//
+// +stateify savable
+type Symlink struct {
+ ramfs.Symlink
+}
+
+// NewSymlink returns a new symlink with the provided permissions.
+func NewSymlink(ctx context.Context, target string, owner fs.FileOwner, msrc *fs.MountSource) *fs.Inode {
+ s := &Symlink{Symlink: *ramfs.NewSymlink(ctx, owner, target)}
+ return fs.NewInode(ctx, s, msrc, fs.StableAttr{
+ DeviceID: tmpfsDevice.DeviceID(),
+ InodeID: tmpfsDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.Symlink,
+ })
+}
+
+// Rename implements fs.InodeOperations.Rename.
+func (s *Symlink) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ return rename(ctx, oldParent, oldName, newParent, newName, replacement)
+}
+
+// StatFS returns the tmpfs info.
+func (s *Symlink) StatFS(context.Context) (fs.Info, error) {
+ return fsInfo, nil
+}
+
+// Socket is a socket.
+//
+// +stateify savable
+type Socket struct {
+ ramfs.Socket
+ fsutil.InodeNotTruncatable `state:"nosave"`
+ fsutil.InodeNotAllocatable `state:"nosave"`
+}
+
+// NewSocket returns a new socket with the provided permissions.
+func NewSocket(ctx context.Context, socket transport.BoundEndpoint, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode {
+ s := &Socket{Socket: *ramfs.NewSocket(ctx, socket, owner, perms)}
+ return fs.NewInode(ctx, s, msrc, fs.StableAttr{
+ DeviceID: tmpfsDevice.DeviceID(),
+ InodeID: tmpfsDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.Socket,
+ })
+}
+
+// Rename implements fs.InodeOperations.Rename.
+func (s *Socket) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ return rename(ctx, oldParent, oldName, newParent, newName, replacement)
+}
+
+// StatFS returns the tmpfs info.
+func (s *Socket) StatFS(context.Context) (fs.Info, error) {
+ return fsInfo, nil
+}
+
+// Fifo is a tmpfs named pipe.
+//
+// +stateify savable
+type Fifo struct {
+ fs.InodeOperations
+}
+
+// NewFifo creates a new named pipe.
+func NewFifo(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode {
+ // First create a pipe.
+ p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)
+
+ // Build pipe InodeOperations.
+ iops := pipe.NewInodeOperations(ctx, perms, p)
+
+ // Wrap the iops with our Fifo.
+ fifoIops := &Fifo{iops}
+
+ // Build a new Inode.
+ return fs.NewInode(ctx, fifoIops, msrc, fs.StableAttr{
+ DeviceID: tmpfsDevice.DeviceID(),
+ InodeID: tmpfsDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.Pipe,
+ })
+}
+
+// Rename implements fs.InodeOperations.Rename.
+func (f *Fifo) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ return rename(ctx, oldParent, oldName, newParent, newName, replacement)
+}
+
+// StatFS returns the tmpfs info.
+func (*Fifo) StatFS(context.Context) (fs.Info, error) {
+ return fsInfo, nil
+}
diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD
new file mode 100644
index 000000000..5cb0e0417
--- /dev/null
+++ b/pkg/sentry/fs/tty/BUILD
@@ -0,0 +1,47 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "tty",
+ srcs = [
+ "dir.go",
+ "fs.go",
+ "line_discipline.go",
+ "master.go",
+ "queue.go",
+ "slave.go",
+ "terminal.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/refs",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/unimpl",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "tty_test",
+ size = "small",
+ srcs = ["tty_test.go"],
+ library = ":tty",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/contexttest",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go
new file mode 100644
index 000000000..108654827
--- /dev/null
+++ b/pkg/sentry/fs/tty/dir.go
@@ -0,0 +1,342 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tty provide pseudoterminals via a devpts filesystem.
+package tty
+
+import (
+ "fmt"
+ "math"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// dirInodeOperations is the root of a devpts mount.
+//
+// This indirectly manages all terminals within the mount.
+//
+// New Terminals are created by masterInodeOperations.GetFile, which registers
+// the slave Inode in the this directory for discovery via Lookup/Readdir. The
+// slave inode is unregistered when the master file is Released, as the slave
+// is no longer discoverable at that point.
+//
+// References on the underlying Terminal are held by masterFileOperations and
+// slaveInodeOperations.
+//
+// masterInodeOperations and slaveInodeOperations hold a pointer to
+// dirInodeOperations, which is reference counted by the refcount their
+// corresponding Dirents hold on their parent (this directory).
+//
+// dirInodeOperations implements fs.InodeOperations.
+//
+// +stateify savable
+type dirInodeOperations struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeIsDirAllocate `state:"nosave"`
+ fsutil.InodeIsDirTruncate `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotRenameable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+
+ // msrc is the super block this directory is on.
+ //
+ // TODO(chrisko): Plumb this through instead of storing it here.
+ msrc *fs.MountSource
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // master is the master PTY inode.
+ master *fs.Inode
+
+ // slaves contains the slave inodes reachable from the directory.
+ //
+ // A new slave is added by allocateTerminal and is removed by
+ // masterFileOperations.Release.
+ //
+ // A reference is held on every slave in the map.
+ slaves map[uint32]*fs.Inode
+
+ // dentryMap is a SortedDentryMap used to implement Readdir containing
+ // the master and all entries in slaves.
+ dentryMap *fs.SortedDentryMap
+
+ // next is the next pty index to use.
+ //
+ // TODO(b/29356795): reuse indices when ptys are closed.
+ next uint32
+}
+
+var _ fs.InodeOperations = (*dirInodeOperations)(nil)
+
+// newDir creates a new dir with a ptmx file and no terminals.
+func newDir(ctx context.Context, m *fs.MountSource) *fs.Inode {
+ d := &dirInodeOperations{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.RootOwner, fs.FilePermsFromMode(0555), linux.DEVPTS_SUPER_MAGIC),
+ msrc: m,
+ slaves: make(map[uint32]*fs.Inode),
+ dentryMap: fs.NewSortedDentryMap(nil),
+ }
+ // Linux devpts uses a default mode of 0000 for ptmx which can be
+ // changed with the ptmxmode mount option. However, that default is not
+ // useful here (since we'd *always* need the mount option, so it is
+ // accessible by default).
+ d.master = newMasterInode(ctx, d, fs.RootOwner, fs.FilePermsFromMode(0666))
+ d.dentryMap.Add("ptmx", fs.DentAttr{
+ Type: d.master.StableAttr.Type,
+ InodeID: d.master.StableAttr.InodeID,
+ })
+
+ return fs.NewInode(ctx, d, m, fs.StableAttr{
+ DeviceID: ptsDevice.DeviceID(),
+ // N.B. Linux always uses inode id 1 for the directory. See
+ // fs/devpts/inode.c:devpts_fill_super.
+ //
+ // TODO(b/75267214): Since ptsDevice must be shared between
+ // different mounts, we must not assign fixed numbers.
+ InodeID: ptsDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.Directory,
+ })
+}
+
+// Release implements fs.InodeOperations.Release.
+func (d *dirInodeOperations) Release(ctx context.Context) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ d.master.DecRef()
+ if len(d.slaves) != 0 {
+ panic(fmt.Sprintf("devpts directory still contains active terminals: %+v", d))
+ }
+}
+
+// Lookup implements fs.InodeOperations.Lookup.
+func (d *dirInodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string) (*fs.Dirent, error) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // Master?
+ if name == "ptmx" {
+ d.master.IncRef()
+ return fs.NewDirent(ctx, d.master, name), nil
+ }
+
+ // Slave number?
+ n, err := strconv.ParseUint(name, 10, 32)
+ if err != nil {
+ // Not found.
+ return nil, syserror.ENOENT
+ }
+
+ s, ok := d.slaves[uint32(n)]
+ if !ok {
+ return nil, syserror.ENOENT
+ }
+
+ s.IncRef()
+ return fs.NewDirent(ctx, s, name), nil
+}
+
+// Create implements fs.InodeOperations.Create.
+//
+// Creation is never allowed.
+func (d *dirInodeOperations) Create(ctx context.Context, dir *fs.Inode, name string, flags fs.FileFlags, perm fs.FilePermissions) (*fs.File, error) {
+ return nil, syserror.EACCES
+}
+
+// CreateDirectory implements fs.InodeOperations.CreateDirectory.
+//
+// Creation is never allowed.
+func (d *dirInodeOperations) CreateDirectory(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions) error {
+ return syserror.EACCES
+}
+
+// CreateLink implements fs.InodeOperations.CreateLink.
+//
+// Creation is never allowed.
+func (d *dirInodeOperations) CreateLink(ctx context.Context, dir *fs.Inode, oldname, newname string) error {
+ return syserror.EACCES
+}
+
+// CreateHardLink implements fs.InodeOperations.CreateHardLink.
+//
+// Creation is never allowed.
+func (d *dirInodeOperations) CreateHardLink(ctx context.Context, dir *fs.Inode, target *fs.Inode, name string) error {
+ return syserror.EACCES
+}
+
+// CreateFifo implements fs.InodeOperations.CreateFifo.
+//
+// Creation is never allowed.
+func (d *dirInodeOperations) CreateFifo(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions) error {
+ return syserror.EACCES
+}
+
+// Remove implements fs.InodeOperations.Remove.
+//
+// Removal is never allowed.
+func (d *dirInodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string) error {
+ return syserror.EPERM
+}
+
+// RemoveDirectory implements fs.InodeOperations.RemoveDirectory.
+//
+// Removal is never allowed.
+func (d *dirInodeOperations) RemoveDirectory(ctx context.Context, dir *fs.Inode, name string) error {
+ return syserror.EPERM
+}
+
+// Bind implements fs.InodeOperations.Bind.
+func (d *dirInodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, data transport.BoundEndpoint, perm fs.FilePermissions) (*fs.Dirent, error) {
+ return nil, syserror.EPERM
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (d *dirInodeOperations) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &dirFileOperations{di: d}), nil
+}
+
+// allocateTerminal creates a new Terminal and installs a pts node for it.
+//
+// The caller must call DecRef when done with the returned Terminal.
+func (d *dirInodeOperations) allocateTerminal(ctx context.Context) (*Terminal, error) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ n := d.next
+ if n == math.MaxUint32 {
+ return nil, syserror.ENOMEM
+ }
+
+ if _, ok := d.slaves[n]; ok {
+ panic(fmt.Sprintf("pty index collision; index %d already exists", n))
+ }
+
+ t := newTerminal(ctx, d, n)
+ d.next++
+
+ // The reference returned by newTerminal is returned to the caller.
+ // Take another for the slave inode.
+ t.IncRef()
+
+ // Create a pts node. The owner is based on the context that opens
+ // ptmx.
+ creds := auth.CredentialsFromContext(ctx)
+ uid, gid := creds.EffectiveKUID, creds.EffectiveKGID
+ slave := newSlaveInode(ctx, d, t, fs.FileOwner{uid, gid}, fs.FilePermsFromMode(0666))
+
+ d.slaves[n] = slave
+ d.dentryMap.Add(strconv.FormatUint(uint64(n), 10), fs.DentAttr{
+ Type: slave.StableAttr.Type,
+ InodeID: slave.StableAttr.InodeID,
+ })
+
+ return t, nil
+}
+
+// masterClose is called when the master end of t is closed.
+func (d *dirInodeOperations) masterClose(t *Terminal) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // The slave end disappears from the directory when the master end is
+ // closed, even if the slave end is open elsewhere.
+ //
+ // N.B. since we're using a backdoor method to remove a directory entry
+ // we won't properly fire inotify events like Linux would.
+ s, ok := d.slaves[t.n]
+ if !ok {
+ panic(fmt.Sprintf("Terminal %+v doesn't exist in %+v?", t, d))
+ }
+
+ s.DecRef()
+ delete(d.slaves, t.n)
+ d.dentryMap.Remove(strconv.FormatUint(uint64(t.n), 10))
+}
+
+// dirFileOperations are the fs.FileOperations for the directory.
+//
+// This is nearly identical to fsutil.DirFileOperations, except that it takes
+// df.di.mu in IterateDir.
+//
+// +stateify savable
+type dirFileOperations struct {
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ // di is the inode operations.
+ di *dirInodeOperations
+
+ // dirCursor contains the name of the last directory entry that was
+ // serialized.
+ dirCursor string
+}
+
+var _ fs.FileOperations = (*dirFileOperations)(nil)
+
+// IterateDir implements DirIterator.IterateDir.
+func (df *dirFileOperations) IterateDir(ctx context.Context, d *fs.Dirent, dirCtx *fs.DirCtx, offset int) (int, error) {
+ df.di.mu.Lock()
+ defer df.di.mu.Unlock()
+
+ n, err := fs.GenericReaddir(dirCtx, df.di.dentryMap)
+ return offset + n, err
+}
+
+// Readdir implements FileOperations.Readdir.
+func (df *dirFileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
+ root := fs.RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+ dirCtx := &fs.DirCtx{
+ Serializer: serializer,
+ DirCursor: &df.dirCursor,
+ }
+ return fs.DirentReaddir(ctx, file.Dirent, df, root, dirCtx, file.Offset())
+}
+
+// Read implements FileOperations.Read
+func (df *dirFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.EISDIR
+}
+
+// Write implements FileOperations.Write.
+func (df *dirFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.EISDIR
+}
diff --git a/pkg/sentry/fs/tty/fs.go b/pkg/sentry/fs/tty/fs.go
new file mode 100644
index 000000000..8fe05ebe5
--- /dev/null
+++ b/pkg/sentry/fs/tty/fs.go
@@ -0,0 +1,111 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tty
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/device"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// ptsDevice is the pseudo-filesystem device.
+var ptsDevice = device.NewAnonDevice()
+
+// filesystem is a devpts filesystem.
+//
+// This devpts is always in the new "multi-instance" mode. i.e., it contains a
+// ptmx device tied to this mount.
+//
+// +stateify savable
+type filesystem struct{}
+
+func init() {
+ fs.RegisterFilesystem(&filesystem{})
+}
+
+// Name matches drivers/devpts/indoe.c:devpts_fs_type.name.
+func (*filesystem) Name() string {
+ return "devpts"
+}
+
+// AllowUserMount allows users to mount(2) this file system.
+func (*filesystem) AllowUserMount() bool {
+ // TODO(b/29356795): Users may mount this once the terminals are in a
+ // usable state.
+ return false
+}
+
+// AllowUserList allows this filesystem to be listed in /proc/filesystems.
+func (*filesystem) AllowUserList() bool {
+ return true
+}
+
+// Flags returns that there is nothing special about this file system.
+func (*filesystem) Flags() fs.FilesystemFlags {
+ return 0
+}
+
+// MountSource returns a devpts root that can be positioned in the vfs.
+func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSourceFlags, data string, _ interface{}) (*fs.Inode, error) {
+ // device is always ignored.
+
+ // No options are supported.
+ if data != "" {
+ return nil, syserror.EINVAL
+ }
+
+ return newDir(ctx, fs.NewMountSource(ctx, &superOperations{}, f, flags)), nil
+}
+
+// superOperations implements fs.MountSourceOperations, preventing caching.
+//
+// +stateify savable
+type superOperations struct{}
+
+// Revalidate implements fs.DirentOperations.Revalidate.
+//
+// It always returns true, forcing a Lookup for all entries.
+//
+// Slave entries are dropped from dir when their master is closed, so an
+// existing slave Dirent in the tree is not sufficient to guarantee that it
+// still exists on the filesystem.
+func (superOperations) Revalidate(context.Context, string, *fs.Inode, *fs.Inode) bool {
+ return true
+}
+
+// Keep implements fs.DirentOperations.Keep.
+//
+// Keep returns false because Revalidate would force a lookup on cached entries
+// anyways.
+func (superOperations) Keep(*fs.Dirent) bool {
+ return false
+}
+
+// CacheReaddir implements fs.DirentOperations.CacheReaddir.
+//
+// CacheReaddir returns false because entries change on master operations.
+func (superOperations) CacheReaddir() bool {
+ return false
+}
+
+// ResetInodeMappings implements MountSourceOperations.ResetInodeMappings.
+func (superOperations) ResetInodeMappings() {}
+
+// SaveInodeMapping implements MountSourceOperations.SaveInodeMapping.
+func (superOperations) SaveInodeMapping(*fs.Inode, string) {}
+
+// Destroy implements MountSourceOperations.Destroy.
+func (superOperations) Destroy() {}
diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go
new file mode 100644
index 000000000..2e9dd2d55
--- /dev/null
+++ b/pkg/sentry/fs/tty/line_discipline.go
@@ -0,0 +1,449 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tty
+
+import (
+ "bytes"
+ "unicode/utf8"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+const (
+ // canonMaxBytes is the number of bytes that fit into a single line of
+ // terminal input in canonical mode. This corresponds to N_TTY_BUF_SIZE
+ // in include/linux/tty.h.
+ canonMaxBytes = 4096
+
+ // nonCanonMaxBytes is the maximum number of bytes that can be read at
+ // a time in noncanonical mode.
+ nonCanonMaxBytes = canonMaxBytes - 1
+
+ spacesPerTab = 8
+)
+
+// lineDiscipline dictates how input and output are handled between the
+// pseudoterminal (pty) master and slave. It can be configured to alter I/O,
+// modify control characters (e.g. Ctrl-C for SIGINT), etc. The following man
+// pages are good resources for how to affect the line discipline:
+//
+// * termios(3)
+// * tty_ioctl(4)
+//
+// This file corresponds most closely to drivers/tty/n_tty.c.
+//
+// lineDiscipline has a simple structure but supports a multitude of options
+// (see the above man pages). It consists of two queues of bytes: one from the
+// terminal master to slave (the input queue) and one from slave to master (the
+// output queue). When bytes are written to one end of the pty, the line
+// discipline reads the bytes, modifies them or takes special action if
+// required, and enqueues them to be read by the other end of the pty:
+//
+// input from terminal +-------------+ input to process (e.g. bash)
+// +------------------------>| input queue |---------------------------+
+// | (inputQueueWrite) +-------------+ (inputQueueRead) |
+// | |
+// | v
+// masterFD slaveFD
+// ^ |
+// | |
+// | output to terminal +--------------+ output from process |
+// +------------------------| output queue |<--------------------------+
+// (outputQueueRead) +--------------+ (outputQueueWrite)
+//
+// Lock order:
+// termiosMu
+// inQueue.mu
+// outQueue.mu
+//
+// +stateify savable
+type lineDiscipline struct {
+ // sizeMu protects size.
+ sizeMu sync.Mutex `state:"nosave"`
+
+ // size is the terminal size (width and height).
+ size linux.WindowSize
+
+ // inQueue is the input queue of the terminal.
+ inQueue queue
+
+ // outQueue is the output queue of the terminal.
+ outQueue queue
+
+ // termiosMu protects termios.
+ termiosMu sync.RWMutex `state:"nosave"`
+
+ // termios is the terminal configuration used by the lineDiscipline.
+ termios linux.KernelTermios
+
+ // column is the location in a row of the cursor. This is important for
+ // handling certain special characters like backspace.
+ column int
+
+ // masterWaiter is used to wait on the master end of the TTY.
+ masterWaiter waiter.Queue `state:"zerovalue"`
+
+ // slaveWaiter is used to wait on the slave end of the TTY.
+ slaveWaiter waiter.Queue `state:"zerovalue"`
+}
+
+func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline {
+ ld := lineDiscipline{termios: termios}
+ ld.inQueue.transformer = &inputQueueTransformer{}
+ ld.outQueue.transformer = &outputQueueTransformer{}
+ return &ld
+}
+
+// getTermios gets the linux.Termios for the tty.
+func (l *lineDiscipline) getTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ // We must copy a Termios struct, not KernelTermios.
+ t := l.termios.ToTermios()
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), t, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+}
+
+// setTermios sets a linux.Termios for the tty.
+func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ l.termiosMu.Lock()
+ defer l.termiosMu.Unlock()
+ oldCanonEnabled := l.termios.LEnabled(linux.ICANON)
+ // We must copy a Termios struct, not KernelTermios.
+ var t linux.Termios
+ _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &t, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ l.termios.FromTermios(t)
+
+ // If canonical mode is turned off, move bytes from inQueue's wait
+ // buffer to its read buffer. Anything already in the read buffer is
+ // now readable.
+ if oldCanonEnabled && !l.termios.LEnabled(linux.ICANON) {
+ l.inQueue.mu.Lock()
+ l.inQueue.pushWaitBufLocked(l)
+ l.inQueue.readable = true
+ l.inQueue.mu.Unlock()
+ l.slaveWaiter.Notify(waiter.EventIn)
+ }
+
+ return 0, err
+}
+
+func (l *lineDiscipline) windowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ l.sizeMu.Lock()
+ defer l.sizeMu.Unlock()
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), l.size, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return err
+}
+
+func (l *lineDiscipline) setWindowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ l.sizeMu.Lock()
+ defer l.sizeMu.Unlock()
+ _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &l.size, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return err
+}
+
+func (l *lineDiscipline) masterReadiness() waiter.EventMask {
+ // We don't have to lock a termios because the default master termios
+ // is immutable.
+ return l.inQueue.writeReadiness(&linux.MasterTermios) | l.outQueue.readReadiness(&linux.MasterTermios)
+}
+
+func (l *lineDiscipline) slaveReadiness() waiter.EventMask {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ return l.outQueue.writeReadiness(&l.termios) | l.inQueue.readReadiness(&l.termios)
+}
+
+func (l *lineDiscipline) inputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ return l.inQueue.readableSize(ctx, io, args)
+}
+
+func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, pushed, err := l.inQueue.read(ctx, dst, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.masterWaiter.Notify(waiter.EventOut)
+ if pushed {
+ l.slaveWaiter.Notify(waiter.EventIn)
+ }
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, err := l.inQueue.write(ctx, src, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.slaveWaiter.Notify(waiter.EventIn)
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+func (l *lineDiscipline) outputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ return l.outQueue.readableSize(ctx, io, args)
+}
+
+func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, pushed, err := l.outQueue.read(ctx, dst, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.slaveWaiter.Notify(waiter.EventOut)
+ if pushed {
+ l.masterWaiter.Notify(waiter.EventIn)
+ }
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+func (l *lineDiscipline) outputQueueWrite(ctx context.Context, src usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, err := l.outQueue.write(ctx, src, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.masterWaiter.Notify(waiter.EventIn)
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+// transformer is a helper interface to make it easier to stateify queue.
+type transformer interface {
+ // transform functions require queue's mutex to be held.
+ transform(*lineDiscipline, *queue, []byte) int
+}
+
+// outputQueueTransformer implements transformer. It performs line discipline
+// transformations on the output queue.
+//
+// +stateify savable
+type outputQueueTransformer struct{}
+
+// transform does output processing for one end of the pty. See
+// drivers/tty/n_tty.c:do_output_char for an analogous kernel function.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+// * q.mu must be held.
+func (*outputQueueTransformer) transform(l *lineDiscipline, q *queue, buf []byte) int {
+ // transformOutput is effectively always in noncanonical mode, as the
+ // master termios never has ICANON set.
+
+ if !l.termios.OEnabled(linux.OPOST) {
+ q.readBuf = append(q.readBuf, buf...)
+ if len(q.readBuf) > 0 {
+ q.readable = true
+ }
+ return len(buf)
+ }
+
+ var ret int
+ for len(buf) > 0 {
+ size := l.peek(buf)
+ cBytes := append([]byte{}, buf[:size]...)
+ ret += size
+ buf = buf[size:]
+ // We're guaranteed that cBytes has at least one element.
+ switch cBytes[0] {
+ case '\n':
+ if l.termios.OEnabled(linux.ONLRET) {
+ l.column = 0
+ }
+ if l.termios.OEnabled(linux.ONLCR) {
+ q.readBuf = append(q.readBuf, '\r', '\n')
+ continue
+ }
+ case '\r':
+ if l.termios.OEnabled(linux.ONOCR) && l.column == 0 {
+ continue
+ }
+ if l.termios.OEnabled(linux.OCRNL) {
+ cBytes[0] = '\n'
+ if l.termios.OEnabled(linux.ONLRET) {
+ l.column = 0
+ }
+ break
+ }
+ l.column = 0
+ case '\t':
+ spaces := spacesPerTab - l.column%spacesPerTab
+ if l.termios.OutputFlags&linux.TABDLY == linux.XTABS {
+ l.column += spaces
+ q.readBuf = append(q.readBuf, bytes.Repeat([]byte{' '}, spacesPerTab)...)
+ continue
+ }
+ l.column += spaces
+ case '\b':
+ if l.column > 0 {
+ l.column--
+ }
+ default:
+ l.column++
+ }
+ q.readBuf = append(q.readBuf, cBytes...)
+ }
+ if len(q.readBuf) > 0 {
+ q.readable = true
+ }
+ return ret
+}
+
+// inputQueueTransformer implements transformer. It performs line discipline
+// transformations on the input queue.
+//
+// +stateify savable
+type inputQueueTransformer struct{}
+
+// transform does input processing for one end of the pty. Characters read are
+// transformed according to flags set in the termios struct. See
+// drivers/tty/n_tty.c:n_tty_receive_char_special for an analogous kernel
+// function.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+// * q.mu must be held.
+func (*inputQueueTransformer) transform(l *lineDiscipline, q *queue, buf []byte) int {
+ // If there's a line waiting to be read in canonical mode, don't write
+ // anything else to the read buffer.
+ if l.termios.LEnabled(linux.ICANON) && q.readable {
+ return 0
+ }
+
+ maxBytes := nonCanonMaxBytes
+ if l.termios.LEnabled(linux.ICANON) {
+ maxBytes = canonMaxBytes
+ }
+
+ var ret int
+ for len(buf) > 0 && len(q.readBuf) < canonMaxBytes {
+ size := l.peek(buf)
+ cBytes := append([]byte{}, buf[:size]...)
+ // We're guaranteed that cBytes has at least one element.
+ switch cBytes[0] {
+ case '\r':
+ if l.termios.IEnabled(linux.IGNCR) {
+ buf = buf[size:]
+ ret += size
+ continue
+ }
+ if l.termios.IEnabled(linux.ICRNL) {
+ cBytes[0] = '\n'
+ }
+ case '\n':
+ if l.termios.IEnabled(linux.INLCR) {
+ cBytes[0] = '\r'
+ }
+ }
+
+ // In canonical mode, we discard non-terminating characters
+ // after the first 4095.
+ if l.shouldDiscard(q, cBytes) {
+ buf = buf[size:]
+ ret += size
+ continue
+ }
+
+ // Stop if the buffer would be overfilled.
+ if len(q.readBuf)+size > maxBytes {
+ break
+ }
+ buf = buf[size:]
+ ret += size
+
+ // If we get EOF, make the buffer available for reading.
+ if l.termios.LEnabled(linux.ICANON) && l.termios.IsEOF(cBytes[0]) {
+ q.readable = true
+ break
+ }
+
+ q.readBuf = append(q.readBuf, cBytes...)
+
+ // Anything written to the readBuf will have to be echoed.
+ if l.termios.LEnabled(linux.ECHO) {
+ l.outQueue.writeBytes(cBytes, l)
+ l.masterWaiter.Notify(waiter.EventIn)
+ }
+
+ // If we finish a line, make it available for reading.
+ if l.termios.LEnabled(linux.ICANON) && l.termios.IsTerminating(cBytes) {
+ q.readable = true
+ break
+ }
+ }
+
+ // In noncanonical mode, everything is readable.
+ if !l.termios.LEnabled(linux.ICANON) && len(q.readBuf) > 0 {
+ q.readable = true
+ }
+
+ return ret
+}
+
+// shouldDiscard returns whether c should be discarded. In canonical mode, if
+// too many bytes are enqueued, we keep reading input and discarding it until
+// we find a terminating character. Signal/echo processing still occurs.
+//
+// Precondition:
+// * l.termiosMu must be held for reading.
+// * q.mu must be held.
+func (l *lineDiscipline) shouldDiscard(q *queue, cBytes []byte) bool {
+ return l.termios.LEnabled(linux.ICANON) && len(q.readBuf)+len(cBytes) >= canonMaxBytes && !l.termios.IsTerminating(cBytes)
+}
+
+// peek returns the size in bytes of the next character to process. As long as
+// b isn't empty, peek returns a value of at least 1.
+func (l *lineDiscipline) peek(b []byte) int {
+ size := 1
+ // If UTF-8 support is enabled, runes might be multiple bytes.
+ if l.termios.IEnabled(linux.IUTF8) {
+ _, size = utf8.DecodeRune(b)
+ }
+ return size
+}
+
+// LINT.ThenChange(../../fsimpl/devpts/line_discipline.go)
diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go
new file mode 100644
index 000000000..fe07fa929
--- /dev/null
+++ b/pkg/sentry/fs/tty/master.go
@@ -0,0 +1,238 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tty
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/unimpl"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+// masterInodeOperations are the fs.InodeOperations for the master end of the
+// Terminal (ptmx file).
+//
+// +stateify savable
+type masterInodeOperations struct {
+ fsutil.SimpleFileInode
+
+ // d is the containing dir.
+ d *dirInodeOperations
+}
+
+var _ fs.InodeOperations = (*masterInodeOperations)(nil)
+
+// newMasterInode creates an Inode for the master end of a terminal.
+func newMasterInode(ctx context.Context, d *dirInodeOperations, owner fs.FileOwner, p fs.FilePermissions) *fs.Inode {
+ iops := &masterInodeOperations{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, owner, p, linux.DEVPTS_SUPER_MAGIC),
+ d: d,
+ }
+
+ return fs.NewInode(ctx, iops, d.msrc, fs.StableAttr{
+ DeviceID: ptsDevice.DeviceID(),
+ // N.B. Linux always uses inode id 2 for ptmx. See
+ // fs/devpts/inode.c:mknod_ptmx.
+ //
+ // TODO(b/75267214): Since ptsDevice must be shared between
+ // different mounts, we must not assign fixed numbers.
+ InodeID: ptsDevice.NextIno(),
+ Type: fs.CharacterDevice,
+ // See fs/devpts/inode.c:devpts_fill_super.
+ BlockSize: 1024,
+ // The PTY master effectively has two different major/minor
+ // device numbers.
+ //
+ // This one is returned by stat for both opened and unopened
+ // instances of this inode.
+ //
+ // When the inode is opened (GetFile), a new device number is
+ // allocated based on major UNIX98_PTY_MASTER_MAJOR and the tty
+ // index as minor number. However, this device number is only
+ // accessible via ioctl(TIOCGDEV) and /proc/TID/stat.
+ DeviceFileMajor: linux.TTYAUX_MAJOR,
+ DeviceFileMinor: linux.PTMX_MINOR,
+ })
+}
+
+// Release implements fs.InodeOperations.Release.
+func (mi *masterInodeOperations) Release(ctx context.Context) {
+}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (*masterInodeOperations) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+//
+// It allocates a new terminal.
+func (mi *masterInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ t, err := mi.d.allocateTerminal(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return fs.NewFile(ctx, d, flags, &masterFileOperations{
+ d: mi.d,
+ t: t,
+ }), nil
+}
+
+// masterFileOperations are the fs.FileOperations for the master end of a terminal.
+//
+// +stateify savable
+type masterFileOperations struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ // d is the containing dir.
+ d *dirInodeOperations
+
+ // t is the connected Terminal.
+ t *Terminal
+}
+
+var _ fs.FileOperations = (*masterFileOperations)(nil)
+
+// Release implements fs.FileOperations.Release.
+func (mf *masterFileOperations) Release() {
+ mf.d.masterClose(mf.t)
+ mf.t.DecRef()
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (mf *masterFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ mf.t.ld.masterWaiter.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (mf *masterFileOperations) EventUnregister(e *waiter.Entry) {
+ mf.t.ld.masterWaiter.EventUnregister(e)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (mf *masterFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return mf.t.ld.masterReadiness()
+}
+
+// Read implements fs.FileOperations.Read.
+func (mf *masterFileOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ return mf.t.ld.outputQueueRead(ctx, dst)
+}
+
+// Write implements fs.FileOperations.Write.
+func (mf *masterFileOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ return mf.t.ld.inputQueueWrite(ctx, src)
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (mf *masterFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch cmd := args[1].Uint(); cmd {
+ case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ
+ // Get the number of bytes in the output queue read buffer.
+ return 0, mf.t.ld.outputQueueReadSize(ctx, io, args)
+ case linux.TCGETS:
+ // N.B. TCGETS on the master actually returns the configuration
+ // of the slave end.
+ return mf.t.ld.getTermios(ctx, io, args)
+ case linux.TCSETS:
+ // N.B. TCSETS on the master actually affects the configuration
+ // of the slave end.
+ return mf.t.ld.setTermios(ctx, io, args)
+ case linux.TCSETSW:
+ // TODO(b/29356795): This should drain the output queue first.
+ return mf.t.ld.setTermios(ctx, io, args)
+ case linux.TIOCGPTN:
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(mf.t.n), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ case linux.TIOCSPTLCK:
+ // TODO(b/29356795): Implement pty locking. For now just pretend we do.
+ return 0, nil
+ case linux.TIOCGWINSZ:
+ return 0, mf.t.ld.windowSize(ctx, io, args)
+ case linux.TIOCSWINSZ:
+ return 0, mf.t.ld.setWindowSize(ctx, io, args)
+ case linux.TIOCSCTTY:
+ // Make the given terminal the controlling terminal of the
+ // calling process.
+ return 0, mf.t.setControllingTTY(ctx, io, args, true /* isMaster */)
+ case linux.TIOCNOTTY:
+ // Release this process's controlling terminal.
+ return 0, mf.t.releaseControllingTTY(ctx, io, args, true /* isMaster */)
+ case linux.TIOCGPGRP:
+ // Get the foreground process group.
+ return mf.t.foregroundProcessGroup(ctx, io, args, true /* isMaster */)
+ case linux.TIOCSPGRP:
+ // Set the foreground process group.
+ return mf.t.setForegroundProcessGroup(ctx, io, args, true /* isMaster */)
+ default:
+ maybeEmitUnimplementedEvent(ctx, cmd)
+ return 0, syserror.ENOTTY
+ }
+}
+
+// maybeEmitUnimplementedEvent emits unimplemented event if cmd is valid.
+func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) {
+ switch cmd {
+ case linux.TCGETS,
+ linux.TCSETS,
+ linux.TCSETSW,
+ linux.TCSETSF,
+ linux.TIOCGWINSZ,
+ linux.TIOCSWINSZ,
+ linux.TIOCSETD,
+ linux.TIOCSBRK,
+ linux.TIOCCBRK,
+ linux.TCSBRK,
+ linux.TCSBRKP,
+ linux.TIOCSTI,
+ linux.TIOCCONS,
+ linux.FIONBIO,
+ linux.TIOCEXCL,
+ linux.TIOCNXCL,
+ linux.TIOCGEXCL,
+ linux.TIOCGSID,
+ linux.TIOCGETD,
+ linux.TIOCVHANGUP,
+ linux.TIOCGDEV,
+ linux.TIOCMGET,
+ linux.TIOCMSET,
+ linux.TIOCMBIC,
+ linux.TIOCMBIS,
+ linux.TIOCGICOUNT,
+ linux.TCFLSH,
+ linux.TIOCSSERIAL,
+ linux.TIOCGPTPEER:
+
+ unimpl.EmitUnimplementedEvent(ctx)
+ }
+}
+
+// LINT.ThenChange(../../fsimpl/devpts/master.go)
diff --git a/pkg/sentry/fs/tty/queue.go b/pkg/sentry/fs/tty/queue.go
new file mode 100644
index 000000000..ceabb9b1e
--- /dev/null
+++ b/pkg/sentry/fs/tty/queue.go
@@ -0,0 +1,240 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tty
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+// waitBufMaxBytes is the maximum size of a wait buffer. It is based on
+// TTYB_DEFAULT_MEM_LIMIT.
+const waitBufMaxBytes = 131072
+
+// queue represents one of the input or output queues between a pty master and
+// slave. Bytes written to a queue are added to the read buffer until it is
+// full, at which point they are written to the wait buffer. Bytes are
+// processed (i.e. undergo termios transformations) as they are added to the
+// read buffer. The read buffer is readable when its length is nonzero and
+// readable is true.
+//
+// +stateify savable
+type queue struct {
+ // mu protects everything in queue.
+ mu sync.Mutex `state:"nosave"`
+
+ // readBuf is buffer of data ready to be read when readable is true.
+ // This data has been processed.
+ readBuf []byte
+
+ // waitBuf contains data that can't fit into readBuf. It is put here
+ // until it can be loaded into the read buffer. waitBuf contains data
+ // that hasn't been processed.
+ waitBuf [][]byte
+ waitBufLen uint64
+
+ // readable indicates whether the read buffer can be read from. In
+ // canonical mode, there can be an unterminated line in the read buffer,
+ // so readable must be checked.
+ readable bool
+
+ // transform is the the queue's function for transforming bytes
+ // entering the queue. For example, transform might convert all '\r's
+ // entering the queue to '\n's.
+ transformer
+}
+
+// readReadiness returns whether q is ready to be read from.
+func (q *queue) readReadiness(t *linux.KernelTermios) waiter.EventMask {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ if len(q.readBuf) > 0 && q.readable {
+ return waiter.EventIn
+ }
+ return waiter.EventMask(0)
+}
+
+// writeReadiness returns whether q is ready to be written to.
+func (q *queue) writeReadiness(t *linux.KernelTermios) waiter.EventMask {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ if q.waitBufLen < waitBufMaxBytes {
+ return waiter.EventOut
+ }
+ return waiter.EventMask(0)
+}
+
+// readableSize writes the number of readable bytes to userspace.
+func (q *queue) readableSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ var size int32
+ if q.readable {
+ size = int32(len(q.readBuf))
+ }
+
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), size, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return err
+
+}
+
+// read reads from q to userspace. It returns the number of bytes read as well
+// as whether the read caused more readable data to become available (whether
+// data was pushed from the wait buffer to the read buffer).
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+func (q *queue) read(ctx context.Context, dst usermem.IOSequence, l *lineDiscipline) (int64, bool, error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ if !q.readable {
+ return 0, false, syserror.ErrWouldBlock
+ }
+
+ if dst.NumBytes() > canonMaxBytes {
+ dst = dst.TakeFirst(canonMaxBytes)
+ }
+
+ n, err := dst.CopyOutFrom(ctx, safemem.ReaderFunc(func(dst safemem.BlockSeq) (uint64, error) {
+ src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(q.readBuf))
+ n, err := safemem.CopySeq(dst, src)
+ if err != nil {
+ return 0, err
+ }
+ q.readBuf = q.readBuf[n:]
+
+ // If we read everything, this queue is no longer readable.
+ if len(q.readBuf) == 0 {
+ q.readable = false
+ }
+
+ return n, nil
+ }))
+ if err != nil {
+ return 0, false, err
+ }
+
+ // Move data from the queue's wait buffer to its read buffer.
+ nPushed := q.pushWaitBufLocked(l)
+
+ return int64(n), nPushed > 0, nil
+}
+
+// write writes to q from userspace.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+func (q *queue) write(ctx context.Context, src usermem.IOSequence, l *lineDiscipline) (int64, error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ // Copy data into the wait buffer.
+ n, err := src.CopyInTo(ctx, safemem.WriterFunc(func(src safemem.BlockSeq) (uint64, error) {
+ copyLen := src.NumBytes()
+ room := waitBufMaxBytes - q.waitBufLen
+ // If out of room, return EAGAIN.
+ if room == 0 && copyLen > 0 {
+ return 0, syserror.ErrWouldBlock
+ }
+ // Cap the size of the wait buffer.
+ if copyLen > room {
+ copyLen = room
+ src = src.TakeFirst64(room)
+ }
+ buf := make([]byte, copyLen)
+
+ // Copy the data into the wait buffer.
+ dst := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf))
+ n, err := safemem.CopySeq(dst, src)
+ if err != nil {
+ return 0, err
+ }
+ q.waitBufAppend(buf)
+
+ return n, nil
+ }))
+ if err != nil {
+ return 0, err
+ }
+
+ // Push data from the wait to the read buffer.
+ q.pushWaitBufLocked(l)
+
+ return n, nil
+}
+
+// writeBytes writes to q from b.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+func (q *queue) writeBytes(b []byte, l *lineDiscipline) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ // Write to the wait buffer.
+ q.waitBufAppend(b)
+ q.pushWaitBufLocked(l)
+}
+
+// pushWaitBufLocked fills the queue's read buffer with data from the wait
+// buffer.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+// * q.mu must be locked.
+func (q *queue) pushWaitBufLocked(l *lineDiscipline) int {
+ if q.waitBufLen == 0 {
+ return 0
+ }
+
+ // Move data from the wait to the read buffer.
+ var total int
+ var i int
+ for i = 0; i < len(q.waitBuf); i++ {
+ n := q.transform(l, q, q.waitBuf[i])
+ total += n
+ if n != len(q.waitBuf[i]) {
+ // The read buffer filled up without consuming the
+ // entire buffer.
+ q.waitBuf[i] = q.waitBuf[i][n:]
+ break
+ }
+ }
+
+ // Update wait buffer based on consumed data.
+ q.waitBuf = q.waitBuf[i:]
+ q.waitBufLen -= uint64(total)
+
+ return total
+}
+
+// Precondition: q.mu must be locked.
+func (q *queue) waitBufAppend(b []byte) {
+ q.waitBuf = append(q.waitBuf, b)
+ q.waitBufLen += uint64(len(b))
+}
+
+// LINT.ThenChange(../../fsimpl/devpts/queue.go)
diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go
new file mode 100644
index 000000000..9871f6fc6
--- /dev/null
+++ b/pkg/sentry/fs/tty/slave.go
@@ -0,0 +1,178 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tty
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+// slaveInodeOperations are the fs.InodeOperations for the slave end of the
+// Terminal (pts file).
+//
+// +stateify savable
+type slaveInodeOperations struct {
+ fsutil.SimpleFileInode
+
+ // d is the containing dir.
+ d *dirInodeOperations
+
+ // t is the connected Terminal.
+ t *Terminal
+}
+
+var _ fs.InodeOperations = (*slaveInodeOperations)(nil)
+
+// newSlaveInode creates an fs.Inode for the slave end of a terminal.
+//
+// newSlaveInode takes ownership of t.
+func newSlaveInode(ctx context.Context, d *dirInodeOperations, t *Terminal, owner fs.FileOwner, p fs.FilePermissions) *fs.Inode {
+ iops := &slaveInodeOperations{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, owner, p, linux.DEVPTS_SUPER_MAGIC),
+ d: d,
+ t: t,
+ }
+
+ return fs.NewInode(ctx, iops, d.msrc, fs.StableAttr{
+ DeviceID: ptsDevice.DeviceID(),
+ // N.B. Linux always uses inode id = tty index + 3. See
+ // fs/devpts/inode.c:devpts_pty_new.
+ //
+ // TODO(b/75267214): Since ptsDevice must be shared between
+ // different mounts, we must not assign fixed numbers.
+ InodeID: ptsDevice.NextIno(),
+ Type: fs.CharacterDevice,
+ // See fs/devpts/inode.c:devpts_fill_super.
+ BlockSize: 1024,
+ DeviceFileMajor: linux.UNIX98_PTY_SLAVE_MAJOR,
+ DeviceFileMinor: t.n,
+ })
+}
+
+// Release implements fs.InodeOperations.Release.
+func (si *slaveInodeOperations) Release(ctx context.Context) {
+ si.t.DecRef()
+}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (*slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+//
+// This may race with destruction of the terminal. If the terminal is gone, it
+// returns ENOENT.
+func (si *slaveInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, d, flags, &slaveFileOperations{si: si}), nil
+}
+
+// slaveFileOperations are the fs.FileOperations for the slave end of a terminal.
+//
+// +stateify savable
+type slaveFileOperations struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ // si is the inode operations.
+ si *slaveInodeOperations
+}
+
+var _ fs.FileOperations = (*slaveFileOperations)(nil)
+
+// Release implements fs.FileOperations.Release.
+func (sf *slaveFileOperations) Release() {
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (sf *slaveFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ sf.si.t.ld.slaveWaiter.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (sf *slaveFileOperations) EventUnregister(e *waiter.Entry) {
+ sf.si.t.ld.slaveWaiter.EventUnregister(e)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (sf *slaveFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return sf.si.t.ld.slaveReadiness()
+}
+
+// Read implements fs.FileOperations.Read.
+func (sf *slaveFileOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ return sf.si.t.ld.inputQueueRead(ctx, dst)
+}
+
+// Write implements fs.FileOperations.Write.
+func (sf *slaveFileOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ return sf.si.t.ld.outputQueueWrite(ctx, src)
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (sf *slaveFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch cmd := args[1].Uint(); cmd {
+ case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ
+ // Get the number of bytes in the input queue read buffer.
+ return 0, sf.si.t.ld.inputQueueReadSize(ctx, io, args)
+ case linux.TCGETS:
+ return sf.si.t.ld.getTermios(ctx, io, args)
+ case linux.TCSETS:
+ return sf.si.t.ld.setTermios(ctx, io, args)
+ case linux.TCSETSW:
+ // TODO(b/29356795): This should drain the output queue first.
+ return sf.si.t.ld.setTermios(ctx, io, args)
+ case linux.TIOCGPTN:
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(sf.si.t.n), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ case linux.TIOCGWINSZ:
+ return 0, sf.si.t.ld.windowSize(ctx, io, args)
+ case linux.TIOCSWINSZ:
+ return 0, sf.si.t.ld.setWindowSize(ctx, io, args)
+ case linux.TIOCSCTTY:
+ // Make the given terminal the controlling terminal of the
+ // calling process.
+ return 0, sf.si.t.setControllingTTY(ctx, io, args, false /* isMaster */)
+ case linux.TIOCNOTTY:
+ // Release this process's controlling terminal.
+ return 0, sf.si.t.releaseControllingTTY(ctx, io, args, false /* isMaster */)
+ case linux.TIOCGPGRP:
+ // Get the foreground process group.
+ return sf.si.t.foregroundProcessGroup(ctx, io, args, false /* isMaster */)
+ case linux.TIOCSPGRP:
+ // Set the foreground process group.
+ return sf.si.t.setForegroundProcessGroup(ctx, io, args, false /* isMaster */)
+ default:
+ maybeEmitUnimplementedEvent(ctx, cmd)
+ return 0, syserror.ENOTTY
+ }
+}
+
+// LINT.ThenChange(../../fsimpl/devpts/slave.go)
diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go
new file mode 100644
index 000000000..ddcccf4da
--- /dev/null
+++ b/pkg/sentry/fs/tty/terminal.go
@@ -0,0 +1,132 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tty
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// LINT.IfChange
+
+// Terminal is a pseudoterminal.
+//
+// +stateify savable
+type Terminal struct {
+ refs.AtomicRefCount
+
+ // n is the terminal index. It is immutable.
+ n uint32
+
+ // d is the containing directory. It is immutable.
+ d *dirInodeOperations
+
+ // ld is the line discipline of the terminal. It is immutable.
+ ld *lineDiscipline
+
+ // masterKTTY contains the controlling process of the master end of
+ // this terminal. This field is immutable.
+ masterKTTY *kernel.TTY
+
+ // slaveKTTY contains the controlling process of the slave end of this
+ // terminal. This field is immutable.
+ slaveKTTY *kernel.TTY
+}
+
+func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal {
+ termios := linux.DefaultSlaveTermios
+ t := Terminal{
+ d: d,
+ n: n,
+ ld: newLineDiscipline(termios),
+ masterKTTY: &kernel.TTY{Index: n},
+ slaveKTTY: &kernel.TTY{Index: n},
+ }
+ t.EnableLeakCheck("tty.Terminal")
+ return &t
+}
+
+// setControllingTTY makes tm the controlling terminal of the calling thread
+// group.
+func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("setControllingTTY must be called from a task context")
+ }
+
+ return task.ThreadGroup().SetControllingTTY(tm.tty(isMaster), args[2].Int())
+}
+
+// releaseControllingTTY removes tm as the controlling terminal of the calling
+// thread group.
+func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("releaseControllingTTY must be called from a task context")
+ }
+
+ return task.ThreadGroup().ReleaseControllingTTY(tm.tty(isMaster))
+}
+
+// foregroundProcessGroup gets the process group ID of tm's foreground process.
+func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("foregroundProcessGroup must be called from a task context")
+ }
+
+ ret, err := task.ThreadGroup().ForegroundProcessGroup(tm.tty(isMaster))
+ if err != nil {
+ return 0, err
+ }
+
+ // Write it out to *arg.
+ _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(ret), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+}
+
+// foregroundProcessGroup sets tm's foreground process.
+func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("setForegroundProcessGroup must be called from a task context")
+ }
+
+ // Read in the process group ID.
+ var pgid int32
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgid, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+
+ ret, err := task.ThreadGroup().SetForegroundProcessGroup(tm.tty(isMaster), kernel.ProcessGroupID(pgid))
+ return uintptr(ret), err
+}
+
+func (tm *Terminal) tty(isMaster bool) *kernel.TTY {
+ if isMaster {
+ return tm.masterKTTY
+ }
+ return tm.slaveKTTY
+}
+
+// LINT.ThenChange(../../fsimpl/devpts/terminal.go)
diff --git a/pkg/sentry/fs/tty/tty_test.go b/pkg/sentry/fs/tty/tty_test.go
new file mode 100644
index 000000000..2cbc05678
--- /dev/null
+++ b/pkg/sentry/fs/tty/tty_test.go
@@ -0,0 +1,56 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tty
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func TestSimpleMasterToSlave(t *testing.T) {
+ ld := newLineDiscipline(linux.DefaultSlaveTermios)
+ ctx := contexttest.Context(t)
+ inBytes := []byte("hello, tty\n")
+ src := usermem.BytesIOSequence(inBytes)
+ outBytes := make([]byte, 32)
+ dst := usermem.BytesIOSequence(outBytes)
+
+ // Write to the input queue.
+ nw, err := ld.inputQueueWrite(ctx, src)
+ if err != nil {
+ t.Fatalf("error writing to input queue: %v", err)
+ }
+ if nw != int64(len(inBytes)) {
+ t.Fatalf("wrote wrong length: got %d, want %d", nw, len(inBytes))
+ }
+
+ // Read from the input queue.
+ nr, err := ld.inputQueueRead(ctx, dst)
+ if err != nil {
+ t.Fatalf("error reading from input queue: %v", err)
+ }
+ if nr != int64(len(inBytes)) {
+ t.Fatalf("read wrong length: got %d, want %d", nr, len(inBytes))
+ }
+
+ outStr := string(outBytes[:nr])
+ inStr := string(inBytes)
+ if outStr != inStr {
+ t.Fatalf("written and read strings do not match: got %q, want %q", outStr, inStr)
+ }
+}
diff --git a/pkg/sentry/fs/user/BUILD b/pkg/sentry/fs/user/BUILD
new file mode 100644
index 000000000..66e949c95
--- /dev/null
+++ b/pkg/sentry/fs/user/BUILD
@@ -0,0 +1,40 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "user",
+ srcs = [
+ "path.go",
+ "user.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/log",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
+
+go_test(
+ name = "user_test",
+ size = "small",
+ srcs = ["user_test.go"],
+ library = ":user",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/contexttest",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fs/user/path.go b/pkg/sentry/fs/user/path.go
new file mode 100644
index 000000000..397e96045
--- /dev/null
+++ b/pkg/sentry/fs/user/path.go
@@ -0,0 +1,170 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package user
+
+import (
+ "fmt"
+ "path"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// ResolveExecutablePath resolves the given executable name given the working
+// dir and environment.
+func ResolveExecutablePath(ctx context.Context, args *kernel.CreateProcessArgs) (string, error) {
+ name := args.Filename
+ if len(name) == 0 {
+ if len(args.Argv) == 0 {
+ return "", fmt.Errorf("no filename or command provided")
+ }
+ name = args.Argv[0]
+ }
+
+ // Absolute paths can be used directly.
+ if path.IsAbs(name) {
+ return name, nil
+ }
+
+ // Paths with '/' in them should be joined to the working directory, or
+ // to the root if working directory is not set.
+ if strings.IndexByte(name, '/') > 0 {
+ wd := args.WorkingDirectory
+ if wd == "" {
+ wd = "/"
+ }
+ if !path.IsAbs(wd) {
+ return "", fmt.Errorf("working directory %q must be absolute", wd)
+ }
+ return path.Join(wd, name), nil
+ }
+
+ // Otherwise, We must lookup the name in the paths.
+ paths := getPath(args.Envv)
+ if kernel.VFS2Enabled {
+ f, err := resolveVFS2(ctx, args.Credentials, args.MountNamespaceVFS2, paths, name)
+ if err != nil {
+ return "", fmt.Errorf("error finding executable %q in PATH %v: %v", name, paths, err)
+ }
+ return f, nil
+ }
+
+ f, err := resolve(ctx, args.MountNamespace, paths, name)
+ if err != nil {
+ return "", fmt.Errorf("error finding executable %q in PATH %v: %v", name, paths, err)
+ }
+ return f, nil
+}
+
+func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name string) (string, error) {
+ root := fs.RootFromContext(ctx)
+ if root == nil {
+ // Caller has no root. Don't bother traversing anything.
+ return "", syserror.ENOENT
+ }
+ defer root.DecRef()
+ for _, p := range paths {
+ if !path.IsAbs(p) {
+ // Relative paths aren't safe, no one should be using them.
+ log.Warningf("Skipping relative path %q in $PATH", p)
+ continue
+ }
+
+ binPath := path.Join(p, name)
+ traversals := uint(linux.MaxSymlinkTraversals)
+ d, err := mns.FindInode(ctx, root, nil, binPath, &traversals)
+ if err == syserror.ENOENT || err == syserror.EACCES {
+ // Didn't find it here.
+ continue
+ }
+ if err != nil {
+ return "", err
+ }
+ defer d.DecRef()
+
+ // Check that it is a regular file.
+ if !fs.IsRegular(d.Inode.StableAttr) {
+ continue
+ }
+
+ // Check whether we can read and execute the found file.
+ if err := d.Inode.CheckPermission(ctx, fs.PermMask{Read: true, Execute: true}); err != nil {
+ log.Infof("Found executable at %q, but user cannot execute it: %v", binPath, err)
+ continue
+ }
+ return path.Join("/", p, name), nil
+ }
+
+ // Couldn't find it.
+ return "", syserror.ENOENT
+}
+
+func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNamespace, paths []string, name string) (string, error) {
+ root := mns.Root()
+ defer root.DecRef()
+ for _, p := range paths {
+ if !path.IsAbs(p) {
+ // Relative paths aren't safe, no one should be using them.
+ log.Warningf("Skipping relative path %q in $PATH", p)
+ continue
+ }
+
+ binPath := path.Join(p, name)
+ pop := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(binPath),
+ FollowFinalSymlink: true,
+ }
+ opts := &vfs.OpenOptions{
+ FileExec: true,
+ Flags: linux.O_RDONLY,
+ }
+ dentry, err := root.Mount().Filesystem().VirtualFilesystem().OpenAt(ctx, creds, pop, opts)
+ if err == syserror.ENOENT || err == syserror.EACCES {
+ // Didn't find it here.
+ continue
+ }
+ if err != nil {
+ return "", err
+ }
+ dentry.DecRef()
+
+ return binPath, nil
+ }
+
+ // Couldn't find it.
+ return "", syserror.ENOENT
+}
+
+// getPath returns the PATH as a slice of strings given the environment
+// variables.
+func getPath(env []string) []string {
+ const prefix = "PATH="
+ for _, e := range env {
+ if strings.HasPrefix(e, prefix) {
+ return strings.Split(strings.TrimPrefix(e, prefix), ":")
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/fs/user/user.go b/pkg/sentry/fs/user/user.go
new file mode 100644
index 000000000..f4d525523
--- /dev/null
+++ b/pkg/sentry/fs/user/user.go
@@ -0,0 +1,239 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package user contains methods for resolving filesystem paths based on the
+// user and their environment.
+package user
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "strconv"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type fileReader struct {
+ // Ctx is the context for the file reader.
+ Ctx context.Context
+
+ // File is the file to read from.
+ File *fs.File
+}
+
+// Read implements io.Reader.Read.
+func (r *fileReader) Read(buf []byte) (int, error) {
+ n, err := r.File.Readv(r.Ctx, usermem.BytesIOSequence(buf))
+ return int(n), err
+}
+
+// getExecUserHome returns the home directory of the executing user read from
+// /etc/passwd as read from the container filesystem.
+func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid auth.KUID) (string, error) {
+ // The default user home directory to return if no user matching the user
+ // if found in the /etc/passwd found in the image.
+ const defaultHome = "/"
+
+ // Open the /etc/passwd file from the dirent via the root mount namespace.
+ mnsRoot := rootMns.Root()
+ maxTraversals := uint(linux.MaxSymlinkTraversals)
+ dirent, err := rootMns.FindInode(ctx, mnsRoot, nil, "/etc/passwd", &maxTraversals)
+ if err != nil {
+ // NOTE: Ignore errors opening the passwd file. If the passwd file
+ // doesn't exist we will return the default home directory.
+ return defaultHome, nil
+ }
+ defer dirent.DecRef()
+
+ // Check read permissions on the file.
+ if err := dirent.Inode.CheckPermission(ctx, fs.PermMask{Read: true}); err != nil {
+ // NOTE: Ignore permissions errors here and return default root dir.
+ return defaultHome, nil
+ }
+
+ // Only open regular files. We don't open other files like named pipes as
+ // they may block and might present some attack surface to the container.
+ // Note that runc does not seem to do this kind of checking.
+ if !fs.IsRegular(dirent.Inode.StableAttr) {
+ return defaultHome, nil
+ }
+
+ f, err := dirent.Inode.GetFile(ctx, dirent, fs.FileFlags{Read: true, Directory: false})
+ if err != nil {
+ return "", err
+ }
+ defer f.DecRef()
+
+ r := &fileReader{
+ Ctx: ctx,
+ File: f,
+ }
+
+ return findHomeInPasswd(uint32(uid), r, defaultHome)
+}
+
+type fileReaderVFS2 struct {
+ ctx context.Context
+ fd *vfs.FileDescription
+}
+
+func (r *fileReaderVFS2) Read(buf []byte) (int, error) {
+ n, err := r.fd.Read(r.ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ return int(n), err
+}
+
+func getExecUserHomeVFS2(ctx context.Context, mns *vfs.MountNamespace, uid auth.KUID) (string, error) {
+ const defaultHome = "/"
+
+ root := mns.Root()
+ defer root.DecRef()
+
+ creds := auth.CredentialsFromContext(ctx)
+
+ target := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse("/etc/passwd"),
+ }
+
+ opts := &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ }
+
+ fd, err := root.Mount().Filesystem().VirtualFilesystem().OpenAt(ctx, creds, target, opts)
+ if err != nil {
+ return defaultHome, nil
+ }
+ defer fd.DecRef()
+
+ r := &fileReaderVFS2{
+ ctx: ctx,
+ fd: fd,
+ }
+
+ homeDir, err := findHomeInPasswd(uint32(uid), r, defaultHome)
+ if err != nil {
+ return "", err
+ }
+
+ return homeDir, nil
+}
+
+// MaybeAddExecUserHome returns a new slice with the HOME enviroment variable
+// set if the slice does not already contain it, otherwise it returns the
+// original slice unmodified.
+func MaybeAddExecUserHome(ctx context.Context, mns *fs.MountNamespace, uid auth.KUID, envv []string) ([]string, error) {
+ // Check if the envv already contains HOME.
+ for _, env := range envv {
+ if strings.HasPrefix(env, "HOME=") {
+ // We have it. Return the original slice unmodified.
+ return envv, nil
+ }
+ }
+
+ // Read /etc/passwd for the user's HOME directory and set the HOME
+ // environment variable as required by POSIX if it is not overridden by
+ // the user.
+ homeDir, err := getExecUserHome(ctx, mns, uid)
+ if err != nil {
+ return nil, fmt.Errorf("error reading exec user: %v", err)
+ }
+
+ return append(envv, "HOME="+homeDir), nil
+}
+
+// MaybeAddExecUserHomeVFS2 returns a new slice with the HOME enviroment
+// variable set if the slice does not already contain it, otherwise it returns
+// the original slice unmodified.
+func MaybeAddExecUserHomeVFS2(ctx context.Context, vmns *vfs.MountNamespace, uid auth.KUID, envv []string) ([]string, error) {
+ // Check if the envv already contains HOME.
+ for _, env := range envv {
+ if strings.HasPrefix(env, "HOME=") {
+ // We have it. Return the original slice unmodified.
+ return envv, nil
+ }
+ }
+
+ // Read /etc/passwd for the user's HOME directory and set the HOME
+ // environment variable as required by POSIX if it is not overridden by
+ // the user.
+ homeDir, err := getExecUserHomeVFS2(ctx, vmns, uid)
+ if err != nil {
+ return nil, fmt.Errorf("error reading exec user: %v", err)
+ }
+ return append(envv, "HOME="+homeDir), nil
+}
+
+// findHomeInPasswd parses a passwd file and returns the given user's home
+// directory. This function does it's best to replicate the runc's behavior.
+func findHomeInPasswd(uid uint32, passwd io.Reader, defaultHome string) (string, error) {
+ s := bufio.NewScanner(passwd)
+
+ for s.Scan() {
+ if err := s.Err(); err != nil {
+ return "", err
+ }
+
+ line := strings.TrimSpace(s.Text())
+ if line == "" {
+ continue
+ }
+
+ // Pull out part of passwd entry. Loosely parse the passwd entry as some
+ // passwd files could be poorly written and for compatibility with runc.
+ //
+ // Per 'man 5 passwd'
+ // /etc/passwd contains one line for each user account, with seven
+ // fields delimited by colons (“:”). These fields are:
+ //
+ // - login name
+ // - optional encrypted password
+ // - numerical user ID
+ // - numerical group ID
+ // - user name or comment field
+ // - user home directory
+ // - optional user command interpreter
+ parts := strings.Split(line, ":")
+
+ found := false
+ homeDir := ""
+ for i, p := range parts {
+ switch i {
+ case 2:
+ parsedUID, err := strconv.ParseUint(p, 10, 32)
+ if err == nil && parsedUID == uint64(uid) {
+ found = true
+ }
+ case 5:
+ homeDir = p
+ }
+ }
+ if found {
+ // NOTE: If the uid is present but the home directory is not
+ // present in the /etc/passwd entry we return an empty string. This
+ // is, for better or worse, what runc does.
+ return homeDir, nil
+ }
+ }
+
+ return defaultHome, nil
+}
diff --git a/pkg/sentry/fs/user/user_test.go b/pkg/sentry/fs/user/user_test.go
new file mode 100644
index 000000000..7d8e9ac7c
--- /dev/null
+++ b/pkg/sentry/fs/user/user_test.go
@@ -0,0 +1,198 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package user
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// createEtcPasswd creates /etc/passwd with the given contents and mode. If
+// mode is empty, then no file will be created. If mode is not a regular file
+// mode, then contents is ignored.
+func createEtcPasswd(ctx context.Context, root *fs.Dirent, contents string, mode linux.FileMode) error {
+ if err := root.CreateDirectory(ctx, root, "etc", fs.FilePermsFromMode(0755)); err != nil {
+ return err
+ }
+ etc, err := root.Walk(ctx, root, "etc")
+ if err != nil {
+ return err
+ }
+ defer etc.DecRef()
+ switch mode.FileType() {
+ case 0:
+ // Don't create anything.
+ return nil
+ case linux.S_IFREG:
+ passwd, err := etc.Create(ctx, root, "passwd", fs.FileFlags{Write: true}, fs.FilePermsFromMode(mode))
+ if err != nil {
+ return err
+ }
+ defer passwd.DecRef()
+ if _, err := passwd.Writev(ctx, usermem.BytesIOSequence([]byte(contents))); err != nil {
+ return err
+ }
+ return nil
+ case linux.S_IFDIR:
+ return etc.CreateDirectory(ctx, root, "passwd", fs.FilePermsFromMode(mode))
+ case linux.S_IFIFO:
+ return etc.CreateFifo(ctx, root, "passwd", fs.FilePermsFromMode(mode))
+ default:
+ return fmt.Errorf("unknown file type %x", mode.FileType())
+ }
+}
+
+// TestGetExecUserHome tests the getExecUserHome function.
+func TestGetExecUserHome(t *testing.T) {
+ tests := map[string]struct {
+ uid auth.KUID
+ passwdContents string
+ passwdMode linux.FileMode
+ expected string
+ }{
+ "success": {
+ uid: 1000,
+ passwdContents: "adin::1000:1111::/home/adin:/bin/sh",
+ passwdMode: linux.S_IFREG | 0666,
+ expected: "/home/adin",
+ },
+ "no_perms": {
+ uid: 1000,
+ passwdContents: "adin::1000:1111::/home/adin:/bin/sh",
+ passwdMode: linux.S_IFREG,
+ expected: "/",
+ },
+ "no_passwd": {
+ uid: 1000,
+ expected: "/",
+ },
+ "directory": {
+ uid: 1000,
+ passwdMode: linux.S_IFDIR | 0666,
+ expected: "/",
+ },
+ // Currently we don't allow named pipes.
+ "named_pipe": {
+ uid: 1000,
+ passwdMode: linux.S_IFIFO | 0666,
+ expected: "/",
+ },
+ }
+
+ for name, tc := range tests {
+ t.Run(name, func(t *testing.T) {
+ ctx := contexttest.Context(t)
+ msrc := fs.NewPseudoMountSource(ctx)
+ rootInode := tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc)
+
+ mns, err := fs.NewMountNamespace(ctx, rootInode)
+ if err != nil {
+ t.Fatalf("NewMountNamespace failed: %v", err)
+ }
+ defer mns.DecRef()
+ root := mns.Root()
+ defer root.DecRef()
+ ctx = fs.WithRoot(ctx, root)
+
+ if err := createEtcPasswd(ctx, root, tc.passwdContents, tc.passwdMode); err != nil {
+ t.Fatalf("createEtcPasswd failed: %v", err)
+ }
+
+ got, err := getExecUserHome(ctx, mns, tc.uid)
+ if err != nil {
+ t.Fatalf("failed to get user home: %v", err)
+ }
+
+ if got != tc.expected {
+ t.Fatalf("expected %v, got: %v", tc.expected, got)
+ }
+ })
+ }
+}
+
+// TestFindHomeInPasswd tests the findHomeInPasswd function's passwd file parsing.
+func TestFindHomeInPasswd(t *testing.T) {
+ tests := map[string]struct {
+ uid uint32
+ passwd string
+ expected string
+ def string
+ }{
+ "empty": {
+ uid: 1000,
+ passwd: "",
+ expected: "/",
+ def: "/",
+ },
+ "whitespace": {
+ uid: 1000,
+ passwd: " ",
+ expected: "/",
+ def: "/",
+ },
+ "full": {
+ uid: 1000,
+ passwd: "adin::1000:1111::/home/adin:/bin/sh",
+ expected: "/home/adin",
+ def: "/",
+ },
+ // For better or worse, this is how runc works.
+ "partial": {
+ uid: 1000,
+ passwd: "adin::1000:1111:",
+ expected: "",
+ def: "/",
+ },
+ "multiple": {
+ uid: 1001,
+ passwd: "adin::1000:1111::/home/adin:/bin/sh\nian::1001:1111::/home/ian:/bin/sh",
+ expected: "/home/ian",
+ def: "/",
+ },
+ "duplicate": {
+ uid: 1000,
+ passwd: "adin::1000:1111::/home/adin:/bin/sh\nian::1000:1111::/home/ian:/bin/sh",
+ expected: "/home/adin",
+ def: "/",
+ },
+ "empty_lines": {
+ uid: 1001,
+ passwd: "adin::1000:1111::/home/adin:/bin/sh\n\n\nian::1001:1111::/home/ian:/bin/sh",
+ expected: "/home/ian",
+ def: "/",
+ },
+ }
+
+ for name, tc := range tests {
+ t.Run(name, func(t *testing.T) {
+ got, err := findHomeInPasswd(tc.uid, strings.NewReader(tc.passwd), tc.def)
+ if err != nil {
+ t.Fatalf("error parsing passwd: %v", err)
+ }
+ if tc.expected != got {
+ t.Fatalf("expected %v, got: %v", tc.expected, got)
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/fsbridge/BUILD b/pkg/sentry/fsbridge/BUILD
new file mode 100644
index 000000000..6c798f0bd
--- /dev/null
+++ b/pkg/sentry/fsbridge/BUILD
@@ -0,0 +1,24 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "fsbridge",
+ srcs = [
+ "bridge.go",
+ "fs.go",
+ "vfs.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsbridge/bridge.go b/pkg/sentry/fsbridge/bridge.go
new file mode 100644
index 000000000..8e7590721
--- /dev/null
+++ b/pkg/sentry/fsbridge/bridge.go
@@ -0,0 +1,54 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fsbridge provides common interfaces to bridge between VFS1 and VFS2
+// files.
+package fsbridge
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// File provides a common interface to bridge between VFS1 and VFS2 files.
+type File interface {
+ // PathnameWithDeleted returns an absolute pathname to vd, consistent with
+ // Linux's d_path(). In particular, if vd.Dentry() has been disowned,
+ // PathnameWithDeleted appends " (deleted)" to the returned pathname.
+ PathnameWithDeleted(ctx context.Context) string
+
+ // ReadFull read all contents from the file.
+ ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error)
+
+ // ConfigureMMap mutates opts to implement mmap(2) for the file.
+ ConfigureMMap(context.Context, *memmap.MMapOpts) error
+
+ // Type returns the file type, e.g. linux.S_IFREG.
+ Type(context.Context) (linux.FileMode, error)
+
+ // IncRef increments reference.
+ IncRef()
+
+ // DecRef decrements reference.
+ DecRef()
+}
+
+// Lookup provides a common interface to open files.
+type Lookup interface {
+ // OpenPath opens a file.
+ OpenPath(ctx context.Context, path string, opts vfs.OpenOptions, remainingTraversals *uint, resolveFinal bool) (File, error)
+}
diff --git a/pkg/sentry/fsbridge/fs.go b/pkg/sentry/fsbridge/fs.go
new file mode 100644
index 000000000..093ce1fb3
--- /dev/null
+++ b/pkg/sentry/fsbridge/fs.go
@@ -0,0 +1,181 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsbridge
+
+import (
+ "io"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// fsFile implements File interface over fs.File.
+//
+// +stateify savable
+type fsFile struct {
+ file *fs.File
+}
+
+var _ File = (*fsFile)(nil)
+
+// NewFSFile creates a new File over fs.File.
+func NewFSFile(file *fs.File) File {
+ return &fsFile{file: file}
+}
+
+// PathnameWithDeleted implements File.
+func (f *fsFile) PathnameWithDeleted(ctx context.Context) string {
+ root := fs.RootFromContext(ctx)
+ if root == nil {
+ // This doesn't correspond to anything in Linux because the vfs is
+ // global there.
+ return ""
+ }
+ defer root.DecRef()
+
+ name, _ := f.file.Dirent.FullName(root)
+ return name
+}
+
+// ReadFull implements File.
+func (f *fsFile) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
+ var total int64
+ for dst.NumBytes() > 0 {
+ n, err := f.file.Preadv(ctx, dst, offset+total)
+ total += n
+ if err == io.EOF && total != 0 {
+ return total, io.ErrUnexpectedEOF
+ } else if err != nil {
+ return total, err
+ }
+ dst = dst.DropFirst64(n)
+ }
+ return total, nil
+}
+
+// ConfigureMMap implements File.
+func (f *fsFile) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ return f.file.ConfigureMMap(ctx, opts)
+}
+
+// Type implements File.
+func (f *fsFile) Type(context.Context) (linux.FileMode, error) {
+ return linux.FileMode(f.file.Dirent.Inode.StableAttr.Type.LinuxType()), nil
+}
+
+// IncRef implements File.
+func (f *fsFile) IncRef() {
+ f.file.IncRef()
+}
+
+// DecRef implements File.
+func (f *fsFile) DecRef() {
+ f.file.DecRef()
+}
+
+// fsLookup implements Lookup interface using fs.File.
+//
+// +stateify savable
+type fsLookup struct {
+ mntns *fs.MountNamespace
+
+ root *fs.Dirent
+ workingDir *fs.Dirent
+}
+
+var _ Lookup = (*fsLookup)(nil)
+
+// NewFSLookup creates a new Lookup using VFS1.
+func NewFSLookup(mntns *fs.MountNamespace, root, workingDir *fs.Dirent) Lookup {
+ return &fsLookup{
+ mntns: mntns,
+ root: root,
+ workingDir: workingDir,
+ }
+}
+
+// OpenPath implements Lookup.
+func (l *fsLookup) OpenPath(ctx context.Context, path string, opts vfs.OpenOptions, remainingTraversals *uint, resolveFinal bool) (File, error) {
+ var d *fs.Dirent
+ var err error
+ if resolveFinal {
+ d, err = l.mntns.FindInode(ctx, l.root, l.workingDir, path, remainingTraversals)
+ } else {
+ d, err = l.mntns.FindLink(ctx, l.root, l.workingDir, path, remainingTraversals)
+ }
+ if err != nil {
+ return nil, err
+ }
+ defer d.DecRef()
+
+ if !resolveFinal && fs.IsSymlink(d.Inode.StableAttr) {
+ return nil, syserror.ELOOP
+ }
+
+ fsPerm := openOptionsToPermMask(&opts)
+ if err := d.Inode.CheckPermission(ctx, fsPerm); err != nil {
+ return nil, err
+ }
+
+ // If they claim it's a directory, then make sure.
+ if strings.HasSuffix(path, "/") {
+ if d.Inode.StableAttr.Type != fs.Directory {
+ return nil, syserror.ENOTDIR
+ }
+ }
+
+ if opts.FileExec && d.Inode.StableAttr.Type != fs.RegularFile {
+ ctx.Infof("%q is not a regular file: %v", path, d.Inode.StableAttr.Type)
+ return nil, syserror.EACCES
+ }
+
+ f, err := d.Inode.GetFile(ctx, d, flagsToFileFlags(opts.Flags))
+ if err != nil {
+ return nil, err
+ }
+
+ return &fsFile{file: f}, nil
+}
+
+func openOptionsToPermMask(opts *vfs.OpenOptions) fs.PermMask {
+ mode := opts.Flags & linux.O_ACCMODE
+ return fs.PermMask{
+ Read: mode == linux.O_RDONLY || mode == linux.O_RDWR,
+ Write: mode == linux.O_WRONLY || mode == linux.O_RDWR,
+ Execute: opts.FileExec,
+ }
+}
+
+func flagsToFileFlags(flags uint32) fs.FileFlags {
+ return fs.FileFlags{
+ Direct: flags&linux.O_DIRECT != 0,
+ DSync: flags&(linux.O_DSYNC|linux.O_SYNC) != 0,
+ Sync: flags&linux.O_SYNC != 0,
+ NonBlocking: flags&linux.O_NONBLOCK != 0,
+ Read: (flags & linux.O_ACCMODE) != linux.O_WRONLY,
+ Write: (flags & linux.O_ACCMODE) != linux.O_RDONLY,
+ Append: flags&linux.O_APPEND != 0,
+ Directory: flags&linux.O_DIRECTORY != 0,
+ Async: flags&linux.O_ASYNC != 0,
+ LargeFile: flags&linux.O_LARGEFILE != 0,
+ Truncate: flags&linux.O_TRUNC != 0,
+ }
+}
diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go
new file mode 100644
index 000000000..89168220a
--- /dev/null
+++ b/pkg/sentry/fsbridge/vfs.go
@@ -0,0 +1,142 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsbridge
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// VFSFile implements File interface over vfs.FileDescription.
+//
+// +stateify savable
+type VFSFile struct {
+ file *vfs.FileDescription
+}
+
+var _ File = (*VFSFile)(nil)
+
+// NewVFSFile creates a new File over fs.File.
+func NewVFSFile(file *vfs.FileDescription) File {
+ return &VFSFile{file: file}
+}
+
+// PathnameWithDeleted implements File.
+func (f *VFSFile) PathnameWithDeleted(ctx context.Context) string {
+ root := vfs.RootFromContext(ctx)
+ defer root.DecRef()
+
+ vfsObj := f.file.VirtualDentry().Mount().Filesystem().VirtualFilesystem()
+ name, _ := vfsObj.PathnameWithDeleted(ctx, root, f.file.VirtualDentry())
+ return name
+}
+
+// ReadFull implements File.
+func (f *VFSFile) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
+ var total int64
+ for dst.NumBytes() > 0 {
+ n, err := f.file.PRead(ctx, dst, offset+total, vfs.ReadOptions{})
+ total += n
+ if err == io.EOF && total != 0 {
+ return total, io.ErrUnexpectedEOF
+ } else if err != nil {
+ return total, err
+ }
+ dst = dst.DropFirst64(n)
+ }
+ return total, nil
+}
+
+// ConfigureMMap implements File.
+func (f *VFSFile) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ return f.file.ConfigureMMap(ctx, opts)
+}
+
+// Type implements File.
+func (f *VFSFile) Type(ctx context.Context) (linux.FileMode, error) {
+ stat, err := f.file.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ return 0, err
+ }
+ return linux.FileMode(stat.Mode).FileType(), nil
+}
+
+// IncRef implements File.
+func (f *VFSFile) IncRef() {
+ f.file.IncRef()
+}
+
+// DecRef implements File.
+func (f *VFSFile) DecRef() {
+ f.file.DecRef()
+}
+
+// FileDescription returns the FileDescription represented by f. It does not
+// take an additional reference on the returned FileDescription.
+func (f *VFSFile) FileDescription() *vfs.FileDescription {
+ return f.file
+}
+
+// fsLookup implements Lookup interface using fs.File.
+//
+// +stateify savable
+type vfsLookup struct {
+ mntns *vfs.MountNamespace
+
+ root vfs.VirtualDentry
+ workingDir vfs.VirtualDentry
+}
+
+var _ Lookup = (*vfsLookup)(nil)
+
+// NewVFSLookup creates a new Lookup using VFS2.
+func NewVFSLookup(mntns *vfs.MountNamespace, root, workingDir vfs.VirtualDentry) Lookup {
+ return &vfsLookup{
+ mntns: mntns,
+ root: root,
+ workingDir: workingDir,
+ }
+}
+
+// OpenPath implements Lookup.
+//
+// remainingTraversals is not configurable in VFS2, all callers are using the
+// default anyways.
+func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) {
+ vfsObj := l.mntns.Root().Mount().Filesystem().VirtualFilesystem()
+ creds := auth.CredentialsFromContext(ctx)
+ path := fspath.Parse(pathname)
+ pop := &vfs.PathOperation{
+ Root: l.root,
+ Start: l.workingDir,
+ Path: path,
+ FollowFinalSymlink: resolveFinal,
+ }
+ if path.Absolute {
+ pop.Start = l.root
+ }
+ fd, err := vfsObj.OpenAt(ctx, creds, pop, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return &VFSFile{file: fd}, nil
+}
diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD
new file mode 100644
index 000000000..93512c9b6
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/BUILD
@@ -0,0 +1,44 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+licenses(["notice"])
+
+go_library(
+ name = "devpts",
+ srcs = [
+ "devpts.go",
+ "line_discipline.go",
+ "master.go",
+ "queue.go",
+ "slave.go",
+ "terminal.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/unimpl",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "devpts_test",
+ size = "small",
+ srcs = ["devpts_test.go"],
+ library = ":devpts",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/contexttest",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go
new file mode 100644
index 000000000..e6fda2b4f
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/devpts.go
@@ -0,0 +1,233 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package devpts provides a filesystem implementation that behaves like
+// devpts.
+package devpts
+
+import (
+ "fmt"
+ "math"
+ "sort"
+ "strconv"
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Name is the filesystem name.
+const Name = "devpts"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+var _ vfs.FilesystemType = (*FilesystemType)(nil)
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ // No data allowed.
+ if opts.Data != "" {
+ return nil, nil, syserror.EINVAL
+ }
+
+ fs, root, err := fstype.newFilesystem(vfsObj, creds)
+ if err != nil {
+ return nil, nil, err
+ }
+ return fs.Filesystem.VFSFilesystem(), root.VFSDentry(), nil
+}
+
+type filesystem struct {
+ kernfs.Filesystem
+
+ devMinor uint32
+}
+
+// newFilesystem creates a new devpts filesystem with root directory and ptmx
+// master inode. It returns the filesystem and root Dentry.
+func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fs := &filesystem{
+ devMinor: devMinor,
+ }
+ fs.Filesystem.VFSFilesystem().Init(vfsObj, fstype, fs)
+
+ // Construct the root directory. This is always inode id 1.
+ root := &rootInode{
+ slaves: make(map[uint32]*slaveInode),
+ }
+ root.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555)
+ root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ root.dentry.Init(root)
+
+ // Construct the pts master inode and dentry. Linux always uses inode
+ // id 2 for ptmx. See fs/devpts/inode.c:mknod_ptmx.
+ master := &masterInode{
+ root: root,
+ }
+ master.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666)
+ master.dentry.Init(master)
+
+ // Add the master as a child of the root.
+ links := root.OrderedChildren.Populate(&root.dentry, map[string]*kernfs.Dentry{
+ "ptmx": &master.dentry,
+ })
+ root.IncLinks(links)
+
+ return fs, &root.dentry, nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release() {
+ fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release()
+}
+
+// rootInode is the root directory inode for the devpts mounts.
+type rootInode struct {
+ kernfs.AlwaysValid
+ kernfs.InodeAttrs
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeNotSymlink
+ kernfs.OrderedChildren
+
+ locks vfs.FileLocks
+
+ // Keep a reference to this inode's dentry.
+ dentry kernfs.Dentry
+
+ // master is the master pty inode. Immutable.
+ master *masterInode
+
+ // root is the root directory inode for this filesystem. Immutable.
+ root *rootInode
+
+ // mu protects the fields below.
+ mu sync.Mutex
+
+ // slaves maps pty ids to slave inodes.
+ slaves map[uint32]*slaveInode
+
+ // nextIdx is the next pty index to use. Must be accessed atomically.
+ //
+ // TODO(b/29356795): reuse indices when ptys are closed.
+ nextIdx uint32
+}
+
+var _ kernfs.Inode = (*rootInode)(nil)
+
+// allocateTerminal creates a new Terminal and installs a pts node for it.
+func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) {
+ i.mu.Lock()
+ defer i.mu.Unlock()
+ if i.nextIdx == math.MaxUint32 {
+ return nil, syserror.ENOMEM
+ }
+ idx := i.nextIdx
+ i.nextIdx++
+
+ // Sanity check that slave with idx does not exist.
+ if _, ok := i.slaves[idx]; ok {
+ panic(fmt.Sprintf("pty index collision; index %d already exists", idx))
+ }
+
+ // Create the new terminal and slave.
+ t := newTerminal(idx)
+ slave := &slaveInode{
+ root: i,
+ t: t,
+ }
+ // Linux always uses pty index + 3 as the inode id. See
+ // fs/devpts/inode.c:devpts_pty_new().
+ slave.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600)
+ slave.dentry.Init(slave)
+ i.slaves[idx] = slave
+
+ return t, nil
+}
+
+// masterClose is called when the master end of t is closed.
+func (i *rootInode) masterClose(t *Terminal) {
+ i.mu.Lock()
+ defer i.mu.Unlock()
+
+ // Sanity check that slave with idx exists.
+ if _, ok := i.slaves[t.n]; !ok {
+ panic(fmt.Sprintf("pty with index %d does not exist", t.n))
+ }
+ delete(i.slaves, t.n)
+}
+
+// Open implements kernfs.Inode.Open.
+func (i *rootInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+// Lookup implements kernfs.Inode.Lookup.
+func (i *rootInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ idx, err := strconv.ParseUint(name, 10, 32)
+ if err != nil {
+ return nil, syserror.ENOENT
+ }
+ i.mu.Lock()
+ defer i.mu.Unlock()
+ if si, ok := i.slaves[uint32(idx)]; ok {
+ si.dentry.IncRef()
+ return si.dentry.VFSDentry(), nil
+
+ }
+ return nil, syserror.ENOENT
+}
+
+// IterDirents implements kernfs.Inode.IterDirents.
+func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+ i.mu.Lock()
+ defer i.mu.Unlock()
+ ids := make([]int, 0, len(i.slaves))
+ for id := range i.slaves {
+ ids = append(ids, int(id))
+ }
+ sort.Ints(ids)
+ for _, id := range ids[relOffset:] {
+ dirent := vfs.Dirent{
+ Name: strconv.FormatUint(uint64(id), 10),
+ Type: linux.DT_CHR,
+ Ino: i.slaves[uint32(id)].InodeAttrs.Ino(),
+ NextOff: offset + 1,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+ return offset, nil
+}
diff --git a/pkg/sentry/fsimpl/devpts/devpts_test.go b/pkg/sentry/fsimpl/devpts/devpts_test.go
new file mode 100644
index 000000000..b7c149047
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/devpts_test.go
@@ -0,0 +1,56 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devpts
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func TestSimpleMasterToSlave(t *testing.T) {
+ ld := newLineDiscipline(linux.DefaultSlaveTermios)
+ ctx := contexttest.Context(t)
+ inBytes := []byte("hello, tty\n")
+ src := usermem.BytesIOSequence(inBytes)
+ outBytes := make([]byte, 32)
+ dst := usermem.BytesIOSequence(outBytes)
+
+ // Write to the input queue.
+ nw, err := ld.inputQueueWrite(ctx, src)
+ if err != nil {
+ t.Fatalf("error writing to input queue: %v", err)
+ }
+ if nw != int64(len(inBytes)) {
+ t.Fatalf("wrote wrong length: got %d, want %d", nw, len(inBytes))
+ }
+
+ // Read from the input queue.
+ nr, err := ld.inputQueueRead(ctx, dst)
+ if err != nil {
+ t.Fatalf("error reading from input queue: %v", err)
+ }
+ if nr != int64(len(inBytes)) {
+ t.Fatalf("read wrong length: got %d, want %d", nr, len(inBytes))
+ }
+
+ outStr := string(outBytes[:nr])
+ inStr := string(inBytes)
+ if outStr != inStr {
+ t.Fatalf("written and read strings do not match: got %q, want %q", outStr, inStr)
+ }
+}
diff --git a/pkg/sentry/fsimpl/devpts/line_discipline.go b/pkg/sentry/fsimpl/devpts/line_discipline.go
new file mode 100644
index 000000000..f7bc325d1
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/line_discipline.go
@@ -0,0 +1,445 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devpts
+
+import (
+ "bytes"
+ "unicode/utf8"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ // canonMaxBytes is the number of bytes that fit into a single line of
+ // terminal input in canonical mode. This corresponds to N_TTY_BUF_SIZE
+ // in include/linux/tty.h.
+ canonMaxBytes = 4096
+
+ // nonCanonMaxBytes is the maximum number of bytes that can be read at
+ // a time in noncanonical mode.
+ nonCanonMaxBytes = canonMaxBytes - 1
+
+ spacesPerTab = 8
+)
+
+// lineDiscipline dictates how input and output are handled between the
+// pseudoterminal (pty) master and slave. It can be configured to alter I/O,
+// modify control characters (e.g. Ctrl-C for SIGINT), etc. The following man
+// pages are good resources for how to affect the line discipline:
+//
+// * termios(3)
+// * tty_ioctl(4)
+//
+// This file corresponds most closely to drivers/tty/n_tty.c.
+//
+// lineDiscipline has a simple structure but supports a multitude of options
+// (see the above man pages). It consists of two queues of bytes: one from the
+// terminal master to slave (the input queue) and one from slave to master (the
+// output queue). When bytes are written to one end of the pty, the line
+// discipline reads the bytes, modifies them or takes special action if
+// required, and enqueues them to be read by the other end of the pty:
+//
+// input from terminal +-------------+ input to process (e.g. bash)
+// +------------------------>| input queue |---------------------------+
+// | (inputQueueWrite) +-------------+ (inputQueueRead) |
+// | |
+// | v
+// masterFD slaveFD
+// ^ |
+// | |
+// | output to terminal +--------------+ output from process |
+// +------------------------| output queue |<--------------------------+
+// (outputQueueRead) +--------------+ (outputQueueWrite)
+//
+// Lock order:
+// termiosMu
+// inQueue.mu
+// outQueue.mu
+//
+// +stateify savable
+type lineDiscipline struct {
+ // sizeMu protects size.
+ sizeMu sync.Mutex `state:"nosave"`
+
+ // size is the terminal size (width and height).
+ size linux.WindowSize
+
+ // inQueue is the input queue of the terminal.
+ inQueue queue
+
+ // outQueue is the output queue of the terminal.
+ outQueue queue
+
+ // termiosMu protects termios.
+ termiosMu sync.RWMutex `state:"nosave"`
+
+ // termios is the terminal configuration used by the lineDiscipline.
+ termios linux.KernelTermios
+
+ // column is the location in a row of the cursor. This is important for
+ // handling certain special characters like backspace.
+ column int
+
+ // masterWaiter is used to wait on the master end of the TTY.
+ masterWaiter waiter.Queue `state:"zerovalue"`
+
+ // slaveWaiter is used to wait on the slave end of the TTY.
+ slaveWaiter waiter.Queue `state:"zerovalue"`
+}
+
+func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline {
+ ld := lineDiscipline{termios: termios}
+ ld.inQueue.transformer = &inputQueueTransformer{}
+ ld.outQueue.transformer = &outputQueueTransformer{}
+ return &ld
+}
+
+// getTermios gets the linux.Termios for the tty.
+func (l *lineDiscipline) getTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ // We must copy a Termios struct, not KernelTermios.
+ t := l.termios.ToTermios()
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), t, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+}
+
+// setTermios sets a linux.Termios for the tty.
+func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ l.termiosMu.Lock()
+ defer l.termiosMu.Unlock()
+ oldCanonEnabled := l.termios.LEnabled(linux.ICANON)
+ // We must copy a Termios struct, not KernelTermios.
+ var t linux.Termios
+ _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &t, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ l.termios.FromTermios(t)
+
+ // If canonical mode is turned off, move bytes from inQueue's wait
+ // buffer to its read buffer. Anything already in the read buffer is
+ // now readable.
+ if oldCanonEnabled && !l.termios.LEnabled(linux.ICANON) {
+ l.inQueue.mu.Lock()
+ l.inQueue.pushWaitBufLocked(l)
+ l.inQueue.readable = true
+ l.inQueue.mu.Unlock()
+ l.slaveWaiter.Notify(waiter.EventIn)
+ }
+
+ return 0, err
+}
+
+func (l *lineDiscipline) windowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ l.sizeMu.Lock()
+ defer l.sizeMu.Unlock()
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), l.size, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return err
+}
+
+func (l *lineDiscipline) setWindowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ l.sizeMu.Lock()
+ defer l.sizeMu.Unlock()
+ _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &l.size, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return err
+}
+
+func (l *lineDiscipline) masterReadiness() waiter.EventMask {
+ // We don't have to lock a termios because the default master termios
+ // is immutable.
+ return l.inQueue.writeReadiness(&linux.MasterTermios) | l.outQueue.readReadiness(&linux.MasterTermios)
+}
+
+func (l *lineDiscipline) slaveReadiness() waiter.EventMask {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ return l.outQueue.writeReadiness(&l.termios) | l.inQueue.readReadiness(&l.termios)
+}
+
+func (l *lineDiscipline) inputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ return l.inQueue.readableSize(ctx, io, args)
+}
+
+func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, pushed, err := l.inQueue.read(ctx, dst, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.masterWaiter.Notify(waiter.EventOut)
+ if pushed {
+ l.slaveWaiter.Notify(waiter.EventIn)
+ }
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, err := l.inQueue.write(ctx, src, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.slaveWaiter.Notify(waiter.EventIn)
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+func (l *lineDiscipline) outputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ return l.outQueue.readableSize(ctx, io, args)
+}
+
+func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, pushed, err := l.outQueue.read(ctx, dst, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.slaveWaiter.Notify(waiter.EventOut)
+ if pushed {
+ l.masterWaiter.Notify(waiter.EventIn)
+ }
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+func (l *lineDiscipline) outputQueueWrite(ctx context.Context, src usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, err := l.outQueue.write(ctx, src, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.masterWaiter.Notify(waiter.EventIn)
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+// transformer is a helper interface to make it easier to stateify queue.
+type transformer interface {
+ // transform functions require queue's mutex to be held.
+ transform(*lineDiscipline, *queue, []byte) int
+}
+
+// outputQueueTransformer implements transformer. It performs line discipline
+// transformations on the output queue.
+//
+// +stateify savable
+type outputQueueTransformer struct{}
+
+// transform does output processing for one end of the pty. See
+// drivers/tty/n_tty.c:do_output_char for an analogous kernel function.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+// * q.mu must be held.
+func (*outputQueueTransformer) transform(l *lineDiscipline, q *queue, buf []byte) int {
+ // transformOutput is effectively always in noncanonical mode, as the
+ // master termios never has ICANON set.
+
+ if !l.termios.OEnabled(linux.OPOST) {
+ q.readBuf = append(q.readBuf, buf...)
+ if len(q.readBuf) > 0 {
+ q.readable = true
+ }
+ return len(buf)
+ }
+
+ var ret int
+ for len(buf) > 0 {
+ size := l.peek(buf)
+ cBytes := append([]byte{}, buf[:size]...)
+ ret += size
+ buf = buf[size:]
+ // We're guaranteed that cBytes has at least one element.
+ switch cBytes[0] {
+ case '\n':
+ if l.termios.OEnabled(linux.ONLRET) {
+ l.column = 0
+ }
+ if l.termios.OEnabled(linux.ONLCR) {
+ q.readBuf = append(q.readBuf, '\r', '\n')
+ continue
+ }
+ case '\r':
+ if l.termios.OEnabled(linux.ONOCR) && l.column == 0 {
+ continue
+ }
+ if l.termios.OEnabled(linux.OCRNL) {
+ cBytes[0] = '\n'
+ if l.termios.OEnabled(linux.ONLRET) {
+ l.column = 0
+ }
+ break
+ }
+ l.column = 0
+ case '\t':
+ spaces := spacesPerTab - l.column%spacesPerTab
+ if l.termios.OutputFlags&linux.TABDLY == linux.XTABS {
+ l.column += spaces
+ q.readBuf = append(q.readBuf, bytes.Repeat([]byte{' '}, spacesPerTab)...)
+ continue
+ }
+ l.column += spaces
+ case '\b':
+ if l.column > 0 {
+ l.column--
+ }
+ default:
+ l.column++
+ }
+ q.readBuf = append(q.readBuf, cBytes...)
+ }
+ if len(q.readBuf) > 0 {
+ q.readable = true
+ }
+ return ret
+}
+
+// inputQueueTransformer implements transformer. It performs line discipline
+// transformations on the input queue.
+//
+// +stateify savable
+type inputQueueTransformer struct{}
+
+// transform does input processing for one end of the pty. Characters read are
+// transformed according to flags set in the termios struct. See
+// drivers/tty/n_tty.c:n_tty_receive_char_special for an analogous kernel
+// function.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+// * q.mu must be held.
+func (*inputQueueTransformer) transform(l *lineDiscipline, q *queue, buf []byte) int {
+ // If there's a line waiting to be read in canonical mode, don't write
+ // anything else to the read buffer.
+ if l.termios.LEnabled(linux.ICANON) && q.readable {
+ return 0
+ }
+
+ maxBytes := nonCanonMaxBytes
+ if l.termios.LEnabled(linux.ICANON) {
+ maxBytes = canonMaxBytes
+ }
+
+ var ret int
+ for len(buf) > 0 && len(q.readBuf) < canonMaxBytes {
+ size := l.peek(buf)
+ cBytes := append([]byte{}, buf[:size]...)
+ // We're guaranteed that cBytes has at least one element.
+ switch cBytes[0] {
+ case '\r':
+ if l.termios.IEnabled(linux.IGNCR) {
+ buf = buf[size:]
+ ret += size
+ continue
+ }
+ if l.termios.IEnabled(linux.ICRNL) {
+ cBytes[0] = '\n'
+ }
+ case '\n':
+ if l.termios.IEnabled(linux.INLCR) {
+ cBytes[0] = '\r'
+ }
+ }
+
+ // In canonical mode, we discard non-terminating characters
+ // after the first 4095.
+ if l.shouldDiscard(q, cBytes) {
+ buf = buf[size:]
+ ret += size
+ continue
+ }
+
+ // Stop if the buffer would be overfilled.
+ if len(q.readBuf)+size > maxBytes {
+ break
+ }
+ buf = buf[size:]
+ ret += size
+
+ // If we get EOF, make the buffer available for reading.
+ if l.termios.LEnabled(linux.ICANON) && l.termios.IsEOF(cBytes[0]) {
+ q.readable = true
+ break
+ }
+
+ q.readBuf = append(q.readBuf, cBytes...)
+
+ // Anything written to the readBuf will have to be echoed.
+ if l.termios.LEnabled(linux.ECHO) {
+ l.outQueue.writeBytes(cBytes, l)
+ l.masterWaiter.Notify(waiter.EventIn)
+ }
+
+ // If we finish a line, make it available for reading.
+ if l.termios.LEnabled(linux.ICANON) && l.termios.IsTerminating(cBytes) {
+ q.readable = true
+ break
+ }
+ }
+
+ // In noncanonical mode, everything is readable.
+ if !l.termios.LEnabled(linux.ICANON) && len(q.readBuf) > 0 {
+ q.readable = true
+ }
+
+ return ret
+}
+
+// shouldDiscard returns whether c should be discarded. In canonical mode, if
+// too many bytes are enqueued, we keep reading input and discarding it until
+// we find a terminating character. Signal/echo processing still occurs.
+//
+// Precondition:
+// * l.termiosMu must be held for reading.
+// * q.mu must be held.
+func (l *lineDiscipline) shouldDiscard(q *queue, cBytes []byte) bool {
+ return l.termios.LEnabled(linux.ICANON) && len(q.readBuf)+len(cBytes) >= canonMaxBytes && !l.termios.IsTerminating(cBytes)
+}
+
+// peek returns the size in bytes of the next character to process. As long as
+// b isn't empty, peek returns a value of at least 1.
+func (l *lineDiscipline) peek(b []byte) int {
+ size := 1
+ // If UTF-8 support is enabled, runes might be multiple bytes.
+ if l.termios.IEnabled(linux.IUTF8) {
+ _, size = utf8.DecodeRune(b)
+ }
+ return size
+}
diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go
new file mode 100644
index 000000000..69879498a
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/master.go
@@ -0,0 +1,237 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devpts
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/unimpl"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// masterInode is the inode for the master end of the Terminal.
+type masterInode struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+
+ locks vfs.FileLocks
+
+ // Keep a reference to this inode's dentry.
+ dentry kernfs.Dentry
+
+ // root is the devpts root inode.
+ root *rootInode
+}
+
+var _ kernfs.Inode = (*masterInode)(nil)
+
+// Open implements kernfs.Inode.Open.
+func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ t, err := mi.root.allocateTerminal(rp.Credentials())
+ if err != nil {
+ return nil, err
+ }
+
+ mi.IncRef()
+ fd := &masterFileDescription{
+ inode: mi,
+ t: t,
+ }
+ fd.LockFD.Init(&mi.locks)
+ if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ mi.DecRef()
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// Stat implements kernfs.Inode.Stat.
+func (mi *masterInode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ statx, err := mi.InodeAttrs.Stat(vfsfs, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ statx.Blksize = 1024
+ statx.RdevMajor = linux.TTYAUX_MAJOR
+ statx.RdevMinor = linux.PTMX_MINOR
+ return statx, nil
+}
+
+// SetStat implements kernfs.Inode.SetStat
+func (mi *masterInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ if opts.Stat.Mask&linux.STATX_SIZE != 0 {
+ return syserror.EINVAL
+ }
+ return mi.InodeAttrs.SetStat(ctx, vfsfs, creds, opts)
+}
+
+type masterFileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ inode *masterInode
+ t *Terminal
+}
+
+var _ vfs.FileDescriptionImpl = (*masterFileDescription)(nil)
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (mfd *masterFileDescription) Release() {
+ mfd.inode.root.masterClose(mfd.t)
+ mfd.inode.DecRef()
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (mfd *masterFileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ mfd.t.ld.masterWaiter.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (mfd *masterFileDescription) EventUnregister(e *waiter.Entry) {
+ mfd.t.ld.masterWaiter.EventUnregister(e)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (mfd *masterFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return mfd.t.ld.masterReadiness()
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (mfd *masterFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
+ return mfd.t.ld.outputQueueRead(ctx, dst)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (mfd *masterFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) {
+ return mfd.t.ld.inputQueueWrite(ctx, src)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
+func (mfd *masterFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch cmd := args[1].Uint(); cmd {
+ case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ
+ // Get the number of bytes in the output queue read buffer.
+ return 0, mfd.t.ld.outputQueueReadSize(ctx, io, args)
+ case linux.TCGETS:
+ // N.B. TCGETS on the master actually returns the configuration
+ // of the slave end.
+ return mfd.t.ld.getTermios(ctx, io, args)
+ case linux.TCSETS:
+ // N.B. TCSETS on the master actually affects the configuration
+ // of the slave end.
+ return mfd.t.ld.setTermios(ctx, io, args)
+ case linux.TCSETSW:
+ // TODO(b/29356795): This should drain the output queue first.
+ return mfd.t.ld.setTermios(ctx, io, args)
+ case linux.TIOCGPTN:
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(mfd.t.n), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ case linux.TIOCSPTLCK:
+ // TODO(b/29356795): Implement pty locking. For now just pretend we do.
+ return 0, nil
+ case linux.TIOCGWINSZ:
+ return 0, mfd.t.ld.windowSize(ctx, io, args)
+ case linux.TIOCSWINSZ:
+ return 0, mfd.t.ld.setWindowSize(ctx, io, args)
+ case linux.TIOCSCTTY:
+ // Make the given terminal the controlling terminal of the
+ // calling process.
+ return 0, mfd.t.setControllingTTY(ctx, io, args, true /* isMaster */)
+ case linux.TIOCNOTTY:
+ // Release this process's controlling terminal.
+ return 0, mfd.t.releaseControllingTTY(ctx, io, args, true /* isMaster */)
+ case linux.TIOCGPGRP:
+ // Get the foreground process group.
+ return mfd.t.foregroundProcessGroup(ctx, io, args, true /* isMaster */)
+ case linux.TIOCSPGRP:
+ // Set the foreground process group.
+ return mfd.t.setForegroundProcessGroup(ctx, io, args, true /* isMaster */)
+ default:
+ maybeEmitUnimplementedEvent(ctx, cmd)
+ return 0, syserror.ENOTTY
+ }
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (mfd *masterFileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ creds := auth.CredentialsFromContext(ctx)
+ fs := mfd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return mfd.inode.SetStat(ctx, fs, creds, opts)
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (mfd *masterFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ fs := mfd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return mfd.inode.Stat(fs, opts)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (mfd *masterFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return mfd.Locks().LockPOSIX(ctx, &mfd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (mfd *masterFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return mfd.Locks().UnlockPOSIX(ctx, &mfd.vfsfd, uid, start, length, whence)
+}
+
+// maybeEmitUnimplementedEvent emits unimplemented event if cmd is valid.
+func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) {
+ switch cmd {
+ case linux.TCGETS,
+ linux.TCSETS,
+ linux.TCSETSW,
+ linux.TCSETSF,
+ linux.TIOCGWINSZ,
+ linux.TIOCSWINSZ,
+ linux.TIOCSETD,
+ linux.TIOCSBRK,
+ linux.TIOCCBRK,
+ linux.TCSBRK,
+ linux.TCSBRKP,
+ linux.TIOCSTI,
+ linux.TIOCCONS,
+ linux.FIONBIO,
+ linux.TIOCEXCL,
+ linux.TIOCNXCL,
+ linux.TIOCGEXCL,
+ linux.TIOCGSID,
+ linux.TIOCGETD,
+ linux.TIOCVHANGUP,
+ linux.TIOCGDEV,
+ linux.TIOCMGET,
+ linux.TIOCMSET,
+ linux.TIOCMBIC,
+ linux.TIOCMBIS,
+ linux.TIOCGICOUNT,
+ linux.TCFLSH,
+ linux.TIOCSSERIAL,
+ linux.TIOCGPTPEER:
+
+ unimpl.EmitUnimplementedEvent(ctx)
+ }
+}
diff --git a/pkg/sentry/fsimpl/devpts/queue.go b/pkg/sentry/fsimpl/devpts/queue.go
new file mode 100644
index 000000000..dffb4232c
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/queue.go
@@ -0,0 +1,236 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devpts
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// waitBufMaxBytes is the maximum size of a wait buffer. It is based on
+// TTYB_DEFAULT_MEM_LIMIT.
+const waitBufMaxBytes = 131072
+
+// queue represents one of the input or output queues between a pty master and
+// slave. Bytes written to a queue are added to the read buffer until it is
+// full, at which point they are written to the wait buffer. Bytes are
+// processed (i.e. undergo termios transformations) as they are added to the
+// read buffer. The read buffer is readable when its length is nonzero and
+// readable is true.
+//
+// +stateify savable
+type queue struct {
+ // mu protects everything in queue.
+ mu sync.Mutex `state:"nosave"`
+
+ // readBuf is buffer of data ready to be read when readable is true.
+ // This data has been processed.
+ readBuf []byte
+
+ // waitBuf contains data that can't fit into readBuf. It is put here
+ // until it can be loaded into the read buffer. waitBuf contains data
+ // that hasn't been processed.
+ waitBuf [][]byte
+ waitBufLen uint64
+
+ // readable indicates whether the read buffer can be read from. In
+ // canonical mode, there can be an unterminated line in the read buffer,
+ // so readable must be checked.
+ readable bool
+
+ // transform is the the queue's function for transforming bytes
+ // entering the queue. For example, transform might convert all '\r's
+ // entering the queue to '\n's.
+ transformer
+}
+
+// readReadiness returns whether q is ready to be read from.
+func (q *queue) readReadiness(t *linux.KernelTermios) waiter.EventMask {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ if len(q.readBuf) > 0 && q.readable {
+ return waiter.EventIn
+ }
+ return waiter.EventMask(0)
+}
+
+// writeReadiness returns whether q is ready to be written to.
+func (q *queue) writeReadiness(t *linux.KernelTermios) waiter.EventMask {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ if q.waitBufLen < waitBufMaxBytes {
+ return waiter.EventOut
+ }
+ return waiter.EventMask(0)
+}
+
+// readableSize writes the number of readable bytes to userspace.
+func (q *queue) readableSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ var size int32
+ if q.readable {
+ size = int32(len(q.readBuf))
+ }
+
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), size, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return err
+
+}
+
+// read reads from q to userspace. It returns the number of bytes read as well
+// as whether the read caused more readable data to become available (whether
+// data was pushed from the wait buffer to the read buffer).
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+func (q *queue) read(ctx context.Context, dst usermem.IOSequence, l *lineDiscipline) (int64, bool, error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ if !q.readable {
+ return 0, false, syserror.ErrWouldBlock
+ }
+
+ if dst.NumBytes() > canonMaxBytes {
+ dst = dst.TakeFirst(canonMaxBytes)
+ }
+
+ n, err := dst.CopyOutFrom(ctx, safemem.ReaderFunc(func(dst safemem.BlockSeq) (uint64, error) {
+ src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(q.readBuf))
+ n, err := safemem.CopySeq(dst, src)
+ if err != nil {
+ return 0, err
+ }
+ q.readBuf = q.readBuf[n:]
+
+ // If we read everything, this queue is no longer readable.
+ if len(q.readBuf) == 0 {
+ q.readable = false
+ }
+
+ return n, nil
+ }))
+ if err != nil {
+ return 0, false, err
+ }
+
+ // Move data from the queue's wait buffer to its read buffer.
+ nPushed := q.pushWaitBufLocked(l)
+
+ return int64(n), nPushed > 0, nil
+}
+
+// write writes to q from userspace.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+func (q *queue) write(ctx context.Context, src usermem.IOSequence, l *lineDiscipline) (int64, error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ // Copy data into the wait buffer.
+ n, err := src.CopyInTo(ctx, safemem.WriterFunc(func(src safemem.BlockSeq) (uint64, error) {
+ copyLen := src.NumBytes()
+ room := waitBufMaxBytes - q.waitBufLen
+ // If out of room, return EAGAIN.
+ if room == 0 && copyLen > 0 {
+ return 0, syserror.ErrWouldBlock
+ }
+ // Cap the size of the wait buffer.
+ if copyLen > room {
+ copyLen = room
+ src = src.TakeFirst64(room)
+ }
+ buf := make([]byte, copyLen)
+
+ // Copy the data into the wait buffer.
+ dst := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf))
+ n, err := safemem.CopySeq(dst, src)
+ if err != nil {
+ return 0, err
+ }
+ q.waitBufAppend(buf)
+
+ return n, nil
+ }))
+ if err != nil {
+ return 0, err
+ }
+
+ // Push data from the wait to the read buffer.
+ q.pushWaitBufLocked(l)
+
+ return n, nil
+}
+
+// writeBytes writes to q from b.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+func (q *queue) writeBytes(b []byte, l *lineDiscipline) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ // Write to the wait buffer.
+ q.waitBufAppend(b)
+ q.pushWaitBufLocked(l)
+}
+
+// pushWaitBufLocked fills the queue's read buffer with data from the wait
+// buffer.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+// * q.mu must be locked.
+func (q *queue) pushWaitBufLocked(l *lineDiscipline) int {
+ if q.waitBufLen == 0 {
+ return 0
+ }
+
+ // Move data from the wait to the read buffer.
+ var total int
+ var i int
+ for i = 0; i < len(q.waitBuf); i++ {
+ n := q.transform(l, q, q.waitBuf[i])
+ total += n
+ if n != len(q.waitBuf[i]) {
+ // The read buffer filled up without consuming the
+ // entire buffer.
+ q.waitBuf[i] = q.waitBuf[i][n:]
+ break
+ }
+ }
+
+ // Update wait buffer based on consumed data.
+ q.waitBuf = q.waitBuf[i:]
+ q.waitBufLen -= uint64(total)
+
+ return total
+}
+
+// Precondition: q.mu must be locked.
+func (q *queue) waitBufAppend(b []byte) {
+ q.waitBuf = append(q.waitBuf, b)
+ q.waitBufLen += uint64(len(b))
+}
diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go
new file mode 100644
index 000000000..cf1a0f0ac
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/slave.go
@@ -0,0 +1,197 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devpts
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// slaveInode is the inode for the slave end of the Terminal.
+type slaveInode struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+
+ locks vfs.FileLocks
+
+ // Keep a reference to this inode's dentry.
+ dentry kernfs.Dentry
+
+ // root is the devpts root inode.
+ root *rootInode
+
+ // t is the connected Terminal.
+ t *Terminal
+}
+
+var _ kernfs.Inode = (*slaveInode)(nil)
+
+// Open implements kernfs.Inode.Open.
+func (si *slaveInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ si.IncRef()
+ fd := &slaveFileDescription{
+ inode: si,
+ }
+ fd.LockFD.Init(&si.locks)
+ if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ si.DecRef()
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+
+}
+
+// Valid implements kernfs.Inode.Valid.
+func (si *slaveInode) Valid(context.Context) bool {
+ // Return valid if the slave still exists.
+ si.root.mu.Lock()
+ defer si.root.mu.Unlock()
+ _, ok := si.root.slaves[si.t.n]
+ return ok
+}
+
+// Stat implements kernfs.Inode.Stat.
+func (si *slaveInode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ statx, err := si.InodeAttrs.Stat(vfsfs, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ statx.Blksize = 1024
+ statx.RdevMajor = linux.UNIX98_PTY_SLAVE_MAJOR
+ statx.RdevMinor = si.t.n
+ return statx, nil
+}
+
+// SetStat implements kernfs.Inode.SetStat
+func (si *slaveInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ if opts.Stat.Mask&linux.STATX_SIZE != 0 {
+ return syserror.EINVAL
+ }
+ return si.InodeAttrs.SetStat(ctx, vfsfs, creds, opts)
+}
+
+type slaveFileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ inode *slaveInode
+}
+
+var _ vfs.FileDescriptionImpl = (*slaveFileDescription)(nil)
+
+// Release implements fs.FileOperations.Release.
+func (sfd *slaveFileDescription) Release() {
+ sfd.inode.DecRef()
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (sfd *slaveFileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ sfd.inode.t.ld.slaveWaiter.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (sfd *slaveFileDescription) EventUnregister(e *waiter.Entry) {
+ sfd.inode.t.ld.slaveWaiter.EventUnregister(e)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (sfd *slaveFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return sfd.inode.t.ld.slaveReadiness()
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (sfd *slaveFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
+ return sfd.inode.t.ld.inputQueueRead(ctx, dst)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (sfd *slaveFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) {
+ return sfd.inode.t.ld.outputQueueWrite(ctx, src)
+}
+
+// Ioctl implements vfs.FileDescripionImpl.Ioctl.
+func (sfd *slaveFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch cmd := args[1].Uint(); cmd {
+ case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ
+ // Get the number of bytes in the input queue read buffer.
+ return 0, sfd.inode.t.ld.inputQueueReadSize(ctx, io, args)
+ case linux.TCGETS:
+ return sfd.inode.t.ld.getTermios(ctx, io, args)
+ case linux.TCSETS:
+ return sfd.inode.t.ld.setTermios(ctx, io, args)
+ case linux.TCSETSW:
+ // TODO(b/29356795): This should drain the output queue first.
+ return sfd.inode.t.ld.setTermios(ctx, io, args)
+ case linux.TIOCGPTN:
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(sfd.inode.t.n), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ case linux.TIOCGWINSZ:
+ return 0, sfd.inode.t.ld.windowSize(ctx, io, args)
+ case linux.TIOCSWINSZ:
+ return 0, sfd.inode.t.ld.setWindowSize(ctx, io, args)
+ case linux.TIOCSCTTY:
+ // Make the given terminal the controlling terminal of the
+ // calling process.
+ return 0, sfd.inode.t.setControllingTTY(ctx, io, args, false /* isMaster */)
+ case linux.TIOCNOTTY:
+ // Release this process's controlling terminal.
+ return 0, sfd.inode.t.releaseControllingTTY(ctx, io, args, false /* isMaster */)
+ case linux.TIOCGPGRP:
+ // Get the foreground process group.
+ return sfd.inode.t.foregroundProcessGroup(ctx, io, args, false /* isMaster */)
+ case linux.TIOCSPGRP:
+ // Set the foreground process group.
+ return sfd.inode.t.setForegroundProcessGroup(ctx, io, args, false /* isMaster */)
+ default:
+ maybeEmitUnimplementedEvent(ctx, cmd)
+ return 0, syserror.ENOTTY
+ }
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (sfd *slaveFileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ creds := auth.CredentialsFromContext(ctx)
+ fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return sfd.inode.SetStat(ctx, fs, creds, opts)
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (sfd *slaveFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return sfd.inode.Stat(fs, opts)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (sfd *slaveFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return sfd.Locks().LockPOSIX(ctx, &sfd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (sfd *slaveFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return sfd.Locks().UnlockPOSIX(ctx, &sfd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/devpts/terminal.go b/pkg/sentry/fsimpl/devpts/terminal.go
new file mode 100644
index 000000000..7d2781c54
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/terminal.go
@@ -0,0 +1,120 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devpts
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Terminal is a pseudoterminal.
+//
+// +stateify savable
+type Terminal struct {
+ // n is the terminal index. It is immutable.
+ n uint32
+
+ // ld is the line discipline of the terminal. It is immutable.
+ ld *lineDiscipline
+
+ // masterKTTY contains the controlling process of the master end of
+ // this terminal. This field is immutable.
+ masterKTTY *kernel.TTY
+
+ // slaveKTTY contains the controlling process of the slave end of this
+ // terminal. This field is immutable.
+ slaveKTTY *kernel.TTY
+}
+
+func newTerminal(n uint32) *Terminal {
+ termios := linux.DefaultSlaveTermios
+ t := Terminal{
+ n: n,
+ ld: newLineDiscipline(termios),
+ masterKTTY: &kernel.TTY{Index: n},
+ slaveKTTY: &kernel.TTY{Index: n},
+ }
+ return &t
+}
+
+// setControllingTTY makes tm the controlling terminal of the calling thread
+// group.
+func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("setControllingTTY must be called from a task context")
+ }
+
+ return task.ThreadGroup().SetControllingTTY(tm.tty(isMaster), args[2].Int())
+}
+
+// releaseControllingTTY removes tm as the controlling terminal of the calling
+// thread group.
+func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("releaseControllingTTY must be called from a task context")
+ }
+
+ return task.ThreadGroup().ReleaseControllingTTY(tm.tty(isMaster))
+}
+
+// foregroundProcessGroup gets the process group ID of tm's foreground process.
+func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("foregroundProcessGroup must be called from a task context")
+ }
+
+ ret, err := task.ThreadGroup().ForegroundProcessGroup(tm.tty(isMaster))
+ if err != nil {
+ return 0, err
+ }
+
+ // Write it out to *arg.
+ _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(ret), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+}
+
+// foregroundProcessGroup sets tm's foreground process.
+func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("setForegroundProcessGroup must be called from a task context")
+ }
+
+ // Read in the process group ID.
+ var pgid int32
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgid, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+
+ ret, err := task.ThreadGroup().SetForegroundProcessGroup(tm.tty(isMaster), kernel.ProcessGroupID(pgid))
+ return uintptr(ret), err
+}
+
+func (tm *Terminal) tty(isMaster bool) *kernel.TTY {
+ if isMaster {
+ return tm.masterKTTY
+ }
+ return tm.slaveKTTY
+}
diff --git a/pkg/sentry/fsimpl/devtmpfs/BUILD b/pkg/sentry/fsimpl/devtmpfs/BUILD
new file mode 100644
index 000000000..aa0c2ad8c
--- /dev/null
+++ b/pkg/sentry/fsimpl/devtmpfs/BUILD
@@ -0,0 +1,33 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+licenses(["notice"])
+
+go_library(
+ name = "devtmpfs",
+ srcs = ["devtmpfs.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "devtmpfs_test",
+ size = "small",
+ srcs = ["devtmpfs_test.go"],
+ library = ":devtmpfs",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/fspath",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
new file mode 100644
index 000000000..d0e06cdc0
--- /dev/null
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
@@ -0,0 +1,219 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package devtmpfs provides an implementation of /dev based on tmpfs,
+// analogous to Linux's devtmpfs.
+package devtmpfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Name is the default filesystem name.
+const Name = "devtmpfs"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct {
+ initOnce sync.Once
+ initErr error
+
+ // fs is the tmpfs filesystem that backs all mounts of this FilesystemType.
+ // root is fs' root. fs and root are immutable.
+ fs *vfs.Filesystem
+ root *vfs.Dentry
+}
+
+// Name implements vfs.FilesystemType.Name.
+func (*FilesystemType) Name() string {
+ return Name
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fst *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ fst.initOnce.Do(func() {
+ fs, root, err := tmpfs.FilesystemType{}.GetFilesystem(ctx, vfsObj, creds, "" /* source */, vfs.GetFilesystemOptions{
+ Data: "mode=0755", // opts from drivers/base/devtmpfs.c:devtmpfs_init()
+ })
+ if err != nil {
+ fst.initErr = err
+ return
+ }
+ fst.fs = fs
+ fst.root = root
+ })
+ if fst.initErr != nil {
+ return nil, nil, fst.initErr
+ }
+ fst.fs.IncRef()
+ fst.root.IncRef()
+ return fst.fs, fst.root, nil
+}
+
+// Accessor allows devices to create device special files in devtmpfs.
+type Accessor struct {
+ vfsObj *vfs.VirtualFilesystem
+ mntns *vfs.MountNamespace
+ root vfs.VirtualDentry
+ creds *auth.Credentials
+}
+
+// NewAccessor returns an Accessor that supports creation of device special
+// files in the devtmpfs instance registered with name fsTypeName in vfsObj.
+func NewAccessor(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, fsTypeName string) (*Accessor, error) {
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "devtmpfs" /* source */, fsTypeName, &vfs.GetFilesystemOptions{})
+ if err != nil {
+ return nil, err
+ }
+ return &Accessor{
+ vfsObj: vfsObj,
+ mntns: mntns,
+ root: mntns.Root(),
+ creds: creds,
+ }, nil
+}
+
+// Release must be called when a is no longer in use.
+func (a *Accessor) Release() {
+ a.root.DecRef()
+ a.mntns.DecRef()
+}
+
+// accessorContext implements context.Context by extending an existing
+// context.Context with an Accessor's values for VFS-relevant state.
+type accessorContext struct {
+ context.Context
+ a *Accessor
+}
+
+func (a *Accessor) wrapContext(ctx context.Context) *accessorContext {
+ return &accessorContext{
+ Context: ctx,
+ a: a,
+ }
+}
+
+// Value implements context.Context.Value.
+func (ac *accessorContext) Value(key interface{}) interface{} {
+ switch key {
+ case vfs.CtxMountNamespace:
+ ac.a.mntns.IncRef()
+ return ac.a.mntns
+ case vfs.CtxRoot:
+ ac.a.root.IncRef()
+ return ac.a.root
+ default:
+ return ac.Context.Value(key)
+ }
+}
+
+func (a *Accessor) pathOperationAt(pathname string) *vfs.PathOperation {
+ return &vfs.PathOperation{
+ Root: a.root,
+ Start: a.root,
+ Path: fspath.Parse(pathname),
+ }
+}
+
+// CreateDeviceFile creates a device special file at the given pathname in the
+// devtmpfs instance accessed by the Accessor.
+func (a *Accessor) CreateDeviceFile(ctx context.Context, pathname string, kind vfs.DeviceKind, major, minor uint32, perms uint16) error {
+ actx := a.wrapContext(ctx)
+
+ mode := (linux.FileMode)(perms)
+ switch kind {
+ case vfs.BlockDevice:
+ mode |= linux.S_IFBLK
+ case vfs.CharDevice:
+ mode |= linux.S_IFCHR
+ default:
+ panic(fmt.Sprintf("invalid vfs.DeviceKind: %v", kind))
+ }
+
+ // Create any parent directories. See
+ // devtmpfs.c:handle_create()=>path_create().
+ for it := fspath.Parse(pathname).Begin; it.NextOk(); it = it.Next() {
+ pop := a.pathOperationAt(it.String())
+ if err := a.vfsObj.MkdirAt(actx, a.creds, pop, &vfs.MkdirOptions{
+ Mode: 0755,
+ }); err != nil {
+ return fmt.Errorf("failed to create directory %q: %v", it.String(), err)
+ }
+ }
+
+ // NOTE: Linux's devtmpfs refuses to automatically delete files it didn't
+ // create, which it recognizes by storing a pointer to the kdevtmpfs struct
+ // thread in struct inode::i_private. Accessor doesn't yet support deletion
+ // of files at all, and probably won't as long as we don't need to support
+ // kernel modules, so this is moot for now.
+ return a.vfsObj.MknodAt(actx, a.creds, a.pathOperationAt(pathname), &vfs.MknodOptions{
+ Mode: mode,
+ DevMajor: major,
+ DevMinor: minor,
+ })
+}
+
+// UserspaceInit creates symbolic links and mount points in the devtmpfs
+// instance accessed by the Accessor that are created by userspace in Linux. It
+// does not create mounts.
+func (a *Accessor) UserspaceInit(ctx context.Context) error {
+ actx := a.wrapContext(ctx)
+
+ // Initialize symlinks.
+ for _, symlink := range []struct {
+ source string
+ target string
+ }{
+ // systemd: src/shared/dev-setup.c:dev_setup()
+ {source: "fd", target: "/proc/self/fd"},
+ {source: "stdin", target: "/proc/self/fd/0"},
+ {source: "stdout", target: "/proc/self/fd/1"},
+ {source: "stderr", target: "/proc/self/fd/2"},
+ // /proc/kcore is not implemented.
+
+ // Linux implements /dev/ptmx as a device node, but advises
+ // container implementations to create /dev/ptmx as a symlink
+ // to pts/ptmx (Documentation/filesystems/devpts.txt). Systemd
+ // follows this advice (src/nspawn/nspawn.c:setup_pts()), while
+ // LXC tries to create a bind mount and falls back to a symlink
+ // (src/lxc/conf.c:lxc_setup_devpts()).
+ {source: "ptmx", target: "pts/ptmx"},
+ } {
+ if err := a.vfsObj.SymlinkAt(actx, a.creds, a.pathOperationAt(symlink.source), symlink.target); err != nil {
+ return fmt.Errorf("failed to create symlink %q => %q: %v", symlink.source, symlink.target, err)
+ }
+ }
+
+ // systemd: src/core/mount-setup.c:mount_table
+ for _, dir := range []string{
+ "shm",
+ "pts",
+ } {
+ if err := a.vfsObj.MkdirAt(actx, a.creds, a.pathOperationAt(dir), &vfs.MkdirOptions{
+ // systemd: src/core/mount-setup.c:mount_one()
+ Mode: 0755,
+ }); err != nil {
+ return fmt.Errorf("failed to create directory %q: %v", dir, err)
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
new file mode 100644
index 000000000..b6d52c015
--- /dev/null
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
@@ -0,0 +1,122 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devtmpfs
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+func TestDevtmpfs(t *testing.T) {
+ ctx := contexttest.Context(t)
+ creds := auth.CredentialsFromContext(ctx)
+
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
+ // Register tmpfs just so that we can have a root filesystem that isn't
+ // devtmpfs.
+ vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ vfsObj.MustRegisterFilesystemType("devtmpfs", &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+
+ // Create a test mount namespace with devtmpfs mounted at "/dev".
+ const devPath = "/dev"
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "tmpfs" /* source */, "tmpfs" /* fsTypeName */, &vfs.GetFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("failed to create tmpfs root mount: %v", err)
+ }
+ defer mntns.DecRef()
+ root := mntns.Root()
+ defer root.DecRef()
+ devpop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(devPath),
+ }
+ if err := vfsObj.MkdirAt(ctx, creds, &devpop, &vfs.MkdirOptions{
+ Mode: 0755,
+ }); err != nil {
+ t.Fatalf("failed to create mount point: %v", err)
+ }
+ if err := vfsObj.MountAt(ctx, creds, "devtmpfs" /* source */, &devpop, "devtmpfs" /* fsTypeName */, &vfs.MountOptions{}); err != nil {
+ t.Fatalf("failed to mount devtmpfs: %v", err)
+ }
+
+ a, err := NewAccessor(ctx, vfsObj, creds, "devtmpfs")
+ if err != nil {
+ t.Fatalf("failed to create devtmpfs.Accessor: %v", err)
+ }
+ defer a.Release()
+
+ // Create "userspace-initialized" files using a devtmpfs.Accessor.
+ if err := a.UserspaceInit(ctx); err != nil {
+ t.Fatalf("failed to userspace-initialize devtmpfs: %v", err)
+ }
+ // Created files should be visible in the test mount namespace.
+ abspath := devPath + "/fd"
+ target, err := vfsObj.ReadlinkAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(abspath),
+ })
+ if want := "/proc/self/fd"; err != nil || target != want {
+ t.Fatalf("readlink(%q): got (%q, %v), wanted (%q, nil)", abspath, target, err, want)
+ }
+
+ // Create a dummy device special file using a devtmpfs.Accessor.
+ const (
+ pathInDev = "dummy"
+ kind = vfs.CharDevice
+ major = 12
+ minor = 34
+ perms = 0600
+ wantMode = linux.S_IFCHR | perms
+ )
+ if err := a.CreateDeviceFile(ctx, pathInDev, kind, major, minor, perms); err != nil {
+ t.Fatalf("failed to create device file: %v", err)
+ }
+ // The device special file should be visible in the test mount namespace.
+ abspath = devPath + "/" + pathInDev
+ stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(abspath),
+ }, &vfs.StatOptions{
+ Mask: linux.STATX_TYPE | linux.STATX_MODE,
+ })
+ if err != nil {
+ t.Fatalf("failed to stat device file at %q: %v", abspath, err)
+ }
+ if stat.Mode != wantMode {
+ t.Errorf("device file mode: got %v, wanted %v", stat.Mode, wantMode)
+ }
+ if stat.RdevMajor != major {
+ t.Errorf("major device number: got %v, wanted %v", stat.RdevMajor, major)
+ }
+ if stat.RdevMinor != minor {
+ t.Errorf("minor device number: got %v, wanted %v", stat.RdevMinor, minor)
+ }
+}
diff --git a/pkg/sentry/fsimpl/eventfd/BUILD b/pkg/sentry/fsimpl/eventfd/BUILD
new file mode 100644
index 000000000..ea167d38c
--- /dev/null
+++ b/pkg/sentry/fsimpl/eventfd/BUILD
@@ -0,0 +1,33 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+licenses(["notice"])
+
+go_library(
+ name = "eventfd",
+ srcs = ["eventfd.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fdnotifier",
+ "//pkg/log",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "eventfd_test",
+ size = "small",
+ srcs = ["eventfd_test.go"],
+ library = ":eventfd",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/vfs",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go
new file mode 100644
index 000000000..d12d78b84
--- /dev/null
+++ b/pkg/sentry/fsimpl/eventfd/eventfd.go
@@ -0,0 +1,285 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package eventfd implements event fds.
+package eventfd
+
+import (
+ "math"
+ "sync"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// EventFileDescription implements FileDescriptionImpl for file-based event
+// notification (eventfd). Eventfds are usually internal to the Sentry but in
+// certain situations they may be converted into a host-backed eventfd.
+type EventFileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+
+ // queue is used to notify interested parties when the event object
+ // becomes readable or writable.
+ queue waiter.Queue `state:"zerovalue"`
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // val is the current value of the event counter.
+ val uint64
+
+ // semMode specifies whether the event is in "semaphore" mode.
+ semMode bool
+
+ // hostfd indicates whether this eventfd is passed through to the host.
+ hostfd int
+}
+
+var _ vfs.FileDescriptionImpl = (*EventFileDescription)(nil)
+
+// New creates a new event fd.
+func New(vfsObj *vfs.VirtualFilesystem, initVal uint64, semMode bool, flags uint32) (*vfs.FileDescription, error) {
+ vd := vfsObj.NewAnonVirtualDentry("[eventfd]")
+ defer vd.DecRef()
+ efd := &EventFileDescription{
+ val: initVal,
+ semMode: semMode,
+ hostfd: -1,
+ }
+ if err := efd.vfsfd.Init(efd, flags, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ DenyPRead: true,
+ DenyPWrite: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &efd.vfsfd, nil
+}
+
+// HostFD returns the host eventfd associated with this event.
+func (efd *EventFileDescription) HostFD() (int, error) {
+ efd.mu.Lock()
+ defer efd.mu.Unlock()
+ if efd.hostfd >= 0 {
+ return efd.hostfd, nil
+ }
+
+ flags := linux.EFD_NONBLOCK
+ if efd.semMode {
+ flags |= linux.EFD_SEMAPHORE
+ }
+
+ fd, _, errno := syscall.Syscall(syscall.SYS_EVENTFD2, uintptr(efd.val), uintptr(flags), 0)
+ if errno != 0 {
+ return -1, errno
+ }
+
+ if err := fdnotifier.AddFD(int32(fd), &efd.queue); err != nil {
+ if closeErr := syscall.Close(int(fd)); closeErr != nil {
+ log.Warningf("close(%d) eventfd failed: %v", fd, closeErr)
+ }
+ return -1, err
+ }
+
+ efd.hostfd = int(fd)
+ return efd.hostfd, nil
+}
+
+// Release implements FileDescriptionImpl.Release()
+func (efd *EventFileDescription) Release() {
+ efd.mu.Lock()
+ defer efd.mu.Unlock()
+ if efd.hostfd >= 0 {
+ fdnotifier.RemoveFD(int32(efd.hostfd))
+ if closeErr := syscall.Close(int(efd.hostfd)); closeErr != nil {
+ log.Warningf("close(%d) eventfd failed: %v", efd.hostfd, closeErr)
+ }
+ efd.hostfd = -1
+ }
+}
+
+// Read implements FileDescriptionImpl.Read.
+func (efd *EventFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
+ if dst.NumBytes() < 8 {
+ return 0, syscall.EINVAL
+ }
+ if err := efd.read(ctx, dst); err != nil {
+ return 0, err
+ }
+ return 8, nil
+}
+
+// Write implements FileDescriptionImpl.Write.
+func (efd *EventFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) {
+ if src.NumBytes() < 8 {
+ return 0, syscall.EINVAL
+ }
+ if err := efd.write(ctx, src); err != nil {
+ return 0, err
+ }
+ return 8, nil
+}
+
+// Preconditions: Must be called with efd.mu locked.
+func (efd *EventFileDescription) hostReadLocked(ctx context.Context, dst usermem.IOSequence) error {
+ var buf [8]byte
+ if _, err := syscall.Read(efd.hostfd, buf[:]); err != nil {
+ if err == syscall.EWOULDBLOCK {
+ return syserror.ErrWouldBlock
+ }
+ return err
+ }
+ _, err := dst.CopyOut(ctx, buf[:])
+ return err
+}
+
+func (efd *EventFileDescription) read(ctx context.Context, dst usermem.IOSequence) error {
+ efd.mu.Lock()
+ if efd.hostfd >= 0 {
+ defer efd.mu.Unlock()
+ return efd.hostReadLocked(ctx, dst)
+ }
+
+ // We can't complete the read if the value is currently zero.
+ if efd.val == 0 {
+ efd.mu.Unlock()
+ return syserror.ErrWouldBlock
+ }
+
+ // Update the value based on the mode the event is operating in.
+ var val uint64
+ if efd.semMode {
+ val = 1
+ // Consistent with Linux, this is done even if writing to memory fails.
+ efd.val--
+ } else {
+ val = efd.val
+ efd.val = 0
+ }
+
+ efd.mu.Unlock()
+
+ // Notify writers. We do this even if we were already writable because
+ // it is possible that a writer is waiting to write the maximum value
+ // to the event.
+ efd.queue.Notify(waiter.EventOut)
+
+ var buf [8]byte
+ usermem.ByteOrder.PutUint64(buf[:], val)
+ _, err := dst.CopyOut(ctx, buf[:])
+ return err
+}
+
+// Preconditions: Must be called with efd.mu locked.
+func (efd *EventFileDescription) hostWriteLocked(val uint64) error {
+ var buf [8]byte
+ usermem.ByteOrder.PutUint64(buf[:], val)
+ _, err := syscall.Write(efd.hostfd, buf[:])
+ if err == syscall.EWOULDBLOCK {
+ return syserror.ErrWouldBlock
+ }
+ return err
+}
+
+func (efd *EventFileDescription) write(ctx context.Context, src usermem.IOSequence) error {
+ var buf [8]byte
+ if _, err := src.CopyIn(ctx, buf[:]); err != nil {
+ return err
+ }
+ val := usermem.ByteOrder.Uint64(buf[:])
+
+ return efd.Signal(val)
+}
+
+// Signal is an internal function to signal the event fd.
+func (efd *EventFileDescription) Signal(val uint64) error {
+ if val == math.MaxUint64 {
+ return syscall.EINVAL
+ }
+
+ efd.mu.Lock()
+
+ if efd.hostfd >= 0 {
+ defer efd.mu.Unlock()
+ return efd.hostWriteLocked(val)
+ }
+
+ // We only allow writes that won't cause the value to go over the max
+ // uint64 minus 1.
+ if val > math.MaxUint64-1-efd.val {
+ efd.mu.Unlock()
+ return syserror.ErrWouldBlock
+ }
+
+ efd.val += val
+ efd.mu.Unlock()
+
+ // Always trigger a notification.
+ efd.queue.Notify(waiter.EventIn)
+
+ return nil
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (efd *EventFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
+ efd.mu.Lock()
+ defer efd.mu.Unlock()
+
+ if efd.hostfd >= 0 {
+ return fdnotifier.NonBlockingPoll(int32(efd.hostfd), mask)
+ }
+
+ ready := waiter.EventMask(0)
+ if efd.val > 0 {
+ ready |= waiter.EventIn
+ }
+
+ if efd.val < math.MaxUint64-1 {
+ ready |= waiter.EventOut
+ }
+
+ return mask & ready
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (efd *EventFileDescription) EventRegister(entry *waiter.Entry, mask waiter.EventMask) {
+ efd.queue.EventRegister(entry, mask)
+
+ efd.mu.Lock()
+ defer efd.mu.Unlock()
+ if efd.hostfd >= 0 {
+ fdnotifier.UpdateFD(int32(efd.hostfd))
+ }
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (efd *EventFileDescription) EventUnregister(entry *waiter.Entry) {
+ efd.queue.EventUnregister(entry)
+
+ efd.mu.Lock()
+ defer efd.mu.Unlock()
+ if efd.hostfd >= 0 {
+ fdnotifier.UpdateFD(int32(efd.hostfd))
+ }
+}
diff --git a/pkg/sentry/fsimpl/eventfd/eventfd_test.go b/pkg/sentry/fsimpl/eventfd/eventfd_test.go
new file mode 100644
index 000000000..20e3adffc
--- /dev/null
+++ b/pkg/sentry/fsimpl/eventfd/eventfd_test.go
@@ -0,0 +1,97 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package eventfd
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestEventFD(t *testing.T) {
+ initVals := []uint64{
+ 0,
+ // Using a non-zero initial value verifies that writing to an
+ // eventfd signals when the eventfd's counter was already
+ // non-zero.
+ 343,
+ }
+
+ for _, initVal := range initVals {
+ ctx := contexttest.Context(t)
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
+
+ // Make a new eventfd that is writable.
+ eventfd, err := New(vfsObj, initVal, false, linux.O_RDWR)
+ if err != nil {
+ t.Fatalf("New() failed: %v", err)
+ }
+ defer eventfd.DecRef()
+
+ // Register a callback for a write event.
+ w, ch := waiter.NewChannelEntry(nil)
+ eventfd.EventRegister(&w, waiter.EventIn)
+ defer eventfd.EventUnregister(&w)
+
+ data := []byte("00000124")
+ // Create and submit a write request.
+ n, err := eventfd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 8 {
+ t.Errorf("eventfd.write wrote %d bytes, not full int64", n)
+ }
+
+ // Check if the callback fired due to the write event.
+ select {
+ case <-ch:
+ default:
+ t.Errorf("Didn't get notified of EventIn after write")
+ }
+ }
+}
+
+func TestEventFDStat(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
+
+ // Make a new eventfd that is writable.
+ eventfd, err := New(vfsObj, 0, false, linux.O_RDWR)
+ if err != nil {
+ t.Fatalf("New() failed: %v", err)
+ }
+ defer eventfd.DecRef()
+
+ statx, err := eventfd.Stat(ctx, vfs.StatOptions{
+ Mask: linux.STATX_BASIC_STATS,
+ })
+ if err != nil {
+ t.Fatalf("eventfd.Stat failed: %v", err)
+ }
+ if statx.Size != 0 {
+ t.Errorf("eventfd size should be 0")
+ }
+}
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
new file mode 100644
index 000000000..ef24f8159
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -0,0 +1,102 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "dirent_list",
+ out = "dirent_list.go",
+ package = "ext",
+ prefix = "dirent",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*dirent",
+ "Linker": "*dirent",
+ },
+)
+
+go_template_instance(
+ name = "fstree",
+ out = "fstree.go",
+ package = "ext",
+ prefix = "generic",
+ template = "//pkg/sentry/vfs/genericfstree:generic_fstree",
+ types = {
+ "Dentry": "dentry",
+ },
+)
+
+go_library(
+ name = "ext",
+ srcs = [
+ "block_map_file.go",
+ "dentry.go",
+ "directory.go",
+ "dirent_list.go",
+ "ext.go",
+ "extent_file.go",
+ "file_description.go",
+ "filesystem.go",
+ "fstree.go",
+ "inode.go",
+ "regular_file.go",
+ "symlink.go",
+ "utils.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/context",
+ "//pkg/fd",
+ "//pkg/fspath",
+ "//pkg/log",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsimpl/ext/disklayout",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/syscalls/linux",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "ext_test",
+ size = "small",
+ srcs = [
+ "block_map_test.go",
+ "ext_test.go",
+ "extent_test.go",
+ ],
+ data = [
+ "//pkg/sentry/fsimpl/ext:assets/bigfile.txt",
+ "//pkg/sentry/fsimpl/ext:assets/file.txt",
+ "//pkg/sentry/fsimpl/ext:assets/tiny.ext2",
+ "//pkg/sentry/fsimpl/ext:assets/tiny.ext3",
+ "//pkg/sentry/fsimpl/ext:assets/tiny.ext4",
+ ],
+ library = ":ext",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fsimpl/ext/disklayout",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/test/testutil",
+ "//pkg/usermem",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/ext/README.md b/pkg/sentry/fsimpl/ext/README.md
new file mode 100644
index 000000000..af00cfda8
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/README.md
@@ -0,0 +1,117 @@
+## EXT(2/3/4) File System
+
+This is a filesystem driver which supports ext2, ext3 and ext4 filesystems.
+Linux has specialized drivers for each variant but none which supports all. This
+library takes advantage of ext's backward compatibility and understands the
+internal organization of on-disk structures to support all variants.
+
+This driver implementation diverges from the Linux implementations in being more
+forgiving about versioning. For instance, if a filesystem contains both extent
+based inodes and classical block map based inodes, this driver will not complain
+and interpret them both correctly. While in Linux this would be an issue. This
+blurs the line between the three ext fs variants.
+
+Ext2 is considered deprecated as of Red Hat Enterprise Linux 7, and ext3 has
+been superseded by ext4 by large performance gains. Thus it is recommended to
+upgrade older filesystem images to ext4 using e2fsprogs for better performance.
+
+### Read Only
+
+This driver currently only allows read only operations. A lot of the design
+decisions are based on this feature. There are plans to implement write (the
+process for which is documented in the future work section).
+
+### Performance
+
+One of the biggest wins about this driver is that it directly talks to the
+underlying block device (or whatever persistent storage is being used), instead
+of making expensive RPCs to a gofer.
+
+Another advantage is that ext fs supports fast concurrent reads. Currently the
+device is represented using a `io.ReaderAt` which allows for concurrent reads.
+All reads are directly passed to the device driver which intelligently serves
+the read requests in the optimal order. There is no congestion due to locking
+while reading in the filesystem level.
+
+Reads are optimized further in the way file data is transferred over to user
+memory. Ext fs directly copies over file data from disk into user memory with no
+additional allocations on the way. We can only get faster by preloading file
+data into memory (see future work section).
+
+The internal structures used to represent files, inodes and file descriptors use
+a lot of inheritance. With the level of indirection that an interface adds with
+an internal pointer, it can quickly fragment a structure across memory. As this
+runs along side a full blown kernel (which is memory intensive), having a
+fragmented struct might hurt performance. Hence these internal structures,
+though interfaced, are tightly packed in memory using the same inheritance
+pattern that pkg/sentry/vfs uses. The pkg/sentry/fsimpl/ext/disklayout package
+makes an execption to this pattern for reasons documented in the package.
+
+### Security
+
+This driver also intends to help sandbox the container better by reducing the
+surface of the host kernel that the application touches. It prevents the
+application from exploiting vulnerabilities in the host filesystem driver. All
+`io.ReaderAt.ReadAt()` calls are translated to `pread(2)` which are directly
+passed to the device driver in the kernel. Hence this reduces the surface for
+attack.
+
+The application can not affect any host filesystems other than the one passed
+via block device by the user.
+
+### Future Work
+
+#### Write
+
+To support write operations we would need to modify the block device underneath.
+Currently, the driver does not modify the device at all, not even for updating
+the access times for reads. Modifying the filesystem incorrectly can corrupt it
+and render it unreadable for other correct ext(x) drivers. Hence caution must be
+maintained while modifying metadata structures.
+
+Ext4 specifically is built for performance and has added a lot of complexity as
+to how metadata structures are modified. For instance, files that are organized
+via an extent tree which must be balanced and file data blocks must be placed in
+the same extent as much as possible to increase locality. Such properties must
+be maintained while modifying the tree.
+
+Ext filesystems boast a lot about locality, which plays a big role in them being
+performant. The block allocation algorithm in Linux does a good job in keeping
+related data together. This behavior must be maintained as much as possible,
+else we might end up degrading the filesystem performance over time.
+
+Ext4 also supports a wide variety of features which are specialized for varying
+use cases. Implementing all of them can get difficult very quickly.
+
+Ext(x) checksums all its metadata structures to check for corruption, so
+modification of any metadata struct must correspond with re-checksumming the
+struct. Linux filesystem drivers also order on-disk updates intelligently to not
+corrupt the filesystem and also remain performant. The in-memory metadata
+structures must be kept in sync with what is on disk.
+
+There is also replication of some important structures across the filesystem.
+All replicas must be updated when their original copy is updated. There is also
+provisioning for snapshotting which must be kept in mind, although it should not
+affect this implementation unless we allow users to create filesystem snapshots.
+
+Ext4 also introduced journaling (jbd2). The journal must be updated
+appropriately.
+
+#### Performance
+
+To improve performance we should implement a buffer cache, and optionally, read
+ahead for small files. While doing so we must also keep in mind the memory usage
+and have a reasonable cap on how much file data we want to hold in memory.
+
+#### Features
+
+Our current implementation will work with most ext4 filesystems for readonly
+purposed. However, the following features are not supported yet:
+
+- Journal
+- Snapshotting
+- Extended Attributes
+- Hash Tree Directories
+- Meta Block Groups
+- Multiple Mount Protection
+- Bigalloc
diff --git a/pkg/sentry/fsimpl/ext/assets/README.md b/pkg/sentry/fsimpl/ext/assets/README.md
new file mode 100644
index 000000000..6f1e81b3a
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/assets/README.md
@@ -0,0 +1,36 @@
+### Tiny Ext(2/3/4) Images
+
+The images are of size 64Kb which supports 64 1k blocks and 16 inodes. This is
+the smallest size mkfs.ext(2/3/4) works with.
+
+These images were generated using the following commands.
+
+```bash
+fallocate -l 64K tiny.ext$VERSION
+mkfs.ext$VERSION -j tiny.ext$VERSION
+```
+
+where `VERSION` is `2`, `3` or `4`.
+
+You can mount it using:
+
+```bash
+sudo mount -o loop tiny.ext$VERSION $MOUNTPOINT
+```
+
+`file.txt`, `bigfile.txt` and `symlink.txt` were added to this image by just
+mounting it and copying (while preserving links) those files to the mountpoint
+directory using:
+
+```bash
+sudo cp -P {file.txt,symlink.txt,bigfile.txt} $MOUNTPOINT
+```
+
+The files in this directory mirror the contents and organisation of the files
+stored in the image.
+
+You can umount the filesystem using:
+
+```bash
+sudo umount $MOUNTPOINT
+```
diff --git a/pkg/sentry/fsimpl/ext/assets/bigfile.txt b/pkg/sentry/fsimpl/ext/assets/bigfile.txt
new file mode 100644
index 000000000..3857cf516
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/assets/bigfile.txt
@@ -0,0 +1,41 @@
+Lorem ipsum dolor sit amet, consectetur adipiscing elit. Phasellus faucibus eleifend orci, ut ornare nibh faucibus eu. Cras at condimentum massa. Nullam luctus, elit non porttitor congue, sapien diam feugiat sapien, sed eleifend nulla mauris non arcu. Sed lacinia mauris magna, eu mollis libero varius sit amet. Donec mollis, quam convallis commodo posuere, dolor nisi placerat nisi, in faucibus augue mi eu lorem. In pharetra consectetur faucibus. Ut euismod ex efficitur egestas tincidunt. Maecenas condimentum ut ante in rutrum. Vivamus sed arcu tempor, faucibus turpis et, lacinia diam.
+
+Sed in lacus vel nisl interdum bibendum in sed justo. Nunc tellus risus, molestie vitae arcu sed, molestie tempus ligula. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Nunc risus neque, volutpat et ante non, ullamcorper condimentum ante. Aliquam sed metus in urna condimentum convallis. Vivamus ut libero mauris. Proin mollis posuere consequat. Vestibulum placerat mollis est et pulvinar.
+
+Donec rutrum odio ac diam pharetra, id fermentum magna cursus. Pellentesque in dapibus elit, et condimentum orci. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Suspendisse euismod dapibus est, id vestibulum mauris. Nulla facilisi. Nulla cursus gravida nisi. Phasellus vestibulum rutrum lectus, a dignissim mauris hendrerit vitae. In at elementum mauris. Integer vel efficitur velit. Nullam fringilla sapien mi, quis luctus neque efficitur ac. Aenean nec quam dapibus nunc commodo pharetra. Proin sapien mi, fermentum aliquet vulputate non, aliquet porttitor diam. Quisque lacinia, urna et finibus fermentum, nunc lacus vehicula ex, sed congue metus lectus ac quam. Aliquam erat volutpat. Suspendisse sodales, dolor ut tincidunt finibus, augue erat varius tellus, a interdum erat sem at nunc. Vestibulum cursus iaculis sapien, vitae feugiat dui auctor quis.
+
+Pellentesque nec maximus nulla, eu blandit diam. Maecenas quis arcu ornare, congue ante at, vehicula ipsum. Praesent feugiat mauris rutrum sem fermentum, nec luctus ipsum placerat. Pellentesque placerat ipsum at dignissim fringilla. Vivamus et posuere sem, eget hendrerit felis. Aenean vulputate, augue vel mollis feugiat, justo ipsum mollis dolor, eu mollis elit neque ut ipsum. Orci varius natoque penatibus et magnis dis parturient montes, nascetur ridiculus mus. Fusce bibendum sem quam, vulputate laoreet mi dapibus imperdiet. Sed a purus non nibh pretium aliquet. Integer eget luctus augue, vitae tincidunt magna. Ut eros enim, egestas eu nulla et, lobortis egestas arcu. Cras id ipsum ac justo lacinia rutrum. Vivamus lectus leo, ultricies sed justo at, pellentesque feugiat magna. Ut sollicitudin neque elit, vel ornare mauris commodo id.
+
+Duis dapibus orci et sapien finibus finibus. Mauris eleifend, lacus at vestibulum maximus, quam ligula pharetra erat, sit amet dapibus neque elit vitae neque. In bibendum sollicitudin erat, eget ultricies tortor malesuada at. Sed sit amet orci turpis. Donec feugiat ligula nibh, molestie tincidunt lectus elementum id. Donec volutpat maximus nibh, in vulputate felis posuere eu. Cras tincidunt ullamcorper lacus. Phasellus porta lorem auctor, congue magna a, commodo elit.
+
+Etiam auctor mi quis elit sodales, eu pulvinar arcu condimentum. Aenean imperdiet risus et dapibus tincidunt. Nullam tincidunt dictum dui, sed commodo urna rutrum id. Ut mollis libero vel elit laoreet bibendum. Quisque arcu arcu, tincidunt at ultricies id, vulputate nec metus. In tristique posuere quam sit amet volutpat. Vivamus scelerisque et nunc at dapibus. Fusce finibus libero ut ligula pretium rhoncus. Mauris non elit in arcu finibus imperdiet. Pellentesque nec massa odio. Proin rutrum mauris non sagittis efficitur. Aliquam auctor quam at dignissim faucibus. Ut eget ligula in magna posuere ultricies vitae sit amet turpis. Duis maximus odio nulla. Donec gravida sem tristique tempus scelerisque.
+
+Interdum et malesuada fames ac ante ipsum primis in faucibus. Fusce pharetra magna vulputate aliquet tempus. Duis id hendrerit arcu. Quisque ut ex elit. Integer velit orci, venenatis ut sapien ac, placerat porttitor dui. Interdum et malesuada fames ac ante ipsum primis in faucibus. Nunc hendrerit cursus diam, hendrerit finibus ipsum scelerisque ut. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos.
+
+Nulla non euismod neque. Phasellus vel sapien eu metus pulvinar rhoncus. Suspendisse eu mollis tellus, quis vestibulum tortor. Maecenas interdum dolor sed nulla fermentum maximus. Donec imperdiet ullamcorper condimentum. Nam quis nibh ante. Praesent quis tellus ut tortor pulvinar blandit sit amet ut sapien. Vestibulum est orci, pellentesque vitae tristique sit amet, tristique non felis.
+
+Vivamus sodales pellentesque varius. Sed vel tempus ligula. Nulla tristique nisl vel dui facilisis, ac sodales augue hendrerit. Proin augue nisi, vestibulum quis augue nec, sagittis tincidunt velit. Vestibulum euismod, nulla nec sodales faucibus, urna sapien vulputate magna, id varius metus sapien ut neque. Duis in mollis urna, in scelerisque enim. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Nunc condimentum dictum turpis, et egestas neque dapibus eget. Quisque fringilla, dui eu venenatis eleifend, erat nibh lacinia urna, at lacinia lacus sapien eu dui. Duis eu erat ut mi lacinia convallis a sed ex.
+
+Fusce elit metus, tincidunt nec eleifend a, hendrerit nec ligula. Duis placerat finibus sollicitudin. In euismod porta tellus, in luctus justo bibendum bibendum. Maecenas at magna eleifend lectus tincidunt suscipit ut a ligula. Nulla tempor accumsan felis, fermentum dapibus est eleifend vitae. Mauris urna sem, fringilla at ultricies non, ultrices in arcu. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Nam vehicula nunc at laoreet imperdiet. Nunc tristique ut risus id aliquet. Integer eleifend massa orci.
+
+Vestibulum sed ante sollicitudin nisi fringilla bibendum nec vel quam. Sed pretium augue eu ligula congue pulvinar. Donec vitae magna tincidunt, pharetra lacus id, convallis nulla. Cras viverra nisl nisl, varius convallis leo vulputate nec. Morbi at consequat dui, sed aliquet metus. Sed suscipit fermentum mollis. Maecenas nec mi sodales, tincidunt purus in, tristique mauris. Orci varius natoque penatibus et magnis dis parturient montes, nascetur ridiculus mus. Donec interdum mi in velit efficitur, quis ultrices ex imperdiet. Sed vestibulum, magna ut tristique pretium, mi ipsum placerat tellus, non tempor enim augue et ex. Pellentesque eget felis quis ante sodales viverra ac sed lacus. Donec suscipit tempus massa, eget laoreet massa molestie at.
+
+Aenean fringilla dui non aliquet consectetur. Fusce cursus quam nec orci hendrerit faucibus. Donec consequat suscipit enim, non volutpat lectus auctor interdum. Proin lorem purus, maximus vel orci vitae, suscipit egestas turpis. Donec risus urna, congue a sem eu, aliquet placerat odio. Morbi gravida tristique turpis, quis efficitur enim. Nunc interdum gravida ipsum vel facilisis. Nunc congue finibus sollicitudin. Quisque euismod aliquet lectus et tincidunt. Curabitur ultrices sem ut mi fringilla fermentum. Morbi pretium, nisi sit amet dapibus congue, dolor enim consectetur risus, a interdum ligula odio sed odio. Quisque facilisis, mi at suscipit gravida, nunc sapien cursus justo, ut luctus odio nulla quis leo. Integer condimentum lobortis mauris, non egestas tellus lobortis sit amet.
+
+In sollicitudin velit ac ante vehicula, vitae varius tortor mollis. In hac habitasse platea dictumst. Quisque et orci lorem. Integer malesuada fringilla luctus. Pellentesque malesuada, mi non lobortis porttitor, ante ligula vulputate ante, nec dictum risus eros sit amet sapien. Nulla aliquam lorem libero, ac varius nulla tristique eget. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Ut pellentesque mauris orci, vel consequat mi varius a. Ut sit amet elit vulputate, lacinia metus non, fermentum nisl. Pellentesque eu nisi sed quam egestas blandit. Duis sit amet lobortis dolor. Donec consectetur sem interdum, tristique elit sit amet, sodales lacus. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Fusce id aliquam augue. Sed pretium congue risus vitae lacinia. Vestibulum non vulputate risus, ut malesuada justo.
+
+Sed odio elit, consectetur ac mauris quis, consequat commodo libero. Fusce sodales velit vulputate pulvinar fermentum. Donec iaculis nec nisl eget faucibus. Mauris at dictum velit. Donec fermentum lectus eu viverra volutpat. Aliquam consequat facilisis lorem, cursus consequat dui bibendum ullamcorper. Pellentesque nulla magna, imperdiet at magna et, cursus egestas enim. Nullam semper molestie lectus sit amet semper. Duis eget tincidunt est. Integer id neque risus. Integer ultricies hendrerit vestibulum. Donec blandit blandit sagittis. Nunc consectetur vitae nisi consectetur volutpat.
+
+Nulla id lorem fermentum, efficitur magna a, hendrerit dui. Vivamus sagittis orci gravida, bibendum quam eget, molestie est. Phasellus nec enim tincidunt, volutpat sapien non, laoreet diam. Nulla posuere enim nec porttitor lobortis. Donec auctor odio ut orci eleifend, ut eleifend purus convallis. Interdum et malesuada fames ac ante ipsum primis in faucibus. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Ut hendrerit, purus eget viverra tincidunt, sem magna imperdiet libero, et aliquam turpis neque vitae elit. Maecenas semper varius iaculis. Cras non lorem quis quam bibendum eleifend in et libero. Curabitur at purus mauris. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Vivamus porta diam sed elit eleifend gravida.
+
+Nulla facilisi. Ut ultricies diam vel diam consectetur, vel porta augue molestie. Fusce interdum sapien et metus facilisis pellentesque. Nulla convallis sem at nunc vehicula facilisis. Nam ac rutrum purus. Nunc bibendum, dolor sit amet tempus ullamcorper, lorem leo tempor sem, id fringilla nunc augue scelerisque augue. Nullam sit amet rutrum nisl. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; Donec sed mauris gravida eros vehicula sagittis at eget orci. Cras elementum, eros at accumsan bibendum, libero neque blandit purus, vitae vestibulum libero massa ac nibh. Integer at placerat nulla. Mauris eu eleifend orci. Aliquam consequat ligula vitae erat porta lobortis. Duis fermentum elit ac aliquet ornare.
+
+Mauris eget cursus tellus, eget sodales purus. Aliquam malesuada, augue id vulputate finibus, nisi ex bibendum nisl, sit amet laoreet quam urna a dolor. Nullam ultricies, sapien eu laoreet consequat, erat eros dignissim diam, ultrices sodales lectus mauris et leo. Morbi lacinia eu ante at tempus. Sed iaculis finibus magna malesuada efficitur. Donec faucibus erat sit amet elementum feugiat. Praesent a placerat nisi. Etiam lacinia gravida diam, et sollicitudin sapien tincidunt ut.
+
+Maecenas felis quam, tincidunt vitae venenatis scelerisque, viverra vitae odio. Phasellus enim neque, ultricies suscipit malesuada sit amet, vehicula sit amet purus. Nulla placerat sit amet dui vel tincidunt. Nam quis neque vel magna commodo egestas. Vestibulum sagittis rutrum lorem ut congue. Maecenas vel ultrices tellus. Donec efficitur, urna ac consequat iaculis, lorem felis pharetra eros, eget faucibus orci lectus sit amet arcu.
+
+Ut a tempus nisi. Nulla facilisi. Praesent vulputate maximus mi et dapibus. Sed sit amet libero ac augue hendrerit efficitur in a sapien. Mauris placerat velit sit amet tellus sollicitudin faucibus. Donec egestas a magna ac suscipit. Duis enim sapien, mollis sed egestas et, vestibulum vel leo.
+
+Proin quis dapibus dui. Donec eu tincidunt nunc. Vivamus eget purus consectetur, maximus ante vitae, tincidunt elit. Aenean mattis dolor a gravida aliquam. Praesent quis tellus id sem maximus vulputate nec sed nulla. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Curabitur metus nulla, volutpat volutpat est eu, hendrerit congue erat. Aliquam sollicitudin augue ante. Sed sollicitudin, magna eu consequat elementum, mi augue ullamcorper felis, molestie imperdiet erat metus iaculis est. Proin ac tortor nisi. Pellentesque quis nisi risus. Integer enim sapien, tincidunt quis tortor id, accumsan venenatis mi. Nulla facilisi.
+
+Cras pretium sit amet quam congue maximus. Morbi lacus libero, imperdiet commodo massa sed, scelerisque placerat libero. Cras nisl nisi, consectetur sed bibendum eu, venenatis at enim. Proin sodales justo at quam aliquam, a consectetur mi ornare. Donec porta ac est sit amet efficitur. Suspendisse vestibulum tortor id neque imperdiet, id lacinia risus vehicula. Phasellus ac eleifend purus. Mauris vel gravida ante. Aliquam vitae lobortis risus. Sed vehicula consectetur tincidunt. Nam et justo vitae purus molestie consequat. Pellentesque ipsum ex, convallis quis blandit non, gravida et urna. Donec diam ligula amet.
diff --git a/pkg/sentry/fsimpl/ext/assets/file.txt b/pkg/sentry/fsimpl/ext/assets/file.txt
new file mode 100644
index 000000000..980a0d5f1
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/assets/file.txt
@@ -0,0 +1 @@
+Hello World!
diff --git a/pkg/sentry/fsimpl/ext/assets/symlink.txt b/pkg/sentry/fsimpl/ext/assets/symlink.txt
new file mode 120000
index 000000000..4c330738c
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/assets/symlink.txt
@@ -0,0 +1 @@
+file.txt \ No newline at end of file
diff --git a/pkg/sentry/fsimpl/ext/assets/tiny.ext2 b/pkg/sentry/fsimpl/ext/assets/tiny.ext2
new file mode 100644
index 000000000..381ade9bf
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/assets/tiny.ext2
Binary files differ
diff --git a/pkg/sentry/fsimpl/ext/assets/tiny.ext3 b/pkg/sentry/fsimpl/ext/assets/tiny.ext3
new file mode 100644
index 000000000..0e97a324c
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/assets/tiny.ext3
Binary files differ
diff --git a/pkg/sentry/fsimpl/ext/assets/tiny.ext4 b/pkg/sentry/fsimpl/ext/assets/tiny.ext4
new file mode 100644
index 000000000..a6859736d
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/assets/tiny.ext4
Binary files differ
diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD
new file mode 100644
index 000000000..6c5a559fd
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/benchmark/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_test")
+
+package(licenses = ["notice"])
+
+go_test(
+ name = "benchmark_test",
+ size = "small",
+ srcs = ["benchmark_test.go"],
+ deps = [
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fsimpl/ext",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
new file mode 100644
index 000000000..89caee3df
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
@@ -0,0 +1,206 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// These benchmarks emulate memfs benchmarks. Ext4 images must be created
+// before this benchmark is run using the `make_deep_ext4.sh` script at
+// /tmp/image-{depth}.ext4 for all the depths tested below.
+//
+// The benchmark itself cannot run the script because the script requires
+// sudo privileges to create the file system images.
+package benchmark_test
+
+import (
+ "fmt"
+ "os"
+ "runtime"
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+var depths = []int{1, 2, 3, 8, 64, 100}
+
+const filename = "file.txt"
+
+// setUp opens imagePath as an ext Filesystem and returns all necessary
+// elements required to run tests. If error is nil, it also returns a tear
+// down function which must be called after the test is run for clean up.
+func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesystem, *vfs.VirtualDentry, func(), error) {
+ f, err := os.Open(imagePath)
+ if err != nil {
+ return nil, nil, nil, nil, err
+ }
+
+ ctx := contexttest.Context(b)
+ creds := auth.CredentialsFromContext(ctx)
+
+ // Create VFS.
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ return nil, nil, nil, nil, err
+ }
+ vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())})
+ if err != nil {
+ f.Close()
+ return nil, nil, nil, nil, err
+ }
+
+ root := mntns.Root()
+
+ tearDown := func() {
+ root.DecRef()
+
+ if err := f.Close(); err != nil {
+ b.Fatalf("tearDown failed: %v", err)
+ }
+ }
+ return ctx, vfsObj, &root, tearDown, nil
+}
+
+// mount mounts extfs at the path operation passed. Returns a tear down
+// function which must be called after the test is run for clean up.
+func mount(b *testing.B, imagePath string, vfsfs *vfs.VirtualFilesystem, pop *vfs.PathOperation) func() {
+ b.Helper()
+
+ f, err := os.Open(imagePath)
+ if err != nil {
+ b.Fatalf("could not open image at %s: %v", imagePath, err)
+ }
+
+ ctx := contexttest.Context(b)
+ creds := auth.CredentialsFromContext(ctx)
+
+ if err := vfsfs.MountAt(ctx, creds, imagePath, pop, "extfs", &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ InternalData: int(f.Fd()),
+ },
+ }); err != nil {
+ b.Fatalf("failed to mount tmpfs submount: %v", err)
+ }
+ return func() {
+ if err := f.Close(); err != nil {
+ b.Fatalf("tearDown failed: %v", err)
+ }
+ }
+}
+
+// BenchmarkVFS2Ext4fsStat emulates BenchmarkVFS2MemfsStat.
+func BenchmarkVFS2Ext4fsStat(b *testing.B) {
+ for _, depth := range depths {
+ b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) {
+ ctx, vfsfs, root, tearDown, err := setUp(b, fmt.Sprintf("/tmp/image-%d.ext4", depth))
+ if err != nil {
+ b.Fatalf("setUp failed: %v", err)
+ }
+ defer tearDown()
+
+ creds := auth.CredentialsFromContext(ctx)
+ var filePathBuilder strings.Builder
+ filePathBuilder.WriteByte('/')
+ for i := 1; i <= depth; i++ {
+ filePathBuilder.WriteString(fmt.Sprintf("%d", i))
+ filePathBuilder.WriteByte('/')
+ }
+ filePathBuilder.WriteString(filename)
+ filePath := filePathBuilder.String()
+
+ runtime.GC()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ stat, err := vfsfs.StatAt(ctx, creds, &vfs.PathOperation{
+ Root: *root,
+ Start: *root,
+ Path: fspath.Parse(filePath),
+ FollowFinalSymlink: true,
+ }, &vfs.StatOptions{})
+ if err != nil {
+ b.Fatalf("stat(%q) failed: %v", filePath, err)
+ }
+ // Sanity check.
+ if stat.Size > 0 {
+ b.Fatalf("got wrong file size (%d)", stat.Size)
+ }
+ }
+ })
+ }
+}
+
+// BenchmarkVFS2ExtfsMountStat emulates BenchmarkVFS2MemfsMountStat.
+func BenchmarkVFS2ExtfsMountStat(b *testing.B) {
+ for _, depth := range depths {
+ b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) {
+ // Create root extfs with depth 1 so we can mount extfs again at /1/.
+ ctx, vfsfs, root, tearDown, err := setUp(b, fmt.Sprintf("/tmp/image-%d.ext4", 1))
+ if err != nil {
+ b.Fatalf("setUp failed: %v", err)
+ }
+ defer tearDown()
+
+ creds := auth.CredentialsFromContext(ctx)
+ mountPointName := "/1/"
+ pop := vfs.PathOperation{
+ Root: *root,
+ Start: *root,
+ Path: fspath.Parse(mountPointName),
+ }
+
+ // Save the mount point for later use.
+ mountPoint, err := vfsfs.GetDentryAt(ctx, creds, &pop, &vfs.GetDentryOptions{})
+ if err != nil {
+ b.Fatalf("failed to walk to mount point: %v", err)
+ }
+ defer mountPoint.DecRef()
+
+ // Create extfs submount.
+ mountTearDown := mount(b, fmt.Sprintf("/tmp/image-%d.ext4", depth), vfsfs, &pop)
+ defer mountTearDown()
+
+ var filePathBuilder strings.Builder
+ filePathBuilder.WriteString(mountPointName)
+ for i := 1; i <= depth; i++ {
+ filePathBuilder.WriteString(fmt.Sprintf("%d", i))
+ filePathBuilder.WriteByte('/')
+ }
+ filePathBuilder.WriteString(filename)
+ filePath := filePathBuilder.String()
+
+ runtime.GC()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ stat, err := vfsfs.StatAt(ctx, creds, &vfs.PathOperation{
+ Root: *root,
+ Start: *root,
+ Path: fspath.Parse(filePath),
+ FollowFinalSymlink: true,
+ }, &vfs.StatOptions{})
+ if err != nil {
+ b.Fatalf("stat(%q) failed: %v", filePath, err)
+ }
+ // Sanity check. touch(1) always creates files of size 0 (empty).
+ if stat.Size > 0 {
+ b.Fatalf("got wrong file size (%d)", stat.Size)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/fsimpl/ext/benchmark/make_deep_ext4.sh b/pkg/sentry/fsimpl/ext/benchmark/make_deep_ext4.sh
new file mode 100755
index 000000000..d0910da1f
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/benchmark/make_deep_ext4.sh
@@ -0,0 +1,72 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script creates an ext4 image with $1 depth of directories and a file in
+# the inner most directory. The created file is at path /1/2/.../depth/file.txt.
+# The ext4 image is written to $2. The image is temporarily mounted at
+# /tmp/mountpoint. This script must be run with sudo privileges.
+
+# Usage:
+# sudo bash make_deep_ext4.sh {depth} {output path}
+
+# Check positional arguments.
+if [ "$#" -ne 2 ]; then
+ echo "Usage: sudo bash make_deep_ext4.sh {depth} {output path}"
+ exit 1
+fi
+
+# Make sure depth is a non-negative number.
+if ! [[ "$1" =~ ^[0-9]+$ ]]; then
+ echo "Depth must be a non-negative number."
+ exit 1
+fi
+
+# Create a 1 MB filesystem image at the requested output path.
+rm -f $2
+fallocate -l 1M $2
+if [ $? -ne 0 ]; then
+ echo "fallocate failed"
+ exit $?
+fi
+
+# Convert that blank into an ext4 image.
+mkfs.ext4 -j $2
+if [ $? -ne 0 ]; then
+ echo "mkfs.ext4 failed"
+ exit $?
+fi
+
+# Mount the image.
+MOUNTPOINT=/tmp/mountpoint
+mkdir -p $MOUNTPOINT
+mount -o loop $2 $MOUNTPOINT
+if [ $? -ne 0 ]; then
+ echo "mount failed"
+ exit $?
+fi
+
+# Create nested directories and the file.
+if [ "$1" -eq 0 ]; then
+ FILEPATH=$MOUNTPOINT/file.txt
+else
+ FILEPATH=$MOUNTPOINT/$(seq -s '/' 1 $1)/file.txt
+fi
+mkdir -p $(dirname $FILEPATH) || exit
+touch $FILEPATH
+
+# Clean up.
+umount $MOUNTPOINT
+rm -rf $MOUNTPOINT
diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go
new file mode 100644
index 000000000..8bb104ff0
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/block_map_file.go
@@ -0,0 +1,201 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "io"
+ "math"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const (
+ // numDirectBlks is the number of direct blocks in ext block map inodes.
+ numDirectBlks = 12
+)
+
+// blockMapFile is a type of regular file which uses direct/indirect block
+// addressing to store file data. This was deprecated in ext4.
+type blockMapFile struct {
+ regFile regularFile
+
+ // directBlks are the direct blocks numbers. The physical blocks pointed by
+ // these holds file data. Contains file blocks 0 to 11.
+ directBlks [numDirectBlks]uint32
+
+ // indirectBlk is the physical block which contains (blkSize/4) direct block
+ // numbers (as uint32 integers).
+ indirectBlk uint32
+
+ // doubleIndirectBlk is the physical block which contains (blkSize/4) indirect
+ // block numbers (as uint32 integers).
+ doubleIndirectBlk uint32
+
+ // tripleIndirectBlk is the physical block which contains (blkSize/4) doubly
+ // indirect block numbers (as uint32 integers).
+ tripleIndirectBlk uint32
+
+ // coverage at (i)th index indicates the amount of file data a node at
+ // height (i) covers. Height 0 is the direct block.
+ coverage [4]uint64
+}
+
+// Compiles only if blockMapFile implements io.ReaderAt.
+var _ io.ReaderAt = (*blockMapFile)(nil)
+
+// newBlockMapFile is the blockMapFile constructor. It initializes the file to
+// physical blocks map with (at most) the first 12 (direct) blocks.
+func newBlockMapFile(args inodeArgs) (*blockMapFile, error) {
+ file := &blockMapFile{}
+ file.regFile.impl = file
+ file.regFile.inode.init(args, &file.regFile)
+
+ for i := uint(0); i < 4; i++ {
+ file.coverage[i] = getCoverage(file.regFile.inode.blkSize, i)
+ }
+
+ blkMap := file.regFile.inode.diskInode.Data()
+ binary.Unmarshal(blkMap[:numDirectBlks*4], binary.LittleEndian, &file.directBlks)
+ binary.Unmarshal(blkMap[numDirectBlks*4:(numDirectBlks+1)*4], binary.LittleEndian, &file.indirectBlk)
+ binary.Unmarshal(blkMap[(numDirectBlks+1)*4:(numDirectBlks+2)*4], binary.LittleEndian, &file.doubleIndirectBlk)
+ binary.Unmarshal(blkMap[(numDirectBlks+2)*4:(numDirectBlks+3)*4], binary.LittleEndian, &file.tripleIndirectBlk)
+ return file, nil
+}
+
+// ReadAt implements io.ReaderAt.ReadAt.
+func (f *blockMapFile) ReadAt(dst []byte, off int64) (int, error) {
+ if len(dst) == 0 {
+ return 0, nil
+ }
+
+ if off < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ offset := uint64(off)
+ size := f.regFile.inode.diskInode.Size()
+ if offset >= size {
+ return 0, io.EOF
+ }
+
+ // dirBlksEnd is the file offset until which direct blocks cover file data.
+ // Direct blocks cover 0 <= file offset < dirBlksEnd.
+ dirBlksEnd := numDirectBlks * f.coverage[0]
+
+ // indirBlkEnd is the file offset until which the indirect block covers file
+ // data. The indirect block covers dirBlksEnd <= file offset < indirBlkEnd.
+ indirBlkEnd := dirBlksEnd + f.coverage[1]
+
+ // doubIndirBlkEnd is the file offset until which the double indirect block
+ // covers file data. The double indirect block covers the range
+ // indirBlkEnd <= file offset < doubIndirBlkEnd.
+ doubIndirBlkEnd := indirBlkEnd + f.coverage[2]
+
+ read := 0
+ toRead := len(dst)
+ if uint64(toRead)+offset > size {
+ toRead = int(size - offset)
+ }
+ for read < toRead {
+ var err error
+ var curR int
+
+ // Figure out which block to delegate the read to.
+ switch {
+ case offset < dirBlksEnd:
+ // Direct block.
+ curR, err = f.read(f.directBlks[offset/f.regFile.inode.blkSize], offset%f.regFile.inode.blkSize, 0, dst[read:])
+ case offset < indirBlkEnd:
+ // Indirect block.
+ curR, err = f.read(f.indirectBlk, offset-dirBlksEnd, 1, dst[read:])
+ case offset < doubIndirBlkEnd:
+ // Doubly indirect block.
+ curR, err = f.read(f.doubleIndirectBlk, offset-indirBlkEnd, 2, dst[read:])
+ default:
+ // Triply indirect block.
+ curR, err = f.read(f.tripleIndirectBlk, offset-doubIndirBlkEnd, 3, dst[read:])
+ }
+
+ read += curR
+ offset += uint64(curR)
+ if err != nil {
+ return read, err
+ }
+ }
+
+ if read < len(dst) {
+ return read, io.EOF
+ }
+ return read, nil
+}
+
+// read is the recursive step of the ReadAt function. It relies on knowing the
+// current node's location on disk (curPhyBlk) and its height in the block map
+// tree. A height of 0 shows that the current node is actually holding file
+// data. relFileOff tells the offset from which we need to start to reading
+// under the current node. It is completely relative to the current node.
+func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, dst []byte) (int, error) {
+ curPhyBlkOff := int64(curPhyBlk) * int64(f.regFile.inode.blkSize)
+ if height == 0 {
+ toRead := int(f.regFile.inode.blkSize - relFileOff)
+ if len(dst) < toRead {
+ toRead = len(dst)
+ }
+
+ n, _ := f.regFile.inode.fs.dev.ReadAt(dst[:toRead], curPhyBlkOff+int64(relFileOff))
+ if n < toRead {
+ return n, syserror.EIO
+ }
+ return n, nil
+ }
+
+ childCov := f.coverage[height-1]
+ startIdx := relFileOff / childCov
+ endIdx := f.regFile.inode.blkSize / 4 // This is exclusive.
+ wantEndIdx := (relFileOff + uint64(len(dst))) / childCov
+ wantEndIdx++ // Make this exclusive.
+ if wantEndIdx < endIdx {
+ endIdx = wantEndIdx
+ }
+
+ read := 0
+ curChildOff := relFileOff % childCov
+ for i := startIdx; i < endIdx; i++ {
+ var childPhyBlk uint32
+ err := readFromDisk(f.regFile.inode.fs.dev, curPhyBlkOff+int64(i*4), &childPhyBlk)
+ if err != nil {
+ return read, err
+ }
+
+ n, err := f.read(childPhyBlk, curChildOff, height-1, dst[read:])
+ read += n
+ if err != nil {
+ return read, err
+ }
+
+ curChildOff = 0
+ }
+
+ return read, nil
+}
+
+// getCoverage returns the number of bytes a node at the given height covers.
+// Height 0 is the file data block itself. Height 1 is the indirect block.
+//
+// Formula: blkSize * ((blkSize / 4)^height)
+func getCoverage(blkSize uint64, height uint) uint64 {
+ return blkSize * uint64(math.Pow(float64(blkSize/4), float64(height)))
+}
diff --git a/pkg/sentry/fsimpl/ext/block_map_test.go b/pkg/sentry/fsimpl/ext/block_map_test.go
new file mode 100644
index 000000000..6fa84e7aa
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/block_map_test.go
@@ -0,0 +1,156 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "bytes"
+ "math/rand"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
+)
+
+// These consts are for mocking the block map tree.
+const (
+ mockBMBlkSize = uint32(16)
+ mockBMDiskSize = 2500
+)
+
+// TestBlockMapReader stress tests block map reader functionality. It performs
+// random length reads from all possible positions in the block map structure.
+func TestBlockMapReader(t *testing.T) {
+ mockBMFile, want := blockMapSetUp(t)
+ n := len(want)
+
+ for from := 0; from < n; from++ {
+ got := make([]byte, n-from)
+
+ if read, err := mockBMFile.ReadAt(got, int64(from)); err != nil {
+ t.Fatalf("file read operation from offset %d to %d only read %d bytes: %v", from, n, read, err)
+ }
+
+ if diff := cmp.Diff(got, want[from:]); diff != "" {
+ t.Fatalf("file data from offset %d to %d mismatched (-want +got):\n%s", from, n, diff)
+ }
+ }
+}
+
+// blkNumGen is a number generator which gives block numbers for building the
+// block map file on disk. It gives unique numbers in a random order which
+// facilitates in creating an extremely fragmented filesystem.
+type blkNumGen struct {
+ nums []uint32
+}
+
+// newBlkNumGen is the blkNumGen constructor.
+func newBlkNumGen() *blkNumGen {
+ blkNums := &blkNumGen{}
+ lim := mockBMDiskSize / mockBMBlkSize
+ blkNums.nums = make([]uint32, lim)
+ for i := uint32(0); i < lim; i++ {
+ blkNums.nums[i] = i
+ }
+
+ rand.Shuffle(int(lim), func(i, j int) {
+ blkNums.nums[i], blkNums.nums[j] = blkNums.nums[j], blkNums.nums[i]
+ })
+ return blkNums
+}
+
+// next returns the next random block number.
+func (n *blkNumGen) next() uint32 {
+ ret := n.nums[0]
+ n.nums = n.nums[1:]
+ return ret
+}
+
+// blockMapSetUp creates a mock disk and a block map file. It initializes the
+// block map file with 12 direct block, 1 indirect block, 1 double indirect
+// block and 1 triple indirect block (basically fill it till the rim). It
+// initializes the disk to reflect the inode. Also returns the file data that
+// the inode covers and that is written to disk.
+func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) {
+ mockDisk := make([]byte, mockBMDiskSize)
+ var fileData []byte
+ blkNums := newBlkNumGen()
+ var data []byte
+
+ // Write the direct blocks.
+ for i := 0; i < numDirectBlks; i++ {
+ curBlkNum := blkNums.next()
+ data = binary.Marshal(data, binary.LittleEndian, curBlkNum)
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, curBlkNum, 0, blkNums)...)
+ }
+
+ // Write to indirect block.
+ indirectBlk := blkNums.next()
+ data = binary.Marshal(data, binary.LittleEndian, indirectBlk)
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, indirectBlk, 1, blkNums)...)
+
+ // Write to indirect block.
+ doublyIndirectBlk := blkNums.next()
+ data = binary.Marshal(data, binary.LittleEndian, doublyIndirectBlk)
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, doublyIndirectBlk, 2, blkNums)...)
+
+ // Write to indirect block.
+ triplyIndirectBlk := blkNums.next()
+ data = binary.Marshal(data, binary.LittleEndian, triplyIndirectBlk)
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, triplyIndirectBlk, 3, blkNums)...)
+
+ args := inodeArgs{
+ fs: &filesystem{
+ dev: bytes.NewReader(mockDisk),
+ },
+ diskInode: &disklayout.InodeNew{
+ InodeOld: disklayout.InodeOld{
+ SizeLo: getMockBMFileFize(),
+ },
+ },
+ blkSize: uint64(mockBMBlkSize),
+ }
+ copy(args.diskInode.Data(), data)
+
+ mockFile, err := newBlockMapFile(args)
+ if err != nil {
+ t.Fatalf("newBlockMapFile failed: %v", err)
+ }
+ return mockFile, fileData
+}
+
+// writeFileDataToBlock writes random bytes to the block on disk.
+func writeFileDataToBlock(disk []byte, blkNum uint32, height uint, blkNums *blkNumGen) []byte {
+ if height == 0 {
+ start := blkNum * mockBMBlkSize
+ end := start + mockBMBlkSize
+ rand.Read(disk[start:end])
+ return disk[start:end]
+ }
+
+ var fileData []byte
+ for off := blkNum * mockBMBlkSize; off < (blkNum+1)*mockBMBlkSize; off += 4 {
+ curBlkNum := blkNums.next()
+ copy(disk[off:off+4], binary.Marshal(nil, binary.LittleEndian, curBlkNum))
+ fileData = append(fileData, writeFileDataToBlock(disk, curBlkNum, height-1, blkNums)...)
+ }
+ return fileData
+}
+
+// getMockBMFileFize gets the size of the mock block map file which is used for
+// testing.
+func getMockBMFileFize() uint32 {
+ return uint32(numDirectBlks*getCoverage(uint64(mockBMBlkSize), 0) + getCoverage(uint64(mockBMBlkSize), 1) + getCoverage(uint64(mockBMBlkSize), 2) + getCoverage(uint64(mockBMBlkSize), 3))
+}
diff --git a/pkg/sentry/fsimpl/ext/dentry.go b/pkg/sentry/fsimpl/ext/dentry.go
new file mode 100644
index 000000000..55902322a
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/dentry.go
@@ -0,0 +1,79 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// dentry implements vfs.DentryImpl.
+type dentry struct {
+ vfsd vfs.Dentry
+
+ // Protected by filesystem.mu.
+ parent *dentry
+ name string
+
+ // inode is the inode represented by this dentry. Multiple Dentries may
+ // share a single non-directory Inode (with hard links). inode is
+ // immutable.
+ inode *inode
+}
+
+// Compiles only if dentry implements vfs.DentryImpl.
+var _ vfs.DentryImpl = (*dentry)(nil)
+
+// newDentry is the dentry constructor.
+func newDentry(in *inode) *dentry {
+ d := &dentry{
+ inode: in,
+ }
+ d.vfsd.Init(d)
+ return d
+}
+
+// IncRef implements vfs.DentryImpl.IncRef.
+func (d *dentry) IncRef() {
+ d.inode.incRef()
+}
+
+// TryIncRef implements vfs.DentryImpl.TryIncRef.
+func (d *dentry) TryIncRef() bool {
+ return d.inode.tryIncRef()
+}
+
+// DecRef implements vfs.DentryImpl.DecRef.
+func (d *dentry) DecRef() {
+ // FIXME(b/134676337): filesystem.mu may not be locked as required by
+ // inode.decRef().
+ d.inode.decRef()
+}
+
+// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
+//
+// TODO(b/134676337): Implement inotify.
+func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {}
+
+// Watches implements vfs.DentryImpl.Watches.
+//
+// TODO(b/134676337): Implement inotify.
+func (d *dentry) Watches() *vfs.Watches {
+ return nil
+}
+
+// OnZeroWatches implements vfs.Dentry.OnZeroWatches.
+//
+// TODO(b/134676337): Implement inotify.
+func (d *dentry) OnZeroWatches() {}
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
new file mode 100644
index 000000000..357512c7e
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -0,0 +1,318 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// directory represents a directory inode. It holds the childList in memory.
+type directory struct {
+ inode inode
+
+ // childCache maps filenames to dentries for children for which dentries
+ // have been instantiated. childCache is protected by filesystem.mu.
+ childCache map[string]*dentry
+
+ // mu serializes the changes to childList.
+ // Lock Order (outermost locks must be taken first):
+ // directory.mu
+ // filesystem.mu
+ mu sync.Mutex
+
+ // childList is a list containing (1) child dirents and (2) fake dirents
+ // (with diskDirent == nil) that represent the iteration position of
+ // directoryFDs. childList is used to support directoryFD.IterDirents()
+ // efficiently. childList is protected by mu.
+ childList direntList
+
+ // childMap maps the child's filename to the dirent structure stored in
+ // childList. This adds some data replication but helps in faster path
+ // traversal. For consistency, key == childMap[key].diskDirent.FileName().
+ // Immutable.
+ childMap map[string]*dirent
+}
+
+// newDirectory is the directory constructor.
+func newDirectory(args inodeArgs, newDirent bool) (*directory, error) {
+ file := &directory{
+ childCache: make(map[string]*dentry),
+ childMap: make(map[string]*dirent),
+ }
+ file.inode.init(args, file)
+
+ // Initialize childList by reading dirents from the underlying file.
+ if args.diskInode.Flags().Index {
+ // TODO(b/134676337): Support hash tree directories. Currently only the '.'
+ // and '..' entries are read in.
+
+ // Users cannot navigate this hash tree directory yet.
+ log.Warningf("hash tree directory being used which is unsupported")
+ return file, nil
+ }
+
+ // The dirents are organized in a linear array in the file data.
+ // Extract the file data and decode the dirents.
+ regFile, err := newRegularFile(args)
+ if err != nil {
+ return nil, err
+ }
+
+ // buf is used as scratch space for reading in dirents from disk and
+ // unmarshalling them into dirent structs.
+ buf := make([]byte, disklayout.DirentSize)
+ size := args.diskInode.Size()
+ for off, inc := uint64(0), uint64(0); off < size; off += inc {
+ toRead := size - off
+ if toRead > disklayout.DirentSize {
+ toRead = disklayout.DirentSize
+ }
+ if n, err := regFile.impl.ReadAt(buf[:toRead], int64(off)); uint64(n) < toRead {
+ return nil, err
+ }
+
+ var curDirent dirent
+ if newDirent {
+ curDirent.diskDirent = &disklayout.DirentNew{}
+ } else {
+ curDirent.diskDirent = &disklayout.DirentOld{}
+ }
+ binary.Unmarshal(buf, binary.LittleEndian, curDirent.diskDirent)
+
+ if curDirent.diskDirent.Inode() != 0 && len(curDirent.diskDirent.FileName()) != 0 {
+ // Inode number and name length fields being set to 0 is used to indicate
+ // an unused dirent.
+ file.childList.PushBack(&curDirent)
+ file.childMap[curDirent.diskDirent.FileName()] = &curDirent
+ }
+
+ // The next dirent is placed exactly after this dirent record on disk.
+ inc = uint64(curDirent.diskDirent.RecordSize())
+ }
+
+ return file, nil
+}
+
+func (i *inode) isDir() bool {
+ _, ok := i.impl.(*directory)
+ return ok
+}
+
+// dirent is the directory.childList node.
+type dirent struct {
+ diskDirent disklayout.Dirent
+
+ // direntEntry links dirents into their parent directory.childList.
+ direntEntry
+}
+
+// directoryFD represents a directory file description. It implements
+// vfs.FileDescriptionImpl.
+type directoryFD struct {
+ fileDescription
+ vfs.DirectoryFileDescriptionDefaultImpl
+
+ // Protected by directory.mu.
+ iter *dirent
+ off int64
+}
+
+// Compiles only if directoryFD implements vfs.FileDescriptionImpl.
+var _ vfs.FileDescriptionImpl = (*directoryFD)(nil)
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *directoryFD) Release() {
+ if fd.iter == nil {
+ return
+ }
+
+ dir := fd.inode().impl.(*directory)
+ dir.mu.Lock()
+ dir.childList.Remove(fd.iter)
+ dir.mu.Unlock()
+ fd.iter = nil
+}
+
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ extfs := fd.filesystem()
+ dir := fd.inode().impl.(*directory)
+
+ dir.mu.Lock()
+ defer dir.mu.Unlock()
+
+ // Ensure that fd.iter exists and is not linked into dir.childList.
+ var child *dirent
+ if fd.iter == nil {
+ // Start iteration at the beginning of dir.
+ child = dir.childList.Front()
+ fd.iter = &dirent{}
+ } else {
+ // Continue iteration from where we left off.
+ child = fd.iter.Next()
+ dir.childList.Remove(fd.iter)
+ }
+ for ; child != nil; child = child.Next() {
+ // Skip other directoryFD iterators.
+ if child.diskDirent != nil {
+ childType, ok := child.diskDirent.FileType()
+ if !ok {
+ // We will need to read the inode off disk. Do not increment
+ // ref count here because this inode is not being added to the
+ // dentry tree.
+ extfs.mu.Lock()
+ childInode, err := extfs.getOrCreateInodeLocked(child.diskDirent.Inode())
+ extfs.mu.Unlock()
+ if err != nil {
+ // Usage of the file description after the error is
+ // undefined. This implementation would continue reading
+ // from the next dirent.
+ fd.off++
+ dir.childList.InsertAfter(child, fd.iter)
+ return err
+ }
+ childType = fs.ToInodeType(childInode.diskInode.Mode().FileType())
+ }
+
+ if err := cb.Handle(vfs.Dirent{
+ Name: child.diskDirent.FileName(),
+ Type: fs.ToDirentType(childType),
+ Ino: uint64(child.diskDirent.Inode()),
+ NextOff: fd.off + 1,
+ }); err != nil {
+ dir.childList.InsertBefore(child, fd.iter)
+ return err
+ }
+ fd.off++
+ }
+ }
+ dir.childList.PushBack(fd.iter)
+ return nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ if whence != linux.SEEK_SET && whence != linux.SEEK_CUR {
+ return 0, syserror.EINVAL
+ }
+
+ dir := fd.inode().impl.(*directory)
+
+ dir.mu.Lock()
+ defer dir.mu.Unlock()
+
+ // Find resulting offset.
+ if whence == linux.SEEK_CUR {
+ offset += fd.off
+ }
+
+ if offset < 0 {
+ // lseek(2) specifies that EINVAL should be returned if the resulting offset
+ // is negative.
+ return 0, syserror.EINVAL
+ }
+
+ n := int64(len(dir.childMap))
+ realWantOff := offset
+ if realWantOff > n {
+ realWantOff = n
+ }
+ realCurOff := fd.off
+ if realCurOff > n {
+ realCurOff = n
+ }
+
+ // Ensure that fd.iter exists and is linked into dir.childList so we can
+ // intelligently seek from the optimal position.
+ if fd.iter == nil {
+ fd.iter = &dirent{}
+ dir.childList.PushFront(fd.iter)
+ }
+
+ // Guess that iterating from the current position is optimal.
+ child := fd.iter
+ diff := realWantOff - realCurOff // Shows direction and magnitude of travel.
+
+ // See if starting from the beginning or end is better.
+ abDiff := diff
+ if diff < 0 {
+ abDiff = -diff
+ }
+ if abDiff > realWantOff {
+ // Starting from the beginning is best.
+ child = dir.childList.Front()
+ diff = realWantOff
+ } else if abDiff > (n - realWantOff) {
+ // Starting from the end is best.
+ child = dir.childList.Back()
+ // (n - 1) because the last non-nil dirent represents the (n-1)th offset.
+ diff = realWantOff - (n - 1)
+ }
+
+ for child != nil {
+ // Skip other directoryFD iterators.
+ if child.diskDirent != nil {
+ if diff == 0 {
+ if child != fd.iter {
+ dir.childList.Remove(fd.iter)
+ dir.childList.InsertBefore(child, fd.iter)
+ }
+
+ fd.off = offset
+ return offset, nil
+ }
+
+ if diff < 0 {
+ diff++
+ child = child.Prev()
+ } else {
+ diff--
+ child = child.Next()
+ }
+ continue
+ }
+
+ if diff < 0 {
+ child = child.Prev()
+ } else {
+ child = child.Next()
+ }
+ }
+
+ // Reaching here indicates that the offset is beyond the end of the childList.
+ dir.childList.Remove(fd.iter)
+ dir.childList.PushBack(fd.iter)
+ fd.off = offset
+ return offset, nil
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *directoryFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *directoryFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD
new file mode 100644
index 000000000..9bd9c76c0
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD
@@ -0,0 +1,47 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "disklayout",
+ srcs = [
+ "block_group.go",
+ "block_group_32.go",
+ "block_group_64.go",
+ "dirent.go",
+ "dirent_new.go",
+ "dirent_old.go",
+ "disklayout.go",
+ "extent.go",
+ "inode.go",
+ "inode_new.go",
+ "inode_old.go",
+ "superblock.go",
+ "superblock_32.go",
+ "superblock_64.go",
+ "superblock_old.go",
+ "test_utils.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ ],
+)
+
+go_test(
+ name = "disklayout_test",
+ size = "small",
+ srcs = [
+ "block_group_test.go",
+ "dirent_test.go",
+ "extent_test.go",
+ "inode_test.go",
+ "superblock_test.go",
+ ],
+ library = ":disklayout",
+ deps = ["//pkg/sentry/kernel/time"],
+)
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group.go b/pkg/sentry/fsimpl/ext/disklayout/block_group.go
new file mode 100644
index 000000000..ad6f4fef8
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group.go
@@ -0,0 +1,137 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+// BlockGroup represents a Linux ext block group descriptor. An ext file system
+// is split into a series of block groups. This provides an access layer to
+// information needed to access and use a block group.
+//
+// Location:
+// - The block group descriptor table is always placed in the blocks
+// immediately after the block containing the superblock.
+// - The 1st block group descriptor in the original table is in the
+// (sb.FirstDataBlock() + 1)th block.
+// - See SuperBlock docs to see where the block group descriptor table is
+// replicated.
+// - sb.BgDescSize() must be used as the block group descriptor entry size
+// while reading the table from disk.
+//
+// See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#block-group-descriptors.
+type BlockGroup interface {
+ // InodeTable returns the absolute block number of the block containing the
+ // inode table. This points to an array of Inode structs. Inode tables are
+ // statically allocated at mkfs time. The superblock records the number of
+ // inodes per group (length of this table) and the size of each inode struct.
+ InodeTable() uint64
+
+ // BlockBitmap returns the absolute block number of the block containing the
+ // block bitmap. This bitmap tracks the usage of data blocks within this block
+ // group and has its own checksum.
+ BlockBitmap() uint64
+
+ // InodeBitmap returns the absolute block number of the block containing the
+ // inode bitmap. This bitmap tracks the usage of this group's inode table
+ // entries and has its own checksum.
+ InodeBitmap() uint64
+
+ // ExclusionBitmap returns the absolute block number of the snapshot exclusion
+ // bitmap.
+ ExclusionBitmap() uint64
+
+ // FreeBlocksCount returns the number of free blocks in the group.
+ FreeBlocksCount() uint32
+
+ // FreeInodesCount returns the number of free inodes in the group.
+ FreeInodesCount() uint32
+
+ // DirectoryCount returns the number of inodes that represent directories
+ // under this block group.
+ DirectoryCount() uint32
+
+ // UnusedInodeCount returns the number of unused inodes beyond the last used
+ // inode in this group's inode table. As a result, we needn’t scan past the
+ // (InodesPerGroup - UnusedInodeCount())th entry in the inode table.
+ UnusedInodeCount() uint32
+
+ // BlockBitmapChecksum returns the block bitmap checksum. This is calculated
+ // using crc32c(FS UUID + group number + entire bitmap).
+ BlockBitmapChecksum() uint32
+
+ // InodeBitmapChecksum returns the inode bitmap checksum. This is calculated
+ // using crc32c(FS UUID + group number + entire bitmap).
+ InodeBitmapChecksum() uint32
+
+ // Checksum returns this block group's checksum.
+ //
+ // If SbMetadataCsum feature is set:
+ // - checksum is crc32c(FS UUID + group number + group descriptor
+ // structure) & 0xFFFF.
+ //
+ // If SbGdtCsum feature is set:
+ // - checksum is crc16(FS UUID + group number + group descriptor
+ // structure).
+ //
+ // SbMetadataCsum and SbGdtCsum should not be both set.
+ // If they are, Linux warns and asks to run fsck.
+ Checksum() uint16
+
+ // Flags returns BGFlags which represents the block group flags.
+ Flags() BGFlags
+}
+
+// These are the different block group flags.
+const (
+ // BgInodeUninit indicates that inode table and bitmap are not initialized.
+ BgInodeUninit uint16 = 0x1
+
+ // BgBlockUninit indicates that block bitmap is not initialized.
+ BgBlockUninit uint16 = 0x2
+
+ // BgInodeZeroed indicates that inode table is zeroed.
+ BgInodeZeroed uint16 = 0x4
+)
+
+// BGFlags represents all the different combinations of block group flags.
+type BGFlags struct {
+ InodeUninit bool
+ BlockUninit bool
+ InodeZeroed bool
+}
+
+// ToInt converts a BGFlags struct back to its 16-bit representation.
+func (f BGFlags) ToInt() uint16 {
+ var res uint16
+
+ if f.InodeUninit {
+ res |= BgInodeUninit
+ }
+ if f.BlockUninit {
+ res |= BgBlockUninit
+ }
+ if f.InodeZeroed {
+ res |= BgInodeZeroed
+ }
+
+ return res
+}
+
+// BGFlagsFromInt converts the 16-bit flag representation to a BGFlags struct.
+func BGFlagsFromInt(flags uint16) BGFlags {
+ return BGFlags{
+ InodeUninit: flags&BgInodeUninit > 0,
+ BlockUninit: flags&BgBlockUninit > 0,
+ InodeZeroed: flags&BgInodeZeroed > 0,
+ }
+}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go
new file mode 100644
index 000000000..3e16c76db
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go
@@ -0,0 +1,72 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+// BlockGroup32Bit emulates the first half of struct ext4_group_desc in
+// fs/ext4/ext4.h. It is the block group descriptor struct for ext2, ext3 and
+// 32-bit ext4 filesystems. It implements BlockGroup interface.
+type BlockGroup32Bit struct {
+ BlockBitmapLo uint32
+ InodeBitmapLo uint32
+ InodeTableLo uint32
+ FreeBlocksCountLo uint16
+ FreeInodesCountLo uint16
+ UsedDirsCountLo uint16
+ FlagsRaw uint16
+ ExcludeBitmapLo uint32
+ BlockBitmapChecksumLo uint16
+ InodeBitmapChecksumLo uint16
+ ItableUnusedLo uint16
+ ChecksumRaw uint16
+}
+
+// Compiles only if BlockGroup32Bit implements BlockGroup.
+var _ BlockGroup = (*BlockGroup32Bit)(nil)
+
+// InodeTable implements BlockGroup.InodeTable.
+func (bg *BlockGroup32Bit) InodeTable() uint64 { return uint64(bg.InodeTableLo) }
+
+// BlockBitmap implements BlockGroup.BlockBitmap.
+func (bg *BlockGroup32Bit) BlockBitmap() uint64 { return uint64(bg.BlockBitmapLo) }
+
+// InodeBitmap implements BlockGroup.InodeBitmap.
+func (bg *BlockGroup32Bit) InodeBitmap() uint64 { return uint64(bg.InodeBitmapLo) }
+
+// ExclusionBitmap implements BlockGroup.ExclusionBitmap.
+func (bg *BlockGroup32Bit) ExclusionBitmap() uint64 { return uint64(bg.ExcludeBitmapLo) }
+
+// FreeBlocksCount implements BlockGroup.FreeBlocksCount.
+func (bg *BlockGroup32Bit) FreeBlocksCount() uint32 { return uint32(bg.FreeBlocksCountLo) }
+
+// FreeInodesCount implements BlockGroup.FreeInodesCount.
+func (bg *BlockGroup32Bit) FreeInodesCount() uint32 { return uint32(bg.FreeInodesCountLo) }
+
+// DirectoryCount implements BlockGroup.DirectoryCount.
+func (bg *BlockGroup32Bit) DirectoryCount() uint32 { return uint32(bg.UsedDirsCountLo) }
+
+// UnusedInodeCount implements BlockGroup.UnusedInodeCount.
+func (bg *BlockGroup32Bit) UnusedInodeCount() uint32 { return uint32(bg.ItableUnusedLo) }
+
+// BlockBitmapChecksum implements BlockGroup.BlockBitmapChecksum.
+func (bg *BlockGroup32Bit) BlockBitmapChecksum() uint32 { return uint32(bg.BlockBitmapChecksumLo) }
+
+// InodeBitmapChecksum implements BlockGroup.InodeBitmapChecksum.
+func (bg *BlockGroup32Bit) InodeBitmapChecksum() uint32 { return uint32(bg.InodeBitmapChecksumLo) }
+
+// Checksum implements BlockGroup.Checksum.
+func (bg *BlockGroup32Bit) Checksum() uint16 { return bg.ChecksumRaw }
+
+// Flags implements BlockGroup.Flags.
+func (bg *BlockGroup32Bit) Flags() BGFlags { return BGFlagsFromInt(bg.FlagsRaw) }
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go
new file mode 100644
index 000000000..9a809197a
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go
@@ -0,0 +1,93 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+// BlockGroup64Bit emulates struct ext4_group_desc in fs/ext4/ext4.h.
+// It is the block group descriptor struct for 64-bit ext4 filesystems.
+// It implements BlockGroup interface. It is an extension of the 32-bit
+// version of BlockGroup.
+type BlockGroup64Bit struct {
+ // We embed the 32-bit struct here because 64-bit version is just an extension
+ // of the 32-bit version.
+ BlockGroup32Bit
+
+ // 64-bit specific fields.
+ BlockBitmapHi uint32
+ InodeBitmapHi uint32
+ InodeTableHi uint32
+ FreeBlocksCountHi uint16
+ FreeInodesCountHi uint16
+ UsedDirsCountHi uint16
+ ItableUnusedHi uint16
+ ExcludeBitmapHi uint32
+ BlockBitmapChecksumHi uint16
+ InodeBitmapChecksumHi uint16
+ _ uint32 // Padding to 64 bytes.
+}
+
+// Compiles only if BlockGroup64Bit implements BlockGroup.
+var _ BlockGroup = (*BlockGroup64Bit)(nil)
+
+// Methods to override. Checksum() and Flags() are not overridden.
+
+// InodeTable implements BlockGroup.InodeTable.
+func (bg *BlockGroup64Bit) InodeTable() uint64 {
+ return (uint64(bg.InodeTableHi) << 32) | uint64(bg.InodeTableLo)
+}
+
+// BlockBitmap implements BlockGroup.BlockBitmap.
+func (bg *BlockGroup64Bit) BlockBitmap() uint64 {
+ return (uint64(bg.BlockBitmapHi) << 32) | uint64(bg.BlockBitmapLo)
+}
+
+// InodeBitmap implements BlockGroup.InodeBitmap.
+func (bg *BlockGroup64Bit) InodeBitmap() uint64 {
+ return (uint64(bg.InodeBitmapHi) << 32) | uint64(bg.InodeBitmapLo)
+}
+
+// ExclusionBitmap implements BlockGroup.ExclusionBitmap.
+func (bg *BlockGroup64Bit) ExclusionBitmap() uint64 {
+ return (uint64(bg.ExcludeBitmapHi) << 32) | uint64(bg.ExcludeBitmapLo)
+}
+
+// FreeBlocksCount implements BlockGroup.FreeBlocksCount.
+func (bg *BlockGroup64Bit) FreeBlocksCount() uint32 {
+ return (uint32(bg.FreeBlocksCountHi) << 16) | uint32(bg.FreeBlocksCountLo)
+}
+
+// FreeInodesCount implements BlockGroup.FreeInodesCount.
+func (bg *BlockGroup64Bit) FreeInodesCount() uint32 {
+ return (uint32(bg.FreeInodesCountHi) << 16) | uint32(bg.FreeInodesCountLo)
+}
+
+// DirectoryCount implements BlockGroup.DirectoryCount.
+func (bg *BlockGroup64Bit) DirectoryCount() uint32 {
+ return (uint32(bg.UsedDirsCountHi) << 16) | uint32(bg.UsedDirsCountLo)
+}
+
+// UnusedInodeCount implements BlockGroup.UnusedInodeCount.
+func (bg *BlockGroup64Bit) UnusedInodeCount() uint32 {
+ return (uint32(bg.ItableUnusedHi) << 16) | uint32(bg.ItableUnusedLo)
+}
+
+// BlockBitmapChecksum implements BlockGroup.BlockBitmapChecksum.
+func (bg *BlockGroup64Bit) BlockBitmapChecksum() uint32 {
+ return (uint32(bg.BlockBitmapChecksumHi) << 16) | uint32(bg.BlockBitmapChecksumLo)
+}
+
+// InodeBitmapChecksum implements BlockGroup.InodeBitmapChecksum.
+func (bg *BlockGroup64Bit) InodeBitmapChecksum() uint32 {
+ return (uint32(bg.InodeBitmapChecksumHi) << 16) | uint32(bg.InodeBitmapChecksumLo)
+}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go
new file mode 100644
index 000000000..0ef4294c0
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go
@@ -0,0 +1,26 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+import (
+ "testing"
+)
+
+// TestBlockGroupSize tests that the block group descriptor structs are of the
+// correct size.
+func TestBlockGroupSize(t *testing.T) {
+ assertSize(t, BlockGroup32Bit{}, 32)
+ assertSize(t, BlockGroup64Bit{}, 64)
+}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent.go b/pkg/sentry/fsimpl/ext/disklayout/dirent.go
new file mode 100644
index 000000000..417b6cf65
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent.go
@@ -0,0 +1,72 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+)
+
+const (
+ // MaxFileName is the maximum length of an ext fs file's name.
+ MaxFileName = 255
+
+ // DirentSize is the size of ext dirent structures.
+ DirentSize = 263
+)
+
+var (
+ // inodeTypeByFileType maps ext4 file types to vfs inode types.
+ //
+ // See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#ftype.
+ inodeTypeByFileType = map[uint8]fs.InodeType{
+ 0: fs.Anonymous,
+ 1: fs.RegularFile,
+ 2: fs.Directory,
+ 3: fs.CharacterDevice,
+ 4: fs.BlockDevice,
+ 5: fs.Pipe,
+ 6: fs.Socket,
+ 7: fs.Symlink,
+ }
+)
+
+// The Dirent interface should be implemented by structs representing ext
+// directory entries. These are for the linear classical directories which
+// just store a list of dirent structs. A directory is a series of data blocks
+// where is each data block contains a linear array of dirents. The last entry
+// of the block has a record size that takes it to the end of the block. The
+// end of the directory is when you read dirInode.Size() bytes from the blocks.
+//
+// See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#linear-classic-directories.
+type Dirent interface {
+ // Inode returns the absolute inode number of the underlying inode.
+ // Inode number 0 signifies an unused dirent.
+ Inode() uint32
+
+ // RecordSize returns the record length of this dirent on disk. The next
+ // dirent in the dirent list should be read after these many bytes from
+ // the current dirent. Must be a multiple of 4.
+ RecordSize() uint16
+
+ // FileName returns the name of the file. Can be at most 255 is length.
+ FileName() string
+
+ // FileType returns the inode type of the underlying inode. This is a
+ // performance hack so that we do not have to read the underlying inode struct
+ // to know the type of inode. This will only work when the SbDirentFileType
+ // feature is set. If not, the second returned value will be false indicating
+ // that user code has to use the inode mode to extract the file type.
+ FileType() (fs.InodeType, bool)
+}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go
new file mode 100644
index 000000000..29ae4a5c2
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go
@@ -0,0 +1,61 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+)
+
+// DirentNew represents the ext4 directory entry struct. This emulates Linux's
+// ext4_dir_entry_2 struct. The FileName can not be more than 255 bytes so we
+// only need 8 bits to store the NameLength. As a result, NameLength has been
+// shortened and the other 8 bits are used to encode the file type. Use the
+// FileTypeRaw field only if the SbDirentFileType feature is set.
+//
+// Note: This struct can be of variable size on disk. The one described below
+// is of maximum size and the FileName beyond NameLength bytes might contain
+// garbage.
+type DirentNew struct {
+ InodeNumber uint32
+ RecordLength uint16
+ NameLength uint8
+ FileTypeRaw uint8
+ FileNameRaw [MaxFileName]byte
+}
+
+// Compiles only if DirentNew implements Dirent.
+var _ Dirent = (*DirentNew)(nil)
+
+// Inode implements Dirent.Inode.
+func (d *DirentNew) Inode() uint32 { return d.InodeNumber }
+
+// RecordSize implements Dirent.RecordSize.
+func (d *DirentNew) RecordSize() uint16 { return d.RecordLength }
+
+// FileName implements Dirent.FileName.
+func (d *DirentNew) FileName() string {
+ return string(d.FileNameRaw[:d.NameLength])
+}
+
+// FileType implements Dirent.FileType.
+func (d *DirentNew) FileType() (fs.InodeType, bool) {
+ if inodeType, ok := inodeTypeByFileType[d.FileTypeRaw]; ok {
+ return inodeType, true
+ }
+
+ panic(fmt.Sprintf("unknown file type %v", d.FileTypeRaw))
+}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go
new file mode 100644
index 000000000..6fff12a6e
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go
@@ -0,0 +1,49 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+import "gvisor.dev/gvisor/pkg/sentry/fs"
+
+// DirentOld represents the old directory entry struct which does not contain
+// the file type. This emulates Linux's ext4_dir_entry struct.
+//
+// Note: This struct can be of variable size on disk. The one described below
+// is of maximum size and the FileName beyond NameLength bytes might contain
+// garbage.
+type DirentOld struct {
+ InodeNumber uint32
+ RecordLength uint16
+ NameLength uint16
+ FileNameRaw [MaxFileName]byte
+}
+
+// Compiles only if DirentOld implements Dirent.
+var _ Dirent = (*DirentOld)(nil)
+
+// Inode implements Dirent.Inode.
+func (d *DirentOld) Inode() uint32 { return d.InodeNumber }
+
+// RecordSize implements Dirent.RecordSize.
+func (d *DirentOld) RecordSize() uint16 { return d.RecordLength }
+
+// FileName implements Dirent.FileName.
+func (d *DirentOld) FileName() string {
+ return string(d.FileNameRaw[:d.NameLength])
+}
+
+// FileType implements Dirent.FileType.
+func (d *DirentOld) FileType() (fs.InodeType, bool) {
+ return fs.Anonymous, false
+}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go
new file mode 100644
index 000000000..934919f8a
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go
@@ -0,0 +1,26 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+import (
+ "testing"
+)
+
+// TestDirentSize tests that the dirent structs are of the correct
+// size.
+func TestDirentSize(t *testing.T) {
+ assertSize(t, DirentOld{}, uintptr(DirentSize))
+ assertSize(t, DirentNew{}, uintptr(DirentSize))
+}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/disklayout.go b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go
new file mode 100644
index 000000000..bdf4e2132
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go
@@ -0,0 +1,50 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package disklayout provides Linux ext file system's disk level structures
+// which can be directly read into from the underlying device. Structs aim to
+// emulate structures `exactly` how they are layed out on disk.
+//
+// This library aims to be compatible with all ext(2/3/4) systems so it
+// provides a generic interface for all major structures and various
+// implementations (for different versions). The user code is responsible for
+// using appropriate implementations based on the underlying device.
+//
+// Interfacing all major structures here serves a few purposes:
+// - Abstracts away the complexity of the underlying structure from client
+// code. The client only has to figure out versioning on set up and then
+// can use these as black boxes and pass it higher up the stack.
+// - Having pointer receivers forces the user to use pointers to these
+// heavy structs. Hence, prevents the client code from unintentionally
+// copying these by value while passing the interface around.
+// - Version-based implementation selection is resolved on set up hence
+// avoiding per call overhead of choosing implementation.
+// - All interface methods are pretty light weight (do not take in any
+// parameters by design). Passing pointer arguments to interface methods
+// can lead to heap allocation as the compiler won't be able to perform
+// escape analysis on an unknown implementation at compile time.
+//
+// Notes:
+// - All fields in these structs are exported because binary.Read would
+// panic otherwise.
+// - All structures on disk are in little-endian order. Only jbd2 (journal)
+// structures are in big-endian order.
+// - All OS dependent fields in these structures will be interpretted using
+// the Linux version of that field.
+// - The suffix `Lo` in field names stands for lower bits of that field.
+// - The suffix `Hi` in field names stands for upper bits of that field.
+// - The suffix `Raw` has been added to indicate that the field is not split
+// into Lo and Hi fields and also to resolve name collision with the
+// respective interface.
+package disklayout
diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent.go b/pkg/sentry/fsimpl/ext/disklayout/extent.go
new file mode 100644
index 000000000..4110649ab
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent.go
@@ -0,0 +1,143 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+// Extents were introduced in ext4 and provide huge performance gains in terms
+// data locality and reduced metadata block usage. Extents are organized in
+// extent trees. The root node is contained in inode.BlocksRaw.
+//
+// Terminology:
+// - Physical Block:
+// Filesystem data block which is addressed normally wrt the entire
+// filesystem (addressed with 48 bits).
+//
+// - File Block:
+// Data block containing *only* file data and addressed wrt to the file
+// with only 32 bits. The (i)th file block contains file data from
+// byte (i * sb.BlockSize()) to ((i+1) * sb.BlockSize()).
+
+const (
+ // ExtentHeaderSize is the size of the header of an extent tree node.
+ ExtentHeaderSize = 12
+
+ // ExtentEntrySize is the size of an entry in an extent tree node.
+ // This size is the same for both leaf and internal nodes.
+ ExtentEntrySize = 12
+
+ // ExtentMagic is the magic number which must be present in the header.
+ ExtentMagic = 0xf30a
+)
+
+// ExtentEntryPair couples an in-memory ExtendNode with the ExtentEntry that
+// points to it. We want to cache these structs in memory to avoid repeated
+// disk reads.
+//
+// Note: This struct itself does not represent an on-disk struct.
+type ExtentEntryPair struct {
+ // Entry points to the child node on disk.
+ Entry ExtentEntry
+ // Node points to child node in memory. Is nil if the current node is a leaf.
+ Node *ExtentNode
+}
+
+// ExtentNode represents an extent tree node. For internal nodes, all Entries
+// will be ExtendIdxs. For leaf nodes, they will all be Extents.
+//
+// Note: This struct itself does not represent an on-disk struct.
+type ExtentNode struct {
+ Header ExtentHeader
+ Entries []ExtentEntryPair
+}
+
+// ExtentEntry represents an extent tree node entry. The entry can either be
+// an ExtentIdx or Extent itself. This exists to simplify navigation logic.
+type ExtentEntry interface {
+ // FileBlock returns the first file block number covered by this entry.
+ FileBlock() uint32
+
+ // PhysicalBlock returns the child physical block that this entry points to.
+ PhysicalBlock() uint64
+}
+
+// ExtentHeader emulates the ext4_extent_header struct in ext4. Each extent
+// tree node begins with this and is followed by `NumEntries` number of:
+// - Extent if `Depth` == 0
+// - ExtentIdx otherwise
+type ExtentHeader struct {
+ // Magic in the extent magic number, must be 0xf30a.
+ Magic uint16
+
+ // NumEntries indicates the number of valid entries following the header.
+ NumEntries uint16
+
+ // MaxEntries that could follow the header. Used while adding entries.
+ MaxEntries uint16
+
+ // Height represents the distance of this node from the farthest leaf. Please
+ // note that Linux incorrectly calls this `Depth` (which means the distance
+ // of the node from the root).
+ Height uint16
+ _ uint32
+}
+
+// ExtentIdx emulates the ext4_extent_idx struct in ext4. Only present in
+// internal nodes. Sorted in ascending order based on FirstFileBlock since
+// Linux does a binary search on this. This points to a block containing the
+// child node.
+type ExtentIdx struct {
+ FirstFileBlock uint32
+ ChildBlockLo uint32
+ ChildBlockHi uint16
+ _ uint16
+}
+
+// Compiles only if ExtentIdx implements ExtentEntry.
+var _ ExtentEntry = (*ExtentIdx)(nil)
+
+// FileBlock implements ExtentEntry.FileBlock.
+func (ei *ExtentIdx) FileBlock() uint32 {
+ return ei.FirstFileBlock
+}
+
+// PhysicalBlock implements ExtentEntry.PhysicalBlock. It returns the
+// physical block number of the child block.
+func (ei *ExtentIdx) PhysicalBlock() uint64 {
+ return (uint64(ei.ChildBlockHi) << 32) | uint64(ei.ChildBlockLo)
+}
+
+// Extent represents the ext4_extent struct in ext4. Only present in leaf
+// nodes. Sorted in ascending order based on FirstFileBlock since Linux does a
+// binary search on this. This points to an array of data blocks containing the
+// file data. It covers `Length` data blocks starting from `StartBlock`.
+type Extent struct {
+ FirstFileBlock uint32
+ Length uint16
+ StartBlockHi uint16
+ StartBlockLo uint32
+}
+
+// Compiles only if Extent implements ExtentEntry.
+var _ ExtentEntry = (*Extent)(nil)
+
+// FileBlock implements ExtentEntry.FileBlock.
+func (e *Extent) FileBlock() uint32 {
+ return e.FirstFileBlock
+}
+
+// PhysicalBlock implements ExtentEntry.PhysicalBlock. It returns the
+// physical block number of the first data block this extent covers.
+func (e *Extent) PhysicalBlock() uint64 {
+ return (uint64(e.StartBlockHi) << 32) | uint64(e.StartBlockLo)
+}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
new file mode 100644
index 000000000..8762b90db
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
@@ -0,0 +1,27 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+import (
+ "testing"
+)
+
+// TestExtentSize tests that the extent structs are of the correct
+// size.
+func TestExtentSize(t *testing.T) {
+ assertSize(t, ExtentHeader{}, ExtentHeaderSize)
+ assertSize(t, ExtentIdx{}, ExtentEntrySize)
+ assertSize(t, Extent{}, ExtentEntrySize)
+}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode.go b/pkg/sentry/fsimpl/ext/disklayout/inode.go
new file mode 100644
index 000000000..88ae913f5
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode.go
@@ -0,0 +1,274 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+)
+
+// Special inodes. See https://www.kernel.org/doc/html/latest/filesystems/ext4/overview.html#special-inodes.
+const (
+ // RootDirInode is the inode number of the root directory inode.
+ RootDirInode = 2
+)
+
+// The Inode interface must be implemented by structs representing ext inodes.
+// The inode stores all the metadata pertaining to the file (except for the
+// file name which is held by the directory entry). It does NOT expose all
+// fields and should be extended if need be.
+//
+// Some file systems (e.g. FAT) use the directory entry to store all this
+// information. Ext file systems do not so that they can support hard links.
+// However, ext4 cheats a little bit and duplicates the file type in the
+// directory entry for performance gains.
+//
+// See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#index-nodes.
+type Inode interface {
+ // Mode returns the linux file mode which is majorly used to extract
+ // information like:
+ // - File permissions (read/write/execute by user/group/others).
+ // - Sticky, set UID and GID bits.
+ // - File type.
+ //
+ // Masks to extract this information are provided in pkg/abi/linux/file.go.
+ Mode() linux.FileMode
+
+ // UID returns the owner UID.
+ UID() auth.KUID
+
+ // GID returns the owner GID.
+ GID() auth.KGID
+
+ // Size returns the size of the file in bytes.
+ Size() uint64
+
+ // InodeSize returns the size of this inode struct in bytes.
+ // In ext2 and ext3, the inode struct and inode disk record size was fixed at
+ // 128 bytes. Ext4 makes it possible for the inode struct to be bigger.
+ // However, accessing any field beyond the 128 bytes marker must be verified
+ // using this method.
+ InodeSize() uint16
+
+ // AccessTime returns the last access time. Shows when the file was last read.
+ //
+ // If InExtendedAttr is set, then this should NOT be used because the
+ // underlying field is used to store the extended attribute value checksum.
+ AccessTime() time.Time
+
+ // ChangeTime returns the last change time. Shows when the file meta data
+ // (like permissions) was last changed.
+ //
+ // If InExtendedAttr is set, then this should NOT be used because the
+ // underlying field is used to store the lower 32 bits of the attribute
+ // value’s reference count.
+ ChangeTime() time.Time
+
+ // ModificationTime returns the last modification time. Shows when the file
+ // content was last modified.
+ //
+ // If InExtendedAttr is set, then this should NOT be used because
+ // the underlying field contains the number of the inode that owns the
+ // extended attribute.
+ ModificationTime() time.Time
+
+ // DeletionTime returns the deletion time. Inodes are marked as deleted by
+ // writing to the underlying field. FS tools can restore files until they are
+ // actually overwritten.
+ DeletionTime() time.Time
+
+ // LinksCount returns the number of hard links to this inode.
+ //
+ // Normally there is an upper limit on the number of hard links:
+ // - ext2/ext3 = 32,000
+ // - ext4 = 65,000
+ //
+ // This implies that an ext4 directory cannot have more than 64,998
+ // subdirectories because each subdirectory will have a hard link to the
+ // directory via the `..` entry. The directory has hard link via the `.` entry
+ // of its own. And finally the inode is initiated with 1 hard link (itself).
+ //
+ // The underlying value is reset to 1 if all the following hold:
+ // - Inode is a directory.
+ // - SbDirNlink is enabled.
+ // - Number of hard links is incremented past 64,999.
+ // Hard link value of 1 for a directory would indicate that the number of hard
+ // links is unknown because a directory can have minimum 2 hard links (itself
+ // and `.` entry).
+ LinksCount() uint16
+
+ // Flags returns InodeFlags which represents the inode flags.
+ Flags() InodeFlags
+
+ // Data returns the underlying inode.i_block array as a slice so it's
+ // modifiable. This field is special and is used to store various kinds of
+ // things depending on the filesystem version and inode type. The underlying
+ // field name in Linux is a little misleading.
+ // - In ext2/ext3, it contains the block map.
+ // - In ext4, it contains the extent tree root node.
+ // - For inline files, it contains the file contents.
+ // - For symlinks, it contains the link path (if it fits here).
+ //
+ // See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#the-contents-of-inode-i-block.
+ Data() []byte
+}
+
+// Inode flags. This is not comprehensive and flags which were not used in
+// the Linux kernel have been excluded.
+const (
+ // InSync indicates that all writes to the file must be synchronous.
+ InSync = 0x8
+
+ // InImmutable indicates that this file is immutable.
+ InImmutable = 0x10
+
+ // InAppend indicates that this file can only be appended to.
+ InAppend = 0x20
+
+ // InNoDump indicates that teh dump(1) utility should not dump this file.
+ InNoDump = 0x40
+
+ // InNoAccessTime indicates that the access time of this inode must not be
+ // updated.
+ InNoAccessTime = 0x80
+
+ // InIndex indicates that this directory has hashed indexes.
+ InIndex = 0x1000
+
+ // InJournalData indicates that file data must always be written through a
+ // journal device.
+ InJournalData = 0x4000
+
+ // InDirSync indicates that all the directory entiry data must be written
+ // synchronously.
+ InDirSync = 0x10000
+
+ // InTopDir indicates that this inode is at the top of the directory hierarchy.
+ InTopDir = 0x20000
+
+ // InHugeFile indicates that this is a huge file.
+ InHugeFile = 0x40000
+
+ // InExtents indicates that this inode uses extents.
+ InExtents = 0x80000
+
+ // InExtendedAttr indicates that this inode stores a large extended attribute
+ // value in its data blocks.
+ InExtendedAttr = 0x200000
+
+ // InInline indicates that this inode has inline data.
+ InInline = 0x10000000
+
+ // InReserved indicates that this inode is reserved for the ext4 library.
+ InReserved = 0x80000000
+)
+
+// InodeFlags represents all possible combinations of inode flags. It aims to
+// cover the bit masks and provide a more user-friendly interface.
+type InodeFlags struct {
+ Sync bool
+ Immutable bool
+ Append bool
+ NoDump bool
+ NoAccessTime bool
+ Index bool
+ JournalData bool
+ DirSync bool
+ TopDir bool
+ HugeFile bool
+ Extents bool
+ ExtendedAttr bool
+ Inline bool
+ Reserved bool
+}
+
+// ToInt converts inode flags back to its 32-bit rep.
+func (f InodeFlags) ToInt() uint32 {
+ var res uint32
+
+ if f.Sync {
+ res |= InSync
+ }
+ if f.Immutable {
+ res |= InImmutable
+ }
+ if f.Append {
+ res |= InAppend
+ }
+ if f.NoDump {
+ res |= InNoDump
+ }
+ if f.NoAccessTime {
+ res |= InNoAccessTime
+ }
+ if f.Index {
+ res |= InIndex
+ }
+ if f.JournalData {
+ res |= InJournalData
+ }
+ if f.DirSync {
+ res |= InDirSync
+ }
+ if f.TopDir {
+ res |= InTopDir
+ }
+ if f.HugeFile {
+ res |= InHugeFile
+ }
+ if f.Extents {
+ res |= InExtents
+ }
+ if f.ExtendedAttr {
+ res |= InExtendedAttr
+ }
+ if f.Inline {
+ res |= InInline
+ }
+ if f.Reserved {
+ res |= InReserved
+ }
+
+ return res
+}
+
+// InodeFlagsFromInt converts the integer representation of inode flags to
+// a InodeFlags struct.
+func InodeFlagsFromInt(f uint32) InodeFlags {
+ return InodeFlags{
+ Sync: f&InSync > 0,
+ Immutable: f&InImmutable > 0,
+ Append: f&InAppend > 0,
+ NoDump: f&InNoDump > 0,
+ NoAccessTime: f&InNoAccessTime > 0,
+ Index: f&InIndex > 0,
+ JournalData: f&InJournalData > 0,
+ DirSync: f&InDirSync > 0,
+ TopDir: f&InTopDir > 0,
+ HugeFile: f&InHugeFile > 0,
+ Extents: f&InExtents > 0,
+ ExtendedAttr: f&InExtendedAttr > 0,
+ Inline: f&InInline > 0,
+ Reserved: f&InReserved > 0,
+ }
+}
+
+// These masks define how users can view/modify inode flags. The rest of the
+// flags are for internal kernel usage only.
+const (
+ InUserReadFlagMask = 0x4BDFFF
+ InUserWriteFlagMask = 0x4B80FF
+)
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_new.go b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go
new file mode 100644
index 000000000..8f9f574ce
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go
@@ -0,0 +1,96 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+import "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+
+// InodeNew represents ext4 inode structure which can be bigger than
+// OldInodeSize. The actual size of this struct should be determined using
+// inode.ExtraInodeSize. Accessing any field here should be verified with the
+// actual size. The extra space between the end of the inode struct and end of
+// the inode record can be used to store extended attr.
+//
+// If the TimeExtra fields are in scope, the lower 2 bits of those are used
+// to extend their counter part to be 34 bits wide; the rest (upper) 30 bits
+// are used to provide nanoscond precision. Hence, these timestamps will now
+// overflow in May 2446.
+// See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#inode-timestamps.
+type InodeNew struct {
+ InodeOld
+
+ ExtraInodeSize uint16
+ ChecksumHi uint16
+ ChangeTimeExtra uint32
+ ModificationTimeExtra uint32
+ AccessTimeExtra uint32
+ CreationTime uint32
+ CreationTimeExtra uint32
+ VersionHi uint32
+ ProjectID uint32
+}
+
+// Compiles only if InodeNew implements Inode.
+var _ Inode = (*InodeNew)(nil)
+
+// fromExtraTime decodes the extra time and constructs the kernel time struct
+// with nanosecond precision.
+func fromExtraTime(lo int32, extra uint32) time.Time {
+ // See description above InodeNew for format.
+ seconds := (int64(extra&0x3) << 32) + int64(lo)
+ nanoseconds := int64(extra >> 2)
+ return time.FromUnix(seconds, nanoseconds)
+}
+
+// Only override methods which change due to ext4 specific fields.
+
+// Size implements Inode.Size.
+func (in *InodeNew) Size() uint64 {
+ return (uint64(in.SizeHi) << 32) | uint64(in.SizeLo)
+}
+
+// InodeSize implements Inode.InodeSize.
+func (in *InodeNew) InodeSize() uint16 {
+ return OldInodeSize + in.ExtraInodeSize
+}
+
+// ChangeTime implements Inode.ChangeTime.
+func (in *InodeNew) ChangeTime() time.Time {
+ // Apply new timestamp logic if inode.ChangeTimeExtra is in scope.
+ if in.ExtraInodeSize >= 8 {
+ return fromExtraTime(in.ChangeTimeRaw, in.ChangeTimeExtra)
+ }
+
+ return in.InodeOld.ChangeTime()
+}
+
+// ModificationTime implements Inode.ModificationTime.
+func (in *InodeNew) ModificationTime() time.Time {
+ // Apply new timestamp logic if inode.ModificationTimeExtra is in scope.
+ if in.ExtraInodeSize >= 12 {
+ return fromExtraTime(in.ModificationTimeRaw, in.ModificationTimeExtra)
+ }
+
+ return in.InodeOld.ModificationTime()
+}
+
+// AccessTime implements Inode.AccessTime.
+func (in *InodeNew) AccessTime() time.Time {
+ // Apply new timestamp logic if inode.AccessTimeExtra is in scope.
+ if in.ExtraInodeSize >= 16 {
+ return fromExtraTime(in.AccessTimeRaw, in.AccessTimeExtra)
+ }
+
+ return in.InodeOld.AccessTime()
+}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_old.go b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go
new file mode 100644
index 000000000..db25b11b6
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go
@@ -0,0 +1,117 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+)
+
+const (
+ // OldInodeSize is the inode size in ext2/ext3.
+ OldInodeSize = 128
+)
+
+// InodeOld implements Inode interface. It emulates ext2/ext3 inode struct.
+// Inode struct size and record size are both 128 bytes for this.
+//
+// All fields representing time are in seconds since the epoch. Which means that
+// they will overflow in January 2038.
+type InodeOld struct {
+ ModeRaw uint16
+ UIDLo uint16
+ SizeLo uint32
+
+ // The time fields are signed integers because they could be negative to
+ // represent time before the epoch.
+ AccessTimeRaw int32
+ ChangeTimeRaw int32
+ ModificationTimeRaw int32
+ DeletionTimeRaw int32
+
+ GIDLo uint16
+ LinksCountRaw uint16
+ BlocksCountLo uint32
+ FlagsRaw uint32
+ VersionLo uint32 // This is OS dependent.
+ DataRaw [60]byte
+ Generation uint32
+ FileACLLo uint32
+ SizeHi uint32
+ ObsoFaddr uint32
+
+ // OS dependent fields have been inlined here.
+ BlocksCountHi uint16
+ FileACLHi uint16
+ UIDHi uint16
+ GIDHi uint16
+ ChecksumLo uint16
+ _ uint16
+}
+
+// Compiles only if InodeOld implements Inode.
+var _ Inode = (*InodeOld)(nil)
+
+// Mode implements Inode.Mode.
+func (in *InodeOld) Mode() linux.FileMode { return linux.FileMode(in.ModeRaw) }
+
+// UID implements Inode.UID.
+func (in *InodeOld) UID() auth.KUID {
+ return auth.KUID((uint32(in.UIDHi) << 16) | uint32(in.UIDLo))
+}
+
+// GID implements Inode.GID.
+func (in *InodeOld) GID() auth.KGID {
+ return auth.KGID((uint32(in.GIDHi) << 16) | uint32(in.GIDLo))
+}
+
+// Size implements Inode.Size.
+func (in *InodeOld) Size() uint64 {
+ // In ext2/ext3, in.SizeHi did not exist, it was instead named in.DirACL.
+ return uint64(in.SizeLo)
+}
+
+// InodeSize implements Inode.InodeSize.
+func (in *InodeOld) InodeSize() uint16 { return OldInodeSize }
+
+// AccessTime implements Inode.AccessTime.
+func (in *InodeOld) AccessTime() time.Time {
+ return time.FromUnix(int64(in.AccessTimeRaw), 0)
+}
+
+// ChangeTime implements Inode.ChangeTime.
+func (in *InodeOld) ChangeTime() time.Time {
+ return time.FromUnix(int64(in.ChangeTimeRaw), 0)
+}
+
+// ModificationTime implements Inode.ModificationTime.
+func (in *InodeOld) ModificationTime() time.Time {
+ return time.FromUnix(int64(in.ModificationTimeRaw), 0)
+}
+
+// DeletionTime implements Inode.DeletionTime.
+func (in *InodeOld) DeletionTime() time.Time {
+ return time.FromUnix(int64(in.DeletionTimeRaw), 0)
+}
+
+// LinksCount implements Inode.LinksCount.
+func (in *InodeOld) LinksCount() uint16 { return in.LinksCountRaw }
+
+// Flags implements Inode.Flags.
+func (in *InodeOld) Flags() InodeFlags { return InodeFlagsFromInt(in.FlagsRaw) }
+
+// Data implements Inode.Data.
+func (in *InodeOld) Data() []byte { return in.DataRaw[:] }
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_test.go b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go
new file mode 100644
index 000000000..dd03ee50e
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go
@@ -0,0 +1,222 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+import (
+ "fmt"
+ "strconv"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+)
+
+// TestInodeSize tests that the inode structs are of the correct size.
+func TestInodeSize(t *testing.T) {
+ assertSize(t, InodeOld{}, OldInodeSize)
+
+ // This was updated from 156 bytes to 160 bytes in Oct 2015.
+ assertSize(t, InodeNew{}, 160)
+}
+
+// TestTimestampSeconds tests that the seconds part of [a/c/m] timestamps in
+// ext4 inode structs are decoded correctly.
+//
+// These tests are derived from the table under https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#inode-timestamps.
+func TestTimestampSeconds(t *testing.T) {
+ type timestampTest struct {
+ // msbSet tells if the most significant bit of InodeOld.[X]TimeRaw is set.
+ // If this is set then the 32-bit time is negative.
+ msbSet bool
+
+ // lowerBound tells if we should take the lowest possible value of
+ // InodeOld.[X]TimeRaw while satisfying test.msbSet condition. If set to
+ // false it tells to take the highest possible value.
+ lowerBound bool
+
+ // extraBits is InodeNew.[X]TimeExtra.
+ extraBits uint32
+
+ // want is the kernel time struct that is expected.
+ want time.Time
+ }
+
+ tests := []timestampTest{
+ // 1901-12-13
+ {
+ msbSet: true,
+ lowerBound: true,
+ extraBits: 0,
+ want: time.FromUnix(int64(-0x80000000), 0),
+ },
+
+ // 1969-12-31
+ {
+ msbSet: true,
+ lowerBound: false,
+ extraBits: 0,
+ want: time.FromUnix(int64(-1), 0),
+ },
+
+ // 1970-01-01
+ {
+ msbSet: false,
+ lowerBound: true,
+ extraBits: 0,
+ want: time.FromUnix(int64(0), 0),
+ },
+
+ // 2038-01-19
+ {
+ msbSet: false,
+ lowerBound: false,
+ extraBits: 0,
+ want: time.FromUnix(int64(0x7fffffff), 0),
+ },
+
+ // 2038-01-19
+ {
+ msbSet: true,
+ lowerBound: true,
+ extraBits: 1,
+ want: time.FromUnix(int64(0x80000000), 0),
+ },
+
+ // 2106-02-07
+ {
+ msbSet: true,
+ lowerBound: false,
+ extraBits: 1,
+ want: time.FromUnix(int64(0xffffffff), 0),
+ },
+
+ // 2106-02-07
+ {
+ msbSet: false,
+ lowerBound: true,
+ extraBits: 1,
+ want: time.FromUnix(int64(0x100000000), 0),
+ },
+
+ // 2174-02-25
+ {
+ msbSet: false,
+ lowerBound: false,
+ extraBits: 1,
+ want: time.FromUnix(int64(0x17fffffff), 0),
+ },
+
+ // 2174-02-25
+ {
+ msbSet: true,
+ lowerBound: true,
+ extraBits: 2,
+ want: time.FromUnix(int64(0x180000000), 0),
+ },
+
+ // 2242-03-16
+ {
+ msbSet: true,
+ lowerBound: false,
+ extraBits: 2,
+ want: time.FromUnix(int64(0x1ffffffff), 0),
+ },
+
+ // 2242-03-16
+ {
+ msbSet: false,
+ lowerBound: true,
+ extraBits: 2,
+ want: time.FromUnix(int64(0x200000000), 0),
+ },
+
+ // 2310-04-04
+ {
+ msbSet: false,
+ lowerBound: false,
+ extraBits: 2,
+ want: time.FromUnix(int64(0x27fffffff), 0),
+ },
+
+ // 2310-04-04
+ {
+ msbSet: true,
+ lowerBound: true,
+ extraBits: 3,
+ want: time.FromUnix(int64(0x280000000), 0),
+ },
+
+ // 2378-04-22
+ {
+ msbSet: true,
+ lowerBound: false,
+ extraBits: 3,
+ want: time.FromUnix(int64(0x2ffffffff), 0),
+ },
+
+ // 2378-04-22
+ {
+ msbSet: false,
+ lowerBound: true,
+ extraBits: 3,
+ want: time.FromUnix(int64(0x300000000), 0),
+ },
+
+ // 2446-05-10
+ {
+ msbSet: false,
+ lowerBound: false,
+ extraBits: 3,
+ want: time.FromUnix(int64(0x37fffffff), 0),
+ },
+ }
+
+ lowerMSB0 := int32(0) // binary: 00000000 00000000 00000000 00000000
+ upperMSB0 := int32(0x7fffffff) // binary: 01111111 11111111 11111111 11111111
+ lowerMSB1 := int32(-0x80000000) // binary: 10000000 00000000 00000000 00000000
+ upperMSB1 := int32(-1) // binary: 11111111 11111111 11111111 11111111
+
+ get32BitTime := func(test timestampTest) int32 {
+ if test.msbSet {
+ if test.lowerBound {
+ return lowerMSB1
+ }
+
+ return upperMSB1
+ }
+
+ if test.lowerBound {
+ return lowerMSB0
+ }
+
+ return upperMSB0
+ }
+
+ getTestName := func(test timestampTest) string {
+ return fmt.Sprintf(
+ "Tests time decoding with epoch bits 0b%s and 32-bit raw time: MSB set=%t, lower bound=%t",
+ strconv.FormatInt(int64(test.extraBits), 2),
+ test.msbSet,
+ test.lowerBound,
+ )
+ }
+
+ for _, test := range tests {
+ t.Run(getTestName(test), func(t *testing.T) {
+ if got := fromExtraTime(get32BitTime(test), test.extraBits); got != test.want {
+ t.Errorf("Expected: %v, Got: %v", test.want, got)
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock.go b/pkg/sentry/fsimpl/ext/disklayout/superblock.go
new file mode 100644
index 000000000..8bb327006
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock.go
@@ -0,0 +1,471 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+const (
+ // SbOffset is the absolute offset at which the superblock is placed.
+ SbOffset = 1024
+)
+
+// SuperBlock should be implemented by structs representing the ext superblock.
+// The superblock holds a lot of information about the enclosing filesystem.
+// This interface aims to provide access methods to important information held
+// by the superblock. It does NOT expose all fields of the superblock, only the
+// ones necessary. This can be expanded when need be.
+//
+// Location and replication:
+// - The superblock is located at offset 1024 in block group 0.
+// - Redundant copies of the superblock and group descriptors are kept in
+// all groups if SbSparse feature flag is NOT set. If it is set, the
+// replicas only exist in groups whose group number is either 0 or a
+// power of 3, 5, or 7.
+// - There is also a sparse superblock feature v2 in which there are just
+// two replicas saved in the block groups pointed by sb.s_backup_bgs.
+//
+// Replicas should eventually be updated if the superblock is updated.
+//
+// See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#super-block.
+type SuperBlock interface {
+ // InodesCount returns the total number of inodes in this filesystem.
+ InodesCount() uint32
+
+ // BlocksCount returns the total number of data blocks in this filesystem.
+ BlocksCount() uint64
+
+ // FreeBlocksCount returns the number of free blocks in this filesystem.
+ FreeBlocksCount() uint64
+
+ // FreeInodesCount returns the number of free inodes in this filesystem.
+ FreeInodesCount() uint32
+
+ // MountCount returns the number of mounts since the last fsck.
+ MountCount() uint16
+
+ // MaxMountCount returns the number of mounts allowed beyond which a fsck is
+ // needed.
+ MaxMountCount() uint16
+
+ // FirstDataBlock returns the absolute block number of the first data block,
+ // which contains the super block itself.
+ //
+ // If the filesystem has 1kb data blocks then this should return 1. For all
+ // other configurations, this typically returns 0.
+ FirstDataBlock() uint32
+
+ // BlockSize returns the size of one data block in this filesystem.
+ // This can be calculated by 2^(10 + sb.s_log_block_size). This ensures that
+ // the smallest block size is 1kb.
+ BlockSize() uint64
+
+ // BlocksPerGroup returns the number of data blocks in a block group.
+ BlocksPerGroup() uint32
+
+ // ClusterSize returns block cluster size (set during mkfs time by admin).
+ // This can be calculated by 2^(10 + sb.s_log_cluster_size). This ensures that
+ // the smallest cluster size is 1kb.
+ //
+ // sb.s_log_cluster_size must equal sb.s_log_block_size if bigalloc feature
+ // is NOT set and consequently BlockSize() = ClusterSize() in that case.
+ ClusterSize() uint64
+
+ // ClustersPerGroup returns:
+ // - number of clusters per group if bigalloc is enabled.
+ // - BlocksPerGroup() otherwise.
+ ClustersPerGroup() uint32
+
+ // InodeSize returns the size of the inode disk record size in bytes. Use this
+ // to iterate over inode arrays on disk.
+ //
+ // In ext2 and ext3:
+ // - Each inode had a disk record of 128 bytes.
+ // - The inode struct size was fixed at 128 bytes.
+ //
+ // In ext4 its possible to allocate larger on-disk inodes:
+ // - Inode disk record size = sb.s_inode_size (function return value).
+ // = 256 (default)
+ // - Inode struct size = 128 + inode.i_extra_isize.
+ // = 128 + 32 = 160 (default)
+ InodeSize() uint16
+
+ // InodesPerGroup returns the number of inodes in a block group.
+ InodesPerGroup() uint32
+
+ // BgDescSize returns the size of the block group descriptor struct.
+ //
+ // In ext2, ext3, ext4 (without 64-bit feature), the block group descriptor
+ // is only 32 bytes long.
+ // In ext4 with 64-bit feature, the block group descriptor expands to AT LEAST
+ // 64 bytes. It might be bigger than that.
+ BgDescSize() uint16
+
+ // CompatibleFeatures returns the CompatFeatures struct which holds all the
+ // compatible features this fs supports.
+ CompatibleFeatures() CompatFeatures
+
+ // IncompatibleFeatures returns the CompatFeatures struct which holds all the
+ // incompatible features this fs supports.
+ IncompatibleFeatures() IncompatFeatures
+
+ // ReadOnlyCompatibleFeatures returns the CompatFeatures struct which holds all the
+ // readonly compatible features this fs supports.
+ ReadOnlyCompatibleFeatures() RoCompatFeatures
+
+ // Magic() returns the magic signature which must be 0xef53.
+ Magic() uint16
+
+ // Revision returns the superblock revision. Superblock struct fields from
+ // offset 0x54 till 0x150 should only be used if superblock has DynamicRev.
+ Revision() SbRevision
+}
+
+// SbRevision is the type for superblock revisions.
+type SbRevision uint32
+
+// Super block revisions.
+const (
+ // OldRev is the good old (original) format.
+ OldRev SbRevision = 0
+
+ // DynamicRev is v2 format w/ dynamic inode sizes.
+ DynamicRev SbRevision = 1
+)
+
+// Superblock compatible features.
+// This is not exhaustive, unused features are not listed.
+const (
+ // SbDirPrealloc indicates directory preallocation.
+ SbDirPrealloc = 0x1
+
+ // SbHasJournal indicates the presence of a journal. jbd2 should only work
+ // with this being set.
+ SbHasJournal = 0x4
+
+ // SbExtAttr indicates extended attributes support.
+ SbExtAttr = 0x8
+
+ // SbResizeInode indicates that the fs has reserved GDT blocks (right after
+ // group descriptors) for fs expansion.
+ SbResizeInode = 0x10
+
+ // SbDirIndex indicates that the fs has directory indices.
+ SbDirIndex = 0x20
+
+ // SbSparseV2 stands for Sparse superblock version 2.
+ SbSparseV2 = 0x200
+)
+
+// CompatFeatures represents a superblock's compatible feature set. If the
+// kernel does not understand any of these feature, it can still read/write
+// to this fs.
+type CompatFeatures struct {
+ DirPrealloc bool
+ HasJournal bool
+ ExtAttr bool
+ ResizeInode bool
+ DirIndex bool
+ SparseV2 bool
+}
+
+// ToInt converts superblock compatible features back to its 32-bit rep.
+func (f CompatFeatures) ToInt() uint32 {
+ var res uint32
+
+ if f.DirPrealloc {
+ res |= SbDirPrealloc
+ }
+ if f.HasJournal {
+ res |= SbHasJournal
+ }
+ if f.ExtAttr {
+ res |= SbExtAttr
+ }
+ if f.ResizeInode {
+ res |= SbResizeInode
+ }
+ if f.DirIndex {
+ res |= SbDirIndex
+ }
+ if f.SparseV2 {
+ res |= SbSparseV2
+ }
+
+ return res
+}
+
+// CompatFeaturesFromInt converts the integer representation of superblock
+// compatible features to CompatFeatures struct.
+func CompatFeaturesFromInt(f uint32) CompatFeatures {
+ return CompatFeatures{
+ DirPrealloc: f&SbDirPrealloc > 0,
+ HasJournal: f&SbHasJournal > 0,
+ ExtAttr: f&SbExtAttr > 0,
+ ResizeInode: f&SbResizeInode > 0,
+ DirIndex: f&SbDirIndex > 0,
+ SparseV2: f&SbSparseV2 > 0,
+ }
+}
+
+// Superblock incompatible features.
+// This is not exhaustive, unused features are not listed.
+const (
+ // SbDirentFileType indicates that directory entries record the file type.
+ // We should use struct DirentNew for dirents then.
+ SbDirentFileType = 0x2
+
+ // SbRecovery indicates that the filesystem needs recovery.
+ SbRecovery = 0x4
+
+ // SbJournalDev indicates that the filesystem has a separate journal device.
+ SbJournalDev = 0x8
+
+ // SbMetaBG indicates that the filesystem is using Meta block groups. Moves
+ // the group descriptors from the congested first block group into the first
+ // group of each metablock group to increase the maximum block groups limit
+ // and hence support much larger filesystems.
+ //
+ // See https://www.kernel.org/doc/html/latest/filesystems/ext4/overview.html#meta-block-groups.
+ SbMetaBG = 0x10
+
+ // SbExtents indicates that the filesystem uses extents. Must be set in ext4
+ // filesystems.
+ SbExtents = 0x40
+
+ // SbIs64Bit indicates that this filesystem addresses blocks with 64-bits.
+ // Hence can support 2^64 data blocks.
+ SbIs64Bit = 0x80
+
+ // SbMMP indicates that this filesystem has multiple mount protection.
+ //
+ // See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#multiple-mount-protection.
+ SbMMP = 0x100
+
+ // SbFlexBg indicates that this filesystem has flexible block groups. Several
+ // block groups are tied into one logical block group so that all the metadata
+ // for the block groups (bitmaps and inode tables) are close together for
+ // faster loading. Consequently, large files will be continuous on disk.
+ // However, this does not affect the placement of redundant superblocks and
+ // group descriptors.
+ //
+ // See https://www.kernel.org/doc/html/latest/filesystems/ext4/overview.html#flexible-block-groups.
+ SbFlexBg = 0x200
+
+ // SbLargeDir shows that large directory enabled. Directory htree can be 3
+ // levels deep. Directory htrees are allowed to be 2 levels deep otherwise.
+ SbLargeDir = 0x4000
+
+ // SbInlineData allows inline data in inodes for really small files.
+ SbInlineData = 0x8000
+
+ // SbEncrypted indicates that this fs contains encrypted inodes.
+ SbEncrypted = 0x10000
+)
+
+// IncompatFeatures represents a superblock's incompatible feature set. If the
+// kernel does not understand any of these feature, it should refuse to mount.
+type IncompatFeatures struct {
+ DirentFileType bool
+ Recovery bool
+ JournalDev bool
+ MetaBG bool
+ Extents bool
+ Is64Bit bool
+ MMP bool
+ FlexBg bool
+ LargeDir bool
+ InlineData bool
+ Encrypted bool
+}
+
+// ToInt converts superblock incompatible features back to its 32-bit rep.
+func (f IncompatFeatures) ToInt() uint32 {
+ var res uint32
+
+ if f.DirentFileType {
+ res |= SbDirentFileType
+ }
+ if f.Recovery {
+ res |= SbRecovery
+ }
+ if f.JournalDev {
+ res |= SbJournalDev
+ }
+ if f.MetaBG {
+ res |= SbMetaBG
+ }
+ if f.Extents {
+ res |= SbExtents
+ }
+ if f.Is64Bit {
+ res |= SbIs64Bit
+ }
+ if f.MMP {
+ res |= SbMMP
+ }
+ if f.FlexBg {
+ res |= SbFlexBg
+ }
+ if f.LargeDir {
+ res |= SbLargeDir
+ }
+ if f.InlineData {
+ res |= SbInlineData
+ }
+ if f.Encrypted {
+ res |= SbEncrypted
+ }
+
+ return res
+}
+
+// IncompatFeaturesFromInt converts the integer representation of superblock
+// incompatible features to IncompatFeatures struct.
+func IncompatFeaturesFromInt(f uint32) IncompatFeatures {
+ return IncompatFeatures{
+ DirentFileType: f&SbDirentFileType > 0,
+ Recovery: f&SbRecovery > 0,
+ JournalDev: f&SbJournalDev > 0,
+ MetaBG: f&SbMetaBG > 0,
+ Extents: f&SbExtents > 0,
+ Is64Bit: f&SbIs64Bit > 0,
+ MMP: f&SbMMP > 0,
+ FlexBg: f&SbFlexBg > 0,
+ LargeDir: f&SbLargeDir > 0,
+ InlineData: f&SbInlineData > 0,
+ Encrypted: f&SbEncrypted > 0,
+ }
+}
+
+// Superblock readonly compatible features.
+// This is not exhaustive, unused features are not listed.
+const (
+ // SbSparse indicates sparse superblocks. Only groups with number either 0 or
+ // a power of 3, 5, or 7 will have redundant copies of the superblock and
+ // block descriptors.
+ SbSparse = 0x1
+
+ // SbLargeFile indicates that this fs has been used to store a file >= 2GiB.
+ SbLargeFile = 0x2
+
+ // SbHugeFile indicates that this fs contains files whose sizes are
+ // represented in units of logicals blocks, not 512-byte sectors.
+ SbHugeFile = 0x8
+
+ // SbGdtCsum indicates that group descriptors have checksums.
+ SbGdtCsum = 0x10
+
+ // SbDirNlink indicates that the new subdirectory limit is 64,999. Ext3 has a
+ // 32,000 subdirectory limit.
+ SbDirNlink = 0x20
+
+ // SbExtraIsize indicates that large inodes exist on this filesystem.
+ SbExtraIsize = 0x40
+
+ // SbHasSnapshot indicates the existence of a snapshot.
+ SbHasSnapshot = 0x80
+
+ // SbQuota enables usage tracking for all quota types.
+ SbQuota = 0x100
+
+ // SbBigalloc maps to the bigalloc feature. When set, the minimum allocation
+ // unit becomes a cluster rather than a data block. Then block bitmaps track
+ // clusters, not data blocks.
+ //
+ // See https://www.kernel.org/doc/html/latest/filesystems/ext4/overview.html#bigalloc.
+ SbBigalloc = 0x200
+
+ // SbMetadataCsum indicates that the fs supports metadata checksumming.
+ SbMetadataCsum = 0x400
+
+ // SbReadOnly marks this filesystem as readonly. Should refuse to mount in
+ // read/write mode.
+ SbReadOnly = 0x1000
+)
+
+// RoCompatFeatures represents a superblock's readonly compatible feature set.
+// If the kernel does not understand any of these feature, it can still mount
+// readonly. But if the user wants to mount read/write, the kernel should
+// refuse to mount.
+type RoCompatFeatures struct {
+ Sparse bool
+ LargeFile bool
+ HugeFile bool
+ GdtCsum bool
+ DirNlink bool
+ ExtraIsize bool
+ HasSnapshot bool
+ Quota bool
+ Bigalloc bool
+ MetadataCsum bool
+ ReadOnly bool
+}
+
+// ToInt converts superblock readonly compatible features to its 32-bit rep.
+func (f RoCompatFeatures) ToInt() uint32 {
+ var res uint32
+
+ if f.Sparse {
+ res |= SbSparse
+ }
+ if f.LargeFile {
+ res |= SbLargeFile
+ }
+ if f.HugeFile {
+ res |= SbHugeFile
+ }
+ if f.GdtCsum {
+ res |= SbGdtCsum
+ }
+ if f.DirNlink {
+ res |= SbDirNlink
+ }
+ if f.ExtraIsize {
+ res |= SbExtraIsize
+ }
+ if f.HasSnapshot {
+ res |= SbHasSnapshot
+ }
+ if f.Quota {
+ res |= SbQuota
+ }
+ if f.Bigalloc {
+ res |= SbBigalloc
+ }
+ if f.MetadataCsum {
+ res |= SbMetadataCsum
+ }
+ if f.ReadOnly {
+ res |= SbReadOnly
+ }
+
+ return res
+}
+
+// RoCompatFeaturesFromInt converts the integer representation of superblock
+// readonly compatible features to RoCompatFeatures struct.
+func RoCompatFeaturesFromInt(f uint32) RoCompatFeatures {
+ return RoCompatFeatures{
+ Sparse: f&SbSparse > 0,
+ LargeFile: f&SbLargeFile > 0,
+ HugeFile: f&SbHugeFile > 0,
+ GdtCsum: f&SbGdtCsum > 0,
+ DirNlink: f&SbDirNlink > 0,
+ ExtraIsize: f&SbExtraIsize > 0,
+ HasSnapshot: f&SbHasSnapshot > 0,
+ Quota: f&SbQuota > 0,
+ Bigalloc: f&SbBigalloc > 0,
+ MetadataCsum: f&SbMetadataCsum > 0,
+ ReadOnly: f&SbReadOnly > 0,
+ }
+}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go
new file mode 100644
index 000000000..53e515fd3
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go
@@ -0,0 +1,76 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+// SuperBlock32Bit implements SuperBlock and represents the 32-bit version of
+// the ext4_super_block struct in fs/ext4/ext4.h. Should be used only if
+// RevLevel = DynamicRev and 64-bit feature is disabled.
+type SuperBlock32Bit struct {
+ // We embed the old superblock struct here because the 32-bit version is just
+ // an extension of the old version.
+ SuperBlockOld
+
+ FirstInode uint32
+ InodeSizeRaw uint16
+ BlockGroupNumber uint16
+ FeatureCompat uint32
+ FeatureIncompat uint32
+ FeatureRoCompat uint32
+ UUID [16]byte
+ VolumeName [16]byte
+ LastMounted [64]byte
+ AlgoUsageBitmap uint32
+ PreallocBlocks uint8
+ PreallocDirBlocks uint8
+ ReservedGdtBlocks uint16
+ JournalUUID [16]byte
+ JournalInum uint32
+ JournalDev uint32
+ LastOrphan uint32
+ HashSeed [4]uint32
+ DefaultHashVersion uint8
+ JnlBackupType uint8
+ BgDescSizeRaw uint16
+ DefaultMountOpts uint32
+ FirstMetaBg uint32
+ MkfsTime uint32
+ JnlBlocks [17]uint32
+}
+
+// Compiles only if SuperBlock32Bit implements SuperBlock.
+var _ SuperBlock = (*SuperBlock32Bit)(nil)
+
+// Only override methods which change based on the additional fields above.
+// Not overriding SuperBlock.BgDescSize because it would still return 32 here.
+
+// InodeSize implements SuperBlock.InodeSize.
+func (sb *SuperBlock32Bit) InodeSize() uint16 {
+ return sb.InodeSizeRaw
+}
+
+// CompatibleFeatures implements SuperBlock.CompatibleFeatures.
+func (sb *SuperBlock32Bit) CompatibleFeatures() CompatFeatures {
+ return CompatFeaturesFromInt(sb.FeatureCompat)
+}
+
+// IncompatibleFeatures implements SuperBlock.IncompatibleFeatures.
+func (sb *SuperBlock32Bit) IncompatibleFeatures() IncompatFeatures {
+ return IncompatFeaturesFromInt(sb.FeatureIncompat)
+}
+
+// ReadOnlyCompatibleFeatures implements SuperBlock.ReadOnlyCompatibleFeatures.
+func (sb *SuperBlock32Bit) ReadOnlyCompatibleFeatures() RoCompatFeatures {
+ return RoCompatFeaturesFromInt(sb.FeatureRoCompat)
+}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go
new file mode 100644
index 000000000..7c1053fb4
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go
@@ -0,0 +1,95 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+// SuperBlock64Bit implements SuperBlock and represents the 64-bit version of
+// the ext4_super_block struct in fs/ext4/ext4.h. This sums up to be exactly
+// 1024 bytes (smallest possible block size) and hence the superblock always
+// fits in no more than one data block. Should only be used when the 64-bit
+// feature is set.
+type SuperBlock64Bit struct {
+ // We embed the 32-bit struct here because 64-bit version is just an extension
+ // of the 32-bit version.
+ SuperBlock32Bit
+
+ BlocksCountHi uint32
+ ReservedBlocksCountHi uint32
+ FreeBlocksCountHi uint32
+ MinInodeSize uint16
+ WantInodeSize uint16
+ Flags uint32
+ RaidStride uint16
+ MmpInterval uint16
+ MmpBlock uint64
+ RaidStripeWidth uint32
+ LogGroupsPerFlex uint8
+ ChecksumType uint8
+ _ uint16
+ KbytesWritten uint64
+ SnapshotInum uint32
+ SnapshotID uint32
+ SnapshotRsrvBlocksCount uint64
+ SnapshotList uint32
+ ErrorCount uint32
+ FirstErrorTime uint32
+ FirstErrorInode uint32
+ FirstErrorBlock uint64
+ FirstErrorFunction [32]byte
+ FirstErrorLine uint32
+ LastErrorTime uint32
+ LastErrorInode uint32
+ LastErrorLine uint32
+ LastErrorBlock uint64
+ LastErrorFunction [32]byte
+ MountOpts [64]byte
+ UserQuotaInum uint32
+ GroupQuotaInum uint32
+ OverheadBlocks uint32
+ BackupBgs [2]uint32
+ EncryptAlgos [4]uint8
+ EncryptPwSalt [16]uint8
+ LostFoundInode uint32
+ ProjectQuotaInode uint32
+ ChecksumSeed uint32
+ WtimeHi uint8
+ MtimeHi uint8
+ MkfsTimeHi uint8
+ LastCheckHi uint8
+ FirstErrorTimeHi uint8
+ LastErrorTimeHi uint8
+ _ [2]uint8
+ Encoding uint16
+ EncodingFlags uint16
+ _ [95]uint32
+ Checksum uint32
+}
+
+// Compiles only if SuperBlock64Bit implements SuperBlock.
+var _ SuperBlock = (*SuperBlock64Bit)(nil)
+
+// Only override methods which change based on the 64-bit feature.
+
+// BlocksCount implements SuperBlock.BlocksCount.
+func (sb *SuperBlock64Bit) BlocksCount() uint64 {
+ return (uint64(sb.BlocksCountHi) << 32) | uint64(sb.BlocksCountLo)
+}
+
+// FreeBlocksCount implements SuperBlock.FreeBlocksCount.
+func (sb *SuperBlock64Bit) FreeBlocksCount() uint64 {
+ return (uint64(sb.FreeBlocksCountHi) << 32) | uint64(sb.FreeBlocksCountLo)
+}
+
+// BgDescSize implements SuperBlock.BgDescSize.
+func (sb *SuperBlock64Bit) BgDescSize() uint16 { return sb.BgDescSizeRaw }
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go
new file mode 100644
index 000000000..9221e0251
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go
@@ -0,0 +1,105 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+// SuperBlockOld implements SuperBlock and represents the old version of the
+// superblock struct. Should be used only if RevLevel = OldRev.
+type SuperBlockOld struct {
+ InodesCountRaw uint32
+ BlocksCountLo uint32
+ ReservedBlocksCount uint32
+ FreeBlocksCountLo uint32
+ FreeInodesCountRaw uint32
+ FirstDataBlockRaw uint32
+ LogBlockSize uint32
+ LogClusterSize uint32
+ BlocksPerGroupRaw uint32
+ ClustersPerGroupRaw uint32
+ InodesPerGroupRaw uint32
+ Mtime uint32
+ Wtime uint32
+ MountCountRaw uint16
+ MaxMountCountRaw uint16
+ MagicRaw uint16
+ State uint16
+ Errors uint16
+ MinorRevLevel uint16
+ LastCheck uint32
+ CheckInterval uint32
+ CreatorOS uint32
+ RevLevel uint32
+ DefResUID uint16
+ DefResGID uint16
+}
+
+// Compiles only if SuperBlockOld implements SuperBlock.
+var _ SuperBlock = (*SuperBlockOld)(nil)
+
+// InodesCount implements SuperBlock.InodesCount.
+func (sb *SuperBlockOld) InodesCount() uint32 { return sb.InodesCountRaw }
+
+// BlocksCount implements SuperBlock.BlocksCount.
+func (sb *SuperBlockOld) BlocksCount() uint64 { return uint64(sb.BlocksCountLo) }
+
+// FreeBlocksCount implements SuperBlock.FreeBlocksCount.
+func (sb *SuperBlockOld) FreeBlocksCount() uint64 { return uint64(sb.FreeBlocksCountLo) }
+
+// FreeInodesCount implements SuperBlock.FreeInodesCount.
+func (sb *SuperBlockOld) FreeInodesCount() uint32 { return sb.FreeInodesCountRaw }
+
+// MountCount implements SuperBlock.MountCount.
+func (sb *SuperBlockOld) MountCount() uint16 { return sb.MountCountRaw }
+
+// MaxMountCount implements SuperBlock.MaxMountCount.
+func (sb *SuperBlockOld) MaxMountCount() uint16 { return sb.MaxMountCountRaw }
+
+// FirstDataBlock implements SuperBlock.FirstDataBlock.
+func (sb *SuperBlockOld) FirstDataBlock() uint32 { return sb.FirstDataBlockRaw }
+
+// BlockSize implements SuperBlock.BlockSize.
+func (sb *SuperBlockOld) BlockSize() uint64 { return 1 << (10 + sb.LogBlockSize) }
+
+// BlocksPerGroup implements SuperBlock.BlocksPerGroup.
+func (sb *SuperBlockOld) BlocksPerGroup() uint32 { return sb.BlocksPerGroupRaw }
+
+// ClusterSize implements SuperBlock.ClusterSize.
+func (sb *SuperBlockOld) ClusterSize() uint64 { return 1 << (10 + sb.LogClusterSize) }
+
+// ClustersPerGroup implements SuperBlock.ClustersPerGroup.
+func (sb *SuperBlockOld) ClustersPerGroup() uint32 { return sb.ClustersPerGroupRaw }
+
+// InodeSize implements SuperBlock.InodeSize.
+func (sb *SuperBlockOld) InodeSize() uint16 { return OldInodeSize }
+
+// InodesPerGroup implements SuperBlock.InodesPerGroup.
+func (sb *SuperBlockOld) InodesPerGroup() uint32 { return sb.InodesPerGroupRaw }
+
+// BgDescSize implements SuperBlock.BgDescSize.
+func (sb *SuperBlockOld) BgDescSize() uint16 { return 32 }
+
+// CompatibleFeatures implements SuperBlock.CompatibleFeatures.
+func (sb *SuperBlockOld) CompatibleFeatures() CompatFeatures { return CompatFeatures{} }
+
+// IncompatibleFeatures implements SuperBlock.IncompatibleFeatures.
+func (sb *SuperBlockOld) IncompatibleFeatures() IncompatFeatures { return IncompatFeatures{} }
+
+// ReadOnlyCompatibleFeatures implements SuperBlock.ReadOnlyCompatibleFeatures.
+func (sb *SuperBlockOld) ReadOnlyCompatibleFeatures() RoCompatFeatures { return RoCompatFeatures{} }
+
+// Magic implements SuperBlock.Magic.
+func (sb *SuperBlockOld) Magic() uint16 { return sb.MagicRaw }
+
+// Revision implements SuperBlock.Revision.
+func (sb *SuperBlockOld) Revision() SbRevision { return SbRevision(sb.RevLevel) }
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go
new file mode 100644
index 000000000..463b5ba21
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go
@@ -0,0 +1,27 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+import (
+ "testing"
+)
+
+// TestSuperBlockSize tests that the superblock structs are of the correct
+// size.
+func TestSuperBlockSize(t *testing.T) {
+ assertSize(t, SuperBlockOld{}, 84)
+ assertSize(t, SuperBlock32Bit{}, 336)
+ assertSize(t, SuperBlock64Bit{}, 1024)
+}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/test_utils.go b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go
new file mode 100644
index 000000000..9c63f04c0
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go
@@ -0,0 +1,30 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package disklayout
+
+import (
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/binary"
+)
+
+func assertSize(t *testing.T, v interface{}, want uintptr) {
+ t.Helper()
+
+ if got := binary.Size(v); got != want {
+ t.Errorf("struct %s should be exactly %d bytes but is %d bytes", reflect.TypeOf(v).Name(), want, got)
+ }
+}
diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go
new file mode 100644
index 000000000..dac6effbf
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/ext.go
@@ -0,0 +1,157 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ext implements readonly ext(2/3/4) filesystems.
+package ext
+
+import (
+ "errors"
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Name is the name of this filesystem.
+const Name = "ext"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+// Compiles only if FilesystemType implements vfs.FilesystemType.
+var _ vfs.FilesystemType = (*FilesystemType)(nil)
+
+// getDeviceFd returns an io.ReaderAt to the underlying device.
+// Currently there are two ways of mounting an ext(2/3/4) fs:
+// 1. Specify a mount with our internal special MountType in the OCI spec.
+// 2. Expose the device to the container and mount it from application layer.
+func getDeviceFd(source string, opts vfs.GetFilesystemOptions) (io.ReaderAt, error) {
+ if opts.InternalData == nil {
+ // User mount call.
+ // TODO(b/134676337): Open the device specified by `source` and return that.
+ panic("unimplemented")
+ }
+
+ // GetFilesystem call originated from within the sentry.
+ devFd, ok := opts.InternalData.(int)
+ if !ok {
+ return nil, errors.New("internal data for ext fs must be an int containing the file descriptor to device")
+ }
+
+ if devFd < 0 {
+ return nil, fmt.Errorf("ext device file descriptor is not valid: %d", devFd)
+ }
+
+ // The fd.ReadWriter returned from fd.NewReadWriter() does not take ownership
+ // of the file descriptor and hence will not close it when it is garbage
+ // collected.
+ return fd.NewReadWriter(devFd), nil
+}
+
+// isCompatible checks if the superblock has feature sets which are compatible.
+// We only need to check the superblock incompatible feature set since we are
+// mounting readonly. We will also need to check readonly compatible feature
+// set when mounting for read/write.
+func isCompatible(sb disklayout.SuperBlock) bool {
+ // Please note that what is being checked is limited based on the fact that we
+ // are mounting readonly and that we are not journaling. When mounting
+ // read/write or with a journal, this must be reevaluated.
+ incompatFeatures := sb.IncompatibleFeatures()
+ if incompatFeatures.MetaBG {
+ log.Warningf("ext fs: meta block groups are not supported")
+ return false
+ }
+ if incompatFeatures.MMP {
+ log.Warningf("ext fs: multiple mount protection is not supported")
+ return false
+ }
+ if incompatFeatures.Encrypted {
+ log.Warningf("ext fs: encrypted inodes not supported")
+ return false
+ }
+ if incompatFeatures.InlineData {
+ log.Warningf("ext fs: inline files not supported")
+ return false
+ }
+ return true
+}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ // TODO(b/134676337): Ensure that the user is mounting readonly. If not,
+ // EACCESS should be returned according to mount(2). Filesystem independent
+ // flags (like readonly) are currently not available in pkg/sentry/vfs.
+
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ dev, err := getDeviceFd(source, opts)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fs := filesystem{
+ dev: dev,
+ inodeCache: make(map[uint32]*inode),
+ devMinor: devMinor,
+ }
+ fs.vfsfs.Init(vfsObj, &fsType, &fs)
+ fs.sb, err = readSuperBlock(dev)
+ if err != nil {
+ fs.vfsfs.DecRef()
+ return nil, nil, err
+ }
+
+ if fs.sb.Magic() != linux.EXT_SUPER_MAGIC {
+ // mount(2) specifies that EINVAL should be returned if the superblock is
+ // invalid.
+ fs.vfsfs.DecRef()
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Refuse to mount if the filesystem is incompatible.
+ if !isCompatible(fs.sb) {
+ fs.vfsfs.DecRef()
+ return nil, nil, syserror.EINVAL
+ }
+
+ fs.bgs, err = readBlockGroups(dev, fs.sb)
+ if err != nil {
+ fs.vfsfs.DecRef()
+ return nil, nil, err
+ }
+
+ rootInode, err := fs.getOrCreateInodeLocked(disklayout.RootDirInode)
+ if err != nil {
+ fs.vfsfs.DecRef()
+ return nil, nil, err
+ }
+ rootInode.incRef()
+
+ return &fs.vfsfs, &newDentry(rootInode).vfsd, nil
+}
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
new file mode 100644
index 000000000..64e9a579f
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -0,0 +1,921 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "path"
+ "sort"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ assetsDir = "pkg/sentry/fsimpl/ext/assets"
+)
+
+var (
+ ext2ImagePath = path.Join(assetsDir, "tiny.ext2")
+ ext3ImagePath = path.Join(assetsDir, "tiny.ext3")
+ ext4ImagePath = path.Join(assetsDir, "tiny.ext4")
+)
+
+// setUp opens imagePath as an ext Filesystem and returns all necessary
+// elements required to run tests. If error is non-nil, it also returns a tear
+// down function which must be called after the test is run for clean up.
+func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesystem, *vfs.VirtualDentry, func(), error) {
+ localImagePath, err := testutil.FindFile(imagePath)
+ if err != nil {
+ return nil, nil, nil, nil, fmt.Errorf("failed to open local image at path %s: %v", imagePath, err)
+ }
+
+ f, err := os.Open(localImagePath)
+ if err != nil {
+ return nil, nil, nil, nil, err
+ }
+
+ ctx := contexttest.Context(t)
+ creds := auth.CredentialsFromContext(ctx)
+
+ // Create VFS.
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
+ vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())})
+ if err != nil {
+ f.Close()
+ return nil, nil, nil, nil, err
+ }
+
+ root := mntns.Root()
+
+ tearDown := func() {
+ root.DecRef()
+
+ if err := f.Close(); err != nil {
+ t.Fatalf("tearDown failed: %v", err)
+ }
+ }
+ return ctx, vfsObj, &root, tearDown, nil
+}
+
+// TODO(b/134676337): Test vfs.FilesystemImpl.ReadlinkAt and
+// vfs.FilesystemImpl.StatFSAt which are not implemented in
+// vfs.VirtualFilesystem yet.
+
+// TestSeek tests vfs.FileDescriptionImpl.Seek functionality.
+func TestSeek(t *testing.T) {
+ type seekTest struct {
+ name string
+ image string
+ path string
+ }
+
+ tests := []seekTest{
+ {
+ name: "ext4 root dir seek",
+ image: ext4ImagePath,
+ path: "/",
+ },
+ {
+ name: "ext3 root dir seek",
+ image: ext3ImagePath,
+ path: "/",
+ },
+ {
+ name: "ext2 root dir seek",
+ image: ext2ImagePath,
+ path: "/",
+ },
+ {
+ name: "ext4 reg file seek",
+ image: ext4ImagePath,
+ path: "/file.txt",
+ },
+ {
+ name: "ext3 reg file seek",
+ image: ext3ImagePath,
+ path: "/file.txt",
+ },
+ {
+ name: "ext2 reg file seek",
+ image: ext2ImagePath,
+ path: "/file.txt",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ctx, vfsfs, root, tearDown, err := setUp(t, test.image)
+ if err != nil {
+ t.Fatalf("setUp failed: %v", err)
+ }
+ defer tearDown()
+
+ fd, err := vfsfs.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.path)},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt failed: %v", err)
+ }
+
+ if n, err := fd.Seek(ctx, 0, linux.SEEK_SET); n != 0 || err != nil {
+ t.Errorf("expected seek position 0, got %d and error %v", n, err)
+ }
+
+ stat, err := fd.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Errorf("fd.stat failed for file %s in image %s: %v", test.path, test.image, err)
+ }
+
+ // We should be able to seek beyond the end of file.
+ size := int64(stat.Size)
+ if n, err := fd.Seek(ctx, size, linux.SEEK_SET); n != size || err != nil {
+ t.Errorf("expected seek position %d, got %d and error %v", size, n, err)
+ }
+
+ // EINVAL should be returned if the resulting offset is negative.
+ if _, err := fd.Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL {
+ t.Errorf("expected error EINVAL but got %v", err)
+ }
+
+ if n, err := fd.Seek(ctx, 3, linux.SEEK_CUR); n != size+3 || err != nil {
+ t.Errorf("expected seek position %d, got %d and error %v", size+3, n, err)
+ }
+
+ // Make sure negative offsets work with SEEK_CUR.
+ if n, err := fd.Seek(ctx, -2, linux.SEEK_CUR); n != size+1 || err != nil {
+ t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err)
+ }
+
+ // EINVAL should be returned if the resulting offset is negative.
+ if _, err := fd.Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL {
+ t.Errorf("expected error EINVAL but got %v", err)
+ }
+
+ // Make sure SEEK_END works with regular files.
+ if _, ok := fd.Impl().(*regularFileFD); ok {
+ // Seek back to 0.
+ if n, err := fd.Seek(ctx, -size, linux.SEEK_END); n != 0 || err != nil {
+ t.Errorf("expected seek position %d, got %d and error %v", 0, n, err)
+ }
+
+ // Seek forward beyond EOF.
+ if n, err := fd.Seek(ctx, 1, linux.SEEK_END); n != size+1 || err != nil {
+ t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err)
+ }
+
+ // EINVAL should be returned if the resulting offset is negative.
+ if _, err := fd.Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL {
+ t.Errorf("expected error EINVAL but got %v", err)
+ }
+ }
+ })
+ }
+}
+
+// TestStatAt tests filesystem.StatAt functionality.
+func TestStatAt(t *testing.T) {
+ type statAtTest struct {
+ name string
+ image string
+ path string
+ want linux.Statx
+ }
+
+ tests := []statAtTest{
+ {
+ name: "ext4 statx small file",
+ image: ext4ImagePath,
+ path: "/file.txt",
+ want: linux.Statx{
+ Blksize: 0x400,
+ Nlink: 1,
+ UID: 0,
+ GID: 0,
+ Mode: 0644 | linux.ModeRegular,
+ Size: 13,
+ },
+ },
+ {
+ name: "ext3 statx small file",
+ image: ext3ImagePath,
+ path: "/file.txt",
+ want: linux.Statx{
+ Blksize: 0x400,
+ Nlink: 1,
+ UID: 0,
+ GID: 0,
+ Mode: 0644 | linux.ModeRegular,
+ Size: 13,
+ },
+ },
+ {
+ name: "ext2 statx small file",
+ image: ext2ImagePath,
+ path: "/file.txt",
+ want: linux.Statx{
+ Blksize: 0x400,
+ Nlink: 1,
+ UID: 0,
+ GID: 0,
+ Mode: 0644 | linux.ModeRegular,
+ Size: 13,
+ },
+ },
+ {
+ name: "ext4 statx big file",
+ image: ext4ImagePath,
+ path: "/bigfile.txt",
+ want: linux.Statx{
+ Blksize: 0x400,
+ Nlink: 1,
+ UID: 0,
+ GID: 0,
+ Mode: 0644 | linux.ModeRegular,
+ Size: 13042,
+ },
+ },
+ {
+ name: "ext3 statx big file",
+ image: ext3ImagePath,
+ path: "/bigfile.txt",
+ want: linux.Statx{
+ Blksize: 0x400,
+ Nlink: 1,
+ UID: 0,
+ GID: 0,
+ Mode: 0644 | linux.ModeRegular,
+ Size: 13042,
+ },
+ },
+ {
+ name: "ext2 statx big file",
+ image: ext2ImagePath,
+ path: "/bigfile.txt",
+ want: linux.Statx{
+ Blksize: 0x400,
+ Nlink: 1,
+ UID: 0,
+ GID: 0,
+ Mode: 0644 | linux.ModeRegular,
+ Size: 13042,
+ },
+ },
+ {
+ name: "ext4 statx symlink file",
+ image: ext4ImagePath,
+ path: "/symlink.txt",
+ want: linux.Statx{
+ Blksize: 0x400,
+ Nlink: 1,
+ UID: 0,
+ GID: 0,
+ Mode: 0777 | linux.ModeSymlink,
+ Size: 8,
+ },
+ },
+ {
+ name: "ext3 statx symlink file",
+ image: ext3ImagePath,
+ path: "/symlink.txt",
+ want: linux.Statx{
+ Blksize: 0x400,
+ Nlink: 1,
+ UID: 0,
+ GID: 0,
+ Mode: 0777 | linux.ModeSymlink,
+ Size: 8,
+ },
+ },
+ {
+ name: "ext2 statx symlink file",
+ image: ext2ImagePath,
+ path: "/symlink.txt",
+ want: linux.Statx{
+ Blksize: 0x400,
+ Nlink: 1,
+ UID: 0,
+ GID: 0,
+ Mode: 0777 | linux.ModeSymlink,
+ Size: 8,
+ },
+ },
+ }
+
+ // Ignore the fields that are not supported by filesystem.StatAt yet and
+ // those which are likely to change as the image does.
+ ignoredFields := map[string]bool{
+ "Attributes": true,
+ "AttributesMask": true,
+ "Atime": true,
+ "Blocks": true,
+ "Btime": true,
+ "Ctime": true,
+ "DevMajor": true,
+ "DevMinor": true,
+ "Ino": true,
+ "Mask": true,
+ "Mtime": true,
+ "RdevMajor": true,
+ "RdevMinor": true,
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ctx, vfsfs, root, tearDown, err := setUp(t, test.image)
+ if err != nil {
+ t.Fatalf("setUp failed: %v", err)
+ }
+ defer tearDown()
+
+ got, err := vfsfs.StatAt(ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.path)},
+ &vfs.StatOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.StatAt failed for file %s in image %s: %v", test.path, test.image, err)
+ }
+
+ cmpIgnoreFields := cmp.FilterPath(func(p cmp.Path) bool {
+ _, ok := ignoredFields[p.String()]
+ return ok
+ }, cmp.Ignore())
+ if diff := cmp.Diff(got, test.want, cmpIgnoreFields, cmpopts.IgnoreUnexported(linux.Statx{})); diff != "" {
+ t.Errorf("stat mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+// TestRead tests the read functionality for vfs file descriptions.
+func TestRead(t *testing.T) {
+ type readTest struct {
+ name string
+ image string
+ absPath string
+ }
+
+ tests := []readTest{
+ {
+ name: "ext4 read small file",
+ image: ext4ImagePath,
+ absPath: "/file.txt",
+ },
+ {
+ name: "ext3 read small file",
+ image: ext3ImagePath,
+ absPath: "/file.txt",
+ },
+ {
+ name: "ext2 read small file",
+ image: ext2ImagePath,
+ absPath: "/file.txt",
+ },
+ {
+ name: "ext4 read big file",
+ image: ext4ImagePath,
+ absPath: "/bigfile.txt",
+ },
+ {
+ name: "ext3 read big file",
+ image: ext3ImagePath,
+ absPath: "/bigfile.txt",
+ },
+ {
+ name: "ext2 read big file",
+ image: ext2ImagePath,
+ absPath: "/bigfile.txt",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ctx, vfsfs, root, tearDown, err := setUp(t, test.image)
+ if err != nil {
+ t.Fatalf("setUp failed: %v", err)
+ }
+ defer tearDown()
+
+ fd, err := vfsfs.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.absPath)},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt failed: %v", err)
+ }
+
+ // Get a local file descriptor and compare its functionality with a vfs file
+ // description for the same file.
+ localFile, err := testutil.FindFile(path.Join(assetsDir, test.absPath))
+ if err != nil {
+ t.Fatalf("testutil.FindFile failed for %s: %v", test.absPath, err)
+ }
+
+ f, err := os.Open(localFile)
+ if err != nil {
+ t.Fatalf("os.Open failed for %s: %v", localFile, err)
+ }
+ defer f.Close()
+
+ // Read the entire file by reading one byte repeatedly. Doing this stress
+ // tests the underlying file reader implementation.
+ got := make([]byte, 1)
+ want := make([]byte, 1)
+ for {
+ n, err := f.Read(want)
+ fd.Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{})
+
+ if diff := cmp.Diff(got, want); diff != "" {
+ t.Errorf("file data mismatch (-want +got):\n%s", diff)
+ }
+
+ // Make sure there is no more file data left after getting EOF.
+ if n == 0 || err == io.EOF {
+ if n, _ := fd.Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}); n != 0 {
+ t.Errorf("extra unexpected file data in file %s in image %s", test.absPath, test.image)
+ }
+
+ break
+ }
+
+ if err != nil {
+ t.Fatalf("read failed: %v", err)
+ }
+ }
+ })
+ }
+}
+
+// iterDirentsCb is a simple callback which just keeps adding the dirents to an
+// internal list. Implements vfs.IterDirentsCallback.
+type iterDirentsCb struct {
+ dirents []vfs.Dirent
+}
+
+// Compiles only if iterDirentCb implements vfs.IterDirentsCallback.
+var _ vfs.IterDirentsCallback = (*iterDirentsCb)(nil)
+
+// newIterDirentsCb is the iterDirent
+func newIterDirentCb() *iterDirentsCb {
+ return &iterDirentsCb{dirents: make([]vfs.Dirent, 0)}
+}
+
+// Handle implements vfs.IterDirentsCallback.Handle.
+func (cb *iterDirentsCb) Handle(dirent vfs.Dirent) error {
+ cb.dirents = append(cb.dirents, dirent)
+ return nil
+}
+
+// TestIterDirents tests the FileDescriptionImpl.IterDirents functionality.
+func TestIterDirents(t *testing.T) {
+ type iterDirentTest struct {
+ name string
+ image string
+ path string
+ want []vfs.Dirent
+ }
+
+ wantDirents := []vfs.Dirent{
+ {
+ Name: ".",
+ Type: linux.DT_DIR,
+ },
+ {
+ Name: "..",
+ Type: linux.DT_DIR,
+ },
+ {
+ Name: "lost+found",
+ Type: linux.DT_DIR,
+ },
+ {
+ Name: "file.txt",
+ Type: linux.DT_REG,
+ },
+ {
+ Name: "bigfile.txt",
+ Type: linux.DT_REG,
+ },
+ {
+ Name: "symlink.txt",
+ Type: linux.DT_LNK,
+ },
+ }
+ tests := []iterDirentTest{
+ {
+ name: "ext4 root dir iteration",
+ image: ext4ImagePath,
+ path: "/",
+ want: wantDirents,
+ },
+ {
+ name: "ext3 root dir iteration",
+ image: ext3ImagePath,
+ path: "/",
+ want: wantDirents,
+ },
+ {
+ name: "ext2 root dir iteration",
+ image: ext2ImagePath,
+ path: "/",
+ want: wantDirents,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ctx, vfsfs, root, tearDown, err := setUp(t, test.image)
+ if err != nil {
+ t.Fatalf("setUp failed: %v", err)
+ }
+ defer tearDown()
+
+ fd, err := vfsfs.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.path)},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt failed: %v", err)
+ }
+
+ cb := &iterDirentsCb{}
+ if err = fd.IterDirents(ctx, cb); err != nil {
+ t.Fatalf("dir fd.IterDirents() failed: %v", err)
+ }
+
+ sort.Slice(cb.dirents, func(i int, j int) bool { return cb.dirents[i].Name < cb.dirents[j].Name })
+ sort.Slice(test.want, func(i int, j int) bool { return test.want[i].Name < test.want[j].Name })
+
+ // Ignore the inode number and offset of dirents because those are likely to
+ // change as the underlying image changes.
+ cmpIgnoreFields := cmp.FilterPath(func(p cmp.Path) bool {
+ return p.String() == "Ino" || p.String() == "NextOff"
+ }, cmp.Ignore())
+ if diff := cmp.Diff(cb.dirents, test.want, cmpIgnoreFields); diff != "" {
+ t.Errorf("dirents mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+// TestRootDir tests that the root directory inode is correctly initialized and
+// returned from setUp.
+func TestRootDir(t *testing.T) {
+ type inodeProps struct {
+ Mode linux.FileMode
+ UID auth.KUID
+ GID auth.KGID
+ Size uint64
+ InodeSize uint16
+ Links uint16
+ Flags disklayout.InodeFlags
+ }
+
+ type rootDirTest struct {
+ name string
+ image string
+ wantInode inodeProps
+ }
+
+ tests := []rootDirTest{
+ {
+ name: "ext4 root dir",
+ image: ext4ImagePath,
+ wantInode: inodeProps{
+ Mode: linux.ModeDirectory | 0755,
+ Size: 0x400,
+ InodeSize: 0x80,
+ Links: 3,
+ Flags: disklayout.InodeFlags{Extents: true},
+ },
+ },
+ {
+ name: "ext3 root dir",
+ image: ext3ImagePath,
+ wantInode: inodeProps{
+ Mode: linux.ModeDirectory | 0755,
+ Size: 0x400,
+ InodeSize: 0x80,
+ Links: 3,
+ },
+ },
+ {
+ name: "ext2 root dir",
+ image: ext2ImagePath,
+ wantInode: inodeProps{
+ Mode: linux.ModeDirectory | 0755,
+ Size: 0x400,
+ InodeSize: 0x80,
+ Links: 3,
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ _, _, vd, tearDown, err := setUp(t, test.image)
+ if err != nil {
+ t.Fatalf("setUp failed: %v", err)
+ }
+ defer tearDown()
+
+ d, ok := vd.Dentry().Impl().(*dentry)
+ if !ok {
+ t.Fatalf("ext dentry of incorrect type: %T", vd.Dentry().Impl())
+ }
+
+ // Offload inode contents into local structs for comparison.
+ gotInode := inodeProps{
+ Mode: d.inode.diskInode.Mode(),
+ UID: d.inode.diskInode.UID(),
+ GID: d.inode.diskInode.GID(),
+ Size: d.inode.diskInode.Size(),
+ InodeSize: d.inode.diskInode.InodeSize(),
+ Links: d.inode.diskInode.LinksCount(),
+ Flags: d.inode.diskInode.Flags(),
+ }
+
+ if diff := cmp.Diff(gotInode, test.wantInode); diff != "" {
+ t.Errorf("inode mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+// TestFilesystemInit tests that the filesystem superblock and block group
+// descriptors are correctly read in and initialized.
+func TestFilesystemInit(t *testing.T) {
+ // sb only contains the immutable properties of the superblock.
+ type sb struct {
+ InodesCount uint32
+ BlocksCount uint64
+ MaxMountCount uint16
+ FirstDataBlock uint32
+ BlockSize uint64
+ BlocksPerGroup uint32
+ ClusterSize uint64
+ ClustersPerGroup uint32
+ InodeSize uint16
+ InodesPerGroup uint32
+ BgDescSize uint16
+ Magic uint16
+ Revision disklayout.SbRevision
+ CompatFeatures disklayout.CompatFeatures
+ IncompatFeatures disklayout.IncompatFeatures
+ RoCompatFeatures disklayout.RoCompatFeatures
+ }
+
+ // bg only contains the immutable properties of the block group descriptor.
+ type bg struct {
+ InodeTable uint64
+ BlockBitmap uint64
+ InodeBitmap uint64
+ ExclusionBitmap uint64
+ Flags disklayout.BGFlags
+ }
+
+ type fsInitTest struct {
+ name string
+ image string
+ wantSb sb
+ wantBgs []bg
+ }
+
+ tests := []fsInitTest{
+ {
+ name: "ext4 filesystem init",
+ image: ext4ImagePath,
+ wantSb: sb{
+ InodesCount: 0x10,
+ BlocksCount: 0x40,
+ MaxMountCount: 0xffff,
+ FirstDataBlock: 0x1,
+ BlockSize: 0x400,
+ BlocksPerGroup: 0x2000,
+ ClusterSize: 0x400,
+ ClustersPerGroup: 0x2000,
+ InodeSize: 0x80,
+ InodesPerGroup: 0x10,
+ BgDescSize: 0x40,
+ Magic: linux.EXT_SUPER_MAGIC,
+ Revision: disklayout.DynamicRev,
+ CompatFeatures: disklayout.CompatFeatures{
+ ExtAttr: true,
+ ResizeInode: true,
+ DirIndex: true,
+ },
+ IncompatFeatures: disklayout.IncompatFeatures{
+ DirentFileType: true,
+ Extents: true,
+ Is64Bit: true,
+ FlexBg: true,
+ },
+ RoCompatFeatures: disklayout.RoCompatFeatures{
+ Sparse: true,
+ LargeFile: true,
+ HugeFile: true,
+ DirNlink: true,
+ ExtraIsize: true,
+ MetadataCsum: true,
+ },
+ },
+ wantBgs: []bg{
+ {
+ InodeTable: 0x23,
+ BlockBitmap: 0x3,
+ InodeBitmap: 0x13,
+ Flags: disklayout.BGFlags{
+ InodeZeroed: true,
+ },
+ },
+ },
+ },
+ {
+ name: "ext3 filesystem init",
+ image: ext3ImagePath,
+ wantSb: sb{
+ InodesCount: 0x10,
+ BlocksCount: 0x40,
+ MaxMountCount: 0xffff,
+ FirstDataBlock: 0x1,
+ BlockSize: 0x400,
+ BlocksPerGroup: 0x2000,
+ ClusterSize: 0x400,
+ ClustersPerGroup: 0x2000,
+ InodeSize: 0x80,
+ InodesPerGroup: 0x10,
+ BgDescSize: 0x20,
+ Magic: linux.EXT_SUPER_MAGIC,
+ Revision: disklayout.DynamicRev,
+ CompatFeatures: disklayout.CompatFeatures{
+ ExtAttr: true,
+ ResizeInode: true,
+ DirIndex: true,
+ },
+ IncompatFeatures: disklayout.IncompatFeatures{
+ DirentFileType: true,
+ },
+ RoCompatFeatures: disklayout.RoCompatFeatures{
+ Sparse: true,
+ LargeFile: true,
+ },
+ },
+ wantBgs: []bg{
+ {
+ InodeTable: 0x5,
+ BlockBitmap: 0x3,
+ InodeBitmap: 0x4,
+ Flags: disklayout.BGFlags{
+ InodeZeroed: true,
+ },
+ },
+ },
+ },
+ {
+ name: "ext2 filesystem init",
+ image: ext2ImagePath,
+ wantSb: sb{
+ InodesCount: 0x10,
+ BlocksCount: 0x40,
+ MaxMountCount: 0xffff,
+ FirstDataBlock: 0x1,
+ BlockSize: 0x400,
+ BlocksPerGroup: 0x2000,
+ ClusterSize: 0x400,
+ ClustersPerGroup: 0x2000,
+ InodeSize: 0x80,
+ InodesPerGroup: 0x10,
+ BgDescSize: 0x20,
+ Magic: linux.EXT_SUPER_MAGIC,
+ Revision: disklayout.DynamicRev,
+ CompatFeatures: disklayout.CompatFeatures{
+ ExtAttr: true,
+ ResizeInode: true,
+ DirIndex: true,
+ },
+ IncompatFeatures: disklayout.IncompatFeatures{
+ DirentFileType: true,
+ },
+ RoCompatFeatures: disklayout.RoCompatFeatures{
+ Sparse: true,
+ LargeFile: true,
+ },
+ },
+ wantBgs: []bg{
+ {
+ InodeTable: 0x5,
+ BlockBitmap: 0x3,
+ InodeBitmap: 0x4,
+ Flags: disklayout.BGFlags{
+ InodeZeroed: true,
+ },
+ },
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ _, _, vd, tearDown, err := setUp(t, test.image)
+ if err != nil {
+ t.Fatalf("setUp failed: %v", err)
+ }
+ defer tearDown()
+
+ fs, ok := vd.Mount().Filesystem().Impl().(*filesystem)
+ if !ok {
+ t.Fatalf("ext filesystem of incorrect type: %T", vd.Mount().Filesystem().Impl())
+ }
+
+ // Offload superblock and block group descriptors contents into
+ // local structs for comparison.
+ totalFreeInodes := uint32(0)
+ totalFreeBlocks := uint64(0)
+ gotSb := sb{
+ InodesCount: fs.sb.InodesCount(),
+ BlocksCount: fs.sb.BlocksCount(),
+ MaxMountCount: fs.sb.MaxMountCount(),
+ FirstDataBlock: fs.sb.FirstDataBlock(),
+ BlockSize: fs.sb.BlockSize(),
+ BlocksPerGroup: fs.sb.BlocksPerGroup(),
+ ClusterSize: fs.sb.ClusterSize(),
+ ClustersPerGroup: fs.sb.ClustersPerGroup(),
+ InodeSize: fs.sb.InodeSize(),
+ InodesPerGroup: fs.sb.InodesPerGroup(),
+ BgDescSize: fs.sb.BgDescSize(),
+ Magic: fs.sb.Magic(),
+ Revision: fs.sb.Revision(),
+ CompatFeatures: fs.sb.CompatibleFeatures(),
+ IncompatFeatures: fs.sb.IncompatibleFeatures(),
+ RoCompatFeatures: fs.sb.ReadOnlyCompatibleFeatures(),
+ }
+ gotNumBgs := len(fs.bgs)
+ gotBgs := make([]bg, gotNumBgs)
+ for i := 0; i < gotNumBgs; i++ {
+ gotBgs[i].InodeTable = fs.bgs[i].InodeTable()
+ gotBgs[i].BlockBitmap = fs.bgs[i].BlockBitmap()
+ gotBgs[i].InodeBitmap = fs.bgs[i].InodeBitmap()
+ gotBgs[i].ExclusionBitmap = fs.bgs[i].ExclusionBitmap()
+ gotBgs[i].Flags = fs.bgs[i].Flags()
+
+ totalFreeInodes += fs.bgs[i].FreeInodesCount()
+ totalFreeBlocks += uint64(fs.bgs[i].FreeBlocksCount())
+ }
+
+ if diff := cmp.Diff(gotSb, test.wantSb); diff != "" {
+ t.Errorf("superblock mismatch (-want +got):\n%s", diff)
+ }
+
+ if diff := cmp.Diff(gotBgs, test.wantBgs); diff != "" {
+ t.Errorf("block group descriptors mismatch (-want +got):\n%s", diff)
+ }
+
+ if diff := cmp.Diff(totalFreeInodes, fs.sb.FreeInodesCount()); diff != "" {
+ t.Errorf("total free inodes mismatch (-want +got):\n%s", diff)
+ }
+
+ if diff := cmp.Diff(totalFreeBlocks, fs.sb.FreeBlocksCount()); diff != "" {
+ t.Errorf("total free blocks mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go
new file mode 100644
index 000000000..c36225a7c
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/extent_file.go
@@ -0,0 +1,238 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "io"
+ "sort"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// extentFile is a type of regular file which uses extents to store file data.
+type extentFile struct {
+ regFile regularFile
+
+ // root is the root extent node. This lives in the 60 byte diskInode.Data().
+ // Immutable.
+ root disklayout.ExtentNode
+}
+
+// Compiles only if extentFile implements io.ReaderAt.
+var _ io.ReaderAt = (*extentFile)(nil)
+
+// newExtentFile is the extent file constructor. It reads the entire extent
+// tree into memory.
+// TODO(b/134676337): Build extent tree on demand to reduce memory usage.
+func newExtentFile(args inodeArgs) (*extentFile, error) {
+ file := &extentFile{}
+ file.regFile.impl = file
+ file.regFile.inode.init(args, &file.regFile)
+ err := file.buildExtTree()
+ if err != nil {
+ return nil, err
+ }
+ return file, nil
+}
+
+// buildExtTree builds the extent tree by reading it from disk by doing
+// running a simple DFS. It first reads the root node from the inode struct in
+// memory. Then it recursively builds the rest of the tree by reading it off
+// disk.
+//
+// Precondition: inode flag InExtents must be set.
+func (f *extentFile) buildExtTree() error {
+ rootNodeData := f.regFile.inode.diskInode.Data()
+
+ binary.Unmarshal(rootNodeData[:disklayout.ExtentHeaderSize], binary.LittleEndian, &f.root.Header)
+
+ // Root node can not have more than 4 entries: 60 bytes = 1 header + 4 entries.
+ if f.root.Header.NumEntries > 4 {
+ // read(2) specifies that EINVAL should be returned if the file is unsuitable
+ // for reading.
+ return syserror.EINVAL
+ }
+
+ f.root.Entries = make([]disklayout.ExtentEntryPair, f.root.Header.NumEntries)
+ for i, off := uint16(0), disklayout.ExtentEntrySize; i < f.root.Header.NumEntries; i, off = i+1, off+disklayout.ExtentEntrySize {
+ var curEntry disklayout.ExtentEntry
+ if f.root.Header.Height == 0 {
+ // Leaf node.
+ curEntry = &disklayout.Extent{}
+ } else {
+ // Internal node.
+ curEntry = &disklayout.ExtentIdx{}
+ }
+ binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentEntrySize], binary.LittleEndian, curEntry)
+ f.root.Entries[i].Entry = curEntry
+ }
+
+ // If this node is internal, perform DFS.
+ if f.root.Header.Height > 0 {
+ for i := uint16(0); i < f.root.Header.NumEntries; i++ {
+ var err error
+ if f.root.Entries[i].Node, err = f.buildExtTreeFromDisk(f.root.Entries[i].Entry); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+// buildExtTreeFromDisk reads the extent tree nodes from disk and recursively
+// builds the tree. Performs a simple DFS. It returns the ExtentNode pointed to
+// by the ExtentEntry.
+func (f *extentFile) buildExtTreeFromDisk(entry disklayout.ExtentEntry) (*disklayout.ExtentNode, error) {
+ var header disklayout.ExtentHeader
+ off := entry.PhysicalBlock() * f.regFile.inode.blkSize
+ err := readFromDisk(f.regFile.inode.fs.dev, int64(off), &header)
+ if err != nil {
+ return nil, err
+ }
+
+ entries := make([]disklayout.ExtentEntryPair, header.NumEntries)
+ for i, off := uint16(0), off+disklayout.ExtentEntrySize; i < header.NumEntries; i, off = i+1, off+disklayout.ExtentEntrySize {
+ var curEntry disklayout.ExtentEntry
+ if header.Height == 0 {
+ // Leaf node.
+ curEntry = &disklayout.Extent{}
+ } else {
+ // Internal node.
+ curEntry = &disklayout.ExtentIdx{}
+ }
+
+ err := readFromDisk(f.regFile.inode.fs.dev, int64(off), curEntry)
+ if err != nil {
+ return nil, err
+ }
+ entries[i].Entry = curEntry
+ }
+
+ // If this node is internal, perform DFS.
+ if header.Height > 0 {
+ for i := uint16(0); i < header.NumEntries; i++ {
+ var err error
+ entries[i].Node, err = f.buildExtTreeFromDisk(entries[i].Entry)
+ if err != nil {
+ return nil, err
+ }
+ }
+ }
+
+ return &disklayout.ExtentNode{header, entries}, nil
+}
+
+// ReadAt implements io.ReaderAt.ReadAt.
+func (f *extentFile) ReadAt(dst []byte, off int64) (int, error) {
+ if len(dst) == 0 {
+ return 0, nil
+ }
+
+ if off < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ if uint64(off) >= f.regFile.inode.diskInode.Size() {
+ return 0, io.EOF
+ }
+
+ n, err := f.read(&f.root, uint64(off), dst)
+ if n < len(dst) && err == nil {
+ err = io.EOF
+ }
+ return n, err
+}
+
+// read is the recursive step of extentFile.ReadAt which traverses the extent
+// tree from the node passed and reads file data.
+func (f *extentFile) read(node *disklayout.ExtentNode, off uint64, dst []byte) (int, error) {
+ // Perform a binary search for the node covering bytes starting at r.fileOff.
+ // A highly fragmented filesystem can have upto 340 entries and so linear
+ // search should be avoided. Finds the first entry which does not cover the
+ // file block we want and subtracts 1 to get the desired index.
+ fileBlk := uint32(off / f.regFile.inode.blkSize)
+ n := len(node.Entries)
+ found := sort.Search(n, func(i int) bool {
+ return node.Entries[i].Entry.FileBlock() > fileBlk
+ }) - 1
+
+ // We should be in this recursive step only if the data we want exists under
+ // the current node.
+ if found < 0 {
+ panic("searching for a file block in an extent entry which does not cover it")
+ }
+
+ read := 0
+ toRead := len(dst)
+ var curR int
+ var err error
+ for i := found; i < n && read < toRead; i++ {
+ if node.Header.Height == 0 {
+ curR, err = f.readFromExtent(node.Entries[i].Entry.(*disklayout.Extent), off, dst[read:])
+ } else {
+ curR, err = f.read(node.Entries[i].Node, off, dst[read:])
+ }
+
+ read += curR
+ off += uint64(curR)
+ if err != nil {
+ return read, err
+ }
+ }
+
+ return read, nil
+}
+
+// readFromExtent reads file data from the extent. It takes advantage of the
+// sequential nature of extents and reads file data from multiple blocks in one
+// call.
+//
+// A non-nil error indicates that this is a partial read and there is probably
+// more to read from this extent. The caller should propagate the error upward
+// and not move to the next extent in the tree.
+//
+// A subsequent call to extentReader.Read should continue reading from where we
+// left off as expected.
+func (f *extentFile) readFromExtent(ex *disklayout.Extent, off uint64, dst []byte) (int, error) {
+ curFileBlk := uint32(off / f.regFile.inode.blkSize)
+ exFirstFileBlk := ex.FileBlock()
+ exLastFileBlk := exFirstFileBlk + uint32(ex.Length) // This is exclusive.
+
+ // We should be in this recursive step only if the data we want exists under
+ // the current extent.
+ if curFileBlk < exFirstFileBlk || exLastFileBlk <= curFileBlk {
+ panic("searching for a file block in an extent which does not cover it")
+ }
+
+ curPhyBlk := uint64(curFileBlk-exFirstFileBlk) + ex.PhysicalBlock()
+ readStart := curPhyBlk*f.regFile.inode.blkSize + (off % f.regFile.inode.blkSize)
+
+ endPhyBlk := ex.PhysicalBlock() + uint64(ex.Length)
+ extentEnd := endPhyBlk * f.regFile.inode.blkSize // This is exclusive.
+
+ toRead := int(extentEnd - readStart)
+ if len(dst) < toRead {
+ toRead = len(dst)
+ }
+
+ n, _ := f.regFile.inode.fs.dev.ReadAt(dst[:toRead], int64(readStart))
+ if n < toRead {
+ return n, syserror.EIO
+ }
+ return n, nil
+}
diff --git a/pkg/sentry/fsimpl/ext/extent_test.go b/pkg/sentry/fsimpl/ext/extent_test.go
new file mode 100644
index 000000000..cd10d46ee
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/extent_test.go
@@ -0,0 +1,265 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "bytes"
+ "math/rand"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
+)
+
+const (
+ // mockExtentBlkSize is the mock block size used for testing.
+ // No block has more than 1 header + 4 entries.
+ mockExtentBlkSize = uint64(64)
+)
+
+// The tree described below looks like:
+//
+// 0.{Head}[Idx][Idx]
+// / \
+// / \
+// 1.{Head}[Ext][Ext] 2.{Head}[Idx]
+// / | \
+// [Phy] [Phy, Phy] 3.{Head}[Ext]
+// |
+// [Phy, Phy, Phy]
+//
+// Legend:
+// - Head = ExtentHeader
+// - Idx = ExtentIdx
+// - Ext = Extent
+// - Phy = Physical Block
+//
+// Please note that ext4 might not construct extent trees looking like this.
+// This is purely for testing the tree traversal logic.
+var (
+ node3 = &disklayout.ExtentNode{
+ Header: disklayout.ExtentHeader{
+ Magic: disklayout.ExtentMagic,
+ NumEntries: 1,
+ MaxEntries: 4,
+ Height: 0,
+ },
+ Entries: []disklayout.ExtentEntryPair{
+ {
+ Entry: &disklayout.Extent{
+ FirstFileBlock: 3,
+ Length: 3,
+ StartBlockLo: 6,
+ },
+ Node: nil,
+ },
+ },
+ }
+
+ node2 = &disklayout.ExtentNode{
+ Header: disklayout.ExtentHeader{
+ Magic: disklayout.ExtentMagic,
+ NumEntries: 1,
+ MaxEntries: 4,
+ Height: 1,
+ },
+ Entries: []disklayout.ExtentEntryPair{
+ {
+ Entry: &disklayout.ExtentIdx{
+ FirstFileBlock: 3,
+ ChildBlockLo: 2,
+ },
+ Node: node3,
+ },
+ },
+ }
+
+ node1 = &disklayout.ExtentNode{
+ Header: disklayout.ExtentHeader{
+ Magic: disklayout.ExtentMagic,
+ NumEntries: 2,
+ MaxEntries: 4,
+ Height: 0,
+ },
+ Entries: []disklayout.ExtentEntryPair{
+ {
+ Entry: &disklayout.Extent{
+ FirstFileBlock: 0,
+ Length: 1,
+ StartBlockLo: 3,
+ },
+ Node: nil,
+ },
+ {
+ Entry: &disklayout.Extent{
+ FirstFileBlock: 1,
+ Length: 2,
+ StartBlockLo: 4,
+ },
+ Node: nil,
+ },
+ },
+ }
+
+ node0 = &disklayout.ExtentNode{
+ Header: disklayout.ExtentHeader{
+ Magic: disklayout.ExtentMagic,
+ NumEntries: 2,
+ MaxEntries: 4,
+ Height: 2,
+ },
+ Entries: []disklayout.ExtentEntryPair{
+ {
+ Entry: &disklayout.ExtentIdx{
+ FirstFileBlock: 0,
+ ChildBlockLo: 0,
+ },
+ Node: node1,
+ },
+ {
+ Entry: &disklayout.ExtentIdx{
+ FirstFileBlock: 3,
+ ChildBlockLo: 1,
+ },
+ Node: node2,
+ },
+ },
+ }
+)
+
+// TestExtentReader stress tests extentReader functionality. It performs random
+// length reads from all possible positions in the extent tree.
+func TestExtentReader(t *testing.T) {
+ mockExtentFile, want := extentTreeSetUp(t, node0)
+ n := len(want)
+
+ for from := 0; from < n; from++ {
+ got := make([]byte, n-from)
+
+ if read, err := mockExtentFile.ReadAt(got, int64(from)); err != nil {
+ t.Fatalf("file read operation from offset %d to %d only read %d bytes: %v", from, n, read, err)
+ }
+
+ if diff := cmp.Diff(got, want[from:]); diff != "" {
+ t.Fatalf("file data from offset %d to %d mismatched (-want +got):\n%s", from, n, diff)
+ }
+ }
+}
+
+// TestBuildExtentTree tests the extent tree building logic.
+func TestBuildExtentTree(t *testing.T) {
+ mockExtentFile, _ := extentTreeSetUp(t, node0)
+
+ opt := cmpopts.IgnoreUnexported(disklayout.ExtentIdx{}, disklayout.ExtentHeader{})
+ if diff := cmp.Diff(&mockExtentFile.root, node0, opt); diff != "" {
+ t.Errorf("extent tree mismatch (-want +got):\n%s", diff)
+ }
+}
+
+// extentTreeSetUp writes the passed extent tree to a mock disk as an extent
+// tree. It also constucts a mock extent file with the same tree built in it.
+// It also writes random data file data and returns it.
+func extentTreeSetUp(t *testing.T, root *disklayout.ExtentNode) (*extentFile, []byte) {
+ t.Helper()
+
+ mockDisk := make([]byte, mockExtentBlkSize*10)
+ mockExtentFile := &extentFile{}
+ args := inodeArgs{
+ fs: &filesystem{
+ dev: bytes.NewReader(mockDisk),
+ },
+ diskInode: &disklayout.InodeNew{
+ InodeOld: disklayout.InodeOld{
+ SizeLo: uint32(mockExtentBlkSize) * getNumPhyBlks(root),
+ },
+ },
+ blkSize: mockExtentBlkSize,
+ }
+ mockExtentFile.regFile.inode.init(args, &mockExtentFile.regFile)
+
+ fileData := writeTree(&mockExtentFile.regFile.inode, mockDisk, node0, mockExtentBlkSize)
+
+ if err := mockExtentFile.buildExtTree(); err != nil {
+ t.Fatalf("inode.buildExtTree failed: %v", err)
+ }
+ return mockExtentFile, fileData
+}
+
+// writeTree writes the tree represented by `root` to the inode and disk. It
+// also writes random file data on disk.
+func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBlkSize uint64) []byte {
+ rootData := binary.Marshal(nil, binary.LittleEndian, root.Header)
+ for _, ep := range root.Entries {
+ rootData = binary.Marshal(rootData, binary.LittleEndian, ep.Entry)
+ }
+
+ copy(in.diskInode.Data(), rootData)
+
+ var fileData []byte
+ for _, ep := range root.Entries {
+ if root.Header.Height == 0 {
+ fileData = append(fileData, writeFileDataToExtent(disk, ep.Entry.(*disklayout.Extent))...)
+ } else {
+ fileData = append(fileData, writeTreeToDisk(disk, ep)...)
+ }
+ }
+ return fileData
+}
+
+// writeTreeToDisk is the recursive step for writeTree which writes the tree
+// on the disk only. Also writes random file data on disk.
+func writeTreeToDisk(disk []byte, curNode disklayout.ExtentEntryPair) []byte {
+ nodeData := binary.Marshal(nil, binary.LittleEndian, curNode.Node.Header)
+ for _, ep := range curNode.Node.Entries {
+ nodeData = binary.Marshal(nodeData, binary.LittleEndian, ep.Entry)
+ }
+
+ copy(disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:], nodeData)
+
+ var fileData []byte
+ for _, ep := range curNode.Node.Entries {
+ if curNode.Node.Header.Height == 0 {
+ fileData = append(fileData, writeFileDataToExtent(disk, ep.Entry.(*disklayout.Extent))...)
+ } else {
+ fileData = append(fileData, writeTreeToDisk(disk, ep)...)
+ }
+ }
+ return fileData
+}
+
+// writeFileDataToExtent writes random bytes to the blocks on disk that the
+// passed extent points to.
+func writeFileDataToExtent(disk []byte, ex *disklayout.Extent) []byte {
+ phyExStartBlk := ex.PhysicalBlock()
+ phyExStartOff := phyExStartBlk * mockExtentBlkSize
+ phyExEndOff := phyExStartOff + uint64(ex.Length)*mockExtentBlkSize
+ rand.Read(disk[phyExStartOff:phyExEndOff])
+ return disk[phyExStartOff:phyExEndOff]
+}
+
+// getNumPhyBlks returns the number of physical blocks covered under the node.
+func getNumPhyBlks(node *disklayout.ExtentNode) uint32 {
+ var res uint32
+ for _, ep := range node.Entries {
+ if node.Header.Height == 0 {
+ res += uint32(ep.Entry.(*disklayout.Extent).Length)
+ } else {
+ res += getNumPhyBlks(ep.Node)
+ }
+ }
+ return res
+}
diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go
new file mode 100644
index 000000000..90b086468
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/file_description.go
@@ -0,0 +1,65 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// fileDescription is embedded by ext implementations of
+// vfs.FileDescriptionImpl.
+type fileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+}
+
+func (fd *fileDescription) filesystem() *filesystem {
+ return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem)
+}
+
+func (fd *fileDescription) inode() *inode {
+ return fd.vfsfd.Dentry().Impl().(*dentry).inode
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ var stat linux.Statx
+ fd.inode().statTo(&stat)
+ return stat, nil
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ if opts.Stat.Mask == 0 {
+ return nil
+ }
+ return syserror.EPERM
+}
+
+// SetStat implements vfs.FileDescriptionImpl.StatFS.
+func (fd *fileDescription) StatFS(ctx context.Context) (linux.Statfs, error) {
+ var stat linux.Statfs
+ fd.filesystem().statTo(&stat)
+ return stat, nil
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync.
+func (fd *fileDescription) Sync(ctx context.Context) error {
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go
new file mode 100644
index 000000000..557963e03
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/filesystem.go
@@ -0,0 +1,548 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "errors"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+var (
+ // errResolveDirent indicates that the vfs.ResolvingPath.Component() does
+ // not exist on the dentry tree but does exist on disk. So it has to be read in
+ // using the in-memory dirent and added to the dentry tree. Usually indicates
+ // the need to lock filesystem.mu for writing.
+ errResolveDirent = errors.New("resolve path component using dirent")
+)
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ vfsfs vfs.Filesystem
+
+ // mu serializes changes to the Dentry tree.
+ mu sync.RWMutex
+
+ // dev represents the underlying fs device. It does not require protection
+ // because io.ReaderAt permits concurrent read calls to it. It translates to
+ // the pread syscall which passes on the read request directly to the device
+ // driver. Device drivers are intelligent in serving multiple concurrent read
+ // requests in the optimal order (taking locality into consideration).
+ dev io.ReaderAt
+
+ // inodeCache maps absolute inode numbers to the corresponding Inode struct.
+ // Inodes should be removed from this once their reference count hits 0.
+ //
+ // Protected by mu because most additions (see IterDirents) and all removals
+ // from this corresponds to a change in the dentry tree.
+ inodeCache map[uint32]*inode
+
+ // sb represents the filesystem superblock. Immutable after initialization.
+ sb disklayout.SuperBlock
+
+ // bgs represents all the block group descriptors for the filesystem.
+ // Immutable after initialization.
+ bgs []disklayout.BlockGroup
+
+ // devMinor is this filesystem's device minor number. Immutable after
+ // initialization.
+ devMinor uint32
+}
+
+// Compiles only if filesystem implements vfs.FilesystemImpl.
+var _ vfs.FilesystemImpl = (*filesystem)(nil)
+
+// stepLocked resolves rp.Component() in parent directory vfsd. The write
+// parameter passed tells if the caller has acquired filesystem.mu for writing
+// or not. If set to true, an existing inode on disk can be added to the dentry
+// tree if not present already.
+//
+// stepLocked is loosely analogous to fs/namei.c:walk_component().
+//
+// Preconditions:
+// - filesystem.mu must be locked (for writing if write param is true).
+// - !rp.Done().
+// - inode == vfsd.Impl().(*Dentry).inode.
+func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write bool) (*vfs.Dentry, *inode, error) {
+ if !inode.isDir() {
+ return nil, nil, syserror.ENOTDIR
+ }
+ if err := inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, nil, err
+ }
+
+ for {
+ name := rp.Component()
+ if name == "." {
+ rp.Advance()
+ return vfsd, inode, nil
+ }
+ d := vfsd.Impl().(*dentry)
+ if name == ".." {
+ isRoot, err := rp.CheckRoot(vfsd)
+ if err != nil {
+ return nil, nil, err
+ }
+ if isRoot || d.parent == nil {
+ rp.Advance()
+ return vfsd, inode, nil
+ }
+ if err := rp.CheckMount(&d.parent.vfsd); err != nil {
+ return nil, nil, err
+ }
+ rp.Advance()
+ return &d.parent.vfsd, d.parent.inode, nil
+ }
+
+ dir := inode.impl.(*directory)
+ child, ok := dir.childCache[name]
+ if !ok {
+ // We may need to instantiate a new dentry for this child.
+ childDirent, ok := dir.childMap[name]
+ if !ok {
+ // The underlying inode does not exist on disk.
+ return nil, nil, syserror.ENOENT
+ }
+
+ if !write {
+ // filesystem.mu must be held for writing to add to the dentry tree.
+ return nil, nil, errResolveDirent
+ }
+
+ // Create and add the component's dirent to the dentry tree.
+ fs := rp.Mount().Filesystem().Impl().(*filesystem)
+ childInode, err := fs.getOrCreateInodeLocked(childDirent.diskDirent.Inode())
+ if err != nil {
+ return nil, nil, err
+ }
+ // incRef because this is being added to the dentry tree.
+ childInode.incRef()
+ child = newDentry(childInode)
+ child.parent = d
+ child.name = name
+ dir.childCache[name] = child
+ }
+ if err := rp.CheckMount(&child.vfsd); err != nil {
+ return nil, nil, err
+ }
+ if child.inode.isSymlink() && rp.ShouldFollowSymlink() {
+ if err := rp.HandleSymlink(child.inode.impl.(*symlink).target); err != nil {
+ return nil, nil, err
+ }
+ continue
+ }
+ rp.Advance()
+ return &child.vfsd, child.inode, nil
+ }
+}
+
+// walkLocked resolves rp to an existing file. The write parameter
+// passed tells if the caller has acquired filesystem.mu for writing or not.
+// If set to true, additions can be made to the dentry tree while walking.
+// If errResolveDirent is returned, the walk needs to be continued with an
+// upgraded filesystem.mu.
+//
+// walkLocked is loosely analogous to Linux's fs/namei.c:path_lookupat().
+//
+// Preconditions:
+// - filesystem.mu must be locked (for writing if write param is true).
+func walkLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) {
+ vfsd := rp.Start()
+ inode := vfsd.Impl().(*dentry).inode
+ for !rp.Done() {
+ var err error
+ vfsd, inode, err = stepLocked(rp, vfsd, inode, write)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ if rp.MustBeDir() && !inode.isDir() {
+ return nil, nil, syserror.ENOTDIR
+ }
+ return vfsd, inode, nil
+}
+
+// walkParentLocked resolves all but the last path component of rp to an
+// existing directory. It does not check that the returned directory is
+// searchable by the provider of rp. The write parameter passed tells if the
+// caller has acquired filesystem.mu for writing or not. If set to true,
+// additions can be made to the dentry tree while walking.
+// If errResolveDirent is returned, the walk needs to be continued with an
+// upgraded filesystem.mu.
+//
+// walkParentLocked is loosely analogous to Linux's fs/namei.c:path_parentat().
+//
+// Preconditions:
+// - filesystem.mu must be locked (for writing if write param is true).
+// - !rp.Done().
+func walkParentLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) {
+ vfsd := rp.Start()
+ inode := vfsd.Impl().(*dentry).inode
+ for !rp.Final() {
+ var err error
+ vfsd, inode, err = stepLocked(rp, vfsd, inode, write)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ if !inode.isDir() {
+ return nil, nil, syserror.ENOTDIR
+ }
+ return vfsd, inode, nil
+}
+
+// walk resolves rp to an existing file. If parent is set to true, it resolves
+// the rp till the parent of the last component which should be an existing
+// directory. If parent is false then resolves rp entirely. Attemps to resolve
+// the path as far as it can with a read lock and upgrades the lock if needed.
+func (fs *filesystem) walk(rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *inode, error) {
+ var (
+ vfsd *vfs.Dentry
+ inode *inode
+ err error
+ )
+
+ // Try walking with the hopes that all dentries have already been pulled out
+ // of disk. This reduces congestion (allows concurrent walks).
+ fs.mu.RLock()
+ if parent {
+ vfsd, inode, err = walkParentLocked(rp, false)
+ } else {
+ vfsd, inode, err = walkLocked(rp, false)
+ }
+ fs.mu.RUnlock()
+
+ if err == errResolveDirent {
+ // Upgrade lock and continue walking. Lock upgrading in the middle of the
+ // walk is fine as this is a read only filesystem.
+ fs.mu.Lock()
+ if parent {
+ vfsd, inode, err = walkParentLocked(rp, true)
+ } else {
+ vfsd, inode, err = walkLocked(rp, true)
+ }
+ fs.mu.Unlock()
+ }
+
+ return vfsd, inode, err
+}
+
+// getOrCreateInodeLocked gets the inode corresponding to the inode number passed in.
+// It creates a new one with the given inode number if one does not exist.
+// The caller must increment the ref count if adding this to the dentry tree.
+//
+// Precondition: must be holding fs.mu for writing.
+func (fs *filesystem) getOrCreateInodeLocked(inodeNum uint32) (*inode, error) {
+ if in, ok := fs.inodeCache[inodeNum]; ok {
+ return in, nil
+ }
+
+ in, err := newInode(fs, inodeNum)
+ if err != nil {
+ return nil, err
+ }
+
+ fs.inodeCache[inodeNum] = in
+ return in, nil
+}
+
+// statTo writes the statfs fields to the output parameter.
+func (fs *filesystem) statTo(stat *linux.Statfs) {
+ stat.Type = uint64(fs.sb.Magic())
+ stat.BlockSize = int64(fs.sb.BlockSize())
+ stat.Blocks = fs.sb.BlocksCount()
+ stat.BlocksFree = fs.sb.FreeBlocksCount()
+ stat.BlocksAvailable = fs.sb.FreeBlocksCount()
+ stat.Files = uint64(fs.sb.InodesCount())
+ stat.FilesFree = uint64(fs.sb.FreeInodesCount())
+ stat.NameLength = disklayout.MaxFileName
+ stat.FragmentSize = int64(fs.sb.BlockSize())
+ // TODO(b/134676337): Set Statfs.Flags and Statfs.FSID.
+}
+
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ _, inode, err := fs.walk(rp, false)
+ if err != nil {
+ return err
+ }
+ return inode.checkPermissions(rp.Credentials(), ats)
+}
+
+// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
+func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
+ vfsd, inode, err := fs.walk(rp, false)
+ if err != nil {
+ return nil, err
+ }
+
+ if opts.CheckSearchable {
+ if !inode.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ }
+
+ inode.incRef()
+ return vfsd, nil
+}
+
+// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt.
+func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
+ vfsd, inode, err := fs.walk(rp, true)
+ if err != nil {
+ return nil, err
+ }
+ inode.incRef()
+ return vfsd, nil
+}
+
+// OpenAt implements vfs.FilesystemImpl.OpenAt.
+func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ vfsd, inode, err := fs.walk(rp, false)
+ if err != nil {
+ return nil, err
+ }
+
+ // EROFS is returned if write access is needed.
+ if vfs.MayWriteFileWithOpenFlags(opts.Flags) || opts.Flags&(linux.O_CREAT|linux.O_EXCL|linux.O_TMPFILE) != 0 {
+ return nil, syserror.EROFS
+ }
+ return inode.open(rp, vfsd, &opts)
+}
+
+// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
+func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
+ _, inode, err := fs.walk(rp, false)
+ if err != nil {
+ return "", err
+ }
+ symlink, ok := inode.impl.(*symlink)
+ if !ok {
+ return "", syserror.EINVAL
+ }
+ return symlink.target, nil
+}
+
+// StatAt implements vfs.FilesystemImpl.StatAt.
+func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
+ _, inode, err := fs.walk(rp, false)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ var stat linux.Statx
+ inode.statTo(&stat)
+ return stat, nil
+}
+
+// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
+func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
+ if _, _, err := fs.walk(rp, false); err != nil {
+ return linux.Statfs{}, err
+ }
+
+ var stat linux.Statfs
+ fs.statTo(&stat)
+ return stat, nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release() {
+ fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+}
+
+// Sync implements vfs.FilesystemImpl.Sync.
+func (fs *filesystem) Sync(ctx context.Context) error {
+ // This is a readonly filesystem for now.
+ return nil
+}
+
+// The vfs.FilesystemImpl functions below return EROFS because their respective
+// man pages say that EROFS must be returned if the path resolves to a file on
+// this read-only filesystem.
+
+// LinkAt implements vfs.FilesystemImpl.LinkAt.
+func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+
+ if _, _, err := fs.walk(rp, true); err != nil {
+ return err
+ }
+
+ return syserror.EROFS
+}
+
+// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
+func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+
+ if _, _, err := fs.walk(rp, true); err != nil {
+ return err
+ }
+
+ return syserror.EROFS
+}
+
+// MknodAt implements vfs.FilesystemImpl.MknodAt.
+func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+
+ _, _, err := fs.walk(rp, true)
+ if err != nil {
+ return err
+ }
+
+ return syserror.EROFS
+}
+
+// RenameAt implements vfs.FilesystemImpl.RenameAt.
+func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
+ if rp.Done() {
+ return syserror.ENOENT
+ }
+
+ _, _, err := fs.walk(rp, false)
+ if err != nil {
+ return err
+ }
+
+ return syserror.EROFS
+}
+
+// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
+func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ _, inode, err := fs.walk(rp, false)
+ if err != nil {
+ return err
+ }
+
+ if !inode.isDir() {
+ return syserror.ENOTDIR
+ }
+
+ return syserror.EROFS
+}
+
+// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
+func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
+ _, _, err := fs.walk(rp, false)
+ if err != nil {
+ return err
+ }
+
+ return syserror.EROFS
+}
+
+// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
+func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+
+ _, _, err := fs.walk(rp, true)
+ if err != nil {
+ return err
+ }
+
+ return syserror.EROFS
+}
+
+// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
+func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ _, inode, err := fs.walk(rp, false)
+ if err != nil {
+ return err
+ }
+
+ if inode.isDir() {
+ return syserror.EISDIR
+ }
+
+ return syserror.EROFS
+}
+
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
+ _, inode, err := fs.walk(rp, false)
+ if err != nil {
+ return nil, err
+ }
+ if err := inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
+
+ // TODO(b/134676337): Support sockets.
+ return nil, syserror.ECONNREFUSED
+}
+
+// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
+ _, _, err := fs.walk(rp, false)
+ if err != nil {
+ return nil, err
+ }
+ return nil, syserror.ENOTSUP
+}
+
+// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
+ _, _, err := fs.walk(rp, false)
+ if err != nil {
+ return "", err
+ }
+ return "", syserror.ENOTSUP
+}
+
+// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
+func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+ _, _, err := fs.walk(rp, false)
+ if err != nil {
+ return err
+ }
+ return syserror.ENOTSUP
+}
+
+// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
+func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+ _, _, err := fs.walk(rp, false)
+ if err != nil {
+ return err
+ }
+ return syserror.ENOTSUP
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b)
+}
diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go
new file mode 100644
index 000000000..30636cf66
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/inode.go
@@ -0,0 +1,242 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// inode represents an ext inode.
+//
+// inode uses the same inheritance pattern that pkg/sentry/vfs structures use.
+// This has been done to increase memory locality.
+//
+// Implementations:
+// inode --
+// |-- dir
+// |-- symlink
+// |-- regular--
+// |-- extent file
+// |-- block map file
+type inode struct {
+ // refs is a reference count. refs is accessed using atomic memory operations.
+ refs int64
+
+ // fs is the containing filesystem.
+ fs *filesystem
+
+ // inodeNum is the inode number of this inode on disk. This is used to
+ // identify inodes within the ext filesystem.
+ inodeNum uint32
+
+ // blkSize is the fs data block size. Same as filesystem.sb.BlockSize().
+ blkSize uint64
+
+ // diskInode gives us access to the inode struct on disk. Immutable.
+ diskInode disklayout.Inode
+
+ locks vfs.FileLocks
+
+ // This is immutable. The first field of the implementations must have inode
+ // as the first field to ensure temporality.
+ impl interface{}
+}
+
+// incRef increments the inode ref count.
+func (in *inode) incRef() {
+ atomic.AddInt64(&in.refs, 1)
+}
+
+// tryIncRef tries to increment the ref count. Returns true if successful.
+func (in *inode) tryIncRef() bool {
+ for {
+ refs := atomic.LoadInt64(&in.refs)
+ if refs == 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&in.refs, refs, refs+1) {
+ return true
+ }
+ }
+}
+
+// decRef decrements the inode ref count and releases the inode resources if
+// the ref count hits 0.
+//
+// Precondition: Must have locked filesystem.mu.
+func (in *inode) decRef() {
+ if refs := atomic.AddInt64(&in.refs, -1); refs == 0 {
+ delete(in.fs.inodeCache, in.inodeNum)
+ } else if refs < 0 {
+ panic("ext.inode.decRef() called without holding a reference")
+ }
+}
+
+// newInode is the inode constructor. Reads the inode off disk. Identifies
+// inodes based on the absolute inode number on disk.
+func newInode(fs *filesystem, inodeNum uint32) (*inode, error) {
+ if inodeNum == 0 {
+ panic("inode number 0 on ext filesystems is not possible")
+ }
+
+ inodeRecordSize := fs.sb.InodeSize()
+ var diskInode disklayout.Inode
+ if inodeRecordSize == disklayout.OldInodeSize {
+ diskInode = &disklayout.InodeOld{}
+ } else {
+ diskInode = &disklayout.InodeNew{}
+ }
+
+ // Calculate where the inode is actually placed.
+ inodesPerGrp := fs.sb.InodesPerGroup()
+ blkSize := fs.sb.BlockSize()
+ inodeTableOff := fs.bgs[getBGNum(inodeNum, inodesPerGrp)].InodeTable() * blkSize
+ inodeOff := inodeTableOff + uint64(uint32(inodeRecordSize)*getBGOff(inodeNum, inodesPerGrp))
+
+ if err := readFromDisk(fs.dev, int64(inodeOff), diskInode); err != nil {
+ return nil, err
+ }
+
+ // Build the inode based on its type.
+ args := inodeArgs{
+ fs: fs,
+ inodeNum: inodeNum,
+ blkSize: blkSize,
+ diskInode: diskInode,
+ }
+
+ switch diskInode.Mode().FileType() {
+ case linux.ModeSymlink:
+ f, err := newSymlink(args)
+ if err != nil {
+ return nil, err
+ }
+ return &f.inode, nil
+ case linux.ModeRegular:
+ f, err := newRegularFile(args)
+ if err != nil {
+ return nil, err
+ }
+ return &f.inode, nil
+ case linux.ModeDirectory:
+ f, err := newDirectory(args, fs.sb.IncompatibleFeatures().DirentFileType)
+ if err != nil {
+ return nil, err
+ }
+ return &f.inode, nil
+ default:
+ // TODO(b/134676337): Return appropriate errors for sockets, pipes and devices.
+ return nil, syserror.EINVAL
+ }
+}
+
+type inodeArgs struct {
+ fs *filesystem
+ inodeNum uint32
+ blkSize uint64
+ diskInode disklayout.Inode
+}
+
+func (in *inode) init(args inodeArgs, impl interface{}) {
+ in.fs = args.fs
+ in.inodeNum = args.inodeNum
+ in.blkSize = args.blkSize
+ in.diskInode = args.diskInode
+ in.impl = impl
+}
+
+// open creates and returns a file description for the dentry passed in.
+func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+ ats := vfs.AccessTypesForOpenFlags(opts)
+ if err := in.checkPermissions(rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+ mnt := rp.Mount()
+ switch in.impl.(type) {
+ case *regularFile:
+ var fd regularFileFD
+ fd.LockFD.Init(&in.locks)
+ if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+ case *directory:
+ // Can't open directories writably. This check is not necessary for a read
+ // only filesystem but will be required when write is implemented.
+ if ats&vfs.MayWrite != 0 {
+ return nil, syserror.EISDIR
+ }
+ var fd directoryFD
+ fd.LockFD.Init(&in.locks)
+ if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+ case *symlink:
+ if opts.Flags&linux.O_PATH == 0 {
+ // Can't open symlinks without O_PATH.
+ return nil, syserror.ELOOP
+ }
+ var fd symlinkFD
+ fd.LockFD.Init(&in.locks)
+ fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
+ return &fd.vfsfd, nil
+ default:
+ panic(fmt.Sprintf("unknown inode type: %T", in.impl))
+ }
+}
+
+func (in *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+ return vfs.GenericCheckPermissions(creds, ats, in.diskInode.Mode(), in.diskInode.UID(), in.diskInode.GID())
+}
+
+// statTo writes the statx fields to the output parameter.
+func (in *inode) statTo(stat *linux.Statx) {
+ stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK |
+ linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_SIZE |
+ linux.STATX_ATIME | linux.STATX_CTIME | linux.STATX_MTIME
+ stat.Blksize = uint32(in.blkSize)
+ stat.Mode = uint16(in.diskInode.Mode())
+ stat.Nlink = uint32(in.diskInode.LinksCount())
+ stat.UID = uint32(in.diskInode.UID())
+ stat.GID = uint32(in.diskInode.GID())
+ stat.Ino = uint64(in.inodeNum)
+ stat.Size = in.diskInode.Size()
+ stat.Atime = in.diskInode.AccessTime().StatxTimestamp()
+ stat.Ctime = in.diskInode.ChangeTime().StatxTimestamp()
+ stat.Mtime = in.diskInode.ModificationTime().StatxTimestamp()
+ stat.DevMajor = linux.UNNAMED_MAJOR
+ stat.DevMinor = in.fs.devMinor
+ // TODO(b/134676337): Set stat.Blocks which is the number of 512 byte blocks
+ // (including metadata blocks) required to represent this file.
+}
+
+// getBGNum returns the block group number that a given inode belongs to.
+func getBGNum(inodeNum uint32, inodesPerGrp uint32) uint32 {
+ return (inodeNum - 1) / inodesPerGrp
+}
+
+// getBGOff returns the offset at which the given inode lives in the block
+// group's inode table, i.e. the index of the inode in the inode table.
+func getBGOff(inodeNum uint32, inodesPerGrp uint32) uint32 {
+ return (inodeNum - 1) % inodesPerGrp
+}
diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go
new file mode 100644
index 000000000..66d14bb95
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/regular_file.go
@@ -0,0 +1,162 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// regularFile represents a regular file's inode. This too follows the
+// inheritance pattern prevelant in the vfs layer described in
+// pkg/sentry/vfs/README.md.
+type regularFile struct {
+ inode inode
+
+ // This is immutable. The first field of fileReader implementations must be
+ // regularFile to ensure temporality.
+ // io.ReaderAt is more strict than io.Reader in the sense that a partial read
+ // is always accompanied by an error. If a read spans past the end of file, a
+ // partial read (within file range) is done and io.EOF is returned.
+ impl io.ReaderAt
+}
+
+// newRegularFile is the regularFile constructor. It figures out what kind of
+// file this is and initializes the fileReader.
+func newRegularFile(args inodeArgs) (*regularFile, error) {
+ if args.diskInode.Flags().Extents {
+ file, err := newExtentFile(args)
+ if err != nil {
+ return nil, err
+ }
+ return &file.regFile, nil
+ }
+
+ file, err := newBlockMapFile(args)
+ if err != nil {
+ return nil, err
+ }
+ return &file.regFile, nil
+}
+
+func (in *inode) isRegular() bool {
+ _, ok := in.impl.(*regularFile)
+ return ok
+}
+
+// directoryFD represents a directory file description. It implements
+// vfs.FileDescriptionImpl.
+type regularFileFD struct {
+ fileDescription
+ vfs.LockFD
+
+ // off is the file offset. off is accessed using atomic memory operations.
+ off int64
+
+ // offMu serializes operations that may mutate off.
+ offMu sync.Mutex
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *regularFileFD) Release() {}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ safeReader := safemem.FromIOReaderAt{
+ ReaderAt: fd.inode().impl.(*regularFile).impl,
+ Offset: offset,
+ }
+
+ // Copies data from disk directly into usermem without any intermediate
+ // allocations (if dst is converted into BlockSeq such that it does not need
+ // safe copying).
+ return dst.CopyOutFrom(ctx, safeReader)
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ n, err := fd.PRead(ctx, dst, fd.off, opts)
+ fd.offMu.Lock()
+ fd.off += n
+ fd.offMu.Unlock()
+ return n, err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ // write(2) specifies that EBADF must be returned if the fd is not open for
+ // writing.
+ return 0, syserror.EBADF
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ n, err := fd.PWrite(ctx, src, fd.off, opts)
+ fd.offMu.Lock()
+ fd.off += n
+ fd.offMu.Unlock()
+ return n, err
+}
+
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (fd *regularFileFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ return syserror.ENOTDIR
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.offMu.Lock()
+ defer fd.offMu.Unlock()
+ switch whence {
+ case linux.SEEK_SET:
+ // Use offset as specified.
+ case linux.SEEK_CUR:
+ offset += fd.off
+ case linux.SEEK_END:
+ offset += int64(fd.inode().diskInode.Size())
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ fd.off = offset
+ return offset, nil
+}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ // TODO(b/134676337): Implement mmap(2).
+ return syserror.ENODEV
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *regularFileFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *regularFileFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go
new file mode 100644
index 000000000..62efd4095
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/symlink.go
@@ -0,0 +1,111 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// symlink represents a symlink inode.
+type symlink struct {
+ inode inode
+ target string // immutable
+}
+
+// newSymlink is the symlink constructor. It reads out the symlink target from
+// the inode (however it might have been stored).
+func newSymlink(args inodeArgs) (*symlink, error) {
+ var link []byte
+
+ // If the symlink target is lesser than 60 bytes, its stores in inode.Data().
+ // Otherwise either extents or block maps will be used to store the link.
+ size := args.diskInode.Size()
+ if size < 60 {
+ link = args.diskInode.Data()[:size]
+ } else {
+ // Create a regular file out of this inode and read out the target.
+ regFile, err := newRegularFile(args)
+ if err != nil {
+ return nil, err
+ }
+
+ link = make([]byte, size)
+ if n, err := regFile.impl.ReadAt(link, 0); uint64(n) < size {
+ return nil, err
+ }
+ }
+
+ file := &symlink{target: string(link)}
+ file.inode.init(args, file)
+ return file, nil
+}
+
+func (in *inode) isSymlink() bool {
+ _, ok := in.impl.(*symlink)
+ return ok
+}
+
+// symlinkFD represents a symlink file description and implements implements
+// vfs.FileDescriptionImpl. which may only be used if open options contains
+// O_PATH. For this reason most of the functions return EBADF.
+type symlinkFD struct {
+ fileDescription
+ vfs.NoLockFD
+}
+
+// Compiles only if symlinkFD implements vfs.FileDescriptionImpl.
+var _ vfs.FileDescriptionImpl = (*symlinkFD)(nil)
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *symlinkFD) Release() {}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *symlinkFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *symlinkFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *symlinkFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *symlinkFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (fd *symlinkFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ return syserror.ENOTDIR
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *symlinkFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *symlinkFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ return syserror.EBADF
+}
diff --git a/pkg/sentry/fsimpl/ext/utils.go b/pkg/sentry/fsimpl/ext/utils.go
new file mode 100644
index 000000000..d8b728f8c
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/utils.go
@@ -0,0 +1,94 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// readFromDisk performs a binary read from disk into the given struct from
+// the absolute offset provided.
+func readFromDisk(dev io.ReaderAt, abOff int64, v interface{}) error {
+ n := binary.Size(v)
+ buf := make([]byte, n)
+ if read, _ := dev.ReadAt(buf, abOff); read < int(n) {
+ return syserror.EIO
+ }
+
+ binary.Unmarshal(buf, binary.LittleEndian, v)
+ return nil
+}
+
+// readSuperBlock reads the SuperBlock from block group 0 in the underlying
+// device. There are three versions of the superblock. This function identifies
+// and returns the correct version.
+func readSuperBlock(dev io.ReaderAt) (disklayout.SuperBlock, error) {
+ var sb disklayout.SuperBlock = &disklayout.SuperBlockOld{}
+ if err := readFromDisk(dev, disklayout.SbOffset, sb); err != nil {
+ return nil, err
+ }
+ if sb.Revision() == disklayout.OldRev {
+ return sb, nil
+ }
+
+ sb = &disklayout.SuperBlock32Bit{}
+ if err := readFromDisk(dev, disklayout.SbOffset, sb); err != nil {
+ return nil, err
+ }
+ if !sb.IncompatibleFeatures().Is64Bit {
+ return sb, nil
+ }
+
+ sb = &disklayout.SuperBlock64Bit{}
+ if err := readFromDisk(dev, disklayout.SbOffset, sb); err != nil {
+ return nil, err
+ }
+ return sb, nil
+}
+
+// blockGroupsCount returns the number of block groups in the ext fs.
+func blockGroupsCount(sb disklayout.SuperBlock) uint64 {
+ blocksCount := sb.BlocksCount()
+ blocksPerGroup := uint64(sb.BlocksPerGroup())
+
+ // Round up the result. float64 can compromise precision so do it manually.
+ return (blocksCount + blocksPerGroup - 1) / blocksPerGroup
+}
+
+// readBlockGroups reads the block group descriptor table from block group 0 in
+// the underlying device.
+func readBlockGroups(dev io.ReaderAt, sb disklayout.SuperBlock) ([]disklayout.BlockGroup, error) {
+ bgCount := blockGroupsCount(sb)
+ bgdSize := uint64(sb.BgDescSize())
+ is64Bit := sb.IncompatibleFeatures().Is64Bit
+ bgds := make([]disklayout.BlockGroup, bgCount)
+
+ for i, off := uint64(0), uint64(sb.FirstDataBlock()+1)*sb.BlockSize(); i < bgCount; i, off = i+1, off+bgdSize {
+ if is64Bit {
+ bgds[i] = &disklayout.BlockGroup64Bit{}
+ } else {
+ bgds[i] = &disklayout.BlockGroup32Bit{}
+ }
+
+ if err := readFromDisk(dev, int64(off), bgds[i]); err != nil {
+ return nil, err
+ }
+ }
+ return bgds, nil
+}
diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD
new file mode 100644
index 000000000..41567967d
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/BUILD
@@ -0,0 +1,19 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "fuse",
+ srcs = [
+ "dev.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go
new file mode 100644
index 000000000..f6a67d005
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/dev.go
@@ -0,0 +1,100 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const fuseDevMinor = 229
+
+// fuseDevice implements vfs.Device for /dev/fuse.
+type fuseDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (fuseDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ var fd DeviceFD
+ if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// DeviceFD implements vfs.FileDescriptionImpl for /dev/fuse.
+type DeviceFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+
+ // TODO(gvisor.dev/issue/2987): Add all the data structures needed to enqueue
+ // and deque requests, control synchronization and establish communication
+ // between the FUSE kernel module and the /dev/fuse character device.
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *DeviceFD) Release() {}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *DeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.ENOSYS
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.ENOSYS
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *DeviceFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ENOSYS
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *DeviceFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ENOSYS
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return 0, syserror.ENOSYS
+}
+
+// Register registers the FUSE device with vfsObj.
+func Register(vfsObj *vfs.VirtualFilesystem) error {
+ if err := vfsObj.RegisterDevice(vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, fuseDevice{}, &vfs.RegisterDeviceOptions{
+ GroupName: "misc",
+ }); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// CreateDevtmpfsFile creates a device special file in devtmpfs.
+func CreateDevtmpfsFile(ctx context.Context, dev *devtmpfs.Accessor) error {
+ if err := dev.CreateDeviceFile(ctx, "fuse", vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, 0666 /* mode */); err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
new file mode 100644
index 000000000..4a800dcf9
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -0,0 +1,89 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+licenses(["notice"])
+
+go_template_instance(
+ name = "dentry_list",
+ out = "dentry_list.go",
+ package = "gofer",
+ prefix = "dentry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*dentry",
+ "Linker": "*dentry",
+ },
+)
+
+go_template_instance(
+ name = "fstree",
+ out = "fstree.go",
+ package = "gofer",
+ prefix = "generic",
+ template = "//pkg/sentry/vfs/genericfstree:generic_fstree",
+ types = {
+ "Dentry": "dentry",
+ },
+)
+
+go_library(
+ name = "gofer",
+ srcs = [
+ "dentry_list.go",
+ "directory.go",
+ "filesystem.go",
+ "fstree.go",
+ "gofer.go",
+ "handle.go",
+ "host_named_pipe.go",
+ "p9file.go",
+ "regular_file.go",
+ "socket.go",
+ "special_file.go",
+ "symlink.go",
+ "time.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fd",
+ "//pkg/fdnotifier",
+ "//pkg/fspath",
+ "//pkg/log",
+ "//pkg/p9",
+ "//pkg/safemem",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsimpl/host",
+ "//pkg/sentry/hostfd",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/pipe",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/socket/control",
+ "//pkg/sentry/socket/unix",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/unet",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "gofer_test",
+ srcs = ["gofer_test.go"],
+ library = ":gofer",
+ deps = [
+ "//pkg/p9",
+ "//pkg/sentry/contexttest",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
new file mode 100644
index 000000000..8c7c8e1b3
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -0,0 +1,308 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+ "sync"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func (d *dentry) isDir() bool {
+ return d.fileType() == linux.S_IFDIR
+}
+
+// Preconditions: filesystem.renameMu must be locked. d.dirMu must be locked.
+// d.isDir(). child must be a newly-created dentry that has never had a parent.
+func (d *dentry) cacheNewChildLocked(child *dentry, name string) {
+ d.IncRef() // reference held by child on its parent
+ child.parent = d
+ child.name = name
+ if d.children == nil {
+ d.children = make(map[string]*dentry)
+ }
+ d.children[name] = child
+}
+
+// Preconditions: d.dirMu must be locked. d.isDir().
+func (d *dentry) cacheNegativeLookupLocked(name string) {
+ // Don't cache negative lookups if InteropModeShared is in effect (since
+ // this makes remote lookup unavoidable), or if d.isSynthetic() (in which
+ // case the only files in the directory are those for which a dentry exists
+ // in d.children). Instead, just delete any previously-cached dentry.
+ if d.fs.opts.interop == InteropModeShared || d.isSynthetic() {
+ delete(d.children, name)
+ return
+ }
+ if d.children == nil {
+ d.children = make(map[string]*dentry)
+ }
+ d.children[name] = nil
+}
+
+type createSyntheticOpts struct {
+ name string
+ mode linux.FileMode
+ kuid auth.KUID
+ kgid auth.KGID
+
+ // The endpoint for a synthetic socket. endpoint should be nil if the file
+ // being created is not a socket.
+ endpoint transport.BoundEndpoint
+
+ // pipe should be nil if the file being created is not a pipe.
+ pipe *pipe.VFSPipe
+}
+
+// createSyntheticChildLocked creates a synthetic file with the given name
+// in d.
+//
+// Preconditions: d.dirMu must be locked. d.isDir(). d does not already contain
+// a child with the given name.
+func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) {
+ d2 := &dentry{
+ refs: 1, // held by d
+ fs: d.fs,
+ ino: d.fs.nextSyntheticIno(),
+ mode: uint32(opts.mode),
+ uid: uint32(opts.kuid),
+ gid: uint32(opts.kgid),
+ blockSize: usermem.PageSize, // arbitrary
+ handle: handle{
+ fd: -1,
+ },
+ nlink: uint32(2),
+ }
+ switch opts.mode.FileType() {
+ case linux.S_IFDIR:
+ // Nothing else needs to be done.
+ case linux.S_IFSOCK:
+ d2.endpoint = opts.endpoint
+ case linux.S_IFIFO:
+ d2.pipe = opts.pipe
+ default:
+ panic(fmt.Sprintf("failed to create synthetic file of unrecognized type: %v", opts.mode.FileType()))
+ }
+ d2.pf.dentry = d2
+ d2.vfsd.Init(d2)
+
+ d.cacheNewChildLocked(d2, opts.name)
+ d.syntheticChildren++
+}
+
+type directoryFD struct {
+ fileDescription
+ vfs.DirectoryFileDescriptionDefaultImpl
+
+ mu sync.Mutex
+ off int64
+ dirents []vfs.Dirent
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *directoryFD) Release() {
+}
+
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+
+ d := fd.dentry()
+ if fd.dirents == nil {
+ ds, err := d.getDirents(ctx)
+ if err != nil {
+ return err
+ }
+ fd.dirents = ds
+ }
+
+ d.InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ if d.cachedMetadataAuthoritative() {
+ d.touchAtime(fd.vfsfd.Mount())
+ }
+
+ for fd.off < int64(len(fd.dirents)) {
+ if err := cb.Handle(fd.dirents[fd.off]); err != nil {
+ return err
+ }
+ fd.off++
+ }
+ return nil
+}
+
+// Preconditions: d.isDir(). There exists at least one directoryFD representing d.
+func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
+ // NOTE(b/135560623): 9P2000.L's readdir does not specify behavior in the
+ // presence of concurrent mutation of an iterated directory, so
+ // implementations may duplicate or omit entries in this case, which
+ // violates POSIX semantics. Thus we read all directory entries while
+ // holding d.dirMu to exclude directory mutations. (Note that it is
+ // impossible for the client to exclude concurrent mutation from other
+ // remote filesystem users. Since there is no way to detect if the server
+ // has incorrectly omitted directory entries, we simply assume that the
+ // server is well-behaved under InteropModeShared.) This is inconsistent
+ // with Linux (which appears to assume that directory fids have the correct
+ // semantics, and translates struct file_operations::readdir calls directly
+ // to readdir RPCs), but is consistent with VFS1.
+
+ // filesystem.renameMu is needed for d.parent, and must be locked before
+ // dentry.dirMu.
+ d.fs.renameMu.RLock()
+ defer d.fs.renameMu.RUnlock()
+ d.dirMu.Lock()
+ defer d.dirMu.Unlock()
+ if d.dirents != nil {
+ return d.dirents, nil
+ }
+
+ // It's not clear if 9P2000.L's readdir is expected to return "." and "..",
+ // so we generate them here.
+ parent := genericParentOrSelf(d)
+ dirents := []vfs.Dirent{
+ {
+ Name: ".",
+ Type: linux.DT_DIR,
+ Ino: uint64(d.ino),
+ NextOff: 1,
+ },
+ {
+ Name: "..",
+ Type: uint8(atomic.LoadUint32(&parent.mode) >> 12),
+ Ino: uint64(parent.ino),
+ NextOff: 2,
+ },
+ }
+ var realChildren map[string]struct{}
+ if !d.isSynthetic() {
+ if d.syntheticChildren != 0 && d.fs.opts.interop == InteropModeShared {
+ // Record the set of children d actually has so that we don't emit
+ // duplicate entries for synthetic children.
+ realChildren = make(map[string]struct{})
+ }
+ off := uint64(0)
+ const count = 64 * 1024 // for consistency with the vfs1 client
+ d.handleMu.RLock()
+ if !d.handleReadable {
+ // This should not be possible because a readable handle should
+ // have been opened when the calling directoryFD was opened.
+ d.handleMu.RUnlock()
+ panic("gofer.dentry.getDirents called without a readable handle")
+ }
+ for {
+ p9ds, err := d.handle.file.readdir(ctx, off, count)
+ if err != nil {
+ d.handleMu.RUnlock()
+ return nil, err
+ }
+ if len(p9ds) == 0 {
+ d.handleMu.RUnlock()
+ break
+ }
+ for _, p9d := range p9ds {
+ if p9d.Name == "." || p9d.Name == ".." {
+ continue
+ }
+ dirent := vfs.Dirent{
+ Name: p9d.Name,
+ Ino: uint64(inoFromPath(p9d.QID.Path)),
+ NextOff: int64(len(dirents) + 1),
+ }
+ // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or
+ // DMSOCKET.
+ switch p9d.Type {
+ case p9.TypeSymlink:
+ dirent.Type = linux.DT_LNK
+ case p9.TypeDir:
+ dirent.Type = linux.DT_DIR
+ default:
+ dirent.Type = linux.DT_REG
+ }
+ dirents = append(dirents, dirent)
+ if realChildren != nil {
+ realChildren[p9d.Name] = struct{}{}
+ }
+ }
+ off = p9ds[len(p9ds)-1].Offset
+ }
+ }
+ // Emit entries for synthetic children.
+ if d.syntheticChildren != 0 {
+ for _, child := range d.children {
+ if child == nil || !child.isSynthetic() {
+ continue
+ }
+ if _, ok := realChildren[child.name]; ok {
+ continue
+ }
+ dirents = append(dirents, vfs.Dirent{
+ Name: child.name,
+ Type: uint8(atomic.LoadUint32(&child.mode) >> 12),
+ Ino: uint64(child.ino),
+ NextOff: int64(len(dirents) + 1),
+ })
+ }
+ }
+ // Cache dirents for future directoryFDs if permitted.
+ if d.cachedMetadataAuthoritative() {
+ d.dirents = dirents
+ }
+ return dirents, nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+
+ switch whence {
+ case linux.SEEK_SET:
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ if offset == 0 {
+ // Ensure that the next call to fd.IterDirents() calls
+ // fd.dentry().getDirents().
+ fd.dirents = nil
+ }
+ fd.off = offset
+ return fd.off, nil
+ case linux.SEEK_CUR:
+ offset += fd.off
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ // Don't clear fd.dirents in this case, even if offset == 0.
+ fd.off = offset
+ return fd.off, nil
+ default:
+ return 0, syserror.EINVAL
+ }
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync.
+func (fd *directoryFD) Sync(ctx context.Context) error {
+ return fd.dentry().handle.sync(ctx)
+}
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
new file mode 100644
index 000000000..cd5f5049e
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -0,0 +1,1504 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "sync"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Sync implements vfs.FilesystemImpl.Sync.
+func (fs *filesystem) Sync(ctx context.Context) error {
+ // Snapshot current syncable dentries and special files.
+ fs.syncMu.Lock()
+ ds := make([]*dentry, 0, len(fs.syncableDentries))
+ for d := range fs.syncableDentries {
+ d.IncRef()
+ ds = append(ds, d)
+ }
+ sffds := make([]*specialFileFD, 0, len(fs.specialFileFDs))
+ for sffd := range fs.specialFileFDs {
+ sffd.vfsfd.IncRef()
+ sffds = append(sffds, sffd)
+ }
+ fs.syncMu.Unlock()
+
+ // Return the first error we encounter, but sync everything we can
+ // regardless.
+ var retErr error
+
+ // Sync regular files.
+ for _, d := range ds {
+ err := d.syncSharedHandle(ctx)
+ d.DecRef()
+ if err != nil && retErr == nil {
+ retErr = err
+ }
+ }
+
+ // Sync special files, which may be writable but do not use dentry shared
+ // handles (so they won't be synced by the above).
+ for _, sffd := range sffds {
+ err := sffd.Sync(ctx)
+ sffd.vfsfd.DecRef()
+ if err != nil && retErr == nil {
+ retErr = err
+ }
+ }
+
+ return retErr
+}
+
+// maxFilenameLen is the maximum length of a filename. This is dictated by 9P's
+// encoding of strings, which uses 2 bytes for the length prefix.
+const maxFilenameLen = (1 << 16) - 1
+
+// dentrySlicePool is a pool of *[]*dentry used to store dentries for which
+// dentry.checkCachingLocked() must be called. The pool holds pointers to
+// slices because Go lacks generics, so sync.Pool operates on interface{}, so
+// every call to (what should be) sync.Pool<[]*dentry>.Put() allocates a copy
+// of the slice header on the heap.
+var dentrySlicePool = sync.Pool{
+ New: func() interface{} {
+ ds := make([]*dentry, 0, 4) // arbitrary non-zero initial capacity
+ return &ds
+ },
+}
+
+func appendDentry(ds *[]*dentry, d *dentry) *[]*dentry {
+ if ds == nil {
+ ds = dentrySlicePool.Get().(*[]*dentry)
+ }
+ *ds = append(*ds, d)
+ return ds
+}
+
+// Preconditions: ds != nil.
+func putDentrySlice(ds *[]*dentry) {
+ // Allow dentries to be GC'd.
+ for i := range *ds {
+ (*ds)[i] = nil
+ }
+ *ds = (*ds)[:0]
+ dentrySlicePool.Put(ds)
+}
+
+// stepLocked resolves rp.Component() to an existing file, starting from the
+// given directory.
+//
+// Dentries which may become cached as a result of the traversal are appended
+// to *ds.
+//
+// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
+// !rp.Done(). If !d.cachedMetadataAuthoritative(), then d's cached metadata
+// must be up to date.
+//
+// Postconditions: The returned dentry's cached metadata is up to date.
+func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) {
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+afterSymlink:
+ name := rp.Component()
+ if name == "." {
+ rp.Advance()
+ return d, nil
+ }
+ if name == ".." {
+ if isRoot, err := rp.CheckRoot(&d.vfsd); err != nil {
+ return nil, err
+ } else if isRoot || d.parent == nil {
+ rp.Advance()
+ return d, nil
+ }
+ // We must assume that d.parent is correct, because if d has been moved
+ // elsewhere in the remote filesystem so that its parent has changed,
+ // we have no way of determining its new parent's location in the
+ // filesystem.
+ //
+ // Call rp.CheckMount() before updating d.parent's metadata, since if
+ // we traverse to another mount then d.parent's metadata is irrelevant.
+ if err := rp.CheckMount(&d.parent.vfsd); err != nil {
+ return nil, err
+ }
+ if d != d.parent && !d.cachedMetadataAuthoritative() {
+ _, attrMask, attr, err := d.parent.file.getAttr(ctx, dentryAttrMask())
+ if err != nil {
+ return nil, err
+ }
+ d.parent.updateFromP9Attrs(attrMask, &attr)
+ }
+ rp.Advance()
+ return d.parent, nil
+ }
+ child, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), d, name, ds)
+ if err != nil {
+ return nil, err
+ }
+ if child == nil {
+ return nil, syserror.ENOENT
+ }
+ if err := rp.CheckMount(&child.vfsd); err != nil {
+ return nil, err
+ }
+ if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() {
+ target, err := child.readlink(ctx, rp.Mount())
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.HandleSymlink(target); err != nil {
+ return nil, err
+ }
+ goto afterSymlink // don't check the current directory again
+ }
+ rp.Advance()
+ return child, nil
+}
+
+// getChildLocked returns a dentry representing the child of parent with the
+// given name. If no such child exists, getChildLocked returns (nil, nil).
+//
+// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked.
+// parent.isDir(). name is not "." or "..".
+//
+// Postconditions: If getChildLocked returns a non-nil dentry, its cached
+// metadata is up to date.
+func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
+ if len(name) > maxFilenameLen {
+ return nil, syserror.ENAMETOOLONG
+ }
+ child, ok := parent.children[name]
+ if (ok && fs.opts.interop != InteropModeShared) || parent.isSynthetic() {
+ // Whether child is nil or not, it is cached information that is
+ // assumed to be correct.
+ return child, nil
+ }
+ // We either don't have cached information or need to verify that it's
+ // still correct, either of which requires a remote lookup. Check if this
+ // name is valid before performing the lookup.
+ return fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, ds)
+}
+
+// Preconditions: As for getChildLocked. !parent.isSynthetic().
+func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) {
+ qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name)
+ if err != nil && err != syserror.ENOENT {
+ return nil, err
+ }
+ if child != nil {
+ if !file.isNil() && inoFromPath(qid.Path) == child.ino {
+ // The file at this path hasn't changed. Just update cached metadata.
+ file.close(ctx)
+ child.updateFromP9Attrs(attrMask, &attr)
+ return child, nil
+ }
+ if file.isNil() && child.isSynthetic() {
+ // We have a synthetic file, and no remote file has arisen to
+ // replace it.
+ return child, nil
+ }
+ // The file at this path has changed or no longer exists. Mark the
+ // dentry invalidated, and re-evaluate its caching status (i.e. if it
+ // has 0 references, drop it). Wait to update parent.children until we
+ // know what to replace the existing dentry with (i.e. one of the
+ // returns below), to avoid a redundant map access.
+ vfsObj.InvalidateDentry(&child.vfsd)
+ if child.isSynthetic() {
+ // Normally we don't mark invalidated dentries as deleted since
+ // they may still exist (but at a different path), and also for
+ // consistency with Linux. However, synthetic files are guaranteed
+ // to become unreachable if their dentries are invalidated, so
+ // treat their invalidation as deletion.
+ child.setDeleted()
+ parent.syntheticChildren--
+ child.decRefLocked()
+ parent.dirents = nil
+ }
+ *ds = appendDentry(*ds, child)
+ }
+ if file.isNil() {
+ // No file exists at this path now. Cache the negative lookup if
+ // allowed.
+ parent.cacheNegativeLookupLocked(name)
+ return nil, nil
+ }
+ // Create a new dentry representing the file.
+ child, err = fs.newDentry(ctx, file, qid, attrMask, &attr)
+ if err != nil {
+ file.close(ctx)
+ delete(parent.children, name)
+ return nil, err
+ }
+ parent.cacheNewChildLocked(child, name)
+ // For now, child has 0 references, so our caller should call
+ // child.checkCachingLocked().
+ *ds = appendDentry(*ds, child)
+ return child, nil
+}
+
+// walkParentDirLocked resolves all but the last path component of rp to an
+// existing directory, starting from the given directory (which is usually
+// rp.Start().Impl().(*dentry)). It does not check that the returned directory
+// is searchable by the provider of rp.
+//
+// Preconditions: fs.renameMu must be locked. !rp.Done(). If
+// !d.cachedMetadataAuthoritative(), then d's cached metadata must be up to
+// date.
+func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
+ for !rp.Final() {
+ d.dirMu.Lock()
+ next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ d.dirMu.Unlock()
+ if err != nil {
+ return nil, err
+ }
+ d = next
+ }
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ return d, nil
+}
+
+// resolveLocked resolves rp to an existing file.
+//
+// Preconditions: fs.renameMu must be locked.
+func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) {
+ d := rp.Start().Impl().(*dentry)
+ if !d.cachedMetadataAuthoritative() {
+ // Get updated metadata for rp.Start() as required by fs.stepLocked().
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return nil, err
+ }
+ }
+ for !rp.Done() {
+ d.dirMu.Lock()
+ next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ d.dirMu.Unlock()
+ if err != nil {
+ return nil, err
+ }
+ d = next
+ }
+ if rp.MustBeDir() && !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ return d, nil
+}
+
+// doCreateAt checks that creating a file at rp is permitted, then invokes
+// createInRemoteDir (if the parent directory is a real remote directory) or
+// createInSyntheticDir (if the parent directory is synthetic) to do so.
+//
+// Preconditions: !rp.Done(). For the final path component in rp,
+// !rp.ShouldFollowSymlink().
+func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, createInRemoteDir func(parent *dentry, name string) error, createInSyntheticDir func(parent *dentry, name string) error) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ start := rp.Start().Impl().(*dentry)
+ if !start.cachedMetadataAuthoritative() {
+ // Get updated metadata for start as required by
+ // fs.walkParentDirLocked().
+ if err := start.updateFromGetattr(ctx); err != nil {
+ return err
+ }
+ }
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return err
+ }
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ name := rp.Component()
+ if name == "." || name == ".." {
+ return syserror.EEXIST
+ }
+ if len(name) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
+ }
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
+ if parent.isDeleted() {
+ return syserror.ENOENT
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+ if parent.isSynthetic() {
+ if child := parent.children[name]; child != nil {
+ return syserror.EEXIST
+ }
+ if createInSyntheticDir == nil {
+ return syserror.EPERM
+ }
+ if err := createInSyntheticDir(parent, name); err != nil {
+ return err
+ }
+ parent.touchCMtime()
+ parent.dirents = nil
+ ev := linux.IN_CREATE
+ if dir {
+ ev |= linux.IN_ISDIR
+ }
+ parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
+ return nil
+ }
+ if fs.opts.interop == InteropModeShared {
+ if child := parent.children[name]; child != nil && child.isSynthetic() {
+ return syserror.EEXIST
+ }
+ // The existence of a non-synthetic dentry at name would be inconclusive
+ // because the file it represents may have been deleted from the remote
+ // filesystem, so we would need to make an RPC to revalidate the dentry.
+ // Just attempt the file creation RPC instead. If a file does exist, the
+ // RPC will fail with EEXIST like we would have. If the RPC succeeds, and a
+ // stale dentry exists, the dentry will fail revalidation next time it's
+ // used.
+ if err := createInRemoteDir(parent, name); err != nil {
+ return err
+ }
+ ev := linux.IN_CREATE
+ if dir {
+ ev |= linux.IN_ISDIR
+ }
+ parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
+ return nil
+ }
+ if child := parent.children[name]; child != nil {
+ return syserror.EEXIST
+ }
+ // No cached dentry exists; however, there might still be an existing file
+ // at name. As above, we attempt the file creation RPC anyway.
+ if err := createInRemoteDir(parent, name); err != nil {
+ return err
+ }
+ if child, ok := parent.children[name]; ok && child == nil {
+ // Delete the now-stale negative dentry.
+ delete(parent.children, name)
+ }
+ parent.touchCMtime()
+ parent.dirents = nil
+ ev := linux.IN_CREATE
+ if dir {
+ ev |= linux.IN_ISDIR
+ }
+ parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
+ return nil
+}
+
+// Preconditions: !rp.Done().
+func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ start := rp.Start().Impl().(*dentry)
+ if !start.cachedMetadataAuthoritative() {
+ // Get updated metadata for start as required by
+ // fs.walkParentDirLocked().
+ if err := start.updateFromGetattr(ctx); err != nil {
+ return err
+ }
+ }
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return err
+ }
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+
+ name := rp.Component()
+ if dir {
+ if name == "." {
+ return syserror.EINVAL
+ }
+ if name == ".." {
+ return syserror.ENOTEMPTY
+ }
+ } else {
+ if name == "." || name == ".." {
+ return syserror.EISDIR
+ }
+ }
+ vfsObj := rp.VirtualFilesystem()
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+
+ child, ok := parent.children[name]
+ if ok && child == nil {
+ return syserror.ENOENT
+ }
+
+ sticky := atomic.LoadUint32(&parent.mode)&linux.ModeSticky != 0
+ if sticky {
+ if !ok {
+ // If the sticky bit is set, we need to retrieve the child to determine
+ // whether removing it is allowed.
+ child, err = fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
+ if err != nil {
+ return err
+ }
+ } else if child != nil && !child.cachedMetadataAuthoritative() {
+ // Make sure the dentry representing the file at name is up to date
+ // before examining its metadata.
+ child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds)
+ if err != nil {
+ return err
+ }
+ }
+ if err := parent.mayDelete(rp.Credentials(), child); err != nil {
+ return err
+ }
+ }
+
+ // If a child dentry exists, prepare to delete it. This should fail if it is
+ // a mount point. We detect mount points by speculatively calling
+ // PrepareDeleteDentry, which fails if child is a mount point. However, we
+ // may need to revalidate the file in this case to make sure that it has not
+ // been deleted or replaced on the remote fs, in which case the mount point
+ // will have disappeared. If calling PrepareDeleteDentry fails again on the
+ // up-to-date dentry, we can be sure that it is a mount point.
+ //
+ // Also note that if child is nil, then it can't be a mount point.
+ if child != nil {
+ // Hold child.dirMu so we can check child.children and
+ // child.syntheticChildren. We don't access these fields until a bit later,
+ // but locking child.dirMu after calling vfs.PrepareDeleteDentry() would
+ // create an inconsistent lock ordering between dentry.dirMu and
+ // vfs.Dentry.mu (in the VFS lock order, it would make dentry.dirMu both "a
+ // FilesystemImpl lock" and "a lock acquired by a FilesystemImpl between
+ // PrepareDeleteDentry and CommitDeleteDentry). To avoid this, lock
+ // child.dirMu before calling PrepareDeleteDentry.
+ child.dirMu.Lock()
+ defer child.dirMu.Unlock()
+ if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
+ // We can skip revalidation in several cases:
+ // - We are not in InteropModeShared
+ // - The parent directory is synthetic, in which case the child must also
+ // be synthetic
+ // - We already updated the child during the sticky bit check above
+ if parent.cachedMetadataAuthoritative() || sticky {
+ return err
+ }
+ child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds)
+ if err != nil {
+ return err
+ }
+ if child != nil {
+ if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
+ return err
+ }
+ }
+ }
+ }
+ flags := uint32(0)
+ // If a dentry exists, use it for best-effort checks on its deletability.
+ if dir {
+ if child != nil {
+ // child must be an empty directory.
+ if child.syntheticChildren != 0 {
+ // This is definitely not an empty directory, irrespective of
+ // fs.opts.interop.
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return syserror.ENOTEMPTY
+ }
+ // If InteropModeShared is in effect and the first call to
+ // PrepareDeleteDentry above succeeded, then child wasn't
+ // revalidated (so we can't expect its file type to be correct) and
+ // individually revalidating its children (to confirm that they
+ // still exist) would be a waste of time.
+ if child.cachedMetadataAuthoritative() {
+ if !child.isDir() {
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return syserror.ENOTDIR
+ }
+ for _, grandchild := range child.children {
+ if grandchild != nil {
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return syserror.ENOTEMPTY
+ }
+ }
+ }
+ }
+ flags = linux.AT_REMOVEDIR
+ } else {
+ // child must be a non-directory file.
+ if child != nil && child.isDir() {
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return syserror.EISDIR
+ }
+ if rp.MustBeDir() {
+ if child != nil {
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ }
+ return syserror.ENOTDIR
+ }
+ }
+ if parent.isSynthetic() {
+ if child == nil {
+ return syserror.ENOENT
+ }
+ } else if child == nil || !child.isSynthetic() {
+ err = parent.file.unlinkAt(ctx, name, flags)
+ if err != nil {
+ if child != nil {
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ }
+ return err
+ }
+ }
+
+ // Generate inotify events for rmdir or unlink.
+ if dir {
+ parent.watches.Notify(name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */)
+ } else {
+ var cw *vfs.Watches
+ if child != nil {
+ cw = &child.watches
+ }
+ vfs.InotifyRemoveChild(cw, &parent.watches, name)
+ }
+
+ if child != nil {
+ vfsObj.CommitDeleteDentry(&child.vfsd)
+ child.setDeleted()
+ if child.isSynthetic() {
+ parent.syntheticChildren--
+ child.decRefLocked()
+ }
+ ds = appendDentry(ds, child)
+ }
+ parent.cacheNegativeLookupLocked(name)
+ if parent.cachedMetadataAuthoritative() {
+ parent.dirents = nil
+ parent.touchCMtime()
+ if dir {
+ parent.decLinks()
+ }
+ }
+ return nil
+}
+
+// renameMuRUnlockAndCheckCaching calls fs.renameMu.RUnlock(), then calls
+// dentry.checkCachingLocked on all dentries in *ds with fs.renameMu locked for
+// writing.
+//
+// ds is a pointer-to-pointer since defer evaluates its arguments immediately,
+// but dentry slices are allocated lazily, and it's much easier to say "defer
+// fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() {
+// fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this.
+func (fs *filesystem) renameMuRUnlockAndCheckCaching(ds **[]*dentry) {
+ fs.renameMu.RUnlock()
+ if *ds == nil {
+ return
+ }
+ if len(**ds) != 0 {
+ fs.renameMu.Lock()
+ for _, d := range **ds {
+ d.checkCachingLocked()
+ }
+ fs.renameMu.Unlock()
+ }
+ putDentrySlice(*ds)
+}
+
+func (fs *filesystem) renameMuUnlockAndCheckCaching(ds **[]*dentry) {
+ if *ds == nil {
+ fs.renameMu.Unlock()
+ return
+ }
+ for _, d := range **ds {
+ d.checkCachingLocked()
+ }
+ fs.renameMu.Unlock()
+ putDentrySlice(*ds)
+}
+
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+ return d.checkPermissions(creds, ats)
+}
+
+// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
+func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ if opts.CheckSearchable {
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ }
+ d.IncRef()
+ return &d.vfsd, nil
+}
+
+// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt.
+func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ start := rp.Start().Impl().(*dentry)
+ if !start.cachedMetadataAuthoritative() {
+ // Get updated metadata for start as required by
+ // fs.walkParentDirLocked().
+ if err := start.updateFromGetattr(ctx); err != nil {
+ return nil, err
+ }
+ }
+ d, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return nil, err
+ }
+ d.IncRef()
+ return &d.vfsd, nil
+}
+
+// LinkAt implements vfs.FilesystemImpl.LinkAt.
+func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string) error {
+ if rp.Mount() != vd.Mount() {
+ return syserror.EXDEV
+ }
+ // 9P2000.L supports hard links, but we don't.
+ return syserror.EPERM
+ }, nil)
+}
+
+// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
+func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
+ creds := rp.Credentials()
+ return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string) error {
+ if _, err := parent.file.mkdir(ctx, name, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)); err != nil {
+ if !opts.ForSyntheticMountpoint || err == syserror.EEXIST {
+ return err
+ }
+ ctx.Infof("Failed to create remote directory %q: %v; falling back to synthetic directory", name, err)
+ parent.createSyntheticChildLocked(&createSyntheticOpts{
+ name: name,
+ mode: linux.S_IFDIR | opts.Mode,
+ kuid: creds.EffectiveKUID,
+ kgid: creds.EffectiveKGID,
+ })
+ }
+ if fs.opts.interop != InteropModeShared {
+ parent.incLinks()
+ }
+ return nil
+ }, func(parent *dentry, name string) error {
+ if !opts.ForSyntheticMountpoint {
+ // Can't create non-synthetic files in synthetic directories.
+ return syserror.EPERM
+ }
+ parent.createSyntheticChildLocked(&createSyntheticOpts{
+ name: name,
+ mode: linux.S_IFDIR | opts.Mode,
+ kuid: creds.EffectiveKUID,
+ kgid: creds.EffectiveKGID,
+ })
+ parent.incLinks()
+ return nil
+ })
+}
+
+// MknodAt implements vfs.FilesystemImpl.MknodAt.
+func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string) error {
+ creds := rp.Credentials()
+ _, err := parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
+ // If the gofer does not allow creating a socket or pipe, create a
+ // synthetic one, i.e. one that is kept entirely in memory.
+ if err == syserror.EPERM {
+ switch opts.Mode.FileType() {
+ case linux.S_IFSOCK:
+ parent.createSyntheticChildLocked(&createSyntheticOpts{
+ name: name,
+ mode: opts.Mode,
+ kuid: creds.EffectiveKUID,
+ kgid: creds.EffectiveKGID,
+ endpoint: opts.Endpoint,
+ })
+ return nil
+ case linux.S_IFIFO:
+ parent.createSyntheticChildLocked(&createSyntheticOpts{
+ name: name,
+ mode: opts.Mode,
+ kuid: creds.EffectiveKUID,
+ kgid: creds.EffectiveKGID,
+ pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize),
+ })
+ return nil
+ }
+ }
+ return err
+ }, nil)
+}
+
+// OpenAt implements vfs.FilesystemImpl.OpenAt.
+func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ // Reject O_TMPFILE, which is not supported; supporting it correctly in the
+ // presence of other remote filesystem users requires remote filesystem
+ // support, and it isn't clear that there's any way to implement this in
+ // 9P.
+ if opts.Flags&linux.O_TMPFILE != 0 {
+ return nil, syserror.EOPNOTSUPP
+ }
+ mayCreate := opts.Flags&linux.O_CREAT != 0
+ mustCreate := opts.Flags&(linux.O_CREAT|linux.O_EXCL) == (linux.O_CREAT | linux.O_EXCL)
+
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+
+ start := rp.Start().Impl().(*dentry)
+ if !start.cachedMetadataAuthoritative() {
+ // Get updated metadata for start as required by fs.stepLocked().
+ if err := start.updateFromGetattr(ctx); err != nil {
+ return nil, err
+ }
+ }
+ if rp.Done() {
+ return start.openLocked(ctx, rp, &opts)
+ }
+
+afterTrailingSymlink:
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return nil, err
+ }
+ // Check for search permission in the parent directory.
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ // Determine whether or not we need to create a file.
+ parent.dirMu.Lock()
+ child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
+ if err == syserror.ENOENT && mayCreate {
+ if parent.isSynthetic() {
+ parent.dirMu.Unlock()
+ return nil, syserror.EPERM
+ }
+ fd, err := parent.createAndOpenChildLocked(ctx, rp, &opts, &ds)
+ parent.dirMu.Unlock()
+ return fd, err
+ }
+ parent.dirMu.Unlock()
+ if err != nil {
+ return nil, err
+ }
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ if !child.isDir() && rp.MustBeDir() {
+ return nil, syserror.ENOTDIR
+ }
+ // Open existing child or follow symlink.
+ if child.isSymlink() && rp.ShouldFollowSymlink() {
+ target, err := child.readlink(ctx, rp.Mount())
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.HandleSymlink(target); err != nil {
+ return nil, err
+ }
+ start = parent
+ goto afterTrailingSymlink
+ }
+ return child.openLocked(ctx, rp, &opts)
+}
+
+// Preconditions: fs.renameMu must be locked.
+func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+ ats := vfs.AccessTypesForOpenFlags(opts)
+ if err := d.checkPermissions(rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+
+ trunc := opts.Flags&linux.O_TRUNC != 0 && d.fileType() == linux.S_IFREG
+ if trunc {
+ // Lock metadataMu *while* we open a regular file with O_TRUNC because
+ // open(2) will change the file size on server.
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ }
+
+ var vfd *vfs.FileDescription
+ var err error
+ mnt := rp.Mount()
+ switch d.fileType() {
+ case linux.S_IFREG:
+ if !d.fs.opts.regularFilesUseSpecialFileFD {
+ if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, trunc); err != nil {
+ return nil, err
+ }
+ fd := &regularFileFD{}
+ fd.LockFD.Init(&d.locks)
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
+ AllowDirectIO: true,
+ }); err != nil {
+ return nil, err
+ }
+ vfd = &fd.vfsfd
+ }
+ case linux.S_IFDIR:
+ // Can't open directories with O_CREAT.
+ if opts.Flags&linux.O_CREAT != 0 {
+ return nil, syserror.EISDIR
+ }
+ // Can't open directories writably.
+ if ats&vfs.MayWrite != 0 {
+ return nil, syserror.EISDIR
+ }
+ if opts.Flags&linux.O_DIRECT != 0 {
+ return nil, syserror.EINVAL
+ }
+ if !d.isSynthetic() {
+ if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, false /* write */, false /* trunc */); err != nil {
+ return nil, err
+ }
+ }
+ fd := &directoryFD{}
+ fd.LockFD.Init(&d.locks)
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+ case linux.S_IFLNK:
+ // Can't open symlinks without O_PATH (which is unimplemented).
+ return nil, syserror.ELOOP
+ case linux.S_IFSOCK:
+ if d.isSynthetic() {
+ return nil, syserror.ENXIO
+ }
+ if d.fs.iopts.OpenSocketsByConnecting {
+ return d.connectSocketLocked(ctx, opts)
+ }
+ case linux.S_IFIFO:
+ if d.isSynthetic() {
+ return d.pipe.Open(ctx, mnt, &d.vfsd, opts.Flags, &d.locks)
+ }
+ }
+
+ if vfd == nil {
+ if vfd, err = d.openSpecialFileLocked(ctx, mnt, opts); err != nil {
+ return nil, err
+ }
+ }
+
+ if trunc {
+ // If no errors occured so far then update file size in memory. This
+ // step is required even if !d.cachedMetadataAuthoritative() because
+ // d.mappings has to be updated.
+ // d.metadataMu has already been acquired if trunc == true.
+ d.updateFileSizeLocked(0)
+
+ if d.cachedMetadataAuthoritative() {
+ d.touchCMtimeLocked()
+ }
+ }
+ return vfd, err
+}
+
+func (d *dentry) connectSocketLocked(ctx context.Context, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+ if opts.Flags&linux.O_DIRECT != 0 {
+ return nil, syserror.EINVAL
+ }
+ fdObj, err := d.file.connect(ctx, p9.AnonymousSocket)
+ if err != nil {
+ return nil, err
+ }
+ fd, err := host.NewFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fdObj.FD(), &host.NewFDOptions{
+ HaveFlags: true,
+ Flags: opts.Flags,
+ })
+ if err != nil {
+ fdObj.Close()
+ return nil, err
+ }
+ fdObj.Release()
+ return fd, nil
+}
+
+func (d *dentry) openSpecialFileLocked(ctx context.Context, mnt *vfs.Mount, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+ ats := vfs.AccessTypesForOpenFlags(opts)
+ if opts.Flags&linux.O_DIRECT != 0 {
+ return nil, syserror.EINVAL
+ }
+ // We assume that the server silently inserts O_NONBLOCK in the open flags
+ // for all named pipes (because all existing gofers do this).
+ //
+ // NOTE(b/133875563): This makes named pipe opens racy, because the
+ // mechanisms for translating nonblocking to blocking opens can only detect
+ // the instantaneous presence of a peer holding the other end of the pipe
+ // open, not whether the pipe was *previously* opened by a peer that has
+ // since closed its end.
+ isBlockingOpenOfNamedPipe := d.fileType() == linux.S_IFIFO && opts.Flags&linux.O_NONBLOCK == 0
+retry:
+ h, err := openHandle(ctx, d.file, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0)
+ if err != nil {
+ if isBlockingOpenOfNamedPipe && ats == vfs.MayWrite && err == syserror.ENXIO {
+ // An attempt to open a named pipe with O_WRONLY|O_NONBLOCK fails
+ // with ENXIO if opening the same named pipe with O_WRONLY would
+ // block because there are no readers of the pipe.
+ if err := sleepBetweenNamedPipeOpenChecks(ctx); err != nil {
+ return nil, err
+ }
+ goto retry
+ }
+ return nil, err
+ }
+ if isBlockingOpenOfNamedPipe && ats == vfs.MayRead && h.fd >= 0 {
+ if err := blockUntilNonblockingPipeHasWriter(ctx, h.fd); err != nil {
+ h.close(ctx)
+ return nil, err
+ }
+ }
+ fd, err := newSpecialFileFD(h, mnt, d, &d.locks, opts.Flags)
+ if err != nil {
+ h.close(ctx)
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// Preconditions: d.fs.renameMu must be locked. d.dirMu must be locked.
+// !d.isSynthetic().
+func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions, ds **[]*dentry) (*vfs.FileDescription, error) {
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ if d.isDeleted() {
+ return nil, syserror.ENOENT
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return nil, err
+ }
+ defer mnt.EndWrite()
+
+ // 9P2000.L's lcreate takes a fid representing the parent directory, and
+ // converts it into an open fid representing the created file, so we need
+ // to duplicate the directory fid first.
+ _, dirfile, err := d.file.walk(ctx, nil)
+ if err != nil {
+ return nil, err
+ }
+ creds := rp.Credentials()
+ name := rp.Component()
+ // Filter file creation flags and O_LARGEFILE out; the create RPC already
+ // has the semantics of O_CREAT|O_EXCL, while some servers will choke on
+ // O_LARGEFILE.
+ createFlags := p9.OpenFlags(opts.Flags &^ (vfs.FileCreationFlags | linux.O_LARGEFILE))
+ fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
+ if err != nil {
+ dirfile.close(ctx)
+ return nil, err
+ }
+ // Then we need to walk to the file we just created to get a non-open fid
+ // representing it, and to get its metadata. This must use d.file since, as
+ // explained above, dirfile was invalidated by dirfile.Create().
+ _, nonOpenFile, attrMask, attr, err := d.file.walkGetAttrOne(ctx, name)
+ if err != nil {
+ openFile.close(ctx)
+ if fdobj != nil {
+ fdobj.Close()
+ }
+ return nil, err
+ }
+
+ // Construct the new dentry.
+ child, err := d.fs.newDentry(ctx, nonOpenFile, createQID, attrMask, &attr)
+ if err != nil {
+ nonOpenFile.close(ctx)
+ openFile.close(ctx)
+ if fdobj != nil {
+ fdobj.Close()
+ }
+ return nil, err
+ }
+ *ds = appendDentry(*ds, child)
+ // Incorporate the fid that was opened by lcreate.
+ useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD
+ if useRegularFileFD {
+ child.handleMu.Lock()
+ child.handle.file = openFile
+ if fdobj != nil {
+ child.handle.fd = int32(fdobj.Release())
+ }
+ child.handleReadable = vfs.MayReadFileWithOpenFlags(opts.Flags)
+ child.handleWritable = vfs.MayWriteFileWithOpenFlags(opts.Flags)
+ child.handleMu.Unlock()
+ }
+ // Insert the dentry into the tree.
+ d.cacheNewChildLocked(child, name)
+ if d.cachedMetadataAuthoritative() {
+ d.touchCMtime()
+ d.dirents = nil
+ }
+
+ // Finally, construct a file description representing the created file.
+ var childVFSFD *vfs.FileDescription
+ if useRegularFileFD {
+ fd := &regularFileFD{}
+ fd.LockFD.Init(&child.locks)
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &child.vfsd, &vfs.FileDescriptionOptions{
+ AllowDirectIO: true,
+ }); err != nil {
+ return nil, err
+ }
+ childVFSFD = &fd.vfsfd
+ } else {
+ h := handle{
+ file: openFile,
+ fd: -1,
+ }
+ if fdobj != nil {
+ h.fd = int32(fdobj.Release())
+ }
+ fd, err := newSpecialFileFD(h, mnt, child, &d.locks, opts.Flags)
+ if err != nil {
+ h.close(ctx)
+ return nil, err
+ }
+ childVFSFD = &fd.vfsfd
+ }
+ d.watches.Notify(name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */)
+ return childVFSFD, nil
+}
+
+// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
+func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return "", err
+ }
+ if !d.isSymlink() {
+ return "", syserror.EINVAL
+ }
+ return d.readlink(ctx, rp.Mount())
+}
+
+// RenameAt implements vfs.FilesystemImpl.RenameAt.
+func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
+ if opts.Flags != 0 {
+ // Requires 9P support.
+ return syserror.EINVAL
+ }
+
+ var ds *[]*dentry
+ fs.renameMu.Lock()
+ defer fs.renameMuUnlockAndCheckCaching(&ds)
+ newParent, err := fs.walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry), &ds)
+ if err != nil {
+ return err
+ }
+ newName := rp.Component()
+ if newName == "." || newName == ".." {
+ return syserror.EBUSY
+ }
+ mnt := rp.Mount()
+ if mnt != oldParentVD.Mount() {
+ return syserror.EXDEV
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+
+ oldParent := oldParentVD.Dentry().Impl().(*dentry)
+ if !oldParent.cachedMetadataAuthoritative() {
+ if err := oldParent.updateFromGetattr(ctx); err != nil {
+ return err
+ }
+ }
+ creds := rp.Credentials()
+ if err := oldParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ vfsObj := rp.VirtualFilesystem()
+ // We need a dentry representing the renamed file since, if it's a
+ // directory, we need to check for write permission on it.
+ oldParent.dirMu.Lock()
+ defer oldParent.dirMu.Unlock()
+ renamed, err := fs.getChildLocked(ctx, vfsObj, oldParent, oldName, &ds)
+ if err != nil {
+ return err
+ }
+ if renamed == nil {
+ return syserror.ENOENT
+ }
+ if err := oldParent.mayDelete(creds, renamed); err != nil {
+ return err
+ }
+ if renamed.isDir() {
+ if renamed == newParent || genericIsAncestorDentry(renamed, newParent) {
+ return syserror.EINVAL
+ }
+ if oldParent != newParent {
+ if err := renamed.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return err
+ }
+ }
+ } else {
+ if opts.MustBeDir || rp.MustBeDir() {
+ return syserror.ENOTDIR
+ }
+ }
+
+ if oldParent != newParent {
+ if err := newParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ newParent.dirMu.Lock()
+ defer newParent.dirMu.Unlock()
+ }
+ if newParent.isDeleted() {
+ return syserror.ENOENT
+ }
+ replaced, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), newParent, newName, &ds)
+ if err != nil {
+ return err
+ }
+ var replacedVFSD *vfs.Dentry
+ if replaced != nil {
+ replacedVFSD = &replaced.vfsd
+ if replaced.isDir() {
+ if !renamed.isDir() {
+ return syserror.EISDIR
+ }
+ } else {
+ if rp.MustBeDir() || renamed.isDir() {
+ return syserror.ENOTDIR
+ }
+ }
+ }
+
+ if oldParent == newParent && oldName == newName {
+ return nil
+ }
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ if err := vfsObj.PrepareRenameDentry(mntns, &renamed.vfsd, replacedVFSD); err != nil {
+ return err
+ }
+
+ // Update the remote filesystem.
+ if !renamed.isSynthetic() {
+ if err := renamed.file.rename(ctx, newParent.file, newName); err != nil {
+ vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD)
+ return err
+ }
+ } else if replaced != nil && !replaced.isSynthetic() {
+ // We are replacing an existing real file with a synthetic one, so we
+ // need to unlink the former.
+ flags := uint32(0)
+ if replaced.isDir() {
+ flags = linux.AT_REMOVEDIR
+ }
+ if err := newParent.file.unlinkAt(ctx, newName, flags); err != nil {
+ vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD)
+ return err
+ }
+ }
+
+ // Update the dentry tree.
+ vfsObj.CommitRenameReplaceDentry(&renamed.vfsd, replacedVFSD)
+ if replaced != nil {
+ replaced.setDeleted()
+ if replaced.isSynthetic() {
+ newParent.syntheticChildren--
+ replaced.decRefLocked()
+ }
+ ds = appendDentry(ds, replaced)
+ }
+ oldParent.cacheNegativeLookupLocked(oldName)
+ // We don't use newParent.cacheNewChildLocked() since we don't want to mess
+ // with reference counts and queue oldParent for checkCachingLocked if the
+ // parent isn't actually changing.
+ if oldParent != newParent {
+ ds = appendDentry(ds, oldParent)
+ newParent.IncRef()
+ if renamed.isSynthetic() {
+ oldParent.syntheticChildren--
+ newParent.syntheticChildren++
+ }
+ }
+ renamed.parent = newParent
+ renamed.name = newName
+ if newParent.children == nil {
+ newParent.children = make(map[string]*dentry)
+ }
+ newParent.children[newName] = renamed
+
+ // Update metadata.
+ if renamed.cachedMetadataAuthoritative() {
+ renamed.touchCtime()
+ }
+ if oldParent.cachedMetadataAuthoritative() {
+ oldParent.dirents = nil
+ oldParent.touchCMtime()
+ if renamed.isDir() {
+ oldParent.decLinks()
+ }
+ }
+ if newParent.cachedMetadataAuthoritative() {
+ newParent.dirents = nil
+ newParent.touchCMtime()
+ if renamed.isDir() && (replaced == nil || !replaced.isDir()) {
+ // Increase the link count if we did not replace another directory.
+ newParent.incLinks()
+ }
+ }
+ vfs.InotifyRename(ctx, &renamed.watches, &oldParent.watches, &newParent.watches, oldName, newName, renamed.isDir())
+ return nil
+}
+
+// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
+func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ return fs.unlinkAt(ctx, rp, true /* dir */)
+}
+
+// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
+func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ fs.renameMuRUnlockAndCheckCaching(&ds)
+ return err
+ }
+ if err := d.setStat(ctx, rp.Credentials(), &opts.Stat, rp.Mount()); err != nil {
+ fs.renameMuRUnlockAndCheckCaching(&ds)
+ return err
+ }
+ fs.renameMuRUnlockAndCheckCaching(&ds)
+
+ if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
+ d.InotifyWithParent(ev, 0, vfs.InodeEvent)
+ }
+ return nil
+}
+
+// StatAt implements vfs.FilesystemImpl.StatAt.
+func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ // Since walking updates metadata for all traversed dentries under
+ // InteropModeShared, including the returned one, we can return cached
+ // metadata here regardless of fs.opts.interop.
+ var stat linux.Statx
+ d.statTo(&stat)
+ return stat, nil
+}
+
+// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
+func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return linux.Statfs{}, err
+ }
+ // If d is synthetic, invoke statfs on the first ancestor of d that isn't.
+ for d.isSynthetic() {
+ d = d.parent
+ }
+ fsstat, err := d.file.statFS(ctx)
+ if err != nil {
+ return linux.Statfs{}, err
+ }
+ nameLen := uint64(fsstat.NameLength)
+ if nameLen > maxFilenameLen {
+ nameLen = maxFilenameLen
+ }
+ return linux.Statfs{
+ // This is primarily for distinguishing a gofer file system in
+ // tests. Testing is important, so instead of defining
+ // something completely random, use a standard value.
+ Type: linux.V9FS_MAGIC,
+ BlockSize: int64(fsstat.BlockSize),
+ Blocks: fsstat.Blocks,
+ BlocksFree: fsstat.BlocksFree,
+ BlocksAvailable: fsstat.BlocksAvailable,
+ Files: fsstat.Files,
+ FilesFree: fsstat.FilesFree,
+ NameLength: nameLen,
+ }, nil
+}
+
+// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
+func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string) error {
+ creds := rp.Credentials()
+ _, err := parent.file.symlink(ctx, target, name, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
+ return err
+ }, nil)
+}
+
+// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
+func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ return fs.unlinkAt(ctx, rp, false /* dir */)
+}
+
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ if d.isSocket() {
+ if !d.isSynthetic() {
+ d.IncRef()
+ return &endpoint{
+ dentry: d,
+ file: d.file.file,
+ path: opts.Addr,
+ }, nil
+ }
+ return d.endpoint, nil
+ }
+ return nil, syserror.ECONNREFUSED
+}
+
+// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ return d.listxattr(ctx, rp.Credentials(), size)
+}
+
+// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return "", err
+ }
+ return d.getxattr(ctx, rp.Credentials(), &opts)
+}
+
+// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
+func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ fs.renameMuRUnlockAndCheckCaching(&ds)
+ return err
+ }
+ if err := d.setxattr(ctx, rp.Credentials(), &opts); err != nil {
+ fs.renameMuRUnlockAndCheckCaching(&ds)
+ return err
+ }
+ fs.renameMuRUnlockAndCheckCaching(&ds)
+
+ d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
+}
+
+// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
+func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ fs.renameMuRUnlockAndCheckCaching(&ds)
+ return err
+ }
+ if err := d.removexattr(ctx, rp.Credentials(), name); err != nil {
+ fs.renameMuRUnlockAndCheckCaching(&ds)
+ return err
+ }
+ fs.renameMuRUnlockAndCheckCaching(&ds)
+
+ d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ fs.renameMu.RLock()
+ defer fs.renameMu.RUnlock()
+ return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b)
+}
+
+func (fs *filesystem) nextSyntheticIno() inodeNumber {
+ return inodeNumber(atomic.AddUint64(&fs.syntheticSeq, 1) | syntheticInoMask)
+}
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
new file mode 100644
index 000000000..2b83094cd
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -0,0 +1,1550 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gofer provides a filesystem implementation that is backed by a 9p
+// server, interchangably referred to as "gofers" throughout this package.
+//
+// Lock order:
+// regularFileFD/directoryFD.mu
+// filesystem.renameMu
+// dentry.dirMu
+// filesystem.syncMu
+// dentry.metadataMu
+// *** "memmap.Mappable locks" below this point
+// dentry.mapsMu
+// *** "memmap.Mappable locks taken by Translate" below this point
+// dentry.handleMu
+// dentry.dataMu
+//
+// Locking dentry.dirMu in multiple dentries requires that either ancestor
+// dentries are locked before descendant dentries, or that filesystem.renameMu
+// is locked for writing.
+package gofer
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/unet"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Name is the default filesystem name.
+const Name = "9p"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ vfsfs vfs.Filesystem
+
+ // mfp is used to allocate memory that caches regular file contents. mfp is
+ // immutable.
+ mfp pgalloc.MemoryFileProvider
+
+ // Immutable options.
+ opts filesystemOptions
+ iopts InternalFilesystemOptions
+
+ // client is the client used by this filesystem. client is immutable.
+ client *p9.Client
+
+ // clock is a realtime clock used to set timestamps in file operations.
+ clock ktime.Clock
+
+ // devMinor is the filesystem's minor device number. devMinor is immutable.
+ devMinor uint32
+
+ // renameMu serves two purposes:
+ //
+ // - It synchronizes path resolution with renaming initiated by this
+ // client.
+ //
+ // - It is held by path resolution to ensure that reachable dentries remain
+ // valid. A dentry is reachable by path resolution if it has a non-zero
+ // reference count (such that it is usable as vfs.ResolvingPath.Start() or
+ // is reachable from its children), or if it is a child dentry (such that
+ // it is reachable from its parent).
+ renameMu sync.RWMutex
+
+ // cachedDentries contains all dentries with 0 references. (Due to race
+ // conditions, it may also contain dentries with non-zero references.)
+ // cachedDentriesLen is the number of dentries in cachedDentries. These
+ // fields are protected by renameMu.
+ cachedDentries dentryList
+ cachedDentriesLen uint64
+
+ // syncableDentries contains all dentries in this filesystem for which
+ // !dentry.file.isNil(). specialFileFDs contains all open specialFileFDs.
+ // These fields are protected by syncMu.
+ syncMu sync.Mutex
+ syncableDentries map[*dentry]struct{}
+ specialFileFDs map[*specialFileFD]struct{}
+
+ // syntheticSeq stores a counter to used to generate unique inodeNumber for
+ // synthetic dentries.
+ syntheticSeq uint64
+}
+
+// inodeNumber represents inode number reported in Dirent.Ino. For regular
+// dentries, it comes from QID.Path from the 9P server. Synthetic dentries
+// have have their inodeNumber generated sequentially, with the MSB reserved to
+// prevent conflicts with regular dentries.
+type inodeNumber uint64
+
+// Reserve MSB for synthetic mounts.
+const syntheticInoMask = uint64(1) << 63
+
+func inoFromPath(path uint64) inodeNumber {
+ if path&syntheticInoMask != 0 {
+ log.Warningf("Dropping MSB from ino, collision is possible. Original: %d, new: %d", path, path&^syntheticInoMask)
+ }
+ return inodeNumber(path &^ syntheticInoMask)
+}
+
+type filesystemOptions struct {
+ // "Standard" 9P options.
+ fd int
+ aname string
+ interop InteropMode // derived from the "cache" mount option
+ dfltuid auth.KUID
+ dfltgid auth.KGID
+ msize uint32
+ version string
+
+ // maxCachedDentries is the maximum number of dentries with 0 references
+ // retained by the client.
+ maxCachedDentries uint64
+
+ // If forcePageCache is true, host FDs may not be used for application
+ // memory mappings even if available; instead, the client must perform its
+ // own caching of regular file pages. This is primarily useful for testing.
+ forcePageCache bool
+
+ // If limitHostFDTranslation is true, apply maxFillRange() constraints to
+ // host FD mappings returned by dentry.(memmap.Mappable).Translate(). This
+ // makes memory accounting behavior more consistent between cases where
+ // host FDs are / are not available, but may increase the frequency of
+ // sentry-handled page faults on files for which a host FD is available.
+ limitHostFDTranslation bool
+
+ // If overlayfsStaleRead is true, O_RDONLY host FDs provided by the remote
+ // filesystem may not be coherent with writable host FDs opened later, so
+ // all uses of the former must be replaced by uses of the latter. This is
+ // usually only the case when the remote filesystem is a Linux overlayfs
+ // mount. (Prior to Linux 4.18, patch series centered on commit
+ // d1d04ef8572b "ovl: stack file ops", both I/O and memory mappings were
+ // incoherent between pre-copy-up and post-copy-up FDs; after that patch
+ // series, only memory mappings are incoherent.)
+ overlayfsStaleRead bool
+
+ // If regularFilesUseSpecialFileFD is true, application FDs representing
+ // regular files will use distinct file handles for each FD, in the same
+ // way that application FDs representing "special files" such as sockets
+ // do. Note that this disables client caching and mmap for regular files.
+ regularFilesUseSpecialFileFD bool
+}
+
+// InteropMode controls the client's interaction with other remote filesystem
+// users.
+type InteropMode uint32
+
+const (
+ // InteropModeExclusive is appropriate when the filesystem client is the
+ // only user of the remote filesystem.
+ //
+ // - The client may cache arbitrary filesystem state (file data, metadata,
+ // filesystem structure, etc.).
+ //
+ // - Client changes to filesystem state may be sent to the remote
+ // filesystem asynchronously, except when server permission checks are
+ // necessary.
+ //
+ // - File timestamps are based on client clocks. This ensures that users of
+ // the client observe timestamps that are coherent with their own clocks
+ // and consistent with Linux's semantics. However, since it is not always
+ // possible for clients to set arbitrary atimes and mtimes, and never
+ // possible for clients to set arbitrary ctimes, file timestamp changes are
+ // stored in the client only and never sent to the remote filesystem.
+ InteropModeExclusive InteropMode = iota
+
+ // InteropModeWritethrough is appropriate when there are read-only users of
+ // the remote filesystem that expect to observe changes made by the
+ // filesystem client.
+ //
+ // - The client may cache arbitrary filesystem state.
+ //
+ // - Client changes to filesystem state must be sent to the remote
+ // filesystem synchronously.
+ //
+ // - File timestamps are based on client clocks. As a corollary, access
+ // timestamp changes from other remote filesystem users will not be visible
+ // to the client.
+ InteropModeWritethrough
+
+ // InteropModeShared is appropriate when there are users of the remote
+ // filesystem that may mutate its state other than the client.
+ //
+ // - The client must verify ("revalidate") cached filesystem state before
+ // using it.
+ //
+ // - Client changes to filesystem state must be sent to the remote
+ // filesystem synchronously.
+ //
+ // - File timestamps are based on server clocks. This is necessary to
+ // ensure that timestamp changes are synchronized between remote filesystem
+ // users.
+ //
+ // Note that the correctness of InteropModeShared depends on the server
+ // correctly implementing 9P fids (i.e. each fid immutably represents a
+ // single filesystem object), even in the presence of remote filesystem
+ // mutations from other users. If this is violated, the behavior of the
+ // client is undefined.
+ InteropModeShared
+)
+
+// InternalFilesystemOptions may be passed as
+// vfs.GetFilesystemOptions.InternalData to FilesystemType.GetFilesystem.
+type InternalFilesystemOptions struct {
+ // If LeakConnection is true, do not close the connection to the server
+ // when the Filesystem is released. This is necessary for deployments in
+ // which servers can handle only a single client and report failure if that
+ // client disconnects.
+ LeakConnection bool
+
+ // If OpenSocketsByConnecting is true, silently translate attempts to open
+ // files identifying as sockets to connect RPCs.
+ OpenSocketsByConnecting bool
+}
+
+// _V9FS_DEFUID and _V9FS_DEFGID (from Linux's fs/9p/v9fs.h) are the default
+// UIDs and GIDs used for files that do not provide a specific owner or group
+// respectively.
+const (
+ // uint32(-2) doesn't work in Go.
+ _V9FS_DEFUID = auth.KUID(4294967294)
+ _V9FS_DEFGID = auth.KGID(4294967294)
+)
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ mfp := pgalloc.MemoryFileProviderFromContext(ctx)
+ if mfp == nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: context does not provide a pgalloc.MemoryFileProvider")
+ return nil, nil, syserror.EINVAL
+ }
+
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ var fsopts filesystemOptions
+
+ // Check that the transport is "fd".
+ trans, ok := mopts["trans"]
+ if !ok {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: transport must be specified as 'trans=fd'")
+ return nil, nil, syserror.EINVAL
+ }
+ delete(mopts, "trans")
+ if trans != "fd" {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: unsupported transport: trans=%s", trans)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Check that read and write FDs are provided and identical.
+ rfdstr, ok := mopts["rfdno"]
+ if !ok {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD must be specified as 'rfdno=<file descriptor>")
+ return nil, nil, syserror.EINVAL
+ }
+ delete(mopts, "rfdno")
+ rfd, err := strconv.Atoi(rfdstr)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid read FD: rfdno=%s", rfdstr)
+ return nil, nil, syserror.EINVAL
+ }
+ wfdstr, ok := mopts["wfdno"]
+ if !ok {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: write FD must be specified as 'wfdno=<file descriptor>")
+ return nil, nil, syserror.EINVAL
+ }
+ delete(mopts, "wfdno")
+ wfd, err := strconv.Atoi(wfdstr)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid write FD: wfdno=%s", wfdstr)
+ return nil, nil, syserror.EINVAL
+ }
+ if rfd != wfd {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD (%d) and write FD (%d) must be equal", rfd, wfd)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.fd = rfd
+
+ // Get the attach name.
+ fsopts.aname = "/"
+ if aname, ok := mopts["aname"]; ok {
+ delete(mopts, "aname")
+ fsopts.aname = aname
+ }
+
+ // Parse the cache policy. For historical reasons, this defaults to the
+ // least generally-applicable option, InteropModeExclusive.
+ fsopts.interop = InteropModeExclusive
+ if cache, ok := mopts["cache"]; ok {
+ delete(mopts, "cache")
+ switch cache {
+ case "fscache":
+ fsopts.interop = InteropModeExclusive
+ case "fscache_writethrough":
+ fsopts.interop = InteropModeWritethrough
+ case "none":
+ fsopts.regularFilesUseSpecialFileFD = true
+ fallthrough
+ case "remote_revalidating":
+ fsopts.interop = InteropModeShared
+ default:
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid cache policy: cache=%s", cache)
+ return nil, nil, syserror.EINVAL
+ }
+ }
+
+ // Parse the default UID and GID.
+ fsopts.dfltuid = _V9FS_DEFUID
+ if dfltuidstr, ok := mopts["dfltuid"]; ok {
+ delete(mopts, "dfltuid")
+ dfltuid, err := strconv.ParseUint(dfltuidstr, 10, 32)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltuid=%s", dfltuidstr)
+ return nil, nil, syserror.EINVAL
+ }
+ // In Linux, dfltuid is interpreted as a UID and is converted to a KUID
+ // in the caller's user namespace, but goferfs isn't
+ // application-mountable.
+ fsopts.dfltuid = auth.KUID(dfltuid)
+ }
+ fsopts.dfltgid = _V9FS_DEFGID
+ if dfltgidstr, ok := mopts["dfltgid"]; ok {
+ delete(mopts, "dfltgid")
+ dfltgid, err := strconv.ParseUint(dfltgidstr, 10, 32)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltgid=%s", dfltgidstr)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.dfltgid = auth.KGID(dfltgid)
+ }
+
+ // Parse the 9P message size.
+ fsopts.msize = 1024 * 1024 // 1M, tested to give good enough performance up to 64M
+ if msizestr, ok := mopts["msize"]; ok {
+ delete(mopts, "msize")
+ msize, err := strconv.ParseUint(msizestr, 10, 32)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid message size: msize=%s", msizestr)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.msize = uint32(msize)
+ }
+
+ // Parse the 9P protocol version.
+ fsopts.version = p9.HighestVersionString()
+ if version, ok := mopts["version"]; ok {
+ delete(mopts, "version")
+ fsopts.version = version
+ }
+
+ // Parse the dentry cache limit.
+ fsopts.maxCachedDentries = 1000
+ if str, ok := mopts["dentry_cache_limit"]; ok {
+ delete(mopts, "dentry_cache_limit")
+ maxCachedDentries, err := strconv.ParseUint(str, 10, 64)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.maxCachedDentries = maxCachedDentries
+ }
+
+ // Handle simple flags.
+ if _, ok := mopts["force_page_cache"]; ok {
+ delete(mopts, "force_page_cache")
+ fsopts.forcePageCache = true
+ }
+ if _, ok := mopts["limit_host_fd_translation"]; ok {
+ delete(mopts, "limit_host_fd_translation")
+ fsopts.limitHostFDTranslation = true
+ }
+ if _, ok := mopts["overlayfs_stale_read"]; ok {
+ delete(mopts, "overlayfs_stale_read")
+ fsopts.overlayfsStaleRead = true
+ }
+ // fsopts.regularFilesUseSpecialFileFD can only be enabled by specifying
+ // "cache=none".
+
+ // Check for unparsed options.
+ if len(mopts) != 0 {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: unknown options: %v", mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Handle internal options.
+ iopts, ok := opts.InternalData.(InternalFilesystemOptions)
+ if opts.InternalData != nil && !ok {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted gofer.InternalFilesystemOptions", opts.InternalData)
+ return nil, nil, syserror.EINVAL
+ }
+ // If !ok, iopts being the zero value is correct.
+
+ // Establish a connection with the server.
+ conn, err := unet.NewSocket(fsopts.fd)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Perform version negotiation with the server.
+ ctx.UninterruptibleSleepStart(false)
+ client, err := p9.NewClient(conn, fsopts.msize, fsopts.version)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ conn.Close()
+ return nil, nil, err
+ }
+ // Ownership of conn has been transferred to client.
+
+ // Perform attach to obtain the filesystem root.
+ ctx.UninterruptibleSleepStart(false)
+ attached, err := client.Attach(fsopts.aname)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ client.Close()
+ return nil, nil, err
+ }
+ attachFile := p9file{attached}
+ qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask())
+ if err != nil {
+ attachFile.close(ctx)
+ client.Close()
+ return nil, nil, err
+ }
+
+ // Construct the filesystem object.
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ attachFile.close(ctx)
+ client.Close()
+ return nil, nil, err
+ }
+ fs := &filesystem{
+ mfp: mfp,
+ opts: fsopts,
+ iopts: iopts,
+ client: client,
+ clock: ktime.RealtimeClockFromContext(ctx),
+ devMinor: devMinor,
+ syncableDentries: make(map[*dentry]struct{}),
+ specialFileFDs: make(map[*specialFileFD]struct{}),
+ }
+ fs.vfsfs.Init(vfsObj, &fstype, fs)
+
+ // Construct the root dentry.
+ root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr)
+ if err != nil {
+ attachFile.close(ctx)
+ fs.vfsfs.DecRef()
+ return nil, nil, err
+ }
+ // Set the root's reference count to 2. One reference is returned to the
+ // caller, and the other is deliberately leaked to prevent the root from
+ // being "cached" and subsequently evicted. Its resources will still be
+ // cleaned up by fs.Release().
+ root.refs = 2
+
+ return &fs.vfsfs, &root.vfsd, nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release() {
+ ctx := context.Background()
+ mf := fs.mfp.MemoryFile()
+
+ fs.syncMu.Lock()
+ for d := range fs.syncableDentries {
+ d.handleMu.Lock()
+ d.dataMu.Lock()
+ if d.handleWritable {
+ // Write dirty cached data to the remote file.
+ if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt); err != nil {
+ log.Warningf("gofer.filesystem.Release: failed to flush dentry: %v", err)
+ }
+ // TODO(jamieliu): Do we need to flushf/fsync d?
+ }
+ // Discard cached pages.
+ d.cache.DropAll(mf)
+ d.dirty.RemoveAll()
+ d.dataMu.Unlock()
+ // Close the host fd if one exists.
+ if d.handle.fd >= 0 {
+ syscall.Close(int(d.handle.fd))
+ d.handle.fd = -1
+ }
+ d.handleMu.Unlock()
+ }
+ // There can't be any specialFileFDs still using fs, since each such
+ // FileDescription would hold a reference on a Mount holding a reference on
+ // fs.
+ fs.syncMu.Unlock()
+
+ if !fs.iopts.LeakConnection {
+ // Close the connection to the server. This implicitly clunks all fids.
+ fs.client.Close()
+ }
+
+ fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+}
+
+// dentry implements vfs.DentryImpl.
+type dentry struct {
+ vfsd vfs.Dentry
+
+ // refs is the reference count. Each dentry holds a reference on its
+ // parent, even if disowned. An additional reference is held on all
+ // synthetic dentries until they are unlinked or invalidated. When refs
+ // reaches 0, the dentry may be added to the cache or destroyed. If refs ==
+ // -1, the dentry has already been destroyed. refs is accessed using atomic
+ // memory operations.
+ refs int64
+
+ // fs is the owning filesystem. fs is immutable.
+ fs *filesystem
+
+ // parent is this dentry's parent directory. Each dentry holds a reference
+ // on its parent. If this dentry is a filesystem root, parent is nil.
+ // parent is protected by filesystem.renameMu.
+ parent *dentry
+
+ // name is the name of this dentry in its parent. If this dentry is a
+ // filesystem root, name is the empty string. name is protected by
+ // filesystem.renameMu.
+ name string
+
+ // We don't support hard links, so each dentry maps 1:1 to an inode.
+
+ // file is the unopened p9.File that backs this dentry. file is immutable.
+ //
+ // If file.isNil(), this dentry represents a synthetic file, i.e. a file
+ // that does not exist on the remote filesystem. As of this writing, the
+ // only files that can be synthetic are sockets, pipes, and directories.
+ file p9file
+
+ // If deleted is non-zero, the file represented by this dentry has been
+ // deleted. deleted is accessed using atomic memory operations.
+ deleted uint32
+
+ // If cached is true, dentryEntry links dentry into
+ // filesystem.cachedDentries. cached and dentryEntry are protected by
+ // filesystem.renameMu.
+ cached bool
+ dentryEntry
+
+ dirMu sync.Mutex
+
+ // If this dentry represents a directory, children contains:
+ //
+ // - Mappings of child filenames to dentries representing those children.
+ //
+ // - Mappings of child filenames that are known not to exist to nil
+ // dentries (only if InteropModeShared is not in effect and the directory
+ // is not synthetic).
+ //
+ // children is protected by dirMu.
+ children map[string]*dentry
+
+ // If this dentry represents a directory, syntheticChildren is the number
+ // of child dentries for which dentry.isSynthetic() == true.
+ // syntheticChildren is protected by dirMu.
+ syntheticChildren int
+
+ // If this dentry represents a directory,
+ // dentry.cachedMetadataAuthoritative() == true, and dirents is not nil, it
+ // is a cache of all entries in the directory, in the order they were
+ // returned by the server. dirents is protected by dirMu.
+ dirents []vfs.Dirent
+
+ // Cached metadata; protected by metadataMu and accessed using atomic
+ // memory operations unless otherwise specified.
+ metadataMu sync.Mutex
+ ino inodeNumber // immutable
+ mode uint32 // type is immutable, perms are mutable
+ uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
+ gid uint32 // auth.KGID, but ...
+ blockSize uint32 // 0 if unknown
+ // Timestamps, all nsecs from the Unix epoch.
+ atime int64
+ mtime int64
+ ctime int64
+ btime int64
+ // File size, protected by both metadataMu and dataMu (i.e. both must be
+ // locked to mutate it).
+ size uint64
+
+ // nlink counts the number of hard links to this dentry. It's updated and
+ // accessed using atomic operations. It's not protected by metadataMu like the
+ // other metadata fields.
+ nlink uint32
+
+ mapsMu sync.Mutex
+
+ // If this dentry represents a regular file, mappings tracks mappings of
+ // the file into memmap.MappingSpaces. mappings is protected by mapsMu.
+ mappings memmap.MappingSet
+
+ // If this dentry represents a regular file or directory:
+ //
+ // - handle is the I/O handle used by all regularFileFDs/directoryFDs
+ // representing this dentry.
+ //
+ // - handleReadable is true if handle is readable.
+ //
+ // - handleWritable is true if handle is writable.
+ //
+ // Invariants:
+ //
+ // - If handleReadable == handleWritable == false, then handle.file == nil
+ // (i.e. there is no open handle). Conversely, if handleReadable ||
+ // handleWritable == true, then handle.file != nil (i.e. there is an open
+ // handle).
+ //
+ // - handleReadable and handleWritable cannot transition from true to false
+ // (i.e. handles may not be downgraded).
+ //
+ // These fields are protected by handleMu.
+ handleMu sync.RWMutex
+ handle handle
+ handleReadable bool
+ handleWritable bool
+
+ dataMu sync.RWMutex
+
+ // If this dentry represents a regular file that is client-cached, cache
+ // maps offsets into the cached file to offsets into
+ // filesystem.mfp.MemoryFile() that store the file's data. cache is
+ // protected by dataMu.
+ cache fsutil.FileRangeSet
+
+ // If this dentry represents a regular file that is client-cached, dirty
+ // tracks dirty segments in cache. dirty is protected by dataMu.
+ dirty fsutil.DirtySet
+
+ // pf implements platform.File for mappings of handle.fd.
+ pf dentryPlatformFile
+
+ // If this dentry represents a symbolic link, InteropModeShared is not in
+ // effect, and haveTarget is true, target is the symlink target. haveTarget
+ // and target are protected by dataMu.
+ haveTarget bool
+ target string
+
+ // If this dentry represents a synthetic socket file, endpoint is the
+ // transport endpoint bound to this file.
+ endpoint transport.BoundEndpoint
+
+ // If this dentry represents a synthetic named pipe, pipe is the pipe
+ // endpoint bound to this file.
+ pipe *pipe.VFSPipe
+
+ locks vfs.FileLocks
+
+ // Inotify watches for this dentry.
+ watches vfs.Watches
+}
+
+// dentryAttrMask returns a p9.AttrMask enabling all attributes used by the
+// gofer client.
+func dentryAttrMask() p9.AttrMask {
+ return p9.AttrMask{
+ Mode: true,
+ UID: true,
+ GID: true,
+ ATime: true,
+ MTime: true,
+ CTime: true,
+ Size: true,
+ BTime: true,
+ }
+}
+
+// newDentry creates a new dentry representing the given file. The dentry
+// initially has no references, but is not cached; it is the caller's
+// responsibility to set the dentry's reference count and/or call
+// dentry.checkCachingLocked() as appropriate.
+//
+// Preconditions: !file.isNil().
+func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, mask p9.AttrMask, attr *p9.Attr) (*dentry, error) {
+ if !mask.Mode {
+ ctx.Warningf("can't create gofer.dentry without file type")
+ return nil, syserror.EIO
+ }
+ if attr.Mode.FileType() == p9.ModeRegular && !mask.Size {
+ ctx.Warningf("can't create regular file gofer.dentry without file size")
+ return nil, syserror.EIO
+ }
+
+ d := &dentry{
+ fs: fs,
+ file: file,
+ ino: inoFromPath(qid.Path),
+ mode: uint32(attr.Mode),
+ uid: uint32(fs.opts.dfltuid),
+ gid: uint32(fs.opts.dfltgid),
+ blockSize: usermem.PageSize,
+ handle: handle{
+ fd: -1,
+ },
+ }
+ d.pf.dentry = d
+ if mask.UID {
+ d.uid = dentryUIDFromP9UID(attr.UID)
+ }
+ if mask.GID {
+ d.gid = dentryGIDFromP9GID(attr.GID)
+ }
+ if mask.Size {
+ d.size = attr.Size
+ }
+ if attr.BlockSize != 0 {
+ d.blockSize = uint32(attr.BlockSize)
+ }
+ if mask.ATime {
+ d.atime = dentryTimestampFromP9(attr.ATimeSeconds, attr.ATimeNanoSeconds)
+ }
+ if mask.MTime {
+ d.mtime = dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds)
+ }
+ if mask.CTime {
+ d.ctime = dentryTimestampFromP9(attr.CTimeSeconds, attr.CTimeNanoSeconds)
+ }
+ if mask.BTime {
+ d.btime = dentryTimestampFromP9(attr.BTimeSeconds, attr.BTimeNanoSeconds)
+ }
+ if mask.NLink {
+ d.nlink = uint32(attr.NLink)
+ }
+ d.vfsd.Init(d)
+
+ fs.syncMu.Lock()
+ fs.syncableDentries[d] = struct{}{}
+ fs.syncMu.Unlock()
+ return d, nil
+}
+
+func (d *dentry) isSynthetic() bool {
+ return d.file.isNil()
+}
+
+func (d *dentry) cachedMetadataAuthoritative() bool {
+ return d.fs.opts.interop != InteropModeShared || d.isSynthetic()
+}
+
+// updateFromP9Attrs is called to update d's metadata after an update from the
+// remote filesystem.
+func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) {
+ d.metadataMu.Lock()
+ if mask.Mode {
+ if got, want := uint32(attr.Mode.FileType()), d.fileType(); got != want {
+ d.metadataMu.Unlock()
+ panic(fmt.Sprintf("gofer.dentry file type changed from %#o to %#o", want, got))
+ }
+ atomic.StoreUint32(&d.mode, uint32(attr.Mode))
+ }
+ if mask.UID {
+ atomic.StoreUint32(&d.uid, dentryUIDFromP9UID(attr.UID))
+ }
+ if mask.GID {
+ atomic.StoreUint32(&d.gid, dentryGIDFromP9GID(attr.GID))
+ }
+ // There is no P9_GETATTR_* bit for I/O block size.
+ if attr.BlockSize != 0 {
+ atomic.StoreUint32(&d.blockSize, uint32(attr.BlockSize))
+ }
+ if mask.ATime {
+ atomic.StoreInt64(&d.atime, dentryTimestampFromP9(attr.ATimeSeconds, attr.ATimeNanoSeconds))
+ }
+ if mask.MTime {
+ atomic.StoreInt64(&d.mtime, dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds))
+ }
+ if mask.CTime {
+ atomic.StoreInt64(&d.ctime, dentryTimestampFromP9(attr.CTimeSeconds, attr.CTimeNanoSeconds))
+ }
+ if mask.BTime {
+ atomic.StoreInt64(&d.btime, dentryTimestampFromP9(attr.BTimeSeconds, attr.BTimeNanoSeconds))
+ }
+ if mask.NLink {
+ atomic.StoreUint32(&d.nlink, uint32(attr.NLink))
+ }
+ if mask.Size {
+ d.updateFileSizeLocked(attr.Size)
+ }
+ d.metadataMu.Unlock()
+}
+
+// Preconditions: !d.isSynthetic()
+func (d *dentry) updateFromGetattr(ctx context.Context) error {
+ // Use d.handle.file, which represents a 9P fid that has been opened, in
+ // preference to d.file, which represents a 9P fid that has not. This may
+ // be significantly more efficient in some implementations.
+ var (
+ file p9file
+ handleMuRLocked bool
+ )
+ d.handleMu.RLock()
+ if !d.handle.file.isNil() {
+ file = d.handle.file
+ handleMuRLocked = true
+ } else {
+ file = d.file
+ d.handleMu.RUnlock()
+ }
+ _, attrMask, attr, err := file.getAttr(ctx, dentryAttrMask())
+ if handleMuRLocked {
+ d.handleMu.RUnlock()
+ }
+ if err != nil {
+ return err
+ }
+ d.updateFromP9Attrs(attrMask, &attr)
+ return nil
+}
+
+func (d *dentry) fileType() uint32 {
+ return atomic.LoadUint32(&d.mode) & linux.S_IFMT
+}
+
+func (d *dentry) statTo(stat *linux.Statx) {
+ stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME
+ stat.Blksize = atomic.LoadUint32(&d.blockSize)
+ stat.Nlink = atomic.LoadUint32(&d.nlink)
+ if stat.Nlink == 0 {
+ // The remote filesystem doesn't support link count; just make
+ // something up. This is consistent with Linux, where
+ // fs/inode.c:inode_init_always() initializes link count to 1, and
+ // fs/9p/vfs_inode_dotl.c:v9fs_stat2inode_dotl() doesn't touch it if
+ // it's not provided by the remote filesystem.
+ stat.Nlink = 1
+ }
+ stat.UID = atomic.LoadUint32(&d.uid)
+ stat.GID = atomic.LoadUint32(&d.gid)
+ stat.Mode = uint16(atomic.LoadUint32(&d.mode))
+ stat.Ino = uint64(d.ino)
+ stat.Size = atomic.LoadUint64(&d.size)
+ // This is consistent with regularFileFD.Seek(), which treats regular files
+ // as having no holes.
+ stat.Blocks = (stat.Size + 511) / 512
+ stat.Atime = statxTimestampFromDentry(atomic.LoadInt64(&d.atime))
+ stat.Btime = statxTimestampFromDentry(atomic.LoadInt64(&d.btime))
+ stat.Ctime = statxTimestampFromDentry(atomic.LoadInt64(&d.ctime))
+ stat.Mtime = statxTimestampFromDentry(atomic.LoadInt64(&d.mtime))
+ stat.DevMajor = linux.UNNAMED_MAJOR
+ stat.DevMinor = d.fs.devMinor
+}
+
+func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mnt *vfs.Mount) error {
+ if stat.Mask == 0 {
+ return nil
+ }
+ if stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_SIZE) != 0 {
+ return syserror.EPERM
+ }
+ mode := linux.FileMode(atomic.LoadUint32(&d.mode))
+ if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ return err
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ setLocalAtime := false
+ setLocalMtime := false
+ if d.cachedMetadataAuthoritative() {
+ // Timestamp updates will be handled locally.
+ setLocalAtime = stat.Mask&linux.STATX_ATIME != 0
+ setLocalMtime = stat.Mask&linux.STATX_MTIME != 0
+ stat.Mask &^= linux.STATX_ATIME | linux.STATX_MTIME
+
+ // Prepare for truncate.
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ switch d.mode & linux.S_IFMT {
+ case linux.S_IFREG:
+ if !setLocalMtime {
+ // Truncate updates mtime.
+ setLocalMtime = true
+ stat.Mtime.Nsec = linux.UTIME_NOW
+ }
+ case linux.S_IFDIR:
+ return syserror.EISDIR
+ default:
+ return syserror.EINVAL
+ }
+ }
+ }
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ // The size needs to be changed even when
+ // !d.cachedMetadataAuthoritative() because d.mappings has to be
+ // updated.
+ d.updateFileSizeLocked(stat.Size)
+ }
+ if !d.isSynthetic() {
+ if stat.Mask != 0 {
+ if err := d.file.setAttr(ctx, p9.SetAttrMask{
+ Permissions: stat.Mask&linux.STATX_MODE != 0,
+ UID: stat.Mask&linux.STATX_UID != 0,
+ GID: stat.Mask&linux.STATX_GID != 0,
+ Size: stat.Mask&linux.STATX_SIZE != 0,
+ ATime: stat.Mask&linux.STATX_ATIME != 0,
+ MTime: stat.Mask&linux.STATX_MTIME != 0,
+ ATimeNotSystemTime: stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW,
+ MTimeNotSystemTime: stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW,
+ }, p9.SetAttr{
+ Permissions: p9.FileMode(stat.Mode),
+ UID: p9.UID(stat.UID),
+ GID: p9.GID(stat.GID),
+ Size: stat.Size,
+ ATimeSeconds: uint64(stat.Atime.Sec),
+ ATimeNanoSeconds: uint64(stat.Atime.Nsec),
+ MTimeSeconds: uint64(stat.Mtime.Sec),
+ MTimeNanoSeconds: uint64(stat.Mtime.Nsec),
+ }); err != nil {
+ return err
+ }
+ }
+ if d.fs.opts.interop == InteropModeShared {
+ // There's no point to updating d's metadata in this case since
+ // it'll be overwritten by revalidation before the next time it's
+ // used anyway. (InteropModeShared inhibits client caching of
+ // regular file data, so there's no cache to truncate either.)
+ return nil
+ }
+ }
+ now := d.fs.clock.Now().Nanoseconds()
+ if stat.Mask&linux.STATX_MODE != 0 {
+ atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode))
+ }
+ if stat.Mask&linux.STATX_UID != 0 {
+ atomic.StoreUint32(&d.uid, stat.UID)
+ }
+ if stat.Mask&linux.STATX_GID != 0 {
+ atomic.StoreUint32(&d.gid, stat.GID)
+ }
+ if setLocalAtime {
+ if stat.Atime.Nsec == linux.UTIME_NOW {
+ atomic.StoreInt64(&d.atime, now)
+ } else {
+ atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime))
+ }
+ // Restore mask bits that we cleared earlier.
+ stat.Mask |= linux.STATX_ATIME
+ }
+ if setLocalMtime {
+ if stat.Mtime.Nsec == linux.UTIME_NOW {
+ atomic.StoreInt64(&d.mtime, now)
+ } else {
+ atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime))
+ }
+ // Restore mask bits that we cleared earlier.
+ stat.Mask |= linux.STATX_MTIME
+ }
+ atomic.StoreInt64(&d.ctime, now)
+ return nil
+}
+
+// Preconditions: d.metadataMu must be locked.
+func (d *dentry) updateFileSizeLocked(newSize uint64) {
+ d.dataMu.Lock()
+ oldSize := d.size
+ d.size = newSize
+ // d.dataMu must be unlocked to lock d.mapsMu and invalidate mappings
+ // below. This allows concurrent calls to Read/Translate/etc. These
+ // functions synchronize with truncation by refusing to use cache
+ // contents beyond the new d.size. (We are still holding d.metadataMu,
+ // so we can't race with Write or another truncate.)
+ d.dataMu.Unlock()
+ if d.size < oldSize {
+ oldpgend, _ := usermem.PageRoundUp(oldSize)
+ newpgend, _ := usermem.PageRoundUp(d.size)
+ if oldpgend != newpgend {
+ d.mapsMu.Lock()
+ d.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
+ // Compare Linux's mm/truncate.c:truncate_setsize() =>
+ // truncate_pagecache() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ InvalidatePrivate: true,
+ })
+ d.mapsMu.Unlock()
+ }
+ // We are now guaranteed that there are no translations of
+ // truncated pages, and can remove them from the cache. Since
+ // truncated pages have been removed from the remote file, they
+ // should be dropped without being written back.
+ d.dataMu.Lock()
+ d.cache.Truncate(d.size, d.fs.mfp.MemoryFile())
+ d.dirty.KeepClean(memmap.MappableRange{d.size, oldpgend})
+ d.dataMu.Unlock()
+ }
+}
+
+func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+ return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid)))
+}
+
+func (d *dentry) mayDelete(creds *auth.Credentials, child *dentry) error {
+ return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&child.uid)))
+}
+
+func dentryUIDFromP9UID(uid p9.UID) uint32 {
+ if !uid.Ok() {
+ return uint32(auth.OverflowUID)
+ }
+ return uint32(uid)
+}
+
+func dentryGIDFromP9GID(gid p9.GID) uint32 {
+ if !gid.Ok() {
+ return uint32(auth.OverflowGID)
+ }
+ return uint32(gid)
+}
+
+// IncRef implements vfs.DentryImpl.IncRef.
+func (d *dentry) IncRef() {
+ // d.refs may be 0 if d.fs.renameMu is locked, which serializes against
+ // d.checkCachingLocked().
+ atomic.AddInt64(&d.refs, 1)
+}
+
+// TryIncRef implements vfs.DentryImpl.TryIncRef.
+func (d *dentry) TryIncRef() bool {
+ for {
+ refs := atomic.LoadInt64(&d.refs)
+ if refs <= 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) {
+ return true
+ }
+ }
+}
+
+// DecRef implements vfs.DentryImpl.DecRef.
+func (d *dentry) DecRef() {
+ if refs := atomic.AddInt64(&d.refs, -1); refs == 0 {
+ d.fs.renameMu.Lock()
+ d.checkCachingLocked()
+ d.fs.renameMu.Unlock()
+ } else if refs < 0 {
+ panic("gofer.dentry.DecRef() called without holding a reference")
+ }
+}
+
+// decRefLocked decrements d's reference count without calling
+// d.checkCachingLocked, even if d's reference count reaches 0; callers are
+// responsible for ensuring that d.checkCachingLocked will be called later.
+func (d *dentry) decRefLocked() {
+ if refs := atomic.AddInt64(&d.refs, -1); refs < 0 {
+ panic("gofer.dentry.decRefLocked() called without holding a reference")
+ }
+}
+
+// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
+func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {
+ if d.isDir() {
+ events |= linux.IN_ISDIR
+ }
+
+ d.fs.renameMu.RLock()
+ // The ordering below is important, Linux always notifies the parent first.
+ if d.parent != nil {
+ d.parent.watches.Notify(d.name, events, cookie, et, d.isDeleted())
+ }
+ d.watches.Notify("", events, cookie, et, d.isDeleted())
+ d.fs.renameMu.RUnlock()
+}
+
+// Watches implements vfs.DentryImpl.Watches.
+func (d *dentry) Watches() *vfs.Watches {
+ return &d.watches
+}
+
+// OnZeroWatches implements vfs.DentryImpl.OnZeroWatches.
+//
+// If no watches are left on this dentry and it has no references, cache it.
+func (d *dentry) OnZeroWatches() {
+ if atomic.LoadInt64(&d.refs) == 0 {
+ d.fs.renameMu.Lock()
+ d.checkCachingLocked()
+ d.fs.renameMu.Unlock()
+ }
+}
+
+// checkCachingLocked should be called after d's reference count becomes 0 or it
+// becomes disowned.
+//
+// It may be called on a destroyed dentry. For example,
+// renameMu[R]UnlockAndCheckCaching may call checkCachingLocked multiple times
+// for the same dentry when the dentry is visited more than once in the same
+// operation. One of the calls may destroy the dentry, so subsequent calls will
+// do nothing.
+//
+// Preconditions: d.fs.renameMu must be locked for writing.
+func (d *dentry) checkCachingLocked() {
+ // Dentries with a non-zero reference count must be retained. (The only way
+ // to obtain a reference on a dentry with zero references is via path
+ // resolution, which requires renameMu, so if d.refs is zero then it will
+ // remain zero while we hold renameMu for writing.)
+ refs := atomic.LoadInt64(&d.refs)
+ if refs > 0 {
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentriesLen--
+ d.cached = false
+ }
+ return
+ }
+ if refs == -1 {
+ // Dentry has already been destroyed.
+ return
+ }
+ // Deleted and invalidated dentries with zero references are no longer
+ // reachable by path resolution and should be dropped immediately.
+ if d.vfsd.IsDead() {
+ if d.isDeleted() {
+ d.watches.HandleDeletion()
+ }
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentriesLen--
+ d.cached = false
+ }
+ d.destroyLocked()
+ return
+ }
+ // If d still has inotify watches and it is not deleted or invalidated, we
+ // cannot cache it and allow it to be evicted. Otherwise, we will lose its
+ // watches, even if a new dentry is created for the same file in the future.
+ // Note that the size of d.watches cannot concurrently transition from zero
+ // to non-zero, because adding a watch requires holding a reference on d.
+ if d.watches.Size() > 0 {
+ return
+ }
+ // If d is already cached, just move it to the front of the LRU.
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentries.PushFront(d)
+ return
+ }
+ // Cache the dentry, then evict the least recently used cached dentry if
+ // the cache becomes over-full.
+ d.fs.cachedDentries.PushFront(d)
+ d.fs.cachedDentriesLen++
+ d.cached = true
+ if d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries {
+ victim := d.fs.cachedDentries.Back()
+ d.fs.cachedDentries.Remove(victim)
+ d.fs.cachedDentriesLen--
+ victim.cached = false
+ // victim.refs may have become non-zero from an earlier path resolution
+ // since it was inserted into fs.cachedDentries.
+ if atomic.LoadInt64(&victim.refs) == 0 {
+ if victim.parent != nil {
+ victim.parent.dirMu.Lock()
+ if !victim.vfsd.IsDead() {
+ // Note that victim can't be a mount point (in any mount
+ // namespace), since VFS holds references on mount points.
+ d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(&victim.vfsd)
+ delete(victim.parent.children, victim.name)
+ // We're only deleting the dentry, not the file it
+ // represents, so we don't need to update
+ // victimParent.dirents etc.
+ }
+ victim.parent.dirMu.Unlock()
+ }
+ victim.destroyLocked()
+ }
+ // Whether or not victim was destroyed, we brought fs.cachedDentriesLen
+ // back down to fs.opts.maxCachedDentries, so we don't loop.
+ }
+}
+
+// destroyLocked destroys the dentry. It may flushes dirty pages from cache,
+// close p9 file and remove reference on parent dentry.
+//
+// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0. d is
+// not a child dentry.
+func (d *dentry) destroyLocked() {
+ switch atomic.LoadInt64(&d.refs) {
+ case 0:
+ // Mark the dentry destroyed.
+ atomic.StoreInt64(&d.refs, -1)
+ case -1:
+ panic("dentry.destroyLocked() called on already destroyed dentry")
+ default:
+ panic("dentry.destroyLocked() called with references on the dentry")
+ }
+
+ ctx := context.Background()
+ d.handleMu.Lock()
+ if !d.handle.file.isNil() {
+ mf := d.fs.mfp.MemoryFile()
+ d.dataMu.Lock()
+ // Write dirty pages back to the remote filesystem.
+ if d.handleWritable {
+ if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, d.handle.writeFromBlocksAt); err != nil {
+ log.Warningf("gofer.dentry.DecRef: failed to write dirty data back: %v", err)
+ }
+ }
+ // Discard cached data.
+ d.cache.DropAll(mf)
+ d.dirty.RemoveAll()
+ d.dataMu.Unlock()
+ // Clunk open fids and close open host FDs.
+ d.handle.close(ctx)
+ }
+ d.handleMu.Unlock()
+
+ if !d.file.isNil() {
+ d.file.close(ctx)
+ d.file = p9file{}
+ // Remove d from the set of syncable dentries.
+ d.fs.syncMu.Lock()
+ delete(d.fs.syncableDentries, d)
+ d.fs.syncMu.Unlock()
+ }
+ // Drop the reference held by d on its parent without recursively locking
+ // d.fs.renameMu.
+ if d.parent != nil {
+ if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 {
+ d.parent.checkCachingLocked()
+ } else if refs < 0 {
+ panic("gofer.dentry.DecRef() called without holding a reference")
+ }
+ }
+}
+
+func (d *dentry) isDeleted() bool {
+ return atomic.LoadUint32(&d.deleted) != 0
+}
+
+func (d *dentry) setDeleted() {
+ atomic.StoreUint32(&d.deleted, 1)
+}
+
+// We only support xattrs prefixed with "user." (see b/148380782). Currently,
+// there is no need to expose any other xattrs through a gofer.
+func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) {
+ if d.file.isNil() || !d.userXattrSupported() {
+ return nil, nil
+ }
+ xattrMap, err := d.file.listXattr(ctx, size)
+ if err != nil {
+ return nil, err
+ }
+ xattrs := make([]string, 0, len(xattrMap))
+ for x := range xattrMap {
+ if strings.HasPrefix(x, linux.XATTR_USER_PREFIX) {
+ xattrs = append(xattrs, x)
+ }
+ }
+ return xattrs, nil
+}
+
+func (d *dentry) getxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) {
+ if d.file.isNil() {
+ return "", syserror.ENODATA
+ }
+ if err := d.checkPermissions(creds, vfs.MayRead); err != nil {
+ return "", err
+ }
+ if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
+ return "", syserror.EOPNOTSUPP
+ }
+ if !d.userXattrSupported() {
+ return "", syserror.ENODATA
+ }
+ return d.file.getXattr(ctx, opts.Name, opts.Size)
+}
+
+func (d *dentry) setxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetxattrOptions) error {
+ if d.file.isNil() {
+ return syserror.EPERM
+ }
+ if err := d.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return err
+ }
+ if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+ if !d.userXattrSupported() {
+ return syserror.EPERM
+ }
+ return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags)
+}
+
+func (d *dentry) removexattr(ctx context.Context, creds *auth.Credentials, name string) error {
+ if d.file.isNil() {
+ return syserror.EPERM
+ }
+ if err := d.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return err
+ }
+ if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+ if !d.userXattrSupported() {
+ return syserror.EPERM
+ }
+ return d.file.removeXattr(ctx, name)
+}
+
+// Extended attributes in the user.* namespace are only supported for regular
+// files and directories.
+func (d *dentry) userXattrSupported() bool {
+ filetype := linux.S_IFMT & atomic.LoadUint32(&d.mode)
+ return filetype == linux.S_IFREG || filetype == linux.S_IFDIR
+}
+
+// Preconditions: !d.isSynthetic(). d.isRegularFile() || d.isDir().
+func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool) error {
+ // O_TRUNC unconditionally requires us to obtain a new handle (opened with
+ // O_TRUNC).
+ if !trunc {
+ d.handleMu.RLock()
+ if (!read || d.handleReadable) && (!write || d.handleWritable) {
+ // The current handle is sufficient.
+ d.handleMu.RUnlock()
+ return nil
+ }
+ d.handleMu.RUnlock()
+ }
+
+ haveOldFD := false
+ d.handleMu.Lock()
+ if (read && !d.handleReadable) || (write && !d.handleWritable) || trunc {
+ // Get a new handle.
+ wantReadable := d.handleReadable || read
+ wantWritable := d.handleWritable || write
+ h, err := openHandle(ctx, d.file, wantReadable, wantWritable, trunc)
+ if err != nil {
+ d.handleMu.Unlock()
+ return err
+ }
+ if !d.handle.file.isNil() {
+ // Check that old and new handles are compatible: If the old handle
+ // includes a host file descriptor but the new one does not, or
+ // vice versa, old and new memory mappings may be incoherent.
+ haveOldFD = d.handle.fd >= 0
+ haveNewFD := h.fd >= 0
+ if haveOldFD != haveNewFD {
+ d.handleMu.Unlock()
+ ctx.Warningf("gofer.dentry.ensureSharedHandle: can't change host FD availability from %v to %v across dentry handle upgrade", haveOldFD, haveNewFD)
+ h.close(ctx)
+ return syserror.EIO
+ }
+ if haveOldFD {
+ // We may have raced with callers of d.pf.FD() that are now
+ // using the old file descriptor, preventing us from safely
+ // closing it. We could handle this by invalidating existing
+ // memmap.Translations, but this is expensive. Instead, use
+ // dup3 to make the old file descriptor refer to the new file
+ // description, then close the new file descriptor (which is no
+ // longer needed). Racing callers may use the old or new file
+ // description, but this doesn't matter since they refer to the
+ // same file (unless d.fs.opts.overlayfsStaleRead is true,
+ // which we handle separately).
+ if err := syscall.Dup3(int(h.fd), int(d.handle.fd), syscall.O_CLOEXEC); err != nil {
+ d.handleMu.Unlock()
+ ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, d.handle.fd, err)
+ h.close(ctx)
+ return err
+ }
+ syscall.Close(int(h.fd))
+ h.fd = d.handle.fd
+ if d.fs.opts.overlayfsStaleRead {
+ // Replace sentry mappings of the old FD with mappings of
+ // the new FD, since the two are not necessarily coherent.
+ if err := d.pf.hostFileMapper.RegenerateMappings(int(h.fd)); err != nil {
+ d.handleMu.Unlock()
+ ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to replace sentry mappings of old FD with mappings of new FD: %v", err)
+ h.close(ctx)
+ return err
+ }
+ }
+ // Clunk the old fid before making the new handle visible (by
+ // unlocking d.handleMu).
+ d.handle.file.close(ctx)
+ }
+ }
+ // Switch to the new handle.
+ d.handle = h
+ d.handleReadable = wantReadable
+ d.handleWritable = wantWritable
+ }
+ d.handleMu.Unlock()
+
+ if d.fs.opts.overlayfsStaleRead && haveOldFD {
+ // Invalidate application mappings that may be using the old FD; they
+ // will be replaced with mappings using the new FD after future calls
+ // to d.Translate(). This requires holding d.mapsMu, which precedes
+ // d.handleMu in the lock order.
+ d.mapsMu.Lock()
+ d.mappings.InvalidateAll(memmap.InvalidateOpts{})
+ d.mapsMu.Unlock()
+ }
+
+ return nil
+}
+
+// incLinks increments link count.
+func (d *dentry) incLinks() {
+ if atomic.LoadUint32(&d.nlink) == 0 {
+ // The remote filesystem doesn't support link count.
+ return
+ }
+ atomic.AddUint32(&d.nlink, 1)
+}
+
+// decLinks decrements link count.
+func (d *dentry) decLinks() {
+ if atomic.LoadUint32(&d.nlink) == 0 {
+ // The remote filesystem doesn't support link count.
+ return
+ }
+ atomic.AddUint32(&d.nlink, ^uint32(0))
+}
+
+// fileDescription is embedded by gofer implementations of
+// vfs.FileDescriptionImpl.
+type fileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ lockLogging sync.Once
+}
+
+func (fd *fileDescription) filesystem() *filesystem {
+ return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem)
+}
+
+func (fd *fileDescription) dentry() *dentry {
+ return fd.vfsfd.Dentry().Impl().(*dentry)
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ d := fd.dentry()
+ const validMask = uint32(linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME)
+ if !d.cachedMetadataAuthoritative() && opts.Mask&validMask != 0 && opts.Sync != linux.AT_STATX_DONT_SYNC {
+ // TODO(jamieliu): Use specialFileFD.handle.file for the getattr if
+ // available?
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return linux.Statx{}, err
+ }
+ }
+ var stat linux.Statx
+ d.statTo(&stat)
+ return stat, nil
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ if err := fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts.Stat, fd.vfsfd.Mount()); err != nil {
+ return err
+ }
+ if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
+ fd.dentry().InotifyWithParent(ev, 0, vfs.InodeEvent)
+ }
+ return nil
+}
+
+// Listxattr implements vfs.FileDescriptionImpl.Listxattr.
+func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) {
+ return fd.dentry().listxattr(ctx, auth.CredentialsFromContext(ctx), size)
+}
+
+// Getxattr implements vfs.FileDescriptionImpl.Getxattr.
+func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) {
+ return fd.dentry().getxattr(ctx, auth.CredentialsFromContext(ctx), &opts)
+}
+
+// Setxattr implements vfs.FileDescriptionImpl.Setxattr.
+func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error {
+ d := fd.dentry()
+ if err := d.setxattr(ctx, auth.CredentialsFromContext(ctx), &opts); err != nil {
+ return err
+ }
+ d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
+}
+
+// Removexattr implements vfs.FileDescriptionImpl.Removexattr.
+func (fd *fileDescription) Removexattr(ctx context.Context, name string) error {
+ d := fd.dentry()
+ if err := d.removexattr(ctx, auth.CredentialsFromContext(ctx), name); err != nil {
+ return err
+ }
+ d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
+}
+
+// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
+func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
+ fd.lockLogging.Do(func() {
+ log.Infof("File lock using gofer file handled internally.")
+ })
+ return fd.LockFD.LockBSD(ctx, uid, t, block)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ fd.lockLogging.Do(func() {
+ log.Infof("Range lock using gofer file handled internally.")
+ })
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go
new file mode 100644
index 000000000..adff39490
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/gofer_test.go
@@ -0,0 +1,63 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "sync/atomic"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+)
+
+func TestDestroyIdempotent(t *testing.T) {
+ fs := filesystem{
+ syncableDentries: make(map[*dentry]struct{}),
+ opts: filesystemOptions{
+ // Test relies on no dentry being held in the cache.
+ maxCachedDentries: 0,
+ },
+ }
+
+ ctx := contexttest.Context(t)
+ attr := &p9.Attr{
+ Mode: p9.ModeRegular,
+ }
+ mask := p9.AttrMask{
+ Mode: true,
+ Size: true,
+ }
+ parent, err := fs.newDentry(ctx, p9file{}, p9.QID{}, mask, attr)
+ if err != nil {
+ t.Fatalf("fs.newDentry(): %v", err)
+ }
+
+ child, err := fs.newDentry(ctx, p9file{}, p9.QID{}, mask, attr)
+ if err != nil {
+ t.Fatalf("fs.newDentry(): %v", err)
+ }
+ parent.cacheNewChildLocked(child, "child")
+
+ child.checkCachingLocked()
+ if got := atomic.LoadInt64(&child.refs); got != -1 {
+ t.Fatalf("child.refs=%d, want: -1", got)
+ }
+ // Parent will also be destroyed when child reference is removed.
+ if got := atomic.LoadInt64(&parent.refs); got != -1 {
+ t.Fatalf("parent.refs=%d, want: -1", got)
+ }
+ child.checkCachingLocked()
+ child.checkCachingLocked()
+}
diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go
new file mode 100644
index 000000000..8792ca4f2
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/handle.go
@@ -0,0 +1,141 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/hostfd"
+)
+
+// handle represents a remote "open file descriptor", consisting of an opened
+// fid (p9.File) and optionally a host file descriptor.
+type handle struct {
+ file p9file
+ fd int32 // -1 if unavailable
+}
+
+// Preconditions: read || write.
+func openHandle(ctx context.Context, file p9file, read, write, trunc bool) (handle, error) {
+ _, newfile, err := file.walk(ctx, nil)
+ if err != nil {
+ return handle{fd: -1}, err
+ }
+ var flags p9.OpenFlags
+ switch {
+ case read && !write:
+ flags = p9.ReadOnly
+ case !read && write:
+ flags = p9.WriteOnly
+ case read && write:
+ flags = p9.ReadWrite
+ }
+ if trunc {
+ flags |= p9.OpenTruncate
+ }
+ fdobj, _, _, err := newfile.open(ctx, flags)
+ if err != nil {
+ newfile.close(ctx)
+ return handle{fd: -1}, err
+ }
+ fd := int32(-1)
+ if fdobj != nil {
+ fd = int32(fdobj.Release())
+ }
+ return handle{
+ file: newfile,
+ fd: fd,
+ }, nil
+}
+
+func (h *handle) close(ctx context.Context) {
+ h.file.close(ctx)
+ h.file = p9file{}
+ if h.fd >= 0 {
+ syscall.Close(int(h.fd))
+ h.fd = -1
+ }
+}
+
+func (h *handle) readToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) {
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+ if h.fd >= 0 {
+ ctx.UninterruptibleSleepStart(false)
+ n, err := hostfd.Preadv2(h.fd, dsts, int64(offset), 0 /* flags */)
+ ctx.UninterruptibleSleepFinish(false)
+ return n, err
+ }
+ if dsts.NumBlocks() == 1 && !dsts.Head().NeedSafecopy() {
+ n, err := h.file.readAt(ctx, dsts.Head().ToSlice(), offset)
+ return uint64(n), err
+ }
+ // Buffer the read since p9.File.ReadAt() takes []byte.
+ buf := make([]byte, dsts.NumBytes())
+ n, err := h.file.readAt(ctx, buf, offset)
+ if n == 0 {
+ return 0, err
+ }
+ if cp, cperr := safemem.CopySeq(dsts, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:n]))); cperr != nil {
+ return cp, cperr
+ }
+ return uint64(n), err
+}
+
+func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) {
+ if srcs.IsEmpty() {
+ return 0, nil
+ }
+ if h.fd >= 0 {
+ ctx.UninterruptibleSleepStart(false)
+ n, err := hostfd.Pwritev2(h.fd, srcs, int64(offset), 0 /* flags */)
+ ctx.UninterruptibleSleepFinish(false)
+ return n, err
+ }
+ if srcs.NumBlocks() == 1 && !srcs.Head().NeedSafecopy() {
+ n, err := h.file.writeAt(ctx, srcs.Head().ToSlice(), offset)
+ return uint64(n), err
+ }
+ // Buffer the write since p9.File.WriteAt() takes []byte.
+ buf := make([]byte, srcs.NumBytes())
+ cp, cperr := safemem.CopySeq(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), srcs)
+ if cp == 0 {
+ return 0, cperr
+ }
+ n, err := h.file.writeAt(ctx, buf[:cp], offset)
+ if err != nil {
+ return uint64(n), err
+ }
+ return cp, cperr
+}
+
+func (h *handle) sync(ctx context.Context) error {
+ // Handle most common case first.
+ if h.fd >= 0 {
+ ctx.UninterruptibleSleepStart(false)
+ err := syscall.Fsync(int(h.fd))
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+ }
+ if h.file.isNil() {
+ // File hasn't been touched, there is nothing to sync.
+ return nil
+ }
+ return h.file.fsync(ctx)
+}
diff --git a/pkg/sentry/fsimpl/gofer/host_named_pipe.go b/pkg/sentry/fsimpl/gofer/host_named_pipe.go
new file mode 100644
index 000000000..7294de7d6
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/host_named_pipe.go
@@ -0,0 +1,97 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Global pipe used by blockUntilNonblockingPipeHasWriter since we can't create
+// pipes after sentry initialization due to syscall filters.
+var (
+ tempPipeMu sync.Mutex
+ tempPipeReadFD int
+ tempPipeWriteFD int
+ tempPipeBuf [1]byte
+)
+
+func init() {
+ var pipeFDs [2]int
+ if err := unix.Pipe(pipeFDs[:]); err != nil {
+ panic(fmt.Sprintf("failed to create pipe for gofer.blockUntilNonblockingPipeHasWriter: %v", err))
+ }
+ tempPipeReadFD = pipeFDs[0]
+ tempPipeWriteFD = pipeFDs[1]
+}
+
+func blockUntilNonblockingPipeHasWriter(ctx context.Context, fd int32) error {
+ for {
+ ok, err := nonblockingPipeHasWriter(fd)
+ if err != nil {
+ return err
+ }
+ if ok {
+ return nil
+ }
+ if err := sleepBetweenNamedPipeOpenChecks(ctx); err != nil {
+ return err
+ }
+ }
+}
+
+func nonblockingPipeHasWriter(fd int32) (bool, error) {
+ tempPipeMu.Lock()
+ defer tempPipeMu.Unlock()
+ // Copy 1 byte from fd into the temporary pipe.
+ n, err := unix.Tee(int(fd), tempPipeWriteFD, 1, unix.SPLICE_F_NONBLOCK)
+ if err == syserror.EAGAIN {
+ // The pipe represented by fd is empty, but has a writer.
+ return true, nil
+ }
+ if err != nil {
+ return false, err
+ }
+ if n == 0 {
+ // The pipe represented by fd is empty and has no writer.
+ return false, nil
+ }
+ // The pipe represented by fd is non-empty, so it either has, or has
+ // previously had, a writer. Remove the byte copied to the temporary pipe
+ // before returning.
+ if n, err := unix.Read(tempPipeReadFD, tempPipeBuf[:]); err != nil || n != 1 {
+ panic(fmt.Sprintf("failed to drain pipe for gofer.blockUntilNonblockingPipeHasWriter: got (%d, %v), wanted (1, nil)", n, err))
+ }
+ return true, nil
+}
+
+func sleepBetweenNamedPipeOpenChecks(ctx context.Context) error {
+ t := time.NewTimer(100 * time.Millisecond)
+ defer t.Stop()
+ cancel := ctx.SleepStart()
+ select {
+ case <-t.C:
+ ctx.SleepFinish(true)
+ return nil
+ case <-cancel:
+ ctx.SleepFinish(false)
+ return syserror.ErrInterrupted
+ }
+}
diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go
new file mode 100644
index 000000000..87f0b877f
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/p9file.go
@@ -0,0 +1,233 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// p9file is a wrapper around p9.File that provides methods that are
+// Context-aware.
+type p9file struct {
+ file p9.File
+}
+
+func (f p9file) isNil() bool {
+ return f.file == nil
+}
+
+func (f p9file) walk(ctx context.Context, names []string) ([]p9.QID, p9file, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qids, newfile, err := f.file.Walk(names)
+ ctx.UninterruptibleSleepFinish(false)
+ return qids, p9file{newfile}, err
+}
+
+func (f p9file) walkGetAttr(ctx context.Context, names []string) ([]p9.QID, p9file, p9.AttrMask, p9.Attr, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qids, newfile, attrMask, attr, err := f.file.WalkGetAttr(names)
+ ctx.UninterruptibleSleepFinish(false)
+ return qids, p9file{newfile}, attrMask, attr, err
+}
+
+// walkGetAttrOne is a wrapper around p9.File.WalkGetAttr that takes a single
+// path component and returns a single qid.
+func (f p9file) walkGetAttrOne(ctx context.Context, name string) (p9.QID, p9file, p9.AttrMask, p9.Attr, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qids, newfile, attrMask, attr, err := f.file.WalkGetAttr([]string{name})
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ return p9.QID{}, p9file{}, p9.AttrMask{}, p9.Attr{}, err
+ }
+ if len(qids) != 1 {
+ ctx.Warningf("p9.File.WalkGetAttr returned %d qids (%v), wanted 1", len(qids), qids)
+ if newfile != nil {
+ p9file{newfile}.close(ctx)
+ }
+ return p9.QID{}, p9file{}, p9.AttrMask{}, p9.Attr{}, syserror.EIO
+ }
+ return qids[0], p9file{newfile}, attrMask, attr, nil
+}
+
+func (f p9file) statFS(ctx context.Context) (p9.FSStat, error) {
+ ctx.UninterruptibleSleepStart(false)
+ fsstat, err := f.file.StatFS()
+ ctx.UninterruptibleSleepFinish(false)
+ return fsstat, err
+}
+
+func (f p9file) getAttr(ctx context.Context, req p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qid, attrMask, attr, err := f.file.GetAttr(req)
+ ctx.UninterruptibleSleepFinish(false)
+ return qid, attrMask, attr, err
+}
+
+func (f p9file) setAttr(ctx context.Context, valid p9.SetAttrMask, attr p9.SetAttr) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.SetAttr(valid, attr)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) listXattr(ctx context.Context, size uint64) (map[string]struct{}, error) {
+ ctx.UninterruptibleSleepStart(false)
+ xattrs, err := f.file.ListXattr(size)
+ ctx.UninterruptibleSleepFinish(false)
+ return xattrs, err
+}
+
+func (f p9file) getXattr(ctx context.Context, name string, size uint64) (string, error) {
+ ctx.UninterruptibleSleepStart(false)
+ val, err := f.file.GetXattr(name, size)
+ ctx.UninterruptibleSleepFinish(false)
+ return val, err
+}
+
+func (f p9file) setXattr(ctx context.Context, name, value string, flags uint32) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.SetXattr(name, value, flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) removeXattr(ctx context.Context, name string) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.RemoveXattr(name)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) allocate(ctx context.Context, mode p9.AllocateMode, offset, length uint64) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.Allocate(mode, offset, length)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) close(ctx context.Context) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.Close()
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) open(ctx context.Context, flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) {
+ ctx.UninterruptibleSleepStart(false)
+ fdobj, qid, iounit, err := f.file.Open(flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return fdobj, qid, iounit, err
+}
+
+func (f p9file) readAt(ctx context.Context, p []byte, offset uint64) (int, error) {
+ ctx.UninterruptibleSleepStart(false)
+ n, err := f.file.ReadAt(p, offset)
+ ctx.UninterruptibleSleepFinish(false)
+ return n, err
+}
+
+func (f p9file) writeAt(ctx context.Context, p []byte, offset uint64) (int, error) {
+ ctx.UninterruptibleSleepStart(false)
+ n, err := f.file.WriteAt(p, offset)
+ ctx.UninterruptibleSleepFinish(false)
+ return n, err
+}
+
+func (f p9file) fsync(ctx context.Context) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.FSync()
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) create(ctx context.Context, name string, flags p9.OpenFlags, permissions p9.FileMode, uid p9.UID, gid p9.GID) (*fd.FD, p9file, p9.QID, uint32, error) {
+ ctx.UninterruptibleSleepStart(false)
+ fdobj, newfile, qid, iounit, err := f.file.Create(name, flags, permissions, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return fdobj, p9file{newfile}, qid, iounit, err
+}
+
+func (f p9file) mkdir(ctx context.Context, name string, permissions p9.FileMode, uid p9.UID, gid p9.GID) (p9.QID, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qid, err := f.file.Mkdir(name, permissions, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return qid, err
+}
+
+func (f p9file) symlink(ctx context.Context, oldName string, newName string, uid p9.UID, gid p9.GID) (p9.QID, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qid, err := f.file.Symlink(oldName, newName, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return qid, err
+}
+
+func (f p9file) link(ctx context.Context, target p9file, newName string) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.Link(target.file, newName)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) mknod(ctx context.Context, name string, mode p9.FileMode, major uint32, minor uint32, uid p9.UID, gid p9.GID) (p9.QID, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qid, err := f.file.Mknod(name, mode, major, minor, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return qid, err
+}
+
+func (f p9file) rename(ctx context.Context, newDir p9file, newName string) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.Rename(newDir.file, newName)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) unlinkAt(ctx context.Context, name string, flags uint32) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.UnlinkAt(name, flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) readdir(ctx context.Context, offset uint64, count uint32) ([]p9.Dirent, error) {
+ ctx.UninterruptibleSleepStart(false)
+ dirents, err := f.file.Readdir(offset, count)
+ ctx.UninterruptibleSleepFinish(false)
+ return dirents, err
+}
+
+func (f p9file) readlink(ctx context.Context) (string, error) {
+ ctx.UninterruptibleSleepStart(false)
+ target, err := f.file.Readlink()
+ ctx.UninterruptibleSleepFinish(false)
+ return target, err
+}
+
+func (f p9file) flush(ctx context.Context) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.Flush()
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) connect(ctx context.Context, flags p9.ConnectFlags) (*fd.FD, error) {
+ ctx.UninterruptibleSleepStart(false)
+ fdobj, err := f.file.Connect(flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return fdobj, err
+}
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
new file mode 100644
index 000000000..a2f02d9c7
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -0,0 +1,892 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+ "io"
+ "math"
+ "sync"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func (d *dentry) isRegularFile() bool {
+ return d.fileType() == linux.S_IFREG
+}
+
+type regularFileFD struct {
+ fileDescription
+
+ // off is the file offset. off is protected by mu.
+ mu sync.Mutex
+ off int64
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *regularFileFD) Release() {
+}
+
+// OnClose implements vfs.FileDescriptionImpl.OnClose.
+func (fd *regularFileFD) OnClose(ctx context.Context) error {
+ if !fd.vfsfd.IsWritable() {
+ return nil
+ }
+ // Skip flushing if writes may be buffered by the client, since (as with
+ // the VFS1 client) we don't flush buffered writes on close anyway.
+ d := fd.dentry()
+ if d.fs.opts.interop == InteropModeExclusive {
+ return nil
+ }
+ d.handleMu.RLock()
+ defer d.handleMu.RUnlock()
+ return d.handle.file.flush(ctx)
+}
+
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
+func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+
+ d := fd.dentry()
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+
+ size := offset + length
+
+ // Allocating a smaller size is a noop.
+ if size <= d.size {
+ return nil
+ }
+
+ d.handleMu.Lock()
+ defer d.handleMu.Unlock()
+
+ err := d.handle.file.allocate(ctx, p9.ToAllocateMode(mode), offset, length)
+ if err != nil {
+ return err
+ }
+ d.size = size
+ if !d.cachedMetadataAuthoritative() {
+ d.touchCMtimeLocked()
+ }
+ return nil
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ // Check for reading at EOF before calling into MM (but not under
+ // InteropModeShared, which makes d.size unreliable).
+ d := fd.dentry()
+ if d.fs.opts.interop != InteropModeShared && uint64(offset) >= atomic.LoadUint64(&d.size) {
+ return 0, io.EOF
+ }
+
+ if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 {
+ // Lock d.metadataMu for the rest of the read to prevent d.size from
+ // changing.
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ // Write dirty cached pages that will be touched by the read back to
+ // the remote file.
+ if err := d.writeback(ctx, offset, dst.NumBytes()); err != nil {
+ return 0, err
+ }
+ }
+
+ rw := getDentryReadWriter(ctx, d, offset)
+ if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 {
+ // Require the read to go to the remote file.
+ rw.direct = true
+ }
+ n, err := dst.CopyOutFrom(ctx, rw)
+ putDentryReadWriter(rw)
+ if d.fs.opts.interop != InteropModeShared {
+ // Compare Linux's mm/filemap.c:do_generic_file_read() => file_accessed().
+ d.touchAtime(fd.vfsfd.Mount())
+ }
+ return n, err
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ fd.mu.Lock()
+ n, err := fd.PRead(ctx, dst, fd.off, opts)
+ fd.off += n
+ fd.mu.Unlock()
+ return n, err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select pwritev2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes())
+ if err != nil {
+ return 0, err
+ }
+ src = src.TakeFirst64(limit)
+
+ d := fd.dentry()
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ if d.fs.opts.interop != InteropModeShared {
+ // Compare Linux's mm/filemap.c:__generic_file_write_iter() =>
+ // file_update_time(). This is d.touchCMtime(), but without locking
+ // d.metadataMu (recursively).
+ d.touchCMtimeLocked()
+ }
+ if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 {
+ // Write dirty cached pages that will be touched by the write back to
+ // the remote file.
+ if err := d.writeback(ctx, offset, src.NumBytes()); err != nil {
+ return 0, err
+ }
+ // Remove touched pages from the cache.
+ pgstart := usermem.PageRoundDown(uint64(offset))
+ pgend, ok := usermem.PageRoundUp(uint64(offset + src.NumBytes()))
+ if !ok {
+ return 0, syserror.EINVAL
+ }
+ mr := memmap.MappableRange{pgstart, pgend}
+ var freed []platform.FileRange
+ d.dataMu.Lock()
+ cseg := d.cache.LowerBoundSegment(mr.Start)
+ for cseg.Ok() && cseg.Start() < mr.End {
+ cseg = d.cache.Isolate(cseg, mr)
+ freed = append(freed, platform.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()})
+ cseg = d.cache.Remove(cseg).NextSegment()
+ }
+ d.dataMu.Unlock()
+ // Invalidate mappings of removed pages.
+ d.mapsMu.Lock()
+ d.mappings.Invalidate(mr, memmap.InvalidateOpts{})
+ d.mapsMu.Unlock()
+ // Finally free pages removed from the cache.
+ mf := d.fs.mfp.MemoryFile()
+ for _, freedFR := range freed {
+ mf.DecRef(freedFR)
+ }
+ }
+ rw := getDentryReadWriter(ctx, d, offset)
+ if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 {
+ // Require the write to go to the remote file.
+ rw.direct = true
+ }
+ n, err := src.CopyInTo(ctx, rw)
+ putDentryReadWriter(rw)
+ if n != 0 && fd.vfsfd.StatusFlags()&(linux.O_DSYNC|linux.O_SYNC) != 0 {
+ // Write dirty cached pages touched by the write back to the remote
+ // file.
+ if err := d.writeback(ctx, offset, src.NumBytes()); err != nil {
+ return 0, err
+ }
+ // Request the remote filesystem to sync the remote file.
+ if err := d.handle.file.fsync(ctx); err != nil {
+ return 0, err
+ }
+ }
+ return n, err
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ fd.mu.Lock()
+ n, err := fd.PWrite(ctx, src, fd.off, opts)
+ fd.off += n
+ fd.mu.Unlock()
+ return n, err
+}
+
+type dentryReadWriter struct {
+ ctx context.Context
+ d *dentry
+ off uint64
+ direct bool
+}
+
+var dentryReadWriterPool = sync.Pool{
+ New: func() interface{} {
+ return &dentryReadWriter{}
+ },
+}
+
+func getDentryReadWriter(ctx context.Context, d *dentry, offset int64) *dentryReadWriter {
+ rw := dentryReadWriterPool.Get().(*dentryReadWriter)
+ rw.ctx = ctx
+ rw.d = d
+ rw.off = uint64(offset)
+ rw.direct = false
+ return rw
+}
+
+func putDentryReadWriter(rw *dentryReadWriter) {
+ rw.ctx = nil
+ rw.d = nil
+ dentryReadWriterPool.Put(rw)
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+
+ // If we have a mmappable host FD (which must be used here to ensure
+ // coherence with memory-mapped I/O), or if InteropModeShared is in effect
+ // (which prevents us from caching file contents and makes dentry.size
+ // unreliable), or if the file was opened O_DIRECT, read directly from
+ // dentry.handle without locking dentry.dataMu.
+ rw.d.handleMu.RLock()
+ if (rw.d.handle.fd >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct {
+ n, err := rw.d.handle.readToBlocksAt(rw.ctx, dsts, rw.off)
+ rw.d.handleMu.RUnlock()
+ rw.off += n
+ return n, err
+ }
+
+ // Otherwise read from/through the cache.
+ mf := rw.d.fs.mfp.MemoryFile()
+ fillCache := mf.ShouldCacheEvictable()
+ var dataMuUnlock func()
+ if fillCache {
+ rw.d.dataMu.Lock()
+ dataMuUnlock = rw.d.dataMu.Unlock
+ } else {
+ rw.d.dataMu.RLock()
+ dataMuUnlock = rw.d.dataMu.RUnlock
+ }
+
+ // Compute the range to read (limited by file size and overflow-checked).
+ if rw.off >= rw.d.size {
+ dataMuUnlock()
+ rw.d.handleMu.RUnlock()
+ return 0, io.EOF
+ }
+ end := rw.d.size
+ if rend := rw.off + dsts.NumBytes(); rend > rw.off && rend < end {
+ end = rend
+ }
+
+ var done uint64
+ seg, gap := rw.d.cache.Find(rw.off)
+ for rw.off < end {
+ mr := memmap.MappableRange{rw.off, end}
+ switch {
+ case seg.Ok():
+ // Get internal mappings from the cache.
+ ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
+ if err != nil {
+ dataMuUnlock()
+ rw.d.handleMu.RUnlock()
+ return done, err
+ }
+
+ // Copy from internal mappings.
+ n, err := safemem.CopySeq(dsts, ims)
+ done += n
+ rw.off += n
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ dataMuUnlock()
+ rw.d.handleMu.RUnlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ gapMR := gap.Range().Intersect(mr)
+ if fillCache {
+ // Read into the cache, then re-enter the loop to read from the
+ // cache.
+ gapEnd, _ := usermem.PageRoundUp(gapMR.End)
+ reqMR := memmap.MappableRange{
+ Start: usermem.PageRoundDown(gapMR.Start),
+ End: gapEnd,
+ }
+ optMR := gap.Range()
+ err := rw.d.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), mf, usage.PageCache, rw.d.handle.readToBlocksAt)
+ mf.MarkEvictable(rw.d, pgalloc.EvictableRange{optMR.Start, optMR.End})
+ seg, gap = rw.d.cache.Find(rw.off)
+ if !seg.Ok() {
+ dataMuUnlock()
+ rw.d.handleMu.RUnlock()
+ return done, err
+ }
+ // err might have occurred in part of gap.Range() outside
+ // gapMR. Forget about it for now; if the error matters and
+ // persists, we'll run into it again in a later iteration of
+ // this loop.
+ } else {
+ // Read directly from the file.
+ gapDsts := dsts.TakeFirst64(gapMR.Length())
+ n, err := rw.d.handle.readToBlocksAt(rw.ctx, gapDsts, gapMR.Start)
+ done += n
+ rw.off += n
+ dsts = dsts.DropFirst64(n)
+ // Partial reads are fine. But we must stop reading.
+ if n != gapDsts.NumBytes() || err != nil {
+ dataMuUnlock()
+ rw.d.handleMu.RUnlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
+ }
+ }
+ }
+ dataMuUnlock()
+ rw.d.handleMu.RUnlock()
+ return done, nil
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+//
+// Preconditions: rw.d.metadataMu must be locked.
+func (rw *dentryReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ if srcs.IsEmpty() {
+ return 0, nil
+ }
+
+ // If we have a mmappable host FD (which must be used here to ensure
+ // coherence with memory-mapped I/O), or if InteropModeShared is in effect
+ // (which prevents us from caching file contents), or if the file was
+ // opened with O_DIRECT, write directly to dentry.handle without locking
+ // dentry.dataMu.
+ rw.d.handleMu.RLock()
+ if (rw.d.handle.fd >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct {
+ n, err := rw.d.handle.writeFromBlocksAt(rw.ctx, srcs, rw.off)
+ rw.off += n
+ rw.d.dataMu.Lock()
+ if rw.off > rw.d.size {
+ atomic.StoreUint64(&rw.d.size, rw.off)
+ // The remote file's size will implicitly be extended to the correct
+ // value when we write back to it.
+ }
+ rw.d.dataMu.Unlock()
+ rw.d.handleMu.RUnlock()
+ return n, err
+ }
+
+ // Otherwise write to/through the cache.
+ mf := rw.d.fs.mfp.MemoryFile()
+ rw.d.dataMu.Lock()
+
+ // Compute the range to write (overflow-checked).
+ start := rw.off
+ end := rw.off + srcs.NumBytes()
+ if end <= rw.off {
+ end = math.MaxInt64
+ }
+
+ var (
+ done uint64
+ retErr error
+ )
+ seg, gap := rw.d.cache.Find(rw.off)
+ for rw.off < end {
+ mr := memmap.MappableRange{rw.off, end}
+ switch {
+ case seg.Ok():
+ // Get internal mappings from the cache.
+ segMR := seg.Range().Intersect(mr)
+ ims, err := mf.MapInternal(seg.FileRangeOf(segMR), usermem.Write)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Copy to internal mappings.
+ n, err := safemem.CopySeq(ims, srcs)
+ done += n
+ rw.off += n
+ srcs = srcs.DropFirst64(n)
+ rw.d.dirty.MarkDirty(segMR)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ // Write directly to the file. At present, we never fill the cache
+ // when writing, since doing so can convert small writes into
+ // inefficient read-modify-write cycles, and we have no mechanism
+ // for detecting or avoiding this.
+ gapMR := gap.Range().Intersect(mr)
+ gapSrcs := srcs.TakeFirst64(gapMR.Length())
+ n, err := rw.d.handle.writeFromBlocksAt(rw.ctx, gapSrcs, gapMR.Start)
+ done += n
+ rw.off += n
+ srcs = srcs.DropFirst64(n)
+ // Partial writes are fine. But we must stop writing.
+ if n != gapSrcs.NumBytes() || err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Continue.
+ seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
+ }
+ }
+exitLoop:
+ if rw.off > rw.d.size {
+ atomic.StoreUint64(&rw.d.size, rw.off)
+ // The remote file's size will implicitly be extended to the correct
+ // value when we write back to it.
+ }
+ // If InteropModeWritethrough is in effect, flush written data back to the
+ // remote filesystem.
+ if rw.d.fs.opts.interop == InteropModeWritethrough && done != 0 {
+ if err := fsutil.SyncDirty(rw.ctx, memmap.MappableRange{
+ Start: start,
+ End: rw.off,
+ }, &rw.d.cache, &rw.d.dirty, rw.d.size, mf, rw.d.handle.writeFromBlocksAt); err != nil {
+ // We have no idea how many bytes were actually flushed.
+ rw.off = start
+ done = 0
+ retErr = err
+ }
+ }
+ rw.d.dataMu.Unlock()
+ rw.d.handleMu.RUnlock()
+ return done, retErr
+}
+
+func (d *dentry) writeback(ctx context.Context, offset, size int64) error {
+ if size == 0 {
+ return nil
+ }
+ d.handleMu.RLock()
+ defer d.handleMu.RUnlock()
+ d.dataMu.Lock()
+ defer d.dataMu.Unlock()
+ // Compute the range of valid bytes (overflow-checked).
+ if uint64(offset) >= d.size {
+ return nil
+ }
+ end := int64(d.size)
+ if rend := offset + size; rend > offset && rend < end {
+ end = rend
+ }
+ return fsutil.SyncDirty(ctx, memmap.MappableRange{
+ Start: uint64(offset),
+ End: uint64(end),
+ }, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt)
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ newOffset, err := regularFileSeekLocked(ctx, fd.dentry(), fd.off, offset, whence)
+ if err != nil {
+ return 0, err
+ }
+ fd.off = newOffset
+ return newOffset, nil
+}
+
+// Calculate the new offset for a seek operation on a regular file.
+func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int64, whence int32) (int64, error) {
+ switch whence {
+ case linux.SEEK_SET:
+ // Use offset as specified.
+ case linux.SEEK_CUR:
+ offset += fdOffset
+ case linux.SEEK_END, linux.SEEK_DATA, linux.SEEK_HOLE:
+ // Ensure file size is up to date.
+ if !d.cachedMetadataAuthoritative() {
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return 0, err
+ }
+ }
+ size := int64(atomic.LoadUint64(&d.size))
+ // For SEEK_DATA and SEEK_HOLE, treat the file as a single contiguous
+ // block of data.
+ switch whence {
+ case linux.SEEK_END:
+ offset += size
+ case linux.SEEK_DATA:
+ if offset > size {
+ return 0, syserror.ENXIO
+ }
+ // Use offset as specified.
+ case linux.SEEK_HOLE:
+ if offset > size {
+ return 0, syserror.ENXIO
+ }
+ offset = size
+ }
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ return offset, nil
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync.
+func (fd *regularFileFD) Sync(ctx context.Context) error {
+ return fd.dentry().syncSharedHandle(ctx)
+}
+
+func (d *dentry) syncSharedHandle(ctx context.Context) error {
+ d.handleMu.RLock()
+ defer d.handleMu.RUnlock()
+
+ if d.handleWritable {
+ d.dataMu.Lock()
+ // Write dirty cached data to the remote file.
+ err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt)
+ d.dataMu.Unlock()
+ if err != nil {
+ return err
+ }
+ }
+ // Sync the remote file.
+ return d.handle.sync(ctx)
+}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ d := fd.dentry()
+ switch d.fs.opts.interop {
+ case InteropModeExclusive:
+ // Any mapping is fine.
+ case InteropModeWritethrough:
+ // Shared writable mappings require a host FD, since otherwise we can't
+ // synchronously flush memory-mapped writes to the remote file.
+ if opts.Private || !opts.MaxPerms.Write {
+ break
+ }
+ fallthrough
+ case InteropModeShared:
+ // All mappings require a host FD to be coherent with other filesystem
+ // users.
+ if d.fs.opts.forcePageCache {
+ // Whether or not we have a host FD, we're not allowed to use it.
+ return syserror.ENODEV
+ }
+ d.handleMu.RLock()
+ haveFD := d.handle.fd >= 0
+ d.handleMu.RUnlock()
+ if !haveFD {
+ return syserror.ENODEV
+ }
+ default:
+ panic(fmt.Sprintf("unknown InteropMode %v", d.fs.opts.interop))
+ }
+ // After this point, d may be used as a memmap.Mappable.
+ d.pf.hostFileMapperInitOnce.Do(d.pf.hostFileMapper.Init)
+ return vfs.GenericConfigureMMap(&fd.vfsfd, d, opts)
+}
+
+func (d *dentry) mayCachePages() bool {
+ if d.fs.opts.interop == InteropModeShared {
+ return false
+ }
+ if d.fs.opts.forcePageCache {
+ return true
+ }
+ d.handleMu.RLock()
+ haveFD := d.handle.fd >= 0
+ d.handleMu.RUnlock()
+ return haveFD
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ d.mapsMu.Lock()
+ mapped := d.mappings.AddMapping(ms, ar, offset, writable)
+ // Do this unconditionally since whether we have a host FD can change
+ // across save/restore.
+ for _, r := range mapped {
+ d.pf.hostFileMapper.IncRefOn(r)
+ }
+ if d.mayCachePages() {
+ // d.Evict() will refuse to evict memory-mapped pages, so tell the
+ // MemoryFile to not bother trying.
+ mf := d.fs.mfp.MemoryFile()
+ for _, r := range mapped {
+ mf.MarkUnevictable(d, pgalloc.EvictableRange{r.Start, r.End})
+ }
+ }
+ d.mapsMu.Unlock()
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ d.mapsMu.Lock()
+ unmapped := d.mappings.RemoveMapping(ms, ar, offset, writable)
+ for _, r := range unmapped {
+ d.pf.hostFileMapper.DecRefOn(r)
+ }
+ if d.mayCachePages() {
+ // Pages that are no longer referenced by any application memory
+ // mappings are now considered unused; allow MemoryFile to evict them
+ // when necessary.
+ mf := d.fs.mfp.MemoryFile()
+ d.dataMu.Lock()
+ for _, r := range unmapped {
+ // Since these pages are no longer mapped, they are no longer
+ // concurrently dirtyable by a writable memory mapping.
+ d.dirty.AllowClean(r)
+ mf.MarkEvictable(d, pgalloc.EvictableRange{r.Start, r.End})
+ }
+ d.dataMu.Unlock()
+ }
+ d.mapsMu.Unlock()
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ return d.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ d.handleMu.RLock()
+ if d.handle.fd >= 0 && !d.fs.opts.forcePageCache {
+ d.handleMu.RUnlock()
+ mr := optional
+ if d.fs.opts.limitHostFDTranslation {
+ mr = maxFillRange(required, optional)
+ }
+ return []memmap.Translation{
+ {
+ Source: mr,
+ File: &d.pf,
+ Offset: mr.Start,
+ Perms: usermem.AnyAccess,
+ },
+ }, nil
+ }
+
+ d.dataMu.Lock()
+
+ // Constrain translations to d.size (rounded up) to prevent translation to
+ // pages that may be concurrently truncated.
+ pgend, _ := usermem.PageRoundUp(d.size)
+ var beyondEOF bool
+ if required.End > pgend {
+ if required.Start >= pgend {
+ d.dataMu.Unlock()
+ d.handleMu.RUnlock()
+ return nil, &memmap.BusError{io.EOF}
+ }
+ beyondEOF = true
+ required.End = pgend
+ }
+ if optional.End > pgend {
+ optional.End = pgend
+ }
+
+ mf := d.fs.mfp.MemoryFile()
+ cerr := d.cache.Fill(ctx, required, maxFillRange(required, optional), mf, usage.PageCache, d.handle.readToBlocksAt)
+
+ var ts []memmap.Translation
+ var translatedEnd uint64
+ for seg := d.cache.FindSegment(required.Start); seg.Ok() && seg.Start() < required.End; seg, _ = seg.NextNonEmpty() {
+ segMR := seg.Range().Intersect(optional)
+ // TODO(jamieliu): Make Translations writable even if writability is
+ // not required if already kept-dirty by another writable translation.
+ perms := usermem.AccessType{
+ Read: true,
+ Execute: true,
+ }
+ if at.Write {
+ // From this point forward, this memory can be dirtied through the
+ // mapping at any time.
+ d.dirty.KeepDirty(segMR)
+ perms.Write = true
+ }
+ ts = append(ts, memmap.Translation{
+ Source: segMR,
+ File: mf,
+ Offset: seg.FileRangeOf(segMR).Start,
+ Perms: perms,
+ })
+ translatedEnd = segMR.End
+ }
+
+ d.dataMu.Unlock()
+ d.handleMu.RUnlock()
+
+ // Don't return the error returned by c.cache.Fill if it occurred outside
+ // of required.
+ if translatedEnd < required.End && cerr != nil {
+ return ts, &memmap.BusError{cerr}
+ }
+ if beyondEOF {
+ return ts, &memmap.BusError{io.EOF}
+ }
+ return ts, nil
+}
+
+func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange {
+ const maxReadahead = 64 << 10 // 64 KB, chosen arbitrarily
+ if required.Length() >= maxReadahead {
+ return required
+ }
+ if optional.Length() <= maxReadahead {
+ return optional
+ }
+ optional.Start = required.Start
+ if optional.Length() <= maxReadahead {
+ return optional
+ }
+ optional.End = optional.Start + maxReadahead
+ return optional
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (d *dentry) InvalidateUnsavable(ctx context.Context) error {
+ // Whether we have a host fd (and consequently what platform.File is
+ // mapped) can change across save/restore, so invalidate all translations
+ // unconditionally.
+ d.mapsMu.Lock()
+ defer d.mapsMu.Unlock()
+ d.mappings.InvalidateAll(memmap.InvalidateOpts{})
+
+ // Write the cache's contents back to the remote file so that if we have a
+ // host fd after restore, the remote file's contents are coherent.
+ mf := d.fs.mfp.MemoryFile()
+ d.dataMu.Lock()
+ defer d.dataMu.Unlock()
+ if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, d.handle.writeFromBlocksAt); err != nil {
+ return err
+ }
+
+ // Discard the cache so that it's not stored in saved state. This is safe
+ // because per InvalidateUnsavable invariants, no new translations can have
+ // been returned after we invalidated all existing translations above.
+ d.cache.DropAll(mf)
+ d.dirty.RemoveAll()
+
+ return nil
+}
+
+// Evict implements pgalloc.EvictableMemoryUser.Evict.
+func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) {
+ d.mapsMu.Lock()
+ defer d.mapsMu.Unlock()
+ d.dataMu.Lock()
+ defer d.dataMu.Unlock()
+
+ mr := memmap.MappableRange{er.Start, er.End}
+ mf := d.fs.mfp.MemoryFile()
+ // Only allow pages that are no longer memory-mapped to be evicted.
+ for mgap := d.mappings.LowerBoundGap(mr.Start); mgap.Ok() && mgap.Start() < mr.End; mgap = mgap.NextGap() {
+ mgapMR := mgap.Range().Intersect(mr)
+ if mgapMR.Length() == 0 {
+ continue
+ }
+ if err := fsutil.SyncDirty(ctx, mgapMR, &d.cache, &d.dirty, d.size, mf, d.handle.writeFromBlocksAt); err != nil {
+ log.Warningf("Failed to writeback cached data %v: %v", mgapMR, err)
+ }
+ d.cache.Drop(mgapMR, mf)
+ d.dirty.KeepClean(mgapMR)
+ }
+}
+
+// dentryPlatformFile implements platform.File. It exists solely because dentry
+// cannot implement both vfs.DentryImpl.IncRef and platform.File.IncRef.
+//
+// dentryPlatformFile is only used when a host FD representing the remote file
+// is available (i.e. dentry.handle.fd >= 0), and that FD is used for
+// application memory mappings (i.e. !filesystem.opts.forcePageCache).
+type dentryPlatformFile struct {
+ *dentry
+
+ // fdRefs counts references on platform.File offsets. fdRefs is protected
+ // by dentry.dataMu.
+ fdRefs fsutil.FrameRefSet
+
+ // If this dentry represents a regular file, and handle.fd >= 0,
+ // hostFileMapper caches mappings of handle.fd.
+ hostFileMapper fsutil.HostFileMapper
+
+ // hostFileMapperInitOnce is used to lazily initialize hostFileMapper.
+ hostFileMapperInitOnce sync.Once
+}
+
+// IncRef implements platform.File.IncRef.
+func (d *dentryPlatformFile) IncRef(fr platform.FileRange) {
+ d.dataMu.Lock()
+ d.fdRefs.IncRefAndAccount(fr)
+ d.dataMu.Unlock()
+}
+
+// DecRef implements platform.File.DecRef.
+func (d *dentryPlatformFile) DecRef(fr platform.FileRange) {
+ d.dataMu.Lock()
+ d.fdRefs.DecRefAndAccount(fr)
+ d.dataMu.Unlock()
+}
+
+// MapInternal implements platform.File.MapInternal.
+func (d *dentryPlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+ d.handleMu.RLock()
+ bs, err := d.hostFileMapper.MapInternal(fr, int(d.handle.fd), at.Write)
+ d.handleMu.RUnlock()
+ return bs, err
+}
+
+// FD implements platform.File.FD.
+func (d *dentryPlatformFile) FD() int {
+ d.handleMu.RLock()
+ fd := d.handle.fd
+ d.handleMu.RUnlock()
+ return int(fd)
+}
diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go
new file mode 100644
index 000000000..d6dbe9092
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/socket.go
@@ -0,0 +1,146 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func (d *dentry) isSocket() bool {
+ return d.fileType() == linux.S_IFSOCK
+}
+
+// endpoint is a Gofer-backed transport.BoundEndpoint.
+//
+// An endpoint's lifetime is the time between when filesystem.BoundEndpointAt()
+// is called and either BoundEndpoint.BidirectionalConnect or
+// BoundEndpoint.UnidirectionalConnect is called.
+type endpoint struct {
+ // dentry is the filesystem dentry which produced this endpoint.
+ dentry *dentry
+
+ // file is the p9 file that contains a single unopened fid.
+ file p9.File
+
+ // path is the sentry path where this endpoint is bound.
+ path string
+}
+
+func sockTypeToP9(t linux.SockType) (p9.ConnectFlags, bool) {
+ switch t {
+ case linux.SOCK_STREAM:
+ return p9.StreamSocket, true
+ case linux.SOCK_SEQPACKET:
+ return p9.SeqpacketSocket, true
+ case linux.SOCK_DGRAM:
+ return p9.DgramSocket, true
+ }
+ return 0, false
+}
+
+// BidirectionalConnect implements ConnectableEndpoint.BidirectionalConnect.
+func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *syserr.Error {
+ cf, ok := sockTypeToP9(ce.Type())
+ if !ok {
+ return syserr.ErrConnectionRefused
+ }
+
+ // No lock ordering required as only the ConnectingEndpoint has a mutex.
+ ce.Lock()
+
+ // Check connecting state.
+ if ce.Connected() {
+ ce.Unlock()
+ return syserr.ErrAlreadyConnected
+ }
+ if ce.Listening() {
+ ce.Unlock()
+ return syserr.ErrInvalidEndpointState
+ }
+
+ c, err := e.newConnectedEndpoint(ctx, cf, ce.WaiterQueue())
+ if err != nil {
+ ce.Unlock()
+ return err
+ }
+
+ returnConnect(c, c)
+ ce.Unlock()
+ if err := c.Init(); err != nil {
+ return syserr.FromError(err)
+ }
+
+ return nil
+}
+
+// UnidirectionalConnect implements
+// transport.BoundEndpoint.UnidirectionalConnect.
+func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.ConnectedEndpoint, *syserr.Error) {
+ c, err := e.newConnectedEndpoint(ctx, p9.DgramSocket, &waiter.Queue{})
+ if err != nil {
+ return nil, err
+ }
+
+ if err := c.Init(); err != nil {
+ return nil, syserr.FromError(err)
+ }
+
+ // We don't need the receiver.
+ c.CloseRecv()
+ c.Release()
+
+ return c, nil
+}
+
+func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFlags, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) {
+ hostFile, err := e.file.Connect(flags)
+ if err != nil {
+ return nil, syserr.ErrConnectionRefused
+ }
+ // Dup the fd so that the new endpoint can manage its lifetime.
+ hostFD, err := syscall.Dup(hostFile.FD())
+ if err != nil {
+ log.Warningf("Could not dup host socket fd %d: %v", hostFile.FD(), err)
+ return nil, syserr.FromError(err)
+ }
+ // After duplicating, we no longer need hostFile.
+ hostFile.Close()
+
+ c, serr := host.NewSCMEndpoint(ctx, hostFD, queue, e.path)
+ if serr != nil {
+ log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.file, flags, serr)
+ return nil, serr
+ }
+ return c, nil
+}
+
+// Release implements transport.BoundEndpoint.Release.
+func (e *endpoint) Release() {
+ e.dentry.DecRef()
+}
+
+// Passcred implements transport.BoundEndpoint.Passcred.
+func (e *endpoint) Passcred() bool {
+ return false
+}
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
new file mode 100644
index 000000000..c1e6b13e5
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -0,0 +1,245 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// specialFileFD implements vfs.FileDescriptionImpl for pipes, sockets, device
+// special files, and (when filesystemOptions.specialRegularFiles is in effect)
+// regular files. specialFileFD differs from regularFileFD by using per-FD
+// handles instead of shared per-dentry handles, and never buffering I/O.
+type specialFileFD struct {
+ fileDescription
+
+ // handle is used for file I/O. handle is immutable.
+ handle handle
+
+ // seekable is true if this file description represents a file for which
+ // file offset is significant, i.e. a regular file. seekable is immutable.
+ seekable bool
+
+ // haveQueue is true if this file description represents a file for which
+ // queue may send I/O readiness events. haveQueue is immutable.
+ haveQueue bool
+ queue waiter.Queue
+
+ // If seekable is true, off is the file offset. off is protected by mu.
+ mu sync.Mutex
+ off int64
+}
+
+func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) {
+ ftype := d.fileType()
+ seekable := ftype == linux.S_IFREG
+ haveQueue := (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && h.fd >= 0
+ fd := &specialFileFD{
+ handle: h,
+ seekable: seekable,
+ haveQueue: haveQueue,
+ }
+ fd.LockFD.Init(locks)
+ if haveQueue {
+ if err := fdnotifier.AddFD(h.fd, &fd.queue); err != nil {
+ return nil, err
+ }
+ }
+ if err := fd.vfsfd.Init(fd, flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
+ DenyPRead: !seekable,
+ DenyPWrite: !seekable,
+ }); err != nil {
+ if haveQueue {
+ fdnotifier.RemoveFD(h.fd)
+ }
+ return nil, err
+ }
+ return fd, nil
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *specialFileFD) Release() {
+ if fd.haveQueue {
+ fdnotifier.RemoveFD(fd.handle.fd)
+ }
+ fd.handle.close(context.Background())
+ fs := fd.vfsfd.Mount().Filesystem().Impl().(*filesystem)
+ fs.syncMu.Lock()
+ delete(fs.specialFileFDs, fd)
+ fs.syncMu.Unlock()
+}
+
+// OnClose implements vfs.FileDescriptionImpl.OnClose.
+func (fd *specialFileFD) OnClose(ctx context.Context) error {
+ if !fd.vfsfd.IsWritable() {
+ return nil
+ }
+ return fd.handle.file.flush(ctx)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (fd *specialFileFD) Readiness(mask waiter.EventMask) waiter.EventMask {
+ if fd.haveQueue {
+ return fdnotifier.NonBlockingPoll(fd.handle.fd, mask)
+ }
+ return fd.fileDescription.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (fd *specialFileFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ if fd.haveQueue {
+ fd.queue.EventRegister(e, mask)
+ fdnotifier.UpdateFD(fd.handle.fd)
+ return
+ }
+ fd.fileDescription.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (fd *specialFileFD) EventUnregister(e *waiter.Entry) {
+ if fd.haveQueue {
+ fd.queue.EventUnregister(e)
+ fdnotifier.UpdateFD(fd.handle.fd)
+ return
+ }
+ fd.fileDescription.EventUnregister(e)
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ if fd.seekable && offset < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ // Going through dst.CopyOutFrom() holds MM locks around file operations of
+ // unknown duration. For regularFileFD, doing so is necessary to support
+ // mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't
+ // hold here since specialFileFD doesn't client-cache data. Just buffer the
+ // read instead.
+ if d := fd.dentry(); d.fs.opts.interop != InteropModeShared {
+ d.touchAtime(fd.vfsfd.Mount())
+ }
+ buf := make([]byte, dst.NumBytes())
+ n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
+ if err == syserror.EAGAIN {
+ err = syserror.ErrWouldBlock
+ }
+ if n == 0 {
+ return 0, err
+ }
+ if cp, cperr := dst.CopyOut(ctx, buf[:n]); cperr != nil {
+ return int64(cp), cperr
+ }
+ return int64(n), err
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *specialFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ if !fd.seekable {
+ return fd.PRead(ctx, dst, -1, opts)
+ }
+
+ fd.mu.Lock()
+ n, err := fd.PRead(ctx, dst, fd.off, opts)
+ fd.off += n
+ fd.mu.Unlock()
+ return n, err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ if fd.seekable && offset < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select pwritev2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ if fd.seekable {
+ limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes())
+ if err != nil {
+ return 0, err
+ }
+ src = src.TakeFirst64(limit)
+ }
+
+ // Do a buffered write. See rationale in PRead.
+ if d := fd.dentry(); d.fs.opts.interop != InteropModeShared {
+ d.touchCMtime()
+ }
+ buf := make([]byte, src.NumBytes())
+ // Don't do partial writes if we get a partial read from src.
+ if _, err := src.CopyIn(ctx, buf); err != nil {
+ return 0, err
+ }
+ n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
+ if err == syserror.EAGAIN {
+ err = syserror.ErrWouldBlock
+ }
+ return int64(n), err
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *specialFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ if !fd.seekable {
+ return fd.PWrite(ctx, src, -1, opts)
+ }
+
+ fd.mu.Lock()
+ n, err := fd.PWrite(ctx, src, fd.off, opts)
+ fd.off += n
+ fd.mu.Unlock()
+ return n, err
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ if !fd.seekable {
+ return 0, syserror.ESPIPE
+ }
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ newOffset, err := regularFileSeekLocked(ctx, fd.dentry(), fd.off, offset, whence)
+ if err != nil {
+ return 0, err
+ }
+ fd.off = newOffset
+ return newOffset, nil
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync.
+func (fd *specialFileFD) Sync(ctx context.Context) error {
+ return fd.dentry().syncSharedHandle(ctx)
+}
diff --git a/pkg/sentry/fsimpl/gofer/symlink.go b/pkg/sentry/fsimpl/gofer/symlink.go
new file mode 100644
index 000000000..2ec819f86
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/symlink.go
@@ -0,0 +1,47 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+func (d *dentry) isSymlink() bool {
+ return d.fileType() == linux.S_IFLNK
+}
+
+// Precondition: d.isSymlink().
+func (d *dentry) readlink(ctx context.Context, mnt *vfs.Mount) (string, error) {
+ if d.fs.opts.interop != InteropModeShared {
+ d.touchAtime(mnt)
+ d.dataMu.Lock()
+ if d.haveTarget {
+ target := d.target
+ d.dataMu.Unlock()
+ return target, nil
+ }
+ }
+ target, err := d.file.readlink(ctx)
+ if d.fs.opts.interop != InteropModeShared {
+ if err == nil {
+ d.haveTarget = true
+ d.target = target
+ }
+ d.dataMu.Unlock()
+ }
+ return target, err
+}
diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go
new file mode 100644
index 000000000..0eef4e16e
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/time.go
@@ -0,0 +1,79 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+func dentryTimestampFromP9(s, ns uint64) int64 {
+ return int64(s*1e9 + ns)
+}
+
+func dentryTimestampFromStatx(ts linux.StatxTimestamp) int64 {
+ return ts.Sec*1e9 + int64(ts.Nsec)
+}
+
+func statxTimestampFromDentry(ns int64) linux.StatxTimestamp {
+ return linux.StatxTimestamp{
+ Sec: ns / 1e9,
+ Nsec: uint32(ns % 1e9),
+ }
+}
+
+// Preconditions: d.cachedMetadataAuthoritative() == true.
+func (d *dentry) touchAtime(mnt *vfs.Mount) {
+ if mnt.Flags.NoATime {
+ return
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return
+ }
+ now := d.fs.clock.Now().Nanoseconds()
+ d.metadataMu.Lock()
+ atomic.StoreInt64(&d.atime, now)
+ d.metadataMu.Unlock()
+ mnt.EndWrite()
+}
+
+// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has
+// successfully called vfs.Mount.CheckBeginWrite().
+func (d *dentry) touchCtime() {
+ now := d.fs.clock.Now().Nanoseconds()
+ d.metadataMu.Lock()
+ atomic.StoreInt64(&d.ctime, now)
+ d.metadataMu.Unlock()
+}
+
+// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has
+// successfully called vfs.Mount.CheckBeginWrite().
+func (d *dentry) touchCMtime() {
+ now := d.fs.clock.Now().Nanoseconds()
+ d.metadataMu.Lock()
+ atomic.StoreInt64(&d.mtime, now)
+ atomic.StoreInt64(&d.ctime, now)
+ d.metadataMu.Unlock()
+}
+
+// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has
+// locked d.metadataMu.
+func (d *dentry) touchCMtimeLocked() {
+ now := d.fs.clock.Now().Nanoseconds()
+ atomic.StoreInt64(&d.mtime, now)
+ atomic.StoreInt64(&d.ctime, now)
+}
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
new file mode 100644
index 000000000..44a09d87a
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -0,0 +1,52 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "host",
+ srcs = [
+ "control.go",
+ "host.go",
+ "ioctl_unsafe.go",
+ "mmap.go",
+ "socket.go",
+ "socket_iovec.go",
+ "socket_unsafe.go",
+ "tty.go",
+ "util.go",
+ "util_unsafe.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fdnotifier",
+ "//pkg/fspath",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/hostfd",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/socket/control",
+ "//pkg/sentry/socket/unix",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/unimpl",
+ "//pkg/sentry/uniqueid",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/unet",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/host/control.go b/pkg/sentry/fsimpl/host/control.go
new file mode 100644
index 000000000..b9082a20f
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/control.go
@@ -0,0 +1,96 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+type scmRights struct {
+ fds []int
+}
+
+func newSCMRights(fds []int) control.SCMRightsVFS2 {
+ return &scmRights{fds}
+}
+
+// Files implements control.SCMRights.Files.
+func (c *scmRights) Files(ctx context.Context, max int) (control.RightsFilesVFS2, bool) {
+ n := max
+ var trunc bool
+ if l := len(c.fds); n > l {
+ n = l
+ } else if n < l {
+ trunc = true
+ }
+
+ rf := control.RightsFilesVFS2(fdsToFiles(ctx, c.fds[:n]))
+
+ // Only consume converted FDs (fdsToFiles may convert fewer than n FDs).
+ c.fds = c.fds[len(rf):]
+ return rf, trunc
+}
+
+// Clone implements transport.RightsControlMessage.Clone.
+func (c *scmRights) Clone() transport.RightsControlMessage {
+ // Host rights never need to be cloned.
+ return nil
+}
+
+// Release implements transport.RightsControlMessage.Release.
+func (c *scmRights) Release() {
+ for _, fd := range c.fds {
+ syscall.Close(fd)
+ }
+ c.fds = nil
+}
+
+// If an error is encountered, only files created before the error will be
+// returned. This is what Linux does.
+func fdsToFiles(ctx context.Context, fds []int) []*vfs.FileDescription {
+ files := make([]*vfs.FileDescription, 0, len(fds))
+ for _, fd := range fds {
+ // Get flags. We do it here because they may be modified
+ // by subsequent functions.
+ fileFlags, _, errno := syscall.Syscall(syscall.SYS_FCNTL, uintptr(fd), syscall.F_GETFL, 0)
+ if errno != 0 {
+ ctx.Warningf("Error retrieving host FD flags: %v", error(errno))
+ break
+ }
+
+ // Create the file backed by hostFD.
+ file, err := ImportFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fd, false /* isTTY */)
+ if err != nil {
+ ctx.Warningf("Error creating file from host FD: %v", err)
+ break
+ }
+
+ if err := file.SetStatusFlags(ctx, auth.CredentialsFromContext(ctx), uint32(fileFlags&linux.O_NONBLOCK)); err != nil {
+ ctx.Warningf("Error setting flags on host FD file: %v", err)
+ break
+ }
+
+ files = append(files, file)
+ }
+ return files
+}
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
new file mode 100644
index 000000000..1cd2982cb
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -0,0 +1,766 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package host provides a filesystem implementation for host files imported as
+// file descriptors.
+package host
+
+import (
+ "fmt"
+ "math"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/refs"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/hostfd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// NewFDOptions contains options to NewFD.
+type NewFDOptions struct {
+ // If IsTTY is true, the file descriptor is a TTY.
+ IsTTY bool
+
+ // If HaveFlags is true, use Flags for the new file description. Otherwise,
+ // the new file description will inherit flags from hostFD.
+ HaveFlags bool
+ Flags uint32
+}
+
+// NewFD returns a vfs.FileDescription representing the given host file
+// descriptor. mnt must be Kernel.HostMount().
+func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) (*vfs.FileDescription, error) {
+ fs, ok := mnt.Filesystem().Impl().(*filesystem)
+ if !ok {
+ return nil, fmt.Errorf("can't import host FDs into filesystems of type %T", mnt.Filesystem().Impl())
+ }
+
+ // Retrieve metadata.
+ var s unix.Stat_t
+ if err := unix.Fstat(hostFD, &s); err != nil {
+ return nil, err
+ }
+
+ flags := opts.Flags
+ if !opts.HaveFlags {
+ // Get flags for the imported FD.
+ flagsInt, err := unix.FcntlInt(uintptr(hostFD), syscall.F_GETFL, 0)
+ if err != nil {
+ return nil, err
+ }
+ flags = uint32(flagsInt)
+ }
+
+ fileMode := linux.FileMode(s.Mode)
+ fileType := fileMode.FileType()
+
+ // Determine if hostFD is seekable. If not, this syscall will return ESPIPE
+ // (see fs/read_write.c:llseek), e.g. for pipes, sockets, and some character
+ // devices.
+ _, err := unix.Seek(hostFD, 0, linux.SEEK_CUR)
+ seekable := err != syserror.ESPIPE
+
+ i := &inode{
+ hostFD: hostFD,
+ ino: fs.NextIno(),
+ isTTY: opts.IsTTY,
+ wouldBlock: wouldBlock(uint32(fileType)),
+ seekable: seekable,
+ // NOTE(b/38213152): Technically, some obscure char devices can be memory
+ // mapped, but we only allow regular files.
+ canMap: fileType == linux.S_IFREG,
+ }
+ i.pf.inode = i
+
+ // Non-seekable files can't be memory mapped, assert this.
+ if !i.seekable && i.canMap {
+ panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped")
+ }
+
+ // If the hostFD would block, we must set it to non-blocking and handle
+ // blocking behavior in the sentry.
+ if i.wouldBlock {
+ if err := syscall.SetNonblock(i.hostFD, true); err != nil {
+ return nil, err
+ }
+ if err := fdnotifier.AddFD(int32(i.hostFD), &i.queue); err != nil {
+ return nil, err
+ }
+ }
+
+ d := &kernfs.Dentry{}
+ d.Init(i)
+
+ // i.open will take a reference on d.
+ defer d.DecRef()
+
+ // For simplicity, fileDescription.offset is set to 0. Technically, we
+ // should only set to 0 on files that are not seekable (sockets, pipes,
+ // etc.), and use the offset from the host fd otherwise when importing.
+ return i.open(ctx, d.VFSDentry(), mnt, flags)
+}
+
+// ImportFD sets up and returns a vfs.FileDescription from a donated fd.
+func ImportFD(ctx context.Context, mnt *vfs.Mount, hostFD int, isTTY bool) (*vfs.FileDescription, error) {
+ return NewFD(ctx, mnt, hostFD, &NewFDOptions{
+ IsTTY: isTTY,
+ })
+}
+
+// filesystemType implements vfs.FilesystemType.
+type filesystemType struct{}
+
+// GetFilesystem implements FilesystemType.GetFilesystem.
+func (filesystemType) GetFilesystem(context.Context, *vfs.VirtualFilesystem, *auth.Credentials, string, vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ panic("host.filesystemType.GetFilesystem should never be called")
+}
+
+// Name implements FilesystemType.Name.
+func (filesystemType) Name() string {
+ return "none"
+}
+
+// NewFilesystem sets up and returns a new hostfs filesystem.
+//
+// Note that there should only ever be one instance of host.filesystem,
+// a global mount for host fds.
+func NewFilesystem(vfsObj *vfs.VirtualFilesystem) (*vfs.Filesystem, error) {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, err
+ }
+ fs := &filesystem{
+ devMinor: devMinor,
+ }
+ fs.VFSFilesystem().Init(vfsObj, filesystemType{}, fs)
+ return fs.VFSFilesystem(), nil
+}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ kernfs.Filesystem
+
+ devMinor uint32
+}
+
+func (fs *filesystem) Release() {
+ fs.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release()
+}
+
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ d := vd.Dentry().Impl().(*kernfs.Dentry)
+ inode := d.Inode().(*inode)
+ b.PrependComponent(fmt.Sprintf("host:[%d]", inode.ino))
+ return vfs.PrependPathSyntheticError{}
+}
+
+// inode implements kernfs.Inode.
+type inode struct {
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+
+ locks vfs.FileLocks
+
+ // When the reference count reaches zero, the host fd is closed.
+ refs.AtomicRefCount
+
+ // hostFD contains the host fd that this file was originally created from,
+ // which must be available at time of restore.
+ //
+ // This field is initialized at creation time and is immutable.
+ hostFD int
+
+ // ino is an inode number unique within this filesystem.
+ //
+ // This field is initialized at creation time and is immutable.
+ ino uint64
+
+ // isTTY is true if this file represents a TTY.
+ //
+ // This field is initialized at creation time and is immutable.
+ isTTY bool
+
+ // seekable is false if the host fd points to a file representing a stream,
+ // e.g. a socket or a pipe. Such files are not seekable and can return
+ // EWOULDBLOCK for I/O operations.
+ //
+ // This field is initialized at creation time and is immutable.
+ seekable bool
+
+ // wouldBlock is true if the host FD would return EWOULDBLOCK for
+ // operations that would block.
+ //
+ // This field is initialized at creation time and is immutable.
+ wouldBlock bool
+
+ // Event queue for blocking operations.
+ queue waiter.Queue
+
+ // canMap specifies whether we allow the file to be memory mapped.
+ //
+ // This field is initialized at creation time and is immutable.
+ canMap bool
+
+ // mapsMu protects mappings.
+ mapsMu sync.Mutex
+
+ // If canMap is true, mappings tracks mappings of hostFD into
+ // memmap.MappingSpaces.
+ mappings memmap.MappingSet
+
+ // pf implements platform.File for mappings of hostFD.
+ pf inodePlatformFile
+}
+
+// CheckPermissions implements kernfs.Inode.
+func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ var s syscall.Stat_t
+ if err := syscall.Fstat(i.hostFD, &s); err != nil {
+ return err
+ }
+ return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(s.Mode), auth.KUID(s.Uid), auth.KGID(s.Gid))
+}
+
+// Mode implements kernfs.Inode.
+func (i *inode) Mode() linux.FileMode {
+ var s syscall.Stat_t
+ if err := syscall.Fstat(i.hostFD, &s); err != nil {
+ // Retrieving the mode from the host fd using fstat(2) should not fail.
+ // If the syscall does not succeed, something is fundamentally wrong.
+ panic(fmt.Sprintf("failed to retrieve mode from host fd %d: %v", i.hostFD, err))
+ }
+ return linux.FileMode(s.Mode)
+}
+
+// Stat implements kernfs.Inode.
+func (i *inode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ if opts.Mask&linux.STATX__RESERVED != 0 {
+ return linux.Statx{}, syserror.EINVAL
+ }
+ if opts.Sync&linux.AT_STATX_SYNC_TYPE == linux.AT_STATX_SYNC_TYPE {
+ return linux.Statx{}, syserror.EINVAL
+ }
+
+ fs := vfsfs.Impl().(*filesystem)
+
+ // Limit our host call only to known flags.
+ mask := opts.Mask & linux.STATX_ALL
+ var s unix.Statx_t
+ err := unix.Statx(i.hostFD, "", int(unix.AT_EMPTY_PATH|opts.Sync), int(mask), &s)
+ if err == syserror.ENOSYS {
+ // Fallback to fstat(2), if statx(2) is not supported on the host.
+ //
+ // TODO(b/151263641): Remove fallback.
+ return i.fstat(fs)
+ }
+ if err != nil {
+ return linux.Statx{}, err
+ }
+
+ // Unconditionally fill blksize, attributes, and device numbers, as
+ // indicated by /include/uapi/linux/stat.h. Inode number is always
+ // available, since we use our own rather than the host's.
+ ls := linux.Statx{
+ Mask: linux.STATX_INO,
+ Blksize: s.Blksize,
+ Attributes: s.Attributes,
+ Ino: i.ino,
+ AttributesMask: s.Attributes_mask,
+ DevMajor: linux.UNNAMED_MAJOR,
+ DevMinor: fs.devMinor,
+ }
+
+ // Copy other fields that were returned by the host. RdevMajor/RdevMinor
+ // are never copied (and therefore left as zero), so as not to expose host
+ // device numbers.
+ ls.Mask |= s.Mask & linux.STATX_ALL
+ if s.Mask&linux.STATX_TYPE != 0 {
+ ls.Mode |= s.Mode & linux.S_IFMT
+ }
+ if s.Mask&linux.STATX_MODE != 0 {
+ ls.Mode |= s.Mode &^ linux.S_IFMT
+ }
+ if s.Mask&linux.STATX_NLINK != 0 {
+ ls.Nlink = s.Nlink
+ }
+ if s.Mask&linux.STATX_UID != 0 {
+ ls.UID = s.Uid
+ }
+ if s.Mask&linux.STATX_GID != 0 {
+ ls.GID = s.Gid
+ }
+ if s.Mask&linux.STATX_ATIME != 0 {
+ ls.Atime = unixToLinuxStatxTimestamp(s.Atime)
+ }
+ if s.Mask&linux.STATX_BTIME != 0 {
+ ls.Btime = unixToLinuxStatxTimestamp(s.Btime)
+ }
+ if s.Mask&linux.STATX_CTIME != 0 {
+ ls.Ctime = unixToLinuxStatxTimestamp(s.Ctime)
+ }
+ if s.Mask&linux.STATX_MTIME != 0 {
+ ls.Mtime = unixToLinuxStatxTimestamp(s.Mtime)
+ }
+ if s.Mask&linux.STATX_SIZE != 0 {
+ ls.Size = s.Size
+ }
+ if s.Mask&linux.STATX_BLOCKS != 0 {
+ ls.Blocks = s.Blocks
+ }
+
+ return ls, nil
+}
+
+// fstat is a best-effort fallback for inode.Stat() if the host does not
+// support statx(2).
+//
+// We ignore the mask and sync flags in opts and simply supply
+// STATX_BASIC_STATS, as fstat(2) itself does not allow the specification
+// of a mask or sync flags. fstat(2) does not provide any metadata
+// equivalent to Statx.Attributes, Statx.AttributesMask, or Statx.Btime, so
+// those fields remain empty.
+func (i *inode) fstat(fs *filesystem) (linux.Statx, error) {
+ var s unix.Stat_t
+ if err := unix.Fstat(i.hostFD, &s); err != nil {
+ return linux.Statx{}, err
+ }
+
+ // As with inode.Stat(), we always use internal device and inode numbers,
+ // and never expose the host's represented device numbers.
+ return linux.Statx{
+ Mask: linux.STATX_BASIC_STATS,
+ Blksize: uint32(s.Blksize),
+ Nlink: uint32(s.Nlink),
+ UID: s.Uid,
+ GID: s.Gid,
+ Mode: uint16(s.Mode),
+ Ino: i.ino,
+ Size: uint64(s.Size),
+ Blocks: uint64(s.Blocks),
+ Atime: timespecToStatxTimestamp(s.Atim),
+ Ctime: timespecToStatxTimestamp(s.Ctim),
+ Mtime: timespecToStatxTimestamp(s.Mtim),
+ DevMajor: linux.UNNAMED_MAJOR,
+ DevMinor: fs.devMinor,
+ }, nil
+}
+
+// SetStat implements kernfs.Inode.
+func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ s := opts.Stat
+
+ m := s.Mask
+ if m == 0 {
+ return nil
+ }
+ if m&^(linux.STATX_MODE|linux.STATX_SIZE|linux.STATX_ATIME|linux.STATX_MTIME) != 0 {
+ return syserror.EPERM
+ }
+ var hostStat syscall.Stat_t
+ if err := syscall.Fstat(i.hostFD, &hostStat); err != nil {
+ return err
+ }
+ if err := vfs.CheckSetStat(ctx, creds, &s, linux.FileMode(hostStat.Mode&linux.PermissionsMask), auth.KUID(hostStat.Uid), auth.KGID(hostStat.Gid)); err != nil {
+ return err
+ }
+
+ if m&linux.STATX_MODE != 0 {
+ if err := syscall.Fchmod(i.hostFD, uint32(s.Mode)); err != nil {
+ return err
+ }
+ }
+ if m&linux.STATX_SIZE != 0 {
+ if err := syscall.Ftruncate(i.hostFD, int64(s.Size)); err != nil {
+ return err
+ }
+ oldSize := uint64(hostStat.Size)
+ if s.Size < oldSize {
+ oldpgend, _ := usermem.PageRoundUp(oldSize)
+ newpgend, _ := usermem.PageRoundUp(s.Size)
+ if oldpgend != newpgend {
+ i.mapsMu.Lock()
+ i.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
+ // Compare Linux's mm/truncate.c:truncate_setsize() =>
+ // truncate_pagecache() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ InvalidatePrivate: true,
+ })
+ i.mapsMu.Unlock()
+ }
+ }
+ }
+ if m&(linux.STATX_ATIME|linux.STATX_MTIME) != 0 {
+ ts := [2]syscall.Timespec{
+ toTimespec(s.Atime, m&linux.STATX_ATIME == 0),
+ toTimespec(s.Mtime, m&linux.STATX_MTIME == 0),
+ }
+ if err := setTimestamps(i.hostFD, &ts); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// DecRef implements kernfs.Inode.
+func (i *inode) DecRef() {
+ i.AtomicRefCount.DecRefWithDestructor(i.Destroy)
+}
+
+// Destroy implements kernfs.Inode.
+func (i *inode) Destroy() {
+ if i.wouldBlock {
+ fdnotifier.RemoveFD(int32(i.hostFD))
+ }
+ if err := unix.Close(i.hostFD); err != nil {
+ log.Warningf("failed to close host fd %d: %v", i.hostFD, err)
+ }
+}
+
+// Open implements kernfs.Inode.
+func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ // Once created, we cannot re-open a socket fd through /proc/[pid]/fd/.
+ if i.Mode().FileType() == linux.S_IFSOCK {
+ return nil, syserror.ENXIO
+ }
+ return i.open(ctx, vfsd, rp.Mount(), opts.Flags)
+}
+
+func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, error) {
+ var s syscall.Stat_t
+ if err := syscall.Fstat(i.hostFD, &s); err != nil {
+ return nil, err
+ }
+ fileType := s.Mode & linux.FileTypeMask
+
+ // Constrain flags to a subset we can handle.
+ //
+ // TODO(gvisor.dev/issue/2601): Support O_NONBLOCK by adding RWF_NOWAIT to pread/pwrite calls.
+ flags &= syscall.O_ACCMODE | syscall.O_NONBLOCK | syscall.O_DSYNC | syscall.O_SYNC | syscall.O_APPEND
+
+ switch fileType {
+ case syscall.S_IFSOCK:
+ if i.isTTY {
+ log.Warningf("cannot use host socket fd %d as TTY", i.hostFD)
+ return nil, syserror.ENOTTY
+ }
+
+ ep, err := newEndpoint(ctx, i.hostFD, &i.queue)
+ if err != nil {
+ return nil, err
+ }
+ // Currently, we only allow Unix sockets to be imported.
+ return unixsocket.NewFileDescription(ep, ep.Type(), flags, mnt, d, &i.locks)
+
+ case syscall.S_IFREG, syscall.S_IFIFO, syscall.S_IFCHR:
+ if i.isTTY {
+ fd := &TTYFileDescription{
+ fileDescription: fileDescription{inode: i},
+ termios: linux.DefaultSlaveTermios,
+ }
+ fd.LockFD.Init(&i.locks)
+ vfsfd := &fd.vfsfd
+ if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return vfsfd, nil
+ }
+
+ fd := &fileDescription{inode: i}
+ fd.LockFD.Init(&i.locks)
+ vfsfd := &fd.vfsfd
+ if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return vfsfd, nil
+
+ default:
+ log.Warningf("cannot import host fd %d with file type %o", i.hostFD, fileType)
+ return nil, syserror.EPERM
+ }
+}
+
+// fileDescription is embedded by host fd implementations of FileDescriptionImpl.
+type fileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ // inode is vfsfd.Dentry().Impl().(*kernfs.Dentry).Inode().(*inode), but
+ // cached to reduce indirections and casting. fileDescription does not hold
+ // a reference on the inode through the inode field (since one is already
+ // held via the Dentry).
+ //
+ // inode is immutable after fileDescription creation.
+ inode *inode
+
+ // offsetMu protects offset.
+ offsetMu sync.Mutex
+
+ // offset specifies the current file offset. It is only meaningful when
+ // inode.seekable is true.
+ offset int64
+}
+
+// SetStat implements vfs.FileDescriptionImpl.
+func (f *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ creds := auth.CredentialsFromContext(ctx)
+ return f.inode.SetStat(ctx, f.vfsfd.Mount().Filesystem(), creds, opts)
+}
+
+// Stat implements vfs.FileDescriptionImpl.
+func (f *fileDescription) Stat(_ context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ return f.inode.Stat(f.vfsfd.Mount().Filesystem(), opts)
+}
+
+// Release implements vfs.FileDescriptionImpl.
+func (f *fileDescription) Release() {
+ // noop
+}
+
+// Allocate implements vfs.FileDescriptionImpl.
+func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ if !f.inode.seekable {
+ return syserror.ESPIPE
+ }
+
+ // TODO(gvisor.dev/issue/2923): Implement Allocate for non-pipe hostfds.
+ return syserror.EOPNOTSUPP
+}
+
+// PRead implements FileDescriptionImpl.
+func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ i := f.inode
+ if !i.seekable {
+ return 0, syserror.ESPIPE
+ }
+
+ return readFromHostFD(ctx, i.hostFD, dst, offset, opts.Flags)
+}
+
+// Read implements FileDescriptionImpl.
+func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ i := f.inode
+ if !i.seekable {
+ n, err := readFromHostFD(ctx, i.hostFD, dst, -1, opts.Flags)
+ if isBlockError(err) {
+ // If we got any data at all, return it as a "completed" partial read
+ // rather than retrying until complete.
+ if n != 0 {
+ err = nil
+ } else {
+ err = syserror.ErrWouldBlock
+ }
+ }
+ return n, err
+ }
+
+ f.offsetMu.Lock()
+ n, err := readFromHostFD(ctx, i.hostFD, dst, f.offset, opts.Flags)
+ f.offset += n
+ f.offsetMu.Unlock()
+ return n, err
+}
+
+func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) {
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if flags&^linux.RWF_HIPRI != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+ reader := hostfd.GetReadWriterAt(int32(hostFD), offset, flags)
+ n, err := dst.CopyOutFrom(ctx, reader)
+ hostfd.PutReadWriterAt(reader)
+ return int64(n), err
+}
+
+// PWrite implements FileDescriptionImpl.
+func (f *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ if !f.inode.seekable {
+ return 0, syserror.ESPIPE
+ }
+
+ return f.writeToHostFD(ctx, src, offset, opts.Flags)
+}
+
+// Write implements FileDescriptionImpl.
+func (f *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ i := f.inode
+ if !i.seekable {
+ n, err := f.writeToHostFD(ctx, src, -1, opts.Flags)
+ if isBlockError(err) {
+ err = syserror.ErrWouldBlock
+ }
+ return n, err
+ }
+
+ f.offsetMu.Lock()
+ // NOTE(gvisor.dev/issue/2983): O_APPEND may cause memory corruption if
+ // another process modifies the host file between retrieving the file size
+ // and writing to the host fd. This is an unavoidable race condition because
+ // we cannot enforce synchronization on the host.
+ if f.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
+ var s syscall.Stat_t
+ if err := syscall.Fstat(i.hostFD, &s); err != nil {
+ f.offsetMu.Unlock()
+ return 0, err
+ }
+ f.offset = s.Size
+ }
+ n, err := f.writeToHostFD(ctx, src, f.offset, opts.Flags)
+ f.offset += n
+ f.offsetMu.Unlock()
+ return n, err
+}
+
+func (f *fileDescription) writeToHostFD(ctx context.Context, src usermem.IOSequence, offset int64, flags uint32) (int64, error) {
+ hostFD := f.inode.hostFD
+ // TODO(gvisor.dev/issue/2601): Support select pwritev2 flags.
+ if flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+ writer := hostfd.GetReadWriterAt(int32(hostFD), offset, flags)
+ n, err := src.CopyInTo(ctx, writer)
+ hostfd.PutReadWriterAt(writer)
+ // NOTE(gvisor.dev/issue/2979): We always sync everything, even for O_DSYNC.
+ if n > 0 && f.vfsfd.StatusFlags()&(linux.O_DSYNC|linux.O_SYNC) != 0 {
+ if syncErr := unix.Fsync(hostFD); syncErr != nil {
+ return int64(n), syncErr
+ }
+ }
+ return int64(n), err
+}
+
+// Seek implements FileDescriptionImpl.
+//
+// Note that we do not support seeking on directories, since we do not even
+// allow directory fds to be imported at all.
+func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (int64, error) {
+ i := f.inode
+ if !i.seekable {
+ return 0, syserror.ESPIPE
+ }
+
+ f.offsetMu.Lock()
+ defer f.offsetMu.Unlock()
+
+ switch whence {
+ case linux.SEEK_SET:
+ if offset < 0 {
+ return f.offset, syserror.EINVAL
+ }
+ f.offset = offset
+
+ case linux.SEEK_CUR:
+ // Check for overflow. Note that underflow cannot occur, since f.offset >= 0.
+ if offset > math.MaxInt64-f.offset {
+ return f.offset, syserror.EOVERFLOW
+ }
+ if f.offset+offset < 0 {
+ return f.offset, syserror.EINVAL
+ }
+ f.offset += offset
+
+ case linux.SEEK_END:
+ var s syscall.Stat_t
+ if err := syscall.Fstat(i.hostFD, &s); err != nil {
+ return f.offset, err
+ }
+ size := s.Size
+
+ // Check for overflow. Note that underflow cannot occur, since size >= 0.
+ if offset > math.MaxInt64-size {
+ return f.offset, syserror.EOVERFLOW
+ }
+ if size+offset < 0 {
+ return f.offset, syserror.EINVAL
+ }
+ f.offset = size + offset
+
+ case linux.SEEK_DATA, linux.SEEK_HOLE:
+ // Modifying the offset in the host file table should not matter, since
+ // this is the only place where we use it.
+ //
+ // For reading and writing, we always rely on our internal offset.
+ n, err := unix.Seek(i.hostFD, offset, int(whence))
+ if err != nil {
+ return f.offset, err
+ }
+ f.offset = n
+
+ default:
+ // Invalid whence.
+ return f.offset, syserror.EINVAL
+ }
+
+ return f.offset, nil
+}
+
+// Sync implements FileDescriptionImpl.
+func (f *fileDescription) Sync(context.Context) error {
+ // TODO(gvisor.dev/issue/1897): Currently, we always sync everything.
+ return unix.Fsync(f.inode.hostFD)
+}
+
+// ConfigureMMap implements FileDescriptionImpl.
+func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts) error {
+ if !f.inode.canMap {
+ return syserror.ENODEV
+ }
+ i := f.inode
+ i.pf.fileMapperInitOnce.Do(i.pf.fileMapper.Init)
+ return vfs.GenericConfigureMMap(&f.vfsfd, i, opts)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (f *fileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ f.inode.queue.EventRegister(e, mask)
+ fdnotifier.UpdateFD(int32(f.inode.hostFD))
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (f *fileDescription) EventUnregister(e *waiter.Entry) {
+ f.inode.queue.EventUnregister(e)
+ fdnotifier.UpdateFD(int32(f.inode.hostFD))
+}
+
+// Readiness uses the poll() syscall to check the status of the underlying FD.
+func (f *fileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fdnotifier.NonBlockingPoll(int32(f.inode.hostFD), mask)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (f *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return f.Locks().LockPOSIX(ctx, &f.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (f *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return f.Locks().UnlockPOSIX(ctx, &f.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/host/ioctl_unsafe.go b/pkg/sentry/fsimpl/host/ioctl_unsafe.go
new file mode 100644
index 000000000..0983bf7d8
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/ioctl_unsafe.go
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+func ioctlGetTermios(fd int) (*linux.Termios, error) {
+ var t linux.Termios
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TCGETS, uintptr(unsafe.Pointer(&t)))
+ if errno != 0 {
+ return nil, errno
+ }
+ return &t, nil
+}
+
+func ioctlSetTermios(fd int, req uint64, t *linux.Termios) error {
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(unsafe.Pointer(t)))
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+func ioctlGetWinsize(fd int) (*linux.Winsize, error) {
+ var w linux.Winsize
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TIOCGWINSZ, uintptr(unsafe.Pointer(&w)))
+ if errno != 0 {
+ return nil, errno
+ }
+ return &w, nil
+}
+
+func ioctlSetWinsize(fd int, w *linux.Winsize) error {
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TIOCSWINSZ, uintptr(unsafe.Pointer(w)))
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/host/mmap.go
new file mode 100644
index 000000000..8545a82f0
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/mmap.go
@@ -0,0 +1,132 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// inodePlatformFile implements platform.File. It exists solely because inode
+// cannot implement both kernfs.Inode.IncRef and platform.File.IncRef.
+//
+// inodePlatformFile should only be used if inode.canMap is true.
+type inodePlatformFile struct {
+ *inode
+
+ // fdRefsMu protects fdRefs.
+ fdRefsMu sync.Mutex
+
+ // fdRefs counts references on platform.File offsets. It is used solely for
+ // memory accounting.
+ fdRefs fsutil.FrameRefSet
+
+ // fileMapper caches mappings of the host file represented by this inode.
+ fileMapper fsutil.HostFileMapper
+
+ // fileMapperInitOnce is used to lazily initialize fileMapper.
+ fileMapperInitOnce sync.Once
+}
+
+// IncRef implements platform.File.IncRef.
+//
+// Precondition: i.inode.canMap must be true.
+func (i *inodePlatformFile) IncRef(fr platform.FileRange) {
+ i.fdRefsMu.Lock()
+ i.fdRefs.IncRefAndAccount(fr)
+ i.fdRefsMu.Unlock()
+}
+
+// DecRef implements platform.File.DecRef.
+//
+// Precondition: i.inode.canMap must be true.
+func (i *inodePlatformFile) DecRef(fr platform.FileRange) {
+ i.fdRefsMu.Lock()
+ i.fdRefs.DecRefAndAccount(fr)
+ i.fdRefsMu.Unlock()
+}
+
+// MapInternal implements platform.File.MapInternal.
+//
+// Precondition: i.inode.canMap must be true.
+func (i *inodePlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+ return i.fileMapper.MapInternal(fr, i.hostFD, at.Write)
+}
+
+// FD implements platform.File.FD.
+func (i *inodePlatformFile) FD() int {
+ return i.hostFD
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+//
+// Precondition: i.inode.canMap must be true.
+func (i *inode) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ i.mapsMu.Lock()
+ mapped := i.mappings.AddMapping(ms, ar, offset, writable)
+ for _, r := range mapped {
+ i.pf.fileMapper.IncRefOn(r)
+ }
+ i.mapsMu.Unlock()
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+//
+// Precondition: i.inode.canMap must be true.
+func (i *inode) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ i.mapsMu.Lock()
+ unmapped := i.mappings.RemoveMapping(ms, ar, offset, writable)
+ for _, r := range unmapped {
+ i.pf.fileMapper.DecRefOn(r)
+ }
+ i.mapsMu.Unlock()
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+//
+// Precondition: i.inode.canMap must be true.
+func (i *inode) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ return i.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+//
+// Precondition: i.inode.canMap must be true.
+func (i *inode) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ mr := optional
+ return []memmap.Translation{
+ {
+ Source: mr,
+ File: &i.pf,
+ Offset: mr.Start,
+ Perms: usermem.AnyAccess,
+ },
+ }, nil
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+//
+// Precondition: i.inode.canMap must be true.
+func (i *inode) InvalidateUnsavable(ctx context.Context) error {
+ // We expect the same host fd across save/restore, so all translations
+ // should be valid.
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go
new file mode 100644
index 000000000..fd16bd92d
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/socket.go
@@ -0,0 +1,385 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/unet"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// Create a new host-backed endpoint from the given fd and its corresponding
+// notification queue.
+func newEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue) (transport.Endpoint, error) {
+ // Set up an external transport.Endpoint using the host fd.
+ addr := fmt.Sprintf("hostfd:[%d]", hostFD)
+ e, err := NewConnectedEndpoint(ctx, hostFD, addr, true /* saveable */)
+ if err != nil {
+ return nil, err.ToError()
+ }
+ ep := transport.NewExternal(ctx, e.stype, uniqueid.GlobalProviderFromContext(ctx), queue, e, e)
+ return ep, nil
+}
+
+// ConnectedEndpoint is an implementation of transport.ConnectedEndpoint and
+// transport.Receiver. It is backed by a host fd that was imported at sentry
+// startup. This fd is shared with a hostfs inode, which retains ownership of
+// it.
+//
+// ConnectedEndpoint is saveable, since we expect that the host will provide
+// the same fd upon restore.
+//
+// As of this writing, we only allow Unix sockets to be imported.
+//
+// +stateify savable
+type ConnectedEndpoint struct {
+ // ref keeps track of references to a ConnectedEndpoint.
+ ref refs.AtomicRefCount
+
+ // mu protects fd below.
+ mu sync.RWMutex `state:"nosave"`
+
+ // fd is the host fd backing this endpoint.
+ fd int
+
+ // addr is the address at which this endpoint is bound.
+ addr string
+
+ // sndbuf is the size of the send buffer.
+ //
+ // N.B. When this is smaller than the host size, we present it via
+ // GetSockOpt and message splitting/rejection in SendMsg, but do not
+ // prevent lots of small messages from filling the real send buffer
+ // size on the host.
+ sndbuf int64 `state:"nosave"`
+
+ // stype is the type of Unix socket.
+ stype linux.SockType
+}
+
+// init performs initialization required for creating new ConnectedEndpoints and
+// for restoring them.
+func (c *ConnectedEndpoint) init() *syserr.Error {
+ family, err := syscall.GetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_DOMAIN)
+ if err != nil {
+ return syserr.FromError(err)
+ }
+
+ if family != syscall.AF_UNIX {
+ // We only allow Unix sockets.
+ return syserr.ErrInvalidEndpointState
+ }
+
+ stype, err := syscall.GetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_TYPE)
+ if err != nil {
+ return syserr.FromError(err)
+ }
+
+ if err := syscall.SetNonblock(c.fd, true); err != nil {
+ return syserr.FromError(err)
+ }
+
+ sndbuf, err := syscall.GetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF)
+ if err != nil {
+ return syserr.FromError(err)
+ }
+
+ c.stype = linux.SockType(stype)
+ c.sndbuf = int64(sndbuf)
+
+ return nil
+}
+
+// NewConnectedEndpoint creates a new ConnectedEndpoint backed by a host fd
+// imported at sentry startup,
+//
+// The caller is responsible for calling Init(). Additionaly, Release needs to
+// be called twice because ConnectedEndpoint is both a transport.Receiver and
+// transport.ConnectedEndpoint.
+func NewConnectedEndpoint(ctx context.Context, hostFD int, addr string, saveable bool) (*ConnectedEndpoint, *syserr.Error) {
+ e := ConnectedEndpoint{
+ fd: hostFD,
+ addr: addr,
+ }
+
+ if err := e.init(); err != nil {
+ return nil, err
+ }
+
+ // AtomicRefCounters start off with a single reference. We need two.
+ e.ref.IncRef()
+ e.ref.EnableLeakCheck("host.ConnectedEndpoint")
+ return &e, nil
+}
+
+// Send implements transport.ConnectedEndpoint.Send.
+func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ if !controlMessages.Empty() {
+ return 0, false, syserr.ErrInvalidEndpointState
+ }
+
+ // Since stream sockets don't preserve message boundaries, we can write
+ // only as much of the message as fits in the send buffer.
+ truncate := c.stype == linux.SOCK_STREAM
+
+ n, totalLen, err := fdWriteVec(c.fd, data, c.sndbuf, truncate)
+ if n < totalLen && err == nil {
+ // The host only returns a short write if it would otherwise
+ // block (and only for stream sockets).
+ err = syserror.EAGAIN
+ }
+ if n > 0 && err != syserror.EAGAIN {
+ // The caller may need to block to send more data, but
+ // otherwise there isn't anything that can be done about an
+ // error with a partial write.
+ err = nil
+ }
+
+ // There is no need for the callee to call SendNotify because fdWriteVec
+ // uses the host's sendmsg(2) and the host kernel's queue.
+ return n, false, syserr.FromError(err)
+}
+
+// SendNotify implements transport.ConnectedEndpoint.SendNotify.
+func (c *ConnectedEndpoint) SendNotify() {}
+
+// CloseSend implements transport.ConnectedEndpoint.CloseSend.
+func (c *ConnectedEndpoint) CloseSend() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if err := syscall.Shutdown(c.fd, syscall.SHUT_WR); err != nil {
+ // A well-formed UDS shutdown can't fail. See
+ // net/unix/af_unix.c:unix_shutdown.
+ panic(fmt.Sprintf("failed write shutdown on host socket %+v: %v", c, err))
+ }
+}
+
+// CloseNotify implements transport.ConnectedEndpoint.CloseNotify.
+func (c *ConnectedEndpoint) CloseNotify() {}
+
+// Writable implements transport.ConnectedEndpoint.Writable.
+func (c *ConnectedEndpoint) Writable() bool {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ return fdnotifier.NonBlockingPoll(int32(c.fd), waiter.EventOut)&waiter.EventOut != 0
+}
+
+// Passcred implements transport.ConnectedEndpoint.Passcred.
+func (c *ConnectedEndpoint) Passcred() bool {
+ // We don't support credential passing for host sockets.
+ return false
+}
+
+// GetLocalAddress implements transport.ConnectedEndpoint.GetLocalAddress.
+func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, nil
+}
+
+// EventUpdate implements transport.ConnectedEndpoint.EventUpdate.
+func (c *ConnectedEndpoint) EventUpdate() {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ if c.fd != -1 {
+ fdnotifier.UpdateFD(int32(c.fd))
+ }
+}
+
+// Recv implements transport.Receiver.Recv.
+func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ var cm unet.ControlMessage
+ if numRights > 0 {
+ cm.EnableFDs(int(numRights))
+ }
+
+ // N.B. Unix sockets don't have a receive buffer, the send buffer
+ // serves both purposes.
+ rl, ml, cl, cTrunc, err := fdReadVec(c.fd, data, []byte(cm), peek, c.sndbuf)
+ if rl > 0 && err != nil {
+ // We got some data, so all we need to do on error is return
+ // the data that we got. Short reads are fine, no need to
+ // block.
+ err = nil
+ }
+ if err != nil {
+ return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err)
+ }
+
+ // There is no need for the callee to call RecvNotify because fdReadVec uses
+ // the host's recvmsg(2) and the host kernel's queue.
+
+ // Trim the control data if we received less than the full amount.
+ if cl < uint64(len(cm)) {
+ cm = cm[:cl]
+ }
+
+ // Avoid extra allocations in the case where there isn't any control data.
+ if len(cm) == 0 {
+ return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, false, nil
+ }
+
+ fds, err := cm.ExtractFDs()
+ if err != nil {
+ return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err)
+ }
+
+ if len(fds) == 0 {
+ return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, false, nil
+ }
+ return rl, ml, control.NewVFS2(nil, nil, newSCMRights(fds)), cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, false, nil
+}
+
+// RecvNotify implements transport.Receiver.RecvNotify.
+func (c *ConnectedEndpoint) RecvNotify() {}
+
+// CloseRecv implements transport.Receiver.CloseRecv.
+func (c *ConnectedEndpoint) CloseRecv() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if err := syscall.Shutdown(c.fd, syscall.SHUT_RD); err != nil {
+ // A well-formed UDS shutdown can't fail. See
+ // net/unix/af_unix.c:unix_shutdown.
+ panic(fmt.Sprintf("failed read shutdown on host socket %+v: %v", c, err))
+ }
+}
+
+// Readable implements transport.Receiver.Readable.
+func (c *ConnectedEndpoint) Readable() bool {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ return fdnotifier.NonBlockingPoll(int32(c.fd), waiter.EventIn)&waiter.EventIn != 0
+}
+
+// SendQueuedSize implements transport.Receiver.SendQueuedSize.
+func (c *ConnectedEndpoint) SendQueuedSize() int64 {
+ // TODO(gvisor.dev/issue/273): SendQueuedSize isn't supported for host
+ // sockets because we don't allow the sentry to call ioctl(2).
+ return -1
+}
+
+// RecvQueuedSize implements transport.Receiver.RecvQueuedSize.
+func (c *ConnectedEndpoint) RecvQueuedSize() int64 {
+ // TODO(gvisor.dev/issue/273): RecvQueuedSize isn't supported for host
+ // sockets because we don't allow the sentry to call ioctl(2).
+ return -1
+}
+
+// SendMaxQueueSize implements transport.Receiver.SendMaxQueueSize.
+func (c *ConnectedEndpoint) SendMaxQueueSize() int64 {
+ return int64(c.sndbuf)
+}
+
+// RecvMaxQueueSize implements transport.Receiver.RecvMaxQueueSize.
+func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
+ // N.B. Unix sockets don't use the receive buffer. We'll claim it is
+ // the same size as the send buffer.
+ return int64(c.sndbuf)
+}
+
+func (c *ConnectedEndpoint) destroyLocked() {
+ c.fd = -1
+}
+
+// Release implements transport.ConnectedEndpoint.Release and
+// transport.Receiver.Release.
+func (c *ConnectedEndpoint) Release() {
+ c.ref.DecRefWithDestructor(func() {
+ c.mu.Lock()
+ c.destroyLocked()
+ c.mu.Unlock()
+ })
+}
+
+// CloseUnread implements transport.ConnectedEndpoint.CloseUnread.
+func (c *ConnectedEndpoint) CloseUnread() {}
+
+// SCMConnectedEndpoint represents an endpoint backed by a host fd that was
+// passed through a gofer Unix socket. It resembles ConnectedEndpoint, with the
+// following differences:
+// - SCMConnectedEndpoint is not saveable, because the host cannot guarantee
+// the same descriptor number across S/R.
+// - SCMConnectedEndpoint holds ownership of its fd and notification queue.
+type SCMConnectedEndpoint struct {
+ ConnectedEndpoint
+
+ queue *waiter.Queue
+}
+
+// Init will do the initialization required without holding other locks.
+func (e *SCMConnectedEndpoint) Init() error {
+ return fdnotifier.AddFD(int32(e.fd), e.queue)
+}
+
+// Release implements transport.ConnectedEndpoint.Release and
+// transport.Receiver.Release.
+func (e *SCMConnectedEndpoint) Release() {
+ e.ref.DecRefWithDestructor(func() {
+ e.mu.Lock()
+ if err := syscall.Close(e.fd); err != nil {
+ log.Warningf("Failed to close host fd %d: %v", err)
+ }
+ fdnotifier.RemoveFD(int32(e.fd))
+ e.destroyLocked()
+ e.mu.Unlock()
+ })
+}
+
+// NewSCMEndpoint creates a new SCMConnectedEndpoint backed by a host fd that
+// was passed through a Unix socket.
+//
+// The caller is responsible for calling Init(). Additionaly, Release needs to
+// be called twice because ConnectedEndpoint is both a transport.Receiver and
+// transport.ConnectedEndpoint.
+func NewSCMEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue, addr string) (*SCMConnectedEndpoint, *syserr.Error) {
+ e := SCMConnectedEndpoint{
+ ConnectedEndpoint: ConnectedEndpoint{
+ fd: hostFD,
+ addr: addr,
+ },
+ queue: queue,
+ }
+
+ if err := e.init(); err != nil {
+ return nil, err
+ }
+
+ // AtomicRefCounters start off with a single reference. We need two.
+ e.ref.IncRef()
+ e.ref.EnableLeakCheck("host.SCMConnectedEndpoint")
+ return &e, nil
+}
diff --git a/pkg/sentry/fsimpl/host/socket_iovec.go b/pkg/sentry/fsimpl/host/socket_iovec.go
new file mode 100644
index 000000000..584c247d2
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/socket_iovec.go
@@ -0,0 +1,113 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// maxIovs is the maximum number of iovecs to pass to the host.
+var maxIovs = linux.UIO_MAXIOV
+
+// copyToMulti copies as many bytes from src to dst as possible.
+func copyToMulti(dst [][]byte, src []byte) {
+ for _, d := range dst {
+ done := copy(d, src)
+ src = src[done:]
+ if len(src) == 0 {
+ break
+ }
+ }
+}
+
+// copyFromMulti copies as many bytes from src to dst as possible.
+func copyFromMulti(dst []byte, src [][]byte) {
+ for _, s := range src {
+ done := copy(dst, s)
+ dst = dst[done:]
+ if len(dst) == 0 {
+ break
+ }
+ }
+}
+
+// buildIovec builds an iovec slice from the given []byte slice.
+//
+// If truncate, truncate bufs > maxlen. Otherwise, immediately return an error.
+//
+// If length < the total length of bufs, err indicates why, even when returning
+// a truncated iovec.
+//
+// If intermediate != nil, iovecs references intermediate rather than bufs and
+// the caller must copy to/from bufs as necessary.
+func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovecs []syscall.Iovec, intermediate []byte, err error) {
+ var iovsRequired int
+ for _, b := range bufs {
+ length += int64(len(b))
+ if len(b) > 0 {
+ iovsRequired++
+ }
+ }
+
+ stopLen := length
+ if length > maxlen {
+ if truncate {
+ stopLen = maxlen
+ err = syserror.EAGAIN
+ } else {
+ return 0, nil, nil, syserror.EMSGSIZE
+ }
+ }
+
+ if iovsRequired > maxIovs {
+ // The kernel will reject our call if we pass this many iovs.
+ // Use a single intermediate buffer instead.
+ b := make([]byte, stopLen)
+
+ return stopLen, []syscall.Iovec{{
+ Base: &b[0],
+ Len: uint64(stopLen),
+ }}, b, err
+ }
+
+ var total int64
+ iovecs = make([]syscall.Iovec, 0, iovsRequired)
+ for i := range bufs {
+ l := len(bufs[i])
+ if l == 0 {
+ continue
+ }
+
+ stop := int64(l)
+ if total+stop > stopLen {
+ stop = stopLen - total
+ }
+
+ iovecs = append(iovecs, syscall.Iovec{
+ Base: &bufs[i][0],
+ Len: uint64(stop),
+ })
+
+ total += stop
+ if total >= stopLen {
+ break
+ }
+ }
+
+ return total, iovecs, nil, err
+}
diff --git a/pkg/sentry/fsimpl/host/socket_unsafe.go b/pkg/sentry/fsimpl/host/socket_unsafe.go
new file mode 100644
index 000000000..35ded24bc
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/socket_unsafe.go
@@ -0,0 +1,101 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+// fdReadVec receives from fd to bufs.
+//
+// If the total length of bufs is > maxlen, fdReadVec will do a partial read
+// and err will indicate why the message was truncated.
+func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int64) (readLen int64, msgLen int64, controlLen uint64, controlTrunc bool, err error) {
+ flags := uintptr(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC)
+ if peek {
+ flags |= syscall.MSG_PEEK
+ }
+
+ // Always truncate the receive buffer. All socket types will truncate
+ // received messages.
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, true)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, 0, 0, false, err
+ }
+
+ var msg syscall.Msghdr
+ if len(control) != 0 {
+ msg.Control = &control[0]
+ msg.Controllen = uint64(len(control))
+ }
+
+ if len(iovecs) != 0 {
+ msg.Iov = &iovecs[0]
+ msg.Iovlen = uint64(len(iovecs))
+ }
+
+ rawN, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags)
+ if e != 0 {
+ // N.B. prioritize the syscall error over the buildIovec error.
+ return 0, 0, 0, false, e
+ }
+ n := int64(rawN)
+
+ // Copy data back to bufs.
+ if intermediate != nil {
+ copyToMulti(bufs, intermediate)
+ }
+
+ controlTrunc = msg.Flags&syscall.MSG_CTRUNC == syscall.MSG_CTRUNC
+
+ if n > length {
+ return length, n, msg.Controllen, controlTrunc, err
+ }
+
+ return n, n, msg.Controllen, controlTrunc, err
+}
+
+// fdWriteVec sends from bufs to fd.
+//
+// If the total length of bufs is > maxlen && truncate, fdWriteVec will do a
+// partial write and err will indicate why the message was truncated.
+func fdWriteVec(fd int, bufs [][]byte, maxlen int64, truncate bool) (int64, int64, error) {
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, truncate)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, length, err
+ }
+
+ // Copy data to intermediate buf.
+ if intermediate != nil {
+ copyFromMulti(intermediate, bufs)
+ }
+
+ var msg syscall.Msghdr
+ if len(iovecs) > 0 {
+ msg.Iov = &iovecs[0]
+ msg.Iovlen = uint64(len(iovecs))
+ }
+
+ n, _, e := syscall.RawSyscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_NOSIGNAL)
+ if e != 0 {
+ // N.B. prioritize the syscall error over the buildIovec error.
+ return 0, length, e
+ }
+
+ return int64(n), length, err
+}
diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go
new file mode 100644
index 000000000..4ee9270cc
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/tty.go
@@ -0,0 +1,390 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/unimpl"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// TTYFileDescription implements vfs.FileDescriptionImpl for a host file
+// descriptor that wraps a TTY FD.
+type TTYFileDescription struct {
+ fileDescription
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // session is the session attached to this TTYFileDescription.
+ session *kernel.Session
+
+ // fgProcessGroup is the foreground process group that is currently
+ // connected to this TTY.
+ fgProcessGroup *kernel.ProcessGroup
+
+ // termios contains the terminal attributes for this TTY.
+ termios linux.KernelTermios
+}
+
+// InitForegroundProcessGroup sets the foreground process group and session for
+// the TTY. This should only be called once, after the foreground process group
+// has been created, but before it has started running.
+func (t *TTYFileDescription) InitForegroundProcessGroup(pg *kernel.ProcessGroup) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if t.fgProcessGroup != nil {
+ panic("foreground process group is already set")
+ }
+ t.fgProcessGroup = pg
+ t.session = pg.Session()
+}
+
+// ForegroundProcessGroup returns the foreground process for the TTY.
+func (t *TTYFileDescription) ForegroundProcessGroup() *kernel.ProcessGroup {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.fgProcessGroup
+}
+
+// Release implements fs.FileOperations.Release.
+func (t *TTYFileDescription) Release() {
+ t.mu.Lock()
+ t.fgProcessGroup = nil
+ t.mu.Unlock()
+
+ t.fileDescription.Release()
+}
+
+// PRead implements vfs.FileDescriptionImpl.
+//
+// Reading from a TTY is only allowed for foreground process groups. Background
+// process groups will either get EIO or a SIGTTIN.
+func (t *TTYFileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Are we allowed to do the read?
+ // drivers/tty/n_tty.c:n_tty_read()=>job_control()=>tty_check_change().
+ if err := t.checkChange(ctx, linux.SIGTTIN); err != nil {
+ return 0, err
+ }
+
+ // Do the read.
+ return t.fileDescription.PRead(ctx, dst, offset, opts)
+}
+
+// Read implements vfs.FileDescriptionImpl.
+//
+// Reading from a TTY is only allowed for foreground process groups. Background
+// process groups will either get EIO or a SIGTTIN.
+func (t *TTYFileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Are we allowed to do the read?
+ // drivers/tty/n_tty.c:n_tty_read()=>job_control()=>tty_check_change().
+ if err := t.checkChange(ctx, linux.SIGTTIN); err != nil {
+ return 0, err
+ }
+
+ // Do the read.
+ return t.fileDescription.Read(ctx, dst, opts)
+}
+
+// PWrite implements vfs.FileDescriptionImpl.
+func (t *TTYFileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Check whether TOSTOP is enabled. This corresponds to the check in
+ // drivers/tty/n_tty.c:n_tty_write().
+ if t.termios.LEnabled(linux.TOSTOP) {
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
+ }
+ return t.fileDescription.PWrite(ctx, src, offset, opts)
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (t *TTYFileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Check whether TOSTOP is enabled. This corresponds to the check in
+ // drivers/tty/n_tty.c:n_tty_write().
+ if t.termios.LEnabled(linux.TOSTOP) {
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
+ }
+ return t.fileDescription.Write(ctx, src, opts)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ // Ignore arg[0]. This is the real FD:
+ fd := t.inode.hostFD
+ ioctl := args[1].Uint64()
+ switch ioctl {
+ case linux.TCGETS:
+ termios, err := ioctlGetTermios(fd)
+ if err != nil {
+ return 0, err
+ }
+ _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), termios, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TCSETS, linux.TCSETSW, linux.TCSETSF:
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
+
+ var termios linux.Termios
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &termios, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ err := ioctlSetTermios(fd, ioctl, &termios)
+ if err == nil {
+ t.termios.FromTermios(termios)
+ }
+ return 0, err
+
+ case linux.TIOCGPGRP:
+ // Args: pid_t *argp
+ // When successful, equivalent to *argp = tcgetpgrp(fd).
+ // Get the process group ID of the foreground process group on this
+ // terminal.
+
+ pidns := kernel.PIDNamespaceFromContext(ctx)
+ if pidns == nil {
+ return 0, syserror.ENOTTY
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Map the ProcessGroup into a ProcessGroupID in the task's PID namespace.
+ pgID := pidns.IDOfProcessGroup(t.fgProcessGroup)
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TIOCSPGRP:
+ // Args: const pid_t *argp
+ // Equivalent to tcsetpgrp(fd, *argp).
+ // Set the foreground process group ID of this terminal.
+
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ return 0, syserror.ENOTTY
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Check that we are allowed to set the process group.
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ // drivers/tty/tty_io.c:tiocspgrp() converts -EIO from tty_check_change()
+ // to -ENOTTY.
+ if err == syserror.EIO {
+ return 0, syserror.ENOTTY
+ }
+ return 0, err
+ }
+
+ // Check that calling task's process group is in the TTY session.
+ if task.ThreadGroup().Session() != t.session {
+ return 0, syserror.ENOTTY
+ }
+
+ var pgID kernel.ProcessGroupID
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+
+ // pgID must be non-negative.
+ if pgID < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Process group with pgID must exist in this PID namespace.
+ pidns := task.PIDNamespace()
+ pg := pidns.ProcessGroupWithID(pgID)
+ if pg == nil {
+ return 0, syserror.ESRCH
+ }
+
+ // Check that new process group is in the TTY session.
+ if pg.Session() != t.session {
+ return 0, syserror.EPERM
+ }
+
+ t.fgProcessGroup = pg
+ return 0, nil
+
+ case linux.TIOCGWINSZ:
+ // Args: struct winsize *argp
+ // Get window size.
+ winsize, err := ioctlGetWinsize(fd)
+ if err != nil {
+ return 0, err
+ }
+ _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), winsize, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TIOCSWINSZ:
+ // Args: const struct winsize *argp
+ // Set window size.
+
+ // Unlike setting the termios, any process group (even background ones) can
+ // set the winsize.
+
+ var winsize linux.Winsize
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &winsize, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ err := ioctlSetWinsize(fd, &winsize)
+ return 0, err
+
+ // Unimplemented commands.
+ case linux.TIOCSETD,
+ linux.TIOCSBRK,
+ linux.TIOCCBRK,
+ linux.TCSBRK,
+ linux.TCSBRKP,
+ linux.TIOCSTI,
+ linux.TIOCCONS,
+ linux.FIONBIO,
+ linux.TIOCEXCL,
+ linux.TIOCNXCL,
+ linux.TIOCGEXCL,
+ linux.TIOCNOTTY,
+ linux.TIOCSCTTY,
+ linux.TIOCGSID,
+ linux.TIOCGETD,
+ linux.TIOCVHANGUP,
+ linux.TIOCGDEV,
+ linux.TIOCMGET,
+ linux.TIOCMSET,
+ linux.TIOCMBIC,
+ linux.TIOCMBIS,
+ linux.TIOCGICOUNT,
+ linux.TCFLSH,
+ linux.TIOCSSERIAL,
+ linux.TIOCGPTPEER:
+
+ unimpl.EmitUnimplementedEvent(ctx)
+ fallthrough
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+// checkChange checks that the process group is allowed to read, write, or
+// change the state of the TTY.
+//
+// This corresponds to Linux drivers/tty/tty_io.c:tty_check_change(). The logic
+// is a bit convoluted, but documented inline.
+//
+// Preconditions: t.mu must be held.
+func (t *TTYFileDescription) checkChange(ctx context.Context, sig linux.Signal) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ // No task? Linux does not have an analog for this case, but
+ // tty_check_change only blocks specific cases and is
+ // surprisingly permissive. Allowing the change seems
+ // appropriate.
+ return nil
+ }
+
+ tg := task.ThreadGroup()
+ pg := tg.ProcessGroup()
+
+ // If the session for the task is different than the session for the
+ // controlling TTY, then the change is allowed. Seems like a bad idea,
+ // but that's exactly what linux does.
+ if tg.Session() != t.fgProcessGroup.Session() {
+ return nil
+ }
+
+ // If we are the foreground process group, then the change is allowed.
+ if pg == t.fgProcessGroup {
+ return nil
+ }
+
+ // We are not the foreground process group.
+
+ // Is the provided signal blocked or ignored?
+ if (task.SignalMask()&linux.SignalSetOf(sig) != 0) || tg.SignalHandlers().IsIgnored(sig) {
+ // If the signal is SIGTTIN, then we are attempting to read
+ // from the TTY. Don't send the signal and return EIO.
+ if sig == linux.SIGTTIN {
+ return syserror.EIO
+ }
+
+ // Otherwise, we are writing or changing terminal state. This is allowed.
+ return nil
+ }
+
+ // If the process group is an orphan, return EIO.
+ if pg.IsOrphan() {
+ return syserror.EIO
+ }
+
+ // Otherwise, send the signal to the process group and return ERESTARTSYS.
+ //
+ // Note that Linux also unconditionally sets TIF_SIGPENDING on current,
+ // but this isn't necessary in gVisor because the rationale given in
+ // 040b6362d58f "tty: fix leakage of -ERESTARTSYS to userland" doesn't
+ // apply: the sentry will handle -ERESTARTSYS in
+ // kernel.runApp.execute() even if the kernel.Task isn't interrupted.
+ //
+ // Linux ignores the result of kill_pgrp().
+ _ = pg.SendSignal(kernel.SignalInfoPriv(sig))
+ return kernel.ERESTARTSYS
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (t *TTYFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, typ fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return t.Locks().LockPOSIX(ctx, &t.vfsfd, uid, typ, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (t *TTYFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return t.Locks().UnlockPOSIX(ctx, &t.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go
new file mode 100644
index 000000000..412bdb2eb
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/util.go
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+func toTimespec(ts linux.StatxTimestamp, omit bool) syscall.Timespec {
+ if omit {
+ return syscall.Timespec{
+ Sec: 0,
+ Nsec: unix.UTIME_OMIT,
+ }
+ }
+ return syscall.Timespec{
+ Sec: ts.Sec,
+ Nsec: int64(ts.Nsec),
+ }
+}
+
+func unixToLinuxStatxTimestamp(ts unix.StatxTimestamp) linux.StatxTimestamp {
+ return linux.StatxTimestamp{Sec: ts.Sec, Nsec: ts.Nsec}
+}
+
+func timespecToStatxTimestamp(ts unix.Timespec) linux.StatxTimestamp {
+ return linux.StatxTimestamp{Sec: int64(ts.Sec), Nsec: uint32(ts.Nsec)}
+}
+
+// wouldBlock returns true for file types that can return EWOULDBLOCK
+// for blocking operations, e.g. pipes, character devices, and sockets.
+func wouldBlock(fileType uint32) bool {
+ return fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK
+}
+
+// isBlockError checks if an error is EAGAIN or EWOULDBLOCK.
+// If so, they can be transformed into syserror.ErrWouldBlock.
+func isBlockError(err error) bool {
+ return err == syserror.EAGAIN || err == syserror.EWOULDBLOCK
+}
diff --git a/pkg/sentry/fsimpl/host/util_unsafe.go b/pkg/sentry/fsimpl/host/util_unsafe.go
new file mode 100644
index 000000000..5136ac844
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/util_unsafe.go
@@ -0,0 +1,34 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+func setTimestamps(fd int, ts *[2]syscall.Timespec) error {
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_UTIMENSAT,
+ uintptr(fd),
+ 0, /* path */
+ uintptr(unsafe.Pointer(ts)),
+ 0, /* flags */
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
new file mode 100644
index 000000000..179df6c1e
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -0,0 +1,75 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+licenses(["notice"])
+
+go_template_instance(
+ name = "fstree",
+ out = "fstree.go",
+ package = "kernfs",
+ prefix = "generic",
+ template = "//pkg/sentry/vfs/genericfstree:generic_fstree",
+ types = {
+ "Dentry": "Dentry",
+ },
+)
+
+go_template_instance(
+ name = "slot_list",
+ out = "slot_list.go",
+ package = "kernfs",
+ prefix = "slot",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*slot",
+ "Linker": "*slot",
+ },
+)
+
+go_library(
+ name = "kernfs",
+ srcs = [
+ "dynamic_bytes_file.go",
+ "fd_impl_util.go",
+ "filesystem.go",
+ "fstree.go",
+ "inode_impl_util.go",
+ "kernfs.go",
+ "slot_list.go",
+ "symlink.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
+
+go_test(
+ name = "kernfs_test",
+ size = "small",
+ srcs = ["kernfs_test.go"],
+ deps = [
+ ":kernfs",
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fsimpl/testutil",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
new file mode 100644
index 000000000..6886b0876
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -0,0 +1,147 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// DynamicBytesFile implements kernfs.Inode and represents a read-only
+// file whose contents are backed by a vfs.DynamicBytesSource.
+//
+// Must be instantiated with NewDynamicBytesFile or initialized with Init
+// before first use.
+//
+// +stateify savable
+type DynamicBytesFile struct {
+ InodeAttrs
+ InodeNoopRefCount
+ InodeNotDirectory
+ InodeNotSymlink
+
+ locks vfs.FileLocks
+ data vfs.DynamicBytesSource
+}
+
+var _ Inode = (*DynamicBytesFile)(nil)
+
+// Init initializes a dynamic bytes file.
+func (f *DynamicBytesFile) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) {
+ if perm&^linux.PermissionsMask != 0 {
+ panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
+ }
+ f.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
+ f.data = data
+}
+
+// Open implements Inode.Open.
+func (f *DynamicBytesFile) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &DynamicBytesFD{}
+ if err := fd.Init(rp.Mount(), vfsd, f.data, &f.locks, opts.Flags); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// SetStat implements Inode.SetStat. By default DynamicBytesFile doesn't allow
+// inode attributes to be changed. Override SetStat() making it call
+// f.InodeAttrs to allow it.
+func (*DynamicBytesFile) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// DynamicBytesFD implements vfs.FileDescriptionImpl for an FD backed by a
+// DynamicBytesFile.
+//
+// Must be initialized with Init before first use.
+//
+// +stateify savable
+type DynamicBytesFD struct {
+ vfs.FileDescriptionDefaultImpl
+ vfs.DynamicBytesFileDescriptionImpl
+ vfs.LockFD
+
+ vfsfd vfs.FileDescription
+ inode Inode
+}
+
+// Init initializes a DynamicBytesFD.
+func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, locks *vfs.FileLocks, flags uint32) error {
+ fd.LockFD.Init(locks)
+ if err := fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{}); err != nil {
+ return err
+ }
+ fd.inode = d.Impl().(*Dentry).inode
+ fd.SetDataSource(data)
+ return nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *DynamicBytesFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return fd.DynamicBytesFileDescriptionImpl.Seek(ctx, offset, whence)
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *DynamicBytesFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return fd.DynamicBytesFileDescriptionImpl.Read(ctx, dst, opts)
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *DynamicBytesFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return fd.DynamicBytesFileDescriptionImpl.PRead(ctx, dst, offset, opts)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *DynamicBytesFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return fd.DynamicBytesFileDescriptionImpl.Write(ctx, src, opts)
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *DynamicBytesFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return fd.DynamicBytesFileDescriptionImpl.PWrite(ctx, src, offset, opts)
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *DynamicBytesFD) Release() {}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *DynamicBytesFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ fs := fd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return fd.inode.Stat(fs, opts)
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *DynamicBytesFD) SetStat(context.Context, vfs.SetStatOptions) error {
+ // DynamicBytesFiles are immutable.
+ return syserror.EPERM
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *DynamicBytesFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *DynamicBytesFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
new file mode 100644
index 000000000..ca8b8c63b
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -0,0 +1,252 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs
+
+import (
+ "math"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// GenericDirectoryFD implements vfs.FileDescriptionImpl for a generic directory
+// inode that uses OrderChildren to track child nodes. GenericDirectoryFD is not
+// compatible with dynamic directories.
+//
+// Note that GenericDirectoryFD holds a lock over OrderedChildren while calling
+// IterDirents callback. The IterDirents callback therefore cannot hash or
+// unhash children, or recursively call IterDirents on the same underlying
+// inode.
+//
+// Must be initialize with Init before first use.
+//
+// Lock ordering: mu => children.mu.
+type GenericDirectoryFD struct {
+ vfs.FileDescriptionDefaultImpl
+ vfs.DirectoryFileDescriptionDefaultImpl
+ vfs.LockFD
+
+ vfsfd vfs.FileDescription
+ children *OrderedChildren
+
+ // mu protects the fields below.
+ mu sync.Mutex
+
+ // off is the current directory offset. Protected by "mu".
+ off int64
+}
+
+// NewGenericDirectoryFD creates a new GenericDirectoryFD and returns its
+// dentry.
+func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions) (*GenericDirectoryFD, error) {
+ fd := &GenericDirectoryFD{}
+ if err := fd.Init(children, locks, opts); err != nil {
+ return nil, err
+ }
+ if err := fd.vfsfd.Init(fd, opts.Flags, m, d, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return fd, nil
+}
+
+// Init initializes a GenericDirectoryFD. Use it when overriding
+// GenericDirectoryFD. Caller must call fd.VFSFileDescription.Init() with the
+// correct implementation.
+func (fd *GenericDirectoryFD) Init(children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions) error {
+ if vfs.AccessTypesForOpenFlags(opts)&vfs.MayWrite != 0 {
+ // Can't open directories for writing.
+ return syserror.EISDIR
+ }
+ fd.LockFD.Init(locks)
+ fd.children = children
+ return nil
+}
+
+// VFSFileDescription returns a pointer to the vfs.FileDescription representing
+// this object.
+func (fd *GenericDirectoryFD) VFSFileDescription() *vfs.FileDescription {
+ return &fd.vfsfd
+}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *GenericDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ return fd.FileDescriptionDefaultImpl.ConfigureMMap(ctx, opts)
+}
+
+// Read implmenets vfs.FileDescriptionImpl.Read.
+func (fd *GenericDirectoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return fd.DirectoryFileDescriptionDefaultImpl.Read(ctx, dst, opts)
+}
+
+// PRead implmenets vfs.FileDescriptionImpl.PRead.
+func (fd *GenericDirectoryFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return fd.DirectoryFileDescriptionDefaultImpl.PRead(ctx, dst, offset, opts)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *GenericDirectoryFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return fd.DirectoryFileDescriptionDefaultImpl.Write(ctx, src, opts)
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *GenericDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return fd.DirectoryFileDescriptionDefaultImpl.PWrite(ctx, src, offset, opts)
+}
+
+// Release implements vfs.FileDecriptionImpl.Release.
+func (fd *GenericDirectoryFD) Release() {}
+
+func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem {
+ return fd.vfsfd.VirtualDentry().Mount().Filesystem()
+}
+
+func (fd *GenericDirectoryFD) inode() Inode {
+ return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
+}
+
+// IterDirents implements vfs.FileDecriptionImpl.IterDirents. IterDirents holds
+// o.mu when calling cb.
+func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+
+ opts := vfs.StatOptions{Mask: linux.STATX_INO}
+ // Handle ".".
+ if fd.off == 0 {
+ stat, err := fd.inode().Stat(fd.filesystem(), opts)
+ if err != nil {
+ return err
+ }
+ dirent := vfs.Dirent{
+ Name: ".",
+ Type: linux.DT_DIR,
+ Ino: stat.Ino,
+ NextOff: 1,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return err
+ }
+ fd.off++
+ }
+
+ // Handle "..".
+ if fd.off == 1 {
+ vfsd := fd.vfsfd.VirtualDentry().Dentry()
+ parentInode := genericParentOrSelf(vfsd.Impl().(*Dentry)).inode
+ stat, err := parentInode.Stat(fd.filesystem(), opts)
+ if err != nil {
+ return err
+ }
+ dirent := vfs.Dirent{
+ Name: "..",
+ Type: linux.FileMode(stat.Mode).DirentType(),
+ Ino: stat.Ino,
+ NextOff: 2,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return err
+ }
+ fd.off++
+ }
+
+ // Handle static children.
+ fd.children.mu.RLock()
+ defer fd.children.mu.RUnlock()
+ // fd.off accounts for "." and "..", but fd.children do not track
+ // these.
+ childIdx := fd.off - 2
+ for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() {
+ inode := it.Dentry.Impl().(*Dentry).inode
+ stat, err := inode.Stat(fd.filesystem(), opts)
+ if err != nil {
+ return err
+ }
+ dirent := vfs.Dirent{
+ Name: it.Name,
+ Type: linux.FileMode(stat.Mode).DirentType(),
+ Ino: stat.Ino,
+ NextOff: fd.off + 1,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return err
+ }
+ fd.off++
+ }
+
+ var err error
+ relOffset := fd.off - int64(len(fd.children.set)) - 2
+ fd.off, err = fd.inode().IterDirents(ctx, cb, fd.off, relOffset)
+ return err
+}
+
+// Seek implements vfs.FileDecriptionImpl.Seek.
+func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+
+ switch whence {
+ case linux.SEEK_SET:
+ // Use offset as given.
+ case linux.SEEK_CUR:
+ offset += fd.off
+ case linux.SEEK_END:
+ // TODO(gvisor.dev/issue/1193): This can prevent new files from showing up
+ // if they are added after SEEK_END.
+ offset = math.MaxInt64
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ fd.off = offset
+ return offset, nil
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ fs := fd.filesystem()
+ inode := fd.inode()
+ return inode.Stat(fs, opts)
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ creds := auth.CredentialsFromContext(ctx)
+ inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
+ return inode.SetStat(ctx, fd.filesystem(), creds, opts)
+}
+
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
+func (fd *GenericDirectoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ return fd.DirectoryFileDescriptionDefaultImpl.Allocate(ctx, mode, offset, length)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *GenericDirectoryFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *GenericDirectoryFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
new file mode 100644
index 000000000..8939871c1
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -0,0 +1,840 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs
+
+// This file implements vfs.FilesystemImpl for kernfs.
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// stepExistingLocked resolves rp.Component() in parent directory vfsd.
+//
+// stepExistingLocked is loosely analogous to fs/namei.c:walk_component().
+//
+// Preconditions: Filesystem.mu must be locked for at least reading. !rp.Done().
+//
+// Postcondition: Caller must call fs.processDeferredDecRefs*.
+func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, mayFollowSymlinks bool) (*vfs.Dentry, error) {
+ d := vfsd.Impl().(*Dentry)
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ // Directory searchable?
+ if err := d.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+afterSymlink:
+ name := rp.Component()
+ // Revalidation must be skipped if name is "." or ".."; d or its parent
+ // respectively can't be expected to transition from invalidated back to
+ // valid, so detecting invalidation and retrying would loop forever. This
+ // is consistent with Linux: fs/namei.c:walk_component() => lookup_fast()
+ // calls d_revalidate(), but walk_component() => handle_dots() does not.
+ if name == "." {
+ rp.Advance()
+ return vfsd, nil
+ }
+ if name == ".." {
+ if isRoot, err := rp.CheckRoot(vfsd); err != nil {
+ return nil, err
+ } else if isRoot || d.parent == nil {
+ rp.Advance()
+ return vfsd, nil
+ }
+ if err := rp.CheckMount(&d.parent.vfsd); err != nil {
+ return nil, err
+ }
+ rp.Advance()
+ return &d.parent.vfsd, nil
+ }
+ if len(name) > linux.NAME_MAX {
+ return nil, syserror.ENAMETOOLONG
+ }
+ d.dirMu.Lock()
+ next, err := fs.revalidateChildLocked(ctx, rp.VirtualFilesystem(), d, name, d.children[name])
+ d.dirMu.Unlock()
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.CheckMount(&next.vfsd); err != nil {
+ return nil, err
+ }
+ // Resolve any symlink at current path component.
+ if mayFollowSymlinks && rp.ShouldFollowSymlink() && next.isSymlink() {
+ targetVD, targetPathname, err := next.inode.Getlink(ctx, rp.Mount())
+ if err != nil {
+ return nil, err
+ }
+ if targetVD.Ok() {
+ err := rp.HandleJump(targetVD)
+ targetVD.DecRef()
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ if err := rp.HandleSymlink(targetPathname); err != nil {
+ return nil, err
+ }
+ }
+ goto afterSymlink
+ }
+ rp.Advance()
+ return &next.vfsd, nil
+}
+
+// revalidateChildLocked must be called after a call to parent.vfsd.Child(name)
+// or vfs.ResolvingPath.ResolveChild(name) returns childVFSD (which may be
+// nil) to verify that the returned child (or lack thereof) is correct.
+//
+// Preconditions: Filesystem.mu must be locked for at least reading.
+// parent.dirMu must be locked. parent.isDir(). name is not "." or "..".
+//
+// Postconditions: Caller must call fs.processDeferredDecRefs*.
+func (fs *Filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *Dentry, name string, child *Dentry) (*Dentry, error) {
+ if child != nil {
+ // Cached dentry exists, revalidate.
+ if !child.inode.Valid(ctx) {
+ delete(parent.children, name)
+ vfsObj.InvalidateDentry(&child.vfsd)
+ fs.deferDecRef(&child.vfsd) // Reference from Lookup.
+ child = nil
+ }
+ }
+ if child == nil {
+ // Dentry isn't cached; it either doesn't exist or failed
+ // revalidation. Attempt to resolve it via Lookup.
+ //
+ // FIXME(gvisor.dev/issue/1193): Inode.Lookup() should return
+ // *(kernfs.)Dentry, not *vfs.Dentry, since (kernfs.)Filesystem assumes
+ // that all dentries in the filesystem are (kernfs.)Dentry and performs
+ // vfs.DentryImpl casts accordingly.
+ childVFSD, err := parent.inode.Lookup(ctx, name)
+ if err != nil {
+ return nil, err
+ }
+ // Reference on childVFSD dropped by a corresponding Valid.
+ child = childVFSD.Impl().(*Dentry)
+ parent.insertChildLocked(name, child)
+ }
+ return child, nil
+}
+
+// walkExistingLocked resolves rp to an existing file.
+//
+// walkExistingLocked is loosely analogous to Linux's
+// fs/namei.c:path_lookupat().
+//
+// Preconditions: Filesystem.mu must be locked for at least reading.
+//
+// Postconditions: Caller must call fs.processDeferredDecRefs*.
+func (fs *Filesystem) walkExistingLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) {
+ vfsd := rp.Start()
+ for !rp.Done() {
+ var err error
+ vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd, true /* mayFollowSymlinks */)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ d := vfsd.Impl().(*Dentry)
+ if rp.MustBeDir() && !d.isDir() {
+ return nil, nil, syserror.ENOTDIR
+ }
+ return vfsd, d.inode, nil
+}
+
+// walkParentDirLocked resolves all but the last path component of rp to an
+// existing directory. It does not check that the returned directory is
+// searchable by the provider of rp.
+//
+// walkParentDirLocked is loosely analogous to Linux's
+// fs/namei.c:path_parentat().
+//
+// Preconditions: Filesystem.mu must be locked for at least reading. !rp.Done().
+//
+// Postconditions: Caller must call fs.processDeferredDecRefs*.
+func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) {
+ vfsd := rp.Start()
+ for !rp.Final() {
+ var err error
+ vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd, true /* mayFollowSymlinks */)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ d := vfsd.Impl().(*Dentry)
+ if !d.isDir() {
+ return nil, nil, syserror.ENOTDIR
+ }
+ return vfsd, d.inode, nil
+}
+
+// checkCreateLocked checks that a file named rp.Component() may be created in
+// directory parentVFSD, then returns rp.Component().
+//
+// Preconditions: Filesystem.mu must be locked for at least reading. parentInode
+// == parentVFSD.Impl().(*Dentry).Inode. isDir(parentInode) == true.
+func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode Inode) (string, error) {
+ if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return "", err
+ }
+ pc := rp.Component()
+ if pc == "." || pc == ".." {
+ return "", syserror.EEXIST
+ }
+ if len(pc) > linux.NAME_MAX {
+ return "", syserror.ENAMETOOLONG
+ }
+ // FIXME(gvisor.dev/issue/1193): Data race due to not holding dirMu.
+ if _, ok := parentVFSD.Impl().(*Dentry).children[pc]; ok {
+ return "", syserror.EEXIST
+ }
+ if parentVFSD.IsDead() {
+ return "", syserror.ENOENT
+ }
+ return pc, nil
+}
+
+// checkDeleteLocked checks that the file represented by vfsd may be deleted.
+//
+// Preconditions: Filesystem.mu must be locked for at least reading.
+func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry) error {
+ parent := vfsd.Impl().(*Dentry).parent
+ if parent == nil {
+ return syserror.EBUSY
+ }
+ if parent.vfsd.IsDead() {
+ return syserror.ENOENT
+ }
+ if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ return nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *Filesystem) Release() {
+}
+
+// Sync implements vfs.FilesystemImpl.Sync.
+func (fs *Filesystem) Sync(ctx context.Context) error {
+ // All filesystem state is in-memory.
+ return nil
+}
+
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *Filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ fs.mu.RLock()
+ defer fs.processDeferredDecRefs()
+ defer fs.mu.RUnlock()
+
+ _, inode, err := fs.walkExistingLocked(ctx, rp)
+ if err != nil {
+ return err
+ }
+ return inode.CheckPermissions(ctx, creds, ats)
+}
+
+// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
+func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
+ fs.mu.RLock()
+ defer fs.processDeferredDecRefs()
+ defer fs.mu.RUnlock()
+ vfsd, inode, err := fs.walkExistingLocked(ctx, rp)
+ if err != nil {
+ return nil, err
+ }
+
+ if opts.CheckSearchable {
+ d := vfsd.Impl().(*Dentry)
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ }
+ vfsd.IncRef() // Ownership transferred to caller.
+ return vfsd, nil
+}
+
+// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt.
+func (fs *Filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
+ fs.mu.RLock()
+ defer fs.processDeferredDecRefs()
+ defer fs.mu.RUnlock()
+ vfsd, _, err := fs.walkParentDirLocked(ctx, rp)
+ if err != nil {
+ return nil, err
+ }
+ vfsd.IncRef() // Ownership transferred to caller.
+ return vfsd, nil
+}
+
+// LinkAt implements vfs.FilesystemImpl.LinkAt.
+func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked()
+ if err != nil {
+ return err
+ }
+ pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
+ if err != nil {
+ return err
+ }
+ if rp.Mount() != vd.Mount() {
+ return syserror.EXDEV
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+
+ d := vd.Dentry().Impl().(*Dentry)
+ if d.isDir() {
+ return syserror.EPERM
+ }
+
+ childVFSD, err := parentInode.NewLink(ctx, pc, d.inode)
+ if err != nil {
+ return err
+ }
+ parentVFSD.Impl().(*Dentry).InsertChild(pc, childVFSD.Impl().(*Dentry))
+ return nil
+}
+
+// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
+func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked()
+ if err != nil {
+ return err
+ }
+ pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
+ if err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ childVFSD, err := parentInode.NewDir(ctx, pc, opts)
+ if err != nil {
+ return err
+ }
+ parentVFSD.Impl().(*Dentry).InsertChild(pc, childVFSD.Impl().(*Dentry))
+ return nil
+}
+
+// MknodAt implements vfs.FilesystemImpl.MknodAt.
+func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked()
+ if err != nil {
+ return err
+ }
+ pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
+ if err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ newVFSD, err := parentInode.NewNode(ctx, pc, opts)
+ if err != nil {
+ return err
+ }
+ parentVFSD.Impl().(*Dentry).InsertChild(pc, newVFSD.Impl().(*Dentry))
+ return nil
+}
+
+// OpenAt implements vfs.FilesystemImpl.OpenAt.
+func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ // Filter out flags that are not supported by kernfs. O_DIRECTORY and
+ // O_NOFOLLOW have no effect here (they're handled by VFS by setting
+ // appropriate bits in rp), but are returned by
+ // FileDescriptionImpl.StatusFlags().
+ opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NONBLOCK
+ ats := vfs.AccessTypesForOpenFlags(&opts)
+
+ // Do not create new file.
+ if opts.Flags&linux.O_CREAT == 0 {
+ fs.mu.RLock()
+ defer fs.processDeferredDecRefs()
+ defer fs.mu.RUnlock()
+ vfsd, inode, err := fs.walkExistingLocked(ctx, rp)
+ if err != nil {
+ return nil, err
+ }
+ if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+ return inode.Open(ctx, rp, vfsd, opts)
+ }
+
+ // May create new file.
+ mustCreate := opts.Flags&linux.O_EXCL != 0
+ vfsd := rp.Start()
+ inode := vfsd.Impl().(*Dentry).inode
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ if rp.Done() {
+ if rp.MustBeDir() {
+ return nil, syserror.EISDIR
+ }
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+ return inode.Open(ctx, rp, vfsd, opts)
+ }
+afterTrailingSymlink:
+ parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked()
+ if err != nil {
+ return nil, err
+ }
+ // Check for search permission in the parent directory.
+ if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ // Reject attempts to open directories with O_CREAT.
+ if rp.MustBeDir() {
+ return nil, syserror.EISDIR
+ }
+ pc := rp.Component()
+ if pc == "." || pc == ".." {
+ return nil, syserror.EISDIR
+ }
+ if len(pc) > linux.NAME_MAX {
+ return nil, syserror.ENAMETOOLONG
+ }
+ // Determine whether or not we need to create a file.
+ childVFSD, err := fs.stepExistingLocked(ctx, rp, parentVFSD, false /* mayFollowSymlinks */)
+ if err == syserror.ENOENT {
+ // Already checked for searchability above; now check for writability.
+ if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return nil, err
+ }
+ defer rp.Mount().EndWrite()
+ // Create and open the child.
+ childVFSD, err = parentInode.NewFile(ctx, pc, opts)
+ if err != nil {
+ return nil, err
+ }
+ child := childVFSD.Impl().(*Dentry)
+ parentVFSD.Impl().(*Dentry).InsertChild(pc, child)
+ return child.inode.Open(ctx, rp, childVFSD, opts)
+ }
+ if err != nil {
+ return nil, err
+ }
+ // Open existing file or follow symlink.
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ child := childVFSD.Impl().(*Dentry)
+ if rp.ShouldFollowSymlink() && child.isSymlink() {
+ targetVD, targetPathname, err := child.inode.Getlink(ctx, rp.Mount())
+ if err != nil {
+ return nil, err
+ }
+ if targetVD.Ok() {
+ err := rp.HandleJump(targetVD)
+ targetVD.DecRef()
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ if err := rp.HandleSymlink(targetPathname); err != nil {
+ return nil, err
+ }
+ }
+ // rp.Final() may no longer be true since we now need to resolve the
+ // symlink target.
+ goto afterTrailingSymlink
+ }
+ if err := child.inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+ return child.inode.Open(ctx, rp, &child.vfsd, opts)
+}
+
+// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
+func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
+ fs.mu.RLock()
+ d, inode, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return "", err
+ }
+ if !d.Impl().(*Dentry).isSymlink() {
+ return "", syserror.EINVAL
+ }
+ return inode.Readlink(ctx)
+}
+
+// RenameAt implements vfs.FilesystemImpl.RenameAt.
+func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
+ // Only RENAME_NOREPLACE is supported.
+ if opts.Flags&^linux.RENAME_NOREPLACE != 0 {
+ return syserror.EINVAL
+ }
+ noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0
+
+ fs.mu.Lock()
+ defer fs.processDeferredDecRefsLocked()
+ defer fs.mu.Unlock()
+
+ // Resolve the destination directory first to verify that it's on this
+ // Mount.
+ dstDirVFSD, dstDirInode, err := fs.walkParentDirLocked(ctx, rp)
+ if err != nil {
+ return err
+ }
+ dstDir := dstDirVFSD.Impl().(*Dentry)
+ mnt := rp.Mount()
+ if mnt != oldParentVD.Mount() {
+ return syserror.EXDEV
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+
+ srcDirVFSD := oldParentVD.Dentry()
+ srcDir := srcDirVFSD.Impl().(*Dentry)
+ srcDir.dirMu.Lock()
+ src, err := fs.revalidateChildLocked(ctx, rp.VirtualFilesystem(), srcDir, oldName, srcDir.children[oldName])
+ srcDir.dirMu.Unlock()
+ if err != nil {
+ return err
+ }
+ srcVFSD := &src.vfsd
+
+ // Can we remove the src dentry?
+ if err := checkDeleteLocked(ctx, rp, srcVFSD); err != nil {
+ return err
+ }
+
+ // Can we create the dst dentry?
+ var dst *Dentry
+ pc, err := checkCreateLocked(ctx, rp, dstDirVFSD, dstDirInode)
+ switch err {
+ case nil:
+ // Ok, continue with rename as replacement.
+ case syserror.EEXIST:
+ if noReplace {
+ // Won't overwrite existing node since RENAME_NOREPLACE was requested.
+ return syserror.EEXIST
+ }
+ dst = dstDir.children[pc]
+ if dst == nil {
+ panic(fmt.Sprintf("Child %q for parent Dentry %+v disappeared inside atomic section?", pc, dstDirVFSD))
+ }
+ default:
+ return err
+ }
+ var dstVFSD *vfs.Dentry
+ if dst != nil {
+ dstVFSD = &dst.vfsd
+ }
+
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ virtfs := rp.VirtualFilesystem()
+
+ // We can't deadlock here due to lock ordering because we're protected from
+ // concurrent renames by fs.mu held for writing.
+ srcDir.dirMu.Lock()
+ defer srcDir.dirMu.Unlock()
+ if srcDir != dstDir {
+ dstDir.dirMu.Lock()
+ defer dstDir.dirMu.Unlock()
+ }
+
+ if err := virtfs.PrepareRenameDentry(mntns, srcVFSD, dstVFSD); err != nil {
+ return err
+ }
+ replaced, err := srcDir.inode.Rename(ctx, src.name, pc, srcVFSD, dstDirVFSD)
+ if err != nil {
+ virtfs.AbortRenameDentry(srcVFSD, dstVFSD)
+ return err
+ }
+ delete(srcDir.children, src.name)
+ if srcDir != dstDir {
+ fs.deferDecRef(srcDirVFSD)
+ dstDir.IncRef()
+ }
+ src.parent = dstDir
+ src.name = pc
+ if dstDir.children == nil {
+ dstDir.children = make(map[string]*Dentry)
+ }
+ dstDir.children[pc] = src
+ virtfs.CommitRenameReplaceDentry(srcVFSD, replaced)
+ return nil
+}
+
+// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
+func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ vfsd, inode, err := fs.walkExistingLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked()
+ if err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ if err := checkDeleteLocked(ctx, rp, vfsd); err != nil {
+ return err
+ }
+ d := vfsd.Impl().(*Dentry)
+ if !d.isDir() {
+ return syserror.ENOTDIR
+ }
+ if inode.HasChildren() {
+ return syserror.ENOTEMPTY
+ }
+ virtfs := rp.VirtualFilesystem()
+ parentDentry := d.parent
+ parentDentry.dirMu.Lock()
+ defer parentDentry.dirMu.Unlock()
+
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil {
+ return err
+ }
+ if err := parentDentry.inode.RmDir(ctx, rp.Component(), vfsd); err != nil {
+ virtfs.AbortDeleteDentry(vfsd)
+ return err
+ }
+ virtfs.CommitDeleteDentry(vfsd)
+ return nil
+}
+
+// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
+func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
+ fs.mu.RLock()
+ _, inode, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return err
+ }
+ if opts.Stat.Mask == 0 {
+ return nil
+ }
+ return inode.SetStat(ctx, fs.VFSFilesystem(), rp.Credentials(), opts)
+}
+
+// StatAt implements vfs.FilesystemImpl.StatAt.
+func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
+ fs.mu.RLock()
+ _, inode, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ return inode.Stat(fs.VFSFilesystem(), opts)
+}
+
+// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
+func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return linux.Statfs{}, err
+ }
+ // TODO(gvisor.dev/issue/1193): actually implement statfs.
+ return linux.Statfs{}, syserror.ENOSYS
+}
+
+// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
+func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked()
+ if err != nil {
+ return err
+ }
+ pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
+ if err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ childVFSD, err := parentInode.NewSymlink(ctx, pc, target)
+ if err != nil {
+ return err
+ }
+ parentVFSD.Impl().(*Dentry).InsertChild(pc, childVFSD.Impl().(*Dentry))
+ return nil
+}
+
+// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
+func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ vfsd, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked()
+ if err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ if err := checkDeleteLocked(ctx, rp, vfsd); err != nil {
+ return err
+ }
+ d := vfsd.Impl().(*Dentry)
+ if d.isDir() {
+ return syserror.EISDIR
+ }
+ virtfs := rp.VirtualFilesystem()
+ parentDentry := d.parent
+ parentDentry.dirMu.Lock()
+ defer parentDentry.dirMu.Unlock()
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil {
+ return err
+ }
+ if err := parentDentry.inode.Unlink(ctx, rp.Component(), vfsd); err != nil {
+ virtfs.AbortDeleteDentry(vfsd)
+ return err
+ }
+ virtfs.CommitDeleteDentry(vfsd)
+ return nil
+}
+
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
+ fs.mu.RLock()
+ _, inode, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return nil, err
+ }
+ if err := inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ return nil, syserror.ECONNREFUSED
+}
+
+// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
+func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return nil, err
+ }
+ // kernfs currently does not support extended attributes.
+ return nil, syserror.ENOTSUP
+}
+
+// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
+func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return "", err
+ }
+ // kernfs currently does not support extended attributes.
+ return "", syserror.ENOTSUP
+}
+
+// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
+func (fs *Filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return err
+ }
+ // kernfs currently does not support extended attributes.
+ return syserror.ENOTSUP
+}
+
+// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
+func (fs *Filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return err
+ }
+ // kernfs currently does not support extended attributes.
+ return syserror.ENOTSUP
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *Filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*Dentry), b)
+}
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
new file mode 100644
index 000000000..4cb885d87
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -0,0 +1,613 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// InodeNoopRefCount partially implements the Inode interface, specifically the
+// inodeRefs sub interface. InodeNoopRefCount implements a simple reference
+// count for inodes, performing no extra actions when references are obtained or
+// released. This is suitable for simple file inodes that don't reference any
+// resources.
+type InodeNoopRefCount struct {
+}
+
+// IncRef implements Inode.IncRef.
+func (InodeNoopRefCount) IncRef() {
+}
+
+// DecRef implements Inode.DecRef.
+func (InodeNoopRefCount) DecRef() {
+}
+
+// TryIncRef implements Inode.TryIncRef.
+func (InodeNoopRefCount) TryIncRef() bool {
+ return true
+}
+
+// Destroy implements Inode.Destroy.
+func (InodeNoopRefCount) Destroy() {
+}
+
+// InodeDirectoryNoNewChildren partially implements the Inode interface.
+// InodeDirectoryNoNewChildren represents a directory inode which does not
+// support creation of new children.
+type InodeDirectoryNoNewChildren struct{}
+
+// NewFile implements Inode.NewFile.
+func (InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+// NewDir implements Inode.NewDir.
+func (InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+// NewLink implements Inode.NewLink.
+func (InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+// NewSymlink implements Inode.NewSymlink.
+func (InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+// NewNode implements Inode.NewNode.
+func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+// InodeNotDirectory partially implements the Inode interface, specifically the
+// inodeDirectory and inodeDynamicDirectory sub interfaces. Inodes that do not
+// represent directories can embed this to provide no-op implementations for
+// directory-related functions.
+type InodeNotDirectory struct {
+}
+
+// HasChildren implements Inode.HasChildren.
+func (InodeNotDirectory) HasChildren() bool {
+ return false
+}
+
+// NewFile implements Inode.NewFile.
+func (InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) {
+ panic("NewFile called on non-directory inode")
+}
+
+// NewDir implements Inode.NewDir.
+func (InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) {
+ panic("NewDir called on non-directory inode")
+}
+
+// NewLink implements Inode.NewLinkink.
+func (InodeNotDirectory) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) {
+ panic("NewLink called on non-directory inode")
+}
+
+// NewSymlink implements Inode.NewSymlink.
+func (InodeNotDirectory) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
+ panic("NewSymlink called on non-directory inode")
+}
+
+// NewNode implements Inode.NewNode.
+func (InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
+ panic("NewNode called on non-directory inode")
+}
+
+// Unlink implements Inode.Unlink.
+func (InodeNotDirectory) Unlink(context.Context, string, *vfs.Dentry) error {
+ panic("Unlink called on non-directory inode")
+}
+
+// RmDir implements Inode.RmDir.
+func (InodeNotDirectory) RmDir(context.Context, string, *vfs.Dentry) error {
+ panic("RmDir called on non-directory inode")
+}
+
+// Rename implements Inode.Rename.
+func (InodeNotDirectory) Rename(context.Context, string, string, *vfs.Dentry, *vfs.Dentry) (*vfs.Dentry, error) {
+ panic("Rename called on non-directory inode")
+}
+
+// Lookup implements Inode.Lookup.
+func (InodeNotDirectory) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ panic("Lookup called on non-directory inode")
+}
+
+// IterDirents implements Inode.IterDirents.
+func (InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
+ panic("IterDirents called on non-directory inode")
+}
+
+// Valid implements Inode.Valid.
+func (InodeNotDirectory) Valid(context.Context) bool {
+ return true
+}
+
+// InodeNoDynamicLookup partially implements the Inode interface, specifically
+// the inodeDynamicLookup sub interface. Directory inodes that do not support
+// dymanic entries (i.e. entries that are not "hashed" into the
+// vfs.Dentry.children) can embed this to provide no-op implementations for
+// functions related to dynamic entries.
+type InodeNoDynamicLookup struct{}
+
+// Lookup implements Inode.Lookup.
+func (InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ return nil, syserror.ENOENT
+}
+
+// IterDirents implements Inode.IterDirents.
+func (InodeNoDynamicLookup) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+ return offset, nil
+}
+
+// Valid implements Inode.Valid.
+func (InodeNoDynamicLookup) Valid(ctx context.Context) bool {
+ return true
+}
+
+// InodeNotSymlink partially implements the Inode interface, specifically the
+// inodeSymlink sub interface. All inodes that are not symlinks may embed this
+// to return the appropriate errors from symlink-related functions.
+type InodeNotSymlink struct{}
+
+// Readlink implements Inode.Readlink.
+func (InodeNotSymlink) Readlink(context.Context) (string, error) {
+ return "", syserror.EINVAL
+}
+
+// Getlink implements Inode.Getlink.
+func (InodeNotSymlink) Getlink(context.Context, *vfs.Mount) (vfs.VirtualDentry, string, error) {
+ return vfs.VirtualDentry{}, "", syserror.EINVAL
+}
+
+// InodeAttrs partially implements the Inode interface, specifically the
+// inodeMetadata sub interface. InodeAttrs provides functionality related to
+// inode attributes.
+//
+// Must be initialized by Init prior to first use.
+type InodeAttrs struct {
+ devMajor uint32
+ devMinor uint32
+ ino uint64
+ mode uint32
+ uid uint32
+ gid uint32
+ nlink uint32
+}
+
+// Init initializes this InodeAttrs.
+func (a *InodeAttrs) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, mode linux.FileMode) {
+ if mode.FileType() == 0 {
+ panic(fmt.Sprintf("No file type specified in 'mode' for InodeAttrs.Init(): mode=0%o", mode))
+ }
+
+ nlink := uint32(1)
+ if mode.FileType() == linux.ModeDirectory {
+ nlink = 2
+ }
+ a.devMajor = devMajor
+ a.devMinor = devMinor
+ atomic.StoreUint64(&a.ino, ino)
+ atomic.StoreUint32(&a.mode, uint32(mode))
+ atomic.StoreUint32(&a.uid, uint32(creds.EffectiveKUID))
+ atomic.StoreUint32(&a.gid, uint32(creds.EffectiveKGID))
+ atomic.StoreUint32(&a.nlink, nlink)
+}
+
+// DevMajor returns the device major number.
+func (a *InodeAttrs) DevMajor() uint32 {
+ return a.devMajor
+}
+
+// DevMinor returns the device minor number.
+func (a *InodeAttrs) DevMinor() uint32 {
+ return a.devMinor
+}
+
+// Ino returns the inode id.
+func (a *InodeAttrs) Ino() uint64 {
+ return atomic.LoadUint64(&a.ino)
+}
+
+// Mode implements Inode.Mode.
+func (a *InodeAttrs) Mode() linux.FileMode {
+ return linux.FileMode(atomic.LoadUint32(&a.mode))
+}
+
+// Stat partially implements Inode.Stat. Note that this function doesn't provide
+// all the stat fields, and the embedder should consider extending the result
+// with filesystem-specific fields.
+func (a *InodeAttrs) Stat(*vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) {
+ var stat linux.Statx
+ stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK
+ stat.DevMajor = a.devMajor
+ stat.DevMinor = a.devMinor
+ stat.Ino = atomic.LoadUint64(&a.ino)
+ stat.Mode = uint16(a.Mode())
+ stat.UID = atomic.LoadUint32(&a.uid)
+ stat.GID = atomic.LoadUint32(&a.gid)
+ stat.Nlink = atomic.LoadUint32(&a.nlink)
+
+ // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps.
+
+ return stat, nil
+}
+
+// SetStat implements Inode.SetStat.
+func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ if opts.Stat.Mask == 0 {
+ return nil
+ }
+ if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID) != 0 {
+ return syserror.EPERM
+ }
+ if err := vfs.CheckSetStat(ctx, creds, &opts.Stat, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil {
+ return err
+ }
+
+ stat := opts.Stat
+ if stat.Mask&linux.STATX_MODE != 0 {
+ for {
+ old := atomic.LoadUint32(&a.mode)
+ new := old | uint32(stat.Mode & ^uint16(linux.S_IFMT))
+ if swapped := atomic.CompareAndSwapUint32(&a.mode, old, new); swapped {
+ break
+ }
+ }
+ }
+
+ if stat.Mask&linux.STATX_UID != 0 {
+ atomic.StoreUint32(&a.uid, stat.UID)
+ }
+ if stat.Mask&linux.STATX_GID != 0 {
+ atomic.StoreUint32(&a.gid, stat.GID)
+ }
+
+ // Note that not all fields are modifiable. For example, the file type and
+ // inode numbers are immutable after node creation.
+
+ // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps.
+ // Also, STATX_SIZE will need some special handling, because read-only static
+ // files should return EIO for truncate operations.
+
+ return nil
+}
+
+// CheckPermissions implements Inode.CheckPermissions.
+func (a *InodeAttrs) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ return vfs.GenericCheckPermissions(
+ creds,
+ ats,
+ a.Mode(),
+ auth.KUID(atomic.LoadUint32(&a.uid)),
+ auth.KGID(atomic.LoadUint32(&a.gid)),
+ )
+}
+
+// IncLinks implements Inode.IncLinks.
+func (a *InodeAttrs) IncLinks(n uint32) {
+ if atomic.AddUint32(&a.nlink, n) <= n {
+ panic("InodeLink.IncLinks called with no existing links")
+ }
+}
+
+// DecLinks implements Inode.DecLinks.
+func (a *InodeAttrs) DecLinks() {
+ if nlink := atomic.AddUint32(&a.nlink, ^uint32(0)); nlink == ^uint32(0) {
+ // Negative overflow
+ panic("Inode.DecLinks called at 0 links")
+ }
+}
+
+type slot struct {
+ Name string
+ Dentry *vfs.Dentry
+ slotEntry
+}
+
+// OrderedChildrenOptions contains initialization options for OrderedChildren.
+type OrderedChildrenOptions struct {
+ // Writable indicates whether vfs.FilesystemImpl methods implemented by
+ // OrderedChildren may modify the tracked children. This applies to
+ // operations related to rename, unlink and rmdir. If an OrderedChildren is
+ // not writable, these operations all fail with EPERM.
+ Writable bool
+}
+
+// OrderedChildren partially implements the Inode interface. OrderedChildren can
+// be embedded in directory inodes to keep track of the children in the
+// directory, and can then be used to implement a generic directory FD -- see
+// GenericDirectoryFD. OrderedChildren is not compatible with dynamic
+// directories.
+//
+// Must be initialize with Init before first use.
+type OrderedChildren struct {
+ refs.AtomicRefCount
+
+ // Can children be modified by user syscalls? It set to false, interface
+ // methods that would modify the children return EPERM. Immutable.
+ writable bool
+
+ mu sync.RWMutex
+ order slotList
+ set map[string]*slot
+}
+
+// Init initializes an OrderedChildren.
+func (o *OrderedChildren) Init(opts OrderedChildrenOptions) {
+ o.writable = opts.Writable
+ o.set = make(map[string]*slot)
+}
+
+// DecRef implements Inode.DecRef.
+func (o *OrderedChildren) DecRef() {
+ o.AtomicRefCount.DecRefWithDestructor(o.Destroy)
+}
+
+// Destroy cleans up resources referenced by this OrderedChildren.
+func (o *OrderedChildren) Destroy() {
+ o.mu.Lock()
+ defer o.mu.Unlock()
+ o.order.Reset()
+ o.set = nil
+}
+
+// Populate inserts children into this OrderedChildren, and d's dentry
+// cache. Populate returns the number of directories inserted, which the caller
+// may use to update the link count for the parent directory.
+//
+// Precondition: d must represent a directory inode. children must not contain
+// any conflicting entries already in o.
+func (o *OrderedChildren) Populate(d *Dentry, children map[string]*Dentry) uint32 {
+ var links uint32
+ for name, child := range children {
+ if child.isDir() {
+ links++
+ }
+ if err := o.Insert(name, child.VFSDentry()); err != nil {
+ panic(fmt.Sprintf("Collision when attempting to insert child %q (%+v) into %+v", name, child, d))
+ }
+ d.InsertChild(name, child)
+ }
+ return links
+}
+
+// HasChildren implements Inode.HasChildren.
+func (o *OrderedChildren) HasChildren() bool {
+ o.mu.RLock()
+ defer o.mu.RUnlock()
+ return len(o.set) > 0
+}
+
+// Insert inserts child into o. This ignores the writability of o, as this is
+// not part of the vfs.FilesystemImpl interface, and is a lower-level operation.
+func (o *OrderedChildren) Insert(name string, child *vfs.Dentry) error {
+ o.mu.Lock()
+ defer o.mu.Unlock()
+ if _, ok := o.set[name]; ok {
+ return syserror.EEXIST
+ }
+ s := &slot{
+ Name: name,
+ Dentry: child,
+ }
+ o.order.PushBack(s)
+ o.set[name] = s
+ return nil
+}
+
+// Precondition: caller must hold o.mu for writing.
+func (o *OrderedChildren) removeLocked(name string) {
+ if s, ok := o.set[name]; ok {
+ delete(o.set, name)
+ o.order.Remove(s)
+ }
+}
+
+// Precondition: caller must hold o.mu for writing.
+func (o *OrderedChildren) replaceChildLocked(name string, new *vfs.Dentry) *vfs.Dentry {
+ if s, ok := o.set[name]; ok {
+ // Existing slot with given name, simply replace the dentry.
+ var old *vfs.Dentry
+ old, s.Dentry = s.Dentry, new
+ return old
+ }
+
+ // No existing slot with given name, create and hash new slot.
+ s := &slot{
+ Name: name,
+ Dentry: new,
+ }
+ o.order.PushBack(s)
+ o.set[name] = s
+ return nil
+}
+
+// Precondition: caller must hold o.mu for reading or writing.
+func (o *OrderedChildren) checkExistingLocked(name string, child *vfs.Dentry) error {
+ s, ok := o.set[name]
+ if !ok {
+ return syserror.ENOENT
+ }
+ if s.Dentry != child {
+ panic(fmt.Sprintf("Dentry hashed into inode doesn't match what vfs thinks! OrderedChild: %+v, vfs: %+v", s.Dentry, child))
+ }
+ return nil
+}
+
+// Unlink implements Inode.Unlink.
+func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *vfs.Dentry) error {
+ if !o.writable {
+ return syserror.EPERM
+ }
+ o.mu.Lock()
+ defer o.mu.Unlock()
+ if err := o.checkExistingLocked(name, child); err != nil {
+ return err
+ }
+
+ // TODO(gvisor.dev/issue/3027): Check sticky bit before removing.
+ o.removeLocked(name)
+ return nil
+}
+
+// Rmdir implements Inode.Rmdir.
+func (o *OrderedChildren) RmDir(ctx context.Context, name string, child *vfs.Dentry) error {
+ // We're not responsible for checking that child is a directory, that it's
+ // empty, or updating any link counts; so this is the same as unlink.
+ return o.Unlink(ctx, name, child)
+}
+
+type renameAcrossDifferentImplementationsError struct{}
+
+func (renameAcrossDifferentImplementationsError) Error() string {
+ return "rename across inodes with different implementations"
+}
+
+// Rename implements Inode.Rename.
+//
+// Precondition: Rename may only be called across two directory inodes with
+// identical implementations of Rename. Practically, this means filesystems that
+// implement Rename by embedding OrderedChildren for any directory
+// implementation must use OrderedChildren for all directory implementations
+// that will support Rename.
+//
+// Postcondition: reference on any replaced dentry transferred to caller.
+func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (*vfs.Dentry, error) {
+ dst, ok := dstDir.Impl().(*Dentry).inode.(interface{}).(*OrderedChildren)
+ if !ok {
+ return nil, renameAcrossDifferentImplementationsError{}
+ }
+ if !o.writable || !dst.writable {
+ return nil, syserror.EPERM
+ }
+ // Note: There's a potential deadlock below if concurrent calls to Rename
+ // refer to the same src and dst directories in reverse. We avoid any
+ // ordering issues because the caller is required to serialize concurrent
+ // calls to Rename in accordance with the interface declaration.
+ o.mu.Lock()
+ defer o.mu.Unlock()
+ if dst != o {
+ dst.mu.Lock()
+ defer dst.mu.Unlock()
+ }
+ if err := o.checkExistingLocked(oldname, child); err != nil {
+ return nil, err
+ }
+
+ // TODO(gvisor.dev/issue/3027): Check sticky bit before removing.
+ replaced := dst.replaceChildLocked(newname, child)
+ return replaced, nil
+}
+
+// nthLocked returns an iterator to the nth child tracked by this object. The
+// iterator is valid until the caller releases o.mu. Returns nil if the
+// requested index falls out of bounds.
+//
+// Preconditon: Caller must hold o.mu for reading.
+func (o *OrderedChildren) nthLocked(i int64) *slot {
+ for it := o.order.Front(); it != nil && i >= 0; it = it.Next() {
+ if i == 0 {
+ return it
+ }
+ i--
+ }
+ return nil
+}
+
+// InodeSymlink partially implements Inode interface for symlinks.
+type InodeSymlink struct {
+ InodeNotDirectory
+}
+
+// Open implements Inode.Open.
+func (InodeSymlink) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ return nil, syserror.ELOOP
+}
+
+// StaticDirectory is a standard implementation of a directory with static
+// contents.
+//
+// +stateify savable
+type StaticDirectory struct {
+ InodeNotSymlink
+ InodeDirectoryNoNewChildren
+ InodeAttrs
+ InodeNoDynamicLookup
+ OrderedChildren
+
+ locks vfs.FileLocks
+}
+
+var _ Inode = (*StaticDirectory)(nil)
+
+// NewStaticDir creates a new static directory and returns its dentry.
+func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]*Dentry) *Dentry {
+ inode := &StaticDirectory{}
+ inode.Init(creds, devMajor, devMinor, ino, perm)
+
+ dentry := &Dentry{}
+ dentry.Init(inode)
+
+ inode.OrderedChildren.Init(OrderedChildrenOptions{})
+ links := inode.OrderedChildren.Populate(dentry, children)
+ inode.IncLinks(links)
+
+ return dentry
+}
+
+// Init initializes StaticDirectory.
+func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
+ if perm&^linux.PermissionsMask != 0 {
+ panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
+ }
+ s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeDirectory|perm)
+}
+
+// Open implements kernfs.Inode.
+func (s *StaticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := NewGenericDirectoryFD(rp.Mount(), vfsd, &s.OrderedChildren, &s.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*StaticDirectory) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// AlwaysValid partially implements kernfs.inodeDynamicLookup.
+type AlwaysValid struct{}
+
+// Valid implements kernfs.inodeDynamicLookup.
+func (*AlwaysValid) Valid(context.Context) bool {
+ return true
+}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
new file mode 100644
index 000000000..596de1edf
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -0,0 +1,456 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package kernfs provides the tools to implement inode-based filesystems.
+// Kernfs has two main features:
+//
+// 1. The Inode interface, which maps VFS2's path-based filesystem operations to
+// specific filesystem nodes. Kernfs uses the Inode interface to provide a
+// blanket implementation for the vfs.FilesystemImpl. Kernfs also serves as
+// the synchronization mechanism for all filesystem operations by holding a
+// filesystem-wide lock across all operations.
+//
+// 2. Various utility types which provide generic implementations for various
+// parts of the Inode and vfs.FileDescription interfaces. Client filesystems
+// based on kernfs can embed the appropriate set of these to avoid having to
+// reimplement common filesystem operations. See inode_impl_util.go and
+// fd_impl_util.go.
+//
+// Reference Model:
+//
+// Kernfs dentries represents named pointers to inodes. Dentries and inode have
+// independent lifetimes and reference counts. A child dentry unconditionally
+// holds a reference on its parent directory's dentry. A dentry also holds a
+// reference on the inode it points to. Multiple dentries can point to the same
+// inode (for example, in the case of hardlinks). File descriptors hold a
+// reference to the dentry they're opened on.
+//
+// Dentries are guaranteed to exist while holding Filesystem.mu for
+// reading. Dropping dentries require holding Filesystem.mu for writing. To
+// queue dentries for destruction from a read critical section, see
+// Filesystem.deferDecRef.
+//
+// Lock ordering:
+//
+// kernfs.Filesystem.mu
+// kernfs.Dentry.dirMu
+// vfs.VirtualFilesystem.mountMu
+// vfs.Dentry.mu
+// kernfs.Filesystem.droppedDentriesMu
+// (inode implementation locks, if any)
+package kernfs
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Filesystem mostly implements vfs.FilesystemImpl for a generic in-memory
+// filesystem. Concrete implementations are expected to embed this in their own
+// Filesystem type.
+type Filesystem struct {
+ vfsfs vfs.Filesystem
+
+ droppedDentriesMu sync.Mutex
+
+ // droppedDentries is a list of dentries waiting to be DecRef()ed. This is
+ // used to defer dentry destruction until mu can be acquired for
+ // writing. Protected by droppedDentriesMu.
+ droppedDentries []*vfs.Dentry
+
+ // mu synchronizes the lifetime of Dentries on this filesystem. Holding it
+ // for reading guarantees continued existence of any resolved dentries, but
+ // the dentry tree may be modified.
+ //
+ // Kernfs dentries can only be DecRef()ed while holding mu for writing. For
+ // example:
+ //
+ // fs.mu.Lock()
+ // defer fs.mu.Unlock()
+ // ...
+ // dentry1.DecRef()
+ // defer dentry2.DecRef() // Ok, will run before Unlock.
+ //
+ // If discarding dentries in a read context, use Filesystem.deferDecRef. For
+ // example:
+ //
+ // fs.mu.RLock()
+ // fs.mu.processDeferredDecRefs()
+ // defer fs.mu.RUnlock()
+ // ...
+ // fs.deferDecRef(dentry)
+ mu sync.RWMutex
+
+ // nextInoMinusOne is used to to allocate inode numbers on this
+ // filesystem. Must be accessed by atomic operations.
+ nextInoMinusOne uint64
+}
+
+// deferDecRef defers dropping a dentry ref until the next call to
+// processDeferredDecRefs{,Locked}. See comment on Filesystem.mu.
+//
+// Precondition: d must not already be pending destruction.
+func (fs *Filesystem) deferDecRef(d *vfs.Dentry) {
+ fs.droppedDentriesMu.Lock()
+ fs.droppedDentries = append(fs.droppedDentries, d)
+ fs.droppedDentriesMu.Unlock()
+}
+
+// processDeferredDecRefs calls vfs.Dentry.DecRef on all dentries in the
+// droppedDentries list. See comment on Filesystem.mu.
+func (fs *Filesystem) processDeferredDecRefs() {
+ fs.mu.Lock()
+ fs.processDeferredDecRefsLocked()
+ fs.mu.Unlock()
+}
+
+// Precondition: fs.mu must be held for writing.
+func (fs *Filesystem) processDeferredDecRefsLocked() {
+ fs.droppedDentriesMu.Lock()
+ for _, d := range fs.droppedDentries {
+ d.DecRef()
+ }
+ fs.droppedDentries = fs.droppedDentries[:0] // Keep slice memory for reuse.
+ fs.droppedDentriesMu.Unlock()
+}
+
+// VFSFilesystem returns the generic vfs filesystem object.
+func (fs *Filesystem) VFSFilesystem() *vfs.Filesystem {
+ return &fs.vfsfs
+}
+
+// NextIno allocates a new inode number on this filesystem.
+func (fs *Filesystem) NextIno() uint64 {
+ return atomic.AddUint64(&fs.nextInoMinusOne, 1)
+}
+
+// These consts are used in the Dentry.flags field.
+const (
+ // Dentry points to a directory inode.
+ dflagsIsDir = 1 << iota
+
+ // Dentry points to a symlink inode.
+ dflagsIsSymlink
+)
+
+// Dentry implements vfs.DentryImpl.
+//
+// A kernfs dentry is similar to a dentry in a traditional filesystem: it's a
+// named reference to an inode. A dentry generally lives as long as it's part of
+// a mounted filesystem tree. Kernfs doesn't cache dentries once all references
+// to them are removed. Dentries hold a single reference to the inode they point
+// to, and child dentries hold a reference on their parent.
+//
+// Must be initialized by Init prior to first use.
+type Dentry struct {
+ vfsd vfs.Dentry
+
+ refs.AtomicRefCount
+
+ // flags caches useful information about the dentry from the inode. See the
+ // dflags* consts above. Must be accessed by atomic ops.
+ flags uint32
+
+ parent *Dentry
+ name string
+
+ // dirMu protects children and the names of child Dentries.
+ dirMu sync.Mutex
+ children map[string]*Dentry
+
+ inode Inode
+}
+
+// Init initializes this dentry.
+//
+// Precondition: Caller must hold a reference on inode.
+//
+// Postcondition: Caller's reference on inode is transferred to the dentry.
+func (d *Dentry) Init(inode Inode) {
+ d.vfsd.Init(d)
+ d.inode = inode
+ ftype := inode.Mode().FileType()
+ if ftype == linux.ModeDirectory {
+ d.flags |= dflagsIsDir
+ }
+ if ftype == linux.ModeSymlink {
+ d.flags |= dflagsIsSymlink
+ }
+}
+
+// VFSDentry returns the generic vfs dentry for this kernfs dentry.
+func (d *Dentry) VFSDentry() *vfs.Dentry {
+ return &d.vfsd
+}
+
+// isDir checks whether the dentry points to a directory inode.
+func (d *Dentry) isDir() bool {
+ return atomic.LoadUint32(&d.flags)&dflagsIsDir != 0
+}
+
+// isSymlink checks whether the dentry points to a symlink inode.
+func (d *Dentry) isSymlink() bool {
+ return atomic.LoadUint32(&d.flags)&dflagsIsSymlink != 0
+}
+
+// DecRef implements vfs.DentryImpl.DecRef.
+func (d *Dentry) DecRef() {
+ d.AtomicRefCount.DecRefWithDestructor(d.destroy)
+}
+
+// Precondition: Dentry must be removed from VFS' dentry cache.
+func (d *Dentry) destroy() {
+ d.inode.DecRef() // IncRef from Init.
+ d.inode = nil
+ if d.parent != nil {
+ d.parent.DecRef() // IncRef from Dentry.InsertChild.
+ }
+}
+
+// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
+//
+// Although Linux technically supports inotify on pseudo filesystems (inotify
+// is implemented at the vfs layer), it is not particularly useful. It is left
+// unimplemented until someone actually needs it.
+func (d *Dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {}
+
+// Watches implements vfs.DentryImpl.Watches.
+func (d *Dentry) Watches() *vfs.Watches {
+ return nil
+}
+
+// OnZeroWatches implements vfs.Dentry.OnZeroWatches.
+func (d *Dentry) OnZeroWatches() {}
+
+// InsertChild inserts child into the vfs dentry cache with the given name under
+// this dentry. This does not update the directory inode, so calling this on
+// its own isn't sufficient to insert a child into a directory. InsertChild
+// updates the link count on d if required.
+//
+// Precondition: d must represent a directory inode.
+func (d *Dentry) InsertChild(name string, child *Dentry) {
+ d.dirMu.Lock()
+ d.insertChildLocked(name, child)
+ d.dirMu.Unlock()
+}
+
+// insertChildLocked is equivalent to InsertChild, with additional
+// preconditions.
+//
+// Precondition: d.dirMu must be locked.
+func (d *Dentry) insertChildLocked(name string, child *Dentry) {
+ if !d.isDir() {
+ panic(fmt.Sprintf("InsertChild called on non-directory Dentry: %+v.", d))
+ }
+ d.IncRef() // DecRef in child's Dentry.destroy.
+ child.parent = d
+ child.name = name
+ if d.children == nil {
+ d.children = make(map[string]*Dentry)
+ }
+ d.children[name] = child
+}
+
+// Inode returns the dentry's inode.
+func (d *Dentry) Inode() Inode {
+ return d.inode
+}
+
+// The Inode interface maps filesystem-level operations that operate on paths to
+// equivalent operations on specific filesystem nodes.
+//
+// The interface methods are groups into logical categories as sub interfaces
+// below. Generally, an implementation for each sub interface can be provided by
+// embedding an appropriate type from inode_impl_utils.go. The sub interfaces
+// are purely organizational. Methods declared directly in the main interface
+// have no generic implementations, and should be explicitly provided by the
+// client filesystem.
+//
+// Generally, implementations are not responsible for tasks that are common to
+// all filesystems. These include:
+//
+// - Checking that dentries passed to methods are of the appropriate file type.
+// - Checking permissions.
+// - Updating link and reference counts.
+//
+// Specific responsibilities of implementations are documented below.
+type Inode interface {
+ // Methods related to reference counting. A generic implementation is
+ // provided by InodeNoopRefCount. These methods are generally called by the
+ // equivalent Dentry methods.
+ inodeRefs
+
+ // Methods related to node metadata. A generic implementation is provided by
+ // InodeAttrs.
+ inodeMetadata
+
+ // Method for inodes that represent symlink. InodeNotSymlink provides a
+ // blanket implementation for all non-symlink inodes.
+ inodeSymlink
+
+ // Method for inodes that represent directories. InodeNotDirectory provides
+ // a blanket implementation for all non-directory inodes.
+ inodeDirectory
+
+ // Method for inodes that represent dynamic directories and their
+ // children. InodeNoDynamicLookup provides a blanket implementation for all
+ // non-dynamic-directory inodes.
+ inodeDynamicLookup
+
+ // Open creates a file description for the filesystem object represented by
+ // this inode. The returned file description should hold a reference on the
+ // inode for its lifetime.
+ //
+ // Precondition: rp.Done(). vfsd.Impl() must be the kernfs Dentry containing
+ // the inode on which Open() is being called.
+ Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error)
+}
+
+type inodeRefs interface {
+ IncRef()
+ DecRef()
+ TryIncRef() bool
+ // Destroy is called when the inode reaches zero references. Destroy release
+ // all resources (references) on objects referenced by the inode, including
+ // any child dentries.
+ Destroy()
+}
+
+type inodeMetadata interface {
+ // CheckPermissions checks that creds may access this inode for the
+ // requested access type, per the the rules of
+ // fs/namei.c:generic_permission().
+ CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error
+
+ // Mode returns the (struct stat)::st_mode value for this inode. This is
+ // separated from Stat for performance.
+ Mode() linux.FileMode
+
+ // Stat returns the metadata for this inode. This corresponds to
+ // vfs.FilesystemImpl.StatAt.
+ Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error)
+
+ // SetStat updates the metadata for this inode. This corresponds to
+ // vfs.FilesystemImpl.SetStatAt. Implementations are responsible for checking
+ // if the operation can be performed (see vfs.CheckSetStat() for common
+ // checks).
+ SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error
+}
+
+// Precondition: All methods in this interface may only be called on directory
+// inodes.
+type inodeDirectory interface {
+ // The New{File,Dir,Node,Symlink} methods below should return a new inode
+ // hashed into this inode.
+ //
+ // These inode constructors are inode-level operations rather than
+ // filesystem-level operations to allow client filesystems to mix different
+ // implementations based on the new node's location in the
+ // filesystem.
+
+ // HasChildren returns true if the directory inode has any children.
+ HasChildren() bool
+
+ // NewFile creates a new regular file inode.
+ NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error)
+
+ // NewDir creates a new directory inode.
+ NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error)
+
+ // NewLink creates a new hardlink to a specified inode in this
+ // directory. Implementations should create a new kernfs Dentry pointing to
+ // target, and update target's link count.
+ NewLink(ctx context.Context, name string, target Inode) (*vfs.Dentry, error)
+
+ // NewSymlink creates a new symbolic link inode.
+ NewSymlink(ctx context.Context, name, target string) (*vfs.Dentry, error)
+
+ // NewNode creates a new filesystem node for a mknod syscall.
+ NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*vfs.Dentry, error)
+
+ // Unlink removes a child dentry from this directory inode.
+ Unlink(ctx context.Context, name string, child *vfs.Dentry) error
+
+ // RmDir removes an empty child directory from this directory
+ // inode. Implementations must update the parent directory's link count,
+ // if required. Implementations are not responsible for checking that child
+ // is a directory, checking for an empty directory.
+ RmDir(ctx context.Context, name string, child *vfs.Dentry) error
+
+ // Rename is called on the source directory containing an inode being
+ // renamed. child should point to the resolved child in the source
+ // directory. If Rename replaces a dentry in the destination directory, it
+ // should return the replaced dentry or nil otherwise.
+ //
+ // Precondition: Caller must serialize concurrent calls to Rename.
+ Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (replaced *vfs.Dentry, err error)
+}
+
+type inodeDynamicLookup interface {
+ // Lookup should return an appropriate dentry if name should resolve to a
+ // child of this dynamic directory inode. This gives the directory an
+ // opportunity on every lookup to resolve additional entries that aren't
+ // hashed into the directory. This is only called when the inode is a
+ // directory. If the inode is not a directory, or if the directory only
+ // contains a static set of children, the implementer can unconditionally
+ // return an appropriate error (ENOTDIR and ENOENT respectively).
+ //
+ // The child returned by Lookup will be hashed into the VFS dentry tree. Its
+ // lifetime can be controlled by the filesystem implementation with an
+ // appropriate implementation of Valid.
+ //
+ // Lookup returns the child with an extra reference and the caller owns this
+ // reference.
+ Lookup(ctx context.Context, name string) (*vfs.Dentry, error)
+
+ // Valid should return true if this inode is still valid, or needs to
+ // be resolved again by a call to Lookup.
+ Valid(ctx context.Context) bool
+
+ // IterDirents is used to iterate over dynamically created entries. It invokes
+ // cb on each entry in the directory represented by the FileDescription.
+ // 'offset' is the offset for the entire IterDirents call, which may include
+ // results from the caller (e.g. "." and ".."). 'relOffset' is the offset
+ // inside the entries returned by this IterDirents invocation. In other words,
+ // 'offset' should be used to calculate each vfs.Dirent.NextOff as well as
+ // the return value, while 'relOffset' is the place to start iteration.
+ IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error)
+}
+
+type inodeSymlink interface {
+ // Readlink returns the target of a symbolic link. If an inode is not a
+ // symlink, the implementation should return EINVAL.
+ Readlink(ctx context.Context) (string, error)
+
+ // Getlink returns the target of a symbolic link, as used by path
+ // resolution:
+ //
+ // - If the inode is a "magic link" (a link whose target is most accurately
+ // represented as a VirtualDentry), Getlink returns (ok VirtualDentry, "",
+ // nil). A reference is taken on the returned VirtualDentry.
+ //
+ // - If the inode is an ordinary symlink, Getlink returns (zero-value
+ // VirtualDentry, symlink target, nil).
+ //
+ // - If the inode is not a symlink, Getlink returns (zero-value
+ // VirtualDentry, "", EINVAL).
+ Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error)
+}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
new file mode 100644
index 000000000..dc407eb1d
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -0,0 +1,330 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs_test
+
+import (
+ "bytes"
+ "fmt"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const defaultMode linux.FileMode = 01777
+const staticFileContent = "This is sample content for a static test file."
+
+// RootDentryFn is a generator function for creating the root dentry of a test
+// filesystem. See newTestSystem.
+type RootDentryFn func(*auth.Credentials, *filesystem) *kernfs.Dentry
+
+// newTestSystem sets up a minimal environment for running a test, including an
+// instance of a test filesystem. Tests can control the contents of the
+// filesystem by providing an appropriate rootFn, which should return a
+// pre-populated root dentry.
+func newTestSystem(t *testing.T, rootFn RootDentryFn) *testutil.System {
+ ctx := contexttest.Context(t)
+ creds := auth.CredentialsFromContext(ctx)
+ v := &vfs.VirtualFilesystem{}
+ if err := v.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
+ v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mns, err := v.NewMountNamespace(ctx, creds, "", "testfs", &vfs.GetFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("Failed to create testfs root mount: %v", err)
+ }
+ return testutil.NewSystem(ctx, t, v, mns)
+}
+
+type fsType struct {
+ rootFn RootDentryFn
+}
+
+type filesystem struct {
+ kernfs.Filesystem
+}
+
+type file struct {
+ kernfs.DynamicBytesFile
+ content string
+}
+
+func (fs *filesystem) newFile(creds *auth.Credentials, content string) *kernfs.Dentry {
+ f := &file{}
+ f.content = content
+ f.DynamicBytesFile.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777)
+
+ d := &kernfs.Dentry{}
+ d.Init(f)
+ return d
+}
+
+func (f *file) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%s", f.content)
+ return nil
+}
+
+type attrs struct {
+ kernfs.InodeAttrs
+}
+
+func (*attrs) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+type readonlyDir struct {
+ attrs
+ kernfs.InodeNotSymlink
+ kernfs.InodeNoDynamicLookup
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.OrderedChildren
+
+ locks vfs.FileLocks
+
+ dentry kernfs.Dentry
+}
+
+func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry {
+ dir := &readonlyDir{}
+ dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
+ dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ dir.dentry.Init(dir)
+
+ dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents))
+
+ return &dir.dentry
+}
+
+func (d *readonlyDir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+type dir struct {
+ attrs
+ kernfs.InodeNotSymlink
+ kernfs.InodeNoDynamicLookup
+ kernfs.OrderedChildren
+
+ locks vfs.FileLocks
+
+ fs *filesystem
+ dentry kernfs.Dentry
+}
+
+func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry {
+ dir := &dir{}
+ dir.fs = fs
+ dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
+ dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true})
+ dir.dentry.Init(dir)
+
+ dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents))
+
+ return &dir.dentry
+}
+
+func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error) {
+ creds := auth.CredentialsFromContext(ctx)
+ dir := d.fs.newDir(creds, opts.Mode, nil)
+ dirVFSD := dir.VFSDentry()
+ if err := d.OrderedChildren.Insert(name, dirVFSD); err != nil {
+ dir.DecRef()
+ return nil, err
+ }
+ d.IncLinks(1)
+ return dirVFSD, nil
+}
+
+func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error) {
+ creds := auth.CredentialsFromContext(ctx)
+ f := d.fs.newFile(creds, "")
+ fVFSD := f.VFSDentry()
+ if err := d.OrderedChildren.Insert(name, fVFSD); err != nil {
+ f.DecRef()
+ return nil, err
+ }
+ return fVFSD, nil
+}
+
+func (*dir) NewLink(context.Context, string, kernfs.Inode) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+func (*dir) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+func (fsType) Name() string {
+ return "kernfs"
+}
+
+func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ fs := &filesystem{}
+ fs.VFSFilesystem().Init(vfsObj, &fst, fs)
+ root := fst.rootFn(creds, fs)
+ return fs.VFSFilesystem(), root.VFSDentry(), nil
+}
+
+// -------------------- Remainder of the file are test cases --------------------
+
+func TestBasic(t *testing.T) {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ "file1": fs.newFile(creds, staticFileContent),
+ })
+ })
+ defer sys.Destroy()
+ sys.GetDentryOrDie(sys.PathOpAtRoot("file1")).DecRef()
+}
+
+func TestMkdirGetDentry(t *testing.T) {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ "dir1": fs.newDir(creds, 0755, nil),
+ })
+ })
+ defer sys.Destroy()
+
+ pop := sys.PathOpAtRoot("dir1/a new directory")
+ if err := sys.VFS.MkdirAt(sys.Ctx, sys.Creds, pop, &vfs.MkdirOptions{Mode: 0755}); err != nil {
+ t.Fatalf("MkdirAt for PathOperation %+v failed: %v", pop, err)
+ }
+ sys.GetDentryOrDie(pop).DecRef()
+}
+
+func TestReadStaticFile(t *testing.T) {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ "file1": fs.newFile(creds, staticFileContent),
+ })
+ })
+ defer sys.Destroy()
+
+ pop := sys.PathOpAtRoot("file1")
+ fd, err := sys.VFS.OpenAt(sys.Ctx, sys.Creds, pop, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err)
+ }
+ defer fd.DecRef()
+
+ content, err := sys.ReadToEnd(fd)
+ if err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+ if diff := cmp.Diff(staticFileContent, content); diff != "" {
+ t.Fatalf("Read returned unexpected data:\n--- want\n+++ got\n%v", diff)
+ }
+}
+
+func TestCreateNewFileInStaticDir(t *testing.T) {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ "dir1": fs.newDir(creds, 0755, nil),
+ })
+ })
+ defer sys.Destroy()
+
+ pop := sys.PathOpAtRoot("dir1/newfile")
+ opts := &vfs.OpenOptions{Flags: linux.O_CREAT | linux.O_EXCL, Mode: defaultMode}
+ fd, err := sys.VFS.OpenAt(sys.Ctx, sys.Creds, pop, opts)
+ if err != nil {
+ t.Fatalf("OpenAt(pop:%+v, opts:%+v) failed: %v", pop, opts, err)
+ }
+
+ // Close the file. The file should persist.
+ fd.DecRef()
+
+ fd, err = sys.VFS.OpenAt(sys.Ctx, sys.Creds, pop, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt(pop:%+v) = %+v failed: %v", pop, fd, err)
+ }
+ fd.DecRef()
+}
+
+func TestDirFDReadWrite(t *testing.T) {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ return fs.newReadonlyDir(creds, 0755, nil)
+ })
+ defer sys.Destroy()
+
+ pop := sys.PathOpAtRoot("/")
+ fd, err := sys.VFS.OpenAt(sys.Ctx, sys.Creds, pop, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err)
+ }
+ defer fd.DecRef()
+
+ // Read/Write should fail for directory FDs.
+ if _, err := fd.Read(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.ReadOptions{}); err != syserror.EISDIR {
+ t.Fatalf("Read for directory FD failed with unexpected error: %v", err)
+ }
+ if _, err := fd.Write(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.WriteOptions{}); err != syserror.EBADF {
+ t.Fatalf("Write for directory FD failed with unexpected error: %v", err)
+ }
+}
+
+func TestDirFDIterDirents(t *testing.T) {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ // Fill root with nodes backed by various inode implementations.
+ "dir1": fs.newReadonlyDir(creds, 0755, nil),
+ "dir2": fs.newDir(creds, 0755, map[string]*kernfs.Dentry{
+ "dir3": fs.newDir(creds, 0755, nil),
+ }),
+ "file1": fs.newFile(creds, staticFileContent),
+ })
+ })
+ defer sys.Destroy()
+
+ pop := sys.PathOpAtRoot("/")
+ sys.AssertAllDirentTypes(sys.ListDirents(pop), map[string]testutil.DirentType{
+ "dir1": linux.DT_DIR,
+ "dir2": linux.DT_DIR,
+ "file1": linux.DT_REG,
+ })
+}
diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go
new file mode 100644
index 000000000..2ab3f53fd
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/symlink.go
@@ -0,0 +1,66 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// StaticSymlink provides an Inode implementation for symlinks that point to
+// a immutable target.
+type StaticSymlink struct {
+ InodeAttrs
+ InodeNoopRefCount
+ InodeSymlink
+
+ target string
+}
+
+var _ Inode = (*StaticSymlink)(nil)
+
+// NewStaticSymlink creates a new symlink file pointing to 'target'.
+func NewStaticSymlink(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) *Dentry {
+ inode := &StaticSymlink{}
+ inode.Init(creds, devMajor, devMinor, ino, target)
+
+ d := &Dentry{}
+ d.Init(inode)
+ return d
+}
+
+// Init initializes the instance.
+func (s *StaticSymlink) Init(creds *auth.Credentials, devMajor uint32, devMinor uint32, ino uint64, target string) {
+ s.target = target
+ s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeSymlink|0777)
+}
+
+// Readlink implements Inode.
+func (s *StaticSymlink) Readlink(_ context.Context) (string, error) {
+ return s.target, nil
+}
+
+// Getlink implements Inode.Getlink.
+func (s *StaticSymlink) Getlink(context.Context, *vfs.Mount) (vfs.VirtualDentry, string, error) {
+ return vfs.VirtualDentry{}, s.target, nil
+}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*StaticSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD
new file mode 100644
index 000000000..8cf5b35d3
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/BUILD
@@ -0,0 +1,41 @@
+load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+licenses(["notice"])
+
+go_template_instance(
+ name = "fstree",
+ out = "fstree.go",
+ package = "overlay",
+ prefix = "generic",
+ template = "//pkg/sentry/vfs/genericfstree:generic_fstree",
+ types = {
+ "Dentry": "dentry",
+ },
+)
+
+go_library(
+ name = "overlay",
+ srcs = [
+ "copy_up.go",
+ "directory.go",
+ "filesystem.go",
+ "fstree.go",
+ "non_directory.go",
+ "overlay.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go
new file mode 100644
index 000000000..8f8dcfafe
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/copy_up.go
@@ -0,0 +1,262 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package overlay
+
+import (
+ "fmt"
+ "io"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func (d *dentry) isCopiedUp() bool {
+ return atomic.LoadUint32(&d.copiedUp) != 0
+}
+
+// copyUpLocked ensures that d exists on the upper layer, i.e. d.upperVD.Ok().
+//
+// Preconditions: filesystem.renameMu must be locked.
+func (d *dentry) copyUpLocked(ctx context.Context) error {
+ // Fast path.
+ if d.isCopiedUp() {
+ return nil
+ }
+
+ ftype := atomic.LoadUint32(&d.mode) & linux.S_IFMT
+ switch ftype {
+ case linux.S_IFREG, linux.S_IFDIR, linux.S_IFLNK, linux.S_IFBLK, linux.S_IFCHR:
+ // Can be copied-up.
+ default:
+ // Can't be copied-up.
+ return syserror.EPERM
+ }
+
+ // Ensure that our parent directory is copied-up.
+ if d.parent == nil {
+ // d is a filesystem root with no upper layer.
+ return syserror.EROFS
+ }
+ if err := d.parent.copyUpLocked(ctx); err != nil {
+ return err
+ }
+
+ d.copyMu.Lock()
+ defer d.copyMu.Unlock()
+ if d.upperVD.Ok() {
+ // Raced with another call to d.copyUpLocked().
+ return nil
+ }
+ if d.vfsd.IsDead() {
+ // Raced with deletion of d.
+ return syserror.ENOENT
+ }
+
+ // Perform copy-up.
+ vfsObj := d.fs.vfsfs.VirtualFilesystem()
+ newpop := vfs.PathOperation{
+ Root: d.parent.upperVD,
+ Start: d.parent.upperVD,
+ Path: fspath.Parse(d.name),
+ }
+ cleanupUndoCopyUp := func() {
+ var err error
+ if ftype == linux.S_IFDIR {
+ err = vfsObj.RmdirAt(ctx, d.fs.creds, &newpop)
+ } else {
+ err = vfsObj.UnlinkAt(ctx, d.fs.creds, &newpop)
+ }
+ if err != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after copy-up error: %v", err)
+ }
+ }
+ switch ftype {
+ case linux.S_IFREG:
+ oldFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerVDs[0],
+ Start: d.lowerVDs[0],
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err != nil {
+ return err
+ }
+ defer oldFD.DecRef()
+ newFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &newpop, &vfs.OpenOptions{
+ Flags: linux.O_WRONLY | linux.O_CREAT | linux.O_EXCL,
+ Mode: linux.FileMode(d.mode &^ linux.S_IFMT),
+ })
+ if err != nil {
+ return err
+ }
+ defer newFD.DecRef()
+ bufIOSeq := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size
+ for {
+ readN, readErr := oldFD.Read(ctx, bufIOSeq, vfs.ReadOptions{})
+ if readErr != nil && readErr != io.EOF {
+ cleanupUndoCopyUp()
+ return readErr
+ }
+ total := int64(0)
+ for total < readN {
+ writeN, writeErr := newFD.Write(ctx, bufIOSeq.DropFirst64(total), vfs.WriteOptions{})
+ total += writeN
+ if writeErr != nil {
+ cleanupUndoCopyUp()
+ return writeErr
+ }
+ }
+ if readErr == io.EOF {
+ break
+ }
+ }
+ if err := newFD.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: d.uid,
+ GID: d.gid,
+ },
+ }); err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ d.upperVD = newFD.VirtualDentry()
+ d.upperVD.IncRef()
+
+ case linux.S_IFDIR:
+ if err := vfsObj.MkdirAt(ctx, d.fs.creds, &newpop, &vfs.MkdirOptions{
+ Mode: linux.FileMode(d.mode &^ linux.S_IFMT),
+ }); err != nil {
+ return err
+ }
+ if err := vfsObj.SetStatAt(ctx, d.fs.creds, &newpop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: d.uid,
+ GID: d.gid,
+ },
+ }); err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ upperVD, err := vfsObj.GetDentryAt(ctx, d.fs.creds, &newpop, &vfs.GetDentryOptions{})
+ if err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ d.upperVD = upperVD
+
+ case linux.S_IFLNK:
+ target, err := vfsObj.ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerVDs[0],
+ Start: d.lowerVDs[0],
+ })
+ if err != nil {
+ return err
+ }
+ if err := vfsObj.SymlinkAt(ctx, d.fs.creds, &newpop, target); err != nil {
+ return err
+ }
+ if err := vfsObj.SetStatAt(ctx, d.fs.creds, &newpop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID,
+ Mode: uint16(d.mode),
+ UID: d.uid,
+ GID: d.gid,
+ },
+ }); err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ upperVD, err := vfsObj.GetDentryAt(ctx, d.fs.creds, &newpop, &vfs.GetDentryOptions{})
+ if err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ d.upperVD = upperVD
+
+ case linux.S_IFBLK, linux.S_IFCHR:
+ lowerStat, err := vfsObj.StatAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerVDs[0],
+ Start: d.lowerVDs[0],
+ }, &vfs.StatOptions{})
+ if err != nil {
+ return err
+ }
+ if err := vfsObj.MknodAt(ctx, d.fs.creds, &newpop, &vfs.MknodOptions{
+ Mode: linux.FileMode(d.mode),
+ DevMajor: lowerStat.RdevMajor,
+ DevMinor: lowerStat.RdevMinor,
+ }); err != nil {
+ return err
+ }
+ if err := vfsObj.SetStatAt(ctx, d.fs.creds, &newpop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: d.uid,
+ GID: d.gid,
+ },
+ }); err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ upperVD, err := vfsObj.GetDentryAt(ctx, d.fs.creds, &newpop, &vfs.GetDentryOptions{})
+ if err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ d.upperVD = upperVD
+
+ default:
+ // Should have rejected this at the beginning of this function?
+ panic(fmt.Sprintf("unexpected file type %o", ftype))
+ }
+
+ // TODO(gvisor.dev/issue/1199): copy up xattrs
+
+ // Update the dentry's device and inode numbers (except for directories,
+ // for which these remain overlay-assigned).
+ if ftype != linux.S_IFDIR {
+ upperStat, err := vfsObj.StatAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.upperVD,
+ Start: d.upperVD,
+ }, &vfs.StatOptions{
+ Mask: linux.STATX_INO,
+ })
+ if err != nil {
+ d.upperVD.DecRef()
+ d.upperVD = vfs.VirtualDentry{}
+ cleanupUndoCopyUp()
+ return err
+ }
+ if upperStat.Mask&linux.STATX_INO == 0 {
+ d.upperVD.DecRef()
+ d.upperVD = vfs.VirtualDentry{}
+ cleanupUndoCopyUp()
+ return syserror.EREMOTE
+ }
+ atomic.StoreUint32(&d.devMajor, upperStat.DevMajor)
+ atomic.StoreUint32(&d.devMinor, upperStat.DevMinor)
+ atomic.StoreUint64(&d.ino, upperStat.Ino)
+ }
+
+ atomic.StoreUint32(&d.copiedUp, 1)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/overlay/directory.go b/pkg/sentry/fsimpl/overlay/directory.go
new file mode 100644
index 000000000..f5c2462a5
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/directory.go
@@ -0,0 +1,287 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package overlay
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+func (d *dentry) isDir() bool {
+ return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFDIR
+}
+
+// Preconditions: d.dirMu must be locked. d.isDir().
+func (d *dentry) collectWhiteoutsForRmdirLocked(ctx context.Context) (map[string]bool, error) {
+ vfsObj := d.fs.vfsfs.VirtualFilesystem()
+ var readdirErr error
+ whiteouts := make(map[string]bool)
+ var maybeWhiteouts []string
+ d.iterLayers(func(layerVD vfs.VirtualDentry, isUpper bool) bool {
+ layerFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY | linux.O_DIRECTORY,
+ })
+ if err != nil {
+ readdirErr = err
+ return false
+ }
+ defer layerFD.DecRef()
+
+ // Reuse slice allocated for maybeWhiteouts from a previous layer to
+ // reduce allocations.
+ maybeWhiteouts = maybeWhiteouts[:0]
+ if err := layerFD.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error {
+ if dirent.Name == "." || dirent.Name == ".." {
+ return nil
+ }
+ if _, ok := whiteouts[dirent.Name]; ok {
+ // This file has been whited-out in a previous layer.
+ return nil
+ }
+ if dirent.Type == linux.DT_CHR {
+ // We have to determine if this is a whiteout, which doesn't
+ // count against the directory's emptiness. However, we can't
+ // do so while holding locks held by layerFD.IterDirents().
+ maybeWhiteouts = append(maybeWhiteouts, dirent.Name)
+ return nil
+ }
+ // Non-whiteout file in the directory prevents rmdir.
+ return syserror.ENOTEMPTY
+ })); err != nil {
+ readdirErr = err
+ return false
+ }
+
+ for _, maybeWhiteoutName := range maybeWhiteouts {
+ stat, err := vfsObj.StatAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ Path: fspath.Parse(maybeWhiteoutName),
+ }, &vfs.StatOptions{})
+ if err != nil {
+ readdirErr = err
+ return false
+ }
+ if stat.RdevMajor != 0 || stat.RdevMinor != 0 {
+ // This file is a real character device, not a whiteout.
+ readdirErr = syserror.ENOTEMPTY
+ return false
+ }
+ whiteouts[maybeWhiteoutName] = isUpper
+ }
+ // Continue iteration since we haven't found any non-whiteout files in
+ // this directory yet.
+ return true
+ })
+ return whiteouts, readdirErr
+}
+
+type directoryFD struct {
+ fileDescription
+ vfs.DirectoryFileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+
+ mu sync.Mutex
+ off int64
+ dirents []vfs.Dirent
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *directoryFD) Release() {
+}
+
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+
+ d := fd.dentry()
+ if fd.dirents == nil {
+ ds, err := d.getDirents(ctx)
+ if err != nil {
+ return err
+ }
+ fd.dirents = ds
+ }
+
+ for fd.off < int64(len(fd.dirents)) {
+ if err := cb.Handle(fd.dirents[fd.off]); err != nil {
+ return err
+ }
+ fd.off++
+ }
+ return nil
+}
+
+// Preconditions: d.isDir().
+func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
+ d.fs.renameMu.RLock()
+ defer d.fs.renameMu.RUnlock()
+ d.dirMu.Lock()
+ defer d.dirMu.Unlock()
+
+ if d.dirents != nil {
+ return d.dirents, nil
+ }
+
+ parent := genericParentOrSelf(d)
+ dirents := []vfs.Dirent{
+ {
+ Name: ".",
+ Type: linux.DT_DIR,
+ Ino: d.ino,
+ NextOff: 1,
+ },
+ {
+ Name: "..",
+ Type: uint8(atomic.LoadUint32(&parent.mode) >> 12),
+ Ino: parent.ino,
+ NextOff: 2,
+ },
+ }
+
+ // Merge dirents from all layers comprising this directory.
+ vfsObj := d.fs.vfsfs.VirtualFilesystem()
+ var readdirErr error
+ prevDirents := make(map[string]struct{})
+ var maybeWhiteouts []vfs.Dirent
+ d.iterLayers(func(layerVD vfs.VirtualDentry, isUpper bool) bool {
+ layerFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY | linux.O_DIRECTORY,
+ })
+ if err != nil {
+ readdirErr = err
+ return false
+ }
+ defer layerFD.DecRef()
+
+ // Reuse slice allocated for maybeWhiteouts from a previous layer to
+ // reduce allocations.
+ maybeWhiteouts = maybeWhiteouts[:0]
+ if err := layerFD.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error {
+ if dirent.Name == "." || dirent.Name == ".." {
+ return nil
+ }
+ if _, ok := prevDirents[dirent.Name]; ok {
+ // This file is hidden by, or merged with, another file with
+ // the same name in a previous layer.
+ return nil
+ }
+ prevDirents[dirent.Name] = struct{}{}
+ if dirent.Type == linux.DT_CHR {
+ // We can't determine if this file is a whiteout while holding
+ // locks held by layerFD.IterDirents().
+ maybeWhiteouts = append(maybeWhiteouts, dirent)
+ return nil
+ }
+ dirent.NextOff = int64(len(dirents) + 1)
+ dirents = append(dirents, dirent)
+ return nil
+ })); err != nil {
+ readdirErr = err
+ return false
+ }
+
+ for _, dirent := range maybeWhiteouts {
+ stat, err := vfsObj.StatAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ Path: fspath.Parse(dirent.Name),
+ }, &vfs.StatOptions{})
+ if err != nil {
+ readdirErr = err
+ return false
+ }
+ if stat.RdevMajor == 0 && stat.RdevMinor == 0 {
+ // This file is a whiteout; don't emit a dirent for it.
+ continue
+ }
+ dirent.NextOff = int64(len(dirents) + 1)
+ dirents = append(dirents, dirent)
+ }
+ return true
+ })
+ if readdirErr != nil {
+ return nil, readdirErr
+ }
+
+ // Cache dirents for future directoryFDs.
+ d.dirents = dirents
+ return dirents, nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+
+ switch whence {
+ case linux.SEEK_SET:
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ if offset == 0 {
+ // Ensure that the next call to fd.IterDirents() calls
+ // fd.dentry().getDirents().
+ fd.dirents = nil
+ }
+ fd.off = offset
+ return fd.off, nil
+ case linux.SEEK_CUR:
+ offset += fd.off
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ // Don't clear fd.dirents in this case, even if offset == 0.
+ fd.off = offset
+ return fd.off, nil
+ default:
+ return 0, syserror.EINVAL
+ }
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync. Forwards sync to the upper
+// layer, if there is one. The lower layer doesn't need to sync because it
+// never changes.
+func (fd *directoryFD) Sync(ctx context.Context) error {
+ d := fd.dentry()
+ if !d.isCopiedUp() {
+ return nil
+ }
+ vfsObj := d.fs.vfsfs.VirtualFilesystem()
+ pop := vfs.PathOperation{
+ Root: d.upperVD,
+ Start: d.upperVD,
+ }
+ upperFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &pop, &vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_DIRECTORY})
+ if err != nil {
+ return err
+ }
+ err = upperFD.Sync(ctx)
+ upperFD.DecRef()
+ return err
+}
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
new file mode 100644
index 000000000..ff82e1f20
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -0,0 +1,1364 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package overlay
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// _OVL_XATTR_OPAQUE is an extended attribute key whose value is set to "y" for
+// opaque directories.
+// Linux: fs/overlayfs/overlayfs.h:OVL_XATTR_OPAQUE
+const _OVL_XATTR_OPAQUE = "trusted.overlay.opaque"
+
+func isWhiteout(stat *linux.Statx) bool {
+ return stat.Mode&linux.S_IFMT == linux.S_IFCHR && stat.RdevMajor == 0 && stat.RdevMinor == 0
+}
+
+// Sync implements vfs.FilesystemImpl.Sync.
+func (fs *filesystem) Sync(ctx context.Context) error {
+ if fs.opts.UpperRoot.Ok() {
+ return fs.opts.UpperRoot.Mount().Filesystem().Impl().Sync(ctx)
+ }
+ return nil
+}
+
+var dentrySlicePool = sync.Pool{
+ New: func() interface{} {
+ ds := make([]*dentry, 0, 4) // arbitrary non-zero initial capacity
+ return &ds
+ },
+}
+
+func appendDentry(ds *[]*dentry, d *dentry) *[]*dentry {
+ if ds == nil {
+ ds = dentrySlicePool.Get().(*[]*dentry)
+ }
+ *ds = append(*ds, d)
+ return ds
+}
+
+// Preconditions: ds != nil.
+func putDentrySlice(ds *[]*dentry) {
+ // Allow dentries to be GC'd.
+ for i := range *ds {
+ (*ds)[i] = nil
+ }
+ *ds = (*ds)[:0]
+ dentrySlicePool.Put(ds)
+}
+
+// renameMuRUnlockAndCheckDrop calls fs.renameMu.RUnlock(), then calls
+// dentry.checkDropLocked on all dentries in *ds with fs.renameMu locked for
+// writing.
+//
+// ds is a pointer-to-pointer since defer evaluates its arguments immediately,
+// but dentry slices are allocated lazily, and it's much easier to say "defer
+// fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() {
+// fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this.
+func (fs *filesystem) renameMuRUnlockAndCheckDrop(ds **[]*dentry) {
+ fs.renameMu.RUnlock()
+ if *ds == nil {
+ return
+ }
+ if len(**ds) != 0 {
+ fs.renameMu.Lock()
+ for _, d := range **ds {
+ d.checkDropLocked()
+ }
+ fs.renameMu.Unlock()
+ }
+ putDentrySlice(*ds)
+}
+
+func (fs *filesystem) renameMuUnlockAndCheckDrop(ds **[]*dentry) {
+ if *ds == nil {
+ fs.renameMu.Unlock()
+ return
+ }
+ for _, d := range **ds {
+ d.checkDropLocked()
+ }
+ fs.renameMu.Unlock()
+ putDentrySlice(*ds)
+}
+
+// stepLocked resolves rp.Component() to an existing file, starting from the
+// given directory.
+//
+// Dentries which may have a reference count of zero, and which therefore
+// should be dropped once traversal is complete, are appended to ds.
+//
+// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
+// !rp.Done().
+func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) {
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+afterSymlink:
+ name := rp.Component()
+ if name == "." {
+ rp.Advance()
+ return d, nil
+ }
+ if name == ".." {
+ if isRoot, err := rp.CheckRoot(&d.vfsd); err != nil {
+ return nil, err
+ } else if isRoot || d.parent == nil {
+ rp.Advance()
+ return d, nil
+ }
+ if err := rp.CheckMount(&d.parent.vfsd); err != nil {
+ return nil, err
+ }
+ rp.Advance()
+ return d.parent, nil
+ }
+ child, err := fs.getChildLocked(ctx, d, name, ds)
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.CheckMount(&child.vfsd); err != nil {
+ return nil, err
+ }
+ if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() {
+ target, err := child.readlink(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.HandleSymlink(target); err != nil {
+ return nil, err
+ }
+ goto afterSymlink // don't check the current directory again
+ }
+ rp.Advance()
+ return child, nil
+}
+
+// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
+func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
+ if child, ok := parent.children[name]; ok {
+ return child, nil
+ }
+ child, err := fs.lookupLocked(ctx, parent, name)
+ if err != nil {
+ return nil, err
+ }
+ if parent.children == nil {
+ parent.children = make(map[string]*dentry)
+ }
+ parent.children[name] = child
+ // child's refcount is initially 0, so it may be dropped after traversal.
+ *ds = appendDentry(*ds, child)
+ return child, nil
+}
+
+// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked.
+func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) {
+ childPath := fspath.Parse(name)
+ child := fs.newDentry()
+ existsOnAnyLayer := false
+ var lookupErr error
+
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ parent.iterLayers(func(parentVD vfs.VirtualDentry, isUpper bool) bool {
+ childVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: parentVD,
+ Start: parentVD,
+ Path: childPath,
+ }, &vfs.GetDentryOptions{})
+ if err == syserror.ENOENT || err == syserror.ENAMETOOLONG {
+ // The file doesn't exist on this layer. Proceed to the next one.
+ return true
+ }
+ if err != nil {
+ lookupErr = err
+ return false
+ }
+
+ mask := uint32(linux.STATX_TYPE)
+ if !existsOnAnyLayer {
+ // Mode, UID, GID, and (for non-directories) inode number come from
+ // the topmost layer on which the file exists.
+ mask |= linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO
+ }
+ stat, err := vfsObj.StatAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: childVD,
+ Start: childVD,
+ }, &vfs.StatOptions{
+ Mask: mask,
+ })
+ if err != nil {
+ lookupErr = err
+ return false
+ }
+ if stat.Mask&mask != mask {
+ lookupErr = syserror.EREMOTE
+ return false
+ }
+
+ if isWhiteout(&stat) {
+ // This is a whiteout, so it "doesn't exist" on this layer, and
+ // layers below this one are ignored.
+ return false
+ }
+ isDir := stat.Mode&linux.S_IFMT == linux.S_IFDIR
+ if existsOnAnyLayer && !isDir {
+ // Directories are not merged with non-directory files from lower
+ // layers; instead, layers including and below the first
+ // non-directory file are ignored. (This file must be a directory
+ // on previous layers, since lower layers aren't searched for
+ // non-directory files.)
+ return false
+ }
+
+ // Update child to include this layer.
+ if isUpper {
+ child.upperVD = childVD
+ child.copiedUp = 1
+ } else {
+ child.lowerVDs = append(child.lowerVDs, childVD)
+ }
+ if !existsOnAnyLayer {
+ existsOnAnyLayer = true
+ child.mode = uint32(stat.Mode)
+ child.uid = stat.UID
+ child.gid = stat.GID
+ child.devMajor = stat.DevMajor
+ child.devMinor = stat.DevMinor
+ child.ino = stat.Ino
+ }
+
+ // For non-directory files, only the topmost layer that contains a file
+ // matters.
+ if !isDir {
+ return false
+ }
+
+ // Directories are merged with directories from lower layers if they
+ // are not explicitly opaque.
+ opaqueVal, err := vfsObj.GetxattrAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: childVD,
+ Start: childVD,
+ }, &vfs.GetxattrOptions{
+ Name: _OVL_XATTR_OPAQUE,
+ Size: 1,
+ })
+ return !(err == nil && opaqueVal == "y")
+ })
+
+ if lookupErr != nil {
+ child.destroyLocked()
+ return nil, lookupErr
+ }
+ if !existsOnAnyLayer {
+ child.destroyLocked()
+ return nil, syserror.ENOENT
+ }
+
+ // Device and inode numbers were copied from the topmost layer above;
+ // override them if necessary.
+ if child.isDir() {
+ child.devMajor = linux.UNNAMED_MAJOR
+ child.devMinor = fs.dirDevMinor
+ child.ino = fs.newDirIno()
+ } else if !child.upperVD.Ok() {
+ child.devMajor = linux.UNNAMED_MAJOR
+ child.devMinor = fs.lowerDevMinors[child.lowerVDs[0].Mount().Filesystem()]
+ }
+
+ parent.IncRef()
+ child.parent = parent
+ child.name = name
+ return child, nil
+}
+
+// lookupLayerLocked is similar to lookupLocked, but only returns information
+// about the file rather than a dentry.
+//
+// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked.
+func (fs *filesystem) lookupLayerLocked(ctx context.Context, parent *dentry, name string) (lookupLayer, error) {
+ childPath := fspath.Parse(name)
+ lookupLayer := lookupLayerNone
+ var lookupErr error
+
+ parent.iterLayers(func(parentVD vfs.VirtualDentry, isUpper bool) bool {
+ stat, err := fs.vfsfs.VirtualFilesystem().StatAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: parentVD,
+ Start: parentVD,
+ Path: childPath,
+ }, &vfs.StatOptions{
+ Mask: linux.STATX_TYPE,
+ })
+ if err == syserror.ENOENT || err == syserror.ENAMETOOLONG {
+ // The file doesn't exist on this layer. Proceed to the next
+ // one.
+ return true
+ }
+ if err != nil {
+ lookupErr = err
+ return false
+ }
+ if stat.Mask&linux.STATX_TYPE == 0 {
+ // Linux's overlayfs tends to return EREMOTE in cases where a file
+ // is unusable for reasons that are not better captured by another
+ // errno.
+ lookupErr = syserror.EREMOTE
+ return false
+ }
+ if isWhiteout(&stat) {
+ // This is a whiteout, so it "doesn't exist" on this layer, and
+ // layers below this one are ignored.
+ if isUpper {
+ lookupLayer = lookupLayerUpperWhiteout
+ }
+ return false
+ }
+ // The file exists; we can stop searching.
+ if isUpper {
+ lookupLayer = lookupLayerUpper
+ } else {
+ lookupLayer = lookupLayerLower
+ }
+ return false
+ })
+
+ return lookupLayer, lookupErr
+}
+
+type lookupLayer int
+
+const (
+ // lookupLayerNone indicates that no file exists at the given path on the
+ // upper layer, and is either whited out or does not exist on lower layers.
+ // Therefore, the file does not exist in the overlay filesystem, and file
+ // creation may proceed normally (if an upper layer exists).
+ lookupLayerNone lookupLayer = iota
+
+ // lookupLayerLower indicates that no file exists at the given path on the
+ // upper layer, but exists on a lower layer. Therefore, the file exists in
+ // the overlay filesystem, but must be copied-up before mutation.
+ lookupLayerLower
+
+ // lookupLayerUpper indicates that a non-whiteout file exists at the given
+ // path on the upper layer. Therefore, the file exists in the overlay
+ // filesystem, and is already copied-up.
+ lookupLayerUpper
+
+ // lookupLayerUpperWhiteout indicates that a whiteout exists at the given
+ // path on the upper layer. Therefore, the file does not exist in the
+ // overlay filesystem, and file creation must remove the whiteout before
+ // proceeding.
+ lookupLayerUpperWhiteout
+)
+
+func (ll lookupLayer) existsInOverlay() bool {
+ return ll == lookupLayerLower || ll == lookupLayerUpper
+}
+
+// walkParentDirLocked resolves all but the last path component of rp to an
+// existing directory, starting from the given directory (which is usually
+// rp.Start().Impl().(*dentry)). It does not check that the returned directory
+// is searchable by the provider of rp.
+//
+// Preconditions: fs.renameMu must be locked. !rp.Done().
+func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
+ for !rp.Final() {
+ d.dirMu.Lock()
+ next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ d.dirMu.Unlock()
+ if err != nil {
+ return nil, err
+ }
+ d = next
+ }
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ return d, nil
+}
+
+// resolveLocked resolves rp to an existing file.
+//
+// Preconditions: fs.renameMu must be locked.
+func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) {
+ d := rp.Start().Impl().(*dentry)
+ for !rp.Done() {
+ d.dirMu.Lock()
+ next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ d.dirMu.Unlock()
+ if err != nil {
+ return nil, err
+ }
+ d = next
+ }
+ if rp.MustBeDir() && !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ return d, nil
+}
+
+// doCreateAt checks that creating a file at rp is permitted, then invokes
+// create to do so.
+//
+// Preconditions: !rp.Done(). For the final path component in rp,
+// !rp.ShouldFollowSymlink().
+func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, create func(parent *dentry, name string, haveUpperWhiteout bool) error) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ start := rp.Start().Impl().(*dentry)
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return err
+ }
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ name := rp.Component()
+ if name == "." || name == ".." {
+ return syserror.EEXIST
+ }
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
+ if parent.vfsd.IsDead() {
+ return syserror.ENOENT
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+
+ // Determine if a file already exists at name.
+ if _, ok := parent.children[name]; ok {
+ return syserror.EEXIST
+ }
+ childLayer, err := fs.lookupLayerLocked(ctx, parent, name)
+ if err != nil {
+ return err
+ }
+ if childLayer.existsInOverlay() {
+ return syserror.EEXIST
+ }
+
+ // Ensure that the parent directory is copied-up so that we can create the
+ // new file in the upper layer.
+ if err := parent.copyUpLocked(ctx); err != nil {
+ return err
+ }
+
+ // Finally create the new file.
+ if err := create(parent, name, childLayer == lookupLayerUpperWhiteout); err != nil {
+ return err
+ }
+ parent.dirents = nil
+ return nil
+}
+
+// Preconditions: pop's parent directory has been copied up.
+func (fs *filesystem) createWhiteout(ctx context.Context, vfsObj *vfs.VirtualFilesystem, pop *vfs.PathOperation) error {
+ return vfsObj.MknodAt(ctx, fs.creds, pop, &vfs.MknodOptions{
+ Mode: linux.S_IFCHR, // permissions == include/linux/fs.h:WHITEOUT_MODE == 0
+ // DevMajor == DevMinor == 0, from include/linux/fs.h:WHITEOUT_DEV
+ })
+}
+
+func (fs *filesystem) cleanupRecreateWhiteout(ctx context.Context, vfsObj *vfs.VirtualFilesystem, pop *vfs.PathOperation) {
+ if err := fs.createWhiteout(ctx, vfsObj, pop); err != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to recreate whiteout after failed file creation: %v", err)
+ }
+}
+
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+ return d.checkPermissions(creds, ats)
+}
+
+// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt.
+func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ layerVD := d.topLayer()
+ return fs.vfsfs.VirtualFilesystem().BoundEndpointAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ }, &opts)
+}
+
+// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
+func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ if opts.CheckSearchable {
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ }
+ d.IncRef()
+ return &d.vfsd, nil
+}
+
+// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt.
+func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ start := rp.Start().Impl().(*dentry)
+ d, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return nil, err
+ }
+ d.IncRef()
+ return &d.vfsd, nil
+}
+
+// LinkAt implements vfs.FilesystemImpl.LinkAt.
+func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ if rp.Mount() != vd.Mount() {
+ return syserror.EXDEV
+ }
+ old := vd.Dentry().Impl().(*dentry)
+ if old.isDir() {
+ return syserror.EPERM
+ }
+ if err := old.copyUpLocked(ctx); err != nil {
+ return err
+ }
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ newpop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(childName),
+ }
+ if haveUpperWhiteout {
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &newpop); err != nil {
+ return err
+ }
+ }
+ if err := vfsObj.LinkAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: old.upperVD,
+ Start: old.upperVD,
+ }, &newpop); err != nil {
+ if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &newpop)
+ }
+ return err
+ }
+ creds := rp.Credentials()
+ if err := vfsObj.SetStatAt(ctx, fs.creds, &newpop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ },
+ }); err != nil {
+ if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &newpop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after LinkAt metadata update failure: %v", cleanupErr)
+ } else if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &newpop)
+ }
+ return err
+ }
+ return nil
+ })
+}
+
+// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
+func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
+ return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ pop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(childName),
+ }
+ if haveUpperWhiteout {
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err != nil {
+ return err
+ }
+ }
+ if err := vfsObj.MkdirAt(ctx, fs.creds, &pop, &opts); err != nil {
+ if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ creds := rp.Credentials()
+ if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ },
+ }); err != nil {
+ if cleanupErr := vfsObj.RmdirAt(ctx, fs.creds, &pop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt metadata update failure: %v", cleanupErr)
+ } else if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ if haveUpperWhiteout {
+ // There may be directories on lower layers (previously hidden by
+ // the whiteout) that the new directory should not be merged with.
+ // Mark it opaque to prevent merging.
+ if err := vfsObj.SetxattrAt(ctx, fs.creds, &pop, &vfs.SetxattrOptions{
+ Name: _OVL_XATTR_OPAQUE,
+ Value: "y",
+ }); err != nil {
+ if cleanupErr := vfsObj.RmdirAt(ctx, fs.creds, &pop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt set-opaque failure: %v", cleanupErr)
+ } else {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ }
+ return nil
+ })
+}
+
+// MknodAt implements vfs.FilesystemImpl.MknodAt.
+func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ // Disallow attempts to create whiteouts.
+ if opts.Mode&linux.S_IFMT == linux.S_IFCHR && opts.DevMajor == 0 && opts.DevMinor == 0 {
+ return syserror.EPERM
+ }
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ pop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(childName),
+ }
+ if haveUpperWhiteout {
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err != nil {
+ return err
+ }
+ }
+ if err := vfsObj.MknodAt(ctx, fs.creds, &pop, &opts); err != nil {
+ if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ creds := rp.Credentials()
+ if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ },
+ }); err != nil {
+ if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after MknodAt metadata update failure: %v", cleanupErr)
+ } else if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ return nil
+ })
+}
+
+// OpenAt implements vfs.FilesystemImpl.OpenAt.
+func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ mayCreate := opts.Flags&linux.O_CREAT != 0
+ mustCreate := opts.Flags&(linux.O_CREAT|linux.O_EXCL) == (linux.O_CREAT | linux.O_EXCL)
+
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+
+ start := rp.Start().Impl().(*dentry)
+ if rp.Done() {
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ return start.openLocked(ctx, rp, &opts)
+ }
+
+afterTrailingSymlink:
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return nil, err
+ }
+ // Check for search permission in the parent directory.
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ // Determine whether or not we need to create a file.
+ parent.dirMu.Lock()
+ child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
+ if err == syserror.ENOENT && mayCreate {
+ fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds)
+ parent.dirMu.Unlock()
+ return fd, err
+ }
+ if err != nil {
+ parent.dirMu.Unlock()
+ return nil, err
+ }
+ // Open existing child or follow symlink.
+ parent.dirMu.Unlock()
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ if child.isSymlink() && rp.ShouldFollowSymlink() {
+ target, err := child.readlink(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.HandleSymlink(target); err != nil {
+ return nil, err
+ }
+ start = parent
+ goto afterTrailingSymlink
+ }
+ return child.openLocked(ctx, rp, &opts)
+}
+
+// Preconditions: fs.renameMu must be locked.
+func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+ ats := vfs.AccessTypesForOpenFlags(opts)
+ if err := d.checkPermissions(rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+ if ats.MayWrite() {
+ if err := d.copyUpLocked(ctx); err != nil {
+ return nil, err
+ }
+ }
+ mnt := rp.Mount()
+
+ // Directory FDs open FDs from each layer when directory entries are read,
+ // so they don't require opening an FD from d.topLayer() up front.
+ ftype := atomic.LoadUint32(&d.mode) & linux.S_IFMT
+ if ftype == linux.S_IFDIR {
+ // Can't open directories with O_CREAT.
+ if opts.Flags&linux.O_CREAT != 0 {
+ return nil, syserror.EISDIR
+ }
+ // Can't open directories writably.
+ if ats&vfs.MayWrite != 0 {
+ return nil, syserror.EISDIR
+ }
+ if opts.Flags&linux.O_DIRECT != 0 {
+ return nil, syserror.EINVAL
+ }
+ fd := &directoryFD{}
+ fd.LockFD.Init(&d.locks)
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+ }
+
+ layerVD, isUpper := d.topLayerInfo()
+ layerFD, err := rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ }, opts)
+ if err != nil {
+ return nil, err
+ }
+ layerFlags := layerFD.StatusFlags()
+ fd := &nonDirectoryFD{
+ copiedUp: isUpper,
+ cachedFD: layerFD,
+ cachedFlags: layerFlags,
+ }
+ fd.LockFD.Init(&d.locks)
+ layerFDOpts := layerFD.Options()
+ if err := fd.vfsfd.Init(fd, layerFlags, mnt, &d.vfsd, &layerFDOpts); err != nil {
+ layerFD.DecRef()
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// Preconditions: parent.dirMu must be locked. parent does not already contain
+// a child named rp.Component().
+func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *dentry, opts *vfs.OpenOptions, ds **[]*dentry) (*vfs.FileDescription, error) {
+ creds := rp.Credentials()
+ if err := parent.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ if parent.vfsd.IsDead() {
+ return nil, syserror.ENOENT
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return nil, err
+ }
+ defer mnt.EndWrite()
+
+ if err := parent.copyUpLocked(ctx); err != nil {
+ return nil, err
+ }
+
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ childName := rp.Component()
+ pop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(childName),
+ }
+ // We don't know if a whiteout exists on the upper layer; speculatively
+ // unlink it.
+ //
+ // TODO(gvisor.dev/issue/1199): Modify OpenAt => stepLocked so that we do
+ // know whether a whiteout exists.
+ var haveUpperWhiteout bool
+ switch err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err {
+ case nil:
+ haveUpperWhiteout = true
+ case syserror.ENOENT:
+ haveUpperWhiteout = false
+ default:
+ return nil, err
+ }
+ // Create the file on the upper layer, and get an FD representing it.
+ upperFD, err := vfsObj.OpenAt(ctx, fs.creds, &pop, &vfs.OpenOptions{
+ Flags: opts.Flags&^vfs.FileCreationFlags | linux.O_CREAT | linux.O_EXCL,
+ Mode: opts.Mode,
+ })
+ if err != nil {
+ if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return nil, err
+ }
+ // Change the file's owner to the caller. We can't use upperFD.SetStat()
+ // because it will pick up creds from ctx.
+ if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ },
+ }); err != nil {
+ if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) metadata update failure: %v", cleanupErr)
+ } else if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return nil, err
+ }
+ // Re-lookup to get a dentry representing the new file, which is needed for
+ // the returned FD.
+ child, err := fs.getChildLocked(ctx, parent, childName, ds)
+ if err != nil {
+ if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) dentry lookup failure: %v", cleanupErr)
+ } else if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return nil, err
+ }
+ // Finally construct the overlay FD.
+ upperFlags := upperFD.StatusFlags()
+ fd := &nonDirectoryFD{
+ copiedUp: true,
+ cachedFD: upperFD,
+ cachedFlags: upperFlags,
+ }
+ fd.LockFD.Init(&child.locks)
+ upperFDOpts := upperFD.Options()
+ if err := fd.vfsfd.Init(fd, upperFlags, mnt, &child.vfsd, &upperFDOpts); err != nil {
+ upperFD.DecRef()
+ // Don't bother with cleanup; the file was created successfully, we
+ // just can't open it anymore for some reason.
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
+func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return "", err
+ }
+ layerVD := d.topLayer()
+ return fs.vfsfs.VirtualFilesystem().ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ })
+}
+
+// RenameAt implements vfs.FilesystemImpl.RenameAt.
+func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
+ if opts.Flags != 0 {
+ return syserror.EINVAL
+ }
+
+ var ds *[]*dentry
+ fs.renameMu.Lock()
+ defer fs.renameMuUnlockAndCheckDrop(&ds)
+ newParent, err := fs.walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry), &ds)
+ if err != nil {
+ return err
+ }
+ newName := rp.Component()
+ if newName == "." || newName == ".." {
+ return syserror.EBUSY
+ }
+ mnt := rp.Mount()
+ if mnt != oldParentVD.Mount() {
+ return syserror.EXDEV
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+
+ // FIXME(gvisor.dev/issue/1199): Actually implement rename.
+ _ = newParent
+ return syserror.EXDEV
+}
+
+// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
+func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ start := rp.Start().Impl().(*dentry)
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return err
+ }
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ name := rp.Component()
+ if name == "." {
+ return syserror.EINVAL
+ }
+ if name == ".." {
+ return syserror.ENOTEMPTY
+ }
+ vfsObj := rp.VirtualFilesystem()
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+
+ // Ensure that parent is copied-up before potentially holding child.copyMu
+ // below.
+ if err := parent.copyUpLocked(ctx); err != nil {
+ return err
+ }
+
+ // Unlike UnlinkAt, we need a dentry representing the child directory being
+ // removed in order to verify that it's empty.
+ child, err := fs.getChildLocked(ctx, parent, name, &ds)
+ if err != nil {
+ return err
+ }
+ if !child.isDir() {
+ return syserror.ENOTDIR
+ }
+ child.dirMu.Lock()
+ defer child.dirMu.Unlock()
+ whiteouts, err := child.collectWhiteoutsForRmdirLocked(ctx)
+ if err != nil {
+ return err
+ }
+ child.copyMu.RLock()
+ defer child.copyMu.RUnlock()
+ if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
+ return err
+ }
+
+ pop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(name),
+ }
+ if child.upperVD.Ok() {
+ cleanupRecreateWhiteouts := func() {
+ if !child.upperVD.Ok() {
+ return
+ }
+ for whiteoutName, whiteoutUpper := range whiteouts {
+ if !whiteoutUpper {
+ continue
+ }
+ if err := fs.createWhiteout(ctx, vfsObj, &vfs.PathOperation{
+ Root: child.upperVD,
+ Start: child.upperVD,
+ Path: fspath.Parse(whiteoutName),
+ }); err != nil && err != syserror.EEXIST {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RmdirAt failure: %v", err)
+ }
+ }
+ }
+ // Remove existing whiteouts on the upper layer.
+ for whiteoutName, whiteoutUpper := range whiteouts {
+ if !whiteoutUpper {
+ continue
+ }
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: child.upperVD,
+ Start: child.upperVD,
+ Path: fspath.Parse(whiteoutName),
+ }); err != nil {
+ cleanupRecreateWhiteouts()
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return err
+ }
+ }
+ // Remove the existing directory on the upper layer.
+ if err := vfsObj.RmdirAt(ctx, fs.creds, &pop); err != nil {
+ cleanupRecreateWhiteouts()
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return err
+ }
+ }
+ if err := fs.createWhiteout(ctx, vfsObj, &pop); err != nil {
+ // Don't attempt to recover from this: the original directory is
+ // already gone, so any dentries representing it are invalid, and
+ // creating a new directory won't undo that.
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to create whiteout during RmdirAt: %v", err)
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return err
+ }
+
+ vfsObj.CommitDeleteDentry(&child.vfsd)
+ delete(parent.children, name)
+ ds = appendDentry(ds, child)
+ parent.dirents = nil
+ return nil
+}
+
+// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
+func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+
+ mode := linux.FileMode(atomic.LoadUint32(&d.mode))
+ if err := vfs.CheckSetStat(ctx, rp.Credentials(), &opts.Stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ return err
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ if err := d.copyUpLocked(ctx); err != nil {
+ return err
+ }
+ // Changes to d's attributes are serialized by d.copyMu.
+ d.copyMu.Lock()
+ defer d.copyMu.Unlock()
+ if err := d.fs.vfsfs.VirtualFilesystem().SetStatAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.upperVD,
+ Start: d.upperVD,
+ }, &opts); err != nil {
+ return err
+ }
+ d.updateAfterSetStatLocked(&opts)
+ return nil
+}
+
+// StatAt implements vfs.FilesystemImpl.StatAt.
+func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+
+ var stat linux.Statx
+ if layerMask := opts.Mask &^ statInternalMask; layerMask != 0 {
+ layerVD := d.topLayer()
+ stat, err = fs.vfsfs.VirtualFilesystem().StatAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ }, &vfs.StatOptions{
+ Mask: layerMask,
+ Sync: opts.Sync,
+ })
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ }
+ d.statInternalTo(ctx, &opts, &stat)
+ return stat, nil
+}
+
+// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
+func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ _, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return linux.Statfs{}, err
+ }
+ return fs.statFS(ctx)
+}
+
+// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
+func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ pop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(childName),
+ }
+ if haveUpperWhiteout {
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err != nil {
+ return err
+ }
+ }
+ if err := vfsObj.SymlinkAt(ctx, fs.creds, &pop, target); err != nil {
+ if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ creds := rp.Credentials()
+ if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ },
+ }); err != nil {
+ if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after SymlinkAt metadata update failure: %v", cleanupErr)
+ } else if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ return nil
+ })
+}
+
+// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
+func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ start := rp.Start().Impl().(*dentry)
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return err
+ }
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ name := rp.Component()
+ if name == "." || name == ".." {
+ return syserror.EISDIR
+ }
+ if rp.MustBeDir() {
+ return syserror.ENOTDIR
+ }
+ vfsObj := rp.VirtualFilesystem()
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+
+ // Ensure that parent is copied-up before potentially holding child.copyMu
+ // below.
+ if err := parent.copyUpLocked(ctx); err != nil {
+ return err
+ }
+
+ child := parent.children[name]
+ var childLayer lookupLayer
+ if child != nil {
+ if child.isDir() {
+ return syserror.EISDIR
+ }
+ if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
+ return err
+ }
+ // Hold child.copyMu to prevent it from being copied-up during
+ // deletion.
+ child.copyMu.RLock()
+ defer child.copyMu.RUnlock()
+ if child.upperVD.Ok() {
+ childLayer = lookupLayerUpper
+ } else {
+ childLayer = lookupLayerLower
+ }
+ } else {
+ // Determine if the file being unlinked actually exists. Holding
+ // parent.dirMu prevents a dentry from being instantiated for the file,
+ // which in turn prevents it from being copied-up, so this result is
+ // stable.
+ childLayer, err = fs.lookupLayerLocked(ctx, parent, name)
+ if err != nil {
+ return err
+ }
+ if !childLayer.existsInOverlay() {
+ return syserror.ENOENT
+ }
+ }
+
+ pop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(name),
+ }
+ if childLayer == lookupLayerUpper {
+ // Remove the existing file on the upper layer.
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err != nil {
+ if child != nil {
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ }
+ return err
+ }
+ }
+ if err := fs.createWhiteout(ctx, vfsObj, &pop); err != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to create whiteout during UnlinkAt: %v", err)
+ if child != nil {
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ }
+ return err
+ }
+
+ if child != nil {
+ vfsObj.CommitDeleteDentry(&child.vfsd)
+ delete(parent.children, name)
+ ds = appendDentry(ds, child)
+ }
+ parent.dirents = nil
+ return nil
+}
+
+// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ _, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ // TODO(gvisor.dev/issue/1199): Linux overlayfs actually allows listxattr,
+ // but not any other xattr syscalls. For now we just reject all of them.
+ return nil, syserror.ENOTSUP
+}
+
+// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ _, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return "", err
+ }
+ return "", syserror.ENOTSUP
+}
+
+// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
+func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ _, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+ return syserror.ENOTSUP
+}
+
+// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
+func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ _, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+ return syserror.ENOTSUP
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ fs.renameMu.RLock()
+ defer fs.renameMu.RUnlock()
+ return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b)
+}
diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/non_directory.go
new file mode 100644
index 000000000..a3c1f7a8d
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/non_directory.go
@@ -0,0 +1,266 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package overlay
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func (d *dentry) isSymlink() bool {
+ return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFLNK
+}
+
+func (d *dentry) readlink(ctx context.Context) (string, error) {
+ layerVD := d.topLayer()
+ return d.fs.vfsfs.VirtualFilesystem().ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ })
+}
+
+type nonDirectoryFD struct {
+ fileDescription
+
+ // If copiedUp is false, cachedFD represents
+ // fileDescription.dentry().lowerVDs[0]; otherwise, cachedFD represents
+ // fileDescription.dentry().upperVD. cachedFlags is the last known value of
+ // cachedFD.StatusFlags(). copiedUp, cachedFD, and cachedFlags are
+ // protected by mu.
+ mu sync.Mutex
+ copiedUp bool
+ cachedFD *vfs.FileDescription
+ cachedFlags uint32
+}
+
+func (fd *nonDirectoryFD) getCurrentFD(ctx context.Context) (*vfs.FileDescription, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ wrappedFD, err := fd.currentFDLocked(ctx)
+ if err != nil {
+ return nil, err
+ }
+ wrappedFD.IncRef()
+ return wrappedFD, nil
+}
+
+func (fd *nonDirectoryFD) currentFDLocked(ctx context.Context) (*vfs.FileDescription, error) {
+ d := fd.dentry()
+ statusFlags := fd.vfsfd.StatusFlags()
+ if !fd.copiedUp && d.isCopiedUp() {
+ // Switch to the copied-up file.
+ upperVD := d.topLayer()
+ upperFD, err := fd.filesystem().vfsfs.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: upperVD,
+ Start: upperVD,
+ }, &vfs.OpenOptions{
+ Flags: statusFlags,
+ })
+ if err != nil {
+ return nil, err
+ }
+ oldOff, oldOffErr := fd.cachedFD.Seek(ctx, 0, linux.SEEK_CUR)
+ if oldOffErr == nil {
+ if _, err := upperFD.Seek(ctx, oldOff, linux.SEEK_SET); err != nil {
+ upperFD.DecRef()
+ return nil, err
+ }
+ }
+ fd.cachedFD.DecRef()
+ fd.copiedUp = true
+ fd.cachedFD = upperFD
+ fd.cachedFlags = statusFlags
+ } else if fd.cachedFlags != statusFlags {
+ if err := fd.cachedFD.SetStatusFlags(ctx, d.fs.creds, statusFlags); err != nil {
+ return nil, err
+ }
+ fd.cachedFlags = statusFlags
+ }
+ return fd.cachedFD, nil
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *nonDirectoryFD) Release() {
+ fd.cachedFD.DecRef()
+ fd.cachedFD = nil
+}
+
+// OnClose implements vfs.FileDescriptionImpl.OnClose.
+func (fd *nonDirectoryFD) OnClose(ctx context.Context) error {
+ // Linux doesn't define ovl_file_operations.flush at all (i.e. its
+ // equivalent to OnClose is a no-op). We pass through to
+ // fd.cachedFD.OnClose() without upgrading if fd.dentry() has been
+ // copied-up, since OnClose is mostly used to define post-close writeback,
+ // and if fd.cachedFD hasn't been updated then it can't have been used to
+ // mutate fd.dentry() anyway.
+ fd.mu.Lock()
+ if statusFlags := fd.vfsfd.StatusFlags(); fd.cachedFlags != statusFlags {
+ if err := fd.cachedFD.SetStatusFlags(ctx, fd.filesystem().creds, statusFlags); err != nil {
+ fd.mu.Unlock()
+ return err
+ }
+ fd.cachedFlags = statusFlags
+ }
+ wrappedFD := fd.cachedFD
+ defer wrappedFD.IncRef()
+ fd.mu.Unlock()
+ return wrappedFD.OnClose(ctx)
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ var stat linux.Statx
+ if layerMask := opts.Mask &^ statInternalMask; layerMask != 0 {
+ wrappedFD, err := fd.getCurrentFD(ctx)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ stat, err = wrappedFD.Stat(ctx, vfs.StatOptions{
+ Mask: layerMask,
+ Sync: opts.Sync,
+ })
+ wrappedFD.DecRef()
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ }
+ fd.dentry().statInternalTo(ctx, &opts, &stat)
+ return stat, nil
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ d := fd.dentry()
+ mode := linux.FileMode(atomic.LoadUint32(&d.mode))
+ if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts.Stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ return err
+ }
+ mnt := fd.vfsfd.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ if err := d.copyUpLocked(ctx); err != nil {
+ return err
+ }
+ // Changes to d's attributes are serialized by d.copyMu.
+ d.copyMu.Lock()
+ defer d.copyMu.Unlock()
+ wrappedFD, err := fd.currentFDLocked(ctx)
+ if err != nil {
+ return err
+ }
+ if err := wrappedFD.SetStat(ctx, opts); err != nil {
+ return err
+ }
+ d.updateAfterSetStatLocked(&opts)
+ return nil
+}
+
+// StatFS implements vfs.FileDesciptionImpl.StatFS.
+func (fd *nonDirectoryFD) StatFS(ctx context.Context) (linux.Statfs, error) {
+ return fd.filesystem().statFS(ctx)
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *nonDirectoryFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ wrappedFD, err := fd.getCurrentFD(ctx)
+ if err != nil {
+ return 0, err
+ }
+ defer wrappedFD.DecRef()
+ return wrappedFD.PRead(ctx, dst, offset, opts)
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *nonDirectoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // Hold fd.mu during the read to serialize the file offset.
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ wrappedFD, err := fd.currentFDLocked(ctx)
+ if err != nil {
+ return 0, err
+ }
+ return wrappedFD.Read(ctx, dst, opts)
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *nonDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ wrappedFD, err := fd.getCurrentFD(ctx)
+ if err != nil {
+ return 0, err
+ }
+ defer wrappedFD.DecRef()
+ return wrappedFD.PWrite(ctx, src, offset, opts)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *nonDirectoryFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // Hold fd.mu during the write to serialize the file offset.
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ wrappedFD, err := fd.currentFDLocked(ctx)
+ if err != nil {
+ return 0, err
+ }
+ return wrappedFD.Write(ctx, src, opts)
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *nonDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ // Hold fd.mu during the seek to serialize the file offset.
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ wrappedFD, err := fd.currentFDLocked(ctx)
+ if err != nil {
+ return 0, err
+ }
+ return wrappedFD.Seek(ctx, offset, whence)
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync.
+func (fd *nonDirectoryFD) Sync(ctx context.Context) error {
+ fd.mu.Lock()
+ if !fd.dentry().isCopiedUp() {
+ fd.mu.Unlock()
+ return nil
+ }
+ wrappedFD, err := fd.currentFDLocked(ctx)
+ if err != nil {
+ fd.mu.Unlock()
+ return err
+ }
+ wrappedFD.IncRef()
+ defer wrappedFD.DecRef()
+ fd.mu.Unlock()
+ return wrappedFD.Sync(ctx)
+}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *nonDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ wrappedFD, err := fd.getCurrentFD(ctx)
+ if err != nil {
+ return err
+ }
+ defer wrappedFD.DecRef()
+ return wrappedFD.ConfigureMMap(ctx, opts)
+}
diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go
new file mode 100644
index 000000000..e720d4825
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/overlay.go
@@ -0,0 +1,627 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package overlay provides an overlay filesystem implementation, which
+// synthesizes a filesystem by composing one or more immutable filesystems
+// ("lower layers") with an optional mutable filesystem ("upper layer").
+//
+// Lock order:
+//
+// directoryFD.mu / nonDirectoryFD.mu
+// filesystem.renameMu
+// dentry.dirMu
+// dentry.copyMu
+//
+// Locking dentry.dirMu in multiple dentries requires that parent dentries are
+// locked before child dentries, and that filesystem.renameMu is locked to
+// stabilize this relationship.
+package overlay
+
+import (
+ "strings"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Name is the default filesystem name.
+const Name = "overlay"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// FilesystemOptions may be passed as vfs.GetFilesystemOptions.InternalData to
+// FilesystemType.GetFilesystem.
+type FilesystemOptions struct {
+ // Callers passing FilesystemOptions to
+ // overlay.FilesystemType.GetFilesystem() are responsible for ensuring that
+ // the vfs.Mounts comprising the layers of the overlay filesystem do not
+ // contain submounts.
+
+ // If UpperRoot.Ok(), it is the root of the writable upper layer of the
+ // overlay.
+ UpperRoot vfs.VirtualDentry
+
+ // LowerRoots contains the roots of the immutable lower layers of the
+ // overlay. LowerRoots is immutable.
+ LowerRoots []vfs.VirtualDentry
+}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ vfsfs vfs.Filesystem
+
+ // Immutable options.
+ opts FilesystemOptions
+
+ // creds is a copy of the filesystem's creator's credentials, which are
+ // used for accesses to the filesystem's layers. creds is immutable.
+ creds *auth.Credentials
+
+ // dirDevMinor is the device minor number used for directories. dirDevMinor
+ // is immutable.
+ dirDevMinor uint32
+
+ // lowerDevMinors maps lower layer filesystems to device minor numbers
+ // assigned to non-directory files originating from that filesystem.
+ // lowerDevMinors is immutable.
+ lowerDevMinors map[*vfs.Filesystem]uint32
+
+ // renameMu synchronizes renaming with non-renaming operations in order to
+ // ensure consistent lock ordering between dentry.dirMu in different
+ // dentries.
+ renameMu sync.RWMutex
+
+ // lastDirIno is the last inode number assigned to a directory. lastDirIno
+ // is accessed using atomic memory operations.
+ lastDirIno uint64
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ fsoptsRaw := opts.InternalData
+ fsopts, haveFSOpts := fsoptsRaw.(FilesystemOptions)
+ if fsoptsRaw != nil && !haveFSOpts {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw)
+ return nil, nil, syserror.EINVAL
+ }
+ if haveFSOpts {
+ if len(fsopts.LowerRoots) == 0 {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: LowerRoots must be non-empty")
+ return nil, nil, syserror.EINVAL
+ }
+ if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: at least two LowerRoots are required when UpperRoot is unspecified")
+ return nil, nil, syserror.EINVAL
+ }
+ // We don't enforce a maximum number of lower layers when not
+ // configured by applications; the sandbox owner can have an overlay
+ // filesystem with any number of lower layers.
+ } else {
+ vfsroot := vfs.RootFromContext(ctx)
+ defer vfsroot.DecRef()
+ upperPathname, ok := mopts["upperdir"]
+ if ok {
+ delete(mopts, "upperdir")
+ // Linux overlayfs also requires a workdir when upperdir is
+ // specified; we don't, so silently ignore this option.
+ delete(mopts, "workdir")
+ upperPath := fspath.Parse(upperPathname)
+ if !upperPath.Absolute {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname)
+ return nil, nil, syserror.EINVAL
+ }
+ upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
+ Root: vfsroot,
+ Start: vfsroot,
+ Path: upperPath,
+ FollowFinalSymlink: true,
+ }, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err)
+ return nil, nil, err
+ }
+ defer upperRoot.DecRef()
+ privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */)
+ if err != nil {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err)
+ return nil, nil, err
+ }
+ defer privateUpperRoot.DecRef()
+ fsopts.UpperRoot = privateUpperRoot
+ }
+ lowerPathnamesStr, ok := mopts["lowerdir"]
+ if !ok {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: missing required option lowerdir")
+ return nil, nil, syserror.EINVAL
+ }
+ delete(mopts, "lowerdir")
+ lowerPathnames := strings.Split(lowerPathnamesStr, ":")
+ const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK
+ if len(lowerPathnames) < 2 && !fsopts.UpperRoot.Ok() {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: at least two lowerdirs are required when upperdir is unspecified")
+ return nil, nil, syserror.EINVAL
+ }
+ if len(lowerPathnames) > maxLowerLayers {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: %d lowerdirs specified, maximum %d", len(lowerPathnames), maxLowerLayers)
+ return nil, nil, syserror.EINVAL
+ }
+ for _, lowerPathname := range lowerPathnames {
+ lowerPath := fspath.Parse(lowerPathname)
+ if !lowerPath.Absolute {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: lowerdir %q must be absolute", lowerPathname)
+ return nil, nil, syserror.EINVAL
+ }
+ lowerRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
+ Root: vfsroot,
+ Start: vfsroot,
+ Path: lowerPath,
+ FollowFinalSymlink: true,
+ }, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err)
+ return nil, nil, err
+ }
+ defer lowerRoot.DecRef()
+ privateLowerRoot, err := clonePrivateMount(vfsObj, lowerRoot, true /* forceReadOnly */)
+ if err != nil {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err)
+ return nil, nil, err
+ }
+ defer privateLowerRoot.DecRef()
+ fsopts.LowerRoots = append(fsopts.LowerRoots, privateLowerRoot)
+ }
+ }
+ if len(mopts) != 0 {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Allocate device numbers.
+ dirDevMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+ lowerDevMinors := make(map[*vfs.Filesystem]uint32)
+ for _, lowerRoot := range fsopts.LowerRoots {
+ lowerFS := lowerRoot.Mount().Filesystem()
+ if _, ok := lowerDevMinors[lowerFS]; !ok {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ vfsObj.PutAnonBlockDevMinor(dirDevMinor)
+ for _, lowerDevMinor := range lowerDevMinors {
+ vfsObj.PutAnonBlockDevMinor(lowerDevMinor)
+ }
+ return nil, nil, err
+ }
+ lowerDevMinors[lowerFS] = devMinor
+ }
+ }
+
+ // Take extra references held by the filesystem.
+ if fsopts.UpperRoot.Ok() {
+ fsopts.UpperRoot.IncRef()
+ }
+ for _, lowerRoot := range fsopts.LowerRoots {
+ lowerRoot.IncRef()
+ }
+
+ fs := &filesystem{
+ opts: fsopts,
+ creds: creds.Fork(),
+ dirDevMinor: dirDevMinor,
+ lowerDevMinors: lowerDevMinors,
+ }
+ fs.vfsfs.Init(vfsObj, &fstype, fs)
+
+ // Construct the root dentry.
+ root := fs.newDentry()
+ root.refs = 1
+ if fs.opts.UpperRoot.Ok() {
+ fs.opts.UpperRoot.IncRef()
+ root.copiedUp = 1
+ root.upperVD = fs.opts.UpperRoot
+ }
+ for _, lowerRoot := range fs.opts.LowerRoots {
+ lowerRoot.IncRef()
+ root.lowerVDs = append(root.lowerVDs, lowerRoot)
+ }
+ rootTopVD := root.topLayer()
+ // Get metadata from the topmost layer. See fs.lookupLocked().
+ const rootStatMask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO
+ rootStat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{
+ Root: rootTopVD,
+ Start: rootTopVD,
+ }, &vfs.StatOptions{
+ Mask: rootStatMask,
+ })
+ if err != nil {
+ root.destroyLocked()
+ fs.vfsfs.DecRef()
+ return nil, nil, err
+ }
+ if rootStat.Mask&rootStatMask != rootStatMask {
+ root.destroyLocked()
+ fs.vfsfs.DecRef()
+ return nil, nil, syserror.EREMOTE
+ }
+ if isWhiteout(&rootStat) {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: filesystem root is a whiteout")
+ root.destroyLocked()
+ fs.vfsfs.DecRef()
+ return nil, nil, syserror.EINVAL
+ }
+ root.mode = uint32(rootStat.Mode)
+ root.uid = rootStat.UID
+ root.gid = rootStat.GID
+ if rootStat.Mode&linux.S_IFMT == linux.S_IFDIR {
+ root.devMajor = linux.UNNAMED_MAJOR
+ root.devMinor = fs.dirDevMinor
+ root.ino = fs.newDirIno()
+ } else if !root.upperVD.Ok() {
+ root.devMajor = linux.UNNAMED_MAJOR
+ root.devMinor = fs.lowerDevMinors[root.lowerVDs[0].Mount().Filesystem()]
+ root.ino = rootStat.Ino
+ } else {
+ root.devMajor = rootStat.DevMajor
+ root.devMinor = rootStat.DevMinor
+ root.ino = rootStat.Ino
+ }
+
+ return &fs.vfsfs, &root.vfsd, nil
+}
+
+// clonePrivateMount creates a non-recursive bind mount rooted at vd, not
+// associated with any MountNamespace, and returns the root of the new mount.
+// (This is required to ensure that each layer of an overlay comprises only a
+// single mount, and therefore can't cross into e.g. the overlay filesystem
+// itself, risking lock recursion.) A reference is held on the returned
+// VirtualDentry.
+func clonePrivateMount(vfsObj *vfs.VirtualFilesystem, vd vfs.VirtualDentry, forceReadOnly bool) (vfs.VirtualDentry, error) {
+ oldmnt := vd.Mount()
+ opts := oldmnt.Options()
+ if forceReadOnly {
+ opts.ReadOnly = true
+ }
+ newmnt, err := vfsObj.NewDisconnectedMount(oldmnt.Filesystem(), vd.Dentry(), &opts)
+ if err != nil {
+ return vfs.VirtualDentry{}, err
+ }
+ return vfs.MakeVirtualDentry(newmnt, vd.Dentry()), nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release() {
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ vfsObj.PutAnonBlockDevMinor(fs.dirDevMinor)
+ for _, lowerDevMinor := range fs.lowerDevMinors {
+ vfsObj.PutAnonBlockDevMinor(lowerDevMinor)
+ }
+ if fs.opts.UpperRoot.Ok() {
+ fs.opts.UpperRoot.DecRef()
+ }
+ for _, lowerRoot := range fs.opts.LowerRoots {
+ lowerRoot.DecRef()
+ }
+}
+
+func (fs *filesystem) statFS(ctx context.Context) (linux.Statfs, error) {
+ // Always statfs the root of the topmost layer. Compare Linux's
+ // fs/overlayfs/super.c:ovl_statfs().
+ var rootVD vfs.VirtualDentry
+ if fs.opts.UpperRoot.Ok() {
+ rootVD = fs.opts.UpperRoot
+ } else {
+ rootVD = fs.opts.LowerRoots[0]
+ }
+ fsstat, err := fs.vfsfs.VirtualFilesystem().StatFSAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: rootVD,
+ Start: rootVD,
+ })
+ if err != nil {
+ return linux.Statfs{}, err
+ }
+ fsstat.Type = linux.OVERLAYFS_SUPER_MAGIC
+ return fsstat, nil
+}
+
+func (fs *filesystem) newDirIno() uint64 {
+ return atomic.AddUint64(&fs.lastDirIno, 1)
+}
+
+// dentry implements vfs.DentryImpl.
+type dentry struct {
+ vfsd vfs.Dentry
+
+ refs int64
+
+ // fs is the owning filesystem. fs is immutable.
+ fs *filesystem
+
+ // mode, uid, and gid are the file mode, owner, and group of the file in
+ // the topmost layer (and therefore the overlay file as well), and are used
+ // for permission checks on this dentry. These fields are protected by
+ // copyMu and accessed using atomic memory operations.
+ mode uint32
+ uid uint32
+ gid uint32
+
+ // copiedUp is 1 if this dentry has been copied-up (i.e. upperVD.Ok()) and
+ // 0 otherwise. copiedUp is accessed using atomic memory operations.
+ copiedUp uint32
+
+ // parent is the dentry corresponding to this dentry's parent directory.
+ // name is this dentry's name in parent. If this dentry is a filesystem
+ // root, parent is nil and name is the empty string. parent and name are
+ // protected by fs.renameMu.
+ parent *dentry
+ name string
+
+ // If this dentry represents a directory, children maps the names of
+ // children for which dentries have been instantiated to those dentries,
+ // and dirents (if not nil) is a cache of dirents as returned by
+ // directoryFDs representing this directory. children is protected by
+ // dirMu.
+ dirMu sync.Mutex
+ children map[string]*dentry
+ dirents []vfs.Dirent
+
+ // upperVD and lowerVDs are the files from the overlay filesystem's layers
+ // that comprise the file on the overlay filesystem.
+ //
+ // If !upperVD.Ok(), it can transition to a valid vfs.VirtualDentry (i.e.
+ // be copied up) with copyMu locked for writing; otherwise, it is
+ // immutable. lowerVDs is always immutable.
+ copyMu sync.RWMutex
+ upperVD vfs.VirtualDentry
+ lowerVDs []vfs.VirtualDentry
+
+ // inlineLowerVDs backs lowerVDs in the common case where len(lowerVDs) <=
+ // len(inlineLowerVDs).
+ inlineLowerVDs [1]vfs.VirtualDentry
+
+ // devMajor, devMinor, and ino are the device major/minor and inode numbers
+ // used by this dentry. These fields are protected by copyMu and accessed
+ // using atomic memory operations.
+ devMajor uint32
+ devMinor uint32
+ ino uint64
+
+ locks vfs.FileLocks
+}
+
+// newDentry creates a new dentry. The dentry initially has no references; it
+// is the caller's responsibility to set the dentry's reference count and/or
+// call dentry.destroy() as appropriate. The dentry is initially invalid in
+// that it contains no layers; the caller is responsible for setting them.
+func (fs *filesystem) newDentry() *dentry {
+ d := &dentry{
+ fs: fs,
+ }
+ d.lowerVDs = d.inlineLowerVDs[:0]
+ d.vfsd.Init(d)
+ return d
+}
+
+// IncRef implements vfs.DentryImpl.IncRef.
+func (d *dentry) IncRef() {
+ // d.refs may be 0 if d.fs.renameMu is locked, which serializes against
+ // d.checkDropLocked().
+ atomic.AddInt64(&d.refs, 1)
+}
+
+// TryIncRef implements vfs.DentryImpl.TryIncRef.
+func (d *dentry) TryIncRef() bool {
+ for {
+ refs := atomic.LoadInt64(&d.refs)
+ if refs <= 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) {
+ return true
+ }
+ }
+}
+
+// DecRef implements vfs.DentryImpl.DecRef.
+func (d *dentry) DecRef() {
+ if refs := atomic.AddInt64(&d.refs, -1); refs == 0 {
+ d.fs.renameMu.Lock()
+ d.checkDropLocked()
+ d.fs.renameMu.Unlock()
+ } else if refs < 0 {
+ panic("overlay.dentry.DecRef() called without holding a reference")
+ }
+}
+
+// checkDropLocked should be called after d's reference count becomes 0 or it
+// becomes deleted.
+//
+// Preconditions: d.fs.renameMu must be locked for writing.
+func (d *dentry) checkDropLocked() {
+ // Dentries with a positive reference count must be retained. (The only way
+ // to obtain a reference on a dentry with zero references is via path
+ // resolution, which requires renameMu, so if d.refs is zero then it will
+ // remain zero while we hold renameMu for writing.) Dentries with a
+ // negative reference count have already been destroyed.
+ if atomic.LoadInt64(&d.refs) != 0 {
+ return
+ }
+ // Refs is still zero; destroy it.
+ d.destroyLocked()
+ return
+}
+
+// destroyLocked destroys the dentry.
+//
+// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0.
+func (d *dentry) destroyLocked() {
+ switch atomic.LoadInt64(&d.refs) {
+ case 0:
+ // Mark the dentry destroyed.
+ atomic.StoreInt64(&d.refs, -1)
+ case -1:
+ panic("overlay.dentry.destroyLocked() called on already destroyed dentry")
+ default:
+ panic("overlay.dentry.destroyLocked() called with references on the dentry")
+ }
+
+ if d.upperVD.Ok() {
+ d.upperVD.DecRef()
+ }
+ for _, lowerVD := range d.lowerVDs {
+ lowerVD.DecRef()
+ }
+
+ if d.parent != nil {
+ d.parent.dirMu.Lock()
+ if !d.vfsd.IsDead() {
+ delete(d.parent.children, d.name)
+ }
+ d.parent.dirMu.Unlock()
+ // Drop the reference held by d on its parent without recursively
+ // locking d.fs.renameMu.
+ if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 {
+ d.parent.checkDropLocked()
+ } else if refs < 0 {
+ panic("overlay.dentry.DecRef() called without holding a reference")
+ }
+ }
+}
+
+// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
+func (d *dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) {
+ // TODO(gvisor.dev/issue/1479): Implement inotify.
+}
+
+// Watches implements vfs.DentryImpl.Watches.
+func (d *dentry) Watches() *vfs.Watches {
+ // TODO(gvisor.dev/issue/1479): Implement inotify.
+ return nil
+}
+
+// OnZeroWatches implements vfs.DentryImpl.OnZeroWatches.
+//
+// TODO(gvisor.dev/issue/1479): Implement inotify.
+func (d *dentry) OnZeroWatches() {}
+
+// iterLayers invokes yield on each layer comprising d, from top to bottom. If
+// any call to yield returns false, iterLayer stops iteration.
+func (d *dentry) iterLayers(yield func(vd vfs.VirtualDentry, isUpper bool) bool) {
+ if d.isCopiedUp() {
+ if !yield(d.upperVD, true) {
+ return
+ }
+ }
+ for _, lowerVD := range d.lowerVDs {
+ if !yield(lowerVD, false) {
+ return
+ }
+ }
+}
+
+func (d *dentry) topLayerInfo() (vd vfs.VirtualDentry, isUpper bool) {
+ if d.isCopiedUp() {
+ return d.upperVD, true
+ }
+ return d.lowerVDs[0], false
+}
+
+func (d *dentry) topLayer() vfs.VirtualDentry {
+ vd, _ := d.topLayerInfo()
+ return vd
+}
+
+func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+ return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid)))
+}
+
+// statInternalMask is the set of stat fields that is set by
+// dentry.statInternalTo().
+const statInternalMask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO
+
+// statInternalTo writes fields to stat that are stored in d, and therefore do
+// not requiring invoking StatAt on the overlay's layers.
+func (d *dentry) statInternalTo(ctx context.Context, opts *vfs.StatOptions, stat *linux.Statx) {
+ stat.Mask |= statInternalMask
+ if d.isDir() {
+ // Linux sets nlink to 1 for merged directories
+ // (fs/overlayfs/inode.c:ovl_getattr()); we set it to 2 because this is
+ // correct more often ("." and the directory's entry in its parent),
+ // and some of our tests expect this.
+ stat.Nlink = 2
+ }
+ stat.UID = atomic.LoadUint32(&d.uid)
+ stat.GID = atomic.LoadUint32(&d.gid)
+ stat.Mode = uint16(atomic.LoadUint32(&d.mode))
+ stat.Ino = atomic.LoadUint64(&d.ino)
+ stat.DevMajor = atomic.LoadUint32(&d.devMajor)
+ stat.DevMinor = atomic.LoadUint32(&d.devMinor)
+}
+
+// Preconditions: d.copyMu must be locked for writing.
+func (d *dentry) updateAfterSetStatLocked(opts *vfs.SetStatOptions) {
+ if opts.Stat.Mask&linux.STATX_MODE != 0 {
+ atomic.StoreUint32(&d.mode, (d.mode&linux.S_IFMT)|uint32(opts.Stat.Mode&^linux.S_IFMT))
+ }
+ if opts.Stat.Mask&linux.STATX_UID != 0 {
+ atomic.StoreUint32(&d.uid, opts.Stat.UID)
+ }
+ if opts.Stat.Mask&linux.STATX_GID != 0 {
+ atomic.StoreUint32(&d.gid, opts.Stat.GID)
+ }
+}
+
+// fileDescription is embedded by overlay implementations of
+// vfs.FileDescriptionImpl.
+type fileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+}
+
+func (fd *fileDescription) filesystem() *filesystem {
+ return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem)
+}
+
+func (fd *fileDescription) dentry() *dentry {
+ return fd.vfsfd.Dentry().Impl().(*dentry)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/pipefs/BUILD b/pkg/sentry/fsimpl/pipefs/BUILD
new file mode 100644
index 000000000..5950a2d59
--- /dev/null
+++ b/pkg/sentry/fsimpl/pipefs/BUILD
@@ -0,0 +1,21 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "pipefs",
+ srcs = ["pipefs.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/pipe",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go
new file mode 100644
index 000000000..dd7eaf4a8
--- /dev/null
+++ b/pkg/sentry/fsimpl/pipefs/pipefs.go
@@ -0,0 +1,165 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pipefs provides the filesystem implementation backing
+// Kernel.PipeMount.
+package pipefs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type filesystemType struct{}
+
+// Name implements vfs.FilesystemType.Name.
+func (filesystemType) Name() string {
+ return "pipefs"
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (filesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ panic("pipefs.filesystemType.GetFilesystem should never be called")
+}
+
+type filesystem struct {
+ kernfs.Filesystem
+
+ devMinor uint32
+}
+
+// NewFilesystem sets up and returns a new vfs.Filesystem implemented by pipefs.
+func NewFilesystem(vfsObj *vfs.VirtualFilesystem) (*vfs.Filesystem, error) {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, err
+ }
+ fs := &filesystem{
+ devMinor: devMinor,
+ }
+ fs.Filesystem.VFSFilesystem().Init(vfsObj, filesystemType{}, fs)
+ return fs.Filesystem.VFSFilesystem(), nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release() {
+ fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release()
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ inode := vd.Dentry().Impl().(*kernfs.Dentry).Inode().(*inode)
+ b.PrependComponent(fmt.Sprintf("pipe:[%d]", inode.ino))
+ return vfs.PrependPathSyntheticError{}
+}
+
+// inode implements kernfs.Inode.
+type inode struct {
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+ kernfs.InodeNoopRefCount
+
+ locks vfs.FileLocks
+ pipe *pipe.VFSPipe
+
+ ino uint64
+ uid auth.KUID
+ gid auth.KGID
+ // We use the creation timestamp for all of atime, mtime, and ctime.
+ ctime ktime.Time
+}
+
+func newInode(ctx context.Context, fs *filesystem) *inode {
+ creds := auth.CredentialsFromContext(ctx)
+ return &inode{
+ pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize),
+ ino: fs.Filesystem.NextIno(),
+ uid: creds.EffectiveKUID,
+ gid: creds.EffectiveKGID,
+ ctime: ktime.NowFromContext(ctx),
+ }
+}
+
+const pipeMode = 0600 | linux.S_IFIFO
+
+// CheckPermissions implements kernfs.Inode.CheckPermissions.
+func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ return vfs.GenericCheckPermissions(creds, ats, pipeMode, i.uid, i.gid)
+}
+
+// Mode implements kernfs.Inode.Mode.
+func (i *inode) Mode() linux.FileMode {
+ return pipeMode
+}
+
+// Stat implements kernfs.Inode.Stat.
+func (i *inode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ ts := linux.NsecToStatxTimestamp(i.ctime.Nanoseconds())
+ return linux.Statx{
+ Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS,
+ Blksize: usermem.PageSize,
+ Nlink: 1,
+ UID: uint32(i.uid),
+ GID: uint32(i.gid),
+ Mode: pipeMode,
+ Ino: i.ino,
+ Size: 0,
+ Blocks: 0,
+ Atime: ts,
+ Ctime: ts,
+ Mtime: ts,
+ DevMajor: linux.UNNAMED_MAJOR,
+ DevMinor: vfsfs.Impl().(*filesystem).devMinor,
+ }, nil
+}
+
+// SetStat implements kernfs.Inode.SetStat.
+func (i *inode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ if opts.Stat.Mask == 0 {
+ return nil
+ }
+ return syserror.EPERM
+}
+
+// TODO(gvisor.dev/issue/1193): kernfs does not provide a way to implement
+// statfs, from which we should indicate PIPEFS_MAGIC.
+
+// Open implements kernfs.Inode.Open.
+func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ return i.pipe.Open(ctx, rp.Mount(), vfsd, opts.Flags, &i.locks)
+}
+
+// NewConnectedPipeFDs returns a pair of FileDescriptions representing the read
+// and write ends of a newly-created pipe, as for pipe(2) and pipe2(2).
+//
+// Preconditions: mnt.Filesystem() must have been returned by NewFilesystem().
+func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, *vfs.FileDescription) {
+ fs := mnt.Filesystem().Impl().(*filesystem)
+ inode := newInode(ctx, fs)
+ var d kernfs.Dentry
+ d.Init(inode)
+ defer d.DecRef()
+ return inode.pipe.ReaderWriterPair(mnt, d.VFSDentry(), flags)
+}
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
new file mode 100644
index 000000000..6014138ff
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -0,0 +1,67 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+licenses(["notice"])
+
+go_library(
+ name = "proc",
+ srcs = [
+ "filesystem.go",
+ "subtasks.go",
+ "task.go",
+ "task_fds.go",
+ "task_files.go",
+ "task_net.go",
+ "tasks.go",
+ "tasks_files.go",
+ "tasks_sys.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/safemem",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsbridge",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/mm",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/unix",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/tcpip/header",
+ "//pkg/usermem",
+ ],
+)
+
+go_test(
+ name = "proc_test",
+ size = "small",
+ srcs = [
+ "tasks_sys_test.go",
+ "tasks_test.go",
+ ],
+ library = ":proc",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fsimpl/testutil",
+ "//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
new file mode 100644
index 000000000..609210253
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -0,0 +1,117 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package proc implements a partial in-memory file system for procfs.
+package proc
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// Name is the default filesystem name.
+const Name = "proc"
+
+// FilesystemType is the factory class for procfs.
+//
+// +stateify savable
+type FilesystemType struct{}
+
+var _ vfs.FilesystemType = (*FilesystemType)(nil)
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+type filesystem struct {
+ kernfs.Filesystem
+
+ devMinor uint32
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ k := kernel.KernelFromContext(ctx)
+ if k == nil {
+ return nil, nil, fmt.Errorf("procfs requires a kernel")
+ }
+ pidns := kernel.PIDNamespaceFromContext(ctx)
+ if pidns == nil {
+ return nil, nil, fmt.Errorf("procfs requires a PID namespace")
+ }
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+ procfs := &filesystem{
+ devMinor: devMinor,
+ }
+ procfs.VFSFilesystem().Init(vfsObj, &ft, procfs)
+
+ var cgroups map[string]string
+ if opts.InternalData != nil {
+ data := opts.InternalData.(*InternalData)
+ cgroups = data.Cgroups
+ }
+
+ _, dentry := procfs.newTasksInode(k, pidns, cgroups)
+ return procfs.VFSFilesystem(), dentry.VFSDentry(), nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release() {
+ fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release()
+}
+
+// dynamicInode is an overfitted interface for common Inodes with
+// dynamicByteSource types used in procfs.
+type dynamicInode interface {
+ kernfs.Inode
+ vfs.DynamicBytesSource
+
+ Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode)
+}
+
+func (fs *filesystem) newDentry(creds *auth.Credentials, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry {
+ inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+type staticFile struct {
+ kernfs.DynamicBytesFile
+ vfs.StaticData
+}
+
+var _ dynamicInode = (*staticFile)(nil)
+
+func newStaticFile(data string) *staticFile {
+ return &staticFile{StaticData: vfs.StaticData{Data: data}}
+}
+
+// InternalData contains internal data passed in to the procfs mount via
+// vfs.GetFilesystemOptions.InternalData.
+type InternalData struct {
+ Cgroups map[string]string
+}
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
new file mode 100644
index 000000000..36a89540c
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -0,0 +1,182 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "sort"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// subtasksInode represents the inode for /proc/[pid]/task/ directory.
+//
+// +stateify savable
+type subtasksInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
+ kernfs.AlwaysValid
+
+ locks vfs.FileLocks
+
+ fs *filesystem
+ task *kernel.Task
+ pidns *kernel.PIDNamespace
+ cgroupControllers map[string]string
+}
+
+var _ kernfs.Inode = (*subtasksInode)(nil)
+
+func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *kernfs.Dentry {
+ subInode := &subtasksInode{
+ fs: fs,
+ task: task,
+ pidns: pidns,
+ cgroupControllers: cgroupControllers,
+ }
+ // Note: credentials are overridden by taskOwnedInode.
+ subInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+ subInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+
+ inode := &taskOwnedInode{Inode: subInode, owner: task}
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
+
+ return dentry
+}
+
+// Lookup implements kernfs.inodeDynamicLookup.
+func (i *subtasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ tid, err := strconv.ParseUint(name, 10, 32)
+ if err != nil {
+ return nil, syserror.ENOENT
+ }
+
+ subTask := i.pidns.TaskWithID(kernel.ThreadID(tid))
+ if subTask == nil {
+ return nil, syserror.ENOENT
+ }
+ if subTask.ThreadGroup() != i.task.ThreadGroup() {
+ return nil, syserror.ENOENT
+ }
+
+ subTaskDentry := i.fs.newTaskInode(subTask, i.pidns, false, i.cgroupControllers)
+ return subTaskDentry.VFSDentry(), nil
+}
+
+// IterDirents implements kernfs.inodeDynamicLookup.
+func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+ tasks := i.task.ThreadGroup().MemberIDs(i.pidns)
+ if len(tasks) == 0 {
+ return offset, syserror.ENOENT
+ }
+ if relOffset >= int64(len(tasks)) {
+ return offset, nil
+ }
+
+ tids := make([]int, 0, len(tasks))
+ for _, tid := range tasks {
+ tids = append(tids, int(tid))
+ }
+
+ sort.Ints(tids)
+ for _, tid := range tids[relOffset:] {
+ dirent := vfs.Dirent{
+ Name: strconv.FormatUint(uint64(tid), 10),
+ Type: linux.DT_DIR,
+ Ino: i.fs.NextIno(),
+ NextOff: offset + 1,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+ return offset, nil
+}
+
+type subtasksFD struct {
+ kernfs.GenericDirectoryFD
+
+ task *kernel.Task
+}
+
+func (fd *subtasksFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ if fd.task.ExitState() >= kernel.TaskExitZombie {
+ return syserror.ENOENT
+ }
+ return fd.GenericDirectoryFD.IterDirents(ctx, cb)
+}
+
+// Seek implements vfs.FileDecriptionImpl.Seek.
+func (fd *subtasksFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ if fd.task.ExitState() >= kernel.TaskExitZombie {
+ return 0, syserror.ENOENT
+ }
+ return fd.GenericDirectoryFD.Seek(ctx, offset, whence)
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *subtasksFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ if fd.task.ExitState() >= kernel.TaskExitZombie {
+ return linux.Statx{}, syserror.ENOENT
+ }
+ return fd.GenericDirectoryFD.Stat(ctx, opts)
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *subtasksFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ if fd.task.ExitState() >= kernel.TaskExitZombie {
+ return syserror.ENOENT
+ }
+ return fd.GenericDirectoryFD.SetStat(ctx, opts)
+}
+
+// Open implements kernfs.Inode.
+func (i *subtasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &subtasksFD{task: i.task}
+ if err := fd.Init(&i.OrderedChildren, &i.locks, &opts); err != nil {
+ return nil, err
+ }
+ if err := fd.VFSFileDescription().Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+// Stat implements kernfs.Inode.
+func (i *subtasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ stat, err := i.InodeAttrs.Stat(vsfs, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ if opts.Mask&linux.STATX_NLINK != 0 {
+ stat.Nlink += uint32(i.task.ThreadGroup().Count())
+ }
+ return stat, nil
+}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*subtasksInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
new file mode 100644
index 000000000..8bb2b0ce1
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -0,0 +1,239 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// taskInode represents the inode for /proc/PID/ directory.
+//
+// +stateify savable
+type taskInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeNoDynamicLookup
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
+
+ locks vfs.FileLocks
+
+ task *kernel.Task
+}
+
+var _ kernfs.Inode = (*taskInode)(nil)
+
+func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) *kernfs.Dentry {
+ // TODO(gvisor.dev/issue/164): Fail with ESRCH if task exited.
+ contents := map[string]*kernfs.Dentry{
+ "auxv": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &auxvData{task: task}),
+ "cmdline": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}),
+ "comm": fs.newComm(task, fs.NextIno(), 0444),
+ "environ": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}),
+ "exe": fs.newExeSymlink(task, fs.NextIno()),
+ "fd": fs.newFDDirInode(task),
+ "fdinfo": fs.newFDInfoDirInode(task),
+ "gid_map": fs.newTaskOwnedFile(task, fs.NextIno(), 0644, &idMapData{task: task, gids: true}),
+ "io": fs.newTaskOwnedFile(task, fs.NextIno(), 0400, newIO(task, isThreadGroup)),
+ "maps": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &mapsData{task: task}),
+ "mountinfo": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &mountInfoData{task: task}),
+ "mounts": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &mountsData{task: task}),
+ "net": fs.newTaskNetDir(task),
+ "ns": fs.newTaskOwnedDir(task, fs.NextIno(), 0511, map[string]*kernfs.Dentry{
+ "net": fs.newNamespaceSymlink(task, fs.NextIno(), "net"),
+ "pid": fs.newNamespaceSymlink(task, fs.NextIno(), "pid"),
+ "user": fs.newNamespaceSymlink(task, fs.NextIno(), "user"),
+ }),
+ "oom_score": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, newStaticFile("0\n")),
+ "oom_score_adj": fs.newTaskOwnedFile(task, fs.NextIno(), 0644, &oomScoreAdj{task: task}),
+ "smaps": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &smapsData{task: task}),
+ "stat": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}),
+ "statm": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &statmData{task: task}),
+ "status": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &statusData{task: task, pidns: pidns}),
+ "uid_map": fs.newTaskOwnedFile(task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}),
+ }
+ if isThreadGroup {
+ contents["task"] = fs.newSubtasks(task, pidns, cgroupControllers)
+ }
+ if len(cgroupControllers) > 0 {
+ contents["cgroup"] = fs.newTaskOwnedFile(task, fs.NextIno(), 0444, newCgroupData(cgroupControllers))
+ }
+
+ taskInode := &taskInode{task: task}
+ // Note: credentials are overridden by taskOwnedInode.
+ taskInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+
+ inode := &taskOwnedInode{Inode: taskInode, owner: task}
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
+
+ taskInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ links := taskInode.OrderedChildren.Populate(dentry, contents)
+ taskInode.IncLinks(links)
+
+ return dentry
+}
+
+// Valid implements kernfs.inodeDynamicLookup. This inode remains valid as long
+// as the task is still running. When it's dead, another tasks with the same
+// PID could replace it.
+func (i *taskInode) Valid(ctx context.Context) bool {
+ return i.task.ExitState() != kernel.TaskExitDead
+}
+
+// Open implements kernfs.Inode.
+func (i *taskInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*taskInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// taskOwnedInode implements kernfs.Inode and overrides inode owner with task
+// effective user and group.
+type taskOwnedInode struct {
+ kernfs.Inode
+
+ // owner is the task that owns this inode.
+ owner *kernel.Task
+}
+
+var _ kernfs.Inode = (*taskOwnedInode)(nil)
+
+func (fs *filesystem) newTaskOwnedFile(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry {
+ // Note: credentials are overridden by taskOwnedInode.
+ inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm)
+
+ taskInode := &taskOwnedInode{Inode: inode, owner: task}
+ d := &kernfs.Dentry{}
+ d.Init(taskInode)
+ return d
+}
+
+func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]*kernfs.Dentry) *kernfs.Dentry {
+ dir := &kernfs.StaticDirectory{}
+
+ // Note: credentials are overridden by taskOwnedInode.
+ dir.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm)
+
+ inode := &taskOwnedInode{Inode: dir, owner: task}
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+
+ dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ links := dir.OrderedChildren.Populate(d, children)
+ dir.IncLinks(links)
+
+ return d
+}
+
+// Stat implements kernfs.Inode.
+func (i *taskOwnedInode) Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ stat, err := i.Inode.Stat(fs, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ if opts.Mask&(linux.STATX_UID|linux.STATX_GID) != 0 {
+ uid, gid := i.getOwner(linux.FileMode(stat.Mode))
+ if opts.Mask&linux.STATX_UID != 0 {
+ stat.UID = uint32(uid)
+ }
+ if opts.Mask&linux.STATX_GID != 0 {
+ stat.GID = uint32(gid)
+ }
+ }
+ return stat, nil
+}
+
+// CheckPermissions implements kernfs.Inode.
+func (i *taskOwnedInode) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ mode := i.Mode()
+ uid, gid := i.getOwner(mode)
+ return vfs.GenericCheckPermissions(creds, ats, mode, uid, gid)
+}
+
+func (i *taskOwnedInode) getOwner(mode linux.FileMode) (auth.KUID, auth.KGID) {
+ // By default, set the task owner as the file owner.
+ creds := i.owner.Credentials()
+ uid := creds.EffectiveKUID
+ gid := creds.EffectiveKGID
+
+ // Linux doesn't apply dumpability adjustments to world readable/executable
+ // directories so that applications can stat /proc/PID to determine the
+ // effective UID of a process. See fs/proc/base.c:task_dump_owner.
+ if mode.FileType() == linux.ModeDirectory && mode.Permissions() == 0555 {
+ return uid, gid
+ }
+
+ // If the task is not dumpable, then root (in the namespace preferred)
+ // owns the file.
+ m := getMM(i.owner)
+ if m == nil {
+ return auth.RootKUID, auth.RootKGID
+ }
+ if m.Dumpability() != mm.UserDumpable {
+ uid = auth.RootKUID
+ if kuid := creds.UserNamespace.MapToKUID(auth.RootUID); kuid.Ok() {
+ uid = kuid
+ }
+ gid = auth.RootKGID
+ if kgid := creds.UserNamespace.MapToKGID(auth.RootGID); kgid.Ok() {
+ gid = kgid
+ }
+ }
+ return uid, gid
+}
+
+func newIO(t *kernel.Task, isThreadGroup bool) *ioData {
+ if isThreadGroup {
+ return &ioData{ioUsage: t.ThreadGroup()}
+ }
+ return &ioData{ioUsage: t}
+}
+
+// newCgroupData creates inode that shows cgroup information.
+// From man 7 cgroups: "For each cgroup hierarchy of which the process is a
+// member, there is one entry containing three colon-separated fields:
+// hierarchy-ID:controller-list:cgroup-path"
+func newCgroupData(controllers map[string]string) dynamicInode {
+ var buf bytes.Buffer
+
+ // The hierarchy ids must be positive integers (for cgroup v1), but the
+ // exact number does not matter, so long as they are unique. We can
+ // just use a counter, but since linux sorts this file in descending
+ // order, we must count down to preserve this behavior.
+ i := len(controllers)
+ for name, dir := range controllers {
+ fmt.Fprintf(&buf, "%d:%s:%s\n", i, name, dir)
+ i--
+ }
+ return newStaticFile(buf.String())
+}
diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go
new file mode 100644
index 000000000..fea29e5f0
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/task_fds.go
@@ -0,0 +1,307 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+func getTaskFD(t *kernel.Task, fd int32) (*vfs.FileDescription, kernel.FDFlags) {
+ var (
+ file *vfs.FileDescription
+ flags kernel.FDFlags
+ )
+ t.WithMuLocked(func(t *kernel.Task) {
+ if fdt := t.FDTable(); fdt != nil {
+ file, flags = fdt.GetVFS2(fd)
+ }
+ })
+ return file, flags
+}
+
+func taskFDExists(t *kernel.Task, fd int32) bool {
+ file, _ := getTaskFD(t, fd)
+ if file == nil {
+ return false
+ }
+ file.DecRef()
+ return true
+}
+
+type fdDir struct {
+ locks vfs.FileLocks
+
+ fs *filesystem
+ task *kernel.Task
+
+ // When produceSymlinks is set, dirents produces for the FDs are reported
+ // as symlink. Otherwise, they are reported as regular files.
+ produceSymlink bool
+}
+
+// IterDirents implements kernfs.inodeDynamicLookup.
+func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+ var fds []int32
+ i.task.WithMuLocked(func(t *kernel.Task) {
+ if fdTable := t.FDTable(); fdTable != nil {
+ fds = fdTable.GetFDs()
+ }
+ })
+
+ typ := uint8(linux.DT_REG)
+ if i.produceSymlink {
+ typ = linux.DT_LNK
+ }
+
+ // Find the appropriate starting point.
+ idx := sort.Search(len(fds), func(i int) bool { return fds[i] >= int32(relOffset) })
+ if idx >= len(fds) {
+ return offset, nil
+ }
+ for _, fd := range fds[idx:] {
+ dirent := vfs.Dirent{
+ Name: strconv.FormatUint(uint64(fd), 10),
+ Type: typ,
+ Ino: i.fs.NextIno(),
+ NextOff: offset + 1,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+ return offset, nil
+}
+
+// fdDirInode represents the inode for /proc/[pid]/fd directory.
+//
+// +stateify savable
+type fdDirInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
+ kernfs.AlwaysValid
+ fdDir
+}
+
+var _ kernfs.Inode = (*fdDirInode)(nil)
+
+func (fs *filesystem) newFDDirInode(task *kernel.Task) *kernfs.Dentry {
+ inode := &fdDirInode{
+ fdDir: fdDir{
+ fs: fs,
+ task: task,
+ produceSymlink: true,
+ },
+ }
+ inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
+ inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+
+ return dentry
+}
+
+// Lookup implements kernfs.inodeDynamicLookup.
+func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ fdInt, err := strconv.ParseInt(name, 10, 32)
+ if err != nil {
+ return nil, syserror.ENOENT
+ }
+ fd := int32(fdInt)
+ if !taskFDExists(i.task, fd) {
+ return nil, syserror.ENOENT
+ }
+ taskDentry := i.fs.newFDSymlink(i.task, fd, i.fs.NextIno())
+ return taskDentry.VFSDentry(), nil
+}
+
+// Open implements kernfs.Inode.
+func (i *fdDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+// CheckPermissions implements kernfs.Inode.
+//
+// This is to match Linux, which uses a special permission handler to guarantee
+// that a process can still access /proc/self/fd after it has executed
+// setuid. See fs/proc/fd.c:proc_fd_permission.
+func (i *fdDirInode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ err := i.InodeAttrs.CheckPermissions(ctx, creds, ats)
+ if err == nil {
+ // Access granted, no extra check needed.
+ return nil
+ }
+ if t := kernel.TaskFromContext(ctx); t != nil {
+ // Allow access if the task trying to access it is in the thread group
+ // corresponding to this directory.
+ if i.task.ThreadGroup() == t.ThreadGroup() {
+ // Access granted (overridden).
+ return nil
+ }
+ }
+ return err
+}
+
+// fdSymlink is an symlink for the /proc/[pid]/fd/[fd] file.
+//
+// +stateify savable
+type fdSymlink struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeSymlink
+
+ task *kernel.Task
+ fd int32
+}
+
+var _ kernfs.Inode = (*fdSymlink)(nil)
+
+func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) *kernfs.Dentry {
+ inode := &fdSymlink{
+ task: task,
+ fd: fd,
+ }
+ inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+func (s *fdSymlink) Readlink(ctx context.Context) (string, error) {
+ file, _ := getTaskFD(s.task, s.fd)
+ if file == nil {
+ return "", syserror.ENOENT
+ }
+ defer file.DecRef()
+ root := vfs.RootFromContext(ctx)
+ defer root.DecRef()
+ return s.task.Kernel().VFS().PathnameWithDeleted(ctx, root, file.VirtualDentry())
+}
+
+func (s *fdSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) {
+ file, _ := getTaskFD(s.task, s.fd)
+ if file == nil {
+ return vfs.VirtualDentry{}, "", syserror.ENOENT
+ }
+ defer file.DecRef()
+ vd := file.VirtualDentry()
+ vd.IncRef()
+ return vd, "", nil
+}
+
+// fdInfoDirInode represents the inode for /proc/[pid]/fdinfo directory.
+//
+// +stateify savable
+type fdInfoDirInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
+ kernfs.AlwaysValid
+ fdDir
+}
+
+var _ kernfs.Inode = (*fdInfoDirInode)(nil)
+
+func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) *kernfs.Dentry {
+ inode := &fdInfoDirInode{
+ fdDir: fdDir{
+ fs: fs,
+ task: task,
+ },
+ }
+ inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
+ inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+
+ return dentry
+}
+
+// Lookup implements kernfs.inodeDynamicLookup.
+func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ fdInt, err := strconv.ParseInt(name, 10, 32)
+ if err != nil {
+ return nil, syserror.ENOENT
+ }
+ fd := int32(fdInt)
+ if !taskFDExists(i.task, fd) {
+ return nil, syserror.ENOENT
+ }
+ data := &fdInfoData{
+ task: i.task,
+ fd: fd,
+ }
+ dentry := i.fs.newTaskOwnedFile(i.task, i.fs.NextIno(), 0444, data)
+ return dentry.VFSDentry(), nil
+}
+
+// Open implements kernfs.Inode.
+func (i *fdInfoDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+// fdInfoData implements vfs.DynamicBytesSource for /proc/[pid]/fdinfo/[fd].
+//
+// +stateify savable
+type fdInfoData struct {
+ kernfs.DynamicBytesFile
+ refs.AtomicRefCount
+
+ task *kernel.Task
+ fd int32
+}
+
+var _ dynamicInode = (*fdInfoData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *fdInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ file, descriptorFlags := getTaskFD(d.task, d.fd)
+ if file == nil {
+ return syserror.ENOENT
+ }
+ defer file.DecRef()
+ // TODO(b/121266871): Include pos, locks, and other data. For now we only
+ // have flags.
+ // See https://www.kernel.org/doc/Documentation/filesystems/proc.txt
+ flags := uint(file.StatusFlags()) | descriptorFlags.ToLinuxFileFlags()
+ fmt.Fprintf(buf, "flags:\t0%o\n", flags)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
new file mode 100644
index 000000000..9af43b859
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -0,0 +1,902 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// "There is an (arbitrary) limit on the number of lines in the file. As at
+// Linux 3.18, the limit is five lines." - user_namespaces(7)
+const maxIDMapLines = 5
+
+// mm gets the kernel task's MemoryManager. No additional reference is taken on
+// mm here. This is safe because MemoryManager.destroy is required to leave the
+// MemoryManager in a state where it's still usable as a DynamicBytesSource.
+func getMM(task *kernel.Task) *mm.MemoryManager {
+ var tmm *mm.MemoryManager
+ task.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ tmm = mm
+ }
+ })
+ return tmm
+}
+
+// getMMIncRef returns t's MemoryManager. If getMMIncRef succeeds, the
+// MemoryManager's users count is incremented, and must be decremented by the
+// caller when it is no longer in use.
+func getMMIncRef(task *kernel.Task) (*mm.MemoryManager, error) {
+ if task.ExitState() == kernel.TaskExitDead {
+ return nil, syserror.ESRCH
+ }
+ var m *mm.MemoryManager
+ task.WithMuLocked(func(t *kernel.Task) {
+ m = t.MemoryManager()
+ })
+ if m == nil || !m.IncUsers() {
+ return nil, io.EOF
+ }
+ return m, nil
+}
+
+func checkTaskState(t *kernel.Task) error {
+ switch t.ExitState() {
+ case kernel.TaskExitZombie:
+ return syserror.EACCES
+ case kernel.TaskExitDead:
+ return syserror.ESRCH
+ }
+ return nil
+}
+
+type bufferWriter struct {
+ buf *bytes.Buffer
+}
+
+// WriteFromBlocks writes up to srcs.NumBytes() bytes from srcs and returns
+// the number of bytes written. It may return a partial write without an
+// error (i.e. (n, nil) where 0 < n < srcs.NumBytes()). It should not
+// return a full write with an error (i.e. srcs.NumBytes(), err) where err
+// != nil).
+func (w *bufferWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ written := srcs.NumBytes()
+ for !srcs.IsEmpty() {
+ w.buf.Write(srcs.Head().ToSlice())
+ srcs = srcs.Tail()
+ }
+ return written, nil
+}
+
+// auxvData implements vfs.DynamicBytesSource for /proc/[pid]/auxv.
+//
+// +stateify savable
+type auxvData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*auxvData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *auxvData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ m, err := getMMIncRef(d.task)
+ if err != nil {
+ return err
+ }
+ defer m.DecUsers(ctx)
+
+ auxv := m.Auxv()
+ // Space for buffer with AT_NULL (0) terminator at the end.
+ buf.Grow((len(auxv) + 1) * 16)
+ for _, e := range auxv {
+ var tmp [16]byte
+ usermem.ByteOrder.PutUint64(tmp[:8], e.Key)
+ usermem.ByteOrder.PutUint64(tmp[8:], uint64(e.Value))
+ buf.Write(tmp[:])
+ }
+ var atNull [16]byte
+ buf.Write(atNull[:])
+
+ return nil
+}
+
+// execArgType enumerates the types of exec arguments that are exposed through
+// proc.
+type execArgType int
+
+const (
+ cmdlineDataArg execArgType = iota
+ environDataArg
+)
+
+// cmdlineData implements vfs.DynamicBytesSource for /proc/[pid]/cmdline.
+//
+// +stateify savable
+type cmdlineData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+
+ // arg is the type of exec argument this file contains.
+ arg execArgType
+}
+
+var _ dynamicInode = (*cmdlineData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ m, err := getMMIncRef(d.task)
+ if err != nil {
+ return err
+ }
+ defer m.DecUsers(ctx)
+
+ // Figure out the bounds of the exec arg we are trying to read.
+ var ar usermem.AddrRange
+ switch d.arg {
+ case cmdlineDataArg:
+ ar = usermem.AddrRange{
+ Start: m.ArgvStart(),
+ End: m.ArgvEnd(),
+ }
+ case environDataArg:
+ ar = usermem.AddrRange{
+ Start: m.EnvvStart(),
+ End: m.EnvvEnd(),
+ }
+ default:
+ panic(fmt.Sprintf("unknown exec arg type %v", d.arg))
+ }
+ if ar.Start == 0 || ar.End == 0 {
+ // Don't attempt to read before the start/end are set up.
+ return io.EOF
+ }
+
+ // N.B. Technically this should be usermem.IOOpts.IgnorePermissions = true
+ // until Linux 4.9 (272ddc8b3735 "proc: don't use FOLL_FORCE for reading
+ // cmdline and environment").
+ writer := &bufferWriter{buf: buf}
+ if n, err := m.CopyInTo(ctx, usermem.AddrRangeSeqOf(ar), writer, usermem.IOOpts{}); n == 0 || err != nil {
+ // Nothing to copy or something went wrong.
+ return err
+ }
+
+ // On Linux, if the NULL byte at the end of the argument vector has been
+ // overwritten, it continues reading the environment vector as part of
+ // the argument vector.
+ if d.arg == cmdlineDataArg && buf.Bytes()[buf.Len()-1] != 0 {
+ if end := bytes.IndexByte(buf.Bytes(), 0); end != -1 {
+ // If we found a NULL character somewhere else in argv, truncate the
+ // return up to the NULL terminator (including it).
+ buf.Truncate(end)
+ return nil
+ }
+
+ // There is no NULL terminator in the string, return into envp.
+ arEnvv := usermem.AddrRange{
+ Start: m.EnvvStart(),
+ End: m.EnvvEnd(),
+ }
+
+ // Upstream limits the returned amount to one page of slop.
+ // https://elixir.bootlin.com/linux/v4.20/source/fs/proc/base.c#L208
+ // we'll return one page total between argv and envp because of the
+ // above page restrictions.
+ if buf.Len() >= usermem.PageSize {
+ // Returned at least one page already, nothing else to add.
+ return nil
+ }
+ remaining := usermem.PageSize - buf.Len()
+ if int(arEnvv.Length()) > remaining {
+ end, ok := arEnvv.Start.AddLength(uint64(remaining))
+ if !ok {
+ return syserror.EFAULT
+ }
+ arEnvv.End = end
+ }
+ if _, err := m.CopyInTo(ctx, usermem.AddrRangeSeqOf(arEnvv), writer, usermem.IOOpts{}); err != nil {
+ return err
+ }
+
+ // Linux will return envp up to and including the first NULL character,
+ // so find it.
+ envStart := int(ar.Length())
+ if nullIdx := bytes.IndexByte(buf.Bytes()[envStart:], 0); nullIdx != -1 {
+ buf.Truncate(envStart + nullIdx)
+ }
+ }
+
+ return nil
+}
+
+// +stateify savable
+type commInode struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+func (fs *filesystem) newComm(task *kernel.Task, ino uint64, perm linux.FileMode) *kernfs.Dentry {
+ inode := &commInode{task: task}
+ inode.DynamicBytesFile.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+func (i *commInode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ // This file can always be read or written by members of the same thread
+ // group. See fs/proc/base.c:proc_tid_comm_permission.
+ //
+ // N.B. This check is currently a no-op as we don't yet support writing and
+ // this file is world-readable anyways.
+ t := kernel.TaskFromContext(ctx)
+ if t != nil && t.ThreadGroup() == i.task.ThreadGroup() && !ats.MayExec() {
+ return nil
+ }
+
+ return i.DynamicBytesFile.CheckPermissions(ctx, creds, ats)
+}
+
+// commData implements vfs.DynamicBytesSource for /proc/[pid]/comm.
+//
+// +stateify savable
+type commData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*commData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *commData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ buf.WriteString(d.task.Name())
+ buf.WriteString("\n")
+ return nil
+}
+
+// idMapData implements vfs.WritableDynamicBytesSource for
+// /proc/[pid]/{gid_map|uid_map}.
+//
+// +stateify savable
+type idMapData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+ gids bool
+}
+
+var _ dynamicInode = (*idMapData)(nil)
+
+// Generate implements vfs.WritableDynamicBytesSource.Generate.
+func (d *idMapData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ var entries []auth.IDMapEntry
+ if d.gids {
+ entries = d.task.UserNamespace().GIDMap()
+ } else {
+ entries = d.task.UserNamespace().UIDMap()
+ }
+ for _, e := range entries {
+ fmt.Fprintf(buf, "%10d %10d %10d\n", e.FirstID, e.FirstParentID, e.Length)
+ }
+ return nil
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ // "In addition, the number of bytes written to the file must be less than
+ // the system page size, and the write must be performed at the start of
+ // the file ..." - user_namespaces(7)
+ srclen := src.NumBytes()
+ if srclen >= usermem.PageSize || offset != 0 {
+ return 0, syserror.EINVAL
+ }
+ b := make([]byte, srclen)
+ if _, err := src.CopyIn(ctx, b); err != nil {
+ return 0, err
+ }
+
+ // Truncate from the first NULL byte.
+ var nul int64
+ nul = int64(bytes.IndexByte(b, 0))
+ if nul == -1 {
+ nul = srclen
+ }
+ b = b[:nul]
+ // Remove the last \n.
+ if nul >= 1 && b[nul-1] == '\n' {
+ b = b[:nul-1]
+ }
+ lines := bytes.SplitN(b, []byte("\n"), maxIDMapLines+1)
+ if len(lines) > maxIDMapLines {
+ return 0, syserror.EINVAL
+ }
+
+ entries := make([]auth.IDMapEntry, len(lines))
+ for i, l := range lines {
+ var e auth.IDMapEntry
+ _, err := fmt.Sscan(string(l), &e.FirstID, &e.FirstParentID, &e.Length)
+ if err != nil {
+ return 0, syserror.EINVAL
+ }
+ entries[i] = e
+ }
+ var err error
+ if d.gids {
+ err = d.task.UserNamespace().SetGIDMap(ctx, entries)
+ } else {
+ err = d.task.UserNamespace().SetUIDMap(ctx, entries)
+ }
+ if err != nil {
+ return 0, err
+ }
+
+ // On success, Linux's kernel/user_namespace.c:map_write() always returns
+ // count, even if fewer bytes were used.
+ return int64(srclen), nil
+}
+
+// mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps.
+//
+// +stateify savable
+type mapsData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*mapsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *mapsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if mm := getMM(d.task); mm != nil {
+ mm.ReadMapsDataInto(ctx, buf)
+ }
+ return nil
+}
+
+// smapsData implements vfs.DynamicBytesSource for /proc/[pid]/smaps.
+//
+// +stateify savable
+type smapsData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*smapsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *smapsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if mm := getMM(d.task); mm != nil {
+ mm.ReadSmapsDataInto(ctx, buf)
+ }
+ return nil
+}
+
+// +stateify savable
+type taskStatData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+
+ // If tgstats is true, accumulate fault stats (not implemented) and CPU
+ // time across all tasks in t's thread group.
+ tgstats bool
+
+ // pidns is the PID namespace associated with the proc filesystem that
+ // includes the file using this statData.
+ pidns *kernel.PIDNamespace
+}
+
+var _ dynamicInode = (*taskStatData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%d ", s.pidns.IDOfTask(s.task))
+ fmt.Fprintf(buf, "(%s) ", s.task.Name())
+ fmt.Fprintf(buf, "%c ", s.task.StateStatus()[0])
+ ppid := kernel.ThreadID(0)
+ if parent := s.task.Parent(); parent != nil {
+ ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
+ }
+ fmt.Fprintf(buf, "%d ", ppid)
+ fmt.Fprintf(buf, "%d ", s.pidns.IDOfProcessGroup(s.task.ThreadGroup().ProcessGroup()))
+ fmt.Fprintf(buf, "%d ", s.pidns.IDOfSession(s.task.ThreadGroup().Session()))
+ fmt.Fprintf(buf, "0 0 " /* tty_nr tpgid */)
+ fmt.Fprintf(buf, "0 " /* flags */)
+ fmt.Fprintf(buf, "0 0 0 0 " /* minflt cminflt majflt cmajflt */)
+ var cputime usage.CPUStats
+ if s.tgstats {
+ cputime = s.task.ThreadGroup().CPUStats()
+ } else {
+ cputime = s.task.CPUStats()
+ }
+ fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
+ cputime = s.task.ThreadGroup().JoinedChildCPUStats()
+ fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
+ fmt.Fprintf(buf, "%d %d ", s.task.Priority(), s.task.Niceness())
+ fmt.Fprintf(buf, "%d ", s.task.ThreadGroup().Count())
+
+ // itrealvalue. Since kernel 2.6.17, this field is no longer
+ // maintained, and is hard coded as 0.
+ fmt.Fprintf(buf, "0 ")
+
+ // Start time is relative to boot time, expressed in clock ticks.
+ fmt.Fprintf(buf, "%d ", linux.ClockTFromDuration(s.task.StartTime().Sub(s.task.Kernel().Timekeeper().BootTime())))
+
+ var vss, rss uint64
+ s.task.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
+ })
+ fmt.Fprintf(buf, "%d %d ", vss, rss/usermem.PageSize)
+
+ // rsslim.
+ fmt.Fprintf(buf, "%d ", s.task.ThreadGroup().Limits().Get(limits.Rss).Cur)
+
+ fmt.Fprintf(buf, "0 0 0 0 0 " /* startcode endcode startstack kstkesp kstkeip */)
+ fmt.Fprintf(buf, "0 0 0 0 0 " /* signal blocked sigignore sigcatch wchan */)
+ fmt.Fprintf(buf, "0 0 " /* nswap cnswap */)
+ terminationSignal := linux.Signal(0)
+ if s.task == s.task.ThreadGroup().Leader() {
+ terminationSignal = s.task.ThreadGroup().TerminationSignal()
+ }
+ fmt.Fprintf(buf, "%d ", terminationSignal)
+ fmt.Fprintf(buf, "0 0 0 " /* processor rt_priority policy */)
+ fmt.Fprintf(buf, "0 0 0 " /* delayacct_blkio_ticks guest_time cguest_time */)
+ fmt.Fprintf(buf, "0 0 0 0 0 0 0 " /* start_data end_data start_brk arg_start arg_end env_start env_end */)
+ fmt.Fprintf(buf, "0\n" /* exit_code */)
+
+ return nil
+}
+
+// statmData implements vfs.DynamicBytesSource for /proc/[pid]/statm.
+//
+// +stateify savable
+type statmData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*statmData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ var vss, rss uint64
+ s.task.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
+ })
+
+ fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/usermem.PageSize, rss/usermem.PageSize)
+ return nil
+}
+
+// statusData implements vfs.DynamicBytesSource for /proc/[pid]/status.
+//
+// +stateify savable
+type statusData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+ pidns *kernel.PIDNamespace
+}
+
+var _ dynamicInode = (*statusData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "Name:\t%s\n", s.task.Name())
+ fmt.Fprintf(buf, "State:\t%s\n", s.task.StateStatus())
+ fmt.Fprintf(buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.task.ThreadGroup()))
+ fmt.Fprintf(buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.task))
+ ppid := kernel.ThreadID(0)
+ if parent := s.task.Parent(); parent != nil {
+ ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
+ }
+ fmt.Fprintf(buf, "PPid:\t%d\n", ppid)
+ tpid := kernel.ThreadID(0)
+ if tracer := s.task.Tracer(); tracer != nil {
+ tpid = s.pidns.IDOfTask(tracer)
+ }
+ fmt.Fprintf(buf, "TracerPid:\t%d\n", tpid)
+ var fds int
+ var vss, rss, data uint64
+ s.task.WithMuLocked(func(t *kernel.Task) {
+ if fdTable := t.FDTable(); fdTable != nil {
+ fds = fdTable.Size()
+ }
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ data = mm.VirtualDataSize()
+ }
+ })
+ fmt.Fprintf(buf, "FDSize:\t%d\n", fds)
+ fmt.Fprintf(buf, "VmSize:\t%d kB\n", vss>>10)
+ fmt.Fprintf(buf, "VmRSS:\t%d kB\n", rss>>10)
+ fmt.Fprintf(buf, "VmData:\t%d kB\n", data>>10)
+ fmt.Fprintf(buf, "Threads:\t%d\n", s.task.ThreadGroup().Count())
+ creds := s.task.Credentials()
+ fmt.Fprintf(buf, "CapInh:\t%016x\n", creds.InheritableCaps)
+ fmt.Fprintf(buf, "CapPrm:\t%016x\n", creds.PermittedCaps)
+ fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps)
+ fmt.Fprintf(buf, "CapBnd:\t%016x\n", creds.BoundingCaps)
+ fmt.Fprintf(buf, "Seccomp:\t%d\n", s.task.SeccompMode())
+ // We unconditionally report a single NUMA node. See
+ // pkg/sentry/syscalls/linux/sys_mempolicy.go.
+ fmt.Fprintf(buf, "Mems_allowed:\t1\n")
+ fmt.Fprintf(buf, "Mems_allowed_list:\t0\n")
+ return nil
+}
+
+// ioUsage is the /proc/[pid]/io and /proc/[pid]/task/[tid]/io data provider.
+type ioUsage interface {
+ // IOUsage returns the io usage data.
+ IOUsage() *usage.IO
+}
+
+// +stateify savable
+type ioData struct {
+ kernfs.DynamicBytesFile
+
+ ioUsage
+}
+
+var _ dynamicInode = (*ioData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (i *ioData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ io := usage.IO{}
+ io.Accumulate(i.IOUsage())
+
+ fmt.Fprintf(buf, "char: %d\n", io.CharsRead)
+ fmt.Fprintf(buf, "wchar: %d\n", io.CharsWritten)
+ fmt.Fprintf(buf, "syscr: %d\n", io.ReadSyscalls)
+ fmt.Fprintf(buf, "syscw: %d\n", io.WriteSyscalls)
+ fmt.Fprintf(buf, "read_bytes: %d\n", io.BytesRead)
+ fmt.Fprintf(buf, "write_bytes: %d\n", io.BytesWritten)
+ fmt.Fprintf(buf, "cancelled_write_bytes: %d\n", io.BytesWriteCancelled)
+ return nil
+}
+
+// oomScoreAdj is a stub of the /proc/<pid>/oom_score_adj file.
+//
+// +stateify savable
+type oomScoreAdj struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ vfs.WritableDynamicBytesSource = (*oomScoreAdj)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (o *oomScoreAdj) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if o.task.ExitState() == kernel.TaskExitDead {
+ return syserror.ESRCH
+ }
+ fmt.Fprintf(buf, "%d\n", o.task.OOMScoreAdj())
+ return nil
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (o *oomScoreAdj) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Limit input size so as not to impact performance if input size is large.
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ var v int32
+ n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
+ if err != nil {
+ return 0, err
+ }
+
+ if o.task.ExitState() == kernel.TaskExitDead {
+ return 0, syserror.ESRCH
+ }
+ if err := o.task.SetOOMScoreAdj(v); err != nil {
+ return 0, err
+ }
+
+ return n, nil
+}
+
+// exeSymlink is an symlink for the /proc/[pid]/exe file.
+//
+// +stateify savable
+type exeSymlink struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeSymlink
+
+ task *kernel.Task
+}
+
+var _ kernfs.Inode = (*exeSymlink)(nil)
+
+func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) *kernfs.Dentry {
+ inode := &exeSymlink{task: task}
+ inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+// Readlink implements kernfs.Inode.
+func (s *exeSymlink) Readlink(ctx context.Context) (string, error) {
+ if !kernel.ContextCanTrace(ctx, s.task, false) {
+ return "", syserror.EACCES
+ }
+
+ // Pull out the executable for /proc/[pid]/exe.
+ exec, err := s.executable()
+ if err != nil {
+ return "", err
+ }
+ defer exec.DecRef()
+
+ return exec.PathnameWithDeleted(ctx), nil
+}
+
+// Getlink implements kernfs.Inode.Getlink.
+func (s *exeSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDentry, string, error) {
+ if !kernel.ContextCanTrace(ctx, s.task, false) {
+ return vfs.VirtualDentry{}, "", syserror.EACCES
+ }
+
+ exec, err := s.executable()
+ if err != nil {
+ return vfs.VirtualDentry{}, "", err
+ }
+ defer exec.DecRef()
+
+ vd := exec.(*fsbridge.VFSFile).FileDescription().VirtualDentry()
+ vd.IncRef()
+ return vd, "", nil
+}
+
+func (s *exeSymlink) executable() (file fsbridge.File, err error) {
+ if err := checkTaskState(s.task); err != nil {
+ return nil, err
+ }
+
+ s.task.WithMuLocked(func(t *kernel.Task) {
+ mm := t.MemoryManager()
+ if mm == nil {
+ err = syserror.EACCES
+ return
+ }
+
+ // The MemoryManager may be destroyed, in which case
+ // MemoryManager.destroy will simply set the executable to nil
+ // (with locks held).
+ file = mm.Executable()
+ if file == nil {
+ err = syserror.ESRCH
+ }
+ })
+ return
+}
+
+// mountInfoData is used to implement /proc/[pid]/mountinfo.
+//
+// +stateify savable
+type mountInfoData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*mountInfoData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (i *mountInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ var fsctx *kernel.FSContext
+ i.task.WithMuLocked(func(t *kernel.Task) {
+ fsctx = t.FSContext()
+ })
+ if fsctx == nil {
+ // The task has been destroyed. Nothing to show here.
+ return nil
+ }
+ rootDir := fsctx.RootDirectoryVFS2()
+ if !rootDir.Ok() {
+ // Root has been destroyed. Don't try to read mounts.
+ return nil
+ }
+ defer rootDir.DecRef()
+ i.task.Kernel().VFS().GenerateProcMountInfo(ctx, rootDir, buf)
+ return nil
+}
+
+// mountsData is used to implement /proc/[pid]/mounts.
+//
+// +stateify savable
+type mountsData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*mountsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (i *mountsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ var fsctx *kernel.FSContext
+ i.task.WithMuLocked(func(t *kernel.Task) {
+ fsctx = t.FSContext()
+ })
+ if fsctx == nil {
+ // The task has been destroyed. Nothing to show here.
+ return nil
+ }
+ rootDir := fsctx.RootDirectoryVFS2()
+ if !rootDir.Ok() {
+ // Root has been destroyed. Don't try to read mounts.
+ return nil
+ }
+ defer rootDir.DecRef()
+ i.task.Kernel().VFS().GenerateProcMounts(ctx, rootDir, buf)
+ return nil
+}
+
+type namespaceSymlink struct {
+ kernfs.StaticSymlink
+
+ task *kernel.Task
+}
+
+func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentry {
+ // Namespace symlinks should contain the namespace name and the inode number
+ // for the namespace instance, so for example user:[123456]. We currently fake
+ // the inode number by sticking the symlink inode in its place.
+ target := fmt.Sprintf("%s:[%d]", ns, ino)
+
+ inode := &namespaceSymlink{task: task}
+ // Note: credentials are overridden by taskOwnedInode.
+ inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target)
+
+ taskInode := &taskOwnedInode{Inode: inode, owner: task}
+ d := &kernfs.Dentry{}
+ d.Init(taskInode)
+ return d
+}
+
+// Readlink implements Inode.
+func (s *namespaceSymlink) Readlink(ctx context.Context) (string, error) {
+ if err := checkTaskState(s.task); err != nil {
+ return "", err
+ }
+ return s.StaticSymlink.Readlink(ctx)
+}
+
+// Getlink implements Inode.Getlink.
+func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) {
+ if err := checkTaskState(s.task); err != nil {
+ return vfs.VirtualDentry{}, "", err
+ }
+
+ // Create a synthetic inode to represent the namespace.
+ dentry := &kernfs.Dentry{}
+ dentry.Init(&namespaceInode{})
+ vd := vfs.MakeVirtualDentry(mnt, dentry.VFSDentry())
+ vd.IncRef()
+ dentry.DecRef()
+ return vd, "", nil
+}
+
+// namespaceInode is a synthetic inode created to represent a namespace in
+// /proc/[pid]/ns/*.
+type namespaceInode struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+
+ locks vfs.FileLocks
+}
+
+var _ kernfs.Inode = (*namespaceInode)(nil)
+
+// Init initializes a namespace inode.
+func (i *namespaceInode) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
+ if perm&^linux.PermissionsMask != 0 {
+ panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
+ }
+ i.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
+}
+
+// Open implements Inode.Open.
+func (i *namespaceInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &namespaceFD{inode: i}
+ i.IncRef()
+ fd.LockFD.Init(&i.locks)
+ if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// namespace FD is a synthetic file that represents a namespace in
+// /proc/[pid]/ns/*.
+type namespaceFD struct {
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ vfsfd vfs.FileDescription
+ inode *namespaceInode
+}
+
+var _ vfs.FileDescriptionImpl = (*namespaceFD)(nil)
+
+// Stat implements FileDescriptionImpl.
+func (fd *namespaceFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ vfs := fd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return fd.inode.Stat(vfs, opts)
+}
+
+// SetStat implements FileDescriptionImpl.
+func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ vfs := fd.vfsfd.VirtualDentry().Mount().Filesystem()
+ creds := auth.CredentialsFromContext(ctx)
+ return fd.inode.SetStat(ctx, vfs, creds, opts)
+}
+
+// Release implements FileDescriptionImpl.
+func (fd *namespaceFD) Release() {
+ fd.inode.DecRef()
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *namespaceFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *namespaceFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go
new file mode 100644
index 000000000..6bde27376
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/task_net.go
@@ -0,0 +1,810 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "reflect"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func (fs *filesystem) newTaskNetDir(task *kernel.Task) *kernfs.Dentry {
+ k := task.Kernel()
+ pidns := task.PIDNamespace()
+ root := auth.NewRootCredentials(pidns.UserNamespace())
+
+ var contents map[string]*kernfs.Dentry
+ if stack := task.NetworkNamespace().Stack(); stack != nil {
+ const (
+ arp = "IP address HW type Flags HW address Mask Device\n"
+ netlink = "sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n"
+ packet = "sk RefCnt Type Proto Iface R Rmem User Inode\n"
+ protocols = "protocol size sockets memory press maxhdr slab module cl co di ac io in de sh ss gs se re sp bi br ha uh gp em\n"
+ ptype = "Type Device Function\n"
+ upd6 = " sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n"
+ )
+ psched := fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond))
+
+ // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task
+ // network namespace.
+ contents = map[string]*kernfs.Dentry{
+ "dev": fs.newDentry(root, fs.NextIno(), 0444, &netDevData{stack: stack}),
+ "snmp": fs.newDentry(root, fs.NextIno(), 0444, &netSnmpData{stack: stack}),
+
+ // The following files are simple stubs until they are implemented in
+ // netstack, if the file contains a header the stub is just the header
+ // otherwise it is an empty file.
+ "arp": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(arp)),
+ "netlink": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(netlink)),
+ "netstat": fs.newDentry(root, fs.NextIno(), 0444, &netStatData{}),
+ "packet": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(packet)),
+ "protocols": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(protocols)),
+
+ // Linux sets psched values to: nsec per usec, psched tick in ns, 1000000,
+ // high res timer ticks per sec (ClockGetres returns 1ns resolution).
+ "psched": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(psched)),
+ "ptype": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(ptype)),
+ "route": fs.newDentry(root, fs.NextIno(), 0444, &netRouteData{stack: stack}),
+ "tcp": fs.newDentry(root, fs.NextIno(), 0444, &netTCPData{kernel: k}),
+ "udp": fs.newDentry(root, fs.NextIno(), 0444, &netUDPData{kernel: k}),
+ "unix": fs.newDentry(root, fs.NextIno(), 0444, &netUnixData{kernel: k}),
+ }
+
+ if stack.SupportsIPv6() {
+ contents["if_inet6"] = fs.newDentry(root, fs.NextIno(), 0444, &ifinet6{stack: stack})
+ contents["ipv6_route"] = fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(""))
+ contents["tcp6"] = fs.newDentry(root, fs.NextIno(), 0444, &netTCP6Data{kernel: k})
+ contents["udp6"] = fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(upd6))
+ }
+ }
+
+ return fs.newTaskOwnedDir(task, fs.NextIno(), 0555, contents)
+}
+
+// ifinet6 implements vfs.DynamicBytesSource for /proc/net/if_inet6.
+//
+// +stateify savable
+type ifinet6 struct {
+ kernfs.DynamicBytesFile
+
+ stack inet.Stack
+}
+
+var _ dynamicInode = (*ifinet6)(nil)
+
+func (n *ifinet6) contents() []string {
+ var lines []string
+ nics := n.stack.Interfaces()
+ for id, naddrs := range n.stack.InterfaceAddrs() {
+ nic, ok := nics[id]
+ if !ok {
+ // NIC was added after NICNames was called. We'll just ignore it.
+ continue
+ }
+
+ for _, a := range naddrs {
+ // IPv6 only.
+ if a.Family != linux.AF_INET6 {
+ continue
+ }
+
+ // Fields:
+ // IPv6 address displayed in 32 hexadecimal chars without colons
+ // Netlink device number (interface index) in hexadecimal (use nic id)
+ // Prefix length in hexadecimal
+ // Scope value (use 0)
+ // Interface flags
+ // Device name
+ lines = append(lines, fmt.Sprintf("%032x %02x %02x %02x %02x %8s\n", a.Addr, id, a.PrefixLen, 0, a.Flags, nic.Name))
+ }
+ }
+ return lines
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (n *ifinet6) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ for _, l := range n.contents() {
+ buf.WriteString(l)
+ }
+ return nil
+}
+
+// netDevData implements vfs.DynamicBytesSource for /proc/net/dev.
+//
+// +stateify savable
+type netDevData struct {
+ kernfs.DynamicBytesFile
+
+ stack inet.Stack
+}
+
+var _ dynamicInode = (*netDevData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (n *netDevData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ interfaces := n.stack.Interfaces()
+ buf.WriteString("Inter-| Receive | Transmit\n")
+ buf.WriteString(" face |bytes packets errs drop fifo frame compressed multicast|bytes packets errs drop fifo colls carrier compressed\n")
+
+ for _, i := range interfaces {
+ // Implements the same format as
+ // net/core/net-procfs.c:dev_seq_printf_stats.
+ var stats inet.StatDev
+ if err := n.stack.Statistics(&stats, i.Name); err != nil {
+ log.Warningf("Failed to retrieve interface statistics for %v: %v", i.Name, err)
+ continue
+ }
+ fmt.Fprintf(
+ buf,
+ "%6s: %7d %7d %4d %4d %4d %5d %10d %9d %8d %7d %4d %4d %4d %5d %7d %10d\n",
+ i.Name,
+ // Received
+ stats[0], // bytes
+ stats[1], // packets
+ stats[2], // errors
+ stats[3], // dropped
+ stats[4], // fifo
+ stats[5], // frame
+ stats[6], // compressed
+ stats[7], // multicast
+ // Transmitted
+ stats[8], // bytes
+ stats[9], // packets
+ stats[10], // errors
+ stats[11], // dropped
+ stats[12], // fifo
+ stats[13], // frame
+ stats[14], // compressed
+ stats[15], // multicast
+ )
+ }
+
+ return nil
+}
+
+// netUnixData implements vfs.DynamicBytesSource for /proc/net/unix.
+//
+// +stateify savable
+type netUnixData struct {
+ kernfs.DynamicBytesFile
+
+ kernel *kernel.Kernel
+}
+
+var _ dynamicInode = (*netUnixData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ buf.WriteString("Num RefCount Protocol Flags Type St Inode Path\n")
+ for _, se := range n.kernel.ListSockets() {
+ s := se.SockVFS2
+ if !s.TryIncRef() {
+ log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
+ continue
+ }
+ if family, _, _ := s.Impl().(socket.SocketVFS2).Type(); family != linux.AF_UNIX {
+ s.DecRef()
+ // Not a unix socket.
+ continue
+ }
+ sops := s.Impl().(*unix.SocketVFS2)
+
+ addr, err := sops.Endpoint().GetLocalAddress()
+ if err != nil {
+ log.Warningf("Failed to retrieve socket name from %+v: %v", s, err)
+ addr.Addr = "<unknown>"
+ }
+
+ sockFlags := 0
+ if ce, ok := sops.Endpoint().(transport.ConnectingEndpoint); ok {
+ if ce.Listening() {
+ // For unix domain sockets, linux reports a single flag
+ // value if the socket is listening, of __SO_ACCEPTCON.
+ sockFlags = linux.SO_ACCEPTCON
+ }
+ }
+
+ // Get inode number.
+ var ino uint64
+ stat, statErr := s.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_INO})
+ if statErr != nil || stat.Mask&linux.STATX_INO == 0 {
+ log.Warningf("Failed to retrieve ino for socket file: %v", statErr)
+ } else {
+ ino = stat.Ino
+ }
+
+ // In the socket entry below, the value for the 'Num' field requires
+ // some consideration. Linux prints the address to the struct
+ // unix_sock representing a socket in the kernel, but may redact the
+ // value for unprivileged users depending on the kptr_restrict
+ // sysctl.
+ //
+ // One use for this field is to allow a privileged user to
+ // introspect into the kernel memory to determine information about
+ // a socket not available through procfs, such as the socket's peer.
+ //
+ // In gvisor, returning a pointer to our internal structures would
+ // be pointless, as it wouldn't match the memory layout for struct
+ // unix_sock, making introspection difficult. We could populate a
+ // struct unix_sock with the appropriate data, but even that
+ // requires consideration for which kernel version to emulate, as
+ // the definition of this struct changes over time.
+ //
+ // For now, we always redact this pointer.
+ fmt.Fprintf(buf, "%#016p: %08X %08X %08X %04X %02X %8d",
+ (*unix.SocketOperations)(nil), // Num, pointer to kernel socket struct.
+ s.Refs()-1, // RefCount, don't count our own ref.
+ 0, // Protocol, always 0 for UDS.
+ sockFlags, // Flags.
+ sops.Endpoint().Type(), // Type.
+ sops.State(), // State.
+ ino, // Inode.
+ )
+
+ // Path
+ if len(addr.Addr) != 0 {
+ if addr.Addr[0] == 0 {
+ // Abstract path.
+ fmt.Fprintf(buf, " @%s", string(addr.Addr[1:]))
+ } else {
+ fmt.Fprintf(buf, " %s", string(addr.Addr))
+ }
+ }
+ fmt.Fprintf(buf, "\n")
+
+ s.DecRef()
+ }
+ return nil
+}
+
+func networkToHost16(n uint16) uint16 {
+ // n is in network byte order, so is big-endian. The most-significant byte
+ // should be stored in the lower address.
+ //
+ // We manually inline binary.BigEndian.Uint16() because Go does not support
+ // non-primitive consts, so binary.BigEndian is a (mutable) var, so calls to
+ // binary.BigEndian.Uint16() require a read of binary.BigEndian and an
+ // interface method call, defeating inlining.
+ buf := [2]byte{byte(n >> 8 & 0xff), byte(n & 0xff)}
+ return usermem.ByteOrder.Uint16(buf[:])
+}
+
+func writeInetAddr(w io.Writer, family int, i linux.SockAddr) {
+ switch family {
+ case linux.AF_INET:
+ var a linux.SockAddrInet
+ if i != nil {
+ a = *i.(*linux.SockAddrInet)
+ }
+
+ // linux.SockAddrInet.Port is stored in the network byte order and is
+ // printed like a number in host byte order. Note that all numbers in host
+ // byte order are printed with the most-significant byte first when
+ // formatted with %X. See get_tcp4_sock() and udp4_format_sock() in Linux.
+ port := networkToHost16(a.Port)
+
+ // linux.SockAddrInet.Addr is stored as a byte slice in big-endian order
+ // (i.e. most-significant byte in index 0). Linux represents this as a
+ // __be32 which is a typedef for an unsigned int, and is printed with
+ // %X. This means that for a little-endian machine, Linux prints the
+ // least-significant byte of the address first. To emulate this, we first
+ // invert the byte order for the address using usermem.ByteOrder.Uint32,
+ // which makes it have the equivalent encoding to a __be32 on a little
+ // endian machine. Note that this operation is a no-op on a big endian
+ // machine. Then similar to Linux, we format it with %X, which will print
+ // the most-significant byte of the __be32 address first, which is now
+ // actually the least-significant byte of the original address in
+ // linux.SockAddrInet.Addr on little endian machines, due to the conversion.
+ addr := usermem.ByteOrder.Uint32(a.Addr[:])
+
+ fmt.Fprintf(w, "%08X:%04X ", addr, port)
+ case linux.AF_INET6:
+ var a linux.SockAddrInet6
+ if i != nil {
+ a = *i.(*linux.SockAddrInet6)
+ }
+
+ port := networkToHost16(a.Port)
+ addr0 := usermem.ByteOrder.Uint32(a.Addr[0:4])
+ addr1 := usermem.ByteOrder.Uint32(a.Addr[4:8])
+ addr2 := usermem.ByteOrder.Uint32(a.Addr[8:12])
+ addr3 := usermem.ByteOrder.Uint32(a.Addr[12:16])
+ fmt.Fprintf(w, "%08X%08X%08X%08X:%04X ", addr0, addr1, addr2, addr3, port)
+ }
+}
+
+func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel, family int) error {
+ // t may be nil here if our caller is not part of a task goroutine. This can
+ // happen for example if we're here for "sentryctl cat". When t is nil,
+ // degrade gracefully and retrieve what we can.
+ t := kernel.TaskFromContext(ctx)
+
+ for _, se := range k.ListSockets() {
+ s := se.SockVFS2
+ if !s.TryIncRef() {
+ log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
+ continue
+ }
+ sops, ok := s.Impl().(socket.SocketVFS2)
+ if !ok {
+ panic(fmt.Sprintf("Found non-socket file in socket table: %+v", s))
+ }
+ if fa, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) {
+ s.DecRef()
+ // Not tcp4 sockets.
+ continue
+ }
+
+ // Linux's documentation for the fields below can be found at
+ // https://www.kernel.org/doc/Documentation/networking/proc_net_tcp.txt.
+ // For Linux's implementation, see net/ipv4/tcp_ipv4.c:get_tcp4_sock().
+ // Note that the header doesn't contain labels for all the fields.
+
+ // Field: sl; entry number.
+ fmt.Fprintf(buf, "%4d: ", se.ID)
+
+ // Field: local_adddress.
+ var localAddr linux.SockAddr
+ if t != nil {
+ if local, _, err := sops.GetSockName(t); err == nil {
+ localAddr = local
+ }
+ }
+ writeInetAddr(buf, family, localAddr)
+
+ // Field: rem_address.
+ var remoteAddr linux.SockAddr
+ if t != nil {
+ if remote, _, err := sops.GetPeerName(t); err == nil {
+ remoteAddr = remote
+ }
+ }
+ writeInetAddr(buf, family, remoteAddr)
+
+ // Field: state; socket state.
+ fmt.Fprintf(buf, "%02X ", sops.State())
+
+ // Field: tx_queue, rx_queue; number of packets in the transmit and
+ // receive queue. Unimplemented.
+ fmt.Fprintf(buf, "%08X:%08X ", 0, 0)
+
+ // Field: tr, tm->when; timer active state and number of jiffies
+ // until timer expires. Unimplemented.
+ fmt.Fprintf(buf, "%02X:%08X ", 0, 0)
+
+ // Field: retrnsmt; number of unrecovered RTO timeouts.
+ // Unimplemented.
+ fmt.Fprintf(buf, "%08X ", 0)
+
+ stat, statErr := s.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_UID | linux.STATX_INO})
+
+ // Field: uid.
+ if statErr != nil || stat.Mask&linux.STATX_UID == 0 {
+ log.Warningf("Failed to retrieve uid for socket file: %v", statErr)
+ fmt.Fprintf(buf, "%5d ", 0)
+ } else {
+ creds := auth.CredentialsFromContext(ctx)
+ fmt.Fprintf(buf, "%5d ", uint32(auth.KUID(stat.UID).In(creds.UserNamespace).OrOverflow()))
+ }
+
+ // Field: timeout; number of unanswered 0-window probes.
+ // Unimplemented.
+ fmt.Fprintf(buf, "%8d ", 0)
+
+ // Field: inode.
+ if statErr != nil || stat.Mask&linux.STATX_INO == 0 {
+ log.Warningf("Failed to retrieve inode for socket file: %v", statErr)
+ fmt.Fprintf(buf, "%8d ", 0)
+ } else {
+ fmt.Fprintf(buf, "%8d ", stat.Ino)
+ }
+
+ // Field: refcount. Don't count the ref we obtain while deferencing
+ // the weakref to this socket.
+ fmt.Fprintf(buf, "%d ", s.Refs()-1)
+
+ // Field: Socket struct address. Redacted due to the same reason as
+ // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData.
+ fmt.Fprintf(buf, "%#016p ", (*socket.Socket)(nil))
+
+ // Field: retransmit timeout. Unimplemented.
+ fmt.Fprintf(buf, "%d ", 0)
+
+ // Field: predicted tick of soft clock (delayed ACK control data).
+ // Unimplemented.
+ fmt.Fprintf(buf, "%d ", 0)
+
+ // Field: (ack.quick<<1)|ack.pingpong, Unimplemented.
+ fmt.Fprintf(buf, "%d ", 0)
+
+ // Field: sending congestion window, Unimplemented.
+ fmt.Fprintf(buf, "%d ", 0)
+
+ // Field: Slow start size threshold, -1 if threshold >= 0xFFFF.
+ // Unimplemented, report as large threshold.
+ fmt.Fprintf(buf, "%d", -1)
+
+ fmt.Fprintf(buf, "\n")
+
+ s.DecRef()
+ }
+
+ return nil
+}
+
+// netTCPData implements vfs.DynamicBytesSource for /proc/net/tcp.
+//
+// +stateify savable
+type netTCPData struct {
+ kernfs.DynamicBytesFile
+
+ kernel *kernel.Kernel
+}
+
+var _ dynamicInode = (*netTCPData)(nil)
+
+func (d *netTCPData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ buf.WriteString(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n")
+ return commonGenerateTCP(ctx, buf, d.kernel, linux.AF_INET)
+}
+
+// netTCP6Data implements vfs.DynamicBytesSource for /proc/net/tcp6.
+//
+// +stateify savable
+type netTCP6Data struct {
+ kernfs.DynamicBytesFile
+
+ kernel *kernel.Kernel
+}
+
+var _ dynamicInode = (*netTCP6Data)(nil)
+
+func (d *netTCP6Data) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ buf.WriteString(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n")
+ return commonGenerateTCP(ctx, buf, d.kernel, linux.AF_INET6)
+}
+
+// netUDPData implements vfs.DynamicBytesSource for /proc/net/udp.
+//
+// +stateify savable
+type netUDPData struct {
+ kernfs.DynamicBytesFile
+
+ kernel *kernel.Kernel
+}
+
+var _ dynamicInode = (*netUDPData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ // t may be nil here if our caller is not part of a task goroutine. This can
+ // happen for example if we're here for "sentryctl cat". When t is nil,
+ // degrade gracefully and retrieve what we can.
+ t := kernel.TaskFromContext(ctx)
+
+ buf.WriteString(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode ref pointer drops \n")
+
+ for _, se := range d.kernel.ListSockets() {
+ s := se.SockVFS2
+ if !s.TryIncRef() {
+ log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
+ continue
+ }
+ sops, ok := s.Impl().(socket.SocketVFS2)
+ if !ok {
+ panic(fmt.Sprintf("Found non-socket file in socket table: %+v", s))
+ }
+ if family, stype, _ := sops.Type(); family != linux.AF_INET || stype != linux.SOCK_DGRAM {
+ s.DecRef()
+ // Not udp4 socket.
+ continue
+ }
+
+ // For Linux's implementation, see net/ipv4/udp.c:udp4_format_sock().
+
+ // Field: sl; entry number.
+ fmt.Fprintf(buf, "%5d: ", se.ID)
+
+ // Field: local_adddress.
+ var localAddr linux.SockAddrInet
+ if t != nil {
+ if local, _, err := sops.GetSockName(t); err == nil {
+ localAddr = *local.(*linux.SockAddrInet)
+ }
+ }
+ writeInetAddr(buf, linux.AF_INET, &localAddr)
+
+ // Field: rem_address.
+ var remoteAddr linux.SockAddrInet
+ if t != nil {
+ if remote, _, err := sops.GetPeerName(t); err == nil {
+ remoteAddr = *remote.(*linux.SockAddrInet)
+ }
+ }
+ writeInetAddr(buf, linux.AF_INET, &remoteAddr)
+
+ // Field: state; socket state.
+ fmt.Fprintf(buf, "%02X ", sops.State())
+
+ // Field: tx_queue, rx_queue; number of packets in the transmit and
+ // receive queue. Unimplemented.
+ fmt.Fprintf(buf, "%08X:%08X ", 0, 0)
+
+ // Field: tr, tm->when. Always 0 for UDP.
+ fmt.Fprintf(buf, "%02X:%08X ", 0, 0)
+
+ // Field: retrnsmt. Always 0 for UDP.
+ fmt.Fprintf(buf, "%08X ", 0)
+
+ stat, statErr := s.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_UID | linux.STATX_INO})
+
+ // Field: uid.
+ if statErr != nil || stat.Mask&linux.STATX_UID == 0 {
+ log.Warningf("Failed to retrieve uid for socket file: %v", statErr)
+ fmt.Fprintf(buf, "%5d ", 0)
+ } else {
+ creds := auth.CredentialsFromContext(ctx)
+ fmt.Fprintf(buf, "%5d ", uint32(auth.KUID(stat.UID).In(creds.UserNamespace).OrOverflow()))
+ }
+
+ // Field: timeout. Always 0 for UDP.
+ fmt.Fprintf(buf, "%8d ", 0)
+
+ // Field: inode.
+ if statErr != nil || stat.Mask&linux.STATX_INO == 0 {
+ log.Warningf("Failed to retrieve inode for socket file: %v", statErr)
+ fmt.Fprintf(buf, "%8d ", 0)
+ } else {
+ fmt.Fprintf(buf, "%8d ", stat.Ino)
+ }
+
+ // Field: ref; reference count on the socket inode. Don't count the ref
+ // we obtain while deferencing the weakref to this socket.
+ fmt.Fprintf(buf, "%d ", s.Refs()-1)
+
+ // Field: Socket struct address. Redacted due to the same reason as
+ // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData.
+ fmt.Fprintf(buf, "%#016p ", (*socket.Socket)(nil))
+
+ // Field: drops; number of dropped packets. Unimplemented.
+ fmt.Fprintf(buf, "%d", 0)
+
+ fmt.Fprintf(buf, "\n")
+
+ s.DecRef()
+ }
+ return nil
+}
+
+// netSnmpData implements vfs.DynamicBytesSource for /proc/net/snmp.
+//
+// +stateify savable
+type netSnmpData struct {
+ kernfs.DynamicBytesFile
+
+ stack inet.Stack
+}
+
+var _ dynamicInode = (*netSnmpData)(nil)
+
+type snmpLine struct {
+ prefix string
+ header string
+}
+
+var snmp = []snmpLine{
+ {
+ prefix: "Ip",
+ header: "Forwarding DefaultTTL InReceives InHdrErrors InAddrErrors ForwDatagrams InUnknownProtos InDiscards InDelivers OutRequests OutDiscards OutNoRoutes ReasmTimeout ReasmReqds ReasmOKs ReasmFails FragOKs FragFails FragCreates",
+ },
+ {
+ prefix: "Icmp",
+ header: "InMsgs InErrors InCsumErrors InDestUnreachs InTimeExcds InParmProbs InSrcQuenchs InRedirects InEchos InEchoReps InTimestamps InTimestampReps InAddrMasks InAddrMaskReps OutMsgs OutErrors OutDestUnreachs OutTimeExcds OutParmProbs OutSrcQuenchs OutRedirects OutEchos OutEchoReps OutTimestamps OutTimestampReps OutAddrMasks OutAddrMaskReps",
+ },
+ {
+ prefix: "IcmpMsg",
+ },
+ {
+ prefix: "Tcp",
+ header: "RtoAlgorithm RtoMin RtoMax MaxConn ActiveOpens PassiveOpens AttemptFails EstabResets CurrEstab InSegs OutSegs RetransSegs InErrs OutRsts InCsumErrors",
+ },
+ {
+ prefix: "Udp",
+ header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti",
+ },
+ {
+ prefix: "UdpLite",
+ header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti",
+ },
+}
+
+func toSlice(a interface{}) []uint64 {
+ v := reflect.Indirect(reflect.ValueOf(a))
+ return v.Slice(0, v.Len()).Interface().([]uint64)
+}
+
+func sprintSlice(s []uint64) string {
+ if len(s) == 0 {
+ return ""
+ }
+ r := fmt.Sprint(s)
+ return r[1 : len(r)-1] // Remove "[]" introduced by fmt of slice.
+}
+
+// Generate implements vfs.DynamicBytesSource.
+func (d *netSnmpData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ types := []interface{}{
+ &inet.StatSNMPIP{},
+ &inet.StatSNMPICMP{},
+ nil, // TODO(gvisor.dev/issue/628): Support IcmpMsg stats.
+ &inet.StatSNMPTCP{},
+ &inet.StatSNMPUDP{},
+ &inet.StatSNMPUDPLite{},
+ }
+ for i, stat := range types {
+ line := snmp[i]
+ if stat == nil {
+ fmt.Fprintf(buf, "%s:\n", line.prefix)
+ fmt.Fprintf(buf, "%s:\n", line.prefix)
+ continue
+ }
+ if err := d.stack.Statistics(stat, line.prefix); err != nil {
+ if err == syserror.EOPNOTSUPP {
+ log.Infof("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err)
+ } else {
+ log.Warningf("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err)
+ }
+ }
+
+ fmt.Fprintf(buf, "%s: %s\n", line.prefix, line.header)
+
+ if line.prefix == "Tcp" {
+ tcp := stat.(*inet.StatSNMPTCP)
+ // "Tcp" needs special processing because MaxConn is signed. RFC 2012.
+ fmt.Fprintf(buf, "%s: %s %d %s\n", line.prefix, sprintSlice(tcp[:3]), int64(tcp[3]), sprintSlice(tcp[4:]))
+ } else {
+ fmt.Fprintf(buf, "%s: %s\n", line.prefix, sprintSlice(toSlice(stat)))
+ }
+ }
+ return nil
+}
+
+// netRouteData implements vfs.DynamicBytesSource for /proc/net/route.
+//
+// +stateify savable
+type netRouteData struct {
+ kernfs.DynamicBytesFile
+
+ stack inet.Stack
+}
+
+var _ dynamicInode = (*netRouteData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.
+// See Linux's net/ipv4/fib_trie.c:fib_route_seq_show.
+func (d *netRouteData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%-127s\n", "Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT")
+
+ interfaces := d.stack.Interfaces()
+ for _, rt := range d.stack.RouteTable() {
+ // /proc/net/route only includes ipv4 routes.
+ if rt.Family != linux.AF_INET {
+ continue
+ }
+
+ // /proc/net/route does not include broadcast or multicast routes.
+ if rt.Type == linux.RTN_BROADCAST || rt.Type == linux.RTN_MULTICAST {
+ continue
+ }
+
+ iface, ok := interfaces[rt.OutputInterface]
+ if !ok || iface.Name == "lo" {
+ continue
+ }
+
+ var (
+ gw uint32
+ prefix uint32
+ flags = linux.RTF_UP
+ )
+ if len(rt.GatewayAddr) == header.IPv4AddressSize {
+ flags |= linux.RTF_GATEWAY
+ gw = usermem.ByteOrder.Uint32(rt.GatewayAddr)
+ }
+ if len(rt.DstAddr) == header.IPv4AddressSize {
+ prefix = usermem.ByteOrder.Uint32(rt.DstAddr)
+ }
+ l := fmt.Sprintf(
+ "%s\t%08X\t%08X\t%04X\t%d\t%d\t%d\t%08X\t%d\t%d\t%d",
+ iface.Name,
+ prefix,
+ gw,
+ flags,
+ 0, // RefCnt.
+ 0, // Use.
+ 0, // Metric.
+ (uint32(1)<<rt.DstLen)-1,
+ 0, // MTU.
+ 0, // Window.
+ 0, // RTT.
+ )
+ fmt.Fprintf(buf, "%-127s\n", l)
+ }
+ return nil
+}
+
+// netStatData implements vfs.DynamicBytesSource for /proc/net/netstat.
+//
+// +stateify savable
+type netStatData struct {
+ kernfs.DynamicBytesFile
+
+ stack inet.Stack
+}
+
+var _ dynamicInode = (*netStatData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.
+// See Linux's net/ipv4/fib_trie.c:fib_route_seq_show.
+func (d *netStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ buf.WriteString("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed " +
+ "EmbryonicRsts PruneCalled RcvPruned OfoPruned OutOfWindowIcmps " +
+ "LockDroppedIcmps ArpFilter TW TWRecycled TWKilled PAWSPassive " +
+ "PAWSActive PAWSEstab DelayedACKs DelayedACKLocked DelayedACKLost " +
+ "ListenOverflows ListenDrops TCPPrequeued TCPDirectCopyFromBacklog " +
+ "TCPDirectCopyFromPrequeue TCPPrequeueDropped TCPHPHits TCPHPHitsToUser " +
+ "TCPPureAcks TCPHPAcks TCPRenoRecovery TCPSackRecovery TCPSACKReneging " +
+ "TCPFACKReorder TCPSACKReorder TCPRenoReorder TCPTSReorder TCPFullUndo " +
+ "TCPPartialUndo TCPDSACKUndo TCPLossUndo TCPLostRetransmit " +
+ "TCPRenoFailures TCPSackFailures TCPLossFailures TCPFastRetrans " +
+ "TCPForwardRetrans TCPSlowStartRetrans TCPTimeouts TCPLossProbes " +
+ "TCPLossProbeRecovery TCPRenoRecoveryFail TCPSackRecoveryFail " +
+ "TCPSchedulerFailed TCPRcvCollapsed TCPDSACKOldSent TCPDSACKOfoSent " +
+ "TCPDSACKRecv TCPDSACKOfoRecv TCPAbortOnData TCPAbortOnClose " +
+ "TCPAbortOnMemory TCPAbortOnTimeout TCPAbortOnLinger TCPAbortFailed " +
+ "TCPMemoryPressures TCPSACKDiscard TCPDSACKIgnoredOld " +
+ "TCPDSACKIgnoredNoUndo TCPSpuriousRTOs TCPMD5NotFound TCPMD5Unexpected " +
+ "TCPMD5Failure TCPSackShifted TCPSackMerged TCPSackShiftFallback " +
+ "TCPBacklogDrop TCPMinTTLDrop TCPDeferAcceptDrop IPReversePathFilter " +
+ "TCPTimeWaitOverflow TCPReqQFullDoCookies TCPReqQFullDrop TCPRetransFail " +
+ "TCPRcvCoalesce TCPOFOQueue TCPOFODrop TCPOFOMerge TCPChallengeACK " +
+ "TCPSYNChallenge TCPFastOpenActive TCPFastOpenActiveFail " +
+ "TCPFastOpenPassive TCPFastOpenPassiveFail TCPFastOpenListenOverflow " +
+ "TCPFastOpenCookieReqd TCPSpuriousRtxHostQueues BusyPollRxPackets " +
+ "TCPAutoCorking TCPFromZeroWindowAdv TCPToZeroWindowAdv " +
+ "TCPWantZeroWindowAdv TCPSynRetrans TCPOrigDataSent TCPHystartTrainDetect " +
+ "TCPHystartTrainCwnd TCPHystartDelayDetect TCPHystartDelayCwnd " +
+ "TCPACKSkippedSynRecv TCPACKSkippedPAWS TCPACKSkippedSeq " +
+ "TCPACKSkippedFinWait2 TCPACKSkippedTimeWait TCPACKSkippedChallenge " +
+ "TCPWinProbe TCPKeepAlive TCPMTUPFail TCPMTUPSuccess\n")
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
new file mode 100644
index 000000000..2f214d0c2
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -0,0 +1,256 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "sort"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const (
+ selfName = "self"
+ threadSelfName = "thread-self"
+)
+
+// tasksInode represents the inode for /proc/ directory.
+//
+// +stateify savable
+type tasksInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
+ kernfs.AlwaysValid
+
+ locks vfs.FileLocks
+
+ fs *filesystem
+ pidns *kernel.PIDNamespace
+
+ // '/proc/self' and '/proc/thread-self' have custom directory offsets in
+ // Linux. So handle them outside of OrderedChildren.
+ selfSymlink *vfs.Dentry
+ threadSelfSymlink *vfs.Dentry
+
+ // cgroupControllers is a map of controller name to directory in the
+ // cgroup hierarchy. These controllers are immutable and will be listed
+ // in /proc/pid/cgroup if not nil.
+ cgroupControllers map[string]string
+}
+
+var _ kernfs.Inode = (*tasksInode)(nil)
+
+func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) (*tasksInode, *kernfs.Dentry) {
+ root := auth.NewRootCredentials(pidns.UserNamespace())
+ contents := map[string]*kernfs.Dentry{
+ "cpuinfo": fs.newDentry(root, fs.NextIno(), 0444, newStaticFileSetStat(cpuInfoData(k))),
+ "filesystems": fs.newDentry(root, fs.NextIno(), 0444, &filesystemsData{}),
+ "loadavg": fs.newDentry(root, fs.NextIno(), 0444, &loadavgData{}),
+ "sys": fs.newSysDir(root, k),
+ "meminfo": fs.newDentry(root, fs.NextIno(), 0444, &meminfoData{}),
+ "mounts": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"),
+ "net": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"),
+ "stat": fs.newDentry(root, fs.NextIno(), 0444, &statData{}),
+ "uptime": fs.newDentry(root, fs.NextIno(), 0444, &uptimeData{}),
+ "version": fs.newDentry(root, fs.NextIno(), 0444, &versionData{}),
+ }
+
+ inode := &tasksInode{
+ pidns: pidns,
+ fs: fs,
+ selfSymlink: fs.newSelfSymlink(root, fs.NextIno(), pidns).VFSDentry(),
+ threadSelfSymlink: fs.newThreadSelfSymlink(root, fs.NextIno(), pidns).VFSDentry(),
+ cgroupControllers: cgroupControllers,
+ }
+ inode.InodeAttrs.Init(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
+
+ inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ links := inode.OrderedChildren.Populate(dentry, contents)
+ inode.IncLinks(links)
+
+ return inode, dentry
+}
+
+// Lookup implements kernfs.inodeDynamicLookup.
+func (i *tasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ // Try to lookup a corresponding task.
+ tid, err := strconv.ParseUint(name, 10, 64)
+ if err != nil {
+ // If it failed to parse, check if it's one of the special handled files.
+ switch name {
+ case selfName:
+ return i.selfSymlink, nil
+ case threadSelfName:
+ return i.threadSelfSymlink, nil
+ }
+ return nil, syserror.ENOENT
+ }
+
+ task := i.pidns.TaskWithID(kernel.ThreadID(tid))
+ if task == nil {
+ return nil, syserror.ENOENT
+ }
+
+ taskDentry := i.fs.newTaskInode(task, i.pidns, true, i.cgroupControllers)
+ return taskDentry.VFSDentry(), nil
+}
+
+// IterDirents implements kernfs.inodeDynamicLookup.
+func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) {
+ // fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256
+ const FIRST_PROCESS_ENTRY = 256
+
+ // Use maxTaskID to shortcut searches that will result in 0 entries.
+ const maxTaskID = kernel.TasksLimit + 1
+ if offset >= maxTaskID {
+ return offset, nil
+ }
+
+ // According to Linux (fs/proc/base.c:proc_pid_readdir()), process directories
+ // start at offset FIRST_PROCESS_ENTRY with '/proc/self', followed by
+ // '/proc/thread-self' and then '/proc/[pid]'.
+ if offset < FIRST_PROCESS_ENTRY {
+ offset = FIRST_PROCESS_ENTRY
+ }
+
+ if offset == FIRST_PROCESS_ENTRY {
+ dirent := vfs.Dirent{
+ Name: selfName,
+ Type: linux.DT_LNK,
+ Ino: i.fs.NextIno(),
+ NextOff: offset + 1,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+ if offset == FIRST_PROCESS_ENTRY+1 {
+ dirent := vfs.Dirent{
+ Name: threadSelfName,
+ Type: linux.DT_LNK,
+ Ino: i.fs.NextIno(),
+ NextOff: offset + 1,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+
+ // Collect all tasks that TGIDs are greater than the offset specified. Per
+ // Linux we only include in directory listings if it's the leader. But for
+ // whatever crazy reason, you can still walk to the given node.
+ var tids []int
+ startTid := offset - FIRST_PROCESS_ENTRY - 2
+ for _, tg := range i.pidns.ThreadGroups() {
+ tid := i.pidns.IDOfThreadGroup(tg)
+ if int64(tid) < startTid {
+ continue
+ }
+ if leader := tg.Leader(); leader != nil {
+ tids = append(tids, int(tid))
+ }
+ }
+
+ if len(tids) == 0 {
+ return offset, nil
+ }
+
+ sort.Ints(tids)
+ for _, tid := range tids {
+ dirent := vfs.Dirent{
+ Name: strconv.FormatUint(uint64(tid), 10),
+ Type: linux.DT_DIR,
+ Ino: i.fs.NextIno(),
+ NextOff: FIRST_PROCESS_ENTRY + 2 + int64(tid) + 1,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+ return maxTaskID, nil
+}
+
+// Open implements kernfs.Inode.
+func (i *tasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+func (i *tasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ stat, err := i.InodeAttrs.Stat(vsfs, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+
+ if opts.Mask&linux.STATX_NLINK != 0 {
+ // Add dynamic children to link count.
+ for _, tg := range i.pidns.ThreadGroups() {
+ if leader := tg.Leader(); leader != nil {
+ stat.Nlink++
+ }
+ }
+ }
+
+ return stat, nil
+}
+
+// staticFileSetStat implements a special static file that allows inode
+// attributes to be set. This is to support /proc files that are readonly, but
+// allow attributes to be set.
+type staticFileSetStat struct {
+ dynamicBytesFileSetAttr
+ vfs.StaticData
+}
+
+var _ dynamicInode = (*staticFileSetStat)(nil)
+
+func newStaticFileSetStat(data string) *staticFileSetStat {
+ return &staticFileSetStat{StaticData: vfs.StaticData{Data: data}}
+}
+
+func cpuInfoData(k *kernel.Kernel) string {
+ features := k.FeatureSet()
+ if features == nil {
+ // Kernel is always initialized with a FeatureSet.
+ panic("cpuinfo read with nil FeatureSet")
+ }
+ var buf bytes.Buffer
+ for i, max := uint(0), k.ApplicationCores(); i < max; i++ {
+ features.WriteCPUInfoTo(i, &buf)
+ }
+ return buf.String()
+}
+
+func shmData(v uint64) dynamicInode {
+ return newStaticFile(strconv.FormatUint(v, 10))
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
new file mode 100644
index 000000000..7d8983aa5
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -0,0 +1,384 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type selfSymlink struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeSymlink
+
+ pidns *kernel.PIDNamespace
+}
+
+var _ kernfs.Inode = (*selfSymlink)(nil)
+
+func (fs *filesystem) newSelfSymlink(creds *auth.Credentials, ino uint64, pidns *kernel.PIDNamespace) *kernfs.Dentry {
+ inode := &selfSymlink{pidns: pidns}
+ inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+func (s *selfSymlink) Readlink(ctx context.Context) (string, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ // Who is reading this link?
+ return "", syserror.EINVAL
+ }
+ tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup())
+ if tgid == 0 {
+ return "", syserror.ENOENT
+ }
+ return strconv.FormatUint(uint64(tgid), 10), nil
+}
+
+func (s *selfSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDentry, string, error) {
+ target, err := s.Readlink(ctx)
+ return vfs.VirtualDentry{}, target, err
+}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*selfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+type threadSelfSymlink struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeSymlink
+
+ pidns *kernel.PIDNamespace
+}
+
+var _ kernfs.Inode = (*threadSelfSymlink)(nil)
+
+func (fs *filesystem) newThreadSelfSymlink(creds *auth.Credentials, ino uint64, pidns *kernel.PIDNamespace) *kernfs.Dentry {
+ inode := &threadSelfSymlink{pidns: pidns}
+ inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ // Who is reading this link?
+ return "", syserror.EINVAL
+ }
+ tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup())
+ tid := s.pidns.IDOfTask(t)
+ if tid == 0 || tgid == 0 {
+ return "", syserror.ENOENT
+ }
+ return fmt.Sprintf("%d/task/%d", tgid, tid), nil
+}
+
+func (s *threadSelfSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDentry, string, error) {
+ target, err := s.Readlink(ctx)
+ return vfs.VirtualDentry{}, target, err
+}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*threadSelfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// dynamicBytesFileSetAttr implements a special file that allows inode
+// attributes to be set. This is to support /proc files that are readonly, but
+// allow attributes to be set.
+type dynamicBytesFileSetAttr struct {
+ kernfs.DynamicBytesFile
+}
+
+// SetStat implements Inode.SetStat.
+func (d *dynamicBytesFileSetAttr) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ return d.DynamicBytesFile.InodeAttrs.SetStat(ctx, fs, creds, opts)
+}
+
+// cpuStats contains the breakdown of CPU time for /proc/stat.
+type cpuStats struct {
+ // user is time spent in userspace tasks with non-positive niceness.
+ user uint64
+
+ // nice is time spent in userspace tasks with positive niceness.
+ nice uint64
+
+ // system is time spent in non-interrupt kernel context.
+ system uint64
+
+ // idle is time spent idle.
+ idle uint64
+
+ // ioWait is time spent waiting for IO.
+ ioWait uint64
+
+ // irq is time spent in interrupt context.
+ irq uint64
+
+ // softirq is time spent in software interrupt context.
+ softirq uint64
+
+ // steal is involuntary wait time.
+ steal uint64
+
+ // guest is time spent in guests with non-positive niceness.
+ guest uint64
+
+ // guestNice is time spent in guests with positive niceness.
+ guestNice uint64
+}
+
+// String implements fmt.Stringer.
+func (c cpuStats) String() string {
+ return fmt.Sprintf("%d %d %d %d %d %d %d %d %d %d", c.user, c.nice, c.system, c.idle, c.ioWait, c.irq, c.softirq, c.steal, c.guest, c.guestNice)
+}
+
+// statData implements vfs.DynamicBytesSource for /proc/stat.
+//
+// +stateify savable
+type statData struct {
+ dynamicBytesFileSetAttr
+}
+
+var _ dynamicInode = (*statData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (*statData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ // TODO(b/37226836): We currently export only zero CPU stats. We could
+ // at least provide some aggregate stats.
+ var cpu cpuStats
+ fmt.Fprintf(buf, "cpu %s\n", cpu)
+
+ k := kernel.KernelFromContext(ctx)
+ for c, max := uint(0), k.ApplicationCores(); c < max; c++ {
+ fmt.Fprintf(buf, "cpu%d %s\n", c, cpu)
+ }
+
+ // The total number of interrupts is dependent on the CPUs and PCI
+ // devices on the system. See arch_probe_nr_irqs.
+ //
+ // Since we don't report real interrupt stats, just choose an arbitrary
+ // value from a representative VM.
+ const numInterrupts = 256
+
+ // The Kernel doesn't handle real interrupts, so report all zeroes.
+ // TODO(b/37226836): We could count page faults as #PF.
+ fmt.Fprintf(buf, "intr 0") // total
+ for i := 0; i < numInterrupts; i++ {
+ fmt.Fprintf(buf, " 0")
+ }
+ fmt.Fprintf(buf, "\n")
+
+ // Total number of context switches.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "ctxt 0\n")
+
+ // CLOCK_REALTIME timestamp from boot, in seconds.
+ fmt.Fprintf(buf, "btime %d\n", k.Timekeeper().BootTime().Seconds())
+
+ // Total number of clones.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "processes 0\n")
+
+ // Number of runnable tasks.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "procs_running 0\n")
+
+ // Number of tasks waiting on IO.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "procs_blocked 0\n")
+
+ // Number of each softirq handled.
+ fmt.Fprintf(buf, "softirq 0") // total
+ for i := 0; i < linux.NumSoftIRQ; i++ {
+ fmt.Fprintf(buf, " 0")
+ }
+ fmt.Fprintf(buf, "\n")
+ return nil
+}
+
+// loadavgData backs /proc/loadavg.
+//
+// +stateify savable
+type loadavgData struct {
+ dynamicBytesFileSetAttr
+}
+
+var _ dynamicInode = (*loadavgData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (*loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ // TODO(b/62345059): Include real data in fields.
+ // Column 1-3: CPU and IO utilization of the last 1, 5, and 10 minute periods.
+ // Column 4-5: currently running processes and the total number of processes.
+ // Column 6: the last process ID used.
+ fmt.Fprintf(buf, "%.2f %.2f %.2f %d/%d %d\n", 0.00, 0.00, 0.00, 0, 0, 0)
+ return nil
+}
+
+// meminfoData implements vfs.DynamicBytesSource for /proc/meminfo.
+//
+// +stateify savable
+type meminfoData struct {
+ dynamicBytesFileSetAttr
+}
+
+var _ dynamicInode = (*meminfoData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (*meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ k := kernel.KernelFromContext(ctx)
+ mf := k.MemoryFile()
+ mf.UpdateUsage()
+ snapshot, totalUsage := usage.MemoryAccounting.Copy()
+ totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage)
+ anon := snapshot.Anonymous + snapshot.Tmpfs
+ file := snapshot.PageCache + snapshot.Mapped
+ // We don't actually have active/inactive LRUs, so just make up numbers.
+ activeFile := (file / 2) &^ (usermem.PageSize - 1)
+ inactiveFile := file - activeFile
+
+ fmt.Fprintf(buf, "MemTotal: %8d kB\n", totalSize/1024)
+ memFree := totalSize - totalUsage
+ if memFree > totalSize {
+ // Underflow.
+ memFree = 0
+ }
+ // We use MemFree as MemAvailable because we don't swap.
+ // TODO(rahat): When reclaim is implemented the value of MemAvailable
+ // should change.
+ fmt.Fprintf(buf, "MemFree: %8d kB\n", memFree/1024)
+ fmt.Fprintf(buf, "MemAvailable: %8d kB\n", memFree/1024)
+ fmt.Fprintf(buf, "Buffers: 0 kB\n") // memory usage by block devices
+ fmt.Fprintf(buf, "Cached: %8d kB\n", (file+snapshot.Tmpfs)/1024)
+ // Emulate a system with no swap, which disables inactivation of anon pages.
+ fmt.Fprintf(buf, "SwapCache: 0 kB\n")
+ fmt.Fprintf(buf, "Active: %8d kB\n", (anon+activeFile)/1024)
+ fmt.Fprintf(buf, "Inactive: %8d kB\n", inactiveFile/1024)
+ fmt.Fprintf(buf, "Active(anon): %8d kB\n", anon/1024)
+ fmt.Fprintf(buf, "Inactive(anon): 0 kB\n")
+ fmt.Fprintf(buf, "Active(file): %8d kB\n", activeFile/1024)
+ fmt.Fprintf(buf, "Inactive(file): %8d kB\n", inactiveFile/1024)
+ fmt.Fprintf(buf, "Unevictable: 0 kB\n") // TODO(b/31823263)
+ fmt.Fprintf(buf, "Mlocked: 0 kB\n") // TODO(b/31823263)
+ fmt.Fprintf(buf, "SwapTotal: 0 kB\n")
+ fmt.Fprintf(buf, "SwapFree: 0 kB\n")
+ fmt.Fprintf(buf, "Dirty: 0 kB\n")
+ fmt.Fprintf(buf, "Writeback: 0 kB\n")
+ fmt.Fprintf(buf, "AnonPages: %8d kB\n", anon/1024)
+ fmt.Fprintf(buf, "Mapped: %8d kB\n", file/1024) // doesn't count mapped tmpfs, which we don't know
+ fmt.Fprintf(buf, "Shmem: %8d kB\n", snapshot.Tmpfs/1024)
+ return nil
+}
+
+// uptimeData implements vfs.DynamicBytesSource for /proc/uptime.
+//
+// +stateify savable
+type uptimeData struct {
+ dynamicBytesFileSetAttr
+}
+
+var _ dynamicInode = (*uptimeData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (*uptimeData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ k := kernel.KernelFromContext(ctx)
+ now := time.NowFromContext(ctx)
+
+ // Pretend that we've spent zero time sleeping (second number).
+ fmt.Fprintf(buf, "%.2f 0.00\n", now.Sub(k.Timekeeper().BootTime()).Seconds())
+ return nil
+}
+
+// versionData implements vfs.DynamicBytesSource for /proc/version.
+//
+// +stateify savable
+type versionData struct {
+ dynamicBytesFileSetAttr
+}
+
+var _ dynamicInode = (*versionData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (*versionData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ k := kernel.KernelFromContext(ctx)
+ init := k.GlobalInit()
+ if init == nil {
+ // Attempted to read before the init Task is created. This can
+ // only occur during startup, which should never need to read
+ // this file.
+ panic("Attempted to read version before initial Task is available")
+ }
+
+ // /proc/version takes the form:
+ //
+ // "SYSNAME version RELEASE (COMPILE_USER@COMPILE_HOST)
+ // (COMPILER_VERSION) VERSION"
+ //
+ // where:
+ // - SYSNAME, RELEASE, and VERSION are the same as returned by
+ // sys_utsname
+ // - COMPILE_USER is the user that build the kernel
+ // - COMPILE_HOST is the hostname of the machine on which the kernel
+ // was built
+ // - COMPILER_VERSION is the version reported by the building compiler
+ //
+ // Since we don't really want to expose build information to
+ // applications, those fields are omitted.
+ //
+ // FIXME(mpratt): Using Version from the init task SyscallTable
+ // disregards the different version a task may have (e.g., in a uts
+ // namespace).
+ ver := init.Leader().SyscallTable().Version
+ fmt.Fprintf(buf, "%s version %s %s\n", ver.Sysname, ver.Release, ver.Version)
+ return nil
+}
+
+// filesystemsData backs /proc/filesystems.
+//
+// +stateify savable
+type filesystemsData struct {
+ kernfs.DynamicBytesFile
+}
+
+var _ dynamicInode = (*filesystemsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *filesystemsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ k := kernel.KernelFromContext(ctx)
+ k.VFS().GenerateProcFilesystems(buf)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
new file mode 100644
index 000000000..6dac2afa4
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -0,0 +1,209 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// newSysDir returns the dentry corresponding to /proc/sys directory.
+func (fs *filesystem) newSysDir(root *auth.Credentials, k *kernel.Kernel) *kernfs.Dentry {
+ return kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "kernel": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "hostname": fs.newDentry(root, fs.NextIno(), 0444, &hostnameData{}),
+ "shmall": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMALL)),
+ "shmmax": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMMAX)),
+ "shmmni": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMMNI)),
+ }),
+ "vm": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "mmap_min_addr": fs.newDentry(root, fs.NextIno(), 0444, &mmapMinAddrData{k: k}),
+ "overcommit_memory": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0\n")),
+ }),
+ "net": fs.newSysNetDir(root, k),
+ })
+}
+
+// newSysNetDir returns the dentry corresponding to /proc/sys/net directory.
+func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) *kernfs.Dentry {
+ var contents map[string]*kernfs.Dentry
+
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process.
+ if stack := k.RootNetworkNamespace().Stack(); stack != nil {
+ contents = map[string]*kernfs.Dentry{
+ "ipv4": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "tcp_sack": fs.newDentry(root, fs.NextIno(), 0644, &tcpSackData{stack: stack}),
+
+ // The following files are simple stubs until they are implemented in
+ // netstack, most of these files are configuration related. We use the
+ // value closest to the actual netstack behavior or any empty file, all
+ // of these files will have mode 0444 (read-only for all users).
+ "ip_local_port_range": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("16000 65535")),
+ "ip_local_reserved_ports": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("")),
+ "ipfrag_time": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("30")),
+ "ip_nonlocal_bind": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "ip_no_pmtu_disc": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
+
+ // tcp_allowed_congestion_control tell the user what they are able to
+ // do as an unprivledged process so we leave it empty.
+ "tcp_allowed_congestion_control": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("")),
+ "tcp_available_congestion_control": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("reno")),
+ "tcp_congestion_control": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("reno")),
+
+ // Many of the following stub files are features netstack doesn't
+ // support. The unsupported features return "0" to indicate they are
+ // disabled.
+ "tcp_base_mss": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1280")),
+ "tcp_dsack": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_early_retrans": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_fack": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_fastopen": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_fastopen_key": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("")),
+ "tcp_invalid_ratelimit": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_keepalive_intvl": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_keepalive_probes": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_keepalive_time": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("7200")),
+ "tcp_mtu_probing": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_no_metrics_save": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
+ "tcp_probe_interval": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_probe_threshold": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_retries1": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("3")),
+ "tcp_retries2": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("15")),
+ "tcp_rfc1337": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
+ "tcp_slow_start_after_idle": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
+ "tcp_synack_retries": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("5")),
+ "tcp_syn_retries": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("3")),
+ "tcp_timestamps": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
+ }),
+ "core": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "default_qdisc": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("pfifo_fast")),
+ "message_burst": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("10")),
+ "message_cost": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("5")),
+ "optmem_max": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "rmem_default": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")),
+ "rmem_max": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")),
+ "somaxconn": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("128")),
+ "wmem_default": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")),
+ "wmem_max": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")),
+ }),
+ }
+ }
+
+ return kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, contents)
+}
+
+// mmapMinAddrData implements vfs.DynamicBytesSource for
+// /proc/sys/vm/mmap_min_addr.
+//
+// +stateify savable
+type mmapMinAddrData struct {
+ kernfs.DynamicBytesFile
+
+ k *kernel.Kernel
+}
+
+var _ dynamicInode = (*mmapMinAddrData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *mmapMinAddrData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%d\n", d.k.Platform.MinUserAddress())
+ return nil
+}
+
+// hostnameData implements vfs.DynamicBytesSource for /proc/sys/kernel/hostname.
+//
+// +stateify savable
+type hostnameData struct {
+ kernfs.DynamicBytesFile
+}
+
+var _ dynamicInode = (*hostnameData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (*hostnameData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ utsns := kernel.UTSNamespaceFromContext(ctx)
+ buf.WriteString(utsns.HostName())
+ buf.WriteString("\n")
+ return nil
+}
+
+// tcpSackData implements vfs.WritableDynamicBytesSource for
+// /proc/sys/net/tcp_sack.
+//
+// +stateify savable
+type tcpSackData struct {
+ kernfs.DynamicBytesFile
+
+ stack inet.Stack `state:"wait"`
+ enabled *bool
+}
+
+var _ vfs.WritableDynamicBytesSource = (*tcpSackData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.
+func (d *tcpSackData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if d.enabled == nil {
+ sack, err := d.stack.TCPSACKEnabled()
+ if err != nil {
+ return err
+ }
+ d.enabled = &sack
+ }
+
+ val := "0\n"
+ if *d.enabled {
+ // Technically, this is not quite compatible with Linux. Linux stores these
+ // as an integer, so if you write "2" into tcp_sack, you should get 2 back.
+ // Tough luck.
+ val = "1\n"
+ }
+ buf.WriteString(val)
+ return nil
+}
+
+func (d *tcpSackData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ if offset != 0 {
+ // No need to handle partial writes thus far.
+ return 0, syserror.EINVAL
+ }
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Limit the amount of memory allocated.
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ var v int32
+ n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
+ if err != nil {
+ return n, err
+ }
+ if d.enabled == nil {
+ d.enabled = new(bool)
+ }
+ *d.enabled = v != 0
+ return n, d.stack.SetTCPSACKEnabled(*d.enabled)
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys_test.go b/pkg/sentry/fsimpl/proc/tasks_sys_test.go
new file mode 100644
index 000000000..be54897bb
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/tasks_sys_test.go
@@ -0,0 +1,78 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+)
+
+func newIPv6TestStack() *inet.TestStack {
+ s := inet.NewTestStack()
+ s.SupportsIPv6Flag = true
+ return s
+}
+
+func TestIfinet6NoAddresses(t *testing.T) {
+ n := &ifinet6{stack: newIPv6TestStack()}
+ var buf bytes.Buffer
+ n.Generate(contexttest.Context(t), &buf)
+ if buf.Len() > 0 {
+ t.Errorf("n.Generate() generated = %v, want = %v", buf.Bytes(), []byte{})
+ }
+}
+
+func TestIfinet6(t *testing.T) {
+ s := newIPv6TestStack()
+ s.InterfacesMap[1] = inet.Interface{Name: "eth0"}
+ s.InterfaceAddrsMap[1] = []inet.InterfaceAddr{
+ {
+ Family: linux.AF_INET6,
+ PrefixLen: 128,
+ Addr: []byte("\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f"),
+ },
+ }
+ s.InterfacesMap[2] = inet.Interface{Name: "eth1"}
+ s.InterfaceAddrsMap[2] = []inet.InterfaceAddr{
+ {
+ Family: linux.AF_INET6,
+ PrefixLen: 128,
+ Addr: []byte("\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"),
+ },
+ }
+ want := map[string]struct{}{
+ "000102030405060708090a0b0c0d0e0f 01 80 00 00 eth0\n": {},
+ "101112131415161718191a1b1c1d1e1f 02 80 00 00 eth1\n": {},
+ }
+
+ n := &ifinet6{stack: s}
+ contents := n.contents()
+ if len(contents) != len(want) {
+ t.Errorf("Got len(n.contents()) = %d, want = %d", len(contents), len(want))
+ }
+ got := map[string]struct{}{}
+ for _, l := range contents {
+ got[l] = struct{}{}
+ }
+
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Got n.contents() = %v, want = %v", got, want)
+ }
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
new file mode 100644
index 000000000..19abb5034
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -0,0 +1,505 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+ "math"
+ "path"
+ "strconv"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+var (
+ // Next offset 256 by convention. Adds 1 for the next offset.
+ selfLink = vfs.Dirent{Type: linux.DT_LNK, NextOff: 256 + 0 + 1}
+ threadSelfLink = vfs.Dirent{Type: linux.DT_LNK, NextOff: 256 + 1 + 1}
+
+ // /proc/[pid] next offset starts at 256+2 (files above), then adds the
+ // PID, and adds 1 for the next offset.
+ proc1 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 1 + 1}
+ proc2 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 2 + 1}
+ proc3 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 3 + 1}
+)
+
+var (
+ tasksStaticFiles = map[string]testutil.DirentType{
+ "cpuinfo": linux.DT_REG,
+ "filesystems": linux.DT_REG,
+ "loadavg": linux.DT_REG,
+ "meminfo": linux.DT_REG,
+ "mounts": linux.DT_LNK,
+ "net": linux.DT_LNK,
+ "self": linux.DT_LNK,
+ "stat": linux.DT_REG,
+ "sys": linux.DT_DIR,
+ "thread-self": linux.DT_LNK,
+ "uptime": linux.DT_REG,
+ "version": linux.DT_REG,
+ }
+ tasksStaticFilesNextOffs = map[string]int64{
+ "self": selfLink.NextOff,
+ "thread-self": threadSelfLink.NextOff,
+ }
+ taskStaticFiles = map[string]testutil.DirentType{
+ "auxv": linux.DT_REG,
+ "cgroup": linux.DT_REG,
+ "cmdline": linux.DT_REG,
+ "comm": linux.DT_REG,
+ "environ": linux.DT_REG,
+ "exe": linux.DT_LNK,
+ "fd": linux.DT_DIR,
+ "fdinfo": linux.DT_DIR,
+ "gid_map": linux.DT_REG,
+ "io": linux.DT_REG,
+ "maps": linux.DT_REG,
+ "mountinfo": linux.DT_REG,
+ "mounts": linux.DT_REG,
+ "net": linux.DT_DIR,
+ "ns": linux.DT_DIR,
+ "oom_score": linux.DT_REG,
+ "oom_score_adj": linux.DT_REG,
+ "smaps": linux.DT_REG,
+ "stat": linux.DT_REG,
+ "statm": linux.DT_REG,
+ "status": linux.DT_REG,
+ "task": linux.DT_DIR,
+ "uid_map": linux.DT_REG,
+ }
+)
+
+func setup(t *testing.T) *testutil.System {
+ k, err := testutil.Boot()
+ if err != nil {
+ t.Fatalf("Error creating kernel: %v", err)
+ }
+
+ ctx := k.SupervisorContext()
+ creds := auth.CredentialsFromContext(ctx)
+
+ k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+
+ mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", tmpfs.Name, &vfs.GetFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("NewMountNamespace(): %v", err)
+ }
+ pop := &vfs.PathOperation{
+ Root: mntns.Root(),
+ Start: mntns.Root(),
+ Path: fspath.Parse("/proc"),
+ }
+ if err := k.VFS().MkdirAt(ctx, creds, pop, &vfs.MkdirOptions{Mode: 0777}); err != nil {
+ t.Fatalf("MkDir(/proc): %v", err)
+ }
+
+ pop = &vfs.PathOperation{
+ Root: mntns.Root(),
+ Start: mntns.Root(),
+ Path: fspath.Parse("/proc"),
+ }
+ mntOpts := &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ InternalData: &InternalData{
+ Cgroups: map[string]string{
+ "cpuset": "/foo/cpuset",
+ "memory": "/foo/memory",
+ },
+ },
+ },
+ }
+ if err := k.VFS().MountAt(ctx, creds, "", pop, Name, mntOpts); err != nil {
+ t.Fatalf("MountAt(/proc): %v", err)
+ }
+ return testutil.NewSystem(ctx, t, k.VFS(), mntns)
+}
+
+func TestTasksEmpty(t *testing.T) {
+ s := setup(t)
+ defer s.Destroy()
+
+ collector := s.ListDirents(s.PathOpAtRoot("/proc"))
+ s.AssertAllDirentTypes(collector, tasksStaticFiles)
+ s.AssertDirentOffsets(collector, tasksStaticFilesNextOffs)
+}
+
+func TestTasks(t *testing.T) {
+ s := setup(t)
+ defer s.Destroy()
+
+ expectedDirents := make(map[string]testutil.DirentType)
+ for n, d := range tasksStaticFiles {
+ expectedDirents[n] = d
+ }
+
+ k := kernel.KernelFromContext(s.Ctx)
+ var tasks []*kernel.Task
+ for i := 0; i < 5; i++ {
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ task, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc, s.MntNs, s.Root, s.Root)
+ if err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+ tasks = append(tasks, task)
+ expectedDirents[fmt.Sprintf("%d", i+1)] = linux.DT_DIR
+ }
+
+ collector := s.ListDirents(s.PathOpAtRoot("/proc"))
+ s.AssertAllDirentTypes(collector, expectedDirents)
+ s.AssertDirentOffsets(collector, tasksStaticFilesNextOffs)
+
+ lastPid := 0
+ dirents := collector.OrderedDirents()
+ doneSkippingNonTaskDirs := false
+ for _, d := range dirents {
+ pid, err := strconv.Atoi(d.Name)
+ if err != nil {
+ if !doneSkippingNonTaskDirs {
+ // We haven't gotten to the task dirs yet.
+ continue
+ }
+ t.Fatalf("Invalid process directory %q", d.Name)
+ }
+ doneSkippingNonTaskDirs = true
+ if lastPid > pid {
+ t.Errorf("pids not in order: %v", dirents)
+ }
+ found := false
+ for _, t := range tasks {
+ if k.TaskSet().Root.IDOfTask(t) == kernel.ThreadID(pid) {
+ found = true
+ }
+ }
+ if !found {
+ t.Errorf("Additional task ID %d listed: %v", pid, tasks)
+ }
+ // Next offset starts at 256+2 ('self' and 'thread-self'), then adds the
+ // PID, and adds 1 for the next offset.
+ if want := int64(256 + 2 + pid + 1); d.NextOff != want {
+ t.Errorf("Wrong dirent offset want: %d got: %d: %+v", want, d.NextOff, d)
+ }
+ }
+ if !doneSkippingNonTaskDirs {
+ t.Fatalf("Never found any process directories.")
+ }
+
+ // Test lookup.
+ for _, path := range []string{"/proc/1", "/proc/2"} {
+ fd, err := s.VFS.OpenAt(
+ s.Ctx,
+ s.Creds,
+ s.PathOpAtRoot(path),
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt(%q) failed: %v", path, err)
+ }
+ defer fd.DecRef()
+ buf := make([]byte, 1)
+ bufIOSeq := usermem.BytesIOSequence(buf)
+ if _, err := fd.Read(s.Ctx, bufIOSeq, vfs.ReadOptions{}); err != syserror.EISDIR {
+ t.Errorf("wrong error reading directory: %v", err)
+ }
+ }
+
+ if _, err := s.VFS.OpenAt(
+ s.Ctx,
+ s.Creds,
+ s.PathOpAtRoot("/proc/9999"),
+ &vfs.OpenOptions{},
+ ); err != syserror.ENOENT {
+ t.Fatalf("wrong error from vfsfs.OpenAt(/proc/9999): %v", err)
+ }
+}
+
+func TestTasksOffset(t *testing.T) {
+ s := setup(t)
+ defer s.Destroy()
+
+ k := kernel.KernelFromContext(s.Ctx)
+ for i := 0; i < 3; i++ {
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ if _, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc, s.MntNs, s.Root, s.Root); err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+ }
+
+ for _, tc := range []struct {
+ name string
+ offset int64
+ wants map[string]vfs.Dirent
+ }{
+ {
+ name: "small offset",
+ offset: 100,
+ wants: map[string]vfs.Dirent{
+ "self": selfLink,
+ "thread-self": threadSelfLink,
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "offset at start",
+ offset: 256,
+ wants: map[string]vfs.Dirent{
+ "self": selfLink,
+ "thread-self": threadSelfLink,
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "skip /proc/self",
+ offset: 257,
+ wants: map[string]vfs.Dirent{
+ "thread-self": threadSelfLink,
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "skip symlinks",
+ offset: 258,
+ wants: map[string]vfs.Dirent{
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "skip first process",
+ offset: 260,
+ wants: map[string]vfs.Dirent{
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "last process",
+ offset: 261,
+ wants: map[string]vfs.Dirent{
+ "3": proc3,
+ },
+ },
+ {
+ name: "after last",
+ offset: 262,
+ wants: nil,
+ },
+ {
+ name: "TaskLimit+1",
+ offset: kernel.TasksLimit + 1,
+ wants: nil,
+ },
+ {
+ name: "max",
+ offset: math.MaxInt64,
+ wants: nil,
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ s := s.WithSubtest(t)
+ fd, err := s.VFS.OpenAt(
+ s.Ctx,
+ s.Creds,
+ s.PathOpAtRoot("/proc"),
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt(/) failed: %v", err)
+ }
+ defer fd.DecRef()
+ if _, err := fd.Seek(s.Ctx, tc.offset, linux.SEEK_SET); err != nil {
+ t.Fatalf("Seek(%d, SEEK_SET): %v", tc.offset, err)
+ }
+
+ var collector testutil.DirentCollector
+ if err := fd.IterDirents(s.Ctx, &collector); err != nil {
+ t.Fatalf("IterDirent(): %v", err)
+ }
+
+ expectedTypes := make(map[string]testutil.DirentType)
+ expectedOffsets := make(map[string]int64)
+ for name, want := range tc.wants {
+ expectedTypes[name] = want.Type
+ if want.NextOff != 0 {
+ expectedOffsets[name] = want.NextOff
+ }
+ }
+
+ collector.SkipDotsChecks(true) // We seek()ed past the dots.
+ s.AssertAllDirentTypes(&collector, expectedTypes)
+ s.AssertDirentOffsets(&collector, expectedOffsets)
+ })
+ }
+}
+
+func TestTask(t *testing.T) {
+ s := setup(t)
+ defer s.Destroy()
+
+ k := kernel.KernelFromContext(s.Ctx)
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ _, err := testutil.CreateTask(s.Ctx, "name", tc, s.MntNs, s.Root, s.Root)
+ if err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+
+ collector := s.ListDirents(s.PathOpAtRoot("/proc/1"))
+ s.AssertAllDirentTypes(collector, taskStaticFiles)
+}
+
+func TestProcSelf(t *testing.T) {
+ s := setup(t)
+ defer s.Destroy()
+
+ k := kernel.KernelFromContext(s.Ctx)
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ task, err := testutil.CreateTask(s.Ctx, "name", tc, s.MntNs, s.Root, s.Root)
+ if err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+
+ collector := s.WithTemporaryContext(task).ListDirents(&vfs.PathOperation{
+ Root: s.Root,
+ Start: s.Root,
+ Path: fspath.Parse("/proc/self/"),
+ FollowFinalSymlink: true,
+ })
+ s.AssertAllDirentTypes(collector, taskStaticFiles)
+}
+
+func iterateDir(ctx context.Context, t *testing.T, s *testutil.System, fd *vfs.FileDescription) {
+ t.Logf("Iterating: %s", fd.MappedName(ctx))
+
+ var collector testutil.DirentCollector
+ if err := fd.IterDirents(ctx, &collector); err != nil {
+ t.Fatalf("IterDirents(): %v", err)
+ }
+ if err := collector.Contains(".", linux.DT_DIR); err != nil {
+ t.Error(err.Error())
+ }
+ if err := collector.Contains("..", linux.DT_DIR); err != nil {
+ t.Error(err.Error())
+ }
+
+ for _, d := range collector.Dirents() {
+ if d.Name == "." || d.Name == ".." {
+ continue
+ }
+ absPath := path.Join(fd.MappedName(ctx), d.Name)
+ if d.Type == linux.DT_LNK {
+ link, err := s.VFS.ReadlinkAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: s.Root, Start: s.Root, Path: fspath.Parse(absPath)},
+ )
+ if err != nil {
+ t.Errorf("vfsfs.ReadlinkAt(%v) failed: %v", absPath, err)
+ } else {
+ t.Logf("Skipping symlink: %s => %s", absPath, link)
+ }
+ continue
+ }
+
+ t.Logf("Opening: %s", absPath)
+ child, err := s.VFS.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: s.Root, Start: s.Root, Path: fspath.Parse(absPath)},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Errorf("vfsfs.OpenAt(%v) failed: %v", absPath, err)
+ continue
+ }
+ defer child.DecRef()
+ stat, err := child.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Errorf("Stat(%v) failed: %v", absPath, err)
+ }
+ if got := linux.FileMode(stat.Mode).DirentType(); got != d.Type {
+ t.Errorf("wrong file mode, stat: %v, dirent: %v", got, d.Type)
+ }
+ if d.Type == linux.DT_DIR {
+ // Found another dir, let's do it again!
+ iterateDir(ctx, t, s, child)
+ }
+ }
+}
+
+// TestTree iterates all directories and stats every file.
+func TestTree(t *testing.T) {
+ s := setup(t)
+ defer s.Destroy()
+
+ k := kernel.KernelFromContext(s.Ctx)
+
+ pop := &vfs.PathOperation{
+ Root: s.Root,
+ Start: s.Root,
+ Path: fspath.Parse("test-file"),
+ }
+ opts := &vfs.OpenOptions{
+ Flags: linux.O_RDONLY | linux.O_CREAT,
+ Mode: 0777,
+ }
+ file, err := s.VFS.OpenAt(s.Ctx, s.Creds, pop, opts)
+ if err != nil {
+ t.Fatalf("failed to create test file: %v", err)
+ }
+ defer file.DecRef()
+
+ var tasks []*kernel.Task
+ for i := 0; i < 5; i++ {
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ task, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc, s.MntNs, s.Root, s.Root)
+ if err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+ // Add file to populate /proc/[pid]/fd and fdinfo directories.
+ task.FDTable().NewFDVFS2(task, 0, file, kernel.FDFlags{})
+ tasks = append(tasks, task)
+ }
+
+ ctx := tasks[0]
+ fd, err := s.VFS.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(s.Ctx),
+ &vfs.PathOperation{Root: s.Root, Start: s.Root, Path: fspath.Parse("/proc")},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt(/proc) failed: %v", err)
+ }
+ iterateDir(ctx, t, s, fd)
+ fd.DecRef()
+}
diff --git a/pkg/sentry/fsimpl/signalfd/BUILD b/pkg/sentry/fsimpl/signalfd/BUILD
new file mode 100644
index 000000000..067c1657f
--- /dev/null
+++ b/pkg/sentry/fsimpl/signalfd/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "signalfd",
+ srcs = ["signalfd.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/context",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go
new file mode 100644
index 000000000..242ba9b5d
--- /dev/null
+++ b/pkg/sentry/fsimpl/signalfd/signalfd.go
@@ -0,0 +1,136 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package signalfd
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SignalFileDescription implements FileDescriptionImpl for signal fds.
+type SignalFileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+
+ // target is the original signal target task.
+ //
+ // The semantics here are a bit broken. Linux will always use current
+ // for all reads, regardless of where the signalfd originated. We can't
+ // do exactly that because we need to plumb the context through
+ // EventRegister in order to support proper blocking behavior. This
+ // will undoubtedly become very complicated quickly.
+ target *kernel.Task
+
+ // mu protects mask.
+ mu sync.Mutex
+
+ // mask is the signal mask. Protected by mu.
+ mask linux.SignalSet
+}
+
+var _ vfs.FileDescriptionImpl = (*SignalFileDescription)(nil)
+
+// New creates a new signal fd.
+func New(vfsObj *vfs.VirtualFilesystem, target *kernel.Task, mask linux.SignalSet, flags uint32) (*vfs.FileDescription, error) {
+ vd := vfsObj.NewAnonVirtualDentry("[signalfd]")
+ defer vd.DecRef()
+ sfd := &SignalFileDescription{
+ target: target,
+ mask: mask,
+ }
+ if err := sfd.vfsfd.Init(sfd, flags, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ DenyPRead: true,
+ DenyPWrite: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &sfd.vfsfd, nil
+}
+
+// Mask returns the signal mask.
+func (sfd *SignalFileDescription) Mask() linux.SignalSet {
+ sfd.mu.Lock()
+ defer sfd.mu.Unlock()
+ return sfd.mask
+}
+
+// SetMask sets the signal mask.
+func (sfd *SignalFileDescription) SetMask(mask linux.SignalSet) {
+ sfd.mu.Lock()
+ defer sfd.mu.Unlock()
+ sfd.mask = mask
+}
+
+// Read implements FileDescriptionImpl.Read.
+func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
+ // Attempt to dequeue relevant signals.
+ info, err := sfd.target.Sigtimedwait(sfd.Mask(), 0)
+ if err != nil {
+ // There must be no signal available.
+ return 0, syserror.ErrWouldBlock
+ }
+
+ // Copy out the signal info using the specified format.
+ var buf [128]byte
+ binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{
+ Signo: uint32(info.Signo),
+ Errno: info.Errno,
+ Code: info.Code,
+ PID: uint32(info.Pid()),
+ UID: uint32(info.Uid()),
+ Status: info.Status(),
+ Overrun: uint32(info.Overrun()),
+ Addr: info.Addr(),
+ })
+ n, err := dst.CopyOut(ctx, buf[:])
+ return int64(n), err
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (sfd *SignalFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
+ sfd.mu.Lock()
+ defer sfd.mu.Unlock()
+ if mask&waiter.EventIn != 0 && sfd.target.PendingSignals()&sfd.mask != 0 {
+ return waiter.EventIn // Pending signals.
+ }
+ return 0
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (sfd *SignalFileDescription) EventRegister(entry *waiter.Entry, _ waiter.EventMask) {
+ sfd.mu.Lock()
+ defer sfd.mu.Unlock()
+ // Register for the signal set; ignore the passed events.
+ sfd.target.SignalRegister(entry, waiter.EventMask(sfd.mask))
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (sfd *SignalFileDescription) EventUnregister(entry *waiter.Entry) {
+ // Unregister the original entry.
+ sfd.target.SignalUnregister(entry)
+}
+
+// Release implements FileDescriptionImpl.Release()
+func (sfd *SignalFileDescription) Release() {}
diff --git a/pkg/sentry/fsimpl/sockfs/BUILD b/pkg/sentry/fsimpl/sockfs/BUILD
new file mode 100644
index 000000000..9453277b8
--- /dev/null
+++ b/pkg/sentry/fsimpl/sockfs/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "sockfs",
+ srcs = ["sockfs.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go
new file mode 100644
index 000000000..ee0828a15
--- /dev/null
+++ b/pkg/sentry/fsimpl/sockfs/sockfs.go
@@ -0,0 +1,109 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package sockfs provides a filesystem implementation for anonymous sockets.
+package sockfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// filesystemType implements vfs.FilesystemType.
+type filesystemType struct{}
+
+// GetFilesystem implements FilesystemType.GetFilesystem.
+func (fsType filesystemType) GetFilesystem(_ context.Context, vfsObj *vfs.VirtualFilesystem, _ *auth.Credentials, _ string, _ vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ panic("sockfs.filesystemType.GetFilesystem should never be called")
+}
+
+// Name implements FilesystemType.Name.
+//
+// Note that registering sockfs is unnecessary, except for the fact that it
+// will not show up under /proc/filesystems as a result. This is a very minor
+// discrepancy from Linux.
+func (filesystemType) Name() string {
+ return "sockfs"
+}
+
+type filesystem struct {
+ kernfs.Filesystem
+
+ devMinor uint32
+}
+
+// NewFilesystem sets up and returns a new sockfs filesystem.
+//
+// Note that there should only ever be one instance of sockfs.Filesystem,
+// backing a global socket mount.
+func NewFilesystem(vfsObj *vfs.VirtualFilesystem) (*vfs.Filesystem, error) {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, err
+ }
+ fs := &filesystem{
+ devMinor: devMinor,
+ }
+ fs.Filesystem.VFSFilesystem().Init(vfsObj, filesystemType{}, fs)
+ return fs.Filesystem.VFSFilesystem(), nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release() {
+ fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release()
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ inode := vd.Dentry().Impl().(*kernfs.Dentry).Inode().(*inode)
+ b.PrependComponent(fmt.Sprintf("socket:[%d]", inode.InodeAttrs.Ino()))
+ return vfs.PrependPathSyntheticError{}
+}
+
+// inode implements kernfs.Inode.
+type inode struct {
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+}
+
+// Open implements kernfs.Inode.Open.
+func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ return nil, syserror.ENXIO
+}
+
+// NewDentry constructs and returns a sockfs dentry.
+//
+// Preconditions: mnt.Filesystem() must have been returned by NewFilesystem().
+func NewDentry(creds *auth.Credentials, mnt *vfs.Mount) *vfs.Dentry {
+ fs := mnt.Filesystem().Impl().(*filesystem)
+
+ // File mode matches net/socket.c:sock_alloc.
+ filemode := linux.FileMode(linux.S_IFSOCK | 0600)
+ i := &inode{}
+ i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode)
+
+ d := &kernfs.Dentry{}
+ d.Init(i)
+ return d.VFSDentry()
+}
diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD
new file mode 100644
index 000000000..a741e2bb6
--- /dev/null
+++ b/pkg/sentry/fsimpl/sys/BUILD
@@ -0,0 +1,34 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+licenses(["notice"])
+
+go_library(
+ name = "sys",
+ srcs = [
+ "sys.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ ],
+)
+
+go_test(
+ name = "sys_test",
+ srcs = ["sys_test.go"],
+ deps = [
+ ":sys",
+ "//pkg/abi/linux",
+ "//pkg/sentry/fsimpl/testutil",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
new file mode 100644
index 000000000..01ce30a4d
--- /dev/null
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -0,0 +1,151 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package sys implements sysfs.
+package sys
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Name is the default filesystem name.
+const Name = "sysfs"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ kernfs.Filesystem
+
+ devMinor uint32
+}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fs := &filesystem{
+ devMinor: devMinor,
+ }
+ fs.VFSFilesystem().Init(vfsObj, &fsType, fs)
+ k := kernel.KernelFromContext(ctx)
+ maxCPUCores := k.ApplicationCores()
+ defaultSysDirMode := linux.FileMode(0755)
+
+ root := fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{
+ "block": fs.newDir(creds, defaultSysDirMode, nil),
+ "bus": fs.newDir(creds, defaultSysDirMode, nil),
+ "class": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{
+ "power_supply": fs.newDir(creds, defaultSysDirMode, nil),
+ }),
+ "dev": fs.newDir(creds, defaultSysDirMode, nil),
+ "devices": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{
+ "system": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{
+ "cpu": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{
+ "online": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
+ "possible": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
+ "present": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
+ }),
+ }),
+ }),
+ "firmware": fs.newDir(creds, defaultSysDirMode, nil),
+ "fs": fs.newDir(creds, defaultSysDirMode, nil),
+ "kernel": fs.newDir(creds, defaultSysDirMode, nil),
+ "module": fs.newDir(creds, defaultSysDirMode, nil),
+ "power": fs.newDir(creds, defaultSysDirMode, nil),
+ })
+ return fs.VFSFilesystem(), root.VFSDentry(), nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release() {
+ fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release()
+}
+
+// dir implements kernfs.Inode.
+type dir struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoDynamicLookup
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.OrderedChildren
+
+ locks vfs.FileLocks
+
+ dentry kernfs.Dentry
+}
+
+func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry {
+ d := &dir{}
+ d.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755)
+ d.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ d.dentry.Init(d)
+
+ d.IncLinks(d.OrderedChildren.Populate(&d.dentry, contents))
+
+ return &d.dentry
+}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*dir) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// Open implements kernfs.Inode.Open.
+func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+// cpuFile implements kernfs.Inode.
+type cpuFile struct {
+ kernfs.DynamicBytesFile
+ maxCores uint
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (c *cpuFile) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "0-%d\n", c.maxCores-1)
+ return nil
+}
+
+func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode linux.FileMode) *kernfs.Dentry {
+ c := &cpuFile{maxCores: maxCores}
+ c.DynamicBytesFile.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode)
+ d := &kernfs.Dentry{}
+ d.Init(c)
+ return d
+}
diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go
new file mode 100644
index 000000000..242d5fd12
--- /dev/null
+++ b/pkg/sentry/fsimpl/sys/sys_test.go
@@ -0,0 +1,89 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sys_test
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sys"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+func newTestSystem(t *testing.T) *testutil.System {
+ k, err := testutil.Boot()
+ if err != nil {
+ t.Fatalf("Failed to create test kernel: %v", err)
+ }
+ ctx := k.SupervisorContext()
+ creds := auth.CredentialsFromContext(ctx)
+ k.VFS().MustRegisterFilesystemType(sys.Name, sys.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+
+ mns, err := k.VFS().NewMountNamespace(ctx, creds, "", sys.Name, &vfs.GetFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("Failed to create new mount namespace: %v", err)
+ }
+ return testutil.NewSystem(ctx, t, k.VFS(), mns)
+}
+
+func TestReadCPUFile(t *testing.T) {
+ s := newTestSystem(t)
+ defer s.Destroy()
+ k := kernel.KernelFromContext(s.Ctx)
+ maxCPUCores := k.ApplicationCores()
+
+ expected := fmt.Sprintf("0-%d\n", maxCPUCores-1)
+
+ for _, fname := range []string{"online", "possible", "present"} {
+ pop := s.PathOpAtRoot(fmt.Sprintf("devices/system/cpu/%s", fname))
+ fd, err := s.VFS.OpenAt(s.Ctx, s.Creds, pop, &vfs.OpenOptions{})
+ if err != nil {
+ t.Fatalf("OpenAt(pop:%+v) = %+v failed: %v", pop, fd, err)
+ }
+ defer fd.DecRef()
+ content, err := s.ReadToEnd(fd)
+ if err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+ if diff := cmp.Diff(expected, content); diff != "" {
+ t.Fatalf("Read returned unexpected data:\n--- want\n+++ got\n%v", diff)
+ }
+ }
+}
+
+func TestSysRootContainsExpectedEntries(t *testing.T) {
+ s := newTestSystem(t)
+ defer s.Destroy()
+ pop := s.PathOpAtRoot("/")
+ s.AssertAllDirentTypes(s.ListDirents(pop), map[string]testutil.DirentType{
+ "block": linux.DT_DIR,
+ "bus": linux.DT_DIR,
+ "class": linux.DT_DIR,
+ "dev": linux.DT_DIR,
+ "devices": linux.DT_DIR,
+ "firmware": linux.DT_DIR,
+ "fs": linux.DT_DIR,
+ "kernel": linux.DT_DIR,
+ "module": linux.DT_DIR,
+ "power": linux.DT_DIR,
+ })
+}
diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD
new file mode 100644
index 000000000..0e4053a46
--- /dev/null
+++ b/pkg/sentry/fsimpl/testutil/BUILD
@@ -0,0 +1,37 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "testutil",
+ testonly = 1,
+ srcs = [
+ "kernel.go",
+ "testutil.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/cpuid",
+ "//pkg/fspath",
+ "//pkg/memutil",
+ "//pkg/sentry/fsbridge",
+ "//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/sched",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/loader",
+ "//pkg/sentry/mm",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/platform/kvm",
+ "//pkg/sentry/platform/ptrace",
+ "//pkg/sentry/time",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/usermem",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go
new file mode 100644
index 000000000..e743e8114
--- /dev/null
+++ b/pkg/sentry/fsimpl/testutil/kernel.go
@@ -0,0 +1,176 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testutil
+
+import (
+ "flag"
+ "fmt"
+ "os"
+ "runtime"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/memutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/sched"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/loader"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+
+ // Platforms are plugable.
+ _ "gvisor.dev/gvisor/pkg/sentry/platform/kvm"
+ _ "gvisor.dev/gvisor/pkg/sentry/platform/ptrace"
+)
+
+var (
+ platformFlag = flag.String("platform", "ptrace", "specify which platform to use")
+)
+
+// Boot initializes a new bare bones kernel for test.
+func Boot() (*kernel.Kernel, error) {
+ platformCtr, err := platform.Lookup(*platformFlag)
+ if err != nil {
+ return nil, fmt.Errorf("platform not found: %v", err)
+ }
+ deviceFile, err := platformCtr.OpenDevice()
+ if err != nil {
+ return nil, fmt.Errorf("creating platform: %v", err)
+ }
+ plat, err := platformCtr.New(deviceFile)
+ if err != nil {
+ return nil, fmt.Errorf("creating platform: %v", err)
+ }
+
+ kernel.VFS2Enabled = true
+ k := &kernel.Kernel{
+ Platform: plat,
+ }
+
+ mf, err := createMemoryFile()
+ if err != nil {
+ return nil, err
+ }
+ k.SetMemoryFile(mf)
+
+ // Pass k as the platform since it is savable, unlike the actual platform.
+ vdso, err := loader.PrepareVDSO(k)
+ if err != nil {
+ return nil, fmt.Errorf("creating vdso: %v", err)
+ }
+
+ // Create timekeeper.
+ tk, err := kernel.NewTimekeeper(k, vdso.ParamPage.FileRange())
+ if err != nil {
+ return nil, fmt.Errorf("creating timekeeper: %v", err)
+ }
+ tk.SetClocks(time.NewCalibratedClocks())
+
+ creds := auth.NewRootCredentials(auth.NewRootUserNamespace())
+
+ // Initiate the Kernel object, which is required by the Context passed
+ // to createVFS in order to mount (among other things) procfs.
+ if err = k.Init(kernel.InitKernelArgs{
+ ApplicationCores: uint(runtime.GOMAXPROCS(-1)),
+ FeatureSet: cpuid.HostFeatureSet(),
+ Timekeeper: tk,
+ RootUserNamespace: creds.UserNamespace,
+ Vdso: vdso,
+ RootUTSNamespace: kernel.NewUTSNamespace("hostname", "domain", creds.UserNamespace),
+ RootIPCNamespace: kernel.NewIPCNamespace(creds.UserNamespace),
+ RootAbstractSocketNamespace: kernel.NewAbstractSocketNamespace(),
+ PIDNamespace: kernel.NewRootPIDNamespace(creds.UserNamespace),
+ }); err != nil {
+ return nil, fmt.Errorf("initializing kernel: %v", err)
+ }
+
+ k.VFS().MustRegisterFilesystemType(tmpfs.Name, &tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+
+ ls, err := limits.NewLinuxLimitSet()
+ if err != nil {
+ return nil, err
+ }
+ tg := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, ls)
+ k.TestOnly_SetGlobalInit(tg)
+
+ return k, nil
+}
+
+// CreateTask creates a new bare bones task for tests.
+func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns *vfs.MountNamespace, root, cwd vfs.VirtualDentry) (*kernel.Task, error) {
+ k := kernel.KernelFromContext(ctx)
+ exe, err := newFakeExecutable(ctx, k.VFS(), auth.CredentialsFromContext(ctx), root)
+ if err != nil {
+ return nil, err
+ }
+ m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation)
+ m.SetExecutable(fsbridge.NewVFSFile(exe))
+
+ config := &kernel.TaskConfig{
+ Kernel: k,
+ ThreadGroup: tc,
+ TaskContext: &kernel.TaskContext{Name: name, MemoryManager: m},
+ Credentials: auth.CredentialsFromContext(ctx),
+ NetworkNamespace: k.RootNetworkNamespace(),
+ AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()),
+ UTSNamespace: kernel.UTSNamespaceFromContext(ctx),
+ IPCNamespace: kernel.IPCNamespaceFromContext(ctx),
+ AbstractSocketNamespace: kernel.NewAbstractSocketNamespace(),
+ MountNamespaceVFS2: mntns,
+ FSContext: kernel.NewFSContextVFS2(root, cwd, 0022),
+ FDTable: k.NewFDTable(),
+ }
+ return k.TaskSet().NewTask(config)
+}
+
+func newFakeExecutable(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, root vfs.VirtualDentry) (*vfs.FileDescription, error) {
+ const name = "executable"
+ pop := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(name),
+ }
+ opts := &vfs.OpenOptions{
+ Flags: linux.O_RDONLY | linux.O_CREAT,
+ Mode: 0777,
+ }
+ return vfsObj.OpenAt(ctx, creds, pop, opts)
+}
+
+func createMemoryFile() (*pgalloc.MemoryFile, error) {
+ const memfileName = "test-memory"
+ memfd, err := memutil.CreateMemFD(memfileName, 0)
+ if err != nil {
+ return nil, fmt.Errorf("error creating memfd: %v", err)
+ }
+ memfile := os.NewFile(uintptr(memfd), memfileName)
+ mf, err := pgalloc.NewMemoryFile(memfile, pgalloc.MemoryFileOpts{})
+ if err != nil {
+ memfile.Close()
+ return nil, fmt.Errorf("error creating pgalloc.MemoryFile: %v", err)
+ }
+ return mf, nil
+}
diff --git a/pkg/sentry/fsimpl/testutil/testutil.go b/pkg/sentry/fsimpl/testutil/testutil.go
new file mode 100644
index 000000000..0556af877
--- /dev/null
+++ b/pkg/sentry/fsimpl/testutil/testutil.go
@@ -0,0 +1,284 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testutil provides common test utilities for kernfs-based
+// filesystems.
+package testutil
+
+import (
+ "fmt"
+ "io"
+ "strings"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// System represents the context for a single test.
+//
+// Test systems must be explicitly destroyed with System.Destroy.
+type System struct {
+ t *testing.T
+ Ctx context.Context
+ Creds *auth.Credentials
+ VFS *vfs.VirtualFilesystem
+ Root vfs.VirtualDentry
+ MntNs *vfs.MountNamespace
+}
+
+// NewSystem constructs a System.
+//
+// Precondition: Caller must hold a reference on MntNs, whose ownership
+// is transferred to the new System.
+func NewSystem(ctx context.Context, t *testing.T, v *vfs.VirtualFilesystem, mns *vfs.MountNamespace) *System {
+ s := &System{
+ t: t,
+ Ctx: ctx,
+ Creds: auth.CredentialsFromContext(ctx),
+ VFS: v,
+ MntNs: mns,
+ Root: mns.Root(),
+ }
+ return s
+}
+
+// WithSubtest creates a temporary test system with a new test harness,
+// referencing all other resources from the original system. This is useful when
+// a system is reused for multiple subtests, and the T needs to change for each
+// case. Note that this is safe when test cases run in parallel, as all
+// resources referenced by the system are immutable, or handle interior
+// mutations in a thread-safe manner.
+//
+// The returned system must not outlive the original and should not be destroyed
+// via System.Destroy.
+func (s *System) WithSubtest(t *testing.T) *System {
+ return &System{
+ t: t,
+ Ctx: s.Ctx,
+ Creds: s.Creds,
+ VFS: s.VFS,
+ MntNs: s.MntNs,
+ Root: s.Root,
+ }
+}
+
+// WithTemporaryContext constructs a temporary test system with a new context
+// ctx. The temporary system borrows all resources and references from the
+// original system. The returned temporary system must not outlive the original
+// system, and should not be destroyed via System.Destroy.
+func (s *System) WithTemporaryContext(ctx context.Context) *System {
+ return &System{
+ t: s.t,
+ Ctx: ctx,
+ Creds: s.Creds,
+ VFS: s.VFS,
+ MntNs: s.MntNs,
+ Root: s.Root,
+ }
+}
+
+// Destroy release resources associated with a test system.
+func (s *System) Destroy() {
+ s.Root.DecRef()
+ s.MntNs.DecRef() // Reference on MntNs passed to NewSystem.
+}
+
+// ReadToEnd reads the contents of fd until EOF to a string.
+func (s *System) ReadToEnd(fd *vfs.FileDescription) (string, error) {
+ buf := make([]byte, usermem.PageSize)
+ bufIOSeq := usermem.BytesIOSequence(buf)
+ opts := vfs.ReadOptions{}
+
+ var content strings.Builder
+ for {
+ n, err := fd.Read(s.Ctx, bufIOSeq, opts)
+ if n == 0 || err != nil {
+ if err == io.EOF {
+ err = nil
+ }
+ return content.String(), err
+ }
+ content.Write(buf[:n])
+ }
+}
+
+// PathOpAtRoot constructs a PathOperation with the given path from
+// the root of the filesystem.
+func (s *System) PathOpAtRoot(path string) *vfs.PathOperation {
+ return &vfs.PathOperation{
+ Root: s.Root,
+ Start: s.Root,
+ Path: fspath.Parse(path),
+ }
+}
+
+// GetDentryOrDie attempts to resolve a dentry referred to by the
+// provided path operation. If unsuccessful, the test fails.
+func (s *System) GetDentryOrDie(pop *vfs.PathOperation) vfs.VirtualDentry {
+ vd, err := s.VFS.GetDentryAt(s.Ctx, s.Creds, pop, &vfs.GetDentryOptions{})
+ if err != nil {
+ s.t.Fatalf("GetDentryAt(pop:%+v) failed: %v", pop, err)
+ }
+ return vd
+}
+
+// DirentType is an alias for values for linux_dirent64.d_type.
+type DirentType = uint8
+
+// ListDirents lists the Dirents for a directory at pop.
+func (s *System) ListDirents(pop *vfs.PathOperation) *DirentCollector {
+ fd, err := s.VFS.OpenAt(s.Ctx, s.Creds, pop, &vfs.OpenOptions{Flags: linux.O_RDONLY})
+ if err != nil {
+ s.t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err)
+ }
+ defer fd.DecRef()
+
+ collector := &DirentCollector{}
+ if err := fd.IterDirents(s.Ctx, collector); err != nil {
+ s.t.Fatalf("IterDirent failed: %v", err)
+ }
+ return collector
+}
+
+// AssertAllDirentTypes verifies that the set of dirents in collector contains
+// exactly the specified set of expected entries. AssertAllDirentTypes respects
+// collector.skipDots, and implicitly checks for "." and ".." accordingly.
+func (s *System) AssertAllDirentTypes(collector *DirentCollector, expected map[string]DirentType) {
+ if expected == nil {
+ expected = make(map[string]DirentType)
+ }
+ // Also implicitly check for "." and "..", if enabled.
+ if !collector.skipDots {
+ expected["."] = linux.DT_DIR
+ expected[".."] = linux.DT_DIR
+ }
+
+ dentryTypes := make(map[string]DirentType)
+ collector.mu.Lock()
+ for _, dirent := range collector.dirents {
+ dentryTypes[dirent.Name] = dirent.Type
+ }
+ collector.mu.Unlock()
+ if diff := cmp.Diff(expected, dentryTypes); diff != "" {
+ s.t.Fatalf("IterDirent had unexpected results:\n--- want\n+++ got\n%v", diff)
+ }
+}
+
+// AssertDirentOffsets verifies that collector contains at least the entries
+// specified in expected, with the given NextOff field. Entries specified in
+// expected but missing from collector result in failure. Extra entries in
+// collector are ignored. AssertDirentOffsets respects collector.skipDots, and
+// implicitly checks for "." and ".." accordingly.
+func (s *System) AssertDirentOffsets(collector *DirentCollector, expected map[string]int64) {
+ // Also implicitly check for "." and "..", if enabled.
+ if !collector.skipDots {
+ expected["."] = 1
+ expected[".."] = 2
+ }
+
+ dentryNextOffs := make(map[string]int64)
+ collector.mu.Lock()
+ for _, dirent := range collector.dirents {
+ // Ignore extra entries in dentries that are not in expected.
+ if _, ok := expected[dirent.Name]; ok {
+ dentryNextOffs[dirent.Name] = dirent.NextOff
+ }
+ }
+ collector.mu.Unlock()
+ if diff := cmp.Diff(expected, dentryNextOffs); diff != "" {
+ s.t.Fatalf("IterDirent had unexpected results:\n--- want\n+++ got\n%v", diff)
+ }
+}
+
+// DirentCollector provides an implementation for vfs.IterDirentsCallback for
+// testing. It simply iterates to the end of a given directory FD and collects
+// all dirents emitted by the callback.
+type DirentCollector struct {
+ mu sync.Mutex
+ order []*vfs.Dirent
+ dirents map[string]*vfs.Dirent
+ // When the collector is used in various Assert* functions, should "." and
+ // ".." be implicitly checked?
+ skipDots bool
+}
+
+// SkipDotsChecks enables or disables the implicit checks on "." and ".." when
+// the collector is used in various Assert* functions. Note that "." and ".."
+// are still collected if passed to d.Handle, so the caller should only disable
+// the checks when they aren't expected.
+func (d *DirentCollector) SkipDotsChecks(value bool) {
+ d.skipDots = value
+}
+
+// Handle implements vfs.IterDirentsCallback.Handle.
+func (d *DirentCollector) Handle(dirent vfs.Dirent) error {
+ d.mu.Lock()
+ if d.dirents == nil {
+ d.dirents = make(map[string]*vfs.Dirent)
+ }
+ d.order = append(d.order, &dirent)
+ d.dirents[dirent.Name] = &dirent
+ d.mu.Unlock()
+ return nil
+}
+
+// Count returns the number of dirents currently in the collector.
+func (d *DirentCollector) Count() int {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return len(d.dirents)
+}
+
+// Contains checks whether the collector has a dirent with the given name and
+// type.
+func (d *DirentCollector) Contains(name string, typ uint8) error {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ dirent, ok := d.dirents[name]
+ if !ok {
+ return fmt.Errorf("No dirent named %q found", name)
+ }
+ if dirent.Type != typ {
+ return fmt.Errorf("Dirent named %q found, but was expecting type %s, got: %+v", name, linux.DirentType.Parse(uint64(typ)), dirent)
+ }
+ return nil
+}
+
+// Dirents returns all dirents discovered by this collector.
+func (d *DirentCollector) Dirents() map[string]*vfs.Dirent {
+ d.mu.Lock()
+ dirents := make(map[string]*vfs.Dirent)
+ for n, d := range d.dirents {
+ dirents[n] = d
+ }
+ d.mu.Unlock()
+ return dirents
+}
+
+// OrderedDirents returns an ordered list of dirents as discovered by this
+// collector.
+func (d *DirentCollector) OrderedDirents() []*vfs.Dirent {
+ d.mu.Lock()
+ dirents := make([]*vfs.Dirent, len(d.order))
+ copy(dirents, d.order)
+ d.mu.Unlock()
+ return dirents
+}
diff --git a/pkg/sentry/fsimpl/timerfd/BUILD b/pkg/sentry/fsimpl/timerfd/BUILD
new file mode 100644
index 000000000..fbb02a271
--- /dev/null
+++ b/pkg/sentry/fsimpl/timerfd/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "timerfd",
+ srcs = ["timerfd.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/context",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/timerfd/timerfd.go b/pkg/sentry/fsimpl/timerfd/timerfd.go
new file mode 100644
index 000000000..2dc90d484
--- /dev/null
+++ b/pkg/sentry/fsimpl/timerfd/timerfd.go
@@ -0,0 +1,144 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package timerfd implements timer fds.
+package timerfd
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// TimerFileDescription implements FileDescriptionImpl for timer fds. It also
+// implements ktime.TimerListener.
+type TimerFileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+
+ events waiter.Queue
+ timer *ktime.Timer
+
+ // val is the number of timer expirations since the last successful
+ // call to PRead, or SetTime. val must be accessed using atomic memory
+ // operations.
+ val uint64
+}
+
+var _ vfs.FileDescriptionImpl = (*TimerFileDescription)(nil)
+var _ ktime.TimerListener = (*TimerFileDescription)(nil)
+
+// New returns a new timer fd.
+func New(vfsObj *vfs.VirtualFilesystem, clock ktime.Clock, flags uint32) (*vfs.FileDescription, error) {
+ vd := vfsObj.NewAnonVirtualDentry("[timerfd]")
+ defer vd.DecRef()
+ tfd := &TimerFileDescription{}
+ tfd.timer = ktime.NewTimer(clock, tfd)
+ if err := tfd.vfsfd.Init(tfd, flags, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ DenyPRead: true,
+ DenyPWrite: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &tfd.vfsfd, nil
+}
+
+// Read implements FileDescriptionImpl.Read.
+func (tfd *TimerFileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ const sizeofUint64 = 8
+ if dst.NumBytes() < sizeofUint64 {
+ return 0, syserror.EINVAL
+ }
+ if val := atomic.SwapUint64(&tfd.val, 0); val != 0 {
+ var buf [sizeofUint64]byte
+ usermem.ByteOrder.PutUint64(buf[:], val)
+ if _, err := dst.CopyOut(ctx, buf[:]); err != nil {
+ // Linux does not undo consuming the number of
+ // expirations even if writing to userspace fails.
+ return 0, err
+ }
+ return sizeofUint64, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+// Clock returns the timer fd's Clock.
+func (tfd *TimerFileDescription) Clock() ktime.Clock {
+ return tfd.timer.Clock()
+}
+
+// GetTime returns the associated Timer's setting and the time at which it was
+// observed.
+func (tfd *TimerFileDescription) GetTime() (ktime.Time, ktime.Setting) {
+ return tfd.timer.Get()
+}
+
+// SetTime atomically changes the associated Timer's setting, resets the number
+// of expirations to 0, and returns the previous setting and the time at which
+// it was observed.
+func (tfd *TimerFileDescription) SetTime(s ktime.Setting) (ktime.Time, ktime.Setting) {
+ return tfd.timer.SwapAnd(s, func() { atomic.StoreUint64(&tfd.val, 0) })
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (tfd *TimerFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
+ var ready waiter.EventMask
+ if atomic.LoadUint64(&tfd.val) != 0 {
+ ready |= waiter.EventIn
+ }
+ return ready
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (tfd *TimerFileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ tfd.events.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (tfd *TimerFileDescription) EventUnregister(e *waiter.Entry) {
+ tfd.events.EventUnregister(e)
+}
+
+// PauseTimer pauses the associated Timer.
+func (tfd *TimerFileDescription) PauseTimer() {
+ tfd.timer.Pause()
+}
+
+// ResumeTimer resumes the associated Timer.
+func (tfd *TimerFileDescription) ResumeTimer() {
+ tfd.timer.Resume()
+}
+
+// Release implements FileDescriptionImpl.Release()
+func (tfd *TimerFileDescription) Release() {
+ tfd.timer.Destroy()
+}
+
+// Notify implements ktime.TimerListener.Notify.
+func (tfd *TimerFileDescription) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
+ atomic.AddUint64(&tfd.val, exp)
+ tfd.events.Notify(waiter.EventIn)
+ return ktime.Setting{}, false
+}
+
+// Destroy implements ktime.TimerListener.Destroy.
+func (tfd *TimerFileDescription) Destroy() {}
diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
new file mode 100644
index 000000000..e73732a6b
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -0,0 +1,112 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+licenses(["notice"])
+
+go_template_instance(
+ name = "dentry_list",
+ out = "dentry_list.go",
+ package = "tmpfs",
+ prefix = "dentry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*dentry",
+ "Linker": "*dentry",
+ },
+)
+
+go_template_instance(
+ name = "fstree",
+ out = "fstree.go",
+ package = "tmpfs",
+ prefix = "generic",
+ template = "//pkg/sentry/vfs/genericfstree:generic_fstree",
+ types = {
+ "Dentry": "dentry",
+ },
+)
+
+go_library(
+ name = "tmpfs",
+ srcs = [
+ "dentry_list.go",
+ "device_file.go",
+ "directory.go",
+ "filesystem.go",
+ "fstree.go",
+ "named_pipe.go",
+ "regular_file.go",
+ "socket_file.go",
+ "symlink.go",
+ "tmpfs.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/amutex",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/log",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/pipe",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/uniqueid",
+ "//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
+ "//pkg/sentry/vfs/memxattr",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
+
+go_test(
+ name = "benchmark_test",
+ size = "small",
+ srcs = ["benchmark_test.go"],
+ deps = [
+ ":tmpfs",
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/refs",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ ],
+)
+
+go_test(
+ name = "tmpfs_test",
+ size = "small",
+ srcs = [
+ "pipe_test.go",
+ "regular_file_test.go",
+ "stat_test.go",
+ "tmpfs_test.go",
+ ],
+ library = ":tmpfs",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
new file mode 100644
index 000000000..2fb5c4d84
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
@@ -0,0 +1,486 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package benchmark_test
+
+import (
+ "fmt"
+ "runtime"
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ _ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Differences from stat_benchmark:
+//
+// - Syscall interception, CopyInPath, copyOutStat, and overlayfs overheads are
+// not included.
+//
+// - *MountStat benchmarks use a tmpfs root mount and a tmpfs submount at /tmp.
+// Non-MountStat benchmarks use a tmpfs root mount and no submounts.
+// stat_benchmark uses a varying root mount, a tmpfs submount at /tmp, and a
+// subdirectory /tmp/<top_dir> (assuming TEST_TMPDIR == "/tmp"). Thus
+// stat_benchmark at depth 1 does a comparable amount of work to *MountStat
+// benchmarks at depth 2, and non-MountStat benchmarks at depth 3.
+var depths = []int{1, 2, 3, 8, 64, 100}
+
+const (
+ mountPointName = "tmp"
+ filename = "gvisor_test_temp_0_1557494568"
+)
+
+// This is copied from syscalls/linux/sys_file.go, with the dependency on
+// kernel.Task stripped out.
+func fileOpOn(ctx context.Context, mntns *fs.MountNamespace, root, wd *fs.Dirent, dirFD int32, path string, resolve bool, fn func(root *fs.Dirent, d *fs.Dirent) error) error {
+ var (
+ d *fs.Dirent // The file.
+ rel *fs.Dirent // The relative directory for search (if required.)
+ err error
+ )
+
+ // Extract the working directory (maybe).
+ if len(path) > 0 && path[0] == '/' {
+ // Absolute path; rel can be nil.
+ } else if dirFD == linux.AT_FDCWD {
+ // Need to reference the working directory.
+ rel = wd
+ } else {
+ // Need to extract the given FD.
+ return syserror.EBADF
+ }
+
+ // Lookup the node.
+ remainingTraversals := uint(linux.MaxSymlinkTraversals)
+ if resolve {
+ d, err = mntns.FindInode(ctx, root, rel, path, &remainingTraversals)
+ } else {
+ d, err = mntns.FindLink(ctx, root, rel, path, &remainingTraversals)
+ }
+ if err != nil {
+ return err
+ }
+
+ err = fn(root, d)
+ d.DecRef()
+ return err
+}
+
+func BenchmarkVFS1TmpfsStat(b *testing.B) {
+ for _, depth := range depths {
+ b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) {
+ ctx := contexttest.Context(b)
+
+ // Create VFS.
+ tmpfsFS, ok := fs.FindFilesystem("tmpfs")
+ if !ok {
+ b.Fatalf("failed to find tmpfs filesystem type")
+ }
+ rootInode, err := tmpfsFS.Mount(ctx, "tmpfs", fs.MountSourceFlags{}, "", nil)
+ if err != nil {
+ b.Fatalf("failed to create tmpfs root mount: %v", err)
+ }
+ mntns, err := fs.NewMountNamespace(ctx, rootInode)
+ if err != nil {
+ b.Fatalf("failed to create mount namespace: %v", err)
+ }
+ defer mntns.DecRef()
+
+ var filePathBuilder strings.Builder
+ filePathBuilder.WriteByte('/')
+
+ // Create nested directories with given depth.
+ root := mntns.Root()
+ defer root.DecRef()
+ d := root
+ d.IncRef()
+ defer d.DecRef()
+ for i := depth; i > 0; i-- {
+ name := fmt.Sprintf("%d", i)
+ if err := d.Inode.CreateDirectory(ctx, d, name, fs.FilePermsFromMode(0755)); err != nil {
+ b.Fatalf("failed to create directory %q: %v", name, err)
+ }
+ next, err := d.Walk(ctx, root, name)
+ if err != nil {
+ b.Fatalf("failed to walk to directory %q: %v", name, err)
+ }
+ d.DecRef()
+ d = next
+ filePathBuilder.WriteString(name)
+ filePathBuilder.WriteByte('/')
+ }
+
+ // Create the file that will be stat'd.
+ file, err := d.Inode.Create(ctx, d, filename, fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0644))
+ if err != nil {
+ b.Fatalf("failed to create file %q: %v", filename, err)
+ }
+ file.DecRef()
+ filePathBuilder.WriteString(filename)
+ filePath := filePathBuilder.String()
+
+ dirPath := false
+ runtime.GC()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ err := fileOpOn(ctx, mntns, root, root, linux.AT_FDCWD, filePath, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent) error {
+ if dirPath && !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+ uattr, err := d.Inode.UnstableAttr(ctx)
+ if err != nil {
+ return err
+ }
+ // Sanity check.
+ if uattr.Perms.User.Execute {
+ b.Fatalf("got wrong permissions (%0o)", uattr.Perms.LinuxMode())
+ }
+ return nil
+ })
+ if err != nil {
+ b.Fatalf("stat(%q) failed: %v", filePath, err)
+ }
+ }
+ // Don't include deferred cleanup in benchmark time.
+ b.StopTimer()
+ })
+ }
+}
+
+func BenchmarkVFS2TmpfsStat(b *testing.B) {
+ for _, depth := range depths {
+ b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) {
+ ctx := contexttest.Context(b)
+ creds := auth.CredentialsFromContext(ctx)
+
+ // Create VFS.
+ vfsObj := vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ b.Fatalf("VFS init: %v", err)
+ }
+ vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
+ if err != nil {
+ b.Fatalf("failed to create tmpfs root mount: %v", err)
+ }
+ defer mntns.DecRef()
+
+ var filePathBuilder strings.Builder
+ filePathBuilder.WriteByte('/')
+
+ // Create nested directories with given depth.
+ root := mntns.Root()
+ defer root.DecRef()
+ vd := root
+ vd.IncRef()
+ for i := depth; i > 0; i-- {
+ name := fmt.Sprintf("%d", i)
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: vd,
+ Path: fspath.Parse(name),
+ }
+ if err := vfsObj.MkdirAt(ctx, creds, &pop, &vfs.MkdirOptions{
+ Mode: 0755,
+ }); err != nil {
+ b.Fatalf("failed to create directory %q: %v", name, err)
+ }
+ nextVD, err := vfsObj.GetDentryAt(ctx, creds, &pop, &vfs.GetDentryOptions{})
+ if err != nil {
+ b.Fatalf("failed to walk to directory %q: %v", name, err)
+ }
+ vd.DecRef()
+ vd = nextVD
+ filePathBuilder.WriteString(name)
+ filePathBuilder.WriteByte('/')
+ }
+
+ // Create the file that will be stat'd.
+ fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: vd,
+ Path: fspath.Parse(filename),
+ FollowFinalSymlink: true,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
+ Mode: 0644,
+ })
+ vd.DecRef()
+ vd = vfs.VirtualDentry{}
+ if err != nil {
+ b.Fatalf("failed to create file %q: %v", filename, err)
+ }
+ defer fd.DecRef()
+ filePathBuilder.WriteString(filename)
+ filePath := filePathBuilder.String()
+
+ runtime.GC()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filePath),
+ FollowFinalSymlink: true,
+ }, &vfs.StatOptions{})
+ if err != nil {
+ b.Fatalf("stat(%q) failed: %v", filePath, err)
+ }
+ // Sanity check.
+ if stat.Mode&^linux.S_IFMT != 0644 {
+ b.Fatalf("got wrong permissions (%0o)", stat.Mode)
+ }
+ }
+ // Don't include deferred cleanup in benchmark time.
+ b.StopTimer()
+ })
+ }
+}
+
+func BenchmarkVFS1TmpfsMountStat(b *testing.B) {
+ for _, depth := range depths {
+ b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) {
+ ctx := contexttest.Context(b)
+
+ // Create VFS.
+ tmpfsFS, ok := fs.FindFilesystem("tmpfs")
+ if !ok {
+ b.Fatalf("failed to find tmpfs filesystem type")
+ }
+ rootInode, err := tmpfsFS.Mount(ctx, "tmpfs", fs.MountSourceFlags{}, "", nil)
+ if err != nil {
+ b.Fatalf("failed to create tmpfs root mount: %v", err)
+ }
+ mntns, err := fs.NewMountNamespace(ctx, rootInode)
+ if err != nil {
+ b.Fatalf("failed to create mount namespace: %v", err)
+ }
+ defer mntns.DecRef()
+
+ var filePathBuilder strings.Builder
+ filePathBuilder.WriteByte('/')
+
+ // Create and mount the submount.
+ root := mntns.Root()
+ defer root.DecRef()
+ if err := root.Inode.CreateDirectory(ctx, root, mountPointName, fs.FilePermsFromMode(0755)); err != nil {
+ b.Fatalf("failed to create mount point: %v", err)
+ }
+ mountPoint, err := root.Walk(ctx, root, mountPointName)
+ if err != nil {
+ b.Fatalf("failed to walk to mount point: %v", err)
+ }
+ defer mountPoint.DecRef()
+ submountInode, err := tmpfsFS.Mount(ctx, "tmpfs", fs.MountSourceFlags{}, "", nil)
+ if err != nil {
+ b.Fatalf("failed to create tmpfs submount: %v", err)
+ }
+ if err := mntns.Mount(ctx, mountPoint, submountInode); err != nil {
+ b.Fatalf("failed to mount tmpfs submount: %v", err)
+ }
+ filePathBuilder.WriteString(mountPointName)
+ filePathBuilder.WriteByte('/')
+
+ // Create nested directories with given depth.
+ d, err := root.Walk(ctx, root, mountPointName)
+ if err != nil {
+ b.Fatalf("failed to walk to mount root: %v", err)
+ }
+ defer d.DecRef()
+ for i := depth; i > 0; i-- {
+ name := fmt.Sprintf("%d", i)
+ if err := d.Inode.CreateDirectory(ctx, d, name, fs.FilePermsFromMode(0755)); err != nil {
+ b.Fatalf("failed to create directory %q: %v", name, err)
+ }
+ next, err := d.Walk(ctx, root, name)
+ if err != nil {
+ b.Fatalf("failed to walk to directory %q: %v", name, err)
+ }
+ d.DecRef()
+ d = next
+ filePathBuilder.WriteString(name)
+ filePathBuilder.WriteByte('/')
+ }
+
+ // Create the file that will be stat'd.
+ file, err := d.Inode.Create(ctx, d, filename, fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0644))
+ if err != nil {
+ b.Fatalf("failed to create file %q: %v", filename, err)
+ }
+ file.DecRef()
+ filePathBuilder.WriteString(filename)
+ filePath := filePathBuilder.String()
+
+ dirPath := false
+ runtime.GC()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ err := fileOpOn(ctx, mntns, root, root, linux.AT_FDCWD, filePath, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent) error {
+ if dirPath && !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+ uattr, err := d.Inode.UnstableAttr(ctx)
+ if err != nil {
+ return err
+ }
+ // Sanity check.
+ if uattr.Perms.User.Execute {
+ b.Fatalf("got wrong permissions (%0o)", uattr.Perms.LinuxMode())
+ }
+ return nil
+ })
+ if err != nil {
+ b.Fatalf("stat(%q) failed: %v", filePath, err)
+ }
+ }
+ // Don't include deferred cleanup in benchmark time.
+ b.StopTimer()
+ })
+ }
+}
+
+func BenchmarkVFS2TmpfsMountStat(b *testing.B) {
+ for _, depth := range depths {
+ b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) {
+ ctx := contexttest.Context(b)
+ creds := auth.CredentialsFromContext(ctx)
+
+ // Create VFS.
+ vfsObj := vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ b.Fatalf("VFS init: %v", err)
+ }
+ vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
+ if err != nil {
+ b.Fatalf("failed to create tmpfs root mount: %v", err)
+ }
+ defer mntns.DecRef()
+
+ var filePathBuilder strings.Builder
+ filePathBuilder.WriteByte('/')
+
+ // Create the mount point.
+ root := mntns.Root()
+ defer root.DecRef()
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(mountPointName),
+ }
+ if err := vfsObj.MkdirAt(ctx, creds, &pop, &vfs.MkdirOptions{
+ Mode: 0755,
+ }); err != nil {
+ b.Fatalf("failed to create mount point: %v", err)
+ }
+ // Save the mount point for later use.
+ mountPoint, err := vfsObj.GetDentryAt(ctx, creds, &pop, &vfs.GetDentryOptions{})
+ if err != nil {
+ b.Fatalf("failed to walk to mount point: %v", err)
+ }
+ defer mountPoint.DecRef()
+ // Create and mount the submount.
+ if err := vfsObj.MountAt(ctx, creds, "", &pop, "tmpfs", &vfs.MountOptions{}); err != nil {
+ b.Fatalf("failed to mount tmpfs submount: %v", err)
+ }
+ filePathBuilder.WriteString(mountPointName)
+ filePathBuilder.WriteByte('/')
+
+ // Create nested directories with given depth.
+ vd, err := vfsObj.GetDentryAt(ctx, creds, &pop, &vfs.GetDentryOptions{})
+ if err != nil {
+ b.Fatalf("failed to walk to mount root: %v", err)
+ }
+ for i := depth; i > 0; i-- {
+ name := fmt.Sprintf("%d", i)
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: vd,
+ Path: fspath.Parse(name),
+ }
+ if err := vfsObj.MkdirAt(ctx, creds, &pop, &vfs.MkdirOptions{
+ Mode: 0755,
+ }); err != nil {
+ b.Fatalf("failed to create directory %q: %v", name, err)
+ }
+ nextVD, err := vfsObj.GetDentryAt(ctx, creds, &pop, &vfs.GetDentryOptions{})
+ if err != nil {
+ b.Fatalf("failed to walk to directory %q: %v", name, err)
+ }
+ vd.DecRef()
+ vd = nextVD
+ filePathBuilder.WriteString(name)
+ filePathBuilder.WriteByte('/')
+ }
+
+ // Create the file that will be stat'd.
+ fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: vd,
+ Path: fspath.Parse(filename),
+ FollowFinalSymlink: true,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
+ Mode: 0644,
+ })
+ vd.DecRef()
+ if err != nil {
+ b.Fatalf("failed to create file %q: %v", filename, err)
+ }
+ fd.DecRef()
+ filePathBuilder.WriteString(filename)
+ filePath := filePathBuilder.String()
+
+ runtime.GC()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filePath),
+ FollowFinalSymlink: true,
+ }, &vfs.StatOptions{})
+ if err != nil {
+ b.Fatalf("stat(%q) failed: %v", filePath, err)
+ }
+ // Sanity check.
+ if stat.Mode&^linux.S_IFMT != 0644 {
+ b.Fatalf("got wrong permissions (%0o)", stat.Mode)
+ }
+ }
+ // Don't include deferred cleanup in benchmark time.
+ b.StopTimer()
+ })
+ }
+}
+
+func init() {
+ // Turn off reference leak checking for a fair comparison between vfs1 and
+ // vfs2.
+ refs.SetLeakMode(refs.NoLeakChecking)
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/device_file.go b/pkg/sentry/fsimpl/tmpfs/device_file.go
new file mode 100644
index 000000000..ac54d420d
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/device_file.go
@@ -0,0 +1,49 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+type deviceFile struct {
+ inode inode
+ kind vfs.DeviceKind
+ major uint32
+ minor uint32
+}
+
+func (fs *filesystem) newDeviceFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode, kind vfs.DeviceKind, major, minor uint32) *inode {
+ file := &deviceFile{
+ kind: kind,
+ major: major,
+ minor: minor,
+ }
+ switch kind {
+ case vfs.BlockDevice:
+ mode |= linux.S_IFBLK
+ case vfs.CharDevice:
+ mode |= linux.S_IFCHR
+ default:
+ panic(fmt.Sprintf("invalid DeviceKind: %v", kind))
+ }
+ file.inode.init(file, fs, kuid, kgid, mode)
+ file.inode.nlink = 1 // from parent directory
+ return &file.inode
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go
new file mode 100644
index 000000000..0a1ad4765
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/directory.go
@@ -0,0 +1,232 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+type directory struct {
+ // Since directories can't be hard-linked, each directory can only be
+ // associated with a single dentry, which we can store in the directory
+ // struct.
+ dentry dentry
+ inode inode
+
+ // childMap maps the names of the directory's children to their dentries.
+ // childMap is protected by filesystem.mu.
+ childMap map[string]*dentry
+
+ // numChildren is len(childMap), but accessed using atomic memory
+ // operations to avoid locking in inode.statTo().
+ numChildren int64
+
+ // childList is a list containing (1) child dentries and (2) fake dentries
+ // (with inode == nil) that represent the iteration position of
+ // directoryFDs. childList is used to support directoryFD.IterDirents()
+ // efficiently. childList is protected by iterMu.
+ iterMu sync.Mutex
+ childList dentryList
+}
+
+func (fs *filesystem) newDirectory(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *directory {
+ dir := &directory{}
+ dir.inode.init(dir, fs, kuid, kgid, linux.S_IFDIR|mode)
+ dir.inode.nlink = 2 // from "." and parent directory or ".." for root
+ dir.dentry.inode = &dir.inode
+ dir.dentry.vfsd.Init(&dir.dentry)
+ return dir
+}
+
+// Preconditions: filesystem.mu must be locked for writing. dir must not
+// already contain a child with the given name.
+func (dir *directory) insertChildLocked(child *dentry, name string) {
+ child.parent = &dir.dentry
+ child.name = name
+ if dir.childMap == nil {
+ dir.childMap = make(map[string]*dentry)
+ }
+ dir.childMap[name] = child
+ atomic.AddInt64(&dir.numChildren, 1)
+ dir.iterMu.Lock()
+ dir.childList.PushBack(child)
+ dir.iterMu.Unlock()
+}
+
+// Preconditions: filesystem.mu must be locked for writing.
+func (dir *directory) removeChildLocked(child *dentry) {
+ delete(dir.childMap, child.name)
+ atomic.AddInt64(&dir.numChildren, -1)
+ dir.iterMu.Lock()
+ dir.childList.Remove(child)
+ dir.iterMu.Unlock()
+}
+
+func (dir *directory) mayDelete(creds *auth.Credentials, child *dentry) error {
+ return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&dir.inode.mode)), auth.KUID(atomic.LoadUint32(&child.inode.uid)))
+}
+
+type directoryFD struct {
+ fileDescription
+ vfs.DirectoryFileDescriptionDefaultImpl
+
+ // Protected by directory.iterMu.
+ iter *dentry
+ off int64
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *directoryFD) Release() {
+ if fd.iter != nil {
+ dir := fd.inode().impl.(*directory)
+ dir.iterMu.Lock()
+ dir.childList.Remove(fd.iter)
+ dir.iterMu.Unlock()
+ fd.iter = nil
+ }
+}
+
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ fs := fd.filesystem()
+ dir := fd.inode().impl.(*directory)
+
+ defer fd.dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+
+ // fs.mu is required to read d.parent and dentry.name.
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ dir.iterMu.Lock()
+ defer dir.iterMu.Unlock()
+
+ fd.inode().touchAtime(fd.vfsfd.Mount())
+
+ if fd.off == 0 {
+ if err := cb.Handle(vfs.Dirent{
+ Name: ".",
+ Type: linux.DT_DIR,
+ Ino: dir.inode.ino,
+ NextOff: 1,
+ }); err != nil {
+ return err
+ }
+ fd.off++
+ }
+
+ if fd.off == 1 {
+ parentInode := genericParentOrSelf(&dir.dentry).inode
+ if err := cb.Handle(vfs.Dirent{
+ Name: "..",
+ Type: parentInode.direntType(),
+ Ino: parentInode.ino,
+ NextOff: 2,
+ }); err != nil {
+ return err
+ }
+ fd.off++
+ }
+
+ var child *dentry
+ if fd.iter == nil {
+ // Start iteration at the beginning of dir.
+ child = dir.childList.Front()
+ fd.iter = &dentry{}
+ } else {
+ // Continue iteration from where we left off.
+ child = fd.iter.Next()
+ dir.childList.Remove(fd.iter)
+ }
+ for child != nil {
+ // Skip other directoryFD iterators.
+ if child.inode != nil {
+ if err := cb.Handle(vfs.Dirent{
+ Name: child.name,
+ Type: child.inode.direntType(),
+ Ino: child.inode.ino,
+ NextOff: fd.off + 1,
+ }); err != nil {
+ dir.childList.InsertBefore(child, fd.iter)
+ return err
+ }
+ fd.off++
+ }
+ child = child.Next()
+ }
+ dir.childList.PushBack(fd.iter)
+ return nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ dir := fd.inode().impl.(*directory)
+ dir.iterMu.Lock()
+ defer dir.iterMu.Unlock()
+
+ switch whence {
+ case linux.SEEK_SET:
+ // Use offset as given.
+ case linux.SEEK_CUR:
+ offset += fd.off
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // If the offset isn't changing (e.g. due to lseek(0, SEEK_CUR)), don't
+ // seek even if doing so might reposition the iterator due to concurrent
+ // mutation of the directory. Compare fs/libfs.c:dcache_dir_lseek().
+ if fd.off == offset {
+ return offset, nil
+ }
+
+ fd.off = offset
+ // Compensate for "." and "..".
+ remChildren := int64(0)
+ if offset >= 2 {
+ remChildren = offset - 2
+ }
+
+ // Ensure that fd.iter exists and is not linked into dir.childList.
+ if fd.iter == nil {
+ fd.iter = &dentry{}
+ } else {
+ dir.childList.Remove(fd.iter)
+ }
+ // Insert fd.iter before the remChildren'th child, or at the end of the
+ // list if remChildren >= number of children.
+ child := dir.childList.Front()
+ for child != nil {
+ // Skip other directoryFD iterators.
+ if child.inode != nil {
+ if remChildren == 0 {
+ dir.childList.InsertBefore(child, fd.iter)
+ return offset, nil
+ }
+ remChildren--
+ }
+ child = child.Next()
+ }
+ dir.childList.PushBack(fd.iter)
+ return offset, nil
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
new file mode 100644
index 000000000..a0f20c2d4
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -0,0 +1,859 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Sync implements vfs.FilesystemImpl.Sync.
+func (fs *filesystem) Sync(ctx context.Context) error {
+ // All filesystem state is in-memory.
+ return nil
+}
+
+// stepLocked resolves rp.Component() to an existing file, starting from the
+// given directory.
+//
+// stepLocked is loosely analogous to fs/namei.c:walk_component().
+//
+// Preconditions: filesystem.mu must be locked. !rp.Done().
+func stepLocked(rp *vfs.ResolvingPath, d *dentry) (*dentry, error) {
+ dir, ok := d.inode.impl.(*directory)
+ if !ok {
+ return nil, syserror.ENOTDIR
+ }
+ if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+afterSymlink:
+ name := rp.Component()
+ if name == "." {
+ rp.Advance()
+ return d, nil
+ }
+ if name == ".." {
+ if isRoot, err := rp.CheckRoot(&d.vfsd); err != nil {
+ return nil, err
+ } else if isRoot || d.parent == nil {
+ rp.Advance()
+ return d, nil
+ }
+ if err := rp.CheckMount(&d.parent.vfsd); err != nil {
+ return nil, err
+ }
+ rp.Advance()
+ return d.parent, nil
+ }
+ if len(name) > linux.NAME_MAX {
+ return nil, syserror.ENAMETOOLONG
+ }
+ child, ok := dir.childMap[name]
+ if !ok {
+ return nil, syserror.ENOENT
+ }
+ if err := rp.CheckMount(&child.vfsd); err != nil {
+ return nil, err
+ }
+ if symlink, ok := child.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() {
+ // Symlink traversal updates access time.
+ child.inode.touchAtime(rp.Mount())
+ if err := rp.HandleSymlink(symlink.target); err != nil {
+ return nil, err
+ }
+ goto afterSymlink // don't check the current directory again
+ }
+ rp.Advance()
+ return child, nil
+}
+
+// walkParentDirLocked resolves all but the last path component of rp to an
+// existing directory, starting from the given directory (which is usually
+// rp.Start().Impl().(*dentry)). It does not check that the returned directory
+// is searchable by the provider of rp.
+//
+// walkParentDirLocked is loosely analogous to Linux's
+// fs/namei.c:path_parentat().
+//
+// Preconditions: filesystem.mu must be locked. !rp.Done().
+func walkParentDirLocked(rp *vfs.ResolvingPath, d *dentry) (*directory, error) {
+ for !rp.Final() {
+ next, err := stepLocked(rp, d)
+ if err != nil {
+ return nil, err
+ }
+ d = next
+ }
+ dir, ok := d.inode.impl.(*directory)
+ if !ok {
+ return nil, syserror.ENOTDIR
+ }
+ return dir, nil
+}
+
+// resolveLocked resolves rp to an existing file.
+//
+// resolveLocked is loosely analogous to Linux's fs/namei.c:path_lookupat().
+//
+// Preconditions: filesystem.mu must be locked.
+func resolveLocked(rp *vfs.ResolvingPath) (*dentry, error) {
+ d := rp.Start().Impl().(*dentry)
+ for !rp.Done() {
+ next, err := stepLocked(rp, d)
+ if err != nil {
+ return nil, err
+ }
+ d = next
+ }
+ if rp.MustBeDir() && !d.inode.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ return d, nil
+}
+
+// doCreateAt checks that creating a file at rp is permitted, then invokes
+// create to do so.
+//
+// doCreateAt is loosely analogous to a conjunction of Linux's
+// fs/namei.c:filename_create() and done_path_create().
+//
+// Preconditions: !rp.Done(). For the final path component in rp,
+// !rp.ShouldFollowSymlink().
+func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(parentDir *directory, name string) error) error {
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parentDir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry))
+ if err != nil {
+ return err
+ }
+ if err := parentDir.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ name := rp.Component()
+ if name == "." || name == ".." {
+ return syserror.EEXIST
+ }
+ if len(name) > linux.NAME_MAX {
+ return syserror.ENAMETOOLONG
+ }
+ if _, ok := parentDir.childMap[name]; ok {
+ return syserror.EEXIST
+ }
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
+ // tmpfs never calls VFS.InvalidateDentry(), so parentDir.dentry can only
+ // be dead if it was deleted.
+ if parentDir.dentry.vfsd.IsDead() {
+ return syserror.ENOENT
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ if err := create(parentDir, name); err != nil {
+ return err
+ }
+
+ ev := linux.IN_CREATE
+ if dir {
+ ev |= linux.IN_ISDIR
+ }
+ parentDir.inode.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
+ parentDir.inode.touchCMtime()
+ return nil
+}
+
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(rp)
+ if err != nil {
+ return err
+ }
+ return d.inode.checkPermissions(creds, ats)
+}
+
+// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
+func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(rp)
+ if err != nil {
+ return nil, err
+ }
+ if opts.CheckSearchable {
+ if !d.inode.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ }
+ d.IncRef()
+ return &d.vfsd, nil
+}
+
+// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt.
+func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ dir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry))
+ if err != nil {
+ return nil, err
+ }
+ dir.dentry.IncRef()
+ return &dir.dentry.vfsd, nil
+}
+
+// LinkAt implements vfs.FilesystemImpl.LinkAt.
+func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
+ return fs.doCreateAt(rp, false /* dir */, func(parentDir *directory, name string) error {
+ if rp.Mount() != vd.Mount() {
+ return syserror.EXDEV
+ }
+ d := vd.Dentry().Impl().(*dentry)
+ i := d.inode
+ if i.isDir() {
+ return syserror.EPERM
+ }
+ if err := vfs.MayLink(auth.CredentialsFromContext(ctx), linux.FileMode(atomic.LoadUint32(&i.mode)), auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil {
+ return err
+ }
+ if i.nlink == 0 {
+ return syserror.ENOENT
+ }
+ if i.nlink == maxLinks {
+ return syserror.EMLINK
+ }
+ i.incLinksLocked()
+ i.watches.Notify("", linux.IN_ATTRIB, 0, vfs.InodeEvent, false /* unlinked */)
+ parentDir.insertChildLocked(fs.newDentry(i), name)
+ return nil
+ })
+}
+
+// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
+func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
+ return fs.doCreateAt(rp, true /* dir */, func(parentDir *directory, name string) error {
+ creds := rp.Credentials()
+ if parentDir.inode.nlink == maxLinks {
+ return syserror.EMLINK
+ }
+ parentDir.inode.incLinksLocked() // from child's ".."
+ childDir := fs.newDirectory(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)
+ parentDir.insertChildLocked(&childDir.dentry, name)
+ return nil
+ })
+}
+
+// MknodAt implements vfs.FilesystemImpl.MknodAt.
+func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
+ return fs.doCreateAt(rp, false /* dir */, func(parentDir *directory, name string) error {
+ creds := rp.Credentials()
+ var childInode *inode
+ switch opts.Mode.FileType() {
+ case linux.S_IFREG:
+ childInode = fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)
+ case linux.S_IFIFO:
+ childInode = fs.newNamedPipe(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)
+ case linux.S_IFBLK:
+ childInode = fs.newDeviceFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode, vfs.BlockDevice, opts.DevMajor, opts.DevMinor)
+ case linux.S_IFCHR:
+ childInode = fs.newDeviceFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode, vfs.CharDevice, opts.DevMajor, opts.DevMinor)
+ case linux.S_IFSOCK:
+ childInode = fs.newSocketFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode, opts.Endpoint)
+ default:
+ return syserror.EINVAL
+ }
+ child := fs.newDentry(childInode)
+ parentDir.insertChildLocked(child, name)
+ return nil
+ })
+}
+
+// OpenAt implements vfs.FilesystemImpl.OpenAt.
+func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ if opts.Flags&linux.O_TMPFILE != 0 {
+ // Not yet supported.
+ return nil, syserror.EOPNOTSUPP
+ }
+
+ // Handle O_CREAT and !O_CREAT separately, since in the latter case we
+ // don't need fs.mu for writing.
+ if opts.Flags&linux.O_CREAT == 0 {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(rp)
+ if err != nil {
+ return nil, err
+ }
+ return d.open(ctx, rp, &opts, false /* afterCreate */)
+ }
+
+ mustCreate := opts.Flags&linux.O_EXCL != 0
+ start := rp.Start().Impl().(*dentry)
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ if rp.Done() {
+ // Reject attempts to open directories with O_CREAT.
+ if rp.MustBeDir() {
+ return nil, syserror.EISDIR
+ }
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ return start.open(ctx, rp, &opts, false /* afterCreate */)
+ }
+afterTrailingSymlink:
+ parentDir, err := walkParentDirLocked(rp, start)
+ if err != nil {
+ return nil, err
+ }
+ // Check for search permission in the parent directory.
+ if err := parentDir.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ // Reject attempts to open directories with O_CREAT.
+ if rp.MustBeDir() {
+ return nil, syserror.EISDIR
+ }
+ name := rp.Component()
+ if name == "." || name == ".." {
+ return nil, syserror.EISDIR
+ }
+ if len(name) > linux.NAME_MAX {
+ return nil, syserror.ENAMETOOLONG
+ }
+ // Determine whether or not we need to create a file.
+ child, ok := parentDir.childMap[name]
+ if !ok {
+ // Already checked for searchability above; now check for writability.
+ if err := parentDir.inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return nil, err
+ }
+ defer rp.Mount().EndWrite()
+ // Create and open the child.
+ creds := rp.Credentials()
+ child := fs.newDentry(fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode))
+ parentDir.insertChildLocked(child, name)
+ fd, err := child.open(ctx, rp, &opts, true)
+ if err != nil {
+ return nil, err
+ }
+ parentDir.inode.watches.Notify(name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */)
+ parentDir.inode.touchCMtime()
+ return fd, nil
+ }
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ // Is the file mounted over?
+ if err := rp.CheckMount(&child.vfsd); err != nil {
+ return nil, err
+ }
+ // Do we need to resolve a trailing symlink?
+ if symlink, ok := child.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() {
+ // Symlink traversal updates access time.
+ child.inode.touchAtime(rp.Mount())
+ if err := rp.HandleSymlink(symlink.target); err != nil {
+ return nil, err
+ }
+ start = &parentDir.dentry
+ goto afterTrailingSymlink
+ }
+ // Open existing file.
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ return child.open(ctx, rp, &opts, false)
+}
+
+func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions, afterCreate bool) (*vfs.FileDescription, error) {
+ ats := vfs.AccessTypesForOpenFlags(opts)
+ if !afterCreate {
+ if err := d.inode.checkPermissions(rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+ }
+ switch impl := d.inode.impl.(type) {
+ case *regularFile:
+ var fd regularFileFD
+ fd.LockFD.Init(&d.inode.locks)
+ if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{AllowDirectIO: true}); err != nil {
+ return nil, err
+ }
+ if !afterCreate && opts.Flags&linux.O_TRUNC != 0 {
+ if _, err := impl.truncate(0); err != nil {
+ return nil, err
+ }
+ }
+ return &fd.vfsfd, nil
+ case *directory:
+ // Can't open directories writably.
+ if ats&vfs.MayWrite != 0 {
+ return nil, syserror.EISDIR
+ }
+ var fd directoryFD
+ fd.LockFD.Init(&d.inode.locks)
+ if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{AllowDirectIO: true}); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+ case *symlink:
+ // TODO(gvisor.dev/issue/2782): Can't open symlinks without O_PATH.
+ return nil, syserror.ELOOP
+ case *namedPipe:
+ return impl.pipe.Open(ctx, rp.Mount(), &d.vfsd, opts.Flags, &d.inode.locks)
+ case *deviceFile:
+ return rp.VirtualFilesystem().OpenDeviceSpecialFile(ctx, rp.Mount(), &d.vfsd, impl.kind, impl.major, impl.minor, opts)
+ case *socketFile:
+ return nil, syserror.ENXIO
+ default:
+ panic(fmt.Sprintf("unknown inode type: %T", d.inode.impl))
+ }
+}
+
+// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
+func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(rp)
+ if err != nil {
+ return "", err
+ }
+ symlink, ok := d.inode.impl.(*symlink)
+ if !ok {
+ return "", syserror.EINVAL
+ }
+ symlink.inode.touchAtime(rp.Mount())
+ return symlink.target, nil
+}
+
+// RenameAt implements vfs.FilesystemImpl.RenameAt.
+func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
+ if opts.Flags != 0 {
+ // TODO(b/145974740): Support renameat2 flags.
+ return syserror.EINVAL
+ }
+
+ // Resolve newParent first to verify that it's on this Mount.
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ newParentDir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry))
+ if err != nil {
+ return err
+ }
+ newName := rp.Component()
+ if newName == "." || newName == ".." {
+ return syserror.EBUSY
+ }
+ mnt := rp.Mount()
+ if mnt != oldParentVD.Mount() {
+ return syserror.EXDEV
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+
+ oldParentDir := oldParentVD.Dentry().Impl().(*dentry).inode.impl.(*directory)
+ if err := oldParentDir.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ renamed, ok := oldParentDir.childMap[oldName]
+ if !ok {
+ return syserror.ENOENT
+ }
+ if err := oldParentDir.mayDelete(rp.Credentials(), renamed); err != nil {
+ return err
+ }
+ // Note that we don't need to call rp.CheckMount(), since if renamed is a
+ // mount point then we want to rename the mount point, not anything in the
+ // mounted filesystem.
+ if renamed.inode.isDir() {
+ if renamed == &newParentDir.dentry || genericIsAncestorDentry(renamed, &newParentDir.dentry) {
+ return syserror.EINVAL
+ }
+ if oldParentDir != newParentDir {
+ // Writability is needed to change renamed's "..".
+ if err := renamed.inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return err
+ }
+ }
+ } else {
+ if opts.MustBeDir || rp.MustBeDir() {
+ return syserror.ENOTDIR
+ }
+ }
+
+ if err := newParentDir.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ replaced, ok := newParentDir.childMap[newName]
+ if ok {
+ replacedDir, ok := replaced.inode.impl.(*directory)
+ if ok {
+ if !renamed.inode.isDir() {
+ return syserror.EISDIR
+ }
+ if len(replacedDir.childMap) != 0 {
+ return syserror.ENOTEMPTY
+ }
+ } else {
+ if rp.MustBeDir() {
+ return syserror.ENOTDIR
+ }
+ if renamed.inode.isDir() {
+ return syserror.ENOTDIR
+ }
+ }
+ } else {
+ if renamed.inode.isDir() && newParentDir.inode.nlink == maxLinks {
+ return syserror.EMLINK
+ }
+ }
+ // tmpfs never calls VFS.InvalidateDentry(), so newParentDir.dentry can
+ // only be dead if it was deleted.
+ if newParentDir.dentry.vfsd.IsDead() {
+ return syserror.ENOENT
+ }
+
+ // Linux places this check before some of those above; we do it here for
+ // simplicity, under the assumption that applications are not intentionally
+ // doing noop renames expecting them to succeed where non-noop renames
+ // would fail.
+ if renamed == replaced {
+ return nil
+ }
+ vfsObj := rp.VirtualFilesystem()
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ var replacedVFSD *vfs.Dentry
+ if replaced != nil {
+ replacedVFSD = &replaced.vfsd
+ }
+ if err := vfsObj.PrepareRenameDentry(mntns, &renamed.vfsd, replacedVFSD); err != nil {
+ return err
+ }
+ if replaced != nil {
+ newParentDir.removeChildLocked(replaced)
+ if replaced.inode.isDir() {
+ newParentDir.inode.decLinksLocked() // from replaced's ".."
+ }
+ replaced.inode.decLinksLocked()
+ }
+ oldParentDir.removeChildLocked(renamed)
+ newParentDir.insertChildLocked(renamed, newName)
+ vfsObj.CommitRenameReplaceDentry(&renamed.vfsd, replacedVFSD)
+ oldParentDir.inode.touchCMtime()
+ if oldParentDir != newParentDir {
+ if renamed.inode.isDir() {
+ oldParentDir.inode.decLinksLocked()
+ newParentDir.inode.incLinksLocked()
+ }
+ newParentDir.inode.touchCMtime()
+ }
+ renamed.inode.touchCtime()
+
+ vfs.InotifyRename(ctx, &renamed.inode.watches, &oldParentDir.inode.watches, &newParentDir.inode.watches, oldName, newName, renamed.inode.isDir())
+ return nil
+}
+
+// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
+func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parentDir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry))
+ if err != nil {
+ return err
+ }
+ if err := parentDir.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ name := rp.Component()
+ if name == "." {
+ return syserror.EINVAL
+ }
+ if name == ".." {
+ return syserror.ENOTEMPTY
+ }
+ child, ok := parentDir.childMap[name]
+ if !ok {
+ return syserror.ENOENT
+ }
+ if err := parentDir.mayDelete(rp.Credentials(), child); err != nil {
+ return err
+ }
+ childDir, ok := child.inode.impl.(*directory)
+ if !ok {
+ return syserror.ENOTDIR
+ }
+ if len(childDir.childMap) != 0 {
+ return syserror.ENOTEMPTY
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ vfsObj := rp.VirtualFilesystem()
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
+ return err
+ }
+ parentDir.removeChildLocked(child)
+ parentDir.inode.watches.Notify(name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */)
+ // Remove links for child, child/., and child/..
+ child.inode.decLinksLocked()
+ child.inode.decLinksLocked()
+ parentDir.inode.decLinksLocked()
+ vfsObj.CommitDeleteDentry(&child.vfsd)
+ parentDir.inode.touchCMtime()
+ return nil
+}
+
+// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
+func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
+ fs.mu.RLock()
+ d, err := resolveLocked(rp)
+ if err != nil {
+ fs.mu.RUnlock()
+ return err
+ }
+ if err := d.inode.setStat(ctx, rp.Credentials(), &opts.Stat); err != nil {
+ fs.mu.RUnlock()
+ return err
+ }
+ fs.mu.RUnlock()
+
+ if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
+ d.InotifyWithParent(ev, 0, vfs.InodeEvent)
+ }
+ return nil
+}
+
+// StatAt implements vfs.FilesystemImpl.StatAt.
+func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(rp)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ var stat linux.Statx
+ d.inode.statTo(&stat)
+ return stat, nil
+}
+
+// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
+func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ if _, err := resolveLocked(rp); err != nil {
+ return linux.Statfs{}, err
+ }
+ statfs := linux.Statfs{
+ Type: linux.TMPFS_MAGIC,
+ BlockSize: usermem.PageSize,
+ FragmentSize: usermem.PageSize,
+ NameLength: linux.NAME_MAX,
+ // TODO(b/29637826): Allow configuring a tmpfs size and enforce it.
+ Blocks: 0,
+ BlocksFree: 0,
+ }
+ return statfs, nil
+}
+
+// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
+func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
+ return fs.doCreateAt(rp, false /* dir */, func(parentDir *directory, name string) error {
+ creds := rp.Credentials()
+ child := fs.newDentry(fs.newSymlink(creds.EffectiveKUID, creds.EffectiveKGID, 0777, target))
+ parentDir.insertChildLocked(child, name)
+ return nil
+ })
+}
+
+// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
+func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parentDir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry))
+ if err != nil {
+ return err
+ }
+ if err := parentDir.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ name := rp.Component()
+ if name == "." || name == ".." {
+ return syserror.EISDIR
+ }
+ child, ok := parentDir.childMap[name]
+ if !ok {
+ return syserror.ENOENT
+ }
+ if err := parentDir.mayDelete(rp.Credentials(), child); err != nil {
+ return err
+ }
+ if child.inode.isDir() {
+ return syserror.EISDIR
+ }
+ if rp.MustBeDir() {
+ return syserror.ENOTDIR
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ vfsObj := rp.VirtualFilesystem()
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
+ return err
+ }
+
+ // Generate inotify events. Note that this must take place before the link
+ // count of the child is decremented, or else the watches may be dropped
+ // before these events are added.
+ vfs.InotifyRemoveChild(&child.inode.watches, &parentDir.inode.watches, name)
+
+ parentDir.removeChildLocked(child)
+ child.inode.decLinksLocked()
+ vfsObj.CommitDeleteDentry(&child.vfsd)
+ parentDir.inode.touchCMtime()
+ return nil
+}
+
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(rp)
+ if err != nil {
+ return nil, err
+ }
+ if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ switch impl := d.inode.impl.(type) {
+ case *socketFile:
+ return impl.ep, nil
+ default:
+ return nil, syserror.ECONNREFUSED
+ }
+}
+
+// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(rp)
+ if err != nil {
+ return nil, err
+ }
+ return d.inode.listxattr(size)
+}
+
+// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(rp)
+ if err != nil {
+ return "", err
+ }
+ return d.inode.getxattr(rp.Credentials(), &opts)
+}
+
+// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
+func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+ fs.mu.RLock()
+ d, err := resolveLocked(rp)
+ if err != nil {
+ fs.mu.RUnlock()
+ return err
+ }
+ if err := d.inode.setxattr(rp.Credentials(), &opts); err != nil {
+ fs.mu.RUnlock()
+ return err
+ }
+ fs.mu.RUnlock()
+
+ d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
+}
+
+// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
+func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+ fs.mu.RLock()
+ d, err := resolveLocked(rp)
+ if err != nil {
+ fs.mu.RUnlock()
+ return err
+ }
+ if err := d.inode.removexattr(rp.Credentials(), name); err != nil {
+ fs.mu.RUnlock()
+ return err
+ }
+ fs.mu.RUnlock()
+
+ d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ mnt := vd.Mount()
+ d := vd.Dentry().Impl().(*dentry)
+ for {
+ if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() {
+ return vfs.PrependPathAtVFSRootError{}
+ }
+ if &d.vfsd == mnt.Root() {
+ return nil
+ }
+ if d.parent == nil {
+ if d.name != "" {
+ // This must be an anonymous memfd file.
+ b.PrependComponent("/" + d.name)
+ return vfs.PrependPathSyntheticError{}
+ }
+ return vfs.PrependPathAtNonMountRootError{}
+ }
+ b.PrependComponent(d.name)
+ d = d.parent
+ }
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
new file mode 100644
index 000000000..739350cf0
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
@@ -0,0 +1,38 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type namedPipe struct {
+ inode inode
+
+ pipe *pipe.VFSPipe
+}
+
+// Preconditions:
+// * fs.mu must be locked.
+// * rp.Mount().CheckBeginWrite() has been called successfully.
+func (fs *filesystem) newNamedPipe(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode {
+ file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)}
+ file.inode.init(file, fs, kuid, kgid, linux.S_IFIFO|mode)
+ file.inode.nlink = 1 // Only the parent has a link.
+ return &file.inode
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
new file mode 100644
index 000000000..1614f2c39
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
@@ -0,0 +1,238 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "bytes"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const fileName = "mypipe"
+
+func TestSeparateFDs(t *testing.T) {
+ ctx, creds, vfsObj, root := setup(t)
+ defer root.DecRef()
+
+ // Open the read side. This is done in a concurrently because opening
+ // One end the pipe blocks until the other end is opened.
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(fileName),
+ FollowFinalSymlink: true,
+ }
+ rfdchan := make(chan *vfs.FileDescription)
+ go func() {
+ openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY}
+ rfd, _ := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ rfdchan <- rfd
+ }()
+
+ // Open the write side.
+ openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY}
+ wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != nil {
+ t.Fatalf("failed to open pipe for writing %q: %v", fileName, err)
+ }
+ defer wfd.DecRef()
+
+ rfd, ok := <-rfdchan
+ if !ok {
+ t.Fatalf("failed to open pipe for reading %q", fileName)
+ }
+ defer rfd.DecRef()
+
+ const msg = "vamos azul"
+ checkEmpty(ctx, t, rfd)
+ checkWrite(ctx, t, wfd, msg)
+ checkRead(ctx, t, rfd, msg)
+}
+
+func TestNonblockingRead(t *testing.T) {
+ ctx, creds, vfsObj, root := setup(t)
+ defer root.DecRef()
+
+ // Open the read side as nonblocking.
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(fileName),
+ FollowFinalSymlink: true,
+ }
+ openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_NONBLOCK}
+ rfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != nil {
+ t.Fatalf("failed to open pipe for reading %q: %v", fileName, err)
+ }
+ defer rfd.DecRef()
+
+ // Open the write side.
+ openOpts = vfs.OpenOptions{Flags: linux.O_WRONLY}
+ wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != nil {
+ t.Fatalf("failed to open pipe for writing %q: %v", fileName, err)
+ }
+ defer wfd.DecRef()
+
+ const msg = "geh blau"
+ checkEmpty(ctx, t, rfd)
+ checkWrite(ctx, t, wfd, msg)
+ checkRead(ctx, t, rfd, msg)
+}
+
+func TestNonblockingWriteError(t *testing.T) {
+ ctx, creds, vfsObj, root := setup(t)
+ defer root.DecRef()
+
+ // Open the write side as nonblocking, which should return ENXIO.
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(fileName),
+ FollowFinalSymlink: true,
+ }
+ openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY | linux.O_NONBLOCK}
+ _, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != syserror.ENXIO {
+ t.Fatalf("expected ENXIO, but got error: %v", err)
+ }
+}
+
+func TestSingleFD(t *testing.T) {
+ ctx, creds, vfsObj, root := setup(t)
+ defer root.DecRef()
+
+ // Open the pipe as readable and writable.
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(fileName),
+ FollowFinalSymlink: true,
+ }
+ openOpts := vfs.OpenOptions{Flags: linux.O_RDWR}
+ fd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != nil {
+ t.Fatalf("failed to open pipe for writing %q: %v", fileName, err)
+ }
+ defer fd.DecRef()
+
+ const msg = "forza blu"
+ checkEmpty(ctx, t, fd)
+ checkWrite(ctx, t, fd, msg)
+ checkRead(ctx, t, fd, msg)
+}
+
+// setup creates a VFS with a pipe in the root directory at path fileName. The
+// returned VirtualDentry must be DecRef()'d be the caller. It calls t.Fatal
+// upon failure.
+func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesystem, vfs.VirtualDentry) {
+ ctx := contexttest.Context(t)
+ creds := auth.CredentialsFromContext(ctx)
+
+ // Create VFS.
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
+ vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("failed to create tmpfs root mount: %v", err)
+ }
+
+ // Create the pipe.
+ root := mntns.Root()
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(fileName),
+ }
+ mknodOpts := vfs.MknodOptions{Mode: linux.ModeNamedPipe | 0644}
+ if err := vfsObj.MknodAt(ctx, creds, &pop, &mknodOpts); err != nil {
+ t.Fatalf("failed to create file %q: %v", fileName, err)
+ }
+
+ // Sanity check: the file pipe exists and has the correct mode.
+ stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(fileName),
+ FollowFinalSymlink: true,
+ }, &vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("stat(%q) failed: %v", fileName, err)
+ }
+ if stat.Mode&^linux.S_IFMT != 0644 {
+ t.Errorf("got wrong permissions (%0o)", stat.Mode)
+ }
+ if stat.Mode&linux.S_IFMT != linux.ModeNamedPipe {
+ t.Errorf("got wrong file type (%0o)", stat.Mode)
+ }
+
+ return ctx, creds, vfsObj, root
+}
+
+// checkEmpty calls t.Fatal if the pipe in fd is not empty.
+func checkEmpty(ctx context.Context, t *testing.T, fd *vfs.FileDescription) {
+ readData := make([]byte, 1)
+ dst := usermem.BytesIOSequence(readData)
+ bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{})
+ if err != syserror.ErrWouldBlock {
+ t.Fatalf("expected ErrWouldBlock reading from empty pipe %q, but got: %v", fileName, err)
+ }
+ if bytesRead != 0 {
+ t.Fatalf("expected to read 0 bytes, but got %d", bytesRead)
+ }
+}
+
+// checkWrite calls t.Fatal if it fails to write all of msg to fd.
+func checkWrite(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) {
+ writeData := []byte(msg)
+ src := usermem.BytesIOSequence(writeData)
+ bytesWritten, err := fd.Write(ctx, src, vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("error writing to pipe %q: %v", fileName, err)
+ }
+ if bytesWritten != int64(len(writeData)) {
+ t.Fatalf("expected to write %d bytes, but wrote %d", len(writeData), bytesWritten)
+ }
+}
+
+// checkRead calls t.Fatal if it fails to read msg from fd.
+func checkRead(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) {
+ readData := make([]byte, len(msg))
+ dst := usermem.BytesIOSequence(readData)
+ bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{})
+ if err != nil {
+ t.Fatalf("error reading from pipe %q: %v", fileName, err)
+ }
+ if bytesRead != int64(len(msg)) {
+ t.Fatalf("expected to read %d bytes, but got %d", len(msg), bytesRead)
+ }
+ if !bytes.Equal(readData, []byte(msg)) {
+ t.Fatalf("expected to read %q from pipe, but got %q", msg, string(readData))
+ }
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
new file mode 100644
index 000000000..1cdb46e6f
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -0,0 +1,626 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "fmt"
+ "io"
+ "math"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// regularFile is a regular (=S_IFREG) tmpfs file.
+type regularFile struct {
+ inode inode
+
+ // memFile is a platform.File used to allocate pages to this regularFile.
+ memFile *pgalloc.MemoryFile
+
+ // mapsMu protects mappings.
+ mapsMu sync.Mutex `state:"nosave"`
+
+ // mappings tracks mappings of the file into memmap.MappingSpaces.
+ //
+ // Protected by mapsMu.
+ mappings memmap.MappingSet
+
+ // writableMappingPages tracks how many pages of virtual memory are mapped
+ // as potentially writable from this file. If a page has multiple mappings,
+ // each mapping is counted separately.
+ //
+ // This counter is susceptible to overflow as we can potentially count
+ // mappings from many VMAs. We count pages rather than bytes to slightly
+ // mitigate this.
+ //
+ // Protected by mapsMu.
+ writableMappingPages uint64
+
+ // dataMu protects the fields below.
+ dataMu sync.RWMutex
+
+ // data maps offsets into the file to offsets into memFile that store
+ // the file's data.
+ //
+ // Protected by dataMu.
+ data fsutil.FileRangeSet
+
+ // seals represents file seals on this inode.
+ //
+ // Protected by dataMu.
+ seals uint32
+
+ // size is the size of data.
+ //
+ // Protected by both dataMu and inode.mu; reading it requires holding
+ // either mutex, while writing requires holding both AND using atomics.
+ // Readers that do not require consistency (like Stat) may read the
+ // value atomically without holding either lock.
+ size uint64
+}
+
+func (fs *filesystem) newRegularFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode {
+ file := &regularFile{
+ memFile: fs.memFile,
+ seals: linux.F_SEAL_SEAL,
+ }
+ file.inode.init(file, fs, kuid, kgid, linux.S_IFREG|mode)
+ file.inode.nlink = 1 // from parent directory
+ return &file.inode
+}
+
+// truncate grows or shrinks the file to the given size. It returns true if the
+// file size was updated.
+func (rf *regularFile) truncate(newSize uint64) (bool, error) {
+ rf.inode.mu.Lock()
+ defer rf.inode.mu.Unlock()
+ return rf.truncateLocked(newSize)
+}
+
+// Preconditions: rf.inode.mu must be held.
+func (rf *regularFile) truncateLocked(newSize uint64) (bool, error) {
+ oldSize := rf.size
+ if newSize == oldSize {
+ // Nothing to do.
+ return false, nil
+ }
+
+ // Need to hold inode.mu and dataMu while modifying size.
+ rf.dataMu.Lock()
+ if newSize > oldSize {
+ // Can we grow the file?
+ if rf.seals&linux.F_SEAL_GROW != 0 {
+ rf.dataMu.Unlock()
+ return false, syserror.EPERM
+ }
+ // We only need to update the file size.
+ atomic.StoreUint64(&rf.size, newSize)
+ rf.dataMu.Unlock()
+ return true, nil
+ }
+
+ // We are shrinking the file. First check if this is allowed.
+ if rf.seals&linux.F_SEAL_SHRINK != 0 {
+ rf.dataMu.Unlock()
+ return false, syserror.EPERM
+ }
+
+ // Update the file size.
+ atomic.StoreUint64(&rf.size, newSize)
+ rf.dataMu.Unlock()
+
+ // Invalidate past translations of truncated pages.
+ oldpgend := fs.OffsetPageEnd(int64(oldSize))
+ newpgend := fs.OffsetPageEnd(int64(newSize))
+ if newpgend < oldpgend {
+ rf.mapsMu.Lock()
+ rf.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
+ // Compare Linux's mm/shmem.c:shmem_setattr() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ InvalidatePrivate: true,
+ })
+ rf.mapsMu.Unlock()
+ }
+
+ // We are now guaranteed that there are no translations of truncated pages,
+ // and can remove them.
+ rf.dataMu.Lock()
+ rf.data.Truncate(newSize, rf.memFile)
+ rf.dataMu.Unlock()
+ return true, nil
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (rf *regularFile) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ rf.mapsMu.Lock()
+ defer rf.mapsMu.Unlock()
+ rf.dataMu.RLock()
+ defer rf.dataMu.RUnlock()
+
+ // Reject writable mapping if F_SEAL_WRITE is set.
+ if rf.seals&linux.F_SEAL_WRITE != 0 && writable {
+ return syserror.EPERM
+ }
+
+ rf.mappings.AddMapping(ms, ar, offset, writable)
+ if writable {
+ pagesBefore := rf.writableMappingPages
+
+ // ar is guaranteed to be page aligned per memmap.Mappable.
+ rf.writableMappingPages += uint64(ar.Length() / usermem.PageSize)
+
+ if rf.writableMappingPages < pagesBefore {
+ panic(fmt.Sprintf("Overflow while mapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages))
+ }
+ }
+
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (rf *regularFile) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ rf.mapsMu.Lock()
+ defer rf.mapsMu.Unlock()
+
+ rf.mappings.RemoveMapping(ms, ar, offset, writable)
+
+ if writable {
+ pagesBefore := rf.writableMappingPages
+
+ // ar is guaranteed to be page aligned per memmap.Mappable.
+ rf.writableMappingPages -= uint64(ar.Length() / usermem.PageSize)
+
+ if rf.writableMappingPages > pagesBefore {
+ panic(fmt.Sprintf("Underflow while unmapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages))
+ }
+ }
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (rf *regularFile) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ return rf.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (rf *regularFile) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ rf.dataMu.Lock()
+ defer rf.dataMu.Unlock()
+
+ // Constrain translations to f.attr.Size (rounded up) to prevent
+ // translation to pages that may be concurrently truncated.
+ pgend := fs.OffsetPageEnd(int64(rf.size))
+ var beyondEOF bool
+ if required.End > pgend {
+ if required.Start >= pgend {
+ return nil, &memmap.BusError{io.EOF}
+ }
+ beyondEOF = true
+ required.End = pgend
+ }
+ if optional.End > pgend {
+ optional.End = pgend
+ }
+
+ cerr := rf.data.Fill(ctx, required, optional, rf.memFile, usage.Tmpfs, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) {
+ // Newly-allocated pages are zeroed, so we don't need to do anything.
+ return dsts.NumBytes(), nil
+ })
+
+ var ts []memmap.Translation
+ var translatedEnd uint64
+ for seg := rf.data.FindSegment(required.Start); seg.Ok() && seg.Start() < required.End; seg, _ = seg.NextNonEmpty() {
+ segMR := seg.Range().Intersect(optional)
+ ts = append(ts, memmap.Translation{
+ Source: segMR,
+ File: rf.memFile,
+ Offset: seg.FileRangeOf(segMR).Start,
+ Perms: usermem.AnyAccess,
+ })
+ translatedEnd = segMR.End
+ }
+
+ // Don't return the error returned by f.data.Fill if it occurred outside of
+ // required.
+ if translatedEnd < required.End && cerr != nil {
+ return ts, &memmap.BusError{cerr}
+ }
+ if beyondEOF {
+ return ts, &memmap.BusError{io.EOF}
+ }
+ return ts, nil
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (*regularFile) InvalidateUnsavable(context.Context) error {
+ return nil
+}
+
+type regularFileFD struct {
+ fileDescription
+
+ // off is the file offset. off is accessed using atomic memory operations.
+ // offMu serializes operations that may mutate off.
+ off int64
+ offMu sync.Mutex
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *regularFileFD) Release() {
+ // noop
+}
+
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
+func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ f := fd.inode().impl.(*regularFile)
+
+ f.inode.mu.Lock()
+ defer f.inode.mu.Unlock()
+ oldSize := f.size
+ size := offset + length
+ if oldSize >= size {
+ return nil
+ }
+ _, err := f.truncateLocked(size)
+ return err
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Check that flags are supported. RWF_DSYNC/RWF_SYNC can be ignored since
+ // all state is in-memory.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ f := fd.inode().impl.(*regularFile)
+ rw := getRegularFileReadWriter(f, offset)
+ n, err := dst.CopyOutFrom(ctx, rw)
+ putRegularFileReadWriter(rw)
+ fd.inode().touchAtime(fd.vfsfd.Mount())
+ return n, err
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ fd.offMu.Lock()
+ n, err := fd.PRead(ctx, dst, fd.off, opts)
+ fd.off += n
+ fd.offMu.Unlock()
+ return n, err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Check that flags are supported. RWF_DSYNC/RWF_SYNC can be ignored since
+ // all state is in-memory.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ srclen := src.NumBytes()
+ if srclen == 0 {
+ return 0, nil
+ }
+ f := fd.inode().impl.(*regularFile)
+ if end := offset + srclen; end < offset {
+ // Overflow.
+ return 0, syserror.EINVAL
+ }
+
+ var err error
+ srclen, err = vfs.CheckLimit(ctx, offset, srclen)
+ if err != nil {
+ return 0, err
+ }
+ src = src.TakeFirst64(srclen)
+
+ f.inode.mu.Lock()
+ rw := getRegularFileReadWriter(f, offset)
+ n, err := src.CopyInTo(ctx, rw)
+ fd.inode().touchCMtimeLocked()
+ f.inode.mu.Unlock()
+ putRegularFileReadWriter(rw)
+ return n, err
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ fd.offMu.Lock()
+ n, err := fd.PWrite(ctx, src, fd.off, opts)
+ fd.off += n
+ fd.offMu.Unlock()
+ return n, err
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.offMu.Lock()
+ defer fd.offMu.Unlock()
+ switch whence {
+ case linux.SEEK_SET:
+ // use offset as specified
+ case linux.SEEK_CUR:
+ offset += fd.off
+ case linux.SEEK_END:
+ offset += int64(atomic.LoadUint64(&fd.inode().impl.(*regularFile).size))
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ fd.off = offset
+ return offset, nil
+}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ file := fd.inode().impl.(*regularFile)
+ return vfs.GenericConfigureMMap(&fd.vfsfd, file, opts)
+}
+
+// regularFileReadWriter implements safemem.Reader and Safemem.Writer.
+type regularFileReadWriter struct {
+ file *regularFile
+
+ // Offset into the file to read/write at. Note that this may be
+ // different from the FD offset if PRead/PWrite is used.
+ off uint64
+}
+
+var regularFileReadWriterPool = sync.Pool{
+ New: func() interface{} {
+ return &regularFileReadWriter{}
+ },
+}
+
+func getRegularFileReadWriter(file *regularFile, offset int64) *regularFileReadWriter {
+ rw := regularFileReadWriterPool.Get().(*regularFileReadWriter)
+ rw.file = file
+ rw.off = uint64(offset)
+ return rw
+}
+
+func putRegularFileReadWriter(rw *regularFileReadWriter) {
+ rw.file = nil
+ regularFileReadWriterPool.Put(rw)
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ rw.file.dataMu.RLock()
+ defer rw.file.dataMu.RUnlock()
+ size := rw.file.size
+
+ // Compute the range to read (limited by file size and overflow-checked).
+ if rw.off >= size {
+ return 0, io.EOF
+ }
+ end := size
+ if rend := rw.off + dsts.NumBytes(); rend > rw.off && rend < end {
+ end = rend
+ }
+
+ var done uint64
+ seg, gap := rw.file.data.Find(uint64(rw.off))
+ for rw.off < end {
+ mr := memmap.MappableRange{uint64(rw.off), uint64(end)}
+ switch {
+ case seg.Ok():
+ // Get internal mappings.
+ ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
+ if err != nil {
+ return done, err
+ }
+
+ // Copy from internal mappings.
+ n, err := safemem.CopySeq(dsts, ims)
+ done += n
+ rw.off += uint64(n)
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ // Tmpfs holes are zero-filled.
+ gapmr := gap.Range().Intersect(mr)
+ dst := dsts.TakeFirst64(gapmr.Length())
+ n, err := safemem.ZeroSeq(dst)
+ done += n
+ rw.off += uint64(n)
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
+ }
+ }
+ return done, nil
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+//
+// Preconditions: inode.mu must be held.
+func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ // Hold dataMu so we can modify size.
+ rw.file.dataMu.Lock()
+ defer rw.file.dataMu.Unlock()
+
+ // Compute the range to write (overflow-checked).
+ end := rw.off + srcs.NumBytes()
+ if end <= rw.off {
+ end = math.MaxInt64
+ }
+
+ // Check if seals prevent either file growth or all writes.
+ switch {
+ case rw.file.seals&linux.F_SEAL_WRITE != 0: // Write sealed
+ return 0, syserror.EPERM
+ case end > rw.file.size && rw.file.seals&linux.F_SEAL_GROW != 0: // Grow sealed
+ // When growth is sealed, Linux effectively allows writes which would
+ // normally grow the file to partially succeed up to the current EOF,
+ // rounded down to the page boundary before the EOF.
+ //
+ // This happens because writes (and thus the growth check) for tmpfs
+ // files proceed page-by-page on Linux, and the final write to the page
+ // containing EOF fails, resulting in a partial write up to the start of
+ // that page.
+ //
+ // To emulate this behaviour, artifically truncate the write to the
+ // start of the page containing the current EOF.
+ //
+ // See Linux, mm/filemap.c:generic_perform_write() and
+ // mm/shmem.c:shmem_write_begin().
+ if pgstart := uint64(usermem.Addr(rw.file.size).RoundDown()); end > pgstart {
+ end = pgstart
+ }
+ if end <= rw.off {
+ // Truncation would result in no data being written.
+ return 0, syserror.EPERM
+ }
+ }
+
+ // Page-aligned mr for when we need to allocate memory. RoundUp can't
+ // overflow since end is an int64.
+ pgstartaddr := usermem.Addr(rw.off).RoundDown()
+ pgendaddr, _ := usermem.Addr(end).RoundUp()
+ pgMR := memmap.MappableRange{uint64(pgstartaddr), uint64(pgendaddr)}
+
+ var (
+ done uint64
+ retErr error
+ )
+ seg, gap := rw.file.data.Find(uint64(rw.off))
+ for rw.off < end {
+ mr := memmap.MappableRange{uint64(rw.off), uint64(end)}
+ switch {
+ case seg.Ok():
+ // Get internal mappings.
+ ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Write)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Copy to internal mappings.
+ n, err := safemem.CopySeq(ims, srcs)
+ done += n
+ rw.off += uint64(n)
+ srcs = srcs.DropFirst64(n)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ // Allocate memory for the write.
+ gapMR := gap.Range().Intersect(pgMR)
+ fr, err := rw.file.memFile.Allocate(gapMR.Length(), usage.Tmpfs)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Write to that memory as usual.
+ seg, gap = rw.file.data.Insert(gap, gapMR, fr.Start), fsutil.FileRangeGapIterator{}
+ }
+ }
+exitLoop:
+ // If the write ends beyond the file's previous size, it causes the
+ // file to grow.
+ if rw.off > rw.file.size {
+ rw.file.size = rw.off
+ }
+
+ return done, retErr
+}
+
+// GetSeals returns the current set of seals on a memfd inode.
+func GetSeals(fd *vfs.FileDescription) (uint32, error) {
+ f, ok := fd.Impl().(*regularFileFD)
+ if !ok {
+ return 0, syserror.EINVAL
+ }
+ rf := f.inode().impl.(*regularFile)
+ rf.dataMu.RLock()
+ defer rf.dataMu.RUnlock()
+ return rf.seals, nil
+}
+
+// AddSeals adds new file seals to a memfd inode.
+func AddSeals(fd *vfs.FileDescription, val uint32) error {
+ f, ok := fd.Impl().(*regularFileFD)
+ if !ok {
+ return syserror.EINVAL
+ }
+ rf := f.inode().impl.(*regularFile)
+ rf.mapsMu.Lock()
+ defer rf.mapsMu.Unlock()
+ rf.dataMu.RLock()
+ defer rf.dataMu.RUnlock()
+
+ if rf.seals&linux.F_SEAL_SEAL != 0 {
+ // Seal applied which prevents addition of any new seals.
+ return syserror.EPERM
+ }
+
+ // F_SEAL_WRITE can only be added if there are no active writable maps.
+ if rf.seals&linux.F_SEAL_WRITE == 0 && val&linux.F_SEAL_WRITE != 0 {
+ if rf.writableMappingPages > 0 {
+ return syserror.EBUSY
+ }
+ }
+
+ // Seals can only be added, never removed.
+ rf.seals |= val
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
new file mode 100644
index 000000000..146c7fdfe
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
@@ -0,0 +1,349 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Test that we can write some data to a file and read it back.`
+func TestSimpleWriteRead(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ // Write.
+ data := []byte("foobarbaz")
+ n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("fd.Write failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Write got short write length %d, want %d", n, len(data))
+ }
+ if got, want := fd.Impl().(*regularFileFD).off, int64(len(data)); got != want {
+ t.Errorf("fd.Write left offset at %d, want %d", got, want)
+ }
+
+ // Seek back to beginning.
+ if _, err := fd.Seek(ctx, 0, linux.SEEK_SET); err != nil {
+ t.Fatalf("fd.Seek failed: %v", err)
+ }
+ if got, want := fd.Impl().(*regularFileFD).off, int64(0); got != want {
+ t.Errorf("fd.Seek(0) left offset at %d, want %d", got, want)
+ }
+
+ // Read.
+ buf := make([]byte, len(data))
+ n, err = fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.Read failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Read got short read length %d, want %d", n, len(data))
+ }
+ if got, want := string(buf), string(data); got != want {
+ t.Errorf("Read got %q want %s", got, want)
+ }
+ if got, want := fd.Impl().(*regularFileFD).off, int64(len(data)); got != want {
+ t.Errorf("fd.Write left offset at %d, want %d", got, want)
+ }
+}
+
+func TestPWrite(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ // Fill file with 1k 'a's.
+ data := bytes.Repeat([]byte{'a'}, 1000)
+ n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("fd.Write failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Write got short write length %d, want %d", n, len(data))
+ }
+
+ // Write "gVisor is awesome" at various offsets.
+ buf := []byte("gVisor is awesome")
+ offsets := []int{0, 1, 2, 10, 20, 50, 100, len(data) - 100, len(data) - 1, len(data), len(data) + 1}
+ for _, offset := range offsets {
+ name := fmt.Sprintf("PWrite offset=%d", offset)
+ t.Run(name, func(t *testing.T) {
+ n, err := fd.PWrite(ctx, usermem.BytesIOSequence(buf), int64(offset), vfs.WriteOptions{})
+ if err != nil {
+ t.Errorf("fd.PWrite got err %v want nil", err)
+ }
+ if n != int64(len(buf)) {
+ t.Errorf("fd.PWrite got %d bytes want %d", n, len(buf))
+ }
+
+ // Update data to reflect expected file contents.
+ if len(data) < offset+len(buf) {
+ data = append(data, make([]byte, (offset+len(buf))-len(data))...)
+ }
+ copy(data[offset:], buf)
+
+ // Read the whole file and compare with data.
+ readBuf := make([]byte, len(data))
+ n, err = fd.PRead(ctx, usermem.BytesIOSequence(readBuf), 0, vfs.ReadOptions{})
+ if err != nil {
+ t.Fatalf("fd.PRead failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.PRead got short read length %d, want %d", n, len(data))
+ }
+ if got, want := string(readBuf), string(data); got != want {
+ t.Errorf("PRead got %q want %s", got, want)
+ }
+
+ })
+ }
+}
+
+func TestLocks(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ uid1 := 123
+ uid2 := 456
+ if err := fd.Impl().LockBSD(ctx, uid1, lock.ReadLock, nil); err != nil {
+ t.Fatalf("fd.Impl().LockBSD failed: err = %v", err)
+ }
+ if err := fd.Impl().LockBSD(ctx, uid2, lock.ReadLock, nil); err != nil {
+ t.Fatalf("fd.Impl().LockBSD failed: err = %v", err)
+ }
+ if got, want := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, nil), syserror.ErrWouldBlock; got != want {
+ t.Fatalf("fd.Impl().LockBSD failed: got = %v, want = %v", got, want)
+ }
+ if err := fd.Impl().UnlockBSD(ctx, uid1); err != nil {
+ t.Fatalf("fd.Impl().UnlockBSD failed: err = %v", err)
+ }
+ if err := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, nil); err != nil {
+ t.Fatalf("fd.Impl().LockBSD failed: err = %v", err)
+ }
+
+ if err := fd.Impl().LockPOSIX(ctx, uid1, lock.ReadLock, 0, 1, linux.SEEK_SET, nil); err != nil {
+ t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err)
+ }
+ if err := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, 1, 2, linux.SEEK_SET, nil); err != nil {
+ t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err)
+ }
+ if err := fd.Impl().LockPOSIX(ctx, uid1, lock.WriteLock, 0, 1, linux.SEEK_SET, nil); err != nil {
+ t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err)
+ }
+ if got, want := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, 0, 1, linux.SEEK_SET, nil), syserror.ErrWouldBlock; got != want {
+ t.Fatalf("fd.Impl().LockPOSIX failed: got = %v, want = %v", got, want)
+ }
+ if err := fd.Impl().UnlockPOSIX(ctx, uid1, 0, 1, linux.SEEK_SET); err != nil {
+ t.Fatalf("fd.Impl().UnlockPOSIX failed: err = %v", err)
+ }
+}
+
+func TestPRead(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ // Write 100 sequences of 'gVisor is awesome'.
+ data := bytes.Repeat([]byte("gVisor is awsome"), 100)
+ n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("fd.Write failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Write got short write length %d, want %d", n, len(data))
+ }
+
+ // Read various sizes from various offsets.
+ sizes := []int{0, 1, 2, 10, 20, 50, 100, 1000}
+ offsets := []int{0, 1, 2, 10, 20, 50, 100, 1000, len(data) - 100, len(data) - 1, len(data), len(data) + 1}
+
+ for _, size := range sizes {
+ for _, offset := range offsets {
+ name := fmt.Sprintf("PRead offset=%d size=%d", offset, size)
+ t.Run(name, func(t *testing.T) {
+ var (
+ wantRead []byte
+ wantErr error
+ )
+ if offset < len(data) {
+ wantRead = data[offset:]
+ } else if size > 0 {
+ wantErr = io.EOF
+ }
+ if offset+size < len(data) {
+ wantRead = wantRead[:size]
+ }
+ buf := make([]byte, size)
+ n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), int64(offset), vfs.ReadOptions{})
+ if err != wantErr {
+ t.Errorf("fd.PRead got err %v want %v", err, wantErr)
+ }
+ if n != int64(len(wantRead)) {
+ t.Errorf("fd.PRead got %d bytes want %d", n, len(wantRead))
+ }
+ if got := string(buf[:n]); got != string(wantRead) {
+ t.Errorf("fd.PRead got %q want %q", got, string(wantRead))
+ }
+ })
+ }
+ }
+}
+
+func TestTruncate(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ // Fill the file with some data.
+ data := bytes.Repeat([]byte("gVisor is awsome"), 100)
+ written, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("fd.Write failed: %v", err)
+ }
+
+ // Size should be same as written.
+ sizeStatOpts := vfs.StatOptions{Mask: linux.STATX_SIZE}
+ stat, err := fd.Stat(ctx, sizeStatOpts)
+ if err != nil {
+ t.Fatalf("fd.Stat failed: %v", err)
+ }
+ if got, want := int64(stat.Size), written; got != want {
+ t.Errorf("fd.Stat got size %d, want %d", got, want)
+ }
+
+ // Truncate down.
+ newSize := uint64(10)
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: newSize,
+ },
+ }); err != nil {
+ t.Errorf("fd.Truncate failed: %v", err)
+ }
+ // Size should be updated.
+ statAfterTruncateDown, err := fd.Stat(ctx, sizeStatOpts)
+ if err != nil {
+ t.Fatalf("fd.Stat failed: %v", err)
+ }
+ if got, want := statAfterTruncateDown.Size, newSize; got != want {
+ t.Errorf("fd.Stat got size %d, want %d", got, want)
+ }
+ // We should only read newSize worth of data.
+ buf := make([]byte, 1000)
+ if n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0, vfs.ReadOptions{}); err != nil && err != io.EOF {
+ t.Fatalf("fd.PRead failed: %v", err)
+ } else if uint64(n) != newSize {
+ t.Errorf("fd.PRead got size %d, want %d", n, newSize)
+ }
+ // Mtime and Ctime should be bumped.
+ if got := statAfterTruncateDown.Mtime.ToNsec(); got <= stat.Mtime.ToNsec() {
+ t.Errorf("fd.Stat got Mtime %v, want > %v", got, stat.Mtime)
+ }
+ if got := statAfterTruncateDown.Ctime.ToNsec(); got <= stat.Ctime.ToNsec() {
+ t.Errorf("fd.Stat got Ctime %v, want > %v", got, stat.Ctime)
+ }
+
+ // Truncate up.
+ newSize = 100
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: newSize,
+ },
+ }); err != nil {
+ t.Errorf("fd.Truncate failed: %v", err)
+ }
+ // Size should be updated.
+ statAfterTruncateUp, err := fd.Stat(ctx, sizeStatOpts)
+ if err != nil {
+ t.Fatalf("fd.Stat failed: %v", err)
+ }
+ if got, want := statAfterTruncateUp.Size, newSize; got != want {
+ t.Errorf("fd.Stat got size %d, want %d", got, want)
+ }
+ // We should read newSize worth of data.
+ buf = make([]byte, 1000)
+ if n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0, vfs.ReadOptions{}); err != nil && err != io.EOF {
+ t.Fatalf("fd.PRead failed: %v", err)
+ } else if uint64(n) != newSize {
+ t.Errorf("fd.PRead got size %d, want %d", n, newSize)
+ }
+ // Bytes should be null after 10, since we previously truncated to 10.
+ for i := uint64(10); i < newSize; i++ {
+ if buf[i] != 0 {
+ t.Errorf("fd.PRead got byte %d=%x, want 0", i, buf[i])
+ break
+ }
+ }
+ // Mtime and Ctime should be bumped.
+ if got := statAfterTruncateUp.Mtime.ToNsec(); got <= statAfterTruncateDown.Mtime.ToNsec() {
+ t.Errorf("fd.Stat got Mtime %v, want > %v", got, statAfterTruncateDown.Mtime)
+ }
+ if got := statAfterTruncateUp.Ctime.ToNsec(); got <= statAfterTruncateDown.Ctime.ToNsec() {
+ t.Errorf("fd.Stat got Ctime %v, want > %v", got, stat.Ctime)
+ }
+
+ // Truncate to the current size.
+ newSize = statAfterTruncateUp.Size
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: newSize,
+ },
+ }); err != nil {
+ t.Errorf("fd.Truncate failed: %v", err)
+ }
+ statAfterTruncateNoop, err := fd.Stat(ctx, sizeStatOpts)
+ if err != nil {
+ t.Fatalf("fd.Stat failed: %v", err)
+ }
+ // Mtime and Ctime should not be bumped, since operation is a noop.
+ if got := statAfterTruncateNoop.Mtime.ToNsec(); got != statAfterTruncateUp.Mtime.ToNsec() {
+ t.Errorf("fd.Stat got Mtime %v, want %v", got, statAfterTruncateUp.Mtime)
+ }
+ if got := statAfterTruncateNoop.Ctime.ToNsec(); got != statAfterTruncateUp.Ctime.ToNsec() {
+ t.Errorf("fd.Stat got Ctime %v, want %v", got, statAfterTruncateUp.Ctime)
+ }
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/socket_file.go b/pkg/sentry/fsimpl/tmpfs/socket_file.go
new file mode 100644
index 000000000..3ed650474
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/socket_file.go
@@ -0,0 +1,34 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+)
+
+// socketFile is a socket (=S_IFSOCK) tmpfs file.
+type socketFile struct {
+ inode inode
+ ep transport.BoundEndpoint
+}
+
+func (fs *filesystem) newSocketFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode, ep transport.BoundEndpoint) *inode {
+ file := &socketFile{ep: ep}
+ file.inode.init(file, fs, kuid, kgid, mode)
+ file.inode.nlink = 1 // from parent directory
+ return &file.inode
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/stat_test.go b/pkg/sentry/fsimpl/tmpfs/stat_test.go
new file mode 100644
index 000000000..f7ee4aab2
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/stat_test.go
@@ -0,0 +1,236 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "fmt"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+func TestStatAfterCreate(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mode := linux.FileMode(0644)
+
+ // Run with different file types.
+ for _, typ := range []string{"file", "dir", "pipe"} {
+ t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) {
+ var (
+ fd *vfs.FileDescription
+ cleanup func()
+ err error
+ )
+ switch typ {
+ case "file":
+ fd, cleanup, err = newFileFD(ctx, mode)
+ case "dir":
+ fd, cleanup, err = newDirFD(ctx, mode)
+ case "pipe":
+ fd, cleanup, err = newPipeFD(ctx, mode)
+ default:
+ panic(fmt.Sprintf("unknown typ %q", typ))
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ got, err := fd.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("Stat failed: %v", err)
+ }
+
+ // Atime, Ctime, Mtime should all be current time (non-zero).
+ atime, ctime, mtime := got.Atime.ToNsec(), got.Ctime.ToNsec(), got.Mtime.ToNsec()
+ if atime != ctime || ctime != mtime {
+ t.Errorf("got atime=%d ctime=%d mtime=%d, wanted equal values", atime, ctime, mtime)
+ }
+ if atime == 0 {
+ t.Errorf("got atime=%d, want non-zero", atime)
+ }
+
+ // Btime should be 0, as it is not set by tmpfs.
+ if btime := got.Btime.ToNsec(); btime != 0 {
+ t.Errorf("got btime %d, want 0", got.Btime.ToNsec())
+ }
+
+ // Size should be 0 (except for directories, which make up a size
+ // of 20 per entry, including the "." and ".." entries present in
+ // otherwise-empty directories).
+ wantSize := uint64(0)
+ if typ == "dir" {
+ wantSize = 40
+ }
+ if got.Size != wantSize {
+ t.Errorf("got size %d, want %d", got.Size, wantSize)
+ }
+
+ // Nlink should be 1 for files, 2 for dirs.
+ wantNlink := uint32(1)
+ if typ == "dir" {
+ wantNlink = 2
+ }
+ if got.Nlink != wantNlink {
+ t.Errorf("got nlink %d, want %d", got.Nlink, wantNlink)
+ }
+
+ // UID and GID are set from context creds.
+ creds := auth.CredentialsFromContext(ctx)
+ if got.UID != uint32(creds.EffectiveKUID) {
+ t.Errorf("got uid %d, want %d", got.UID, uint32(creds.EffectiveKUID))
+ }
+ if got.GID != uint32(creds.EffectiveKGID) {
+ t.Errorf("got gid %d, want %d", got.GID, uint32(creds.EffectiveKGID))
+ }
+
+ // Mode.
+ wantMode := uint16(mode)
+ switch typ {
+ case "file":
+ wantMode |= linux.S_IFREG
+ case "dir":
+ wantMode |= linux.S_IFDIR
+ case "pipe":
+ wantMode |= linux.S_IFIFO
+ default:
+ panic(fmt.Sprintf("unknown typ %q", typ))
+ }
+
+ if got.Mode != wantMode {
+ t.Errorf("got mode %x, want %x", got.Mode, wantMode)
+ }
+
+ // Ino.
+ if got.Ino == 0 {
+ t.Errorf("got ino %d, want not 0", got.Ino)
+ }
+ })
+ }
+}
+
+func TestSetStatAtime(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ allStatOptions := vfs.StatOptions{Mask: linux.STATX_ALL}
+
+ // Get initial stat.
+ initialStat, err := fd.Stat(ctx, allStatOptions)
+ if err != nil {
+ t.Fatalf("Stat failed: %v", err)
+ }
+
+ // Set atime, but without the mask.
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: linux.Statx{
+ Mask: 0,
+ Atime: linux.NsecToStatxTimestamp(100),
+ }}); err != nil {
+ t.Errorf("SetStat atime without mask failed: %v", err)
+ }
+ // Atime should be unchanged.
+ if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
+ t.Errorf("Stat got error: %v", err)
+ } else if gotStat.Atime != initialStat.Atime {
+ t.Errorf("Stat got atime %d, want %d", gotStat.Atime, initialStat.Atime)
+ }
+
+ // Set atime, this time included in the mask.
+ setStat := linux.Statx{
+ Mask: linux.STATX_ATIME,
+ Atime: linux.NsecToStatxTimestamp(100),
+ }
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil {
+ t.Errorf("SetStat atime with mask failed: %v", err)
+ }
+ if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
+ t.Errorf("Stat got error: %v", err)
+ } else if gotStat.Atime != setStat.Atime {
+ t.Errorf("Stat got atime %d, want %d", gotStat.Atime, setStat.Atime)
+ }
+}
+
+func TestSetStat(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mode := linux.FileMode(0644)
+
+ // Run with different file types.
+ for _, typ := range []string{"file", "dir", "pipe"} {
+ t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) {
+ var (
+ fd *vfs.FileDescription
+ cleanup func()
+ err error
+ )
+ switch typ {
+ case "file":
+ fd, cleanup, err = newFileFD(ctx, mode)
+ case "dir":
+ fd, cleanup, err = newDirFD(ctx, mode)
+ case "pipe":
+ fd, cleanup, err = newPipeFD(ctx, mode)
+ default:
+ panic(fmt.Sprintf("unknown typ %q", typ))
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ allStatOptions := vfs.StatOptions{Mask: linux.STATX_ALL}
+
+ // Get initial stat.
+ initialStat, err := fd.Stat(ctx, allStatOptions)
+ if err != nil {
+ t.Fatalf("Stat failed: %v", err)
+ }
+
+ // Set atime, but without the mask.
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: linux.Statx{
+ Mask: 0,
+ Atime: linux.NsecToStatxTimestamp(100),
+ }}); err != nil {
+ t.Errorf("SetStat atime without mask failed: %v", err)
+ }
+ // Atime should be unchanged.
+ if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
+ t.Errorf("Stat got error: %v", err)
+ } else if gotStat.Atime != initialStat.Atime {
+ t.Errorf("Stat got atime %d, want %d", gotStat.Atime, initialStat.Atime)
+ }
+
+ // Set atime, this time included in the mask.
+ setStat := linux.Statx{
+ Mask: linux.STATX_ATIME,
+ Atime: linux.NsecToStatxTimestamp(100),
+ }
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil {
+ t.Errorf("SetStat atime with mask failed: %v", err)
+ }
+ if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
+ t.Errorf("Stat got error: %v", err)
+ } else if gotStat.Atime != setStat.Atime {
+ t.Errorf("Stat got atime %d, want %d", gotStat.Atime, setStat.Atime)
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/symlink.go b/pkg/sentry/fsimpl/tmpfs/symlink.go
new file mode 100644
index 000000000..b0de5fabe
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/symlink.go
@@ -0,0 +1,37 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+type symlink struct {
+ inode inode
+ target string // immutable
+}
+
+func (fs *filesystem) newSymlink(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode, target string) *inode {
+ link := &symlink{
+ target: target,
+ }
+ link.inode.init(link, fs, kuid, kgid, linux.S_IFLNK|mode)
+ link.inode.nlink = 1 // from parent directory
+ return &link.inode
+}
+
+// O_PATH is unimplemented, so there's no way to get a FileDescription
+// representing a symlink yet.
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
new file mode 100644
index 000000000..d7f4f0779
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -0,0 +1,787 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tmpfs provides an in-memory filesystem whose contents are
+// application-mutable, consistent with Linux's tmpfs.
+//
+// Lock order:
+//
+// filesystem.mu
+// inode.mu
+// regularFileFD.offMu
+// *** "memmap.Mappable locks" below this point
+// regularFile.mapsMu
+// *** "memmap.Mappable locks taken by Translate" below this point
+// regularFile.dataMu
+// directory.iterMu
+package tmpfs
+
+import (
+ "fmt"
+ "math"
+ "strconv"
+ "strings"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/memxattr"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Name is the default filesystem name.
+const Name = "tmpfs"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ vfsfs vfs.Filesystem
+
+ // memFile is used to allocate pages to for regular files.
+ memFile *pgalloc.MemoryFile
+
+ // clock is a realtime clock used to set timestamps in file operations.
+ clock time.Clock
+
+ // devMinor is the filesystem's minor device number. devMinor is immutable.
+ devMinor uint32
+
+ // mu serializes changes to the Dentry tree.
+ mu sync.RWMutex
+
+ nextInoMinusOne uint64 // accessed using atomic memory operations
+}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// FilesystemOpts is used to pass configuration data to tmpfs.
+type FilesystemOpts struct {
+ // RootFileType is the FileType of the filesystem root. Valid values
+ // are: S_IFDIR, S_IFREG, and S_IFLNK. Defaults to S_IFDIR.
+ RootFileType uint16
+
+ // RootSymlinkTarget is the target of the root symlink. Only valid if
+ // RootFileType == S_IFLNK.
+ RootSymlinkTarget string
+
+ // FilesystemType allows setting a different FilesystemType for this
+ // tmpfs filesystem. This allows tmpfs to "impersonate" other
+ // filesystems, like ramdiskfs and cgroupfs.
+ FilesystemType vfs.FilesystemType
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, _ string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ memFileProvider := pgalloc.MemoryFileProviderFromContext(ctx)
+ if memFileProvider == nil {
+ panic("MemoryFileProviderFromContext returned nil")
+ }
+
+ rootFileType := uint16(linux.S_IFDIR)
+ newFSType := vfs.FilesystemType(&fstype)
+ tmpfsOpts, ok := opts.InternalData.(FilesystemOpts)
+ if ok {
+ if tmpfsOpts.RootFileType != 0 {
+ rootFileType = tmpfsOpts.RootFileType
+ }
+ if tmpfsOpts.FilesystemType != nil {
+ newFSType = tmpfsOpts.FilesystemType
+ }
+ }
+
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ rootMode := linux.FileMode(0777)
+ if rootFileType == linux.S_IFDIR {
+ rootMode = 01777
+ }
+ modeStr, ok := mopts["mode"]
+ if ok {
+ delete(mopts, "mode")
+ mode, err := strconv.ParseUint(modeStr, 8, 32)
+ if err != nil {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid mode: %q", modeStr)
+ return nil, nil, syserror.EINVAL
+ }
+ rootMode = linux.FileMode(mode & 07777)
+ }
+ rootKUID := creds.EffectiveKUID
+ uidStr, ok := mopts["uid"]
+ if ok {
+ delete(mopts, "uid")
+ uid, err := strconv.ParseUint(uidStr, 10, 32)
+ if err != nil {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid uid: %q", uidStr)
+ return nil, nil, syserror.EINVAL
+ }
+ kuid := creds.UserNamespace.MapToKUID(auth.UID(uid))
+ if !kuid.Ok() {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unmapped uid: %d", uid)
+ return nil, nil, syserror.EINVAL
+ }
+ rootKUID = kuid
+ }
+ rootKGID := creds.EffectiveKGID
+ gidStr, ok := mopts["gid"]
+ if ok {
+ delete(mopts, "gid")
+ gid, err := strconv.ParseUint(gidStr, 10, 32)
+ if err != nil {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid gid: %q", gidStr)
+ return nil, nil, syserror.EINVAL
+ }
+ kgid := creds.UserNamespace.MapToKGID(auth.GID(gid))
+ if !kgid.Ok() {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unmapped gid: %d", gid)
+ return nil, nil, syserror.EINVAL
+ }
+ rootKGID = kgid
+ }
+ if len(mopts) != 0 {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unknown options: %v", mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+ clock := time.RealtimeClockFromContext(ctx)
+ fs := filesystem{
+ memFile: memFileProvider.MemoryFile(),
+ clock: clock,
+ devMinor: devMinor,
+ }
+ fs.vfsfs.Init(vfsObj, newFSType, &fs)
+
+ var root *dentry
+ switch rootFileType {
+ case linux.S_IFREG:
+ root = fs.newDentry(fs.newRegularFile(rootKUID, rootKGID, rootMode))
+ case linux.S_IFLNK:
+ root = fs.newDentry(fs.newSymlink(rootKUID, rootKGID, rootMode, tmpfsOpts.RootSymlinkTarget))
+ case linux.S_IFDIR:
+ root = &fs.newDirectory(rootKUID, rootKGID, rootMode).dentry
+ default:
+ fs.vfsfs.DecRef()
+ return nil, nil, fmt.Errorf("invalid tmpfs root file type: %#o", rootFileType)
+ }
+ return &fs.vfsfs, &root.vfsd, nil
+}
+
+// NewFilesystem returns a new tmpfs filesystem.
+func NewFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*vfs.Filesystem, *vfs.Dentry, error) {
+ return FilesystemType{}.GetFilesystem(ctx, vfsObj, creds, "", vfs.GetFilesystemOptions{})
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release() {
+ fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+}
+
+// dentry implements vfs.DentryImpl.
+type dentry struct {
+ vfsd vfs.Dentry
+
+ // parent is this dentry's parent directory. Each referenced dentry holds a
+ // reference on parent.dentry. If this dentry is a filesystem root, parent
+ // is nil. parent is protected by filesystem.mu.
+ parent *dentry
+
+ // name is the name of this dentry in its parent. If this dentry is a
+ // filesystem root, name is the empty string. name is protected by
+ // filesystem.mu.
+ name string
+
+ // dentryEntry (ugh) links dentries into their parent directory.childList.
+ dentryEntry
+
+ // inode is the inode represented by this dentry. Multiple Dentries may
+ // share a single non-directory inode (with hard links). inode is
+ // immutable.
+ //
+ // tmpfs doesn't count references on dentries; because the dentry tree is
+ // the sole source of truth, it is by definition always consistent with the
+ // state of the filesystem. However, it does count references on inodes,
+ // because inode resources are released when all references are dropped.
+ // dentry therefore forwards reference counting directly to inode.
+ inode *inode
+}
+
+func (fs *filesystem) newDentry(inode *inode) *dentry {
+ d := &dentry{
+ inode: inode,
+ }
+ d.vfsd.Init(d)
+ return d
+}
+
+// IncRef implements vfs.DentryImpl.IncRef.
+func (d *dentry) IncRef() {
+ d.inode.incRef()
+}
+
+// TryIncRef implements vfs.DentryImpl.TryIncRef.
+func (d *dentry) TryIncRef() bool {
+ return d.inode.tryIncRef()
+}
+
+// DecRef implements vfs.DentryImpl.DecRef.
+func (d *dentry) DecRef() {
+ d.inode.decRef()
+}
+
+// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
+func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {
+ if d.inode.isDir() {
+ events |= linux.IN_ISDIR
+ }
+
+ // tmpfs never calls VFS.InvalidateDentry(), so d.vfsd.IsDead() indicates
+ // that d was deleted.
+ deleted := d.vfsd.IsDead()
+
+ d.inode.fs.mu.RLock()
+ // The ordering below is important, Linux always notifies the parent first.
+ if d.parent != nil {
+ d.parent.inode.watches.Notify(d.name, events, cookie, et, deleted)
+ }
+ d.inode.watches.Notify("", events, cookie, et, deleted)
+ d.inode.fs.mu.RUnlock()
+}
+
+// Watches implements vfs.DentryImpl.Watches.
+func (d *dentry) Watches() *vfs.Watches {
+ return &d.inode.watches
+}
+
+// OnZeroWatches implements vfs.Dentry.OnZeroWatches.
+func (d *dentry) OnZeroWatches() {}
+
+// inode represents a filesystem object.
+type inode struct {
+ // fs is the owning filesystem. fs is immutable.
+ fs *filesystem
+
+ // refs is a reference count. refs is accessed using atomic memory
+ // operations.
+ //
+ // A reference is held on all inodes as long as they are reachable in the
+ // filesystem tree, i.e. nlink is nonzero. This reference is dropped when
+ // nlink reaches 0.
+ refs int64
+
+ // xattrs implements extended attributes.
+ //
+ // TODO(b/148380782): Support xattrs other than user.*
+ xattrs memxattr.SimpleExtendedAttributes
+
+ // Inode metadata. Writing multiple fields atomically requires holding
+ // mu, othewise atomic operations can be used.
+ mu sync.Mutex
+ mode uint32 // file type and mode
+ nlink uint32 // protected by filesystem.mu instead of inode.mu
+ uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
+ gid uint32 // auth.KGID, but ...
+ ino uint64 // immutable
+
+ // Linux's tmpfs has no concept of btime.
+ atime int64 // nanoseconds
+ ctime int64 // nanoseconds
+ mtime int64 // nanoseconds
+
+ locks vfs.FileLocks
+
+ // Inotify watches for this inode.
+ watches vfs.Watches
+
+ impl interface{} // immutable
+}
+
+const maxLinks = math.MaxUint32
+
+func (i *inode) init(impl interface{}, fs *filesystem, kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) {
+ if mode.FileType() == 0 {
+ panic("file type is required in FileMode")
+ }
+ i.fs = fs
+ i.refs = 1
+ i.mode = uint32(mode)
+ i.uid = uint32(kuid)
+ i.gid = uint32(kgid)
+ i.ino = atomic.AddUint64(&fs.nextInoMinusOne, 1)
+ // Tmpfs creation sets atime, ctime, and mtime to current time.
+ now := fs.clock.Now().Nanoseconds()
+ i.atime = now
+ i.ctime = now
+ i.mtime = now
+ // i.nlink initialized by caller
+ i.impl = impl
+}
+
+// incLinksLocked increments i's link count.
+//
+// Preconditions: filesystem.mu must be locked for writing. i.nlink != 0.
+// i.nlink < maxLinks.
+func (i *inode) incLinksLocked() {
+ if i.nlink == 0 {
+ panic("tmpfs.inode.incLinksLocked() called with no existing links")
+ }
+ if i.nlink == maxLinks {
+ panic("tmpfs.inode.incLinksLocked() called with maximum link count")
+ }
+ atomic.AddUint32(&i.nlink, 1)
+}
+
+// decLinksLocked decrements i's link count. If the link count reaches 0, we
+// remove a reference on i as well.
+//
+// Preconditions: filesystem.mu must be locked for writing. i.nlink != 0.
+func (i *inode) decLinksLocked() {
+ if i.nlink == 0 {
+ panic("tmpfs.inode.decLinksLocked() called with no existing links")
+ }
+ if atomic.AddUint32(&i.nlink, ^uint32(0)) == 0 {
+ i.decRef()
+ }
+}
+
+func (i *inode) incRef() {
+ if atomic.AddInt64(&i.refs, 1) <= 1 {
+ panic("tmpfs.inode.incRef() called without holding a reference")
+ }
+}
+
+func (i *inode) tryIncRef() bool {
+ for {
+ refs := atomic.LoadInt64(&i.refs)
+ if refs == 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&i.refs, refs, refs+1) {
+ return true
+ }
+ }
+}
+
+func (i *inode) decRef() {
+ if refs := atomic.AddInt64(&i.refs, -1); refs == 0 {
+ i.watches.HandleDeletion()
+ if regFile, ok := i.impl.(*regularFile); ok {
+ // Release memory used by regFile to store data. Since regFile is
+ // no longer usable, we don't need to grab any locks or update any
+ // metadata.
+ regFile.data.DropAll(regFile.memFile)
+ }
+ } else if refs < 0 {
+ panic("tmpfs.inode.decRef() called without holding a reference")
+ }
+}
+
+func (i *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+ mode := linux.FileMode(atomic.LoadUint32(&i.mode))
+ return vfs.GenericCheckPermissions(creds, ats, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid)))
+}
+
+// Go won't inline this function, and returning linux.Statx (which is quite
+// big) means spending a lot of time in runtime.duffcopy(), so instead it's an
+// output parameter.
+//
+// Note that Linux does not guarantee to return consistent data (in the case of
+// a concurrent modification), so we do not require holding inode.mu.
+func (i *inode) statTo(stat *linux.Statx) {
+ stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK |
+ linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_SIZE |
+ linux.STATX_BLOCKS | linux.STATX_ATIME | linux.STATX_CTIME |
+ linux.STATX_MTIME
+ stat.Blksize = usermem.PageSize
+ stat.Nlink = atomic.LoadUint32(&i.nlink)
+ stat.UID = atomic.LoadUint32(&i.uid)
+ stat.GID = atomic.LoadUint32(&i.gid)
+ stat.Mode = uint16(atomic.LoadUint32(&i.mode))
+ stat.Ino = i.ino
+ stat.Atime = linux.NsecToStatxTimestamp(i.atime)
+ stat.Ctime = linux.NsecToStatxTimestamp(i.ctime)
+ stat.Mtime = linux.NsecToStatxTimestamp(i.mtime)
+ stat.DevMajor = linux.UNNAMED_MAJOR
+ stat.DevMinor = i.fs.devMinor
+ switch impl := i.impl.(type) {
+ case *regularFile:
+ stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS
+ stat.Size = uint64(atomic.LoadUint64(&impl.size))
+ // TODO(jamieliu): This should be impl.data.Span() / 512, but this is
+ // too expensive to compute here. Cache it in regularFile.
+ stat.Blocks = allocatedBlocksForSize(stat.Size)
+ case *directory:
+ // "20" is mm/shmem.c:BOGO_DIRENT_SIZE.
+ stat.Size = 20 * (2 + uint64(atomic.LoadInt64(&impl.numChildren)))
+ // stat.Blocks is 0.
+ case *symlink:
+ stat.Size = uint64(len(impl.target))
+ // stat.Blocks is 0.
+ case *namedPipe, *socketFile:
+ // stat.Size and stat.Blocks are 0.
+ case *deviceFile:
+ // stat.Size and stat.Blocks are 0.
+ stat.RdevMajor = impl.major
+ stat.RdevMinor = impl.minor
+ default:
+ panic(fmt.Sprintf("unknown inode type: %T", i.impl))
+ }
+}
+
+func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx) error {
+ if stat.Mask == 0 {
+ return nil
+ }
+ if stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME|linux.STATX_SIZE) != 0 {
+ return syserror.EPERM
+ }
+ mode := linux.FileMode(atomic.LoadUint32(&i.mode))
+ if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil {
+ return err
+ }
+ i.mu.Lock()
+ defer i.mu.Unlock()
+ var (
+ needsMtimeBump bool
+ needsCtimeBump bool
+ )
+ mask := stat.Mask
+ if mask&linux.STATX_MODE != 0 {
+ ft := atomic.LoadUint32(&i.mode) & linux.S_IFMT
+ atomic.StoreUint32(&i.mode, ft|uint32(stat.Mode&^linux.S_IFMT))
+ needsCtimeBump = true
+ }
+ if mask&linux.STATX_UID != 0 {
+ atomic.StoreUint32(&i.uid, stat.UID)
+ needsCtimeBump = true
+ }
+ if mask&linux.STATX_GID != 0 {
+ atomic.StoreUint32(&i.gid, stat.GID)
+ needsCtimeBump = true
+ }
+ if mask&linux.STATX_SIZE != 0 {
+ switch impl := i.impl.(type) {
+ case *regularFile:
+ updated, err := impl.truncateLocked(stat.Size)
+ if err != nil {
+ return err
+ }
+ if updated {
+ needsMtimeBump = true
+ needsCtimeBump = true
+ }
+ case *directory:
+ return syserror.EISDIR
+ default:
+ return syserror.EINVAL
+ }
+ }
+ now := i.fs.clock.Now().Nanoseconds()
+ if mask&linux.STATX_ATIME != 0 {
+ if stat.Atime.Nsec == linux.UTIME_NOW {
+ atomic.StoreInt64(&i.atime, now)
+ } else {
+ atomic.StoreInt64(&i.atime, stat.Atime.ToNsecCapped())
+ }
+ needsCtimeBump = true
+ }
+ if mask&linux.STATX_MTIME != 0 {
+ if stat.Mtime.Nsec == linux.UTIME_NOW {
+ atomic.StoreInt64(&i.mtime, now)
+ } else {
+ atomic.StoreInt64(&i.mtime, stat.Mtime.ToNsecCapped())
+ }
+ needsCtimeBump = true
+ // Ignore the mtime bump, since we just set it ourselves.
+ needsMtimeBump = false
+ }
+ if mask&linux.STATX_CTIME != 0 {
+ if stat.Ctime.Nsec == linux.UTIME_NOW {
+ atomic.StoreInt64(&i.ctime, now)
+ } else {
+ atomic.StoreInt64(&i.ctime, stat.Ctime.ToNsecCapped())
+ }
+ // Ignore the ctime bump, since we just set it ourselves.
+ needsCtimeBump = false
+ }
+ if needsMtimeBump {
+ atomic.StoreInt64(&i.mtime, now)
+ }
+ if needsCtimeBump {
+ atomic.StoreInt64(&i.ctime, now)
+ }
+
+ return nil
+}
+
+// allocatedBlocksForSize returns the number of 512B blocks needed to
+// accommodate the given size in bytes, as appropriate for struct
+// stat::st_blocks and struct statx::stx_blocks. (Note that this 512B block
+// size is independent of the "preferred block size for I/O", struct
+// stat::st_blksize and struct statx::stx_blksize.)
+func allocatedBlocksForSize(size uint64) uint64 {
+ return (size + 511) / 512
+}
+
+func (i *inode) direntType() uint8 {
+ switch impl := i.impl.(type) {
+ case *regularFile:
+ return linux.DT_REG
+ case *directory:
+ return linux.DT_DIR
+ case *symlink:
+ return linux.DT_LNK
+ case *socketFile:
+ return linux.DT_SOCK
+ case *deviceFile:
+ switch impl.kind {
+ case vfs.BlockDevice:
+ return linux.DT_BLK
+ case vfs.CharDevice:
+ return linux.DT_CHR
+ default:
+ panic(fmt.Sprintf("unknown vfs.DeviceKind: %v", impl.kind))
+ }
+ default:
+ panic(fmt.Sprintf("unknown inode type: %T", i.impl))
+ }
+}
+
+func (i *inode) isDir() bool {
+ return linux.FileMode(i.mode).FileType() == linux.S_IFDIR
+}
+
+func (i *inode) touchAtime(mnt *vfs.Mount) {
+ if mnt.Flags.NoATime {
+ return
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return
+ }
+ now := i.fs.clock.Now().Nanoseconds()
+ i.mu.Lock()
+ atomic.StoreInt64(&i.atime, now)
+ i.mu.Unlock()
+ mnt.EndWrite()
+}
+
+// Preconditions: The caller has called vfs.Mount.CheckBeginWrite().
+func (i *inode) touchCtime() {
+ now := i.fs.clock.Now().Nanoseconds()
+ i.mu.Lock()
+ atomic.StoreInt64(&i.ctime, now)
+ i.mu.Unlock()
+}
+
+// Preconditions: The caller has called vfs.Mount.CheckBeginWrite().
+func (i *inode) touchCMtime() {
+ now := i.fs.clock.Now().Nanoseconds()
+ i.mu.Lock()
+ atomic.StoreInt64(&i.mtime, now)
+ atomic.StoreInt64(&i.ctime, now)
+ i.mu.Unlock()
+}
+
+// Preconditions: The caller has called vfs.Mount.CheckBeginWrite() and holds
+// inode.mu.
+func (i *inode) touchCMtimeLocked() {
+ now := i.fs.clock.Now().Nanoseconds()
+ atomic.StoreInt64(&i.mtime, now)
+ atomic.StoreInt64(&i.ctime, now)
+}
+
+func (i *inode) listxattr(size uint64) ([]string, error) {
+ return i.xattrs.Listxattr(size)
+}
+
+func (i *inode) getxattr(creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) {
+ if err := i.checkPermissions(creds, vfs.MayRead); err != nil {
+ return "", err
+ }
+ if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
+ return "", syserror.EOPNOTSUPP
+ }
+ if !i.userXattrSupported() {
+ return "", syserror.ENODATA
+ }
+ return i.xattrs.Getxattr(opts)
+}
+
+func (i *inode) setxattr(creds *auth.Credentials, opts *vfs.SetxattrOptions) error {
+ if err := i.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return err
+ }
+ if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+ if !i.userXattrSupported() {
+ return syserror.EPERM
+ }
+ return i.xattrs.Setxattr(opts)
+}
+
+func (i *inode) removexattr(creds *auth.Credentials, name string) error {
+ if err := i.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return err
+ }
+ if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+ if !i.userXattrSupported() {
+ return syserror.EPERM
+ }
+ return i.xattrs.Removexattr(name)
+}
+
+// Extended attributes in the user.* namespace are only supported for regular
+// files and directories.
+func (i *inode) userXattrSupported() bool {
+ filetype := linux.S_IFMT & atomic.LoadUint32(&i.mode)
+ return filetype == linux.S_IFREG || filetype == linux.S_IFDIR
+}
+
+// fileDescription is embedded by tmpfs implementations of
+// vfs.FileDescriptionImpl.
+type fileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+}
+
+func (fd *fileDescription) filesystem() *filesystem {
+ return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem)
+}
+
+func (fd *fileDescription) dentry() *dentry {
+ return fd.vfsfd.Dentry().Impl().(*dentry)
+}
+
+func (fd *fileDescription) inode() *inode {
+ return fd.dentry().inode
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ var stat linux.Statx
+ fd.inode().statTo(&stat)
+ return stat, nil
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ creds := auth.CredentialsFromContext(ctx)
+ d := fd.dentry()
+ if err := d.inode.setStat(ctx, creds, &opts.Stat); err != nil {
+ return err
+ }
+
+ if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
+ d.InotifyWithParent(ev, 0, vfs.InodeEvent)
+ }
+ return nil
+}
+
+// Listxattr implements vfs.FileDescriptionImpl.Listxattr.
+func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) {
+ return fd.inode().listxattr(size)
+}
+
+// Getxattr implements vfs.FileDescriptionImpl.Getxattr.
+func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) {
+ return fd.inode().getxattr(auth.CredentialsFromContext(ctx), &opts)
+}
+
+// Setxattr implements vfs.FileDescriptionImpl.Setxattr.
+func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error {
+ d := fd.dentry()
+ if err := d.inode.setxattr(auth.CredentialsFromContext(ctx), &opts); err != nil {
+ return err
+ }
+
+ // Generate inotify events.
+ d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
+}
+
+// Removexattr implements vfs.FileDescriptionImpl.Removexattr.
+func (fd *fileDescription) Removexattr(ctx context.Context, name string) error {
+ d := fd.dentry()
+ if err := d.inode.removexattr(auth.CredentialsFromContext(ctx), name); err != nil {
+ return err
+ }
+
+ // Generate inotify events.
+ d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
+}
+
+// NewMemfd creates a new tmpfs regular file and file description that can back
+// an anonymous fd created by memfd_create.
+func NewMemfd(mount *vfs.Mount, creds *auth.Credentials, allowSeals bool, name string) (*vfs.FileDescription, error) {
+ fs, ok := mount.Filesystem().Impl().(*filesystem)
+ if !ok {
+ panic("NewMemfd() called with non-tmpfs mount")
+ }
+
+ // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd inodes are set up with
+ // S_IRWXUGO.
+ inode := fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, 0777)
+ rf := inode.impl.(*regularFile)
+ if allowSeals {
+ rf.seals = 0
+ }
+
+ d := fs.newDentry(inode)
+ defer d.DecRef()
+ d.name = name
+
+ // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd files are set up with
+ // FMODE_READ | FMODE_WRITE.
+ var fd regularFileFD
+ fd.Init(&inode.locks)
+ flags := uint32(linux.O_RDWR)
+ if err := fd.vfsfd.Init(&fd, flags, mount, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync. It does nothing because all
+// filesystem state is in-memory.
+func (*fileDescription) Sync(context.Context) error {
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go
new file mode 100644
index 000000000..a240fb276
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go
@@ -0,0 +1,156 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// nextFileID is used to generate unique file names.
+var nextFileID int64
+
+// newTmpfsRoot creates a new tmpfs mount, and returns the root. If the error
+// is not nil, then cleanup should be called when the root is no longer needed.
+func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentry, func(), error) {
+ creds := auth.CredentialsFromContext(ctx)
+
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err)
+ }
+
+ vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
+ if err != nil {
+ return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("failed to create tmpfs root mount: %v", err)
+ }
+ root := mntns.Root()
+ return vfsObj, root, func() {
+ root.DecRef()
+ mntns.DecRef()
+ }, nil
+}
+
+// newFileFD creates a new file in a new tmpfs mount, and returns the FD. If
+// the returned err is not nil, then cleanup should be called when the FD is no
+// longer needed.
+func newFileFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) {
+ creds := auth.CredentialsFromContext(ctx)
+ vfsObj, root, cleanup, err := newTmpfsRoot(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ filename := fmt.Sprintf("tmpfs-test-file-%d", atomic.AddInt64(&nextFileID, 1))
+
+ // Create the file that will be write/read.
+ fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
+ Mode: linux.ModeRegular | mode,
+ })
+ if err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("failed to create file %q: %v", filename, err)
+ }
+
+ return fd, cleanup, nil
+}
+
+// newDirFD is like newFileFD, but for directories.
+func newDirFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) {
+ creds := auth.CredentialsFromContext(ctx)
+ vfsObj, root, cleanup, err := newTmpfsRoot(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ dirname := fmt.Sprintf("tmpfs-test-dir-%d", atomic.AddInt64(&nextFileID, 1))
+
+ // Create the dir.
+ if err := vfsObj.MkdirAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(dirname),
+ }, &vfs.MkdirOptions{
+ Mode: linux.ModeDirectory | mode,
+ }); err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("failed to create directory %q: %v", dirname, err)
+ }
+
+ // Open the dir and return it.
+ fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(dirname),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY | linux.O_DIRECTORY,
+ })
+ if err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("failed to open directory %q: %v", dirname, err)
+ }
+
+ return fd, cleanup, nil
+}
+
+// newPipeFD is like newFileFD, but for pipes.
+func newPipeFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) {
+ creds := auth.CredentialsFromContext(ctx)
+ vfsObj, root, cleanup, err := newTmpfsRoot(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ name := fmt.Sprintf("tmpfs-test-%d", atomic.AddInt64(&nextFileID, 1))
+
+ if err := vfsObj.MknodAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(name),
+ }, &vfs.MknodOptions{
+ Mode: linux.ModeNamedPipe | mode,
+ }); err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("failed to create pipe %q: %v", name, err)
+ }
+
+ fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(name),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("failed to open pipe %q: %v", name, err)
+ }
+
+ return fd, cleanup, nil
+}
diff --git a/pkg/sentry/hostcpu/BUILD b/pkg/sentry/hostcpu/BUILD
new file mode 100644
index 000000000..e6933aa70
--- /dev/null
+++ b/pkg/sentry/hostcpu/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "hostcpu",
+ srcs = [
+ "getcpu_amd64.s",
+ "getcpu_arm64.s",
+ "hostcpu.go",
+ ],
+ visibility = ["//:sandbox"],
+)
+
+go_test(
+ name = "hostcpu_test",
+ size = "small",
+ srcs = ["hostcpu_test.go"],
+ library = ":hostcpu",
+)
diff --git a/pkg/sentry/hostcpu/getcpu_amd64.s b/pkg/sentry/hostcpu/getcpu_amd64.s
new file mode 100644
index 000000000..aa00316da
--- /dev/null
+++ b/pkg/sentry/hostcpu/getcpu_amd64.s
@@ -0,0 +1,24 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// func GetCPU() (cpu uint32)
+TEXT ·GetCPU(SB), NOSPLIT, $0-4
+ BYTE $0x0f; BYTE $0x01; BYTE $0xf9; // RDTSCP
+ // On Linux, the bottom 12 bits of IA32_TSC_AUX are CPU and the upper 20
+ // are node. See arch/x86/entry/vdso/vma.c:vgetcpu_cpu_init().
+ ANDL $0xfff, CX
+ MOVL CX, cpu+0(FP)
+ RET
diff --git a/pkg/sentry/hostcpu/getcpu_arm64.s b/pkg/sentry/hostcpu/getcpu_arm64.s
new file mode 100644
index 000000000..caf9abb89
--- /dev/null
+++ b/pkg/sentry/hostcpu/getcpu_arm64.s
@@ -0,0 +1,28 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// GetCPU makes the getcpu(unsigned *cpu, unsigned *node, NULL) syscall for
+// the lack of an optimazed way of getting the current CPU number on arm64.
+
+// func GetCPU() (cpu uint32)
+TEXT ·GetCPU(SB), NOSPLIT, $0-4
+ MOVW ZR, cpu+0(FP)
+ MOVD $cpu+0(FP), R0
+ MOVD $0x0, R1 // unused
+ MOVD $0x0, R2 // unused
+ MOVD $0xA8, R8 // SYS_GETCPU
+ SVC
+ RET
diff --git a/pkg/sentry/hostcpu/hostcpu.go b/pkg/sentry/hostcpu/hostcpu.go
new file mode 100644
index 000000000..d78f78402
--- /dev/null
+++ b/pkg/sentry/hostcpu/hostcpu.go
@@ -0,0 +1,67 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package hostcpu provides utilities for working with CPU information provided
+// by a host Linux kernel.
+package hostcpu
+
+import (
+ "fmt"
+ "io/ioutil"
+ "strconv"
+ "strings"
+ "unicode"
+)
+
+// GetCPU returns the caller's current CPU number, without using the Linux VDSO
+// (which is not available to the sentry) or the getcpu(2) system call (which
+// is relatively slow).
+func GetCPU() uint32
+
+// MaxPossibleCPU returns the highest possible CPU number, which is guaranteed
+// not to change for the lifetime of the host kernel.
+func MaxPossibleCPU() (uint32, error) {
+ const path = "/sys/devices/system/cpu/possible"
+ data, err := ioutil.ReadFile(path)
+ if err != nil {
+ return 0, err
+ }
+ str := string(data)
+ // Linux: drivers/base/cpu.c:show_cpus_attr() =>
+ // include/linux/cpumask.h:cpumask_print_to_pagebuf() =>
+ // lib/bitmap.c:bitmap_print_to_pagebuf()
+ i, err := maxValueInLinuxBitmap(str)
+ if err != nil {
+ return 0, fmt.Errorf("invalid %s (%q): %v", path, str, err)
+ }
+ return uint32(i), nil
+}
+
+// maxValueInLinuxBitmap returns the maximum value specified in str, which is a
+// string emitted by Linux's lib/bitmap.c:bitmap_print_to_pagebuf(list=true).
+func maxValueInLinuxBitmap(str string) (uint64, error) {
+ str = strings.TrimSpace(str)
+ // Find the last decimal number in str.
+ idx := strings.LastIndexFunc(str, func(c rune) bool {
+ return !unicode.IsDigit(c)
+ })
+ if idx != -1 {
+ str = str[idx+1:]
+ }
+ i, err := strconv.ParseUint(str, 10, 64)
+ if err != nil {
+ return 0, err
+ }
+ return i, nil
+}
diff --git a/pkg/sentry/hostcpu/hostcpu_test.go b/pkg/sentry/hostcpu/hostcpu_test.go
new file mode 100644
index 000000000..7d6885c9e
--- /dev/null
+++ b/pkg/sentry/hostcpu/hostcpu_test.go
@@ -0,0 +1,52 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostcpu
+
+import (
+ "fmt"
+ "testing"
+)
+
+func TestMaxValueInLinuxBitmap(t *testing.T) {
+ for _, test := range []struct {
+ str string
+ max uint64
+ }{
+ {"0", 0},
+ {"0\n", 0},
+ {"0,2", 2},
+ {"0-63", 63},
+ {"0-3,8-11", 11},
+ } {
+ t.Run(fmt.Sprintf("%q", test.str), func(t *testing.T) {
+ max, err := maxValueInLinuxBitmap(test.str)
+ if err != nil || max != test.max {
+ t.Errorf("maxValueInLinuxBitmap: got (%d, %v), wanted (%d, nil)", max, err, test.max)
+ }
+ })
+ }
+}
+
+func TestMaxValueInLinuxBitmapErrors(t *testing.T) {
+ for _, str := range []string{"", "\n"} {
+ t.Run(fmt.Sprintf("%q", str), func(t *testing.T) {
+ max, err := maxValueInLinuxBitmap(str)
+ if err == nil {
+ t.Errorf("maxValueInLinuxBitmap: got (%d, nil), wanted (_, error)", max)
+ }
+ t.Log(err)
+ })
+ }
+}
diff --git a/pkg/sentry/hostfd/BUILD b/pkg/sentry/hostfd/BUILD
new file mode 100644
index 000000000..364a78306
--- /dev/null
+++ b/pkg/sentry/hostfd/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "hostfd",
+ srcs = [
+ "hostfd.go",
+ "hostfd_unsafe.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/safemem",
+ "//pkg/sync",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/hostfd/hostfd.go b/pkg/sentry/hostfd/hostfd.go
new file mode 100644
index 000000000..70dd9cafb
--- /dev/null
+++ b/pkg/sentry/hostfd/hostfd.go
@@ -0,0 +1,84 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package hostfd provides efficient I/O with host file descriptors.
+package hostfd
+
+import (
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// ReadWriterAt implements safemem.Reader and safemem.Writer by reading from
+// and writing to a host file descriptor respectively. ReadWriterAts should be
+// obtained by calling GetReadWriterAt.
+//
+// Clients should usually prefer to use Preadv2 and Pwritev2 directly.
+type ReadWriterAt struct {
+ fd int32
+ offset int64
+ flags uint32
+}
+
+var rwpool = sync.Pool{
+ New: func() interface{} {
+ return &ReadWriterAt{}
+ },
+}
+
+// GetReadWriterAt returns a ReadWriterAt that reads from / writes to the given
+// host file descriptor, starting at the given offset and using the given
+// preadv2(2)/pwritev2(2) flags. If offset is -1, the host file descriptor's
+// offset is used instead. Users are responsible for ensuring that fd remains
+// valid for the lifetime of the returned ReadWriterAt, and must call
+// PutReadWriterAt when it is no longer needed.
+func GetReadWriterAt(fd int32, offset int64, flags uint32) *ReadWriterAt {
+ rw := rwpool.Get().(*ReadWriterAt)
+ *rw = ReadWriterAt{
+ fd: fd,
+ offset: offset,
+ flags: flags,
+ }
+ return rw
+}
+
+// PutReadWriterAt releases a ReadWriterAt returned by a previous call to
+// GetReadWriterAt that is no longer in use.
+func PutReadWriterAt(rw *ReadWriterAt) {
+ rwpool.Put(rw)
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (rw *ReadWriterAt) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+ n, err := Preadv2(rw.fd, dsts, rw.offset, rw.flags)
+ if rw.offset >= 0 {
+ rw.offset += int64(n)
+ }
+ return n, err
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (rw *ReadWriterAt) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ if srcs.IsEmpty() {
+ return 0, nil
+ }
+ n, err := Pwritev2(rw.fd, srcs, rw.offset, rw.flags)
+ if rw.offset >= 0 {
+ rw.offset += int64(n)
+ }
+ return n, err
+}
diff --git a/pkg/sentry/hostfd/hostfd_unsafe.go b/pkg/sentry/hostfd/hostfd_unsafe.go
new file mode 100644
index 000000000..cd4dc67fb
--- /dev/null
+++ b/pkg/sentry/hostfd/hostfd_unsafe.go
@@ -0,0 +1,85 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostfd
+
+import (
+ "io"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/safemem"
+)
+
+// Preadv2 reads up to dsts.NumBytes() bytes from host file descriptor fd into
+// dsts. offset and flags are interpreted as for preadv2(2).
+//
+// Preconditions: !dsts.IsEmpty().
+func Preadv2(fd int32, dsts safemem.BlockSeq, offset int64, flags uint32) (uint64, error) {
+ // No buffering is necessary regardless of safecopy; host syscalls will
+ // return EFAULT if appropriate, instead of raising SIGBUS.
+ var (
+ n uintptr
+ e syscall.Errno
+ )
+ if flags == 0 && dsts.NumBlocks() == 1 {
+ // Use read() or pread() to avoid iovec allocation and copying.
+ dst := dsts.Head()
+ if offset == -1 {
+ n, _, e = syscall.Syscall(unix.SYS_READ, uintptr(fd), dst.Addr(), uintptr(dst.Len()))
+ } else {
+ n, _, e = syscall.Syscall6(unix.SYS_PREAD64, uintptr(fd), dst.Addr(), uintptr(dst.Len()), uintptr(offset), 0 /* pos_h */, 0 /* unused */)
+ }
+ } else {
+ iovs := safemem.IovecsFromBlockSeq(dsts)
+ n, _, e = syscall.Syscall6(unix.SYS_PREADV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags))
+ }
+ if e != 0 {
+ return 0, e
+ }
+ if n == 0 {
+ return 0, io.EOF
+ }
+ return uint64(n), nil
+}
+
+// Pwritev2 writes up to srcs.NumBytes() from srcs into host file descriptor
+// fd. offset and flags are interpreted as for pwritev2(2).
+//
+// Preconditions: !srcs.IsEmpty().
+func Pwritev2(fd int32, srcs safemem.BlockSeq, offset int64, flags uint32) (uint64, error) {
+ // No buffering is necessary regardless of safecopy; host syscalls will
+ // return EFAULT if appropriate, instead of raising SIGBUS.
+ var (
+ n uintptr
+ e syscall.Errno
+ )
+ if flags == 0 && srcs.NumBlocks() == 1 {
+ // Use write() or pwrite() to avoid iovec allocation and copying.
+ src := srcs.Head()
+ if offset == -1 {
+ n, _, e = syscall.Syscall(unix.SYS_WRITE, uintptr(fd), src.Addr(), uintptr(src.Len()))
+ } else {
+ n, _, e = syscall.Syscall6(unix.SYS_PWRITE64, uintptr(fd), src.Addr(), uintptr(src.Len()), uintptr(offset), 0 /* pos_h */, 0 /* unused */)
+ }
+ } else {
+ iovs := safemem.IovecsFromBlockSeq(srcs)
+ n, _, e = syscall.Syscall6(unix.SYS_PWRITEV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags))
+ }
+ if e != 0 {
+ return 0, e
+ }
+ return uint64(n), nil
+}
diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD
new file mode 100644
index 000000000..61c78569d
--- /dev/null
+++ b/pkg/sentry/hostmm/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "hostmm",
+ srcs = [
+ "cgroup.go",
+ "hostmm.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/fd",
+ "//pkg/log",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/hostmm/cgroup.go b/pkg/sentry/hostmm/cgroup.go
new file mode 100644
index 000000000..e5cc26ab2
--- /dev/null
+++ b/pkg/sentry/hostmm/cgroup.go
@@ -0,0 +1,111 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostmm
+
+import (
+ "bufio"
+ "fmt"
+ "os"
+ "path"
+ "strings"
+)
+
+// currentCgroupDirectory returns the directory for the cgroup for the given
+// controller in which the calling process resides.
+func currentCgroupDirectory(ctrl string) (string, error) {
+ root, err := cgroupRootDirectory(ctrl)
+ if err != nil {
+ return "", err
+ }
+ cg, err := currentCgroup(ctrl)
+ if err != nil {
+ return "", err
+ }
+ return path.Join(root, cg), nil
+}
+
+// cgroupRootDirectory returns the root directory for the cgroup hierarchy in
+// which the given cgroup controller is mounted in the calling process' mount
+// namespace.
+func cgroupRootDirectory(ctrl string) (string, error) {
+ const path = "/proc/self/mounts"
+ file, err := os.Open(path)
+ if err != nil {
+ return "", err
+ }
+ defer file.Close()
+
+ // Per proc(5) -> fstab(5):
+ // Each line of /proc/self/mounts describes a mount.
+ scanner := bufio.NewScanner(file)
+ for scanner.Scan() {
+ // Each line consists of 6 space-separated fields. Find the line for
+ // which the third field (fs_vfstype) is cgroup, and the fourth field
+ // (fs_mntops, a comma-separated list of mount options) contains
+ // ctrl.
+ var spec, file, vfstype, mntopts, freq, passno string
+ const nrfields = 6
+ line := scanner.Text()
+ n, err := fmt.Sscan(line, &spec, &file, &vfstype, &mntopts, &freq, &passno)
+ if err != nil {
+ return "", fmt.Errorf("failed to parse %s: %v", path, err)
+ }
+ if n != nrfields {
+ return "", fmt.Errorf("failed to parse %s: line %q: got %d fields, wanted %d", path, line, n, nrfields)
+ }
+ if vfstype != "cgroup" {
+ continue
+ }
+ for _, mntopt := range strings.Split(mntopts, ",") {
+ if mntopt == ctrl {
+ return file, nil
+ }
+ }
+ }
+ return "", fmt.Errorf("no cgroup hierarchy mounted for controller %s", ctrl)
+}
+
+// currentCgroup returns the cgroup for the given controller in which the
+// calling process resides. The returned string is a path that should be
+// interpreted as relative to cgroupRootDirectory(ctrl).
+func currentCgroup(ctrl string) (string, error) {
+ const path = "/proc/self/cgroup"
+ file, err := os.Open(path)
+ if err != nil {
+ return "", err
+ }
+ defer file.Close()
+
+ // Per proc(5) -> cgroups(7):
+ // Each line of /proc/self/cgroups describes a cgroup hierarchy.
+ scanner := bufio.NewScanner(file)
+ for scanner.Scan() {
+ // Each line consists of 3 colon-separated fields. Find the line for
+ // which the second field (controller-list, a comma-separated list of
+ // cgroup controllers) contains ctrl.
+ line := scanner.Text()
+ const nrfields = 3
+ fields := strings.Split(line, ":")
+ if len(fields) != nrfields {
+ return "", fmt.Errorf("failed to parse %s: line %q: got %d fields, wanted %d", path, line, len(fields), nrfields)
+ }
+ for _, controller := range strings.Split(fields[1], ",") {
+ if controller == ctrl {
+ return fields[2], nil
+ }
+ }
+ }
+ return "", fmt.Errorf("not a member of a cgroup hierarchy for controller %s", ctrl)
+}
diff --git a/pkg/sentry/hostmm/hostmm.go b/pkg/sentry/hostmm/hostmm.go
new file mode 100644
index 000000000..506c7864a
--- /dev/null
+++ b/pkg/sentry/hostmm/hostmm.go
@@ -0,0 +1,130 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package hostmm provides tools for interacting with the host Linux kernel's
+// virtual memory management subsystem.
+package hostmm
+
+import (
+ "fmt"
+ "os"
+ "path"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// NotifyCurrentMemcgPressureCallback requests that f is called whenever the
+// calling process' memory cgroup indicates memory pressure of the given level,
+// as specified by Linux's Documentation/cgroup-v1/memory.txt.
+//
+// If NotifyCurrentMemcgPressureCallback succeeds, it returns a function that
+// terminates the requested memory pressure notifications. This function may be
+// called at most once.
+func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error) {
+ cgdir, err := currentCgroupDirectory("memory")
+ if err != nil {
+ return nil, err
+ }
+
+ pressurePath := path.Join(cgdir, "memory.pressure_level")
+ pressureFile, err := os.Open(pressurePath)
+ if err != nil {
+ return nil, err
+ }
+ defer pressureFile.Close()
+
+ eventControlPath := path.Join(cgdir, "cgroup.event_control")
+ eventControlFile, err := os.OpenFile(eventControlPath, os.O_WRONLY, 0)
+ if err != nil {
+ return nil, err
+ }
+ defer eventControlFile.Close()
+
+ eventFD, err := newEventFD()
+ if err != nil {
+ return nil, err
+ }
+
+ // Don't use fmt.Fprintf since the whole string needs to be written in a
+ // single syscall.
+ eventControlStr := fmt.Sprintf("%d %d %s", eventFD.FD(), pressureFile.Fd(), level)
+ if n, err := eventControlFile.Write([]byte(eventControlStr)); n != len(eventControlStr) || err != nil {
+ eventFD.Close()
+ return nil, fmt.Errorf("error writing %q to %s: got (%d, %v), wanted (%d, nil)", eventControlStr, eventControlPath, n, err, len(eventControlStr))
+ }
+
+ log.Debugf("Receiving memory pressure level notifications from %s at level %q", pressurePath, level)
+ const sizeofUint64 = 8
+ // The most significant bit of the eventfd value is set by the stop
+ // function, which is practically unambiguous since it's not plausible for
+ // 2**63 pressure events to occur between eventfd reads.
+ const stopVal = 1 << 63
+ stopCh := make(chan struct{})
+ go func() { // S/R-SAFE: f provides synchronization if necessary
+ rw := fd.NewReadWriter(eventFD.FD())
+ var buf [sizeofUint64]byte
+ for {
+ n, err := rw.Read(buf[:])
+ if err != nil {
+ if err == syscall.EINTR {
+ continue
+ }
+ panic(fmt.Sprintf("failed to read from memory pressure level eventfd: %v", err))
+ }
+ if n != sizeofUint64 {
+ panic(fmt.Sprintf("short read from memory pressure level eventfd: got %d bytes, wanted %d", n, sizeofUint64))
+ }
+ val := usermem.ByteOrder.Uint64(buf[:])
+ if val >= stopVal {
+ // Assume this was due to the notifier's "destructor" (the
+ // function returned by NotifyCurrentMemcgPressureCallback
+ // below) being called.
+ eventFD.Close()
+ close(stopCh)
+ return
+ }
+ f()
+ }
+ }()
+ return func() {
+ rw := fd.NewReadWriter(eventFD.FD())
+ var buf [sizeofUint64]byte
+ usermem.ByteOrder.PutUint64(buf[:], stopVal)
+ for {
+ n, err := rw.Write(buf[:])
+ if err != nil {
+ if err == syscall.EINTR {
+ continue
+ }
+ panic(fmt.Sprintf("failed to write to memory pressure level eventfd: %v", err))
+ }
+ if n != sizeofUint64 {
+ panic(fmt.Sprintf("short write to memory pressure level eventfd: got %d bytes, wanted %d", n, sizeofUint64))
+ }
+ break
+ }
+ <-stopCh
+ }, nil
+}
+
+func newEventFD() (*fd.FD, error) {
+ f, _, e := syscall.Syscall(syscall.SYS_EVENTFD2, 0, 0, 0)
+ if e != 0 {
+ return nil, fmt.Errorf("failed to create eventfd: %v", e)
+ }
+ return fd.New(int(f)), nil
+}
diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD
new file mode 100644
index 000000000..07bf39fed
--- /dev/null
+++ b/pkg/sentry/inet/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library")
+
+package(
+ default_visibility = ["//:sandbox"],
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "inet",
+ srcs = [
+ "context.go",
+ "inet.go",
+ "namespace.go",
+ "test_stack.go",
+ ],
+ deps = [
+ "//pkg/context",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/sentry/inet/context.go b/pkg/sentry/inet/context.go
new file mode 100644
index 000000000..e8cc1bffd
--- /dev/null
+++ b/pkg/sentry/inet/context.go
@@ -0,0 +1,35 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package inet
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+// contextID is the inet package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxStack is a Context.Value key for a network stack.
+ CtxStack contextID = iota
+)
+
+// StackFromContext returns the network stack associated with ctx.
+func StackFromContext(ctx context.Context) Stack {
+ if v := ctx.Value(CtxStack); v != nil {
+ return v.(Stack)
+ }
+ return nil
+}
diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go
new file mode 100644
index 000000000..2916a0644
--- /dev/null
+++ b/pkg/sentry/inet/inet.go
@@ -0,0 +1,191 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package inet defines semantics for IP stacks.
+package inet
+
+import "gvisor.dev/gvisor/pkg/tcpip/stack"
+
+// Stack represents a TCP/IP stack.
+type Stack interface {
+ // Interfaces returns all network interfaces as a mapping from interface
+ // indexes to interface properties. Interface indices are strictly positive
+ // integers.
+ Interfaces() map[int32]Interface
+
+ // InterfaceAddrs returns all network interface addresses as a mapping from
+ // interface indexes to a slice of associated interface address properties.
+ InterfaceAddrs() map[int32][]InterfaceAddr
+
+ // AddInterfaceAddr adds an address to the network interface identified by
+ // index.
+ AddInterfaceAddr(idx int32, addr InterfaceAddr) error
+
+ // SupportsIPv6 returns true if the stack supports IPv6 connectivity.
+ SupportsIPv6() bool
+
+ // TCPReceiveBufferSize returns TCP receive buffer size settings.
+ TCPReceiveBufferSize() (TCPBufferSize, error)
+
+ // SetTCPReceiveBufferSize attempts to change TCP receive buffer size
+ // settings.
+ SetTCPReceiveBufferSize(size TCPBufferSize) error
+
+ // TCPSendBufferSize returns TCP send buffer size settings.
+ TCPSendBufferSize() (TCPBufferSize, error)
+
+ // SetTCPSendBufferSize attempts to change TCP send buffer size settings.
+ SetTCPSendBufferSize(size TCPBufferSize) error
+
+ // TCPSACKEnabled returns true if RFC 2018 TCP Selective Acknowledgements
+ // are enabled.
+ TCPSACKEnabled() (bool, error)
+
+ // SetTCPSACKEnabled attempts to change TCP selective acknowledgement
+ // settings.
+ SetTCPSACKEnabled(enabled bool) error
+
+ // Statistics reports stack statistics.
+ Statistics(stat interface{}, arg string) error
+
+ // RouteTable returns the network stack's route table.
+ RouteTable() []Route
+
+ // Resume restarts the network stack after restore.
+ Resume()
+
+ // RegisteredEndpoints returns all endpoints which are currently registered.
+ RegisteredEndpoints() []stack.TransportEndpoint
+
+ // CleanupEndpoints returns endpoints currently in the cleanup state.
+ CleanupEndpoints() []stack.TransportEndpoint
+
+ // RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful
+ // for restoring a stack after a save.
+ RestoreCleanupEndpoints([]stack.TransportEndpoint)
+}
+
+// Interface contains information about a network interface.
+type Interface struct {
+ // DeviceType is the device type, a Linux ARPHRD_* constant.
+ DeviceType uint16
+
+ // Flags is the device flags; see netdevice(7), under "Ioctls",
+ // "SIOCGIFFLAGS, SIOCSIFFLAGS".
+ Flags uint32
+
+ // Name is the device name.
+ Name string
+
+ // Addr is the hardware device address.
+ Addr []byte
+
+ // MTU is the maximum transmission unit.
+ MTU uint32
+}
+
+// InterfaceAddr contains information about a network interface address.
+type InterfaceAddr struct {
+ // Family is the address family, a Linux AF_* constant.
+ Family uint8
+
+ // PrefixLen is the address prefix length.
+ PrefixLen uint8
+
+ // Flags is the address flags.
+ Flags uint8
+
+ // Addr is the actual address.
+ Addr []byte
+}
+
+// TCPBufferSize contains settings controlling TCP buffer sizing.
+//
+// +stateify savable
+type TCPBufferSize struct {
+ // Min is the minimum size.
+ Min int
+
+ // Default is the default size.
+ Default int
+
+ // Max is the maximum size.
+ Max int
+}
+
+// StatDev describes one line of /proc/net/dev, i.e., stats for one network
+// interface.
+type StatDev [16]uint64
+
+// Route contains information about a network route.
+type Route struct {
+ // Family is the address family, a Linux AF_* constant.
+ Family uint8
+
+ // DstLen is the length of the destination address.
+ DstLen uint8
+
+ // SrcLen is the length of the source address.
+ SrcLen uint8
+
+ // TOS is the Type of Service filter.
+ TOS uint8
+
+ // Table is the routing table ID.
+ Table uint8
+
+ // Protocol is the route origin, a Linux RTPROT_* constant.
+ Protocol uint8
+
+ // Scope is the distance to destination, a Linux RT_SCOPE_* constant.
+ Scope uint8
+
+ // Type is the route origin, a Linux RTN_* constant.
+ Type uint8
+
+ // Flags are route flags. See rtnetlink(7) under "rtm_flags".
+ Flags uint32
+
+ // DstAddr is the route destination address (RTA_DST).
+ DstAddr []byte
+
+ // SrcAddr is the route source address (RTA_SRC).
+ SrcAddr []byte
+
+ // OutputInterface is the output interface index (RTA_OIF).
+ OutputInterface int32
+
+ // GatewayAddr is the route gateway address (RTA_GATEWAY).
+ GatewayAddr []byte
+}
+
+// Below SNMP metrics are from Linux/usr/include/linux/snmp.h.
+
+// StatSNMPIP describes Ip line of /proc/net/snmp.
+type StatSNMPIP [19]uint64
+
+// StatSNMPICMP describes Icmp line of /proc/net/snmp.
+type StatSNMPICMP [27]uint64
+
+// StatSNMPICMPMSG describes IcmpMsg line of /proc/net/snmp.
+type StatSNMPICMPMSG [512]uint64
+
+// StatSNMPTCP describes Tcp line of /proc/net/snmp.
+type StatSNMPTCP [15]uint64
+
+// StatSNMPUDP describes Udp line of /proc/net/snmp.
+type StatSNMPUDP [8]uint64
+
+// StatSNMPUDPLite describes UdpLite line of /proc/net/snmp.
+type StatSNMPUDPLite [8]uint64
diff --git a/pkg/sentry/inet/namespace.go b/pkg/sentry/inet/namespace.go
new file mode 100644
index 000000000..029af3025
--- /dev/null
+++ b/pkg/sentry/inet/namespace.go
@@ -0,0 +1,102 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package inet
+
+// Namespace represents a network namespace. See network_namespaces(7).
+//
+// +stateify savable
+type Namespace struct {
+ // stack is the network stack implementation of this network namespace.
+ stack Stack `state:"nosave"`
+
+ // creator allows kernel to create new network stack for network namespaces.
+ // If nil, no networking will function if network is namespaced.
+ //
+ // At afterLoad(), creator will be used to create network stack. Stateify
+ // needs to wait for this field to be loaded before calling afterLoad().
+ creator NetworkStackCreator `state:"wait"`
+
+ // isRoot indicates whether this is the root network namespace.
+ isRoot bool
+}
+
+// NewRootNamespace creates the root network namespace, with creator
+// allowing new network namespaces to be created. If creator is nil, no
+// networking will function if the network is namespaced.
+func NewRootNamespace(stack Stack, creator NetworkStackCreator) *Namespace {
+ return &Namespace{
+ stack: stack,
+ creator: creator,
+ isRoot: true,
+ }
+}
+
+// NewNamespace creates a new network namespace from the root.
+func NewNamespace(root *Namespace) *Namespace {
+ n := &Namespace{
+ creator: root.creator,
+ }
+ n.init()
+ return n
+}
+
+// Stack returns the network stack of n. Stack may return nil if no network
+// stack is configured.
+func (n *Namespace) Stack() Stack {
+ return n.stack
+}
+
+// IsRoot returns whether n is the root network namespace.
+func (n *Namespace) IsRoot() bool {
+ return n.isRoot
+}
+
+// RestoreRootStack restores the root network namespace with stack. This should
+// only be called when restoring kernel.
+func (n *Namespace) RestoreRootStack(stack Stack) {
+ if !n.isRoot {
+ panic("RestoreRootStack can only be called on root network namespace")
+ }
+ if n.stack != nil {
+ panic("RestoreRootStack called after a stack has already been set")
+ }
+ n.stack = stack
+}
+
+func (n *Namespace) init() {
+ // Root network namespace will have stack assigned later.
+ if n.isRoot {
+ return
+ }
+ if n.creator != nil {
+ var err error
+ n.stack, err = n.creator.CreateStack()
+ if err != nil {
+ panic(err)
+ }
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (n *Namespace) afterLoad() {
+ n.init()
+}
+
+// NetworkStackCreator allows new instances of a network stack to be created. It
+// is used by the kernel to create new network namespaces when requested.
+type NetworkStackCreator interface {
+ // CreateStack creates a new network stack for a network namespace.
+ CreateStack() (Stack, error)
+}
diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go
new file mode 100644
index 000000000..d8961fc94
--- /dev/null
+++ b/pkg/sentry/inet/test_stack.go
@@ -0,0 +1,118 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package inet
+
+import "gvisor.dev/gvisor/pkg/tcpip/stack"
+
+// TestStack is a dummy implementation of Stack for tests.
+type TestStack struct {
+ InterfacesMap map[int32]Interface
+ InterfaceAddrsMap map[int32][]InterfaceAddr
+ RouteList []Route
+ SupportsIPv6Flag bool
+ TCPRecvBufSize TCPBufferSize
+ TCPSendBufSize TCPBufferSize
+ TCPSACKFlag bool
+}
+
+// NewTestStack returns a TestStack with no network interfaces. The value of
+// all other options is unspecified; tests that rely on specific values must
+// set them explicitly.
+func NewTestStack() *TestStack {
+ return &TestStack{
+ InterfacesMap: make(map[int32]Interface),
+ InterfaceAddrsMap: make(map[int32][]InterfaceAddr),
+ }
+}
+
+// Interfaces implements Stack.Interfaces.
+func (s *TestStack) Interfaces() map[int32]Interface {
+ return s.InterfacesMap
+}
+
+// InterfaceAddrs implements Stack.InterfaceAddrs.
+func (s *TestStack) InterfaceAddrs() map[int32][]InterfaceAddr {
+ return s.InterfaceAddrsMap
+}
+
+// AddInterfaceAddr implements Stack.AddInterfaceAddr.
+func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error {
+ s.InterfaceAddrsMap[idx] = append(s.InterfaceAddrsMap[idx], addr)
+ return nil
+}
+
+// SupportsIPv6 implements Stack.SupportsIPv6.
+func (s *TestStack) SupportsIPv6() bool {
+ return s.SupportsIPv6Flag
+}
+
+// TCPReceiveBufferSize implements Stack.TCPReceiveBufferSize.
+func (s *TestStack) TCPReceiveBufferSize() (TCPBufferSize, error) {
+ return s.TCPRecvBufSize, nil
+}
+
+// SetTCPReceiveBufferSize implements Stack.SetTCPReceiveBufferSize.
+func (s *TestStack) SetTCPReceiveBufferSize(size TCPBufferSize) error {
+ s.TCPRecvBufSize = size
+ return nil
+}
+
+// TCPSendBufferSize implements Stack.TCPSendBufferSize.
+func (s *TestStack) TCPSendBufferSize() (TCPBufferSize, error) {
+ return s.TCPSendBufSize, nil
+}
+
+// SetTCPSendBufferSize implements Stack.SetTCPSendBufferSize.
+func (s *TestStack) SetTCPSendBufferSize(size TCPBufferSize) error {
+ s.TCPSendBufSize = size
+ return nil
+}
+
+// TCPSACKEnabled implements Stack.TCPSACKEnabled.
+func (s *TestStack) TCPSACKEnabled() (bool, error) {
+ return s.TCPSACKFlag, nil
+}
+
+// SetTCPSACKEnabled implements Stack.SetTCPSACKEnabled.
+func (s *TestStack) SetTCPSACKEnabled(enabled bool) error {
+ s.TCPSACKFlag = enabled
+ return nil
+}
+
+// Statistics implements inet.Stack.Statistics.
+func (s *TestStack) Statistics(stat interface{}, arg string) error {
+ return nil
+}
+
+// RouteTable implements Stack.RouteTable.
+func (s *TestStack) RouteTable() []Route {
+ return s.RouteList
+}
+
+// Resume implements Stack.Resume.
+func (s *TestStack) Resume() {}
+
+// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
+func (s *TestStack) RegisteredEndpoints() []stack.TransportEndpoint {
+ return nil
+}
+
+// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
+func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint {
+ return nil
+}
+
+// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
+func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
new file mode 100644
index 000000000..25fe1921b
--- /dev/null
+++ b/pkg/sentry/kernel/BUILD
@@ -0,0 +1,241 @@
+load("//tools:defs.bzl", "go_library", "go_test", "proto_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "pending_signals_list",
+ out = "pending_signals_list.go",
+ package = "kernel",
+ prefix = "pendingSignal",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*pendingSignal",
+ "Linker": "*pendingSignal",
+ },
+)
+
+go_template_instance(
+ name = "process_group_list",
+ out = "process_group_list.go",
+ package = "kernel",
+ prefix = "processGroup",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*ProcessGroup",
+ "Linker": "*ProcessGroup",
+ },
+)
+
+go_template_instance(
+ name = "seqatomic_taskgoroutineschedinfo",
+ out = "seqatomic_taskgoroutineschedinfo_unsafe.go",
+ package = "kernel",
+ suffix = "TaskGoroutineSchedInfo",
+ template = "//pkg/sync:generic_seqatomic",
+ types = {
+ "Value": "TaskGoroutineSchedInfo",
+ },
+)
+
+go_template_instance(
+ name = "session_list",
+ out = "session_list.go",
+ package = "kernel",
+ prefix = "session",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*Session",
+ "Linker": "*Session",
+ },
+)
+
+go_template_instance(
+ name = "task_list",
+ out = "task_list.go",
+ package = "kernel",
+ prefix = "task",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*Task",
+ "Linker": "*Task",
+ },
+)
+
+go_template_instance(
+ name = "socket_list",
+ out = "socket_list.go",
+ package = "kernel",
+ prefix = "socket",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*SocketEntry",
+ "Linker": "*SocketEntry",
+ },
+)
+
+proto_library(
+ name = "uncaught_signal",
+ srcs = ["uncaught_signal.proto"],
+ visibility = ["//visibility:public"],
+ deps = ["//pkg/sentry/arch:registers_proto"],
+)
+
+go_library(
+ name = "kernel",
+ srcs = [
+ "abstract_socket_namespace.go",
+ "aio.go",
+ "context.go",
+ "fd_table.go",
+ "fd_table_unsafe.go",
+ "fs_context.go",
+ "ipc_namespace.go",
+ "kernel.go",
+ "kernel_opts.go",
+ "kernel_state.go",
+ "pending_signals.go",
+ "pending_signals_list.go",
+ "pending_signals_state.go",
+ "posixtimer.go",
+ "process_group_list.go",
+ "ptrace.go",
+ "ptrace_amd64.go",
+ "ptrace_arm64.go",
+ "rseq.go",
+ "seccomp.go",
+ "seqatomic_taskgoroutineschedinfo_unsafe.go",
+ "session_list.go",
+ "sessions.go",
+ "signal.go",
+ "signal_handlers.go",
+ "socket_list.go",
+ "syscalls.go",
+ "syscalls_state.go",
+ "syslog.go",
+ "task.go",
+ "task_acct.go",
+ "task_block.go",
+ "task_clone.go",
+ "task_context.go",
+ "task_exec.go",
+ "task_exit.go",
+ "task_futex.go",
+ "task_identity.go",
+ "task_list.go",
+ "task_log.go",
+ "task_net.go",
+ "task_run.go",
+ "task_sched.go",
+ "task_signals.go",
+ "task_start.go",
+ "task_stop.go",
+ "task_syscall.go",
+ "task_usermem.go",
+ "thread_group.go",
+ "threads.go",
+ "timekeeper.go",
+ "timekeeper_state.go",
+ "tty.go",
+ "uts_namespace.go",
+ "vdso.go",
+ "version.go",
+ ],
+ imports = [
+ "gvisor.dev/gvisor/pkg/bpf",
+ "gvisor.dev/gvisor/pkg/sentry/device",
+ "gvisor.dev/gvisor/pkg/tcpip",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ ":uncaught_signal_go_proto",
+ "//pkg/abi",
+ "//pkg/abi/linux",
+ "//pkg/amutex",
+ "//pkg/binary",
+ "//pkg/bits",
+ "//pkg/bpf",
+ "//pkg/context",
+ "//pkg/cpuid",
+ "//pkg/eventchannel",
+ "//pkg/fspath",
+ "//pkg/log",
+ "//pkg/metric",
+ "//pkg/refs",
+ "//pkg/safemem",
+ "//pkg/secio",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fs/timerfd",
+ "//pkg/sentry/fsbridge",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/fsimpl/pipefs",
+ "//pkg/sentry/fsimpl/sockfs",
+ "//pkg/sentry/fsimpl/timerfd",
+ "//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/hostcpu",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/epoll",
+ "//pkg/sentry/kernel/futex",
+ "//pkg/sentry/kernel/sched",
+ "//pkg/sentry/kernel/semaphore",
+ "//pkg/sentry/kernel/shm",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/loader",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/mm",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/socket/netlink/port",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/time",
+ "//pkg/sentry/unimpl",
+ "//pkg/sentry/unimpl:unimplemented_syscall_go_proto",
+ "//pkg/sentry/uniqueid",
+ "//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
+ "//pkg/state",
+ "//pkg/state/statefile",
+ "//pkg/state/wire",
+ "//pkg/sync",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/tcpip/stack",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ ],
+)
+
+go_test(
+ name = "kernel_test",
+ size = "small",
+ srcs = [
+ "fd_table_test.go",
+ "table_test.go",
+ "task_test.go",
+ "timekeeper_test.go",
+ ],
+ library = ":kernel",
+ deps = [
+ "//pkg/abi",
+ "//pkg/context",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/filetest",
+ "//pkg/sentry/kernel/sched",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/time",
+ "//pkg/sentry/usage",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/kernel/README.md b/pkg/sentry/kernel/README.md
new file mode 100644
index 000000000..427311be8
--- /dev/null
+++ b/pkg/sentry/kernel/README.md
@@ -0,0 +1,108 @@
+This package contains:
+
+- A (partial) emulation of the "core Linux kernel", which governs task
+ execution and scheduling, system call dispatch, and signal handling. See
+ below for details.
+
+- The top-level interface for the sentry's Linux kernel emulation in general,
+ used by the `main` function of all versions of the sentry. This interface
+ revolves around the `Env` type (defined in `kernel.go`).
+
+# Background
+
+In Linux, each schedulable context is referred to interchangeably as a "task" or
+"thread". Tasks can be divided into userspace and kernel tasks. In the sentry,
+scheduling is managed by the Go runtime, so each schedulable context is a
+goroutine; only "userspace" (application) contexts are referred to as tasks, and
+represented by Task objects. (From this point forward, "task" refers to the
+sentry's notion of a task unless otherwise specified.)
+
+At a high level, Linux application threads can be thought of as repeating a "run
+loop":
+
+- Some amount of application code is executed in userspace.
+
+- A trap (explicit syscall invocation, hardware interrupt or exception, etc.)
+ causes control flow to switch to the kernel.
+
+- Some amount of kernel code is executed in kernelspace, e.g. to handle the
+ cause of the trap.
+
+- The kernel "returns from the trap" into application code.
+
+Analogously, each task in the sentry is associated with a *task goroutine* that
+executes that task's run loop (`Task.run` in `task_run.go`). However, the
+sentry's task run loop differs in structure in order to support saving execution
+state to, and resuming execution from, checkpoints.
+
+While in kernelspace, a Linux thread can be descheduled (cease execution) in a
+variety of ways:
+
+- It can yield or be preempted, becoming temporarily descheduled but still
+ runnable. At present, the sentry delegates scheduling of runnable threads to
+ the Go runtime.
+
+- It can exit, becoming permanently descheduled. The sentry's equivalent is
+ returning from `Task.run`, terminating the task goroutine.
+
+- It can enter interruptible sleep, a state in which it can be woken by a
+ caller-defined wakeup or the receipt of a signal. In the sentry,
+ interruptible sleep (which is ambiguously referred to as *blocking*) is
+ implemented by making all events that can end blocking (including signal
+ notifications) communicated via Go channels and using `select` to multiplex
+ wakeup sources; see `task_block.go`.
+
+- It can enter uninterruptible sleep, a state in which it can only be woken by
+ a caller-defined wakeup. Killable sleep is a closely related variant in
+ which the task can also be woken by SIGKILL. (These definitions also include
+ Linux's "group-stopped" (`TASK_STOPPED`) and "ptrace-stopped"
+ (`TASK_TRACED`) states.)
+
+To maximize compatibility with Linux, sentry checkpointing appears as a spurious
+signal-delivery interrupt on all tasks; interrupted system calls return `EINTR`
+or are automatically restarted as usual. However, these semantics require that
+uninterruptible and killable sleeps do not appear to be interrupted. In other
+words, the state of the task, including its progress through the interrupted
+operation, must be preserved by checkpointing. For many such sleeps, the wakeup
+condition is application-controlled, making it infeasible to wait for the sleep
+to end before checkpointing. Instead, we must support checkpointing progress
+through sleeping operations.
+
+# Implementation
+
+We break the task's control flow graph into *states*, delimited by:
+
+1. Points where uninterruptible and killable sleeps may occur. For example,
+ there exists a state boundary between signal dequeueing and signal delivery
+ because there may be an intervening ptrace signal-delivery-stop.
+
+2. Points where sleep-induced branches may "rejoin" normal execution. For
+ example, the syscall exit state exists because it can be reached immediately
+ following a synchronous syscall, or after a task that is sleeping in
+ `execve()` or `vfork()` resumes execution.
+
+3. Points containing large branches. This is strictly for organizational
+ purposes. For example, the state that processes interrupt-signaled
+ conditions is kept separate from the main "app" state to reduce the size of
+ the latter.
+
+4. `SyscallReinvoke`, which does not correspond to anything in Linux, and
+ exists solely to serve the autosave feature.
+
+![dot -Tpng -Goverlap=false -orun_states.png run_states.dot](g3doc/run_states.png "Task control flow graph")
+
+States before which a stop may occur are represented as implementations of the
+`taskRunState` interface named `run(state)`, allowing them to be saved and
+restored. States that cannot be immediately preceded by a stop are simply `Task`
+methods named `do(state)`.
+
+Conditions that can require task goroutines to cease execution for unknown
+lengths of time are called *stops*. Stops are divided into *internal stops*,
+which are stops whose start and end conditions are implemented within the
+sentry, and *external stops*, which are stops whose start and end conditions are
+not known to the sentry. Hence all uninterruptible and killable sleeps are
+internal stops, and the existence of a pending checkpoint operation is an
+external stop. Internal stops are reified into instances of the `TaskStop` type,
+while external stops are merely counted. The task run loop alternates between
+checking for stops and advancing the task's state. This allows checkpointing to
+hold tasks in a stopped state while waiting for all tasks in the system to stop.
diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go
new file mode 100644
index 000000000..920fe4329
--- /dev/null
+++ b/pkg/sentry/kernel/abstract_socket_namespace.go
@@ -0,0 +1,111 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// +stateify savable
+type abstractEndpoint struct {
+ ep transport.BoundEndpoint
+ wr *refs.WeakRef
+ name string
+ ns *AbstractSocketNamespace
+}
+
+// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
+func (e *abstractEndpoint) WeakRefGone() {
+ e.ns.mu.Lock()
+ if e.ns.endpoints[e.name].ep == e.ep {
+ delete(e.ns.endpoints, e.name)
+ }
+ e.ns.mu.Unlock()
+}
+
+// AbstractSocketNamespace is used to implement the Linux abstract socket functionality.
+//
+// +stateify savable
+type AbstractSocketNamespace struct {
+ mu sync.Mutex `state:"nosave"`
+
+ // Keeps mapping from name to endpoint.
+ endpoints map[string]abstractEndpoint
+}
+
+// NewAbstractSocketNamespace returns a new AbstractSocketNamespace.
+func NewAbstractSocketNamespace() *AbstractSocketNamespace {
+ return &AbstractSocketNamespace{
+ endpoints: make(map[string]abstractEndpoint),
+ }
+}
+
+// A boundEndpoint wraps a transport.BoundEndpoint to maintain a reference on
+// its backing object.
+type boundEndpoint struct {
+ transport.BoundEndpoint
+ rc refs.RefCounter
+}
+
+// Release implements transport.BoundEndpoint.Release.
+func (e *boundEndpoint) Release() {
+ e.rc.DecRef()
+ e.BoundEndpoint.Release()
+}
+
+// BoundEndpoint retrieves the endpoint bound to the given name. The return
+// value is nil if no endpoint was bound.
+func (a *AbstractSocketNamespace) BoundEndpoint(name string) transport.BoundEndpoint {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ ep, ok := a.endpoints[name]
+ if !ok {
+ return nil
+ }
+
+ rc := ep.wr.Get()
+ if rc == nil {
+ delete(a.endpoints, name)
+ return nil
+ }
+
+ return &boundEndpoint{ep.ep, rc}
+}
+
+// Bind binds the given socket.
+//
+// When the last reference managed by rc is dropped, ep may be removed from the
+// namespace.
+func (a *AbstractSocketNamespace) Bind(name string, ep transport.BoundEndpoint, rc refs.RefCounter) error {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ if ep, ok := a.endpoints[name]; ok {
+ if rc := ep.wr.Get(); rc != nil {
+ rc.DecRef()
+ return syscall.EADDRINUSE
+ }
+ }
+
+ ae := abstractEndpoint{ep: ep, name: name, ns: a}
+ ae.wr = refs.NewWeakRef(rc, &ae)
+ a.endpoints[name] = ae
+ return nil
+}
diff --git a/pkg/sentry/kernel/aio.go b/pkg/sentry/kernel/aio.go
new file mode 100644
index 000000000..0ac78c0b8
--- /dev/null
+++ b/pkg/sentry/kernel/aio.go
@@ -0,0 +1,81 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// AIOCallback is an function that does asynchronous I/O on behalf of a task.
+type AIOCallback func(context.Context)
+
+// QueueAIO queues an AIOCallback which will be run asynchronously.
+func (t *Task) QueueAIO(cb AIOCallback) {
+ ctx := taskAsyncContext{t: t}
+ wg := &t.TaskSet().aioGoroutines
+ wg.Add(1)
+ go func() {
+ cb(ctx)
+ wg.Done()
+ }()
+}
+
+type taskAsyncContext struct {
+ context.NoopSleeper
+ t *Task
+}
+
+// Debugf implements log.Logger.Debugf.
+func (ctx taskAsyncContext) Debugf(format string, v ...interface{}) {
+ ctx.t.Debugf(format, v...)
+}
+
+// Infof implements log.Logger.Infof.
+func (ctx taskAsyncContext) Infof(format string, v ...interface{}) {
+ ctx.t.Infof(format, v...)
+}
+
+// Warningf implements log.Logger.Warningf.
+func (ctx taskAsyncContext) Warningf(format string, v ...interface{}) {
+ ctx.t.Warningf(format, v...)
+}
+
+// IsLogging implements log.Logger.IsLogging.
+func (ctx taskAsyncContext) IsLogging(level log.Level) bool {
+ return ctx.t.IsLogging(level)
+}
+
+// Deadline implements context.Context.Deadline.
+func (ctx taskAsyncContext) Deadline() (time.Time, bool) {
+ return ctx.t.Deadline()
+}
+
+// Done implements context.Context.Done.
+func (ctx taskAsyncContext) Done() <-chan struct{} {
+ return ctx.t.Done()
+}
+
+// Err implements context.Context.Err.
+func (ctx taskAsyncContext) Err() error {
+ return ctx.t.Err()
+}
+
+// Value implements context.Context.Value.
+func (ctx taskAsyncContext) Value(key interface{}) interface{} {
+ return ctx.t.Value(key)
+}
diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD
new file mode 100644
index 000000000..2bc49483a
--- /dev/null
+++ b/pkg/sentry/kernel/auth/BUILD
@@ -0,0 +1,69 @@
+load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "atomicptr_credentials",
+ out = "atomicptr_credentials_unsafe.go",
+ package = "auth",
+ suffix = "Credentials",
+ template = "//pkg/sync:generic_atomicptr",
+ types = {
+ "Value": "Credentials",
+ },
+)
+
+go_template_instance(
+ name = "id_map_range",
+ out = "id_map_range.go",
+ package = "auth",
+ prefix = "idMap",
+ template = "//pkg/segment:generic_range",
+ types = {
+ "T": "uint32",
+ },
+)
+
+go_template_instance(
+ name = "id_map_set",
+ out = "id_map_set.go",
+ consts = {
+ "minDegree": "3",
+ },
+ package = "auth",
+ prefix = "idMap",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "uint32",
+ "Range": "idMapRange",
+ "Value": "uint32",
+ "Functions": "idMapFunctions",
+ },
+)
+
+go_library(
+ name = "auth",
+ srcs = [
+ "atomicptr_credentials_unsafe.go",
+ "auth.go",
+ "capability_set.go",
+ "context.go",
+ "credentials.go",
+ "id.go",
+ "id_map.go",
+ "id_map_functions.go",
+ "id_map_range.go",
+ "id_map_set.go",
+ "user_namespace.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/bits",
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/sync",
+ "//pkg/syserror",
+ ],
+)
diff --git a/pkg/sentry/kernel/auth/auth.go b/pkg/sentry/kernel/auth/auth.go
new file mode 100644
index 000000000..847d121aa
--- /dev/null
+++ b/pkg/sentry/kernel/auth/auth.go
@@ -0,0 +1,22 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package auth implements an access control model that is a subset of Linux's.
+//
+// The auth package supports two kinds of access controls: user/group IDs and
+// capabilities. Each resource in the security model is associated with a user
+// namespace; "privileged" operations check that the operator's credentials
+// have the required user/group IDs or capabilities within the user namespace
+// of accessed resources.
+package auth
diff --git a/pkg/sentry/kernel/auth/capability_set.go b/pkg/sentry/kernel/auth/capability_set.go
new file mode 100644
index 000000000..fc8c6745c
--- /dev/null
+++ b/pkg/sentry/kernel/auth/capability_set.go
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package auth
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
+)
+
+// A CapabilitySet is a set of capabilities implemented as a bitset. The zero
+// value of CapabilitySet is a set containing no capabilities.
+type CapabilitySet uint64
+
+// AllCapabilities is a CapabilitySet containing all valid capabilities.
+var AllCapabilities = CapabilitySetOf(linux.CAP_LAST_CAP+1) - 1
+
+// CapabilitySetOf returns a CapabilitySet containing only the given
+// capability.
+func CapabilitySetOf(cp linux.Capability) CapabilitySet {
+ return CapabilitySet(bits.MaskOf64(int(cp)))
+}
+
+// CapabilitySetOfMany returns a CapabilitySet containing the given capabilities.
+func CapabilitySetOfMany(cps []linux.Capability) CapabilitySet {
+ var cs uint64
+ for _, cp := range cps {
+ cs |= bits.MaskOf64(int(cp))
+ }
+ return CapabilitySet(cs)
+}
+
+// TaskCapabilities represents all the capability sets for a task. Each of these
+// sets is explained in greater detail in capabilities(7).
+type TaskCapabilities struct {
+ // Permitted is a limiting superset for the effective capabilities that
+ // the thread may assume.
+ PermittedCaps CapabilitySet
+ // Inheritable is a set of capabilities preserved across an execve(2).
+ InheritableCaps CapabilitySet
+ // Effective is the set of capabilities used by the kernel to perform
+ // permission checks for the thread.
+ EffectiveCaps CapabilitySet
+ // Bounding is a limiting superset for the capabilities that a thread
+ // can add to its inheritable set using capset(2).
+ BoundingCaps CapabilitySet
+ // Ambient is a set of capabilities that are preserved across an
+ // execve(2) of a program that is not privileged.
+ AmbientCaps CapabilitySet
+}
diff --git a/pkg/sentry/kernel/auth/context.go b/pkg/sentry/kernel/auth/context.go
new file mode 100644
index 000000000..ef5723127
--- /dev/null
+++ b/pkg/sentry/kernel/auth/context.go
@@ -0,0 +1,36 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package auth
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+// contextID is the auth package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxCredentials is a Context.Value key for Credentials.
+ CtxCredentials contextID = iota
+)
+
+// CredentialsFromContext returns a copy of the Credentials used by ctx, or a
+// set of Credentials with no capabilities if ctx does not have Credentials.
+func CredentialsFromContext(ctx context.Context) *Credentials {
+ if v := ctx.Value(CtxCredentials); v != nil {
+ return v.(*Credentials)
+ }
+ return NewAnonymousCredentials()
+}
diff --git a/pkg/sentry/kernel/auth/credentials.go b/pkg/sentry/kernel/auth/credentials.go
new file mode 100644
index 000000000..6862f2ef5
--- /dev/null
+++ b/pkg/sentry/kernel/auth/credentials.go
@@ -0,0 +1,262 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package auth
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Credentials contains information required to authorize privileged operations
+// in a user namespace.
+//
+// +stateify savable
+type Credentials struct {
+ // Real/effective/saved user/group IDs in the root user namespace. None of
+ // these should ever be NoID.
+ RealKUID KUID
+ EffectiveKUID KUID
+ SavedKUID KUID
+ RealKGID KGID
+ EffectiveKGID KGID
+ SavedKGID KGID
+
+ // Filesystem user/group IDs are not implemented. "... setfsuid() is
+ // nowadays unneeded and should be avoided in new applications (likewise
+ // for setfsgid(2))." - setfsuid(2)
+
+ // Supplementary groups used by set/getgroups.
+ //
+ // ExtraKGIDs slices are immutable, allowing multiple Credentials with the
+ // same ExtraKGIDs to share the same slice.
+ ExtraKGIDs []KGID
+
+ // The capability sets applicable to this set of credentials.
+ PermittedCaps CapabilitySet
+ InheritableCaps CapabilitySet
+ EffectiveCaps CapabilitySet
+ BoundingCaps CapabilitySet
+ // Ambient capabilities are not introduced until Linux 4.3.
+
+ // KeepCaps is the flag for PR_SET_KEEPCAPS which allow capabilities to be
+ // maintained after a switch from root user to non-root user via setuid().
+ KeepCaps bool
+
+ // The user namespace associated with the owner of the credentials.
+ UserNamespace *UserNamespace
+}
+
+// NewAnonymousCredentials returns a set of credentials with no capabilities in
+// any user namespace.
+func NewAnonymousCredentials() *Credentials {
+ // Create a new root user namespace. Since the new namespace's owner is
+ // KUID 0 and the returned credentials have non-zero KUID/KGID, the
+ // returned credentials do not have any capabilities in the new namespace.
+ // Since the new namespace is not part of any existing user namespace
+ // hierarchy, the returned credentials do not have any capabilities in any
+ // other namespace.
+ return &Credentials{
+ RealKUID: NobodyKUID,
+ EffectiveKUID: NobodyKUID,
+ SavedKUID: NobodyKUID,
+ RealKGID: NobodyKGID,
+ EffectiveKGID: NobodyKGID,
+ SavedKGID: NobodyKGID,
+ UserNamespace: NewRootUserNamespace(),
+ }
+}
+
+// NewRootCredentials returns a set of credentials with KUID and KGID 0 (i.e.
+// global root) in user namespace ns.
+func NewRootCredentials(ns *UserNamespace) *Credentials {
+ // I can't find documentation for this anywhere, but it's correct for the
+ // inheritable capability set to be initially empty (the capabilities test
+ // checks for this property).
+ return &Credentials{
+ RealKUID: RootKUID,
+ EffectiveKUID: RootKUID,
+ SavedKUID: RootKUID,
+ RealKGID: RootKGID,
+ EffectiveKGID: RootKGID,
+ SavedKGID: RootKGID,
+ PermittedCaps: AllCapabilities,
+ EffectiveCaps: AllCapabilities,
+ BoundingCaps: AllCapabilities,
+ UserNamespace: ns,
+ }
+}
+
+// NewUserCredentials returns a set of credentials based on the given UID, GIDs,
+// and capabilities in a given namespace. If all arguments are their zero
+// values, this returns the same credentials as NewRootCredentials.
+func NewUserCredentials(kuid KUID, kgid KGID, extraKGIDs []KGID, capabilities *TaskCapabilities, ns *UserNamespace) *Credentials {
+ creds := NewRootCredentials(ns)
+
+ // Set the UID.
+ uid := kuid
+ creds.RealKUID = uid
+ creds.EffectiveKUID = uid
+ creds.SavedKUID = uid
+
+ // Set GID.
+ gid := kgid
+ creds.RealKGID = gid
+ creds.EffectiveKGID = gid
+ creds.SavedKGID = gid
+
+ // Set additional GIDs.
+ creds.ExtraKGIDs = append(creds.ExtraKGIDs, extraKGIDs...)
+
+ // Set capabilities.
+ if capabilities != nil {
+ creds.PermittedCaps = capabilities.PermittedCaps
+ creds.EffectiveCaps = capabilities.EffectiveCaps
+ creds.BoundingCaps = capabilities.BoundingCaps
+ creds.InheritableCaps = capabilities.InheritableCaps
+ // TODO(nlacasse): Support ambient capabilities.
+ } else {
+ // If no capabilities are specified, grant capabilities consistent with
+ // setresuid + setresgid from NewRootCredentials to the given uid and
+ // gid.
+ if kuid == RootKUID {
+ creds.PermittedCaps = AllCapabilities
+ creds.EffectiveCaps = AllCapabilities
+ } else {
+ creds.PermittedCaps = 0
+ creds.EffectiveCaps = 0
+ }
+ creds.BoundingCaps = AllCapabilities
+ }
+
+ return creds
+}
+
+// Fork generates an identical copy of a set of credentials.
+func (c *Credentials) Fork() *Credentials {
+ nc := new(Credentials)
+ *nc = *c // Copy-by-value; this is legal for all fields.
+ return nc
+}
+
+// InGroup returns true if c is in group kgid. Compare Linux's
+// kernel/groups.c:in_group_p().
+func (c *Credentials) InGroup(kgid KGID) bool {
+ if c.EffectiveKGID == kgid {
+ return true
+ }
+ for _, extraKGID := range c.ExtraKGIDs {
+ if extraKGID == kgid {
+ return true
+ }
+ }
+ return false
+}
+
+// HasCapabilityIn returns true if c has capability cp in ns.
+func (c *Credentials) HasCapabilityIn(cp linux.Capability, ns *UserNamespace) bool {
+ for {
+ // "1. A process has a capability inside a user namespace if it is a member
+ // of that namespace and it has the capability in its effective capability
+ // set." - user_namespaces(7)
+ if c.UserNamespace == ns {
+ return CapabilitySetOf(cp)&c.EffectiveCaps != 0
+ }
+ // "3. ... A process that resides in the parent of the user namespace and
+ // whose effective user ID matches the owner of the namespace has all
+ // capabilities in the namespace."
+ if c.UserNamespace == ns.parent && c.EffectiveKUID == ns.owner {
+ return true
+ }
+ // "2. If a process has a capability in a user namespace, then it has that
+ // capability in all child (and further removed descendant) namespaces as
+ // well."
+ if ns.parent == nil {
+ return false
+ }
+ ns = ns.parent
+ }
+}
+
+// HasCapability returns true if c has capability cp in its user namespace.
+func (c *Credentials) HasCapability(cp linux.Capability) bool {
+ return c.HasCapabilityIn(cp, c.UserNamespace)
+}
+
+// UseUID checks that c can use uid in its user namespace, then translates it
+// to the root user namespace.
+//
+// The checks UseUID does are common, but you should verify that it's doing
+// exactly what you want.
+func (c *Credentials) UseUID(uid UID) (KUID, error) {
+ // uid must be mapped.
+ kuid := c.UserNamespace.MapToKUID(uid)
+ if !kuid.Ok() {
+ return NoID, syserror.EINVAL
+ }
+ // If c has CAP_SETUID, then it can use any UID in its user namespace.
+ if c.HasCapability(linux.CAP_SETUID) {
+ return kuid, nil
+ }
+ // Otherwise, c must already have the UID as its real, effective, or saved
+ // set-user-ID.
+ if kuid == c.RealKUID || kuid == c.EffectiveKUID || kuid == c.SavedKUID {
+ return kuid, nil
+ }
+ return NoID, syserror.EPERM
+}
+
+// UseGID checks that c can use gid in its user namespace, then translates it
+// to the root user namespace.
+func (c *Credentials) UseGID(gid GID) (KGID, error) {
+ kgid := c.UserNamespace.MapToKGID(gid)
+ if !kgid.Ok() {
+ return NoID, syserror.EINVAL
+ }
+ if c.HasCapability(linux.CAP_SETGID) {
+ return kgid, nil
+ }
+ if kgid == c.RealKGID || kgid == c.EffectiveKGID || kgid == c.SavedKGID {
+ return kgid, nil
+ }
+ return NoID, syserror.EPERM
+}
+
+// SetUID translates the provided uid to the root user namespace and updates c's
+// uids to it. This performs no permissions or capabilities checks, the caller
+// is responsible for ensuring the calling context is permitted to modify c.
+func (c *Credentials) SetUID(uid UID) error {
+ kuid := c.UserNamespace.MapToKUID(uid)
+ if !kuid.Ok() {
+ return syserror.EINVAL
+ }
+ c.RealKUID = kuid
+ c.EffectiveKUID = kuid
+ c.SavedKUID = kuid
+ return nil
+}
+
+// SetGID translates the provided gid to the root user namespace and updates c's
+// gids to it. This performs no permissions or capabilities checks, the caller
+// is responsible for ensuring the calling context is permitted to modify c.
+func (c *Credentials) SetGID(gid GID) error {
+ kgid := c.UserNamespace.MapToKGID(gid)
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+ c.RealKGID = kgid
+ c.EffectiveKGID = kgid
+ c.SavedKGID = kgid
+ return nil
+}
diff --git a/pkg/sentry/kernel/auth/id.go b/pkg/sentry/kernel/auth/id.go
new file mode 100644
index 000000000..0a58ba17c
--- /dev/null
+++ b/pkg/sentry/kernel/auth/id.go
@@ -0,0 +1,121 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package auth
+
+import (
+ "math"
+)
+
+// UID is a user ID in an unspecified user namespace.
+type UID uint32
+
+// GID is a group ID in an unspecified user namespace.
+type GID uint32
+
+// In the root user namespace, user/group IDs have a 1-to-1 relationship with
+// the users/groups they represent. In other user namespaces, this is not the
+// case; for example, two different unmapped users may both "have" the overflow
+// UID. This means that it is generally only valid to compare user and group
+// IDs in the root user namespace. We assign distinct types, KUID/KGID, to such
+// IDs to emphasize this distinction. ("k" is for "key", as in "unique key".
+// Linux also uses the prefix "k", but I think they mean "kernel".)
+
+// KUID is a user ID in the root user namespace.
+type KUID uint32
+
+// KGID is a group ID in the root user namespace.
+type KGID uint32
+
+const (
+ // NoID is uint32(-1). -1 is consistently used as a special value, in Linux
+ // and by extension in the auth package, to mean "no ID":
+ //
+ // - ID mapping returns -1 if the ID is not mapped.
+ //
+ // - Most set*id() syscalls accept -1 to mean "do not change this ID".
+ NoID = math.MaxUint32
+
+ // OverflowUID is the default value of /proc/sys/kernel/overflowuid. The
+ // "overflow UID" is usually [1] used when translating a user ID between
+ // namespaces fails because the ID is not mapped. (We don't implement this
+ // file, so the overflow UID is constant.)
+ //
+ // [1] "There is one notable case where unmapped user and group IDs are not
+ // converted to the corresponding overflow ID value. When viewing a uid_map
+ // or gid_map file in which there is no mapping for the second field, that
+ // field is displayed as 4294967295 (-1 as an unsigned integer);" -
+ // user_namespaces(7)
+ OverflowUID = UID(65534)
+ OverflowGID = GID(65534)
+
+ // NobodyKUID is the user ID usually reserved for the least privileged user
+ // "nobody".
+ NobodyKUID = KUID(65534)
+ NobodyKGID = KGID(65534)
+
+ // RootKUID is the user ID usually used for the most privileged user "root".
+ RootKUID = KUID(0)
+ RootKGID = KGID(0)
+ RootUID = UID(0)
+ RootGID = GID(0)
+)
+
+// Ok returns true if uid is not -1.
+func (uid UID) Ok() bool {
+ return uid != NoID
+}
+
+// Ok returns true if gid is not -1.
+func (gid GID) Ok() bool {
+ return gid != NoID
+}
+
+// Ok returns true if kuid is not -1.
+func (kuid KUID) Ok() bool {
+ return kuid != NoID
+}
+
+// Ok returns true if kgid is not -1.
+func (kgid KGID) Ok() bool {
+ return kgid != NoID
+}
+
+// OrOverflow returns uid if it is valid and the overflow UID otherwise.
+func (uid UID) OrOverflow() UID {
+ if uid.Ok() {
+ return uid
+ }
+ return OverflowUID
+}
+
+// OrOverflow returns gid if it is valid and the overflow GID otherwise.
+func (gid GID) OrOverflow() GID {
+ if gid.Ok() {
+ return gid
+ }
+ return OverflowGID
+}
+
+// In translates kuid into user namespace ns. If kuid is not mapped in ns, In
+// returns NoID.
+func (kuid KUID) In(ns *UserNamespace) UID {
+ return ns.MapFromKUID(kuid)
+}
+
+// In translates kgid into user namespace ns. If kgid is not mapped in ns, In
+// returns NoID.
+func (kgid KGID) In(ns *UserNamespace) GID {
+ return ns.MapFromKGID(kgid)
+}
diff --git a/pkg/sentry/kernel/auth/id_map.go b/pkg/sentry/kernel/auth/id_map.go
new file mode 100644
index 000000000..28cbe159d
--- /dev/null
+++ b/pkg/sentry/kernel/auth/id_map.go
@@ -0,0 +1,285 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package auth
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// MapFromKUID translates kuid, a UID in the root namespace, to a UID in ns.
+func (ns *UserNamespace) MapFromKUID(kuid KUID) UID {
+ if ns.parent == nil {
+ return UID(kuid)
+ }
+ return UID(ns.mapID(&ns.uidMapFromParent, uint32(ns.parent.MapFromKUID(kuid))))
+}
+
+// MapFromKGID translates kgid, a GID in the root namespace, to a GID in ns.
+func (ns *UserNamespace) MapFromKGID(kgid KGID) GID {
+ if ns.parent == nil {
+ return GID(kgid)
+ }
+ return GID(ns.mapID(&ns.gidMapFromParent, uint32(ns.parent.MapFromKGID(kgid))))
+}
+
+// MapToKUID translates uid, a UID in ns, to a UID in the root namespace.
+func (ns *UserNamespace) MapToKUID(uid UID) KUID {
+ if ns.parent == nil {
+ return KUID(uid)
+ }
+ return ns.parent.MapToKUID(UID(ns.mapID(&ns.uidMapToParent, uint32(uid))))
+}
+
+// MapToKGID translates gid, a GID in ns, to a GID in the root namespace.
+func (ns *UserNamespace) MapToKGID(gid GID) KGID {
+ if ns.parent == nil {
+ return KGID(gid)
+ }
+ return ns.parent.MapToKGID(GID(ns.mapID(&ns.gidMapToParent, uint32(gid))))
+}
+
+func (ns *UserNamespace) mapID(m *idMapSet, id uint32) uint32 {
+ if id == NoID {
+ return NoID
+ }
+ ns.mu.Lock()
+ defer ns.mu.Unlock()
+ if it := m.FindSegment(id); it.Ok() {
+ return it.Value() + (id - it.Start())
+ }
+ return NoID
+}
+
+// allIDsMapped returns true if all IDs in the range [start, end) are mapped in
+// m.
+//
+// Preconditions: end >= start.
+func (ns *UserNamespace) allIDsMapped(m *idMapSet, start, end uint32) bool {
+ ns.mu.Lock()
+ defer ns.mu.Unlock()
+ return m.SpanRange(idMapRange{start, end}) == end-start
+}
+
+// An IDMapEntry represents a mapping from a range of contiguous IDs in a user
+// namespace to an equally-sized range of contiguous IDs in the namespace's
+// parent.
+//
+// +stateify savable
+type IDMapEntry struct {
+ // FirstID is the first ID in the range in the namespace.
+ FirstID uint32
+
+ // FirstParentID is the first ID in the range in the parent namespace.
+ FirstParentID uint32
+
+ // Length is the number of IDs in the range.
+ Length uint32
+}
+
+// SetUIDMap instructs ns to translate UIDs as specified by entries.
+//
+// Note: SetUIDMap does not place an upper bound on the number of entries, but
+// Linux does. This restriction is implemented in SetUIDMap's caller, the
+// implementation of /proc/[pid]/uid_map.
+func (ns *UserNamespace) SetUIDMap(ctx context.Context, entries []IDMapEntry) error {
+ c := CredentialsFromContext(ctx)
+
+ ns.mu.Lock()
+ defer ns.mu.Unlock()
+ // "After the creation of a new user namespace, the uid_map file of *one*
+ // of the processes in the namespace may be written to *once* to define the
+ // mapping of user IDs in the new user namespace. An attempt to write more
+ // than once to a uid_map file in a user namespace fails with the error
+ // EPERM. Similar rules apply for gid_map files." - user_namespaces(7)
+ if !ns.uidMapFromParent.IsEmpty() {
+ return syserror.EPERM
+ }
+ // "At least one line must be written to the file."
+ if len(entries) == 0 {
+ return syserror.EINVAL
+ }
+ // """
+ // In order for a process to write to the /proc/[pid]/uid_map
+ // (/proc/[pid]/gid_map) file, all of the following requirements must be
+ // met:
+ //
+ // 1. The writing process must have the CAP_SETUID (CAP_SETGID) capability
+ // in the user namespace of the process pid.
+ // """
+ if !c.HasCapabilityIn(linux.CAP_SETUID, ns) {
+ return syserror.EPERM
+ }
+ // "2. The writing process must either be in the user namespace of the process
+ // pid or be in the parent user namespace of the process pid."
+ if c.UserNamespace != ns && c.UserNamespace != ns.parent {
+ return syserror.EPERM
+ }
+ // """
+ // 3. (see trySetUIDMap)
+ //
+ // 4. One of the following two cases applies:
+ //
+ // * Either the writing process has the CAP_SETUID (CAP_SETGID) capability
+ // in the parent user namespace.
+ // """
+ if !c.HasCapabilityIn(linux.CAP_SETUID, ns.parent) {
+ // """
+ // * Or otherwise all of the following restrictions apply:
+ //
+ // + The data written to uid_map (gid_map) must consist of a single line
+ // that maps the writing process' effective user ID (group ID) in the
+ // parent user namespace to a user ID (group ID) in the user namespace.
+ // """
+ if len(entries) != 1 || ns.parent.MapToKUID(UID(entries[0].FirstParentID)) != c.EffectiveKUID || entries[0].Length != 1 {
+ return syserror.EPERM
+ }
+ // """
+ // + The writing process must have the same effective user ID as the
+ // process that created the user namespace.
+ // """
+ if c.EffectiveKUID != ns.owner {
+ return syserror.EPERM
+ }
+ }
+ // trySetUIDMap leaves data in maps if it fails.
+ if err := ns.trySetUIDMap(entries); err != nil {
+ ns.uidMapFromParent.RemoveAll()
+ ns.uidMapToParent.RemoveAll()
+ return err
+ }
+ return nil
+}
+
+func (ns *UserNamespace) trySetUIDMap(entries []IDMapEntry) error {
+ for _, e := range entries {
+ // Determine upper bounds and check for overflow. This implicitly
+ // checks for NoID.
+ lastID := e.FirstID + e.Length
+ if lastID <= e.FirstID {
+ return syserror.EINVAL
+ }
+ lastParentID := e.FirstParentID + e.Length
+ if lastParentID <= e.FirstParentID {
+ return syserror.EINVAL
+ }
+ // "3. The mapped user IDs (group IDs) must in turn have a mapping in
+ // the parent user namespace."
+ // Only the root namespace has a nil parent, and root is assigned
+ // mappings when it's created, so SetUIDMap would have returned EPERM
+ // without reaching this point if ns is root.
+ if !ns.parent.allIDsMapped(&ns.parent.uidMapToParent, e.FirstParentID, lastParentID) {
+ return syserror.EPERM
+ }
+ // If either of these Adds fail, we have an overlapping range.
+ if !ns.uidMapFromParent.Add(idMapRange{e.FirstParentID, lastParentID}, e.FirstID) {
+ return syserror.EINVAL
+ }
+ if !ns.uidMapToParent.Add(idMapRange{e.FirstID, lastID}, e.FirstParentID) {
+ return syserror.EINVAL
+ }
+ }
+ return nil
+}
+
+// SetGIDMap instructs ns to translate GIDs as specified by entries.
+func (ns *UserNamespace) SetGIDMap(ctx context.Context, entries []IDMapEntry) error {
+ c := CredentialsFromContext(ctx)
+
+ ns.mu.Lock()
+ defer ns.mu.Unlock()
+ if !ns.gidMapFromParent.IsEmpty() {
+ return syserror.EPERM
+ }
+ if len(entries) == 0 {
+ return syserror.EINVAL
+ }
+ if !c.HasCapabilityIn(linux.CAP_SETGID, ns) {
+ return syserror.EPERM
+ }
+ if c.UserNamespace != ns && c.UserNamespace != ns.parent {
+ return syserror.EPERM
+ }
+ if !c.HasCapabilityIn(linux.CAP_SETGID, ns.parent) {
+ if len(entries) != 1 || ns.parent.MapToKGID(GID(entries[0].FirstParentID)) != c.EffectiveKGID || entries[0].Length != 1 {
+ return syserror.EPERM
+ }
+ // It's correct for this to still be UID.
+ if c.EffectiveKUID != ns.owner {
+ return syserror.EPERM
+ }
+ // "In the case of gid_map, use of the setgroups(2) system call must
+ // first be denied by writing "deny" to the /proc/[pid]/setgroups file
+ // (see below) before writing to gid_map." (This file isn't implemented
+ // in the version of Linux we're emulating; see comment in
+ // UserNamespace.)
+ }
+ if err := ns.trySetGIDMap(entries); err != nil {
+ ns.gidMapFromParent.RemoveAll()
+ ns.gidMapToParent.RemoveAll()
+ return err
+ }
+ return nil
+}
+
+func (ns *UserNamespace) trySetGIDMap(entries []IDMapEntry) error {
+ for _, e := range entries {
+ lastID := e.FirstID + e.Length
+ if lastID <= e.FirstID {
+ return syserror.EINVAL
+ }
+ lastParentID := e.FirstParentID + e.Length
+ if lastParentID <= e.FirstParentID {
+ return syserror.EINVAL
+ }
+ if !ns.parent.allIDsMapped(&ns.parent.gidMapToParent, e.FirstParentID, lastParentID) {
+ return syserror.EPERM
+ }
+ if !ns.gidMapFromParent.Add(idMapRange{e.FirstParentID, lastParentID}, e.FirstID) {
+ return syserror.EINVAL
+ }
+ if !ns.gidMapToParent.Add(idMapRange{e.FirstID, lastID}, e.FirstParentID) {
+ return syserror.EINVAL
+ }
+ }
+ return nil
+}
+
+// UIDMap returns the user ID mappings configured for ns. If no mappings
+// have been configured, UIDMap returns nil.
+func (ns *UserNamespace) UIDMap() []IDMapEntry {
+ return ns.getIDMap(&ns.uidMapToParent)
+}
+
+// GIDMap returns the group ID mappings configured for ns. If no mappings
+// have been configured, GIDMap returns nil.
+func (ns *UserNamespace) GIDMap() []IDMapEntry {
+ return ns.getIDMap(&ns.gidMapToParent)
+}
+
+func (ns *UserNamespace) getIDMap(m *idMapSet) []IDMapEntry {
+ ns.mu.Lock()
+ defer ns.mu.Unlock()
+ var entries []IDMapEntry
+ for it := m.FirstSegment(); it.Ok(); it = it.NextSegment() {
+ entries = append(entries, IDMapEntry{
+ FirstID: it.Start(),
+ FirstParentID: it.Value(),
+ Length: it.Range().Length(),
+ })
+ }
+ return entries
+}
diff --git a/pkg/sentry/kernel/auth/id_map_functions.go b/pkg/sentry/kernel/auth/id_map_functions.go
new file mode 100644
index 000000000..432dbfb6d
--- /dev/null
+++ b/pkg/sentry/kernel/auth/id_map_functions.go
@@ -0,0 +1,45 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package auth
+
+// idMapFunctions "implements" generic interface segment.Functions for
+// idMapSet. An idMapSet maps non-overlapping ranges of contiguous IDs in one
+// user namespace to non-overlapping ranges of contiguous IDs in another user
+// namespace. Each such ID mapping is implemented as a range-to-value mapping
+// in the set such that [range.Start(), range.End()) => [value, value +
+// range.Length()).
+type idMapFunctions struct{}
+
+func (idMapFunctions) MinKey() uint32 {
+ return 0
+}
+
+func (idMapFunctions) MaxKey() uint32 {
+ return NoID
+}
+
+func (idMapFunctions) ClearValue(*uint32) {}
+
+func (idMapFunctions) Merge(r1 idMapRange, val1 uint32, r2 idMapRange, val2 uint32) (uint32, bool) {
+ // Mapped ranges have to be contiguous.
+ if val1+r1.Length() != val2 {
+ return 0, false
+ }
+ return val1, true
+}
+
+func (idMapFunctions) Split(r idMapRange, val uint32, split uint32) (uint32, uint32) {
+ return val, val + (split - r.Start)
+}
diff --git a/pkg/sentry/kernel/auth/user_namespace.go b/pkg/sentry/kernel/auth/user_namespace.go
new file mode 100644
index 000000000..9dd52c860
--- /dev/null
+++ b/pkg/sentry/kernel/auth/user_namespace.go
@@ -0,0 +1,129 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package auth
+
+import (
+ "math"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// A UserNamespace represents a user namespace. See user_namespaces(7) for
+// details.
+//
+// +stateify savable
+type UserNamespace struct {
+ // parent is this namespace's parent. If this is the root namespace, parent
+ // is nil. The parent pointer is immutable.
+ parent *UserNamespace
+
+ // owner is the effective UID of the namespace's creator in the root
+ // namespace. owner is immutable.
+ owner KUID
+
+ // mu protects the following fields.
+ //
+ // If mu will be locked in multiple UserNamespaces, it must be locked in
+ // descendant namespaces before ancestors.
+ mu sync.Mutex `state:"nosave"`
+
+ // Mappings of user/group IDs between this namespace and its parent.
+ //
+ // All ID maps, once set, cannot be changed. This means that successful
+ // UID/GID translations cannot be racy.
+ uidMapFromParent idMapSet
+ uidMapToParent idMapSet
+ gidMapFromParent idMapSet
+ gidMapToParent idMapSet
+
+ // TODO(b/27454212): Support disabling setgroups(2).
+}
+
+// NewRootUserNamespace returns a UserNamespace that is appropriate for a
+// system's root user namespace.
+func NewRootUserNamespace() *UserNamespace {
+ var ns UserNamespace
+ // """
+ // The initial user namespace has no parent namespace, but, for
+ // consistency, the kernel provides dummy user and group ID mapping files
+ // for this namespace. Looking at the uid_map file (gid_map is the same)
+ // from a shell in the initial namespace shows:
+ //
+ // $ cat /proc/$$/uid_map
+ // 0 0 4294967295
+ // """ - user_namespaces(7)
+ for _, m := range []*idMapSet{
+ &ns.uidMapFromParent,
+ &ns.uidMapToParent,
+ &ns.gidMapFromParent,
+ &ns.gidMapToParent,
+ } {
+ if !m.Add(idMapRange{0, math.MaxUint32}, 0) {
+ panic("Failed to insert into empty ID map")
+ }
+ }
+ return &ns
+}
+
+// Root returns the root of the user namespace tree containing ns.
+func (ns *UserNamespace) Root() *UserNamespace {
+ for ns.parent != nil {
+ ns = ns.parent
+ }
+ return ns
+}
+
+// "The kernel imposes (since version 3.11) a limit of 32 nested levels of user
+// namespaces." - user_namespaces(7)
+const maxUserNamespaceDepth = 32
+
+func (ns *UserNamespace) depth() int {
+ var i int
+ for ns != nil {
+ i++
+ ns = ns.parent
+ }
+ return i
+}
+
+// NewChildUserNamespace returns a new user namespace created by a caller with
+// credentials c.
+func (c *Credentials) NewChildUserNamespace() (*UserNamespace, error) {
+ if c.UserNamespace.depth() >= maxUserNamespaceDepth {
+ // "... Calls to unshare(2) or clone(2) that would cause this limit to
+ // be exceeded fail with the error EUSERS." - user_namespaces(7)
+ return nil, syserror.EUSERS
+ }
+ // "EPERM: CLONE_NEWUSER was specified in flags, but either the effective
+ // user ID or the effective group ID of the caller does not have a mapping
+ // in the parent namespace (see user_namespaces(7))." - clone(2)
+ // "CLONE_NEWUSER requires that the user ID and group ID of the calling
+ // process are mapped to user IDs and group IDs in the user namespace of
+ // the calling process at the time of the call." - unshare(2)
+ if !c.EffectiveKUID.In(c.UserNamespace).Ok() {
+ return nil, syserror.EPERM
+ }
+ if !c.EffectiveKGID.In(c.UserNamespace).Ok() {
+ return nil, syserror.EPERM
+ }
+ return &UserNamespace{
+ parent: c.UserNamespace,
+ owner: c.EffectiveKUID,
+ // "When a user namespace is created, it starts without a mapping of
+ // user IDs (group IDs) to the parent user namespace." -
+ // user_namespaces(7)
+ }, nil
+}
diff --git a/pkg/sentry/kernel/context.go b/pkg/sentry/kernel/context.go
new file mode 100644
index 000000000..dd5f0f5fa
--- /dev/null
+++ b/pkg/sentry/kernel/context.go
@@ -0,0 +1,114 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+// contextID is the kernel package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxCanTrace is a Context.Value key for a function with the same
+ // signature and semantics as kernel.Task.CanTrace.
+ CtxCanTrace contextID = iota
+
+ // CtxKernel is a Context.Value key for a Kernel.
+ CtxKernel
+
+ // CtxPIDNamespace is a Context.Value key for a PIDNamespace.
+ CtxPIDNamespace
+
+ // CtxTask is a Context.Value key for a Task.
+ CtxTask
+
+ // CtxUTSNamespace is a Context.Value key for a UTSNamespace.
+ CtxUTSNamespace
+
+ // CtxIPCNamespace is a Context.Value key for a IPCNamespace.
+ CtxIPCNamespace
+)
+
+// ContextCanTrace returns true if ctx is permitted to trace t, in the same sense
+// as kernel.Task.CanTrace.
+func ContextCanTrace(ctx context.Context, t *Task, attach bool) bool {
+ if v := ctx.Value(CtxCanTrace); v != nil {
+ return v.(func(*Task, bool) bool)(t, attach)
+ }
+ return false
+}
+
+// KernelFromContext returns the Kernel in which ctx is executing, or nil if
+// there is no such Kernel.
+func KernelFromContext(ctx context.Context) *Kernel {
+ if v := ctx.Value(CtxKernel); v != nil {
+ return v.(*Kernel)
+ }
+ return nil
+}
+
+// PIDNamespaceFromContext returns the PID namespace in which ctx is executing,
+// or nil if there is no such PID namespace.
+func PIDNamespaceFromContext(ctx context.Context) *PIDNamespace {
+ if v := ctx.Value(CtxPIDNamespace); v != nil {
+ return v.(*PIDNamespace)
+ }
+ return nil
+}
+
+// UTSNamespaceFromContext returns the UTS namespace in which ctx is executing,
+// or nil if there is no such UTS namespace.
+func UTSNamespaceFromContext(ctx context.Context) *UTSNamespace {
+ if v := ctx.Value(CtxUTSNamespace); v != nil {
+ return v.(*UTSNamespace)
+ }
+ return nil
+}
+
+// IPCNamespaceFromContext returns the IPC namespace in which ctx is executing,
+// or nil if there is no such IPC namespace.
+func IPCNamespaceFromContext(ctx context.Context) *IPCNamespace {
+ if v := ctx.Value(CtxIPCNamespace); v != nil {
+ return v.(*IPCNamespace)
+ }
+ return nil
+}
+
+// TaskFromContext returns the Task associated with ctx, or nil if there is no
+// such Task.
+func TaskFromContext(ctx context.Context) *Task {
+ if v := ctx.Value(CtxTask); v != nil {
+ return v.(*Task)
+ }
+ return nil
+}
+
+// Deadline implements context.Context.Deadline.
+func (*Task) Deadline() (time.Time, bool) {
+ return time.Time{}, false
+}
+
+// Done implements context.Context.Done.
+func (*Task) Done() <-chan struct{} {
+ return nil
+}
+
+// Err implements context.Context.Err.
+func (*Task) Err() error {
+ return nil
+}
diff --git a/pkg/sentry/kernel/contexttest/BUILD b/pkg/sentry/kernel/contexttest/BUILD
new file mode 100644
index 000000000..9d26392c0
--- /dev/null
+++ b/pkg/sentry/kernel/contexttest/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "contexttest",
+ testonly = 1,
+ srcs = ["contexttest.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/context",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ ],
+)
diff --git a/pkg/sentry/kernel/contexttest/contexttest.go b/pkg/sentry/kernel/contexttest/contexttest.go
new file mode 100644
index 000000000..22c340e56
--- /dev/null
+++ b/pkg/sentry/kernel/contexttest/contexttest.go
@@ -0,0 +1,40 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package contexttest provides a test context.Context which includes
+// a dummy kernel pointing to a valid platform.
+package contexttest
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+)
+
+// Context returns a Context that may be used in tests. Uses ptrace as the
+// platform.Platform, and provides a stub kernel that only serves to point to
+// the platform.
+func Context(tb testing.TB) context.Context {
+ ctx := contexttest.Context(tb)
+ k := &kernel.Kernel{
+ Platform: platform.FromContext(ctx),
+ }
+ k.SetMemoryFile(pgalloc.MemoryFileFromContext(ctx))
+ ctx.(*contexttest.TestContext).RegisterValue(kernel.CtxKernel, k)
+ return ctx
+}
diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD
new file mode 100644
index 000000000..75eedd5a2
--- /dev/null
+++ b/pkg/sentry/kernel/epoll/BUILD
@@ -0,0 +1,51 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "epoll_list",
+ out = "epoll_list.go",
+ package = "epoll",
+ prefix = "pollEntry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*pollEntry",
+ "Linker": "*pollEntry",
+ },
+)
+
+go_library(
+ name = "epoll",
+ srcs = [
+ "epoll.go",
+ "epoll_list.go",
+ "epoll_state.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/refs",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/anon",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sync",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "epoll_test",
+ size = "small",
+ srcs = [
+ "epoll_test.go",
+ ],
+ library = ":epoll",
+ deps = [
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fs/filetest",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go
new file mode 100644
index 000000000..4c0f1e41f
--- /dev/null
+++ b/pkg/sentry/kernel/epoll/epoll.go
@@ -0,0 +1,462 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package epoll provides an implementation of Linux's IO event notification
+// facility. See epoll(7) for more details.
+package epoll
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/anon"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// EntryFlags is a bitmask that holds an entry's flags.
+type EntryFlags int
+
+// Valid entry flags.
+const (
+ OneShot EntryFlags = 1 << iota
+ EdgeTriggered
+)
+
+// FileIdentifier identifies a file. We cannot use just the FD because it could
+// potentially be reassigned. We also cannot use just the file pointer because
+// it is possible to have multiple entries for the same file object as long as
+// they are created with different FDs (i.e., the FDs point to the same file).
+//
+// +stateify savable
+type FileIdentifier struct {
+ File *fs.File `state:"wait"`
+ Fd int32
+}
+
+// pollEntry holds all the state associated with an event poll entry, that is,
+// a file being observed by an event poll object.
+//
+// +stateify savable
+type pollEntry struct {
+ pollEntryEntry
+ file *refs.WeakRef `state:"manual"`
+ id FileIdentifier `state:"wait"`
+ userData [2]int32
+ waiter waiter.Entry `state:"manual"`
+ mask waiter.EventMask
+ flags EntryFlags
+
+ epoll *EventPoll
+
+ // We cannot save the current list pointer as it points into EventPoll
+ // struct, while state framework currently does not support such
+ // in-struct pointers. Instead, EventPoll will properly set this field
+ // in its loading logic.
+ curList *pollEntryList `state:"nosave"`
+}
+
+// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
+// weakReferenceGone is called when the file in the weak reference is destroyed.
+// The poll entry is removed in response to this.
+func (p *pollEntry) WeakRefGone() {
+ p.epoll.RemoveEntry(p.id)
+}
+
+// EventPoll holds all the state associated with an event poll object, that is,
+// collection of files to observe and their current state.
+//
+// +stateify savable
+type EventPoll struct {
+ fsutil.FilePipeSeek `state:"zerovalue"`
+ fsutil.FileNotDirReaddir `state:"zerovalue"`
+ fsutil.FileNoFsync `state:"zerovalue"`
+ fsutil.FileNoopFlush `state:"zerovalue"`
+ fsutil.FileNoIoctl `state:"zerovalue"`
+ fsutil.FileNoMMap `state:"zerovalue"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ // Wait queue is used to notify interested parties when the event poll
+ // object itself becomes readable or writable.
+ waiter.Queue `state:"zerovalue"`
+
+ // files is the map of all the files currently being observed, it is
+ // protected by mu.
+ mu sync.Mutex `state:"nosave"`
+ files map[FileIdentifier]*pollEntry
+
+ // listsMu protects manipulation of the lists below. It needs to be a
+ // different lock to avoid circular lock acquisition order involving
+ // the wait queue mutexes and mu. The full order is mu, observed file
+ // wait queue mutex, then listsMu; this allows listsMu to be acquired
+ // when (*pollEntry).Callback is called.
+ //
+ // An entry is always in one of the following lists:
+ // readyList -- when there's a chance that it's ready to have
+ // events delivered to epoll waiters. Given that being
+ // ready is a transient state, the Readiness() and
+ // readEvents() functions always call the entry's file
+ // Readiness() function to confirm it's ready.
+ // waitingList -- when there's no chance that the entry is ready,
+ // so it's waiting for the (*pollEntry).Callback to be called
+ // on it before it gets moved to the readyList.
+ // disabledList -- when the entry is disabled. This happens when
+ // a one-shot entry gets delivered via readEvents().
+ listsMu sync.Mutex `state:"nosave"`
+ readyList pollEntryList
+ waitingList pollEntryList
+ disabledList pollEntryList
+}
+
+// cycleMu is used to serialize all the cycle checks. This is only used when
+// an event poll file is added as an entry to another event poll. Such checks
+// are serialized to avoid lock acquisition order inversion: if a thread is
+// adding A to B, and another thread is adding B to A, each would acquire A's
+// and B's mutexes in reverse order, and could cause deadlocks. Having this
+// lock prevents this by allowing only one check at a time to happen.
+//
+// We do the cycle check to prevent callers from introducing potentially
+// infinite recursions. If a caller were to add A to B and then B to A, for
+// event poll A to know if it's readable, it would need to check event poll B,
+// which in turn would need event poll A and so on indefinitely.
+var cycleMu sync.Mutex
+
+// NewEventPoll allocates and initializes a new event poll object.
+func NewEventPoll(ctx context.Context) *fs.File {
+ // name matches fs/eventpoll.c:epoll_create1.
+ dirent := fs.NewDirent(ctx, anon.NewInode(ctx), fmt.Sprintf("anon_inode:[eventpoll]"))
+ // Release the initial dirent reference after NewFile takes a reference.
+ defer dirent.DecRef()
+ return fs.NewFile(ctx, dirent, fs.FileFlags{}, &EventPoll{
+ files: make(map[FileIdentifier]*pollEntry),
+ })
+}
+
+// Release implements fs.FileOperations.Release.
+func (e *EventPoll) Release() {
+ // We need to take the lock now because files may be attempting to
+ // remove entries in parallel if they get destroyed.
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Go through all entries and clean up.
+ for _, entry := range e.files {
+ entry.id.File.EventUnregister(&entry.waiter)
+ entry.file.Drop()
+ }
+ e.files = nil
+}
+
+// Read implements fs.FileOperations.Read.
+func (*EventPoll) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syscall.ENOSYS
+}
+
+// Write implements fs.FileOperations.Write.
+func (*EventPoll) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syscall.ENOSYS
+}
+
+// eventsAvailable determines if 'e' has events available for delivery.
+func (e *EventPoll) eventsAvailable() bool {
+ e.listsMu.Lock()
+
+ for it := e.readyList.Front(); it != nil; {
+ entry := it
+ it = it.Next()
+
+ // If the entry is ready, we know 'e' has at least one entry
+ // ready for delivery.
+ ready := entry.id.File.Readiness(entry.mask)
+ if ready != 0 {
+ e.listsMu.Unlock()
+ return true
+ }
+
+ // Entry is not ready, so move it to waiting list.
+ e.readyList.Remove(entry)
+ e.waitingList.PushBack(entry)
+ entry.curList = &e.waitingList
+ }
+
+ e.listsMu.Unlock()
+
+ return false
+}
+
+// Readiness determines if the event poll object is currently readable (i.e.,
+// if there are pending events for delivery).
+func (e *EventPoll) Readiness(mask waiter.EventMask) waiter.EventMask {
+ ready := waiter.EventMask(0)
+
+ if (mask&waiter.EventIn) != 0 && e.eventsAvailable() {
+ ready |= waiter.EventIn
+ }
+
+ return ready
+}
+
+// ReadEvents returns up to max available events.
+func (e *EventPoll) ReadEvents(max int) []linux.EpollEvent {
+ var local pollEntryList
+ var ret []linux.EpollEvent
+
+ e.listsMu.Lock()
+
+ // Go through all entries we believe may be ready.
+ for it := e.readyList.Front(); it != nil && len(ret) < max; {
+ entry := it
+ it = it.Next()
+
+ // Check the entry's readiness. It it's not really ready, we
+ // just put it back in the waiting list and move on to the next
+ // entry.
+ ready := entry.id.File.Readiness(entry.mask) & entry.mask
+ if ready == 0 {
+ e.readyList.Remove(entry)
+ e.waitingList.PushBack(entry)
+ entry.curList = &e.waitingList
+
+ continue
+ }
+
+ // Add event to the array that will be returned to caller.
+ ret = append(ret, linux.EpollEvent{
+ Events: uint32(ready),
+ Data: entry.userData,
+ })
+
+ // The entry is consumed, so we must move it to the disabled
+ // list in case it's one-shot, or back to the wait list if it's
+ // edge-triggered. If it's neither, we leave it in the ready
+ // list so that its readiness can be checked the next time
+ // around; however, we must move it to the end of the list so
+ // that other events can be delivered as well.
+ e.readyList.Remove(entry)
+ if entry.flags&OneShot != 0 {
+ e.disabledList.PushBack(entry)
+ entry.curList = &e.disabledList
+ } else if entry.flags&EdgeTriggered != 0 {
+ e.waitingList.PushBack(entry)
+ entry.curList = &e.waitingList
+ } else {
+ local.PushBack(entry)
+ }
+ }
+
+ e.readyList.PushBackList(&local)
+
+ e.listsMu.Unlock()
+
+ return ret
+}
+
+// Callback implements waiter.EntryCallback.Callback.
+//
+// Callback is called when one of the files we're polling becomes ready. It
+// moves said file to the readyList if it's currently in the waiting list.
+func (p *pollEntry) Callback(*waiter.Entry) {
+ e := p.epoll
+
+ e.listsMu.Lock()
+
+ if p.curList == &e.waitingList {
+ e.waitingList.Remove(p)
+ e.readyList.PushBack(p)
+ p.curList = &e.readyList
+ e.listsMu.Unlock()
+
+ e.Notify(waiter.EventIn)
+ return
+ }
+
+ e.listsMu.Unlock()
+}
+
+// initEntryReadiness initializes the entry's state with regards to its
+// readiness by placing it in the appropriate list and registering for
+// notifications.
+func (e *EventPoll) initEntryReadiness(entry *pollEntry) {
+ // A new entry starts off in the waiting list.
+ e.listsMu.Lock()
+ e.waitingList.PushBack(entry)
+ entry.curList = &e.waitingList
+ e.listsMu.Unlock()
+
+ // Register for event notifications.
+ f := entry.id.File
+ f.EventRegister(&entry.waiter, entry.mask)
+
+ // Check if the file happens to already be in a ready state.
+ ready := f.Readiness(entry.mask) & entry.mask
+ if ready != 0 {
+ entry.Callback(&entry.waiter)
+ }
+}
+
+// observes checks if event poll object e is directly or indirectly observing
+// event poll object ep. It uses a bounded recursive depth-first search.
+func (e *EventPoll) observes(ep *EventPoll, depthLeft int) bool {
+ // If we reached the maximum depth, we'll consider that we found it
+ // because we don't want to allow chains that are too long.
+ if depthLeft <= 0 {
+ return true
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Go through each observed file and check if it is or observes ep.
+ for id := range e.files {
+ f, ok := id.File.FileOperations.(*EventPoll)
+ if !ok {
+ continue
+ }
+
+ if f == ep || f.observes(ep, depthLeft-1) {
+ return true
+ }
+ }
+
+ return false
+}
+
+// AddEntry adds a new file to the collection of files observed by e.
+func (e *EventPoll) AddEntry(id FileIdentifier, flags EntryFlags, mask waiter.EventMask, data [2]int32) error {
+ // Acquire cycle check lock if another event poll is being added.
+ ep, ok := id.File.FileOperations.(*EventPoll)
+ if ok {
+ cycleMu.Lock()
+ defer cycleMu.Unlock()
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Fail if the file already has an entry.
+ if _, ok := e.files[id]; ok {
+ return syscall.EEXIST
+ }
+
+ // Check if a cycle would be created. We use 4 as the limit because
+ // that's the value used by linux and we want to emulate it.
+ if ep != nil {
+ if e == ep {
+ return syscall.EINVAL
+ }
+
+ if ep.observes(e, 4) {
+ return syscall.ELOOP
+ }
+ }
+
+ // Create new entry and add it to map.
+ //
+ // N.B. Even though we are creating a weak reference here, we know it
+ // won't trigger a callback because we hold a reference to the file
+ // throughout the execution of this function.
+ entry := &pollEntry{
+ id: id,
+ userData: data,
+ epoll: e,
+ flags: flags,
+ mask: mask,
+ }
+ entry.waiter.Callback = entry
+ e.files[id] = entry
+ entry.file = refs.NewWeakRef(id.File, entry)
+
+ // Initialize the readiness state of the new entry.
+ e.initEntryReadiness(entry)
+
+ return nil
+}
+
+// UpdateEntry updates the flags, mask and user data associated with a file that
+// is already part of the collection of observed files.
+func (e *EventPoll) UpdateEntry(id FileIdentifier, flags EntryFlags, mask waiter.EventMask, data [2]int32) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Fail if the file doesn't have an entry.
+ entry, ok := e.files[id]
+ if !ok {
+ return syscall.ENOENT
+ }
+
+ // Unregister the old mask and remove entry from the list it's in, so
+ // (*pollEntry).Callback is guaranteed to not be called on this entry anymore.
+ entry.id.File.EventUnregister(&entry.waiter)
+
+ // Remove entry from whatever list it's in. This ensure that no other
+ // threads have access to this entry as the only way left to find it
+ // is via e.files, but we hold e.mu, which prevents that.
+ e.listsMu.Lock()
+ entry.curList.Remove(entry)
+ e.listsMu.Unlock()
+
+ // Initialize new readiness state.
+ entry.flags = flags
+ entry.mask = mask
+ entry.userData = data
+ e.initEntryReadiness(entry)
+
+ return nil
+}
+
+// RemoveEntry a files from the collection of observed files.
+func (e *EventPoll) RemoveEntry(id FileIdentifier) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Fail if the file doesn't have an entry.
+ entry, ok := e.files[id]
+ if !ok {
+ return syscall.ENOENT
+ }
+
+ // Unregister from file first so that no concurrent attempts will be
+ // made to manipulate the file.
+ entry.id.File.EventUnregister(&entry.waiter)
+
+ // Remove from the current list.
+ e.listsMu.Lock()
+ entry.curList.Remove(entry)
+ entry.curList = nil
+ e.listsMu.Unlock()
+
+ // Remove file from map, and drop weak reference.
+ delete(e.files, id)
+ entry.file.Drop()
+
+ return nil
+}
+
+// UnregisterEpollWaiters removes the epoll waiter objects from the waiting
+// queues. This is different from Release() as the file is not dereferenced.
+func (e *EventPoll) UnregisterEpollWaiters() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ for _, entry := range e.files {
+ entry.id.File.EventUnregister(&entry.waiter)
+ }
+}
diff --git a/pkg/sentry/kernel/epoll/epoll_state.go b/pkg/sentry/kernel/epoll/epoll_state.go
new file mode 100644
index 000000000..7c61e0258
--- /dev/null
+++ b/pkg/sentry/kernel/epoll/epoll_state.go
@@ -0,0 +1,51 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package epoll
+
+import (
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// afterLoad is invoked by stateify.
+func (p *pollEntry) afterLoad() {
+ p.waiter.Callback = p
+ p.file = refs.NewWeakRef(p.id.File, p)
+ p.id.File.EventRegister(&p.waiter, p.mask)
+}
+
+// afterLoad is invoked by stateify.
+func (e *EventPoll) afterLoad() {
+ e.listsMu.Lock()
+ defer e.listsMu.Unlock()
+
+ for _, ls := range []*pollEntryList{&e.waitingList, &e.readyList, &e.disabledList} {
+ for it := ls.Front(); it != nil; it = it.Next() {
+ it.curList = ls
+ }
+ }
+
+ for it := e.waitingList.Front(); it != nil; {
+ entry := it
+ it = it.Next()
+
+ if entry.id.File.Readiness(entry.mask) != 0 {
+ e.waitingList.Remove(entry)
+ e.readyList.PushBack(entry)
+ entry.curList = &e.readyList
+ e.Notify(waiter.EventIn)
+ }
+ }
+}
diff --git a/pkg/sentry/kernel/epoll/epoll_test.go b/pkg/sentry/kernel/epoll/epoll_test.go
new file mode 100644
index 000000000..22630e9c5
--- /dev/null
+++ b/pkg/sentry/kernel/epoll/epoll_test.go
@@ -0,0 +1,54 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package epoll
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fs/filetest"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestFileDestroyed(t *testing.T) {
+ f := filetest.NewTestFile(t)
+ id := FileIdentifier{f, 12}
+
+ efile := NewEventPoll(contexttest.Context(t))
+ e := efile.FileOperations.(*EventPoll)
+ if err := e.AddEntry(id, 0, waiter.EventIn, [2]int32{}); err != nil {
+ t.Fatalf("addEntry failed: %v", err)
+ }
+
+ // Check that we get an event reported twice in a row.
+ evt := e.ReadEvents(1)
+ if len(evt) != 1 {
+ t.Fatalf("Unexpected number of ready events: want %v, got %v", 1, len(evt))
+ }
+
+ evt = e.ReadEvents(1)
+ if len(evt) != 1 {
+ t.Fatalf("Unexpected number of ready events: want %v, got %v", 1, len(evt))
+ }
+
+ // Destroy the file. Check that we get no more events.
+ f.DecRef()
+
+ evt = e.ReadEvents(1)
+ if len(evt) != 0 {
+ t.Fatalf("Unexpected number of ready events: want %v, got %v", 0, len(evt))
+ }
+
+}
diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD
new file mode 100644
index 000000000..9983a32e5
--- /dev/null
+++ b/pkg/sentry/kernel/eventfd/BUILD
@@ -0,0 +1,33 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "eventfd",
+ srcs = ["eventfd.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fdnotifier",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/anon",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "eventfd_test",
+ size = "small",
+ srcs = ["eventfd_test.go"],
+ library = ":eventfd",
+ deps = [
+ "//pkg/sentry/contexttest",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go
new file mode 100644
index 000000000..87951adeb
--- /dev/null
+++ b/pkg/sentry/kernel/eventfd/eventfd.go
@@ -0,0 +1,285 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package eventfd provides an implementation of Linux's file-based event
+// notification.
+package eventfd
+
+import (
+ "math"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/anon"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// EventOperations represents an event with the semantics of Linux's file-based event
+// notification (eventfd). Eventfds are usually internal to the Sentry but in certain
+// situations they may be converted into a host-backed eventfd.
+//
+// +stateify savable
+type EventOperations struct {
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ // Mutex that protects accesses to the fields of this event.
+ mu sync.Mutex `state:"nosave"`
+
+ // Queue is used to notify interested parties when the event object
+ // becomes readable or writable.
+ wq waiter.Queue `state:"zerovalue"`
+
+ // val is the current value of the event counter.
+ val uint64
+
+ // semMode specifies whether the event is in "semaphore" mode.
+ semMode bool
+
+ // hostfd indicates whether this eventfd is passed through to the host.
+ hostfd int
+}
+
+// New creates a new event object with the supplied initial value and mode.
+func New(ctx context.Context, initVal uint64, semMode bool) *fs.File {
+ // name matches fs/eventfd.c:eventfd_file_create.
+ dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[eventfd]")
+ // Release the initial dirent reference after NewFile takes a reference.
+ defer dirent.DecRef()
+ return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &EventOperations{
+ val: initVal,
+ semMode: semMode,
+ hostfd: -1,
+ })
+}
+
+// HostFD returns the host eventfd associated with this event.
+func (e *EventOperations) HostFD() (int, error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if e.hostfd >= 0 {
+ return e.hostfd, nil
+ }
+
+ flags := linux.EFD_NONBLOCK
+ if e.semMode {
+ flags |= linux.EFD_SEMAPHORE
+ }
+
+ fd, _, err := syscall.Syscall(syscall.SYS_EVENTFD2, uintptr(e.val), uintptr(flags), 0)
+ if err != 0 {
+ return -1, err
+ }
+
+ if err := fdnotifier.AddFD(int32(fd), &e.wq); err != nil {
+ syscall.Close(int(fd))
+ return -1, err
+ }
+
+ e.hostfd = int(fd)
+ return e.hostfd, nil
+}
+
+// Release implements fs.FileOperations.Release.
+func (e *EventOperations) Release() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if e.hostfd >= 0 {
+ fdnotifier.RemoveFD(int32(e.hostfd))
+ syscall.Close(e.hostfd)
+ e.hostfd = -1
+ }
+}
+
+// Read implements fs.FileOperations.Read.
+func (e *EventOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ if dst.NumBytes() < 8 {
+ return 0, syscall.EINVAL
+ }
+ if err := e.read(ctx, dst); err != nil {
+ return 0, err
+ }
+ return 8, nil
+}
+
+// Write implements fs.FileOperations.Write.
+func (e *EventOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ if src.NumBytes() < 8 {
+ return 0, syscall.EINVAL
+ }
+ if err := e.write(ctx, src); err != nil {
+ return 0, err
+ }
+ return 8, nil
+}
+
+// Must be called with e.mu locked.
+func (e *EventOperations) hostRead(ctx context.Context, dst usermem.IOSequence) error {
+ var buf [8]byte
+
+ if _, err := syscall.Read(e.hostfd, buf[:]); err != nil {
+ if err == syscall.EWOULDBLOCK {
+ return syserror.ErrWouldBlock
+ }
+ return err
+ }
+
+ _, err := dst.CopyOut(ctx, buf[:])
+ return err
+}
+
+func (e *EventOperations) read(ctx context.Context, dst usermem.IOSequence) error {
+ e.mu.Lock()
+
+ if e.hostfd >= 0 {
+ defer e.mu.Unlock()
+ return e.hostRead(ctx, dst)
+ }
+
+ // We can't complete the read if the value is currently zero.
+ if e.val == 0 {
+ e.mu.Unlock()
+ return syserror.ErrWouldBlock
+ }
+
+ // Update the value based on the mode the event is operating in.
+ var val uint64
+ if e.semMode {
+ val = 1
+ // Consistent with Linux, this is done even if writing to memory fails.
+ e.val--
+ } else {
+ val = e.val
+ e.val = 0
+ }
+
+ e.mu.Unlock()
+
+ // Notify writers. We do this even if we were already writable because
+ // it is possible that a writer is waiting to write the maximum value
+ // to the event.
+ e.wq.Notify(waiter.EventOut)
+
+ var buf [8]byte
+ usermem.ByteOrder.PutUint64(buf[:], val)
+ _, err := dst.CopyOut(ctx, buf[:])
+ return err
+}
+
+// Must be called with e.mu locked.
+func (e *EventOperations) hostWrite(val uint64) error {
+ var buf [8]byte
+ usermem.ByteOrder.PutUint64(buf[:], val)
+ _, err := syscall.Write(e.hostfd, buf[:])
+ if err == syscall.EWOULDBLOCK {
+ return syserror.ErrWouldBlock
+ }
+ return err
+}
+
+func (e *EventOperations) write(ctx context.Context, src usermem.IOSequence) error {
+ var buf [8]byte
+ if _, err := src.CopyIn(ctx, buf[:]); err != nil {
+ return err
+ }
+ val := usermem.ByteOrder.Uint64(buf[:])
+
+ return e.Signal(val)
+}
+
+// Signal is an internal function to signal the event fd.
+func (e *EventOperations) Signal(val uint64) error {
+ if val == math.MaxUint64 {
+ return syscall.EINVAL
+ }
+
+ e.mu.Lock()
+
+ if e.hostfd >= 0 {
+ defer e.mu.Unlock()
+ return e.hostWrite(val)
+ }
+
+ // We only allow writes that won't cause the value to go over the max
+ // uint64 minus 1.
+ if val > math.MaxUint64-1-e.val {
+ e.mu.Unlock()
+ return syserror.ErrWouldBlock
+ }
+
+ e.val += val
+ e.mu.Unlock()
+
+ // Always trigger a notification.
+ e.wq.Notify(waiter.EventIn)
+
+ return nil
+}
+
+// Readiness returns the ready events for the event fd.
+func (e *EventOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ e.mu.Lock()
+ if e.hostfd >= 0 {
+ defer e.mu.Unlock()
+ return fdnotifier.NonBlockingPoll(int32(e.hostfd), mask)
+ }
+
+ ready := waiter.EventMask(0)
+ if e.val > 0 {
+ ready |= waiter.EventIn
+ }
+
+ if e.val < math.MaxUint64-1 {
+ ready |= waiter.EventOut
+ }
+ e.mu.Unlock()
+
+ return mask & ready
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (e *EventOperations) EventRegister(entry *waiter.Entry, mask waiter.EventMask) {
+ e.wq.EventRegister(entry, mask)
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if e.hostfd >= 0 {
+ fdnotifier.UpdateFD(int32(e.hostfd))
+ }
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (e *EventOperations) EventUnregister(entry *waiter.Entry) {
+ e.wq.EventUnregister(entry)
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if e.hostfd >= 0 {
+ fdnotifier.UpdateFD(int32(e.hostfd))
+ }
+}
diff --git a/pkg/sentry/kernel/eventfd/eventfd_test.go b/pkg/sentry/kernel/eventfd/eventfd_test.go
new file mode 100644
index 000000000..9b4892f74
--- /dev/null
+++ b/pkg/sentry/kernel/eventfd/eventfd_test.go
@@ -0,0 +1,78 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package eventfd
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestEventfd(t *testing.T) {
+ initVals := []uint64{
+ 0,
+ // Using a non-zero initial value verifies that writing to an
+ // eventfd signals when the eventfd's counter was already
+ // non-zero.
+ 343,
+ }
+
+ for _, initVal := range initVals {
+ ctx := contexttest.Context(t)
+
+ // Make a new event that is writable.
+ event := New(ctx, initVal, false)
+
+ // Register a callback for a write event.
+ w, ch := waiter.NewChannelEntry(nil)
+ event.EventRegister(&w, waiter.EventIn)
+ defer event.EventUnregister(&w)
+
+ data := []byte("00000124")
+ // Create and submit a write request.
+ n, err := event.Writev(ctx, usermem.BytesIOSequence(data))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 8 {
+ t.Errorf("eventfd.write wrote %d bytes, not full int64", n)
+ }
+
+ // Check if the callback fired due to the write event.
+ select {
+ case <-ch:
+ default:
+ t.Errorf("Didn't get notified of EventIn after write")
+ }
+ }
+}
+
+func TestEventfdStat(t *testing.T) {
+ ctx := contexttest.Context(t)
+
+ // Make a new event that is writable.
+ event := New(ctx, 0, false)
+
+ // Create and submit an stat request.
+ uattr, err := event.Dirent.Inode.UnstableAttr(ctx)
+ if err != nil {
+ t.Fatalf("eventfd stat request failed: %v", err)
+ }
+ if uattr.Size != 0 {
+ t.Fatal("EventFD size should be 0")
+ }
+}
diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD
new file mode 100644
index 000000000..2b3955598
--- /dev/null
+++ b/pkg/sentry/kernel/fasync/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "fasync",
+ srcs = ["fasync.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go
new file mode 100644
index 000000000..153d2cd9b
--- /dev/null
+++ b/pkg/sentry/kernel/fasync/fasync.go
@@ -0,0 +1,188 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fasync provides FIOASYNC related functionality.
+package fasync
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// New creates a new fs.FileAsync.
+func New() fs.FileAsync {
+ return &FileAsync{}
+}
+
+// NewVFS2 creates a new vfs.FileAsync.
+func NewVFS2() vfs.FileAsync {
+ return &FileAsync{}
+}
+
+// FileAsync sends signals when the registered file is ready for IO.
+//
+// +stateify savable
+type FileAsync struct {
+ // e is immutable after first use (which is protected by mu below).
+ e waiter.Entry
+
+ // regMu protects registeration and unregistration actions on e.
+ //
+ // regMu must be held while registration decisions are being made
+ // through the registration action itself.
+ //
+ // Lock ordering: regMu, mu.
+ regMu sync.Mutex `state:"nosave"`
+
+ // mu protects all following fields.
+ //
+ // Lock ordering: e.mu, mu.
+ mu sync.Mutex `state:"nosave"`
+ requester *auth.Credentials
+ registered bool
+
+ // Only one of the following is allowed to be non-nil.
+ recipientPG *kernel.ProcessGroup
+ recipientTG *kernel.ThreadGroup
+ recipientT *kernel.Task
+}
+
+// Callback sends a signal.
+func (a *FileAsync) Callback(e *waiter.Entry) {
+ a.mu.Lock()
+ if !a.registered {
+ a.mu.Unlock()
+ return
+ }
+ t := a.recipientT
+ tg := a.recipientTG
+ if a.recipientPG != nil {
+ tg = a.recipientPG.Originator()
+ }
+ if tg != nil {
+ t = tg.Leader()
+ }
+ if t == nil {
+ // No recipient has been registered.
+ a.mu.Unlock()
+ return
+ }
+ c := t.Credentials()
+ // Logic from sigio_perm in fs/fcntl.c.
+ if a.requester.EffectiveKUID == 0 ||
+ a.requester.EffectiveKUID == c.SavedKUID ||
+ a.requester.EffectiveKUID == c.RealKUID ||
+ a.requester.RealKUID == c.SavedKUID ||
+ a.requester.RealKUID == c.RealKUID {
+ t.SendSignal(kernel.SignalInfoPriv(linux.SIGIO))
+ }
+ a.mu.Unlock()
+}
+
+// Register sets the file which will be monitored for IO events.
+//
+// The file must not be currently registered.
+func (a *FileAsync) Register(w waiter.Waitable) {
+ a.regMu.Lock()
+ defer a.regMu.Unlock()
+ a.mu.Lock()
+
+ if a.registered {
+ a.mu.Unlock()
+ panic("registering already registered file")
+ }
+
+ if a.e.Callback == nil {
+ a.e.Callback = a
+ }
+ a.registered = true
+
+ a.mu.Unlock()
+ w.EventRegister(&a.e, waiter.EventIn|waiter.EventOut|waiter.EventErr|waiter.EventHUp)
+}
+
+// Unregister stops monitoring a file.
+//
+// The file must be currently registered.
+func (a *FileAsync) Unregister(w waiter.Waitable) {
+ a.regMu.Lock()
+ defer a.regMu.Unlock()
+ a.mu.Lock()
+
+ if !a.registered {
+ a.mu.Unlock()
+ panic("unregistering unregistered file")
+ }
+
+ a.registered = false
+
+ a.mu.Unlock()
+ w.EventUnregister(&a.e)
+}
+
+// Owner returns who is currently getting signals. All return values will be
+// nil if no one is set to receive signals.
+func (a *FileAsync) Owner() (*kernel.Task, *kernel.ThreadGroup, *kernel.ProcessGroup) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ return a.recipientT, a.recipientTG, a.recipientPG
+}
+
+// SetOwnerTask sets the owner (who will receive signals) to a specified task.
+// Only this owner will receive signals.
+func (a *FileAsync) SetOwnerTask(requester *kernel.Task, recipient *kernel.Task) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.requester = requester.Credentials()
+ a.recipientT = recipient
+ a.recipientTG = nil
+ a.recipientPG = nil
+}
+
+// SetOwnerThreadGroup sets the owner (who will receive signals) to a specified
+// thread group. Only this owner will receive signals.
+func (a *FileAsync) SetOwnerThreadGroup(requester *kernel.Task, recipient *kernel.ThreadGroup) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.requester = requester.Credentials()
+ a.recipientT = nil
+ a.recipientTG = recipient
+ a.recipientPG = nil
+}
+
+// SetOwnerProcessGroup sets the owner (who will receive signals) to a
+// specified process group. Only this owner will receive signals.
+func (a *FileAsync) SetOwnerProcessGroup(requester *kernel.Task, recipient *kernel.ProcessGroup) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.requester = requester.Credentials()
+ a.recipientT = nil
+ a.recipientTG = nil
+ a.recipientPG = recipient
+}
+
+// ClearOwner unsets the current signal recipient.
+func (a *FileAsync) ClearOwner() {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.requester = nil
+ a.recipientT = nil
+ a.recipientTG = nil
+ a.recipientPG = nil
+}
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
new file mode 100644
index 000000000..4b7d234a4
--- /dev/null
+++ b/pkg/sentry/kernel/fd_table.go
@@ -0,0 +1,638 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+ "math"
+ "strings"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// FDFlags define flags for an individual descriptor.
+//
+// +stateify savable
+type FDFlags struct {
+ // CloseOnExec indicates the descriptor should be closed on exec.
+ CloseOnExec bool
+}
+
+// ToLinuxFileFlags converts a kernel.FDFlags object to a Linux file flags
+// representation.
+func (f FDFlags) ToLinuxFileFlags() (mask uint) {
+ if f.CloseOnExec {
+ mask |= linux.O_CLOEXEC
+ }
+ return
+}
+
+// ToLinuxFDFlags converts a kernel.FDFlags object to a Linux descriptor flags
+// representation.
+func (f FDFlags) ToLinuxFDFlags() (mask uint) {
+ if f.CloseOnExec {
+ mask |= linux.FD_CLOEXEC
+ }
+ return
+}
+
+// descriptor holds the details about a file descriptor, namely a pointer to
+// the file itself and the descriptor flags.
+//
+// Note that this is immutable and can only be changed via operations on the
+// descriptorTable.
+//
+// It contains both VFS1 and VFS2 file types, but only one of them can be set.
+//
+// +stateify savable
+type descriptor struct {
+ // TODO(gvisor.dev/issue/1624): Remove fs.File.
+ file *fs.File
+ fileVFS2 *vfs.FileDescription
+ flags FDFlags
+}
+
+// FDTable is used to manage File references and flags.
+//
+// +stateify savable
+type FDTable struct {
+ refs.AtomicRefCount
+ k *Kernel
+
+ // mu protects below.
+ mu sync.Mutex `state:"nosave"`
+
+ // next is start position to find fd.
+ next int32
+
+ // used contains the number of non-nil entries. It must be accessed
+ // atomically. It may be read atomically without holding mu (but not
+ // written).
+ used int32
+
+ // descriptorTable holds descriptors.
+ descriptorTable `state:".(map[int32]descriptor)"`
+}
+
+func (f *FDTable) saveDescriptorTable() map[int32]descriptor {
+ m := make(map[int32]descriptor)
+ f.forEach(func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
+ m[fd] = descriptor{
+ file: file,
+ fileVFS2: fileVFS2,
+ flags: flags,
+ }
+ })
+ return m
+}
+
+func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) {
+ f.init() // Initialize table.
+ for fd, d := range m {
+ f.setAll(fd, d.file, d.fileVFS2, d.flags)
+
+ // Note that we do _not_ need to acquire a extra table reference here. The
+ // table reference will already be accounted for in the file, so we drop the
+ // reference taken by set above.
+ switch {
+ case d.file != nil:
+ d.file.DecRef()
+ case d.fileVFS2 != nil:
+ d.fileVFS2.DecRef()
+ }
+ }
+}
+
+// drop drops the table reference.
+func (f *FDTable) drop(file *fs.File) {
+ // Release locks.
+ file.Dirent.Inode.LockCtx.Posix.UnlockRegion(f, lock.LockRange{0, lock.LockEOF})
+
+ // Send inotify events.
+ d := file.Dirent
+ var ev uint32
+ if fs.IsDir(d.Inode.StableAttr) {
+ ev |= linux.IN_ISDIR
+ }
+ if file.Flags().Write {
+ ev |= linux.IN_CLOSE_WRITE
+ } else {
+ ev |= linux.IN_CLOSE_NOWRITE
+ }
+ d.InotifyEvent(ev, 0)
+
+ // Drop the table reference.
+ file.DecRef()
+}
+
+// dropVFS2 drops the table reference.
+func (f *FDTable) dropVFS2(file *vfs.FileDescription) {
+ // Release any POSIX lock possibly held by the FDTable. Range {0, 0} means the
+ // entire file.
+ err := file.UnlockPOSIX(context.Background(), f, 0, 0, linux.SEEK_SET)
+ if err != nil && err != syserror.ENOLCK {
+ panic(fmt.Sprintf("UnlockPOSIX failed: %v", err))
+ }
+
+ // Generate inotify events.
+ ev := uint32(linux.IN_CLOSE_NOWRITE)
+ if file.IsWritable() {
+ ev = linux.IN_CLOSE_WRITE
+ }
+ file.Dentry().InotifyWithParent(ev, 0, vfs.PathEvent)
+
+ // Drop the table's reference.
+ file.DecRef()
+}
+
+// NewFDTable allocates a new FDTable that may be used by tasks in k.
+func (k *Kernel) NewFDTable() *FDTable {
+ f := &FDTable{k: k}
+ f.init()
+ return f
+}
+
+// destroy removes all of the file descriptors from the map.
+func (f *FDTable) destroy() {
+ f.RemoveIf(func(*fs.File, *vfs.FileDescription, FDFlags) bool {
+ return true
+ })
+}
+
+// DecRef implements RefCounter.DecRef with destructor f.destroy.
+func (f *FDTable) DecRef() {
+ f.DecRefWithDestructor(f.destroy)
+}
+
+// Size returns the number of file descriptor slots currently allocated.
+func (f *FDTable) Size() int {
+ size := atomic.LoadInt32(&f.used)
+ return int(size)
+}
+
+// forEach iterates over all non-nil files in sorted order.
+//
+// It is the caller's responsibility to acquire an appropriate lock.
+func (f *FDTable) forEach(fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags)) {
+ // retries tracks the number of failed TryIncRef attempts for the same FD.
+ retries := 0
+ fd := int32(0)
+ for {
+ file, fileVFS2, flags, ok := f.getAll(fd)
+ if !ok {
+ break
+ }
+ switch {
+ case file != nil:
+ if !file.TryIncRef() {
+ retries++
+ if retries > 1000 {
+ panic(fmt.Sprintf("File in FD table has been destroyed. FD: %d, File: %+v, FileOps: %+v", fd, file, file.FileOperations))
+ }
+ continue // Race caught.
+ }
+ fn(fd, file, nil, flags)
+ file.DecRef()
+ case fileVFS2 != nil:
+ if !fileVFS2.TryIncRef() {
+ retries++
+ if retries > 1000 {
+ panic(fmt.Sprintf("File in FD table has been destroyed. FD: %d, File: %+v, Impl: %+v", fd, fileVFS2, fileVFS2.Impl()))
+ }
+ continue // Race caught.
+ }
+ fn(fd, nil, fileVFS2, flags)
+ fileVFS2.DecRef()
+ }
+ retries = 0
+ fd++
+ }
+}
+
+// String is a stringer for FDTable.
+func (f *FDTable) String() string {
+ var buf strings.Builder
+ f.forEach(func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
+ switch {
+ case file != nil:
+ n, _ := file.Dirent.FullName(nil /* root */)
+ fmt.Fprintf(&buf, "\tfd:%d => name %s\n", fd, n)
+
+ case fileVFS2 != nil:
+ vfsObj := fileVFS2.Mount().Filesystem().VirtualFilesystem()
+ name, err := vfsObj.PathnameWithDeleted(context.Background(), vfs.VirtualDentry{}, fileVFS2.VirtualDentry())
+ if err != nil {
+ fmt.Fprintf(&buf, "<err: %v>\n", err)
+ return
+ }
+ fmt.Fprintf(&buf, "\tfd:%d => name %s\n", fd, name)
+ }
+ })
+ return buf.String()
+}
+
+// NewFDs allocates new FDs guaranteed to be the lowest number available
+// greater than or equal to the fd parameter. All files will share the set
+// flags. Success is guaranteed to be all or none.
+func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags FDFlags) (fds []int32, err error) {
+ if fd < 0 {
+ // Don't accept negative FDs.
+ return nil, syscall.EINVAL
+ }
+
+ // Default limit.
+ end := int32(math.MaxInt32)
+
+ // Ensure we don't get past the provided limit.
+ if limitSet := limits.FromContext(ctx); limitSet != nil {
+ lim := limitSet.Get(limits.NumberOfFiles)
+ if lim.Cur != limits.Infinity {
+ end = int32(lim.Cur)
+ }
+ if fd >= end {
+ return nil, syscall.EMFILE
+ }
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // From f.next to find available fd.
+ if fd < f.next {
+ fd = f.next
+ }
+
+ // Install all entries.
+ for i := fd; i < end && len(fds) < len(files); i++ {
+ if d, _, _ := f.get(i); d == nil {
+ f.set(i, files[len(fds)], flags) // Set the descriptor.
+ fds = append(fds, i) // Record the file descriptor.
+ }
+ }
+
+ // Failure? Unwind existing FDs.
+ if len(fds) < len(files) {
+ for _, i := range fds {
+ f.set(i, nil, FDFlags{}) // Zap entry.
+ }
+ return nil, syscall.EMFILE
+ }
+
+ if fd == f.next {
+ // Update next search start position.
+ f.next = fds[len(fds)-1] + 1
+ }
+
+ return fds, nil
+}
+
+// NewFDsVFS2 allocates new FDs guaranteed to be the lowest number available
+// greater than or equal to the fd parameter. All files will share the set
+// flags. Success is guaranteed to be all or none.
+func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDescription, flags FDFlags) (fds []int32, err error) {
+ if fd < 0 {
+ // Don't accept negative FDs.
+ return nil, syscall.EINVAL
+ }
+
+ // Default limit.
+ end := int32(math.MaxInt32)
+
+ // Ensure we don't get past the provided limit.
+ if limitSet := limits.FromContext(ctx); limitSet != nil {
+ lim := limitSet.Get(limits.NumberOfFiles)
+ if lim.Cur != limits.Infinity {
+ end = int32(lim.Cur)
+ }
+ if fd >= end {
+ return nil, syscall.EMFILE
+ }
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // From f.next to find available fd.
+ if fd < f.next {
+ fd = f.next
+ }
+
+ // Install all entries.
+ for i := fd; i < end && len(fds) < len(files); i++ {
+ if d, _, _ := f.getVFS2(i); d == nil {
+ f.setVFS2(i, files[len(fds)], flags) // Set the descriptor.
+ fds = append(fds, i) // Record the file descriptor.
+ }
+ }
+
+ // Failure? Unwind existing FDs.
+ if len(fds) < len(files) {
+ for _, i := range fds {
+ f.setVFS2(i, nil, FDFlags{}) // Zap entry.
+ }
+ return nil, syscall.EMFILE
+ }
+
+ if fd == f.next {
+ // Update next search start position.
+ f.next = fds[len(fds)-1] + 1
+ }
+
+ return fds, nil
+}
+
+// NewFDVFS2 allocates a file descriptor greater than or equal to minfd for
+// the given file description. If it succeeds, it takes a reference on file.
+func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
+ if minfd < 0 {
+ // Don't accept negative FDs.
+ return -1, syscall.EINVAL
+ }
+
+ // Default limit.
+ end := int32(math.MaxInt32)
+
+ // Ensure we don't get past the provided limit.
+ if limitSet := limits.FromContext(ctx); limitSet != nil {
+ lim := limitSet.Get(limits.NumberOfFiles)
+ if lim.Cur != limits.Infinity {
+ end = int32(lim.Cur)
+ }
+ if minfd >= end {
+ return -1, syscall.EMFILE
+ }
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // From f.next to find available fd.
+ fd := minfd
+ if fd < f.next {
+ fd = f.next
+ }
+ for fd < end {
+ if d, _, _ := f.getVFS2(fd); d == nil {
+ f.setVFS2(fd, file, flags)
+ if fd == f.next {
+ // Update next search start position.
+ f.next = fd + 1
+ }
+ return fd, nil
+ }
+ fd++
+ }
+ return -1, syscall.EMFILE
+}
+
+// NewFDAt sets the file reference for the given FD. If there is an active
+// reference for that FD, the ref count for that existing reference is
+// decremented.
+func (f *FDTable) NewFDAt(ctx context.Context, fd int32, file *fs.File, flags FDFlags) error {
+ return f.newFDAt(ctx, fd, file, nil, flags)
+}
+
+// NewFDAtVFS2 sets the file reference for the given FD. If there is an active
+// reference for that FD, the ref count for that existing reference is
+// decremented.
+func (f *FDTable) NewFDAtVFS2(ctx context.Context, fd int32, file *vfs.FileDescription, flags FDFlags) error {
+ return f.newFDAt(ctx, fd, nil, file, flags)
+}
+
+func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) error {
+ if fd < 0 {
+ // Don't accept negative FDs.
+ return syscall.EBADF
+ }
+
+ // Check the limit for the provided file.
+ if limitSet := limits.FromContext(ctx); limitSet != nil {
+ if lim := limitSet.Get(limits.NumberOfFiles); lim.Cur != limits.Infinity && uint64(fd) >= lim.Cur {
+ return syscall.EMFILE
+ }
+ }
+
+ // Install the entry.
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.setAll(fd, file, fileVFS2, flags)
+ return nil
+}
+
+// SetFlags sets the flags for the given file descriptor.
+//
+// True is returned iff flags were changed.
+func (f *FDTable) SetFlags(fd int32, flags FDFlags) error {
+ if fd < 0 {
+ // Don't accept negative FDs.
+ return syscall.EBADF
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ file, _, _ := f.get(fd)
+ if file == nil {
+ // No file found.
+ return syscall.EBADF
+ }
+
+ // Update the flags.
+ f.set(fd, file, flags)
+ return nil
+}
+
+// SetFlagsVFS2 sets the flags for the given file descriptor.
+//
+// True is returned iff flags were changed.
+func (f *FDTable) SetFlagsVFS2(fd int32, flags FDFlags) error {
+ if fd < 0 {
+ // Don't accept negative FDs.
+ return syscall.EBADF
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ file, _, _ := f.getVFS2(fd)
+ if file == nil {
+ // No file found.
+ return syscall.EBADF
+ }
+
+ // Update the flags.
+ f.setVFS2(fd, file, flags)
+ return nil
+}
+
+// Get returns a reference to the file and the flags for the FD or nil if no
+// file is defined for the given fd.
+//
+// N.B. Callers are required to use DecRef when they are done.
+//
+//go:nosplit
+func (f *FDTable) Get(fd int32) (*fs.File, FDFlags) {
+ if fd < 0 {
+ return nil, FDFlags{}
+ }
+
+ for {
+ file, flags, _ := f.get(fd)
+ if file != nil {
+ if !file.TryIncRef() {
+ continue // Race caught.
+ }
+ // Reference acquired.
+ return file, flags
+ }
+ // No file available.
+ return nil, FDFlags{}
+ }
+}
+
+// GetVFS2 returns a reference to the file and the flags for the FD or nil if no
+// file is defined for the given fd.
+//
+// N.B. Callers are required to use DecRef when they are done.
+//
+//go:nosplit
+func (f *FDTable) GetVFS2(fd int32) (*vfs.FileDescription, FDFlags) {
+ if fd < 0 {
+ return nil, FDFlags{}
+ }
+
+ for {
+ file, flags, _ := f.getVFS2(fd)
+ if file != nil {
+ if !file.TryIncRef() {
+ continue // Race caught.
+ }
+ // Reference acquired.
+ return file, flags
+ }
+ // No file available.
+ return nil, FDFlags{}
+ }
+}
+
+// GetFDs returns a sorted list of valid fds.
+//
+// Precondition: The caller must be running on the task goroutine, or Task.mu
+// must be locked.
+func (f *FDTable) GetFDs() []int32 {
+ fds := make([]int32, 0, int(atomic.LoadInt32(&f.used)))
+ f.forEach(func(fd int32, _ *fs.File, _ *vfs.FileDescription, _ FDFlags) {
+ fds = append(fds, fd)
+ })
+ return fds
+}
+
+// GetRefs returns a stable slice of references to all files and bumps the
+// reference count on each. The caller must use DecRef on each reference when
+// they're done using the slice.
+func (f *FDTable) GetRefs() []*fs.File {
+ files := make([]*fs.File, 0, f.Size())
+ f.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
+ file.IncRef() // Acquire a reference for caller.
+ files = append(files, file)
+ })
+ return files
+}
+
+// GetRefsVFS2 returns a stable slice of references to all files and bumps the
+// reference count on each. The caller must use DecRef on each reference when
+// they're done using the slice.
+func (f *FDTable) GetRefsVFS2() []*vfs.FileDescription {
+ files := make([]*vfs.FileDescription, 0, f.Size())
+ f.forEach(func(_ int32, _ *fs.File, file *vfs.FileDescription, _ FDFlags) {
+ file.IncRef() // Acquire a reference for caller.
+ files = append(files, file)
+ })
+ return files
+}
+
+// Fork returns an independent FDTable.
+func (f *FDTable) Fork() *FDTable {
+ clone := f.k.NewFDTable()
+
+ f.forEach(func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
+ // The set function here will acquire an appropriate table
+ // reference for the clone. We don't need anything else.
+ switch {
+ case file != nil:
+ clone.set(fd, file, flags)
+ case fileVFS2 != nil:
+ clone.setVFS2(fd, fileVFS2, flags)
+ }
+ })
+ return clone
+}
+
+// Remove removes an FD from and returns a non-file iff successful.
+//
+// N.B. Callers are required to use DecRef when they are done.
+func (f *FDTable) Remove(fd int32) (*fs.File, *vfs.FileDescription) {
+ if fd < 0 {
+ return nil, nil
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // Update current available position.
+ if fd < f.next {
+ f.next = fd
+ }
+
+ orig, orig2, _, _ := f.getAll(fd)
+
+ // Add reference for caller.
+ switch {
+ case orig != nil:
+ orig.IncRef()
+ case orig2 != nil:
+ orig2.IncRef()
+ }
+ if orig != nil || orig2 != nil {
+ f.setAll(fd, nil, nil, FDFlags{}) // Zap entry.
+ }
+ return orig, orig2
+}
+
+// RemoveIf removes all FDs where cond is true.
+func (f *FDTable) RemoveIf(cond func(*fs.File, *vfs.FileDescription, FDFlags) bool) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ f.forEach(func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
+ if cond(file, fileVFS2, flags) {
+ f.set(fd, nil, FDFlags{}) // Clear from table.
+ // Update current available position.
+ if fd < f.next {
+ f.next = fd
+ }
+ }
+ })
+}
diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go
new file mode 100644
index 000000000..29f95a2c4
--- /dev/null
+++ b/pkg/sentry/kernel/fd_table_test.go
@@ -0,0 +1,228 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "runtime"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/filetest"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+const (
+ // maxFD is the maximum FD to try to create in the map.
+ //
+ // This number of open files has been seen in the wild.
+ maxFD = 2 * 1024
+)
+
+func runTest(t testing.TB, fn func(ctx context.Context, fdTable *FDTable, file *fs.File, limitSet *limits.LimitSet)) {
+ t.Helper() // Don't show in stacks.
+
+ // Create the limits and context.
+ limitSet := limits.NewLimitSet()
+ limitSet.Set(limits.NumberOfFiles, limits.Limit{maxFD, maxFD}, true)
+ ctx := contexttest.WithLimitSet(contexttest.Context(t), limitSet)
+
+ // Create a test file.;
+ file := filetest.NewTestFile(t)
+
+ // Create the table.
+ fdTable := new(FDTable)
+ fdTable.init()
+
+ // Run the test.
+ fn(ctx, fdTable, file, limitSet)
+}
+
+// TestFDTableMany allocates maxFD FDs, i.e. maxes out the FDTable, until there
+// is no room, then makes sure that NewFDAt works and also that if we remove
+// one and add one that works too.
+func TestFDTableMany(t *testing.T) {
+ runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) {
+ for i := 0; i < maxFD; i++ {
+ if _, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil {
+ t.Fatalf("Allocated %v FDs but wanted to allocate %v", i, maxFD)
+ }
+ }
+
+ if _, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err == nil {
+ t.Fatalf("fdTable.NewFDs(0, r) in full map: got nil, wanted error")
+ }
+
+ if err := fdTable.NewFDAt(ctx, 1, file, FDFlags{}); err != nil {
+ t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err)
+ }
+
+ i := int32(2)
+ fdTable.Remove(i)
+ if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != i {
+ t.Fatalf("Allocated %v FDs but wanted to allocate %v: %v", i, maxFD, err)
+ }
+ })
+}
+
+func TestFDTableOverLimit(t *testing.T) {
+ runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) {
+ if _, err := fdTable.NewFDs(ctx, maxFD, []*fs.File{file}, FDFlags{}); err == nil {
+ t.Fatalf("fdTable.NewFDs(maxFD, f): got nil, wanted error")
+ }
+
+ if _, err := fdTable.NewFDs(ctx, maxFD-2, []*fs.File{file, file, file}, FDFlags{}); err == nil {
+ t.Fatalf("fdTable.NewFDs(maxFD-2, {f,f,f}): got nil, wanted error")
+ }
+
+ if fds, err := fdTable.NewFDs(ctx, maxFD-3, []*fs.File{file, file, file}, FDFlags{}); err != nil {
+ t.Fatalf("fdTable.NewFDs(maxFD-3, {f,f,f}): got %v, wanted nil", err)
+ } else {
+ for _, fd := range fds {
+ fdTable.Remove(fd)
+ }
+ }
+
+ if fds, err := fdTable.NewFDs(ctx, maxFD-1, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != maxFD-1 {
+ t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err)
+ }
+
+ if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil {
+ t.Fatalf("Adding an FD to a resized map: got %v, want nil", err)
+ } else if len(fds) != 1 || fds[0] != 0 {
+ t.Fatalf("Added an FD to a resized map: got %v, want {1}", fds)
+ }
+ })
+}
+
+// TestFDTable does a set of simple tests to make sure simple adds, removes,
+// GetRefs, and DecRefs work. The ordering is just weird enough that a
+// table-driven approach seemed clumsy.
+func TestFDTable(t *testing.T) {
+ runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, limitSet *limits.LimitSet) {
+ // Cap the limit at one.
+ limitSet.Set(limits.NumberOfFiles, limits.Limit{1, maxFD}, true)
+
+ if _, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil {
+ t.Fatalf("Adding an FD to an empty 1-size map: got %v, want nil", err)
+ }
+
+ if _, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err == nil {
+ t.Fatalf("Adding an FD to a filled 1-size map: got nil, wanted an error")
+ }
+
+ // Remove the previous limit.
+ limitSet.Set(limits.NumberOfFiles, limits.Limit{maxFD, maxFD}, true)
+
+ if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil {
+ t.Fatalf("Adding an FD to a resized map: got %v, want nil", err)
+ } else if len(fds) != 1 || fds[0] != 1 {
+ t.Fatalf("Added an FD to a resized map: got %v, want {1}", fds)
+ }
+
+ if err := fdTable.NewFDAt(ctx, 1, file, FDFlags{}); err != nil {
+ t.Fatalf("Replacing FD 1 via fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err)
+ }
+
+ if err := fdTable.NewFDAt(ctx, maxFD+1, file, FDFlags{}); err == nil {
+ t.Fatalf("Using an FD that was too large via fdTable.NewFDAt(%v, r, FDFlags{}): got nil, wanted an error", maxFD+1)
+ }
+
+ if ref, _ := fdTable.Get(1); ref == nil {
+ t.Fatalf("fdTable.Get(1): got nil, wanted %v", file)
+ }
+
+ if ref, _ := fdTable.Get(2); ref != nil {
+ t.Fatalf("fdTable.Get(2): got a %v, wanted nil", ref)
+ }
+
+ ref, _ := fdTable.Remove(1)
+ if ref == nil {
+ t.Fatalf("fdTable.Remove(1) for an existing FD: failed, want success")
+ }
+ ref.DecRef()
+
+ if ref, _ := fdTable.Remove(1); ref != nil {
+ t.Fatalf("r.Remove(1) for a removed FD: got success, want failure")
+ }
+ })
+}
+
+func TestDescriptorFlags(t *testing.T) {
+ runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) {
+ if err := fdTable.NewFDAt(ctx, 2, file, FDFlags{CloseOnExec: true}); err != nil {
+ t.Fatalf("fdTable.NewFDAt(2, r, FDFlags{}): got %v, wanted nil", err)
+ }
+
+ newFile, flags := fdTable.Get(2)
+ if newFile == nil {
+ t.Fatalf("fdTable.Get(2): got a %v, wanted nil", newFile)
+ }
+
+ if !flags.CloseOnExec {
+ t.Fatalf("new File flags %v don't match original %d\n", flags, 0)
+ }
+ })
+}
+
+func BenchmarkFDLookupAndDecRef(b *testing.B) {
+ b.StopTimer() // Setup.
+
+ runTest(b, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) {
+ fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file, file, file, file, file}, FDFlags{})
+ if err != nil {
+ b.Fatalf("fdTable.NewFDs: got %v, wanted nil", err)
+ }
+
+ b.StartTimer() // Benchmark.
+ for i := 0; i < b.N; i++ {
+ tf, _ := fdTable.Get(fds[i%len(fds)])
+ tf.DecRef()
+ }
+ })
+}
+
+func BenchmarkFDLookupAndDecRefConcurrent(b *testing.B) {
+ b.StopTimer() // Setup.
+
+ runTest(b, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) {
+ fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file, file, file, file, file}, FDFlags{})
+ if err != nil {
+ b.Fatalf("fdTable.NewFDs: got %v, wanted nil", err)
+ }
+
+ concurrency := runtime.GOMAXPROCS(0)
+ if concurrency < 4 {
+ concurrency = 4
+ }
+ each := b.N / concurrency
+
+ b.StartTimer() // Benchmark.
+ var wg sync.WaitGroup
+ for i := 0; i < concurrency; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for i := 0; i < each; i++ {
+ tf, _ := fdTable.Get(fds[i%len(fds)])
+ tf.DecRef()
+ }
+ }()
+ }
+ wg.Wait()
+ })
+}
diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go
new file mode 100644
index 000000000..7fd97dc53
--- /dev/null
+++ b/pkg/sentry/kernel/fd_table_unsafe.go
@@ -0,0 +1,169 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "sync/atomic"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+type descriptorTable struct {
+ // slice is a *[]unsafe.Pointer, where each element is actually
+ // *descriptor object, updated atomically.
+ //
+ // Changes to the slice itself requiring holding FDTable.mu.
+ slice unsafe.Pointer `state:".(map[int32]*descriptor)"`
+}
+
+// init initializes the table.
+func (f *FDTable) init() {
+ var slice []unsafe.Pointer // Empty slice.
+ atomic.StorePointer(&f.slice, unsafe.Pointer(&slice))
+}
+
+// get gets a file entry.
+//
+// The boolean indicates whether this was in range.
+//
+//go:nosplit
+func (f *FDTable) get(fd int32) (*fs.File, FDFlags, bool) {
+ file, _, flags, ok := f.getAll(fd)
+ return file, flags, ok
+}
+
+// getVFS2 gets a file entry.
+//
+// The boolean indicates whether this was in range.
+//
+//go:nosplit
+func (f *FDTable) getVFS2(fd int32) (*vfs.FileDescription, FDFlags, bool) {
+ _, file, flags, ok := f.getAll(fd)
+ return file, flags, ok
+}
+
+// getAll gets a file entry.
+//
+// The boolean indicates whether this was in range.
+//
+//go:nosplit
+func (f *FDTable) getAll(fd int32) (*fs.File, *vfs.FileDescription, FDFlags, bool) {
+ slice := *(*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice))
+ if fd >= int32(len(slice)) {
+ return nil, nil, FDFlags{}, false
+ }
+ d := (*descriptor)(atomic.LoadPointer(&slice[fd]))
+ if d == nil {
+ return nil, nil, FDFlags{}, true
+ }
+ if d.file != nil && d.fileVFS2 != nil {
+ panic("VFS1 and VFS2 files set")
+ }
+ return d.file, d.fileVFS2, d.flags, true
+}
+
+// set sets an entry.
+//
+// This handles accounting changes, as well as acquiring and releasing the
+// reference needed by the table iff the file is different.
+//
+// Precondition: mu must be held.
+func (f *FDTable) set(fd int32, file *fs.File, flags FDFlags) {
+ f.setAll(fd, file, nil, flags)
+}
+
+// setVFS2 sets an entry.
+//
+// This handles accounting changes, as well as acquiring and releasing the
+// reference needed by the table iff the file is different.
+//
+// Precondition: mu must be held.
+func (f *FDTable) setVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) {
+ f.setAll(fd, nil, file, flags)
+}
+
+// setAll sets an entry.
+//
+// This handles accounting changes, as well as acquiring and releasing the
+// reference needed by the table iff the file is different.
+//
+// Precondition: mu must be held.
+func (f *FDTable) setAll(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
+ if file != nil && fileVFS2 != nil {
+ panic("VFS1 and VFS2 files set")
+ }
+
+ slice := *(*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice))
+
+ // Grow the table as required.
+ if last := int32(len(slice)); fd >= last {
+ end := fd + 1
+ if end < 2*last {
+ end = 2 * last
+ }
+ slice = append(slice, make([]unsafe.Pointer, end-last)...)
+ atomic.StorePointer(&f.slice, unsafe.Pointer(&slice))
+ }
+
+ var desc *descriptor
+ if file != nil || fileVFS2 != nil {
+ desc = &descriptor{
+ file: file,
+ fileVFS2: fileVFS2,
+ flags: flags,
+ }
+ }
+
+ // Update the single element.
+ orig := (*descriptor)(atomic.SwapPointer(&slice[fd], unsafe.Pointer(desc)))
+
+ // Acquire a table reference.
+ if desc != nil {
+ switch {
+ case desc.file != nil:
+ if orig == nil || desc.file != orig.file {
+ desc.file.IncRef()
+ }
+ case desc.fileVFS2 != nil:
+ if orig == nil || desc.fileVFS2 != orig.fileVFS2 {
+ desc.fileVFS2.IncRef()
+ }
+ }
+ }
+
+ // Drop the table reference.
+ if orig != nil {
+ switch {
+ case orig.file != nil:
+ if desc == nil || desc.file != orig.file {
+ f.drop(orig.file)
+ }
+ case orig.fileVFS2 != nil:
+ if desc == nil || desc.fileVFS2 != orig.fileVFS2 {
+ f.dropVFS2(orig.fileVFS2)
+ }
+ }
+ }
+
+ // Adjust used.
+ switch {
+ case orig == nil && desc != nil:
+ atomic.AddInt32(&f.used, 1)
+ case orig != nil && desc == nil:
+ atomic.AddInt32(&f.used, -1)
+ }
+}
diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go
new file mode 100644
index 000000000..47f78df9a
--- /dev/null
+++ b/pkg/sentry/kernel/fs_context.go
@@ -0,0 +1,283 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// FSContext contains filesystem context.
+//
+// This includes umask and working directory.
+//
+// +stateify savable
+type FSContext struct {
+ refs.AtomicRefCount
+
+ // mu protects below.
+ mu sync.Mutex `state:"nosave"`
+
+ // root is the filesystem root. Will be nil iff the FSContext has been
+ // destroyed.
+ root *fs.Dirent
+
+ // rootVFS2 is the filesystem root.
+ rootVFS2 vfs.VirtualDentry
+
+ // cwd is the current working directory. Will be nil iff the FSContext
+ // has been destroyed.
+ cwd *fs.Dirent
+
+ // cwdVFS2 is the current working directory.
+ cwdVFS2 vfs.VirtualDentry
+
+ // umask is the current file mode creation mask. When a thread using this
+ // context invokes a syscall that creates a file, bits set in umask are
+ // removed from the permissions that the file is created with.
+ umask uint
+}
+
+// newFSContext returns a new filesystem context.
+func newFSContext(root, cwd *fs.Dirent, umask uint) *FSContext {
+ root.IncRef()
+ cwd.IncRef()
+ f := FSContext{
+ root: root,
+ cwd: cwd,
+ umask: umask,
+ }
+ f.EnableLeakCheck("kernel.FSContext")
+ return &f
+}
+
+// NewFSContextVFS2 returns a new filesystem context.
+func NewFSContextVFS2(root, cwd vfs.VirtualDentry, umask uint) *FSContext {
+ root.IncRef()
+ cwd.IncRef()
+ f := FSContext{
+ rootVFS2: root,
+ cwdVFS2: cwd,
+ umask: umask,
+ }
+ f.EnableLeakCheck("kernel.FSContext")
+ return &f
+}
+
+// destroy is the destructor for an FSContext.
+//
+// This will call DecRef on both root and cwd Dirents. If either call to
+// DecRef returns an error, then it will be propagated. If both calls to
+// DecRef return an error, then the one from root.DecRef will be propagated.
+//
+// Note that there may still be calls to WorkingDirectory() or RootDirectory()
+// (that return nil). This is because valid references may still be held via
+// proc files or other mechanisms.
+func (f *FSContext) destroy() {
+ // Hold f.mu so that we don't race with RootDirectory() and
+ // WorkingDirectory().
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ if VFS2Enabled {
+ f.rootVFS2.DecRef()
+ f.rootVFS2 = vfs.VirtualDentry{}
+ f.cwdVFS2.DecRef()
+ f.cwdVFS2 = vfs.VirtualDentry{}
+ } else {
+ f.root.DecRef()
+ f.root = nil
+ f.cwd.DecRef()
+ f.cwd = nil
+ }
+}
+
+// DecRef implements RefCounter.DecRef with destructor f.destroy.
+func (f *FSContext) DecRef() {
+ f.DecRefWithDestructor(f.destroy)
+}
+
+// Fork forks this FSContext.
+//
+// This is not a valid call after destroy.
+func (f *FSContext) Fork() *FSContext {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ if VFS2Enabled {
+ f.cwdVFS2.IncRef()
+ f.rootVFS2.IncRef()
+ } else {
+ f.cwd.IncRef()
+ f.root.IncRef()
+ }
+
+ return &FSContext{
+ cwd: f.cwd,
+ root: f.root,
+ cwdVFS2: f.cwdVFS2,
+ rootVFS2: f.rootVFS2,
+ umask: f.umask,
+ }
+}
+
+// WorkingDirectory returns the current working directory.
+//
+// This will return nil if called after destroy(), otherwise it will return a
+// Dirent with a reference taken.
+func (f *FSContext) WorkingDirectory() *fs.Dirent {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ f.cwd.IncRef()
+ return f.cwd
+}
+
+// WorkingDirectoryVFS2 returns the current working directory.
+//
+// This will return nil if called after destroy(), otherwise it will return a
+// Dirent with a reference taken.
+func (f *FSContext) WorkingDirectoryVFS2() vfs.VirtualDentry {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ f.cwdVFS2.IncRef()
+ return f.cwdVFS2
+}
+
+// SetWorkingDirectory sets the current working directory.
+// This will take an extra reference on the Dirent.
+//
+// This is not a valid call after destroy.
+func (f *FSContext) SetWorkingDirectory(d *fs.Dirent) {
+ if d == nil {
+ panic("FSContext.SetWorkingDirectory called with nil dirent")
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ if f.cwd == nil {
+ panic(fmt.Sprintf("FSContext.SetWorkingDirectory(%v)) called after destroy", d))
+ }
+
+ old := f.cwd
+ f.cwd = d
+ d.IncRef()
+ old.DecRef()
+}
+
+// SetWorkingDirectoryVFS2 sets the current working directory.
+// This will take an extra reference on the VirtualDentry.
+//
+// This is not a valid call after destroy.
+func (f *FSContext) SetWorkingDirectoryVFS2(d vfs.VirtualDentry) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ old := f.cwdVFS2
+ f.cwdVFS2 = d
+ d.IncRef()
+ old.DecRef()
+}
+
+// RootDirectory returns the current filesystem root.
+//
+// This will return nil if called after destroy(), otherwise it will return a
+// Dirent with a reference taken.
+func (f *FSContext) RootDirectory() *fs.Dirent {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ if f.root != nil {
+ f.root.IncRef()
+ }
+ return f.root
+}
+
+// RootDirectoryVFS2 returns the current filesystem root.
+//
+// This will return nil if called after destroy(), otherwise it will return a
+// Dirent with a reference taken.
+func (f *FSContext) RootDirectoryVFS2() vfs.VirtualDentry {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ f.rootVFS2.IncRef()
+ return f.rootVFS2
+}
+
+// SetRootDirectory sets the root directory.
+// This will take an extra reference on the Dirent.
+//
+// This is not a valid call after free.
+func (f *FSContext) SetRootDirectory(d *fs.Dirent) {
+ if d == nil {
+ panic("FSContext.SetRootDirectory called with nil dirent")
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ if f.root == nil {
+ panic(fmt.Sprintf("FSContext.SetRootDirectory(%v)) called after destroy", d))
+ }
+
+ old := f.root
+ f.root = d
+ d.IncRef()
+ old.DecRef()
+}
+
+// SetRootDirectoryVFS2 sets the root directory. It takes a reference on vd.
+//
+// This is not a valid call after free.
+func (f *FSContext) SetRootDirectoryVFS2(vd vfs.VirtualDentry) {
+ if !vd.Ok() {
+ panic("FSContext.SetRootDirectoryVFS2 called with zero-value VirtualDentry")
+ }
+
+ f.mu.Lock()
+
+ if !f.rootVFS2.Ok() {
+ f.mu.Unlock()
+ panic(fmt.Sprintf("FSContext.SetRootDirectoryVFS2(%v)) called after destroy", vd))
+ }
+
+ old := f.rootVFS2
+ vd.IncRef()
+ f.rootVFS2 = vd
+ f.mu.Unlock()
+ old.DecRef()
+}
+
+// Umask returns the current umask.
+func (f *FSContext) Umask() uint {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ return f.umask
+}
+
+// SwapUmask atomically sets the current umask and returns the old umask.
+func (f *FSContext) SwapUmask(mask uint) uint {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ old := f.umask
+ f.umask = mask
+ return old
+}
diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD
new file mode 100644
index 000000000..c5021f2db
--- /dev/null
+++ b/pkg/sentry/kernel/futex/BUILD
@@ -0,0 +1,57 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "atomicptr_bucket",
+ out = "atomicptr_bucket_unsafe.go",
+ package = "futex",
+ suffix = "Bucket",
+ template = "//pkg/sync:generic_atomicptr",
+ types = {
+ "Value": "bucket",
+ },
+)
+
+go_template_instance(
+ name = "waiter_list",
+ out = "waiter_list.go",
+ package = "futex",
+ prefix = "waiter",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*Waiter",
+ "Linker": "*Waiter",
+ },
+)
+
+go_library(
+ name = "futex",
+ srcs = [
+ "atomicptr_bucket_unsafe.go",
+ "futex.go",
+ "waiter_list.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/sentry/memmap",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
+
+go_test(
+ name = "futex_test",
+ size = "small",
+ srcs = ["futex_test.go"],
+ library = ":futex",
+ deps = [
+ "//pkg/sync",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go
new file mode 100644
index 000000000..732e66da4
--- /dev/null
+++ b/pkg/sentry/kernel/futex/futex.go
@@ -0,0 +1,795 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package futex provides an implementation of the futex interface as found in
+// the Linux kernel. It allows one to easily transform Wait() calls into waits
+// on a channel, which is useful in a Go-based kernel, for example.
+package futex
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// KeyKind indicates the type of a Key.
+type KeyKind int
+
+const (
+ // KindPrivate indicates a private futex (a futex syscall with the
+ // FUTEX_PRIVATE_FLAG set).
+ KindPrivate KeyKind = iota
+
+ // KindSharedPrivate indicates a shared futex on a private memory mapping.
+ // Although KindPrivate and KindSharedPrivate futexes both use memory
+ // addresses to identify futexes, they do not interoperate (in Linux, the
+ // two are distinguished by the FUT_OFF_MMSHARED flag, which is used in key
+ // comparison).
+ KindSharedPrivate
+
+ // KindSharedMappable indicates a shared futex on a memory mapping other
+ // than a private anonymous memory mapping.
+ KindSharedMappable
+)
+
+// Key represents something that a futex waiter may wait on.
+type Key struct {
+ // Kind is the type of the Key.
+ Kind KeyKind
+
+ // Mappable is the memory-mapped object that is represented by the Key.
+ // Mappable is always nil if Kind is not KindSharedMappable, and may be nil
+ // even if it is.
+ Mappable memmap.Mappable
+
+ // MappingIdentity is the MappingIdentity associated with Mappable.
+ // MappingIdentity is always nil is Mappable is nil, and may be nil even if
+ // it isn't.
+ MappingIdentity memmap.MappingIdentity
+
+ // If Kind is KindPrivate or KindSharedPrivate, Offset is the represented
+ // memory address. Otherwise, Offset is the represented offset into
+ // Mappable.
+ Offset uint64
+}
+
+func (k *Key) release() {
+ if k.MappingIdentity != nil {
+ k.MappingIdentity.DecRef()
+ }
+ k.Mappable = nil
+ k.MappingIdentity = nil
+}
+
+func (k *Key) clone() Key {
+ if k.MappingIdentity != nil {
+ k.MappingIdentity.IncRef()
+ }
+ return *k
+}
+
+// Preconditions: k.Kind == KindPrivate or KindSharedPrivate.
+func (k *Key) addr() usermem.Addr {
+ return usermem.Addr(k.Offset)
+}
+
+// matches returns true if a wakeup on k2 should wake a waiter waiting on k.
+func (k *Key) matches(k2 *Key) bool {
+ // k.MappingIdentity is ignored; it's only used for reference counting.
+ return k.Kind == k2.Kind && k.Mappable == k2.Mappable && k.Offset == k2.Offset
+}
+
+// Target abstracts memory accesses and keys.
+type Target interface {
+ // SwapUint32 gives access to usermem.IO.SwapUint32.
+ SwapUint32(addr usermem.Addr, new uint32) (uint32, error)
+
+ // CompareAndSwap gives access to usermem.IO.CompareAndSwapUint32.
+ CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error)
+
+ // LoadUint32 gives access to usermem.IO.LoadUint32.
+ LoadUint32(addr usermem.Addr) (uint32, error)
+
+ // GetSharedKey returns a Key with kind KindSharedPrivate or
+ // KindSharedMappable corresponding to the memory mapped at address addr.
+ //
+ // If GetSharedKey returns a Key with a non-nil MappingIdentity, a
+ // reference is held on the MappingIdentity, which must be dropped by the
+ // caller when the Key is no longer in use.
+ GetSharedKey(addr usermem.Addr) (Key, error)
+}
+
+// check performs a basic equality check on the given address.
+func check(t Target, addr usermem.Addr, val uint32) error {
+ cur, err := t.LoadUint32(addr)
+ if err != nil {
+ return err
+ }
+ if cur != val {
+ return syserror.EAGAIN
+ }
+ return nil
+}
+
+// atomicOp performs a complex operation on the given address.
+func atomicOp(t Target, addr usermem.Addr, opIn uint32) (bool, error) {
+ opType := (opIn >> 28) & 0xf
+ cmp := (opIn >> 24) & 0xf
+ opArg := (opIn >> 12) & 0xfff
+ cmpArg := opIn & 0xfff
+
+ if opType&linux.FUTEX_OP_OPARG_SHIFT != 0 {
+ opArg = 1 << opArg
+ opType &^= linux.FUTEX_OP_OPARG_SHIFT // Clear flag.
+ }
+
+ var (
+ oldVal uint32
+ err error
+ )
+ if opType == linux.FUTEX_OP_SET {
+ oldVal, err = t.SwapUint32(addr, opArg)
+ if err != nil {
+ return false, err
+ }
+ } else {
+ for {
+ oldVal, err = t.LoadUint32(addr)
+ if err != nil {
+ return false, err
+ }
+ var newVal uint32
+ switch opType {
+ case linux.FUTEX_OP_ADD:
+ newVal = oldVal + opArg
+ case linux.FUTEX_OP_OR:
+ newVal = oldVal | opArg
+ case linux.FUTEX_OP_ANDN:
+ newVal = oldVal &^ opArg
+ case linux.FUTEX_OP_XOR:
+ newVal = oldVal ^ opArg
+ default:
+ return false, syserror.ENOSYS
+ }
+ prev, err := t.CompareAndSwapUint32(addr, oldVal, newVal)
+ if err != nil {
+ return false, err
+ }
+ if prev == oldVal {
+ break // Success.
+ }
+ }
+ }
+
+ switch cmp {
+ case linux.FUTEX_OP_CMP_EQ:
+ return oldVal == cmpArg, nil
+ case linux.FUTEX_OP_CMP_NE:
+ return oldVal != cmpArg, nil
+ case linux.FUTEX_OP_CMP_LT:
+ return oldVal < cmpArg, nil
+ case linux.FUTEX_OP_CMP_LE:
+ return oldVal <= cmpArg, nil
+ case linux.FUTEX_OP_CMP_GT:
+ return oldVal > cmpArg, nil
+ case linux.FUTEX_OP_CMP_GE:
+ return oldVal >= cmpArg, nil
+ default:
+ return false, syserror.ENOSYS
+ }
+}
+
+// Waiter is the struct which gets enqueued into buckets for wake up routines
+// and requeue routines to scan and notify. Once a Waiter has been enqueued by
+// WaitPrepare(), callers may listen on C for wake up events.
+type Waiter struct {
+ // Synchronization:
+ //
+ // - A Waiter that is not enqueued in a bucket is exclusively owned (no
+ // synchronization applies).
+ //
+ // - A Waiter is enqueued in a bucket by calling WaitPrepare(). After this,
+ // waiterEntry, bucket, and key are protected by the bucket.mu ("bucket
+ // lock") of the containing bucket, and bitmask is immutable. Note that
+ // since bucket is mutated using atomic memory operations, bucket.Load()
+ // may be called without holding the bucket lock, although it may change
+ // racily. See WaitComplete().
+ //
+ // - A Waiter is only guaranteed to be no longer queued after calling
+ // WaitComplete().
+
+ // waiterEntry links Waiter into bucket.waiters.
+ waiterEntry
+
+ // bucket is the bucket this waiter is queued in. If bucket is nil, the
+ // waiter is not waiting and is not in any bucket.
+ bucket AtomicPtrBucket
+
+ // C is sent to when the Waiter is woken.
+ C chan struct{}
+
+ // key is what this waiter is waiting on.
+ key Key
+
+ // The bitmask we're waiting on.
+ // This is used the case of a FUTEX_WAKE_BITSET.
+ bitmask uint32
+
+ // tid is the thread ID for the waiter in case this is a PI mutex.
+ tid uint32
+}
+
+// NewWaiter returns a new unqueued Waiter.
+func NewWaiter() *Waiter {
+ return &Waiter{
+ C: make(chan struct{}, 1),
+ }
+}
+
+// woken returns true if w has been woken since the last call to WaitPrepare.
+func (w *Waiter) woken() bool {
+ return len(w.C) != 0
+}
+
+// bucket holds a list of waiters for a given address hash.
+//
+// +stateify savable
+type bucket struct {
+ // mu protects waiters and contained Waiter state. See comment in Waiter.
+ mu sync.Mutex `state:"nosave"`
+
+ waiters waiterList `state:"zerovalue"`
+}
+
+// wakeLocked wakes up to n waiters matching the bitmask at the addr for this
+// bucket and returns the number of waiters woken.
+//
+// Preconditions: b.mu must be locked.
+func (b *bucket) wakeLocked(key *Key, bitmask uint32, n int) int {
+ done := 0
+ for w := b.waiters.Front(); done < n && w != nil; {
+ if !w.key.matches(key) || w.bitmask&bitmask == 0 {
+ // Not matching.
+ w = w.Next()
+ continue
+ }
+
+ // Remove from the bucket and wake the waiter.
+ woke := w
+ w = w.Next() // Next iteration.
+ b.wakeWaiterLocked(woke)
+ done++
+ }
+ return done
+}
+
+func (b *bucket) wakeWaiterLocked(w *Waiter) {
+ // Remove from the bucket and wake the waiter.
+ b.waiters.Remove(w)
+ w.C <- struct{}{}
+
+ // NOTE: The above channel write establishes a write barrier according
+ // to the memory model, so nothing may be ordered around it. Since
+ // we've dequeued w and will never touch it again, we can safely
+ // store nil to w.bucket here and allow the WaitComplete() to
+ // short-circuit grabbing the bucket lock. If they somehow miss the
+ // store, we are still holding the lock, so we can know that they won't
+ // dequeue w, assume it's free and have the below operation
+ // afterwards.
+ w.bucket.Store(nil)
+}
+
+// requeueLocked takes n waiters from the bucket and moves them to naddr on the
+// bucket "to".
+//
+// Preconditions: b and to must be locked.
+func (b *bucket) requeueLocked(to *bucket, key, nkey *Key, n int) int {
+ done := 0
+ for w := b.waiters.Front(); done < n && w != nil; {
+ if !w.key.matches(key) {
+ // Not matching.
+ w = w.Next()
+ continue
+ }
+
+ requeued := w
+ w = w.Next() // Next iteration.
+ b.waiters.Remove(requeued)
+ requeued.key.release()
+ requeued.key = nkey.clone()
+ to.waiters.PushBack(requeued)
+ requeued.bucket.Store(to)
+ done++
+ }
+ return done
+}
+
+const (
+ // bucketCount is the number of buckets per Manager. By having many of
+ // these we reduce contention when concurrent yet unrelated calls are made.
+ bucketCount = 1 << bucketCountBits
+ bucketCountBits = 10
+)
+
+// getKey returns a Key representing address addr in c.
+func getKey(t Target, addr usermem.Addr, private bool) (Key, error) {
+ // Ensure the address is aligned.
+ // It must be a DWORD boundary.
+ if addr&0x3 != 0 {
+ return Key{}, syserror.EINVAL
+ }
+ if private {
+ return Key{Kind: KindPrivate, Offset: uint64(addr)}, nil
+ }
+ return t.GetSharedKey(addr)
+}
+
+// bucketIndexForAddr returns the index into Manager.buckets for addr.
+func bucketIndexForAddr(addr usermem.Addr) uintptr {
+ // - The bottom 2 bits of addr must be 0, per getKey.
+ //
+ // - On amd64, the top 16 bits of addr (bits 48-63) must be equal to bit 47
+ // for a canonical address, and (on all existing platforms) bit 47 must be
+ // 0 for an application address.
+ //
+ // Thus 19 bits of addr are "useless" for hashing, leaving only 45 "useful"
+ // bits. We choose one of the simplest possible hash functions that at
+ // least uses all 45 useful bits in the output, given that bucketCountBits
+ // == 10. This hash function also has the property that it will usually map
+ // adjacent addresses to adjacent buckets, slightly improving memory
+ // locality when an application synchronization structure uses multiple
+ // nearby futexes.
+ //
+ // Note that despite the large number of arithmetic operations in the
+ // function, many components can be computed in parallel, such that the
+ // critical path is 1 bit shift + 3 additions (2 in h1, then h1 + h2). This
+ // is also why h1 and h2 are grouped separately; for "(addr >> 2) + ... +
+ // (addr >> 42)" without any additional grouping, the compiler puts all 4
+ // additions in the critical path.
+ h1 := uintptr(addr>>2) + uintptr(addr>>12) + uintptr(addr>>22)
+ h2 := uintptr(addr>>32) + uintptr(addr>>42)
+ return (h1 + h2) % bucketCount
+}
+
+// Manager holds futex state for a single virtual address space.
+//
+// +stateify savable
+type Manager struct {
+ // privateBuckets holds buckets for KindPrivate and KindSharedPrivate
+ // futexes.
+ privateBuckets [bucketCount]bucket `state:"zerovalue"`
+
+ // sharedBucket is the bucket for KindSharedMappable futexes. sharedBucket
+ // may be shared by multiple Managers. The sharedBucket pointer is
+ // immutable.
+ sharedBucket *bucket
+}
+
+// NewManager returns an initialized futex manager.
+func NewManager() *Manager {
+ return &Manager{
+ sharedBucket: &bucket{},
+ }
+}
+
+// Fork returns a new Manager. Shared futex clients using the returned Manager
+// may interoperate with those using m.
+func (m *Manager) Fork() *Manager {
+ return &Manager{
+ sharedBucket: m.sharedBucket,
+ }
+}
+
+// lockBucket returns a locked bucket for the given key.
+func (m *Manager) lockBucket(k *Key) *bucket {
+ var b *bucket
+ if k.Kind == KindSharedMappable {
+ b = m.sharedBucket
+ } else {
+ b = &m.privateBuckets[bucketIndexForAddr(k.addr())]
+ }
+ b.mu.Lock()
+ return b
+}
+
+// lockBuckets returns locked buckets for the given keys.
+func (m *Manager) lockBuckets(k1, k2 *Key) (*bucket, *bucket) {
+ // Buckets must be consistently ordered to avoid circular lock
+ // dependencies. We order buckets in m.privateBuckets by index (lowest
+ // index first), and all buckets in m.privateBuckets precede
+ // m.sharedBucket.
+
+ // Handle the common case first:
+ if k1.Kind != KindSharedMappable && k2.Kind != KindSharedMappable {
+ i1 := bucketIndexForAddr(k1.addr())
+ i2 := bucketIndexForAddr(k2.addr())
+ b1 := &m.privateBuckets[i1]
+ b2 := &m.privateBuckets[i2]
+ switch {
+ case i1 < i2:
+ b1.mu.Lock()
+ b2.mu.Lock()
+ case i2 < i1:
+ b2.mu.Lock()
+ b1.mu.Lock()
+ default:
+ b1.mu.Lock()
+ }
+ return b1, b2
+ }
+
+ // At least one of b1 or b2 should be m.sharedBucket.
+ b1 := m.sharedBucket
+ b2 := m.sharedBucket
+ if k1.Kind != KindSharedMappable {
+ b1 = m.lockBucket(k1)
+ } else if k2.Kind != KindSharedMappable {
+ b2 = m.lockBucket(k2)
+ }
+ m.sharedBucket.mu.Lock()
+ return b1, b2
+}
+
+// Wake wakes up to n waiters matching the bitmask on the given addr.
+// The number of waiters woken is returned.
+func (m *Manager) Wake(t Target, addr usermem.Addr, private bool, bitmask uint32, n int) (int, error) {
+ // This function is very hot; avoid defer.
+ k, err := getKey(t, addr, private)
+ if err != nil {
+ return 0, err
+ }
+
+ b := m.lockBucket(&k)
+ r := b.wakeLocked(&k, bitmask, n)
+
+ b.mu.Unlock()
+ k.release()
+ return r, nil
+}
+
+func (m *Manager) doRequeue(t Target, addr, naddr usermem.Addr, private bool, checkval bool, val uint32, nwake int, nreq int) (int, error) {
+ k1, err := getKey(t, addr, private)
+ if err != nil {
+ return 0, err
+ }
+ defer k1.release()
+ k2, err := getKey(t, naddr, private)
+ if err != nil {
+ return 0, err
+ }
+ defer k2.release()
+
+ b1, b2 := m.lockBuckets(&k1, &k2)
+ defer b1.mu.Unlock()
+ if b2 != b1 {
+ defer b2.mu.Unlock()
+ }
+
+ if checkval {
+ if err := check(t, addr, val); err != nil {
+ return 0, err
+ }
+ }
+
+ // Wake the number required.
+ done := b1.wakeLocked(&k1, ^uint32(0), nwake)
+
+ // Requeue the number required.
+ b1.requeueLocked(b2, &k1, &k2, nreq)
+
+ return done, nil
+}
+
+// Requeue wakes up to nwake waiters on the given addr, and unconditionally
+// requeues up to nreq waiters on naddr.
+func (m *Manager) Requeue(t Target, addr, naddr usermem.Addr, private bool, nwake int, nreq int) (int, error) {
+ return m.doRequeue(t, addr, naddr, private, false, 0, nwake, nreq)
+}
+
+// RequeueCmp atomically checks that the addr contains val (via the Target),
+// wakes up to nwake waiters on addr and then unconditionally requeues nreq
+// waiters on naddr.
+func (m *Manager) RequeueCmp(t Target, addr, naddr usermem.Addr, private bool, val uint32, nwake int, nreq int) (int, error) {
+ return m.doRequeue(t, addr, naddr, private, true, val, nwake, nreq)
+}
+
+// WakeOp atomically applies op to the memory address addr2, wakes up to nwake1
+// waiters unconditionally from addr1, and, based on the original value at addr2
+// and a comparison encoded in op, wakes up to nwake2 waiters from addr2.
+// It returns the total number of waiters woken.
+func (m *Manager) WakeOp(t Target, addr1, addr2 usermem.Addr, private bool, nwake1 int, nwake2 int, op uint32) (int, error) {
+ k1, err := getKey(t, addr1, private)
+ if err != nil {
+ return 0, err
+ }
+ defer k1.release()
+ k2, err := getKey(t, addr2, private)
+ if err != nil {
+ return 0, err
+ }
+ defer k2.release()
+
+ b1, b2 := m.lockBuckets(&k1, &k2)
+ defer b1.mu.Unlock()
+ if b2 != b1 {
+ defer b2.mu.Unlock()
+ }
+
+ done := 0
+ cond, err := atomicOp(t, addr2, op)
+ if err != nil {
+ return 0, err
+ }
+
+ // Wake up up to nwake1 entries from the first bucket.
+ done = b1.wakeLocked(&k1, ^uint32(0), nwake1)
+
+ // Wake up up to nwake2 entries from the second bucket if the
+ // operation yielded true.
+ if cond {
+ done += b2.wakeLocked(&k2, ^uint32(0), nwake2)
+ }
+
+ return done, nil
+}
+
+// WaitPrepare atomically checks that addr contains val (via the Checker), then
+// enqueues w to be woken by a send to w.C. If WaitPrepare returns nil, the
+// Waiter must be subsequently removed by calling WaitComplete, whether or not
+// a wakeup is received on w.C.
+func (m *Manager) WaitPrepare(w *Waiter, t Target, addr usermem.Addr, private bool, val uint32, bitmask uint32) error {
+ k, err := getKey(t, addr, private)
+ if err != nil {
+ return err
+ }
+ // Ownership of k is transferred to w below.
+
+ // Prepare the Waiter before taking the bucket lock.
+ select {
+ case <-w.C:
+ default:
+ }
+ w.key = k
+ w.bitmask = bitmask
+
+ b := m.lockBucket(&k)
+ // This function is very hot; avoid defer.
+
+ // Perform our atomic check.
+ if err := check(t, addr, val); err != nil {
+ b.mu.Unlock()
+ w.key.release()
+ return err
+ }
+
+ // Add the waiter to the bucket.
+ b.waiters.PushBack(w)
+ w.bucket.Store(b)
+
+ b.mu.Unlock()
+ return nil
+}
+
+// WaitComplete must be called when a Waiter previously added by WaitPrepare is
+// no longer eligible to be woken.
+func (m *Manager) WaitComplete(w *Waiter) {
+ // Remove w from the bucket it's in.
+ for {
+ b := w.bucket.Load()
+
+ // If b is nil, the waiter isn't in any bucket anymore. This can't be
+ // racy because the waiter can't be concurrently re-queued in another
+ // bucket.
+ if b == nil {
+ break
+ }
+
+ // Take the bucket lock. Note that without holding the bucket lock, the
+ // waiter is not guaranteed to stay in that bucket, so after we take
+ // the bucket lock, we must ensure that the bucket hasn't changed: if
+ // it happens to have changed, we release the old bucket lock and try
+ // again with the new bucket; if it hasn't changed, we know it won't
+ // change now because we hold the lock.
+ b.mu.Lock()
+ if b != w.bucket.Load() {
+ b.mu.Unlock()
+ continue
+ }
+
+ // Remove waiter from bucket.
+ b.waiters.Remove(w)
+ w.bucket.Store(nil)
+ b.mu.Unlock()
+ break
+ }
+
+ // Release references held by the waiter.
+ w.key.release()
+}
+
+// LockPI attempts to lock the futex following the Priority-inheritance futex
+// rules. The lock is acquired only when 'addr' points to 0. The TID of the
+// calling task is set to 'addr' to indicate the futex is owned. It returns true
+// if the futex was successfully acquired.
+//
+// FUTEX_OWNER_DIED is only set by the Linux when robust lists are in use (see
+// exit_robust_list()). Given we don't support robust lists, although handled
+// below, it's never set.
+func (m *Manager) LockPI(w *Waiter, t Target, addr usermem.Addr, tid uint32, private, try bool) (bool, error) {
+ k, err := getKey(t, addr, private)
+ if err != nil {
+ return false, err
+ }
+ // Ownership of k is transferred to w below.
+
+ // Prepare the Waiter before taking the bucket lock.
+ select {
+ case <-w.C:
+ default:
+ }
+ w.key = k
+ w.tid = tid
+
+ b := m.lockBucket(&k)
+ // Hot function: avoid defers.
+
+ success, err := m.lockPILocked(w, t, addr, tid, b, try)
+ if err != nil {
+ w.key.release()
+ b.mu.Unlock()
+ return false, err
+ }
+ if success || try {
+ // Release waiter if it's not going to be a wait.
+ w.key.release()
+ }
+ b.mu.Unlock()
+ return success, nil
+}
+
+func (m *Manager) lockPILocked(w *Waiter, t Target, addr usermem.Addr, tid uint32, b *bucket, try bool) (bool, error) {
+ for {
+ cur, err := t.LoadUint32(addr)
+ if err != nil {
+ return false, err
+ }
+ if (cur & linux.FUTEX_TID_MASK) == tid {
+ return false, syserror.EDEADLK
+ }
+
+ if (cur & linux.FUTEX_TID_MASK) == 0 {
+ // No owner and no waiters, try to acquire the futex.
+
+ // Set TID and preserve owner died status.
+ val := tid
+ val |= cur & linux.FUTEX_OWNER_DIED
+ prev, err := t.CompareAndSwapUint32(addr, cur, val)
+ if err != nil {
+ return false, err
+ }
+ if prev != cur {
+ // CAS failed, retry...
+ // Linux reacquires the bucket lock on retries, which will re-lookup the
+ // mapping at the futex address. However, retrying while holding the
+ // lock is more efficient and reduces the chance of another conflict.
+ continue
+ }
+ // Futex acquired.
+ return true, nil
+ }
+
+ // Futex is already owned, prepare to wait.
+
+ if try {
+ // Caller doesn't want to wait.
+ return false, nil
+ }
+
+ // Set waiters bit if not set yet.
+ if cur&linux.FUTEX_WAITERS == 0 {
+ prev, err := t.CompareAndSwapUint32(addr, cur, cur|linux.FUTEX_WAITERS)
+ if err != nil {
+ return false, err
+ }
+ if prev != cur {
+ // CAS failed, retry...
+ continue
+ }
+ }
+
+ // Add the waiter to the bucket.
+ b.waiters.PushBack(w)
+ w.bucket.Store(b)
+ return false, nil
+ }
+}
+
+// UnlockPI unlock the futex following the Priority-inheritance futex
+// rules. The address provided must contain the caller's TID. If there are
+// waiters, TID of the next waiter (FIFO) is set to the given address, and the
+// waiter woken up. If there are no waiters, 0 is set to the address.
+func (m *Manager) UnlockPI(t Target, addr usermem.Addr, tid uint32, private bool) error {
+ k, err := getKey(t, addr, private)
+ if err != nil {
+ return err
+ }
+ b := m.lockBucket(&k)
+
+ err = m.unlockPILocked(t, addr, tid, b, &k)
+
+ k.release()
+ b.mu.Unlock()
+ return err
+}
+
+func (m *Manager) unlockPILocked(t Target, addr usermem.Addr, tid uint32, b *bucket, key *Key) error {
+ cur, err := t.LoadUint32(addr)
+ if err != nil {
+ return err
+ }
+
+ if (cur & linux.FUTEX_TID_MASK) != tid {
+ return syserror.EPERM
+ }
+
+ var next *Waiter // Who's the next owner?
+ var next2 *Waiter // Who's the one after that?
+ for w := b.waiters.Front(); w != nil; w = w.Next() {
+ if !w.key.matches(key) {
+ continue
+ }
+
+ if next == nil {
+ next = w
+ } else {
+ next2 = w
+ break
+ }
+ }
+
+ if next == nil {
+ // It's safe to set 0 because there are no waiters, no new owner, and the
+ // executing task is the current owner (no owner died bit).
+ prev, err := t.CompareAndSwapUint32(addr, cur, 0)
+ if err != nil {
+ return err
+ }
+ if prev != cur {
+ // Let user mode handle CAS races. This is different than lock, which
+ // retries when CAS fails.
+ return syserror.EAGAIN
+ }
+ return nil
+ }
+
+ // Set next owner's TID, waiters if there are any. Resets owner died bit, if
+ // set, because the executing task takes over as the owner.
+ val := next.tid
+ if next2 != nil {
+ val |= linux.FUTEX_WAITERS
+ }
+
+ prev, err := t.CompareAndSwapUint32(addr, cur, val)
+ if err != nil {
+ return err
+ }
+ if prev != cur {
+ return syserror.EINVAL
+ }
+
+ b.wakeWaiterLocked(next)
+ return nil
+}
diff --git a/pkg/sentry/kernel/futex/futex_test.go b/pkg/sentry/kernel/futex/futex_test.go
new file mode 100644
index 000000000..7c5c7665b
--- /dev/null
+++ b/pkg/sentry/kernel/futex/futex_test.go
@@ -0,0 +1,530 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package futex
+
+import (
+ "math"
+ "runtime"
+ "sync/atomic"
+ "syscall"
+ "testing"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// testData implements the Target interface, and allows us to
+// treat the address passed for futex operations as an index in
+// a byte slice for testing simplicity.
+type testData []byte
+
+const sizeofInt32 = 4
+
+func newTestData(size uint) testData {
+ return make([]byte, size)
+}
+
+func (t testData) SwapUint32(addr usermem.Addr, new uint32) (uint32, error) {
+ val := atomic.SwapUint32((*uint32)(unsafe.Pointer(&t[addr])), new)
+ return val, nil
+}
+
+func (t testData) CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) {
+ if atomic.CompareAndSwapUint32((*uint32)(unsafe.Pointer(&t[addr])), old, new) {
+ return old, nil
+ }
+ return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t[addr]))), nil
+}
+
+func (t testData) LoadUint32(addr usermem.Addr) (uint32, error) {
+ return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t[addr]))), nil
+}
+
+func (t testData) GetSharedKey(addr usermem.Addr) (Key, error) {
+ return Key{
+ Kind: KindSharedMappable,
+ Offset: uint64(addr),
+ }, nil
+}
+
+func futexKind(private bool) string {
+ if private {
+ return "private"
+ }
+ return "shared"
+}
+
+func newPreparedTestWaiter(t *testing.T, m *Manager, ta Target, addr usermem.Addr, private bool, val uint32, bitmask uint32) *Waiter {
+ w := NewWaiter()
+ if err := m.WaitPrepare(w, ta, addr, private, val, bitmask); err != nil {
+ t.Fatalf("WaitPrepare failed: %v", err)
+ }
+ return w
+}
+
+func TestFutexWake(t *testing.T) {
+ for _, private := range []bool{false, true} {
+ t.Run(futexKind(private), func(t *testing.T) {
+ m := NewManager()
+ d := newTestData(sizeofInt32)
+
+ // Start waiting for wakeup.
+ w := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
+ defer m.WaitComplete(w)
+
+ // Perform a wakeup.
+ if n, err := m.Wake(d, 0, private, ^uint32(0), 1); err != nil || n != 1 {
+ t.Errorf("Wake: got (%d, %v), wanted (1, nil)", n, err)
+ }
+
+ // Expect the waiter to have been woken.
+ if !w.woken() {
+ t.Error("waiter not woken")
+ }
+ })
+ }
+}
+
+func TestFutexWakeBitmask(t *testing.T) {
+ for _, private := range []bool{false, true} {
+ t.Run(futexKind(private), func(t *testing.T) {
+ m := NewManager()
+ d := newTestData(sizeofInt32)
+
+ // Start waiting for wakeup.
+ w := newPreparedTestWaiter(t, m, d, 0, private, 0, 0x0000ffff)
+ defer m.WaitComplete(w)
+
+ // Perform a wakeup using the wrong bitmask.
+ if n, err := m.Wake(d, 0, private, 0xffff0000, 1); err != nil || n != 0 {
+ t.Errorf("Wake with non-matching bitmask: got (%d, %v), wanted (0, nil)", n, err)
+ }
+
+ // Expect the waiter to still be waiting.
+ if w.woken() {
+ t.Error("waiter woken unexpectedly")
+ }
+
+ // Perform a wakeup using the right bitmask.
+ if n, err := m.Wake(d, 0, private, 0x00000001, 1); err != nil || n != 1 {
+ t.Errorf("Wake with matching bitmask: got (%d, %v), wanted (1, nil)", n, err)
+ }
+
+ // Expect that the waiter was woken.
+ if !w.woken() {
+ t.Error("waiter not woken")
+ }
+ })
+ }
+}
+
+func TestFutexWakeTwo(t *testing.T) {
+ for _, private := range []bool{false, true} {
+ t.Run(futexKind(private), func(t *testing.T) {
+ m := NewManager()
+ d := newTestData(sizeofInt32)
+
+ // Start three waiters waiting for wakeup.
+ var ws [3]*Waiter
+ for i := range ws {
+ ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
+ defer m.WaitComplete(ws[i])
+ }
+
+ // Perform two wakeups.
+ const wakeups = 2
+ if n, err := m.Wake(d, 0, private, ^uint32(0), 2); err != nil || n != wakeups {
+ t.Errorf("Wake: got (%d, %v), wanted (%d, nil)", n, err, wakeups)
+ }
+
+ // Expect that exactly two waiters were woken.
+ // We don't get guarantees about exactly which two,
+ // (although we expect them to be w1 and w2).
+ awake := 0
+ for i := range ws {
+ if ws[i].woken() {
+ awake++
+ }
+ }
+ if awake != wakeups {
+ t.Errorf("got %d woken waiters, wanted %d", awake, wakeups)
+ }
+ })
+ }
+}
+
+func TestFutexWakeUnrelated(t *testing.T) {
+ for _, private := range []bool{false, true} {
+ t.Run(futexKind(private), func(t *testing.T) {
+ m := NewManager()
+ d := newTestData(2 * sizeofInt32)
+
+ // Start two waiters waiting for wakeup on different addresses.
+ w1 := newPreparedTestWaiter(t, m, d, 0*sizeofInt32, private, 0, ^uint32(0))
+ defer m.WaitComplete(w1)
+ w2 := newPreparedTestWaiter(t, m, d, 1*sizeofInt32, private, 0, ^uint32(0))
+ defer m.WaitComplete(w2)
+
+ // Perform two wakeups on the second address.
+ if n, err := m.Wake(d, 1*sizeofInt32, private, ^uint32(0), 2); err != nil || n != 1 {
+ t.Errorf("Wake: got (%d, %v), wanted (1, nil)", n, err)
+ }
+
+ // Expect that only the second waiter was woken.
+ if w1.woken() {
+ t.Error("w1 woken unexpectedly")
+ }
+ if !w2.woken() {
+ t.Error("w2 not woken")
+ }
+ })
+ }
+}
+
+func TestWakeOpEmpty(t *testing.T) {
+ for _, private := range []bool{false, true} {
+ t.Run(futexKind(private), func(t *testing.T) {
+ m := NewManager()
+ d := newTestData(2 * sizeofInt32)
+
+ // Perform wakeups with no waiters.
+ if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 10, 0); err != nil || n != 0 {
+ t.Fatalf("WakeOp: got (%d, %v), wanted (0, nil)", n, err)
+ }
+ })
+ }
+}
+
+func TestWakeOpFirstNonEmpty(t *testing.T) {
+ for _, private := range []bool{false, true} {
+ t.Run(futexKind(private), func(t *testing.T) {
+ m := NewManager()
+ d := newTestData(8)
+
+ // Add two waiters on address 0.
+ w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
+ defer m.WaitComplete(w1)
+ w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
+ defer m.WaitComplete(w2)
+
+ // Perform 10 wakeups on address 0.
+ if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 0, 0); err != nil || n != 2 {
+ t.Errorf("WakeOp: got (%d, %v), wanted (2, nil)", n, err)
+ }
+
+ // Expect that both waiters were woken.
+ if !w1.woken() {
+ t.Error("w1 not woken")
+ }
+ if !w2.woken() {
+ t.Error("w2 not woken")
+ }
+ })
+ }
+}
+
+func TestWakeOpSecondNonEmpty(t *testing.T) {
+ for _, private := range []bool{false, true} {
+ t.Run(futexKind(private), func(t *testing.T) {
+ m := NewManager()
+ d := newTestData(8)
+
+ // Add two waiters on address sizeofInt32.
+ w1 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
+ defer m.WaitComplete(w1)
+ w2 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
+ defer m.WaitComplete(w2)
+
+ // Perform 10 wakeups on address sizeofInt32 (contingent on
+ // d.Op(0), which should succeed).
+ if n, err := m.WakeOp(d, 0, sizeofInt32, private, 0, 10, 0); err != nil || n != 2 {
+ t.Errorf("WakeOp: got (%d, %v), wanted (2, nil)", n, err)
+ }
+
+ // Expect that both waiters were woken.
+ if !w1.woken() {
+ t.Error("w1 not woken")
+ }
+ if !w2.woken() {
+ t.Error("w2 not woken")
+ }
+ })
+ }
+}
+
+func TestWakeOpSecondNonEmptyFailingOp(t *testing.T) {
+ for _, private := range []bool{false, true} {
+ t.Run(futexKind(private), func(t *testing.T) {
+ m := NewManager()
+ d := newTestData(8)
+
+ // Add two waiters on address sizeofInt32.
+ w1 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
+ defer m.WaitComplete(w1)
+ w2 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
+ defer m.WaitComplete(w2)
+
+ // Perform 10 wakeups on address sizeofInt32 (contingent on
+ // d.Op(1), which should fail).
+ if n, err := m.WakeOp(d, 0, sizeofInt32, private, 0, 10, 1); err != nil || n != 0 {
+ t.Errorf("WakeOp: got (%d, %v), wanted (0, nil)", n, err)
+ }
+
+ // Expect that neither waiter was woken.
+ if w1.woken() {
+ t.Error("w1 woken unexpectedly")
+ }
+ if w2.woken() {
+ t.Error("w2 woken unexpectedly")
+ }
+ })
+ }
+}
+
+func TestWakeOpAllNonEmpty(t *testing.T) {
+ for _, private := range []bool{false, true} {
+ t.Run(futexKind(private), func(t *testing.T) {
+ m := NewManager()
+ d := newTestData(8)
+
+ // Add two waiters on address 0.
+ w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
+ defer m.WaitComplete(w1)
+ w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
+ defer m.WaitComplete(w2)
+
+ // Add two waiters on address sizeofInt32.
+ w3 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
+ defer m.WaitComplete(w3)
+ w4 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
+ defer m.WaitComplete(w4)
+
+ // Perform 10 wakeups on address 0 (unconditionally), and 10
+ // wakeups on address sizeofInt32 (contingent on d.Op(0), which
+ // should succeed).
+ if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 10, 0); err != nil || n != 4 {
+ t.Errorf("WakeOp: got (%d, %v), wanted (4, nil)", n, err)
+ }
+
+ // Expect that all waiters were woken.
+ if !w1.woken() {
+ t.Error("w1 not woken")
+ }
+ if !w2.woken() {
+ t.Error("w2 not woken")
+ }
+ if !w3.woken() {
+ t.Error("w3 not woken")
+ }
+ if !w4.woken() {
+ t.Error("w4 not woken")
+ }
+ })
+ }
+}
+
+func TestWakeOpAllNonEmptyFailingOp(t *testing.T) {
+ for _, private := range []bool{false, true} {
+ t.Run(futexKind(private), func(t *testing.T) {
+ m := NewManager()
+ d := newTestData(8)
+
+ // Add two waiters on address 0.
+ w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
+ defer m.WaitComplete(w1)
+ w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
+ defer m.WaitComplete(w2)
+
+ // Add two waiters on address sizeofInt32.
+ w3 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
+ defer m.WaitComplete(w3)
+ w4 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
+ defer m.WaitComplete(w4)
+
+ // Perform 10 wakeups on address 0 (unconditionally), and 10
+ // wakeups on address sizeofInt32 (contingent on d.Op(1), which
+ // should fail).
+ if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 10, 1); err != nil || n != 2 {
+ t.Errorf("WakeOp: got (%d, %v), wanted (2, nil)", n, err)
+ }
+
+ // Expect that only the first two waiters were woken.
+ if !w1.woken() {
+ t.Error("w1 not woken")
+ }
+ if !w2.woken() {
+ t.Error("w2 not woken")
+ }
+ if w3.woken() {
+ t.Error("w3 woken unexpectedly")
+ }
+ if w4.woken() {
+ t.Error("w4 woken unexpectedly")
+ }
+ })
+ }
+}
+
+func TestWakeOpSameAddress(t *testing.T) {
+ for _, private := range []bool{false, true} {
+ t.Run(futexKind(private), func(t *testing.T) {
+ m := NewManager()
+ d := newTestData(8)
+
+ // Add four waiters on address 0.
+ var ws [4]*Waiter
+ for i := range ws {
+ ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
+ defer m.WaitComplete(ws[i])
+ }
+
+ // Perform 1 wakeup on address 0 (unconditionally), and 1 wakeup
+ // on address 0 (contingent on d.Op(0), which should succeed).
+ const wakeups = 2
+ if n, err := m.WakeOp(d, 0, 0, private, 1, 1, 0); err != nil || n != wakeups {
+ t.Errorf("WakeOp: got (%d, %v), wanted (%d, nil)", n, err, wakeups)
+ }
+
+ // Expect that exactly two waiters were woken.
+ awake := 0
+ for i := range ws {
+ if ws[i].woken() {
+ awake++
+ }
+ }
+ if awake != wakeups {
+ t.Errorf("got %d woken waiters, wanted %d", awake, wakeups)
+ }
+ })
+ }
+}
+
+func TestWakeOpSameAddressFailingOp(t *testing.T) {
+ for _, private := range []bool{false, true} {
+ t.Run(futexKind(private), func(t *testing.T) {
+ m := NewManager()
+ d := newTestData(8)
+
+ // Add four waiters on address 0.
+ var ws [4]*Waiter
+ for i := range ws {
+ ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
+ defer m.WaitComplete(ws[i])
+ }
+
+ // Perform 1 wakeup on address 0 (unconditionally), and 1 wakeup
+ // on address 0 (contingent on d.Op(1), which should fail).
+ const wakeups = 1
+ if n, err := m.WakeOp(d, 0, 0, private, 1, 1, 1); err != nil || n != wakeups {
+ t.Errorf("WakeOp: got (%d, %v), wanted (%d, nil)", n, err, wakeups)
+ }
+
+ // Expect that exactly one waiter was woken.
+ awake := 0
+ for i := range ws {
+ if ws[i].woken() {
+ awake++
+ }
+ }
+ if awake != wakeups {
+ t.Errorf("got %d woken waiters, wanted %d", awake, wakeups)
+ }
+ })
+ }
+}
+
+const (
+ testMutexSize = sizeofInt32
+ testMutexLocked uint32 = 1
+ testMutexUnlocked uint32 = 0
+)
+
+// testMutex ties together a testData slice, an address, and a
+// futex manager in order to implement the sync.Locker interface.
+// Beyond being used as a Locker, this is a simple mechanism for
+// changing the underlying values for simpler tests.
+type testMutex struct {
+ a usermem.Addr
+ d testData
+ m *Manager
+}
+
+func newTestMutex(addr usermem.Addr, d testData, m *Manager) *testMutex {
+ return &testMutex{a: addr, d: d, m: m}
+}
+
+// Lock acquires the testMutex.
+// This may wait for it to be available via the futex manager.
+func (t *testMutex) Lock() {
+ for {
+ // Attempt to grab the lock.
+ if atomic.CompareAndSwapUint32(
+ (*uint32)(unsafe.Pointer(&t.d[t.a])),
+ testMutexUnlocked,
+ testMutexLocked) {
+ // Lock held.
+ return
+ }
+
+ // Wait for it to be "not locked".
+ w := NewWaiter()
+ err := t.m.WaitPrepare(w, t.d, t.a, true, testMutexLocked, ^uint32(0))
+ if err == syscall.EAGAIN {
+ continue
+ }
+ if err != nil {
+ // Should never happen.
+ panic("WaitPrepare returned unexpected error: " + err.Error())
+ }
+ <-w.C
+ t.m.WaitComplete(w)
+ }
+}
+
+// Unlock releases the testMutex.
+// This will notify any waiters via the futex manager.
+func (t *testMutex) Unlock() {
+ // Unlock.
+ atomic.StoreUint32((*uint32)(unsafe.Pointer(&t.d[t.a])), testMutexUnlocked)
+
+ // Notify all waiters.
+ t.m.Wake(t.d, t.a, true, ^uint32(0), math.MaxInt32)
+}
+
+// This function was shamelessly stolen from mutex_test.go.
+func HammerMutex(l sync.Locker, loops int, cdone chan bool) {
+ for i := 0; i < loops; i++ {
+ l.Lock()
+ runtime.Gosched()
+ l.Unlock()
+ }
+ cdone <- true
+}
+
+func TestMutexStress(t *testing.T) {
+ m := NewManager()
+ d := newTestData(testMutexSize)
+ tm := newTestMutex(0*testMutexSize, d, m)
+ c := make(chan bool)
+
+ for i := 0; i < 10; i++ {
+ go HammerMutex(tm, 1000, c)
+ }
+
+ for i := 0; i < 10; i++ {
+ <-c
+ }
+}
diff --git a/pkg/sentry/kernel/g3doc/run_states.dot b/pkg/sentry/kernel/g3doc/run_states.dot
new file mode 100644
index 000000000..7861fe1f5
--- /dev/null
+++ b/pkg/sentry/kernel/g3doc/run_states.dot
@@ -0,0 +1,99 @@
+digraph {
+ subgraph {
+ App;
+ }
+ subgraph {
+ Interrupt;
+ InterruptAfterSignalDeliveryStop;
+ }
+ subgraph {
+ Syscall;
+ SyscallAfterPtraceEventSeccomp;
+ SyscallEnter;
+ SyscallAfterSyscallEnterStop;
+ SyscallAfterSysemuStop;
+ SyscallInvoke;
+ SyscallAfterPtraceEventClone;
+ SyscallAfterExecStop;
+ SyscallAfterVforkStop;
+ SyscallReinvoke;
+ SyscallExit;
+ }
+ subgraph {
+ Vsyscall;
+ VsyscallAfterPtraceEventSeccomp;
+ VsyscallInvoke;
+ }
+ subgraph {
+ Exit;
+ ExitMain; // leave thread group, release resources, reparent children, kill PID namespace and wait if TGID 1
+ ExitNotify; // signal parent/tracer, become waitable
+ ExitDone; // represented by t.runState == nil
+ }
+
+ // Task exit
+ Exit -> ExitMain;
+ ExitMain -> ExitNotify;
+ ExitNotify -> ExitDone;
+
+ // Execution of untrusted application code
+ App -> App;
+
+ // Interrupts (usually signal delivery)
+ App -> Interrupt;
+ Interrupt -> Interrupt; // if other interrupt conditions may still apply
+ Interrupt -> Exit; // if killed
+
+ // Syscalls
+ App -> Syscall;
+ Syscall -> SyscallEnter;
+ SyscallEnter -> SyscallInvoke;
+ SyscallInvoke -> SyscallExit;
+ SyscallExit -> App;
+
+ // exit, exit_group
+ SyscallInvoke -> Exit;
+
+ // execve
+ SyscallInvoke -> SyscallAfterExecStop;
+ SyscallAfterExecStop -> SyscallExit;
+ SyscallAfterExecStop -> App; // fatal signal pending
+
+ // vfork
+ SyscallInvoke -> SyscallAfterVforkStop;
+ SyscallAfterVforkStop -> SyscallExit;
+
+ // Vsyscalls
+ App -> Vsyscall;
+ Vsyscall -> VsyscallInvoke;
+ Vsyscall -> App; // fault while reading return address from stack
+ VsyscallInvoke -> App;
+
+ // ptrace-specific branches
+ Interrupt -> InterruptAfterSignalDeliveryStop;
+ InterruptAfterSignalDeliveryStop -> Interrupt;
+ SyscallEnter -> SyscallAfterSyscallEnterStop;
+ SyscallAfterSyscallEnterStop -> SyscallInvoke;
+ SyscallAfterSyscallEnterStop -> SyscallExit; // skipped by tracer
+ SyscallAfterSyscallEnterStop -> App; // fatal signal pending
+ SyscallEnter -> SyscallAfterSysemuStop;
+ SyscallAfterSysemuStop -> SyscallExit;
+ SyscallAfterSysemuStop -> App; // fatal signal pending
+ SyscallInvoke -> SyscallAfterPtraceEventClone;
+ SyscallAfterPtraceEventClone -> SyscallExit;
+ SyscallAfterPtraceEventClone -> SyscallAfterVforkStop;
+
+ // seccomp
+ Syscall -> App; // SECCOMP_RET_TRAP, SECCOMP_RET_ERRNO, SECCOMP_RET_KILL, SECCOMP_RET_TRACE without tracer
+ Syscall -> SyscallAfterPtraceEventSeccomp; // SECCOMP_RET_TRACE
+ SyscallAfterPtraceEventSeccomp -> SyscallEnter;
+ SyscallAfterPtraceEventSeccomp -> SyscallExit; // skipped by tracer
+ SyscallAfterPtraceEventSeccomp -> App; // fatal signal pending
+ Vsyscall -> VsyscallAfterPtraceEventSeccomp;
+ VsyscallAfterPtraceEventSeccomp -> VsyscallInvoke;
+ VsyscallAfterPtraceEventSeccomp -> App;
+
+ // Autosave
+ SyscallInvoke -> SyscallReinvoke;
+ SyscallReinvoke -> SyscallInvoke;
+}
diff --git a/pkg/sentry/kernel/g3doc/run_states.png b/pkg/sentry/kernel/g3doc/run_states.png
new file mode 100644
index 000000000..b63b60f02
--- /dev/null
+++ b/pkg/sentry/kernel/g3doc/run_states.png
Binary files differ
diff --git a/pkg/sentry/kernel/ipc_namespace.go b/pkg/sentry/kernel/ipc_namespace.go
new file mode 100644
index 000000000..80a070d7e
--- /dev/null
+++ b/pkg/sentry/kernel/ipc_namespace.go
@@ -0,0 +1,58 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/semaphore"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/shm"
+)
+
+// IPCNamespace represents an IPC namespace.
+//
+// +stateify savable
+type IPCNamespace struct {
+ // User namespace which owns this IPC namespace. Immutable.
+ userNS *auth.UserNamespace
+
+ semaphores *semaphore.Registry
+ shms *shm.Registry
+}
+
+// NewIPCNamespace creates a new IPC namespace.
+func NewIPCNamespace(userNS *auth.UserNamespace) *IPCNamespace {
+ return &IPCNamespace{
+ userNS: userNS,
+ semaphores: semaphore.NewRegistry(userNS),
+ shms: shm.NewRegistry(userNS),
+ }
+}
+
+// SemaphoreRegistry returns the semaphore set registry for this namespace.
+func (i *IPCNamespace) SemaphoreRegistry() *semaphore.Registry {
+ return i.semaphores
+}
+
+// ShmRegistry returns the shm segment registry for this namespace.
+func (i *IPCNamespace) ShmRegistry() *shm.Registry {
+ return i.shms
+}
+
+// IPCNamespace returns the task's IPC namespace.
+func (t *Task) IPCNamespace() *IPCNamespace {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.ipcns
+}
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
new file mode 100644
index 000000000..2177b785a
--- /dev/null
+++ b/pkg/sentry/kernel/kernel.go
@@ -0,0 +1,1682 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package kernel provides an emulation of the Linux kernel.
+//
+// See README.md for a detailed overview.
+//
+// Lock order (outermost locks must be taken first):
+//
+// Kernel.extMu
+// ThreadGroup.timerMu
+// ktime.Timer.mu (for kernelCPUClockTicker and IntervalTimer)
+// TaskSet.mu
+// SignalHandlers.mu
+// Task.mu
+// runningTasksMu
+//
+// Locking SignalHandlers.mu in multiple SignalHandlers requires locking
+// TaskSet.mu exclusively first. Locking Task.mu in multiple Tasks at the same
+// time requires locking all of their signal mutexes first.
+package kernel
+
+import (
+ "errors"
+ "fmt"
+ "path/filepath"
+ "sync/atomic"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/eventchannel"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ oldtimerfd "gvisor.dev/gvisor/pkg/sentry/fs/timerfd"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/pipefs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/timerfd"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/hostcpu"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/epoll"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/futex"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/sched"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/loader"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port"
+ sentrytime "gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sentry/unimpl"
+ uspb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto"
+ "gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/wire"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// VFS2Enabled is set to true when VFS2 is enabled. Added as a global for allow
+// easy access everywhere. To be removed once VFS2 becomes the default.
+var VFS2Enabled = false
+
+// Kernel represents an emulated Linux kernel. It must be initialized by calling
+// Init() or LoadFrom().
+//
+// +stateify savable
+type Kernel struct {
+ // extMu serializes external changes to the Kernel with calls to
+ // Kernel.SaveTo. (Kernel.SaveTo requires that the state of the Kernel
+ // remains frozen for the duration of the call; it requires that the Kernel
+ // is paused as a precondition, which ensures that none of the tasks
+ // running within the Kernel can affect its state, but extMu is required to
+ // ensure that concurrent users of the Kernel *outside* the Kernel's
+ // control cannot affect its state by calling e.g.
+ // Kernel.SendExternalSignal.)
+ extMu sync.Mutex `state:"nosave"`
+
+ // started is true if Start has been called. Unless otherwise specified,
+ // all Kernel fields become immutable once started becomes true.
+ started bool `state:"nosave"`
+
+ // All of the following fields are immutable unless otherwise specified.
+
+ // Platform is the platform that is used to execute tasks in the created
+ // Kernel. See comment on pgalloc.MemoryFileProvider for why Platform is
+ // embedded anonymously (the same issue applies).
+ platform.Platform `state:"nosave"`
+
+ // mf provides application memory.
+ mf *pgalloc.MemoryFile `state:"nosave"`
+
+ // See InitKernelArgs for the meaning of these fields.
+ featureSet *cpuid.FeatureSet
+ timekeeper *Timekeeper
+ tasks *TaskSet
+ rootUserNamespace *auth.UserNamespace
+ rootNetworkNamespace *inet.Namespace
+ applicationCores uint
+ useHostCores bool
+ extraAuxv []arch.AuxEntry
+ vdso *loader.VDSO
+ rootUTSNamespace *UTSNamespace
+ rootIPCNamespace *IPCNamespace
+ rootAbstractSocketNamespace *AbstractSocketNamespace
+
+ // futexes is the "root" futex.Manager, from which all others are forked.
+ // This is necessary to ensure that shared futexes are coherent across all
+ // tasks, including those created by CreateProcess.
+ futexes *futex.Manager
+
+ // globalInit is the thread group whose leader has ID 1 in the root PID
+ // namespace. globalInit is stored separately so that it is accessible even
+ // after all tasks in the thread group have exited, such that ID 1 is no
+ // longer mapped.
+ //
+ // globalInit is mutable until it is assigned by the first successful call
+ // to CreateProcess, and is protected by extMu.
+ globalInit *ThreadGroup
+
+ // realtimeClock is a ktime.Clock based on timekeeper's Realtime.
+ realtimeClock *timekeeperClock
+
+ // monotonicClock is a ktime.Clock based on timekeeper's Monotonic.
+ monotonicClock *timekeeperClock
+
+ // syslog is the kernel log.
+ syslog syslog
+
+ // runningTasksMu synchronizes disable/enable of cpuClockTicker when
+ // the kernel is idle (runningTasks == 0).
+ //
+ // runningTasksMu is used to exclude critical sections when the timer
+ // disables itself and when the first active task enables the timer,
+ // ensuring that tasks always see a valid cpuClock value.
+ runningTasksMu sync.Mutex `state:"nosave"`
+
+ // runningTasks is the total count of tasks currently in
+ // TaskGoroutineRunningSys or TaskGoroutineRunningApp. i.e., they are
+ // not blocked or stopped.
+ //
+ // runningTasks must be accessed atomically. Increments from 0 to 1 are
+ // further protected by runningTasksMu (see incRunningTasks).
+ runningTasks int64
+
+ // cpuClock is incremented every linux.ClockTick. cpuClock is used to
+ // measure task CPU usage, since sampling monotonicClock twice on every
+ // syscall turns out to be unreasonably expensive. This is similar to how
+ // Linux does task CPU accounting on x86 (CONFIG_IRQ_TIME_ACCOUNTING),
+ // although Linux also uses scheduler timing information to improve
+ // resolution (kernel/sched/cputime.c:cputime_adjust()), which we can't do
+ // since "preeemptive" scheduling is managed by the Go runtime, which
+ // doesn't provide this information.
+ //
+ // cpuClock is mutable, and is accessed using atomic memory operations.
+ cpuClock uint64
+
+ // cpuClockTicker increments cpuClock.
+ cpuClockTicker *ktime.Timer `state:"nosave"`
+
+ // cpuClockTickerDisabled indicates that cpuClockTicker has been
+ // disabled because no tasks are running.
+ //
+ // cpuClockTickerDisabled is protected by runningTasksMu.
+ cpuClockTickerDisabled bool
+
+ // cpuClockTickerSetting is the ktime.Setting of cpuClockTicker at the
+ // point it was disabled. It is cached here to avoid a lock ordering
+ // violation with cpuClockTicker.mu when runningTaskMu is held.
+ //
+ // cpuClockTickerSetting is only valid when cpuClockTickerDisabled is
+ // true.
+ //
+ // cpuClockTickerSetting is protected by runningTasksMu.
+ cpuClockTickerSetting ktime.Setting
+
+ // uniqueID is used to generate unique identifiers.
+ //
+ // uniqueID is mutable, and is accessed using atomic memory operations.
+ uniqueID uint64
+
+ // nextInotifyCookie is a monotonically increasing counter used for
+ // generating unique inotify event cookies.
+ //
+ // nextInotifyCookie is mutable, and is accessed using atomic memory
+ // operations.
+ nextInotifyCookie uint32
+
+ // netlinkPorts manages allocation of netlink socket port IDs.
+ netlinkPorts *port.Manager
+
+ // saveErr is the error causing the sandbox to exit during save, if
+ // any. It is protected by extMu.
+ saveErr error `state:"nosave"`
+
+ // danglingEndpoints is used to save / restore tcpip.DanglingEndpoints.
+ danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"`
+
+ // sockets is the list of all network sockets the system. Protected by
+ // extMu.
+ sockets socketList
+
+ // nextSocketEntry is the next entry number to use in sockets. Protected
+ // by extMu.
+ nextSocketEntry uint64
+
+ // deviceRegistry is used to save/restore device.SimpleDevices.
+ deviceRegistry struct{} `state:".(*device.Registry)"`
+
+ // DirentCacheLimiter controls the number of total dirent entries can be in
+ // caches. Not all caches use it, only the caches that use host resources use
+ // the limiter. It may be nil if disabled.
+ DirentCacheLimiter *fs.DirentCacheLimiter
+
+ // unimplementedSyscallEmitterOnce is used in the initialization of
+ // unimplementedSyscallEmitter.
+ unimplementedSyscallEmitterOnce sync.Once `state:"nosave"`
+
+ // unimplementedSyscallEmitter is used to emit unimplemented syscall
+ // events. This is initialized lazily on the first unimplemented
+ // syscall.
+ unimplementedSyscallEmitter eventchannel.Emitter `state:"nosave"`
+
+ // SpecialOpts contains special kernel options.
+ SpecialOpts
+
+ // VFS keeps the filesystem state used across the kernel.
+ vfs vfs.VirtualFilesystem
+
+ // hostMount is the Mount used for file descriptors that were imported
+ // from the host.
+ hostMount *vfs.Mount
+
+ // pipeMount is the Mount used for pipes created by the pipe() and pipe2()
+ // syscalls (as opposed to named pipes created by mknod()).
+ pipeMount *vfs.Mount
+
+ // shmMount is the Mount used for anonymous files created by the
+ // memfd_create() syscalls. It is analagous to Linux's shm_mnt.
+ shmMount *vfs.Mount
+
+ // socketMount is the Mount used for sockets created by the socket() and
+ // socketpair() syscalls. There are several cases where a socket dentry will
+ // not be contained in socketMount:
+ // 1. Socket files created by mknod()
+ // 2. Socket fds imported from the host (Kernel.hostMount is used for these)
+ // 3. Socket files created by binding Unix sockets to a file path
+ socketMount *vfs.Mount
+
+ // If set to true, report address space activation waits as if the task is in
+ // external wait so that the watchdog doesn't report the task stuck.
+ SleepForAddressSpaceActivation bool
+}
+
+// InitKernelArgs holds arguments to Init.
+type InitKernelArgs struct {
+ // FeatureSet is the emulated CPU feature set.
+ FeatureSet *cpuid.FeatureSet
+
+ // Timekeeper manages time for all tasks in the system.
+ Timekeeper *Timekeeper
+
+ // RootUserNamespace is the root user namespace.
+ RootUserNamespace *auth.UserNamespace
+
+ // RootNetworkNamespace is the root network namespace. If nil, no networking
+ // will be available.
+ RootNetworkNamespace *inet.Namespace
+
+ // ApplicationCores is the number of logical CPUs visible to sandboxed
+ // applications. The set of logical CPU IDs is [0, ApplicationCores); thus
+ // ApplicationCores is analogous to Linux's nr_cpu_ids, the index of the
+ // most significant bit in cpu_possible_mask + 1.
+ ApplicationCores uint
+
+ // If UseHostCores is true, Task.CPU() returns the task goroutine's CPU
+ // instead of a virtualized CPU number, and Task.CopyToCPUMask() is a
+ // no-op. If ApplicationCores is less than hostcpu.MaxPossibleCPU(), it
+ // will be overridden.
+ UseHostCores bool
+
+ // ExtraAuxv contains additional auxiliary vector entries that are added to
+ // each process by the ELF loader.
+ ExtraAuxv []arch.AuxEntry
+
+ // Vdso holds the VDSO and its parameter page.
+ Vdso *loader.VDSO
+
+ // RootUTSNamespace is the root UTS namespace.
+ RootUTSNamespace *UTSNamespace
+
+ // RootIPCNamespace is the root IPC namespace.
+ RootIPCNamespace *IPCNamespace
+
+ // RootAbstractSocketNamespace is the root Abstract Socket namespace.
+ RootAbstractSocketNamespace *AbstractSocketNamespace
+
+ // PIDNamespace is the root PID namespace.
+ PIDNamespace *PIDNamespace
+}
+
+// Init initialize the Kernel with no tasks.
+//
+// Callers must manually set Kernel.Platform and call Kernel.SetMemoryFile
+// before calling Init.
+func (k *Kernel) Init(args InitKernelArgs) error {
+ if args.FeatureSet == nil {
+ return fmt.Errorf("FeatureSet is nil")
+ }
+ if args.Timekeeper == nil {
+ return fmt.Errorf("Timekeeper is nil")
+ }
+ if args.Timekeeper.clocks == nil {
+ return fmt.Errorf("Must call Timekeeper.SetClocks() before Kernel.Init()")
+ }
+ if args.RootUserNamespace == nil {
+ return fmt.Errorf("RootUserNamespace is nil")
+ }
+ if args.ApplicationCores == 0 {
+ return fmt.Errorf("ApplicationCores is 0")
+ }
+
+ k.featureSet = args.FeatureSet
+ k.timekeeper = args.Timekeeper
+ k.tasks = newTaskSet(args.PIDNamespace)
+ k.rootUserNamespace = args.RootUserNamespace
+ k.rootUTSNamespace = args.RootUTSNamespace
+ k.rootIPCNamespace = args.RootIPCNamespace
+ k.rootAbstractSocketNamespace = args.RootAbstractSocketNamespace
+ k.rootNetworkNamespace = args.RootNetworkNamespace
+ if k.rootNetworkNamespace == nil {
+ k.rootNetworkNamespace = inet.NewRootNamespace(nil, nil)
+ }
+ k.applicationCores = args.ApplicationCores
+ if args.UseHostCores {
+ k.useHostCores = true
+ maxCPU, err := hostcpu.MaxPossibleCPU()
+ if err != nil {
+ return fmt.Errorf("Failed to get maximum CPU number: %v", err)
+ }
+ minAppCores := uint(maxCPU) + 1
+ if k.applicationCores < minAppCores {
+ log.Infof("UseHostCores enabled: increasing ApplicationCores from %d to %d", k.applicationCores, minAppCores)
+ k.applicationCores = minAppCores
+ }
+ }
+ k.extraAuxv = args.ExtraAuxv
+ k.vdso = args.Vdso
+ k.realtimeClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Realtime}
+ k.monotonicClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Monotonic}
+ k.futexes = futex.NewManager()
+ k.netlinkPorts = port.New()
+
+ if VFS2Enabled {
+ if err := k.vfs.Init(); err != nil {
+ return fmt.Errorf("failed to initialize VFS: %v", err)
+ }
+
+ pipeFilesystem, err := pipefs.NewFilesystem(&k.vfs)
+ if err != nil {
+ return fmt.Errorf("failed to create pipefs filesystem: %v", err)
+ }
+ defer pipeFilesystem.DecRef()
+ pipeMount, err := k.vfs.NewDisconnectedMount(pipeFilesystem, nil, &vfs.MountOptions{})
+ if err != nil {
+ return fmt.Errorf("failed to create pipefs mount: %v", err)
+ }
+ k.pipeMount = pipeMount
+
+ tmpfsFilesystem, tmpfsRoot, err := tmpfs.NewFilesystem(k.SupervisorContext(), &k.vfs, auth.NewRootCredentials(k.rootUserNamespace))
+ if err != nil {
+ return fmt.Errorf("failed to create tmpfs filesystem: %v", err)
+ }
+ defer tmpfsFilesystem.DecRef()
+ defer tmpfsRoot.DecRef()
+ shmMount, err := k.vfs.NewDisconnectedMount(tmpfsFilesystem, tmpfsRoot, &vfs.MountOptions{})
+ if err != nil {
+ return fmt.Errorf("failed to create tmpfs mount: %v", err)
+ }
+ k.shmMount = shmMount
+
+ socketFilesystem, err := sockfs.NewFilesystem(&k.vfs)
+ if err != nil {
+ return fmt.Errorf("failed to create sockfs filesystem: %v", err)
+ }
+ defer socketFilesystem.DecRef()
+ socketMount, err := k.vfs.NewDisconnectedMount(socketFilesystem, nil, &vfs.MountOptions{})
+ if err != nil {
+ return fmt.Errorf("failed to create sockfs mount: %v", err)
+ }
+ k.socketMount = socketMount
+ }
+
+ return nil
+}
+
+// SaveTo saves the state of k to w.
+//
+// Preconditions: The kernel must be paused throughout the call to SaveTo.
+func (k *Kernel) SaveTo(w wire.Writer) error {
+ saveStart := time.Now()
+ ctx := k.SupervisorContext()
+
+ // Do not allow other Kernel methods to affect it while it's being saved.
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+
+ // Stop time.
+ k.pauseTimeLocked()
+ defer k.resumeTimeLocked()
+
+ // Evict all evictable MemoryFile allocations.
+ k.mf.StartEvictions()
+ k.mf.WaitForEvictions()
+
+ // Flush write operations on open files so data reaches backing storage.
+ // This must come after MemoryFile eviction since eviction may cause file
+ // writes.
+ if err := k.tasks.flushWritesToFiles(ctx); err != nil {
+ return err
+ }
+
+ // Remove all epoll waiter objects from underlying wait queues.
+ // NOTE: for programs to resume execution in future snapshot scenarios,
+ // we will need to re-establish these waiter objects after saving.
+ k.tasks.unregisterEpollWaiters()
+
+ // Clear the dirent cache before saving because Dirents must be Loaded in a
+ // particular order (parents before children), and Loading dirents from a cache
+ // breaks that order.
+ if err := k.flushMountSourceRefs(); err != nil {
+ return err
+ }
+
+ // Ensure that all inode and mount release operations have completed.
+ fs.AsyncBarrier()
+
+ // Once all fs work has completed (flushed references have all been released),
+ // reset mount mappings. This allows individual mounts to save how inodes map
+ // to filesystem resources. Without this, fs.Inodes cannot be restored.
+ fs.SaveInodeMappings()
+
+ // Discard unsavable mappings, such as those for host file descriptors.
+ // This must be done after waiting for "asynchronous fs work", which
+ // includes async I/O that may touch application memory.
+ if err := k.invalidateUnsavableMappings(ctx); err != nil {
+ return fmt.Errorf("failed to invalidate unsavable mappings: %v", err)
+ }
+
+ // Save the CPUID FeatureSet before the rest of the kernel so we can
+ // verify its compatibility on restore before attempting to restore the
+ // entire kernel, which may fail on an incompatible machine.
+ //
+ // N.B. This will also be saved along with the full kernel save below.
+ cpuidStart := time.Now()
+ if _, err := state.Save(k.SupervisorContext(), w, k.FeatureSet()); err != nil {
+ return err
+ }
+ log.Infof("CPUID save took [%s].", time.Since(cpuidStart))
+
+ // Save the kernel state.
+ kernelStart := time.Now()
+ stats, err := state.Save(k.SupervisorContext(), w, k)
+ if err != nil {
+ return err
+ }
+ log.Infof("Kernel save stats: %s", stats.String())
+ log.Infof("Kernel save took [%s].", time.Since(kernelStart))
+
+ // Save the memory file's state.
+ memoryStart := time.Now()
+ if err := k.mf.SaveTo(k.SupervisorContext(), w); err != nil {
+ return err
+ }
+ log.Infof("Memory save took [%s].", time.Since(memoryStart))
+
+ log.Infof("Overall save took [%s].", time.Since(saveStart))
+
+ return nil
+}
+
+// flushMountSourceRefs flushes the MountSources for all mounted filesystems
+// and open FDs.
+func (k *Kernel) flushMountSourceRefs() error {
+ // Flush all mount sources for currently mounted filesystems in each task.
+ flushed := make(map[*fs.MountNamespace]struct{})
+ k.tasks.mu.RLock()
+ k.tasks.forEachThreadGroupLocked(func(tg *ThreadGroup) {
+ if _, ok := flushed[tg.mounts]; ok {
+ // Already flushed.
+ return
+ }
+ tg.mounts.FlushMountSourceRefs()
+ flushed[tg.mounts] = struct{}{}
+ })
+ k.tasks.mu.RUnlock()
+
+ // There may be some open FDs whose filesystems have been unmounted. We
+ // must flush those as well.
+ return k.tasks.forEachFDPaused(func(file *fs.File, _ *vfs.FileDescription) error {
+ file.Dirent.Inode.MountSource.FlushDirentRefs()
+ return nil
+ })
+}
+
+// forEachFDPaused applies the given function to each open file descriptor in
+// each task.
+//
+// Precondition: Must be called with the kernel paused.
+func (ts *TaskSet) forEachFDPaused(f func(*fs.File, *vfs.FileDescription) error) (err error) {
+ // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
+ if VFS2Enabled {
+ return nil
+ }
+
+ ts.mu.RLock()
+ defer ts.mu.RUnlock()
+ for t := range ts.Root.tids {
+ // We can skip locking Task.mu here since the kernel is paused.
+ if t.fdTable == nil {
+ continue
+ }
+ t.fdTable.forEach(func(_ int32, file *fs.File, fileVFS2 *vfs.FileDescription, _ FDFlags) {
+ if lastErr := f(file, fileVFS2); lastErr != nil && err == nil {
+ err = lastErr
+ }
+ })
+ }
+ return err
+}
+
+func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error {
+ // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
+ return ts.forEachFDPaused(func(file *fs.File, _ *vfs.FileDescription) error {
+ if flags := file.Flags(); !flags.Write {
+ return nil
+ }
+ if sattr := file.Dirent.Inode.StableAttr; !fs.IsFile(sattr) && !fs.IsDir(sattr) {
+ return nil
+ }
+ // Here we need all metadata synced.
+ syncErr := file.Fsync(ctx, 0, fs.FileMaxOffset, fs.SyncAll)
+ if err := fs.SaveFileFsyncError(syncErr); err != nil {
+ name, _ := file.Dirent.FullName(nil /* root */)
+ // Wrap this error in ErrSaveRejection so that it will trigger a save
+ // error, rather than a panic. This also allows us to distinguish Fsync
+ // errors from state file errors in state.Save.
+ return fs.ErrSaveRejection{
+ Err: fmt.Errorf("%q was not sufficiently synced: %v", name, err),
+ }
+ }
+ return nil
+ })
+}
+
+// Preconditions: The kernel must be paused.
+func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error {
+ invalidated := make(map[*mm.MemoryManager]struct{})
+ k.tasks.mu.RLock()
+ defer k.tasks.mu.RUnlock()
+ for t := range k.tasks.Root.tids {
+ // We can skip locking Task.mu here since the kernel is paused.
+ if mm := t.tc.MemoryManager; mm != nil {
+ if _, ok := invalidated[mm]; !ok {
+ if err := mm.InvalidateUnsavable(ctx); err != nil {
+ return err
+ }
+ invalidated[mm] = struct{}{}
+ }
+ }
+ // I really wish we just had a sync.Map of all MMs...
+ if r, ok := t.runState.(*runSyscallAfterExecStop); ok {
+ if err := r.tc.MemoryManager.InvalidateUnsavable(ctx); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+func (ts *TaskSet) unregisterEpollWaiters() {
+ // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
+ if VFS2Enabled {
+ return
+ }
+
+ ts.mu.RLock()
+ defer ts.mu.RUnlock()
+
+ // Tasks that belong to the same process could potentially point to the
+ // same FDTable. So we retain a map of processed ones to avoid
+ // processing the same FDTable multiple times.
+ processed := make(map[*FDTable]struct{})
+ for t := range ts.Root.tids {
+ // We can skip locking Task.mu here since the kernel is paused.
+ if t.fdTable == nil {
+ continue
+ }
+ if _, ok := processed[t.fdTable]; ok {
+ continue
+ }
+ t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
+ if e, ok := file.FileOperations.(*epoll.EventPoll); ok {
+ e.UnregisterEpollWaiters()
+ }
+ })
+ processed[t.fdTable] = struct{}{}
+ }
+}
+
+// LoadFrom returns a new Kernel loaded from args.
+func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clocks) error {
+ loadStart := time.Now()
+
+ initAppCores := k.applicationCores
+
+ // Load the pre-saved CPUID FeatureSet.
+ //
+ // N.B. This was also saved along with the full kernel below, so we
+ // don't need to explicitly install it in the Kernel.
+ cpuidStart := time.Now()
+ var features cpuid.FeatureSet
+ if _, err := state.Load(k.SupervisorContext(), r, &features); err != nil {
+ return err
+ }
+ log.Infof("CPUID load took [%s].", time.Since(cpuidStart))
+
+ // Verify that the FeatureSet is usable on this host. We do this before
+ // Kernel load so that the explicit CPUID mismatch error has priority
+ // over floating point state restore errors that may occur on load on
+ // an incompatible machine.
+ if err := features.CheckHostCompatible(); err != nil {
+ return err
+ }
+
+ // Load the kernel state.
+ kernelStart := time.Now()
+ stats, err := state.Load(k.SupervisorContext(), r, k)
+ if err != nil {
+ return err
+ }
+ log.Infof("Kernel load stats: %s", stats.String())
+ log.Infof("Kernel load took [%s].", time.Since(kernelStart))
+
+ // rootNetworkNamespace should be populated after loading the state file.
+ // Restore the root network stack.
+ k.rootNetworkNamespace.RestoreRootStack(net)
+
+ // Load the memory file's state.
+ memoryStart := time.Now()
+ if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil {
+ return err
+ }
+ log.Infof("Memory load took [%s].", time.Since(memoryStart))
+
+ log.Infof("Overall load took [%s]", time.Since(loadStart))
+
+ k.Timekeeper().SetClocks(clocks)
+ if net != nil {
+ net.Resume()
+ }
+
+ // Ensure that all pending asynchronous work is complete:
+ // - namedpipe opening
+ // - inode file opening
+ if err := fs.AsyncErrorBarrier(); err != nil {
+ return err
+ }
+
+ tcpip.AsyncLoading.Wait()
+
+ log.Infof("Overall load took [%s] after async work", time.Since(loadStart))
+
+ // Applications may size per-cpu structures based on k.applicationCores, so
+ // it can't change across save/restore. When we are virtualizing CPU
+ // numbers, this isn't a problem. However, when we are exposing host CPU
+ // assignments, we can't tolerate an increase in the number of host CPUs,
+ // which could result in getcpu(2) returning CPUs that applications expect
+ // not to exist.
+ if k.useHostCores && initAppCores > k.applicationCores {
+ return fmt.Errorf("UseHostCores enabled: can't increase ApplicationCores from %d to %d after restore", k.applicationCores, initAppCores)
+ }
+
+ return nil
+}
+
+// UniqueID returns a unique identifier.
+func (k *Kernel) UniqueID() uint64 {
+ id := atomic.AddUint64(&k.uniqueID, 1)
+ if id == 0 {
+ panic("unique identifier generator wrapped around")
+ }
+ return id
+}
+
+// CreateProcessArgs holds arguments to kernel.CreateProcess.
+type CreateProcessArgs struct {
+ // Filename is the filename to load as the init binary.
+ //
+ // If this is provided as "", File will be checked, then the file will be
+ // guessed via Argv[0].
+ Filename string
+
+ // File is a passed host FD pointing to a file to load as the init binary.
+ //
+ // This is checked if and only if Filename is "".
+ File fsbridge.File
+
+ // Argvv is a list of arguments.
+ Argv []string
+
+ // Envv is a list of environment variables.
+ Envv []string
+
+ // WorkingDirectory is the initial working directory.
+ //
+ // This defaults to the root if empty.
+ WorkingDirectory string
+
+ // Credentials is the initial credentials.
+ Credentials *auth.Credentials
+
+ // FDTable is the initial set of file descriptors. If CreateProcess succeeds,
+ // it takes a reference on FDTable.
+ FDTable *FDTable
+
+ // Umask is the initial umask.
+ Umask uint
+
+ // Limits is the initial resource limits.
+ Limits *limits.LimitSet
+
+ // MaxSymlinkTraversals is the maximum number of symlinks to follow
+ // during resolution.
+ MaxSymlinkTraversals uint
+
+ // UTSNamespace is the initial UTS namespace.
+ UTSNamespace *UTSNamespace
+
+ // IPCNamespace is the initial IPC namespace.
+ IPCNamespace *IPCNamespace
+
+ // PIDNamespace is the initial PID Namespace.
+ PIDNamespace *PIDNamespace
+
+ // AbstractSocketNamespace is the initial Abstract Socket namespace.
+ AbstractSocketNamespace *AbstractSocketNamespace
+
+ // MountNamespace optionally contains the mount namespace for this
+ // process. If nil, the init process's mount namespace is used.
+ //
+ // Anyone setting MountNamespace must donate a reference (i.e.
+ // increment it).
+ MountNamespace *fs.MountNamespace
+
+ // MountNamespaceVFS2 optionally contains the mount namespace for this
+ // process. If nil, the init process's mount namespace is used.
+ //
+ // Anyone setting MountNamespaceVFS2 must donate a reference (i.e.
+ // increment it).
+ MountNamespaceVFS2 *vfs.MountNamespace
+
+ // ContainerID is the container that the process belongs to.
+ ContainerID string
+}
+
+// NewContext returns a context.Context that represents the task that will be
+// created by args.NewContext(k).
+func (args *CreateProcessArgs) NewContext(k *Kernel) *createProcessContext {
+ return &createProcessContext{
+ Logger: log.Log(),
+ k: k,
+ args: args,
+ }
+}
+
+// createProcessContext is a context.Context that represents the context
+// associated with a task that is being created.
+type createProcessContext struct {
+ context.NoopSleeper
+ log.Logger
+ k *Kernel
+ args *CreateProcessArgs
+}
+
+// Value implements context.Context.Value.
+func (ctx *createProcessContext) Value(key interface{}) interface{} {
+ switch key {
+ case CtxKernel:
+ return ctx.k
+ case CtxPIDNamespace:
+ return ctx.args.PIDNamespace
+ case CtxUTSNamespace:
+ return ctx.args.UTSNamespace
+ case CtxIPCNamespace:
+ return ctx.args.IPCNamespace
+ case auth.CtxCredentials:
+ return ctx.args.Credentials
+ case fs.CtxRoot:
+ if ctx.args.MountNamespace != nil {
+ // MountNamespace.Root() will take a reference on the root dirent for us.
+ return ctx.args.MountNamespace.Root()
+ }
+ return nil
+ case vfs.CtxRoot:
+ if ctx.args.MountNamespaceVFS2 == nil {
+ return nil
+ }
+ // MountNamespaceVFS2.Root() takes a reference on the root dirent for us.
+ return ctx.args.MountNamespaceVFS2.Root()
+ case vfs.CtxMountNamespace:
+ if ctx.k.globalInit == nil {
+ return nil
+ }
+ // MountNamespaceVFS2 takes a reference for us.
+ return ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
+ case fs.CtxDirentCacheLimiter:
+ return ctx.k.DirentCacheLimiter
+ case inet.CtxStack:
+ return ctx.k.RootNetworkNamespace().Stack()
+ case ktime.CtxRealtimeClock:
+ return ctx.k.RealtimeClock()
+ case limits.CtxLimits:
+ return ctx.args.Limits
+ case pgalloc.CtxMemoryFile:
+ return ctx.k.mf
+ case pgalloc.CtxMemoryFileProvider:
+ return ctx.k
+ case platform.CtxPlatform:
+ return ctx.k
+ case uniqueid.CtxGlobalUniqueID:
+ return ctx.k.UniqueID()
+ case uniqueid.CtxGlobalUniqueIDProvider:
+ return ctx.k
+ case uniqueid.CtxInotifyCookie:
+ return ctx.k.GenerateInotifyCookie()
+ case unimpl.CtxEvents:
+ return ctx.k
+ default:
+ return nil
+ }
+}
+
+// CreateProcess creates a new task in a new thread group with the given
+// options. The new task has no parent and is in the root PID namespace.
+//
+// If k.Start() has already been called, then the created process must be
+// started by calling kernel.StartProcess(tg).
+//
+// If k.Start() has not yet been called, then the created task will begin
+// running when k.Start() is called.
+//
+// CreateProcess has no analogue in Linux; it is used to create the initial
+// application task, as well as processes started by the control server.
+func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, error) {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ log.Infof("EXEC: %v", args.Argv)
+
+ ctx := args.NewContext(k)
+
+ var (
+ opener fsbridge.Lookup
+ fsContext *FSContext
+ mntns *fs.MountNamespace
+ )
+
+ if VFS2Enabled {
+ mntnsVFS2 := args.MountNamespaceVFS2
+ if mntnsVFS2 == nil {
+ // MountNamespaceVFS2 adds a reference to the namespace, which is
+ // transferred to the new process.
+ mntnsVFS2 = k.globalInit.Leader().MountNamespaceVFS2()
+ }
+ // Get the root directory from the MountNamespace.
+ root := args.MountNamespaceVFS2.Root()
+ // The call to newFSContext below will take a reference on root, so we
+ // don't need to hold this one.
+ defer root.DecRef()
+
+ // Grab the working directory.
+ wd := root // Default.
+ if args.WorkingDirectory != "" {
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: wd,
+ Path: fspath.Parse(args.WorkingDirectory),
+ FollowFinalSymlink: true,
+ }
+ var err error
+ wd, err = k.VFS().GetDentryAt(ctx, args.Credentials, &pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err)
+ }
+ defer wd.DecRef()
+ }
+ opener = fsbridge.NewVFSLookup(mntnsVFS2, root, wd)
+ fsContext = NewFSContextVFS2(root, wd, args.Umask)
+
+ } else {
+ mntns = args.MountNamespace
+ if mntns == nil {
+ mntns = k.GlobalInit().Leader().MountNamespace()
+ mntns.IncRef()
+ }
+ // Get the root directory from the MountNamespace.
+ root := mntns.Root()
+ // The call to newFSContext below will take a reference on root, so we
+ // don't need to hold this one.
+ defer root.DecRef()
+
+ // Grab the working directory.
+ remainingTraversals := args.MaxSymlinkTraversals
+ wd := root // Default.
+ if args.WorkingDirectory != "" {
+ var err error
+ wd, err = mntns.FindInode(ctx, root, nil, args.WorkingDirectory, &remainingTraversals)
+ if err != nil {
+ return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err)
+ }
+ defer wd.DecRef()
+ }
+ opener = fsbridge.NewFSLookup(mntns, root, wd)
+ fsContext = newFSContext(root, wd, args.Umask)
+ }
+
+ tg := k.NewThreadGroup(mntns, args.PIDNamespace, NewSignalHandlers(), linux.SIGCHLD, args.Limits)
+
+ // Check which file to start from.
+ switch {
+ case args.Filename != "":
+ // If a filename is given, take that.
+ // Set File to nil so we resolve the path in LoadTaskImage.
+ args.File = nil
+ case args.File != nil:
+ // If File is set, take the File provided directly.
+ default:
+ // Otherwise look at Argv and see if the first argument is a valid path.
+ if len(args.Argv) == 0 {
+ return nil, 0, fmt.Errorf("no filename or command provided")
+ }
+ if !filepath.IsAbs(args.Argv[0]) {
+ return nil, 0, fmt.Errorf("'%s' is not an absolute path", args.Argv[0])
+ }
+ args.Filename = args.Argv[0]
+ }
+
+ // Create a fresh task context.
+ remainingTraversals := args.MaxSymlinkTraversals
+ loadArgs := loader.LoadArgs{
+ Opener: opener,
+ RemainingTraversals: &remainingTraversals,
+ ResolveFinal: true,
+ Filename: args.Filename,
+ File: args.File,
+ CloseOnExec: false,
+ Argv: args.Argv,
+ Envv: args.Envv,
+ Features: k.featureSet,
+ }
+
+ tc, se := k.LoadTaskImage(ctx, loadArgs)
+ if se != nil {
+ return nil, 0, errors.New(se.String())
+ }
+
+ // Take a reference on the FDTable, which will be transferred to
+ // TaskSet.NewTask().
+ args.FDTable.IncRef()
+
+ // Create the task.
+ config := &TaskConfig{
+ Kernel: k,
+ ThreadGroup: tg,
+ TaskContext: tc,
+ FSContext: fsContext,
+ FDTable: args.FDTable,
+ Credentials: args.Credentials,
+ NetworkNamespace: k.RootNetworkNamespace(),
+ AllowedCPUMask: sched.NewFullCPUSet(k.applicationCores),
+ UTSNamespace: args.UTSNamespace,
+ IPCNamespace: args.IPCNamespace,
+ AbstractSocketNamespace: args.AbstractSocketNamespace,
+ MountNamespaceVFS2: args.MountNamespaceVFS2,
+ ContainerID: args.ContainerID,
+ }
+ t, err := k.tasks.NewTask(config)
+ if err != nil {
+ return nil, 0, err
+ }
+ t.traceExecEvent(tc) // Simulate exec for tracing.
+
+ // Success.
+ tgid := k.tasks.Root.IDOfThreadGroup(tg)
+ if k.globalInit == nil {
+ k.globalInit = tg
+ }
+ return tg, tgid, nil
+}
+
+// StartProcess starts running a process that was created with CreateProcess.
+func (k *Kernel) StartProcess(tg *ThreadGroup) {
+ t := tg.Leader()
+ tid := k.tasks.Root.IDOfTask(t)
+ t.Start(tid)
+}
+
+// Start starts execution of all tasks in k.
+//
+// Preconditions: Start may be called exactly once.
+func (k *Kernel) Start() error {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+
+ if k.globalInit == nil {
+ return fmt.Errorf("kernel contains no tasks")
+ }
+ if k.started {
+ return fmt.Errorf("kernel already started")
+ }
+
+ k.started = true
+ k.cpuClockTicker = ktime.NewTimer(k.monotonicClock, newKernelCPUClockTicker(k))
+ k.cpuClockTicker.Swap(ktime.Setting{
+ Enabled: true,
+ Period: linux.ClockTick,
+ })
+ // If k was created by LoadKernelFrom, timers were stopped during
+ // Kernel.SaveTo and need to be resumed. If k was created by NewKernel,
+ // this is a no-op.
+ k.resumeTimeLocked()
+ // Start task goroutines.
+ k.tasks.mu.RLock()
+ defer k.tasks.mu.RUnlock()
+ for t, tid := range k.tasks.Root.tids {
+ t.Start(tid)
+ }
+ return nil
+}
+
+// pauseTimeLocked pauses all Timers and Timekeeper updates.
+//
+// Preconditions: Any task goroutines running in k must be stopped. k.extMu
+// must be locked.
+func (k *Kernel) pauseTimeLocked() {
+ // k.cpuClockTicker may be nil since Kernel.SaveTo() may be called before
+ // Kernel.Start().
+ if k.cpuClockTicker != nil {
+ k.cpuClockTicker.Pause()
+ }
+
+ // By precondition, nothing else can be interacting with PIDNamespace.tids
+ // or FDTable.files, so we can iterate them without synchronization. (We
+ // can't hold the TaskSet mutex when pausing thread group timers because
+ // thread group timers call ThreadGroup.SendSignal, which takes the TaskSet
+ // mutex, while holding the Timer mutex.)
+ for t := range k.tasks.Root.tids {
+ if t == t.tg.leader {
+ t.tg.itimerRealTimer.Pause()
+ for _, it := range t.tg.timers {
+ it.PauseTimer()
+ }
+ }
+ // This means we'll iterate FDTables shared by multiple tasks repeatedly,
+ // but ktime.Timer.Pause is idempotent so this is harmless.
+ if t.fdTable != nil {
+ t.fdTable.forEach(func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) {
+ if VFS2Enabled {
+ if tfd, ok := fd.Impl().(*timerfd.TimerFileDescription); ok {
+ tfd.PauseTimer()
+ }
+ } else {
+ if tfd, ok := file.FileOperations.(*oldtimerfd.TimerOperations); ok {
+ tfd.PauseTimer()
+ }
+ }
+ })
+ }
+ }
+ k.timekeeper.PauseUpdates()
+}
+
+// resumeTimeLocked resumes all Timers and Timekeeper updates. If
+// pauseTimeLocked has not been previously called, resumeTimeLocked has no
+// effect.
+//
+// Preconditions: Any task goroutines running in k must be stopped. k.extMu
+// must be locked.
+func (k *Kernel) resumeTimeLocked() {
+ if k.cpuClockTicker != nil {
+ k.cpuClockTicker.Resume()
+ }
+
+ k.timekeeper.ResumeUpdates()
+ for t := range k.tasks.Root.tids {
+ if t == t.tg.leader {
+ t.tg.itimerRealTimer.Resume()
+ for _, it := range t.tg.timers {
+ it.ResumeTimer()
+ }
+ }
+ if t.fdTable != nil {
+ t.fdTable.forEach(func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) {
+ if VFS2Enabled {
+ if tfd, ok := fd.Impl().(*timerfd.TimerFileDescription); ok {
+ tfd.ResumeTimer()
+ }
+ } else {
+ if tfd, ok := file.FileOperations.(*oldtimerfd.TimerOperations); ok {
+ tfd.ResumeTimer()
+ }
+ }
+ })
+ }
+ }
+}
+
+func (k *Kernel) incRunningTasks() {
+ for {
+ tasks := atomic.LoadInt64(&k.runningTasks)
+ if tasks != 0 {
+ // Standard case. Simply increment.
+ if !atomic.CompareAndSwapInt64(&k.runningTasks, tasks, tasks+1) {
+ continue
+ }
+ return
+ }
+
+ // Transition from 0 -> 1. Synchronize with other transitions and timer.
+ k.runningTasksMu.Lock()
+ tasks = atomic.LoadInt64(&k.runningTasks)
+ if tasks != 0 {
+ // We're no longer the first task, no need to
+ // re-enable.
+ atomic.AddInt64(&k.runningTasks, 1)
+ k.runningTasksMu.Unlock()
+ return
+ }
+
+ if !k.cpuClockTickerDisabled {
+ // Timer was never disabled.
+ atomic.StoreInt64(&k.runningTasks, 1)
+ k.runningTasksMu.Unlock()
+ return
+ }
+
+ // We need to update cpuClock for all of the ticks missed while we
+ // slept, and then re-enable the timer.
+ //
+ // The Notify in Swap isn't sufficient. kernelCPUClockTicker.Notify
+ // always increments cpuClock by 1 regardless of the number of
+ // expirations as a heuristic to avoid over-accounting in cases of CPU
+ // throttling.
+ //
+ // We want to cover the normal case, when all time should be accounted,
+ // so we increment for all expirations. Throttling is less concerning
+ // here because the ticker is only disabled from Notify. This means
+ // that Notify must schedule and compensate for the throttled period
+ // before the timer is disabled. Throttling while the timer is disabled
+ // doesn't matter, as nothing is running or reading cpuClock anyways.
+ //
+ // S/R also adds complication, as there are two cases. Recall that
+ // monotonicClock will jump forward on restore.
+ //
+ // 1. If the ticker is enabled during save, then on Restore Notify is
+ // called with many expirations, covering the time jump, but cpuClock
+ // is only incremented by 1.
+ //
+ // 2. If the ticker is disabled during save, then after Restore the
+ // first wakeup will call this function and cpuClock will be
+ // incremented by the number of expirations across the S/R.
+ //
+ // These cause very different value of cpuClock. But again, since
+ // nothing was running while the ticker was disabled, those differences
+ // don't matter.
+ setting, exp := k.cpuClockTickerSetting.At(k.monotonicClock.Now())
+ if exp > 0 {
+ atomic.AddUint64(&k.cpuClock, exp)
+ }
+
+ // Now that cpuClock is updated it is safe to allow other tasks to
+ // transition to running.
+ atomic.StoreInt64(&k.runningTasks, 1)
+
+ // N.B. we must unlock before calling Swap to maintain lock ordering.
+ //
+ // cpuClockTickerDisabled need not wait until after Swap to become
+ // true. It is sufficient that the timer *will* be enabled.
+ k.cpuClockTickerDisabled = false
+ k.runningTasksMu.Unlock()
+
+ // This won't call Notify (unless it's been ClockTick since setting.At
+ // above). This means we skip the thread group work in Notify. However,
+ // since nothing was running while we were disabled, none of the timers
+ // could have expired.
+ k.cpuClockTicker.Swap(setting)
+
+ return
+ }
+}
+
+func (k *Kernel) decRunningTasks() {
+ tasks := atomic.AddInt64(&k.runningTasks, -1)
+ if tasks < 0 {
+ panic(fmt.Sprintf("Invalid running count %d", tasks))
+ }
+
+ // Nothing to do. The next CPU clock tick will disable the timer if
+ // there is still nothing running. This provides approximately one tick
+ // of slack in which we can switch back and forth between idle and
+ // active without an expensive transition.
+}
+
+// WaitExited blocks until all tasks in k have exited.
+func (k *Kernel) WaitExited() {
+ k.tasks.liveGoroutines.Wait()
+}
+
+// Kill requests that all tasks in k immediately exit as if group exiting with
+// status es. Kill does not wait for tasks to exit.
+func (k *Kernel) Kill(es ExitStatus) {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ k.tasks.Kill(es)
+}
+
+// Pause requests that all tasks in k temporarily stop executing, and blocks
+// until all tasks and asynchronous I/O operations in k have stopped. Multiple
+// calls to Pause nest and require an equal number of calls to Unpause to
+// resume execution.
+func (k *Kernel) Pause() {
+ k.extMu.Lock()
+ k.tasks.BeginExternalStop()
+ k.extMu.Unlock()
+ k.tasks.runningGoroutines.Wait()
+ k.tasks.aioGoroutines.Wait()
+}
+
+// Unpause ends the effect of a previous call to Pause. If Unpause is called
+// without a matching preceding call to Pause, Unpause may panic.
+func (k *Kernel) Unpause() {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ k.tasks.EndExternalStop()
+}
+
+// SendExternalSignal injects a signal into the kernel.
+//
+// context is used only for debugging to describe how the signal was received.
+//
+// Preconditions: Kernel must have an init process.
+func (k *Kernel) SendExternalSignal(info *arch.SignalInfo, context string) {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ k.sendExternalSignal(info, context)
+}
+
+// SendExternalSignalThreadGroup injects a signal into an specific ThreadGroup.
+// This function doesn't skip signals like SendExternalSignal does.
+func (k *Kernel) SendExternalSignalThreadGroup(tg *ThreadGroup, info *arch.SignalInfo) error {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ return tg.SendSignal(info)
+}
+
+// SendContainerSignal sends the given signal to all processes inside the
+// namespace that match the given container ID.
+func (k *Kernel) SendContainerSignal(cid string, info *arch.SignalInfo) error {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ k.tasks.mu.RLock()
+ defer k.tasks.mu.RUnlock()
+
+ var lastErr error
+ for tg := range k.tasks.Root.tgids {
+ if tg.leader.ContainerID() == cid {
+ tg.signalHandlers.mu.Lock()
+ infoCopy := *info
+ if err := tg.leader.sendSignalLocked(&infoCopy, true /*group*/); err != nil {
+ lastErr = err
+ }
+ tg.signalHandlers.mu.Unlock()
+ }
+ }
+ return lastErr
+}
+
+// RebuildTraceContexts rebuilds the trace context for all tasks.
+//
+// Unfortunately, if these are built while tracing is not enabled, then we will
+// not have meaningful trace data. Rebuilding here ensures that we can do so
+// after tracing has been enabled.
+func (k *Kernel) RebuildTraceContexts() {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ k.tasks.mu.RLock()
+ defer k.tasks.mu.RUnlock()
+
+ for t, tid := range k.tasks.Root.tids {
+ t.rebuildTraceContext(tid)
+ }
+}
+
+// FeatureSet returns the FeatureSet.
+func (k *Kernel) FeatureSet() *cpuid.FeatureSet {
+ return k.featureSet
+}
+
+// Timekeeper returns the Timekeeper.
+func (k *Kernel) Timekeeper() *Timekeeper {
+ return k.timekeeper
+}
+
+// TaskSet returns the TaskSet.
+func (k *Kernel) TaskSet() *TaskSet {
+ return k.tasks
+}
+
+// RootUserNamespace returns the root UserNamespace.
+func (k *Kernel) RootUserNamespace() *auth.UserNamespace {
+ return k.rootUserNamespace
+}
+
+// RootUTSNamespace returns the root UTSNamespace.
+func (k *Kernel) RootUTSNamespace() *UTSNamespace {
+ return k.rootUTSNamespace
+}
+
+// RootIPCNamespace returns the root IPCNamespace.
+func (k *Kernel) RootIPCNamespace() *IPCNamespace {
+ return k.rootIPCNamespace
+}
+
+// RootPIDNamespace returns the root PIDNamespace.
+func (k *Kernel) RootPIDNamespace() *PIDNamespace {
+ return k.tasks.Root
+}
+
+// RootAbstractSocketNamespace returns the root AbstractSocketNamespace.
+func (k *Kernel) RootAbstractSocketNamespace() *AbstractSocketNamespace {
+ return k.rootAbstractSocketNamespace
+}
+
+// RootNetworkNamespace returns the root network namespace, always non-nil.
+func (k *Kernel) RootNetworkNamespace() *inet.Namespace {
+ return k.rootNetworkNamespace
+}
+
+// GlobalInit returns the thread group with ID 1 in the root PID namespace, or
+// nil if no such thread group exists. GlobalInit may return a thread group
+// containing no tasks if the thread group has already exited.
+func (k *Kernel) GlobalInit() *ThreadGroup {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ return k.globalInit
+}
+
+// TestOnly_SetGlobalInit sets the thread group with ID 1 in the root PID namespace.
+func (k *Kernel) TestOnly_SetGlobalInit(tg *ThreadGroup) {
+ k.globalInit = tg
+}
+
+// ApplicationCores returns the number of CPUs visible to sandboxed
+// applications.
+func (k *Kernel) ApplicationCores() uint {
+ return k.applicationCores
+}
+
+// RealtimeClock returns the application CLOCK_REALTIME clock.
+func (k *Kernel) RealtimeClock() ktime.Clock {
+ return k.realtimeClock
+}
+
+// MonotonicClock returns the application CLOCK_MONOTONIC clock.
+func (k *Kernel) MonotonicClock() ktime.Clock {
+ return k.monotonicClock
+}
+
+// CPUClockNow returns the current value of k.cpuClock.
+func (k *Kernel) CPUClockNow() uint64 {
+ return atomic.LoadUint64(&k.cpuClock)
+}
+
+// Syslog returns the syslog.
+func (k *Kernel) Syslog() *syslog {
+ return &k.syslog
+}
+
+// GenerateInotifyCookie generates a unique inotify event cookie.
+//
+// Returned values may overlap with previously returned values if the value
+// space is exhausted. 0 is not a valid cookie value, all other values
+// representable in a uint32 are allowed.
+func (k *Kernel) GenerateInotifyCookie() uint32 {
+ id := atomic.AddUint32(&k.nextInotifyCookie, 1)
+ // Wrap-around is explicitly allowed for inotify event cookies.
+ if id == 0 {
+ id = atomic.AddUint32(&k.nextInotifyCookie, 1)
+ }
+ return id
+}
+
+// NetlinkPorts returns the netlink port manager.
+func (k *Kernel) NetlinkPorts() *port.Manager {
+ return k.netlinkPorts
+}
+
+// SaveError returns the sandbox error that caused the kernel to exit during
+// save.
+func (k *Kernel) SaveError() error {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ return k.saveErr
+}
+
+// SetSaveError sets the sandbox error that caused the kernel to exit during
+// save, if one is not already set.
+func (k *Kernel) SetSaveError(err error) {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ if k.saveErr == nil {
+ k.saveErr = err
+ }
+}
+
+var _ tcpip.Clock = (*Kernel)(nil)
+
+// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
+func (k *Kernel) NowNanoseconds() int64 {
+ now, err := k.timekeeper.GetTime(sentrytime.Realtime)
+ if err != nil {
+ panic("Kernel.NowNanoseconds: " + err.Error())
+ }
+ return now
+}
+
+// NowMonotonic implements tcpip.Clock.NowMonotonic.
+func (k *Kernel) NowMonotonic() int64 {
+ now, err := k.timekeeper.GetTime(sentrytime.Monotonic)
+ if err != nil {
+ panic("Kernel.NowMonotonic: " + err.Error())
+ }
+ return now
+}
+
+// SetMemoryFile sets Kernel.mf. SetMemoryFile must be called before Init or
+// LoadFrom.
+func (k *Kernel) SetMemoryFile(mf *pgalloc.MemoryFile) {
+ k.mf = mf
+}
+
+// MemoryFile implements pgalloc.MemoryFileProvider.MemoryFile.
+func (k *Kernel) MemoryFile() *pgalloc.MemoryFile {
+ return k.mf
+}
+
+// SupervisorContext returns a Context with maximum privileges in k. It should
+// only be used by goroutines outside the control of the emulated kernel
+// defined by e.
+//
+// Callers are responsible for ensuring that the returned Context is not used
+// concurrently with changes to the Kernel.
+func (k *Kernel) SupervisorContext() context.Context {
+ return supervisorContext{
+ Logger: log.Log(),
+ k: k,
+ }
+}
+
+// SocketEntry represents a socket recorded in Kernel.sockets. It implements
+// refs.WeakRefUser for sockets stored in the socket table.
+//
+// +stateify savable
+type SocketEntry struct {
+ socketEntry
+ k *Kernel
+ Sock *refs.WeakRef
+ SockVFS2 *vfs.FileDescription
+ ID uint64 // Socket table entry number.
+}
+
+// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
+func (s *SocketEntry) WeakRefGone() {
+ s.k.extMu.Lock()
+ s.k.sockets.Remove(s)
+ s.k.extMu.Unlock()
+}
+
+// RecordSocket adds a socket to the system-wide socket table for tracking.
+//
+// Precondition: Caller must hold a reference to sock.
+func (k *Kernel) RecordSocket(sock *fs.File) {
+ k.extMu.Lock()
+ id := k.nextSocketEntry
+ k.nextSocketEntry++
+ s := &SocketEntry{k: k, ID: id}
+ s.Sock = refs.NewWeakRef(sock, s)
+ k.sockets.PushBack(s)
+ k.extMu.Unlock()
+}
+
+// RecordSocketVFS2 adds a VFS2 socket to the system-wide socket table for
+// tracking.
+//
+// Precondition: Caller must hold a reference to sock.
+//
+// Note that the socket table will not hold a reference on the
+// vfs.FileDescription, because we do not support weak refs on VFS2 files.
+func (k *Kernel) RecordSocketVFS2(sock *vfs.FileDescription) {
+ k.extMu.Lock()
+ id := k.nextSocketEntry
+ k.nextSocketEntry++
+ s := &SocketEntry{
+ k: k,
+ ID: id,
+ SockVFS2: sock,
+ }
+ k.sockets.PushBack(s)
+ k.extMu.Unlock()
+}
+
+// ListSockets returns a snapshot of all sockets.
+//
+// Callers of ListSockets() in VFS2 should use SocketEntry.SockVFS2.TryIncRef()
+// to get a reference on a socket in the table.
+func (k *Kernel) ListSockets() []*SocketEntry {
+ k.extMu.Lock()
+ var socks []*SocketEntry
+ for s := k.sockets.Front(); s != nil; s = s.Next() {
+ socks = append(socks, s)
+ }
+ k.extMu.Unlock()
+ return socks
+}
+
+// supervisorContext is a privileged context.
+type supervisorContext struct {
+ context.NoopSleeper
+ log.Logger
+ k *Kernel
+}
+
+// Value implements context.Context.
+func (ctx supervisorContext) Value(key interface{}) interface{} {
+ switch key {
+ case CtxCanTrace:
+ // The supervisor context can trace anything. (None of
+ // supervisorContext's users are expected to invoke ptrace, but ptrace
+ // permissions are required for certain file accesses.)
+ return func(*Task, bool) bool { return true }
+ case CtxKernel:
+ return ctx.k
+ case CtxPIDNamespace:
+ return ctx.k.tasks.Root
+ case CtxUTSNamespace:
+ return ctx.k.rootUTSNamespace
+ case CtxIPCNamespace:
+ return ctx.k.rootIPCNamespace
+ case auth.CtxCredentials:
+ // The supervisor context is global root.
+ return auth.NewRootCredentials(ctx.k.rootUserNamespace)
+ case fs.CtxRoot:
+ if ctx.k.globalInit != nil {
+ return ctx.k.globalInit.mounts.Root()
+ }
+ return nil
+ case vfs.CtxRoot:
+ if ctx.k.globalInit == nil {
+ return vfs.VirtualDentry{}
+ }
+ mntns := ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
+ defer mntns.DecRef()
+ // Root() takes a reference on the root dirent for us.
+ return mntns.Root()
+ case vfs.CtxMountNamespace:
+ if ctx.k.globalInit == nil {
+ return nil
+ }
+ // MountNamespaceVFS2() takes a reference for us.
+ return ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
+ case fs.CtxDirentCacheLimiter:
+ return ctx.k.DirentCacheLimiter
+ case inet.CtxStack:
+ return ctx.k.RootNetworkNamespace().Stack()
+ case ktime.CtxRealtimeClock:
+ return ctx.k.RealtimeClock()
+ case limits.CtxLimits:
+ // No limits apply.
+ return limits.NewLimitSet()
+ case pgalloc.CtxMemoryFile:
+ return ctx.k.mf
+ case pgalloc.CtxMemoryFileProvider:
+ return ctx.k
+ case platform.CtxPlatform:
+ return ctx.k
+ case uniqueid.CtxGlobalUniqueID:
+ return ctx.k.UniqueID()
+ case uniqueid.CtxGlobalUniqueIDProvider:
+ return ctx.k
+ case uniqueid.CtxInotifyCookie:
+ return ctx.k.GenerateInotifyCookie()
+ case unimpl.CtxEvents:
+ return ctx.k
+ default:
+ return nil
+ }
+}
+
+// Rate limits for the number of unimplemented syscall events.
+const (
+ unimplementedSyscallsMaxRate = 100 // events per second
+ unimplementedSyscallBurst = 1000 // events
+)
+
+// EmitUnimplementedEvent emits an UnimplementedSyscall event via the event
+// channel.
+func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) {
+ k.unimplementedSyscallEmitterOnce.Do(func() {
+ k.unimplementedSyscallEmitter = eventchannel.RateLimitedEmitterFrom(eventchannel.DefaultEmitter, unimplementedSyscallsMaxRate, unimplementedSyscallBurst)
+ })
+
+ t := TaskFromContext(ctx)
+ k.unimplementedSyscallEmitter.Emit(&uspb.UnimplementedSyscall{
+ Tid: int32(t.ThreadID()),
+ Registers: t.Arch().StateData().Proto(),
+ })
+}
+
+// VFS returns the virtual filesystem for the kernel.
+func (k *Kernel) VFS() *vfs.VirtualFilesystem {
+ return &k.vfs
+}
+
+// SetHostMount sets the hostfs mount.
+func (k *Kernel) SetHostMount(mnt *vfs.Mount) {
+ if k.hostMount != nil {
+ panic("Kernel.hostMount cannot be set more than once")
+ }
+ k.hostMount = mnt
+}
+
+// HostMount returns the hostfs mount.
+func (k *Kernel) HostMount() *vfs.Mount {
+ return k.hostMount
+}
+
+// PipeMount returns the pipefs mount.
+func (k *Kernel) PipeMount() *vfs.Mount {
+ return k.pipeMount
+}
+
+// ShmMount returns the tmpfs mount.
+func (k *Kernel) ShmMount() *vfs.Mount {
+ return k.shmMount
+}
+
+// SocketMount returns the sockfs mount.
+func (k *Kernel) SocketMount() *vfs.Mount {
+ return k.socketMount
+}
diff --git a/pkg/sentry/kernel/kernel_opts.go b/pkg/sentry/kernel/kernel_opts.go
new file mode 100644
index 000000000..2e66ec587
--- /dev/null
+++ b/pkg/sentry/kernel/kernel_opts.go
@@ -0,0 +1,20 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+// SpecialOpts contains non-standard options for the kernel.
+//
+// +stateify savable
+type SpecialOpts struct{}
diff --git a/pkg/sentry/kernel/kernel_state.go b/pkg/sentry/kernel/kernel_state.go
new file mode 100644
index 000000000..909219086
--- /dev/null
+++ b/pkg/sentry/kernel/kernel_state.go
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/device"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// saveDanglingEndpoints is invoked by stateify.
+func (k *Kernel) saveDanglingEndpoints() []tcpip.Endpoint {
+ return tcpip.GetDanglingEndpoints()
+}
+
+// loadDanglingEndpoints is invoked by stateify.
+func (k *Kernel) loadDanglingEndpoints(es []tcpip.Endpoint) {
+ for _, e := range es {
+ tcpip.AddDanglingEndpoint(e)
+ }
+}
+
+// saveDeviceRegistry is invoked by stateify.
+func (k *Kernel) saveDeviceRegistry() *device.Registry {
+ return device.SimpleDevices
+}
+
+// loadDeviceRegistry is invoked by stateify.
+func (k *Kernel) loadDeviceRegistry(r *device.Registry) {
+ device.SimpleDevices.LoadFrom(r)
+}
diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD
new file mode 100644
index 000000000..4486848d2
--- /dev/null
+++ b/pkg/sentry/kernel/memevent/BUILD
@@ -0,0 +1,24 @@
+load("//tools:defs.bzl", "go_library", "proto_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "memevent",
+ srcs = ["memory_events.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ ":memory_events_go_proto",
+ "//pkg/eventchannel",
+ "//pkg/log",
+ "//pkg/metric",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/usage",
+ "//pkg/sync",
+ ],
+)
+
+proto_library(
+ name = "memory_events",
+ srcs = ["memory_events.proto"],
+ visibility = ["//visibility:public"],
+)
diff --git a/pkg/sentry/kernel/memevent/memory_events.go b/pkg/sentry/kernel/memevent/memory_events.go
new file mode 100644
index 000000000..200565bb8
--- /dev/null
+++ b/pkg/sentry/kernel/memevent/memory_events.go
@@ -0,0 +1,111 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package memevent implements the memory usage events controller, which
+// periodically emits events via the eventchannel.
+package memevent
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/eventchannel"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/metric"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ pb "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+var totalTicks = metric.MustCreateNewUint64Metric("/memory_events/ticks", false /*sync*/, "Total number of memory event periods that have elapsed since startup.")
+var totalEvents = metric.MustCreateNewUint64Metric("/memory_events/events", false /*sync*/, "Total number of memory events emitted.")
+
+// MemoryEvents describes the configuration for the global memory event emitter.
+type MemoryEvents struct {
+ k *kernel.Kernel
+
+ // The period is how often to emit an event. The memory events goroutine
+ // will ensure a minimum of one event is emitted per this period, regardless
+ // how of much memory usage has changed.
+ period time.Duration
+
+ // Writing to this channel indicates the memory goroutine should stop.
+ stop chan struct{}
+
+ // done is used to signal when the memory event goroutine has exited.
+ done sync.WaitGroup
+}
+
+// New creates a new MemoryEvents.
+func New(k *kernel.Kernel, period time.Duration) *MemoryEvents {
+ return &MemoryEvents{
+ k: k,
+ period: period,
+ stop: make(chan struct{}),
+ }
+}
+
+// Stop stops the memory usage events emitter goroutine. Stop must not be called
+// concurrently with Start and may only be called once.
+func (m *MemoryEvents) Stop() {
+ close(m.stop)
+ m.done.Wait()
+}
+
+// Start starts the memory usage events emitter goroutine. Start must not be
+// called concurrently with Stop and may only be called once.
+func (m *MemoryEvents) Start() {
+ if m.period == 0 {
+ return
+ }
+ m.done.Add(1)
+ go m.run() // S/R-SAFE: doesn't interact with saved state.
+}
+
+func (m *MemoryEvents) run() {
+ defer m.done.Done()
+
+ // Emit the first event immediately on startup.
+ totalTicks.Increment()
+ m.emit()
+
+ ticker := time.NewTicker(m.period)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-m.stop:
+ return
+ case <-ticker.C:
+ totalTicks.Increment()
+ m.emit()
+ }
+ }
+}
+
+func (m *MemoryEvents) emit() {
+ totalPlatform, err := m.k.MemoryFile().TotalUsage()
+ if err != nil {
+ log.Warningf("Failed to fetch memory usage for memory events: %v", err)
+ return
+ }
+ snapshot, _ := usage.MemoryAccounting.Copy()
+ total := totalPlatform + snapshot.Mapped
+
+ totalEvents.Increment()
+ eventchannel.Emit(&pb.MemoryUsageEvent{
+ Mapped: snapshot.Mapped,
+ Total: total,
+ })
+}
diff --git a/pkg/sentry/kernel/memevent/memory_events.proto b/pkg/sentry/kernel/memevent/memory_events.proto
new file mode 100644
index 000000000..bf8029ff5
--- /dev/null
+++ b/pkg/sentry/kernel/memevent/memory_events.proto
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package gvisor;
+
+// MemoryUsageEvent describes the memory usage of the sandbox at a single
+// instant in time. These messages are emitted periodically on the eventchannel.
+message MemoryUsageEvent {
+ // The total memory usage of the sandboxed application in bytes, calculated
+ // using the 'fast' method.
+ uint64 total = 1;
+
+ // Memory used to back memory-mapped regions for files in the application, in
+ // bytes. This corresponds to the usage.MemoryKind.Mapped memory type.
+ uint64 mapped = 2;
+}
diff --git a/pkg/sentry/kernel/pending_signals.go b/pkg/sentry/kernel/pending_signals.go
new file mode 100644
index 000000000..77a35b788
--- /dev/null
+++ b/pkg/sentry/kernel/pending_signals.go
@@ -0,0 +1,142 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+const (
+ // stdSignalCap is the maximum number of instances of a given standard
+ // signal that may be pending. ("[If] multiple instances of a standard
+ // signal are delivered while that signal is currently blocked, then only
+ // one instance is queued.") - signal(7)
+ stdSignalCap = 1
+
+ // rtSignalCap is the maximum number of instances of a given realtime
+ // signal that may be pending.
+ //
+ // TODO(igudger): In Linux, the minimum signal queue size is
+ // RLIMIT_SIGPENDING, which is by default max_threads/2.
+ rtSignalCap = 32
+)
+
+// pendingSignals holds a collection of pending signals. The zero value of
+// pendingSignals is a valid empty collection. pendingSignals is thread-unsafe;
+// users must provide synchronization.
+//
+// +stateify savable
+type pendingSignals struct {
+ // signals contains all pending signals.
+ //
+ // Note that signals is zero-indexed, but signal 1 is the first valid
+ // signal, so signals[0] contains signals with signo 1 etc. This offset is
+ // usually handled by using Signal.index().
+ signals [linux.SignalMaximum]pendingSignalQueue `state:".([]savedPendingSignal)"`
+
+ // Bit i of pendingSet is set iff there is at least one signal with signo
+ // i+1 pending.
+ pendingSet linux.SignalSet `state:"manual"`
+}
+
+// pendingSignalQueue holds a pendingSignalList for a single signal number.
+//
+// +stateify savable
+type pendingSignalQueue struct {
+ pendingSignalList
+ length int
+}
+
+// +stateify savable
+type pendingSignal struct {
+ // pendingSignalEntry links into a pendingSignalList.
+ pendingSignalEntry
+ *arch.SignalInfo
+
+ // If timer is not nil, it is the IntervalTimer which sent this signal.
+ timer *IntervalTimer
+}
+
+// enqueue enqueues the given signal. enqueue returns true on success and false
+// on failure (if the given signal's queue is full).
+//
+// Preconditions: info represents a valid signal.
+func (p *pendingSignals) enqueue(info *arch.SignalInfo, timer *IntervalTimer) bool {
+ sig := linux.Signal(info.Signo)
+ q := &p.signals[sig.Index()]
+ if sig.IsStandard() {
+ if q.length >= stdSignalCap {
+ return false
+ }
+ } else if q.length >= rtSignalCap {
+ return false
+ }
+ q.pendingSignalList.PushBack(&pendingSignal{SignalInfo: info, timer: timer})
+ q.length++
+ p.pendingSet |= linux.SignalSetOf(sig)
+ return true
+}
+
+// dequeue dequeues and returns any pending signal not masked by mask. If no
+// unmasked signals are pending, dequeue returns nil.
+func (p *pendingSignals) dequeue(mask linux.SignalSet) *arch.SignalInfo {
+ // "Real-time signals are delivered in a guaranteed order. Multiple
+ // real-time signals of the same type are delivered in the order they were
+ // sent. If different real-time signals are sent to a process, they are
+ // delivered starting with the lowest-numbered signal. (I.e., low-numbered
+ // signals have highest priority.) By contrast, if multiple standard
+ // signals are pending for a process, the order in which they are delivered
+ // is unspecified. If both standard and real-time signals are pending for a
+ // process, POSIX leaves it unspecified which is delivered first. Linux,
+ // like many other implementations, gives priority to standard signals in
+ // this case." - signal(7)
+ lowestPendingUnblockedBit := bits.TrailingZeros64(uint64(p.pendingSet &^ mask))
+ if lowestPendingUnblockedBit >= linux.SignalMaximum {
+ return nil
+ }
+ return p.dequeueSpecific(linux.Signal(lowestPendingUnblockedBit + 1))
+}
+
+func (p *pendingSignals) dequeueSpecific(sig linux.Signal) *arch.SignalInfo {
+ q := &p.signals[sig.Index()]
+ ps := q.pendingSignalList.Front()
+ if ps == nil {
+ return nil
+ }
+ q.pendingSignalList.Remove(ps)
+ q.length--
+ if q.length == 0 {
+ p.pendingSet &^= linux.SignalSetOf(sig)
+ }
+ if ps.timer != nil {
+ ps.timer.updateDequeuedSignalLocked(ps.SignalInfo)
+ }
+ return ps.SignalInfo
+}
+
+// discardSpecific causes all pending signals with number sig to be discarded.
+func (p *pendingSignals) discardSpecific(sig linux.Signal) {
+ q := &p.signals[sig.Index()]
+ for ps := q.pendingSignalList.Front(); ps != nil; ps = ps.Next() {
+ if ps.timer != nil {
+ ps.timer.signalRejectedLocked()
+ }
+ }
+ q.pendingSignalList.Reset()
+ q.length = 0
+ p.pendingSet &^= linux.SignalSetOf(sig)
+}
diff --git a/pkg/sentry/kernel/pending_signals_state.go b/pkg/sentry/kernel/pending_signals_state.go
new file mode 100644
index 000000000..ca8b4e164
--- /dev/null
+++ b/pkg/sentry/kernel/pending_signals_state.go
@@ -0,0 +1,46 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+// +stateify savable
+type savedPendingSignal struct {
+ si *arch.SignalInfo
+ timer *IntervalTimer
+}
+
+// saveSignals is invoked by stateify.
+func (p *pendingSignals) saveSignals() []savedPendingSignal {
+ var pending []savedPendingSignal
+ for _, q := range p.signals {
+ for ps := q.pendingSignalList.Front(); ps != nil; ps = ps.Next() {
+ pending = append(pending, savedPendingSignal{
+ si: ps.SignalInfo,
+ timer: ps.timer,
+ })
+ }
+ }
+ return pending
+}
+
+// loadSignals is invoked by stateify.
+func (p *pendingSignals) loadSignals(pending []savedPendingSignal) {
+ for _, sps := range pending {
+ p.enqueue(sps.si, sps.timer)
+ }
+}
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
new file mode 100644
index 000000000..449643118
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -0,0 +1,54 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "pipe",
+ srcs = [
+ "device.go",
+ "node.go",
+ "pipe.go",
+ "pipe_unsafe.go",
+ "pipe_util.go",
+ "reader.go",
+ "reader_writer.go",
+ "vfs.go",
+ "writer.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/amutex",
+ "//pkg/buffer",
+ "//pkg/context",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "pipe_test",
+ size = "small",
+ srcs = [
+ "node_test.go",
+ "pipe_test.go",
+ ],
+ library = ":pipe",
+ deps = [
+ "//pkg/context",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/kernel/pipe/device.go b/pkg/sentry/kernel/pipe/device.go
new file mode 100644
index 000000000..89f5d9342
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import "gvisor.dev/gvisor/pkg/sentry/device"
+
+// pipeDevice is used for all pipe files.
+var pipeDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go
new file mode 100644
index 000000000..4b688c627
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/node.go
@@ -0,0 +1,139 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// inodeOperations implements fs.InodeOperations for pipes.
+//
+// +stateify savable
+type inodeOperations struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+
+ // Marking pipe inodes as virtual allows them to be saved and restored
+ // even if they have been unlinked. We can get away with this because
+ // their state exists entirely within the sentry.
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // p is the underlying Pipe object representing this fifo.
+ p *Pipe
+
+ // Channels for synchronizing the creation of new readers and writers of
+ // this fifo. See waitFor and newHandleLocked.
+ //
+ // These are not saved/restored because all waiters are unblocked on save,
+ // and either automatically restart (via ERESTARTSYS) or return EINTR on
+ // resume. On restarts via ERESTARTSYS, the appropriate channel will be
+ // recreated.
+ rWakeup chan struct{} `state:"nosave"`
+ wWakeup chan struct{} `state:"nosave"`
+}
+
+var _ fs.InodeOperations = (*inodeOperations)(nil)
+
+// NewInodeOperations returns a new fs.InodeOperations for a given pipe.
+func NewInodeOperations(ctx context.Context, perms fs.FilePermissions, p *Pipe) *inodeOperations {
+ return &inodeOperations{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.FileOwnerFromContext(ctx), perms, linux.PIPEFS_MAGIC),
+ p: p,
+ }
+}
+
+// GetFile implements fs.InodeOperations.GetFile. Named pipes have special blocking
+// semantics during open:
+//
+// "Normally, opening the FIFO blocks until the other end is opened also. A
+// process can open a FIFO in nonblocking mode. In this case, opening for
+// read-only will succeed even if no-one has opened on the write side yet,
+// opening for write-only will fail with ENXIO (no such device or address)
+// unless the other end has already been opened. Under Linux, opening a FIFO
+// for read and write will succeed both in blocking and nonblocking mode. POSIX
+// leaves this behavior undefined. This can be used to open a FIFO for writing
+// while there are no readers available." - fifo(7)
+func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ i.mu.Lock()
+ defer i.mu.Unlock()
+
+ switch {
+ case flags.Read && !flags.Write: // O_RDONLY.
+ r := i.p.Open(ctx, d, flags)
+ newHandleLocked(&i.rWakeup)
+
+ if i.p.isNamed && !flags.NonBlocking && !i.p.HasWriters() {
+ if !waitFor(&i.mu, &i.wWakeup, ctx) {
+ r.DecRef()
+ return nil, syserror.ErrInterrupted
+ }
+ }
+
+ // By now, either we're doing a nonblocking open or we have a writer. On
+ // a nonblocking read-only open, the open succeeds even if no-one has
+ // opened the write side yet.
+ return r, nil
+
+ case flags.Write && !flags.Read: // O_WRONLY.
+ w := i.p.Open(ctx, d, flags)
+ newHandleLocked(&i.wWakeup)
+
+ if i.p.isNamed && !i.p.HasReaders() {
+ // On a nonblocking, write-only open, the open fails with ENXIO if the
+ // read side isn't open yet.
+ if flags.NonBlocking {
+ w.DecRef()
+ return nil, syserror.ENXIO
+ }
+
+ if !waitFor(&i.mu, &i.rWakeup, ctx) {
+ w.DecRef()
+ return nil, syserror.ErrInterrupted
+ }
+ }
+ return w, nil
+
+ case flags.Read && flags.Write: // O_RDWR.
+ // Pipes opened for read-write always succeeds without blocking.
+ rw := i.p.Open(ctx, d, flags)
+ newHandleLocked(&i.rWakeup)
+ newHandleLocked(&i.wWakeup)
+ return rw, nil
+
+ default:
+ return nil, syserror.EINVAL
+ }
+}
+
+func (*inodeOperations) Allocate(_ context.Context, _ *fs.Inode, _, _ int64) error {
+ return syserror.EPIPE
+}
diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go
new file mode 100644
index 000000000..ab75a87ff
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/node_test.go
@@ -0,0 +1,320 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type sleeper struct {
+ context.Context
+ ch chan struct{}
+}
+
+func newSleeperContext(t *testing.T) context.Context {
+ return &sleeper{
+ Context: contexttest.Context(t),
+ ch: make(chan struct{}),
+ }
+}
+
+func (s *sleeper) SleepStart() <-chan struct{} {
+ return s.ch
+}
+
+func (s *sleeper) SleepFinish(bool) {
+}
+
+func (s *sleeper) Cancel() {
+ s.ch <- struct{}{}
+}
+
+func (s *sleeper) Interrupted() bool {
+ return len(s.ch) != 0
+}
+
+type openResult struct {
+ *fs.File
+ error
+}
+
+var perms fs.FilePermissions = fs.FilePermissions{
+ User: fs.PermMask{Read: true, Write: true},
+}
+
+func testOpenOrDie(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs.FileFlags, doneChan chan<- struct{}) (*fs.File, error) {
+ inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe})
+ d := fs.NewDirent(ctx, inode, "pipe")
+ file, err := n.GetFile(ctx, d, flags)
+ if err != nil {
+ t.Fatalf("open with flags %+v failed: %v", flags, err)
+ }
+ if doneChan != nil {
+ doneChan <- struct{}{}
+ }
+ return file, err
+}
+
+func testOpen(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs.FileFlags, resChan chan<- openResult) (*fs.File, error) {
+ inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe})
+ d := fs.NewDirent(ctx, inode, "pipe")
+ file, err := n.GetFile(ctx, d, flags)
+ if resChan != nil {
+ resChan <- openResult{file, err}
+ }
+ return file, err
+}
+
+func newNamedPipe(t *testing.T) *Pipe {
+ return NewPipe(true, DefaultPipeSize, usermem.PageSize)
+}
+
+func newAnonPipe(t *testing.T) *Pipe {
+ return NewPipe(false, DefaultPipeSize, usermem.PageSize)
+}
+
+// assertRecvBlocks ensures that a recv attempt on c blocks for at least
+// blockDuration. This is useful for checking that a goroutine that is supposed
+// to be executing a blocking operation is actually blocking.
+func assertRecvBlocks(t *testing.T, c <-chan struct{}, blockDuration time.Duration, failMsg string) {
+ select {
+ case <-c:
+ t.Fatalf(failMsg)
+ case <-time.After(blockDuration):
+ // Ok, blocked for the required duration.
+ }
+}
+
+func TestReadOpenBlocksForWriteOpen(t *testing.T) {
+ ctx := newSleeperContext(t)
+ f := NewInodeOperations(ctx, perms, newNamedPipe(t))
+
+ rDone := make(chan struct{})
+ go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone)
+
+ // Verify that the open for read is blocking.
+ assertRecvBlocks(t, rDone, time.Millisecond*100,
+ "open for read not blocking with no writers")
+
+ wDone := make(chan struct{})
+ go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone)
+
+ <-wDone
+ <-rDone
+}
+
+func TestWriteOpenBlocksForReadOpen(t *testing.T) {
+ ctx := newSleeperContext(t)
+ f := NewInodeOperations(ctx, perms, newNamedPipe(t))
+
+ wDone := make(chan struct{})
+ go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone)
+
+ // Verify that the open for write is blocking
+ assertRecvBlocks(t, wDone, time.Millisecond*100,
+ "open for write not blocking with no readers")
+
+ rDone := make(chan struct{})
+ go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone)
+
+ <-rDone
+ <-wDone
+}
+
+func TestMultipleWriteOpenDoesntCountAsReadOpen(t *testing.T) {
+ ctx := newSleeperContext(t)
+ f := NewInodeOperations(ctx, perms, newNamedPipe(t))
+
+ rDone1 := make(chan struct{})
+ rDone2 := make(chan struct{})
+ go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone1)
+ go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone2)
+
+ assertRecvBlocks(t, rDone1, time.Millisecond*100,
+ "open for read didn't block with no writers")
+ assertRecvBlocks(t, rDone2, time.Millisecond*100,
+ "open for read didn't block with no writers")
+
+ wDone := make(chan struct{})
+ go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone)
+
+ <-wDone
+ <-rDone2
+ <-rDone1
+}
+
+func TestClosedReaderBlocksWriteOpen(t *testing.T) {
+ ctx := newSleeperContext(t)
+ f := NewInodeOperations(ctx, perms, newNamedPipe(t))
+
+ rFile, _ := testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, NonBlocking: true}, nil)
+ rFile.DecRef()
+
+ wDone := make(chan struct{})
+ // This open for write should block because the reader is now gone.
+ go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone)
+ assertRecvBlocks(t, wDone, time.Millisecond*100,
+ "open for write didn't block with no concurrent readers")
+
+ // Open for read again. This should unblock the open for write.
+ rDone := make(chan struct{})
+ go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone)
+
+ <-rDone
+ <-wDone
+}
+
+func TestReadWriteOpenNeverBlocks(t *testing.T) {
+ ctx := newSleeperContext(t)
+ f := NewInodeOperations(ctx, perms, newNamedPipe(t))
+
+ rwDone := make(chan struct{})
+ // Open for read-write never wait for a reader or writer, even if the
+ // nonblocking flag is not set.
+ go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, Write: true, NonBlocking: false}, rwDone)
+ <-rwDone
+}
+
+func TestReadWriteOpenUnblocksReadOpen(t *testing.T) {
+ ctx := newSleeperContext(t)
+ f := NewInodeOperations(ctx, perms, newNamedPipe(t))
+
+ rDone := make(chan struct{})
+ go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone)
+
+ rwDone := make(chan struct{})
+ go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, Write: true}, rwDone)
+
+ <-rwDone
+ <-rDone
+}
+
+func TestReadWriteOpenUnblocksWriteOpen(t *testing.T) {
+ ctx := newSleeperContext(t)
+ f := NewInodeOperations(ctx, perms, newNamedPipe(t))
+
+ wDone := make(chan struct{})
+ go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone)
+
+ rwDone := make(chan struct{})
+ go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, Write: true}, rwDone)
+
+ <-rwDone
+ <-wDone
+}
+
+func TestBlockedOpenIsCancellable(t *testing.T) {
+ ctx := newSleeperContext(t)
+ f := NewInodeOperations(ctx, perms, newNamedPipe(t))
+
+ done := make(chan openResult)
+ go testOpen(ctx, t, f, fs.FileFlags{Read: true}, done)
+ select {
+ case <-done:
+ t.Fatalf("open for read didn't block with no writers")
+ case <-time.After(time.Millisecond * 100):
+ // Ok.
+ }
+
+ ctx.(*sleeper).Cancel()
+ // If the cancel on the sleeper didn't work, the open for read would never
+ // return.
+ res := <-done
+ if res.error != syserror.ErrInterrupted {
+ t.Fatalf("Cancellation didn't cause GetFile to return fs.ErrInterrupted, got %v.",
+ res.error)
+ }
+}
+
+func TestNonblockingReadOpenFileNoWriters(t *testing.T) {
+ ctx := newSleeperContext(t)
+ f := NewInodeOperations(ctx, perms, newNamedPipe(t))
+
+ if _, err := testOpen(ctx, t, f, fs.FileFlags{Read: true, NonBlocking: true}, nil); err != nil {
+ t.Fatalf("Nonblocking open for read failed with error %v.", err)
+ }
+}
+
+func TestNonblockingWriteOpenFileNoReaders(t *testing.T) {
+ ctx := newSleeperContext(t)
+ f := NewInodeOperations(ctx, perms, newNamedPipe(t))
+
+ if _, err := testOpen(ctx, t, f, fs.FileFlags{Write: true, NonBlocking: true}, nil); err != syserror.ENXIO {
+ t.Fatalf("Nonblocking open for write failed unexpected error %v.", err)
+ }
+}
+
+func TestNonBlockingReadOpenWithWriter(t *testing.T) {
+ ctx := newSleeperContext(t)
+ f := NewInodeOperations(ctx, perms, newNamedPipe(t))
+
+ wDone := make(chan struct{})
+ go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone)
+
+ // Open for write blocks since there are no readers yet.
+ assertRecvBlocks(t, wDone, time.Millisecond*100,
+ "Open for write didn't block with no reader.")
+
+ if _, err := testOpen(ctx, t, f, fs.FileFlags{Read: true, NonBlocking: true}, nil); err != nil {
+ t.Fatalf("Nonblocking open for read failed with error %v.", err)
+ }
+
+ // Open for write should now be unblocked.
+ <-wDone
+}
+
+func TestNonBlockingWriteOpenWithReader(t *testing.T) {
+ ctx := newSleeperContext(t)
+ f := NewInodeOperations(ctx, perms, newNamedPipe(t))
+
+ rDone := make(chan struct{})
+ go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone)
+
+ // Open for write blocked, since no reader yet.
+ assertRecvBlocks(t, rDone, time.Millisecond*100,
+ "Open for reader didn't block with no writer.")
+
+ if _, err := testOpen(ctx, t, f, fs.FileFlags{Write: true, NonBlocking: true}, nil); err != nil {
+ t.Fatalf("Nonblocking open for write failed with error %v.", err)
+ }
+
+ // Open for write should now be unblocked.
+ <-rDone
+}
+
+func TestAnonReadOpen(t *testing.T) {
+ ctx := newSleeperContext(t)
+ f := NewInodeOperations(ctx, perms, newAnonPipe(t))
+
+ if _, err := testOpen(ctx, t, f, fs.FileFlags{Read: true}, nil); err != nil {
+ t.Fatalf("open anon pipe for read failed: %v", err)
+ }
+}
+
+func TestAnonWriteOpen(t *testing.T) {
+ ctx := newSleeperContext(t)
+ f := NewInodeOperations(ctx, perms, newAnonPipe(t))
+
+ if _, err := testOpen(ctx, t, f, fs.FileFlags{Write: true}, nil); err != nil {
+ t.Fatalf("open anon pipe for write failed: %v", err)
+ }
+}
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
new file mode 100644
index 000000000..79645d7d2
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -0,0 +1,419 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pipe provides a pipe implementation.
+package pipe
+
+import (
+ "fmt"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/buffer"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ // MinimumPipeSize is a hard limit of the minimum size of a pipe.
+ MinimumPipeSize = 64 << 10
+
+ // DefaultPipeSize is the system-wide default size of a pipe in bytes.
+ DefaultPipeSize = MinimumPipeSize
+
+ // MaximumPipeSize is a hard limit on the maximum size of a pipe.
+ MaximumPipeSize = 8 << 20
+)
+
+// Pipe is an encapsulation of a platform-independent pipe.
+// It manages a buffered byte queue shared between a reader/writer
+// pair.
+//
+// +stateify savable
+type Pipe struct {
+ waiter.Queue `state:"nosave"`
+
+ // isNamed indicates whether this is a named pipe.
+ //
+ // This value is immutable.
+ isNamed bool
+
+ // atomicIOBytes is the maximum number of bytes that the pipe will
+ // guarantee atomic reads or writes atomically.
+ //
+ // This value is immutable.
+ atomicIOBytes int64
+
+ // The number of active readers for this pipe.
+ //
+ // Access atomically.
+ readers int32
+
+ // The number of active writes for this pipe.
+ //
+ // Access atomically.
+ writers int32
+
+ // mu protects all pipe internal state below.
+ mu sync.Mutex `state:"nosave"`
+
+ // view is the underlying set of buffers.
+ //
+ // This is protected by mu.
+ view buffer.View
+
+ // max is the maximum size of the pipe in bytes. When this max has been
+ // reached, writers will get EWOULDBLOCK.
+ //
+ // This is protected by mu.
+ max int64
+
+ // hadWriter indicates if this pipe ever had a writer. Note that this
+ // does not necessarily indicate there is *currently* a writer, just
+ // that there has been a writer at some point since the pipe was
+ // created.
+ //
+ // This is protected by mu.
+ hadWriter bool
+}
+
+// NewPipe initializes and returns a pipe.
+//
+// N.B. The size and atomicIOBytes will be bounded.
+func NewPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe {
+ if sizeBytes < MinimumPipeSize {
+ sizeBytes = MinimumPipeSize
+ }
+ if sizeBytes > MaximumPipeSize {
+ sizeBytes = MaximumPipeSize
+ }
+ if atomicIOBytes <= 0 {
+ atomicIOBytes = 1
+ }
+ if atomicIOBytes > sizeBytes {
+ atomicIOBytes = sizeBytes
+ }
+ var p Pipe
+ initPipe(&p, isNamed, sizeBytes, atomicIOBytes)
+ return &p
+}
+
+func initPipe(pipe *Pipe, isNamed bool, sizeBytes, atomicIOBytes int64) {
+ if sizeBytes < MinimumPipeSize {
+ sizeBytes = MinimumPipeSize
+ }
+ if sizeBytes > MaximumPipeSize {
+ sizeBytes = MaximumPipeSize
+ }
+ if atomicIOBytes <= 0 {
+ atomicIOBytes = 1
+ }
+ if atomicIOBytes > sizeBytes {
+ atomicIOBytes = sizeBytes
+ }
+ pipe.isNamed = isNamed
+ pipe.max = sizeBytes
+ pipe.atomicIOBytes = atomicIOBytes
+}
+
+// NewConnectedPipe initializes a pipe and returns a pair of objects
+// representing the read and write ends of the pipe.
+func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs.File, *fs.File) {
+ p := NewPipe(false /* isNamed */, sizeBytes, atomicIOBytes)
+
+ // Build an fs.Dirent for the pipe which will be shared by both
+ // returned files.
+ perms := fs.FilePermissions{
+ User: fs.PermMask{Read: true, Write: true},
+ }
+ iops := NewInodeOperations(ctx, perms, p)
+ ino := pipeDevice.NextIno()
+ sattr := fs.StableAttr{
+ Type: fs.Pipe,
+ DeviceID: pipeDevice.DeviceID(),
+ InodeID: ino,
+ BlockSize: int64(atomicIOBytes),
+ }
+ ms := fs.NewPseudoMountSource(ctx)
+ d := fs.NewDirent(ctx, fs.NewInode(ctx, iops, ms, sattr), fmt.Sprintf("pipe:[%d]", ino))
+ // The p.Open calls below will each take a reference on the Dirent. We
+ // must drop the one we already have.
+ defer d.DecRef()
+ return p.Open(ctx, d, fs.FileFlags{Read: true}), p.Open(ctx, d, fs.FileFlags{Write: true})
+}
+
+// Open opens the pipe and returns a new file.
+//
+// Precondition: at least one of flags.Read or flags.Write must be set.
+func (p *Pipe) Open(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) *fs.File {
+ flags.NonSeekable = true
+ switch {
+ case flags.Read && flags.Write:
+ p.rOpen()
+ p.wOpen()
+ return fs.NewFile(ctx, d, flags, &ReaderWriter{
+ Pipe: p,
+ })
+ case flags.Read:
+ p.rOpen()
+ return fs.NewFile(ctx, d, flags, &Reader{
+ ReaderWriter: ReaderWriter{Pipe: p},
+ })
+ case flags.Write:
+ p.wOpen()
+ return fs.NewFile(ctx, d, flags, &Writer{
+ ReaderWriter: ReaderWriter{Pipe: p},
+ })
+ default:
+ // Precondition violated.
+ panic("invalid pipe flags")
+ }
+}
+
+type readOps struct {
+ // left returns the bytes remaining.
+ left func() int64
+
+ // limit limits subsequence reads.
+ limit func(int64)
+
+ // read performs the actual read operation.
+ read func(*buffer.View) (int64, error)
+}
+
+// read reads data from the pipe into dst and returns the number of bytes
+// read, or returns ErrWouldBlock if the pipe is empty.
+//
+// Precondition: this pipe must have readers.
+func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) {
+ // Don't block for a zero-length read even if the pipe is empty.
+ if ops.left() == 0 {
+ return 0, nil
+ }
+
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.readLocked(ctx, ops)
+}
+
+func (p *Pipe) readLocked(ctx context.Context, ops readOps) (int64, error) {
+ // Is the pipe empty?
+ if p.view.Size() == 0 {
+ if !p.HasWriters() {
+ // There are no writers, return EOF.
+ return 0, nil
+ }
+ return 0, syserror.ErrWouldBlock
+ }
+
+ // Limit how much we consume.
+ if ops.left() > p.view.Size() {
+ ops.limit(p.view.Size())
+ }
+
+ // Copy user data; the read op is responsible for trimming.
+ done, err := ops.read(&p.view)
+ return done, err
+}
+
+type writeOps struct {
+ // left returns the bytes remaining.
+ left func() int64
+
+ // limit should limit subsequent writes.
+ limit func(int64)
+
+ // write should write to the provided buffer.
+ write func(*buffer.View) (int64, error)
+}
+
+// write writes data from sv into the pipe and returns the number of bytes
+// written. If no bytes are written because the pipe is full (or has less than
+// atomicIOBytes free capacity), write returns ErrWouldBlock.
+//
+// Precondition: this pipe must have writers.
+func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.writeLocked(ctx, ops)
+}
+
+func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) {
+ // Can't write to a pipe with no readers.
+ if !p.HasReaders() {
+ return 0, syscall.EPIPE
+ }
+
+ // POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be
+ // atomic, but requires no atomicity for writes larger than this.
+ wanted := ops.left()
+ avail := p.max - p.view.Size()
+ if wanted > avail {
+ if wanted <= p.atomicIOBytes {
+ return 0, syserror.ErrWouldBlock
+ }
+ ops.limit(avail)
+ }
+
+ // Copy user data.
+ done, err := ops.write(&p.view)
+ if err != nil {
+ return done, err
+ }
+
+ if done < avail {
+ // Non-failure, but short write.
+ return done, nil
+ }
+ if done < wanted {
+ // Partial write due to full pipe. Note that this could also be
+ // the short write case above, we would expect a second call
+ // and the write to return zero bytes in this case.
+ return done, syserror.ErrWouldBlock
+ }
+
+ return done, nil
+}
+
+// rOpen signals a new reader of the pipe.
+func (p *Pipe) rOpen() {
+ atomic.AddInt32(&p.readers, 1)
+}
+
+// wOpen signals a new writer of the pipe.
+func (p *Pipe) wOpen() {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.hadWriter = true
+ atomic.AddInt32(&p.writers, 1)
+}
+
+// rClose signals that a reader has closed their end of the pipe.
+func (p *Pipe) rClose() {
+ newReaders := atomic.AddInt32(&p.readers, -1)
+ if newReaders < 0 {
+ panic(fmt.Sprintf("Refcounting bug, pipe has negative readers: %v", newReaders))
+ }
+}
+
+// wClose signals that a writer has closed their end of the pipe.
+func (p *Pipe) wClose() {
+ newWriters := atomic.AddInt32(&p.writers, -1)
+ if newWriters < 0 {
+ panic(fmt.Sprintf("Refcounting bug, pipe has negative writers: %v.", newWriters))
+ }
+}
+
+// HasReaders returns whether the pipe has any active readers.
+func (p *Pipe) HasReaders() bool {
+ return atomic.LoadInt32(&p.readers) > 0
+}
+
+// HasWriters returns whether the pipe has any active writers.
+func (p *Pipe) HasWriters() bool {
+ return atomic.LoadInt32(&p.writers) > 0
+}
+
+// rReadinessLocked calculates the read readiness.
+//
+// Precondition: mu must be held.
+func (p *Pipe) rReadinessLocked() waiter.EventMask {
+ ready := waiter.EventMask(0)
+ if p.HasReaders() && p.view.Size() != 0 {
+ ready |= waiter.EventIn
+ }
+ if !p.HasWriters() && p.hadWriter {
+ // POLLHUP must be suppressed until the pipe has had at least one writer
+ // at some point. Otherwise a reader thread may poll and immediately get
+ // a POLLHUP before the writer ever opens the pipe, which the reader may
+ // interpret as the writer opening then closing the pipe.
+ ready |= waiter.EventHUp
+ }
+ return ready
+}
+
+// rReadiness returns a mask that states whether the read end of the pipe is
+// ready for reading.
+func (p *Pipe) rReadiness() waiter.EventMask {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.rReadinessLocked()
+}
+
+// wReadinessLocked calculates the write readiness.
+//
+// Precondition: mu must be held.
+func (p *Pipe) wReadinessLocked() waiter.EventMask {
+ ready := waiter.EventMask(0)
+ if p.HasWriters() && p.view.Size() < p.max {
+ ready |= waiter.EventOut
+ }
+ if !p.HasReaders() {
+ ready |= waiter.EventErr
+ }
+ return ready
+}
+
+// wReadiness returns a mask that states whether the write end of the pipe
+// is ready for writing.
+func (p *Pipe) wReadiness() waiter.EventMask {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.wReadinessLocked()
+}
+
+// rwReadiness returns a mask that states whether a read-write handle to the
+// pipe is ready for IO.
+func (p *Pipe) rwReadiness() waiter.EventMask {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.rReadinessLocked() | p.wReadinessLocked()
+}
+
+// queued returns the amount of queued data.
+func (p *Pipe) queued() int64 {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.view.Size()
+}
+
+// FifoSize implements fs.FifoSizer.FifoSize.
+func (p *Pipe) FifoSize(context.Context, *fs.File) (int64, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.max, nil
+}
+
+// SetFifoSize implements fs.FifoSizer.SetFifoSize.
+func (p *Pipe) SetFifoSize(size int64) (int64, error) {
+ if size < 0 {
+ return 0, syserror.EINVAL
+ }
+ if size < MinimumPipeSize {
+ size = MinimumPipeSize // Per spec.
+ }
+ if size > MaximumPipeSize {
+ return 0, syserror.EPERM
+ }
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if size < p.view.Size() {
+ return 0, syserror.EBUSY
+ }
+ p.max = size
+ return size, nil
+}
diff --git a/pkg/sentry/kernel/pipe/pipe_test.go b/pkg/sentry/kernel/pipe/pipe_test.go
new file mode 100644
index 000000000..bda739dbe
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/pipe_test.go
@@ -0,0 +1,139 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "bytes"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestPipeRW(t *testing.T) {
+ ctx := contexttest.Context(t)
+ r, w := NewConnectedPipe(ctx, 65536, 4096)
+ defer r.DecRef()
+ defer w.DecRef()
+
+ msg := []byte("here's some bytes")
+ wantN := int64(len(msg))
+ n, err := w.Writev(ctx, usermem.BytesIOSequence(msg))
+ if n != wantN || err != nil {
+ t.Fatalf("Writev: got (%d, %v), wanted (%d, nil)", n, err, wantN)
+ }
+
+ buf := make([]byte, len(msg))
+ n, err = r.Readv(ctx, usermem.BytesIOSequence(buf))
+ if n != wantN || err != nil || !bytes.Equal(buf, msg) {
+ t.Fatalf("Readv: got (%d, %v) %q, wanted (%d, nil) %q", n, err, buf, wantN, msg)
+ }
+}
+
+func TestPipeReadBlock(t *testing.T) {
+ ctx := contexttest.Context(t)
+ r, w := NewConnectedPipe(ctx, 65536, 4096)
+ defer r.DecRef()
+ defer w.DecRef()
+
+ n, err := r.Readv(ctx, usermem.BytesIOSequence(make([]byte, 1)))
+ if n != 0 || err != syserror.ErrWouldBlock {
+ t.Fatalf("Readv: got (%d, %v), wanted (0, %v)", n, err, syserror.ErrWouldBlock)
+ }
+}
+
+func TestPipeWriteBlock(t *testing.T) {
+ const atomicIOBytes = 2
+ const capacity = MinimumPipeSize
+
+ ctx := contexttest.Context(t)
+ r, w := NewConnectedPipe(ctx, capacity, atomicIOBytes)
+ defer r.DecRef()
+ defer w.DecRef()
+
+ msg := make([]byte, capacity+1)
+ n, err := w.Writev(ctx, usermem.BytesIOSequence(msg))
+ if wantN, wantErr := int64(capacity), syserror.ErrWouldBlock; n != wantN || err != wantErr {
+ t.Fatalf("Writev: got (%d, %v), wanted (%d, %v)", n, err, wantN, wantErr)
+ }
+}
+
+func TestPipeWriteUntilEnd(t *testing.T) {
+ const atomicIOBytes = 2
+
+ ctx := contexttest.Context(t)
+ r, w := NewConnectedPipe(ctx, atomicIOBytes, atomicIOBytes)
+ defer r.DecRef()
+ defer w.DecRef()
+
+ msg := []byte("here's some bytes")
+
+ wDone := make(chan struct{}, 0)
+ rDone := make(chan struct{}, 0)
+ defer func() {
+ // Signal the reader to stop and wait until it does so.
+ close(wDone)
+ <-rDone
+ }()
+
+ go func() {
+ defer close(rDone)
+ // Read from r until done is closed.
+ ctx := contexttest.Context(t)
+ buf := make([]byte, len(msg)+1)
+ dst := usermem.BytesIOSequence(buf)
+ e, ch := waiter.NewChannelEntry(nil)
+ r.EventRegister(&e, waiter.EventIn)
+ defer r.EventUnregister(&e)
+ for {
+ n, err := r.Readv(ctx, dst)
+ dst = dst.DropFirst64(n)
+ if err == syserror.ErrWouldBlock {
+ select {
+ case <-ch:
+ continue
+ case <-wDone:
+ // We expect to have 1 byte left in dst since len(buf) ==
+ // len(msg)+1.
+ if dst.NumBytes() != 1 || !bytes.Equal(buf[:len(msg)], msg) {
+ t.Errorf("Reader: got %q (%d bytes remaining), wanted %q", buf, dst.NumBytes(), msg)
+ }
+ return
+ }
+ }
+ if err != nil {
+ t.Fatalf("Readv: got unexpected error %v", err)
+ }
+ }
+ }()
+
+ src := usermem.BytesIOSequence(msg)
+ e, ch := waiter.NewChannelEntry(nil)
+ w.EventRegister(&e, waiter.EventOut)
+ defer w.EventUnregister(&e)
+ for src.NumBytes() != 0 {
+ n, err := w.Writev(ctx, src)
+ src = src.DropFirst64(n)
+ if err == syserror.ErrWouldBlock {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("Writev: got (%d, %v)", n, err)
+ }
+ }
+}
diff --git a/pkg/sentry/kernel/pipe/pipe_unsafe.go b/pkg/sentry/kernel/pipe/pipe_unsafe.go
new file mode 100644
index 000000000..dd60cba24
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/pipe_unsafe.go
@@ -0,0 +1,35 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "unsafe"
+)
+
+// lockTwoPipes locks both x.mu and y.mu in an order that is guaranteed to be
+// consistent for both lockTwoPipes(x, y) and lockTwoPipes(y, x), such that
+// concurrent calls cannot deadlock.
+//
+// Preconditions: x != y.
+func lockTwoPipes(x, y *Pipe) {
+ // Lock the two pipes in order of increasing address.
+ if uintptr(unsafe.Pointer(x)) < uintptr(unsafe.Pointer(y)) {
+ x.mu.Lock()
+ y.mu.Lock()
+ } else {
+ y.mu.Lock()
+ x.mu.Lock()
+ }
+}
diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go
new file mode 100644
index 000000000..aacf28da2
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/pipe_util.go
@@ -0,0 +1,214 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "io"
+ "math"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/amutex"
+ "gvisor.dev/gvisor/pkg/buffer"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// This file contains Pipe file functionality that is tied to neither VFS nor
+// the old fs architecture.
+
+// Release cleans up the pipe's state.
+func (p *Pipe) Release() {
+ p.rClose()
+ p.wClose()
+
+ // Wake up readers and writers.
+ p.Notify(waiter.EventIn | waiter.EventOut)
+}
+
+// Read reads from the Pipe into dst.
+func (p *Pipe) Read(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+ n, err := p.read(ctx, readOps{
+ left: func() int64 {
+ return dst.NumBytes()
+ },
+ limit: func(l int64) {
+ dst = dst.TakeFirst64(l)
+ },
+ read: func(view *buffer.View) (int64, error) {
+ n, err := dst.CopyOutFrom(ctx, view)
+ dst = dst.DropFirst64(n)
+ view.TrimFront(n)
+ return n, err
+ },
+ })
+ if n > 0 {
+ p.Notify(waiter.EventOut)
+ }
+ return n, err
+}
+
+// WriteTo writes to w from the Pipe.
+func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool) (int64, error) {
+ ops := readOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ read: func(view *buffer.View) (int64, error) {
+ n, err := view.ReadToWriter(w, count)
+ if !dup {
+ view.TrimFront(n)
+ }
+ count -= n
+ return n, err
+ },
+ }
+ n, err := p.read(ctx, ops)
+ if n > 0 {
+ p.Notify(waiter.EventOut)
+ }
+ return n, err
+}
+
+// Write writes to the Pipe from src.
+func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error) {
+ n, err := p.write(ctx, writeOps{
+ left: func() int64 {
+ return src.NumBytes()
+ },
+ limit: func(l int64) {
+ src = src.TakeFirst64(l)
+ },
+ write: func(view *buffer.View) (int64, error) {
+ n, err := src.CopyInTo(ctx, view)
+ src = src.DropFirst64(n)
+ return n, err
+ },
+ })
+ if n > 0 {
+ p.Notify(waiter.EventIn)
+ }
+ return n, err
+}
+
+// ReadFrom reads from r to the Pipe.
+func (p *Pipe) ReadFrom(ctx context.Context, r io.Reader, count int64) (int64, error) {
+ n, err := p.write(ctx, writeOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ write: func(view *buffer.View) (int64, error) {
+ n, err := view.WriteFromReader(r, count)
+ count -= n
+ return n, err
+ },
+ })
+ if n > 0 {
+ p.Notify(waiter.EventIn)
+ }
+ return n, err
+}
+
+// Readiness returns the ready events in the underlying pipe.
+func (p *Pipe) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return p.rwReadiness() & mask
+}
+
+// Ioctl implements ioctls on the Pipe.
+func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ // Switch on ioctl request.
+ switch int(args[1].Int()) {
+ case linux.FIONREAD:
+ v := p.queued()
+ if v > math.MaxInt32 {
+ v = math.MaxInt32 // Silently truncate.
+ }
+ // Copy result to userspace.
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ default:
+ return 0, syscall.ENOTTY
+ }
+}
+
+// waitFor blocks until the underlying pipe has at least one reader/writer is
+// announced via 'wakeupChan', or until 'sleeper' is cancelled. Any call to this
+// function will block for either readers or writers, depending on where
+// 'wakeupChan' points.
+//
+// mu must be held by the caller. waitFor returns with mu held, but it will
+// drop mu before blocking for any reader/writers.
+func waitFor(mu *sync.Mutex, wakeupChan *chan struct{}, sleeper amutex.Sleeper) bool {
+ // Ideally this function would simply use a condition variable. However, the
+ // wait needs to be interruptible via 'sleeper', so we must sychronize via a
+ // channel. The synchronization below relies on the fact that closing a
+ // channel unblocks all receives on the channel.
+
+ // Does an appropriate wakeup channel already exist? If not, create a new
+ // one. This is all done under f.mu to avoid races.
+ if *wakeupChan == nil {
+ *wakeupChan = make(chan struct{})
+ }
+
+ // Grab a local reference to the wakeup channel since it may disappear as
+ // soon as we drop f.mu.
+ wakeup := *wakeupChan
+
+ // Drop the lock and prepare to sleep.
+ mu.Unlock()
+ cancel := sleeper.SleepStart()
+
+ // Wait for either a new reader/write to be signalled via 'wakeup', or
+ // for the sleep to be cancelled.
+ select {
+ case <-wakeup:
+ sleeper.SleepFinish(true)
+ case <-cancel:
+ sleeper.SleepFinish(false)
+ }
+
+ // Take the lock and check if we were woken. If we were woken and
+ // interrupted, the former takes priority.
+ mu.Lock()
+ select {
+ case <-wakeup:
+ return true
+ default:
+ return false
+ }
+}
+
+// newHandleLocked signals a new pipe reader or writer depending on where
+// 'wakeupChan' points. This unblocks any corresponding reader or writer
+// waiting for the other end of the channel to be opened, see Fifo.waitFor.
+//
+// Precondition: the mutex protecting wakeupChan must be held.
+func newHandleLocked(wakeupChan *chan struct{}) {
+ if *wakeupChan != nil {
+ close(*wakeupChan)
+ *wakeupChan = nil
+ }
+}
diff --git a/pkg/sentry/kernel/pipe/reader.go b/pkg/sentry/kernel/pipe/reader.go
new file mode 100644
index 000000000..7724b4452
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/reader.go
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// Reader satisfies the fs.FileOperations interface for read-only pipes.
+// Reader should be used with !fs.FileFlags.Write to reject writes.
+//
+// +stateify savable
+type Reader struct {
+ ReaderWriter
+}
+
+// Release implements fs.FileOperations.Release.
+//
+// This overrides ReaderWriter.Release.
+func (r *Reader) Release() {
+ r.Pipe.rClose()
+
+ // Wake up writers.
+ r.Pipe.Notify(waiter.EventOut)
+}
+
+// Readiness returns the ready events in the underlying pipe.
+func (r *Reader) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return r.Pipe.rReadiness() & mask
+}
diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go
new file mode 100644
index 000000000..b2b5691ee
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/reader_writer.go
@@ -0,0 +1,67 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// ReaderWriter satisfies the FileOperations interface and services both
+// read and write requests. This should only be used directly for named pipes.
+// pipe(2) and pipe2(2) only support unidirectional pipes and should use
+// either pipe.Reader or pipe.Writer.
+//
+// +stateify savable
+type ReaderWriter struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ *Pipe
+}
+
+// Read implements fs.FileOperations.Read.
+func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ return rw.Pipe.Read(ctx, dst)
+}
+
+// WriteTo implements fs.FileOperations.WriteTo.
+func (rw *ReaderWriter) WriteTo(ctx context.Context, _ *fs.File, w io.Writer, count int64, dup bool) (int64, error) {
+ return rw.Pipe.WriteTo(ctx, w, count, dup)
+}
+
+// Write implements fs.FileOperations.Write.
+func (rw *ReaderWriter) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ return rw.Pipe.Write(ctx, src)
+}
+
+// ReadFrom implements fs.FileOperations.WriteTo.
+func (rw *ReaderWriter) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
+ return rw.Pipe.ReadFrom(ctx, r, count)
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (rw *ReaderWriter) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return rw.Pipe.Ioctl(ctx, io, args)
+}
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
new file mode 100644
index 000000000..45d4c5fc1
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -0,0 +1,468 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/buffer"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// This file contains types enabling the pipe package to be used with the vfs
+// package.
+
+// VFSPipe represents the actual pipe, analagous to an inode. VFSPipes should
+// not be copied.
+type VFSPipe struct {
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // pipe is the underlying pipe.
+ pipe Pipe
+
+ // Channels for synchronizing the creation of new readers and writers
+ // of this fifo. See waitFor and newHandleLocked.
+ //
+ // These are not saved/restored because all waiters are unblocked on
+ // save, and either automatically restart (via ERESTARTSYS) or return
+ // EINTR on resume. On restarts via ERESTARTSYS, the appropriate
+ // channel will be recreated.
+ rWakeup chan struct{} `state:"nosave"`
+ wWakeup chan struct{} `state:"nosave"`
+}
+
+// NewVFSPipe returns an initialized VFSPipe.
+func NewVFSPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *VFSPipe {
+ var vp VFSPipe
+ initPipe(&vp.pipe, isNamed, sizeBytes, atomicIOBytes)
+ return &vp
+}
+
+// ReaderWriterPair returns read-only and write-only FDs for vp.
+//
+// Preconditions: statusFlags should not contain an open access mode.
+func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription) {
+ // Connected pipes share the same locks.
+ locks := &vfs.FileLocks{}
+ return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks)
+}
+
+// Open opens the pipe represented by vp.
+func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) (*vfs.FileDescription, error) {
+ vp.mu.Lock()
+ defer vp.mu.Unlock()
+
+ readable := vfs.MayReadFileWithOpenFlags(statusFlags)
+ writable := vfs.MayWriteFileWithOpenFlags(statusFlags)
+ if !readable && !writable {
+ return nil, syserror.EINVAL
+ }
+
+ fd := vp.newFD(mnt, vfsd, statusFlags, locks)
+
+ // Named pipes have special blocking semantics during open:
+ //
+ // "Normally, opening the FIFO blocks until the other end is opened also. A
+ // process can open a FIFO in nonblocking mode. In this case, opening for
+ // read-only will succeed even if no-one has opened on the write side yet,
+ // opening for write-only will fail with ENXIO (no such device or address)
+ // unless the other end has already been opened. Under Linux, opening a
+ // FIFO for read and write will succeed both in blocking and nonblocking
+ // mode. POSIX leaves this behavior undefined. This can be used to open a
+ // FIFO for writing while there are no readers available." - fifo(7)
+ switch {
+ case readable && writable:
+ // Pipes opened for read-write always succeed without blocking.
+ newHandleLocked(&vp.rWakeup)
+ newHandleLocked(&vp.wWakeup)
+
+ case readable:
+ newHandleLocked(&vp.rWakeup)
+ // If this pipe is being opened as blocking and there's no
+ // writer, we have to wait for a writer to open the other end.
+ if vp.pipe.isNamed && statusFlags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) {
+ fd.DecRef()
+ return nil, syserror.EINTR
+ }
+
+ case writable:
+ newHandleLocked(&vp.wWakeup)
+
+ if vp.pipe.isNamed && !vp.pipe.HasReaders() {
+ // Non-blocking, write-only opens fail with ENXIO when the read
+ // side isn't open yet.
+ if statusFlags&linux.O_NONBLOCK != 0 {
+ fd.DecRef()
+ return nil, syserror.ENXIO
+ }
+ // Wait for a reader to open the other end.
+ if !waitFor(&vp.mu, &vp.rWakeup, ctx) {
+ fd.DecRef()
+ return nil, syserror.EINTR
+ }
+ }
+
+ default:
+ panic("invalid pipe flags: must be readable, writable, or both")
+ }
+
+ return fd, nil
+}
+
+// Preconditions: vp.mu must be held.
+func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) *vfs.FileDescription {
+ fd := &VFSPipeFD{
+ pipe: &vp.pipe,
+ }
+ fd.LockFD.Init(locks)
+ fd.vfsfd.Init(fd, statusFlags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
+ UseDentryMetadata: true,
+ })
+
+ switch {
+ case fd.vfsfd.IsReadable() && fd.vfsfd.IsWritable():
+ vp.pipe.rOpen()
+ vp.pipe.wOpen()
+ case fd.vfsfd.IsReadable():
+ vp.pipe.rOpen()
+ case fd.vfsfd.IsWritable():
+ vp.pipe.wOpen()
+ default:
+ panic("invalid pipe flags: must be readable, writable, or both")
+ }
+
+ return &fd.vfsfd
+}
+
+// VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements
+// non-atomic usermem.IO methods, allowing it to be passed as usermem.IO to
+// other FileDescriptions for splice(2) and tee(2).
+type VFSPipeFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.LockFD
+
+ pipe *Pipe
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *VFSPipeFD) Release() {
+ var event waiter.EventMask
+ if fd.vfsfd.IsReadable() {
+ fd.pipe.rClose()
+ event |= waiter.EventOut
+ }
+ if fd.vfsfd.IsWritable() {
+ fd.pipe.wClose()
+ event |= waiter.EventIn | waiter.EventHUp
+ }
+ if event == 0 {
+ panic("invalid pipe flags: must be readable, writable, or both")
+ }
+
+ fd.pipe.Notify(event)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (fd *VFSPipeFD) Readiness(mask waiter.EventMask) waiter.EventMask {
+ switch {
+ case fd.vfsfd.IsReadable() && fd.vfsfd.IsWritable():
+ return fd.pipe.rwReadiness()
+ case fd.vfsfd.IsReadable():
+ return fd.pipe.rReadiness()
+ case fd.vfsfd.IsWritable():
+ return fd.pipe.wReadiness()
+ default:
+ panic("pipe FD is neither readable nor writable")
+ }
+}
+
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
+func (fd *VFSPipeFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ return syserror.ESPIPE
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (fd *VFSPipeFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fd.pipe.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (fd *VFSPipeFD) EventUnregister(e *waiter.Entry) {
+ fd.pipe.EventUnregister(e)
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *VFSPipeFD) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
+ return fd.pipe.Read(ctx, dst)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *VFSPipeFD) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) {
+ return fd.pipe.Write(ctx, src)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
+func (fd *VFSPipeFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return fd.pipe.Ioctl(ctx, uio, args)
+}
+
+// PipeSize implements fcntl(F_GETPIPE_SZ).
+func (fd *VFSPipeFD) PipeSize() int64 {
+ // Inline Pipe.FifoSize() rather than calling it with nil Context and
+ // fs.File and ignoring the returned error (which is always nil).
+ fd.pipe.mu.Lock()
+ defer fd.pipe.mu.Unlock()
+ return fd.pipe.max
+}
+
+// SetPipeSize implements fcntl(F_SETPIPE_SZ).
+func (fd *VFSPipeFD) SetPipeSize(size int64) (int64, error) {
+ return fd.pipe.SetFifoSize(size)
+}
+
+// IOSequence returns a useremm.IOSequence that reads up to count bytes from,
+// or writes up to count bytes to, fd.
+func (fd *VFSPipeFD) IOSequence(count int64) usermem.IOSequence {
+ return usermem.IOSequence{
+ IO: fd,
+ Addrs: usermem.AddrRangeSeqOf(usermem.AddrRange{0, usermem.Addr(count)}),
+ }
+}
+
+// CopyIn implements usermem.IO.CopyIn.
+func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, opts usermem.IOOpts) (int, error) {
+ origCount := int64(len(dst))
+ n, err := fd.pipe.read(ctx, readOps{
+ left: func() int64 {
+ return int64(len(dst))
+ },
+ limit: func(l int64) {
+ dst = dst[:l]
+ },
+ read: func(view *buffer.View) (int64, error) {
+ n, err := view.ReadAt(dst, 0)
+ view.TrimFront(int64(n))
+ return int64(n), err
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventOut)
+ }
+ if err == nil && n != origCount {
+ return int(n), syserror.ErrWouldBlock
+ }
+ return int(n), err
+}
+
+// CopyOut implements usermem.IO.CopyOut.
+func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, opts usermem.IOOpts) (int, error) {
+ origCount := int64(len(src))
+ n, err := fd.pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return int64(len(src))
+ },
+ limit: func(l int64) {
+ src = src[:l]
+ },
+ write: func(view *buffer.View) (int64, error) {
+ view.Append(src)
+ return int64(len(src)), nil
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventIn)
+ }
+ if err == nil && n != origCount {
+ return int(n), syserror.ErrWouldBlock
+ }
+ return int(n), err
+}
+
+// ZeroOut implements usermem.IO.ZeroOut.
+func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int64, opts usermem.IOOpts) (int64, error) {
+ origCount := toZero
+ n, err := fd.pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return toZero
+ },
+ limit: func(l int64) {
+ toZero = l
+ },
+ write: func(view *buffer.View) (int64, error) {
+ view.Grow(view.Size()+toZero, true /* zero */)
+ return toZero, nil
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventIn)
+ }
+ if err == nil && n != origCount {
+ return n, syserror.ErrWouldBlock
+ }
+ return n, err
+}
+
+// CopyInTo implements usermem.IO.CopyInTo.
+func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) {
+ count := ars.NumBytes()
+ if count == 0 {
+ return 0, nil
+ }
+ origCount := count
+ n, err := fd.pipe.read(ctx, readOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ read: func(view *buffer.View) (int64, error) {
+ n, err := view.ReadToSafememWriter(dst, uint64(count))
+ view.TrimFront(int64(n))
+ return int64(n), err
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventOut)
+ }
+ if err == nil && n != origCount {
+ return n, syserror.ErrWouldBlock
+ }
+ return n, err
+}
+
+// CopyOutFrom implements usermem.IO.CopyOutFrom.
+func (fd *VFSPipeFD) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) {
+ count := ars.NumBytes()
+ if count == 0 {
+ return 0, nil
+ }
+ origCount := count
+ n, err := fd.pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ write: func(view *buffer.View) (int64, error) {
+ n, err := view.WriteFromSafememReader(src, uint64(count))
+ return int64(n), err
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventIn)
+ }
+ if err == nil && n != origCount {
+ return n, syserror.ErrWouldBlock
+ }
+ return n, err
+}
+
+// SwapUint32 implements usermem.IO.SwapUint32.
+func (fd *VFSPipeFD) SwapUint32(ctx context.Context, addr usermem.Addr, new uint32, opts usermem.IOOpts) (uint32, error) {
+ // How did a pipe get passed as the virtual address space to futex(2)?
+ panic("VFSPipeFD.SwapUint32 called unexpectedly")
+}
+
+// CompareAndSwapUint32 implements usermem.IO.CompareAndSwapUint32.
+func (fd *VFSPipeFD) CompareAndSwapUint32(ctx context.Context, addr usermem.Addr, old, new uint32, opts usermem.IOOpts) (uint32, error) {
+ panic("VFSPipeFD.CompareAndSwapUint32 called unexpectedly")
+}
+
+// LoadUint32 implements usermem.IO.LoadUint32.
+func (fd *VFSPipeFD) LoadUint32(ctx context.Context, addr usermem.Addr, opts usermem.IOOpts) (uint32, error) {
+ panic("VFSPipeFD.LoadUint32 called unexpectedly")
+}
+
+// Splice reads up to count bytes from src and writes them to dst. It returns
+// the number of bytes moved.
+//
+// Preconditions: count > 0.
+func Splice(ctx context.Context, dst, src *VFSPipeFD, count int64) (int64, error) {
+ return spliceOrTee(ctx, dst, src, count, true /* removeFromSrc */)
+}
+
+// Tee reads up to count bytes from src and writes them to dst, without
+// removing the read bytes from src. It returns the number of bytes copied.
+//
+// Preconditions: count > 0.
+func Tee(ctx context.Context, dst, src *VFSPipeFD, count int64) (int64, error) {
+ return spliceOrTee(ctx, dst, src, count, false /* removeFromSrc */)
+}
+
+// Preconditions: count > 0.
+func spliceOrTee(ctx context.Context, dst, src *VFSPipeFD, count int64, removeFromSrc bool) (int64, error) {
+ if dst.pipe == src.pipe {
+ return 0, syserror.EINVAL
+ }
+
+ lockTwoPipes(dst.pipe, src.pipe)
+ defer dst.pipe.mu.Unlock()
+ defer src.pipe.mu.Unlock()
+
+ n, err := dst.pipe.writeLocked(ctx, writeOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ write: func(dstView *buffer.View) (int64, error) {
+ return src.pipe.readLocked(ctx, readOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ read: func(srcView *buffer.View) (int64, error) {
+ n, err := srcView.ReadToSafememWriter(dstView, uint64(count))
+ if n > 0 && removeFromSrc {
+ srcView.TrimFront(int64(n))
+ }
+ return int64(n), err
+ },
+ })
+ },
+ })
+ if n > 0 {
+ dst.pipe.Notify(waiter.EventIn)
+ src.pipe.Notify(waiter.EventOut)
+ }
+ return n, err
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *VFSPipeFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *VFSPipeFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/kernel/pipe/writer.go b/pkg/sentry/kernel/pipe/writer.go
new file mode 100644
index 000000000..5bc6aa931
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/writer.go
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// Writer satisfies the fs.FileOperations interface for write-only pipes.
+// Writer should be used with !fs.FileFlags.Read to reject reads.
+//
+// +stateify savable
+type Writer struct {
+ ReaderWriter
+}
+
+// Release implements fs.FileOperations.Release.
+//
+// This overrides ReaderWriter.Release.
+func (w *Writer) Release() {
+ w.Pipe.wClose()
+
+ // Wake up readers.
+ w.Pipe.Notify(waiter.EventHUp)
+}
+
+// Readiness returns the ready events in the underlying pipe.
+func (w *Writer) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return w.Pipe.wReadiness() & mask
+}
diff --git a/pkg/sentry/kernel/posixtimer.go b/pkg/sentry/kernel/posixtimer.go
new file mode 100644
index 000000000..2e861a5a8
--- /dev/null
+++ b/pkg/sentry/kernel/posixtimer.go
@@ -0,0 +1,308 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "math"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// IntervalTimer represents a POSIX interval timer as described by
+// timer_create(2).
+//
+// +stateify savable
+type IntervalTimer struct {
+ timer *ktime.Timer
+
+ // If target is not nil, it receives signo from timer expirations. If group
+ // is true, these signals are thread-group-directed. These fields are
+ // immutable.
+ target *Task
+ signo linux.Signal
+ id linux.TimerID
+ sigval uint64
+ group bool
+
+ // If sigpending is true, a signal to target is already queued, and timer
+ // expirations should increment overrunCur instead of sending another
+ // signal. sigpending is protected by target's signal mutex. (If target is
+ // nil, the timer will never send signals, so sigpending will be unused.)
+ sigpending bool
+
+ // If sigorphan is true, timer's setting has been changed since sigpending
+ // last became true, such that overruns should no longer be counted in the
+ // pending signals si_overrun. sigorphan is protected by target's signal
+ // mutex.
+ sigorphan bool
+
+ // overrunCur is the number of overruns that have occurred since the last
+ // time a signal was sent. overrunCur is protected by target's signal
+ // mutex.
+ overrunCur uint64
+
+ // Consider the last signal sent by this timer that has been dequeued.
+ // overrunLast is the number of overruns that occurred between when this
+ // signal was sent and when it was dequeued. Equivalently, overrunLast was
+ // the value of overrunCur when this signal was dequeued. overrunLast is
+ // protected by target's signal mutex.
+ overrunLast uint64
+}
+
+// DestroyTimer releases it's resources.
+func (it *IntervalTimer) DestroyTimer() {
+ it.timer.Destroy()
+ it.timerSettingChanged()
+ // A destroyed IntervalTimer is still potentially reachable via a
+ // pendingSignal; nil out timer so that it won't be saved.
+ it.timer = nil
+}
+
+func (it *IntervalTimer) timerSettingChanged() {
+ if it.target == nil {
+ return
+ }
+ it.target.tg.pidns.owner.mu.RLock()
+ defer it.target.tg.pidns.owner.mu.RUnlock()
+ it.target.tg.signalHandlers.mu.Lock()
+ defer it.target.tg.signalHandlers.mu.Unlock()
+ it.sigorphan = true
+ it.overrunCur = 0
+ it.overrunLast = 0
+}
+
+// PauseTimer pauses the associated Timer.
+func (it *IntervalTimer) PauseTimer() {
+ it.timer.Pause()
+}
+
+// ResumeTimer resumes the associated Timer.
+func (it *IntervalTimer) ResumeTimer() {
+ it.timer.Resume()
+}
+
+// Preconditions: it.target's signal mutex must be locked.
+func (it *IntervalTimer) updateDequeuedSignalLocked(si *arch.SignalInfo) {
+ it.sigpending = false
+ if it.sigorphan {
+ return
+ }
+ it.overrunLast = it.overrunCur
+ it.overrunCur = 0
+ si.SetOverrun(saturateI32FromU64(it.overrunLast))
+}
+
+// Preconditions: it.target's signal mutex must be locked.
+func (it *IntervalTimer) signalRejectedLocked() {
+ it.sigpending = false
+ if it.sigorphan {
+ return
+ }
+ it.overrunCur++
+}
+
+// Notify implements ktime.TimerListener.Notify.
+func (it *IntervalTimer) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
+ if it.target == nil {
+ return ktime.Setting{}, false
+ }
+
+ it.target.tg.pidns.owner.mu.RLock()
+ defer it.target.tg.pidns.owner.mu.RUnlock()
+ it.target.tg.signalHandlers.mu.Lock()
+ defer it.target.tg.signalHandlers.mu.Unlock()
+
+ if it.sigpending {
+ it.overrunCur += exp
+ return ktime.Setting{}, false
+ }
+
+ // sigpending must be set before sendSignalTimerLocked() so that it can be
+ // unset if the signal is discarded (in which case sendSignalTimerLocked()
+ // will return nil).
+ it.sigpending = true
+ it.sigorphan = false
+ it.overrunCur += exp - 1
+ si := &arch.SignalInfo{
+ Signo: int32(it.signo),
+ Code: arch.SignalInfoTimer,
+ }
+ si.SetTimerID(it.id)
+ si.SetSigval(it.sigval)
+ // si_overrun is set when the signal is dequeued.
+ if err := it.target.sendSignalTimerLocked(si, it.group, it); err != nil {
+ it.signalRejectedLocked()
+ }
+
+ return ktime.Setting{}, false
+}
+
+// Destroy implements ktime.TimerListener.Destroy. Users of Timer should call
+// DestroyTimer instead.
+func (it *IntervalTimer) Destroy() {
+}
+
+// IntervalTimerCreate implements timer_create(2).
+func (t *Task) IntervalTimerCreate(c ktime.Clock, sigev *linux.Sigevent) (linux.TimerID, error) {
+ t.tg.timerMu.Lock()
+ defer t.tg.timerMu.Unlock()
+
+ // Allocate a timer ID.
+ var id linux.TimerID
+ end := t.tg.nextTimerID
+ for {
+ id = t.tg.nextTimerID
+ _, ok := t.tg.timers[id]
+ t.tg.nextTimerID++
+ if t.tg.nextTimerID < 0 {
+ t.tg.nextTimerID = 0
+ }
+ if !ok {
+ break
+ }
+ if t.tg.nextTimerID == end {
+ return 0, syserror.EAGAIN
+ }
+ }
+
+ // "The implementation of the default case where evp [sic] is NULL is
+ // handled inside glibc, which invokes the underlying system call with a
+ // suitably populated sigevent structure." - timer_create(2). This is
+ // misleading; the timer_create syscall also handles a NULL sevp as
+ // described by the man page
+ // (kernel/time/posix-timers.c:sys_timer_create(), do_timer_create()). This
+ // must be handled here instead of the syscall wrapper since sigval is the
+ // timer ID, which isn't available until we allocate it in this function.
+ if sigev == nil {
+ sigev = &linux.Sigevent{
+ Signo: int32(linux.SIGALRM),
+ Notify: linux.SIGEV_SIGNAL,
+ Value: uint64(id),
+ }
+ }
+
+ // Construct the timer.
+ it := &IntervalTimer{
+ id: id,
+ sigval: sigev.Value,
+ }
+ switch sigev.Notify {
+ case linux.SIGEV_NONE:
+ // leave it.target = nil
+ case linux.SIGEV_SIGNAL, linux.SIGEV_THREAD:
+ // POSIX SIGEV_THREAD semantics are implemented in userspace by libc;
+ // to the kernel, SIGEV_THREAD and SIGEV_SIGNAL are equivalent. (See
+ // Linux's kernel/time/posix-timers.c:good_sigevent().)
+ it.target = t.tg.leader
+ it.group = true
+ case linux.SIGEV_THREAD_ID:
+ t.tg.pidns.owner.mu.RLock()
+ target, ok := t.tg.pidns.tasks[ThreadID(sigev.Tid)]
+ t.tg.pidns.owner.mu.RUnlock()
+ if !ok || target.tg != t.tg {
+ return 0, syserror.EINVAL
+ }
+ it.target = target
+ default:
+ return 0, syserror.EINVAL
+ }
+ if sigev.Notify != linux.SIGEV_NONE {
+ it.signo = linux.Signal(sigev.Signo)
+ if !it.signo.IsValid() {
+ return 0, syserror.EINVAL
+ }
+ }
+ it.timer = ktime.NewTimer(c, it)
+
+ t.tg.timers[id] = it
+ return id, nil
+}
+
+// IntervalTimerDelete implements timer_delete(2).
+func (t *Task) IntervalTimerDelete(id linux.TimerID) error {
+ t.tg.timerMu.Lock()
+ defer t.tg.timerMu.Unlock()
+ it := t.tg.timers[id]
+ if it == nil {
+ return syserror.EINVAL
+ }
+ delete(t.tg.timers, id)
+ it.DestroyTimer()
+ return nil
+}
+
+// IntervalTimerSettime implements timer_settime(2).
+func (t *Task) IntervalTimerSettime(id linux.TimerID, its linux.Itimerspec, abs bool) (linux.Itimerspec, error) {
+ t.tg.timerMu.Lock()
+ defer t.tg.timerMu.Unlock()
+ it := t.tg.timers[id]
+ if it == nil {
+ return linux.Itimerspec{}, syserror.EINVAL
+ }
+
+ newS, err := ktime.SettingFromItimerspec(its, abs, it.timer.Clock())
+ if err != nil {
+ return linux.Itimerspec{}, err
+ }
+ tm, oldS := it.timer.SwapAnd(newS, it.timerSettingChanged)
+ its = ktime.ItimerspecFromSetting(tm, oldS)
+ return its, nil
+}
+
+// IntervalTimerGettime implements timer_gettime(2).
+func (t *Task) IntervalTimerGettime(id linux.TimerID) (linux.Itimerspec, error) {
+ t.tg.timerMu.Lock()
+ defer t.tg.timerMu.Unlock()
+ it := t.tg.timers[id]
+ if it == nil {
+ return linux.Itimerspec{}, syserror.EINVAL
+ }
+
+ tm, s := it.timer.Get()
+ its := ktime.ItimerspecFromSetting(tm, s)
+ return its, nil
+}
+
+// IntervalTimerGetoverrun implements timer_getoverrun(2).
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) IntervalTimerGetoverrun(id linux.TimerID) (int32, error) {
+ t.tg.timerMu.Lock()
+ defer t.tg.timerMu.Unlock()
+ it := t.tg.timers[id]
+ if it == nil {
+ return 0, syserror.EINVAL
+ }
+ // By timer_create(2) invariant, either it.target == nil (in which case
+ // it.overrunLast is immutably 0) or t.tg == it.target.tg; and the fact
+ // that t is executing timer_getoverrun(2) means that t.tg can't be
+ // completing execve, so t.tg.signalHandlers can't be changing, allowing us
+ // to lock t.tg.signalHandlers.mu without holding the TaskSet mutex.
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ // This is consistent with Linux after 78c9c4dfbf8c ("posix-timers:
+ // Sanitize overrun handling").
+ return saturateI32FromU64(it.overrunLast), nil
+}
+
+func saturateI32FromU64(x uint64) int32 {
+ if x > math.MaxInt32 {
+ return math.MaxInt32
+ }
+ return int32(x)
+}
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
new file mode 100644
index 000000000..e23e796ef
--- /dev/null
+++ b/pkg/sentry/kernel/ptrace.go
@@ -0,0 +1,1119 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// ptraceOptions are the subset of options controlling a task's ptrace behavior
+// that are set by ptrace(PTRACE_SETOPTIONS).
+//
+// +stateify savable
+type ptraceOptions struct {
+ // ExitKill is true if the tracee should be sent SIGKILL when the tracer
+ // exits.
+ ExitKill bool
+
+ // If SysGood is true, set bit 7 in the signal number for
+ // syscall-entry-stop and syscall-exit-stop traps delivered to this task's
+ // tracer.
+ SysGood bool
+
+ // TraceClone is true if the tracer wants to receive PTRACE_EVENT_CLONE
+ // events.
+ TraceClone bool
+
+ // TraceExec is true if the tracer wants to receive PTRACE_EVENT_EXEC
+ // events.
+ TraceExec bool
+
+ // TraceExit is true if the tracer wants to receive PTRACE_EVENT_EXIT
+ // events.
+ TraceExit bool
+
+ // TraceFork is true if the tracer wants to receive PTRACE_EVENT_FORK
+ // events.
+ TraceFork bool
+
+ // TraceSeccomp is true if the tracer wants to receive PTRACE_EVENT_SECCOMP
+ // events.
+ TraceSeccomp bool
+
+ // TraceVfork is true if the tracer wants to receive PTRACE_EVENT_VFORK
+ // events.
+ TraceVfork bool
+
+ // TraceVforkDone is true if the tracer wants to receive
+ // PTRACE_EVENT_VFORK_DONE events.
+ TraceVforkDone bool
+}
+
+// ptraceSyscallMode controls the behavior of a ptraced task at syscall entry
+// and exit.
+type ptraceSyscallMode int
+
+const (
+ // ptraceSyscallNone indicates that the task has never ptrace-stopped, or
+ // that it was resumed from its last ptrace-stop by PTRACE_CONT or
+ // PTRACE_DETACH. The task's syscalls will not be intercepted.
+ ptraceSyscallNone ptraceSyscallMode = iota
+
+ // ptraceSyscallIntercept indicates that the task was resumed from its last
+ // ptrace-stop by PTRACE_SYSCALL. The next time the task enters or exits a
+ // syscall, a ptrace-stop will occur.
+ ptraceSyscallIntercept
+
+ // ptraceSyscallEmu indicates that the task was resumed from its last
+ // ptrace-stop by PTRACE_SYSEMU or PTRACE_SYSEMU_SINGLESTEP. The next time
+ // the task enters a syscall, the syscall will be skipped, and a
+ // ptrace-stop will occur.
+ ptraceSyscallEmu
+)
+
+// CanTrace checks that t is permitted to access target's state, as defined by
+// ptrace(2), subsection "Ptrace access mode checking". If attach is true, it
+// checks for access mode PTRACE_MODE_ATTACH; otherwise, it checks for access
+// mode PTRACE_MODE_READ.
+//
+// NOTE(b/30815691): The result of CanTrace is immediately stale (e.g., a
+// racing setuid(2) may change traceability). This may pose a risk when a task
+// changes from traceable to not traceable. This is only problematic across
+// execve, where privileges may increase.
+//
+// We currently do not implement privileged executables (set-user/group-ID bits
+// and file capabilities), so that case is not reachable.
+func (t *Task) CanTrace(target *Task, attach bool) bool {
+ // "1. If the calling thread and the target thread are in the same thread
+ // group, access is always allowed." - ptrace(2)
+ //
+ // Note: Strictly speaking, prior to 73af963f9f30 ("__ptrace_may_access()
+ // should not deny sub-threads", first released in Linux 3.12), the rule
+ // only applies if t and target are the same task. But, as that commit
+ // message puts it, "[any] security check is pointless when the tasks share
+ // the same ->mm."
+ if t.tg == target.tg {
+ return true
+ }
+
+ // """
+ // 2. If the access mode specifies PTRACE_MODE_FSCREDS (ED: snipped,
+ // doesn't exist until Linux 4.5).
+ //
+ // Otherwise, the access mode specifies PTRACE_MODE_REALCREDS, so use the
+ // caller's real UID and GID for the checks in the next step. (Most APIs
+ // that check the caller's UID and GID use the effective IDs. For
+ // historical reasons, the PTRACE_MODE_REALCREDS check uses the real IDs
+ // instead.)
+ //
+ // 3. Deny access if neither of the following is true:
+ //
+ // - The real, effective, and saved-set user IDs of the target match the
+ // caller's user ID, *and* the real, effective, and saved-set group IDs of
+ // the target match the caller's group ID.
+ //
+ // - The caller has the CAP_SYS_PTRACE capability in the user namespace of
+ // the target.
+ //
+ // 4. Deny access if the target process "dumpable" attribute has a value
+ // other than 1 (SUID_DUMP_USER; see the discussion of PR_SET_DUMPABLE in
+ // prctl(2)), and the caller does not have the CAP_SYS_PTRACE capability in
+ // the user namespace of the target process.
+ //
+ // 5. The kernel LSM security_ptrace_access_check() interface is invoked to
+ // see if ptrace access is permitted. The results depend on the LSM(s). The
+ // implementation of this interface in the commoncap LSM performs the
+ // following steps:
+ //
+ // a) If the access mode includes PTRACE_MODE_FSCREDS, then use the
+ // caller's effective capability set; otherwise (the access mode specifies
+ // PTRACE_MODE_REALCREDS, so) use the caller's permitted capability set.
+ //
+ // b) Deny access if neither of the following is true:
+ //
+ // - The caller and the target process are in the same user namespace, and
+ // the caller's capabilities are a proper superset of the target process's
+ // permitted capabilities.
+ //
+ // - The caller has the CAP_SYS_PTRACE capability in the target process's
+ // user namespace.
+ //
+ // Note that the commoncap LSM does not distinguish between
+ // PTRACE_MODE_READ and PTRACE_MODE_ATTACH. (ED: From earlier in this
+ // section: "the commoncap LSM ... is always invoked".)
+ // """
+ callerCreds := t.Credentials()
+ targetCreds := target.Credentials()
+ if callerCreds.HasCapabilityIn(linux.CAP_SYS_PTRACE, targetCreds.UserNamespace) {
+ return true
+ }
+ if cuid := callerCreds.RealKUID; cuid != targetCreds.RealKUID || cuid != targetCreds.EffectiveKUID || cuid != targetCreds.SavedKUID {
+ return false
+ }
+ if cgid := callerCreds.RealKGID; cgid != targetCreds.RealKGID || cgid != targetCreds.EffectiveKGID || cgid != targetCreds.SavedKGID {
+ return false
+ }
+ var targetMM *mm.MemoryManager
+ target.WithMuLocked(func(t *Task) {
+ targetMM = t.MemoryManager()
+ })
+ if targetMM != nil && targetMM.Dumpability() != mm.UserDumpable {
+ return false
+ }
+ if callerCreds.UserNamespace != targetCreds.UserNamespace {
+ return false
+ }
+ if targetCreds.PermittedCaps&^callerCreds.PermittedCaps != 0 {
+ return false
+ }
+ return true
+}
+
+// Tracer returns t's ptrace Tracer.
+func (t *Task) Tracer() *Task {
+ return t.ptraceTracer.Load().(*Task)
+}
+
+// hasTracer returns true if t has a ptrace tracer attached.
+func (t *Task) hasTracer() bool {
+ // This isn't just inlined into callers so that if Task.Tracer() turns out
+ // to be too expensive because of e.g. interface conversion, we can switch
+ // to having a separate atomic flag more easily.
+ return t.Tracer() != nil
+}
+
+// ptraceStop is a TaskStop placed on tasks in a ptrace-stop.
+//
+// +stateify savable
+type ptraceStop struct {
+ // If frozen is true, the stopped task's tracer is currently operating on
+ // it, so Task.Kill should not remove the stop.
+ frozen bool
+
+ // If listen is true, the stopped task's tracer invoked PTRACE_LISTEN, so
+ // ptraceFreeze should fail.
+ listen bool
+}
+
+// Killable implements TaskStop.Killable.
+func (s *ptraceStop) Killable() bool {
+ return !s.frozen
+}
+
+// beginPtraceStopLocked initiates an unfrozen ptrace-stop on t. If t has been
+// killed, the stop is skipped, and beginPtraceStopLocked returns false.
+//
+// beginPtraceStopLocked does not signal t's tracer or wake it if it is
+// waiting.
+//
+// Preconditions: The TaskSet mutex must be locked. The caller must be running
+// on the task goroutine.
+func (t *Task) beginPtraceStopLocked() bool {
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ // This is analogous to Linux's kernel/signal.c:ptrace_stop() => ... =>
+ // kernel/sched/core.c:__schedule() => signal_pending_state() check, which
+ // is what prevents tasks from entering ptrace-stops after being killed.
+ // Note that if t was SIGKILLed and beingPtraceStopLocked is being called
+ // for PTRACE_EVENT_EXIT, the task will have dequeued the signal before
+ // entering the exit path, so t.killedLocked() will no longer return true.
+ // This is consistent with Linux: "Bugs: ... A SIGKILL signal may still
+ // cause a PTRACE_EVENT_EXIT stop before actual signal death. This may be
+ // changed in the future; SIGKILL is meant to always immediately kill tasks
+ // even under ptrace. Last confirmed on Linux 3.13." - ptrace(2)
+ if t.killedLocked() {
+ return false
+ }
+ t.beginInternalStopLocked(&ptraceStop{})
+ return true
+}
+
+// Preconditions: The TaskSet mutex must be locked.
+func (t *Task) ptraceTrapLocked(code int32) {
+ // This is unconditional in ptrace_stop().
+ t.tg.signalHandlers.mu.Lock()
+ t.trapStopPending = false
+ t.tg.signalHandlers.mu.Unlock()
+ t.ptraceCode = code
+ t.ptraceSiginfo = &arch.SignalInfo{
+ Signo: int32(linux.SIGTRAP),
+ Code: code,
+ }
+ t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t]))
+ t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ if t.beginPtraceStopLocked() {
+ tracer := t.Tracer()
+ tracer.signalStop(t, arch.CLD_TRAPPED, int32(linux.SIGTRAP))
+ tracer.tg.eventQueue.Notify(EventTraceeStop)
+ }
+}
+
+// ptraceFreeze checks if t is in a ptraceStop. If so, it freezes the
+// ptraceStop, temporarily preventing it from being removed by a concurrent
+// Task.Kill, and returns true. Otherwise it returns false.
+//
+// Preconditions: The TaskSet mutex must be locked. The caller must be running
+// on the task goroutine of t's tracer.
+func (t *Task) ptraceFreeze() bool {
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ if t.stop == nil {
+ return false
+ }
+ s, ok := t.stop.(*ptraceStop)
+ if !ok {
+ return false
+ }
+ if s.listen {
+ return false
+ }
+ s.frozen = true
+ return true
+}
+
+// ptraceUnfreeze ends the effect of a previous successful call to
+// ptraceFreeze.
+//
+// Preconditions: t must be in a frozen ptraceStop.
+func (t *Task) ptraceUnfreeze() {
+ // t.tg.signalHandlers is stable because t is in a frozen ptrace-stop,
+ // preventing its thread group from completing execve.
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ t.ptraceUnfreezeLocked()
+}
+
+// Preconditions: t must be in a frozen ptraceStop. t's signal mutex must be
+// locked.
+func (t *Task) ptraceUnfreezeLocked() {
+ // Do this even if the task has been killed to ensure a panic if t.stop is
+ // nil or not a ptraceStop.
+ t.stop.(*ptraceStop).frozen = false
+ if t.killedLocked() {
+ t.endInternalStopLocked()
+ }
+}
+
+// ptraceUnstop implements ptrace request PTRACE_CONT, PTRACE_SYSCALL,
+// PTRACE_SINGLESTEP, PTRACE_SYSEMU, or PTRACE_SYSEMU_SINGLESTEP depending on
+// mode and singlestep.
+//
+// Preconditions: t must be in a frozen ptrace stop.
+//
+// Postconditions: If ptraceUnstop returns nil, t will no longer be in a ptrace
+// stop.
+func (t *Task) ptraceUnstop(mode ptraceSyscallMode, singlestep bool, sig linux.Signal) error {
+ if sig != 0 && !sig.IsValid() {
+ return syserror.EIO
+ }
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ t.ptraceCode = int32(sig)
+ t.ptraceSyscallMode = mode
+ t.ptraceSinglestep = singlestep
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ t.endInternalStopLocked()
+ return nil
+}
+
+func (t *Task) ptraceTraceme() error {
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ if t.hasTracer() {
+ return syserror.EPERM
+ }
+ if t.parent == nil {
+ // In Linux, only init can not have a parent, and init is assumed never
+ // to invoke PTRACE_TRACEME. In the sentry, TGID 1 is an arbitrary user
+ // application that may invoke PTRACE_TRACEME; having no parent can
+ // also occur if all tasks in the parent thread group have exited, and
+ // failed to find a living thread group to reparent to. The former case
+ // is treated as if TGID 1 has an exited parent in an invisible
+ // ancestor PID namespace that is an owner of the root user namespace
+ // (and consequently has CAP_SYS_PTRACE), and the latter case is a
+ // special form of the exited parent case below. In either case,
+ // returning nil here is correct.
+ return nil
+ }
+ if !t.parent.CanTrace(t, true) {
+ return syserror.EPERM
+ }
+ if t.parent.exitState != TaskExitNone {
+ // Fail silently, as if we were successfully attached but then
+ // immediately detached. This is consistent with Linux.
+ return nil
+ }
+ t.ptraceTracer.Store(t.parent)
+ t.parent.ptraceTracees[t] = struct{}{}
+ return nil
+}
+
+// ptraceAttach implements ptrace(PTRACE_ATTACH, target) if seize is false, and
+// ptrace(PTRACE_SEIZE, target, 0, opts) if seize is true. t is the caller.
+func (t *Task) ptraceAttach(target *Task, seize bool, opts uintptr) error {
+ if t.tg == target.tg {
+ return syserror.EPERM
+ }
+ if !t.CanTrace(target, true) {
+ return syserror.EPERM
+ }
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ if target.hasTracer() {
+ return syserror.EPERM
+ }
+ // Attaching to zombies and dead tasks is not permitted; the exit
+ // notification logic relies on this. Linux allows attaching to PF_EXITING
+ // tasks, though.
+ if target.exitState >= TaskExitZombie {
+ return syserror.EPERM
+ }
+ if seize {
+ if err := target.ptraceSetOptionsLocked(opts); err != nil {
+ return syserror.EIO
+ }
+ }
+ target.ptraceTracer.Store(t)
+ t.ptraceTracees[target] = struct{}{}
+ target.ptraceSeized = seize
+ target.tg.signalHandlers.mu.Lock()
+ // "Unlike PTRACE_ATTACH, PTRACE_SEIZE does not stop the process." -
+ // ptrace(2)
+ if !seize {
+ target.sendSignalLocked(&arch.SignalInfo{
+ Signo: int32(linux.SIGSTOP),
+ Code: arch.SignalInfoUser,
+ }, false /* group */)
+ }
+ // Undocumented Linux feature: If the tracee is already group-stopped (and
+ // consequently will not report the SIGSTOP just sent), force it to leave
+ // and re-enter the stop so that it will switch to a ptrace-stop.
+ if target.stop == (*groupStop)(nil) {
+ target.trapStopPending = true
+ target.endInternalStopLocked()
+ // TODO(jamieliu): Linux blocks ptrace_attach() until the task has
+ // entered the ptrace-stop (or exited) via JOBCTL_TRAPPING.
+ }
+ target.tg.signalHandlers.mu.Unlock()
+ return nil
+}
+
+// ptraceDetach implements ptrace(PTRACE_DETACH, target, 0, sig). t is the
+// caller.
+//
+// Preconditions: target must be a tracee of t in a frozen ptrace stop.
+//
+// Postconditions: If ptraceDetach returns nil, target will no longer be in a
+// ptrace stop.
+func (t *Task) ptraceDetach(target *Task, sig linux.Signal) error {
+ if sig != 0 && !sig.IsValid() {
+ return syserror.EIO
+ }
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ target.ptraceCode = int32(sig)
+ target.forgetTracerLocked()
+ delete(t.ptraceTracees, target)
+ return nil
+}
+
+// exitPtrace is called in the exit path to detach all of t's tracees.
+func (t *Task) exitPtrace() {
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ for target := range t.ptraceTracees {
+ if target.ptraceOpts.ExitKill {
+ target.tg.signalHandlers.mu.Lock()
+ target.sendSignalLocked(&arch.SignalInfo{
+ Signo: int32(linux.SIGKILL),
+ }, false /* group */)
+ target.tg.signalHandlers.mu.Unlock()
+ }
+ // Leave ptraceCode unchanged so that if the task is ptrace-stopped, it
+ // observes the ptraceCode it set before it entered the stop. I believe
+ // this is consistent with Linux.
+ target.forgetTracerLocked()
+ }
+ // "nil maps cannot be saved"
+ t.ptraceTracees = make(map[*Task]struct{})
+}
+
+// forgetTracerLocked detaches t's tracer and ensures that t is no longer
+// ptrace-stopped.
+//
+// Preconditions: The TaskSet mutex must be locked for writing.
+func (t *Task) forgetTracerLocked() {
+ t.ptraceSeized = false
+ t.ptraceOpts = ptraceOptions{}
+ t.ptraceSyscallMode = ptraceSyscallNone
+ t.ptraceSinglestep = false
+ t.ptraceTracer.Store((*Task)(nil))
+ if t.exitTracerNotified && !t.exitTracerAcked {
+ t.exitTracerAcked = true
+ t.exitNotifyLocked(true)
+ }
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ // Unset t.trapStopPending, which might have been set by PTRACE_INTERRUPT. If
+ // it wasn't, it will be reset via t.groupStopPending after the following.
+ t.trapStopPending = false
+ // If t's thread group is in a group stop and t is eligible to participate,
+ // make it do so. This is essentially the reverse of the special case in
+ // ptraceAttach, which converts a group stop to a ptrace stop. ("Handling
+ // of restart from group-stop is currently buggy, but the "as planned"
+ // behavior is to leave tracee stopped and waiting for SIGCONT." -
+ // ptrace(2))
+ if (t.tg.groupStopComplete || t.tg.groupStopPendingCount != 0) && !t.groupStopPending && t.exitState < TaskExitInitiated {
+ t.groupStopPending = true
+ // t already participated in the group stop when it unset
+ // groupStopPending.
+ t.groupStopAcknowledged = true
+ t.interrupt()
+ }
+ if _, ok := t.stop.(*ptraceStop); ok {
+ t.endInternalStopLocked()
+ }
+}
+
+// ptraceSignalLocked is called after signal dequeueing to check if t should
+// enter ptrace signal-delivery-stop.
+//
+// Preconditions: The signal mutex must be locked. The caller must be running
+// on the task goroutine.
+func (t *Task) ptraceSignalLocked(info *arch.SignalInfo) bool {
+ if linux.Signal(info.Signo) == linux.SIGKILL {
+ return false
+ }
+ if !t.hasTracer() {
+ return false
+ }
+ // The tracer might change this signal into a stop signal, in which case
+ // any SIGCONT received after the signal was originally dequeued should
+ // cancel it. This is consistent with Linux.
+ t.tg.groupStopDequeued = true
+ // This is unconditional in ptrace_stop().
+ t.trapStopPending = false
+ // Can't lock the TaskSet mutex while holding a signal mutex.
+ t.tg.signalHandlers.mu.Unlock()
+ defer t.tg.signalHandlers.mu.Lock()
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ tracer := t.Tracer()
+ if tracer == nil {
+ return false
+ }
+ t.ptraceCode = info.Signo
+ t.ptraceSiginfo = info
+ t.Debugf("Entering signal-delivery-stop for signal %d", info.Signo)
+ if t.beginPtraceStopLocked() {
+ tracer.signalStop(t, arch.CLD_TRAPPED, info.Signo)
+ tracer.tg.eventQueue.Notify(EventTraceeStop)
+ }
+ return true
+}
+
+// ptraceSeccomp is called when a seccomp-bpf filter returns action
+// SECCOMP_RET_TRACE to check if t should enter PTRACE_EVENT_SECCOMP stop. data
+// is the lower 16 bits of the filter's return value.
+func (t *Task) ptraceSeccomp(data uint16) bool {
+ if !t.hasTracer() {
+ return false
+ }
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ if !t.ptraceOpts.TraceSeccomp {
+ return false
+ }
+ t.Debugf("Entering PTRACE_EVENT_SECCOMP stop")
+ t.ptraceEventLocked(linux.PTRACE_EVENT_SECCOMP, uint64(data))
+ return true
+}
+
+// ptraceSyscallEnter is called immediately before entering a syscall to check
+// if t should enter ptrace syscall-enter-stop.
+func (t *Task) ptraceSyscallEnter() (taskRunState, bool) {
+ if !t.hasTracer() {
+ return nil, false
+ }
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ switch t.ptraceSyscallMode {
+ case ptraceSyscallNone:
+ return nil, false
+ case ptraceSyscallIntercept:
+ t.Debugf("Entering syscall-enter-stop from PTRACE_SYSCALL")
+ t.ptraceSyscallStopLocked()
+ return (*runSyscallAfterSyscallEnterStop)(nil), true
+ case ptraceSyscallEmu:
+ t.Debugf("Entering syscall-enter-stop from PTRACE_SYSEMU")
+ t.ptraceSyscallStopLocked()
+ return (*runSyscallAfterSysemuStop)(nil), true
+ }
+ panic(fmt.Sprintf("Unknown ptraceSyscallMode: %v", t.ptraceSyscallMode))
+}
+
+// ptraceSyscallExit is called immediately after leaving a syscall to check if
+// t should enter ptrace syscall-exit-stop.
+func (t *Task) ptraceSyscallExit() {
+ if !t.hasTracer() {
+ return
+ }
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ if t.ptraceSyscallMode != ptraceSyscallIntercept {
+ return
+ }
+ t.Debugf("Entering syscall-exit-stop")
+ t.ptraceSyscallStopLocked()
+}
+
+// Preconditions: The TaskSet mutex must be locked.
+func (t *Task) ptraceSyscallStopLocked() {
+ code := int32(linux.SIGTRAP)
+ if t.ptraceOpts.SysGood {
+ code |= 0x80
+ }
+ t.ptraceTrapLocked(code)
+}
+
+type ptraceCloneKind int32
+
+const (
+ // ptraceCloneKindClone represents a call to Task.Clone where
+ // TerminationSignal is not SIGCHLD and Vfork is false.
+ ptraceCloneKindClone ptraceCloneKind = iota
+
+ // ptraceCloneKindFork represents a call to Task.Clone where
+ // TerminationSignal is SIGCHLD and Vfork is false.
+ ptraceCloneKindFork
+
+ // ptraceCloneKindVfork represents a call to Task.Clone where Vfork is
+ // true.
+ ptraceCloneKindVfork
+)
+
+// ptraceClone is called at the end of a clone or fork syscall to check if t
+// should enter PTRACE_EVENT_CLONE, PTRACE_EVENT_FORK, or PTRACE_EVENT_VFORK
+// stop. child is the new task.
+func (t *Task) ptraceClone(kind ptraceCloneKind, child *Task, opts *CloneOptions) bool {
+ if !t.hasTracer() {
+ return false
+ }
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ event := false
+ if !opts.Untraced {
+ switch kind {
+ case ptraceCloneKindClone:
+ if t.ptraceOpts.TraceClone {
+ t.Debugf("Entering PTRACE_EVENT_CLONE stop")
+ t.ptraceEventLocked(linux.PTRACE_EVENT_CLONE, uint64(t.tg.pidns.tids[child]))
+ event = true
+ }
+ case ptraceCloneKindFork:
+ if t.ptraceOpts.TraceFork {
+ t.Debugf("Entering PTRACE_EVENT_FORK stop")
+ t.ptraceEventLocked(linux.PTRACE_EVENT_FORK, uint64(t.tg.pidns.tids[child]))
+ event = true
+ }
+ case ptraceCloneKindVfork:
+ if t.ptraceOpts.TraceVfork {
+ t.Debugf("Entering PTRACE_EVENT_VFORK stop")
+ t.ptraceEventLocked(linux.PTRACE_EVENT_VFORK, uint64(t.tg.pidns.tids[child]))
+ event = true
+ }
+ default:
+ panic(fmt.Sprintf("Unknown ptraceCloneKind: %v", kind))
+ }
+ }
+ // "If the PTRACE_O_TRACEFORK, PTRACE_O_TRACEVFORK, or PTRACE_O_TRACECLONE
+ // options are in effect, then children created by, respectively, vfork(2)
+ // or clone(2) with the CLONE_VFORK flag, fork(2) or clone(2) with the exit
+ // signal set to SIGCHLD, and other kinds of clone(2), are automatically
+ // attached to the same tracer which traced their parent. SIGSTOP is
+ // delivered to the children, causing them to enter signal-delivery-stop
+ // after they exit the system call which created them." - ptrace(2)
+ //
+ // clone(2)'s documentation of CLONE_UNTRACED and CLONE_PTRACE is
+ // confusingly wrong; see kernel/fork.c:_do_fork() => copy_process() =>
+ // include/linux/ptrace.h:ptrace_init_task().
+ if event || opts.InheritTracer {
+ tracer := t.Tracer()
+ if tracer != nil {
+ child.ptraceTracer.Store(tracer)
+ tracer.ptraceTracees[child] = struct{}{}
+ // "The "seized" behavior ... is inherited by children that are
+ // automatically attached using PTRACE_O_TRACEFORK,
+ // PTRACE_O_TRACEVFORK, and PTRACE_O_TRACECLONE." - ptrace(2)
+ child.ptraceSeized = t.ptraceSeized
+ // "Flags are inherited by new tracees created and "auto-attached"
+ // via active PTRACE_O_TRACEFORK, PTRACE_O_TRACEVFORK, or
+ // PTRACE_O_TRACECLONE options." - ptrace(2)
+ child.ptraceOpts = t.ptraceOpts
+ child.tg.signalHandlers.mu.Lock()
+ // "PTRACE_SEIZE: ... Automatically attached children stop with
+ // PTRACE_EVENT_STOP and WSTOPSIG(status) returns SIGTRAP instead
+ // of having SIGSTOP signal delivered to them." - ptrace(2)
+ if child.ptraceSeized {
+ child.trapStopPending = true
+ } else {
+ child.pendingSignals.enqueue(&arch.SignalInfo{
+ Signo: int32(linux.SIGSTOP),
+ }, nil)
+ }
+ // The child will self-interrupt() when its task goroutine starts
+ // running, so we don't have to.
+ child.tg.signalHandlers.mu.Unlock()
+ }
+ }
+ return event
+}
+
+// ptraceVforkDone is called after the end of a vfork stop to check if t should
+// enter PTRACE_EVENT_VFORK_DONE stop. child is the new task's thread ID in t's
+// PID namespace.
+func (t *Task) ptraceVforkDone(child ThreadID) bool {
+ if !t.hasTracer() {
+ return false
+ }
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ if !t.ptraceOpts.TraceVforkDone {
+ return false
+ }
+ t.Debugf("Entering PTRACE_EVENT_VFORK_DONE stop")
+ t.ptraceEventLocked(linux.PTRACE_EVENT_VFORK_DONE, uint64(child))
+ return true
+}
+
+// ptraceExec is called at the end of an execve syscall to check if t should
+// enter PTRACE_EVENT_EXEC stop. oldTID is t's thread ID, in its *tracer's* PID
+// namespace, prior to the execve. (If t did not have a tracer at the time
+// oldTID was read, oldTID may be 0. This is consistent with Linux.)
+func (t *Task) ptraceExec(oldTID ThreadID) {
+ if !t.hasTracer() {
+ return
+ }
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ // Recheck with the TaskSet mutex locked. Most ptrace points don't need to
+ // do this because detaching resets ptrace options, but PTRACE_EVENT_EXEC
+ // is special because both TraceExec and !TraceExec do something if a
+ // tracer is attached.
+ if !t.hasTracer() {
+ return
+ }
+ if t.ptraceOpts.TraceExec {
+ t.Debugf("Entering PTRACE_EVENT_EXEC stop")
+ t.ptraceEventLocked(linux.PTRACE_EVENT_EXEC, uint64(oldTID))
+ return
+ }
+ // "If the PTRACE_O_TRACEEXEC option is not in effect for the execing
+ // tracee, and if the tracee was PTRACE_ATTACHed rather that [sic]
+ // PTRACE_SEIZEd, the kernel delivers an extra SIGTRAP to the tracee after
+ // execve(2) returns. This is an ordinary signal (similar to one which can
+ // be generated by `kill -TRAP`, not a special kind of ptrace-stop.
+ // Employing PTRACE_GETSIGINFO for this signal returns si_code set to 0
+ // (SI_USER). This signal may be blocked by signal mask, and thus may be
+ // delivered (much) later." - ptrace(2)
+ if t.ptraceSeized {
+ return
+ }
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ t.sendSignalLocked(&arch.SignalInfo{
+ Signo: int32(linux.SIGTRAP),
+ Code: arch.SignalInfoUser,
+ }, false /* group */)
+}
+
+// ptraceExit is called early in the task exit path to check if t should enter
+// PTRACE_EVENT_EXIT stop.
+func (t *Task) ptraceExit() {
+ if !t.hasTracer() {
+ return
+ }
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ if !t.ptraceOpts.TraceExit {
+ return
+ }
+ t.tg.signalHandlers.mu.Lock()
+ status := t.exitStatus.Status()
+ t.tg.signalHandlers.mu.Unlock()
+ t.Debugf("Entering PTRACE_EVENT_EXIT stop")
+ t.ptraceEventLocked(linux.PTRACE_EVENT_EXIT, uint64(status))
+}
+
+// Preconditions: The TaskSet mutex must be locked.
+func (t *Task) ptraceEventLocked(event int32, msg uint64) {
+ t.ptraceEventMsg = msg
+ // """
+ // PTRACE_EVENT stops are observed by the tracer as waitpid(2) returning
+ // with WIFSTOPPED(status), and WSTOPSIG(status) returns SIGTRAP. An
+ // additional bit is set in the higher byte of the status word: the value
+ // status>>8 will be
+ //
+ // (SIGTRAP | PTRACE_EVENT_foo << 8).
+ //
+ // ...
+ //
+ // """ - ptrace(2)
+ t.ptraceTrapLocked(int32(linux.SIGTRAP) | (event << 8))
+}
+
+// ptraceKill implements ptrace(PTRACE_KILL, target). t is the caller.
+func (t *Task) ptraceKill(target *Task) error {
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ if target.Tracer() != t {
+ return syserror.ESRCH
+ }
+ target.tg.signalHandlers.mu.Lock()
+ defer target.tg.signalHandlers.mu.Unlock()
+ // "This operation is deprecated; do not use it! Instead, send a SIGKILL
+ // directly using kill(2) or tgkill(2). The problem with PTRACE_KILL is
+ // that it requires the tracee to be in signal-delivery-stop, otherwise it
+ // may not work (i.e., may complete successfully but won't kill the
+ // tracee)." - ptrace(2)
+ if target.stop == nil {
+ return nil
+ }
+ if _, ok := target.stop.(*ptraceStop); !ok {
+ return nil
+ }
+ target.ptraceCode = int32(linux.SIGKILL)
+ target.endInternalStopLocked()
+ return nil
+}
+
+func (t *Task) ptraceInterrupt(target *Task) error {
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ if target.Tracer() != t {
+ return syserror.ESRCH
+ }
+ if !target.ptraceSeized {
+ return syserror.EIO
+ }
+ target.tg.signalHandlers.mu.Lock()
+ defer target.tg.signalHandlers.mu.Unlock()
+ if target.killedLocked() || target.exitState >= TaskExitInitiated {
+ return nil
+ }
+ target.trapStopPending = true
+ if s, ok := target.stop.(*ptraceStop); ok && s.listen {
+ target.endInternalStopLocked()
+ }
+ target.interrupt()
+ return nil
+}
+
+// Preconditions: The TaskSet mutex must be locked for writing. t must have a
+// tracer.
+func (t *Task) ptraceSetOptionsLocked(opts uintptr) error {
+ const valid = uintptr(linux.PTRACE_O_EXITKILL |
+ linux.PTRACE_O_TRACESYSGOOD |
+ linux.PTRACE_O_TRACECLONE |
+ linux.PTRACE_O_TRACEEXEC |
+ linux.PTRACE_O_TRACEEXIT |
+ linux.PTRACE_O_TRACEFORK |
+ linux.PTRACE_O_TRACESECCOMP |
+ linux.PTRACE_O_TRACEVFORK |
+ linux.PTRACE_O_TRACEVFORKDONE)
+ if opts&^valid != 0 {
+ return syserror.EINVAL
+ }
+ t.ptraceOpts = ptraceOptions{
+ ExitKill: opts&linux.PTRACE_O_EXITKILL != 0,
+ SysGood: opts&linux.PTRACE_O_TRACESYSGOOD != 0,
+ TraceClone: opts&linux.PTRACE_O_TRACECLONE != 0,
+ TraceExec: opts&linux.PTRACE_O_TRACEEXEC != 0,
+ TraceExit: opts&linux.PTRACE_O_TRACEEXIT != 0,
+ TraceFork: opts&linux.PTRACE_O_TRACEFORK != 0,
+ TraceSeccomp: opts&linux.PTRACE_O_TRACESECCOMP != 0,
+ TraceVfork: opts&linux.PTRACE_O_TRACEVFORK != 0,
+ TraceVforkDone: opts&linux.PTRACE_O_TRACEVFORKDONE != 0,
+ }
+ return nil
+}
+
+// Ptrace implements the ptrace system call.
+func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
+ // PTRACE_TRACEME ignores all other arguments.
+ if req == linux.PTRACE_TRACEME {
+ return t.ptraceTraceme()
+ }
+ // All other ptrace requests operate on a current or future tracee
+ // specified by pid.
+ target := t.tg.pidns.TaskWithID(pid)
+ if target == nil {
+ return syserror.ESRCH
+ }
+
+ // PTRACE_ATTACH and PTRACE_SEIZE do not require that target is not already
+ // a tracee.
+ if req == linux.PTRACE_ATTACH || req == linux.PTRACE_SEIZE {
+ seize := req == linux.PTRACE_SEIZE
+ if seize && addr != 0 {
+ return syserror.EIO
+ }
+ return t.ptraceAttach(target, seize, uintptr(data))
+ }
+ // PTRACE_KILL and PTRACE_INTERRUPT require that the target is a tracee,
+ // but does not require that it is ptrace-stopped.
+ if req == linux.PTRACE_KILL {
+ return t.ptraceKill(target)
+ }
+ if req == linux.PTRACE_INTERRUPT {
+ return t.ptraceInterrupt(target)
+ }
+ // All other ptrace requests require that the target is a ptrace-stopped
+ // tracee, and freeze the ptrace-stop so the tracee can be operated on.
+ t.tg.pidns.owner.mu.RLock()
+ if target.Tracer() != t {
+ t.tg.pidns.owner.mu.RUnlock()
+ return syserror.ESRCH
+ }
+ if !target.ptraceFreeze() {
+ t.tg.pidns.owner.mu.RUnlock()
+ // "Most ptrace commands (all except PTRACE_ATTACH, PTRACE_SEIZE,
+ // PTRACE_TRACEME, PTRACE_INTERRUPT, and PTRACE_KILL) require the
+ // tracee to be in a ptrace-stop, otherwise they fail with ESRCH." -
+ // ptrace(2)
+ return syserror.ESRCH
+ }
+ t.tg.pidns.owner.mu.RUnlock()
+ // Even if the target has a ptrace-stop active, the tracee's task goroutine
+ // may not yet have reached Task.doStop; wait for it to do so. This is safe
+ // because there's no way for target to initiate a ptrace-stop and then
+ // block (by calling Task.block) before entering it.
+ //
+ // Caveat: If tasks were just restored, the tracee's first call to
+ // Task.Activate (in Task.run) occurs before its first call to Task.doStop,
+ // which may block if the tracer's address space is active.
+ t.UninterruptibleSleepStart(true)
+ target.waitGoroutineStoppedOrExited()
+ t.UninterruptibleSleepFinish(true)
+
+ // Resuming commands end the ptrace stop, but only if successful.
+ // PTRACE_LISTEN ends the ptrace stop if trapNotifyPending is already set on the
+ // target.
+ switch req {
+ case linux.PTRACE_DETACH:
+ if err := t.ptraceDetach(target, linux.Signal(data)); err != nil {
+ target.ptraceUnfreeze()
+ return err
+ }
+ return nil
+
+ case linux.PTRACE_CONT:
+ if err := target.ptraceUnstop(ptraceSyscallNone, false, linux.Signal(data)); err != nil {
+ target.ptraceUnfreeze()
+ return err
+ }
+ return nil
+
+ case linux.PTRACE_SYSCALL:
+ if err := target.ptraceUnstop(ptraceSyscallIntercept, false, linux.Signal(data)); err != nil {
+ target.ptraceUnfreeze()
+ return err
+ }
+ return nil
+
+ case linux.PTRACE_SINGLESTEP:
+ if err := target.ptraceUnstop(ptraceSyscallNone, true, linux.Signal(data)); err != nil {
+ target.ptraceUnfreeze()
+ return err
+ }
+ return nil
+
+ case linux.PTRACE_SYSEMU:
+ if err := target.ptraceUnstop(ptraceSyscallEmu, false, linux.Signal(data)); err != nil {
+ target.ptraceUnfreeze()
+ return err
+ }
+ return nil
+
+ case linux.PTRACE_SYSEMU_SINGLESTEP:
+ if err := target.ptraceUnstop(ptraceSyscallEmu, true, linux.Signal(data)); err != nil {
+ target.ptraceUnfreeze()
+ return err
+ }
+ return nil
+
+ case linux.PTRACE_LISTEN:
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ if !target.ptraceSeized {
+ return syserror.EIO
+ }
+ if target.ptraceSiginfo == nil {
+ return syserror.EIO
+ }
+ if target.ptraceSiginfo.Code>>8 != linux.PTRACE_EVENT_STOP {
+ return syserror.EIO
+ }
+ target.tg.signalHandlers.mu.Lock()
+ defer target.tg.signalHandlers.mu.Unlock()
+ if target.trapNotifyPending {
+ target.endInternalStopLocked()
+ } else {
+ target.stop.(*ptraceStop).listen = true
+ target.ptraceUnfreezeLocked()
+ }
+ return nil
+ }
+
+ // All other ptrace requests expect us to unfreeze the stop.
+ defer target.ptraceUnfreeze()
+
+ switch req {
+ case linux.PTRACE_PEEKTEXT, linux.PTRACE_PEEKDATA:
+ // "At the system call level, the PTRACE_PEEKTEXT, PTRACE_PEEKDATA, and
+ // PTRACE_PEEKUSER requests have a different API: they store the result
+ // at the address specified by the data parameter, and the return value
+ // is the error flag." - ptrace(2)
+ word := t.Arch().Native(0)
+ if _, err := usermem.CopyObjectIn(t, target.MemoryManager(), addr, word, usermem.IOOpts{
+ IgnorePermissions: true,
+ }); err != nil {
+ return err
+ }
+ _, err := t.CopyOut(data, word)
+ return err
+
+ case linux.PTRACE_POKETEXT, linux.PTRACE_POKEDATA:
+ _, err := usermem.CopyObjectOut(t, target.MemoryManager(), addr, t.Arch().Native(uintptr(data)), usermem.IOOpts{
+ IgnorePermissions: true,
+ })
+ return err
+
+ case linux.PTRACE_GETREGSET:
+ // "Read the tracee's registers. addr specifies, in an
+ // architecture-dependent way, the type of registers to be read. ...
+ // data points to a struct iovec, which describes the destination
+ // buffer's location and length. On return, the kernel modifies iov.len
+ // to indicate the actual number of bytes returned." - ptrace(2)
+ ars, err := t.CopyInIovecs(data, 1)
+ if err != nil {
+ return err
+ }
+ ar := ars.Head()
+ n, err := target.Arch().PtraceGetRegSet(uintptr(addr), &usermem.IOReadWriter{
+ Ctx: t,
+ IO: t.MemoryManager(),
+ Addr: ar.Start,
+ Opts: usermem.IOOpts{
+ AddressSpaceActive: true,
+ },
+ }, int(ar.Length()))
+ if err != nil {
+ return err
+ }
+
+ // Update iovecs to represent the range of the written register set.
+ end, ok := ar.Start.AddLength(uint64(n))
+ if !ok {
+ panic(fmt.Sprintf("%#x + %#x overflows. Invalid reg size > %#x", ar.Start, n, ar.Length()))
+ }
+ ar.End = end
+ return t.CopyOutIovecs(data, usermem.AddrRangeSeqOf(ar))
+
+ case linux.PTRACE_SETREGSET:
+ ars, err := t.CopyInIovecs(data, 1)
+ if err != nil {
+ return err
+ }
+ ar := ars.Head()
+ n, err := target.Arch().PtraceSetRegSet(uintptr(addr), &usermem.IOReadWriter{
+ Ctx: t,
+ IO: t.MemoryManager(),
+ Addr: ar.Start,
+ Opts: usermem.IOOpts{
+ AddressSpaceActive: true,
+ },
+ }, int(ar.Length()))
+ if err != nil {
+ return err
+ }
+ ar.End -= usermem.Addr(n)
+ return t.CopyOutIovecs(data, usermem.AddrRangeSeqOf(ar))
+
+ case linux.PTRACE_GETSIGINFO:
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ if target.ptraceSiginfo == nil {
+ return syserror.EINVAL
+ }
+ _, err := t.CopyOut(data, target.ptraceSiginfo)
+ return err
+
+ case linux.PTRACE_SETSIGINFO:
+ var info arch.SignalInfo
+ if _, err := t.CopyIn(data, &info); err != nil {
+ return err
+ }
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ if target.ptraceSiginfo == nil {
+ return syserror.EINVAL
+ }
+ target.ptraceSiginfo = &info
+ return nil
+
+ case linux.PTRACE_GETSIGMASK:
+ if addr != linux.SignalSetSize {
+ return syserror.EINVAL
+ }
+ _, err := t.CopyOut(data, target.SignalMask())
+ return err
+
+ case linux.PTRACE_SETSIGMASK:
+ if addr != linux.SignalSetSize {
+ return syserror.EINVAL
+ }
+ var mask linux.SignalSet
+ if _, err := t.CopyIn(data, &mask); err != nil {
+ return err
+ }
+ // The target's task goroutine is stopped, so this is safe:
+ target.SetSignalMask(mask &^ UnblockableSignals)
+ return nil
+
+ case linux.PTRACE_SETOPTIONS:
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ return target.ptraceSetOptionsLocked(uintptr(data))
+
+ case linux.PTRACE_GETEVENTMSG:
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ _, err := t.CopyOut(usermem.Addr(data), target.ptraceEventMsg)
+ return err
+
+ // PEEKSIGINFO is unimplemented but seems to have no users anywhere.
+
+ default:
+ return t.ptraceArch(target, req, addr, data)
+ }
+}
diff --git a/pkg/sentry/kernel/ptrace_amd64.go b/pkg/sentry/kernel/ptrace_amd64.go
new file mode 100644
index 000000000..cef1276ec
--- /dev/null
+++ b/pkg/sentry/kernel/ptrace_amd64.go
@@ -0,0 +1,89 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package kernel
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// ptraceArch implements arch-specific ptrace commands.
+func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) error {
+ switch req {
+ case linux.PTRACE_PEEKUSR: // aka PTRACE_PEEKUSER
+ n, err := target.Arch().PtracePeekUser(uintptr(addr))
+ if err != nil {
+ return err
+ }
+ _, err = t.CopyOut(data, n)
+ return err
+
+ case linux.PTRACE_POKEUSR: // aka PTRACE_POKEUSER
+ return target.Arch().PtracePokeUser(uintptr(addr), uintptr(data))
+
+ case linux.PTRACE_GETREGS:
+ // "Copy the tracee's general-purpose ... registers ... to the address
+ // data in the tracer. ... (addr is ignored.) Note that SPARC systems
+ // have the meaning of data and addr reversed ..."
+ _, err := target.Arch().PtraceGetRegs(&usermem.IOReadWriter{
+ Ctx: t,
+ IO: t.MemoryManager(),
+ Addr: data,
+ Opts: usermem.IOOpts{
+ AddressSpaceActive: true,
+ },
+ })
+ return err
+
+ case linux.PTRACE_GETFPREGS:
+ _, err := target.Arch().PtraceGetFPRegs(&usermem.IOReadWriter{
+ Ctx: t,
+ IO: t.MemoryManager(),
+ Addr: data,
+ Opts: usermem.IOOpts{
+ AddressSpaceActive: true,
+ },
+ })
+ return err
+
+ case linux.PTRACE_SETREGS:
+ _, err := target.Arch().PtraceSetRegs(&usermem.IOReadWriter{
+ Ctx: t,
+ IO: t.MemoryManager(),
+ Addr: data,
+ Opts: usermem.IOOpts{
+ AddressSpaceActive: true,
+ },
+ })
+ return err
+
+ case linux.PTRACE_SETFPREGS:
+ _, err := target.Arch().PtraceSetFPRegs(&usermem.IOReadWriter{
+ Ctx: t,
+ IO: t.MemoryManager(),
+ Addr: data,
+ Opts: usermem.IOOpts{
+ AddressSpaceActive: true,
+ },
+ })
+ return err
+
+ default:
+ return syserror.EIO
+ }
+}
diff --git a/pkg/sentry/kernel/ptrace_arm64.go b/pkg/sentry/kernel/ptrace_arm64.go
new file mode 100644
index 000000000..d971b96b3
--- /dev/null
+++ b/pkg/sentry/kernel/ptrace_arm64.go
@@ -0,0 +1,27 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package kernel
+
+import (
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// ptraceArch implements arch-specific ptrace commands.
+func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) error {
+ return syserror.EIO
+}
diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go
new file mode 100644
index 000000000..18416643b
--- /dev/null
+++ b/pkg/sentry/kernel/rseq.go
@@ -0,0 +1,393 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/hostcpu"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Restartable sequences.
+//
+// We support two different APIs for restartable sequences.
+//
+// 1. The upstream interface added in v4.18.
+// 2. The interface described in https://lwn.net/Articles/650333/.
+//
+// Throughout this file and other parts of the kernel, the latter is referred
+// to as "old rseq". This interface was never merged upstream, but is supported
+// for a limited set of applications that use it regardless.
+
+// OldRSeqCriticalRegion describes an old rseq critical region.
+//
+// +stateify savable
+type OldRSeqCriticalRegion struct {
+ // When a task in this thread group has its CPU preempted (as defined by
+ // platform.ErrContextCPUPreempted) or has a signal delivered to an
+ // application handler while its instruction pointer is in CriticalSection,
+ // set the instruction pointer to Restart and application register r10 (on
+ // amd64) to the former instruction pointer.
+ CriticalSection usermem.AddrRange
+ Restart usermem.Addr
+}
+
+// RSeqAvailable returns true if t supports (old and new) restartable sequences.
+func (t *Task) RSeqAvailable() bool {
+ return t.k.useHostCores && t.k.Platform.DetectsCPUPreemption()
+}
+
+// SetRSeq registers addr as this thread's rseq structure.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) SetRSeq(addr usermem.Addr, length, signature uint32) error {
+ if t.rseqAddr != 0 {
+ if t.rseqAddr != addr {
+ return syserror.EINVAL
+ }
+ if t.rseqSignature != signature {
+ return syserror.EINVAL
+ }
+ return syserror.EBUSY
+ }
+
+ // rseq must be aligned and correctly sized.
+ if addr&(linux.AlignOfRSeq-1) != 0 {
+ return syserror.EINVAL
+ }
+ if length != linux.SizeOfRSeq {
+ return syserror.EINVAL
+ }
+ if _, ok := t.MemoryManager().CheckIORange(addr, linux.SizeOfRSeq); !ok {
+ return syserror.EFAULT
+ }
+
+ t.rseqAddr = addr
+ t.rseqSignature = signature
+
+ // Initialize the CPUID.
+ //
+ // Linux implicitly does this on return from userspace, where failure
+ // would cause SIGSEGV.
+ if err := t.rseqUpdateCPU(); err != nil {
+ t.rseqAddr = 0
+ t.rseqSignature = 0
+
+ t.Debugf("Failed to copy CPU to %#x for rseq: %v", t.rseqAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return syserror.EFAULT
+ }
+
+ return nil
+}
+
+// ClearRSeq unregisters addr as this thread's rseq structure.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) ClearRSeq(addr usermem.Addr, length, signature uint32) error {
+ if t.rseqAddr == 0 {
+ return syserror.EINVAL
+ }
+ if t.rseqAddr != addr {
+ return syserror.EINVAL
+ }
+ if length != linux.SizeOfRSeq {
+ return syserror.EINVAL
+ }
+ if t.rseqSignature != signature {
+ return syserror.EPERM
+ }
+
+ if err := t.rseqClearCPU(); err != nil {
+ return err
+ }
+
+ t.rseqAddr = 0
+ t.rseqSignature = 0
+
+ if t.oldRSeqCPUAddr == 0 {
+ // rseqCPU no longer needed.
+ t.rseqCPU = -1
+ }
+
+ return nil
+}
+
+// OldRSeqCriticalRegion returns a copy of t's thread group's current
+// old restartable sequence.
+func (t *Task) OldRSeqCriticalRegion() OldRSeqCriticalRegion {
+ return *t.tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion)
+}
+
+// SetOldRSeqCriticalRegion replaces t's thread group's old restartable
+// sequence.
+//
+// Preconditions: t.RSeqAvailable() == true.
+func (t *Task) SetOldRSeqCriticalRegion(r OldRSeqCriticalRegion) error {
+ // These checks are somewhat more lenient than in Linux, which (bizarrely)
+ // requires r.CriticalSection to be non-empty and r.Restart to be
+ // outside of r.CriticalSection, even if r.CriticalSection.Start == 0
+ // (which disables the critical region).
+ if r.CriticalSection.Start == 0 {
+ r.CriticalSection.End = 0
+ r.Restart = 0
+ t.tg.oldRSeqCritical.Store(&r)
+ return nil
+ }
+ if r.CriticalSection.Start >= r.CriticalSection.End {
+ return syserror.EINVAL
+ }
+ if r.CriticalSection.Contains(r.Restart) {
+ return syserror.EINVAL
+ }
+ // TODO(jamieliu): check that r.CriticalSection and r.Restart are in
+ // the application address range, for consistency with Linux.
+ t.tg.oldRSeqCritical.Store(&r)
+ return nil
+}
+
+// OldRSeqCPUAddr returns the address that old rseq will keep updated with t's
+// CPU number.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) OldRSeqCPUAddr() usermem.Addr {
+ return t.oldRSeqCPUAddr
+}
+
+// SetOldRSeqCPUAddr replaces the address that old rseq will keep updated with
+// t's CPU number.
+//
+// Preconditions: t.RSeqAvailable() == true. The caller must be running on the
+// task goroutine. t's AddressSpace must be active.
+func (t *Task) SetOldRSeqCPUAddr(addr usermem.Addr) error {
+ t.oldRSeqCPUAddr = addr
+
+ // Check that addr is writable.
+ //
+ // N.B. rseqUpdateCPU may fail on a bad t.rseqAddr as well. That's
+ // unfortunate, but unlikely in a correct program.
+ if err := t.rseqUpdateCPU(); err != nil {
+ t.oldRSeqCPUAddr = 0
+ return syserror.EINVAL // yes, EINVAL, not err or EFAULT
+ }
+ return nil
+}
+
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) rseqUpdateCPU() error {
+ if t.rseqAddr == 0 && t.oldRSeqCPUAddr == 0 {
+ t.rseqCPU = -1
+ return nil
+ }
+
+ t.rseqCPU = int32(hostcpu.GetCPU())
+
+ // Update both CPUs, even if one fails.
+ rerr := t.rseqCopyOutCPU()
+ oerr := t.oldRSeqCopyOutCPU()
+
+ if rerr != nil {
+ return rerr
+ }
+ return oerr
+}
+
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) oldRSeqCopyOutCPU() error {
+ if t.oldRSeqCPUAddr == 0 {
+ return nil
+ }
+
+ buf := t.CopyScratchBuffer(4)
+ usermem.ByteOrder.PutUint32(buf, uint32(t.rseqCPU))
+ _, err := t.CopyOutBytes(t.oldRSeqCPUAddr, buf)
+ return err
+}
+
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) rseqCopyOutCPU() error {
+ if t.rseqAddr == 0 {
+ return nil
+ }
+
+ buf := t.CopyScratchBuffer(8)
+ // CPUIDStart and CPUID are the first two fields in linux.RSeq.
+ usermem.ByteOrder.PutUint32(buf, uint32(t.rseqCPU)) // CPUIDStart
+ usermem.ByteOrder.PutUint32(buf[4:], uint32(t.rseqCPU)) // CPUID
+ // N.B. This write is not atomic, but since this occurs on the task
+ // goroutine then as long as userspace uses a single-instruction read
+ // it can't see an invalid value.
+ _, err := t.CopyOutBytes(t.rseqAddr, buf)
+ return err
+}
+
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) rseqClearCPU() error {
+ buf := t.CopyScratchBuffer(8)
+ // CPUIDStart and CPUID are the first two fields in linux.RSeq.
+ usermem.ByteOrder.PutUint32(buf, 0) // CPUIDStart
+ usermem.ByteOrder.PutUint32(buf[4:], linux.RSEQ_CPU_ID_UNINITIALIZED) // CPUID
+ // N.B. This write is not atomic, but since this occurs on the task
+ // goroutine then as long as userspace uses a single-instruction read
+ // it can't see an invalid value.
+ _, err := t.CopyOutBytes(t.rseqAddr, buf)
+ return err
+}
+
+// rseqAddrInterrupt checks if IP is in a critical section, and aborts if so.
+//
+// This is a bit complex since both the RSeq and RSeqCriticalSection structs
+// are stored in userspace. So we must:
+//
+// 1. Copy in the address of RSeqCriticalSection from RSeq.
+// 2. Copy in RSeqCriticalSection itself.
+// 3. Validate critical section struct version, address range, abort address.
+// 4. Validate the abort signature (4 bytes preceding abort IP match expected
+// signature).
+// 5. Clear address of RSeqCriticalSection from RSeq.
+// 6. Finally, conditionally abort.
+//
+// See kernel/rseq.c:rseq_ip_fixup for reference.
+//
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) rseqAddrInterrupt() {
+ if t.rseqAddr == 0 {
+ return
+ }
+
+ critAddrAddr, ok := t.rseqAddr.AddLength(linux.OffsetOfRSeqCriticalSection)
+ if !ok {
+ // SetRSeq should validate this.
+ panic(fmt.Sprintf("t.rseqAddr (%#x) not large enough", t.rseqAddr))
+ }
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ t.Debugf("Only 64-bit rseq supported.")
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ buf := t.CopyScratchBuffer(8)
+ if _, err := t.CopyInBytes(critAddrAddr, buf); err != nil {
+ t.Debugf("Failed to copy critical section address from %#x for rseq: %v", critAddrAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ critAddr := usermem.Addr(usermem.ByteOrder.Uint64(buf))
+ if critAddr == 0 {
+ return
+ }
+
+ var cs linux.RSeqCriticalSection
+ if _, err := cs.CopyIn(t, critAddr); err != nil {
+ t.Debugf("Failed to copy critical section from %#x for rseq: %v", critAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ if cs.Version != 0 {
+ t.Debugf("Unknown version in %+v", cs)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ start := usermem.Addr(cs.Start)
+ critRange, ok := start.ToRange(cs.PostCommitOffset)
+ if !ok {
+ t.Debugf("Invalid start and offset in %+v", cs)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ abort := usermem.Addr(cs.Abort)
+ if critRange.Contains(abort) {
+ t.Debugf("Abort in critical section in %+v", cs)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ // Verify signature.
+ sigAddr := abort - linux.SizeOfRSeqSignature
+
+ buf = t.CopyScratchBuffer(linux.SizeOfRSeqSignature)
+ if _, err := t.CopyInBytes(sigAddr, buf); err != nil {
+ t.Debugf("Failed to copy critical section signature from %#x for rseq: %v", sigAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ sig := usermem.ByteOrder.Uint32(buf)
+ if sig != t.rseqSignature {
+ t.Debugf("Mismatched rseq signature %d != %d", sig, t.rseqSignature)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ // Clear the critical section address.
+ //
+ // NOTE(b/143949567): We don't support any rseq flags, so we always
+ // restart if we are in the critical section, and thus *always* clear
+ // critAddrAddr.
+ if _, err := t.MemoryManager().ZeroOut(t, critAddrAddr, int64(t.Arch().Width()), usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ t.Debugf("Failed to clear critical section address from %#x for rseq: %v", critAddrAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ // Finally we can actually decide whether or not to restart.
+ if !critRange.Contains(usermem.Addr(t.Arch().IP())) {
+ return
+ }
+
+ t.Arch().SetIP(uintptr(cs.Abort))
+}
+
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) oldRSeqInterrupt() {
+ r := t.tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion)
+ if ip := t.Arch().IP(); r.CriticalSection.Contains(usermem.Addr(ip)) {
+ t.Debugf("Interrupted rseq critical section at %#x; restarting at %#x", ip, r.Restart)
+ t.Arch().SetIP(uintptr(r.Restart))
+ t.Arch().SetOldRSeqInterruptedIP(ip)
+ }
+}
+
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) rseqInterrupt() {
+ t.rseqAddrInterrupt()
+ t.oldRSeqInterrupt()
+}
diff --git a/pkg/sentry/kernel/sched/BUILD b/pkg/sentry/kernel/sched/BUILD
new file mode 100644
index 000000000..1b82e087b
--- /dev/null
+++ b/pkg/sentry/kernel/sched/BUILD
@@ -0,0 +1,19 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "sched",
+ srcs = [
+ "cpuset.go",
+ "sched.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+)
+
+go_test(
+ name = "sched_test",
+ size = "small",
+ srcs = ["cpuset_test.go"],
+ library = ":sched",
+)
diff --git a/pkg/sentry/kernel/sched/cpuset.go b/pkg/sentry/kernel/sched/cpuset.go
new file mode 100644
index 000000000..c6c436690
--- /dev/null
+++ b/pkg/sentry/kernel/sched/cpuset.go
@@ -0,0 +1,105 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sched
+
+import "math/bits"
+
+const (
+ bitsPerByte = 8
+ bytesPerLong = 8 // only for 64-bit architectures
+)
+
+// CPUSet contains a bitmap to record CPU information.
+//
+// Note that this definition is only correct for little-endian architectures,
+// since Linux's cpumask_t uses unsigned long.
+type CPUSet []byte
+
+// CPUSetSize returns the size in bytes of a CPUSet that can contain num cpus.
+func CPUSetSize(num uint) uint {
+ // NOTE(b/68859821): Applications may expect that the size of a CPUSet in
+ // bytes is always a multiple of sizeof(unsigned long), since this is true
+ // in Linux. Thus we always round up.
+ bytes := (num + bitsPerByte - 1) / bitsPerByte
+ longs := (bytes + bytesPerLong - 1) / bytesPerLong
+ return longs * bytesPerLong
+}
+
+// NewCPUSet returns a CPUSet for the given number of CPUs which initially
+// contains no CPUs.
+func NewCPUSet(num uint) CPUSet {
+ return CPUSet(make([]byte, CPUSetSize(num)))
+}
+
+// NewFullCPUSet returns a CPUSet for the given number of CPUs, all of which
+// are present in the set.
+func NewFullCPUSet(num uint) CPUSet {
+ c := NewCPUSet(num)
+ var i uint
+ for ; i < num/bitsPerByte; i++ {
+ c[i] = 0xff
+ }
+ if rem := num % bitsPerByte; rem != 0 {
+ c[i] = (1 << rem) - 1
+ }
+ return c
+}
+
+// Size returns the size of 'c' in bytes.
+func (c CPUSet) Size() uint {
+ return uint(len(c))
+}
+
+// NumCPUs returns how many cpus are set in the CPUSet.
+func (c CPUSet) NumCPUs() uint {
+ var n int
+ for _, b := range c {
+ n += bits.OnesCount8(b)
+ }
+ return uint(n)
+}
+
+// Copy returns a copy of the CPUSet.
+func (c CPUSet) Copy() CPUSet {
+ return append(CPUSet(nil), c...)
+}
+
+// Set sets the bit corresponding to cpu.
+func (c *CPUSet) Set(cpu uint) {
+ (*c)[cpu/bitsPerByte] |= 1 << (cpu % bitsPerByte)
+}
+
+// ClearAbove clears bits corresponding to cpu and all higher cpus.
+func (c *CPUSet) ClearAbove(cpu uint) {
+ i := cpu / bitsPerByte
+ if i >= c.Size() {
+ return
+ }
+ (*c)[i] &^= 0xff << (cpu % bitsPerByte)
+ for i++; i < c.Size(); i++ {
+ (*c)[i] = 0
+ }
+}
+
+// ForEachCPU iterates over the CPUSet and calls fn with the cpu index if
+// it's set.
+func (c CPUSet) ForEachCPU(fn func(uint)) {
+ for i := uint(0); i < c.Size()*bitsPerByte; i++ {
+ bit := uint(1) << (i & (bitsPerByte - 1))
+ if uint(c[i/bitsPerByte])&bit == bit {
+ fn(i)
+ }
+ }
+}
diff --git a/pkg/sentry/kernel/sched/cpuset_test.go b/pkg/sentry/kernel/sched/cpuset_test.go
new file mode 100644
index 000000000..3af9f1197
--- /dev/null
+++ b/pkg/sentry/kernel/sched/cpuset_test.go
@@ -0,0 +1,44 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sched
+
+import (
+ "testing"
+)
+
+func TestNumCPUs(t *testing.T) {
+ for i := uint(0); i < 1024; i++ {
+ c := NewCPUSet(i)
+ for j := uint(0); j < i; j++ {
+ c.Set(j)
+ }
+ n := c.NumCPUs()
+ if n != i {
+ t.Errorf("got wrong number of cpus %d, want %d", n, i)
+ }
+ }
+}
+
+func TestClearAbove(t *testing.T) {
+ const n = 1024
+ c := NewFullCPUSet(n)
+ for i := uint(0); i < n; i++ {
+ cpu := n - i
+ c.ClearAbove(cpu)
+ if got := c.NumCPUs(); got != cpu {
+ t.Errorf("iteration %d: got %d cpus, wanted %d", i, got, cpu)
+ }
+ }
+}
diff --git a/pkg/sentry/kernel/sched/sched.go b/pkg/sentry/kernel/sched/sched.go
new file mode 100644
index 000000000..de18c9d02
--- /dev/null
+++ b/pkg/sentry/kernel/sched/sched.go
@@ -0,0 +1,16 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package sched implements scheduler related features.
+package sched
diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go
new file mode 100644
index 000000000..c38c5a40c
--- /dev/null
+++ b/pkg/sentry/kernel/seccomp.go
@@ -0,0 +1,217 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const maxSyscallFilterInstructions = 1 << 15
+
+// seccompData is equivalent to struct seccomp_data, which contains the data
+// passed to seccomp-bpf filters.
+type seccompData struct {
+ // nr is the system call number.
+ nr int32
+
+ // arch is an AUDIT_ARCH_* value indicating the system call convention.
+ arch uint32
+
+ // instructionPointer is the value of the instruction pointer at the time
+ // of the system call.
+ instructionPointer uint64
+
+ // args contains the first 6 system call arguments.
+ args [6]uint64
+}
+
+func (d *seccompData) asBPFInput() bpf.Input {
+ return bpf.InputBytes{binary.Marshal(nil, usermem.ByteOrder, d), usermem.ByteOrder}
+}
+
+func seccompSiginfo(t *Task, errno, sysno int32, ip usermem.Addr) *arch.SignalInfo {
+ si := &arch.SignalInfo{
+ Signo: int32(linux.SIGSYS),
+ Errno: errno,
+ Code: arch.SYS_SECCOMP,
+ }
+ si.SetCallAddr(uint64(ip))
+ si.SetSyscall(sysno)
+ si.SetArch(t.SyscallTable().AuditNumber)
+ return si
+}
+
+// checkSeccompSyscall applies the task's seccomp filters before the execution
+// of syscall sysno at instruction pointer ip. (These parameters must be passed
+// in because vsyscalls do not use the values in t.Arch().)
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) checkSeccompSyscall(sysno int32, args arch.SyscallArguments, ip usermem.Addr) linux.BPFAction {
+ result := linux.BPFAction(t.evaluateSyscallFilters(sysno, args, ip))
+ action := result & linux.SECCOMP_RET_ACTION
+ switch action {
+ case linux.SECCOMP_RET_TRAP:
+ // "Results in the kernel sending a SIGSYS signal to the triggering
+ // task without executing the system call. ... The SECCOMP_RET_DATA
+ // portion of the return value will be passed as si_errno." -
+ // Documentation/prctl/seccomp_filter.txt
+ t.SendSignal(seccompSiginfo(t, int32(result.Data()), sysno, ip))
+ // "The return value register will contain an arch-dependent value." In
+ // practice, it's ~always the syscall number.
+ t.Arch().SetReturn(uintptr(sysno))
+
+ case linux.SECCOMP_RET_ERRNO:
+ // "Results in the lower 16-bits of the return value being passed to
+ // userland as the errno without executing the system call."
+ t.Arch().SetReturn(-uintptr(result.Data()))
+
+ case linux.SECCOMP_RET_TRACE:
+ // "When returned, this value will cause the kernel to attempt to
+ // notify a ptrace()-based tracer prior to executing the system call.
+ // If there is no tracer present, -ENOSYS is returned to userland and
+ // the system call is not executed."
+ if !t.ptraceSeccomp(result.Data()) {
+ // This useless-looking temporary is needed because Go.
+ tmp := uintptr(syscall.ENOSYS)
+ t.Arch().SetReturn(-tmp)
+ return linux.SECCOMP_RET_ERRNO
+ }
+
+ case linux.SECCOMP_RET_ALLOW:
+ // "Results in the system call being executed."
+
+ case linux.SECCOMP_RET_KILL_THREAD:
+ // "Results in the task exiting immediately without executing the
+ // system call. The exit status of the task will be SIGSYS, not
+ // SIGKILL."
+
+ default:
+ // consistent with Linux
+ return linux.SECCOMP_RET_KILL_THREAD
+ }
+ return action
+}
+
+func (t *Task) evaluateSyscallFilters(sysno int32, args arch.SyscallArguments, ip usermem.Addr) uint32 {
+ data := seccompData{
+ nr: sysno,
+ arch: t.tc.st.AuditNumber,
+ instructionPointer: uint64(ip),
+ }
+ // data.args is []uint64 and args is []arch.SyscallArgument (uintptr), so
+ // we can't do any slicing tricks or even use copy/append here.
+ for i, arg := range args {
+ if i >= len(data.args) {
+ break
+ }
+ data.args[i] = arg.Uint64()
+ }
+ input := data.asBPFInput()
+
+ ret := uint32(linux.SECCOMP_RET_ALLOW)
+ f := t.syscallFilters.Load()
+ if f == nil {
+ return ret
+ }
+
+ // "Every filter successfully installed will be evaluated (in reverse
+ // order) for each system call the task makes." - kernel/seccomp.c
+ for i := len(f.([]bpf.Program)) - 1; i >= 0; i-- {
+ thisRet, err := bpf.Exec(f.([]bpf.Program)[i], input)
+ if err != nil {
+ t.Debugf("seccomp-bpf filter %d returned error: %v", i, err)
+ thisRet = uint32(linux.SECCOMP_RET_KILL_THREAD)
+ }
+ // "If multiple filters exist, the return value for the evaluation of a
+ // given system call will always use the highest precedent value." -
+ // Documentation/prctl/seccomp_filter.txt
+ //
+ // (Note that this contradicts prctl(2): "If the filters permit prctl()
+ // calls, then additional filters can be added; they are run in order
+ // until the first non-allow result is seen." prctl(2) is incorrect.)
+ //
+ // "The ordering ensures that a min_t() over composed return values
+ // always selects the least permissive choice." -
+ // include/uapi/linux/seccomp.h
+ if (thisRet & linux.SECCOMP_RET_ACTION) < (ret & linux.SECCOMP_RET_ACTION) {
+ ret = thisRet
+ }
+ }
+
+ return ret
+}
+
+// AppendSyscallFilter adds BPF program p as a system call filter.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) AppendSyscallFilter(p bpf.Program, syncAll bool) error {
+ // While syscallFilters are an atomic.Value we must take the mutex to prevent
+ // our read-copy-update from happening while another task is syncing syscall
+ // filters to us, this keeps the filters in a consistent state.
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+
+ // Cap the combined length of all syscall filters (plus a penalty of 4
+ // instructions per filter beyond the first) to maxSyscallFilterInstructions.
+ // This restriction is inherited from Linux.
+ totalLength := p.Length()
+ var newFilters []bpf.Program
+
+ if sf := t.syscallFilters.Load(); sf != nil {
+ oldFilters := sf.([]bpf.Program)
+ for _, f := range oldFilters {
+ totalLength += f.Length() + 4
+ }
+ newFilters = append(newFilters, oldFilters...)
+ }
+
+ if totalLength > maxSyscallFilterInstructions {
+ return syserror.ENOMEM
+ }
+
+ newFilters = append(newFilters, p)
+ t.syscallFilters.Store(newFilters)
+
+ if syncAll {
+ // Note: No new privs is always assumed to be set.
+ for ot := t.tg.tasks.Front(); ot != nil; ot = ot.Next() {
+ if ot != t {
+ var copiedFilters []bpf.Program
+ copiedFilters = append(copiedFilters, newFilters...)
+ ot.syscallFilters.Store(copiedFilters)
+ }
+ }
+ }
+
+ return nil
+}
+
+// SeccompMode returns a SECCOMP_MODE_* constant indicating the task's current
+// seccomp syscall filtering mode, appropriate for both prctl(PR_GET_SECCOMP)
+// and /proc/[pid]/status.
+func (t *Task) SeccompMode() int {
+ f := t.syscallFilters.Load()
+ if f != nil && len(f.([]bpf.Program)) > 0 {
+ return linux.SECCOMP_MODE_FILTER
+ }
+ return linux.SECCOMP_MODE_NONE
+}
diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD
new file mode 100644
index 000000000..65e5427c1
--- /dev/null
+++ b/pkg/sentry/kernel/semaphore/BUILD
@@ -0,0 +1,49 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "waiter_list",
+ out = "waiter_list.go",
+ package = "semaphore",
+ prefix = "waiter",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*waiter",
+ "Linker": "*waiter",
+ },
+)
+
+go_library(
+ name = "semaphore",
+ srcs = [
+ "semaphore.go",
+ "waiter_list.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sync",
+ "//pkg/syserror",
+ ],
+)
+
+go_test(
+ name = "semaphore_test",
+ size = "small",
+ srcs = ["semaphore_test.go"],
+ library = ":semaphore",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/syserror",
+ ],
+)
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
new file mode 100644
index 000000000..c00fa1138
--- /dev/null
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -0,0 +1,572 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package semaphore implements System V semaphores.
+package semaphore
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const (
+ valueMax = 32767 // SEMVMX
+
+ // semaphoresMax is "maximum number of semaphores per semaphore ID" (SEMMSL).
+ semaphoresMax = 32000
+
+ // setMax is "system-wide limit on the number of semaphore sets" (SEMMNI).
+ setsMax = 32000
+
+ // semaphoresTotalMax is "system-wide limit on the number of semaphores"
+ // (SEMMNS = SEMMNI*SEMMSL).
+ semaphoresTotalMax = 1024000000
+)
+
+// Registry maintains a set of semaphores that can be found by key or ID.
+//
+// +stateify savable
+type Registry struct {
+ // userNS owning the ipc name this registry belongs to. Immutable.
+ userNS *auth.UserNamespace
+ // mu protects all fields below.
+ mu sync.Mutex `state:"nosave"`
+ semaphores map[int32]*Set
+ lastIDUsed int32
+}
+
+// Set represents a set of semaphores that can be operated atomically.
+//
+// +stateify savable
+type Set struct {
+ // registry owning this sem set. Immutable.
+ registry *Registry
+
+ // Id is a handle that identifies the set.
+ ID int32
+
+ // key is an user provided key that can be shared between processes.
+ key int32
+
+ // creator is the user that created the set. Immutable.
+ creator fs.FileOwner
+
+ // mu protects all fields below.
+ mu sync.Mutex `state:"nosave"`
+ owner fs.FileOwner
+ perms fs.FilePermissions
+ opTime ktime.Time
+ changeTime ktime.Time
+
+ // sems holds all semaphores in the set. The slice itself is immutable after
+ // it's been set, however each 'sem' object in the slice requires 'mu' lock.
+ sems []sem
+
+ // dead is set to true when the set is removed and can't be reached anymore.
+ // All waiters must wake up and fail when set is dead.
+ dead bool
+}
+
+// sem represents a single semaphore from a set.
+//
+// +stateify savable
+type sem struct {
+ value int16
+ waiters waiterList `state:"zerovalue"`
+ pid int32
+}
+
+// waiter represents a caller that is waiting for the semaphore value to
+// become positive or zero.
+//
+// +stateify savable
+type waiter struct {
+ waiterEntry
+
+ // value represents how much resource the waiter needs to wake up.
+ value int16
+ ch chan struct{}
+}
+
+// NewRegistry creates a new semaphore set registry.
+func NewRegistry(userNS *auth.UserNamespace) *Registry {
+ return &Registry{
+ userNS: userNS,
+ semaphores: make(map[int32]*Set),
+ }
+}
+
+// FindOrCreate searches for a semaphore set that matches 'key'. If not found,
+// it may create a new one if requested. If private is true, key is ignored and
+// a new set is always created. If create is false, it fails if a set cannot
+// be found. If exclusive is true, it fails if a set with the same key already
+// exists.
+func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linux.FileMode, private, create, exclusive bool) (*Set, error) {
+ if nsems < 0 || nsems > semaphoresMax {
+ return nil, syserror.EINVAL
+ }
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if !private {
+ // Look up an existing semaphore.
+ if set := r.findByKey(key); set != nil {
+ set.mu.Lock()
+ defer set.mu.Unlock()
+
+ // Check that caller can access semaphore set.
+ creds := auth.CredentialsFromContext(ctx)
+ if !set.checkPerms(creds, fs.PermsFromMode(mode)) {
+ return nil, syserror.EACCES
+ }
+
+ // Validate parameters.
+ if nsems > int32(set.Size()) {
+ return nil, syserror.EINVAL
+ }
+ if create && exclusive {
+ return nil, syserror.EEXIST
+ }
+ return set, nil
+ }
+
+ if !create {
+ // Semaphore not found and should not be created.
+ return nil, syserror.ENOENT
+ }
+ }
+
+ // Zero is only valid if an existing set is found.
+ if nsems == 0 {
+ return nil, syserror.EINVAL
+ }
+
+ // Apply system limits.
+ if len(r.semaphores) >= setsMax {
+ return nil, syserror.EINVAL
+ }
+ if r.totalSems() > int(semaphoresTotalMax-nsems) {
+ return nil, syserror.EINVAL
+ }
+
+ // Finally create a new set.
+ owner := fs.FileOwnerFromContext(ctx)
+ perms := fs.FilePermsFromMode(mode)
+ return r.newSet(ctx, key, owner, owner, perms, nsems)
+}
+
+// RemoveID removes set with give 'id' from the registry and marks the set as
+// dead. All waiters will be awakened and fail.
+func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ set := r.semaphores[id]
+ if set == nil {
+ return syserror.EINVAL
+ }
+
+ set.mu.Lock()
+ defer set.mu.Unlock()
+
+ // "The effective user ID of the calling process must match the creator or
+ // owner of the semaphore set, or the caller must be privileged."
+ if !set.checkCredentials(creds) && !set.checkCapability(creds) {
+ return syserror.EACCES
+ }
+
+ delete(r.semaphores, set.ID)
+ set.destroy()
+ return nil
+}
+
+func (r *Registry) newSet(ctx context.Context, key int32, owner, creator fs.FileOwner, perms fs.FilePermissions, nsems int32) (*Set, error) {
+ set := &Set{
+ registry: r,
+ key: key,
+ owner: owner,
+ creator: owner,
+ perms: perms,
+ changeTime: ktime.NowFromContext(ctx),
+ sems: make([]sem, nsems),
+ }
+
+ // Find the next available ID.
+ for id := r.lastIDUsed + 1; id != r.lastIDUsed; id++ {
+ // Handle wrap around.
+ if id < 0 {
+ id = 0
+ continue
+ }
+ if r.semaphores[id] == nil {
+ r.lastIDUsed = id
+ r.semaphores[id] = set
+ set.ID = id
+ return set, nil
+ }
+ }
+
+ log.Warningf("Semaphore map is full, they must be leaking")
+ return nil, syserror.ENOMEM
+}
+
+// FindByID looks up a set given an ID.
+func (r *Registry) FindByID(id int32) *Set {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ return r.semaphores[id]
+}
+
+func (r *Registry) findByKey(key int32) *Set {
+ for _, v := range r.semaphores {
+ if v.key == key {
+ return v
+ }
+ }
+ return nil
+}
+
+func (r *Registry) totalSems() int {
+ totalSems := 0
+ for _, v := range r.semaphores {
+ totalSems += v.Size()
+ }
+ return totalSems
+}
+
+func (s *Set) findSem(num int32) *sem {
+ if num < 0 || int(num) >= s.Size() {
+ return nil
+ }
+ return &s.sems[num]
+}
+
+// Size returns the number of semaphores in the set. Size is immutable.
+func (s *Set) Size() int {
+ return len(s.sems)
+}
+
+// Change changes some fields from the set atomically.
+func (s *Set) Change(ctx context.Context, creds *auth.Credentials, owner fs.FileOwner, perms fs.FilePermissions) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // "The effective UID of the calling process must match the owner or creator
+ // of the semaphore set, or the caller must be privileged."
+ if !s.checkCredentials(creds) && !s.checkCapability(creds) {
+ return syserror.EACCES
+ }
+
+ s.owner = owner
+ s.perms = perms
+ s.changeTime = ktime.NowFromContext(ctx)
+ return nil
+}
+
+// SetVal overrides a semaphore value, waking up waiters as needed.
+func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Credentials, pid int32) error {
+ if val < 0 || val > valueMax {
+ return syserror.ERANGE
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // "The calling process must have alter permission on the semaphore set."
+ if !s.checkPerms(creds, fs.PermMask{Write: true}) {
+ return syserror.EACCES
+ }
+
+ sem := s.findSem(num)
+ if sem == nil {
+ return syserror.ERANGE
+ }
+
+ // TODO(gvisor.dev/issue/137): Clear undo entries in all processes.
+ sem.value = val
+ sem.pid = pid
+ s.changeTime = ktime.NowFromContext(ctx)
+ sem.wakeWaiters()
+ return nil
+}
+
+// SetValAll overrides all semaphores values, waking up waiters as needed. It also
+// sets semaphore's PID which was fixed in Linux 4.6.
+//
+// 'len(vals)' must be equal to 's.Size()'.
+func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credentials, pid int32) error {
+ if len(vals) != s.Size() {
+ panic(fmt.Sprintf("vals length (%d) different that Set.Size() (%d)", len(vals), s.Size()))
+ }
+
+ for _, val := range vals {
+ if val < 0 || val > valueMax {
+ return syserror.ERANGE
+ }
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // "The calling process must have alter permission on the semaphore set."
+ if !s.checkPerms(creds, fs.PermMask{Write: true}) {
+ return syserror.EACCES
+ }
+
+ for i, val := range vals {
+ sem := &s.sems[i]
+
+ // TODO(gvisor.dev/issue/137): Clear undo entries in all processes.
+ sem.value = int16(val)
+ sem.pid = pid
+ sem.wakeWaiters()
+ }
+ s.changeTime = ktime.NowFromContext(ctx)
+ return nil
+}
+
+// GetVal returns a semaphore value.
+func (s *Set) GetVal(num int32, creds *auth.Credentials) (int16, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // "The calling process must have read permission on the semaphore set."
+ if !s.checkPerms(creds, fs.PermMask{Read: true}) {
+ return 0, syserror.EACCES
+ }
+
+ sem := s.findSem(num)
+ if sem == nil {
+ return 0, syserror.ERANGE
+ }
+ return sem.value, nil
+}
+
+// GetValAll returns value for all semaphores.
+func (s *Set) GetValAll(creds *auth.Credentials) ([]uint16, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // "The calling process must have read permission on the semaphore set."
+ if !s.checkPerms(creds, fs.PermMask{Read: true}) {
+ return nil, syserror.EACCES
+ }
+
+ vals := make([]uint16, s.Size())
+ for i, sem := range s.sems {
+ vals[i] = uint16(sem.value)
+ }
+ return vals, nil
+}
+
+// GetPID returns the PID set when performing operations in the semaphore.
+func (s *Set) GetPID(num int32, creds *auth.Credentials) (int32, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // "The calling process must have read permission on the semaphore set."
+ if !s.checkPerms(creds, fs.PermMask{Read: true}) {
+ return 0, syserror.EACCES
+ }
+
+ sem := s.findSem(num)
+ if sem == nil {
+ return 0, syserror.ERANGE
+ }
+ return sem.pid, nil
+}
+
+// ExecuteOps attempts to execute a list of operations to the set. It only
+// succeeds when all operations can be applied. No changes are made if it fails.
+//
+// On failure, it may return an error (retries are hopeless) or it may return
+// a channel that can be waited on before attempting again.
+func (s *Set) ExecuteOps(ctx context.Context, ops []linux.Sembuf, creds *auth.Credentials, pid int32) (chan struct{}, int32, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // Did it race with a removal operation?
+ if s.dead {
+ return nil, 0, syserror.EIDRM
+ }
+
+ // Validate the operations.
+ readOnly := true
+ for _, op := range ops {
+ if s.findSem(int32(op.SemNum)) == nil {
+ return nil, 0, syserror.EFBIG
+ }
+ if op.SemOp != 0 {
+ readOnly = false
+ }
+ }
+
+ if !s.checkPerms(creds, fs.PermMask{Read: readOnly, Write: !readOnly}) {
+ return nil, 0, syserror.EACCES
+ }
+
+ ch, num, err := s.executeOps(ctx, ops, pid)
+ if err != nil {
+ return nil, 0, err
+ }
+ return ch, num, nil
+}
+
+func (s *Set) executeOps(ctx context.Context, ops []linux.Sembuf, pid int32) (chan struct{}, int32, error) {
+ // Changes to semaphores go to this slice temporarily until they all succeed.
+ tmpVals := make([]int16, len(s.sems))
+ for i := range s.sems {
+ tmpVals[i] = s.sems[i].value
+ }
+
+ for _, op := range ops {
+ sem := &s.sems[op.SemNum]
+ if op.SemOp == 0 {
+ // Handle 'wait for zero' operation.
+ if tmpVals[op.SemNum] != 0 {
+ // Semaphore isn't 0, must wait.
+ if op.SemFlg&linux.IPC_NOWAIT != 0 {
+ return nil, 0, syserror.ErrWouldBlock
+ }
+
+ w := newWaiter(op.SemOp)
+ sem.waiters.PushBack(w)
+ return w.ch, int32(op.SemNum), nil
+ }
+ } else {
+ if op.SemOp < 0 {
+ // Handle 'wait' operation.
+ if -op.SemOp > valueMax {
+ return nil, 0, syserror.ERANGE
+ }
+ if -op.SemOp > tmpVals[op.SemNum] {
+ // Not enough resources, must wait.
+ if op.SemFlg&linux.IPC_NOWAIT != 0 {
+ return nil, 0, syserror.ErrWouldBlock
+ }
+
+ w := newWaiter(op.SemOp)
+ sem.waiters.PushBack(w)
+ return w.ch, int32(op.SemNum), nil
+ }
+ } else {
+ // op.SemOp > 0: Handle 'signal' operation.
+ if tmpVals[op.SemNum] > valueMax-op.SemOp {
+ return nil, 0, syserror.ERANGE
+ }
+ }
+
+ tmpVals[op.SemNum] += op.SemOp
+ }
+ }
+
+ // All operations succeeded, apply them.
+ // TODO(gvisor.dev/issue/137): handle undo operations.
+ for i, v := range tmpVals {
+ s.sems[i].value = v
+ s.sems[i].wakeWaiters()
+ s.sems[i].pid = pid
+ }
+ s.opTime = ktime.NowFromContext(ctx)
+ return nil, 0, nil
+}
+
+// AbortWait notifies that a waiter is giving up and will not wait on the
+// channel anymore.
+func (s *Set) AbortWait(num int32, ch chan struct{}) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ sem := &s.sems[num]
+ for w := sem.waiters.Front(); w != nil; w = w.Next() {
+ if w.ch == ch {
+ sem.waiters.Remove(w)
+ return
+ }
+ }
+ // Waiter may not be found in case it raced with wakeWaiters().
+}
+
+func (s *Set) checkCredentials(creds *auth.Credentials) bool {
+ return s.owner.UID == creds.EffectiveKUID ||
+ s.owner.GID == creds.EffectiveKGID ||
+ s.creator.UID == creds.EffectiveKUID ||
+ s.creator.GID == creds.EffectiveKGID
+}
+
+func (s *Set) checkCapability(creds *auth.Credentials) bool {
+ return creds.HasCapabilityIn(linux.CAP_IPC_OWNER, s.registry.userNS) && creds.UserNamespace.MapFromKUID(s.owner.UID).Ok()
+}
+
+func (s *Set) checkPerms(creds *auth.Credentials, reqPerms fs.PermMask) bool {
+ // Are we owner, or in group, or other?
+ p := s.perms.Other
+ if s.owner.UID == creds.EffectiveKUID {
+ p = s.perms.User
+ } else if creds.InGroup(s.owner.GID) {
+ p = s.perms.Group
+ }
+
+ // Are permissions satisfied without capability checks?
+ if p.SupersetOf(reqPerms) {
+ return true
+ }
+
+ return s.checkCapability(creds)
+}
+
+// destroy destroys the set. Caller must hold 's.mu'.
+func (s *Set) destroy() {
+ // Notify all waiters. They will fail on the next attempt to execute
+ // operations and return error.
+ s.dead = true
+ for _, s := range s.sems {
+ for w := s.waiters.Front(); w != nil; w = w.Next() {
+ w.ch <- struct{}{}
+ }
+ s.waiters.Reset()
+ }
+}
+
+// wakeWaiters goes over all waiters and checks which of them can be notified.
+func (s *sem) wakeWaiters() {
+ // Note that this will release all waiters waiting for 0 too.
+ for w := s.waiters.Front(); w != nil; {
+ if s.value < w.value {
+ // Still blocked, skip it.
+ w = w.Next()
+ continue
+ }
+ w.ch <- struct{}{}
+ old := w
+ w = w.Next()
+ s.waiters.Remove(old)
+ }
+}
+
+func newWaiter(val int16) *waiter {
+ return &waiter{
+ value: val,
+ ch: make(chan struct{}, 1),
+ }
+}
diff --git a/pkg/sentry/kernel/semaphore/semaphore_test.go b/pkg/sentry/kernel/semaphore/semaphore_test.go
new file mode 100644
index 000000000..e47acefdf
--- /dev/null
+++ b/pkg/sentry/kernel/semaphore/semaphore_test.go
@@ -0,0 +1,172 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package semaphore
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+func executeOps(ctx context.Context, t *testing.T, set *Set, ops []linux.Sembuf, block bool) chan struct{} {
+ ch, _, err := set.executeOps(ctx, ops, 123)
+ if err != nil {
+ t.Fatalf("ExecuteOps(ops) failed, err: %v, ops: %+v", err, ops)
+ }
+ if block {
+ if ch == nil {
+ t.Fatalf("ExecuteOps(ops) got: nil, expected: !nil, ops: %+v", ops)
+ }
+ if signalled(ch) {
+ t.Fatalf("ExecuteOps(ops) channel should not have been signalled, ops: %+v", ops)
+ }
+ } else {
+ if ch != nil {
+ t.Fatalf("ExecuteOps(ops) got: %v, expected: nil, ops: %+v", ch, ops)
+ }
+ }
+ return ch
+}
+
+func signalled(ch chan struct{}) bool {
+ select {
+ case <-ch:
+ return true
+ default:
+ return false
+ }
+}
+
+func TestBasic(t *testing.T) {
+ ctx := contexttest.Context(t)
+ set := &Set{ID: 123, sems: make([]sem, 1)}
+ ops := []linux.Sembuf{
+ {SemOp: 1},
+ }
+ executeOps(ctx, t, set, ops, false)
+
+ ops[0].SemOp = -1
+ executeOps(ctx, t, set, ops, false)
+
+ ops[0].SemOp = -1
+ ch1 := executeOps(ctx, t, set, ops, true)
+
+ ops[0].SemOp = 1
+ executeOps(ctx, t, set, ops, false)
+ if !signalled(ch1) {
+ t.Fatalf("ExecuteOps(ops) channel should not have been signalled, ops: %+v", ops)
+ }
+}
+
+func TestWaitForZero(t *testing.T) {
+ ctx := contexttest.Context(t)
+ set := &Set{ID: 123, sems: make([]sem, 1)}
+ ops := []linux.Sembuf{
+ {SemOp: 0},
+ }
+ executeOps(ctx, t, set, ops, false)
+
+ ops[0].SemOp = -2
+ ch1 := executeOps(ctx, t, set, ops, true)
+
+ ops[0].SemOp = 0
+ executeOps(ctx, t, set, ops, false)
+
+ ops[0].SemOp = 1
+ executeOps(ctx, t, set, ops, false)
+
+ ops[0].SemOp = 0
+ chZero1 := executeOps(ctx, t, set, ops, true)
+
+ ops[0].SemOp = 0
+ chZero2 := executeOps(ctx, t, set, ops, true)
+
+ ops[0].SemOp = 1
+ executeOps(ctx, t, set, ops, false)
+ if !signalled(ch1) {
+ t.Fatalf("ExecuteOps(ops) channel should have been signalled, ops: %+v, set: %+v", ops, set)
+ }
+
+ ops[0].SemOp = -2
+ executeOps(ctx, t, set, ops, false)
+ if !signalled(chZero1) {
+ t.Fatalf("ExecuteOps(ops) channel zero 1 should have been signalled, ops: %+v, set: %+v", ops, set)
+ }
+ if !signalled(chZero2) {
+ t.Fatalf("ExecuteOps(ops) channel zero 2 should have been signalled, ops: %+v, set: %+v", ops, set)
+ }
+}
+
+func TestNoWait(t *testing.T) {
+ ctx := contexttest.Context(t)
+ set := &Set{ID: 123, sems: make([]sem, 1)}
+ ops := []linux.Sembuf{
+ {SemOp: 1},
+ }
+ executeOps(ctx, t, set, ops, false)
+
+ ops[0].SemOp = -2
+ ops[0].SemFlg = linux.IPC_NOWAIT
+ if _, _, err := set.executeOps(ctx, ops, 123); err != syserror.ErrWouldBlock {
+ t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, syserror.ErrWouldBlock)
+ }
+
+ ops[0].SemOp = 0
+ ops[0].SemFlg = linux.IPC_NOWAIT
+ if _, _, err := set.executeOps(ctx, ops, 123); err != syserror.ErrWouldBlock {
+ t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, syserror.ErrWouldBlock)
+ }
+}
+
+func TestUnregister(t *testing.T) {
+ ctx := contexttest.Context(t)
+ r := NewRegistry(auth.NewRootUserNamespace())
+ set, err := r.FindOrCreate(ctx, 123, 2, linux.FileMode(0x600), true, true, true)
+ if err != nil {
+ t.Fatalf("FindOrCreate() failed, err: %v", err)
+ }
+ if got := r.FindByID(set.ID); got.ID != set.ID {
+ t.Fatalf("FindById(%d) failed, got: %+v, expected: %+v", set.ID, got, set)
+ }
+
+ ops := []linux.Sembuf{
+ {SemOp: -1},
+ }
+ chs := make([]chan struct{}, 0, 5)
+ for i := 0; i < 5; i++ {
+ ch := executeOps(ctx, t, set, ops, true)
+ chs = append(chs, ch)
+ }
+
+ creds := auth.CredentialsFromContext(ctx)
+ if err := r.RemoveID(set.ID, creds); err != nil {
+ t.Fatalf("RemoveID(%d) failed, err: %v", set.ID, err)
+ }
+ if !set.dead {
+ t.Fatalf("set is not dead: %+v", set)
+ }
+ if got := r.FindByID(set.ID); got != nil {
+ t.Fatalf("FindById(%d) failed, got: %+v, expected: nil", set.ID, got)
+ }
+ for i, ch := range chs {
+ if !signalled(ch) {
+ t.Fatalf("channel %d should have been signalled", i)
+ }
+ }
+}
diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go
new file mode 100644
index 000000000..0e19286de
--- /dev/null
+++ b/pkg/sentry/kernel/sessions.go
@@ -0,0 +1,528 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// SessionID is the public identifier.
+type SessionID ThreadID
+
+// ProcessGroupID is the public identifier.
+type ProcessGroupID ThreadID
+
+// Session contains a leader threadgroup and a list of ProcessGroups.
+//
+// +stateify savable
+type Session struct {
+ refs refs.AtomicRefCount
+
+ // leader is the originator of the Session.
+ //
+ // Note that this may no longer be running (and may be reaped), so the
+ // ID is cached upon initial creation. The leader is still required
+ // however, since its PIDNamespace defines the scope of the Session.
+ //
+ // The leader is immutable.
+ leader *ThreadGroup
+
+ // id is the cached identifier in the leader's namespace.
+ //
+ // The id is immutable.
+ id SessionID
+
+ // foreground is the foreground process group.
+ //
+ // This is protected by TaskSet.mu.
+ foreground *ProcessGroup
+
+ // ProcessGroups is a list of process groups in this Session. This is
+ // protected by TaskSet.mu.
+ processGroups processGroupList
+
+ // sessionEntry is the embed for TaskSet.sessions. This is protected by
+ // TaskSet.mu.
+ sessionEntry
+}
+
+// incRef grabs a reference.
+func (s *Session) incRef() {
+ s.refs.IncRef()
+}
+
+// decRef drops a reference.
+//
+// Precondition: callers must hold TaskSet.mu for writing.
+func (s *Session) decRef() {
+ s.refs.DecRefWithDestructor(func() {
+ // Remove translations from the leader.
+ for ns := s.leader.pidns; ns != nil; ns = ns.parent {
+ id := ns.sids[s]
+ delete(ns.sids, s)
+ delete(ns.sessions, id)
+ }
+
+ // Remove from the list of global Sessions.
+ s.leader.pidns.owner.sessions.Remove(s)
+ })
+}
+
+// ProcessGroup contains an originator threadgroup and a parent Session.
+//
+// +stateify savable
+type ProcessGroup struct {
+ refs refs.AtomicRefCount // not exported.
+
+ // originator is the originator of the group.
+ //
+ // See note re: leader in Session. The same applies here.
+ //
+ // The originator is immutable.
+ originator *ThreadGroup
+
+ // id is the cached identifier in the originator's namespace.
+ //
+ // The id is immutable.
+ id ProcessGroupID
+
+ // Session is the parent Session.
+ //
+ // The session is immutable.
+ session *Session
+
+ // ancestors is the number of thread groups in this process group whose
+ // parent is in a different process group in the same session.
+ //
+ // The name is derived from the fact that process groups where
+ // ancestors is zero are considered "orphans".
+ //
+ // ancestors is protected by TaskSet.mu.
+ ancestors uint32
+
+ // processGroupEntry is the embedded entry for Sessions.groups. This is
+ // protected by TaskSet.mu.
+ processGroupEntry
+}
+
+// Originator retuns the originator of the process group.
+func (pg *ProcessGroup) Originator() *ThreadGroup {
+ return pg.originator
+}
+
+// IsOrphan returns true if this process group is an orphan.
+func (pg *ProcessGroup) IsOrphan() bool {
+ pg.originator.TaskSet().mu.RLock()
+ defer pg.originator.TaskSet().mu.RUnlock()
+ return pg.ancestors == 0
+}
+
+// incRefWithParent grabs a reference.
+//
+// This function is called when this ProcessGroup is being associated with some
+// new ThreadGroup, tg. parentPG is the ProcessGroup of tg's parent
+// ThreadGroup. If tg is init, then parentPG may be nil.
+//
+// Precondition: callers must hold TaskSet.mu for writing.
+func (pg *ProcessGroup) incRefWithParent(parentPG *ProcessGroup) {
+ // We acquire an "ancestor" reference in the case of a nil parent.
+ // This is because the process being associated is init, and init can
+ // never be orphaned (we count it as always having an ancestor).
+ if pg != parentPG && (parentPG == nil || pg.session == parentPG.session) {
+ pg.ancestors++
+ }
+
+ pg.refs.IncRef()
+}
+
+// decRefWithParent drops a reference.
+//
+// parentPG is per incRefWithParent.
+//
+// Precondition: callers must hold TaskSet.mu for writing.
+func (pg *ProcessGroup) decRefWithParent(parentPG *ProcessGroup) {
+ // See incRefWithParent regarding parent == nil.
+ if pg != parentPG && (parentPG == nil || pg.session == parentPG.session) {
+ pg.ancestors--
+ }
+
+ alive := true
+ pg.refs.DecRefWithDestructor(func() {
+ alive = false // don't bother with handleOrphan.
+
+ // Remove translations from the originator.
+ for ns := pg.originator.pidns; ns != nil; ns = ns.parent {
+ id := ns.pgids[pg]
+ delete(ns.pgids, pg)
+ delete(ns.processGroups, id)
+ }
+
+ // Remove the list of process groups.
+ pg.session.processGroups.Remove(pg)
+ pg.session.decRef()
+ })
+ if alive {
+ pg.handleOrphan()
+ }
+}
+
+// parentPG returns the parent process group.
+//
+// Precondition: callers must hold TaskSet.mu.
+func (tg *ThreadGroup) parentPG() *ProcessGroup {
+ if tg.leader.parent != nil {
+ return tg.leader.parent.tg.processGroup
+ }
+ return nil
+}
+
+// handleOrphan checks whether the process group is an orphan and has any
+// stopped jobs. If yes, then appropriate signals are delivered to each thread
+// group within the process group.
+//
+// Precondition: callers must hold TaskSet.mu for writing.
+func (pg *ProcessGroup) handleOrphan() {
+ // Check if this process is an orphan.
+ if pg.ancestors != 0 {
+ return
+ }
+
+ // See if there are any stopped jobs.
+ hasStopped := false
+ pg.originator.pidns.owner.forEachThreadGroupLocked(func(tg *ThreadGroup) {
+ if tg.processGroup != pg {
+ return
+ }
+ tg.signalHandlers.mu.Lock()
+ if tg.groupStopComplete {
+ hasStopped = true
+ }
+ tg.signalHandlers.mu.Unlock()
+ })
+ if !hasStopped {
+ return
+ }
+
+ // Deliver appropriate signals to all thread groups.
+ pg.originator.pidns.owner.forEachThreadGroupLocked(func(tg *ThreadGroup) {
+ if tg.processGroup != pg {
+ return
+ }
+ tg.signalHandlers.mu.Lock()
+ tg.leader.sendSignalLocked(SignalInfoPriv(linux.SIGHUP), true /* group */)
+ tg.leader.sendSignalLocked(SignalInfoPriv(linux.SIGCONT), true /* group */)
+ tg.signalHandlers.mu.Unlock()
+ })
+
+ return
+}
+
+// Session returns the process group's session without taking a reference.
+func (pg *ProcessGroup) Session() *Session {
+ return pg.session
+}
+
+// SendSignal sends a signal to all processes inside the process group. It is
+// analagous to kernel/signal.c:kill_pgrp.
+func (pg *ProcessGroup) SendSignal(info *arch.SignalInfo) error {
+ tasks := pg.originator.TaskSet()
+ tasks.mu.RLock()
+ defer tasks.mu.RUnlock()
+
+ var lastErr error
+ for tg := range tasks.Root.tgids {
+ if tg.processGroup == pg {
+ tg.signalHandlers.mu.Lock()
+ infoCopy := *info
+ if err := tg.leader.sendSignalLocked(&infoCopy, true /*group*/); err != nil {
+ lastErr = err
+ }
+ tg.signalHandlers.mu.Unlock()
+ }
+ }
+ return lastErr
+}
+
+// CreateSession creates a new Session, with the ThreadGroup as the leader.
+//
+// EPERM may be returned if either the given ThreadGroup is already a Session
+// leader, or a ProcessGroup already exists for the ThreadGroup's ID.
+func (tg *ThreadGroup) CreateSession() error {
+ tg.pidns.owner.mu.Lock()
+ defer tg.pidns.owner.mu.Unlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+ return tg.createSession()
+}
+
+// createSession creates a new session for a threadgroup.
+//
+// Precondition: callers must hold TaskSet.mu and the signal mutex for writing.
+func (tg *ThreadGroup) createSession() error {
+ // Get the ID for this thread in the current namespace.
+ id := tg.pidns.tgids[tg]
+
+ // Check if this ThreadGroup already leads a Session, or
+ // if the proposed group is already taken.
+ for s := tg.pidns.owner.sessions.Front(); s != nil; s = s.Next() {
+ if s.leader.pidns != tg.pidns {
+ continue
+ }
+ if s.leader == tg {
+ return syserror.EPERM
+ }
+ if s.id == SessionID(id) {
+ return syserror.EPERM
+ }
+ for pg := s.processGroups.Front(); pg != nil; pg = pg.Next() {
+ if pg.id == ProcessGroupID(id) {
+ return syserror.EPERM
+ }
+ }
+ }
+
+ // Create a new Session, with a single reference.
+ s := &Session{
+ id: SessionID(id),
+ leader: tg,
+ }
+ s.refs.EnableLeakCheck("kernel.Session")
+
+ // Create a new ProcessGroup, belonging to that Session.
+ // This also has a single reference (assigned below).
+ //
+ // Note that since this is a new session and a new process group, there
+ // will be zero ancestors for this process group. (It is an orphan at
+ // this point.)
+ pg := &ProcessGroup{
+ id: ProcessGroupID(id),
+ originator: tg,
+ session: s,
+ ancestors: 0,
+ }
+ pg.refs.EnableLeakCheck("kernel.ProcessGroup")
+
+ // Tie them and return the result.
+ s.processGroups.PushBack(pg)
+ tg.pidns.owner.sessions.PushBack(s)
+
+ // Leave the current group, and assign the new one.
+ if tg.processGroup != nil {
+ oldParentPG := tg.parentPG()
+ tg.forEachChildThreadGroupLocked(func(childTG *ThreadGroup) {
+ childTG.processGroup.incRefWithParent(pg)
+ childTG.processGroup.decRefWithParent(oldParentPG)
+ })
+ // If tg.processGroup is an orphan, decRefWithParent will lock
+ // the signal mutex of each thread group in tg.processGroup.
+ // However, tg's signal mutex may already be locked at this
+ // point. We change tg's process group before calling
+ // decRefWithParent to avoid locking tg's signal mutex twice.
+ oldPG := tg.processGroup
+ tg.processGroup = pg
+ oldPG.decRefWithParent(oldParentPG)
+ } else {
+ // The current process group may be nil only in the case of an
+ // unparented thread group (i.e. the init process). This would
+ // not normally occur, but we allow it for the convenience of
+ // CreateSession working from that point. There will be no
+ // child processes. We always say that the very first group
+ // created has ancestors (avoids checks elsewhere).
+ //
+ // Note that this mirrors the parent == nil logic in
+ // incRef/decRef/reparent, which counts nil as an ancestor.
+ tg.processGroup = pg
+ tg.processGroup.ancestors++
+ }
+
+ // Ensure a translation is added to all namespaces.
+ for ns := tg.pidns; ns != nil; ns = ns.parent {
+ local := ns.tgids[tg]
+ ns.sids[s] = SessionID(local)
+ ns.sessions[SessionID(local)] = s
+ ns.pgids[pg] = ProcessGroupID(local)
+ ns.processGroups[ProcessGroupID(local)] = pg
+ }
+
+ // Disconnect from the controlling terminal.
+ tg.tty = nil
+
+ return nil
+}
+
+// CreateProcessGroup creates a new process group.
+//
+// An EPERM error will be returned if the ThreadGroup belongs to a different
+// Session, is a Session leader or the group already exists.
+func (tg *ThreadGroup) CreateProcessGroup() error {
+ tg.pidns.owner.mu.Lock()
+ defer tg.pidns.owner.mu.Unlock()
+
+ // Get the ID for this thread in the current namespace.
+ id := tg.pidns.tgids[tg]
+
+ // Per above, check for a Session leader or existing group.
+ for s := tg.pidns.owner.sessions.Front(); s != nil; s = s.Next() {
+ if s.leader.pidns != tg.pidns {
+ continue
+ }
+ if s.leader == tg {
+ return syserror.EPERM
+ }
+ for pg := s.processGroups.Front(); pg != nil; pg = pg.Next() {
+ if pg.id == ProcessGroupID(id) {
+ return syserror.EPERM
+ }
+ }
+ }
+
+ // Create a new ProcessGroup, belonging to the current Session.
+ //
+ // We manually adjust the ancestors if the parent is in the same
+ // session.
+ tg.processGroup.session.incRef()
+ pg := ProcessGroup{
+ id: ProcessGroupID(id),
+ originator: tg,
+ session: tg.processGroup.session,
+ }
+ pg.refs.EnableLeakCheck("kernel.ProcessGroup")
+
+ if tg.leader.parent != nil && tg.leader.parent.tg.processGroup.session == pg.session {
+ pg.ancestors++
+ }
+
+ // Assign the new process group; adjust children.
+ oldParentPG := tg.parentPG()
+ tg.forEachChildThreadGroupLocked(func(childTG *ThreadGroup) {
+ childTG.processGroup.incRefWithParent(&pg)
+ childTG.processGroup.decRefWithParent(oldParentPG)
+ })
+ tg.processGroup.decRefWithParent(oldParentPG)
+ tg.processGroup = &pg
+
+ // Add the new process group to the session.
+ pg.session.processGroups.PushBack(&pg)
+
+ // Ensure this translation is added to all namespaces.
+ for ns := tg.pidns; ns != nil; ns = ns.parent {
+ local := ns.tgids[tg]
+ ns.pgids[&pg] = ProcessGroupID(local)
+ ns.processGroups[ProcessGroupID(local)] = &pg
+ }
+
+ return nil
+}
+
+// JoinProcessGroup joins an existing process group.
+//
+// This function will return EACCES if an exec has been performed since fork
+// by the given ThreadGroup, and EPERM if the Sessions are not the same or the
+// group does not exist.
+//
+// If checkExec is set, then the join is not permitted after the process has
+// executed exec at least once.
+func (tg *ThreadGroup) JoinProcessGroup(pidns *PIDNamespace, pgid ProcessGroupID, checkExec bool) error {
+ pidns.owner.mu.Lock()
+ defer pidns.owner.mu.Unlock()
+
+ // Lookup the ProcessGroup.
+ pg := pidns.processGroups[pgid]
+ if pg == nil {
+ return syserror.EPERM
+ }
+
+ // Disallow the join if an execve has performed, per POSIX.
+ if checkExec && tg.execed {
+ return syserror.EACCES
+ }
+
+ // See if it's in the same session as ours.
+ if pg.session != tg.processGroup.session {
+ return syserror.EPERM
+ }
+
+ // Join the group; adjust children.
+ parentPG := tg.parentPG()
+ pg.incRefWithParent(parentPG)
+ tg.forEachChildThreadGroupLocked(func(childTG *ThreadGroup) {
+ childTG.processGroup.incRefWithParent(pg)
+ childTG.processGroup.decRefWithParent(tg.processGroup)
+ })
+ tg.processGroup.decRefWithParent(parentPG)
+ tg.processGroup = pg
+
+ return nil
+}
+
+// Session returns the ThreadGroup's Session.
+//
+// A reference is not taken on the session.
+func (tg *ThreadGroup) Session() *Session {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ return tg.processGroup.session
+}
+
+// IDOfSession returns the Session assigned to s in PID namespace ns.
+//
+// If this group isn't visible in this namespace, zero will be returned. It is
+// the callers responsibility to check that before using this function.
+func (pidns *PIDNamespace) IDOfSession(s *Session) SessionID {
+ pidns.owner.mu.RLock()
+ defer pidns.owner.mu.RUnlock()
+ return pidns.sids[s]
+}
+
+// SessionWithID returns the Session with the given ID in the PID namespace ns,
+// or nil if that given ID is not defined in this namespace.
+//
+// A reference is not taken on the session.
+func (pidns *PIDNamespace) SessionWithID(id SessionID) *Session {
+ pidns.owner.mu.RLock()
+ defer pidns.owner.mu.RUnlock()
+ return pidns.sessions[id]
+}
+
+// ProcessGroup returns the ThreadGroup's ProcessGroup.
+//
+// A reference is not taken on the process group.
+func (tg *ThreadGroup) ProcessGroup() *ProcessGroup {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ return tg.processGroup
+}
+
+// IDOfProcessGroup returns the process group assigned to pg in PID namespace ns.
+//
+// The same constraints apply as IDOfSession.
+func (pidns *PIDNamespace) IDOfProcessGroup(pg *ProcessGroup) ProcessGroupID {
+ pidns.owner.mu.RLock()
+ defer pidns.owner.mu.RUnlock()
+ return pidns.pgids[pg]
+}
+
+// ProcessGroupWithID returns the ProcessGroup with the given ID in the PID
+// namespace ns, or nil if that given ID is not defined in this namespace.
+//
+// A reference is not taken on the process group.
+func (pidns *PIDNamespace) ProcessGroupWithID(id ProcessGroupID) *ProcessGroup {
+ pidns.owner.mu.RLock()
+ defer pidns.owner.mu.RUnlock()
+ return pidns.processGroups[id]
+}
diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD
new file mode 100644
index 000000000..bfd779837
--- /dev/null
+++ b/pkg/sentry/kernel/shm/BUILD
@@ -0,0 +1,29 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "shm",
+ srcs = [
+ "device.go",
+ "shm.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/usage",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/kernel/shm/device.go b/pkg/sentry/kernel/shm/device.go
new file mode 100644
index 000000000..6b0d5818b
--- /dev/null
+++ b/pkg/sentry/kernel/shm/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package shm
+
+import "gvisor.dev/gvisor/pkg/sentry/device"
+
+// shmDevice is the kernel shm device.
+var shmDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
new file mode 100644
index 000000000..f66cfcc7f
--- /dev/null
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -0,0 +1,707 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package shm implements sysv shared memory segments.
+//
+// Known missing features:
+//
+// - SHM_LOCK/SHM_UNLOCK are no-ops. The sentry currently doesn't implement
+// memory locking in general.
+//
+// - SHM_HUGETLB and related flags for shmget(2) are ignored. There's no easy
+// way to implement hugetlb support on a per-map basis, and it has no impact
+// on correctness.
+//
+// - SHM_NORESERVE for shmget(2) is ignored, the sentry doesn't implement swap
+// so it's meaningless to reserve space for swap.
+//
+// - No per-process segment size enforcement. This feature probably isn't used
+// much anyways, since Linux sets the per-process limits to the system-wide
+// limits by default.
+//
+// Lock ordering: mm.mappingMu -> shm registry lock -> shm lock
+package shm
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Key represents a shm segment key. Analogous to a file name.
+type Key int32
+
+// ID represents the opaque handle for a shm segment. Analogous to an fd.
+type ID int32
+
+// Registry tracks all shared memory segments in an IPC namespace. The registry
+// provides the mechanisms for creating and finding segments, and reporting
+// global shm parameters.
+//
+// +stateify savable
+type Registry struct {
+ // userNS owns the IPC namespace this registry belong to. Immutable.
+ userNS *auth.UserNamespace
+
+ // mu protects all fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // shms maps segment ids to segments.
+ //
+ // shms holds all referenced segments, which are removed on the last
+ // DecRef. Thus, it cannot itself hold a reference on the Shm.
+ //
+ // Since removal only occurs after the last (unlocked) DecRef, there
+ // exists a short window during which a Shm still exists in Shm, but is
+ // unreferenced. Users must use TryIncRef to determine if the Shm is
+ // still valid.
+ shms map[ID]*Shm
+
+ // keysToShms maps segment keys to segments.
+ //
+ // Shms in keysToShms are guaranteed to be referenced, as they are
+ // removed by disassociateKey before the last DecRef.
+ keysToShms map[Key]*Shm
+
+ // Sum of the sizes of all existing segments rounded up to page size, in
+ // units of page size.
+ totalPages uint64
+
+ // ID assigned to the last created segment. Used to quickly find the next
+ // unused ID.
+ lastIDUsed ID
+}
+
+// NewRegistry creates a new shm registry.
+func NewRegistry(userNS *auth.UserNamespace) *Registry {
+ return &Registry{
+ userNS: userNS,
+ shms: make(map[ID]*Shm),
+ keysToShms: make(map[Key]*Shm),
+ }
+}
+
+// FindByID looks up a segment given an ID.
+//
+// FindByID returns a reference on Shm.
+func (r *Registry) FindByID(id ID) *Shm {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ s := r.shms[id]
+ // Take a reference on s. If TryIncRef fails, s has reached the last
+ // DecRef, but hasn't quite been removed from r.shms yet.
+ if s != nil && s.TryIncRef() {
+ return s
+ }
+ return nil
+}
+
+// dissociateKey removes the association between a segment and its key,
+// preventing it from being discovered in the registry. This doesn't necessarily
+// mean the segment is about to be destroyed. This is analogous to unlinking a
+// file; the segment can still be used by a process already referencing it, but
+// cannot be discovered by a new process.
+func (r *Registry) dissociateKey(s *Shm) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.key != linux.IPC_PRIVATE {
+ delete(r.keysToShms, s.key)
+ s.key = linux.IPC_PRIVATE
+ }
+}
+
+// FindOrCreate looks up or creates a segment in the registry. It's functionally
+// analogous to open(2).
+//
+// FindOrCreate returns a reference on Shm.
+func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size uint64, mode linux.FileMode, private, create, exclusive bool) (*Shm, error) {
+ if (create || private) && (size < linux.SHMMIN || size > linux.SHMMAX) {
+ // "A new segment was to be created and size is less than SHMMIN or
+ // greater than SHMMAX." - man shmget(2)
+ //
+ // Note that 'private' always implies the creation of a new segment
+ // whether IPC_CREAT is specified or not.
+ return nil, syserror.EINVAL
+ }
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if len(r.shms) >= linux.SHMMNI {
+ // "All possible shared memory IDs have been taken (SHMMNI) ..."
+ // - man shmget(2)
+ return nil, syserror.ENOSPC
+ }
+
+ if !private {
+ // Look up an existing segment.
+ if shm := r.keysToShms[key]; shm != nil {
+ shm.mu.Lock()
+ defer shm.mu.Unlock()
+
+ // Check that caller can access the segment.
+ if !shm.checkPermissions(ctx, fs.PermsFromMode(mode)) {
+ // "The user does not have permission to access the shared
+ // memory segment, and does not have the CAP_IPC_OWNER
+ // capability in the user namespace that governs its IPC
+ // namespace." - man shmget(2)
+ return nil, syserror.EACCES
+ }
+
+ if size > shm.size {
+ // "A segment for the given key exists, but size is greater than
+ // the size of that segment." - man shmget(2)
+ return nil, syserror.EINVAL
+ }
+
+ if create && exclusive {
+ // "IPC_CREAT and IPC_EXCL were specified in shmflg, but a
+ // shared memory segment already exists for key."
+ // - man shmget(2)
+ return nil, syserror.EEXIST
+ }
+
+ shm.IncRef()
+ return shm, nil
+ }
+
+ if !create {
+ // "No segment exists for the given key, and IPC_CREAT was not
+ // specified." - man shmget(2)
+ return nil, syserror.ENOENT
+ }
+ }
+
+ var sizeAligned uint64
+ if val, ok := usermem.Addr(size).RoundUp(); ok {
+ sizeAligned = uint64(val)
+ } else {
+ return nil, syserror.EINVAL
+ }
+
+ if numPages := sizeAligned / usermem.PageSize; r.totalPages+numPages > linux.SHMALL {
+ // "... allocating a segment of the requested size would cause the
+ // system to exceed the system-wide limit on shared memory (SHMALL)."
+ // - man shmget(2)
+ return nil, syserror.ENOSPC
+ }
+
+ // Need to create a new segment.
+ creator := fs.FileOwnerFromContext(ctx)
+ perms := fs.FilePermsFromMode(mode)
+ s, err := r.newShm(ctx, pid, key, creator, perms, size)
+ if err != nil {
+ return nil, err
+ }
+ // The initial reference is held by s itself. Take another to return to
+ // the caller.
+ s.IncRef()
+ return s, nil
+}
+
+// newShm creates a new segment in the registry.
+//
+// Precondition: Caller must hold r.mu.
+func (r *Registry) newShm(ctx context.Context, pid int32, key Key, creator fs.FileOwner, perms fs.FilePermissions, size uint64) (*Shm, error) {
+ mfp := pgalloc.MemoryFileProviderFromContext(ctx)
+ if mfp == nil {
+ panic(fmt.Sprintf("context.Context %T lacks non-nil value for key %T", ctx, pgalloc.CtxMemoryFileProvider))
+ }
+
+ effectiveSize := uint64(usermem.Addr(size).MustRoundUp())
+ fr, err := mfp.MemoryFile().Allocate(effectiveSize, usage.Anonymous)
+ if err != nil {
+ return nil, err
+ }
+
+ shm := &Shm{
+ mfp: mfp,
+ registry: r,
+ creator: creator,
+ size: size,
+ effectiveSize: effectiveSize,
+ fr: fr,
+ key: key,
+ perms: perms,
+ owner: creator,
+ creatorPID: pid,
+ changeTime: ktime.NowFromContext(ctx),
+ }
+ shm.EnableLeakCheck("kernel.Shm")
+
+ // Find the next available ID.
+ for id := r.lastIDUsed + 1; id != r.lastIDUsed; id++ {
+ // Handle wrap around.
+ if id < 0 {
+ id = 0
+ continue
+ }
+ if r.shms[id] == nil {
+ r.lastIDUsed = id
+
+ shm.ID = id
+ r.shms[id] = shm
+ r.keysToShms[key] = shm
+
+ r.totalPages += effectiveSize / usermem.PageSize
+
+ return shm, nil
+ }
+ }
+
+ log.Warningf("Shm ids exhuasted, they may be leaking")
+ return nil, syserror.ENOSPC
+}
+
+// IPCInfo reports global parameters for sysv shared memory segments on this
+// system. See shmctl(IPC_INFO).
+func (r *Registry) IPCInfo() *linux.ShmParams {
+ return &linux.ShmParams{
+ ShmMax: linux.SHMMAX,
+ ShmMin: linux.SHMMIN,
+ ShmMni: linux.SHMMNI,
+ ShmSeg: linux.SHMSEG,
+ ShmAll: linux.SHMALL,
+ }
+}
+
+// ShmInfo reports linux-specific global parameters for sysv shared memory
+// segments on this system. See shmctl(SHM_INFO).
+func (r *Registry) ShmInfo() *linux.ShmInfo {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ return &linux.ShmInfo{
+ UsedIDs: int32(r.lastIDUsed),
+ ShmTot: r.totalPages,
+ ShmRss: r.totalPages, // We could probably get a better estimate from memory accounting.
+ ShmSwp: 0, // No reclaim at the moment.
+ }
+}
+
+// remove deletes a segment from this registry, deaccounting the memory used by
+// the segment.
+//
+// Precondition: Must follow a call to r.dissociateKey(s).
+func (r *Registry) remove(s *Shm) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if s.key != linux.IPC_PRIVATE {
+ panic(fmt.Sprintf("Attempted to remove %s from the registry whose key is still associated", s.debugLocked()))
+ }
+
+ delete(r.shms, s.ID)
+ r.totalPages -= s.effectiveSize / usermem.PageSize
+}
+
+// Shm represents a single shared memory segment.
+//
+// Shm segment are backed directly by an allocation from platform memory.
+// Segments are always mapped as a whole, greatly simplifying how mappings are
+// tracked. However note that mremap and munmap calls may cause the vma for a
+// segment to become fragmented; which requires special care when unmapping a
+// segment. See mm/shm.go.
+//
+// Segments persist until they are explicitly marked for destruction via
+// MarkDestroyed().
+//
+// Shm implements memmap.Mappable and memmap.MappingIdentity.
+//
+// +stateify savable
+type Shm struct {
+ // AtomicRefCount tracks the number of references to this segment.
+ //
+ // A segment holds a reference to itself until it is marked for
+ // destruction.
+ //
+ // In addition to direct users, the MemoryManager will hold references
+ // via MappingIdentity.
+ refs.AtomicRefCount
+
+ mfp pgalloc.MemoryFileProvider
+
+ // registry points to the shm registry containing this segment. Immutable.
+ registry *Registry
+
+ // ID is the kernel identifier for this segment. Immutable.
+ ID ID
+
+ // creator is the user that created the segment. Immutable.
+ creator fs.FileOwner
+
+ // size is the requested size of the segment at creation, in
+ // bytes. Immutable.
+ size uint64
+
+ // effectiveSize of the segment, rounding up to the next page
+ // boundary. Immutable.
+ //
+ // Invariant: effectiveSize must be a multiple of usermem.PageSize.
+ effectiveSize uint64
+
+ // fr is the offset into mfp.MemoryFile() that backs this contents of this
+ // segment. Immutable.
+ fr platform.FileRange
+
+ // mu protects all fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // key is the public identifier for this segment.
+ key Key
+
+ // perms is the access permissions for the segment.
+ perms fs.FilePermissions
+
+ // owner of this segment.
+ owner fs.FileOwner
+ // attachTime is updated on every successful shmat.
+ attachTime ktime.Time
+ // detachTime is updated on every successful shmdt.
+ detachTime ktime.Time
+ // changeTime is updated on every successful changes to the segment via
+ // shmctl(IPC_SET).
+ changeTime ktime.Time
+
+ // creatorPID is the PID of the process that created the segment.
+ creatorPID int32
+ // lastAttachDetachPID is the pid of the process that issued the last shmat
+ // or shmdt syscall.
+ lastAttachDetachPID int32
+
+ // pendingDestruction indicates the segment was marked as destroyed through
+ // shmctl(IPC_RMID). When marked as destroyed, the segment will not be found
+ // in the registry and can no longer be attached. When the last user
+ // detaches from the segment, it is destroyed.
+ pendingDestruction bool
+}
+
+// Precondition: Caller must hold s.mu.
+func (s *Shm) debugLocked() string {
+ return fmt.Sprintf("Shm{id: %d, key: %d, size: %d bytes, refs: %d, destroyed: %v}",
+ s.ID, s.key, s.size, s.ReadRefs(), s.pendingDestruction)
+}
+
+// MappedName implements memmap.MappingIdentity.MappedName.
+func (s *Shm) MappedName(ctx context.Context) string {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return fmt.Sprintf("SYSV%08d", s.key)
+}
+
+// DeviceID implements memmap.MappingIdentity.DeviceID.
+func (s *Shm) DeviceID() uint64 {
+ return shmDevice.DeviceID()
+}
+
+// InodeID implements memmap.MappingIdentity.InodeID.
+func (s *Shm) InodeID() uint64 {
+ // "shmid gets reported as "inode#" in /proc/pid/maps. proc-ps tools use
+ // this. Changing this will break them." -- Linux, ipc/shm.c:newseg()
+ return uint64(s.ID)
+}
+
+// DecRef overrides refs.RefCount.DecRef with a destructor.
+//
+// Precondition: Caller must not hold s.mu.
+func (s *Shm) DecRef() {
+ s.DecRefWithDestructor(s.destroy)
+}
+
+// Msync implements memmap.MappingIdentity.Msync. Msync is a no-op for shm
+// segments.
+func (s *Shm) Msync(context.Context, memmap.MappableRange) error {
+ return nil
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (s *Shm) AddMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.AddrRange, _ uint64, _ bool) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.attachTime = ktime.NowFromContext(ctx)
+ if pid, ok := context.ThreadGroupIDFromContext(ctx); ok {
+ s.lastAttachDetachPID = pid
+ } else {
+ // AddMapping is called during a syscall, so ctx should always be a task
+ // context.
+ log.Warningf("Adding mapping to %s but couldn't get the current pid; not updating the last attach pid", s.debugLocked())
+ }
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (s *Shm) RemoveMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.AddrRange, _ uint64, _ bool) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ // RemoveMapping may be called during task exit, when ctx
+ // is context.Background. Gracefully handle missing clocks. Failing to
+ // update the detach time in these cases is ok, since no one can observe the
+ // omission.
+ if clock := ktime.RealtimeClockFromContext(ctx); clock != nil {
+ s.detachTime = clock.Now()
+ }
+
+ // If called from a non-task context we also won't have a threadgroup
+ // id. Silently skip updating the lastAttachDetachPid in that case.
+ if pid, ok := context.ThreadGroupIDFromContext(ctx); ok {
+ s.lastAttachDetachPID = pid
+ } else {
+ log.Debugf("Couldn't obtain pid when removing mapping to %s, not updating the last detach pid.", s.debugLocked())
+ }
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (*Shm) CopyMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, usermem.AddrRange, uint64, bool) error {
+ return nil
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (s *Shm) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ var err error
+ if required.End > s.fr.Length() {
+ err = &memmap.BusError{syserror.EFAULT}
+ }
+ if source := optional.Intersect(memmap.MappableRange{0, s.fr.Length()}); source.Length() != 0 {
+ return []memmap.Translation{
+ {
+ Source: source,
+ File: s.mfp.MemoryFile(),
+ Offset: s.fr.Start + source.Start,
+ Perms: usermem.AnyAccess,
+ },
+ }, err
+ }
+ return nil, err
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (s *Shm) InvalidateUnsavable(ctx context.Context) error {
+ return nil
+}
+
+// AttachOpts describes various flags passed to shmat(2).
+type AttachOpts struct {
+ Execute bool
+ Readonly bool
+ Remap bool
+}
+
+// ConfigureAttach creates an mmap configuration for the segment with the
+// requested attach options.
+//
+// Postconditions: The returned MMapOpts are valid only as long as a reference
+// continues to be held on s.
+func (s *Shm) ConfigureAttach(ctx context.Context, addr usermem.Addr, opts AttachOpts) (memmap.MMapOpts, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.pendingDestruction && s.ReadRefs() == 0 {
+ return memmap.MMapOpts{}, syserror.EIDRM
+ }
+
+ if !s.checkPermissions(ctx, fs.PermMask{
+ Read: true,
+ Write: !opts.Readonly,
+ Execute: opts.Execute,
+ }) {
+ // "The calling process does not have the required permissions for the
+ // requested attach type, and does not have the CAP_IPC_OWNER capability
+ // in the user namespace that governs its IPC namespace." - man shmat(2)
+ return memmap.MMapOpts{}, syserror.EACCES
+ }
+ return memmap.MMapOpts{
+ Length: s.size,
+ Offset: 0,
+ Addr: addr,
+ Fixed: opts.Remap,
+ Perms: usermem.AccessType{
+ Read: true,
+ Write: !opts.Readonly,
+ Execute: opts.Execute,
+ },
+ MaxPerms: usermem.AnyAccess,
+ Mappable: s,
+ MappingIdentity: s,
+ }, nil
+}
+
+// EffectiveSize returns the size of the underlying shared memory segment. This
+// may be larger than the requested size at creation, due to rounding to page
+// boundaries.
+func (s *Shm) EffectiveSize() uint64 {
+ return s.effectiveSize
+}
+
+// IPCStat returns information about a shm. See shmctl(IPC_STAT).
+func (s *Shm) IPCStat(ctx context.Context) (*linux.ShmidDS, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // "The caller must have read permission on the shared memory segment."
+ // - man shmctl(2)
+ if !s.checkPermissions(ctx, fs.PermMask{Read: true}) {
+ // "IPC_STAT or SHM_STAT is requested and shm_perm.mode does not allow
+ // read access for shmid, and the calling process does not have the
+ // CAP_IPC_OWNER capability in the user namespace that governs its IPC
+ // namespace." - man shmctl(2)
+ return nil, syserror.EACCES
+ }
+
+ var mode uint16
+ if s.pendingDestruction {
+ mode |= linux.SHM_DEST
+ }
+ creds := auth.CredentialsFromContext(ctx)
+
+ // Use the reference count as a rudimentary count of the number of
+ // attaches. We exclude:
+ //
+ // 1. The reference the caller holds.
+ // 2. The self-reference held by s prior to destruction.
+ //
+ // Note that this may still overcount by including transient references
+ // used in concurrent calls.
+ nattach := uint64(s.ReadRefs()) - 1
+ if !s.pendingDestruction {
+ nattach--
+ }
+
+ ds := &linux.ShmidDS{
+ ShmPerm: linux.IPCPerm{
+ Key: uint32(s.key),
+ UID: uint32(creds.UserNamespace.MapFromKUID(s.owner.UID)),
+ GID: uint32(creds.UserNamespace.MapFromKGID(s.owner.GID)),
+ CUID: uint32(creds.UserNamespace.MapFromKUID(s.creator.UID)),
+ CGID: uint32(creds.UserNamespace.MapFromKGID(s.creator.GID)),
+ Mode: mode | uint16(s.perms.LinuxMode()),
+ Seq: 0, // IPC sequences not supported.
+ },
+ ShmSegsz: s.size,
+ ShmAtime: s.attachTime.TimeT(),
+ ShmDtime: s.detachTime.TimeT(),
+ ShmCtime: s.changeTime.TimeT(),
+ ShmCpid: s.creatorPID,
+ ShmLpid: s.lastAttachDetachPID,
+ ShmNattach: nattach,
+ }
+
+ return ds, nil
+}
+
+// Set modifies attributes for a segment. See shmctl(IPC_SET).
+func (s *Shm) Set(ctx context.Context, ds *linux.ShmidDS) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if !s.checkOwnership(ctx) {
+ return syserror.EPERM
+ }
+
+ creds := auth.CredentialsFromContext(ctx)
+ uid := creds.UserNamespace.MapToKUID(auth.UID(ds.ShmPerm.UID))
+ gid := creds.UserNamespace.MapToKGID(auth.GID(ds.ShmPerm.GID))
+ if !uid.Ok() || !gid.Ok() {
+ return syserror.EINVAL
+ }
+
+ // User may only modify the lower 9 bits of the mode. All the other bits are
+ // always 0 for the underlying inode.
+ mode := linux.FileMode(ds.ShmPerm.Mode & 0x1ff)
+ s.perms = fs.FilePermsFromMode(mode)
+
+ s.owner.UID = uid
+ s.owner.GID = gid
+
+ s.changeTime = ktime.NowFromContext(ctx)
+ return nil
+}
+
+func (s *Shm) destroy() {
+ s.mfp.MemoryFile().DecRef(s.fr)
+ s.registry.remove(s)
+}
+
+// MarkDestroyed marks a segment for destruction. The segment is actually
+// destroyed once it has no references. MarkDestroyed may be called multiple
+// times, and is safe to call after a segment has already been destroyed. See
+// shmctl(IPC_RMID).
+func (s *Shm) MarkDestroyed() {
+ s.registry.dissociateKey(s)
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if !s.pendingDestruction {
+ s.pendingDestruction = true
+ // Drop the self-reference so destruction occurs when all
+ // external references are gone.
+ //
+ // N.B. This cannot be the final DecRef, as the caller also
+ // holds a reference.
+ s.DecRef()
+ return
+ }
+}
+
+// checkOwnership verifies whether a segment may be accessed by ctx as an
+// owner. See ipc/util.c:ipcctl_pre_down_nolock() in Linux.
+//
+// Precondition: Caller must hold s.mu.
+func (s *Shm) checkOwnership(ctx context.Context) bool {
+ creds := auth.CredentialsFromContext(ctx)
+ if s.owner.UID == creds.EffectiveKUID || s.creator.UID == creds.EffectiveKUID {
+ return true
+ }
+
+ // Tasks with CAP_SYS_ADMIN may bypass ownership checks. Strangely, Linux
+ // doesn't use CAP_IPC_OWNER for this despite CAP_IPC_OWNER being documented
+ // for use to "override IPC ownership checks".
+ return creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, s.registry.userNS)
+}
+
+// checkPermissions verifies whether a segment is accessible by ctx for access
+// described by req. See ipc/util.c:ipcperms() in Linux.
+//
+// Precondition: Caller must hold s.mu.
+func (s *Shm) checkPermissions(ctx context.Context, req fs.PermMask) bool {
+ creds := auth.CredentialsFromContext(ctx)
+
+ p := s.perms.Other
+ if s.owner.UID == creds.EffectiveKUID {
+ p = s.perms.User
+ } else if creds.InGroup(s.owner.GID) {
+ p = s.perms.Group
+ }
+ if p.SupersetOf(req) {
+ return true
+ }
+
+ // Tasks with CAP_IPC_OWNER may bypass permission checks.
+ return creds.HasCapabilityIn(linux.CAP_IPC_OWNER, s.registry.userNS)
+}
diff --git a/pkg/sentry/kernel/signal.go b/pkg/sentry/kernel/signal.go
new file mode 100644
index 000000000..e8cce37d0
--- /dev/null
+++ b/pkg/sentry/kernel/signal.go
@@ -0,0 +1,79 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+)
+
+// SignalPanic is used to panic the running threads. It is a signal which
+// cannot be used by the application: it must be caught and ignored by the
+// runtime (in order to catch possible races).
+const SignalPanic = linux.SIGUSR2
+
+// sendExternalSignal is called when an asynchronous signal is sent to the
+// sentry ("in sentry context"). On some platforms, it may also be called when
+// an asynchronous signal is sent to sandboxed application threads ("in
+// application context").
+//
+// context is used only for debugging to differentiate these cases.
+//
+// Preconditions: Kernel must have an init process.
+func (k *Kernel) sendExternalSignal(info *arch.SignalInfo, context string) {
+ switch linux.Signal(info.Signo) {
+ case linux.SIGURG:
+ // Sent by the Go 1.14+ runtime for asynchronous goroutine preemption.
+
+ case platform.SignalInterrupt:
+ // Assume that a call to platform.Context.Interrupt() misfired.
+
+ case SignalPanic:
+ // SignalPanic is also specially handled in sentry setup to ensure that
+ // it causes a panic even after tasks exit, but SignalPanic may also
+ // be sent here if it is received while in app context.
+ panic("Signal-induced panic")
+
+ default:
+ log.Infof("Received external signal %d in %s context", info.Signo, context)
+ if k.globalInit == nil {
+ panic(fmt.Sprintf("Received external signal %d before init created", info.Signo))
+ }
+ k.globalInit.SendSignal(info)
+ }
+}
+
+// SignalInfoPriv returns a SignalInfo equivalent to Linux's SEND_SIG_PRIV.
+func SignalInfoPriv(sig linux.Signal) *arch.SignalInfo {
+ return &arch.SignalInfo{
+ Signo: int32(sig),
+ Code: arch.SignalInfoKernel,
+ }
+}
+
+// SignalInfoNoInfo returns a SignalInfo equivalent to Linux's SEND_SIG_NOINFO.
+func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *arch.SignalInfo {
+ info := &arch.SignalInfo{
+ Signo: int32(sig),
+ Code: arch.SignalInfoUser,
+ }
+ info.SetPid(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg)))
+ info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
+ return info
+}
diff --git a/pkg/sentry/kernel/signal_handlers.go b/pkg/sentry/kernel/signal_handlers.go
new file mode 100644
index 000000000..768fda220
--- /dev/null
+++ b/pkg/sentry/kernel/signal_handlers.go
@@ -0,0 +1,88 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// SignalHandlers holds information about signal actions.
+//
+// +stateify savable
+type SignalHandlers struct {
+ // mu protects actions, as well as the signal state of all tasks and thread
+ // groups using this SignalHandlers object. (See comment on
+ // ThreadGroup.signalHandlers.)
+ mu sync.Mutex `state:"nosave"`
+
+ // actions is the action to be taken upon receiving each signal.
+ actions map[linux.Signal]arch.SignalAct
+}
+
+// NewSignalHandlers returns a new SignalHandlers specifying all default
+// actions.
+func NewSignalHandlers() *SignalHandlers {
+ return &SignalHandlers{
+ actions: make(map[linux.Signal]arch.SignalAct),
+ }
+}
+
+// Fork returns a copy of sh for a new thread group.
+func (sh *SignalHandlers) Fork() *SignalHandlers {
+ sh2 := NewSignalHandlers()
+ sh.mu.Lock()
+ defer sh.mu.Unlock()
+ for sig, act := range sh.actions {
+ sh2.actions[sig] = act
+ }
+ return sh2
+}
+
+// CopyForExec returns a copy of sh for a thread group that is undergoing an
+// execve. (See comments in Task.finishExec.)
+func (sh *SignalHandlers) CopyForExec() *SignalHandlers {
+ sh2 := NewSignalHandlers()
+ sh.mu.Lock()
+ defer sh.mu.Unlock()
+ for sig, act := range sh.actions {
+ if act.Handler == arch.SignalActIgnore {
+ sh2.actions[sig] = arch.SignalAct{
+ Handler: arch.SignalActIgnore,
+ }
+ }
+ }
+ return sh2
+}
+
+// IsIgnored returns true if the signal is ignored.
+func (sh *SignalHandlers) IsIgnored(sig linux.Signal) bool {
+ sh.mu.Lock()
+ defer sh.mu.Unlock()
+ sa, ok := sh.actions[sig]
+ return ok && sa.Handler == arch.SignalActIgnore
+}
+
+// dequeueActionLocked returns the SignalAct that should be used to handle sig.
+//
+// Preconditions: sh.mu must be locked.
+func (sh *SignalHandlers) dequeueAction(sig linux.Signal) arch.SignalAct {
+ act := sh.actions[sig]
+ if act.IsResetHandler() {
+ delete(sh.actions, sig)
+ }
+ return act
+}
diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD
new file mode 100644
index 000000000..3eb78e91b
--- /dev/null
+++ b/pkg/sentry/kernel/signalfd/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "signalfd",
+ srcs = ["signalfd.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/anon",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go
new file mode 100644
index 000000000..8243bb93e
--- /dev/null
+++ b/pkg/sentry/kernel/signalfd/signalfd.go
@@ -0,0 +1,139 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package signalfd provides an implementation of signal file descriptors.
+package signalfd
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/anon"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SignalOperations represent a file with signalfd semantics.
+//
+// +stateify savable
+type SignalOperations struct {
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ // target is the original task target.
+ //
+ // The semantics here are a bit broken. Linux will always use current
+ // for all reads, regardless of where the signalfd originated. We can't
+ // do exactly that because we need to plumb the context through
+ // EventRegister in order to support proper blocking behavior. This
+ // will undoubtedly become very complicated quickly.
+ target *kernel.Task
+
+ // mu protects below.
+ mu sync.Mutex `state:"nosave"`
+
+ // mask is the signal mask. Protected by mu.
+ mask linux.SignalSet
+}
+
+// New creates a new signalfd object with the supplied mask.
+func New(ctx context.Context, mask linux.SignalSet) (*fs.File, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ // No task context? Not valid.
+ return nil, syserror.EINVAL
+ }
+ // name matches fs/signalfd.c:signalfd4.
+ dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[signalfd]")
+ return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &SignalOperations{
+ target: t,
+ mask: mask,
+ }), nil
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *SignalOperations) Release() {}
+
+// Mask returns the signal mask.
+func (s *SignalOperations) Mask() linux.SignalSet {
+ s.mu.Lock()
+ mask := s.mask
+ s.mu.Unlock()
+ return mask
+}
+
+// SetMask sets the signal mask.
+func (s *SignalOperations) SetMask(mask linux.SignalSet) {
+ s.mu.Lock()
+ s.mask = mask
+ s.mu.Unlock()
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ // Attempt to dequeue relevant signals.
+ info, err := s.target.Sigtimedwait(s.Mask(), 0)
+ if err != nil {
+ // There must be no signal available.
+ return 0, syserror.ErrWouldBlock
+ }
+
+ // Copy out the signal info using the specified format.
+ var buf [128]byte
+ binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{
+ Signo: uint32(info.Signo),
+ Errno: info.Errno,
+ Code: info.Code,
+ PID: uint32(info.Pid()),
+ UID: uint32(info.Uid()),
+ Status: info.Status(),
+ Overrun: uint32(info.Overrun()),
+ Addr: info.Addr(),
+ })
+ n, err := dst.CopyOut(ctx, buf[:])
+ return int64(n), err
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SignalOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ if mask&waiter.EventIn != 0 && s.target.PendingSignals()&s.Mask() != 0 {
+ return waiter.EventIn // Pending signals.
+ }
+ return 0
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SignalOperations) EventRegister(entry *waiter.Entry, _ waiter.EventMask) {
+ // Register for the signal set; ignore the passed events.
+ s.target.SignalRegister(entry, waiter.EventMask(s.Mask()))
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SignalOperations) EventUnregister(entry *waiter.Entry) {
+ // Unregister the original entry.
+ s.target.SignalUnregister(entry)
+}
diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go
new file mode 100644
index 000000000..413111faf
--- /dev/null
+++ b/pkg/sentry/kernel/syscalls.go
@@ -0,0 +1,364 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// maxSyscallNum is the highest supported syscall number.
+//
+// The types below create fast lookup slices for all syscalls. This maximum
+// serves as a sanity check that we don't allocate huge slices for a very large
+// syscall. This is checked during registration.
+const maxSyscallNum = 2000
+
+// SyscallSupportLevel is a syscall support levels.
+type SyscallSupportLevel int
+
+// String returns a human readable represetation of the support level.
+func (l SyscallSupportLevel) String() string {
+ switch l {
+ case SupportUnimplemented:
+ return "Unimplemented"
+ case SupportPartial:
+ return "Partial Support"
+ case SupportFull:
+ return "Full Support"
+ default:
+ return "Undocumented"
+ }
+}
+
+const (
+ // SupportUndocumented indicates the syscall is not documented yet.
+ SupportUndocumented = iota
+
+ // SupportUnimplemented indicates the syscall is unimplemented.
+ SupportUnimplemented
+
+ // SupportPartial indicates the syscall is partially supported.
+ SupportPartial
+
+ // SupportFull indicates the syscall is fully supported.
+ SupportFull
+)
+
+// Syscall includes the syscall implementation and compatibility information.
+type Syscall struct {
+ // Name is the syscall name.
+ Name string
+ // Fn is the implementation of the syscall.
+ Fn SyscallFn
+ // SupportLevel is the level of support implemented in gVisor.
+ SupportLevel SyscallSupportLevel
+ // Note describes the compatibility of the syscall.
+ Note string
+ // URLs is set of URLs to any relevant bugs or issues.
+ URLs []string
+}
+
+// SyscallFn is a syscall implementation.
+type SyscallFn func(t *Task, args arch.SyscallArguments) (uintptr, *SyscallControl, error)
+
+// MissingFn is a syscall to be called when an implementation is missing.
+type MissingFn func(t *Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error)
+
+// Possible flags for SyscallFlagsTable.enable.
+const (
+ // syscallPresent indicates that this is not a missing syscall.
+ //
+ // This flag is used internally in SyscallFlagsTable.
+ syscallPresent = 1 << iota
+
+ // StraceEnableLog enables syscall log tracing.
+ StraceEnableLog
+
+ // StraceEnableEvent enables syscall event tracing.
+ StraceEnableEvent
+
+ // ExternalBeforeEnable enables the external hook before syscall execution.
+ ExternalBeforeEnable
+
+ // ExternalAfterEnable enables the external hook after syscall execution.
+ ExternalAfterEnable
+)
+
+// StraceEnableBits combines both strace log and event flags.
+const StraceEnableBits = StraceEnableLog | StraceEnableEvent
+
+// SyscallFlagsTable manages a set of enable/disable bit fields on a per-syscall
+// basis.
+type SyscallFlagsTable struct {
+ // mu protects writes to the fields below.
+ //
+ // Atomic loads are always allowed. Atomic stores are allowed only
+ // while mu is held.
+ mu sync.Mutex
+
+ // enable contains the enable bits for each syscall.
+ //
+ // missing syscalls have the same value in enable as missingEnable to
+ // avoid an extra branch in Word.
+ enable []uint32
+
+ // missingEnable contains the enable bits for missing syscalls.
+ missingEnable uint32
+}
+
+// Init initializes the struct, with all syscalls in table set to enable.
+//
+// max is the largest syscall number in table.
+func (e *SyscallFlagsTable) init(table map[uintptr]Syscall, max uintptr) {
+ e.enable = make([]uint32, max+1)
+ for num := range table {
+ e.enable[num] = syscallPresent
+ }
+}
+
+// Word returns the enable bitfield for sysno.
+func (e *SyscallFlagsTable) Word(sysno uintptr) uint32 {
+ if sysno < uintptr(len(e.enable)) {
+ return atomic.LoadUint32(&e.enable[sysno])
+ }
+
+ return atomic.LoadUint32(&e.missingEnable)
+}
+
+// Enable sets enable bit bit for all syscalls based on s.
+//
+// Syscalls missing from s are disabled.
+//
+// Syscalls missing from the initial table passed to Init cannot be added as
+// individual syscalls. If present in s they will be ignored.
+//
+// Callers to Word may see either the old or new value while this function
+// is executing.
+func (e *SyscallFlagsTable) Enable(bit uint32, s map[uintptr]bool, missingEnable bool) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ missingVal := atomic.LoadUint32(&e.missingEnable)
+ if missingEnable {
+ missingVal |= bit
+ } else {
+ missingVal &^= bit
+ }
+ atomic.StoreUint32(&e.missingEnable, missingVal)
+
+ for num := range e.enable {
+ val := atomic.LoadUint32(&e.enable[num])
+ if !bits.IsOn32(val, syscallPresent) {
+ // Missing.
+ atomic.StoreUint32(&e.enable[num], missingVal)
+ continue
+ }
+
+ if s[uintptr(num)] {
+ val |= bit
+ } else {
+ val &^= bit
+ }
+ atomic.StoreUint32(&e.enable[num], val)
+ }
+}
+
+// EnableAll sets enable bit bit for all syscalls, present and missing.
+func (e *SyscallFlagsTable) EnableAll(bit uint32) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ missingVal := atomic.LoadUint32(&e.missingEnable)
+ missingVal |= bit
+ atomic.StoreUint32(&e.missingEnable, missingVal)
+
+ for num := range e.enable {
+ val := atomic.LoadUint32(&e.enable[num])
+ if !bits.IsOn32(val, syscallPresent) {
+ // Missing.
+ atomic.StoreUint32(&e.enable[num], missingVal)
+ continue
+ }
+
+ val |= bit
+ atomic.StoreUint32(&e.enable[num], val)
+ }
+}
+
+// Stracer traces syscall execution.
+type Stracer interface {
+ // SyscallEnter is called on syscall entry.
+ //
+ // The returned private data is passed to SyscallExit.
+ SyscallEnter(t *Task, sysno uintptr, args arch.SyscallArguments, flags uint32) interface{}
+
+ // SyscallExit is called on syscall exit.
+ SyscallExit(context interface{}, t *Task, sysno, rval uintptr, err error)
+}
+
+// SyscallTable is a lookup table of system calls.
+//
+// Note that a SyscallTable is not savable directly. Instead, they are saved as
+// an OS/Arch pair and lookup happens again on restore.
+type SyscallTable struct {
+ // OS is the operating system that this syscall table implements.
+ OS abi.OS
+
+ // Arch is the architecture that this syscall table targets.
+ Arch arch.Arch
+
+ // The OS version that this syscall table implements.
+ Version Version
+
+ // AuditNumber is a numeric constant that represents the syscall table. If
+ // non-zero, auditNumber must be one of the AUDIT_ARCH_* values defined by
+ // linux/audit.h.
+ AuditNumber uint32
+
+ // Table is the collection of functions.
+ Table map[uintptr]Syscall
+
+ // lookup is a fixed-size array that holds the syscalls (indexed by
+ // their numbers). It is used for fast look ups.
+ lookup []SyscallFn
+
+ // Emulate is a collection of instruction addresses to emulate. The
+ // keys are addresses, and the values are system call numbers.
+ Emulate map[usermem.Addr]uintptr
+
+ // The function to call in case of a missing system call.
+ Missing MissingFn
+
+ // Stracer traces this syscall table.
+ Stracer Stracer
+
+ // External is used to handle an external callback.
+ External func(*Kernel)
+
+ // ExternalFilterBefore is called before External is called before the syscall is executed.
+ // External is not called if it returns false.
+ ExternalFilterBefore func(*Task, uintptr, arch.SyscallArguments) bool
+
+ // ExternalFilterAfter is called before External is called after the syscall is executed.
+ // External is not called if it returns false.
+ ExternalFilterAfter func(*Task, uintptr, arch.SyscallArguments) bool
+
+ // FeatureEnable stores the strace and one-shot enable bits.
+ FeatureEnable SyscallFlagsTable
+}
+
+// MaxSysno returns the largest system call number.
+func (s *SyscallTable) MaxSysno() (max uintptr) {
+ for num := range s.Table {
+ if num > max {
+ max = num
+ }
+ }
+ return max
+}
+
+// allSyscallTables contains all known tables.
+var allSyscallTables []*SyscallTable
+
+// SyscallTables returns a read-only slice of registered SyscallTables.
+func SyscallTables() []*SyscallTable {
+ return allSyscallTables
+}
+
+// LookupSyscallTable returns the SyscallCall table for the OS/Arch combination.
+func LookupSyscallTable(os abi.OS, a arch.Arch) (*SyscallTable, bool) {
+ for _, s := range allSyscallTables {
+ if s.OS == os && s.Arch == a {
+ return s, true
+ }
+ }
+ return nil, false
+}
+
+// RegisterSyscallTable registers a new syscall table for use by a Kernel.
+func RegisterSyscallTable(s *SyscallTable) {
+ if max := s.MaxSysno(); max > maxSyscallNum {
+ panic(fmt.Sprintf("SyscallTable %+v contains too large syscall number %d", s, max))
+ }
+ if _, ok := LookupSyscallTable(s.OS, s.Arch); ok {
+ panic(fmt.Sprintf("Duplicate SyscallTable registered for OS %v Arch %v", s.OS, s.Arch))
+ }
+ allSyscallTables = append(allSyscallTables, s)
+ s.Init()
+}
+
+// Init initializes the system call table.
+//
+// This should normally be called only during registration.
+func (s *SyscallTable) Init() {
+ if s.Table == nil {
+ // Ensure non-nil lookup table.
+ s.Table = make(map[uintptr]Syscall)
+ }
+ if s.Emulate == nil {
+ // Ensure non-nil emulate table.
+ s.Emulate = make(map[usermem.Addr]uintptr)
+ }
+
+ max := s.MaxSysno() // Checked during RegisterSyscallTable.
+
+ // Initialize the fast-lookup table.
+ s.lookup = make([]SyscallFn, max+1)
+ for num, sc := range s.Table {
+ s.lookup[num] = sc.Fn
+ }
+
+ // Initialize all features.
+ s.FeatureEnable.init(s.Table, max)
+}
+
+// Lookup returns the syscall implementation, if one exists.
+func (s *SyscallTable) Lookup(sysno uintptr) SyscallFn {
+ if sysno < uintptr(len(s.lookup)) {
+ return s.lookup[sysno]
+ }
+
+ return nil
+}
+
+// LookupName looks up a syscall name.
+func (s *SyscallTable) LookupName(sysno uintptr) string {
+ if sc, ok := s.Table[sysno]; ok {
+ return sc.Name
+ }
+ return fmt.Sprintf("sys_%d", sysno) // Unlikely.
+}
+
+// LookupEmulate looks up an emulation syscall number.
+func (s *SyscallTable) LookupEmulate(addr usermem.Addr) (uintptr, bool) {
+ sysno, ok := s.Emulate[addr]
+ return sysno, ok
+}
+
+// mapLookup is similar to Lookup, except that it only uses the syscall table,
+// that is, it skips the fast look array. This is available for benchmarking.
+func (s *SyscallTable) mapLookup(sysno uintptr) SyscallFn {
+ if sc, ok := s.Table[sysno]; ok {
+ return sc.Fn
+ }
+ return nil
+}
diff --git a/pkg/sentry/kernel/syscalls_state.go b/pkg/sentry/kernel/syscalls_state.go
new file mode 100644
index 000000000..90f890495
--- /dev/null
+++ b/pkg/sentry/kernel/syscalls_state.go
@@ -0,0 +1,47 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+// syscallTableInfo is used to reload the SyscallTable.
+//
+// +stateify savable
+type syscallTableInfo struct {
+ OS abi.OS
+ Arch arch.Arch
+}
+
+// saveSt saves the SyscallTable.
+func (tc *TaskContext) saveSt() syscallTableInfo {
+ return syscallTableInfo{
+ OS: tc.st.OS,
+ Arch: tc.st.Arch,
+ }
+}
+
+// loadSt loads the SyscallTable.
+func (tc *TaskContext) loadSt(sti syscallTableInfo) {
+ st, ok := LookupSyscallTable(sti.OS, sti.Arch)
+ if !ok {
+ panic(fmt.Sprintf("syscall table not found for OS %v, Arch %v", sti.OS, sti.Arch))
+ }
+ tc.st = st // Save the table reference.
+}
diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go
new file mode 100644
index 000000000..4607cde2f
--- /dev/null
+++ b/pkg/sentry/kernel/syslog.go
@@ -0,0 +1,108 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+ "math/rand"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// syslog represents a sentry-global kernel log.
+//
+// Currently, it contains only fun messages for a dmesg easter egg.
+//
+// +stateify savable
+type syslog struct {
+ // mu protects the below.
+ mu sync.Mutex `state:"nosave"`
+
+ // msg is the syslog message buffer. It is lazily initialized.
+ msg []byte
+}
+
+// Log returns a copy of the syslog.
+func (s *syslog) Log() []byte {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if s.msg != nil {
+ // Already initialized, just return a copy.
+ o := make([]byte, len(s.msg))
+ copy(o, s.msg)
+ return o
+ }
+
+ // Not initialized, create message.
+ allMessages := []string{
+ "Synthesizing system calls...",
+ "Mounting deweydecimalfs...",
+ "Moving files to filing cabinet...",
+ "Digging up root...",
+ "Constructing home...",
+ "Segmenting fault lines...",
+ "Creating bureaucratic processes...",
+ "Searching for needles in stacks...",
+ "Preparing for the zombie uprising...",
+ "Feeding the init monster...",
+ "Creating cloned children...",
+ "Daemonizing children...",
+ "Waiting for children...",
+ "Gathering forks...",
+ "Committing treasure map to memory...",
+ "Reading process obituaries...",
+ "Searching for socket adapter...",
+ "Creating process schedule...",
+ "Generating random numbers by fair dice roll...",
+ "Rewriting operating system in Javascript...",
+ "Reticulating splines...",
+ "Consulting tar man page...",
+ "Forking spaghetti code...",
+ "Checking naughty and nice process list...",
+ "Checking naughty and nice process list...", // Check it up to twice.
+ "Granting licence to kill(2)...", // British spelling for British movie.
+ "Letting the watchdogs out...",
+ }
+
+ selectMessage := func() string {
+ i := rand.Intn(len(allMessages))
+ m := allMessages[i]
+
+ // Delete the selected message.
+ allMessages[i] = allMessages[len(allMessages)-1]
+ allMessages = allMessages[:len(allMessages)-1]
+
+ return m
+ }
+
+ const format = "<6>[%11.6f] %s\n"
+
+ s.msg = append(s.msg, []byte(fmt.Sprintf(format, 0.0, "Starting gVisor..."))...)
+
+ time := 0.1
+ for i := 0; i < 10; i++ {
+ time += rand.Float64() / 2
+ s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, selectMessage()))...)
+ }
+
+ time += rand.Float64() / 2
+ s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Ready!"))...)
+
+ // Return a copy.
+ o := make([]byte, len(s.msg))
+ copy(o, s.msg)
+ return o
+}
diff --git a/pkg/sentry/kernel/table_test.go b/pkg/sentry/kernel/table_test.go
new file mode 100644
index 000000000..32cf47e05
--- /dev/null
+++ b/pkg/sentry/kernel/table_test.go
@@ -0,0 +1,110 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+const (
+ maxTestSyscall = 1000
+)
+
+func createSyscallTable() *SyscallTable {
+ m := make(map[uintptr]Syscall)
+ for i := uintptr(0); i <= maxTestSyscall; i++ {
+ j := i
+ m[i] = Syscall{
+ Fn: func(*Task, arch.SyscallArguments) (uintptr, *SyscallControl, error) {
+ return j, nil, nil
+ },
+ }
+ }
+
+ s := &SyscallTable{
+ OS: abi.Linux,
+ Arch: arch.AMD64,
+ Table: m,
+ }
+
+ RegisterSyscallTable(s)
+ return s
+}
+
+func TestTable(t *testing.T) {
+ table := createSyscallTable()
+ defer func() {
+ // Cleanup registered tables to keep tests separate.
+ allSyscallTables = []*SyscallTable{}
+ }()
+
+ // Go through all functions and check that they return the right value.
+ for i := uintptr(0); i < maxTestSyscall; i++ {
+ fn := table.Lookup(i)
+ if fn == nil {
+ t.Errorf("Syscall %v is set to nil", i)
+ continue
+ }
+
+ v, _, _ := fn(nil, arch.SyscallArguments{})
+ if v != i {
+ t.Errorf("Wrong return value for syscall %v: expected %v, got %v", i, i, v)
+ }
+ }
+
+ // Check that values outside the range return nil.
+ for i := uintptr(maxTestSyscall + 1); i < maxTestSyscall+100; i++ {
+ fn := table.Lookup(i)
+ if fn != nil {
+ t.Errorf("Syscall %v is not nil: %v", i, fn)
+ continue
+ }
+ }
+}
+
+func BenchmarkTableLookup(b *testing.B) {
+ table := createSyscallTable()
+
+ b.ResetTimer()
+
+ j := uintptr(0)
+ for i := 0; i < b.N; i++ {
+ table.Lookup(j)
+ j = (j + 1) % 310
+ }
+
+ b.StopTimer()
+ // Cleanup registered tables to keep tests separate.
+ allSyscallTables = []*SyscallTable{}
+}
+
+func BenchmarkTableMapLookup(b *testing.B) {
+ table := createSyscallTable()
+
+ b.ResetTimer()
+
+ j := uintptr(0)
+ for i := 0; i < b.N; i++ {
+ table.mapLookup(j)
+ j = (j + 1) % 310
+ }
+
+ b.StopTimer()
+ // Cleanup registered tables to keep tests separate.
+ allSyscallTables = []*SyscallTable{}
+}
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
new file mode 100644
index 000000000..f48247c94
--- /dev/null
+++ b/pkg/sentry/kernel/task.go
@@ -0,0 +1,886 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ gocontext "context"
+ "runtime/trace"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/futex"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/sched"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/unimpl"
+ "gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// Task represents a thread of execution in the untrusted app. It
+// includes registers and any thread-specific state that you would
+// normally expect.
+//
+// Each task is associated with a goroutine, called the task goroutine, that
+// executes code (application code, system calls, etc.) on behalf of that task.
+// See Task.run (task_run.go).
+//
+// All fields that are "owned by the task goroutine" can only be mutated by the
+// task goroutine while it is running. The task goroutine does not require
+// synchronization to read these fields, although it still requires
+// synchronization as described for those fields to mutate them.
+//
+// All fields that are "exclusive to the task goroutine" can only be accessed
+// by the task goroutine while it is running. The task goroutine does not
+// require synchronization to read or write these fields.
+//
+// +stateify savable
+type Task struct {
+ taskNode
+
+ // runState is what the task goroutine is executing if it is not stopped.
+ // If runState is nil, the task goroutine should exit or has exited.
+ // runState is exclusive to the task goroutine.
+ runState taskRunState
+
+ // haveSyscallReturn is true if tc.Arch().Return() represents a value
+ // returned by a syscall (or set by ptrace after a syscall).
+ //
+ // haveSyscallReturn is exclusive to the task goroutine.
+ haveSyscallReturn bool
+
+ // interruptChan is notified whenever the task goroutine is interrupted
+ // (usually by a pending signal). interruptChan is effectively a condition
+ // variable that can be used in select statements.
+ //
+ // interruptChan is not saved; because saving interrupts all tasks,
+ // interruptChan is always notified after restore (see Task.run).
+ interruptChan chan struct{} `state:"nosave"`
+
+ // gosched contains the current scheduling state of the task goroutine.
+ //
+ // gosched is protected by goschedSeq. gosched is owned by the task
+ // goroutine.
+ goschedSeq sync.SeqCount `state:"nosave"`
+ gosched TaskGoroutineSchedInfo
+
+ // yieldCount is the number of times the task goroutine has called
+ // Task.InterruptibleSleepStart, Task.UninterruptibleSleepStart, or
+ // Task.Yield(), voluntarily ceasing execution.
+ //
+ // yieldCount is accessed using atomic memory operations. yieldCount is
+ // owned by the task goroutine.
+ yieldCount uint64
+
+ // pendingSignals is the set of pending signals that may be handled only by
+ // this task.
+ //
+ // pendingSignals is protected by (taskNode.)tg.signalHandlers.mu
+ // (hereafter "the signal mutex"); see comment on
+ // ThreadGroup.signalHandlers.
+ pendingSignals pendingSignals
+
+ // signalMask is the set of signals whose delivery is currently blocked.
+ //
+ // signalMask is accessed using atomic memory operations, and is protected
+ // by the signal mutex (such that reading signalMask is safe if either the
+ // signal mutex is locked or if atomic memory operations are used, while
+ // writing signalMask requires both). signalMask is owned by the task
+ // goroutine.
+ signalMask linux.SignalSet
+
+ // If the task goroutine is currently executing Task.sigtimedwait,
+ // realSignalMask is the previous value of signalMask, which has temporarily
+ // been replaced by Task.sigtimedwait. Otherwise, realSignalMask is 0.
+ //
+ // realSignalMask is exclusive to the task goroutine.
+ realSignalMask linux.SignalSet
+
+ // If haveSavedSignalMask is true, savedSignalMask is the signal mask that
+ // should be applied after the task has either delivered one signal to a
+ // user handler or is about to resume execution in the untrusted
+ // application.
+ //
+ // Both haveSavedSignalMask and savedSignalMask are exclusive to the task
+ // goroutine.
+ haveSavedSignalMask bool
+ savedSignalMask linux.SignalSet
+
+ // signalStack is the alternate signal stack used by signal handlers for
+ // which the SA_ONSTACK flag is set.
+ //
+ // signalStack is exclusive to the task goroutine.
+ signalStack arch.SignalStack
+
+ // signalQueue is a set of registered waiters for signal-related events.
+ //
+ // signalQueue is protected by the signalMutex. Note that the task does
+ // not implement all queue methods, specifically the readiness checks.
+ // The task only broadcast a notification on signal delivery.
+ signalQueue waiter.Queue `state:"zerovalue"`
+
+ // If groupStopPending is true, the task should participate in a group
+ // stop in the interrupt path.
+ //
+ // groupStopPending is analogous to JOBCTL_STOP_PENDING in Linux.
+ //
+ // groupStopPending is protected by the signal mutex.
+ groupStopPending bool
+
+ // If groupStopAcknowledged is true, the task has already acknowledged that
+ // it is entering the most recent group stop that has been initiated on its
+ // thread group.
+ //
+ // groupStopAcknowledged is analogous to !JOBCTL_STOP_CONSUME in Linux.
+ //
+ // groupStopAcknowledged is protected by the signal mutex.
+ groupStopAcknowledged bool
+
+ // If trapStopPending is true, the task goroutine should enter a
+ // PTRACE_INTERRUPT-induced stop from the interrupt path.
+ //
+ // trapStopPending is analogous to JOBCTL_TRAP_STOP in Linux, except that
+ // Linux also sets JOBCTL_TRAP_STOP when a ptraced task detects
+ // JOBCTL_STOP_PENDING.
+ //
+ // trapStopPending is protected by the signal mutex.
+ trapStopPending bool
+
+ // If trapNotifyPending is true, this task is PTRACE_SEIZEd, and a group
+ // stop has begun or ended since the last time the task entered a
+ // ptrace-stop from the group-stop path.
+ //
+ // trapNotifyPending is analogous to JOBCTL_TRAP_NOTIFY in Linux.
+ //
+ // trapNotifyPending is protected by the signal mutex.
+ trapNotifyPending bool
+
+ // If stop is not nil, it is the internally-initiated condition that
+ // currently prevents the task goroutine from running.
+ //
+ // stop is protected by the signal mutex.
+ stop TaskStop
+
+ // stopCount is the number of active external stops (calls to
+ // Task.BeginExternalStop that have not been paired with a call to
+ // Task.EndExternalStop), plus 1 if stop is not nil. Hence stopCount is
+ // non-zero if the task goroutine should stop.
+ //
+ // Mutating stopCount requires both locking the signal mutex and using
+ // atomic memory operations. Reading stopCount requires either locking the
+ // signal mutex or using atomic memory operations. This allows Task.doStop
+ // to require only a single atomic read in the common case where stopCount
+ // is 0.
+ //
+ // stopCount is not saved, because external stops cannot be retained across
+ // a save/restore cycle. (Suppose a sentryctl command issues an external
+ // stop; after a save/restore cycle, the restored sentry has no knowledge
+ // of the pre-save sentryctl command, and the stopped task would remain
+ // stopped forever.)
+ stopCount int32 `state:"nosave"`
+
+ // endStopCond is signaled when stopCount transitions to 0. The combination
+ // of stopCount and endStopCond effectively form a sync.WaitGroup, but
+ // WaitGroup provides no way to read its counter value.
+ //
+ // Invariant: endStopCond.L is the signal mutex. (This is not racy because
+ // sync.Cond.Wait is the only user of sync.Cond.L; only the task goroutine
+ // calls sync.Cond.Wait; and only the task goroutine can change the
+ // identity of the signal mutex, in Task.finishExec.)
+ endStopCond sync.Cond `state:"nosave"`
+
+ // exitStatus is the task's exit status.
+ //
+ // exitStatus is protected by the signal mutex.
+ exitStatus ExitStatus
+
+ // syscallRestartBlock represents a custom restart function to run in
+ // restart_syscall(2) to resume an interrupted syscall.
+ //
+ // syscallRestartBlock is exclusive to the task goroutine.
+ syscallRestartBlock SyscallRestartBlock
+
+ // p provides the mechanism by which the task runs code in userspace. The p
+ // interface object is immutable.
+ p platform.Context `state:"nosave"`
+
+ // k is the Kernel that this task belongs to. The k pointer is immutable.
+ k *Kernel
+
+ // containerID has no equivalent in Linux; it's used by runsc to track all
+ // tasks that belong to a given containers since cgroups aren't implemented.
+ // It's inherited by the children, is immutable, and may be empty.
+ //
+ // NOTE: cgroups can be used to track this when implemented.
+ containerID string
+
+ // mu protects some of the following fields.
+ mu sync.Mutex `state:"nosave"`
+
+ // tc holds task data provided by the ELF loader.
+ //
+ // tc is protected by mu, and is owned by the task goroutine.
+ tc TaskContext
+
+ // fsContext is the task's filesystem context.
+ //
+ // fsContext is protected by mu, and is owned by the task goroutine.
+ fsContext *FSContext
+
+ // fdTable is the task's file descriptor table.
+ //
+ // fdTable is protected by mu, and is owned by the task goroutine.
+ fdTable *FDTable
+
+ // If vforkParent is not nil, it is the task that created this task with
+ // vfork() or clone(CLONE_VFORK), and should have its vforkStop ended when
+ // this TaskContext is released.
+ //
+ // vforkParent is protected by the TaskSet mutex.
+ vforkParent *Task
+
+ // exitState is the task's progress through the exit path.
+ //
+ // exitState is protected by the TaskSet mutex. exitState is owned by the
+ // task goroutine.
+ exitState TaskExitState
+
+ // exitTracerNotified is true if the exit path has either signaled the
+ // task's tracer to indicate the exit, or determined that no such signal is
+ // needed. exitTracerNotified can only be true if exitState is
+ // TaskExitZombie or TaskExitDead.
+ //
+ // exitTracerNotified is protected by the TaskSet mutex.
+ exitTracerNotified bool
+
+ // exitTracerAcked is true if exitTracerNotified is true and either the
+ // task's tracer has acknowledged the exit notification, or the exit path
+ // has determined that no such notification is needed.
+ //
+ // exitTracerAcked is protected by the TaskSet mutex.
+ exitTracerAcked bool
+
+ // exitParentNotified is true if the exit path has either signaled the
+ // task's parent to indicate the exit, or determined that no such signal is
+ // needed. exitParentNotified can only be true if exitState is
+ // TaskExitZombie or TaskExitDead.
+ //
+ // exitParentNotified is protected by the TaskSet mutex.
+ exitParentNotified bool
+
+ // exitParentAcked is true if exitParentNotified is true and either the
+ // task's parent has acknowledged the exit notification, or the exit path
+ // has determined that no such acknowledgment is needed.
+ //
+ // exitParentAcked is protected by the TaskSet mutex.
+ exitParentAcked bool
+
+ // goroutineStopped is a WaitGroup whose counter value is 1 when the task
+ // goroutine is running and 0 when the task goroutine is stopped or has
+ // exited.
+ goroutineStopped sync.WaitGroup `state:"nosave"`
+
+ // ptraceTracer is the task that is ptrace-attached to this one. If
+ // ptraceTracer is nil, this task is not being traced. Note that due to
+ // atomic.Value limitations (atomic.Value.Store(nil) panics), a nil
+ // ptraceTracer is always represented as a typed nil (i.e. (*Task)(nil)).
+ //
+ // ptraceTracer is protected by the TaskSet mutex, and accessed with atomic
+ // operations. This allows paths that wouldn't otherwise lock the TaskSet
+ // mutex, notably the syscall path, to check if ptraceTracer is nil without
+ // additional synchronization.
+ ptraceTracer atomic.Value `state:".(*Task)"`
+
+ // ptraceTracees is the set of tasks that this task is ptrace-attached to.
+ //
+ // ptraceTracees is protected by the TaskSet mutex.
+ ptraceTracees map[*Task]struct{}
+
+ // ptraceSeized is true if ptraceTracer attached to this task with
+ // PTRACE_SEIZE.
+ //
+ // ptraceSeized is protected by the TaskSet mutex.
+ ptraceSeized bool
+
+ // ptraceOpts contains ptrace options explicitly set by the tracer. If
+ // ptraceTracer is nil, ptraceOpts is expected to be the zero value.
+ //
+ // ptraceOpts is protected by the TaskSet mutex.
+ ptraceOpts ptraceOptions
+
+ // ptraceSyscallMode controls ptrace behavior around syscall entry and
+ // exit.
+ //
+ // ptraceSyscallMode is protected by the TaskSet mutex.
+ ptraceSyscallMode ptraceSyscallMode
+
+ // If ptraceSinglestep is true, the next time the task executes application
+ // code, single-stepping should be enabled. ptraceSinglestep is stored
+ // independently of the architecture-specific trap flag because tracer
+ // detaching (which can happen concurrently with the tracee's execution if
+ // the tracer exits) must disable single-stepping, and the task's
+ // architectural state is implicitly exclusive to the task goroutine (no
+ // synchronization occurs before passing registers to SwitchToApp).
+ //
+ // ptraceSinglestep is analogous to Linux's TIF_SINGLESTEP.
+ //
+ // ptraceSinglestep is protected by the TaskSet mutex.
+ ptraceSinglestep bool
+
+ // If t is ptrace-stopped, ptraceCode is a ptrace-defined value set at the
+ // time that t entered the ptrace stop, reset to 0 when the tracer
+ // acknowledges the stop with a wait*() syscall. Otherwise, it is the
+ // signal number passed to the ptrace operation that ended the last ptrace
+ // stop on this task. In the latter case, the effect of ptraceCode depends
+ // on the nature of the ptrace stop; signal-delivery-stop uses it to
+ // conditionally override ptraceSiginfo, syscall-entry/exit-stops send the
+ // signal to the task after leaving the stop, and PTRACE_EVENT stops and
+ // traced group stops ignore it entirely.
+ //
+ // Linux contextually stores the equivalent of ptraceCode in
+ // task_struct::exit_code.
+ //
+ // ptraceCode is protected by the TaskSet mutex.
+ ptraceCode int32
+
+ // ptraceSiginfo is the value returned to the tracer by
+ // ptrace(PTRACE_GETSIGINFO) and modified by ptrace(PTRACE_SETSIGINFO).
+ // (Despite the name, PTRACE_PEEKSIGINFO is completely unrelated.)
+ // ptraceSiginfo is nil if the task is in a ptraced group-stop (this is
+ // required for PTRACE_GETSIGINFO to return EINVAL during such stops, which
+ // is in turn required to distinguish group stops from other ptrace stops,
+ // per subsection "Group-stop" in ptrace(2)).
+ //
+ // ptraceSiginfo is analogous to Linux's task_struct::last_siginfo.
+ //
+ // ptraceSiginfo is protected by the TaskSet mutex.
+ ptraceSiginfo *arch.SignalInfo
+
+ // ptraceEventMsg is the value set by PTRACE_EVENT stops and returned to
+ // the tracer by ptrace(PTRACE_GETEVENTMSG).
+ //
+ // ptraceEventMsg is protected by the TaskSet mutex.
+ ptraceEventMsg uint64
+
+ // The struct that holds the IO-related usage. The ioUsage pointer is
+ // immutable.
+ ioUsage *usage.IO
+
+ // logPrefix is a string containing the task's thread ID in the root PID
+ // namespace, and is prepended to log messages emitted by Task.Infof etc.
+ logPrefix atomic.Value `state:"nosave"`
+
+ // traceContext and traceTask are both used for tracing, and are
+ // updated along with the logPrefix in updateInfoLocked.
+ //
+ // These are exclusive to the task goroutine.
+ traceContext gocontext.Context `state:"nosave"`
+ traceTask *trace.Task `state:"nosave"`
+
+ // creds is the task's credentials.
+ //
+ // creds.Load() may be called without synchronization. creds.Store() is
+ // serialized by mu. creds is owned by the task goroutine. All
+ // auth.Credentials objects that creds may point to, or have pointed to
+ // in the past, must be treated as immutable.
+ creds auth.AtomicPtrCredentials
+
+ // utsns is the task's UTS namespace.
+ //
+ // utsns is protected by mu. utsns is owned by the task goroutine.
+ utsns *UTSNamespace
+
+ // ipcns is the task's IPC namespace.
+ //
+ // ipcns is protected by mu. ipcns is owned by the task goroutine.
+ ipcns *IPCNamespace
+
+ // abstractSockets tracks abstract sockets that are in use.
+ //
+ // abstractSockets is protected by mu.
+ abstractSockets *AbstractSocketNamespace
+
+ // mountNamespaceVFS2 is the task's mount namespace.
+ //
+ // It is protected by mu. It is owned by the task goroutine.
+ mountNamespaceVFS2 *vfs.MountNamespace
+
+ // parentDeathSignal is sent to this task's thread group when its parent exits.
+ //
+ // parentDeathSignal is protected by mu.
+ parentDeathSignal linux.Signal
+
+ // syscallFilters is all seccomp-bpf syscall filters applicable to the
+ // task, in the order in which they were installed. The type of the atomic
+ // is []bpf.Program. Writing needs to be protected by the signal mutex.
+ //
+ // syscallFilters is owned by the task goroutine.
+ syscallFilters atomic.Value `state:".([]bpf.Program)"`
+
+ // If cleartid is non-zero, treat it as a pointer to a ThreadID in the
+ // task's virtual address space; when the task exits, set the pointed-to
+ // ThreadID to 0, and wake any futex waiters.
+ //
+ // cleartid is exclusive to the task goroutine.
+ cleartid usermem.Addr
+
+ // This is mostly a fake cpumask just for sched_set/getaffinity as we
+ // don't really control the affinity.
+ //
+ // Invariant: allowedCPUMask.Size() ==
+ // sched.CPUMaskSize(Kernel.applicationCores).
+ //
+ // allowedCPUMask is protected by mu.
+ allowedCPUMask sched.CPUSet
+
+ // cpu is the fake cpu number returned by getcpu(2). cpu is ignored
+ // entirely if Kernel.useHostCores is true.
+ //
+ // cpu is accessed using atomic memory operations.
+ cpu int32
+
+ // This is used to keep track of changes made to a process' priority/niceness.
+ // It is mostly used to provide some reasonable return value from
+ // getpriority(2) after a call to setpriority(2) has been made.
+ // We currently do not actually modify a process' scheduling priority.
+ // NOTE: This represents the userspace view of priority (nice).
+ // This means that the value should be in the range [-20, 19].
+ //
+ // niceness is protected by mu.
+ niceness int
+
+ // This is used to track the numa policy for the current thread. This can be
+ // modified through a set_mempolicy(2) syscall. Since we always report a
+ // single numa node, all policies are no-ops. We only track this information
+ // so that we can return reasonable values if the application calls
+ // get_mempolicy(2) after setting a non-default policy. Note that in the
+ // real syscall, nodemask can be longer than a single unsigned long, but we
+ // always report a single node so never need to save more than a single
+ // bit.
+ //
+ // numaPolicy and numaNodeMask are protected by mu.
+ numaPolicy linux.NumaPolicy
+ numaNodeMask uint64
+
+ // netns is the task's network namespace. netns is never nil.
+ //
+ // netns is protected by mu.
+ netns *inet.Namespace
+
+ // If rseqPreempted is true, before the next call to p.Switch(),
+ // interrupt rseq critical regions as defined by rseqAddr and
+ // tg.oldRSeqCritical and write the task goroutine's CPU number to
+ // rseqAddr/oldRSeqCPUAddr.
+ //
+ // We support two ABIs for restartable sequences:
+ //
+ // 1. The upstream interface added in v4.18,
+ // 2. An "old" interface never merged upstream. In the implementation,
+ // this is referred to as "old rseq".
+ //
+ // rseqPreempted is exclusive to the task goroutine.
+ rseqPreempted bool `state:"nosave"`
+
+ // rseqCPU is the last CPU number written to rseqAddr/oldRSeqCPUAddr.
+ //
+ // If rseq is unused, rseqCPU is -1 for convenient use in
+ // platform.Context.Switch.
+ //
+ // rseqCPU is exclusive to the task goroutine.
+ rseqCPU int32
+
+ // oldRSeqCPUAddr is a pointer to the userspace old rseq CPU variable.
+ //
+ // oldRSeqCPUAddr is exclusive to the task goroutine.
+ oldRSeqCPUAddr usermem.Addr
+
+ // rseqAddr is a pointer to the userspace linux.RSeq structure.
+ //
+ // rseqAddr is exclusive to the task goroutine.
+ rseqAddr usermem.Addr
+
+ // rseqSignature is the signature that the rseq abort IP must be signed
+ // with.
+ //
+ // rseqSignature is exclusive to the task goroutine.
+ rseqSignature uint32
+
+ // copyScratchBuffer is a buffer available to CopyIn/CopyOut
+ // implementations that require an intermediate buffer to copy data
+ // into/out of. It prevents these buffers from being allocated/zeroed in
+ // each syscall and eventually garbage collected.
+ //
+ // copyScratchBuffer is exclusive to the task goroutine.
+ copyScratchBuffer [copyScratchBufferLen]byte `state:"nosave"`
+
+ // blockingTimer is used for blocking timeouts. blockingTimerChan is the
+ // channel that is sent to when blockingTimer fires.
+ //
+ // blockingTimer is exclusive to the task goroutine.
+ blockingTimer *ktime.Timer `state:"nosave"`
+ blockingTimerChan <-chan struct{} `state:"nosave"`
+
+ // futexWaiter is used for futex(FUTEX_WAIT) syscalls.
+ //
+ // futexWaiter is exclusive to the task goroutine.
+ futexWaiter *futex.Waiter `state:"nosave"`
+
+ // startTime is the real time at which the task started. It is set when
+ // a Task is created or invokes execve(2).
+ //
+ // startTime is protected by mu.
+ startTime ktime.Time
+}
+
+func (t *Task) savePtraceTracer() *Task {
+ return t.ptraceTracer.Load().(*Task)
+}
+
+func (t *Task) loadPtraceTracer(tracer *Task) {
+ t.ptraceTracer.Store(tracer)
+}
+
+func (t *Task) saveSyscallFilters() []bpf.Program {
+ if f := t.syscallFilters.Load(); f != nil {
+ return f.([]bpf.Program)
+ }
+ return nil
+}
+
+func (t *Task) loadSyscallFilters(filters []bpf.Program) {
+ t.syscallFilters.Store(filters)
+}
+
+// afterLoad is invoked by stateify.
+func (t *Task) afterLoad() {
+ t.updateInfoLocked()
+ t.interruptChan = make(chan struct{}, 1)
+ t.gosched.State = TaskGoroutineNonexistent
+ if t.stop != nil {
+ t.stopCount = 1
+ }
+ t.endStopCond.L = &t.tg.signalHandlers.mu
+ t.p = t.k.Platform.NewContext()
+ t.rseqPreempted = true
+ t.futexWaiter = futex.NewWaiter()
+}
+
+// copyScratchBufferLen is the length of Task.copyScratchBuffer.
+const copyScratchBufferLen = 144 // sizeof(struct stat)
+
+// CopyScratchBuffer returns a scratch buffer to be used in CopyIn/CopyOut
+// functions. It must only be used within those functions and can only be used
+// by the task goroutine; it exists to improve performance and thus
+// intentionally lacks any synchronization.
+//
+// Callers should pass a constant value as an argument if possible, which will
+// allow the compiler to inline and optimize out the if statement below.
+func (t *Task) CopyScratchBuffer(size int) []byte {
+ if size > copyScratchBufferLen {
+ return make([]byte, size)
+ }
+ return t.copyScratchBuffer[:size]
+}
+
+// FutexWaiter returns the Task's futex.Waiter.
+func (t *Task) FutexWaiter() *futex.Waiter {
+ return t.futexWaiter
+}
+
+// Kernel returns the Kernel containing t.
+func (t *Task) Kernel() *Kernel {
+ return t.k
+}
+
+// Value implements context.Context.Value.
+//
+// Preconditions: The caller must be running on the task goroutine (as implied
+// by the requirements of context.Context).
+func (t *Task) Value(key interface{}) interface{} {
+ switch key {
+ case CtxCanTrace:
+ return t.CanTrace
+ case CtxKernel:
+ return t.k
+ case CtxPIDNamespace:
+ return t.tg.pidns
+ case CtxUTSNamespace:
+ return t.utsns
+ case CtxIPCNamespace:
+ return t.ipcns
+ case CtxTask:
+ return t
+ case auth.CtxCredentials:
+ return t.Credentials()
+ case context.CtxThreadGroupID:
+ return int32(t.ThreadGroup().ID())
+ case fs.CtxRoot:
+ return t.fsContext.RootDirectory()
+ case vfs.CtxRoot:
+ return t.fsContext.RootDirectoryVFS2()
+ case vfs.CtxMountNamespace:
+ t.mountNamespaceVFS2.IncRef()
+ return t.mountNamespaceVFS2
+ case fs.CtxDirentCacheLimiter:
+ return t.k.DirentCacheLimiter
+ case inet.CtxStack:
+ return t.NetworkContext()
+ case ktime.CtxRealtimeClock:
+ return t.k.RealtimeClock()
+ case limits.CtxLimits:
+ return t.tg.limits
+ case pgalloc.CtxMemoryFile:
+ return t.k.mf
+ case pgalloc.CtxMemoryFileProvider:
+ return t.k
+ case platform.CtxPlatform:
+ return t.k
+ case uniqueid.CtxGlobalUniqueID:
+ return t.k.UniqueID()
+ case uniqueid.CtxGlobalUniqueIDProvider:
+ return t.k
+ case uniqueid.CtxInotifyCookie:
+ return t.k.GenerateInotifyCookie()
+ case unimpl.CtxEvents:
+ return t.k
+ default:
+ return nil
+ }
+}
+
+// SetClearTID sets t's cleartid.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) SetClearTID(addr usermem.Addr) {
+ t.cleartid = addr
+}
+
+// SetSyscallRestartBlock sets the restart block for use in
+// restart_syscall(2). After registering a restart block, a syscall should
+// return ERESTART_RESTARTBLOCK to request a restart using the block.
+//
+// Precondition: The caller must be running on the task goroutine.
+func (t *Task) SetSyscallRestartBlock(r SyscallRestartBlock) {
+ t.syscallRestartBlock = r
+}
+
+// SyscallRestartBlock returns the currently registered restart block for use in
+// restart_syscall(2). This function is *not* idempotent and may be called once
+// per syscall. This function must not be called if a restart block has not been
+// registered for the current syscall.
+//
+// Precondition: The caller must be running on the task goroutine.
+func (t *Task) SyscallRestartBlock() SyscallRestartBlock {
+ r := t.syscallRestartBlock
+ // Explicitly set the restart block to nil so that a future syscall can't
+ // accidentally reuse it.
+ t.syscallRestartBlock = nil
+ return r
+}
+
+// IsChrooted returns true if the root directory of t's FSContext is not the
+// root directory of t's MountNamespace.
+//
+// Preconditions: The caller must be running on the task goroutine, or t.mu
+// must be locked.
+func (t *Task) IsChrooted() bool {
+ if VFS2Enabled {
+ realRoot := t.mountNamespaceVFS2.Root()
+ defer realRoot.DecRef()
+ root := t.fsContext.RootDirectoryVFS2()
+ defer root.DecRef()
+ return root != realRoot
+ }
+
+ realRoot := t.tg.mounts.Root()
+ defer realRoot.DecRef()
+ root := t.fsContext.RootDirectory()
+ if root != nil {
+ defer root.DecRef()
+ }
+ return root != realRoot
+}
+
+// TaskContext returns t's TaskContext.
+//
+// Precondition: The caller must be running on the task goroutine, or t.mu must
+// be locked.
+func (t *Task) TaskContext() *TaskContext {
+ return &t.tc
+}
+
+// FSContext returns t's FSContext. FSContext does not take an additional
+// reference on the returned FSContext.
+//
+// Precondition: The caller must be running on the task goroutine, or t.mu must
+// be locked.
+func (t *Task) FSContext() *FSContext {
+ return t.fsContext
+}
+
+// FDTable returns t's FDTable. FDMTable does not take an additional reference
+// on the returned FDMap.
+//
+// Precondition: The caller must be running on the task goroutine, or t.mu must
+// be locked.
+func (t *Task) FDTable() *FDTable {
+ return t.fdTable
+}
+
+// GetFile is a convenience wrapper for t.FDTable().Get.
+//
+// Precondition: same as FDTable.Get.
+func (t *Task) GetFile(fd int32) *fs.File {
+ f, _ := t.fdTable.Get(fd)
+ return f
+}
+
+// GetFileVFS2 is a convenience wrapper for t.FDTable().GetVFS2.
+//
+// Precondition: same as FDTable.Get.
+func (t *Task) GetFileVFS2(fd int32) *vfs.FileDescription {
+ f, _ := t.fdTable.GetVFS2(fd)
+ return f
+}
+
+// NewFDs is a convenience wrapper for t.FDTable().NewFDs.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.
+func (t *Task) NewFDs(fd int32, files []*fs.File, flags FDFlags) ([]int32, error) {
+ return t.fdTable.NewFDs(t, fd, files, flags)
+}
+
+// NewFDsVFS2 is a convenience wrapper for t.FDTable().NewFDsVFS2.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.
+func (t *Task) NewFDsVFS2(fd int32, files []*vfs.FileDescription, flags FDFlags) ([]int32, error) {
+ return t.fdTable.NewFDsVFS2(t, fd, files, flags)
+}
+
+// NewFDFrom is a convenience wrapper for t.FDTable().NewFDs with a single file.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.
+func (t *Task) NewFDFrom(fd int32, file *fs.File, flags FDFlags) (int32, error) {
+ fds, err := t.fdTable.NewFDs(t, fd, []*fs.File{file}, flags)
+ if err != nil {
+ return 0, err
+ }
+ return fds[0], nil
+}
+
+// NewFDFromVFS2 is a convenience wrapper for t.FDTable().NewFDVFS2.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.Get.
+func (t *Task) NewFDFromVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
+ return t.fdTable.NewFDVFS2(t, fd, file, flags)
+}
+
+// NewFDAt is a convenience wrapper for t.FDTable().NewFDAt.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.
+func (t *Task) NewFDAt(fd int32, file *fs.File, flags FDFlags) error {
+ return t.fdTable.NewFDAt(t, fd, file, flags)
+}
+
+// NewFDAtVFS2 is a convenience wrapper for t.FDTable().NewFDAtVFS2.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.
+func (t *Task) NewFDAtVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) error {
+ return t.fdTable.NewFDAtVFS2(t, fd, file, flags)
+}
+
+// WithMuLocked executes f with t.mu locked.
+func (t *Task) WithMuLocked(f func(*Task)) {
+ t.mu.Lock()
+ f(t)
+ t.mu.Unlock()
+}
+
+// MountNamespace returns t's MountNamespace. MountNamespace does not take an
+// additional reference on the returned MountNamespace.
+func (t *Task) MountNamespace() *fs.MountNamespace {
+ return t.tg.mounts
+}
+
+// MountNamespaceVFS2 returns t's MountNamespace. A reference is taken on the
+// returned mount namespace.
+func (t *Task) MountNamespaceVFS2() *vfs.MountNamespace {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.mountNamespaceVFS2.IncRef()
+ return t.mountNamespaceVFS2
+}
+
+// AbstractSockets returns t's AbstractSocketNamespace.
+func (t *Task) AbstractSockets() *AbstractSocketNamespace {
+ return t.abstractSockets
+}
+
+// ContainerID returns t's container ID.
+func (t *Task) ContainerID() string {
+ return t.containerID
+}
+
+// OOMScoreAdj gets the task's thread group's OOM score adjustment.
+func (t *Task) OOMScoreAdj() int32 {
+ return atomic.LoadInt32(&t.tg.oomScoreAdj)
+}
+
+// SetOOMScoreAdj sets the task's thread group's OOM score adjustment. The
+// value should be between -1000 and 1000 inclusive.
+func (t *Task) SetOOMScoreAdj(adj int32) error {
+ if adj > 1000 || adj < -1000 {
+ return syserror.EINVAL
+ }
+ atomic.StoreInt32(&t.tg.oomScoreAdj, adj)
+ return nil
+}
+
+// UID returns t's uid.
+// TODO(gvisor.dev/issue/170): This method is not namespaced yet.
+func (t *Task) UID() uint32 {
+ return uint32(t.Credentials().EffectiveKUID)
+}
+
+// GID returns t's gid.
+// TODO(gvisor.dev/issue/170): This method is not namespaced yet.
+func (t *Task) GID() uint32 {
+ return uint32(t.Credentials().EffectiveKGID)
+}
diff --git a/pkg/sentry/kernel/task_acct.go b/pkg/sentry/kernel/task_acct.go
new file mode 100644
index 000000000..5f3e60fe8
--- /dev/null
+++ b/pkg/sentry/kernel/task_acct.go
@@ -0,0 +1,196 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+// Accounting, limits, timers.
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Getitimer implements getitimer(2).
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) Getitimer(id int32) (linux.ItimerVal, error) {
+ var tm ktime.Time
+ var s ktime.Setting
+ switch id {
+ case linux.ITIMER_REAL:
+ tm, s = t.tg.itimerRealTimer.Get()
+ case linux.ITIMER_VIRTUAL:
+ tm = t.tg.UserCPUClock().Now()
+ t.tg.signalHandlers.mu.Lock()
+ s, _ = t.tg.itimerVirtSetting.At(tm)
+ t.tg.signalHandlers.mu.Unlock()
+ case linux.ITIMER_PROF:
+ tm = t.tg.CPUClock().Now()
+ t.tg.signalHandlers.mu.Lock()
+ s, _ = t.tg.itimerProfSetting.At(tm)
+ t.tg.signalHandlers.mu.Unlock()
+ default:
+ return linux.ItimerVal{}, syserror.EINVAL
+ }
+ val, iv := ktime.SpecFromSetting(tm, s)
+ return linux.ItimerVal{
+ Value: linux.DurationToTimeval(val),
+ Interval: linux.DurationToTimeval(iv),
+ }, nil
+}
+
+// Setitimer implements setitimer(2).
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) Setitimer(id int32, newitv linux.ItimerVal) (linux.ItimerVal, error) {
+ var tm ktime.Time
+ var olds ktime.Setting
+ switch id {
+ case linux.ITIMER_REAL:
+ news, err := ktime.SettingFromSpec(newitv.Value.ToDuration(), newitv.Interval.ToDuration(), t.tg.itimerRealTimer.Clock())
+ if err != nil {
+ return linux.ItimerVal{}, err
+ }
+ tm, olds = t.tg.itimerRealTimer.Swap(news)
+ case linux.ITIMER_VIRTUAL:
+ c := t.tg.UserCPUClock()
+ var err error
+ t.k.cpuClockTicker.Atomically(func() {
+ tm = c.Now()
+ var news ktime.Setting
+ news, err = ktime.SettingFromSpecAt(newitv.Value.ToDuration(), newitv.Interval.ToDuration(), tm)
+ if err != nil {
+ return
+ }
+ t.tg.signalHandlers.mu.Lock()
+ olds = t.tg.itimerVirtSetting
+ t.tg.itimerVirtSetting = news
+ t.tg.updateCPUTimersEnabledLocked()
+ t.tg.signalHandlers.mu.Unlock()
+ })
+ if err != nil {
+ return linux.ItimerVal{}, err
+ }
+ case linux.ITIMER_PROF:
+ c := t.tg.CPUClock()
+ var err error
+ t.k.cpuClockTicker.Atomically(func() {
+ tm = c.Now()
+ var news ktime.Setting
+ news, err = ktime.SettingFromSpecAt(newitv.Value.ToDuration(), newitv.Interval.ToDuration(), tm)
+ if err != nil {
+ return
+ }
+ t.tg.signalHandlers.mu.Lock()
+ olds = t.tg.itimerProfSetting
+ t.tg.itimerProfSetting = news
+ t.tg.updateCPUTimersEnabledLocked()
+ t.tg.signalHandlers.mu.Unlock()
+ })
+ if err != nil {
+ return linux.ItimerVal{}, err
+ }
+ default:
+ return linux.ItimerVal{}, syserror.EINVAL
+ }
+ oldval, oldiv := ktime.SpecFromSetting(tm, olds)
+ return linux.ItimerVal{
+ Value: linux.DurationToTimeval(oldval),
+ Interval: linux.DurationToTimeval(oldiv),
+ }, nil
+}
+
+// IOUsage returns the io usage of the thread.
+func (t *Task) IOUsage() *usage.IO {
+ return t.ioUsage
+}
+
+// IOUsage returns the total io usage of all dead and live threads in the group.
+func (tg *ThreadGroup) IOUsage() *usage.IO {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+
+ io := *tg.ioUsage
+ // Account for active tasks.
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ io.Accumulate(t.IOUsage())
+ }
+ return &io
+}
+
+// Name returns t's name.
+func (t *Task) Name() string {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.tc.Name
+}
+
+// SetName changes t's name.
+func (t *Task) SetName(name string) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.tc.Name = name
+ t.Debugf("Set thread name to %q", name)
+}
+
+// Limits implements context.Context.Limits.
+func (t *Task) Limits() *limits.LimitSet {
+ return t.ThreadGroup().Limits()
+}
+
+// StartTime returns t's start time.
+func (t *Task) StartTime() ktime.Time {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.startTime
+}
+
+// MaxRSS returns the maximum resident set size of the task in bytes. which
+// should be one of RUSAGE_SELF, RUSAGE_CHILDREN, RUSAGE_THREAD, or
+// RUSAGE_BOTH. See getrusage(2) for documentation on the behavior of these
+// flags.
+func (t *Task) MaxRSS(which int32) uint64 {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+
+ switch which {
+ case linux.RUSAGE_SELF, linux.RUSAGE_THREAD:
+ // If there's an active mm we can use its value.
+ if mm := t.MemoryManager(); mm != nil {
+ if mmMaxRSS := mm.MaxResidentSetSize(); mmMaxRSS > t.tg.maxRSS {
+ return mmMaxRSS
+ }
+ }
+ return t.tg.maxRSS
+ case linux.RUSAGE_CHILDREN:
+ return t.tg.childMaxRSS
+ case linux.RUSAGE_BOTH:
+ maxRSS := t.tg.maxRSS
+ if maxRSS < t.tg.childMaxRSS {
+ maxRSS = t.tg.childMaxRSS
+ }
+ if mm := t.MemoryManager(); mm != nil {
+ if mmMaxRSS := mm.MaxResidentSetSize(); mmMaxRSS > maxRSS {
+ return mmMaxRSS
+ }
+ }
+ return maxRSS
+ default:
+ // We'll only get here if which is invalid.
+ return 0
+ }
+}
diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go
new file mode 100644
index 000000000..4a4a69ee2
--- /dev/null
+++ b/pkg/sentry/kernel/task_block.go
@@ -0,0 +1,230 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "runtime"
+ "runtime/trace"
+ "time"
+
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// BlockWithTimeout blocks t until an event is received from C, the application
+// monotonic clock indicates that timeout has elapsed (only if haveTimeout is true),
+// or t is interrupted. It returns:
+//
+// - The remaining timeout, which is guaranteed to be 0 if the timeout expired,
+// and is unspecified if haveTimeout is false.
+//
+// - An error which is nil if an event is received from C, ETIMEDOUT if the timeout
+// expired, and syserror.ErrInterrupted if t is interrupted.
+func (t *Task) BlockWithTimeout(C chan struct{}, haveTimeout bool, timeout time.Duration) (time.Duration, error) {
+ if !haveTimeout {
+ return timeout, t.block(C, nil)
+ }
+
+ start := t.Kernel().MonotonicClock().Now()
+ deadline := start.Add(timeout)
+ err := t.BlockWithDeadline(C, true, deadline)
+
+ // Timeout, explicitly return a remaining duration of 0.
+ if err == syserror.ETIMEDOUT {
+ return 0, err
+ }
+
+ // Compute the remaining timeout. Note that even if block() above didn't
+ // return due to a timeout, we may have used up any of the remaining time
+ // since then. We cap the remaining timeout to 0 to make it easier to
+ // directly use the returned duration.
+ end := t.Kernel().MonotonicClock().Now()
+ remainingTimeout := timeout - end.Sub(start)
+ if remainingTimeout < 0 {
+ remainingTimeout = 0
+ }
+
+ return remainingTimeout, err
+}
+
+// BlockWithDeadline blocks t until an event is received from C, the
+// application monotonic clock indicates a time of deadline (only if
+// haveDeadline is true), or t is interrupted. It returns nil if an event is
+// received from C, ETIMEDOUT if the deadline expired, and
+// syserror.ErrInterrupted if t is interrupted.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) BlockWithDeadline(C chan struct{}, haveDeadline bool, deadline ktime.Time) error {
+ if !haveDeadline {
+ return t.block(C, nil)
+ }
+
+ // Start the timeout timer.
+ t.blockingTimer.Swap(ktime.Setting{
+ Enabled: true,
+ Next: deadline,
+ })
+
+ err := t.block(C, t.blockingTimerChan)
+
+ // Stop the timeout timer and drain the channel.
+ t.blockingTimer.Swap(ktime.Setting{})
+ select {
+ case <-t.blockingTimerChan:
+ default:
+ }
+
+ return err
+}
+
+// BlockWithTimer blocks t until an event is received from C or tchan, or t is
+// interrupted. It returns nil if an event is received from C, ETIMEDOUT if an
+// event is received from tchan, and syserror.ErrInterrupted if t is
+// interrupted.
+//
+// Most clients should use BlockWithDeadline or BlockWithTimeout instead.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) BlockWithTimer(C <-chan struct{}, tchan <-chan struct{}) error {
+ return t.block(C, tchan)
+}
+
+// Block blocks t until an event is received from C or t is interrupted. It
+// returns nil if an event is received from C and syserror.ErrInterrupted if t
+// is interrupted.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) Block(C <-chan struct{}) error {
+ return t.block(C, nil)
+}
+
+// block blocks a task on one of many events.
+// N.B. defer is too expensive to be used here.
+func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error {
+ // Fast path if the request is already done.
+ select {
+ case <-C:
+ return nil
+ default:
+ }
+
+ // Deactive our address space, we don't need it.
+ interrupt := t.SleepStart()
+
+ // If the request is not completed, but the timer has already expired,
+ // then ensure that we run through a scheduler cycle. This is because
+ // we may see applications relying on timer slack to yield the thread.
+ // For example, they may attempt to sleep for some number of nanoseconds,
+ // and expect that this will actually yield the CPU and sleep for at
+ // least microseconds, e.g.:
+ // https://github.com/LMAX-Exchange/disruptor/commit/6ca210f2bcd23f703c479804d583718e16f43c07
+ if len(timerChan) > 0 {
+ runtime.Gosched()
+ }
+
+ region := trace.StartRegion(t.traceContext, blockRegion)
+ select {
+ case <-C:
+ region.End()
+ t.SleepFinish(true)
+ // Woken by event.
+ return nil
+
+ case <-interrupt:
+ region.End()
+ t.SleepFinish(false)
+ // Return the indicated error on interrupt.
+ return syserror.ErrInterrupted
+
+ case <-timerChan:
+ region.End()
+ t.SleepFinish(true)
+ // We've timed out.
+ return syserror.ETIMEDOUT
+ }
+}
+
+// SleepStart implements amutex.Sleeper.SleepStart.
+func (t *Task) SleepStart() <-chan struct{} {
+ t.Deactivate()
+ t.accountTaskGoroutineEnter(TaskGoroutineBlockedInterruptible)
+ return t.interruptChan
+}
+
+// SleepFinish implements amutex.Sleeper.SleepFinish.
+func (t *Task) SleepFinish(success bool) {
+ if !success {
+ // The interrupted notification is consumed only at the top-level
+ // (Run). Therefore we attempt to reset the pending notification.
+ // This will also elide our next entry back into the task, so we
+ // will process signals, state changes, etc.
+ t.interruptSelf()
+ }
+ t.accountTaskGoroutineLeave(TaskGoroutineBlockedInterruptible)
+ t.Activate()
+}
+
+// Interrupted implements amutex.Sleeper.Interrupted
+func (t *Task) Interrupted() bool {
+ return len(t.interruptChan) != 0
+}
+
+// UninterruptibleSleepStart implements context.Context.UninterruptibleSleepStart.
+func (t *Task) UninterruptibleSleepStart(deactivate bool) {
+ if deactivate {
+ t.Deactivate()
+ }
+ t.accountTaskGoroutineEnter(TaskGoroutineBlockedUninterruptible)
+}
+
+// UninterruptibleSleepFinish implements context.Context.UninterruptibleSleepFinish.
+func (t *Task) UninterruptibleSleepFinish(activate bool) {
+ t.accountTaskGoroutineLeave(TaskGoroutineBlockedUninterruptible)
+ if activate {
+ t.Activate()
+ }
+}
+
+// interrupted returns true if interrupt or interruptSelf has been called at
+// least once since the last call to interrupted.
+func (t *Task) interrupted() bool {
+ select {
+ case <-t.interruptChan:
+ return true
+ default:
+ return false
+ }
+}
+
+// interrupt unblocks the task and interrupts it if it's currently running in
+// userspace.
+func (t *Task) interrupt() {
+ t.interruptSelf()
+ t.p.Interrupt()
+}
+
+// interruptSelf is like Interrupt, but can only be called by the task
+// goroutine.
+func (t *Task) interruptSelf() {
+ select {
+ case t.interruptChan <- struct{}{}:
+ t.Debugf("Interrupt queued")
+ default:
+ t.Debugf("Dropping duplicate interrupt")
+ }
+ // platform.Context.Interrupt() is unnecessary since a task goroutine
+ // calling interruptSelf() cannot also be blocked in
+ // platform.Context.Switch().
+}
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
new file mode 100644
index 000000000..e1ecca99e
--- /dev/null
+++ b/pkg/sentry/kernel/task_clone.go
@@ -0,0 +1,540 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// SharingOptions controls what resources are shared by a new task created by
+// Task.Clone, or an existing task affected by Task.Unshare.
+type SharingOptions struct {
+ // If NewAddressSpace is true, the task should have an independent virtual
+ // address space.
+ NewAddressSpace bool
+
+ // If NewSignalHandlers is true, the task should use an independent set of
+ // signal handlers.
+ NewSignalHandlers bool
+
+ // If NewThreadGroup is true, the task should be the leader of its own
+ // thread group. TerminationSignal is the signal that the thread group
+ // will send to its parent when it exits. If NewThreadGroup is false,
+ // TerminationSignal is ignored.
+ NewThreadGroup bool
+ TerminationSignal linux.Signal
+
+ // If NewPIDNamespace is true:
+ //
+ // - In the context of Task.Clone, the new task should be the init task
+ // (TID 1) in a new PID namespace.
+ //
+ // - In the context of Task.Unshare, the task should create a new PID
+ // namespace, and all subsequent clones of the task should be members of
+ // the new PID namespace.
+ NewPIDNamespace bool
+
+ // If NewUserNamespace is true, the task should have an independent user
+ // namespace.
+ NewUserNamespace bool
+
+ // If NewNetworkNamespace is true, the task should have an independent
+ // network namespace.
+ NewNetworkNamespace bool
+
+ // If NewFiles is true, the task should use an independent file descriptor
+ // table.
+ NewFiles bool
+
+ // If NewFSContext is true, the task should have an independent FSContext.
+ NewFSContext bool
+
+ // If NewUTSNamespace is true, the task should have an independent UTS
+ // namespace.
+ NewUTSNamespace bool
+
+ // If NewIPCNamespace is true, the task should have an independent IPC
+ // namespace.
+ NewIPCNamespace bool
+}
+
+// CloneOptions controls the behavior of Task.Clone.
+type CloneOptions struct {
+ // SharingOptions defines the set of resources that the new task will share
+ // with its parent.
+ SharingOptions
+
+ // Stack is the initial stack pointer of the new task. If Stack is 0, the
+ // new task will start with the same stack pointer as its parent.
+ Stack usermem.Addr
+
+ // If SetTLS is true, set the new task's TLS (thread-local storage)
+ // descriptor to TLS. If SetTLS is false, TLS is ignored.
+ SetTLS bool
+ TLS usermem.Addr
+
+ // If ChildClearTID is true, when the child exits, 0 is written to the
+ // address ChildTID in the child's memory, and if the write is successful a
+ // futex wake on the same address is performed.
+ //
+ // If ChildSetTID is true, the child's thread ID (in the child's PID
+ // namespace) is written to address ChildTID in the child's memory. (As in
+ // Linux, failed writes are silently ignored.)
+ ChildClearTID bool
+ ChildSetTID bool
+ ChildTID usermem.Addr
+
+ // If ParentSetTID is true, the child's thread ID (in the parent's PID
+ // namespace) is written to address ParentTID in the parent's memory. (As
+ // in Linux, failed writes are silently ignored.)
+ //
+ // Older versions of the clone(2) man page state that CLONE_PARENT_SETTID
+ // causes the child's thread ID to be written to ptid in both the parent
+ // and child's memory, but this is a documentation error fixed by
+ // 87ab04792ced ("clone.2: Fix description of CLONE_PARENT_SETTID").
+ ParentSetTID bool
+ ParentTID usermem.Addr
+
+ // If Vfork is true, place the parent in vforkStop until the cloned task
+ // releases its TaskContext.
+ Vfork bool
+
+ // If Untraced is true, do not report PTRACE_EVENT_CLONE/FORK/VFORK for
+ // this clone(), and do not ptrace-attach the caller's tracer to the new
+ // task. (PTRACE_EVENT_VFORK_DONE will still be reported if appropriate).
+ Untraced bool
+
+ // If InheritTracer is true, ptrace-attach the caller's tracer to the new
+ // task, even if no PTRACE_EVENT_CLONE/FORK/VFORK event would be reported
+ // for it. If both Untraced and InheritTracer are true, no event will be
+ // reported, but tracer inheritance will still occur.
+ InheritTracer bool
+}
+
+// Clone implements the clone(2) syscall and returns the thread ID of the new
+// task in t's PID namespace. Clone may return both a non-zero thread ID and a
+// non-nil error.
+//
+// Preconditions: The caller must be running Task.doSyscallInvoke on the task
+// goroutine.
+func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
+ // Since signal actions may refer to application signal handlers by virtual
+ // address, any set of signal handlers must refer to the same address
+ // space.
+ if !opts.NewSignalHandlers && opts.NewAddressSpace {
+ return 0, nil, syserror.EINVAL
+ }
+ // In order for the behavior of thread-group-directed signals to be sane,
+ // all tasks in a thread group must share signal handlers.
+ if !opts.NewThreadGroup && opts.NewSignalHandlers {
+ return 0, nil, syserror.EINVAL
+ }
+ // All tasks in a thread group must be in the same PID namespace.
+ if !opts.NewThreadGroup && (opts.NewPIDNamespace || t.childPIDNamespace != nil) {
+ return 0, nil, syserror.EINVAL
+ }
+ // The two different ways of specifying a new PID namespace are
+ // incompatible.
+ if opts.NewPIDNamespace && t.childPIDNamespace != nil {
+ return 0, nil, syserror.EINVAL
+ }
+ // Thread groups and FS contexts cannot span user namespaces.
+ if opts.NewUserNamespace && (!opts.NewThreadGroup || !opts.NewFSContext) {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // "If CLONE_NEWUSER is specified along with other CLONE_NEW* flags in a
+ // single clone(2) or unshare(2) call, the user namespace is guaranteed to
+ // be created first, giving the child (clone(2)) or caller (unshare(2))
+ // privileges over the remaining namespaces created by the call." -
+ // user_namespaces(7)
+ creds := t.Credentials()
+ userns := creds.UserNamespace
+ if opts.NewUserNamespace {
+ var err error
+ // "EPERM (since Linux 3.9): CLONE_NEWUSER was specified in flags and
+ // the caller is in a chroot environment (i.e., the caller's root
+ // directory does not match the root directory of the mount namespace
+ // in which it resides)." - clone(2). Neither chroot(2) nor
+ // user_namespaces(7) document this.
+ if t.IsChrooted() {
+ return 0, nil, syserror.EPERM
+ }
+ userns, err = creds.NewChildUserNamespace()
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+ if (opts.NewPIDNamespace || opts.NewNetworkNamespace || opts.NewUTSNamespace) && !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, userns) {
+ return 0, nil, syserror.EPERM
+ }
+
+ utsns := t.UTSNamespace()
+ if opts.NewUTSNamespace {
+ // Note that this must happen after NewUserNamespace so we get
+ // the new userns if there is one.
+ utsns = t.UTSNamespace().Clone(userns)
+ }
+
+ ipcns := t.IPCNamespace()
+ if opts.NewIPCNamespace {
+ // Note that "If CLONE_NEWIPC is set, then create the process in a new IPC
+ // namespace"
+ ipcns = NewIPCNamespace(userns)
+ }
+
+ netns := t.NetworkNamespace()
+ if opts.NewNetworkNamespace {
+ netns = inet.NewNamespace(netns)
+ }
+
+ // TODO(b/63601033): Implement CLONE_NEWNS.
+ mntnsVFS2 := t.mountNamespaceVFS2
+ if mntnsVFS2 != nil {
+ mntnsVFS2.IncRef()
+ }
+
+ tc, err := t.tc.Fork(t, t.k, !opts.NewAddressSpace)
+ if err != nil {
+ return 0, nil, err
+ }
+ // clone() returns 0 in the child.
+ tc.Arch.SetReturn(0)
+ if opts.Stack != 0 {
+ tc.Arch.SetStack(uintptr(opts.Stack))
+ }
+ if opts.SetTLS {
+ if !tc.Arch.SetTLS(uintptr(opts.TLS)) {
+ return 0, nil, syserror.EPERM
+ }
+ }
+
+ var fsContext *FSContext
+ if opts.NewFSContext {
+ fsContext = t.fsContext.Fork()
+ } else {
+ fsContext = t.fsContext
+ fsContext.IncRef()
+ }
+
+ var fdTable *FDTable
+ if opts.NewFiles {
+ fdTable = t.fdTable.Fork()
+ } else {
+ fdTable = t.fdTable
+ fdTable.IncRef()
+ }
+
+ pidns := t.tg.pidns
+ if t.childPIDNamespace != nil {
+ pidns = t.childPIDNamespace
+ } else if opts.NewPIDNamespace {
+ pidns = pidns.NewChild(userns)
+ }
+
+ tg := t.tg
+ rseqAddr := usermem.Addr(0)
+ rseqSignature := uint32(0)
+ if opts.NewThreadGroup {
+ if tg.mounts != nil {
+ tg.mounts.IncRef()
+ }
+ sh := t.tg.signalHandlers
+ if opts.NewSignalHandlers {
+ sh = sh.Fork()
+ }
+ tg = t.k.NewThreadGroup(tg.mounts, pidns, sh, opts.TerminationSignal, tg.limits.GetCopy())
+ tg.oomScoreAdj = atomic.LoadInt32(&t.tg.oomScoreAdj)
+ rseqAddr = t.rseqAddr
+ rseqSignature = t.rseqSignature
+ }
+
+ cfg := &TaskConfig{
+ Kernel: t.k,
+ ThreadGroup: tg,
+ SignalMask: t.SignalMask(),
+ TaskContext: tc,
+ FSContext: fsContext,
+ FDTable: fdTable,
+ Credentials: creds,
+ Niceness: t.Niceness(),
+ NetworkNamespace: netns,
+ AllowedCPUMask: t.CPUMask(),
+ UTSNamespace: utsns,
+ IPCNamespace: ipcns,
+ AbstractSocketNamespace: t.abstractSockets,
+ MountNamespaceVFS2: mntnsVFS2,
+ RSeqAddr: rseqAddr,
+ RSeqSignature: rseqSignature,
+ ContainerID: t.ContainerID(),
+ }
+ if opts.NewThreadGroup {
+ cfg.Parent = t
+ } else {
+ cfg.InheritParent = t
+ }
+ nt, err := t.tg.pidns.owner.NewTask(cfg)
+ if err != nil {
+ if opts.NewThreadGroup {
+ tg.release()
+ }
+ return 0, nil, err
+ }
+
+ // "A child process created via fork(2) inherits a copy of its parent's
+ // alternate signal stack settings" - sigaltstack(2).
+ //
+ // However kernel/fork.c:copy_process() adds a limitation to this:
+ // "sigaltstack should be cleared when sharing the same VM".
+ if opts.NewAddressSpace || opts.Vfork {
+ nt.SetSignalStack(t.SignalStack())
+ }
+
+ if userns != creds.UserNamespace {
+ if err := nt.SetUserNamespace(userns); err != nil {
+ // This shouldn't be possible: userns was created from nt.creds, so
+ // nt should have CAP_SYS_ADMIN in userns.
+ panic("Task.Clone: SetUserNamespace failed: " + err.Error())
+ }
+ }
+
+ // This has to happen last, because e.g. ptraceClone may send a SIGSTOP to
+ // nt that it must receive before its task goroutine starts running.
+ tid := nt.k.tasks.Root.IDOfTask(nt)
+ defer nt.Start(tid)
+ t.traceCloneEvent(tid)
+
+ // "If fork/clone and execve are allowed by @prog, any child processes will
+ // be constrained to the same filters and system call ABI as the parent." -
+ // Documentation/prctl/seccomp_filter.txt
+ if f := t.syscallFilters.Load(); f != nil {
+ copiedFilters := append([]bpf.Program(nil), f.([]bpf.Program)...)
+ nt.syscallFilters.Store(copiedFilters)
+ }
+ if opts.Vfork {
+ nt.vforkParent = t
+ }
+
+ if opts.ChildClearTID {
+ nt.SetClearTID(opts.ChildTID)
+ }
+ if opts.ChildSetTID {
+ // Can't use Task.CopyOut, which assumes AddressSpaceActive.
+ usermem.CopyObjectOut(t, nt.MemoryManager(), opts.ChildTID, nt.ThreadID(), usermem.IOOpts{})
+ }
+ ntid := t.tg.pidns.IDOfTask(nt)
+ if opts.ParentSetTID {
+ t.CopyOut(opts.ParentTID, ntid)
+ }
+
+ kind := ptraceCloneKindClone
+ if opts.Vfork {
+ kind = ptraceCloneKindVfork
+ } else if opts.TerminationSignal == linux.SIGCHLD {
+ kind = ptraceCloneKindFork
+ }
+ if t.ptraceClone(kind, nt, opts) {
+ if opts.Vfork {
+ return ntid, &SyscallControl{next: &runSyscallAfterPtraceEventClone{vforkChild: nt, vforkChildTID: ntid}}, nil
+ }
+ return ntid, &SyscallControl{next: &runSyscallAfterPtraceEventClone{}}, nil
+ }
+ if opts.Vfork {
+ t.maybeBeginVforkStop(nt)
+ return ntid, &SyscallControl{next: &runSyscallAfterVforkStop{childTID: ntid}}, nil
+ }
+ return ntid, nil, nil
+}
+
+// maybeBeginVforkStop checks if a previously-started vfork child is still
+// running and has not yet released its MM, such that its parent t should enter
+// a vforkStop.
+//
+// Preconditions: The caller must be running on t's task goroutine.
+func (t *Task) maybeBeginVforkStop(child *Task) {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ if t.killedLocked() {
+ child.vforkParent = nil
+ return
+ }
+ if child.vforkParent == t {
+ t.beginInternalStopLocked((*vforkStop)(nil))
+ }
+}
+
+func (t *Task) unstopVforkParent() {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ if p := t.vforkParent; p != nil {
+ p.tg.signalHandlers.mu.Lock()
+ defer p.tg.signalHandlers.mu.Unlock()
+ if _, ok := p.stop.(*vforkStop); ok {
+ p.endInternalStopLocked()
+ }
+ // Parent no longer needs to be unstopped.
+ t.vforkParent = nil
+ }
+}
+
+// +stateify savable
+type runSyscallAfterPtraceEventClone struct {
+ vforkChild *Task
+
+ // If vforkChild is not nil, vforkChildTID is its thread ID in the parent's
+ // PID namespace. vforkChildTID must be stored since the child may exit and
+ // release its TID before the PTRACE_EVENT stop ends.
+ vforkChildTID ThreadID
+}
+
+func (r *runSyscallAfterPtraceEventClone) execute(t *Task) taskRunState {
+ if r.vforkChild != nil {
+ t.maybeBeginVforkStop(r.vforkChild)
+ return &runSyscallAfterVforkStop{r.vforkChildTID}
+ }
+ return (*runSyscallExit)(nil)
+}
+
+// +stateify savable
+type runSyscallAfterVforkStop struct {
+ // childTID has the same meaning as
+ // runSyscallAfterPtraceEventClone.vforkChildTID.
+ childTID ThreadID
+}
+
+func (r *runSyscallAfterVforkStop) execute(t *Task) taskRunState {
+ t.ptraceVforkDone(r.childTID)
+ return (*runSyscallExit)(nil)
+}
+
+// Unshare changes the set of resources t shares with other tasks, as specified
+// by opts.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) Unshare(opts *SharingOptions) error {
+ // In Linux unshare(2), NewThreadGroup implies NewSignalHandlers and
+ // NewSignalHandlers implies NewAddressSpace. All three flags are no-ops if
+ // t is the only task using its MM, which due to clone(2)'s rules imply
+ // that it is also the only task using its signal handlers / in its thread
+ // group, and cause EINVAL to be returned otherwise.
+ //
+ // Since we don't count the number of tasks using each address space or set
+ // of signal handlers, we reject NewSignalHandlers and NewAddressSpace
+ // altogether, and interpret NewThreadGroup as requiring that t be the only
+ // member of its thread group. This seems to be logically coherent, in the
+ // sense that clone(2) allows a task to share signal handlers and address
+ // spaces with tasks in other thread groups.
+ if opts.NewAddressSpace || opts.NewSignalHandlers {
+ return syserror.EINVAL
+ }
+ creds := t.Credentials()
+ if opts.NewThreadGroup {
+ t.tg.signalHandlers.mu.Lock()
+ if t.tg.tasksCount != 1 {
+ t.tg.signalHandlers.mu.Unlock()
+ return syserror.EINVAL
+ }
+ t.tg.signalHandlers.mu.Unlock()
+ // This isn't racy because we're the only living task, and therefore
+ // the only task capable of creating new ones, in our thread group.
+ }
+ if opts.NewUserNamespace {
+ if t.IsChrooted() {
+ return syserror.EPERM
+ }
+ newUserNS, err := creds.NewChildUserNamespace()
+ if err != nil {
+ return err
+ }
+ err = t.SetUserNamespace(newUserNS)
+ if err != nil {
+ return err
+ }
+ // Need to reload creds, becaue t.SetUserNamespace() changed task credentials.
+ creds = t.Credentials()
+ }
+ haveCapSysAdmin := t.HasCapability(linux.CAP_SYS_ADMIN)
+ if opts.NewPIDNamespace {
+ if !haveCapSysAdmin {
+ return syserror.EPERM
+ }
+ t.childPIDNamespace = t.tg.pidns.NewChild(t.UserNamespace())
+ }
+ t.mu.Lock()
+ // Can't defer unlock: DecRefs must occur without holding t.mu.
+ if opts.NewNetworkNamespace {
+ if !haveCapSysAdmin {
+ t.mu.Unlock()
+ return syserror.EPERM
+ }
+ t.netns = inet.NewNamespace(t.netns)
+ }
+ if opts.NewUTSNamespace {
+ if !haveCapSysAdmin {
+ t.mu.Unlock()
+ return syserror.EPERM
+ }
+ // Note that this must happen after NewUserNamespace, so the
+ // new user namespace is used if there is one.
+ t.utsns = t.utsns.Clone(creds.UserNamespace)
+ }
+ if opts.NewIPCNamespace {
+ if !haveCapSysAdmin {
+ t.mu.Unlock()
+ return syserror.EPERM
+ }
+ // Note that "If CLONE_NEWIPC is set, then create the process in a new IPC
+ // namespace"
+ t.ipcns = NewIPCNamespace(creds.UserNamespace)
+ }
+ var oldFDTable *FDTable
+ if opts.NewFiles {
+ oldFDTable = t.fdTable
+ t.fdTable = oldFDTable.Fork()
+ }
+ var oldFSContext *FSContext
+ if opts.NewFSContext {
+ oldFSContext = t.fsContext
+ t.fsContext = oldFSContext.Fork()
+ }
+ t.mu.Unlock()
+ if oldFDTable != nil {
+ oldFDTable.DecRef()
+ }
+ if oldFSContext != nil {
+ oldFSContext.DecRef()
+ }
+ return nil
+}
+
+// vforkStop is a TaskStop imposed on a task that creates a child with
+// CLONE_VFORK or vfork(2), that ends when the child task ceases to use its
+// current MM. (Normally, CLONE_VFORK is used in conjunction with CLONE_VM, so
+// that the child and parent share mappings until the child execve()s into a
+// new process image or exits.)
+//
+// +stateify savable
+type vforkStop struct{}
+
+// StopIgnoresKill implements TaskStop.Killable.
+func (*vforkStop) Killable() bool { return true }
diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go
new file mode 100644
index 000000000..9fa528384
--- /dev/null
+++ b/pkg/sentry/kernel/task_context.go
@@ -0,0 +1,169 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/futex"
+ "gvisor.dev/gvisor/pkg/sentry/loader"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+var errNoSyscalls = syserr.New("no syscall table found", linux.ENOEXEC)
+
+// Auxmap contains miscellaneous data for the task.
+type Auxmap map[string]interface{}
+
+// TaskContext is the subset of a task's data that is provided by the loader.
+//
+// +stateify savable
+type TaskContext struct {
+ // Name is the thread name set by the prctl(PR_SET_NAME) system call.
+ Name string
+
+ // Arch is the architecture-specific context (registers, etc.)
+ Arch arch.Context
+
+ // MemoryManager is the task's address space.
+ MemoryManager *mm.MemoryManager
+
+ // fu implements futexes in the address space.
+ fu *futex.Manager
+
+ // st is the task's syscall table.
+ st *SyscallTable `state:".(syscallTableInfo)"`
+}
+
+// release releases all resources held by the TaskContext. release is called by
+// the task when it execs into a new TaskContext or exits.
+func (tc *TaskContext) release() {
+ // Nil out pointers so that if the task is saved after release, it doesn't
+ // follow the pointers to possibly now-invalid objects.
+ if tc.MemoryManager != nil {
+ tc.MemoryManager.DecUsers(context.Background())
+ tc.MemoryManager = nil
+ }
+ tc.fu = nil
+}
+
+// Fork returns a duplicate of tc. The copied TaskContext always has an
+// independent arch.Context. If shareAddressSpace is true, the copied
+// TaskContext shares an address space with the original; otherwise, the copied
+// TaskContext has an independent address space that is initially a duplicate
+// of the original's.
+func (tc *TaskContext) Fork(ctx context.Context, k *Kernel, shareAddressSpace bool) (*TaskContext, error) {
+ newTC := &TaskContext{
+ Name: tc.Name,
+ Arch: tc.Arch.Fork(),
+ st: tc.st,
+ }
+ if shareAddressSpace {
+ newTC.MemoryManager = tc.MemoryManager
+ if newTC.MemoryManager != nil {
+ if !newTC.MemoryManager.IncUsers() {
+ // Shouldn't be possible since tc.MemoryManager should be a
+ // counted user.
+ panic(fmt.Sprintf("TaskContext.Fork called with userless TaskContext.MemoryManager"))
+ }
+ }
+ newTC.fu = tc.fu
+ } else {
+ newMM, err := tc.MemoryManager.Fork(ctx)
+ if err != nil {
+ return nil, err
+ }
+ newTC.MemoryManager = newMM
+ newTC.fu = k.futexes.Fork()
+ }
+ return newTC, nil
+}
+
+// Arch returns t's arch.Context.
+//
+// Preconditions: The caller must be running on the task goroutine, or t.mu
+// must be locked.
+func (t *Task) Arch() arch.Context {
+ return t.tc.Arch
+}
+
+// MemoryManager returns t's MemoryManager. MemoryManager does not take an
+// additional reference on the returned MM.
+//
+// Preconditions: The caller must be running on the task goroutine, or t.mu
+// must be locked.
+func (t *Task) MemoryManager() *mm.MemoryManager {
+ return t.tc.MemoryManager
+}
+
+// SyscallTable returns t's syscall table.
+//
+// Preconditions: The caller must be running on the task goroutine, or t.mu
+// must be locked.
+func (t *Task) SyscallTable() *SyscallTable {
+ return t.tc.st
+}
+
+// Stack returns the userspace stack.
+//
+// Preconditions: The caller must be running on the task goroutine, or t.mu
+// must be locked.
+func (t *Task) Stack() *arch.Stack {
+ return &arch.Stack{t.Arch(), t.MemoryManager(), usermem.Addr(t.Arch().Stack())}
+}
+
+// LoadTaskImage loads a specified file into a new TaskContext.
+//
+// args.MemoryManager does not need to be set by the caller.
+func (k *Kernel) LoadTaskImage(ctx context.Context, args loader.LoadArgs) (*TaskContext, *syserr.Error) {
+ // If File is not nil, we should load that instead of resolving Filename.
+ if args.File != nil {
+ args.Filename = args.File.PathnameWithDeleted(ctx)
+ }
+
+ // Prepare a new user address space to load into.
+ m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation)
+ defer m.DecUsers(ctx)
+ args.MemoryManager = m
+
+ os, ac, name, err := loader.Load(ctx, args, k.extraAuxv, k.vdso)
+ if err != nil {
+ return nil, err
+ }
+
+ // Lookup our new syscall table.
+ st, ok := LookupSyscallTable(os, ac.Arch())
+ if !ok {
+ // No syscall table found. This means that the ELF binary does not match
+ // the architecture.
+ return nil, errNoSyscalls
+ }
+
+ if !m.IncUsers() {
+ panic("Failed to increment users count on new MM")
+ }
+ return &TaskContext{
+ Name: name,
+ Arch: ac,
+ MemoryManager: m,
+ fu: k.futexes.Fork(),
+ st: st,
+ }, nil
+}
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
new file mode 100644
index 000000000..9b69f3cbe
--- /dev/null
+++ b/pkg/sentry/kernel/task_exec.go
@@ -0,0 +1,277 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+// This file implements the machinery behind the execve() syscall. In brief, a
+// thread executes an execve() by killing all other threads in its thread
+// group, assuming the leader's identity, and then switching process images.
+//
+// This design is effectively mandated by Linux. From ptrace(2):
+//
+// """
+// execve(2) under ptrace
+// When one thread in a multithreaded process calls execve(2), the
+// kernel destroys all other threads in the process, and resets the
+// thread ID of the execing thread to the thread group ID (process ID).
+// (Or, to put things another way, when a multithreaded process does an
+// execve(2), at completion of the call, it appears as though the
+// execve(2) occurred in the thread group leader, regardless of which
+// thread did the execve(2).) This resetting of the thread ID looks
+// very confusing to tracers:
+//
+// * All other threads stop in PTRACE_EVENT_EXIT stop, if the
+// PTRACE_O_TRACEEXIT option was turned on. Then all other threads
+// except the thread group leader report death as if they exited via
+// _exit(2) with exit code 0.
+//
+// * The execing tracee changes its thread ID while it is in the
+// execve(2). (Remember, under ptrace, the "pid" returned from
+// waitpid(2), or fed into ptrace calls, is the tracee's thread ID.)
+// That is, the tracee's thread ID is reset to be the same as its
+// process ID, which is the same as the thread group leader's thread
+// ID.
+//
+// * Then a PTRACE_EVENT_EXEC stop happens, if the PTRACE_O_TRACEEXEC
+// option was turned on.
+//
+// * If the thread group leader has reported its PTRACE_EVENT_EXIT stop
+// by this time, it appears to the tracer that the dead thread leader
+// "reappears from nowhere". (Note: the thread group leader does not
+// report death via WIFEXITED(status) until there is at least one
+// other live thread. This eliminates the possibility that the
+// tracer will see it dying and then reappearing.) If the thread
+// group leader was still alive, for the tracer this may look as if
+// thread group leader returns from a different system call than it
+// entered, or even "returned from a system call even though it was
+// not in any system call". If the thread group leader was not
+// traced (or was traced by a different tracer), then during
+// execve(2) it will appear as if it has become a tracee of the
+// tracer of the execing tracee.
+//
+// All of the above effects are the artifacts of the thread ID change in
+// the tracee.
+// """
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// execStop is a TaskStop that a task sets on itself when it wants to execve
+// and is waiting for the other tasks in its thread group to exit first.
+//
+// +stateify savable
+type execStop struct{}
+
+// Killable implements TaskStop.Killable.
+func (*execStop) Killable() bool { return true }
+
+// Execve implements the execve(2) syscall by killing all other tasks in its
+// thread group and switching to newTC. Execve always takes ownership of newTC.
+//
+// Preconditions: The caller must be running Task.doSyscallInvoke on the task
+// goroutine.
+func (t *Task) Execve(newTC *TaskContext) (*SyscallControl, error) {
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+
+ if t.tg.exiting || t.tg.execing != nil {
+ // We lost to a racing group-exit, kill, or exec from another thread
+ // and should just exit.
+ newTC.release()
+ return nil, syserror.EINTR
+ }
+
+ // Cancel any racing group stops.
+ t.tg.endGroupStopLocked(false)
+
+ // If the task has any siblings, they have to exit before the exec can
+ // continue.
+ t.tg.execing = t
+ if t.tg.tasks.Front() != t.tg.tasks.Back() {
+ // "[All] other threads except the thread group leader report death as
+ // if they exited via _exit(2) with exit code 0." - ptrace(2)
+ for sibling := t.tg.tasks.Front(); sibling != nil; sibling = sibling.Next() {
+ if t != sibling {
+ sibling.killLocked()
+ }
+ }
+ // The last sibling to exit will wake t.
+ t.beginInternalStopLocked((*execStop)(nil))
+ }
+
+ return &SyscallControl{next: &runSyscallAfterExecStop{newTC}, ignoreReturn: true}, nil
+}
+
+// The runSyscallAfterExecStop state continues execve(2) after all siblings of
+// a thread in the execve syscall have exited.
+//
+// +stateify savable
+type runSyscallAfterExecStop struct {
+ tc *TaskContext
+}
+
+func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
+ t.traceExecEvent(r.tc)
+ t.tg.pidns.owner.mu.Lock()
+ t.tg.execing = nil
+ if t.killed() {
+ t.tg.pidns.owner.mu.Unlock()
+ r.tc.release()
+ return (*runInterrupt)(nil)
+ }
+ // We are the thread group leader now. Save our old thread ID for
+ // PTRACE_EVENT_EXEC. This is racy in that if a tracer attaches after this
+ // point it will get a PID of 0, but this is consistent with Linux.
+ oldTID := ThreadID(0)
+ if tracer := t.Tracer(); tracer != nil {
+ oldTID = tracer.tg.pidns.tids[t]
+ }
+ t.promoteLocked()
+ // "POSIX timers are not preserved (timer_create(2))." - execve(2). Handle
+ // this first since POSIX timers are protected by the signal mutex, which
+ // we're about to change. Note that we have to stop and destroy timers
+ // without holding any mutexes to avoid circular lock ordering.
+ var its []*IntervalTimer
+ t.tg.signalHandlers.mu.Lock()
+ for _, it := range t.tg.timers {
+ its = append(its, it)
+ }
+ t.tg.timers = make(map[linux.TimerID]*IntervalTimer)
+ t.tg.signalHandlers.mu.Unlock()
+ t.tg.pidns.owner.mu.Unlock()
+ for _, it := range its {
+ it.DestroyTimer()
+ }
+ t.tg.pidns.owner.mu.Lock()
+ // "During an execve(2), the dispositions of handled signals are reset to
+ // the default; the dispositions of ignored signals are left unchanged. ...
+ // [The] signal mask is preserved across execve(2). ... [The] pending
+ // signal set is preserved across an execve(2)." - signal(7)
+ //
+ // Details:
+ //
+ // - If the thread group is sharing its signal handlers with another thread
+ // group via CLONE_SIGHAND, execve forces the signal handlers to be copied
+ // (see Linux's fs/exec.c:de_thread). We're not reference-counting signal
+ // handlers, so we always make a copy.
+ //
+ // - "Disposition" only means sigaction::sa_handler/sa_sigaction; flags,
+ // restorer (if present), and mask are always reset. (See Linux's
+ // fs/exec.c:setup_new_exec => kernel/signal.c:flush_signal_handlers.)
+ t.tg.signalHandlers = t.tg.signalHandlers.CopyForExec()
+ t.endStopCond.L = &t.tg.signalHandlers.mu
+ // "Any alternate signal stack is not preserved (sigaltstack(2))." - execve(2)
+ t.signalStack = arch.SignalStack{Flags: arch.SignalStackFlagDisable}
+ // "The termination signal is reset to SIGCHLD (see clone(2))."
+ t.tg.terminationSignal = linux.SIGCHLD
+ // execed indicates that the process can no longer join a process group
+ // in some scenarios (namely, the parent call setpgid(2) on the child).
+ // See the JoinProcessGroup function in sessions.go for more context.
+ t.tg.execed = true
+ // Maximum RSS is preserved across execve(2).
+ t.updateRSSLocked()
+ // Restartable sequence state is discarded.
+ t.rseqPreempted = false
+ t.rseqCPU = -1
+ t.rseqAddr = 0
+ t.rseqSignature = 0
+ t.oldRSeqCPUAddr = 0
+ t.tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{})
+ t.tg.pidns.owner.mu.Unlock()
+
+ oldFDTable := t.fdTable
+ t.fdTable = t.fdTable.Fork()
+ oldFDTable.DecRef()
+
+ // Remove FDs with the CloseOnExec flag set.
+ t.fdTable.RemoveIf(func(_ *fs.File, _ *vfs.FileDescription, flags FDFlags) bool {
+ return flags.CloseOnExec
+ })
+
+ // NOTE(b/30815691): We currently do not implement privileged
+ // executables (set-user/group-ID bits and file capabilities). This
+ // allows us to unconditionally enable user dumpability on the new mm.
+ // See fs/exec.c:setup_new_exec.
+ r.tc.MemoryManager.SetDumpability(mm.UserDumpable)
+
+ // Switch to the new process.
+ t.MemoryManager().Deactivate()
+ t.mu.Lock()
+ // Update credentials to reflect the execve. This should precede switching
+ // MMs to ensure that dumpability has been reset first, if needed.
+ t.updateCredsForExecLocked()
+ t.tc.release()
+ t.tc = *r.tc
+ t.mu.Unlock()
+ t.unstopVforkParent()
+ // NOTE(b/30316266): All locks must be dropped prior to calling Activate.
+ t.MemoryManager().Activate(t)
+
+ t.ptraceExec(oldTID)
+ return (*runSyscallExit)(nil)
+}
+
+// promoteLocked makes t the leader of its thread group. If t is already the
+// thread group leader, promoteLocked is a no-op.
+//
+// Preconditions: All other tasks in t's thread group, including the existing
+// leader (if it is not t), have reached TaskExitZombie. The TaskSet mutex must
+// be locked for writing.
+func (t *Task) promoteLocked() {
+ oldLeader := t.tg.leader
+ if t == oldLeader {
+ return
+ }
+ // Swap the leader's TIDs with the execing task's. The latter will be
+ // released when the old leader is reaped below.
+ for ns := t.tg.pidns; ns != nil; ns = ns.parent {
+ oldTID, leaderTID := ns.tids[t], ns.tids[oldLeader]
+ ns.tids[oldLeader] = oldTID
+ ns.tids[t] = leaderTID
+ ns.tasks[oldTID] = oldLeader
+ ns.tasks[leaderTID] = t
+ // Neither the ThreadGroup nor TGID change, so no need to
+ // update ns.tgids.
+ }
+
+ // Inherit the old leader's start time.
+ oldStartTime := oldLeader.StartTime()
+ t.mu.Lock()
+ t.startTime = oldStartTime
+ t.mu.Unlock()
+
+ t.tg.leader = t
+ t.Infof("Becoming TID %d (in root PID namespace)", t.tg.pidns.owner.Root.tids[t])
+ t.updateInfoLocked()
+ // Reap the original leader. If it has a tracer, detach it instead of
+ // waiting for it to acknowledge the original leader's death.
+ oldLeader.exitParentNotified = true
+ oldLeader.exitParentAcked = true
+ if tracer := oldLeader.Tracer(); tracer != nil {
+ delete(tracer.ptraceTracees, oldLeader)
+ oldLeader.forgetTracerLocked()
+ // Notify the tracer that it will no longer be receiving these events
+ // from the tracee.
+ tracer.tg.eventQueue.Notify(EventExit | EventTraceeStop | EventGroupContinue)
+ }
+ oldLeader.exitNotifyLocked(false)
+}
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
new file mode 100644
index 000000000..c4ade6e8e
--- /dev/null
+++ b/pkg/sentry/kernel/task_exit.go
@@ -0,0 +1,1167 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+// This file implements the task exit cycle:
+//
+// - Tasks are asynchronously requested to exit with Task.Kill.
+//
+// - When able, the task goroutine enters the exit path starting from state
+// runExit.
+//
+// - Other tasks observe completed exits with Task.Wait (which implements the
+// wait*() family of syscalls).
+
+import (
+ "errors"
+ "fmt"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// An ExitStatus is a value communicated from an exiting task or thread group
+// to the party that reaps it.
+//
+// +stateify savable
+type ExitStatus struct {
+ // Code is the numeric value passed to the call to exit or exit_group that
+ // caused the exit. If the exit was not caused by such a call, Code is 0.
+ Code int
+
+ // Signo is the signal that caused the exit. If the exit was not caused by
+ // a signal, Signo is 0.
+ Signo int
+}
+
+// Signaled returns true if the ExitStatus indicates that the exiting task or
+// thread group was killed by a signal.
+func (es ExitStatus) Signaled() bool {
+ return es.Signo != 0
+}
+
+// Status returns the numeric representation of the ExitStatus returned by e.g.
+// the wait4() system call.
+func (es ExitStatus) Status() uint32 {
+ return ((uint32(es.Code) & 0xff) << 8) | (uint32(es.Signo) & 0xff)
+}
+
+// ShellExitCode returns the numeric exit code that Bash would return for an
+// exit status of es.
+func (es ExitStatus) ShellExitCode() int {
+ if es.Signaled() {
+ return 128 + es.Signo
+ }
+ return es.Code
+}
+
+// TaskExitState represents a step in the task exit path.
+//
+// "Exiting" and "exited" are often ambiguous; prefer to name specific states.
+type TaskExitState int
+
+const (
+ // TaskExitNone indicates that the task has not begun exiting.
+ TaskExitNone TaskExitState = iota
+
+ // TaskExitInitiated indicates that the task goroutine has entered the exit
+ // path, and the task is no longer eligible to participate in group stops
+ // or group signal handling. TaskExitInitiated is analogous to Linux's
+ // PF_EXITING.
+ TaskExitInitiated
+
+ // TaskExitZombie indicates that the task has released its resources, and
+ // the task no longer prevents a sibling thread from completing execve.
+ TaskExitZombie
+
+ // TaskExitDead indicates that the task's thread IDs have been released,
+ // and the task no longer prevents its thread group leader from being
+ // reaped. ("Reaping" refers to the transitioning of a task from
+ // TaskExitZombie to TaskExitDead.)
+ TaskExitDead
+)
+
+// String implements fmt.Stringer.
+func (t TaskExitState) String() string {
+ switch t {
+ case TaskExitNone:
+ return "TaskExitNone"
+ case TaskExitInitiated:
+ return "TaskExitInitiated"
+ case TaskExitZombie:
+ return "TaskExitZombie"
+ case TaskExitDead:
+ return "TaskExitDead"
+ default:
+ return strconv.Itoa(int(t))
+ }
+}
+
+// killLocked marks t as killed by enqueueing a SIGKILL, without causing the
+// thread-group-affecting side effects SIGKILL usually has.
+//
+// Preconditions: The signal mutex must be locked.
+func (t *Task) killLocked() {
+ // Clear killable stops.
+ if t.stop != nil && t.stop.Killable() {
+ t.endInternalStopLocked()
+ }
+ t.pendingSignals.enqueue(&arch.SignalInfo{
+ Signo: int32(linux.SIGKILL),
+ // Linux just sets SIGKILL in the pending signal bitmask without
+ // enqueueing an actual siginfo, such that
+ // kernel/signal.c:collect_signal() initializes si_code to SI_USER.
+ Code: arch.SignalInfoUser,
+ }, nil)
+ t.interrupt()
+}
+
+// killed returns true if t has a SIGKILL pending. killed is analogous to
+// Linux's fatal_signal_pending().
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) killed() bool {
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ return t.killedLocked()
+}
+
+func (t *Task) killedLocked() bool {
+ return t.pendingSignals.pendingSet&linux.SignalSetOf(linux.SIGKILL) != 0
+}
+
+// PrepareExit indicates an exit with status es.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) PrepareExit(es ExitStatus) {
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ t.exitStatus = es
+}
+
+// PrepareGroupExit indicates a group exit with status es to t's thread group.
+//
+// PrepareGroupExit is analogous to Linux's do_group_exit(), except that it
+// does not tail-call do_exit(), except that it *does* set Task.exitStatus.
+// (Linux does not do so until within do_exit(), since it reuses exit_code for
+// ptrace.)
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) PrepareGroupExit(es ExitStatus) {
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ if t.tg.exiting || t.tg.execing != nil {
+ // Note that if t.tg.exiting is false but t.tg.execing is not nil, i.e.
+ // this "group exit" is being executed by the killed sibling of an
+ // execing task, then Task.Execve never set t.tg.exitStatus, so it's
+ // still the zero value. This is consistent with Linux, both in intent
+ // ("all other threads ... report death as if they exited via _exit(2)
+ // with exit code 0" - ptrace(2), "execve under ptrace") and in
+ // implementation (compare fs/exec.c:de_thread() =>
+ // kernel/signal.c:zap_other_threads() and
+ // kernel/exit.c:do_group_exit() =>
+ // include/linux/sched.h:signal_group_exit()).
+ t.exitStatus = t.tg.exitStatus
+ return
+ }
+ t.tg.exiting = true
+ t.tg.exitStatus = es
+ t.exitStatus = es
+ for sibling := t.tg.tasks.Front(); sibling != nil; sibling = sibling.Next() {
+ if sibling != t {
+ sibling.killLocked()
+ }
+ }
+}
+
+// Kill requests that all tasks in ts exit as if group exiting with status es.
+// Kill does not wait for tasks to exit.
+//
+// Kill has no analogue in Linux; it's provided for save/restore only.
+func (ts *TaskSet) Kill(es ExitStatus) {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ ts.Root.exiting = true
+ for t := range ts.Root.tids {
+ t.tg.signalHandlers.mu.Lock()
+ if !t.tg.exiting {
+ t.tg.exiting = true
+ t.tg.exitStatus = es
+ }
+ t.killLocked()
+ t.tg.signalHandlers.mu.Unlock()
+ }
+}
+
+// advanceExitStateLocked checks that t's current exit state is oldExit, then
+// sets it to newExit. If t's current exit state is not oldExit,
+// advanceExitStateLocked panics.
+//
+// Preconditions: The TaskSet mutex must be locked.
+func (t *Task) advanceExitStateLocked(oldExit, newExit TaskExitState) {
+ if t.exitState != oldExit {
+ panic(fmt.Sprintf("Transitioning from exit state %v to %v: unexpected preceding state %v", oldExit, newExit, t.exitState))
+ }
+ t.Debugf("Transitioning from exit state %v to %v", oldExit, newExit)
+ t.exitState = newExit
+}
+
+// runExit is the entry point into the task exit path.
+//
+// +stateify savable
+type runExit struct{}
+
+func (*runExit) execute(t *Task) taskRunState {
+ t.ptraceExit()
+ return (*runExitMain)(nil)
+}
+
+// +stateify savable
+type runExitMain struct{}
+
+func (*runExitMain) execute(t *Task) taskRunState {
+ t.traceExitEvent()
+ lastExiter := t.exitThreadGroup()
+
+ // If the task has a cleartid, and the thread group wasn't killed by a
+ // signal, handle that before releasing the MM.
+ if t.cleartid != 0 {
+ t.tg.signalHandlers.mu.Lock()
+ signaled := t.tg.exiting && t.tg.exitStatus.Signaled()
+ t.tg.signalHandlers.mu.Unlock()
+ if !signaled {
+ if _, err := t.CopyOut(t.cleartid, ThreadID(0)); err == nil {
+ t.Futex().Wake(t, t.cleartid, false, ^uint32(0), 1)
+ }
+ // If the CopyOut fails, there's nothing we can do.
+ }
+ }
+
+ // Deactivate the address space and update max RSS before releasing the
+ // task's MM.
+ t.Deactivate()
+ t.tg.pidns.owner.mu.Lock()
+ t.updateRSSLocked()
+ t.tg.pidns.owner.mu.Unlock()
+ t.mu.Lock()
+ t.tc.release()
+ t.mu.Unlock()
+
+ // Releasing the MM unblocks a blocked CLONE_VFORK parent.
+ t.unstopVforkParent()
+
+ t.fsContext.DecRef()
+ t.fdTable.DecRef()
+
+ t.mu.Lock()
+ if t.mountNamespaceVFS2 != nil {
+ t.mountNamespaceVFS2.DecRef()
+ t.mountNamespaceVFS2 = nil
+ }
+ t.mu.Unlock()
+
+ // If this is the last task to exit from the thread group, release the
+ // thread group's resources.
+ if lastExiter {
+ t.tg.release()
+ }
+
+ // Detach tracees.
+ t.exitPtrace()
+
+ // Reparent the task's children.
+ t.exitChildren()
+
+ // Don't tail-call runExitNotify, as exitChildren may have initiated a stop
+ // to wait for a PID namespace to die.
+ return (*runExitNotify)(nil)
+}
+
+// exitThreadGroup transitions t to TaskExitInitiated, indicating to t's thread
+// group that it is no longer eligible to participate in group activities. It
+// returns true if t is the last task in its thread group to call
+// exitThreadGroup.
+func (t *Task) exitThreadGroup() bool {
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ t.tg.signalHandlers.mu.Lock()
+ // Can't defer unlock: see below.
+
+ t.advanceExitStateLocked(TaskExitNone, TaskExitInitiated)
+ t.tg.activeTasks--
+ last := t.tg.activeTasks == 0
+
+ // Ensure that someone will handle the signals we can't.
+ t.setSignalMaskLocked(^linux.SignalSet(0))
+
+ // Check if this task's exit interacts with an initiated group stop.
+ if !t.groupStopPending {
+ t.tg.signalHandlers.mu.Unlock()
+ return last
+ }
+ t.groupStopPending = false
+ sig := t.tg.groupStopSignal
+ notifyParent := t.participateGroupStopLocked()
+ // signalStop must be called with t's signal mutex unlocked.
+ t.tg.signalHandlers.mu.Unlock()
+ if notifyParent && t.tg.leader.parent != nil {
+ t.tg.leader.parent.signalStop(t, arch.CLD_STOPPED, int32(sig))
+ t.tg.leader.parent.tg.eventQueue.Notify(EventChildGroupStop)
+ }
+ return last
+}
+
+func (t *Task) exitChildren() {
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ newParent := t.findReparentTargetLocked()
+ if newParent == nil {
+ // "If the init process of a PID namespace terminates, the kernel
+ // terminates all of the processes in the namespace via a SIGKILL
+ // signal." - pid_namespaces(7)
+ t.Debugf("Init process terminating, killing namespace")
+ t.tg.pidns.exiting = true
+ for other := range t.tg.pidns.tgids {
+ if other == t.tg {
+ continue
+ }
+ other.signalHandlers.mu.Lock()
+ other.leader.sendSignalLocked(&arch.SignalInfo{
+ Signo: int32(linux.SIGKILL),
+ }, true /* group */)
+ other.signalHandlers.mu.Unlock()
+ }
+ // TODO(b/37722272): The init process waits for all processes in the
+ // namespace to exit before completing its own exit
+ // (kernel/pid_namespace.c:zap_pid_ns_processes()). Stop until all
+ // other tasks in the namespace are dead, except possibly for this
+ // thread group's leader (which can't be reaped until this task exits).
+ }
+ // This is correct even if newParent is nil (it ensures that children don't
+ // wait for a parent to reap them.)
+ for c := range t.children {
+ if sig := c.ParentDeathSignal(); sig != 0 {
+ siginfo := &arch.SignalInfo{
+ Signo: int32(sig),
+ Code: arch.SignalInfoUser,
+ }
+ siginfo.SetPid(int32(c.tg.pidns.tids[t]))
+ siginfo.SetUid(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow()))
+ c.tg.signalHandlers.mu.Lock()
+ c.sendSignalLocked(siginfo, true /* group */)
+ c.tg.signalHandlers.mu.Unlock()
+ }
+ c.reparentLocked(newParent)
+ if newParent != nil {
+ newParent.children[c] = struct{}{}
+ }
+ }
+}
+
+// findReparentTargetLocked returns the task to which t's children should be
+// reparented. If no such task exists, findNewParentLocked returns nil.
+//
+// Preconditions: The TaskSet mutex must be locked.
+func (t *Task) findReparentTargetLocked() *Task {
+ // Reparent to any sibling in the same thread group that hasn't begun
+ // exiting.
+ if t2 := t.tg.anyNonExitingTaskLocked(); t2 != nil {
+ return t2
+ }
+ // "A child process that is orphaned within the namespace will be
+ // reparented to [the init process for the namespace] ..." -
+ // pid_namespaces(7)
+ if init := t.tg.pidns.tasks[InitTID]; init != nil {
+ return init.tg.anyNonExitingTaskLocked()
+ }
+ return nil
+}
+
+func (tg *ThreadGroup) anyNonExitingTaskLocked() *Task {
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ if t.exitState == TaskExitNone {
+ return t
+ }
+ }
+ return nil
+}
+
+// reparentLocked changes t's parent. The new parent may be nil.
+//
+// Preconditions: The TaskSet mutex must be locked for writing.
+func (t *Task) reparentLocked(parent *Task) {
+ oldParent := t.parent
+ t.parent = parent
+ // If a thread group leader's parent changes, reset the thread group's
+ // termination signal to SIGCHLD and re-check exit notification. (Compare
+ // kernel/exit.c:reparent_leader().)
+ if t != t.tg.leader {
+ return
+ }
+ if oldParent == nil && parent == nil {
+ return
+ }
+ if oldParent != nil && parent != nil && oldParent.tg == parent.tg {
+ return
+ }
+ t.tg.terminationSignal = linux.SIGCHLD
+ if t.exitParentNotified && !t.exitParentAcked {
+ t.exitParentNotified = false
+ t.exitNotifyLocked(false)
+ }
+}
+
+// When a task exits, other tasks in the system, notably the task's parent and
+// ptracer, may want to be notified. The exit notification system ensures that
+// interested tasks receive signals and/or are woken from blocking calls to
+// wait*() syscalls; these notifications must be resolved before exiting tasks
+// can be reaped and disappear from the system.
+//
+// Each task may have a parent task and/or a tracer task. If both a parent and
+// a tracer exist, they may be the same task, different tasks in the same
+// thread group, or tasks in different thread groups. (In the last case, Linux
+// refers to the task as being ptrace-reparented due to an implementation
+// detail; we avoid this terminology to avoid confusion.)
+//
+// A thread group is *empty* if all non-leader tasks in the thread group are
+// dead, and the leader is either a zombie or dead. The exit of a thread group
+// leader is never waitable - by either the parent or tracer - until the thread
+// group is empty.
+//
+// There are a few ways for an exit notification to be resolved:
+//
+// - The exit notification may be acknowledged by a call to Task.Wait with
+// WaitOptions.ConsumeEvent set (e.g. due to a wait4() syscall).
+//
+// - If the notified party is the parent, and the parent thread group is not
+// also the tracer thread group, and the notification signal is SIGCHLD, the
+// parent may explicitly ignore the notification (see quote in exitNotify).
+// Note that it's possible for the notified party to ignore the signal in other
+// cases, but the notification is only resolved under the above conditions.
+// (Actually, there is one exception; see the last paragraph of the "leader,
+// has tracer, tracer thread group is parent thread group" case below.)
+//
+// - If the notified party is the parent, and the parent does not exist, the
+// notification is resolved as if ignored. (This is only possible in the
+// sentry. In Linux, the only task / thread group without a parent is global
+// init, and killing global init causes a kernel panic.)
+//
+// - If the notified party is a tracer, the tracer may detach the traced task.
+// (Zombie tasks cannot be ptrace-attached, so the reverse is not possible.)
+//
+// In addition, if the notified party is the parent, the parent may exit and
+// cause the notifying task to be reparented to another thread group. This does
+// not resolve the notification; instead, the notification must be resent to
+// the new parent.
+//
+// The series of notifications generated for a given task's exit depend on
+// whether it is a thread group leader; whether the task is ptraced; and, if
+// so, whether the tracer thread group is the same as the parent thread group.
+//
+// - Non-leader, no tracer: No notification is generated; the task is reaped
+// immediately.
+//
+// - Non-leader, has tracer: SIGCHLD is sent to the tracer. When the tracer
+// notification is resolved (by waiting or detaching), the task is reaped. (For
+// non-leaders, whether the tracer and parent thread groups are the same is
+// irrelevant.)
+//
+// - Leader, no tracer: The task remains a zombie, with no notification sent,
+// until all other tasks in the thread group are dead. (In Linux terms, this
+// condition is indicated by include/linux/sched.h:thread_group_empty(); tasks
+// are removed from their thread_group list in kernel/exit.c:release_task() =>
+// __exit_signal() => __unhash_process().) Then the thread group's termination
+// signal is sent to the parent. When the parent notification is resolved (by
+// waiting or ignoring), the task is reaped.
+//
+// - Leader, has tracer, tracer thread group is not parent thread group:
+// SIGCHLD is sent to the tracer. When the tracer notification is resolved (by
+// waiting or detaching), and all other tasks in the thread group are dead, the
+// thread group's termination signal is sent to the parent. (Note that the
+// tracer cannot resolve the exit notification by waiting until the thread
+// group is empty.) When the parent notification is resolved, the task is
+// reaped.
+//
+// - Leader, has tracer, tracer thread group is parent thread group:
+//
+// If all other tasks in the thread group are dead, the thread group's
+// termination signal is sent to the parent. At this point, the notification
+// can only be resolved by waiting. If the parent detaches from the task as a
+// tracer, the notification is not resolved, but the notification can now be
+// resolved by waiting or ignoring. When the parent notification is resolved,
+// the task is reaped.
+//
+// If at least one task in the thread group is not dead, SIGCHLD is sent to the
+// parent. At this point, the notification cannot be resolved at all; once the
+// thread group becomes empty, it can be resolved only by waiting. If the
+// parent detaches from the task as a tracer before all remaining tasks die,
+// then exit notification proceeds as in the case where the leader never had a
+// tracer. If the parent detaches from the task as a tracer after all remaining
+// tasks die, the notification is not resolved, but the notification can now be
+// resolved by waiting or ignoring. When the parent notification is resolved,
+// the task is reaped.
+//
+// In both of the above cases, when the parent detaches from the task as a
+// tracer while the thread group is empty, whether or not the parent resolves
+// the notification by ignoring it is based on the parent's SIGCHLD signal
+// action, whether or not the thread group's termination signal is SIGCHLD
+// (Linux: kernel/ptrace.c:__ptrace_detach() => ignoring_children()).
+//
+// There is one final wrinkle: A leader can become a non-leader due to a
+// sibling execve. In this case, the execing thread detaches the leader's
+// tracer (if one exists) and reaps the leader immediately. In Linux, this is
+// in fs/exec.c:de_thread(); in the sentry, this is in Task.promoteLocked().
+
+// +stateify savable
+type runExitNotify struct{}
+
+func (*runExitNotify) execute(t *Task) taskRunState {
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ t.advanceExitStateLocked(TaskExitInitiated, TaskExitZombie)
+ t.tg.liveTasks--
+ // Check if this completes a sibling's execve.
+ if t.tg.execing != nil && t.tg.liveTasks == 1 {
+ // execing blocks the addition of new tasks to the thread group, so
+ // the sole living task must be the execing one.
+ e := t.tg.execing
+ e.tg.signalHandlers.mu.Lock()
+ if _, ok := e.stop.(*execStop); ok {
+ e.endInternalStopLocked()
+ }
+ e.tg.signalHandlers.mu.Unlock()
+ }
+ t.exitNotifyLocked(false)
+ // The task goroutine will now exit.
+ return nil
+}
+
+// exitNotifyLocked is called after changes to t's state that affect exit
+// notification.
+//
+// If fromPtraceDetach is true, the caller is ptraceDetach or exitPtrace;
+// thanks to Linux's haphazard implementation of this functionality, such cases
+// determine whether parent notifications are ignored based on the parent's
+// handling of SIGCHLD, regardless of what the exited task's thread group's
+// termination signal is.
+//
+// Preconditions: The TaskSet mutex must be locked for writing.
+func (t *Task) exitNotifyLocked(fromPtraceDetach bool) {
+ if t.exitState != TaskExitZombie {
+ return
+ }
+ if !t.exitTracerNotified {
+ t.exitTracerNotified = true
+ tracer := t.Tracer()
+ if tracer == nil {
+ t.exitTracerAcked = true
+ } else if t != t.tg.leader || t.parent == nil || tracer.tg != t.parent.tg {
+ // Don't set exitParentNotified if t is non-leader, even if the
+ // tracer is in the parent thread group, so that if the parent
+ // detaches the following call to exitNotifyLocked passes through
+ // the !exitParentNotified case below and causes t to be reaped
+ // immediately.
+ //
+ // Tracer notification doesn't care about about
+ // SIG_IGN/SA_NOCLDWAIT.
+ tracer.tg.signalHandlers.mu.Lock()
+ tracer.sendSignalLocked(t.exitNotificationSignal(linux.SIGCHLD, tracer), true /* group */)
+ tracer.tg.signalHandlers.mu.Unlock()
+ // Wake EventTraceeStop waiters as well since this task will never
+ // ptrace-stop again.
+ tracer.tg.eventQueue.Notify(EventExit | EventTraceeStop)
+ } else {
+ // t is a leader and the tracer is in the parent thread group.
+ t.exitParentNotified = true
+ sig := linux.SIGCHLD
+ if t.tg.tasksCount == 1 {
+ sig = t.tg.terminationSignal
+ }
+ // This notification doesn't care about SIG_IGN/SA_NOCLDWAIT either
+ // (in Linux, the check in do_notify_parent() is gated by
+ // !tsk->ptrace.)
+ t.parent.tg.signalHandlers.mu.Lock()
+ t.parent.sendSignalLocked(t.exitNotificationSignal(sig, t.parent), true /* group */)
+ t.parent.tg.signalHandlers.mu.Unlock()
+ // See below for rationale for this event mask.
+ t.parent.tg.eventQueue.Notify(EventExit | EventChildGroupStop | EventGroupContinue)
+ }
+ }
+ if t.exitTracerAcked && !t.exitParentNotified {
+ if t != t.tg.leader {
+ t.exitParentNotified = true
+ t.exitParentAcked = true
+ } else if t.tg.tasksCount == 1 {
+ t.exitParentNotified = true
+ if t.parent == nil {
+ t.exitParentAcked = true
+ } else {
+ // "POSIX.1-2001 specifies that if the disposition of SIGCHLD is
+ // set to SIG_IGN or the SA_NOCLDWAIT flag is set for SIGCHLD (see
+ // sigaction(2)), then children that terminate do not become
+ // zombies and a call to wait() or waitpid() will block until all
+ // children have terminated, and then fail with errno set to
+ // ECHILD. (The original POSIX standard left the behavior of
+ // setting SIGCHLD to SIG_IGN unspecified. Note that even though
+ // the default disposition of SIGCHLD is "ignore", explicitly
+ // setting the disposition to SIG_IGN results in different
+ // treatment of zombie process children.) Linux 2.6 conforms to
+ // this specification." - wait(2)
+ //
+ // Some undocumented Linux-specific details:
+ //
+ // - All of the above is ignored if the termination signal isn't
+ // SIGCHLD.
+ //
+ // - SA_NOCLDWAIT causes the leader to be immediately reaped, but
+ // does not suppress the SIGCHLD.
+ signalParent := t.tg.terminationSignal.IsValid()
+ t.parent.tg.signalHandlers.mu.Lock()
+ if t.tg.terminationSignal == linux.SIGCHLD || fromPtraceDetach {
+ if act, ok := t.parent.tg.signalHandlers.actions[linux.SIGCHLD]; ok {
+ if act.Handler == arch.SignalActIgnore {
+ t.exitParentAcked = true
+ signalParent = false
+ } else if act.Flags&arch.SignalFlagNoCldWait != 0 {
+ t.exitParentAcked = true
+ }
+ }
+ }
+ if signalParent {
+ t.parent.tg.leader.sendSignalLocked(t.exitNotificationSignal(t.tg.terminationSignal, t.parent), true /* group */)
+ }
+ t.parent.tg.signalHandlers.mu.Unlock()
+ // If a task in the parent was waiting for a child group stop
+ // or continue, it needs to be notified of the exit, because
+ // there may be no remaining eligible tasks (so that wait
+ // should return ECHILD).
+ t.parent.tg.eventQueue.Notify(EventExit | EventChildGroupStop | EventGroupContinue)
+ }
+ }
+ }
+ if t.exitTracerAcked && t.exitParentAcked {
+ t.advanceExitStateLocked(TaskExitZombie, TaskExitDead)
+ for ns := t.tg.pidns; ns != nil; ns = ns.parent {
+ tid := ns.tids[t]
+ delete(ns.tasks, tid)
+ delete(ns.tids, t)
+ if t == t.tg.leader {
+ delete(ns.tgids, t.tg)
+ }
+ }
+ t.tg.exitedCPUStats.Accumulate(t.CPUStats())
+ t.tg.ioUsage.Accumulate(t.ioUsage)
+ t.tg.signalHandlers.mu.Lock()
+ t.tg.tasks.Remove(t)
+ t.tg.tasksCount--
+ tc := t.tg.tasksCount
+ t.tg.signalHandlers.mu.Unlock()
+ if tc == 1 && t != t.tg.leader {
+ // Our fromPtraceDetach doesn't matter here (in Linux terms, this
+ // is via a call to release_task()).
+ t.tg.leader.exitNotifyLocked(false)
+ } else if tc == 0 {
+ t.tg.processGroup.decRefWithParent(t.tg.parentPG())
+ }
+ if t.parent != nil {
+ delete(t.parent.children, t)
+ t.parent = nil
+ }
+ }
+}
+
+// Preconditions: The TaskSet mutex must be locked.
+func (t *Task) exitNotificationSignal(sig linux.Signal, receiver *Task) *arch.SignalInfo {
+ info := &arch.SignalInfo{
+ Signo: int32(sig),
+ }
+ info.SetPid(int32(receiver.tg.pidns.tids[t]))
+ info.SetUid(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
+ if t.exitStatus.Signaled() {
+ info.Code = arch.CLD_KILLED
+ info.SetStatus(int32(t.exitStatus.Signo))
+ } else {
+ info.Code = arch.CLD_EXITED
+ info.SetStatus(int32(t.exitStatus.Code))
+ }
+ // TODO(b/72102453): Set utime, stime.
+ return info
+}
+
+// ExitStatus returns t's exit status, which is only guaranteed to be
+// meaningful if t.ExitState() != TaskExitNone.
+func (t *Task) ExitStatus() ExitStatus {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ return t.exitStatus
+}
+
+// ExitStatus returns the exit status that would be returned by a consuming
+// wait*() on tg.
+func (tg *ThreadGroup) ExitStatus() ExitStatus {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+ if tg.exiting {
+ return tg.exitStatus
+ }
+ return tg.leader.exitStatus
+}
+
+// TerminationSignal returns the thread group's termination signal.
+func (tg *ThreadGroup) TerminationSignal() linux.Signal {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ return tg.terminationSignal
+}
+
+// Task events that can be waited for.
+const (
+ // EventExit represents an exit notification generated for a child thread
+ // group leader or a tracee under the conditions specified in the comment
+ // above runExitNotify.
+ EventExit waiter.EventMask = 1 << iota
+
+ // EventChildGroupStop occurs when a child thread group completes a group
+ // stop (i.e. all tasks in the child thread group have entered a stopped
+ // state as a result of a group stop).
+ EventChildGroupStop
+
+ // EventTraceeStop occurs when a task that is ptraced by a task in the
+ // notified thread group enters a ptrace stop (see ptrace(2)).
+ EventTraceeStop
+
+ // EventGroupContinue occurs when a child thread group, or a thread group
+ // whose leader is ptraced by a task in the notified thread group, that had
+ // initiated or completed a group stop leaves the group stop, due to the
+ // child thread group or any task in the child thread group being sent
+ // SIGCONT.
+ EventGroupContinue
+)
+
+// WaitOptions controls the behavior of Task.Wait.
+type WaitOptions struct {
+ // If SpecificTID is non-zero, only events from the task with thread ID
+ // SpecificTID are eligible to be waited for. SpecificTID is resolved in
+ // the PID namespace of the waiter (the method receiver of Task.Wait). If
+ // no such task exists, or that task would not otherwise be eligible to be
+ // waited for by the waiting task, then there are no waitable tasks and
+ // Wait will return ECHILD.
+ SpecificTID ThreadID
+
+ // If SpecificPGID is non-zero, only events from ThreadGroups with a
+ // matching ProcessGroupID are eligible to be waited for. (Same
+ // constraints as SpecificTID apply.)
+ SpecificPGID ProcessGroupID
+
+ // Terminology note: Per waitpid(2), "a clone child is one which delivers
+ // no signal, or a signal other than SIGCHLD to its parent upon
+ // termination." In Linux, termination signal is technically a per-task
+ // property rather than a per-thread-group property. However, clone()
+ // forces no termination signal for tasks created with CLONE_THREAD, and
+ // execve() resets the termination signal to SIGCHLD, so all
+ // non-group-leader threads have no termination signal and are therefore
+ // "clone tasks".
+
+ // If NonCloneTasks is true, events from non-clone tasks are eligible to be
+ // waited for.
+ NonCloneTasks bool
+
+ // If CloneTasks is true, events from clone tasks are eligible to be waited
+ // for.
+ CloneTasks bool
+
+ // If SiblingChildren is true, events from children tasks of any task
+ // in the thread group of the waiter are eligible to be waited for.
+ SiblingChildren bool
+
+ // Events is a bitwise combination of the events defined above that specify
+ // what events are of interest to the call to Wait.
+ Events waiter.EventMask
+
+ // If ConsumeEvent is true, the Wait should consume the event such that it
+ // cannot be returned by a future Wait. Note that if a task exit is
+ // consumed in this way, in most cases the task will be reaped.
+ ConsumeEvent bool
+
+ // If BlockInterruptErr is not nil, Wait will block until either an event
+ // is available or there are no tasks that could produce a waitable event;
+ // if that blocking is interrupted, Wait returns BlockInterruptErr. If
+ // BlockInterruptErr is nil, Wait will not block.
+ BlockInterruptErr error
+}
+
+// Preconditions: The TaskSet mutex must be locked (for reading or writing).
+func (o *WaitOptions) matchesTask(t *Task, pidns *PIDNamespace, tracee bool) bool {
+ if o.SpecificTID != 0 && o.SpecificTID != pidns.tids[t] {
+ return false
+ }
+ if o.SpecificPGID != 0 && o.SpecificPGID != pidns.pgids[t.tg.processGroup] {
+ return false
+ }
+ // Tracees are always eligible.
+ if tracee {
+ return true
+ }
+ if t == t.tg.leader && t.tg.terminationSignal == linux.SIGCHLD {
+ return o.NonCloneTasks
+ }
+ return o.CloneTasks
+}
+
+// ErrNoWaitableEvent is returned by non-blocking Task.Waits (e.g.
+// waitpid(WNOHANG)) that find no waitable events, but determine that waitable
+// events may exist in the future. (In contrast, if a non-blocking or blocking
+// Wait determines that there are no tasks that can produce a waitable event,
+// Task.Wait returns ECHILD.)
+var ErrNoWaitableEvent = errors.New("non-blocking Wait found eligible threads but no waitable events")
+
+// WaitResult contains information about a waited-for event.
+type WaitResult struct {
+ // Task is the task that reported the event.
+ Task *Task
+
+ // TID is the thread ID of Task in the PID namespace of the task that
+ // called Wait (that is, the method receiver of the call to Task.Wait). TID
+ // is provided because consuming exit waits cause the thread ID to be
+ // deallocated.
+ TID ThreadID
+
+ // UID is the real UID of Task in the user namespace of the task that
+ // called Wait.
+ UID auth.UID
+
+ // Event is exactly one of the events defined above.
+ Event waiter.EventMask
+
+ // Status is the numeric status associated with the event.
+ Status uint32
+}
+
+// Wait waits for an event from a thread group that is a child of t's thread
+// group, or a task in such a thread group, or a task that is ptraced by t,
+// subject to the options specified in opts.
+func (t *Task) Wait(opts *WaitOptions) (*WaitResult, error) {
+ if opts.BlockInterruptErr == nil {
+ return t.waitOnce(opts)
+ }
+ w, ch := waiter.NewChannelEntry(nil)
+ t.tg.eventQueue.EventRegister(&w, opts.Events)
+ defer t.tg.eventQueue.EventUnregister(&w)
+ for {
+ wr, err := t.waitOnce(opts)
+ if err != ErrNoWaitableEvent {
+ // This includes err == nil.
+ return wr, err
+ }
+ if err := t.Block(ch); err != nil {
+ return wr, syserror.ConvertIntr(err, opts.BlockInterruptErr)
+ }
+ }
+}
+
+func (t *Task) waitOnce(opts *WaitOptions) (*WaitResult, error) {
+ anyWaitableTasks := false
+
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+
+ if opts.SiblingChildren {
+ // We can wait on the children and tracees of any task in the
+ // same thread group.
+ for parent := t.tg.tasks.Front(); parent != nil; parent = parent.Next() {
+ wr, any := t.waitParentLocked(opts, parent)
+ if wr != nil {
+ return wr, nil
+ }
+ anyWaitableTasks = anyWaitableTasks || any
+ }
+ } else {
+ // We can only wait on this task.
+ var wr *WaitResult
+ wr, anyWaitableTasks = t.waitParentLocked(opts, t)
+ if wr != nil {
+ return wr, nil
+ }
+ }
+
+ if anyWaitableTasks {
+ return nil, ErrNoWaitableEvent
+ }
+ return nil, syserror.ECHILD
+}
+
+// Preconditions: The TaskSet mutex must be locked for writing.
+func (t *Task) waitParentLocked(opts *WaitOptions, parent *Task) (*WaitResult, bool) {
+ anyWaitableTasks := false
+
+ for child := range parent.children {
+ if !opts.matchesTask(child, parent.tg.pidns, false) {
+ continue
+ }
+ // Non-leaders don't notify parents on exit and aren't eligible to
+ // be waited on.
+ if opts.Events&EventExit != 0 && child == child.tg.leader && !child.exitParentAcked {
+ anyWaitableTasks = true
+ if wr := t.waitCollectZombieLocked(child, opts, false); wr != nil {
+ return wr, anyWaitableTasks
+ }
+ }
+ // Check for group stops and continues. Tasks that have passed
+ // TaskExitInitiated can no longer participate in group stops.
+ if opts.Events&(EventChildGroupStop|EventGroupContinue) == 0 {
+ continue
+ }
+ if child.exitState >= TaskExitInitiated {
+ continue
+ }
+ // If the waiter is in the same thread group as the task's
+ // tracer, do not report its group stops; they will be reported
+ // as ptrace stops instead. This also skips checking for group
+ // continues, but they'll be checked for when scanning tracees
+ // below. (Per kernel/exit.c:wait_consider_task(): "If a
+ // ptracer wants to distinguish the two events for its own
+ // children, it should create a separate process which takes
+ // the role of real parent.")
+ if tracer := child.Tracer(); tracer != nil && tracer.tg == parent.tg {
+ continue
+ }
+ anyWaitableTasks = true
+ if opts.Events&EventChildGroupStop != 0 {
+ if wr := t.waitCollectChildGroupStopLocked(child, opts); wr != nil {
+ return wr, anyWaitableTasks
+ }
+ }
+ if opts.Events&EventGroupContinue != 0 {
+ if wr := t.waitCollectGroupContinueLocked(child, opts); wr != nil {
+ return wr, anyWaitableTasks
+ }
+ }
+ }
+ for tracee := range parent.ptraceTracees {
+ if !opts.matchesTask(tracee, parent.tg.pidns, true) {
+ continue
+ }
+ // Non-leaders do notify tracers on exit.
+ if opts.Events&EventExit != 0 && !tracee.exitTracerAcked {
+ anyWaitableTasks = true
+ if wr := t.waitCollectZombieLocked(tracee, opts, true); wr != nil {
+ return wr, anyWaitableTasks
+ }
+ }
+ if opts.Events&(EventTraceeStop|EventGroupContinue) == 0 {
+ continue
+ }
+ if tracee.exitState >= TaskExitInitiated {
+ continue
+ }
+ anyWaitableTasks = true
+ if opts.Events&EventTraceeStop != 0 {
+ if wr := t.waitCollectTraceeStopLocked(tracee, opts); wr != nil {
+ return wr, anyWaitableTasks
+ }
+ }
+ if opts.Events&EventGroupContinue != 0 {
+ if wr := t.waitCollectGroupContinueLocked(tracee, opts); wr != nil {
+ return wr, anyWaitableTasks
+ }
+ }
+ }
+
+ return nil, anyWaitableTasks
+}
+
+// Preconditions: The TaskSet mutex must be locked for writing.
+func (t *Task) waitCollectZombieLocked(target *Task, opts *WaitOptions, asPtracer bool) *WaitResult {
+ if asPtracer && !target.exitTracerNotified {
+ return nil
+ }
+ if !asPtracer && !target.exitParentNotified {
+ return nil
+ }
+ // Zombied thread group leaders are never waitable until their thread group
+ // is otherwise empty. Usually this is caught by the
+ // target.exitParentNotified check above, but if t is both (in the thread
+ // group of) target's tracer and parent, asPtracer may be true.
+ if target == target.tg.leader && target.tg.tasksCount != 1 {
+ return nil
+ }
+ pid := t.tg.pidns.tids[target]
+ uid := target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()
+ status := target.exitStatus.Status()
+ if !opts.ConsumeEvent {
+ return &WaitResult{
+ Task: target,
+ TID: pid,
+ UID: uid,
+ Event: EventExit,
+ Status: status,
+ }
+ }
+ // Surprisingly, the exit status reported by a non-consuming wait can
+ // differ from that reported by a consuming wait; the latter will return
+ // the group exit code if one is available.
+ if target.tg.exiting {
+ status = target.tg.exitStatus.Status()
+ }
+ // t may be (in the thread group of) target's parent, tracer, or both. We
+ // don't need to check for !exitTracerAcked because tracees are detached
+ // here, and we don't need to check for !exitParentAcked because zombies
+ // will be reaped here.
+ if tracer := target.Tracer(); tracer != nil && tracer.tg == t.tg && target.exitTracerNotified {
+ target.exitTracerAcked = true
+ target.ptraceTracer.Store((*Task)(nil))
+ delete(t.ptraceTracees, target)
+ }
+ if target.parent != nil && target.parent.tg == t.tg && target.exitParentNotified {
+ target.exitParentAcked = true
+ if target == target.tg.leader {
+ // target.tg.exitedCPUStats doesn't include target.CPUStats() yet,
+ // and won't until after target.exitNotifyLocked() (maybe). Include
+ // target.CPUStats() explicitly. This is consistent with Linux,
+ // which accounts an exited task's cputime to its thread group in
+ // kernel/exit.c:release_task() => __exit_signal(), and uses
+ // thread_group_cputime_adjusted() in wait_task_zombie().
+ t.tg.childCPUStats.Accumulate(target.CPUStats())
+ t.tg.childCPUStats.Accumulate(target.tg.exitedCPUStats)
+ t.tg.childCPUStats.Accumulate(target.tg.childCPUStats)
+ // Update t's child max resident set size. The size will be the maximum
+ // of this thread's size and all its childrens' sizes.
+ if t.tg.childMaxRSS < target.tg.maxRSS {
+ t.tg.childMaxRSS = target.tg.maxRSS
+ }
+ if t.tg.childMaxRSS < target.tg.childMaxRSS {
+ t.tg.childMaxRSS = target.tg.childMaxRSS
+ }
+ }
+ }
+ target.exitNotifyLocked(false)
+ return &WaitResult{
+ Task: target,
+ TID: pid,
+ UID: uid,
+ Event: EventExit,
+ Status: status,
+ }
+}
+
+// updateRSSLocked updates t.tg.maxRSS.
+//
+// Preconditions: The TaskSet mutex must be locked for writing.
+func (t *Task) updateRSSLocked() {
+ if mmMaxRSS := t.MemoryManager().MaxResidentSetSize(); t.tg.maxRSS < mmMaxRSS {
+ t.tg.maxRSS = mmMaxRSS
+ }
+}
+
+// Preconditions: The TaskSet mutex must be locked for writing.
+func (t *Task) waitCollectChildGroupStopLocked(target *Task, opts *WaitOptions) *WaitResult {
+ target.tg.signalHandlers.mu.Lock()
+ defer target.tg.signalHandlers.mu.Unlock()
+ if !target.tg.groupStopWaitable {
+ return nil
+ }
+ pid := t.tg.pidns.tids[target]
+ uid := target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()
+ sig := target.tg.groupStopSignal
+ if opts.ConsumeEvent {
+ target.tg.groupStopWaitable = false
+ }
+ return &WaitResult{
+ Task: target,
+ TID: pid,
+ UID: uid,
+ Event: EventChildGroupStop,
+ // There is no name for these status constants.
+ Status: (uint32(sig)&0xff)<<8 | 0x7f,
+ }
+}
+
+// Preconditions: The TaskSet mutex must be locked for writing.
+func (t *Task) waitCollectGroupContinueLocked(target *Task, opts *WaitOptions) *WaitResult {
+ target.tg.signalHandlers.mu.Lock()
+ defer target.tg.signalHandlers.mu.Unlock()
+ if !target.tg.groupContWaitable {
+ return nil
+ }
+ pid := t.tg.pidns.tids[target]
+ uid := target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()
+ if opts.ConsumeEvent {
+ target.tg.groupContWaitable = false
+ }
+ return &WaitResult{
+ Task: target,
+ TID: pid,
+ UID: uid,
+ Event: EventGroupContinue,
+ Status: 0xffff,
+ }
+}
+
+// Preconditions: The TaskSet mutex must be locked for writing.
+func (t *Task) waitCollectTraceeStopLocked(target *Task, opts *WaitOptions) *WaitResult {
+ target.tg.signalHandlers.mu.Lock()
+ defer target.tg.signalHandlers.mu.Unlock()
+ if target.stop == nil {
+ return nil
+ }
+ if _, ok := target.stop.(*ptraceStop); !ok {
+ return nil
+ }
+ if target.ptraceCode == 0 {
+ return nil
+ }
+ pid := t.tg.pidns.tids[target]
+ uid := target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()
+ code := target.ptraceCode
+ if opts.ConsumeEvent {
+ target.ptraceCode = 0
+ }
+ return &WaitResult{
+ Task: target,
+ TID: pid,
+ UID: uid,
+ Event: EventTraceeStop,
+ Status: uint32(code)<<8 | 0x7f,
+ }
+}
+
+// ExitState returns t's current progress through the exit path.
+func (t *Task) ExitState() TaskExitState {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ return t.exitState
+}
+
+// ParentDeathSignal returns t's parent death signal.
+func (t *Task) ParentDeathSignal() linux.Signal {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.parentDeathSignal
+}
+
+// SetParentDeathSignal sets t's parent death signal.
+func (t *Task) SetParentDeathSignal(sig linux.Signal) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.parentDeathSignal = sig
+}
diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go
new file mode 100644
index 000000000..a53e77c9f
--- /dev/null
+++ b/pkg/sentry/kernel/task_futex.go
@@ -0,0 +1,54 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/kernel/futex"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Futex returns t's futex manager.
+//
+// Preconditions: The caller must be running on the task goroutine, or t.mu
+// must be locked.
+func (t *Task) Futex() *futex.Manager {
+ return t.tc.fu
+}
+
+// SwapUint32 implements futex.Target.SwapUint32.
+func (t *Task) SwapUint32(addr usermem.Addr, new uint32) (uint32, error) {
+ return t.MemoryManager().SwapUint32(t, addr, new, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+}
+
+// CompareAndSwapUint32 implements futex.Target.CompareAndSwapUint32.
+func (t *Task) CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) {
+ return t.MemoryManager().CompareAndSwapUint32(t, addr, old, new, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+}
+
+// LoadUint32 implements futex.Target.LoadUint32.
+func (t *Task) LoadUint32(addr usermem.Addr) (uint32, error) {
+ return t.MemoryManager().LoadUint32(t, addr, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+}
+
+// GetSharedKey implements futex.Target.GetSharedKey.
+func (t *Task) GetSharedKey(addr usermem.Addr) (futex.Key, error) {
+ return t.MemoryManager().GetSharedFutexKey(t, addr)
+}
diff --git a/pkg/sentry/kernel/task_identity.go b/pkg/sentry/kernel/task_identity.go
new file mode 100644
index 000000000..0325967e4
--- /dev/null
+++ b/pkg/sentry/kernel/task_identity.go
@@ -0,0 +1,606 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Credentials returns t's credentials.
+//
+// This value must be considered immutable.
+func (t *Task) Credentials() *auth.Credentials {
+ return t.creds.Load()
+}
+
+// UserNamespace returns the user namespace associated with the task.
+func (t *Task) UserNamespace() *auth.UserNamespace {
+ return t.Credentials().UserNamespace
+}
+
+// HasCapabilityIn checks if the task has capability cp in user namespace ns.
+func (t *Task) HasCapabilityIn(cp linux.Capability, ns *auth.UserNamespace) bool {
+ return t.Credentials().HasCapabilityIn(cp, ns)
+}
+
+// HasCapability checks if the task has capability cp in its user namespace.
+func (t *Task) HasCapability(cp linux.Capability) bool {
+ return t.Credentials().HasCapability(cp)
+}
+
+// SetUID implements the semantics of setuid(2).
+func (t *Task) SetUID(uid auth.UID) error {
+ // setuid considers -1 to be invalid.
+ if !uid.Ok() {
+ return syserror.EINVAL
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ creds := t.Credentials()
+ kuid := creds.UserNamespace.MapToKUID(uid)
+ if !kuid.Ok() {
+ return syserror.EINVAL
+ }
+ // "setuid() sets the effective user ID of the calling process. If the
+ // effective UID of the caller is root (more precisely: if the caller has
+ // the CAP_SETUID capability), the real UID and saved set-user-ID are also
+ // set." - setuid(2)
+ if creds.HasCapability(linux.CAP_SETUID) {
+ t.setKUIDsUncheckedLocked(kuid, kuid, kuid)
+ return nil
+ }
+ // "EPERM: The user is not privileged (Linux: does not have the CAP_SETUID
+ // capability) and uid does not match the real UID or saved set-user-ID of
+ // the calling process."
+ if kuid != creds.RealKUID && kuid != creds.SavedKUID {
+ return syserror.EPERM
+ }
+ t.setKUIDsUncheckedLocked(creds.RealKUID, kuid, creds.SavedKUID)
+ return nil
+}
+
+// SetREUID implements the semantics of setreuid(2).
+func (t *Task) SetREUID(r, e auth.UID) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ // "Supplying a value of -1 for either the real or effective user ID forces
+ // the system to leave that ID unchanged." - setreuid(2)
+ creds := t.Credentials()
+ newR := creds.RealKUID
+ if r.Ok() {
+ newR = creds.UserNamespace.MapToKUID(r)
+ if !newR.Ok() {
+ return syserror.EINVAL
+ }
+ }
+ newE := creds.EffectiveKUID
+ if e.Ok() {
+ newE = creds.UserNamespace.MapToKUID(e)
+ if !newE.Ok() {
+ return syserror.EINVAL
+ }
+ }
+ if !creds.HasCapability(linux.CAP_SETUID) {
+ // "Unprivileged processes may only set the effective user ID to the
+ // real user ID, the effective user ID, or the saved set-user-ID."
+ if newE != creds.RealKUID && newE != creds.EffectiveKUID && newE != creds.SavedKUID {
+ return syserror.EPERM
+ }
+ // "Unprivileged users may only set the real user ID to the real user
+ // ID or the effective user ID."
+ if newR != creds.RealKUID && newR != creds.EffectiveKUID {
+ return syserror.EPERM
+ }
+ }
+ // "If the real user ID is set (i.e., ruid is not -1) or the effective user
+ // ID is set to a value not equal to the previous real user ID, the saved
+ // set-user-ID will be set to the new effective user ID."
+ newS := creds.SavedKUID
+ if r.Ok() || (e.Ok() && newE != creds.EffectiveKUID) {
+ newS = newE
+ }
+ t.setKUIDsUncheckedLocked(newR, newE, newS)
+ return nil
+}
+
+// SetRESUID implements the semantics of the setresuid(2) syscall.
+func (t *Task) SetRESUID(r, e, s auth.UID) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ // "Unprivileged user processes may change the real UID, effective UID, and
+ // saved set-user-ID, each to one of: the current real UID, the current
+ // effective UID or the current saved set-user-ID. Privileged processes (on
+ // Linux, those having the CAP_SETUID capability) may set the real UID,
+ // effective UID, and saved set-user-ID to arbitrary values. If one of the
+ // arguments equals -1, the corresponding value is not changed." -
+ // setresuid(2)
+ var err error
+ creds := t.Credentials()
+ newR := creds.RealKUID
+ if r.Ok() {
+ newR, err = creds.UseUID(r)
+ if err != nil {
+ return err
+ }
+ }
+ newE := creds.EffectiveKUID
+ if e.Ok() {
+ newE, err = creds.UseUID(e)
+ if err != nil {
+ return err
+ }
+ }
+ newS := creds.SavedKUID
+ if s.Ok() {
+ newS, err = creds.UseUID(s)
+ if err != nil {
+ return err
+ }
+ }
+ t.setKUIDsUncheckedLocked(newR, newE, newS)
+ return nil
+}
+
+// Preconditions: t.mu must be locked.
+func (t *Task) setKUIDsUncheckedLocked(newR, newE, newS auth.KUID) {
+ creds := t.Credentials().Fork() // The credentials object is immutable. See doc for creds.
+ root := creds.UserNamespace.MapToKUID(auth.RootUID)
+ oldR, oldE, oldS := creds.RealKUID, creds.EffectiveKUID, creds.SavedKUID
+ creds.RealKUID, creds.EffectiveKUID, creds.SavedKUID = newR, newE, newS
+
+ // "1. If one or more of the real, effective or saved set user IDs was
+ // previously 0, and as a result of the UID changes all of these IDs have a
+ // nonzero value, then all capabilities are cleared from the permitted and
+ // effective capability sets." - capabilities(7)
+ if (oldR == root || oldE == root || oldS == root) && (newR != root && newE != root && newS != root) {
+ // prctl(2): "PR_SET_KEEPCAP: Set the state of the calling thread's
+ // "keep capabilities" flag, which determines whether the thread's permitted
+ // capability set is cleared when a change is made to the
+ // thread's user IDs such that the thread's real UID, effective
+ // UID, and saved set-user-ID all become nonzero when at least
+ // one of them previously had the value 0. By default, the
+ // permitted capability set is cleared when such a change is
+ // made; setting the "keep capabilities" flag prevents it from
+ // being cleared." (A thread's effective capability set is always
+ // cleared when such a credential change is made,
+ // regardless of the setting of the "keep capabilities" flag.)
+ if !creds.KeepCaps {
+ creds.PermittedCaps = 0
+ creds.EffectiveCaps = 0
+ }
+ }
+ // """
+ // 2. If the effective user ID is changed from 0 to nonzero, then all
+ // capabilities are cleared from the effective set.
+ //
+ // 3. If the effective user ID is changed from nonzero to 0, then the
+ // permitted set is copied to the effective set.
+ // """
+ if oldE == root && newE != root {
+ creds.EffectiveCaps = 0
+ } else if oldE != root && newE == root {
+ creds.EffectiveCaps = creds.PermittedCaps
+ }
+ // "4. If the filesystem user ID is changed from 0 to nonzero (see
+ // setfsuid(2)), then the following capabilities are cleared from the
+ // effective set: ..."
+ // (filesystem UIDs aren't implemented, nor are any of the capabilities in
+ // question)
+
+ if oldE != newE {
+ // "[dumpability] is reset to the current value contained in
+ // the file /proc/sys/fs/suid_dumpable (which by default has
+ // the value 0), in the following circumstances: The process's
+ // effective user or group ID is changed." - prctl(2)
+ //
+ // (suid_dumpable isn't implemented, so we just use the
+ // default.
+ t.MemoryManager().SetDumpability(mm.NotDumpable)
+
+ // Not documented, but compare Linux's kernel/cred.c:commit_creds().
+ t.parentDeathSignal = 0
+ }
+ t.creds.Store(creds)
+}
+
+// SetGID implements the semantics of setgid(2).
+func (t *Task) SetGID(gid auth.GID) error {
+ if !gid.Ok() {
+ return syserror.EINVAL
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ creds := t.Credentials()
+ kgid := creds.UserNamespace.MapToKGID(gid)
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+ if creds.HasCapability(linux.CAP_SETGID) {
+ t.setKGIDsUncheckedLocked(kgid, kgid, kgid)
+ return nil
+ }
+ if kgid != creds.RealKGID && kgid != creds.SavedKGID {
+ return syserror.EPERM
+ }
+ t.setKGIDsUncheckedLocked(creds.RealKGID, kgid, creds.SavedKGID)
+ return nil
+}
+
+// SetREGID implements the semantics of setregid(2).
+func (t *Task) SetREGID(r, e auth.GID) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ creds := t.Credentials()
+ newR := creds.RealKGID
+ if r.Ok() {
+ newR = creds.UserNamespace.MapToKGID(r)
+ if !newR.Ok() {
+ return syserror.EINVAL
+ }
+ }
+ newE := creds.EffectiveKGID
+ if e.Ok() {
+ newE = creds.UserNamespace.MapToKGID(e)
+ if !newE.Ok() {
+ return syserror.EINVAL
+ }
+ }
+ if !creds.HasCapability(linux.CAP_SETGID) {
+ if newE != creds.RealKGID && newE != creds.EffectiveKGID && newE != creds.SavedKGID {
+ return syserror.EPERM
+ }
+ if newR != creds.RealKGID && newR != creds.EffectiveKGID {
+ return syserror.EPERM
+ }
+ }
+ newS := creds.SavedKGID
+ if r.Ok() || (e.Ok() && newE != creds.EffectiveKGID) {
+ newS = newE
+ }
+ t.setKGIDsUncheckedLocked(newR, newE, newS)
+ return nil
+}
+
+// SetRESGID implements the semantics of the setresgid(2) syscall.
+func (t *Task) SetRESGID(r, e, s auth.GID) error {
+ var err error
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ creds := t.Credentials()
+ newR := creds.RealKGID
+ if r.Ok() {
+ newR, err = creds.UseGID(r)
+ if err != nil {
+ return err
+ }
+ }
+ newE := creds.EffectiveKGID
+ if e.Ok() {
+ newE, err = creds.UseGID(e)
+ if err != nil {
+ return err
+ }
+ }
+ newS := creds.SavedKGID
+ if s.Ok() {
+ newS, err = creds.UseGID(s)
+ if err != nil {
+ return err
+ }
+ }
+ t.setKGIDsUncheckedLocked(newR, newE, newS)
+ return nil
+}
+
+func (t *Task) setKGIDsUncheckedLocked(newR, newE, newS auth.KGID) {
+ creds := t.Credentials().Fork() // The credentials object is immutable. See doc for creds.
+ oldE := creds.EffectiveKGID
+ creds.RealKGID, creds.EffectiveKGID, creds.SavedKGID = newR, newE, newS
+
+ if oldE != newE {
+ // "[dumpability] is reset to the current value contained in
+ // the file /proc/sys/fs/suid_dumpable (which by default has
+ // the value 0), in the following circumstances: The process's
+ // effective user or group ID is changed." - prctl(2)
+ //
+ // (suid_dumpable isn't implemented, so we just use the
+ // default.
+ t.MemoryManager().SetDumpability(mm.NotDumpable)
+
+ // Not documented, but compare Linux's
+ // kernel/cred.c:commit_creds().
+ t.parentDeathSignal = 0
+ }
+ t.creds.Store(creds)
+}
+
+// SetExtraGIDs attempts to change t's supplemental groups. All IDs are
+// interpreted as being in t's user namespace.
+func (t *Task) SetExtraGIDs(gids []auth.GID) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ creds := t.Credentials()
+ if !creds.HasCapability(linux.CAP_SETGID) {
+ return syserror.EPERM
+ }
+ kgids := make([]auth.KGID, len(gids))
+ for i, gid := range gids {
+ kgid := creds.UserNamespace.MapToKGID(gid)
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+ kgids[i] = kgid
+ }
+ creds = creds.Fork() // The credentials object is immutable. See doc for creds.
+ creds.ExtraKGIDs = kgids
+ t.creds.Store(creds)
+ return nil
+}
+
+// SetCapabilitySets attempts to change t's permitted, inheritable, and
+// effective capability sets.
+func (t *Task) SetCapabilitySets(permitted, inheritable, effective auth.CapabilitySet) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ // "Permitted: This is a limiting superset for the effective capabilities
+ // that the thread may assume." - capabilities(7)
+ if effective & ^permitted != 0 {
+ return syserror.EPERM
+ }
+ creds := t.Credentials()
+ // "It is also a limiting superset for the capabilities that may be added
+ // to the inheritable set by a thread that does not have the CAP_SETPCAP
+ // capability in its effective set."
+ if !creds.HasCapability(linux.CAP_SETPCAP) && (inheritable & ^(creds.InheritableCaps|creds.PermittedCaps) != 0) {
+ return syserror.EPERM
+ }
+ // "If a thread drops a capability from its permitted set, it can never
+ // reacquire that capability (unless it execve(2)s ..."
+ if permitted & ^creds.PermittedCaps != 0 {
+ return syserror.EPERM
+ }
+ // "... if a capability is not in the bounding set, then a thread can't add
+ // this capability to its inheritable set, even if it was in its permitted
+ // capabilities ..."
+ if inheritable & ^(creds.InheritableCaps|creds.BoundingCaps) != 0 {
+ return syserror.EPERM
+ }
+ creds = creds.Fork() // The credentials object is immutable. See doc for creds.
+ creds.PermittedCaps = permitted
+ creds.InheritableCaps = inheritable
+ creds.EffectiveCaps = effective
+ t.creds.Store(creds)
+ return nil
+}
+
+// DropBoundingCapability attempts to drop capability cp from t's capability
+// bounding set.
+func (t *Task) DropBoundingCapability(cp linux.Capability) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ creds := t.Credentials()
+ if !creds.HasCapability(linux.CAP_SETPCAP) {
+ return syserror.EPERM
+ }
+ creds = creds.Fork() // The credentials object is immutable. See doc for creds.
+ creds.BoundingCaps &^= auth.CapabilitySetOf(cp)
+ t.creds.Store(creds)
+ return nil
+}
+
+// SetUserNamespace attempts to move c into ns.
+func (t *Task) SetUserNamespace(ns *auth.UserNamespace) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ creds := t.Credentials()
+ // "A process reassociating itself with a user namespace must have the
+ // CAP_SYS_ADMIN capability in the target user namespace." - setns(2)
+ //
+ // If t just created ns, then t.creds is guaranteed to have CAP_SYS_ADMIN
+ // in ns (by rule 3 in auth.Credentials.HasCapability).
+ if !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, ns) {
+ return syserror.EPERM
+ }
+
+ creds = creds.Fork() // The credentials object is immutable. See doc for creds.
+ creds.UserNamespace = ns
+ // "The child process created by clone(2) with the CLONE_NEWUSER flag
+ // starts out with a complete set of capabilities in the new user
+ // namespace. Likewise, a process that creates a new user namespace using
+ // unshare(2) or joins an existing user namespace using setns(2) gains a
+ // full set of capabilities in that namespace."
+ creds.PermittedCaps = auth.AllCapabilities
+ creds.InheritableCaps = 0
+ creds.EffectiveCaps = auth.AllCapabilities
+ creds.BoundingCaps = auth.AllCapabilities
+ // "A call to clone(2), unshare(2), or setns(2) using the CLONE_NEWUSER
+ // flag sets the "securebits" flags (see capabilities(7)) to their default
+ // values (all flags disabled) in the child (for clone(2)) or caller (for
+ // unshare(2), or setns(2)." - user_namespaces(7)
+ creds.KeepCaps = false
+ t.creds.Store(creds)
+
+ return nil
+}
+
+// SetKeepCaps will set the keep capabilities flag PR_SET_KEEPCAPS.
+func (t *Task) SetKeepCaps(k bool) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ creds := t.Credentials().Fork() // The credentials object is immutable. See doc for creds.
+ creds.KeepCaps = k
+ t.creds.Store(creds)
+}
+
+// updateCredsForExecLocked updates t.creds to reflect an execve().
+//
+// NOTE(b/30815691): We currently do not implement privileged executables
+// (set-user/group-ID bits and file capabilities). This allows us to make a lot
+// of simplifying assumptions:
+//
+// - We assume the no_new_privs bit (set by prctl(SET_NO_NEW_PRIVS)), which
+// disables the features we don't support anyway, is always set. This
+// drastically simplifies this function.
+//
+// - We don't set AT_SECURE = 1, because no_new_privs always being set means
+// that the conditions that require AT_SECURE = 1 never arise. (Compare Linux's
+// security/commoncap.c:cap_bprm_set_creds() and cap_bprm_secureexec().)
+//
+// - We don't check for CAP_SYS_ADMIN in prctl(PR_SET_SECCOMP), since
+// seccomp-bpf is also allowed if the task has no_new_privs set.
+//
+// - Task.ptraceAttach does not serialize with execve as it does in Linux,
+// since no_new_privs being set has the same effect as the presence of an
+// unprivileged tracer.
+//
+// Preconditions: t.mu must be locked.
+func (t *Task) updateCredsForExecLocked() {
+ // """
+ // During an execve(2), the kernel calculates the new capabilities of
+ // the process using the following algorithm:
+ //
+ // P'(permitted) = (P(inheritable) & F(inheritable)) |
+ // (F(permitted) & cap_bset)
+ //
+ // P'(effective) = F(effective) ? P'(permitted) : 0
+ //
+ // P'(inheritable) = P(inheritable) [i.e., unchanged]
+ //
+ // where:
+ //
+ // P denotes the value of a thread capability set before the
+ // execve(2)
+ //
+ // P' denotes the value of a thread capability set after the
+ // execve(2)
+ //
+ // F denotes a file capability set
+ //
+ // cap_bset is the value of the capability bounding set
+ //
+ // ...
+ //
+ // In order to provide an all-powerful root using capability sets, during
+ // an execve(2):
+ //
+ // 1. If a set-user-ID-root program is being executed, or the real user ID
+ // of the process is 0 (root) then the file inheritable and permitted sets
+ // are defined to be all ones (i.e. all capabilities enabled).
+ //
+ // 2. If a set-user-ID-root program is being executed, then the file
+ // effective bit is defined to be one (enabled).
+ //
+ // The upshot of the above rules, combined with the capabilities
+ // transformations described above, is that when a process execve(2)s a
+ // set-user-ID-root program, or when a process with an effective UID of 0
+ // execve(2)s a program, it gains all capabilities in its permitted and
+ // effective capability sets, except those masked out by the capability
+ // bounding set.
+ // """ - capabilities(7)
+ // (ambient capability sets omitted)
+ //
+ // As the last paragraph implies, the case of "a set-user-ID root program
+ // is being executed" also includes the case where (namespace) root is
+ // executing a non-set-user-ID program; the actual check is just based on
+ // the effective user ID.
+ var newPermitted auth.CapabilitySet // since F(inheritable) == F(permitted) == 0
+ fileEffective := false
+ creds := t.Credentials()
+ root := creds.UserNamespace.MapToKUID(auth.RootUID)
+ if creds.EffectiveKUID == root || creds.RealKUID == root {
+ newPermitted = creds.InheritableCaps | creds.BoundingCaps
+ if creds.EffectiveKUID == root {
+ fileEffective = true
+ }
+ }
+
+ creds = creds.Fork() // The credentials object is immutable. See doc for creds.
+
+ // Now we enter poorly-documented, somewhat confusing territory. (The
+ // accompanying comment in Linux's security/commoncap.c:cap_bprm_set_creds
+ // is not very helpful.) My reading of it is:
+ //
+ // If at least one of the following is true:
+ //
+ // A1. The execing task is ptraced, and the tracer did not have
+ // CAP_SYS_PTRACE in the execing task's user namespace at the time of
+ // PTRACE_ATTACH.
+ //
+ // A2. The execing task shares its FS context with at least one task in
+ // another thread group.
+ //
+ // A3. The execing task has no_new_privs set.
+ //
+ // AND at least one of the following is true:
+ //
+ // B1. The new effective user ID (which may come from set-user-ID, or be the
+ // execing task's existing effective user ID) is not equal to the task's
+ // real UID.
+ //
+ // B2. The new effective group ID (which may come from set-group-ID, or be
+ // the execing task's existing effective group ID) is not equal to the
+ // task's real GID.
+ //
+ // B3. The new permitted capability set contains capabilities not in the
+ // task's permitted capability set.
+ //
+ // Then:
+ //
+ // C1. Limit the new permitted capability set to the task's permitted
+ // capability set.
+ //
+ // C2. If either the task does not have CAP_SETUID in its user namespace, or
+ // the task has no_new_privs set, force the new effective UID and GID to
+ // the task's real UID and GID.
+ //
+ // But since no_new_privs is always set (A3 is always true), this becomes
+ // much simpler. If B1 and B2 are false, C2 is a no-op. If B3 is false, C1
+ // is a no-op. So we can just do C1 and C2 unconditionally.
+ if creds.EffectiveKUID != creds.RealKUID || creds.EffectiveKGID != creds.RealKGID {
+ creds.EffectiveKUID = creds.RealKUID
+ creds.EffectiveKGID = creds.RealKGID
+ t.parentDeathSignal = 0
+ }
+ // (Saved set-user-ID is always set to the new effective user ID, and saved
+ // set-group-ID is always set to the new effective group ID, regardless of
+ // the above.)
+ creds.SavedKUID = creds.RealKUID
+ creds.SavedKGID = creds.RealKGID
+ creds.PermittedCaps &= newPermitted
+ if fileEffective {
+ creds.EffectiveCaps = creds.PermittedCaps
+ } else {
+ creds.EffectiveCaps = 0
+ }
+
+ // prctl(2): The "keep capabilities" value will be reset to 0 on subsequent
+ // calls to execve(2).
+ creds.KeepCaps = false
+
+ // "The bounding set is inherited at fork(2) from the thread's parent, and
+ // is preserved across an execve(2)". So we're done.
+ t.creds.Store(creds)
+}
diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go
new file mode 100644
index 000000000..eeccaa197
--- /dev/null
+++ b/pkg/sentry/kernel/task_log.go
@@ -0,0 +1,208 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+ "runtime/trace"
+ "sort"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ // maxStackDebugBytes is the maximum number of user stack bytes that may be
+ // printed by debugDumpStack.
+ maxStackDebugBytes = 1024
+)
+
+// Infof logs an formatted info message by calling log.Infof.
+func (t *Task) Infof(fmt string, v ...interface{}) {
+ if log.IsLogging(log.Info) {
+ log.InfofAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
+ }
+}
+
+// Warningf logs a warning string by calling log.Warningf.
+func (t *Task) Warningf(fmt string, v ...interface{}) {
+ if log.IsLogging(log.Warning) {
+ log.WarningfAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
+ }
+}
+
+// Debugf creates a debug string that includes the task ID.
+func (t *Task) Debugf(fmt string, v ...interface{}) {
+ if log.IsLogging(log.Debug) {
+ log.DebugfAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
+ }
+}
+
+// IsLogging returns true iff this level is being logged.
+func (t *Task) IsLogging(level log.Level) bool {
+ return log.IsLogging(level)
+}
+
+// DebugDumpState logs task state at log level debug.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) DebugDumpState() {
+ t.debugDumpRegisters()
+ t.debugDumpStack()
+ if mm := t.MemoryManager(); mm != nil {
+ t.Debugf("Mappings:\n%s", mm)
+ }
+ t.Debugf("FDTable:\n%s", t.fdTable)
+}
+
+// debugDumpRegisters logs register state at log level debug.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) debugDumpRegisters() {
+ if !t.IsLogging(log.Debug) {
+ return
+ }
+ regmap, err := t.Arch().RegisterMap()
+ if err != nil {
+ t.Debugf("Registers: %v", err)
+ } else {
+ t.Debugf("Registers:")
+ var regs []string
+ for reg := range regmap {
+ regs = append(regs, reg)
+ }
+ sort.Strings(regs)
+ for _, reg := range regs {
+ t.Debugf("%-8s = %016x", reg, regmap[reg])
+ }
+ }
+}
+
+// debugDumpStack logs user stack contents at log level debug.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) debugDumpStack() {
+ if !t.IsLogging(log.Debug) {
+ return
+ }
+ m := t.MemoryManager()
+ if m == nil {
+ t.Debugf("Memory manager for task is gone, skipping application stack dump.")
+ return
+ }
+ t.Debugf("Stack:")
+ start := usermem.Addr(t.Arch().Stack())
+ // Round addr down to a 16-byte boundary.
+ start &= ^usermem.Addr(15)
+ // Print 16 bytes per line, one byte at a time.
+ for offset := uint64(0); offset < maxStackDebugBytes; offset += 16 {
+ addr, ok := start.AddLength(offset)
+ if !ok {
+ break
+ }
+ var data [16]byte
+ n, err := m.CopyIn(t, addr, data[:], usermem.IOOpts{
+ IgnorePermissions: true,
+ })
+ // Print as much of the line as we can, even if an error was
+ // encountered.
+ if n > 0 {
+ t.Debugf("%x: % x", addr, data[:n])
+ }
+ if err != nil {
+ t.Debugf("Error reading stack at address %x: %v", addr+usermem.Addr(n), err)
+ break
+ }
+ }
+}
+
+// trace definitions.
+//
+// Note that all region names are prefixed by ':' in order to ensure that they
+// are lexically ordered before all system calls, which use the naked system
+// call name (e.g. "read") for maximum clarity.
+const (
+ traceCategory = "task"
+ runRegion = ":run"
+ blockRegion = ":block"
+ cpuidRegion = ":cpuid"
+ faultRegion = ":fault"
+)
+
+// updateInfoLocked updates the task's cached log prefix and tracing
+// information to reflect its current thread ID.
+//
+// Preconditions: The task's owning TaskSet.mu must be locked.
+func (t *Task) updateInfoLocked() {
+ // Use the task's TID in the root PID namespace for logging.
+ tid := t.tg.pidns.owner.Root.tids[t]
+ t.logPrefix.Store(fmt.Sprintf("[% 4d] ", tid))
+ t.rebuildTraceContext(tid)
+}
+
+// rebuildTraceContext rebuilds the trace context.
+//
+// Precondition: the passed tid must be the tid in the root namespace.
+func (t *Task) rebuildTraceContext(tid ThreadID) {
+ // Re-initialize the trace context.
+ if t.traceTask != nil {
+ t.traceTask.End()
+ }
+
+ // Note that we define the "task type" to be the dynamic TID. This does
+ // not align perfectly with the documentation for "tasks" in the
+ // tracing package. Tasks may be assumed to be bounded by analysis
+ // tools. However, if we just use a generic "task" type here, then the
+ // "user-defined tasks" page on the tracing dashboard becomes nearly
+ // unusable, as it loads all traces from all tasks.
+ //
+ // We can assume that the number of tasks in the system is not
+ // arbitrarily large (in general it won't be, especially for cases
+ // where we're collecting a brief profile), so using the TID is a
+ // reasonable compromise in this case.
+ t.traceContext, t.traceTask = trace.NewTask(t, fmt.Sprintf("tid:%d", tid))
+}
+
+// traceCloneEvent is called when a new task is spawned.
+//
+// ntid must be the new task's ThreadID in the root namespace.
+func (t *Task) traceCloneEvent(ntid ThreadID) {
+ if !trace.IsEnabled() {
+ return
+ }
+ trace.Logf(t.traceContext, traceCategory, "spawn: %d", ntid)
+}
+
+// traceExitEvent is called when a task exits.
+func (t *Task) traceExitEvent() {
+ if !trace.IsEnabled() {
+ return
+ }
+ trace.Logf(t.traceContext, traceCategory, "exit status: 0x%x", t.exitStatus.Status())
+}
+
+// traceExecEvent is called when a task calls exec.
+func (t *Task) traceExecEvent(tc *TaskContext) {
+ if !trace.IsEnabled() {
+ return
+ }
+ file := tc.MemoryManager.Executable()
+ if file == nil {
+ trace.Logf(t.traceContext, traceCategory, "exec: << unknown >>")
+ return
+ }
+ defer file.DecRef()
+ trace.Logf(t.traceContext, traceCategory, "exec: %s", file.PathnameWithDeleted(t))
+}
diff --git a/pkg/sentry/kernel/task_net.go b/pkg/sentry/kernel/task_net.go
new file mode 100644
index 000000000..f7711232c
--- /dev/null
+++ b/pkg/sentry/kernel/task_net.go
@@ -0,0 +1,44 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+)
+
+// IsNetworkNamespaced returns true if t is in a non-root network namespace.
+func (t *Task) IsNetworkNamespaced() bool {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return !t.netns.IsRoot()
+}
+
+// NetworkContext returns the network stack used by the task. NetworkContext
+// may return nil if no network stack is available.
+//
+// TODO(gvisor.dev/issue/1833): Migrate callers of this method to
+// NetworkNamespace().
+func (t *Task) NetworkContext() inet.Stack {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.netns.Stack()
+}
+
+// NetworkNamespace returns the network namespace observed by the task.
+func (t *Task) NetworkNamespace() *inet.Namespace {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.netns
+}
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
new file mode 100644
index 000000000..d654dd997
--- /dev/null
+++ b/pkg/sentry/kernel/task_run.go
@@ -0,0 +1,380 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "bytes"
+ "runtime"
+ "runtime/trace"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/hostcpu"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// A taskRunState is a reified state in the task state machine. See README.md
+// for details. The canonical list of all run states, as well as transitions
+// between them, is given in run_states.dot.
+//
+// The set of possible states is enumerable and completely defined by the
+// kernel package, so taskRunState would ideally be represented by a
+// discriminated union. However, Go does not support sum types.
+//
+// Hence, as with TaskStop, data-free taskRunStates should be represented as
+// typecast nils to avoid unnecessary allocation.
+type taskRunState interface {
+ // execute executes the code associated with this state over the given task
+ // and returns the following state. If execute returns nil, the task
+ // goroutine should exit.
+ //
+ // It is valid to tail-call a following state's execute to avoid the
+ // overhead of converting the following state to an interface object and
+ // checking for stops, provided that the tail-call cannot recurse.
+ execute(*Task) taskRunState
+}
+
+// run runs the task goroutine.
+//
+// threadID a dummy value set to the task's TID in the root PID namespace to
+// make it visible in stack dumps. A goroutine for a given task can be identified
+// searching for Task.run()'s argument value.
+func (t *Task) run(threadID uintptr) {
+ // Construct t.blockingTimer here. We do this here because we can't
+ // reconstruct t.blockingTimer during restore in Task.afterLoad(), because
+ // kernel.timekeeper.SetClocks() hasn't been called yet.
+ blockingTimerNotifier, blockingTimerChan := ktime.NewChannelNotifier()
+ t.blockingTimer = ktime.NewTimer(t.k.MonotonicClock(), blockingTimerNotifier)
+ defer t.blockingTimer.Destroy()
+ t.blockingTimerChan = blockingTimerChan
+
+ // Activate our address space.
+ t.Activate()
+ // The corresponding t.Deactivate occurs in the exit path
+ // (runExitMain.execute) so that when
+ // Platform.CooperativelySharesAddressSpace() == true, we give up the
+ // AddressSpace before the task goroutine finishes executing.
+
+ // If this is a newly-started task, it should check for participation in
+ // group stops. If this is a task resuming after restore, it was
+ // interrupted by saving. In either case, the task is initially
+ // interrupted.
+ t.interruptSelf()
+
+ for {
+ // Explanation for this ordering:
+ //
+ // - A freshly-started task that is stopped should not do anything
+ // before it enters the stop.
+ //
+ // - If taskRunState.execute returns nil, the task goroutine should
+ // exit without checking for a stop.
+ //
+ // - Task.Start won't start Task.run if t.runState is nil, so this
+ // ordering is safe.
+ t.doStop()
+ t.runState = t.runState.execute(t)
+ if t.runState == nil {
+ t.accountTaskGoroutineEnter(TaskGoroutineNonexistent)
+ t.goroutineStopped.Done()
+ t.tg.liveGoroutines.Done()
+ t.tg.pidns.owner.liveGoroutines.Done()
+ t.tg.pidns.owner.runningGoroutines.Done()
+ t.p.Release()
+
+ // Keep argument alive because stack trace for dead variables may not be correct.
+ runtime.KeepAlive(threadID)
+ return
+ }
+ }
+}
+
+// doStop is called by Task.run to block until the task is not stopped.
+func (t *Task) doStop() {
+ if atomic.LoadInt32(&t.stopCount) == 0 {
+ return
+ }
+ t.Deactivate()
+ // NOTE(b/30316266): t.Activate() must be called without any locks held, so
+ // this defer must precede the defer for unlocking the signal mutex.
+ defer t.Activate()
+ t.accountTaskGoroutineEnter(TaskGoroutineStopped)
+ defer t.accountTaskGoroutineLeave(TaskGoroutineStopped)
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ t.tg.pidns.owner.runningGoroutines.Add(-1)
+ defer t.tg.pidns.owner.runningGoroutines.Add(1)
+ t.goroutineStopped.Add(-1)
+ defer t.goroutineStopped.Add(1)
+ for t.stopCount > 0 {
+ t.endStopCond.Wait()
+ }
+}
+
+func (*runApp) handleCPUIDInstruction(t *Task) error {
+ if len(arch.CPUIDInstruction) == 0 {
+ // CPUID emulation isn't supported, but this code can be
+ // executed, because the ptrace platform returns
+ // ErrContextSignalCPUID on page faults too. Look at
+ // pkg/sentry/platform/ptrace/ptrace.go:context.Switch for more
+ // details.
+ return platform.ErrContextSignal
+ }
+ // Is this a CPUID instruction?
+ region := trace.StartRegion(t.traceContext, cpuidRegion)
+ expected := arch.CPUIDInstruction[:]
+ found := make([]byte, len(expected))
+ _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found)
+ if err == nil && bytes.Equal(expected, found) {
+ // Skip the cpuid instruction.
+ t.Arch().CPUIDEmulate(t)
+ t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected)))
+ region.End()
+
+ return nil
+ }
+ region.End() // Not an actual CPUID, but required copy-in.
+ return platform.ErrContextSignal
+}
+
+// The runApp state checks for interrupts before executing untrusted
+// application code.
+//
+// +stateify savable
+type runApp struct{}
+
+func (app *runApp) execute(t *Task) taskRunState {
+ if t.interrupted() {
+ // Checkpointing instructs tasks to stop by sending an interrupt, so we
+ // must check for stops before entering runInterrupt (instead of
+ // tail-calling it).
+ return (*runInterrupt)(nil)
+ }
+
+ // We're about to switch to the application again. If there's still a
+ // unhandled SyscallRestartErrno that wasn't translated to an EINTR,
+ // restart the syscall that was interrupted. If there's a saved signal
+ // mask, restore it. (Note that restoring the saved signal mask may unblock
+ // a pending signal, causing another interruption, but that signal should
+ // not interact with the interrupted syscall.)
+ if t.haveSyscallReturn {
+ if sre, ok := SyscallRestartErrnoFromReturn(t.Arch().Return()); ok {
+ if sre == ERESTART_RESTARTBLOCK {
+ t.Debugf("Restarting syscall %d with restart block after errno %d: not interrupted by handled signal", t.Arch().SyscallNo(), sre)
+ t.Arch().RestartSyscallWithRestartBlock()
+ } else {
+ t.Debugf("Restarting syscall %d after errno %d: not interrupted by handled signal", t.Arch().SyscallNo(), sre)
+ t.Arch().RestartSyscall()
+ }
+ }
+ t.haveSyscallReturn = false
+ }
+ if t.haveSavedSignalMask {
+ t.SetSignalMask(t.savedSignalMask)
+ t.haveSavedSignalMask = false
+ if t.interrupted() {
+ return (*runInterrupt)(nil)
+ }
+ }
+
+ // Apply restartable sequences.
+ if t.rseqPreempted {
+ t.rseqPreempted = false
+ if t.rseqAddr != 0 || t.oldRSeqCPUAddr != 0 {
+ // Linux writes the CPU on every preemption. We only do
+ // so if it changed. Thus we may delay delivery of
+ // SIGSEGV if rseqAddr/oldRSeqCPUAddr is invalid.
+ cpu := int32(hostcpu.GetCPU())
+ if t.rseqCPU != cpu {
+ t.rseqCPU = cpu
+ if err := t.rseqCopyOutCPU(); err != nil {
+ t.Debugf("Failed to copy CPU to %#x for rseq: %v", t.rseqAddr, err)
+ t.forceSignal(linux.SIGSEGV, false)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ // Re-enter the task run loop for signal delivery.
+ return (*runApp)(nil)
+ }
+ if err := t.oldRSeqCopyOutCPU(); err != nil {
+ t.Debugf("Failed to copy CPU to %#x for old rseq: %v", t.oldRSeqCPUAddr, err)
+ t.forceSignal(linux.SIGSEGV, false)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ // Re-enter the task run loop for signal delivery.
+ return (*runApp)(nil)
+ }
+ }
+ }
+ t.rseqInterrupt()
+ }
+
+ // Check if we need to enable single-stepping. Tracers expect that the
+ // kernel preserves the value of the single-step flag set by PTRACE_SETREGS
+ // whether or not PTRACE_SINGLESTEP/PTRACE_SYSEMU_SINGLESTEP is used (this
+ // includes our ptrace platform, by the way), so we should only clear the
+ // single-step flag if we're responsible for setting it. (clearSinglestep
+ // is therefore analogous to Linux's TIF_FORCED_TF.)
+ //
+ // Strictly speaking, we should also not clear the single-step flag if we
+ // single-step through an instruction that sets the single-step flag
+ // (arch/x86/kernel/step.c:is_setting_trap_flag()). But nobody sets their
+ // own TF. (Famous last words, I know.)
+ clearSinglestep := false
+ if t.hasTracer() {
+ t.tg.pidns.owner.mu.RLock()
+ if t.ptraceSinglestep {
+ clearSinglestep = !t.Arch().SingleStep()
+ t.Arch().SetSingleStep()
+ }
+ t.tg.pidns.owner.mu.RUnlock()
+ }
+
+ region := trace.StartRegion(t.traceContext, runRegion)
+ t.accountTaskGoroutineEnter(TaskGoroutineRunningApp)
+ info, at, err := t.p.Switch(t.MemoryManager().AddressSpace(), t.Arch(), t.rseqCPU)
+ t.accountTaskGoroutineLeave(TaskGoroutineRunningApp)
+ region.End()
+
+ if clearSinglestep {
+ t.Arch().ClearSingleStep()
+ }
+
+ switch err {
+ case nil:
+ // Handle application system call.
+ return t.doSyscall()
+
+ case platform.ErrContextInterrupt:
+ // Interrupted by platform.Context.Interrupt(). Re-enter the run
+ // loop to figure out why.
+ return (*runApp)(nil)
+
+ case platform.ErrContextSignalCPUID:
+ if err := app.handleCPUIDInstruction(t); err == nil {
+ // Resume execution.
+ return (*runApp)(nil)
+ }
+
+ // The instruction at the given RIP was not a CPUID, and we
+ // fallthrough to the default signal deliver behavior below.
+ fallthrough
+
+ case platform.ErrContextSignal:
+ // Looks like a signal has been delivered to us. If it's a synchronous
+ // signal (SEGV, SIGBUS, etc.), it should be sent to the application
+ // thread that received it.
+ sig := linux.Signal(info.Signo)
+
+ // Was it a fault that we should handle internally? If so, this wasn't
+ // an application-generated signal and we should continue execution
+ // normally.
+ if at.Any() {
+ region := trace.StartRegion(t.traceContext, faultRegion)
+ addr := usermem.Addr(info.Addr())
+ err := t.MemoryManager().HandleUserFault(t, addr, at, usermem.Addr(t.Arch().Stack()))
+ region.End()
+ if err == nil {
+ // The fault was handled appropriately.
+ // We can resume running the application.
+ return (*runApp)(nil)
+ }
+
+ // Is this a vsyscall that we need emulate?
+ //
+ // Note that we don't track vsyscalls as part of a
+ // specific trace region. This is because regions don't
+ // stack, and the actual system call will count as a
+ // region. We should be able to easily identify
+ // vsyscalls by having a <fault><syscall> pair.
+ if at.Execute {
+ if sysno, ok := t.tc.st.LookupEmulate(addr); ok {
+ return t.doVsyscall(addr, sysno)
+ }
+ }
+
+ // Faults are common, log only at debug level.
+ t.Debugf("Unhandled user fault: addr=%x ip=%x access=%v err=%v", addr, t.Arch().IP(), at, err)
+ t.DebugDumpState()
+
+ // Continue to signal handling.
+ //
+ // Convert a BusError error to a SIGBUS from a SIGSEGV. All
+ // other info bits stay the same (address, etc.).
+ if _, ok := err.(*memmap.BusError); ok {
+ sig = linux.SIGBUS
+ info.Signo = int32(linux.SIGBUS)
+ }
+ }
+
+ switch sig {
+ case linux.SIGILL, linux.SIGSEGV, linux.SIGBUS, linux.SIGFPE, linux.SIGTRAP:
+ // Synchronous signal. Send it to ourselves. Assume the signal is
+ // legitimate and force it (work around the signal being ignored or
+ // blocked) like Linux does. Conveniently, this is even the correct
+ // behavior for SIGTRAP from single-stepping.
+ t.forceSignal(linux.Signal(sig), false /* unconditional */)
+ t.SendSignal(info)
+
+ case platform.SignalInterrupt:
+ // Assume that a call to platform.Context.Interrupt() misfired.
+
+ case linux.SIGPROF:
+ // It's a profiling interrupt: there's not much
+ // we can do. We've already paid a decent cost
+ // by intercepting the signal, at this point we
+ // simply ignore it.
+
+ default:
+ // Asynchronous signal. Let the system deal with it.
+ t.k.sendExternalSignal(info, "application")
+ }
+
+ return (*runApp)(nil)
+
+ case platform.ErrContextCPUPreempted:
+ // Ensure that rseq critical sections are interrupted and per-thread
+ // CPU values are updated before the next platform.Context.Switch().
+ t.rseqPreempted = true
+ return (*runApp)(nil)
+
+ default:
+ // What happened? Can't continue.
+ t.Warningf("Unexpected SwitchToApp error: %v", err)
+ t.PrepareExit(ExitStatus{Code: ExtractErrno(err, -1)})
+ return (*runExit)(nil)
+ }
+}
+
+// waitGoroutineStoppedOrExited blocks until t's task goroutine stops or exits.
+func (t *Task) waitGoroutineStoppedOrExited() {
+ t.goroutineStopped.Wait()
+}
+
+// WaitExited blocks until all task goroutines in tg have exited.
+//
+// WaitExited does not correspond to anything in Linux; it's provided so that
+// external callers of Kernel.CreateProcess can wait for the created thread
+// group to terminate.
+func (tg *ThreadGroup) WaitExited() {
+ tg.liveGoroutines.Wait()
+}
+
+// Yield yields the processor for the calling task.
+func (t *Task) Yield() {
+ atomic.AddUint64(&t.yieldCount, 1)
+ runtime.Gosched()
+}
diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go
new file mode 100644
index 000000000..09366b60c
--- /dev/null
+++ b/pkg/sentry/kernel/task_sched.go
@@ -0,0 +1,668 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+// CPU scheduling, real and fake.
+
+import (
+ "fmt"
+ "math/rand"
+ "sync/atomic"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/hostcpu"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/sched"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// TaskGoroutineState is a coarse representation of the current execution
+// status of a kernel.Task goroutine.
+type TaskGoroutineState int
+
+const (
+ // TaskGoroutineNonexistent indicates that the task goroutine has either
+ // not yet been created by Task.Start() or has returned from Task.run().
+ // This must be the zero value for TaskGoroutineState.
+ TaskGoroutineNonexistent TaskGoroutineState = iota
+
+ // TaskGoroutineRunningSys indicates that the task goroutine is executing
+ // sentry code.
+ TaskGoroutineRunningSys
+
+ // TaskGoroutineRunningApp indicates that the task goroutine is executing
+ // application code.
+ TaskGoroutineRunningApp
+
+ // TaskGoroutineBlockedInterruptible indicates that the task goroutine is
+ // blocked in Task.block(), and hence may be woken by Task.interrupt()
+ // (e.g. due to signal delivery).
+ TaskGoroutineBlockedInterruptible
+
+ // TaskGoroutineBlockedUninterruptible indicates that the task goroutine is
+ // stopped outside of Task.block() and Task.doStop(), and hence cannot be
+ // woken by Task.interrupt().
+ TaskGoroutineBlockedUninterruptible
+
+ // TaskGoroutineStopped indicates that the task goroutine is blocked in
+ // Task.doStop(). TaskGoroutineStopped is similar to
+ // TaskGoroutineBlockedUninterruptible, but is a separate state to make it
+ // possible to determine when Task.stop is meaningful.
+ TaskGoroutineStopped
+)
+
+// TaskGoroutineSchedInfo contains task goroutine scheduling state which must
+// be read and updated atomically.
+//
+// +stateify savable
+type TaskGoroutineSchedInfo struct {
+ // Timestamp was the value of Kernel.cpuClock when this
+ // TaskGoroutineSchedInfo was last updated.
+ Timestamp uint64
+
+ // State is the current state of the task goroutine.
+ State TaskGoroutineState
+
+ // UserTicks is the amount of time the task goroutine has spent executing
+ // its associated Task's application code, in units of linux.ClockTick.
+ UserTicks uint64
+
+ // SysTicks is the amount of time the task goroutine has spent executing in
+ // the sentry, in units of linux.ClockTick.
+ SysTicks uint64
+}
+
+// userTicksAt returns the extrapolated value of ts.UserTicks after
+// Kernel.CPUClockNow() indicates a time of now.
+//
+// Preconditions: now <= Kernel.CPUClockNow(). (Since Kernel.cpuClock is
+// monotonic, this is satisfied if now is the result of a previous call to
+// Kernel.CPUClockNow().) This requirement exists because otherwise a racing
+// change to t.gosched can cause userTicksAt to adjust stats by too much,
+// making the observed stats non-monotonic.
+func (ts *TaskGoroutineSchedInfo) userTicksAt(now uint64) uint64 {
+ if ts.Timestamp < now && ts.State == TaskGoroutineRunningApp {
+ // Update stats to reflect execution since the last update.
+ return ts.UserTicks + (now - ts.Timestamp)
+ }
+ return ts.UserTicks
+}
+
+// sysTicksAt returns the extrapolated value of ts.SysTicks after
+// Kernel.CPUClockNow() indicates a time of now.
+//
+// Preconditions: As for userTicksAt.
+func (ts *TaskGoroutineSchedInfo) sysTicksAt(now uint64) uint64 {
+ if ts.Timestamp < now && ts.State == TaskGoroutineRunningSys {
+ return ts.SysTicks + (now - ts.Timestamp)
+ }
+ return ts.SysTicks
+}
+
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) accountTaskGoroutineEnter(state TaskGoroutineState) {
+ now := t.k.CPUClockNow()
+ if t.gosched.State != TaskGoroutineRunningSys {
+ panic(fmt.Sprintf("Task goroutine switching from state %v (expected %v) to %v", t.gosched.State, TaskGoroutineRunningSys, state))
+ }
+ t.goschedSeq.BeginWrite()
+ // This function is very hot; avoid defer.
+ t.gosched.SysTicks += now - t.gosched.Timestamp
+ t.gosched.Timestamp = now
+ t.gosched.State = state
+ t.goschedSeq.EndWrite()
+
+ if state != TaskGoroutineRunningApp {
+ // Task is blocking/stopping.
+ t.k.decRunningTasks()
+ }
+}
+
+// Preconditions: The caller must be running on the task goroutine, and leaving
+// a state indicated by a previous call to
+// t.accountTaskGoroutineEnter(state).
+func (t *Task) accountTaskGoroutineLeave(state TaskGoroutineState) {
+ if state != TaskGoroutineRunningApp {
+ // Task is unblocking/continuing.
+ t.k.incRunningTasks()
+ }
+
+ now := t.k.CPUClockNow()
+ if t.gosched.State != state {
+ panic(fmt.Sprintf("Task goroutine switching from state %v (expected %v) to %v", t.gosched.State, state, TaskGoroutineRunningSys))
+ }
+ t.goschedSeq.BeginWrite()
+ // This function is very hot; avoid defer.
+ if state == TaskGoroutineRunningApp {
+ t.gosched.UserTicks += now - t.gosched.Timestamp
+ }
+ t.gosched.Timestamp = now
+ t.gosched.State = TaskGoroutineRunningSys
+ t.goschedSeq.EndWrite()
+}
+
+// TaskGoroutineSchedInfo returns a copy of t's task goroutine scheduling info.
+// Most clients should use t.CPUStats() instead.
+func (t *Task) TaskGoroutineSchedInfo() TaskGoroutineSchedInfo {
+ return SeqAtomicLoadTaskGoroutineSchedInfo(&t.goschedSeq, &t.gosched)
+}
+
+// CPUStats returns the CPU usage statistics of t.
+func (t *Task) CPUStats() usage.CPUStats {
+ return t.cpuStatsAt(t.k.CPUClockNow())
+}
+
+// Preconditions: As for TaskGoroutineSchedInfo.userTicksAt.
+func (t *Task) cpuStatsAt(now uint64) usage.CPUStats {
+ tsched := t.TaskGoroutineSchedInfo()
+ return usage.CPUStats{
+ UserTime: time.Duration(tsched.userTicksAt(now) * uint64(linux.ClockTick)),
+ SysTime: time.Duration(tsched.sysTicksAt(now) * uint64(linux.ClockTick)),
+ VoluntarySwitches: atomic.LoadUint64(&t.yieldCount),
+ }
+}
+
+// CPUStats returns the combined CPU usage statistics of all past and present
+// threads in tg.
+func (tg *ThreadGroup) CPUStats() usage.CPUStats {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ // Hack to get a pointer to the Kernel.
+ if tg.leader == nil {
+ // Per comment on tg.leader, this is only possible if nothing in the
+ // ThreadGroup has ever executed anyway.
+ return usage.CPUStats{}
+ }
+ return tg.cpuStatsAtLocked(tg.leader.k.CPUClockNow())
+}
+
+// Preconditions: As for TaskGoroutineSchedInfo.userTicksAt. The TaskSet mutex
+// must be locked.
+func (tg *ThreadGroup) cpuStatsAtLocked(now uint64) usage.CPUStats {
+ stats := tg.exitedCPUStats
+ // Account for live tasks.
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ stats.Accumulate(t.cpuStatsAt(now))
+ }
+ return stats
+}
+
+// JoinedChildCPUStats implements the semantics of RUSAGE_CHILDREN: "Return
+// resource usage statistics for all children of [tg] that have terminated and
+// been waited for. These statistics will include the resources used by
+// grandchildren, and further removed descendants, if all of the intervening
+// descendants waited on their terminated children."
+func (tg *ThreadGroup) JoinedChildCPUStats() usage.CPUStats {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ return tg.childCPUStats
+}
+
+// taskClock is a ktime.Clock that measures the time that a task has spent
+// executing. taskClock is primarily used to implement CLOCK_THREAD_CPUTIME_ID.
+//
+// +stateify savable
+type taskClock struct {
+ t *Task
+
+ // If includeSys is true, the taskClock includes both time spent executing
+ // application code as well as time spent in the sentry. Otherwise, the
+ // taskClock includes only time spent executing application code.
+ includeSys bool
+
+ // Implements waiter.Waitable. TimeUntil wouldn't change its estimation
+ // based on either of the clock events, so there's no event to be
+ // notified for.
+ ktime.NoClockEvents `state:"nosave"`
+
+ // Implements ktime.Clock.WallTimeUntil.
+ //
+ // As an upper bound, a task's clock cannot advance faster than CPU
+ // time. It would have to execute at a rate of more than 1 task-second
+ // per 1 CPU-second, which isn't possible.
+ ktime.WallRateClock `state:"nosave"`
+}
+
+// UserCPUClock returns a clock measuring the CPU time the task has spent
+// executing application code.
+func (t *Task) UserCPUClock() ktime.Clock {
+ return &taskClock{t: t, includeSys: false}
+}
+
+// CPUClock returns a clock measuring the CPU time the task has spent executing
+// application and "kernel" code.
+func (t *Task) CPUClock() ktime.Clock {
+ return &taskClock{t: t, includeSys: true}
+}
+
+// Now implements ktime.Clock.Now.
+func (tc *taskClock) Now() ktime.Time {
+ stats := tc.t.CPUStats()
+ if tc.includeSys {
+ return ktime.FromNanoseconds((stats.UserTime + stats.SysTime).Nanoseconds())
+ }
+ return ktime.FromNanoseconds(stats.UserTime.Nanoseconds())
+}
+
+// tgClock is a ktime.Clock that measures the time a thread group has spent
+// executing. tgClock is primarily used to implement CLOCK_PROCESS_CPUTIME_ID.
+//
+// +stateify savable
+type tgClock struct {
+ tg *ThreadGroup
+
+ // If includeSys is true, the tgClock includes both time spent executing
+ // application code as well as time spent in the sentry. Otherwise, the
+ // tgClock includes only time spent executing application code.
+ includeSys bool
+
+ // Implements waiter.Waitable.
+ ktime.ClockEventsQueue `state:"nosave"`
+}
+
+// Now implements ktime.Clock.Now.
+func (tgc *tgClock) Now() ktime.Time {
+ stats := tgc.tg.CPUStats()
+ if tgc.includeSys {
+ return ktime.FromNanoseconds((stats.UserTime + stats.SysTime).Nanoseconds())
+ }
+ return ktime.FromNanoseconds(stats.UserTime.Nanoseconds())
+}
+
+// WallTimeUntil implements ktime.Clock.WallTimeUntil.
+func (tgc *tgClock) WallTimeUntil(t, now ktime.Time) time.Duration {
+ // Thread group CPU time should not exceed wall time * live tasks, since
+ // task goroutines exit after the transition to TaskExitZombie in
+ // runExitNotify.
+ tgc.tg.pidns.owner.mu.RLock()
+ n := tgc.tg.liveTasks
+ tgc.tg.pidns.owner.mu.RUnlock()
+ if n == 0 {
+ if t.Before(now) {
+ return 0
+ }
+ // The timer tick raced with thread group exit, after which no more
+ // tasks can enter the thread group. So tgc.Now() will never advance
+ // again. Return a large delay; the timer should be stopped long before
+ // it comes again anyway.
+ return time.Hour
+ }
+ // This is a lower bound on the amount of time that can elapse before an
+ // associated timer expires, so returning this value tends to result in a
+ // sequence of closely-spaced ticks just before timer expiry. To avoid
+ // this, round up to the nearest ClockTick; CPU usage measurements are
+ // limited to this resolution anyway.
+ remaining := time.Duration(t.Sub(now).Nanoseconds()/int64(n)) * time.Nanosecond
+ return ((remaining + (linux.ClockTick - time.Nanosecond)) / linux.ClockTick) * linux.ClockTick
+}
+
+// UserCPUClock returns a ktime.Clock that measures the time that a thread
+// group has spent executing.
+func (tg *ThreadGroup) UserCPUClock() ktime.Clock {
+ return &tgClock{tg: tg, includeSys: false}
+}
+
+// CPUClock returns a ktime.Clock that measures the time that a thread group
+// has spent executing, including sentry time.
+func (tg *ThreadGroup) CPUClock() ktime.Clock {
+ return &tgClock{tg: tg, includeSys: true}
+}
+
+type kernelCPUClockTicker struct {
+ k *Kernel
+
+ // These are essentially kernelCPUClockTicker.Notify local variables that
+ // are cached between calls to reduce allocations.
+ rng *rand.Rand
+ tgs []*ThreadGroup
+}
+
+func newKernelCPUClockTicker(k *Kernel) *kernelCPUClockTicker {
+ return &kernelCPUClockTicker{
+ k: k,
+ rng: rand.New(rand.NewSource(rand.Int63())),
+ }
+}
+
+// Notify implements ktime.TimerListener.Notify.
+func (ticker *kernelCPUClockTicker) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
+ // Only increment cpuClock by 1 regardless of the number of expirations.
+ // This approximately compensates for cases where thread throttling or bad
+ // Go runtime scheduling prevents the kernelCPUClockTicker goroutine, and
+ // presumably task goroutines as well, from executing for a long period of
+ // time. It's also necessary to prevent CPU clocks from seeing large
+ // discontinuous jumps.
+ now := atomic.AddUint64(&ticker.k.cpuClock, 1)
+
+ // Check thread group CPU timers.
+ tgs := ticker.k.tasks.Root.ThreadGroupsAppend(ticker.tgs)
+ for _, tg := range tgs {
+ if atomic.LoadUint32(&tg.cpuTimersEnabled) == 0 {
+ continue
+ }
+
+ ticker.k.tasks.mu.RLock()
+ if tg.leader == nil {
+ // No tasks have ever run in this thread group.
+ ticker.k.tasks.mu.RUnlock()
+ continue
+ }
+ // Accumulate thread group CPU stats, and randomly select running tasks
+ // using reservoir sampling to receive CPU timer signals.
+ var virtReceiver *Task
+ nrVirtCandidates := 0
+ var profReceiver *Task
+ nrProfCandidates := 0
+ tgUserTime := tg.exitedCPUStats.UserTime
+ tgSysTime := tg.exitedCPUStats.SysTime
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ tsched := t.TaskGoroutineSchedInfo()
+ tgUserTime += time.Duration(tsched.userTicksAt(now) * uint64(linux.ClockTick))
+ tgSysTime += time.Duration(tsched.sysTicksAt(now) * uint64(linux.ClockTick))
+ switch tsched.State {
+ case TaskGoroutineRunningApp:
+ // Considered by ITIMER_VIRT, ITIMER_PROF, and RLIMIT_CPU
+ // timers.
+ nrVirtCandidates++
+ if int(randInt31n(ticker.rng, int32(nrVirtCandidates))) == 0 {
+ virtReceiver = t
+ }
+ fallthrough
+ case TaskGoroutineRunningSys:
+ // Considered by ITIMER_PROF and RLIMIT_CPU timers.
+ nrProfCandidates++
+ if int(randInt31n(ticker.rng, int32(nrProfCandidates))) == 0 {
+ profReceiver = t
+ }
+ }
+ }
+ tgVirtNow := ktime.FromNanoseconds(tgUserTime.Nanoseconds())
+ tgProfNow := ktime.FromNanoseconds((tgUserTime + tgSysTime).Nanoseconds())
+
+ // All of the following are standard (not real-time) signals, which are
+ // automatically deduplicated, so we ignore the number of expirations.
+ tg.signalHandlers.mu.Lock()
+ // It should only be possible for these timers to advance if we found
+ // at least one running task.
+ if virtReceiver != nil {
+ // ITIMER_VIRTUAL
+ newItimerVirtSetting, exp := tg.itimerVirtSetting.At(tgVirtNow)
+ tg.itimerVirtSetting = newItimerVirtSetting
+ if exp != 0 {
+ virtReceiver.sendSignalLocked(SignalInfoPriv(linux.SIGVTALRM), true)
+ }
+ }
+ if profReceiver != nil {
+ // ITIMER_PROF
+ newItimerProfSetting, exp := tg.itimerProfSetting.At(tgProfNow)
+ tg.itimerProfSetting = newItimerProfSetting
+ if exp != 0 {
+ profReceiver.sendSignalLocked(SignalInfoPriv(linux.SIGPROF), true)
+ }
+ // RLIMIT_CPU soft limit
+ newRlimitCPUSoftSetting, exp := tg.rlimitCPUSoftSetting.At(tgProfNow)
+ tg.rlimitCPUSoftSetting = newRlimitCPUSoftSetting
+ if exp != 0 {
+ profReceiver.sendSignalLocked(SignalInfoPriv(linux.SIGXCPU), true)
+ }
+ // RLIMIT_CPU hard limit
+ rlimitCPUMax := tg.limits.Get(limits.CPU).Max
+ if rlimitCPUMax != limits.Infinity && !tgProfNow.Before(ktime.FromSeconds(int64(rlimitCPUMax))) {
+ profReceiver.sendSignalLocked(SignalInfoPriv(linux.SIGKILL), true)
+ }
+ }
+ tg.signalHandlers.mu.Unlock()
+
+ ticker.k.tasks.mu.RUnlock()
+ }
+
+ // Retain tgs between calls to Notify to reduce allocations.
+ for i := range tgs {
+ tgs[i] = nil
+ }
+ ticker.tgs = tgs[:0]
+
+ // If nothing is running, we can disable the timer.
+ tasks := atomic.LoadInt64(&ticker.k.runningTasks)
+ if tasks == 0 {
+ ticker.k.runningTasksMu.Lock()
+ defer ticker.k.runningTasksMu.Unlock()
+ tasks := atomic.LoadInt64(&ticker.k.runningTasks)
+ if tasks != 0 {
+ // Raced with a 0 -> 1 transition.
+ return setting, false
+ }
+
+ // Stop the timer. We must cache the current setting so the
+ // kernel can access it without violating the lock order.
+ ticker.k.cpuClockTickerSetting = setting
+ ticker.k.cpuClockTickerDisabled = true
+ setting.Enabled = false
+ return setting, true
+ }
+
+ return setting, false
+}
+
+// Destroy implements ktime.TimerListener.Destroy.
+func (ticker *kernelCPUClockTicker) Destroy() {
+}
+
+// randInt31n returns a random integer in [0, n).
+//
+// randInt31n is equivalent to math/rand.Rand.int31n(), which is unexported.
+// See that function for details.
+func randInt31n(rng *rand.Rand, n int32) int32 {
+ v := rng.Uint32()
+ prod := uint64(v) * uint64(n)
+ low := uint32(prod)
+ if low < uint32(n) {
+ thresh := uint32(-n) % uint32(n)
+ for low < thresh {
+ v = rng.Uint32()
+ prod = uint64(v) * uint64(n)
+ low = uint32(prod)
+ }
+ }
+ return int32(prod >> 32)
+}
+
+// NotifyRlimitCPUUpdated is called by setrlimit.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) NotifyRlimitCPUUpdated() {
+ t.k.cpuClockTicker.Atomically(func() {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ rlimitCPU := t.tg.limits.Get(limits.CPU)
+ t.tg.rlimitCPUSoftSetting = ktime.Setting{
+ Enabled: rlimitCPU.Cur != limits.Infinity,
+ Next: ktime.FromNanoseconds((time.Duration(rlimitCPU.Cur) * time.Second).Nanoseconds()),
+ Period: time.Second,
+ }
+ if rlimitCPU.Max != limits.Infinity {
+ // Check if tg is already over the hard limit.
+ tgcpu := t.tg.cpuStatsAtLocked(t.k.CPUClockNow())
+ tgProfNow := ktime.FromNanoseconds((tgcpu.UserTime + tgcpu.SysTime).Nanoseconds())
+ if !tgProfNow.Before(ktime.FromSeconds(int64(rlimitCPU.Max))) {
+ t.sendSignalLocked(SignalInfoPriv(linux.SIGKILL), true)
+ }
+ }
+ t.tg.updateCPUTimersEnabledLocked()
+ })
+}
+
+// Preconditions: The signal mutex must be locked.
+func (tg *ThreadGroup) updateCPUTimersEnabledLocked() {
+ rlimitCPU := tg.limits.Get(limits.CPU)
+ if tg.itimerVirtSetting.Enabled || tg.itimerProfSetting.Enabled || tg.rlimitCPUSoftSetting.Enabled || rlimitCPU.Max != limits.Infinity {
+ atomic.StoreUint32(&tg.cpuTimersEnabled, 1)
+ } else {
+ atomic.StoreUint32(&tg.cpuTimersEnabled, 0)
+ }
+}
+
+// StateStatus returns a string representation of the task's current state,
+// appropriate for /proc/[pid]/status.
+func (t *Task) StateStatus() string {
+ switch s := t.TaskGoroutineSchedInfo().State; s {
+ case TaskGoroutineNonexistent:
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ switch t.exitState {
+ case TaskExitZombie:
+ return "Z (zombie)"
+ case TaskExitDead:
+ return "X (dead)"
+ default:
+ // The task goroutine can't exit before passing through
+ // runExitNotify, so this indicates that the task has been created,
+ // but the task goroutine hasn't yet started. The Linux equivalent
+ // is struct task_struct::state == TASK_NEW
+ // (kernel/fork.c:copy_process() =>
+ // kernel/sched/core.c:sched_fork()), but the TASK_NEW bit is
+ // masked out by TASK_REPORT for /proc/[pid]/status, leaving only
+ // TASK_RUNNING.
+ return "R (running)"
+ }
+ case TaskGoroutineRunningSys, TaskGoroutineRunningApp:
+ return "R (running)"
+ case TaskGoroutineBlockedInterruptible:
+ return "S (sleeping)"
+ case TaskGoroutineStopped:
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ switch t.stop.(type) {
+ case *groupStop:
+ return "T (stopped)"
+ case *ptraceStop:
+ return "t (tracing stop)"
+ }
+ fallthrough
+ case TaskGoroutineBlockedUninterruptible:
+ // This is the name Linux uses for TASK_UNINTERRUPTIBLE and
+ // TASK_KILLABLE (= TASK_UNINTERRUPTIBLE | TASK_WAKEKILL):
+ // fs/proc/array.c:task_state_array.
+ return "D (disk sleep)"
+ default:
+ panic(fmt.Sprintf("Invalid TaskGoroutineState: %v", s))
+ }
+}
+
+// CPUMask returns a copy of t's allowed CPU mask.
+func (t *Task) CPUMask() sched.CPUSet {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.allowedCPUMask.Copy()
+}
+
+// SetCPUMask sets t's allowed CPU mask based on mask. It takes ownership of
+// mask.
+//
+// Preconditions: mask.Size() ==
+// sched.CPUSetSize(t.Kernel().ApplicationCores()).
+func (t *Task) SetCPUMask(mask sched.CPUSet) error {
+ if want := sched.CPUSetSize(t.k.applicationCores); mask.Size() != want {
+ panic(fmt.Sprintf("Invalid CPUSet %v (expected %d bytes)", mask, want))
+ }
+
+ // Remove CPUs in mask above Kernel.applicationCores.
+ mask.ClearAbove(t.k.applicationCores)
+
+ // Ensure that at least 1 CPU is still allowed.
+ if mask.NumCPUs() == 0 {
+ return syserror.EINVAL
+ }
+
+ if t.k.useHostCores {
+ // No-op; pretend the mask was immediately changed back.
+ return nil
+ }
+
+ t.tg.pidns.owner.mu.RLock()
+ rootTID := t.tg.pidns.owner.Root.tids[t]
+ t.tg.pidns.owner.mu.RUnlock()
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.allowedCPUMask = mask
+ atomic.StoreInt32(&t.cpu, assignCPU(mask, rootTID))
+ return nil
+}
+
+// CPU returns the cpu id for a given task.
+func (t *Task) CPU() int32 {
+ if t.k.useHostCores {
+ return int32(hostcpu.GetCPU())
+ }
+
+ return atomic.LoadInt32(&t.cpu)
+}
+
+// assignCPU returns the virtualized CPU number for the task with global TID
+// tid and allowedCPUMask allowed.
+func assignCPU(allowed sched.CPUSet, tid ThreadID) (cpu int32) {
+ // To pretend that threads are evenly distributed to allowed CPUs, choose n
+ // to be less than the number of CPUs in allowed ...
+ n := int(tid) % int(allowed.NumCPUs())
+ // ... then pick the nth CPU in allowed.
+ allowed.ForEachCPU(func(c uint) {
+ if n--; n == 0 {
+ cpu = int32(c)
+ }
+ })
+ return cpu
+}
+
+// Niceness returns t's niceness.
+func (t *Task) Niceness() int {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.niceness
+}
+
+// Priority returns t's priority.
+func (t *Task) Priority() int {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.niceness + 20
+}
+
+// SetNiceness sets t's niceness to n.
+func (t *Task) SetNiceness(n int) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.niceness = n
+}
+
+// NumaPolicy returns t's current numa policy.
+func (t *Task) NumaPolicy() (policy linux.NumaPolicy, nodeMask uint64) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.numaPolicy, t.numaNodeMask
+}
+
+// SetNumaPolicy sets t's numa policy.
+func (t *Task) SetNumaPolicy(policy linux.NumaPolicy, nodeMask uint64) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.numaPolicy = policy
+ t.numaNodeMask = nodeMask
+}
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
new file mode 100644
index 000000000..79766cafe
--- /dev/null
+++ b/pkg/sentry/kernel/task_signals.go
@@ -0,0 +1,1139 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+// This file defines the behavior of task signal handling.
+
+import (
+ "fmt"
+ "sync/atomic"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/eventchannel"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SignalAction is an internal signal action.
+type SignalAction int
+
+// Available signal actions.
+// Note that although we refer the complete set internally,
+// the application is only capable of using the Default and
+// Ignore actions from the system call interface.
+const (
+ SignalActionTerm SignalAction = iota
+ SignalActionCore
+ SignalActionStop
+ SignalActionIgnore
+ SignalActionHandler
+)
+
+// Default signal handler actions. Note that for most signals,
+// (except SIGKILL and SIGSTOP) these can be overridden by the app.
+var defaultActions = map[linux.Signal]SignalAction{
+ // POSIX.1-1990 standard.
+ linux.SIGHUP: SignalActionTerm,
+ linux.SIGINT: SignalActionTerm,
+ linux.SIGQUIT: SignalActionCore,
+ linux.SIGILL: SignalActionCore,
+ linux.SIGABRT: SignalActionCore,
+ linux.SIGFPE: SignalActionCore,
+ linux.SIGKILL: SignalActionTerm, // but see ThreadGroup.applySignalSideEffects
+ linux.SIGSEGV: SignalActionCore,
+ linux.SIGPIPE: SignalActionTerm,
+ linux.SIGALRM: SignalActionTerm,
+ linux.SIGTERM: SignalActionTerm,
+ linux.SIGUSR1: SignalActionTerm,
+ linux.SIGUSR2: SignalActionTerm,
+ linux.SIGCHLD: SignalActionIgnore,
+ linux.SIGCONT: SignalActionIgnore, // but see ThreadGroup.applySignalSideEffects
+ linux.SIGSTOP: SignalActionStop,
+ linux.SIGTSTP: SignalActionStop,
+ linux.SIGTTIN: SignalActionStop,
+ linux.SIGTTOU: SignalActionStop,
+ // POSIX.1-2001 standard.
+ linux.SIGBUS: SignalActionCore,
+ linux.SIGPROF: SignalActionTerm,
+ linux.SIGSYS: SignalActionCore,
+ linux.SIGTRAP: SignalActionCore,
+ linux.SIGURG: SignalActionIgnore,
+ linux.SIGVTALRM: SignalActionTerm,
+ linux.SIGXCPU: SignalActionCore,
+ linux.SIGXFSZ: SignalActionCore,
+ // The rest on linux.
+ linux.SIGSTKFLT: SignalActionTerm,
+ linux.SIGIO: SignalActionTerm,
+ linux.SIGPWR: SignalActionTerm,
+ linux.SIGWINCH: SignalActionIgnore,
+}
+
+// computeAction figures out what to do given a signal number
+// and an arch.SignalAct. SIGSTOP always results in a SignalActionStop,
+// and SIGKILL always results in a SignalActionTerm.
+// Signal 0 is always ignored as many programs use it for various internal functions
+// and don't expect it to do anything.
+//
+// In the event the signal is not one of these, act.Handler determines what
+// happens next.
+// If act.Handler is:
+// 0, the default action is taken;
+// 1, the signal is ignored;
+// anything else, the function returns SignalActionHandler.
+func computeAction(sig linux.Signal, act arch.SignalAct) SignalAction {
+ switch sig {
+ case linux.SIGSTOP:
+ return SignalActionStop
+ case linux.SIGKILL:
+ return SignalActionTerm
+ case linux.Signal(0):
+ return SignalActionIgnore
+ }
+
+ switch act.Handler {
+ case arch.SignalActDefault:
+ return defaultActions[sig]
+ case arch.SignalActIgnore:
+ return SignalActionIgnore
+ default:
+ return SignalActionHandler
+ }
+}
+
+// UnblockableSignals contains the set of signals which cannot be blocked.
+var UnblockableSignals = linux.MakeSignalSet(linux.SIGKILL, linux.SIGSTOP)
+
+// StopSignals is the set of signals whose default action is SignalActionStop.
+var StopSignals = linux.MakeSignalSet(linux.SIGSTOP, linux.SIGTSTP, linux.SIGTTIN, linux.SIGTTOU)
+
+// dequeueSignalLocked returns a pending signal that is *not* included in mask.
+// If there are no pending unmasked signals, dequeueSignalLocked returns nil.
+//
+// Preconditions: t.tg.signalHandlers.mu must be locked.
+func (t *Task) dequeueSignalLocked(mask linux.SignalSet) *arch.SignalInfo {
+ if info := t.pendingSignals.dequeue(mask); info != nil {
+ return info
+ }
+ return t.tg.pendingSignals.dequeue(mask)
+}
+
+// discardSpecificLocked removes all instances of the given signal from all
+// signal queues in tg.
+//
+// Preconditions: The signal mutex must be locked.
+func (tg *ThreadGroup) discardSpecificLocked(sig linux.Signal) {
+ tg.pendingSignals.discardSpecific(sig)
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ t.pendingSignals.discardSpecific(sig)
+ }
+}
+
+// PendingSignals returns the set of pending signals.
+func (t *Task) PendingSignals() linux.SignalSet {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ return t.pendingSignals.pendingSet | t.tg.pendingSignals.pendingSet
+}
+
+// deliverSignal delivers the given signal and returns the following run state.
+func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunState {
+ sigact := computeAction(linux.Signal(info.Signo), act)
+
+ if t.haveSyscallReturn {
+ if sre, ok := SyscallRestartErrnoFromReturn(t.Arch().Return()); ok {
+ // Signals that are ignored, cause a thread group stop, or
+ // terminate the thread group do not interact with interrupted
+ // syscalls; in Linux terms, they are never returned to the signal
+ // handling path from get_signal => get_signal_to_deliver. The
+ // behavior of an interrupted syscall is determined by the first
+ // signal that is actually handled (by userspace).
+ if sigact == SignalActionHandler {
+ switch {
+ case sre == ERESTARTNOHAND:
+ fallthrough
+ case sre == ERESTART_RESTARTBLOCK:
+ fallthrough
+ case (sre == ERESTARTSYS && !act.IsRestart()):
+ t.Debugf("Not restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo)
+ t.Arch().SetReturn(uintptr(-ExtractErrno(syserror.EINTR, -1)))
+ default:
+ t.Debugf("Restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo)
+ t.Arch().RestartSyscall()
+ }
+ }
+ }
+ }
+
+ switch sigact {
+ case SignalActionTerm, SignalActionCore:
+ // "Default action is to terminate the process." - signal(7)
+ t.Debugf("Signal %d: terminating thread group", info.Signo)
+
+ // Emit an event channel messages related to this uncaught signal.
+ ucs := &ucspb.UncaughtSignal{
+ Tid: int32(t.Kernel().TaskSet().Root.IDOfTask(t)),
+ Pid: int32(t.Kernel().TaskSet().Root.IDOfThreadGroup(t.ThreadGroup())),
+ Registers: t.Arch().StateData().Proto(),
+ SignalNumber: info.Signo,
+ }
+
+ // Attach an fault address if appropriate.
+ switch linux.Signal(info.Signo) {
+ case linux.SIGSEGV, linux.SIGFPE, linux.SIGILL, linux.SIGTRAP, linux.SIGBUS:
+ ucs.FaultAddr = info.Addr()
+ }
+
+ eventchannel.Emit(ucs)
+
+ t.PrepareGroupExit(ExitStatus{Signo: int(info.Signo)})
+ return (*runExit)(nil)
+
+ case SignalActionStop:
+ // "Default action is to stop the process."
+ t.initiateGroupStop(info)
+
+ case SignalActionIgnore:
+ // "Default action is to ignore the signal."
+ t.Debugf("Signal %d: ignored", info.Signo)
+
+ case SignalActionHandler:
+ // Try to deliver the signal to the user-configured handler.
+ t.Debugf("Signal %d: delivering to handler", info.Signo)
+ if err := t.deliverSignalToHandler(info, act); err != nil {
+ // This is not a warning, it can occur during normal operation.
+ t.Debugf("Failed to deliver signal %+v to user handler: %v", info, err)
+
+ // Send a forced SIGSEGV. If the signal that couldn't be delivered
+ // was a SIGSEGV, force the handler to SIG_DFL.
+ t.forceSignal(linux.SIGSEGV, linux.Signal(info.Signo) == linux.SIGSEGV /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ }
+
+ default:
+ panic(fmt.Sprintf("Unknown signal action %+v, %d?", info, computeAction(linux.Signal(info.Signo), act)))
+ }
+ return (*runInterrupt)(nil)
+}
+
+// deliverSignalToHandler changes the task's userspace state to enter the given
+// user-configured handler for the given signal.
+func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct) error {
+ // Signal delivery to an application handler interrupts restartable
+ // sequences.
+ t.rseqInterrupt()
+
+ // Are executing on the main stack,
+ // or the provided alternate stack?
+ sp := usermem.Addr(t.Arch().Stack())
+
+ // N.B. This is a *copy* of the alternate stack that the user's signal
+ // handler expects to see in its ucontext (even if it's not in use).
+ alt := t.signalStack
+ if act.IsOnStack() && alt.IsEnabled() {
+ alt.SetOnStack()
+ if !alt.Contains(sp) {
+ sp = usermem.Addr(alt.Top())
+ }
+ }
+
+ // Set up the signal handler. If we have a saved signal mask, the signal
+ // handler should run with the current mask, but sigreturn should restore
+ // the saved one.
+ st := &arch.Stack{t.Arch(), t.MemoryManager(), sp}
+ mask := t.signalMask
+ if t.haveSavedSignalMask {
+ mask = t.savedSignalMask
+ }
+
+ // Set up the restorer.
+ // x86-64 should always uses SA_RESTORER, but this flag is optional on other platforms.
+ // Please see the linux code as reference:
+ // linux/arch/x86/kernel/signal.c:__setup_rt_frame()
+ // If SA_RESTORER is not configured, we can use the sigreturn trampolines
+ // the vdso provides instead.
+ // Please see the linux code as reference:
+ // linux/arch/arm64/kernel/signal.c:setup_return()
+ if act.Flags&linux.SA_RESTORER == 0 {
+ act.Restorer = t.MemoryManager().VDSOSigReturn()
+ }
+
+ if err := t.Arch().SignalSetup(st, &act, info, &alt, mask); err != nil {
+ return err
+ }
+ t.haveSavedSignalMask = false
+
+ // Add our signal mask.
+ newMask := t.signalMask | act.Mask
+ if !act.IsNoDefer() {
+ newMask |= linux.SignalSetOf(linux.Signal(info.Signo))
+ }
+ t.SetSignalMask(newMask)
+
+ return nil
+}
+
+var ctrlResume = &SyscallControl{ignoreReturn: true}
+
+// SignalReturn implements sigreturn(2) (if rt is false) or rt_sigreturn(2) (if
+// rt is true).
+func (t *Task) SignalReturn(rt bool) (*SyscallControl, error) {
+ st := t.Stack()
+ sigset, alt, err := t.Arch().SignalRestore(st, rt)
+ if err != nil {
+ return nil, err
+ }
+
+ // Attempt to record the given signal stack. Note that we silently
+ // ignore failures here, as does Linux. Only an EFAULT may be
+ // generated, but SignalRestore has already deserialized the entire
+ // frame successfully.
+ t.SetSignalStack(alt)
+
+ // Restore our signal mask. SIGKILL and SIGSTOP should not be blocked.
+ t.SetSignalMask(sigset &^ UnblockableSignals)
+
+ return ctrlResume, nil
+}
+
+// Sigtimedwait implements the semantics of sigtimedwait(2).
+//
+// Preconditions: The caller must be running on the task goroutine. t.exitState
+// < TaskExitZombie.
+func (t *Task) Sigtimedwait(set linux.SignalSet, timeout time.Duration) (*arch.SignalInfo, error) {
+ // set is the set of signals we're interested in; invert it to get the set
+ // of signals to block.
+ mask := ^(set &^ UnblockableSignals)
+
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ if info := t.dequeueSignalLocked(mask); info != nil {
+ return info, nil
+ }
+
+ if timeout == 0 {
+ return nil, syserror.EAGAIN
+ }
+
+ // Unblock signals we're waiting for. Remember the original signal mask so
+ // that Task.sendSignalTimerLocked doesn't discard ignored signals that
+ // we're temporarily unblocking.
+ t.realSignalMask = t.signalMask
+ t.setSignalMaskLocked(t.signalMask & mask)
+
+ // Wait for a timeout or new signal.
+ t.tg.signalHandlers.mu.Unlock()
+ _, err := t.BlockWithTimeout(nil, true, timeout)
+ t.tg.signalHandlers.mu.Lock()
+
+ // Restore the original signal mask.
+ t.setSignalMaskLocked(t.realSignalMask)
+ t.realSignalMask = 0
+
+ if info := t.dequeueSignalLocked(mask); info != nil {
+ return info, nil
+ }
+ if err == syserror.ETIMEDOUT {
+ return nil, syserror.EAGAIN
+ }
+ return nil, err
+}
+
+// SendSignal sends the given signal to t.
+//
+// The following errors may be returned:
+//
+// syserror.ESRCH - The task has exited.
+// syserror.EINVAL - The signal is not valid.
+// syserror.EAGAIN - THe signal is realtime, and cannot be queued.
+//
+func (t *Task) SendSignal(info *arch.SignalInfo) error {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ return t.sendSignalLocked(info, false /* group */)
+}
+
+// SendGroupSignal sends the given signal to t's thread group.
+func (t *Task) SendGroupSignal(info *arch.SignalInfo) error {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ return t.sendSignalLocked(info, true /* group */)
+}
+
+// SendSignal sends the given signal to tg, using tg's leader to determine if
+// the signal is blocked.
+func (tg *ThreadGroup) SendSignal(info *arch.SignalInfo) error {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+ return tg.leader.sendSignalLocked(info, true /* group */)
+}
+
+func (t *Task) sendSignalLocked(info *arch.SignalInfo, group bool) error {
+ return t.sendSignalTimerLocked(info, group, nil)
+}
+
+func (t *Task) sendSignalTimerLocked(info *arch.SignalInfo, group bool, timer *IntervalTimer) error {
+ if t.exitState == TaskExitDead {
+ return syserror.ESRCH
+ }
+ sig := linux.Signal(info.Signo)
+ if sig == 0 {
+ return nil
+ }
+ if !sig.IsValid() {
+ return syserror.EINVAL
+ }
+
+ // Signal side effects apply even if the signal is ultimately discarded.
+ t.tg.applySignalSideEffectsLocked(sig)
+
+ // TODO: "Only signals for which the "init" process has established a
+ // signal handler can be sent to the "init" process by other members of the
+ // PID namespace. This restriction applies even to privileged processes,
+ // and prevents other members of the PID namespace from accidentally
+ // killing the "init" process." - pid_namespaces(7). We don't currently do
+ // this for child namespaces, though we should; we also don't do this for
+ // the root namespace (the same restriction applies to global init on
+ // Linux), where whether or not we should is much murkier. In practice,
+ // most sandboxed applications are not prepared to function as an init
+ // process.
+
+ // Unmasked, ignored signals are discarded without being queued, unless
+ // they will be visible to a tracer. Even for group signals, it's the
+ // originally-targeted task's signal mask and tracer that matter; compare
+ // Linux's kernel/signal.c:__send_signal() => prepare_signal() =>
+ // sig_ignored().
+ ignored := computeAction(sig, t.tg.signalHandlers.actions[sig]) == SignalActionIgnore
+ if sigset := linux.SignalSetOf(sig); sigset&t.signalMask == 0 && sigset&t.realSignalMask == 0 && ignored && !t.hasTracer() {
+ t.Debugf("Discarding ignored signal %d", sig)
+ if timer != nil {
+ timer.signalRejectedLocked()
+ }
+ return nil
+ }
+
+ q := &t.pendingSignals
+ if group {
+ q = &t.tg.pendingSignals
+ }
+ if !q.enqueue(info, timer) {
+ if sig.IsRealtime() {
+ return syserror.EAGAIN
+ }
+ t.Debugf("Discarding duplicate signal %d", sig)
+ if timer != nil {
+ timer.signalRejectedLocked()
+ }
+ return nil
+ }
+
+ // Find a receiver to notify. Note that the task we choose to notify, if
+ // any, may not be the task that actually dequeues and handles the signal;
+ // e.g. a racing signal mask change may cause the notified task to become
+ // ineligible, or a racing sibling task may dequeue the signal first.
+ if t.canReceiveSignalLocked(sig) {
+ t.Debugf("Notified of signal %d", sig)
+ t.interrupt()
+ return nil
+ }
+ if group {
+ if nt := t.tg.findSignalReceiverLocked(sig); nt != nil {
+ nt.Debugf("Notified of group signal %d", sig)
+ nt.interrupt()
+ return nil
+ }
+ }
+ t.Debugf("No task notified of signal %d", sig)
+ return nil
+}
+
+func (tg *ThreadGroup) applySignalSideEffectsLocked(sig linux.Signal) {
+ switch {
+ case linux.SignalSetOf(sig)&StopSignals != 0:
+ // Stop signals cause all prior SIGCONT to be discarded. (This is
+ // despite the fact this has little effect since SIGCONT's most
+ // important effect is applied when the signal is sent in the branch
+ // below, not when the signal is delivered.)
+ tg.discardSpecificLocked(linux.SIGCONT)
+ case sig == linux.SIGCONT:
+ // "The SIGCONT signal has a side effect of waking up (all threads of)
+ // a group-stopped process. This side effect happens before
+ // signal-delivery-stop. The tracer can't suppress this side effect (it
+ // can only suppress signal injection, which only causes the SIGCONT
+ // handler to not be executed in the tracee, if such a handler is
+ // installed." - ptrace(2)
+ tg.endGroupStopLocked(true)
+ case sig == linux.SIGKILL:
+ // "SIGKILL does not generate signal-delivery-stop and therefore the
+ // tracer can't suppress it. SIGKILL kills even within system calls
+ // (syscall-exit-stop is not generated prior to death by SIGKILL)." -
+ // ptrace(2)
+ //
+ // Note that this differs from ThreadGroup.requestExit in that it
+ // ignores tg.execing.
+ if !tg.exiting {
+ tg.exiting = true
+ tg.exitStatus = ExitStatus{Signo: int(linux.SIGKILL)}
+ }
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ t.killLocked()
+ }
+ }
+}
+
+// canReceiveSignalLocked returns true if t should be interrupted to receive
+// the given signal. canReceiveSignalLocked is analogous to Linux's
+// kernel/signal.c:wants_signal(), but see below for divergences.
+//
+// Preconditions: The signal mutex must be locked.
+func (t *Task) canReceiveSignalLocked(sig linux.Signal) bool {
+ // Notify that the signal is queued.
+ t.signalQueue.Notify(waiter.EventMask(linux.MakeSignalSet(sig)))
+
+ // - Do not choose tasks that are blocking the signal.
+ if linux.SignalSetOf(sig)&t.signalMask != 0 {
+ return false
+ }
+ // - No need to check Task.exitState, as the exit path sets every bit in the
+ // signal mask when it transitions from TaskExitNone to TaskExitInitiated.
+ // - No special case for SIGKILL: SIGKILL already interrupted all tasks in the
+ // task group via applySignalSideEffects => killLocked.
+ // - Do not choose stopped tasks, which cannot handle signals.
+ if t.stop != nil {
+ return false
+ }
+ // - Do not choose tasks that have already been interrupted, as they may be
+ // busy handling another signal.
+ if len(t.interruptChan) != 0 {
+ return false
+ }
+ return true
+}
+
+// findSignalReceiverLocked returns a task in tg that should be interrupted to
+// receive the given signal. If no such task exists, findSignalReceiverLocked
+// returns nil.
+//
+// Linux actually records curr_target to balance the group signal targets.
+//
+// Preconditions: The signal mutex must be locked.
+func (tg *ThreadGroup) findSignalReceiverLocked(sig linux.Signal) *Task {
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ if t.canReceiveSignalLocked(sig) {
+ return t
+ }
+ }
+ return nil
+}
+
+// forceSignal ensures that the task is not ignoring or blocking the given
+// signal. If unconditional is true, forceSignal takes action even if the
+// signal isn't being ignored or blocked.
+func (t *Task) forceSignal(sig linux.Signal, unconditional bool) {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ t.forceSignalLocked(sig, unconditional)
+}
+
+func (t *Task) forceSignalLocked(sig linux.Signal, unconditional bool) {
+ blocked := linux.SignalSetOf(sig)&t.signalMask != 0
+ act := t.tg.signalHandlers.actions[sig]
+ ignored := act.Handler == arch.SignalActIgnore
+ if blocked || ignored || unconditional {
+ act.Handler = arch.SignalActDefault
+ t.tg.signalHandlers.actions[sig] = act
+ if blocked {
+ t.setSignalMaskLocked(t.signalMask &^ linux.SignalSetOf(sig))
+ }
+ }
+}
+
+// SignalMask returns a copy of t's signal mask.
+func (t *Task) SignalMask() linux.SignalSet {
+ return linux.SignalSet(atomic.LoadUint64((*uint64)(&t.signalMask)))
+}
+
+// SetSignalMask sets t's signal mask.
+//
+// Preconditions: SetSignalMask can only be called by the task goroutine.
+// t.exitState < TaskExitZombie.
+func (t *Task) SetSignalMask(mask linux.SignalSet) {
+ // By precondition, t prevents t.tg from completing an execve and mutating
+ // t.tg.signalHandlers, so we can skip the TaskSet mutex.
+ t.tg.signalHandlers.mu.Lock()
+ t.setSignalMaskLocked(mask)
+ t.tg.signalHandlers.mu.Unlock()
+}
+
+// Preconditions: The signal mutex must be locked.
+func (t *Task) setSignalMaskLocked(mask linux.SignalSet) {
+ oldMask := t.signalMask
+ atomic.StoreUint64((*uint64)(&t.signalMask), uint64(mask))
+
+ // If the new mask blocks any signals that were not blocked by the old
+ // mask, and at least one such signal is pending in tg.pendingSignals, and
+ // t has been woken, it could be the case that t was woken to handle that
+ // signal, but will no longer do so as a result of its new signal mask, so
+ // we have to pick a replacement.
+ blocked := mask &^ oldMask
+ blockedGroupPending := blocked & t.tg.pendingSignals.pendingSet
+ if blockedGroupPending != 0 && t.interrupted() {
+ linux.ForEachSignal(blockedGroupPending, func(sig linux.Signal) {
+ if nt := t.tg.findSignalReceiverLocked(sig); nt != nil {
+ nt.interrupt()
+ return
+ }
+ })
+ // We have to re-issue the interrupt consumed by t.interrupted() since
+ // it might have been for a different reason.
+ t.interruptSelf()
+ }
+
+ // Conversely, if the new mask unblocks any signals that were blocked by
+ // the old mask, and at least one such signal is pending, we may now need
+ // to handle that signal.
+ unblocked := oldMask &^ mask
+ unblockedPending := unblocked & (t.pendingSignals.pendingSet | t.tg.pendingSignals.pendingSet)
+ if unblockedPending != 0 {
+ t.interruptSelf()
+ }
+}
+
+// SetSavedSignalMask sets the saved signal mask (see Task.savedSignalMask's
+// comment).
+//
+// Preconditions: SetSavedSignalMask can only be called by the task goroutine.
+func (t *Task) SetSavedSignalMask(mask linux.SignalSet) {
+ t.savedSignalMask = mask
+ t.haveSavedSignalMask = true
+}
+
+// SignalStack returns the task-private signal stack.
+func (t *Task) SignalStack() arch.SignalStack {
+ alt := t.signalStack
+ if t.onSignalStack(alt) {
+ alt.Flags |= arch.SignalStackFlagOnStack
+ }
+ return alt
+}
+
+// onSignalStack returns true if the task is executing on the given signal stack.
+func (t *Task) onSignalStack(alt arch.SignalStack) bool {
+ sp := usermem.Addr(t.Arch().Stack())
+ return alt.Contains(sp)
+}
+
+// SetSignalStack sets the task-private signal stack.
+//
+// This value may not be changed if the task is currently executing on the
+// signal stack, i.e. if t.onSignalStack returns true. In this case, this
+// function will return false. Otherwise, true is returned.
+func (t *Task) SetSignalStack(alt arch.SignalStack) bool {
+ // Check that we're not executing on the stack.
+ if t.onSignalStack(t.signalStack) {
+ return false
+ }
+
+ if alt.Flags&arch.SignalStackFlagDisable != 0 {
+ // Don't record anything beyond the flags.
+ t.signalStack = arch.SignalStack{
+ Flags: arch.SignalStackFlagDisable,
+ }
+ } else {
+ // Mask out irrelevant parts: only disable matters.
+ alt.Flags &= arch.SignalStackFlagDisable
+ t.signalStack = alt
+ }
+ return true
+}
+
+// SetSignalAct atomically sets the thread group's signal action for signal sig
+// to *actptr (if actptr is not nil) and returns the old signal action.
+func (tg *ThreadGroup) SetSignalAct(sig linux.Signal, actptr *arch.SignalAct) (arch.SignalAct, error) {
+ if !sig.IsValid() {
+ return arch.SignalAct{}, syserror.EINVAL
+ }
+
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ sh := tg.signalHandlers
+ sh.mu.Lock()
+ defer sh.mu.Unlock()
+ oldact := sh.actions[sig]
+ if actptr != nil {
+ if sig == linux.SIGKILL || sig == linux.SIGSTOP {
+ return oldact, syserror.EINVAL
+ }
+
+ act := *actptr
+ act.Mask &^= UnblockableSignals
+ sh.actions[sig] = act
+ // From POSIX, by way of Linux:
+ //
+ // "Setting a signal action to SIG_IGN for a signal that is pending
+ // shall cause the pending signal to be discarded, whether or not it is
+ // blocked."
+ //
+ // "Setting a signal action to SIG_DFL for a signal that is pending and
+ // whose default action is to ignore the signal (for example, SIGCHLD),
+ // shall cause the pending signal to be discarded, whether or not it is
+ // blocked."
+ if computeAction(sig, act) == SignalActionIgnore {
+ tg.discardSpecificLocked(sig)
+ }
+ }
+ return oldact, nil
+}
+
+// CopyOutSignalAct converts the given SignalAct into an architecture-specific
+// type and then copies it out to task memory.
+func (t *Task) CopyOutSignalAct(addr usermem.Addr, s *arch.SignalAct) error {
+ n := t.Arch().NewSignalAct()
+ n.SerializeFrom(s)
+ _, err := n.CopyOut(t, addr)
+ return err
+}
+
+// CopyInSignalAct copies an architecture-specific sigaction type from task
+// memory and then converts it into a SignalAct.
+func (t *Task) CopyInSignalAct(addr usermem.Addr) (arch.SignalAct, error) {
+ n := t.Arch().NewSignalAct()
+ var s arch.SignalAct
+ if _, err := n.CopyIn(t, addr); err != nil {
+ return s, err
+ }
+ n.DeserializeTo(&s)
+ return s, nil
+}
+
+// CopyOutSignalStack converts the given SignalStack into an
+// architecture-specific type and then copies it out to task memory.
+func (t *Task) CopyOutSignalStack(addr usermem.Addr, s *arch.SignalStack) error {
+ n := t.Arch().NewSignalStack()
+ n.SerializeFrom(s)
+ _, err := n.CopyOut(t, addr)
+ return err
+}
+
+// CopyInSignalStack copies an architecture-specific stack_t from task memory
+// and then converts it into a SignalStack.
+func (t *Task) CopyInSignalStack(addr usermem.Addr) (arch.SignalStack, error) {
+ n := t.Arch().NewSignalStack()
+ var s arch.SignalStack
+ if _, err := n.CopyIn(t, addr); err != nil {
+ return s, err
+ }
+ n.DeserializeTo(&s)
+ return s, nil
+}
+
+// groupStop is a TaskStop placed on tasks that have received a stop signal
+// (SIGSTOP, SIGTSTP, SIGTTIN, SIGTTOU). (The term "group-stop" originates from
+// the ptrace man page.)
+//
+// +stateify savable
+type groupStop struct{}
+
+// Killable implements TaskStop.Killable.
+func (*groupStop) Killable() bool { return true }
+
+// initiateGroupStop attempts to initiate a group stop based on a
+// previously-dequeued stop signal.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) initiateGroupStop(info *arch.SignalInfo) {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ if t.groupStopPending {
+ t.Debugf("Signal %d: not stopping thread group: lost to racing stop signal", info.Signo)
+ return
+ }
+ if !t.tg.groupStopDequeued {
+ t.Debugf("Signal %d: not stopping thread group: lost to racing SIGCONT", info.Signo)
+ return
+ }
+ if t.tg.exiting {
+ t.Debugf("Signal %d: not stopping thread group: lost to racing group exit", info.Signo)
+ return
+ }
+ if t.tg.execing != nil {
+ t.Debugf("Signal %d: not stopping thread group: lost to racing execve", info.Signo)
+ return
+ }
+ if !t.tg.groupStopComplete {
+ t.tg.groupStopSignal = linux.Signal(info.Signo)
+ }
+ t.tg.groupStopPendingCount = 0
+ for t2 := t.tg.tasks.Front(); t2 != nil; t2 = t2.Next() {
+ if t2.killedLocked() || t2.exitState >= TaskExitInitiated {
+ t2.groupStopPending = false
+ continue
+ }
+ t2.groupStopPending = true
+ t2.groupStopAcknowledged = false
+ if t2.ptraceSeized {
+ t2.trapNotifyPending = true
+ if s, ok := t2.stop.(*ptraceStop); ok && s.listen {
+ t2.endInternalStopLocked()
+ }
+ }
+ t2.interrupt()
+ t.tg.groupStopPendingCount++
+ }
+ t.Debugf("Signal %d: stopping %d threads in thread group", info.Signo, t.tg.groupStopPendingCount)
+}
+
+// endGroupStopLocked ensures that all prior stop signals received by tg are
+// not stopping tg and will not stop tg in the future. If broadcast is true,
+// parent and tracer notification will be scheduled if appropriate.
+//
+// Preconditions: The signal mutex must be locked.
+func (tg *ThreadGroup) endGroupStopLocked(broadcast bool) {
+ // Discard all previously-queued stop signals.
+ linux.ForEachSignal(StopSignals, tg.discardSpecificLocked)
+
+ if tg.groupStopPendingCount == 0 && !tg.groupStopComplete {
+ return
+ }
+
+ completeStr := "incomplete"
+ if tg.groupStopComplete {
+ completeStr = "complete"
+ }
+ tg.leader.Debugf("Ending %s group stop with %d threads pending", completeStr, tg.groupStopPendingCount)
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ t.groupStopPending = false
+ if t.ptraceSeized {
+ t.trapNotifyPending = true
+ if s, ok := t.stop.(*ptraceStop); ok && s.listen {
+ t.endInternalStopLocked()
+ }
+ } else {
+ if _, ok := t.stop.(*groupStop); ok {
+ t.endInternalStopLocked()
+ }
+ }
+ }
+ if broadcast {
+ // Instead of notifying the parent here, set groupContNotify so that
+ // one of the continuing tasks does so. (Linux does something similar.)
+ // The reason we do this is to keep locking sane. In order to send a
+ // signal to the parent, we need to lock its signal mutex, but we're
+ // already holding tg's signal mutex, and the TaskSet mutex must be
+ // locked for writing for us to hold two signal mutexes. Since we don't
+ // want to require this for endGroupStopLocked (which is called from
+ // signal-sending paths), nor do we want to lose atomicity by releasing
+ // the mutexes we're already holding, just let the continuing thread
+ // group deal with it.
+ tg.groupContNotify = true
+ tg.groupContInterrupted = !tg.groupStopComplete
+ tg.groupContWaitable = true
+ }
+ // Unsetting groupStopDequeued will cause racing calls to initiateGroupStop
+ // to recognize that the group stop has been cancelled.
+ tg.groupStopDequeued = false
+ tg.groupStopSignal = 0
+ tg.groupStopPendingCount = 0
+ tg.groupStopComplete = false
+ tg.groupStopWaitable = false
+}
+
+// participateGroupStopLocked is called to handle thread group side effects
+// after t unsets t.groupStopPending. The caller must handle task side effects
+// (e.g. placing the task goroutine into the group stop). It returns true if
+// the caller must notify t.tg.leader's parent of a completed group stop (which
+// participateGroupStopLocked cannot do due to holding the wrong locks).
+//
+// Preconditions: The signal mutex must be locked.
+func (t *Task) participateGroupStopLocked() bool {
+ if t.groupStopAcknowledged {
+ return false
+ }
+ t.groupStopAcknowledged = true
+ t.tg.groupStopPendingCount--
+ if t.tg.groupStopPendingCount != 0 {
+ return false
+ }
+ if t.tg.groupStopComplete {
+ return false
+ }
+ t.Debugf("Completing group stop")
+ t.tg.groupStopComplete = true
+ t.tg.groupStopWaitable = true
+ t.tg.groupContNotify = false
+ t.tg.groupContWaitable = false
+ return true
+}
+
+// signalStop sends a signal to t's thread group of a new group stop, group
+// continue, or ptrace stop, if appropriate. code and status are set in the
+// signal sent to tg, if any.
+//
+// Preconditions: The TaskSet mutex must be locked (for reading or writing).
+func (t *Task) signalStop(target *Task, code int32, status int32) {
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ act, ok := t.tg.signalHandlers.actions[linux.SIGCHLD]
+ if !ok || (act.Handler != arch.SignalActIgnore && act.Flags&arch.SignalFlagNoCldStop == 0) {
+ sigchld := &arch.SignalInfo{
+ Signo: int32(linux.SIGCHLD),
+ Code: code,
+ }
+ sigchld.SetPid(int32(t.tg.pidns.tids[target]))
+ sigchld.SetUid(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ sigchld.SetStatus(status)
+ // TODO(b/72102453): Set utime, stime.
+ t.sendSignalLocked(sigchld, true /* group */)
+ }
+}
+
+// The runInterrupt state handles conditions indicated by interrupts.
+//
+// +stateify savable
+type runInterrupt struct{}
+
+func (*runInterrupt) execute(t *Task) taskRunState {
+ // Interrupts are de-duplicated (if t is interrupted twice before
+ // t.interrupted() is called, t.interrupted() will only return true once),
+ // so early exits from this function must re-enter the runInterrupt state
+ // to check for more interrupt-signaled conditions.
+
+ t.tg.signalHandlers.mu.Lock()
+
+ // Did we just leave a group stop?
+ if t.tg.groupContNotify {
+ t.tg.groupContNotify = false
+ sig := t.tg.groupStopSignal
+ intr := t.tg.groupContInterrupted
+ t.tg.signalHandlers.mu.Unlock()
+ t.tg.pidns.owner.mu.RLock()
+ // For consistency with Linux, if the parent and (thread group
+ // leader's) tracer are in the same thread group, deduplicate
+ // notifications.
+ notifyParent := t.tg.leader.parent != nil
+ if tracer := t.tg.leader.Tracer(); tracer != nil {
+ if notifyParent && tracer.tg == t.tg.leader.parent.tg {
+ notifyParent = false
+ }
+ // Sending CLD_STOPPED to the tracer doesn't really make any sense;
+ // the thread group leader may have already entered the stop and
+ // notified its tracer accordingly. But it's consistent with
+ // Linux...
+ if intr {
+ tracer.signalStop(t.tg.leader, arch.CLD_STOPPED, int32(sig))
+ if !notifyParent {
+ tracer.tg.eventQueue.Notify(EventGroupContinue | EventTraceeStop | EventChildGroupStop)
+ } else {
+ tracer.tg.eventQueue.Notify(EventGroupContinue | EventTraceeStop)
+ }
+ } else {
+ tracer.signalStop(t.tg.leader, arch.CLD_CONTINUED, int32(sig))
+ tracer.tg.eventQueue.Notify(EventGroupContinue)
+ }
+ }
+ if notifyParent {
+ // If groupContInterrupted, do as Linux does and pretend the group
+ // stop completed just before it ended. The theoretical behavior in
+ // this case would be to send a SIGCHLD indicating the completed
+ // stop, followed by a SIGCHLD indicating the continue. However,
+ // SIGCHLD is a standard signal, so the latter would always be
+ // dropped. Hence sending only the former is equivalent.
+ if intr {
+ t.tg.leader.parent.signalStop(t.tg.leader, arch.CLD_STOPPED, int32(sig))
+ t.tg.leader.parent.tg.eventQueue.Notify(EventGroupContinue | EventChildGroupStop)
+ } else {
+ t.tg.leader.parent.signalStop(t.tg.leader, arch.CLD_CONTINUED, int32(sig))
+ t.tg.leader.parent.tg.eventQueue.Notify(EventGroupContinue)
+ }
+ }
+ t.tg.pidns.owner.mu.RUnlock()
+ return (*runInterrupt)(nil)
+ }
+
+ // Do we need to enter a group stop or related ptrace stop? This path is
+ // analogous to Linux's kernel/signal.c:get_signal() => do_signal_stop()
+ // (with ptrace enabled) and do_jobctl_trap().
+ if t.groupStopPending || t.trapStopPending || t.trapNotifyPending {
+ sig := t.tg.groupStopSignal
+ notifyParent := false
+ if t.groupStopPending {
+ t.groupStopPending = false
+ // We care about t.tg.groupStopSignal (for tracer notification)
+ // even if this doesn't complete a group stop, so keep the
+ // value of sig we've already read.
+ notifyParent = t.participateGroupStopLocked()
+ }
+ t.trapStopPending = false
+ t.trapNotifyPending = false
+ // Drop the signal mutex so we can take the TaskSet mutex.
+ t.tg.signalHandlers.mu.Unlock()
+
+ t.tg.pidns.owner.mu.RLock()
+ if t.tg.leader.parent == nil {
+ notifyParent = false
+ }
+ if tracer := t.Tracer(); tracer != nil {
+ if t.ptraceSeized {
+ if sig == 0 {
+ sig = linux.SIGTRAP
+ }
+ // "If tracee was attached using PTRACE_SEIZE, group-stop is
+ // indicated by PTRACE_EVENT_STOP: status>>16 ==
+ // PTRACE_EVENT_STOP. This allows detection of group-stops
+ // without requiring an extra PTRACE_GETSIGINFO call." -
+ // "Group-stop", ptrace(2)
+ t.ptraceCode = int32(sig) | linux.PTRACE_EVENT_STOP<<8
+ t.ptraceSiginfo = &arch.SignalInfo{
+ Signo: int32(sig),
+ Code: t.ptraceCode,
+ }
+ t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t]))
+ t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ } else {
+ t.ptraceCode = int32(sig)
+ t.ptraceSiginfo = nil
+ }
+ if t.beginPtraceStopLocked() {
+ tracer.signalStop(t, arch.CLD_STOPPED, int32(sig))
+ // For consistency with Linux, if the parent and tracer are in the
+ // same thread group, deduplicate notification signals.
+ if notifyParent && tracer.tg == t.tg.leader.parent.tg {
+ notifyParent = false
+ tracer.tg.eventQueue.Notify(EventChildGroupStop | EventTraceeStop)
+ } else {
+ tracer.tg.eventQueue.Notify(EventTraceeStop)
+ }
+ }
+ } else {
+ t.tg.signalHandlers.mu.Lock()
+ if !t.killedLocked() {
+ t.beginInternalStopLocked((*groupStop)(nil))
+ }
+ t.tg.signalHandlers.mu.Unlock()
+ }
+ if notifyParent {
+ t.tg.leader.parent.signalStop(t.tg.leader, arch.CLD_STOPPED, int32(sig))
+ t.tg.leader.parent.tg.eventQueue.Notify(EventChildGroupStop)
+ }
+ t.tg.pidns.owner.mu.RUnlock()
+
+ return (*runInterrupt)(nil)
+ }
+
+ // Are there signals pending?
+ if info := t.dequeueSignalLocked(t.signalMask); info != nil {
+ if linux.SignalSetOf(linux.Signal(info.Signo))&StopSignals != 0 {
+ // Indicate that we've dequeued a stop signal before unlocking the
+ // signal mutex; initiateGroupStop will check for races with
+ // endGroupStopLocked after relocking it.
+ t.tg.groupStopDequeued = true
+ }
+ if t.ptraceSignalLocked(info) {
+ // Dequeueing the signal action must wait until after the
+ // signal-delivery-stop ends since the tracer can change or
+ // suppress the signal.
+ t.tg.signalHandlers.mu.Unlock()
+ return (*runInterruptAfterSignalDeliveryStop)(nil)
+ }
+ act := t.tg.signalHandlers.dequeueAction(linux.Signal(info.Signo))
+ t.tg.signalHandlers.mu.Unlock()
+ return t.deliverSignal(info, act)
+ }
+
+ t.tg.signalHandlers.mu.Unlock()
+ return (*runApp)(nil)
+}
+
+// +stateify savable
+type runInterruptAfterSignalDeliveryStop struct{}
+
+func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState {
+ t.tg.pidns.owner.mu.Lock()
+ // Can't defer unlock: deliverSignal must be called without holding TaskSet
+ // mutex.
+ sig := linux.Signal(t.ptraceCode)
+ defer func() {
+ t.ptraceSiginfo = nil
+ }()
+ if !sig.IsValid() {
+ t.tg.pidns.owner.mu.Unlock()
+ return (*runInterrupt)(nil)
+ }
+ info := t.ptraceSiginfo
+ if sig != linux.Signal(info.Signo) {
+ info.Signo = int32(sig)
+ info.Errno = 0
+ info.Code = arch.SignalInfoUser
+ // pid isn't a valid field for all signal numbers, but Linux
+ // doesn't care (kernel/signal.c:ptrace_signal()).
+ //
+ // Linux uses t->parent for the tid and uid here, which is the tracer
+ // if it hasn't detached or the real parent otherwise.
+ parent := t.parent
+ if tracer := t.Tracer(); tracer != nil {
+ parent = tracer
+ }
+ if parent == nil {
+ // Tracer has detached and t was created by Kernel.CreateProcess().
+ // Pretend the parent is in an ancestor PID + user namespace.
+ info.SetPid(0)
+ info.SetUid(int32(auth.OverflowUID))
+ } else {
+ info.SetPid(int32(t.tg.pidns.tids[parent]))
+ info.SetUid(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ }
+ }
+ t.tg.signalHandlers.mu.Lock()
+ t.tg.pidns.owner.mu.Unlock()
+ // If the signal is masked, re-queue it.
+ if linux.SignalSetOf(sig)&t.signalMask != 0 {
+ t.sendSignalLocked(info, false /* group */)
+ t.tg.signalHandlers.mu.Unlock()
+ return (*runInterrupt)(nil)
+ }
+ act := t.tg.signalHandlers.dequeueAction(linux.Signal(info.Signo))
+ t.tg.signalHandlers.mu.Unlock()
+ return t.deliverSignal(info, act)
+}
+
+// SignalRegister registers a waiter for pending signals.
+func (t *Task) SignalRegister(e *waiter.Entry, mask waiter.EventMask) {
+ t.tg.signalHandlers.mu.Lock()
+ t.signalQueue.EventRegister(e, mask)
+ t.tg.signalHandlers.mu.Unlock()
+}
+
+// SignalUnregister unregisters a waiter for pending signals.
+func (t *Task) SignalUnregister(e *waiter.Entry) {
+ t.tg.signalHandlers.mu.Lock()
+ t.signalQueue.EventUnregister(e)
+ t.tg.signalHandlers.mu.Unlock()
+}
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
new file mode 100644
index 000000000..8485fb4b6
--- /dev/null
+++ b/pkg/sentry/kernel/task_start.go
@@ -0,0 +1,319 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/futex"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/sched"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// TaskConfig defines the configuration of a new Task (see below).
+type TaskConfig struct {
+ // Kernel is the owning Kernel.
+ Kernel *Kernel
+
+ // Parent is the new task's parent. Parent may be nil.
+ Parent *Task
+
+ // If InheritParent is not nil, use InheritParent's parent as the new
+ // task's parent.
+ InheritParent *Task
+
+ // ThreadGroup is the ThreadGroup the new task belongs to.
+ ThreadGroup *ThreadGroup
+
+ // SignalMask is the new task's initial signal mask.
+ SignalMask linux.SignalSet
+
+ // TaskContext is the TaskContext of the new task. Ownership of the
+ // TaskContext is transferred to TaskSet.NewTask, whether or not it
+ // succeeds.
+ TaskContext *TaskContext
+
+ // FSContext is the FSContext of the new task. A reference must be held on
+ // FSContext, which is transferred to TaskSet.NewTask whether or not it
+ // succeeds.
+ FSContext *FSContext
+
+ // FDTable is the FDTableof the new task. A reference must be held on
+ // FDMap, which is transferred to TaskSet.NewTask whether or not it
+ // succeeds.
+ FDTable *FDTable
+
+ // Credentials is the Credentials of the new task.
+ Credentials *auth.Credentials
+
+ // Niceness is the niceness of the new task.
+ Niceness int
+
+ // NetworkNamespace is the network namespace to be used for the new task.
+ NetworkNamespace *inet.Namespace
+
+ // AllowedCPUMask contains the cpus that this task can run on.
+ AllowedCPUMask sched.CPUSet
+
+ // UTSNamespace is the UTSNamespace of the new task.
+ UTSNamespace *UTSNamespace
+
+ // IPCNamespace is the IPCNamespace of the new task.
+ IPCNamespace *IPCNamespace
+
+ // AbstractSocketNamespace is the AbstractSocketNamespace of the new task.
+ AbstractSocketNamespace *AbstractSocketNamespace
+
+ // MountNamespaceVFS2 is the MountNamespace of the new task.
+ MountNamespaceVFS2 *vfs.MountNamespace
+
+ // RSeqAddr is a pointer to the the userspace linux.RSeq structure.
+ RSeqAddr usermem.Addr
+
+ // RSeqSignature is the signature that the rseq abort IP must be signed
+ // with.
+ RSeqSignature uint32
+
+ // ContainerID is the container the new task belongs to.
+ ContainerID string
+}
+
+// NewTask creates a new task defined by cfg.
+//
+// NewTask does not start the returned task; the caller must call Task.Start.
+func (ts *TaskSet) NewTask(cfg *TaskConfig) (*Task, error) {
+ t, err := ts.newTask(cfg)
+ if err != nil {
+ cfg.TaskContext.release()
+ cfg.FSContext.DecRef()
+ cfg.FDTable.DecRef()
+ if cfg.MountNamespaceVFS2 != nil {
+ cfg.MountNamespaceVFS2.DecRef()
+ }
+ return nil, err
+ }
+ return t, nil
+}
+
+// newTask is a helper for TaskSet.NewTask that only takes ownership of parts
+// of cfg if it succeeds.
+func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
+ tg := cfg.ThreadGroup
+ tc := cfg.TaskContext
+ t := &Task{
+ taskNode: taskNode{
+ tg: tg,
+ parent: cfg.Parent,
+ children: make(map[*Task]struct{}),
+ },
+ runState: (*runApp)(nil),
+ interruptChan: make(chan struct{}, 1),
+ signalMask: cfg.SignalMask,
+ signalStack: arch.SignalStack{Flags: arch.SignalStackFlagDisable},
+ tc: *tc,
+ fsContext: cfg.FSContext,
+ fdTable: cfg.FDTable,
+ p: cfg.Kernel.Platform.NewContext(),
+ k: cfg.Kernel,
+ ptraceTracees: make(map[*Task]struct{}),
+ allowedCPUMask: cfg.AllowedCPUMask.Copy(),
+ ioUsage: &usage.IO{},
+ niceness: cfg.Niceness,
+ netns: cfg.NetworkNamespace,
+ utsns: cfg.UTSNamespace,
+ ipcns: cfg.IPCNamespace,
+ abstractSockets: cfg.AbstractSocketNamespace,
+ mountNamespaceVFS2: cfg.MountNamespaceVFS2,
+ rseqCPU: -1,
+ rseqAddr: cfg.RSeqAddr,
+ rseqSignature: cfg.RSeqSignature,
+ futexWaiter: futex.NewWaiter(),
+ containerID: cfg.ContainerID,
+ }
+ t.creds.Store(cfg.Credentials)
+ t.endStopCond.L = &t.tg.signalHandlers.mu
+ t.ptraceTracer.Store((*Task)(nil))
+ // We don't construct t.blockingTimer until Task.run(); see that function
+ // for justification.
+
+ // Make the new task (and possibly thread group) visible to the rest of
+ // the system atomically.
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+ if tg.exiting || tg.execing != nil {
+ // If the caller is in the same thread group, then what we return
+ // doesn't matter too much since the caller will exit before it returns
+ // to userspace. If the caller isn't in the same thread group, then
+ // we're in uncharted territory and can return whatever we want.
+ return nil, syserror.EINTR
+ }
+ if err := ts.assignTIDsLocked(t); err != nil {
+ return nil, err
+ }
+ // Below this point, newTask is expected not to fail (there is no rollback
+ // of assignTIDsLocked or any of the following).
+
+ // Logging on t's behalf will panic if t.logPrefix hasn't been
+ // initialized. This is the earliest point at which we can do so
+ // (since t now has thread IDs).
+ t.updateInfoLocked()
+
+ if cfg.InheritParent != nil {
+ t.parent = cfg.InheritParent.parent
+ }
+ if t.parent != nil {
+ t.parent.children[t] = struct{}{}
+ }
+
+ if tg.leader == nil {
+ // New thread group.
+ tg.leader = t
+ if parentPG := tg.parentPG(); parentPG == nil {
+ tg.createSession()
+ } else {
+ // Inherit the process group and terminal.
+ parentPG.incRefWithParent(parentPG)
+ tg.processGroup = parentPG
+ tg.tty = t.parent.tg.tty
+ }
+ }
+ tg.tasks.PushBack(t)
+ tg.tasksCount++
+ tg.liveTasks++
+ tg.activeTasks++
+
+ // Propagate external TaskSet stops to the new task.
+ t.stopCount = ts.stopCount
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ t.cpu = assignCPU(t.allowedCPUMask, ts.Root.tids[t])
+
+ t.startTime = t.k.RealtimeClock().Now()
+
+ return t, nil
+}
+
+// assignTIDsLocked ensures that new task t is visible in all PID namespaces in
+// which it should be visible.
+//
+// Preconditions: ts.mu must be locked for writing.
+func (ts *TaskSet) assignTIDsLocked(t *Task) error {
+ type allocatedTID struct {
+ ns *PIDNamespace
+ tid ThreadID
+ }
+ var allocatedTIDs []allocatedTID
+ for ns := t.tg.pidns; ns != nil; ns = ns.parent {
+ tid, err := ns.allocateTID()
+ if err != nil {
+ // Failure. Remove the tids we already allocated in descendant
+ // namespaces.
+ for _, a := range allocatedTIDs {
+ delete(a.ns.tasks, a.tid)
+ delete(a.ns.tids, t)
+ if t.tg.leader == nil {
+ delete(a.ns.tgids, t.tg)
+ }
+ }
+ return err
+ }
+ ns.tasks[tid] = t
+ ns.tids[t] = tid
+ if t.tg.leader == nil {
+ // New thread group.
+ ns.tgids[t.tg] = tid
+ }
+ allocatedTIDs = append(allocatedTIDs, allocatedTID{ns, tid})
+ }
+ return nil
+}
+
+// allocateTID returns an unused ThreadID from ns.
+//
+// Preconditions: ns.owner.mu must be locked for writing.
+func (ns *PIDNamespace) allocateTID() (ThreadID, error) {
+ if ns.exiting {
+ // "In this case, a subsequent fork(2) into this PID namespace will
+ // fail with the error ENOMEM; it is not possible to create a new
+ // processes [sic] in a PID namespace whose init process has
+ // terminated." - pid_namespaces(7)
+ return 0, syserror.ENOMEM
+ }
+ tid := ns.last
+ for {
+ // Next.
+ tid++
+ if tid > TasksLimit {
+ tid = InitTID + 1
+ }
+
+ // Is it available?
+ tidInUse := func() bool {
+ if _, ok := ns.tasks[tid]; ok {
+ return true
+ }
+ if _, ok := ns.processGroups[ProcessGroupID(tid)]; ok {
+ return true
+ }
+ if _, ok := ns.sessions[SessionID(tid)]; ok {
+ return true
+ }
+ return false
+ }()
+
+ if !tidInUse {
+ ns.last = tid
+ return tid, nil
+ }
+
+ // Did we do a full cycle?
+ if tid == ns.last {
+ // No tid available.
+ return 0, syserror.EAGAIN
+ }
+ }
+}
+
+// Start starts the task goroutine. Start must be called exactly once for each
+// task returned by NewTask.
+//
+// 'tid' must be the task's TID in the root PID namespace and it's used for
+// debugging purposes only (set as parameter to Task.run to make it visible
+// in stack dumps).
+func (t *Task) Start(tid ThreadID) {
+ // If the task was restored, it may be "starting" after having already exited.
+ if t.runState == nil {
+ return
+ }
+ t.goroutineStopped.Add(1)
+ t.tg.liveGoroutines.Add(1)
+ t.tg.pidns.owner.liveGoroutines.Add(1)
+ t.tg.pidns.owner.runningGoroutines.Add(1)
+
+ // Task is now running in system mode.
+ t.accountTaskGoroutineLeave(TaskGoroutineNonexistent)
+
+ // Use the task's TID in the root PID namespace to make it visible in stack dumps.
+ go t.run(uintptr(tid)) // S/R-SAFE: synchronizes with saving through stops
+}
diff --git a/pkg/sentry/kernel/task_stop.go b/pkg/sentry/kernel/task_stop.go
new file mode 100644
index 000000000..10c6e455c
--- /dev/null
+++ b/pkg/sentry/kernel/task_stop.go
@@ -0,0 +1,226 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+// This file implements task stops, which represent the equivalent of Linux's
+// uninterruptible sleep states in a way that is compatible with save/restore.
+// Task stops comprise both internal stops (which form part of the task's
+// "normal" control flow) and external stops (which do not); see README.md for
+// details.
+//
+// There are multiple interfaces for interacting with stops because there are
+// multiple cases to consider:
+//
+// - A task goroutine can begin a stop on its associated task (e.g. a
+// vfork() syscall stopping the calling task until the child task releases its
+// MM). In this case, calling Task.interrupt is both unnecessary (the task
+// goroutine obviously cannot be blocked in Task.block or executing application
+// code) and undesirable (as it may spuriously interrupt a in-progress
+// syscall).
+//
+// Beginning internal stops in this case is implemented by
+// Task.beginInternalStop / Task.beginInternalStopLocked. As of this writing,
+// there are no instances of this case that begin external stops, except for
+// autosave; however, autosave terminates the sentry without ending the
+// external stop, so the spurious interrupt is moot.
+//
+// - An arbitrary goroutine can begin a stop on an unrelated task (e.g. all
+// tasks being stopped in preparation for state checkpointing). If the task
+// goroutine may be in Task.block or executing application code, it must be
+// interrupted by Task.interrupt for it to actually enter the stop; since,
+// strictly speaking, we have no way of determining this, we call
+// Task.interrupt unconditionally.
+//
+// Beginning external stops in this case is implemented by
+// Task.BeginExternalStop. As of this writing, there are no instances of this
+// case that begin internal stops.
+//
+// - An arbitrary goroutine can end a stop on an unrelated task (e.g. an
+// exiting task resuming a sibling task that has been blocked in an execve()
+// syscall waiting for other tasks to exit). In this case, Task.endStopCond
+// must be notified to kick the task goroutine out of Task.doStop.
+//
+// Ending internal stops in this case is implemented by
+// Task.endInternalStopLocked. Ending external stops in this case is
+// implemented by Task.EndExternalStop.
+//
+// - Hypothetically, a task goroutine can end an internal stop on its
+// associated task. As of this writing, there are no instances of this case.
+// However, any instances of this case could still use the above functions,
+// since notifying Task.endStopCond would be unnecessary but harmless.
+
+import (
+ "fmt"
+ "sync/atomic"
+)
+
+// A TaskStop is a condition visible to the task control flow graph that
+// prevents a task goroutine from running or exiting, i.e. an internal stop.
+//
+// NOTE(b/30793614): Most TaskStops don't contain any data; they're
+// distinguished by their type. The obvious way to implement such a TaskStop
+// is:
+//
+// type groupStop struct{}
+// func (groupStop) Killable() bool { return true }
+// ...
+// t.beginInternalStop(groupStop{})
+//
+// However, this doesn't work because the state package can't serialize values,
+// only pointers. Furthermore, the correctness of save/restore depends on the
+// ability to pass a TaskStop to endInternalStop that will compare equal to the
+// TaskStop that was passed to beginInternalStop, even if a save/restore cycle
+// occurred between the two. As a result, the current idiom is to always use a
+// typecast nil for data-free TaskStops:
+//
+// type groupStop struct{}
+// func (*groupStop) Killable() bool { return true }
+// ...
+// t.beginInternalStop((*groupStop)(nil))
+//
+// This is pretty gross, but the alternatives seem grosser.
+type TaskStop interface {
+ // Killable returns true if Task.Kill should end the stop prematurely.
+ // Killable is analogous to Linux's TASK_WAKEKILL.
+ Killable() bool
+}
+
+// beginInternalStop indicates the start of an internal stop that applies to t.
+//
+// Preconditions: The task must not already be in an internal stop (i.e. t.stop
+// == nil). The caller must be running on the task goroutine.
+func (t *Task) beginInternalStop(s TaskStop) {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ t.beginInternalStopLocked(s)
+}
+
+// Preconditions: The signal mutex must be locked. All preconditions for
+// Task.beginInternalStop also apply.
+func (t *Task) beginInternalStopLocked(s TaskStop) {
+ if t.stop != nil {
+ panic(fmt.Sprintf("Attempting to enter internal stop %#v when already in internal stop %#v", s, t.stop))
+ }
+ t.Debugf("Entering internal stop %#v", s)
+ t.stop = s
+ t.beginStopLocked()
+}
+
+// endInternalStopLocked indicates the end of an internal stop that applies to
+// t. endInternalStopLocked does not wait for the task to resume.
+//
+// The caller is responsible for ensuring that the internal stop they expect
+// actually applies to t; this requires holding the signal mutex which protects
+// t.stop, which is why there is no endInternalStop that locks the signal mutex
+// for you.
+//
+// Preconditions: The signal mutex must be locked. The task must be in an
+// internal stop (i.e. t.stop != nil).
+func (t *Task) endInternalStopLocked() {
+ if t.stop == nil {
+ panic("Attempting to leave non-existent internal stop")
+ }
+ t.Debugf("Leaving internal stop %#v", t.stop)
+ t.stop = nil
+ t.endStopLocked()
+}
+
+// BeginExternalStop indicates the start of an external stop that applies to t.
+// BeginExternalStop does not wait for t's task goroutine to stop.
+func (t *Task) BeginExternalStop() {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ t.beginStopLocked()
+ t.interrupt()
+}
+
+// EndExternalStop indicates the end of an external stop started by a previous
+// call to Task.BeginExternalStop. EndExternalStop does not wait for t's task
+// goroutine to resume.
+func (t *Task) EndExternalStop() {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ t.endStopLocked()
+}
+
+// beginStopLocked increments t.stopCount to indicate that a new internal or
+// external stop applies to t.
+//
+// Preconditions: The signal mutex must be locked.
+func (t *Task) beginStopLocked() {
+ if newval := atomic.AddInt32(&t.stopCount, 1); newval <= 0 {
+ // Most likely overflow.
+ panic(fmt.Sprintf("Invalid stopCount: %d", newval))
+ }
+}
+
+// endStopLocked decrements t.stopCount to indicate that an existing internal
+// or external stop no longer applies to t.
+//
+// Preconditions: The signal mutex must be locked.
+func (t *Task) endStopLocked() {
+ if newval := atomic.AddInt32(&t.stopCount, -1); newval < 0 {
+ panic(fmt.Sprintf("Invalid stopCount: %d", newval))
+ } else if newval == 0 {
+ t.endStopCond.Signal()
+ }
+}
+
+// BeginExternalStop indicates the start of an external stop that applies to
+// all current and future tasks in ts. BeginExternalStop does not wait for
+// task goroutines to stop.
+func (ts *TaskSet) BeginExternalStop() {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ ts.stopCount++
+ if ts.stopCount <= 0 {
+ panic(fmt.Sprintf("Invalid stopCount: %d", ts.stopCount))
+ }
+ if ts.Root == nil {
+ return
+ }
+ for t := range ts.Root.tids {
+ t.tg.signalHandlers.mu.Lock()
+ t.beginStopLocked()
+ t.tg.signalHandlers.mu.Unlock()
+ t.interrupt()
+ }
+}
+
+// EndExternalStop indicates the end of an external stop started by a previous
+// call to TaskSet.BeginExternalStop. EndExternalStop does not wait for task
+// goroutines to resume.
+func (ts *TaskSet) EndExternalStop() {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ ts.stopCount--
+ if ts.stopCount < 0 {
+ panic(fmt.Sprintf("Invalid stopCount: %d", ts.stopCount))
+ }
+ if ts.Root == nil {
+ return
+ }
+ for t := range ts.Root.tids {
+ t.tg.signalHandlers.mu.Lock()
+ t.endStopLocked()
+ t.tg.signalHandlers.mu.Unlock()
+ }
+}
diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go
new file mode 100644
index 000000000..a5903b0b5
--- /dev/null
+++ b/pkg/sentry/kernel/task_syscall.go
@@ -0,0 +1,469 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+ "os"
+ "runtime/trace"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/metric"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// SyscallRestartErrno represents a ERESTART* errno defined in the Linux's kernel
+// include/linux/errno.h. These errnos are never returned to userspace
+// directly, but are used to communicate the expected behavior of an
+// interrupted syscall from the syscall to signal handling.
+type SyscallRestartErrno int
+
+// These numeric values are significant because ptrace syscall exit tracing can
+// observe them.
+//
+// For all of the following errnos, if the syscall is not interrupted by a
+// signal delivered to a user handler, the syscall is restarted.
+const (
+ // ERESTARTSYS is returned by an interrupted syscall to indicate that it
+ // should be converted to EINTR if interrupted by a signal delivered to a
+ // user handler without SA_RESTART set, and restarted otherwise.
+ ERESTARTSYS = SyscallRestartErrno(512)
+
+ // ERESTARTNOINTR is returned by an interrupted syscall to indicate that it
+ // should always be restarted.
+ ERESTARTNOINTR = SyscallRestartErrno(513)
+
+ // ERESTARTNOHAND is returned by an interrupted syscall to indicate that it
+ // should be converted to EINTR if interrupted by a signal delivered to a
+ // user handler, and restarted otherwise.
+ ERESTARTNOHAND = SyscallRestartErrno(514)
+
+ // ERESTART_RESTARTBLOCK is returned by an interrupted syscall to indicate
+ // that it should be restarted using a custom function. The interrupted
+ // syscall must register a custom restart function by calling
+ // Task.SetRestartSyscallFn.
+ ERESTART_RESTARTBLOCK = SyscallRestartErrno(516)
+)
+
+var vsyscallCount = metric.MustCreateNewUint64Metric("/kernel/vsyscall_count", false /* sync */, "Number of times vsyscalls were invoked by the application")
+
+// Error implements error.Error.
+func (e SyscallRestartErrno) Error() string {
+ // Descriptions are borrowed from strace.
+ switch e {
+ case ERESTARTSYS:
+ return "to be restarted if SA_RESTART is set"
+ case ERESTARTNOINTR:
+ return "to be restarted"
+ case ERESTARTNOHAND:
+ return "to be restarted if no handler"
+ case ERESTART_RESTARTBLOCK:
+ return "interrupted by signal"
+ default:
+ return "(unknown interrupt error)"
+ }
+}
+
+// SyscallRestartErrnoFromReturn returns the SyscallRestartErrno represented by
+// rv, the value in a syscall return register.
+func SyscallRestartErrnoFromReturn(rv uintptr) (SyscallRestartErrno, bool) {
+ switch int(rv) {
+ case -int(ERESTARTSYS):
+ return ERESTARTSYS, true
+ case -int(ERESTARTNOINTR):
+ return ERESTARTNOINTR, true
+ case -int(ERESTARTNOHAND):
+ return ERESTARTNOHAND, true
+ case -int(ERESTART_RESTARTBLOCK):
+ return ERESTART_RESTARTBLOCK, true
+ default:
+ return 0, false
+ }
+}
+
+// SyscallRestartBlock represents the restart block for a syscall restartable
+// with a custom function. It encapsulates the state required to restart a
+// syscall across a S/R.
+type SyscallRestartBlock interface {
+ Restart(t *Task) (uintptr, error)
+}
+
+// SyscallControl is returned by syscalls to control the behavior of
+// Task.doSyscallInvoke.
+type SyscallControl struct {
+ // next is the state that the task goroutine should switch to. If next is
+ // nil, the task goroutine should continue to syscall exit as usual.
+ next taskRunState
+
+ // If ignoreReturn is true, Task.doSyscallInvoke should not store any value
+ // in the task's syscall return value register.
+ ignoreReturn bool
+}
+
+var (
+ // CtrlDoExit is returned by the implementations of the exit and exit_group
+ // syscalls to enter the task exit path directly, skipping syscall exit
+ // tracing.
+ CtrlDoExit = &SyscallControl{next: (*runExit)(nil), ignoreReturn: true}
+
+ // ctrlStopAndReinvokeSyscall is returned by syscalls using the external
+ // feature before syscall execution. This causes Task.doSyscallInvoke
+ // to return runSyscallReinvoke, allowing Task.run to check for stops
+ // before immediately re-invoking the syscall (skipping the re-checking
+ // of seccomp filters and ptrace which would confuse userspace
+ // tracing).
+ ctrlStopAndReinvokeSyscall = &SyscallControl{next: (*runSyscallReinvoke)(nil), ignoreReturn: true}
+
+ // ctrlStopBeforeSyscallExit is returned by syscalls that initiate a stop at
+ // their end. This causes Task.doSyscallInvoke to return runSyscallExit, rather
+ // than tail-calling it, allowing stops to be checked before syscall exit.
+ ctrlStopBeforeSyscallExit = &SyscallControl{next: (*runSyscallExit)(nil)}
+)
+
+func (t *Task) invokeExternal() {
+ t.BeginExternalStop()
+ go func() { // S/R-SAFE: External control flow.
+ defer t.EndExternalStop()
+ t.SyscallTable().External(t.Kernel())
+ }()
+}
+
+func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval uintptr, ctrl *SyscallControl, err error) {
+ s := t.SyscallTable()
+
+ fe := s.FeatureEnable.Word(sysno)
+
+ var straceContext interface{}
+ if bits.IsAnyOn32(fe, StraceEnableBits) {
+ straceContext = s.Stracer.SyscallEnter(t, sysno, args, fe)
+ }
+
+ if bits.IsOn32(fe, ExternalBeforeEnable) && (s.ExternalFilterBefore == nil || s.ExternalFilterBefore(t, sysno, args)) {
+ t.invokeExternal()
+ // Ensure we check for stops, then invoke the syscall again.
+ ctrl = ctrlStopAndReinvokeSyscall
+ } else {
+ fn := s.Lookup(sysno)
+ var region *trace.Region // Only non-nil if tracing == true.
+ if trace.IsEnabled() {
+ region = trace.StartRegion(t.traceContext, s.LookupName(sysno))
+ }
+ if fn != nil {
+ // Call our syscall implementation.
+ rval, ctrl, err = fn(t, args)
+ } else {
+ // Use the missing function if not found.
+ rval, err = t.SyscallTable().Missing(t, sysno, args)
+ }
+ if region != nil {
+ region.End()
+ }
+ }
+
+ if bits.IsOn32(fe, ExternalAfterEnable) && (s.ExternalFilterAfter == nil || s.ExternalFilterAfter(t, sysno, args)) {
+ t.invokeExternal()
+ // Don't reinvoke the syscall.
+ }
+
+ if bits.IsAnyOn32(fe, StraceEnableBits) {
+ s.Stracer.SyscallExit(straceContext, t, sysno, rval, err)
+ }
+
+ return
+}
+
+// doSyscall is the entry point for an invocation of a system call specified by
+// the current state of t's registers.
+//
+// The syscall path is very hot; avoid defer.
+func (t *Task) doSyscall() taskRunState {
+ // Save value of the register which is clobbered in the following
+ // t.Arch().SetReturn(-ENOSYS) operation. This is dedicated to arm64.
+ //
+ // On x86, register rax was shared by syscall number and return
+ // value, and at the entry of the syscall handler, the rax was
+ // saved to regs.orig_rax which was exposed to userspace.
+ // But on arm64, syscall number was passed through X8, and the X0
+ // was shared by the first syscall argument and return value. The
+ // X0 was saved to regs.orig_x0 which was not exposed to userspace.
+ // So we have to do the same operation here to save the X0 value
+ // into the task context.
+ t.Arch().SyscallSaveOrig()
+
+ sysno := t.Arch().SyscallNo()
+ args := t.Arch().SyscallArgs()
+
+ // Tracers expect to see this between when the task traps into the kernel
+ // to perform a syscall and when the syscall is actually invoked.
+ // This useless-looking temporary is needed because Go.
+ tmp := uintptr(syscall.ENOSYS)
+ t.Arch().SetReturn(-tmp)
+
+ // Check seccomp filters. The nil check is for performance (as seccomp use
+ // is rare), not needed for correctness.
+ if t.syscallFilters.Load() != nil {
+ switch r := t.checkSeccompSyscall(int32(sysno), args, usermem.Addr(t.Arch().IP())); r {
+ case linux.SECCOMP_RET_ERRNO, linux.SECCOMP_RET_TRAP:
+ t.Debugf("Syscall %d: denied by seccomp", sysno)
+ return (*runSyscallExit)(nil)
+ case linux.SECCOMP_RET_ALLOW:
+ // ok
+ case linux.SECCOMP_RET_KILL_THREAD:
+ t.Debugf("Syscall %d: killed by seccomp", sysno)
+ t.PrepareExit(ExitStatus{Signo: int(linux.SIGSYS)})
+ return (*runExit)(nil)
+ case linux.SECCOMP_RET_TRACE:
+ t.Debugf("Syscall %d: stopping for PTRACE_EVENT_SECCOMP", sysno)
+ return (*runSyscallAfterPtraceEventSeccomp)(nil)
+ default:
+ panic(fmt.Sprintf("Unknown seccomp result %d", r))
+ }
+ }
+
+ return t.doSyscallEnter(sysno, args)
+}
+
+type runSyscallAfterPtraceEventSeccomp struct{}
+
+func (*runSyscallAfterPtraceEventSeccomp) execute(t *Task) taskRunState {
+ if t.killed() {
+ // "[S]yscall-exit-stop is not generated prior to death by SIGKILL." -
+ // ptrace(2)
+ return (*runInterrupt)(nil)
+ }
+ sysno := t.Arch().SyscallNo()
+ // "The tracer can skip the system call by changing the syscall number to
+ // -1." - Documentation/prctl/seccomp_filter.txt
+ if sysno == ^uintptr(0) {
+ return (*runSyscallExit)(nil).execute(t)
+ }
+ args := t.Arch().SyscallArgs()
+ return t.doSyscallEnter(sysno, args)
+}
+
+func (t *Task) doSyscallEnter(sysno uintptr, args arch.SyscallArguments) taskRunState {
+ if next, ok := t.ptraceSyscallEnter(); ok {
+ return next
+ }
+ return t.doSyscallInvoke(sysno, args)
+}
+
+// +stateify savable
+type runSyscallAfterSyscallEnterStop struct{}
+
+func (*runSyscallAfterSyscallEnterStop) execute(t *Task) taskRunState {
+ if sig := linux.Signal(t.ptraceCode); sig.IsValid() {
+ t.tg.signalHandlers.mu.Lock()
+ t.sendSignalLocked(SignalInfoPriv(sig), false /* group */)
+ t.tg.signalHandlers.mu.Unlock()
+ }
+ if t.killed() {
+ return (*runInterrupt)(nil)
+ }
+ sysno := t.Arch().SyscallNo()
+ if sysno == ^uintptr(0) {
+ return (*runSyscallExit)(nil)
+ }
+ args := t.Arch().SyscallArgs()
+
+ return t.doSyscallInvoke(sysno, args)
+}
+
+// +stateify savable
+type runSyscallAfterSysemuStop struct{}
+
+func (*runSyscallAfterSysemuStop) execute(t *Task) taskRunState {
+ if sig := linux.Signal(t.ptraceCode); sig.IsValid() {
+ t.tg.signalHandlers.mu.Lock()
+ t.sendSignalLocked(SignalInfoPriv(sig), false /* group */)
+ t.tg.signalHandlers.mu.Unlock()
+ }
+ if t.killed() {
+ return (*runInterrupt)(nil)
+ }
+ return (*runSyscallExit)(nil).execute(t)
+}
+
+func (t *Task) doSyscallInvoke(sysno uintptr, args arch.SyscallArguments) taskRunState {
+ rval, ctrl, err := t.executeSyscall(sysno, args)
+
+ if ctrl != nil {
+ if !ctrl.ignoreReturn {
+ t.Arch().SetReturn(rval)
+ }
+ if ctrl.next != nil {
+ return ctrl.next
+ }
+ } else if err != nil {
+ t.Arch().SetReturn(uintptr(-ExtractErrno(err, int(sysno))))
+ t.haveSyscallReturn = true
+ } else {
+ t.Arch().SetReturn(rval)
+ }
+
+ return (*runSyscallExit)(nil).execute(t)
+}
+
+// +stateify savable
+type runSyscallReinvoke struct{}
+
+func (*runSyscallReinvoke) execute(t *Task) taskRunState {
+ if t.killed() {
+ // It's possible that since the last execution, the task has
+ // been forcible killed. Invoking the system call here could
+ // result in an infinite loop if it is again preempted by an
+ // external stop and reinvoked.
+ return (*runInterrupt)(nil)
+ }
+
+ sysno := t.Arch().SyscallNo()
+ args := t.Arch().SyscallArgs()
+ return t.doSyscallInvoke(sysno, args)
+}
+
+// +stateify savable
+type runSyscallExit struct{}
+
+func (*runSyscallExit) execute(t *Task) taskRunState {
+ t.ptraceSyscallExit()
+ return (*runApp)(nil)
+}
+
+// doVsyscall is the entry point for a vsyscall invocation of syscall sysno, as
+// indicated by an execution fault at address addr. doVsyscall returns the
+// task's next run state.
+func (t *Task) doVsyscall(addr usermem.Addr, sysno uintptr) taskRunState {
+ vsyscallCount.Increment()
+
+ // Grab the caller up front, to make sure there's a sensible stack.
+ caller := t.Arch().Native(uintptr(0))
+ if _, err := t.CopyIn(usermem.Addr(t.Arch().Stack()), caller); err != nil {
+ t.Debugf("vsyscall %d: error reading return address from stack: %v", sysno, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return (*runApp)(nil)
+ }
+
+ // For _vsyscalls_, there is no need to translate System V calling convention
+ // to syscall ABI because they both use RDI, RSI, and RDX for the first three
+ // arguments and none of the vsyscalls uses more than two arguments.
+ args := t.Arch().SyscallArgs()
+ if t.syscallFilters.Load() != nil {
+ switch r := t.checkSeccompSyscall(int32(sysno), args, addr); r {
+ case linux.SECCOMP_RET_ERRNO, linux.SECCOMP_RET_TRAP:
+ t.Debugf("vsyscall %d, caller %x: denied by seccomp", sysno, t.Arch().Value(caller))
+ return (*runApp)(nil)
+ case linux.SECCOMP_RET_ALLOW:
+ // ok
+ case linux.SECCOMP_RET_TRACE:
+ t.Debugf("vsyscall %d, caller %x: stopping for PTRACE_EVENT_SECCOMP", sysno, t.Arch().Value(caller))
+ return &runVsyscallAfterPtraceEventSeccomp{addr, sysno, caller}
+ case linux.SECCOMP_RET_KILL_THREAD:
+ t.Debugf("vsyscall %d: killed by seccomp", sysno)
+ t.PrepareExit(ExitStatus{Signo: int(linux.SIGSYS)})
+ return (*runExit)(nil)
+ default:
+ panic(fmt.Sprintf("Unknown seccomp result %d", r))
+ }
+ }
+
+ return t.doVsyscallInvoke(sysno, args, caller)
+}
+
+type runVsyscallAfterPtraceEventSeccomp struct {
+ addr usermem.Addr
+ sysno uintptr
+ caller interface{}
+}
+
+func (r *runVsyscallAfterPtraceEventSeccomp) execute(t *Task) taskRunState {
+ if t.killed() {
+ return (*runInterrupt)(nil)
+ }
+ sysno := t.Arch().SyscallNo()
+ // "... the syscall may not be changed to another system call using the
+ // orig_rax register. It may only be changed to -1 order [sic] to skip the
+ // currently emulated call. ... The tracer MUST NOT modify rip or rsp." -
+ // Documentation/prctl/seccomp_filter.txt. On Linux, changing orig_ax or ip
+ // causes do_exit(SIGSYS), and changing sp is ignored.
+ if (sysno != ^uintptr(0) && sysno != r.sysno) || usermem.Addr(t.Arch().IP()) != r.addr {
+ t.PrepareExit(ExitStatus{Signo: int(linux.SIGSYS)})
+ return (*runExit)(nil)
+ }
+ if sysno == ^uintptr(0) {
+ return (*runApp)(nil)
+ }
+ return t.doVsyscallInvoke(sysno, t.Arch().SyscallArgs(), r.caller)
+}
+
+func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, caller interface{}) taskRunState {
+ rval, ctrl, err := t.executeSyscall(sysno, args)
+ if ctrl != nil {
+ t.Debugf("vsyscall %d, caller %x: syscall control: %v", sysno, t.Arch().Value(caller), ctrl)
+ // Set the return value. The stack has already been adjusted.
+ t.Arch().SetReturn(0)
+ } else if err == nil {
+ t.Debugf("vsyscall %d, caller %x: successfully emulated syscall", sysno, t.Arch().Value(caller))
+ // Set the return value. The stack has already been adjusted.
+ t.Arch().SetReturn(uintptr(rval))
+ } else {
+ t.Debugf("vsyscall %d, caller %x: emulated syscall returned error: %v", sysno, t.Arch().Value(caller), err)
+ if err == syserror.EFAULT {
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ // A return is not emulated in this case.
+ return (*runApp)(nil)
+ }
+ t.Arch().SetReturn(uintptr(-ExtractErrno(err, int(sysno))))
+ }
+ t.Arch().SetIP(t.Arch().Value(caller))
+ t.Arch().SetStack(t.Arch().Stack() + uintptr(t.Arch().Width()))
+ return (*runApp)(nil)
+}
+
+// ExtractErrno extracts an integer error number from the error.
+// The syscall number is purely for context in the error case. Use -1 if
+// syscall number is unknown.
+func ExtractErrno(err error, sysno int) int {
+ switch err := err.(type) {
+ case nil:
+ return 0
+ case syscall.Errno:
+ return int(err)
+ case SyscallRestartErrno:
+ return int(err)
+ case *memmap.BusError:
+ // Bus errors may generate SIGBUS, but for syscalls they still
+ // return EFAULT. See case in task_run.go where the fault is
+ // handled (and the SIGBUS is delivered).
+ return int(syscall.EFAULT)
+ case *os.PathError:
+ return ExtractErrno(err.Err, sysno)
+ case *os.LinkError:
+ return ExtractErrno(err.Err, sysno)
+ case *os.SyscallError:
+ return ExtractErrno(err.Err, sysno)
+ default:
+ if errno, ok := syserror.TranslateError(err); ok {
+ return int(errno)
+ }
+ }
+ panic(fmt.Sprintf("Unknown syscall %d error: %v", sysno, err))
+}
diff --git a/pkg/sentry/kernel/task_test.go b/pkg/sentry/kernel/task_test.go
new file mode 100644
index 000000000..cfcde9a7a
--- /dev/null
+++ b/pkg/sentry/kernel/task_test.go
@@ -0,0 +1,69 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sentry/kernel/sched"
+)
+
+func TestTaskCPU(t *testing.T) {
+ for _, test := range []struct {
+ mask sched.CPUSet
+ tid ThreadID
+ cpu int32
+ }{
+ {
+ mask: []byte{0xff},
+ tid: 1,
+ cpu: 0,
+ },
+ {
+ mask: []byte{0xff},
+ tid: 10,
+ cpu: 1,
+ },
+ {
+ // more than 8 cpus.
+ mask: []byte{0xff, 0xff},
+ tid: 10,
+ cpu: 9,
+ },
+ {
+ // missing the first cpu.
+ mask: []byte{0xfe},
+ tid: 1,
+ cpu: 1,
+ },
+ {
+ mask: []byte{0xfe},
+ tid: 10,
+ cpu: 3,
+ },
+ {
+ // missing the fifth cpu.
+ mask: []byte{0xef},
+ tid: 10,
+ cpu: 2,
+ },
+ } {
+ assigned := assignCPU(test.mask, test.tid)
+ if test.cpu != assigned {
+ t.Errorf("assignCPU(%v, %v) got %v, want %v", test.mask, test.tid, assigned, test.cpu)
+ }
+ }
+
+}
diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go
new file mode 100644
index 000000000..b02044ad2
--- /dev/null
+++ b/pkg/sentry/kernel/task_usermem.go
@@ -0,0 +1,301 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "math"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// MAX_RW_COUNT is the maximum size in bytes of a single read or write.
+// Reads and writes that exceed this size may be silently truncated.
+// (Linux: include/linux/fs.h:MAX_RW_COUNT)
+var MAX_RW_COUNT = int(usermem.Addr(math.MaxInt32).RoundDown())
+
+// Activate ensures that the task has an active address space.
+func (t *Task) Activate() {
+ if mm := t.MemoryManager(); mm != nil {
+ if err := mm.Activate(t); err != nil {
+ panic("unable to activate mm: " + err.Error())
+ }
+ }
+}
+
+// Deactivate relinquishes the task's active address space.
+func (t *Task) Deactivate() {
+ if mm := t.MemoryManager(); mm != nil {
+ mm.Deactivate()
+ }
+}
+
+// CopyIn copies a fixed-size value or slice of fixed-size values in from the
+// task's memory. The copy will fail with syscall.EFAULT if it traverses user
+// memory that is unmapped or not readable by the user.
+//
+// This Task's AddressSpace must be active.
+func (t *Task) CopyIn(addr usermem.Addr, dst interface{}) (int, error) {
+ return usermem.CopyObjectIn(t, t.MemoryManager(), addr, dst, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+}
+
+// CopyInBytes is a fast version of CopyIn if the caller can serialize the
+// data without reflection and pass in a byte slice.
+//
+// This Task's AddressSpace must be active.
+func (t *Task) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
+ return t.MemoryManager().CopyIn(t, addr, dst, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+}
+
+// CopyOut copies a fixed-size value or slice of fixed-size values out to the
+// task's memory. The copy will fail with syscall.EFAULT if it traverses user
+// memory that is unmapped or not writeable by the user.
+//
+// This Task's AddressSpace must be active.
+func (t *Task) CopyOut(addr usermem.Addr, src interface{}) (int, error) {
+ return usermem.CopyObjectOut(t, t.MemoryManager(), addr, src, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+}
+
+// CopyOutBytes is a fast version of CopyOut if the caller can serialize the
+// data without reflection and pass in a byte slice.
+//
+// This Task's AddressSpace must be active.
+func (t *Task) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) {
+ return t.MemoryManager().CopyOut(t, addr, src, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+}
+
+// CopyInString copies a NUL-terminated string of length at most maxlen in from
+// the task's memory. The copy will fail with syscall.EFAULT if it traverses
+// user memory that is unmapped or not readable by the user.
+//
+// This Task's AddressSpace must be active.
+func (t *Task) CopyInString(addr usermem.Addr, maxlen int) (string, error) {
+ return usermem.CopyStringIn(t, t.MemoryManager(), addr, maxlen, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+}
+
+// CopyInVector copies a NULL-terminated vector of strings from the task's
+// memory. The copy will fail with syscall.EFAULT if it traverses
+// user memory that is unmapped or not readable by the user.
+//
+// maxElemSize is the maximum size of each individual element.
+//
+// maxTotalSize is the maximum total length of all elements plus the total
+// number of elements. For example, the following strings correspond to
+// the following set of sizes:
+//
+// { "a", "b", "c" } => 6 (3 for lengths, 3 for elements)
+// { "abc" } => 4 (3 for length, 1 for elements)
+//
+// This Task's AddressSpace must be active.
+func (t *Task) CopyInVector(addr usermem.Addr, maxElemSize, maxTotalSize int) ([]string, error) {
+ var v []string
+ for {
+ argAddr := t.Arch().Native(0)
+ if _, err := t.CopyIn(addr, argAddr); err != nil {
+ return v, err
+ }
+ if t.Arch().Value(argAddr) == 0 {
+ break
+ }
+ // Each string has a zero terminating byte counted, so copying out a string
+ // requires at least one byte of space. Also, see the calculation below.
+ if maxTotalSize <= 0 {
+ return nil, syserror.ENOMEM
+ }
+ thisMax := maxElemSize
+ if maxTotalSize < thisMax {
+ thisMax = maxTotalSize
+ }
+ arg, err := t.CopyInString(usermem.Addr(t.Arch().Value(argAddr)), thisMax)
+ if err != nil {
+ return v, err
+ }
+ v = append(v, arg)
+ addr += usermem.Addr(t.Arch().Width())
+ maxTotalSize -= len(arg) + 1
+ }
+ return v, nil
+}
+
+// CopyOutIovecs converts src to an array of struct iovecs and copies it to the
+// memory mapped at addr.
+//
+// Preconditions: As for usermem.IO.CopyOut. The caller must be running on the
+// task goroutine. t's AddressSpace must be active.
+func (t *Task) CopyOutIovecs(addr usermem.Addr, src usermem.AddrRangeSeq) error {
+ switch t.Arch().Width() {
+ case 8:
+ const itemLen = 16
+ if _, ok := addr.AddLength(uint64(src.NumRanges()) * itemLen); !ok {
+ return syserror.EFAULT
+ }
+
+ b := t.CopyScratchBuffer(itemLen)
+ for ; !src.IsEmpty(); src = src.Tail() {
+ ar := src.Head()
+ usermem.ByteOrder.PutUint64(b[0:8], uint64(ar.Start))
+ usermem.ByteOrder.PutUint64(b[8:16], uint64(ar.Length()))
+ if _, err := t.CopyOutBytes(addr, b); err != nil {
+ return err
+ }
+ addr += itemLen
+ }
+
+ default:
+ return syserror.ENOSYS
+ }
+
+ return nil
+}
+
+// CopyInIovecs copies an array of numIovecs struct iovecs from the memory
+// mapped at addr, converts them to usermem.AddrRanges, and returns them as a
+// usermem.AddrRangeSeq.
+//
+// CopyInIovecs shares the following properties with Linux's
+// lib/iov_iter.c:import_iovec() => fs/read_write.c:rw_copy_check_uvector():
+//
+// - If the length of any AddrRange would exceed the range of an ssize_t,
+// CopyInIovecs returns EINVAL.
+//
+// - If the length of any AddrRange would cause its end to overflow,
+// CopyInIovecs returns EFAULT.
+//
+// - If any AddrRange would include addresses outside the application address
+// range, CopyInIovecs returns EFAULT.
+//
+// - The combined length of all AddrRanges is limited to MAX_RW_COUNT. If the
+// combined length of all AddrRanges would otherwise exceed this amount, ranges
+// beyond MAX_RW_COUNT are silently truncated.
+//
+// Preconditions: As for usermem.IO.CopyIn. The caller must be running on the
+// task goroutine. t's AddressSpace must be active.
+func (t *Task) CopyInIovecs(addr usermem.Addr, numIovecs int) (usermem.AddrRangeSeq, error) {
+ if numIovecs == 0 {
+ return usermem.AddrRangeSeq{}, nil
+ }
+
+ var dst []usermem.AddrRange
+ if numIovecs > 1 {
+ dst = make([]usermem.AddrRange, 0, numIovecs)
+ }
+
+ switch t.Arch().Width() {
+ case 8:
+ const itemLen = 16
+ if _, ok := addr.AddLength(uint64(numIovecs) * itemLen); !ok {
+ return usermem.AddrRangeSeq{}, syserror.EFAULT
+ }
+
+ b := t.CopyScratchBuffer(itemLen)
+ for i := 0; i < numIovecs; i++ {
+ if _, err := t.CopyInBytes(addr, b); err != nil {
+ return usermem.AddrRangeSeq{}, err
+ }
+
+ base := usermem.Addr(usermem.ByteOrder.Uint64(b[0:8]))
+ length := usermem.ByteOrder.Uint64(b[8:16])
+ if length > math.MaxInt64 {
+ return usermem.AddrRangeSeq{}, syserror.EINVAL
+ }
+ ar, ok := t.MemoryManager().CheckIORange(base, int64(length))
+ if !ok {
+ return usermem.AddrRangeSeq{}, syserror.EFAULT
+ }
+
+ if numIovecs == 1 {
+ // Special case to avoid allocating dst.
+ return usermem.AddrRangeSeqOf(ar).TakeFirst(MAX_RW_COUNT), nil
+ }
+ dst = append(dst, ar)
+
+ addr += itemLen
+ }
+
+ default:
+ return usermem.AddrRangeSeq{}, syserror.ENOSYS
+ }
+
+ // Truncate to MAX_RW_COUNT.
+ var total uint64
+ for i := range dst {
+ dstlen := uint64(dst[i].Length())
+ if rem := uint64(MAX_RW_COUNT) - total; rem < dstlen {
+ dst[i].End -= usermem.Addr(dstlen - rem)
+ dstlen = rem
+ }
+ total += dstlen
+ }
+
+ return usermem.AddrRangeSeqFromSlice(dst), nil
+}
+
+// SingleIOSequence returns a usermem.IOSequence representing [addr,
+// addr+length) in t's address space. If this contains addresses outside the
+// application address range, it returns EFAULT. If length exceeds
+// MAX_RW_COUNT, the range is silently truncated.
+//
+// SingleIOSequence is analogous to Linux's
+// lib/iov_iter.c:import_single_range(). (Note that the non-vectorized read and
+// write syscalls in Linux do not use import_single_range(). However they check
+// access_ok() in fs/read_write.c:vfs_read/vfs_write, and overflowing address
+// ranges are truncated to MAX_RW_COUNT by fs/read_write.c:rw_verify_area().)
+func (t *Task) SingleIOSequence(addr usermem.Addr, length int, opts usermem.IOOpts) (usermem.IOSequence, error) {
+ if length > MAX_RW_COUNT {
+ length = MAX_RW_COUNT
+ }
+ ar, ok := t.MemoryManager().CheckIORange(addr, int64(length))
+ if !ok {
+ return usermem.IOSequence{}, syserror.EFAULT
+ }
+ return usermem.IOSequence{
+ IO: t.MemoryManager(),
+ Addrs: usermem.AddrRangeSeqOf(ar),
+ Opts: opts,
+ }, nil
+}
+
+// IovecsIOSequence returns a usermem.IOSequence representing the array of
+// iovcnt struct iovecs at addr in t's address space. opts applies to the
+// returned IOSequence, not the reading of the struct iovec array.
+//
+// IovecsIOSequence is analogous to Linux's lib/iov_iter.c:import_iovec().
+//
+// Preconditions: As for Task.CopyInIovecs.
+func (t *Task) IovecsIOSequence(addr usermem.Addr, iovcnt int, opts usermem.IOOpts) (usermem.IOSequence, error) {
+ if iovcnt < 0 || iovcnt > linux.UIO_MAXIOV {
+ return usermem.IOSequence{}, syserror.EINVAL
+ }
+ ars, err := t.CopyInIovecs(addr, iovcnt)
+ if err != nil {
+ return usermem.IOSequence{}, err
+ }
+ return usermem.IOSequence{
+ IO: t.MemoryManager(),
+ Addrs: ars,
+ Opts: opts,
+ }, nil
+}
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
new file mode 100644
index 000000000..4dfd2c990
--- /dev/null
+++ b/pkg/sentry/kernel/thread_group.go
@@ -0,0 +1,531 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// A ThreadGroup is a logical grouping of tasks that has widespread
+// significance to other kernel features (e.g. signal handling). ("Thread
+// groups" are usually called "processes" in userspace documentation.)
+//
+// ThreadGroup is a superset of Linux's struct signal_struct.
+//
+// +stateify savable
+type ThreadGroup struct {
+ threadGroupNode
+
+ // signalHandlers is the set of signal handlers used by every task in this
+ // thread group. (signalHandlers may also be shared with other thread
+ // groups.)
+ //
+ // signalHandlers.mu (hereafter "the signal mutex") protects state related
+ // to signal handling, as well as state that usually needs to be atomic
+ // with signal handling, for all ThreadGroups and Tasks using
+ // signalHandlers. (This is analogous to Linux's use of struct
+ // sighand_struct::siglock.)
+ //
+ // The signalHandlers pointer can only be mutated during an execve
+ // (Task.finishExec). Consequently, when it's possible for a task in the
+ // thread group to be completing an execve, signalHandlers is protected by
+ // the owning TaskSet.mu. Otherwise, it is possible to read the
+ // signalHandlers pointer without synchronization. In particular,
+ // completing an execve requires that all other tasks in the thread group
+ // have exited, so task goroutines do not need the owning TaskSet.mu to
+ // read the signalHandlers pointer of their thread groups.
+ signalHandlers *SignalHandlers
+
+ // pendingSignals is the set of pending signals that may be handled by any
+ // task in this thread group.
+ //
+ // pendingSignals is protected by the signal mutex.
+ pendingSignals pendingSignals
+
+ // If groupStopDequeued is true, a task in the thread group has dequeued a
+ // stop signal, but has not yet initiated the group stop.
+ //
+ // groupStopDequeued is analogous to Linux's JOBCTL_STOP_DEQUEUED.
+ //
+ // groupStopDequeued is protected by the signal mutex.
+ groupStopDequeued bool
+
+ // groupStopSignal is the signal that caused a group stop to be initiated.
+ //
+ // groupStopSignal is protected by the signal mutex.
+ groupStopSignal linux.Signal
+
+ // groupStopPendingCount is the number of active tasks in the thread group
+ // for which Task.groupStopPending is set.
+ //
+ // groupStopPendingCount is analogous to Linux's
+ // signal_struct::group_stop_count.
+ //
+ // groupStopPendingCount is protected by the signal mutex.
+ groupStopPendingCount int
+
+ // If groupStopComplete is true, groupStopPendingCount transitioned from
+ // non-zero to zero without an intervening SIGCONT.
+ //
+ // groupStopComplete is analogous to Linux's SIGNAL_STOP_STOPPED.
+ //
+ // groupStopComplete is protected by the signal mutex.
+ groupStopComplete bool
+
+ // If groupStopWaitable is true, the thread group is indicating a waitable
+ // group stop event (as defined by EventChildGroupStop).
+ //
+ // Linux represents the analogous state as SIGNAL_STOP_STOPPED being set
+ // and group_exit_code being non-zero.
+ //
+ // groupStopWaitable is protected by the signal mutex.
+ groupStopWaitable bool
+
+ // If groupContNotify is true, then a SIGCONT has recently ended a group
+ // stop on this thread group, and the first task to observe it should
+ // notify its parent. groupContInterrupted is true iff SIGCONT ended an
+ // incomplete group stop. If groupContNotify is false, groupContInterrupted is
+ // meaningless.
+ //
+ // Analogues in Linux:
+ //
+ // - groupContNotify && groupContInterrupted is represented by
+ // SIGNAL_CLD_STOPPED.
+ //
+ // - groupContNotify && !groupContInterrupted is represented by
+ // SIGNAL_CLD_CONTINUED.
+ //
+ // - !groupContNotify is represented by neither flag being set.
+ //
+ // groupContNotify and groupContInterrupted are protected by the signal
+ // mutex.
+ groupContNotify bool
+ groupContInterrupted bool
+
+ // If groupContWaitable is true, the thread group is indicating a waitable
+ // continue event (as defined by EventGroupContinue).
+ //
+ // groupContWaitable is analogous to Linux's SIGNAL_STOP_CONTINUED.
+ //
+ // groupContWaitable is protected by the signal mutex.
+ groupContWaitable bool
+
+ // exiting is true if all tasks in the ThreadGroup should exit. exiting is
+ // analogous to Linux's SIGNAL_GROUP_EXIT.
+ //
+ // exiting is protected by the signal mutex. exiting can only transition
+ // from false to true.
+ exiting bool
+
+ // exitStatus is the thread group's exit status.
+ //
+ // While exiting is false, exitStatus is protected by the signal mutex.
+ // When exiting becomes true, exitStatus becomes immutable.
+ exitStatus ExitStatus
+
+ // terminationSignal is the signal that this thread group's leader will
+ // send to its parent when it exits.
+ //
+ // terminationSignal is protected by the TaskSet mutex.
+ terminationSignal linux.Signal
+
+ // liveGoroutines is the number of non-exited task goroutines in the thread
+ // group.
+ //
+ // liveGoroutines is not saved; it is reset as task goroutines are
+ // restarted by Task.Start.
+ liveGoroutines sync.WaitGroup `state:"nosave"`
+
+ timerMu sync.Mutex `state:"nosave"`
+
+ // itimerRealTimer implements ITIMER_REAL for the thread group.
+ itimerRealTimer *ktime.Timer
+
+ // itimerVirtSetting is the ITIMER_VIRTUAL setting for the thread group.
+ //
+ // itimerVirtSetting is protected by the signal mutex.
+ itimerVirtSetting ktime.Setting
+
+ // itimerProfSetting is the ITIMER_PROF setting for the thread group.
+ //
+ // itimerProfSetting is protected by the signal mutex.
+ itimerProfSetting ktime.Setting
+
+ // rlimitCPUSoftSetting is the setting for RLIMIT_CPU soft limit
+ // notifications for the thread group.
+ //
+ // rlimitCPUSoftSetting is protected by the signal mutex.
+ rlimitCPUSoftSetting ktime.Setting
+
+ // cpuTimersEnabled is non-zero if itimerVirtSetting.Enabled is true,
+ // itimerProfSetting.Enabled is true, rlimitCPUSoftSetting.Enabled is true,
+ // or limits.Get(CPU) is finite.
+ //
+ // cpuTimersEnabled is protected by the signal mutex. cpuTimersEnabled is
+ // accessed using atomic memory operations.
+ cpuTimersEnabled uint32
+
+ // timers is the thread group's POSIX interval timers. nextTimerID is the
+ // TimerID at which allocation should begin searching for an unused ID.
+ //
+ // timers and nextTimerID are protected by timerMu.
+ timers map[linux.TimerID]*IntervalTimer
+ nextTimerID linux.TimerID
+
+ // exitedCPUStats is the CPU usage for all exited tasks in the thread
+ // group. exitedCPUStats is protected by the TaskSet mutex.
+ exitedCPUStats usage.CPUStats
+
+ // childCPUStats is the CPU usage of all joined descendants of this thread
+ // group. childCPUStats is protected by the TaskSet mutex.
+ childCPUStats usage.CPUStats
+
+ // ioUsage is the I/O usage for all exited tasks in the thread group.
+ // The ioUsage pointer is immutable.
+ ioUsage *usage.IO
+
+ // maxRSS is the historical maximum resident set size of the thread group, updated when:
+ //
+ // - A task in the thread group exits, since after all tasks have
+ // exited the MemoryManager is no longer reachable.
+ //
+ // - The thread group completes an execve, since this changes
+ // MemoryManagers.
+ //
+ // maxRSS is protected by the TaskSet mutex.
+ maxRSS uint64
+
+ // childMaxRSS is the maximum resident set size in bytes of all joined
+ // descendants of this thread group.
+ //
+ // childMaxRSS is protected by the TaskSet mutex.
+ childMaxRSS uint64
+
+ // Resource limits for this ThreadGroup. The limits pointer is immutable.
+ limits *limits.LimitSet
+
+ // processGroup is the processGroup for this thread group.
+ //
+ // processGroup is protected by the TaskSet mutex.
+ processGroup *ProcessGroup
+
+ // execed indicates an exec has occurred since creation. This will be
+ // set by finishExec, and new TheadGroups will have this field cleared.
+ // When execed is set, the processGroup may no longer be changed.
+ //
+ // execed is protected by the TaskSet mutex.
+ execed bool
+
+ // oldRSeqCritical is the thread group's old rseq critical region.
+ oldRSeqCritical atomic.Value `state:".(*OldRSeqCriticalRegion)"`
+
+ // mounts is the thread group's mount namespace. This does not really
+ // correspond to a "mount namespace" in Linux, but is more like a
+ // complete VFS that need not be shared between processes. See the
+ // comment in mounts.go for more information.
+ //
+ // mounts is immutable.
+ mounts *fs.MountNamespace
+
+ // tty is the thread group's controlling terminal. If nil, there is no
+ // controlling terminal.
+ //
+ // tty is protected by the signal mutex.
+ tty *TTY
+
+ // oomScoreAdj is the thread group's OOM score adjustment. This is
+ // currently not used but is maintained for consistency.
+ // TODO(gvisor.dev/issue/1967)
+ //
+ // oomScoreAdj is accessed using atomic memory operations.
+ oomScoreAdj int32
+}
+
+// NewThreadGroup returns a new, empty thread group in PID namespace pidns. The
+// thread group leader will send its parent terminationSignal when it exits.
+// The new thread group isn't visible to the system until a task has been
+// created inside of it by a successful call to TaskSet.NewTask.
+func (k *Kernel) NewThreadGroup(mntns *fs.MountNamespace, pidns *PIDNamespace, sh *SignalHandlers, terminationSignal linux.Signal, limits *limits.LimitSet) *ThreadGroup {
+ tg := &ThreadGroup{
+ threadGroupNode: threadGroupNode{
+ pidns: pidns,
+ },
+ signalHandlers: sh,
+ terminationSignal: terminationSignal,
+ ioUsage: &usage.IO{},
+ limits: limits,
+ mounts: mntns,
+ }
+ tg.itimerRealTimer = ktime.NewTimer(k.monotonicClock, &itimerRealListener{tg: tg})
+ tg.timers = make(map[linux.TimerID]*IntervalTimer)
+ tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{})
+ return tg
+}
+
+// saveOldRSeqCritical is invoked by stateify.
+func (tg *ThreadGroup) saveOldRSeqCritical() *OldRSeqCriticalRegion {
+ return tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion)
+}
+
+// loadOldRSeqCritical is invoked by stateify.
+func (tg *ThreadGroup) loadOldRSeqCritical(r *OldRSeqCriticalRegion) {
+ tg.oldRSeqCritical.Store(r)
+}
+
+// SignalHandlers returns the signal handlers used by tg.
+//
+// Preconditions: The caller must provide the synchronization required to read
+// tg.signalHandlers, as described in the field's comment.
+func (tg *ThreadGroup) SignalHandlers() *SignalHandlers {
+ return tg.signalHandlers
+}
+
+// Limits returns tg's limits.
+func (tg *ThreadGroup) Limits() *limits.LimitSet {
+ return tg.limits
+}
+
+// release releases the thread group's resources.
+func (tg *ThreadGroup) release() {
+ // Timers must be destroyed without holding the TaskSet or signal mutexes
+ // since timers send signals with Timer.mu locked.
+ tg.itimerRealTimer.Destroy()
+ var its []*IntervalTimer
+ tg.pidns.owner.mu.Lock()
+ tg.signalHandlers.mu.Lock()
+ for _, it := range tg.timers {
+ its = append(its, it)
+ }
+ tg.timers = make(map[linux.TimerID]*IntervalTimer) // nil maps can't be saved
+ tg.signalHandlers.mu.Unlock()
+ tg.pidns.owner.mu.Unlock()
+ for _, it := range its {
+ it.DestroyTimer()
+ }
+ if tg.mounts != nil {
+ tg.mounts.DecRef()
+ }
+}
+
+// forEachChildThreadGroupLocked indicates over all child ThreadGroups.
+//
+// Precondition: TaskSet.mu must be held.
+func (tg *ThreadGroup) forEachChildThreadGroupLocked(fn func(*ThreadGroup)) {
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ for child := range t.children {
+ if child == child.tg.leader {
+ fn(child.tg)
+ }
+ }
+ }
+}
+
+// SetControllingTTY sets tty as the controlling terminal of tg.
+func (tg *ThreadGroup) SetControllingTTY(tty *TTY, arg int32) error {
+ tty.mu.Lock()
+ defer tty.mu.Unlock()
+
+ // We might be asked to set the controlling terminal of multiple
+ // processes, so we lock both the TaskSet and SignalHandlers.
+ tg.pidns.owner.mu.Lock()
+ defer tg.pidns.owner.mu.Unlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+
+ // "The calling process must be a session leader and not have a
+ // controlling terminal already." - tty_ioctl(4)
+ if tg.processGroup.session.leader != tg || tg.tty != nil {
+ return syserror.EINVAL
+ }
+
+ // "If this terminal is already the controlling terminal of a different
+ // session group, then the ioctl fails with EPERM, unless the caller
+ // has the CAP_SYS_ADMIN capability and arg equals 1, in which case the
+ // terminal is stolen, and all processes that had it as controlling
+ // terminal lose it." - tty_ioctl(4)
+ if tty.tg != nil && tg.processGroup.session != tty.tg.processGroup.session {
+ // Stealing requires CAP_SYS_ADMIN in the root user namespace.
+ if creds := auth.CredentialsFromContext(tg.leader); !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, creds.UserNamespace.Root()) || arg != 1 {
+ return syserror.EPERM
+ }
+ // Steal the TTY away. Unlike TIOCNOTTY, don't send signals.
+ for othertg := range tg.pidns.owner.Root.tgids {
+ // This won't deadlock by locking tg.signalHandlers
+ // because at this point:
+ // - We only lock signalHandlers if it's in the same
+ // session as the tty's controlling thread group.
+ // - We know that the calling thread group is not in
+ // the same session as the tty's controlling thread
+ // group.
+ if othertg.processGroup.session == tty.tg.processGroup.session {
+ othertg.signalHandlers.mu.Lock()
+ othertg.tty = nil
+ othertg.signalHandlers.mu.Unlock()
+ }
+ }
+ }
+
+ // Set the controlling terminal and foreground process group.
+ tg.tty = tty
+ tg.processGroup.session.foreground = tg.processGroup
+ // Set this as the controlling process of the terminal.
+ tty.tg = tg
+
+ return nil
+}
+
+// ReleaseControllingTTY gives up tty as the controlling tty of tg.
+func (tg *ThreadGroup) ReleaseControllingTTY(tty *TTY) error {
+ tty.mu.Lock()
+ defer tty.mu.Unlock()
+
+ // We might be asked to set the controlling terminal of multiple
+ // processes, so we lock both the TaskSet and SignalHandlers.
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+
+ // Just below, we may re-lock signalHandlers in order to send signals.
+ // Thus we can't defer Unlock here.
+ tg.signalHandlers.mu.Lock()
+
+ if tg.tty == nil || tg.tty != tty {
+ tg.signalHandlers.mu.Unlock()
+ return syserror.ENOTTY
+ }
+
+ // "If the process was session leader, then send SIGHUP and SIGCONT to
+ // the foreground process group and all processes in the current
+ // session lose their controlling terminal." - tty_ioctl(4)
+ // Remove tty as the controlling tty for each process in the session,
+ // then send them SIGHUP and SIGCONT.
+
+ // If we're not the session leader, we don't have to do much.
+ if tty.tg != tg {
+ tg.tty = nil
+ tg.signalHandlers.mu.Unlock()
+ return nil
+ }
+
+ tg.signalHandlers.mu.Unlock()
+
+ // We're the session leader. SIGHUP and SIGCONT the foreground process
+ // group and remove all controlling terminals in the session.
+ var lastErr error
+ for othertg := range tg.pidns.owner.Root.tgids {
+ if othertg.processGroup.session == tg.processGroup.session {
+ othertg.signalHandlers.mu.Lock()
+ othertg.tty = nil
+ if othertg.processGroup == tg.processGroup.session.foreground {
+ if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGHUP)}, true /* group */); err != nil {
+ lastErr = err
+ }
+ if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGCONT)}, true /* group */); err != nil {
+ lastErr = err
+ }
+ }
+ othertg.signalHandlers.mu.Unlock()
+ }
+ }
+
+ return lastErr
+}
+
+// ForegroundProcessGroup returns the process group ID of the foreground
+// process group.
+func (tg *ThreadGroup) ForegroundProcessGroup(tty *TTY) (int32, error) {
+ tty.mu.Lock()
+ defer tty.mu.Unlock()
+
+ tg.pidns.owner.mu.Lock()
+ defer tg.pidns.owner.mu.Unlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+
+ // "When fd does not refer to the controlling terminal of the calling
+ // process, -1 is returned" - tcgetpgrp(3)
+ if tg.tty != tty {
+ return -1, syserror.ENOTTY
+ }
+
+ return int32(tg.processGroup.session.foreground.id), nil
+}
+
+// SetForegroundProcessGroup sets the foreground process group of tty to pgid.
+func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID) (int32, error) {
+ tty.mu.Lock()
+ defer tty.mu.Unlock()
+
+ tg.pidns.owner.mu.Lock()
+ defer tg.pidns.owner.mu.Unlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+
+ // TODO(b/129283598): "If tcsetpgrp() is called by a member of a
+ // background process group in its session, and the calling process is
+ // not blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all
+ // members of this background process group."
+
+ // tty must be the controlling terminal.
+ if tg.tty != tty {
+ return -1, syserror.ENOTTY
+ }
+
+ // pgid must be positive.
+ if pgid < 0 {
+ return -1, syserror.EINVAL
+ }
+
+ // pg must not be empty. Empty process groups are removed from their
+ // pid namespaces.
+ pg, ok := tg.pidns.processGroups[pgid]
+ if !ok {
+ return -1, syserror.ESRCH
+ }
+
+ // pg must be part of this process's session.
+ if tg.processGroup.session != pg.session {
+ return -1, syserror.EPERM
+ }
+
+ tg.processGroup.session.foreground.id = pgid
+ return 0, nil
+}
+
+// itimerRealListener implements ktime.Listener for ITIMER_REAL expirations.
+//
+// +stateify savable
+type itimerRealListener struct {
+ tg *ThreadGroup
+}
+
+// Notify implements ktime.TimerListener.Notify.
+func (l *itimerRealListener) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
+ l.tg.SendSignal(SignalInfoPriv(linux.SIGALRM))
+ return ktime.Setting{}, false
+}
+
+// Destroy implements ktime.TimerListener.Destroy.
+func (l *itimerRealListener) Destroy() {
+}
diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go
new file mode 100644
index 000000000..872e1a82d
--- /dev/null
+++ b/pkg/sentry/kernel/threads.go
@@ -0,0 +1,478 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// TasksLimit is the maximum number of threads for untrusted application.
+// Linux doesn't really limit this directly, rather it is limited by total
+// memory size, stacks allocated and a global maximum. There's no real reason
+// for us to limit it either, (esp. since threads are backed by go routines),
+// and we would expect to hit resource limits long before hitting this number.
+// However, for correctness, we still check that the user doesn't exceed this
+// number.
+//
+// Note that because of the way futexes are implemented, there *are* in fact
+// serious restrictions on valid thread IDs. They are limited to 2^30 - 1
+// (kernel/fork.c:MAX_THREADS).
+const TasksLimit = (1 << 16)
+
+// ThreadID is a generic thread identifier.
+type ThreadID int32
+
+// String returns a decimal representation of the ThreadID.
+func (tid ThreadID) String() string {
+ return fmt.Sprintf("%d", tid)
+}
+
+// InitTID is the TID given to the first task added to each PID namespace. The
+// thread group led by InitTID is called the namespace's init process. The
+// death of a PID namespace's init process causes all tasks visible in that
+// namespace to be killed.
+const InitTID ThreadID = 1
+
+// A TaskSet comprises all tasks in a system.
+//
+// +stateify savable
+type TaskSet struct {
+ // mu protects all relationships betweens tasks and thread groups in the
+ // TaskSet. (mu is approximately equivalent to Linux's tasklist_lock.)
+ mu sync.RWMutex `state:"nosave"`
+
+ // Root is the root PID namespace, in which all tasks in the TaskSet are
+ // visible. The Root pointer is immutable.
+ Root *PIDNamespace
+
+ // sessions is the set of all sessions.
+ sessions sessionList
+
+ // stopCount is the number of active external stops applicable to all tasks
+ // in the TaskSet (calls to TaskSet.BeginExternalStop that have not been
+ // paired with a call to TaskSet.EndExternalStop). stopCount is protected
+ // by mu.
+ //
+ // stopCount is not saved for the same reason as Task.stopCount; it is
+ // always reset to zero after restore.
+ stopCount int32 `state:"nosave"`
+
+ // liveGoroutines is the number of non-exited task goroutines in the
+ // TaskSet.
+ //
+ // liveGoroutines is not saved; it is reset as task goroutines are
+ // restarted by Task.Start.
+ liveGoroutines sync.WaitGroup `state:"nosave"`
+
+ // runningGoroutines is the number of running task goroutines in the
+ // TaskSet.
+ //
+ // runningGoroutines is not saved; its counter value is required to be zero
+ // at time of save (but note that this is not necessarily the same thing as
+ // sync.WaitGroup's zero value).
+ runningGoroutines sync.WaitGroup `state:"nosave"`
+
+ // aioGoroutines is the number of goroutines running async I/O
+ // callbacks.
+ //
+ // aioGoroutines is not saved but is required to be zero at the time of
+ // save.
+ aioGoroutines sync.WaitGroup `state:"nosave"`
+}
+
+// newTaskSet returns a new, empty TaskSet.
+func newTaskSet(pidns *PIDNamespace) *TaskSet {
+ ts := &TaskSet{Root: pidns}
+ pidns.owner = ts
+ return ts
+}
+
+// forEachThreadGroupLocked applies f to each thread group in ts.
+//
+// Preconditions: ts.mu must be locked (for reading or writing).
+func (ts *TaskSet) forEachThreadGroupLocked(f func(tg *ThreadGroup)) {
+ for tg := range ts.Root.tgids {
+ f(tg)
+ }
+}
+
+// A PIDNamespace represents a PID namespace, a bimap between thread IDs and
+// tasks. See the pid_namespaces(7) man page for further details.
+//
+// N.B. A task is said to be visible in a PID namespace if the PID namespace
+// contains a thread ID that maps to that task.
+//
+// +stateify savable
+type PIDNamespace struct {
+ // owner is the TaskSet that this PID namespace belongs to. The owner
+ // pointer is immutable.
+ owner *TaskSet
+
+ // parent is the PID namespace of the process that created this one. If
+ // this is the root PID namespace, parent is nil. The parent pointer is
+ // immutable.
+ //
+ // Invariant: All tasks that are visible in this namespace are also visible
+ // in all ancestor namespaces.
+ parent *PIDNamespace
+
+ // userns is the user namespace with which this PID namespace is
+ // associated. Privileged operations on this PID namespace must have
+ // appropriate capabilities in userns. The userns pointer is immutable.
+ userns *auth.UserNamespace
+
+ // The following fields are protected by owner.mu.
+
+ // last is the last ThreadID to be allocated in this namespace.
+ last ThreadID
+
+ // tasks is a mapping from ThreadIDs in this namespace to tasks visible in
+ // the namespace.
+ tasks map[ThreadID]*Task
+
+ // tids is a mapping from tasks visible in this namespace to their
+ // identifiers in this namespace.
+ tids map[*Task]ThreadID
+
+ // tgids is a mapping from thread groups visible in this namespace to
+ // their identifiers in this namespace.
+ //
+ // The content of tgids is equivalent to tids[tg.leader]. This exists
+ // primarily as an optimization to quickly find all thread groups.
+ tgids map[*ThreadGroup]ThreadID
+
+ // sessions is a mapping from SessionIDs in this namespace to sessions
+ // visible in the namespace.
+ sessions map[SessionID]*Session
+
+ // sids is a mapping from sessions visible in this namespace to their
+ // identifiers in this namespace.
+ sids map[*Session]SessionID
+
+ // processGroups is a mapping from ProcessGroupIDs in this namespace to
+ // process groups visible in the namespace.
+ processGroups map[ProcessGroupID]*ProcessGroup
+
+ // pgids is a mapping from process groups visible in this namespace to
+ // their identifiers in this namespace.
+ pgids map[*ProcessGroup]ProcessGroupID
+
+ // exiting indicates that the namespace's init process is exiting or has
+ // exited.
+ exiting bool
+}
+
+func newPIDNamespace(ts *TaskSet, parent *PIDNamespace, userns *auth.UserNamespace) *PIDNamespace {
+ return &PIDNamespace{
+ owner: ts,
+ parent: parent,
+ userns: userns,
+ tasks: make(map[ThreadID]*Task),
+ tids: make(map[*Task]ThreadID),
+ tgids: make(map[*ThreadGroup]ThreadID),
+ sessions: make(map[SessionID]*Session),
+ sids: make(map[*Session]SessionID),
+ processGroups: make(map[ProcessGroupID]*ProcessGroup),
+ pgids: make(map[*ProcessGroup]ProcessGroupID),
+ }
+}
+
+// NewRootPIDNamespace creates the root PID namespace. 'owner' is not available
+// yet when root namespace is created and must be set by caller.
+func NewRootPIDNamespace(userns *auth.UserNamespace) *PIDNamespace {
+ return newPIDNamespace(nil, nil, userns)
+}
+
+// NewChild returns a new, empty PID namespace that is a child of ns. Authority
+// over the new PID namespace is controlled by userns.
+func (ns *PIDNamespace) NewChild(userns *auth.UserNamespace) *PIDNamespace {
+ return newPIDNamespace(ns.owner, ns, userns)
+}
+
+// TaskWithID returns the task with thread ID tid in PID namespace ns. If no
+// task has that TID, TaskWithID returns nil.
+func (ns *PIDNamespace) TaskWithID(tid ThreadID) *Task {
+ ns.owner.mu.RLock()
+ t := ns.tasks[tid]
+ ns.owner.mu.RUnlock()
+ return t
+}
+
+// ThreadGroupWithID returns the thread group lead by the task with thread ID
+// tid in PID namespace ns. If no task has that TID, or if the task with that
+// TID is not a thread group leader, ThreadGroupWithID returns nil.
+func (ns *PIDNamespace) ThreadGroupWithID(tid ThreadID) *ThreadGroup {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ t := ns.tasks[tid]
+ if t == nil {
+ return nil
+ }
+ if t != t.tg.leader {
+ return nil
+ }
+ return t.tg
+}
+
+// IDOfTask returns the TID assigned to the given task in PID namespace ns. If
+// the task is not visible in that namespace, IDOfTask returns 0. (This return
+// value is significant in some cases, e.g. getppid() is documented as
+// returning 0 if the caller's parent is in an ancestor namespace and
+// consequently not visible to the caller.) If the task is nil, IDOfTask returns
+// 0.
+func (ns *PIDNamespace) IDOfTask(t *Task) ThreadID {
+ ns.owner.mu.RLock()
+ id := ns.tids[t]
+ ns.owner.mu.RUnlock()
+ return id
+}
+
+// IDOfThreadGroup returns the TID assigned to tg's leader in PID namespace ns.
+// If the task is not visible in that namespace, IDOfThreadGroup returns 0.
+func (ns *PIDNamespace) IDOfThreadGroup(tg *ThreadGroup) ThreadID {
+ ns.owner.mu.RLock()
+ id := ns.tgids[tg]
+ ns.owner.mu.RUnlock()
+ return id
+}
+
+// Tasks returns a snapshot of the tasks in ns.
+func (ns *PIDNamespace) Tasks() []*Task {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ tasks := make([]*Task, 0, len(ns.tasks))
+ for t := range ns.tids {
+ tasks = append(tasks, t)
+ }
+ return tasks
+}
+
+// ThreadGroups returns a snapshot of the thread groups in ns.
+func (ns *PIDNamespace) ThreadGroups() []*ThreadGroup {
+ return ns.ThreadGroupsAppend(nil)
+}
+
+// ThreadGroupsAppend appends a snapshot of the thread groups in ns to tgs.
+func (ns *PIDNamespace) ThreadGroupsAppend(tgs []*ThreadGroup) []*ThreadGroup {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ for tg := range ns.tgids {
+ tgs = append(tgs, tg)
+ }
+ return tgs
+}
+
+// UserNamespace returns the user namespace associated with PID namespace ns.
+func (ns *PIDNamespace) UserNamespace() *auth.UserNamespace {
+ return ns.userns
+}
+
+// A threadGroupNode defines the relationship between a thread group and the
+// rest of the system. Conceptually, threadGroupNode is data belonging to the
+// owning TaskSet, as if TaskSet contained a field `nodes
+// map[*ThreadGroup]*threadGroupNode`. However, for practical reasons,
+// threadGroupNode is embedded in the ThreadGroup it represents.
+// (threadGroupNode is an anonymous field in ThreadGroup; this is to expose
+// threadGroupEntry's methods on ThreadGroup to make it implement
+// threadGroupLinker.)
+//
+// +stateify savable
+type threadGroupNode struct {
+ // pidns is the PID namespace containing the thread group and all of its
+ // member tasks. The pidns pointer is immutable.
+ pidns *PIDNamespace
+
+ // eventQueue is notified whenever a event of interest to Task.Wait occurs
+ // in a child of this thread group, or a ptrace tracee of a task in this
+ // thread group. Events are defined in task_exit.go.
+ //
+ // Note that we cannot check and save this wait queue similarly to other
+ // wait queues, as the queue will not be empty by the time of saving, due
+ // to the wait sourced from Exec().
+ eventQueue waiter.Queue `state:"nosave"`
+
+ // leader is the thread group's leader, which is the oldest task in the
+ // thread group; usually the last task in the thread group to call
+ // execve(), or if no such task exists then the first task in the thread
+ // group, which was created by a call to fork() or clone() without
+ // CLONE_THREAD. Once a thread group has been made visible to the rest of
+ // the system by TaskSet.newTask, leader is never nil.
+ //
+ // Note that it's possible for the leader to exit without causing the rest
+ // of the thread group to exit; in such a case, leader will still be valid
+ // and non-nil, but leader will not be in tasks.
+ //
+ // leader is protected by the TaskSet mutex.
+ leader *Task
+
+ // If execing is not nil, it is a task in the thread group that has killed
+ // all other tasks so that it can become the thread group leader and
+ // perform an execve. (execing may already be the thread group leader.)
+ //
+ // execing is analogous to Linux's signal_struct::group_exit_task.
+ //
+ // execing is protected by the TaskSet mutex.
+ execing *Task
+
+ // tasks is all tasks in the thread group that have not yet been reaped.
+ //
+ // tasks is protected by both the TaskSet mutex and the signal mutex:
+ // Mutating tasks requires locking the TaskSet mutex for writing *and*
+ // locking the signal mutex. Reading tasks requires locking the TaskSet
+ // mutex *or* locking the signal mutex.
+ tasks taskList
+
+ // tasksCount is the number of tasks in the thread group that have not yet
+ // been reaped; equivalently, tasksCount is the number of tasks in tasks.
+ //
+ // tasksCount is protected by both the TaskSet mutex and the signal mutex,
+ // as with tasks.
+ tasksCount int
+
+ // liveTasks is the number of tasks in the thread group that have not yet
+ // reached TaskExitZombie.
+ //
+ // liveTasks is protected by the TaskSet mutex (NOT the signal mutex).
+ liveTasks int
+
+ // activeTasks is the number of tasks in the thread group that have not yet
+ // reached TaskExitInitiated.
+ //
+ // activeTasks is protected by both the TaskSet mutex and the signal mutex,
+ // as with tasks.
+ activeTasks int
+}
+
+// PIDNamespace returns the PID namespace containing tg.
+func (tg *ThreadGroup) PIDNamespace() *PIDNamespace {
+ return tg.pidns
+}
+
+// TaskSet returns the TaskSet containing tg.
+func (tg *ThreadGroup) TaskSet() *TaskSet {
+ return tg.pidns.owner
+}
+
+// Leader returns tg's leader.
+func (tg *ThreadGroup) Leader() *Task {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ return tg.leader
+}
+
+// Count returns the number of non-exited threads in the group.
+func (tg *ThreadGroup) Count() int {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ var count int
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ count++
+ }
+ return count
+}
+
+// MemberIDs returns a snapshot of the ThreadIDs (in PID namespace pidns) for
+// all tasks in tg.
+func (tg *ThreadGroup) MemberIDs(pidns *PIDNamespace) []ThreadID {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+
+ var tasks []ThreadID
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ if id, ok := pidns.tids[t]; ok {
+ tasks = append(tasks, id)
+ }
+ }
+ return tasks
+}
+
+// ID returns tg's leader's thread ID in its own PID namespace. If tg's leader
+// is dead, ID returns 0.
+func (tg *ThreadGroup) ID() ThreadID {
+ tg.pidns.owner.mu.RLock()
+ id := tg.pidns.tgids[tg]
+ tg.pidns.owner.mu.RUnlock()
+ return id
+}
+
+// A taskNode defines the relationship between a task and the rest of the
+// system. The comments on threadGroupNode also apply to taskNode.
+//
+// +stateify savable
+type taskNode struct {
+ // tg is the thread group that this task belongs to. The tg pointer is
+ // immutable.
+ tg *ThreadGroup `state:"wait"`
+
+ // taskEntry links into tg.tasks. Note that this means that
+ // Task.Next/Prev/SetNext/SetPrev refer to sibling tasks in the same thread
+ // group. See threadGroupNode.tasks for synchronization info.
+ taskEntry
+
+ // parent is the task's parent. parent may be nil.
+ //
+ // parent is protected by the TaskSet mutex.
+ parent *Task
+
+ // children is this task's children.
+ //
+ // children is protected by the TaskSet mutex.
+ children map[*Task]struct{}
+
+ // If childPIDNamespace is not nil, all new tasks created by this task will
+ // be members of childPIDNamespace rather than this one. (As a corollary,
+ // this task becomes unable to create sibling tasks in the same thread
+ // group.)
+ //
+ // childPIDNamespace is exclusive to the task goroutine.
+ childPIDNamespace *PIDNamespace
+}
+
+// ThreadGroup returns the thread group containing t.
+func (t *Task) ThreadGroup() *ThreadGroup {
+ return t.tg
+}
+
+// PIDNamespace returns the PID namespace containing t.
+func (t *Task) PIDNamespace() *PIDNamespace {
+ return t.tg.pidns
+}
+
+// TaskSet returns the TaskSet containing t.
+func (t *Task) TaskSet() *TaskSet {
+ return t.tg.pidns.owner
+}
+
+// Timekeeper returns the system Timekeeper.
+func (t *Task) Timekeeper() *Timekeeper {
+ return t.k.timekeeper
+}
+
+// Parent returns t's parent.
+func (t *Task) Parent() *Task {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ return t.parent
+}
+
+// ThreadID returns t's thread ID in its own PID namespace. If the task is
+// dead, ThreadID returns 0.
+func (t *Task) ThreadID() ThreadID {
+ return t.tg.pidns.IDOfTask(t)
+}
diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD
new file mode 100644
index 000000000..7ba7dc50c
--- /dev/null
+++ b/pkg/sentry/kernel/time/BUILD
@@ -0,0 +1,19 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "time",
+ srcs = [
+ "context.go",
+ "time.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/kernel/time/context.go b/pkg/sentry/kernel/time/context.go
new file mode 100644
index 000000000..00b729d88
--- /dev/null
+++ b/pkg/sentry/kernel/time/context.go
@@ -0,0 +1,44 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+// contextID is the time package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxRealtimeClock is a Context.Value key for the current real time.
+ CtxRealtimeClock contextID = iota
+)
+
+// RealtimeClockFromContext returns the real time clock associated with context
+// ctx.
+func RealtimeClockFromContext(ctx context.Context) Clock {
+ if v := ctx.Value(CtxRealtimeClock); v != nil {
+ return v.(Clock)
+ }
+ return nil
+}
+
+// NowFromContext returns the current real time associated with context ctx.
+func NowFromContext(ctx context.Context) Time {
+ if clk := RealtimeClockFromContext(ctx); clk != nil {
+ return clk.Now()
+ }
+ panic("encountered context without RealtimeClock")
+}
diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go
new file mode 100644
index 000000000..e959700f2
--- /dev/null
+++ b/pkg/sentry/kernel/time/time.go
@@ -0,0 +1,709 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package time defines the Timer type, which provides a periodic timer that
+// works by sampling a user-provided clock.
+package time
+
+import (
+ "fmt"
+ "math"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// Events that may be generated by a Clock.
+const (
+ // ClockEventSet occurs when a Clock undergoes a discontinuous change.
+ ClockEventSet waiter.EventMask = 1 << iota
+
+ // ClockEventRateIncrease occurs when the rate at which a Clock advances
+ // increases significantly, such that values returned by previous calls to
+ // Clock.WallTimeUntil may be too large.
+ ClockEventRateIncrease
+)
+
+// Time represents an instant in time with nanosecond precision.
+//
+// Time may represent time with respect to any clock and may not have any
+// meaning in the real world.
+//
+// +stateify savable
+type Time struct {
+ ns int64
+}
+
+var (
+ // MinTime is the zero time instant, the lowest possible time that can
+ // be represented by Time.
+ MinTime = Time{ns: math.MinInt64}
+
+ // MaxTime is the highest possible time that can be represented by
+ // Time.
+ MaxTime = Time{ns: math.MaxInt64}
+
+ // ZeroTime represents the zero time in an unspecified Clock's domain.
+ ZeroTime = Time{ns: 0}
+)
+
+const (
+ // MinDuration is the minimum duration representable by time.Duration.
+ MinDuration = time.Duration(math.MinInt64)
+
+ // MaxDuration is the maximum duration representable by time.Duration.
+ MaxDuration = time.Duration(math.MaxInt64)
+)
+
+// FromNanoseconds returns a Time representing the point ns nanoseconds after
+// an unspecified Clock's zero time.
+func FromNanoseconds(ns int64) Time {
+ return Time{ns}
+}
+
+// FromSeconds returns a Time representing the point s seconds after an
+// unspecified Clock's zero time.
+func FromSeconds(s int64) Time {
+ if s > math.MaxInt64/time.Second.Nanoseconds() {
+ return MaxTime
+ }
+ return Time{s * 1e9}
+}
+
+// FromUnix converts from Unix seconds and nanoseconds to Time, assuming a real
+// time Unix clock domain.
+func FromUnix(s int64, ns int64) Time {
+ if s > math.MaxInt64/time.Second.Nanoseconds() {
+ return MaxTime
+ }
+ t := s * 1e9
+ if t > math.MaxInt64-ns {
+ return MaxTime
+ }
+ return Time{t + ns}
+}
+
+// FromTimespec converts from Linux Timespec to Time.
+func FromTimespec(ts linux.Timespec) Time {
+ return Time{ts.ToNsecCapped()}
+}
+
+// FromTimeval converts a Linux Timeval to Time.
+func FromTimeval(tv linux.Timeval) Time {
+ return Time{tv.ToNsecCapped()}
+}
+
+// Nanoseconds returns nanoseconds elapsed since the zero time in t's Clock
+// domain. If t represents walltime, this is nanoseconds since the Unix epoch.
+func (t Time) Nanoseconds() int64 {
+ return t.ns
+}
+
+// Seconds returns seconds elapsed since the zero time in t's Clock domain. If
+// t represents walltime, this is seconds since Unix epoch.
+func (t Time) Seconds() int64 {
+ return t.Nanoseconds() / time.Second.Nanoseconds()
+}
+
+// Timespec converts Time to a Linux timespec.
+func (t Time) Timespec() linux.Timespec {
+ return linux.NsecToTimespec(t.Nanoseconds())
+}
+
+// Unix returns the (seconds, nanoseconds) representation of t such that
+// seconds*1e9 + nanoseconds = t.
+func (t Time) Unix() (s int64, ns int64) {
+ s = t.ns / 1e9
+ ns = t.ns % 1e9
+ return
+}
+
+// TimeT converts Time to a Linux time_t.
+func (t Time) TimeT() linux.TimeT {
+ return linux.NsecToTimeT(t.Nanoseconds())
+}
+
+// Timeval converts Time to a Linux timeval.
+func (t Time) Timeval() linux.Timeval {
+ return linux.NsecToTimeval(t.Nanoseconds())
+}
+
+// StatxTimestamp converts Time to a Linux statx_timestamp.
+func (t Time) StatxTimestamp() linux.StatxTimestamp {
+ return linux.NsecToStatxTimestamp(t.Nanoseconds())
+}
+
+// Add adds the duration of d to t.
+func (t Time) Add(d time.Duration) Time {
+ if t.ns > 0 && d.Nanoseconds() > math.MaxInt64-int64(t.ns) {
+ return MaxTime
+ }
+ if t.ns < 0 && d.Nanoseconds() < math.MinInt64-int64(t.ns) {
+ return MinTime
+ }
+ return Time{int64(t.ns) + d.Nanoseconds()}
+}
+
+// AddTime adds the duration of u to t.
+func (t Time) AddTime(u Time) Time {
+ return t.Add(time.Duration(u.ns))
+}
+
+// Equal reports whether the two times represent the same instant in time.
+func (t Time) Equal(u Time) bool {
+ return t.ns == u.ns
+}
+
+// Before reports whether the instant t is before the instant u.
+func (t Time) Before(u Time) bool {
+ return t.ns < u.ns
+}
+
+// After reports whether the instant t is after the instant u.
+func (t Time) After(u Time) bool {
+ return t.ns > u.ns
+}
+
+// Sub returns the duration of t - u.
+//
+// N.B. This measure may not make sense for every Time returned by ktime.Clock.
+// Callers who need wall time duration can use ktime.Clock.WallTimeUntil to
+// estimate that wall time.
+func (t Time) Sub(u Time) time.Duration {
+ dur := time.Duration(int64(t.ns)-int64(u.ns)) * time.Nanosecond
+ switch {
+ case u.Add(dur).Equal(t):
+ return dur
+ case t.Before(u):
+ return MinDuration
+ default:
+ return MaxDuration
+ }
+}
+
+// IsMin returns whether t represents the lowest possible time instant.
+func (t Time) IsMin() bool {
+ return t == MinTime
+}
+
+// IsZero returns whether t represents the zero time instant in t's Clock domain.
+func (t Time) IsZero() bool {
+ return t == ZeroTime
+}
+
+// String returns the time represented in nanoseconds as a string.
+func (t Time) String() string {
+ return fmt.Sprintf("%dns", t.Nanoseconds())
+}
+
+// A Clock is an abstract time source.
+type Clock interface {
+ // Now returns the current time in nanoseconds according to the Clock.
+ Now() Time
+
+ // WallTimeUntil returns the estimated wall time until Now will return a
+ // value greater than or equal to t, given that a recent call to Now
+ // returned now. If t has already passed, WallTimeUntil may return 0 or a
+ // negative value.
+ //
+ // WallTimeUntil must be abstract to support Clocks that do not represent
+ // wall time (e.g. thread group execution timers). Clocks that represent
+ // wall times may embed the WallRateClock type to obtain an appropriate
+ // trivial implementation of WallTimeUntil.
+ //
+ // WallTimeUntil is used to determine when associated Timers should next
+ // check for expirations. Returning too small a value may result in
+ // spurious Timer goroutine wakeups, while returning too large a value may
+ // result in late expirations. Implementations should usually err on the
+ // side of underestimating.
+ WallTimeUntil(t, now Time) time.Duration
+
+ // Waitable methods may be used to subscribe to Clock events. Waiters will
+ // not be preserved by Save and must be re-established during restore.
+ //
+ // Since Clock events are transient, implementations of
+ // waiter.Waitable.Readiness should return 0.
+ waiter.Waitable
+}
+
+// WallRateClock implements Clock.WallTimeUntil for Clocks that elapse at the
+// same rate as wall time.
+type WallRateClock struct{}
+
+// WallTimeUntil implements Clock.WallTimeUntil.
+func (*WallRateClock) WallTimeUntil(t, now Time) time.Duration {
+ return t.Sub(now)
+}
+
+// NoClockEvents implements waiter.Waitable for Clocks that do not generate
+// events.
+type NoClockEvents struct{}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (*NoClockEvents) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return 0
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (*NoClockEvents) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (*NoClockEvents) EventUnregister(e *waiter.Entry) {
+}
+
+// ClockEventsQueue implements waiter.Waitable by wrapping waiter.Queue and
+// defining waiter.Waitable.Readiness as required by Clock.
+type ClockEventsQueue struct {
+ waiter.Queue
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (*ClockEventsQueue) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return 0
+}
+
+// A TimerListener receives expirations from a Timer.
+type TimerListener interface {
+ // Notify is called when its associated Timer expires. exp is the number of
+ // expirations. setting is the next timer Setting.
+ //
+ // Notify is called with the associated Timer's mutex locked, so Notify
+ // must not take any locks that precede Timer.mu in lock order.
+ //
+ // If Notify returns true, the timer will use the returned setting
+ // rather than the passed one.
+ //
+ // Preconditions: exp > 0.
+ Notify(exp uint64, setting Setting) (newSetting Setting, update bool)
+
+ // Destroy is called when the timer is destroyed.
+ Destroy()
+}
+
+// Setting contains user-controlled mutable Timer properties.
+//
+// +stateify savable
+type Setting struct {
+ // Enabled is true if the timer is running.
+ Enabled bool
+
+ // Next is the time in nanoseconds of the next expiration.
+ Next Time
+
+ // Period is the time in nanoseconds between expirations. If Period is
+ // zero, the timer will not automatically restart after expiring.
+ //
+ // Invariant: Period >= 0.
+ Period time.Duration
+}
+
+// SettingFromSpec converts a (value, interval) pair to a Setting based on a
+// reading from c. value is interpreted as a time relative to c.Now().
+func SettingFromSpec(value time.Duration, interval time.Duration, c Clock) (Setting, error) {
+ return SettingFromSpecAt(value, interval, c.Now())
+}
+
+// SettingFromSpecAt converts a (value, interval) pair to a Setting. value is
+// interpreted as a time relative to now.
+func SettingFromSpecAt(value time.Duration, interval time.Duration, now Time) (Setting, error) {
+ if value < 0 {
+ return Setting{}, syserror.EINVAL
+ }
+ if value == 0 {
+ return Setting{Period: interval}, nil
+ }
+ return Setting{
+ Enabled: true,
+ Next: now.Add(value),
+ Period: interval,
+ }, nil
+}
+
+// SettingFromAbsSpec converts a (value, interval) pair to a Setting. value is
+// interpreted as an absolute time.
+func SettingFromAbsSpec(value Time, interval time.Duration) (Setting, error) {
+ if value.Before(ZeroTime) {
+ return Setting{}, syserror.EINVAL
+ }
+ if value.IsZero() {
+ return Setting{Period: interval}, nil
+ }
+ return Setting{
+ Enabled: true,
+ Next: value,
+ Period: interval,
+ }, nil
+}
+
+// SettingFromItimerspec converts a linux.Itimerspec to a Setting. If abs is
+// true, its.Value is interpreted as an absolute time. Otherwise, it is
+// interpreted as a time relative to c.Now().
+func SettingFromItimerspec(its linux.Itimerspec, abs bool, c Clock) (Setting, error) {
+ if abs {
+ return SettingFromAbsSpec(FromTimespec(its.Value), its.Interval.ToDuration())
+ }
+ return SettingFromSpec(its.Value.ToDuration(), its.Interval.ToDuration(), c)
+}
+
+// SpecFromSetting converts a timestamp and a Setting to a (relative value,
+// interval) pair, as used by most Linux syscalls that return a struct
+// itimerval or struct itimerspec.
+func SpecFromSetting(now Time, s Setting) (value, period time.Duration) {
+ if !s.Enabled {
+ return 0, s.Period
+ }
+ return s.Next.Sub(now), s.Period
+}
+
+// ItimerspecFromSetting converts a Setting to a linux.Itimerspec.
+func ItimerspecFromSetting(now Time, s Setting) linux.Itimerspec {
+ val, iv := SpecFromSetting(now, s)
+ return linux.Itimerspec{
+ Interval: linux.DurationToTimespec(iv),
+ Value: linux.DurationToTimespec(val),
+ }
+}
+
+// At returns an updated Setting and a number of expirations after the
+// associated Clock indicates a time of now.
+//
+// Settings may be created by successive calls to At with decreasing
+// values of now (i.e. time may appear to go backward). Supporting this is
+// required to support non-monotonic clocks, as well as allowing
+// Timer.clock.Now() to be called without holding Timer.mu.
+func (s Setting) At(now Time) (Setting, uint64) {
+ if !s.Enabled {
+ return s, 0
+ }
+ if s.Next.After(now) {
+ return s, 0
+ }
+ if s.Period == 0 {
+ s.Enabled = false
+ return s, 1
+ }
+ exp := 1 + uint64(now.Sub(s.Next).Nanoseconds())/uint64(s.Period)
+ s.Next = s.Next.Add(time.Duration(uint64(s.Period) * exp))
+ return s, exp
+}
+
+// Timer is an optionally-periodic timer driven by sampling a user-specified
+// Clock. Timer's semantics support the requirements of Linux's interval timers
+// (setitimer(2), timer_create(2), timerfd_create(2)).
+//
+// Timers should be created using NewTimer and must be cleaned up by calling
+// Timer.Destroy when no longer used.
+//
+// +stateify savable
+type Timer struct {
+ // clock is the time source. clock is immutable.
+ clock Clock
+
+ // listener is notified of expirations. listener is immutable.
+ listener TimerListener
+
+ // mu protects the following mutable fields.
+ mu sync.Mutex `state:"nosave"`
+
+ // setting is the timer setting. setting is protected by mu.
+ setting Setting
+
+ // paused is true if the Timer is paused. paused is protected by mu.
+ paused bool
+
+ // kicker is used to wake the Timer goroutine. The kicker pointer is
+ // immutable, but its state is protected by mu.
+ kicker *time.Timer `state:"nosave"`
+
+ // entry is registered with clock.EventRegister. entry is immutable.
+ //
+ // Per comment in Clock, entry must be re-registered after restore; per
+ // comment in Timer.Load, this is done in Timer.Resume.
+ entry waiter.Entry `state:"nosave"`
+
+ // events is the channel that will be notified whenever entry receives an
+ // event. It is also closed by Timer.Destroy to instruct the Timer
+ // goroutine to exit.
+ events chan struct{} `state:"nosave"`
+}
+
+// timerTickEvents are Clock events that require the Timer goroutine to Tick
+// prematurely.
+const timerTickEvents = ClockEventSet | ClockEventRateIncrease
+
+// NewTimer returns a new Timer that will obtain time from clock and send
+// expirations to listener. The Timer is initially stopped and has no first
+// expiration or period configured.
+func NewTimer(clock Clock, listener TimerListener) *Timer {
+ t := &Timer{
+ clock: clock,
+ listener: listener,
+ }
+ t.init()
+ return t
+}
+
+// After waits for the duration to elapse according to clock and then sends a
+// notification on the returned channel. The timer is started immediately and
+// will fire exactly once. The second return value is the start time used with
+// the duration.
+//
+// Callers must call Timer.Destroy.
+func After(clock Clock, duration time.Duration) (*Timer, Time, <-chan struct{}) {
+ notifier, tchan := NewChannelNotifier()
+ t := NewTimer(clock, notifier)
+ now := clock.Now()
+
+ t.Swap(Setting{
+ Enabled: true,
+ Period: 0,
+ Next: now.Add(duration),
+ })
+ return t, now, tchan
+}
+
+// init initializes Timer state that is not preserved across save/restore. If
+// init has already been called, calling it again is a no-op.
+//
+// Preconditions: t.mu must be locked, or the caller must have exclusive access
+// to t.
+func (t *Timer) init() {
+ if t.kicker != nil {
+ return
+ }
+ // If t.kicker is nil, the Timer goroutine can't be running, so we can't
+ // race with it.
+ t.kicker = time.NewTimer(0)
+ t.entry, t.events = waiter.NewChannelEntry(nil)
+ t.clock.EventRegister(&t.entry, timerTickEvents)
+ go t.runGoroutine() // S/R-SAFE: synchronized by t.mu
+}
+
+// Destroy releases resources owned by the Timer. A Destroyed Timer must not be
+// used again; in particular, a Destroyed Timer should not be Saved.
+func (t *Timer) Destroy() {
+ // Stop the Timer, ensuring that the Timer goroutine will not call
+ // t.kicker.Reset, before calling t.kicker.Stop.
+ t.mu.Lock()
+ t.setting.Enabled = false
+ t.mu.Unlock()
+ t.kicker.Stop()
+ // Unregister t.entry, ensuring that the Clock will not send to t.events,
+ // before closing t.events to instruct the Timer goroutine to exit.
+ t.clock.EventUnregister(&t.entry)
+ close(t.events)
+ t.listener.Destroy()
+}
+
+func (t *Timer) runGoroutine() {
+ for {
+ select {
+ case <-t.kicker.C:
+ case _, ok := <-t.events:
+ if !ok {
+ // Channel closed by Destroy.
+ return
+ }
+ }
+ t.Tick()
+ }
+}
+
+// Tick requests that the Timer immediately check for expirations and
+// re-evaluate when it should next check for expirations.
+func (t *Timer) Tick() {
+ now := t.clock.Now()
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if t.paused {
+ return
+ }
+ s, exp := t.setting.At(now)
+ t.setting = s
+ if exp > 0 {
+ if newS, ok := t.listener.Notify(exp, t.setting); ok {
+ t.setting = newS
+ }
+ }
+ t.resetKickerLocked(now)
+}
+
+// Pause pauses the Timer, ensuring that it does not generate any further
+// expirations until Resume is called. If the Timer is already paused, Pause
+// has no effect.
+func (t *Timer) Pause() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.paused = true
+ // t.kicker may be nil if we were restored but never resumed.
+ if t.kicker != nil {
+ t.kicker.Stop()
+ }
+}
+
+// Resume ends the effect of Pause. If the Timer is not paused, Resume has no
+// effect.
+func (t *Timer) Resume() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if !t.paused {
+ return
+ }
+ t.paused = false
+
+ // Lazily initialize the Timer. We can't call Timer.init until Timer.Resume
+ // because save/restore will restore Timers before
+ // kernel.Timekeeper.SetClocks() has been called, so if t.clock is backed
+ // by a kernel.Timekeeper then the Timer goroutine will panic if it calls
+ // t.clock.Now().
+ t.init()
+
+ // Kick the Timer goroutine in case it was already initialized, but the
+ // Timer goroutine was sleeping.
+ t.kicker.Reset(0)
+}
+
+// Get returns a snapshot of the Timer's current Setting and the time
+// (according to the Timer's Clock) at which the snapshot was taken.
+//
+// Preconditions: The Timer must not be paused (since its Setting cannot
+// be advanced to the current time while it is paused.)
+func (t *Timer) Get() (Time, Setting) {
+ now := t.clock.Now()
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if t.paused {
+ panic(fmt.Sprintf("Timer.Get called on paused Timer %p", t))
+ }
+ s, exp := t.setting.At(now)
+ t.setting = s
+ if exp > 0 {
+ if newS, ok := t.listener.Notify(exp, t.setting); ok {
+ t.setting = newS
+ }
+ }
+ t.resetKickerLocked(now)
+ return now, s
+}
+
+// Swap atomically changes the Timer's Setting and returns the Timer's previous
+// Setting and the time (according to the Timer's Clock) at which the snapshot
+// was taken. Setting s.Enabled to true starts the Timer, while setting
+// s.Enabled to false stops it.
+//
+// Preconditions: The Timer must not be paused.
+func (t *Timer) Swap(s Setting) (Time, Setting) {
+ return t.SwapAnd(s, nil)
+}
+
+// SwapAnd atomically changes the Timer's Setting, calls f if it is not nil,
+// and returns the Timer's previous Setting and the time (according to the
+// Timer's Clock) at which the Setting was changed. Setting s.Enabled to true
+// starts the timer, while setting s.Enabled to false stops it.
+//
+// Preconditions: The Timer must not be paused. f cannot call any Timer methods
+// since it is called with the Timer mutex locked.
+func (t *Timer) SwapAnd(s Setting, f func()) (Time, Setting) {
+ now := t.clock.Now()
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if t.paused {
+ panic(fmt.Sprintf("Timer.SwapAnd called on paused Timer %p", t))
+ }
+ oldS, oldExp := t.setting.At(now)
+ if oldExp > 0 {
+ t.listener.Notify(oldExp, oldS)
+ // N.B. The returned Setting doesn't matter because we're about
+ // to overwrite.
+ }
+ if f != nil {
+ f()
+ }
+ newS, newExp := s.At(now)
+ t.setting = newS
+ if newExp > 0 {
+ if newS, ok := t.listener.Notify(newExp, t.setting); ok {
+ t.setting = newS
+ }
+ }
+ t.resetKickerLocked(now)
+ return now, oldS
+}
+
+// Atomically invokes f atomically with respect to expirations of t; that is, t
+// cannot generate expirations while f is being called.
+//
+// Preconditions: f cannot call any Timer methods since it is called with the
+// Timer mutex locked.
+func (t *Timer) Atomically(f func()) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ f()
+}
+
+// Preconditions: t.mu must be locked.
+func (t *Timer) resetKickerLocked(now Time) {
+ if t.setting.Enabled {
+ // Clock.WallTimeUntil may return a negative value. This is fine;
+ // time.when treats negative Durations as 0.
+ t.kicker.Reset(t.clock.WallTimeUntil(t.setting.Next, now))
+ }
+ // We don't call t.kicker.Stop if !t.setting.Enabled because in most cases
+ // resetKickerLocked will be called from the Timer goroutine itself, in
+ // which case t.kicker has already fired and t.kicker.Stop will be an
+ // expensive no-op (time.Timer.Stop => time.stopTimer => runtime.stopTimer
+ // => runtime.deltimer).
+}
+
+// Clock returns the Clock used by t.
+func (t *Timer) Clock() Clock {
+ return t.clock
+}
+
+// ChannelNotifier is a TimerListener that sends a message on an empty struct
+// channel.
+//
+// ChannelNotifier cannot be saved or loaded.
+type ChannelNotifier struct {
+ // tchan must be a buffered channel.
+ tchan chan struct{}
+}
+
+// NewChannelNotifier creates a new channel notifier.
+//
+// If the notifier is used with a timer, Timer.Destroy will close the channel
+// returned here.
+func NewChannelNotifier() (TimerListener, <-chan struct{}) {
+ tchan := make(chan struct{}, 1)
+ return &ChannelNotifier{tchan}, tchan
+}
+
+// Notify implements ktime.TimerListener.Notify.
+func (c *ChannelNotifier) Notify(uint64, Setting) (Setting, bool) {
+ select {
+ case c.tchan <- struct{}{}:
+ default:
+ }
+
+ return Setting{}, false
+}
+
+// Destroy implements ktime.TimerListener.Destroy and will close the channel.
+func (c *ChannelNotifier) Destroy() {
+ close(c.tchan)
+}
diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go
new file mode 100644
index 000000000..0adf25691
--- /dev/null
+++ b/pkg/sentry/kernel/timekeeper.go
@@ -0,0 +1,325 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+ "sync/atomic"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/log"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ sentrytime "gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Timekeeper manages all of the kernel clocks.
+//
+// +stateify savable
+type Timekeeper struct {
+ // clocks are the clock sources.
+ //
+ // These are not saved directly, as the new machine's clock may behave
+ // differently.
+ //
+ // It is set only once, by SetClocks.
+ clocks sentrytime.Clocks `state:"nosave"`
+
+ // bootTime is the realtime when the system "booted". i.e., when
+ // SetClocks was called in the initial (not restored) run.
+ bootTime ktime.Time
+
+ // monotonicOffset is the offset to apply to the monotonic clock output
+ // from clocks.
+ //
+ // It is set only once, by SetClocks.
+ monotonicOffset int64 `state:"nosave"`
+
+ // monotonicLowerBound is the lowerBound for monotonic time.
+ monotonicLowerBound int64 `state:"nosave"`
+
+ // restored, if non-nil, indicates that this Timekeeper was restored
+ // from a state file. The clocks are not set until restored is closed.
+ restored chan struct{} `state:"nosave"`
+
+ // saveMonotonic is the (offset) value of the monotonic clock at the
+ // time of save.
+ //
+ // It is only valid if restored is non-nil.
+ //
+ // It is only used in SetClocks after restore to compute the new
+ // monotonicOffset.
+ saveMonotonic int64
+
+ // saveRealtime is the value of the realtime clock at the time of save.
+ //
+ // It is only valid if restored is non-nil.
+ //
+ // It is only used in SetClocks after restore to compute the new
+ // monotonicOffset.
+ saveRealtime int64
+
+ // params manages the parameter page.
+ params *VDSOParamPage
+
+ // mu protects destruction with stop and wg.
+ mu sync.Mutex `state:"nosave"`
+
+ // stop is used to tell the update goroutine to exit.
+ stop chan struct{} `state:"nosave"`
+
+ // wg is used to indicate that the update goroutine has exited.
+ wg sync.WaitGroup `state:"nosave"`
+}
+
+// NewTimekeeper returns a Timekeeper that is automatically kept up-to-date.
+// NewTimekeeper does not take ownership of paramPage.
+//
+// SetClocks must be called on the returned Timekeeper before it is usable.
+func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage platform.FileRange) (*Timekeeper, error) {
+ return &Timekeeper{
+ params: NewVDSOParamPage(mfp, paramPage),
+ }, nil
+}
+
+// SetClocks the backing clock source.
+//
+// SetClocks must be called before the Timekeeper is used, and it may not be
+// called more than once, as changing the clock source without extra correction
+// could cause time discontinuities.
+//
+// It must also be called after Load.
+func (t *Timekeeper) SetClocks(c sentrytime.Clocks) {
+ // Update the params, marking them "not ready", as we may need to
+ // restart calibration on this new machine.
+ if t.restored != nil {
+ if err := t.params.Write(func() vdsoParams {
+ return vdsoParams{}
+ }); err != nil {
+ panic("unable to reset VDSO params: " + err.Error())
+ }
+ }
+
+ if t.clocks != nil {
+ panic("SetClocks called on previously-initialized Timekeeper")
+ }
+
+ t.clocks = c
+
+ // Compute the offset of the monotonic clock from the base Clocks.
+ //
+ // In a fresh (not restored) sentry, monotonic time starts at zero.
+ //
+ // In a restored sentry, monotonic time jumps forward by approximately
+ // the same amount as real time. There are no guarantees here, we are
+ // just making a best-effort attempt to make it appear that the app
+ // was simply not scheduled for a long period, rather than that the
+ // real time clock was changed.
+ //
+ // If real time went backwards, it remains the same.
+ wantMonotonic := int64(0)
+
+ nowMonotonic, err := t.clocks.GetTime(sentrytime.Monotonic)
+ if err != nil {
+ panic("Unable to get current monotonic time: " + err.Error())
+ }
+
+ nowRealtime, err := t.clocks.GetTime(sentrytime.Realtime)
+ if err != nil {
+ panic("Unable to get current realtime: " + err.Error())
+ }
+
+ if t.restored != nil {
+ wantMonotonic = t.saveMonotonic
+ elapsed := nowRealtime - t.saveRealtime
+ if elapsed > 0 {
+ wantMonotonic += elapsed
+ }
+ }
+
+ t.monotonicOffset = wantMonotonic - nowMonotonic
+
+ if t.restored == nil {
+ // Hold on to the initial "boot" time.
+ t.bootTime = ktime.FromNanoseconds(nowRealtime)
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.startUpdater()
+
+ if t.restored != nil {
+ close(t.restored)
+ }
+}
+
+// startUpdater starts an update goroutine that keeps the clocks updated.
+//
+// mu must be held.
+func (t *Timekeeper) startUpdater() {
+ if t.stop != nil {
+ // Timekeeper already started
+ return
+ }
+ t.stop = make(chan struct{})
+
+ // Keep the clocks up to date.
+ //
+ // Note that the Go runtime uses host CLOCK_MONOTONIC to service the
+ // timer, so it may run at a *slightly* different rate from the
+ // application CLOCK_MONOTONIC. That is fine, as we only need to update
+ // at approximately this rate.
+ timer := time.NewTicker(sentrytime.ApproxUpdateInterval)
+ t.wg.Add(1)
+ go func() { // S/R-SAFE: stopped during save.
+ defer t.wg.Done()
+ for {
+ // Start with an update immediately, so the clocks are
+ // ready ASAP.
+
+ // Call Update within a Write block to prevent the VDSO
+ // from using the old params between Update and
+ // Write.
+ if err := t.params.Write(func() vdsoParams {
+ monotonicParams, monotonicOk, realtimeParams, realtimeOk := t.clocks.Update()
+
+ var p vdsoParams
+ if monotonicOk {
+ p.monotonicReady = 1
+ p.monotonicBaseCycles = int64(monotonicParams.BaseCycles)
+ p.monotonicBaseRef = int64(monotonicParams.BaseRef) + t.monotonicOffset
+ p.monotonicFrequency = monotonicParams.Frequency
+ }
+ if realtimeOk {
+ p.realtimeReady = 1
+ p.realtimeBaseCycles = int64(realtimeParams.BaseCycles)
+ p.realtimeBaseRef = int64(realtimeParams.BaseRef)
+ p.realtimeFrequency = realtimeParams.Frequency
+ }
+
+ log.Debugf("Updating VDSO parameters: %+v", p)
+
+ return p
+ }); err != nil {
+ log.Warningf("Unable to update VDSO parameter page: %v", err)
+ }
+
+ select {
+ case <-timer.C:
+ case <-t.stop:
+ return
+ }
+ }
+ }()
+}
+
+// stopUpdater stops the update goroutine, blocking until it exits.
+//
+// mu must be held.
+func (t *Timekeeper) stopUpdater() {
+ if t.stop == nil {
+ // Updater not running.
+ return
+ }
+
+ close(t.stop)
+ t.wg.Wait()
+ t.stop = nil
+}
+
+// Destroy destroys the Timekeeper, freeing all associated resources.
+func (t *Timekeeper) Destroy() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ t.stopUpdater()
+}
+
+// PauseUpdates stops clock parameter updates. This should only be used when
+// Tasks are not running and thus cannot access the clock.
+func (t *Timekeeper) PauseUpdates() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.stopUpdater()
+}
+
+// ResumeUpdates restarts clock parameter updates stopped by PauseUpdates.
+func (t *Timekeeper) ResumeUpdates() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.startUpdater()
+}
+
+// GetTime returns the current time in nanoseconds.
+func (t *Timekeeper) GetTime(c sentrytime.ClockID) (int64, error) {
+ if t.clocks == nil {
+ if t.restored == nil {
+ panic("Timekeeper used before initialized with SetClocks")
+ }
+ <-t.restored
+ }
+ now, err := t.clocks.GetTime(c)
+ if err == nil && c == sentrytime.Monotonic {
+ now += t.monotonicOffset
+ for {
+ // It's possible that the clock is shaky. This may be due to
+ // platform issues, e.g. the KVM platform relies on the guest
+ // TSC and host TSC, which may not be perfectly in sync. To
+ // work around this issue, ensure that the monotonic time is
+ // always bounded by the last time read.
+ oldLowerBound := atomic.LoadInt64(&t.monotonicLowerBound)
+ if now < oldLowerBound {
+ now = oldLowerBound
+ break
+ }
+ if atomic.CompareAndSwapInt64(&t.monotonicLowerBound, oldLowerBound, now) {
+ break
+ }
+ }
+ }
+ return now, err
+}
+
+// BootTime returns the system boot real time.
+func (t *Timekeeper) BootTime() ktime.Time {
+ return t.bootTime
+}
+
+// timekeeperClock is a ktime.Clock that reads time from a
+// kernel.Timekeeper-managed clock.
+//
+// +stateify savable
+type timekeeperClock struct {
+ tk *Timekeeper
+ c sentrytime.ClockID
+
+ // Implements ktime.Clock.WallTimeUntil.
+ ktime.WallRateClock `state:"nosave"`
+
+ // Implements waiter.Waitable. (We have no ability to detect
+ // discontinuities from external changes to CLOCK_REALTIME).
+ ktime.NoClockEvents `state:"nosave"`
+}
+
+// Now implements ktime.Clock.Now.
+func (tc *timekeeperClock) Now() ktime.Time {
+ now, err := tc.tk.GetTime(tc.c)
+ if err != nil {
+ panic(fmt.Sprintf("timekeeperClock(ClockID=%v)).Now: %v", tc.c, err))
+ }
+ return ktime.FromNanoseconds(now)
+}
diff --git a/pkg/sentry/kernel/timekeeper_state.go b/pkg/sentry/kernel/timekeeper_state.go
new file mode 100644
index 000000000..8e961c832
--- /dev/null
+++ b/pkg/sentry/kernel/timekeeper_state.go
@@ -0,0 +1,41 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/time"
+)
+
+// beforeSave is invoked by stateify.
+func (t *Timekeeper) beforeSave() {
+ if t.stop != nil {
+ panic("pauseUpdates must be called before Save")
+ }
+
+ // N.B. we want the *offset* monotonic time.
+ var err error
+ if t.saveMonotonic, err = t.GetTime(time.Monotonic); err != nil {
+ panic("unable to get current monotonic time: " + err.Error())
+ }
+
+ if t.saveRealtime, err = t.GetTime(time.Realtime); err != nil {
+ panic("unable to get current realtime: " + err.Error())
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (t *Timekeeper) afterLoad() {
+ t.restored = make(chan struct{})
+}
diff --git a/pkg/sentry/kernel/timekeeper_test.go b/pkg/sentry/kernel/timekeeper_test.go
new file mode 100644
index 000000000..cf2f7ca72
--- /dev/null
+++ b/pkg/sentry/kernel/timekeeper_test.go
@@ -0,0 +1,156 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ sentrytime "gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// mockClocks is a sentrytime.Clocks that simply returns the times in the
+// struct.
+type mockClocks struct {
+ monotonic int64
+ realtime int64
+}
+
+// Update implements sentrytime.Clocks.Update. It does nothing.
+func (*mockClocks) Update() (monotonicParams sentrytime.Parameters, monotonicOk bool, realtimeParam sentrytime.Parameters, realtimeOk bool) {
+ return
+}
+
+// Update implements sentrytime.Clocks.GetTime.
+func (c *mockClocks) GetTime(id sentrytime.ClockID) (int64, error) {
+ switch id {
+ case sentrytime.Monotonic:
+ return c.monotonic, nil
+ case sentrytime.Realtime:
+ return c.realtime, nil
+ default:
+ return 0, syserror.EINVAL
+ }
+}
+
+// stateTestClocklessTimekeeper returns a test Timekeeper which has not had
+// SetClocks called.
+func stateTestClocklessTimekeeper(tb testing.TB) *Timekeeper {
+ ctx := contexttest.Context(tb)
+ mfp := pgalloc.MemoryFileProviderFromContext(ctx)
+ fr, err := mfp.MemoryFile().Allocate(usermem.PageSize, usage.Anonymous)
+ if err != nil {
+ tb.Fatalf("failed to allocate memory: %v", err)
+ }
+ return &Timekeeper{
+ params: NewVDSOParamPage(mfp, fr),
+ }
+}
+
+func stateTestTimekeeper(tb testing.TB) *Timekeeper {
+ t := stateTestClocklessTimekeeper(tb)
+ t.SetClocks(sentrytime.NewCalibratedClocks())
+ return t
+}
+
+// TestTimekeeperMonotonicZero tests that monotonic time starts at zero.
+func TestTimekeeperMonotonicZero(t *testing.T) {
+ c := &mockClocks{
+ monotonic: 100000,
+ }
+
+ tk := stateTestClocklessTimekeeper(t)
+ tk.SetClocks(c)
+ defer tk.Destroy()
+
+ now, err := tk.GetTime(sentrytime.Monotonic)
+ if err != nil {
+ t.Errorf("GetTime err got %v want nil", err)
+ }
+ if now != 0 {
+ t.Errorf("GetTime got %d want 0", now)
+ }
+
+ c.monotonic += 10
+
+ now, err = tk.GetTime(sentrytime.Monotonic)
+ if err != nil {
+ t.Errorf("GetTime err got %v want nil", err)
+ }
+ if now != 10 {
+ t.Errorf("GetTime got %d want 10", now)
+ }
+}
+
+// TestTimekeeperMonotonicJumpForward tests that monotonic time jumps forward
+// after restore.
+func TestTimekeeperMonotonicForward(t *testing.T) {
+ c := &mockClocks{
+ monotonic: 900000,
+ realtime: 600000,
+ }
+
+ tk := stateTestClocklessTimekeeper(t)
+ tk.restored = make(chan struct{})
+ tk.saveMonotonic = 100000
+ tk.saveRealtime = 400000
+ tk.SetClocks(c)
+ defer tk.Destroy()
+
+ // The monotonic clock should jump ahead by 200000 to 300000.
+ //
+ // The new system monotonic time (900000) is irrelevant to what the app
+ // sees.
+ now, err := tk.GetTime(sentrytime.Monotonic)
+ if err != nil {
+ t.Errorf("GetTime err got %v want nil", err)
+ }
+ if now != 300000 {
+ t.Errorf("GetTime got %d want 300000", now)
+ }
+}
+
+// TestTimekeeperMonotonicJumpBackwards tests that monotonic time does not jump
+// backwards when realtime goes backwards.
+func TestTimekeeperMonotonicJumpBackwards(t *testing.T) {
+ c := &mockClocks{
+ monotonic: 900000,
+ realtime: 400000,
+ }
+
+ tk := stateTestClocklessTimekeeper(t)
+ tk.restored = make(chan struct{})
+ tk.saveMonotonic = 100000
+ tk.saveRealtime = 600000
+ tk.SetClocks(c)
+ defer tk.Destroy()
+
+ // The monotonic clock should remain at 100000.
+ //
+ // The new system monotonic time (900000) is irrelevant to what the app
+ // sees and we don't want to jump the monotonic clock backwards like
+ // realtime did.
+ now, err := tk.GetTime(sentrytime.Monotonic)
+ if err != nil {
+ t.Errorf("GetTime err got %v want nil", err)
+ }
+ if now != 100000 {
+ t.Errorf("GetTime got %d want 100000", now)
+ }
+}
diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go
new file mode 100644
index 000000000..d0e0810e8
--- /dev/null
+++ b/pkg/sentry/kernel/tty.go
@@ -0,0 +1,41 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import "gvisor.dev/gvisor/pkg/sync"
+
+// TTY defines the relationship between a thread group and its controlling
+// terminal.
+//
+// +stateify savable
+type TTY struct {
+ // Index is the terminal index. It is immutable.
+ Index uint32
+
+ mu sync.Mutex `state:"nosave"`
+
+ // tg is protected by mu.
+ tg *ThreadGroup
+}
+
+// TTY returns the thread group's controlling terminal. If nil, there is no
+// controlling terminal.
+func (tg *ThreadGroup) TTY() *TTY {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+ return tg.tty
+}
diff --git a/pkg/sentry/kernel/uncaught_signal.proto b/pkg/sentry/kernel/uncaught_signal.proto
new file mode 100644
index 000000000..0bdb062cb
--- /dev/null
+++ b/pkg/sentry/kernel/uncaught_signal.proto
@@ -0,0 +1,37 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package gvisor;
+
+import "pkg/sentry/arch/registers.proto";
+
+message UncaughtSignal {
+ // Thread ID.
+ int32 tid = 1;
+
+ // Process ID.
+ int32 pid = 2;
+
+ // Registers at the time of the fault or signal.
+ Registers registers = 3;
+
+ // Signal number.
+ int32 signal_number = 4;
+
+ // The memory location which caused the fault (set if applicable, 0
+ // otherwise). This will be set for SIGILL, SIGFPE, SIGSEGV, and SIGBUS.
+ uint64 fault_addr = 5;
+}
diff --git a/pkg/sentry/kernel/uts_namespace.go b/pkg/sentry/kernel/uts_namespace.go
new file mode 100644
index 000000000..8ccf04bd1
--- /dev/null
+++ b/pkg/sentry/kernel/uts_namespace.go
@@ -0,0 +1,101 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// UTSNamespace represents a UTS namespace, a holder of two system identifiers:
+// the hostname and domain name.
+//
+// +stateify savable
+type UTSNamespace struct {
+ // mu protects all fields below.
+ mu sync.Mutex `state:"nosave"`
+ hostName string
+ domainName string
+
+ // userns is the user namespace associated with the UTSNamespace.
+ // Privileged operations on this UTSNamespace must have appropriate
+ // capabilities in userns.
+ //
+ // userns is immutable.
+ userns *auth.UserNamespace
+}
+
+// NewUTSNamespace creates a new UTS namespace.
+func NewUTSNamespace(hostName, domainName string, userns *auth.UserNamespace) *UTSNamespace {
+ return &UTSNamespace{
+ hostName: hostName,
+ domainName: domainName,
+ userns: userns,
+ }
+}
+
+// UTSNamespace returns the task's UTS namespace.
+func (t *Task) UTSNamespace() *UTSNamespace {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.utsns
+}
+
+// HostName returns the host name of this UTS namespace.
+func (u *UTSNamespace) HostName() string {
+ u.mu.Lock()
+ defer u.mu.Unlock()
+ return u.hostName
+}
+
+// SetHostName sets the host name of this UTS namespace.
+func (u *UTSNamespace) SetHostName(host string) {
+ u.mu.Lock()
+ defer u.mu.Unlock()
+ u.hostName = host
+}
+
+// DomainName returns the domain name of this UTS namespace.
+func (u *UTSNamespace) DomainName() string {
+ u.mu.Lock()
+ defer u.mu.Unlock()
+ return u.domainName
+}
+
+// SetDomainName sets the domain name of this UTS namespace.
+func (u *UTSNamespace) SetDomainName(domain string) {
+ u.mu.Lock()
+ defer u.mu.Unlock()
+ u.domainName = domain
+}
+
+// UserNamespace returns the user namespace associated with this UTS namespace.
+func (u *UTSNamespace) UserNamespace() *auth.UserNamespace {
+ u.mu.Lock()
+ defer u.mu.Unlock()
+ return u.userns
+}
+
+// Clone makes a copy of this UTS namespace, associating the given user
+// namespace.
+func (u *UTSNamespace) Clone(userns *auth.UserNamespace) *UTSNamespace {
+ u.mu.Lock()
+ defer u.mu.Unlock()
+ return &UTSNamespace{
+ hostName: u.hostName,
+ domainName: u.domainName,
+ userns: userns,
+ }
+}
diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go
new file mode 100644
index 000000000..f1b3c212c
--- /dev/null
+++ b/pkg/sentry/kernel/vdso.go
@@ -0,0 +1,148 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// vdsoParams are the parameters exposed to the VDSO.
+//
+// They are exposed to the VDSO via a parameter page managed by VDSOParamPage,
+// which also includes a sequence counter.
+type vdsoParams struct {
+ monotonicReady uint64
+ monotonicBaseCycles int64
+ monotonicBaseRef int64
+ monotonicFrequency uint64
+
+ realtimeReady uint64
+ realtimeBaseCycles int64
+ realtimeBaseRef int64
+ realtimeFrequency uint64
+}
+
+// VDSOParamPage manages a VDSO parameter page.
+//
+// Its memory layout looks like:
+//
+// type page struct {
+// // seq is a sequence counter that protects the fields below.
+// seq uint64
+// vdsoParams
+// }
+//
+// Everything in the struct is 8 bytes for easy alignment.
+//
+// It must be kept in sync with params in vdso/vdso_time.cc.
+//
+// +stateify savable
+type VDSOParamPage struct {
+ // The parameter page is fr, allocated from mfp.MemoryFile().
+ mfp pgalloc.MemoryFileProvider
+ fr platform.FileRange
+
+ // seq is the current sequence count written to the page.
+ //
+ // A write is in progress if bit 1 of the counter is set.
+ //
+ // Timekeeper's updater goroutine may call Write before equality is
+ // checked in state_test_util tests, causing this field to change across
+ // save / restore.
+ seq uint64
+}
+
+// NewVDSOParamPage returns a VDSOParamPage.
+//
+// Preconditions:
+//
+// * fr is a single page allocated from mfp.MemoryFile(). VDSOParamPage does
+// not take ownership of fr; it must remain allocated for the lifetime of the
+// VDSOParamPage.
+//
+// * VDSOParamPage must be the only writer to fr.
+//
+// * mfp.MemoryFile().MapInternal(fr) must return a single safemem.Block.
+func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *VDSOParamPage {
+ return &VDSOParamPage{mfp: mfp, fr: fr}
+}
+
+// access returns a mapping of the param page.
+func (v *VDSOParamPage) access() (safemem.Block, error) {
+ bs, err := v.mfp.MemoryFile().MapInternal(v.fr, usermem.ReadWrite)
+ if err != nil {
+ return safemem.Block{}, err
+ }
+ if bs.NumBlocks() != 1 {
+ panic(fmt.Sprintf("Multiple blocks (%d) in VDSO param BlockSeq", bs.NumBlocks()))
+ }
+ return bs.Head(), nil
+}
+
+// incrementSeq increments the sequence counter in the param page.
+func (v *VDSOParamPage) incrementSeq(paramPage safemem.Block) error {
+ next := v.seq + 1
+ old, err := safemem.SwapUint64(paramPage, next)
+ if err != nil {
+ return err
+ }
+
+ if old != v.seq {
+ return fmt.Errorf("unexpected VDSOParamPage seq value: got %d expected %d. Application may hang or get incorrect time from the VDSO.", old, v.seq)
+ }
+
+ v.seq = next
+ return nil
+}
+
+// Write updates the VDSO parameters.
+//
+// Write starts a write block, calls f to get the new parameters, writes
+// out the new parameters, then ends the write block.
+func (v *VDSOParamPage) Write(f func() vdsoParams) error {
+ paramPage, err := v.access()
+ if err != nil {
+ return err
+ }
+
+ // Write begin.
+ next := v.seq + 1
+ if next%2 != 1 {
+ panic("Out-of-order sequence count")
+ }
+
+ err = v.incrementSeq(paramPage)
+ if err != nil {
+ return err
+ }
+
+ // Get the new params.
+ p := f()
+ buf := binary.Marshal(nil, usermem.ByteOrder, p)
+
+ // Skip the sequence counter.
+ if _, err := safemem.Copy(paramPage.DropFirst(8), safemem.BlockFromSafeSlice(buf)); err != nil {
+ panic(fmt.Sprintf("Unable to get set VDSO parameters: %v", err))
+ }
+
+ // Write end.
+ return v.incrementSeq(paramPage)
+}
diff --git a/pkg/sentry/kernel/version.go b/pkg/sentry/kernel/version.go
new file mode 100644
index 000000000..5640dd71d
--- /dev/null
+++ b/pkg/sentry/kernel/version.go
@@ -0,0 +1,33 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+// Version defines the application-visible system version.
+type Version struct {
+ // Operating system name (e.g. "Linux").
+ Sysname string
+
+ // Operating system release (e.g. "4.4-amd64").
+ Release string
+
+ // Operating system version. On Linux this takes the shape
+ // "#VERSION CONFIG_FLAGS TIMESTAMP"
+ // where:
+ // - VERSION is a sequence counter incremented on every successful build
+ // - CONFIG_FLAGS is a space-separated list of major enabled kernel features
+ // (e.g. "SMP" and "PREEMPT")
+ // - TIMESTAMP is the build timestamp as returned by `date`
+ Version string
+}
diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD
new file mode 100644
index 000000000..cf591c4c1
--- /dev/null
+++ b/pkg/sentry/limits/BUILD
@@ -0,0 +1,27 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "limits",
+ srcs = [
+ "context.go",
+ "limits.go",
+ "linux.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "limits_test",
+ size = "small",
+ srcs = [
+ "limits_test.go",
+ ],
+ library = ":limits",
+)
diff --git a/pkg/sentry/limits/context.go b/pkg/sentry/limits/context.go
new file mode 100644
index 000000000..77e1fe217
--- /dev/null
+++ b/pkg/sentry/limits/context.go
@@ -0,0 +1,35 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package limits
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+// contextID is the limit package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxLimits is a Context.Value key for a LimitSet.
+ CtxLimits contextID = iota
+)
+
+// FromContext returns the limits that apply to ctx.
+func FromContext(ctx context.Context) *LimitSet {
+ if v := ctx.Value(CtxLimits); v != nil {
+ return v.(*LimitSet)
+ }
+ return nil
+}
diff --git a/pkg/sentry/limits/limits.go b/pkg/sentry/limits/limits.go
new file mode 100644
index 000000000..31b9e9ff6
--- /dev/null
+++ b/pkg/sentry/limits/limits.go
@@ -0,0 +1,137 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package limits provides resource limits.
+package limits
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// LimitType defines a type of resource limit.
+type LimitType int
+
+// Set of constants defining the different types of resource limits.
+const (
+ CPU LimitType = iota
+ FileSize
+ Data
+ Stack
+ Core
+ Rss
+ ProcessCount
+ NumberOfFiles
+ MemoryLocked
+ AS
+ Locks
+ SignalsPending
+ MessageQueueBytes
+ Nice
+ RealTimePriority
+ Rttime
+)
+
+// Infinity is a constant representing a resource with no limit.
+const Infinity = ^uint64(0)
+
+// Limit specifies a system limit.
+//
+// +stateify savable
+type Limit struct {
+ // Cur specifies the current limit.
+ Cur uint64
+ // Max specifies the maximum settable limit.
+ Max uint64
+}
+
+// LimitSet represents the Limits that correspond to each LimitType.
+//
+// +stateify savable
+type LimitSet struct {
+ mu sync.Mutex `state:"nosave"`
+ data map[LimitType]Limit
+}
+
+// NewLimitSet creates a new, empty LimitSet.
+func NewLimitSet() *LimitSet {
+ return &LimitSet{
+ data: make(map[LimitType]Limit),
+ }
+}
+
+// GetCopy returns a clone of the LimitSet.
+func (l *LimitSet) GetCopy() *LimitSet {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ copyData := make(map[LimitType]Limit)
+ for k, v := range l.data {
+ copyData[k] = v
+ }
+ return &LimitSet{
+ data: copyData,
+ }
+}
+
+// Get returns the resource limit associated with LimitType t.
+// If no limit is provided, it defaults to an infinite limit.Infinity.
+func (l *LimitSet) Get(t LimitType) Limit {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ s, ok := l.data[t]
+ if !ok {
+ return Limit{Cur: Infinity, Max: Infinity}
+ }
+ return s
+}
+
+// GetCapped returns the current value for the limit, capped as specified.
+func (l *LimitSet) GetCapped(t LimitType, max uint64) uint64 {
+ s := l.Get(t)
+ if s.Cur == Infinity || s.Cur > max {
+ return max
+ }
+ return s.Cur
+}
+
+// SetUnchecked assigns value v to resource of LimitType t.
+func (l *LimitSet) SetUnchecked(t LimitType, v Limit) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ l.data[t] = v
+}
+
+// Set assigns value v to resource of LimitType t and returns the old value.
+// privileged should be true only when either the caller has CAP_SYS_RESOURCE
+// or when creating limits for a new kernel.
+func (l *LimitSet) Set(t LimitType, v Limit, privileged bool) (Limit, error) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ // If a limit is already set, make sure the new limit doesn't
+ // exceed the previous max limit.
+ if _, ok := l.data[t]; ok {
+ // Unprivileged users can only lower their hard limits.
+ if l.data[t].Max < v.Max && !privileged {
+ return Limit{}, syscall.EPERM
+ }
+ if v.Cur > v.Max {
+ return Limit{}, syscall.EINVAL
+ }
+ }
+ old := l.data[t]
+ l.data[t] = v
+ return old, nil
+}
diff --git a/pkg/sentry/limits/limits_test.go b/pkg/sentry/limits/limits_test.go
new file mode 100644
index 000000000..658a20f56
--- /dev/null
+++ b/pkg/sentry/limits/limits_test.go
@@ -0,0 +1,43 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package limits
+
+import (
+ "syscall"
+ "testing"
+)
+
+func TestSet(t *testing.T) {
+ testCases := []struct {
+ limit Limit
+ privileged bool
+ expectedErr error
+ }{
+ {limit: Limit{Cur: 50, Max: 50}, privileged: false, expectedErr: nil},
+ {limit: Limit{Cur: 20, Max: 50}, privileged: false, expectedErr: nil},
+ {limit: Limit{Cur: 20, Max: 60}, privileged: false, expectedErr: syscall.EPERM},
+ {limit: Limit{Cur: 60, Max: 50}, privileged: false, expectedErr: syscall.EINVAL},
+ {limit: Limit{Cur: 11, Max: 10}, privileged: false, expectedErr: syscall.EINVAL},
+ {limit: Limit{Cur: 20, Max: 60}, privileged: true, expectedErr: nil},
+ }
+
+ ls := NewLimitSet()
+ for _, tc := range testCases {
+ if _, err := ls.Set(1, tc.limit, tc.privileged); err != tc.expectedErr {
+ t.Fatalf("Tried to set Limit to %+v and privilege %t: got %v, wanted %v", tc.limit, tc.privileged, err, tc.expectedErr)
+ }
+ }
+
+}
diff --git a/pkg/sentry/limits/linux.go b/pkg/sentry/limits/linux.go
new file mode 100644
index 000000000..3f71abecc
--- /dev/null
+++ b/pkg/sentry/limits/linux.go
@@ -0,0 +1,100 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package limits
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// FromLinuxResource maps linux resources to sentry LimitTypes.
+var FromLinuxResource = map[int]LimitType{
+ linux.RLIMIT_CPU: CPU,
+ linux.RLIMIT_FSIZE: FileSize,
+ linux.RLIMIT_DATA: Data,
+ linux.RLIMIT_STACK: Stack,
+ linux.RLIMIT_CORE: Core,
+ linux.RLIMIT_RSS: Rss,
+ linux.RLIMIT_NPROC: ProcessCount,
+ linux.RLIMIT_NOFILE: NumberOfFiles,
+ linux.RLIMIT_MEMLOCK: MemoryLocked,
+ linux.RLIMIT_AS: AS,
+ linux.RLIMIT_LOCKS: Locks,
+ linux.RLIMIT_SIGPENDING: SignalsPending,
+ linux.RLIMIT_MSGQUEUE: MessageQueueBytes,
+ linux.RLIMIT_NICE: Nice,
+ linux.RLIMIT_RTPRIO: RealTimePriority,
+ linux.RLIMIT_RTTIME: Rttime,
+}
+
+// FromLinux maps linux rlimit values to sentry Limits, being careful to handle
+// infinities.
+func FromLinux(rl uint64) uint64 {
+ if rl == linux.RLimInfinity {
+ return Infinity
+ }
+ return rl
+}
+
+// ToLinux maps sentry Limits to linux rlimit values, being careful to handle
+// infinities.
+func ToLinux(l uint64) uint64 {
+ if l == Infinity {
+ return linux.RLimInfinity
+ }
+ return l
+}
+
+// NewLinuxLimitSet returns a LimitSet whose values match the default rlimits
+// in Linux.
+func NewLinuxLimitSet() (*LimitSet, error) {
+ ls := NewLimitSet()
+ for rlt, rl := range linux.InitRLimits {
+ lt, ok := FromLinuxResource[rlt]
+ if !ok {
+ return nil, fmt.Errorf("unknown rlimit type %v", rlt)
+ }
+ ls.SetUnchecked(lt, Limit{
+ Cur: FromLinux(rl.Cur),
+ Max: FromLinux(rl.Max),
+ })
+ }
+ return ls, nil
+}
+
+// NewLinuxDistroLimitSet returns a new LimitSet whose values are typical
+// for a booted Linux distro.
+//
+// Many Linux init systems adjust the default Linux limits to values more
+// expected by the rest of the userspace. NewLinuxDistroLimitSet returns a
+// LimitSet with sensible defaults for applications that aren't starting
+// their own init system.
+func NewLinuxDistroLimitSet() (*LimitSet, error) {
+ ls, err := NewLinuxLimitSet()
+ if err != nil {
+ return nil, err
+ }
+
+ // Adjust ProcessCount to a lower value because GNU bash allocates 16
+ // bytes per proc and OOMs if this number is set too high. Value was
+ // picked arbitrarily.
+ //
+ // 1,048,576 ought to be enough for anyone.
+ l := ls.Get(ProcessCount)
+ l.Cur = 1 << 20
+ ls.Set(ProcessCount, l, true /* privileged */)
+ return ls, nil
+}
diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD
new file mode 100644
index 000000000..34bdb0b69
--- /dev/null
+++ b/pkg/sentry/loader/BUILD
@@ -0,0 +1,46 @@
+load("//tools:defs.bzl", "go_embed_data", "go_library")
+
+package(licenses = ["notice"])
+
+go_embed_data(
+ name = "vdso_bin",
+ src = "//vdso:vdso.so",
+ package = "loader",
+ var = "vdsoBin",
+)
+
+go_library(
+ name = "loader",
+ srcs = [
+ "elf.go",
+ "interpreter.go",
+ "loader.go",
+ "vdso.go",
+ "vdso_state.go",
+ ":vdso_bin",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi",
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/context",
+ "//pkg/cpuid",
+ "//pkg/log",
+ "//pkg/rand",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fsbridge",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/mm",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/uniqueid",
+ "//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
new file mode 100644
index 000000000..ddeaff3db
--- /dev/null
+++ b/pkg/sentry/loader/elf.go
@@ -0,0 +1,700 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package loader
+
+import (
+ "bytes"
+ "debug/elf"
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ // elfMagic identifies an ELF file.
+ elfMagic = "\x7fELF"
+
+ // maxTotalPhdrSize is the maximum combined size of all program
+ // headers. Linux limits this to one page.
+ maxTotalPhdrSize = usermem.PageSize
+)
+
+var (
+ // header64Size is the size of elf.Header64.
+ header64Size = int(binary.Size(elf.Header64{}))
+
+ // Prog64Size is the size of elf.Prog64.
+ prog64Size = int(binary.Size(elf.Prog64{}))
+)
+
+func progFlagsAsPerms(f elf.ProgFlag) usermem.AccessType {
+ var p usermem.AccessType
+ if f&elf.PF_R == elf.PF_R {
+ p.Read = true
+ }
+ if f&elf.PF_W == elf.PF_W {
+ p.Write = true
+ }
+ if f&elf.PF_X == elf.PF_X {
+ p.Execute = true
+ }
+ return p
+}
+
+// elfInfo contains the metadata needed to load an ELF binary.
+type elfInfo struct {
+ // os is the target OS of the ELF.
+ os abi.OS
+
+ // arch is the target architecture of the ELF.
+ arch arch.Arch
+
+ // entry is the program entry point.
+ entry usermem.Addr
+
+ // phdrs are the program headers.
+ phdrs []elf.ProgHeader
+
+ // phdrSize is the size of a single program header in the ELF.
+ phdrSize int
+
+ // phdrOff is the offset of the program headers in the file.
+ phdrOff uint64
+
+ // sharedObject is true if the ELF represents a shared object.
+ sharedObject bool
+}
+
+// fullReader interface extracts the ReadFull method from fsbridge.File so that
+// client code does not need to define an entire fsbridge.File when only read
+// functionality is needed.
+//
+// TODO(gvisor.dev/issue/1035): Once VFS2 ships, rewrite this to wrap
+// vfs.FileDescription's PRead/Read instead.
+type fullReader interface {
+ // ReadFull is the same as fsbridge.File.ReadFull.
+ ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error)
+}
+
+// parseHeader parse the ELF header, verifying that this is a supported ELF
+// file and returning the ELF program headers.
+//
+// This is similar to elf.NewFile, except that it is more strict about what it
+// accepts from the ELF, and it doesn't parse unnecessary parts of the file.
+func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) {
+ // Check ident first; it will tell us the endianness of the rest of the
+ // structs.
+ var ident [elf.EI_NIDENT]byte
+ _, err := f.ReadFull(ctx, usermem.BytesIOSequence(ident[:]), 0)
+ if err != nil {
+ log.Infof("Error reading ELF ident: %v", err)
+ // The entire ident array always exists.
+ if err == io.EOF || err == io.ErrUnexpectedEOF {
+ err = syserror.ENOEXEC
+ }
+ return elfInfo{}, err
+ }
+
+ // Only some callers pre-check the ELF magic.
+ if !bytes.Equal(ident[:len(elfMagic)], []byte(elfMagic)) {
+ log.Infof("File is not an ELF")
+ return elfInfo{}, syserror.ENOEXEC
+ }
+
+ // We only support 64-bit, little endian binaries
+ if class := elf.Class(ident[elf.EI_CLASS]); class != elf.ELFCLASS64 {
+ log.Infof("Unsupported ELF class: %v", class)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+ if endian := elf.Data(ident[elf.EI_DATA]); endian != elf.ELFDATA2LSB {
+ log.Infof("Unsupported ELF endianness: %v", endian)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+ byteOrder := binary.LittleEndian
+
+ if version := elf.Version(ident[elf.EI_VERSION]); version != elf.EV_CURRENT {
+ log.Infof("Unsupported ELF version: %v", version)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+ // EI_OSABI is ignored by Linux, which is the only OS supported.
+ os := abi.Linux
+
+ var hdr elf.Header64
+ hdrBuf := make([]byte, header64Size)
+ _, err = f.ReadFull(ctx, usermem.BytesIOSequence(hdrBuf), 0)
+ if err != nil {
+ log.Infof("Error reading ELF header: %v", err)
+ // The entire header always exists.
+ if err == io.EOF || err == io.ErrUnexpectedEOF {
+ err = syserror.ENOEXEC
+ }
+ return elfInfo{}, err
+ }
+ binary.Unmarshal(hdrBuf, byteOrder, &hdr)
+
+ // We support amd64 and arm64.
+ var a arch.Arch
+ switch machine := elf.Machine(hdr.Machine); machine {
+ case elf.EM_X86_64:
+ a = arch.AMD64
+ case elf.EM_AARCH64:
+ a = arch.ARM64
+ default:
+ log.Infof("Unsupported ELF machine %d", machine)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+
+ var sharedObject bool
+ elfType := elf.Type(hdr.Type)
+ switch elfType {
+ case elf.ET_EXEC:
+ sharedObject = false
+ case elf.ET_DYN:
+ sharedObject = true
+ default:
+ log.Infof("Unsupported ELF type %v", elfType)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+
+ if int(hdr.Phentsize) != prog64Size {
+ log.Infof("Unsupported phdr size %d", hdr.Phentsize)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+ totalPhdrSize := prog64Size * int(hdr.Phnum)
+ if totalPhdrSize < prog64Size {
+ log.Warningf("No phdrs or total phdr size overflows: prog64Size: %d phnum: %d", prog64Size, int(hdr.Phnum))
+ return elfInfo{}, syserror.ENOEXEC
+ }
+ if totalPhdrSize > maxTotalPhdrSize {
+ log.Infof("Too many phdrs (%d): total size %d > %d", hdr.Phnum, totalPhdrSize, maxTotalPhdrSize)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+
+ phdrBuf := make([]byte, totalPhdrSize)
+ _, err = f.ReadFull(ctx, usermem.BytesIOSequence(phdrBuf), int64(hdr.Phoff))
+ if err != nil {
+ log.Infof("Error reading ELF phdrs: %v", err)
+ // If phdrs were specified, they should all exist.
+ if err == io.EOF || err == io.ErrUnexpectedEOF {
+ err = syserror.ENOEXEC
+ }
+ return elfInfo{}, err
+ }
+
+ phdrs := make([]elf.ProgHeader, hdr.Phnum)
+ for i := range phdrs {
+ var prog64 elf.Prog64
+ binary.Unmarshal(phdrBuf[:prog64Size], byteOrder, &prog64)
+ phdrBuf = phdrBuf[prog64Size:]
+ phdrs[i] = elf.ProgHeader{
+ Type: elf.ProgType(prog64.Type),
+ Flags: elf.ProgFlag(prog64.Flags),
+ Off: prog64.Off,
+ Vaddr: prog64.Vaddr,
+ Paddr: prog64.Paddr,
+ Filesz: prog64.Filesz,
+ Memsz: prog64.Memsz,
+ Align: prog64.Align,
+ }
+ }
+
+ return elfInfo{
+ os: os,
+ arch: a,
+ entry: usermem.Addr(hdr.Entry),
+ phdrs: phdrs,
+ phdrOff: hdr.Phoff,
+ phdrSize: prog64Size,
+ sharedObject: sharedObject,
+ }, nil
+}
+
+// mapSegment maps a phdr into the Task. offset is the offset to apply to
+// phdr.Vaddr.
+func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr *elf.ProgHeader, offset usermem.Addr) error {
+ // We must make a page-aligned mapping.
+ adjust := usermem.Addr(phdr.Vaddr).PageOffset()
+
+ addr, ok := offset.AddLength(phdr.Vaddr)
+ if !ok {
+ // If offset != 0 we should have ensured this would fit.
+ ctx.Warningf("Computed segment load address overflows: %#x + %#x", phdr.Vaddr, offset)
+ return syserror.ENOEXEC
+ }
+ addr -= usermem.Addr(adjust)
+
+ fileSize := phdr.Filesz + adjust
+ if fileSize < phdr.Filesz {
+ ctx.Infof("Computed segment file size overflows: %#x + %#x", phdr.Filesz, adjust)
+ return syserror.ENOEXEC
+ }
+ ms, ok := usermem.Addr(fileSize).RoundUp()
+ if !ok {
+ ctx.Infof("fileSize %#x too large", fileSize)
+ return syserror.ENOEXEC
+ }
+ mapSize := uint64(ms)
+
+ if mapSize > 0 {
+ // This must result in a page-aligned offset. i.e., the original
+ // phdr.Off must have the same alignment as phdr.Vaddr. If that is not
+ // true, MMap will reject the mapping.
+ fileOffset := phdr.Off - adjust
+
+ prot := progFlagsAsPerms(phdr.Flags)
+ mopts := memmap.MMapOpts{
+ Length: mapSize,
+ Offset: fileOffset,
+ Addr: addr,
+ Fixed: true,
+ // Linux will happily allow conflicting segments to map over
+ // one another.
+ Unmap: true,
+ Private: true,
+ Perms: prot,
+ MaxPerms: usermem.AnyAccess,
+ }
+ defer func() {
+ if mopts.MappingIdentity != nil {
+ mopts.MappingIdentity.DecRef()
+ }
+ }()
+ if err := f.ConfigureMMap(ctx, &mopts); err != nil {
+ ctx.Infof("File is not memory-mappable: %v", err)
+ return err
+ }
+ if _, err := m.MMap(ctx, mopts); err != nil {
+ ctx.Infof("Error mapping PT_LOAD segment %+v at %#x: %v", phdr, addr, err)
+ return err
+ }
+
+ // We need to clear the end of the last page that exceeds fileSize so
+ // we don't map part of the file beyond fileSize.
+ //
+ // Note that Linux *does not* clear the portion of the first page
+ // before phdr.Off.
+ if mapSize > fileSize {
+ zeroAddr, ok := addr.AddLength(fileSize)
+ if !ok {
+ panic(fmt.Sprintf("successfully mmaped address overflows? %#x + %#x", addr, fileSize))
+ }
+ zeroSize := int64(mapSize - fileSize)
+ if zeroSize < 0 {
+ panic(fmt.Sprintf("zeroSize too big? %#x", uint64(zeroSize)))
+ }
+ if _, err := m.ZeroOut(ctx, zeroAddr, zeroSize, usermem.IOOpts{IgnorePermissions: true}); err != nil {
+ ctx.Warningf("Failed to zero end of page [%#x, %#x): %v", zeroAddr, zeroAddr+usermem.Addr(zeroSize), err)
+ return err
+ }
+ }
+ }
+
+ memSize := phdr.Memsz + adjust
+ if memSize < phdr.Memsz {
+ ctx.Infof("Computed segment mem size overflows: %#x + %#x", phdr.Memsz, adjust)
+ return syserror.ENOEXEC
+ }
+
+ // Allocate more anonymous pages if necessary.
+ if mapSize < memSize {
+ anonAddr, ok := addr.AddLength(mapSize)
+ if !ok {
+ panic(fmt.Sprintf("anonymous memory doesn't fit in pre-sized range? %#x + %#x", addr, mapSize))
+ }
+ anonSize, ok := usermem.Addr(memSize - mapSize).RoundUp()
+ if !ok {
+ ctx.Infof("extra anon pages too large: %#x", memSize-mapSize)
+ return syserror.ENOEXEC
+ }
+
+ // N.B. Linux uses vm_brk_flags to map these pages, which only
+ // honors the X bit, always mapping at least RW. ignoring These
+ // pages are not included in the final brk region.
+ prot := usermem.ReadWrite
+ if phdr.Flags&elf.PF_X == elf.PF_X {
+ prot.Execute = true
+ }
+
+ if _, err := m.MMap(ctx, memmap.MMapOpts{
+ Length: uint64(anonSize),
+ Addr: anonAddr,
+ // Fixed without Unmap will fail the mmap if something is
+ // already at addr.
+ Fixed: true,
+ Private: true,
+ Perms: prot,
+ MaxPerms: usermem.AnyAccess,
+ }); err != nil {
+ ctx.Infof("Error mapping PT_LOAD segment %v anonymous memory: %v", phdr, err)
+ return err
+ }
+ }
+
+ return nil
+}
+
+// loadedELF describes an ELF that has been successfully loaded.
+type loadedELF struct {
+ // os is the target OS of the ELF.
+ os abi.OS
+
+ // arch is the target architecture of the ELF.
+ arch arch.Arch
+
+ // entry is the entry point of the ELF.
+ entry usermem.Addr
+
+ // start is the end of the ELF.
+ start usermem.Addr
+
+ // end is the end of the ELF.
+ end usermem.Addr
+
+ // interpter is the path to the ELF interpreter.
+ interpreter string
+
+ // phdrAddr is the address of the ELF program headers.
+ phdrAddr usermem.Addr
+
+ // phdrSize is the size of a single program header in the ELF.
+ phdrSize int
+
+ // phdrNum is the number of program headers.
+ phdrNum int
+
+ // auxv contains a subset of ELF-specific auxiliary vector entries:
+ // * AT_PHDR
+ // * AT_PHENT
+ // * AT_PHNUM
+ // * AT_BASE
+ // * AT_ENTRY
+ auxv arch.Auxv
+}
+
+// loadParsedELF loads f into mm.
+//
+// info is the parsed elfInfo from the header.
+//
+// It does not load the ELF interpreter, or return any auxv entries.
+//
+// Preconditions:
+// * f is an ELF file
+func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, info elfInfo, sharedLoadOffset usermem.Addr) (loadedELF, error) {
+ first := true
+ var start, end usermem.Addr
+ var interpreter string
+ for _, phdr := range info.phdrs {
+ switch phdr.Type {
+ case elf.PT_LOAD:
+ vaddr := usermem.Addr(phdr.Vaddr)
+ if first {
+ first = false
+ start = vaddr
+ }
+ if vaddr < end {
+ // NOTE(b/37474556): Linux allows out-of-order
+ // segments, in violation of the spec.
+ ctx.Infof("PT_LOAD headers out-of-order. %#x < %#x", vaddr, end)
+ return loadedELF{}, syserror.ENOEXEC
+ }
+ var ok bool
+ end, ok = vaddr.AddLength(phdr.Memsz)
+ if !ok {
+ ctx.Infof("PT_LOAD header size overflows. %#x + %#x", vaddr, phdr.Memsz)
+ return loadedELF{}, syserror.ENOEXEC
+ }
+
+ case elf.PT_INTERP:
+ if phdr.Filesz < 2 {
+ ctx.Infof("PT_INTERP path too small: %v", phdr.Filesz)
+ return loadedELF{}, syserror.ENOEXEC
+ }
+ if phdr.Filesz > linux.PATH_MAX {
+ ctx.Infof("PT_INTERP path too big: %v", phdr.Filesz)
+ return loadedELF{}, syserror.ENOEXEC
+ }
+
+ path := make([]byte, phdr.Filesz)
+ _, err := f.ReadFull(ctx, usermem.BytesIOSequence(path), int64(phdr.Off))
+ if err != nil {
+ // If an interpreter was specified, it should exist.
+ ctx.Infof("Error reading PT_INTERP path: %v", err)
+ return loadedELF{}, syserror.ENOEXEC
+ }
+
+ if path[len(path)-1] != 0 {
+ ctx.Infof("PT_INTERP path not NUL-terminated: %v", path)
+ return loadedELF{}, syserror.ENOEXEC
+ }
+
+ // Strip NUL-terminator and everything beyond from
+ // string. Note that there may be a NUL-terminator
+ // before len(path)-1.
+ interpreter = string(path[:bytes.IndexByte(path, '\x00')])
+ if interpreter == "" {
+ // Linux actually attempts to open_exec("\0").
+ // open_exec -> do_open_execat fails to check
+ // that name != '\0' before calling
+ // do_filp_open, which thus opens the working
+ // directory. do_open_execat returns EACCES
+ // because the directory is not a regular file.
+ //
+ // We bypass that nonsense and simply
+ // short-circuit with EACCES. Those this does
+ // mean that there may be some edge cases where
+ // the open path would return a different
+ // error.
+ ctx.Infof("PT_INTERP path is empty: %v", path)
+ return loadedELF{}, syserror.EACCES
+ }
+ }
+ }
+
+ // Shared objects don't have fixed load addresses. We need to pick a
+ // base address big enough to fit all segments, so we first create a
+ // mapping for the total size just to find a region that is big enough.
+ //
+ // It is safe to unmap it immediately without racing with another mapping
+ // because we are the only one in control of the MemoryManager.
+ //
+ // Note that the vaddr of the first PT_LOAD segment is ignored when
+ // choosing the load address (even if it is non-zero). The vaddr does
+ // become an offset from that load address.
+ var offset usermem.Addr
+ if info.sharedObject {
+ totalSize := end - start
+ totalSize, ok := totalSize.RoundUp()
+ if !ok {
+ ctx.Infof("ELF PT_LOAD segments too big")
+ return loadedELF{}, syserror.ENOEXEC
+ }
+
+ var err error
+ offset, err = m.MMap(ctx, memmap.MMapOpts{
+ Length: uint64(totalSize),
+ Addr: sharedLoadOffset,
+ Private: true,
+ })
+ if err != nil {
+ ctx.Infof("Error allocating address space for shared object: %v", err)
+ return loadedELF{}, err
+ }
+ if err := m.MUnmap(ctx, offset, uint64(totalSize)); err != nil {
+ panic(fmt.Sprintf("Failed to unmap base address: %v", err))
+ }
+
+ start, ok = start.AddLength(uint64(offset))
+ if !ok {
+ panic(fmt.Sprintf("Start %#x + offset %#x overflows?", start, offset))
+ }
+
+ end, ok = end.AddLength(uint64(offset))
+ if !ok {
+ panic(fmt.Sprintf("End %#x + offset %#x overflows?", end, offset))
+ }
+
+ info.entry, ok = info.entry.AddLength(uint64(offset))
+ if !ok {
+ ctx.Infof("Entrypoint %#x + offset %#x overflows? Is the entrypoint within a segment?", info.entry, offset)
+ return loadedELF{}, err
+ }
+ }
+
+ // Map PT_LOAD segments.
+ for _, phdr := range info.phdrs {
+ switch phdr.Type {
+ case elf.PT_LOAD:
+ if phdr.Memsz == 0 {
+ // No need to load segments with size 0, but
+ // they exist in some binaries.
+ continue
+ }
+
+ if err := mapSegment(ctx, m, f, &phdr, offset); err != nil {
+ ctx.Infof("Failed to map PT_LOAD segment: %+v", phdr)
+ return loadedELF{}, err
+ }
+ }
+ }
+
+ // This assumes that the first segment contains the ELF headers. This
+ // may not be true in a malformed ELF, but Linux makes the same
+ // assumption.
+ phdrAddr, ok := start.AddLength(info.phdrOff)
+ if !ok {
+ ctx.Warningf("ELF start address %#x + phdr offset %#x overflows", start, info.phdrOff)
+ phdrAddr = 0
+ }
+
+ return loadedELF{
+ os: info.os,
+ arch: info.arch,
+ entry: info.entry,
+ start: start,
+ end: end,
+ interpreter: interpreter,
+ phdrAddr: phdrAddr,
+ phdrSize: info.phdrSize,
+ phdrNum: len(info.phdrs),
+ }, nil
+}
+
+// loadInitialELF loads f into mm.
+//
+// It creates an arch.Context for the ELF and prepares the mm for this arch.
+//
+// It does not load the ELF interpreter, or return any auxv entries.
+//
+// Preconditions:
+// * f is an ELF file
+// * f is the first ELF loaded into m
+func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureSet, f fsbridge.File) (loadedELF, arch.Context, error) {
+ info, err := parseHeader(ctx, f)
+ if err != nil {
+ ctx.Infof("Failed to parse initial ELF: %v", err)
+ return loadedELF{}, nil, err
+ }
+
+ // Check Image Compatibility.
+ if arch.Host != info.arch {
+ ctx.Warningf("Found mismatch for platform %s with ELF type %s", arch.Host.String(), info.arch.String())
+ return loadedELF{}, nil, syserror.ENOEXEC
+ }
+
+ // Create the arch.Context now so we can prepare the mmap layout before
+ // mapping anything.
+ ac := arch.New(info.arch, fs)
+
+ l, err := m.SetMmapLayout(ac, limits.FromContext(ctx))
+ if err != nil {
+ ctx.Warningf("Failed to set mmap layout: %v", err)
+ return loadedELF{}, nil, err
+ }
+
+ // PIELoadAddress tries to move the ELF out of the way of the default
+ // mmap base to ensure that the initial brk has sufficient space to
+ // grow.
+ le, err := loadParsedELF(ctx, m, f, info, ac.PIELoadAddress(l))
+ return le, ac, err
+}
+
+// loadInterpreterELF loads f into mm.
+//
+// The interpreter must be for the same OS/Arch as the initial ELF.
+//
+// It does not return any auxv entries.
+//
+// Preconditions:
+// * f is an ELF file
+func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, initial loadedELF) (loadedELF, error) {
+ info, err := parseHeader(ctx, f)
+ if err != nil {
+ if err == syserror.ENOEXEC {
+ // Bad interpreter.
+ err = syserror.ELIBBAD
+ }
+ return loadedELF{}, err
+ }
+
+ if info.os != initial.os {
+ ctx.Infof("Initial ELF OS %v and interpreter ELF OS %v differ", initial.os, info.os)
+ return loadedELF{}, syserror.ELIBBAD
+ }
+ if info.arch != initial.arch {
+ ctx.Infof("Initial ELF arch %v and interpreter ELF arch %v differ", initial.arch, info.arch)
+ return loadedELF{}, syserror.ELIBBAD
+ }
+
+ // The interpreter is not given a load offset, as its location does not
+ // affect brk.
+ return loadParsedELF(ctx, m, f, info, 0)
+}
+
+// loadELF loads args.File into the Task address space.
+//
+// If loadELF returns ErrSwitchFile it should be called again with the returned
+// path and argv.
+//
+// Preconditions:
+// * args.File is an ELF file
+func loadELF(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, error) {
+ bin, ac, err := loadInitialELF(ctx, args.MemoryManager, args.Features, args.File)
+ if err != nil {
+ ctx.Infof("Error loading binary: %v", err)
+ return loadedELF{}, nil, err
+ }
+
+ var interp loadedELF
+ if bin.interpreter != "" {
+ // Even if we do not allow the final link of the script to be
+ // resolved, the interpreter should still be resolved if it is
+ // a symlink.
+ args.ResolveFinal = true
+ // Refresh the traversal limit.
+ *args.RemainingTraversals = linux.MaxSymlinkTraversals
+ args.Filename = bin.interpreter
+ intFile, err := openPath(ctx, args)
+ if err != nil {
+ ctx.Infof("Error opening interpreter %s: %v", bin.interpreter, err)
+ return loadedELF{}, nil, err
+ }
+ defer intFile.DecRef()
+
+ interp, err = loadInterpreterELF(ctx, args.MemoryManager, intFile, bin)
+ if err != nil {
+ ctx.Infof("Error loading interpreter: %v", err)
+ return loadedELF{}, nil, err
+ }
+
+ if interp.interpreter != "" {
+ // No recursive interpreters!
+ ctx.Infof("Interpreter requires an interpreter")
+ return loadedELF{}, nil, syserror.ENOEXEC
+ }
+ }
+
+ // ELF-specific auxv entries.
+ bin.auxv = arch.Auxv{
+ arch.AuxEntry{linux.AT_PHDR, bin.phdrAddr},
+ arch.AuxEntry{linux.AT_PHENT, usermem.Addr(bin.phdrSize)},
+ arch.AuxEntry{linux.AT_PHNUM, usermem.Addr(bin.phdrNum)},
+ arch.AuxEntry{linux.AT_ENTRY, bin.entry},
+ }
+ if bin.interpreter != "" {
+ bin.auxv = append(bin.auxv, arch.AuxEntry{linux.AT_BASE, interp.start})
+
+ // Start in the interpreter.
+ // N.B. AT_ENTRY above contains the *original* entry point.
+ bin.entry = interp.entry
+ } else {
+ // Always add AT_BASE even if there is no interpreter.
+ bin.auxv = append(bin.auxv, arch.AuxEntry{linux.AT_BASE, 0})
+ }
+
+ return bin, ac, nil
+}
diff --git a/pkg/sentry/loader/interpreter.go b/pkg/sentry/loader/interpreter.go
new file mode 100644
index 000000000..3886b4d33
--- /dev/null
+++ b/pkg/sentry/loader/interpreter.go
@@ -0,0 +1,108 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package loader
+
+import (
+ "bytes"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ // interpreterScriptMagic identifies an interpreter script.
+ interpreterScriptMagic = "#!"
+
+ // interpMaxLineLength is the maximum length for the first line of an
+ // interpreter script.
+ //
+ // From execve(2): "A maximum line length of 127 characters is allowed
+ // for the first line in a #! executable shell script."
+ interpMaxLineLength = 127
+)
+
+// parseInterpreterScript returns the interpreter path and argv.
+func parseInterpreterScript(ctx context.Context, filename string, f fsbridge.File, argv []string) (newpath string, newargv []string, err error) {
+ line := make([]byte, interpMaxLineLength)
+ n, err := f.ReadFull(ctx, usermem.BytesIOSequence(line), 0)
+ // Short read is OK.
+ if err != nil && err != io.ErrUnexpectedEOF {
+ if err == io.EOF {
+ err = syserror.ENOEXEC
+ }
+ return "", []string{}, err
+ }
+ line = line[:n]
+
+ if !bytes.Equal(line[:2], []byte(interpreterScriptMagic)) {
+ return "", []string{}, syserror.ENOEXEC
+ }
+ // Ignore #!.
+ line = line[2:]
+
+ // Ignore everything after newline.
+ // Linux silently truncates the remainder of the line if it exceeds
+ // interpMaxLineLength.
+ i := bytes.IndexByte(line, '\n')
+ if i > 0 {
+ line = line[:i]
+ }
+
+ // Skip any whitespace before the interpeter.
+ line = bytes.TrimLeft(line, " \t")
+
+ // Linux only looks for spaces or tabs delimiting the interpreter and
+ // arg.
+ //
+ // execve(2): "On Linux, the entire string following the interpreter
+ // name is passed as a single argument to the interpreter, and this
+ // string can include white space."
+ interp := line
+ var arg []byte
+ i = bytes.IndexAny(line, " \t")
+ if i >= 0 {
+ interp = line[:i]
+ arg = bytes.TrimLeft(line[i:], " \t")
+ }
+
+ if string(interp) == "" {
+ ctx.Infof("Interpreter script contains no interpreter: %v", line)
+ return "", []string{}, syserror.ENOEXEC
+ }
+
+ // Build the new argument list:
+ //
+ // 1. The interpreter.
+ newargv = append(newargv, string(interp))
+
+ // 2. The optional interpreter argument.
+ if len(arg) > 0 {
+ newargv = append(newargv, string(arg))
+ }
+
+ // 3. The original arguments. The original argv[0] is replaced with the
+ // full script filename.
+ if len(argv) > 0 {
+ argv[0] = filename
+ } else {
+ argv = []string{filename}
+ }
+ newargv = append(newargv, argv...)
+
+ return string(interp), newargv, nil
+}
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
new file mode 100644
index 000000000..986c7fb4d
--- /dev/null
+++ b/pkg/sentry/loader/loader.go
@@ -0,0 +1,315 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package loader loads an executable file into a MemoryManager.
+package loader
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "path"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// LoadArgs holds specifications for an executable file to be loaded.
+type LoadArgs struct {
+ // MemoryManager is the memory manager to load the executable into.
+ MemoryManager *mm.MemoryManager
+
+ // RemainingTraversals is the maximum number of symlinks to follow to
+ // resolve Filename. This counter is passed by reference to keep it
+ // updated throughout the call stack.
+ RemainingTraversals *uint
+
+ // ResolveFinal indicates whether the final link of Filename should be
+ // resolved, if it is a symlink.
+ ResolveFinal bool
+
+ // Filename is the path for the executable.
+ Filename string
+
+ // File is an open fs.File object of the executable. If File is not
+ // nil, then File will be loaded and Filename will be ignored.
+ //
+ // The caller is responsible for checking that the user can execute this file.
+ File fsbridge.File
+
+ // Opener is used to open the executable file when 'File' is nil.
+ Opener fsbridge.Lookup
+
+ // CloseOnExec indicates that the executable (or one of its parent
+ // directories) was opened with O_CLOEXEC. If the executable is an
+ // interpreter script, then cause an ENOENT error to occur, since the
+ // script would otherwise be inaccessible to the interpreter.
+ CloseOnExec bool
+
+ // Argv is the vector of arguments to pass to the executable.
+ Argv []string
+
+ // Envv is the vector of environment variables to pass to the
+ // executable.
+ Envv []string
+
+ // Features specifies the CPU feature set for the executable.
+ Features *cpuid.FeatureSet
+}
+
+// openPath opens args.Filename and checks that it is valid for loading.
+//
+// openPath returns an *fs.Dirent and *fs.File for args.Filename, which is not
+// installed in the Task FDTable. The caller takes ownership of both.
+//
+// args.Filename must be a readable, executable, regular file.
+func openPath(ctx context.Context, args LoadArgs) (fsbridge.File, error) {
+ if args.Filename == "" {
+ ctx.Infof("cannot open empty name")
+ return nil, syserror.ENOENT
+ }
+
+ // TODO(gvisor.dev/issue/160): Linux requires only execute permission,
+ // not read. However, our backing filesystems may prevent us from reading
+ // the file without read permission. Additionally, a task with a
+ // non-readable executable has additional constraints on access via
+ // ptrace and procfs.
+ opts := vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ FileExec: true,
+ }
+ return args.Opener.OpenPath(ctx, args.Filename, opts, args.RemainingTraversals, args.ResolveFinal)
+}
+
+// checkIsRegularFile prevents us from trying to execute a directory, pipe, etc.
+func checkIsRegularFile(ctx context.Context, file fsbridge.File, filename string) error {
+ t, err := file.Type(ctx)
+ if err != nil {
+ return err
+ }
+ if t != linux.ModeRegular {
+ ctx.Infof("%q is not a regular file: %v", filename, t)
+ return syserror.EACCES
+ }
+ return nil
+}
+
+// allocStack allocates and maps a stack in to any available part of the address space.
+func allocStack(ctx context.Context, m *mm.MemoryManager, a arch.Context) (*arch.Stack, error) {
+ ar, err := m.MapStack(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return &arch.Stack{a, m, ar.End}, nil
+}
+
+const (
+ // maxLoaderAttempts is the maximum number of attempts to try to load
+ // an interpreter scripts, to prevent loops. 6 (initial + 5 changes) is
+ // what the Linux kernel allows (fs/exec.c:search_binary_handler).
+ maxLoaderAttempts = 6
+)
+
+// loadExecutable loads an executable that is pointed to by args.File. The
+// caller is responsible for checking that the user can execute this file.
+// If nil, the path args.Filename is resolved and loaded (check that the user
+// can execute this file is done here in this case). If the executable is an
+// interpreter script rather than an ELF, the binary of the corresponding
+// interpreter will be loaded.
+//
+// It returns:
+// * loadedELF, description of the loaded binary
+// * arch.Context matching the binary arch
+// * fs.Dirent of the binary file
+// * Possibly updated args.Argv
+func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, fsbridge.File, []string, error) {
+ for i := 0; i < maxLoaderAttempts; i++ {
+ if args.File == nil {
+ var err error
+ args.File, err = openPath(ctx, args)
+ if err != nil {
+ ctx.Infof("Error opening %s: %v", args.Filename, err)
+ return loadedELF{}, nil, nil, nil, err
+ }
+ // Ensure file is release in case the code loops or errors out.
+ defer args.File.DecRef()
+ } else {
+ if err := checkIsRegularFile(ctx, args.File, args.Filename); err != nil {
+ return loadedELF{}, nil, nil, nil, err
+ }
+ }
+
+ // Check the header. Is this an ELF or interpreter script?
+ var hdr [4]uint8
+ // N.B. We assume that reading from a regular file cannot block.
+ _, err := args.File.ReadFull(ctx, usermem.BytesIOSequence(hdr[:]), 0)
+ // Allow unexpected EOF, as a valid executable could be only three bytes
+ // (e.g., #!a).
+ if err != nil && err != io.ErrUnexpectedEOF {
+ if err == io.EOF {
+ err = syserror.ENOEXEC
+ }
+ return loadedELF{}, nil, nil, nil, err
+ }
+
+ switch {
+ case bytes.Equal(hdr[:], []byte(elfMagic)):
+ loaded, ac, err := loadELF(ctx, args)
+ if err != nil {
+ ctx.Infof("Error loading ELF: %v", err)
+ return loadedELF{}, nil, nil, nil, err
+ }
+ // An ELF is always terminal. Hold on to file.
+ args.File.IncRef()
+ return loaded, ac, args.File, args.Argv, err
+
+ case bytes.Equal(hdr[:2], []byte(interpreterScriptMagic)):
+ if args.CloseOnExec {
+ return loadedELF{}, nil, nil, nil, syserror.ENOENT
+ }
+ args.Filename, args.Argv, err = parseInterpreterScript(ctx, args.Filename, args.File, args.Argv)
+ if err != nil {
+ ctx.Infof("Error loading interpreter script: %v", err)
+ return loadedELF{}, nil, nil, nil, err
+ }
+ // Refresh the traversal limit for the interpreter.
+ *args.RemainingTraversals = linux.MaxSymlinkTraversals
+
+ default:
+ ctx.Infof("Unknown magic: %v", hdr)
+ return loadedELF{}, nil, nil, nil, syserror.ENOEXEC
+ }
+ // Set to nil in case we loop on a Interpreter Script.
+ args.File = nil
+ }
+
+ return loadedELF{}, nil, nil, nil, syserror.ELOOP
+}
+
+// Load loads args.File into a MemoryManager. If args.File is nil, the path
+// args.Filename is resolved and loaded instead.
+//
+// If Load returns ErrSwitchFile it should be called again with the returned
+// path and argv.
+//
+// Preconditions:
+// * The Task MemoryManager is empty.
+// * Load is called on the Task goroutine.
+func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) {
+ // Load the executable itself.
+ loaded, ac, file, newArgv, err := loadExecutable(ctx, args)
+ if err != nil {
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux())
+ }
+ defer file.DecRef()
+
+ // Load the VDSO.
+ vdsoAddr, err := loadVDSO(ctx, args.MemoryManager, vdso, loaded)
+ if err != nil {
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("error loading VDSO: %v", err), syserr.FromError(err).ToLinux())
+ }
+
+ // Setup the heap. brk starts at the next page after the end of the
+ // executable. Userspace can assume that the remainer of the page after
+ // loaded.end is available for its use.
+ e, ok := loaded.end.RoundUp()
+ if !ok {
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("brk overflows: %#x", loaded.end), linux.ENOEXEC)
+ }
+ args.MemoryManager.BrkSetup(ctx, e)
+
+ // Allocate our stack.
+ stack, err := allocStack(ctx, args.MemoryManager, ac)
+ if err != nil {
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to allocate stack: %v", err), syserr.FromError(err).ToLinux())
+ }
+
+ // Push the original filename to the stack, for AT_EXECFN.
+ execfn, err := stack.Push(args.Filename)
+ if err != nil {
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push exec filename: %v", err), syserr.FromError(err).ToLinux())
+ }
+
+ // Push 16 random bytes on the stack which AT_RANDOM will point to.
+ var b [16]byte
+ if _, err := rand.Read(b[:]); err != nil {
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to read random bytes: %v", err), syserr.FromError(err).ToLinux())
+ }
+ random, err := stack.Push(b)
+ if err != nil {
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push random bytes: %v", err), syserr.FromError(err).ToLinux())
+ }
+
+ c := auth.CredentialsFromContext(ctx)
+
+ // Add generic auxv entries.
+ auxv := append(loaded.auxv, arch.Auxv{
+ arch.AuxEntry{linux.AT_UID, usermem.Addr(c.RealKUID.In(c.UserNamespace).OrOverflow())},
+ arch.AuxEntry{linux.AT_EUID, usermem.Addr(c.EffectiveKUID.In(c.UserNamespace).OrOverflow())},
+ arch.AuxEntry{linux.AT_GID, usermem.Addr(c.RealKGID.In(c.UserNamespace).OrOverflow())},
+ arch.AuxEntry{linux.AT_EGID, usermem.Addr(c.EffectiveKGID.In(c.UserNamespace).OrOverflow())},
+ // The conditions that require AT_SECURE = 1 never arise. See
+ // kernel.Task.updateCredsForExecLocked.
+ arch.AuxEntry{linux.AT_SECURE, 0},
+ arch.AuxEntry{linux.AT_CLKTCK, linux.CLOCKS_PER_SEC},
+ arch.AuxEntry{linux.AT_EXECFN, execfn},
+ arch.AuxEntry{linux.AT_RANDOM, random},
+ arch.AuxEntry{linux.AT_PAGESZ, usermem.PageSize},
+ arch.AuxEntry{linux.AT_SYSINFO_EHDR, vdsoAddr},
+ }...)
+ auxv = append(auxv, extraAuxv...)
+
+ sl, err := stack.Load(newArgv, args.Envv, auxv)
+ if err != nil {
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load stack: %v", err), syserr.FromError(err).ToLinux())
+ }
+
+ m := args.MemoryManager
+ m.SetArgvStart(sl.ArgvStart)
+ m.SetArgvEnd(sl.ArgvEnd)
+ m.SetEnvvStart(sl.EnvvStart)
+ m.SetEnvvEnd(sl.EnvvEnd)
+ m.SetAuxv(auxv)
+ m.SetExecutable(file)
+
+ symbolValue, err := getSymbolValueFromVDSO("rt_sigreturn")
+ if err != nil {
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to find rt_sigreturn in vdso: %v", err), syserr.FromError(err).ToLinux())
+ }
+
+ // Found rt_sigretrun.
+ addr := uint64(vdsoAddr) + symbolValue - vdsoPrelink
+ m.SetVDSOSigReturn(addr)
+
+ ac.SetIP(uintptr(loaded.entry))
+ ac.SetStack(uintptr(stack.Bottom))
+
+ name := path.Base(args.Filename)
+ if len(name) > linux.TASK_COMM_LEN-1 {
+ name = name[:linux.TASK_COMM_LEN-1]
+ }
+
+ return loaded.os, ac, name, nil
+}
diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go
new file mode 100644
index 000000000..05a294fe6
--- /dev/null
+++ b/pkg/sentry/loader/vdso.go
@@ -0,0 +1,382 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package loader
+
+import (
+ "bytes"
+ "debug/elf"
+ "fmt"
+ "io"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const vdsoPrelink = 0xffffffffff700000
+
+type fileContext struct {
+ context.Context
+}
+
+func (f *fileContext) Value(key interface{}) interface{} {
+ switch key {
+ case uniqueid.CtxGlobalUniqueID:
+ return uint64(0)
+ default:
+ return f.Context.Value(key)
+ }
+}
+
+type byteFullReader struct {
+ data []byte
+}
+
+func (b *byteFullReader) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ if offset >= int64(len(b.data)) {
+ return 0, io.EOF
+ }
+ n, err := dst.CopyOut(ctx, b.data[offset:])
+ return int64(n), err
+}
+
+// validateVDSO checks that the VDSO can be loaded by loadVDSO.
+//
+// VDSOs are special (see below). Since we are going to map the VDSO directly
+// rather than using a normal loading process, we require that the PT_LOAD
+// segments have the same layout in the ELF as they expect to have in memory.
+//
+// Namely, this means that we must verify:
+// * PT_LOAD file offsets are equivalent to the memory offset from the first
+// segment.
+// * No extra zeroed space (memsz) is required.
+// * PT_LOAD segments are in order.
+// * No two PT_LOAD segments occupy parts of the same page.
+// * PT_LOAD segments don't extend beyond the end of the file.
+//
+// ctx may be nil if f does not need it.
+func validateVDSO(ctx context.Context, f fullReader, size uint64) (elfInfo, error) {
+ info, err := parseHeader(ctx, f)
+ if err != nil {
+ log.Infof("Unable to parse VDSO header: %v", err)
+ return elfInfo{}, err
+ }
+
+ var first *elf.ProgHeader
+ var prev *elf.ProgHeader
+ var prevEnd usermem.Addr
+ for i, phdr := range info.phdrs {
+ if phdr.Type != elf.PT_LOAD {
+ continue
+ }
+
+ if first == nil {
+ first = &info.phdrs[i]
+ if phdr.Off != 0 {
+ log.Warningf("First PT_LOAD segment has non-zero file offset")
+ return elfInfo{}, syserror.ENOEXEC
+ }
+ }
+
+ memoryOffset := phdr.Vaddr - first.Vaddr
+ if memoryOffset != phdr.Off {
+ log.Warningf("PT_LOAD segment memory offset %#x != file offset %#x", memoryOffset, phdr.Off)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+
+ // memsz larger than filesz means that extra zeroed space should be
+ // provided at the end of the segment. Since we are mapping the ELF
+ // directly, we don't want to just overwrite part of the ELF with
+ // zeroes.
+ if phdr.Memsz != phdr.Filesz {
+ log.Warningf("PT_LOAD segment memsz %#x != filesz %#x", phdr.Memsz, phdr.Filesz)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+
+ start := usermem.Addr(memoryOffset)
+ end, ok := start.AddLength(phdr.Memsz)
+ if !ok {
+ log.Warningf("PT_LOAD segment size overflows: %#x + %#x", start, end)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+ if uint64(end) > size {
+ log.Warningf("PT_LOAD segment end %#x extends beyond end of file %#x", end, size)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+
+ if prev != nil {
+ if start < prevEnd {
+ log.Warningf("PT_LOAD segments out of order")
+ return elfInfo{}, syserror.ENOEXEC
+ }
+
+ // We mprotect entire pages, so each segment must be in
+ // its own page.
+ prevEndPage := prevEnd.RoundDown()
+ startPage := start.RoundDown()
+ if prevEndPage >= startPage {
+ log.Warningf("PT_LOAD segments share a page: %#x", prevEndPage)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+ }
+ prev = &info.phdrs[i]
+ prevEnd = end
+ }
+
+ return info, nil
+}
+
+// VDSO describes a VDSO.
+//
+// NOTE(mpratt): to support multiple architectures or operating systems, this
+// would need to contain a VDSO for each.
+//
+// +stateify savable
+type VDSO struct {
+ // ParamPage is the VDSO parameter page. This page should be updated to
+ // inform the VDSO for timekeeping data.
+ ParamPage *mm.SpecialMappable
+
+ // vdso is the VDSO ELF itself.
+ vdso *mm.SpecialMappable
+
+ // os is the operating system targeted by the VDSO.
+ os abi.OS
+
+ // arch is the architecture targeted by the VDSO.
+ arch arch.Arch
+
+ // phdrs are the VDSO ELF phdrs.
+ phdrs []elf.ProgHeader `state:".([]elfProgHeader)"`
+}
+
+// getSymbolValueFromVDSO returns the specific symbol value in vdso.so.
+func getSymbolValueFromVDSO(symbol string) (uint64, error) {
+ f, err := elf.NewFile(bytes.NewReader(vdsoBin))
+ if err != nil {
+ return 0, err
+ }
+ syms, err := f.Symbols()
+ if err != nil {
+ return 0, err
+ }
+
+ for _, sym := range syms {
+ if elf.ST_BIND(sym.Info) != elf.STB_LOCAL && sym.Section != elf.SHN_UNDEF {
+ if strings.Contains(sym.Name, symbol) {
+ return sym.Value, nil
+ }
+ }
+ }
+ return 0, fmt.Errorf("no %v in vdso.so", symbol)
+}
+
+// PrepareVDSO validates the system VDSO and returns a VDSO, containing the
+// param page for updating by the kernel.
+func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) {
+ vdsoFile := &byteFullReader{data: vdsoBin}
+
+ // First make sure the VDSO is valid. vdsoFile does not use ctx, so a
+ // nil context can be passed.
+ info, err := validateVDSO(nil, vdsoFile, uint64(len(vdsoBin)))
+ if err != nil {
+ return nil, err
+ }
+
+ // Then copy it into a VDSO mapping.
+ size, ok := usermem.Addr(len(vdsoBin)).RoundUp()
+ if !ok {
+ return nil, fmt.Errorf("VDSO size overflows? %#x", len(vdsoBin))
+ }
+
+ mf := mfp.MemoryFile()
+ vdso, err := mf.Allocate(uint64(size), usage.System)
+ if err != nil {
+ return nil, fmt.Errorf("unable to allocate VDSO memory: %v", err)
+ }
+
+ ims, err := mf.MapInternal(vdso, usermem.ReadWrite)
+ if err != nil {
+ mf.DecRef(vdso)
+ return nil, fmt.Errorf("unable to map VDSO memory: %v", err)
+ }
+
+ _, err = safemem.CopySeq(ims, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(vdsoBin)))
+ if err != nil {
+ mf.DecRef(vdso)
+ return nil, fmt.Errorf("unable to copy VDSO into memory: %v", err)
+ }
+
+ // Finally, allocate a param page for this VDSO.
+ paramPage, err := mf.Allocate(usermem.PageSize, usage.System)
+ if err != nil {
+ mf.DecRef(vdso)
+ return nil, fmt.Errorf("unable to allocate VDSO param page: %v", err)
+ }
+
+ return &VDSO{
+ ParamPage: mm.NewSpecialMappable("[vvar]", mfp, paramPage),
+ // TODO(gvisor.dev/issue/157): Don't advertise the VDSO, as
+ // some applications may not be able to handle multiple [vdso]
+ // hints.
+ vdso: mm.NewSpecialMappable("", mfp, vdso),
+ os: info.os,
+ arch: info.arch,
+ phdrs: info.phdrs,
+ }, nil
+}
+
+// loadVDSO loads the VDSO into m.
+//
+// VDSOs are special.
+//
+// VDSOs are fully position independent. However, instead of loading a VDSO
+// like a normal ELF binary, mapping only the PT_LOAD segments, the Linux
+// kernel simply directly maps the entire file into process memory, with very
+// little real ELF parsing.
+//
+// NOTE(b/25323870): This means that userspace can, and unfortunately does,
+// depend on parts of the ELF that would normally not be mapped. To maintain
+// compatibility with such binaries, we load the VDSO much like Linux.
+//
+// loadVDSO takes a reference on the VDSO and parameter page FrameRegions.
+func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF) (usermem.Addr, error) {
+ if v.os != bin.os {
+ ctx.Warningf("Binary ELF OS %v and VDSO ELF OS %v differ", bin.os, v.os)
+ return 0, syserror.ENOEXEC
+ }
+ if v.arch != bin.arch {
+ ctx.Warningf("Binary ELF arch %v and VDSO ELF arch %v differ", bin.arch, v.arch)
+ return 0, syserror.ENOEXEC
+ }
+
+ // Reserve address space for the VDSO and its parameter page, which is
+ // mapped just before the VDSO.
+ mapSize := v.vdso.Length() + v.ParamPage.Length()
+ addr, err := m.MMap(ctx, memmap.MMapOpts{
+ Length: mapSize,
+ Private: true,
+ })
+ if err != nil {
+ ctx.Infof("Unable to reserve VDSO address space: %v", err)
+ return 0, err
+ }
+
+ // Now map the param page.
+ _, err = m.MMap(ctx, memmap.MMapOpts{
+ Length: v.ParamPage.Length(),
+ MappingIdentity: v.ParamPage,
+ Mappable: v.ParamPage,
+ Addr: addr,
+ Fixed: true,
+ Unmap: true,
+ Private: true,
+ Perms: usermem.Read,
+ MaxPerms: usermem.Read,
+ })
+ if err != nil {
+ ctx.Infof("Unable to map VDSO param page: %v", err)
+ return 0, err
+ }
+
+ // Now map the VDSO itself.
+ vdsoAddr, ok := addr.AddLength(v.ParamPage.Length())
+ if !ok {
+ panic(fmt.Sprintf("Part of mapped range overflows? %#x + %#x", addr, v.ParamPage.Length()))
+ }
+ _, err = m.MMap(ctx, memmap.MMapOpts{
+ Length: v.vdso.Length(),
+ MappingIdentity: v.vdso,
+ Mappable: v.vdso,
+ Addr: vdsoAddr,
+ Fixed: true,
+ Unmap: true,
+ Private: true,
+ Perms: usermem.Read,
+ MaxPerms: usermem.AnyAccess,
+ })
+ if err != nil {
+ ctx.Infof("Unable to map VDSO: %v", err)
+ return 0, err
+ }
+
+ vdsoEnd, ok := vdsoAddr.AddLength(v.vdso.Length())
+ if !ok {
+ panic(fmt.Sprintf("VDSO mapping overflows? %#x + %#x", vdsoAddr, v.vdso.Length()))
+ }
+
+ // Set additional protections for the individual segments.
+ var first *elf.ProgHeader
+ for i, phdr := range v.phdrs {
+ if phdr.Type != elf.PT_LOAD {
+ continue
+ }
+
+ if first == nil {
+ first = &v.phdrs[i]
+ }
+
+ memoryOffset := phdr.Vaddr - first.Vaddr
+ segAddr, ok := vdsoAddr.AddLength(memoryOffset)
+ if !ok {
+ ctx.Warningf("PT_LOAD segment address overflows: %#x + %#x", segAddr, memoryOffset)
+ return 0, syserror.ENOEXEC
+ }
+ segPage := segAddr.RoundDown()
+ segSize := usermem.Addr(phdr.Memsz)
+ segSize, ok = segSize.AddLength(segAddr.PageOffset())
+ if !ok {
+ ctx.Warningf("PT_LOAD segment memsize %#x + offset %#x overflows", phdr.Memsz, segAddr.PageOffset())
+ return 0, syserror.ENOEXEC
+ }
+ segSize, ok = segSize.RoundUp()
+ if !ok {
+ ctx.Warningf("PT_LOAD segment size overflows: %#x", phdr.Memsz+segAddr.PageOffset())
+ return 0, syserror.ENOEXEC
+ }
+ segEnd, ok := segPage.AddLength(uint64(segSize))
+ if !ok {
+ ctx.Warningf("PT_LOAD segment range overflows: %#x + %#x", segAddr, segSize)
+ return 0, syserror.ENOEXEC
+ }
+ if segEnd > vdsoEnd {
+ ctx.Warningf("PT_LOAD segment ends beyond VDSO: %#x > %#x", segEnd, vdsoEnd)
+ return 0, syserror.ENOEXEC
+ }
+
+ perms := progFlagsAsPerms(phdr.Flags)
+ if perms != usermem.Read {
+ if err := m.MProtect(segPage, uint64(segSize), perms, false); err != nil {
+ ctx.Warningf("Unable to set PT_LOAD segment protections %+v at [%#x, %#x): %v", perms, segAddr, segEnd, err)
+ return 0, syserror.ENOEXEC
+ }
+ }
+ }
+
+ return vdsoAddr, nil
+}
diff --git a/pkg/sentry/loader/vdso_state.go b/pkg/sentry/loader/vdso_state.go
new file mode 100644
index 000000000..db378e90a
--- /dev/null
+++ b/pkg/sentry/loader/vdso_state.go
@@ -0,0 +1,48 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package loader
+
+import (
+ "debug/elf"
+)
+
+// +stateify savable
+type elfProgHeader struct {
+ Type elf.ProgType
+ Flags elf.ProgFlag
+ Off uint64
+ Vaddr uint64
+ Paddr uint64
+ Filesz uint64
+ Memsz uint64
+ Align uint64
+}
+
+// savePhdrs is invoked by stateify.
+func (v *VDSO) savePhdrs() []elfProgHeader {
+ s := make([]elfProgHeader, 0, len(v.phdrs))
+ for _, h := range v.phdrs {
+ s = append(s, elfProgHeader(h))
+ }
+ return s
+}
+
+// loadPhdrs is invoked by stateify.
+func (v *VDSO) loadPhdrs(s []elfProgHeader) {
+ v.phdrs = make([]elf.ProgHeader, 0, len(s))
+ for _, h := range s {
+ v.phdrs = append(v.phdrs, elf.ProgHeader(h))
+ }
+}
diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD
new file mode 100644
index 000000000..a98b66de1
--- /dev/null
+++ b/pkg/sentry/memmap/BUILD
@@ -0,0 +1,55 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "mappable_range",
+ out = "mappable_range.go",
+ package = "memmap",
+ prefix = "Mappable",
+ template = "//pkg/segment:generic_range",
+ types = {
+ "T": "uint64",
+ },
+)
+
+go_template_instance(
+ name = "mapping_set_impl",
+ out = "mapping_set_impl.go",
+ package = "memmap",
+ prefix = "Mapping",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "uint64",
+ "Range": "MappableRange",
+ "Value": "MappingsOfRange",
+ "Functions": "mappingSetFunctions",
+ },
+)
+
+go_library(
+ name = "memmap",
+ srcs = [
+ "mappable_range.go",
+ "mapping_set.go",
+ "mapping_set_impl.go",
+ "memmap.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/sentry/platform",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
+
+go_test(
+ name = "memmap_test",
+ size = "small",
+ srcs = ["mapping_set_test.go"],
+ library = ":memmap",
+ deps = ["//pkg/usermem"],
+)
diff --git a/pkg/sentry/memmap/mapping_set.go b/pkg/sentry/memmap/mapping_set.go
new file mode 100644
index 000000000..d609c1ae0
--- /dev/null
+++ b/pkg/sentry/memmap/mapping_set.go
@@ -0,0 +1,253 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memmap
+
+import (
+ "fmt"
+ "math"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// MappingSet maps offsets into a Mappable to mappings of those offsets. It is
+// used to implement Mappable.AddMapping and RemoveMapping for Mappables that
+// may need to call MappingSpace.Invalidate.
+//
+// type MappingSet <generated by go_generics>
+
+// MappingsOfRange is the value type of MappingSet, and represents the set of
+// all mappings of the corresponding MappableRange.
+//
+// Using a map offers O(1) lookups in RemoveMapping and
+// mappingSetFunctions.Merge.
+type MappingsOfRange map[MappingOfRange]struct{}
+
+// MappingOfRange represents a mapping of a MappableRange.
+//
+// +stateify savable
+type MappingOfRange struct {
+ MappingSpace MappingSpace
+ AddrRange usermem.AddrRange
+ Writable bool
+}
+
+func (r MappingOfRange) invalidate(opts InvalidateOpts) {
+ r.MappingSpace.Invalidate(r.AddrRange, opts)
+}
+
+// String implements fmt.Stringer.String.
+func (r MappingOfRange) String() string {
+ return fmt.Sprintf("%#v", r.AddrRange)
+}
+
+// mappingSetFunctions implements segment.Functions for MappingSet.
+type mappingSetFunctions struct{}
+
+// MinKey implements segment.Functions.MinKey.
+func (mappingSetFunctions) MinKey() uint64 {
+ return 0
+}
+
+// MaxKey implements segment.Functions.MaxKey.
+func (mappingSetFunctions) MaxKey() uint64 {
+ return math.MaxUint64
+}
+
+// ClearValue implements segment.Functions.ClearValue.
+func (mappingSetFunctions) ClearValue(v *MappingsOfRange) {
+ *v = MappingsOfRange{}
+}
+
+// Merge implements segment.Functions.Merge.
+//
+// Since each value is a map of MappingOfRanges, values can only be merged if
+// all MappingOfRanges in each map have an exact pair in the other map, forming
+// one contiguous region.
+func (mappingSetFunctions) Merge(r1 MappableRange, val1 MappingsOfRange, r2 MappableRange, val2 MappingsOfRange) (MappingsOfRange, bool) {
+ if len(val1) != len(val2) {
+ return nil, false
+ }
+
+ merged := make(MappingsOfRange, len(val1))
+
+ // Each MappingOfRange in val1 must have a matching region in val2, forming
+ // one contiguous region.
+ for k1 := range val1 {
+ // We expect val2 to contain a key that forms a contiguous
+ // region with k1.
+ k2 := MappingOfRange{
+ MappingSpace: k1.MappingSpace,
+ AddrRange: usermem.AddrRange{
+ Start: k1.AddrRange.End,
+ End: k1.AddrRange.End + usermem.Addr(r2.Length()),
+ },
+ Writable: k1.Writable,
+ }
+ if _, ok := val2[k2]; !ok {
+ return nil, false
+ }
+
+ // OK. Add it to the merged map.
+ merged[MappingOfRange{
+ MappingSpace: k1.MappingSpace,
+ AddrRange: usermem.AddrRange{
+ Start: k1.AddrRange.Start,
+ End: k2.AddrRange.End,
+ },
+ Writable: k1.Writable,
+ }] = struct{}{}
+ }
+
+ return merged, true
+}
+
+// Split implements segment.Functions.Split.
+func (mappingSetFunctions) Split(r MappableRange, val MappingsOfRange, split uint64) (MappingsOfRange, MappingsOfRange) {
+ if split <= r.Start || split >= r.End {
+ panic(fmt.Sprintf("split is not within range %v", r))
+ }
+
+ m1 := make(MappingsOfRange, len(val))
+ m2 := make(MappingsOfRange, len(val))
+
+ // split is a value in MappableRange, we need the offset into the
+ // corresponding MappingsOfRange.
+ offset := usermem.Addr(split - r.Start)
+ for k := range val {
+ k1 := MappingOfRange{
+ MappingSpace: k.MappingSpace,
+ AddrRange: usermem.AddrRange{
+ Start: k.AddrRange.Start,
+ End: k.AddrRange.Start + offset,
+ },
+ Writable: k.Writable,
+ }
+ m1[k1] = struct{}{}
+
+ k2 := MappingOfRange{
+ MappingSpace: k.MappingSpace,
+ AddrRange: usermem.AddrRange{
+ Start: k.AddrRange.Start + offset,
+ End: k.AddrRange.End,
+ },
+ Writable: k.Writable,
+ }
+ m2[k2] = struct{}{}
+ }
+
+ return m1, m2
+}
+
+// subsetMapping returns the MappingOfRange that maps subsetRange, given that
+// ms maps wholeRange beginning at addr.
+//
+// For instance, suppose wholeRange = [0x0, 0x2000) and addr = 0x4000,
+// indicating that ms maps addresses [0x4000, 0x6000) to MappableRange [0x0,
+// 0x2000). Then for subsetRange = [0x1000, 0x2000), subsetMapping returns a
+// MappingOfRange for which AddrRange = [0x5000, 0x6000).
+func subsetMapping(wholeRange, subsetRange MappableRange, ms MappingSpace, addr usermem.Addr, writable bool) MappingOfRange {
+ if !wholeRange.IsSupersetOf(subsetRange) {
+ panic(fmt.Sprintf("%v is not a superset of %v", wholeRange, subsetRange))
+ }
+
+ offset := subsetRange.Start - wholeRange.Start
+ start := addr + usermem.Addr(offset)
+ return MappingOfRange{
+ MappingSpace: ms,
+ AddrRange: usermem.AddrRange{
+ Start: start,
+ End: start + usermem.Addr(subsetRange.Length()),
+ },
+ Writable: writable,
+ }
+}
+
+// AddMapping adds the given mapping and returns the set of MappableRanges that
+// previously had no mappings.
+//
+// Preconditions: As for Mappable.AddMapping.
+func (s *MappingSet) AddMapping(ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) []MappableRange {
+ mr := MappableRange{offset, offset + uint64(ar.Length())}
+ var mapped []MappableRange
+ seg, gap := s.Find(mr.Start)
+ for {
+ switch {
+ case seg.Ok() && seg.Start() < mr.End:
+ seg = s.Isolate(seg, mr)
+ seg.Value()[subsetMapping(mr, seg.Range(), ms, ar.Start, writable)] = struct{}{}
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok() && gap.Start() < mr.End:
+ gapMR := gap.Range().Intersect(mr)
+ mapped = append(mapped, gapMR)
+ // Insert a set and continue from the above case.
+ seg, gap = s.Insert(gap, gapMR, make(MappingsOfRange)), MappingGapIterator{}
+
+ default:
+ return mapped
+ }
+ }
+}
+
+// RemoveMapping removes the given mapping and returns the set of
+// MappableRanges that now have no mappings.
+//
+// Preconditions: As for Mappable.RemoveMapping.
+func (s *MappingSet) RemoveMapping(ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) []MappableRange {
+ mr := MappableRange{offset, offset + uint64(ar.Length())}
+ var unmapped []MappableRange
+
+ seg := s.FindSegment(mr.Start)
+ if !seg.Ok() {
+ panic(fmt.Sprintf("MappingSet.RemoveMapping(%v): no segment containing %#x: %v", mr, mr.Start, s))
+ }
+ for seg.Ok() && seg.Start() < mr.End {
+ // Ensure this segment is limited to our range.
+ seg = s.Isolate(seg, mr)
+
+ // Remove this part of the mapping.
+ mappings := seg.Value()
+ delete(mappings, subsetMapping(mr, seg.Range(), ms, ar.Start, writable))
+
+ if len(mappings) == 0 {
+ unmapped = append(unmapped, seg.Range())
+ seg = s.Remove(seg).NextSegment()
+ } else {
+ seg = seg.NextSegment()
+ }
+ }
+ s.MergeAdjacent(mr)
+ return unmapped
+}
+
+// Invalidate calls MappingSpace.Invalidate for all mappings of offsets in mr.
+func (s *MappingSet) Invalidate(mr MappableRange, opts InvalidateOpts) {
+ for seg := s.LowerBoundSegment(mr.Start); seg.Ok() && seg.Start() < mr.End; seg = seg.NextSegment() {
+ segMR := seg.Range()
+ for m := range seg.Value() {
+ region := subsetMapping(segMR, segMR.Intersect(mr), m.MappingSpace, m.AddrRange.Start, m.Writable)
+ region.invalidate(opts)
+ }
+ }
+}
+
+// InvalidateAll calls MappingSpace.Invalidate for all mappings of s.
+func (s *MappingSet) InvalidateAll(opts InvalidateOpts) {
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ for m := range seg.Value() {
+ m.invalidate(opts)
+ }
+ }
+}
diff --git a/pkg/sentry/memmap/mapping_set_test.go b/pkg/sentry/memmap/mapping_set_test.go
new file mode 100644
index 000000000..d39efe38f
--- /dev/null
+++ b/pkg/sentry/memmap/mapping_set_test.go
@@ -0,0 +1,260 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memmap
+
+import (
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type testMappingSpace struct {
+ // Ideally we'd store the full ranges that were invalidated, rather
+ // than individual calls to Invalidate, as they are an implementation
+ // detail, but this is the simplest way for now.
+ inv []usermem.AddrRange
+}
+
+func (n *testMappingSpace) reset() {
+ n.inv = []usermem.AddrRange{}
+}
+
+func (n *testMappingSpace) Invalidate(ar usermem.AddrRange, opts InvalidateOpts) {
+ n.inv = append(n.inv, ar)
+}
+
+func TestAddRemoveMapping(t *testing.T) {
+ set := MappingSet{}
+ ms := &testMappingSpace{}
+
+ mapped := set.AddMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, true)
+ if got, want := mapped, []MappableRange{{0x1000, 0x3000}}; !reflect.DeepEqual(got, want) {
+ t.Errorf("AddMapping: got %+v, wanted %+v", got, want)
+ }
+
+ // Mappings (usermem.AddrRanges => memmap.MappableRange):
+ // [0x10000, 0x12000) => [0x1000, 0x3000)
+ t.Log(&set)
+
+ mapped = set.AddMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true)
+ if len(mapped) != 0 {
+ t.Errorf("AddMapping: got %+v, wanted []", mapped)
+ }
+
+ // Mappings:
+ // [0x10000, 0x11000) => [0x1000, 0x2000)
+ // [0x11000, 0x12000) and [0x20000, 0x21000) => [0x2000, 0x3000)
+ t.Log(&set)
+
+ mapped = set.AddMapping(ms, usermem.AddrRange{0x30000, 0x31000}, 0x4000, true)
+ if got, want := mapped, []MappableRange{{0x4000, 0x5000}}; !reflect.DeepEqual(got, want) {
+ t.Errorf("AddMapping: got %+v, wanted %+v", got, want)
+ }
+
+ // Mappings:
+ // [0x10000, 0x11000) => [0x1000, 0x2000)
+ // [0x11000, 0x12000) and [0x20000, 0x21000) => [0x2000, 0x3000)
+ // [0x30000, 0x31000) => [0x4000, 0x5000)
+ t.Log(&set)
+
+ mapped = set.AddMapping(ms, usermem.AddrRange{0x12000, 0x15000}, 0x3000, true)
+ if got, want := mapped, []MappableRange{{0x3000, 0x4000}, {0x5000, 0x6000}}; !reflect.DeepEqual(got, want) {
+ t.Errorf("AddMapping: got %+v, wanted %+v", got, want)
+ }
+
+ // Mappings:
+ // [0x10000, 0x11000) => [0x1000, 0x2000)
+ // [0x11000, 0x12000) and [0x20000, 0x21000) => [0x2000, 0x3000)
+ // [0x12000, 0x13000) => [0x3000, 0x4000)
+ // [0x13000, 0x14000) and [0x30000, 0x31000) => [0x4000, 0x5000)
+ // [0x14000, 0x15000) => [0x5000, 0x6000)
+ t.Log(&set)
+
+ unmapped := set.RemoveMapping(ms, usermem.AddrRange{0x10000, 0x11000}, 0x1000, true)
+ if got, want := unmapped, []MappableRange{{0x1000, 0x2000}}; !reflect.DeepEqual(got, want) {
+ t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want)
+ }
+
+ // Mappings:
+ // [0x11000, 0x12000) and [0x20000, 0x21000) => [0x2000, 0x3000)
+ // [0x12000, 0x13000) => [0x3000, 0x4000)
+ // [0x13000, 0x14000) and [0x30000, 0x31000) => [0x4000, 0x5000)
+ // [0x14000, 0x15000) => [0x5000, 0x6000)
+ t.Log(&set)
+
+ unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true)
+ if len(unmapped) != 0 {
+ t.Errorf("RemoveMapping: got %+v, wanted []", unmapped)
+ }
+
+ // Mappings:
+ // [0x11000, 0x13000) => [0x2000, 0x4000)
+ // [0x13000, 0x14000) and [0x30000, 0x31000) => [0x4000, 0x5000)
+ // [0x14000, 0x15000) => [0x5000, 0x6000)
+ t.Log(&set)
+
+ unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x11000, 0x15000}, 0x2000, true)
+ if got, want := unmapped, []MappableRange{{0x2000, 0x4000}, {0x5000, 0x6000}}; !reflect.DeepEqual(got, want) {
+ t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want)
+ }
+
+ // Mappings:
+ // [0x30000, 0x31000) => [0x4000, 0x5000)
+ t.Log(&set)
+
+ unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x30000, 0x31000}, 0x4000, true)
+ if got, want := unmapped, []MappableRange{{0x4000, 0x5000}}; !reflect.DeepEqual(got, want) {
+ t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want)
+ }
+}
+
+func TestInvalidateWholeMapping(t *testing.T) {
+ set := MappingSet{}
+ ms := &testMappingSpace{}
+
+ set.AddMapping(ms, usermem.AddrRange{0x10000, 0x11000}, 0, true)
+ // Mappings:
+ // [0x10000, 0x11000) => [0, 0x1000)
+ t.Log(&set)
+ set.Invalidate(MappableRange{0, 0x1000}, InvalidateOpts{})
+ if got, want := ms.inv, []usermem.AddrRange{{0x10000, 0x11000}}; !reflect.DeepEqual(got, want) {
+ t.Errorf("Invalidate: got %+v, wanted %+v", got, want)
+ }
+}
+
+func TestInvalidatePartialMapping(t *testing.T) {
+ set := MappingSet{}
+ ms := &testMappingSpace{}
+
+ set.AddMapping(ms, usermem.AddrRange{0x10000, 0x13000}, 0, true)
+ // Mappings:
+ // [0x10000, 0x13000) => [0, 0x3000)
+ t.Log(&set)
+ set.Invalidate(MappableRange{0x1000, 0x2000}, InvalidateOpts{})
+ if got, want := ms.inv, []usermem.AddrRange{{0x11000, 0x12000}}; !reflect.DeepEqual(got, want) {
+ t.Errorf("Invalidate: got %+v, wanted %+v", got, want)
+ }
+}
+
+func TestInvalidateMultipleMappings(t *testing.T) {
+ set := MappingSet{}
+ ms := &testMappingSpace{}
+
+ set.AddMapping(ms, usermem.AddrRange{0x10000, 0x11000}, 0, true)
+ set.AddMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true)
+ // Mappings:
+ // [0x10000, 0x11000) => [0, 0x1000)
+ // [0x12000, 0x13000) => [0x2000, 0x3000)
+ t.Log(&set)
+ set.Invalidate(MappableRange{0, 0x3000}, InvalidateOpts{})
+ if got, want := ms.inv, []usermem.AddrRange{{0x10000, 0x11000}, {0x20000, 0x21000}}; !reflect.DeepEqual(got, want) {
+ t.Errorf("Invalidate: got %+v, wanted %+v", got, want)
+ }
+}
+
+func TestInvalidateOverlappingMappings(t *testing.T) {
+ set := MappingSet{}
+ ms1 := &testMappingSpace{}
+ ms2 := &testMappingSpace{}
+
+ set.AddMapping(ms1, usermem.AddrRange{0x10000, 0x12000}, 0, true)
+ set.AddMapping(ms2, usermem.AddrRange{0x20000, 0x22000}, 0x1000, true)
+ // Mappings:
+ // ms1:[0x10000, 0x12000) => [0, 0x2000)
+ // ms2:[0x11000, 0x13000) => [0x1000, 0x3000)
+ t.Log(&set)
+ set.Invalidate(MappableRange{0x1000, 0x2000}, InvalidateOpts{})
+ if got, want := ms1.inv, []usermem.AddrRange{{0x11000, 0x12000}}; !reflect.DeepEqual(got, want) {
+ t.Errorf("Invalidate: ms1: got %+v, wanted %+v", got, want)
+ }
+ if got, want := ms2.inv, []usermem.AddrRange{{0x20000, 0x21000}}; !reflect.DeepEqual(got, want) {
+ t.Errorf("Invalidate: ms1: got %+v, wanted %+v", got, want)
+ }
+}
+
+func TestMixedWritableMappings(t *testing.T) {
+ set := MappingSet{}
+ ms := &testMappingSpace{}
+
+ mapped := set.AddMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, true)
+ if got, want := mapped, []MappableRange{{0x1000, 0x3000}}; !reflect.DeepEqual(got, want) {
+ t.Errorf("AddMapping: got %+v, wanted %+v", got, want)
+ }
+
+ // Mappings:
+ // [0x10000, 0x12000) writable => [0x1000, 0x3000)
+ t.Log(&set)
+
+ mapped = set.AddMapping(ms, usermem.AddrRange{0x20000, 0x22000}, 0x2000, false)
+ if got, want := mapped, []MappableRange{{0x3000, 0x4000}}; !reflect.DeepEqual(got, want) {
+ t.Errorf("AddMapping: got %+v, wanted %+v", got, want)
+ }
+
+ // Mappings:
+ // [0x10000, 0x11000) writable => [0x1000, 0x2000)
+ // [0x11000, 0x12000) writable and [0x20000, 0x21000) readonly => [0x2000, 0x3000)
+ // [0x21000, 0x22000) readonly => [0x3000, 0x4000)
+ t.Log(&set)
+
+ // Unmap should fail because we specified the readonly map address range, but
+ // asked to unmap a writable segment.
+ unmapped := set.RemoveMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true)
+ if len(unmapped) != 0 {
+ t.Errorf("RemoveMapping: got %+v, wanted []", unmapped)
+ }
+
+ // Readonly mapping removed, but writable mapping still exists in the range,
+ // so no mappable range fully unmapped.
+ unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, false)
+ if len(unmapped) != 0 {
+ t.Errorf("RemoveMapping: got %+v, wanted []", unmapped)
+ }
+
+ // Mappings:
+ // [0x10000, 0x12000) writable => [0x1000, 0x3000)
+ // [0x21000, 0x22000) readonly => [0x3000, 0x4000)
+ t.Log(&set)
+
+ unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x11000, 0x12000}, 0x2000, true)
+ if got, want := unmapped, []MappableRange{{0x2000, 0x3000}}; !reflect.DeepEqual(got, want) {
+ t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want)
+ }
+
+ // Mappings:
+ // [0x10000, 0x12000) writable => [0x1000, 0x3000)
+ // [0x21000, 0x22000) readonly => [0x3000, 0x4000)
+ t.Log(&set)
+
+ // Unmap should fail since writable bit doesn't match.
+ unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, false)
+ if len(unmapped) != 0 {
+ t.Errorf("RemoveMapping: got %+v, wanted []", unmapped)
+ }
+
+ unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, true)
+ if got, want := unmapped, []MappableRange{{0x1000, 0x2000}}; !reflect.DeepEqual(got, want) {
+ t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want)
+ }
+
+ // Mappings:
+ // [0x21000, 0x22000) readonly => [0x3000, 0x4000)
+ t.Log(&set)
+
+ unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x21000, 0x22000}, 0x3000, false)
+ if got, want := unmapped, []MappableRange{{0x3000, 0x4000}}; !reflect.DeepEqual(got, want) {
+ t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want)
+ }
+}
diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go
new file mode 100644
index 000000000..c6db9fc8f
--- /dev/null
+++ b/pkg/sentry/memmap/memmap.go
@@ -0,0 +1,363 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package memmap defines semantics for memory mappings.
+package memmap
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Mappable represents a memory-mappable object, a mutable mapping from uint64
+// offsets to (platform.File, uint64 File offset) pairs.
+//
+// See mm/mm.go for Mappable's place in the lock order.
+//
+// Preconditions: For all Mappable methods, usermem.AddrRanges and
+// MappableRanges must be non-empty (Length() != 0), and usermem.Addrs and
+// Mappable offsets must be page-aligned.
+type Mappable interface {
+ // AddMapping notifies the Mappable of a mapping from addresses ar in ms to
+ // offsets [offset, offset+ar.Length()) in this Mappable.
+ //
+ // The writable flag indicates whether the backing data for a Mappable can
+ // be modified through the mapping. Effectively, this means a shared mapping
+ // where Translate may be called with at.Write == true. This is a property
+ // established at mapping creation and must remain constant throughout the
+ // lifetime of the mapping.
+ //
+ // Preconditions: offset+ar.Length() does not overflow.
+ AddMapping(ctx context.Context, ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error
+
+ // RemoveMapping notifies the Mappable of the removal of a mapping from
+ // addresses ar in ms to offsets [offset, offset+ar.Length()) in this
+ // Mappable.
+ //
+ // Preconditions: offset+ar.Length() does not overflow. The removed mapping
+ // must exist. writable must match the corresponding call to AddMapping.
+ RemoveMapping(ctx context.Context, ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool)
+
+ // CopyMapping notifies the Mappable of an attempt to copy a mapping in ms
+ // from srcAR to dstAR. For most Mappables, this is equivalent to
+ // AddMapping. Note that it is possible that srcAR.Length() != dstAR.Length(),
+ // and also that srcAR.Length() == 0.
+ //
+ // CopyMapping is only called when a mapping is copied within a given
+ // MappingSpace; it is analogous to Linux's vm_operations_struct::mremap.
+ //
+ // Preconditions: offset+srcAR.Length() and offset+dstAR.Length() do not
+ // overflow. The mapping at srcAR must exist. writable must match the
+ // corresponding call to AddMapping.
+ CopyMapping(ctx context.Context, ms MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error
+
+ // Translate returns the Mappable's current mappings for at least the range
+ // of offsets specified by required, and at most the range of offsets
+ // specified by optional. at is the set of access types that may be
+ // performed using the returned Translations. If not all required offsets
+ // are translated, it returns a non-nil error explaining why.
+ //
+ // Translations are valid until invalidated by a callback to
+ // MappingSpace.Invalidate or until the caller removes its mapping of the
+ // translated range. Mappable implementations must ensure that at least one
+ // reference is held on all pages in a platform.File that may be the result
+ // of a valid Translation.
+ //
+ // Preconditions: required.Length() > 0. optional.IsSupersetOf(required).
+ // required and optional must be page-aligned. The caller must have
+ // established a mapping for all of the queried offsets via a previous call
+ // to AddMapping. The caller is responsible for ensuring that calls to
+ // Translate synchronize with invalidation.
+ //
+ // Postconditions: See CheckTranslateResult.
+ Translate(ctx context.Context, required, optional MappableRange, at usermem.AccessType) ([]Translation, error)
+
+ // InvalidateUnsavable requests that the Mappable invalidate Translations
+ // that cannot be preserved across save/restore.
+ //
+ // Invariant: InvalidateUnsavable never races with concurrent calls to any
+ // other Mappable methods.
+ InvalidateUnsavable(ctx context.Context) error
+}
+
+// Translations are returned by Mappable.Translate.
+type Translation struct {
+ // Source is the translated range in the Mappable.
+ Source MappableRange
+
+ // File is the mapped file.
+ File platform.File
+
+ // Offset is the offset into File at which this Translation begins.
+ Offset uint64
+
+ // Perms is the set of permissions for which platform.AddressSpace.MapFile
+ // and platform.AddressSpace.MapInternal on this Translation is permitted.
+ Perms usermem.AccessType
+}
+
+// FileRange returns the platform.FileRange represented by t.
+func (t Translation) FileRange() platform.FileRange {
+ return platform.FileRange{t.Offset, t.Offset + t.Source.Length()}
+}
+
+// CheckTranslateResult returns an error if (ts, terr) does not satisfy all
+// postconditions for Mappable.Translate(required, optional, at).
+//
+// Preconditions: As for Mappable.Translate.
+func CheckTranslateResult(required, optional MappableRange, at usermem.AccessType, ts []Translation, terr error) error {
+ // Verify that the inputs to Mappable.Translate were valid.
+ if !required.WellFormed() || required.Length() <= 0 {
+ panic(fmt.Sprintf("invalid required range: %v", required))
+ }
+ if !usermem.Addr(required.Start).IsPageAligned() || !usermem.Addr(required.End).IsPageAligned() {
+ panic(fmt.Sprintf("unaligned required range: %v", required))
+ }
+ if !optional.IsSupersetOf(required) {
+ panic(fmt.Sprintf("optional range %v is not a superset of required range %v", optional, required))
+ }
+ if !usermem.Addr(optional.Start).IsPageAligned() || !usermem.Addr(optional.End).IsPageAligned() {
+ panic(fmt.Sprintf("unaligned optional range: %v", optional))
+ }
+
+ // The first Translation must include required.Start.
+ if len(ts) != 0 && !ts[0].Source.Contains(required.Start) {
+ return fmt.Errorf("first Translation %+v does not cover start of required range %v", ts[0], required)
+ }
+ for i, t := range ts {
+ if !t.Source.WellFormed() || t.Source.Length() <= 0 {
+ return fmt.Errorf("Translation %+v has invalid Source", t)
+ }
+ if !usermem.Addr(t.Source.Start).IsPageAligned() || !usermem.Addr(t.Source.End).IsPageAligned() {
+ return fmt.Errorf("Translation %+v has unaligned Source", t)
+ }
+ if t.File == nil {
+ return fmt.Errorf("Translation %+v has nil File", t)
+ }
+ if !usermem.Addr(t.Offset).IsPageAligned() {
+ return fmt.Errorf("Translation %+v has unaligned Offset", t)
+ }
+ // Translations must be contiguous and in increasing order of
+ // Translation.Source.
+ if i > 0 && ts[i-1].Source.End != t.Source.Start {
+ return fmt.Errorf("Translations %+v and %+v are not contiguous", ts[i-1], t)
+ }
+ // At least part of each Translation must be required.
+ if t.Source.Intersect(required).Length() == 0 {
+ return fmt.Errorf("Translation %+v lies entirely outside required range %v", t, required)
+ }
+ // Translations must be constrained to the optional range.
+ if !optional.IsSupersetOf(t.Source) {
+ return fmt.Errorf("Translation %+v lies outside optional range %v", t, optional)
+ }
+ // Each Translation must permit a superset of requested accesses.
+ if !t.Perms.SupersetOf(at) {
+ return fmt.Errorf("Translation %+v does not permit all requested access types %v", t, at)
+ }
+ }
+ // If the set of Translations does not cover the entire required range,
+ // Translate must return a non-nil error explaining why.
+ if terr == nil {
+ if len(ts) == 0 {
+ return fmt.Errorf("no Translations and no error")
+ }
+ if t := ts[len(ts)-1]; !t.Source.Contains(required.End - 1) {
+ return fmt.Errorf("last Translation %+v does not reach end of required range %v, but Translate returned no error", t, required)
+ }
+ }
+ return nil
+}
+
+// BusError may be returned by implementations of Mappable.Translate for errors
+// that should result in SIGBUS delivery if they cause application page fault
+// handling to fail.
+type BusError struct {
+ // Err is the original error.
+ Err error
+}
+
+// Error implements error.Error.
+func (b *BusError) Error() string {
+ return fmt.Sprintf("BusError: %v", b.Err.Error())
+}
+
+// MappableRange represents a range of uint64 offsets into a Mappable.
+//
+// type MappableRange <generated using go_generics>
+
+// String implements fmt.Stringer.String.
+func (mr MappableRange) String() string {
+ return fmt.Sprintf("[%#x, %#x)", mr.Start, mr.End)
+}
+
+// MappingSpace represents a mutable mapping from usermem.Addrs to (Mappable,
+// uint64 offset) pairs.
+type MappingSpace interface {
+ // Invalidate is called to notify the MappingSpace that values returned by
+ // previous calls to Mappable.Translate for offsets mapped by addresses in
+ // ar are no longer valid.
+ //
+ // Invalidate must not take any locks preceding mm.MemoryManager.activeMu
+ // in the lock order.
+ //
+ // Preconditions: ar.Length() != 0. ar must be page-aligned.
+ Invalidate(ar usermem.AddrRange, opts InvalidateOpts)
+}
+
+// InvalidateOpts holds options to MappingSpace.Invalidate.
+type InvalidateOpts struct {
+ // InvalidatePrivate is true if private pages in the invalidated region
+ // should also be discarded, causing their data to be lost.
+ InvalidatePrivate bool
+}
+
+// MappingIdentity controls the lifetime of a Mappable, and provides
+// information about the Mappable for /proc/[pid]/maps. It is distinct from
+// Mappable because all Mappables that are coherent must compare equal to
+// support the implementation of shared futexes, but different
+// MappingIdentities may represent the same Mappable, in the same way that
+// multiple fs.Files may represent the same fs.Inode. (This similarity is not
+// coincidental; fs.File implements MappingIdentity, and some
+// fs.InodeOperations implement Mappable.)
+type MappingIdentity interface {
+ // IncRef increments the MappingIdentity's reference count.
+ IncRef()
+
+ // DecRef decrements the MappingIdentity's reference count.
+ DecRef()
+
+ // MappedName returns the application-visible name shown in
+ // /proc/[pid]/maps.
+ MappedName(ctx context.Context) string
+
+ // DeviceID returns the device number shown in /proc/[pid]/maps.
+ DeviceID() uint64
+
+ // InodeID returns the inode number shown in /proc/[pid]/maps.
+ InodeID() uint64
+
+ // Msync has the same semantics as fs.FileOperations.Fsync(ctx,
+ // int64(mr.Start), int64(mr.End-1), fs.SyncData).
+ // (fs.FileOperations.Fsync() takes an inclusive end, but mr.End is
+ // exclusive, hence mr.End-1.) It is defined rather than Fsync so that
+ // implementors don't need to depend on the fs package for fs.SyncType.
+ Msync(ctx context.Context, mr MappableRange) error
+}
+
+// MLockMode specifies the memory locking behavior of a memory mapping.
+type MLockMode int
+
+// Note that the ordering of MLockModes is significant; see
+// mm.MemoryManager.defMLockMode.
+const (
+ // MLockNone specifies that a mapping has no memory locking behavior.
+ //
+ // This must be the zero value for MLockMode.
+ MLockNone MLockMode = iota
+
+ // MLockEager specifies that a mapping is memory-locked, as by mlock() or
+ // similar. Pages in the mapping should be made, and kept, resident in
+ // physical memory as soon as possible.
+ //
+ // As of this writing, MLockEager does not cause memory-locking to be
+ // requested from the host; it only affects the sentry's memory management
+ // behavior.
+ //
+ // MLockEager is analogous to Linux's VM_LOCKED.
+ MLockEager
+
+ // MLockLazy specifies that a mapping is memory-locked, as by mlock() or
+ // similar. Pages in the mapping should be kept resident in physical memory
+ // once they have been made resident due to e.g. a page fault.
+ //
+ // As of this writing, MLockLazy does not cause memory-locking to be
+ // requested from the host; in fact, it has virtually no effect, except for
+ // interactions between mlocked pages and other syscalls.
+ //
+ // MLockLazy is analogous to Linux's VM_LOCKED | VM_LOCKONFAULT.
+ MLockLazy
+)
+
+// MMapOpts specifies a request to create a memory mapping.
+type MMapOpts struct {
+ // Length is the length of the mapping.
+ Length uint64
+
+ // MappingIdentity controls the lifetime of Mappable, and provides
+ // properties of the mapping shown in /proc/[pid]/maps. If MMapOpts is used
+ // to successfully create a memory mapping, a reference is taken on
+ // MappingIdentity.
+ MappingIdentity MappingIdentity
+
+ // Mappable is the Mappable to be mapped. If Mappable is nil, the mapping
+ // is anonymous. If Mappable is not nil, it must remain valid as long as a
+ // reference is held on MappingIdentity.
+ Mappable Mappable
+
+ // Offset is the offset into Mappable to map. If Mappable is nil, Offset is
+ // ignored.
+ Offset uint64
+
+ // Addr is the suggested address for the mapping.
+ Addr usermem.Addr
+
+ // Fixed specifies whether this is a fixed mapping (it must be located at
+ // Addr).
+ Fixed bool
+
+ // Unmap specifies whether existing mappings in the range being mapped may
+ // be replaced. If Unmap is true, Fixed must be true.
+ Unmap bool
+
+ // If Map32Bit is true, all addresses in the created mapping must fit in a
+ // 32-bit integer. (Note that the "end address" of the mapping, i.e. the
+ // address of the first byte *after* the mapping, need not fit in a 32-bit
+ // integer.) Map32Bit is ignored if Fixed is true.
+ Map32Bit bool
+
+ // Perms is the set of permissions to the applied to this mapping.
+ Perms usermem.AccessType
+
+ // MaxPerms limits the set of permissions that may ever apply to this
+ // mapping. If Mappable is not nil, all memmap.Translations returned by
+ // Mappable.Translate must support all accesses in MaxPerms.
+ //
+ // Preconditions: MaxAccessType should be an effective AccessType, as
+ // access cannot be limited beyond effective AccessTypes.
+ MaxPerms usermem.AccessType
+
+ // Private is true if writes to the mapping should be propagated to a copy
+ // that is exclusive to the MemoryManager.
+ Private bool
+
+ // GrowsDown is true if the mapping should be automatically expanded
+ // downward on guard page faults.
+ GrowsDown bool
+
+ // Precommit is true if the platform should eagerly commit resources to the
+ // mapping (see platform.AddressSpace.MapFile).
+ Precommit bool
+
+ // MLockMode specifies the memory locking behavior of the mapping.
+ MLockMode MLockMode
+
+ // Hint is the name used for the mapping in /proc/[pid]/maps. If Hint is
+ // empty, MappingIdentity.MappedName() will be used instead.
+ //
+ // TODO(jamieliu): Replace entirely with MappingIdentity?
+ Hint string
+}
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
new file mode 100644
index 000000000..a036ce53c
--- /dev/null
+++ b/pkg/sentry/mm/BUILD
@@ -0,0 +1,142 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "file_refcount_set",
+ out = "file_refcount_set.go",
+ imports = {
+ "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ },
+ package = "mm",
+ prefix = "fileRefcount",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "uint64",
+ "Range": "platform.FileRange",
+ "Value": "int32",
+ "Functions": "fileRefcountSetFunctions",
+ },
+)
+
+go_template_instance(
+ name = "vma_set",
+ out = "vma_set.go",
+ consts = {
+ "minDegree": "8",
+ "trackGaps": "1",
+ },
+ imports = {
+ "usermem": "gvisor.dev/gvisor/pkg/usermem",
+ },
+ package = "mm",
+ prefix = "vma",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "usermem.Addr",
+ "Range": "usermem.AddrRange",
+ "Value": "vma",
+ "Functions": "vmaSetFunctions",
+ },
+)
+
+go_template_instance(
+ name = "pma_set",
+ out = "pma_set.go",
+ consts = {
+ "minDegree": "8",
+ },
+ imports = {
+ "usermem": "gvisor.dev/gvisor/pkg/usermem",
+ },
+ package = "mm",
+ prefix = "pma",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "usermem.Addr",
+ "Range": "usermem.AddrRange",
+ "Value": "pma",
+ "Functions": "pmaSetFunctions",
+ },
+)
+
+go_template_instance(
+ name = "io_list",
+ out = "io_list.go",
+ package = "mm",
+ prefix = "io",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*ioResult",
+ "Linker": "*ioResult",
+ },
+)
+
+go_library(
+ name = "mm",
+ srcs = [
+ "address_space.go",
+ "aio_context.go",
+ "aio_context_state.go",
+ "debug.go",
+ "file_refcount_set.go",
+ "io.go",
+ "io_list.go",
+ "lifecycle.go",
+ "metadata.go",
+ "mm.go",
+ "pma.go",
+ "pma_set.go",
+ "procfs.go",
+ "save_restore.go",
+ "shm.go",
+ "special_mappable.go",
+ "syscalls.go",
+ "vma.go",
+ "vma_set.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/atomicbitops",
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/safecopy",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fs/proc/seqfile",
+ "//pkg/sentry/fsbridge",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/futex",
+ "//pkg/sentry/kernel/shm",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/usage",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/tcpip/buffer",
+ "//pkg/usermem",
+ ],
+)
+
+go_test(
+ name = "mm_test",
+ size = "small",
+ srcs = ["mm_test.go"],
+ library = ":mm",
+ deps = [
+ "//pkg/context",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/mm/README.md b/pkg/sentry/mm/README.md
new file mode 100644
index 000000000..f4d43d927
--- /dev/null
+++ b/pkg/sentry/mm/README.md
@@ -0,0 +1,280 @@
+This package provides an emulation of Linux semantics for application virtual
+memory mappings.
+
+For completeness, this document also describes aspects of the memory management
+subsystem defined outside this package.
+
+# Background
+
+We begin by describing semantics for virtual memory in Linux.
+
+A virtual address space is defined as a collection of mappings from virtual
+addresses to physical memory. However, userspace applications do not configure
+mappings to physical memory directly. Instead, applications configure memory
+mappings from virtual addresses to offsets into a file using the `mmap` system
+call.[^mmap-anon] For example, a call to:
+
+ mmap(
+ /* addr = */ 0x400000,
+ /* length = */ 0x1000,
+ PROT_READ | PROT_WRITE,
+ MAP_SHARED,
+ /* fd = */ 3,
+ /* offset = */ 0);
+
+creates a mapping of length 0x1000 bytes, starting at virtual address (VA)
+0x400000, to offset 0 in the file represented by file descriptor (FD) 3. Within
+the Linux kernel, virtual memory mappings are represented by *virtual memory
+areas* (VMAs). Supposing that FD 3 represents file /tmp/foo, the state of the
+virtual memory subsystem after the `mmap` call may be depicted as:
+
+ VMA: VA:0x400000 -> /tmp/foo:0x0
+
+Establishing a virtual memory area does not necessarily establish a mapping to a
+physical address, because Linux has not necessarily provisioned physical memory
+to store the file's contents. Thus, if the application attempts to read the
+contents of VA 0x400000, it may incur a *page fault*, a CPU exception that
+forces the kernel to create such a mapping to service the read.
+
+For a file, doing so consists of several logical phases:
+
+1. The kernel allocates physical memory to store the contents of the required
+ part of the file, and copies file contents to the allocated memory.
+ Supposing that the kernel chooses the physical memory at physical address
+ (PA) 0x2fb000, the resulting state of the system is:
+
+ VMA: VA:0x400000 -> /tmp/foo:0x0
+ Filemap: /tmp/foo:0x0 -> PA:0x2fb000
+
+ (In Linux the state of the mapping from file offset to physical memory is
+ stored in `struct address_space`, but to avoid confusion with other notions
+ of address space we will refer to this system as filemap, named after Linux
+ kernel source file `mm/filemap.c`.)
+
+2. The kernel stores the effective mapping from virtual to physical address in
+ a *page table entry* (PTE) in the application's *page tables*, which are
+ used by the CPU's virtual memory hardware to perform address translation.
+ The resulting state of the system is:
+
+ VMA: VA:0x400000 -> /tmp/foo:0x0
+ Filemap: /tmp/foo:0x0 -> PA:0x2fb000
+ PTE: VA:0x400000 -----------------> PA:0x2fb000
+
+ The PTE is required for the application to actually use the contents of the
+ mapped file as virtual memory. However, the PTE is derived from the VMA and
+ filemap state, both of which are independently mutable, such that mutations
+ to either will affect the PTE. For example:
+
+ - The application may remove the VMA using the `munmap` system call. This
+ breaks the mapping from VA:0x400000 to /tmp/foo:0x0, and consequently
+ the mapping from VA:0x400000 to PA:0x2fb000. However, it does not
+ necessarily break the mapping from /tmp/foo:0x0 to PA:0x2fb000, so a
+ future mapping of the same file offset may reuse this physical memory.
+
+ - The application may invalidate the file's contents by passing a length
+ of 0 to the `ftruncate` system call. This breaks the mapping from
+ /tmp/foo:0x0 to PA:0x2fb000, and consequently the mapping from
+ VA:0x400000 to PA:0x2fb000. However, it does not break the mapping from
+ VA:0x400000 to /tmp/foo:0x0, so future changes to the file's contents
+ may again be made visible at VA:0x400000 after another page fault
+ results in the allocation of a new physical address.
+
+ Note that, in order to correctly break the mapping from VA:0x400000 to
+ PA:0x2fb000 in the latter case, filemap must also store a *reverse mapping*
+ from /tmp/foo:0x0 to VA:0x400000 so that it can locate and remove the PTE.
+
+[^mmap-anon]: Memory mappings to non-files are discussed in later sections.
+
+## Private Mappings
+
+The preceding example considered VMAs created using the `MAP_SHARED` flag, which
+means that PTEs derived from the mapping should always use physical memory that
+represents the current state of the mapped file.[^mmap-dev-zero] Applications
+can alternatively pass the `MAP_PRIVATE` flag to create a *private mapping*.
+Private mappings are *copy-on-write*.
+
+Suppose that the application instead created a private mapping in the previous
+example. In Linux, the state of the system after a read page fault would be:
+
+ VMA: VA:0x400000 -> /tmp/foo:0x0 (private)
+ Filemap: /tmp/foo:0x0 -> PA:0x2fb000
+ PTE: VA:0x400000 -----------------> PA:0x2fb000 (read-only)
+
+Now suppose the application attempts to write to VA:0x400000. For a shared
+mapping, the write would be propagated to PA:0x2fb000, and the kernel would be
+responsible for ensuring that the write is later propagated to the mapped file.
+For a private mapping, the write incurs another page fault since the PTE is
+marked read-only. In response, the kernel allocates physical memory to store the
+mapping's *private copy* of the file's contents, copies file contents to the
+allocated memory, and changes the PTE to map to the private copy. Supposing that
+the kernel chooses the physical memory at physical address (PA) 0x5ea000, the
+resulting state of the system is:
+
+ VMA: VA:0x400000 -> /tmp/foo:0x0 (private)
+ Filemap: /tmp/foo:0x0 -> PA:0x2fb000
+ PTE: VA:0x400000 -----------------> PA:0x5ea000
+
+Note that the filemap mapping from /tmp/foo:0x0 to PA:0x2fb000 may still exist,
+but is now irrelevant to this mapping.
+
+[^mmap-dev-zero]: Modulo files with special mmap semantics such as `/dev/zero`.
+
+## Anonymous Mappings
+
+Instead of passing a file to the `mmap` system call, applications can instead
+request an *anonymous* mapping by passing the `MAP_ANONYMOUS` flag.
+Semantically, an anonymous mapping is essentially a mapping to an ephemeral file
+initially filled with zero bytes. Practically speaking, this is how shared
+anonymous mappings are implemented, but private anonymous mappings do not result
+in the creation of an ephemeral file; since there would be no way to modify the
+contents of the underlying file through a private mapping, all private anonymous
+mappings use a single shared page filled with zero bytes until copy-on-write
+occurs.
+
+# Virtual Memory in the Sentry
+
+The sentry implements application virtual memory atop a host kernel, introducing
+an additional level of indirection to the above.
+
+Consider the same scenario as in the previous section. Since the sentry handles
+application system calls, the effect of an application `mmap` system call is to
+create a VMA in the sentry (as opposed to the host kernel):
+
+ Sentry VMA: VA:0x400000 -> /tmp/foo:0x0
+
+When the application first incurs a page fault on this address, the host kernel
+delivers information about the page fault to the sentry in a platform-dependent
+manner, and the sentry handles the fault:
+
+1. The sentry allocates memory to store the contents of the required part of
+ the file, and copies file contents to the allocated memory. However, since
+ the sentry is implemented atop a host kernel, it does not configure mappings
+ to physical memory directly. Instead, mappable "memory" in the sentry is
+ represented by a host file descriptor and offset, since (as noted in
+ "Background") this is the memory mapping primitive provided by the host
+ kernel. In general, memory is allocated from a temporary host file using the
+ `pgalloc` package. Supposing that the sentry allocates offset 0x3000 from
+ host file "memory-file", the resulting state is:
+
+ Sentry VMA: VA:0x400000 -> /tmp/foo:0x0
+ Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000
+
+2. The sentry stores the effective mapping from virtual address to host file in
+ a host VMA by invoking the `mmap` system call:
+
+ Sentry VMA: VA:0x400000 -> /tmp/foo:0x0
+ Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000
+ Host VMA: VA:0x400000 -----------------> host:memory-file:0x3000
+
+3. The sentry returns control to the application, which immediately incurs the
+ page fault again.[^mmap-populate] However, since a host VMA now exists for
+ the faulting virtual address, the host kernel now handles the page fault as
+ described in "Background":
+
+ Sentry VMA: VA:0x400000 -> /tmp/foo:0x0
+ Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000
+ Host VMA: VA:0x400000 -----------------> host:memory-file:0x3000
+ Host filemap: host:memory-file:0x3000 -> PA:0x2fb000
+ Host PTE: VA:0x400000 --------------------------------------------> PA:0x2fb000
+
+Thus, from an implementation standpoint, host VMAs serve the same purpose in the
+sentry that PTEs do in Linux. As in Linux, sentry VMA and filemap state is
+independently mutable, and the desired state of host VMAs is derived from that
+state.
+
+[^mmap-populate]: The sentry could force the host kernel to establish PTEs when
+ it creates the host VMA by passing the `MAP_POPULATE` flag to
+ the `mmap` system call, but usually does not. This is because,
+ to reduce the number of page faults that require handling by
+ the sentry and (correspondingly) the number of host `mmap`
+ system calls, the sentry usually creates host VMAs that are
+ much larger than the single faulting page.
+
+## Private Mappings
+
+The sentry implements private mappings consistently with Linux. Before
+copy-on-write, the private mapping example given in the Background results in:
+
+ Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 (private)
+ Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000
+ Host VMA: VA:0x400000 -----------------> host:memory-file:0x3000 (read-only)
+ Host filemap: host:memory-file:0x3000 -> PA:0x2fb000
+ Host PTE: VA:0x400000 --------------------------------------------> PA:0x2fb000 (read-only)
+
+When the application attempts to write to this address, the host kernel delivers
+information about the resulting page fault to the sentry. Analogous to Linux,
+the sentry allocates memory to store the mapping's private copy of the file's
+contents, copies file contents to the allocated memory, and changes the host VMA
+to map to the private copy. Supposing that the sentry chooses the offset 0x4000
+in host file `memory-file` to store the private copy, the state of the system
+after copy-on-write is:
+
+ Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 (private)
+ Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000
+ Host VMA: VA:0x400000 -----------------> host:memory-file:0x4000
+ Host filemap: host:memory-file:0x4000 -> PA:0x5ea000
+ Host PTE: VA:0x400000 --------------------------------------------> PA:0x5ea000
+
+However, this highlights an important difference between Linux and the sentry.
+In Linux, page tables are concrete (architecture-dependent) data structures
+owned by the kernel. Conversely, the sentry has the ability to create and
+destroy host VMAs using host system calls, but it does not have direct access to
+their state. Thus, as written, if the application invokes the `munmap` system
+call to remove the sentry VMA, it is non-trivial for the sentry to determine
+that it should deallocate `host:memory-file:0x4000`. This implies that the
+sentry must retain information about the host VMAs that it has created.
+
+## Anonymous Mappings
+
+The sentry implements anonymous mappings consistently with Linux, except that
+there is no shared zero page.
+
+# Implementation Constructs
+
+In Linux:
+
+- A virtual address space is represented by `struct mm_struct`.
+
+- VMAs are represented by `struct vm_area_struct`, stored in `struct
+ mm_struct::mmap`.
+
+- Mappings from file offsets to physical memory are stored in `struct
+ address_space`.
+
+- Reverse mappings from file offsets to virtual mappings are stored in `struct
+ address_space::i_mmap`.
+
+- Physical memory pages are represented by a pointer to `struct page` or an
+ index called a *page frame number* (PFN), represented by `pfn_t`.
+
+- PTEs are represented by architecture-dependent type `pte_t`, stored in a
+ table hierarchy rooted at `struct mm_struct::pgd`.
+
+In the sentry:
+
+- A virtual address space is represented by type [`mm.MemoryManager`][mm].
+
+- Sentry VMAs are represented by type [`mm.vma`][mm], stored in
+ `mm.MemoryManager.vmas`.
+
+- Mappings from sentry file offsets to host file offsets are abstracted
+ through interface method [`memmap.Mappable.Translate`][memmap].
+
+- Reverse mappings from sentry file offsets to virtual mappings are abstracted
+ through interface methods
+ [`memmap.Mappable.AddMapping` and `memmap.Mappable.RemoveMapping`][memmap].
+
+- Host files that may be mapped into host VMAs are represented by type
+ [`platform.File`][platform].
+
+- Host VMAs are represented in the sentry by type [`mm.pma`][mm] ("platform
+ mapping area"), stored in `mm.MemoryManager.pmas`.
+
+- Creation and destruction of host VMAs is abstracted through interface
+ methods
+ [`platform.AddressSpace.MapFile` and `platform.AddressSpace.Unmap`][platform].
+
+[memmap]: https://github.com/google/gvisor/blob/master/pkg/sentry/memmap/memmap.go
+[mm]: https://github.com/google/gvisor/blob/master/pkg/sentry/mm/mm.go
+[pgalloc]: https://github.com/google/gvisor/blob/master/pkg/sentry/pgalloc/pgalloc.go
+[platform]: https://github.com/google/gvisor/blob/master/pkg/sentry/platform/platform.go
diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go
new file mode 100644
index 000000000..5c667117c
--- /dev/null
+++ b/pkg/sentry/mm/address_space.go
@@ -0,0 +1,236 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// AddressSpace returns the platform.AddressSpace bound to mm.
+//
+// Preconditions: The caller must have called mm.Activate().
+func (mm *MemoryManager) AddressSpace() platform.AddressSpace {
+ if atomic.LoadInt32(&mm.active) == 0 {
+ panic("trying to use inactive address space?")
+ }
+ return mm.as
+}
+
+// Activate ensures this MemoryManager has a platform.AddressSpace.
+//
+// The caller must not hold any locks when calling Activate.
+//
+// When this MemoryManager is no longer needed by a task, it should call
+// Deactivate to release the reference.
+func (mm *MemoryManager) Activate(ctx context.Context) error {
+ // Fast path: the MemoryManager already has an active
+ // platform.AddressSpace, and we just need to indicate that we need it too.
+ for {
+ active := atomic.LoadInt32(&mm.active)
+ if active == 0 {
+ // Fall back to the slow path.
+ break
+ }
+ if atomic.CompareAndSwapInt32(&mm.active, active, active+1) {
+ return nil
+ }
+ }
+
+ for {
+ // Slow path: may need to synchronize with other goroutines changing
+ // mm.active to or from zero.
+ mm.activeMu.Lock()
+ // Inline Unlock instead of using a defer for performance since this
+ // method is commonly in the hot-path.
+
+ // Check if we raced with another goroutine performing activation.
+ if atomic.LoadInt32(&mm.active) > 0 {
+ // This can't race; Deactivate can't decrease mm.active from 1 to 0
+ // without holding activeMu.
+ atomic.AddInt32(&mm.active, 1)
+ mm.activeMu.Unlock()
+ return nil
+ }
+
+ // Do we have a context? If so, then we never unmapped it. This can
+ // only be the case if !mm.p.CooperativelySchedulesAddressSpace().
+ if mm.as != nil {
+ atomic.StoreInt32(&mm.active, 1)
+ mm.activeMu.Unlock()
+ return nil
+ }
+
+ // Get a new address space. We must force unmapping by passing nil to
+ // NewAddressSpace if requested. (As in the nil interface object, not a
+ // typed nil.)
+ mappingsID := (interface{})(mm)
+ if mm.unmapAllOnActivate {
+ mappingsID = nil
+ }
+ as, c, err := mm.p.NewAddressSpace(mappingsID)
+ if err != nil {
+ mm.activeMu.Unlock()
+ return err
+ }
+ if as == nil {
+ // AddressSpace is unavailable, we must wait.
+ //
+ // activeMu must not be held while waiting, as the user of the address
+ // space we are waiting on may attempt to take activeMu.
+ mm.activeMu.Unlock()
+
+ sleep := mm.p.CooperativelySchedulesAddressSpace() && mm.sleepForActivation
+ if sleep {
+ // Mark this task sleeping while waiting for the address space to
+ // prevent the watchdog from reporting it as a stuck task.
+ ctx.UninterruptibleSleepStart(false)
+ }
+ <-c
+ if sleep {
+ ctx.UninterruptibleSleepFinish(false)
+ }
+ continue
+ }
+
+ // Okay, we could restore all mappings at this point.
+ // But forget that. Let's just let them fault in.
+ mm.as = as
+
+ // Unmapping is done, if necessary.
+ mm.unmapAllOnActivate = false
+
+ // Now that m.as has been assigned, we can set m.active to a non-zero value
+ // to enable the fast path.
+ atomic.StoreInt32(&mm.active, 1)
+
+ mm.activeMu.Unlock()
+ return nil
+ }
+}
+
+// Deactivate releases a reference to the MemoryManager.
+func (mm *MemoryManager) Deactivate() {
+ // Fast path: this is not the last goroutine to deactivate the
+ // MemoryManager.
+ for {
+ active := atomic.LoadInt32(&mm.active)
+ if active == 1 {
+ // Fall back to the slow path.
+ break
+ }
+ if atomic.CompareAndSwapInt32(&mm.active, active, active-1) {
+ return
+ }
+ }
+
+ mm.activeMu.Lock()
+ // Same as Activate.
+
+ // Still active?
+ if atomic.AddInt32(&mm.active, -1) > 0 {
+ mm.activeMu.Unlock()
+ return
+ }
+
+ // Can we hold on to the address space?
+ if !mm.p.CooperativelySchedulesAddressSpace() {
+ mm.activeMu.Unlock()
+ return
+ }
+
+ // Release the address space.
+ mm.as.Release()
+
+ // Lost it.
+ mm.as = nil
+ mm.activeMu.Unlock()
+}
+
+// mapASLocked maps addresses in ar into mm.as. If precommit is true, mappings
+// for all addresses in ar should be precommitted.
+//
+// Preconditions: mm.activeMu must be locked. mm.as != nil. ar.Length() != 0.
+// ar must be page-aligned. pseg == mm.pmas.LowerBoundSegment(ar.Start).
+func (mm *MemoryManager) mapASLocked(pseg pmaIterator, ar usermem.AddrRange, precommit bool) error {
+ // By default, map entire pmas at a time, under the assumption that there
+ // is no cost to mapping more of a pma than necessary.
+ mapAR := usermem.AddrRange{0, ^usermem.Addr(usermem.PageSize - 1)}
+ if precommit {
+ // When explicitly precommitting, only map ar, since overmapping may
+ // incur unexpected resource usage.
+ mapAR = ar
+ } else if mapUnit := mm.p.MapUnit(); mapUnit != 0 {
+ // Limit the range we map to ar, aligned to mapUnit.
+ mapMask := usermem.Addr(mapUnit - 1)
+ mapAR.Start = ar.Start &^ mapMask
+ // If rounding ar.End up overflows, just keep the existing mapAR.End.
+ if end := (ar.End + mapMask) &^ mapMask; end >= ar.End {
+ mapAR.End = end
+ }
+ }
+ if checkInvariants {
+ if !mapAR.IsSupersetOf(ar) {
+ panic(fmt.Sprintf("mapAR %#v is not a superset of ar %#v", mapAR, ar))
+ }
+ }
+
+ // Since this checks ar.End and not mapAR.End, we will never map a pma that
+ // is not required.
+ for pseg.Ok() && pseg.Start() < ar.End {
+ pma := pseg.ValuePtr()
+ pmaAR := pseg.Range()
+ pmaMapAR := pmaAR.Intersect(mapAR)
+ perms := pma.effectivePerms
+ if pma.needCOW {
+ perms.Write = false
+ }
+ if perms.Any() { // MapFile precondition
+ if err := mm.as.MapFile(pmaMapAR.Start, pma.file, pseg.fileRangeOf(pmaMapAR), perms, precommit); err != nil {
+ return err
+ }
+ }
+ pseg = pseg.NextSegment()
+ }
+ return nil
+}
+
+// unmapASLocked removes all AddressSpace mappings for addresses in ar.
+//
+// Preconditions: mm.activeMu must be locked.
+func (mm *MemoryManager) unmapASLocked(ar usermem.AddrRange) {
+ if mm.as == nil {
+ // No AddressSpace? Force all mappings to be unmapped on the next
+ // Activate.
+ mm.unmapAllOnActivate = true
+ return
+ }
+
+ // unmapASLocked doesn't require vmas or pmas to exist for ar, so it can be
+ // passed ranges that include addresses that can't be mapped by the
+ // application.
+ ar = ar.Intersect(mm.applicationAddrRange())
+
+ // Note that this AddressSpace may or may not be active. If the
+ // platform does not require cooperative sharing of AddressSpaces, they
+ // are retained between Deactivate/Activate calls. Despite not being
+ // active, it is still valid to perform operations on these address
+ // spaces.
+ mm.as.Unmap(ar.Start, uint64(ar.Length()))
+}
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
new file mode 100644
index 000000000..379148903
--- /dev/null
+++ b/pkg/sentry/mm/aio_context.go
@@ -0,0 +1,429 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// aioManager creates and manages asynchronous I/O contexts.
+//
+// +stateify savable
+type aioManager struct {
+ // mu protects below.
+ mu sync.Mutex `state:"nosave"`
+
+ // aioContexts is the set of asynchronous I/O contexts.
+ contexts map[uint64]*AIOContext
+}
+
+func (a *aioManager) destroy() {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ for _, ctx := range a.contexts {
+ ctx.destroy()
+ }
+}
+
+// newAIOContext creates a new context for asynchronous I/O.
+//
+// Returns false if 'id' is currently in use.
+func (a *aioManager) newAIOContext(events uint32, id uint64) bool {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ if _, ok := a.contexts[id]; ok {
+ return false
+ }
+
+ a.contexts[id] = &AIOContext{
+ requestReady: make(chan struct{}, 1),
+ maxOutstanding: events,
+ }
+ return true
+}
+
+// destroyAIOContext destroys an asynchronous I/O context. It doesn't wait for
+// for pending requests to complete. Returns the destroyed AIOContext so it can
+// be drained.
+//
+// Nil is returned if the context does not exist.
+func (a *aioManager) destroyAIOContext(id uint64) *AIOContext {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ ctx, ok := a.contexts[id]
+ if !ok {
+ return nil
+ }
+ delete(a.contexts, id)
+ ctx.destroy()
+ return ctx
+}
+
+// lookupAIOContext looks up the given context.
+//
+// Returns false if context does not exist.
+func (a *aioManager) lookupAIOContext(id uint64) (*AIOContext, bool) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ ctx, ok := a.contexts[id]
+ return ctx, ok
+}
+
+// ioResult is a completed I/O operation.
+//
+// +stateify savable
+type ioResult struct {
+ data interface{}
+ ioEntry
+}
+
+// AIOContext is a single asynchronous I/O context.
+//
+// +stateify savable
+type AIOContext struct {
+ // requestReady is the notification channel used for all requests.
+ requestReady chan struct{} `state:"nosave"`
+
+ // mu protects below.
+ mu sync.Mutex `state:"nosave"`
+
+ // results is the set of completed requests.
+ results ioList
+
+ // maxOutstanding is the maximum number of outstanding entries; this value
+ // is immutable.
+ maxOutstanding uint32
+
+ // outstanding is the number of requests outstanding; this will effectively
+ // be the number of entries in the result list or that are expected to be
+ // added to the result list.
+ outstanding uint32
+
+ // dead is set when the context is destroyed.
+ dead bool `state:"zerovalue"`
+}
+
+// destroy marks the context dead.
+func (ctx *AIOContext) destroy() {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+ ctx.dead = true
+ ctx.checkForDone()
+}
+
+// Preconditions: ctx.mu must be held by caller.
+func (ctx *AIOContext) checkForDone() {
+ if ctx.dead && ctx.outstanding == 0 {
+ close(ctx.requestReady)
+ ctx.requestReady = nil
+ }
+}
+
+// Prepare reserves space for a new request, returning true if available.
+// Returns false if the context is busy.
+func (ctx *AIOContext) Prepare() bool {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+ if ctx.outstanding >= ctx.maxOutstanding {
+ return false
+ }
+ ctx.outstanding++
+ return true
+}
+
+// PopRequest pops a completed request if available, this function does not do
+// any blocking. Returns false if no request is available.
+func (ctx *AIOContext) PopRequest() (interface{}, bool) {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+
+ // Is there anything ready?
+ if e := ctx.results.Front(); e != nil {
+ if ctx.outstanding == 0 {
+ panic("AIOContext outstanding is going negative")
+ }
+ ctx.outstanding--
+ ctx.results.Remove(e)
+ ctx.checkForDone()
+ return e.data, true
+ }
+ return nil, false
+}
+
+// FinishRequest finishes a pending request. It queues up the data
+// and notifies listeners.
+func (ctx *AIOContext) FinishRequest(data interface{}) {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+
+ // Push to the list and notify opportunistically. The channel notify
+ // here is guaranteed to be safe because outstanding must be non-zero.
+ // The requestReady channel is only closed when outstanding reaches zero.
+ ctx.results.PushBack(&ioResult{data: data})
+
+ select {
+ case ctx.requestReady <- struct{}{}:
+ default:
+ }
+}
+
+// WaitChannel returns a channel that is notified when an AIO request is
+// completed. Returns nil if the context is destroyed and there are no more
+// outstanding requests.
+func (ctx *AIOContext) WaitChannel() chan struct{} {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+ return ctx.requestReady
+}
+
+// Dead returns true if the context has been destroyed.
+func (ctx *AIOContext) Dead() bool {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+ return ctx.dead
+}
+
+// CancelPendingRequest forgets about a request that hasn't yet completed.
+func (ctx *AIOContext) CancelPendingRequest() {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+
+ if ctx.outstanding == 0 {
+ panic("AIOContext outstanding is going negative")
+ }
+ ctx.outstanding--
+ ctx.checkForDone()
+}
+
+// Drain drops all completed requests. Pending requests remain untouched.
+func (ctx *AIOContext) Drain() {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+
+ if ctx.outstanding == 0 {
+ return
+ }
+ size := uint32(ctx.results.Len())
+ if ctx.outstanding < size {
+ panic("AIOContext outstanding is going negative")
+ }
+ ctx.outstanding -= size
+ ctx.results.Reset()
+ ctx.checkForDone()
+}
+
+// aioMappable implements memmap.MappingIdentity and memmap.Mappable for AIO
+// ring buffers.
+//
+// +stateify savable
+type aioMappable struct {
+ refs.AtomicRefCount
+
+ mfp pgalloc.MemoryFileProvider
+ fr platform.FileRange
+}
+
+var aioRingBufferSize = uint64(usermem.Addr(linux.AIORingSize).MustRoundUp())
+
+func newAIOMappable(mfp pgalloc.MemoryFileProvider) (*aioMappable, error) {
+ fr, err := mfp.MemoryFile().Allocate(aioRingBufferSize, usage.Anonymous)
+ if err != nil {
+ return nil, err
+ }
+ m := aioMappable{mfp: mfp, fr: fr}
+ m.EnableLeakCheck("mm.aioMappable")
+ return &m, nil
+}
+
+// DecRef implements refs.RefCounter.DecRef.
+func (m *aioMappable) DecRef() {
+ m.AtomicRefCount.DecRefWithDestructor(func() {
+ m.mfp.MemoryFile().DecRef(m.fr)
+ })
+}
+
+// MappedName implements memmap.MappingIdentity.MappedName.
+func (m *aioMappable) MappedName(ctx context.Context) string {
+ return "[aio]"
+}
+
+// DeviceID implements memmap.MappingIdentity.DeviceID.
+func (m *aioMappable) DeviceID() uint64 {
+ return 0
+}
+
+// InodeID implements memmap.MappingIdentity.InodeID.
+func (m *aioMappable) InodeID() uint64 {
+ return 0
+}
+
+// Msync implements memmap.MappingIdentity.Msync.
+func (m *aioMappable) Msync(ctx context.Context, mr memmap.MappableRange) error {
+ // Linux: aio_ring_fops.fsync == NULL
+ return syserror.EINVAL
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (m *aioMappable) AddMapping(_ context.Context, _ memmap.MappingSpace, ar usermem.AddrRange, offset uint64, _ bool) error {
+ // Don't allow mappings to be expanded (in Linux, fs/aio.c:aio_ring_mmap()
+ // sets VM_DONTEXPAND).
+ if offset != 0 || uint64(ar.Length()) != aioRingBufferSize {
+ return syserror.EFAULT
+ }
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (m *aioMappable) RemoveMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, uint64, bool) {
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (m *aioMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, _ bool) error {
+ // Don't allow mappings to be expanded (in Linux, fs/aio.c:aio_ring_mmap()
+ // sets VM_DONTEXPAND).
+ if offset != 0 || uint64(dstAR.Length()) != aioRingBufferSize {
+ return syserror.EFAULT
+ }
+ // Require that the mapping correspond to a live AIOContext. Compare
+ // Linux's fs/aio.c:aio_ring_mremap().
+ mm, ok := ms.(*MemoryManager)
+ if !ok {
+ return syserror.EINVAL
+ }
+ am := &mm.aioManager
+ am.mu.Lock()
+ defer am.mu.Unlock()
+ oldID := uint64(srcAR.Start)
+ aioCtx, ok := am.contexts[oldID]
+ if !ok {
+ return syserror.EINVAL
+ }
+ aioCtx.mu.Lock()
+ defer aioCtx.mu.Unlock()
+ if aioCtx.dead {
+ return syserror.EINVAL
+ }
+ // Use the new ID for the AIOContext.
+ am.contexts[uint64(dstAR.Start)] = aioCtx
+ delete(am.contexts, oldID)
+ return nil
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (m *aioMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ var err error
+ if required.End > m.fr.Length() {
+ err = &memmap.BusError{syserror.EFAULT}
+ }
+ if source := optional.Intersect(memmap.MappableRange{0, m.fr.Length()}); source.Length() != 0 {
+ return []memmap.Translation{
+ {
+ Source: source,
+ File: m.mfp.MemoryFile(),
+ Offset: m.fr.Start + source.Start,
+ Perms: usermem.AnyAccess,
+ },
+ }, err
+ }
+ return nil, err
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (m *aioMappable) InvalidateUnsavable(ctx context.Context) error {
+ return nil
+}
+
+// NewAIOContext creates a new context for asynchronous I/O.
+//
+// NewAIOContext is analogous to Linux's fs/aio.c:ioctx_alloc().
+func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint64, error) {
+ // libaio get_ioevents() expects context "handle" to be a valid address.
+ // libaio peeks inside looking for a magic number. This function allocates
+ // a page per context and keeps it set to zeroes to ensure it will not
+ // match AIO_RING_MAGIC and make libaio happy.
+ m, err := newAIOMappable(mm.mfp)
+ if err != nil {
+ return 0, err
+ }
+ defer m.DecRef()
+ addr, err := mm.MMap(ctx, memmap.MMapOpts{
+ Length: aioRingBufferSize,
+ MappingIdentity: m,
+ Mappable: m,
+ // Linux uses "do_mmap_pgoff(..., PROT_READ | PROT_WRITE, ...)" in
+ // fs/aio.c:aio_setup_ring(). Since we don't implement AIO_RING_MAGIC,
+ // user mode should not write to this page.
+ Perms: usermem.Read,
+ MaxPerms: usermem.Read,
+ })
+ if err != nil {
+ return 0, err
+ }
+ id := uint64(addr)
+ if !mm.aioManager.newAIOContext(events, id) {
+ mm.MUnmap(ctx, addr, aioRingBufferSize)
+ return 0, syserror.EINVAL
+ }
+ return id, nil
+}
+
+// DestroyAIOContext destroys an asynchronous I/O context. It returns the
+// destroyed context. nil if the context does not exist.
+func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) *AIOContext {
+ if _, ok := mm.LookupAIOContext(ctx, id); !ok {
+ return nil
+ }
+
+ // Only unmaps after it assured that the address is a valid aio context to
+ // prevent random memory from been unmapped.
+ //
+ // Note: It's possible to unmap this address and map something else into
+ // the same address. Then it would be unmapping memory that it doesn't own.
+ // This is, however, the way Linux implements AIO. Keeps the same [weird]
+ // semantics in case anyone relies on it.
+ mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize)
+
+ return mm.aioManager.destroyAIOContext(id)
+}
+
+// LookupAIOContext looks up the given context. It returns false if the context
+// does not exist.
+func (mm *MemoryManager) LookupAIOContext(ctx context.Context, id uint64) (*AIOContext, bool) {
+ aioCtx, ok := mm.aioManager.lookupAIOContext(id)
+ if !ok {
+ return nil, false
+ }
+
+ // Protect against 'ids' that are inaccessible (Linux also reads 4 bytes
+ // from id).
+ var buf [4]byte
+ _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{})
+ if err != nil {
+ return nil, false
+ }
+
+ return aioCtx, true
+}
diff --git a/pkg/sentry/mm/aio_context_state.go b/pkg/sentry/mm/aio_context_state.go
new file mode 100644
index 000000000..3dabac1af
--- /dev/null
+++ b/pkg/sentry/mm/aio_context_state.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+// afterLoad is invoked by stateify.
+func (a *AIOContext) afterLoad() {
+ a.requestReady = make(chan struct{}, 1)
+}
diff --git a/pkg/sentry/mm/debug.go b/pkg/sentry/mm/debug.go
new file mode 100644
index 000000000..c273c982e
--- /dev/null
+++ b/pkg/sentry/mm/debug.go
@@ -0,0 +1,98 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+const (
+ // If checkInvariants is true, perform runtime checks for invariants
+ // expected by the mm package. This is normally disabled since MM is a
+ // significant hot path in general, and some such checks (notably
+ // memmap.CheckTranslateResult) are very expensive.
+ checkInvariants = false
+
+ // If logIOErrors is true, log I/O errors that originate from MM before
+ // converting them to EFAULT.
+ logIOErrors = false
+)
+
+// String implements fmt.Stringer.String.
+func (mm *MemoryManager) String() string {
+ return mm.DebugString(context.Background())
+}
+
+// DebugString returns a string containing information about mm for debugging.
+func (mm *MemoryManager) DebugString(ctx context.Context) string {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ mm.activeMu.RLock()
+ defer mm.activeMu.RUnlock()
+ return mm.debugStringLocked(ctx)
+}
+
+// Preconditions: mm.mappingMu and mm.activeMu must be locked.
+func (mm *MemoryManager) debugStringLocked(ctx context.Context) string {
+ var b bytes.Buffer
+ b.WriteString("VMAs:\n")
+ for vseg := mm.vmas.FirstSegment(); vseg.Ok(); vseg = vseg.NextSegment() {
+ b.Write(mm.vmaMapsEntryLocked(ctx, vseg))
+ }
+ b.WriteString("PMAs:\n")
+ for pseg := mm.pmas.FirstSegment(); pseg.Ok(); pseg = pseg.NextSegment() {
+ b.Write(pseg.debugStringEntryLocked())
+ }
+ return string(b.Bytes())
+}
+
+// Preconditions: mm.activeMu must be locked.
+func (pseg pmaIterator) debugStringEntryLocked() []byte {
+ var b bytes.Buffer
+
+ fmt.Fprintf(&b, "%08x-%08x ", pseg.Start(), pseg.End())
+
+ pma := pseg.ValuePtr()
+ if pma.effectivePerms.Read {
+ b.WriteByte('r')
+ } else {
+ b.WriteByte('-')
+ }
+ if pma.effectivePerms.Write {
+ if pma.needCOW {
+ b.WriteByte('c')
+ } else {
+ b.WriteByte('w')
+ }
+ } else {
+ b.WriteByte('-')
+ }
+ if pma.effectivePerms.Execute {
+ b.WriteByte('x')
+ } else {
+ b.WriteByte('-')
+ }
+ if pma.private {
+ b.WriteByte('p')
+ } else {
+ b.WriteByte('s')
+ }
+
+ fmt.Fprintf(&b, " %08x %T\n", pma.off, pma.file)
+ return b.Bytes()
+}
diff --git a/pkg/sentry/mm/io.go b/pkg/sentry/mm/io.go
new file mode 100644
index 000000000..fa776f9c6
--- /dev/null
+++ b/pkg/sentry/mm/io.go
@@ -0,0 +1,639 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// There are two supported ways to copy data to/from application virtual
+// memory:
+//
+// 1. Internally-mapped copying: Determine the platform.File that backs the
+// copied-to/from virtual address, obtain a mapping of its pages, and read or
+// write to the mapping.
+//
+// 2. AddressSpace copying: If platform.Platform.SupportsAddressSpaceIO() is
+// true, AddressSpace permissions are applicable, and an AddressSpace is
+// available, copy directly through the AddressSpace, handling faults as
+// needed.
+//
+// (Given that internally-mapped copying requires that backing memory is always
+// implemented using a host file descriptor, we could also preadv/pwritev to it
+// instead. But this would incur a host syscall for each use of the mapped
+// page, whereas mmap is a one-time cost.)
+//
+// The fixed overhead of internally-mapped copying is expected to be higher
+// than that of AddressSpace copying since the former always needs to translate
+// addresses, whereas the latter only needs to do so when faults occur.
+// However, the throughput of internally-mapped copying is expected to be
+// somewhat higher than that of AddressSpace copying due to the high cost of
+// page faults and because implementations of the latter usually rely on
+// safecopy, which doesn't use AVX registers. So we prefer to use AddressSpace
+// copying (when available) for smaller copies, and switch to internally-mapped
+// copying once a size threshold is exceeded.
+const (
+ // copyMapMinBytes is the size threshold for switching to internally-mapped
+ // copying in CopyOut, CopyIn, and ZeroOut.
+ copyMapMinBytes = 32 << 10 // 32 KB
+
+ // rwMapMinBytes is the size threshold for switching to internally-mapped
+ // copying in CopyOutFrom and CopyInTo. It's lower than copyMapMinBytes
+ // since AddressSpace copying in this case requires additional buffering;
+ // see CopyOutFrom for details.
+ rwMapMinBytes = 512
+)
+
+// CheckIORange is similar to usermem.Addr.ToRange, but applies bounds checks
+// consistent with Linux's arch/x86/include/asm/uaccess.h:access_ok().
+//
+// Preconditions: length >= 0.
+func (mm *MemoryManager) CheckIORange(addr usermem.Addr, length int64) (usermem.AddrRange, bool) {
+ // Note that access_ok() constrains end even if length == 0.
+ ar, ok := addr.ToRange(uint64(length))
+ return ar, (ok && ar.End <= mm.layout.MaxAddr)
+}
+
+// checkIOVec applies bound checks consistent with Linux's
+// arch/x86/include/asm/uaccess.h:access_ok() to ars.
+func (mm *MemoryManager) checkIOVec(ars usermem.AddrRangeSeq) bool {
+ for !ars.IsEmpty() {
+ ar := ars.Head()
+ if _, ok := mm.CheckIORange(ar.Start, int64(ar.Length())); !ok {
+ return false
+ }
+ ars = ars.Tail()
+ }
+ return true
+}
+
+func (mm *MemoryManager) asioEnabled(opts usermem.IOOpts) bool {
+ return mm.haveASIO && !opts.IgnorePermissions && opts.AddressSpaceActive
+}
+
+// translateIOError converts errors to EFAULT, as is usually reported for all
+// I/O errors originating from MM in Linux.
+func translateIOError(ctx context.Context, err error) error {
+ if err == nil {
+ return nil
+ }
+ if logIOErrors {
+ ctx.Debugf("MM I/O error: %v", err)
+ }
+ return syserror.EFAULT
+}
+
+// CopyOut implements usermem.IO.CopyOut.
+func (mm *MemoryManager) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, opts usermem.IOOpts) (int, error) {
+ ar, ok := mm.CheckIORange(addr, int64(len(src)))
+ if !ok {
+ return 0, syserror.EFAULT
+ }
+
+ if len(src) == 0 {
+ return 0, nil
+ }
+
+ // Do AddressSpace IO if applicable.
+ if mm.asioEnabled(opts) && len(src) < copyMapMinBytes {
+ return mm.asCopyOut(ctx, addr, src)
+ }
+
+ // Go through internal mappings.
+ n64, err := mm.withInternalMappings(ctx, ar, usermem.Write, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
+ n, err := safemem.CopySeq(ims, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(src)))
+ return n, translateIOError(ctx, err)
+ })
+ return int(n64), err
+}
+
+func (mm *MemoryManager) asCopyOut(ctx context.Context, addr usermem.Addr, src []byte) (int, error) {
+ var done int
+ for {
+ n, err := mm.as.CopyOut(addr+usermem.Addr(done), src[done:])
+ done += n
+ if err == nil {
+ return done, nil
+ }
+ if f, ok := err.(platform.SegmentationFault); ok {
+ ar, _ := addr.ToRange(uint64(len(src)))
+ if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.Write); err != nil {
+ return done, err
+ }
+ continue
+ }
+ return done, translateIOError(ctx, err)
+ }
+}
+
+// CopyIn implements usermem.IO.CopyIn.
+func (mm *MemoryManager) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, opts usermem.IOOpts) (int, error) {
+ ar, ok := mm.CheckIORange(addr, int64(len(dst)))
+ if !ok {
+ return 0, syserror.EFAULT
+ }
+
+ if len(dst) == 0 {
+ return 0, nil
+ }
+
+ // Do AddressSpace IO if applicable.
+ if mm.asioEnabled(opts) && len(dst) < copyMapMinBytes {
+ return mm.asCopyIn(ctx, addr, dst)
+ }
+
+ // Go through internal mappings.
+ n64, err := mm.withInternalMappings(ctx, ar, usermem.Read, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
+ n, err := safemem.CopySeq(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(dst)), ims)
+ return n, translateIOError(ctx, err)
+ })
+ return int(n64), err
+}
+
+func (mm *MemoryManager) asCopyIn(ctx context.Context, addr usermem.Addr, dst []byte) (int, error) {
+ var done int
+ for {
+ n, err := mm.as.CopyIn(addr+usermem.Addr(done), dst[done:])
+ done += n
+ if err == nil {
+ return done, nil
+ }
+ if f, ok := err.(platform.SegmentationFault); ok {
+ ar, _ := addr.ToRange(uint64(len(dst)))
+ if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.Read); err != nil {
+ return done, err
+ }
+ continue
+ }
+ return done, translateIOError(ctx, err)
+ }
+}
+
+// ZeroOut implements usermem.IO.ZeroOut.
+func (mm *MemoryManager) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int64, opts usermem.IOOpts) (int64, error) {
+ ar, ok := mm.CheckIORange(addr, toZero)
+ if !ok {
+ return 0, syserror.EFAULT
+ }
+
+ if toZero == 0 {
+ return 0, nil
+ }
+
+ // Do AddressSpace IO if applicable.
+ if mm.asioEnabled(opts) && toZero < copyMapMinBytes {
+ return mm.asZeroOut(ctx, addr, toZero)
+ }
+
+ // Go through internal mappings.
+ return mm.withInternalMappings(ctx, ar, usermem.Write, opts.IgnorePermissions, func(dsts safemem.BlockSeq) (uint64, error) {
+ n, err := safemem.ZeroSeq(dsts)
+ return n, translateIOError(ctx, err)
+ })
+}
+
+func (mm *MemoryManager) asZeroOut(ctx context.Context, addr usermem.Addr, toZero int64) (int64, error) {
+ var done int64
+ for {
+ n, err := mm.as.ZeroOut(addr+usermem.Addr(done), uintptr(toZero-done))
+ done += int64(n)
+ if err == nil {
+ return done, nil
+ }
+ if f, ok := err.(platform.SegmentationFault); ok {
+ ar, _ := addr.ToRange(uint64(toZero))
+ if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.Write); err != nil {
+ return done, err
+ }
+ continue
+ }
+ return done, translateIOError(ctx, err)
+ }
+}
+
+// CopyOutFrom implements usermem.IO.CopyOutFrom.
+func (mm *MemoryManager) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) {
+ if !mm.checkIOVec(ars) {
+ return 0, syserror.EFAULT
+ }
+
+ if ars.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Do AddressSpace IO if applicable.
+ if mm.asioEnabled(opts) && ars.NumBytes() < rwMapMinBytes {
+ // We have to introduce a buffered copy, instead of just passing a
+ // safemem.BlockSeq representing addresses in the AddressSpace to src.
+ // This is because usermem.IO.CopyOutFrom() guarantees that it calls
+ // src.ReadToBlocks() at most once, which is incompatible with handling
+ // faults between calls. In the future, this is probably best resolved
+ // by introducing a CopyOutFrom variant or option that allows it to
+ // call src.ReadToBlocks() any number of times.
+ //
+ // This issue applies to CopyInTo as well.
+ buf := make([]byte, int(ars.NumBytes()))
+ bufN, bufErr := src.ReadToBlocks(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)))
+ var done int64
+ for done < int64(bufN) {
+ ar := ars.Head()
+ cplen := int64(ar.Length())
+ if cplen > int64(bufN)-done {
+ cplen = int64(bufN) - done
+ }
+ n, err := mm.asCopyOut(ctx, ar.Start, buf[int(done):int(done+cplen)])
+ done += int64(n)
+ if err != nil {
+ return done, err
+ }
+ ars = ars.Tail()
+ }
+ // Do not convert errors returned by src to EFAULT.
+ return done, bufErr
+ }
+
+ // Go through internal mappings.
+ return mm.withVecInternalMappings(ctx, ars, usermem.Write, opts.IgnorePermissions, src.ReadToBlocks)
+}
+
+// CopyInTo implements usermem.IO.CopyInTo.
+func (mm *MemoryManager) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) {
+ if !mm.checkIOVec(ars) {
+ return 0, syserror.EFAULT
+ }
+
+ if ars.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Do AddressSpace IO if applicable.
+ if mm.asioEnabled(opts) && ars.NumBytes() < rwMapMinBytes {
+ buf := make([]byte, int(ars.NumBytes()))
+ var done int
+ var bufErr error
+ for !ars.IsEmpty() {
+ ar := ars.Head()
+ var n int
+ n, bufErr = mm.asCopyIn(ctx, ar.Start, buf[done:done+int(ar.Length())])
+ done += n
+ if bufErr != nil {
+ break
+ }
+ ars = ars.Tail()
+ }
+ n, err := dst.WriteFromBlocks(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:done])))
+ if err != nil {
+ return int64(n), err
+ }
+ // Do not convert errors returned by dst to EFAULT.
+ return int64(n), bufErr
+ }
+
+ // Go through internal mappings.
+ return mm.withVecInternalMappings(ctx, ars, usermem.Read, opts.IgnorePermissions, dst.WriteFromBlocks)
+}
+
+// SwapUint32 implements usermem.IO.SwapUint32.
+func (mm *MemoryManager) SwapUint32(ctx context.Context, addr usermem.Addr, new uint32, opts usermem.IOOpts) (uint32, error) {
+ ar, ok := mm.CheckIORange(addr, 4)
+ if !ok {
+ return 0, syserror.EFAULT
+ }
+
+ // Do AddressSpace IO if applicable.
+ if mm.haveASIO && opts.AddressSpaceActive && !opts.IgnorePermissions {
+ for {
+ old, err := mm.as.SwapUint32(addr, new)
+ if err == nil {
+ return old, nil
+ }
+ if f, ok := err.(platform.SegmentationFault); ok {
+ if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.ReadWrite); err != nil {
+ return 0, err
+ }
+ continue
+ }
+ return 0, translateIOError(ctx, err)
+ }
+ }
+
+ // Go through internal mappings.
+ var old uint32
+ _, err := mm.withInternalMappings(ctx, ar, usermem.ReadWrite, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
+ if ims.NumBlocks() != 1 || ims.NumBytes() != 4 {
+ // Atomicity is unachievable across mappings.
+ return 0, syserror.EFAULT
+ }
+ im := ims.Head()
+ var err error
+ old, err = safemem.SwapUint32(im, new)
+ if err != nil {
+ return 0, translateIOError(ctx, err)
+ }
+ // Return the number of bytes read.
+ return 4, nil
+ })
+ return old, err
+}
+
+// CompareAndSwapUint32 implements usermem.IO.CompareAndSwapUint32.
+func (mm *MemoryManager) CompareAndSwapUint32(ctx context.Context, addr usermem.Addr, old, new uint32, opts usermem.IOOpts) (uint32, error) {
+ ar, ok := mm.CheckIORange(addr, 4)
+ if !ok {
+ return 0, syserror.EFAULT
+ }
+
+ // Do AddressSpace IO if applicable.
+ if mm.haveASIO && opts.AddressSpaceActive && !opts.IgnorePermissions {
+ for {
+ prev, err := mm.as.CompareAndSwapUint32(addr, old, new)
+ if err == nil {
+ return prev, nil
+ }
+ if f, ok := err.(platform.SegmentationFault); ok {
+ if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.ReadWrite); err != nil {
+ return 0, err
+ }
+ continue
+ }
+ return 0, translateIOError(ctx, err)
+ }
+ }
+
+ // Go through internal mappings.
+ var prev uint32
+ _, err := mm.withInternalMappings(ctx, ar, usermem.ReadWrite, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
+ if ims.NumBlocks() != 1 || ims.NumBytes() != 4 {
+ // Atomicity is unachievable across mappings.
+ return 0, syserror.EFAULT
+ }
+ im := ims.Head()
+ var err error
+ prev, err = safemem.CompareAndSwapUint32(im, old, new)
+ if err != nil {
+ return 0, translateIOError(ctx, err)
+ }
+ // Return the number of bytes read.
+ return 4, nil
+ })
+ return prev, err
+}
+
+// LoadUint32 implements usermem.IO.LoadUint32.
+func (mm *MemoryManager) LoadUint32(ctx context.Context, addr usermem.Addr, opts usermem.IOOpts) (uint32, error) {
+ ar, ok := mm.CheckIORange(addr, 4)
+ if !ok {
+ return 0, syserror.EFAULT
+ }
+
+ // Do AddressSpace IO if applicable.
+ if mm.haveASIO && opts.AddressSpaceActive && !opts.IgnorePermissions {
+ for {
+ val, err := mm.as.LoadUint32(addr)
+ if err == nil {
+ return val, nil
+ }
+ if f, ok := err.(platform.SegmentationFault); ok {
+ if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.Read); err != nil {
+ return 0, err
+ }
+ continue
+ }
+ return 0, translateIOError(ctx, err)
+ }
+ }
+
+ // Go through internal mappings.
+ var val uint32
+ _, err := mm.withInternalMappings(ctx, ar, usermem.Read, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
+ if ims.NumBlocks() != 1 || ims.NumBytes() != 4 {
+ // Atomicity is unachievable across mappings.
+ return 0, syserror.EFAULT
+ }
+ im := ims.Head()
+ var err error
+ val, err = safemem.LoadUint32(im)
+ if err != nil {
+ return 0, translateIOError(ctx, err)
+ }
+ // Return the number of bytes read.
+ return 4, nil
+ })
+ return val, err
+}
+
+// handleASIOFault handles a page fault at address addr for an AddressSpaceIO
+// operation spanning ioar.
+//
+// Preconditions: mm.as != nil. ioar.Length() != 0. ioar.Contains(addr).
+func (mm *MemoryManager) handleASIOFault(ctx context.Context, addr usermem.Addr, ioar usermem.AddrRange, at usermem.AccessType) error {
+ // Try to map all remaining pages in the I/O operation. This RoundUp can't
+ // overflow because otherwise it would have been caught by CheckIORange.
+ end, _ := ioar.End.RoundUp()
+ ar := usermem.AddrRange{addr.RoundDown(), end}
+
+ // Don't bother trying existingPMAsLocked; in most cases, if we did have
+ // existing pmas, we wouldn't have faulted.
+
+ // Ensure that we have usable vmas. Here and below, only return early if we
+ // can't map the first (faulting) page; failure to map later pages are
+ // silently ignored. This maximizes partial success.
+ mm.mappingMu.RLock()
+ vseg, vend, err := mm.getVMAsLocked(ctx, ar, at, false)
+ if vendaddr := vend.Start(); vendaddr < ar.End {
+ if vendaddr <= ar.Start {
+ mm.mappingMu.RUnlock()
+ return translateIOError(ctx, err)
+ }
+ ar.End = vendaddr
+ }
+
+ // Ensure that we have usable pmas.
+ mm.activeMu.Lock()
+ pseg, pend, err := mm.getPMAsLocked(ctx, vseg, ar, at)
+ mm.mappingMu.RUnlock()
+ if pendaddr := pend.Start(); pendaddr < ar.End {
+ if pendaddr <= ar.Start {
+ mm.activeMu.Unlock()
+ return translateIOError(ctx, err)
+ }
+ ar.End = pendaddr
+ }
+
+ // Downgrade to a read-lock on activeMu since we don't need to mutate pmas
+ // anymore.
+ mm.activeMu.DowngradeLock()
+
+ err = mm.mapASLocked(pseg, ar, false)
+ mm.activeMu.RUnlock()
+ return translateIOError(ctx, err)
+}
+
+// withInternalMappings ensures that pmas exist for all addresses in ar,
+// support access of type (at, ignorePermissions), and have internal mappings
+// cached. It then calls f with mm.activeMu locked for reading, passing
+// internal mappings for the subrange of ar for which this property holds.
+//
+// withInternalMappings takes a function returning uint64 since many safemem
+// functions have this property, but returns an int64 since this is usually
+// more useful for usermem.IO methods.
+//
+// Preconditions: 0 < ar.Length() <= math.MaxInt64.
+func (mm *MemoryManager) withInternalMappings(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool, f func(safemem.BlockSeq) (uint64, error)) (int64, error) {
+ // If pmas are already available, we can do IO without touching mm.vmas or
+ // mm.mappingMu.
+ mm.activeMu.RLock()
+ if pseg := mm.existingPMAsLocked(ar, at, ignorePermissions, true /* needInternalMappings */); pseg.Ok() {
+ n, err := f(mm.internalMappingsLocked(pseg, ar))
+ mm.activeMu.RUnlock()
+ // Do not convert errors returned by f to EFAULT.
+ return int64(n), err
+ }
+ mm.activeMu.RUnlock()
+
+ // Ensure that we have usable vmas.
+ mm.mappingMu.RLock()
+ vseg, vend, verr := mm.getVMAsLocked(ctx, ar, at, ignorePermissions)
+ if vendaddr := vend.Start(); vendaddr < ar.End {
+ if vendaddr <= ar.Start {
+ mm.mappingMu.RUnlock()
+ return 0, translateIOError(ctx, verr)
+ }
+ ar.End = vendaddr
+ }
+
+ // Ensure that we have usable pmas.
+ mm.activeMu.Lock()
+ pseg, pend, perr := mm.getPMAsLocked(ctx, vseg, ar, at)
+ mm.mappingMu.RUnlock()
+ if pendaddr := pend.Start(); pendaddr < ar.End {
+ if pendaddr <= ar.Start {
+ mm.activeMu.Unlock()
+ return 0, translateIOError(ctx, perr)
+ }
+ ar.End = pendaddr
+ }
+ imend, imerr := mm.getPMAInternalMappingsLocked(pseg, ar)
+ mm.activeMu.DowngradeLock()
+ if imendaddr := imend.Start(); imendaddr < ar.End {
+ if imendaddr <= ar.Start {
+ mm.activeMu.RUnlock()
+ return 0, translateIOError(ctx, imerr)
+ }
+ ar.End = imendaddr
+ }
+
+ // Do I/O.
+ un, err := f(mm.internalMappingsLocked(pseg, ar))
+ mm.activeMu.RUnlock()
+ n := int64(un)
+
+ // Return the first error in order of progress through ar.
+ if err != nil {
+ // Do not convert errors returned by f to EFAULT.
+ return n, err
+ }
+ if imerr != nil {
+ return n, translateIOError(ctx, imerr)
+ }
+ if perr != nil {
+ return n, translateIOError(ctx, perr)
+ }
+ return n, translateIOError(ctx, verr)
+}
+
+// withVecInternalMappings ensures that pmas exist for all addresses in ars,
+// support access of type (at, ignorePermissions), and have internal mappings
+// cached. It then calls f with mm.activeMu locked for reading, passing
+// internal mappings for the subset of ars for which this property holds.
+//
+// Preconditions: !ars.IsEmpty().
+func (mm *MemoryManager) withVecInternalMappings(ctx context.Context, ars usermem.AddrRangeSeq, at usermem.AccessType, ignorePermissions bool, f func(safemem.BlockSeq) (uint64, error)) (int64, error) {
+ // withInternalMappings is faster than withVecInternalMappings because of
+ // iterator plumbing (this isn't generally practical in the vector case due
+ // to iterator invalidation between AddrRanges). Use it if possible.
+ if ars.NumRanges() == 1 {
+ return mm.withInternalMappings(ctx, ars.Head(), at, ignorePermissions, f)
+ }
+
+ // If pmas are already available, we can do IO without touching mm.vmas or
+ // mm.mappingMu.
+ mm.activeMu.RLock()
+ if mm.existingVecPMAsLocked(ars, at, ignorePermissions, true /* needInternalMappings */) {
+ n, err := f(mm.vecInternalMappingsLocked(ars))
+ mm.activeMu.RUnlock()
+ // Do not convert errors returned by f to EFAULT.
+ return int64(n), err
+ }
+ mm.activeMu.RUnlock()
+
+ // Ensure that we have usable vmas.
+ mm.mappingMu.RLock()
+ vars, verr := mm.getVecVMAsLocked(ctx, ars, at, ignorePermissions)
+ if vars.NumBytes() == 0 {
+ mm.mappingMu.RUnlock()
+ return 0, translateIOError(ctx, verr)
+ }
+
+ // Ensure that we have usable pmas.
+ mm.activeMu.Lock()
+ pars, perr := mm.getVecPMAsLocked(ctx, vars, at)
+ mm.mappingMu.RUnlock()
+ if pars.NumBytes() == 0 {
+ mm.activeMu.Unlock()
+ return 0, translateIOError(ctx, perr)
+ }
+ imars, imerr := mm.getVecPMAInternalMappingsLocked(pars)
+ mm.activeMu.DowngradeLock()
+ if imars.NumBytes() == 0 {
+ mm.activeMu.RUnlock()
+ return 0, translateIOError(ctx, imerr)
+ }
+
+ // Do I/O.
+ un, err := f(mm.vecInternalMappingsLocked(imars))
+ mm.activeMu.RUnlock()
+ n := int64(un)
+
+ // Return the first error in order of progress through ars.
+ if err != nil {
+ // Do not convert errors from f to EFAULT.
+ return n, err
+ }
+ if imerr != nil {
+ return n, translateIOError(ctx, imerr)
+ }
+ if perr != nil {
+ return n, translateIOError(ctx, perr)
+ }
+ return n, translateIOError(ctx, verr)
+}
+
+// truncatedAddrRangeSeq returns a copy of ars, but with the end truncated to
+// at most address end on AddrRange arsit.Head(). It is used in vector I/O paths to
+// truncate usermem.AddrRangeSeq when errors occur.
+//
+// Preconditions: !arsit.IsEmpty(). end <= arsit.Head().End.
+func truncatedAddrRangeSeq(ars, arsit usermem.AddrRangeSeq, end usermem.Addr) usermem.AddrRangeSeq {
+ ar := arsit.Head()
+ if end <= ar.Start {
+ return ars.TakeFirst64(ars.NumBytes() - arsit.NumBytes())
+ }
+ return ars.TakeFirst64(ars.NumBytes() - arsit.NumBytes() + int64(end-ar.Start))
+}
diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go
new file mode 100644
index 000000000..aac56679b
--- /dev/null
+++ b/pkg/sentry/mm/lifecycle.go
@@ -0,0 +1,283 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// NewMemoryManager returns a new MemoryManager with no mappings and 1 user.
+func NewMemoryManager(p platform.Platform, mfp pgalloc.MemoryFileProvider, sleepForActivation bool) *MemoryManager {
+ return &MemoryManager{
+ p: p,
+ mfp: mfp,
+ haveASIO: p.SupportsAddressSpaceIO(),
+ privateRefs: &privateRefs{},
+ users: 1,
+ auxv: arch.Auxv{},
+ dumpability: UserDumpable,
+ aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ sleepForActivation: sleepForActivation,
+ }
+}
+
+// SetMmapLayout initializes mm's layout from the given arch.Context.
+//
+// Preconditions: mm contains no mappings and is not used concurrently.
+func (mm *MemoryManager) SetMmapLayout(ac arch.Context, r *limits.LimitSet) (arch.MmapLayout, error) {
+ layout, err := ac.NewMmapLayout(mm.p.MinUserAddress(), mm.p.MaxUserAddress(), r)
+ if err != nil {
+ return arch.MmapLayout{}, err
+ }
+ mm.layout = layout
+ return layout, nil
+}
+
+// Fork creates a copy of mm with 1 user, as for Linux syscalls fork() or
+// clone() (without CLONE_VM).
+func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ mm2 := &MemoryManager{
+ p: mm.p,
+ mfp: mm.mfp,
+ haveASIO: mm.haveASIO,
+ layout: mm.layout,
+ privateRefs: mm.privateRefs,
+ users: 1,
+ brk: mm.brk,
+ usageAS: mm.usageAS,
+ dataAS: mm.dataAS,
+ // "The child does not inherit its parent's memory locks (mlock(2),
+ // mlockall(2))." - fork(2). So lockedAS is 0 and defMLockMode is
+ // MLockNone, both of which are zero values. vma.mlockMode is reset
+ // when copied below.
+ captureInvalidations: true,
+ argv: mm.argv,
+ envv: mm.envv,
+ auxv: append(arch.Auxv(nil), mm.auxv...),
+ // IncRef'd below, once we know that there isn't an error.
+ executable: mm.executable,
+ dumpability: mm.dumpability,
+ aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ sleepForActivation: mm.sleepForActivation,
+ vdsoSigReturnAddr: mm.vdsoSigReturnAddr,
+ }
+
+ // Copy vmas.
+ dontforks := false
+ dstvgap := mm2.vmas.FirstGap()
+ for srcvseg := mm.vmas.FirstSegment(); srcvseg.Ok(); srcvseg = srcvseg.NextSegment() {
+ vma := srcvseg.Value() // makes a copy of the vma
+ vmaAR := srcvseg.Range()
+
+ if vma.dontfork {
+ length := uint64(vmaAR.Length())
+ mm2.usageAS -= length
+ if vma.isPrivateDataLocked() {
+ mm2.dataAS -= length
+ }
+ dontforks = true
+ continue
+ }
+
+ // Inform the Mappable, if any, of the new mapping.
+ if vma.mappable != nil {
+ if err := vma.mappable.AddMapping(ctx, mm2, vmaAR, vma.off, vma.canWriteMappableLocked()); err != nil {
+ mm2.removeVMAsLocked(ctx, mm2.applicationAddrRange())
+ return nil, err
+ }
+ }
+ if vma.id != nil {
+ vma.id.IncRef()
+ }
+ vma.mlockMode = memmap.MLockNone
+ dstvgap = mm2.vmas.Insert(dstvgap, vmaAR, vma).NextGap()
+ // We don't need to update mm2.usageAS since we copied it from mm
+ // above.
+ }
+
+ // Copy pmas. We have to lock mm.activeMu for writing to make existing
+ // private pmas copy-on-write. We also have to lock mm2.activeMu since
+ // after copying vmas above, memmap.Mappables may call mm2.Invalidate. We
+ // only copy private pmas, since in the common case where fork(2) is
+ // immediately followed by execve(2), copying non-private pmas that can be
+ // regenerated by calling memmap.Mappable.Translate is a waste of time.
+ // (Linux does the same; compare kernel/fork.c:dup_mmap() =>
+ // mm/memory.c:copy_page_range().)
+ mm2.activeMu.Lock()
+ defer mm2.activeMu.Unlock()
+ mm.activeMu.Lock()
+ defer mm.activeMu.Unlock()
+ if dontforks {
+ defer mm.pmas.MergeRange(mm.applicationAddrRange())
+ }
+ srcvseg := mm.vmas.FirstSegment()
+ dstpgap := mm2.pmas.FirstGap()
+ var unmapAR usermem.AddrRange
+ for srcpseg := mm.pmas.FirstSegment(); srcpseg.Ok(); srcpseg = srcpseg.NextSegment() {
+ pma := srcpseg.ValuePtr()
+ if !pma.private {
+ continue
+ }
+
+ if dontforks {
+ // Find the 'vma' that contains the starting address
+ // associated with the 'pma' (there must be one).
+ srcvseg = srcvseg.seekNextLowerBound(srcpseg.Start())
+ if checkInvariants {
+ if !srcvseg.Ok() {
+ panic(fmt.Sprintf("no vma covers pma range %v", srcpseg.Range()))
+ }
+ if srcpseg.Start() < srcvseg.Start() {
+ panic(fmt.Sprintf("vma %v ran ahead of pma %v", srcvseg.Range(), srcpseg.Range()))
+ }
+ }
+
+ srcpseg = mm.pmas.Isolate(srcpseg, srcvseg.Range())
+ if srcvseg.ValuePtr().dontfork {
+ continue
+ }
+ pma = srcpseg.ValuePtr()
+ }
+
+ if !pma.needCOW {
+ pma.needCOW = true
+ if pma.effectivePerms.Write {
+ // We don't want to unmap the whole address space, even though
+ // doing so would reduce calls to unmapASLocked(), because mm
+ // will most likely continue to be used after the fork, so
+ // unmapping pmas unnecessarily will result in extra page
+ // faults. But we do want to merge consecutive AddrRanges
+ // across pma boundaries.
+ if unmapAR.End == srcpseg.Start() {
+ unmapAR.End = srcpseg.End()
+ } else {
+ if unmapAR.Length() != 0 {
+ mm.unmapASLocked(unmapAR)
+ }
+ unmapAR = srcpseg.Range()
+ }
+ pma.effectivePerms.Write = false
+ }
+ pma.maxPerms.Write = false
+ }
+ fr := srcpseg.fileRange()
+ mm2.incPrivateRef(fr)
+ srcpseg.ValuePtr().file.IncRef(fr)
+ addrRange := srcpseg.Range()
+ mm2.addRSSLocked(addrRange)
+ dstpgap = mm2.pmas.Insert(dstpgap, addrRange, *pma).NextGap()
+ }
+ if unmapAR.Length() != 0 {
+ mm.unmapASLocked(unmapAR)
+ }
+
+ // Between when we call memmap.Mappable.AddMapping while copying vmas and
+ // when we lock mm2.activeMu to copy pmas, calls to mm2.Invalidate() are
+ // ineffective because the pmas they invalidate haven't yet been copied,
+ // possibly allowing mm2 to get invalidated translations:
+ //
+ // Invalidating Mappable mm.Fork
+ // --------------------- -------
+ //
+ // mm2.Invalidate()
+ // mm.activeMu.Lock()
+ // mm.Invalidate() /* blocks */
+ // mm2.activeMu.Lock()
+ // (mm copies invalidated pma to mm2)
+ //
+ // This would technically be both safe (since we only copy private pmas,
+ // which will still hold a reference on their memory) and consistent with
+ // Linux, but we avoid it anyway by setting mm2.captureInvalidations during
+ // construction, causing calls to mm2.Invalidate() to be captured in
+ // mm2.capturedInvalidations, to be replayed after pmas are copied - i.e.
+ // here.
+ mm2.captureInvalidations = false
+ for _, invArgs := range mm2.capturedInvalidations {
+ mm2.invalidateLocked(invArgs.ar, invArgs.opts.InvalidatePrivate, true)
+ }
+ mm2.capturedInvalidations = nil
+
+ if mm2.executable != nil {
+ mm2.executable.IncRef()
+ }
+ return mm2, nil
+}
+
+// IncUsers increments mm's user count and returns true. If the user count is
+// already 0, IncUsers does nothing and returns false.
+func (mm *MemoryManager) IncUsers() bool {
+ for {
+ users := atomic.LoadInt32(&mm.users)
+ if users == 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt32(&mm.users, users, users+1) {
+ return true
+ }
+ }
+}
+
+// DecUsers decrements mm's user count. If the user count reaches 0, all
+// mappings in mm are unmapped.
+func (mm *MemoryManager) DecUsers(ctx context.Context) {
+ if users := atomic.AddInt32(&mm.users, -1); users > 0 {
+ return
+ } else if users < 0 {
+ panic(fmt.Sprintf("Invalid MemoryManager.users: %d", users))
+ }
+
+ mm.aioManager.destroy()
+
+ mm.metadataMu.Lock()
+ exe := mm.executable
+ mm.executable = nil
+ mm.metadataMu.Unlock()
+ if exe != nil {
+ exe.DecRef()
+ }
+
+ mm.activeMu.Lock()
+ // Sanity check.
+ if atomic.LoadInt32(&mm.active) != 0 {
+ panic("active address space lost?")
+ }
+ // Make sure the AddressSpace is returned.
+ if mm.as != nil {
+ mm.as.Release()
+ mm.as = nil
+ }
+ mm.activeMu.Unlock()
+
+ mm.mappingMu.Lock()
+ defer mm.mappingMu.Unlock()
+ // If mm is being dropped before mm.SetMmapLayout was called,
+ // mm.applicationAddrRange() will be empty.
+ if ar := mm.applicationAddrRange(); ar.Length() != 0 {
+ mm.unmapLocked(ctx, ar)
+ }
+}
diff --git a/pkg/sentry/mm/metadata.go b/pkg/sentry/mm/metadata.go
new file mode 100644
index 000000000..28e5057f7
--- /dev/null
+++ b/pkg/sentry/mm/metadata.go
@@ -0,0 +1,183 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Dumpability describes if and how core dumps should be created.
+type Dumpability int
+
+const (
+ // NotDumpable indicates that core dumps should never be created.
+ NotDumpable Dumpability = iota
+
+ // UserDumpable indicates that core dumps should be created, owned by
+ // the current user.
+ UserDumpable
+
+ // RootDumpable indicates that core dumps should be created, owned by
+ // root.
+ RootDumpable
+)
+
+// Dumpability returns the dumpability.
+func (mm *MemoryManager) Dumpability() Dumpability {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ return mm.dumpability
+}
+
+// SetDumpability sets the dumpability.
+func (mm *MemoryManager) SetDumpability(d Dumpability) {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ mm.dumpability = d
+}
+
+// ArgvStart returns the start of the application argument vector.
+//
+// There is no guarantee that this value is sensible w.r.t. ArgvEnd.
+func (mm *MemoryManager) ArgvStart() usermem.Addr {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ return mm.argv.Start
+}
+
+// SetArgvStart sets the start of the application argument vector.
+func (mm *MemoryManager) SetArgvStart(a usermem.Addr) {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ mm.argv.Start = a
+}
+
+// ArgvEnd returns the end of the application argument vector.
+//
+// There is no guarantee that this value is sensible w.r.t. ArgvStart.
+func (mm *MemoryManager) ArgvEnd() usermem.Addr {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ return mm.argv.End
+}
+
+// SetArgvEnd sets the end of the application argument vector.
+func (mm *MemoryManager) SetArgvEnd(a usermem.Addr) {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ mm.argv.End = a
+}
+
+// EnvvStart returns the start of the application environment vector.
+//
+// There is no guarantee that this value is sensible w.r.t. EnvvEnd.
+func (mm *MemoryManager) EnvvStart() usermem.Addr {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ return mm.envv.Start
+}
+
+// SetEnvvStart sets the start of the application environment vector.
+func (mm *MemoryManager) SetEnvvStart(a usermem.Addr) {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ mm.envv.Start = a
+}
+
+// EnvvEnd returns the end of the application environment vector.
+//
+// There is no guarantee that this value is sensible w.r.t. EnvvStart.
+func (mm *MemoryManager) EnvvEnd() usermem.Addr {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ return mm.envv.End
+}
+
+// SetEnvvEnd sets the end of the application environment vector.
+func (mm *MemoryManager) SetEnvvEnd(a usermem.Addr) {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ mm.envv.End = a
+}
+
+// Auxv returns the current map of auxiliary vectors.
+func (mm *MemoryManager) Auxv() arch.Auxv {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ return append(arch.Auxv(nil), mm.auxv...)
+}
+
+// SetAuxv sets the entire map of auxiliary vectors.
+func (mm *MemoryManager) SetAuxv(auxv arch.Auxv) {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ mm.auxv = append(arch.Auxv(nil), auxv...)
+}
+
+// Executable returns the executable, if available.
+//
+// An additional reference will be taken in the case of a non-nil executable,
+// which must be released by the caller.
+func (mm *MemoryManager) Executable() fsbridge.File {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+
+ if mm.executable == nil {
+ return nil
+ }
+
+ mm.executable.IncRef()
+ return mm.executable
+}
+
+// SetExecutable sets the executable.
+//
+// This takes a reference on d.
+func (mm *MemoryManager) SetExecutable(file fsbridge.File) {
+ mm.metadataMu.Lock()
+
+ // Grab a new reference.
+ file.IncRef()
+
+ // Set the executable.
+ orig := mm.executable
+ mm.executable = file
+
+ mm.metadataMu.Unlock()
+
+ // Release the old reference.
+ //
+ // Do this without holding the lock, since it may wind up doing some
+ // I/O to sync the dirent, etc.
+ if orig != nil {
+ orig.DecRef()
+ }
+}
+
+// VDSOSigReturn returns the address of vdso_sigreturn.
+func (mm *MemoryManager) VDSOSigReturn() uint64 {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ return mm.vdsoSigReturnAddr
+}
+
+// SetVDSOSigReturn sets the address of vdso_sigreturn.
+func (mm *MemoryManager) SetVDSOSigReturn(addr uint64) {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ mm.vdsoSigReturnAddr = addr
+}
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
new file mode 100644
index 000000000..6db7c3d40
--- /dev/null
+++ b/pkg/sentry/mm/mm.go
@@ -0,0 +1,478 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package mm provides a memory management subsystem. See README.md for a
+// detailed overview.
+//
+// Lock order:
+//
+// fs locks, except for memmap.Mappable locks
+// mm.MemoryManager.metadataMu
+// mm.MemoryManager.mappingMu
+// Locks taken by memmap.Mappable methods other than Translate
+// mm.MemoryManager.activeMu
+// Locks taken by memmap.Mappable.Translate
+// mm.privateRefs.mu
+// platform.AddressSpace locks
+// platform.File locks
+// mm.aioManager.mu
+// mm.AIOContext.mu
+//
+// Only mm.MemoryManager.Fork is permitted to lock mm.MemoryManager.activeMu in
+// multiple mm.MemoryManagers, as it does so in a well-defined order (forked
+// child first).
+package mm
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// MemoryManager implements a virtual address space.
+//
+// +stateify savable
+type MemoryManager struct {
+ // p and mfp are immutable.
+ p platform.Platform
+ mfp pgalloc.MemoryFileProvider
+
+ // haveASIO is the cached result of p.SupportsAddressSpaceIO(). Aside from
+ // eliminating an indirect call in the hot I/O path, this makes
+ // MemoryManager.asioEnabled() a leaf function, allowing it to be inlined.
+ //
+ // haveASIO is immutable.
+ haveASIO bool `state:"nosave"`
+
+ // layout is the memory layout.
+ //
+ // layout is set by the binary loader before the MemoryManager can be used.
+ layout arch.MmapLayout
+
+ // privateRefs stores reference counts for private memory (memory whose
+ // ownership is shared by one or more pmas instead of being owned by a
+ // memmap.Mappable).
+ //
+ // privateRefs is immutable.
+ privateRefs *privateRefs
+
+ // users is the number of dependencies on the mappings in the MemoryManager.
+ // When the number of references in users reaches zero, all mappings are
+ // unmapped.
+ //
+ // users is accessed using atomic memory operations.
+ users int32
+
+ // mappingMu is analogous to Linux's struct mm_struct::mmap_sem.
+ mappingMu sync.RWMutex `state:"nosave"`
+
+ // vmas stores virtual memory areas. Since vmas are stored by value,
+ // clients should usually use vmaIterator.ValuePtr() instead of
+ // vmaIterator.Value() to get a pointer to the vma rather than a copy.
+ //
+ // Invariants: vmas are always page-aligned.
+ //
+ // vmas is protected by mappingMu.
+ vmas vmaSet
+
+ // brk is the mm's brk, which is manipulated using the brk(2) system call.
+ // The brk is initially set up by the loader which maps an executable
+ // binary into the mm.
+ //
+ // brk is protected by mappingMu.
+ brk usermem.AddrRange
+
+ // usageAS is vmas.Span(), cached to accelerate RLIMIT_AS checks.
+ //
+ // usageAS is protected by mappingMu.
+ usageAS uint64
+
+ // lockedAS is the combined size in bytes of all vmas with vma.mlockMode !=
+ // memmap.MLockNone.
+ //
+ // lockedAS is protected by mappingMu.
+ lockedAS uint64
+
+ // dataAS is the size of private data segments, like mm_struct->data_vm.
+ // It means the vma which is private, writable, not stack.
+ //
+ // dataAS is protected by mappingMu.
+ dataAS uint64
+
+ // New VMAs created by MMap use whichever of memmap.MMapOpts.MLockMode or
+ // defMLockMode is greater.
+ //
+ // defMLockMode is protected by mappingMu.
+ defMLockMode memmap.MLockMode
+
+ // activeMu is loosely analogous to Linux's struct
+ // mm_struct::page_table_lock.
+ activeMu sync.RWMutex `state:"nosave"`
+
+ // pmas stores platform mapping areas used to implement vmas. Since pmas
+ // are stored by value, clients should usually use pmaIterator.ValuePtr()
+ // instead of pmaIterator.Value() to get a pointer to the pma rather than
+ // a copy.
+ //
+ // Inserting or removing segments from pmas should happen along with a
+ // call to mm.insertRSS or mm.removeRSS.
+ //
+ // Invariants: pmas are always page-aligned. If a pma exists for a given
+ // address, a vma must also exist for that address.
+ //
+ // pmas is protected by activeMu.
+ pmas pmaSet
+
+ // curRSS is pmas.Span(), cached to accelerate updates to maxRSS. It is
+ // reported as the MemoryManager's RSS.
+ //
+ // maxRSS should be modified only via insertRSS and removeRSS, not
+ // directly.
+ //
+ // maxRSS is protected by activeMu.
+ curRSS uint64
+
+ // maxRSS is the maximum resident set size in bytes of a MemoryManager.
+ // It is tracked as the application adds and removes mappings to pmas.
+ //
+ // maxRSS should be modified only via insertRSS, not directly.
+ //
+ // maxRSS is protected by activeMu.
+ maxRSS uint64
+
+ // as is the platform.AddressSpace that pmas are mapped into. active is the
+ // number of contexts that require as to be non-nil; if active == 0, as may
+ // be nil.
+ //
+ // as is protected by activeMu. active is manipulated with atomic memory
+ // operations; transitions to and from zero are additionally protected by
+ // activeMu. (This is because such transitions may need to be atomic with
+ // changes to as.)
+ as platform.AddressSpace `state:"nosave"`
+ active int32 `state:"zerovalue"`
+
+ // unmapAllOnActivate indicates that the next Activate call should activate
+ // an empty AddressSpace.
+ //
+ // This is used to ensure that an AddressSpace cached in
+ // NewAddressSpace is not used after some change in the MemoryManager
+ // or VMAs has made that AddressSpace stale.
+ //
+ // unmapAllOnActivate is protected by activeMu. It must only be set when
+ // there is no active or cached AddressSpace. If as != nil, then
+ // invalidations should be propagated immediately.
+ unmapAllOnActivate bool `state:"nosave"`
+
+ // If captureInvalidations is true, calls to MM.Invalidate() are recorded
+ // in capturedInvalidations rather than being applied immediately to pmas.
+ // This is to avoid a race condition in MM.Fork(); see that function for
+ // details.
+ //
+ // Both captureInvalidations and capturedInvalidations are protected by
+ // activeMu. Neither need to be saved since captureInvalidations is only
+ // enabled during MM.Fork(), during which saving can't occur.
+ captureInvalidations bool `state:"zerovalue"`
+ capturedInvalidations []invalidateArgs `state:"nosave"`
+
+ metadataMu sync.Mutex `state:"nosave"`
+
+ // argv is the application argv. This is set up by the loader and may be
+ // modified by prctl(PR_SET_MM_ARG_START/PR_SET_MM_ARG_END). No
+ // requirements apply to argv; we do not require that argv.WellFormed().
+ //
+ // argv is protected by metadataMu.
+ argv usermem.AddrRange
+
+ // envv is the application envv. This is set up by the loader and may be
+ // modified by prctl(PR_SET_MM_ENV_START/PR_SET_MM_ENV_END). No
+ // requirements apply to envv; we do not require that envv.WellFormed().
+ //
+ // envv is protected by metadataMu.
+ envv usermem.AddrRange
+
+ // auxv is the ELF's auxiliary vector.
+ //
+ // auxv is protected by metadataMu.
+ auxv arch.Auxv
+
+ // executable is the executable for this MemoryManager. If executable
+ // is not nil, it holds a reference on the Dirent.
+ //
+ // executable is protected by metadataMu.
+ executable fsbridge.File
+
+ // dumpability describes if and how this MemoryManager may be dumped to
+ // userspace.
+ //
+ // dumpability is protected by metadataMu.
+ dumpability Dumpability
+
+ // aioManager keeps track of AIOContexts used for async IOs. AIOManager
+ // must be cloned when CLONE_VM is used.
+ aioManager aioManager
+
+ // sleepForActivation indicates whether the task should report to be sleeping
+ // before trying to activate the address space. When set to true, delays in
+ // activation are not reported as stuck tasks by the watchdog.
+ sleepForActivation bool
+
+ // vdsoSigReturnAddr is the address of 'vdso_sigreturn'.
+ vdsoSigReturnAddr uint64
+}
+
+// vma represents a virtual memory area.
+//
+// +stateify savable
+type vma struct {
+ // mappable is the virtual memory object mapped by this vma. If mappable is
+ // nil, the vma represents a private anonymous mapping.
+ mappable memmap.Mappable
+
+ // off is the offset into mappable at which this vma begins. If mappable is
+ // nil, off is meaningless.
+ off uint64
+
+ // To speedup VMA save/restore, we group and save the following booleans
+ // as a single integer.
+
+ // realPerms are the memory permissions on this vma, as defined by the
+ // application.
+ realPerms usermem.AccessType `state:".(int)"`
+
+ // effectivePerms are the memory permissions on this vma which are
+ // actually used to control access.
+ //
+ // Invariant: effectivePerms == realPerms.Effective().
+ effectivePerms usermem.AccessType `state:"manual"`
+
+ // maxPerms limits the set of permissions that may ever apply to this
+ // memory, as well as accesses for which usermem.IOOpts.IgnorePermissions
+ // is true (e.g. ptrace(PTRACE_POKEDATA)).
+ //
+ // Invariant: maxPerms == maxPerms.Effective().
+ maxPerms usermem.AccessType `state:"manual"`
+
+ // private is true if this is a MAP_PRIVATE mapping, such that writes to
+ // the mapping are propagated to a copy.
+ private bool `state:"manual"`
+
+ // growsDown is true if the mapping may be automatically extended downward
+ // under certain conditions. If growsDown is true, mappable must be nil.
+ //
+ // There is currently no corresponding growsUp flag; in Linux, the only
+ // architectures that can have VM_GROWSUP mappings are ia64, parisc, and
+ // metag, none of which we currently support.
+ growsDown bool `state:"manual"`
+
+ // dontfork is the MADV_DONTFORK setting for this vma configured by madvise().
+ dontfork bool
+
+ mlockMode memmap.MLockMode
+
+ // numaPolicy is the NUMA policy for this vma set by mbind().
+ numaPolicy linux.NumaPolicy
+
+ // numaNodemask is the NUMA nodemask for this vma set by mbind().
+ numaNodemask uint64
+
+ // If id is not nil, it controls the lifecycle of mappable and provides vma
+ // metadata shown in /proc/[pid]/maps, and the vma holds a reference.
+ id memmap.MappingIdentity
+
+ // If hint is non-empty, it is a description of the vma printed in
+ // /proc/[pid]/maps. hint takes priority over id.MappedName().
+ hint string
+}
+
+const (
+ vmaRealPermsRead = 1 << iota
+ vmaRealPermsWrite
+ vmaRealPermsExecute
+ vmaEffectivePermsRead
+ vmaEffectivePermsWrite
+ vmaEffectivePermsExecute
+ vmaMaxPermsRead
+ vmaMaxPermsWrite
+ vmaMaxPermsExecute
+ vmaPrivate
+ vmaGrowsDown
+)
+
+func (v *vma) saveRealPerms() int {
+ var b int
+ if v.realPerms.Read {
+ b |= vmaRealPermsRead
+ }
+ if v.realPerms.Write {
+ b |= vmaRealPermsWrite
+ }
+ if v.realPerms.Execute {
+ b |= vmaRealPermsExecute
+ }
+ if v.effectivePerms.Read {
+ b |= vmaEffectivePermsRead
+ }
+ if v.effectivePerms.Write {
+ b |= vmaEffectivePermsWrite
+ }
+ if v.effectivePerms.Execute {
+ b |= vmaEffectivePermsExecute
+ }
+ if v.maxPerms.Read {
+ b |= vmaMaxPermsRead
+ }
+ if v.maxPerms.Write {
+ b |= vmaMaxPermsWrite
+ }
+ if v.maxPerms.Execute {
+ b |= vmaMaxPermsExecute
+ }
+ if v.private {
+ b |= vmaPrivate
+ }
+ if v.growsDown {
+ b |= vmaGrowsDown
+ }
+ return b
+}
+
+func (v *vma) loadRealPerms(b int) {
+ if b&vmaRealPermsRead > 0 {
+ v.realPerms.Read = true
+ }
+ if b&vmaRealPermsWrite > 0 {
+ v.realPerms.Write = true
+ }
+ if b&vmaRealPermsExecute > 0 {
+ v.realPerms.Execute = true
+ }
+ if b&vmaEffectivePermsRead > 0 {
+ v.effectivePerms.Read = true
+ }
+ if b&vmaEffectivePermsWrite > 0 {
+ v.effectivePerms.Write = true
+ }
+ if b&vmaEffectivePermsExecute > 0 {
+ v.effectivePerms.Execute = true
+ }
+ if b&vmaMaxPermsRead > 0 {
+ v.maxPerms.Read = true
+ }
+ if b&vmaMaxPermsWrite > 0 {
+ v.maxPerms.Write = true
+ }
+ if b&vmaMaxPermsExecute > 0 {
+ v.maxPerms.Execute = true
+ }
+ if b&vmaPrivate > 0 {
+ v.private = true
+ }
+ if b&vmaGrowsDown > 0 {
+ v.growsDown = true
+ }
+}
+
+// pma represents a platform mapping area.
+//
+// +stateify savable
+type pma struct {
+ // file is the file mapped by this pma. Only pmas for which file ==
+ // MemoryManager.mfp.MemoryFile() may be saved. pmas hold a reference to
+ // the corresponding file range while they exist.
+ file platform.File `state:"nosave"`
+
+ // off is the offset into file at which this pma begins.
+ //
+ // Note that pmas do *not* hold references on offsets in file! If private
+ // is true, MemoryManager.privateRefs holds the reference instead. If
+ // private is false, the corresponding memmap.Mappable holds the reference
+ // instead (per memmap.Mappable.Translate requirement).
+ off uint64
+
+ // translatePerms is the permissions returned by memmap.Mappable.Translate.
+ // If private is true, translatePerms is usermem.AnyAccess.
+ translatePerms usermem.AccessType
+
+ // effectivePerms is the permissions allowed for non-ignorePermissions
+ // accesses. maxPerms is the permissions allowed for ignorePermissions
+ // accesses. These are vma.effectivePerms and vma.maxPerms respectively,
+ // masked by pma.translatePerms and with Write disallowed if pma.needCOW is
+ // true.
+ //
+ // These are stored in the pma so that the IO implementation can avoid
+ // iterating mm.vmas when pmas already exist.
+ effectivePerms usermem.AccessType
+ maxPerms usermem.AccessType
+
+ // needCOW is true if writes to the mapping must be propagated to a copy.
+ needCOW bool
+
+ // private is true if this pma represents private memory.
+ //
+ // If private is true, file must be MemoryManager.mfp.MemoryFile(), the pma
+ // holds a reference on the mapped memory that is tracked in privateRefs,
+ // and calls to Invalidate for which
+ // memmap.InvalidateOpts.InvalidatePrivate is false should ignore the pma.
+ //
+ // If private is false, this pma caches a translation from the
+ // corresponding vma's memmap.Mappable.Translate.
+ private bool
+
+ // If internalMappings is not empty, it is the cached return value of
+ // file.MapInternal for the platform.FileRange mapped by this pma.
+ internalMappings safemem.BlockSeq `state:"nosave"`
+}
+
+// +stateify savable
+type privateRefs struct {
+ mu sync.Mutex `state:"nosave"`
+
+ // refs maps offsets into MemoryManager.mfp.MemoryFile() to the number of
+ // pmas (or, equivalently, MemoryManagers) that share ownership of the
+ // memory at that offset.
+ refs fileRefcountSet
+}
+
+type invalidateArgs struct {
+ ar usermem.AddrRange
+ opts memmap.InvalidateOpts
+}
+
+// fileRefcountSetFunctions implements segment.Functions for fileRefcountSet.
+type fileRefcountSetFunctions struct{}
+
+func (fileRefcountSetFunctions) MinKey() uint64 {
+ return 0
+}
+
+func (fileRefcountSetFunctions) MaxKey() uint64 {
+ return ^uint64(0)
+}
+
+func (fileRefcountSetFunctions) ClearValue(_ *int32) {
+}
+
+func (fileRefcountSetFunctions) Merge(_ platform.FileRange, rc1 int32, _ platform.FileRange, rc2 int32) (int32, bool) {
+ return rc1, rc1 == rc2
+}
+
+func (fileRefcountSetFunctions) Split(_ platform.FileRange, rc int32, _ uint64) (int32, int32) {
+ return rc, rc
+}
diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go
new file mode 100644
index 000000000..fdc308542
--- /dev/null
+++ b/pkg/sentry/mm/mm_test.go
@@ -0,0 +1,230 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func testMemoryManager(ctx context.Context) *MemoryManager {
+ p := platform.FromContext(ctx)
+ mfp := pgalloc.MemoryFileProviderFromContext(ctx)
+ mm := NewMemoryManager(p, mfp, false)
+ mm.layout = arch.MmapLayout{
+ MinAddr: p.MinUserAddress(),
+ MaxAddr: p.MaxUserAddress(),
+ BottomUpBase: p.MinUserAddress(),
+ TopDownBase: p.MaxUserAddress(),
+ }
+ return mm
+}
+
+func (mm *MemoryManager) realUsageAS() uint64 {
+ return uint64(mm.vmas.Span())
+}
+
+func TestUsageASUpdates(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mm := testMemoryManager(ctx)
+ defer mm.DecUsers(ctx)
+
+ addr, err := mm.MMap(ctx, memmap.MMapOpts{
+ Length: 2 * usermem.PageSize,
+ })
+ if err != nil {
+ t.Fatalf("MMap got err %v want nil", err)
+ }
+ realUsage := mm.realUsageAS()
+ if mm.usageAS != realUsage {
+ t.Fatalf("usageAS believes %v bytes are mapped; %v bytes are actually mapped", mm.usageAS, realUsage)
+ }
+
+ mm.MUnmap(ctx, addr, usermem.PageSize)
+ realUsage = mm.realUsageAS()
+ if mm.usageAS != realUsage {
+ t.Fatalf("usageAS believes %v bytes are mapped; %v bytes are actually mapped", mm.usageAS, realUsage)
+ }
+}
+
+func (mm *MemoryManager) realDataAS() uint64 {
+ var sz uint64
+ for seg := mm.vmas.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ vma := seg.Value()
+ if vma.isPrivateDataLocked() {
+ sz += uint64(seg.Range().Length())
+ }
+ }
+ return sz
+}
+
+func TestDataASUpdates(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mm := testMemoryManager(ctx)
+ defer mm.DecUsers(ctx)
+
+ addr, err := mm.MMap(ctx, memmap.MMapOpts{
+ Length: 3 * usermem.PageSize,
+ Private: true,
+ Perms: usermem.Write,
+ MaxPerms: usermem.AnyAccess,
+ })
+ if err != nil {
+ t.Fatalf("MMap got err %v want nil", err)
+ }
+ if mm.dataAS == 0 {
+ t.Fatalf("dataAS is 0, wanted not 0")
+ }
+ realDataAS := mm.realDataAS()
+ if mm.dataAS != realDataAS {
+ t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS)
+ }
+
+ mm.MUnmap(ctx, addr, usermem.PageSize)
+ realDataAS = mm.realDataAS()
+ if mm.dataAS != realDataAS {
+ t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS)
+ }
+
+ mm.MProtect(addr+usermem.PageSize, usermem.PageSize, usermem.Read, false)
+ realDataAS = mm.realDataAS()
+ if mm.dataAS != realDataAS {
+ t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS)
+ }
+
+ mm.MRemap(ctx, addr+2*usermem.PageSize, usermem.PageSize, 2*usermem.PageSize, MRemapOpts{
+ Move: MRemapMayMove,
+ })
+ realDataAS = mm.realDataAS()
+ if mm.dataAS != realDataAS {
+ t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS)
+ }
+}
+
+func TestBrkDataLimitUpdates(t *testing.T) {
+ limitSet := limits.NewLimitSet()
+ limitSet.Set(limits.Data, limits.Limit{}, true /* privileged */) // zero RLIMIT_DATA
+
+ ctx := contexttest.WithLimitSet(contexttest.Context(t), limitSet)
+ mm := testMemoryManager(ctx)
+ defer mm.DecUsers(ctx)
+
+ // Try to extend the brk by one page and expect doing so to fail.
+ oldBrk, _ := mm.Brk(ctx, 0)
+ if newBrk, _ := mm.Brk(ctx, oldBrk+usermem.PageSize); newBrk != oldBrk {
+ t.Errorf("brk() increased data segment above RLIMIT_DATA (old brk = %#x, new brk = %#x", oldBrk, newBrk)
+ }
+}
+
+// TestIOAfterUnmap ensures that IO fails after unmap.
+func TestIOAfterUnmap(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mm := testMemoryManager(ctx)
+ defer mm.DecUsers(ctx)
+
+ addr, err := mm.MMap(ctx, memmap.MMapOpts{
+ Length: usermem.PageSize,
+ Private: true,
+ Perms: usermem.Read,
+ MaxPerms: usermem.AnyAccess,
+ })
+ if err != nil {
+ t.Fatalf("MMap got err %v want nil", err)
+ }
+
+ // IO works before munmap.
+ b := make([]byte, 1)
+ n, err := mm.CopyIn(ctx, addr, b, usermem.IOOpts{})
+ if err != nil {
+ t.Errorf("CopyIn got err %v want nil", err)
+ }
+ if n != 1 {
+ t.Errorf("CopyIn got %d want 1", n)
+ }
+
+ err = mm.MUnmap(ctx, addr, usermem.PageSize)
+ if err != nil {
+ t.Fatalf("MUnmap got err %v want nil", err)
+ }
+
+ n, err = mm.CopyIn(ctx, addr, b, usermem.IOOpts{})
+ if err != syserror.EFAULT {
+ t.Errorf("CopyIn got err %v want EFAULT", err)
+ }
+ if n != 0 {
+ t.Errorf("CopyIn got %d want 0", n)
+ }
+}
+
+// TestIOAfterMProtect tests IO interaction with mprotect permissions.
+func TestIOAfterMProtect(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mm := testMemoryManager(ctx)
+ defer mm.DecUsers(ctx)
+
+ addr, err := mm.MMap(ctx, memmap.MMapOpts{
+ Length: usermem.PageSize,
+ Private: true,
+ Perms: usermem.ReadWrite,
+ MaxPerms: usermem.AnyAccess,
+ })
+ if err != nil {
+ t.Fatalf("MMap got err %v want nil", err)
+ }
+
+ // Writing works before mprotect.
+ b := make([]byte, 1)
+ n, err := mm.CopyOut(ctx, addr, b, usermem.IOOpts{})
+ if err != nil {
+ t.Errorf("CopyOut got err %v want nil", err)
+ }
+ if n != 1 {
+ t.Errorf("CopyOut got %d want 1", n)
+ }
+
+ err = mm.MProtect(addr, usermem.PageSize, usermem.Read, false)
+ if err != nil {
+ t.Errorf("MProtect got err %v want nil", err)
+ }
+
+ // Without IgnorePermissions, CopyOut should no longer succeed.
+ n, err = mm.CopyOut(ctx, addr, b, usermem.IOOpts{})
+ if err != syserror.EFAULT {
+ t.Errorf("CopyOut got err %v want EFAULT", err)
+ }
+ if n != 0 {
+ t.Errorf("CopyOut got %d want 0", n)
+ }
+
+ // With IgnorePermissions, CopyOut should succeed despite mprotect.
+ n, err = mm.CopyOut(ctx, addr, b, usermem.IOOpts{
+ IgnorePermissions: true,
+ })
+ if err != nil {
+ t.Errorf("CopyOut got err %v want nil", err)
+ }
+ if n != 1 {
+ t.Errorf("CopyOut got %d want 1", n)
+ }
+}
diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go
new file mode 100644
index 000000000..62e4c20af
--- /dev/null
+++ b/pkg/sentry/mm/pma.go
@@ -0,0 +1,1036 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safecopy"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// existingPMAsLocked checks that pmas exist for all addresses in ar, and
+// support access of type (at, ignorePermissions). If so, it returns an
+// iterator to the pma containing ar.Start. Otherwise it returns a terminal
+// iterator.
+//
+// Preconditions: mm.activeMu must be locked. ar.Length() != 0.
+func (mm *MemoryManager) existingPMAsLocked(ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool, needInternalMappings bool) pmaIterator {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ }
+
+ first := mm.pmas.FindSegment(ar.Start)
+ pseg := first
+ for pseg.Ok() {
+ pma := pseg.ValuePtr()
+ perms := pma.effectivePerms
+ if ignorePermissions {
+ perms = pma.maxPerms
+ }
+ if !perms.SupersetOf(at) {
+ return pmaIterator{}
+ }
+ if needInternalMappings && pma.internalMappings.IsEmpty() {
+ return pmaIterator{}
+ }
+
+ if ar.End <= pseg.End() {
+ return first
+ }
+ pseg, _ = pseg.NextNonEmpty()
+ }
+
+ // Ran out of pmas before reaching ar.End.
+ return pmaIterator{}
+}
+
+// existingVecPMAsLocked returns true if pmas exist for all addresses in ars,
+// and support access of type (at, ignorePermissions).
+//
+// Preconditions: mm.activeMu must be locked.
+func (mm *MemoryManager) existingVecPMAsLocked(ars usermem.AddrRangeSeq, at usermem.AccessType, ignorePermissions bool, needInternalMappings bool) bool {
+ for ; !ars.IsEmpty(); ars = ars.Tail() {
+ if ar := ars.Head(); ar.Length() != 0 && !mm.existingPMAsLocked(ar, at, ignorePermissions, needInternalMappings).Ok() {
+ return false
+ }
+ }
+ return true
+}
+
+// getPMAsLocked ensures that pmas exist for all addresses in ar, and support
+// access of type at. It returns:
+//
+// - An iterator to the pma containing ar.Start. If no pma contains ar.Start,
+// the iterator is unspecified.
+//
+// - An iterator to the gap after the last pma containing an address in ar. If
+// pmas exist for no addresses in ar, the iterator is to a gap that begins
+// before ar.Start.
+//
+// - An error that is non-nil if pmas exist for only a subset of ar.
+//
+// Preconditions: mm.mappingMu must be locked. mm.activeMu must be locked for
+// writing. ar.Length() != 0. vseg.Range().Contains(ar.Start). vmas must exist
+// for all addresses in ar, and support accesses of type at (i.e. permission
+// checks must have been performed against vmas).
+func (mm *MemoryManager) getPMAsLocked(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, at usermem.AccessType) (pmaIterator, pmaGapIterator, error) {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ if !vseg.Ok() {
+ panic("terminal vma iterator")
+ }
+ if !vseg.Range().Contains(ar.Start) {
+ panic(fmt.Sprintf("initial vma %v does not cover start of ar %v", vseg.Range(), ar))
+ }
+ }
+
+ // Page-align ar so that all AddrRanges are aligned.
+ end, ok := ar.End.RoundUp()
+ var alignerr error
+ if !ok {
+ end = ar.End.RoundDown()
+ alignerr = syserror.EFAULT
+ }
+ ar = usermem.AddrRange{ar.Start.RoundDown(), end}
+
+ pstart, pend, perr := mm.getPMAsInternalLocked(ctx, vseg, ar, at)
+ if pend.Start() <= ar.Start {
+ return pmaIterator{}, pend, perr
+ }
+ // getPMAsInternalLocked may not have returned pstart due to iterator
+ // invalidation.
+ if !pstart.Ok() {
+ pstart = mm.findOrSeekPrevUpperBoundPMA(ar.Start, pend)
+ }
+ if perr != nil {
+ return pstart, pend, perr
+ }
+ return pstart, pend, alignerr
+}
+
+// getVecPMAsLocked ensures that pmas exist for all addresses in ars, and
+// support access of type at. It returns the subset of ars for which pmas
+// exist. If this is not equal to ars, it returns a non-nil error explaining
+// why.
+//
+// Preconditions: mm.mappingMu must be locked. mm.activeMu must be locked for
+// writing. vmas must exist for all addresses in ars, and support accesses of
+// type at (i.e. permission checks must have been performed against vmas).
+func (mm *MemoryManager) getVecPMAsLocked(ctx context.Context, ars usermem.AddrRangeSeq, at usermem.AccessType) (usermem.AddrRangeSeq, error) {
+ for arsit := ars; !arsit.IsEmpty(); arsit = arsit.Tail() {
+ ar := arsit.Head()
+ if ar.Length() == 0 {
+ continue
+ }
+ if checkInvariants {
+ if !ar.WellFormed() {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ }
+
+ // Page-align ar so that all AddrRanges are aligned.
+ end, ok := ar.End.RoundUp()
+ var alignerr error
+ if !ok {
+ end = ar.End.RoundDown()
+ alignerr = syserror.EFAULT
+ }
+ ar = usermem.AddrRange{ar.Start.RoundDown(), end}
+
+ _, pend, perr := mm.getPMAsInternalLocked(ctx, mm.vmas.FindSegment(ar.Start), ar, at)
+ if perr != nil {
+ return truncatedAddrRangeSeq(ars, arsit, pend.Start()), perr
+ }
+ if alignerr != nil {
+ return truncatedAddrRangeSeq(ars, arsit, pend.Start()), alignerr
+ }
+ }
+
+ return ars, nil
+}
+
+// getPMAsInternalLocked is equivalent to getPMAsLocked, with the following
+// exceptions:
+//
+// - getPMAsInternalLocked returns a pmaIterator on a best-effort basis (that
+// is, the returned iterator may be terminal, even if a pma that contains
+// ar.Start exists). Returning this iterator on a best-effort basis allows
+// callers that require it to use it when it's cheaply available, while also
+// avoiding the overhead of retrieving it when it's not.
+//
+// - getPMAsInternalLocked additionally requires that ar is page-aligned.
+//
+// getPMAsInternalLocked is an implementation helper for getPMAsLocked and
+// getVecPMAsLocked; other clients should call one of those instead.
+func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, at usermem.AccessType) (pmaIterator, pmaGapIterator, error) {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ if !vseg.Ok() {
+ panic("terminal vma iterator")
+ }
+ if !vseg.Range().Contains(ar.Start) {
+ panic(fmt.Sprintf("initial vma %v does not cover start of ar %v", vseg.Range(), ar))
+ }
+ }
+
+ mf := mm.mfp.MemoryFile()
+ // Limit the range we allocate to ar, aligned to privateAllocUnit.
+ maskAR := privateAligned(ar)
+ didUnmapAS := false
+ // The range in which we iterate vmas and pmas is still limited to ar, to
+ // ensure that we don't allocate or COW-break a pma we don't need.
+ pseg, pgap := mm.pmas.Find(ar.Start)
+ pstart := pseg
+ for {
+ // Get pmas for this vma.
+ vsegAR := vseg.Range().Intersect(ar)
+ vma := vseg.ValuePtr()
+ pmaLoop:
+ for {
+ switch {
+ case pgap.Ok() && pgap.Start() < vsegAR.End:
+ // Need a pma here.
+ optAR := vseg.Range().Intersect(pgap.Range())
+ if checkInvariants {
+ if optAR.Length() <= 0 {
+ panic(fmt.Sprintf("vseg %v and pgap %v do not overlap", vseg, pgap))
+ }
+ }
+ if vma.mappable == nil {
+ // Private anonymous mappings get pmas by allocating.
+ allocAR := optAR.Intersect(maskAR)
+ fr, err := mf.Allocate(uint64(allocAR.Length()), usage.Anonymous)
+ if err != nil {
+ return pstart, pgap, err
+ }
+ if checkInvariants {
+ if !fr.WellFormed() || fr.Length() != uint64(allocAR.Length()) {
+ panic(fmt.Sprintf("Allocate(%v) returned invalid FileRange %v", allocAR.Length(), fr))
+ }
+ }
+ mm.addRSSLocked(allocAR)
+ mm.incPrivateRef(fr)
+ mf.IncRef(fr)
+ pseg, pgap = mm.pmas.Insert(pgap, allocAR, pma{
+ file: mf,
+ off: fr.Start,
+ translatePerms: usermem.AnyAccess,
+ effectivePerms: vma.effectivePerms,
+ maxPerms: vma.maxPerms,
+ // Since we just allocated this memory and have the
+ // only reference, the new pma does not need
+ // copy-on-write.
+ private: true,
+ }).NextNonEmpty()
+ pstart = pmaIterator{} // iterators invalidated
+ } else {
+ // Other mappings get pmas by translating.
+ optMR := vseg.mappableRangeOf(optAR)
+ reqAR := optAR.Intersect(ar)
+ reqMR := vseg.mappableRangeOf(reqAR)
+ perms := at
+ if vma.private {
+ // This pma will be copy-on-write; don't require write
+ // permission, but do require read permission to
+ // facilitate the copy.
+ //
+ // If at.Write is true, we will need to break
+ // copy-on-write immediately, which occurs after
+ // translation below.
+ perms.Read = true
+ perms.Write = false
+ }
+ ts, err := vma.mappable.Translate(ctx, reqMR, optMR, perms)
+ if checkInvariants {
+ if err := memmap.CheckTranslateResult(reqMR, optMR, perms, ts, err); err != nil {
+ panic(fmt.Sprintf("Mappable(%T).Translate(%v, %v, %v): %v", vma.mappable, reqMR, optMR, perms, err))
+ }
+ }
+ // Install a pma for each translation.
+ if len(ts) == 0 {
+ return pstart, pgap, err
+ }
+ pstart = pmaIterator{} // iterators invalidated
+ for _, t := range ts {
+ newpmaAR := vseg.addrRangeOf(t.Source)
+ newpma := pma{
+ file: t.File,
+ off: t.Offset,
+ translatePerms: t.Perms,
+ effectivePerms: vma.effectivePerms.Intersect(t.Perms),
+ maxPerms: vma.maxPerms.Intersect(t.Perms),
+ }
+ if vma.private {
+ newpma.effectivePerms.Write = false
+ newpma.maxPerms.Write = false
+ newpma.needCOW = true
+ }
+ mm.addRSSLocked(newpmaAR)
+ t.File.IncRef(t.FileRange())
+ // This is valid because memmap.Mappable.Translate is
+ // required to return Translations in increasing
+ // Translation.Source order.
+ pseg = mm.pmas.Insert(pgap, newpmaAR, newpma)
+ pgap = pseg.NextGap()
+ }
+ // The error returned by Translate is only significant if
+ // it occurred before ar.End.
+ if err != nil && vseg.addrRangeOf(ts[len(ts)-1].Source).End < ar.End {
+ return pstart, pgap, err
+ }
+ // Rewind pseg to the first pma inserted and continue the
+ // loop to check if we need to break copy-on-write.
+ pseg, pgap = mm.findOrSeekPrevUpperBoundPMA(vseg.addrRangeOf(ts[0].Source).Start, pgap), pmaGapIterator{}
+ continue
+ }
+
+ case pseg.Ok() && pseg.Start() < vsegAR.End:
+ oldpma := pseg.ValuePtr()
+ if at.Write && mm.isPMACopyOnWriteLocked(vseg, pseg) {
+ // Break copy-on-write by copying.
+ if checkInvariants {
+ if !oldpma.maxPerms.Read {
+ panic(fmt.Sprintf("pma %v needs to be copied for writing, but is not readable: %v", pseg.Range(), oldpma))
+ }
+ }
+ // The majority of copy-on-write breaks on executable pages
+ // come from:
+ //
+ // - The ELF loader, which must zero out bytes on the last
+ // page of each segment after the end of the segment.
+ //
+ // - gdb's use of ptrace to insert breakpoints.
+ //
+ // Neither of these cases has enough spatial locality to
+ // benefit from copying nearby pages, so if the vma is
+ // executable, only copy the pages required.
+ var copyAR usermem.AddrRange
+ if vseg.ValuePtr().effectivePerms.Execute {
+ copyAR = pseg.Range().Intersect(ar)
+ } else {
+ copyAR = pseg.Range().Intersect(maskAR)
+ }
+ // Get internal mappings from the pma to copy from.
+ if err := pseg.getInternalMappingsLocked(); err != nil {
+ return pstart, pseg.PrevGap(), err
+ }
+ // Copy contents.
+ fr, err := mf.AllocateAndFill(uint64(copyAR.Length()), usage.Anonymous, &safemem.BlockSeqReader{mm.internalMappingsLocked(pseg, copyAR)})
+ if _, ok := err.(safecopy.BusError); ok {
+ // If we got SIGBUS during the copy, deliver SIGBUS to
+ // userspace (instead of SIGSEGV) if we're breaking
+ // copy-on-write due to application page fault.
+ err = &memmap.BusError{err}
+ }
+ if fr.Length() == 0 {
+ return pstart, pseg.PrevGap(), err
+ }
+ // Unmap all of maskAR, not just copyAR, to minimize host
+ // syscalls. AddressSpace mappings must be removed before
+ // mm.decPrivateRef().
+ if !didUnmapAS {
+ mm.unmapASLocked(maskAR)
+ didUnmapAS = true
+ }
+ // Replace the pma with a copy in the part of the address
+ // range where copying was successful. This doesn't change
+ // RSS.
+ copyAR.End = copyAR.Start + usermem.Addr(fr.Length())
+ if copyAR != pseg.Range() {
+ pseg = mm.pmas.Isolate(pseg, copyAR)
+ pstart = pmaIterator{} // iterators invalidated
+ }
+ oldpma = pseg.ValuePtr()
+ if oldpma.private {
+ mm.decPrivateRef(pseg.fileRange())
+ }
+ oldpma.file.DecRef(pseg.fileRange())
+ mm.incPrivateRef(fr)
+ mf.IncRef(fr)
+ oldpma.file = mf
+ oldpma.off = fr.Start
+ oldpma.translatePerms = usermem.AnyAccess
+ oldpma.effectivePerms = vma.effectivePerms
+ oldpma.maxPerms = vma.maxPerms
+ oldpma.needCOW = false
+ oldpma.private = true
+ oldpma.internalMappings = safemem.BlockSeq{}
+ // Try to merge the pma with its neighbors.
+ if prev := pseg.PrevSegment(); prev.Ok() {
+ if merged := mm.pmas.Merge(prev, pseg); merged.Ok() {
+ pseg = merged
+ pstart = pmaIterator{} // iterators invalidated
+ }
+ }
+ if next := pseg.NextSegment(); next.Ok() {
+ if merged := mm.pmas.Merge(pseg, next); merged.Ok() {
+ pseg = merged
+ pstart = pmaIterator{} // iterators invalidated
+ }
+ }
+ // The error returned by AllocateAndFill is only
+ // significant if it occurred before ar.End.
+ if err != nil && pseg.End() < ar.End {
+ return pstart, pseg.NextGap(), err
+ }
+ // Ensure pseg and pgap are correct for the next iteration
+ // of the loop.
+ pseg, pgap = pseg.NextNonEmpty()
+ } else if !oldpma.translatePerms.SupersetOf(at) {
+ // Get new pmas (with sufficient permissions) by calling
+ // memmap.Mappable.Translate again.
+ if checkInvariants {
+ if oldpma.private {
+ panic(fmt.Sprintf("private pma %v has non-maximal pma.translatePerms: %v", pseg.Range(), oldpma))
+ }
+ }
+ // Allow the entire pma to be replaced.
+ optAR := pseg.Range()
+ optMR := vseg.mappableRangeOf(optAR)
+ reqAR := optAR.Intersect(ar)
+ reqMR := vseg.mappableRangeOf(reqAR)
+ perms := oldpma.translatePerms.Union(at)
+ ts, err := vma.mappable.Translate(ctx, reqMR, optMR, perms)
+ if checkInvariants {
+ if err := memmap.CheckTranslateResult(reqMR, optMR, perms, ts, err); err != nil {
+ panic(fmt.Sprintf("Mappable(%T).Translate(%v, %v, %v): %v", vma.mappable, reqMR, optMR, perms, err))
+ }
+ }
+ // Remove the part of the existing pma covered by new
+ // Translations, then insert new pmas. This doesn't change
+ // RSS. Note that we don't need to call unmapASLocked: any
+ // existing AddressSpace mappings are still valid (though
+ // less permissive than the new pmas indicate) until
+ // Invalidate is called, and will be replaced by future
+ // calls to mapASLocked.
+ if len(ts) == 0 {
+ return pstart, pseg.PrevGap(), err
+ }
+ transMR := memmap.MappableRange{ts[0].Source.Start, ts[len(ts)-1].Source.End}
+ transAR := vseg.addrRangeOf(transMR)
+ pseg = mm.pmas.Isolate(pseg, transAR)
+ pseg.ValuePtr().file.DecRef(pseg.fileRange())
+ pgap = mm.pmas.Remove(pseg)
+ pstart = pmaIterator{} // iterators invalidated
+ for _, t := range ts {
+ newpmaAR := vseg.addrRangeOf(t.Source)
+ newpma := pma{
+ file: t.File,
+ off: t.Offset,
+ translatePerms: t.Perms,
+ effectivePerms: vma.effectivePerms.Intersect(t.Perms),
+ maxPerms: vma.maxPerms.Intersect(t.Perms),
+ }
+ if vma.private {
+ newpma.effectivePerms.Write = false
+ newpma.maxPerms.Write = false
+ newpma.needCOW = true
+ }
+ t.File.IncRef(t.FileRange())
+ pseg = mm.pmas.Insert(pgap, newpmaAR, newpma)
+ pgap = pseg.NextGap()
+ }
+ // The error returned by Translate is only significant if
+ // it occurred before ar.End.
+ if err != nil && pseg.End() < ar.End {
+ return pstart, pgap, err
+ }
+ // Ensure pseg and pgap are correct for the next iteration
+ // of the loop.
+ if pgap.Range().Length() == 0 {
+ pseg, pgap = pgap.NextSegment(), pmaGapIterator{}
+ } else {
+ pseg = pmaIterator{}
+ }
+ } else {
+ // We have a usable pma; continue.
+ pseg, pgap = pseg.NextNonEmpty()
+ }
+
+ default:
+ break pmaLoop
+ }
+ }
+ // Go to the next vma.
+ if ar.End <= vseg.End() {
+ if pgap.Ok() {
+ return pstart, pgap, nil
+ }
+ return pstart, pseg.PrevGap(), nil
+ }
+ vseg = vseg.NextSegment()
+ }
+}
+
+const (
+ // When memory is allocated for a private pma, align the allocated address
+ // range to a privateAllocUnit boundary when possible. Larger values of
+ // privateAllocUnit may reduce page faults by allowing fewer, larger pmas
+ // to be mapped, but may result in larger amounts of wasted memory in the
+ // presence of fragmentation. privateAllocUnit must be a power-of-2
+ // multiple of usermem.PageSize.
+ privateAllocUnit = usermem.HugePageSize
+
+ privateAllocMask = privateAllocUnit - 1
+)
+
+func privateAligned(ar usermem.AddrRange) usermem.AddrRange {
+ aligned := usermem.AddrRange{ar.Start &^ privateAllocMask, ar.End}
+ if end := (ar.End + privateAllocMask) &^ privateAllocMask; end >= ar.End {
+ aligned.End = end
+ }
+ if checkInvariants {
+ if !aligned.IsSupersetOf(ar) {
+ panic(fmt.Sprintf("aligned AddrRange %#v is not a superset of ar %#v", aligned, ar))
+ }
+ }
+ return aligned
+}
+
+// isPMACopyOnWriteLocked returns true if the contents of the pma represented
+// by pseg must be copied to a new private pma to be written to.
+//
+// If the pma is a copy-on-write private pma, and holds the only reference on
+// the memory it maps, isPMACopyOnWriteLocked will take ownership of the memory
+// and update the pma to indicate that it does not require copy-on-write.
+//
+// Preconditions: vseg.Range().IsSupersetOf(pseg.Range()). mm.mappingMu must be
+// locked. mm.activeMu must be locked for writing.
+func (mm *MemoryManager) isPMACopyOnWriteLocked(vseg vmaIterator, pseg pmaIterator) bool {
+ pma := pseg.ValuePtr()
+ if !pma.needCOW {
+ return false
+ }
+ if !pma.private {
+ return true
+ }
+ // If we have the only reference on private memory to be copied, just take
+ // ownership of it instead of copying. If we do hold the only reference,
+ // additional references can only be taken by mm.Fork(), which is excluded
+ // by mm.activeMu, so this isn't racy.
+ mm.privateRefs.mu.Lock()
+ defer mm.privateRefs.mu.Unlock()
+ fr := pseg.fileRange()
+ // This check relies on mm.privateRefs.refs being kept fully merged.
+ rseg := mm.privateRefs.refs.FindSegment(fr.Start)
+ if rseg.Ok() && rseg.Value() == 1 && fr.End <= rseg.End() {
+ pma.needCOW = false
+ // pma.private => pma.translatePerms == usermem.AnyAccess
+ vma := vseg.ValuePtr()
+ pma.effectivePerms = vma.effectivePerms
+ pma.maxPerms = vma.maxPerms
+ return false
+ }
+ return true
+}
+
+// Invalidate implements memmap.MappingSpace.Invalidate.
+func (mm *MemoryManager) Invalidate(ar usermem.AddrRange, opts memmap.InvalidateOpts) {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ }
+
+ mm.activeMu.Lock()
+ defer mm.activeMu.Unlock()
+ if mm.captureInvalidations {
+ mm.capturedInvalidations = append(mm.capturedInvalidations, invalidateArgs{ar, opts})
+ return
+ }
+ mm.invalidateLocked(ar, opts.InvalidatePrivate, true)
+}
+
+// invalidateLocked removes pmas and AddressSpace mappings of those pmas for
+// addresses in ar.
+//
+// Preconditions: mm.activeMu must be locked for writing. ar.Length() != 0. ar
+// must be page-aligned.
+func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivate, invalidateShared bool) {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ }
+
+ var didUnmapAS bool
+ pseg := mm.pmas.LowerBoundSegment(ar.Start)
+ for pseg.Ok() && pseg.Start() < ar.End {
+ pma := pseg.ValuePtr()
+ if (invalidatePrivate && pma.private) || (invalidateShared && !pma.private) {
+ pseg = mm.pmas.Isolate(pseg, ar)
+ pma = pseg.ValuePtr()
+ if !didUnmapAS {
+ // Unmap all of ar, not just pseg.Range(), to minimize host
+ // syscalls. AddressSpace mappings must be removed before
+ // mm.decPrivateRef().
+ mm.unmapASLocked(ar)
+ didUnmapAS = true
+ }
+ if pma.private {
+ mm.decPrivateRef(pseg.fileRange())
+ }
+ mm.removeRSSLocked(pseg.Range())
+ pma.file.DecRef(pseg.fileRange())
+ pseg = mm.pmas.Remove(pseg).NextSegment()
+ } else {
+ pseg = pseg.NextSegment()
+ }
+ }
+}
+
+// Pin returns the platform.File ranges currently mapped by addresses in ar in
+// mm, acquiring a reference on the returned ranges which the caller must
+// release by calling Unpin. If not all addresses are mapped, Pin returns a
+// non-nil error. Note that Pin may return both a non-empty slice of
+// PinnedRanges and a non-nil error.
+//
+// Pin does not prevent mapped ranges from changing, making it unsuitable for
+// most I/O. It should only be used in contexts that would use get_user_pages()
+// in the Linux kernel.
+//
+// Preconditions: ar.Length() != 0. ar must be page-aligned.
+func (mm *MemoryManager) Pin(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) ([]PinnedRange, error) {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ }
+
+ // Ensure that we have usable vmas.
+ mm.mappingMu.RLock()
+ vseg, vend, verr := mm.getVMAsLocked(ctx, ar, at, ignorePermissions)
+ if vendaddr := vend.Start(); vendaddr < ar.End {
+ if vendaddr <= ar.Start {
+ mm.mappingMu.RUnlock()
+ return nil, verr
+ }
+ ar.End = vendaddr
+ }
+
+ // Ensure that we have usable pmas.
+ mm.activeMu.Lock()
+ pseg, pend, perr := mm.getPMAsLocked(ctx, vseg, ar, at)
+ mm.mappingMu.RUnlock()
+ if pendaddr := pend.Start(); pendaddr < ar.End {
+ if pendaddr <= ar.Start {
+ mm.activeMu.Unlock()
+ return nil, perr
+ }
+ ar.End = pendaddr
+ }
+
+ // Gather pmas.
+ var prs []PinnedRange
+ for pseg.Ok() && pseg.Start() < ar.End {
+ psar := pseg.Range().Intersect(ar)
+ f := pseg.ValuePtr().file
+ fr := pseg.fileRangeOf(psar)
+ f.IncRef(fr)
+ prs = append(prs, PinnedRange{
+ Source: psar,
+ File: f,
+ Offset: fr.Start,
+ })
+ pseg = pseg.NextSegment()
+ }
+ mm.activeMu.Unlock()
+
+ // Return the first error in order of progress through ar.
+ if perr != nil {
+ return prs, perr
+ }
+ return prs, verr
+}
+
+// PinnedRanges are returned by MemoryManager.Pin.
+type PinnedRange struct {
+ // Source is the corresponding range of addresses.
+ Source usermem.AddrRange
+
+ // File is the mapped file.
+ File platform.File
+
+ // Offset is the offset into File at which this PinnedRange begins.
+ Offset uint64
+}
+
+// FileRange returns the platform.File offsets mapped by pr.
+func (pr PinnedRange) FileRange() platform.FileRange {
+ return platform.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())}
+}
+
+// Unpin releases the reference held by prs.
+func Unpin(prs []PinnedRange) {
+ for i := range prs {
+ prs[i].File.DecRef(prs[i].FileRange())
+ }
+}
+
+// movePMAsLocked moves all pmas in oldAR to newAR.
+//
+// Preconditions: mm.activeMu must be locked for writing. oldAR.Length() != 0.
+// oldAR.Length() <= newAR.Length(). !oldAR.Overlaps(newAR).
+// mm.pmas.IsEmptyRange(newAR). oldAR and newAR must be page-aligned.
+func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) {
+ if checkInvariants {
+ if !oldAR.WellFormed() || oldAR.Length() <= 0 || !oldAR.IsPageAligned() {
+ panic(fmt.Sprintf("invalid oldAR: %v", oldAR))
+ }
+ if !newAR.WellFormed() || newAR.Length() <= 0 || !newAR.IsPageAligned() {
+ panic(fmt.Sprintf("invalid newAR: %v", newAR))
+ }
+ if oldAR.Length() > newAR.Length() {
+ panic(fmt.Sprintf("old address range %v may contain pmas that will not fit in new address range %v", oldAR, newAR))
+ }
+ if oldAR.Overlaps(newAR) {
+ panic(fmt.Sprintf("old and new address ranges overlap: %v, %v", oldAR, newAR))
+ }
+ // mm.pmas.IsEmptyRange is checked by mm.pmas.Insert.
+ }
+
+ type movedPMA struct {
+ oldAR usermem.AddrRange
+ pma pma
+ }
+ var movedPMAs []movedPMA
+ pseg := mm.pmas.LowerBoundSegment(oldAR.Start)
+ for pseg.Ok() && pseg.Start() < oldAR.End {
+ pseg = mm.pmas.Isolate(pseg, oldAR)
+ movedPMAs = append(movedPMAs, movedPMA{
+ oldAR: pseg.Range(),
+ pma: pseg.Value(),
+ })
+ pseg = mm.pmas.Remove(pseg).NextSegment()
+ // No RSS change is needed since we're re-inserting the same pmas
+ // below.
+ }
+
+ off := newAR.Start - oldAR.Start
+ pgap := mm.pmas.FindGap(newAR.Start)
+ for i := range movedPMAs {
+ mpma := &movedPMAs[i]
+ pmaNewAR := usermem.AddrRange{mpma.oldAR.Start + off, mpma.oldAR.End + off}
+ pgap = mm.pmas.Insert(pgap, pmaNewAR, mpma.pma).NextGap()
+ }
+
+ mm.unmapASLocked(oldAR)
+}
+
+// getPMAInternalMappingsLocked ensures that pmas for all addresses in ar have
+// cached internal mappings. It returns:
+//
+// - An iterator to the gap after the last pma with internal mappings
+// containing an address in ar. If internal mappings exist for no addresses in
+// ar, the iterator is to a gap that begins before ar.Start.
+//
+// - An error that is non-nil if internal mappings exist for only a subset of
+// ar.
+//
+// Preconditions: mm.activeMu must be locked for writing.
+// pseg.Range().Contains(ar.Start). pmas must exist for all addresses in ar.
+// ar.Length() != 0.
+//
+// Postconditions: getPMAInternalMappingsLocked does not invalidate iterators
+// into mm.pmas.
+func (mm *MemoryManager) getPMAInternalMappingsLocked(pseg pmaIterator, ar usermem.AddrRange) (pmaGapIterator, error) {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ if !pseg.Range().Contains(ar.Start) {
+ panic(fmt.Sprintf("initial pma %v does not cover start of ar %v", pseg.Range(), ar))
+ }
+ }
+
+ for {
+ if err := pseg.getInternalMappingsLocked(); err != nil {
+ return pseg.PrevGap(), err
+ }
+ if ar.End <= pseg.End() {
+ return pseg.NextGap(), nil
+ }
+ pseg, _ = pseg.NextNonEmpty()
+ }
+}
+
+// getVecPMAInternalMappingsLocked ensures that pmas for all addresses in ars
+// have cached internal mappings. It returns the subset of ars for which
+// internal mappings exist. If this is not equal to ars, it returns a non-nil
+// error explaining why.
+//
+// Preconditions: mm.activeMu must be locked for writing. pmas must exist for
+// all addresses in ar.
+//
+// Postconditions: getVecPMAInternalMappingsLocked does not invalidate iterators
+// into mm.pmas.
+func (mm *MemoryManager) getVecPMAInternalMappingsLocked(ars usermem.AddrRangeSeq) (usermem.AddrRangeSeq, error) {
+ for arsit := ars; !arsit.IsEmpty(); arsit = arsit.Tail() {
+ ar := arsit.Head()
+ if ar.Length() == 0 {
+ continue
+ }
+ if pend, err := mm.getPMAInternalMappingsLocked(mm.pmas.FindSegment(ar.Start), ar); err != nil {
+ return truncatedAddrRangeSeq(ars, arsit, pend.Start()), err
+ }
+ }
+ return ars, nil
+}
+
+// internalMappingsLocked returns internal mappings for addresses in ar.
+//
+// Preconditions: mm.activeMu must be locked. Internal mappings must have been
+// previously established for all addresses in ar. ar.Length() != 0.
+// pseg.Range().Contains(ar.Start).
+func (mm *MemoryManager) internalMappingsLocked(pseg pmaIterator, ar usermem.AddrRange) safemem.BlockSeq {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ if !pseg.Range().Contains(ar.Start) {
+ panic(fmt.Sprintf("initial pma %v does not cover start of ar %v", pseg.Range(), ar))
+ }
+ }
+
+ if ar.End <= pseg.End() {
+ // Since only one pma is involved, we can use pma.internalMappings
+ // directly, avoiding a slice allocation.
+ offset := uint64(ar.Start - pseg.Start())
+ return pseg.ValuePtr().internalMappings.DropFirst64(offset).TakeFirst64(uint64(ar.Length()))
+ }
+
+ var ims []safemem.Block
+ for {
+ pr := pseg.Range().Intersect(ar)
+ for pims := pseg.ValuePtr().internalMappings.DropFirst64(uint64(pr.Start - pseg.Start())).TakeFirst64(uint64(pr.Length())); !pims.IsEmpty(); pims = pims.Tail() {
+ ims = append(ims, pims.Head())
+ }
+ if ar.End <= pseg.End() {
+ break
+ }
+ pseg = pseg.NextSegment()
+ }
+ return safemem.BlockSeqFromSlice(ims)
+}
+
+// vecInternalMappingsLocked returns internal mappings for addresses in ars.
+//
+// Preconditions: mm.activeMu must be locked. Internal mappings must have been
+// previously established for all addresses in ars.
+func (mm *MemoryManager) vecInternalMappingsLocked(ars usermem.AddrRangeSeq) safemem.BlockSeq {
+ var ims []safemem.Block
+ for ; !ars.IsEmpty(); ars = ars.Tail() {
+ ar := ars.Head()
+ if ar.Length() == 0 {
+ continue
+ }
+ for pims := mm.internalMappingsLocked(mm.pmas.FindSegment(ar.Start), ar); !pims.IsEmpty(); pims = pims.Tail() {
+ ims = append(ims, pims.Head())
+ }
+ }
+ return safemem.BlockSeqFromSlice(ims)
+}
+
+// incPrivateRef acquires a reference on private pages in fr.
+func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) {
+ mm.privateRefs.mu.Lock()
+ defer mm.privateRefs.mu.Unlock()
+ refSet := &mm.privateRefs.refs
+ seg, gap := refSet.Find(fr.Start)
+ for {
+ switch {
+ case seg.Ok() && seg.Start() < fr.End:
+ seg = refSet.Isolate(seg, fr)
+ seg.SetValue(seg.Value() + 1)
+ seg, gap = seg.NextNonEmpty()
+ case gap.Ok() && gap.Start() < fr.End:
+ seg, gap = refSet.InsertWithoutMerging(gap, gap.Range().Intersect(fr), 1).NextNonEmpty()
+ default:
+ refSet.MergeAdjacent(fr)
+ return
+ }
+ }
+}
+
+// decPrivateRef releases a reference on private pages in fr.
+func (mm *MemoryManager) decPrivateRef(fr platform.FileRange) {
+ var freed []platform.FileRange
+
+ mm.privateRefs.mu.Lock()
+ refSet := &mm.privateRefs.refs
+ seg := refSet.LowerBoundSegment(fr.Start)
+ for seg.Ok() && seg.Start() < fr.End {
+ seg = refSet.Isolate(seg, fr)
+ if old := seg.Value(); old == 1 {
+ freed = append(freed, seg.Range())
+ seg = refSet.Remove(seg).NextSegment()
+ } else {
+ seg.SetValue(old - 1)
+ seg = seg.NextSegment()
+ }
+ }
+ refSet.MergeAdjacent(fr)
+ mm.privateRefs.mu.Unlock()
+
+ mf := mm.mfp.MemoryFile()
+ for _, fr := range freed {
+ mf.DecRef(fr)
+ }
+}
+
+// addRSSLocked updates the current and maximum resident set size of a
+// MemoryManager to reflect the insertion of a pma at ar.
+//
+// Preconditions: mm.activeMu must be locked for writing.
+func (mm *MemoryManager) addRSSLocked(ar usermem.AddrRange) {
+ mm.curRSS += uint64(ar.Length())
+ if mm.curRSS > mm.maxRSS {
+ mm.maxRSS = mm.curRSS
+ }
+}
+
+// removeRSSLocked updates the current resident set size of a MemoryManager to
+// reflect the removal of a pma at ar.
+//
+// Preconditions: mm.activeMu must be locked for writing.
+func (mm *MemoryManager) removeRSSLocked(ar usermem.AddrRange) {
+ mm.curRSS -= uint64(ar.Length())
+}
+
+// pmaSetFunctions implements segment.Functions for pmaSet.
+type pmaSetFunctions struct{}
+
+func (pmaSetFunctions) MinKey() usermem.Addr {
+ return 0
+}
+
+func (pmaSetFunctions) MaxKey() usermem.Addr {
+ return ^usermem.Addr(0)
+}
+
+func (pmaSetFunctions) ClearValue(pma *pma) {
+ pma.file = nil
+ pma.internalMappings = safemem.BlockSeq{}
+}
+
+func (pmaSetFunctions) Merge(ar1 usermem.AddrRange, pma1 pma, ar2 usermem.AddrRange, pma2 pma) (pma, bool) {
+ if pma1.file != pma2.file ||
+ pma1.off+uint64(ar1.Length()) != pma2.off ||
+ pma1.translatePerms != pma2.translatePerms ||
+ pma1.effectivePerms != pma2.effectivePerms ||
+ pma1.maxPerms != pma2.maxPerms ||
+ pma1.needCOW != pma2.needCOW ||
+ pma1.private != pma2.private {
+ return pma{}, false
+ }
+
+ // Discard internal mappings instead of trying to merge them, since merging
+ // them requires an allocation and getting them again from the
+ // platform.File might not.
+ pma1.internalMappings = safemem.BlockSeq{}
+ return pma1, true
+}
+
+func (pmaSetFunctions) Split(ar usermem.AddrRange, p pma, split usermem.Addr) (pma, pma) {
+ newlen1 := uint64(split - ar.Start)
+ p2 := p
+ p2.off += newlen1
+ if !p.internalMappings.IsEmpty() {
+ p.internalMappings = p.internalMappings.TakeFirst64(newlen1)
+ p2.internalMappings = p2.internalMappings.DropFirst64(newlen1)
+ }
+ return p, p2
+}
+
+// findOrSeekPrevUpperBoundPMA returns mm.pmas.UpperBoundSegment(addr), but may do
+// so by scanning linearly backward from pgap.
+//
+// Preconditions: mm.activeMu must be locked. addr <= pgap.Start().
+func (mm *MemoryManager) findOrSeekPrevUpperBoundPMA(addr usermem.Addr, pgap pmaGapIterator) pmaIterator {
+ if checkInvariants {
+ if !pgap.Ok() {
+ panic("terminal pma iterator")
+ }
+ if addr > pgap.Start() {
+ panic(fmt.Sprintf("can't seek backward to %#x from %#x", addr, pgap.Start()))
+ }
+ }
+ // Optimistically check if pgap.PrevSegment() is the PMA we're looking for,
+ // which is the case if findOrSeekPrevUpperBoundPMA is called to find the
+ // start of a range containing only a single PMA.
+ if pseg := pgap.PrevSegment(); pseg.Start() <= addr {
+ return pseg
+ }
+ return mm.pmas.UpperBoundSegment(addr)
+}
+
+// getInternalMappingsLocked ensures that pseg.ValuePtr().internalMappings is
+// non-empty.
+//
+// Preconditions: mm.activeMu must be locked for writing.
+func (pseg pmaIterator) getInternalMappingsLocked() error {
+ pma := pseg.ValuePtr()
+ if pma.internalMappings.IsEmpty() {
+ // This must use maxPerms (instead of perms) because some permission
+ // constraints are only visible to vmas; for example, mappings of
+ // read-only files have vma.maxPerms.Write unset, but this may not be
+ // visible to the memmap.Mappable.
+ perms := pma.maxPerms
+ // We will never execute application code through an internal mapping.
+ perms.Execute = false
+ ims, err := pma.file.MapInternal(pseg.fileRange(), perms)
+ if err != nil {
+ return err
+ }
+ pma.internalMappings = ims
+ }
+ return nil
+}
+
+func (pseg pmaIterator) fileRange() platform.FileRange {
+ return pseg.fileRangeOf(pseg.Range())
+}
+
+// Preconditions: pseg.Range().IsSupersetOf(ar). ar.Length != 0.
+func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange {
+ if checkInvariants {
+ if !pseg.Ok() {
+ panic("terminal pma iterator")
+ }
+ if !ar.WellFormed() || ar.Length() <= 0 {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ if !pseg.Range().IsSupersetOf(ar) {
+ panic(fmt.Sprintf("ar %v out of bounds %v", ar, pseg.Range()))
+ }
+ }
+
+ pma := pseg.ValuePtr()
+ pstart := pseg.Start()
+ return platform.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)}
+}
diff --git a/pkg/sentry/mm/procfs.go b/pkg/sentry/mm/procfs.go
new file mode 100644
index 000000000..6efe5102b
--- /dev/null
+++ b/pkg/sentry/mm/procfs.go
@@ -0,0 +1,329 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "bytes"
+ "fmt"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ // devMinorBits is the number of minor bits in a device number. Linux:
+ // include/linux/kdev_t.h:MINORBITS
+ devMinorBits = 20
+
+ vsyscallEnd = usermem.Addr(0xffffffffff601000)
+ vsyscallMapsEntry = "ffffffffff600000-ffffffffff601000 r-xp 00000000 00:00 0 [vsyscall]\n"
+ vsyscallSmapsEntry = vsyscallMapsEntry +
+ "Size: 4 kB\n" +
+ "Rss: 0 kB\n" +
+ "Pss: 0 kB\n" +
+ "Shared_Clean: 0 kB\n" +
+ "Shared_Dirty: 0 kB\n" +
+ "Private_Clean: 0 kB\n" +
+ "Private_Dirty: 0 kB\n" +
+ "Referenced: 0 kB\n" +
+ "Anonymous: 0 kB\n" +
+ "AnonHugePages: 0 kB\n" +
+ "Shared_Hugetlb: 0 kB\n" +
+ "Private_Hugetlb: 0 kB\n" +
+ "Swap: 0 kB\n" +
+ "SwapPss: 0 kB\n" +
+ "KernelPageSize: 4 kB\n" +
+ "MMUPageSize: 4 kB\n" +
+ "Locked: 0 kB\n" +
+ "VmFlags: rd ex \n"
+)
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (mm *MemoryManager) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadMapsDataInto is called by fsimpl/proc.mapsData.Generate to
+// implement /proc/[pid]/maps.
+func (mm *MemoryManager) ReadMapsDataInto(ctx context.Context, buf *bytes.Buffer) {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ var start usermem.Addr
+
+ for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
+ mm.appendVMAMapsEntryLocked(ctx, vseg, buf)
+ }
+
+ // We always emulate vsyscall, so advertise it here. Everything about a
+ // vsyscall region is static, so just hard code the maps entry since we
+ // don't have a real vma backing it. The vsyscall region is at the end of
+ // the virtual address space so nothing should be mapped after it (if
+ // something is really mapped in the tiny ~10 MiB segment afterwards, we'll
+ // get the sorting on the maps file wrong at worst; but that's not possible
+ // on any current platform).
+ //
+ // Artifically adjust the seqfile handle so we only output vsyscall entry once.
+ if start != vsyscallEnd {
+ buf.WriteString(vsyscallMapsEntry)
+ }
+}
+
+// ReadMapsSeqFileData is called by fs/proc.mapsData.ReadSeqFileData to
+// implement /proc/[pid]/maps.
+func (mm *MemoryManager) ReadMapsSeqFileData(ctx context.Context, handle seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ var data []seqfile.SeqData
+ var start usermem.Addr
+ if handle != nil {
+ start = *handle.(*usermem.Addr)
+ }
+ for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
+ vmaAddr := vseg.End()
+ data = append(data, seqfile.SeqData{
+ Buf: mm.vmaMapsEntryLocked(ctx, vseg),
+ Handle: &vmaAddr,
+ })
+ }
+
+ // We always emulate vsyscall, so advertise it here. Everything about a
+ // vsyscall region is static, so just hard code the maps entry since we
+ // don't have a real vma backing it. The vsyscall region is at the end of
+ // the virtual address space so nothing should be mapped after it (if
+ // something is really mapped in the tiny ~10 MiB segment afterwards, we'll
+ // get the sorting on the maps file wrong at worst; but that's not possible
+ // on any current platform).
+ //
+ // Artifically adjust the seqfile handle so we only output vsyscall entry once.
+ if start != vsyscallEnd {
+ vmaAddr := vsyscallEnd
+ data = append(data, seqfile.SeqData{
+ Buf: []byte(vsyscallMapsEntry),
+ Handle: &vmaAddr,
+ })
+ }
+ return data, 1
+}
+
+// vmaMapsEntryLocked returns a /proc/[pid]/maps entry for the vma iterated by
+// vseg, including the trailing newline.
+//
+// Preconditions: mm.mappingMu must be locked.
+func (mm *MemoryManager) vmaMapsEntryLocked(ctx context.Context, vseg vmaIterator) []byte {
+ var b bytes.Buffer
+ mm.appendVMAMapsEntryLocked(ctx, vseg, &b)
+ return b.Bytes()
+}
+
+// Preconditions: mm.mappingMu must be locked.
+func (mm *MemoryManager) appendVMAMapsEntryLocked(ctx context.Context, vseg vmaIterator, b *bytes.Buffer) {
+ vma := vseg.ValuePtr()
+ private := "p"
+ if !vma.private {
+ private = "s"
+ }
+
+ var dev, ino uint64
+ if vma.id != nil {
+ dev = vma.id.DeviceID()
+ ino = vma.id.InodeID()
+ }
+ devMajor := uint32(dev >> devMinorBits)
+ devMinor := uint32(dev & ((1 << devMinorBits) - 1))
+
+ // Do not include the guard page: fs/proc/task_mmu.c:show_map_vma() =>
+ // stack_guard_page_start().
+ lineLen, _ := fmt.Fprintf(b, "%08x-%08x %s%s %08x %02x:%02x %d ",
+ vseg.Start(), vseg.End(), vma.realPerms, private, vma.off, devMajor, devMinor, ino)
+
+ // Figure out our filename or hint.
+ var s string
+ if vma.hint != "" {
+ s = vma.hint
+ } else if vma.id != nil {
+ // FIXME(jamieliu): We are holding mm.mappingMu here, which is
+ // consistent with Linux's holding mmap_sem in
+ // fs/proc/task_mmu.c:show_map_vma() => fs/seq_file.c:seq_file_path().
+ // However, it's not clear that fs.File.MappedName() is actually
+ // consistent with this lock order.
+ s = vma.id.MappedName(ctx)
+ }
+ if s != "" {
+ // Per linux, we pad until the 74th character.
+ if pad := 73 - lineLen; pad > 0 {
+ b.WriteString(strings.Repeat(" ", pad))
+ }
+ b.WriteString(s)
+ }
+ b.WriteString("\n")
+}
+
+// ReadSmapsDataInto is called by fsimpl/proc.smapsData.Generate to
+// implement /proc/[pid]/maps.
+func (mm *MemoryManager) ReadSmapsDataInto(ctx context.Context, buf *bytes.Buffer) {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ var start usermem.Addr
+
+ for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
+ mm.vmaSmapsEntryIntoLocked(ctx, vseg, buf)
+ }
+
+ // We always emulate vsyscall, so advertise it here. See
+ // ReadMapsSeqFileData for additional commentary.
+ if start != vsyscallEnd {
+ buf.WriteString(vsyscallSmapsEntry)
+ }
+}
+
+// ReadSmapsSeqFileData is called by fs/proc.smapsData.ReadSeqFileData to
+// implement /proc/[pid]/smaps.
+func (mm *MemoryManager) ReadSmapsSeqFileData(ctx context.Context, handle seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ var data []seqfile.SeqData
+ var start usermem.Addr
+ if handle != nil {
+ start = *handle.(*usermem.Addr)
+ }
+ for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
+ vmaAddr := vseg.End()
+ data = append(data, seqfile.SeqData{
+ Buf: mm.vmaSmapsEntryLocked(ctx, vseg),
+ Handle: &vmaAddr,
+ })
+ }
+
+ // We always emulate vsyscall, so advertise it here. See
+ // ReadMapsSeqFileData for additional commentary.
+ if start != vsyscallEnd {
+ vmaAddr := vsyscallEnd
+ data = append(data, seqfile.SeqData{
+ Buf: []byte(vsyscallSmapsEntry),
+ Handle: &vmaAddr,
+ })
+ }
+ return data, 1
+}
+
+// vmaSmapsEntryLocked returns a /proc/[pid]/smaps entry for the vma iterated
+// by vseg, including the trailing newline.
+//
+// Preconditions: mm.mappingMu must be locked.
+func (mm *MemoryManager) vmaSmapsEntryLocked(ctx context.Context, vseg vmaIterator) []byte {
+ var b bytes.Buffer
+ mm.vmaSmapsEntryIntoLocked(ctx, vseg, &b)
+ return b.Bytes()
+}
+
+func (mm *MemoryManager) vmaSmapsEntryIntoLocked(ctx context.Context, vseg vmaIterator, b *bytes.Buffer) {
+ mm.appendVMAMapsEntryLocked(ctx, vseg, b)
+ vma := vseg.ValuePtr()
+
+ // We take mm.activeMu here in each call to vmaSmapsEntryLocked, instead of
+ // requiring it to be locked as a precondition, to reduce the latency
+ // impact of reading /proc/[pid]/smaps on concurrent performance-sensitive
+ // operations requiring activeMu for writing like faults.
+ mm.activeMu.RLock()
+ var rss uint64
+ var anon uint64
+ vsegAR := vseg.Range()
+ for pseg := mm.pmas.LowerBoundSegment(vsegAR.Start); pseg.Ok() && pseg.Start() < vsegAR.End; pseg = pseg.NextSegment() {
+ psegAR := pseg.Range().Intersect(vsegAR)
+ size := uint64(psegAR.Length())
+ rss += size
+ if pseg.ValuePtr().private {
+ anon += size
+ }
+ }
+ mm.activeMu.RUnlock()
+
+ fmt.Fprintf(b, "Size: %8d kB\n", vseg.Range().Length()/1024)
+ fmt.Fprintf(b, "Rss: %8d kB\n", rss/1024)
+ // Currently we report PSS = RSS, i.e. we pretend each page mapped by a pma
+ // is only mapped by that pma. This avoids having to query memmap.Mappables
+ // for reference count information on each page. As a corollary, all pages
+ // are accounted as "private" whether or not the vma is private; compare
+ // Linux's fs/proc/task_mmu.c:smaps_account().
+ fmt.Fprintf(b, "Pss: %8d kB\n", rss/1024)
+ fmt.Fprintf(b, "Shared_Clean: %8d kB\n", 0)
+ fmt.Fprintf(b, "Shared_Dirty: %8d kB\n", 0)
+ // Pretend that all pages are dirty if the vma is writable, and clean otherwise.
+ clean := rss
+ if vma.effectivePerms.Write {
+ clean = 0
+ }
+ fmt.Fprintf(b, "Private_Clean: %8d kB\n", clean/1024)
+ fmt.Fprintf(b, "Private_Dirty: %8d kB\n", (rss-clean)/1024)
+ // Pretend that all pages are "referenced" (recently touched).
+ fmt.Fprintf(b, "Referenced: %8d kB\n", rss/1024)
+ fmt.Fprintf(b, "Anonymous: %8d kB\n", anon/1024)
+ // Hugepages (hugetlb and THP) are not implemented.
+ fmt.Fprintf(b, "AnonHugePages: %8d kB\n", 0)
+ fmt.Fprintf(b, "Shared_Hugetlb: %8d kB\n", 0)
+ fmt.Fprintf(b, "Private_Hugetlb: %7d kB\n", 0)
+ // Swap is not implemented.
+ fmt.Fprintf(b, "Swap: %8d kB\n", 0)
+ fmt.Fprintf(b, "SwapPss: %8d kB\n", 0)
+ fmt.Fprintf(b, "KernelPageSize: %8d kB\n", usermem.PageSize/1024)
+ fmt.Fprintf(b, "MMUPageSize: %8d kB\n", usermem.PageSize/1024)
+ locked := rss
+ if vma.mlockMode == memmap.MLockNone {
+ locked = 0
+ }
+ fmt.Fprintf(b, "Locked: %8d kB\n", locked/1024)
+
+ b.WriteString("VmFlags: ")
+ if vma.realPerms.Read {
+ b.WriteString("rd ")
+ }
+ if vma.realPerms.Write {
+ b.WriteString("wr ")
+ }
+ if vma.realPerms.Execute {
+ b.WriteString("ex ")
+ }
+ if vma.canWriteMappableLocked() { // VM_SHARED
+ b.WriteString("sh ")
+ }
+ if vma.maxPerms.Read {
+ b.WriteString("mr ")
+ }
+ if vma.maxPerms.Write {
+ b.WriteString("mw ")
+ }
+ if vma.maxPerms.Execute {
+ b.WriteString("me ")
+ }
+ if !vma.private { // VM_MAYSHARE
+ b.WriteString("ms ")
+ }
+ if vma.growsDown {
+ b.WriteString("gd ")
+ }
+ if vma.mlockMode != memmap.MLockNone { // VM_LOCKED
+ b.WriteString("lo ")
+ }
+ if vma.mlockMode == memmap.MLockLazy { // VM_LOCKONFAULT
+ b.WriteString("?? ") // no explicit encoding in fs/proc/task_mmu.c:show_smap_vma_flags()
+ }
+ if vma.private && vma.effectivePerms.Write { // VM_ACCOUNT
+ b.WriteString("ac ")
+ }
+ b.WriteString("\n")
+}
diff --git a/pkg/sentry/mm/save_restore.go b/pkg/sentry/mm/save_restore.go
new file mode 100644
index 000000000..f56215d9a
--- /dev/null
+++ b/pkg/sentry/mm/save_restore.go
@@ -0,0 +1,57 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+// InvalidateUnsavable invokes memmap.Mappable.InvalidateUnsavable on all
+// Mappables mapped by mm.
+func (mm *MemoryManager) InvalidateUnsavable(ctx context.Context) error {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ for vseg := mm.vmas.FirstSegment(); vseg.Ok(); vseg = vseg.NextSegment() {
+ if vma := vseg.ValuePtr(); vma.mappable != nil {
+ if err := vma.mappable.InvalidateUnsavable(ctx); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// beforeSave is invoked by stateify.
+func (mm *MemoryManager) beforeSave() {
+ mf := mm.mfp.MemoryFile()
+ for pseg := mm.pmas.FirstSegment(); pseg.Ok(); pseg = pseg.NextSegment() {
+ if pma := pseg.ValuePtr(); pma.file != mf {
+ // InvalidateUnsavable should have caused all such pmas to be
+ // invalidated.
+ panic(fmt.Sprintf("Can't save pma %#v with non-MemoryFile of type %T:\n%s", pseg.Range(), pma.file, mm))
+ }
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (mm *MemoryManager) afterLoad() {
+ mm.haveASIO = mm.p.SupportsAddressSpaceIO()
+ mf := mm.mfp.MemoryFile()
+ for pseg := mm.pmas.FirstSegment(); pseg.Ok(); pseg = pseg.NextSegment() {
+ pseg.ValuePtr().file = mf
+ }
+}
diff --git a/pkg/sentry/mm/shm.go b/pkg/sentry/mm/shm.go
new file mode 100644
index 000000000..6432731d4
--- /dev/null
+++ b/pkg/sentry/mm/shm.go
@@ -0,0 +1,66 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/shm"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// DetachShm unmaps a sysv shared memory segment.
+func (mm *MemoryManager) DetachShm(ctx context.Context, addr usermem.Addr) error {
+ if addr != addr.RoundDown() {
+ // "... shmaddr is not aligned on a page boundary." - man shmdt(2)
+ return syserror.EINVAL
+ }
+
+ var detached *shm.Shm
+ mm.mappingMu.Lock()
+ defer mm.mappingMu.Unlock()
+
+ // Find and remove the first vma containing an address >= addr that maps a
+ // segment originally attached at addr.
+ vseg := mm.vmas.LowerBoundSegment(addr)
+ for vseg.Ok() {
+ vma := vseg.ValuePtr()
+ if shm, ok := vma.mappable.(*shm.Shm); ok && vseg.Start() >= addr && uint64(vseg.Start()-addr) == vma.off {
+ detached = shm
+ vseg = mm.unmapLocked(ctx, vseg.Range()).NextSegment()
+ break
+ } else {
+ vseg = vseg.NextSegment()
+ }
+ }
+
+ if detached == nil {
+ // There is no shared memory segment attached at addr.
+ return syserror.EINVAL
+ }
+
+ // Remove all vmas that could have been created by the same attach.
+ end := addr + usermem.Addr(detached.EffectiveSize())
+ for vseg.Ok() && vseg.End() <= end {
+ vma := vseg.ValuePtr()
+ if vma.mappable == detached && uint64(vseg.Start()-addr) == vma.off {
+ vseg = mm.unmapLocked(ctx, vseg.Range()).NextSegment()
+ } else {
+ vseg = vseg.NextSegment()
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go
new file mode 100644
index 000000000..9ad52082d
--- /dev/null
+++ b/pkg/sentry/mm/special_mappable.go
@@ -0,0 +1,157 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// SpecialMappable implements memmap.MappingIdentity and memmap.Mappable with
+// semantics similar to Linux's mm/mmap.c:_install_special_mapping(), except
+// that SpecialMappable takes ownership of the memory that it represents
+// (_install_special_mapping() does not.)
+//
+// +stateify savable
+type SpecialMappable struct {
+ refs.AtomicRefCount
+
+ mfp pgalloc.MemoryFileProvider
+ fr platform.FileRange
+ name string
+}
+
+// NewSpecialMappable returns a SpecialMappable that owns fr, which represents
+// offsets in mfp.MemoryFile() that contain the SpecialMappable's data. The
+// SpecialMappable will use the given name in /proc/[pid]/maps.
+//
+// Preconditions: fr.Length() != 0.
+func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *SpecialMappable {
+ m := SpecialMappable{mfp: mfp, fr: fr, name: name}
+ m.EnableLeakCheck("mm.SpecialMappable")
+ return &m
+}
+
+// DecRef implements refs.RefCounter.DecRef.
+func (m *SpecialMappable) DecRef() {
+ m.AtomicRefCount.DecRefWithDestructor(func() {
+ m.mfp.MemoryFile().DecRef(m.fr)
+ })
+}
+
+// MappedName implements memmap.MappingIdentity.MappedName.
+func (m *SpecialMappable) MappedName(ctx context.Context) string {
+ return m.name
+}
+
+// DeviceID implements memmap.MappingIdentity.DeviceID.
+func (m *SpecialMappable) DeviceID() uint64 {
+ return 0
+}
+
+// InodeID implements memmap.MappingIdentity.InodeID.
+func (m *SpecialMappable) InodeID() uint64 {
+ return 0
+}
+
+// Msync implements memmap.MappingIdentity.Msync.
+func (m *SpecialMappable) Msync(ctx context.Context, mr memmap.MappableRange) error {
+ // Linux: vm_file is NULL, causing msync to skip it entirely.
+ return nil
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (*SpecialMappable) AddMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, uint64, bool) error {
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (*SpecialMappable) RemoveMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, uint64, bool) {
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (*SpecialMappable) CopyMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, usermem.AddrRange, uint64, bool) error {
+ return nil
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (m *SpecialMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ var err error
+ if required.End > m.fr.Length() {
+ err = &memmap.BusError{syserror.EFAULT}
+ }
+ if source := optional.Intersect(memmap.MappableRange{0, m.fr.Length()}); source.Length() != 0 {
+ return []memmap.Translation{
+ {
+ Source: source,
+ File: m.mfp.MemoryFile(),
+ Offset: m.fr.Start + source.Start,
+ Perms: usermem.AnyAccess,
+ },
+ }, err
+ }
+ return nil, err
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (m *SpecialMappable) InvalidateUnsavable(ctx context.Context) error {
+ // Since data is stored in pgalloc.MemoryFile, the contents of which are
+ // preserved across save/restore, we don't need to do anything.
+ return nil
+}
+
+// MemoryFileProvider returns the MemoryFileProvider whose MemoryFile stores
+// the SpecialMappable's contents.
+func (m *SpecialMappable) MemoryFileProvider() pgalloc.MemoryFileProvider {
+ return m.mfp
+}
+
+// FileRange returns the offsets into MemoryFileProvider().MemoryFile() that
+// store the SpecialMappable's contents.
+func (m *SpecialMappable) FileRange() platform.FileRange {
+ return m.fr
+}
+
+// Length returns the length of the SpecialMappable.
+func (m *SpecialMappable) Length() uint64 {
+ return m.fr.Length()
+}
+
+// NewSharedAnonMappable returns a SpecialMappable that implements the
+// semantics of mmap(MAP_SHARED|MAP_ANONYMOUS) and mappings of /dev/zero.
+//
+// TODO(jamieliu): The use of SpecialMappable is a lazy code reuse hack. Linux
+// uses an ephemeral file created by mm/shmem.c:shmem_zero_setup(); we should
+// do the same to get non-zero device and inode IDs.
+func NewSharedAnonMappable(length uint64, mfp pgalloc.MemoryFileProvider) (*SpecialMappable, error) {
+ if length == 0 {
+ return nil, syserror.EINVAL
+ }
+ alignedLen, ok := usermem.Addr(length).RoundUp()
+ if !ok {
+ return nil, syserror.EINVAL
+ }
+ fr, err := mfp.MemoryFile().Allocate(uint64(alignedLen), usage.Anonymous)
+ if err != nil {
+ return nil, err
+ }
+ return NewSpecialMappable("/dev/zero (deleted)", mfp, fr), nil
+}
diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go
new file mode 100644
index 000000000..3f496aa9f
--- /dev/null
+++ b/pkg/sentry/mm/syscalls.go
@@ -0,0 +1,1286 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "fmt"
+ mrand "math/rand"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/futex"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// HandleUserFault handles an application page fault. sp is the faulting
+// application thread's stack pointer.
+//
+// Preconditions: mm.as != nil.
+func (mm *MemoryManager) HandleUserFault(ctx context.Context, addr usermem.Addr, at usermem.AccessType, sp usermem.Addr) error {
+ ar, ok := addr.RoundDown().ToRange(usermem.PageSize)
+ if !ok {
+ return syserror.EFAULT
+ }
+
+ // Don't bother trying existingPMAsLocked; in most cases, if we did have
+ // existing pmas, we wouldn't have faulted.
+
+ // Ensure that we have a usable vma. Here and below, since we are only
+ // asking for a single page, there is no possibility of partial success,
+ // and any error is immediately fatal.
+ mm.mappingMu.RLock()
+ vseg, _, err := mm.getVMAsLocked(ctx, ar, at, false)
+ if err != nil {
+ mm.mappingMu.RUnlock()
+ return err
+ }
+
+ // Ensure that we have a usable pma.
+ mm.activeMu.Lock()
+ pseg, _, err := mm.getPMAsLocked(ctx, vseg, ar, at)
+ mm.mappingMu.RUnlock()
+ if err != nil {
+ mm.activeMu.Unlock()
+ return err
+ }
+
+ // Downgrade to a read-lock on activeMu since we don't need to mutate pmas
+ // anymore.
+ mm.activeMu.DowngradeLock()
+
+ // Map the faulted page into the active AddressSpace.
+ err = mm.mapASLocked(pseg, ar, false)
+ mm.activeMu.RUnlock()
+ return err
+}
+
+// MMap establishes a memory mapping.
+func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (usermem.Addr, error) {
+ if opts.Length == 0 {
+ return 0, syserror.EINVAL
+ }
+ length, ok := usermem.Addr(opts.Length).RoundUp()
+ if !ok {
+ return 0, syserror.ENOMEM
+ }
+ opts.Length = uint64(length)
+
+ if opts.Mappable != nil {
+ // Offset must be aligned.
+ if usermem.Addr(opts.Offset).RoundDown() != usermem.Addr(opts.Offset) {
+ return 0, syserror.EINVAL
+ }
+ // Offset + length must not overflow.
+ if end := opts.Offset + opts.Length; end < opts.Offset {
+ return 0, syserror.ENOMEM
+ }
+ } else {
+ opts.Offset = 0
+ if !opts.Private {
+ if opts.MappingIdentity != nil {
+ return 0, syserror.EINVAL
+ }
+ m, err := NewSharedAnonMappable(opts.Length, pgalloc.MemoryFileProviderFromContext(ctx))
+ if err != nil {
+ return 0, err
+ }
+ defer m.DecRef()
+ opts.MappingIdentity = m
+ opts.Mappable = m
+ }
+ }
+
+ if opts.Addr.RoundDown() != opts.Addr {
+ // MAP_FIXED requires addr to be page-aligned; non-fixed mappings
+ // don't.
+ if opts.Fixed {
+ return 0, syserror.EINVAL
+ }
+ opts.Addr = opts.Addr.RoundDown()
+ }
+
+ if !opts.MaxPerms.SupersetOf(opts.Perms) {
+ return 0, syserror.EACCES
+ }
+ if opts.Unmap && !opts.Fixed {
+ return 0, syserror.EINVAL
+ }
+ if opts.GrowsDown && opts.Mappable != nil {
+ return 0, syserror.EINVAL
+ }
+
+ // Get the new vma.
+ mm.mappingMu.Lock()
+ if opts.MLockMode < mm.defMLockMode {
+ opts.MLockMode = mm.defMLockMode
+ }
+ vseg, ar, err := mm.createVMALocked(ctx, opts)
+ if err != nil {
+ mm.mappingMu.Unlock()
+ return 0, err
+ }
+
+ // TODO(jamieliu): In Linux, VM_LOCKONFAULT (which may be set on the new
+ // vma by mlockall(MCL_FUTURE|MCL_ONFAULT) => mm_struct::def_flags) appears
+ // to effectively disable MAP_POPULATE by unsetting FOLL_POPULATE in
+ // mm/util.c:vm_mmap_pgoff() => mm/gup.c:__mm_populate() =>
+ // populate_vma_page_range(). Confirm this behavior.
+ switch {
+ case opts.Precommit || opts.MLockMode == memmap.MLockEager:
+ // Get pmas and map with precommit as requested.
+ mm.populateVMAAndUnlock(ctx, vseg, ar, true)
+
+ case opts.Mappable == nil && length <= privateAllocUnit:
+ // NOTE(b/63077076, b/63360184): Get pmas and map eagerly in the hope
+ // that doing so will save on future page faults. We only do this for
+ // anonymous mappings, since otherwise the cost of
+ // memmap.Mappable.Translate is unknown; and only for small mappings,
+ // to avoid needing to allocate large amounts of memory that we may
+ // subsequently need to checkpoint.
+ mm.populateVMAAndUnlock(ctx, vseg, ar, false)
+
+ default:
+ mm.mappingMu.Unlock()
+ }
+
+ return ar.Start, nil
+}
+
+// populateVMA obtains pmas for addresses in ar in the given vma, and maps them
+// into mm.as if it is active.
+//
+// Preconditions: mm.mappingMu must be locked. vseg.Range().IsSupersetOf(ar).
+func (mm *MemoryManager) populateVMA(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, precommit bool) {
+ if !vseg.ValuePtr().effectivePerms.Any() {
+ // Linux doesn't populate inaccessible pages. See
+ // mm/gup.c:populate_vma_page_range.
+ return
+ }
+
+ mm.activeMu.Lock()
+ // Can't defer mm.activeMu.Unlock(); see below.
+
+ // Even if we get new pmas, we can't actually map them if we don't have an
+ // AddressSpace.
+ if mm.as == nil {
+ mm.activeMu.Unlock()
+ return
+ }
+
+ // Ensure that we have usable pmas.
+ pseg, _, err := mm.getPMAsLocked(ctx, vseg, ar, usermem.NoAccess)
+ if err != nil {
+ // mm/util.c:vm_mmap_pgoff() ignores the error, if any, from
+ // mm/gup.c:mm_populate(). If it matters, we'll get it again when
+ // userspace actually tries to use the failing page.
+ mm.activeMu.Unlock()
+ return
+ }
+
+ // Downgrade to a read-lock on activeMu since we don't need to mutate pmas
+ // anymore.
+ mm.activeMu.DowngradeLock()
+
+ // As above, errors are silently ignored.
+ mm.mapASLocked(pseg, ar, precommit)
+ mm.activeMu.RUnlock()
+}
+
+// populateVMAAndUnlock is equivalent to populateVMA, but also unconditionally
+// unlocks mm.mappingMu. In cases where populateVMAAndUnlock is usable, it is
+// preferable to populateVMA since it unlocks mm.mappingMu before performing
+// expensive operations that don't require it to be locked.
+//
+// Preconditions: mm.mappingMu must be locked for writing.
+// vseg.Range().IsSupersetOf(ar).
+//
+// Postconditions: mm.mappingMu will be unlocked.
+func (mm *MemoryManager) populateVMAAndUnlock(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, precommit bool) {
+ // See populateVMA above for commentary.
+ if !vseg.ValuePtr().effectivePerms.Any() {
+ mm.mappingMu.Unlock()
+ return
+ }
+
+ mm.activeMu.Lock()
+
+ if mm.as == nil {
+ mm.activeMu.Unlock()
+ mm.mappingMu.Unlock()
+ return
+ }
+
+ // mm.mappingMu doesn't need to be write-locked for getPMAsLocked, and it
+ // isn't needed at all for mapASLocked.
+ mm.mappingMu.DowngradeLock()
+ pseg, _, err := mm.getPMAsLocked(ctx, vseg, ar, usermem.NoAccess)
+ mm.mappingMu.RUnlock()
+ if err != nil {
+ mm.activeMu.Unlock()
+ return
+ }
+
+ mm.activeMu.DowngradeLock()
+ mm.mapASLocked(pseg, ar, precommit)
+ mm.activeMu.RUnlock()
+}
+
+// MapStack allocates the initial process stack.
+func (mm *MemoryManager) MapStack(ctx context.Context) (usermem.AddrRange, error) {
+ // maxStackSize is the maximum supported process stack size in bytes.
+ //
+ // This limit exists because stack growing isn't implemented, so the entire
+ // process stack must be mapped up-front.
+ const maxStackSize = 128 << 20
+
+ stackSize := limits.FromContext(ctx).Get(limits.Stack)
+ r, ok := usermem.Addr(stackSize.Cur).RoundUp()
+ sz := uint64(r)
+ if !ok {
+ // RLIM_INFINITY rounds up to 0.
+ sz = linux.DefaultStackSoftLimit
+ } else if sz > maxStackSize {
+ ctx.Warningf("Capping stack size from RLIMIT_STACK of %v down to %v.", sz, maxStackSize)
+ sz = maxStackSize
+ } else if sz == 0 {
+ return usermem.AddrRange{}, syserror.ENOMEM
+ }
+ szaddr := usermem.Addr(sz)
+ ctx.Debugf("Allocating stack with size of %v bytes", sz)
+
+ // Determine the stack's desired location. Unlike Linux, address
+ // randomization can't be disabled.
+ stackEnd := mm.layout.MaxAddr - usermem.Addr(mrand.Int63n(int64(mm.layout.MaxStackRand))).RoundDown()
+ if stackEnd < szaddr {
+ return usermem.AddrRange{}, syserror.ENOMEM
+ }
+ stackStart := stackEnd - szaddr
+ mm.mappingMu.Lock()
+ defer mm.mappingMu.Unlock()
+ _, ar, err := mm.createVMALocked(ctx, memmap.MMapOpts{
+ Length: sz,
+ Addr: stackStart,
+ Perms: usermem.ReadWrite,
+ MaxPerms: usermem.AnyAccess,
+ Private: true,
+ GrowsDown: true,
+ MLockMode: mm.defMLockMode,
+ Hint: "[stack]",
+ })
+ return ar, err
+}
+
+// MUnmap implements the semantics of Linux's munmap(2).
+func (mm *MemoryManager) MUnmap(ctx context.Context, addr usermem.Addr, length uint64) error {
+ if addr != addr.RoundDown() {
+ return syserror.EINVAL
+ }
+ if length == 0 {
+ return syserror.EINVAL
+ }
+ la, ok := usermem.Addr(length).RoundUp()
+ if !ok {
+ return syserror.EINVAL
+ }
+ ar, ok := addr.ToRange(uint64(la))
+ if !ok {
+ return syserror.EINVAL
+ }
+
+ mm.mappingMu.Lock()
+ defer mm.mappingMu.Unlock()
+ mm.unmapLocked(ctx, ar)
+ return nil
+}
+
+// MRemapOpts specifies options to MRemap.
+type MRemapOpts struct {
+ // Move controls whether MRemap moves the remapped mapping to a new address.
+ Move MRemapMoveMode
+
+ // NewAddr is the new address for the remapping. NewAddr is ignored unless
+ // Move is MMRemapMustMove.
+ NewAddr usermem.Addr
+}
+
+// MRemapMoveMode controls MRemap's moving behavior.
+type MRemapMoveMode int
+
+const (
+ // MRemapNoMove prevents MRemap from moving the remapped mapping.
+ MRemapNoMove MRemapMoveMode = iota
+
+ // MRemapMayMove allows MRemap to move the remapped mapping.
+ MRemapMayMove
+
+ // MRemapMustMove requires MRemap to move the remapped mapping to
+ // MRemapOpts.NewAddr, replacing any existing mappings in the remapped
+ // range.
+ MRemapMustMove
+)
+
+// MRemap implements the semantics of Linux's mremap(2).
+func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr usermem.Addr, oldSize uint64, newSize uint64, opts MRemapOpts) (usermem.Addr, error) {
+ // "Note that old_address has to be page aligned." - mremap(2)
+ if oldAddr.RoundDown() != oldAddr {
+ return 0, syserror.EINVAL
+ }
+
+ // Linux treats an old_size that rounds up to 0 as 0, which is otherwise a
+ // valid size. However, new_size can't be 0 after rounding.
+ oldSizeAddr, _ := usermem.Addr(oldSize).RoundUp()
+ oldSize = uint64(oldSizeAddr)
+ newSizeAddr, ok := usermem.Addr(newSize).RoundUp()
+ if !ok || newSizeAddr == 0 {
+ return 0, syserror.EINVAL
+ }
+ newSize = uint64(newSizeAddr)
+
+ oldEnd, ok := oldAddr.AddLength(oldSize)
+ if !ok {
+ return 0, syserror.EINVAL
+ }
+
+ mm.mappingMu.Lock()
+ defer mm.mappingMu.Unlock()
+
+ // All cases require that a vma exists at oldAddr.
+ vseg := mm.vmas.FindSegment(oldAddr)
+ if !vseg.Ok() {
+ return 0, syserror.EFAULT
+ }
+
+ // Behavior matrix:
+ //
+ // Move | oldSize = 0 | oldSize < newSize | oldSize = newSize | oldSize > newSize
+ // ---------+-------------+-------------------+-------------------+------------------
+ // NoMove | ENOMEM [1] | Grow in-place | No-op | Shrink in-place
+ // MayMove | Copy [1] | Grow in-place or | No-op | Shrink in-place
+ // | | move | |
+ // MustMove | Copy | Move and grow | Move | Shrink and move
+ //
+ // [1] In-place growth is impossible because the vma at oldAddr already
+ // occupies at least part of the destination. Thus the NoMove case always
+ // fails and the MayMove case always falls back to copying.
+
+ if vma := vseg.ValuePtr(); newSize > oldSize && vma.mlockMode != memmap.MLockNone {
+ // Check against RLIMIT_MEMLOCK. Unlike mmap, mlock, and mlockall,
+ // mremap in Linux does not check mm/mlock.c:can_do_mlock() and
+ // therefore does not return EPERM if RLIMIT_MEMLOCK is 0 and
+ // !CAP_IPC_LOCK.
+ mlockLimit := limits.FromContext(ctx).Get(limits.MemoryLocked).Cur
+ if creds := auth.CredentialsFromContext(ctx); !creds.HasCapabilityIn(linux.CAP_IPC_LOCK, creds.UserNamespace.Root()) {
+ if newLockedAS := mm.lockedAS - oldSize + newSize; newLockedAS > mlockLimit {
+ return 0, syserror.EAGAIN
+ }
+ }
+ }
+
+ if opts.Move != MRemapMustMove {
+ // Handle no-ops and in-place shrinking. These cases don't care if
+ // [oldAddr, oldEnd) maps to a single vma, or is even mapped at all
+ // (aside from oldAddr).
+ if newSize <= oldSize {
+ if newSize < oldSize {
+ // If oldAddr+oldSize didn't overflow, oldAddr+newSize can't
+ // either.
+ newEnd := oldAddr + usermem.Addr(newSize)
+ mm.unmapLocked(ctx, usermem.AddrRange{newEnd, oldEnd})
+ }
+ return oldAddr, nil
+ }
+
+ // Handle in-place growing.
+
+ // Check that oldEnd maps to the same vma as oldAddr.
+ if vseg.End() < oldEnd {
+ return 0, syserror.EFAULT
+ }
+ // "Grow" the existing vma by creating a new mergeable one.
+ vma := vseg.ValuePtr()
+ var newOffset uint64
+ if vma.mappable != nil {
+ newOffset = vseg.mappableRange().End
+ }
+ vseg, ar, err := mm.createVMALocked(ctx, memmap.MMapOpts{
+ Length: newSize - oldSize,
+ MappingIdentity: vma.id,
+ Mappable: vma.mappable,
+ Offset: newOffset,
+ Addr: oldEnd,
+ Fixed: true,
+ Perms: vma.realPerms,
+ MaxPerms: vma.maxPerms,
+ Private: vma.private,
+ GrowsDown: vma.growsDown,
+ MLockMode: vma.mlockMode,
+ Hint: vma.hint,
+ })
+ if err == nil {
+ if vma.mlockMode == memmap.MLockEager {
+ mm.populateVMA(ctx, vseg, ar, true)
+ }
+ return oldAddr, nil
+ }
+ // In-place growth failed. In the MRemapMayMove case, fall through to
+ // copying/moving below.
+ if opts.Move == MRemapNoMove {
+ return 0, err
+ }
+ }
+
+ // Find a location for the new mapping.
+ var newAR usermem.AddrRange
+ switch opts.Move {
+ case MRemapMayMove:
+ newAddr, err := mm.findAvailableLocked(newSize, findAvailableOpts{})
+ if err != nil {
+ return 0, err
+ }
+ newAR, _ = newAddr.ToRange(newSize)
+
+ case MRemapMustMove:
+ newAddr := opts.NewAddr
+ if newAddr.RoundDown() != newAddr {
+ return 0, syserror.EINVAL
+ }
+ var ok bool
+ newAR, ok = newAddr.ToRange(newSize)
+ if !ok {
+ return 0, syserror.EINVAL
+ }
+ if (usermem.AddrRange{oldAddr, oldEnd}).Overlaps(newAR) {
+ return 0, syserror.EINVAL
+ }
+
+ // Check that the new region is valid.
+ _, err := mm.findAvailableLocked(newSize, findAvailableOpts{
+ Addr: newAddr,
+ Fixed: true,
+ Unmap: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ // Unmap any mappings at the destination.
+ mm.unmapLocked(ctx, newAR)
+
+ // If the sizes specify shrinking, unmap everything between the new and
+ // old sizes at the source. Unmapping before the following checks is
+ // correct: compare Linux's mm/mremap.c:mremap_to() => do_munmap(),
+ // vma_to_resize().
+ if newSize < oldSize {
+ oldNewEnd := oldAddr + usermem.Addr(newSize)
+ mm.unmapLocked(ctx, usermem.AddrRange{oldNewEnd, oldEnd})
+ oldEnd = oldNewEnd
+ }
+
+ // unmapLocked may have invalidated vseg; look it up again.
+ vseg = mm.vmas.FindSegment(oldAddr)
+ }
+
+ oldAR := usermem.AddrRange{oldAddr, oldEnd}
+
+ // Check that oldEnd maps to the same vma as oldAddr.
+ if vseg.End() < oldEnd {
+ return 0, syserror.EFAULT
+ }
+
+ // Check against RLIMIT_AS.
+ newUsageAS := mm.usageAS - uint64(oldAR.Length()) + uint64(newAR.Length())
+ if limitAS := limits.FromContext(ctx).Get(limits.AS).Cur; newUsageAS > limitAS {
+ return 0, syserror.ENOMEM
+ }
+
+ if vma := vseg.ValuePtr(); vma.mappable != nil {
+ // Check that offset+length does not overflow.
+ if vma.off+uint64(newAR.Length()) < vma.off {
+ return 0, syserror.EINVAL
+ }
+ // Inform the Mappable, if any, of the new mapping.
+ if err := vma.mappable.CopyMapping(ctx, mm, oldAR, newAR, vseg.mappableOffsetAt(oldAR.Start), vma.canWriteMappableLocked()); err != nil {
+ return 0, err
+ }
+ }
+
+ if oldSize == 0 {
+ // Handle copying.
+ //
+ // We can't use createVMALocked because it calls Mappable.AddMapping,
+ // whereas we've already called Mappable.CopyMapping (which is
+ // consistent with Linux). Call vseg.Value() (rather than
+ // vseg.ValuePtr()) to make a copy of the vma.
+ vma := vseg.Value()
+ if vma.mappable != nil {
+ vma.off = vseg.mappableOffsetAt(oldAR.Start)
+ }
+ if vma.id != nil {
+ vma.id.IncRef()
+ }
+ vseg := mm.vmas.Insert(mm.vmas.FindGap(newAR.Start), newAR, vma)
+ mm.usageAS += uint64(newAR.Length())
+ if vma.isPrivateDataLocked() {
+ mm.dataAS += uint64(newAR.Length())
+ }
+ if vma.mlockMode != memmap.MLockNone {
+ mm.lockedAS += uint64(newAR.Length())
+ if vma.mlockMode == memmap.MLockEager {
+ mm.populateVMA(ctx, vseg, newAR, true)
+ }
+ }
+ return newAR.Start, nil
+ }
+
+ // Handle moving.
+ //
+ // Remove the existing vma before inserting the new one to minimize
+ // iterator invalidation. We do this directly (instead of calling
+ // removeVMAsLocked) because:
+ //
+ // 1. We can't drop the reference on vma.id, which will be transferred to
+ // the new vma.
+ //
+ // 2. We can't call vma.mappable.RemoveMapping, because pmas are still at
+ // oldAR, so calling RemoveMapping could cause us to miss an invalidation
+ // overlapping oldAR.
+ //
+ // Call vseg.Value() (rather than vseg.ValuePtr()) to make a copy of the
+ // vma.
+ vseg = mm.vmas.Isolate(vseg, oldAR)
+ vma := vseg.Value()
+ mm.vmas.Remove(vseg)
+ vseg = mm.vmas.Insert(mm.vmas.FindGap(newAR.Start), newAR, vma)
+ mm.usageAS = mm.usageAS - uint64(oldAR.Length()) + uint64(newAR.Length())
+ if vma.isPrivateDataLocked() {
+ mm.dataAS = mm.dataAS - uint64(oldAR.Length()) + uint64(newAR.Length())
+ }
+ if vma.mlockMode != memmap.MLockNone {
+ mm.lockedAS = mm.lockedAS - uint64(oldAR.Length()) + uint64(newAR.Length())
+ }
+
+ // Move pmas. This is technically optional for non-private pmas, which
+ // could just go through memmap.Mappable.Translate again, but it's required
+ // for private pmas.
+ mm.activeMu.Lock()
+ mm.movePMAsLocked(oldAR, newAR)
+ mm.activeMu.Unlock()
+
+ // Now that pmas have been moved to newAR, we can notify vma.mappable that
+ // oldAR is no longer mapped.
+ if vma.mappable != nil {
+ vma.mappable.RemoveMapping(ctx, mm, oldAR, vma.off, vma.canWriteMappableLocked())
+ }
+
+ if vma.mlockMode == memmap.MLockEager {
+ mm.populateVMA(ctx, vseg, newAR, true)
+ }
+
+ return newAR.Start, nil
+}
+
+// MProtect implements the semantics of Linux's mprotect(2).
+func (mm *MemoryManager) MProtect(addr usermem.Addr, length uint64, realPerms usermem.AccessType, growsDown bool) error {
+ if addr.RoundDown() != addr {
+ return syserror.EINVAL
+ }
+ if length == 0 {
+ return nil
+ }
+ rlength, ok := usermem.Addr(length).RoundUp()
+ if !ok {
+ return syserror.ENOMEM
+ }
+ ar, ok := addr.ToRange(uint64(rlength))
+ if !ok {
+ return syserror.ENOMEM
+ }
+ effectivePerms := realPerms.Effective()
+
+ mm.mappingMu.Lock()
+ defer mm.mappingMu.Unlock()
+ // Non-growsDown mprotect requires that all of ar is mapped, and stops at
+ // the first non-empty gap. growsDown mprotect requires that the first vma
+ // be growsDown, but does not require it to extend all the way to ar.Start;
+ // vmas after the first must be contiguous but need not be growsDown, like
+ // the non-growsDown case.
+ vseg := mm.vmas.LowerBoundSegment(ar.Start)
+ if !vseg.Ok() {
+ return syserror.ENOMEM
+ }
+ if growsDown {
+ if !vseg.ValuePtr().growsDown {
+ return syserror.EINVAL
+ }
+ if ar.End <= vseg.Start() {
+ return syserror.ENOMEM
+ }
+ ar.Start = vseg.Start()
+ } else {
+ if ar.Start < vseg.Start() {
+ return syserror.ENOMEM
+ }
+ }
+
+ mm.activeMu.Lock()
+ defer mm.activeMu.Unlock()
+ defer func() {
+ mm.vmas.MergeRange(ar)
+ mm.vmas.MergeAdjacent(ar)
+ mm.pmas.MergeRange(ar)
+ mm.pmas.MergeAdjacent(ar)
+ }()
+ pseg := mm.pmas.LowerBoundSegment(ar.Start)
+ var didUnmapAS bool
+ for {
+ // Check for permission validity before splitting vmas, for consistency
+ // with Linux.
+ if !vseg.ValuePtr().maxPerms.SupersetOf(effectivePerms) {
+ return syserror.EACCES
+ }
+ vseg = mm.vmas.Isolate(vseg, ar)
+
+ // Update vma permissions.
+ vma := vseg.ValuePtr()
+ vmaLength := vseg.Range().Length()
+ if vma.isPrivateDataLocked() {
+ mm.dataAS -= uint64(vmaLength)
+ }
+
+ vma.realPerms = realPerms
+ vma.effectivePerms = effectivePerms
+ if vma.isPrivateDataLocked() {
+ mm.dataAS += uint64(vmaLength)
+ }
+
+ // Propagate vma permission changes to pmas.
+ for pseg.Ok() && pseg.Start() < vseg.End() {
+ if pseg.Range().Overlaps(vseg.Range()) {
+ pseg = mm.pmas.Isolate(pseg, vseg.Range())
+ pma := pseg.ValuePtr()
+ if !effectivePerms.SupersetOf(pma.effectivePerms) && !didUnmapAS {
+ // Unmap all of ar, not just vseg.Range(), to minimize host
+ // syscalls.
+ mm.unmapASLocked(ar)
+ didUnmapAS = true
+ }
+ pma.effectivePerms = effectivePerms.Intersect(pma.translatePerms)
+ if pma.needCOW {
+ pma.effectivePerms.Write = false
+ }
+ }
+ pseg = pseg.NextSegment()
+ }
+
+ // Continue to the next vma.
+ if ar.End <= vseg.End() {
+ return nil
+ }
+ vseg, _ = vseg.NextNonEmpty()
+ if !vseg.Ok() {
+ return syserror.ENOMEM
+ }
+ }
+}
+
+// BrkSetup sets mm's brk address to addr and its brk size to 0.
+func (mm *MemoryManager) BrkSetup(ctx context.Context, addr usermem.Addr) {
+ mm.mappingMu.Lock()
+ defer mm.mappingMu.Unlock()
+ // Unmap the existing brk.
+ if mm.brk.Length() != 0 {
+ mm.unmapLocked(ctx, mm.brk)
+ }
+ mm.brk = usermem.AddrRange{addr, addr}
+}
+
+// Brk implements the semantics of Linux's brk(2), except that it returns an
+// error on failure.
+func (mm *MemoryManager) Brk(ctx context.Context, addr usermem.Addr) (usermem.Addr, error) {
+ mm.mappingMu.Lock()
+ // Can't defer mm.mappingMu.Unlock(); see below.
+
+ if addr < mm.brk.Start {
+ addr = mm.brk.End
+ mm.mappingMu.Unlock()
+ return addr, syserror.EINVAL
+ }
+
+ // TODO(gvisor.dev/issue/156): This enforces RLIMIT_DATA, but is
+ // slightly more permissive than the usual data limit. In particular,
+ // this only limits the size of the heap; a true RLIMIT_DATA limits the
+ // size of heap + data + bss. The segment sizes need to be plumbed from
+ // the loader package to fully enforce RLIMIT_DATA.
+ if uint64(addr-mm.brk.Start) > limits.FromContext(ctx).Get(limits.Data).Cur {
+ addr = mm.brk.End
+ mm.mappingMu.Unlock()
+ return addr, syserror.ENOMEM
+ }
+
+ oldbrkpg, _ := mm.brk.End.RoundUp()
+ newbrkpg, ok := addr.RoundUp()
+ if !ok {
+ addr = mm.brk.End
+ mm.mappingMu.Unlock()
+ return addr, syserror.EFAULT
+ }
+
+ switch {
+ case oldbrkpg < newbrkpg:
+ vseg, ar, err := mm.createVMALocked(ctx, memmap.MMapOpts{
+ Length: uint64(newbrkpg - oldbrkpg),
+ Addr: oldbrkpg,
+ Fixed: true,
+ // Compare Linux's
+ // arch/x86/include/asm/page_types.h:VM_DATA_DEFAULT_FLAGS.
+ Perms: usermem.ReadWrite,
+ MaxPerms: usermem.AnyAccess,
+ Private: true,
+ // Linux: mm/mmap.c:sys_brk() => do_brk_flags() includes
+ // mm->def_flags.
+ MLockMode: mm.defMLockMode,
+ Hint: "[heap]",
+ })
+ if err != nil {
+ addr = mm.brk.End
+ mm.mappingMu.Unlock()
+ return addr, err
+ }
+ mm.brk.End = addr
+ if mm.defMLockMode == memmap.MLockEager {
+ mm.populateVMAAndUnlock(ctx, vseg, ar, true)
+ } else {
+ mm.mappingMu.Unlock()
+ }
+
+ case newbrkpg < oldbrkpg:
+ mm.unmapLocked(ctx, usermem.AddrRange{newbrkpg, oldbrkpg})
+ fallthrough
+
+ default:
+ mm.brk.End = addr
+ mm.mappingMu.Unlock()
+ }
+
+ return addr, nil
+}
+
+// MLock implements the semantics of Linux's mlock()/mlock2()/munlock(),
+// depending on mode.
+func (mm *MemoryManager) MLock(ctx context.Context, addr usermem.Addr, length uint64, mode memmap.MLockMode) error {
+ // Linux allows this to overflow.
+ la, _ := usermem.Addr(length + addr.PageOffset()).RoundUp()
+ ar, ok := addr.RoundDown().ToRange(uint64(la))
+ if !ok {
+ return syserror.EINVAL
+ }
+
+ mm.mappingMu.Lock()
+ // Can't defer mm.mappingMu.Unlock(); see below.
+
+ if mode != memmap.MLockNone {
+ // Check against RLIMIT_MEMLOCK.
+ if creds := auth.CredentialsFromContext(ctx); !creds.HasCapabilityIn(linux.CAP_IPC_LOCK, creds.UserNamespace.Root()) {
+ mlockLimit := limits.FromContext(ctx).Get(limits.MemoryLocked).Cur
+ if mlockLimit == 0 {
+ mm.mappingMu.Unlock()
+ return syserror.EPERM
+ }
+ if newLockedAS := mm.lockedAS + uint64(ar.Length()) - mm.mlockedBytesRangeLocked(ar); newLockedAS > mlockLimit {
+ mm.mappingMu.Unlock()
+ return syserror.ENOMEM
+ }
+ }
+ }
+
+ // Check this after RLIMIT_MEMLOCK for consistency with Linux.
+ if ar.Length() == 0 {
+ mm.mappingMu.Unlock()
+ return nil
+ }
+
+ // Apply the new mlock mode to vmas.
+ var unmapped bool
+ vseg := mm.vmas.FindSegment(ar.Start)
+ for {
+ if !vseg.Ok() {
+ unmapped = true
+ break
+ }
+ vseg = mm.vmas.Isolate(vseg, ar)
+ vma := vseg.ValuePtr()
+ prevMode := vma.mlockMode
+ vma.mlockMode = mode
+ if mode != memmap.MLockNone && prevMode == memmap.MLockNone {
+ mm.lockedAS += uint64(vseg.Range().Length())
+ } else if mode == memmap.MLockNone && prevMode != memmap.MLockNone {
+ mm.lockedAS -= uint64(vseg.Range().Length())
+ }
+ if ar.End <= vseg.End() {
+ break
+ }
+ vseg, _ = vseg.NextNonEmpty()
+ }
+ mm.vmas.MergeRange(ar)
+ mm.vmas.MergeAdjacent(ar)
+ if unmapped {
+ mm.mappingMu.Unlock()
+ return syserror.ENOMEM
+ }
+
+ if mode == memmap.MLockEager {
+ // Ensure that we have usable pmas. Since we didn't return ENOMEM
+ // above, ar must be fully covered by vmas, so we can just use
+ // NextSegment below.
+ mm.activeMu.Lock()
+ mm.mappingMu.DowngradeLock()
+ for vseg := mm.vmas.FindSegment(ar.Start); vseg.Ok() && vseg.Start() < ar.End; vseg = vseg.NextSegment() {
+ if !vseg.ValuePtr().effectivePerms.Any() {
+ // Linux: mm/gup.c:__get_user_pages() returns EFAULT in this
+ // case, which is converted to ENOMEM by mlock.
+ mm.activeMu.Unlock()
+ mm.mappingMu.RUnlock()
+ return syserror.ENOMEM
+ }
+ _, _, err := mm.getPMAsLocked(ctx, vseg, vseg.Range().Intersect(ar), usermem.NoAccess)
+ if err != nil {
+ mm.activeMu.Unlock()
+ mm.mappingMu.RUnlock()
+ // Linux: mm/mlock.c:__mlock_posix_error_return()
+ if err == syserror.EFAULT {
+ return syserror.ENOMEM
+ }
+ if err == syserror.ENOMEM {
+ return syserror.EAGAIN
+ }
+ return err
+ }
+ }
+
+ // Map pmas into the active AddressSpace, if we have one.
+ mm.mappingMu.RUnlock()
+ if mm.as != nil {
+ mm.activeMu.DowngradeLock()
+ err := mm.mapASLocked(mm.pmas.LowerBoundSegment(ar.Start), ar, true /* precommit */)
+ mm.activeMu.RUnlock()
+ if err != nil {
+ return err
+ }
+ } else {
+ mm.activeMu.Unlock()
+ }
+ } else {
+ mm.mappingMu.Unlock()
+ }
+
+ return nil
+}
+
+// MLockAllOpts holds options to MLockAll.
+type MLockAllOpts struct {
+ // If Current is true, change the memory-locking behavior of all mappings
+ // to Mode. If Future is true, upgrade the memory-locking behavior of all
+ // future mappings to Mode. At least one of Current or Future must be true.
+ Current bool
+ Future bool
+ Mode memmap.MLockMode
+}
+
+// MLockAll implements the semantics of Linux's mlockall()/munlockall(),
+// depending on opts.
+func (mm *MemoryManager) MLockAll(ctx context.Context, opts MLockAllOpts) error {
+ if !opts.Current && !opts.Future {
+ return syserror.EINVAL
+ }
+
+ mm.mappingMu.Lock()
+ // Can't defer mm.mappingMu.Unlock(); see below.
+
+ if opts.Current {
+ if opts.Mode != memmap.MLockNone {
+ // Check against RLIMIT_MEMLOCK.
+ if creds := auth.CredentialsFromContext(ctx); !creds.HasCapabilityIn(linux.CAP_IPC_LOCK, creds.UserNamespace.Root()) {
+ mlockLimit := limits.FromContext(ctx).Get(limits.MemoryLocked).Cur
+ if mlockLimit == 0 {
+ mm.mappingMu.Unlock()
+ return syserror.EPERM
+ }
+ if uint64(mm.vmas.Span()) > mlockLimit {
+ mm.mappingMu.Unlock()
+ return syserror.ENOMEM
+ }
+ }
+ }
+ for vseg := mm.vmas.FirstSegment(); vseg.Ok(); vseg = vseg.NextSegment() {
+ vma := vseg.ValuePtr()
+ prevMode := vma.mlockMode
+ vma.mlockMode = opts.Mode
+ if opts.Mode != memmap.MLockNone && prevMode == memmap.MLockNone {
+ mm.lockedAS += uint64(vseg.Range().Length())
+ } else if opts.Mode == memmap.MLockNone && prevMode != memmap.MLockNone {
+ mm.lockedAS -= uint64(vseg.Range().Length())
+ }
+ }
+ }
+
+ if opts.Future {
+ mm.defMLockMode = opts.Mode
+ }
+
+ if opts.Current && opts.Mode == memmap.MLockEager {
+ // Linux: mm/mlock.c:sys_mlockall() => include/linux/mm.h:mm_populate()
+ // ignores the return value of __mm_populate(), so all errors below are
+ // ignored.
+ //
+ // Try to get usable pmas.
+ mm.activeMu.Lock()
+ mm.mappingMu.DowngradeLock()
+ for vseg := mm.vmas.FirstSegment(); vseg.Ok(); vseg = vseg.NextSegment() {
+ if vseg.ValuePtr().effectivePerms.Any() {
+ mm.getPMAsLocked(ctx, vseg, vseg.Range(), usermem.NoAccess)
+ }
+ }
+
+ // Map all pmas into the active AddressSpace, if we have one.
+ mm.mappingMu.RUnlock()
+ if mm.as != nil {
+ mm.activeMu.DowngradeLock()
+ mm.mapASLocked(mm.pmas.FirstSegment(), mm.applicationAddrRange(), true /* precommit */)
+ mm.activeMu.RUnlock()
+ } else {
+ mm.activeMu.Unlock()
+ }
+ } else {
+ mm.mappingMu.Unlock()
+ }
+ return nil
+}
+
+// NumaPolicy implements the semantics of Linux's get_mempolicy(MPOL_F_ADDR).
+func (mm *MemoryManager) NumaPolicy(addr usermem.Addr) (linux.NumaPolicy, uint64, error) {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ vseg := mm.vmas.FindSegment(addr)
+ if !vseg.Ok() {
+ return 0, 0, syserror.EFAULT
+ }
+ vma := vseg.ValuePtr()
+ return vma.numaPolicy, vma.numaNodemask, nil
+}
+
+// SetNumaPolicy implements the semantics of Linux's mbind().
+func (mm *MemoryManager) SetNumaPolicy(addr usermem.Addr, length uint64, policy linux.NumaPolicy, nodemask uint64) error {
+ if !addr.IsPageAligned() {
+ return syserror.EINVAL
+ }
+ // Linux allows this to overflow.
+ la, _ := usermem.Addr(length).RoundUp()
+ ar, ok := addr.ToRange(uint64(la))
+ if !ok {
+ return syserror.EINVAL
+ }
+ if ar.Length() == 0 {
+ return nil
+ }
+
+ mm.mappingMu.Lock()
+ defer mm.mappingMu.Unlock()
+ defer func() {
+ mm.vmas.MergeRange(ar)
+ mm.vmas.MergeAdjacent(ar)
+ }()
+ vseg := mm.vmas.LowerBoundSegment(ar.Start)
+ lastEnd := ar.Start
+ for {
+ if !vseg.Ok() || lastEnd < vseg.Start() {
+ // "EFAULT: ... there was an unmapped hole in the specified memory
+ // range specified [sic] by addr and len." - mbind(2)
+ return syserror.EFAULT
+ }
+ vseg = mm.vmas.Isolate(vseg, ar)
+ vma := vseg.ValuePtr()
+ vma.numaPolicy = policy
+ vma.numaNodemask = nodemask
+ lastEnd = vseg.End()
+ if ar.End <= lastEnd {
+ return nil
+ }
+ vseg, _ = vseg.NextNonEmpty()
+ }
+}
+
+// SetDontFork implements the semantics of madvise MADV_DONTFORK.
+func (mm *MemoryManager) SetDontFork(addr usermem.Addr, length uint64, dontfork bool) error {
+ ar, ok := addr.ToRange(length)
+ if !ok {
+ return syserror.EINVAL
+ }
+
+ mm.mappingMu.Lock()
+ defer mm.mappingMu.Unlock()
+ defer func() {
+ mm.vmas.MergeRange(ar)
+ mm.vmas.MergeAdjacent(ar)
+ }()
+
+ for vseg := mm.vmas.LowerBoundSegment(ar.Start); vseg.Ok() && vseg.Start() < ar.End; vseg = vseg.NextSegment() {
+ vseg = mm.vmas.Isolate(vseg, ar)
+ vma := vseg.ValuePtr()
+ vma.dontfork = dontfork
+ }
+
+ if mm.vmas.SpanRange(ar) != ar.Length() {
+ return syserror.ENOMEM
+ }
+ return nil
+}
+
+// Decommit implements the semantics of Linux's madvise(MADV_DONTNEED).
+func (mm *MemoryManager) Decommit(addr usermem.Addr, length uint64) error {
+ ar, ok := addr.ToRange(length)
+ if !ok {
+ return syserror.EINVAL
+ }
+
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ mm.activeMu.Lock()
+ defer mm.activeMu.Unlock()
+
+ // Linux's mm/madvise.c:madvise_dontneed() => mm/memory.c:zap_page_range()
+ // is analogous to our mm.invalidateLocked(ar, true, true). We inline this
+ // here, with the special case that we synchronously decommit
+ // uniquely-owned (non-copy-on-write) pages for private anonymous vma,
+ // which is the common case for MADV_DONTNEED. Invalidating these pmas, and
+ // allowing them to be reallocated when touched again, increases pma
+ // fragmentation, which may significantly reduce performance for
+ // non-vectored I/O implementations. Also, decommitting synchronously
+ // ensures that Decommit immediately reduces host memory usage.
+ var didUnmapAS bool
+ pseg := mm.pmas.LowerBoundSegment(ar.Start)
+ mf := mm.mfp.MemoryFile()
+ for vseg := mm.vmas.LowerBoundSegment(ar.Start); vseg.Ok() && vseg.Start() < ar.End; vseg = vseg.NextSegment() {
+ vma := vseg.ValuePtr()
+ if vma.mlockMode != memmap.MLockNone {
+ return syserror.EINVAL
+ }
+ vsegAR := vseg.Range().Intersect(ar)
+ // pseg should already correspond to either this vma or a later one,
+ // since there can't be a pma without a corresponding vma.
+ if checkInvariants {
+ if pseg.Ok() && pseg.End() <= vsegAR.Start {
+ panic(fmt.Sprintf("pma %v precedes vma %v", pseg.Range(), vsegAR))
+ }
+ }
+ for pseg.Ok() && pseg.Start() < vsegAR.End {
+ pma := pseg.ValuePtr()
+ if pma.private && !mm.isPMACopyOnWriteLocked(vseg, pseg) {
+ psegAR := pseg.Range().Intersect(ar)
+ if vsegAR.IsSupersetOf(psegAR) && vma.mappable == nil {
+ if err := mf.Decommit(pseg.fileRangeOf(psegAR)); err == nil {
+ pseg = pseg.NextSegment()
+ continue
+ }
+ // If an error occurs, fall through to the general
+ // invalidation case below.
+ }
+ }
+ pseg = mm.pmas.Isolate(pseg, vsegAR)
+ pma = pseg.ValuePtr()
+ if !didUnmapAS {
+ // Unmap all of ar, not just pseg.Range(), to minimize host
+ // syscalls. AddressSpace mappings must be removed before
+ // mm.decPrivateRef().
+ mm.unmapASLocked(ar)
+ didUnmapAS = true
+ }
+ if pma.private {
+ mm.decPrivateRef(pseg.fileRange())
+ }
+ pma.file.DecRef(pseg.fileRange())
+ mm.removeRSSLocked(pseg.Range())
+ pseg = mm.pmas.Remove(pseg).NextSegment()
+ }
+ }
+
+ // "If there are some parts of the specified address space that are not
+ // mapped, the Linux version of madvise() ignores them and applies the call
+ // to the rest (but returns ENOMEM from the system call, as it should)." -
+ // madvise(2)
+ if mm.vmas.SpanRange(ar) != ar.Length() {
+ return syserror.ENOMEM
+ }
+ return nil
+}
+
+// MSyncOpts holds options to MSync.
+type MSyncOpts struct {
+ // Sync has the semantics of MS_SYNC.
+ Sync bool
+
+ // Invalidate has the semantics of MS_INVALIDATE.
+ Invalidate bool
+}
+
+// MSync implements the semantics of Linux's msync().
+func (mm *MemoryManager) MSync(ctx context.Context, addr usermem.Addr, length uint64, opts MSyncOpts) error {
+ if addr != addr.RoundDown() {
+ return syserror.EINVAL
+ }
+ if length == 0 {
+ return nil
+ }
+ la, ok := usermem.Addr(length).RoundUp()
+ if !ok {
+ return syserror.ENOMEM
+ }
+ ar, ok := addr.ToRange(uint64(la))
+ if !ok {
+ return syserror.ENOMEM
+ }
+
+ mm.mappingMu.RLock()
+ // Can't defer mm.mappingMu.RUnlock(); see below.
+ vseg := mm.vmas.LowerBoundSegment(ar.Start)
+ if !vseg.Ok() {
+ mm.mappingMu.RUnlock()
+ return syserror.ENOMEM
+ }
+ var unmapped bool
+ lastEnd := ar.Start
+ for {
+ if !vseg.Ok() {
+ mm.mappingMu.RUnlock()
+ unmapped = true
+ break
+ }
+ if lastEnd < vseg.Start() {
+ unmapped = true
+ }
+ lastEnd = vseg.End()
+ vma := vseg.ValuePtr()
+ if opts.Invalidate && vma.mlockMode != memmap.MLockNone {
+ mm.mappingMu.RUnlock()
+ return syserror.EBUSY
+ }
+ // It's only possible to have dirtied the Mappable through a shared
+ // mapping. Don't check if the mapping is writable, because mprotect
+ // may have changed this, and also because Linux doesn't.
+ if id := vma.id; opts.Sync && id != nil && vma.mappable != nil && !vma.private {
+ // We can't call memmap.MappingIdentity.Msync while holding
+ // mm.mappingMu since it may take fs locks that precede it in the
+ // lock order.
+ id.IncRef()
+ mr := vseg.mappableRangeOf(vseg.Range().Intersect(ar))
+ mm.mappingMu.RUnlock()
+ err := id.Msync(ctx, mr)
+ id.DecRef()
+ if err != nil {
+ return err
+ }
+ if lastEnd >= ar.End {
+ break
+ }
+ mm.mappingMu.RLock()
+ vseg = mm.vmas.LowerBoundSegment(lastEnd)
+ } else {
+ if lastEnd >= ar.End {
+ mm.mappingMu.RUnlock()
+ break
+ }
+ vseg = vseg.NextSegment()
+ }
+ }
+
+ if unmapped {
+ return syserror.ENOMEM
+ }
+ return nil
+}
+
+// GetSharedFutexKey is used by kernel.Task.GetSharedKey.
+func (mm *MemoryManager) GetSharedFutexKey(ctx context.Context, addr usermem.Addr) (futex.Key, error) {
+ ar, ok := addr.ToRange(4) // sizeof(int32).
+ if !ok {
+ return futex.Key{}, syserror.EFAULT
+ }
+
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ vseg, _, err := mm.getVMAsLocked(ctx, ar, usermem.Read, false)
+ if err != nil {
+ return futex.Key{}, err
+ }
+ vma := vseg.ValuePtr()
+
+ if vma.private {
+ return futex.Key{
+ Kind: futex.KindSharedPrivate,
+ Offset: uint64(addr),
+ }, nil
+ }
+
+ if vma.id != nil {
+ vma.id.IncRef()
+ }
+ return futex.Key{
+ Kind: futex.KindSharedMappable,
+ Mappable: vma.mappable,
+ MappingIdentity: vma.id,
+ Offset: vseg.mappableOffsetAt(addr),
+ }, nil
+}
+
+// VirtualMemorySize returns the combined length in bytes of all mappings in
+// mm.
+func (mm *MemoryManager) VirtualMemorySize() uint64 {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ return mm.usageAS
+}
+
+// VirtualMemorySizeRange returns the combined length in bytes of all mappings
+// in ar in mm.
+func (mm *MemoryManager) VirtualMemorySizeRange(ar usermem.AddrRange) uint64 {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ return uint64(mm.vmas.SpanRange(ar))
+}
+
+// ResidentSetSize returns the value advertised as mm's RSS in bytes.
+func (mm *MemoryManager) ResidentSetSize() uint64 {
+ mm.activeMu.RLock()
+ defer mm.activeMu.RUnlock()
+ return mm.curRSS
+}
+
+// MaxResidentSetSize returns the value advertised as mm's max RSS in bytes.
+func (mm *MemoryManager) MaxResidentSetSize() uint64 {
+ mm.activeMu.RLock()
+ defer mm.activeMu.RUnlock()
+ return mm.maxRSS
+}
+
+// VirtualDataSize returns the size of private data segments in mm.
+func (mm *MemoryManager) VirtualDataSize() uint64 {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ return mm.dataAS
+}
diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go
new file mode 100644
index 000000000..16d8207e9
--- /dev/null
+++ b/pkg/sentry/mm/vma.go
@@ -0,0 +1,568 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Preconditions: mm.mappingMu must be locked for writing. opts must be valid
+// as defined by the checks in MMap.
+func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOpts) (vmaIterator, usermem.AddrRange, error) {
+ if opts.MaxPerms != opts.MaxPerms.Effective() {
+ panic(fmt.Sprintf("Non-effective MaxPerms %s cannot be enforced", opts.MaxPerms))
+ }
+
+ // Find a usable range.
+ addr, err := mm.findAvailableLocked(opts.Length, findAvailableOpts{
+ Addr: opts.Addr,
+ Fixed: opts.Fixed,
+ Unmap: opts.Unmap,
+ Map32Bit: opts.Map32Bit,
+ })
+ if err != nil {
+ return vmaIterator{}, usermem.AddrRange{}, err
+ }
+ ar, _ := addr.ToRange(opts.Length)
+
+ // Check against RLIMIT_AS.
+ newUsageAS := mm.usageAS + opts.Length
+ if opts.Unmap {
+ newUsageAS -= uint64(mm.vmas.SpanRange(ar))
+ }
+ if limitAS := limits.FromContext(ctx).Get(limits.AS).Cur; newUsageAS > limitAS {
+ return vmaIterator{}, usermem.AddrRange{}, syserror.ENOMEM
+ }
+
+ if opts.MLockMode != memmap.MLockNone {
+ // Check against RLIMIT_MEMLOCK.
+ if creds := auth.CredentialsFromContext(ctx); !creds.HasCapabilityIn(linux.CAP_IPC_LOCK, creds.UserNamespace.Root()) {
+ mlockLimit := limits.FromContext(ctx).Get(limits.MemoryLocked).Cur
+ if mlockLimit == 0 {
+ return vmaIterator{}, usermem.AddrRange{}, syserror.EPERM
+ }
+ newLockedAS := mm.lockedAS + opts.Length
+ if opts.Unmap {
+ newLockedAS -= mm.mlockedBytesRangeLocked(ar)
+ }
+ if newLockedAS > mlockLimit {
+ return vmaIterator{}, usermem.AddrRange{}, syserror.EAGAIN
+ }
+ }
+ }
+
+ // Remove overwritten mappings. This ordering is consistent with Linux:
+ // compare Linux's mm/mmap.c:mmap_region() => do_munmap(),
+ // file->f_op->mmap().
+ var vgap vmaGapIterator
+ if opts.Unmap {
+ vgap = mm.unmapLocked(ctx, ar)
+ } else {
+ vgap = mm.vmas.FindGap(ar.Start)
+ }
+
+ // Inform the Mappable, if any, of the new mapping.
+ if opts.Mappable != nil {
+ // The expression for writable is vma.canWriteMappableLocked(), but we
+ // don't yet have a vma.
+ if err := opts.Mappable.AddMapping(ctx, mm, ar, opts.Offset, !opts.Private && opts.MaxPerms.Write); err != nil {
+ return vmaIterator{}, usermem.AddrRange{}, err
+ }
+ }
+
+ // Take a reference on opts.MappingIdentity before inserting the vma since
+ // vma merging can drop the reference.
+ if opts.MappingIdentity != nil {
+ opts.MappingIdentity.IncRef()
+ }
+
+ // Finally insert the vma.
+ v := vma{
+ mappable: opts.Mappable,
+ off: opts.Offset,
+ realPerms: opts.Perms,
+ effectivePerms: opts.Perms.Effective(),
+ maxPerms: opts.MaxPerms,
+ private: opts.Private,
+ growsDown: opts.GrowsDown,
+ mlockMode: opts.MLockMode,
+ numaPolicy: linux.MPOL_DEFAULT,
+ id: opts.MappingIdentity,
+ hint: opts.Hint,
+ }
+
+ vseg := mm.vmas.Insert(vgap, ar, v)
+ mm.usageAS += opts.Length
+ if v.isPrivateDataLocked() {
+ mm.dataAS += opts.Length
+ }
+ if opts.MLockMode != memmap.MLockNone {
+ mm.lockedAS += opts.Length
+ }
+
+ return vseg, ar, nil
+}
+
+type findAvailableOpts struct {
+ // These fields are equivalent to those in memmap.MMapOpts, except that:
+ //
+ // - Addr must be page-aligned.
+ //
+ // - Unmap allows existing guard pages in the returned range.
+
+ Addr usermem.Addr
+ Fixed bool
+ Unmap bool
+ Map32Bit bool
+}
+
+// map32Start/End are the bounds to which MAP_32BIT mappings are constrained,
+// and are equivalent to Linux's MAP32_BASE and MAP32_MAX respectively.
+const (
+ map32Start = 0x40000000
+ map32End = 0x80000000
+)
+
+// findAvailableLocked finds an allocatable range.
+//
+// Preconditions: mm.mappingMu must be locked.
+func (mm *MemoryManager) findAvailableLocked(length uint64, opts findAvailableOpts) (usermem.Addr, error) {
+ if opts.Fixed {
+ opts.Map32Bit = false
+ }
+ allowedAR := mm.applicationAddrRange()
+ if opts.Map32Bit {
+ allowedAR = allowedAR.Intersect(usermem.AddrRange{map32Start, map32End})
+ }
+
+ // Does the provided suggestion work?
+ if ar, ok := opts.Addr.ToRange(length); ok {
+ if allowedAR.IsSupersetOf(ar) {
+ if opts.Unmap {
+ return ar.Start, nil
+ }
+ // Check for the presence of an existing vma or guard page.
+ if vgap := mm.vmas.FindGap(ar.Start); vgap.Ok() && vgap.availableRange().IsSupersetOf(ar) {
+ return ar.Start, nil
+ }
+ }
+ }
+
+ // Fixed mappings accept only the requested address.
+ if opts.Fixed {
+ return 0, syserror.ENOMEM
+ }
+
+ // Prefer hugepage alignment if a hugepage or more is requested.
+ alignment := uint64(usermem.PageSize)
+ if length >= usermem.HugePageSize {
+ alignment = usermem.HugePageSize
+ }
+
+ if opts.Map32Bit {
+ return mm.findLowestAvailableLocked(length, alignment, allowedAR)
+ }
+ if mm.layout.DefaultDirection == arch.MmapBottomUp {
+ return mm.findLowestAvailableLocked(length, alignment, usermem.AddrRange{mm.layout.BottomUpBase, mm.layout.MaxAddr})
+ }
+ return mm.findHighestAvailableLocked(length, alignment, usermem.AddrRange{mm.layout.MinAddr, mm.layout.TopDownBase})
+}
+
+func (mm *MemoryManager) applicationAddrRange() usermem.AddrRange {
+ return usermem.AddrRange{mm.layout.MinAddr, mm.layout.MaxAddr}
+}
+
+// Preconditions: mm.mappingMu must be locked.
+func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) {
+ for gap := mm.vmas.LowerBoundGap(bounds.Start); gap.Ok() && gap.Start() < bounds.End; gap = gap.NextLargeEnoughGap(usermem.Addr(length)) {
+ if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length {
+ // Can we shift up to match the alignment?
+ if offset := uint64(gr.Start) % alignment; offset != 0 {
+ if uint64(gr.Length()) >= length+alignment-offset {
+ // Yes, we're aligned.
+ return gr.Start + usermem.Addr(alignment-offset), nil
+ }
+ }
+
+ // Either aligned perfectly, or can't align it.
+ return gr.Start, nil
+ }
+ }
+ return 0, syserror.ENOMEM
+}
+
+// Preconditions: mm.mappingMu must be locked.
+func (mm *MemoryManager) findHighestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) {
+ for gap := mm.vmas.UpperBoundGap(bounds.End); gap.Ok() && gap.End() > bounds.Start; gap = gap.PrevLargeEnoughGap(usermem.Addr(length)) {
+ if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length {
+ // Can we shift down to match the alignment?
+ start := gr.End - usermem.Addr(length)
+ if offset := uint64(start) % alignment; offset != 0 {
+ if gr.Start <= start-usermem.Addr(offset) {
+ // Yes, we're aligned.
+ return start - usermem.Addr(offset), nil
+ }
+ }
+
+ // Either aligned perfectly, or can't align it.
+ return start, nil
+ }
+ }
+ return 0, syserror.ENOMEM
+}
+
+// Preconditions: mm.mappingMu must be locked.
+func (mm *MemoryManager) mlockedBytesRangeLocked(ar usermem.AddrRange) uint64 {
+ var total uint64
+ for vseg := mm.vmas.LowerBoundSegment(ar.Start); vseg.Ok() && vseg.Start() < ar.End; vseg = vseg.NextSegment() {
+ if vseg.ValuePtr().mlockMode != memmap.MLockNone {
+ total += uint64(vseg.Range().Intersect(ar).Length())
+ }
+ }
+ return total
+}
+
+// getVMAsLocked ensures that vmas exist for all addresses in ar, and support
+// access of type (at, ignorePermissions). It returns:
+//
+// - An iterator to the vma containing ar.Start. If no vma contains ar.Start,
+// the iterator is unspecified.
+//
+// - An iterator to the gap after the last vma containing an address in ar. If
+// vmas exist for no addresses in ar, the iterator is to a gap that begins
+// before ar.Start.
+//
+// - An error that is non-nil if vmas exist for only a subset of ar.
+//
+// Preconditions: mm.mappingMu must be locked for reading; it may be
+// temporarily unlocked. ar.Length() != 0.
+func (mm *MemoryManager) getVMAsLocked(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) (vmaIterator, vmaGapIterator, error) {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ }
+
+ // Inline mm.vmas.LowerBoundSegment so that we have the preceding gap if
+ // !vbegin.Ok().
+ vbegin, vgap := mm.vmas.Find(ar.Start)
+ if !vbegin.Ok() {
+ vbegin = vgap.NextSegment()
+ // vseg.Ok() is checked before entering the following loop.
+ } else {
+ vgap = vbegin.PrevGap()
+ }
+
+ addr := ar.Start
+ vseg := vbegin
+ for vseg.Ok() {
+ // Loop invariants: vgap = vseg.PrevGap(); addr < vseg.End().
+ vma := vseg.ValuePtr()
+ if addr < vseg.Start() {
+ // TODO(jamieliu): Implement vma.growsDown here.
+ return vbegin, vgap, syserror.EFAULT
+ }
+
+ perms := vma.effectivePerms
+ if ignorePermissions {
+ perms = vma.maxPerms
+ }
+ if !perms.SupersetOf(at) {
+ return vbegin, vgap, syserror.EPERM
+ }
+
+ addr = vseg.End()
+ vgap = vseg.NextGap()
+ if addr >= ar.End {
+ return vbegin, vgap, nil
+ }
+ vseg = vgap.NextSegment()
+ }
+
+ // Ran out of vmas before ar.End.
+ return vbegin, vgap, syserror.EFAULT
+}
+
+// getVecVMAsLocked ensures that vmas exist for all addresses in ars, and
+// support access to type of (at, ignorePermissions). It returns the subset of
+// ars for which vmas exist. If this is not equal to ars, it returns a non-nil
+// error explaining why.
+//
+// Preconditions: mm.mappingMu must be locked for reading; it may be
+// temporarily unlocked.
+//
+// Postconditions: ars is not mutated.
+func (mm *MemoryManager) getVecVMAsLocked(ctx context.Context, ars usermem.AddrRangeSeq, at usermem.AccessType, ignorePermissions bool) (usermem.AddrRangeSeq, error) {
+ for arsit := ars; !arsit.IsEmpty(); arsit = arsit.Tail() {
+ ar := arsit.Head()
+ if ar.Length() == 0 {
+ continue
+ }
+ if _, vend, err := mm.getVMAsLocked(ctx, ar, at, ignorePermissions); err != nil {
+ return truncatedAddrRangeSeq(ars, arsit, vend.Start()), err
+ }
+ }
+ return ars, nil
+}
+
+// vma extension will not shrink the number of unmapped bytes between the start
+// of a growsDown vma and the end of its predecessor non-growsDown vma below
+// guardBytes.
+//
+// guardBytes is equivalent to Linux's stack_guard_gap after upstream
+// 1be7107fbe18 "mm: larger stack guard gap, between vmas".
+const guardBytes = 256 * usermem.PageSize
+
+// unmapLocked unmaps all addresses in ar and returns the resulting gap in
+// mm.vmas.
+//
+// Preconditions: mm.mappingMu must be locked for writing. ar.Length() != 0.
+// ar must be page-aligned.
+func (mm *MemoryManager) unmapLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ }
+
+ // AddressSpace mappings and pmas must be invalidated before
+ // mm.removeVMAsLocked() => memmap.Mappable.RemoveMapping().
+ mm.Invalidate(ar, memmap.InvalidateOpts{InvalidatePrivate: true})
+ return mm.removeVMAsLocked(ctx, ar)
+}
+
+// removeVMAsLocked removes vmas for addresses in ar and returns the resulting
+// gap in mm.vmas. It does not remove pmas or AddressSpace mappings; clients
+// must do so before calling removeVMAsLocked.
+//
+// Preconditions: mm.mappingMu must be locked for writing. ar.Length() != 0. ar
+// must be page-aligned.
+func (mm *MemoryManager) removeVMAsLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ }
+
+ vseg, vgap := mm.vmas.Find(ar.Start)
+ if vgap.Ok() {
+ vseg = vgap.NextSegment()
+ }
+ for vseg.Ok() && vseg.Start() < ar.End {
+ vseg = mm.vmas.Isolate(vseg, ar)
+ vmaAR := vseg.Range()
+ vma := vseg.ValuePtr()
+ if vma.mappable != nil {
+ vma.mappable.RemoveMapping(ctx, mm, vmaAR, vma.off, vma.canWriteMappableLocked())
+ }
+ if vma.id != nil {
+ vma.id.DecRef()
+ }
+ mm.usageAS -= uint64(vmaAR.Length())
+ if vma.isPrivateDataLocked() {
+ mm.dataAS -= uint64(vmaAR.Length())
+ }
+ if vma.mlockMode != memmap.MLockNone {
+ mm.lockedAS -= uint64(vmaAR.Length())
+ }
+ vgap = mm.vmas.Remove(vseg)
+ vseg = vgap.NextSegment()
+ }
+ return vgap
+}
+
+// canWriteMappableLocked returns true if it is possible for vma.mappable to be
+// written to via this vma, i.e. if it is possible that
+// vma.mappable.Translate(at.Write=true) may be called as a result of this vma.
+// This includes via I/O with usermem.IOOpts.IgnorePermissions = true, such as
+// PTRACE_POKEDATA.
+//
+// canWriteMappableLocked is equivalent to Linux's VM_SHARED.
+//
+// Preconditions: mm.mappingMu must be locked.
+func (vma *vma) canWriteMappableLocked() bool {
+ return !vma.private && vma.maxPerms.Write
+}
+
+// isPrivateDataLocked identify the data segments - private, writable, not stack
+//
+// Preconditions: mm.mappingMu must be locked.
+func (vma *vma) isPrivateDataLocked() bool {
+ return vma.realPerms.Write && vma.private && !vma.growsDown
+}
+
+// vmaSetFunctions implements segment.Functions for vmaSet.
+type vmaSetFunctions struct{}
+
+func (vmaSetFunctions) MinKey() usermem.Addr {
+ return 0
+}
+
+func (vmaSetFunctions) MaxKey() usermem.Addr {
+ return ^usermem.Addr(0)
+}
+
+func (vmaSetFunctions) ClearValue(vma *vma) {
+ vma.mappable = nil
+ vma.id = nil
+ vma.hint = ""
+}
+
+func (vmaSetFunctions) Merge(ar1 usermem.AddrRange, vma1 vma, ar2 usermem.AddrRange, vma2 vma) (vma, bool) {
+ if vma1.mappable != vma2.mappable ||
+ (vma1.mappable != nil && vma1.off+uint64(ar1.Length()) != vma2.off) ||
+ vma1.realPerms != vma2.realPerms ||
+ vma1.maxPerms != vma2.maxPerms ||
+ vma1.private != vma2.private ||
+ vma1.growsDown != vma2.growsDown ||
+ vma1.mlockMode != vma2.mlockMode ||
+ vma1.numaPolicy != vma2.numaPolicy ||
+ vma1.numaNodemask != vma2.numaNodemask ||
+ vma1.dontfork != vma2.dontfork ||
+ vma1.id != vma2.id ||
+ vma1.hint != vma2.hint {
+ return vma{}, false
+ }
+
+ if vma2.id != nil {
+ vma2.id.DecRef()
+ }
+ return vma1, true
+}
+
+func (vmaSetFunctions) Split(ar usermem.AddrRange, v vma, split usermem.Addr) (vma, vma) {
+ v2 := v
+ if v2.mappable != nil {
+ v2.off += uint64(split - ar.Start)
+ }
+ if v2.id != nil {
+ v2.id.IncRef()
+ }
+ return v, v2
+}
+
+// Preconditions: vseg.ValuePtr().mappable != nil. vseg.Range().Contains(addr).
+func (vseg vmaIterator) mappableOffsetAt(addr usermem.Addr) uint64 {
+ if checkInvariants {
+ if !vseg.Ok() {
+ panic("terminal vma iterator")
+ }
+ if vseg.ValuePtr().mappable == nil {
+ panic("Mappable offset is meaningless for anonymous vma")
+ }
+ if !vseg.Range().Contains(addr) {
+ panic(fmt.Sprintf("addr %v out of bounds %v", addr, vseg.Range()))
+ }
+ }
+
+ vma := vseg.ValuePtr()
+ vstart := vseg.Start()
+ return vma.off + uint64(addr-vstart)
+}
+
+// Preconditions: vseg.ValuePtr().mappable != nil.
+func (vseg vmaIterator) mappableRange() memmap.MappableRange {
+ return vseg.mappableRangeOf(vseg.Range())
+}
+
+// Preconditions: vseg.ValuePtr().mappable != nil.
+// vseg.Range().IsSupersetOf(ar). ar.Length() != 0.
+func (vseg vmaIterator) mappableRangeOf(ar usermem.AddrRange) memmap.MappableRange {
+ if checkInvariants {
+ if !vseg.Ok() {
+ panic("terminal vma iterator")
+ }
+ if vseg.ValuePtr().mappable == nil {
+ panic("MappableRange is meaningless for anonymous vma")
+ }
+ if !ar.WellFormed() || ar.Length() <= 0 {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ if !vseg.Range().IsSupersetOf(ar) {
+ panic(fmt.Sprintf("ar %v out of bounds %v", ar, vseg.Range()))
+ }
+ }
+
+ vma := vseg.ValuePtr()
+ vstart := vseg.Start()
+ return memmap.MappableRange{vma.off + uint64(ar.Start-vstart), vma.off + uint64(ar.End-vstart)}
+}
+
+// Preconditions: vseg.ValuePtr().mappable != nil.
+// vseg.mappableRange().IsSupersetOf(mr). mr.Length() != 0.
+func (vseg vmaIterator) addrRangeOf(mr memmap.MappableRange) usermem.AddrRange {
+ if checkInvariants {
+ if !vseg.Ok() {
+ panic("terminal vma iterator")
+ }
+ if vseg.ValuePtr().mappable == nil {
+ panic("MappableRange is meaningless for anonymous vma")
+ }
+ if !mr.WellFormed() || mr.Length() <= 0 {
+ panic(fmt.Sprintf("invalid mr: %v", mr))
+ }
+ if !vseg.mappableRange().IsSupersetOf(mr) {
+ panic(fmt.Sprintf("mr %v out of bounds %v", mr, vseg.mappableRange()))
+ }
+ }
+
+ vma := vseg.ValuePtr()
+ vstart := vseg.Start()
+ return usermem.AddrRange{vstart + usermem.Addr(mr.Start-vma.off), vstart + usermem.Addr(mr.End-vma.off)}
+}
+
+// seekNextLowerBound returns mm.vmas.LowerBoundSegment(addr), but does so by
+// scanning linearly forward from vseg.
+//
+// Preconditions: mm.mappingMu must be locked. addr >= vseg.Start().
+func (vseg vmaIterator) seekNextLowerBound(addr usermem.Addr) vmaIterator {
+ if checkInvariants {
+ if !vseg.Ok() {
+ panic("terminal vma iterator")
+ }
+ if addr < vseg.Start() {
+ panic(fmt.Sprintf("can't seek forward to %#x from %#x", addr, vseg.Start()))
+ }
+ }
+ for vseg.Ok() && addr >= vseg.End() {
+ vseg = vseg.NextSegment()
+ }
+ return vseg
+}
+
+// availableRange returns the subset of vgap.Range() in which new vmas may be
+// created without MMapOpts.Unmap == true.
+func (vgap vmaGapIterator) availableRange() usermem.AddrRange {
+ ar := vgap.Range()
+ next := vgap.NextSegment()
+ if !next.Ok() || !next.ValuePtr().growsDown {
+ return ar
+ }
+ // Exclude guard pages.
+ if ar.Length() < guardBytes {
+ return usermem.AddrRange{ar.Start, ar.Start}
+ }
+ ar.End -= guardBytes
+ return ar
+}
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
new file mode 100644
index 000000000..e1fcb175f
--- /dev/null
+++ b/pkg/sentry/pgalloc/BUILD
@@ -0,0 +1,108 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "evictable_range",
+ out = "evictable_range.go",
+ package = "pgalloc",
+ prefix = "Evictable",
+ template = "//pkg/segment:generic_range",
+ types = {
+ "T": "uint64",
+ },
+)
+
+go_template_instance(
+ name = "evictable_range_set",
+ out = "evictable_range_set.go",
+ package = "pgalloc",
+ prefix = "evictableRange",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "uint64",
+ "Range": "EvictableRange",
+ "Value": "evictableRangeSetValue",
+ "Functions": "evictableRangeSetFunctions",
+ },
+)
+
+go_template_instance(
+ name = "usage_set",
+ out = "usage_set.go",
+ consts = {
+ "minDegree": "10",
+ "trackGaps": "1",
+ },
+ imports = {
+ "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ },
+ package = "pgalloc",
+ prefix = "usage",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "uint64",
+ "Range": "platform.FileRange",
+ "Value": "usageInfo",
+ "Functions": "usageSetFunctions",
+ },
+)
+
+go_template_instance(
+ name = "reclaim_set",
+ out = "reclaim_set.go",
+ consts = {
+ "minDegree": "10",
+ },
+ imports = {
+ "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ },
+ package = "pgalloc",
+ prefix = "reclaim",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "uint64",
+ "Range": "platform.FileRange",
+ "Value": "reclaimSetValue",
+ "Functions": "reclaimSetFunctions",
+ },
+)
+
+go_library(
+ name = "pgalloc",
+ srcs = [
+ "context.go",
+ "evictable_range.go",
+ "evictable_range_set.go",
+ "pgalloc.go",
+ "pgalloc_unsafe.go",
+ "reclaim_set.go",
+ "save_restore.go",
+ "usage_set.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/memutil",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/hostmm",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/usage",
+ "//pkg/state",
+ "//pkg/state/wire",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
+
+go_test(
+ name = "pgalloc_test",
+ size = "small",
+ srcs = ["pgalloc_test.go"],
+ library = ":pgalloc",
+ deps = ["//pkg/usermem"],
+)
diff --git a/pkg/sentry/pgalloc/context.go b/pkg/sentry/pgalloc/context.go
new file mode 100644
index 000000000..d25215418
--- /dev/null
+++ b/pkg/sentry/pgalloc/context.go
@@ -0,0 +1,48 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pgalloc
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+// contextID is this package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxMemoryFile is a Context.Value key for a MemoryFile.
+ CtxMemoryFile contextID = iota
+
+ // CtxMemoryFileProvider is a Context.Value key for a MemoryFileProvider.
+ CtxMemoryFileProvider
+)
+
+// MemoryFileFromContext returns the MemoryFile used by ctx, or nil if no such
+// MemoryFile exists.
+func MemoryFileFromContext(ctx context.Context) *MemoryFile {
+ if v := ctx.Value(CtxMemoryFile); v != nil {
+ return v.(*MemoryFile)
+ }
+ return nil
+}
+
+// MemoryFileProviderFromContext returns the MemoryFileProvider used by ctx, or nil if no such
+// MemoryFileProvider exists.
+func MemoryFileProviderFromContext(ctx context.Context) MemoryFileProvider {
+ if v := ctx.Value(CtxMemoryFileProvider); v != nil {
+ return v.(MemoryFileProvider)
+ }
+ return nil
+}
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
new file mode 100644
index 000000000..afab97c0a
--- /dev/null
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -0,0 +1,1279 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pgalloc contains the page allocator subsystem, which manages memory
+// that may be mapped into application address spaces.
+//
+// Lock order:
+//
+// pgalloc.MemoryFile.mu
+// pgalloc.MemoryFile.mappingsMu
+package pgalloc
+
+import (
+ "fmt"
+ "math"
+ "os"
+ "sync/atomic"
+ "syscall"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/hostmm"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// MemoryFile is a platform.File whose pages may be allocated to arbitrary
+// users.
+type MemoryFile struct {
+ // opts holds options passed to NewMemoryFile. opts is immutable.
+ opts MemoryFileOpts
+
+ // MemoryFile owns a single backing file, which is modeled as follows:
+ //
+ // Each page in the file can be committed or uncommitted. A page is
+ // committed if the host kernel is spending resources to store its contents
+ // and uncommitted otherwise. This definition includes pages that the host
+ // kernel has swapped; this is intentional, to ensure that accounting does
+ // not change even if host kernel swapping behavior changes, and that
+ // memory used by pseudo-swap mechanisms like zswap is still accounted.
+ //
+ // The initial contents of uncommitted pages are implicitly zero bytes. A
+ // read or write to the contents of an uncommitted page causes it to be
+ // committed. This is the only event that can cause a uncommitted page to
+ // be committed.
+ //
+ // fallocate(FALLOC_FL_PUNCH_HOLE) (MemoryFile.Decommit) causes committed
+ // pages to be uncommitted. This is the only event that can cause a
+ // committed page to be uncommitted.
+ //
+ // Memory accounting is based on identifying the set of committed pages.
+ // Since we do not have direct access to the MMU, tracking reads and writes
+ // to uncommitted pages to detect commitment would introduce additional
+ // page faults, which would be prohibitively expensive. Instead, we query
+ // the host kernel to determine which pages are committed.
+
+ // file is the backing file. The file pointer is immutable.
+ file *os.File
+
+ mu sync.Mutex
+
+ // usage maps each page in the file to metadata for that page. Pages for
+ // which no segment exists in usage are both unallocated (not in use) and
+ // uncommitted.
+ //
+ // Since usage stores usageInfo objects by value, clients should usually
+ // use usageIterator.ValuePtr() instead of usageIterator.Value() to get a
+ // pointer to the usageInfo rather than a copy.
+ //
+ // usage must be kept maximally merged (that is, there should never be two
+ // adjacent segments with the same values). At least markReclaimed depends
+ // on this property.
+ //
+ // usage is protected by mu.
+ usage usageSet
+
+ // The UpdateUsage function scans all segments with knownCommitted set
+ // to false, sees which pages are committed and creates corresponding
+ // segments with knownCommitted set to true.
+ //
+ // In order to avoid unnecessary scans, usageExpected tracks the total
+ // file blocks expected. This is used to elide the scan when this
+ // matches the underlying file blocks.
+ //
+ // To track swapped pages, usageSwapped tracks the discrepency between
+ // what is observed in core and what is reported by the file. When
+ // usageSwapped is non-zero, a sweep will be performed at least every
+ // second. The start of the last sweep is recorded in usageLast.
+ //
+ // All usage attributes are all protected by mu.
+ usageExpected uint64
+ usageSwapped uint64
+ usageLast time.Time
+
+ // fileSize is the size of the backing memory file in bytes. fileSize is
+ // always a power-of-two multiple of chunkSize.
+ //
+ // fileSize is protected by mu.
+ fileSize int64
+
+ // Pages from the backing file are mapped into the local address space on
+ // the granularity of large pieces called chunks. mappings is a []uintptr
+ // that stores, for each chunk, the start address of a mapping of that
+ // chunk in the current process' address space, or 0 if no such mapping
+ // exists. Once a chunk is mapped, it is never remapped or unmapped until
+ // the MemoryFile is destroyed.
+ //
+ // Mutating the mappings slice or its contents requires both holding
+ // mappingsMu and using atomic memory operations. (The slice is mutated
+ // whenever the file is expanded. Per the above, the only permitted
+ // mutation of the slice's contents is the assignment of a mapping to a
+ // chunk that was previously unmapped.) Reading the slice or its contents
+ // only requires *either* holding mappingsMu or using atomic memory
+ // operations. This allows MemoryFile.MapInternal to avoid locking in the
+ // common case where chunk mappings already exist.
+ mappingsMu sync.Mutex
+ mappings atomic.Value
+
+ // destroyed is set by Destroy to instruct the reclaimer goroutine to
+ // release resources and exit. destroyed is protected by mu.
+ destroyed bool
+
+ // reclaimable is true if usage may contain reclaimable pages. reclaimable
+ // is protected by mu.
+ reclaimable bool
+
+ // relcaim is the collection of regions for reclaim. relcaim is protected
+ // by mu.
+ reclaim reclaimSet
+
+ // reclaimCond is signaled (with mu locked) when reclaimable or destroyed
+ // transitions from false to true.
+ reclaimCond sync.Cond
+
+ // evictable maps EvictableMemoryUsers to eviction state.
+ //
+ // evictable is protected by mu.
+ evictable map[EvictableMemoryUser]*evictableMemoryUserInfo
+
+ // evictionWG counts the number of goroutines currently performing evictions.
+ evictionWG sync.WaitGroup
+
+ // stopNotifyPressure stops memory cgroup pressure level
+ // notifications used to drive eviction. stopNotifyPressure is
+ // immutable.
+ stopNotifyPressure func()
+}
+
+// MemoryFileOpts provides options to NewMemoryFile.
+type MemoryFileOpts struct {
+ // DelayedEviction controls the extent to which the MemoryFile may delay
+ // eviction of evictable allocations.
+ DelayedEviction DelayedEvictionType
+
+ // If UseHostMemcgPressure is true, use host memory cgroup pressure level
+ // notifications to determine when eviction is necessary. This option has
+ // no effect unless DelayedEviction is DelayedEvictionEnabled.
+ UseHostMemcgPressure bool
+
+ // If ManualZeroing is true, MemoryFile must not assume that new pages
+ // obtained from the host are zero-filled, such that MemoryFile must manually
+ // zero newly-allocated pages.
+ ManualZeroing bool
+}
+
+// DelayedEvictionType is the type of MemoryFileOpts.DelayedEviction.
+type DelayedEvictionType int
+
+const (
+ // DelayedEvictionDefault has unspecified behavior.
+ DelayedEvictionDefault DelayedEvictionType = iota
+
+ // DelayedEvictionDisabled requires that evictable allocations are evicted
+ // as soon as possible.
+ DelayedEvictionDisabled
+
+ // DelayedEvictionEnabled requests that the MemoryFile delay eviction of
+ // evictable allocations until doing so is considered necessary to avoid
+ // performance degradation due to host memory pressure, or OOM kills.
+ //
+ // As of this writing, the behavior of DelayedEvictionEnabled depends on
+ // whether or not MemoryFileOpts.UseHostMemcgPressure is enabled:
+ //
+ // - If UseHostMemcgPressure is true, evictions are delayed until memory
+ // pressure is indicated.
+ //
+ // - Otherwise, evictions are only delayed until the reclaimer goroutine
+ // is out of work (pages to reclaim).
+ DelayedEvictionEnabled
+
+ // DelayedEvictionManual requires that evictable allocations are only
+ // evicted when MemoryFile.StartEvictions() is called. This is extremely
+ // dangerous outside of tests.
+ DelayedEvictionManual
+)
+
+// usageInfo tracks usage information.
+//
+// +stateify savable
+type usageInfo struct {
+ // kind is the usage kind.
+ kind usage.MemoryKind
+
+ // knownCommitted is true if the tracked region is definitely committed.
+ // (If it is false, the tracked region may or may not be committed.)
+ knownCommitted bool
+
+ refs uint64
+}
+
+// An EvictableMemoryUser represents a user of MemoryFile-allocated memory that
+// may be asked to deallocate that memory in the presence of memory pressure.
+type EvictableMemoryUser interface {
+ // Evict requests that the EvictableMemoryUser deallocate memory used by
+ // er, which was registered as evictable by a previous call to
+ // MemoryFile.MarkEvictable.
+ //
+ // Evict is not required to deallocate memory. In particular, since pgalloc
+ // must call Evict without holding locks to avoid circular lock ordering,
+ // it is possible that the passed range has already been marked as
+ // unevictable by a racing call to MemoryFile.MarkUnevictable.
+ // Implementations of EvictableMemoryUser must detect such races and handle
+ // them by making Evict have no effect on unevictable ranges.
+ //
+ // After a call to Evict, the MemoryFile will consider the evicted range
+ // unevictable (i.e. it will not call Evict on the same range again) until
+ // informed otherwise by a subsequent call to MarkEvictable.
+ Evict(ctx context.Context, er EvictableRange)
+}
+
+// An EvictableRange represents a range of uint64 offsets in an
+// EvictableMemoryUser.
+//
+// In practice, most EvictableMemoryUsers will probably be implementations of
+// memmap.Mappable, and EvictableRange therefore corresponds to
+// memmap.MappableRange. However, this package cannot depend on the memmap
+// package, since doing so would create a circular dependency.
+//
+// type EvictableRange <generated using go_generics>
+
+// evictableMemoryUserInfo is the value type of MemoryFile.evictable.
+type evictableMemoryUserInfo struct {
+ // ranges tracks all evictable ranges for the given user.
+ ranges evictableRangeSet
+
+ // If evicting is true, there is a goroutine currently evicting all
+ // evictable ranges for this user.
+ evicting bool
+}
+
+const (
+ chunkShift = 30
+ chunkSize = 1 << chunkShift // 1 GB
+ chunkMask = chunkSize - 1
+
+ // maxPage is the highest 64-bit page.
+ maxPage = math.MaxUint64 &^ (usermem.PageSize - 1)
+)
+
+// NewMemoryFile creates a MemoryFile backed by the given file. If
+// NewMemoryFile succeeds, ownership of file is transferred to the returned
+// MemoryFile.
+func NewMemoryFile(file *os.File, opts MemoryFileOpts) (*MemoryFile, error) {
+ switch opts.DelayedEviction {
+ case DelayedEvictionDefault:
+ opts.DelayedEviction = DelayedEvictionEnabled
+ case DelayedEvictionDisabled, DelayedEvictionManual:
+ opts.UseHostMemcgPressure = false
+ case DelayedEvictionEnabled:
+ // ok
+ default:
+ return nil, fmt.Errorf("invalid MemoryFileOpts.DelayedEviction: %v", opts.DelayedEviction)
+ }
+
+ // Truncate the file to 0 bytes first to ensure that it's empty.
+ if err := file.Truncate(0); err != nil {
+ return nil, err
+ }
+ f := &MemoryFile{
+ opts: opts,
+ file: file,
+ evictable: make(map[EvictableMemoryUser]*evictableMemoryUserInfo),
+ }
+ f.mappings.Store(make([]uintptr, 0))
+ f.reclaimCond.L = &f.mu
+
+ if f.opts.DelayedEviction == DelayedEvictionEnabled && f.opts.UseHostMemcgPressure {
+ stop, err := hostmm.NotifyCurrentMemcgPressureCallback(func() {
+ f.mu.Lock()
+ startedAny := f.startEvictionsLocked()
+ f.mu.Unlock()
+ if startedAny {
+ log.Debugf("pgalloc.MemoryFile performing evictions due to memcg pressure")
+ }
+ }, "low")
+ if err != nil {
+ return nil, fmt.Errorf("failed to configure memcg pressure level notifications: %v", err)
+ }
+ f.stopNotifyPressure = stop
+ }
+
+ go f.runReclaim() // S/R-SAFE: f.mu
+
+ // The Linux kernel contains an optional feature called "Integrity
+ // Measurement Architecture" (IMA). If IMA is enabled, it will checksum
+ // binaries the first time they are mapped PROT_EXEC. This is bad news for
+ // executable pages mapped from our backing file, which can grow to
+ // terabytes in (sparse) size. If IMA attempts to checksum a file that
+ // large, it will allocate all of the sparse pages and quickly exhaust all
+ // memory.
+ //
+ // Work around IMA by immediately creating a temporary PROT_EXEC mapping,
+ // while the backing file is still small. IMA will ignore any future
+ // mappings.
+ m, _, errno := syscall.Syscall6(
+ syscall.SYS_MMAP,
+ 0,
+ usermem.PageSize,
+ syscall.PROT_EXEC,
+ syscall.MAP_SHARED,
+ file.Fd(),
+ 0)
+ if errno != 0 {
+ // This isn't fatal (IMA may not even be in use). Log the error, but
+ // don't return it.
+ log.Warningf("Failed to pre-map MemoryFile PROT_EXEC: %v", errno)
+ } else {
+ if _, _, errno := syscall.Syscall(
+ syscall.SYS_MUNMAP,
+ m,
+ usermem.PageSize,
+ 0); errno != 0 {
+ panic(fmt.Sprintf("failed to unmap PROT_EXEC MemoryFile mapping: %v", errno))
+ }
+ }
+
+ return f, nil
+}
+
+// Destroy releases all resources used by f.
+//
+// Preconditions: All pages allocated by f have been freed.
+//
+// Postconditions: None of f's methods may be called after Destroy.
+func (f *MemoryFile) Destroy() {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.destroyed = true
+ f.reclaimCond.Signal()
+}
+
+// Allocate returns a range of initially-zeroed pages of the given length with
+// the given accounting kind and a single reference held by the caller. When
+// the last reference on an allocated page is released, ownership of the page
+// is returned to the MemoryFile, allowing it to be returned by a future call
+// to Allocate.
+//
+// Preconditions: length must be page-aligned and non-zero.
+func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.FileRange, error) {
+ if length == 0 || length%usermem.PageSize != 0 {
+ panic(fmt.Sprintf("invalid allocation length: %#x", length))
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // Align hugepage-and-larger allocations on hugepage boundaries to try
+ // to take advantage of hugetmpfs.
+ alignment := uint64(usermem.PageSize)
+ if length >= usermem.HugePageSize {
+ alignment = usermem.HugePageSize
+ }
+
+ // Find a range in the underlying file.
+ fr, ok := findAvailableRange(&f.usage, f.fileSize, length, alignment)
+ if !ok {
+ return platform.FileRange{}, syserror.ENOMEM
+ }
+
+ // Expand the file if needed.
+ if int64(fr.End) > f.fileSize {
+ // Round the new file size up to be chunk-aligned.
+ newFileSize := (int64(fr.End) + chunkMask) &^ chunkMask
+ if err := f.file.Truncate(newFileSize); err != nil {
+ return platform.FileRange{}, err
+ }
+ f.fileSize = newFileSize
+ f.mappingsMu.Lock()
+ oldMappings := f.mappings.Load().([]uintptr)
+ newMappings := make([]uintptr, newFileSize>>chunkShift)
+ copy(newMappings, oldMappings)
+ f.mappings.Store(newMappings)
+ f.mappingsMu.Unlock()
+ }
+
+ // Mark selected pages as in use.
+ if f.opts.ManualZeroing {
+ if err := f.forEachMappingSlice(fr, func(bs []byte) {
+ for i := range bs {
+ bs[i] = 0
+ }
+ }); err != nil {
+ return platform.FileRange{}, err
+ }
+ }
+ if !f.usage.Add(fr, usageInfo{
+ kind: kind,
+ refs: 1,
+ }) {
+ panic(fmt.Sprintf("allocating %v: failed to insert into usage set:\n%v", fr, &f.usage))
+ }
+
+ return fr, nil
+}
+
+// findAvailableRange returns an available range in the usageSet.
+//
+// Note that scanning for available slots takes place from end first backwards,
+// then forwards. This heuristic has important consequence for how sequential
+// mappings can be merged in the host VMAs, given that addresses for both
+// application and sentry mappings are allocated top-down (from higher to
+// lower addresses). The file is also grown expoentially in order to create
+// space for mappings to be allocated downwards.
+//
+// Precondition: alignment must be a power of 2.
+func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (platform.FileRange, bool) {
+ alignmentMask := alignment - 1
+
+ // Search for space in existing gaps, starting at the current end of the
+ // file and working backward.
+ lastGap := usage.LastGap()
+ gap := lastGap
+ for {
+ end := gap.End()
+ if end > uint64(fileSize) {
+ end = uint64(fileSize)
+ }
+
+ // Try to allocate from the end of this gap, with the start of the
+ // allocated range aligned down to alignment.
+ unalignedStart := end - length
+ if unalignedStart > end {
+ // Negative overflow: this and all preceding gaps are too small to
+ // accommodate length.
+ break
+ }
+ if start := unalignedStart &^ alignmentMask; start >= gap.Start() {
+ return platform.FileRange{start, start + length}, true
+ }
+
+ gap = gap.PrevLargeEnoughGap(length)
+ if !gap.Ok() {
+ break
+ }
+ }
+
+ // Check that it's possible to fit this allocation at the end of a file of any size.
+ min := lastGap.Start()
+ min = (min + alignmentMask) &^ alignmentMask
+ if min+length < min {
+ // Overflow: allocation would exceed the range of uint64.
+ return platform.FileRange{}, false
+ }
+
+ // Determine the minimum file size required to fit this allocation at its end.
+ for {
+ newFileSize := 2 * fileSize
+ if newFileSize <= fileSize {
+ if fileSize != 0 {
+ // Overflow: allocation would exceed the range of int64.
+ return platform.FileRange{}, false
+ }
+ newFileSize = chunkSize
+ }
+ fileSize = newFileSize
+
+ unalignedStart := uint64(fileSize) - length
+ if unalignedStart > uint64(fileSize) {
+ // Negative overflow: fileSize is still inadequate.
+ continue
+ }
+ if start := unalignedStart &^ alignmentMask; start >= min {
+ return platform.FileRange{start, start + length}, true
+ }
+ }
+}
+
+// AllocateAndFill allocates memory of the given kind and fills it by calling
+// r.ReadToBlocks() repeatedly until either length bytes are read or a non-nil
+// error is returned. It returns the memory filled by r, truncated down to the
+// nearest page. If this is shorter than length bytes due to an error returned
+// by r.ReadToBlocks(), it returns that error.
+//
+// Preconditions: length > 0. length must be page-aligned.
+func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (platform.FileRange, error) {
+ fr, err := f.Allocate(length, kind)
+ if err != nil {
+ return platform.FileRange{}, err
+ }
+ dsts, err := f.MapInternal(fr, usermem.Write)
+ if err != nil {
+ f.DecRef(fr)
+ return platform.FileRange{}, err
+ }
+ n, err := safemem.ReadFullToBlocks(r, dsts)
+ un := uint64(usermem.Addr(n).RoundDown())
+ if un < length {
+ // Free unused memory and update fr to contain only the memory that is
+ // still allocated.
+ f.DecRef(platform.FileRange{fr.Start + un, fr.End})
+ fr.End = fr.Start + un
+ }
+ return fr, err
+}
+
+// fallocate(2) modes, defined in Linux's include/uapi/linux/falloc.h.
+const (
+ _FALLOC_FL_KEEP_SIZE = 1
+ _FALLOC_FL_PUNCH_HOLE = 2
+)
+
+// Decommit releases resources associated with maintaining the contents of the
+// given pages. If Decommit succeeds, future accesses of the decommitted pages
+// will read zeroes.
+//
+// Preconditions: fr.Length() > 0.
+func (f *MemoryFile) Decommit(fr platform.FileRange) error {
+ if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
+ panic(fmt.Sprintf("invalid range: %v", fr))
+ }
+
+ // "After a successful call, subsequent reads from this range will
+ // return zeroes. The FALLOC_FL_PUNCH_HOLE flag must be ORed with
+ // FALLOC_FL_KEEP_SIZE in mode ..." - fallocate(2)
+ err := syscall.Fallocate(
+ int(f.file.Fd()),
+ _FALLOC_FL_PUNCH_HOLE|_FALLOC_FL_KEEP_SIZE,
+ int64(fr.Start),
+ int64(fr.Length()))
+ if err != nil {
+ return err
+ }
+ f.markDecommitted(fr)
+ return nil
+}
+
+func (f *MemoryFile) markDecommitted(fr platform.FileRange) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ // Since we're changing the knownCommitted attribute, we need to merge
+ // across the entire range to ensure that the usage tree is minimal.
+ gap := f.usage.ApplyContiguous(fr, func(seg usageIterator) {
+ val := seg.ValuePtr()
+ if val.knownCommitted {
+ // Drop the usageExpected appropriately.
+ amount := seg.Range().Length()
+ usage.MemoryAccounting.Dec(amount, val.kind)
+ f.usageExpected -= amount
+ val.knownCommitted = false
+ }
+ })
+ if gap.Ok() {
+ panic(fmt.Sprintf("Decommit(%v): attempted to decommit unallocated pages %v:\n%v", fr, gap.Range(), &f.usage))
+ }
+ f.usage.MergeRange(fr)
+}
+
+// IncRef implements platform.File.IncRef.
+func (f *MemoryFile) IncRef(fr platform.FileRange) {
+ if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
+ panic(fmt.Sprintf("invalid range: %v", fr))
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ gap := f.usage.ApplyContiguous(fr, func(seg usageIterator) {
+ seg.ValuePtr().refs++
+ })
+ if gap.Ok() {
+ panic(fmt.Sprintf("IncRef(%v): attempted to IncRef on unallocated pages %v:\n%v", fr, gap.Range(), &f.usage))
+ }
+
+ f.usage.MergeAdjacent(fr)
+}
+
+// DecRef implements platform.File.DecRef.
+func (f *MemoryFile) DecRef(fr platform.FileRange) {
+ if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
+ panic(fmt.Sprintf("invalid range: %v", fr))
+ }
+
+ var freed bool
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ for seg := f.usage.FindSegment(fr.Start); seg.Ok() && seg.Start() < fr.End; seg = seg.NextSegment() {
+ seg = f.usage.Isolate(seg, fr)
+ val := seg.ValuePtr()
+ if val.refs == 0 {
+ panic(fmt.Sprintf("DecRef(%v): 0 existing references on %v:\n%v", fr, seg.Range(), &f.usage))
+ }
+ val.refs--
+ if val.refs == 0 {
+ f.reclaim.Add(seg.Range(), reclaimSetValue{})
+ freed = true
+ // Reclassify memory as System, until it's freed by the reclaim
+ // goroutine.
+ if val.knownCommitted {
+ usage.MemoryAccounting.Move(seg.Range().Length(), usage.System, val.kind)
+ }
+ val.kind = usage.System
+ }
+ }
+ f.usage.MergeAdjacent(fr)
+
+ if freed {
+ f.reclaimable = true
+ f.reclaimCond.Signal()
+ }
+}
+
+// MapInternal implements platform.File.MapInternal.
+func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+ if !fr.WellFormed() || fr.Length() == 0 {
+ panic(fmt.Sprintf("invalid range: %v", fr))
+ }
+ if at.Execute {
+ return safemem.BlockSeq{}, syserror.EACCES
+ }
+
+ chunks := ((fr.End + chunkMask) >> chunkShift) - (fr.Start >> chunkShift)
+ if chunks == 1 {
+ // Avoid an unnecessary slice allocation.
+ var seq safemem.BlockSeq
+ err := f.forEachMappingSlice(fr, func(bs []byte) {
+ seq = safemem.BlockSeqOf(safemem.BlockFromSafeSlice(bs))
+ })
+ return seq, err
+ }
+ blocks := make([]safemem.Block, 0, chunks)
+ err := f.forEachMappingSlice(fr, func(bs []byte) {
+ blocks = append(blocks, safemem.BlockFromSafeSlice(bs))
+ })
+ return safemem.BlockSeqFromSlice(blocks), err
+}
+
+// forEachMappingSlice invokes fn on a sequence of byte slices that
+// collectively map all bytes in fr.
+func (f *MemoryFile) forEachMappingSlice(fr platform.FileRange, fn func([]byte)) error {
+ mappings := f.mappings.Load().([]uintptr)
+ for chunkStart := fr.Start &^ chunkMask; chunkStart < fr.End; chunkStart += chunkSize {
+ chunk := int(chunkStart >> chunkShift)
+ m := atomic.LoadUintptr(&mappings[chunk])
+ if m == 0 {
+ var err error
+ mappings, m, err = f.getChunkMapping(chunk)
+ if err != nil {
+ return err
+ }
+ }
+ startOff := uint64(0)
+ if chunkStart < fr.Start {
+ startOff = fr.Start - chunkStart
+ }
+ endOff := uint64(chunkSize)
+ if chunkStart+chunkSize > fr.End {
+ endOff = fr.End - chunkStart
+ }
+ fn(unsafeSlice(m, chunkSize)[startOff:endOff])
+ }
+ return nil
+}
+
+func (f *MemoryFile) getChunkMapping(chunk int) ([]uintptr, uintptr, error) {
+ f.mappingsMu.Lock()
+ defer f.mappingsMu.Unlock()
+ // Another thread may have replaced f.mappings altogether due to file
+ // expansion.
+ mappings := f.mappings.Load().([]uintptr)
+ // Another thread may have already mapped the chunk.
+ if m := mappings[chunk]; m != 0 {
+ return mappings, m, nil
+ }
+ m, _, errno := syscall.Syscall6(
+ syscall.SYS_MMAP,
+ 0,
+ chunkSize,
+ syscall.PROT_READ|syscall.PROT_WRITE,
+ syscall.MAP_SHARED,
+ f.file.Fd(),
+ uintptr(chunk<<chunkShift))
+ if errno != 0 {
+ return nil, 0, errno
+ }
+ atomic.StoreUintptr(&mappings[chunk], m)
+ return mappings, m, nil
+}
+
+// MarkEvictable allows f to request memory deallocation by calling
+// user.Evict(er) in the future.
+//
+// Redundantly marking an already-evictable range as evictable has no effect.
+func (f *MemoryFile) MarkEvictable(user EvictableMemoryUser, er EvictableRange) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ info, ok := f.evictable[user]
+ if !ok {
+ info = &evictableMemoryUserInfo{}
+ f.evictable[user] = info
+ }
+ gap := info.ranges.LowerBoundGap(er.Start)
+ for gap.Ok() && gap.Start() < er.End {
+ gapER := gap.Range().Intersect(er)
+ if gapER.Length() == 0 {
+ gap = gap.NextGap()
+ continue
+ }
+ gap = info.ranges.Insert(gap, gapER, evictableRangeSetValue{}).NextGap()
+ }
+ if !info.evicting {
+ switch f.opts.DelayedEviction {
+ case DelayedEvictionDisabled:
+ // Kick off eviction immediately.
+ f.startEvictionGoroutineLocked(user, info)
+ case DelayedEvictionEnabled:
+ if !f.opts.UseHostMemcgPressure {
+ // Ensure that the reclaimer goroutine is running, so that it
+ // can start eviction when necessary.
+ f.reclaimCond.Signal()
+ }
+ }
+ }
+}
+
+// MarkUnevictable informs f that user no longer considers er to be evictable,
+// so the MemoryFile should no longer call user.Evict(er). Note that, per
+// EvictableMemoryUser.Evict's documentation, user.Evict(er) may still be
+// called even after MarkUnevictable returns due to race conditions, and
+// implementations of EvictableMemoryUser must handle this possibility.
+//
+// Redundantly marking an already-unevictable range as unevictable has no
+// effect.
+func (f *MemoryFile) MarkUnevictable(user EvictableMemoryUser, er EvictableRange) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ info, ok := f.evictable[user]
+ if !ok {
+ return
+ }
+ seg := info.ranges.LowerBoundSegment(er.Start)
+ for seg.Ok() && seg.Start() < er.End {
+ seg = info.ranges.Isolate(seg, er)
+ seg = info.ranges.Remove(seg).NextSegment()
+ }
+ // We can only remove info if there's no eviction goroutine running on its
+ // behalf.
+ if !info.evicting && info.ranges.IsEmpty() {
+ delete(f.evictable, user)
+ }
+}
+
+// MarkAllUnevictable informs f that user no longer considers any offsets to be
+// evictable. It otherwise has the same semantics as MarkUnevictable.
+func (f *MemoryFile) MarkAllUnevictable(user EvictableMemoryUser) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ info, ok := f.evictable[user]
+ if !ok {
+ return
+ }
+ info.ranges.RemoveAll()
+ // We can only remove info if there's no eviction goroutine running on its
+ // behalf.
+ if !info.evicting {
+ delete(f.evictable, user)
+ }
+}
+
+// ShouldCacheEvictable returns true if f is meaningfully delaying evictions of
+// evictable memory, such that it may be advantageous to cache data in
+// evictable memory. The value returned by ShouldCacheEvictable may change
+// between calls.
+func (f *MemoryFile) ShouldCacheEvictable() bool {
+ return f.opts.DelayedEviction == DelayedEvictionManual || f.opts.UseHostMemcgPressure
+}
+
+// UpdateUsage ensures that the memory usage statistics in
+// usage.MemoryAccounting are up to date.
+func (f *MemoryFile) UpdateUsage() error {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // If the underlying usage matches where the usage tree already
+ // represents, then we can just avoid the entire scan (we know it's
+ // accurate).
+ currentUsage, err := f.TotalUsage()
+ if err != nil {
+ return err
+ }
+ if currentUsage == f.usageExpected && f.usageSwapped == 0 {
+ log.Debugf("UpdateUsage: skipped with usageSwapped=0.")
+ return nil
+ }
+ // If the current usage matches the expected but there's swap
+ // accounting, then ensure a scan takes place at least every second
+ // (when requested).
+ if currentUsage == f.usageExpected+f.usageSwapped && time.Now().Before(f.usageLast.Add(time.Second)) {
+ log.Debugf("UpdateUsage: skipped with usageSwapped!=0.")
+ return nil
+ }
+
+ f.usageLast = time.Now()
+ err = f.updateUsageLocked(currentUsage, mincore)
+ log.Debugf("UpdateUsage: currentUsage=%d, usageExpected=%d, usageSwapped=%d.",
+ currentUsage, f.usageExpected, f.usageSwapped)
+ log.Debugf("UpdateUsage: took %v.", time.Since(f.usageLast))
+ return err
+}
+
+// updateUsageLocked attempts to detect commitment of previous-uncommitted
+// pages by invoking checkCommitted, which is a function that, for each page i
+// in bs, sets committed[i] to 1 if the page is committed and 0 otherwise.
+//
+// Precondition: f.mu must be held.
+func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func(bs []byte, committed []byte) error) error {
+ // Track if anything changed to elide the merge. In the common case, we
+ // expect all segments to be committed and no merge to occur.
+ changedAny := false
+ defer func() {
+ if changedAny {
+ f.usage.MergeAll()
+ }
+
+ // Adjust the swap usage to reflect reality.
+ if f.usageExpected < currentUsage {
+ // Since no pages may be marked decommitted while we hold mu, we
+ // know that usage may have only increased since we got the last
+ // current usage. Therefore, if usageExpected is still short of
+ // currentUsage, we must assume that the difference is in pages
+ // that have been swapped.
+ newUsageSwapped := currentUsage - f.usageExpected
+ if f.usageSwapped < newUsageSwapped {
+ usage.MemoryAccounting.Inc(newUsageSwapped-f.usageSwapped, usage.System)
+ } else {
+ usage.MemoryAccounting.Dec(f.usageSwapped-newUsageSwapped, usage.System)
+ }
+ f.usageSwapped = newUsageSwapped
+ } else if f.usageSwapped != 0 {
+ // We have more usage accounted for than the file itself.
+ // That's fine, we probably caught a race where pages were
+ // being committed while the above loop was running. Just
+ // report the higher number that we found and ignore swap.
+ usage.MemoryAccounting.Dec(f.usageSwapped, usage.System)
+ f.usageSwapped = 0
+ }
+ }()
+
+ // Reused mincore buffer, will generally be <= 4096 bytes.
+ var buf []byte
+
+ // Iterate over all usage data. There will only be usage segments
+ // present when there is an associated reference.
+ for seg := f.usage.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ val := seg.Value()
+
+ // Already known to be committed; ignore.
+ if val.knownCommitted {
+ continue
+ }
+
+ // Assume that reclaimable pages (that aren't already known to be
+ // committed) are not committed. This isn't necessarily true, even
+ // after the reclaimer does Decommit(), because the kernel may
+ // subsequently back the hugepage-sized region containing the
+ // decommitted page with a hugepage. However, it's consistent with our
+ // treatment of unallocated pages, which have the same property.
+ if val.refs == 0 {
+ continue
+ }
+
+ // Get the range for this segment. As we touch slices, the
+ // Start value will be walked along.
+ r := seg.Range()
+
+ var checkErr error
+ err := f.forEachMappingSlice(r, func(s []byte) {
+ if checkErr != nil {
+ return
+ }
+
+ // Ensure that we have sufficient buffer for the call
+ // (one byte per page). The length of each slice must
+ // be page-aligned.
+ bufLen := len(s) / usermem.PageSize
+ if len(buf) < bufLen {
+ buf = make([]byte, bufLen)
+ }
+
+ // Query for new pages in core.
+ if err := checkCommitted(s, buf); err != nil {
+ checkErr = err
+ return
+ }
+
+ // Scan each page and switch out segments.
+ populatedRun := false
+ populatedRunStart := 0
+ for i := 0; i <= bufLen; i++ {
+ // We run past the end of the slice here to
+ // simplify the logic and only set populated if
+ // we're still looking at elements.
+ populated := false
+ if i < bufLen {
+ populated = buf[i]&0x1 != 0
+ }
+
+ switch {
+ case populated == populatedRun:
+ // Keep the run going.
+ continue
+ case populated && !populatedRun:
+ // Begin the run.
+ populatedRun = true
+ populatedRunStart = i
+ // Keep going.
+ continue
+ case !populated && populatedRun:
+ // Finish the run by changing this segment.
+ runRange := platform.FileRange{
+ Start: r.Start + uint64(populatedRunStart*usermem.PageSize),
+ End: r.Start + uint64(i*usermem.PageSize),
+ }
+ seg = f.usage.Isolate(seg, runRange)
+ seg.ValuePtr().knownCommitted = true
+ // Advance the segment only if we still
+ // have work to do in the context of
+ // the original segment from the for
+ // loop. Otherwise, the for loop itself
+ // will advance the segment
+ // appropriately.
+ if runRange.End != r.End {
+ seg = seg.NextSegment()
+ }
+ amount := runRange.Length()
+ usage.MemoryAccounting.Inc(amount, val.kind)
+ f.usageExpected += amount
+ changedAny = true
+ populatedRun = false
+ }
+ }
+
+ // Advance r.Start.
+ r.Start += uint64(len(s))
+ })
+ if checkErr != nil {
+ return checkErr
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// TotalUsage returns an aggregate usage for all memory statistics except
+// Mapped (which is external to MemoryFile). This is generally much cheaper
+// than UpdateUsage, but will not provide a fine-grained breakdown.
+func (f *MemoryFile) TotalUsage() (uint64, error) {
+ // Stat the underlying file to discover the underlying usage. stat(2)
+ // always reports the allocated block count in units of 512 bytes. This
+ // includes pages in the page cache and swapped pages.
+ var stat syscall.Stat_t
+ if err := syscall.Fstat(int(f.file.Fd()), &stat); err != nil {
+ return 0, err
+ }
+ return uint64(stat.Blocks * 512), nil
+}
+
+// TotalSize returns the current size of the backing file in bytes, which is an
+// upper bound on the amount of memory that can currently be allocated from the
+// MemoryFile. The value returned by TotalSize is permitted to change.
+func (f *MemoryFile) TotalSize() uint64 {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ return uint64(f.fileSize)
+}
+
+// File returns the backing file.
+func (f *MemoryFile) File() *os.File {
+ return f.file
+}
+
+// FD implements platform.File.FD.
+func (f *MemoryFile) FD() int {
+ return int(f.file.Fd())
+}
+
+// String implements fmt.Stringer.String.
+//
+// Note that because f.String locks f.mu, calling f.String internally
+// (including indirectly through the fmt package) risks recursive locking.
+// Within the pgalloc package, use f.usage directly instead.
+func (f *MemoryFile) String() string {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ return f.usage.String()
+}
+
+// runReclaim implements the reclaimer goroutine, which continuously decommits
+// reclaimable pages in order to reduce memory usage and make them available
+// for allocation.
+func (f *MemoryFile) runReclaim() {
+ for {
+ // N.B. We must call f.markReclaimed on the returned FrameRange.
+ fr, ok := f.findReclaimable()
+ if !ok {
+ break
+ }
+
+ if err := f.Decommit(fr); err != nil {
+ log.Warningf("Reclaim failed to decommit %v: %v", fr, err)
+ // Zero the pages manually. This won't reduce memory usage, but at
+ // least ensures that the pages will be zero when reallocated.
+ f.forEachMappingSlice(fr, func(bs []byte) {
+ for i := range bs {
+ bs[i] = 0
+ }
+ })
+ // Pretend the pages were decommitted even though they weren't,
+ // since the memory accounting implementation has no idea how to
+ // deal with this.
+ f.markDecommitted(fr)
+ }
+ f.markReclaimed(fr)
+ }
+
+ // We only get here if findReclaimable finds f.destroyed set and returns
+ // false.
+ f.mu.Lock()
+ if !f.destroyed {
+ f.mu.Unlock()
+ panic("findReclaimable broke out of reclaim loop, but destroyed is no longer set")
+ }
+ f.file.Close()
+ // Ensure that any attempts to use f.file.Fd() fail instead of getting a fd
+ // that has possibly been reassigned.
+ f.file = nil
+ f.mappingsMu.Lock()
+ defer f.mappingsMu.Unlock()
+ mappings := f.mappings.Load().([]uintptr)
+ for i, m := range mappings {
+ if m != 0 {
+ _, _, errno := syscall.Syscall(syscall.SYS_MUNMAP, m, chunkSize, 0)
+ if errno != 0 {
+ log.Warningf("Failed to unmap mapping %#x for MemoryFile chunk %d: %v", m, i, errno)
+ }
+ }
+ }
+ // Similarly, invalidate f.mappings. (atomic.Value.Store(nil) panics.)
+ f.mappings.Store([]uintptr{})
+ f.mu.Unlock()
+
+ // This must be called without holding f.mu to avoid circular lock
+ // ordering.
+ if f.stopNotifyPressure != nil {
+ f.stopNotifyPressure()
+ }
+}
+
+// findReclaimable finds memory that has been marked for reclaim.
+//
+// Note that there returned range will be removed from tracking. It
+// must be reclaimed (removed from f.usage) at this point.
+func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ for {
+ for {
+ if f.destroyed {
+ return platform.FileRange{}, false
+ }
+ if f.reclaimable {
+ break
+ }
+ if f.opts.DelayedEviction == DelayedEvictionEnabled && !f.opts.UseHostMemcgPressure {
+ // No work to do. Evict any pending evictable allocations to
+ // get more reclaimable pages before going to sleep.
+ f.startEvictionsLocked()
+ }
+ f.reclaimCond.Wait()
+ }
+ // Allocate works from the back of the file inwards, so reclaim
+ // preserves this order to minimize the cost of the search.
+ if seg := f.reclaim.LastSegment(); seg.Ok() {
+ fr := seg.Range()
+ f.reclaim.Remove(seg)
+ return fr, true
+ }
+ // Nothing is reclaimable.
+ f.reclaimable = false
+ }
+}
+
+func (f *MemoryFile) markReclaimed(fr platform.FileRange) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ seg := f.usage.FindSegment(fr.Start)
+ // All of fr should be mapped to a single uncommitted reclaimable
+ // segment accounted to System.
+ if !seg.Ok() {
+ panic(fmt.Sprintf("reclaimed pages %v include unreferenced pages:\n%v", fr, &f.usage))
+ }
+ if !seg.Range().IsSupersetOf(fr) {
+ panic(fmt.Sprintf("reclaimed pages %v are not entirely contained in segment %v with state %v:\n%v", fr, seg.Range(), seg.Value(), &f.usage))
+ }
+ if got, want := seg.Value(), (usageInfo{
+ kind: usage.System,
+ knownCommitted: false,
+ refs: 0,
+ }); got != want {
+ panic(fmt.Sprintf("reclaimed pages %v in segment %v has incorrect state %v, wanted %v:\n%v", fr, seg.Range(), got, want, &f.usage))
+ }
+ // Deallocate reclaimed pages. Even though all of seg is reclaimable,
+ // the caller of markReclaimed may not have decommitted it, so we can
+ // only mark fr as reclaimed.
+ f.usage.Remove(f.usage.Isolate(seg, fr))
+}
+
+// StartEvictions requests that f evict all evictable allocations. It does not
+// wait for eviction to complete; for this, see MemoryFile.WaitForEvictions.
+func (f *MemoryFile) StartEvictions() {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.startEvictionsLocked()
+}
+
+// Preconditions: f.mu must be locked.
+func (f *MemoryFile) startEvictionsLocked() bool {
+ startedAny := false
+ for user, info := range f.evictable {
+ // Don't start multiple goroutines to evict the same user's
+ // allocations.
+ if !info.evicting {
+ f.startEvictionGoroutineLocked(user, info)
+ startedAny = true
+ }
+ }
+ return startedAny
+}
+
+// Preconditions: info == f.evictable[user]. !info.evicting. f.mu must be
+// locked.
+func (f *MemoryFile) startEvictionGoroutineLocked(user EvictableMemoryUser, info *evictableMemoryUserInfo) {
+ info.evicting = true
+ f.evictionWG.Add(1)
+ go func() { // S/R-SAFE: f.evictionWG
+ defer f.evictionWG.Done()
+ for {
+ f.mu.Lock()
+ info, ok := f.evictable[user]
+ if !ok {
+ // This shouldn't happen: only this goroutine is permitted
+ // to delete this entry.
+ f.mu.Unlock()
+ panic(fmt.Sprintf("evictableMemoryUserInfo for EvictableMemoryUser %v deleted while eviction goroutine running", user))
+ }
+ if info.ranges.IsEmpty() {
+ delete(f.evictable, user)
+ f.mu.Unlock()
+ return
+ }
+ // Evict from the end of info.ranges, under the assumption that
+ // if ranges in user start being used again (and are
+ // consequently marked unevictable), such uses are more likely
+ // to start from the beginning of user.
+ seg := info.ranges.LastSegment()
+ er := seg.Range()
+ info.ranges.Remove(seg)
+ // user.Evict() must be called without holding f.mu to avoid
+ // circular lock ordering.
+ f.mu.Unlock()
+ user.Evict(context.Background(), er)
+ }
+ }()
+}
+
+// WaitForEvictions blocks until f is no longer evicting any evictable
+// allocations.
+func (f *MemoryFile) WaitForEvictions() {
+ f.evictionWG.Wait()
+}
+
+type usageSetFunctions struct{}
+
+func (usageSetFunctions) MinKey() uint64 {
+ return 0
+}
+
+func (usageSetFunctions) MaxKey() uint64 {
+ return math.MaxUint64
+}
+
+func (usageSetFunctions) ClearValue(val *usageInfo) {
+}
+
+func (usageSetFunctions) Merge(_ platform.FileRange, val1 usageInfo, _ platform.FileRange, val2 usageInfo) (usageInfo, bool) {
+ return val1, val1 == val2
+}
+
+func (usageSetFunctions) Split(_ platform.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) {
+ return val, val
+}
+
+// evictableRangeSetValue is the value type of evictableRangeSet.
+type evictableRangeSetValue struct{}
+
+type evictableRangeSetFunctions struct{}
+
+func (evictableRangeSetFunctions) MinKey() uint64 {
+ return 0
+}
+
+func (evictableRangeSetFunctions) MaxKey() uint64 {
+ return math.MaxUint64
+}
+
+func (evictableRangeSetFunctions) ClearValue(val *evictableRangeSetValue) {
+}
+
+func (evictableRangeSetFunctions) Merge(_ EvictableRange, _ evictableRangeSetValue, _ EvictableRange, _ evictableRangeSetValue) (evictableRangeSetValue, bool) {
+ return evictableRangeSetValue{}, true
+}
+
+func (evictableRangeSetFunctions) Split(_ EvictableRange, _ evictableRangeSetValue, _ uint64) (evictableRangeSetValue, evictableRangeSetValue) {
+ return evictableRangeSetValue{}, evictableRangeSetValue{}
+}
+
+// reclaimSetValue is the value type of reclaimSet.
+type reclaimSetValue struct{}
+
+type reclaimSetFunctions struct{}
+
+func (reclaimSetFunctions) MinKey() uint64 {
+ return 0
+}
+
+func (reclaimSetFunctions) MaxKey() uint64 {
+ return math.MaxUint64
+}
+
+func (reclaimSetFunctions) ClearValue(val *reclaimSetValue) {
+}
+
+func (reclaimSetFunctions) Merge(_ platform.FileRange, _ reclaimSetValue, _ platform.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) {
+ return reclaimSetValue{}, true
+}
+
+func (reclaimSetFunctions) Split(_ platform.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) {
+ return reclaimSetValue{}, reclaimSetValue{}
+}
diff --git a/pkg/sentry/pgalloc/pgalloc_test.go b/pkg/sentry/pgalloc/pgalloc_test.go
new file mode 100644
index 000000000..405db141f
--- /dev/null
+++ b/pkg/sentry/pgalloc/pgalloc_test.go
@@ -0,0 +1,246 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pgalloc
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ page = usermem.PageSize
+ hugepage = usermem.HugePageSize
+ topPage = (1 << 63) - page
+)
+
+func TestFindUnallocatedRange(t *testing.T) {
+ for _, test := range []struct {
+ desc string
+ usage *usageSegmentDataSlices
+ fileSize int64
+ length uint64
+ alignment uint64
+ start uint64
+ expectFail bool
+ }{
+ {
+ desc: "Initial allocation succeeds",
+ usage: &usageSegmentDataSlices{},
+ length: page,
+ alignment: page,
+ start: chunkSize - page, // Grows by chunkSize, allocate down.
+ },
+ {
+ desc: "Allocation finds empty space at start of file",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page},
+ End: []uint64{2 * page},
+ Values: []usageInfo{{refs: 1}},
+ },
+ fileSize: 2 * page,
+ length: page,
+ alignment: page,
+ start: 0,
+ },
+ {
+ desc: "Allocation finds empty space at end of file",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{0},
+ End: []uint64{page},
+ Values: []usageInfo{{refs: 1}},
+ },
+ fileSize: 2 * page,
+ length: page,
+ alignment: page,
+ start: page,
+ },
+ {
+ desc: "In-use frames are not allocatable",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{0, page},
+ End: []uint64{page, 2 * page},
+ Values: []usageInfo{{refs: 1}, {refs: 2}},
+ },
+ fileSize: 2 * page,
+ length: page,
+ alignment: page,
+ start: 3 * page, // Double fileSize, allocate top-down.
+ },
+ {
+ desc: "Reclaimable frames are not allocatable",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{0, page, 2 * page},
+ End: []uint64{page, 2 * page, 3 * page},
+ Values: []usageInfo{{refs: 1}, {refs: 0}, {refs: 1}},
+ },
+ fileSize: 3 * page,
+ length: page,
+ alignment: page,
+ start: 5 * page, // Double fileSize, grow down.
+ },
+ {
+ desc: "Gaps between in-use frames are allocatable",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{0, 2 * page},
+ End: []uint64{page, 3 * page},
+ Values: []usageInfo{{refs: 1}, {refs: 1}},
+ },
+ fileSize: 3 * page,
+ length: page,
+ alignment: page,
+ start: page,
+ },
+ {
+ desc: "Inadequately-sized gaps are rejected",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{0, 2 * page},
+ End: []uint64{page, 3 * page},
+ Values: []usageInfo{{refs: 1}, {refs: 1}},
+ },
+ fileSize: 3 * page,
+ length: 2 * page,
+ alignment: page,
+ start: 4 * page, // Double fileSize, grow down.
+ },
+ {
+ desc: "Alignment is honored at end of file",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{0, hugepage + page},
+ // Hugepage-sized gap here that shouldn't be allocated from
+ // since it's incorrectly aligned.
+ End: []uint64{page, hugepage + 2*page},
+ Values: []usageInfo{{refs: 1}, {refs: 1}},
+ },
+ fileSize: hugepage + 2*page,
+ length: hugepage,
+ alignment: hugepage,
+ start: 3 * hugepage, // Double fileSize until alignment is satisfied, grow down.
+ },
+ {
+ desc: "Alignment is honored before end of file",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{0, 2*hugepage + page},
+ // Page will need to be shifted down from top.
+ End: []uint64{page, 2*hugepage + 2*page},
+ Values: []usageInfo{{refs: 1}, {refs: 1}},
+ },
+ fileSize: 2*hugepage + 2*page,
+ length: hugepage,
+ alignment: hugepage,
+ start: hugepage,
+ },
+ {
+ desc: "Allocation doubles file size more than once if necessary",
+ usage: &usageSegmentDataSlices{},
+ fileSize: page,
+ length: 4 * page,
+ alignment: page,
+ start: 0,
+ },
+ {
+ desc: "Allocations are compact if possible",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page, 3 * page},
+ End: []uint64{2 * page, 4 * page},
+ Values: []usageInfo{{refs: 1}, {refs: 2}},
+ },
+ fileSize: 4 * page,
+ length: page,
+ alignment: page,
+ start: 2 * page,
+ },
+ {
+ desc: "Top-down allocation within one gap",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page, 4 * page, 7 * page},
+ End: []uint64{2 * page, 5 * page, 8 * page},
+ Values: []usageInfo{{refs: 1}, {refs: 2}, {refs: 1}},
+ },
+ fileSize: 8 * page,
+ length: page,
+ alignment: page,
+ start: 6 * page,
+ },
+ {
+ desc: "Top-down allocation between multiple gaps",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page, 3 * page, 5 * page},
+ End: []uint64{2 * page, 4 * page, 6 * page},
+ Values: []usageInfo{{refs: 1}, {refs: 2}, {refs: 1}},
+ },
+ fileSize: 6 * page,
+ length: page,
+ alignment: page,
+ start: 4 * page,
+ },
+ {
+ desc: "Top-down allocation with large top gap",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page, 3 * page},
+ End: []uint64{2 * page, 4 * page},
+ Values: []usageInfo{{refs: 1}, {refs: 2}},
+ },
+ fileSize: 8 * page,
+ length: page,
+ alignment: page,
+ start: 7 * page,
+ },
+ {
+ desc: "Gaps found with possible overflow",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page, topPage - page},
+ End: []uint64{2 * page, topPage},
+ Values: []usageInfo{{refs: 1}, {refs: 1}},
+ },
+ fileSize: topPage,
+ length: page,
+ alignment: page,
+ start: topPage - 2*page,
+ },
+ {
+ desc: "Overflow detected",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page},
+ End: []uint64{topPage},
+ Values: []usageInfo{{refs: 1}},
+ },
+ fileSize: topPage,
+ length: 2 * page,
+ alignment: page,
+ expectFail: true,
+ },
+ } {
+ t.Run(test.desc, func(t *testing.T) {
+ var usage usageSet
+ if err := usage.ImportSortedSlices(test.usage); err != nil {
+ t.Fatalf("Failed to initialize usage from %v: %v", test.usage, err)
+ }
+ fr, ok := findAvailableRange(&usage, test.fileSize, test.length, test.alignment)
+ if !test.expectFail && !ok {
+ t.Fatalf("findAvailableRange(%v, %x, %x, %x): got %x, false wanted %x, true", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start)
+ }
+ if test.expectFail && ok {
+ t.Fatalf("findAvailableRange(%v, %x, %x, %x): got %x, true wanted %x, false", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start)
+ }
+ if ok && fr.Start != test.start {
+ t.Errorf("findAvailableRange(%v, %x, %x, %x): got start=%x, wanted %x", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start)
+ }
+ if ok && fr.End != test.start+test.length {
+ t.Errorf("findAvailableRange(%v, %x, %x, %x): got end=%x, wanted %x", test.usage, test.fileSize, test.length, test.alignment, fr.End, test.start+test.length)
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/pgalloc/pgalloc_unsafe.go b/pkg/sentry/pgalloc/pgalloc_unsafe.go
new file mode 100644
index 000000000..a4b5d581c
--- /dev/null
+++ b/pkg/sentry/pgalloc/pgalloc_unsafe.go
@@ -0,0 +1,40 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pgalloc
+
+import (
+ "reflect"
+ "syscall"
+ "unsafe"
+)
+
+func unsafeSlice(addr uintptr, length int) (slice []byte) {
+ sh := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
+ sh.Data = addr
+ sh.Len = length
+ sh.Cap = length
+ return
+}
+
+func mincore(s []byte, buf []byte) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_MINCORE,
+ uintptr(unsafe.Pointer(&s[0])),
+ uintptr(len(s)),
+ uintptr(unsafe.Pointer(&buf[0]))); errno != 0 {
+ return errno
+ }
+ return nil
+}
diff --git a/pkg/sentry/pgalloc/save_restore.go b/pkg/sentry/pgalloc/save_restore.go
new file mode 100644
index 000000000..78317fa35
--- /dev/null
+++ b/pkg/sentry/pgalloc/save_restore.go
@@ -0,0 +1,212 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pgalloc
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "runtime"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/wire"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// SaveTo writes f's state to the given stream.
+func (f *MemoryFile) SaveTo(ctx context.Context, w wire.Writer) error {
+ // Wait for reclaim.
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ for f.reclaimable {
+ f.reclaimCond.Signal()
+ f.mu.Unlock()
+ runtime.Gosched()
+ f.mu.Lock()
+ }
+
+ // Ensure that there are no pending evictions.
+ if len(f.evictable) != 0 {
+ panic(fmt.Sprintf("evictions still pending for %d users; call StartEvictions and WaitForEvictions before SaveTo", len(f.evictable)))
+ }
+
+ // Ensure that all pages that contain data have knownCommitted set, since
+ // we only store knownCommitted pages below.
+ zeroPage := make([]byte, usermem.PageSize)
+ err := f.updateUsageLocked(0, func(bs []byte, committed []byte) error {
+ for pgoff := 0; pgoff < len(bs); pgoff += usermem.PageSize {
+ i := pgoff / usermem.PageSize
+ pg := bs[pgoff : pgoff+usermem.PageSize]
+ if !bytes.Equal(pg, zeroPage) {
+ committed[i] = 1
+ continue
+ }
+ committed[i] = 0
+ // Reading the page caused it to be committed; decommit it to
+ // reduce memory usage.
+ //
+ // "MADV_REMOVE [...] Free up a given range of pages and its
+ // associated backing store. This is equivalent to punching a hole
+ // in the corresponding byte range of the backing store (see
+ // fallocate(2))." - madvise(2)
+ if err := syscall.Madvise(pg, syscall.MADV_REMOVE); err != nil {
+ // This doesn't impact the correctness of saved memory, it
+ // just means that we're incrementally more likely to OOM.
+ // Complain, but don't abort saving.
+ log.Warningf("Decommitting page %p while saving failed: %v", pg, err)
+ }
+ }
+ return nil
+ })
+ if err != nil {
+ return err
+ }
+
+ // Save metadata.
+ if _, err := state.Save(ctx, w, &f.fileSize); err != nil {
+ return err
+ }
+ if _, err := state.Save(ctx, w, &f.usage); err != nil {
+ return err
+ }
+
+ // Dump out committed pages.
+ for seg := f.usage.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ if !seg.Value().knownCommitted {
+ continue
+ }
+ // Write a header to distinguish from objects.
+ if err := state.WriteHeader(w, uint64(seg.Range().Length()), false); err != nil {
+ return err
+ }
+ // Write out data.
+ var ioErr error
+ err := f.forEachMappingSlice(seg.Range(), func(s []byte) {
+ if ioErr != nil {
+ return
+ }
+ _, ioErr = w.Write(s)
+ })
+ if ioErr != nil {
+ return ioErr
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// LoadFrom loads MemoryFile state from the given stream.
+func (f *MemoryFile) LoadFrom(ctx context.Context, r wire.Reader) error {
+ // Load metadata.
+ if _, err := state.Load(ctx, r, &f.fileSize); err != nil {
+ return err
+ }
+ if err := f.file.Truncate(f.fileSize); err != nil {
+ return err
+ }
+ newMappings := make([]uintptr, f.fileSize>>chunkShift)
+ f.mappings.Store(newMappings)
+ if _, err := state.Load(ctx, r, &f.usage); err != nil {
+ return err
+ }
+
+ // Try to map committed chunks concurrently: For any given chunk, either
+ // this loop or the following one will mmap the chunk first and cache it in
+ // f.mappings for the other, but this loop is likely to run ahead of the
+ // other since it doesn't do any work between mmaps. The rest of this
+ // function doesn't mutate f.usage, so it's safe to iterate concurrently.
+ mapperDone := make(chan struct{})
+ mapperCanceled := int32(0)
+ go func() { // S/R-SAFE: see comment
+ defer func() { close(mapperDone) }()
+ for seg := f.usage.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ if atomic.LoadInt32(&mapperCanceled) != 0 {
+ return
+ }
+ if seg.Value().knownCommitted {
+ f.forEachMappingSlice(seg.Range(), func(s []byte) {})
+ }
+ }
+ }()
+ defer func() {
+ atomic.StoreInt32(&mapperCanceled, 1)
+ <-mapperDone
+ }()
+
+ // Load committed pages.
+ for seg := f.usage.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ if !seg.Value().knownCommitted {
+ continue
+ }
+ // Verify header.
+ length, object, err := state.ReadHeader(r)
+ if err != nil {
+ return err
+ }
+ if object {
+ // Not expected.
+ return fmt.Errorf("unexpected object")
+ }
+ if expected := uint64(seg.Range().Length()); length != expected {
+ // Size mismatch.
+ return fmt.Errorf("mismatched segment: expected %d, got %d", expected, length)
+ }
+ // Read data.
+ var ioErr error
+ err = f.forEachMappingSlice(seg.Range(), func(s []byte) {
+ if ioErr != nil {
+ return
+ }
+ _, ioErr = io.ReadFull(r, s)
+ })
+ if ioErr != nil {
+ return ioErr
+ }
+ if err != nil {
+ return err
+ }
+
+ // Update accounting for restored pages. We need to do this here since
+ // these segments are marked as "known committed", and will be skipped
+ // over on accounting scans.
+ usage.MemoryAccounting.Inc(seg.End()-seg.Start(), seg.Value().kind)
+ }
+
+ return nil
+}
+
+// MemoryFileProvider provides the MemoryFile method.
+//
+// This type exists to work around a save/restore defect. The only object in a
+// saved object graph that S/R allows to be replaced at time of restore is the
+// starting point of the restore, kernel.Kernel. However, the MemoryFile
+// changes between save and restore as well, so objects that need persistent
+// access to the MemoryFile must instead store a pointer to the Kernel and call
+// Kernel.MemoryFile() as required. In most cases, depending on the kernel
+// package directly would create a package dependency loop, so the stored
+// pointer must instead be a MemoryProvider interface object. Correspondingly,
+// kernel.Kernel is the only implementation of this interface.
+type MemoryFileProvider interface {
+ // MemoryFile returns the Kernel MemoryFile.
+ MemoryFile() *MemoryFile
+}
diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD
new file mode 100644
index 000000000..453241eca
--- /dev/null
+++ b/pkg/sentry/platform/BUILD
@@ -0,0 +1,39 @@
+load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "file_range",
+ out = "file_range.go",
+ package = "platform",
+ prefix = "File",
+ template = "//pkg/segment:generic_range",
+ types = {
+ "T": "uint64",
+ },
+)
+
+go_library(
+ name = "platform",
+ srcs = [
+ "context.go",
+ "file_range.go",
+ "mmap_min_addr.go",
+ "platform.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/atomicbitops",
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/safecopy",
+ "//pkg/safemem",
+ "//pkg/seccomp",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/usage",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/platform/context.go b/pkg/sentry/platform/context.go
new file mode 100644
index 000000000..6759cda65
--- /dev/null
+++ b/pkg/sentry/platform/context.go
@@ -0,0 +1,36 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package platform
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+// contextID is the auth package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxPlatform is a Context.Value key for a Platform.
+ CtxPlatform contextID = iota
+)
+
+// FromContext returns the Platform that is used to execute ctx's application
+// code, or nil if no such Platform exists.
+func FromContext(ctx context.Context) Platform {
+ if v := ctx.Value(CtxPlatform); v != nil {
+ return v.(Platform)
+ }
+ return nil
+}
diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD
new file mode 100644
index 000000000..83b385f14
--- /dev/null
+++ b/pkg/sentry/platform/interrupt/BUILD
@@ -0,0 +1,19 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "interrupt",
+ srcs = [
+ "interrupt.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = ["//pkg/sync"],
+)
+
+go_test(
+ name = "interrupt_test",
+ size = "small",
+ srcs = ["interrupt_test.go"],
+ library = ":interrupt",
+)
diff --git a/pkg/sentry/platform/interrupt/interrupt.go b/pkg/sentry/platform/interrupt/interrupt.go
new file mode 100644
index 000000000..57be41647
--- /dev/null
+++ b/pkg/sentry/platform/interrupt/interrupt.go
@@ -0,0 +1,97 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package interrupt provides an interrupt helper.
+package interrupt
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Receiver receives interrupt notifications from a Forwarder.
+type Receiver interface {
+ // NotifyInterrupt is called when the Receiver receives an interrupt.
+ NotifyInterrupt()
+}
+
+// Forwarder is a helper for delivering delayed signal interruptions.
+//
+// This helps platform implementations with Interrupt semantics.
+type Forwarder struct {
+ // mu protects the below.
+ mu sync.Mutex
+
+ // dst is the function to be called when NotifyInterrupt() is called. If
+ // dst is nil, pending will be set instead, causing the next call to
+ // Enable() to return false.
+ dst Receiver
+ pending bool
+}
+
+// Enable attempts to enable interrupt forwarding to r. If f has already
+// received an interrupt, Enable does nothing and returns false. Otherwise,
+// future calls to f.NotifyInterrupt() cause r.NotifyInterrupt() to be called,
+// and Enable returns true.
+//
+// Usage:
+//
+// if !f.Enable(r) {
+// // There was an interrupt.
+// return
+// }
+// defer f.Disable()
+//
+// Preconditions: r must not be nil. f must not already be forwarding
+// interrupts to a Receiver.
+func (f *Forwarder) Enable(r Receiver) bool {
+ if r == nil {
+ panic("nil Receiver")
+ }
+ f.mu.Lock()
+ if f.dst != nil {
+ f.mu.Unlock()
+ panic(fmt.Sprintf("already forwarding interrupts to %+v", f.dst))
+ }
+ if f.pending {
+ f.pending = false
+ f.mu.Unlock()
+ return false
+ }
+ f.dst = r
+ f.mu.Unlock()
+ return true
+}
+
+// Disable stops interrupt forwarding. If interrupt forwarding is already
+// disabled, Disable is a no-op.
+func (f *Forwarder) Disable() {
+ f.mu.Lock()
+ f.dst = nil
+ f.mu.Unlock()
+}
+
+// NotifyInterrupt implements Receiver.NotifyInterrupt. If interrupt forwarding
+// is enabled, the configured Receiver will be notified. Otherwise the
+// interrupt will be delivered to the next call to Enable.
+func (f *Forwarder) NotifyInterrupt() {
+ f.mu.Lock()
+ if f.dst != nil {
+ f.dst.NotifyInterrupt()
+ } else {
+ f.pending = true
+ }
+ f.mu.Unlock()
+}
diff --git a/pkg/sentry/platform/interrupt/interrupt_test.go b/pkg/sentry/platform/interrupt/interrupt_test.go
new file mode 100644
index 000000000..0ecdf6e7a
--- /dev/null
+++ b/pkg/sentry/platform/interrupt/interrupt_test.go
@@ -0,0 +1,99 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package interrupt
+
+import (
+ "testing"
+)
+
+type countingReceiver struct {
+ interrupts int
+}
+
+// NotifyInterrupt implements Receiver.NotifyInterrupt.
+func (r *countingReceiver) NotifyInterrupt() {
+ r.interrupts++
+}
+
+func TestSingleInterruptBeforeEnable(t *testing.T) {
+ var (
+ f Forwarder
+ r countingReceiver
+ )
+ f.NotifyInterrupt()
+ // The interrupt should cause the first Enable to fail.
+ if f.Enable(&r) {
+ f.Disable()
+ t.Fatalf("Enable: got true, wanted false")
+ }
+ // The failing Enable "acknowledges" the interrupt, allowing future Enables
+ // to succeed.
+ if !f.Enable(&r) {
+ t.Fatalf("Enable: got false, wanted true")
+ }
+ f.Disable()
+}
+
+func TestMultipleInterruptsBeforeEnable(t *testing.T) {
+ var (
+ f Forwarder
+ r countingReceiver
+ )
+ f.NotifyInterrupt()
+ f.NotifyInterrupt()
+ // The interrupts should cause the first Enable to fail.
+ if f.Enable(&r) {
+ f.Disable()
+ t.Fatalf("Enable: got true, wanted false")
+ }
+ // Interrupts are deduplicated while the Forwarder is disabled, so the
+ // failing Enable "acknowledges" all interrupts, allowing future Enables to
+ // succeed.
+ if !f.Enable(&r) {
+ t.Fatalf("Enable: got false, wanted true")
+ }
+ f.Disable()
+}
+
+func TestSingleInterruptAfterEnable(t *testing.T) {
+ var (
+ f Forwarder
+ r countingReceiver
+ )
+ if !f.Enable(&r) {
+ t.Fatalf("Enable: got false, wanted true")
+ }
+ defer f.Disable()
+ f.NotifyInterrupt()
+ if r.interrupts != 1 {
+ t.Errorf("interrupts: got %d, wanted 1", r.interrupts)
+ }
+}
+
+func TestMultipleInterruptsAfterEnable(t *testing.T) {
+ var (
+ f Forwarder
+ r countingReceiver
+ )
+ if !f.Enable(&r) {
+ t.Fatalf("Enable: got false, wanted true")
+ }
+ defer f.Disable()
+ f.NotifyInterrupt()
+ f.NotifyInterrupt()
+ if r.interrupts != 2 {
+ t.Errorf("interrupts: got %d, wanted 2", r.interrupts)
+ }
+}
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
new file mode 100644
index 000000000..4792454c4
--- /dev/null
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -0,0 +1,80 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "kvm",
+ srcs = [
+ "address_space.go",
+ "bluepill.go",
+ "bluepill_allocator.go",
+ "bluepill_amd64.go",
+ "bluepill_amd64.s",
+ "bluepill_amd64_unsafe.go",
+ "bluepill_arm64.go",
+ "bluepill_arm64.s",
+ "bluepill_arm64_unsafe.go",
+ "bluepill_fault.go",
+ "bluepill_unsafe.go",
+ "context.go",
+ "filters_amd64.go",
+ "filters_arm64.go",
+ "kvm.go",
+ "kvm_amd64.go",
+ "kvm_amd64_unsafe.go",
+ "kvm_arm64.go",
+ "kvm_arm64_unsafe.go",
+ "kvm_const.go",
+ "kvm_const_arm64.go",
+ "machine.go",
+ "machine_amd64.go",
+ "machine_amd64_unsafe.go",
+ "machine_arm64.go",
+ "machine_arm64_unsafe.go",
+ "machine_unsafe.go",
+ "physical_map.go",
+ "physical_map_amd64.go",
+ "physical_map_arm64.go",
+ "virtual_map.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/atomicbitops",
+ "//pkg/cpuid",
+ "//pkg/log",
+ "//pkg/procid",
+ "//pkg/safecopy",
+ "//pkg/seccomp",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/platform/interrupt",
+ "//pkg/sentry/platform/ring0",
+ "//pkg/sentry/platform/ring0/pagetables",
+ "//pkg/sentry/time",
+ "//pkg/sync",
+ "//pkg/usermem",
+ ],
+)
+
+go_test(
+ name = "kvm_test",
+ srcs = [
+ "kvm_test.go",
+ "virtual_map_test.go",
+ ],
+ library = ":kvm",
+ tags = [
+ "manual",
+ "nogotsan",
+ "requires-kvm",
+ ],
+ deps = [
+ "//pkg/sentry/arch",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/platform/kvm/testutil",
+ "//pkg/sentry/platform/ring0",
+ "//pkg/sentry/platform/ring0/pagetables",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go
new file mode 100644
index 000000000..faf1d5e1c
--- /dev/null
+++ b/pkg/sentry/platform/kvm/address_space.go
@@ -0,0 +1,249 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/atomicbitops"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// dirtySet tracks vCPUs for invalidation.
+type dirtySet struct {
+ vCPUMasks []uint64
+}
+
+// forEach iterates over all CPUs in the dirty set.
+//
+//go:nosplit
+func (ds *dirtySet) forEach(m *machine, fn func(c *vCPU)) {
+ for index := range ds.vCPUMasks {
+ mask := atomic.SwapUint64(&ds.vCPUMasks[index], 0)
+ if mask != 0 {
+ for bit := 0; bit < 64; bit++ {
+ if mask&(1<<uint64(bit)) == 0 {
+ continue
+ }
+ id := 64*index + bit
+ fn(m.vCPUsByID[id])
+ }
+ }
+ }
+}
+
+// mark marks the given vCPU as dirty and returns whether it was previously
+// clean. Being previously clean implies that a flush is needed on entry.
+func (ds *dirtySet) mark(c *vCPU) bool {
+ index := uint64(c.id) / 64
+ bit := uint64(1) << uint(c.id%64)
+
+ oldValue := atomic.LoadUint64(&ds.vCPUMasks[index])
+ if oldValue&bit != 0 {
+ return false // Not clean.
+ }
+
+ // Set the bit unilaterally, and ensure that a flush takes place. Note
+ // that it's possible for races to occur here, but since the flush is
+ // taking place long after these lines there's no race in practice.
+ atomicbitops.OrUint64(&ds.vCPUMasks[index], bit)
+ return true // Previously clean.
+}
+
+// addressSpace is a wrapper for PageTables.
+type addressSpace struct {
+ platform.NoAddressSpaceIO
+
+ // mu is the lock for modifications to the address space.
+ //
+ // Note that the page tables themselves are not locked.
+ mu sync.Mutex
+
+ // machine is the underlying machine.
+ machine *machine
+
+ // pageTables are for this particular address space.
+ pageTables *pagetables.PageTables
+
+ // dirtySet is the set of dirty vCPUs.
+ dirtySet *dirtySet
+}
+
+// invalidate is the implementation for Invalidate.
+func (as *addressSpace) invalidate() {
+ as.dirtySet.forEach(as.machine, func(c *vCPU) {
+ if c.active.get() == as { // If this happens to be active,
+ c.BounceToKernel() // ... force a kernel transition.
+ }
+ })
+}
+
+// Invalidate interrupts all dirty contexts.
+func (as *addressSpace) Invalidate() {
+ as.mu.Lock()
+ defer as.mu.Unlock()
+ as.invalidate()
+}
+
+// Touch adds the given vCPU to the dirty list.
+//
+// The return value indicates whether a flush is required.
+func (as *addressSpace) Touch(c *vCPU) bool {
+ return as.dirtySet.mark(c)
+}
+
+type hostMapEntry struct {
+ addr uintptr
+ length uintptr
+}
+
+// mapLocked maps the given host entry.
+//
+// +checkescape:hard,stack
+//
+//go:nosplit
+func (as *addressSpace) mapLocked(addr usermem.Addr, m hostMapEntry, at usermem.AccessType) (inv bool) {
+ for m.length > 0 {
+ physical, length, ok := translateToPhysical(m.addr)
+ if !ok {
+ panic("unable to translate segment")
+ }
+ if length > m.length {
+ length = m.length
+ }
+
+ // Ensure that this map has physical mappings. If the page does
+ // not have physical mappings, the KVM module may inject
+ // spurious exceptions when emulation fails (i.e. it tries to
+ // emulate because the RIP is pointed at those pages).
+ as.machine.mapPhysical(physical, length, physicalRegions, _KVM_MEM_FLAGS_NONE)
+
+ // Install the page table mappings. Note that the ordering is
+ // important; if the pagetable mappings were installed before
+ // ensuring the physical pages were available, then some other
+ // thread could theoretically access them.
+ inv = as.pageTables.Map(addr, length, pagetables.MapOpts{
+ AccessType: at,
+ User: true,
+ }, physical) || inv
+ m.addr += length
+ m.length -= length
+ addr += usermem.Addr(length)
+ }
+
+ return inv
+}
+
+// MapFile implements platform.AddressSpace.MapFile.
+func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error {
+ as.mu.Lock()
+ defer as.mu.Unlock()
+
+ // Get mappings in the sentry's address space, which are guaranteed to be
+ // valid as long as a reference is held on the mapped pages (which is in
+ // turn required by AddressSpace.MapFile precondition).
+ //
+ // If precommit is true, we will touch mappings to commit them, so ensure
+ // that mappings are readable from sentry context.
+ //
+ // We don't execute from application file-mapped memory, and guest page
+ // tables don't care if we have execute permission (but they do need pages
+ // to be readable).
+ bs, err := f.MapInternal(fr, usermem.AccessType{
+ Read: at.Read || at.Execute || precommit,
+ Write: at.Write,
+ })
+ if err != nil {
+ return err
+ }
+
+ // See block in mapLocked.
+ as.pageTables.Allocator.(*allocator).cpu = as.machine.Get()
+ defer as.machine.Put(as.pageTables.Allocator.(*allocator).cpu)
+
+ // Map the mappings in the sentry's address space (guest physical memory)
+ // into the application's address space (guest virtual memory).
+ inv := false
+ for !bs.IsEmpty() {
+ b := bs.Head()
+ bs = bs.Tail()
+ // Since fr was page-aligned, b should also be page-aligned. We do the
+ // lookup in our host page tables for this translation.
+ if precommit {
+ s := b.ToSlice()
+ for i := 0; i < len(s); i += usermem.PageSize {
+ _ = s[i] // Touch to commit.
+ }
+ }
+
+ // See bluepill_allocator.go.
+ bluepill(as.pageTables.Allocator.(*allocator).cpu)
+
+ // Perform the mapping.
+ prev := as.mapLocked(addr, hostMapEntry{
+ addr: b.Addr(),
+ length: uintptr(b.Len()),
+ }, at)
+ inv = inv || prev
+ addr += usermem.Addr(b.Len())
+ }
+ if inv {
+ as.invalidate()
+ }
+
+ return nil
+}
+
+// unmapLocked is an escape-checked wrapped around Unmap.
+//
+// +checkescape:hard,stack
+//
+//go:nosplit
+func (as *addressSpace) unmapLocked(addr usermem.Addr, length uint64) bool {
+ return as.pageTables.Unmap(addr, uintptr(length))
+}
+
+// Unmap unmaps the given range by calling pagetables.PageTables.Unmap.
+func (as *addressSpace) Unmap(addr usermem.Addr, length uint64) {
+ as.mu.Lock()
+ defer as.mu.Unlock()
+
+ // See above & bluepill_allocator.go.
+ as.pageTables.Allocator.(*allocator).cpu = as.machine.Get()
+ defer as.machine.Put(as.pageTables.Allocator.(*allocator).cpu)
+ bluepill(as.pageTables.Allocator.(*allocator).cpu)
+
+ if prev := as.unmapLocked(addr, length); prev {
+ // Invalidate all active vCPUs.
+ as.invalidate()
+
+ // Recycle any freed intermediate pages.
+ as.pageTables.Allocator.Recycle()
+ }
+}
+
+// Release releases the page tables.
+func (as *addressSpace) Release() {
+ as.Unmap(0, ^uint64(0))
+
+ // Free all pages from the allocator.
+ as.pageTables.Allocator.(*allocator).base.Drain()
+
+ // Drop all cached machine references.
+ as.machine.dropPageTables(as.pageTables)
+}
diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go
new file mode 100644
index 000000000..4b23f7803
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill.go
@@ -0,0 +1,96 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "fmt"
+ "reflect"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/safecopy"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+)
+
+// bluepill enters guest mode.
+func bluepill(*vCPU)
+
+// sighandler is the signal entry point.
+func sighandler()
+
+// dieTrampoline is the assembly trampoline. This calls dieHandler.
+//
+// This uses an architecture-specific calling convention, documented in
+// dieArchSetup and the assembly implementation for dieTrampoline.
+func dieTrampoline()
+
+var (
+ // bounceSignal is the signal used for bouncing KVM.
+ //
+ // We use SIGCHLD because it is not masked by the runtime, and
+ // it will be ignored properly by other parts of the kernel.
+ bounceSignal = syscall.SIGCHLD
+
+ // bounceSignalMask has only bounceSignal set.
+ bounceSignalMask = uint64(1 << (uint64(bounceSignal) - 1))
+
+ // bounce is the interrupt vector used to return to the kernel.
+ bounce = uint32(ring0.VirtualizationException)
+
+ // savedHandler is a pointer to the previous handler.
+ //
+ // This is called by bluepillHandler.
+ savedHandler uintptr
+
+ // dieTrampolineAddr is the address of dieTrampoline.
+ dieTrampolineAddr uintptr
+)
+
+// redpill invokes a syscall with -1.
+//
+//go:nosplit
+func redpill() {
+ syscall.RawSyscall(^uintptr(0), 0, 0, 0)
+}
+
+// dieHandler is called by dieTrampoline.
+//
+//go:nosplit
+func dieHandler(c *vCPU) {
+ throw(c.dieState.message)
+}
+
+// die is called to set the vCPU up to panic.
+//
+// This loads vCPU state, and sets up a call for the trampoline.
+//
+//go:nosplit
+func (c *vCPU) die(context *arch.SignalContext64, msg string) {
+ // Save the death message, which will be thrown.
+ c.dieState.message = msg
+
+ // Setup the trampoline.
+ dieArchSetup(c, context, &c.dieState.guestRegs)
+}
+
+func init() {
+ // Install the handler.
+ if err := safecopy.ReplaceSignalHandler(bluepillSignal, reflect.ValueOf(sighandler).Pointer(), &savedHandler); err != nil {
+ panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err))
+ }
+
+ // Extract the address for the trampoline.
+ dieTrampolineAddr = reflect.ValueOf(dieTrampoline).Pointer()
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_allocator.go b/pkg/sentry/platform/kvm/bluepill_allocator.go
new file mode 100644
index 000000000..9485e1301
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_allocator.go
@@ -0,0 +1,100 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+)
+
+type allocator struct {
+ base pagetables.RuntimeAllocator
+
+ // cpu must be set prior to any pagetable operation.
+ //
+ // Due to the way KVM's shadow paging implementation works,
+ // modifications to the page tables while in host mode may not be
+ // trapped, leading to the shadow pages being out of sync. Therefore,
+ // we need to ensure that we are in guest mode for page table
+ // modifications. See the call to bluepill, below.
+ cpu *vCPU
+}
+
+// newAllocator is used to define the allocator.
+func newAllocator() *allocator {
+ a := new(allocator)
+ a.base.Init()
+ return a
+}
+
+// NewPTEs implements pagetables.Allocator.NewPTEs.
+//
+// +checkescape:all
+//
+//go:nosplit
+func (a *allocator) NewPTEs() *pagetables.PTEs {
+ ptes := a.base.NewPTEs() // escapes: bluepill below.
+ if a.cpu != nil {
+ bluepill(a.cpu)
+ }
+ return ptes
+}
+
+// PhysicalFor returns the physical address for a set of PTEs.
+//
+// +checkescape:all
+//
+//go:nosplit
+func (a *allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr {
+ virtual := a.base.PhysicalFor(ptes)
+ physical, _, ok := translateToPhysical(virtual)
+ if !ok {
+ panic(fmt.Sprintf("PhysicalFor failed for %p", ptes)) // escapes: panic.
+ }
+ return physical
+}
+
+// LookupPTEs implements pagetables.Allocator.LookupPTEs.
+//
+// +checkescape:all
+//
+//go:nosplit
+func (a *allocator) LookupPTEs(physical uintptr) *pagetables.PTEs {
+ virtualStart, physicalStart, _, ok := calculateBluepillFault(physical, physicalRegions)
+ if !ok {
+ panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical)) // escapes: panic.
+ }
+ return a.base.LookupPTEs(virtualStart + (physical - physicalStart))
+}
+
+// FreePTEs implements pagetables.Allocator.FreePTEs.
+//
+// +checkescape:all
+//
+//go:nosplit
+func (a *allocator) FreePTEs(ptes *pagetables.PTEs) {
+ a.base.FreePTEs(ptes) // escapes: bluepill below.
+ if a.cpu != nil {
+ bluepill(a.cpu)
+ }
+}
+
+// Recycle implements pagetables.Allocator.Recycle.
+//
+//go:nosplit
+func (a *allocator) Recycle() {
+ a.base.Recycle()
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go
new file mode 100644
index 000000000..ddc1554d5
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.go
@@ -0,0 +1,129 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package kvm
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+)
+
+var (
+ // The action for bluepillSignal is changed by sigaction().
+ bluepillSignal = syscall.SIGSEGV
+)
+
+// bluepillArchEnter is called during bluepillEnter.
+//
+//go:nosplit
+func bluepillArchEnter(context *arch.SignalContext64) *vCPU {
+ c := vCPUPtr(uintptr(context.Rax))
+ regs := c.CPU.Registers()
+ regs.R8 = context.R8
+ regs.R9 = context.R9
+ regs.R10 = context.R10
+ regs.R11 = context.R11
+ regs.R12 = context.R12
+ regs.R13 = context.R13
+ regs.R14 = context.R14
+ regs.R15 = context.R15
+ regs.Rdi = context.Rdi
+ regs.Rsi = context.Rsi
+ regs.Rbp = context.Rbp
+ regs.Rbx = context.Rbx
+ regs.Rdx = context.Rdx
+ regs.Rax = context.Rax
+ regs.Rcx = context.Rcx
+ regs.Rsp = context.Rsp
+ regs.Rip = context.Rip
+ regs.Eflags = context.Eflags
+ regs.Eflags &^= uint64(ring0.KernelFlagsClear)
+ regs.Eflags |= ring0.KernelFlagsSet
+ regs.Cs = uint64(ring0.Kcode)
+ regs.Ds = uint64(ring0.Udata)
+ regs.Es = uint64(ring0.Udata)
+ regs.Ss = uint64(ring0.Kdata)
+ return c
+}
+
+// KernelSyscall handles kernel syscalls.
+//
+// +checkescape:all
+//
+//go:nosplit
+func (c *vCPU) KernelSyscall() {
+ regs := c.Registers()
+ if regs.Rax != ^uint64(0) {
+ regs.Rip -= 2 // Rewind.
+ }
+ // We only trigger a bluepill entry in the bluepill function, and can
+ // therefore be guaranteed that there is no floating point state to be
+ // loaded on resuming from halt. We only worry about saving on exit.
+ ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) // escapes: no.
+ ring0.Halt()
+ ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no, reload host segment.
+}
+
+// KernelException handles kernel exceptions.
+//
+// +checkescape:all
+//
+//go:nosplit
+func (c *vCPU) KernelException(vector ring0.Vector) {
+ regs := c.Registers()
+ if vector == ring0.Vector(bounce) {
+ // These should not interrupt kernel execution; point the Rip
+ // to zero to ensure that we get a reasonable panic when we
+ // attempt to return and a full stack trace.
+ regs.Rip = 0
+ }
+ // See above.
+ ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) // escapes: no.
+ ring0.Halt()
+ ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no; reload host segment.
+}
+
+// bluepillArchExit is called during bluepillEnter.
+//
+//go:nosplit
+func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
+ regs := c.CPU.Registers()
+ context.R8 = regs.R8
+ context.R9 = regs.R9
+ context.R10 = regs.R10
+ context.R11 = regs.R11
+ context.R12 = regs.R12
+ context.R13 = regs.R13
+ context.R14 = regs.R14
+ context.R15 = regs.R15
+ context.Rdi = regs.Rdi
+ context.Rsi = regs.Rsi
+ context.Rbp = regs.Rbp
+ context.Rbx = regs.Rbx
+ context.Rdx = regs.Rdx
+ context.Rax = regs.Rax
+ context.Rcx = regs.Rcx
+ context.Rsp = regs.Rsp
+ context.Rip = regs.Rip
+ context.Eflags = regs.Eflags
+
+ // Set the context pointer to the saved floating point state. This is
+ // where the guest data has been serialized, the kernel will restore
+ // from this new pointer value.
+ context.Fpstate = uint64(uintptrValue((*byte)(c.floatingPointState)))
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s
new file mode 100644
index 000000000..2bc34a435
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.s
@@ -0,0 +1,93 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// VCPU_CPU is the location of the CPU in the vCPU struct.
+//
+// This is guaranteed to be zero.
+#define VCPU_CPU 0x0
+
+// CPU_SELF is the self reference in ring0's percpu.
+//
+// This is guaranteed to be zero.
+#define CPU_SELF 0x0
+
+// Context offsets.
+//
+// Only limited use of the context is done in the assembly stub below, most is
+// done in the Go handlers. However, the RIP must be examined.
+#define CONTEXT_RAX 0x90
+#define CONTEXT_RIP 0xa8
+#define CONTEXT_FP 0xe0
+
+// CLI is the literal byte for the disable interrupts instruction.
+//
+// This is checked as the source of the fault.
+#define CLI $0xfa
+
+// See bluepill.go.
+TEXT ·bluepill(SB),NOSPLIT,$0
+begin:
+ MOVQ vcpu+0(FP), AX
+ LEAQ VCPU_CPU(AX), BX
+ BYTE CLI;
+check_vcpu:
+ MOVQ CPU_SELF(GS), CX
+ CMPQ BX, CX
+ JE right_vCPU
+wrong_vcpu:
+ CALL ·redpill(SB)
+ JMP begin
+right_vCPU:
+ RET
+
+// sighandler: see bluepill.go for documentation.
+//
+// The arguments are the following:
+//
+// DI - The signal number.
+// SI - Pointer to siginfo_t structure.
+// DX - Pointer to ucontext structure.
+//
+TEXT ·sighandler(SB),NOSPLIT,$0
+ // Check if the signal is from the kernel.
+ MOVQ $0x80, CX
+ CMPL CX, 0x8(SI)
+ JNE fallback
+
+ // Check if RIP is disable interrupts.
+ MOVQ CONTEXT_RIP(DX), CX
+ CMPQ CX, $0x0
+ JE fallback
+ CMPB 0(CX), CLI
+ JNE fallback
+
+ // Call the bluepillHandler.
+ PUSHQ DX // First argument (context).
+ CALL ·bluepillHandler(SB) // Call the handler.
+ POPQ DX // Discard the argument.
+ RET
+
+fallback:
+ // Jump to the previous signal handler.
+ XORQ CX, CX
+ MOVQ ·savedHandler(SB), AX
+ JMP AX
+
+// dieTrampoline: see bluepill.go, bluepill_amd64_unsafe.go for documentation.
+TEXT ·dieTrampoline(SB),NOSPLIT,$0
+ PUSHQ BX // First argument (vCPU).
+ PUSHQ AX // Fake the old RIP as caller.
+ JMP ·dieHandler(SB)
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
new file mode 100644
index 000000000..03a98512e
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
@@ -0,0 +1,87 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package kvm
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+)
+
+// dieArchSetup initializes the state for dieTrampoline.
+//
+// The amd64 dieTrampoline requires the vCPU to be set in BX, and the last RIP
+// to be in AX. The trampoline then simulates a call to dieHandler from the
+// provided RIP.
+//
+//go:nosplit
+func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) {
+ // Reload all registers to have an accurate stack trace when we return
+ // to host mode. This means that the stack should be unwound correctly.
+ if errno := c.getUserRegisters(&c.dieState.guestRegs); errno != 0 {
+ throw(c.dieState.message)
+ }
+
+ // If the vCPU is in user mode, we set the stack to the stored stack
+ // value in the vCPU itself. We don't want to unwind the user stack.
+ if guestRegs.RFLAGS&ring0.UserFlagsSet == ring0.UserFlagsSet {
+ regs := c.CPU.Registers()
+ context.Rax = regs.Rax
+ context.Rsp = regs.Rsp
+ context.Rbp = regs.Rbp
+ } else {
+ context.Rax = guestRegs.RIP
+ context.Rsp = guestRegs.RSP
+ context.Rbp = guestRegs.RBP
+ context.Eflags = guestRegs.RFLAGS
+ }
+ context.Rbx = uint64(uintptr(unsafe.Pointer(c)))
+ context.Rip = uint64(dieTrampolineAddr)
+}
+
+// getHypercallID returns hypercall ID.
+//
+//go:nosplit
+func getHypercallID(addr uintptr) int {
+ return _KVM_HYPERCALL_MAX
+}
+
+// bluepillStopGuest is reponsible for injecting interrupt.
+//
+//go:nosplit
+func bluepillStopGuest(c *vCPU) {
+ // Interrupt: we must have requested an interrupt
+ // window; set the interrupt line.
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_INTERRUPT,
+ uintptr(unsafe.Pointer(&bounce))); errno != 0 {
+ throw("interrupt injection failed")
+ }
+ // Clear previous injection request.
+ c.runData.requestInterruptWindow = 0
+}
+
+// bluepillReadyStopGuest checks whether the current vCPU is ready for interrupt injection.
+//
+//go:nosplit
+func bluepillReadyStopGuest(c *vCPU) bool {
+ return c.runData.readyForInterruptInjection != 0
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
new file mode 100644
index 000000000..dba563160
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -0,0 +1,124 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package kvm
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+)
+
+var (
+ // The action for bluepillSignal is changed by sigaction().
+ bluepillSignal = syscall.SIGILL
+
+ // vcpuSErr is the event of system error.
+ vcpuSErr = kvmVcpuEvents{
+ exception: exception{
+ sErrPending: 1,
+ sErrHasEsr: 0,
+ pad: [6]uint8{0, 0, 0, 0, 0, 0},
+ sErrEsr: 1,
+ },
+ rsvd: [12]uint32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
+ }
+)
+
+// bluepillArchEnter is called during bluepillEnter.
+//
+//go:nosplit
+func bluepillArchEnter(context *arch.SignalContext64) (c *vCPU) {
+ c = vCPUPtr(uintptr(context.Regs[8]))
+ regs := c.CPU.Registers()
+ regs.Regs = context.Regs
+ regs.Sp = context.Sp
+ regs.Pc = context.Pc
+ regs.Pstate = context.Pstate
+ regs.Pstate &^= uint64(ring0.KernelFlagsClear)
+ regs.Pstate |= ring0.KernelFlagsSet
+ return
+}
+
+// bluepillArchExit is called during bluepillEnter.
+//
+//go:nosplit
+func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
+ regs := c.CPU.Registers()
+ context.Regs = regs.Regs
+ context.Sp = regs.Sp
+ context.Pc = regs.Pc
+ context.Pstate = regs.Pstate
+ context.Pstate &^= uint64(ring0.UserFlagsClear)
+ context.Pstate |= ring0.UserFlagsSet
+
+ lazyVfp := c.GetLazyVFP()
+ if lazyVfp != 0 {
+ fpsimd := fpsimdPtr((*byte)(c.floatingPointState))
+ context.Fpsimd64.Fpsr = fpsimd.Fpsr
+ context.Fpsimd64.Fpcr = fpsimd.Fpcr
+ context.Fpsimd64.Vregs = fpsimd.Vregs
+ }
+}
+
+// KernelSyscall handles kernel syscalls.
+//
+// +checkescape:all
+//
+//go:nosplit
+func (c *vCPU) KernelSyscall() {
+ regs := c.Registers()
+ if regs.Regs[8] != ^uint64(0) {
+ regs.Pc -= 4 // Rewind.
+ }
+
+ vfpEnable := ring0.CPACREL1()
+ if vfpEnable != 0 {
+ fpsimd := fpsimdPtr((*byte)(c.floatingPointState))
+ fpcr := ring0.GetFPCR()
+ fpsr := ring0.GetFPSR()
+ fpsimd.Fpcr = uint32(fpcr)
+ fpsimd.Fpsr = uint32(fpsr)
+ ring0.SaveVRegs((*byte)(c.floatingPointState))
+ }
+
+ ring0.Halt()
+}
+
+// KernelException handles kernel exceptions.
+//
+// +checkescape:all
+//
+//go:nosplit
+func (c *vCPU) KernelException(vector ring0.Vector) {
+ regs := c.Registers()
+ if vector == ring0.Vector(bounce) {
+ regs.Pc = 0
+ }
+
+ vfpEnable := ring0.CPACREL1()
+ if vfpEnable != 0 {
+ fpsimd := fpsimdPtr((*byte)(c.floatingPointState))
+ fpcr := ring0.GetFPCR()
+ fpsr := ring0.GetFPSR()
+ fpsimd.Fpcr = uint32(fpcr)
+ fpsimd.Fpsr = uint32(fpsr)
+ ring0.SaveVRegs((*byte)(c.floatingPointState))
+ }
+
+ ring0.Halt()
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s
new file mode 100644
index 000000000..04efa0147
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.s
@@ -0,0 +1,89 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// VCPU_CPU is the location of the CPU in the vCPU struct.
+//
+// This is guaranteed to be zero.
+#define VCPU_CPU 0x0
+
+// CPU_SELF is the self reference in ring0's percpu.
+//
+// This is guaranteed to be zero.
+#define CPU_SELF 0x0
+
+// Context offsets.
+//
+// Only limited use of the context is done in the assembly stub below, most is
+// done in the Go handlers.
+#define SIGINFO_SIGNO 0x0
+#define CONTEXT_PC 0x1B8
+#define CONTEXT_R0 0xB8
+
+// See bluepill.go.
+TEXT ·bluepill(SB),NOSPLIT,$0
+begin:
+ MOVD vcpu+0(FP), R8
+ MOVD $VCPU_CPU(R8), R9
+ ORR $0xffff000000000000, R9, R9
+ // Trigger sigill.
+ // In ring0.Start(), the value of R8 will be stored into tpidr_el1.
+ // When the context was loaded into vcpu successfully,
+ // we will check if the value of R10 and R9 are the same.
+ WORD $0xd538d08a // MRS TPIDR_EL1, R10
+check_vcpu:
+ CMP R10, R9
+ BEQ right_vCPU
+wrong_vcpu:
+ CALL ·redpill(SB)
+ B begin
+right_vCPU:
+ RET
+
+// sighandler: see bluepill.go for documentation.
+//
+// The arguments are the following:
+//
+// R0 - The signal number.
+// R1 - Pointer to siginfo_t structure.
+// R2 - Pointer to ucontext structure.
+//
+TEXT ·sighandler(SB),NOSPLIT,$0
+ // si_signo should be sigill.
+ MOVD SIGINFO_SIGNO(R1), R7
+ CMPW $4, R7
+ BNE fallback
+
+ MOVD CONTEXT_PC(R2), R7
+ CMPW $0, R7
+ BEQ fallback
+
+ MOVD R2, 8(RSP)
+ BL ·bluepillHandler(SB) // Call the handler.
+
+ RET
+
+fallback:
+ // Jump to the previous signal handler.
+ MOVD ·savedHandler(SB), R7
+ B (R7)
+
+// dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation.
+TEXT ·dieTrampoline(SB),NOSPLIT,$0
+ // R0: Fake the old PC as caller
+ // R1: First argument (vCPU)
+ MOVD.P R1, 8(RSP) // R1: First argument (vCPU)
+ MOVD.P R0, 8(RSP) // R0: Fake the old PC as caller
+ B ·dieHandler(SB)
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
new file mode 100644
index 000000000..8b64f3a1e
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
@@ -0,0 +1,97 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package kvm
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+)
+
+// fpsimdPtr returns a fpsimd64 for the given address.
+//
+//go:nosplit
+func fpsimdPtr(addr *byte) *arch.FpsimdContext {
+ return (*arch.FpsimdContext)(unsafe.Pointer(addr))
+}
+
+// dieArchSetup initialies the state for dieTrampoline.
+//
+// The arm64 dieTrampoline requires the vCPU to be set in R1, and the last PC
+// to be in R0. The trampoline then simulates a call to dieHandler from the
+// provided PC.
+//
+//go:nosplit
+func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) {
+ // If the vCPU is in user mode, we set the stack to the stored stack
+ // value in the vCPU itself. We don't want to unwind the user stack.
+ if guestRegs.Regs.Pstate&ring0.PSR_MODE_MASK == ring0.PSR_MODE_EL0t {
+ regs := c.CPU.Registers()
+ context.Regs[0] = regs.Regs[0]
+ context.Sp = regs.Sp
+ context.Regs[29] = regs.Regs[29] // stack base address
+ } else {
+ context.Regs[0] = guestRegs.Regs.Pc
+ context.Sp = guestRegs.Regs.Sp
+ context.Regs[29] = guestRegs.Regs.Regs[29]
+ context.Pstate = guestRegs.Regs.Pstate
+ }
+ context.Regs[1] = uint64(uintptr(unsafe.Pointer(c)))
+ context.Pc = uint64(dieTrampolineAddr)
+}
+
+// bluepillArchFpContext returns the arch-specific fpsimd context.
+//
+//go:nosplit
+func bluepillArchFpContext(context unsafe.Pointer) *arch.FpsimdContext {
+ return &((*arch.SignalContext64)(context).Fpsimd64)
+}
+
+// getHypercallID returns hypercall ID.
+//
+// On Arm64, the MMIO address should be 64-bit aligned.
+//
+//go:nosplit
+func getHypercallID(addr uintptr) int {
+ if addr < arm64HypercallMMIOBase || addr >= (arm64HypercallMMIOBase+_AARCH64_HYPERCALL_MMIO_SIZE) {
+ return _KVM_HYPERCALL_MAX
+ } else {
+ return int(((addr) - arm64HypercallMMIOBase) >> 3)
+ }
+}
+
+// bluepillStopGuest is reponsible for injecting sError.
+//
+//go:nosplit
+func bluepillStopGuest(c *vCPU) {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_VCPU_EVENTS,
+ uintptr(unsafe.Pointer(&vcpuSErr))); errno != 0 {
+ throw("sErr injection failed")
+ }
+}
+
+// bluepillReadyStopGuest checks whether the current vCPU is ready for sError injection.
+//
+//go:nosplit
+func bluepillReadyStopGuest(c *vCPU) bool {
+ return true
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go
new file mode 100644
index 000000000..e34f46aeb
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_fault.go
@@ -0,0 +1,127 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ // faultBlockSize is the size used for servicing memory faults.
+ //
+ // This should be large enough to avoid frequent faults and avoid using
+ // all available KVM slots (~512), but small enough that KVM does not
+ // complain about slot sizes (~4GB). See handleBluepillFault for how
+ // this block is used.
+ faultBlockSize = 2 << 30
+
+ // faultBlockMask is the mask for the fault blocks.
+ //
+ // This must be typed to avoid overflow complaints (ugh).
+ faultBlockMask = ^uintptr(faultBlockSize - 1)
+)
+
+// yield yields the CPU.
+//
+//go:nosplit
+func yield() {
+ syscall.RawSyscall(syscall.SYS_SCHED_YIELD, 0, 0, 0)
+}
+
+// calculateBluepillFault calculates the fault address range.
+//
+//go:nosplit
+func calculateBluepillFault(physical uintptr, phyRegions []physicalRegion) (virtualStart, physicalStart, length uintptr, ok bool) {
+ alignedPhysical := physical &^ uintptr(usermem.PageSize-1)
+ for _, pr := range phyRegions {
+ end := pr.physical + pr.length
+ if physical < pr.physical || physical >= end {
+ continue
+ }
+
+ // Adjust the block to match our size.
+ physicalStart = alignedPhysical & faultBlockMask
+ if physicalStart < pr.physical {
+ // Bound the starting point to the start of the region.
+ physicalStart = pr.physical
+ }
+ virtualStart = pr.virtual + (physicalStart - pr.physical)
+ physicalEnd := physicalStart + faultBlockSize
+ if physicalEnd > end {
+ physicalEnd = end
+ }
+ length = physicalEnd - physicalStart
+ return virtualStart, physicalStart, length, true
+ }
+
+ return 0, 0, 0, false
+}
+
+// handleBluepillFault handles a physical fault.
+//
+// The corresponding virtual address is returned. This may throw on error.
+//
+//go:nosplit
+func handleBluepillFault(m *machine, physical uintptr, phyRegions []physicalRegion, flags uint32) (uintptr, bool) {
+ // Paging fault: we need to map the underlying physical pages for this
+ // fault. This all has to be done in this function because we're in a
+ // signal handler context. (We can't call any functions that might
+ // split the stack.)
+ virtualStart, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions)
+ if !ok {
+ return 0, false
+ }
+
+ // Set the KVM slot.
+ //
+ // First, we need to acquire the exclusive right to set a slot. See
+ // machine.nextSlot for information about the protocol.
+ slot := atomic.SwapUint32(&m.nextSlot, ^uint32(0))
+ for slot == ^uint32(0) {
+ yield() // Race with another call.
+ slot = atomic.SwapUint32(&m.nextSlot, ^uint32(0))
+ }
+ errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart, flags)
+ if errno == 0 {
+ // Successfully added region; we can increment nextSlot and
+ // allow another set to proceed here.
+ atomic.StoreUint32(&m.nextSlot, slot+1)
+ return virtualStart + (physical - physicalStart), true
+ }
+
+ // Release our slot (still available).
+ atomic.StoreUint32(&m.nextSlot, slot)
+
+ switch errno {
+ case syscall.EEXIST:
+ // The region already exists. It's possible that we raced with
+ // another vCPU here. We just revert nextSlot and return true,
+ // because this must have been satisfied by some other vCPU.
+ return virtualStart + (physical - physicalStart), true
+ case syscall.EINVAL:
+ throw("set memory region failed; out of slots")
+ case syscall.ENOMEM:
+ throw("set memory region failed: out of memory")
+ case syscall.EFAULT:
+ throw("set memory region failed: invalid physical range")
+ default:
+ throw("set memory region failed: unknown reason")
+ }
+
+ panic("unreachable")
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
new file mode 100644
index 000000000..bf357de1a
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -0,0 +1,232 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build go1.12
+// +build !go1.16
+
+// Check go:linkname function signatures when updating Go version.
+
+package kvm
+
+import (
+ "sync/atomic"
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+//go:linkname throw runtime.throw
+func throw(string)
+
+// vCPUPtr returns a CPU for the given address.
+//
+//go:nosplit
+func vCPUPtr(addr uintptr) *vCPU {
+ return (*vCPU)(unsafe.Pointer(addr))
+}
+
+// bytePtr returns a bytePtr for the given address.
+//
+//go:nosplit
+func bytePtr(addr uintptr) *byte {
+ return (*byte)(unsafe.Pointer(addr))
+}
+
+// uintptrValue returns a uintptr for the given address.
+//
+//go:nosplit
+func uintptrValue(addr *byte) uintptr {
+ return (uintptr)(unsafe.Pointer(addr))
+}
+
+// bluepillArchContext returns the UContext64.
+//
+//go:nosplit
+func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 {
+ return &((*arch.UContext64)(context).MContext)
+}
+
+// bluepillHandleHlt is reponsible for handling VM-Exit.
+//
+//go:nosplit
+func bluepillGuestExit(c *vCPU, context unsafe.Pointer) {
+ // Copy out registers.
+ bluepillArchExit(c, bluepillArchContext(context))
+
+ // Return to the vCPUReady state; notify any waiters.
+ user := atomic.LoadUint32(&c.state) & vCPUUser
+ switch atomic.SwapUint32(&c.state, user) {
+ case user | vCPUGuest: // Expected case.
+ case user | vCPUGuest | vCPUWaiter:
+ c.notify()
+ default:
+ throw("invalid state")
+ }
+}
+
+// bluepillHandler is called from the signal stub.
+//
+// The world may be stopped while this is executing, and it executes on the
+// signal stack. It should only execute raw system calls and functions that are
+// explicitly marked go:nosplit.
+//
+// +checkescape:all
+//
+//go:nosplit
+func bluepillHandler(context unsafe.Pointer) {
+ // Sanitize the registers; interrupts must always be disabled.
+ c := bluepillArchEnter(bluepillArchContext(context))
+
+ // Increment the number of switches.
+ atomic.AddUint32(&c.switches, 1)
+
+ // Mark this as guest mode.
+ switch atomic.SwapUint32(&c.state, vCPUGuest|vCPUUser) {
+ case vCPUUser: // Expected case.
+ case vCPUUser | vCPUWaiter:
+ c.notify()
+ default:
+ throw("invalid state")
+ }
+
+ for {
+ _, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(c.fd), _KVM_RUN, 0) // escapes: no.
+ switch errno {
+ case 0: // Expected case.
+ case syscall.EINTR:
+ // First, we process whatever pending signal
+ // interrupted KVM. Since we're in a signal handler
+ // currently, all signals are masked and the signal
+ // must have been delivered directly to this thread.
+ timeout := syscall.Timespec{}
+ sig, _, errno := syscall.RawSyscall6( // escapes: no.
+ syscall.SYS_RT_SIGTIMEDWAIT,
+ uintptr(unsafe.Pointer(&bounceSignalMask)),
+ 0, // siginfo.
+ uintptr(unsafe.Pointer(&timeout)), // timeout.
+ 8, // sigset size.
+ 0, 0)
+ if errno == syscall.EAGAIN {
+ continue
+ }
+ if errno != 0 {
+ throw("error waiting for pending signal")
+ }
+ if sig != uintptr(bounceSignal) {
+ throw("unexpected signal")
+ }
+
+ // Check whether the current state of the vCPU is ready
+ // for interrupt injection. Because we don't have a
+ // PIC, we can't inject an interrupt while they are
+ // masked. We need to request a window if it's not
+ // ready.
+ if bluepillReadyStopGuest(c) {
+ // Force injection below; the vCPU is ready.
+ c.runData.exitReason = _KVM_EXIT_IRQ_WINDOW_OPEN
+ } else {
+ c.runData.requestInterruptWindow = 1
+ continue // Rerun vCPU.
+ }
+ case syscall.EFAULT:
+ // If a fault is not serviceable due to the host
+ // backing pages having page permissions, instead of an
+ // MMIO exit we receive EFAULT from the run ioctl. We
+ // always inject an NMI here since we may be in kernel
+ // mode and have interrupts disabled.
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_NMI, 0); errno != 0 {
+ throw("NMI injection failed")
+ }
+ continue // Rerun vCPU.
+ default:
+ throw("run failed")
+ }
+
+ switch c.runData.exitReason {
+ case _KVM_EXIT_EXCEPTION:
+ c.die(bluepillArchContext(context), "exception")
+ return
+ case _KVM_EXIT_IO:
+ c.die(bluepillArchContext(context), "I/O")
+ return
+ case _KVM_EXIT_INTERNAL_ERROR:
+ // An internal error is typically thrown when emulation
+ // fails. This can occur via the MMIO path below (and
+ // it might fail because we have multiple regions that
+ // are not mapped). We would actually prefer that no
+ // emulation occur, and don't mind at all if it fails.
+ case _KVM_EXIT_HYPERCALL:
+ c.die(bluepillArchContext(context), "hypercall")
+ return
+ case _KVM_EXIT_DEBUG:
+ c.die(bluepillArchContext(context), "debug")
+ return
+ case _KVM_EXIT_HLT:
+ bluepillGuestExit(c, context)
+ return
+ case _KVM_EXIT_MMIO:
+ physical := uintptr(c.runData.data[0])
+ if getHypercallID(physical) == _KVM_HYPERCALL_VMEXIT {
+ bluepillGuestExit(c, context)
+ return
+ }
+
+ // Increment the fault count.
+ atomic.AddUint32(&c.faults, 1)
+
+ // For MMIO, the physical address is the first data item.
+ physical = uintptr(c.runData.data[0])
+ virtual, ok := handleBluepillFault(c.machine, physical, physicalRegions, _KVM_MEM_FLAGS_NONE)
+ if !ok {
+ c.die(bluepillArchContext(context), "invalid physical address")
+ return
+ }
+
+ // We now need to fill in the data appropriately. KVM
+ // expects us to provide the result of the given MMIO
+ // operation in the runData struct. This is safe
+ // because, if a fault occurs here, the same fault
+ // would have occurred in guest mode. The kernel should
+ // not create invalid page table mappings.
+ data := (*[8]byte)(unsafe.Pointer(&c.runData.data[1]))
+ length := (uintptr)((uint32)(c.runData.data[2]))
+ write := (uint8)(((c.runData.data[2] >> 32) & 0xff)) != 0
+ for i := uintptr(0); i < length; i++ {
+ b := bytePtr(uintptr(virtual) + i)
+ if write {
+ // Write to the given address.
+ *b = data[i]
+ } else {
+ // Read from the given address.
+ data[i] = *b
+ }
+ }
+ case _KVM_EXIT_IRQ_WINDOW_OPEN:
+ bluepillStopGuest(c)
+ case _KVM_EXIT_SHUTDOWN:
+ c.die(bluepillArchContext(context), "shutdown")
+ return
+ case _KVM_EXIT_FAIL_ENTRY:
+ c.die(bluepillArchContext(context), "entry failed")
+ return
+ default:
+ c.die(bluepillArchContext(context), "unknown")
+ return
+ }
+ }
+}
diff --git a/pkg/sentry/platform/kvm/context.go b/pkg/sentry/platform/kvm/context.go
new file mode 100644
index 000000000..6507121ea
--- /dev/null
+++ b/pkg/sentry/platform/kvm/context.go
@@ -0,0 +1,90 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/interrupt"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// context is an implementation of the platform context.
+//
+// This is a thin wrapper around the machine.
+type context struct {
+ // machine is the parent machine, and is immutable.
+ machine *machine
+
+ // info is the arch.SignalInfo cached for this context.
+ info arch.SignalInfo
+
+ // interrupt is the interrupt context.
+ interrupt interrupt.Forwarder
+}
+
+// Switch runs the provided context in the given address space.
+func (c *context) Switch(as platform.AddressSpace, ac arch.Context, _ int32) (*arch.SignalInfo, usermem.AccessType, error) {
+ localAS := as.(*addressSpace)
+
+ // Grab a vCPU.
+ cpu := c.machine.Get()
+
+ // Enable interrupts (i.e. calls to vCPU.Notify).
+ if !c.interrupt.Enable(cpu) {
+ c.machine.Put(cpu) // Already preempted.
+ return nil, usermem.NoAccess, platform.ErrContextInterrupt
+ }
+
+ // Set the active address space.
+ //
+ // This must be done prior to the call to Touch below. If the address
+ // space is invalidated between this line and the call below, we will
+ // flag on entry anyways. When the active address space below is
+ // cleared, it indicates that we don't need an explicit interrupt and
+ // that the flush can occur naturally on the next user entry.
+ cpu.active.set(localAS)
+
+ // Prepare switch options.
+ switchOpts := ring0.SwitchOpts{
+ Registers: &ac.StateData().Regs,
+ FloatingPointState: (*byte)(ac.FloatingPointData()),
+ PageTables: localAS.pageTables,
+ Flush: localAS.Touch(cpu),
+ FullRestore: ac.FullRestore(),
+ }
+
+ // Take the blue pill.
+ at, err := cpu.SwitchToUser(switchOpts, &c.info)
+
+ // Clear the address space.
+ cpu.active.set(nil)
+
+ // Release resources.
+ c.machine.Put(cpu)
+
+ // All done.
+ c.interrupt.Disable()
+ return &c.info, at, err
+}
+
+// Interrupt interrupts the running context.
+func (c *context) Interrupt() {
+ c.interrupt.NotifyInterrupt()
+}
+
+// Release implements platform.Context.Release().
+func (c *context) Release() {}
diff --git a/pkg/sentry/platform/kvm/filters_amd64.go b/pkg/sentry/platform/kvm/filters_amd64.go
new file mode 100644
index 000000000..7d949f1dd
--- /dev/null
+++ b/pkg/sentry/platform/kvm/filters_amd64.go
@@ -0,0 +1,33 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+// SyscallFilters returns syscalls made exclusively by the KVM platform.
+func (*KVM) SyscallFilters() seccomp.SyscallRules {
+ return seccomp.SyscallRules{
+ syscall.SYS_ARCH_PRCTL: {},
+ syscall.SYS_IOCTL: {},
+ syscall.SYS_MMAP: {},
+ syscall.SYS_RT_SIGSUSPEND: {},
+ syscall.SYS_RT_SIGTIMEDWAIT: {},
+ 0xffffffffffffffff: {}, // KVM uses syscall -1 to transition to host.
+ }
+}
diff --git a/pkg/sentry/platform/kvm/filters_arm64.go b/pkg/sentry/platform/kvm/filters_arm64.go
new file mode 100644
index 000000000..9245d07c2
--- /dev/null
+++ b/pkg/sentry/platform/kvm/filters_arm64.go
@@ -0,0 +1,32 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+// SyscallFilters returns syscalls made exclusively by the KVM platform.
+func (*KVM) SyscallFilters() seccomp.SyscallRules {
+ return seccomp.SyscallRules{
+ syscall.SYS_IOCTL: {},
+ syscall.SYS_MMAP: {},
+ syscall.SYS_RT_SIGSUSPEND: {},
+ syscall.SYS_RT_SIGTIMEDWAIT: {},
+ 0xffffffffffffffff: {}, // KVM uses syscall -1 to transition to host.
+ }
+}
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
new file mode 100644
index 000000000..ae813e24e
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -0,0 +1,201 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package kvm provides a kvm-based implementation of the platform interface.
+package kvm
+
+import (
+ "fmt"
+ "os"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// userMemoryRegion is a region of physical memory.
+//
+// This mirrors kvm_memory_region.
+type userMemoryRegion struct {
+ slot uint32
+ flags uint32
+ guestPhysAddr uint64
+ memorySize uint64
+ userspaceAddr uint64
+}
+
+// runData is the run structure. This may be mapped for synchronous register
+// access (although that doesn't appear to be supported by my kernel at least).
+//
+// This mirrors kvm_run.
+type runData struct {
+ requestInterruptWindow uint8
+ _ [7]uint8
+
+ exitReason uint32
+ readyForInterruptInjection uint8
+ ifFlag uint8
+ _ [2]uint8
+
+ cr8 uint64
+ apicBase uint64
+
+ // This is the union data for exits. Interpretation depends entirely on
+ // the exitReason above (see vCPU code for more information).
+ data [32]uint64
+}
+
+// KVM represents a lightweight VM context.
+type KVM struct {
+ platform.NoCPUPreemptionDetection
+
+ // machine is the backing VM.
+ machine *machine
+}
+
+var (
+ globalOnce sync.Once
+ globalErr error
+)
+
+// OpenDevice opens the KVM device at /dev/kvm and returns the File.
+func OpenDevice() (*os.File, error) {
+ f, err := os.OpenFile("/dev/kvm", syscall.O_RDWR, 0)
+ if err != nil {
+ return nil, fmt.Errorf("error opening /dev/kvm: %v", err)
+ }
+ return f, nil
+}
+
+// New returns a new KVM-based implementation of the platform interface.
+func New(deviceFile *os.File) (*KVM, error) {
+ fd := deviceFile.Fd()
+
+ // Ensure global initialization is done.
+ globalOnce.Do(func() {
+ globalErr = updateGlobalOnce(int(fd))
+ })
+ if globalErr != nil {
+ return nil, globalErr
+ }
+
+ // Create a new VM fd.
+ var (
+ vm uintptr
+ errno syscall.Errno
+ )
+ for {
+ vm, _, errno = syscall.Syscall(syscall.SYS_IOCTL, fd, _KVM_CREATE_VM, 0)
+ if errno == syscall.EINTR {
+ continue
+ }
+ if errno != 0 {
+ return nil, fmt.Errorf("creating VM: %v", errno)
+ }
+ break
+ }
+ // We are done with the device file.
+ deviceFile.Close()
+
+ // Create a VM context.
+ machine, err := newMachine(int(vm))
+ if err != nil {
+ return nil, err
+ }
+
+ // All set.
+ return &KVM{
+ machine: machine,
+ }, nil
+}
+
+// SupportsAddressSpaceIO implements platform.Platform.SupportsAddressSpaceIO.
+func (*KVM) SupportsAddressSpaceIO() bool {
+ return false
+}
+
+// CooperativelySchedulesAddressSpace implements platform.Platform.CooperativelySchedulesAddressSpace.
+func (*KVM) CooperativelySchedulesAddressSpace() bool {
+ return false
+}
+
+// MapUnit implements platform.Platform.MapUnit.
+func (*KVM) MapUnit() uint64 {
+ // We greedily creates PTEs in MapFile, so extremely large mappings can
+ // be expensive. Not _that_ expensive since we allow super pages, but
+ // even though can get out of hand if you're creating multi-terabyte
+ // mappings. For this reason, we limit mappings to an arbitrary 16MB.
+ return 16 << 20
+}
+
+// MinUserAddress returns the lowest available address.
+func (*KVM) MinUserAddress() usermem.Addr {
+ return usermem.PageSize
+}
+
+// MaxUserAddress returns the first address that may not be used.
+func (*KVM) MaxUserAddress() usermem.Addr {
+ return usermem.Addr(ring0.MaximumUserAddress)
+}
+
+// NewAddressSpace returns a new pagetable root.
+func (k *KVM) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) {
+ // Allocate page tables and install system mappings.
+ pageTables := pagetables.New(newAllocator())
+ applyPhysicalRegions(func(pr physicalRegion) bool {
+ // Map the kernel in the upper half.
+ pageTables.Map(
+ usermem.Addr(ring0.KernelStartAddress|pr.virtual),
+ pr.length,
+ pagetables.MapOpts{AccessType: usermem.AnyAccess},
+ pr.physical)
+ return true // Keep iterating.
+ })
+
+ // Return the new address space.
+ return &addressSpace{
+ machine: k.machine,
+ pageTables: pageTables,
+ dirtySet: k.machine.newDirtySet(),
+ }, nil, nil
+}
+
+// NewContext returns an interruptible context.
+func (k *KVM) NewContext() platform.Context {
+ return &context{
+ machine: k.machine,
+ }
+}
+
+type constructor struct{}
+
+func (*constructor) New(f *os.File) (platform.Platform, error) {
+ return New(f)
+}
+
+func (*constructor) OpenDevice() (*os.File, error) {
+ return OpenDevice()
+}
+
+// Flags implements platform.Constructor.Flags().
+func (*constructor) Requirements() platform.Requirements {
+ return platform.Requirements{}
+}
+
+func init() {
+ platform.Register("kvm", &constructor{})
+}
diff --git a/pkg/sentry/platform/kvm/kvm_amd64.go b/pkg/sentry/platform/kvm/kvm_amd64.go
new file mode 100644
index 000000000..093497bc4
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_amd64.go
@@ -0,0 +1,190 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package kvm
+
+import (
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+)
+
+// userRegs represents KVM user registers.
+//
+// This mirrors kvm_regs.
+type userRegs struct {
+ RAX uint64
+ RBX uint64
+ RCX uint64
+ RDX uint64
+ RSI uint64
+ RDI uint64
+ RSP uint64
+ RBP uint64
+ R8 uint64
+ R9 uint64
+ R10 uint64
+ R11 uint64
+ R12 uint64
+ R13 uint64
+ R14 uint64
+ R15 uint64
+ RIP uint64
+ RFLAGS uint64
+}
+
+// systemRegs represents KVM system registers.
+//
+// This mirrors kvm_sregs.
+type systemRegs struct {
+ CS segment
+ DS segment
+ ES segment
+ FS segment
+ GS segment
+ SS segment
+ TR segment
+ LDT segment
+ GDT descriptor
+ IDT descriptor
+ CR0 uint64
+ CR2 uint64
+ CR3 uint64
+ CR4 uint64
+ CR8 uint64
+ EFER uint64
+ apicBase uint64
+ interruptBitmap [(_KVM_NR_INTERRUPTS + 63) / 64]uint64
+}
+
+// segment is the expanded form of a segment register.
+//
+// This mirrors kvm_segment.
+type segment struct {
+ base uint64
+ limit uint32
+ selector uint16
+ typ uint8
+ present uint8
+ DPL uint8
+ DB uint8
+ S uint8
+ L uint8
+ G uint8
+ AVL uint8
+ unusable uint8
+ _ uint8
+}
+
+// Clear clears the segment and marks it unusable.
+func (s *segment) Clear() {
+ *s = segment{unusable: 1}
+}
+
+// selector is a segment selector.
+type selector uint16
+
+// tobool is a simple helper.
+func tobool(x ring0.SegmentDescriptorFlags) uint8 {
+ if x != 0 {
+ return 1
+ }
+ return 0
+}
+
+// Load loads the segment described by d into the segment s.
+//
+// The argument sel is recorded as the segment selector index.
+func (s *segment) Load(d *ring0.SegmentDescriptor, sel ring0.Selector) {
+ flag := d.Flags()
+ if flag&ring0.SegmentDescriptorPresent == 0 {
+ s.Clear()
+ return
+ }
+ s.base = uint64(d.Base())
+ s.limit = d.Limit()
+ s.typ = uint8((flag>>8)&0xF) | 1
+ s.S = tobool(flag & ring0.SegmentDescriptorSystem)
+ s.DPL = uint8(d.DPL())
+ s.present = tobool(flag & ring0.SegmentDescriptorPresent)
+ s.AVL = tobool(flag & ring0.SegmentDescriptorAVL)
+ s.L = tobool(flag & ring0.SegmentDescriptorLong)
+ s.DB = tobool(flag & ring0.SegmentDescriptorDB)
+ s.G = tobool(flag & ring0.SegmentDescriptorG)
+ if s.L != 0 {
+ s.limit = 0xffffffff
+ }
+ s.unusable = 0
+ s.selector = uint16(sel)
+}
+
+// descriptor describes a region of physical memory.
+//
+// It corresponds to the pseudo-descriptor used in the x86 LGDT and LIDT
+// instructions, and mirrors kvm_dtable.
+type descriptor struct {
+ base uint64
+ limit uint16
+ _ [3]uint16
+}
+
+// modelControlRegister is an MSR entry.
+//
+// This mirrors kvm_msr_entry.
+type modelControlRegister struct {
+ index uint32
+ _ uint32
+ data uint64
+}
+
+// modelControlRegisers is a collection of MSRs.
+//
+// This mirrors kvm_msrs.
+type modelControlRegisters struct {
+ nmsrs uint32
+ _ uint32
+ entries [16]modelControlRegister
+}
+
+// cpuidEntry is a single CPUID entry.
+//
+// This mirrors kvm_cpuid_entry2.
+type cpuidEntry struct {
+ function uint32
+ index uint32
+ flags uint32
+ eax uint32
+ ebx uint32
+ ecx uint32
+ edx uint32
+ _ [3]uint32
+}
+
+// cpuidEntries is a collection of CPUID entries.
+//
+// This mirrors kvm_cpuid2.
+type cpuidEntries struct {
+ nr uint32
+ _ uint32
+ entries [_KVM_NR_CPUID_ENTRIES]cpuidEntry
+}
+
+// updateGlobalOnce does global initialization. It has to be called only once.
+func updateGlobalOnce(fd int) error {
+ physicalInit()
+ err := updateSystemValues(int(fd))
+ ring0.Init(cpuid.HostFeatureSet())
+ return err
+}
diff --git a/pkg/sentry/platform/kvm/kvm_amd64_unsafe.go b/pkg/sentry/platform/kvm/kvm_amd64_unsafe.go
new file mode 100644
index 000000000..46c4b9113
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_amd64_unsafe.go
@@ -0,0 +1,77 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package kvm
+
+import (
+ "fmt"
+ "syscall"
+ "unsafe"
+)
+
+var (
+ runDataSize int
+ hasGuestPCID bool
+ cpuidSupported = cpuidEntries{nr: _KVM_NR_CPUID_ENTRIES}
+)
+
+func updateSystemValues(fd int) error {
+ // Extract the mmap size.
+ sz, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(fd), _KVM_GET_VCPU_MMAP_SIZE, 0)
+ if errno != 0 {
+ return fmt.Errorf("getting VCPU mmap size: %v", errno)
+ }
+
+ // Save the data.
+ runDataSize = int(sz)
+
+ // Must do the dance to figure out the number of entries.
+ _, _, errno = syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(fd),
+ _KVM_GET_SUPPORTED_CPUID,
+ uintptr(unsafe.Pointer(&cpuidSupported)))
+ if errno != 0 && errno != syscall.ENOMEM {
+ // Some other error occurred.
+ return fmt.Errorf("getting supported CPUID: %v", errno)
+ }
+
+ // The number should now be correct.
+ _, _, errno = syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(fd),
+ _KVM_GET_SUPPORTED_CPUID,
+ uintptr(unsafe.Pointer(&cpuidSupported)))
+ if errno != 0 {
+ // Didn't work with the right number.
+ return fmt.Errorf("getting supported CPUID (2nd attempt): %v", errno)
+ }
+
+ // Calculate whether guestPCID is supported.
+ //
+ // FIXME(ascannell): These should go through the much more pleasant
+ // cpuid package interfaces, once a way to accept raw kvm CPUID entries
+ // is plumbed (or some rough equivalent).
+ for i := 0; i < int(cpuidSupported.nr); i++ {
+ entry := cpuidSupported.entries[i]
+ if entry.function == 1 && entry.index == 0 && entry.ecx&(1<<17) != 0 {
+ hasGuestPCID = true // Found matching PCID in guest feature set.
+ }
+ }
+
+ // Success.
+ return nil
+}
diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go
new file mode 100644
index 000000000..0b06a923a
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_arm64.go
@@ -0,0 +1,67 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package kvm
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+)
+
+type kvmOneReg struct {
+ id uint64
+ addr uint64
+}
+
+// arm64HypercallMMIOBase is MMIO base address used to dispatch hypercalls.
+var arm64HypercallMMIOBase uintptr
+
+const KVM_NR_SPSR = 5
+
+type userFpsimdState struct {
+ vregs [64]uint64
+ fpsr uint32
+ fpcr uint32
+ reserved [2]uint32
+}
+
+type userRegs struct {
+ Regs arch.Registers
+ sp_el1 uint64
+ elr_el1 uint64
+ spsr [KVM_NR_SPSR]uint64
+ fpRegs userFpsimdState
+}
+
+type exception struct {
+ sErrPending uint8
+ sErrHasEsr uint8
+ pad [6]uint8
+ sErrEsr uint64
+}
+
+type kvmVcpuEvents struct {
+ exception
+ rsvd [12]uint32
+}
+
+// updateGlobalOnce does global initialization. It has to be called only once.
+func updateGlobalOnce(fd int) error {
+ physicalInit()
+ err := updateSystemValues(int(fd))
+ ring0.Init()
+ return err
+}
diff --git a/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go
new file mode 100644
index 000000000..6531bae1d
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go
@@ -0,0 +1,39 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package kvm
+
+import (
+ "fmt"
+ "syscall"
+)
+
+var (
+ runDataSize int
+)
+
+func updateSystemValues(fd int) error {
+ // Extract the mmap size.
+ sz, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(fd), _KVM_GET_VCPU_MMAP_SIZE, 0)
+ if errno != 0 {
+ return fmt.Errorf("getting VCPU mmap size: %v", errno)
+ }
+ // Save the data.
+ runDataSize = int(sz)
+
+ // Success.
+ return nil
+}
diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go
new file mode 100644
index 000000000..3bf918446
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_const.go
@@ -0,0 +1,87 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+// KVM ioctls.
+//
+// Only the ioctls we need in Go appear here; some additional ioctls are used
+// within the assembly stubs (KVM_INTERRUPT, etc.).
+const (
+ _KVM_CREATE_VM = 0xae01
+ _KVM_GET_VCPU_MMAP_SIZE = 0xae04
+ _KVM_CREATE_VCPU = 0xae41
+ _KVM_SET_TSS_ADDR = 0xae47
+ _KVM_RUN = 0xae80
+ _KVM_NMI = 0xae9a
+ _KVM_CHECK_EXTENSION = 0xae03
+ _KVM_INTERRUPT = 0x4004ae86
+ _KVM_SET_MSRS = 0x4008ae89
+ _KVM_SET_USER_MEMORY_REGION = 0x4020ae46
+ _KVM_SET_REGS = 0x4090ae82
+ _KVM_SET_SREGS = 0x4138ae84
+ _KVM_GET_REGS = 0x8090ae81
+ _KVM_GET_SUPPORTED_CPUID = 0xc008ae05
+ _KVM_SET_CPUID2 = 0x4008ae90
+ _KVM_SET_SIGNAL_MASK = 0x4004ae8b
+ _KVM_GET_VCPU_EVENTS = 0x8040ae9f
+ _KVM_SET_VCPU_EVENTS = 0x4040aea0
+)
+
+// KVM exit reasons.
+const (
+ _KVM_EXIT_EXCEPTION = 0x1
+ _KVM_EXIT_IO = 0x2
+ _KVM_EXIT_HYPERCALL = 0x3
+ _KVM_EXIT_DEBUG = 0x4
+ _KVM_EXIT_HLT = 0x5
+ _KVM_EXIT_MMIO = 0x6
+ _KVM_EXIT_IRQ_WINDOW_OPEN = 0x7
+ _KVM_EXIT_SHUTDOWN = 0x8
+ _KVM_EXIT_FAIL_ENTRY = 0x9
+ _KVM_EXIT_INTERNAL_ERROR = 0x11
+ _KVM_EXIT_SYSTEM_EVENT = 0x18
+)
+
+// KVM capability options.
+const (
+ _KVM_CAP_MAX_VCPUS = 0x42
+ _KVM_CAP_ARM_VM_IPA_SIZE = 0xa5
+ _KVM_CAP_VCPU_EVENTS = 0x29
+ _KVM_CAP_ARM_INJECT_SERROR_ESR = 0x9e
+)
+
+// KVM limits.
+const (
+ _KVM_NR_VCPUS = 0xff
+ _KVM_NR_INTERRUPTS = 0x100
+ _KVM_NR_CPUID_ENTRIES = 0x100
+)
+
+// KVM kvm_memory_region::flags.
+const (
+ _KVM_MEM_LOG_DIRTY_PAGES = uint32(1) << 0
+ _KVM_MEM_READONLY = uint32(1) << 1
+ _KVM_MEM_FLAGS_NONE = 0
+)
+
+// KVM hypercall list.
+// Canonical list of hypercalls supported.
+const (
+ // On amd64, it uses 'HLT' to leave the guest.
+ // Unlike amd64, arm64 can only uses mmio_exit/psci to leave the guest.
+ // _KVM_HYPERCALL_VMEXIT is only used on Arm64 for now.
+ _KVM_HYPERCALL_VMEXIT int = iota
+ _KVM_HYPERCALL_MAX
+)
diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go
new file mode 100644
index 000000000..6f0539c29
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go
@@ -0,0 +1,140 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+// KVM ioctls for Arm64.
+const (
+ _KVM_GET_ONE_REG = 0x4010aeab
+ _KVM_SET_ONE_REG = 0x4010aeac
+
+ _KVM_ARM_TARGET_GENERIC_V8 = 5
+ _KVM_ARM_PREFERRED_TARGET = 0x8020aeaf
+ _KVM_ARM_VCPU_INIT = 0x4020aeae
+ _KVM_ARM64_REGS_PSTATE = 0x6030000000100042
+ _KVM_ARM64_REGS_SP_EL1 = 0x6030000000100044
+ _KVM_ARM64_REGS_R0 = 0x6030000000100000
+ _KVM_ARM64_REGS_R1 = 0x6030000000100002
+ _KVM_ARM64_REGS_R2 = 0x6030000000100004
+ _KVM_ARM64_REGS_R3 = 0x6030000000100006
+ _KVM_ARM64_REGS_R8 = 0x6030000000100010
+ _KVM_ARM64_REGS_R18 = 0x6030000000100024
+ _KVM_ARM64_REGS_PC = 0x6030000000100040
+ _KVM_ARM64_REGS_MAIR_EL1 = 0x603000000013c510
+ _KVM_ARM64_REGS_TCR_EL1 = 0x603000000013c102
+ _KVM_ARM64_REGS_TTBR0_EL1 = 0x603000000013c100
+ _KVM_ARM64_REGS_TTBR1_EL1 = 0x603000000013c101
+ _KVM_ARM64_REGS_SCTLR_EL1 = 0x603000000013c080
+ _KVM_ARM64_REGS_CPACR_EL1 = 0x603000000013c082
+ _KVM_ARM64_REGS_VBAR_EL1 = 0x603000000013c600
+)
+
+// Arm64: Architectural Feature Access Control Register EL1.
+const (
+ _FPEN_NOTRAP = 3
+ _FPEN_SHIFT = 20
+)
+
+// Arm64: System Control Register EL1.
+const (
+ _SCTLR_M = 1 << 0
+ _SCTLR_C = 1 << 2
+ _SCTLR_I = 1 << 12
+)
+
+// Arm64: Translation Control Register EL1.
+const (
+ _TCR_IPS_40BITS = 2 << 32 // PA=40
+ _TCR_IPS_48BITS = 5 << 32 // PA=48
+
+ _TCR_T0SZ_OFFSET = 0
+ _TCR_T1SZ_OFFSET = 16
+ _TCR_IRGN0_SHIFT = 8
+ _TCR_IRGN1_SHIFT = 24
+ _TCR_ORGN0_SHIFT = 10
+ _TCR_ORGN1_SHIFT = 26
+ _TCR_SH0_SHIFT = 12
+ _TCR_SH1_SHIFT = 28
+ _TCR_TG0_SHIFT = 14
+ _TCR_TG1_SHIFT = 30
+
+ _TCR_T0SZ_VA48 = 64 - 48 // VA=48
+ _TCR_T1SZ_VA48 = 64 - 48 // VA=48
+
+ _TCR_ASID16 = 1 << 36
+ _TCR_TBI0 = 1 << 37
+
+ _TCR_TXSZ_VA48 = (_TCR_T0SZ_VA48 << _TCR_T0SZ_OFFSET) | (_TCR_T1SZ_VA48 << _TCR_T1SZ_OFFSET)
+
+ _TCR_TG0_4K = 0 << _TCR_TG0_SHIFT // 4K
+ _TCR_TG0_64K = 1 << _TCR_TG0_SHIFT // 64K
+
+ _TCR_TG1_4K = 2 << _TCR_TG1_SHIFT
+
+ _TCR_TG_FLAGS = _TCR_TG0_4K | _TCR_TG1_4K
+
+ _TCR_IRGN0_WBWA = 1 << _TCR_IRGN0_SHIFT
+ _TCR_IRGN1_WBWA = 1 << _TCR_IRGN1_SHIFT
+ _TCR_IRGN_WBWA = _TCR_IRGN0_WBWA | _TCR_IRGN1_WBWA
+
+ _TCR_ORGN0_WBWA = 1 << _TCR_ORGN0_SHIFT
+ _TCR_ORGN1_WBWA = 1 << _TCR_ORGN1_SHIFT
+
+ _TCR_ORGN_WBWA = _TCR_ORGN0_WBWA | _TCR_ORGN1_WBWA
+
+ _TCR_SHARED = (3 << _TCR_SH0_SHIFT) | (3 << _TCR_SH1_SHIFT)
+
+ _TCR_CACHE_FLAGS = _TCR_IRGN_WBWA | _TCR_ORGN_WBWA
+)
+
+// Arm64: Memory Attribute Indirection Register EL1.
+const (
+ _MT_DEVICE_nGnRnE = 0
+ _MT_DEVICE_nGnRE = 1
+ _MT_DEVICE_GRE = 2
+ _MT_NORMAL_NC = 3
+ _MT_NORMAL = 4
+ _MT_NORMAL_WT = 5
+ _MT_EL1_INIT = (0 << _MT_DEVICE_nGnRnE) | (0x4 << _MT_DEVICE_nGnRE * 8) | (0xc << _MT_DEVICE_GRE * 8) | (0x44 << _MT_NORMAL_NC * 8) | (0xff << _MT_NORMAL * 8) | (0xbb << _MT_NORMAL_WT * 8)
+)
+
+const (
+ _KVM_ARM_VCPU_POWER_OFF = 0 // CPU is started in OFF state
+ _KVM_ARM_VCPU_PSCI_0_2 = 2 // CPU uses PSCI v0.2
+)
+
+// Arm64: Exception Syndrome Register EL1.
+const (
+ _ESR_ELx_FSC = 0x3F
+
+ _ESR_SEGV_MAPERR_L0 = 0x4
+ _ESR_SEGV_MAPERR_L1 = 0x5
+ _ESR_SEGV_MAPERR_L2 = 0x6
+ _ESR_SEGV_MAPERR_L3 = 0x7
+
+ _ESR_SEGV_ACCERR_L1 = 0x9
+ _ESR_SEGV_ACCERR_L2 = 0xa
+ _ESR_SEGV_ACCERR_L3 = 0xb
+
+ _ESR_SEGV_PEMERR_L1 = 0xd
+ _ESR_SEGV_PEMERR_L2 = 0xe
+ _ESR_SEGV_PEMERR_L3 = 0xf
+)
+
+// Arm64: MMIO base address used to dispatch hypercalls.
+const (
+ // on Arm64, the MMIO address must be 64-bit aligned.
+ // Currently, we only need 1 hypercall: hypercall_vmexit.
+ _AARCH64_HYPERCALL_MMIO_SIZE = 1 << 3
+)
diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go
new file mode 100644
index 000000000..6c8f4fa28
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_test.go
@@ -0,0 +1,533 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "math/rand"
+ "reflect"
+ "sync/atomic"
+ "syscall"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+var dummyFPState = (*byte)(arch.NewFloatingPointData())
+
+type testHarness interface {
+ Errorf(format string, args ...interface{})
+ Fatalf(format string, args ...interface{})
+}
+
+func kvmTest(t testHarness, setup func(*KVM), fn func(*vCPU) bool) {
+ // Create the machine.
+ deviceFile, err := OpenDevice()
+ if err != nil {
+ t.Fatalf("error opening device file: %v", err)
+ }
+ k, err := New(deviceFile)
+ if err != nil {
+ t.Fatalf("error creating KVM instance: %v", err)
+ }
+ defer k.machine.Destroy()
+
+ // Call additional setup.
+ if setup != nil {
+ setup(k)
+ }
+
+ var c *vCPU // For recovery.
+ defer func() {
+ redpill()
+ if c != nil {
+ k.machine.Put(c)
+ }
+ }()
+ for {
+ c = k.machine.Get()
+ if !fn(c) {
+ break
+ }
+
+ // We put the vCPU here and clear the value so that the
+ // deferred recovery will not re-put it above.
+ k.machine.Put(c)
+ c = nil
+ }
+}
+
+func bluepillTest(t testHarness, fn func(*vCPU)) {
+ kvmTest(t, nil, func(c *vCPU) bool {
+ bluepill(c)
+ fn(c)
+ return false
+ })
+}
+
+func TestKernelSyscall(t *testing.T) {
+ bluepillTest(t, func(c *vCPU) {
+ redpill() // Leave guest mode.
+ if got := atomic.LoadUint32(&c.state); got != vCPUUser {
+ t.Errorf("vCPU not in ready state: got %v", got)
+ }
+ })
+}
+
+func hostFault() {
+ defer func() {
+ recover()
+ }()
+ var foo *int
+ *foo = 0
+}
+
+func TestKernelFault(t *testing.T) {
+ hostFault() // Ensure recovery works.
+ bluepillTest(t, func(c *vCPU) {
+ hostFault()
+ if got := atomic.LoadUint32(&c.state); got != vCPUUser {
+ t.Errorf("vCPU not in ready state: got %v", got)
+ }
+ })
+}
+
+func TestKernelFloatingPoint(t *testing.T) {
+ bluepillTest(t, func(c *vCPU) {
+ if !testutil.FloatingPointWorks() {
+ t.Errorf("floating point does not work, and it should!")
+ }
+ })
+}
+
+func applicationTest(t testHarness, useHostMappings bool, target func(), fn func(*vCPU, *arch.Registers, *pagetables.PageTables) bool) {
+ // Initialize registers & page tables.
+ var (
+ regs arch.Registers
+ pt *pagetables.PageTables
+ )
+ testutil.SetTestTarget(&regs, target)
+
+ kvmTest(t, func(k *KVM) {
+ // Create new page tables.
+ as, _, err := k.NewAddressSpace(nil /* invalidator */)
+ if err != nil {
+ t.Fatalf("can't create new address space: %v", err)
+ }
+ pt = as.(*addressSpace).pageTables
+
+ if useHostMappings {
+ // Apply the physical mappings to these page tables.
+ // (This is normally dangerous, since they point to
+ // physical pages that may not exist. This shouldn't be
+ // done for regular user code, but is fine for test
+ // purposes.)
+ applyPhysicalRegions(func(pr physicalRegion) bool {
+ pt.Map(usermem.Addr(pr.virtual), pr.length, pagetables.MapOpts{
+ AccessType: usermem.AnyAccess,
+ User: true,
+ }, pr.physical)
+ return true // Keep iterating.
+ })
+ }
+ }, func(c *vCPU) bool {
+ // Invoke the function with the extra data.
+ return fn(c, &regs, pt)
+ })
+}
+
+func TestApplicationSyscall(t *testing.T) {
+ applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ var si arch.SignalInfo
+ if _, err := c.SwitchToUser(ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: dummyFPState,
+ PageTables: pt,
+ FullRestore: true,
+ }, &si); err == platform.ErrContextInterrupt {
+ return true // Retry.
+ } else if err != nil {
+ t.Errorf("application syscall with full restore failed: %v", err)
+ }
+ return false
+ })
+ applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ var si arch.SignalInfo
+ if _, err := c.SwitchToUser(ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: dummyFPState,
+ PageTables: pt,
+ }, &si); err == platform.ErrContextInterrupt {
+ return true // Retry.
+ } else if err != nil {
+ t.Errorf("application syscall with partial restore failed: %v", err)
+ }
+ return false
+ })
+}
+
+func TestApplicationFault(t *testing.T) {
+ applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ testutil.SetTouchTarget(regs, nil) // Cause fault.
+ var si arch.SignalInfo
+ if _, err := c.SwitchToUser(ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: dummyFPState,
+ PageTables: pt,
+ FullRestore: true,
+ }, &si); err == platform.ErrContextInterrupt {
+ return true // Retry.
+ } else if err != platform.ErrContextSignal || si.Signo != int32(syscall.SIGSEGV) {
+ t.Errorf("application fault with full restore got (%v, %v), expected (%v, SIGSEGV)", err, si, platform.ErrContextSignal)
+ }
+ return false
+ })
+ applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ testutil.SetTouchTarget(regs, nil) // Cause fault.
+ var si arch.SignalInfo
+ if _, err := c.SwitchToUser(ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: dummyFPState,
+ PageTables: pt,
+ }, &si); err == platform.ErrContextInterrupt {
+ return true // Retry.
+ } else if err != platform.ErrContextSignal || si.Signo != int32(syscall.SIGSEGV) {
+ t.Errorf("application fault with partial restore got (%v, %v), expected (%v, SIGSEGV)", err, si, platform.ErrContextSignal)
+ }
+ return false
+ })
+}
+
+func TestRegistersSyscall(t *testing.T) {
+ applicationTest(t, true, testutil.TwiddleRegsSyscall, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ testutil.SetTestRegs(regs) // Fill values for all registers.
+ for {
+ var si arch.SignalInfo
+ if _, err := c.SwitchToUser(ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: dummyFPState,
+ PageTables: pt,
+ }, &si); err == platform.ErrContextInterrupt {
+ continue // Retry.
+ } else if err != nil {
+ t.Errorf("application register check with partial restore got unexpected error: %v", err)
+ }
+ if err := testutil.CheckTestRegs(regs, false); err != nil {
+ t.Errorf("application register check with partial restore failed: %v", err)
+ }
+ break // Done.
+ }
+ return false
+ })
+}
+
+func TestRegistersFault(t *testing.T) {
+ applicationTest(t, true, testutil.TwiddleRegsFault, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ testutil.SetTestRegs(regs) // Fill values for all registers.
+ for {
+ var si arch.SignalInfo
+ if _, err := c.SwitchToUser(ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: dummyFPState,
+ PageTables: pt,
+ FullRestore: true,
+ }, &si); err == platform.ErrContextInterrupt {
+ continue // Retry.
+ } else if err != platform.ErrContextSignal || si.Signo != int32(syscall.SIGSEGV) {
+ t.Errorf("application register check with full restore got unexpected error: %v", err)
+ }
+ if err := testutil.CheckTestRegs(regs, true); err != nil {
+ t.Errorf("application register check with full restore failed: %v", err)
+ }
+ break // Done.
+ }
+ return false
+ })
+}
+
+func TestSegments(t *testing.T) {
+ applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ testutil.SetTestSegments(regs)
+ for {
+ var si arch.SignalInfo
+ if _, err := c.SwitchToUser(ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: dummyFPState,
+ PageTables: pt,
+ FullRestore: true,
+ }, &si); err == platform.ErrContextInterrupt {
+ continue // Retry.
+ } else if err != nil {
+ t.Errorf("application segment check with full restore got unexpected error: %v", err)
+ }
+ if err := testutil.CheckTestSegments(regs); err != nil {
+ t.Errorf("application segment check with full restore failed: %v", err)
+ }
+ break // Done.
+ }
+ return false
+ })
+}
+
+func TestBounce(t *testing.T) {
+ applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ go func() {
+ time.Sleep(time.Millisecond)
+ c.BounceToKernel()
+ }()
+ var si arch.SignalInfo
+ if _, err := c.SwitchToUser(ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: dummyFPState,
+ PageTables: pt,
+ }, &si); err != platform.ErrContextInterrupt {
+ t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt)
+ }
+ return false
+ })
+ applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ go func() {
+ time.Sleep(time.Millisecond)
+ c.BounceToKernel()
+ }()
+ var si arch.SignalInfo
+ if _, err := c.SwitchToUser(ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: dummyFPState,
+ PageTables: pt,
+ FullRestore: true,
+ }, &si); err != platform.ErrContextInterrupt {
+ t.Errorf("application full restore: got %v, wanted %v", err, platform.ErrContextInterrupt)
+ }
+ return false
+ })
+}
+
+func TestBounceStress(t *testing.T) {
+ applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ randomSleep := func() {
+ // O(hundreds of microseconds) is appropriate to ensure
+ // different overlaps and different schedules.
+ if n := rand.Intn(1000); n > 100 {
+ time.Sleep(time.Duration(n) * time.Microsecond)
+ }
+ }
+ for i := 0; i < 1000; i++ {
+ // Start an asynchronously executing goroutine that
+ // calls Bounce at pseudo-random point in time.
+ // This should wind up calling Bounce when the
+ // kernel is in various stages of the switch.
+ go func() {
+ randomSleep()
+ c.BounceToKernel()
+ }()
+ randomSleep()
+ var si arch.SignalInfo
+ if _, err := c.SwitchToUser(ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: dummyFPState,
+ PageTables: pt,
+ }, &si); err != platform.ErrContextInterrupt {
+ t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt)
+ }
+ c.unlock()
+ randomSleep()
+ c.lock()
+ }
+ return false
+ })
+}
+
+func TestInvalidate(t *testing.T) {
+ var data uintptr // Used below.
+ applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ testutil.SetTouchTarget(regs, &data) // Read legitimate value.
+ for {
+ var si arch.SignalInfo
+ if _, err := c.SwitchToUser(ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: dummyFPState,
+ PageTables: pt,
+ }, &si); err == platform.ErrContextInterrupt {
+ continue // Retry.
+ } else if err != nil {
+ t.Errorf("application partial restore: got %v, wanted nil", err)
+ }
+ break // Done.
+ }
+ // Unmap the page containing data & invalidate.
+ pt.Unmap(usermem.Addr(reflect.ValueOf(&data).Pointer() & ^uintptr(usermem.PageSize-1)), usermem.PageSize)
+ for {
+ var si arch.SignalInfo
+ if _, err := c.SwitchToUser(ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: dummyFPState,
+ PageTables: pt,
+ Flush: true,
+ }, &si); err == platform.ErrContextInterrupt {
+ continue // Retry.
+ } else if err != platform.ErrContextSignal {
+ t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextSignal)
+ }
+ break // Success.
+ }
+ return false
+ })
+}
+
+// IsFault returns true iff the given signal represents a fault.
+func IsFault(err error, si *arch.SignalInfo) bool {
+ return err == platform.ErrContextSignal && si.Signo == int32(syscall.SIGSEGV)
+}
+
+func TestEmptyAddressSpace(t *testing.T) {
+ applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ var si arch.SignalInfo
+ if _, err := c.SwitchToUser(ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: dummyFPState,
+ PageTables: pt,
+ }, &si); err == platform.ErrContextInterrupt {
+ return true // Retry.
+ } else if !IsFault(err, &si) {
+ t.Errorf("first fault with partial restore failed got %v", err)
+ t.Logf("registers: %#v", &regs)
+ }
+ return false
+ })
+ applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ var si arch.SignalInfo
+ if _, err := c.SwitchToUser(ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: dummyFPState,
+ PageTables: pt,
+ FullRestore: true,
+ }, &si); err == platform.ErrContextInterrupt {
+ return true // Retry.
+ } else if !IsFault(err, &si) {
+ t.Errorf("first fault with full restore failed got %v", err)
+ t.Logf("registers: %#v", &regs)
+ }
+ return false
+ })
+}
+
+func TestWrongVCPU(t *testing.T) {
+ kvmTest(t, nil, func(c1 *vCPU) bool {
+ kvmTest(t, nil, func(c2 *vCPU) bool {
+ // Basic test, one then the other.
+ bluepill(c1)
+ bluepill(c2)
+ if c2.switches == 0 {
+ // Don't allow the test to proceed if this fails.
+ t.Fatalf("wrong vCPU#2 switches: vCPU1=%+v,vCPU2=%+v", c1, c2)
+ }
+
+ // Alternate vCPUs; we expect to need to trigger the
+ // wrong vCPU path on each switch.
+ for i := 0; i < 100; i++ {
+ bluepill(c1)
+ bluepill(c2)
+ }
+ if count := c1.switches; count < 90 {
+ t.Errorf("wrong vCPU#1 switches: vCPU1=%+v,vCPU2=%+v", c1, c2)
+ }
+ if count := c2.switches; count < 90 {
+ t.Errorf("wrong vCPU#2 switches: vCPU1=%+v,vCPU2=%+v", c1, c2)
+ }
+ return false
+ })
+ return false
+ })
+ kvmTest(t, nil, func(c1 *vCPU) bool {
+ kvmTest(t, nil, func(c2 *vCPU) bool {
+ bluepill(c1)
+ bluepill(c2)
+ return false
+ })
+ return false
+ })
+}
+
+func BenchmarkApplicationSyscall(b *testing.B) {
+ var (
+ i int // Iteration includes machine.Get() / machine.Put().
+ a int // Count for ErrContextInterrupt.
+ )
+ applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ var si arch.SignalInfo
+ if _, err := c.SwitchToUser(ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: dummyFPState,
+ PageTables: pt,
+ }, &si); err == platform.ErrContextInterrupt {
+ a++
+ return true // Ignore.
+ } else if err != nil {
+ b.Fatalf("benchmark failed: %v", err)
+ }
+ i++
+ return i < b.N
+ })
+ if a != 0 {
+ b.Logf("ErrContextInterrupt occurred %d times (in %d iterations).", a, a+i)
+ }
+}
+
+func BenchmarkKernelSyscall(b *testing.B) {
+ // Note that the target passed here is irrelevant, we never execute SwitchToUser.
+ applicationTest(b, true, testutil.Getpid, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ // iteration does not include machine.Get() / machine.Put().
+ for i := 0; i < b.N; i++ {
+ testutil.Getpid()
+ }
+ return false
+ })
+}
+
+func BenchmarkWorldSwitchToUserRoundtrip(b *testing.B) {
+ // see BenchmarkApplicationSyscall.
+ var (
+ i int
+ a int
+ )
+ applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ var si arch.SignalInfo
+ if _, err := c.SwitchToUser(ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: dummyFPState,
+ PageTables: pt,
+ }, &si); err == platform.ErrContextInterrupt {
+ a++
+ return true // Ignore.
+ } else if err != nil {
+ b.Fatalf("benchmark failed: %v", err)
+ }
+ // This will intentionally cause the world switch. By executing
+ // a host syscall here, we force the transition between guest
+ // and host mode.
+ testutil.Getpid()
+ i++
+ return i < b.N
+ })
+ if a != 0 {
+ b.Logf("ErrContextInterrupt occurred %d times (in %d iterations).", a, a+i)
+ }
+}
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
new file mode 100644
index 000000000..6c54712d1
--- /dev/null
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -0,0 +1,575 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "fmt"
+ "runtime"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/atomicbitops"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/procid"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// machine contains state associated with the VM as a whole.
+type machine struct {
+ // fd is the vm fd.
+ fd int
+
+ // nextSlot is the next slot for setMemoryRegion.
+ //
+ // This must be accessed atomically. If nextSlot is ^uint32(0), then
+ // slots are currently being updated, and the caller should retry.
+ nextSlot uint32
+
+ // kernel is the set of global structures.
+ kernel ring0.Kernel
+
+ // mappingCache is used for mapPhysical.
+ mappingCache sync.Map
+
+ // mu protects vCPUs.
+ mu sync.RWMutex
+
+ // available is notified when vCPUs are available.
+ available sync.Cond
+
+ // vCPUsByTID are the machine vCPUs.
+ //
+ // These are populated dynamically.
+ vCPUsByTID map[uint64]*vCPU
+
+ // vCPUsByID are the machine vCPUs, can be indexed by the vCPU's ID.
+ vCPUsByID []*vCPU
+
+ // maxVCPUs is the maximum number of vCPUs supported by the machine.
+ maxVCPUs int
+
+ // nextID is the next vCPU ID.
+ nextID uint32
+}
+
+const (
+ // vCPUReady is an alias for all the below clear.
+ vCPUReady uint32 = 0
+
+ // vCPUser indicates that the vCPU is in or about to enter user mode.
+ vCPUUser uint32 = 1 << 0
+
+ // vCPUGuest indicates the vCPU is in guest mode.
+ vCPUGuest uint32 = 1 << 1
+
+ // vCPUWaiter indicates that there is a waiter.
+ //
+ // If this is set, then notify must be called on any state transitions.
+ vCPUWaiter uint32 = 1 << 2
+)
+
+// vCPU is a single KVM vCPU.
+type vCPU struct {
+ // CPU is the kernel CPU data.
+ //
+ // This must be the first element of this structure, it is referenced
+ // by the bluepill code (see bluepill_amd64.s).
+ ring0.CPU
+
+ // id is the vCPU id.
+ id int
+
+ // fd is the vCPU fd.
+ fd int
+
+ // tid is the last set tid.
+ tid uint64
+
+ // switches is a count of world switches (informational only).
+ switches uint32
+
+ // faults is a count of world faults (informational only).
+ faults uint32
+
+ // state is the vCPU state.
+ //
+ // This is a bitmask of the three fields (vCPU*) described above.
+ state uint32
+
+ // runData for this vCPU.
+ runData *runData
+
+ // machine associated with this vCPU.
+ machine *machine
+
+ // active is the current addressSpace: this is set and read atomically,
+ // it is used to elide unnecessary interrupts due to invalidations.
+ active atomicAddressSpace
+
+ // vCPUArchState is the architecture-specific state.
+ vCPUArchState
+
+ dieState dieState
+}
+
+type dieState struct {
+ // message is thrown from die.
+ message string
+
+ // guestRegs is used to store register state during vCPU.die() to prevent
+ // allocation inside nosplit function.
+ guestRegs userRegs
+}
+
+// newVCPU creates a returns a new vCPU.
+//
+// Precondition: mu must be held.
+func (m *machine) newVCPU() *vCPU {
+ // Create the vCPU.
+ id := int(atomic.AddUint32(&m.nextID, 1) - 1)
+ fd, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CREATE_VCPU, uintptr(id))
+ if errno != 0 {
+ panic(fmt.Sprintf("error creating new vCPU: %v", errno))
+ }
+
+ c := &vCPU{
+ id: id,
+ fd: int(fd),
+ machine: m,
+ }
+ c.CPU.Init(&m.kernel, c)
+ m.vCPUsByID[c.id] = c
+
+ // Ensure the signal mask is correct.
+ if err := c.setSignalMask(); err != nil {
+ panic(fmt.Sprintf("error setting signal mask: %v", err))
+ }
+
+ // Map the run data.
+ runData, err := mapRunData(int(fd))
+ if err != nil {
+ panic(fmt.Sprintf("error mapping run data: %v", err))
+ }
+ c.runData = runData
+
+ // Initialize architecture state.
+ if err := c.initArchState(); err != nil {
+ panic(fmt.Sprintf("error initialization vCPU state: %v", err))
+ }
+
+ return c // Done.
+}
+
+// newMachine returns a new VM context.
+func newMachine(vm int) (*machine, error) {
+ // Create the machine.
+ m := &machine{fd: vm}
+ m.available.L = &m.mu
+ m.kernel.Init(ring0.KernelOpts{
+ PageTables: pagetables.New(newAllocator()),
+ })
+
+ maxVCPUs, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS)
+ if errno != 0 {
+ m.maxVCPUs = _KVM_NR_VCPUS
+ } else {
+ m.maxVCPUs = int(maxVCPUs)
+ }
+ log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs)
+
+ // Create the vCPUs map/slices.
+ m.vCPUsByTID = make(map[uint64]*vCPU)
+ m.vCPUsByID = make([]*vCPU, m.maxVCPUs)
+
+ // Apply the physical mappings. Note that these mappings may point to
+ // guest physical addresses that are not actually available. These
+ // physical pages are mapped on demand, see kernel_unsafe.go.
+ applyPhysicalRegions(func(pr physicalRegion) bool {
+ // Map everything in the lower half.
+ m.kernel.PageTables.Map(
+ usermem.Addr(pr.virtual),
+ pr.length,
+ pagetables.MapOpts{AccessType: usermem.AnyAccess},
+ pr.physical)
+
+ // And keep everything in the upper half.
+ m.kernel.PageTables.Map(
+ usermem.Addr(ring0.KernelStartAddress|pr.virtual),
+ pr.length,
+ pagetables.MapOpts{AccessType: usermem.AnyAccess},
+ pr.physical)
+
+ return true // Keep iterating.
+ })
+
+ var physicalRegionsReadOnly []physicalRegion
+ var physicalRegionsAvailable []physicalRegion
+
+ physicalRegionsReadOnly = rdonlyRegionsForSetMem()
+ physicalRegionsAvailable = availableRegionsForSetMem()
+
+ // Map all read-only regions.
+ for _, r := range physicalRegionsReadOnly {
+ m.mapPhysical(r.physical, r.length, physicalRegionsReadOnly, _KVM_MEM_READONLY)
+ }
+
+ // Ensure that the currently mapped virtual regions are actually
+ // available in the VM. Note that this doesn't guarantee no future
+ // faults, however it should guarantee that everything is available to
+ // ensure successful vCPU entry.
+ applyVirtualRegions(func(vr virtualRegion) {
+ if excludeVirtualRegion(vr) {
+ return // skip region.
+ }
+
+ for _, r := range physicalRegionsReadOnly {
+ if vr.virtual == r.virtual {
+ return
+ }
+ }
+
+ for virtual := vr.virtual; virtual < vr.virtual+vr.length; {
+ physical, length, ok := translateToPhysical(virtual)
+ if !ok {
+ // This must be an invalid region that was
+ // knocked out by creation of the physical map.
+ return
+ }
+ if virtual+length > vr.virtual+vr.length {
+ // Cap the length to the end of the area.
+ length = vr.virtual + vr.length - virtual
+ }
+
+ // Ensure the physical range is mapped.
+ m.mapPhysical(physical, length, physicalRegionsAvailable, _KVM_MEM_FLAGS_NONE)
+ virtual += length
+ }
+ })
+
+ // Initialize architecture state.
+ if err := m.initArchState(); err != nil {
+ m.Destroy()
+ return nil, err
+ }
+
+ // Ensure the machine is cleaned up properly.
+ runtime.SetFinalizer(m, (*machine).Destroy)
+ return m, nil
+}
+
+// mapPhysical checks for the mapping of a physical range, and installs one if
+// not available. This attempts to be efficient for calls in the hot path.
+//
+// This panics on error.
+//
+//go:nosplit
+func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalRegion, flags uint32) {
+ for end := physical + length; physical < end; {
+ _, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions)
+ if !ok {
+ // Should never happen.
+ panic("mapPhysical on unknown physical address")
+ }
+
+ if _, ok := m.mappingCache.LoadOrStore(physicalStart, true); !ok {
+ // Not present in the cache; requires setting the slot.
+ if _, ok := handleBluepillFault(m, physical, phyRegions, flags); !ok {
+ panic("handleBluepillFault failed")
+ }
+ }
+
+ // Move to the next chunk.
+ physical = physicalStart + length
+ }
+}
+
+// Destroy frees associated resources.
+//
+// Destroy should only be called once all active users of the machine are gone.
+// The machine object should not be used after calling Destroy.
+//
+// Precondition: all vCPUs must be returned to the machine.
+func (m *machine) Destroy() {
+ runtime.SetFinalizer(m, nil)
+
+ // Destroy vCPUs.
+ for _, c := range m.vCPUsByID {
+ if c == nil {
+ continue
+ }
+
+ // Ensure the vCPU is not still running in guest mode. This is
+ // possible iff teardown has been done by other threads, and
+ // somehow a single thread has not executed any system calls.
+ c.BounceToHost()
+
+ // Note that the runData may not be mapped if an error occurs
+ // during the middle of initialization.
+ if c.runData != nil {
+ if err := unmapRunData(c.runData); err != nil {
+ panic(fmt.Sprintf("error unmapping rundata: %v", err))
+ }
+ }
+ if err := syscall.Close(int(c.fd)); err != nil {
+ panic(fmt.Sprintf("error closing vCPU fd: %v", err))
+ }
+ }
+
+ // vCPUs are gone: teardown machine state.
+ if err := syscall.Close(m.fd); err != nil {
+ panic(fmt.Sprintf("error closing VM fd: %v", err))
+ }
+}
+
+// Get gets an available vCPU.
+//
+// This will return with the OS thread locked.
+func (m *machine) Get() *vCPU {
+ m.mu.RLock()
+ runtime.LockOSThread()
+ tid := procid.Current()
+
+ // Check for an exact match.
+ if c := m.vCPUsByTID[tid]; c != nil {
+ c.lock()
+ m.mu.RUnlock()
+ return c
+ }
+
+ // The happy path failed. We now proceed to acquire an exclusive lock
+ // (because the vCPU map may change), and scan all available vCPUs.
+ // In this case, we first unlock the OS thread. Otherwise, if mu is
+ // not available, the current system thread will be parked and a new
+ // system thread spawned. We avoid this situation by simply refreshing
+ // tid after relocking the system thread.
+ m.mu.RUnlock()
+ runtime.UnlockOSThread()
+ m.mu.Lock()
+ runtime.LockOSThread()
+ tid = procid.Current()
+
+ // Recheck for an exact match.
+ if c := m.vCPUsByTID[tid]; c != nil {
+ c.lock()
+ m.mu.Unlock()
+ return c
+ }
+
+ for {
+ // Scan for an available vCPU.
+ for origTID, c := range m.vCPUsByTID {
+ if atomic.CompareAndSwapUint32(&c.state, vCPUReady, vCPUUser) {
+ delete(m.vCPUsByTID, origTID)
+ m.vCPUsByTID[tid] = c
+ m.mu.Unlock()
+ c.loadSegments(tid)
+ return c
+ }
+ }
+
+ // Create a new vCPU (maybe).
+ if int(m.nextID) < m.maxVCPUs {
+ c := m.newVCPU()
+ c.lock()
+ m.vCPUsByTID[tid] = c
+ m.mu.Unlock()
+ c.loadSegments(tid)
+ return c
+ }
+
+ // Scan for something not in user mode.
+ for origTID, c := range m.vCPUsByTID {
+ if !atomic.CompareAndSwapUint32(&c.state, vCPUGuest, vCPUGuest|vCPUWaiter) {
+ continue
+ }
+
+ // The vCPU is not be able to transition to
+ // vCPUGuest|vCPUUser or to vCPUUser because that
+ // transition requires holding the machine mutex, as we
+ // do now. There is no path to register a waiter on
+ // just the vCPUReady state.
+ for {
+ c.waitUntilNot(vCPUGuest | vCPUWaiter)
+ if atomic.CompareAndSwapUint32(&c.state, vCPUReady, vCPUUser) {
+ break
+ }
+ }
+
+ // Steal the vCPU.
+ delete(m.vCPUsByTID, origTID)
+ m.vCPUsByTID[tid] = c
+ m.mu.Unlock()
+ c.loadSegments(tid)
+ return c
+ }
+
+ // Everything is executing in user mode. Wait until something
+ // is available. Note that signaling the condition variable
+ // will have the extra effect of kicking the vCPUs out of guest
+ // mode if that's where they were.
+ m.available.Wait()
+ }
+}
+
+// Put puts the current vCPU.
+func (m *machine) Put(c *vCPU) {
+ c.unlock()
+ runtime.UnlockOSThread()
+
+ m.mu.RLock()
+ m.available.Signal()
+ m.mu.RUnlock()
+}
+
+// newDirtySet returns a new dirty set.
+func (m *machine) newDirtySet() *dirtySet {
+ return &dirtySet{
+ vCPUMasks: make([]uint64, (m.maxVCPUs+63)/64, (m.maxVCPUs+63)/64),
+ }
+}
+
+// lock marks the vCPU as in user mode.
+//
+// This should only be called directly when known to be safe, i.e. when
+// the vCPU is owned by the current TID with no chance of theft.
+//
+//go:nosplit
+func (c *vCPU) lock() {
+ atomicbitops.OrUint32(&c.state, vCPUUser)
+}
+
+// unlock clears the vCPUUser bit.
+//
+//go:nosplit
+func (c *vCPU) unlock() {
+ if atomic.CompareAndSwapUint32(&c.state, vCPUUser|vCPUGuest, vCPUGuest) {
+ // Happy path: no exits are forced, and we can continue
+ // executing on our merry way with a single atomic access.
+ return
+ }
+
+ // Clear the lock.
+ origState := atomic.LoadUint32(&c.state)
+ atomicbitops.AndUint32(&c.state, ^vCPUUser)
+ switch origState {
+ case vCPUUser:
+ // Normal state.
+ case vCPUUser | vCPUGuest | vCPUWaiter:
+ // Force a transition: this must trigger a notification when we
+ // return from guest mode. We must clear vCPUWaiter here
+ // anyways, because BounceToKernel will force a transition only
+ // from ring3 to ring0, which will not clear this bit. Halt may
+ // workaround the issue, but if there is no exception or
+ // syscall in this period, BounceToKernel will hang.
+ atomicbitops.AndUint32(&c.state, ^vCPUWaiter)
+ c.notify()
+ case vCPUUser | vCPUWaiter:
+ // Waiting for the lock to be released; the responsibility is
+ // on us to notify the waiter and clear the associated bit.
+ atomicbitops.AndUint32(&c.state, ^vCPUWaiter)
+ c.notify()
+ default:
+ panic("invalid state")
+ }
+}
+
+// NotifyInterrupt implements interrupt.Receiver.NotifyInterrupt.
+//
+//go:nosplit
+func (c *vCPU) NotifyInterrupt() {
+ c.BounceToKernel()
+}
+
+// pid is used below in bounce.
+var pid = syscall.Getpid()
+
+// bounce forces a return to the kernel or to host mode.
+//
+// This effectively unwinds the state machine.
+func (c *vCPU) bounce(forceGuestExit bool) {
+ for {
+ switch state := atomic.LoadUint32(&c.state); state {
+ case vCPUReady, vCPUWaiter:
+ // There is nothing to be done, we're already in the
+ // kernel pre-acquisition. The Bounce criteria have
+ // been satisfied.
+ return
+ case vCPUUser:
+ // We need to register a waiter for the actual guest
+ // transition. When the transition takes place, then we
+ // can inject an interrupt to ensure a return to host
+ // mode.
+ atomic.CompareAndSwapUint32(&c.state, state, state|vCPUWaiter)
+ case vCPUUser | vCPUWaiter:
+ // Wait for the transition to guest mode. This should
+ // come from the bluepill handler.
+ c.waitUntilNot(state)
+ case vCPUGuest, vCPUUser | vCPUGuest:
+ if state == vCPUGuest && !forceGuestExit {
+ // The vCPU is already not acquired, so there's
+ // no need to do a fresh injection here.
+ return
+ }
+ // The vCPU is in user or kernel mode. Attempt to
+ // register a notification on change.
+ if !atomic.CompareAndSwapUint32(&c.state, state, state|vCPUWaiter) {
+ break // Retry.
+ }
+ for {
+ // We need to spin here until the signal is
+ // delivered, because Tgkill can return EAGAIN
+ // under memory pressure. Since we already
+ // marked ourselves as a waiter, we need to
+ // ensure that a signal is actually delivered.
+ if err := syscall.Tgkill(pid, int(atomic.LoadUint64(&c.tid)), bounceSignal); err == nil {
+ break
+ } else if err.(syscall.Errno) == syscall.EAGAIN {
+ continue
+ } else {
+ // Nothing else should be returned by tgkill.
+ panic(fmt.Sprintf("unexpected tgkill error: %v", err))
+ }
+ }
+ case vCPUGuest | vCPUWaiter, vCPUUser | vCPUGuest | vCPUWaiter:
+ if state == vCPUGuest|vCPUWaiter && !forceGuestExit {
+ // See above.
+ return
+ }
+ // Wait for the transition. This again should happen
+ // from the bluepill handler, but on the way out.
+ c.waitUntilNot(state)
+ default:
+ // Should not happen: the above is exhaustive.
+ panic("invalid state")
+ }
+ }
+}
+
+// BounceToKernel ensures that the vCPU bounces back to the kernel.
+//
+//go:nosplit
+func (c *vCPU) BounceToKernel() {
+ c.bounce(false)
+}
+
+// BounceToHost ensures that the vCPU is in host mode.
+//
+//go:nosplit
+func (c *vCPU) BounceToHost() {
+ c.bounce(true)
+}
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
new file mode 100644
index 000000000..acc823ba6
--- /dev/null
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -0,0 +1,347 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package kvm
+
+import (
+ "fmt"
+ "reflect"
+ "runtime/debug"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// initArchState initializes architecture-specific state.
+func (m *machine) initArchState() error {
+ // Set the legacy TSS address. This address is covered by the reserved
+ // range (up to 4GB). In fact, this is a main reason it exists.
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(m.fd),
+ _KVM_SET_TSS_ADDR,
+ uintptr(reservedMemory-(3*usermem.PageSize))); errno != 0 {
+ return errno
+ }
+
+ // Enable CPUID faulting, if possible. Note that this also serves as a
+ // basic platform sanity tests, since we will enter guest mode for the
+ // first time here. The recovery is necessary, since if we fail to read
+ // the platform info register, we will retry to host mode and
+ // ultimately need to handle a segmentation fault.
+ old := debug.SetPanicOnFault(true)
+ defer func() {
+ recover()
+ debug.SetPanicOnFault(old)
+ }()
+ c := m.Get()
+ defer m.Put(c)
+ bluepill(c)
+ ring0.SetCPUIDFaulting(true)
+
+ return nil
+}
+
+type vCPUArchState struct {
+ // PCIDs is the set of PCIDs for this vCPU.
+ //
+ // This starts above fixedKernelPCID.
+ PCIDs *pagetables.PCIDs
+
+ // floatingPointState is the floating point state buffer used in guest
+ // to host transitions. See usage in bluepill_amd64.go.
+ floatingPointState *arch.FloatingPointData
+}
+
+const (
+ // fixedKernelPCID is a fixed kernel PCID used for the kernel page
+ // tables. We must start allocating user PCIDs above this in order to
+ // avoid any conflict (see below).
+ fixedKernelPCID = 1
+
+ // poolPCIDs is the number of PCIDs to record in the database. As this
+ // grows, assignment can take longer, since it is a simple linear scan.
+ // Beyond a relatively small number, there are likely few perform
+ // benefits, since the TLB has likely long since lost any translations
+ // from more than a few PCIDs past.
+ poolPCIDs = 8
+)
+
+// dropPageTables drops cached page table entries.
+func (m *machine) dropPageTables(pt *pagetables.PageTables) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ // Clear from all PCIDs.
+ for _, c := range m.vCPUsByID {
+ if c != nil && c.PCIDs != nil {
+ c.PCIDs.Drop(pt)
+ }
+ }
+}
+
+// initArchState initializes architecture-specific state.
+func (c *vCPU) initArchState() error {
+ var (
+ kernelSystemRegs systemRegs
+ kernelUserRegs userRegs
+ )
+
+ // Set base control registers.
+ kernelSystemRegs.CR0 = c.CR0()
+ kernelSystemRegs.CR4 = c.CR4()
+ kernelSystemRegs.EFER = c.EFER()
+
+ // Set the IDT & GDT in the registers.
+ kernelSystemRegs.IDT.base, kernelSystemRegs.IDT.limit = c.IDT()
+ kernelSystemRegs.GDT.base, kernelSystemRegs.GDT.limit = c.GDT()
+ kernelSystemRegs.CS.Load(&ring0.KernelCodeSegment, ring0.Kcode)
+ kernelSystemRegs.DS.Load(&ring0.UserDataSegment, ring0.Udata)
+ kernelSystemRegs.ES.Load(&ring0.UserDataSegment, ring0.Udata)
+ kernelSystemRegs.SS.Load(&ring0.KernelDataSegment, ring0.Kdata)
+ kernelSystemRegs.FS.Load(&ring0.UserDataSegment, ring0.Udata)
+ kernelSystemRegs.GS.Load(&ring0.UserDataSegment, ring0.Udata)
+ tssBase, tssLimit, tss := c.TSS()
+ kernelSystemRegs.TR.Load(tss, ring0.Tss)
+ kernelSystemRegs.TR.base = tssBase
+ kernelSystemRegs.TR.limit = uint32(tssLimit)
+
+ // Point to kernel page tables, with no initial PCID.
+ kernelSystemRegs.CR3 = c.machine.kernel.PageTables.CR3(false, 0)
+
+ // Initialize the PCID database.
+ if hasGuestPCID {
+ // Note that NewPCIDs may return a nil table here, in which
+ // case we simply don't use PCID support (see below). In
+ // practice, this should not happen, however.
+ c.PCIDs = pagetables.NewPCIDs(fixedKernelPCID+1, poolPCIDs)
+ }
+
+ // Set the CPUID; this is required before setting system registers,
+ // since KVM will reject several CR4 bits if the CPUID does not
+ // indicate the support is available.
+ if err := c.setCPUID(); err != nil {
+ return err
+ }
+
+ // Set the entrypoint for the kernel.
+ kernelUserRegs.RIP = uint64(reflect.ValueOf(ring0.Start).Pointer())
+ kernelUserRegs.RAX = uint64(reflect.ValueOf(&c.CPU).Pointer())
+ kernelUserRegs.RFLAGS = ring0.KernelFlagsSet
+
+ // Set the system registers.
+ if err := c.setSystemRegisters(&kernelSystemRegs); err != nil {
+ return err
+ }
+
+ // Set the user registers.
+ if err := c.setUserRegisters(&kernelUserRegs); err != nil {
+ return err
+ }
+
+ // Allocate some floating point state save area for the local vCPU.
+ // This will be saved prior to leaving the guest, and we restore from
+ // this always. We cannot use the pointer in the context alone because
+ // we don't know how large the area there is in reality.
+ c.floatingPointState = arch.NewFloatingPointData()
+
+ // Set the time offset to the host native time.
+ return c.setSystemTime()
+}
+
+// nonCanonical generates a canonical address return.
+//
+//go:nosplit
+func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (usermem.AccessType, error) {
+ *info = arch.SignalInfo{
+ Signo: signal,
+ Code: arch.SignalInfoKernel,
+ }
+ info.SetAddr(addr) // Include address.
+ return usermem.NoAccess, platform.ErrContextSignal
+}
+
+// fault generates an appropriate fault return.
+//
+//go:nosplit
+func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, error) {
+ bluepill(c) // Probably no-op, but may not be.
+ faultAddr := ring0.ReadCR2()
+ code, user := c.ErrorCode()
+ if !user {
+ // The last fault serviced by this CPU was not a user
+ // fault, so we can't reliably trust the faultAddr or
+ // the code provided here. We need to re-execute.
+ return usermem.NoAccess, platform.ErrContextInterrupt
+ }
+ // Reset the pointed SignalInfo.
+ *info = arch.SignalInfo{Signo: signal}
+ info.SetAddr(uint64(faultAddr))
+ accessType := usermem.AccessType{
+ Read: code&(1<<1) == 0,
+ Write: code&(1<<1) != 0,
+ Execute: code&(1<<4) != 0,
+ }
+ if !accessType.Write && !accessType.Execute {
+ info.Code = 1 // SEGV_MAPERR.
+ } else {
+ info.Code = 2 // SEGV_ACCERR.
+ }
+ return accessType, platform.ErrContextSignal
+}
+
+// SwitchToUser unpacks architectural-details.
+func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (usermem.AccessType, error) {
+ // Check for canonical addresses.
+ if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Rip) {
+ return nonCanonical(regs.Rip, int32(syscall.SIGSEGV), info)
+ } else if !ring0.IsCanonical(regs.Rsp) {
+ return nonCanonical(regs.Rsp, int32(syscall.SIGBUS), info)
+ } else if !ring0.IsCanonical(regs.Fs_base) {
+ return nonCanonical(regs.Fs_base, int32(syscall.SIGBUS), info)
+ } else if !ring0.IsCanonical(regs.Gs_base) {
+ return nonCanonical(regs.Gs_base, int32(syscall.SIGBUS), info)
+ }
+
+ // Assign PCIDs.
+ if c.PCIDs != nil {
+ var requireFlushPCID bool // Force a flush?
+ switchOpts.UserPCID, requireFlushPCID = c.PCIDs.Assign(switchOpts.PageTables)
+ switchOpts.KernelPCID = fixedKernelPCID
+ switchOpts.Flush = switchOpts.Flush || requireFlushPCID
+ }
+
+ // See below.
+ var vector ring0.Vector
+
+ // Past this point, stack growth can cause system calls (and a break
+ // from guest mode). So we need to ensure that between the bluepill
+ // call here and the switch call immediately below, no additional
+ // allocations occur.
+ entersyscall()
+ bluepill(c)
+ vector = c.CPU.SwitchToUser(switchOpts)
+ exitsyscall()
+
+ switch vector {
+ case ring0.Syscall, ring0.SyscallInt80:
+ // Fast path: system call executed.
+ return usermem.NoAccess, nil
+
+ case ring0.PageFault:
+ return c.fault(int32(syscall.SIGSEGV), info)
+
+ case ring0.Debug, ring0.Breakpoint:
+ *info = arch.SignalInfo{
+ Signo: int32(syscall.SIGTRAP),
+ Code: 1, // TRAP_BRKPT (breakpoint).
+ }
+ info.SetAddr(switchOpts.Registers.Rip) // Include address.
+ return usermem.AccessType{}, platform.ErrContextSignal
+
+ case ring0.GeneralProtectionFault,
+ ring0.SegmentNotPresent,
+ ring0.BoundRangeExceeded,
+ ring0.InvalidTSS,
+ ring0.StackSegmentFault:
+ *info = arch.SignalInfo{
+ Signo: int32(syscall.SIGSEGV),
+ Code: arch.SignalInfoKernel,
+ }
+ info.SetAddr(switchOpts.Registers.Rip) // Include address.
+ if vector == ring0.GeneralProtectionFault {
+ // When CPUID faulting is enabled, we will generate a #GP(0) when
+ // userspace executes a CPUID instruction. This is handled above,
+ // because we need to be able to map and read user memory.
+ return usermem.AccessType{}, platform.ErrContextSignalCPUID
+ }
+ return usermem.AccessType{}, platform.ErrContextSignal
+
+ case ring0.InvalidOpcode:
+ *info = arch.SignalInfo{
+ Signo: int32(syscall.SIGILL),
+ Code: 1, // ILL_ILLOPC (illegal opcode).
+ }
+ info.SetAddr(switchOpts.Registers.Rip) // Include address.
+ return usermem.AccessType{}, platform.ErrContextSignal
+
+ case ring0.DivideByZero:
+ *info = arch.SignalInfo{
+ Signo: int32(syscall.SIGFPE),
+ Code: 1, // FPE_INTDIV (divide by zero).
+ }
+ info.SetAddr(switchOpts.Registers.Rip) // Include address.
+ return usermem.AccessType{}, platform.ErrContextSignal
+
+ case ring0.Overflow:
+ *info = arch.SignalInfo{
+ Signo: int32(syscall.SIGFPE),
+ Code: 2, // FPE_INTOVF (integer overflow).
+ }
+ info.SetAddr(switchOpts.Registers.Rip) // Include address.
+ return usermem.AccessType{}, platform.ErrContextSignal
+
+ case ring0.X87FloatingPointException,
+ ring0.SIMDFloatingPointException:
+ *info = arch.SignalInfo{
+ Signo: int32(syscall.SIGFPE),
+ Code: 7, // FPE_FLTINV (invalid operation).
+ }
+ info.SetAddr(switchOpts.Registers.Rip) // Include address.
+ return usermem.AccessType{}, platform.ErrContextSignal
+
+ case ring0.Vector(bounce): // ring0.VirtualizationException
+ return usermem.NoAccess, platform.ErrContextInterrupt
+
+ case ring0.AlignmentCheck:
+ *info = arch.SignalInfo{
+ Signo: int32(syscall.SIGBUS),
+ Code: 2, // BUS_ADRERR (physical address does not exist).
+ }
+ return usermem.NoAccess, platform.ErrContextSignal
+
+ case ring0.NMI:
+ // An NMI is generated only when a fault is not servicable by
+ // KVM itself, so we think some mapping is writeable but it's
+ // really not. This could happen, e.g. if some file is
+ // truncated (and would generate a SIGBUS) and we map it
+ // directly into the instance.
+ return c.fault(int32(syscall.SIGBUS), info)
+
+ case ring0.DeviceNotAvailable,
+ ring0.DoubleFault,
+ ring0.CoprocessorSegmentOverrun,
+ ring0.MachineCheck,
+ ring0.SecurityException:
+ fallthrough
+ default:
+ panic(fmt.Sprintf("unexpected vector: 0x%x", vector))
+ }
+}
+
+// On x86 platform, the flags for "setMemoryRegion" can always be set as 0.
+// There is no need to return read-only physicalRegions.
+func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
+ return nil
+}
+
+func availableRegionsForSetMem() (phyRegions []physicalRegion) {
+ return physicalRegions
+}
diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
new file mode 100644
index 000000000..290f035dd
--- /dev/null
+++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
@@ -0,0 +1,177 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package kvm
+
+import (
+ "fmt"
+ "sync/atomic"
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/time"
+)
+
+// loadSegments copies the current segments.
+//
+// This may be called from within the signal context and throws on error.
+//
+//go:nosplit
+func (c *vCPU) loadSegments(tid uint64) {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_ARCH_PRCTL,
+ linux.ARCH_GET_FS,
+ uintptr(unsafe.Pointer(&c.CPU.Registers().Fs_base)),
+ 0); errno != 0 {
+ throw("getting FS segment")
+ }
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_ARCH_PRCTL,
+ linux.ARCH_GET_GS,
+ uintptr(unsafe.Pointer(&c.CPU.Registers().Gs_base)),
+ 0); errno != 0 {
+ throw("getting GS segment")
+ }
+ atomic.StoreUint64(&c.tid, tid)
+}
+
+// setCPUID sets the CPUID to be used by the guest.
+func (c *vCPU) setCPUID() error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_CPUID2,
+ uintptr(unsafe.Pointer(&cpuidSupported))); errno != 0 {
+ return fmt.Errorf("error setting CPUID: %v", errno)
+ }
+ return nil
+}
+
+// setSystemTime sets the TSC for the vCPU.
+//
+// This has to make the call many times in order to minimize the intrinsic
+// error in the offset. Unfortunately KVM does not expose a relative offset via
+// the API, so this is an approximation. We do this via an iterative algorithm.
+// This has the advantage that it can generally deal with highly variable
+// system call times and should converge on the correct offset.
+func (c *vCPU) setSystemTime() error {
+ const (
+ _MSR_IA32_TSC = 0x00000010
+ calibrateTries = 10
+ )
+ registers := modelControlRegisters{
+ nmsrs: 1,
+ }
+ registers.entries[0] = modelControlRegister{
+ index: _MSR_IA32_TSC,
+ }
+ target := uint64(^uint32(0))
+ for done := 0; done < calibrateTries; {
+ start := uint64(time.Rdtsc())
+ registers.entries[0].data = start + target
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_MSRS,
+ uintptr(unsafe.Pointer(&registers))); errno != 0 {
+ return fmt.Errorf("error setting system time: %v", errno)
+ }
+ // See if this is our new minimum call time. Note that this
+ // serves two functions: one, we make sure that we are
+ // accurately predicting the offset we need to set. Second, we
+ // don't want to do the final set on a slow call, which could
+ // produce a really bad result. So we only count attempts
+ // within +/- 6.25% of our minimum as an attempt.
+ end := uint64(time.Rdtsc())
+ if end < start {
+ continue // Totally bogus.
+ }
+ half := (end - start) / 2
+ if half < target {
+ target = half
+ }
+ if (half - target) < target/8 {
+ done++
+ }
+ }
+ return nil
+}
+
+// setSignalMask sets the vCPU signal mask.
+//
+// This must be called prior to running the vCPU.
+func (c *vCPU) setSignalMask() error {
+ // The layout of this structure implies that it will not necessarily be
+ // the same layout chosen by the Go compiler. It gets fudged here.
+ var data struct {
+ length uint32
+ mask1 uint32
+ mask2 uint32
+ _ uint32
+ }
+ data.length = 8 // Fixed sigset size.
+ data.mask1 = ^uint32(bounceSignalMask & 0xffffffff)
+ data.mask2 = ^uint32(bounceSignalMask >> 32)
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_SIGNAL_MASK,
+ uintptr(unsafe.Pointer(&data))); errno != 0 {
+ return fmt.Errorf("error setting signal mask: %v", errno)
+ }
+ return nil
+}
+
+// setUserRegisters sets user registers in the vCPU.
+func (c *vCPU) setUserRegisters(uregs *userRegs) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_REGS,
+ uintptr(unsafe.Pointer(uregs))); errno != 0 {
+ return fmt.Errorf("error setting user registers: %v", errno)
+ }
+ return nil
+}
+
+// getUserRegisters reloads user registers in the vCPU.
+//
+// This is safe to call from a nosplit context.
+//
+//go:nosplit
+func (c *vCPU) getUserRegisters(uregs *userRegs) syscall.Errno {
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_GET_REGS,
+ uintptr(unsafe.Pointer(uregs))); errno != 0 {
+ return errno
+ }
+ return 0
+}
+
+// setSystemRegisters sets system registers.
+func (c *vCPU) setSystemRegisters(sregs *systemRegs) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_SREGS,
+ uintptr(unsafe.Pointer(sregs))); errno != 0 {
+ return fmt.Errorf("error setting system registers: %v", errno)
+ }
+ return nil
+}
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
new file mode 100644
index 000000000..f3bf973de
--- /dev/null
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -0,0 +1,195 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package kvm
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type vCPUArchState struct {
+ // PCIDs is the set of PCIDs for this vCPU.
+ //
+ // This starts above fixedKernelPCID.
+ PCIDs *pagetables.PCIDs
+
+ // floatingPointState is the floating point state buffer used in guest
+ // to host transitions. See usage in bluepill_arm64.go.
+ floatingPointState *arch.FloatingPointData
+}
+
+const (
+ // fixedKernelPCID is a fixed kernel PCID used for the kernel page
+ // tables. We must start allocating user PCIDs above this in order to
+ // avoid any conflict (see below).
+ fixedKernelPCID = 1
+
+ // poolPCIDs is the number of PCIDs to record in the database. As this
+ // grows, assignment can take longer, since it is a simple linear scan.
+ // Beyond a relatively small number, there are likely few perform
+ // benefits, since the TLB has likely long since lost any translations
+ // from more than a few PCIDs past.
+ poolPCIDs = 8
+)
+
+// Get all read-only physicalRegions.
+func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
+ var rdonlyRegions []region
+
+ applyVirtualRegions(func(vr virtualRegion) {
+ if excludeVirtualRegion(vr) {
+ return
+ }
+
+ if !vr.accessType.Write && vr.accessType.Read {
+ rdonlyRegions = append(rdonlyRegions, vr.region)
+ }
+
+ // TODO(gvisor.dev/issue/2686): PROT_NONE should be specially treated.
+ // Workaround: treated as rdonly temporarily.
+ if !vr.accessType.Write && !vr.accessType.Read && !vr.accessType.Execute {
+ rdonlyRegions = append(rdonlyRegions, vr.region)
+ }
+ })
+
+ for _, r := range rdonlyRegions {
+ physical, _, ok := translateToPhysical(r.virtual)
+ if !ok {
+ continue
+ }
+
+ phyRegions = append(phyRegions, physicalRegion{
+ region: region{
+ virtual: r.virtual,
+ length: r.length,
+ },
+ physical: physical,
+ })
+ }
+
+ return phyRegions
+}
+
+// Get all available physicalRegions.
+func availableRegionsForSetMem() (phyRegions []physicalRegion) {
+ var excludeRegions []region
+ applyVirtualRegions(func(vr virtualRegion) {
+ if !vr.accessType.Write {
+ excludeRegions = append(excludeRegions, vr.region)
+ }
+ })
+
+ phyRegions = computePhysicalRegions(excludeRegions)
+
+ return phyRegions
+}
+
+// dropPageTables drops cached page table entries.
+func (m *machine) dropPageTables(pt *pagetables.PageTables) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ // Clear from all PCIDs.
+ for _, c := range m.vCPUsByID {
+ if c.PCIDs != nil {
+ c.PCIDs.Drop(pt)
+ }
+ }
+}
+
+// nonCanonical generates a canonical address return.
+//
+//go:nosplit
+func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (usermem.AccessType, error) {
+ *info = arch.SignalInfo{
+ Signo: signal,
+ Code: arch.SignalInfoKernel,
+ }
+ info.SetAddr(addr) // Include address.
+ return usermem.NoAccess, platform.ErrContextSignal
+}
+
+// fault generates an appropriate fault return.
+//
+//go:nosplit
+func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, error) {
+ faultAddr := c.GetFaultAddr()
+ code, user := c.ErrorCode()
+
+ // Reset the pointed SignalInfo.
+ *info = arch.SignalInfo{Signo: signal}
+ info.SetAddr(uint64(faultAddr))
+
+ read := true
+ write := false
+ execute := true
+
+ ret := code & _ESR_ELx_FSC
+ switch ret {
+ case _ESR_SEGV_MAPERR_L0, _ESR_SEGV_MAPERR_L1, _ESR_SEGV_MAPERR_L2, _ESR_SEGV_MAPERR_L3:
+ info.Code = 1 //SEGV_MAPERR
+ read = false
+ write = true
+ execute = false
+ case _ESR_SEGV_ACCERR_L1, _ESR_SEGV_ACCERR_L2, _ESR_SEGV_ACCERR_L3, _ESR_SEGV_PEMERR_L1, _ESR_SEGV_PEMERR_L2, _ESR_SEGV_PEMERR_L3:
+ info.Code = 2 // SEGV_ACCERR.
+ read = true
+ write = false
+ execute = false
+ default:
+ info.Code = 2
+ }
+
+ if !user {
+ read = true
+ write = false
+ execute = true
+
+ }
+ accessType := usermem.AccessType{
+ Read: read,
+ Write: write,
+ Execute: execute,
+ }
+
+ return accessType, platform.ErrContextSignal
+}
+
+// retryInGuest runs the given function in guest mode.
+//
+// If the function does not complete in guest mode (due to execution of a
+// system call due to a GC stall, for example), then it will be retried. The
+// given function must be idempotent as a result of the retry mechanism.
+func (m *machine) retryInGuest(fn func()) {
+ c := m.Get()
+ defer m.Put(c)
+ for {
+ c.ClearErrorCode() // See below.
+ bluepill(c) // Force guest mode.
+ fn() // Execute the given function.
+ _, user := c.ErrorCode()
+ if user {
+ // If user is set, then we haven't bailed back to host
+ // mode via a kernel exception or system call. We
+ // consider the full function to have executed in guest
+ // mode and we can return.
+ break
+ }
+ }
+}
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
new file mode 100644
index 000000000..8bed34922
--- /dev/null
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -0,0 +1,282 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package kvm
+
+import (
+ "fmt"
+ "reflect"
+ "sync/atomic"
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type kvmVcpuInit struct {
+ target uint32
+ features [7]uint32
+}
+
+var vcpuInit kvmVcpuInit
+
+// initArchState initializes architecture-specific state.
+func (m *machine) initArchState() error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(m.fd),
+ _KVM_ARM_PREFERRED_TARGET,
+ uintptr(unsafe.Pointer(&vcpuInit))); errno != 0 {
+ panic(fmt.Sprintf("error setting KVM_ARM_PREFERRED_TARGET failed: %v", errno))
+ }
+ return nil
+}
+
+// initArchState initializes architecture-specific state.
+func (c *vCPU) initArchState() error {
+ var (
+ reg kvmOneReg
+ data uint64
+ regGet kvmOneReg
+ dataGet uint64
+ )
+
+ reg.addr = uint64(reflect.ValueOf(&data).Pointer())
+ regGet.addr = uint64(reflect.ValueOf(&dataGet).Pointer())
+
+ vcpuInit.target = _KVM_ARM_TARGET_GENERIC_V8
+ vcpuInit.features[0] |= (1 << _KVM_ARM_VCPU_PSCI_0_2)
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_ARM_VCPU_INIT,
+ uintptr(unsafe.Pointer(&vcpuInit))); errno != 0 {
+ panic(fmt.Sprintf("error setting KVM_ARM_VCPU_INIT failed: %v", errno))
+ }
+
+ // cpacr_el1
+ reg.id = _KVM_ARM64_REGS_CPACR_EL1
+ // It is off by default, and it is turned on only when in use.
+ data = 0 // Disable fpsimd.
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // sctlr_el1
+ regGet.id = _KVM_ARM64_REGS_SCTLR_EL1
+ if err := c.getOneRegister(&regGet); err != nil {
+ return err
+ }
+
+ dataGet |= (_SCTLR_M | _SCTLR_C | _SCTLR_I)
+ data = dataGet
+ reg.id = _KVM_ARM64_REGS_SCTLR_EL1
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // tcr_el1
+ data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS
+ reg.id = _KVM_ARM64_REGS_TCR_EL1
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // mair_el1
+ data = _MT_EL1_INIT
+ reg.id = _KVM_ARM64_REGS_MAIR_EL1
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // ttbr0_el1
+ data = c.machine.kernel.PageTables.TTBR0_EL1(false, 0)
+
+ reg.id = _KVM_ARM64_REGS_TTBR0_EL1
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ c.SetTtbr0Kvm(uintptr(data))
+
+ // ttbr1_el1
+ data = c.machine.kernel.PageTables.TTBR1_EL1(false, 0)
+
+ reg.id = _KVM_ARM64_REGS_TTBR1_EL1
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // sp_el1
+ data = c.CPU.StackTop()
+ reg.id = _KVM_ARM64_REGS_SP_EL1
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // pc
+ reg.id = _KVM_ARM64_REGS_PC
+ data = uint64(reflect.ValueOf(ring0.Start).Pointer())
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // r8
+ reg.id = _KVM_ARM64_REGS_R8
+ data = uint64(reflect.ValueOf(&c.CPU).Pointer())
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // vbar_el1
+ reg.id = _KVM_ARM64_REGS_VBAR_EL1
+
+ fromLocation := reflect.ValueOf(ring0.Vectors).Pointer()
+ offset := fromLocation & (1<<11 - 1)
+ if offset != 0 {
+ offset = 1<<11 - offset
+ }
+
+ toLocation := fromLocation + offset
+ data = uint64(ring0.KernelStartAddress | toLocation)
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // Use the address of the exception vector table as
+ // the MMIO address base.
+ arm64HypercallMMIOBase = toLocation
+
+ data = ring0.PsrDefaultSet | ring0.KernelFlagsSet
+ reg.id = _KVM_ARM64_REGS_PSTATE
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ c.floatingPointState = arch.NewFloatingPointData()
+ return nil
+}
+
+//go:nosplit
+func (c *vCPU) loadSegments(tid uint64) {
+ // TODO(gvisor.dev/issue/1238): TLS is not supported.
+ // Get TLS from tpidr_el0.
+ atomic.StoreUint64(&c.tid, tid)
+}
+
+func (c *vCPU) setOneRegister(reg *kvmOneReg) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_ONE_REG,
+ uintptr(unsafe.Pointer(reg))); errno != 0 {
+ return fmt.Errorf("error setting one register: %v", errno)
+ }
+ return nil
+}
+
+func (c *vCPU) getOneRegister(reg *kvmOneReg) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_GET_ONE_REG,
+ uintptr(unsafe.Pointer(reg))); errno != 0 {
+ return fmt.Errorf("error setting one register: %v", errno)
+ }
+ return nil
+}
+
+// setCPUID sets the CPUID to be used by the guest.
+func (c *vCPU) setCPUID() error {
+ return nil
+}
+
+// setSystemTime sets the TSC for the vCPU.
+func (c *vCPU) setSystemTime() error {
+ return nil
+}
+
+// setSignalMask sets the vCPU signal mask.
+//
+// This must be called prior to running the vCPU.
+func (c *vCPU) setSignalMask() error {
+ // The layout of this structure implies that it will not necessarily be
+ // the same layout chosen by the Go compiler. It gets fudged here.
+ var data struct {
+ length uint32
+ mask1 uint32
+ mask2 uint32
+ _ uint32
+ }
+ data.length = 8 // Fixed sigset size.
+ data.mask1 = ^uint32(bounceSignalMask & 0xffffffff)
+ data.mask2 = ^uint32(bounceSignalMask >> 32)
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_SIGNAL_MASK,
+ uintptr(unsafe.Pointer(&data))); errno != 0 {
+ return fmt.Errorf("error setting signal mask: %v", errno)
+ }
+
+ return nil
+}
+
+// SwitchToUser unpacks architectural-details.
+func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (usermem.AccessType, error) {
+ // Check for canonical addresses.
+ if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Pc) {
+ return nonCanonical(regs.Pc, int32(syscall.SIGSEGV), info)
+ } else if !ring0.IsCanonical(regs.Sp) {
+ return nonCanonical(regs.Sp, int32(syscall.SIGBUS), info)
+ }
+
+ var vector ring0.Vector
+ ttbr0App := switchOpts.PageTables.TTBR0_EL1(false, 0)
+ c.SetTtbr0App(uintptr(ttbr0App))
+
+ // TODO(gvisor.dev/issue/1238): full context-switch supporting for Arm64.
+ // The Arm64 user-mode execution state consists of:
+ // x0-x30
+ // PC, SP, PSTATE
+ // V0-V31: 32 128-bit registers for floating point, and simd
+ // FPSR
+ // TPIDR_EL0, used for TLS
+ appRegs := switchOpts.Registers
+ c.SetAppAddr(ring0.KernelStartAddress | uintptr(unsafe.Pointer(appRegs)))
+
+ entersyscall()
+ bluepill(c)
+ vector = c.CPU.SwitchToUser(switchOpts)
+ exitsyscall()
+
+ switch vector {
+ case ring0.Syscall:
+ // Fast path: system call executed.
+ return usermem.NoAccess, nil
+
+ case ring0.PageFault:
+ return c.fault(int32(syscall.SIGSEGV), info)
+ case ring0.Vector(bounce): // ring0.VirtualizationException
+ return usermem.NoAccess, platform.ErrContextInterrupt
+ default:
+ return usermem.NoAccess, platform.ErrContextSignal
+ }
+
+}
diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go
new file mode 100644
index 000000000..9f86f6a7a
--- /dev/null
+++ b/pkg/sentry/platform/kvm/machine_unsafe.go
@@ -0,0 +1,145 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build go1.12
+// +build !go1.16
+
+// Check go:linkname function signatures when updating Go version.
+
+package kvm
+
+import (
+ "fmt"
+ "math"
+ "sync/atomic"
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+//go:linkname entersyscall runtime.entersyscall
+func entersyscall()
+
+//go:linkname exitsyscall runtime.exitsyscall
+func exitsyscall()
+
+// setMemoryRegion initializes a region.
+//
+// This may be called from bluepillHandler, and therefore returns an errno
+// directly (instead of wrapping in an error) to avoid allocations.
+//
+//go:nosplit
+func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr, flags uint32) syscall.Errno {
+ userRegion := userMemoryRegion{
+ slot: uint32(slot),
+ flags: uint32(flags),
+ guestPhysAddr: uint64(physical),
+ memorySize: uint64(length),
+ userspaceAddr: uint64(virtual),
+ }
+
+ // Set the region.
+ _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(m.fd),
+ _KVM_SET_USER_MEMORY_REGION,
+ uintptr(unsafe.Pointer(&userRegion)))
+ return errno
+}
+
+// mapRunData maps the vCPU run data.
+func mapRunData(fd int) (*runData, error) {
+ r, _, errno := syscall.RawSyscall6(
+ syscall.SYS_MMAP,
+ 0,
+ uintptr(runDataSize),
+ syscall.PROT_READ|syscall.PROT_WRITE,
+ syscall.MAP_SHARED,
+ uintptr(fd),
+ 0)
+ if errno != 0 {
+ return nil, fmt.Errorf("error mapping runData: %v", errno)
+ }
+ return (*runData)(unsafe.Pointer(r)), nil
+}
+
+// unmapRunData unmaps the vCPU run data.
+func unmapRunData(r *runData) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_MUNMAP,
+ uintptr(unsafe.Pointer(r)),
+ uintptr(runDataSize),
+ 0); errno != 0 {
+ return fmt.Errorf("error unmapping runData: %v", errno)
+ }
+ return nil
+}
+
+// atomicAddressSpace is an atomic address space pointer.
+type atomicAddressSpace struct {
+ pointer unsafe.Pointer
+}
+
+// set sets the address space value.
+//
+//go:nosplit
+func (a *atomicAddressSpace) set(as *addressSpace) {
+ atomic.StorePointer(&a.pointer, unsafe.Pointer(as))
+}
+
+// get gets the address space value.
+//
+// Note that this should be considered best-effort, and may have changed by the
+// time this function returns.
+//
+//go:nosplit
+func (a *atomicAddressSpace) get() *addressSpace {
+ return (*addressSpace)(atomic.LoadPointer(&a.pointer))
+}
+
+// notify notifies that the vCPU has transitioned modes.
+//
+// This may be called by a signal handler and therefore throws on error.
+//
+//go:nosplit
+func (c *vCPU) notify() {
+ _, _, errno := syscall.RawSyscall6( // escapes: no.
+ syscall.SYS_FUTEX,
+ uintptr(unsafe.Pointer(&c.state)),
+ linux.FUTEX_WAKE|linux.FUTEX_PRIVATE_FLAG,
+ math.MaxInt32, // Number of waiters.
+ 0, 0, 0)
+ if errno != 0 {
+ throw("futex wake error")
+ }
+}
+
+// waitUntilNot waits for the vCPU to transition modes.
+//
+// The state should have been previously set to vCPUWaiter after performing an
+// appropriate action to cause a transition (e.g. interrupt injection).
+//
+// This panics on error.
+func (c *vCPU) waitUntilNot(state uint32) {
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_FUTEX,
+ uintptr(unsafe.Pointer(&c.state)),
+ linux.FUTEX_WAIT|linux.FUTEX_PRIVATE_FLAG,
+ uintptr(state),
+ 0, 0, 0)
+ if errno != 0 && errno != syscall.EINTR && errno != syscall.EAGAIN {
+ panic("futex wait error")
+ }
+}
diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go
new file mode 100644
index 000000000..f7fa2f98d
--- /dev/null
+++ b/pkg/sentry/platform/kvm/physical_map.go
@@ -0,0 +1,214 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "fmt"
+ "sort"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type region struct {
+ virtual uintptr
+ length uintptr
+}
+
+type physicalRegion struct {
+ region
+ physical uintptr
+}
+
+// physicalRegions contains a list of available physical regions.
+//
+// The physical value used in physicalRegions is a number indicating the
+// physical offset, aligned appropriately and starting above reservedMemory.
+var physicalRegions []physicalRegion
+
+// fillAddressSpace fills the host address space with PROT_NONE mappings until
+// we have a host address space size that is less than or equal to the physical
+// address space. This allows us to have an injective host virtual to guest
+// physical mapping.
+//
+// The excluded regions are returned.
+func fillAddressSpace() (excludedRegions []region) {
+ // We can cut vSize in half, because the kernel will be using the top
+ // half and we ignore it while constructing mappings. It's as if we've
+ // already excluded half the possible addresses.
+ vSize := ring0.UserspaceSize
+
+ // We exclude reservedMemory below from our physical memory size, so it
+ // needs to be dropped here as well. Otherwise, we could end up with
+ // physical addresses that are beyond what is mapped.
+ pSize := uintptr(1) << ring0.PhysicalAddressBits()
+ pSize -= reservedMemory
+
+ // Add specifically excluded regions; see excludeVirtualRegion.
+ applyVirtualRegions(func(vr virtualRegion) {
+ if excludeVirtualRegion(vr) {
+ excludedRegions = append(excludedRegions, vr.region)
+ vSize -= vr.length
+ log.Infof("excluded: virtual [%x,%x)", vr.virtual, vr.virtual+vr.length)
+ }
+ })
+
+ // Do we need any more work?
+ if vSize < pSize {
+ return excludedRegions
+ }
+
+ // Calculate the required space and fill it.
+ //
+ // Note carefully that we add faultBlockSize to required up front, and
+ // on each iteration of the loop below (i.e. each new physical region
+ // we define), we add faultBlockSize again. This is done because the
+ // computation of physical regions will ensure proper alignments with
+ // faultBlockSize, potentially causing up to faultBlockSize bytes in
+ // internal fragmentation for each physical region. So we need to
+ // account for this properly during allocation.
+ requiredAddr, ok := usermem.Addr(vSize - pSize + faultBlockSize).RoundUp()
+ if !ok {
+ panic(fmt.Sprintf(
+ "overflow for vSize (%x) - pSize (%x) + faultBlockSize (%x)",
+ vSize, pSize, faultBlockSize))
+ }
+ required := uintptr(requiredAddr)
+ current := required // Attempted mmap size.
+ for filled := uintptr(0); filled < required && current > 0; {
+ addr, _, errno := syscall.RawSyscall6(
+ syscall.SYS_MMAP,
+ 0, // Suggested address.
+ current,
+ syscall.PROT_NONE,
+ syscall.MAP_ANONYMOUS|syscall.MAP_PRIVATE|syscall.MAP_NORESERVE,
+ 0, 0)
+ if errno != 0 {
+ // Attempt half the size; overflow not possible.
+ currentAddr, _ := usermem.Addr(current >> 1).RoundUp()
+ current = uintptr(currentAddr)
+ continue
+ }
+ // We filled a block.
+ filled += current
+ excludedRegions = append(excludedRegions, region{
+ virtual: addr,
+ length: current,
+ })
+ // See comment above.
+ if filled != required {
+ required += faultBlockSize
+ }
+ }
+ if current == 0 {
+ panic("filling address space failed")
+ }
+ sort.Slice(excludedRegions, func(i, j int) bool {
+ return excludedRegions[i].virtual < excludedRegions[j].virtual
+ })
+ for _, r := range excludedRegions {
+ log.Infof("region: virtual [%x,%x)", r.virtual, r.virtual+r.length)
+ }
+ return excludedRegions
+}
+
+// computePhysicalRegions computes physical regions.
+func computePhysicalRegions(excludedRegions []region) (physicalRegions []physicalRegion) {
+ physical := uintptr(reservedMemory)
+ addValidRegion := func(virtual, length uintptr) {
+ if length == 0 {
+ return
+ }
+ if virtual == 0 {
+ virtual += usermem.PageSize
+ length -= usermem.PageSize
+ }
+ if end := virtual + length; end > ring0.MaximumUserAddress {
+ length -= (end - ring0.MaximumUserAddress)
+ }
+ if length == 0 {
+ return
+ }
+ // Round physical up to the same alignment as the virtual
+ // address (with respect to faultBlockSize).
+ if offset := virtual &^ faultBlockMask; physical&^faultBlockMask != offset {
+ if newPhysical := (physical & faultBlockMask) + offset; newPhysical > physical {
+ physical = newPhysical // Round up by only a little bit.
+ } else {
+ physical = ((physical + faultBlockSize) & faultBlockMask) + offset
+ }
+ }
+ physicalRegions = append(physicalRegions, physicalRegion{
+ region: region{
+ virtual: virtual,
+ length: length,
+ },
+ physical: physical,
+ })
+ physical += length
+ }
+ lastExcludedEnd := uintptr(0)
+ for _, r := range excludedRegions {
+ addValidRegion(lastExcludedEnd, r.virtual-lastExcludedEnd)
+ lastExcludedEnd = r.virtual + r.length
+ }
+ addValidRegion(lastExcludedEnd, ring0.MaximumUserAddress-lastExcludedEnd)
+
+ // Dump our all physical regions.
+ for _, r := range physicalRegions {
+ log.Infof("physicalRegion: virtual [%x,%x) => physical [%x,%x)",
+ r.virtual, r.virtual+r.length, r.physical, r.physical+r.length)
+ }
+ return physicalRegions
+}
+
+// physicalInit initializes physical address mappings.
+func physicalInit() {
+ physicalRegions = computePhysicalRegions(fillAddressSpace())
+}
+
+// applyPhysicalRegions applies the given function on physical regions.
+//
+// Iteration continues as long as true is returned. The return value is the
+// return from the last call to fn, or true if there are no entries.
+//
+// Precondition: physicalInit must have been called.
+func applyPhysicalRegions(fn func(pr physicalRegion) bool) bool {
+ for _, pr := range physicalRegions {
+ if !fn(pr) {
+ return false
+ }
+ }
+ return true
+}
+
+// translateToPhysical translates the given virtual address.
+//
+// Precondition: physicalInit must have been called.
+//
+//go:nosplit
+func translateToPhysical(virtual uintptr) (physical uintptr, length uintptr, ok bool) {
+ for _, pr := range physicalRegions {
+ if pr.virtual <= virtual && virtual < pr.virtual+pr.length {
+ physical = pr.physical + (virtual - pr.virtual)
+ length = pr.length - (virtual - pr.virtual)
+ ok = true
+ return
+ }
+ }
+ return
+}
diff --git a/pkg/sentry/platform/kvm/physical_map_amd64.go b/pkg/sentry/platform/kvm/physical_map_amd64.go
new file mode 100644
index 000000000..c5adfb577
--- /dev/null
+++ b/pkg/sentry/platform/kvm/physical_map_amd64.go
@@ -0,0 +1,22 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+const (
+ // reservedMemory is a chunk of physical memory reserved starting at
+ // physical address zero. There are some special pages in this region,
+ // so we just call the whole thing off.
+ reservedMemory = 0x100000000
+)
diff --git a/pkg/sentry/platform/kvm/physical_map_arm64.go b/pkg/sentry/platform/kvm/physical_map_arm64.go
new file mode 100644
index 000000000..4d8561453
--- /dev/null
+++ b/pkg/sentry/platform/kvm/physical_map_arm64.go
@@ -0,0 +1,19 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+const (
+ reservedMemory = 0
+)
diff --git a/pkg/sentry/platform/kvm/testutil/BUILD b/pkg/sentry/platform/kvm/testutil/BUILD
new file mode 100644
index 000000000..f7feb8683
--- /dev/null
+++ b/pkg/sentry/platform/kvm/testutil/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "testutil",
+ testonly = 1,
+ srcs = [
+ "testutil.go",
+ "testutil_amd64.go",
+ "testutil_amd64.s",
+ "testutil_arm64.go",
+ "testutil_arm64.s",
+ ],
+ visibility = ["//pkg/sentry/platform/kvm:__pkg__"],
+ deps = ["//pkg/sentry/arch"],
+)
diff --git a/pkg/sentry/platform/kvm/testutil/testutil.go b/pkg/sentry/platform/kvm/testutil/testutil.go
new file mode 100644
index 000000000..5c1efa0fd
--- /dev/null
+++ b/pkg/sentry/platform/kvm/testutil/testutil.go
@@ -0,0 +1,72 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testutil provides common assembly stubs for testing.
+package testutil
+
+import (
+ "fmt"
+ "strings"
+)
+
+// Getpid executes a trivial system call.
+func Getpid()
+
+// Touch touches the value in the first register.
+func Touch()
+
+// SyscallLoop executes a syscall and loops.
+func SyscallLoop()
+
+// SpinLoop spins on the CPU.
+func SpinLoop()
+
+// HaltLoop immediately halts and loops.
+func HaltLoop()
+
+// TwiddleRegsFault twiddles registers then faults.
+func TwiddleRegsFault()
+
+// TwiddleRegsSyscall twiddles registers then executes a syscall.
+func TwiddleRegsSyscall()
+
+// FloatingPointWorks is a floating point test.
+//
+// It returns true or false.
+func FloatingPointWorks() bool
+
+// RegisterMismatchError is used for checking registers.
+type RegisterMismatchError []string
+
+// Error returns a human-readable error.
+func (r RegisterMismatchError) Error() string {
+ return strings.Join([]string(r), ";")
+}
+
+// addRegisterMisatch allows simple chaining of register mismatches.
+func addRegisterMismatch(err error, reg string, got, expected interface{}) error {
+ errStr := fmt.Sprintf("%s got %08x, expected %08x", reg, got, expected)
+ switch r := err.(type) {
+ case nil:
+ // Return a new register mismatch.
+ return RegisterMismatchError{errStr}
+ case RegisterMismatchError:
+ // Append the error.
+ r = append(r, errStr)
+ return r
+ default:
+ // Leave as is.
+ return err
+ }
+}
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go
new file mode 100644
index 000000000..8048eedec
--- /dev/null
+++ b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go
@@ -0,0 +1,139 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package testutil
+
+import (
+ "reflect"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+// TwiddleSegments reads segments into known registers.
+func TwiddleSegments()
+
+// SetTestTarget sets the rip appropriately.
+func SetTestTarget(regs *arch.Registers, fn func()) {
+ regs.Rip = uint64(reflect.ValueOf(fn).Pointer())
+}
+
+// SetTouchTarget sets rax appropriately.
+func SetTouchTarget(regs *arch.Registers, target *uintptr) {
+ if target != nil {
+ regs.Rax = uint64(reflect.ValueOf(target).Pointer())
+ } else {
+ regs.Rax = 0
+ }
+}
+
+// RewindSyscall rewinds a syscall RIP.
+func RewindSyscall(regs *arch.Registers) {
+ regs.Rip -= 2
+}
+
+// SetTestRegs initializes registers to known values.
+func SetTestRegs(regs *arch.Registers) {
+ regs.R15 = 0x15
+ regs.R14 = 0x14
+ regs.R13 = 0x13
+ regs.R12 = 0x12
+ regs.Rbp = 0xb9
+ regs.Rbx = 0xb4
+ regs.R11 = 0x11
+ regs.R10 = 0x10
+ regs.R9 = 0x09
+ regs.R8 = 0x08
+ regs.Rax = 0x44
+ regs.Rcx = 0xc4
+ regs.Rdx = 0xd4
+ regs.Rsi = 0x51
+ regs.Rdi = 0xd1
+ regs.Rsp = 0x59
+}
+
+// CheckTestRegs checks that registers were twiddled per TwiddleRegs.
+func CheckTestRegs(regs *arch.Registers, full bool) (err error) {
+ if need := ^uint64(0x15); regs.R15 != need {
+ err = addRegisterMismatch(err, "R15", regs.R15, need)
+ }
+ if need := ^uint64(0x14); regs.R14 != need {
+ err = addRegisterMismatch(err, "R14", regs.R14, need)
+ }
+ if need := ^uint64(0x13); regs.R13 != need {
+ err = addRegisterMismatch(err, "R13", regs.R13, need)
+ }
+ if need := ^uint64(0x12); regs.R12 != need {
+ err = addRegisterMismatch(err, "R12", regs.R12, need)
+ }
+ if need := ^uint64(0xb9); regs.Rbp != need {
+ err = addRegisterMismatch(err, "Rbp", regs.Rbp, need)
+ }
+ if need := ^uint64(0xb4); regs.Rbx != need {
+ err = addRegisterMismatch(err, "Rbx", regs.Rbx, need)
+ }
+ if need := ^uint64(0x10); regs.R10 != need {
+ err = addRegisterMismatch(err, "R10", regs.R10, need)
+ }
+ if need := ^uint64(0x09); regs.R9 != need {
+ err = addRegisterMismatch(err, "R9", regs.R9, need)
+ }
+ if need := ^uint64(0x08); regs.R8 != need {
+ err = addRegisterMismatch(err, "R8", regs.R8, need)
+ }
+ if need := ^uint64(0x44); regs.Rax != need {
+ err = addRegisterMismatch(err, "Rax", regs.Rax, need)
+ }
+ if need := ^uint64(0xd4); regs.Rdx != need {
+ err = addRegisterMismatch(err, "Rdx", regs.Rdx, need)
+ }
+ if need := ^uint64(0x51); regs.Rsi != need {
+ err = addRegisterMismatch(err, "Rsi", regs.Rsi, need)
+ }
+ if need := ^uint64(0xd1); regs.Rdi != need {
+ err = addRegisterMismatch(err, "Rdi", regs.Rdi, need)
+ }
+ if need := ^uint64(0x59); regs.Rsp != need {
+ err = addRegisterMismatch(err, "Rsp", regs.Rsp, need)
+ }
+ // Rcx & R11 are ignored if !full is set.
+ if need := ^uint64(0x11); full && regs.R11 != need {
+ err = addRegisterMismatch(err, "R11", regs.R11, need)
+ }
+ if need := ^uint64(0xc4); full && regs.Rcx != need {
+ err = addRegisterMismatch(err, "Rcx", regs.Rcx, need)
+ }
+ return
+}
+
+var fsData uint64 = 0x55
+var gsData uint64 = 0x85
+
+// SetTestSegments initializes segments to known values.
+func SetTestSegments(regs *arch.Registers) {
+ regs.Fs_base = uint64(reflect.ValueOf(&fsData).Pointer())
+ regs.Gs_base = uint64(reflect.ValueOf(&gsData).Pointer())
+}
+
+// CheckTestSegments checks that registers were twiddled per TwiddleSegments.
+func CheckTestSegments(regs *arch.Registers) (err error) {
+ if regs.Rax != fsData {
+ err = addRegisterMismatch(err, "Rax", regs.Rax, fsData)
+ }
+ if regs.Rbx != gsData {
+ err = addRegisterMismatch(err, "Rbx", regs.Rcx, gsData)
+ }
+ return
+}
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_amd64.s b/pkg/sentry/platform/kvm/testutil/testutil_amd64.s
new file mode 100644
index 000000000..491ec0c2a
--- /dev/null
+++ b/pkg/sentry/platform/kvm/testutil/testutil_amd64.s
@@ -0,0 +1,98 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+// test_util_amd64.s provides AMD64 test functions.
+
+#include "funcdata.h"
+#include "textflag.h"
+
+TEXT ·Getpid(SB),NOSPLIT,$0
+ NO_LOCAL_POINTERS
+ MOVQ $39, AX // getpid
+ SYSCALL
+ RET
+
+TEXT ·Touch(SB),NOSPLIT,$0
+start:
+ MOVQ 0(AX), BX // deref AX
+ MOVQ $39, AX // getpid
+ SYSCALL
+ JMP start
+
+TEXT ·HaltLoop(SB),NOSPLIT,$0
+start:
+ HLT
+ JMP start
+
+TEXT ·SyscallLoop(SB),NOSPLIT,$0
+start:
+ SYSCALL
+ JMP start
+
+TEXT ·SpinLoop(SB),NOSPLIT,$0
+start:
+ JMP start
+
+TEXT ·FloatingPointWorks(SB),NOSPLIT,$0-8
+ NO_LOCAL_POINTERS
+ MOVQ $1, AX
+ MOVQ AX, X0
+ MOVQ $39, AX // getpid
+ SYSCALL
+ MOVQ X0, AX
+ CMPQ AX, $1
+ SETEQ ret+0(FP)
+ RET
+
+#define TWIDDLE_REGS() \
+ NOTQ R15; \
+ NOTQ R14; \
+ NOTQ R13; \
+ NOTQ R12; \
+ NOTQ BP; \
+ NOTQ BX; \
+ NOTQ R11; \
+ NOTQ R10; \
+ NOTQ R9; \
+ NOTQ R8; \
+ NOTQ AX; \
+ NOTQ CX; \
+ NOTQ DX; \
+ NOTQ SI; \
+ NOTQ DI; \
+ NOTQ SP;
+
+TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0
+ TWIDDLE_REGS()
+ SYSCALL
+ RET // never reached
+
+TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0
+ TWIDDLE_REGS()
+ JMP AX // must fault
+ RET // never reached
+
+#define READ_FS() BYTE $0x64; BYTE $0x48; BYTE $0x8b; BYTE $0x00;
+#define READ_GS() BYTE $0x65; BYTE $0x48; BYTE $0x8b; BYTE $0x00;
+
+TEXT ·TwiddleSegments(SB),NOSPLIT,$0
+ MOVQ $0x0, AX
+ READ_GS()
+ MOVQ AX, BX
+ MOVQ $0x0, AX
+ READ_FS()
+ SYSCALL
+ RET // never reached
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
new file mode 100644
index 000000000..ca902c8c1
--- /dev/null
+++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
@@ -0,0 +1,60 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package testutil
+
+import (
+ "fmt"
+ "reflect"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+// SetTestTarget sets the rip appropriately.
+func SetTestTarget(regs *arch.Registers, fn func()) {
+ regs.Pc = uint64(reflect.ValueOf(fn).Pointer())
+}
+
+// SetTouchTarget sets rax appropriately.
+func SetTouchTarget(regs *arch.Registers, target *uintptr) {
+ if target != nil {
+ regs.Regs[8] = uint64(reflect.ValueOf(target).Pointer())
+ } else {
+ regs.Regs[8] = 0
+ }
+}
+
+// RewindSyscall rewinds a syscall RIP.
+func RewindSyscall(regs *arch.Registers) {
+ regs.Pc -= 4
+}
+
+// SetTestRegs initializes registers to known values.
+func SetTestRegs(regs *arch.Registers) {
+ for i := 0; i <= 30; i++ {
+ regs.Regs[i] = uint64(i) + 1
+ }
+}
+
+// CheckTestRegs checks that registers were twiddled per TwiddleRegs.
+func CheckTestRegs(regs *arch.Registers, full bool) (err error) {
+ for i := 0; i <= 30; i++ {
+ if need := ^uint64(i + 1); regs.Regs[i] != need {
+ err = addRegisterMismatch(err, fmt.Sprintf("R%d", i), regs.Regs[i], need)
+ }
+ }
+ return
+}
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
new file mode 100644
index 000000000..0bebee852
--- /dev/null
+++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
@@ -0,0 +1,106 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+// test_util_arm64.s provides ARM64 test functions.
+
+#include "funcdata.h"
+#include "textflag.h"
+
+#define SYS_GETPID 172
+
+// This function simulates the getpid syscall.
+TEXT ·Getpid(SB),NOSPLIT,$0
+ NO_LOCAL_POINTERS
+ MOVD $SYS_GETPID, R8
+ SVC
+ RET
+
+TEXT ·Touch(SB),NOSPLIT,$0
+start:
+ MOVD 0(R8), R1
+ MOVD $SYS_GETPID, R8 // getpid
+ SVC
+ B start
+
+TEXT ·HaltLoop(SB),NOSPLIT,$0
+start:
+ HLT
+ B start
+
+// This function simulates a loop of syscall.
+TEXT ·SyscallLoop(SB),NOSPLIT,$0
+start:
+ SVC
+ B start
+
+TEXT ·SpinLoop(SB),NOSPLIT,$0
+start:
+ B start
+
+TEXT ·FloatingPointWorks(SB),NOSPLIT,$0-8
+ NO_LOCAL_POINTERS
+ FMOVD $(9.9), F0
+ MOVD $SYS_GETPID, R8 // getpid
+ SVC
+ FMOVD $(9.9), F1
+ FCMPD F0, F1
+ BNE isNaN
+ MOVD $1, R0
+ MOVD R0, ret+0(FP)
+ RET
+isNaN:
+ MOVD $0, ret+0(FP)
+ RET
+
+// MVN: bitwise logical NOT
+// This case simulates an application that modified R0-R30.
+#define TWIDDLE_REGS() \
+ MVN R0, R0; \
+ MVN R1, R1; \
+ MVN R2, R2; \
+ MVN R3, R3; \
+ MVN R4, R4; \
+ MVN R5, R5; \
+ MVN R6, R6; \
+ MVN R7, R7; \
+ MVN R8, R8; \
+ MVN R9, R9; \
+ MVN R10, R10; \
+ MVN R11, R11; \
+ MVN R12, R12; \
+ MVN R13, R13; \
+ MVN R14, R14; \
+ MVN R15, R15; \
+ MVN R16, R16; \
+ MVN R17, R17; \
+ MVN R18_PLATFORM, R18_PLATFORM; \
+ MVN R19, R19; \
+ MVN R20, R20; \
+ MVN R21, R21; \
+ MVN R22, R22; \
+ MVN R23, R23; \
+ MVN R24, R24; \
+ MVN R25, R25; \
+ MVN R26, R26; \
+ MVN R27, R27; \
+ MVN g, g; \
+ MVN R29, R29; \
+ MVN R30, R30;
+
+TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0
+ TWIDDLE_REGS()
+ SVC
+ RET // never reached
diff --git a/pkg/sentry/platform/kvm/virtual_map.go b/pkg/sentry/platform/kvm/virtual_map.go
new file mode 100644
index 000000000..c8897d34f
--- /dev/null
+++ b/pkg/sentry/platform/kvm/virtual_map.go
@@ -0,0 +1,113 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "os"
+ "regexp"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type virtualRegion struct {
+ region
+ accessType usermem.AccessType
+ shared bool
+ offset uintptr
+ filename string
+}
+
+// mapsLine matches a single line from /proc/PID/maps.
+var mapsLine = regexp.MustCompile("([0-9a-f]+)-([0-9a-f]+) ([r-][w-][x-][sp]) ([0-9a-f]+) [0-9a-f]{2}:[0-9a-f]{2,} [0-9]+\\s+(.*)")
+
+// excludeRegion returns true if these regions should be excluded from the
+// physical map. Virtual regions need to be excluded if get_user_pages will
+// fail on those addresses, preventing KVM from satisfying EPT faults.
+//
+// This includes the VVAR page because the VVAR page may be mapped as I/O
+// memory. And the VDSO page is knocked out because the VVAR page is not even
+// recorded in /proc/self/maps on older kernels; knocking out the VDSO page
+// prevents code in the VDSO from accessing the VVAR address.
+//
+// This is called by the physical map functions, not applyVirtualRegions.
+func excludeVirtualRegion(r virtualRegion) bool {
+ return r.filename == "[vvar]" || r.filename == "[vdso]"
+}
+
+// applyVirtualRegions parses the process maps file.
+//
+// Unlike mappedRegions, these are not consistent over time.
+func applyVirtualRegions(fn func(vr virtualRegion)) error {
+ // Open /proc/self/maps.
+ f, err := os.Open("/proc/self/maps")
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+
+ // Parse all entries.
+ r := bufio.NewReader(f)
+ for {
+ b, err := r.ReadBytes('\n')
+ if b != nil && len(b) > 0 {
+ m := mapsLine.FindSubmatch(b)
+ if m == nil {
+ // This should not happen: kernel bug?
+ return fmt.Errorf("badly formed line: %v", string(b))
+ }
+ start, err := strconv.ParseUint(string(m[1]), 16, 64)
+ if err != nil {
+ return fmt.Errorf("bad start address: %v", string(b))
+ }
+ end, err := strconv.ParseUint(string(m[2]), 16, 64)
+ if err != nil {
+ return fmt.Errorf("bad end address: %v", string(b))
+ }
+ read := m[3][0] == 'r'
+ write := m[3][1] == 'w'
+ execute := m[3][2] == 'x'
+ shared := m[3][3] == 's'
+ offset, err := strconv.ParseUint(string(m[4]), 16, 64)
+ if err != nil {
+ return fmt.Errorf("bad offset: %v", string(b))
+ }
+ fn(virtualRegion{
+ region: region{
+ virtual: uintptr(start),
+ length: uintptr(end - start),
+ },
+ accessType: usermem.AccessType{
+ Read: read,
+ Write: write,
+ Execute: execute,
+ },
+ shared: shared,
+ offset: uintptr(offset),
+ filename: string(m[5]),
+ })
+ }
+ if err != nil && err == io.EOF {
+ break
+ } else if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/platform/kvm/virtual_map_test.go b/pkg/sentry/platform/kvm/virtual_map_test.go
new file mode 100644
index 000000000..327e2be4f
--- /dev/null
+++ b/pkg/sentry/platform/kvm/virtual_map_test.go
@@ -0,0 +1,93 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "syscall"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type checker struct {
+ ok bool
+ accessType usermem.AccessType
+}
+
+func (c *checker) Containing(addr uintptr) func(virtualRegion) {
+ c.ok = false // Reset for below calls.
+ return func(vr virtualRegion) {
+ if vr.virtual <= addr && addr < vr.virtual+vr.length {
+ c.ok = true
+ c.accessType = vr.accessType
+ }
+ }
+}
+
+func TestParseMaps(t *testing.T) {
+ c := new(checker)
+
+ // Simple test.
+ if err := applyVirtualRegions(c.Containing(0)); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // MMap a new page.
+ addr, _, errno := syscall.RawSyscall6(
+ syscall.SYS_MMAP, 0, usermem.PageSize,
+ syscall.PROT_READ|syscall.PROT_WRITE,
+ syscall.MAP_ANONYMOUS|syscall.MAP_PRIVATE, 0, 0)
+ if errno != 0 {
+ t.Fatalf("unexpected map error: %v", errno)
+ }
+
+ // Re-parse maps.
+ if err := applyVirtualRegions(c.Containing(addr)); err != nil {
+ syscall.RawSyscall(syscall.SYS_MUNMAP, addr, usermem.PageSize, 0)
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // Assert that it now does contain the region.
+ if !c.ok {
+ syscall.RawSyscall(syscall.SYS_MUNMAP, addr, usermem.PageSize, 0)
+ t.Fatalf("updated map does not contain 0x%08x, expected true", addr)
+ }
+
+ // Map the region as PROT_NONE.
+ newAddr, _, errno := syscall.RawSyscall6(
+ syscall.SYS_MMAP, addr, usermem.PageSize,
+ syscall.PROT_NONE,
+ syscall.MAP_ANONYMOUS|syscall.MAP_FIXED|syscall.MAP_PRIVATE, 0, 0)
+ if errno != 0 {
+ t.Fatalf("unexpected map error: %v", errno)
+ }
+ if newAddr != addr {
+ t.Fatalf("unable to remap address: got 0x%08x, wanted 0x%08x", newAddr, addr)
+ }
+
+ // Re-parse maps.
+ if err := applyVirtualRegions(c.Containing(addr)); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if !c.ok {
+ t.Fatalf("final map does not contain 0x%08x, expected true", addr)
+ }
+ if c.accessType.Read || c.accessType.Write || c.accessType.Execute {
+ t.Fatalf("final map has incorrect permissions for 0x%08x", addr)
+ }
+
+ // Unmap the region.
+ syscall.RawSyscall(syscall.SYS_MUNMAP, addr, usermem.PageSize, 0)
+}
diff --git a/pkg/sentry/platform/mmap_min_addr.go b/pkg/sentry/platform/mmap_min_addr.go
new file mode 100644
index 000000000..091c2e365
--- /dev/null
+++ b/pkg/sentry/platform/mmap_min_addr.go
@@ -0,0 +1,60 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package platform
+
+import (
+ "fmt"
+ "io/ioutil"
+ "strconv"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// systemMMapMinAddrSource is the source file.
+const systemMMapMinAddrSource = "/proc/sys/vm/mmap_min_addr"
+
+// systemMMapMinAddr is the system's minimum map address.
+var systemMMapMinAddr uint64
+
+// SystemMMapMinAddr returns the minimum system address.
+func SystemMMapMinAddr() usermem.Addr {
+ return usermem.Addr(systemMMapMinAddr)
+}
+
+// MMapMinAddr is a size zero struct that implements MinUserAddress based on
+// the system minimum address. It is suitable for embedding in platforms that
+// rely on the system mmap, and thus require the system minimum.
+type MMapMinAddr struct {
+}
+
+// MinUserAddress implements platform.MinUserAddresss.
+func (*MMapMinAddr) MinUserAddress() usermem.Addr {
+ return SystemMMapMinAddr()
+}
+
+func init() {
+ // Open the source file.
+ b, err := ioutil.ReadFile(systemMMapMinAddrSource)
+ if err != nil {
+ panic(fmt.Sprintf("couldn't open %s: %v", systemMMapMinAddrSource, err))
+ }
+
+ // Parse the result.
+ systemMMapMinAddr, err = strconv.ParseUint(strings.TrimSpace(string(b)), 10, 64)
+ if err != nil {
+ panic(fmt.Sprintf("couldn't parse %s from %s: %v", string(b), systemMMapMinAddrSource, err))
+ }
+}
diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go
new file mode 100644
index 000000000..171513f3f
--- /dev/null
+++ b/pkg/sentry/platform/platform.go
@@ -0,0 +1,398 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package platform provides a Platform abstraction.
+//
+// See Platform for more information.
+package platform
+
+import (
+ "fmt"
+ "os"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/seccomp"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Platform provides abstractions for execution contexts (Context,
+// AddressSpace).
+type Platform interface {
+ // SupportsAddressSpaceIO returns true if AddressSpaces returned by this
+ // Platform support AddressSpaceIO methods.
+ //
+ // The value returned by SupportsAddressSpaceIO is guaranteed to remain
+ // unchanged over the lifetime of the Platform.
+ SupportsAddressSpaceIO() bool
+
+ // CooperativelySchedulesAddressSpace returns true if the Platform has a
+ // limited number of AddressSpaces, such that mm.MemoryManager.Deactivate
+ // should call AddressSpace.Release when there are no goroutines that
+ // require the mm.MemoryManager to have an active AddressSpace.
+ //
+ // The value returned by CooperativelySchedulesAddressSpace is guaranteed
+ // to remain unchanged over the lifetime of the Platform.
+ CooperativelySchedulesAddressSpace() bool
+
+ // DetectsCPUPreemption returns true if Contexts returned by the Platform
+ // can reliably return ErrContextCPUPreempted.
+ DetectsCPUPreemption() bool
+
+ // MapUnit returns the alignment used for optional mappings into this
+ // platform's AddressSpaces. Higher values indicate lower per-page costs
+ // for AddressSpace.MapFile. As a special case, a MapUnit of 0 indicates
+ // that the cost of AddressSpace.MapFile is effectively independent of the
+ // number of pages mapped. If MapUnit is non-zero, it must be a power-of-2
+ // multiple of usermem.PageSize.
+ MapUnit() uint64
+
+ // MinUserAddress returns the minimum mappable address on this
+ // platform.
+ MinUserAddress() usermem.Addr
+
+ // MaxUserAddress returns the maximum mappable address on this
+ // platform.
+ MaxUserAddress() usermem.Addr
+
+ // NewAddressSpace returns a new memory context for this platform.
+ //
+ // If mappingsID is not nil, the platform may assume that (1) all calls
+ // to NewAddressSpace with the same mappingsID represent the same
+ // (mutable) set of mappings, and (2) the set of mappings has not
+ // changed since the last time AddressSpace.Release was called on an
+ // AddressSpace returned by a call to NewAddressSpace with the same
+ // mappingsID.
+ //
+ // If a new AddressSpace cannot be created immediately, a nil
+ // AddressSpace is returned, along with channel that is closed when
+ // the caller should retry a call to NewAddressSpace.
+ //
+ // In general, this blocking behavior only occurs when
+ // CooperativelySchedulesAddressSpace (above) returns false.
+ NewAddressSpace(mappingsID interface{}) (AddressSpace, <-chan struct{}, error)
+
+ // NewContext returns a new execution context.
+ NewContext() Context
+
+ // PreemptAllCPUs causes all concurrent calls to Context.Switch(), as well
+ // as the first following call to Context.Switch() for each Context, to
+ // return ErrContextCPUPreempted.
+ //
+ // PreemptAllCPUs is only supported if DetectsCPUPremption() == true.
+ // Platforms for which this does not hold may panic if PreemptAllCPUs is
+ // called.
+ PreemptAllCPUs() error
+
+ // SyscallFilters returns syscalls made exclusively by this platform.
+ SyscallFilters() seccomp.SyscallRules
+}
+
+// NoCPUPreemptionDetection implements Platform.DetectsCPUPreemption and
+// dependent methods for Platforms that do not support this feature.
+type NoCPUPreemptionDetection struct{}
+
+// DetectsCPUPreemption implements Platform.DetectsCPUPreemption.
+func (NoCPUPreemptionDetection) DetectsCPUPreemption() bool {
+ return false
+}
+
+// PreemptAllCPUs implements Platform.PreemptAllCPUs.
+func (NoCPUPreemptionDetection) PreemptAllCPUs() error {
+ panic("This platform does not support CPU preemption detection")
+}
+
+// Context represents the execution context for a single thread.
+type Context interface {
+ // Switch resumes execution of the thread specified by the arch.Context
+ // in the provided address space. This call will block while the thread
+ // is executing.
+ //
+ // If cpu is non-negative, and it is not the number of the CPU that the
+ // thread executes on, Context should return ErrContextCPUPreempted. cpu
+ // can only be non-negative if Platform.DetectsCPUPreemption() is true;
+ // Contexts from Platforms for which this does not hold may ignore cpu, or
+ // panic if cpu is non-negative.
+ //
+ // Switch may return one of the following special errors:
+ //
+ // - nil: The Context invoked a system call.
+ //
+ // - ErrContextSignal: The Context was interrupted by a signal. The
+ // returned *arch.SignalInfo contains information about the signal. If
+ // arch.SignalInfo.Signo == SIGSEGV, the returned usermem.AccessType
+ // contains the access type of the triggering fault. The caller owns
+ // the returned SignalInfo.
+ //
+ // - ErrContextInterrupt: The Context was interrupted by a call to
+ // Interrupt(). Switch() may return ErrContextInterrupt spuriously. In
+ // particular, most implementations of Interrupt() will cause the first
+ // following call to Switch() to return ErrContextInterrupt if there is no
+ // concurrent call to Switch().
+ //
+ // - ErrContextCPUPreempted: See the definition of that error for details.
+ Switch(as AddressSpace, ac arch.Context, cpu int32) (*arch.SignalInfo, usermem.AccessType, error)
+
+ // Interrupt interrupts a concurrent call to Switch(), causing it to return
+ // ErrContextInterrupt.
+ Interrupt()
+
+ // Release() releases any resources associated with this context.
+ Release()
+}
+
+var (
+ // ErrContextSignal is returned by Context.Switch() to indicate that the
+ // Context was interrupted by a signal.
+ ErrContextSignal = fmt.Errorf("interrupted by signal")
+
+ // ErrContextSignalCPUID is equivalent to ErrContextSignal, except that
+ // a check should be done for execution of the CPUID instruction. If
+ // the current instruction pointer is a CPUID instruction, then this
+ // should be emulated appropriately. If not, then the given signal
+ // should be handled per above.
+ ErrContextSignalCPUID = fmt.Errorf("interrupted by signal, possible CPUID")
+
+ // ErrContextInterrupt is returned by Context.Switch() to indicate that the
+ // Context was interrupted by a call to Context.Interrupt().
+ ErrContextInterrupt = fmt.Errorf("interrupted by platform.Context.Interrupt()")
+
+ // ErrContextCPUPreempted is returned by Context.Switch() to indicate that
+ // one of the following occurred:
+ //
+ // - The CPU executing the Context is not the CPU passed to
+ // Context.Switch().
+ //
+ // - The CPU executing the Context may have executed another Context since
+ // the last time it executed this one; or the CPU has previously executed
+ // another Context, and has never executed this one.
+ //
+ // - Platform.PreemptAllCPUs() was called since the last return from
+ // Context.Switch().
+ ErrContextCPUPreempted = fmt.Errorf("interrupted by CPU preemption")
+)
+
+// SignalInterrupt is a signal reserved for use by implementations of
+// Context.Interrupt(). The sentry guarantees that it will ignore delivery of
+// this signal both to Contexts and to the sentry itself, under the assumption
+// that they originate from races with Context.Interrupt().
+//
+// NOTE(b/23420492): The Go runtime only guarantees that a small subset
+// of signals will be always be unblocked on all threads, one of which
+// is SIGCHLD.
+const SignalInterrupt = linux.SIGCHLD
+
+// AddressSpace represents a virtual address space in which a Context can
+// execute.
+type AddressSpace interface {
+ // MapFile creates a shared mapping of offsets fr from f at address addr.
+ // Any existing overlapping mappings are silently replaced.
+ //
+ // If precommit is true, the platform should eagerly commit resources (e.g.
+ // physical memory) to the mapping. The precommit flag is advisory and
+ // implementations may choose to ignore it.
+ //
+ // Preconditions: addr and fr must be page-aligned. fr.Length() > 0.
+ // at.Any() == true. At least one reference must be held on all pages in
+ // fr, and must continue to be held as long as pages are mapped.
+ MapFile(addr usermem.Addr, f File, fr FileRange, at usermem.AccessType, precommit bool) error
+
+ // Unmap unmaps the given range.
+ //
+ // Preconditions: addr is page-aligned. length > 0.
+ Unmap(addr usermem.Addr, length uint64)
+
+ // Release releases this address space. After releasing, a new AddressSpace
+ // must be acquired via platform.NewAddressSpace().
+ Release()
+
+ // AddressSpaceIO methods are supported iff the associated platform's
+ // Platform.SupportsAddressSpaceIO() == true. AddressSpaces for which this
+ // does not hold may panic if AddressSpaceIO methods are invoked.
+ AddressSpaceIO
+}
+
+// AddressSpaceIO supports IO through the memory mappings installed in an
+// AddressSpace.
+//
+// AddressSpaceIO implementors are responsible for ensuring that address ranges
+// are application-mappable.
+type AddressSpaceIO interface {
+ // CopyOut copies len(src) bytes from src to the memory mapped at addr. It
+ // returns the number of bytes copied. If the number of bytes copied is <
+ // len(src), it returns a non-nil error explaining why.
+ CopyOut(addr usermem.Addr, src []byte) (int, error)
+
+ // CopyIn copies len(dst) bytes from the memory mapped at addr to dst.
+ // It returns the number of bytes copied. If the number of bytes copied is
+ // < len(dst), it returns a non-nil error explaining why.
+ CopyIn(addr usermem.Addr, dst []byte) (int, error)
+
+ // ZeroOut sets toZero bytes to 0, starting at addr. It returns the number
+ // of bytes zeroed. If the number of bytes zeroed is < toZero, it returns a
+ // non-nil error explaining why.
+ ZeroOut(addr usermem.Addr, toZero uintptr) (uintptr, error)
+
+ // SwapUint32 atomically sets the uint32 value at addr to new and returns
+ // the previous value.
+ //
+ // Preconditions: addr must be aligned to a 4-byte boundary.
+ SwapUint32(addr usermem.Addr, new uint32) (uint32, error)
+
+ // CompareAndSwapUint32 atomically compares the uint32 value at addr to
+ // old; if they are equal, the value in memory is replaced by new. In
+ // either case, the previous value stored in memory is returned.
+ //
+ // Preconditions: addr must be aligned to a 4-byte boundary.
+ CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error)
+
+ // LoadUint32 atomically loads the uint32 value at addr and returns it.
+ //
+ // Preconditions: addr must be aligned to a 4-byte boundary.
+ LoadUint32(addr usermem.Addr) (uint32, error)
+}
+
+// NoAddressSpaceIO implements AddressSpaceIO methods by panicking.
+type NoAddressSpaceIO struct{}
+
+// CopyOut implements AddressSpaceIO.CopyOut.
+func (NoAddressSpaceIO) CopyOut(addr usermem.Addr, src []byte) (int, error) {
+ panic("This platform does not support AddressSpaceIO")
+}
+
+// CopyIn implements AddressSpaceIO.CopyIn.
+func (NoAddressSpaceIO) CopyIn(addr usermem.Addr, dst []byte) (int, error) {
+ panic("This platform does not support AddressSpaceIO")
+}
+
+// ZeroOut implements AddressSpaceIO.ZeroOut.
+func (NoAddressSpaceIO) ZeroOut(addr usermem.Addr, toZero uintptr) (uintptr, error) {
+ panic("This platform does not support AddressSpaceIO")
+}
+
+// SwapUint32 implements AddressSpaceIO.SwapUint32.
+func (NoAddressSpaceIO) SwapUint32(addr usermem.Addr, new uint32) (uint32, error) {
+ panic("This platform does not support AddressSpaceIO")
+}
+
+// CompareAndSwapUint32 implements AddressSpaceIO.CompareAndSwapUint32.
+func (NoAddressSpaceIO) CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) {
+ panic("This platform does not support AddressSpaceIO")
+}
+
+// LoadUint32 implements AddressSpaceIO.LoadUint32.
+func (NoAddressSpaceIO) LoadUint32(addr usermem.Addr) (uint32, error) {
+ panic("This platform does not support AddressSpaceIO")
+}
+
+// SegmentationFault is an error returned by AddressSpaceIO methods when IO
+// fails due to access of an unmapped page, or a mapped page with insufficient
+// permissions.
+type SegmentationFault struct {
+ // Addr is the address at which the fault occurred.
+ Addr usermem.Addr
+}
+
+// Error implements error.Error.
+func (f SegmentationFault) Error() string {
+ return fmt.Sprintf("segmentation fault at %#x", f.Addr)
+}
+
+// File represents a host file that may be mapped into an AddressSpace.
+type File interface {
+ // All pages in a File are reference-counted.
+
+ // IncRef increments the reference count on all pages in fr.
+ //
+ // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
+ // 0. At least one reference must be held on all pages in fr. (The File
+ // interface does not provide a way to acquire an initial reference;
+ // implementors may define mechanisms for doing so.)
+ IncRef(fr FileRange)
+
+ // DecRef decrements the reference count on all pages in fr.
+ //
+ // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
+ // 0. At least one reference must be held on all pages in fr.
+ DecRef(fr FileRange)
+
+ // MapInternal returns a mapping of the given file offsets in the invoking
+ // process' address space for reading and writing.
+ //
+ // Note that fr.Start and fr.End need not be page-aligned.
+ //
+ // Preconditions: fr.Length() > 0. At least one reference must be held on
+ // all pages in fr.
+ //
+ // Postconditions: The returned mapping is valid as long as at least one
+ // reference is held on the mapped pages.
+ MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error)
+
+ // FD returns the file descriptor represented by the File.
+ //
+ // The only permitted operation on the returned file descriptor is to map
+ // pages from it consistent with the requirements of AddressSpace.MapFile.
+ FD() int
+}
+
+// FileRange represents a range of uint64 offsets into a File.
+//
+// type FileRange <generated using go_generics>
+
+// String implements fmt.Stringer.String.
+func (fr FileRange) String() string {
+ return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End)
+}
+
+// Requirements is used to specify platform specific requirements.
+type Requirements struct {
+ // RequiresCurrentPIDNS indicates that the sandbox has to be started in the
+ // current pid namespace.
+ RequiresCurrentPIDNS bool
+ // RequiresCapSysPtrace indicates that the sandbox has to be started with
+ // the CAP_SYS_PTRACE capability.
+ RequiresCapSysPtrace bool
+}
+
+// Constructor represents a platform type.
+type Constructor interface {
+ // New returns a new platform instance.
+ //
+ // Arguments:
+ //
+ // * deviceFile - the device file (e.g. /dev/kvm for the KVM platform).
+ New(deviceFile *os.File) (Platform, error)
+ OpenDevice() (*os.File, error)
+
+ // Requirements returns platform specific requirements.
+ Requirements() Requirements
+}
+
+// platforms contains all available platform types.
+var platforms = map[string]Constructor{}
+
+// Register registers a new platform type.
+func Register(name string, platform Constructor) {
+ platforms[name] = platform
+}
+
+// Lookup looks up the platform constructor by name.
+func Lookup(name string) (Constructor, error) {
+ p, ok := platforms[name]
+ if !ok {
+ return nil, fmt.Errorf("unknown platform: %v", name)
+ }
+ return p, nil
+}
diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD
new file mode 100644
index 000000000..30402c2df
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/BUILD
@@ -0,0 +1,39 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "ptrace",
+ srcs = [
+ "filters.go",
+ "ptrace.go",
+ "ptrace_amd64.go",
+ "ptrace_arm64.go",
+ "ptrace_arm64_unsafe.go",
+ "ptrace_unsafe.go",
+ "stub_amd64.s",
+ "stub_arm64.s",
+ "stub_unsafe.go",
+ "subprocess.go",
+ "subprocess_amd64.go",
+ "subprocess_arm64.go",
+ "subprocess_linux.go",
+ "subprocess_linux_unsafe.go",
+ "subprocess_unsafe.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/log",
+ "//pkg/procid",
+ "//pkg/safecopy",
+ "//pkg/seccomp",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/hostcpu",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/platform/interrupt",
+ "//pkg/sync",
+ "//pkg/usermem",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/platform/ptrace/filters.go b/pkg/sentry/platform/ptrace/filters.go
new file mode 100644
index 000000000..1e07cfd0d
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/filters.go
@@ -0,0 +1,33 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ptrace
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+// SyscallFilters returns syscalls made exclusively by the ptrace platform.
+func (*PTrace) SyscallFilters() seccomp.SyscallRules {
+ return seccomp.SyscallRules{
+ unix.SYS_GETCPU: {},
+ unix.SYS_SCHED_SETAFFINITY: {},
+ syscall.SYS_PTRACE: {},
+ syscall.SYS_TGKILL: {},
+ syscall.SYS_WAIT4: {},
+ }
+}
diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go
new file mode 100644
index 000000000..08d055e05
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/ptrace.go
@@ -0,0 +1,266 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ptrace provides a ptrace-based implementation of the platform
+// interface. This is useful for development and testing purposes primarily,
+// and runs on stock kernels without special permissions.
+//
+// In a nutshell, it works as follows:
+//
+// The creation of a new address space creates a new child processes with a
+// single thread which is traced by a single goroutine.
+//
+// A context is just a collection of temporary variables. Calling Switch on a
+// context does the following:
+//
+// Locks the runtime thread.
+//
+// Looks up a traced subprocess thread for the current runtime thread. If
+// none exists, the dedicated goroutine is asked to create a new stopped
+// thread in the subprocess. This stopped subprocess thread is then traced
+// by the current thread and this information is stored for subsequent
+// switches.
+//
+// The context is then bound with information about the subprocess thread
+// so that the context may be appropriately interrupted via a signal.
+//
+// The requested operation is performed in the traced subprocess thread
+// (e.g. set registers, execute, return).
+//
+// Lock order:
+//
+// subprocess.mu
+// context.mu
+package ptrace
+
+import (
+ "os"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/interrupt"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+var (
+ // stubStart is the link address for our stub, and determines the
+ // maximum user address. This is valid only after a call to stubInit.
+ //
+ // We attempt to link the stub here, and adjust downward as needed.
+ stubStart uintptr = stubInitAddress
+
+ // stubEnd is the first byte past the end of the stub, as with
+ // stubStart this is valid only after a call to stubInit.
+ stubEnd uintptr
+
+ // stubInitialized controls one-time stub initialization.
+ stubInitialized sync.Once
+)
+
+type context struct {
+ // signalInfo is the signal info, if and when a signal is received.
+ signalInfo arch.SignalInfo
+
+ // interrupt is the interrupt context.
+ interrupt interrupt.Forwarder
+
+ // mu protects the following fields.
+ mu sync.Mutex
+
+ // If lastFaultSP is non-nil, the last context switch was due to a fault
+ // received while executing lastFaultSP. Only context.Switch may set
+ // lastFaultSP to a non-nil value.
+ lastFaultSP *subprocess
+
+ // lastFaultAddr is the last faulting address; this is only meaningful if
+ // lastFaultSP is non-nil.
+ lastFaultAddr usermem.Addr
+
+ // lastFaultIP is the address of the last faulting instruction;
+ // this is also only meaningful if lastFaultSP is non-nil.
+ lastFaultIP usermem.Addr
+}
+
+// Switch runs the provided context in the given address space.
+func (c *context) Switch(as platform.AddressSpace, ac arch.Context, cpu int32) (*arch.SignalInfo, usermem.AccessType, error) {
+ s := as.(*subprocess)
+ isSyscall := s.switchToApp(c, ac)
+
+ var (
+ faultSP *subprocess
+ faultAddr usermem.Addr
+ faultIP usermem.Addr
+ )
+ if !isSyscall && linux.Signal(c.signalInfo.Signo) == linux.SIGSEGV {
+ faultSP = s
+ faultAddr = usermem.Addr(c.signalInfo.Addr())
+ faultIP = usermem.Addr(ac.IP())
+ }
+
+ // Update the context to reflect the outcome of this context switch.
+ c.mu.Lock()
+ lastFaultSP := c.lastFaultSP
+ lastFaultAddr := c.lastFaultAddr
+ lastFaultIP := c.lastFaultIP
+ // At this point, c may not yet be in s.contexts, so c.lastFaultSP won't be
+ // updated by s.Unmap(). This is fine; we only need to synchronize with
+ // calls to s.Unmap() that occur after the handling of this fault.
+ c.lastFaultSP = faultSP
+ c.lastFaultAddr = faultAddr
+ c.lastFaultIP = faultIP
+ c.mu.Unlock()
+
+ // Update subprocesses to reflect the outcome of this context switch.
+ if lastFaultSP != faultSP {
+ if lastFaultSP != nil {
+ lastFaultSP.mu.Lock()
+ delete(lastFaultSP.contexts, c)
+ lastFaultSP.mu.Unlock()
+ }
+ if faultSP != nil {
+ faultSP.mu.Lock()
+ faultSP.contexts[c] = struct{}{}
+ faultSP.mu.Unlock()
+ }
+ }
+
+ if isSyscall {
+ return nil, usermem.NoAccess, nil
+ }
+
+ si := c.signalInfo
+
+ if faultSP == nil {
+ // Non-fault signal.
+ return &si, usermem.NoAccess, platform.ErrContextSignal
+ }
+
+ // Got a page fault. Ideally, we'd get real fault type here, but ptrace
+ // doesn't expose this information. Instead, we use a simple heuristic:
+ //
+ // It was an instruction fault iff the faulting addr == instruction
+ // pointer.
+ //
+ // It was a write fault if the fault is immediately repeated.
+ at := usermem.Read
+ if faultAddr == faultIP {
+ at.Execute = true
+ }
+ if lastFaultSP == faultSP &&
+ lastFaultAddr == faultAddr &&
+ lastFaultIP == faultIP {
+ at.Write = true
+ }
+
+ // Unfortunately, we have to unilaterally return ErrContextSignalCPUID
+ // here, in case this fault was generated by a CPUID exception. There
+ // is no way to distinguish between CPUID-generated faults and regular
+ // page faults.
+ return &si, at, platform.ErrContextSignalCPUID
+}
+
+// Interrupt interrupts the running guest application associated with this context.
+func (c *context) Interrupt() {
+ c.interrupt.NotifyInterrupt()
+}
+
+// Release implements platform.Context.Release().
+func (c *context) Release() {}
+
+// PTrace represents a collection of ptrace subprocesses.
+type PTrace struct {
+ platform.MMapMinAddr
+ platform.NoCPUPreemptionDetection
+}
+
+// New returns a new ptrace-based implementation of the platform interface.
+func New() (*PTrace, error) {
+ stubInitialized.Do(func() {
+ // Initialize the stub.
+ stubInit()
+
+ // Create the master process for the global pool. This must be
+ // done before initializing any other processes.
+ master, err := newSubprocess(createStub)
+ if err != nil {
+ // Should never happen.
+ panic("unable to initialize ptrace master: " + err.Error())
+ }
+
+ // Set the master on the globalPool.
+ globalPool.master = master
+ })
+
+ return &PTrace{}, nil
+}
+
+// SupportsAddressSpaceIO implements platform.Platform.SupportsAddressSpaceIO.
+func (*PTrace) SupportsAddressSpaceIO() bool {
+ return false
+}
+
+// CooperativelySchedulesAddressSpace implements platform.Platform.CooperativelySchedulesAddressSpace.
+func (*PTrace) CooperativelySchedulesAddressSpace() bool {
+ return false
+}
+
+// MapUnit implements platform.Platform.MapUnit.
+func (*PTrace) MapUnit() uint64 {
+ // The host kernel manages page tables and arbitrary-sized mappings
+ // have effectively the same cost.
+ return 0
+}
+
+// MaxUserAddress returns the first address that may not be used by user
+// applications.
+func (*PTrace) MaxUserAddress() usermem.Addr {
+ return usermem.Addr(stubStart)
+}
+
+// NewAddressSpace returns a new subprocess.
+func (p *PTrace) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) {
+ as, err := newSubprocess(globalPool.master.createStub)
+ return as, nil, err
+}
+
+// NewContext returns an interruptible context.
+func (*PTrace) NewContext() platform.Context {
+ return &context{}
+}
+
+type constructor struct{}
+
+func (*constructor) New(*os.File) (platform.Platform, error) {
+ return New()
+}
+
+func (*constructor) OpenDevice() (*os.File, error) {
+ return nil, nil
+}
+
+// Flags implements platform.Constructor.Flags().
+func (*constructor) Requirements() platform.Requirements {
+ // TODO(b/75837838): Also set a new PID namespace so that we limit
+ // access to other host processes.
+ return platform.Requirements{
+ RequiresCapSysPtrace: true,
+ RequiresCurrentPIDNS: true,
+ }
+}
+
+func init() {
+ platform.Register("ptrace", &constructor{})
+}
diff --git a/pkg/sentry/platform/ptrace/ptrace_amd64.go b/pkg/sentry/platform/ptrace/ptrace_amd64.go
new file mode 100644
index 000000000..3b9a870a5
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/ptrace_amd64.go
@@ -0,0 +1,46 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ptrace
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+// fpRegSet returns the GETREGSET/SETREGSET register set type to be used.
+func fpRegSet(useXsave bool) uintptr {
+ if useXsave {
+ return linux.NT_X86_XSTATE
+ }
+ return linux.NT_PRFPREG
+}
+
+func stackPointer(r *arch.Registers) uintptr {
+ return uintptr(r.Rsp)
+}
+
+// x86 use the fs_base register to store the TLS pointer which can be
+// get/set in "func (t *thread) get/setRegs(regs *arch.Registers)".
+// So both of the get/setTLS() operations are noop here.
+
+// getTLS gets the thread local storage register.
+func (t *thread) getTLS(tls *uint64) error {
+ return nil
+}
+
+// setTLS sets the thread local storage register.
+func (t *thread) setTLS(tls *uint64) error {
+ return nil
+}
diff --git a/pkg/sentry/platform/ptrace/ptrace_arm64.go b/pkg/sentry/platform/ptrace/ptrace_arm64.go
new file mode 100644
index 000000000..5c869926a
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/ptrace_arm64.go
@@ -0,0 +1,29 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ptrace
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+// fpRegSet returns the GETREGSET/SETREGSET register set type to be used.
+func fpRegSet(_ bool) uintptr {
+ return linux.NT_PRFPREG
+}
+
+func stackPointer(r *arch.Registers) uintptr {
+ return uintptr(r.Sp)
+}
diff --git a/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go
new file mode 100644
index 000000000..32b8a6be9
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go
@@ -0,0 +1,62 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ptrace
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// getTLS gets the thread local storage register.
+func (t *thread) getTLS(tls *uint64) error {
+ iovec := syscall.Iovec{
+ Base: (*byte)(unsafe.Pointer(tls)),
+ Len: uint64(unsafe.Sizeof(*tls)),
+ }
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_GETREGSET,
+ uintptr(t.tid),
+ linux.NT_ARM_TLS,
+ uintptr(unsafe.Pointer(&iovec)),
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+// setTLS sets the thread local storage register.
+func (t *thread) setTLS(tls *uint64) error {
+ iovec := syscall.Iovec{
+ Base: (*byte)(unsafe.Pointer(tls)),
+ Len: uint64(unsafe.Sizeof(*tls)),
+ }
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_SETREGSET,
+ uintptr(t.tid),
+ linux.NT_ARM_TLS,
+ uintptr(unsafe.Pointer(&iovec)),
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
new file mode 100644
index 000000000..8b72d24e8
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
@@ -0,0 +1,172 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ptrace
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// getRegs gets the general purpose register set.
+func (t *thread) getRegs(regs *arch.Registers) error {
+ iovec := syscall.Iovec{
+ Base: (*byte)(unsafe.Pointer(regs)),
+ Len: uint64(unsafe.Sizeof(*regs)),
+ }
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_GETREGSET,
+ uintptr(t.tid),
+ linux.NT_PRSTATUS,
+ uintptr(unsafe.Pointer(&iovec)),
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+// setRegs sets the general purpose register set.
+func (t *thread) setRegs(regs *arch.Registers) error {
+ iovec := syscall.Iovec{
+ Base: (*byte)(unsafe.Pointer(regs)),
+ Len: uint64(unsafe.Sizeof(*regs)),
+ }
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_SETREGSET,
+ uintptr(t.tid),
+ linux.NT_PRSTATUS,
+ uintptr(unsafe.Pointer(&iovec)),
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+// getFPRegs gets the floating-point data via the GETREGSET ptrace syscall.
+func (t *thread) getFPRegs(fpState *arch.FloatingPointData, fpLen uint64, useXsave bool) error {
+ iovec := syscall.Iovec{
+ Base: (*byte)(fpState),
+ Len: fpLen,
+ }
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_GETREGSET,
+ uintptr(t.tid),
+ fpRegSet(useXsave),
+ uintptr(unsafe.Pointer(&iovec)),
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+// setFPRegs sets the floating-point data via the SETREGSET ptrace syscall.
+func (t *thread) setFPRegs(fpState *arch.FloatingPointData, fpLen uint64, useXsave bool) error {
+ iovec := syscall.Iovec{
+ Base: (*byte)(fpState),
+ Len: fpLen,
+ }
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_SETREGSET,
+ uintptr(t.tid),
+ fpRegSet(useXsave),
+ uintptr(unsafe.Pointer(&iovec)),
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+// getSignalInfo retrieves information about the signal that caused the stop.
+func (t *thread) getSignalInfo(si *arch.SignalInfo) error {
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_GETSIGINFO,
+ uintptr(t.tid),
+ 0,
+ uintptr(unsafe.Pointer(si)),
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+// clone creates a new thread from this one.
+//
+// The returned thread will be stopped and available for any system thread to
+// call attach on it.
+//
+// Precondition: the OS thread must be locked and own t.
+func (t *thread) clone() (*thread, error) {
+ r, ok := usermem.Addr(stackPointer(&t.initRegs)).RoundUp()
+ if !ok {
+ return nil, syscall.EINVAL
+ }
+ rval, err := t.syscallIgnoreInterrupt(
+ &t.initRegs,
+ syscall.SYS_CLONE,
+ arch.SyscallArgument{Value: uintptr(
+ syscall.CLONE_FILES |
+ syscall.CLONE_FS |
+ syscall.CLONE_SIGHAND |
+ syscall.CLONE_THREAD |
+ syscall.CLONE_PTRACE |
+ syscall.CLONE_VM)},
+ // The stack pointer is just made up, but we have it be
+ // something sensible so the kernel doesn't think we're
+ // up to no good. Which we are.
+ arch.SyscallArgument{Value: uintptr(r)},
+ arch.SyscallArgument{},
+ arch.SyscallArgument{},
+ // We use these registers initially, but really they
+ // could be anything. We're going to stop immediately.
+ arch.SyscallArgument{Value: uintptr(unsafe.Pointer(&t.initRegs))})
+ if err != nil {
+ return nil, err
+ }
+
+ return &thread{
+ tgid: t.tgid,
+ tid: int32(rval),
+ cpu: ^uint32(0),
+ }, nil
+}
+
+// getEventMessage retrieves a message about the ptrace event that just happened.
+func (t *thread) getEventMessage() (uintptr, error) {
+ var msg uintptr
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_GETEVENTMSG,
+ uintptr(t.tid),
+ 0,
+ uintptr(unsafe.Pointer(&msg)),
+ 0, 0)
+ if errno != 0 {
+ return msg, errno
+ }
+ return msg, nil
+}
diff --git a/pkg/sentry/platform/ptrace/stub_amd64.s b/pkg/sentry/platform/ptrace/stub_amd64.s
new file mode 100644
index 000000000..16f9c523e
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/stub_amd64.s
@@ -0,0 +1,119 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "funcdata.h"
+#include "textflag.h"
+
+#define SYS_GETPID 39
+#define SYS_EXIT 60
+#define SYS_KILL 62
+#define SYS_GETPPID 110
+#define SYS_PRCTL 157
+
+#define SIGKILL 9
+#define SIGSTOP 19
+
+#define PR_SET_PDEATHSIG 1
+
+// stub bootstraps the child and sends itself SIGSTOP to wait for attach.
+//
+// R15 contains the expected PPID. R15 is used instead of a more typical DI
+// since syscalls will clobber DI and createStub wants to pass a new PPID to
+// grandchildren.
+//
+// This should not be used outside the context of a new ptrace child (as the
+// function is otherwise a bunch of nonsense).
+TEXT ·stub(SB),NOSPLIT,$0
+begin:
+ // N.B. This loop only executes in the context of a single-threaded
+ // fork child.
+
+ MOVQ $SYS_PRCTL, AX
+ MOVQ $PR_SET_PDEATHSIG, DI
+ MOVQ $SIGKILL, SI
+ SYSCALL
+
+ CMPQ AX, $0
+ JNE error
+
+ // If the parent already died before we called PR_SET_DEATHSIG then
+ // we'll have an unexpected PPID.
+ MOVQ $SYS_GETPPID, AX
+ SYSCALL
+
+ CMPQ AX, $0
+ JL error
+
+ CMPQ AX, R15
+ JNE parent_dead
+
+ MOVQ $SYS_GETPID, AX
+ SYSCALL
+
+ CMPQ AX, $0
+ JL error
+
+ MOVQ $0, BX
+
+ // SIGSTOP to wait for attach.
+ //
+ // The SYSCALL instruction will be used for future syscall injection by
+ // thread.syscall.
+ MOVQ AX, DI
+ MOVQ $SYS_KILL, AX
+ MOVQ $SIGSTOP, SI
+ SYSCALL
+
+ // The sentry sets BX to 1 when creating stub process.
+ CMPQ BX, $1
+ JE clone
+
+ // Notify the Sentry that syscall exited.
+done:
+ INT $3
+ // Be paranoid.
+ JMP done
+clone:
+ // subprocess.createStub clones a new stub process that is untraced,
+ // thus executing this code. We setup the PDEATHSIG before SIGSTOPing
+ // ourselves for attach by the tracer.
+ //
+ // R15 has been updated with the expected PPID.
+ CMPQ AX, $0
+ JE begin
+
+ // The clone syscall returns a non-zero value.
+ JMP done
+error:
+ // Exit with -errno.
+ MOVQ AX, DI
+ NEGQ DI
+ MOVQ $SYS_EXIT, AX
+ SYSCALL
+ HLT
+
+parent_dead:
+ MOVQ $SYS_EXIT, AX
+ MOVQ $1, DI
+ SYSCALL
+ HLT
+
+// stubCall calls the stub function at the given address with the given PPID.
+//
+// This is a distinct function because stub, above, may be mapped at any
+// arbitrary location, and stub has a specific binary API (see above).
+TEXT ·stubCall(SB),NOSPLIT,$0-16
+ MOVQ addr+0(FP), AX
+ MOVQ pid+8(FP), R15
+ JMP AX
diff --git a/pkg/sentry/platform/ptrace/stub_arm64.s b/pkg/sentry/platform/ptrace/stub_arm64.s
new file mode 100644
index 000000000..6162df02a
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/stub_arm64.s
@@ -0,0 +1,112 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "funcdata.h"
+#include "textflag.h"
+
+#define SYS_GETPID 172
+#define SYS_EXIT 93
+#define SYS_KILL 129
+#define SYS_GETPPID 173
+#define SYS_PRCTL 167
+
+#define SIGKILL 9
+#define SIGSTOP 19
+
+#define PR_SET_PDEATHSIG 1
+
+// stub bootstraps the child and sends itself SIGSTOP to wait for attach.
+//
+// R7 contains the expected PPID.
+//
+// This should not be used outside the context of a new ptrace child (as the
+// function is otherwise a bunch of nonsense).
+TEXT ·stub(SB),NOSPLIT,$0
+begin:
+ // N.B. This loop only executes in the context of a single-threaded
+ // fork child.
+
+ MOVD $SYS_PRCTL, R8
+ MOVD $PR_SET_PDEATHSIG, R0
+ MOVD $SIGKILL, R1
+ SVC
+
+ CMN $4095, R0
+ BCS error
+
+ // If the parent already died before we called PR_SET_DEATHSIG then
+ // we'll have an unexpected PPID.
+ MOVD $SYS_GETPPID, R8
+ SVC
+
+ CMP R0, R7
+ BNE parent_dead
+
+ MOVD $SYS_GETPID, R8
+ SVC
+
+ CMP $0x0, R0
+ BLT error
+
+ MOVD $0, R9
+
+ // SIGSTOP to wait for attach.
+ //
+ // The SYSCALL instruction will be used for future syscall injection by
+ // thread.syscall.
+ MOVD $SYS_KILL, R8
+ MOVD $SIGSTOP, R1
+ SVC
+
+ // The sentry sets R9 to 1 when creating stub process.
+ CMP $1, R9
+ BEQ clone
+
+done:
+ // Notify the Sentry that syscall exited.
+ BRK $3
+ B done // Be paranoid.
+clone:
+ // subprocess.createStub clones a new stub process that is untraced,
+ // thus executing this code. We setup the PDEATHSIG before SIGSTOPing
+ // ourselves for attach by the tracer.
+ //
+ // R7 has been updated with the expected PPID.
+ CMP $0, R0
+ BEQ begin
+
+ // The clone system call returned a non-zero value.
+ B done
+
+error:
+ // Exit with -errno.
+ NEG R0, R0
+ MOVD $SYS_EXIT, R8
+ SVC
+ HLT
+
+parent_dead:
+ MOVD $SYS_EXIT, R8
+ MOVD $1, R0
+ SVC
+ HLT
+
+// stubCall calls the stub function at the given address with the given PPID.
+//
+// This is a distinct function because stub, above, may be mapped at any
+// arbitrary location, and stub has a specific binary API (see above).
+TEXT ·stubCall(SB),NOSPLIT,$0-16
+ MOVD addr+0(FP), R0
+ MOVD pid+8(FP), R7
+ B (R0)
diff --git a/pkg/sentry/platform/ptrace/stub_unsafe.go b/pkg/sentry/platform/ptrace/stub_unsafe.go
new file mode 100644
index 000000000..341dde143
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/stub_unsafe.go
@@ -0,0 +1,98 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ptrace
+
+import (
+ "reflect"
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/safecopy"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// stub is defined in arch-specific assembly.
+func stub()
+
+// stubCall calls the stub at the given address with the given pid.
+func stubCall(addr, pid uintptr)
+
+// unsafeSlice returns a slice for the given address and length.
+func unsafeSlice(addr uintptr, length int) (slice []byte) {
+ sh := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
+ sh.Data = addr
+ sh.Len = length
+ sh.Cap = length
+ return
+}
+
+// stubInit initializes the stub.
+func stubInit() {
+ // Grab the existing stub.
+ stubBegin := reflect.ValueOf(stub).Pointer()
+ stubLen := int(safecopy.FindEndAddress(stubBegin) - stubBegin)
+ stubSlice := unsafeSlice(stubBegin, stubLen)
+ mapLen := uintptr(stubLen)
+ if offset := mapLen % usermem.PageSize; offset != 0 {
+ mapLen += usermem.PageSize - offset
+ }
+
+ for stubStart > 0 {
+ // Map the target address for the stub.
+ //
+ // We don't use FIXED here because we don't want to unmap
+ // something that may have been there already. We just walk
+ // down the address space until we find a place where the stub
+ // can be placed.
+ addr, _, errno := syscall.RawSyscall6(
+ syscall.SYS_MMAP,
+ stubStart,
+ mapLen,
+ syscall.PROT_WRITE|syscall.PROT_READ,
+ syscall.MAP_PRIVATE|syscall.MAP_ANONYMOUS,
+ 0 /* fd */, 0 /* offset */)
+ if addr != stubStart || errno != 0 {
+ if addr != 0 {
+ // Unmap the region we've mapped accidentally.
+ syscall.RawSyscall(syscall.SYS_MUNMAP, addr, mapLen, 0)
+ }
+
+ // Attempt to begin at a lower address.
+ stubStart -= uintptr(usermem.PageSize)
+ continue
+ }
+
+ // Copy the stub to the address.
+ targetSlice := unsafeSlice(addr, stubLen)
+ copy(targetSlice, stubSlice)
+
+ // Make the stub executable.
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_MPROTECT,
+ stubStart,
+ mapLen,
+ syscall.PROT_EXEC|syscall.PROT_READ); errno != 0 {
+ panic("mprotect failed: " + errno.Error())
+ }
+
+ // Set the end.
+ stubEnd = stubStart + mapLen
+ return
+ }
+
+ // This will happen only if we exhaust the entire address
+ // space, and it will take a long, long time.
+ panic("failed to map stub")
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
new file mode 100644
index 000000000..2389423b0
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -0,0 +1,663 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ptrace
+
+import (
+ "fmt"
+ "os"
+ "runtime"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/procid"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Linux kernel errnos which "should never be seen by user programs", but will
+// be revealed to ptrace syscall exit tracing.
+//
+// These constants are only used in subprocess.go.
+const (
+ ERESTARTSYS = syscall.Errno(512)
+ ERESTARTNOINTR = syscall.Errno(513)
+ ERESTARTNOHAND = syscall.Errno(514)
+)
+
+// globalPool exists to solve two distinct problems:
+//
+// 1) Subprocesses can't always be killed properly (see Release).
+//
+// 2) Any seccomp filters that have been installed will apply to subprocesses
+// created here. Therefore we use the intermediary (master), which is created
+// on initialization of the platform.
+var globalPool struct {
+ mu sync.Mutex
+ master *subprocess
+ available []*subprocess
+}
+
+// thread is a traced thread; it is a thread identifier.
+//
+// This is a convenience type for defining ptrace operations.
+type thread struct {
+ tgid int32
+ tid int32
+ cpu uint32
+
+ // initRegs are the initial registers for the first thread.
+ //
+ // These are used for the register set for system calls.
+ initRegs arch.Registers
+}
+
+// threadPool is a collection of threads.
+type threadPool struct {
+ // mu protects below.
+ mu sync.Mutex
+
+ // threads is the collection of threads.
+ //
+ // This map is indexed by system TID (the calling thread); which will
+ // be the tracer for the given *thread, and therefore capable of using
+ // relevant ptrace calls.
+ threads map[int32]*thread
+}
+
+// lookupOrCreate looks up a given thread or creates one.
+//
+// newThread will generally be subprocess.newThread.
+//
+// Precondition: the runtime OS thread must be locked.
+func (tp *threadPool) lookupOrCreate(currentTID int32, newThread func() *thread) *thread {
+ tp.mu.Lock()
+ t, ok := tp.threads[currentTID]
+ if !ok {
+ // Before creating a new thread, see if we can find a thread
+ // whose system tid has disappeared.
+ //
+ // TODO(b/77216482): Other parts of this package depend on
+ // threads never exiting.
+ for origTID, t := range tp.threads {
+ // Signal zero is an easy existence check.
+ if err := syscall.Tgkill(syscall.Getpid(), int(origTID), 0); err != nil {
+ // This thread has been abandoned; reuse it.
+ delete(tp.threads, origTID)
+ tp.threads[currentTID] = t
+ tp.mu.Unlock()
+ return t
+ }
+ }
+
+ // Create a new thread.
+ t = newThread()
+ tp.threads[currentTID] = t
+ }
+ tp.mu.Unlock()
+ return t
+}
+
+// subprocess is a collection of threads being traced.
+type subprocess struct {
+ platform.NoAddressSpaceIO
+
+ // requests is used to signal creation of new threads.
+ requests chan chan *thread
+
+ // sysemuThreads are reserved for emulation.
+ sysemuThreads threadPool
+
+ // syscallThreads are reserved for syscalls (except clone, which is
+ // handled in the dedicated goroutine corresponding to requests above).
+ syscallThreads threadPool
+
+ // mu protects the following fields.
+ mu sync.Mutex
+
+ // contexts is the set of contexts for which it's possible that
+ // context.lastFaultSP == this subprocess.
+ contexts map[*context]struct{}
+}
+
+// newSubprocess returns a usable subprocess.
+//
+// This will either be a newly created subprocess, or one from the global pool.
+// The create function will be called in the latter case, which is guaranteed
+// to happen with the runtime thread locked.
+func newSubprocess(create func() (*thread, error)) (*subprocess, error) {
+ // See Release.
+ globalPool.mu.Lock()
+ if len(globalPool.available) > 0 {
+ sp := globalPool.available[len(globalPool.available)-1]
+ globalPool.available = globalPool.available[:len(globalPool.available)-1]
+ globalPool.mu.Unlock()
+ return sp, nil
+ }
+ globalPool.mu.Unlock()
+
+ // The following goroutine is responsible for creating the first traced
+ // thread, and responding to requests to make additional threads in the
+ // traced process. The process will be killed and reaped when the
+ // request channel is closed, which happens in Release below.
+ errChan := make(chan error)
+ requests := make(chan chan *thread)
+ go func() { // S/R-SAFE: Platform-related.
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+
+ // Initialize the first thread.
+ firstThread, err := create()
+ if err != nil {
+ errChan <- err
+ return
+ }
+ firstThread.grabInitRegs()
+
+ // Ready to handle requests.
+ errChan <- nil
+
+ // Wait for requests to create threads.
+ for r := range requests {
+ t, err := firstThread.clone()
+ if err != nil {
+ // Should not happen: not recoverable.
+ panic(fmt.Sprintf("error initializing first thread: %v", err))
+ }
+
+ // Since the new thread was created with
+ // clone(CLONE_PTRACE), it will begin execution with
+ // SIGSTOP pending and with this thread as its tracer.
+ // (Hopefully nobody tgkilled it with a signal <
+ // SIGSTOP before the SIGSTOP was delivered, in which
+ // case that signal would be delivered before SIGSTOP.)
+ if sig := t.wait(stopped); sig != syscall.SIGSTOP {
+ panic(fmt.Sprintf("error waiting for new clone: expected SIGSTOP, got %v", sig))
+ }
+
+ // Detach the thread.
+ t.detach()
+ t.initRegs = firstThread.initRegs
+
+ // Return the thread.
+ r <- t
+ }
+
+ // Requests should never be closed.
+ panic("unreachable")
+ }()
+
+ // Wait until error or readiness.
+ if err := <-errChan; err != nil {
+ return nil, err
+ }
+
+ // Ready.
+ sp := &subprocess{
+ requests: requests,
+ sysemuThreads: threadPool{
+ threads: make(map[int32]*thread),
+ },
+ syscallThreads: threadPool{
+ threads: make(map[int32]*thread),
+ },
+ contexts: make(map[*context]struct{}),
+ }
+
+ sp.unmap()
+ return sp, nil
+}
+
+// unmap unmaps non-stub regions of the process.
+//
+// This will panic on failure (which should never happen).
+func (s *subprocess) unmap() {
+ s.Unmap(0, uint64(stubStart))
+ if maximumUserAddress != stubEnd {
+ s.Unmap(usermem.Addr(stubEnd), uint64(maximumUserAddress-stubEnd))
+ }
+}
+
+// Release kills the subprocess.
+//
+// Just kidding! We can't safely co-ordinate the detaching of all the
+// tracees (since the tracers are random runtime threads, and the process
+// won't exit until tracers have been notifier).
+//
+// Therefore we simply unmap everything in the subprocess and return it to the
+// globalPool. This has the added benefit of reducing creation time for new
+// subprocesses.
+func (s *subprocess) Release() {
+ go func() { // S/R-SAFE: Platform.
+ s.unmap()
+ globalPool.mu.Lock()
+ globalPool.available = append(globalPool.available, s)
+ globalPool.mu.Unlock()
+ }()
+}
+
+// newThread creates a new traced thread.
+//
+// Precondition: the OS thread must be locked.
+func (s *subprocess) newThread() *thread {
+ // Ask the first thread to create a new one.
+ r := make(chan *thread)
+ s.requests <- r
+ t := <-r
+
+ // Attach the subprocess to this one.
+ t.attach()
+
+ // Return the new thread, which is now bound.
+ return t
+}
+
+// attach attaches to the thread.
+func (t *thread) attach() {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_ATTACH, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
+ panic(fmt.Sprintf("unable to attach: %v", errno))
+ }
+
+ // PTRACE_ATTACH sends SIGSTOP, and wakes the tracee if it was already
+ // stopped from the SIGSTOP queued by CLONE_PTRACE (see inner loop of
+ // newSubprocess), so we always expect to see signal-delivery-stop with
+ // SIGSTOP.
+ if sig := t.wait(stopped); sig != syscall.SIGSTOP {
+ panic(fmt.Sprintf("wait failed: expected SIGSTOP, got %v", sig))
+ }
+
+ // Initialize options.
+ t.init()
+}
+
+func (t *thread) grabInitRegs() {
+ // Grab registers.
+ //
+ // Note that we adjust the current register RIP value to be just before
+ // the current system call executed. This depends on the definition of
+ // the stub itself.
+ if err := t.getRegs(&t.initRegs); err != nil {
+ panic(fmt.Sprintf("ptrace get regs failed: %v", err))
+ }
+ t.adjustInitRegsRip()
+}
+
+// detach detaches from the thread.
+//
+// Because the SIGSTOP is not suppressed, the thread will enter group-stop.
+func (t *thread) detach() {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_DETACH, uintptr(t.tid), 0, uintptr(syscall.SIGSTOP), 0, 0); errno != 0 {
+ panic(fmt.Sprintf("can't detach new clone: %v", errno))
+ }
+}
+
+// waitOutcome is used for wait below.
+type waitOutcome int
+
+const (
+ // stopped indicates that the process was stopped.
+ stopped waitOutcome = iota
+
+ // killed indicates that the process was killed.
+ killed
+)
+
+func (t *thread) dumpAndPanic(message string) {
+ var regs arch.Registers
+ message += "\n"
+ if err := t.getRegs(&regs); err == nil {
+ message += dumpRegs(&regs)
+ } else {
+ log.Warningf("unable to get registers: %v", err)
+ }
+ message += fmt.Sprintf("stubStart\t = %016x\n", stubStart)
+ panic(message)
+}
+
+func (t *thread) unexpectedStubExit() {
+ msg, err := t.getEventMessage()
+ status := syscall.WaitStatus(msg)
+ if status.Signaled() && status.Signal() == syscall.SIGKILL {
+ // SIGKILL can be only sent by a user or OOM-killer. In both
+ // these cases, we don't need to panic. There is no reasons to
+ // think that something wrong in gVisor.
+ log.Warningf("The ptrace stub process %v has been killed by SIGKILL.", t.tgid)
+ pid := os.Getpid()
+ syscall.Tgkill(pid, pid, syscall.Signal(syscall.SIGKILL))
+ }
+ t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err))
+}
+
+// wait waits for a stop event.
+//
+// Precondition: outcome is a valid waitOutcome.
+func (t *thread) wait(outcome waitOutcome) syscall.Signal {
+ var status syscall.WaitStatus
+
+ for {
+ r, err := syscall.Wait4(int(t.tid), &status, syscall.WALL|syscall.WUNTRACED, nil)
+ if err == syscall.EINTR || err == syscall.EAGAIN {
+ // Wait was interrupted; wait again.
+ continue
+ } else if err != nil {
+ panic(fmt.Sprintf("ptrace wait failed: %v", err))
+ }
+ if int(r) != int(t.tid) {
+ panic(fmt.Sprintf("ptrace wait returned %v, expected %v", r, t.tid))
+ }
+ switch outcome {
+ case stopped:
+ if !status.Stopped() {
+ t.dumpAndPanic(fmt.Sprintf("ptrace status unexpected: got %v, wanted stopped", status))
+ }
+ stopSig := status.StopSignal()
+ if stopSig == 0 {
+ continue // Spurious stop.
+ }
+ if stopSig == syscall.SIGTRAP {
+ if status.TrapCause() == syscall.PTRACE_EVENT_EXIT {
+ t.unexpectedStubExit()
+ }
+ // Re-encode the trap cause the way it's expected.
+ return stopSig | syscall.Signal(status.TrapCause()<<8)
+ }
+ // Not a trap signal.
+ return stopSig
+ case killed:
+ if !status.Exited() && !status.Signaled() {
+ t.dumpAndPanic(fmt.Sprintf("ptrace status unexpected: got %v, wanted exited", status))
+ }
+ return syscall.Signal(status.ExitStatus())
+ default:
+ // Should not happen.
+ t.dumpAndPanic(fmt.Sprintf("unknown outcome: %v", outcome))
+ }
+ }
+}
+
+// destroy kills the thread.
+//
+// Note that this should not be used in the general case; the death of threads
+// will typically cause the death of the parent. This is a utility method for
+// manually created threads.
+func (t *thread) destroy() {
+ t.detach()
+ syscall.Tgkill(int(t.tgid), int(t.tid), syscall.Signal(syscall.SIGKILL))
+ t.wait(killed)
+}
+
+// init initializes trace options.
+func (t *thread) init() {
+ // Set the TRACESYSGOOD option to differentiate real SIGTRAP.
+ // set PTRACE_O_EXITKILL to ensure that the unexpected exit of the
+ // sentry will immediately kill the associated stubs.
+ const PTRACE_O_EXITKILL = 0x100000
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_SETOPTIONS,
+ uintptr(t.tid),
+ 0,
+ syscall.PTRACE_O_TRACESYSGOOD|syscall.PTRACE_O_TRACEEXIT|PTRACE_O_EXITKILL,
+ 0, 0)
+ if errno != 0 {
+ panic(fmt.Sprintf("ptrace set options failed: %v", errno))
+ }
+}
+
+// syscall executes a system call cycle in the traced context.
+//
+// This is _not_ for use by application system calls, rather it is for use when
+// a system call must be injected into the remote context (e.g. mmap, munmap).
+// Note that clones are handled separately.
+func (t *thread) syscall(regs *arch.Registers) (uintptr, error) {
+ // Set registers.
+ if err := t.setRegs(regs); err != nil {
+ panic(fmt.Sprintf("ptrace set regs failed: %v", err))
+ }
+
+ for {
+ // Execute the syscall instruction. The task has to stop on the
+ // trap instruction which is right after the syscall
+ // instruction.
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_CONT, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
+ panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
+ }
+
+ sig := t.wait(stopped)
+ if sig == syscall.SIGTRAP {
+ // Reached syscall-enter-stop.
+ break
+ } else {
+ // Some other signal caused a thread stop; ignore.
+ if sig != syscall.SIGSTOP && sig != syscall.SIGCHLD {
+ log.Warningf("The thread %d:%d has been interrupted by %d", t.tgid, t.tid, sig)
+ }
+ continue
+ }
+ }
+
+ // Grab registers.
+ if err := t.getRegs(regs); err != nil {
+ panic(fmt.Sprintf("ptrace get regs failed: %v", err))
+ }
+
+ return syscallReturnValue(regs)
+}
+
+// syscallIgnoreInterrupt ignores interrupts on the system call thread and
+// restarts the syscall if the kernel indicates that should happen.
+func (t *thread) syscallIgnoreInterrupt(
+ initRegs *arch.Registers,
+ sysno uintptr,
+ args ...arch.SyscallArgument) (uintptr, error) {
+ for {
+ regs := createSyscallRegs(initRegs, sysno, args...)
+ rval, err := t.syscall(&regs)
+ switch err {
+ case ERESTARTSYS:
+ continue
+ case ERESTARTNOINTR:
+ continue
+ case ERESTARTNOHAND:
+ continue
+ default:
+ return rval, err
+ }
+ }
+}
+
+// NotifyInterrupt implements interrupt.Receiver.NotifyInterrupt.
+func (t *thread) NotifyInterrupt() {
+ syscall.Tgkill(int(t.tgid), int(t.tid), syscall.Signal(platform.SignalInterrupt))
+}
+
+// switchToApp is called from the main SwitchToApp entrypoint.
+//
+// This function returns true on a system call, false on a signal.
+func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
+ // Lock the thread for ptrace operations.
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+
+ // Extract floating point state.
+ fpState := ac.FloatingPointData()
+ fpLen, _ := ac.FeatureSet().ExtendedStateSize()
+ useXsave := ac.FeatureSet().UseXsave()
+
+ // Grab our thread from the pool.
+ currentTID := int32(procid.Current())
+ t := s.sysemuThreads.lookupOrCreate(currentTID, s.newThread)
+
+ // Reset necessary registers.
+ regs := &ac.StateData().Regs
+ t.resetSysemuRegs(regs)
+
+ // Extract TLS register
+ tls := uint64(ac.TLS())
+
+ // Check for interrupts, and ensure that future interrupts will signal t.
+ if !c.interrupt.Enable(t) {
+ // Pending interrupt; simulate.
+ c.signalInfo = arch.SignalInfo{Signo: int32(platform.SignalInterrupt)}
+ return false
+ }
+ defer c.interrupt.Disable()
+
+ // Ensure that the CPU set is bound appropriately; this makes the
+ // emulation below several times faster, presumably by avoiding
+ // interprocessor wakeups and by simplifying the schedule.
+ t.bind()
+
+ // Set registers.
+ if err := t.setRegs(regs); err != nil {
+ panic(fmt.Sprintf("ptrace set regs (%+v) failed: %v", regs, err))
+ }
+ if err := t.setFPRegs(fpState, uint64(fpLen), useXsave); err != nil {
+ panic(fmt.Sprintf("ptrace set fpregs (%+v) failed: %v", fpState, err))
+ }
+ if err := t.setTLS(&tls); err != nil {
+ panic(fmt.Sprintf("ptrace set tls (%+v) failed: %v", tls, err))
+ }
+
+ for {
+ // Start running until the next system call.
+ if isSingleStepping(regs) {
+ if _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ unix.PTRACE_SYSEMU_SINGLESTEP,
+ uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
+ panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
+ }
+ } else {
+ if _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ unix.PTRACE_SYSEMU,
+ uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
+ panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
+ }
+ }
+
+ // Wait for the syscall-enter stop.
+ sig := t.wait(stopped)
+
+ // Refresh all registers.
+ if err := t.getRegs(regs); err != nil {
+ panic(fmt.Sprintf("ptrace get regs failed: %v", err))
+ }
+ if err := t.getFPRegs(fpState, uint64(fpLen), useXsave); err != nil {
+ panic(fmt.Sprintf("ptrace get fpregs failed: %v", err))
+ }
+ if err := t.getTLS(&tls); err != nil {
+ panic(fmt.Sprintf("ptrace get tls failed: %v", err))
+ }
+ if !ac.SetTLS(uintptr(tls)) {
+ panic(fmt.Sprintf("tls value %v is invalid", tls))
+ }
+
+ // Is it a system call?
+ if sig == (syscallEvent | syscall.SIGTRAP) {
+ // Ensure registers are sane.
+ updateSyscallRegs(regs)
+ return true
+ } else if sig == syscall.SIGSTOP {
+ // SIGSTOP was delivered to another thread in the same thread
+ // group, which initiated another group stop. Just ignore it.
+ continue
+ }
+
+ // Grab signal information.
+ if err := t.getSignalInfo(&c.signalInfo); err != nil {
+ // Should never happen.
+ panic(fmt.Sprintf("ptrace get signal info failed: %v", err))
+ }
+
+ // We have a signal. We verify however, that the signal was
+ // either delivered from the kernel or from this process. We
+ // don't respect other signals.
+ if c.signalInfo.Code > 0 {
+ // The signal was generated by the kernel. We inspect
+ // the signal information, and may patch it in order to
+ // facilitate vsyscall emulation. See patchSignalInfo.
+ patchSignalInfo(regs, &c.signalInfo)
+ return false
+ } else if c.signalInfo.Code <= 0 && c.signalInfo.Pid() == int32(os.Getpid()) {
+ // The signal was generated by this process. That means
+ // that it was an interrupt or something else that we
+ // should bail for. Note that we ignore signals
+ // generated by other processes.
+ return false
+ }
+ }
+}
+
+// syscall executes the given system call without handling interruptions.
+func (s *subprocess) syscall(sysno uintptr, args ...arch.SyscallArgument) (uintptr, error) {
+ // Grab a thread.
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+ currentTID := int32(procid.Current())
+ t := s.syscallThreads.lookupOrCreate(currentTID, s.newThread)
+
+ return t.syscallIgnoreInterrupt(&t.initRegs, sysno, args...)
+}
+
+// MapFile implements platform.AddressSpace.MapFile.
+func (s *subprocess) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error {
+ var flags int
+ if precommit {
+ flags |= syscall.MAP_POPULATE
+ }
+ _, err := s.syscall(
+ syscall.SYS_MMAP,
+ arch.SyscallArgument{Value: uintptr(addr)},
+ arch.SyscallArgument{Value: uintptr(fr.Length())},
+ arch.SyscallArgument{Value: uintptr(at.Prot())},
+ arch.SyscallArgument{Value: uintptr(flags | syscall.MAP_SHARED | syscall.MAP_FIXED)},
+ arch.SyscallArgument{Value: uintptr(f.FD())},
+ arch.SyscallArgument{Value: uintptr(fr.Start)})
+ return err
+}
+
+// Unmap implements platform.AddressSpace.Unmap.
+func (s *subprocess) Unmap(addr usermem.Addr, length uint64) {
+ ar, ok := addr.ToRange(length)
+ if !ok {
+ panic(fmt.Sprintf("addr %#x + length %#x overflows", addr, length))
+ }
+ s.mu.Lock()
+ for c := range s.contexts {
+ c.mu.Lock()
+ if c.lastFaultSP == s && ar.Contains(c.lastFaultAddr) {
+ // Forget the last fault so that if c faults again, the fault isn't
+ // incorrectly reported as a write fault. If this is being called
+ // due to munmap() of the corresponding vma, handling of the second
+ // fault will fail anyway.
+ c.lastFaultSP = nil
+ delete(s.contexts, c)
+ }
+ c.mu.Unlock()
+ }
+ s.mu.Unlock()
+ _, err := s.syscall(
+ syscall.SYS_MUNMAP,
+ arch.SyscallArgument{Value: uintptr(addr)},
+ arch.SyscallArgument{Value: uintptr(length)})
+ if err != nil {
+ // We never expect this to happen.
+ panic(fmt.Sprintf("munmap(%x, %x)) failed: %v", addr, length, err))
+ }
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go
new file mode 100644
index 000000000..84b699f0d
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go
@@ -0,0 +1,259 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package ptrace
+
+import (
+ "fmt"
+ "strings"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/seccomp"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+const (
+ // maximumUserAddress is the largest possible user address.
+ maximumUserAddress = 0x7ffffffff000
+
+ // stubInitAddress is the initial attempt link address for the stub.
+ stubInitAddress = 0x7fffffff0000
+
+ // initRegsRipAdjustment is the size of the syscall instruction.
+ initRegsRipAdjustment = 2
+)
+
+// resetSysemuRegs sets up emulation registers.
+//
+// This should be called prior to calling sysemu.
+func (t *thread) resetSysemuRegs(regs *arch.Registers) {
+ regs.Cs = t.initRegs.Cs
+ regs.Ss = t.initRegs.Ss
+ regs.Ds = t.initRegs.Ds
+ regs.Es = t.initRegs.Es
+ regs.Fs = t.initRegs.Fs
+ regs.Gs = t.initRegs.Gs
+}
+
+// createSyscallRegs sets up syscall registers.
+//
+// This should be called to generate registers for a system call.
+func createSyscallRegs(initRegs *arch.Registers, sysno uintptr, args ...arch.SyscallArgument) arch.Registers {
+ // Copy initial registers.
+ regs := *initRegs
+
+ // Set our syscall number.
+ regs.Rax = uint64(sysno)
+ if len(args) >= 1 {
+ regs.Rdi = args[0].Uint64()
+ }
+ if len(args) >= 2 {
+ regs.Rsi = args[1].Uint64()
+ }
+ if len(args) >= 3 {
+ regs.Rdx = args[2].Uint64()
+ }
+ if len(args) >= 4 {
+ regs.R10 = args[3].Uint64()
+ }
+ if len(args) >= 5 {
+ regs.R8 = args[4].Uint64()
+ }
+ if len(args) >= 6 {
+ regs.R9 = args[5].Uint64()
+ }
+
+ return regs
+}
+
+// isSingleStepping determines if the registers indicate single-stepping.
+func isSingleStepping(regs *arch.Registers) bool {
+ return (regs.Eflags & arch.X86TrapFlag) != 0
+}
+
+// updateSyscallRegs updates registers after finishing sysemu.
+func updateSyscallRegs(regs *arch.Registers) {
+ // Ptrace puts -ENOSYS in rax on syscall-enter-stop.
+ regs.Rax = regs.Orig_rax
+}
+
+// syscallReturnValue extracts a sensible return from registers.
+func syscallReturnValue(regs *arch.Registers) (uintptr, error) {
+ rval := int64(regs.Rax)
+ if rval < 0 {
+ return 0, syscall.Errno(-rval)
+ }
+ return uintptr(rval), nil
+}
+
+func dumpRegs(regs *arch.Registers) string {
+ var m strings.Builder
+
+ fmt.Fprintf(&m, "Registers:\n")
+ fmt.Fprintf(&m, "\tR15\t = %016x\n", regs.R15)
+ fmt.Fprintf(&m, "\tR14\t = %016x\n", regs.R14)
+ fmt.Fprintf(&m, "\tR13\t = %016x\n", regs.R13)
+ fmt.Fprintf(&m, "\tR12\t = %016x\n", regs.R12)
+ fmt.Fprintf(&m, "\tRbp\t = %016x\n", regs.Rbp)
+ fmt.Fprintf(&m, "\tRbx\t = %016x\n", regs.Rbx)
+ fmt.Fprintf(&m, "\tR11\t = %016x\n", regs.R11)
+ fmt.Fprintf(&m, "\tR10\t = %016x\n", regs.R10)
+ fmt.Fprintf(&m, "\tR9\t = %016x\n", regs.R9)
+ fmt.Fprintf(&m, "\tR8\t = %016x\n", regs.R8)
+ fmt.Fprintf(&m, "\tRax\t = %016x\n", regs.Rax)
+ fmt.Fprintf(&m, "\tRcx\t = %016x\n", regs.Rcx)
+ fmt.Fprintf(&m, "\tRdx\t = %016x\n", regs.Rdx)
+ fmt.Fprintf(&m, "\tRsi\t = %016x\n", regs.Rsi)
+ fmt.Fprintf(&m, "\tRdi\t = %016x\n", regs.Rdi)
+ fmt.Fprintf(&m, "\tOrig_rax = %016x\n", regs.Orig_rax)
+ fmt.Fprintf(&m, "\tRip\t = %016x\n", regs.Rip)
+ fmt.Fprintf(&m, "\tCs\t = %016x\n", regs.Cs)
+ fmt.Fprintf(&m, "\tEflags\t = %016x\n", regs.Eflags)
+ fmt.Fprintf(&m, "\tRsp\t = %016x\n", regs.Rsp)
+ fmt.Fprintf(&m, "\tSs\t = %016x\n", regs.Ss)
+ fmt.Fprintf(&m, "\tFs_base\t = %016x\n", regs.Fs_base)
+ fmt.Fprintf(&m, "\tGs_base\t = %016x\n", regs.Gs_base)
+ fmt.Fprintf(&m, "\tDs\t = %016x\n", regs.Ds)
+ fmt.Fprintf(&m, "\tEs\t = %016x\n", regs.Es)
+ fmt.Fprintf(&m, "\tFs\t = %016x\n", regs.Fs)
+ fmt.Fprintf(&m, "\tGs\t = %016x\n", regs.Gs)
+
+ return m.String()
+}
+
+// adjustInitregsRip adjust the current register RIP value to
+// be just before the system call instruction excution
+func (t *thread) adjustInitRegsRip() {
+ t.initRegs.Rip -= initRegsRipAdjustment
+}
+
+// Pass the expected PPID to the child via R15 when creating stub process.
+func initChildProcessPPID(initregs *arch.Registers, ppid int32) {
+ initregs.R15 = uint64(ppid)
+ // Rbx has to be set to 1 when creating stub process.
+ initregs.Rbx = 1
+}
+
+// patchSignalInfo patches the signal info to account for hitting the seccomp
+// filters from vsyscall emulation, specified below. We allow for SIGSYS as a
+// synchronous trap, but patch the structure to appear like a SIGSEGV with the
+// Rip as the faulting address.
+//
+// Note that this should only be called after verifying that the signalInfo has
+// been generated by the kernel.
+func patchSignalInfo(regs *arch.Registers, signalInfo *arch.SignalInfo) {
+ if linux.Signal(signalInfo.Signo) == linux.SIGSYS {
+ signalInfo.Signo = int32(linux.SIGSEGV)
+
+ // Unwind the kernel emulation, if any has occurred. A SIGSYS is delivered
+ // with the si_call_addr field pointing to the current RIP. This field
+ // aligns with the si_addr field for a SIGSEGV, so we don't need to touch
+ // anything there. We do need to unwind emulation however, so we set the
+ // instruction pointer to the faulting value, and "unpop" the stack.
+ regs.Rip = signalInfo.Addr()
+ regs.Rsp -= 8
+ }
+}
+
+// enableCpuidFault enables cpuid-faulting.
+//
+// This may fail on older kernels or hardware, so we just disregard the result.
+// Host CPUID will be enabled.
+//
+// This is safe to call in an afterFork context.
+//
+//go:nosplit
+func enableCpuidFault() {
+ syscall.RawSyscall6(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0, 0, 0, 0)
+}
+
+// appendArchSeccompRules append architecture specific seccomp rules when creating BPF program.
+// Ref attachedThread() for more detail.
+func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFAction) []seccomp.RuleSet {
+ rules = append(rules,
+ // Rules for trapping vsyscall access.
+ seccomp.RuleSet{
+ Rules: seccomp.SyscallRules{
+ syscall.SYS_GETTIMEOFDAY: {},
+ syscall.SYS_TIME: {},
+ unix.SYS_GETCPU: {}, // SYS_GETCPU was not defined in package syscall on amd64.
+ },
+ Action: linux.SECCOMP_RET_TRAP,
+ Vsyscall: true,
+ })
+ if defaultAction != linux.SECCOMP_RET_ALLOW {
+ rules = append(rules,
+ seccomp.RuleSet{
+ Rules: seccomp.SyscallRules{
+ syscall.SYS_ARCH_PRCTL: []seccomp.Rule{
+ {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)},
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ })
+ }
+ return rules
+}
+
+// probeSeccomp returns true iff seccomp is run after ptrace notifications,
+// which is generally the case for kernel version >= 4.8. This check is dynamic
+// because kernels have be backported behavior.
+//
+// See createStub for more information.
+//
+// Precondition: the runtime OS thread must be locked.
+func probeSeccomp() bool {
+ // Create a completely new, destroyable process.
+ t, err := attachedThread(0, linux.SECCOMP_RET_ERRNO)
+ if err != nil {
+ panic(fmt.Sprintf("seccomp probe failed: %v", err))
+ }
+ defer t.destroy()
+
+ // Set registers to the yield system call. This call is not allowed
+ // by the filters specified in the attachThread function.
+ regs := createSyscallRegs(&t.initRegs, syscall.SYS_SCHED_YIELD)
+ if err := t.setRegs(&regs); err != nil {
+ panic(fmt.Sprintf("ptrace set regs failed: %v", err))
+ }
+
+ for {
+ // Attempt an emulation.
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, unix.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
+ panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
+ }
+
+ sig := t.wait(stopped)
+ if sig == (syscallEvent | syscall.SIGTRAP) {
+ // Did the seccomp errno hook already run? This would
+ // indicate that seccomp is first in line and we're
+ // less than 4.8.
+ if err := t.getRegs(&regs); err != nil {
+ panic(fmt.Sprintf("ptrace get-regs failed: %v", err))
+ }
+ if _, err := syscallReturnValue(&regs); err == nil {
+ // The seccomp errno mode ran first, and reset
+ // the error in the registers.
+ return false
+ }
+ // The seccomp hook did not run yet, and therefore it
+ // is safe to use RET_KILL mode for dispatched calls.
+ return true
+ }
+ }
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go
new file mode 100644
index 000000000..bd618fae8
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go
@@ -0,0 +1,174 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ptrace
+
+import (
+ "fmt"
+ "strings"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/seccomp"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+const (
+ // maximumUserAddress is the largest possible user address.
+ maximumUserAddress = 0xfffffffff000
+
+ // stubInitAddress is the initial attempt link address for the stub.
+ // Only support 48bits VA currently.
+ stubInitAddress = 0xffffffff0000
+
+ // initRegsRipAdjustment is the size of the svc instruction.
+ initRegsRipAdjustment = 4
+)
+
+// resetSysemuRegs sets up emulation registers.
+//
+// This should be called prior to calling sysemu.
+func (t *thread) resetSysemuRegs(regs *arch.Registers) {
+}
+
+// createSyscallRegs sets up syscall registers.
+//
+// This should be called to generate registers for a system call.
+func createSyscallRegs(initRegs *arch.Registers, sysno uintptr, args ...arch.SyscallArgument) arch.Registers {
+ // Copy initial registers (Pc, Sp, etc.).
+ regs := *initRegs
+
+ // Set our syscall number.
+ // r8 for the syscall number.
+ // r0-r6 is used to store the parameters.
+ regs.Regs[8] = uint64(sysno)
+ if len(args) >= 1 {
+ regs.Regs[0] = args[0].Uint64()
+ }
+ if len(args) >= 2 {
+ regs.Regs[1] = args[1].Uint64()
+ }
+ if len(args) >= 3 {
+ regs.Regs[2] = args[2].Uint64()
+ }
+ if len(args) >= 4 {
+ regs.Regs[3] = args[3].Uint64()
+ }
+ if len(args) >= 5 {
+ regs.Regs[4] = args[4].Uint64()
+ }
+ if len(args) >= 6 {
+ regs.Regs[5] = args[5].Uint64()
+ }
+
+ return regs
+}
+
+// isSingleStepping determines if the registers indicate single-stepping.
+func isSingleStepping(regs *arch.Registers) bool {
+ // Refer to the ARM SDM D2.12.3: software step state machine
+ // return (regs.Pstate.SS == 1) && (MDSCR_EL1.SS == 1).
+ //
+ // Since the host Linux kernel will set MDSCR_EL1.SS on our behalf
+ // when we call a single-step ptrace command, we only need to check
+ // the Pstate.SS bit here.
+ return (regs.Pstate & arch.ARMTrapFlag) != 0
+}
+
+// updateSyscallRegs updates registers after finishing sysemu.
+func updateSyscallRegs(regs *arch.Registers) {
+ // No special work is necessary.
+ return
+}
+
+// syscallReturnValue extracts a sensible return from registers.
+func syscallReturnValue(regs *arch.Registers) (uintptr, error) {
+ rval := int64(regs.Regs[0])
+ if rval < 0 {
+ return 0, syscall.Errno(-rval)
+ }
+ return uintptr(rval), nil
+}
+
+func dumpRegs(regs *arch.Registers) string {
+ var m strings.Builder
+
+ fmt.Fprintf(&m, "Registers:\n")
+
+ for i := 0; i < 31; i++ {
+ fmt.Fprintf(&m, "\tRegs[%d]\t = %016x\n", i, regs.Regs[i])
+ }
+ fmt.Fprintf(&m, "\tSp\t = %016x\n", regs.Sp)
+ fmt.Fprintf(&m, "\tPc\t = %016x\n", regs.Pc)
+ fmt.Fprintf(&m, "\tPstate\t = %016x\n", regs.Pstate)
+
+ return m.String()
+}
+
+// adjustInitregsRip adjust the current register RIP value to
+// be just before the system call instruction excution
+func (t *thread) adjustInitRegsRip() {
+ t.initRegs.Pc -= initRegsRipAdjustment
+}
+
+// Pass the expected PPID to the child via X7 when creating stub process
+func initChildProcessPPID(initregs *arch.Registers, ppid int32) {
+ initregs.Regs[7] = uint64(ppid)
+ // R9 has to be set to 1 when creating stub process.
+ initregs.Regs[9] = 1
+}
+
+// patchSignalInfo patches the signal info to account for hitting the seccomp
+// filters from vsyscall emulation, specified below. We allow for SIGSYS as a
+// synchronous trap, but patch the structure to appear like a SIGSEGV with the
+// Rip as the faulting address.
+//
+// Note that this should only be called after verifying that the signalInfo has
+// been generated by the kernel.
+func patchSignalInfo(regs *arch.Registers, signalInfo *arch.SignalInfo) {
+ if linux.Signal(signalInfo.Signo) == linux.SIGSYS {
+ signalInfo.Signo = int32(linux.SIGSEGV)
+
+ // Unwind the kernel emulation, if any has occurred. A SIGSYS is delivered
+ // with the si_call_addr field pointing to the current RIP. This field
+ // aligns with the si_addr field for a SIGSEGV, so we don't need to touch
+ // anything there. We do need to unwind emulation however, so we set the
+ // instruction pointer to the faulting value, and "unpop" the stack.
+ regs.Pc = signalInfo.Addr()
+ regs.Sp -= 8
+ }
+}
+
+// Noop on arm64.
+//
+//go:nosplit
+func enableCpuidFault() {
+}
+
+// appendArchSeccompRules append architecture specific seccomp rules when creating BPF program.
+// Ref attachedThread() for more detail.
+func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFAction) []seccomp.RuleSet {
+ return rules
+}
+
+// probeSeccomp returns true if seccomp is run after ptrace notifications,
+// which is generally the case for kernel version >= 4.8.
+//
+// On arm64, the support of PTRACE_SYSEMU was added in the 5.3 kernel, so
+// probeSeccomp can always return true.
+func probeSeccomp() bool {
+ return true
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
new file mode 100644
index 000000000..2ce528601
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -0,0 +1,259 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package ptrace
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/procid"
+ "gvisor.dev/gvisor/pkg/seccomp"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+const syscallEvent syscall.Signal = 0x80
+
+// createStub creates a fresh stub processes.
+//
+// Precondition: the runtime OS thread must be locked.
+func createStub() (*thread, error) {
+ // The exact interactions of ptrace and seccomp are complex, and
+ // changed in recent kernel versions. Before commit 93e35efb8de45, the
+ // seccomp check is done before the ptrace emulation check. This means
+ // that any calls not matching this list will trigger the seccomp
+ // default action instead of notifying ptrace.
+ //
+ // After commit 93e35efb8de45, the seccomp check is done after the
+ // ptrace emulation check. This simplifies using SYSEMU, since seccomp
+ // will never run for emulation. Seccomp will only run for injected
+ // system calls, and thus we can use RET_KILL as our violation action.
+ var defaultAction linux.BPFAction
+ if probeSeccomp() {
+ log.Infof("Latest seccomp behavior found (kernel >= 4.8 likely)")
+ defaultAction = linux.SECCOMP_RET_KILL_THREAD
+ } else {
+ // We must rely on SYSEMU behavior; tracing with SYSEMU is broken.
+ log.Infof("Legacy seccomp behavior found (kernel < 4.8 likely)")
+ defaultAction = linux.SECCOMP_RET_ALLOW
+ }
+
+ // When creating the new child process, we specify SIGKILL as the
+ // signal to deliver when the child exits. We never expect a subprocess
+ // to exit; they are pooled and reused. This is done to ensure that if
+ // a subprocess is OOM-killed, this process (and all other stubs,
+ // transitively) will be killed as well. It's simply not possible to
+ // safely handle a single stub getting killed: the exact state of
+ // execution is unknown and not recoverable.
+ //
+ // In addition, we set the PTRACE_O_TRACEEXIT option to log more
+ // information about a stub process when it receives a fatal signal.
+ return attachedThread(uintptr(syscall.SIGKILL)|syscall.CLONE_FILES, defaultAction)
+}
+
+// attachedThread returns a new attached thread.
+//
+// Precondition: the runtime OS thread must be locked.
+func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, error) {
+ // Create a BPF program that allows only the system calls needed by the
+ // stub and all its children. This is used to create child stubs
+ // (below), so we must include the ability to fork, but otherwise lock
+ // down available calls only to what is needed.
+ rules := []seccomp.RuleSet{}
+ if defaultAction != linux.SECCOMP_RET_ALLOW {
+ rules = append(rules, seccomp.RuleSet{
+ Rules: seccomp.SyscallRules{
+ syscall.SYS_CLONE: []seccomp.Rule{
+ // Allow creation of new subprocesses (used by the master).
+ {seccomp.AllowValue(syscall.CLONE_FILES | syscall.SIGKILL)},
+ // Allow creation of new threads within a single address space (used by addresss spaces).
+ {seccomp.AllowValue(
+ syscall.CLONE_FILES |
+ syscall.CLONE_FS |
+ syscall.CLONE_SIGHAND |
+ syscall.CLONE_THREAD |
+ syscall.CLONE_PTRACE |
+ syscall.CLONE_VM)},
+ },
+
+ // For the initial process creation.
+ syscall.SYS_WAIT4: {},
+ syscall.SYS_EXIT: {},
+
+ // For the stub prctl dance (all).
+ syscall.SYS_PRCTL: []seccomp.Rule{
+ {seccomp.AllowValue(syscall.PR_SET_PDEATHSIG), seccomp.AllowValue(syscall.SIGKILL)},
+ },
+ syscall.SYS_GETPPID: {},
+
+ // For the stub to stop itself (all).
+ syscall.SYS_GETPID: {},
+ syscall.SYS_KILL: []seccomp.Rule{
+ {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SIGSTOP)},
+ },
+
+ // Injected to support the address space operations.
+ syscall.SYS_MMAP: {},
+ syscall.SYS_MUNMAP: {},
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ })
+ }
+ rules = appendArchSeccompRules(rules, defaultAction)
+ instrs, err := seccomp.BuildProgram(rules, defaultAction)
+ if err != nil {
+ return nil, err
+ }
+
+ // Declare all variables up front in order to ensure that there's no
+ // need for allocations between beforeFork & afterFork.
+ var (
+ pid uintptr
+ ppid uintptr
+ errno syscall.Errno
+ )
+
+ // Remember the current ppid for the pdeathsig race.
+ ppid, _, _ = syscall.RawSyscall(syscall.SYS_GETPID, 0, 0, 0)
+
+ // Among other things, beforeFork masks all signals.
+ beforeFork()
+
+ // Do the clone.
+ pid, _, errno = syscall.RawSyscall6(syscall.SYS_CLONE, flags, 0, 0, 0, 0, 0)
+ if errno != 0 {
+ afterFork()
+ return nil, errno
+ }
+
+ // Is this the parent?
+ if pid != 0 {
+ // Among other things, restore signal mask.
+ afterFork()
+
+ // Initialize the first thread.
+ t := &thread{
+ tgid: int32(pid),
+ tid: int32(pid),
+ cpu: ^uint32(0),
+ }
+ if sig := t.wait(stopped); sig != syscall.SIGSTOP {
+ return nil, fmt.Errorf("wait failed: expected SIGSTOP, got %v", sig)
+ }
+ t.attach()
+ t.grabInitRegs()
+
+ return t, nil
+ }
+
+ // Move the stub to a new session (and thus a new process group). This
+ // prevents the stub from getting PTY job control signals intended only
+ // for the sentry process. We must call this before restoring signal
+ // mask.
+ if _, _, errno := syscall.RawSyscall(syscall.SYS_SETSID, 0, 0, 0); errno != 0 {
+ syscall.RawSyscall(syscall.SYS_EXIT, uintptr(errno), 0, 0)
+ }
+
+ // afterForkInChild resets all signals to their default dispositions
+ // and restores the signal mask to its pre-fork state.
+ afterForkInChild()
+
+ // Explicitly unmask all signals to ensure that the tracer can see
+ // them.
+ if errno := unmaskAllSignals(); errno != 0 {
+ syscall.RawSyscall(syscall.SYS_EXIT, uintptr(errno), 0, 0)
+ }
+
+ // Set an aggressive BPF filter for the stub and all it's children. See
+ // the description of the BPF program built above.
+ if errno := seccomp.SetFilter(instrs); errno != 0 {
+ syscall.RawSyscall(syscall.SYS_EXIT, uintptr(errno), 0, 0)
+ }
+
+ // Enable cpuid-faulting.
+ enableCpuidFault()
+
+ // Call the stub; should not return.
+ stubCall(stubStart, ppid)
+ panic("unreachable")
+}
+
+// createStub creates a stub processes as a child of an existing subprocesses.
+//
+// Precondition: the runtime OS thread must be locked.
+func (s *subprocess) createStub() (*thread, error) {
+ // There's no need to lock the runtime thread here, as this can only be
+ // called from a context that is already locked.
+ currentTID := int32(procid.Current())
+ t := s.syscallThreads.lookupOrCreate(currentTID, s.newThread)
+
+ // Pass the expected PPID to the child via R15.
+ regs := t.initRegs
+ initChildProcessPPID(&regs, t.tgid)
+
+ // Call fork in a subprocess.
+ //
+ // The new child must set up PDEATHSIG to ensure it dies if this
+ // process dies. Since this process could die at any time, this cannot
+ // be done via instrumentation from here.
+ //
+ // Instead, we create the child untraced, which will do the PDEATHSIG
+ // setup and then SIGSTOP itself for our attach below.
+ //
+ // See above re: SIGKILL.
+ pid, err := t.syscallIgnoreInterrupt(
+ &regs,
+ syscall.SYS_CLONE,
+ arch.SyscallArgument{Value: uintptr(syscall.SIGKILL | syscall.CLONE_FILES)},
+ arch.SyscallArgument{Value: 0},
+ arch.SyscallArgument{Value: 0},
+ arch.SyscallArgument{Value: 0},
+ arch.SyscallArgument{Value: 0},
+ arch.SyscallArgument{Value: 0})
+ if err != nil {
+ return nil, fmt.Errorf("creating stub process: %v", err)
+ }
+
+ // Wait for child to enter group-stop, so we don't stop its
+ // bootstrapping work with t.attach below.
+ //
+ // We unfortunately don't have a handy part of memory to write the wait
+ // status. If the wait succeeds, we'll assume that it was the SIGSTOP.
+ // If the child actually exited, the attach below will fail.
+ _, err = t.syscallIgnoreInterrupt(
+ &t.initRegs,
+ syscall.SYS_WAIT4,
+ arch.SyscallArgument{Value: uintptr(pid)},
+ arch.SyscallArgument{Value: 0},
+ arch.SyscallArgument{Value: syscall.WALL | syscall.WUNTRACED},
+ arch.SyscallArgument{Value: 0},
+ arch.SyscallArgument{Value: 0},
+ arch.SyscallArgument{Value: 0})
+ if err != nil {
+ return nil, fmt.Errorf("waiting on stub process: %v", err)
+ }
+
+ childT := &thread{
+ tgid: int32(pid),
+ tid: int32(pid),
+ cpu: ^uint32(0),
+ }
+ childT.attach()
+
+ return childT, nil
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go
new file mode 100644
index 000000000..245b20722
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go
@@ -0,0 +1,95 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+// +build amd64 arm64
+
+package ptrace
+
+import (
+ "sync/atomic"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/hostcpu"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// maskPool contains reusable CPU masks for setting affinity. Unfortunately,
+// runtime.NumCPU doesn't actually record the number of CPUs on the system, it
+// just records the number of CPUs available in the scheduler affinity set at
+// startup. This may a) change over time and b) gives a number far lower than
+// the maximum indexable CPU. To prevent lots of allocation in the hot path, we
+// use a pool to store large masks that we can reuse during bind.
+var maskPool = sync.Pool{
+ New: func() interface{} {
+ const maxCPUs = 1024 // Not a hard limit; see below.
+ return make([]uintptr, maxCPUs/64)
+ },
+}
+
+// unmaskAllSignals unmasks all signals on the current thread.
+//
+//go:nosplit
+func unmaskAllSignals() syscall.Errno {
+ var set linux.SignalSet
+ _, _, errno := syscall.RawSyscall6(syscall.SYS_RT_SIGPROCMASK, linux.SIG_SETMASK, uintptr(unsafe.Pointer(&set)), 0, linux.SignalSetSize, 0, 0)
+ return errno
+}
+
+// setCPU sets the CPU affinity.
+func (t *thread) setCPU(cpu uint32) error {
+ mask := maskPool.Get().([]uintptr)
+ n := int(cpu / 64)
+ v := uintptr(1 << uintptr(cpu%64))
+ if n >= len(mask) {
+ // See maskPool note above. We've actually exceeded the number
+ // of available cores. Grow the mask and return it.
+ mask = make([]uintptr, n+1)
+ }
+ mask[n] |= v
+ if _, _, errno := syscall.RawSyscall(
+ unix.SYS_SCHED_SETAFFINITY,
+ uintptr(t.tid),
+ uintptr(len(mask)*8),
+ uintptr(unsafe.Pointer(&mask[0]))); errno != 0 {
+ return errno
+ }
+ mask[n] &^= v
+ maskPool.Put(mask)
+ return nil
+}
+
+// bind attempts to ensure that the thread is on the same CPU as the current
+// thread. This provides no guarantees as it is fundamentally a racy operation:
+// CPU sets may change and we may be rescheduled in the middle of this
+// operation. As a result, no failures are reported.
+//
+// Precondition: the current runtime thread should be locked.
+func (t *thread) bind() {
+ currentCPU := hostcpu.GetCPU()
+
+ if oldCPU := atomic.SwapUint32(&t.cpu, currentCPU); oldCPU != currentCPU {
+ // Set the affinity on the thread and save the CPU for next
+ // round; we don't expect CPUs to bounce around too frequently.
+ //
+ // (It's worth noting that we could move CPUs between this point
+ // and when the tracee finishes executing. But that would be
+ // roughly the status quo anyways -- we're just maximizing our
+ // chances of colocation, not guaranteeing it.)
+ t.setCPU(currentCPU)
+ }
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_unsafe.go
new file mode 100644
index 000000000..0bee995e4
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/subprocess_unsafe.go
@@ -0,0 +1,33 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build go1.12
+// +build !go1.16
+
+// Check go:linkname function signatures when updating Go version.
+
+package ptrace
+
+import (
+ _ "unsafe" // required for go:linkname.
+)
+
+//go:linkname beforeFork syscall.runtime_BeforeFork
+func beforeFork()
+
+//go:linkname afterFork syscall.runtime_AfterFork
+func afterFork()
+
+//go:linkname afterForkInChild syscall.runtime_AfterForkInChild
+func afterForkInChild()
diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD
new file mode 100644
index 000000000..679b287c3
--- /dev/null
+++ b/pkg/sentry/platform/ring0/BUILD
@@ -0,0 +1,86 @@
+load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template(
+ name = "defs_amd64",
+ srcs = [
+ "defs.go",
+ "defs_amd64.go",
+ "offsets_amd64.go",
+ "x86.go",
+ ],
+ visibility = [":__subpackages__"],
+)
+
+go_template(
+ name = "defs_arm64",
+ srcs = [
+ "aarch64.go",
+ "defs.go",
+ "defs_arm64.go",
+ "offsets_arm64.go",
+ ],
+ visibility = [":__subpackages__"],
+)
+
+go_template_instance(
+ name = "defs_impl_amd64",
+ out = "defs_impl_amd64.go",
+ package = "ring0",
+ template = ":defs_amd64",
+)
+
+go_template_instance(
+ name = "defs_impl_arm64",
+ out = "defs_impl_arm64.go",
+ package = "ring0",
+ template = ":defs_arm64",
+)
+
+genrule(
+ name = "entry_impl_amd64",
+ srcs = ["entry_amd64.s"],
+ outs = ["entry_impl_amd64.s"],
+ cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@",
+ tools = ["//pkg/sentry/platform/ring0/gen_offsets"],
+)
+
+genrule(
+ name = "entry_impl_arm64",
+ srcs = ["entry_arm64.s"],
+ outs = ["entry_impl_arm64.s"],
+ cmd = "(echo -e '// build +arm64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@",
+ tools = ["//pkg/sentry/platform/ring0/gen_offsets"],
+)
+
+go_library(
+ name = "ring0",
+ srcs = [
+ "defs_impl_amd64.go",
+ "defs_impl_arm64.go",
+ "entry_amd64.go",
+ "entry_arm64.go",
+ "entry_impl_amd64.s",
+ "entry_impl_arm64.s",
+ "kernel.go",
+ "kernel_amd64.go",
+ "kernel_arm64.go",
+ "kernel_unsafe.go",
+ "lib_amd64.go",
+ "lib_amd64.s",
+ "lib_arm64.go",
+ "lib_arm64.s",
+ "lib_arm64_unsafe.go",
+ "ring0.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/cpuid",
+ "//pkg/safecopy",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/platform/ring0/pagetables",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go
new file mode 100644
index 000000000..8122ac6e2
--- /dev/null
+++ b/pkg/sentry/platform/ring0/aarch64.go
@@ -0,0 +1,110 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+// Useful bits.
+const (
+ _PGD_PGT_BASE = 0x1000
+ _PGD_PGT_SIZE = 0x1000
+ _PUD_PGT_BASE = 0x2000
+ _PUD_PGT_SIZE = 0x1000
+ _PMD_PGT_BASE = 0x3000
+ _PMD_PGT_SIZE = 0x4000
+ _PTE_PGT_BASE = 0x7000
+ _PTE_PGT_SIZE = 0x1000
+
+ _PSR_D_BIT = 0x00000200
+ _PSR_A_BIT = 0x00000100
+ _PSR_I_BIT = 0x00000080
+ _PSR_F_BIT = 0x00000040
+)
+
+const (
+ // PSR bits
+ PSR_MODE_EL0t = 0x00000000
+ PSR_MODE_EL1t = 0x00000004
+ PSR_MODE_EL1h = 0x00000005
+ PSR_MODE_MASK = 0x0000000f
+
+ // KernelFlagsSet should always be set in the kernel.
+ KernelFlagsSet = PSR_MODE_EL1h
+
+ // UserFlagsSet are always set in userspace.
+ UserFlagsSet = PSR_MODE_EL0t
+
+ KernelFlagsClear = PSR_MODE_MASK
+ UserFlagsClear = PSR_MODE_MASK
+
+ PsrDefaultSet = _PSR_D_BIT | _PSR_A_BIT | _PSR_I_BIT | _PSR_F_BIT
+)
+
+// Vector is an exception vector.
+type Vector uintptr
+
+// Exception vectors.
+const (
+ El1SyncInvalid = iota
+ El1IrqInvalid
+ El1FiqInvalid
+ El1ErrorInvalid
+ El1Sync
+ El1Irq
+ El1Fiq
+ El1Error
+ El0Sync
+ El0Irq
+ El0Fiq
+ El0Error
+ El0Sync_invalid
+ El0Irq_invalid
+ El0Fiq_invalid
+ El0Error_invalid
+ El1Sync_da
+ El1Sync_ia
+ El1Sync_sp_pc
+ El1Sync_undef
+ El1Sync_dbg
+ El1Sync_inv
+ El0Sync_svc
+ El0Sync_da
+ El0Sync_ia
+ El0Sync_fpsimd_acc
+ El0Sync_sve_acc
+ El0Sync_sys
+ El0Sync_sp_pc
+ El0Sync_undef
+ El0Sync_dbg
+ El0Sync_inv
+ _NR_INTERRUPTS
+)
+
+// System call vectors.
+const (
+ Syscall Vector = El0Sync_svc
+ PageFault Vector = El0Sync_da
+ VirtualizationException Vector = El0Error
+)
+
+// VirtualAddressBits returns the number bits available for virtual addresses.
+func VirtualAddressBits() uint32 {
+ return 48
+}
+
+// PhysicalAddressBits returns the number of bits available for physical addresses.
+func PhysicalAddressBits() uint32 {
+ return 40
+}
diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go
new file mode 100644
index 000000000..e6daf24df
--- /dev/null
+++ b/pkg/sentry/platform/ring0/defs.go
@@ -0,0 +1,109 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ring0
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+)
+
+// Kernel is a global kernel object.
+//
+// This contains global state, shared by multiple CPUs.
+type Kernel struct {
+ KernelArchState
+}
+
+// Hooks are hooks for kernel functions.
+type Hooks interface {
+ // KernelSyscall is called for kernel system calls.
+ //
+ // Return from this call will restore registers and return to the kernel: the
+ // registers must be modified directly.
+ //
+ // If this function is not provided, a kernel exception results in halt.
+ //
+ // This must be go:nosplit, as this will be on the interrupt stack.
+ // Closures are permitted, as the pointer to the closure frame is not
+ // passed on the stack.
+ KernelSyscall()
+
+ // KernelException handles an exception during kernel execution.
+ //
+ // Return from this call will restore registers and return to the kernel: the
+ // registers must be modified directly.
+ //
+ // If this function is not provided, a kernel exception results in halt.
+ //
+ // This must be go:nosplit, as this will be on the interrupt stack.
+ // Closures are permitted, as the pointer to the closure frame is not
+ // passed on the stack.
+ KernelException(Vector)
+}
+
+// CPU is the per-CPU struct.
+type CPU struct {
+ // self is a self reference.
+ //
+ // This is always guaranteed to be at offset zero.
+ self *CPU
+
+ // kernel is reference to the kernel that this CPU was initialized
+ // with. This reference is kept for garbage collection purposes: CPU
+ // registers may refer to objects within the Kernel object that cannot
+ // be safely freed.
+ kernel *Kernel
+
+ // CPUArchState is architecture-specific state.
+ CPUArchState
+
+ // registers is a set of registers; these may be used on kernel system
+ // calls and exceptions via the Registers function.
+ registers arch.Registers
+
+ // hooks are kernel hooks.
+ hooks Hooks
+}
+
+// Registers returns a modifiable-copy of the kernel registers.
+//
+// This is explicitly safe to call during KernelException and KernelSyscall.
+//
+//go:nosplit
+func (c *CPU) Registers() *arch.Registers {
+ return &c.registers
+}
+
+// SwitchOpts are passed to the Switch function.
+type SwitchOpts struct {
+ // Registers are the user register state.
+ Registers *arch.Registers
+
+ // FloatingPointState is a byte pointer where floating point state is
+ // saved and restored.
+ FloatingPointState *byte
+
+ // PageTables are the application page tables.
+ PageTables *pagetables.PageTables
+
+ // Flush indicates that a TLB flush should be forced on switch.
+ Flush bool
+
+ // FullRestore indicates that an iret-based restore should be used.
+ FullRestore bool
+
+ // SwitchArchOpts are architecture-specific options.
+ SwitchArchOpts
+}
diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go
new file mode 100644
index 000000000..9c6c2cf5c
--- /dev/null
+++ b/pkg/sentry/platform/ring0/defs_amd64.go
@@ -0,0 +1,148 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package ring0
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+var (
+ // UserspaceSize is the total size of userspace.
+ UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1)
+
+ // MaximumUserAddress is the largest possible user address.
+ MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1)
+
+ // KernelStartAddress is the starting kernel address.
+ KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1)
+)
+
+// Segment indices and Selectors.
+const (
+ // Index into GDT array.
+ _ = iota // Null descriptor first.
+ _ // Reserved (Linux is kernel 32).
+ segKcode // Kernel code (64-bit).
+ segKdata // Kernel data.
+ segUcode32 // User code (32-bit).
+ segUdata // User data.
+ segUcode64 // User code (64-bit).
+ segTss // Task segment descriptor.
+ segTssHi // Upper bits for TSS.
+ segLast // Last segment (terminal, not included).
+)
+
+// Selectors.
+const (
+ Kcode Selector = segKcode << 3
+ Kdata Selector = segKdata << 3
+ Ucode32 Selector = (segUcode32 << 3) | 3
+ Udata Selector = (segUdata << 3) | 3
+ Ucode64 Selector = (segUcode64 << 3) | 3
+ Tss Selector = segTss << 3
+)
+
+// Standard segments.
+var (
+ UserCodeSegment32 SegmentDescriptor
+ UserDataSegment SegmentDescriptor
+ UserCodeSegment64 SegmentDescriptor
+ KernelCodeSegment SegmentDescriptor
+ KernelDataSegment SegmentDescriptor
+)
+
+// KernelOpts has initialization options for the kernel.
+type KernelOpts struct {
+ // PageTables are the kernel pagetables; this must be provided.
+ PageTables *pagetables.PageTables
+}
+
+// KernelArchState contains architecture-specific state.
+type KernelArchState struct {
+ KernelOpts
+
+ // globalIDT is our set of interrupt gates.
+ globalIDT idt64
+}
+
+// CPUArchState contains CPU-specific arch state.
+type CPUArchState struct {
+ // stack is the stack used for interrupts on this CPU.
+ stack [256]byte
+
+ // errorCode is the error code from the last exception.
+ errorCode uintptr
+
+ // errorType indicates the type of error code here, it is always set
+ // along with the errorCode value above.
+ //
+ // It will either by 1, which indicates a user error, or 0 indicating a
+ // kernel error. If the error code below returns false (kernel error),
+ // then it cannot provide relevant information about the last
+ // exception.
+ errorType uintptr
+
+ // gdt is the CPU's descriptor table.
+ gdt descriptorTable
+
+ // tss is the CPU's task state.
+ tss TaskState64
+}
+
+// ErrorCode returns the last error code.
+//
+// The returned boolean indicates whether the error code corresponds to the
+// last user error or not. If it does not, then fault information must be
+// ignored. This is generally the result of a kernel fault while servicing a
+// user fault.
+//
+//go:nosplit
+func (c *CPU) ErrorCode() (value uintptr, user bool) {
+ return c.errorCode, c.errorType != 0
+}
+
+// ClearErrorCode resets the error code.
+//
+//go:nosplit
+func (c *CPU) ClearErrorCode() {
+ c.errorCode = 0 // No code.
+ c.errorType = 1 // User mode.
+}
+
+// SwitchArchOpts are embedded in SwitchOpts.
+type SwitchArchOpts struct {
+ // UserPCID indicates that the application PCID to be used on switch,
+ // assuming that PCIDs are supported.
+ //
+ // Per pagetables_x86.go, a zero PCID implies a flush.
+ UserPCID uint16
+
+ // KernelPCID indicates that the kernel PCID to be used on return,
+ // assuming that PCIDs are supported.
+ //
+ // Per pagetables_x86.go, a zero PCID implies a flush.
+ KernelPCID uint16
+}
+
+func init() {
+ KernelCodeSegment.setCode64(0, 0, 0)
+ KernelDataSegment.setData(0, 0xffffffff, 0)
+ UserCodeSegment32.setCode64(0, 0, 3)
+ UserDataSegment.setData(0, 0xffffffff, 3)
+ UserCodeSegment64.setCode64(0, 0, 3)
+}
diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go
new file mode 100644
index 000000000..0e2ab716c
--- /dev/null
+++ b/pkg/sentry/platform/ring0/defs_arm64.go
@@ -0,0 +1,143 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+var (
+ // UserspaceSize is the total size of userspace.
+ UserspaceSize = uintptr(1) << (VirtualAddressBits())
+
+ // MaximumUserAddress is the largest possible user address.
+ MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1)
+
+ // KernelStartAddress is the starting kernel address.
+ KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1)
+)
+
+// KernelOpts has initialization options for the kernel.
+type KernelOpts struct {
+ // PageTables are the kernel pagetables; this must be provided.
+ PageTables *pagetables.PageTables
+}
+
+// KernelArchState contains architecture-specific state.
+type KernelArchState struct {
+ KernelOpts
+}
+
+// CPUArchState contains CPU-specific arch state.
+type CPUArchState struct {
+ // stack is the stack used for interrupts on this CPU.
+ stack [512]byte
+
+ // errorCode is the error code from the last exception.
+ errorCode uintptr
+
+ // errorType indicates the type of error code here, it is always set
+ // along with the errorCode value above.
+ //
+ // It will either by 1, which indicates a user error, or 0 indicating a
+ // kernel error. If the error code below returns false (kernel error),
+ // then it cannot provide relevant information about the last
+ // exception.
+ errorType uintptr
+
+ // faultAddr is the value of far_el1.
+ faultAddr uintptr
+
+ // ttbr0Kvm is the value of ttbr0_el1 for sentry.
+ ttbr0Kvm uintptr
+
+ // ttbr0App is the value of ttbr0_el1 for applicaton.
+ ttbr0App uintptr
+
+ // exception vector.
+ vecCode Vector
+
+ // application context pointer.
+ appAddr uintptr
+
+ // lazyVFP is the value of cpacr_el1.
+ lazyVFP uintptr
+}
+
+// ErrorCode returns the last error code.
+//
+// The returned boolean indicates whether the error code corresponds to the
+// last user error or not. If it does not, then fault information must be
+// ignored. This is generally the result of a kernel fault while servicing a
+// user fault.
+//
+//go:nosplit
+func (c *CPU) ErrorCode() (value uintptr, user bool) {
+ return c.errorCode, c.errorType != 0
+}
+
+// ClearErrorCode resets the error code.
+//
+//go:nosplit
+func (c *CPU) ClearErrorCode() {
+ c.errorCode = 0 // No code.
+ c.errorType = 1 // User mode.
+}
+
+//go:nosplit
+func (c *CPU) GetFaultAddr() (value uintptr) {
+ return c.faultAddr
+}
+
+//go:nosplit
+func (c *CPU) SetTtbr0Kvm(value uintptr) {
+ c.ttbr0Kvm = value
+}
+
+//go:nosplit
+func (c *CPU) SetTtbr0App(value uintptr) {
+ c.ttbr0App = value
+}
+
+//go:nosplit
+func (c *CPU) GetVector() (value Vector) {
+ return c.vecCode
+}
+
+//go:nosplit
+func (c *CPU) SetAppAddr(value uintptr) {
+ c.appAddr = value
+}
+
+// GetLazyVFP returns the value of cpacr_el1.
+//go:nosplit
+func (c *CPU) GetLazyVFP() (value uintptr) {
+ return c.lazyVFP
+}
+
+// SwitchArchOpts are embedded in SwitchOpts.
+type SwitchArchOpts struct {
+ // UserASID indicates that the application ASID to be used on switch,
+ UserASID uint16
+
+ // KernelASID indicates that the kernel ASID to be used on return,
+ KernelASID uint16
+}
+
+func init() {
+}
diff --git a/pkg/sentry/platform/ring0/entry_amd64.go b/pkg/sentry/platform/ring0/entry_amd64.go
new file mode 100644
index 000000000..7fa43c2f5
--- /dev/null
+++ b/pkg/sentry/platform/ring0/entry_amd64.go
@@ -0,0 +1,128 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package ring0
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+// This is an assembly function.
+//
+// The sysenter function is invoked in two situations:
+//
+// (1) The guest kernel has executed a system call.
+// (2) The guest application has executed a system call.
+//
+// The interrupt flag is examined to determine whether the system call was
+// executed from kernel mode or not and the appropriate stub is called.
+func sysenter()
+
+// swapgs swaps the current GS value.
+//
+// This must be called prior to sysret/iret.
+func swapgs()
+
+// sysret returns to userspace from a system call.
+//
+// The return code is the vector that interrupted execution.
+//
+// See stubs.go for a note regarding the frame size of this function.
+func sysret(*CPU, *arch.Registers) Vector
+
+// "iret is the cadillac of CPL switching."
+//
+// -- Neel Natu
+//
+// iret is nearly identical to sysret, except an iret is used to fully restore
+// all user state. This must be called in cases where all registers need to be
+// restored.
+func iret(*CPU, *arch.Registers) Vector
+
+// exception is the generic exception entry.
+//
+// This is called by the individual stub definitions.
+func exception()
+
+// resume is a stub that restores the CPU kernel registers.
+//
+// This is used when processing kernel exceptions and syscalls.
+func resume()
+
+// Start is the CPU entrypoint.
+//
+// The following start conditions must be satisfied:
+//
+// * AX should contain the CPU pointer.
+// * c.GDT() should be loaded as the GDT.
+// * c.IDT() should be loaded as the IDT.
+// * c.CR0() should be the current CR0 value.
+// * c.CR3() should be set to the kernel PageTables.
+// * c.CR4() should be the current CR4 value.
+// * c.EFER() should be the current EFER value.
+//
+// The CPU state will be set to c.Registers().
+func Start()
+
+// Exception stubs.
+func divideByZero()
+func debug()
+func nmi()
+func breakpoint()
+func overflow()
+func boundRangeExceeded()
+func invalidOpcode()
+func deviceNotAvailable()
+func doubleFault()
+func coprocessorSegmentOverrun()
+func invalidTSS()
+func segmentNotPresent()
+func stackSegmentFault()
+func generalProtectionFault()
+func pageFault()
+func x87FloatingPointException()
+func alignmentCheck()
+func machineCheck()
+func simdFloatingPointException()
+func virtualizationException()
+func securityException()
+func syscallInt80()
+
+// Exception handler index.
+var handlers = map[Vector]func(){
+ DivideByZero: divideByZero,
+ Debug: debug,
+ NMI: nmi,
+ Breakpoint: breakpoint,
+ Overflow: overflow,
+ BoundRangeExceeded: boundRangeExceeded,
+ InvalidOpcode: invalidOpcode,
+ DeviceNotAvailable: deviceNotAvailable,
+ DoubleFault: doubleFault,
+ CoprocessorSegmentOverrun: coprocessorSegmentOverrun,
+ InvalidTSS: invalidTSS,
+ SegmentNotPresent: segmentNotPresent,
+ StackSegmentFault: stackSegmentFault,
+ GeneralProtectionFault: generalProtectionFault,
+ PageFault: pageFault,
+ X87FloatingPointException: x87FloatingPointException,
+ AlignmentCheck: alignmentCheck,
+ MachineCheck: machineCheck,
+ SIMDFloatingPointException: simdFloatingPointException,
+ VirtualizationException: virtualizationException,
+ SecurityException: securityException,
+ SyscallInt80: syscallInt80,
+}
diff --git a/pkg/sentry/platform/ring0/entry_amd64.s b/pkg/sentry/platform/ring0/entry_amd64.s
new file mode 100644
index 000000000..02df38331
--- /dev/null
+++ b/pkg/sentry/platform/ring0/entry_amd64.s
@@ -0,0 +1,319 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "funcdata.h"
+#include "textflag.h"
+
+// NB: Offsets are programmatically generated (see BUILD).
+//
+// This file is concatenated with the definitions.
+
+// Saves a register set.
+//
+// This is a macro because it may need to executed in contents where a stack is
+// not available for calls.
+//
+// The following registers are not saved: AX, SP, IP, FLAGS, all segments.
+#define REGISTERS_SAVE(reg, offset) \
+ MOVQ R15, offset+PTRACE_R15(reg); \
+ MOVQ R14, offset+PTRACE_R14(reg); \
+ MOVQ R13, offset+PTRACE_R13(reg); \
+ MOVQ R12, offset+PTRACE_R12(reg); \
+ MOVQ BP, offset+PTRACE_RBP(reg); \
+ MOVQ BX, offset+PTRACE_RBX(reg); \
+ MOVQ CX, offset+PTRACE_RCX(reg); \
+ MOVQ DX, offset+PTRACE_RDX(reg); \
+ MOVQ R11, offset+PTRACE_R11(reg); \
+ MOVQ R10, offset+PTRACE_R10(reg); \
+ MOVQ R9, offset+PTRACE_R9(reg); \
+ MOVQ R8, offset+PTRACE_R8(reg); \
+ MOVQ SI, offset+PTRACE_RSI(reg); \
+ MOVQ DI, offset+PTRACE_RDI(reg);
+
+// Loads a register set.
+//
+// This is a macro because it may need to executed in contents where a stack is
+// not available for calls.
+//
+// The following registers are not loaded: AX, SP, IP, FLAGS, all segments.
+#define REGISTERS_LOAD(reg, offset) \
+ MOVQ offset+PTRACE_R15(reg), R15; \
+ MOVQ offset+PTRACE_R14(reg), R14; \
+ MOVQ offset+PTRACE_R13(reg), R13; \
+ MOVQ offset+PTRACE_R12(reg), R12; \
+ MOVQ offset+PTRACE_RBP(reg), BP; \
+ MOVQ offset+PTRACE_RBX(reg), BX; \
+ MOVQ offset+PTRACE_RCX(reg), CX; \
+ MOVQ offset+PTRACE_RDX(reg), DX; \
+ MOVQ offset+PTRACE_R11(reg), R11; \
+ MOVQ offset+PTRACE_R10(reg), R10; \
+ MOVQ offset+PTRACE_R9(reg), R9; \
+ MOVQ offset+PTRACE_R8(reg), R8; \
+ MOVQ offset+PTRACE_RSI(reg), SI; \
+ MOVQ offset+PTRACE_RDI(reg), DI;
+
+// SWAP_GS swaps the kernel GS (CPU).
+#define SWAP_GS() \
+ BYTE $0x0F; BYTE $0x01; BYTE $0xf8;
+
+// IRET returns from an interrupt frame.
+#define IRET() \
+ BYTE $0x48; BYTE $0xcf;
+
+// SYSRET64 executes the sysret instruction.
+#define SYSRET64() \
+ BYTE $0x48; BYTE $0x0f; BYTE $0x07;
+
+// LOAD_KERNEL_ADDRESS loads a kernel address.
+#define LOAD_KERNEL_ADDRESS(from, to) \
+ MOVQ from, to; \
+ ORQ ·KernelStartAddress(SB), to;
+
+// LOAD_KERNEL_STACK loads the kernel stack.
+#define LOAD_KERNEL_STACK(from) \
+ LOAD_KERNEL_ADDRESS(CPU_SELF(from), SP); \
+ LEAQ CPU_STACK_TOP(SP), SP;
+
+// See kernel.go.
+TEXT ·Halt(SB),NOSPLIT,$0
+ HLT
+ RET
+
+// See entry_amd64.go.
+TEXT ·swapgs(SB),NOSPLIT,$0
+ SWAP_GS()
+ RET
+
+// See entry_amd64.go.
+TEXT ·sysret(SB),NOSPLIT,$0-24
+ // Save original state.
+ LOAD_KERNEL_ADDRESS(cpu+0(FP), BX)
+ LOAD_KERNEL_ADDRESS(regs+8(FP), AX)
+ MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX)
+ MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX)
+ MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX)
+
+ // Restore user register state.
+ REGISTERS_LOAD(AX, 0)
+ MOVQ PTRACE_RIP(AX), CX // Needed for SYSRET.
+ MOVQ PTRACE_FLAGS(AX), R11 // Needed for SYSRET.
+ MOVQ PTRACE_RSP(AX), SP // Restore the stack directly.
+ MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch).
+ SYSRET64()
+
+// See entry_amd64.go.
+TEXT ·iret(SB),NOSPLIT,$0-24
+ // Save original state.
+ LOAD_KERNEL_ADDRESS(cpu+0(FP), BX)
+ LOAD_KERNEL_ADDRESS(regs+8(FP), AX)
+ MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX)
+ MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX)
+ MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX)
+
+ // Build an IRET frame & restore state.
+ LOAD_KERNEL_STACK(BX)
+ MOVQ PTRACE_SS(AX), BX; PUSHQ BX
+ MOVQ PTRACE_RSP(AX), CX; PUSHQ CX
+ MOVQ PTRACE_FLAGS(AX), DX; PUSHQ DX
+ MOVQ PTRACE_CS(AX), DI; PUSHQ DI
+ MOVQ PTRACE_RIP(AX), SI; PUSHQ SI
+ REGISTERS_LOAD(AX, 0) // Restore most registers.
+ MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch).
+ IRET()
+
+// See entry_amd64.go.
+TEXT ·resume(SB),NOSPLIT,$0
+ // See iret, above.
+ MOVQ CPU_REGISTERS+PTRACE_SS(GS), BX; PUSHQ BX
+ MOVQ CPU_REGISTERS+PTRACE_RSP(GS), CX; PUSHQ CX
+ MOVQ CPU_REGISTERS+PTRACE_FLAGS(GS), DX; PUSHQ DX
+ MOVQ CPU_REGISTERS+PTRACE_CS(GS), DI; PUSHQ DI
+ MOVQ CPU_REGISTERS+PTRACE_RIP(GS), SI; PUSHQ SI
+ REGISTERS_LOAD(GS, CPU_REGISTERS)
+ MOVQ CPU_REGISTERS+PTRACE_RAX(GS), AX
+ IRET()
+
+// See entry_amd64.go.
+TEXT ·Start(SB),NOSPLIT,$0
+ LOAD_KERNEL_STACK(AX) // Set the stack.
+ PUSHQ $0x0 // Previous frame pointer.
+ MOVQ SP, BP // Set frame pointer.
+ PUSHQ AX // First argument (CPU).
+ CALL ·start(SB) // Call Go hook.
+ JMP ·resume(SB) // Restore to registers.
+
+// See entry_amd64.go.
+TEXT ·sysenter(SB),NOSPLIT,$0
+ // Interrupts are always disabled while we're executing in kernel mode
+ // and always enabled while executing in user mode. Therefore, we can
+ // reliably look at the flags in R11 to determine where this syscall
+ // was from.
+ TESTL $_RFLAGS_IF, R11
+ JZ kernel
+
+user:
+ SWAP_GS()
+ XCHGQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Swap stacks.
+ XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for AX (regs).
+ REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX.
+ MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Load saved AX value.
+ MOVQ BX, PTRACE_RAX(AX) // Save everything else.
+ MOVQ BX, PTRACE_ORIGRAX(AX)
+ MOVQ CX, PTRACE_RIP(AX)
+ MOVQ R11, PTRACE_FLAGS(AX)
+ MOVQ CPU_REGISTERS+PTRACE_RSP(GS), BX; MOVQ BX, PTRACE_RSP(AX)
+ MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code.
+ MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user.
+
+ // Return to the kernel, where the frame is:
+ //
+ // vector (sp+24)
+ // regs (sp+16)
+ // cpu (sp+8)
+ // vcpu.Switch (sp+0)
+ //
+ MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer.
+ MOVQ $Syscall, 24(SP) // Output vector.
+ RET
+
+kernel:
+ // We can't restore the original stack, but we can access the registers
+ // in the CPU state directly. No need for temporary juggling.
+ MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS)
+ MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS)
+ REGISTERS_SAVE(GS, CPU_REGISTERS)
+ MOVQ CX, CPU_REGISTERS+PTRACE_RIP(GS)
+ MOVQ R11, CPU_REGISTERS+PTRACE_FLAGS(GS)
+ MOVQ SP, CPU_REGISTERS+PTRACE_RSP(GS)
+ MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code.
+ MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel.
+
+ // Call the syscall trampoline.
+ LOAD_KERNEL_STACK(GS)
+ MOVQ CPU_SELF(GS), AX // Load vCPU.
+ PUSHQ AX // First argument (vCPU).
+ CALL ·kernelSyscall(SB) // Call the trampoline.
+ POPQ AX // Pop vCPU.
+ JMP ·resume(SB)
+
+// exception is a generic exception handler.
+//
+// There are two cases handled:
+//
+// 1) An exception in kernel mode: this results in saving the state at the time
+// of the exception and calling the defined hook.
+//
+// 2) An exception in guest mode: the original kernel frame is restored, and
+// the vector & error codes are pushed as return values.
+//
+// See below for the stubs that call exception.
+TEXT ·exception(SB),NOSPLIT,$0
+ // Determine whether the exception occurred in kernel mode or user
+ // mode, based on the flags. We expect the following stack:
+ //
+ // SS (sp+48)
+ // SP (sp+40)
+ // FLAGS (sp+32)
+ // CS (sp+24)
+ // IP (sp+16)
+ // ERROR_CODE (sp+8)
+ // VECTOR (sp+0)
+ //
+ TESTL $_RFLAGS_IF, 32(SP)
+ JZ kernel
+
+user:
+ SWAP_GS()
+ ADDQ $-8, SP // Adjust for flags.
+ MOVQ $_KERNEL_FLAGS, 0(SP); BYTE $0x9d; // Reset flags (POPFQ).
+ XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for user regs.
+ REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX.
+ MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Restore original AX.
+ MOVQ BX, PTRACE_RAX(AX) // Save it.
+ MOVQ BX, PTRACE_ORIGRAX(AX)
+ MOVQ 16(SP), BX; MOVQ BX, PTRACE_RIP(AX)
+ MOVQ 24(SP), CX; MOVQ CX, PTRACE_CS(AX)
+ MOVQ 32(SP), DX; MOVQ DX, PTRACE_FLAGS(AX)
+ MOVQ 40(SP), DI; MOVQ DI, PTRACE_RSP(AX)
+ MOVQ 48(SP), SI; MOVQ SI, PTRACE_SS(AX)
+
+ // Copy out and return.
+ MOVQ 0(SP), BX // Load vector.
+ MOVQ 8(SP), CX // Load error code.
+ MOVQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Original stack (kernel version).
+ MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer.
+ MOVQ CX, CPU_ERROR_CODE(GS) // Set error code.
+ MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user.
+ MOVQ BX, 24(SP) // Output vector.
+ RET
+
+kernel:
+ // As per above, we can save directly.
+ MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS)
+ MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS)
+ REGISTERS_SAVE(GS, CPU_REGISTERS)
+ MOVQ 16(SP), AX; MOVQ AX, CPU_REGISTERS+PTRACE_RIP(GS)
+ MOVQ 32(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_FLAGS(GS)
+ MOVQ 40(SP), CX; MOVQ CX, CPU_REGISTERS+PTRACE_RSP(GS)
+
+ // Set the error code and adjust the stack.
+ MOVQ 8(SP), AX // Load the error code.
+ MOVQ AX, CPU_ERROR_CODE(GS) // Copy out to the CPU.
+ MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel.
+ MOVQ 0(SP), BX // BX contains the vector.
+ ADDQ $48, SP // Drop the exception frame.
+
+ // Call the exception trampoline.
+ LOAD_KERNEL_STACK(GS)
+ MOVQ CPU_SELF(GS), AX // Load vCPU.
+ PUSHQ BX // Second argument (vector).
+ PUSHQ AX // First argument (vCPU).
+ CALL ·kernelException(SB) // Call the trampoline.
+ POPQ BX // Pop vector.
+ POPQ AX // Pop vCPU.
+ JMP ·resume(SB)
+
+#define EXCEPTION_WITH_ERROR(value, symbol) \
+TEXT symbol,NOSPLIT,$0; \
+ PUSHQ $value; \
+ JMP ·exception(SB);
+
+#define EXCEPTION_WITHOUT_ERROR(value, symbol) \
+TEXT symbol,NOSPLIT,$0; \
+ PUSHQ $0x0; \
+ PUSHQ $value; \
+ JMP ·exception(SB);
+
+EXCEPTION_WITHOUT_ERROR(DivideByZero, ·divideByZero(SB))
+EXCEPTION_WITHOUT_ERROR(Debug, ·debug(SB))
+EXCEPTION_WITHOUT_ERROR(NMI, ·nmi(SB))
+EXCEPTION_WITHOUT_ERROR(Breakpoint, ·breakpoint(SB))
+EXCEPTION_WITHOUT_ERROR(Overflow, ·overflow(SB))
+EXCEPTION_WITHOUT_ERROR(BoundRangeExceeded, ·boundRangeExceeded(SB))
+EXCEPTION_WITHOUT_ERROR(InvalidOpcode, ·invalidOpcode(SB))
+EXCEPTION_WITHOUT_ERROR(DeviceNotAvailable, ·deviceNotAvailable(SB))
+EXCEPTION_WITH_ERROR(DoubleFault, ·doubleFault(SB))
+EXCEPTION_WITHOUT_ERROR(CoprocessorSegmentOverrun, ·coprocessorSegmentOverrun(SB))
+EXCEPTION_WITH_ERROR(InvalidTSS, ·invalidTSS(SB))
+EXCEPTION_WITH_ERROR(SegmentNotPresent, ·segmentNotPresent(SB))
+EXCEPTION_WITH_ERROR(StackSegmentFault, ·stackSegmentFault(SB))
+EXCEPTION_WITH_ERROR(GeneralProtectionFault, ·generalProtectionFault(SB))
+EXCEPTION_WITH_ERROR(PageFault, ·pageFault(SB))
+EXCEPTION_WITHOUT_ERROR(X87FloatingPointException, ·x87FloatingPointException(SB))
+EXCEPTION_WITH_ERROR(AlignmentCheck, ·alignmentCheck(SB))
+EXCEPTION_WITHOUT_ERROR(MachineCheck, ·machineCheck(SB))
+EXCEPTION_WITHOUT_ERROR(SIMDFloatingPointException, ·simdFloatingPointException(SB))
+EXCEPTION_WITHOUT_ERROR(VirtualizationException, ·virtualizationException(SB))
+EXCEPTION_WITH_ERROR(SecurityException, ·securityException(SB))
+EXCEPTION_WITHOUT_ERROR(SyscallInt80, ·syscallInt80(SB))
diff --git a/pkg/sentry/platform/ring0/entry_arm64.go b/pkg/sentry/platform/ring0/entry_arm64.go
new file mode 100644
index 000000000..62a93f3d6
--- /dev/null
+++ b/pkg/sentry/platform/ring0/entry_arm64.go
@@ -0,0 +1,60 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+// This is an assembly function.
+//
+// The sysenter function is invoked in two situations:
+//
+// (1) The guest kernel has executed a system call.
+// (2) The guest application has executed a system call.
+//
+// The interrupt flag is examined to determine whether the system call was
+// executed from kernel mode or not and the appropriate stub is called.
+
+func El1_sync_invalid()
+func El1_irq_invalid()
+func El1_fiq_invalid()
+func El1_error_invalid()
+
+func El1_sync()
+func El1_irq()
+func El1_fiq()
+func El1_error()
+
+func El0_sync()
+func El0_irq()
+func El0_fiq()
+func El0_error()
+
+func El0_sync_invalid()
+func El0_irq_invalid()
+func El0_fiq_invalid()
+func El0_error_invalid()
+
+func Vectors()
+
+// Start is the CPU entrypoint.
+//
+// The CPU state will be set to c.Registers().
+func Start()
+func kernelExitToEl1()
+
+func kernelExitToEl0()
+
+// Shutdown execution
+func Shutdown()
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
new file mode 100644
index 000000000..2bc5f3ecd
--- /dev/null
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -0,0 +1,782 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "funcdata.h"
+#include "textflag.h"
+
+// NB: Offsets are programatically generated (see BUILD).
+//
+// This file is concatenated with the definitions.
+
+// Saves a register set.
+//
+// This is a macro because it may need to executed in contents where a stack is
+// not available for calls.
+//
+
+// ERET returns using the ELR and SPSR for the current exception level.
+#define ERET() \
+ WORD $0xd69f03e0
+
+// RSV_REG is a register that holds el1 information temporarily.
+#define RSV_REG R18_PLATFORM
+
+// RSV_REG_APP is a register that holds el0 information temporarily.
+#define RSV_REG_APP R9
+
+#define FPEN_NOTRAP 0x3
+#define FPEN_SHIFT 20
+
+#define FPEN_ENABLE (FPEN_NOTRAP << FPEN_SHIFT)
+
+// Saves a register set.
+//
+// This is a macro because it may need to executed in contents where a stack is
+// not available for calls.
+//
+// The following registers are not saved: R9, R18.
+#define REGISTERS_SAVE(reg, offset) \
+ MOVD R0, offset+PTRACE_R0(reg); \
+ MOVD R1, offset+PTRACE_R1(reg); \
+ MOVD R2, offset+PTRACE_R2(reg); \
+ MOVD R3, offset+PTRACE_R3(reg); \
+ MOVD R4, offset+PTRACE_R4(reg); \
+ MOVD R5, offset+PTRACE_R5(reg); \
+ MOVD R6, offset+PTRACE_R6(reg); \
+ MOVD R7, offset+PTRACE_R7(reg); \
+ MOVD R8, offset+PTRACE_R8(reg); \
+ MOVD R10, offset+PTRACE_R10(reg); \
+ MOVD R11, offset+PTRACE_R11(reg); \
+ MOVD R12, offset+PTRACE_R12(reg); \
+ MOVD R13, offset+PTRACE_R13(reg); \
+ MOVD R14, offset+PTRACE_R14(reg); \
+ MOVD R15, offset+PTRACE_R15(reg); \
+ MOVD R16, offset+PTRACE_R16(reg); \
+ MOVD R17, offset+PTRACE_R17(reg); \
+ MOVD R19, offset+PTRACE_R19(reg); \
+ MOVD R20, offset+PTRACE_R20(reg); \
+ MOVD R21, offset+PTRACE_R21(reg); \
+ MOVD R22, offset+PTRACE_R22(reg); \
+ MOVD R23, offset+PTRACE_R23(reg); \
+ MOVD R24, offset+PTRACE_R24(reg); \
+ MOVD R25, offset+PTRACE_R25(reg); \
+ MOVD R26, offset+PTRACE_R26(reg); \
+ MOVD R27, offset+PTRACE_R27(reg); \
+ MOVD g, offset+PTRACE_R28(reg); \
+ MOVD R29, offset+PTRACE_R29(reg); \
+ MOVD R30, offset+PTRACE_R30(reg);
+
+// Loads a register set.
+//
+// This is a macro because it may need to executed in contents where a stack is
+// not available for calls.
+//
+// The following registers are not loaded: R9, R18.
+#define REGISTERS_LOAD(reg, offset) \
+ MOVD offset+PTRACE_R0(reg), R0; \
+ MOVD offset+PTRACE_R1(reg), R1; \
+ MOVD offset+PTRACE_R2(reg), R2; \
+ MOVD offset+PTRACE_R3(reg), R3; \
+ MOVD offset+PTRACE_R4(reg), R4; \
+ MOVD offset+PTRACE_R5(reg), R5; \
+ MOVD offset+PTRACE_R6(reg), R6; \
+ MOVD offset+PTRACE_R7(reg), R7; \
+ MOVD offset+PTRACE_R8(reg), R8; \
+ MOVD offset+PTRACE_R10(reg), R10; \
+ MOVD offset+PTRACE_R11(reg), R11; \
+ MOVD offset+PTRACE_R12(reg), R12; \
+ MOVD offset+PTRACE_R13(reg), R13; \
+ MOVD offset+PTRACE_R14(reg), R14; \
+ MOVD offset+PTRACE_R15(reg), R15; \
+ MOVD offset+PTRACE_R16(reg), R16; \
+ MOVD offset+PTRACE_R17(reg), R17; \
+ MOVD offset+PTRACE_R19(reg), R19; \
+ MOVD offset+PTRACE_R20(reg), R20; \
+ MOVD offset+PTRACE_R21(reg), R21; \
+ MOVD offset+PTRACE_R22(reg), R22; \
+ MOVD offset+PTRACE_R23(reg), R23; \
+ MOVD offset+PTRACE_R24(reg), R24; \
+ MOVD offset+PTRACE_R25(reg), R25; \
+ MOVD offset+PTRACE_R26(reg), R26; \
+ MOVD offset+PTRACE_R27(reg), R27; \
+ MOVD offset+PTRACE_R28(reg), g; \
+ MOVD offset+PTRACE_R29(reg), R29; \
+ MOVD offset+PTRACE_R30(reg), R30;
+
+// NOP-s
+#define nop31Instructions() \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f;
+
+#define ESR_ELx_EC_UNKNOWN (0x00)
+#define ESR_ELx_EC_WFx (0x01)
+/* Unallocated EC: 0x02 */
+#define ESR_ELx_EC_CP15_32 (0x03)
+#define ESR_ELx_EC_CP15_64 (0x04)
+#define ESR_ELx_EC_CP14_MR (0x05)
+#define ESR_ELx_EC_CP14_LS (0x06)
+#define ESR_ELx_EC_FP_ASIMD (0x07)
+#define ESR_ELx_EC_CP10_ID (0x08) /* EL2 only */
+#define ESR_ELx_EC_PAC (0x09) /* EL2 and above */
+/* Unallocated EC: 0x0A - 0x0B */
+#define ESR_ELx_EC_CP14_64 (0x0C)
+/* Unallocated EC: 0x0d */
+#define ESR_ELx_EC_ILL (0x0E)
+/* Unallocated EC: 0x0F - 0x10 */
+#define ESR_ELx_EC_SVC32 (0x11)
+#define ESR_ELx_EC_HVC32 (0x12) /* EL2 only */
+#define ESR_ELx_EC_SMC32 (0x13) /* EL2 and above */
+/* Unallocated EC: 0x14 */
+#define ESR_ELx_EC_SVC64 (0x15)
+#define ESR_ELx_EC_HVC64 (0x16) /* EL2 and above */
+#define ESR_ELx_EC_SMC64 (0x17) /* EL2 and above */
+#define ESR_ELx_EC_SYS64 (0x18)
+#define ESR_ELx_EC_SVE (0x19)
+/* Unallocated EC: 0x1A - 0x1E */
+#define ESR_ELx_EC_IMP_DEF (0x1f) /* EL3 only */
+#define ESR_ELx_EC_IABT_LOW (0x20)
+#define ESR_ELx_EC_IABT_CUR (0x21)
+#define ESR_ELx_EC_PC_ALIGN (0x22)
+/* Unallocated EC: 0x23 */
+#define ESR_ELx_EC_DABT_LOW (0x24)
+#define ESR_ELx_EC_DABT_CUR (0x25)
+#define ESR_ELx_EC_SP_ALIGN (0x26)
+/* Unallocated EC: 0x27 */
+#define ESR_ELx_EC_FP_EXC32 (0x28)
+/* Unallocated EC: 0x29 - 0x2B */
+#define ESR_ELx_EC_FP_EXC64 (0x2C)
+/* Unallocated EC: 0x2D - 0x2E */
+#define ESR_ELx_EC_SERROR (0x2F)
+#define ESR_ELx_EC_BREAKPT_LOW (0x30)
+#define ESR_ELx_EC_BREAKPT_CUR (0x31)
+#define ESR_ELx_EC_SOFTSTP_LOW (0x32)
+#define ESR_ELx_EC_SOFTSTP_CUR (0x33)
+#define ESR_ELx_EC_WATCHPT_LOW (0x34)
+#define ESR_ELx_EC_WATCHPT_CUR (0x35)
+/* Unallocated EC: 0x36 - 0x37 */
+#define ESR_ELx_EC_BKPT32 (0x38)
+/* Unallocated EC: 0x39 */
+#define ESR_ELx_EC_VECTOR32 (0x3A) /* EL2 only */
+/* Unallocted EC: 0x3B */
+#define ESR_ELx_EC_BRK64 (0x3C)
+/* Unallocated EC: 0x3D - 0x3F */
+#define ESR_ELx_EC_MAX (0x3F)
+
+#define ESR_ELx_EC_SHIFT (26)
+#define ESR_ELx_EC_MASK (UL(0x3F) << ESR_ELx_EC_SHIFT)
+#define ESR_ELx_EC(esr) (((esr) & ESR_ELx_EC_MASK) >> ESR_ELx_EC_SHIFT)
+
+#define ESR_ELx_IL_SHIFT (25)
+#define ESR_ELx_IL (UL(1) << ESR_ELx_IL_SHIFT)
+#define ESR_ELx_ISS_MASK (ESR_ELx_IL - 1)
+
+/* ISS field definitions shared by different classes */
+#define ESR_ELx_WNR_SHIFT (6)
+#define ESR_ELx_WNR (UL(1) << ESR_ELx_WNR_SHIFT)
+
+/* Asynchronous Error Type */
+#define ESR_ELx_IDS_SHIFT (24)
+#define ESR_ELx_IDS (UL(1) << ESR_ELx_IDS_SHIFT)
+#define ESR_ELx_AET_SHIFT (10)
+#define ESR_ELx_AET (UL(0x7) << ESR_ELx_AET_SHIFT)
+
+#define ESR_ELx_AET_UC (UL(0) << ESR_ELx_AET_SHIFT)
+#define ESR_ELx_AET_UEU (UL(1) << ESR_ELx_AET_SHIFT)
+#define ESR_ELx_AET_UEO (UL(2) << ESR_ELx_AET_SHIFT)
+#define ESR_ELx_AET_UER (UL(3) << ESR_ELx_AET_SHIFT)
+#define ESR_ELx_AET_CE (UL(6) << ESR_ELx_AET_SHIFT)
+
+/* Shared ISS field definitions for Data/Instruction aborts */
+#define ESR_ELx_SET_SHIFT (11)
+#define ESR_ELx_SET_MASK (UL(3) << ESR_ELx_SET_SHIFT)
+#define ESR_ELx_FnV_SHIFT (10)
+#define ESR_ELx_FnV (UL(1) << ESR_ELx_FnV_SHIFT)
+#define ESR_ELx_EA_SHIFT (9)
+#define ESR_ELx_EA (UL(1) << ESR_ELx_EA_SHIFT)
+#define ESR_ELx_S1PTW_SHIFT (7)
+#define ESR_ELx_S1PTW (UL(1) << ESR_ELx_S1PTW_SHIFT)
+
+/* Shared ISS fault status code(IFSC/DFSC) for Data/Instruction aborts */
+#define ESR_ELx_FSC (0x3F)
+#define ESR_ELx_FSC_TYPE (0x3C)
+#define ESR_ELx_FSC_EXTABT (0x10)
+#define ESR_ELx_FSC_SERROR (0x11)
+#define ESR_ELx_FSC_ACCESS (0x08)
+#define ESR_ELx_FSC_FAULT (0x04)
+#define ESR_ELx_FSC_PERM (0x0C)
+
+/* ISS field definitions for Data Aborts */
+#define ESR_ELx_ISV_SHIFT (24)
+#define ESR_ELx_ISV (UL(1) << ESR_ELx_ISV_SHIFT)
+#define ESR_ELx_SAS_SHIFT (22)
+#define ESR_ELx_SAS (UL(3) << ESR_ELx_SAS_SHIFT)
+#define ESR_ELx_SSE_SHIFT (21)
+#define ESR_ELx_SSE (UL(1) << ESR_ELx_SSE_SHIFT)
+#define ESR_ELx_SRT_SHIFT (16)
+#define ESR_ELx_SRT_MASK (UL(0x1F) << ESR_ELx_SRT_SHIFT)
+#define ESR_ELx_SF_SHIFT (15)
+#define ESR_ELx_SF (UL(1) << ESR_ELx_SF_SHIFT)
+#define ESR_ELx_AR_SHIFT (14)
+#define ESR_ELx_AR (UL(1) << ESR_ELx_AR_SHIFT)
+#define ESR_ELx_CM_SHIFT (8)
+#define ESR_ELx_CM (UL(1) << ESR_ELx_CM_SHIFT)
+
+/* ISS field definitions for exceptions taken in to Hyp */
+#define ESR_ELx_CV (UL(1) << 24)
+#define ESR_ELx_COND_SHIFT (20)
+#define ESR_ELx_COND_MASK (UL(0xF) << ESR_ELx_COND_SHIFT)
+#define ESR_ELx_WFx_ISS_TI (UL(1) << 0)
+#define ESR_ELx_WFx_ISS_WFI (UL(0) << 0)
+#define ESR_ELx_WFx_ISS_WFE (UL(1) << 0)
+#define ESR_ELx_xVC_IMM_MASK ((1UL << 16) - 1)
+
+// LOAD_KERNEL_ADDRESS loads a kernel address.
+#define LOAD_KERNEL_ADDRESS(from, to) \
+ MOVD from, to; \
+ ORR $0xffff000000000000, to, to;
+
+// LOAD_KERNEL_STACK loads the kernel temporary stack.
+#define LOAD_KERNEL_STACK(from) \
+ LOAD_KERNEL_ADDRESS(CPU_SELF(from), RSV_REG); \
+ MOVD $CPU_STACK_TOP(RSV_REG), RSV_REG; \
+ MOVD RSV_REG, RSP; \
+ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
+ ISB $15; \
+ DSB $15;
+
+// SWITCH_TO_APP_PAGETABLE sets a new pagetable for a container application.
+#define SWITCH_TO_APP_PAGETABLE(from) \
+ MOVD CPU_TTBR0_APP(from), RSV_REG; \
+ WORD $0xd5182012; \ // MSR R18, TTBR0_EL1
+ ISB $15; \
+ DSB $15;
+
+// SWITCH_TO_KVM_PAGETABLE sets the kvm pagetable.
+#define SWITCH_TO_KVM_PAGETABLE(from) \
+ MOVD CPU_TTBR0_KVM(from), RSV_REG; \
+ WORD $0xd5182012; \ // MSR R18, TTBR0_EL1
+ ISB $15; \
+ DSB $15;
+
+#define IRQ_ENABLE \
+ MSR $2, DAIFSet;
+
+#define IRQ_DISABLE \
+ MSR $2, DAIFClr;
+
+#define VFP_ENABLE \
+ MOVD $FPEN_ENABLE, R0; \
+ WORD $0xd5181040; \ //MSR R0, CPACR_EL1
+ ISB $15;
+
+#define VFP_DISABLE \
+ MOVD $0x0, R0; \
+ WORD $0xd5181040; \ //MSR R0, CPACR_EL1
+ ISB $15;
+
+// KERNEL_ENTRY_FROM_EL0 is the entry code of the vcpu from el0 to el1.
+#define KERNEL_ENTRY_FROM_EL0 \
+ SUB $16, RSP, RSP; \ // step1, save r18, r9 into kernel temporary stack.
+ STP (RSV_REG, RSV_REG_APP), 16*0(RSP); \
+ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18, step2, switch user pagetable.
+ SWITCH_TO_KVM_PAGETABLE(RSV_REG); \
+ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
+ MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP; \ // step3, load app context pointer.
+ REGISTERS_SAVE(RSV_REG_APP, 0); \ // step4, save app context.
+ MOVD RSV_REG_APP, R20; \
+ LDP 16*0(RSP), (RSV_REG, RSV_REG_APP); \
+ ADD $16, RSP, RSP; \
+ MOVD RSV_REG, PTRACE_R18(R20); \
+ MOVD RSV_REG_APP, PTRACE_R9(R20); \
+ MOVD R20, RSV_REG_APP; \
+ WORD $0xd5384003; \ // MRS SPSR_EL1, R3
+ MOVD R3, PTRACE_PSTATE(RSV_REG_APP); \
+ MRS ELR_EL1, R3; \
+ MOVD R3, PTRACE_PC(RSV_REG_APP); \
+ WORD $0xd5384103; \ // MRS SP_EL0, R3
+ MOVD R3, PTRACE_SP(RSV_REG_APP);
+
+// KERNEL_ENTRY_FROM_EL1 is the entry code of the vcpu from el1 to el1.
+#define KERNEL_ENTRY_FROM_EL1 \
+ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
+ REGISTERS_SAVE(RSV_REG, CPU_REGISTERS); \ // Save sentry context.
+ MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG); \
+ WORD $0xd5384004; \ // MRS SPSR_EL1, R4
+ MOVD R4, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG); \
+ MRS ELR_EL1, R4; \
+ MOVD R4, CPU_REGISTERS+PTRACE_PC(RSV_REG); \
+ MOVD RSP, R4; \
+ MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \
+ LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack.
+
+// Halt halts execution.
+TEXT ·Halt(SB),NOSPLIT,$0
+ // Clear bluepill.
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ CMP RSV_REG, R9
+ BNE mmio_exit
+ MOVD $0, CPU_REGISTERS+PTRACE_R9(RSV_REG)
+mmio_exit:
+ // Disable fpsimd.
+ WORD $0xd5381041 // MRS CPACR_EL1, R1
+ MOVD R1, CPU_LAZY_VFP(RSV_REG)
+ VFP_DISABLE
+
+ // Trigger MMIO_EXIT/_KVM_HYPERCALL_VMEXIT.
+ //
+ // To keep it simple, I used the address of exception table as the
+ // MMIO base address, so that I can trigger a MMIO-EXIT by forcibly writing
+ // a read-only space.
+ // Also, the length is engough to match a sufficient number of hypercall ID.
+ // Then, in host user space, I can calculate this address to find out
+ // which hypercall.
+ MRS VBAR_EL1, R9
+ MOVD R0, 0x0(R9)
+
+ RET
+
+// HaltAndResume halts execution and point the pointer to the resume function.
+TEXT ·HaltAndResume(SB),NOSPLIT,$0
+ BL ·Halt(SB)
+ B ·kernelExitToEl1(SB) // Resume.
+
+// HaltEl1SvcAndResume calls Hooks.KernelSyscall and resume.
+TEXT ·HaltEl1SvcAndResume(SB),NOSPLIT,$0
+ WORD $0xd538d092 // MRS TPIDR_EL1, R18
+ MOVD CPU_SELF(RSV_REG), R3 // Load vCPU.
+ MOVD R3, 8(RSP) // First argument (vCPU).
+ CALL ·kernelSyscall(SB) // Call the trampoline.
+ B ·kernelExitToEl1(SB) // Resume.
+
+// Shutdown stops the guest.
+TEXT ·Shutdown(SB),NOSPLIT,$0
+ // PSCI EVENT.
+ MOVD $0x84000009, R0
+ HVC $0
+
+// See kernel.go.
+TEXT ·Current(SB),NOSPLIT,$0-8
+ MOVD CPU_SELF(RSV_REG), R8
+ MOVD R8, ret+0(FP)
+ RET
+
+#define STACK_FRAME_SIZE 16
+
+// kernelExitToEl0 is the entrypoint for application in guest_el0.
+// Prepare the vcpu environment for container application.
+TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
+ // Step1, save sentry context into memory.
+ REGISTERS_SAVE(RSV_REG, CPU_REGISTERS)
+ MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG)
+
+ WORD $0xd5384003 // MRS SPSR_EL1, R3
+ MOVD R3, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG)
+ MOVD R30, CPU_REGISTERS+PTRACE_PC(RSV_REG)
+ MOVD RSP, R3
+ MOVD R3, CPU_REGISTERS+PTRACE_SP(RSV_REG)
+
+ MOVD CPU_REGISTERS+PTRACE_R3(RSV_REG), R3
+
+ // Step2, save SP_EL1, PSTATE into kernel temporary stack.
+ // switch to temporary stack.
+ LOAD_KERNEL_STACK(RSV_REG)
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+
+ SUB $STACK_FRAME_SIZE, RSP, RSP
+ MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R11
+ MOVD CPU_REGISTERS+PTRACE_PSTATE(RSV_REG), R12
+ STP (R11, R12), 16*0(RSP)
+
+ MOVD CPU_REGISTERS+PTRACE_R11(RSV_REG), R11
+ MOVD CPU_REGISTERS+PTRACE_R12(RSV_REG), R12
+
+ // Step3, test user pagetable.
+ // If user pagetable is empty, trapped in el1_ia.
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ SWITCH_TO_APP_PAGETABLE(RSV_REG)
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ SWITCH_TO_KVM_PAGETABLE(RSV_REG)
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+
+ // If pagetable is not empty, recovery kernel temporary stack.
+ ADD $STACK_FRAME_SIZE, RSP, RSP
+
+ // Step4, load app context pointer.
+ MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP
+
+ // Step5, prepare the environment for container application.
+ // set sp_el0.
+ MOVD PTRACE_SP(RSV_REG_APP), R1
+ WORD $0xd5184101 //MSR R1, SP_EL0
+ // set pc.
+ MOVD PTRACE_PC(RSV_REG_APP), R1
+ MSR R1, ELR_EL1
+ // set pstate.
+ MOVD PTRACE_PSTATE(RSV_REG_APP), R1
+ WORD $0xd5184001 //MSR R1, SPSR_EL1
+
+ // RSV_REG & RSV_REG_APP will be loaded at the end.
+ REGISTERS_LOAD(RSV_REG_APP, 0)
+
+ // switch to user pagetable.
+ MOVD PTRACE_R18(RSV_REG_APP), RSV_REG
+ MOVD PTRACE_R9(RSV_REG_APP), RSV_REG_APP
+
+ SUB $STACK_FRAME_SIZE, RSP, RSP
+ STP (RSV_REG, RSV_REG_APP), 16*0(RSP)
+
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+
+ SWITCH_TO_APP_PAGETABLE(RSV_REG)
+
+ LDP 16*0(RSP), (RSV_REG, RSV_REG_APP)
+ ADD $STACK_FRAME_SIZE, RSP, RSP
+
+ ERET()
+
+// kernelExitToEl1 is the entrypoint for sentry in guest_el1.
+// Prepare the vcpu environment for sentry.
+TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+
+ MOVD CPU_REGISTERS+PTRACE_PSTATE(RSV_REG), R1
+ WORD $0xd5184001 //MSR R1, SPSR_EL1
+
+ MOVD CPU_REGISTERS+PTRACE_PC(RSV_REG), R1
+ MSR R1, ELR_EL1
+
+ MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R1
+ MOVD R1, RSP
+
+ REGISTERS_LOAD(RSV_REG, CPU_REGISTERS)
+ MOVD CPU_REGISTERS+PTRACE_R9(RSV_REG), RSV_REG_APP
+
+ ERET()
+
+// Start is the CPU entrypoint.
+TEXT ·Start(SB),NOSPLIT,$0
+ IRQ_DISABLE
+ MOVD R8, RSV_REG
+ ORR $0xffff000000000000, RSV_REG, RSV_REG
+ WORD $0xd518d092 //MSR R18, TPIDR_EL1
+
+ B ·kernelExitToEl1(SB)
+
+// El1_sync_invalid is the handler for an invalid EL1_sync.
+TEXT ·El1_sync_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+// El1_irq_invalid is the handler for an invalid El1_irq.
+TEXT ·El1_irq_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+// El1_fiq_invalid is the handler for an invalid El1_fiq.
+TEXT ·El1_fiq_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+// El1_error_invalid is the handler for an invalid El1_error.
+TEXT ·El1_error_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+// El1_sync is the handler for El1_sync.
+TEXT ·El1_sync(SB),NOSPLIT,$0
+ KERNEL_ENTRY_FROM_EL1
+ WORD $0xd5385219 // MRS ESR_EL1, R25
+ LSR $ESR_ELx_EC_SHIFT, R25, R24
+ CMP $ESR_ELx_EC_DABT_CUR, R24
+ BEQ el1_da
+ CMP $ESR_ELx_EC_IABT_CUR, R24
+ BEQ el1_ia
+ CMP $ESR_ELx_EC_SYS64, R24
+ BEQ el1_undef
+ CMP $ESR_ELx_EC_SP_ALIGN, R24
+ BEQ el1_sp_pc
+ CMP $ESR_ELx_EC_PC_ALIGN, R24
+ BEQ el1_sp_pc
+ CMP $ESR_ELx_EC_UNKNOWN, R24
+ BEQ el1_undef
+ CMP $ESR_ELx_EC_SVC64, R24
+ BEQ el1_svc
+ CMP $ESR_ELx_EC_BREAKPT_CUR, R24
+ BGE el1_dbg
+ CMP $ESR_ELx_EC_FP_ASIMD, R24
+ BEQ el1_fpsimd_acc
+ B el1_invalid
+
+el1_da:
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ WORD $0xd538601a //MRS FAR_EL1, R26
+
+ MOVD R26, CPU_FAULT_ADDR(RSV_REG)
+
+ MOVD $0, CPU_ERROR_TYPE(RSV_REG)
+
+ MOVD $PageFault, R3
+ MOVD R3, CPU_VECTOR_CODE(RSV_REG)
+
+ B ·HaltAndResume(SB)
+
+el1_ia:
+ B ·HaltAndResume(SB)
+
+el1_sp_pc:
+ B ·Shutdown(SB)
+
+el1_undef:
+ B ·Shutdown(SB)
+
+el1_svc:
+ MOVD $0, CPU_ERROR_CODE(RSV_REG)
+ MOVD $0, CPU_ERROR_TYPE(RSV_REG)
+ B ·HaltEl1SvcAndResume(SB)
+
+el1_dbg:
+ B ·Shutdown(SB)
+
+el1_fpsimd_acc:
+ VFP_ENABLE
+ B ·kernelExitToEl1(SB) // Resume.
+
+el1_invalid:
+ B ·Shutdown(SB)
+
+// El1_irq is the handler for El1_irq.
+TEXT ·El1_irq(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+// El1_fiq is the handler for El1_fiq.
+TEXT ·El1_fiq(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+// El1_error is the handler for El1_error.
+TEXT ·El1_error(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+// El0_sync is the handler for El0_sync.
+TEXT ·El0_sync(SB),NOSPLIT,$0
+ KERNEL_ENTRY_FROM_EL0
+ WORD $0xd5385219 // MRS ESR_EL1, R25
+ LSR $ESR_ELx_EC_SHIFT, R25, R24
+ CMP $ESR_ELx_EC_SVC64, R24
+ BEQ el0_svc
+ CMP $ESR_ELx_EC_DABT_LOW, R24
+ BEQ el0_da
+ CMP $ESR_ELx_EC_IABT_LOW, R24
+ BEQ el0_ia
+ CMP $ESR_ELx_EC_FP_ASIMD, R24
+ BEQ el0_fpsimd_acc
+ CMP $ESR_ELx_EC_SVE, R24
+ BEQ el0_sve_acc
+ CMP $ESR_ELx_EC_FP_EXC64, R24
+ BEQ el0_fpsimd_exc
+ CMP $ESR_ELx_EC_SP_ALIGN, R24
+ BEQ el0_sp_pc
+ CMP $ESR_ELx_EC_PC_ALIGN, R24
+ BEQ el0_sp_pc
+ CMP $ESR_ELx_EC_UNKNOWN, R24
+ BEQ el0_undef
+ CMP $ESR_ELx_EC_BREAKPT_LOW, R24
+ BGE el0_dbg
+ B el0_invalid
+
+el0_svc:
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+
+ MOVD $0, CPU_ERROR_CODE(RSV_REG) // Clear error code.
+
+ MOVD $1, R3
+ MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user.
+
+ MOVD $Syscall, R3
+ MOVD R3, CPU_VECTOR_CODE(RSV_REG)
+
+ B ·HaltAndResume(SB)
+
+el0_da:
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ WORD $0xd538601a //MRS FAR_EL1, R26
+
+ MOVD R26, CPU_FAULT_ADDR(RSV_REG)
+
+ MOVD $1, R3
+ MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user.
+
+ MOVD $PageFault, R3
+ MOVD R3, CPU_VECTOR_CODE(RSV_REG)
+
+ B ·HaltAndResume(SB)
+
+el0_ia:
+ B ·Shutdown(SB)
+
+el0_fpsimd_acc:
+ B ·Shutdown(SB)
+
+el0_sve_acc:
+ B ·Shutdown(SB)
+
+el0_fpsimd_exc:
+ B ·Shutdown(SB)
+
+el0_sp_pc:
+ B ·Shutdown(SB)
+
+el0_undef:
+ B ·Shutdown(SB)
+
+el0_dbg:
+ B ·Shutdown(SB)
+
+el0_invalid:
+ B ·Shutdown(SB)
+
+TEXT ·El0_irq(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El0_fiq(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El0_error(SB),NOSPLIT,$0
+ KERNEL_ENTRY_FROM_EL0
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ WORD $0xd538601a //MRS FAR_EL1, R26
+
+ MOVD R26, CPU_FAULT_ADDR(RSV_REG)
+
+ MOVD $1, R3
+ MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user.
+
+ MOVD $VirtualizationException, R3
+ MOVD R3, CPU_VECTOR_CODE(RSV_REG)
+
+ B ·HaltAndResume(SB)
+
+TEXT ·El0_sync_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El0_irq_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El0_fiq_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El0_error_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+// Vectors implements exception vector table.
+TEXT ·Vectors(SB),NOSPLIT,$0
+ B ·El1_sync_invalid(SB)
+ nop31Instructions()
+ B ·El1_irq_invalid(SB)
+ nop31Instructions()
+ B ·El1_fiq_invalid(SB)
+ nop31Instructions()
+ B ·El1_error_invalid(SB)
+ nop31Instructions()
+
+ B ·El1_sync(SB)
+ nop31Instructions()
+ B ·El1_irq(SB)
+ nop31Instructions()
+ B ·El1_fiq(SB)
+ nop31Instructions()
+ B ·El1_error(SB)
+ nop31Instructions()
+
+ B ·El0_sync(SB)
+ nop31Instructions()
+ B ·El0_irq(SB)
+ nop31Instructions()
+ B ·El0_fiq(SB)
+ nop31Instructions()
+ B ·El0_error(SB)
+ nop31Instructions()
+
+ B ·El0_sync_invalid(SB)
+ nop31Instructions()
+ B ·El0_irq_invalid(SB)
+ nop31Instructions()
+ B ·El0_fiq_invalid(SB)
+ nop31Instructions()
+ B ·El0_error_invalid(SB)
+ nop31Instructions()
+
+ // The exception-vector-table is required to be 11-bits aligned.
+ // Please see Linux source code as reference: arch/arm64/kernel/entry.s.
+ // For gvisor, I defined it as 4K in length, filled the 2nd 2K part with NOPs.
+ // So that, I can safely move the 1st 2K part into the address with 11-bits alignment.
+ WORD $0xd503201f //nop
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD
new file mode 100644
index 000000000..549f3d228
--- /dev/null
+++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD
@@ -0,0 +1,34 @@
+load("//tools:defs.bzl", "go_binary")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "defs_impl_arm64",
+ out = "defs_impl_arm64.go",
+ package = "main",
+ template = "//pkg/sentry/platform/ring0:defs_arm64",
+)
+
+go_template_instance(
+ name = "defs_impl_amd64",
+ out = "defs_impl_amd64.go",
+ package = "main",
+ template = "//pkg/sentry/platform/ring0:defs_amd64",
+)
+
+go_binary(
+ name = "gen_offsets",
+ srcs = [
+ "defs_impl_amd64.go",
+ "defs_impl_arm64.go",
+ "main.go",
+ ],
+ visibility = ["//pkg/sentry/platform/ring0:__pkg__"],
+ deps = [
+ "//pkg/cpuid",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/platform/ring0/pagetables",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/platform/ring0/gen_offsets/main.go b/pkg/sentry/platform/ring0/gen_offsets/main.go
new file mode 100644
index 000000000..a4927da2f
--- /dev/null
+++ b/pkg/sentry/platform/ring0/gen_offsets/main.go
@@ -0,0 +1,24 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary gen_offsets is a helper for generating offset headers.
+package main
+
+import (
+ "os"
+)
+
+func main() {
+ Emit(os.Stdout)
+}
diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go
new file mode 100644
index 000000000..021693791
--- /dev/null
+++ b/pkg/sentry/platform/ring0/kernel.go
@@ -0,0 +1,82 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ring0
+
+// Init initializes a new kernel.
+//
+// N.B. that constraints on KernelOpts must be satisfied.
+//
+//go:nosplit
+func (k *Kernel) Init(opts KernelOpts) {
+ k.init(opts)
+}
+
+// Halt halts execution.
+func Halt()
+
+// defaultHooks implements hooks.
+type defaultHooks struct{}
+
+// KernelSyscall implements Hooks.KernelSyscall.
+//
+// +checkescape:all
+//
+//go:nosplit
+func (defaultHooks) KernelSyscall() {
+ Halt()
+}
+
+// KernelException implements Hooks.KernelException.
+//
+// +checkescape:all
+//
+//go:nosplit
+func (defaultHooks) KernelException(Vector) {
+ Halt()
+}
+
+// kernelSyscall is a trampoline.
+//
+// +checkescape:hard,stack
+//
+//go:nosplit
+func kernelSyscall(c *CPU) {
+ c.hooks.KernelSyscall()
+}
+
+// kernelException is a trampoline.
+//
+// +checkescape:hard,stack
+//
+//go:nosplit
+func kernelException(c *CPU, vector Vector) {
+ c.hooks.KernelException(vector)
+}
+
+// Init initializes a new CPU.
+//
+// Init allows embedding in other objects.
+func (c *CPU) Init(k *Kernel, hooks Hooks) {
+ c.self = c // Set self reference.
+ c.kernel = k // Set kernel reference.
+ c.init() // Perform architectural init.
+
+ // Require hooks.
+ if hooks != nil {
+ c.hooks = hooks
+ } else {
+ c.hooks = defaultHooks{}
+ }
+}
diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go
new file mode 100644
index 000000000..d37981dbf
--- /dev/null
+++ b/pkg/sentry/platform/ring0/kernel_amd64.go
@@ -0,0 +1,281 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package ring0
+
+import (
+ "encoding/binary"
+)
+
+// init initializes architecture-specific state.
+func (k *Kernel) init(opts KernelOpts) {
+ // Save the root page tables.
+ k.PageTables = opts.PageTables
+
+ // Setup the IDT, which is uniform.
+ for v, handler := range handlers {
+ // Allow Breakpoint and Overflow to be called from all
+ // privilege levels.
+ dpl := 0
+ if v == Breakpoint || v == Overflow {
+ dpl = 3
+ }
+ // Note that we set all traps to use the interrupt stack, this
+ // is defined below when setting up the TSS.
+ k.globalIDT[v].setInterrupt(Kcode, uint64(kernelFunc(handler)), dpl, 1 /* ist */)
+ }
+}
+
+// init initializes architecture-specific state.
+func (c *CPU) init() {
+ // Null segment.
+ c.gdt[0].setNull()
+
+ // Kernel & user segments.
+ c.gdt[segKcode] = KernelCodeSegment
+ c.gdt[segKdata] = KernelDataSegment
+ c.gdt[segUcode32] = UserCodeSegment32
+ c.gdt[segUdata] = UserDataSegment
+ c.gdt[segUcode64] = UserCodeSegment64
+
+ // The task segment, this spans two entries.
+ tssBase, tssLimit, _ := c.TSS()
+ c.gdt[segTss].set(
+ uint32(tssBase),
+ uint32(tssLimit),
+ 0, // Privilege level zero.
+ SegmentDescriptorPresent|
+ SegmentDescriptorAccess|
+ SegmentDescriptorWrite|
+ SegmentDescriptorExecute)
+ c.gdt[segTssHi].setHi(uint32((tssBase) >> 32))
+
+ // Set the kernel stack pointer in the TSS (virtual address).
+ stackAddr := c.StackTop()
+ c.tss.rsp0Lo = uint32(stackAddr)
+ c.tss.rsp0Hi = uint32(stackAddr >> 32)
+ c.tss.ist1Lo = uint32(stackAddr)
+ c.tss.ist1Hi = uint32(stackAddr >> 32)
+
+ // Set the I/O bitmap base address beyond the last byte in the TSS
+ // to block access to the entire I/O address range.
+ //
+ // From section 18.5.2 "I/O Permission Bit Map" from Intel SDM vol1:
+ // I/O addresses not spanned by the map are treated as if they had set
+ // bits in the map.
+ c.tss.ioPerm = tssLimit + 1
+
+ // Permanently set the kernel segments.
+ c.registers.Cs = uint64(Kcode)
+ c.registers.Ds = uint64(Kdata)
+ c.registers.Es = uint64(Kdata)
+ c.registers.Ss = uint64(Kdata)
+ c.registers.Fs = uint64(Kdata)
+ c.registers.Gs = uint64(Kdata)
+
+ // Set mandatory flags.
+ c.registers.Eflags = KernelFlagsSet
+}
+
+// StackTop returns the kernel's stack address.
+//
+//go:nosplit
+func (c *CPU) StackTop() uint64 {
+ return uint64(kernelAddr(&c.stack[0])) + uint64(len(c.stack))
+}
+
+// IDT returns the CPU's IDT base and limit.
+//
+//go:nosplit
+func (c *CPU) IDT() (uint64, uint16) {
+ return uint64(kernelAddr(&c.kernel.globalIDT[0])), uint16(binary.Size(&c.kernel.globalIDT) - 1)
+}
+
+// GDT returns the CPU's GDT base and limit.
+//
+//go:nosplit
+func (c *CPU) GDT() (uint64, uint16) {
+ return uint64(kernelAddr(&c.gdt[0])), uint16(8*segLast - 1)
+}
+
+// TSS returns the CPU's TSS base, limit and value.
+//
+//go:nosplit
+func (c *CPU) TSS() (uint64, uint16, *SegmentDescriptor) {
+ return uint64(kernelAddr(&c.tss)), uint16(binary.Size(&c.tss) - 1), &c.gdt[segTss]
+}
+
+// CR0 returns the CPU's CR0 value.
+//
+//go:nosplit
+func (c *CPU) CR0() uint64 {
+ return _CR0_PE | _CR0_PG | _CR0_AM | _CR0_ET
+}
+
+// CR4 returns the CPU's CR4 value.
+//
+//go:nosplit
+func (c *CPU) CR4() uint64 {
+ cr4 := uint64(_CR4_PAE | _CR4_PSE | _CR4_OSFXSR | _CR4_OSXMMEXCPT)
+ if hasPCID {
+ cr4 |= _CR4_PCIDE
+ }
+ if hasXSAVE {
+ cr4 |= _CR4_OSXSAVE
+ }
+ if hasSMEP {
+ cr4 |= _CR4_SMEP
+ }
+ if hasFSGSBASE {
+ cr4 |= _CR4_FSGSBASE
+ }
+ return cr4
+}
+
+// EFER returns the CPU's EFER value.
+//
+//go:nosplit
+func (c *CPU) EFER() uint64 {
+ return _EFER_LME | _EFER_LMA | _EFER_SCE | _EFER_NX
+}
+
+// IsCanonical indicates whether addr is canonical per the amd64 spec.
+//
+//go:nosplit
+func IsCanonical(addr uint64) bool {
+ return addr <= 0x00007fffffffffff || addr > 0xffff800000000000
+}
+
+// SwitchToUser performs either a sysret or an iret.
+//
+// The return value is the vector that interrupted execution.
+//
+// This function will not split the stack. Callers will probably want to call
+// runtime.entersyscall (and pair with a call to runtime.exitsyscall) prior to
+// calling this function.
+//
+// When this is done, this region is quite sensitive to things like system
+// calls. After calling entersyscall, any memory used must have been allocated
+// and no function calls without go:nosplit are permitted. Any calls made here
+// are protected appropriately (e.g. IsCanonical and CR3).
+//
+// Also note that this function transitively depends on the compiler generating
+// code that uses IP-relative addressing inside of absolute addresses. That's
+// the case for amd64, but may not be the case for other architectures.
+//
+// Precondition: the Rip, Rsp, Fs and Gs registers must be canonical.
+//
+// +checkescape:all
+//
+//go:nosplit
+func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
+ userCR3 := switchOpts.PageTables.CR3(!switchOpts.Flush, switchOpts.UserPCID)
+ kernelCR3 := c.kernel.PageTables.CR3(true, switchOpts.KernelPCID)
+
+ // Sanitize registers.
+ regs := switchOpts.Registers
+ regs.Eflags &= ^uint64(UserFlagsClear)
+ regs.Eflags |= UserFlagsSet
+ regs.Cs = uint64(Ucode64) // Required for iret.
+ regs.Ss = uint64(Udata) // Ditto.
+
+ // Perform the switch.
+ swapgs() // GS will be swapped on return.
+ WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS.
+ WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS.
+ LoadFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy in floating point.
+ jumpToKernel() // Switch to upper half.
+ writeCR3(uintptr(userCR3)) // Change to user address space.
+ if switchOpts.FullRestore {
+ vector = iret(c, regs)
+ } else {
+ vector = sysret(c, regs)
+ }
+ writeCR3(uintptr(kernelCR3)) // Return to kernel address space.
+ jumpToUser() // Return to lower half.
+ SaveFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy out floating point.
+ WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS.
+ return
+}
+
+// start is the CPU entrypoint.
+//
+// This is called from the Start asm stub (see entry_amd64.go); on return the
+// registers in c.registers will be restored (not segments).
+//
+//go:nosplit
+func start(c *CPU) {
+ // Save per-cpu & FS segment.
+ WriteGS(kernelAddr(c))
+ WriteFS(uintptr(c.registers.Fs_base))
+
+ // Initialize floating point.
+ //
+ // Note that on skylake, the valid XCR0 mask reported seems to be 0xff.
+ // This breaks down as:
+ //
+ // bit0 - x87
+ // bit1 - SSE
+ // bit2 - AVX
+ // bit3-4 - MPX
+ // bit5-7 - AVX512
+ //
+ // For some reason, enabled MPX & AVX512 on platforms that report them
+ // seems to be cause a general protection fault. (Maybe there are some
+ // virtualization issues and these aren't exported to the guest cpuid.)
+ // This needs further investigation, but we can limit the floating
+ // point operations to x87, SSE & AVX for now.
+ fninit()
+ xsetbv(0, validXCR0Mask&0x7)
+
+ // Set the syscall target.
+ wrmsr(_MSR_LSTAR, kernelFunc(sysenter))
+ wrmsr(_MSR_SYSCALL_MASK, KernelFlagsClear|_RFLAGS_DF)
+
+ // NOTE: This depends on having the 64-bit segments immediately
+ // following the 32-bit user segments. This is simply the way the
+ // sysret instruction is designed to work (it assumes they follow).
+ wrmsr(_MSR_STAR, uintptr(uint64(Kcode)<<32|uint64(Ucode32)<<48))
+ wrmsr(_MSR_CSTAR, kernelFunc(sysenter))
+}
+
+// SetCPUIDFaulting sets CPUID faulting per the boolean value.
+//
+// True is returned if faulting could be set.
+//
+//go:nosplit
+func SetCPUIDFaulting(on bool) bool {
+ // Per the SDM (Vol 3, Table 2-43), PLATFORM_INFO bit 31 denotes support
+ // for CPUID faulting, and we enable and disable via the MISC_FEATURES MSR.
+ if rdmsr(_MSR_PLATFORM_INFO)&_PLATFORM_INFO_CPUID_FAULT != 0 {
+ features := rdmsr(_MSR_MISC_FEATURES)
+ if on {
+ features |= _MISC_FEATURE_CPUID_TRAP
+ } else {
+ features &^= _MISC_FEATURE_CPUID_TRAP
+ }
+ wrmsr(_MSR_MISC_FEATURES, features)
+ return true // Setting successful.
+ }
+ return false
+}
+
+// ReadCR2 reads the current CR2 value.
+//
+//go:nosplit
+func ReadCR2() uintptr {
+ return readCR2()
+}
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
new file mode 100644
index 000000000..ccacaea6b
--- /dev/null
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -0,0 +1,66 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+// HaltAndResume halts execution and point the pointer to the resume function.
+//go:nosplit
+func HaltAndResume()
+
+// HaltEl1SvcAndResume calls Hooks.KernelSyscall and resume.
+//go:nosplit
+func HaltEl1SvcAndResume()
+
+// init initializes architecture-specific state.
+func (k *Kernel) init(opts KernelOpts) {
+ // Save the root page tables.
+ k.PageTables = opts.PageTables
+}
+
+// init initializes architecture-specific state.
+func (c *CPU) init() {
+ // Set the kernel stack pointer(virtual address).
+ c.registers.Sp = uint64(c.StackTop())
+
+}
+
+// StackTop returns the kernel's stack address.
+//
+//go:nosplit
+func (c *CPU) StackTop() uint64 {
+ return uint64(kernelAddr(&c.stack[0])) + uint64(len(c.stack))
+}
+
+// IsCanonical indicates whether addr is canonical per the arm64 spec.
+//
+//go:nosplit
+func IsCanonical(addr uint64) bool {
+ return addr <= 0x0000ffffffffffff || addr > 0xffff000000000000
+}
+
+//go:nosplit
+func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
+ // Sanitize registers.
+ regs := switchOpts.Registers
+
+ regs.Pstate &= ^uint64(UserFlagsClear)
+ regs.Pstate |= UserFlagsSet
+ kernelExitToEl0()
+ vector = c.vecCode
+
+ // Perform the switch.
+ return
+}
diff --git a/pkg/sentry/platform/ring0/kernel_unsafe.go b/pkg/sentry/platform/ring0/kernel_unsafe.go
new file mode 100644
index 000000000..16955ad91
--- /dev/null
+++ b/pkg/sentry/platform/ring0/kernel_unsafe.go
@@ -0,0 +1,41 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ring0
+
+import (
+ "unsafe"
+)
+
+// eface mirrors runtime.eface.
+type eface struct {
+ typ uintptr
+ data unsafe.Pointer
+}
+
+// kernelAddr returns the kernel virtual address for the given object.
+//
+//go:nosplit
+func kernelAddr(obj interface{}) uintptr {
+ e := (*eface)(unsafe.Pointer(&obj))
+ return KernelStartAddress | uintptr(e.data)
+}
+
+// kernelFunc returns the address of the given function.
+//
+//go:nosplit
+func kernelFunc(fn func()) uintptr {
+ fnptr := (**uintptr)(unsafe.Pointer(&fn))
+ return KernelStartAddress | **fnptr
+}
diff --git a/pkg/sentry/platform/ring0/lib_amd64.go b/pkg/sentry/platform/ring0/lib_amd64.go
new file mode 100644
index 000000000..ca968a036
--- /dev/null
+++ b/pkg/sentry/platform/ring0/lib_amd64.go
@@ -0,0 +1,131 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package ring0
+
+import (
+ "gvisor.dev/gvisor/pkg/cpuid"
+)
+
+// LoadFloatingPoint loads floating point state by the most efficient mechanism
+// available (set by Init).
+var LoadFloatingPoint func(*byte)
+
+// SaveFloatingPoint saves floating point state by the most efficient mechanism
+// available (set by Init).
+var SaveFloatingPoint func(*byte)
+
+// fxrstor uses fxrstor64 to load floating point state.
+func fxrstor(*byte)
+
+// xrstor uses xrstor to load floating point state.
+func xrstor(*byte)
+
+// fxsave uses fxsave64 to save floating point state.
+func fxsave(*byte)
+
+// xsave uses xsave to save floating point state.
+func xsave(*byte)
+
+// xsaveopt uses xsaveopt to save floating point state.
+func xsaveopt(*byte)
+
+// WriteFS sets the GS address (set by init).
+var WriteFS func(addr uintptr)
+
+// wrfsbase writes to the GS base address.
+func wrfsbase(addr uintptr)
+
+// wrfsmsr writes to the GS_BASE MSR.
+func wrfsmsr(addr uintptr)
+
+// WriteGS sets the GS address (set by init).
+var WriteGS func(addr uintptr)
+
+// wrgsbase writes to the GS base address.
+func wrgsbase(addr uintptr)
+
+// wrgsmsr writes to the GS_BASE MSR.
+func wrgsmsr(addr uintptr)
+
+// writeCR3 writes the CR3 value.
+func writeCR3(phys uintptr)
+
+// readCR3 reads the current CR3 value.
+func readCR3() uintptr
+
+// readCR2 reads the current CR2 value.
+func readCR2() uintptr
+
+// jumpToKernel jumps to the kernel version of the current RIP.
+func jumpToKernel()
+
+// jumpToUser jumps to the user version of the current RIP.
+func jumpToUser()
+
+// fninit initializes the floating point unit.
+func fninit()
+
+// xsetbv writes to an extended control register.
+func xsetbv(reg, value uintptr)
+
+// xgetbv reads an extended control register.
+func xgetbv(reg uintptr) uintptr
+
+// wrmsr reads to the given MSR.
+func wrmsr(reg, value uintptr)
+
+// rdmsr reads the given MSR.
+func rdmsr(reg uintptr) uintptr
+
+// Mostly-constants set by Init.
+var (
+ hasSMEP bool
+ hasPCID bool
+ hasXSAVEOPT bool
+ hasXSAVE bool
+ hasFSGSBASE bool
+ validXCR0Mask uintptr
+)
+
+// Init sets function pointers based on architectural features.
+//
+// This must be called prior to using ring0.
+func Init(featureSet *cpuid.FeatureSet) {
+ hasSMEP = featureSet.HasFeature(cpuid.X86FeatureSMEP)
+ hasPCID = featureSet.HasFeature(cpuid.X86FeaturePCID)
+ hasXSAVEOPT = featureSet.UseXsaveopt()
+ hasXSAVE = featureSet.UseXsave()
+ hasFSGSBASE = featureSet.HasFeature(cpuid.X86FeatureFSGSBase)
+ validXCR0Mask = uintptr(featureSet.ValidXCR0Mask())
+ if hasXSAVEOPT {
+ SaveFloatingPoint = xsaveopt
+ LoadFloatingPoint = xrstor
+ } else if hasXSAVE {
+ SaveFloatingPoint = xsave
+ LoadFloatingPoint = xrstor
+ } else {
+ SaveFloatingPoint = fxsave
+ LoadFloatingPoint = fxrstor
+ }
+ if hasFSGSBASE {
+ WriteFS = wrfsbase
+ WriteGS = wrgsbase
+ } else {
+ WriteFS = wrfsmsr
+ WriteGS = wrgsmsr
+ }
+}
diff --git a/pkg/sentry/platform/ring0/lib_amd64.s b/pkg/sentry/platform/ring0/lib_amd64.s
new file mode 100644
index 000000000..75d742750
--- /dev/null
+++ b/pkg/sentry/platform/ring0/lib_amd64.s
@@ -0,0 +1,247 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "funcdata.h"
+#include "textflag.h"
+
+// fxrstor loads floating point state.
+//
+// The code corresponds to:
+//
+// fxrstor64 (%rbx)
+//
+TEXT ·fxrstor(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), BX
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x0b;
+ RET
+
+// xrstor loads floating point state.
+//
+// The code corresponds to:
+//
+// xrstor (%rdi)
+//
+TEXT ·xrstor(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), DI
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x2f;
+ RET
+
+// fxsave saves floating point state.
+//
+// The code corresponds to:
+//
+// fxsave64 (%rbx)
+//
+TEXT ·fxsave(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), BX
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x03;
+ RET
+
+// xsave saves floating point state.
+//
+// The code corresponds to:
+//
+// xsave (%rdi)
+//
+TEXT ·xsave(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), DI
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27;
+ RET
+
+// xsaveopt saves floating point state.
+//
+// The code corresponds to:
+//
+// xsaveopt (%rdi)
+//
+TEXT ·xsaveopt(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), DI
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x37;
+ RET
+
+// wrfsbase writes to the FS base.
+//
+// The code corresponds to:
+//
+// wrfsbase %rax
+//
+TEXT ·wrfsbase(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), AX
+ BYTE $0xf3; BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0xd0;
+ RET
+
+// wrfsmsr writes to the FSBASE MSR.
+//
+// The code corresponds to:
+//
+// wrmsr (writes EDX:EAX to the MSR in ECX)
+//
+TEXT ·wrfsmsr(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), AX
+ MOVQ AX, DX
+ SHRQ $32, DX
+ MOVQ $0xc0000100, CX // MSR_FS_BASE
+ BYTE $0x0f; BYTE $0x30;
+ RET
+
+// wrgsbase writes to the GS base.
+//
+// The code corresponds to:
+//
+// wrgsbase %rax
+//
+TEXT ·wrgsbase(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), AX
+ BYTE $0xf3; BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0xd8;
+ RET
+
+// wrgsmsr writes to the GSBASE MSR.
+//
+// See wrfsmsr.
+TEXT ·wrgsmsr(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), AX
+ MOVQ AX, DX
+ SHRQ $32, DX
+ MOVQ $0xc0000101, CX // MSR_GS_BASE
+ BYTE $0x0f; BYTE $0x30; // WRMSR
+ RET
+
+// jumpToUser changes execution to the user address.
+//
+// This works by changing the return value to the user version.
+TEXT ·jumpToUser(SB),NOSPLIT,$0
+ MOVQ 0(SP), AX
+ MOVQ ·KernelStartAddress(SB), BX
+ NOTQ BX
+ ANDQ BX, SP // Switch the stack.
+ ANDQ BX, BP // Switch the frame pointer.
+ ANDQ BX, AX // Future return value.
+ MOVQ AX, 0(SP)
+ RET
+
+// jumpToKernel changes execution to the kernel address space.
+//
+// This works by changing the return value to the kernel version.
+TEXT ·jumpToKernel(SB),NOSPLIT,$0
+ MOVQ 0(SP), AX
+ MOVQ ·KernelStartAddress(SB), BX
+ ORQ BX, SP // Switch the stack.
+ ORQ BX, BP // Switch the frame pointer.
+ ORQ BX, AX // Future return value.
+ MOVQ AX, 0(SP)
+ RET
+
+// writeCR3 writes the given CR3 value.
+//
+// The code corresponds to:
+//
+// mov %rax, %cr3
+//
+TEXT ·writeCR3(SB),NOSPLIT,$0-8
+ MOVQ cr3+0(FP), AX
+ BYTE $0x0f; BYTE $0x22; BYTE $0xd8;
+ RET
+
+// readCR3 reads the current CR3 value.
+//
+// The code corresponds to:
+//
+// mov %cr3, %rax
+//
+TEXT ·readCR3(SB),NOSPLIT,$0-8
+ BYTE $0x0f; BYTE $0x20; BYTE $0xd8;
+ MOVQ AX, ret+0(FP)
+ RET
+
+// readCR2 reads the current CR2 value.
+//
+// The code corresponds to:
+//
+// mov %cr2, %rax
+//
+TEXT ·readCR2(SB),NOSPLIT,$0-8
+ BYTE $0x0f; BYTE $0x20; BYTE $0xd0;
+ MOVQ AX, ret+0(FP)
+ RET
+
+// fninit initializes the floating point unit.
+//
+// The code corresponds to:
+//
+// fninit
+TEXT ·fninit(SB),NOSPLIT,$0
+ BYTE $0xdb; BYTE $0xe3;
+ RET
+
+// xsetbv writes to an extended control register.
+//
+// The code corresponds to:
+//
+// xsetbv
+//
+TEXT ·xsetbv(SB),NOSPLIT,$0-16
+ MOVL reg+0(FP), CX
+ MOVL value+8(FP), AX
+ MOVL value+12(FP), DX
+ BYTE $0x0f; BYTE $0x01; BYTE $0xd1;
+ RET
+
+// xgetbv reads an extended control register.
+//
+// The code corresponds to:
+//
+// xgetbv
+//
+TEXT ·xgetbv(SB),NOSPLIT,$0-16
+ MOVL reg+0(FP), CX
+ BYTE $0x0f; BYTE $0x01; BYTE $0xd0;
+ MOVL AX, ret+8(FP)
+ MOVL DX, ret+12(FP)
+ RET
+
+// wrmsr writes to a control register.
+//
+// The code corresponds to:
+//
+// wrmsr
+//
+TEXT ·wrmsr(SB),NOSPLIT,$0-16
+ MOVL reg+0(FP), CX
+ MOVL value+8(FP), AX
+ MOVL value+12(FP), DX
+ BYTE $0x0f; BYTE $0x30;
+ RET
+
+// rdmsr reads a control register.
+//
+// The code corresponds to:
+//
+// rdmsr
+//
+TEXT ·rdmsr(SB),NOSPLIT,$0-16
+ MOVL reg+0(FP), CX
+ BYTE $0x0f; BYTE $0x32;
+ MOVL AX, ret+8(FP)
+ MOVL DX, ret+12(FP)
+ RET
diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go
new file mode 100644
index 000000000..a6345010d
--- /dev/null
+++ b/pkg/sentry/platform/ring0/lib_arm64.go
@@ -0,0 +1,52 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+// CPACREL1 returns the value of the CPACR_EL1 register.
+func CPACREL1() (value uintptr)
+
+// FPCR returns the value of FPCR register.
+func GetFPCR() (value uintptr)
+
+// SetFPCR writes the FPCR value.
+func SetFPCR(value uintptr)
+
+// FPSR returns the value of FPSR register.
+func GetFPSR() (value uintptr)
+
+// SetFPSR writes the FPSR value.
+func SetFPSR(value uintptr)
+
+// SaveVRegs saves V0-V31 registers.
+// V0-V31: 32 128-bit registers for floating point and simd.
+func SaveVRegs(*byte)
+
+// LoadVRegs loads V0-V31 registers.
+func LoadVRegs(*byte)
+
+// GetTLS returns the value of TPIDR_EL0 register.
+func GetTLS() (value uint64)
+
+// SetTLS writes the TPIDR_EL0 value.
+func SetTLS(value uint64)
+
+// Init sets function pointers based on architectural features.
+//
+// This must be called prior to using ring0.
+func Init() {
+ rewriteVectors()
+}
diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s
new file mode 100644
index 000000000..b63e14b41
--- /dev/null
+++ b/pkg/sentry/platform/ring0/lib_arm64.s
@@ -0,0 +1,131 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "funcdata.h"
+#include "textflag.h"
+
+TEXT ·GetTLS(SB),NOSPLIT,$0-8
+ MRS TPIDR_EL0, R1
+ MOVD R1, ret+0(FP)
+ RET
+
+TEXT ·SetTLS(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R1
+ MSR R1, TPIDR_EL0
+ RET
+
+TEXT ·CPACREL1(SB),NOSPLIT,$0-8
+ WORD $0xd5381041 // MRS CPACR_EL1, R1
+ MOVD R1, ret+0(FP)
+ RET
+
+TEXT ·GetFPCR(SB),NOSPLIT,$0-8
+ WORD $0xd53b4201 // MRS NZCV, R1
+ MOVD R1, ret+0(FP)
+ RET
+
+TEXT ·GetFPSR(SB),NOSPLIT,$0-8
+ WORD $0xd53b4421 // MRS FPSR, R1
+ MOVD R1, ret+0(FP)
+ RET
+
+TEXT ·SetFPCR(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R1
+ WORD $0xd51b4201 // MSR R1, NZCV
+ RET
+
+TEXT ·SetFPSR(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R1
+ WORD $0xd51b4421 // MSR R1, FPSR
+ RET
+
+TEXT ·SaveVRegs(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R0
+
+ // Skip aarch64_ctx, fpsr, fpcr.
+ FMOVD F0, 16*1(R0)
+ FMOVD F1, 16*2(R0)
+ FMOVD F2, 16*3(R0)
+ FMOVD F3, 16*4(R0)
+ FMOVD F4, 16*5(R0)
+ FMOVD F5, 16*6(R0)
+ FMOVD F6, 16*7(R0)
+ FMOVD F7, 16*8(R0)
+ FMOVD F8, 16*9(R0)
+ FMOVD F9, 16*10(R0)
+ FMOVD F10, 16*11(R0)
+ FMOVD F11, 16*12(R0)
+ FMOVD F12, 16*13(R0)
+ FMOVD F13, 16*14(R0)
+ FMOVD F14, 16*15(R0)
+ FMOVD F15, 16*16(R0)
+ FMOVD F16, 16*17(R0)
+ FMOVD F17, 16*18(R0)
+ FMOVD F18, 16*19(R0)
+ FMOVD F19, 16*20(R0)
+ FMOVD F20, 16*21(R0)
+ FMOVD F21, 16*22(R0)
+ FMOVD F22, 16*23(R0)
+ FMOVD F23, 16*24(R0)
+ FMOVD F24, 16*25(R0)
+ FMOVD F25, 16*26(R0)
+ FMOVD F26, 16*27(R0)
+ FMOVD F27, 16*28(R0)
+ FMOVD F28, 16*29(R0)
+ FMOVD F29, 16*30(R0)
+ FMOVD F30, 16*31(R0)
+ FMOVD F31, 16*32(R0)
+ ISB $15
+
+ RET
+
+TEXT ·LoadVRegs(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R0
+
+ // Skip aarch64_ctx, fpsr, fpcr.
+ FMOVD 16*1(R0), F0
+ FMOVD 16*2(R0), F1
+ FMOVD 16*3(R0), F2
+ FMOVD 16*4(R0), F3
+ FMOVD 16*5(R0), F4
+ FMOVD 16*6(R0), F5
+ FMOVD 16*7(R0), F6
+ FMOVD 16*8(R0), F7
+ FMOVD 16*9(R0), F8
+ FMOVD 16*10(R0), F9
+ FMOVD 16*11(R0), F10
+ FMOVD 16*12(R0), F11
+ FMOVD 16*13(R0), F12
+ FMOVD 16*14(R0), F13
+ FMOVD 16*15(R0), F14
+ FMOVD 16*16(R0), F15
+ FMOVD 16*17(R0), F16
+ FMOVD 16*18(R0), F17
+ FMOVD 16*19(R0), F18
+ FMOVD 16*20(R0), F19
+ FMOVD 16*21(R0), F20
+ FMOVD 16*22(R0), F21
+ FMOVD 16*23(R0), F22
+ FMOVD 16*24(R0), F23
+ FMOVD 16*25(R0), F24
+ FMOVD 16*26(R0), F25
+ FMOVD 16*27(R0), F26
+ FMOVD 16*28(R0), F27
+ FMOVD 16*29(R0), F28
+ FMOVD 16*30(R0), F29
+ FMOVD 16*31(R0), F30
+ FMOVD 16*32(R0), F31
+ ISB $15
+
+ RET
diff --git a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go b/pkg/sentry/platform/ring0/lib_arm64_unsafe.go
new file mode 100644
index 000000000..c05166fea
--- /dev/null
+++ b/pkg/sentry/platform/ring0/lib_arm64_unsafe.go
@@ -0,0 +1,108 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+import (
+ "reflect"
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/safecopy"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ nopInstruction = 0xd503201f
+ instSize = unsafe.Sizeof(uint32(0))
+ vectorsRawLen = 0x800
+)
+
+func unsafeSlice(addr uintptr, length int) (slice []uint32) {
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
+ hdr.Data = addr
+ hdr.Len = length / int(instSize)
+ hdr.Cap = length / int(instSize)
+ return slice
+}
+
+// Work around: move ring0.Vectors() into a specific address with 11-bits alignment.
+//
+// According to the design documentation of Arm64,
+// the start address of exception vector table should be 11-bits aligned.
+// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S
+// But, we can't align a function's start address to a specific address by using golang.
+// We have raised this question in golang community:
+// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I
+// This function will be removed when golang supports this feature.
+//
+// There are 2 jobs were implemented in this function:
+// 1, move the start address of exception vector table into the specific address.
+// 2, modify the offset of each instruction.
+func rewriteVectors() {
+ vectorsBegin := reflect.ValueOf(Vectors).Pointer()
+
+ // The exception-vector-table is required to be 11-bits aligned.
+ // And the size is 0x800.
+ // Please see the documentation as reference:
+ // https://developer.arm.com/docs/100933/0100/aarch64-exception-vector-table
+ //
+ // But, golang does not allow to set a function's address to a specific value.
+ // So, for gvisor, I defined the size of exception-vector-table as 4K,
+ // filled the 2nd 2K part with NOP-s.
+ // So that, I can safely move the 1st 2K part into the address with 11-bits alignment.
+ //
+ // So, the prerequisite for this function to work correctly is:
+ // vectorsSafeLen >= 0x1000
+ // vectorsRawLen = 0x800
+ vectorsSafeLen := int(safecopy.FindEndAddress(vectorsBegin) - vectorsBegin)
+ if vectorsSafeLen < 2*vectorsRawLen {
+ panic("Can't update vectors")
+ }
+
+ vectorsSafeTable := unsafeSlice(vectorsBegin, vectorsSafeLen) // Now a []uint32
+ vectorsRawLen32 := vectorsRawLen / int(instSize)
+
+ offset := vectorsBegin & (1<<11 - 1)
+ if offset != 0 {
+ offset = 1<<11 - offset
+ }
+
+ pageBegin := (vectorsBegin + offset) & ^uintptr(usermem.PageSize-1)
+
+ _, _, errno := syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC))
+ if errno != 0 {
+ panic(errno.Error())
+ }
+
+ offset = offset / instSize // By index, not bytes.
+ // Move exception-vector-table into the specific address, should uses memmove here.
+ for i := 1; i <= vectorsRawLen32; i++ {
+ vectorsSafeTable[int(offset)+vectorsRawLen32-i] = vectorsSafeTable[vectorsRawLen32-i]
+ }
+
+ // Adjust branch since instruction was moved forward.
+ for i := 0; i < vectorsRawLen32; i++ {
+ if vectorsSafeTable[int(offset)+i] != nopInstruction {
+ vectorsSafeTable[int(offset)+i] -= uint32(offset)
+ }
+ }
+
+ _, _, errno = syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_EXEC))
+ if errno != 0 {
+ panic(errno.Error())
+ }
+}
diff --git a/pkg/sentry/platform/ring0/offsets_amd64.go b/pkg/sentry/platform/ring0/offsets_amd64.go
new file mode 100644
index 000000000..b8ab120a0
--- /dev/null
+++ b/pkg/sentry/platform/ring0/offsets_amd64.go
@@ -0,0 +1,93 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package ring0
+
+import (
+ "fmt"
+ "io"
+ "reflect"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+// Emit prints architecture-specific offsets.
+func Emit(w io.Writer) {
+ fmt.Fprintf(w, "// Automatically generated, do not edit.\n")
+
+ c := &CPU{}
+ fmt.Fprintf(w, "\n// CPU offsets.\n")
+ fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack)))
+ fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer())
+
+ fmt.Fprintf(w, "\n// Bits.\n")
+ fmt.Fprintf(w, "#define _RFLAGS_IF 0x%02x\n", _RFLAGS_IF)
+ fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet)
+
+ fmt.Fprintf(w, "\n// Vectors.\n")
+ fmt.Fprintf(w, "#define DivideByZero 0x%02x\n", DivideByZero)
+ fmt.Fprintf(w, "#define Debug 0x%02x\n", Debug)
+ fmt.Fprintf(w, "#define NMI 0x%02x\n", NMI)
+ fmt.Fprintf(w, "#define Breakpoint 0x%02x\n", Breakpoint)
+ fmt.Fprintf(w, "#define Overflow 0x%02x\n", Overflow)
+ fmt.Fprintf(w, "#define BoundRangeExceeded 0x%02x\n", BoundRangeExceeded)
+ fmt.Fprintf(w, "#define InvalidOpcode 0x%02x\n", InvalidOpcode)
+ fmt.Fprintf(w, "#define DeviceNotAvailable 0x%02x\n", DeviceNotAvailable)
+ fmt.Fprintf(w, "#define DoubleFault 0x%02x\n", DoubleFault)
+ fmt.Fprintf(w, "#define CoprocessorSegmentOverrun 0x%02x\n", CoprocessorSegmentOverrun)
+ fmt.Fprintf(w, "#define InvalidTSS 0x%02x\n", InvalidTSS)
+ fmt.Fprintf(w, "#define SegmentNotPresent 0x%02x\n", SegmentNotPresent)
+ fmt.Fprintf(w, "#define StackSegmentFault 0x%02x\n", StackSegmentFault)
+ fmt.Fprintf(w, "#define GeneralProtectionFault 0x%02x\n", GeneralProtectionFault)
+ fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault)
+ fmt.Fprintf(w, "#define X87FloatingPointException 0x%02x\n", X87FloatingPointException)
+ fmt.Fprintf(w, "#define AlignmentCheck 0x%02x\n", AlignmentCheck)
+ fmt.Fprintf(w, "#define MachineCheck 0x%02x\n", MachineCheck)
+ fmt.Fprintf(w, "#define SIMDFloatingPointException 0x%02x\n", SIMDFloatingPointException)
+ fmt.Fprintf(w, "#define VirtualizationException 0x%02x\n", VirtualizationException)
+ fmt.Fprintf(w, "#define SecurityException 0x%02x\n", SecurityException)
+ fmt.Fprintf(w, "#define SyscallInt80 0x%02x\n", SyscallInt80)
+ fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall)
+
+ p := &arch.Registers{}
+ fmt.Fprintf(w, "\n// Ptrace registers.\n")
+ fmt.Fprintf(w, "#define PTRACE_R15 0x%02x\n", reflect.ValueOf(&p.R15).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R14 0x%02x\n", reflect.ValueOf(&p.R14).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R13 0x%02x\n", reflect.ValueOf(&p.R13).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R12 0x%02x\n", reflect.ValueOf(&p.R12).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_RBP 0x%02x\n", reflect.ValueOf(&p.Rbp).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_RBX 0x%02x\n", reflect.ValueOf(&p.Rbx).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R11 0x%02x\n", reflect.ValueOf(&p.R11).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R10 0x%02x\n", reflect.ValueOf(&p.R10).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R9 0x%02x\n", reflect.ValueOf(&p.R9).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R8 0x%02x\n", reflect.ValueOf(&p.R8).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_RAX 0x%02x\n", reflect.ValueOf(&p.Rax).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_RCX 0x%02x\n", reflect.ValueOf(&p.Rcx).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_RDX 0x%02x\n", reflect.ValueOf(&p.Rdx).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_RSI 0x%02x\n", reflect.ValueOf(&p.Rsi).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_RDI 0x%02x\n", reflect.ValueOf(&p.Rdi).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_ORIGRAX 0x%02x\n", reflect.ValueOf(&p.Orig_rax).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_RIP 0x%02x\n", reflect.ValueOf(&p.Rip).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_CS 0x%02x\n", reflect.ValueOf(&p.Cs).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_FLAGS 0x%02x\n", reflect.ValueOf(&p.Eflags).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_RSP 0x%02x\n", reflect.ValueOf(&p.Rsp).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_SS 0x%02x\n", reflect.ValueOf(&p.Ss).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_FS 0x%02x\n", reflect.ValueOf(&p.Fs_base).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_GS 0x%02x\n", reflect.ValueOf(&p.Gs_base).Pointer()-reflect.ValueOf(p).Pointer())
+}
diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go
new file mode 100644
index 000000000..f3de962f0
--- /dev/null
+++ b/pkg/sentry/platform/ring0/offsets_arm64.go
@@ -0,0 +1,127 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+import (
+ "fmt"
+ "io"
+ "reflect"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+// Emit prints architecture-specific offsets.
+func Emit(w io.Writer) {
+ fmt.Fprintf(w, "// Automatically generated, do not edit.\n")
+
+ c := &CPU{}
+ fmt.Fprintf(w, "\n// CPU offsets.\n")
+ fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack)))
+ fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_FAULT_ADDR 0x%02x\n", reflect.ValueOf(&c.faultAddr).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_TTBR0_KVM 0x%02x\n", reflect.ValueOf(&c.ttbr0Kvm).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_TTBR0_APP 0x%02x\n", reflect.ValueOf(&c.ttbr0App).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_VECTOR_CODE 0x%02x\n", reflect.ValueOf(&c.vecCode).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_APP_ADDR 0x%02x\n", reflect.ValueOf(&c.appAddr).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_LAZY_VFP 0x%02x\n", reflect.ValueOf(&c.lazyVFP).Pointer()-reflect.ValueOf(c).Pointer())
+
+ fmt.Fprintf(w, "\n// Bits.\n")
+ fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet)
+
+ fmt.Fprintf(w, "\n// Vectors.\n")
+ fmt.Fprintf(w, "#define El1SyncInvalid 0x%02x\n", El1SyncInvalid)
+ fmt.Fprintf(w, "#define El1IrqInvalid 0x%02x\n", El1IrqInvalid)
+ fmt.Fprintf(w, "#define El1FiqInvalid 0x%02x\n", El1FiqInvalid)
+ fmt.Fprintf(w, "#define El1ErrorInvalid 0x%02x\n", El1ErrorInvalid)
+
+ fmt.Fprintf(w, "#define El1Sync 0x%02x\n", El1Sync)
+ fmt.Fprintf(w, "#define El1Irq 0x%02x\n", El1Irq)
+ fmt.Fprintf(w, "#define El1Fiq 0x%02x\n", El1Fiq)
+ fmt.Fprintf(w, "#define El1Error 0x%02x\n", El1Error)
+
+ fmt.Fprintf(w, "#define El0Sync 0x%02x\n", El0Sync)
+ fmt.Fprintf(w, "#define El0Irq 0x%02x\n", El0Irq)
+ fmt.Fprintf(w, "#define El0Fiq 0x%02x\n", El0Fiq)
+ fmt.Fprintf(w, "#define El0Error 0x%02x\n", El0Error)
+
+ fmt.Fprintf(w, "#define El0Sync_invalid 0x%02x\n", El0Sync_invalid)
+ fmt.Fprintf(w, "#define El0Irq_invalid 0x%02x\n", El0Irq_invalid)
+ fmt.Fprintf(w, "#define El0Fiq_invalid 0x%02x\n", El0Fiq_invalid)
+ fmt.Fprintf(w, "#define El0Error_invalid 0x%02x\n", El0Error_invalid)
+
+ fmt.Fprintf(w, "#define El1Sync_da 0x%02x\n", El1Sync_da)
+ fmt.Fprintf(w, "#define El1Sync_ia 0x%02x\n", El1Sync_ia)
+ fmt.Fprintf(w, "#define El1Sync_sp_pc 0x%02x\n", El1Sync_sp_pc)
+ fmt.Fprintf(w, "#define El1Sync_undef 0x%02x\n", El1Sync_undef)
+ fmt.Fprintf(w, "#define El1Sync_dbg 0x%02x\n", El1Sync_dbg)
+ fmt.Fprintf(w, "#define El1Sync_inv 0x%02x\n", El1Sync_inv)
+
+ fmt.Fprintf(w, "#define El0Sync_svc 0x%02x\n", El0Sync_svc)
+ fmt.Fprintf(w, "#define El0Sync_da 0x%02x\n", El0Sync_da)
+ fmt.Fprintf(w, "#define El0Sync_ia 0x%02x\n", El0Sync_ia)
+ fmt.Fprintf(w, "#define El0Sync_fpsimd_acc 0x%02x\n", El0Sync_fpsimd_acc)
+ fmt.Fprintf(w, "#define El0Sync_sve_acc 0x%02x\n", El0Sync_sve_acc)
+ fmt.Fprintf(w, "#define El0Sync_sys 0x%02x\n", El0Sync_sys)
+ fmt.Fprintf(w, "#define El0Sync_sp_pc 0x%02x\n", El0Sync_sp_pc)
+ fmt.Fprintf(w, "#define El0Sync_undef 0x%02x\n", El0Sync_undef)
+ fmt.Fprintf(w, "#define El0Sync_dbg 0x%02x\n", El0Sync_dbg)
+ fmt.Fprintf(w, "#define El0Sync_inv 0x%02x\n", El0Sync_inv)
+
+ fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault)
+ fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall)
+ fmt.Fprintf(w, "#define VirtualizationException 0x%02x\n", VirtualizationException)
+
+ p := &arch.Registers{}
+ fmt.Fprintf(w, "\n// Ptrace registers.\n")
+ fmt.Fprintf(w, "#define PTRACE_R0 0x%02x\n", reflect.ValueOf(&p.Regs[0]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R1 0x%02x\n", reflect.ValueOf(&p.Regs[1]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R2 0x%02x\n", reflect.ValueOf(&p.Regs[2]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R3 0x%02x\n", reflect.ValueOf(&p.Regs[3]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R4 0x%02x\n", reflect.ValueOf(&p.Regs[4]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R5 0x%02x\n", reflect.ValueOf(&p.Regs[5]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R6 0x%02x\n", reflect.ValueOf(&p.Regs[6]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R7 0x%02x\n", reflect.ValueOf(&p.Regs[7]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R8 0x%02x\n", reflect.ValueOf(&p.Regs[8]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R9 0x%02x\n", reflect.ValueOf(&p.Regs[9]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R10 0x%02x\n", reflect.ValueOf(&p.Regs[10]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R11 0x%02x\n", reflect.ValueOf(&p.Regs[11]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R12 0x%02x\n", reflect.ValueOf(&p.Regs[12]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R13 0x%02x\n", reflect.ValueOf(&p.Regs[13]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R14 0x%02x\n", reflect.ValueOf(&p.Regs[14]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R15 0x%02x\n", reflect.ValueOf(&p.Regs[15]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R16 0x%02x\n", reflect.ValueOf(&p.Regs[16]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R17 0x%02x\n", reflect.ValueOf(&p.Regs[17]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R18 0x%02x\n", reflect.ValueOf(&p.Regs[18]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R19 0x%02x\n", reflect.ValueOf(&p.Regs[19]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R20 0x%02x\n", reflect.ValueOf(&p.Regs[20]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R21 0x%02x\n", reflect.ValueOf(&p.Regs[21]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R22 0x%02x\n", reflect.ValueOf(&p.Regs[22]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R23 0x%02x\n", reflect.ValueOf(&p.Regs[23]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R24 0x%02x\n", reflect.ValueOf(&p.Regs[24]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R25 0x%02x\n", reflect.ValueOf(&p.Regs[25]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R26 0x%02x\n", reflect.ValueOf(&p.Regs[26]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R27 0x%02x\n", reflect.ValueOf(&p.Regs[27]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R28 0x%02x\n", reflect.ValueOf(&p.Regs[28]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R29 0x%02x\n", reflect.ValueOf(&p.Regs[29]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R30 0x%02x\n", reflect.ValueOf(&p.Regs[30]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_SP 0x%02x\n", reflect.ValueOf(&p.Sp).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_PC 0x%02x\n", reflect.ValueOf(&p.Pc).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_PSTATE 0x%02x\n", reflect.ValueOf(&p.Pstate).Pointer()-reflect.ValueOf(p).Pointer())
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD
new file mode 100644
index 000000000..16d5f478b
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/BUILD
@@ -0,0 +1,115 @@
+load("//tools:defs.bzl", "go_library", "go_test", "select_arch")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template(
+ name = "generic_walker",
+ srcs = select_arch(
+ amd64 = ["walker_amd64.go"],
+ arm64 = ["walker_arm64.go"],
+ ),
+ opt_types = [
+ "Visitor",
+ ],
+ visibility = [":__pkg__"],
+)
+
+go_template_instance(
+ name = "walker_map",
+ out = "walker_map.go",
+ package = "pagetables",
+ prefix = "map",
+ template = ":generic_walker",
+ types = {
+ "Visitor": "mapVisitor",
+ },
+)
+
+go_template_instance(
+ name = "walker_unmap",
+ out = "walker_unmap.go",
+ package = "pagetables",
+ prefix = "unmap",
+ template = ":generic_walker",
+ types = {
+ "Visitor": "unmapVisitor",
+ },
+)
+
+go_template_instance(
+ name = "walker_lookup",
+ out = "walker_lookup.go",
+ package = "pagetables",
+ prefix = "lookup",
+ template = ":generic_walker",
+ types = {
+ "Visitor": "lookupVisitor",
+ },
+)
+
+go_template_instance(
+ name = "walker_empty",
+ out = "walker_empty.go",
+ package = "pagetables",
+ prefix = "empty",
+ template = ":generic_walker",
+ types = {
+ "Visitor": "emptyVisitor",
+ },
+)
+
+go_template_instance(
+ name = "walker_check",
+ out = "walker_check.go",
+ package = "pagetables",
+ prefix = "check",
+ template = ":generic_walker",
+ types = {
+ "Visitor": "checkVisitor",
+ },
+)
+
+go_library(
+ name = "pagetables",
+ srcs = [
+ "allocator.go",
+ "allocator_unsafe.go",
+ "pagetables.go",
+ "pagetables_aarch64.go",
+ "pagetables_amd64.go",
+ "pagetables_arm64.go",
+ "pagetables_x86.go",
+ "pcids.go",
+ "pcids_aarch64.go",
+ "pcids_aarch64.s",
+ "pcids_x86.go",
+ "walker_amd64.go",
+ "walker_arm64.go",
+ "walker_empty.go",
+ "walker_lookup.go",
+ "walker_map.go",
+ "walker_unmap.go",
+ ],
+ visibility = [
+ "//pkg/sentry/platform/kvm:__subpackages__",
+ "//pkg/sentry/platform/ring0:__subpackages__",
+ ],
+ deps = [
+ "//pkg/sync",
+ "//pkg/usermem",
+ ],
+)
+
+go_test(
+ name = "pagetables_test",
+ size = "small",
+ srcs = [
+ "pagetables_amd64_test.go",
+ "pagetables_arm64_test.go",
+ "pagetables_test.go",
+ "walker_check.go",
+ ],
+ library = ":pagetables",
+ deps = ["//pkg/usermem"],
+)
diff --git a/pkg/sentry/platform/ring0/pagetables/allocator.go b/pkg/sentry/platform/ring0/pagetables/allocator.go
new file mode 100644
index 000000000..8d75b7599
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/allocator.go
@@ -0,0 +1,127 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pagetables
+
+// Allocator is used to allocate and map PTEs.
+//
+// Note that allocators may be called concurrently.
+type Allocator interface {
+ // NewPTEs returns a new set of PTEs and their physical address.
+ NewPTEs() *PTEs
+
+ // PhysicalFor gives the physical address for a set of PTEs.
+ PhysicalFor(ptes *PTEs) uintptr
+
+ // LookupPTEs looks up PTEs by physical address.
+ LookupPTEs(physical uintptr) *PTEs
+
+ // FreePTEs marks a set of PTEs a freed, although they may not be available
+ // for use again until Recycle is called, below.
+ FreePTEs(ptes *PTEs)
+
+ // Recycle makes freed PTEs available for use again.
+ Recycle()
+}
+
+// RuntimeAllocator is a trivial allocator.
+type RuntimeAllocator struct {
+ // used is the set of PTEs that have been allocated. This includes any
+ // PTEs that may be in the pool below. PTEs are only freed from this
+ // map by the Drain call.
+ //
+ // This exists to prevent accidental garbage collection.
+ used map[*PTEs]struct{}
+
+ // pool is the set of free-to-use PTEs.
+ pool []*PTEs
+
+ // freed is the set of recently-freed PTEs.
+ freed []*PTEs
+}
+
+// NewRuntimeAllocator returns an allocator that uses runtime allocation.
+func NewRuntimeAllocator() *RuntimeAllocator {
+ r := new(RuntimeAllocator)
+ r.Init()
+ return r
+}
+
+// Init initializes a RuntimeAllocator.
+func (r *RuntimeAllocator) Init() {
+ r.used = make(map[*PTEs]struct{})
+}
+
+// Recycle returns freed pages to the pool.
+func (r *RuntimeAllocator) Recycle() {
+ r.pool = append(r.pool, r.freed...)
+ r.freed = r.freed[:0]
+}
+
+// Drain empties the pool.
+func (r *RuntimeAllocator) Drain() {
+ r.Recycle()
+ for i, ptes := range r.pool {
+ // Zap the entry in the underlying array to ensure that it can
+ // be properly garbage collected.
+ r.pool[i] = nil
+ // Similarly, free the reference held by the used map (these
+ // also apply for the pool entries).
+ delete(r.used, ptes)
+ }
+ r.pool = r.pool[:0]
+}
+
+// NewPTEs implements Allocator.NewPTEs.
+//
+// Note that the "physical" address here is actually the virtual address of the
+// PTEs structure. The entries are tracked only to avoid garbage collection.
+//
+// This is guaranteed not to split as long as the pool is sufficiently full.
+//
+//go:nosplit
+func (r *RuntimeAllocator) NewPTEs() *PTEs {
+ // Pull from the pool if we can.
+ if len(r.pool) > 0 {
+ ptes := r.pool[len(r.pool)-1]
+ r.pool = r.pool[:len(r.pool)-1]
+ return ptes
+ }
+
+ // Allocate a new entry.
+ ptes := newAlignedPTEs()
+ r.used[ptes] = struct{}{}
+ return ptes
+}
+
+// PhysicalFor returns the physical address for the given PTEs.
+//
+//go:nosplit
+func (r *RuntimeAllocator) PhysicalFor(ptes *PTEs) uintptr {
+ return physicalFor(ptes)
+}
+
+// LookupPTEs implements Allocator.LookupPTEs.
+//
+//go:nosplit
+func (r *RuntimeAllocator) LookupPTEs(physical uintptr) *PTEs {
+ return fromPhysical(physical)
+}
+
+// FreePTEs implements Allocator.FreePTEs.
+//
+//go:nosplit
+func (r *RuntimeAllocator) FreePTEs(ptes *PTEs) {
+ r.freed = append(r.freed, ptes)
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go b/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go
new file mode 100644
index 000000000..d08bfdeb3
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go
@@ -0,0 +1,53 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pagetables
+
+import (
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// newAlignedPTEs returns a set of aligned PTEs.
+func newAlignedPTEs() *PTEs {
+ ptes := new(PTEs)
+ offset := physicalFor(ptes) & (usermem.PageSize - 1)
+ if offset == 0 {
+ // Already aligned.
+ return ptes
+ }
+
+ // Need to force an aligned allocation.
+ unaligned := make([]byte, (2*usermem.PageSize)-1)
+ offset = uintptr(unsafe.Pointer(&unaligned[0])) & (usermem.PageSize - 1)
+ if offset != 0 {
+ offset = usermem.PageSize - offset
+ }
+ return (*PTEs)(unsafe.Pointer(&unaligned[offset]))
+}
+
+// physicalFor returns the "physical" address for PTEs.
+//
+//go:nosplit
+func physicalFor(ptes *PTEs) uintptr {
+ return uintptr(unsafe.Pointer(ptes))
+}
+
+// fromPhysical returns the PTEs from the "physical" address.
+//
+//go:nosplit
+func fromPhysical(physical uintptr) *PTEs {
+ return (*PTEs)(unsafe.Pointer(physical))
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go
new file mode 100644
index 000000000..7f18ac296
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go
@@ -0,0 +1,220 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pagetables provides a generic implementation of pagetables.
+//
+// The core functions must be safe to call from a nosplit context. Furthermore,
+// this pagetables implementation goes to lengths to ensure that all functions
+// are free from runtime allocation. Calls to NewPTEs/FreePTEs may be made
+// during walks, but these can be cached elsewhere if required.
+package pagetables
+
+import (
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// PageTables is a set of page tables.
+type PageTables struct {
+ // Allocator is used to allocate nodes.
+ Allocator Allocator
+
+ // root is the pagetable root.
+ root *PTEs
+
+ // rootPhysical is the cached physical address of the root.
+ //
+ // This is saved only to prevent constant translation.
+ rootPhysical uintptr
+
+ // archPageTables includes architecture-specific features.
+ archPageTables
+}
+
+// New returns new PageTables.
+func New(a Allocator) *PageTables {
+ p := new(PageTables)
+ p.Init(a)
+ return p
+}
+
+// mapVisitor is used for map.
+type mapVisitor struct {
+ target uintptr // Input.
+ physical uintptr // Input.
+ opts MapOpts // Input.
+ prev bool // Output.
+}
+
+// visit is used for map.
+//
+//go:nosplit
+func (v *mapVisitor) visit(start uintptr, pte *PTE, align uintptr) {
+ p := v.physical + (start - uintptr(v.target))
+ if pte.Valid() && (pte.Address() != p || pte.Opts() != v.opts) {
+ v.prev = true
+ }
+ if p&align != 0 {
+ // We will install entries at a smaller granulaity if we don't
+ // install a valid entry here, however we must zap any existing
+ // entry to ensure this happens.
+ pte.Clear()
+ return
+ }
+ pte.Set(p, v.opts)
+}
+
+//go:nosplit
+func (*mapVisitor) requiresAlloc() bool { return true }
+
+//go:nosplit
+func (*mapVisitor) requiresSplit() bool { return true }
+
+// Map installs a mapping with the given physical address.
+//
+// True is returned iff there was a previous mapping in the range.
+//
+// Precondition: addr & length must be page-aligned, their sum must not overflow.
+//
+// +checkescape:hard,stack
+//
+//go:nosplit
+func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physical uintptr) bool {
+ if !opts.AccessType.Any() {
+ return p.Unmap(addr, length)
+ }
+ w := mapWalker{
+ pageTables: p,
+ visitor: mapVisitor{
+ target: uintptr(addr),
+ physical: physical,
+ opts: opts,
+ },
+ }
+ w.iterateRange(uintptr(addr), uintptr(addr)+length)
+ return w.visitor.prev
+}
+
+// unmapVisitor is used for unmap.
+type unmapVisitor struct {
+ count int
+}
+
+//go:nosplit
+func (*unmapVisitor) requiresAlloc() bool { return false }
+
+//go:nosplit
+func (*unmapVisitor) requiresSplit() bool { return true }
+
+// visit unmaps the given entry.
+//
+//go:nosplit
+func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) {
+ pte.Clear()
+ v.count++
+}
+
+// Unmap unmaps the given range.
+//
+// True is returned iff there was a previous mapping in the range.
+//
+// Precondition: addr & length must be page-aligned.
+//
+// +checkescape:hard,stack
+//
+//go:nosplit
+func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool {
+ w := unmapWalker{
+ pageTables: p,
+ visitor: unmapVisitor{
+ count: 0,
+ },
+ }
+ w.iterateRange(uintptr(addr), uintptr(addr)+length)
+ return w.visitor.count > 0
+}
+
+// emptyVisitor is used for emptiness checks.
+type emptyVisitor struct {
+ count int
+}
+
+//go:nosplit
+func (*emptyVisitor) requiresAlloc() bool { return false }
+
+//go:nosplit
+func (*emptyVisitor) requiresSplit() bool { return false }
+
+// visit unmaps the given entry.
+//
+//go:nosplit
+func (v *emptyVisitor) visit(start uintptr, pte *PTE, align uintptr) {
+ v.count++
+}
+
+// IsEmpty checks if the given range is empty.
+//
+// Precondition: addr & length must be page-aligned.
+//
+// +checkescape:hard,stack
+//
+//go:nosplit
+func (p *PageTables) IsEmpty(addr usermem.Addr, length uintptr) bool {
+ w := emptyWalker{
+ pageTables: p,
+ }
+ w.iterateRange(uintptr(addr), uintptr(addr)+length)
+ return w.visitor.count == 0
+}
+
+// lookupVisitor is used for lookup.
+type lookupVisitor struct {
+ target uintptr // Input.
+ physical uintptr // Output.
+ opts MapOpts // Output.
+}
+
+// visit matches the given address.
+//
+//go:nosplit
+func (v *lookupVisitor) visit(start uintptr, pte *PTE, align uintptr) {
+ if !pte.Valid() {
+ return
+ }
+ v.physical = pte.Address() + (start - uintptr(v.target))
+ v.opts = pte.Opts()
+}
+
+//go:nosplit
+func (*lookupVisitor) requiresAlloc() bool { return false }
+
+//go:nosplit
+func (*lookupVisitor) requiresSplit() bool { return false }
+
+// Lookup returns the physical address for the given virtual address.
+//
+// +checkescape:hard,stack
+//
+//go:nosplit
+func (p *PageTables) Lookup(addr usermem.Addr) (physical uintptr, opts MapOpts) {
+ mask := uintptr(usermem.PageSize - 1)
+ offset := uintptr(addr) & mask
+ w := lookupWalker{
+ pageTables: p,
+ visitor: lookupVisitor{
+ target: uintptr(addr &^ usermem.Addr(mask)),
+ },
+ }
+ w.iterateRange(uintptr(addr), uintptr(addr)+1)
+ return w.visitor.physical + offset, w.visitor.opts
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
new file mode 100644
index 000000000..78510ebed
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
@@ -0,0 +1,212 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package pagetables
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// archPageTables is architecture-specific data.
+type archPageTables struct {
+ // root is the pagetable root for kernel space.
+ root *PTEs
+
+ // rootPhysical is the cached physical address of the root.
+ //
+ // This is saved only to prevent constant translation.
+ rootPhysical uintptr
+
+ asid uint16
+}
+
+// TTBR0_EL1 returns the translation table base register 0.
+//
+//go:nosplit
+func (p *PageTables) TTBR0_EL1(noFlush bool, asid uint16) uint64 {
+ return uint64(p.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset
+}
+
+// TTBR1_EL1 returns the translation table base register 1.
+//
+//go:nosplit
+func (p *PageTables) TTBR1_EL1(noFlush bool, asid uint16) uint64 {
+ return uint64(p.archPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset
+}
+
+// Bits in page table entries.
+const (
+ typeTable = 0x3 << 0
+ typeSect = 0x1 << 0
+ typePage = 0x3 << 0
+ pteValid = 0x1 << 0
+ pteTableBit = 0x1 << 1
+ pteTypeMask = 0x3 << 0
+ present = pteValid | pteTableBit
+ user = 0x1 << 6 /* AP[1] */
+ readOnly = 0x1 << 7 /* AP[2] */
+ accessed = 0x1 << 10
+ dbm = 0x1 << 51
+ writable = dbm
+ cont = 0x1 << 52
+ pxn = 0x1 << 53
+ xn = 0x1 << 54
+ dirty = 0x1 << 55
+ nG = 0x1 << 11
+ shared = 0x3 << 8
+)
+
+const (
+ mtNormal = 0x4 << 2
+)
+
+const (
+ executeDisable = xn
+ optionMask = 0xfff | 0xfff<<48
+ protDefault = accessed | shared | mtNormal
+)
+
+// MapOpts are x86 options.
+type MapOpts struct {
+ // AccessType defines permissions.
+ AccessType usermem.AccessType
+
+ // Global indicates the page is globally accessible.
+ Global bool
+
+ // User indicates the page is a user page.
+ User bool
+}
+
+// PTE is a page table entry.
+type PTE uintptr
+
+// Clear clears this PTE, including sect page information.
+//
+//go:nosplit
+func (p *PTE) Clear() {
+ atomic.StoreUintptr((*uintptr)(p), 0)
+}
+
+// Valid returns true iff this entry is valid.
+//
+//go:nosplit
+func (p *PTE) Valid() bool {
+ return atomic.LoadUintptr((*uintptr)(p))&present != 0
+}
+
+// Opts returns the PTE options.
+//
+// These are all options except Valid and Sect.
+//
+//go:nosplit
+func (p *PTE) Opts() MapOpts {
+ v := atomic.LoadUintptr((*uintptr)(p))
+
+ return MapOpts{
+ AccessType: usermem.AccessType{
+ Read: true,
+ Write: v&readOnly == 0,
+ Execute: v&xn == 0,
+ },
+ Global: v&nG == 0,
+ User: v&user != 0,
+ }
+}
+
+// SetSect sets this page as a sect page.
+//
+// The page must not be valid or a panic will result.
+//
+//go:nosplit
+func (p *PTE) SetSect() {
+ if p.Valid() {
+ // This is not allowed.
+ panic("SetSect called on valid page!")
+ }
+ atomic.StoreUintptr((*uintptr)(p), typeSect)
+}
+
+// IsSect returns true iff this page is a sect page.
+//
+//go:nosplit
+func (p *PTE) IsSect() bool {
+ return atomic.LoadUintptr((*uintptr)(p))&pteTypeMask == typeSect
+}
+
+// Set sets this PTE value.
+//
+// This does not change the sect page property.
+//
+//go:nosplit
+func (p *PTE) Set(addr uintptr, opts MapOpts) {
+ if !opts.AccessType.Any() {
+ p.Clear()
+ return
+ }
+ v := (addr &^ optionMask) | protDefault | nG | readOnly
+
+ if p.IsSect() {
+ // Note that this is inherited from the previous instance. Set
+ // does not change the value of Sect. See above.
+ v |= typeSect
+ } else {
+ v |= typePage
+ }
+
+ if opts.Global {
+ v = v &^ nG
+ }
+
+ if opts.AccessType.Execute {
+ v = v &^ executeDisable
+ } else {
+ v |= executeDisable
+ }
+ if opts.AccessType.Write {
+ v = v &^ readOnly
+ }
+
+ if opts.User {
+ v |= user
+ } else {
+ v = v &^ user
+ }
+ atomic.StoreUintptr((*uintptr)(p), v)
+}
+
+// setPageTable sets this PTE value and forces the write bit and sect bit to
+// be cleared. This is used explicitly for breaking sect pages.
+//
+//go:nosplit
+func (p *PTE) setPageTable(pt *PageTables, ptes *PTEs) {
+ addr := pt.Allocator.PhysicalFor(ptes)
+ if addr&^optionMask != addr {
+ // This should never happen.
+ panic("unaligned physical address!")
+ }
+ v := addr | typeTable | protDefault
+ atomic.StoreUintptr((*uintptr)(p), v)
+}
+
+// Address extracts the address. This should only be used if Valid returns true.
+//
+//go:nosplit
+func (p *PTE) Address() uintptr {
+ return atomic.LoadUintptr((*uintptr)(p)) &^ optionMask
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
new file mode 100644
index 000000000..0c153cf8c
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
@@ -0,0 +1,54 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pagetables
+
+// Address constraints.
+//
+// The lowerTop and upperBottom currently apply to four-level pagetables;
+// additional refactoring would be necessary to support five-level pagetables.
+const (
+ lowerTop = 0x00007fffffffffff
+ upperBottom = 0xffff800000000000
+
+ pteShift = 12
+ pmdShift = 21
+ pudShift = 30
+ pgdShift = 39
+
+ pteMask = 0x1ff << pteShift
+ pmdMask = 0x1ff << pmdShift
+ pudMask = 0x1ff << pudShift
+ pgdMask = 0x1ff << pgdShift
+
+ pteSize = 1 << pteShift
+ pmdSize = 1 << pmdShift
+ pudSize = 1 << pudShift
+ pgdSize = 1 << pgdShift
+
+ executeDisable = 1 << 63
+ entriesPerPage = 512
+)
+
+// Init initializes a set of PageTables.
+//
+//go:nosplit
+func (p *PageTables) Init(allocator Allocator) {
+ p.Allocator = allocator
+ p.root = p.Allocator.NewPTEs()
+ p.rootPhysical = p.Allocator.PhysicalFor(p.root)
+}
+
+// PTEs is a collection of entries.
+type PTEs [entriesPerPage]PTE
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go
new file mode 100644
index 000000000..54e8e554f
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go
@@ -0,0 +1,75 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package pagetables
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func Test2MAnd4K(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map a small page and a huge page.
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42)
+ pt.Map(0x00007f0000000000, pmdSize, MapOpts{AccessType: usermem.Read}, pmdSize*47)
+
+ checkMappings(t, pt, []mapping{
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}},
+ {0x00007f0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read}},
+ })
+}
+
+func Test1GAnd4K(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map a small page and a super page.
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42)
+ pt.Map(0x00007f0000000000, pudSize, MapOpts{AccessType: usermem.Read}, pudSize*47)
+
+ checkMappings(t, pt, []mapping{
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}},
+ {0x00007f0000000000, pudSize, pudSize * 47, MapOpts{AccessType: usermem.Read}},
+ })
+}
+
+func TestSplit1GPage(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map a super page and knock out the middle.
+ pt.Map(0x00007f0000000000, pudSize, MapOpts{AccessType: usermem.Read}, pudSize*42)
+ pt.Unmap(usermem.Addr(0x00007f0000000000+pteSize), pudSize-(2*pteSize))
+
+ checkMappings(t, pt, []mapping{
+ {0x00007f0000000000, pteSize, pudSize * 42, MapOpts{AccessType: usermem.Read}},
+ {0x00007f0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: usermem.Read}},
+ })
+}
+
+func TestSplit2MPage(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map a huge page and knock out the middle.
+ pt.Map(0x00007f0000000000, pmdSize, MapOpts{AccessType: usermem.Read}, pmdSize*42)
+ pt.Unmap(usermem.Addr(0x00007f0000000000+pteSize), pmdSize-(2*pteSize))
+
+ checkMappings(t, pt, []mapping{
+ {0x00007f0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: usermem.Read}},
+ {0x00007f0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: usermem.Read}},
+ })
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
new file mode 100644
index 000000000..1a49f12a2
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
@@ -0,0 +1,57 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pagetables
+
+// Address constraints.
+//
+// The lowerTop and upperBottom currently apply to four-level pagetables;
+// additional refactoring would be necessary to support five-level pagetables.
+const (
+ lowerTop = 0x0000ffffffffffff
+ upperBottom = 0xffff000000000000
+ pteShift = 12
+ pmdShift = 21
+ pudShift = 30
+ pgdShift = 39
+
+ pteMask = 0x1ff << pteShift
+ pmdMask = 0x1ff << pmdShift
+ pudMask = 0x1ff << pudShift
+ pgdMask = 0x1ff << pgdShift
+
+ pteSize = 1 << pteShift
+ pmdSize = 1 << pmdShift
+ pudSize = 1 << pudShift
+ pgdSize = 1 << pgdShift
+
+ ttbrASIDOffset = 55
+ ttbrASIDMask = 0xff
+
+ entriesPerPage = 512
+)
+
+// Init initializes a set of PageTables.
+//
+//go:nosplit
+func (p *PageTables) Init(allocator Allocator) {
+ p.Allocator = allocator
+ p.root = p.Allocator.NewPTEs()
+ p.rootPhysical = p.Allocator.PhysicalFor(p.root)
+ p.archPageTables.root = p.Allocator.NewPTEs()
+ p.archPageTables.rootPhysical = p.Allocator.PhysicalFor(p.archPageTables.root)
+}
+
+// PTEs is a collection of entries.
+type PTEs [entriesPerPage]PTE
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go
new file mode 100644
index 000000000..2f73d424f
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go
@@ -0,0 +1,80 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package pagetables
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func Test2MAnd4K(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map a small page and a huge page.
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42)
+ pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*47)
+
+ pt.Map(0xffff000000400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: false}, pteSize*42)
+ pt.Map(0xffffff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: false}, pmdSize*47)
+
+ checkMappings(t, pt, []mapping{
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}},
+ {0x0000ff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: true}},
+ {0xffff000000400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: false}},
+ {0xffffff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: false}},
+ })
+}
+
+func Test1GAnd4K(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map a small page and a super page.
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42)
+ pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*47)
+
+ checkMappings(t, pt, []mapping{
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}},
+ {0x0000ff0000000000, pudSize, pudSize * 47, MapOpts{AccessType: usermem.Read, User: true}},
+ })
+}
+
+func TestSplit1GPage(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map a super page and knock out the middle.
+ pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*42)
+ pt.Unmap(usermem.Addr(0x0000ff0000000000+pteSize), pudSize-(2*pteSize))
+
+ checkMappings(t, pt, []mapping{
+ {0x0000ff0000000000, pteSize, pudSize * 42, MapOpts{AccessType: usermem.Read, User: true}},
+ {0x0000ff0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}},
+ })
+}
+
+func TestSplit2MPage(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map a huge page and knock out the middle.
+ pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*42)
+ pt.Unmap(usermem.Addr(0x0000ff0000000000+pteSize), pmdSize-(2*pteSize))
+
+ checkMappings(t, pt, []mapping{
+ {0x0000ff0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: usermem.Read, User: true}},
+ {0x0000ff0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}},
+ })
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_test.go
new file mode 100644
index 000000000..5c88d087d
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_test.go
@@ -0,0 +1,156 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pagetables
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type mapping struct {
+ start uintptr
+ length uintptr
+ addr uintptr
+ opts MapOpts
+}
+
+type checkVisitor struct {
+ expected []mapping // Input.
+ current int // Temporary.
+ found []mapping // Output.
+ failed string // Output.
+}
+
+func (v *checkVisitor) visit(start uintptr, pte *PTE, align uintptr) {
+ v.found = append(v.found, mapping{
+ start: start,
+ length: align + 1,
+ addr: pte.Address(),
+ opts: pte.Opts(),
+ })
+ if v.failed != "" {
+ // Don't keep looking for errors.
+ return
+ }
+
+ if v.current >= len(v.expected) {
+ v.failed = "more mappings than expected"
+ } else if v.expected[v.current].start != start {
+ v.failed = "start didn't match expected"
+ } else if v.expected[v.current].length != (align + 1) {
+ v.failed = "end didn't match expected"
+ } else if v.expected[v.current].addr != pte.Address() {
+ v.failed = "address didn't match expected"
+ } else if v.expected[v.current].opts != pte.Opts() {
+ v.failed = "opts didn't match"
+ }
+ v.current++
+}
+
+func (*checkVisitor) requiresAlloc() bool { return false }
+
+func (*checkVisitor) requiresSplit() bool { return false }
+
+func checkMappings(t *testing.T, pt *PageTables, m []mapping) {
+ // Iterate over all the mappings.
+ w := checkWalker{
+ pageTables: pt,
+ visitor: checkVisitor{
+ expected: m,
+ },
+ }
+ w.iterateRange(0, ^uintptr(0))
+
+ // Were we expected additional mappings?
+ if w.visitor.failed == "" && w.visitor.current != len(w.visitor.expected) {
+ w.visitor.failed = "insufficient mappings found"
+ }
+
+ // Emit a meaningful error message on failure.
+ if w.visitor.failed != "" {
+ t.Errorf("%s; got %#v, wanted %#v", w.visitor.failed, w.visitor.found, w.visitor.expected)
+ }
+}
+
+func TestUnmap(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map and unmap one entry.
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42)
+ pt.Unmap(0x400000, pteSize)
+
+ checkMappings(t, pt, nil)
+}
+
+func TestReadOnly(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map one entry.
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.Read}, pteSize*42)
+
+ checkMappings(t, pt, []mapping{
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.Read}},
+ })
+}
+
+func TestReadWrite(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map one entry.
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42)
+
+ checkMappings(t, pt, []mapping{
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}},
+ })
+}
+
+func TestSerialEntries(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map two sequential entries.
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42)
+ pt.Map(0x401000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*47)
+
+ checkMappings(t, pt, []mapping{
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}},
+ {0x401000, pteSize, pteSize * 47, MapOpts{AccessType: usermem.ReadWrite}},
+ })
+}
+
+func TestSpanningEntries(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Span a pgd with two pages.
+ pt.Map(0x00007efffffff000, 2*pteSize, MapOpts{AccessType: usermem.Read}, pteSize*42)
+
+ checkMappings(t, pt, []mapping{
+ {0x00007efffffff000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.Read}},
+ {0x00007f0000000000, pteSize, pteSize * 43, MapOpts{AccessType: usermem.Read}},
+ })
+}
+
+func TestSparseEntries(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map two entries in different pgds.
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42)
+ pt.Map(0x00007f0000000000, pteSize, MapOpts{AccessType: usermem.Read}, pteSize*47)
+
+ checkMappings(t, pt, []mapping{
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}},
+ {0x00007f0000000000, pteSize, pteSize * 47, MapOpts{AccessType: usermem.Read}},
+ })
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go b/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go
new file mode 100644
index 000000000..157438d9b
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go
@@ -0,0 +1,180 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build 386 amd64
+
+package pagetables
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// archPageTables is architecture-specific data.
+type archPageTables struct {
+ // pcid is the value assigned by PCIDs.Assign.
+ //
+ // Note that zero is a valid PCID.
+ pcid uint16
+}
+
+// CR3 returns the CR3 value for these tables.
+//
+// This may be called in interrupt contexts. A PCID of zero always implies a
+// flush and should be passed when PCIDs are not enabled. See pcids_x86.go for
+// more information.
+//
+//go:nosplit
+func (p *PageTables) CR3(noFlush bool, pcid uint16) uint64 {
+ // Bit 63 is set to avoid flushing the PCID (per SDM 4.10.4.1).
+ const noFlushBit uint64 = 0x8000000000000000
+ if noFlush && pcid != 0 {
+ return noFlushBit | uint64(p.rootPhysical) | uint64(pcid)
+ }
+ return uint64(p.rootPhysical) | uint64(pcid)
+}
+
+// Bits in page table entries.
+const (
+ present = 0x001
+ writable = 0x002
+ user = 0x004
+ writeThrough = 0x008
+ cacheDisable = 0x010
+ accessed = 0x020
+ dirty = 0x040
+ super = 0x080
+ global = 0x100
+ optionMask = executeDisable | 0xfff
+)
+
+// MapOpts are x86 options.
+type MapOpts struct {
+ // AccessType defines permissions.
+ AccessType usermem.AccessType
+
+ // Global indicates the page is globally accessible.
+ Global bool
+
+ // User indicates the page is a user page.
+ User bool
+}
+
+// PTE is a page table entry.
+type PTE uintptr
+
+// Clear clears this PTE, including super page information.
+//
+//go:nosplit
+func (p *PTE) Clear() {
+ atomic.StoreUintptr((*uintptr)(p), 0)
+}
+
+// Valid returns true iff this entry is valid.
+//
+//go:nosplit
+func (p *PTE) Valid() bool {
+ return atomic.LoadUintptr((*uintptr)(p))&present != 0
+}
+
+// Opts returns the PTE options.
+//
+// These are all options except Valid and Super.
+//
+//go:nosplit
+func (p *PTE) Opts() MapOpts {
+ v := atomic.LoadUintptr((*uintptr)(p))
+ return MapOpts{
+ AccessType: usermem.AccessType{
+ Read: v&present != 0,
+ Write: v&writable != 0,
+ Execute: v&executeDisable == 0,
+ },
+ Global: v&global != 0,
+ User: v&user != 0,
+ }
+}
+
+// SetSuper sets this page as a super page.
+//
+// The page must not be valid or a panic will result.
+//
+//go:nosplit
+func (p *PTE) SetSuper() {
+ if p.Valid() {
+ // This is not allowed.
+ panic("SetSuper called on valid page!")
+ }
+ atomic.StoreUintptr((*uintptr)(p), super)
+}
+
+// IsSuper returns true iff this page is a super page.
+//
+//go:nosplit
+func (p *PTE) IsSuper() bool {
+ return atomic.LoadUintptr((*uintptr)(p))&super != 0
+}
+
+// Set sets this PTE value.
+//
+// This does not change the super page property.
+//
+//go:nosplit
+func (p *PTE) Set(addr uintptr, opts MapOpts) {
+ if !opts.AccessType.Any() {
+ p.Clear()
+ return
+ }
+ v := (addr &^ optionMask) | present | accessed
+ if opts.User {
+ v |= user
+ }
+ if opts.Global {
+ v |= global
+ }
+ if !opts.AccessType.Execute {
+ v |= executeDisable
+ }
+ if opts.AccessType.Write {
+ v |= writable | dirty
+ }
+ if p.IsSuper() {
+ // Note that this is inherited from the previous instance. Set
+ // does not change the value of Super. See above.
+ v |= super
+ }
+ atomic.StoreUintptr((*uintptr)(p), v)
+}
+
+// setPageTable sets this PTE value and forces the write bit and super bit to
+// be cleared. This is used explicitly for breaking super pages.
+//
+//go:nosplit
+func (p *PTE) setPageTable(pt *PageTables, ptes *PTEs) {
+ addr := pt.Allocator.PhysicalFor(ptes)
+ if addr&^optionMask != addr {
+ // This should never happen.
+ panic("unaligned physical address!")
+ }
+ v := addr | present | user | writable | accessed | dirty
+ atomic.StoreUintptr((*uintptr)(p), v)
+}
+
+// Address extracts the address. This should only be used if Valid returns true.
+//
+//go:nosplit
+func (p *PTE) Address() uintptr {
+ return atomic.LoadUintptr((*uintptr)(p)) &^ optionMask
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pcids.go b/pkg/sentry/platform/ring0/pagetables/pcids.go
new file mode 100644
index 000000000..964496aac
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pcids.go
@@ -0,0 +1,104 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pagetables
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// PCIDs is a simple PCID database.
+//
+// This is not protected by locks and is thus suitable for use only with a
+// single CPU at a time.
+type PCIDs struct {
+ // mu protects below.
+ mu sync.Mutex
+
+ // cache are the assigned page tables.
+ cache map[*PageTables]uint16
+
+ // avail are available PCIDs.
+ avail []uint16
+}
+
+// NewPCIDs returns a new PCID database.
+//
+// start is the first index to assign. Typically this will be one, as the zero
+// pcid will always be flushed on transition (see pagetables_x86.go). This may
+// be more than one if specific PCIDs are reserved.
+//
+// Nil is returned iff the start and size are out of range.
+func NewPCIDs(start, size uint16) *PCIDs {
+ if start+uint16(size) > limitPCID {
+ return nil // See comment.
+ }
+ p := &PCIDs{
+ cache: make(map[*PageTables]uint16),
+ }
+ for pcid := start; pcid < start+size; pcid++ {
+ p.avail = append(p.avail, pcid)
+ }
+ return p
+}
+
+// Assign assigns a PCID to the given PageTables.
+//
+// This may overwrite any previous assignment provided. If this in the case,
+// true is returned to indicate that the PCID should be flushed.
+func (p *PCIDs) Assign(pt *PageTables) (uint16, bool) {
+ p.mu.Lock()
+ if pcid, ok := p.cache[pt]; ok {
+ p.mu.Unlock()
+ return pcid, false // No flush.
+ }
+
+ // Is there something available?
+ if len(p.avail) > 0 {
+ pcid := p.avail[len(p.avail)-1]
+ p.avail = p.avail[:len(p.avail)-1]
+ p.cache[pt] = pcid
+
+ // We need to flush because while this is in the available
+ // pool, it may have been used previously.
+ p.mu.Unlock()
+ return pcid, true
+ }
+
+ // Evict an existing table.
+ for old, pcid := range p.cache {
+ delete(p.cache, old)
+ p.cache[pt] = pcid
+
+ // A flush is definitely required in this case, these page
+ // tables may still be active. (They will just be assigned some
+ // other PCID if and when they hit the given CPU again.)
+ p.mu.Unlock()
+ return pcid, true
+ }
+
+ // No PCID.
+ p.mu.Unlock()
+ return 0, false
+}
+
+// Drop drops references to a set of page tables.
+func (p *PCIDs) Drop(pt *PageTables) {
+ p.mu.Lock()
+ if pcid, ok := p.cache[pt]; ok {
+ delete(p.cache, pt)
+ p.avail = append(p.avail, pcid)
+ }
+ p.mu.Unlock()
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.go
new file mode 100644
index 000000000..fbfd41d83
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.go
@@ -0,0 +1,32 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package pagetables
+
+// limitPCID is the maximum value of PCIDs.
+//
+// In VMSAv8-64, the PCID(ASID) size is an IMPLEMENTATION DEFINED choice
+// of 8 bits or 16 bits, and ID_AA64MMFR0_EL1.ASIDBits identifies the
+// supported size. When an implementation supports a 16-bit ASID, TCR_ELx.AS
+// selects whether the top 8 bits of the ASID are used.
+var limitPCID uint16
+
+// GetASIDBits return the system ASID bits, 8 or 16 bits.
+func GetASIDBits() uint8
+
+func init() {
+ limitPCID = uint16(1)<<GetASIDBits() - 1
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.s b/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.s
new file mode 100644
index 000000000..e9d62d768
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.s
@@ -0,0 +1,45 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+#include "funcdata.h"
+#include "textflag.h"
+
+#define ID_AA64MMFR0_ASIDBITS_SHIFT 4
+#define ID_AA64MMFR0_ASIDBITS_16 2
+#define TCR_EL1_AS_BIT 36
+
+// GetASIDBits return the system ASID bits, 8 or 16 bits.
+//
+// func GetASIDBits() uint8
+TEXT ·GetASIDBits(SB),NOSPLIT,$0-1
+ // First, check whether 16bits ASID is supported.
+ // ID_AA64MMFR0_EL1.ASIDBITS[7:4] == 0010.
+ WORD $0xd5380700 // MRS ID_AA64MMFR0_EL1, R0
+ UBFX $ID_AA64MMFR0_ASIDBITS_SHIFT, R0, $4, R0
+ CMPW $ID_AA64MMFR0_ASIDBITS_16, R0
+ BNE bits_8
+
+ // Second, check whether 16bits ASID is enabled.
+ // TCR_EL1.AS[36] == 1.
+ WORD $0xd5382040 // MRS TCR_EL1, R0
+ TBZ $TCR_EL1_AS_BIT, R0, bits_8
+ MOVD $16, R0
+ B done
+bits_8:
+ MOVD $8, R0
+done:
+ MOVB R0, ret+0(FP)
+ RET
diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go
new file mode 100644
index 000000000..91fc5e8dd
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go
@@ -0,0 +1,20 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build i386 amd64
+
+package pagetables
+
+// limitPCID is the maximum value of valid PCIDs.
+const limitPCID = 4095
diff --git a/pkg/sentry/platform/ring0/pagetables/walker_amd64.go b/pkg/sentry/platform/ring0/pagetables/walker_amd64.go
new file mode 100644
index 000000000..8f9dacd93
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/walker_amd64.go
@@ -0,0 +1,307 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package pagetables
+
+// Visitor is a generic type.
+type Visitor interface {
+ // visit is called on each PTE.
+ visit(start uintptr, pte *PTE, align uintptr)
+
+ // requiresAlloc indicates that new entries should be allocated within
+ // the walked range.
+ requiresAlloc() bool
+
+ // requiresSplit indicates that entries in the given range should be
+ // split if they are huge or jumbo pages.
+ requiresSplit() bool
+}
+
+// Walker walks page tables.
+type Walker struct {
+ // pageTables are the tables to walk.
+ pageTables *PageTables
+
+ // Visitor is the set of arguments.
+ visitor Visitor
+}
+
+// iterateRange iterates over all appropriate levels of page tables for the given range.
+//
+// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The
+// exception is super pages. If a valid super page (huge or jumbo) cannot be
+// installed, then the walk will continue to individual entries.
+//
+// This algorithm will attempt to maximize the use of super pages whenever
+// possible. Whether a super page is provided will be clear through the range
+// provided in the callback.
+//
+// Note that if requiresAlloc is true, then no gaps will be present. However,
+// if alloc is not set, then the iteration will likely be full of gaps.
+//
+// Note that this function should generally be avoided in favor of Map, Unmap,
+// etc. when not necessary.
+//
+// Precondition: start must be page-aligned.
+//
+// Precondition: start must be less than end.
+//
+// Precondition: If requiresAlloc is true, then start and end should not span
+// non-canonical ranges. If they do, a panic will result.
+//
+//go:nosplit
+func (w *Walker) iterateRange(start, end uintptr) {
+ if start%pteSize != 0 {
+ panic("unaligned start")
+ }
+ if end < start {
+ panic("start > end")
+ }
+ if start < lowerTop {
+ if end <= lowerTop {
+ w.iterateRangeCanonical(start, end)
+ } else if end > lowerTop && end <= upperBottom {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(start, lowerTop)
+ } else {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(start, lowerTop)
+ w.iterateRangeCanonical(upperBottom, end)
+ }
+ } else if start < upperBottom {
+ if end <= upperBottom {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ } else {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(upperBottom, end)
+ }
+ } else {
+ w.iterateRangeCanonical(start, end)
+ }
+}
+
+// next returns the next address quantized by the given size.
+//
+//go:nosplit
+func next(start uintptr, size uintptr) uintptr {
+ start &= ^(size - 1)
+ start += size
+ return start
+}
+
+// iterateRangeCanonical walks a canonical range.
+//
+//go:nosplit
+func (w *Walker) iterateRangeCanonical(start, end uintptr) {
+ for pgdIndex := uint16((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ {
+ var (
+ pgdEntry = &w.pageTables.root[pgdIndex]
+ pudEntries *PTEs
+ )
+ if !pgdEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+ // Skip over this entry.
+ start = next(start, pgdSize)
+ continue
+ }
+
+ // Allocate a new pgd.
+ pudEntries = w.pageTables.Allocator.NewPTEs()
+ pgdEntry.setPageTable(w.pageTables, pudEntries)
+ } else {
+ pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address())
+ }
+
+ // Map the next level.
+ clearPUDEntries := uint16(0)
+
+ for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ {
+ var (
+ pudEntry = &pudEntries[pudIndex]
+ pmdEntries *PTEs
+ )
+ if !pudEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+ // Skip over this entry.
+ clearPUDEntries++
+ start = next(start, pudSize)
+ continue
+ }
+
+ // This level has 1-GB super pages. Is this
+ // entire region at least as large as a single
+ // PUD entry? If so, we can skip allocating a
+ // new page for the pmd.
+ if start&(pudSize-1) == 0 && end-start >= pudSize {
+ pudEntry.SetSuper()
+ w.visitor.visit(uintptr(start), pudEntry, pudSize-1)
+ if pudEntry.Valid() {
+ start = next(start, pudSize)
+ continue
+ }
+ }
+
+ // Allocate a new pud.
+ pmdEntries = w.pageTables.Allocator.NewPTEs()
+ pudEntry.setPageTable(w.pageTables, pmdEntries)
+
+ } else if pudEntry.IsSuper() {
+ // Does this page need to be split?
+ if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < next(start, pudSize)) {
+ // Install the relevant entries.
+ pmdEntries = w.pageTables.Allocator.NewPTEs()
+ for index := uint16(0); index < entriesPerPage; index++ {
+ pmdEntries[index].SetSuper()
+ pmdEntries[index].Set(
+ pudEntry.Address()+(pmdSize*uintptr(index)),
+ pudEntry.Opts())
+ }
+ pudEntry.setPageTable(w.pageTables, pmdEntries)
+ } else {
+ // A super page to be checked directly.
+ w.visitor.visit(uintptr(start), pudEntry, pudSize-1)
+
+ // Might have been cleared.
+ if !pudEntry.Valid() {
+ clearPUDEntries++
+ }
+
+ // Note that the super page was changed.
+ start = next(start, pudSize)
+ continue
+ }
+ } else {
+ pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address())
+ }
+
+ // Map the next level, since this is valid.
+ clearPMDEntries := uint16(0)
+
+ for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ {
+ var (
+ pmdEntry = &pmdEntries[pmdIndex]
+ pteEntries *PTEs
+ )
+ if !pmdEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+ // Skip over this entry.
+ clearPMDEntries++
+ start = next(start, pmdSize)
+ continue
+ }
+
+ // This level has 2-MB huge pages. If this
+ // region is contined in a single PMD entry?
+ // As above, we can skip allocating a new page.
+ if start&(pmdSize-1) == 0 && end-start >= pmdSize {
+ pmdEntry.SetSuper()
+ w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1)
+ if pmdEntry.Valid() {
+ start = next(start, pmdSize)
+ continue
+ }
+ }
+
+ // Allocate a new pmd.
+ pteEntries = w.pageTables.Allocator.NewPTEs()
+ pmdEntry.setPageTable(w.pageTables, pteEntries)
+
+ } else if pmdEntry.IsSuper() {
+ // Does this page need to be split?
+ if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < next(start, pmdSize)) {
+ // Install the relevant entries.
+ pteEntries = w.pageTables.Allocator.NewPTEs()
+ for index := uint16(0); index < entriesPerPage; index++ {
+ pteEntries[index].Set(
+ pmdEntry.Address()+(pteSize*uintptr(index)),
+ pmdEntry.Opts())
+ }
+ pmdEntry.setPageTable(w.pageTables, pteEntries)
+ } else {
+ // A huge page to be checked directly.
+ w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1)
+
+ // Might have been cleared.
+ if !pmdEntry.Valid() {
+ clearPMDEntries++
+ }
+
+ // Note that the huge page was changed.
+ start = next(start, pmdSize)
+ continue
+ }
+ } else {
+ pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address())
+ }
+
+ // Map the next level, since this is valid.
+ clearPTEEntries := uint16(0)
+
+ for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ {
+ var (
+ pteEntry = &pteEntries[pteIndex]
+ )
+ if !pteEntry.Valid() && !w.visitor.requiresAlloc() {
+ clearPTEEntries++
+ start += pteSize
+ continue
+ }
+
+ // At this point, we are guaranteed that start%pteSize == 0.
+ w.visitor.visit(uintptr(start), pteEntry, pteSize-1)
+ if !pteEntry.Valid() {
+ if w.visitor.requiresAlloc() {
+ panic("PTE not set after iteration with requiresAlloc!")
+ }
+ clearPTEEntries++
+ }
+
+ // Note that the pte was changed.
+ start += pteSize
+ continue
+ }
+
+ // Check if we no longer need this page.
+ if clearPTEEntries == entriesPerPage {
+ pmdEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pteEntries)
+ clearPMDEntries++
+ }
+ }
+
+ // Check if we no longer need this page.
+ if clearPMDEntries == entriesPerPage {
+ pudEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pmdEntries)
+ clearPUDEntries++
+ }
+ }
+
+ // Check if we no longer need this page.
+ if clearPUDEntries == entriesPerPage {
+ pgdEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pudEntries)
+ }
+ }
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
new file mode 100644
index 000000000..c261d393a
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
@@ -0,0 +1,314 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package pagetables
+
+// Visitor is a generic type.
+type Visitor interface {
+ // visit is called on each PTE.
+ visit(start uintptr, pte *PTE, align uintptr)
+
+ // requiresAlloc indicates that new entries should be allocated within
+ // the walked range.
+ requiresAlloc() bool
+
+ // requiresSplit indicates that entries in the given range should be
+ // split if they are huge or jumbo pages.
+ requiresSplit() bool
+}
+
+// Walker walks page tables.
+type Walker struct {
+ // pageTables are the tables to walk.
+ pageTables *PageTables
+
+ // Visitor is the set of arguments.
+ visitor Visitor
+}
+
+// iterateRange iterates over all appropriate levels of page tables for the given range.
+//
+// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The
+// exception is sect pages. If a valid sect page (huge or jumbo) cannot be
+// installed, then the walk will continue to individual entries.
+//
+// This algorithm will attempt to maximize the use of sect pages whenever
+// possible. Whether a sect page is provided will be clear through the range
+// provided in the callback.
+//
+// Note that if requiresAlloc is true, then no gaps will be present. However,
+// if alloc is not set, then the iteration will likely be full of gaps.
+//
+// Note that this function should generally be avoided in favor of Map, Unmap,
+// etc. when not necessary.
+//
+// Precondition: start must be page-aligned.
+//
+// Precondition: start must be less than end.
+//
+// Precondition: If requiresAlloc is true, then start and end should not span
+// non-canonical ranges. If they do, a panic will result.
+//
+//go:nosplit
+func (w *Walker) iterateRange(start, end uintptr) {
+ if start%pteSize != 0 {
+ panic("unaligned start")
+ }
+ if end < start {
+ panic("start > end")
+ }
+ if start < lowerTop {
+ if end <= lowerTop {
+ w.iterateRangeCanonical(start, end)
+ } else if end > lowerTop && end <= upperBottom {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(start, lowerTop)
+ } else {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(start, lowerTop)
+ w.iterateRangeCanonical(upperBottom, end)
+ }
+ } else if start < upperBottom {
+ if end <= upperBottom {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ } else {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(upperBottom, end)
+ }
+ } else {
+ w.iterateRangeCanonical(start, end)
+ }
+}
+
+// next returns the next address quantized by the given size.
+//
+//go:nosplit
+func next(start uintptr, size uintptr) uintptr {
+ start &= ^(size - 1)
+ start += size
+ return start
+}
+
+// iterateRangeCanonical walks a canonical range.
+//
+//go:nosplit
+func (w *Walker) iterateRangeCanonical(start, end uintptr) {
+ pgdEntryIndex := w.pageTables.root
+ if start >= upperBottom {
+ pgdEntryIndex = w.pageTables.archPageTables.root
+ }
+
+ for pgdIndex := (uint16((start & pgdMask) >> pgdShift)); start < end && pgdIndex < entriesPerPage; pgdIndex++ {
+ var (
+ pgdEntry = &pgdEntryIndex[pgdIndex]
+ pudEntries *PTEs
+ )
+ if !pgdEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+ // Skip over this entry.
+ start = next(start, pgdSize)
+ continue
+ }
+
+ // Allocate a new pgd.
+ pudEntries = w.pageTables.Allocator.NewPTEs()
+ pgdEntry.setPageTable(w.pageTables, pudEntries)
+ } else {
+ pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address())
+ }
+
+ // Map the next level.
+ clearPUDEntries := uint16(0)
+
+ for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ {
+ var (
+ pudEntry = &pudEntries[pudIndex]
+ pmdEntries *PTEs
+ )
+ if !pudEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+ // Skip over this entry.
+ clearPUDEntries++
+ start = next(start, pudSize)
+ continue
+ }
+
+ // This level has 1-GB sect pages. Is this
+ // entire region at least as large as a single
+ // PUD entry? If so, we can skip allocating a
+ // new page for the pmd.
+ if start&(pudSize-1) == 0 && end-start >= pudSize {
+ pudEntry.SetSect()
+ w.visitor.visit(uintptr(start), pudEntry, pudSize-1)
+ if pudEntry.Valid() {
+ start = next(start, pudSize)
+ continue
+ }
+ }
+
+ // Allocate a new pud.
+ pmdEntries = w.pageTables.Allocator.NewPTEs()
+ pudEntry.setPageTable(w.pageTables, pmdEntries)
+
+ } else if pudEntry.IsSect() {
+ // Does this page need to be split?
+ if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < next(start, pudSize)) {
+ // Install the relevant entries.
+ pmdEntries = w.pageTables.Allocator.NewPTEs()
+ for index := uint16(0); index < entriesPerPage; index++ {
+ pmdEntries[index].SetSect()
+ pmdEntries[index].Set(
+ pudEntry.Address()+(pmdSize*uintptr(index)),
+ pudEntry.Opts())
+ }
+ pudEntry.setPageTable(w.pageTables, pmdEntries)
+ } else {
+ // A sect page to be checked directly.
+ w.visitor.visit(uintptr(start), pudEntry, pudSize-1)
+
+ // Might have been cleared.
+ if !pudEntry.Valid() {
+ clearPUDEntries++
+ }
+
+ // Note that the sect page was changed.
+ start = next(start, pudSize)
+ continue
+ }
+
+ } else {
+ pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address())
+ }
+
+ // Map the next level, since this is valid.
+ clearPMDEntries := uint16(0)
+
+ for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ {
+ var (
+ pmdEntry = &pmdEntries[pmdIndex]
+ pteEntries *PTEs
+ )
+ if !pmdEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+ // Skip over this entry.
+ clearPMDEntries++
+ start = next(start, pmdSize)
+ continue
+ }
+
+ // This level has 2-MB huge pages. If this
+ // region is contined in a single PMD entry?
+ // As above, we can skip allocating a new page.
+ if start&(pmdSize-1) == 0 && end-start >= pmdSize {
+ pmdEntry.SetSect()
+ w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1)
+ if pmdEntry.Valid() {
+ start = next(start, pmdSize)
+ continue
+ }
+ }
+
+ // Allocate a new pmd.
+ pteEntries = w.pageTables.Allocator.NewPTEs()
+ pmdEntry.setPageTable(w.pageTables, pteEntries)
+
+ } else if pmdEntry.IsSect() {
+ // Does this page need to be split?
+ if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < next(start, pmdSize)) {
+ // Install the relevant entries.
+ pteEntries = w.pageTables.Allocator.NewPTEs()
+ for index := uint16(0); index < entriesPerPage; index++ {
+ pteEntries[index].Set(
+ pmdEntry.Address()+(pteSize*uintptr(index)),
+ pmdEntry.Opts())
+ }
+ pmdEntry.setPageTable(w.pageTables, pteEntries)
+ } else {
+ // A huge page to be checked directly.
+ w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1)
+
+ // Might have been cleared.
+ if !pmdEntry.Valid() {
+ clearPMDEntries++
+ }
+
+ // Note that the huge page was changed.
+ start = next(start, pmdSize)
+ continue
+ }
+
+ } else {
+ pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address())
+ }
+
+ // Map the next level, since this is valid.
+ clearPTEEntries := uint16(0)
+
+ for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ {
+ var (
+ pteEntry = &pteEntries[pteIndex]
+ )
+ if !pteEntry.Valid() && !w.visitor.requiresAlloc() {
+ clearPTEEntries++
+ start += pteSize
+ continue
+ }
+
+ // At this point, we are guaranteed that start%pteSize == 0.
+ w.visitor.visit(uintptr(start), pteEntry, pteSize-1)
+ if !pteEntry.Valid() {
+ if w.visitor.requiresAlloc() {
+ panic("PTE not set after iteration with requiresAlloc!")
+ }
+ clearPTEEntries++
+ }
+
+ // Note that the pte was changed.
+ start += pteSize
+ continue
+ }
+
+ // Check if we no longer need this page.
+ if clearPTEEntries == entriesPerPage {
+ pmdEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pteEntries)
+ clearPMDEntries++
+ }
+ }
+
+ // Check if we no longer need this page.
+ if clearPMDEntries == entriesPerPage {
+ pudEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pmdEntries)
+ clearPUDEntries++
+ }
+ }
+
+ // Check if we no longer need this page.
+ if clearPUDEntries == entriesPerPage {
+ pgdEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pudEntries)
+ }
+ }
+}
diff --git a/pkg/sentry/platform/ring0/ring0.go b/pkg/sentry/platform/ring0/ring0.go
new file mode 100644
index 000000000..cdeb1b43a
--- /dev/null
+++ b/pkg/sentry/platform/ring0/ring0.go
@@ -0,0 +1,16 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ring0 provides basic operating system-level stubs.
+package ring0
diff --git a/pkg/sentry/platform/ring0/x86.go b/pkg/sentry/platform/ring0/x86.go
new file mode 100644
index 000000000..9da0ea685
--- /dev/null
+++ b/pkg/sentry/platform/ring0/x86.go
@@ -0,0 +1,264 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build 386 amd64
+
+package ring0
+
+import (
+ "gvisor.dev/gvisor/pkg/cpuid"
+)
+
+// Useful bits.
+const (
+ _CR0_PE = 1 << 0
+ _CR0_ET = 1 << 4
+ _CR0_AM = 1 << 18
+ _CR0_PG = 1 << 31
+
+ _CR4_PSE = 1 << 4
+ _CR4_PAE = 1 << 5
+ _CR4_PGE = 1 << 7
+ _CR4_OSFXSR = 1 << 9
+ _CR4_OSXMMEXCPT = 1 << 10
+ _CR4_FSGSBASE = 1 << 16
+ _CR4_PCIDE = 1 << 17
+ _CR4_OSXSAVE = 1 << 18
+ _CR4_SMEP = 1 << 20
+
+ _RFLAGS_AC = 1 << 18
+ _RFLAGS_NT = 1 << 14
+ _RFLAGS_IOPL = 3 << 12
+ _RFLAGS_DF = 1 << 10
+ _RFLAGS_IF = 1 << 9
+ _RFLAGS_STEP = 1 << 8
+ _RFLAGS_RESERVED = 1 << 1
+
+ _EFER_SCE = 0x001
+ _EFER_LME = 0x100
+ _EFER_LMA = 0x400
+ _EFER_NX = 0x800
+
+ _MSR_STAR = 0xc0000081
+ _MSR_LSTAR = 0xc0000082
+ _MSR_CSTAR = 0xc0000083
+ _MSR_SYSCALL_MASK = 0xc0000084
+ _MSR_PLATFORM_INFO = 0xce
+ _MSR_MISC_FEATURES = 0x140
+
+ _PLATFORM_INFO_CPUID_FAULT = 1 << 31
+
+ _MISC_FEATURE_CPUID_TRAP = 0x1
+)
+
+const (
+ // KernelFlagsSet should always be set in the kernel.
+ KernelFlagsSet = _RFLAGS_RESERVED
+
+ // UserFlagsSet are always set in userspace.
+ UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF
+
+ // KernelFlagsClear should always be clear in the kernel.
+ KernelFlagsClear = _RFLAGS_STEP | _RFLAGS_IF | _RFLAGS_IOPL | _RFLAGS_AC | _RFLAGS_NT
+
+ // UserFlagsClear are always cleared in userspace.
+ UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL
+)
+
+// Vector is an exception vector.
+type Vector uintptr
+
+// Exception vectors.
+const (
+ DivideByZero Vector = iota
+ Debug
+ NMI
+ Breakpoint
+ Overflow
+ BoundRangeExceeded
+ InvalidOpcode
+ DeviceNotAvailable
+ DoubleFault
+ CoprocessorSegmentOverrun
+ InvalidTSS
+ SegmentNotPresent
+ StackSegmentFault
+ GeneralProtectionFault
+ PageFault
+ _
+ X87FloatingPointException
+ AlignmentCheck
+ MachineCheck
+ SIMDFloatingPointException
+ VirtualizationException
+ SecurityException = 0x1e
+ SyscallInt80 = 0x80
+ _NR_INTERRUPTS = SyscallInt80 + 1
+)
+
+// System call vectors.
+const (
+ Syscall Vector = _NR_INTERRUPTS
+)
+
+// VirtualAddressBits returns the number bits available for virtual addresses.
+//
+// Note that sign-extension semantics apply to the highest order bit.
+//
+// FIXME(b/69382326): This should use the cpuid passed to Init.
+func VirtualAddressBits() uint32 {
+ ax, _, _, _ := cpuid.HostID(0x80000008, 0)
+ return (ax >> 8) & 0xff
+}
+
+// PhysicalAddressBits returns the number of bits available for physical addresses.
+//
+// FIXME(b/69382326): This should use the cpuid passed to Init.
+func PhysicalAddressBits() uint32 {
+ ax, _, _, _ := cpuid.HostID(0x80000008, 0)
+ return ax & 0xff
+}
+
+// Selector is a segment Selector.
+type Selector uint16
+
+// SegmentDescriptor is a segment descriptor.
+type SegmentDescriptor struct {
+ bits [2]uint32
+}
+
+// descriptorTable is a collection of descriptors.
+type descriptorTable [32]SegmentDescriptor
+
+// SegmentDescriptorFlags are typed flags within a descriptor.
+type SegmentDescriptorFlags uint32
+
+// SegmentDescriptorFlag declarations.
+const (
+ SegmentDescriptorAccess SegmentDescriptorFlags = 1 << 8 // Access bit (always set).
+ SegmentDescriptorWrite = 1 << 9 // Write permission.
+ SegmentDescriptorExpandDown = 1 << 10 // Grows down, not used.
+ SegmentDescriptorExecute = 1 << 11 // Execute permission.
+ SegmentDescriptorSystem = 1 << 12 // Zero => system, 1 => user code/data.
+ SegmentDescriptorPresent = 1 << 15 // Present.
+ SegmentDescriptorAVL = 1 << 20 // Available.
+ SegmentDescriptorLong = 1 << 21 // Long mode.
+ SegmentDescriptorDB = 1 << 22 // 16 or 32-bit.
+ SegmentDescriptorG = 1 << 23 // Granularity: page or byte.
+)
+
+// Base returns the descriptor's base linear address.
+func (d *SegmentDescriptor) Base() uint32 {
+ return d.bits[1]&0xFF000000 | (d.bits[1]&0x000000FF)<<16 | d.bits[0]>>16
+}
+
+// Limit returns the descriptor size.
+func (d *SegmentDescriptor) Limit() uint32 {
+ l := d.bits[0]&0xFFFF | d.bits[1]&0xF0000
+ if d.bits[1]&uint32(SegmentDescriptorG) != 0 {
+ l <<= 12
+ l |= 0xFFF
+ }
+ return l
+}
+
+// Flags returns descriptor flags.
+func (d *SegmentDescriptor) Flags() SegmentDescriptorFlags {
+ return SegmentDescriptorFlags(d.bits[1] & 0x00F09F00)
+}
+
+// DPL returns the descriptor privilege level.
+func (d *SegmentDescriptor) DPL() int {
+ return int((d.bits[1] >> 13) & 3)
+}
+
+func (d *SegmentDescriptor) setNull() {
+ d.bits[0] = 0
+ d.bits[1] = 0
+}
+
+func (d *SegmentDescriptor) set(base, limit uint32, dpl int, flags SegmentDescriptorFlags) {
+ flags |= SegmentDescriptorPresent
+ if limit>>12 != 0 {
+ limit >>= 12
+ flags |= SegmentDescriptorG
+ }
+ d.bits[0] = base<<16 | limit&0xFFFF
+ d.bits[1] = base&0xFF000000 | (base>>16)&0xFF | limit&0x000F0000 | uint32(flags) | uint32(dpl)<<13
+}
+
+func (d *SegmentDescriptor) setCode32(base, limit uint32, dpl int) {
+ d.set(base, limit, dpl,
+ SegmentDescriptorDB|
+ SegmentDescriptorExecute|
+ SegmentDescriptorSystem)
+}
+
+func (d *SegmentDescriptor) setCode64(base, limit uint32, dpl int) {
+ d.set(base, limit, dpl,
+ SegmentDescriptorG|
+ SegmentDescriptorLong|
+ SegmentDescriptorExecute|
+ SegmentDescriptorSystem)
+}
+
+func (d *SegmentDescriptor) setData(base, limit uint32, dpl int) {
+ d.set(base, limit, dpl,
+ SegmentDescriptorWrite|
+ SegmentDescriptorSystem)
+}
+
+// setHi is only used for the TSS segment, which is magically 64-bits.
+func (d *SegmentDescriptor) setHi(base uint32) {
+ d.bits[0] = base
+ d.bits[1] = 0
+}
+
+// Gate64 is a 64-bit task, trap, or interrupt gate.
+type Gate64 struct {
+ bits [4]uint32
+}
+
+// idt64 is a 64-bit interrupt descriptor table.
+type idt64 [_NR_INTERRUPTS]Gate64
+
+func (g *Gate64) setInterrupt(cs Selector, rip uint64, dpl int, ist int) {
+ g.bits[0] = uint32(cs)<<16 | uint32(rip)&0xFFFF
+ g.bits[1] = uint32(rip)&0xFFFF0000 | SegmentDescriptorPresent | uint32(dpl)<<13 | 14<<8 | uint32(ist)&0x7
+ g.bits[2] = uint32(rip >> 32)
+}
+
+func (g *Gate64) setTrap(cs Selector, rip uint64, dpl int, ist int) {
+ g.setInterrupt(cs, rip, dpl, ist)
+ g.bits[1] |= 1 << 8
+}
+
+// TaskState64 is a 64-bit task state structure.
+type TaskState64 struct {
+ _ uint32
+ rsp0Lo, rsp0Hi uint32
+ rsp1Lo, rsp1Hi uint32
+ rsp2Lo, rsp2Hi uint32
+ _ [2]uint32
+ ist1Lo, ist1Hi uint32
+ ist2Lo, ist2Hi uint32
+ ist3Lo, ist3Hi uint32
+ ist4Lo, ist4Hi uint32
+ ist5Lo, ist5Hi uint32
+ ist6Lo, ist6Hi uint32
+ ist7Lo, ist7Hi uint32
+ _ [2]uint32
+ _ uint16
+ ioPerm uint16
+}
diff --git a/pkg/sentry/sighandling/BUILD b/pkg/sentry/sighandling/BUILD
new file mode 100644
index 000000000..6c38a3f44
--- /dev/null
+++ b/pkg/sentry/sighandling/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "sighandling",
+ srcs = [
+ "sighandling.go",
+ "sighandling_unsafe.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = ["//pkg/abi/linux"],
+)
diff --git a/pkg/sentry/sighandling/sighandling.go b/pkg/sentry/sighandling/sighandling.go
new file mode 100644
index 000000000..83195d5a1
--- /dev/null
+++ b/pkg/sentry/sighandling/sighandling.go
@@ -0,0 +1,102 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package sighandling contains helpers for handling signals to applications.
+package sighandling
+
+import (
+ "os"
+ "os/signal"
+ "reflect"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// numSignals is the number of normal (non-realtime) signals on Linux.
+const numSignals = 32
+
+// handleSignals listens for incoming signals and calls the given handler
+// function.
+//
+// It stops when the stop channel is closed. The done channel is closed once it
+// will no longer deliver signals to k.
+func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), stop, done chan struct{}) {
+ // Build a select case.
+ sc := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(stop)}}
+ for _, sigchan := range sigchans {
+ sc = append(sc, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sigchan)})
+ }
+
+ for {
+ // Wait for a notification.
+ index, _, ok := reflect.Select(sc)
+
+ // Was it the stop channel?
+ if index == 0 {
+ if !ok {
+ // Stop forwarding and notify that it's done.
+ close(done)
+ return
+ }
+ continue
+ }
+
+ // How about a different close?
+ if !ok {
+ panic("signal channel closed unexpectedly")
+ }
+
+ // Otherwise, it was a signal on channel N. Index 0 represents the stop
+ // channel, so index N represents the channel for signal N.
+ handler(linux.Signal(index))
+ }
+}
+
+// StartSignalForwarding ensures that synchronous signals are passed to the
+// given handler function and returns a callback that stops signal delivery.
+//
+// Note that this function permanently takes over signal handling. After the
+// stop callback, signals revert to the default Go runtime behavior, which
+// cannot be overridden with external calls to signal.Notify.
+func StartSignalForwarding(handler func(linux.Signal)) func() {
+ stop := make(chan struct{})
+ done := make(chan struct{})
+
+ // Register individual channels. One channel per standard signal is
+ // required as os.Notify() is non-blocking and may drop signals. To avoid
+ // this, standard signals have to be queued separately. Channel size 1 is
+ // enough for standard signals as their semantics allow de-duplication.
+ //
+ // External real-time signals are not supported. We rely on the go-runtime
+ // for their handling.
+ var sigchans []chan os.Signal
+ for sig := 1; sig <= numSignals+1; sig++ {
+ sigchan := make(chan os.Signal, 1)
+ sigchans = append(sigchans, sigchan)
+
+ // SIGURG is used by Go's runtime scheduler.
+ if sig == int(linux.SIGURG) {
+ continue
+ }
+ signal.Notify(sigchan, syscall.Signal(sig))
+ }
+ // Start up our listener.
+ go handleSignals(sigchans, handler, stop, done) // S/R-SAFE: synchronized by Kernel.extMu.
+
+ return func() {
+ close(stop)
+ <-done
+ }
+}
diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go
new file mode 100644
index 000000000..1ebe22d34
--- /dev/null
+++ b/pkg/sentry/sighandling/sighandling_unsafe.go
@@ -0,0 +1,48 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sighandling
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// FIXME(gvisor.dev/issue/214): Move to pkg/abi/linux along with definitions in
+// pkg/sentry/arch.
+type sigaction struct {
+ handler uintptr
+ flags uint64
+ restorer uintptr
+ mask uint64
+}
+
+// IgnoreChildStop sets the SA_NOCLDSTOP flag, causing child processes to not
+// generate SIGCHLD when they stop.
+func IgnoreChildStop() error {
+ var sa sigaction
+
+ // Get the existing signal handler information, and set the flag.
+ if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(syscall.SIGCHLD), 0, uintptr(unsafe.Pointer(&sa)), linux.SignalSetSize, 0, 0); e != 0 {
+ return e
+ }
+ sa.flags |= linux.SA_NOCLDSTOP
+ if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(syscall.SIGCHLD), uintptr(unsafe.Pointer(&sa)), 0, linux.SignalSetSize, 0, 0); e != 0 {
+ return e
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
new file mode 100644
index 000000000..c40c6d673
--- /dev/null
+++ b/pkg/sentry/socket/BUILD
@@ -0,0 +1,24 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "socket",
+ srcs = ["socket.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/context",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
+ "//pkg/syserr",
+ "//pkg/tcpip",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
new file mode 100644
index 000000000..ca16d0381
--- /dev/null
+++ b/pkg/sentry/socket/control/BUILD
@@ -0,0 +1,29 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "control",
+ srcs = [
+ "control.go",
+ "control_vfs2.go",
+ ],
+ imports = [
+ "gvisor.dev/gvisor/pkg/sentry/fs",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
new file mode 100644
index 000000000..8b439a078
--- /dev/null
+++ b/pkg/sentry/socket/control/control.go
@@ -0,0 +1,591 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package control provides internal representations of socket control
+// messages.
+package control
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const maxInt = int(^uint(0) >> 1)
+
+// SCMCredentials represents a SCM_CREDENTIALS socket control message.
+type SCMCredentials interface {
+ transport.CredentialsControlMessage
+
+ // Credentials returns properly namespaced values for the sender's pid, uid
+ // and gid.
+ Credentials(t *kernel.Task) (kernel.ThreadID, auth.UID, auth.GID)
+}
+
+// LINT.IfChange
+
+// SCMRights represents a SCM_RIGHTS socket control message.
+type SCMRights interface {
+ transport.RightsControlMessage
+
+ // Files returns up to max RightsFiles.
+ //
+ // Returned files are consumed and ownership is transferred to the caller.
+ // Subsequent calls to Files will return the next files.
+ Files(ctx context.Context, max int) (rf RightsFiles, truncated bool)
+}
+
+// RightsFiles represents a SCM_RIGHTS socket control message. A reference is
+// maintained for each fs.File and is release either when an FD is created or
+// when the Release method is called.
+//
+// +stateify savable
+type RightsFiles []*fs.File
+
+// NewSCMRights creates a new SCM_RIGHTS socket control message representation
+// using local sentry FDs.
+func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) {
+ files := make(RightsFiles, 0, len(fds))
+ for _, fd := range fds {
+ file := t.GetFile(fd)
+ if file == nil {
+ files.Release()
+ return nil, syserror.EBADF
+ }
+ files = append(files, file)
+ }
+ return &files, nil
+}
+
+// Files implements SCMRights.Files.
+func (fs *RightsFiles) Files(ctx context.Context, max int) (RightsFiles, bool) {
+ n := max
+ var trunc bool
+ if l := len(*fs); n > l {
+ n = l
+ } else if n < l {
+ trunc = true
+ }
+ rf := (*fs)[:n]
+ *fs = (*fs)[n:]
+ return rf, trunc
+}
+
+// Clone implements transport.RightsControlMessage.Clone.
+func (fs *RightsFiles) Clone() transport.RightsControlMessage {
+ nfs := append(RightsFiles(nil), *fs...)
+ for _, nf := range nfs {
+ nf.IncRef()
+ }
+ return &nfs
+}
+
+// Release implements transport.RightsControlMessage.Release.
+func (fs *RightsFiles) Release() {
+ for _, f := range *fs {
+ f.DecRef()
+ }
+ *fs = nil
+}
+
+// rightsFDs gets up to the specified maximum number of FDs.
+func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32, bool) {
+ files, trunc := rights.Files(t, max)
+ fds := make([]int32, 0, len(files))
+ for i := 0; i < max && len(files) > 0; i++ {
+ fd, err := t.NewFDFrom(0, files[0], kernel.FDFlags{
+ CloseOnExec: cloexec,
+ })
+ files[0].DecRef()
+ files = files[1:]
+ if err != nil {
+ t.Warningf("Error inserting FD: %v", err)
+ // This is what Linux does.
+ break
+ }
+
+ fds = append(fds, int32(fd))
+ }
+ return fds, trunc
+}
+
+// PackRights packs as many FDs as will fit into the unused capacity of buf.
+func PackRights(t *kernel.Task, rights SCMRights, cloexec bool, buf []byte, flags int) ([]byte, int) {
+ maxFDs := (cap(buf) - len(buf) - linux.SizeOfControlMessageHeader) / 4
+ // Linux does not return any FDs if none fit.
+ if maxFDs <= 0 {
+ flags |= linux.MSG_CTRUNC
+ return buf, flags
+ }
+ fds, trunc := rightsFDs(t, rights, cloexec, maxFDs)
+ if trunc {
+ flags |= linux.MSG_CTRUNC
+ }
+ align := t.Arch().Width()
+ return putCmsg(buf, flags, linux.SCM_RIGHTS, align, fds)
+}
+
+// LINT.ThenChange(./control_vfs2.go)
+
+// scmCredentials represents an SCM_CREDENTIALS socket control message.
+//
+// +stateify savable
+type scmCredentials struct {
+ t *kernel.Task
+ kuid auth.KUID
+ kgid auth.KGID
+}
+
+// NewSCMCredentials creates a new SCM_CREDENTIALS socket control message
+// representation.
+func NewSCMCredentials(t *kernel.Task, cred linux.ControlMessageCredentials) (SCMCredentials, error) {
+ tcred := t.Credentials()
+ kuid, err := tcred.UseUID(auth.UID(cred.UID))
+ if err != nil {
+ return nil, err
+ }
+ kgid, err := tcred.UseGID(auth.GID(cred.GID))
+ if err != nil {
+ return nil, err
+ }
+ if kernel.ThreadID(cred.PID) != t.ThreadGroup().ID() && !t.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.PIDNamespace().UserNamespace()) {
+ return nil, syserror.EPERM
+ }
+ return &scmCredentials{t, kuid, kgid}, nil
+}
+
+// Equals implements transport.CredentialsControlMessage.Equals.
+func (c *scmCredentials) Equals(oc transport.CredentialsControlMessage) bool {
+ if oc, _ := oc.(*scmCredentials); oc != nil && *c == *oc {
+ return true
+ }
+ return false
+}
+
+func putUint64(buf []byte, n uint64) []byte {
+ usermem.ByteOrder.PutUint64(buf[len(buf):len(buf)+8], n)
+ return buf[:len(buf)+8]
+}
+
+func putUint32(buf []byte, n uint32) []byte {
+ usermem.ByteOrder.PutUint32(buf[len(buf):len(buf)+4], n)
+ return buf[:len(buf)+4]
+}
+
+// putCmsg writes a control message header and as much data as will fit into
+// the unused capacity of a buffer.
+func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) {
+ space := binary.AlignDown(cap(buf)-len(buf), 4)
+
+ // We can't write to space that doesn't exist, so if we are going to align
+ // the available space, we must align down.
+ //
+ // align must be >= 4 and each data int32 is 4 bytes. The length of the
+ // header is already aligned, so if we align to the width of the data there
+ // are two cases:
+ // 1. The aligned length is less than the length of the header. The
+ // unaligned length was also less than the length of the header, so we
+ // can't write anything.
+ // 2. The aligned length is greater than or equal to the length of the
+ // header. We can write the header plus zero or more bytes of data. We can't
+ // write a partial int32, so the length of the message will be
+ // min(aligned length, header + data).
+ if space < linux.SizeOfControlMessageHeader {
+ flags |= linux.MSG_CTRUNC
+ return buf, flags
+ }
+
+ length := 4*len(data) + linux.SizeOfControlMessageHeader
+ if length > space {
+ length = space
+ }
+ buf = putUint64(buf, uint64(length))
+ buf = putUint32(buf, linux.SOL_SOCKET)
+ buf = putUint32(buf, msgType)
+ for _, d := range data {
+ if len(buf)+4 > cap(buf) {
+ flags |= linux.MSG_CTRUNC
+ break
+ }
+ buf = putUint32(buf, uint32(d))
+ }
+ return alignSlice(buf, align), flags
+}
+
+func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interface{}) []byte {
+ if cap(buf)-len(buf) < linux.SizeOfControlMessageHeader {
+ return buf
+ }
+ ob := buf
+
+ buf = putUint64(buf, uint64(linux.SizeOfControlMessageHeader))
+ buf = putUint32(buf, msgLevel)
+ buf = putUint32(buf, msgType)
+
+ hdrBuf := buf
+
+ buf = binary.Marshal(buf, usermem.ByteOrder, data)
+
+ // If the control message data brought us over capacity, omit it.
+ if cap(buf) != cap(ob) {
+ return hdrBuf
+ }
+
+ // Update control message length to include data.
+ putUint64(ob, uint64(len(buf)-len(ob)))
+
+ return alignSlice(buf, align)
+}
+
+// Credentials implements SCMCredentials.Credentials.
+func (c *scmCredentials) Credentials(t *kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) {
+ // "When a process's user and group IDs are passed over a UNIX domain
+ // socket to a process in a different user namespace (see the description
+ // of SCM_CREDENTIALS in unix(7)), they are translated into the
+ // corresponding values as per the receiving process's user and group ID
+ // mappings." - user_namespaces(7)
+ pid := t.PIDNamespace().IDOfTask(c.t)
+ uid := c.kuid.In(t.UserNamespace()).OrOverflow()
+ gid := c.kgid.In(t.UserNamespace()).OrOverflow()
+
+ return pid, uid, gid
+}
+
+// PackCredentials packs the credentials in the control message (or default
+// credentials if none) into a buffer.
+func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int) ([]byte, int) {
+ align := t.Arch().Width()
+
+ // Default credentials if none are available.
+ pid := kernel.ThreadID(0)
+ uid := auth.UID(auth.NobodyKUID)
+ gid := auth.GID(auth.NobodyKGID)
+
+ if creds != nil {
+ pid, uid, gid = creds.Credentials(t)
+ }
+ c := []int32{int32(pid), int32(uid), int32(gid)}
+ return putCmsg(buf, flags, linux.SCM_CREDENTIALS, align, c)
+}
+
+// alignSlice extends a slice's length (up to the capacity) to align it.
+func alignSlice(buf []byte, align uint) []byte {
+ aligned := binary.AlignUp(len(buf), align)
+ if aligned > cap(buf) {
+ // Linux allows unaligned data if there isn't room for alignment.
+ // Since there isn't room for alignment, there isn't room for any
+ // additional messages either.
+ return buf
+ }
+ return buf[:aligned]
+}
+
+// PackTimestamp packs a SO_TIMESTAMP socket control message.
+func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_SOCKET,
+ linux.SO_TIMESTAMP,
+ t.Arch().Width(),
+ linux.NsecToTimeval(timestamp),
+ )
+}
+
+// PackInq packs a TCP_INQ socket control message.
+func PackInq(t *kernel.Task, inq int32, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_TCP,
+ linux.TCP_INQ,
+ t.Arch().Width(),
+ inq,
+ )
+}
+
+// PackTOS packs an IP_TOS socket control message.
+func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IP,
+ linux.IP_TOS,
+ t.Arch().Width(),
+ tos,
+ )
+}
+
+// PackTClass packs an IPV6_TCLASS socket control message.
+func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IPV6,
+ linux.IPV6_TCLASS,
+ t.Arch().Width(),
+ tClass,
+ )
+}
+
+// PackIPPacketInfo packs an IP_PKTINFO socket control message.
+func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte {
+ var p linux.ControlMessageIPPacketInfo
+ p.NIC = int32(packetInfo.NIC)
+ copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
+ copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
+
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IP,
+ linux.IP_PKTINFO,
+ t.Arch().Width(),
+ p,
+ )
+}
+
+// PackControlMessages packs control messages into the given buffer.
+//
+// We skip control messages specific to Unix domain sockets.
+//
+// Note that some control messages may be truncated if they do not fit under
+// the capacity of buf.
+func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte {
+ if cmsgs.IP.HasTimestamp {
+ buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf)
+ }
+
+ if cmsgs.IP.HasInq {
+ // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP.
+ buf = PackInq(t, cmsgs.IP.Inq, buf)
+ }
+
+ if cmsgs.IP.HasTOS {
+ buf = PackTOS(t, cmsgs.IP.TOS, buf)
+ }
+
+ if cmsgs.IP.HasTClass {
+ buf = PackTClass(t, cmsgs.IP.TClass, buf)
+ }
+
+ if cmsgs.IP.HasIPPacketInfo {
+ buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf)
+ }
+
+ return buf
+}
+
+// cmsgSpace is equivalent to CMSG_SPACE in Linux.
+func cmsgSpace(t *kernel.Task, dataLen int) int {
+ return linux.SizeOfControlMessageHeader + binary.AlignUp(dataLen, t.Arch().Width())
+}
+
+// CmsgsSpace returns the number of bytes needed to fit the control messages
+// represented in cmsgs.
+func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
+ space := 0
+
+ if cmsgs.IP.HasTimestamp {
+ space += cmsgSpace(t, linux.SizeOfTimeval)
+ }
+
+ if cmsgs.IP.HasInq {
+ space += cmsgSpace(t, linux.SizeOfControlMessageInq)
+ }
+
+ if cmsgs.IP.HasTOS {
+ space += cmsgSpace(t, linux.SizeOfControlMessageTOS)
+ }
+
+ if cmsgs.IP.HasTClass {
+ space += cmsgSpace(t, linux.SizeOfControlMessageTClass)
+ }
+
+ return space
+}
+
+// NewIPPacketInfo returns the IPPacketInfo struct.
+func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo {
+ var p tcpip.IPPacketInfo
+ p.NIC = tcpip.NICID(packetInfo.NIC)
+ copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:])
+ copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:])
+
+ return p
+}
+
+// Parse parses a raw socket control message into portable objects.
+func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) {
+ var (
+ cmsgs socket.ControlMessages
+ fds linux.ControlMessageRights
+ )
+
+ for i := 0; i < len(buf); {
+ if i+linux.SizeOfControlMessageHeader > len(buf) {
+ return cmsgs, syserror.EINVAL
+ }
+
+ var h linux.ControlMessageHeader
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], usermem.ByteOrder, &h)
+
+ if h.Length < uint64(linux.SizeOfControlMessageHeader) {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ if h.Length > uint64(len(buf)-i) {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ i += linux.SizeOfControlMessageHeader
+ length := int(h.Length) - linux.SizeOfControlMessageHeader
+
+ // The use of t.Arch().Width() is analogous to Linux's use of
+ // sizeof(long) in CMSG_ALIGN.
+ width := t.Arch().Width()
+
+ switch h.Level {
+ case linux.SOL_SOCKET:
+ switch h.Type {
+ case linux.SCM_RIGHTS:
+ rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight)
+ numRights := rightsSize / linux.SizeOfControlMessageRight
+
+ if len(fds)+numRights > linux.SCM_MAX_FD {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight {
+ fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
+ }
+
+ i += binary.AlignUp(length, width)
+
+ case linux.SCM_CREDENTIALS:
+ if length < linux.SizeOfControlMessageCredentials {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ var creds linux.ControlMessageCredentials
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds)
+ scmCreds, err := NewSCMCredentials(t, creds)
+ if err != nil {
+ return socket.ControlMessages{}, err
+ }
+ cmsgs.Unix.Credentials = scmCreds
+ i += binary.AlignUp(length, width)
+
+ default:
+ // Unknown message type.
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ case linux.SOL_IP:
+ switch h.Type {
+ case linux.IP_TOS:
+ if length < linux.SizeOfControlMessageTOS {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ cmsgs.IP.HasTOS = true
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS)
+ i += binary.AlignUp(length, width)
+
+ case linux.IP_PKTINFO:
+ if length < linux.SizeOfControlMessageIPPacketInfo {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ cmsgs.IP.HasIPPacketInfo = true
+ var packetInfo linux.ControlMessageIPPacketInfo
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
+
+ cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo)
+ i += binary.AlignUp(length, width)
+
+ default:
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ case linux.SOL_IPV6:
+ switch h.Type {
+ case linux.IPV6_TCLASS:
+ if length < linux.SizeOfControlMessageTClass {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ cmsgs.IP.HasTClass = true
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
+ i += binary.AlignUp(length, width)
+
+ default:
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ default:
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ }
+
+ if cmsgs.Unix.Credentials == nil {
+ cmsgs.Unix.Credentials = makeCreds(t, socketOrEndpoint)
+ }
+
+ if len(fds) > 0 {
+ if kernel.VFS2Enabled {
+ rights, err := NewSCMRightsVFS2(t, fds)
+ if err != nil {
+ return socket.ControlMessages{}, err
+ }
+ cmsgs.Unix.Rights = rights
+ } else {
+ rights, err := NewSCMRights(t, fds)
+ if err != nil {
+ return socket.ControlMessages{}, err
+ }
+ cmsgs.Unix.Rights = rights
+ }
+ }
+
+ return cmsgs, nil
+}
+
+func makeCreds(t *kernel.Task, socketOrEndpoint interface{}) SCMCredentials {
+ if t == nil || socketOrEndpoint == nil {
+ return nil
+ }
+ if cr, ok := socketOrEndpoint.(transport.Credentialer); ok && (cr.Passcred() || cr.ConnectedPasscred()) {
+ return MakeCreds(t)
+ }
+ return nil
+}
+
+// MakeCreds creates default SCMCredentials.
+func MakeCreds(t *kernel.Task) SCMCredentials {
+ if t == nil {
+ return nil
+ }
+ tcred := t.Credentials()
+ return &scmCredentials{t, tcred.EffectiveKUID, tcred.EffectiveKGID}
+}
+
+// LINT.IfChange
+
+// New creates default control messages if needed.
+func New(t *kernel.Task, socketOrEndpoint interface{}, rights SCMRights) transport.ControlMessages {
+ return transport.ControlMessages{
+ Credentials: makeCreds(t, socketOrEndpoint),
+ Rights: rights,
+ }
+}
+
+// LINT.ThenChange(./control_vfs2.go)
diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go
new file mode 100644
index 000000000..fd08179be
--- /dev/null
+++ b/pkg/sentry/socket/control/control_vfs2.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package control
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// SCMRightsVFS2 represents a SCM_RIGHTS socket control message.
+type SCMRightsVFS2 interface {
+ transport.RightsControlMessage
+
+ // Files returns up to max RightsFiles.
+ //
+ // Returned files are consumed and ownership is transferred to the caller.
+ // Subsequent calls to Files will return the next files.
+ Files(ctx context.Context, max int) (rf RightsFilesVFS2, truncated bool)
+}
+
+// RightsFiles represents a SCM_RIGHTS socket control message. A reference is
+// maintained for each vfs.FileDescription and is release either when an FD is created or
+// when the Release method is called.
+type RightsFilesVFS2 []*vfs.FileDescription
+
+// NewSCMRightsVFS2 creates a new SCM_RIGHTS socket control message
+// representation using local sentry FDs.
+func NewSCMRightsVFS2(t *kernel.Task, fds []int32) (SCMRightsVFS2, error) {
+ files := make(RightsFilesVFS2, 0, len(fds))
+ for _, fd := range fds {
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ files.Release()
+ return nil, syserror.EBADF
+ }
+ files = append(files, file)
+ }
+ return &files, nil
+}
+
+// Files implements SCMRights.Files.
+func (fs *RightsFilesVFS2) Files(ctx context.Context, max int) (RightsFilesVFS2, bool) {
+ n := max
+ var trunc bool
+ if l := len(*fs); n > l {
+ n = l
+ } else if n < l {
+ trunc = true
+ }
+ rf := (*fs)[:n]
+ *fs = (*fs)[n:]
+ return rf, trunc
+}
+
+// Clone implements transport.RightsControlMessage.Clone.
+func (fs *RightsFilesVFS2) Clone() transport.RightsControlMessage {
+ nfs := append(RightsFilesVFS2(nil), *fs...)
+ for _, nf := range nfs {
+ nf.IncRef()
+ }
+ return &nfs
+}
+
+// Release implements transport.RightsControlMessage.Release.
+func (fs *RightsFilesVFS2) Release() {
+ for _, f := range *fs {
+ f.DecRef()
+ }
+ *fs = nil
+}
+
+// rightsFDsVFS2 gets up to the specified maximum number of FDs.
+func rightsFDsVFS2(t *kernel.Task, rights SCMRightsVFS2, cloexec bool, max int) ([]int32, bool) {
+ files, trunc := rights.Files(t, max)
+ fds := make([]int32, 0, len(files))
+ for i := 0; i < max && len(files) > 0; i++ {
+ fd, err := t.NewFDFromVFS2(0, files[0], kernel.FDFlags{
+ CloseOnExec: cloexec,
+ })
+ files[0].DecRef()
+ files = files[1:]
+ if err != nil {
+ t.Warningf("Error inserting FD: %v", err)
+ // This is what Linux does.
+ break
+ }
+
+ fds = append(fds, int32(fd))
+ }
+ return fds, trunc
+}
+
+// PackRightsVFS2 packs as many FDs as will fit into the unused capacity of buf.
+func PackRightsVFS2(t *kernel.Task, rights SCMRightsVFS2, cloexec bool, buf []byte, flags int) ([]byte, int) {
+ maxFDs := (cap(buf) - len(buf) - linux.SizeOfControlMessageHeader) / 4
+ // Linux does not return any FDs if none fit.
+ if maxFDs <= 0 {
+ flags |= linux.MSG_CTRUNC
+ return buf, flags
+ }
+ fds, trunc := rightsFDsVFS2(t, rights, cloexec, maxFDs)
+ if trunc {
+ flags |= linux.MSG_CTRUNC
+ }
+ align := t.Arch().Width()
+ return putCmsg(buf, flags, linux.SCM_RIGHTS, align, fds)
+}
+
+// NewVFS2 creates default control messages if needed.
+func NewVFS2(t *kernel.Task, socketOrEndpoint interface{}, rights SCMRightsVFS2) transport.ControlMessages {
+ return transport.ControlMessages{
+ Credentials: makeCreds(t, socketOrEndpoint),
+ Rights: rights,
+ }
+}
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
new file mode 100644
index 000000000..ff81ea6e6
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -0,0 +1,45 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "hostinet",
+ srcs = [
+ "device.go",
+ "hostinet.go",
+ "save_restore.go",
+ "socket.go",
+ "socket_unsafe.go",
+ "socket_vfs2.go",
+ "sockopt_impl.go",
+ "stack.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/context",
+ "//pkg/fdnotifier",
+ "//pkg/log",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsimpl/sockfs",
+ "//pkg/sentry/hostfd",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/control",
+ "//pkg/sentry/vfs",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/tcpip/stack",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/socket/hostinet/device.go b/pkg/sentry/socket/hostinet/device.go
new file mode 100644
index 000000000..27049d65f
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/device.go
@@ -0,0 +1,19 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import "gvisor.dev/gvisor/pkg/sentry/device"
+
+var socketDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/socket/hostinet/hostinet.go b/pkg/sentry/socket/hostinet/hostinet.go
new file mode 100644
index 000000000..0d6f51d2b
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/hostinet.go
@@ -0,0 +1,17 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package hostinet implements AF_INET and AF_INET6 sockets using the host's
+// network stack.
+package hostinet
diff --git a/pkg/sentry/socket/hostinet/save_restore.go b/pkg/sentry/socket/hostinet/save_restore.go
new file mode 100644
index 000000000..1dec33897
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/save_restore.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+// beforeSave is invoked by stateify.
+func (*socketOperations) beforeSave() {
+ panic("host.socketOperations is not savable")
+}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
new file mode 100644
index 000000000..a92aed2c9
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -0,0 +1,713 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "fmt"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ sizeofInt32 = 4
+
+ // sizeofSockaddr is the size in bytes of the largest sockaddr type
+ // supported by this package.
+ sizeofSockaddr = syscall.SizeofSockaddrInet6 // sizeof(sockaddr_in6) > sizeof(sockaddr_in)
+
+ // maxControlLen is the maximum size of a control message buffer used in a
+ // recvmsg or sendmsg syscall.
+ maxControlLen = 1024
+)
+
+// LINT.IfChange
+
+// socketOperations implements fs.FileOperations and socket.Socket for a socket
+// implemented using a host socket.
+type socketOperations struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ socketOpsCommon
+}
+
+// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
+//
+// +stateify savable
+type socketOpsCommon struct {
+ socket.SendReceiveTimeout
+
+ family int // Read-only.
+ stype linux.SockType // Read-only.
+ protocol int // Read-only.
+ queue waiter.Queue
+
+ // fd is the host socket fd. It must have O_NONBLOCK, so that operations
+ // will return EWOULDBLOCK instead of blocking on the host. This allows us to
+ // handle blocking behavior independently in the sentry.
+ fd int
+}
+
+var _ = socket.Socket(&socketOperations{})
+
+func newSocketFile(ctx context.Context, family int, stype linux.SockType, protocol int, fd int, nonblock bool) (*fs.File, *syserr.Error) {
+ s := &socketOperations{
+ socketOpsCommon: socketOpsCommon{
+ family: family,
+ stype: stype,
+ protocol: protocol,
+ fd: fd,
+ },
+ }
+ if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ dirent := socket.NewDirent(ctx, socketDevice)
+ defer dirent.DecRef()
+ return fs.NewFile(ctx, dirent, fs.FileFlags{NonBlocking: nonblock, Read: true, Write: true, NonSeekable: true}, s), nil
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *socketOpsCommon) Release() {
+ fdnotifier.RemoveFD(int32(s.fd))
+ syscall.Close(s.fd)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fdnotifier.NonBlockingPoll(int32(s.fd), mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.queue.EventRegister(e, mask)
+ fdnotifier.UpdateFD(int32(s.fd))
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
+ s.queue.EventUnregister(e)
+ fdnotifier.UpdateFD(int32(s.fd))
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (s *socketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return ioctl(ctx, s.fd, io, args)
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ n, err := dst.CopyOutFrom(ctx, safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of dst.Addrs was unusable.
+ if uint64(dst.NumBytes()) != dsts.NumBytes() {
+ return 0, nil
+ }
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+ if dsts.NumBlocks() == 1 {
+ // Skip allocating []syscall.Iovec.
+ n, err := syscall.Read(s.fd, dsts.Head().ToSlice())
+ if err != nil {
+ return 0, translateIOSyscallError(err)
+ }
+ return uint64(n), nil
+ }
+ return readv(s.fd, safemem.IovecsFromBlockSeq(dsts))
+ }))
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ n, err := src.CopyInTo(ctx, safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of src.Addrs was unusable.
+ if uint64(src.NumBytes()) != srcs.NumBytes() {
+ return 0, nil
+ }
+ if srcs.IsEmpty() {
+ return 0, nil
+ }
+ if srcs.NumBlocks() == 1 {
+ // Skip allocating []syscall.Iovec.
+ n, err := syscall.Write(s.fd, srcs.Head().ToSlice())
+ if err != nil {
+ return 0, translateIOSyscallError(err)
+ }
+ return uint64(n), nil
+ }
+ return writev(s.fd, safemem.IovecsFromBlockSeq(srcs))
+ }))
+ return int64(n), err
+}
+
+// Connect implements socket.Socket.Connect.
+func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ if len(sockaddr) > sizeofSockaddr {
+ sockaddr = sockaddr[:sizeofSockaddr]
+ }
+
+ _, _, errno := syscall.Syscall(syscall.SYS_CONNECT, uintptr(s.fd), uintptr(firstBytePtr(sockaddr)), uintptr(len(sockaddr)))
+
+ if errno == 0 {
+ return nil
+ }
+ if errno != syscall.EINPROGRESS || !blocking {
+ return syserr.FromError(translateIOSyscallError(errno))
+ }
+
+ // "EINPROGRESS: The socket is nonblocking and the connection cannot be
+ // completed immediately. It is possible to select(2) or poll(2) for
+ // completion by selecting the socket for writing. After select(2)
+ // indicates writability, use getsockopt(2) to read the SO_ERROR option at
+ // level SOL-SOCKET to determine whether connect() completed successfully
+ // (SO_ERROR is zero) or unsuccessfully (SO_ERROR is one of the usual error
+ // codes listed here, explaining the reason for the failure)." - connect(2)
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+ if s.Readiness(waiter.EventOut)&waiter.EventOut == 0 {
+ if err := t.Block(ch); err != nil {
+ return syserr.FromError(err)
+ }
+ }
+ val, err := syscall.GetsockoptInt(s.fd, syscall.SOL_SOCKET, syscall.SO_ERROR)
+ if err != nil {
+ return syserr.FromError(err)
+ }
+ if val != 0 {
+ return syserr.FromError(syscall.Errno(uintptr(val)))
+ }
+ return nil
+}
+
+// Accept implements socket.Socket.Accept.
+func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ var peerAddr linux.SockAddr
+ var peerAddrBuf []byte
+ var peerAddrlen uint32
+ var peerAddrPtr *byte
+ var peerAddrlenPtr *uint32
+ if peerRequested {
+ peerAddrBuf = make([]byte, sizeofSockaddr)
+ peerAddrlen = uint32(len(peerAddrBuf))
+ peerAddrPtr = &peerAddrBuf[0]
+ peerAddrlenPtr = &peerAddrlen
+ }
+
+ // Conservatively ignore all flags specified by the application and add
+ // SOCK_NONBLOCK since socketOpsCommon requires it.
+ fd, syscallErr := accept4(s.fd, peerAddrPtr, peerAddrlenPtr, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC)
+ if blocking {
+ var ch chan struct{}
+ for syscallErr == syserror.ErrWouldBlock {
+ if ch != nil {
+ if syscallErr = t.Block(ch); syscallErr != nil {
+ break
+ }
+ } else {
+ var e waiter.Entry
+ e, ch = waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+ }
+ fd, syscallErr = accept4(s.fd, peerAddrPtr, peerAddrlenPtr, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC)
+ }
+ }
+
+ if peerRequested {
+ peerAddr = socket.UnmarshalSockAddr(s.family, peerAddrBuf[:peerAddrlen])
+ }
+ if syscallErr != nil {
+ return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr)
+ }
+
+ var (
+ kfd int32
+ kerr error
+ )
+ if kernel.VFS2Enabled {
+ f, err := newVFS2Socket(t, s.family, s.stype, s.protocol, fd, uint32(flags&syscall.SOCK_NONBLOCK))
+ if err != nil {
+ syscall.Close(fd)
+ return 0, nil, 0, err
+ }
+ defer f.DecRef()
+
+ kfd, kerr = t.NewFDFromVFS2(0, f, kernel.FDFlags{
+ CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0,
+ })
+ t.Kernel().RecordSocketVFS2(f)
+ } else {
+ f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0)
+ if err != nil {
+ syscall.Close(fd)
+ return 0, nil, 0, err
+ }
+ defer f.DecRef()
+
+ kfd, kerr = t.NewFDFrom(0, f, kernel.FDFlags{
+ CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0,
+ })
+ t.Kernel().RecordSocket(f)
+ }
+
+ return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr)
+}
+
+// Bind implements socket.Socket.Bind.
+func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ if len(sockaddr) > sizeofSockaddr {
+ sockaddr = sockaddr[:sizeofSockaddr]
+ }
+
+ _, _, errno := syscall.Syscall(syscall.SYS_BIND, uintptr(s.fd), uintptr(firstBytePtr(sockaddr)), uintptr(len(sockaddr)))
+ if errno != 0 {
+ return syserr.FromError(errno)
+ }
+ return nil
+}
+
+// Listen implements socket.Socket.Listen.
+func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
+ return syserr.FromError(syscall.Listen(s.fd, backlog))
+}
+
+// Shutdown implements socket.Socket.Shutdown.
+func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
+ switch how {
+ case syscall.SHUT_RD, syscall.SHUT_WR, syscall.SHUT_RDWR:
+ return syserr.FromError(syscall.Shutdown(s.fd, how))
+ default:
+ return syserr.ErrInvalidArgument
+ }
+}
+
+// GetSockOpt implements socket.Socket.GetSockOpt.
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+ if outLen < 0 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // Only allow known and safe options.
+ optlen := getSockOptLen(t, level, name)
+ switch level {
+ case linux.SOL_IP:
+ switch name {
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO:
+ optlen = sizeofInt32
+ }
+ case linux.SOL_IPV6:
+ switch name {
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
+ optlen = sizeofInt32
+ }
+ case linux.SOL_SOCKET:
+ switch name {
+ case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
+ optlen = sizeofInt32
+ case linux.SO_LINGER:
+ optlen = syscall.SizeofLinger
+ }
+ case linux.SOL_TCP:
+ switch name {
+ case linux.TCP_NODELAY:
+ optlen = sizeofInt32
+ case linux.TCP_INFO:
+ optlen = int(linux.SizeOfTCPInfo)
+ }
+ }
+
+ if optlen == 0 {
+ return nil, syserr.ErrProtocolNotAvailable // ENOPROTOOPT
+ }
+ if outLen < optlen {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ opt, err := getsockopt(s.fd, level, name, optlen)
+ if err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return opt, nil
+}
+
+// SetSockOpt implements socket.Socket.SetSockOpt.
+func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
+ // Only allow known and safe options.
+ optlen := setSockOptLen(t, level, name)
+ switch level {
+ case linux.SOL_IP:
+ switch name {
+ case linux.IP_TOS, linux.IP_RECVTOS:
+ optlen = sizeofInt32
+ case linux.IP_PKTINFO:
+ optlen = linux.SizeOfControlMessageIPPacketInfo
+ }
+ case linux.SOL_IPV6:
+ switch name {
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
+ optlen = sizeofInt32
+ }
+ case linux.SOL_SOCKET:
+ switch name {
+ case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
+ optlen = sizeofInt32
+ }
+ case linux.SOL_TCP:
+ switch name {
+ case linux.TCP_NODELAY:
+ optlen = sizeofInt32
+ }
+ }
+
+ if optlen == 0 {
+ // Pretend to accept socket options we don't understand. This seems
+ // dangerous, but it's what netstack does...
+ return nil
+ }
+ if len(opt) < optlen {
+ return syserr.ErrInvalidArgument
+ }
+ opt = opt[:optlen]
+
+ _, _, errno := syscall.Syscall6(syscall.SYS_SETSOCKOPT, uintptr(s.fd), uintptr(level), uintptr(name), uintptr(firstBytePtr(opt)), uintptr(len(opt)), 0)
+ if errno != 0 {
+ return syserr.FromError(errno)
+ }
+ return nil
+}
+
+// RecvMsg implements socket.Socket.RecvMsg.
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+ // Only allow known and safe flags.
+ //
+ // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary
+ // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the
+ // Socket interface's dependence on netstack.
+ if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
+ }
+
+ var senderAddr linux.SockAddr
+ var senderAddrBuf []byte
+ if senderRequested {
+ senderAddrBuf = make([]byte, sizeofSockaddr)
+ }
+
+ var controlBuf []byte
+ var msgFlags int
+
+ recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of dst.Addrs was unusable.
+ if uint64(dst.NumBytes()) != dsts.NumBytes() {
+ return 0, nil
+ }
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+
+ // We always do a non-blocking recv*().
+ sysflags := flags | syscall.MSG_DONTWAIT
+
+ iovs := safemem.IovecsFromBlockSeq(dsts)
+ msg := syscall.Msghdr{
+ Iov: &iovs[0],
+ Iovlen: uint64(len(iovs)),
+ }
+ if len(senderAddrBuf) != 0 {
+ msg.Name = &senderAddrBuf[0]
+ msg.Namelen = uint32(len(senderAddrBuf))
+ }
+ if controlLen > 0 {
+ if controlLen > maxControlLen {
+ controlLen = maxControlLen
+ }
+ controlBuf = make([]byte, controlLen)
+ msg.Control = &controlBuf[0]
+ msg.Controllen = controlLen
+ }
+ n, err := recvmsg(s.fd, &msg, sysflags)
+ if err != nil {
+ return 0, err
+ }
+ senderAddrBuf = senderAddrBuf[:msg.Namelen]
+ msgFlags = int(msg.Flags)
+ controlLen = uint64(msg.Controllen)
+ return n, nil
+ })
+
+ var ch chan struct{}
+ n, err := dst.CopyOutFrom(t, recvmsgToBlocks)
+ if flags&syscall.MSG_DONTWAIT == 0 {
+ for err == syserror.ErrWouldBlock {
+ // We only expect blocking to come from the actual syscall, in which
+ // case it can't have returned any data.
+ if n != 0 {
+ panic(fmt.Sprintf("CopyOutFrom: got (%d, %v), wanted (0, %v)", n, err, err))
+ }
+ if ch != nil {
+ if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ break
+ }
+ } else {
+ var e waiter.Entry
+ e, ch = waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+ }
+ n, err = dst.CopyOutFrom(t, recvmsgToBlocks)
+ }
+ }
+ if err != nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
+
+ if senderRequested {
+ senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf)
+ }
+
+ unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen])
+ if err != nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
+
+ controlMessages := socket.ControlMessages{}
+ for _, unixCmsg := range unixControlMessages {
+ switch unixCmsg.Header.Level {
+ case syscall.SOL_IP:
+ switch unixCmsg.Header.Type {
+ case syscall.IP_TOS:
+ controlMessages.IP.HasTOS = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS)
+
+ case syscall.IP_PKTINFO:
+ controlMessages.IP.HasIPPacketInfo = true
+ var packetInfo linux.ControlMessageIPPacketInfo
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
+ controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo)
+ }
+
+ case syscall.SOL_IPV6:
+ switch unixCmsg.Header.Type {
+ case syscall.IPV6_TCLASS:
+ controlMessages.IP.HasTClass = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass)
+ }
+ }
+ }
+
+ return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil
+}
+
+// SendMsg implements socket.Socket.SendMsg.
+func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+ // Only allow known and safe flags.
+ if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
+ return 0, syserr.ErrInvalidArgument
+ }
+
+ space := uint64(control.CmsgsSpace(t, controlMessages))
+ if space > maxControlLen {
+ space = maxControlLen
+ }
+ controlBuf := make([]byte, 0, space)
+ // PackControlMessages will append up to space bytes to controlBuf.
+ controlBuf = control.PackControlMessages(t, controlMessages, controlBuf)
+
+ sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of src.Addrs was unusable.
+ if uint64(src.NumBytes()) != srcs.NumBytes() {
+ return 0, nil
+ }
+ if srcs.IsEmpty() && len(controlBuf) == 0 {
+ return 0, nil
+ }
+
+ // We always do a non-blocking send*().
+ sysflags := flags | syscall.MSG_DONTWAIT
+
+ if srcs.NumBlocks() == 1 && len(controlBuf) == 0 {
+ // Skip allocating []syscall.Iovec.
+ src := srcs.Head()
+ n, _, errno := syscall.Syscall6(syscall.SYS_SENDTO, uintptr(s.fd), src.Addr(), uintptr(src.Len()), uintptr(sysflags), uintptr(firstBytePtr(to)), uintptr(len(to)))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+ }
+
+ iovs := safemem.IovecsFromBlockSeq(srcs)
+ msg := syscall.Msghdr{
+ Iov: &iovs[0],
+ Iovlen: uint64(len(iovs)),
+ }
+ if len(to) != 0 {
+ msg.Name = &to[0]
+ msg.Namelen = uint32(len(to))
+ }
+ if len(controlBuf) != 0 {
+ msg.Control = &controlBuf[0]
+ msg.Controllen = uint64(len(controlBuf))
+ }
+ return sendmsg(s.fd, &msg, sysflags)
+ })
+
+ var ch chan struct{}
+ n, err := src.CopyInTo(t, sendmsgFromBlocks)
+ if flags&syscall.MSG_DONTWAIT == 0 {
+ for err == syserror.ErrWouldBlock {
+ // We only expect blocking to come from the actual syscall, in which
+ // case it can't have returned any data.
+ if n != 0 {
+ panic(fmt.Sprintf("CopyInTo: got (%d, %v), wanted (0, %v)", n, err, err))
+ }
+ if ch != nil {
+ if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
+ break
+ }
+ } else {
+ var e waiter.Entry
+ e, ch = waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+ }
+ n, err = src.CopyInTo(t, sendmsgFromBlocks)
+ }
+ }
+
+ return int(n), syserr.FromError(err)
+}
+
+func translateIOSyscallError(err error) error {
+ if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK {
+ return syserror.ErrWouldBlock
+ }
+ return err
+}
+
+// State implements socket.Socket.State.
+func (s *socketOpsCommon) State() uint32 {
+ info := linux.TCPInfo{}
+ buf, err := getsockopt(s.fd, syscall.SOL_TCP, syscall.TCP_INFO, linux.SizeOfTCPInfo)
+ if err != nil {
+ if err != syscall.ENOPROTOOPT {
+ log.Warningf("Failed to get TCP socket info from %+v: %v", s, err)
+ }
+ // For non-TCP sockets, silently ignore the failure.
+ return 0
+ }
+ if len(buf) != linux.SizeOfTCPInfo {
+ // Unmarshal below will panic if getsockopt returns a buffer of
+ // unexpected size.
+ log.Warningf("Failed to get TCP socket info from %+v: getsockopt(2) returned %d bytes, expecting %d bytes.", s, len(buf), linux.SizeOfTCPInfo)
+ return 0
+ }
+
+ binary.Unmarshal(buf, usermem.ByteOrder, &info)
+ return uint32(info.State)
+}
+
+// Type implements socket.Socket.Type.
+func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
+ return s.family, s.stype, s.protocol
+}
+
+type socketProvider struct {
+ family int
+}
+
+// Socket implements socket.Provider.Socket.
+func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Check that we are using the host network stack.
+ stack := t.NetworkContext()
+ if stack == nil {
+ return nil, nil
+ }
+ if _, ok := stack.(*Stack); !ok {
+ return nil, nil
+ }
+
+ // Only accept TCP and UDP.
+ stype := stypeflags & linux.SOCK_TYPE_MASK
+ switch stype {
+ case syscall.SOCK_STREAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_TCP:
+ // ok
+ default:
+ return nil, nil
+ }
+ case syscall.SOCK_DGRAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_UDP:
+ // ok
+ default:
+ return nil, nil
+ }
+ default:
+ return nil, nil
+ }
+
+ // Conservatively ignore all flags specified by the application and add
+ // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0
+ // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent.
+ fd, err := syscall.Socket(p.family, int(stype)|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return newSocketFile(t, p.family, stype, protocol, fd, stypeflags&syscall.SOCK_NONBLOCK != 0)
+}
+
+// Pair implements socket.Provider.Pair.
+func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+ // Not supported by AF_INET/AF_INET6.
+ return nil, nil, nil
+}
+
+// LINT.ThenChange(./socket_vfs2.go)
+
+func init() {
+ for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} {
+ socket.RegisterProvider(family, &socketProvider{family})
+ socket.RegisterProviderVFS2(family, &socketProviderVFS2{})
+ }
+}
diff --git a/pkg/sentry/socket/hostinet/socket_unsafe.go b/pkg/sentry/socket/hostinet/socket_unsafe.go
new file mode 100644
index 000000000..3f420c2ec
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/socket_unsafe.go
@@ -0,0 +1,139 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func firstBytePtr(bs []byte) unsafe.Pointer {
+ if bs == nil {
+ return nil
+ }
+ return unsafe.Pointer(&bs[0])
+}
+
+// Preconditions: len(dsts) != 0.
+func readv(fd int, dsts []syscall.Iovec) (uint64, error) {
+ n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&dsts[0])), uintptr(len(dsts)))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+}
+
+// Preconditions: len(srcs) != 0.
+func writev(fd int, srcs []syscall.Iovec) (uint64, error) {
+ n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&srcs[0])), uintptr(len(srcs)))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+}
+
+func ioctl(ctx context.Context, fd int, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch cmd := uintptr(args[1].Int()); cmd {
+ case syscall.TIOCINQ, syscall.TIOCOUTQ:
+ var val int32
+ if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), cmd, uintptr(unsafe.Pointer(&val))); errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ var buf [4]byte
+ usermem.ByteOrder.PutUint32(buf[:], uint32(val))
+ _, err := io.CopyOut(ctx, args[2].Pointer(), buf[:], usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+func accept4(fd int, addr *byte, addrlen *uint32, flags int) (int, error) {
+ afd, _, errno := syscall.Syscall6(syscall.SYS_ACCEPT4, uintptr(fd), uintptr(unsafe.Pointer(addr)), uintptr(unsafe.Pointer(addrlen)), uintptr(flags), 0, 0)
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return int(afd), nil
+}
+
+func getsockopt(fd int, level, name int, optlen int) ([]byte, error) {
+ opt := make([]byte, optlen)
+ optlen32 := int32(len(opt))
+ _, _, errno := syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd), uintptr(level), uintptr(name), uintptr(firstBytePtr(opt)), uintptr(unsafe.Pointer(&optlen32)), 0)
+ if errno != 0 {
+ return nil, errno
+ }
+ return opt[:optlen32], nil
+}
+
+// GetSockName implements socket.Socket.GetSockName.
+func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+ addr := make([]byte, sizeofSockaddr)
+ addrlen := uint32(len(addr))
+ _, _, errno := syscall.Syscall(syscall.SYS_GETSOCKNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen)))
+ if errno != 0 {
+ return nil, 0, syserr.FromError(errno)
+ }
+ return socket.UnmarshalSockAddr(s.family, addr), addrlen, nil
+}
+
+// GetPeerName implements socket.Socket.GetPeerName.
+func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+ addr := make([]byte, sizeofSockaddr)
+ addrlen := uint32(len(addr))
+ _, _, errno := syscall.Syscall(syscall.SYS_GETPEERNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen)))
+ if errno != 0 {
+ return nil, 0, syserr.FromError(errno)
+ }
+ return socket.UnmarshalSockAddr(s.family, addr), addrlen, nil
+}
+
+func recvfrom(fd int, dst []byte, flags int, from *[]byte) (uint64, error) {
+ fromLen := uint32(len(*from))
+ n, _, errno := syscall.Syscall6(syscall.SYS_RECVFROM, uintptr(fd), uintptr(firstBytePtr(dst)), uintptr(len(dst)), uintptr(flags), uintptr(firstBytePtr(*from)), uintptr(unsafe.Pointer(&fromLen)))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ *from = (*from)[:fromLen]
+ return uint64(n), nil
+}
+
+func recvmsg(fd int, msg *syscall.Msghdr, flags int) (uint64, error) {
+ n, _, errno := syscall.Syscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(msg)), uintptr(flags))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+}
+
+func sendmsg(fd int, msg *syscall.Msghdr, flags int) (uint64, error) {
+ n, _, errno := syscall.Syscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(msg)), uintptr(flags))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+}
diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go
new file mode 100644
index 000000000..8f192c62f
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/socket_vfs2.go
@@ -0,0 +1,202 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
+ "gvisor.dev/gvisor/pkg/sentry/hostfd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+type socketVFS2 struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ // We store metadata for hostinet sockets internally. Technically, we should
+ // access metadata (e.g. through stat, chmod) on the host for correctness,
+ // but this is not very useful for inet socket fds, which do not belong to a
+ // concrete file anyway.
+ vfs.DentryMetadataFileDescriptionImpl
+
+ socketOpsCommon
+}
+
+var _ = socket.SocketVFS2(&socketVFS2{})
+
+func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) {
+ mnt := t.Kernel().SocketMount()
+ d := sockfs.NewDentry(t.Credentials(), mnt)
+
+ s := &socketVFS2{
+ socketOpsCommon: socketOpsCommon{
+ family: family,
+ stype: stype,
+ protocol: protocol,
+ fd: fd,
+ },
+ }
+ s.LockFD.Init(&vfs.FileLocks{})
+ if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ vfsfd := &s.vfsfd
+ if err := vfsfd.Init(s, linux.O_RDWR|(flags&linux.O_NONBLOCK), mnt, d, &vfs.FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return vfsfd, nil
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *socketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.socketOpsCommon.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *socketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.socketOpsCommon.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *socketVFS2) EventUnregister(e *waiter.Entry) {
+ s.socketOpsCommon.EventUnregister(e)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (s *socketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return ioctl(ctx, s.fd, uio, args)
+}
+
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
+func (s *socketVFS2) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ return syserror.ENODEV
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (s *socketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Read implements vfs.FileDescriptionImpl.
+func (s *socketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ reader := hostfd.GetReadWriterAt(int32(s.fd), -1, opts.Flags)
+ n, err := dst.CopyOutFrom(ctx, reader)
+ hostfd.PutReadWriterAt(reader)
+ return int64(n), err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.
+func (s *socketVFS2) PWrite(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (s *socketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ writer := hostfd.GetReadWriterAt(int32(s.fd), -1, opts.Flags)
+ n, err := src.CopyInTo(ctx, writer)
+ hostfd.PutReadWriterAt(writer)
+ return int64(n), err
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (s *socketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (s *socketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
+}
+
+type socketProviderVFS2 struct {
+ family int
+}
+
+// Socket implements socket.ProviderVFS2.Socket.
+func (p *socketProviderVFS2) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ // Check that we are using the host network stack.
+ stack := t.NetworkContext()
+ if stack == nil {
+ return nil, nil
+ }
+ if _, ok := stack.(*Stack); !ok {
+ return nil, nil
+ }
+
+ // Only accept TCP and UDP.
+ stype := stypeflags & linux.SOCK_TYPE_MASK
+ switch stype {
+ case syscall.SOCK_STREAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_TCP:
+ // ok
+ default:
+ return nil, nil
+ }
+ case syscall.SOCK_DGRAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_UDP:
+ // ok
+ default:
+ return nil, nil
+ }
+ default:
+ return nil, nil
+ }
+
+ // Conservatively ignore all flags specified by the application and add
+ // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0
+ // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent.
+ fd, err := syscall.Socket(p.family, int(stype)|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return newVFS2Socket(t, p.family, stype, protocol, fd, uint32(stypeflags&syscall.SOCK_NONBLOCK))
+}
+
+// Pair implements socket.Provider.Pair.
+func (p *socketProviderVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ // Not supported by AF_INET/AF_INET6.
+ return nil, nil, nil
+}
diff --git a/pkg/sentry/socket/hostinet/sockopt_impl.go b/pkg/sentry/socket/hostinet/sockopt_impl.go
new file mode 100644
index 000000000..8a783712e
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/sockopt_impl.go
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+func getSockOptLen(t *kernel.Task, level, name int) int {
+ return 0 // No custom options.
+}
+
+func setSockOptLen(t *kernel.Task, level, name int) int {
+ return 0 // No custom options.
+}
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
new file mode 100644
index 000000000..a48082631
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -0,0 +1,459 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "reflect"
+ "strconv"
+ "strings"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+var defaultRecvBufSize = inet.TCPBufferSize{
+ Min: 4096,
+ Default: 87380,
+ Max: 6291456,
+}
+
+var defaultSendBufSize = inet.TCPBufferSize{
+ Min: 4096,
+ Default: 16384,
+ Max: 4194304,
+}
+
+// Stack implements inet.Stack for host sockets.
+type Stack struct {
+ // Stack is immutable.
+ interfaces map[int32]inet.Interface
+ interfaceAddrs map[int32][]inet.InterfaceAddr
+ routes []inet.Route
+ supportsIPv6 bool
+ tcpRecvBufSize inet.TCPBufferSize
+ tcpSendBufSize inet.TCPBufferSize
+ tcpSACKEnabled bool
+ netDevFile *os.File
+ netSNMPFile *os.File
+}
+
+// NewStack returns an empty Stack containing no configuration.
+func NewStack() *Stack {
+ return &Stack{
+ interfaces: make(map[int32]inet.Interface),
+ interfaceAddrs: make(map[int32][]inet.InterfaceAddr),
+ }
+}
+
+// Configure sets up the stack using the current state of the host network.
+func (s *Stack) Configure() error {
+ if err := addHostInterfaces(s); err != nil {
+ return err
+ }
+
+ if err := addHostRoutes(s); err != nil {
+ return err
+ }
+
+ if _, err := os.Stat("/proc/net/if_inet6"); err == nil {
+ s.supportsIPv6 = true
+ }
+
+ s.tcpRecvBufSize = defaultRecvBufSize
+ if tcpRMem, err := readTCPBufferSizeFile("/proc/sys/net/ipv4/tcp_rmem"); err == nil {
+ s.tcpRecvBufSize = tcpRMem
+ } else {
+ log.Warningf("Failed to read TCP receive buffer size, using default values")
+ }
+
+ s.tcpSendBufSize = defaultSendBufSize
+ if tcpWMem, err := readTCPBufferSizeFile("/proc/sys/net/ipv4/tcp_wmem"); err == nil {
+ s.tcpSendBufSize = tcpWMem
+ } else {
+ log.Warningf("Failed to read TCP send buffer size, using default values")
+ }
+
+ // SACK is important for performance and even compatibility, assume it's
+ // enabled if we can't find the actual value.
+ s.tcpSACKEnabled = true
+ if sack, err := ioutil.ReadFile("/proc/sys/net/ipv4/tcp_sack"); err == nil {
+ s.tcpSACKEnabled = strings.TrimSpace(string(sack)) != "0"
+ } else {
+ log.Warningf("Failed to read if TCP SACK if enabled, setting to true")
+ }
+
+ if f, err := os.Open("/proc/net/dev"); err != nil {
+ log.Warningf("Failed to open /proc/net/dev: %v", err)
+ } else {
+ s.netDevFile = f
+ }
+
+ if f, err := os.Open("/proc/net/snmp"); err != nil {
+ log.Warningf("Failed to open /proc/net/snmp: %v", err)
+ } else {
+ s.netSNMPFile = f
+ }
+
+ return nil
+}
+
+// ExtractHostInterfaces will populate an interface map and
+// interfaceAddrs map with the results of the equivalent
+// netlink messages.
+func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.NetlinkMessage, interfaces map[int32]inet.Interface, interfaceAddrs map[int32][]inet.InterfaceAddr) error {
+ for _, link := range links {
+ if link.Header.Type != syscall.RTM_NEWLINK {
+ continue
+ }
+ if len(link.Data) < syscall.SizeofIfInfomsg {
+ return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid data length (%d bytes, expected at least %d bytes)", len(link.Data), syscall.SizeofIfInfomsg)
+ }
+ var ifinfo syscall.IfInfomsg
+ binary.Unmarshal(link.Data[:syscall.SizeofIfInfomsg], usermem.ByteOrder, &ifinfo)
+ inetIF := inet.Interface{
+ DeviceType: ifinfo.Type,
+ Flags: ifinfo.Flags,
+ }
+ // Not clearly documented: syscall.ParseNetlinkRouteAttr will check the
+ // syscall.NetlinkMessage.Header.Type and skip the struct ifinfomsg
+ // accordingly.
+ attrs, err := syscall.ParseNetlinkRouteAttr(&link)
+ if err != nil {
+ return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid rtattrs: %v", err)
+ }
+ for _, attr := range attrs {
+ switch attr.Attr.Type {
+ case syscall.IFLA_ADDRESS:
+ inetIF.Addr = attr.Value
+ case syscall.IFLA_IFNAME:
+ inetIF.Name = string(attr.Value[:len(attr.Value)-1])
+ }
+ }
+ interfaces[ifinfo.Index] = inetIF
+ }
+
+ for _, addr := range addrs {
+ if addr.Header.Type != syscall.RTM_NEWADDR {
+ continue
+ }
+ if len(addr.Data) < syscall.SizeofIfAddrmsg {
+ return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid data length (%d bytes, expected at least %d bytes)", len(addr.Data), syscall.SizeofIfAddrmsg)
+ }
+ var ifaddr syscall.IfAddrmsg
+ binary.Unmarshal(addr.Data[:syscall.SizeofIfAddrmsg], usermem.ByteOrder, &ifaddr)
+ inetAddr := inet.InterfaceAddr{
+ Family: ifaddr.Family,
+ PrefixLen: ifaddr.Prefixlen,
+ Flags: ifaddr.Flags,
+ }
+ attrs, err := syscall.ParseNetlinkRouteAttr(&addr)
+ if err != nil {
+ return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid rtattrs: %v", err)
+ }
+ for _, attr := range attrs {
+ switch attr.Attr.Type {
+ case syscall.IFA_ADDRESS:
+ inetAddr.Addr = attr.Value
+ }
+ }
+ interfaceAddrs[int32(ifaddr.Index)] = append(interfaceAddrs[int32(ifaddr.Index)], inetAddr)
+ }
+
+ return nil
+}
+
+// ExtractHostRoutes populates the given routes slice with the data from the
+// host route table.
+func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error) {
+ var routes []inet.Route
+ for _, routeMsg := range routeMsgs {
+ if routeMsg.Header.Type != syscall.RTM_NEWROUTE {
+ continue
+ }
+
+ var ifRoute syscall.RtMsg
+ binary.Unmarshal(routeMsg.Data[:syscall.SizeofRtMsg], usermem.ByteOrder, &ifRoute)
+ inetRoute := inet.Route{
+ Family: ifRoute.Family,
+ DstLen: ifRoute.Dst_len,
+ SrcLen: ifRoute.Src_len,
+ TOS: ifRoute.Tos,
+ Table: ifRoute.Table,
+ Protocol: ifRoute.Protocol,
+ Scope: ifRoute.Scope,
+ Type: ifRoute.Type,
+ Flags: ifRoute.Flags,
+ }
+
+ // Not clearly documented: syscall.ParseNetlinkRouteAttr will check the
+ // syscall.NetlinkMessage.Header.Type and skip the struct rtmsg
+ // accordingly.
+ attrs, err := syscall.ParseNetlinkRouteAttr(&routeMsg)
+ if err != nil {
+ return nil, fmt.Errorf("RTM_GETROUTE returned RTM_NEWROUTE message with invalid rtattrs: %v", err)
+ }
+
+ for _, attr := range attrs {
+ switch attr.Attr.Type {
+ case syscall.RTA_DST:
+ inetRoute.DstAddr = attr.Value
+ case syscall.RTA_SRC:
+ inetRoute.SrcAddr = attr.Value
+ case syscall.RTA_GATEWAY:
+ inetRoute.GatewayAddr = attr.Value
+ case syscall.RTA_OIF:
+ expected := int(binary.Size(inetRoute.OutputInterface))
+ if len(attr.Value) != expected {
+ return nil, fmt.Errorf("RTM_GETROUTE returned RTM_NEWROUTE message with invalid attribute data length (%d bytes, expected %d bytes)", len(attr.Value), expected)
+ }
+ binary.Unmarshal(attr.Value, usermem.ByteOrder, &inetRoute.OutputInterface)
+ }
+ }
+
+ routes = append(routes, inetRoute)
+ }
+
+ return routes, nil
+}
+
+func addHostInterfaces(s *Stack) error {
+ links, err := doNetlinkRouteRequest(syscall.RTM_GETLINK)
+ if err != nil {
+ return fmt.Errorf("RTM_GETLINK failed: %v", err)
+ }
+
+ addrs, err := doNetlinkRouteRequest(syscall.RTM_GETADDR)
+ if err != nil {
+ return fmt.Errorf("RTM_GETADDR failed: %v", err)
+ }
+
+ return ExtractHostInterfaces(links, addrs, s.interfaces, s.interfaceAddrs)
+}
+
+func addHostRoutes(s *Stack) error {
+ routes, err := doNetlinkRouteRequest(syscall.RTM_GETROUTE)
+ if err != nil {
+ return fmt.Errorf("RTM_GETROUTE failed: %v", err)
+ }
+
+ s.routes, err = ExtractHostRoutes(routes)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func doNetlinkRouteRequest(req int) ([]syscall.NetlinkMessage, error) {
+ data, err := syscall.NetlinkRIB(req, syscall.AF_UNSPEC)
+ if err != nil {
+ return nil, err
+ }
+ return syscall.ParseNetlinkMessage(data)
+}
+
+func readTCPBufferSizeFile(filename string) (inet.TCPBufferSize, error) {
+ contents, err := ioutil.ReadFile(filename)
+ if err != nil {
+ return inet.TCPBufferSize{}, fmt.Errorf("failed to read %s: %v", filename, err)
+ }
+ ioseq := usermem.BytesIOSequence(contents)
+ fields := make([]int32, 3)
+ if n, err := usermem.CopyInt32StringsInVec(context.Background(), ioseq.IO, ioseq.Addrs, fields, ioseq.Opts); n != ioseq.NumBytes() || err != nil {
+ return inet.TCPBufferSize{}, fmt.Errorf("failed to parse %s (%q): got %v after %d/%d bytes", filename, contents, err, n, ioseq.NumBytes())
+ }
+ return inet.TCPBufferSize{
+ Min: int(fields[0]),
+ Default: int(fields[1]),
+ Max: int(fields[2]),
+ }, nil
+}
+
+// Interfaces implements inet.Stack.Interfaces.
+func (s *Stack) Interfaces() map[int32]inet.Interface {
+ interfaces := make(map[int32]inet.Interface)
+ for k, v := range s.interfaces {
+ interfaces[k] = v
+ }
+ return interfaces
+}
+
+// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
+func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
+ addrs := make(map[int32][]inet.InterfaceAddr)
+ for k, v := range s.interfaceAddrs {
+ addrs[k] = append([]inet.InterfaceAddr(nil), v...)
+ }
+ return addrs
+}
+
+// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
+func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+ return syserror.EACCES
+}
+
+// SupportsIPv6 implements inet.Stack.SupportsIPv6.
+func (s *Stack) SupportsIPv6() bool {
+ return s.supportsIPv6
+}
+
+// TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize.
+func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
+ return s.tcpRecvBufSize, nil
+}
+
+// SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize.
+func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error {
+ return syserror.EACCES
+}
+
+// TCPSendBufferSize implements inet.Stack.TCPSendBufferSize.
+func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) {
+ return s.tcpSendBufSize, nil
+}
+
+// SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize.
+func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error {
+ return syserror.EACCES
+}
+
+// TCPSACKEnabled implements inet.Stack.TCPSACKEnabled.
+func (s *Stack) TCPSACKEnabled() (bool, error) {
+ return s.tcpSACKEnabled, nil
+}
+
+// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
+func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
+ return syserror.EACCES
+}
+
+// getLine reads one line from proc file, with specified prefix.
+// The last argument, withHeader, specifies if it contains line header.
+func getLine(f *os.File, prefix string, withHeader bool) string {
+ data := make([]byte, 4096)
+
+ if _, err := f.Seek(0, 0); err != nil {
+ return ""
+ }
+
+ if _, err := io.ReadFull(f, data); err != io.ErrUnexpectedEOF {
+ return ""
+ }
+
+ prefix = prefix + ":"
+ lines := strings.Split(string(data), "\n")
+ for _, l := range lines {
+ l = strings.TrimSpace(l)
+ if strings.HasPrefix(l, prefix) {
+ if withHeader {
+ withHeader = false
+ continue
+ }
+ return l
+ }
+ }
+ return ""
+}
+
+func toSlice(i interface{}) []uint64 {
+ v := reflect.Indirect(reflect.ValueOf(i))
+ return v.Slice(0, v.Len()).Interface().([]uint64)
+}
+
+// Statistics implements inet.Stack.Statistics.
+func (s *Stack) Statistics(stat interface{}, arg string) error {
+ var (
+ snmpTCP bool
+ rawLine string
+ sliceStat []uint64
+ )
+
+ switch stat.(type) {
+ case *inet.StatDev:
+ if s.netDevFile == nil {
+ return fmt.Errorf("/proc/net/dev is not opened for hostinet")
+ }
+ rawLine = getLine(s.netDevFile, arg, false /* with no header */)
+ case *inet.StatSNMPIP, *inet.StatSNMPICMP, *inet.StatSNMPICMPMSG, *inet.StatSNMPTCP, *inet.StatSNMPUDP, *inet.StatSNMPUDPLite:
+ if s.netSNMPFile == nil {
+ return fmt.Errorf("/proc/net/snmp is not opened for hostinet")
+ }
+ rawLine = getLine(s.netSNMPFile, arg, true)
+ default:
+ return syserr.ErrEndpointOperation.ToError()
+ }
+
+ if rawLine == "" {
+ return fmt.Errorf("Failed to get raw line")
+ }
+
+ parts := strings.SplitN(rawLine, ":", 2)
+ if len(parts) != 2 {
+ return fmt.Errorf("Failed to get prefix from: %q", rawLine)
+ }
+
+ sliceStat = toSlice(stat)
+ fields := strings.Fields(strings.TrimSpace(parts[1]))
+ if len(fields) != len(sliceStat) {
+ return fmt.Errorf("Failed to parse fields: %q", rawLine)
+ }
+ if _, ok := stat.(*inet.StatSNMPTCP); ok {
+ snmpTCP = true
+ }
+ for i := 0; i < len(sliceStat); i++ {
+ var err error
+ if snmpTCP && i == 3 {
+ var tmp int64
+ // MaxConn field is signed, RFC 2012.
+ tmp, err = strconv.ParseInt(fields[i], 10, 64)
+ sliceStat[i] = uint64(tmp) // Convert back to int before use.
+ } else {
+ sliceStat[i], err = strconv.ParseUint(fields[i], 10, 64)
+ }
+ if err != nil {
+ return fmt.Errorf("Failed to parse field %d from: %q, %v", i, rawLine, err)
+ }
+ }
+
+ return nil
+}
+
+// RouteTable implements inet.Stack.RouteTable.
+func (s *Stack) RouteTable() []inet.Route {
+ return append([]inet.Route(nil), s.routes...)
+}
+
+// Resume implements inet.Stack.Resume.
+func (s *Stack) Resume() {}
+
+// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
+func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil }
+
+// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
+func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil }
+
+// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
+func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
new file mode 100644
index 000000000..721094bbf
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -0,0 +1,29 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "netfilter",
+ srcs = [
+ "extensions.go",
+ "netfilter.go",
+ "owner_matcher.go",
+ "targets.go",
+ "tcp_matcher.go",
+ "udp_matcher.go",
+ ],
+ # This target depends on netstack and should only be used by epsocket,
+ # which is allowed to depend on netstack.
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/log",
+ "//pkg/sentry/kernel",
+ "//pkg/syserr",
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
new file mode 100644
index 000000000..0336a32d8
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -0,0 +1,95 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// TODO(gvisor.dev/issue/170): The following per-matcher params should be
+// supported:
+// - Table name
+// - Match size
+// - User size
+// - Hooks
+// - Proto
+// - Family
+
+// matchMaker knows how to (un)marshal the matcher named name().
+type matchMaker interface {
+ // name is the matcher name as stored in the xt_entry_match struct.
+ name() string
+
+ // marshal converts from an stack.Matcher to an ABI struct.
+ marshal(matcher stack.Matcher) []byte
+
+ // unmarshal converts from the ABI matcher struct to an
+ // stack.Matcher.
+ unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error)
+}
+
+// matchMakers maps the name of supported matchers to the matchMaker that
+// marshals and unmarshals it. It is immutable after package initialization.
+var matchMakers = map[string]matchMaker{}
+
+// registermatchMaker should be called by match extensions to register them
+// with the netfilter package.
+func registerMatchMaker(mm matchMaker) {
+ if _, ok := matchMakers[mm.name()]; ok {
+ panic(fmt.Sprintf("Multiple matches registered with name %q.", mm.name()))
+ }
+ matchMakers[mm.name()] = mm
+}
+
+func marshalMatcher(matcher stack.Matcher) []byte {
+ matchMaker, ok := matchMakers[matcher.Name()]
+ if !ok {
+ panic(fmt.Sprintf("Unknown matcher of type %T.", matcher))
+ }
+ return matchMaker.marshal(matcher)
+}
+
+// marshalEntryMatch creates a marshalled XTEntryMatch with the given name and
+// data appended at the end.
+func marshalEntryMatch(name string, data []byte) []byte {
+ nflog("marshaling matcher %q", name)
+
+ // We have to pad this struct size to a multiple of 8 bytes.
+ size := binary.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8)
+ matcher := linux.KernelXTEntryMatch{
+ XTEntryMatch: linux.XTEntryMatch{
+ MatchSize: uint16(size),
+ },
+ Data: data,
+ }
+ copy(matcher.Name[:], name)
+
+ buf := make([]byte, 0, size)
+ buf = binary.Marshal(buf, usermem.ByteOrder, matcher)
+ return append(buf, make([]byte, size-len(buf))...)
+}
+
+func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) {
+ matchMaker, ok := matchMakers[match.Name.String()]
+ if !ok {
+ return nil, fmt.Errorf("unsupported matcher with name %q", match.Name.String())
+ }
+ return matchMaker.unmarshal(buf, filter)
+}
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
new file mode 100644
index 000000000..f7abe77d3
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -0,0 +1,761 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package netfilter helps the sentry interact with netstack's netfilter
+// capabilities.
+package netfilter
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// errorTargetName is used to mark targets as error targets. Error targets
+// shouldn't be reached - an error has occurred if we fall through to one.
+const errorTargetName = "ERROR"
+
+// redirectTargetName is used to mark targets as redirect targets. Redirect
+// targets should be reached for only NAT and Mangle tables. These targets will
+// change the destination port/destination IP for packets.
+const redirectTargetName = "REDIRECT"
+
+// enableLogging controls whether to log the (de)serialization of netfilter
+// structs between userspace and netstack. These logs are useful when
+// developing iptables, but can pollute sentry logs otherwise.
+const enableLogging = false
+
+// emptyFilter is for comparison with a rule's filters to determine whether it
+// is also empty. It is immutable.
+var emptyFilter = stack.IPHeaderFilter{
+ Dst: "\x00\x00\x00\x00",
+ DstMask: "\x00\x00\x00\x00",
+ Src: "\x00\x00\x00\x00",
+ SrcMask: "\x00\x00\x00\x00",
+}
+
+// nflog logs messages related to the writing and reading of iptables.
+func nflog(format string, args ...interface{}) {
+ if enableLogging && log.IsLogging(log.Debug) {
+ log.Debugf("netfilter: "+format, args...)
+ }
+}
+
+// GetInfo returns information about iptables.
+func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) {
+ // Read in the struct and table name.
+ var info linux.IPTGetinfo
+ if _, err := t.CopyIn(outPtr, &info); err != nil {
+ return linux.IPTGetinfo{}, syserr.FromError(err)
+ }
+
+ _, info, err := convertNetstackToBinary(stack, info.Name)
+ if err != nil {
+ nflog("couldn't convert iptables: %v", err)
+ return linux.IPTGetinfo{}, syserr.ErrInvalidArgument
+ }
+
+ nflog("returning info: %+v", info)
+ return info, nil
+}
+
+// GetEntries returns netstack's iptables rules encoded for the iptables tool.
+func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) {
+ // Read in the struct and table name.
+ var userEntries linux.IPTGetEntries
+ if _, err := t.CopyIn(outPtr, &userEntries); err != nil {
+ nflog("couldn't copy in entries %q", userEntries.Name)
+ return linux.KernelIPTGetEntries{}, syserr.FromError(err)
+ }
+
+ // Convert netstack's iptables rules to something that the iptables
+ // tool can understand.
+ entries, _, err := convertNetstackToBinary(stack, userEntries.Name)
+ if err != nil {
+ nflog("couldn't read entries: %v", err)
+ return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
+ }
+ if binary.Size(entries) > uintptr(outLen) {
+ nflog("insufficient GetEntries output size: %d", uintptr(outLen))
+ return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
+ }
+
+ return entries, nil
+}
+
+// convertNetstackToBinary converts the iptables as stored in netstack to the
+// format expected by the iptables tool. Linux stores each table as a binary
+// blob that can only be traversed by parsing a bit, reading some offsets,
+// jumping to those offsets, parsing again, etc.
+func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) {
+ table, ok := stack.IPTables().GetTable(tablename.String())
+ if !ok {
+ return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename)
+ }
+
+ var entries linux.KernelIPTGetEntries
+ var info linux.IPTGetinfo
+ info.ValidHooks = table.ValidHooks()
+
+ // The table name has to fit in the struct.
+ if linux.XT_TABLE_MAXNAMELEN < len(tablename) {
+ return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename)
+ }
+ copy(info.Name[:], tablename[:])
+ copy(entries.Name[:], tablename[:])
+
+ for ruleIdx, rule := range table.Rules {
+ nflog("convert to binary: current offset: %d", entries.Size)
+
+ // Is this a chain entry point?
+ for hook, hookRuleIdx := range table.BuiltinChains {
+ if hookRuleIdx == ruleIdx {
+ nflog("convert to binary: found hook %d at offset %d", hook, entries.Size)
+ info.HookEntry[hook] = entries.Size
+ }
+ }
+ // Is this a chain underflow point?
+ for underflow, underflowRuleIdx := range table.Underflows {
+ if underflowRuleIdx == ruleIdx {
+ nflog("convert to binary: found underflow %d at offset %d", underflow, entries.Size)
+ info.Underflow[underflow] = entries.Size
+ }
+ }
+
+ // Each rule corresponds to an entry.
+ entry := linux.KernelIPTEntry{
+ IPTEntry: linux.IPTEntry{
+ IP: linux.IPTIP{
+ Protocol: uint16(rule.Filter.Protocol),
+ },
+ NextOffset: linux.SizeOfIPTEntry,
+ TargetOffset: linux.SizeOfIPTEntry,
+ },
+ }
+ copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst)
+ copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask)
+ copy(entry.IPTEntry.IP.Src[:], rule.Filter.Src)
+ copy(entry.IPTEntry.IP.SrcMask[:], rule.Filter.SrcMask)
+ copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface)
+ copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask)
+ if rule.Filter.DstInvert {
+ entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP
+ }
+ if rule.Filter.SrcInvert {
+ entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_SRCIP
+ }
+ if rule.Filter.OutputInterfaceInvert {
+ entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT
+ }
+
+ for _, matcher := range rule.Matchers {
+ // Serialize the matcher and add it to the
+ // entry.
+ serialized := marshalMatcher(matcher)
+ nflog("convert to binary: matcher serialized as: %v", serialized)
+ if len(serialized)%8 != 0 {
+ panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher))
+ }
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.NextOffset += uint16(len(serialized))
+ entry.TargetOffset += uint16(len(serialized))
+ }
+
+ // Serialize and append the target.
+ serialized := marshalTarget(rule.Target)
+ if len(serialized)%8 != 0 {
+ panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target))
+ }
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.NextOffset += uint16(len(serialized))
+
+ nflog("convert to binary: adding entry: %+v", entry)
+
+ entries.Size += uint32(entry.NextOffset)
+ entries.Entrytable = append(entries.Entrytable, entry)
+ info.NumEntries++
+ }
+
+ nflog("convert to binary: finished with an marshalled size of %d", info.Size)
+ info.Size = entries.Size
+ return entries, info, nil
+}
+
+func marshalTarget(target stack.Target) []byte {
+ switch tg := target.(type) {
+ case stack.AcceptTarget:
+ return marshalStandardTarget(stack.RuleAccept)
+ case stack.DropTarget:
+ return marshalStandardTarget(stack.RuleDrop)
+ case stack.ErrorTarget:
+ return marshalErrorTarget(errorTargetName)
+ case stack.UserChainTarget:
+ return marshalErrorTarget(tg.Name)
+ case stack.ReturnTarget:
+ return marshalStandardTarget(stack.RuleReturn)
+ case stack.RedirectTarget:
+ return marshalRedirectTarget(tg)
+ case JumpTarget:
+ return marshalJumpTarget(tg)
+ default:
+ panic(fmt.Errorf("unknown target of type %T", target))
+ }
+}
+
+func marshalStandardTarget(verdict stack.RuleVerdict) []byte {
+ nflog("convert to binary: marshalling standard target")
+
+ // The target's name will be the empty string.
+ target := linux.XTStandardTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTStandardTarget,
+ },
+ Verdict: translateFromStandardVerdict(verdict),
+ }
+
+ ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+func marshalErrorTarget(errorName string) []byte {
+ // This is an error target named error
+ target := linux.XTErrorTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTErrorTarget,
+ },
+ }
+ copy(target.Name[:], errorName)
+ copy(target.Target.Name[:], errorTargetName)
+
+ ret := make([]byte, 0, linux.SizeOfXTErrorTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+func marshalRedirectTarget(rt stack.RedirectTarget) []byte {
+ // This is a redirect target named redirect
+ target := linux.XTRedirectTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTRedirectTarget,
+ },
+ }
+ copy(target.Target.Name[:], redirectTargetName)
+
+ ret := make([]byte, 0, linux.SizeOfXTRedirectTarget)
+ target.NfRange.RangeSize = 1
+ if rt.RangeProtoSpecified {
+ target.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED
+ }
+ // Convert port from little endian to big endian.
+ port := make([]byte, 2)
+ binary.LittleEndian.PutUint16(port, rt.MinPort)
+ target.NfRange.RangeIPV4.MinPort = binary.BigEndian.Uint16(port)
+ binary.LittleEndian.PutUint16(port, rt.MaxPort)
+ target.NfRange.RangeIPV4.MaxPort = binary.BigEndian.Uint16(port)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+func marshalJumpTarget(jt JumpTarget) []byte {
+ nflog("convert to binary: marshalling jump target")
+
+ // The target's name will be the empty string.
+ target := linux.XTStandardTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTStandardTarget,
+ },
+ // Verdict is overloaded by the ABI. When positive, it holds
+ // the jump offset from the start of the table.
+ Verdict: int32(jt.Offset),
+ }
+
+ ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+// translateFromStandardVerdict translates verdicts the same way as the iptables
+// tool.
+func translateFromStandardVerdict(verdict stack.RuleVerdict) int32 {
+ switch verdict {
+ case stack.RuleAccept:
+ return -linux.NF_ACCEPT - 1
+ case stack.RuleDrop:
+ return -linux.NF_DROP - 1
+ case stack.RuleReturn:
+ return linux.NF_RETURN
+ default:
+ // TODO(gvisor.dev/issue/170): Support Jump.
+ panic(fmt.Sprintf("unknown standard verdict: %d", verdict))
+ }
+}
+
+// translateToStandardTarget translates from the value in a
+// linux.XTStandardTarget to an stack.Verdict.
+func translateToStandardTarget(val int32) (stack.Target, error) {
+ // TODO(gvisor.dev/issue/170): Support other verdicts.
+ switch val {
+ case -linux.NF_ACCEPT - 1:
+ return stack.AcceptTarget{}, nil
+ case -linux.NF_DROP - 1:
+ return stack.DropTarget{}, nil
+ case -linux.NF_QUEUE - 1:
+ return nil, errors.New("unsupported iptables verdict QUEUE")
+ case linux.NF_RETURN:
+ return stack.ReturnTarget{}, nil
+ default:
+ return nil, fmt.Errorf("unknown iptables verdict %d", val)
+ }
+}
+
+// SetEntries sets iptables rules for a single table. See
+// net/ipv4/netfilter/ip_tables.c:translate_table for reference.
+func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
+ // Get the basic rules data (struct ipt_replace).
+ if len(optVal) < linux.SizeOfIPTReplace {
+ nflog("optVal has insufficient size for replace %d", len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ var replace linux.IPTReplace
+ replaceBuf := optVal[:linux.SizeOfIPTReplace]
+ optVal = optVal[linux.SizeOfIPTReplace:]
+ binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace)
+
+ // TODO(gvisor.dev/issue/170): Support other tables.
+ var table stack.Table
+ switch replace.Name.String() {
+ case stack.TablenameFilter:
+ table = stack.EmptyFilterTable()
+ case stack.TablenameNat:
+ table = stack.EmptyNatTable()
+ default:
+ nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
+ return syserr.ErrInvalidArgument
+ }
+
+ nflog("set entries: setting entries in table %q", replace.Name.String())
+
+ // Convert input into a list of rules and their offsets.
+ var offset uint32
+ // offsets maps rule byte offsets to their position in table.Rules.
+ offsets := map[uint32]int{}
+ for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ {
+ nflog("set entries: processing entry at offset %d", offset)
+
+ // Get the struct ipt_entry.
+ if len(optVal) < linux.SizeOfIPTEntry {
+ nflog("optVal has insufficient size for entry %d", len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ var entry linux.IPTEntry
+ buf := optVal[:linux.SizeOfIPTEntry]
+ binary.Unmarshal(buf, usermem.ByteOrder, &entry)
+ initialOptValLen := len(optVal)
+ optVal = optVal[linux.SizeOfIPTEntry:]
+
+ if entry.TargetOffset < linux.SizeOfIPTEntry {
+ nflog("entry has too-small target offset %d", entry.TargetOffset)
+ return syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): We should support more IPTIP
+ // filtering fields.
+ filter, err := filterFromIPTIP(entry.IP)
+ if err != nil {
+ nflog("bad iptip: %v", err)
+ return syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): Matchers and targets can specify
+ // that they only work for certain protocols, hooks, tables.
+ // Get matchers.
+ matchersSize := entry.TargetOffset - linux.SizeOfIPTEntry
+ if len(optVal) < int(matchersSize) {
+ nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ matchers, err := parseMatchers(filter, optVal[:matchersSize])
+ if err != nil {
+ nflog("failed to parse matchers: %v", err)
+ return syserr.ErrInvalidArgument
+ }
+ optVal = optVal[matchersSize:]
+
+ // Get the target of the rule.
+ targetSize := entry.NextOffset - entry.TargetOffset
+ if len(optVal) < int(targetSize) {
+ nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ target, err := parseTarget(filter, optVal[:targetSize])
+ if err != nil {
+ nflog("failed to parse target: %v", err)
+ return syserr.ErrInvalidArgument
+ }
+ optVal = optVal[targetSize:]
+
+ table.Rules = append(table.Rules, stack.Rule{
+ Filter: filter,
+ Target: target,
+ Matchers: matchers,
+ })
+ offsets[offset] = int(entryIdx)
+ offset += uint32(entry.NextOffset)
+
+ if initialOptValLen-len(optVal) != int(entry.NextOffset) {
+ nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ }
+
+ // Go through the list of supported hooks for this table and, for each
+ // one, set the rule it corresponds to.
+ for hook, _ := range replace.HookEntry {
+ if table.ValidHooks()&(1<<hook) != 0 {
+ hk := hookFromLinux(hook)
+ for offset, ruleIdx := range offsets {
+ if offset == replace.HookEntry[hook] {
+ table.BuiltinChains[hk] = ruleIdx
+ }
+ if offset == replace.Underflow[hook] {
+ if !validUnderflow(table.Rules[ruleIdx]) {
+ nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP", ruleIdx)
+ return syserr.ErrInvalidArgument
+ }
+ table.Underflows[hk] = ruleIdx
+ }
+ }
+ if ruleIdx := table.BuiltinChains[hk]; ruleIdx == stack.HookUnset {
+ nflog("hook %v is unset.", hk)
+ return syserr.ErrInvalidArgument
+ }
+ if ruleIdx := table.Underflows[hk]; ruleIdx == stack.HookUnset {
+ nflog("underflow %v is unset.", hk)
+ return syserr.ErrInvalidArgument
+ }
+ }
+ }
+
+ // Add the user chains.
+ for ruleIdx, rule := range table.Rules {
+ target, ok := rule.Target.(stack.UserChainTarget)
+ if !ok {
+ continue
+ }
+
+ // We found a user chain. Before inserting it into the table,
+ // check that:
+ // - There's some other rule after it.
+ // - There are no matchers.
+ if ruleIdx == len(table.Rules)-1 {
+ nflog("user chain must have a rule or default policy")
+ return syserr.ErrInvalidArgument
+ }
+ if len(table.Rules[ruleIdx].Matchers) != 0 {
+ nflog("user chain's first node must have no matchers")
+ return syserr.ErrInvalidArgument
+ }
+ table.UserChains[target.Name] = ruleIdx + 1
+ }
+
+ // Set each jump to point to the appropriate rule. Right now they hold byte
+ // offsets.
+ for ruleIdx, rule := range table.Rules {
+ jump, ok := rule.Target.(JumpTarget)
+ if !ok {
+ continue
+ }
+
+ // Find the rule corresponding to the jump rule offset.
+ jumpTo, ok := offsets[jump.Offset]
+ if !ok {
+ nflog("failed to find a rule to jump to")
+ return syserr.ErrInvalidArgument
+ }
+ jump.RuleNum = jumpTo
+ rule.Target = jump
+ table.Rules[ruleIdx] = rule
+ }
+
+ // TODO(gvisor.dev/issue/170): Support other chains.
+ // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now,
+ // make sure all other chains point to ACCEPT rules.
+ for hook, ruleIdx := range table.BuiltinChains {
+ if hook == stack.Forward || hook == stack.Postrouting {
+ if !isUnconditionalAccept(table.Rules[ruleIdx]) {
+ nflog("hook %d is unsupported.", hook)
+ return syserr.ErrInvalidArgument
+ }
+ }
+ }
+
+ // TODO(gvisor.dev/issue/170): Check the following conditions:
+ // - There are no loops.
+ // - There are no chains without an unconditional final rule.
+ // - There are no chains without an unconditional underflow rule.
+
+ stk.IPTables().ReplaceTable(replace.Name.String(), table)
+
+ return nil
+}
+
+// parseMatchers parses 0 or more matchers from optVal. optVal should contain
+// only the matchers.
+func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, error) {
+ nflog("set entries: parsing matchers of size %d", len(optVal))
+ var matchers []stack.Matcher
+ for len(optVal) > 0 {
+ nflog("set entries: optVal has len %d", len(optVal))
+
+ // Get the XTEntryMatch.
+ if len(optVal) < linux.SizeOfXTEntryMatch {
+ return nil, fmt.Errorf("optVal has insufficient size for entry match: %d", len(optVal))
+ }
+ var match linux.XTEntryMatch
+ buf := optVal[:linux.SizeOfXTEntryMatch]
+ binary.Unmarshal(buf, usermem.ByteOrder, &match)
+ nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match)
+
+ // Check some invariants.
+ if match.MatchSize < linux.SizeOfXTEntryMatch {
+
+ return nil, fmt.Errorf("match size is too small, must be at least %d", linux.SizeOfXTEntryMatch)
+ }
+ if len(optVal) < int(match.MatchSize) {
+ return nil, fmt.Errorf("optVal has insufficient size for match: %d", len(optVal))
+ }
+
+ // Parse the specific matcher.
+ matcher, err := unmarshalMatcher(match, filter, optVal[linux.SizeOfXTEntryMatch:match.MatchSize])
+ if err != nil {
+ return nil, fmt.Errorf("failed to create matcher: %v", err)
+ }
+ matchers = append(matchers, matcher)
+
+ // TODO(gvisor.dev/issue/170): Check the revision field.
+ optVal = optVal[match.MatchSize:]
+ }
+
+ if len(optVal) != 0 {
+ return nil, errors.New("optVal should be exhausted after parsing matchers")
+ }
+
+ return matchers, nil
+}
+
+// parseTarget parses a target from optVal. optVal should contain only the
+// target.
+func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, error) {
+ nflog("set entries: parsing target of size %d", len(optVal))
+ if len(optVal) < linux.SizeOfXTEntryTarget {
+ return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal))
+ }
+ var target linux.XTEntryTarget
+ buf := optVal[:linux.SizeOfXTEntryTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &target)
+ switch target.Name.String() {
+ case "":
+ // Standard target.
+ if len(optVal) != linux.SizeOfXTStandardTarget {
+ return nil, fmt.Errorf("optVal has wrong size for standard target %d", len(optVal))
+ }
+ var standardTarget linux.XTStandardTarget
+ buf = optVal[:linux.SizeOfXTStandardTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget)
+
+ if standardTarget.Verdict < 0 {
+ // A Verdict < 0 indicates a non-jump verdict.
+ return translateToStandardTarget(standardTarget.Verdict)
+ }
+ // A verdict >= 0 indicates a jump.
+ return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil
+
+ case errorTargetName:
+ // Error target.
+ if len(optVal) != linux.SizeOfXTErrorTarget {
+ return nil, fmt.Errorf("optVal has insufficient size for error target %d", len(optVal))
+ }
+ var errorTarget linux.XTErrorTarget
+ buf = optVal[:linux.SizeOfXTErrorTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget)
+
+ // Error targets are used in 2 cases:
+ // * An actual error case. These rules have an error
+ // named errorTargetName. The last entry of the table
+ // is usually an error case to catch any packets that
+ // somehow fall through every rule.
+ // * To mark the start of a user defined chain. These
+ // rules have an error with the name of the chain.
+ switch name := errorTarget.Name.String(); name {
+ case errorTargetName:
+ nflog("set entries: error target")
+ return stack.ErrorTarget{}, nil
+ default:
+ // User defined chain.
+ nflog("set entries: user-defined target %q", name)
+ return stack.UserChainTarget{Name: name}, nil
+ }
+
+ case redirectTargetName:
+ // Redirect target.
+ if len(optVal) < linux.SizeOfXTRedirectTarget {
+ return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal))
+ }
+
+ if filter.Protocol != header.TCPProtocolNumber && filter.Protocol != header.UDPProtocolNumber {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+
+ var redirectTarget linux.XTRedirectTarget
+ buf = optVal[:linux.SizeOfXTRedirectTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget)
+
+ // Copy linux.XTRedirectTarget to stack.RedirectTarget.
+ var target stack.RedirectTarget
+ nfRange := redirectTarget.NfRange
+
+ // RangeSize should be 1.
+ if nfRange.RangeSize != 1 {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+
+ // TODO(gvisor.dev/issue/170): Check if the flags are valid.
+ // Also check if we need to map ports or IP.
+ // For now, redirect target only supports destination port change.
+ // Port range and IP range are not supported yet.
+ if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+ target.RangeProtoSpecified = true
+
+ target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
+ target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:])
+
+ // TODO(gvisor.dev/issue/170): Port range is not supported yet.
+ if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+
+ // Convert port from big endian to little endian.
+ port := make([]byte, 2)
+ binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MinPort)
+ target.MinPort = binary.LittleEndian.Uint16(port)
+
+ binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MaxPort)
+ target.MaxPort = binary.LittleEndian.Uint16(port)
+ return target, nil
+ }
+
+ // Unknown target.
+ return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet.", target.Name.String())
+}
+
+func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) {
+ if containsUnsupportedFields(iptip) {
+ return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip)
+ }
+ if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize {
+ return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask))
+ }
+ if len(iptip.Src) != header.IPv4AddressSize || len(iptip.SrcMask) != header.IPv4AddressSize {
+ return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask))
+ }
+
+ n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0)
+ if n == -1 {
+ n = len(iptip.OutputInterface)
+ }
+ ifname := string(iptip.OutputInterface[:n])
+
+ n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0)
+ if n == -1 {
+ n = len(iptip.OutputInterfaceMask)
+ }
+ ifnameMask := string(iptip.OutputInterfaceMask[:n])
+
+ return stack.IPHeaderFilter{
+ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
+ Dst: tcpip.Address(iptip.Dst[:]),
+ DstMask: tcpip.Address(iptip.DstMask[:]),
+ DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0,
+ Src: tcpip.Address(iptip.Src[:]),
+ SrcMask: tcpip.Address(iptip.SrcMask[:]),
+ SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0,
+ OutputInterface: ifname,
+ OutputInterfaceMask: ifnameMask,
+ OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0,
+ }, nil
+}
+
+func containsUnsupportedFields(iptip linux.IPTIP) bool {
+ // The following features are supported:
+ // - Protocol
+ // - Dst and DstMask
+ // - Src and SrcMask
+ // - The inverse destination IP check flag
+ // - OutputInterface, OutputInterfaceMask and its inverse.
+ var emptyInterface = [linux.IFNAMSIZ]byte{}
+ // Disable any supported inverse flags.
+ inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT)
+ return iptip.InputInterface != emptyInterface ||
+ iptip.InputInterfaceMask != emptyInterface ||
+ iptip.Flags != 0 ||
+ iptip.InverseFlags&^inverseMask != 0
+}
+
+func validUnderflow(rule stack.Rule) bool {
+ if len(rule.Matchers) != 0 {
+ return false
+ }
+ if rule.Filter != emptyFilter {
+ return false
+ }
+ switch rule.Target.(type) {
+ case stack.AcceptTarget, stack.DropTarget:
+ return true
+ default:
+ return false
+ }
+}
+
+func isUnconditionalAccept(rule stack.Rule) bool {
+ if !validUnderflow(rule) {
+ return false
+ }
+ _, ok := rule.Target.(stack.AcceptTarget)
+ return ok
+}
+
+func hookFromLinux(hook int) stack.Hook {
+ switch hook {
+ case linux.NF_INET_PRE_ROUTING:
+ return stack.Prerouting
+ case linux.NF_INET_LOCAL_IN:
+ return stack.Input
+ case linux.NF_INET_FORWARD:
+ return stack.Forward
+ case linux.NF_INET_LOCAL_OUT:
+ return stack.Output
+ case linux.NF_INET_POST_ROUTING:
+ return stack.Postrouting
+ }
+ panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook))
+}
diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go
new file mode 100644
index 000000000..1b4e0ad79
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/owner_matcher.go
@@ -0,0 +1,149 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const matcherNameOwner = "owner"
+
+func init() {
+ registerMatchMaker(ownerMarshaler{})
+}
+
+// ownerMarshaler implements matchMaker for owner matching.
+type ownerMarshaler struct{}
+
+// name implements matchMaker.name.
+func (ownerMarshaler) name() string {
+ return matcherNameOwner
+}
+
+// marshal implements matchMaker.marshal.
+func (ownerMarshaler) marshal(mr stack.Matcher) []byte {
+ matcher := mr.(*OwnerMatcher)
+ iptOwnerInfo := linux.IPTOwnerInfo{
+ UID: matcher.uid,
+ GID: matcher.gid,
+ }
+
+ // Support for UID and GID match.
+ if matcher.matchUID {
+ iptOwnerInfo.Match = linux.XT_OWNER_UID
+ if matcher.invertUID {
+ iptOwnerInfo.Invert = linux.XT_OWNER_UID
+ }
+ }
+ if matcher.matchGID {
+ iptOwnerInfo.Match |= linux.XT_OWNER_GID
+ if matcher.invertGID {
+ iptOwnerInfo.Invert |= linux.XT_OWNER_GID
+ }
+ }
+
+ buf := make([]byte, 0, linux.SizeOfIPTOwnerInfo)
+ return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, usermem.ByteOrder, iptOwnerInfo))
+}
+
+// unmarshal implements matchMaker.unmarshal.
+func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
+ if len(buf) < linux.SizeOfIPTOwnerInfo {
+ return nil, fmt.Errorf("buf has insufficient size for owner match: %d", len(buf))
+ }
+
+ // For alignment reasons, the match's total size may
+ // exceed what's strictly necessary to hold matchData.
+ var matchData linux.IPTOwnerInfo
+ binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], usermem.ByteOrder, &matchData)
+ nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData)
+
+ var owner OwnerMatcher
+ owner.uid = matchData.UID
+ owner.gid = matchData.GID
+
+ // Check flags.
+ if matchData.Match&linux.XT_OWNER_UID != 0 {
+ owner.matchUID = true
+ if matchData.Invert&linux.XT_OWNER_UID != 0 {
+ owner.invertUID = true
+ }
+ }
+ if matchData.Match&linux.XT_OWNER_GID != 0 {
+ owner.matchGID = true
+ if matchData.Invert&linux.XT_OWNER_GID != 0 {
+ owner.invertGID = true
+ }
+ }
+
+ return &owner, nil
+}
+
+type OwnerMatcher struct {
+ uid uint32
+ gid uint32
+ matchUID bool
+ matchGID bool
+ invertUID bool
+ invertGID bool
+}
+
+// Name implements Matcher.Name.
+func (*OwnerMatcher) Name() string {
+ return matcherNameOwner
+}
+
+// Match implements Matcher.Match.
+func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
+ // Support only for OUTPUT chain.
+ // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also.
+ if hook != stack.Output {
+ return false, true
+ }
+
+ // If the packet owner is not set, drop the packet.
+ if pkt.Owner == nil {
+ return false, true
+ }
+
+ var matches bool
+ // Check for UID match.
+ if om.matchUID {
+ if pkt.Owner.UID() == om.uid {
+ matches = true
+ }
+ if matches == om.invertUID {
+ return false, false
+ }
+ }
+
+ // Check for GID match.
+ if om.matchGID {
+ matches = false
+ if pkt.Owner.GID() == om.gid {
+ matches = true
+ }
+ if matches == om.invertGID {
+ return false, false
+ }
+ }
+
+ return true, false
+}
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
new file mode 100644
index 000000000..b91ba3ab3
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -0,0 +1,35 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// JumpTarget implements stack.Target.
+type JumpTarget struct {
+ // Offset is the byte offset of the rule to jump to. It is used for
+ // marshaling and unmarshaling.
+ Offset uint32
+
+ // RuleNum is the rule to jump to.
+ RuleNum int
+}
+
+// Action implements stack.Target.Action.
+func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) {
+ return stack.RuleJump, jt.RuleNum
+}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
new file mode 100644
index 000000000..4f98ee2d5
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -0,0 +1,130 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const matcherNameTCP = "tcp"
+
+func init() {
+ registerMatchMaker(tcpMarshaler{})
+}
+
+// tcpMarshaler implements matchMaker for TCP matching.
+type tcpMarshaler struct{}
+
+// name implements matchMaker.name.
+func (tcpMarshaler) name() string {
+ return matcherNameTCP
+}
+
+// marshal implements matchMaker.marshal.
+func (tcpMarshaler) marshal(mr stack.Matcher) []byte {
+ matcher := mr.(*TCPMatcher)
+ xttcp := linux.XTTCP{
+ SourcePortStart: matcher.sourcePortStart,
+ SourcePortEnd: matcher.sourcePortEnd,
+ DestinationPortStart: matcher.destinationPortStart,
+ DestinationPortEnd: matcher.destinationPortEnd,
+ }
+ buf := make([]byte, 0, linux.SizeOfXTTCP)
+ return marshalEntryMatch(matcherNameTCP, binary.Marshal(buf, usermem.ByteOrder, xttcp))
+}
+
+// unmarshal implements matchMaker.unmarshal.
+func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
+ if len(buf) < linux.SizeOfXTTCP {
+ return nil, fmt.Errorf("buf has insufficient size for TCP match: %d", len(buf))
+ }
+
+ // For alignment reasons, the match's total size may
+ // exceed what's strictly necessary to hold matchData.
+ var matchData linux.XTTCP
+ binary.Unmarshal(buf[:linux.SizeOfXTTCP], usermem.ByteOrder, &matchData)
+ nflog("parseMatchers: parsed XTTCP: %+v", matchData)
+
+ if matchData.Option != 0 ||
+ matchData.FlagMask != 0 ||
+ matchData.FlagCompare != 0 ||
+ matchData.InverseFlags != 0 {
+ return nil, fmt.Errorf("unsupported TCP matcher flags set")
+ }
+
+ if filter.Protocol != header.TCPProtocolNumber {
+ return nil, fmt.Errorf("TCP matching is only valid for protocol %d.", header.TCPProtocolNumber)
+ }
+
+ return &TCPMatcher{
+ sourcePortStart: matchData.SourcePortStart,
+ sourcePortEnd: matchData.SourcePortEnd,
+ destinationPortStart: matchData.DestinationPortStart,
+ destinationPortEnd: matchData.DestinationPortEnd,
+ }, nil
+}
+
+// TCPMatcher matches TCP packets and their headers. It implements Matcher.
+type TCPMatcher struct {
+ sourcePortStart uint16
+ sourcePortEnd uint16
+ destinationPortStart uint16
+ destinationPortEnd uint16
+}
+
+// Name implements Matcher.Name.
+func (*TCPMatcher) Name() string {
+ return matcherNameTCP
+}
+
+// Match implements Matcher.Match.
+func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
+ netHeader := header.IPv4(pkt.NetworkHeader)
+
+ if netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ return false, false
+ }
+
+ // We dont't match fragments.
+ if frag := netHeader.FragmentOffset(); frag != 0 {
+ if frag == 1 {
+ return false, true
+ }
+ return false, false
+ }
+
+ tcpHeader := header.TCP(pkt.TransportHeader)
+ if len(tcpHeader) < header.TCPMinimumSize {
+ // There's no valid TCP header here, so we drop the packet immediately.
+ return false, true
+ }
+
+ // Check whether the source and destination ports are within the
+ // matching range.
+ if sourcePort := tcpHeader.SourcePort(); sourcePort < tm.sourcePortStart || tm.sourcePortEnd < sourcePort {
+ return false, false
+ }
+ if destinationPort := tcpHeader.DestinationPort(); destinationPort < tm.destinationPortStart || tm.destinationPortEnd < destinationPort {
+ return false, false
+ }
+
+ return true, false
+}
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
new file mode 100644
index 000000000..3f20fc891
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -0,0 +1,129 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const matcherNameUDP = "udp"
+
+func init() {
+ registerMatchMaker(udpMarshaler{})
+}
+
+// udpMarshaler implements matchMaker for UDP matching.
+type udpMarshaler struct{}
+
+// name implements matchMaker.name.
+func (udpMarshaler) name() string {
+ return matcherNameUDP
+}
+
+// marshal implements matchMaker.marshal.
+func (udpMarshaler) marshal(mr stack.Matcher) []byte {
+ matcher := mr.(*UDPMatcher)
+ xtudp := linux.XTUDP{
+ SourcePortStart: matcher.sourcePortStart,
+ SourcePortEnd: matcher.sourcePortEnd,
+ DestinationPortStart: matcher.destinationPortStart,
+ DestinationPortEnd: matcher.destinationPortEnd,
+ }
+ buf := make([]byte, 0, linux.SizeOfXTUDP)
+ return marshalEntryMatch(matcherNameUDP, binary.Marshal(buf, usermem.ByteOrder, xtudp))
+}
+
+// unmarshal implements matchMaker.unmarshal.
+func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
+ if len(buf) < linux.SizeOfXTUDP {
+ return nil, fmt.Errorf("buf has insufficient size for UDP match: %d", len(buf))
+ }
+
+ // For alignment reasons, the match's total size may exceed what's
+ // strictly necessary to hold matchData.
+ var matchData linux.XTUDP
+ binary.Unmarshal(buf[:linux.SizeOfXTUDP], usermem.ByteOrder, &matchData)
+ nflog("parseMatchers: parsed XTUDP: %+v", matchData)
+
+ if matchData.InverseFlags != 0 {
+ return nil, fmt.Errorf("unsupported UDP matcher inverse flags set")
+ }
+
+ if filter.Protocol != header.UDPProtocolNumber {
+ return nil, fmt.Errorf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber)
+ }
+
+ return &UDPMatcher{
+ sourcePortStart: matchData.SourcePortStart,
+ sourcePortEnd: matchData.SourcePortEnd,
+ destinationPortStart: matchData.DestinationPortStart,
+ destinationPortEnd: matchData.DestinationPortEnd,
+ }, nil
+}
+
+// UDPMatcher matches UDP packets and their headers. It implements Matcher.
+type UDPMatcher struct {
+ sourcePortStart uint16
+ sourcePortEnd uint16
+ destinationPortStart uint16
+ destinationPortEnd uint16
+}
+
+// Name implements Matcher.Name.
+func (*UDPMatcher) Name() string {
+ return matcherNameUDP
+}
+
+// Match implements Matcher.Match.
+func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
+ netHeader := header.IPv4(pkt.NetworkHeader)
+
+ // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
+ // into the stack.Check codepath as matchers are added.
+ if netHeader.TransportProtocol() != header.UDPProtocolNumber {
+ return false, false
+ }
+
+ // We dont't match fragments.
+ if frag := netHeader.FragmentOffset(); frag != 0 {
+ if frag == 1 {
+ return false, true
+ }
+ return false, false
+ }
+
+ udpHeader := header.UDP(pkt.TransportHeader)
+ if len(udpHeader) < header.UDPMinimumSize {
+ // There's no valid UDP header here, so we drop the packet immediately.
+ return false, true
+ }
+
+ // Check whether the source and destination ports are within the
+ // matching range.
+ if sourcePort := udpHeader.SourcePort(); sourcePort < um.sourcePortStart || um.sourcePortEnd < sourcePort {
+ return false, false
+ }
+ if destinationPort := udpHeader.DestinationPort(); destinationPort < um.destinationPortStart || um.destinationPortEnd < destinationPort {
+ return false, false
+ }
+
+ return true, false
+}
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
new file mode 100644
index 000000000..d5ca3ac56
--- /dev/null
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -0,0 +1,52 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "netlink",
+ srcs = [
+ "message.go",
+ "provider.go",
+ "provider_vfs2.go",
+ "socket.go",
+ "socket_vfs2.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/context",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsimpl/sockfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/netlink/port",
+ "//pkg/sentry/socket/unix",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "netlink_test",
+ size = "small",
+ srcs = [
+ "message_test.go",
+ ],
+ deps = [
+ ":netlink",
+ "//pkg/abi/linux",
+ ],
+)
diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go
new file mode 100644
index 000000000..0899c61d1
--- /dev/null
+++ b/pkg/sentry/socket/netlink/message.go
@@ -0,0 +1,281 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netlink
+
+import (
+ "fmt"
+ "math"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// alignPad returns the length of padding required for alignment.
+//
+// Preconditions: align is a power of two.
+func alignPad(length int, align uint) int {
+ return binary.AlignUp(length, align) - length
+}
+
+// Message contains a complete serialized netlink message.
+type Message struct {
+ hdr linux.NetlinkMessageHeader
+ buf []byte
+}
+
+// NewMessage creates a new Message containing the passed header.
+//
+// The header length will be updated by Finalize.
+func NewMessage(hdr linux.NetlinkMessageHeader) *Message {
+ return &Message{
+ hdr: hdr,
+ buf: binary.Marshal(nil, usermem.ByteOrder, hdr),
+ }
+}
+
+// ParseMessage parses the first message seen at buf, returning the rest of the
+// buffer. If message is malformed, ok of false is returned. For last message,
+// padding check is loose, if there isn't enought padding, whole buf is consumed
+// and ok is set to true.
+func ParseMessage(buf []byte) (msg *Message, rest []byte, ok bool) {
+ b := BytesView(buf)
+
+ hdrBytes, ok := b.Extract(linux.NetlinkMessageHeaderSize)
+ if !ok {
+ return
+ }
+ var hdr linux.NetlinkMessageHeader
+ binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr)
+
+ // Msg portion.
+ totalMsgLen := int(hdr.Length)
+ _, ok = b.Extract(totalMsgLen - linux.NetlinkMessageHeaderSize)
+ if !ok {
+ return
+ }
+
+ // Padding.
+ numPad := alignPad(totalMsgLen, linux.NLMSG_ALIGNTO)
+ // Linux permits the last message not being aligned, just consume all of it.
+ // Ref: net/netlink/af_netlink.c:netlink_rcv_skb
+ if numPad > len(b) {
+ numPad = len(b)
+ }
+ _, ok = b.Extract(numPad)
+ if !ok {
+ return
+ }
+
+ return &Message{
+ hdr: hdr,
+ buf: buf[:totalMsgLen],
+ }, []byte(b), true
+}
+
+// Header returns the header of this message.
+func (m *Message) Header() linux.NetlinkMessageHeader {
+ return m.hdr
+}
+
+// GetData unmarshals the payload message header from this netlink message, and
+// returns the attributes portion.
+func (m *Message) GetData(msg interface{}) (AttrsView, bool) {
+ b := BytesView(m.buf)
+
+ _, ok := b.Extract(linux.NetlinkMessageHeaderSize)
+ if !ok {
+ return nil, false
+ }
+
+ size := int(binary.Size(msg))
+ msgBytes, ok := b.Extract(size)
+ if !ok {
+ return nil, false
+ }
+ binary.Unmarshal(msgBytes, usermem.ByteOrder, msg)
+
+ numPad := alignPad(linux.NetlinkMessageHeaderSize+size, linux.NLMSG_ALIGNTO)
+ // Linux permits the last message not being aligned, just consume all of it.
+ // Ref: net/netlink/af_netlink.c:netlink_rcv_skb
+ if numPad > len(b) {
+ numPad = len(b)
+ }
+ _, ok = b.Extract(numPad)
+ if !ok {
+ return nil, false
+ }
+
+ return AttrsView(b), true
+}
+
+// Finalize returns the []byte containing the entire message, with the total
+// length set in the message header. The Message must not be modified after
+// calling Finalize.
+func (m *Message) Finalize() []byte {
+ // Update length, which is the first 4 bytes of the header.
+ usermem.ByteOrder.PutUint32(m.buf, uint32(len(m.buf)))
+
+ // Align the message. Note that the message length in the header (set
+ // above) is the useful length of the message, not the total aligned
+ // length. See net/netlink/af_netlink.c:__nlmsg_put.
+ aligned := binary.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO)
+ m.putZeros(aligned - len(m.buf))
+ return m.buf
+}
+
+// putZeros adds n zeros to the message.
+func (m *Message) putZeros(n int) {
+ for n > 0 {
+ m.buf = append(m.buf, 0)
+ n--
+ }
+}
+
+// Put serializes v into the message.
+func (m *Message) Put(v interface{}) {
+ m.buf = binary.Marshal(m.buf, usermem.ByteOrder, v)
+}
+
+// PutAttr adds v to the message as a netlink attribute.
+//
+// Preconditions: The serialized attribute (linux.NetlinkAttrHeaderSize +
+// binary.Size(v) fits in math.MaxUint16 bytes.
+func (m *Message) PutAttr(atype uint16, v interface{}) {
+ l := linux.NetlinkAttrHeaderSize + int(binary.Size(v))
+ if l > math.MaxUint16 {
+ panic(fmt.Sprintf("attribute too large: %d", l))
+ }
+
+ m.Put(linux.NetlinkAttrHeader{
+ Type: atype,
+ Length: uint16(l),
+ })
+ m.Put(v)
+
+ // Align the attribute.
+ aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
+ m.putZeros(aligned - l)
+}
+
+// PutAttrString adds s to the message as a netlink attribute.
+func (m *Message) PutAttrString(atype uint16, s string) {
+ l := linux.NetlinkAttrHeaderSize + len(s) + 1
+ m.Put(linux.NetlinkAttrHeader{
+ Type: atype,
+ Length: uint16(l),
+ })
+
+ // String + NUL-termination.
+ m.Put([]byte(s))
+ m.putZeros(1)
+
+ // Align the attribute.
+ aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
+ m.putZeros(aligned - l)
+}
+
+// MessageSet contains a series of netlink messages.
+type MessageSet struct {
+ // Multi indicates that this a multi-part message, to be terminated by
+ // NLMSG_DONE. NLMSG_DONE is sent even if the set contains only one
+ // Message.
+ //
+ // If Multi is set, all added messages will have NLM_F_MULTI set.
+ Multi bool
+
+ // PortID is the destination port for all messages.
+ PortID int32
+
+ // Seq is the sequence counter for all messages in the set.
+ Seq uint32
+
+ // Messages contains the messages in the set.
+ Messages []*Message
+}
+
+// NewMessageSet creates a new MessageSet.
+//
+// portID is the destination port to set as PortID in all messages.
+//
+// seq is the sequence counter to set as seq in all messages in the set.
+func NewMessageSet(portID int32, seq uint32) *MessageSet {
+ return &MessageSet{
+ PortID: portID,
+ Seq: seq,
+ }
+}
+
+// AddMessage adds a new message to the set and returns it for further
+// additions.
+//
+// The passed header will have Seq, PortID and the multi flag set
+// automatically.
+func (ms *MessageSet) AddMessage(hdr linux.NetlinkMessageHeader) *Message {
+ hdr.Seq = ms.Seq
+ hdr.PortID = uint32(ms.PortID)
+ if ms.Multi {
+ hdr.Flags |= linux.NLM_F_MULTI
+ }
+
+ m := NewMessage(hdr)
+ ms.Messages = append(ms.Messages, m)
+ return m
+}
+
+// AttrsView is a view into the attributes portion of a netlink message.
+type AttrsView []byte
+
+// Empty returns whether there is no attribute left in v.
+func (v AttrsView) Empty() bool {
+ return len(v) == 0
+}
+
+// ParseFirst parses first netlink attribute at the beginning of v.
+func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest AttrsView, ok bool) {
+ b := BytesView(v)
+
+ hdrBytes, ok := b.Extract(linux.NetlinkAttrHeaderSize)
+ if !ok {
+ return
+ }
+ binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr)
+
+ value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize)
+ if !ok {
+ return
+ }
+
+ _, ok = b.Extract(alignPad(int(hdr.Length), linux.NLA_ALIGNTO))
+ if !ok {
+ return
+ }
+
+ return hdr, value, AttrsView(b), ok
+}
+
+// BytesView supports extracting data from a byte slice with bounds checking.
+type BytesView []byte
+
+// Extract removes the first n bytes from v and returns it. If n is out of
+// bounds, it returns false.
+func (v *BytesView) Extract(n int) ([]byte, bool) {
+ if n < 0 || n > len(*v) {
+ return nil, false
+ }
+ extracted := (*v)[:n]
+ *v = (*v)[n:]
+ return extracted, true
+}
diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go
new file mode 100644
index 000000000..ef13d9386
--- /dev/null
+++ b/pkg/sentry/socket/netlink/message_test.go
@@ -0,0 +1,312 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package message_test
+
+import (
+ "bytes"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netlink"
+)
+
+type dummyNetlinkMsg struct {
+ Foo uint16
+}
+
+func TestParseMessage(t *testing.T) {
+ tests := []struct {
+ desc string
+ input []byte
+
+ header linux.NetlinkMessageHeader
+ dataMsg *dummyNetlinkMsg
+ restLen int
+ ok bool
+ }{
+ {
+ desc: "valid",
+ input: []byte{
+ 0x14, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
+ },
+ header: linux.NetlinkMessageHeader{
+ Length: 20,
+ Type: 1,
+ Flags: 2,
+ Seq: 3,
+ PortID: 4,
+ },
+ dataMsg: &dummyNetlinkMsg{
+ Foo: 0x3130,
+ },
+ restLen: 0,
+ ok: true,
+ },
+ {
+ desc: "valid with next message",
+ input: []byte{
+ 0x14, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
+ 0xFF, // Next message (rest)
+ },
+ header: linux.NetlinkMessageHeader{
+ Length: 20,
+ Type: 1,
+ Flags: 2,
+ Seq: 3,
+ PortID: 4,
+ },
+ dataMsg: &dummyNetlinkMsg{
+ Foo: 0x3130,
+ },
+ restLen: 1,
+ ok: true,
+ },
+ {
+ desc: "valid for last message without padding",
+ input: []byte{
+ 0x12, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, // Data message
+ },
+ header: linux.NetlinkMessageHeader{
+ Length: 18,
+ Type: 1,
+ Flags: 2,
+ Seq: 3,
+ PortID: 4,
+ },
+ dataMsg: &dummyNetlinkMsg{
+ Foo: 0x3130,
+ },
+ restLen: 0,
+ ok: true,
+ },
+ {
+ desc: "valid for last message not to be aligned",
+ input: []byte{
+ 0x13, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, // Data message
+ 0x00, // Excessive 1 byte permitted at end
+ },
+ header: linux.NetlinkMessageHeader{
+ Length: 19,
+ Type: 1,
+ Flags: 2,
+ Seq: 3,
+ PortID: 4,
+ },
+ dataMsg: &dummyNetlinkMsg{
+ Foo: 0x3130,
+ },
+ restLen: 0,
+ ok: true,
+ },
+ {
+ desc: "header.Length too short",
+ input: []byte{
+ 0x04, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
+ },
+ ok: false,
+ },
+ {
+ desc: "header.Length too long",
+ input: []byte{
+ 0xFF, 0xFF, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
+ },
+ ok: false,
+ },
+ {
+ desc: "header incomplete",
+ input: []byte{
+ 0x04, 0x00, 0x00, 0x00, // Length
+ },
+ ok: false,
+ },
+ {
+ desc: "empty message",
+ input: []byte{},
+ ok: false,
+ },
+ }
+ for _, test := range tests {
+ msg, rest, ok := netlink.ParseMessage(test.input)
+ if ok != test.ok {
+ t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok)
+ continue
+ }
+ if !test.ok {
+ continue
+ }
+ if !reflect.DeepEqual(msg.Header(), test.header) {
+ t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header)
+ }
+
+ dataMsg := &dummyNetlinkMsg{}
+ _, dataOk := msg.GetData(dataMsg)
+ if !dataOk {
+ t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk)
+ } else if !reflect.DeepEqual(dataMsg, test.dataMsg) {
+ t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg)
+ }
+
+ if got, want := rest, test.input[len(test.input)-test.restLen:]; !bytes.Equal(got, want) {
+ t.Errorf("%v: got rest = %v, want = %v", test.desc, got, want)
+ }
+ }
+}
+
+func TestAttrView(t *testing.T) {
+ tests := []struct {
+ desc string
+ input []byte
+
+ // Outputs for ParseFirst.
+ hdr linux.NetlinkAttrHeader
+ value []byte
+ restLen int
+ ok bool
+
+ // Outputs for Empty.
+ isEmpty bool
+ }{
+ {
+ desc: "valid",
+ input: []byte{
+ 0x06, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x00, 0x00, // Data with 2 bytes padding
+ },
+ hdr: linux.NetlinkAttrHeader{
+ Length: 6,
+ Type: 1,
+ },
+ value: []byte{0x30, 0x31},
+ restLen: 0,
+ ok: true,
+ isEmpty: false,
+ },
+ {
+ desc: "at alignment",
+ input: []byte{
+ 0x08, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x32, 0x33, // Data
+ },
+ hdr: linux.NetlinkAttrHeader{
+ Length: 8,
+ Type: 1,
+ },
+ value: []byte{0x30, 0x31, 0x32, 0x33},
+ restLen: 0,
+ ok: true,
+ isEmpty: false,
+ },
+ {
+ desc: "at alignment with rest data",
+ input: []byte{
+ 0x08, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x32, 0x33, // Data
+ 0xFF, 0xFE, // Rest data
+ },
+ hdr: linux.NetlinkAttrHeader{
+ Length: 8,
+ Type: 1,
+ },
+ value: []byte{0x30, 0x31, 0x32, 0x33},
+ restLen: 2,
+ ok: true,
+ isEmpty: false,
+ },
+ {
+ desc: "hdr.Length too long",
+ input: []byte{
+ 0xFF, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x32, 0x33, // Data
+ },
+ ok: false,
+ isEmpty: false,
+ },
+ {
+ desc: "hdr.Length too short",
+ input: []byte{
+ 0x01, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x32, 0x33, // Data
+ },
+ ok: false,
+ isEmpty: false,
+ },
+ {
+ desc: "empty",
+ input: []byte{},
+ ok: false,
+ isEmpty: true,
+ },
+ }
+ for _, test := range tests {
+ attrs := netlink.AttrsView(test.input)
+
+ // Test ParseFirst().
+ hdr, value, rest, ok := attrs.ParseFirst()
+ if ok != test.ok {
+ t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok)
+ } else if test.ok {
+ if !reflect.DeepEqual(hdr, test.hdr) {
+ t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, hdr, test.hdr)
+ }
+ if !bytes.Equal(value, test.value) {
+ t.Errorf("%v: got value = %v, want = %v", test.desc, value, test.value)
+ }
+ if wantRest := test.input[len(test.input)-test.restLen:]; !bytes.Equal(rest, wantRest) {
+ t.Errorf("%v: got rest = %v, want = %v", test.desc, rest, wantRest)
+ }
+ }
+
+ // Test Empty().
+ if got, want := attrs.Empty(), test.isEmpty; got != want {
+ t.Errorf("%v: got empty = %v, want = %v", test.desc, got, want)
+ }
+ }
+}
diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD
new file mode 100644
index 000000000..3a22923d8
--- /dev/null
+++ b/pkg/sentry/socket/netlink/port/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "port",
+ srcs = ["port.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = ["//pkg/sync"],
+)
+
+go_test(
+ name = "port_test",
+ srcs = ["port_test.go"],
+ library = ":port",
+)
diff --git a/pkg/sentry/socket/netlink/port/port.go b/pkg/sentry/socket/netlink/port/port.go
new file mode 100644
index 000000000..2cd3afc22
--- /dev/null
+++ b/pkg/sentry/socket/netlink/port/port.go
@@ -0,0 +1,117 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package port provides port ID allocation for netlink sockets.
+//
+// A netlink port is any int32 value. Positive ports are typically equivalent
+// to the PID of the binding process. If that port is unavailable, negative
+// ports are searched to find a free port that will not conflict with other
+// PIDS.
+package port
+
+import (
+ "fmt"
+ "math"
+ "math/rand"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// maxPorts is a sanity limit on the maximum number of ports to allocate per
+// protocol.
+const maxPorts = 10000
+
+// Manager allocates netlink port IDs.
+//
+// +stateify savable
+type Manager struct {
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // ports contains a map of allocated ports for each protocol.
+ ports map[int]map[int32]struct{}
+}
+
+// New creates a new Manager.
+func New() *Manager {
+ return &Manager{
+ ports: make(map[int]map[int32]struct{}),
+ }
+}
+
+// Allocate reserves a new port ID for protocol. hint will be taken if
+// available.
+func (m *Manager) Allocate(protocol int, hint int32) (int32, bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ proto, ok := m.ports[protocol]
+ if !ok {
+ proto = make(map[int32]struct{})
+ // Port 0 is reserved for the kernel.
+ proto[0] = struct{}{}
+ m.ports[protocol] = proto
+ }
+
+ if len(proto) >= maxPorts {
+ return 0, false
+ }
+
+ if _, ok := proto[hint]; !ok {
+ // Hint is available, reserve it.
+ proto[hint] = struct{}{}
+ return hint, true
+ }
+
+ // Search for any free port in [math.MinInt32, -4096). The positive
+ // port space is left open for pid-based allocations. This behavior is
+ // consistent with Linux.
+ start := int32(math.MinInt32 + rand.Int63n(math.MaxInt32-4096+1))
+ curr := start
+ for {
+ if _, ok := proto[curr]; !ok {
+ proto[curr] = struct{}{}
+ return curr, true
+ }
+
+ curr--
+ if curr >= -4096 {
+ curr = -4097
+ }
+ if curr == start {
+ // Nothing found. We should always find a free port
+ // because maxPorts < -4096 - MinInt32.
+ panic(fmt.Sprintf("No free port found in %+v", proto))
+ }
+ }
+}
+
+// Release frees the specified port for protocol.
+//
+// Preconditions: port is already allocated.
+func (m *Manager) Release(protocol int, port int32) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ proto, ok := m.ports[protocol]
+ if !ok {
+ panic(fmt.Sprintf("Released port %d for protocol %d which has no allocations", port, protocol))
+ }
+
+ if _, ok := proto[port]; !ok {
+ panic(fmt.Sprintf("Released port %d for protocol %d is not allocated", port, protocol))
+ }
+
+ delete(proto, port)
+}
diff --git a/pkg/sentry/socket/netlink/port/port_test.go b/pkg/sentry/socket/netlink/port/port_test.go
new file mode 100644
index 000000000..516f6cd6c
--- /dev/null
+++ b/pkg/sentry/socket/netlink/port/port_test.go
@@ -0,0 +1,82 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package port
+
+import (
+ "testing"
+)
+
+func TestAllocateHint(t *testing.T) {
+ m := New()
+
+ // We can get the hint port.
+ p, ok := m.Allocate(0, 1)
+ if !ok {
+ t.Errorf("m.Allocate got !ok want ok")
+ }
+ if p != 1 {
+ t.Errorf("m.Allocate(0, 1) got %d want 1", p)
+ }
+
+ // Hint is taken.
+ p, ok = m.Allocate(0, 1)
+ if !ok {
+ t.Errorf("m.Allocate got !ok want ok")
+ }
+ if p == 1 {
+ t.Errorf("m.Allocate(0, 1) got 1 want anything else")
+ }
+
+ // Hint is available for a different protocol.
+ p, ok = m.Allocate(1, 1)
+ if !ok {
+ t.Errorf("m.Allocate got !ok want ok")
+ }
+ if p != 1 {
+ t.Errorf("m.Allocate(1, 1) got %d want 1", p)
+ }
+
+ m.Release(0, 1)
+
+ // Hint is available again after release.
+ p, ok = m.Allocate(0, 1)
+ if !ok {
+ t.Errorf("m.Allocate got !ok want ok")
+ }
+ if p != 1 {
+ t.Errorf("m.Allocate(0, 1) got %d want 1", p)
+ }
+}
+
+func TestAllocateExhausted(t *testing.T) {
+ m := New()
+
+ // Fill all ports (0 is already reserved).
+ for i := int32(1); i < maxPorts; i++ {
+ p, ok := m.Allocate(0, i)
+ if !ok {
+ t.Fatalf("m.Allocate got !ok want ok")
+ }
+ if p != i {
+ t.Fatalf("m.Allocate(0, %d) got %d want %d", i, p, i)
+ }
+ }
+
+ // Now no more can be allocated.
+ p, ok := m.Allocate(0, 1)
+ if ok {
+ t.Errorf("m.Allocate got %d, ok want !ok", p)
+ }
+}
diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go
new file mode 100644
index 000000000..0d45e5053
--- /dev/null
+++ b/pkg/sentry/socket/netlink/provider.go
@@ -0,0 +1,116 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netlink
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/syserr"
+)
+
+// Protocol is the implementation of a netlink socket protocol.
+type Protocol interface {
+ // Protocol returns the Linux netlink protocol value.
+ Protocol() int
+
+ // CanSend returns true if this protocol may ever send messages.
+ //
+ // TODO(gvisor.dev/issue/1119): This is a workaround to allow
+ // advertising support for otherwise unimplemented features on sockets
+ // that will never send messages, thus making those features no-ops.
+ CanSend() bool
+
+ // ProcessMessage processes a single message from userspace.
+ //
+ // If err == nil, any messages added to ms will be sent back to the
+ // other end of the socket. Setting ms.Multi will cause an NLMSG_DONE
+ // message to be sent even if ms contains no messages.
+ ProcessMessage(ctx context.Context, msg *Message, ms *MessageSet) *syserr.Error
+}
+
+// Provider is a function that creates a new Protocol for a specific netlink
+// protocol.
+//
+// Note that this is distinct from socket.Provider, which is used for all
+// socket families.
+type Provider func(t *kernel.Task) (Protocol, *syserr.Error)
+
+// protocols holds a map of all known address protocols and their provider.
+var protocols = make(map[int]Provider)
+
+// RegisterProvider registers the provider of a given address protocol so that
+// netlink sockets of that type can be created via socket(2).
+//
+// Preconditions: May only be called before any netlink sockets are created.
+func RegisterProvider(protocol int, provider Provider) {
+ if p, ok := protocols[protocol]; ok {
+ panic(fmt.Sprintf("Netlink protocol %d already provided by %+v", protocol, p))
+ }
+
+ protocols[protocol] = provider
+}
+
+// LINT.IfChange
+
+// socketProvider implements socket.Provider.
+type socketProvider struct {
+}
+
+// Socket implements socket.Provider.Socket.
+func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Netlink sockets must be specified as datagram or raw, but they
+ // behave the same regardless of type.
+ if stype != linux.SOCK_DGRAM && stype != linux.SOCK_RAW {
+ return nil, syserr.ErrSocketNotSupported
+ }
+
+ provider, ok := protocols[protocol]
+ if !ok {
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ p, err := provider(t)
+ if err != nil {
+ return nil, err
+ }
+
+ s, err := NewSocket(t, stype, p)
+ if err != nil {
+ return nil, err
+ }
+
+ d := socket.NewDirent(t, netlinkSocketDevice)
+ defer d.DecRef()
+ return fs.NewFile(t, d, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, s), nil
+}
+
+// Pair implements socket.Provider.Pair by returning an error.
+func (*socketProvider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
+ // Netlink sockets never supports creating socket pairs.
+ return nil, nil, syserr.ErrNotSupported
+}
+
+// LINT.ThenChange(./provider_vfs2.go)
+
+// init registers the socket provider.
+func init() {
+ socket.RegisterProvider(linux.AF_NETLINK, &socketProvider{})
+ socket.RegisterProviderVFS2(linux.AF_NETLINK, &socketProviderVFS2{})
+}
diff --git a/pkg/sentry/socket/netlink/provider_vfs2.go b/pkg/sentry/socket/netlink/provider_vfs2.go
new file mode 100644
index 000000000..bb205be0d
--- /dev/null
+++ b/pkg/sentry/socket/netlink/provider_vfs2.go
@@ -0,0 +1,69 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netlink
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+)
+
+// socketProviderVFS2 implements socket.Provider.
+type socketProviderVFS2 struct {
+}
+
+// Socket implements socket.Provider.Socket.
+func (*socketProviderVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ // Netlink sockets must be specified as datagram or raw, but they
+ // behave the same regardless of type.
+ if stype != linux.SOCK_DGRAM && stype != linux.SOCK_RAW {
+ return nil, syserr.ErrSocketNotSupported
+ }
+
+ provider, ok := protocols[protocol]
+ if !ok {
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ p, err := provider(t)
+ if err != nil {
+ return nil, err
+ }
+
+ s, err := NewVFS2(t, stype, p)
+ if err != nil {
+ return nil, err
+ }
+
+ vfsfd := &s.vfsfd
+ mnt := t.Kernel().SocketMount()
+ d := sockfs.NewDentry(t.Credentials(), mnt)
+ if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return vfsfd, nil
+}
+
+// Pair implements socket.Provider.Pair by returning an error.
+func (*socketProviderVFS2) Pair(*kernel.Task, linux.SockType, int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ // Netlink sockets never supports creating socket pairs.
+ return nil, nil, syserr.ErrNotSupported
+}
diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD
new file mode 100644
index 000000000..93127398d
--- /dev/null
+++ b/pkg/sentry/socket/netlink/route/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "route",
+ srcs = [
+ "protocol.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/socket/netlink",
+ "//pkg/syserr",
+ ],
+)
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
new file mode 100644
index 000000000..c84d8bd7c
--- /dev/null
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -0,0 +1,498 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package route provides a NETLINK_ROUTE socket protocol.
+package route
+
+import (
+ "bytes"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netlink"
+ "gvisor.dev/gvisor/pkg/syserr"
+)
+
+// commandKind describes the operational class of a message type.
+//
+// The route message types use the lower 2 bits of the type to describe class
+// of command.
+type commandKind int
+
+const (
+ kindNew commandKind = 0x0
+ kindDel = 0x1
+ kindGet = 0x2
+ kindSet = 0x3
+)
+
+func typeKind(typ uint16) commandKind {
+ return commandKind(typ & 0x3)
+}
+
+// Protocol implements netlink.Protocol.
+//
+// +stateify savable
+type Protocol struct{}
+
+var _ netlink.Protocol = (*Protocol)(nil)
+
+// NewProtocol creates a NETLINK_ROUTE netlink.Protocol.
+func NewProtocol(t *kernel.Task) (netlink.Protocol, *syserr.Error) {
+ return &Protocol{}, nil
+}
+
+// Protocol implements netlink.Protocol.Protocol.
+func (p *Protocol) Protocol() int {
+ return linux.NETLINK_ROUTE
+}
+
+// CanSend implements netlink.Protocol.CanSend.
+func (p *Protocol) CanSend() bool {
+ return true
+}
+
+// dumpLinks handles RTM_GETLINK dump requests.
+func (p *Protocol) dumpLinks(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ // NLM_F_DUMP + RTM_GETLINK messages are supposed to include an
+ // ifinfomsg. However, Linux <3.9 only checked for rtgenmsg, and some
+ // userspace applications (including glibc) still include rtgenmsg.
+ // Linux has a workaround based on the total message length.
+ //
+ // We don't bother to check for either, since we don't support any
+ // extra attributes that may be included anyways.
+ //
+ // The message may also contain netlink attribute IFLA_EXT_MASK, which
+ // we don't support.
+
+ // The RTM_GETLINK dump response is a set of messages each containing
+ // an InterfaceInfoMessage followed by a set of netlink attributes.
+
+ // We always send back an NLMSG_DONE.
+ ms.Multi = true
+
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network devices.
+ return nil
+ }
+
+ for idx, i := range stack.Interfaces() {
+ addNewLinkMessage(ms, idx, i)
+ }
+
+ return nil
+}
+
+// getLinks handles RTM_GETLINK requests.
+func (p *Protocol) getLink(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network devices.
+ return nil
+ }
+
+ // Parse message.
+ var ifi linux.InterfaceInfoMessage
+ attrs, ok := msg.GetData(&ifi)
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+
+ // Parse attributes.
+ var byName []byte
+ for !attrs.Empty() {
+ ahdr, value, rest, ok := attrs.ParseFirst()
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+ attrs = rest
+
+ switch ahdr.Type {
+ case linux.IFLA_IFNAME:
+ if len(value) < 1 {
+ return syserr.ErrInvalidArgument
+ }
+ byName = value[:len(value)-1]
+
+ // TODO(gvisor.dev/issue/578): Support IFLA_EXT_MASK.
+ }
+ }
+
+ found := false
+ for idx, i := range stack.Interfaces() {
+ switch {
+ case ifi.Index > 0:
+ if idx != ifi.Index {
+ continue
+ }
+ case byName != nil:
+ if string(byName) != i.Name {
+ continue
+ }
+ default:
+ // Criteria not specified.
+ return syserr.ErrInvalidArgument
+ }
+
+ addNewLinkMessage(ms, idx, i)
+ found = true
+ break
+ }
+ if !found {
+ return syserr.ErrNoDevice
+ }
+ return nil
+}
+
+// addNewLinkMessage appends RTM_NEWLINK message for the given interface into
+// the message set.
+func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.RTM_NEWLINK,
+ })
+
+ m.Put(linux.InterfaceInfoMessage{
+ Family: linux.AF_UNSPEC,
+ Type: i.DeviceType,
+ Index: idx,
+ Flags: i.Flags,
+ })
+
+ m.PutAttrString(linux.IFLA_IFNAME, i.Name)
+ m.PutAttr(linux.IFLA_MTU, i.MTU)
+
+ mac := make([]byte, 6)
+ brd := mac
+ if len(i.Addr) > 0 {
+ mac = i.Addr
+ brd = bytes.Repeat([]byte{0xff}, len(i.Addr))
+ }
+ m.PutAttr(linux.IFLA_ADDRESS, mac)
+ m.PutAttr(linux.IFLA_BROADCAST, brd)
+
+ // TODO(gvisor.dev/issue/578): There are many more attributes.
+}
+
+// dumpAddrs handles RTM_GETADDR dump requests.
+func (p *Protocol) dumpAddrs(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ // RTM_GETADDR dump requests need not contain anything more than the
+ // netlink header and 1 byte protocol family common to all
+ // NETLINK_ROUTE requests.
+ //
+ // TODO(b/68878065): Filter output by passed protocol family.
+
+ // The RTM_GETADDR dump response is a set of RTM_NEWADDR messages each
+ // containing an InterfaceAddrMessage followed by a set of netlink
+ // attributes.
+
+ // We always send back an NLMSG_DONE.
+ ms.Multi = true
+
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network devices.
+ return nil
+ }
+
+ for id, as := range stack.InterfaceAddrs() {
+ for _, a := range as {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.RTM_NEWADDR,
+ })
+
+ m.Put(linux.InterfaceAddrMessage{
+ Family: a.Family,
+ PrefixLen: a.PrefixLen,
+ Index: uint32(id),
+ })
+
+ m.PutAttr(linux.IFA_LOCAL, []byte(a.Addr))
+ m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr))
+
+ // TODO(gvisor.dev/issue/578): There are many more attributes.
+ }
+ }
+
+ return nil
+}
+
+// commonPrefixLen reports the length of the longest IP address prefix.
+// This is a simplied version from Golang's src/net/addrselect.go.
+func commonPrefixLen(a, b []byte) (cpl int) {
+ for len(a) > 0 {
+ if a[0] == b[0] {
+ cpl += 8
+ a = a[1:]
+ b = b[1:]
+ continue
+ }
+ bits := 8
+ ab, bb := a[0], b[0]
+ for {
+ ab >>= 1
+ bb >>= 1
+ bits--
+ if ab == bb {
+ cpl += bits
+ return
+ }
+ }
+ }
+ return
+}
+
+// fillRoute returns the Route using LPM algorithm. Refer to Linux's
+// net/ipv4/route.c:rt_fill_info().
+func fillRoute(routes []inet.Route, addr []byte) (inet.Route, *syserr.Error) {
+ family := uint8(linux.AF_INET)
+ if len(addr) != 4 {
+ family = linux.AF_INET6
+ }
+
+ idx := -1 // Index of the Route rule to be returned.
+ idxDef := -1 // Index of the default route rule.
+ prefix := 0 // Current longest prefix.
+ for i, route := range routes {
+ if route.Family != family {
+ continue
+ }
+
+ if len(route.GatewayAddr) > 0 && route.DstLen == 0 {
+ idxDef = i
+ continue
+ }
+
+ cpl := commonPrefixLen(addr, route.DstAddr)
+ if cpl < int(route.DstLen) {
+ continue
+ }
+ cpl = int(route.DstLen)
+ if cpl > prefix {
+ idx = i
+ prefix = cpl
+ }
+ }
+ if idx == -1 {
+ idx = idxDef
+ }
+ if idx == -1 {
+ return inet.Route{}, syserr.ErrNoRoute
+ }
+
+ route := routes[idx]
+ if family == linux.AF_INET {
+ route.DstLen = 32
+ } else {
+ route.DstLen = 128
+ }
+ route.DstAddr = addr
+ route.Flags |= linux.RTM_F_CLONED // This route is cloned.
+ return route, nil
+}
+
+// parseForDestination parses a message as format of RouteMessage-RtAttr-dst.
+func parseForDestination(msg *netlink.Message) ([]byte, *syserr.Error) {
+ var rtMsg linux.RouteMessage
+ attrs, ok := msg.GetData(&rtMsg)
+ if !ok {
+ return nil, syserr.ErrInvalidArgument
+ }
+ // iproute2 added the RTM_F_LOOKUP_TABLE flag in version v4.4.0. See
+ // commit bc234301af12. Note we don't check this flag for backward
+ // compatibility.
+ if rtMsg.Flags != 0 && rtMsg.Flags != linux.RTM_F_LOOKUP_TABLE {
+ return nil, syserr.ErrNotSupported
+ }
+
+ // Expect first attribute is RTA_DST.
+ if hdr, value, _, ok := attrs.ParseFirst(); ok && hdr.Type == linux.RTA_DST {
+ return value, nil
+ }
+ return nil, syserr.ErrInvalidArgument
+}
+
+// dumpRoutes handles RTM_GETROUTE requests.
+func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ // RTM_GETROUTE dump requests need not contain anything more than the
+ // netlink header and 1 byte protocol family common to all
+ // NETLINK_ROUTE requests.
+
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network routes.
+ return nil
+ }
+
+ hdr := msg.Header()
+ routeTables := stack.RouteTable()
+
+ if hdr.Flags == linux.NLM_F_REQUEST {
+ dst, err := parseForDestination(msg)
+ if err != nil {
+ return err
+ }
+ route, err := fillRoute(routeTables, dst)
+ if err != nil {
+ // TODO(gvisor.dev/issue/1237): return NLMSG_ERROR with ENETUNREACH.
+ return syserr.ErrNotSupported
+ }
+ routeTables = append([]inet.Route{}, route)
+ } else if hdr.Flags&linux.NLM_F_DUMP == linux.NLM_F_DUMP {
+ // We always send back an NLMSG_DONE.
+ ms.Multi = true
+ } else {
+ // TODO(b/68878065): Only above cases are supported.
+ return syserr.ErrNotSupported
+ }
+
+ for _, rt := range routeTables {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.RTM_NEWROUTE,
+ })
+
+ m.Put(linux.RouteMessage{
+ Family: rt.Family,
+ DstLen: rt.DstLen,
+ SrcLen: rt.SrcLen,
+ TOS: rt.TOS,
+
+ // Always return the main table since we don't have multiple
+ // routing tables.
+ Table: linux.RT_TABLE_MAIN,
+ Protocol: rt.Protocol,
+ Scope: rt.Scope,
+ Type: rt.Type,
+
+ Flags: rt.Flags,
+ })
+
+ m.PutAttr(254, []byte{123})
+ if rt.DstLen > 0 {
+ m.PutAttr(linux.RTA_DST, rt.DstAddr)
+ }
+ if rt.SrcLen > 0 {
+ m.PutAttr(linux.RTA_SRC, rt.SrcAddr)
+ }
+ if rt.OutputInterface != 0 {
+ m.PutAttr(linux.RTA_OIF, rt.OutputInterface)
+ }
+ if len(rt.GatewayAddr) > 0 {
+ m.PutAttr(linux.RTA_GATEWAY, rt.GatewayAddr)
+ }
+
+ // TODO(gvisor.dev/issue/578): There are many more attributes.
+ }
+
+ return nil
+}
+
+// newAddr handles RTM_NEWADDR requests.
+func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network stack.
+ return syserr.ErrProtocolNotSupported
+ }
+
+ var ifa linux.InterfaceAddrMessage
+ attrs, ok := msg.GetData(&ifa)
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+
+ for !attrs.Empty() {
+ ahdr, value, rest, ok := attrs.ParseFirst()
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+ attrs = rest
+
+ switch ahdr.Type {
+ case linux.IFA_LOCAL:
+ err := stack.AddInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{
+ Family: ifa.Family,
+ PrefixLen: ifa.PrefixLen,
+ Flags: ifa.Flags,
+ Addr: value,
+ })
+ if err == syscall.EEXIST {
+ flags := msg.Header().Flags
+ if flags&linux.NLM_F_EXCL != 0 {
+ return syserr.ErrExists
+ }
+ } else if err != nil {
+ return syserr.ErrInvalidArgument
+ }
+ }
+ }
+ return nil
+}
+
+// ProcessMessage implements netlink.Protocol.ProcessMessage.
+func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ hdr := msg.Header()
+
+ // All messages start with a 1 byte protocol family.
+ var family uint8
+ if _, ok := msg.GetData(&family); !ok {
+ // Linux ignores messages missing the protocol family. See
+ // net/core/rtnetlink.c:rtnetlink_rcv_msg.
+ return nil
+ }
+
+ // Non-GET message types require CAP_NET_ADMIN.
+ if typeKind(hdr.Type) != kindGet {
+ creds := auth.CredentialsFromContext(ctx)
+ if !creds.HasCapability(linux.CAP_NET_ADMIN) {
+ return syserr.ErrPermissionDenied
+ }
+ }
+
+ if hdr.Flags&linux.NLM_F_DUMP == linux.NLM_F_DUMP {
+ // TODO(b/68878065): Only the dump variant of the types below are
+ // supported.
+ switch hdr.Type {
+ case linux.RTM_GETLINK:
+ return p.dumpLinks(ctx, msg, ms)
+ case linux.RTM_GETADDR:
+ return p.dumpAddrs(ctx, msg, ms)
+ case linux.RTM_GETROUTE:
+ return p.dumpRoutes(ctx, msg, ms)
+ default:
+ return syserr.ErrNotSupported
+ }
+ } else if hdr.Flags&linux.NLM_F_REQUEST == linux.NLM_F_REQUEST {
+ switch hdr.Type {
+ case linux.RTM_GETLINK:
+ return p.getLink(ctx, msg, ms)
+ case linux.RTM_GETROUTE:
+ return p.dumpRoutes(ctx, msg, ms)
+ case linux.RTM_NEWADDR:
+ return p.newAddr(ctx, msg, ms)
+ default:
+ return syserr.ErrNotSupported
+ }
+ }
+ return syserr.ErrNotSupported
+}
+
+// init registers the NETLINK_ROUTE provider.
+func init() {
+ netlink.RegisterProvider(linux.NETLINK_ROUTE, NewProtocol)
+}
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
new file mode 100644
index 000000000..81f34c5a2
--- /dev/null
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -0,0 +1,780 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package netlink provides core functionality for netlink sockets.
+package netlink
+
+import (
+ "math"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/device"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const sizeOfInt32 int = 4
+
+const (
+ // minBufferSize is the smallest size of a send buffer.
+ minSendBufferSize = 4 << 10 // 4096 bytes.
+
+ // defaultSendBufferSize is the default size for the send buffer.
+ defaultSendBufferSize = 16 * 1024
+
+ // maxBufferSize is the largest size a send buffer can grow to.
+ maxSendBufferSize = 4 << 20 // 4MB
+)
+
+var errNoFilter = syserr.New("no filter attached", linux.ENOENT)
+
+// netlinkSocketDevice is the netlink socket virtual device.
+var netlinkSocketDevice = device.NewAnonDevice()
+
+// LINT.IfChange
+
+// Socket is the base socket type for netlink sockets.
+//
+// This implementation only supports userspace sending and receiving messages
+// to/from the kernel.
+//
+// Socket implements socket.Socket and transport.Credentialer.
+//
+// +stateify savable
+type Socket struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ socketOpsCommon
+}
+
+// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
+//
+// +stateify savable
+type socketOpsCommon struct {
+ socket.SendReceiveTimeout
+
+ // ports provides netlink port allocation.
+ ports *port.Manager
+
+ // protocol is the netlink protocol implementation.
+ protocol Protocol
+
+ // skType is the socket type. This is either SOCK_DGRAM or SOCK_RAW for
+ // netlink sockets.
+ skType linux.SockType
+
+ // ep is a datagram unix endpoint used to buffer messages sent from the
+ // kernel to userspace. RecvMsg reads messages from this endpoint.
+ ep transport.Endpoint
+
+ // connection is the kernel's connection to ep, used to write messages
+ // sent to userspace.
+ connection transport.ConnectedEndpoint
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // bound indicates that portid is valid.
+ bound bool
+
+ // portID is the port ID allocated for this socket.
+ portID int32
+
+ // sendBufferSize is the send buffer "size". We don't actually have a
+ // fixed buffer but only consume this many bytes.
+ sendBufferSize uint32
+
+ // passcred indicates if this socket wants SCM credentials.
+ passcred bool
+
+ // filter indicates that this socket has a BPF filter "installed".
+ //
+ // TODO(gvisor.dev/issue/1119): We don't actually support filtering,
+ // this is just bookkeeping for tracking add/remove.
+ filter bool
+}
+
+var _ socket.Socket = (*Socket)(nil)
+var _ transport.Credentialer = (*Socket)(nil)
+
+// NewSocket creates a new Socket.
+func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) {
+ // Datagram endpoint used to buffer kernel -> user messages.
+ ep := transport.NewConnectionless(t)
+
+ // Bind the endpoint for good measure so we can connect to it. The
+ // bound address will never be exposed.
+ if err := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); err != nil {
+ ep.Close()
+ return nil, err
+ }
+
+ // Create a connection from which the kernel can write messages.
+ connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t)
+ if err != nil {
+ ep.Close()
+ return nil, err
+ }
+
+ return &Socket{
+ socketOpsCommon: socketOpsCommon{
+ ports: t.Kernel().NetlinkPorts(),
+ protocol: protocol,
+ skType: skType,
+ ep: ep,
+ connection: connection,
+ sendBufferSize: defaultSendBufferSize,
+ },
+ }, nil
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *socketOpsCommon) Release() {
+ s.connection.Release()
+ s.ep.Close()
+
+ if s.bound {
+ s.ports.Release(s.protocol.Protocol(), s.portID)
+ }
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // ep holds messages to be read and thus handles EventIn readiness.
+ ready := s.ep.Readiness(mask)
+
+ if mask&waiter.EventOut == waiter.EventOut {
+ // sendMsg handles messages synchronously and is thus always
+ // ready for writing.
+ ready |= waiter.EventOut
+ }
+
+ return ready
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.ep.EventRegister(e, mask)
+ // Writable readiness never changes, so no registration is needed.
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
+ s.ep.EventUnregister(e)
+}
+
+// Passcred implements transport.Credentialer.Passcred.
+func (s *socketOpsCommon) Passcred() bool {
+ s.mu.Lock()
+ passcred := s.passcred
+ s.mu.Unlock()
+ return passcred
+}
+
+// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
+func (s *socketOpsCommon) ConnectedPasscred() bool {
+ // This socket is connected to the kernel, which doesn't need creds.
+ //
+ // This is arbitrary, as ConnectedPasscred on this type has no callers.
+ return false
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (*Socket) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArguments) (uintptr, error) {
+ // TODO(b/68878065): no ioctls supported.
+ return 0, syserror.ENOTTY
+}
+
+// ExtractSockAddr extracts the SockAddrNetlink from b.
+func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) {
+ if len(b) < linux.SockAddrNetlinkSize {
+ return nil, syserr.ErrBadAddress
+ }
+
+ var sa linux.SockAddrNetlink
+ binary.Unmarshal(b[:linux.SockAddrNetlinkSize], usermem.ByteOrder, &sa)
+
+ if sa.Family != linux.AF_NETLINK {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return &sa, nil
+}
+
+// bindPort binds this socket to a port, preferring 'port' if it is available.
+//
+// port of 0 defaults to the ThreadGroup ID.
+//
+// Preconditions: mu is held.
+func (s *socketOpsCommon) bindPort(t *kernel.Task, port int32) *syserr.Error {
+ if s.bound {
+ // Re-binding is only allowed if the port doesn't change.
+ if port != s.portID {
+ return syserr.ErrInvalidArgument
+ }
+
+ return nil
+ }
+
+ if port == 0 {
+ port = int32(t.ThreadGroup().ID())
+ }
+ port, ok := s.ports.Allocate(s.protocol.Protocol(), port)
+ if !ok {
+ return syserr.ErrBusy
+ }
+
+ s.portID = port
+ s.bound = true
+ return nil
+}
+
+// Bind implements socket.Socket.Bind.
+func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ a, err := ExtractSockAddr(sockaddr)
+ if err != nil {
+ return err
+ }
+
+ // No support for multicast groups yet.
+ if a.Groups != 0 {
+ return syserr.ErrPermissionDenied
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ return s.bindPort(t, int32(a.PortID))
+}
+
+// Connect implements socket.Socket.Connect.
+func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ a, err := ExtractSockAddr(sockaddr)
+ if err != nil {
+ return err
+ }
+
+ // No support for multicast groups yet.
+ if a.Groups != 0 {
+ return syserr.ErrPermissionDenied
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if a.PortID == 0 {
+ // Netlink sockets default to connected to the kernel, but
+ // connecting anyways automatically binds if not already bound.
+ if !s.bound {
+ // Pass port 0 to get an auto-selected port ID.
+ return s.bindPort(t, 0)
+ }
+ return nil
+ }
+
+ // We don't support non-kernel destination ports. Linux returns EPERM
+ // if applications attempt to do this without NL_CFG_F_NONROOT_SEND, so
+ // we emulate that.
+ return syserr.ErrPermissionDenied
+}
+
+// Accept implements socket.Socket.Accept.
+func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ // Netlink sockets never support accept.
+ return 0, nil, 0, syserr.ErrNotSupported
+}
+
+// Listen implements socket.Socket.Listen.
+func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
+ // Netlink sockets never support listen.
+ return syserr.ErrNotSupported
+}
+
+// Shutdown implements socket.Socket.Shutdown.
+func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
+ // Netlink sockets never support shutdown.
+ return syserr.ErrNotSupported
+}
+
+// GetSockOpt implements socket.Socket.GetSockOpt.
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+ switch level {
+ case linux.SOL_SOCKET:
+ switch name {
+ case linux.SO_SNDBUF:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return int32(s.sendBufferSize), nil
+
+ case linux.SO_RCVBUF:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ // We don't have limit on receiving size.
+ return int32(math.MaxInt32), nil
+
+ case linux.SO_PASSCRED:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ var passcred int32
+ if s.Passcred() {
+ passcred = 1
+ }
+ return passcred, nil
+
+ default:
+ socket.GetSockOptEmitUnimplementedEvent(t, name)
+ }
+
+ case linux.SOL_NETLINK:
+ switch name {
+ case linux.NETLINK_BROADCAST_ERROR,
+ linux.NETLINK_CAP_ACK,
+ linux.NETLINK_DUMP_STRICT_CHK,
+ linux.NETLINK_EXT_ACK,
+ linux.NETLINK_LIST_MEMBERSHIPS,
+ linux.NETLINK_NO_ENOBUFS,
+ linux.NETLINK_PKTINFO:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+ }
+ // TODO(b/68878065): other sockopts are not supported.
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+// SetSockOpt implements socket.Socket.SetSockOpt.
+func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
+ switch level {
+ case linux.SOL_SOCKET:
+ switch name {
+ case linux.SO_SNDBUF:
+ if len(opt) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ size := usermem.ByteOrder.Uint32(opt)
+ if size < minSendBufferSize {
+ size = minSendBufferSize
+ } else if size > maxSendBufferSize {
+ size = maxSendBufferSize
+ }
+ s.mu.Lock()
+ s.sendBufferSize = size
+ s.mu.Unlock()
+ return nil
+
+ case linux.SO_RCVBUF:
+ if len(opt) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ // We don't have limit on receiving size. So just accept anything as
+ // valid for compatibility.
+ return nil
+
+ case linux.SO_PASSCRED:
+ if len(opt) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ passcred := usermem.ByteOrder.Uint32(opt)
+
+ s.mu.Lock()
+ s.passcred = passcred != 0
+ s.mu.Unlock()
+ return nil
+
+ case linux.SO_ATTACH_FILTER:
+ // TODO(gvisor.dev/issue/1119): We don't actually
+ // support filtering. If this socket can't ever send
+ // messages, then there is nothing to filter and we can
+ // advertise support. Otherwise, be conservative and
+ // return an error.
+ if s.protocol.CanSend() {
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ return syserr.ErrProtocolNotAvailable
+ }
+
+ s.mu.Lock()
+ s.filter = true
+ s.mu.Unlock()
+ return nil
+
+ case linux.SO_DETACH_FILTER:
+ // TODO(gvisor.dev/issue/1119): See above.
+ if s.protocol.CanSend() {
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ return syserr.ErrProtocolNotAvailable
+ }
+
+ s.mu.Lock()
+ filter := s.filter
+ s.filter = false
+ s.mu.Unlock()
+
+ if !filter {
+ return errNoFilter
+ }
+
+ return nil
+
+ default:
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ }
+
+ case linux.SOL_NETLINK:
+ switch name {
+ case linux.NETLINK_ADD_MEMBERSHIP,
+ linux.NETLINK_BROADCAST_ERROR,
+ linux.NETLINK_CAP_ACK,
+ linux.NETLINK_DROP_MEMBERSHIP,
+ linux.NETLINK_DUMP_STRICT_CHK,
+ linux.NETLINK_EXT_ACK,
+ linux.NETLINK_LISTEN_ALL_NSID,
+ linux.NETLINK_NO_ENOBUFS,
+ linux.NETLINK_PKTINFO:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+
+ }
+ // TODO(b/68878065): other sockopts are not supported.
+ return syserr.ErrProtocolNotAvailable
+}
+
+// GetSockName implements socket.Socket.GetSockName.
+func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ sa := &linux.SockAddrNetlink{
+ Family: linux.AF_NETLINK,
+ PortID: uint32(s.portID),
+ }
+ return sa, uint32(binary.Size(sa)), nil
+}
+
+// GetPeerName implements socket.Socket.GetPeerName.
+func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+ sa := &linux.SockAddrNetlink{
+ Family: linux.AF_NETLINK,
+ // TODO(b/68878065): Support non-kernel peers. For now the peer
+ // must be the kernel.
+ PortID: 0,
+ }
+ return sa, uint32(binary.Size(sa)), nil
+}
+
+// RecvMsg implements socket.Socket.RecvMsg.
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+ from := &linux.SockAddrNetlink{
+ Family: linux.AF_NETLINK,
+ PortID: 0,
+ }
+ fromLen := uint32(binary.Size(from))
+
+ trunc := flags&linux.MSG_TRUNC != 0
+
+ r := unix.EndpointReader{
+ Ctx: t,
+ Endpoint: s.ep,
+ Peek: flags&linux.MSG_PEEK != 0,
+ }
+
+ doRead := func() (int64, error) {
+ return dst.CopyOutFrom(t, &r)
+ }
+
+ // If MSG_TRUNC is set with a zero byte destination then we still need
+ // to read the message and discard it, or in the case where MSG_PEEK is
+ // set, leave it be. In both cases the full message length must be
+ // returned.
+ if trunc && dst.Addrs.NumBytes() == 0 {
+ doRead = func() (int64, error) {
+ err := r.Truncate()
+ // Always return zero for bytes read since the destination size is
+ // zero.
+ return 0, err
+ }
+ }
+
+ if n, err := doRead(); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ var mflags int
+ if n < int64(r.MsgSize) {
+ mflags |= linux.MSG_TRUNC
+ }
+ if trunc {
+ n = int64(r.MsgSize)
+ }
+ return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // receive all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ for {
+ if n, err := doRead(); err != syserror.ErrWouldBlock {
+ var mflags int
+ if n < int64(r.MsgSize) {
+ mflags |= linux.MSG_TRUNC
+ }
+ if trunc {
+ n = int64(r.MsgSize)
+ }
+ return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
+ }
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ }
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
+ }
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *Socket) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ return dst.CopyOutFrom(ctx, &unix.EndpointReader{
+ Endpoint: s.ep,
+ })
+}
+
+// kernelSCM implements control.SCMCredentials with credentials that represent
+// the kernel itself rather than a Task.
+//
+// +stateify savable
+type kernelSCM struct{}
+
+// Equals implements transport.CredentialsControlMessage.Equals.
+func (kernelSCM) Equals(oc transport.CredentialsControlMessage) bool {
+ _, ok := oc.(kernelSCM)
+ return ok
+}
+
+// Credentials implements control.SCMCredentials.Credentials.
+func (kernelSCM) Credentials(*kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) {
+ return 0, auth.RootUID, auth.RootGID
+}
+
+// kernelCreds is the concrete version of kernelSCM used in all creds.
+var kernelCreds = &kernelSCM{}
+
+// sendResponse sends the response messages in ms back to userspace.
+func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error {
+ // Linux combines multiple netlink messages into a single datagram.
+ bufs := make([][]byte, 0, len(ms.Messages))
+ for _, m := range ms.Messages {
+ bufs = append(bufs, m.Finalize())
+ }
+
+ // All messages are from the kernel.
+ cms := transport.ControlMessages{
+ Credentials: kernelCreds,
+ }
+
+ if len(bufs) > 0 {
+ // RecvMsg never receives the address, so we don't need to send
+ // one.
+ _, notify, err := s.connection.Send(bufs, cms, tcpip.FullAddress{})
+ // If the buffer is full, we simply drop messages, just like
+ // Linux.
+ if err != nil && err != syserr.ErrWouldBlock {
+ return err
+ }
+ if notify {
+ s.connection.SendNotify()
+ }
+ }
+
+ // N.B. multi-part messages should still send NLMSG_DONE even if
+ // MessageSet contains no messages.
+ //
+ // N.B. NLMSG_DONE is always sent in a different datagram. See
+ // net/netlink/af_netlink.c:netlink_dump.
+ if ms.Multi {
+ m := NewMessage(linux.NetlinkMessageHeader{
+ Type: linux.NLMSG_DONE,
+ Flags: linux.NLM_F_MULTI,
+ Seq: ms.Seq,
+ PortID: uint32(ms.PortID),
+ })
+
+ // Add the dump_done_errno payload.
+ m.Put(int64(0))
+
+ _, notify, err := s.connection.Send([][]byte{m.Finalize()}, cms, tcpip.FullAddress{})
+ if err != nil && err != syserr.ErrWouldBlock {
+ return err
+ }
+ if notify {
+ s.connection.SendNotify()
+ }
+ }
+
+ return nil
+}
+
+func dumpErrorMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.NLMSG_ERROR,
+ })
+ m.Put(linux.NetlinkErrorMessage{
+ Error: int32(-err.ToLinux().Number()),
+ Header: hdr,
+ })
+}
+
+func dumpAckMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet) {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.NLMSG_ERROR,
+ })
+ m.Put(linux.NetlinkErrorMessage{
+ Error: 0,
+ Header: hdr,
+ })
+}
+
+// processMessages handles each message in buf, passing it to the protocol
+// handler for final handling.
+func (s *socketOpsCommon) processMessages(ctx context.Context, buf []byte) *syserr.Error {
+ for len(buf) > 0 {
+ msg, rest, ok := ParseMessage(buf)
+ if !ok {
+ // Linux ignores messages that are too short. See
+ // net/netlink/af_netlink.c:netlink_rcv_skb.
+ break
+ }
+ buf = rest
+ hdr := msg.Header()
+
+ // Ignore control messages.
+ if hdr.Type < linux.NLMSG_MIN_TYPE {
+ continue
+ }
+
+ ms := NewMessageSet(s.portID, hdr.Seq)
+ if err := s.protocol.ProcessMessage(ctx, msg, ms); err != nil {
+ dumpErrorMesage(hdr, ms, err)
+ } else if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK {
+ dumpAckMesage(hdr, ms)
+ }
+
+ if err := s.sendResponse(ctx, ms); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// sendMsg is the core of message send, used for SendMsg and Write.
+func (s *socketOpsCommon) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+ dstPort := int32(0)
+
+ if len(to) != 0 {
+ a, err := ExtractSockAddr(to)
+ if err != nil {
+ return 0, err
+ }
+
+ // No support for multicast groups yet.
+ if a.Groups != 0 {
+ return 0, syserr.ErrPermissionDenied
+ }
+
+ dstPort = int32(a.PortID)
+ }
+
+ if dstPort != 0 {
+ // Non-kernel destinations not supported yet. Treat as if
+ // NL_CFG_F_NONROOT_SEND is not set.
+ return 0, syserr.ErrPermissionDenied
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // For simplicity, and consistency with Linux, we copy in the entire
+ // message up front.
+ if src.NumBytes() > int64(s.sendBufferSize) {
+ return 0, syserr.ErrMessageTooLong
+ }
+
+ buf := make([]byte, src.NumBytes())
+ n, err := src.CopyIn(ctx, buf)
+ if err != nil {
+ // Don't partially consume messages.
+ return 0, syserr.FromError(err)
+ }
+
+ if err := s.processMessages(ctx, buf); err != nil {
+ return 0, err
+ }
+
+ return n, nil
+}
+
+// SendMsg implements socket.Socket.SendMsg.
+func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+ return s.sendMsg(t, src, to, flags, controlMessages)
+}
+
+// Write implements fs.FileOperations.Write.
+func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{})
+ return int64(n), err.ToError()
+}
+
+// State implements socket.Socket.State.
+func (s *socketOpsCommon) State() uint32 {
+ return s.ep.State()
+}
+
+// Type implements socket.Socket.Type.
+func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
+ return linux.AF_NETLINK, s.skType, s.protocol.Protocol()
+}
+
+// LINT.ThenChange(./socket_vfs2.go)
diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go
new file mode 100644
index 000000000..dbcd8b49a
--- /dev/null
+++ b/pkg/sentry/socket/netlink/socket_vfs2.go
@@ -0,0 +1,152 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netlink
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SocketVFS2 is the base VFS2 socket type for netlink sockets.
+//
+// This implementation only supports userspace sending and receiving messages
+// to/from the kernel.
+//
+// SocketVFS2 implements socket.SocketVFS2 and transport.Credentialer.
+type SocketVFS2 struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.LockFD
+
+ socketOpsCommon
+}
+
+var _ socket.SocketVFS2 = (*SocketVFS2)(nil)
+var _ transport.Credentialer = (*SocketVFS2)(nil)
+
+// NewVFS2 creates a new SocketVFS2.
+func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketVFS2, *syserr.Error) {
+ // Datagram endpoint used to buffer kernel -> user messages.
+ ep := transport.NewConnectionless(t)
+
+ // Bind the endpoint for good measure so we can connect to it. The
+ // bound address will never be exposed.
+ if err := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); err != nil {
+ ep.Close()
+ return nil, err
+ }
+
+ // Create a connection from which the kernel can write messages.
+ connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t)
+ if err != nil {
+ ep.Close()
+ return nil, err
+ }
+
+ fd := &SocketVFS2{
+ socketOpsCommon: socketOpsCommon{
+ ports: t.Kernel().NetlinkPorts(),
+ protocol: protocol,
+ skType: skType,
+ ep: ep,
+ connection: connection,
+ sendBufferSize: defaultSendBufferSize,
+ },
+ }
+ fd.LockFD.Init(&vfs.FileLocks{})
+ return fd, nil
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.socketOpsCommon.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.socketOpsCommon.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SocketVFS2) EventUnregister(e *waiter.Entry) {
+ s.socketOpsCommon.EventUnregister(e)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (*SocketVFS2) Ioctl(context.Context, usermem.IO, arch.SyscallArguments) (uintptr, error) {
+ // TODO(b/68878065): no ioctls supported.
+ return 0, syserror.ENOTTY
+}
+
+// PRead implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Read implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ return dst.CopyOutFrom(ctx, &unix.EndpointReader{
+ Endpoint: s.ep,
+ })
+}
+
+// PWrite implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{})
+ return int64(n), err.ToError()
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/socket/netlink/uevent/BUILD b/pkg/sentry/socket/netlink/uevent/BUILD
new file mode 100644
index 000000000..b6434923c
--- /dev/null
+++ b/pkg/sentry/socket/netlink/uevent/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "uevent",
+ srcs = ["protocol.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/socket/netlink",
+ "//pkg/syserr",
+ ],
+)
diff --git a/pkg/sentry/socket/netlink/uevent/protocol.go b/pkg/sentry/socket/netlink/uevent/protocol.go
new file mode 100644
index 000000000..029ba21b5
--- /dev/null
+++ b/pkg/sentry/socket/netlink/uevent/protocol.go
@@ -0,0 +1,60 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package uevent provides a NETLINK_KOBJECT_UEVENT socket protocol.
+//
+// NETLINK_KOBJECT_UEVENT sockets send udev-style device events. gVisor does
+// not support any device events, so these sockets never send any messages.
+package uevent
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netlink"
+ "gvisor.dev/gvisor/pkg/syserr"
+)
+
+// Protocol implements netlink.Protocol.
+//
+// +stateify savable
+type Protocol struct{}
+
+var _ netlink.Protocol = (*Protocol)(nil)
+
+// NewProtocol creates a NETLINK_KOBJECT_UEVENT netlink.Protocol.
+func NewProtocol(t *kernel.Task) (netlink.Protocol, *syserr.Error) {
+ return &Protocol{}, nil
+}
+
+// Protocol implements netlink.Protocol.Protocol.
+func (p *Protocol) Protocol() int {
+ return linux.NETLINK_KOBJECT_UEVENT
+}
+
+// CanSend implements netlink.Protocol.CanSend.
+func (p *Protocol) CanSend() bool {
+ return false
+}
+
+// ProcessMessage implements netlink.Protocol.ProcessMessage.
+func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ // Silently ignore all messages.
+ return nil
+}
+
+// init registers the NETLINK_KOBJECT_UEVENT provider.
+func init() {
+ netlink.RegisterProvider(linux.NETLINK_KOBJECT_UEVENT, NewProtocol)
+}
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
new file mode 100644
index 000000000..ea6ebd0e2
--- /dev/null
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -0,0 +1,56 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "netstack",
+ srcs = [
+ "device.go",
+ "netstack.go",
+ "netstack_vfs2.go",
+ "provider.go",
+ "provider_vfs2.go",
+ "save_restore.go",
+ "stack.go",
+ ],
+ visibility = [
+ "//pkg/sentry:internal",
+ ],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/amutex",
+ "//pkg/binary",
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/metric",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsimpl/sockfs",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/netfilter",
+ "//pkg/sentry/unimpl",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/socket/netstack/device.go b/pkg/sentry/socket/netstack/device.go
new file mode 100644
index 000000000..fbeb89fb8
--- /dev/null
+++ b/pkg/sentry/socket/netstack/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netstack
+
+import "gvisor.dev/gvisor/pkg/sentry/device"
+
+// netstackDevice is the endpoint socket virtual device.
+var netstackDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
new file mode 100644
index 000000000..3b248a953
--- /dev/null
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -0,0 +1,3143 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package netstack provides an implementation of the socket.Socket interface
+// that is backed by a tcpip.Endpoint.
+//
+// It does not depend on any particular endpoint implementation, and thus can
+// be used to expose certain endpoints to the sentry while leaving others out,
+// for example, TCP endpoints and Unix-domain endpoints.
+//
+// Lock ordering: netstack => mm: ioSequencePayload copies user memory inside
+// tcpip.Endpoint.Write(). Netstack is allowed to (and does) hold locks during
+// this operation.
+package netstack
+
+import (
+ "bytes"
+ "io"
+ "math"
+ "reflect"
+ "sync/atomic"
+ "syscall"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/amutex"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/metric"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
+ "gvisor.dev/gvisor/pkg/sentry/unimpl"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func mustCreateMetric(name, description string) *tcpip.StatCounter {
+ var cm tcpip.StatCounter
+ metric.MustRegisterCustomUint64Metric(name, true /* cumulative */, false /* sync */, description, cm.Value)
+ return &cm
+}
+
+func mustCreateGauge(name, description string) *tcpip.StatCounter {
+ var cm tcpip.StatCounter
+ metric.MustRegisterCustomUint64Metric(name, false /* cumulative */, false /* sync */, description, cm.Value)
+ return &cm
+}
+
+// Metrics contains metrics exported by netstack.
+var Metrics = tcpip.Stats{
+ UnknownProtocolRcvdPackets: mustCreateMetric("/netstack/unknown_protocol_received_packets", "Number of packets received by netstack that were for an unknown or unsupported protocol."),
+ MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."),
+ DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."),
+ ICMP: tcpip.ICMPStats{
+ V4PacketsSent: tcpip.ICMPv4SentPacketStats{
+ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
+ Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."),
+ },
+ Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."),
+ },
+ V4PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{
+ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
+ Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."),
+ },
+ Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."),
+ },
+ V6PacketsSent: tcpip.ICMPv6SentPacketStats{
+ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."),
+ },
+ Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."),
+ },
+ V6PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{
+ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."),
+ },
+ Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."),
+ },
+ },
+ IP: tcpip.IPStats{
+ PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."),
+ InvalidDestinationAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."),
+ InvalidSourceAddressesReceived: mustCreateMetric("/netstack/ip/invalid_source_addresses_received", "Total number of IP packets received with an unknown or invalid source address."),
+ PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."),
+ PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."),
+ OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."),
+ MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."),
+ MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."),
+ },
+ TCP: tcpip.TCPStats{
+ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."),
+ PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."),
+ CurrentEstablished: mustCreateGauge("/netstack/tcp/current_established", "Number of connections in ESTABLISHED state now."),
+ CurrentConnected: mustCreateGauge("/netstack/tcp/current_open", "Number of connections that are in connected state."),
+ EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"),
+ EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "Number of times established TCP connections made a transition to CLOSED state."),
+ EstablishedTimedout: mustCreateMetric("/netstack/tcp/established_timedout", "Number of times an established connection was reset because of keep-alive time out."),
+ ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."),
+ ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."),
+ ListenOverflowSynCookieSent: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_sent", "Number of times a SYN cookie was sent."),
+ ListenOverflowSynCookieRcvd: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_rcvd", "Number of times a SYN cookie was received."),
+ ListenOverflowInvalidSynCookieRcvd: mustCreateMetric("/netstack/tcp/listen_overflow_invalid_syn_cookie_rcvd", "Number of times an invalid SYN cookie was received."),
+ FailedConnectionAttempts: mustCreateMetric("/netstack/tcp/failed_connection_attempts", "Number of calls to Connect or Listen (active and passive openings, respectively) that end in an error."),
+ ValidSegmentsReceived: mustCreateMetric("/netstack/tcp/valid_segments_received", "Number of TCP segments received that the transport layer successfully parsed."),
+ InvalidSegmentsReceived: mustCreateMetric("/netstack/tcp/invalid_segments_received", "Number of TCP segments received that the transport layer could not parse."),
+ SegmentsSent: mustCreateMetric("/netstack/tcp/segments_sent", "Number of TCP segments sent."),
+ SegmentSendErrors: mustCreateMetric("/netstack/tcp/segment_send_errors", "Number of TCP segments failed to be sent."),
+ ResetsSent: mustCreateMetric("/netstack/tcp/resets_sent", "Number of TCP resets sent."),
+ ResetsReceived: mustCreateMetric("/netstack/tcp/resets_received", "Number of TCP resets received."),
+ Retransmits: mustCreateMetric("/netstack/tcp/retransmits", "Number of TCP segments retransmitted."),
+ FastRecovery: mustCreateMetric("/netstack/tcp/fast_recovery", "Number of times fast recovery was used to recover from packet loss."),
+ SACKRecovery: mustCreateMetric("/netstack/tcp/sack_recovery", "Number of times SACK recovery was used to recover from packet loss."),
+ SlowStartRetransmits: mustCreateMetric("/netstack/tcp/slow_start_retransmits", "Number of segments retransmitted in slow start mode."),
+ FastRetransmit: mustCreateMetric("/netstack/tcp/fast_retransmit", "Number of TCP segments which were fast retransmitted."),
+ Timeouts: mustCreateMetric("/netstack/tcp/timeouts", "Number of times RTO expired."),
+ ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."),
+ },
+ UDP: tcpip.UDPStats{
+ PacketsReceived: mustCreateMetric("/netstack/udp/packets_received", "Number of UDP datagrams received via HandlePacket."),
+ UnknownPortErrors: mustCreateMetric("/netstack/udp/unknown_port_errors", "Number of incoming UDP datagrams dropped because they did not have a known destination port."),
+ ReceiveBufferErrors: mustCreateMetric("/netstack/udp/receive_buffer_errors", "Number of incoming UDP datagrams dropped due to the receiving buffer being in an invalid state."),
+ MalformedPacketsReceived: mustCreateMetric("/netstack/udp/malformed_packets_received", "Number of incoming UDP datagrams dropped due to the UDP header being in a malformed state."),
+ PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."),
+ PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."),
+ ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."),
+ },
+}
+
+// DefaultTTL is linux's default TTL. All network protocols in all stacks used
+// with this package must have this value set as their default TTL.
+const DefaultTTL = 64
+
+const sizeOfInt32 int = 4
+
+var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL)
+
+// ntohs converts a 16-bit number from network byte order to host byte order. It
+// assumes that the host is little endian.
+func ntohs(v uint16) uint16 {
+ return v<<8 | v>>8
+}
+
+// htons converts a 16-bit number from host byte order to network byte order. It
+// assumes that the host is little endian.
+func htons(v uint16) uint16 {
+ return ntohs(v)
+}
+
+// commonEndpoint represents the intersection of a tcpip.Endpoint and a
+// transport.Endpoint.
+type commonEndpoint interface {
+ // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress and
+ // transport.Endpoint.GetLocalAddress.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress and
+ // transport.Endpoint.GetRemoteAddress.
+ GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // Readiness implements tcpip.Endpoint.Readiness and
+ // transport.Endpoint.Readiness.
+ Readiness(mask waiter.EventMask) waiter.EventMask
+
+ // SetSockOpt implements tcpip.Endpoint.SetSockOpt and
+ // transport.Endpoint.SetSockOpt.
+ SetSockOpt(interface{}) *tcpip.Error
+
+ // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and
+ // transport.Endpoint.SetSockOptBool.
+ SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
+
+ // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
+ // transport.Endpoint.SetSockOptInt.
+ SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
+
+ // GetSockOpt implements tcpip.Endpoint.GetSockOpt and
+ // transport.Endpoint.GetSockOpt.
+ GetSockOpt(interface{}) *tcpip.Error
+
+ // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and
+ // transport.Endpoint.GetSockOpt.
+ GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
+
+ // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and
+ // transport.Endpoint.GetSockOpt.
+ GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
+}
+
+// LINT.IfChange
+
+// SocketOperations encapsulates all the state needed to represent a network stack
+// endpoint in the kernel context.
+//
+// +stateify savable
+type SocketOperations struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ socketOpsCommon
+}
+
+// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
+//
+// +stateify savable
+type socketOpsCommon struct {
+ socket.SendReceiveTimeout
+ *waiter.Queue
+
+ family int
+ Endpoint tcpip.Endpoint
+ skType linux.SockType
+ protocol int
+
+ // readViewHasData is 1 iff readView has data to be read, 0 otherwise.
+ // Must be accessed using atomic operations. It must only be written
+ // with readMu held but can be read without holding readMu. The latter
+ // is required to avoid deadlocks in epoll Readiness checks.
+ readViewHasData uint32
+
+ // readMu protects access to the below fields.
+ readMu sync.Mutex `state:"nosave"`
+ // readView contains the remaining payload from the last packet.
+ readView buffer.View
+ // readCM holds control message information for the last packet read
+ // from Endpoint.
+ readCM tcpip.ControlMessages
+ sender tcpip.FullAddress
+
+ // sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps
+ // of returned messages can be returned via control messages. When
+ // false, the same timestamp is instead stored and can be read via the
+ // SIOCGSTAMP ioctl. It is protected by readMu. See socket(7).
+ sockOptTimestamp bool
+ // timestampValid indicates whether timestamp for SIOCGSTAMP has been
+ // set. It is protected by readMu.
+ timestampValid bool
+ // timestampNS holds the timestamp to use with SIOCTSTAMP. It is only
+ // valid when timestampValid is true. It is protected by readMu.
+ timestampNS int64
+
+ // sockOptInq corresponds to TCP_INQ. It is implemented at this level
+ // because it takes into account data from readView.
+ sockOptInq bool
+}
+
+// New creates a new endpoint socket.
+func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
+ if skType == linux.SOCK_STREAM {
+ if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ }
+
+ dirent := socket.NewDirent(t, netstackDevice)
+ defer dirent.DecRef()
+ return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, &SocketOperations{
+ socketOpsCommon: socketOpsCommon{
+ Queue: queue,
+ family: family,
+ Endpoint: endpoint,
+ skType: skType,
+ protocol: protocol,
+ },
+ }), nil
+}
+
+var sockAddrInetSize = int(binary.Size(linux.SockAddrInet{}))
+var sockAddrInet6Size = int(binary.Size(linux.SockAddrInet6{}))
+var sockAddrLinkSize = int(binary.Size(linux.SockAddrLink{}))
+
+// bytesToIPAddress converts an IPv4 or IPv6 address from the user to the
+// netstack representation taking any addresses into account.
+func bytesToIPAddress(addr []byte) tcpip.Address {
+ if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) {
+ return ""
+ }
+ return tcpip.Address(addr)
+}
+
+// AddressAndFamily reads an sockaddr struct from the given address and
+// converts it to the FullAddress format. It supports AF_UNIX, AF_INET,
+// AF_INET6, and AF_PACKET addresses.
+//
+// AddressAndFamily returns an address and its family.
+func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
+ // Make sure we have at least 2 bytes for the address family.
+ if len(addr) < 2 {
+ return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
+ }
+
+ // Get the rest of the fields based on the address family.
+ switch family := usermem.ByteOrder.Uint16(addr); family {
+ case linux.AF_UNIX:
+ path := addr[2:]
+ if len(path) > linux.UnixPathMax {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ // Drop the terminating NUL (if one exists) and everything after
+ // it for filesystem (non-abstract) addresses.
+ if len(path) > 0 && path[0] != 0 {
+ if n := bytes.IndexByte(path[1:], 0); n >= 0 {
+ path = path[:n+1]
+ }
+ }
+ return tcpip.FullAddress{
+ Addr: tcpip.Address(path),
+ }, family, nil
+
+ case linux.AF_INET:
+ var a linux.SockAddrInet
+ if len(addr) < sockAddrInetSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
+
+ out := tcpip.FullAddress{
+ Addr: bytesToIPAddress(a.Addr[:]),
+ Port: ntohs(a.Port),
+ }
+ return out, family, nil
+
+ case linux.AF_INET6:
+ var a linux.SockAddrInet6
+ if len(addr) < sockAddrInet6Size {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
+
+ out := tcpip.FullAddress{
+ Addr: bytesToIPAddress(a.Addr[:]),
+ Port: ntohs(a.Port),
+ }
+ if isLinkLocal(out.Addr) {
+ out.NIC = tcpip.NICID(a.Scope_id)
+ }
+ return out, family, nil
+
+ case linux.AF_PACKET:
+ var a linux.SockAddrLink
+ if len(addr) < sockAddrLinkSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+ if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+
+ // TODO(b/129292371): Return protocol too.
+ return tcpip.FullAddress{
+ NIC: tcpip.NICID(a.InterfaceIndex),
+ Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ }, family, nil
+
+ case linux.AF_UNSPEC:
+ return tcpip.FullAddress{}, family, nil
+
+ default:
+ return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported
+ }
+}
+
+func (s *socketOpsCommon) isPacketBased() bool {
+ return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW
+}
+
+// fetchReadView updates the readView field of the socket if it's currently
+// empty. It assumes that the socket is locked.
+//
+// Precondition: s.readMu must be held.
+func (s *socketOpsCommon) fetchReadView() *syserr.Error {
+ if len(s.readView) > 0 {
+ return nil
+ }
+ s.readView = nil
+ s.sender = tcpip.FullAddress{}
+
+ v, cms, err := s.Endpoint.Read(&s.sender)
+ if err != nil {
+ atomic.StoreUint32(&s.readViewHasData, 0)
+ return syserr.TranslateNetstackError(err)
+ }
+
+ s.readView = v
+ s.readCM = cms
+ atomic.StoreUint32(&s.readViewHasData, 1)
+
+ return nil
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *socketOpsCommon) Release() {
+ s.Endpoint.Close()
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false)
+ if err == syserr.ErrWouldBlock {
+ return int64(n), syserror.ErrWouldBlock
+ }
+ if err != nil {
+ return 0, err.ToError()
+ }
+ return int64(n), nil
+}
+
+// WriteTo implements fs.FileOperations.WriteTo.
+func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
+ s.readMu.Lock()
+
+ // Copy as much data as possible.
+ done := int64(0)
+ for count > 0 {
+ // This may return a blocking error.
+ if err := s.fetchReadView(); err != nil {
+ s.readMu.Unlock()
+ return done, err.ToError()
+ }
+
+ // Write to the underlying file.
+ n, err := dst.Write(s.readView)
+ done += int64(n)
+ count -= int64(n)
+ if dup {
+ // That's all we support for dup. This is generally
+ // supported by any Linux system calls, but the
+ // expectation is that now a caller will call read to
+ // actually remove these bytes from the socket.
+ break
+ }
+
+ // Drop that part of the view.
+ s.readView.TrimFront(n)
+ if err != nil {
+ s.readMu.Unlock()
+ return done, err
+ }
+ }
+
+ s.readMu.Unlock()
+ return done, nil
+}
+
+// ioSequencePayload implements tcpip.Payload.
+//
+// t copies user memory bytes on demand based on the requested size.
+type ioSequencePayload struct {
+ ctx context.Context
+ src usermem.IOSequence
+}
+
+// FullPayload implements tcpip.Payloader.FullPayload
+func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) {
+ return i.Payload(int(i.src.NumBytes()))
+}
+
+// Payload implements tcpip.Payloader.Payload.
+func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) {
+ if max := int(i.src.NumBytes()); size > max {
+ size = max
+ }
+ v := buffer.NewView(size)
+ if _, err := i.src.CopyIn(i.ctx, v); err != nil {
+ return nil, tcpip.ErrBadAddress
+ }
+ return v, nil
+}
+
+// DropFirst drops the first n bytes from underlying src.
+func (i *ioSequencePayload) DropFirst(n int) {
+ i.src = i.src.DropFirst(int(n))
+}
+
+// Write implements fs.FileOperations.Write.
+func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ f := &ioSequencePayload{ctx: ctx, src: src}
+ n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
+ if err == tcpip.ErrWouldBlock {
+ return 0, syserror.ErrWouldBlock
+ }
+
+ if resCh != nil {
+ if err := amutex.Block(ctx, resCh); err != nil {
+ return 0, err
+ }
+ n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{})
+ }
+
+ if err != nil {
+ return 0, syserr.TranslateNetstackError(err).ToError()
+ }
+
+ if int64(n) < src.NumBytes() {
+ return int64(n), syserror.ErrWouldBlock
+ }
+
+ return int64(n), nil
+}
+
+// readerPayload implements tcpip.Payloader.
+//
+// It allocates a view and reads from a reader on-demand, based on available
+// capacity in the endpoint.
+type readerPayload struct {
+ ctx context.Context
+ r io.Reader
+ count int64
+ err error
+}
+
+// FullPayload implements tcpip.Payloader.FullPayload.
+func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) {
+ return r.Payload(int(r.count))
+}
+
+// Payload implements tcpip.Payloader.Payload.
+func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) {
+ if size > int(r.count) {
+ size = int(r.count)
+ }
+ v := buffer.NewView(size)
+ n, err := r.r.Read(v)
+ if n > 0 {
+ // We ignore the error here. It may re-occur on subsequent
+ // reads, but for now we can enqueue some amount of data.
+ r.count -= int64(n)
+ return v[:n], nil
+ }
+ if err == syserror.ErrWouldBlock {
+ return nil, tcpip.ErrWouldBlock
+ } else if err != nil {
+ r.err = err // Save for propation.
+ return nil, tcpip.ErrBadAddress
+ }
+
+ // There is no data and no error. Return an error, which will propagate
+ // r.err, which will be nil. This is the desired result: (0, nil).
+ return nil, tcpip.ErrBadAddress
+}
+
+// ReadFrom implements fs.FileOperations.ReadFrom.
+func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
+ f := &readerPayload{ctx: ctx, r: r, count: count}
+ n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{
+ // Reads may be destructive but should be very fast,
+ // so we can't release the lock while copying data.
+ Atomic: true,
+ })
+ if err == tcpip.ErrWouldBlock {
+ return 0, syserror.ErrWouldBlock
+ }
+
+ if resCh != nil {
+ if err := amutex.Block(ctx, resCh); err != nil {
+ return 0, err
+ }
+ n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{
+ Atomic: true, // See above.
+ })
+ }
+ if err == tcpip.ErrWouldBlock {
+ return n, syserror.ErrWouldBlock
+ } else if err != nil {
+ return int64(n), f.err // Propagate error.
+ }
+
+ return int64(n), nil
+}
+
+// Readiness returns a mask of ready events for socket s.
+func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
+ r := s.Endpoint.Readiness(mask)
+
+ // Check our cached value iff the caller asked for readability and the
+ // endpoint itself is currently not readable.
+ if (mask & ^r & waiter.EventIn) != 0 {
+ if atomic.LoadUint32(&s.readViewHasData) == 1 {
+ r |= waiter.EventIn
+ }
+ }
+
+ return r
+}
+
+func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error {
+ if family == uint16(s.family) {
+ return nil
+ }
+ if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 {
+ v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption)
+ if err != nil {
+ return syserr.TranslateNetstackError(err)
+ }
+ if !v {
+ return nil
+ }
+ }
+ return syserr.ErrInvalidArgument
+}
+
+// mapFamily maps the AF_INET ANY address to the IPv4-mapped IPv6 ANY if the
+// receiver's family is AF_INET6.
+//
+// This is a hack to work around the fact that both IPv4 and IPv6 ANY are
+// represented by the empty string.
+//
+// TODO(gvisor.dev/issue/1556): remove this function.
+func (s *socketOpsCommon) mapFamily(addr tcpip.FullAddress, family uint16) tcpip.FullAddress {
+ if len(addr.Addr) == 0 && s.family == linux.AF_INET6 && family == linux.AF_INET {
+ addr.Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
+ }
+ return addr
+}
+
+// Connect implements the linux syscall connect(2) for sockets backed by
+// tpcip.Endpoint.
+func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ addr, family, err := AddressAndFamily(sockaddr)
+ if err != nil {
+ return err
+ }
+
+ if family == linux.AF_UNSPEC {
+ err := s.Endpoint.Disconnect()
+ if err == tcpip.ErrNotSupported {
+ return syserr.ErrAddressFamilyNotSupported
+ }
+ return syserr.TranslateNetstackError(err)
+ }
+
+ if err := s.checkFamily(family, false /* exact */); err != nil {
+ return err
+ }
+ addr = s.mapFamily(addr, family)
+
+ // Always return right away in the non-blocking case.
+ if !blocking {
+ return syserr.TranslateNetstackError(s.Endpoint.Connect(addr))
+ }
+
+ // Register for notification when the endpoint becomes writable, then
+ // initiate the connection.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+
+ if err := s.Endpoint.Connect(addr); err != tcpip.ErrConnectStarted && err != tcpip.ErrAlreadyConnecting {
+ if (s.family == unix.AF_INET || s.family == unix.AF_INET6) && s.skType == linux.SOCK_STREAM {
+ // TCP unlike UDP returns EADDRNOTAVAIL when it can't
+ // find an available local ephemeral port.
+ if err == tcpip.ErrNoPortAvailable {
+ return syserr.ErrAddressNotAvailable
+ }
+ }
+
+ return syserr.TranslateNetstackError(err)
+ }
+
+ // It's pending, so we have to wait for a notification, and fetch the
+ // result once the wait completes.
+ if err := t.Block(ch); err != nil {
+ return syserr.FromError(err)
+ }
+
+ // Call Connect() again after blocking to find connect's result.
+ return syserr.TranslateNetstackError(s.Endpoint.Connect(addr))
+}
+
+// Bind implements the linux syscall bind(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ if len(sockaddr) < 2 {
+ return syserr.ErrInvalidArgument
+ }
+
+ family := usermem.ByteOrder.Uint16(sockaddr)
+ var addr tcpip.FullAddress
+
+ // Bind for AF_PACKET requires only family, protocol and ifindex.
+ // In function AddressAndFamily, we check the address length which is
+ // not needed for AF_PACKET bind.
+ if family == linux.AF_PACKET {
+ var a linux.SockAddrLink
+ if len(sockaddr) < sockAddrLinkSize {
+ return syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(sockaddr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+
+ if a.Protocol != uint16(s.protocol) {
+ return syserr.ErrInvalidArgument
+ }
+
+ addr = tcpip.FullAddress{
+ NIC: tcpip.NICID(a.InterfaceIndex),
+ Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ }
+ } else {
+ var err *syserr.Error
+ addr, family, err = AddressAndFamily(sockaddr)
+ if err != nil {
+ return err
+ }
+
+ if err = s.checkFamily(family, true /* exact */); err != nil {
+ return err
+ }
+
+ addr = s.mapFamily(addr, family)
+ }
+
+ // Issue the bind request to the endpoint.
+ return syserr.TranslateNetstackError(s.Endpoint.Bind(addr))
+}
+
+// Listen implements the linux syscall listen(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
+ return syserr.TranslateNetstackError(s.Endpoint.Listen(backlog))
+}
+
+// blockingAccept implements a blocking version of accept(2), that is, if no
+// connections are ready to be accept, it will block until one becomes ready.
+func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) {
+ // Register for notifications.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ // Try to accept the connection again; if it fails, then wait until we
+ // get a notification.
+ for {
+ if ep, wq, err := s.Endpoint.Accept(); err != tcpip.ErrWouldBlock {
+ return ep, wq, syserr.TranslateNetstackError(err)
+ }
+
+ if err := t.Block(ch); err != nil {
+ return nil, nil, syserr.FromError(err)
+ }
+ }
+}
+
+// Accept implements the linux syscall accept(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ // Issue the accept request to get the new endpoint.
+ ep, wq, terr := s.Endpoint.Accept()
+ if terr != nil {
+ if terr != tcpip.ErrWouldBlock || !blocking {
+ return 0, nil, 0, syserr.TranslateNetstackError(terr)
+ }
+
+ var err *syserr.Error
+ ep, wq, err = s.blockingAccept(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ ns, err := New(t, s.family, s.skType, s.protocol, wq, ep)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ defer ns.DecRef()
+
+ if flags&linux.SOCK_NONBLOCK != 0 {
+ flags := ns.Flags()
+ flags.NonBlocking = true
+ ns.SetFlags(flags.Settable())
+ }
+
+ var addr linux.SockAddr
+ var addrLen uint32
+ if peerRequested {
+ // Get address of the peer and write it to peer slice.
+ var err *syserr.Error
+ addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ })
+
+ t.Kernel().RecordSocket(ns)
+
+ return fd, addr, addrLen, syserr.FromError(e)
+}
+
+// ConvertShutdown converts Linux shutdown flags into tcpip shutdown flags.
+func ConvertShutdown(how int) (tcpip.ShutdownFlags, *syserr.Error) {
+ var f tcpip.ShutdownFlags
+ switch how {
+ case linux.SHUT_RD:
+ f = tcpip.ShutdownRead
+ case linux.SHUT_WR:
+ f = tcpip.ShutdownWrite
+ case linux.SHUT_RDWR:
+ f = tcpip.ShutdownRead | tcpip.ShutdownWrite
+ default:
+ return 0, syserr.ErrInvalidArgument
+ }
+ return f, nil
+}
+
+// Shutdown implements the linux syscall shutdown(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
+ f, err := ConvertShutdown(how)
+ if err != nil {
+ return err
+ }
+
+ // Issue shutdown request.
+ return syserr.TranslateNetstackError(s.Endpoint.Shutdown(f))
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+ // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
+ // implemented specifically for netstack.SocketOperations rather than
+ // commonEndpoint. commonEndpoint should be extended to support socket
+ // options where the implementation is not shared, as unix sockets need
+ // their own support for SO_TIMESTAMP.
+ if level == linux.SOL_SOCKET && name == linux.SO_TIMESTAMP {
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ val := int32(0)
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ if s.sockOptTimestamp {
+ val = 1
+ }
+ return val, nil
+ }
+ if level == linux.SOL_TCP && name == linux.TCP_INQ {
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ val := int32(0)
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ if s.sockOptInq {
+ val = 1
+ }
+ return val, nil
+ }
+
+ if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
+ switch name {
+ case linux.IPT_SO_GET_INFO:
+ if outLen < linux.SizeOfIPTGetinfo {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr)
+ if err != nil {
+ return nil, err
+ }
+ return info, nil
+
+ case linux.IPT_SO_GET_ENTRIES:
+ if outLen < linux.SizeOfIPTGetEntries {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ entries, err := netfilter.GetEntries(t, stack.(*Stack).Stack, outPtr, outLen)
+ if err != nil {
+ return nil, err
+ }
+ return entries, nil
+
+ }
+ }
+
+ return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen)
+}
+
+// GetSockOpt can be used to implement the linux syscall getsockopt(2) for
+// sockets backed by a commonEndpoint.
+func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
+ switch level {
+ case linux.SOL_SOCKET:
+ return getSockOptSocket(t, s, ep, family, skType, name, outLen)
+
+ case linux.SOL_TCP:
+ return getSockOptTCP(t, ep, name, outLen)
+
+ case linux.SOL_IPV6:
+ return getSockOptIPv6(t, ep, name, outLen)
+
+ case linux.SOL_IP:
+ return getSockOptIP(t, ep, name, outLen, family)
+
+ case linux.SOL_UDP,
+ linux.SOL_ICMPV6,
+ linux.SOL_RAW,
+ linux.SOL_PACKET:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+func boolToInt32(v bool) int32 {
+ if v {
+ return 1
+ }
+ return 0
+}
+
+// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
+func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) {
+ // TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
+ switch name {
+ case linux.SO_ERROR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // Get the last error and convert it.
+ err := ep.GetSockOpt(tcpip.ErrorOption{})
+ if err == nil {
+ return int32(0), nil
+ }
+ return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil
+
+ case linux.SO_PEERCRED:
+ if family != linux.AF_UNIX || outLen < syscall.SizeofUcred {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ tcred := t.Credentials()
+ return syscall.Ucred{
+ Pid: int32(t.ThreadGroup().ID()),
+ Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()),
+ Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()),
+ }, nil
+
+ case linux.SO_PASSCRED:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.PasscredOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.SO_SNDBUF:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ if size > math.MaxInt32 {
+ size = math.MaxInt32
+ }
+
+ return int32(size), nil
+
+ case linux.SO_RCVBUF:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ if size > math.MaxInt32 {
+ size = math.MaxInt32
+ }
+
+ return int32(size), nil
+
+ case linux.SO_REUSEADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.SO_REUSEPORT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReusePortOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.SO_BINDTODEVICE:
+ var v tcpip.BindToDeviceOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ if v == 0 {
+ return []byte{}, nil
+ }
+ if outLen < linux.IFNAMSIZ {
+ return nil, syserr.ErrInvalidArgument
+ }
+ s := t.NetworkContext()
+ if s == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ nic, ok := s.Interfaces()[int32(v)]
+ if !ok {
+ // The NICID no longer indicates a valid interface, probably because that
+ // interface was removed.
+ return nil, syserr.ErrUnknownDevice
+ }
+ return append([]byte(nic.Name), 0), nil
+
+ case linux.SO_BROADCAST:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.BroadcastOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.SO_KEEPALIVE:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.SO_LINGER:
+ if outLen < linux.SizeOfLinger {
+ return nil, syserr.ErrInvalidArgument
+ }
+ return linux.Linger{}, nil
+
+ case linux.SO_SNDTIMEO:
+ // TODO(igudger): Linux allows shorter lengths for partial results.
+ if outLen < linux.SizeOfTimeval {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return linux.NsecToTimeval(s.SendTimeout()), nil
+
+ case linux.SO_RCVTIMEO:
+ // TODO(igudger): Linux allows shorter lengths for partial results.
+ if outLen < linux.SizeOfTimeval {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return linux.NsecToTimeval(s.RecvTimeout()), nil
+
+ case linux.SO_OOBINLINE:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.OutOfBandInlineOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.SO_NO_CHECK:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.NoChecksumOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ default:
+ socket.GetSockOptEmitUnimplementedEvent(t, name)
+ }
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+// getSockOptTCP implements GetSockOpt when level is SOL_TCP.
+func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+ switch name {
+ case linux.TCP_NODELAY:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.DelayOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(!v), nil
+
+ case linux.TCP_CORK:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.CorkOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.TCP_QUICKACK:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.QuickAckOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.TCP_MAXSEG:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptInt(tcpip.MaxSegOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.TCP_KEEPIDLE:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.KeepaliveIdleOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(time.Duration(v) / time.Second), nil
+
+ case linux.TCP_KEEPINTVL:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.KeepaliveIntervalOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(time.Duration(v) / time.Second), nil
+
+ case linux.TCP_KEEPCNT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptInt(tcpip.KeepaliveCountOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.TCP_USER_TIMEOUT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.TCPUserTimeoutOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(time.Duration(v) / time.Millisecond), nil
+
+ case linux.TCP_INFO:
+ var v tcpip.TCPInfoOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ // TODO(b/64800844): Translate fields once they are added to
+ // tcpip.TCPInfoOption.
+ info := linux.TCPInfo{}
+
+ // Linux truncates the output binary to outLen.
+ ib := binary.Marshal(nil, usermem.ByteOrder, &info)
+ if len(ib) > outLen {
+ ib = ib[:outLen]
+ }
+
+ return ib, nil
+
+ case linux.TCP_CC_INFO,
+ linux.TCP_NOTSENT_LOWAT,
+ linux.TCP_ZEROCOPY_RECEIVE:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+
+ case linux.TCP_CONGESTION:
+ if outLen <= 0 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.CongestionControlOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ // We match linux behaviour here where it returns the lower of
+ // TCP_CA_NAME_MAX bytes or the value of the option length.
+ //
+ // This is Linux's net/tcp.h TCP_CA_NAME_MAX.
+ const tcpCANameMax = 16
+
+ toCopy := tcpCANameMax
+ if outLen < tcpCANameMax {
+ toCopy = outLen
+ }
+ b := make([]byte, toCopy)
+ copy(b, v)
+ return b, nil
+
+ case linux.TCP_LINGER2:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.TCPLingerTimeoutOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(time.Duration(v) / time.Second), nil
+
+ case linux.TCP_DEFER_ACCEPT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.TCPDeferAcceptOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(time.Duration(v) / time.Second), nil
+
+ case linux.TCP_SYNCNT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptInt(tcpip.TCPSynCountOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.TCP_WINDOW_CLAMP:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptInt(tcpip.TCPWindowClampOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+ default:
+ emitUnimplementedEventTCP(t, name)
+ }
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
+func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+ switch name {
+ case linux.IPV6_V6ONLY:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.V6OnlyOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.IPV6_PATHMTU:
+ t.Kernel().EmitUnimplementedEvent(t)
+
+ case linux.IPV6_TCLASS:
+ // Length handling for parity with Linux.
+ if outLen == 0 {
+ return make([]byte, 0), nil
+ }
+ v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ uintv := uint32(v)
+ // Linux truncates the output binary to outLen.
+ ib := binary.Marshal(nil, usermem.ByteOrder, &uintv)
+ // Handle cases where outLen is lesser than sizeOfInt32.
+ if len(ib) > outLen {
+ ib = ib[:outLen]
+ }
+ return ib, nil
+
+ case linux.IPV6_RECVTCLASS:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ default:
+ emitUnimplementedEventIPv6(t, name)
+ }
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+// getSockOptIP implements GetSockOpt when level is SOL_IP.
+func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) {
+ switch name {
+ case linux.IP_TTL:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptInt(tcpip.TTLOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ // Fill in the default value, if needed.
+ if v == 0 {
+ v = DefaultTTL
+ }
+
+ return int32(v), nil
+
+ case linux.IP_MULTICAST_TTL:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptInt(tcpip.MulticastTTLOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.IP_MULTICAST_IF:
+ if outLen < len(linux.InetAddr{}) {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.MulticastInterfaceOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
+
+ return a.(*linux.SockAddrInet).Addr, nil
+
+ case linux.IP_MULTICAST_LOOP:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.IP_TOS:
+ // Length handling for parity with Linux.
+ if outLen == 0 {
+ return []byte(nil), nil
+ }
+ v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ if outLen < sizeOfInt32 {
+ return uint8(v), nil
+ }
+ return int32(v), nil
+
+ case linux.IP_RECVTOS:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.IP_PKTINFO:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ default:
+ emitUnimplementedEventIP(t, name)
+ }
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
+ // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
+ // implemented specifically for netstack.SocketOperations rather than
+ // commonEndpoint. commonEndpoint should be extended to support socket
+ // options where the implementation is not shared, as unix sockets need
+ // their own support for SO_TIMESTAMP.
+ if level == linux.SOL_SOCKET && name == linux.SO_TIMESTAMP {
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ s.sockOptTimestamp = usermem.ByteOrder.Uint32(optVal) != 0
+ return nil
+ }
+ if level == linux.SOL_TCP && name == linux.TCP_INQ {
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ s.sockOptInq = usermem.ByteOrder.Uint32(optVal) != 0
+ return nil
+ }
+
+ if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
+ switch name {
+ case linux.IPT_SO_SET_REPLACE:
+ if len(optVal) < linux.SizeOfIPTReplace {
+ return syserr.ErrInvalidArgument
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return syserr.ErrNoDevice
+ }
+ // Stack must be a netstack stack.
+ return netfilter.SetEntries(stack.(*Stack).Stack, optVal)
+
+ case linux.IPT_SO_SET_ADD_COUNTERS:
+ // TODO(gvisor.dev/issue/170): Counter support.
+ return nil
+ }
+ }
+
+ return SetSockOpt(t, s, s.Endpoint, level, name, optVal)
+}
+
+// SetSockOpt can be used to implement the linux syscall setsockopt(2) for
+// sockets backed by a commonEndpoint.
+func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error {
+ switch level {
+ case linux.SOL_SOCKET:
+ return setSockOptSocket(t, s, ep, name, optVal)
+
+ case linux.SOL_TCP:
+ return setSockOptTCP(t, ep, name, optVal)
+
+ case linux.SOL_IPV6:
+ return setSockOptIPv6(t, ep, name, optVal)
+
+ case linux.SOL_IP:
+ return setSockOptIP(t, ep, name, optVal)
+
+ case linux.SOL_UDP,
+ linux.SOL_ICMPV6,
+ linux.SOL_RAW,
+ linux.SOL_PACKET:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+
+ // Default to the old behavior; hand off to network stack.
+ return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+}
+
+// setSockOptSocket implements SetSockOpt when level is SOL_SOCKET.
+func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ switch name {
+ case linux.SO_SNDBUF:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v)))
+
+ case linux.SO_RCVBUF:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v)))
+
+ case linux.SO_REUSEADDR:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0))
+
+ case linux.SO_REUSEPORT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0))
+
+ case linux.SO_BINDTODEVICE:
+ n := bytes.IndexByte(optVal, 0)
+ if n == -1 {
+ n = len(optVal)
+ }
+ name := string(optVal[:n])
+ if name == "" {
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(0)))
+ }
+ s := t.NetworkContext()
+ if s == nil {
+ return syserr.ErrNoDevice
+ }
+ for nicID, nic := range s.Interfaces() {
+ if nic.Name == name {
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(nicID)))
+ }
+ }
+ return syserr.ErrUnknownDevice
+
+ case linux.SO_BROADCAST:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.BroadcastOption, v != 0))
+
+ case linux.SO_PASSCRED:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0))
+
+ case linux.SO_KEEPALIVE:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0))
+
+ case linux.SO_SNDTIMEO:
+ if len(optVal) < linux.SizeOfTimeval {
+ return syserr.ErrInvalidArgument
+ }
+
+ var v linux.Timeval
+ binary.Unmarshal(optVal[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
+ return syserr.ErrDomain
+ }
+ s.SetSendTimeout(v.ToNsecCapped())
+ return nil
+
+ case linux.SO_RCVTIMEO:
+ if len(optVal) < linux.SizeOfTimeval {
+ return syserr.ErrInvalidArgument
+ }
+
+ var v linux.Timeval
+ binary.Unmarshal(optVal[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
+ return syserr.ErrDomain
+ }
+ s.SetRecvTimeout(v.ToNsecCapped())
+ return nil
+
+ case linux.SO_OOBINLINE:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+
+ if v == 0 {
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.OutOfBandInlineOption(v)))
+
+ case linux.SO_NO_CHECK:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0))
+
+ case linux.SO_LINGER:
+ if len(optVal) < linux.SizeOfLinger {
+ return syserr.ErrInvalidArgument
+ }
+
+ var v linux.Linger
+ binary.Unmarshal(optVal[:linux.SizeOfLinger], usermem.ByteOrder, &v)
+
+ if v != (linux.Linger{}) {
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ }
+
+ return nil
+
+ default:
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ }
+
+ // Default to the old behavior; hand off to network stack.
+ return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+}
+
+// setSockOptTCP implements SetSockOpt when level is SOL_TCP.
+func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ switch name {
+ case linux.TCP_NODELAY:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0))
+
+ case linux.TCP_CORK:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0))
+
+ case linux.TCP_QUICKACK:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0))
+
+ case linux.TCP_MAXSEG:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MaxSegOption, int(v)))
+
+ case linux.TCP_KEEPIDLE:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ if v < 1 || v > linux.MAX_TCP_KEEPIDLE {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIdleOption(time.Second * time.Duration(v))))
+
+ case linux.TCP_KEEPINTVL:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ if v < 1 || v > linux.MAX_TCP_KEEPINTVL {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v))))
+
+ case linux.TCP_KEEPCNT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ if v < 1 || v > linux.MAX_TCP_KEEPCNT {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.KeepaliveCountOption, int(v)))
+
+ case linux.TCP_USER_TIMEOUT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+ if v < 0 {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v))))
+
+ case linux.TCP_CONGESTION:
+ v := tcpip.CongestionControlOption(optVal)
+ if err := ep.SetSockOpt(v); err != nil {
+ return syserr.TranslateNetstackError(err)
+ }
+ return nil
+
+ case linux.TCP_LINGER2:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v))))
+
+ case linux.TCP_DEFER_ACCEPT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+ if v < 0 {
+ v = 0
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v))))
+
+ case linux.TCP_SYNCNT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := usermem.ByteOrder.Uint32(optVal)
+
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TCPSynCountOption, int(v)))
+
+ case linux.TCP_WINDOW_CLAMP:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := usermem.ByteOrder.Uint32(optVal)
+
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TCPWindowClampOption, int(v)))
+
+ case linux.TCP_REPAIR_OPTIONS:
+ t.Kernel().EmitUnimplementedEvent(t)
+
+ default:
+ emitUnimplementedEventTCP(t, name)
+ }
+
+ // Default to the old behavior; hand off to network stack.
+ return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+}
+
+// setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6.
+func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ switch name {
+ case linux.IPV6_V6ONLY:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0))
+
+ case linux.IPV6_ADD_MEMBERSHIP,
+ linux.IPV6_DROP_MEMBERSHIP,
+ linux.IPV6_IPSEC_POLICY,
+ linux.IPV6_JOIN_ANYCAST,
+ linux.IPV6_LEAVE_ANYCAST,
+ // TODO(b/148887420): Add support for IPV6_PKTINFO.
+ linux.IPV6_PKTINFO,
+ linux.IPV6_ROUTER_ALERT,
+ linux.IPV6_XFRM_POLICY,
+ linux.MCAST_BLOCK_SOURCE,
+ linux.MCAST_JOIN_GROUP,
+ linux.MCAST_JOIN_SOURCE_GROUP,
+ linux.MCAST_LEAVE_GROUP,
+ linux.MCAST_LEAVE_SOURCE_GROUP,
+ linux.MCAST_UNBLOCK_SOURCE:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+
+ case linux.IPV6_TCLASS:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+ if v < -1 || v > 255 {
+ return syserr.ErrInvalidArgument
+ }
+ if v == -1 {
+ v = 0
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, int(v)))
+
+ case linux.IPV6_RECVTCLASS:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0))
+
+ default:
+ emitUnimplementedEventIPv6(t, name)
+ }
+
+ // Default to the old behavior; hand off to network stack.
+ return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+}
+
+var (
+ inetMulticastRequestSize = int(binary.Size(linux.InetMulticastRequest{}))
+ inetMulticastRequestWithNICSize = int(binary.Size(linux.InetMulticastRequestWithNIC{}))
+)
+
+// copyInMulticastRequest copies in a variable-size multicast request. The
+// kernel determines which structure was passed by its length. IP_MULTICAST_IF
+// supports ip_mreqn, ip_mreq and in_addr, while IP_ADD_MEMBERSHIP and
+// IP_DROP_MEMBERSHIP only support ip_mreqn and ip_mreq. To handle this,
+// allowAddr controls whether in_addr is accepted or rejected.
+func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastRequestWithNIC, *syserr.Error) {
+ if len(optVal) < len(linux.InetAddr{}) {
+ return linux.InetMulticastRequestWithNIC{}, syserr.ErrInvalidArgument
+ }
+
+ if len(optVal) < inetMulticastRequestSize {
+ if !allowAddr {
+ return linux.InetMulticastRequestWithNIC{}, syserr.ErrInvalidArgument
+ }
+
+ var req linux.InetMulticastRequestWithNIC
+ copy(req.InterfaceAddr[:], optVal)
+ return req, nil
+ }
+
+ if len(optVal) >= inetMulticastRequestWithNICSize {
+ var req linux.InetMulticastRequestWithNIC
+ binary.Unmarshal(optVal[:inetMulticastRequestWithNICSize], usermem.ByteOrder, &req)
+ return req, nil
+ }
+
+ var req linux.InetMulticastRequestWithNIC
+ binary.Unmarshal(optVal[:inetMulticastRequestSize], usermem.ByteOrder, &req.InetMulticastRequest)
+ return req, nil
+}
+
+// parseIntOrChar copies either a 32-bit int or an 8-bit uint out of buf.
+//
+// net/ipv4/ip_sockglue.c:do_ip_setsockopt does this for its socket options.
+func parseIntOrChar(buf []byte) (int32, *syserr.Error) {
+ if len(buf) == 0 {
+ return 0, syserr.ErrInvalidArgument
+ }
+
+ if len(buf) >= sizeOfInt32 {
+ return int32(usermem.ByteOrder.Uint32(buf)), nil
+ }
+
+ return int32(buf[0]), nil
+}
+
+// setSockOptIP implements SetSockOpt when level is SOL_IP.
+func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ switch name {
+ case linux.IP_MULTICAST_TTL:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ if v == -1 {
+ // Linux translates -1 to 1.
+ v = 1
+ }
+ if v < 0 || v > 255 {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MulticastTTLOption, int(v)))
+
+ case linux.IP_ADD_MEMBERSHIP:
+ req, err := copyInMulticastRequest(optVal, false /* allowAddr */)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.AddMembershipOption{
+ NIC: tcpip.NICID(req.InterfaceIndex),
+ // TODO(igudger): Change AddMembership to use the standard
+ // any address representation.
+ InterfaceAddr: tcpip.Address(req.InterfaceAddr[:]),
+ MulticastAddr: tcpip.Address(req.MulticastAddr[:]),
+ }))
+
+ case linux.IP_DROP_MEMBERSHIP:
+ req, err := copyInMulticastRequest(optVal, false /* allowAddr */)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.RemoveMembershipOption{
+ NIC: tcpip.NICID(req.InterfaceIndex),
+ // TODO(igudger): Change DropMembership to use the standard
+ // any address representation.
+ InterfaceAddr: tcpip.Address(req.InterfaceAddr[:]),
+ MulticastAddr: tcpip.Address(req.MulticastAddr[:]),
+ }))
+
+ case linux.IP_MULTICAST_IF:
+ req, err := copyInMulticastRequest(optVal, true /* allowAddr */)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MulticastInterfaceOption{
+ NIC: tcpip.NICID(req.InterfaceIndex),
+ InterfaceAddr: bytesToIPAddress(req.InterfaceAddr[:]),
+ }))
+
+ case linux.IP_MULTICAST_LOOP:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0))
+
+ case linux.MCAST_JOIN_GROUP:
+ // FIXME(b/124219304): Implement MCAST_JOIN_GROUP.
+ t.Kernel().EmitUnimplementedEvent(t)
+ return syserr.ErrInvalidArgument
+
+ case linux.IP_TTL:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ // -1 means default TTL.
+ if v == -1 {
+ v = 0
+ } else if v < 1 || v > 255 {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TTLOption, int(v)))
+
+ case linux.IP_TOS:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv4TOSOption, int(v)))
+
+ case linux.IP_RECVTOS:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0))
+
+ case linux.IP_PKTINFO:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0))
+
+ case linux.IP_HDRINCL:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0))
+
+ case linux.IP_ADD_SOURCE_MEMBERSHIP,
+ linux.IP_BIND_ADDRESS_NO_PORT,
+ linux.IP_BLOCK_SOURCE,
+ linux.IP_CHECKSUM,
+ linux.IP_DROP_SOURCE_MEMBERSHIP,
+ linux.IP_FREEBIND,
+ linux.IP_IPSEC_POLICY,
+ linux.IP_MINTTL,
+ linux.IP_MSFILTER,
+ linux.IP_MTU_DISCOVER,
+ linux.IP_MULTICAST_ALL,
+ linux.IP_NODEFRAG,
+ linux.IP_OPTIONS,
+ linux.IP_PASSSEC,
+ linux.IP_RECVERR,
+ linux.IP_RECVFRAGSIZE,
+ linux.IP_RECVOPTS,
+ linux.IP_RECVORIGDSTADDR,
+ linux.IP_RECVTTL,
+ linux.IP_RETOPTS,
+ linux.IP_TRANSPARENT,
+ linux.IP_UNBLOCK_SOURCE,
+ linux.IP_UNICAST_IF,
+ linux.IP_XFRM_POLICY,
+ linux.MCAST_BLOCK_SOURCE,
+ linux.MCAST_JOIN_SOURCE_GROUP,
+ linux.MCAST_LEAVE_GROUP,
+ linux.MCAST_LEAVE_SOURCE_GROUP,
+ linux.MCAST_MSFILTER,
+ linux.MCAST_UNBLOCK_SOURCE:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+
+ // Default to the old behavior; hand off to network stack.
+ return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+}
+
+// emitUnimplementedEventTCP emits unimplemented event if name is valid. This
+// function contains names that are common between Get and SetSockOpt when
+// level is SOL_TCP.
+func emitUnimplementedEventTCP(t *kernel.Task, name int) {
+ switch name {
+ case linux.TCP_CONGESTION,
+ linux.TCP_CORK,
+ linux.TCP_FASTOPEN,
+ linux.TCP_FASTOPEN_CONNECT,
+ linux.TCP_FASTOPEN_KEY,
+ linux.TCP_FASTOPEN_NO_COOKIE,
+ linux.TCP_QUEUE_SEQ,
+ linux.TCP_REPAIR,
+ linux.TCP_REPAIR_QUEUE,
+ linux.TCP_REPAIR_WINDOW,
+ linux.TCP_SAVED_SYN,
+ linux.TCP_SAVE_SYN,
+ linux.TCP_THIN_DUPACK,
+ linux.TCP_THIN_LINEAR_TIMEOUTS,
+ linux.TCP_TIMESTAMP,
+ linux.TCP_ULP:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+}
+
+// emitUnimplementedEventIPv6 emits unimplemented event if name is valid. It
+// contains names that are common between Get and SetSockOpt when level is
+// SOL_IPV6.
+func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
+ switch name {
+ case linux.IPV6_2292DSTOPTS,
+ linux.IPV6_2292HOPLIMIT,
+ linux.IPV6_2292HOPOPTS,
+ linux.IPV6_2292PKTINFO,
+ linux.IPV6_2292PKTOPTIONS,
+ linux.IPV6_2292RTHDR,
+ linux.IPV6_ADDR_PREFERENCES,
+ linux.IPV6_AUTOFLOWLABEL,
+ linux.IPV6_DONTFRAG,
+ linux.IPV6_DSTOPTS,
+ linux.IPV6_FLOWINFO,
+ linux.IPV6_FLOWINFO_SEND,
+ linux.IPV6_FLOWLABEL_MGR,
+ linux.IPV6_FREEBIND,
+ linux.IPV6_HOPOPTS,
+ linux.IPV6_MINHOPCOUNT,
+ linux.IPV6_MTU,
+ linux.IPV6_MTU_DISCOVER,
+ linux.IPV6_MULTICAST_ALL,
+ linux.IPV6_MULTICAST_HOPS,
+ linux.IPV6_MULTICAST_IF,
+ linux.IPV6_MULTICAST_LOOP,
+ linux.IPV6_RECVDSTOPTS,
+ linux.IPV6_RECVERR,
+ linux.IPV6_RECVFRAGSIZE,
+ linux.IPV6_RECVHOPLIMIT,
+ linux.IPV6_RECVHOPOPTS,
+ linux.IPV6_RECVORIGDSTADDR,
+ linux.IPV6_RECVPATHMTU,
+ linux.IPV6_RECVPKTINFO,
+ linux.IPV6_RECVRTHDR,
+ linux.IPV6_RTHDR,
+ linux.IPV6_RTHDRDSTOPTS,
+ linux.IPV6_TCLASS,
+ linux.IPV6_TRANSPARENT,
+ linux.IPV6_UNICAST_HOPS,
+ linux.IPV6_UNICAST_IF,
+ linux.MCAST_MSFILTER,
+ linux.IPV6_ADDRFORM:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+}
+
+// emitUnimplementedEventIP emits unimplemented event if name is valid. It
+// contains names that are common between Get and SetSockOpt when level is
+// SOL_IP.
+func emitUnimplementedEventIP(t *kernel.Task, name int) {
+ switch name {
+ case linux.IP_TOS,
+ linux.IP_TTL,
+ linux.IP_HDRINCL,
+ linux.IP_OPTIONS,
+ linux.IP_ROUTER_ALERT,
+ linux.IP_RECVOPTS,
+ linux.IP_RETOPTS,
+ linux.IP_PKTINFO,
+ linux.IP_PKTOPTIONS,
+ linux.IP_MTU_DISCOVER,
+ linux.IP_RECVERR,
+ linux.IP_RECVTTL,
+ linux.IP_RECVTOS,
+ linux.IP_MTU,
+ linux.IP_FREEBIND,
+ linux.IP_IPSEC_POLICY,
+ linux.IP_XFRM_POLICY,
+ linux.IP_PASSSEC,
+ linux.IP_TRANSPARENT,
+ linux.IP_ORIGDSTADDR,
+ linux.IP_MINTTL,
+ linux.IP_NODEFRAG,
+ linux.IP_CHECKSUM,
+ linux.IP_BIND_ADDRESS_NO_PORT,
+ linux.IP_RECVFRAGSIZE,
+ linux.IP_MULTICAST_IF,
+ linux.IP_MULTICAST_TTL,
+ linux.IP_MULTICAST_LOOP,
+ linux.IP_ADD_MEMBERSHIP,
+ linux.IP_DROP_MEMBERSHIP,
+ linux.IP_UNBLOCK_SOURCE,
+ linux.IP_BLOCK_SOURCE,
+ linux.IP_ADD_SOURCE_MEMBERSHIP,
+ linux.IP_DROP_SOURCE_MEMBERSHIP,
+ linux.IP_MSFILTER,
+ linux.MCAST_JOIN_GROUP,
+ linux.MCAST_BLOCK_SOURCE,
+ linux.MCAST_UNBLOCK_SOURCE,
+ linux.MCAST_LEAVE_GROUP,
+ linux.MCAST_JOIN_SOURCE_GROUP,
+ linux.MCAST_LEAVE_SOURCE_GROUP,
+ linux.MCAST_MSFILTER,
+ linux.IP_MULTICAST_ALL,
+ linux.IP_UNICAST_IF:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+}
+
+// isLinkLocal determines if the given IPv6 address is link-local. This is the
+// case when it has the fe80::/10 prefix. This check is used to determine when
+// the NICID is relevant for a given IPv6 address.
+func isLinkLocal(addr tcpip.Address) bool {
+ return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80
+}
+
+// ConvertAddress converts the given address to a native format.
+func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) {
+ switch family {
+ case linux.AF_UNIX:
+ var out linux.SockAddrUnix
+ out.Family = linux.AF_UNIX
+ l := len([]byte(addr.Addr))
+ for i := 0; i < l; i++ {
+ out.Path[i] = int8(addr.Addr[i])
+ }
+
+ // Linux returns the used length of the address struct (including the
+ // null terminator) for filesystem paths. The Family field is 2 bytes.
+ // It is sometimes allowed to exclude the null terminator if the
+ // address length is the max. Abstract and empty paths always return
+ // the full exact length.
+ if l == 0 || out.Path[0] == 0 || l == len(out.Path) {
+ return &out, uint32(2 + l)
+ }
+ return &out, uint32(3 + l)
+
+ case linux.AF_INET:
+ var out linux.SockAddrInet
+ copy(out.Addr[:], addr.Addr)
+ out.Family = linux.AF_INET
+ out.Port = htons(addr.Port)
+ return &out, uint32(sockAddrInetSize)
+
+ case linux.AF_INET6:
+ var out linux.SockAddrInet6
+ if len(addr.Addr) == header.IPv4AddressSize {
+ // Copy address in v4-mapped format.
+ copy(out.Addr[12:], addr.Addr)
+ out.Addr[10] = 0xff
+ out.Addr[11] = 0xff
+ } else {
+ copy(out.Addr[:], addr.Addr)
+ }
+ out.Family = linux.AF_INET6
+ out.Port = htons(addr.Port)
+ if isLinkLocal(addr.Addr) {
+ out.Scope_id = uint32(addr.NIC)
+ }
+ return &out, uint32(sockAddrInet6Size)
+
+ case linux.AF_PACKET:
+ // TODO(b/129292371): Return protocol too.
+ var out linux.SockAddrLink
+ out.Family = linux.AF_PACKET
+ out.InterfaceIndex = int32(addr.NIC)
+ out.HardwareAddrLen = header.EthernetAddressSize
+ copy(out.HardwareAddr[:], addr.Addr)
+ return &out, uint32(sockAddrLinkSize)
+
+ default:
+ return nil, 0
+ }
+}
+
+// GetSockName implements the linux syscall getsockname(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+ addr, err := s.Endpoint.GetLocalAddress()
+ if err != nil {
+ return nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ a, l := ConvertAddress(s.family, addr)
+ return a, l, nil
+}
+
+// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+ addr, err := s.Endpoint.GetRemoteAddress()
+ if err != nil {
+ return nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ a, l := ConvertAddress(s.family, addr)
+ return a, l, nil
+}
+
+// coalescingRead is the fast path for non-blocking, non-peek, stream-based
+// case. It coalesces as many packets as possible before returning to the
+// caller.
+//
+// Precondition: s.readMu must be locked.
+func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequence, discard bool) (int, *syserr.Error) {
+ var err *syserr.Error
+ var copied int
+
+ // Copy as many views as possible into the user-provided buffer.
+ for {
+ // Always do at least one fetchReadView, even if the number of bytes to
+ // read is 0.
+ err = s.fetchReadView()
+ if err != nil {
+ break
+ }
+ if dst.NumBytes() == 0 {
+ break
+ }
+
+ var n int
+ var e error
+ if discard {
+ n = len(s.readView)
+ if int64(n) > dst.NumBytes() {
+ n = int(dst.NumBytes())
+ }
+ } else {
+ n, e = dst.CopyOut(ctx, s.readView)
+ // Set the control message, even if 0 bytes were read.
+ if e == nil {
+ s.updateTimestamp()
+ }
+ }
+ copied += n
+ s.readView.TrimFront(n)
+ if len(s.readView) == 0 {
+ atomic.StoreUint32(&s.readViewHasData, 0)
+ }
+
+ dst = dst.DropFirst(n)
+ if e != nil {
+ err = syserr.FromError(e)
+ break
+ }
+ }
+
+ // If we managed to copy something, we must deliver it.
+ if copied > 0 {
+ s.Endpoint.ModerateRecvBuf(copied)
+ return copied, nil
+ }
+
+ return 0, err
+}
+
+func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
+ if !s.sockOptInq {
+ return
+ }
+ rcvBufUsed, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if err != nil {
+ return
+ }
+ cmsg.IP.HasInq = true
+ cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed)
+}
+
+// nonBlockingRead issues a non-blocking read.
+//
+// TODO(b/78348848): Support timestamps for stream sockets.
+func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+ isPacket := s.isPacketBased()
+
+ // Fast path for regular reads from stream (e.g., TCP) endpoints. Note
+ // that senderRequested is ignored for stream sockets.
+ if !peek && !isPacket {
+ // TCP sockets discard the data if MSG_TRUNC is set.
+ //
+ // This behavior is documented in man 7 tcp:
+ // Since version 2.4, Linux supports the use of MSG_TRUNC in the flags
+ // argument of recv(2) (and recvmsg(2)). This flag causes the received
+ // bytes of data to be discarded, rather than passed back in a
+ // caller-supplied buffer.
+ s.readMu.Lock()
+ n, err := s.coalescingRead(ctx, dst, trunc)
+ cmsg := s.controlMessages()
+ s.fillCmsgInq(&cmsg)
+ s.readMu.Unlock()
+ return n, 0, nil, 0, cmsg, err
+ }
+
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+
+ if err := s.fetchReadView(); err != nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, err
+ }
+
+ if !isPacket && peek && trunc {
+ // MSG_TRUNC with MSG_PEEK on a TCP socket returns the
+ // amount that could be read.
+ rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if err != nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
+ }
+ available := len(s.readView) + int(rql)
+ bufLen := int(dst.NumBytes())
+ if available < bufLen {
+ return available, 0, nil, 0, socket.ControlMessages{}, nil
+ }
+ return bufLen, 0, nil, 0, socket.ControlMessages{}, nil
+ }
+
+ n, err := dst.CopyOut(ctx, s.readView)
+ // Set the control message, even if 0 bytes were read.
+ if err == nil {
+ s.updateTimestamp()
+ }
+ var addr linux.SockAddr
+ var addrLen uint32
+ if isPacket && senderRequested {
+ addr, addrLen = ConvertAddress(s.family, s.sender)
+ }
+
+ if peek {
+ if l := len(s.readView); trunc && l > n {
+ // isPacket must be true.
+ return l, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ }
+
+ if isPacket || err != nil {
+ return n, 0, addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ }
+
+ // We need to peek beyond the first message.
+ dst = dst.DropFirst(n)
+ num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) {
+ n, _, err := s.Endpoint.Peek(dsts)
+ // TODO(b/78348848): Handle peek timestamp.
+ if err != nil {
+ return int64(n), syserr.TranslateNetstackError(err).ToError()
+ }
+ return int64(n), nil
+ }})
+ n += int(num)
+ if err == syserror.ErrWouldBlock && n > 0 {
+ // We got some data, so no need to return an error.
+ err = nil
+ }
+ return n, 0, nil, 0, s.controlMessages(), syserr.FromError(err)
+ }
+
+ var msgLen int
+ if isPacket {
+ msgLen = len(s.readView)
+ s.readView = nil
+ } else {
+ msgLen = int(n)
+ s.readView.TrimFront(int(n))
+ }
+
+ if len(s.readView) == 0 {
+ atomic.StoreUint32(&s.readViewHasData, 0)
+ }
+
+ var flags int
+ if msgLen > int(n) {
+ flags |= linux.MSG_TRUNC
+ }
+
+ if trunc {
+ n = msgLen
+ }
+
+ cmsg := s.controlMessages()
+ s.fillCmsgInq(&cmsg)
+ return n, flags, addr, addrLen, cmsg, syserr.FromError(err)
+}
+
+func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
+ return socket.ControlMessages{
+ IP: tcpip.ControlMessages{
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ HasTClass: s.readCM.HasTClass,
+ TClass: s.readCM.TClass,
+ HasIPPacketInfo: s.readCM.HasIPPacketInfo,
+ PacketInfo: s.readCM.PacketInfo,
+ },
+ }
+}
+
+// updateTimestamp sets the timestamp for SIOCGSTAMP. It should be called after
+// successfully writing packet data out to userspace.
+//
+// Precondition: s.readMu must be locked.
+func (s *socketOpsCommon) updateTimestamp() {
+ // Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled.
+ if !s.sockOptTimestamp {
+ s.timestampValid = true
+ s.timestampNS = s.readCM.Timestamp
+ }
+}
+
+// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+ trunc := flags&linux.MSG_TRUNC != 0
+ peek := flags&linux.MSG_PEEK != 0
+ dontWait := flags&linux.MSG_DONTWAIT != 0
+ waitAll := flags&linux.MSG_WAITALL != 0
+ if senderRequested && !s.isPacketBased() {
+ // Stream sockets ignore the sender address.
+ senderRequested = false
+ }
+ n, msgFlags, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
+
+ if s.isPacketBased() && err == syserr.ErrClosedForReceive && flags&linux.MSG_DONTWAIT != 0 {
+ // In this situation we should return EAGAIN.
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ }
+
+ if err != nil && (err != syserr.ErrWouldBlock || dontWait) {
+ // Read failed and we should not retry.
+ return 0, 0, nil, 0, socket.ControlMessages{}, err
+ }
+
+ if err == nil && (dontWait || !waitAll || s.isPacketBased() || int64(n) >= dst.NumBytes()) {
+ // We got all the data we need.
+ return
+ }
+
+ // Don't overwrite any data we received.
+ dst = dst.DropFirst(n)
+
+ // We'll have to block. Register for notifications and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ for {
+ var rn int
+ rn, msgFlags, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
+ n += rn
+ if err != nil && err != syserr.ErrWouldBlock {
+ // Always stop on errors other than would block as we generally
+ // won't be able to get any more data. Eat the error if we got
+ // any data.
+ if n > 0 {
+ err = nil
+ }
+ return
+ }
+ if err == nil && (s.isPacketBased() || !waitAll || int64(rn) >= dst.NumBytes()) {
+ // We got all the data we need.
+ return
+ }
+ dst = dst.DropFirst(rn)
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if n > 0 {
+ return n, msgFlags, senderAddr, senderAddrLen, controlMessages, nil
+ }
+ if err == syserror.ETIMEDOUT {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ }
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
+ }
+}
+
+// SendMsg implements the linux syscall sendmsg(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+ // Reject Unix control messages.
+ if !controlMessages.Unix.Empty() {
+ return 0, syserr.ErrInvalidArgument
+ }
+
+ var addr *tcpip.FullAddress
+ if len(to) > 0 {
+ addrBuf, family, err := AddressAndFamily(to)
+ if err != nil {
+ return 0, err
+ }
+ if err := s.checkFamily(family, false /* exact */); err != nil {
+ return 0, err
+ }
+ addrBuf = s.mapFamily(addrBuf, family)
+
+ addr = &addrBuf
+ }
+
+ opts := tcpip.WriteOptions{
+ To: addr,
+ More: flags&linux.MSG_MORE != 0,
+ EndOfRecord: flags&linux.MSG_EOR != 0,
+ }
+
+ v := &ioSequencePayload{t, src}
+ n, resCh, err := s.Endpoint.Write(v, opts)
+ if resCh != nil {
+ if err := t.Block(resCh); err != nil {
+ return 0, syserr.FromError(err)
+ }
+ n, _, err = s.Endpoint.Write(v, opts)
+ }
+ dontWait := flags&linux.MSG_DONTWAIT != 0
+ if err == nil && (n >= v.src.NumBytes() || dontWait) {
+ // Complete write.
+ return int(n), nil
+ }
+ if err != nil && (err != tcpip.ErrWouldBlock || dontWait) {
+ return int(n), syserr.TranslateNetstackError(err)
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+
+ v.DropFirst(int(n))
+ total := n
+ for {
+ n, _, err = s.Endpoint.Write(v, opts)
+ v.DropFirst(int(n))
+ total += n
+
+ if err != nil && err != tcpip.ErrWouldBlock && total == 0 {
+ return 0, syserr.TranslateNetstackError(err)
+ }
+
+ if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock {
+ return int(total), nil
+ }
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return int(total), syserr.ErrTryAgain
+ }
+ // handleIOError will consume errors from t.Block if needed.
+ return int(total), syserr.FromError(err)
+ }
+ }
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return s.socketOpsCommon.ioctl(ctx, io, args)
+}
+
+func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ // SIOCGSTAMP is implemented by netstack rather than all commonEndpoint
+ // sockets.
+ // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP.
+ switch args[1].Int() {
+ case syscall.SIOCGSTAMP:
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ if !s.timestampValid {
+ return 0, syserror.ENOENT
+ }
+
+ tv := linux.NsecToTimeval(s.timestampNS)
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &tv, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TIOCINQ:
+ v, terr := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if terr != nil {
+ return 0, syserr.TranslateNetstackError(terr).ToError()
+ }
+
+ // Add bytes removed from the endpoint but not yet sent to the caller.
+ s.readMu.Lock()
+ v += len(s.readView)
+ s.readMu.Unlock()
+
+ if v > math.MaxInt32 {
+ v = math.MaxInt32
+ }
+
+ // Copy result to userspace.
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ }
+
+ return Ioctl(ctx, s.Endpoint, io, args)
+}
+
+// Ioctl performs a socket ioctl.
+func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch arg := int(args[1].Int()); arg {
+ case syscall.SIOCGIFFLAGS,
+ syscall.SIOCGIFADDR,
+ syscall.SIOCGIFBRDADDR,
+ syscall.SIOCGIFDSTADDR,
+ syscall.SIOCGIFHWADDR,
+ syscall.SIOCGIFINDEX,
+ syscall.SIOCGIFMAP,
+ syscall.SIOCGIFMETRIC,
+ syscall.SIOCGIFMTU,
+ syscall.SIOCGIFNAME,
+ syscall.SIOCGIFNETMASK,
+ syscall.SIOCGIFTXQLEN:
+
+ var ifr linux.IFReq
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ if err := interfaceIoctl(ctx, io, arg, &ifr); err != nil {
+ return 0, err.ToError()
+ }
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case syscall.SIOCGIFCONF:
+ // Return a list of interface addresses or the buffer size
+ // necessary to hold the list.
+ var ifc linux.IFConf
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+
+ if err := ifconfIoctl(ctx, io, &ifc); err != nil {
+ return 0, err
+ }
+
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+
+ return 0, err
+
+ case linux.TIOCINQ:
+ v, terr := ep.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if terr != nil {
+ return 0, syserr.TranslateNetstackError(terr).ToError()
+ }
+
+ if v > math.MaxInt32 {
+ v = math.MaxInt32
+ }
+ // Copy result to userspace.
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TIOCOUTQ:
+ v, terr := ep.GetSockOptInt(tcpip.SendQueueSizeOption)
+ if terr != nil {
+ return 0, syserr.TranslateNetstackError(terr).ToError()
+ }
+
+ if v > math.MaxInt32 {
+ v = math.MaxInt32
+ }
+
+ // Copy result to userspace.
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.SIOCGIFMEM, linux.SIOCGIFPFLAGS, linux.SIOCGMIIPHY, linux.SIOCGMIIREG:
+ unimpl.EmitUnimplementedEvent(ctx)
+ }
+
+ return 0, syserror.ENOTTY
+}
+
+// interfaceIoctl implements interface requests.
+func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFReq) *syserr.Error {
+ var (
+ iface inet.Interface
+ index int32
+ found bool
+ )
+
+ // Find the relevant device.
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ return syserr.ErrNoDevice
+ }
+
+ // SIOCGIFNAME uses ifr.ifr_ifindex rather than ifr.ifr_name to
+ // identify a device.
+ if arg == syscall.SIOCGIFNAME {
+ // Gets the name of the interface given the interface index
+ // stored in ifr_ifindex.
+ index = int32(usermem.ByteOrder.Uint32(ifr.Data[:4]))
+ if iface, ok := stack.Interfaces()[index]; ok {
+ ifr.SetName(iface.Name)
+ return nil
+ }
+ return syserr.ErrNoDevice
+ }
+
+ // Find the relevant device.
+ for index, iface = range stack.Interfaces() {
+ if iface.Name == ifr.Name() {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return syserr.ErrNoDevice
+ }
+
+ switch arg {
+ case syscall.SIOCGIFINDEX:
+ // Copy out the index to the data.
+ usermem.ByteOrder.PutUint32(ifr.Data[:], uint32(index))
+
+ case syscall.SIOCGIFHWADDR:
+ // Copy the hardware address out.
+ ifr.Data[0] = 6 // IEEE802.2 arp type.
+ ifr.Data[1] = 0
+ n := copy(ifr.Data[2:], iface.Addr)
+ for i := 2 + n; i < len(ifr.Data); i++ {
+ ifr.Data[i] = 0 // Clear padding.
+ }
+ usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(n))
+
+ case syscall.SIOCGIFFLAGS:
+ f, err := interfaceStatusFlags(stack, iface.Name)
+ if err != nil {
+ return err
+ }
+ // Drop the flags that don't fit in the size that we need to return. This
+ // matches Linux behavior.
+ usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(f))
+
+ case syscall.SIOCGIFADDR:
+ // Copy the IPv4 address out.
+ for _, addr := range stack.InterfaceAddrs()[index] {
+ // This ioctl is only compatible with AF_INET addresses.
+ if addr.Family != linux.AF_INET {
+ continue
+ }
+ copy(ifr.Data[4:8], addr.Addr)
+ break
+ }
+
+ case syscall.SIOCGIFMETRIC:
+ // Gets the metric of the device. As per netdevice(7), this
+ // always just sets ifr_metric to 0.
+ usermem.ByteOrder.PutUint32(ifr.Data[:4], 0)
+
+ case syscall.SIOCGIFMTU:
+ // Gets the MTU of the device.
+ usermem.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU)
+
+ case syscall.SIOCGIFMAP:
+ // Gets the hardware parameters of the device.
+ // TODO(gvisor.dev/issue/505): Implement.
+
+ case syscall.SIOCGIFTXQLEN:
+ // Gets the transmit queue length of the device.
+ // TODO(gvisor.dev/issue/505): Implement.
+
+ case syscall.SIOCGIFDSTADDR:
+ // Gets the destination address of a point-to-point device.
+ // TODO(gvisor.dev/issue/505): Implement.
+
+ case syscall.SIOCGIFBRDADDR:
+ // Gets the broadcast address of a device.
+ // TODO(gvisor.dev/issue/505): Implement.
+
+ case syscall.SIOCGIFNETMASK:
+ // Gets the network mask of a device.
+ for _, addr := range stack.InterfaceAddrs()[index] {
+ // This ioctl is only compatible with AF_INET addresses.
+ if addr.Family != linux.AF_INET {
+ continue
+ }
+ // Populate ifr.ifr_netmask (type sockaddr).
+ usermem.ByteOrder.PutUint16(ifr.Data[0:2], uint16(linux.AF_INET))
+ usermem.ByteOrder.PutUint16(ifr.Data[2:4], 0)
+ var mask uint32 = 0xffffffff << (32 - addr.PrefixLen)
+ // Netmask is expected to be returned as a big endian
+ // value.
+ binary.BigEndian.PutUint32(ifr.Data[4:8], mask)
+ break
+ }
+
+ default:
+ // Not a valid call.
+ return syserr.ErrInvalidArgument
+ }
+
+ return nil
+}
+
+// ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl.
+func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error {
+ // If Ptr is NULL, return the necessary buffer size via Len.
+ // Otherwise, write up to Len bytes starting at Ptr containing ifreq
+ // structs.
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ return syserr.ErrNoDevice.ToError()
+ }
+
+ if ifc.Ptr == 0 {
+ ifc.Len = int32(len(stack.Interfaces())) * int32(linux.SizeOfIFReq)
+ return nil
+ }
+
+ max := ifc.Len
+ ifc.Len = 0
+ for key, ifaceAddrs := range stack.InterfaceAddrs() {
+ iface := stack.Interfaces()[key]
+ for _, ifaceAddr := range ifaceAddrs {
+ // Don't write past the end of the buffer.
+ if ifc.Len+int32(linux.SizeOfIFReq) > max {
+ break
+ }
+ if ifaceAddr.Family != linux.AF_INET {
+ continue
+ }
+
+ // Populate ifr.ifr_addr.
+ ifr := linux.IFReq{}
+ ifr.SetName(iface.Name)
+ usermem.ByteOrder.PutUint16(ifr.Data[0:2], uint16(ifaceAddr.Family))
+ usermem.ByteOrder.PutUint16(ifr.Data[2:4], 0)
+ copy(ifr.Data[4:8], ifaceAddr.Addr[:4])
+
+ // Copy the ifr to userspace.
+ dst := uintptr(ifc.Ptr) + uintptr(ifc.Len)
+ ifc.Len += int32(linux.SizeOfIFReq)
+ if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// interfaceStatusFlags returns status flags for an interface in the stack.
+// Flag values and meanings are described in greater detail in netdevice(7) in
+// the SIOCGIFFLAGS section.
+func interfaceStatusFlags(stack inet.Stack, name string) (uint32, *syserr.Error) {
+ // We should only ever be passed a netstack.Stack.
+ epstack, ok := stack.(*Stack)
+ if !ok {
+ return 0, errStackType
+ }
+
+ // Find the NIC corresponding to this interface.
+ for _, info := range epstack.Stack.NICInfo() {
+ if info.Name == name {
+ return nicStateFlagsToLinux(info.Flags), nil
+ }
+ }
+ return 0, syserr.ErrNoDevice
+}
+
+func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 {
+ var rv uint32
+ if f.Up {
+ rv |= linux.IFF_UP | linux.IFF_LOWER_UP
+ }
+ if f.Running {
+ rv |= linux.IFF_RUNNING
+ }
+ if f.Promiscuous {
+ rv |= linux.IFF_PROMISC
+ }
+ if f.Loopback {
+ rv |= linux.IFF_LOOPBACK
+ }
+ return rv
+}
+
+// State implements socket.Socket.State. State translates the internal state
+// returned by netstack to values defined by Linux.
+func (s *socketOpsCommon) State() uint32 {
+ if s.family != linux.AF_INET && s.family != linux.AF_INET6 {
+ // States not implemented for this socket's family.
+ return 0
+ }
+
+ switch {
+ case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP:
+ // TCP socket.
+ switch tcp.EndpointState(s.Endpoint.State()) {
+ case tcp.StateEstablished:
+ return linux.TCP_ESTABLISHED
+ case tcp.StateSynSent:
+ return linux.TCP_SYN_SENT
+ case tcp.StateSynRecv:
+ return linux.TCP_SYN_RECV
+ case tcp.StateFinWait1:
+ return linux.TCP_FIN_WAIT1
+ case tcp.StateFinWait2:
+ return linux.TCP_FIN_WAIT2
+ case tcp.StateTimeWait:
+ return linux.TCP_TIME_WAIT
+ case tcp.StateClose, tcp.StateInitial, tcp.StateBound, tcp.StateConnecting, tcp.StateError:
+ return linux.TCP_CLOSE
+ case tcp.StateCloseWait:
+ return linux.TCP_CLOSE_WAIT
+ case tcp.StateLastAck:
+ return linux.TCP_LAST_ACK
+ case tcp.StateListen:
+ return linux.TCP_LISTEN
+ case tcp.StateClosing:
+ return linux.TCP_CLOSING
+ default:
+ // Internal or unknown state.
+ return 0
+ }
+ case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP:
+ // UDP socket.
+ switch udp.EndpointState(s.Endpoint.State()) {
+ case udp.StateInitial, udp.StateBound, udp.StateClosed:
+ return linux.TCP_CLOSE
+ case udp.StateConnected:
+ return linux.TCP_ESTABLISHED
+ default:
+ return 0
+ }
+ case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6:
+ // TODO(b/112063468): Export states for ICMP sockets.
+ case s.skType == linux.SOCK_RAW:
+ // TODO(b/112063468): Export states for raw sockets.
+ default:
+ // Unknown transport protocol, how did we make this socket?
+ log.Warningf("Unknown transport protocol for an existing socket: family=%v, type=%v, protocol=%v, internal type %v", s.family, s.skType, s.protocol, reflect.TypeOf(s.Endpoint).Elem())
+ return 0
+ }
+
+ return 0
+}
+
+// Type implements socket.Socket.Type.
+func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
+ return s.family, s.skType, s.protocol
+}
+
+// LINT.ThenChange(./netstack_vfs2.go)
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
new file mode 100644
index 000000000..d65a89316
--- /dev/null
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -0,0 +1,330 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netstack
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/amutex"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SocketVFS2 encapsulates all the state needed to represent a network stack
+// endpoint in the kernel context.
+type SocketVFS2 struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.LockFD
+
+ socketOpsCommon
+}
+
+var _ = socket.SocketVFS2(&SocketVFS2{})
+
+// NewVFS2 creates a new endpoint socket.
+func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*vfs.FileDescription, *syserr.Error) {
+ if skType == linux.SOCK_STREAM {
+ if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ }
+
+ mnt := t.Kernel().SocketMount()
+ d := sockfs.NewDentry(t.Credentials(), mnt)
+
+ s := &SocketVFS2{
+ socketOpsCommon: socketOpsCommon{
+ Queue: queue,
+ family: family,
+ Endpoint: endpoint,
+ skType: skType,
+ protocol: protocol,
+ },
+ }
+ s.LockFD.Init(&vfs.FileLocks{})
+ vfsfd := &s.vfsfd
+ if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return vfsfd, nil
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.socketOpsCommon.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.socketOpsCommon.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SocketVFS2) EventUnregister(e *waiter.Entry) {
+ s.socketOpsCommon.EventUnregister(e)
+}
+
+// Read implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false)
+ if err == syserr.ErrWouldBlock {
+ return int64(n), syserror.ErrWouldBlock
+ }
+ if err != nil {
+ return 0, err.ToError()
+ }
+ return int64(n), nil
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ f := &ioSequencePayload{ctx: ctx, src: src}
+ n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
+ if err == tcpip.ErrWouldBlock {
+ return 0, syserror.ErrWouldBlock
+ }
+
+ if resCh != nil {
+ if err := amutex.Block(ctx, resCh); err != nil {
+ return 0, err
+ }
+ n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{})
+ }
+
+ if err != nil {
+ return 0, syserr.TranslateNetstackError(err).ToError()
+ }
+
+ if int64(n) < src.NumBytes() {
+ return int64(n), syserror.ErrWouldBlock
+ }
+
+ return int64(n), nil
+}
+
+// Accept implements the linux syscall accept(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ // Issue the accept request to get the new endpoint.
+ ep, wq, terr := s.Endpoint.Accept()
+ if terr != nil {
+ if terr != tcpip.ErrWouldBlock || !blocking {
+ return 0, nil, 0, syserr.TranslateNetstackError(terr)
+ }
+
+ var err *syserr.Error
+ ep, wq, err = s.blockingAccept(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ ns, err := NewVFS2(t, s.family, s.skType, s.protocol, wq, ep)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ defer ns.DecRef()
+
+ if err := ns.SetStatusFlags(t, t.Credentials(), uint32(flags&linux.SOCK_NONBLOCK)); err != nil {
+ return 0, nil, 0, syserr.FromError(err)
+ }
+
+ var addr linux.SockAddr
+ var addrLen uint32
+ if peerRequested {
+ // Get address of the peer and write it to peer slice.
+ var err *syserr.Error
+ addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ })
+
+ t.Kernel().RecordSocketVFS2(ns)
+
+ return fd, addr, addrLen, syserr.FromError(e)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return s.socketOpsCommon.ioctl(ctx, uio, args)
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+ // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
+ // implemented specifically for netstack.SocketVFS2 rather than
+ // commonEndpoint. commonEndpoint should be extended to support socket
+ // options where the implementation is not shared, as unix sockets need
+ // their own support for SO_TIMESTAMP.
+ if level == linux.SOL_SOCKET && name == linux.SO_TIMESTAMP {
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ val := int32(0)
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ if s.sockOptTimestamp {
+ val = 1
+ }
+ return val, nil
+ }
+ if level == linux.SOL_TCP && name == linux.TCP_INQ {
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ val := int32(0)
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ if s.sockOptInq {
+ val = 1
+ }
+ return val, nil
+ }
+
+ if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
+ switch name {
+ case linux.IPT_SO_GET_INFO:
+ if outLen < linux.SizeOfIPTGetinfo {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr)
+ if err != nil {
+ return nil, err
+ }
+ return info, nil
+
+ case linux.IPT_SO_GET_ENTRIES:
+ if outLen < linux.SizeOfIPTGetEntries {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ entries, err := netfilter.GetEntries(t, stack.(*Stack).Stack, outPtr, outLen)
+ if err != nil {
+ return nil, err
+ }
+ return entries, nil
+
+ }
+ }
+
+ return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
+ // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
+ // implemented specifically for netstack.SocketVFS2 rather than
+ // commonEndpoint. commonEndpoint should be extended to support socket
+ // options where the implementation is not shared, as unix sockets need
+ // their own support for SO_TIMESTAMP.
+ if level == linux.SOL_SOCKET && name == linux.SO_TIMESTAMP {
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ s.sockOptTimestamp = usermem.ByteOrder.Uint32(optVal) != 0
+ return nil
+ }
+ if level == linux.SOL_TCP && name == linux.TCP_INQ {
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ s.sockOptInq = usermem.ByteOrder.Uint32(optVal) != 0
+ return nil
+ }
+
+ if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
+ switch name {
+ case linux.IPT_SO_SET_REPLACE:
+ if len(optVal) < linux.SizeOfIPTReplace {
+ return syserr.ErrInvalidArgument
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return syserr.ErrNoDevice
+ }
+ // Stack must be a netstack stack.
+ return netfilter.SetEntries(stack.(*Stack).Stack, optVal)
+
+ case linux.IPT_SO_SET_ADD_COUNTERS:
+ // TODO(gvisor.dev/issue/170): Counter support.
+ return nil
+ }
+ }
+
+ return SetSockOpt(t, s, s.Endpoint, level, name, optVal)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
new file mode 100644
index 000000000..ead3b2b79
--- /dev/null
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -0,0 +1,199 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netstack
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+// provider is an inet socket provider.
+type provider struct {
+ family int
+ netProto tcpip.NetworkProtocolNumber
+}
+
+// getTransportProtocol figures out transport protocol. Currently only TCP,
+// UDP, and ICMP are supported. The bool return value is true when this socket
+// is associated with a transport protocol. This is only false for SOCK_RAW,
+// IPPROTO_IP sockets.
+func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol int) (tcpip.TransportProtocolNumber, bool, *syserr.Error) {
+ switch stype {
+ case linux.SOCK_STREAM:
+ if protocol != 0 && protocol != syscall.IPPROTO_TCP {
+ return 0, true, syserr.ErrInvalidArgument
+ }
+ return tcp.ProtocolNumber, true, nil
+
+ case linux.SOCK_DGRAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_UDP:
+ return udp.ProtocolNumber, true, nil
+ case syscall.IPPROTO_ICMP:
+ return header.ICMPv4ProtocolNumber, true, nil
+ case syscall.IPPROTO_ICMPV6:
+ return header.ICMPv6ProtocolNumber, true, nil
+ }
+
+ case linux.SOCK_RAW:
+ // Raw sockets require CAP_NET_RAW.
+ creds := auth.CredentialsFromContext(ctx)
+ if !creds.HasCapability(linux.CAP_NET_RAW) {
+ return 0, true, syserr.ErrNotPermitted
+ }
+
+ switch protocol {
+ case syscall.IPPROTO_ICMP:
+ return header.ICMPv4ProtocolNumber, true, nil
+ case syscall.IPPROTO_ICMPV6:
+ return header.ICMPv6ProtocolNumber, true, nil
+ case syscall.IPPROTO_UDP:
+ return header.UDPProtocolNumber, true, nil
+ case syscall.IPPROTO_TCP:
+ return header.TCPProtocolNumber, true, nil
+ // IPPROTO_RAW signifies that the raw socket isn't assigned to
+ // a transport protocol. Users will be able to write packets'
+ // IP headers and won't receive anything.
+ case syscall.IPPROTO_RAW:
+ return tcpip.TransportProtocolNumber(0), false, nil
+ }
+ }
+ return 0, true, syserr.ErrProtocolNotSupported
+}
+
+// Socket creates a new socket object for the AF_INET, AF_INET6, or AF_PACKET
+// family.
+func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Fail right away if we don't have a stack.
+ stack := t.NetworkContext()
+ if stack == nil {
+ // Don't propagate an error here. Instead, allow the socket
+ // code to continue searching for another provider.
+ return nil, nil
+ }
+ eps, ok := stack.(*Stack)
+ if !ok {
+ return nil, nil
+ }
+
+ // Packet sockets are handled separately, since they are neither INET
+ // nor INET6 specific.
+ if p.family == linux.AF_PACKET {
+ return packetSocket(t, eps, stype, protocol)
+ }
+
+ // Figure out the transport protocol.
+ transProto, associated, err := getTransportProtocol(t, stype, protocol)
+ if err != nil {
+ return nil, err
+ }
+
+ // Create the endpoint.
+ var ep tcpip.Endpoint
+ var e *tcpip.Error
+ wq := &waiter.Queue{}
+ if stype == linux.SOCK_RAW {
+ ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated)
+ } else {
+ ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq)
+
+ // Assign task to PacketOwner interface to get the UID and GID for
+ // iptables owner matching.
+ if e == nil {
+ ep.SetOwner(t)
+ }
+ }
+ if e != nil {
+ return nil, syserr.TranslateNetstackError(e)
+ }
+
+ return New(t, p.family, stype, int(transProto), wq, ep)
+}
+
+func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Packet sockets require CAP_NET_RAW.
+ creds := auth.CredentialsFromContext(t)
+ if !creds.HasCapability(linux.CAP_NET_RAW) {
+ return nil, syserr.ErrNotPermitted
+ }
+
+ // "cooked" packets don't contain link layer information.
+ var cooked bool
+ switch stype {
+ case linux.SOCK_DGRAM:
+ cooked = true
+ case linux.SOCK_RAW:
+ cooked = false
+ default:
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // protocol is passed in network byte order, but netstack wants it in
+ // host order.
+ netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol)))
+
+ wq := &waiter.Queue{}
+ ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return New(t, linux.AF_PACKET, stype, protocol, wq, ep)
+}
+
+// LINT.ThenChange(./provider_vfs2.go)
+
+// Pair just returns nil sockets (not supported).
+func (*provider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
+ return nil, nil, nil
+}
+
+// init registers socket providers for AF_INET, AF_INET6, and AF_PACKET.
+func init() {
+ // Providers backed by netstack.
+ p := []provider{
+ {
+ family: linux.AF_INET,
+ netProto: ipv4.ProtocolNumber,
+ },
+ {
+ family: linux.AF_INET6,
+ netProto: ipv6.ProtocolNumber,
+ },
+ {
+ family: linux.AF_PACKET,
+ },
+ }
+
+ for i := range p {
+ socket.RegisterProvider(p[i].family, &p[i])
+ }
+}
diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go
new file mode 100644
index 000000000..2a01143f6
--- /dev/null
+++ b/pkg/sentry/socket/netstack/provider_vfs2.go
@@ -0,0 +1,141 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netstack
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// providerVFS2 is an inet socket provider.
+type providerVFS2 struct {
+ family int
+ netProto tcpip.NetworkProtocolNumber
+}
+
+// Socket creates a new socket object for the AF_INET, AF_INET6, or AF_PACKET
+// family.
+func (p *providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ // Fail right away if we don't have a stack.
+ stack := t.NetworkContext()
+ if stack == nil {
+ // Don't propagate an error here. Instead, allow the socket
+ // code to continue searching for another provider.
+ return nil, nil
+ }
+ eps, ok := stack.(*Stack)
+ if !ok {
+ return nil, nil
+ }
+
+ // Packet sockets are handled separately, since they are neither INET
+ // nor INET6 specific.
+ if p.family == linux.AF_PACKET {
+ return packetSocketVFS2(t, eps, stype, protocol)
+ }
+
+ // Figure out the transport protocol.
+ transProto, associated, err := getTransportProtocol(t, stype, protocol)
+ if err != nil {
+ return nil, err
+ }
+
+ // Create the endpoint.
+ var ep tcpip.Endpoint
+ var e *tcpip.Error
+ wq := &waiter.Queue{}
+ if stype == linux.SOCK_RAW {
+ ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated)
+ } else {
+ ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq)
+
+ // Assign task to PacketOwner interface to get the UID and GID for
+ // iptables owner matching.
+ if e == nil {
+ ep.SetOwner(t)
+ }
+ }
+ if e != nil {
+ return nil, syserr.TranslateNetstackError(e)
+ }
+
+ return NewVFS2(t, p.family, stype, int(transProto), wq, ep)
+}
+
+func packetSocketVFS2(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ // Packet sockets require CAP_NET_RAW.
+ creds := auth.CredentialsFromContext(t)
+ if !creds.HasCapability(linux.CAP_NET_RAW) {
+ return nil, syserr.ErrNotPermitted
+ }
+
+ // "cooked" packets don't contain link layer information.
+ var cooked bool
+ switch stype {
+ case linux.SOCK_DGRAM:
+ cooked = true
+ case linux.SOCK_RAW:
+ cooked = false
+ default:
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // protocol is passed in network byte order, but netstack wants it in
+ // host order.
+ netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol)))
+
+ wq := &waiter.Queue{}
+ ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return NewVFS2(t, linux.AF_PACKET, stype, protocol, wq, ep)
+}
+
+// Pair just returns nil sockets (not supported).
+func (*providerVFS2) Pair(*kernel.Task, linux.SockType, int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ return nil, nil, nil
+}
+
+// init registers socket providers for AF_INET, AF_INET6, and AF_PACKET.
+func init() {
+ // Providers backed by netstack.
+ p := []providerVFS2{
+ {
+ family: linux.AF_INET,
+ netProto: ipv4.ProtocolNumber,
+ },
+ {
+ family: linux.AF_INET6,
+ netProto: ipv6.ProtocolNumber,
+ },
+ {
+ family: linux.AF_PACKET,
+ },
+ }
+
+ for i := range p {
+ socket.RegisterProviderVFS2(p[i].family, &p[i])
+ }
+}
diff --git a/pkg/sentry/socket/netstack/save_restore.go b/pkg/sentry/socket/netstack/save_restore.go
new file mode 100644
index 000000000..c7aaf722a
--- /dev/null
+++ b/pkg/sentry/socket/netstack/save_restore.go
@@ -0,0 +1,27 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netstack
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// afterLoad is invoked by stateify.
+func (s *Stack) afterLoad() {
+ s.Stack = stack.StackFromEnv // FIXME(b/36201077)
+ if s.Stack == nil {
+ panic("can't restore without netstack/tcpip/stack.Stack")
+ }
+}
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
new file mode 100644
index 000000000..548442b96
--- /dev/null
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -0,0 +1,386 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netstack
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+)
+
+// Stack implements inet.Stack for netstack/tcpip/stack.Stack.
+//
+// +stateify savable
+type Stack struct {
+ Stack *stack.Stack `state:"manual"`
+}
+
+// SupportsIPv6 implements Stack.SupportsIPv6.
+func (s *Stack) SupportsIPv6() bool {
+ return s.Stack.CheckNetworkProtocol(ipv6.ProtocolNumber)
+}
+
+// Interfaces implements inet.Stack.Interfaces.
+func (s *Stack) Interfaces() map[int32]inet.Interface {
+ is := make(map[int32]inet.Interface)
+ for id, ni := range s.Stack.NICInfo() {
+ var devType uint16
+ if ni.Flags.Loopback {
+ devType = linux.ARPHRD_LOOPBACK
+ }
+ is[int32(id)] = inet.Interface{
+ Name: ni.Name,
+ Addr: []byte(ni.LinkAddress),
+ Flags: uint32(nicStateFlagsToLinux(ni.Flags)),
+ DeviceType: devType,
+ MTU: ni.MTU,
+ }
+ }
+ return is
+}
+
+// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
+func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
+ nicAddrs := make(map[int32][]inet.InterfaceAddr)
+ for id, ni := range s.Stack.NICInfo() {
+ var addrs []inet.InterfaceAddr
+ for _, a := range ni.ProtocolAddresses {
+ var family uint8
+ switch a.Protocol {
+ case ipv4.ProtocolNumber:
+ family = linux.AF_INET
+ case ipv6.ProtocolNumber:
+ family = linux.AF_INET6
+ default:
+ log.Warningf("Unknown network protocol in %+v", a)
+ continue
+ }
+
+ addrs = append(addrs, inet.InterfaceAddr{
+ Family: family,
+ PrefixLen: uint8(a.AddressWithPrefix.PrefixLen),
+ Addr: []byte(a.AddressWithPrefix.Address),
+ // TODO(b/68878065): Other fields.
+ })
+ }
+ nicAddrs[int32(id)] = addrs
+ }
+ return nicAddrs
+}
+
+// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
+func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+ var (
+ protocol tcpip.NetworkProtocolNumber
+ address tcpip.Address
+ )
+ switch addr.Family {
+ case linux.AF_INET:
+ if len(addr.Addr) < header.IPv4AddressSize {
+ return syserror.EINVAL
+ }
+ if addr.PrefixLen > header.IPv4AddressSize*8 {
+ return syserror.EINVAL
+ }
+ protocol = ipv4.ProtocolNumber
+ address = tcpip.Address(addr.Addr[:header.IPv4AddressSize])
+
+ case linux.AF_INET6:
+ if len(addr.Addr) < header.IPv6AddressSize {
+ return syserror.EINVAL
+ }
+ if addr.PrefixLen > header.IPv6AddressSize*8 {
+ return syserror.EINVAL
+ }
+ protocol = ipv6.ProtocolNumber
+ address = tcpip.Address(addr.Addr[:header.IPv6AddressSize])
+
+ default:
+ return syserror.ENOTSUP
+ }
+
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: int(addr.PrefixLen),
+ },
+ }
+
+ // Attach address to interface.
+ if err := s.Stack.AddProtocolAddressWithOptions(tcpip.NICID(idx), protocolAddress, stack.CanBePrimaryEndpoint); err != nil {
+ return syserr.TranslateNetstackError(err).ToError()
+ }
+
+ // Add route for local network.
+ s.Stack.AddRoute(tcpip.Route{
+ Destination: protocolAddress.AddressWithPrefix.Subnet(),
+ Gateway: "", // No gateway for local network.
+ NIC: tcpip.NICID(idx),
+ })
+ return nil
+}
+
+// TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize.
+func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
+ var rs tcp.ReceiveBufferSizeOption
+ err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &rs)
+ return inet.TCPBufferSize{
+ Min: rs.Min,
+ Default: rs.Default,
+ Max: rs.Max,
+ }, syserr.TranslateNetstackError(err).ToError()
+}
+
+// SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize.
+func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error {
+ rs := tcp.ReceiveBufferSizeOption{
+ Min: size.Min,
+ Default: size.Default,
+ Max: size.Max,
+ }
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, rs)).ToError()
+}
+
+// TCPSendBufferSize implements inet.Stack.TCPSendBufferSize.
+func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) {
+ var ss tcp.SendBufferSizeOption
+ err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &ss)
+ return inet.TCPBufferSize{
+ Min: ss.Min,
+ Default: ss.Default,
+ Max: ss.Max,
+ }, syserr.TranslateNetstackError(err).ToError()
+}
+
+// SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize.
+func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error {
+ ss := tcp.SendBufferSizeOption{
+ Min: size.Min,
+ Default: size.Default,
+ Max: size.Max,
+ }
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, ss)).ToError()
+}
+
+// TCPSACKEnabled implements inet.Stack.TCPSACKEnabled.
+func (s *Stack) TCPSACKEnabled() (bool, error) {
+ var sack tcp.SACKEnabled
+ err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &sack)
+ return bool(sack), syserr.TranslateNetstackError(err).ToError()
+}
+
+// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
+func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError()
+}
+
+// Statistics implements inet.Stack.Statistics.
+func (s *Stack) Statistics(stat interface{}, arg string) error {
+ switch stats := stat.(type) {
+ case *inet.StatDev:
+ for _, ni := range s.Stack.NICInfo() {
+ if ni.Name != arg {
+ continue
+ }
+ // TODO(gvisor.dev/issue/2103) Support stubbed stats.
+ *stats = inet.StatDev{
+ // Receive section.
+ ni.Stats.Rx.Bytes.Value(), // bytes.
+ ni.Stats.Rx.Packets.Value(), // packets.
+ 0, // errs.
+ 0, // drop.
+ 0, // fifo.
+ 0, // frame.
+ 0, // compressed.
+ 0, // multicast.
+ // Transmit section.
+ ni.Stats.Tx.Bytes.Value(), // bytes.
+ ni.Stats.Tx.Packets.Value(), // packets.
+ 0, // errs.
+ 0, // drop.
+ 0, // fifo.
+ 0, // colls.
+ 0, // carrier.
+ 0, // compressed.
+ }
+ break
+ }
+ case *inet.StatSNMPIP:
+ ip := Metrics.IP
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
+ *stats = inet.StatSNMPIP{
+ 0, // Ip/Forwarding.
+ 0, // Ip/DefaultTTL.
+ ip.PacketsReceived.Value(), // InReceives.
+ 0, // Ip/InHdrErrors.
+ ip.InvalidDestinationAddressesReceived.Value(), // InAddrErrors.
+ 0, // Ip/ForwDatagrams.
+ 0, // Ip/InUnknownProtos.
+ 0, // Ip/InDiscards.
+ ip.PacketsDelivered.Value(), // InDelivers.
+ ip.PacketsSent.Value(), // OutRequests.
+ ip.OutgoingPacketErrors.Value(), // OutDiscards.
+ 0, // Ip/OutNoRoutes.
+ 0, // Support Ip/ReasmTimeout.
+ 0, // Support Ip/ReasmReqds.
+ 0, // Support Ip/ReasmOKs.
+ 0, // Support Ip/ReasmFails.
+ 0, // Support Ip/FragOKs.
+ 0, // Support Ip/FragFails.
+ 0, // Support Ip/FragCreates.
+ }
+ case *inet.StatSNMPICMP:
+ in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats
+ out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
+ *stats = inet.StatSNMPICMP{
+ 0, // Icmp/InMsgs.
+ Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors.
+ 0, // Icmp/InCsumErrors.
+ in.DstUnreachable.Value(), // InDestUnreachs.
+ in.TimeExceeded.Value(), // InTimeExcds.
+ in.ParamProblem.Value(), // InParmProbs.
+ in.SrcQuench.Value(), // InSrcQuenchs.
+ in.Redirect.Value(), // InRedirects.
+ in.Echo.Value(), // InEchos.
+ in.EchoReply.Value(), // InEchoReps.
+ in.Timestamp.Value(), // InTimestamps.
+ in.TimestampReply.Value(), // InTimestampReps.
+ in.InfoRequest.Value(), // InAddrMasks.
+ in.InfoReply.Value(), // InAddrMaskReps.
+ 0, // Icmp/OutMsgs.
+ Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors.
+ out.DstUnreachable.Value(), // OutDestUnreachs.
+ out.TimeExceeded.Value(), // OutTimeExcds.
+ out.ParamProblem.Value(), // OutParmProbs.
+ out.SrcQuench.Value(), // OutSrcQuenchs.
+ out.Redirect.Value(), // OutRedirects.
+ out.Echo.Value(), // OutEchos.
+ out.EchoReply.Value(), // OutEchoReps.
+ out.Timestamp.Value(), // OutTimestamps.
+ out.TimestampReply.Value(), // OutTimestampReps.
+ out.InfoRequest.Value(), // OutAddrMasks.
+ out.InfoReply.Value(), // OutAddrMaskReps.
+ }
+ case *inet.StatSNMPTCP:
+ tcp := Metrics.TCP
+ // RFC 2012 (updates 1213): SNMPv2-MIB-TCP.
+ *stats = inet.StatSNMPTCP{
+ 1, // RtoAlgorithm.
+ 200, // RtoMin.
+ 120000, // RtoMax.
+ (1<<64 - 1), // MaxConn.
+ tcp.ActiveConnectionOpenings.Value(), // ActiveOpens.
+ tcp.PassiveConnectionOpenings.Value(), // PassiveOpens.
+ tcp.FailedConnectionAttempts.Value(), // AttemptFails.
+ tcp.EstablishedResets.Value(), // EstabResets.
+ tcp.CurrentEstablished.Value(), // CurrEstab.
+ tcp.ValidSegmentsReceived.Value(), // InSegs.
+ tcp.SegmentsSent.Value(), // OutSegs.
+ tcp.Retransmits.Value(), // RetransSegs.
+ tcp.InvalidSegmentsReceived.Value(), // InErrs.
+ tcp.ResetsSent.Value(), // OutRsts.
+ tcp.ChecksumErrors.Value(), // InCsumErrors.
+ }
+ case *inet.StatSNMPUDP:
+ udp := Metrics.UDP
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
+ *stats = inet.StatSNMPUDP{
+ udp.PacketsReceived.Value(), // InDatagrams.
+ udp.UnknownPortErrors.Value(), // NoPorts.
+ 0, // Udp/InErrors.
+ udp.PacketsSent.Value(), // OutDatagrams.
+ udp.ReceiveBufferErrors.Value(), // RcvbufErrors.
+ 0, // Udp/SndbufErrors.
+ udp.ChecksumErrors.Value(), // Udp/InCsumErrors.
+ 0, // Udp/IgnoredMulti.
+ }
+ default:
+ return syserr.ErrEndpointOperation.ToError()
+ }
+ return nil
+}
+
+// RouteTable implements inet.Stack.RouteTable.
+func (s *Stack) RouteTable() []inet.Route {
+ var routeTable []inet.Route
+
+ for _, rt := range s.Stack.GetRouteTable() {
+ var family uint8
+ switch len(rt.Destination.ID()) {
+ case header.IPv4AddressSize:
+ family = linux.AF_INET
+ case header.IPv6AddressSize:
+ family = linux.AF_INET6
+ default:
+ log.Warningf("Unknown network protocol in route %+v", rt)
+ continue
+ }
+
+ routeTable = append(routeTable, inet.Route{
+ Family: family,
+ DstLen: uint8(rt.Destination.Prefix()), // The CIDR prefix for the destination.
+
+ // Always return unspecified protocol since we have no notion of
+ // protocol for routes.
+ Protocol: linux.RTPROT_UNSPEC,
+ // Set statically to LINK scope for now.
+ //
+ // TODO(gvisor.dev/issue/595): Set scope for routes.
+ Scope: linux.RT_SCOPE_LINK,
+ Type: linux.RTN_UNICAST,
+
+ DstAddr: []byte(rt.Destination.ID()),
+ OutputInterface: int32(rt.NIC),
+ GatewayAddr: []byte(rt.Gateway),
+ })
+ }
+
+ return routeTable
+}
+
+// IPTables returns the stack's iptables.
+func (s *Stack) IPTables() (*stack.IPTables, error) {
+ return s.Stack.IPTables(), nil
+}
+
+// Resume implements inet.Stack.Resume.
+func (s *Stack) Resume() {
+ s.Stack.Resume()
+}
+
+// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
+func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint {
+ return s.Stack.RegisteredEndpoints()
+}
+
+// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
+func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint {
+ return s.Stack.CleanupEndpoints()
+}
+
+// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
+func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) {
+ s.Stack.RestoreCleanupEndpoints(es)
+}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
new file mode 100644
index 000000000..fcd7f9d7f
--- /dev/null
+++ b/pkg/sentry/socket/socket.go
@@ -0,0 +1,461 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package socket provides the interfaces that need to be provided by socket
+// implementations and providers, as well as per family demultiplexing of socket
+// creation.
+package socket
+
+import (
+ "fmt"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/device"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// ControlMessages represents the union of unix control messages and tcpip
+// control messages.
+type ControlMessages struct {
+ Unix transport.ControlMessages
+ IP tcpip.ControlMessages
+}
+
+// Release releases Unix domain socket credentials and rights.
+func (c *ControlMessages) Release() {
+ c.Unix.Release()
+}
+
+// Socket is an interface combining fs.FileOperations and SocketOps,
+// representing a VFS1 socket file.
+type Socket interface {
+ fs.FileOperations
+ SocketOps
+}
+
+// SocketVFS2 is an interface combining vfs.FileDescription and SocketOps,
+// representing a VFS2 socket file.
+type SocketVFS2 interface {
+ vfs.FileDescriptionImpl
+ SocketOps
+}
+
+// SocketOps is the interface containing socket syscalls used by the syscall
+// layer to redirect them to the appropriate implementation.
+//
+// It is implemented by both Socket and SocketVFS2.
+type SocketOps interface {
+ // Connect implements the connect(2) linux syscall.
+ Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error
+
+ // Accept implements the accept4(2) linux syscall.
+ // Returns fd, real peer address length and error. Real peer address
+ // length is only set if len(peer) > 0.
+ Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error)
+
+ // Bind implements the bind(2) linux syscall.
+ Bind(t *kernel.Task, sockaddr []byte) *syserr.Error
+
+ // Listen implements the listen(2) linux syscall.
+ Listen(t *kernel.Task, backlog int) *syserr.Error
+
+ // Shutdown implements the shutdown(2) linux syscall.
+ Shutdown(t *kernel.Task, how int) *syserr.Error
+
+ // GetSockOpt implements the getsockopt(2) linux syscall.
+ GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error)
+
+ // SetSockOpt implements the setsockopt(2) linux syscall.
+ SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error
+
+ // GetSockName implements the getsockname(2) linux syscall.
+ //
+ // addrLen is the address length to be returned to the application, not
+ // necessarily the actual length of the address.
+ GetSockName(t *kernel.Task) (addr linux.SockAddr, addrLen uint32, err *syserr.Error)
+
+ // GetPeerName implements the getpeername(2) linux syscall.
+ //
+ // addrLen is the address length to be returned to the application, not
+ // necessarily the actual length of the address.
+ GetPeerName(t *kernel.Task) (addr linux.SockAddr, addrLen uint32, err *syserr.Error)
+
+ // RecvMsg implements the recvmsg(2) linux syscall.
+ //
+ // senderAddrLen is the address length to be returned to the application,
+ // not necessarily the actual length of the address.
+ //
+ // flags control how RecvMsg should be completed. msgFlags indicate how
+ // the RecvMsg call was completed. Note that control message truncation
+ // may still be required even if the MSG_CTRUNC bit is not set in
+ // msgFlags. In that case, the caller should set MSG_CTRUNC appropriately.
+ //
+ // If err != nil, the recv was not successful.
+ RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages ControlMessages, err *syserr.Error)
+
+ // SendMsg implements the sendmsg(2) linux syscall. SendMsg does not take
+ // ownership of the ControlMessage on error.
+ //
+ // If n > 0, err will either be nil or an error from t.Block.
+ SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages ControlMessages) (n int, err *syserr.Error)
+
+ // SetRecvTimeout sets the timeout (in ns) for recv operations. Zero means
+ // no timeout, and negative means DONTWAIT.
+ SetRecvTimeout(nanoseconds int64)
+
+ // RecvTimeout gets the current timeout (in ns) for recv operations. Zero
+ // means no timeout, and negative means DONTWAIT.
+ RecvTimeout() int64
+
+ // SetSendTimeout sets the timeout (in ns) for send operations. Zero means
+ // no timeout, and negative means DONTWAIT.
+ SetSendTimeout(nanoseconds int64)
+
+ // SendTimeout gets the current timeout (in ns) for send operations. Zero
+ // means no timeout, and negative means DONTWAIT.
+ SendTimeout() int64
+
+ // State returns the current state of the socket, as represented by Linux in
+ // procfs. The returned state value is protocol-specific.
+ State() uint32
+
+ // Type returns the family, socket type and protocol of the socket.
+ Type() (family int, skType linux.SockType, protocol int)
+}
+
+// Provider is the interface implemented by providers of sockets for specific
+// address families (e.g., AF_INET).
+type Provider interface {
+ // Socket creates a new socket.
+ //
+ // If a nil Socket _and_ a nil error is returned, it means that the
+ // protocol is not supported. A non-nil error should only be returned
+ // if the protocol is supported, but an error occurs during creation.
+ Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error)
+
+ // Pair creates a pair of connected sockets.
+ //
+ // See Socket for error information.
+ Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error)
+}
+
+// families holds a map of all known address families and their providers.
+var families = make(map[int][]Provider)
+
+// RegisterProvider registers the provider of a given address family so that
+// sockets of that type can be created via socket() and/or socketpair()
+// syscalls.
+//
+// This should only be called during the initialization of the address family.
+func RegisterProvider(family int, provider Provider) {
+ families[family] = append(families[family], provider)
+}
+
+// New creates a new socket with the given family, type and protocol.
+func New(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
+ for _, p := range families[family] {
+ s, err := p.Socket(t, stype, protocol)
+ if err != nil {
+ return nil, err
+ }
+ if s != nil {
+ t.Kernel().RecordSocket(s)
+ return s, nil
+ }
+ }
+
+ return nil, syserr.ErrAddressFamilyNotSupported
+}
+
+// Pair creates a new connected socket pair with the given family, type and
+// protocol.
+func Pair(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+ providers, ok := families[family]
+ if !ok {
+ return nil, nil, syserr.ErrAddressFamilyNotSupported
+ }
+
+ for _, p := range providers {
+ s1, s2, err := p.Pair(t, stype, protocol)
+ if err != nil {
+ return nil, nil, err
+ }
+ if s1 != nil && s2 != nil {
+ k := t.Kernel()
+ k.RecordSocket(s1)
+ k.RecordSocket(s2)
+ return s1, s2, nil
+ }
+ }
+
+ return nil, nil, syserr.ErrSocketNotSupported
+}
+
+// NewDirent returns a sockfs fs.Dirent that resides on device d.
+func NewDirent(ctx context.Context, d *device.Device) *fs.Dirent {
+ ino := d.NextIno()
+ iops := &fsutil.SimpleFileInode{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.FileOwnerFromContext(ctx), fs.FilePermissions{
+ User: fs.PermMask{Read: true, Write: true},
+ }, linux.SOCKFS_MAGIC),
+ }
+ inode := fs.NewInode(ctx, iops, fs.NewPseudoMountSource(ctx), fs.StableAttr{
+ Type: fs.Socket,
+ DeviceID: d.DeviceID(),
+ InodeID: ino,
+ BlockSize: usermem.PageSize,
+ })
+
+ // Dirent name matches net/socket.c:sockfs_dname.
+ return fs.NewDirent(ctx, inode, fmt.Sprintf("socket:[%d]", ino))
+}
+
+// ProviderVFS2 is the vfs2 interface implemented by providers of sockets for
+// specific address families (e.g., AF_INET).
+type ProviderVFS2 interface {
+ // Socket creates a new socket.
+ //
+ // If a nil Socket _and_ a nil error is returned, it means that the
+ // protocol is not supported. A non-nil error should only be returned
+ // if the protocol is supported, but an error occurs during creation.
+ Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error)
+
+ // Pair creates a pair of connected sockets.
+ //
+ // See Socket for error information.
+ Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error)
+}
+
+// familiesVFS2 holds a map of all known address families and their providers.
+var familiesVFS2 = make(map[int][]ProviderVFS2)
+
+// RegisterProviderVFS2 registers the provider of a given address family so that
+// sockets of that type can be created via socket() and/or socketpair()
+// syscalls.
+//
+// This should only be called during the initialization of the address family.
+func RegisterProviderVFS2(family int, provider ProviderVFS2) {
+ familiesVFS2[family] = append(familiesVFS2[family], provider)
+}
+
+// NewVFS2 creates a new socket with the given family, type and protocol.
+func NewVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ for _, p := range familiesVFS2[family] {
+ s, err := p.Socket(t, stype, protocol)
+ if err != nil {
+ return nil, err
+ }
+ if s != nil {
+ t.Kernel().RecordSocketVFS2(s)
+ return s, nil
+ }
+ }
+
+ return nil, syserr.ErrAddressFamilyNotSupported
+}
+
+// PairVFS2 creates a new connected socket pair with the given family, type and
+// protocol.
+func PairVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ providers, ok := familiesVFS2[family]
+ if !ok {
+ return nil, nil, syserr.ErrAddressFamilyNotSupported
+ }
+
+ for _, p := range providers {
+ s1, s2, err := p.Pair(t, stype, protocol)
+ if err != nil {
+ return nil, nil, err
+ }
+ if s1 != nil && s2 != nil {
+ k := t.Kernel()
+ k.RecordSocketVFS2(s1)
+ k.RecordSocketVFS2(s2)
+ return s1, s2, nil
+ }
+ }
+
+ return nil, nil, syserr.ErrSocketNotSupported
+}
+
+// SendReceiveTimeout stores timeouts for send and receive calls.
+//
+// It is meant to be embedded into Socket implementations to help satisfy the
+// interface.
+//
+// Care must be taken when copying SendReceiveTimeout as it contains atomic
+// variables.
+//
+// +stateify savable
+type SendReceiveTimeout struct {
+ // send is length of the send timeout in nanoseconds.
+ //
+ // send must be accessed atomically.
+ send int64
+
+ // recv is length of the receive timeout in nanoseconds.
+ //
+ // recv must be accessed atomically.
+ recv int64
+}
+
+// SetRecvTimeout implements Socket.SetRecvTimeout.
+func (to *SendReceiveTimeout) SetRecvTimeout(nanoseconds int64) {
+ atomic.StoreInt64(&to.recv, nanoseconds)
+}
+
+// RecvTimeout implements Socket.RecvTimeout.
+func (to *SendReceiveTimeout) RecvTimeout() int64 {
+ return atomic.LoadInt64(&to.recv)
+}
+
+// SetSendTimeout implements Socket.SetSendTimeout.
+func (to *SendReceiveTimeout) SetSendTimeout(nanoseconds int64) {
+ atomic.StoreInt64(&to.send, nanoseconds)
+}
+
+// SendTimeout implements Socket.SendTimeout.
+func (to *SendReceiveTimeout) SendTimeout() int64 {
+ return atomic.LoadInt64(&to.send)
+}
+
+// GetSockOptEmitUnimplementedEvent emits unimplemented event if name is valid.
+// It contains names that are valid for GetSockOpt when level is SOL_SOCKET.
+func GetSockOptEmitUnimplementedEvent(t *kernel.Task, name int) {
+ switch name {
+ case linux.SO_ACCEPTCONN,
+ linux.SO_BPF_EXTENSIONS,
+ linux.SO_COOKIE,
+ linux.SO_DOMAIN,
+ linux.SO_ERROR,
+ linux.SO_GET_FILTER,
+ linux.SO_INCOMING_NAPI_ID,
+ linux.SO_MEMINFO,
+ linux.SO_PEERCRED,
+ linux.SO_PEERGROUPS,
+ linux.SO_PEERNAME,
+ linux.SO_PEERSEC,
+ linux.SO_PROTOCOL,
+ linux.SO_SNDLOWAT,
+ linux.SO_TYPE:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+
+ default:
+ emitUnimplementedEvent(t, name)
+ }
+}
+
+// SetSockOptEmitUnimplementedEvent emits unimplemented event if name is valid.
+// It contains names that are valid for SetSockOpt when level is SOL_SOCKET.
+func SetSockOptEmitUnimplementedEvent(t *kernel.Task, name int) {
+ switch name {
+ case linux.SO_ATTACH_BPF,
+ linux.SO_ATTACH_FILTER,
+ linux.SO_ATTACH_REUSEPORT_CBPF,
+ linux.SO_ATTACH_REUSEPORT_EBPF,
+ linux.SO_CNX_ADVICE,
+ linux.SO_DETACH_FILTER,
+ linux.SO_RCVBUFFORCE,
+ linux.SO_SNDBUFFORCE:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+
+ default:
+ emitUnimplementedEvent(t, name)
+ }
+}
+
+// emitUnimplementedEvent emits unimplemented event if name is valid. It
+// contains names that are common between Get and SetSocketOpt when level is
+// SOL_SOCKET.
+func emitUnimplementedEvent(t *kernel.Task, name int) {
+ switch name {
+ case linux.SO_BINDTODEVICE,
+ linux.SO_BROADCAST,
+ linux.SO_BSDCOMPAT,
+ linux.SO_BUSY_POLL,
+ linux.SO_DEBUG,
+ linux.SO_DONTROUTE,
+ linux.SO_INCOMING_CPU,
+ linux.SO_KEEPALIVE,
+ linux.SO_LINGER,
+ linux.SO_LOCK_FILTER,
+ linux.SO_MARK,
+ linux.SO_MAX_PACING_RATE,
+ linux.SO_NOFCS,
+ linux.SO_OOBINLINE,
+ linux.SO_PASSCRED,
+ linux.SO_PASSSEC,
+ linux.SO_PEEK_OFF,
+ linux.SO_PRIORITY,
+ linux.SO_RCVBUF,
+ linux.SO_RCVLOWAT,
+ linux.SO_RCVTIMEO,
+ linux.SO_REUSEADDR,
+ linux.SO_REUSEPORT,
+ linux.SO_RXQ_OVFL,
+ linux.SO_SELECT_ERR_QUEUE,
+ linux.SO_SNDBUF,
+ linux.SO_SNDTIMEO,
+ linux.SO_TIMESTAMP,
+ linux.SO_TIMESTAMPING,
+ linux.SO_TIMESTAMPNS,
+ linux.SO_TXTIME,
+ linux.SO_WIFI_STATUS,
+ linux.SO_ZEROCOPY:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+}
+
+// UnmarshalSockAddr unmarshals memory representing a struct sockaddr to one of
+// the ABI socket address types.
+//
+// Precondition: data must be long enough to represent a socket address of the
+// given family.
+func UnmarshalSockAddr(family int, data []byte) linux.SockAddr {
+ switch family {
+ case syscall.AF_INET:
+ var addr linux.SockAddrInet
+ binary.Unmarshal(data[:syscall.SizeofSockaddrInet4], usermem.ByteOrder, &addr)
+ return &addr
+ case syscall.AF_INET6:
+ var addr linux.SockAddrInet6
+ binary.Unmarshal(data[:syscall.SizeofSockaddrInet6], usermem.ByteOrder, &addr)
+ return &addr
+ case syscall.AF_UNIX:
+ var addr linux.SockAddrUnix
+ binary.Unmarshal(data[:syscall.SizeofSockaddrUnix], usermem.ByteOrder, &addr)
+ return &addr
+ case syscall.AF_NETLINK:
+ var addr linux.SockAddrNetlink
+ binary.Unmarshal(data[:syscall.SizeofSockaddrNetlink], usermem.ByteOrder, &addr)
+ return &addr
+ default:
+ panic(fmt.Sprintf("Unsupported socket family %v", family))
+ }
+}
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
new file mode 100644
index 000000000..cca5e70f1
--- /dev/null
+++ b/pkg/sentry/socket/unix/BUILD
@@ -0,0 +1,39 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "unix",
+ srcs = [
+ "device.go",
+ "io.go",
+ "unix.go",
+ "unix_vfs2.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/refs",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsimpl/sockfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/control",
+ "//pkg/sentry/socket/netstack",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/socket/unix/device.go b/pkg/sentry/socket/unix/device.go
new file mode 100644
index 000000000..db01ac4c9
--- /dev/null
+++ b/pkg/sentry/socket/unix/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unix
+
+import "gvisor.dev/gvisor/pkg/sentry/device"
+
+// unixSocketDevice is the unix socket virtual device.
+var unixSocketDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go
new file mode 100644
index 000000000..129949990
--- /dev/null
+++ b/pkg/sentry/socket/unix/io.go
@@ -0,0 +1,111 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unix
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// EndpointWriter implements safemem.Writer that writes to a transport.Endpoint.
+//
+// EndpointWriter is not thread-safe.
+type EndpointWriter struct {
+ Ctx context.Context
+
+ // Endpoint is the transport.Endpoint to write to.
+ Endpoint transport.Endpoint
+
+ // Control is the control messages to send.
+ Control transport.ControlMessages
+
+ // To is the endpoint to send to. May be nil.
+ To transport.BoundEndpoint
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (w *EndpointWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ return safemem.FromVecWriterFunc{func(bufs [][]byte) (int64, error) {
+ n, err := w.Endpoint.SendMsg(w.Ctx, bufs, w.Control, w.To)
+ if err != nil {
+ return int64(n), err.ToError()
+ }
+ return int64(n), nil
+ }}.WriteFromBlocks(srcs)
+}
+
+// EndpointReader implements safemem.Reader that reads from a
+// transport.Endpoint.
+//
+// EndpointReader is not thread-safe.
+type EndpointReader struct {
+ Ctx context.Context
+
+ // Endpoint is the transport.Endpoint to read from.
+ Endpoint transport.Endpoint
+
+ // Creds indicates if credential control messages are requested.
+ Creds bool
+
+ // NumRights is the number of SCM_RIGHTS FDs requested.
+ NumRights int
+
+ // Peek indicates that the data should not be consumed from the
+ // endpoint.
+ Peek bool
+
+ // MsgSize is the size of the message that was read from. For stream
+ // sockets, it is the amount read.
+ MsgSize int64
+
+ // From, if not nil, will be set with the address read from.
+ From *tcpip.FullAddress
+
+ // Control contains the received control messages.
+ Control transport.ControlMessages
+
+ // ControlTrunc indicates that SCM_RIGHTS FDs were discarded based on
+ // the value of NumRights.
+ ControlTrunc bool
+}
+
+// Truncate calls RecvMsg on the endpoint without writing to a destination.
+func (r *EndpointReader) Truncate() error {
+ // Ignore bytes read since it will always be zero.
+ _, ms, c, ct, err := r.Endpoint.RecvMsg(r.Ctx, [][]byte{}, r.Creds, r.NumRights, r.Peek, r.From)
+ r.Control = c
+ r.ControlTrunc = ct
+ r.MsgSize = ms
+ if err != nil {
+ return err.ToError()
+ }
+ return nil
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) {
+ n, ms, c, ct, err := r.Endpoint.RecvMsg(r.Ctx, bufs, r.Creds, r.NumRights, r.Peek, r.From)
+ r.Control = c
+ r.ControlTrunc = ct
+ r.MsgSize = ms
+ if err != nil {
+ return int64(n), err.ToError()
+ }
+ return int64(n), nil
+ }}.ReadToBlocks(dsts)
+}
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
new file mode 100644
index 000000000..c708b6030
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -0,0 +1,41 @@
+load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "transport_message_list",
+ out = "transport_message_list.go",
+ package = "transport",
+ prefix = "message",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*message",
+ "Linker": "*message",
+ },
+)
+
+go_library(
+ name = "transport",
+ srcs = [
+ "connectioned.go",
+ "connectioned_state.go",
+ "connectionless.go",
+ "queue.go",
+ "transport_message_list.go",
+ "unix.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/ilist",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/sync",
+ "//pkg/syserr",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
new file mode 100644
index 000000000..a1e49cc57
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -0,0 +1,486 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// UniqueIDProvider generates a sequence of unique identifiers useful for,
+// among other things, lock ordering.
+type UniqueIDProvider interface {
+ // UniqueID returns a new unique identifier.
+ UniqueID() uint64
+}
+
+// A ConnectingEndpoint is a connectioned unix endpoint that is attempting to
+// establish a bidirectional connection with a BoundEndpoint.
+type ConnectingEndpoint interface {
+ // ID returns the endpoint's globally unique identifier. This identifier
+ // must be used to determine locking order if more than one endpoint is
+ // to be locked in the same codepath. The endpoint with the smaller
+ // identifier must be locked before endpoints with larger identifiers.
+ ID() uint64
+
+ // Passcred implements socket.Credentialer.Passcred.
+ Passcred() bool
+
+ // Type returns the socket type, typically either SockStream or
+ // SockSeqpacket. The connection attempt must be aborted if this
+ // value doesn't match the ConnectableEndpoint's type.
+ Type() linux.SockType
+
+ // GetLocalAddress returns the bound path.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // Locker protects the following methods. While locked, only the holder of
+ // the lock can change the return value of the protected methods.
+ sync.Locker
+
+ // Connected returns true iff the ConnectingEndpoint is in the connected
+ // state. ConnectingEndpoints can only be connected to a single endpoint,
+ // so the connection attempt must be aborted if this returns true.
+ Connected() bool
+
+ // Listening returns true iff the ConnectingEndpoint is in the listening
+ // state. ConnectingEndpoints cannot make connections while listening, so
+ // the connection attempt must be aborted if this returns true.
+ Listening() bool
+
+ // WaiterQueue returns a pointer to the endpoint's waiter queue.
+ WaiterQueue() *waiter.Queue
+}
+
+// connectionedEndpoint is a Unix-domain connected or connectable endpoint and implements
+// ConnectingEndpoint, ConnectableEndpoint and tcpip.Endpoint.
+//
+// connectionedEndpoints must be in connected state in order to transfer data.
+//
+// This implementation includes STREAM and SEQPACKET Unix sockets created with
+// socket(2), accept(2) or socketpair(2) and dgram unix sockets created with
+// socketpair(2). See unix_connectionless.go for the implementation of DGRAM
+// Unix sockets created with socket(2).
+//
+// The state is much simpler than a TCP endpoint, so it is not encoded
+// explicitly. Instead we enforce the following invariants:
+//
+// receiver != nil, connected != nil => connected.
+// path != "" && acceptedChan == nil => bound, not listening.
+// path != "" && acceptedChan != nil => bound and listening.
+//
+// Only one of these will be true at any moment.
+//
+// +stateify savable
+type connectionedEndpoint struct {
+ baseEndpoint
+
+ // id is the unique endpoint identifier. This is used exclusively for
+ // lock ordering within connect.
+ id uint64
+
+ // idGenerator is used to generate new unique endpoint identifiers.
+ idGenerator UniqueIDProvider
+
+ // stype is used by connecting sockets to ensure that they are the
+ // same type. The value is typically either tcpip.SockSeqpacket or
+ // tcpip.SockStream.
+ stype linux.SockType
+
+ // acceptedChan is per the TCP endpoint implementation. Note that the
+ // sockets in this channel are _already in the connected state_, and
+ // have another associated connectionedEndpoint.
+ //
+ // If nil, then no listen call has been made.
+ acceptedChan chan *connectionedEndpoint `state:".([]*connectionedEndpoint)"`
+}
+
+var (
+ _ = BoundEndpoint((*connectionedEndpoint)(nil))
+ _ = Endpoint((*connectionedEndpoint)(nil))
+)
+
+// NewConnectioned creates a new unbound connectionedEndpoint.
+func NewConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) Endpoint {
+ return &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
+ id: uid.UniqueID(),
+ idGenerator: uid,
+ stype: stype,
+ }
+}
+
+// NewPair allocates a new pair of connected unix-domain connectionedEndpoints.
+func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
+ a := &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
+ id: uid.UniqueID(),
+ idGenerator: uid,
+ stype: stype,
+ }
+ b := &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
+ id: uid.UniqueID(),
+ idGenerator: uid,
+ stype: stype,
+ }
+
+ q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit}
+ q1.EnableLeakCheck("transport.queue")
+ q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit}
+ q2.EnableLeakCheck("transport.queue")
+
+ if stype == linux.SOCK_STREAM {
+ a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}}
+ b.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q2}}
+ } else {
+ a.receiver = &queueReceiver{q1}
+ b.receiver = &queueReceiver{q2}
+ }
+
+ q2.IncRef()
+ a.connected = &connectedEndpoint{
+ endpoint: b,
+ writeQueue: q2,
+ }
+ q1.IncRef()
+ b.connected = &connectedEndpoint{
+ endpoint: a,
+ writeQueue: q1,
+ }
+
+ return a, b
+}
+
+// NewExternal creates a new externally backed Endpoint. It behaves like a
+// socketpair.
+func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint {
+ return &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected},
+ id: uid.UniqueID(),
+ idGenerator: uid,
+ stype: stype,
+ }
+}
+
+// ID implements ConnectingEndpoint.ID.
+func (e *connectionedEndpoint) ID() uint64 {
+ return e.id
+}
+
+// Type implements ConnectingEndpoint.Type and Endpoint.Type.
+func (e *connectionedEndpoint) Type() linux.SockType {
+ return e.stype
+}
+
+// WaiterQueue implements ConnectingEndpoint.WaiterQueue.
+func (e *connectionedEndpoint) WaiterQueue() *waiter.Queue {
+ return e.Queue
+}
+
+// isBound returns true iff the connectionedEndpoint is bound (but not
+// listening).
+func (e *connectionedEndpoint) isBound() bool {
+ return e.path != "" && e.acceptedChan == nil
+}
+
+// Listening implements ConnectingEndpoint.Listening.
+func (e *connectionedEndpoint) Listening() bool {
+ return e.acceptedChan != nil
+}
+
+// Close puts the connectionedEndpoint in a closed state and frees all
+// resources associated with it.
+//
+// The socket will be a fresh state after a call to close and may be reused.
+// That is, close may be used to "unbind" or "disconnect" the socket in error
+// paths.
+func (e *connectionedEndpoint) Close() {
+ e.Lock()
+ var c ConnectedEndpoint
+ var r Receiver
+ switch {
+ case e.Connected():
+ e.connected.CloseSend()
+ e.receiver.CloseRecv()
+ // Still have unread data? If yes, we set this into the write
+ // end so that the peer can get ECONNRESET) when it does read.
+ if e.receiver.RecvQueuedSize() > 0 {
+ e.connected.CloseUnread()
+ }
+ c = e.connected
+ r = e.receiver
+ e.connected = nil
+ e.receiver = nil
+ case e.isBound():
+ e.path = ""
+ case e.Listening():
+ close(e.acceptedChan)
+ for n := range e.acceptedChan {
+ n.Close()
+ }
+ e.acceptedChan = nil
+ e.path = ""
+ }
+ e.Unlock()
+ if c != nil {
+ c.CloseNotify()
+ c.Release()
+ }
+ if r != nil {
+ r.CloseNotify()
+ r.Release()
+ }
+}
+
+// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
+func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
+ if ce.Type() != e.stype {
+ return syserr.ErrWrongProtocolForSocket
+ }
+
+ // Check if ce is e to avoid a deadlock.
+ if ce, ok := ce.(*connectionedEndpoint); ok && ce == e {
+ return syserr.ErrInvalidEndpointState
+ }
+
+ // Do a dance to safely acquire locks on both endpoints.
+ if e.id < ce.ID() {
+ e.Lock()
+ ce.Lock()
+ } else {
+ ce.Lock()
+ e.Lock()
+ }
+
+ // Check connecting state.
+ if ce.Connected() {
+ e.Unlock()
+ ce.Unlock()
+ return syserr.ErrAlreadyConnected
+ }
+ if ce.Listening() {
+ e.Unlock()
+ ce.Unlock()
+ return syserr.ErrInvalidEndpointState
+ }
+
+ // Check bound state.
+ if !e.Listening() {
+ e.Unlock()
+ ce.Unlock()
+ return syserr.ErrConnectionRefused
+ }
+
+ // Create a newly bound connectionedEndpoint.
+ ne := &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{
+ path: e.path,
+ Queue: &waiter.Queue{},
+ },
+ id: e.idGenerator.UniqueID(),
+ idGenerator: e.idGenerator,
+ stype: e.stype,
+ }
+
+ readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit}
+ readQueue.EnableLeakCheck("transport.queue")
+ ne.connected = &connectedEndpoint{
+ endpoint: ce,
+ writeQueue: readQueue,
+ }
+
+ writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit}
+ writeQueue.EnableLeakCheck("transport.queue")
+ if e.stype == linux.SOCK_STREAM {
+ ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}}
+ } else {
+ ne.receiver = &queueReceiver{readQueue: writeQueue}
+ }
+
+ select {
+ case e.acceptedChan <- ne:
+ // Commit state.
+ writeQueue.IncRef()
+ connected := &connectedEndpoint{
+ endpoint: ne,
+ writeQueue: writeQueue,
+ }
+ readQueue.IncRef()
+ if e.stype == linux.SOCK_STREAM {
+ returnConnect(&streamQueueReceiver{queueReceiver: queueReceiver{readQueue: readQueue}}, connected)
+ } else {
+ returnConnect(&queueReceiver{readQueue: readQueue}, connected)
+ }
+
+ // Notify can deadlock if we are holding these locks.
+ e.Unlock()
+ ce.Unlock()
+
+ // Notify on both ends.
+ e.Notify(waiter.EventIn)
+ ce.WaiterQueue().Notify(waiter.EventOut)
+
+ return nil
+ default:
+ // Busy; return ECONNREFUSED per spec.
+ ne.Close()
+ e.Unlock()
+ ce.Unlock()
+ return syserr.ErrConnectionRefused
+ }
+}
+
+// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
+func (e *connectionedEndpoint) UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error) {
+ return nil, syserr.ErrConnectionRefused
+}
+
+// Connect attempts to directly connect to another Endpoint.
+// Implements Endpoint.Connect.
+func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint) *syserr.Error {
+ returnConnect := func(r Receiver, ce ConnectedEndpoint) {
+ e.receiver = r
+ e.connected = ce
+ }
+
+ return server.BidirectionalConnect(ctx, e, returnConnect)
+}
+
+// Listen starts listening on the connection.
+func (e *connectionedEndpoint) Listen(backlog int) *syserr.Error {
+ e.Lock()
+ defer e.Unlock()
+ if e.Listening() {
+ // Adjust the size of the channel iff we can fix existing
+ // pending connections into the new one.
+ if len(e.acceptedChan) > backlog {
+ return syserr.ErrInvalidEndpointState
+ }
+ origChan := e.acceptedChan
+ e.acceptedChan = make(chan *connectionedEndpoint, backlog)
+ close(origChan)
+ for ep := range origChan {
+ e.acceptedChan <- ep
+ }
+ return nil
+ }
+ if !e.isBound() {
+ return syserr.ErrInvalidEndpointState
+ }
+
+ // Normal case.
+ e.acceptedChan = make(chan *connectionedEndpoint, backlog)
+ return nil
+}
+
+// Accept accepts a new connection.
+func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) {
+ e.Lock()
+ defer e.Unlock()
+
+ if !e.Listening() {
+ return nil, syserr.ErrInvalidEndpointState
+ }
+
+ select {
+ case ne := <-e.acceptedChan:
+ return ne, nil
+
+ default:
+ // Nothing left.
+ return nil, syserr.ErrWouldBlock
+ }
+}
+
+// Bind binds the connection.
+//
+// For Unix connectionedEndpoints, this _only sets the address associated with
+// the socket_. Work associated with sockets in the filesystem or finding those
+// sockets must be done by a higher level.
+//
+// Bind will fail only if the socket is connected, bound or the passed address
+// is invalid (the empty string).
+func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *syserr.Error) *syserr.Error {
+ e.Lock()
+ defer e.Unlock()
+ if e.isBound() || e.Listening() {
+ return syserr.ErrAlreadyBound
+ }
+ if addr.Addr == "" {
+ // The empty string is not permitted.
+ return syserr.ErrBadLocalAddress
+ }
+ if commit != nil {
+ if err := commit(); err != nil {
+ return err
+ }
+ }
+
+ // Save the bound address.
+ e.path = string(addr.Addr)
+ return nil
+}
+
+// SendMsg writes data and a control message to the endpoint's peer.
+// This method does not block if the data cannot be written.
+func (e *connectionedEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (int64, *syserr.Error) {
+ // Stream sockets do not support specifying the endpoint. Seqpacket
+ // sockets ignore the passed endpoint.
+ if e.stype == linux.SOCK_STREAM && to != nil {
+ return 0, syserr.ErrNotSupported
+ }
+ return e.baseEndpoint.SendMsg(ctx, data, c, to)
+}
+
+// Readiness returns the current readiness of the connectionedEndpoint. For
+// example, if waiter.EventIn is set, the connectionedEndpoint is immediately
+// readable.
+func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ e.Lock()
+ defer e.Unlock()
+
+ ready := waiter.EventMask(0)
+ switch {
+ case e.Connected():
+ if mask&waiter.EventIn != 0 && e.receiver.Readable() {
+ ready |= waiter.EventIn
+ }
+ if mask&waiter.EventOut != 0 && e.connected.Writable() {
+ ready |= waiter.EventOut
+ }
+ case e.Listening():
+ if mask&waiter.EventIn != 0 && len(e.acceptedChan) > 0 {
+ ready |= waiter.EventIn
+ }
+ }
+
+ return ready
+}
+
+// State implements socket.Socket.State.
+func (e *connectionedEndpoint) State() uint32 {
+ e.Lock()
+ defer e.Unlock()
+
+ if e.Connected() {
+ return linux.SS_CONNECTED
+ }
+ return linux.SS_UNCONNECTED
+}
diff --git a/pkg/sentry/socket/unix/transport/connectioned_state.go b/pkg/sentry/socket/unix/transport/connectioned_state.go
new file mode 100644
index 000000000..7e02a5db8
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/connectioned_state.go
@@ -0,0 +1,53 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+// saveAcceptedChan is invoked by stateify.
+func (e *connectionedEndpoint) saveAcceptedChan() []*connectionedEndpoint {
+ // If acceptedChan is nil (i.e. we are not listening) then we will save nil.
+ // Otherwise we create a (possibly empty) slice of the values in acceptedChan and
+ // save that.
+ var acceptedSlice []*connectionedEndpoint
+ if e.acceptedChan != nil {
+ // Swap out acceptedChan with a new empty channel of the same capacity.
+ saveChan := e.acceptedChan
+ e.acceptedChan = make(chan *connectionedEndpoint, cap(saveChan))
+
+ // Create a new slice with the same len and capacity as the channel.
+ acceptedSlice = make([]*connectionedEndpoint, len(saveChan), cap(saveChan))
+ // Drain acceptedChan into saveSlice, and fill up the new acceptChan at the
+ // same time.
+ for i := range acceptedSlice {
+ ep := <-saveChan
+ acceptedSlice[i] = ep
+ e.acceptedChan <- ep
+ }
+ close(saveChan)
+ }
+ return acceptedSlice
+}
+
+// loadAcceptedChan is invoked by stateify.
+func (e *connectionedEndpoint) loadAcceptedChan(acceptedSlice []*connectionedEndpoint) {
+ // If acceptedSlice is nil, then acceptedChan should also be nil.
+ if acceptedSlice != nil {
+ // Otherwise, create a new channel with the same capacity as acceptedSlice.
+ e.acceptedChan = make(chan *connectionedEndpoint, cap(acceptedSlice))
+ // Seed the channel with values from acceptedSlice.
+ for _, ep := range acceptedSlice {
+ e.acceptedChan <- ep
+ }
+ }
+}
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
new file mode 100644
index 000000000..4b06d63ac
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -0,0 +1,218 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// connectionlessEndpoint is a unix endpoint for unix sockets that support operating in
+// a connectionless fashon.
+//
+// Specifically, this means datagram unix sockets not created with
+// socketpair(2).
+//
+// +stateify savable
+type connectionlessEndpoint struct {
+ baseEndpoint
+}
+
+var (
+ _ = BoundEndpoint((*connectionlessEndpoint)(nil))
+ _ = Endpoint((*connectionlessEndpoint)(nil))
+)
+
+// NewConnectionless creates a new unbound dgram endpoint.
+func NewConnectionless(ctx context.Context) Endpoint {
+ ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}}
+ q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}
+ q.EnableLeakCheck("transport.queue")
+ ep.receiver = &queueReceiver{readQueue: &q}
+ return ep
+}
+
+// isBound returns true iff the endpoint is bound.
+func (e *connectionlessEndpoint) isBound() bool {
+ return e.path != ""
+}
+
+// Close puts the endpoint in a closed state and frees all resources associated
+// with it.
+func (e *connectionlessEndpoint) Close() {
+ e.Lock()
+ if e.connected != nil {
+ e.connected.Release()
+ e.connected = nil
+ }
+
+ if e.isBound() {
+ e.path = ""
+ }
+
+ e.receiver.CloseRecv()
+ r := e.receiver
+ e.receiver = nil
+ e.Unlock()
+
+ r.CloseNotify()
+ r.Release()
+}
+
+// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
+func (e *connectionlessEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
+ return syserr.ErrConnectionRefused
+}
+
+// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
+func (e *connectionlessEndpoint) UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error) {
+ e.Lock()
+ r := e.receiver
+ e.Unlock()
+ if r == nil {
+ return nil, syserr.ErrConnectionRefused
+ }
+ q := r.(*queueReceiver).readQueue
+ if !q.TryIncRef() {
+ return nil, syserr.ErrConnectionRefused
+ }
+ return &connectedEndpoint{
+ endpoint: e,
+ writeQueue: q,
+ }, nil
+}
+
+// SendMsg writes data and a control message to the specified endpoint.
+// This method does not block if the data cannot be written.
+func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (int64, *syserr.Error) {
+ if to == nil {
+ return e.baseEndpoint.SendMsg(ctx, data, c, nil)
+ }
+
+ connected, err := to.UnidirectionalConnect(ctx)
+ if err != nil {
+ return 0, syserr.ErrInvalidEndpointState
+ }
+ defer connected.Release()
+
+ e.Lock()
+ n, notify, err := connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
+ e.Unlock()
+
+ if notify {
+ connected.SendNotify()
+ }
+
+ return n, err
+}
+
+// Type implements Endpoint.Type.
+func (e *connectionlessEndpoint) Type() linux.SockType {
+ return linux.SOCK_DGRAM
+}
+
+// Connect attempts to connect directly to server.
+func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoint) *syserr.Error {
+ connected, err := server.UnidirectionalConnect(ctx)
+ if err != nil {
+ return err
+ }
+
+ e.Lock()
+ if e.connected != nil {
+ e.connected.Release()
+ }
+ e.connected = connected
+ e.Unlock()
+
+ return nil
+}
+
+// Listen starts listening on the connection.
+func (e *connectionlessEndpoint) Listen(int) *syserr.Error {
+ return syserr.ErrNotSupported
+}
+
+// Accept accepts a new connection.
+func (e *connectionlessEndpoint) Accept() (Endpoint, *syserr.Error) {
+ return nil, syserr.ErrNotSupported
+}
+
+// Bind binds the connection.
+//
+// For Unix endpoints, this _only sets the address associated with the socket_.
+// Work associated with sockets in the filesystem or finding those sockets must
+// be done by a higher level.
+//
+// Bind will fail only if the socket is connected, bound or the passed address
+// is invalid (the empty string).
+func (e *connectionlessEndpoint) Bind(addr tcpip.FullAddress, commit func() *syserr.Error) *syserr.Error {
+ e.Lock()
+ defer e.Unlock()
+ if e.isBound() {
+ return syserr.ErrAlreadyBound
+ }
+ if addr.Addr == "" {
+ // The empty string is not permitted.
+ return syserr.ErrBadLocalAddress
+ }
+ if commit != nil {
+ if err := commit(); err != nil {
+ return err
+ }
+ }
+
+ // Save the bound address.
+ e.path = string(addr.Addr)
+ return nil
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *connectionlessEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ e.Lock()
+ defer e.Unlock()
+
+ ready := waiter.EventMask(0)
+ if mask&waiter.EventIn != 0 && e.receiver.Readable() {
+ ready |= waiter.EventIn
+ }
+
+ if e.Connected() {
+ if mask&waiter.EventOut != 0 && e.connected.Writable() {
+ ready |= waiter.EventOut
+ }
+ }
+
+ return ready
+}
+
+// State implements socket.Socket.State.
+func (e *connectionlessEndpoint) State() uint32 {
+ e.Lock()
+ defer e.Unlock()
+
+ switch {
+ case e.isBound():
+ return linux.SS_UNCONNECTED
+ case e.Connected():
+ return linux.SS_CONNECTING
+ default:
+ return linux.SS_DISCONNECTING
+ }
+}
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
new file mode 100644
index 000000000..d8f3ad63d
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -0,0 +1,247 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// queue is a buffer queue.
+//
+// +stateify savable
+type queue struct {
+ refs.AtomicRefCount
+
+ ReaderQueue *waiter.Queue
+ WriterQueue *waiter.Queue
+
+ mu sync.Mutex `state:"nosave"`
+ closed bool
+ unread bool
+ used int64
+ limit int64
+ dataList messageList
+}
+
+// Close closes q for reading and writing. It is immediately not writable and
+// will become unreadable when no more data is pending.
+//
+// Both the read and write queues must be notified after closing:
+// q.ReaderQueue.Notify(waiter.EventIn)
+// q.WriterQueue.Notify(waiter.EventOut)
+func (q *queue) Close() {
+ q.mu.Lock()
+ q.closed = true
+ q.mu.Unlock()
+}
+
+// Reset empties the queue and Releases all of the Entries.
+//
+// Both the read and write queues must be notified after resetting:
+// q.ReaderQueue.Notify(waiter.EventIn)
+// q.WriterQueue.Notify(waiter.EventOut)
+func (q *queue) Reset() {
+ q.mu.Lock()
+ for cur := q.dataList.Front(); cur != nil; cur = cur.Next() {
+ cur.Release()
+ }
+ q.dataList.Reset()
+ q.used = 0
+ q.mu.Unlock()
+}
+
+// DecRef implements RefCounter.DecRef with destructor q.Reset.
+func (q *queue) DecRef() {
+ q.DecRefWithDestructor(q.Reset)
+ // We don't need to notify after resetting because no one cares about
+ // this queue after all references have been dropped.
+}
+
+// IsReadable determines if q is currently readable.
+func (q *queue) IsReadable() bool {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ return q.closed || q.dataList.Front() != nil
+}
+
+// bufWritable returns true if there is space for writing.
+//
+// N.B. Linux only considers a unix socket "writable" if >75% of the buffer is
+// free.
+//
+// See net/unix/af_unix.c:unix_writeable.
+func (q *queue) bufWritable() bool {
+ return 4*q.used < q.limit
+}
+
+// IsWritable determines if q is currently writable.
+func (q *queue) IsWritable() bool {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ return q.closed || q.bufWritable()
+}
+
+// Enqueue adds an entry to the data queue if room is available.
+//
+// If discardEmpty is true and there are zero bytes of data, the packet is
+// dropped.
+//
+// If truncate is true, Enqueue may truncate the message before enqueuing it.
+// Otherwise, the entire message must fit. If l is less than the size of data,
+// err indicates why.
+//
+// If notify is true, ReaderQueue.Notify must be called:
+// q.ReaderQueue.Notify(waiter.EventIn)
+func (q *queue) Enqueue(data [][]byte, c ControlMessages, from tcpip.FullAddress, discardEmpty bool, truncate bool) (l int64, notify bool, err *syserr.Error) {
+ q.mu.Lock()
+
+ if q.closed {
+ q.mu.Unlock()
+ return 0, false, syserr.ErrClosedForSend
+ }
+
+ for _, d := range data {
+ l += int64(len(d))
+ }
+ if discardEmpty && l == 0 {
+ q.mu.Unlock()
+ c.Release()
+ return 0, false, nil
+ }
+
+ free := q.limit - q.used
+
+ if l > free && truncate {
+ if free == 0 {
+ // Message can't fit right now.
+ q.mu.Unlock()
+ return 0, false, syserr.ErrWouldBlock
+ }
+
+ l = free
+ err = syserr.ErrWouldBlock
+ }
+
+ if l > q.limit {
+ // Message is too big to ever fit.
+ q.mu.Unlock()
+ return 0, false, syserr.ErrMessageTooLong
+ }
+
+ if l > free {
+ // Message can't fit right now, and could not be truncated.
+ q.mu.Unlock()
+ return 0, false, syserr.ErrWouldBlock
+ }
+
+ // Aggregate l bytes of data. This will truncate the data if l is less than
+ // the total bytes held in data.
+ v := make([]byte, l)
+ for i, b := 0, v; i < len(data) && len(b) > 0; i++ {
+ n := copy(b, data[i])
+ b = b[n:]
+ }
+
+ notify = q.dataList.Front() == nil
+ q.used += l
+ q.dataList.PushBack(&message{
+ Data: buffer.View(v),
+ Control: c,
+ Address: from,
+ })
+
+ q.mu.Unlock()
+
+ return l, notify, err
+}
+
+// Dequeue removes the first entry in the data queue, if one exists.
+//
+// If notify is true, WriterQueue.Notify must be called:
+// q.WriterQueue.Notify(waiter.EventOut)
+func (q *queue) Dequeue() (e *message, notify bool, err *syserr.Error) {
+ q.mu.Lock()
+
+ if q.dataList.Front() == nil {
+ err := syserr.ErrWouldBlock
+ if q.closed {
+ err = syserr.ErrClosedForReceive
+ if q.unread {
+ err = syserr.ErrConnectionReset
+ }
+ }
+ q.mu.Unlock()
+
+ return nil, false, err
+ }
+
+ notify = !q.bufWritable()
+
+ e = q.dataList.Front()
+ q.dataList.Remove(e)
+ q.used -= e.Length()
+
+ notify = notify && q.bufWritable()
+
+ q.mu.Unlock()
+
+ return e, notify, nil
+}
+
+// Peek returns the first entry in the data queue, if one exists.
+func (q *queue) Peek() (*message, *syserr.Error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ if q.dataList.Front() == nil {
+ err := syserr.ErrWouldBlock
+ if q.closed {
+ if err = syserr.ErrClosedForReceive; q.unread {
+ err = syserr.ErrConnectionReset
+ }
+ }
+ return nil, err
+ }
+
+ return q.dataList.Front().Peek(), nil
+}
+
+// QueuedSize returns the number of bytes currently in the queue, that is, the
+// number of readable bytes.
+func (q *queue) QueuedSize() int64 {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ return q.used
+}
+
+// MaxQueueSize returns the maximum number of bytes storable in the queue.
+func (q *queue) MaxQueueSize() int64 {
+ return q.limit
+}
+
+// CloseUnread sets flag to indicate that the peer is closed (not shutdown)
+// with unread data. So if read on this queue shall return ECONNRESET error.
+func (q *queue) CloseUnread() {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ q.unread = true
+}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
new file mode 100644
index 000000000..2f1b127df
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -0,0 +1,1006 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package transport contains the implementation of Unix endpoints.
+package transport
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// initialLimit is the starting limit for the socket buffers.
+const initialLimit = 16 * 1024
+
+// A RightsControlMessage is a control message containing FDs.
+type RightsControlMessage interface {
+ // Clone returns a copy of the RightsControlMessage.
+ Clone() RightsControlMessage
+
+ // Release releases any resources owned by the RightsControlMessage.
+ Release()
+}
+
+// A CredentialsControlMessage is a control message containing Unix credentials.
+type CredentialsControlMessage interface {
+ // Equals returns true iff the two messages are equal.
+ Equals(CredentialsControlMessage) bool
+}
+
+// A ControlMessages represents a collection of socket control messages.
+//
+// +stateify savable
+type ControlMessages struct {
+ // Rights is a control message containing FDs.
+ Rights RightsControlMessage
+
+ // Credentials is a control message containing Unix credentials.
+ Credentials CredentialsControlMessage
+}
+
+// Empty returns true iff the ControlMessages does not contain either
+// credentials or rights.
+func (c *ControlMessages) Empty() bool {
+ return c.Rights == nil && c.Credentials == nil
+}
+
+// Clone clones both the credentials and the rights.
+func (c *ControlMessages) Clone() ControlMessages {
+ cm := ControlMessages{}
+ if c.Rights != nil {
+ cm.Rights = c.Rights.Clone()
+ }
+ cm.Credentials = c.Credentials
+ return cm
+}
+
+// Release releases both the credentials and the rights.
+func (c *ControlMessages) Release() {
+ if c.Rights != nil {
+ c.Rights.Release()
+ }
+ *c = ControlMessages{}
+}
+
+// Endpoint is the interface implemented by Unix transport protocol
+// implementations that expose functionality like sendmsg, recvmsg, connect,
+// etc. to Unix socket implementations.
+type Endpoint interface {
+ Credentialer
+ waiter.Waitable
+
+ // Close puts the endpoint in a closed state and frees all resources
+ // associated with it.
+ Close()
+
+ // RecvMsg reads data and a control message from the endpoint. This method
+ // does not block if there is no data pending.
+ //
+ // creds indicates if credential control messages are requested by the
+ // caller. This is useful for determining if control messages can be
+ // coalesced. creds is a hint and can be safely ignored by the
+ // implementation if no coalescing is possible. It is fine to return
+ // credential control messages when none were requested or to not return
+ // credential control messages when they were requested.
+ //
+ // numRights is the number of SCM_RIGHTS FDs requested by the caller. This
+ // is useful if one must allocate a buffer to receive a SCM_RIGHTS message
+ // or determine if control messages can be coalesced. numRights is a hint
+ // and can be safely ignored by the implementation if the number of
+ // available SCM_RIGHTS FDs is known and no coalescing is possible. It is
+ // fine for the returned number of SCM_RIGHTS FDs to be either higher or
+ // lower than the requested number.
+ //
+ // If peek is true, no data should be consumed from the Endpoint. Any and
+ // all data returned from a peek should be available in the next call to
+ // RecvMsg.
+ //
+ // recvLen is the number of bytes copied into data.
+ //
+ // msgLen is the length of the read message consumed for datagram Endpoints.
+ // msgLen is always the same as recvLen for stream Endpoints.
+ //
+ // CMTruncated indicates that the numRights hint was used to receive fewer
+ // than the total available SCM_RIGHTS FDs. Additional truncation may be
+ // required by the caller.
+ RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, err *syserr.Error)
+
+ // SendMsg writes data and a control message to the endpoint's peer.
+ // This method does not block if the data cannot be written.
+ //
+ // SendMsg does not take ownership of any of its arguments on error.
+ SendMsg(context.Context, [][]byte, ControlMessages, BoundEndpoint) (int64, *syserr.Error)
+
+ // Connect connects this endpoint directly to another.
+ //
+ // This should be called on the client endpoint, and the (bound)
+ // endpoint passed in as a parameter.
+ //
+ // The error codes are the same as Connect.
+ Connect(ctx context.Context, server BoundEndpoint) *syserr.Error
+
+ // Shutdown closes the read and/or write end of the endpoint connection
+ // to its peer.
+ Shutdown(flags tcpip.ShutdownFlags) *syserr.Error
+
+ // Listen puts the endpoint in "listen" mode, which allows it to accept
+ // new connections.
+ Listen(backlog int) *syserr.Error
+
+ // Accept returns a new endpoint if a peer has established a connection
+ // to an endpoint previously set to listen mode. This method does not
+ // block if no new connections are available.
+ //
+ // The returned Queue is the wait queue for the newly created endpoint.
+ Accept() (Endpoint, *syserr.Error)
+
+ // Bind binds the endpoint to a specific local address and port.
+ // Specifying a NIC is optional.
+ //
+ // An optional commit function will be executed atomically with respect
+ // to binding the endpoint. If this returns an error, the bind will not
+ // occur and the error will be propagated back to the caller.
+ Bind(address tcpip.FullAddress, commit func() *syserr.Error) *syserr.Error
+
+ // Type return the socket type, typically either SockStream, SockDgram
+ // or SockSeqpacket.
+ Type() linux.SockType
+
+ // GetLocalAddress returns the address to which the endpoint is bound.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // GetRemoteAddress returns the address to which the endpoint is
+ // connected.
+ GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // SetSockOpt sets a socket option. opt should be one of the tcpip.*Option
+ // types.
+ SetSockOpt(opt interface{}) *tcpip.Error
+
+ // SetSockOptBool sets a socket option for simple cases when a value has
+ // the int type.
+ SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
+
+ // SetSockOptInt sets a socket option for simple cases when a value has
+ // the int type.
+ SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
+
+ // GetSockOpt gets a socket option. opt should be a pointer to one of the
+ // tcpip.*Option types.
+ GetSockOpt(opt interface{}) *tcpip.Error
+
+ // GetSockOptBool gets a socket option for simple cases when a return
+ // value has the int type.
+ GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
+
+ // GetSockOptInt gets a socket option for simple cases when a return
+ // value has the int type.
+ GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
+
+ // State returns the current state of the socket, as represented by Linux in
+ // procfs.
+ State() uint32
+}
+
+// A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket
+// option.
+type Credentialer interface {
+ // Passcred returns whether or not the SO_PASSCRED socket option is
+ // enabled on this end.
+ Passcred() bool
+
+ // ConnectedPasscred returns whether or not the SO_PASSCRED socket option
+ // is enabled on the connected end.
+ ConnectedPasscred() bool
+}
+
+// A BoundEndpoint is a unix endpoint that can be connected to.
+type BoundEndpoint interface {
+ // BidirectionalConnect establishes a bi-directional connection between two
+ // unix endpoints in an all-or-nothing manner. If an error occurs during
+ // connecting, the state of neither endpoint should be modified.
+ //
+ // In order for an endpoint to establish such a bidirectional connection
+ // with a BoundEndpoint, the endpoint calls the BidirectionalConnect method
+ // on the BoundEndpoint and sends a representation of itself (the
+ // ConnectingEndpoint) and a callback (returnConnect) to receive the
+ // connection information (Receiver and ConnectedEndpoint) upon a
+ // successful connect. The callback should only be called on a successful
+ // connect.
+ //
+ // For a connection attempt to be successful, the ConnectingEndpoint must
+ // be unconnected and not listening and the BoundEndpoint whose
+ // BidirectionalConnect method is being called must be listening.
+ //
+ // This method will return syserr.ErrConnectionRefused on endpoints with a
+ // type that isn't SockStream or SockSeqpacket.
+ BidirectionalConnect(ctx context.Context, ep ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error
+
+ // UnidirectionalConnect establishes a write-only connection to a unix
+ // endpoint.
+ //
+ // An endpoint which calls UnidirectionalConnect and supports it itself must
+ // not hold its own lock when calling UnidirectionalConnect.
+ //
+ // This method will return syserr.ErrConnectionRefused on a non-SockDgram
+ // endpoint.
+ UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error)
+
+ // Passcred returns whether or not the SO_PASSCRED socket option is
+ // enabled on this end.
+ Passcred() bool
+
+ // Release releases any resources held by the BoundEndpoint. It must be
+ // called before dropping all references to a BoundEndpoint returned by a
+ // function.
+ Release()
+}
+
+// message represents a message passed over a Unix domain socket.
+//
+// +stateify savable
+type message struct {
+ messageEntry
+
+ // Data is the Message payload.
+ Data buffer.View
+
+ // Control is auxiliary control message data that goes along with the
+ // data.
+ Control ControlMessages
+
+ // Address is the bound address of the endpoint that sent the message.
+ //
+ // If the endpoint that sent the message is not bound, the Address is
+ // the empty string.
+ Address tcpip.FullAddress
+}
+
+// Length returns number of bytes stored in the message.
+func (m *message) Length() int64 {
+ return int64(len(m.Data))
+}
+
+// Release releases any resources held by the message.
+func (m *message) Release() {
+ m.Control.Release()
+}
+
+// Peek returns a copy of the message.
+func (m *message) Peek() *message {
+ return &message{Data: m.Data, Control: m.Control.Clone(), Address: m.Address}
+}
+
+// Truncate reduces the length of the message payload to n bytes.
+//
+// Preconditions: n <= m.Length().
+func (m *message) Truncate(n int64) {
+ m.Data.CapLength(int(n))
+}
+
+// A Receiver can be used to receive Messages.
+type Receiver interface {
+ // Recv receives a single message. This method does not block.
+ //
+ // See Endpoint.RecvMsg for documentation on shared arguments.
+ //
+ // notify indicates if RecvNotify should be called.
+ Recv(data [][]byte, creds bool, numRights int, peek bool) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error)
+
+ // RecvNotify notifies the Receiver of a successful Recv. This must not be
+ // called while holding any endpoint locks.
+ RecvNotify()
+
+ // CloseRecv prevents the receiving of additional Messages.
+ //
+ // After CloseRecv is called, CloseNotify must also be called.
+ CloseRecv()
+
+ // CloseNotify notifies the Receiver of recv being closed. This must not be
+ // called while holding any endpoint locks.
+ CloseNotify()
+
+ // Readable returns if messages should be attempted to be received. This
+ // includes when read has been shutdown.
+ Readable() bool
+
+ // RecvQueuedSize returns the total amount of data currently receivable.
+ // RecvQueuedSize should return -1 if the operation isn't supported.
+ RecvQueuedSize() int64
+
+ // RecvMaxQueueSize returns maximum value for RecvQueuedSize.
+ // RecvMaxQueueSize should return -1 if the operation isn't supported.
+ RecvMaxQueueSize() int64
+
+ // Release releases any resources owned by the Receiver. It should be
+ // called before droping all references to a Receiver.
+ Release()
+}
+
+// queueReceiver implements Receiver for datagram sockets.
+//
+// +stateify savable
+type queueReceiver struct {
+ readQueue *queue
+}
+
+// Recv implements Receiver.Recv.
+func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+ var m *message
+ var notify bool
+ var err *syserr.Error
+ if peek {
+ m, err = q.readQueue.Peek()
+ } else {
+ m, notify, err = q.readQueue.Dequeue()
+ }
+ if err != nil {
+ return 0, 0, ControlMessages{}, false, tcpip.FullAddress{}, false, err
+ }
+ src := []byte(m.Data)
+ var copied int64
+ for i := 0; i < len(data) && len(src) > 0; i++ {
+ n := copy(data[i], src)
+ copied += int64(n)
+ src = src[n:]
+ }
+ return copied, int64(len(m.Data)), m.Control, false, m.Address, notify, nil
+}
+
+// RecvNotify implements Receiver.RecvNotify.
+func (q *queueReceiver) RecvNotify() {
+ q.readQueue.WriterQueue.Notify(waiter.EventOut)
+}
+
+// CloseNotify implements Receiver.CloseNotify.
+func (q *queueReceiver) CloseNotify() {
+ q.readQueue.ReaderQueue.Notify(waiter.EventIn)
+ q.readQueue.WriterQueue.Notify(waiter.EventOut)
+}
+
+// CloseRecv implements Receiver.CloseRecv.
+func (q *queueReceiver) CloseRecv() {
+ q.readQueue.Close()
+}
+
+// Readable implements Receiver.Readable.
+func (q *queueReceiver) Readable() bool {
+ return q.readQueue.IsReadable()
+}
+
+// RecvQueuedSize implements Receiver.RecvQueuedSize.
+func (q *queueReceiver) RecvQueuedSize() int64 {
+ return q.readQueue.QueuedSize()
+}
+
+// RecvMaxQueueSize implements Receiver.RecvMaxQueueSize.
+func (q *queueReceiver) RecvMaxQueueSize() int64 {
+ return q.readQueue.MaxQueueSize()
+}
+
+// Release implements Receiver.Release.
+func (q *queueReceiver) Release() {
+ q.readQueue.DecRef()
+}
+
+// streamQueueReceiver implements Receiver for stream sockets.
+//
+// +stateify savable
+type streamQueueReceiver struct {
+ queueReceiver
+
+ mu sync.Mutex `state:"nosave"`
+ buffer []byte
+ control ControlMessages
+ addr tcpip.FullAddress
+}
+
+func vecCopy(data [][]byte, buf []byte) (int64, [][]byte, []byte) {
+ var copied int64
+ for len(data) > 0 && len(buf) > 0 {
+ n := copy(data[0], buf)
+ copied += int64(n)
+ buf = buf[n:]
+ data[0] = data[0][n:]
+ if len(data[0]) == 0 {
+ data = data[1:]
+ }
+ }
+ return copied, data, buf
+}
+
+// Readable implements Receiver.Readable.
+func (q *streamQueueReceiver) Readable() bool {
+ q.mu.Lock()
+ bl := len(q.buffer)
+ r := q.readQueue.IsReadable()
+ q.mu.Unlock()
+ // We're readable if we have data in our buffer or if the queue receiver is
+ // readable.
+ return bl > 0 || r
+}
+
+// RecvQueuedSize implements Receiver.RecvQueuedSize.
+func (q *streamQueueReceiver) RecvQueuedSize() int64 {
+ q.mu.Lock()
+ bl := len(q.buffer)
+ qs := q.readQueue.QueuedSize()
+ q.mu.Unlock()
+ return int64(bl) + qs
+}
+
+// RecvMaxQueueSize implements Receiver.RecvMaxQueueSize.
+func (q *streamQueueReceiver) RecvMaxQueueSize() int64 {
+ // The RecvMaxQueueSize() is the readQueue's MaxQueueSize() plus the largest
+ // message we can buffer which is also the largest message we can receive.
+ return 2 * q.readQueue.MaxQueueSize()
+}
+
+// Recv implements Receiver.Recv.
+func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ var notify bool
+
+ // If we have no data in the endpoint, we need to get some.
+ if len(q.buffer) == 0 {
+ // Load the next message into a buffer, even if we are peeking. Peeking
+ // won't consume the message, so it will be still available to be read
+ // the next time Recv() is called.
+ m, n, err := q.readQueue.Dequeue()
+ if err != nil {
+ return 0, 0, ControlMessages{}, false, tcpip.FullAddress{}, false, err
+ }
+ notify = n
+ q.buffer = []byte(m.Data)
+ q.control = m.Control
+ q.addr = m.Address
+ }
+
+ var copied int64
+ if peek {
+ // Don't consume control message if we are peeking.
+ c := q.control.Clone()
+
+ // Don't consume data since we are peeking.
+ copied, data, _ = vecCopy(data, q.buffer)
+
+ return copied, copied, c, false, q.addr, notify, nil
+ }
+
+ // Consume data and control message since we are not peeking.
+ copied, data, q.buffer = vecCopy(data, q.buffer)
+
+ // Save the original state of q.control.
+ c := q.control
+
+ // Remove rights from q.control and leave behind just the creds.
+ q.control.Rights = nil
+ if !wantCreds {
+ c.Credentials = nil
+ }
+
+ var cmTruncated bool
+ if c.Rights != nil && numRights == 0 {
+ c.Rights.Release()
+ c.Rights = nil
+ cmTruncated = true
+ }
+
+ haveRights := c.Rights != nil
+
+ // If we have more capacity for data and haven't received any usable
+ // rights.
+ //
+ // Linux never coalesces rights control messages.
+ for !haveRights && len(data) > 0 {
+ // Get a message from the readQueue.
+ m, n, err := q.readQueue.Dequeue()
+ if err != nil {
+ // We already got some data, so ignore this error. This will
+ // manifest as a short read to the user, which is what Linux
+ // does.
+ break
+ }
+ notify = notify || n
+ q.buffer = []byte(m.Data)
+ q.control = m.Control
+ q.addr = m.Address
+
+ if wantCreds {
+ if (q.control.Credentials == nil) != (c.Credentials == nil) {
+ // One message has credentials, the other does not.
+ break
+ }
+
+ if q.control.Credentials != nil && c.Credentials != nil && !q.control.Credentials.Equals(c.Credentials) {
+ // Both messages have credentials, but they don't match.
+ break
+ }
+ }
+
+ if numRights != 0 && c.Rights != nil && q.control.Rights != nil {
+ // Both messages have rights.
+ break
+ }
+
+ var cpd int64
+ cpd, data, q.buffer = vecCopy(data, q.buffer)
+ copied += cpd
+
+ if cpd == 0 {
+ // data was actually full.
+ break
+ }
+
+ if q.control.Rights != nil {
+ // Consume rights.
+ if numRights == 0 {
+ cmTruncated = true
+ q.control.Rights.Release()
+ } else {
+ c.Rights = q.control.Rights
+ haveRights = true
+ }
+ q.control.Rights = nil
+ }
+ }
+ return copied, copied, c, cmTruncated, q.addr, notify, nil
+}
+
+// A ConnectedEndpoint is an Endpoint that can be used to send Messages.
+type ConnectedEndpoint interface {
+ // Passcred implements Endpoint.Passcred.
+ Passcred() bool
+
+ // GetLocalAddress implements Endpoint.GetLocalAddress.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // Send sends a single message. This method does not block.
+ //
+ // notify indicates if SendNotify should be called.
+ //
+ // syserr.ErrWouldBlock can be returned along with a partial write if
+ // the caller should block to send the rest of the data.
+ Send(data [][]byte, c ControlMessages, from tcpip.FullAddress) (n int64, notify bool, err *syserr.Error)
+
+ // SendNotify notifies the ConnectedEndpoint of a successful Send. This
+ // must not be called while holding any endpoint locks.
+ SendNotify()
+
+ // CloseSend prevents the sending of additional Messages.
+ //
+ // After CloseSend is call, CloseNotify must also be called.
+ CloseSend()
+
+ // CloseNotify notifies the ConnectedEndpoint of send being closed. This
+ // must not be called while holding any endpoint locks.
+ CloseNotify()
+
+ // Writable returns if messages should be attempted to be sent. This
+ // includes when write has been shutdown.
+ Writable() bool
+
+ // EventUpdate lets the ConnectedEndpoint know that event registrations
+ // have changed.
+ EventUpdate()
+
+ // SendQueuedSize returns the total amount of data currently queued for
+ // sending. SendQueuedSize should return -1 if the operation isn't
+ // supported.
+ SendQueuedSize() int64
+
+ // SendMaxQueueSize returns maximum value for SendQueuedSize.
+ // SendMaxQueueSize should return -1 if the operation isn't supported.
+ SendMaxQueueSize() int64
+
+ // Release releases any resources owned by the ConnectedEndpoint. It should
+ // be called before droping all references to a ConnectedEndpoint.
+ Release()
+
+ // CloseUnread sets the fact that this end is closed with unread data to
+ // the peer socket.
+ CloseUnread()
+}
+
+// +stateify savable
+type connectedEndpoint struct {
+ // endpoint represents the subset of the Endpoint functionality needed by
+ // the connectedEndpoint. It is implemented by both connectionedEndpoint
+ // and connectionlessEndpoint and allows the use of types which don't
+ // fully implement Endpoint.
+ endpoint interface {
+ // Passcred implements Endpoint.Passcred.
+ Passcred() bool
+
+ // GetLocalAddress implements Endpoint.GetLocalAddress.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // Type implements Endpoint.Type.
+ Type() linux.SockType
+ }
+
+ writeQueue *queue
+}
+
+// Passcred implements ConnectedEndpoint.Passcred.
+func (e *connectedEndpoint) Passcred() bool {
+ return e.endpoint.Passcred()
+}
+
+// GetLocalAddress implements ConnectedEndpoint.GetLocalAddress.
+func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return e.endpoint.GetLocalAddress()
+}
+
+// Send implements ConnectedEndpoint.Send.
+func (e *connectedEndpoint) Send(data [][]byte, c ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
+ discardEmpty := false
+ truncate := false
+ if e.endpoint.Type() == linux.SOCK_STREAM {
+ // Discard empty stream packets. Since stream sockets don't
+ // preserve message boundaries, sending zero bytes is a no-op.
+ // In Linux, the receiver actually uses a zero-length receive
+ // as an indication that the stream was closed.
+ discardEmpty = true
+
+ // Since stream sockets don't preserve message boundaries, we
+ // can write only as much of the message as fits in the queue.
+ truncate = true
+ }
+
+ return e.writeQueue.Enqueue(data, c, from, discardEmpty, truncate)
+}
+
+// SendNotify implements ConnectedEndpoint.SendNotify.
+func (e *connectedEndpoint) SendNotify() {
+ e.writeQueue.ReaderQueue.Notify(waiter.EventIn)
+}
+
+// CloseNotify implements ConnectedEndpoint.CloseNotify.
+func (e *connectedEndpoint) CloseNotify() {
+ e.writeQueue.ReaderQueue.Notify(waiter.EventIn)
+ e.writeQueue.WriterQueue.Notify(waiter.EventOut)
+}
+
+// CloseSend implements ConnectedEndpoint.CloseSend.
+func (e *connectedEndpoint) CloseSend() {
+ e.writeQueue.Close()
+}
+
+// Writable implements ConnectedEndpoint.Writable.
+func (e *connectedEndpoint) Writable() bool {
+ return e.writeQueue.IsWritable()
+}
+
+// EventUpdate implements ConnectedEndpoint.EventUpdate.
+func (*connectedEndpoint) EventUpdate() {}
+
+// SendQueuedSize implements ConnectedEndpoint.SendQueuedSize.
+func (e *connectedEndpoint) SendQueuedSize() int64 {
+ return e.writeQueue.QueuedSize()
+}
+
+// SendMaxQueueSize implements ConnectedEndpoint.SendMaxQueueSize.
+func (e *connectedEndpoint) SendMaxQueueSize() int64 {
+ return e.writeQueue.MaxQueueSize()
+}
+
+// Release implements ConnectedEndpoint.Release.
+func (e *connectedEndpoint) Release() {
+ e.writeQueue.DecRef()
+}
+
+// CloseUnread implements ConnectedEndpoint.CloseUnread.
+func (e *connectedEndpoint) CloseUnread() {
+ e.writeQueue.CloseUnread()
+}
+
+// baseEndpoint is an embeddable unix endpoint base used in both the connected and connectionless
+// unix domain socket Endpoint implementations.
+//
+// Not to be used on its own.
+//
+// +stateify savable
+type baseEndpoint struct {
+ *waiter.Queue
+
+ // passcred specifies whether SCM_CREDENTIALS socket control messages are
+ // enabled on this endpoint. Must be accessed atomically.
+ passcred int32
+
+ // Mutex protects the below fields.
+ sync.Mutex `state:"nosave"`
+
+ // receiver allows Messages to be received.
+ receiver Receiver
+
+ // connected allows messages to be sent and state information about the
+ // connected endpoint to be read.
+ connected ConnectedEndpoint
+
+ // path is not empty if the endpoint has been bound,
+ // or may be used if the endpoint is connected.
+ path string
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (e *baseEndpoint) EventRegister(we *waiter.Entry, mask waiter.EventMask) {
+ e.Queue.EventRegister(we, mask)
+ e.Lock()
+ if e.connected != nil {
+ e.connected.EventUpdate()
+ }
+ e.Unlock()
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (e *baseEndpoint) EventUnregister(we *waiter.Entry) {
+ e.Queue.EventUnregister(we)
+ e.Lock()
+ if e.connected != nil {
+ e.connected.EventUpdate()
+ }
+ e.Unlock()
+}
+
+// Passcred implements Credentialer.Passcred.
+func (e *baseEndpoint) Passcred() bool {
+ return atomic.LoadInt32(&e.passcred) != 0
+}
+
+// ConnectedPasscred implements Credentialer.ConnectedPasscred.
+func (e *baseEndpoint) ConnectedPasscred() bool {
+ e.Lock()
+ defer e.Unlock()
+ return e.connected != nil && e.connected.Passcred()
+}
+
+func (e *baseEndpoint) setPasscred(pc bool) {
+ if pc {
+ atomic.StoreInt32(&e.passcred, 1)
+ } else {
+ atomic.StoreInt32(&e.passcred, 0)
+ }
+}
+
+// Connected implements ConnectingEndpoint.Connected.
+func (e *baseEndpoint) Connected() bool {
+ return e.receiver != nil && e.connected != nil
+}
+
+// RecvMsg reads data and a control message from the endpoint.
+func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool, addr *tcpip.FullAddress) (int64, int64, ControlMessages, bool, *syserr.Error) {
+ e.Lock()
+
+ if e.receiver == nil {
+ e.Unlock()
+ return 0, 0, ControlMessages{}, false, syserr.ErrNotConnected
+ }
+
+ recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(data, creds, numRights, peek)
+ e.Unlock()
+ if err != nil {
+ return 0, 0, ControlMessages{}, false, err
+ }
+
+ if notify {
+ e.receiver.RecvNotify()
+ }
+
+ if addr != nil {
+ *addr = a
+ }
+ return recvLen, msgLen, cms, cmt, nil
+}
+
+// SendMsg writes data and a control message to the endpoint's peer.
+// This method does not block if the data cannot be written.
+func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (int64, *syserr.Error) {
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return 0, syserr.ErrNotConnected
+ }
+ if to != nil {
+ e.Unlock()
+ return 0, syserr.ErrAlreadyConnected
+ }
+
+ n, notify, err := e.connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
+ e.Unlock()
+
+ if notify {
+ e.connected.SendNotify()
+ }
+
+ return n, err
+}
+
+// SetSockOpt sets a socket option. Currently not supported.
+func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return nil
+}
+
+func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.BroadcastOption:
+ case tcpip.PasscredOption:
+ e.setPasscred(v)
+ case tcpip.ReuseAddressOption:
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ }
+ return nil
+}
+
+func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ case tcpip.ReceiveBufferSizeOption:
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ }
+ return nil
+}
+
+func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ switch opt {
+ case tcpip.KeepaliveEnabledOption:
+ return false, nil
+
+ case tcpip.PasscredOption:
+ return e.Passcred(), nil
+
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+ switch opt {
+ case tcpip.ReceiveQueueSizeOption:
+ v := 0
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return -1, tcpip.ErrNotConnected
+ }
+ v = int(e.receiver.RecvQueuedSize())
+ e.Unlock()
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
+ }
+ return v, nil
+
+ case tcpip.SendQueueSizeOption:
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return -1, tcpip.ErrNotConnected
+ }
+ v := e.connected.SendQueuedSize()
+ e.Unlock()
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
+ }
+ return int(v), nil
+
+ case tcpip.SendBufferSizeOption:
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return -1, tcpip.ErrNotConnected
+ }
+ v := e.connected.SendMaxQueueSize()
+ e.Unlock()
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
+ }
+ return int(v), nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.Lock()
+ if e.receiver == nil {
+ e.Unlock()
+ return -1, tcpip.ErrNotConnected
+ }
+ v := e.receiver.RecvMaxQueueSize()
+ e.Unlock()
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
+ }
+ return int(v), nil
+
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ default:
+ log.Warningf("Unsupported socket option: %T", opt)
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection to its
+// peer.
+func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *syserr.Error {
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return syserr.ErrNotConnected
+ }
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.receiver.CloseRecv()
+ }
+
+ if flags&tcpip.ShutdownWrite != 0 {
+ e.connected.CloseSend()
+ }
+
+ e.Unlock()
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.receiver.CloseNotify()
+ }
+
+ if flags&tcpip.ShutdownWrite != 0 {
+ e.connected.CloseNotify()
+ }
+
+ return nil
+}
+
+// GetLocalAddress returns the bound path.
+func (e *baseEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.Lock()
+ defer e.Unlock()
+ return tcpip.FullAddress{Addr: tcpip.Address(e.path)}, nil
+}
+
+// GetRemoteAddress returns the local address of the connected endpoint (if
+// available).
+func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.Lock()
+ c := e.connected
+ e.Unlock()
+ if c != nil {
+ return c.GetLocalAddress()
+ }
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+}
+
+// Release implements BoundEndpoint.Release.
+func (*baseEndpoint) Release() {
+ // Binding a baseEndpoint doesn't take a reference.
+}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
new file mode 100644
index 000000000..4bb2b6ff4
--- /dev/null
+++ b/pkg/sentry/socket/unix/unix.go
@@ -0,0 +1,772 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package unix provides an implementation of the socket.Socket interface for
+// the AF_UNIX protocol family.
+package unix
+
+import (
+ "fmt"
+ "strings"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SocketOperations is a Unix socket. It is similar to a netstack socket,
+// except it is backed by a transport.Endpoint instead of a tcpip.Endpoint.
+//
+// +stateify savable
+type SocketOperations struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ socketOpsCommon
+}
+
+// New creates a new unix socket.
+func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File {
+ dirent := socket.NewDirent(ctx, unixSocketDevice)
+ defer dirent.DecRef()
+ return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true, NonSeekable: true})
+}
+
+// NewWithDirent creates a new unix socket using an existing dirent.
+func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, stype linux.SockType, flags fs.FileFlags) *fs.File {
+ // You can create AF_UNIX, SOCK_RAW sockets. They're the same as
+ // SOCK_DGRAM and don't require CAP_NET_RAW.
+ if stype == linux.SOCK_RAW {
+ stype = linux.SOCK_DGRAM
+ }
+
+ s := SocketOperations{
+ socketOpsCommon: socketOpsCommon{
+ ep: ep,
+ stype: stype,
+ },
+ }
+ s.EnableLeakCheck("unix.SocketOperations")
+
+ return fs.NewFile(ctx, d, flags, &s)
+}
+
+// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
+//
+// +stateify savable
+type socketOpsCommon struct {
+ refs.AtomicRefCount
+ socket.SendReceiveTimeout
+
+ ep transport.Endpoint
+ stype linux.SockType
+}
+
+// DecRef implements RefCounter.DecRef.
+func (s *socketOpsCommon) DecRef() {
+ s.DecRefWithDestructor(func() {
+ s.ep.Close()
+ })
+}
+
+// Release implemements fs.FileOperations.Release.
+func (s *socketOpsCommon) Release() {
+ // Release only decrements a reference on s because s may be referenced in
+ // the abstract socket namespace.
+ s.DecRef()
+}
+
+func (s *socketOpsCommon) isPacket() bool {
+ switch s.stype {
+ case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+ return true
+ case linux.SOCK_STREAM:
+ return false
+ default:
+ // We shouldn't have allowed any other socket types during creation.
+ panic(fmt.Sprintf("Invalid socket type %d", s.stype))
+ }
+}
+
+// Endpoint extracts the transport.Endpoint.
+func (s *socketOpsCommon) Endpoint() transport.Endpoint {
+ return s.ep
+}
+
+// extractPath extracts and validates the address.
+func extractPath(sockaddr []byte) (string, *syserr.Error) {
+ addr, family, err := netstack.AddressAndFamily(sockaddr)
+ if err != nil {
+ if err == syserr.ErrAddressFamilyNotSupported {
+ err = syserr.ErrInvalidArgument
+ }
+ return "", err
+ }
+ if family != linux.AF_UNIX {
+ return "", syserr.ErrInvalidArgument
+ }
+
+ // The address is trimmed by GetAddress.
+ p := string(addr.Addr)
+ if p == "" {
+ // Not allowed.
+ return "", syserr.ErrInvalidArgument
+ }
+ if p[len(p)-1] == '/' {
+ // Weird, they tried to bind '/a/b/c/'?
+ return "", syserr.ErrIsDir
+ }
+
+ return p, nil
+}
+
+// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
+// a transport.Endpoint.
+func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+ addr, err := s.ep.GetRemoteAddress()
+ if err != nil {
+ return nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
+ return a, l, nil
+}
+
+// GetSockName implements the linux syscall getsockname(2) for sockets backed by
+// a transport.Endpoint.
+func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+ addr, err := s.ep.GetLocalAddress()
+ if err != nil {
+ return nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
+ return a, l, nil
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return netstack.Ioctl(ctx, s.ep, io, args)
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+ return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
+}
+
+// Listen implements the linux syscall listen(2) for sockets backed by
+// a transport.Endpoint.
+func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
+ return s.ep.Listen(backlog)
+}
+
+// blockingAccept implements a blocking version of accept(2), that is, if no
+// connections are ready to be accept, it will block until one becomes ready.
+func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) {
+ // Register for notifications.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ // Try to accept the connection; if it fails, then wait until we get a
+ // notification.
+ for {
+ if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock {
+ return ep, err
+ }
+
+ if err := t.Block(ch); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ }
+}
+
+// Accept implements the linux syscall accept(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ // Issue the accept request to get the new endpoint.
+ ep, err := s.ep.Accept()
+ if err != nil {
+ if err != syserr.ErrWouldBlock || !blocking {
+ return 0, nil, 0, err
+ }
+
+ var err *syserr.Error
+ ep, err = s.blockingAccept(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ ns := New(t, ep, s.stype)
+ defer ns.DecRef()
+
+ if flags&linux.SOCK_NONBLOCK != 0 {
+ flags := ns.Flags()
+ flags.NonBlocking = true
+ ns.SetFlags(flags.Settable())
+ }
+
+ var addr linux.SockAddr
+ var addrLen uint32
+ if peerRequested {
+ // Get address of the peer.
+ var err *syserr.Error
+ addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ })
+ if e != nil {
+ return 0, nil, 0, syserr.FromError(e)
+ }
+
+ t.Kernel().RecordSocket(ns)
+
+ return fd, addr, addrLen, nil
+}
+
+// Bind implements the linux syscall bind(2) for unix sockets.
+func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ p, e := extractPath(sockaddr)
+ if e != nil {
+ return e
+ }
+
+ bep, ok := s.ep.(transport.BoundEndpoint)
+ if !ok {
+ // This socket can't be bound.
+ return syserr.ErrInvalidArgument
+ }
+
+ return s.ep.Bind(tcpip.FullAddress{Addr: tcpip.Address(p)}, func() *syserr.Error {
+ // Is it abstract?
+ if p[0] == 0 {
+ if t.IsNetworkNamespaced() {
+ return syserr.ErrInvalidEndpointState
+ }
+ if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil {
+ // syserr.ErrPortInUse corresponds to EADDRINUSE.
+ return syserr.ErrPortInUse
+ }
+ } else {
+ // The parent and name.
+ var d *fs.Dirent
+ var name string
+
+ cwd := t.FSContext().WorkingDirectory()
+ defer cwd.DecRef()
+
+ // Is there no slash at all?
+ if !strings.Contains(p, "/") {
+ d = cwd
+ name = p
+ } else {
+ root := t.FSContext().RootDirectory()
+ defer root.DecRef()
+ // Find the last path component, we know that something follows
+ // that final slash, otherwise extractPath() would have failed.
+ lastSlash := strings.LastIndex(p, "/")
+ subPath := p[:lastSlash]
+ if subPath == "" {
+ // Fix up subpath in case file is in root.
+ subPath = "/"
+ }
+ var err error
+ remainingTraversals := uint(fs.DefaultTraversalLimit)
+ d, err = t.MountNamespace().FindInode(t, root, cwd, subPath, &remainingTraversals)
+ if err != nil {
+ // No path available.
+ return syserr.ErrNoSuchFile
+ }
+ defer d.DecRef()
+ name = p[lastSlash+1:]
+ }
+
+ // Create the socket.
+ //
+ // Note that the file permissions here are not set correctly (see
+ // gvisor.dev/issue/2324). There is no convenient way to get permissions
+ // on the socket referred to by s, so we will leave this discrepancy
+ // unresolved until VFS2 replaces this code.
+ childDir, err := d.Bind(t, t.FSContext().RootDirectory(), name, bep, fs.FilePermissions{User: fs.PermMask{Read: true}})
+ if err != nil {
+ return syserr.ErrPortInUse
+ }
+ childDir.DecRef()
+ }
+
+ return nil
+ })
+}
+
+// extractEndpoint retrieves the transport.BoundEndpoint associated with a Unix
+// socket path. The Release must be called on the transport.BoundEndpoint when
+// the caller is done with it.
+func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, *syserr.Error) {
+ path, err := extractPath(sockaddr)
+ if err != nil {
+ return nil, err
+ }
+
+ // Is it abstract?
+ if path[0] == 0 {
+ if t.IsNetworkNamespaced() {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ ep := t.AbstractSockets().BoundEndpoint(path[1:])
+ if ep == nil {
+ // No socket found.
+ return nil, syserr.ErrConnectionRefused
+ }
+
+ return ep, nil
+ }
+
+ if kernel.VFS2Enabled {
+ p := fspath.Parse(path)
+ root := t.FSContext().RootDirectoryVFS2()
+ start := root
+ relPath := !p.Absolute
+ if relPath {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ }
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: p,
+ FollowFinalSymlink: true,
+ }
+ ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop, &vfs.BoundEndpointOptions{path})
+ root.DecRef()
+ if relPath {
+ start.DecRef()
+ }
+ if e != nil {
+ return nil, syserr.FromError(e)
+ }
+ return ep, nil
+ }
+
+ // Find the node in the filesystem.
+ root := t.FSContext().RootDirectory()
+ cwd := t.FSContext().WorkingDirectory()
+ remainingTraversals := uint(fs.DefaultTraversalLimit)
+ d, e := t.MountNamespace().FindInode(t, root, cwd, path, &remainingTraversals)
+ cwd.DecRef()
+ root.DecRef()
+ if e != nil {
+ return nil, syserr.FromError(e)
+ }
+
+ // Extract the endpoint if one is there.
+ ep := d.Inode.BoundEndpoint(path)
+ d.DecRef()
+ if ep == nil {
+ // No socket!
+ return nil, syserr.ErrConnectionRefused
+ }
+ return ep, nil
+}
+
+// Connect implements the linux syscall connect(2) for unix sockets.
+func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ ep, err := extractEndpoint(t, sockaddr)
+ if err != nil {
+ return err
+ }
+ defer ep.Release()
+
+ // Connect the server endpoint.
+ err = s.ep.Connect(t, ep)
+
+ if err == syserr.ErrWrongProtocolForSocket {
+ // Linux for abstract sockets returns ErrConnectionRefused
+ // instead of ErrWrongProtocolForSocket.
+ path, _ := extractPath(sockaddr)
+ if len(path) > 0 && path[0] == 0 {
+ err = syserr.ErrConnectionRefused
+ }
+ }
+
+ return err
+}
+
+// Write implements fs.FileOperations.Write.
+func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ t := kernel.TaskFromContext(ctx)
+ ctrl := control.New(t, s.ep, nil)
+
+ if src.NumBytes() == 0 {
+ nInt, err := s.ep.SendMsg(ctx, [][]byte{}, ctrl, nil)
+ return int64(nInt), err.ToError()
+ }
+
+ return src.CopyInTo(ctx, &EndpointWriter{
+ Ctx: ctx,
+ Endpoint: s.ep,
+ Control: ctrl,
+ To: nil,
+ })
+}
+
+// SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by
+// a transport.Endpoint.
+func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+ w := EndpointWriter{
+ Ctx: t,
+ Endpoint: s.ep,
+ Control: controlMessages.Unix,
+ To: nil,
+ }
+ if len(to) > 0 {
+ switch s.stype {
+ case linux.SOCK_SEQPACKET:
+ to = nil
+ case linux.SOCK_STREAM:
+ if s.State() == linux.SS_CONNECTED {
+ return 0, syserr.ErrAlreadyConnected
+ }
+ return 0, syserr.ErrNotSupported
+ default:
+ ep, err := extractEndpoint(t, to)
+ if err != nil {
+ return 0, err
+ }
+ defer ep.Release()
+ w.To = ep
+
+ if ep.Passcred() && w.Control.Credentials == nil {
+ w.Control.Credentials = control.MakeCreds(t)
+ }
+ }
+ }
+
+ n, err := src.CopyInTo(t, &w)
+ if err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ return int(n), syserr.FromError(err)
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst64(n)
+
+ n, err = src.CopyInTo(t, &w)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
+ break
+ }
+ }
+
+ return int(total), syserr.FromError(err)
+}
+
+// Passcred implements transport.Credentialer.Passcred.
+func (s *socketOpsCommon) Passcred() bool {
+ return s.ep.Passcred()
+}
+
+// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
+func (s *socketOpsCommon) ConnectedPasscred() bool {
+ return s.ep.ConnectedPasscred()
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.ep.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.ep.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
+ s.ep.EventUnregister(e)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
+ return netstack.SetSockOpt(t, s, s.ep, level, name, optVal)
+}
+
+// Shutdown implements the linux syscall shutdown(2) for sockets backed by
+// a transport.Endpoint.
+func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
+ f, err := netstack.ConvertShutdown(how)
+ if err != nil {
+ return err
+ }
+
+ // Issue shutdown request.
+ return s.ep.Shutdown(f)
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ return dst.CopyOutFrom(ctx, &EndpointReader{
+ Ctx: ctx,
+ Endpoint: s.ep,
+ NumRights: 0,
+ Peek: false,
+ From: nil,
+ })
+}
+
+// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
+// a transport.Endpoint.
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+ trunc := flags&linux.MSG_TRUNC != 0
+ peek := flags&linux.MSG_PEEK != 0
+ dontWait := flags&linux.MSG_DONTWAIT != 0
+ waitAll := flags&linux.MSG_WAITALL != 0
+ isPacket := s.isPacket()
+
+ // Calculate the number of FDs for which we have space and if we are
+ // requesting credentials.
+ var wantCreds bool
+ rightsLen := int(controlDataLen) - syscall.SizeofCmsghdr
+ if s.Passcred() {
+ // Credentials take priority if they are enabled and there is space.
+ wantCreds = rightsLen > 0
+ if !wantCreds {
+ msgFlags |= linux.MSG_CTRUNC
+ }
+ credLen := syscall.CmsgSpace(syscall.SizeofUcred)
+ rightsLen -= credLen
+ }
+ // FDs are 32 bit (4 byte) ints.
+ numRights := rightsLen / 4
+ if numRights < 0 {
+ numRights = 0
+ }
+
+ r := EndpointReader{
+ Ctx: t,
+ Endpoint: s.ep,
+ Creds: wantCreds,
+ NumRights: numRights,
+ Peek: peek,
+ }
+ if senderRequested {
+ r.From = &tcpip.FullAddress{}
+ }
+
+ doRead := func() (int64, error) {
+ return dst.CopyOutFrom(t, &r)
+ }
+
+ // If MSG_TRUNC is set with a zero byte destination then we still need
+ // to read the message and discard it, or in the case where MSG_PEEK is
+ // set, leave it be. In both cases the full message length must be
+ // returned.
+ if trunc && dst.Addrs.NumBytes() == 0 {
+ doRead = func() (int64, error) {
+ err := r.Truncate()
+ // Always return zero for bytes read since the destination size is
+ // zero.
+ return 0, err
+ }
+
+ }
+
+ var total int64
+ if n, err := doRead(); err != syserror.ErrWouldBlock || dontWait {
+ var from linux.SockAddr
+ var fromLen uint32
+ if r.From != nil && len([]byte(r.From.Addr)) != 0 {
+ from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
+ }
+
+ if r.ControlTrunc {
+ msgFlags |= linux.MSG_CTRUNC
+ }
+
+ if err != nil || dontWait || !waitAll || isPacket || n >= dst.NumBytes() {
+ if isPacket && n < int64(r.MsgSize) {
+ msgFlags |= linux.MSG_TRUNC
+ }
+
+ if trunc {
+ n = int64(r.MsgSize)
+ }
+
+ return int(n), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+ }
+
+ // Don't overwrite any data we received.
+ dst = dst.DropFirst64(n)
+ total += n
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ for {
+ if n, err := doRead(); err != syserror.ErrWouldBlock {
+ var from linux.SockAddr
+ var fromLen uint32
+ if r.From != nil {
+ from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
+ }
+
+ if r.ControlTrunc {
+ msgFlags |= linux.MSG_CTRUNC
+ }
+
+ if trunc {
+ // n and r.MsgSize are the same for streams.
+ total += int64(r.MsgSize)
+ } else {
+ total += n
+ }
+
+ streamPeerClosed := s.stype == linux.SOCK_STREAM && n == 0 && err == nil
+ if err != nil || !waitAll || isPacket || n >= dst.NumBytes() || streamPeerClosed {
+ if total > 0 {
+ err = nil
+ }
+ if isPacket && n < int64(r.MsgSize) {
+ msgFlags |= linux.MSG_TRUNC
+ }
+ return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+ }
+
+ // Don't overwrite any data we received.
+ dst = dst.DropFirst64(n)
+ }
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if total > 0 {
+ err = nil
+ }
+ if err == syserror.ETIMEDOUT {
+ return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ }
+ return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
+ }
+}
+
+// State implements socket.Socket.State.
+func (s *socketOpsCommon) State() uint32 {
+ return s.ep.State()
+}
+
+// Type implements socket.Socket.Type.
+func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
+ // Unix domain sockets always have a protocol of 0.
+ return linux.AF_UNIX, s.stype, 0
+}
+
+// provider is a unix domain socket provider.
+type provider struct{}
+
+// Socket returns a new unix domain socket.
+func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // Create the endpoint and socket.
+ var ep transport.Endpoint
+ switch stype {
+ case linux.SOCK_DGRAM, linux.SOCK_RAW:
+ ep = transport.NewConnectionless(t)
+ case linux.SOCK_SEQPACKET, linux.SOCK_STREAM:
+ ep = transport.NewConnectioned(t, stype, t.Kernel())
+ default:
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return New(t, ep, stype), nil
+}
+
+// Pair creates a new pair of AF_UNIX connected sockets.
+func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, nil, syserr.ErrProtocolNotSupported
+ }
+
+ switch stype {
+ case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET, linux.SOCK_RAW:
+ // Ok
+ default:
+ return nil, nil, syserr.ErrInvalidArgument
+ }
+
+ // Create the endpoints and sockets.
+ ep1, ep2 := transport.NewPair(t, stype, t.Kernel())
+ s1 := New(t, ep1, stype)
+ s2 := New(t, ep2, stype)
+
+ return s1, s2, nil
+}
+
+func init() {
+ socket.RegisterProvider(linux.AF_UNIX, &provider{})
+ socket.RegisterProviderVFS2(linux.AF_UNIX, &providerVFS2{})
+}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
new file mode 100644
index 000000000..ff2149250
--- /dev/null
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -0,0 +1,371 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unix
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SocketVFS2 implements socket.SocketVFS2 (and by extension,
+// vfs.FileDescriptionImpl) for Unix sockets.
+type SocketVFS2 struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.LockFD
+
+ socketOpsCommon
+}
+
+var _ = socket.SocketVFS2(&SocketVFS2{})
+
+// NewSockfsFile creates a new socket file in the global sockfs mount and
+// returns a corresponding file description.
+func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) {
+ mnt := t.Kernel().SocketMount()
+ d := sockfs.NewDentry(t.Credentials(), mnt)
+
+ fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{})
+ if err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return fd, nil
+}
+
+// NewFileDescription creates and returns a socket file description
+// corresponding to the given mount and dentry.
+func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry, locks *vfs.FileLocks) (*vfs.FileDescription, error) {
+ // You can create AF_UNIX, SOCK_RAW sockets. They're the same as
+ // SOCK_DGRAM and don't require CAP_NET_RAW.
+ if stype == linux.SOCK_RAW {
+ stype = linux.SOCK_DGRAM
+ }
+
+ sock := &SocketVFS2{
+ socketOpsCommon: socketOpsCommon{
+ ep: ep,
+ stype: stype,
+ },
+ }
+ sock.LockFD.Init(locks)
+ vfsfd := &sock.vfsfd
+ if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return vfsfd, nil
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+ return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
+}
+
+// blockingAccept implements a blocking version of accept(2), that is, if no
+// connections are ready to be accept, it will block until one becomes ready.
+func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) {
+ // Register for notifications.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.socketOpsCommon.EventRegister(&e, waiter.EventIn)
+ defer s.socketOpsCommon.EventUnregister(&e)
+
+ // Try to accept the connection; if it fails, then wait until we get a
+ // notification.
+ for {
+ if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock {
+ return ep, err
+ }
+
+ if err := t.Block(ch); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ }
+}
+
+// Accept implements the linux syscall accept(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ // Issue the accept request to get the new endpoint.
+ ep, err := s.ep.Accept()
+ if err != nil {
+ if err != syserr.ErrWouldBlock || !blocking {
+ return 0, nil, 0, err
+ }
+
+ var err *syserr.Error
+ ep, err = s.blockingAccept(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ ns, err := NewSockfsFile(t, ep, s.stype)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ defer ns.DecRef()
+
+ if flags&linux.SOCK_NONBLOCK != 0 {
+ ns.SetStatusFlags(t, t.Credentials(), linux.SOCK_NONBLOCK)
+ }
+
+ var addr linux.SockAddr
+ var addrLen uint32
+ if peerRequested {
+ // Get address of the peer.
+ var err *syserr.Error
+ addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ })
+ if e != nil {
+ return 0, nil, 0, syserr.FromError(e)
+ }
+
+ t.Kernel().RecordSocketVFS2(ns)
+ return fd, addr, addrLen, nil
+}
+
+// Bind implements the linux syscall bind(2) for unix sockets.
+func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ p, e := extractPath(sockaddr)
+ if e != nil {
+ return e
+ }
+
+ bep, ok := s.ep.(transport.BoundEndpoint)
+ if !ok {
+ // This socket can't be bound.
+ return syserr.ErrInvalidArgument
+ }
+
+ return s.ep.Bind(tcpip.FullAddress{Addr: tcpip.Address(p)}, func() *syserr.Error {
+ // Is it abstract?
+ if p[0] == 0 {
+ if t.IsNetworkNamespaced() {
+ return syserr.ErrInvalidEndpointState
+ }
+ if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil {
+ // syserr.ErrPortInUse corresponds to EADDRINUSE.
+ return syserr.ErrPortInUse
+ }
+ } else {
+ path := fspath.Parse(p)
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ relPath := !path.Absolute
+ if relPath {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ }
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ }
+ stat, err := s.vfsfd.Stat(t, vfs.StatOptions{Mask: linux.STATX_MODE})
+ if err != nil {
+ return syserr.FromError(err)
+ }
+ err = t.Kernel().VFS().MknodAt(t, t.Credentials(), &pop, &vfs.MknodOptions{
+ // File permissions correspond to net/unix/af_unix.c:unix_bind.
+ Mode: linux.FileMode(linux.S_IFSOCK | uint(stat.Mode)&^t.FSContext().Umask()),
+ Endpoint: bep,
+ })
+ if err == syserror.EEXIST {
+ return syserr.ErrAddressInUse
+ }
+ return syserr.FromError(err)
+ }
+
+ return nil
+ })
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return netstack.Ioctl(ctx, s.ep, uio, args)
+}
+
+// PRead implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Read implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ return dst.CopyOutFrom(ctx, &EndpointReader{
+ Ctx: ctx,
+ Endpoint: s.ep,
+ NumRights: 0,
+ Peek: false,
+ From: nil,
+ })
+}
+
+// PWrite implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ t := kernel.TaskFromContext(ctx)
+ ctrl := control.New(t, s.ep, nil)
+
+ if src.NumBytes() == 0 {
+ nInt, err := s.ep.SendMsg(ctx, [][]byte{}, ctrl, nil)
+ return int64(nInt), err.ToError()
+ }
+
+ return src.CopyInTo(ctx, &EndpointWriter{
+ Ctx: ctx,
+ Endpoint: s.ep,
+ Control: ctrl,
+ To: nil,
+ })
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.socketOpsCommon.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.socketOpsCommon.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SocketVFS2) EventUnregister(e *waiter.Entry) {
+ s.socketOpsCommon.EventUnregister(e)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
+ return netstack.SetSockOpt(t, s, s.ep, level, name, optVal)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
+}
+
+// providerVFS2 is a unix domain socket provider for VFS2.
+type providerVFS2 struct{}
+
+func (*providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // Create the endpoint and socket.
+ var ep transport.Endpoint
+ switch stype {
+ case linux.SOCK_DGRAM, linux.SOCK_RAW:
+ ep = transport.NewConnectionless(t)
+ case linux.SOCK_SEQPACKET, linux.SOCK_STREAM:
+ ep = transport.NewConnectioned(t, stype, t.Kernel())
+ default:
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ f, err := NewSockfsFile(t, ep, stype)
+ if err != nil {
+ ep.Close()
+ return nil, err
+ }
+ return f, nil
+}
+
+// Pair creates a new pair of AF_UNIX connected sockets.
+func (*providerVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, nil, syserr.ErrProtocolNotSupported
+ }
+
+ switch stype {
+ case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET, linux.SOCK_RAW:
+ // Ok
+ default:
+ return nil, nil, syserr.ErrInvalidArgument
+ }
+
+ // Create the endpoints and sockets.
+ ep1, ep2 := transport.NewPair(t, stype, t.Kernel())
+ s1, err := NewSockfsFile(t, ep1, stype)
+ if err != nil {
+ ep1.Close()
+ ep2.Close()
+ return nil, nil, err
+ }
+ s2, err := NewSockfsFile(t, ep2, stype)
+ if err != nil {
+ s1.DecRef()
+ ep2.Close()
+ return nil, nil, err
+ }
+
+ return s1, s2, nil
+}
diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD
new file mode 100644
index 000000000..0ea4aab8b
--- /dev/null
+++ b/pkg/sentry/state/BUILD
@@ -0,0 +1,23 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "state",
+ srcs = [
+ "state.go",
+ "state_metadata.go",
+ "state_unsafe.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/log",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/time",
+ "//pkg/sentry/watchdog",
+ "//pkg/state/statefile",
+ "//pkg/syserror",
+ ],
+)
diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go
new file mode 100644
index 000000000..9eb626b76
--- /dev/null
+++ b/pkg/sentry/state/state.go
@@ -0,0 +1,119 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package state provides high-level state wrappers.
+package state
+
+import (
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sentry/watchdog"
+ "gvisor.dev/gvisor/pkg/state/statefile"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+var previousMetadata map[string]string
+
+// ErrStateFile is returned when an error is encountered writing the statefile
+// (which may occur during open or close calls in addition to write).
+type ErrStateFile struct {
+ err error
+}
+
+// Error implements error.Error().
+func (e ErrStateFile) Error() string {
+ return fmt.Sprintf("statefile error: %v", e.err)
+}
+
+// SaveOpts contains save-related options.
+type SaveOpts struct {
+ // Destination is the save target.
+ Destination io.Writer
+
+ // Key is used for state integrity check.
+ Key []byte
+
+ // Metadata is save metadata.
+ Metadata map[string]string
+
+ // Callback is called prior to unpause, with any save error.
+ Callback func(err error)
+}
+
+// Save saves the system state.
+func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error {
+ log.Infof("Sandbox save started, pausing all tasks.")
+ k.Pause()
+ defer k.Unpause()
+ defer log.Infof("Tasks resumed after save.")
+
+ w.Stop()
+ defer w.Start()
+
+ // Supplement the metadata.
+ if opts.Metadata == nil {
+ opts.Metadata = make(map[string]string)
+ }
+ addSaveMetadata(opts.Metadata)
+
+ // Open the statefile.
+ wc, err := statefile.NewWriter(opts.Destination, opts.Key, opts.Metadata)
+ if err != nil {
+ err = ErrStateFile{err}
+ } else {
+ // Save the kernel.
+ err = k.SaveTo(wc)
+
+ // ENOSPC is a state file error. This error can only come from
+ // writing the state file, and not from fs.FileOperations.Fsync
+ // because we wrap those in kernel.TaskSet.flushWritesToFiles.
+ if err == syserror.ENOSPC {
+ err = ErrStateFile{err}
+ }
+
+ if closeErr := wc.Close(); err == nil && closeErr != nil {
+ err = ErrStateFile{closeErr}
+ }
+ }
+ opts.Callback(err)
+ return err
+}
+
+// LoadOpts contains load-related options.
+type LoadOpts struct {
+ // Destination is the load source.
+ Source io.Reader
+
+ // Key is used for state integrity check.
+ Key []byte
+}
+
+// Load loads the given kernel, setting the provided platform and stack.
+func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) error {
+ // Open the file.
+ r, m, err := statefile.NewReader(opts.Source, opts.Key)
+ if err != nil {
+ return ErrStateFile{err}
+ }
+
+ previousMetadata = m
+
+ // Restore the Kernel object graph.
+ return k.LoadFrom(r, n, clocks)
+}
diff --git a/pkg/sentry/state/state_metadata.go b/pkg/sentry/state/state_metadata.go
new file mode 100644
index 000000000..cefd20b9b
--- /dev/null
+++ b/pkg/sentry/state/state_metadata.go
@@ -0,0 +1,45 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// The save metadata keys for timestamp.
+const (
+ cpuUsage = "cpu_usage"
+ metadataTimestamp = "timestamp"
+)
+
+func addSaveMetadata(m map[string]string) {
+ t, err := CPUTime()
+ if err != nil {
+ log.Warningf("Error getting cpu time: %v", err)
+ }
+ if previousMetadata != nil {
+ p, err := time.ParseDuration(previousMetadata[cpuUsage])
+ if err != nil {
+ log.Warningf("Error parsing previous runs' cpu time: %v", err)
+ }
+ t += p
+ }
+ m[cpuUsage] = t.String()
+
+ m[metadataTimestamp] = fmt.Sprintf("%v", time.Now())
+}
diff --git a/pkg/sentry/state/state_unsafe.go b/pkg/sentry/state/state_unsafe.go
new file mode 100644
index 000000000..d271c6fc9
--- /dev/null
+++ b/pkg/sentry/state/state_unsafe.go
@@ -0,0 +1,34 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "fmt"
+ "syscall"
+ "time"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// CPUTime returns the CPU time usage by Sentry and app.
+func CPUTime() (time.Duration, error) {
+ var ts syscall.Timespec
+ _, _, errno := syscall.RawSyscall(syscall.SYS_CLOCK_GETTIME, uintptr(linux.CLOCK_PROCESS_CPUTIME_ID), uintptr(unsafe.Pointer(&ts)), 0)
+ if errno != 0 {
+ return 0, fmt.Errorf("failed calling clock_gettime(CLOCK_PROCESS_CPUTIME_ID): errno=%d", errno)
+ }
+ return time.Duration(ts.Nano()), nil
+}
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
new file mode 100644
index 000000000..88d5db9fc
--- /dev/null
+++ b/pkg/sentry/strace/BUILD
@@ -0,0 +1,45 @@
+load("//tools:defs.bzl", "go_library", "proto_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "strace",
+ srcs = [
+ "capability.go",
+ "clone.go",
+ "epoll.go",
+ "futex.go",
+ "linux64_amd64.go",
+ "linux64_arm64.go",
+ "open.go",
+ "poll.go",
+ "ptrace.go",
+ "select.go",
+ "signal.go",
+ "socket.go",
+ "strace.go",
+ "syscalls.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ ":strace_go_proto",
+ "//pkg/abi",
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/bits",
+ "//pkg/eventchannel",
+ "//pkg/seccomp",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/socket/netlink",
+ "//pkg/sentry/socket/netstack",
+ "//pkg/sentry/syscalls/linux",
+ "//pkg/usermem",
+ ],
+)
+
+proto_library(
+ name = "strace",
+ srcs = ["strace.proto"],
+ visibility = ["//visibility:public"],
+)
diff --git a/pkg/sentry/strace/capability.go b/pkg/sentry/strace/capability.go
new file mode 100644
index 000000000..3255dc18d
--- /dev/null
+++ b/pkg/sentry/strace/capability.go
@@ -0,0 +1,176 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// CapabilityBitset is the set of capabilities in a bitset.
+var CapabilityBitset = abi.FlagSet{
+ {
+ Flag: 1 << uint32(linux.CAP_CHOWN),
+ Name: "CAP_CHOWN",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_DAC_OVERRIDE),
+ Name: "CAP_DAC_OVERRIDE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_DAC_READ_SEARCH),
+ Name: "CAP_DAC_READ_SEARCH",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_FOWNER),
+ Name: "CAP_FOWNER",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_FSETID),
+ Name: "CAP_FSETID",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_KILL),
+ Name: "CAP_KILL",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SETGID),
+ Name: "CAP_SETGID",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SETUID),
+ Name: "CAP_SETUID",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SETPCAP),
+ Name: "CAP_SETPCAP",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_LINUX_IMMUTABLE),
+ Name: "CAP_LINUX_IMMUTABLE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_NET_BIND_SERVICE),
+ Name: "CAP_NET_BIND_SERVICE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_NET_BROADCAST),
+ Name: "CAP_NET_BROADCAST",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_NET_ADMIN),
+ Name: "CAP_NET_ADMIN",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_NET_RAW),
+ Name: "CAP_NET_RAW",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_IPC_LOCK),
+ Name: "CAP_IPC_LOCK",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_IPC_OWNER),
+ Name: "CAP_IPC_OWNER",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_MODULE),
+ Name: "CAP_SYS_MODULE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_RAWIO),
+ Name: "CAP_SYS_RAWIO",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_CHROOT),
+ Name: "CAP_SYS_CHROOT",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_PTRACE),
+ Name: "CAP_SYS_PTRACE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_PACCT),
+ Name: "CAP_SYS_PACCT",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_ADMIN),
+ Name: "CAP_SYS_ADMIN",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_BOOT),
+ Name: "CAP_SYS_BOOT",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_NICE),
+ Name: "CAP_SYS_NICE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_RESOURCE),
+ Name: "CAP_SYS_RESOURCE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_TIME),
+ Name: "CAP_SYS_TIME",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_TTY_CONFIG),
+ Name: "CAP_SYS_TTY_CONFIG",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_MKNOD),
+ Name: "CAP_MKNOD",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_LEASE),
+ Name: "CAP_LEASE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_AUDIT_WRITE),
+ Name: "CAP_AUDIT_WRITE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_AUDIT_CONTROL),
+ Name: "CAP_AUDIT_CONTROL",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SETFCAP),
+ Name: "CAP_SETFCAP",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_MAC_OVERRIDE),
+ Name: "CAP_MAC_OVERRIDE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_MAC_ADMIN),
+ Name: "CAP_MAC_ADMIN",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYSLOG),
+ Name: "CAP_SYSLOG",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_WAKE_ALARM),
+ Name: "CAP_WAKE_ALARM",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_BLOCK_SUSPEND),
+ Name: "CAP_BLOCK_SUSPEND",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_AUDIT_READ),
+ Name: "CAP_AUDIT_READ",
+ },
+}
diff --git a/pkg/sentry/strace/clone.go b/pkg/sentry/strace/clone.go
new file mode 100644
index 000000000..e99158712
--- /dev/null
+++ b/pkg/sentry/strace/clone.go
@@ -0,0 +1,113 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi"
+)
+
+// CloneFlagSet is the set of clone(2) flags.
+var CloneFlagSet = abi.FlagSet{
+ {
+ Flag: syscall.CLONE_VM,
+ Name: "CLONE_VM",
+ },
+ {
+ Flag: syscall.CLONE_FS,
+ Name: "CLONE_FS",
+ },
+ {
+ Flag: syscall.CLONE_FILES,
+ Name: "CLONE_FILES",
+ },
+ {
+ Flag: syscall.CLONE_SIGHAND,
+ Name: "CLONE_SIGHAND",
+ },
+ {
+ Flag: syscall.CLONE_PTRACE,
+ Name: "CLONE_PTRACE",
+ },
+ {
+ Flag: syscall.CLONE_VFORK,
+ Name: "CLONE_VFORK",
+ },
+ {
+ Flag: syscall.CLONE_PARENT,
+ Name: "CLONE_PARENT",
+ },
+ {
+ Flag: syscall.CLONE_THREAD,
+ Name: "CLONE_THREAD",
+ },
+ {
+ Flag: syscall.CLONE_NEWNS,
+ Name: "CLONE_NEWNS",
+ },
+ {
+ Flag: syscall.CLONE_SYSVSEM,
+ Name: "CLONE_SYSVSEM",
+ },
+ {
+ Flag: syscall.CLONE_SETTLS,
+ Name: "CLONE_SETTLS",
+ },
+ {
+ Flag: syscall.CLONE_PARENT_SETTID,
+ Name: "CLONE_PARENT_SETTID",
+ },
+ {
+ Flag: syscall.CLONE_CHILD_CLEARTID,
+ Name: "CLONE_CHILD_CLEARTID",
+ },
+ {
+ Flag: syscall.CLONE_DETACHED,
+ Name: "CLONE_DETACHED",
+ },
+ {
+ Flag: syscall.CLONE_UNTRACED,
+ Name: "CLONE_UNTRACED",
+ },
+ {
+ Flag: syscall.CLONE_CHILD_SETTID,
+ Name: "CLONE_CHILD_SETTID",
+ },
+ {
+ Flag: syscall.CLONE_NEWUTS,
+ Name: "CLONE_NEWUTS",
+ },
+ {
+ Flag: syscall.CLONE_NEWIPC,
+ Name: "CLONE_NEWIPC",
+ },
+ {
+ Flag: syscall.CLONE_NEWUSER,
+ Name: "CLONE_NEWUSER",
+ },
+ {
+ Flag: syscall.CLONE_NEWPID,
+ Name: "CLONE_NEWPID",
+ },
+ {
+ Flag: syscall.CLONE_NEWNET,
+ Name: "CLONE_NEWNET",
+ },
+ {
+ Flag: syscall.CLONE_IO,
+ Name: "CLONE_IO",
+ },
+}
diff --git a/pkg/sentry/strace/epoll.go b/pkg/sentry/strace/epoll.go
new file mode 100644
index 000000000..a6e48b836
--- /dev/null
+++ b/pkg/sentry/strace/epoll.go
@@ -0,0 +1,89 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "fmt"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func epollEvent(t *kernel.Task, eventAddr usermem.Addr) string {
+ var e linux.EpollEvent
+ if _, err := t.CopyIn(eventAddr, &e); err != nil {
+ return fmt.Sprintf("%#x {error reading event: %v}", eventAddr, err)
+ }
+ var sb strings.Builder
+ fmt.Fprintf(&sb, "%#x ", eventAddr)
+ writeEpollEvent(&sb, e)
+ return sb.String()
+}
+
+func epollEvents(t *kernel.Task, eventsAddr usermem.Addr, numEvents, maxBytes uint64) string {
+ var sb strings.Builder
+ fmt.Fprintf(&sb, "%#x {", eventsAddr)
+ addr := eventsAddr
+ for i := uint64(0); i < numEvents; i++ {
+ var e linux.EpollEvent
+ if _, err := t.CopyIn(addr, &e); err != nil {
+ fmt.Fprintf(&sb, "{error reading event at %#x: %v}", addr, err)
+ continue
+ }
+ writeEpollEvent(&sb, e)
+ if uint64(sb.Len()) >= maxBytes {
+ sb.WriteString("...")
+ break
+ }
+ if _, ok := addr.AddLength(uint64(linux.SizeOfEpollEvent)); !ok {
+ fmt.Fprintf(&sb, "{error reading event at %#x: EFAULT}", addr)
+ continue
+ }
+ }
+ sb.WriteString("}")
+ return sb.String()
+}
+
+func writeEpollEvent(sb *strings.Builder, e linux.EpollEvent) {
+ events := epollEventEvents.Parse(uint64(e.Events))
+ fmt.Fprintf(sb, "{events=%s data=[%#x, %#x]}", events, e.Data[0], e.Data[1])
+}
+
+var epollCtlOps = abi.ValueSet{
+ linux.EPOLL_CTL_ADD: "EPOLL_CTL_ADD",
+ linux.EPOLL_CTL_DEL: "EPOLL_CTL_DEL",
+ linux.EPOLL_CTL_MOD: "EPOLL_CTL_MOD",
+}
+
+var epollEventEvents = abi.FlagSet{
+ {Flag: linux.EPOLLIN, Name: "EPOLLIN"},
+ {Flag: linux.EPOLLPRI, Name: "EPOLLPRI"},
+ {Flag: linux.EPOLLOUT, Name: "EPOLLOUT"},
+ {Flag: linux.EPOLLERR, Name: "EPOLLERR"},
+ {Flag: linux.EPOLLHUP, Name: "EPULLHUP"},
+ {Flag: linux.EPOLLRDNORM, Name: "EPOLLRDNORM"},
+ {Flag: linux.EPOLLRDBAND, Name: "EPOLLRDBAND"},
+ {Flag: linux.EPOLLWRNORM, Name: "EPOLLWRNORM"},
+ {Flag: linux.EPOLLWRBAND, Name: "EPOLLWRBAND"},
+ {Flag: linux.EPOLLMSG, Name: "EPOLLMSG"},
+ {Flag: linux.EPOLLRDHUP, Name: "EPOLLRDHUP"},
+ {Flag: linux.EPOLLEXCLUSIVE, Name: "EPOLLEXCLUSIVE"},
+ {Flag: linux.EPOLLWAKEUP, Name: "EPOLLWAKEUP"},
+ {Flag: linux.EPOLLONESHOT, Name: "EPOLLONESHOT"},
+ {Flag: linux.EPOLLET, Name: "EPOLLET"},
+}
diff --git a/pkg/sentry/strace/futex.go b/pkg/sentry/strace/futex.go
new file mode 100644
index 000000000..d55c4080e
--- /dev/null
+++ b/pkg/sentry/strace/futex.go
@@ -0,0 +1,52 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// FutexCmd are the possible futex(2) commands.
+var FutexCmd = abi.ValueSet{
+ linux.FUTEX_WAIT: "FUTEX_WAIT",
+ linux.FUTEX_WAKE: "FUTEX_WAKE",
+ linux.FUTEX_FD: "FUTEX_FD",
+ linux.FUTEX_REQUEUE: "FUTEX_REQUEUE",
+ linux.FUTEX_CMP_REQUEUE: "FUTEX_CMP_REQUEUE",
+ linux.FUTEX_WAKE_OP: "FUTEX_WAKE_OP",
+ linux.FUTEX_LOCK_PI: "FUTEX_LOCK_PI",
+ linux.FUTEX_UNLOCK_PI: "FUTEX_UNLOCK_PI",
+ linux.FUTEX_TRYLOCK_PI: "FUTEX_TRYLOCK_PI",
+ linux.FUTEX_WAIT_BITSET: "FUTEX_WAIT_BITSET",
+ linux.FUTEX_WAKE_BITSET: "FUTEX_WAKE_BITSET",
+ linux.FUTEX_WAIT_REQUEUE_PI: "FUTEX_WAIT_REQUEUE_PI",
+ linux.FUTEX_CMP_REQUEUE_PI: "FUTEX_CMP_REQUEUE_PI",
+}
+
+func futex(op uint64) string {
+ cmd := op &^ (linux.FUTEX_PRIVATE_FLAG | linux.FUTEX_CLOCK_REALTIME)
+ clockRealtime := (op & linux.FUTEX_CLOCK_REALTIME) == linux.FUTEX_CLOCK_REALTIME
+ private := (op & linux.FUTEX_PRIVATE_FLAG) == linux.FUTEX_PRIVATE_FLAG
+
+ s := FutexCmd.Parse(cmd)
+ if clockRealtime {
+ s += "|FUTEX_CLOCK_REALTIME"
+ }
+ if private {
+ s += "|FUTEX_PRIVATE_FLAG"
+ }
+ return s
+}
diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go
new file mode 100644
index 000000000..71b92eaee
--- /dev/null
+++ b/pkg/sentry/strace/linux64_amd64.go
@@ -0,0 +1,384 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package strace
+
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+// linuxAMD64 provides a mapping of the Linux amd64 syscalls and their argument
+// types for display / formatting.
+var linuxAMD64 = SyscallMap{
+ 0: makeSyscallInfo("read", FD, ReadBuffer, Hex),
+ 1: makeSyscallInfo("write", FD, WriteBuffer, Hex),
+ 2: makeSyscallInfo("open", Path, OpenFlags, Mode),
+ 3: makeSyscallInfo("close", FD),
+ 4: makeSyscallInfo("stat", Path, Stat),
+ 5: makeSyscallInfo("fstat", FD, Stat),
+ 6: makeSyscallInfo("lstat", Path, Stat),
+ 7: makeSyscallInfo("poll", PollFDs, Hex, Hex),
+ 8: makeSyscallInfo("lseek", Hex, Hex, Hex),
+ 9: makeSyscallInfo("mmap", Hex, Hex, Hex, Hex, FD, Hex),
+ 10: makeSyscallInfo("mprotect", Hex, Hex, Hex),
+ 11: makeSyscallInfo("munmap", Hex, Hex),
+ 12: makeSyscallInfo("brk", Hex),
+ 13: makeSyscallInfo("rt_sigaction", Signal, SigAction, PostSigAction, Hex),
+ 14: makeSyscallInfo("rt_sigprocmask", SignalMaskAction, SigSet, PostSigSet, Hex),
+ 15: makeSyscallInfo("rt_sigreturn"),
+ 16: makeSyscallInfo("ioctl", FD, Hex, Hex),
+ 17: makeSyscallInfo("pread64", FD, ReadBuffer, Hex, Hex),
+ 18: makeSyscallInfo("pwrite64", FD, WriteBuffer, Hex, Hex),
+ 19: makeSyscallInfo("readv", FD, ReadIOVec, Hex),
+ 20: makeSyscallInfo("writev", FD, WriteIOVec, Hex),
+ 21: makeSyscallInfo("access", Path, Oct),
+ 22: makeSyscallInfo("pipe", PipeFDs),
+ 23: makeSyscallInfo("select", Hex, SelectFDSet, SelectFDSet, SelectFDSet, Timeval),
+ 24: makeSyscallInfo("sched_yield"),
+ 25: makeSyscallInfo("mremap", Hex, Hex, Hex, Hex, Hex),
+ 26: makeSyscallInfo("msync", Hex, Hex, Hex),
+ 27: makeSyscallInfo("mincore", Hex, Hex, Hex),
+ 28: makeSyscallInfo("madvise", Hex, Hex, Hex),
+ 29: makeSyscallInfo("shmget", Hex, Hex, Hex),
+ 30: makeSyscallInfo("shmat", Hex, Hex, Hex),
+ 31: makeSyscallInfo("shmctl", Hex, Hex, Hex),
+ 32: makeSyscallInfo("dup", FD),
+ 33: makeSyscallInfo("dup2", FD, FD),
+ 34: makeSyscallInfo("pause"),
+ 35: makeSyscallInfo("nanosleep", Timespec, PostTimespec),
+ 36: makeSyscallInfo("getitimer", ItimerType, PostItimerVal),
+ 37: makeSyscallInfo("alarm", Hex),
+ 38: makeSyscallInfo("setitimer", ItimerType, ItimerVal, PostItimerVal),
+ 39: makeSyscallInfo("getpid"),
+ 40: makeSyscallInfo("sendfile", FD, FD, Hex, Hex),
+ 41: makeSyscallInfo("socket", SockFamily, SockType, SockProtocol),
+ 42: makeSyscallInfo("connect", FD, SockAddr, Hex),
+ 43: makeSyscallInfo("accept", FD, PostSockAddr, SockLen),
+ 44: makeSyscallInfo("sendto", FD, Hex, Hex, Hex, SockAddr, Hex),
+ 45: makeSyscallInfo("recvfrom", FD, Hex, Hex, Hex, PostSockAddr, SockLen),
+ 46: makeSyscallInfo("sendmsg", FD, SendMsgHdr, Hex),
+ 47: makeSyscallInfo("recvmsg", FD, RecvMsgHdr, Hex),
+ 48: makeSyscallInfo("shutdown", FD, Hex),
+ 49: makeSyscallInfo("bind", FD, SockAddr, Hex),
+ 50: makeSyscallInfo("listen", FD, Hex),
+ 51: makeSyscallInfo("getsockname", FD, PostSockAddr, SockLen),
+ 52: makeSyscallInfo("getpeername", FD, PostSockAddr, SockLen),
+ 53: makeSyscallInfo("socketpair", SockFamily, SockType, SockProtocol, Hex),
+ 54: makeSyscallInfo("setsockopt", FD, SockOptLevel, SockOptName, SetSockOptVal, Hex /* length by value, not a pointer */),
+ 55: makeSyscallInfo("getsockopt", FD, SockOptLevel, SockOptName, GetSockOptVal, SockLen),
+ 56: makeSyscallInfo("clone", CloneFlags, Hex, Hex, Hex, Hex),
+ 57: makeSyscallInfo("fork"),
+ 58: makeSyscallInfo("vfork"),
+ 59: makeSyscallInfo("execve", Path, ExecveStringVector, ExecveStringVector),
+ 60: makeSyscallInfo("exit", Hex),
+ 61: makeSyscallInfo("wait4", Hex, Hex, Hex, Rusage),
+ 62: makeSyscallInfo("kill", Hex, Signal),
+ 63: makeSyscallInfo("uname", Uname),
+ 64: makeSyscallInfo("semget", Hex, Hex, Hex),
+ 65: makeSyscallInfo("semop", Hex, Hex, Hex),
+ 66: makeSyscallInfo("semctl", Hex, Hex, Hex, Hex),
+ 67: makeSyscallInfo("shmdt", Hex),
+ 68: makeSyscallInfo("msgget", Hex, Hex),
+ 69: makeSyscallInfo("msgsnd", Hex, Hex, Hex, Hex),
+ 70: makeSyscallInfo("msgrcv", Hex, Hex, Hex, Hex, Hex),
+ 71: makeSyscallInfo("msgctl", Hex, Hex, Hex),
+ 72: makeSyscallInfo("fcntl", FD, Hex, Hex),
+ 73: makeSyscallInfo("flock", FD, Hex),
+ 74: makeSyscallInfo("fsync", FD),
+ 75: makeSyscallInfo("fdatasync", FD),
+ 76: makeSyscallInfo("truncate", Path, Hex),
+ 77: makeSyscallInfo("ftruncate", FD, Hex),
+ 78: makeSyscallInfo("getdents", FD, Hex, Hex),
+ 79: makeSyscallInfo("getcwd", PostPath, Hex),
+ 80: makeSyscallInfo("chdir", Path),
+ 81: makeSyscallInfo("fchdir", FD),
+ 82: makeSyscallInfo("rename", Path, Path),
+ 83: makeSyscallInfo("mkdir", Path, Oct),
+ 84: makeSyscallInfo("rmdir", Path),
+ 85: makeSyscallInfo("creat", Path, Oct),
+ 86: makeSyscallInfo("link", Path, Path),
+ 87: makeSyscallInfo("unlink", Path),
+ 88: makeSyscallInfo("symlink", Path, Path),
+ 89: makeSyscallInfo("readlink", Path, ReadBuffer, Hex),
+ 90: makeSyscallInfo("chmod", Path, Mode),
+ 91: makeSyscallInfo("fchmod", FD, Mode),
+ 92: makeSyscallInfo("chown", Path, Hex, Hex),
+ 93: makeSyscallInfo("fchown", FD, Hex, Hex),
+ 94: makeSyscallInfo("lchown", Path, Hex, Hex),
+ 95: makeSyscallInfo("umask", Hex),
+ 96: makeSyscallInfo("gettimeofday", Timeval, Hex),
+ 97: makeSyscallInfo("getrlimit", Hex, Hex),
+ 98: makeSyscallInfo("getrusage", Hex, Rusage),
+ 99: makeSyscallInfo("sysinfo", Hex),
+ 100: makeSyscallInfo("times", Hex),
+ 101: makeSyscallInfo("ptrace", PtraceRequest, Hex, Hex, Hex),
+ 102: makeSyscallInfo("getuid"),
+ 103: makeSyscallInfo("syslog", Hex, Hex, Hex),
+ 104: makeSyscallInfo("getgid"),
+ 105: makeSyscallInfo("setuid", Hex),
+ 106: makeSyscallInfo("setgid", Hex),
+ 107: makeSyscallInfo("geteuid"),
+ 108: makeSyscallInfo("getegid"),
+ 109: makeSyscallInfo("setpgid", Hex, Hex),
+ 110: makeSyscallInfo("getppid"),
+ 111: makeSyscallInfo("getpgrp"),
+ 112: makeSyscallInfo("setsid"),
+ 113: makeSyscallInfo("setreuid", Hex, Hex),
+ 114: makeSyscallInfo("setregid", Hex, Hex),
+ 115: makeSyscallInfo("getgroups", Hex, Hex),
+ 116: makeSyscallInfo("setgroups", Hex, Hex),
+ 117: makeSyscallInfo("setresuid", Hex, Hex, Hex),
+ 118: makeSyscallInfo("getresuid", Hex, Hex, Hex),
+ 119: makeSyscallInfo("setresgid", Hex, Hex, Hex),
+ 120: makeSyscallInfo("getresgid", Hex, Hex, Hex),
+ 121: makeSyscallInfo("getpgid", Hex),
+ 122: makeSyscallInfo("setfsuid", Hex),
+ 123: makeSyscallInfo("setfsgid", Hex),
+ 124: makeSyscallInfo("getsid", Hex),
+ 125: makeSyscallInfo("capget", CapHeader, PostCapData),
+ 126: makeSyscallInfo("capset", CapHeader, CapData),
+ 127: makeSyscallInfo("rt_sigpending", Hex),
+ 128: makeSyscallInfo("rt_sigtimedwait", SigSet, Hex, Timespec, Hex),
+ 129: makeSyscallInfo("rt_sigqueueinfo", Hex, Signal, Hex),
+ 130: makeSyscallInfo("rt_sigsuspend", Hex),
+ 131: makeSyscallInfo("sigaltstack", Hex, Hex),
+ 132: makeSyscallInfo("utime", Path, Utimbuf),
+ 133: makeSyscallInfo("mknod", Path, Mode, Hex),
+ 134: makeSyscallInfo("uselib", Hex),
+ 135: makeSyscallInfo("personality", Hex),
+ 136: makeSyscallInfo("ustat", Hex, Hex),
+ 137: makeSyscallInfo("statfs", Path, Hex),
+ 138: makeSyscallInfo("fstatfs", FD, Hex),
+ 139: makeSyscallInfo("sysfs", Hex, Hex, Hex),
+ 140: makeSyscallInfo("getpriority", Hex, Hex),
+ 141: makeSyscallInfo("setpriority", Hex, Hex, Hex),
+ 142: makeSyscallInfo("sched_setparam", Hex, Hex),
+ 143: makeSyscallInfo("sched_getparam", Hex, Hex),
+ 144: makeSyscallInfo("sched_setscheduler", Hex, Hex, Hex),
+ 145: makeSyscallInfo("sched_getscheduler", Hex),
+ 146: makeSyscallInfo("sched_get_priority_max", Hex),
+ 147: makeSyscallInfo("sched_get_priority_min", Hex),
+ 148: makeSyscallInfo("sched_rr_get_interval", Hex, Hex),
+ 149: makeSyscallInfo("mlock", Hex, Hex),
+ 150: makeSyscallInfo("munlock", Hex, Hex),
+ 151: makeSyscallInfo("mlockall", Hex),
+ 152: makeSyscallInfo("munlockall"),
+ 153: makeSyscallInfo("vhangup"),
+ 154: makeSyscallInfo("modify_ldt", Hex, Hex, Hex),
+ 155: makeSyscallInfo("pivot_root", Path, Path),
+ 156: makeSyscallInfo("_sysctl", Hex),
+ 157: makeSyscallInfo("prctl", Hex, Hex, Hex, Hex, Hex),
+ 158: makeSyscallInfo("arch_prctl", Hex, Hex),
+ 159: makeSyscallInfo("adjtimex", Hex),
+ 160: makeSyscallInfo("setrlimit", Hex, Hex),
+ 161: makeSyscallInfo("chroot", Path),
+ 162: makeSyscallInfo("sync"),
+ 163: makeSyscallInfo("acct", Hex),
+ 164: makeSyscallInfo("settimeofday", Timeval, Hex),
+ 165: makeSyscallInfo("mount", Path, Path, Path, Hex, Path),
+ 166: makeSyscallInfo("umount2", Path, Hex),
+ 167: makeSyscallInfo("swapon", Hex, Hex),
+ 168: makeSyscallInfo("swapoff", Hex),
+ 169: makeSyscallInfo("reboot", Hex, Hex, Hex, Hex),
+ 170: makeSyscallInfo("sethostname", Hex, Hex),
+ 171: makeSyscallInfo("setdomainname", Hex, Hex),
+ 172: makeSyscallInfo("iopl", Hex),
+ 173: makeSyscallInfo("ioperm", Hex, Hex, Hex),
+ 174: makeSyscallInfo("create_module", Path, Hex),
+ 175: makeSyscallInfo("init_module", Hex, Hex, Hex),
+ 176: makeSyscallInfo("delete_module", Hex, Hex),
+ 177: makeSyscallInfo("get_kernel_syms", Hex),
+ // 178: query_module (only present in Linux < 2.6)
+ 179: makeSyscallInfo("quotactl", Hex, Hex, Hex, Hex),
+ 180: makeSyscallInfo("nfsservctl", Hex, Hex, Hex),
+ // 181: getpmsg (not implemented in the Linux kernel)
+ // 182: putpmsg (not implemented in the Linux kernel)
+ // 183: afs_syscall (not implemented in the Linux kernel)
+ // 184: tuxcall (not implemented in the Linux kernel)
+ // 185: security (not implemented in the Linux kernel)
+ 186: makeSyscallInfo("gettid"),
+ 187: makeSyscallInfo("readahead", Hex, Hex, Hex),
+ 188: makeSyscallInfo("setxattr", Path, Path, Hex, Hex, Hex),
+ 189: makeSyscallInfo("lsetxattr", Path, Path, Hex, Hex, Hex),
+ 190: makeSyscallInfo("fsetxattr", FD, Path, Hex, Hex, Hex),
+ 191: makeSyscallInfo("getxattr", Path, Path, Hex, Hex),
+ 192: makeSyscallInfo("lgetxattr", Path, Path, Hex, Hex),
+ 193: makeSyscallInfo("fgetxattr", FD, Path, Hex, Hex),
+ 194: makeSyscallInfo("listxattr", Path, Path, Hex),
+ 195: makeSyscallInfo("llistxattr", Path, Path, Hex),
+ 196: makeSyscallInfo("flistxattr", FD, Path, Hex),
+ 197: makeSyscallInfo("removexattr", Path, Path),
+ 198: makeSyscallInfo("lremovexattr", Path, Path),
+ 199: makeSyscallInfo("fremovexattr", FD, Path),
+ 200: makeSyscallInfo("tkill", Hex, Signal),
+ 201: makeSyscallInfo("time", Hex),
+ 202: makeSyscallInfo("futex", Hex, FutexOp, Hex, Timespec, Hex, Hex),
+ 203: makeSyscallInfo("sched_setaffinity", Hex, Hex, Hex),
+ 204: makeSyscallInfo("sched_getaffinity", Hex, Hex, Hex),
+ 205: makeSyscallInfo("set_thread_area", Hex),
+ 206: makeSyscallInfo("io_setup", Hex, Hex),
+ 207: makeSyscallInfo("io_destroy", Hex),
+ 208: makeSyscallInfo("io_getevents", Hex, Hex, Hex, Hex, Timespec),
+ 209: makeSyscallInfo("io_submit", Hex, Hex, Hex),
+ 210: makeSyscallInfo("io_cancel", Hex, Hex, Hex),
+ 211: makeSyscallInfo("get_thread_area", Hex),
+ 212: makeSyscallInfo("lookup_dcookie", Hex, Hex, Hex),
+ 213: makeSyscallInfo("epoll_create", Hex),
+ // 214: epoll_ctl_old (not implemented in the Linux kernel)
+ // 215: epoll_wait_old (not implemented in the Linux kernel)
+ 216: makeSyscallInfo("remap_file_pages", Hex, Hex, Hex, Hex, Hex),
+ 217: makeSyscallInfo("getdents64", FD, Hex, Hex),
+ 218: makeSyscallInfo("set_tid_address", Hex),
+ 219: makeSyscallInfo("restart_syscall"),
+ 220: makeSyscallInfo("semtimedop", Hex, Hex, Hex, Hex),
+ 221: makeSyscallInfo("fadvise64", FD, Hex, Hex, Hex),
+ 222: makeSyscallInfo("timer_create", Hex, Hex, Hex),
+ 223: makeSyscallInfo("timer_settime", Hex, Hex, ItimerSpec, PostItimerSpec),
+ 224: makeSyscallInfo("timer_gettime", Hex, PostItimerSpec),
+ 225: makeSyscallInfo("timer_getoverrun", Hex),
+ 226: makeSyscallInfo("timer_delete", Hex),
+ 227: makeSyscallInfo("clock_settime", Hex, Timespec),
+ 228: makeSyscallInfo("clock_gettime", Hex, PostTimespec),
+ 229: makeSyscallInfo("clock_getres", Hex, PostTimespec),
+ 230: makeSyscallInfo("clock_nanosleep", Hex, Hex, Timespec, PostTimespec),
+ 231: makeSyscallInfo("exit_group", Hex),
+ 232: makeSyscallInfo("epoll_wait", FD, EpollEvents, Hex, Hex),
+ 233: makeSyscallInfo("epoll_ctl", FD, EpollCtlOp, FD, EpollEvent),
+ 234: makeSyscallInfo("tgkill", Hex, Hex, Signal),
+ 235: makeSyscallInfo("utimes", Path, Timeval),
+ // 236: vserver (not implemented in the Linux kernel)
+ 237: makeSyscallInfo("mbind", Hex, Hex, Hex, Hex, Hex, Hex),
+ 238: makeSyscallInfo("set_mempolicy", Hex, Hex, Hex),
+ 239: makeSyscallInfo("get_mempolicy", Hex, Hex, Hex, Hex, Hex),
+ 240: makeSyscallInfo("mq_open", Hex, Hex, Hex, Hex),
+ 241: makeSyscallInfo("mq_unlink", Hex),
+ 242: makeSyscallInfo("mq_timedsend", Hex, Hex, Hex, Hex, Hex),
+ 243: makeSyscallInfo("mq_timedreceive", Hex, Hex, Hex, Hex, Hex),
+ 244: makeSyscallInfo("mq_notify", Hex, Hex),
+ 245: makeSyscallInfo("mq_getsetattr", Hex, Hex, Hex),
+ 246: makeSyscallInfo("kexec_load", Hex, Hex, Hex, Hex),
+ 247: makeSyscallInfo("waitid", Hex, Hex, Hex, Hex, Rusage),
+ 248: makeSyscallInfo("add_key", Hex, Hex, Hex, Hex, Hex),
+ 249: makeSyscallInfo("request_key", Hex, Hex, Hex, Hex),
+ 250: makeSyscallInfo("keyctl", Hex, Hex, Hex, Hex, Hex),
+ 251: makeSyscallInfo("ioprio_set", Hex, Hex, Hex),
+ 252: makeSyscallInfo("ioprio_get", Hex, Hex),
+ 253: makeSyscallInfo("inotify_init"),
+ 254: makeSyscallInfo("inotify_add_watch", Hex, Path, Hex),
+ 255: makeSyscallInfo("inotify_rm_watch", Hex, Hex),
+ 256: makeSyscallInfo("migrate_pages", Hex, Hex, Hex, Hex),
+ 257: makeSyscallInfo("openat", FD, Path, OpenFlags, Mode),
+ 258: makeSyscallInfo("mkdirat", FD, Path, Hex),
+ 259: makeSyscallInfo("mknodat", FD, Path, Mode, Hex),
+ 260: makeSyscallInfo("fchownat", FD, Path, Hex, Hex, Hex),
+ 261: makeSyscallInfo("futimesat", FD, Path, Hex),
+ 262: makeSyscallInfo("newfstatat", FD, Path, Stat, Hex),
+ 263: makeSyscallInfo("unlinkat", FD, Path, Hex),
+ 264: makeSyscallInfo("renameat", FD, Path, Hex, Path),
+ 265: makeSyscallInfo("linkat", FD, Path, Hex, Path, Hex),
+ 266: makeSyscallInfo("symlinkat", Path, Hex, Path),
+ 267: makeSyscallInfo("readlinkat", FD, Path, ReadBuffer, Hex),
+ 268: makeSyscallInfo("fchmodat", FD, Path, Mode),
+ 269: makeSyscallInfo("faccessat", FD, Path, Oct, Hex),
+ 270: makeSyscallInfo("pselect6", Hex, SelectFDSet, SelectFDSet, SelectFDSet, Timespec, SigSet),
+ 271: makeSyscallInfo("ppoll", PollFDs, Hex, Timespec, SigSet, Hex),
+ 272: makeSyscallInfo("unshare", CloneFlags),
+ 273: makeSyscallInfo("set_robust_list", Hex, Hex),
+ 274: makeSyscallInfo("get_robust_list", Hex, Hex, Hex),
+ 275: makeSyscallInfo("splice", FD, Hex, FD, Hex, Hex, Hex),
+ 276: makeSyscallInfo("tee", FD, FD, Hex, Hex),
+ 277: makeSyscallInfo("sync_file_range", FD, Hex, Hex, Hex),
+ 278: makeSyscallInfo("vmsplice", FD, Hex, Hex, Hex),
+ 279: makeSyscallInfo("move_pages", Hex, Hex, Hex, Hex, Hex, Hex),
+ 280: makeSyscallInfo("utimensat", FD, Path, UTimeTimespec, Hex),
+ 281: makeSyscallInfo("epoll_pwait", FD, EpollEvents, Hex, Hex, SigSet, Hex),
+ 282: makeSyscallInfo("signalfd", Hex, Hex, Hex),
+ 283: makeSyscallInfo("timerfd_create", Hex, Hex),
+ 284: makeSyscallInfo("eventfd", Hex),
+ 285: makeSyscallInfo("fallocate", FD, Hex, Hex, Hex),
+ 286: makeSyscallInfo("timerfd_settime", FD, Hex, ItimerSpec, PostItimerSpec),
+ 287: makeSyscallInfo("timerfd_gettime", FD, PostItimerSpec),
+ 288: makeSyscallInfo("accept4", FD, PostSockAddr, SockLen, SockFlags),
+ 289: makeSyscallInfo("signalfd4", Hex, Hex, Hex, Hex),
+ 290: makeSyscallInfo("eventfd2", Hex, Hex),
+ 291: makeSyscallInfo("epoll_create1", Hex),
+ 292: makeSyscallInfo("dup3", FD, FD, Hex),
+ 293: makeSyscallInfo("pipe2", PipeFDs, Hex),
+ 294: makeSyscallInfo("inotify_init1", Hex),
+ 295: makeSyscallInfo("preadv", FD, ReadIOVec, Hex, Hex),
+ 296: makeSyscallInfo("pwritev", FD, WriteIOVec, Hex, Hex),
+ 297: makeSyscallInfo("rt_tgsigqueueinfo", Hex, Hex, Signal, Hex),
+ 298: makeSyscallInfo("perf_event_open", Hex, Hex, Hex, Hex, Hex),
+ 299: makeSyscallInfo("recvmmsg", FD, Hex, Hex, Hex, Hex),
+ 300: makeSyscallInfo("fanotify_init", Hex, Hex),
+ 301: makeSyscallInfo("fanotify_mark", Hex, Hex, Hex, Hex, Hex),
+ 302: makeSyscallInfo("prlimit64", Hex, Hex, Hex, Hex),
+ 303: makeSyscallInfo("name_to_handle_at", FD, Hex, Hex, Hex, Hex),
+ 304: makeSyscallInfo("open_by_handle_at", FD, Hex, Hex),
+ 305: makeSyscallInfo("clock_adjtime", Hex, Hex),
+ 306: makeSyscallInfo("syncfs", FD),
+ 307: makeSyscallInfo("sendmmsg", FD, Hex, Hex, Hex),
+ 308: makeSyscallInfo("setns", FD, Hex),
+ 309: makeSyscallInfo("getcpu", Hex, Hex, Hex),
+ 310: makeSyscallInfo("process_vm_readv", Hex, ReadIOVec, Hex, IOVec, Hex, Hex),
+ 311: makeSyscallInfo("process_vm_writev", Hex, IOVec, Hex, WriteIOVec, Hex, Hex),
+ 312: makeSyscallInfo("kcmp", Hex, Hex, Hex, Hex, Hex),
+ 313: makeSyscallInfo("finit_module", Hex, Hex, Hex),
+ 314: makeSyscallInfo("sched_setattr", Hex, Hex, Hex),
+ 315: makeSyscallInfo("sched_getattr", Hex, Hex, Hex),
+ 316: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex),
+ 317: makeSyscallInfo("seccomp", Hex, Hex, Hex),
+ 318: makeSyscallInfo("getrandom", Hex, Hex, Hex),
+ 319: makeSyscallInfo("memfd_create", Path, Hex), // Not quite a path, but close.
+ 320: makeSyscallInfo("kexec_file_load", FD, FD, Hex, Hex, Hex),
+ 321: makeSyscallInfo("bpf", Hex, Hex, Hex),
+ 322: makeSyscallInfo("execveat", FD, Path, ExecveStringVector, ExecveStringVector, Hex),
+ 323: makeSyscallInfo("userfaultfd", Hex),
+ 324: makeSyscallInfo("membarrier", Hex, Hex),
+ 325: makeSyscallInfo("mlock2", Hex, Hex, Hex),
+ 326: makeSyscallInfo("copy_file_range", FD, Hex, FD, Hex, Hex, Hex),
+ 327: makeSyscallInfo("preadv2", FD, ReadIOVec, Hex, Hex, Hex),
+ 328: makeSyscallInfo("pwritev2", FD, WriteIOVec, Hex, Hex, Hex),
+ 329: makeSyscallInfo("pkey_mprotect", Hex, Hex, Hex, Hex),
+ 330: makeSyscallInfo("pkey_alloc", Hex, Hex),
+ 331: makeSyscallInfo("pkey_free", Hex),
+ 332: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex),
+ 333: makeSyscallInfo("io_pgetevents", Hex, Hex, Hex, Hex, Timespec, SigSet),
+ 334: makeSyscallInfo("rseq", Hex, Hex, Hex, Hex),
+ 424: makeSyscallInfo("pidfd_send_signal", FD, Signal, Hex, Hex),
+ 425: makeSyscallInfo("io_uring_setup", Hex, Hex),
+ 426: makeSyscallInfo("io_uring_enter", FD, Hex, Hex, Hex, SigSet, Hex),
+ 427: makeSyscallInfo("io_uring_register", FD, Hex, Hex, Hex),
+ 428: makeSyscallInfo("open_tree", FD, Path, Hex),
+ 429: makeSyscallInfo("move_mount", FD, Path, FD, Path, Hex),
+ 430: makeSyscallInfo("fsopen", Path, Hex), // Not quite a path, but close.
+ 431: makeSyscallInfo("fsconfig", FD, Hex, Hex, Hex, Hex),
+ 432: makeSyscallInfo("fsmount", FD, Hex, Hex),
+ 433: makeSyscallInfo("fspick", FD, Path, Hex),
+ 434: makeSyscallInfo("pidfd_open", Hex, Hex),
+ 435: makeSyscallInfo("clone3", Hex, Hex),
+}
+
+func init() {
+ syscallTables = append(syscallTables,
+ syscallTable{
+ os: abi.Linux,
+ arch: arch.AMD64,
+ syscalls: linuxAMD64,
+ },
+ )
+}
diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go
new file mode 100644
index 000000000..bd7361a52
--- /dev/null
+++ b/pkg/sentry/strace/linux64_arm64.go
@@ -0,0 +1,323 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package strace
+
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+// linuxARM64 provides a mapping of the Linux arm64 syscalls and their argument
+// types for display / formatting.
+var linuxARM64 = SyscallMap{
+ 0: makeSyscallInfo("io_setup", Hex, Hex),
+ 1: makeSyscallInfo("io_destroy", Hex),
+ 2: makeSyscallInfo("io_submit", Hex, Hex, Hex),
+ 3: makeSyscallInfo("io_cancel", Hex, Hex, Hex),
+ 4: makeSyscallInfo("io_getevents", Hex, Hex, Hex, Hex, Timespec),
+ 5: makeSyscallInfo("setxattr", Path, Path, Hex, Hex, Hex),
+ 6: makeSyscallInfo("lsetxattr", Path, Path, Hex, Hex, Hex),
+ 7: makeSyscallInfo("fsetxattr", FD, Path, Hex, Hex, Hex),
+ 8: makeSyscallInfo("getxattr", Path, Path, Hex, Hex),
+ 9: makeSyscallInfo("lgetxattr", Path, Path, Hex, Hex),
+ 10: makeSyscallInfo("fgetxattr", FD, Path, Hex, Hex),
+ 11: makeSyscallInfo("listxattr", Path, Path, Hex),
+ 12: makeSyscallInfo("llistxattr", Path, Path, Hex),
+ 13: makeSyscallInfo("flistxattr", FD, Path, Hex),
+ 14: makeSyscallInfo("removexattr", Path, Path),
+ 15: makeSyscallInfo("lremovexattr", Path, Path),
+ 16: makeSyscallInfo("fremovexattr", FD, Path),
+ 17: makeSyscallInfo("getcwd", PostPath, Hex),
+ 18: makeSyscallInfo("lookup_dcookie", Hex, Hex, Hex),
+ 19: makeSyscallInfo("eventfd2", Hex, Hex),
+ 20: makeSyscallInfo("epoll_create1", Hex),
+ 21: makeSyscallInfo("epoll_ctl", FD, EpollCtlOp, FD, EpollEvent),
+ 22: makeSyscallInfo("epoll_pwait", FD, EpollEvents, Hex, Hex, SigSet, Hex),
+ 23: makeSyscallInfo("dup", FD),
+ 24: makeSyscallInfo("dup3", FD, FD, Hex),
+ 25: makeSyscallInfo("fcntl", FD, Hex, Hex),
+ 26: makeSyscallInfo("inotify_init1", Hex),
+ 27: makeSyscallInfo("inotify_add_watch", Hex, Path, Hex),
+ 28: makeSyscallInfo("inotify_rm_watch", Hex, Hex),
+ 29: makeSyscallInfo("ioctl", FD, Hex, Hex),
+ 30: makeSyscallInfo("ioprio_set", Hex, Hex, Hex),
+ 31: makeSyscallInfo("ioprio_get", Hex, Hex),
+ 32: makeSyscallInfo("flock", FD, Hex),
+ 33: makeSyscallInfo("mknodat", FD, Path, Mode, Hex),
+ 34: makeSyscallInfo("mkdirat", FD, Path, Hex),
+ 35: makeSyscallInfo("unlinkat", FD, Path, Hex),
+ 36: makeSyscallInfo("symlinkat", Path, Hex, Path),
+ 37: makeSyscallInfo("linkat", FD, Path, Hex, Path, Hex),
+ 38: makeSyscallInfo("renameat", FD, Path, Hex, Path),
+ 39: makeSyscallInfo("umount2", Path, Hex),
+ 40: makeSyscallInfo("mount", Path, Path, Path, Hex, Path),
+ 41: makeSyscallInfo("pivot_root", Path, Path),
+ 42: makeSyscallInfo("nfsservctl", Hex, Hex, Hex),
+ 43: makeSyscallInfo("statfs", Path, Hex),
+ 44: makeSyscallInfo("fstatfs", FD, Hex),
+ 45: makeSyscallInfo("truncate", Path, Hex),
+ 46: makeSyscallInfo("ftruncate", FD, Hex),
+ 47: makeSyscallInfo("fallocate", FD, Hex, Hex, Hex),
+ 48: makeSyscallInfo("faccessat", FD, Path, Oct, Hex),
+ 49: makeSyscallInfo("chdir", Path),
+ 50: makeSyscallInfo("fchdir", FD),
+ 51: makeSyscallInfo("chroot", Path),
+ 52: makeSyscallInfo("fchmod", FD, Mode),
+ 53: makeSyscallInfo("fchmodat", FD, Path, Mode),
+ 54: makeSyscallInfo("fchownat", FD, Path, Hex, Hex, Hex),
+ 55: makeSyscallInfo("fchown", FD, Hex, Hex),
+ 56: makeSyscallInfo("openat", FD, Path, OpenFlags, Mode),
+ 57: makeSyscallInfo("close", FD),
+ 58: makeSyscallInfo("vhangup"),
+ 59: makeSyscallInfo("pipe2", PipeFDs, Hex),
+ 60: makeSyscallInfo("quotactl", Hex, Hex, Hex, Hex),
+ 61: makeSyscallInfo("getdents64", FD, Hex, Hex),
+ 62: makeSyscallInfo("lseek", Hex, Hex, Hex),
+ 63: makeSyscallInfo("read", FD, ReadBuffer, Hex),
+ 64: makeSyscallInfo("write", FD, WriteBuffer, Hex),
+ 65: makeSyscallInfo("readv", FD, ReadIOVec, Hex),
+ 66: makeSyscallInfo("writev", FD, WriteIOVec, Hex),
+ 67: makeSyscallInfo("pread64", FD, ReadBuffer, Hex, Hex),
+ 68: makeSyscallInfo("pwrite64", FD, WriteBuffer, Hex, Hex),
+ 69: makeSyscallInfo("preadv", FD, ReadIOVec, Hex, Hex),
+ 70: makeSyscallInfo("pwritev", FD, WriteIOVec, Hex, Hex),
+ 71: makeSyscallInfo("sendfile", FD, FD, Hex, Hex),
+ 72: makeSyscallInfo("pselect6", Hex, Hex, Hex, Hex, Hex, Hex),
+ 73: makeSyscallInfo("ppoll", PollFDs, Hex, Timespec, SigSet, Hex),
+ 74: makeSyscallInfo("signalfd4", Hex, Hex, Hex, Hex),
+ 75: makeSyscallInfo("vmsplice", FD, Hex, Hex, Hex),
+ 76: makeSyscallInfo("splice", FD, Hex, FD, Hex, Hex, Hex),
+ 77: makeSyscallInfo("tee", FD, FD, Hex, Hex),
+ 78: makeSyscallInfo("readlinkat", FD, Path, ReadBuffer, Hex),
+ 79: makeSyscallInfo("fstatat", FD, Path, Stat, Hex),
+ 80: makeSyscallInfo("fstat", FD, Stat),
+ 81: makeSyscallInfo("sync"),
+ 82: makeSyscallInfo("fsync", FD),
+ 83: makeSyscallInfo("fdatasync", FD),
+ 84: makeSyscallInfo("sync_file_range", FD, Hex, Hex, Hex),
+ 85: makeSyscallInfo("timerfd_create", Hex, Hex),
+ 86: makeSyscallInfo("timerfd_settime", FD, Hex, ItimerSpec, PostItimerSpec),
+ 87: makeSyscallInfo("timerfd_gettime", FD, PostItimerSpec),
+ 88: makeSyscallInfo("utimensat", FD, Path, UTimeTimespec, Hex),
+ 89: makeSyscallInfo("acct", Hex),
+ 90: makeSyscallInfo("capget", CapHeader, PostCapData),
+ 91: makeSyscallInfo("capset", CapHeader, CapData),
+ 92: makeSyscallInfo("personality", Hex),
+ 93: makeSyscallInfo("exit", Hex),
+ 94: makeSyscallInfo("exit_group", Hex),
+ 95: makeSyscallInfo("waitid", Hex, Hex, Hex, Hex, Rusage),
+ 96: makeSyscallInfo("set_tid_address", Hex),
+ 97: makeSyscallInfo("unshare", CloneFlags),
+ 98: makeSyscallInfo("futex", Hex, FutexOp, Hex, Timespec, Hex, Hex),
+ 99: makeSyscallInfo("set_robust_list", Hex, Hex),
+ 100: makeSyscallInfo("get_robust_list", Hex, Hex, Hex),
+ 101: makeSyscallInfo("nanosleep", Timespec, PostTimespec),
+ 102: makeSyscallInfo("getitimer", ItimerType, PostItimerVal),
+ 103: makeSyscallInfo("setitimer", ItimerType, ItimerVal, PostItimerVal),
+ 104: makeSyscallInfo("kexec_load", Hex, Hex, Hex, Hex),
+ 105: makeSyscallInfo("init_module", Hex, Hex, Hex),
+ 106: makeSyscallInfo("delete_module", Hex, Hex),
+ 107: makeSyscallInfo("timer_create", Hex, Hex, Hex),
+ 108: makeSyscallInfo("timer_gettime", Hex, PostItimerSpec),
+ 109: makeSyscallInfo("timer_getoverrun", Hex),
+ 110: makeSyscallInfo("timer_settime", Hex, Hex, ItimerSpec, PostItimerSpec),
+ 111: makeSyscallInfo("timer_delete", Hex),
+ 112: makeSyscallInfo("clock_settime", Hex, Timespec),
+ 113: makeSyscallInfo("clock_gettime", Hex, PostTimespec),
+ 114: makeSyscallInfo("clock_getres", Hex, PostTimespec),
+ 115: makeSyscallInfo("clock_nanosleep", Hex, Hex, Timespec, PostTimespec),
+ 116: makeSyscallInfo("syslog", Hex, Hex, Hex),
+ 117: makeSyscallInfo("ptrace", PtraceRequest, Hex, Hex, Hex),
+ 118: makeSyscallInfo("sched_setparam", Hex, Hex),
+ 119: makeSyscallInfo("sched_setscheduler", Hex, Hex, Hex),
+ 120: makeSyscallInfo("sched_getscheduler", Hex),
+ 121: makeSyscallInfo("sched_getparam", Hex, Hex),
+ 122: makeSyscallInfo("sched_setaffinity", Hex, Hex, Hex),
+ 123: makeSyscallInfo("sched_getaffinity", Hex, Hex, Hex),
+ 124: makeSyscallInfo("sched_yield"),
+ 125: makeSyscallInfo("sched_get_priority_max", Hex),
+ 126: makeSyscallInfo("sched_get_priority_min", Hex),
+ 127: makeSyscallInfo("sched_rr_get_interval", Hex, Hex),
+ 128: makeSyscallInfo("restart_syscall"),
+ 129: makeSyscallInfo("kill", Hex, Signal),
+ 130: makeSyscallInfo("tkill", Hex, Signal),
+ 131: makeSyscallInfo("tgkill", Hex, Hex, Signal),
+ 132: makeSyscallInfo("sigaltstack", Hex, Hex),
+ 133: makeSyscallInfo("rt_sigsuspend", Hex),
+ 134: makeSyscallInfo("rt_sigaction", Signal, SigAction, PostSigAction, Hex),
+ 135: makeSyscallInfo("rt_sigprocmask", SignalMaskAction, SigSet, PostSigSet, Hex),
+ 136: makeSyscallInfo("rt_sigpending", Hex),
+ 137: makeSyscallInfo("rt_sigtimedwait", SigSet, Hex, Timespec, Hex),
+ 138: makeSyscallInfo("rt_sigqueueinfo", Hex, Signal, Hex),
+ 139: makeSyscallInfo("rt_sigreturn"),
+ 140: makeSyscallInfo("setpriority", Hex, Hex, Hex),
+ 141: makeSyscallInfo("getpriority", Hex, Hex),
+ 142: makeSyscallInfo("reboot", Hex, Hex, Hex, Hex),
+ 143: makeSyscallInfo("setregid", Hex, Hex),
+ 144: makeSyscallInfo("setgid", Hex),
+ 145: makeSyscallInfo("setreuid", Hex, Hex),
+ 146: makeSyscallInfo("setuid", Hex),
+ 147: makeSyscallInfo("setresuid", Hex, Hex, Hex),
+ 148: makeSyscallInfo("getresuid", Hex, Hex, Hex),
+ 149: makeSyscallInfo("setresgid", Hex, Hex, Hex),
+ 150: makeSyscallInfo("getresgid", Hex, Hex, Hex),
+ 151: makeSyscallInfo("setfsuid", Hex),
+ 152: makeSyscallInfo("setfsgid", Hex),
+ 153: makeSyscallInfo("times", Hex),
+ 154: makeSyscallInfo("setpgid", Hex, Hex),
+ 155: makeSyscallInfo("getpgid", Hex),
+ 156: makeSyscallInfo("getsid", Hex),
+ 157: makeSyscallInfo("setsid"),
+ 158: makeSyscallInfo("getgroups", Hex, Hex),
+ 159: makeSyscallInfo("setgroups", Hex, Hex),
+ 160: makeSyscallInfo("uname", Uname),
+ 161: makeSyscallInfo("sethostname", Hex, Hex),
+ 162: makeSyscallInfo("setdomainname", Hex, Hex),
+ 163: makeSyscallInfo("getrlimit", Hex, Hex),
+ 164: makeSyscallInfo("setrlimit", Hex, Hex),
+ 165: makeSyscallInfo("getrusage", Hex, Rusage),
+ 166: makeSyscallInfo("umask", Hex),
+ 167: makeSyscallInfo("prctl", Hex, Hex, Hex, Hex, Hex),
+ 168: makeSyscallInfo("getcpu", Hex, Hex, Hex),
+ 169: makeSyscallInfo("gettimeofday", Timeval, Hex),
+ 170: makeSyscallInfo("settimeofday", Timeval, Hex),
+ 171: makeSyscallInfo("adjtimex", Hex),
+ 172: makeSyscallInfo("getpid"),
+ 173: makeSyscallInfo("getppid"),
+ 174: makeSyscallInfo("getuid"),
+ 175: makeSyscallInfo("geteuid"),
+ 176: makeSyscallInfo("getgid"),
+ 177: makeSyscallInfo("getegid"),
+ 178: makeSyscallInfo("gettid"),
+ 179: makeSyscallInfo("sysinfo", Hex),
+ 180: makeSyscallInfo("mq_open", Hex, Hex, Hex, Hex),
+ 181: makeSyscallInfo("mq_unlink", Hex),
+ 182: makeSyscallInfo("mq_timedsend", Hex, Hex, Hex, Hex, Hex),
+ 183: makeSyscallInfo("mq_timedreceive", Hex, Hex, Hex, Hex, Hex),
+ 184: makeSyscallInfo("mq_notify", Hex, Hex),
+ 185: makeSyscallInfo("mq_getsetattr", Hex, Hex, Hex),
+ 186: makeSyscallInfo("msgget", Hex, Hex),
+ 187: makeSyscallInfo("msgctl", Hex, Hex, Hex),
+ 188: makeSyscallInfo("msgrcv", Hex, Hex, Hex, Hex, Hex),
+ 189: makeSyscallInfo("msgsnd", Hex, Hex, Hex, Hex),
+ 190: makeSyscallInfo("semget", Hex, Hex, Hex),
+ 191: makeSyscallInfo("semctl", Hex, Hex, Hex, Hex),
+ 192: makeSyscallInfo("semtimedop", Hex, Hex, Hex, Hex),
+ 193: makeSyscallInfo("semop", Hex, Hex, Hex),
+ 194: makeSyscallInfo("shmget", Hex, Hex, Hex),
+ 195: makeSyscallInfo("shmctl", Hex, Hex, Hex),
+ 196: makeSyscallInfo("shmat", Hex, Hex, Hex),
+ 197: makeSyscallInfo("shmdt", Hex),
+ 198: makeSyscallInfo("socket", SockFamily, SockType, SockProtocol),
+ 199: makeSyscallInfo("socketpair", SockFamily, SockType, SockProtocol, Hex),
+ 200: makeSyscallInfo("bind", FD, SockAddr, Hex),
+ 201: makeSyscallInfo("listen", FD, Hex),
+ 202: makeSyscallInfo("accept", FD, PostSockAddr, SockLen),
+ 203: makeSyscallInfo("connect", FD, SockAddr, Hex),
+ 204: makeSyscallInfo("getsockname", FD, PostSockAddr, SockLen),
+ 205: makeSyscallInfo("getpeername", FD, PostSockAddr, SockLen),
+ 206: makeSyscallInfo("sendto", FD, Hex, Hex, Hex, SockAddr, Hex),
+ 207: makeSyscallInfo("recvfrom", FD, Hex, Hex, Hex, PostSockAddr, SockLen),
+ 208: makeSyscallInfo("setsockopt", FD, Hex, Hex, Hex, Hex),
+ 209: makeSyscallInfo("getsockopt", FD, Hex, Hex, Hex, Hex),
+ 210: makeSyscallInfo("shutdown", FD, Hex),
+ 211: makeSyscallInfo("sendmsg", FD, SendMsgHdr, Hex),
+ 212: makeSyscallInfo("recvmsg", FD, RecvMsgHdr, Hex),
+ 213: makeSyscallInfo("readahead", Hex, Hex, Hex),
+ 214: makeSyscallInfo("brk", Hex),
+ 215: makeSyscallInfo("munmap", Hex, Hex),
+ 216: makeSyscallInfo("mremap", Hex, Hex, Hex, Hex, Hex),
+ 217: makeSyscallInfo("add_key", Hex, Hex, Hex, Hex, Hex),
+ 218: makeSyscallInfo("request_key", Hex, Hex, Hex, Hex),
+ 219: makeSyscallInfo("keyctl", Hex, Hex, Hex, Hex, Hex),
+ 220: makeSyscallInfo("clone", CloneFlags, Hex, Hex, Hex, Hex),
+ 221: makeSyscallInfo("execve", Path, ExecveStringVector, ExecveStringVector),
+ 222: makeSyscallInfo("mmap", Hex, Hex, Hex, Hex, FD, Hex),
+ 223: makeSyscallInfo("fadvise64", FD, Hex, Hex, Hex),
+ 224: makeSyscallInfo("swapon", Hex, Hex),
+ 225: makeSyscallInfo("swapoff", Hex),
+ 226: makeSyscallInfo("mprotect", Hex, Hex, Hex),
+ 227: makeSyscallInfo("msync", Hex, Hex, Hex),
+ 228: makeSyscallInfo("mlock", Hex, Hex),
+ 229: makeSyscallInfo("munlock", Hex, Hex),
+ 230: makeSyscallInfo("mlockall", Hex),
+ 231: makeSyscallInfo("munlockall"),
+ 232: makeSyscallInfo("mincore", Hex, Hex, Hex),
+ 233: makeSyscallInfo("madvise", Hex, Hex, Hex),
+ 234: makeSyscallInfo("remap_file_pages", Hex, Hex, Hex, Hex, Hex),
+ 235: makeSyscallInfo("mbind", Hex, Hex, Hex, Hex, Hex, Hex),
+ 236: makeSyscallInfo("get_mempolicy", Hex, Hex, Hex, Hex, Hex),
+ 237: makeSyscallInfo("set_mempolicy", Hex, Hex, Hex),
+ 238: makeSyscallInfo("migrate_pages", Hex, Hex, Hex, Hex),
+ 239: makeSyscallInfo("move_pages", Hex, Hex, Hex, Hex, Hex, Hex),
+ 240: makeSyscallInfo("rt_tgsigqueueinfo", Hex, Hex, Signal, Hex),
+ 241: makeSyscallInfo("perf_event_open", Hex, Hex, Hex, Hex, Hex),
+ 242: makeSyscallInfo("accept4", FD, PostSockAddr, SockLen, SockFlags),
+ 243: makeSyscallInfo("recvmmsg", FD, Hex, Hex, Hex, Hex),
+
+ 260: makeSyscallInfo("wait4", Hex, Hex, Hex, Rusage),
+ 261: makeSyscallInfo("prlimit64", Hex, Hex, Hex, Hex),
+ 262: makeSyscallInfo("fanotify_init", Hex, Hex),
+ 263: makeSyscallInfo("fanotify_mark", Hex, Hex, Hex, Hex, Hex),
+ 264: makeSyscallInfo("name_to_handle_at", FD, Hex, Hex, Hex, Hex),
+ 265: makeSyscallInfo("open_by_handle_at", FD, Hex, Hex),
+ 266: makeSyscallInfo("clock_adjtime", Hex, Hex),
+ 267: makeSyscallInfo("syncfs", FD),
+ 268: makeSyscallInfo("setns", FD, Hex),
+ 269: makeSyscallInfo("sendmmsg", FD, Hex, Hex, Hex),
+ 270: makeSyscallInfo("process_vm_readv", Hex, ReadIOVec, Hex, IOVec, Hex, Hex),
+ 271: makeSyscallInfo("process_vm_writev", Hex, IOVec, Hex, WriteIOVec, Hex, Hex),
+ 272: makeSyscallInfo("kcmp", Hex, Hex, Hex, Hex, Hex),
+ 273: makeSyscallInfo("finit_module", Hex, Hex, Hex),
+ 274: makeSyscallInfo("sched_setattr", Hex, Hex, Hex),
+ 275: makeSyscallInfo("sched_getattr", Hex, Hex, Hex),
+ 276: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex),
+ 277: makeSyscallInfo("seccomp", Hex, Hex, Hex),
+ 278: makeSyscallInfo("getrandom", Hex, Hex, Hex),
+ 279: makeSyscallInfo("memfd_create", Path, Hex),
+ 280: makeSyscallInfo("bpf", Hex, Hex, Hex),
+ 281: makeSyscallInfo("execveat", FD, Path, Hex, Hex, Hex),
+ 282: makeSyscallInfo("userfaultfd", Hex),
+ 283: makeSyscallInfo("membarrier", Hex),
+ 284: makeSyscallInfo("mlock2", Hex, Hex, Hex),
+ 285: makeSyscallInfo("copy_file_range", FD, Hex, FD, Hex, Hex, Hex),
+ 286: makeSyscallInfo("preadv2", FD, ReadIOVec, Hex, Hex, Hex),
+ 287: makeSyscallInfo("pwritev2", FD, WriteIOVec, Hex, Hex, Hex),
+ 291: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex),
+ 292: makeSyscallInfo("io_pgetevents", Hex, Hex, Hex, Hex, Timespec, SigSet),
+ 293: makeSyscallInfo("rseq", Hex, Hex, Hex, Hex),
+ 424: makeSyscallInfo("pidfd_send_signal", FD, Signal, Hex, Hex),
+ 425: makeSyscallInfo("io_uring_setup", Hex, Hex),
+ 426: makeSyscallInfo("io_uring_enter", FD, Hex, Hex, Hex, SigSet, Hex),
+ 427: makeSyscallInfo("io_uring_register", FD, Hex, Hex, Hex),
+ 428: makeSyscallInfo("open_tree", FD, Path, Hex),
+ 429: makeSyscallInfo("move_mount", FD, Path, FD, Path, Hex),
+ 430: makeSyscallInfo("fsopen", Path, Hex), // Not quite a path, but close.
+ 431: makeSyscallInfo("fsconfig", FD, Hex, Hex, Hex, Hex),
+ 432: makeSyscallInfo("fsmount", FD, Hex, Hex),
+ 433: makeSyscallInfo("fspick", FD, Path, Hex),
+ 434: makeSyscallInfo("pidfd_open", Hex, Hex),
+ 435: makeSyscallInfo("clone3", Hex, Hex),
+}
+
+func init() {
+ syscallTables = append(syscallTables,
+ syscallTable{
+ os: abi.Linux,
+ arch: arch.ARM64,
+ syscalls: linuxARM64})
+}
diff --git a/pkg/sentry/strace/open.go b/pkg/sentry/strace/open.go
new file mode 100644
index 000000000..e40bcb53b
--- /dev/null
+++ b/pkg/sentry/strace/open.go
@@ -0,0 +1,96 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi"
+)
+
+// OpenMode represents the mode to open(2) a file.
+var OpenMode = abi.ValueSet{
+ syscall.O_RDWR: "O_RDWR",
+ syscall.O_WRONLY: "O_WRONLY",
+ syscall.O_RDONLY: "O_RDONLY",
+}
+
+// OpenFlagSet is the set of open(2) flags.
+var OpenFlagSet = abi.FlagSet{
+ {
+ Flag: syscall.O_APPEND,
+ Name: "O_APPEND",
+ },
+ {
+ Flag: syscall.O_ASYNC,
+ Name: "O_ASYNC",
+ },
+ {
+ Flag: syscall.O_CLOEXEC,
+ Name: "O_CLOEXEC",
+ },
+ {
+ Flag: syscall.O_CREAT,
+ Name: "O_CREAT",
+ },
+ {
+ Flag: syscall.O_DIRECT,
+ Name: "O_DIRECT",
+ },
+ {
+ Flag: syscall.O_DIRECTORY,
+ Name: "O_DIRECTORY",
+ },
+ {
+ Flag: syscall.O_EXCL,
+ Name: "O_EXCL",
+ },
+ {
+ Flag: syscall.O_NOATIME,
+ Name: "O_NOATIME",
+ },
+ {
+ Flag: syscall.O_NOCTTY,
+ Name: "O_NOCTTY",
+ },
+ {
+ Flag: syscall.O_NOFOLLOW,
+ Name: "O_NOFOLLOW",
+ },
+ {
+ Flag: syscall.O_NONBLOCK,
+ Name: "O_NONBLOCK",
+ },
+ {
+ Flag: 0x200000, // O_PATH
+ Name: "O_PATH",
+ },
+ {
+ Flag: syscall.O_SYNC,
+ Name: "O_SYNC",
+ },
+ {
+ Flag: syscall.O_TRUNC,
+ Name: "O_TRUNC",
+ },
+}
+
+func open(val uint64) string {
+ s := OpenMode.Parse(val & syscall.O_ACCMODE)
+ if flags := OpenFlagSet.Parse(val &^ syscall.O_ACCMODE); flags != "" {
+ s += "|" + flags
+ }
+ return s
+}
diff --git a/pkg/sentry/strace/poll.go b/pkg/sentry/strace/poll.go
new file mode 100644
index 000000000..074e80f9b
--- /dev/null
+++ b/pkg/sentry/strace/poll.go
@@ -0,0 +1,71 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "fmt"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// PollEventSet is the set of poll(2) event flags.
+var PollEventSet = abi.FlagSet{
+ {Flag: linux.POLLIN, Name: "POLLIN"},
+ {Flag: linux.POLLPRI, Name: "POLLPRI"},
+ {Flag: linux.POLLOUT, Name: "POLLOUT"},
+ {Flag: linux.POLLERR, Name: "POLLERR"},
+ {Flag: linux.POLLHUP, Name: "POLLHUP"},
+ {Flag: linux.POLLNVAL, Name: "POLLNVAL"},
+ {Flag: linux.POLLRDNORM, Name: "POLLRDNORM"},
+ {Flag: linux.POLLRDBAND, Name: "POLLRDBAND"},
+ {Flag: linux.POLLWRNORM, Name: "POLLWRNORM"},
+ {Flag: linux.POLLWRBAND, Name: "POLLWRBAND"},
+ {Flag: linux.POLLMSG, Name: "POLLMSG"},
+ {Flag: linux.POLLREMOVE, Name: "POLLREMOVE"},
+ {Flag: linux.POLLRDHUP, Name: "POLLRDHUP"},
+ {Flag: linux.POLLFREE, Name: "POLLFREE"},
+ {Flag: linux.POLL_BUSY_LOOP, Name: "POLL_BUSY_LOOP"},
+}
+
+func pollFD(t *kernel.Task, pfd *linux.PollFD, post bool) string {
+ revents := "..."
+ if post {
+ revents = PollEventSet.Parse(uint64(pfd.REvents))
+ }
+ return fmt.Sprintf("{FD: %s, Events: %s, REvents: %s}", fd(t, pfd.FD), PollEventSet.Parse(uint64(pfd.Events)), revents)
+}
+
+func pollFDs(t *kernel.Task, addr usermem.Addr, nfds uint, post bool) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ pfds, err := slinux.CopyInPollFDs(t, addr, nfds)
+ if err != nil {
+ return fmt.Sprintf("%#x (error decoding pollfds: %s)", addr, err)
+ }
+
+ s := make([]string, 0, len(pfds))
+ for i := range pfds {
+ s = append(s, pollFD(t, &pfds[i], post))
+ }
+
+ return fmt.Sprintf("%#x [%s]", addr, strings.Join(s, ", "))
+}
diff --git a/pkg/sentry/strace/ptrace.go b/pkg/sentry/strace/ptrace.go
new file mode 100644
index 000000000..338bafc6c
--- /dev/null
+++ b/pkg/sentry/strace/ptrace.go
@@ -0,0 +1,62 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// PtraceRequestSet are the possible ptrace(2) requests.
+var PtraceRequestSet = abi.ValueSet{
+ linux.PTRACE_TRACEME: "PTRACE_TRACEME",
+ linux.PTRACE_PEEKTEXT: "PTRACE_PEEKTEXT",
+ linux.PTRACE_PEEKDATA: "PTRACE_PEEKDATA",
+ linux.PTRACE_PEEKUSR: "PTRACE_PEEKUSR",
+ linux.PTRACE_POKETEXT: "PTRACE_POKETEXT",
+ linux.PTRACE_POKEDATA: "PTRACE_POKEDATA",
+ linux.PTRACE_POKEUSR: "PTRACE_POKEUSR",
+ linux.PTRACE_CONT: "PTRACE_CONT",
+ linux.PTRACE_KILL: "PTRACE_KILL",
+ linux.PTRACE_SINGLESTEP: "PTRACE_SINGLESTEP",
+ linux.PTRACE_ATTACH: "PTRACE_ATTACH",
+ linux.PTRACE_DETACH: "PTRACE_DETACH",
+ linux.PTRACE_SYSCALL: "PTRACE_SYSCALL",
+ linux.PTRACE_SETOPTIONS: "PTRACE_SETOPTIONS",
+ linux.PTRACE_GETEVENTMSG: "PTRACE_GETEVENTMSG",
+ linux.PTRACE_GETSIGINFO: "PTRACE_GETSIGINFO",
+ linux.PTRACE_SETSIGINFO: "PTRACE_SETSIGINFO",
+ linux.PTRACE_GETREGSET: "PTRACE_GETREGSET",
+ linux.PTRACE_SETREGSET: "PTRACE_SETREGSET",
+ linux.PTRACE_SEIZE: "PTRACE_SEIZE",
+ linux.PTRACE_INTERRUPT: "PTRACE_INTERRUPT",
+ linux.PTRACE_LISTEN: "PTRACE_LISTEN",
+ linux.PTRACE_PEEKSIGINFO: "PTRACE_PEEKSIGINFO",
+ linux.PTRACE_GETSIGMASK: "PTRACE_GETSIGMASK",
+ linux.PTRACE_SETSIGMASK: "PTRACE_SETSIGMASK",
+ linux.PTRACE_GETREGS: "PTRACE_GETREGS",
+ linux.PTRACE_SETREGS: "PTRACE_SETREGS",
+ linux.PTRACE_GETFPREGS: "PTRACE_GETFPREGS",
+ linux.PTRACE_SETFPREGS: "PTRACE_SETFPREGS",
+ linux.PTRACE_GETFPXREGS: "PTRACE_GETFPXREGS",
+ linux.PTRACE_SETFPXREGS: "PTRACE_SETFPXREGS",
+ linux.PTRACE_OLDSETOPTIONS: "PTRACE_OLDSETOPTIONS",
+ linux.PTRACE_GET_THREAD_AREA: "PTRACE_GET_THREAD_AREA",
+ linux.PTRACE_SET_THREAD_AREA: "PTRACE_SET_THREAD_AREA",
+ linux.PTRACE_ARCH_PRCTL: "PTRACE_ARCH_PRCTL",
+ linux.PTRACE_SYSEMU: "PTRACE_SYSEMU",
+ linux.PTRACE_SYSEMU_SINGLESTEP: "PTRACE_SYSEMU_SINGLESTEP",
+ linux.PTRACE_SINGLEBLOCK: "PTRACE_SINGLEBLOCK",
+}
diff --git a/pkg/sentry/strace/select.go b/pkg/sentry/strace/select.go
new file mode 100644
index 000000000..3a4c32aa0
--- /dev/null
+++ b/pkg/sentry/strace/select.go
@@ -0,0 +1,56 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func fdsFromSet(t *kernel.Task, set []byte) []int {
+ var fds []int
+ // Append n if the n-th bit is 1.
+ for i, v := range set {
+ for j := 0; j < 8; j++ {
+ if (v>>j)&1 == 1 {
+ fds = append(fds, i*8+j)
+ }
+ }
+ }
+ return fds
+}
+
+func fdSet(t *kernel.Task, nfds int, addr usermem.Addr) string {
+ if nfds < 0 {
+ return fmt.Sprintf("%#x (negative nfds)", addr)
+ }
+ if addr == 0 {
+ return "null"
+ }
+
+ // Calculate the size of the fd set (one bit per fd).
+ nBytes := (nfds + 7) / 8
+ nBitsInLastPartialByte := nfds % 8
+
+ set, err := linux.CopyInFDSet(t, addr, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return fmt.Sprintf("%#x (error decoding fdset: %s)", addr, err)
+ }
+
+ return fmt.Sprintf("%#x %v", addr, fdsFromSet(t, set))
+}
diff --git a/pkg/sentry/strace/signal.go b/pkg/sentry/strace/signal.go
new file mode 100644
index 000000000..c41f36e3f
--- /dev/null
+++ b/pkg/sentry/strace/signal.go
@@ -0,0 +1,148 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "fmt"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// signalNames contains the names of all named signals.
+var signalNames = abi.ValueSet{
+ uint64(linux.SIGABRT): "SIGABRT",
+ uint64(linux.SIGALRM): "SIGALRM",
+ uint64(linux.SIGBUS): "SIGBUS",
+ uint64(linux.SIGCHLD): "SIGCHLD",
+ uint64(linux.SIGCONT): "SIGCONT",
+ uint64(linux.SIGFPE): "SIGFPE",
+ uint64(linux.SIGHUP): "SIGHUP",
+ uint64(linux.SIGILL): "SIGILL",
+ uint64(linux.SIGINT): "SIGINT",
+ uint64(linux.SIGIO): "SIGIO",
+ uint64(linux.SIGKILL): "SIGKILL",
+ uint64(linux.SIGPIPE): "SIGPIPE",
+ uint64(linux.SIGPROF): "SIGPROF",
+ uint64(linux.SIGPWR): "SIGPWR",
+ uint64(linux.SIGQUIT): "SIGQUIT",
+ uint64(linux.SIGSEGV): "SIGSEGV",
+ uint64(linux.SIGSTKFLT): "SIGSTKFLT",
+ uint64(linux.SIGSTOP): "SIGSTOP",
+ uint64(linux.SIGSYS): "SIGSYS",
+ uint64(linux.SIGTERM): "SIGTERM",
+ uint64(linux.SIGTRAP): "SIGTRAP",
+ uint64(linux.SIGTSTP): "SIGTSTP",
+ uint64(linux.SIGTTIN): "SIGTTIN",
+ uint64(linux.SIGTTOU): "SIGTTOU",
+ uint64(linux.SIGURG): "SIGURG",
+ uint64(linux.SIGUSR1): "SIGUSR1",
+ uint64(linux.SIGUSR2): "SIGUSR2",
+ uint64(linux.SIGVTALRM): "SIGVTALRM",
+ uint64(linux.SIGWINCH): "SIGWINCH",
+ uint64(linux.SIGXCPU): "SIGXCPU",
+ uint64(linux.SIGXFSZ): "SIGXFSZ",
+}
+
+var signalMaskActions = abi.ValueSet{
+ linux.SIG_BLOCK: "SIG_BLOCK",
+ linux.SIG_UNBLOCK: "SIG_UNBLOCK",
+ linux.SIG_SETMASK: "SIG_SETMASK",
+}
+
+var sigActionFlags = abi.FlagSet{
+ {
+ Flag: linux.SA_NOCLDSTOP,
+ Name: "SA_NOCLDSTOP",
+ },
+ {
+ Flag: linux.SA_NOCLDWAIT,
+ Name: "SA_NOCLDWAIT",
+ },
+ {
+ Flag: linux.SA_SIGINFO,
+ Name: "SA_SIGINFO",
+ },
+ {
+ Flag: linux.SA_RESTORER,
+ Name: "SA_RESTORER",
+ },
+ {
+ Flag: linux.SA_ONSTACK,
+ Name: "SA_ONSTACK",
+ },
+ {
+ Flag: linux.SA_RESTART,
+ Name: "SA_RESTART",
+ },
+ {
+ Flag: linux.SA_NODEFER,
+ Name: "SA_NODEFER",
+ },
+ {
+ Flag: linux.SA_RESETHAND,
+ Name: "SA_RESETHAND",
+ },
+}
+
+func sigSet(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ var b [linux.SignalSetSize]byte
+ if _, err := t.CopyInBytes(addr, b[:]); err != nil {
+ return fmt.Sprintf("%#x (error copying sigset: %v)", addr, err)
+ }
+
+ set := linux.SignalSet(usermem.ByteOrder.Uint64(b[:]))
+
+ return fmt.Sprintf("%#x %s", addr, formatSigSet(set))
+}
+
+func formatSigSet(set linux.SignalSet) string {
+ var signals []string
+ linux.ForEachSignal(set, func(sig linux.Signal) {
+ signals = append(signals, signalNames.ParseDecimal(uint64(sig)))
+ })
+
+ return fmt.Sprintf("[%v]", strings.Join(signals, " "))
+}
+
+func sigAction(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ sa, err := t.CopyInSignalAct(addr)
+ if err != nil {
+ return fmt.Sprintf("%#x (error copying sigaction: %v)", addr, err)
+ }
+
+ var handler string
+ switch sa.Handler {
+ case linux.SIG_IGN:
+ handler = "SIG_IGN"
+ case linux.SIG_DFL:
+ handler = "SIG_DFL"
+ default:
+ handler = fmt.Sprintf("%#x", sa.Handler)
+ }
+
+ return fmt.Sprintf("%#x {Handler: %s, Flags: %s, Restorer: %#x, Mask: %s}", addr, handler, sigActionFlags.Parse(sa.Flags), sa.Restorer, formatSigSet(sa.Mask))
+}
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
new file mode 100644
index 000000000..c0512de89
--- /dev/null
+++ b/pkg/sentry/strace/socket.go
@@ -0,0 +1,644 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "fmt"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netlink"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// SocketFamily are the possible socket(2) families.
+var SocketFamily = abi.ValueSet{
+ linux.AF_UNSPEC: "AF_UNSPEC",
+ linux.AF_UNIX: "AF_UNIX",
+ linux.AF_INET: "AF_INET",
+ linux.AF_AX25: "AF_AX25",
+ linux.AF_IPX: "AF_IPX",
+ linux.AF_APPLETALK: "AF_APPLETALK",
+ linux.AF_NETROM: "AF_NETROM",
+ linux.AF_BRIDGE: "AF_BRIDGE",
+ linux.AF_ATMPVC: "AF_ATMPVC",
+ linux.AF_X25: "AF_X25",
+ linux.AF_INET6: "AF_INET6",
+ linux.AF_ROSE: "AF_ROSE",
+ linux.AF_DECnet: "AF_DECnet",
+ linux.AF_NETBEUI: "AF_NETBEUI",
+ linux.AF_SECURITY: "AF_SECURITY",
+ linux.AF_KEY: "AF_KEY",
+ linux.AF_NETLINK: "AF_NETLINK",
+ linux.AF_PACKET: "AF_PACKET",
+ linux.AF_ASH: "AF_ASH",
+ linux.AF_ECONET: "AF_ECONET",
+ linux.AF_ATMSVC: "AF_ATMSVC",
+ linux.AF_RDS: "AF_RDS",
+ linux.AF_SNA: "AF_SNA",
+ linux.AF_IRDA: "AF_IRDA",
+ linux.AF_PPPOX: "AF_PPPOX",
+ linux.AF_WANPIPE: "AF_WANPIPE",
+ linux.AF_LLC: "AF_LLC",
+ linux.AF_IB: "AF_IB",
+ linux.AF_MPLS: "AF_MPLS",
+ linux.AF_CAN: "AF_CAN",
+ linux.AF_TIPC: "AF_TIPC",
+ linux.AF_BLUETOOTH: "AF_BLUETOOTH",
+ linux.AF_IUCV: "AF_IUCV",
+ linux.AF_RXRPC: "AF_RXRPC",
+ linux.AF_ISDN: "AF_ISDN",
+ linux.AF_PHONET: "AF_PHONET",
+ linux.AF_IEEE802154: "AF_IEEE802154",
+ linux.AF_CAIF: "AF_CAIF",
+ linux.AF_ALG: "AF_ALG",
+ linux.AF_NFC: "AF_NFC",
+ linux.AF_VSOCK: "AF_VSOCK",
+}
+
+// SocketType are the possible socket(2) types.
+var SocketType = abi.ValueSet{
+ uint64(linux.SOCK_STREAM): "SOCK_STREAM",
+ uint64(linux.SOCK_DGRAM): "SOCK_DGRAM",
+ uint64(linux.SOCK_RAW): "SOCK_RAW",
+ uint64(linux.SOCK_RDM): "SOCK_RDM",
+ uint64(linux.SOCK_SEQPACKET): "SOCK_SEQPACKET",
+ uint64(linux.SOCK_DCCP): "SOCK_DCCP",
+ uint64(linux.SOCK_PACKET): "SOCK_PACKET",
+}
+
+// SocketFlagSet are the possible socket(2) flags.
+var SocketFlagSet = abi.FlagSet{
+ {
+ Flag: linux.SOCK_CLOEXEC,
+ Name: "SOCK_CLOEXEC",
+ },
+ {
+ Flag: linux.SOCK_NONBLOCK,
+ Name: "SOCK_NONBLOCK",
+ },
+}
+
+// ipProtocol are the possible socket(2) types for INET and INET6 sockets.
+var ipProtocol = abi.ValueSet{
+ linux.IPPROTO_IP: "IPPROTO_IP",
+ linux.IPPROTO_ICMP: "IPPROTO_ICMP",
+ linux.IPPROTO_IGMP: "IPPROTO_IGMP",
+ linux.IPPROTO_IPIP: "IPPROTO_IPIP",
+ linux.IPPROTO_TCP: "IPPROTO_TCP",
+ linux.IPPROTO_EGP: "IPPROTO_EGP",
+ linux.IPPROTO_PUP: "IPPROTO_PUP",
+ linux.IPPROTO_UDP: "IPPROTO_UDP",
+ linux.IPPROTO_IDP: "IPPROTO_IDP",
+ linux.IPPROTO_TP: "IPPROTO_TP",
+ linux.IPPROTO_DCCP: "IPPROTO_DCCP",
+ linux.IPPROTO_IPV6: "IPPROTO_IPV6",
+ linux.IPPROTO_RSVP: "IPPROTO_RSVP",
+ linux.IPPROTO_GRE: "IPPROTO_GRE",
+ linux.IPPROTO_ESP: "IPPROTO_ESP",
+ linux.IPPROTO_AH: "IPPROTO_AH",
+ linux.IPPROTO_MTP: "IPPROTO_MTP",
+ linux.IPPROTO_BEETPH: "IPPROTO_BEETPH",
+ linux.IPPROTO_ENCAP: "IPPROTO_ENCAP",
+ linux.IPPROTO_PIM: "IPPROTO_PIM",
+ linux.IPPROTO_COMP: "IPPROTO_COMP",
+ linux.IPPROTO_SCTP: "IPPROTO_SCTP",
+ linux.IPPROTO_UDPLITE: "IPPROTO_UDPLITE",
+ linux.IPPROTO_MPLS: "IPPROTO_MPLS",
+ linux.IPPROTO_RAW: "IPPROTO_RAW",
+}
+
+// SocketProtocol are the possible socket(2) protocols for each protocol family.
+var SocketProtocol = map[int32]abi.ValueSet{
+ linux.AF_INET: ipProtocol,
+ linux.AF_INET6: ipProtocol,
+ linux.AF_NETLINK: {
+ linux.NETLINK_ROUTE: "NETLINK_ROUTE",
+ linux.NETLINK_UNUSED: "NETLINK_UNUSED",
+ linux.NETLINK_USERSOCK: "NETLINK_USERSOCK",
+ linux.NETLINK_FIREWALL: "NETLINK_FIREWALL",
+ linux.NETLINK_SOCK_DIAG: "NETLINK_SOCK_DIAG",
+ linux.NETLINK_NFLOG: "NETLINK_NFLOG",
+ linux.NETLINK_XFRM: "NETLINK_XFRM",
+ linux.NETLINK_SELINUX: "NETLINK_SELINUX",
+ linux.NETLINK_ISCSI: "NETLINK_ISCSI",
+ linux.NETLINK_AUDIT: "NETLINK_AUDIT",
+ linux.NETLINK_FIB_LOOKUP: "NETLINK_FIB_LOOKUP",
+ linux.NETLINK_CONNECTOR: "NETLINK_CONNECTOR",
+ linux.NETLINK_NETFILTER: "NETLINK_NETFILTER",
+ linux.NETLINK_IP6_FW: "NETLINK_IP6_FW",
+ linux.NETLINK_DNRTMSG: "NETLINK_DNRTMSG",
+ linux.NETLINK_KOBJECT_UEVENT: "NETLINK_KOBJECT_UEVENT",
+ linux.NETLINK_GENERIC: "NETLINK_GENERIC",
+ linux.NETLINK_SCSITRANSPORT: "NETLINK_SCSITRANSPORT",
+ linux.NETLINK_ECRYPTFS: "NETLINK_ECRYPTFS",
+ linux.NETLINK_RDMA: "NETLINK_RDMA",
+ linux.NETLINK_CRYPTO: "NETLINK_CRYPTO",
+ },
+}
+
+var controlMessageType = map[int32]string{
+ linux.SCM_RIGHTS: "SCM_RIGHTS",
+ linux.SCM_CREDENTIALS: "SCM_CREDENTIALS",
+ linux.SO_TIMESTAMP: "SO_TIMESTAMP",
+}
+
+func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) string {
+ if length > maxBytes {
+ return fmt.Sprintf("%#x (error decoding control: invalid length (%d))", addr, length)
+ }
+
+ buf := make([]byte, length)
+ if _, err := t.CopyIn(addr, &buf); err != nil {
+ return fmt.Sprintf("%#x (error decoding control: %v)", addr, err)
+ }
+
+ var strs []string
+
+ for i := 0; i < len(buf); {
+ if i+linux.SizeOfControlMessageHeader > len(buf) {
+ strs = append(strs, "{invalid control message (too short)}")
+ break
+ }
+
+ var h linux.ControlMessageHeader
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], usermem.ByteOrder, &h)
+
+ var skipData bool
+ level := "SOL_SOCKET"
+ if h.Level != linux.SOL_SOCKET {
+ skipData = true
+ level = fmt.Sprint(h.Level)
+ }
+
+ typ, ok := controlMessageType[h.Type]
+ if !ok {
+ skipData = true
+ typ = fmt.Sprint(h.Type)
+ }
+
+ if h.Length > uint64(len(buf)-i) {
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, content extends beyond buffer}",
+ level,
+ typ,
+ h.Length,
+ ))
+ break
+ }
+
+ i += linux.SizeOfControlMessageHeader
+ width := t.Arch().Width()
+ length := int(h.Length) - linux.SizeOfControlMessageHeader
+ if length < 0 {
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, content too short}",
+ level,
+ typ,
+ h.Length,
+ ))
+ break
+ }
+
+ if skipData {
+ strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length))
+ i += binary.AlignUp(length, width)
+ continue
+ }
+
+ switch h.Type {
+ case linux.SCM_RIGHTS:
+ rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight)
+
+ numRights := rightsSize / linux.SizeOfControlMessageRight
+ fds := make(linux.ControlMessageRights, numRights)
+ binary.Unmarshal(buf[i:i+rightsSize], usermem.ByteOrder, &fds)
+
+ rights := make([]string, 0, len(fds))
+ for _, fd := range fds {
+ rights = append(rights, fmt.Sprint(fd))
+ }
+
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, content: %s}",
+ level,
+ typ,
+ h.Length,
+ strings.Join(rights, ","),
+ ))
+
+ case linux.SCM_CREDENTIALS:
+ if length < linux.SizeOfControlMessageCredentials {
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, content too short}",
+ level,
+ typ,
+ h.Length,
+ ))
+ break
+ }
+
+ var creds linux.ControlMessageCredentials
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds)
+
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}",
+ level,
+ typ,
+ h.Length,
+ creds.PID,
+ creds.UID,
+ creds.GID,
+ ))
+
+ case linux.SO_TIMESTAMP:
+ if length < linux.SizeOfTimeval {
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, content too short}",
+ level,
+ typ,
+ h.Length,
+ ))
+ break
+ }
+
+ var tv linux.Timeval
+ binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], usermem.ByteOrder, &tv)
+
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}",
+ level,
+ typ,
+ h.Length,
+ tv.Sec,
+ tv.Usec,
+ ))
+
+ default:
+ panic("unreachable")
+ }
+ i += binary.AlignUp(length, width)
+ }
+
+ return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", "))
+}
+
+func msghdr(t *kernel.Task, addr usermem.Addr, printContent bool, maxBytes uint64) string {
+ var msg slinux.MessageHeader64
+ if err := slinux.CopyInMessageHeader64(t, addr, &msg); err != nil {
+ return fmt.Sprintf("%#x (error decoding msghdr: %v)", addr, err)
+ }
+ s := fmt.Sprintf(
+ "%#x {name=%#x, namelen=%d, iovecs=%s",
+ addr,
+ msg.Name,
+ msg.NameLen,
+ iovecs(t, usermem.Addr(msg.Iov), int(msg.IovLen), printContent, maxBytes),
+ )
+ if printContent {
+ s = fmt.Sprintf("%s, control={%s}", s, cmsghdr(t, usermem.Addr(msg.Control), msg.ControlLen, maxBytes))
+ } else {
+ s = fmt.Sprintf("%s, control=%#x, control_len=%d", s, msg.Control, msg.ControlLen)
+ }
+ return fmt.Sprintf("%s, flags=%d}", s, msg.Flags)
+}
+
+func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ b, err := slinux.CaptureAddress(t, addr, length)
+ if err != nil {
+ return fmt.Sprintf("%#x {error reading address: %v}", addr, err)
+ }
+
+ // Extract address family.
+ if len(b) < 2 {
+ return fmt.Sprintf("%#x {address too short: %d bytes}", addr, len(b))
+ }
+ family := usermem.ByteOrder.Uint16(b)
+
+ familyStr := SocketFamily.Parse(uint64(family))
+
+ switch family {
+ case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX:
+ fa, _, err := netstack.AddressAndFamily(b)
+ if err != nil {
+ return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err)
+ }
+
+ if family == linux.AF_UNIX {
+ return fmt.Sprintf("%#x {Family: %s, Addr: %q}", addr, familyStr, string(fa.Addr))
+ }
+
+ return fmt.Sprintf("%#x {Family: %s, Addr: %v, Port: %d}", addr, familyStr, fa.Addr, fa.Port)
+ case linux.AF_NETLINK:
+ sa, err := netlink.ExtractSockAddr(b)
+ if err != nil {
+ return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err)
+ }
+ return fmt.Sprintf("%#x {Family: %s, PortID: %d, Groups: %d}", addr, familyStr, sa.PortID, sa.Groups)
+ default:
+ return fmt.Sprintf("%#x {Family: %s, family addr format unknown}", addr, familyStr)
+ }
+}
+
+func postSockAddr(t *kernel.Task, addr usermem.Addr, lengthPtr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ if lengthPtr == 0 {
+ return fmt.Sprintf("%#x {length null}", addr)
+ }
+
+ l, err := copySockLen(t, lengthPtr)
+ if err != nil {
+ return fmt.Sprintf("%#x {error reading length: %v}", addr, err)
+ }
+
+ return sockAddr(t, addr, l)
+}
+
+func copySockLen(t *kernel.Task, addr usermem.Addr) (uint32, error) {
+ // socklen_t is 32-bits.
+ var l uint32
+ _, err := t.CopyIn(addr, &l)
+ return l, err
+}
+
+func sockLenPointer(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+ l, err := copySockLen(t, addr)
+ if err != nil {
+ return fmt.Sprintf("%#x {error reading length: %v}", addr, err)
+ }
+ return fmt.Sprintf("%#x {length=%v}", addr, l)
+}
+
+func sockType(stype int32) string {
+ s := SocketType.Parse(uint64(stype & linux.SOCK_TYPE_MASK))
+ if flags := SocketFlagSet.Parse(uint64(stype &^ linux.SOCK_TYPE_MASK)); flags != "" {
+ s += "|" + flags
+ }
+ return s
+}
+
+func sockProtocol(family, protocol int32) string {
+ protocols, ok := SocketProtocol[family]
+ if !ok {
+ return fmt.Sprintf("%#x", protocol)
+ }
+ return protocols.Parse(uint64(protocol))
+}
+
+func sockFlags(flags int32) string {
+ if flags == 0 {
+ return "0"
+ }
+ return SocketFlagSet.Parse(uint64(flags))
+}
+
+func getSockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, optLen usermem.Addr, maximumBlobSize uint, rval uintptr) string {
+ if int(rval) < 0 {
+ return hexNum(uint64(optVal))
+ }
+ if optVal == 0 {
+ return "null"
+ }
+ l, err := copySockLen(t, optLen)
+ if err != nil {
+ return fmt.Sprintf("%#x {error reading length: %v}", optLen, err)
+ }
+ return sockOptVal(t, level, optname, optVal, uint64(l), maximumBlobSize)
+}
+
+func sockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, optLen uint64, maximumBlobSize uint) string {
+ switch optLen {
+ case 1:
+ var v uint8
+ _, err := t.CopyIn(optVal, &v)
+ if err != nil {
+ return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err)
+ }
+ return fmt.Sprintf("%#x {value=%v}", optVal, v)
+ case 2:
+ var v uint16
+ _, err := t.CopyIn(optVal, &v)
+ if err != nil {
+ return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err)
+ }
+ return fmt.Sprintf("%#x {value=%v}", optVal, v)
+ case 4:
+ var v uint32
+ _, err := t.CopyIn(optVal, &v)
+ if err != nil {
+ return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err)
+ }
+ return fmt.Sprintf("%#x {value=%v}", optVal, v)
+ default:
+ return dump(t, optVal, uint(optLen), maximumBlobSize)
+ }
+}
+
+var sockOptLevels = abi.ValueSet{
+ linux.SOL_IP: "SOL_IP",
+ linux.SOL_SOCKET: "SOL_SOCKET",
+ linux.SOL_TCP: "SOL_TCP",
+ linux.SOL_UDP: "SOL_UDP",
+ linux.SOL_IPV6: "SOL_IPV6",
+ linux.SOL_ICMPV6: "SOL_ICMPV6",
+ linux.SOL_RAW: "SOL_RAW",
+ linux.SOL_PACKET: "SOL_PACKET",
+ linux.SOL_NETLINK: "SOL_NETLINK",
+}
+
+var sockOptNames = map[uint64]abi.ValueSet{
+ linux.SOL_IP: {
+ linux.IP_TTL: "IP_TTL",
+ linux.IP_MULTICAST_TTL: "IP_MULTICAST_TTL",
+ linux.IP_MULTICAST_IF: "IP_MULTICAST_IF",
+ linux.IP_MULTICAST_LOOP: "IP_MULTICAST_LOOP",
+ linux.IP_TOS: "IP_TOS",
+ linux.IP_RECVTOS: "IP_RECVTOS",
+ linux.IPT_SO_GET_INFO: "IPT_SO_GET_INFO",
+ linux.IPT_SO_GET_ENTRIES: "IPT_SO_GET_ENTRIES",
+ linux.IP_ADD_MEMBERSHIP: "IP_ADD_MEMBERSHIP",
+ linux.IP_DROP_MEMBERSHIP: "IP_DROP_MEMBERSHIP",
+ linux.MCAST_JOIN_GROUP: "MCAST_JOIN_GROUP",
+ linux.IP_ADD_SOURCE_MEMBERSHIP: "IP_ADD_SOURCE_MEMBERSHIP",
+ linux.IP_BIND_ADDRESS_NO_PORT: "IP_BIND_ADDRESS_NO_PORT",
+ linux.IP_BLOCK_SOURCE: "IP_BLOCK_SOURCE",
+ linux.IP_CHECKSUM: "IP_CHECKSUM",
+ linux.IP_DROP_SOURCE_MEMBERSHIP: "IP_DROP_SOURCE_MEMBERSHIP",
+ linux.IP_FREEBIND: "IP_FREEBIND",
+ linux.IP_HDRINCL: "IP_HDRINCL",
+ linux.IP_IPSEC_POLICY: "IP_IPSEC_POLICY",
+ linux.IP_MINTTL: "IP_MINTTL",
+ linux.IP_MSFILTER: "IP_MSFILTER",
+ linux.IP_MTU_DISCOVER: "IP_MTU_DISCOVER",
+ linux.IP_MULTICAST_ALL: "IP_MULTICAST_ALL",
+ linux.IP_NODEFRAG: "IP_NODEFRAG",
+ linux.IP_OPTIONS: "IP_OPTIONS",
+ linux.IP_PASSSEC: "IP_PASSSEC",
+ linux.IP_PKTINFO: "IP_PKTINFO",
+ linux.IP_RECVERR: "IP_RECVERR",
+ linux.IP_RECVFRAGSIZE: "IP_RECVFRAGSIZE",
+ linux.IP_RECVOPTS: "IP_RECVOPTS",
+ linux.IP_RECVORIGDSTADDR: "IP_RECVORIGDSTADDR",
+ linux.IP_RECVTTL: "IP_RECVTTL",
+ linux.IP_RETOPTS: "IP_RETOPTS",
+ linux.IP_TRANSPARENT: "IP_TRANSPARENT",
+ linux.IP_UNBLOCK_SOURCE: "IP_UNBLOCK_SOURCE",
+ linux.IP_UNICAST_IF: "IP_UNICAST_IF",
+ linux.IP_XFRM_POLICY: "IP_XFRM_POLICY",
+ linux.MCAST_BLOCK_SOURCE: "MCAST_BLOCK_SOURCE",
+ linux.MCAST_JOIN_SOURCE_GROUP: "MCAST_JOIN_SOURCE_GROUP",
+ linux.MCAST_LEAVE_GROUP: "MCAST_LEAVE_GROUP",
+ linux.MCAST_LEAVE_SOURCE_GROUP: "MCAST_LEAVE_SOURCE_GROUP",
+ linux.MCAST_MSFILTER: "MCAST_MSFILTER",
+ linux.MCAST_UNBLOCK_SOURCE: "MCAST_UNBLOCK_SOURCE",
+ linux.IP_ROUTER_ALERT: "IP_ROUTER_ALERT",
+ linux.IP_PKTOPTIONS: "IP_PKTOPTIONS",
+ linux.IP_MTU: "IP_MTU",
+ },
+ linux.SOL_SOCKET: {
+ linux.SO_ERROR: "SO_ERROR",
+ linux.SO_PEERCRED: "SO_PEERCRED",
+ linux.SO_PASSCRED: "SO_PASSCRED",
+ linux.SO_SNDBUF: "SO_SNDBUF",
+ linux.SO_RCVBUF: "SO_RCVBUF",
+ linux.SO_REUSEADDR: "SO_REUSEADDR",
+ linux.SO_REUSEPORT: "SO_REUSEPORT",
+ linux.SO_BINDTODEVICE: "SO_BINDTODEVICE",
+ linux.SO_BROADCAST: "SO_BROADCAST",
+ linux.SO_KEEPALIVE: "SO_KEEPALIVE",
+ linux.SO_LINGER: "SO_LINGER",
+ linux.SO_SNDTIMEO: "SO_SNDTIMEO",
+ linux.SO_RCVTIMEO: "SO_RCVTIMEO",
+ linux.SO_OOBINLINE: "SO_OOBINLINE",
+ linux.SO_TIMESTAMP: "SO_TIMESTAMP",
+ },
+ linux.SOL_TCP: {
+ linux.TCP_NODELAY: "TCP_NODELAY",
+ linux.TCP_CORK: "TCP_CORK",
+ linux.TCP_QUICKACK: "TCP_QUICKACK",
+ linux.TCP_MAXSEG: "TCP_MAXSEG",
+ linux.TCP_KEEPIDLE: "TCP_KEEPIDLE",
+ linux.TCP_KEEPINTVL: "TCP_KEEPINTVL",
+ linux.TCP_USER_TIMEOUT: "TCP_USER_TIMEOUT",
+ linux.TCP_INFO: "TCP_INFO",
+ linux.TCP_CC_INFO: "TCP_CC_INFO",
+ linux.TCP_NOTSENT_LOWAT: "TCP_NOTSENT_LOWAT",
+ linux.TCP_ZEROCOPY_RECEIVE: "TCP_ZEROCOPY_RECEIVE",
+ linux.TCP_CONGESTION: "TCP_CONGESTION",
+ linux.TCP_LINGER2: "TCP_LINGER2",
+ linux.TCP_DEFER_ACCEPT: "TCP_DEFER_ACCEPT",
+ linux.TCP_REPAIR_OPTIONS: "TCP_REPAIR_OPTIONS",
+ linux.TCP_INQ: "TCP_INQ",
+ linux.TCP_FASTOPEN: "TCP_FASTOPEN",
+ linux.TCP_FASTOPEN_CONNECT: "TCP_FASTOPEN_CONNECT",
+ linux.TCP_FASTOPEN_KEY: "TCP_FASTOPEN_KEY",
+ linux.TCP_FASTOPEN_NO_COOKIE: "TCP_FASTOPEN_NO_COOKIE",
+ linux.TCP_KEEPCNT: "TCP_KEEPCNT",
+ linux.TCP_QUEUE_SEQ: "TCP_QUEUE_SEQ",
+ linux.TCP_REPAIR: "TCP_REPAIR",
+ linux.TCP_REPAIR_QUEUE: "TCP_REPAIR_QUEUE",
+ linux.TCP_REPAIR_WINDOW: "TCP_REPAIR_WINDOW",
+ linux.TCP_SAVED_SYN: "TCP_SAVED_SYN",
+ linux.TCP_SAVE_SYN: "TCP_SAVE_SYN",
+ linux.TCP_SYNCNT: "TCP_SYNCNT",
+ linux.TCP_THIN_DUPACK: "TCP_THIN_DUPACK",
+ linux.TCP_THIN_LINEAR_TIMEOUTS: "TCP_THIN_LINEAR_TIMEOUTS",
+ linux.TCP_TIMESTAMP: "TCP_TIMESTAMP",
+ linux.TCP_ULP: "TCP_ULP",
+ linux.TCP_WINDOW_CLAMP: "TCP_WINDOW_CLAMP",
+ },
+ linux.SOL_IPV6: {
+ linux.IPV6_V6ONLY: "IPV6_V6ONLY",
+ linux.IPV6_PATHMTU: "IPV6_PATHMTU",
+ linux.IPV6_TCLASS: "IPV6_TCLASS",
+ linux.IPV6_ADD_MEMBERSHIP: "IPV6_ADD_MEMBERSHIP",
+ linux.IPV6_DROP_MEMBERSHIP: "IPV6_DROP_MEMBERSHIP",
+ linux.IPV6_IPSEC_POLICY: "IPV6_IPSEC_POLICY",
+ linux.IPV6_JOIN_ANYCAST: "IPV6_JOIN_ANYCAST",
+ linux.IPV6_LEAVE_ANYCAST: "IPV6_LEAVE_ANYCAST",
+ linux.IPV6_PKTINFO: "IPV6_PKTINFO",
+ linux.IPV6_ROUTER_ALERT: "IPV6_ROUTER_ALERT",
+ linux.IPV6_XFRM_POLICY: "IPV6_XFRM_POLICY",
+ linux.MCAST_BLOCK_SOURCE: "MCAST_BLOCK_SOURCE",
+ linux.MCAST_JOIN_GROUP: "MCAST_JOIN_GROUP",
+ linux.MCAST_JOIN_SOURCE_GROUP: "MCAST_JOIN_SOURCE_GROUP",
+ linux.MCAST_LEAVE_GROUP: "MCAST_LEAVE_GROUP",
+ linux.MCAST_LEAVE_SOURCE_GROUP: "MCAST_LEAVE_SOURCE_GROUP",
+ linux.MCAST_UNBLOCK_SOURCE: "MCAST_UNBLOCK_SOURCE",
+ linux.IPV6_2292DSTOPTS: "IPV6_2292DSTOPTS",
+ linux.IPV6_2292HOPLIMIT: "IPV6_2292HOPLIMIT",
+ linux.IPV6_2292HOPOPTS: "IPV6_2292HOPOPTS",
+ linux.IPV6_2292PKTINFO: "IPV6_2292PKTINFO",
+ linux.IPV6_2292PKTOPTIONS: "IPV6_2292PKTOPTIONS",
+ linux.IPV6_2292RTHDR: "IPV6_2292RTHDR",
+ linux.IPV6_ADDR_PREFERENCES: "IPV6_ADDR_PREFERENCES",
+ linux.IPV6_AUTOFLOWLABEL: "IPV6_AUTOFLOWLABEL",
+ linux.IPV6_DONTFRAG: "IPV6_DONTFRAG",
+ linux.IPV6_DSTOPTS: "IPV6_DSTOPTS",
+ linux.IPV6_FLOWINFO: "IPV6_FLOWINFO",
+ linux.IPV6_FLOWINFO_SEND: "IPV6_FLOWINFO_SEND",
+ linux.IPV6_FLOWLABEL_MGR: "IPV6_FLOWLABEL_MGR",
+ linux.IPV6_FREEBIND: "IPV6_FREEBIND",
+ linux.IPV6_HOPOPTS: "IPV6_HOPOPTS",
+ linux.IPV6_MINHOPCOUNT: "IPV6_MINHOPCOUNT",
+ linux.IPV6_MTU: "IPV6_MTU",
+ linux.IPV6_MTU_DISCOVER: "IPV6_MTU_DISCOVER",
+ linux.IPV6_MULTICAST_ALL: "IPV6_MULTICAST_ALL",
+ linux.IPV6_MULTICAST_HOPS: "IPV6_MULTICAST_HOPS",
+ linux.IPV6_MULTICAST_IF: "IPV6_MULTICAST_IF",
+ linux.IPV6_MULTICAST_LOOP: "IPV6_MULTICAST_LOOP",
+ linux.IPV6_RECVDSTOPTS: "IPV6_RECVDSTOPTS",
+ linux.IPV6_RECVERR: "IPV6_RECVERR",
+ linux.IPV6_RECVFRAGSIZE: "IPV6_RECVFRAGSIZE",
+ linux.IPV6_RECVHOPLIMIT: "IPV6_RECVHOPLIMIT",
+ linux.IPV6_RECVHOPOPTS: "IPV6_RECVHOPOPTS",
+ linux.IPV6_RECVORIGDSTADDR: "IPV6_RECVORIGDSTADDR",
+ linux.IPV6_RECVPATHMTU: "IPV6_RECVPATHMTU",
+ linux.IPV6_RECVPKTINFO: "IPV6_RECVPKTINFO",
+ linux.IPV6_RECVRTHDR: "IPV6_RECVRTHDR",
+ linux.IPV6_RECVTCLASS: "IPV6_RECVTCLASS",
+ linux.IPV6_RTHDR: "IPV6_RTHDR",
+ linux.IPV6_RTHDRDSTOPTS: "IPV6_RTHDRDSTOPTS",
+ linux.IPV6_TRANSPARENT: "IPV6_TRANSPARENT",
+ linux.IPV6_UNICAST_HOPS: "IPV6_UNICAST_HOPS",
+ linux.IPV6_UNICAST_IF: "IPV6_UNICAST_IF",
+ linux.MCAST_MSFILTER: "MCAST_MSFILTER",
+ linux.IPV6_ADDRFORM: "IPV6_ADDRFORM",
+ },
+ linux.SOL_NETLINK: {
+ linux.NETLINK_BROADCAST_ERROR: "NETLINK_BROADCAST_ERROR",
+ linux.NETLINK_CAP_ACK: "NETLINK_CAP_ACK",
+ linux.NETLINK_DUMP_STRICT_CHK: "NETLINK_DUMP_STRICT_CHK",
+ linux.NETLINK_EXT_ACK: "NETLINK_EXT_ACK",
+ linux.NETLINK_LIST_MEMBERSHIPS: "NETLINK_LIST_MEMBERSHIPS",
+ linux.NETLINK_NO_ENOBUFS: "NETLINK_NO_ENOBUFS",
+ linux.NETLINK_PKTINFO: "NETLINK_PKTINFO",
+ },
+}
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
new file mode 100644
index 000000000..68ca537c8
--- /dev/null
+++ b/pkg/sentry/strace/strace.go
@@ -0,0 +1,874 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package strace implements the logic to print out the input and the return value
+// of each traced syscall.
+package strace
+
+import (
+ "encoding/binary"
+ "fmt"
+ "strconv"
+ "strings"
+ "syscall"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/eventchannel"
+ "gvisor.dev/gvisor/pkg/seccomp"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ pb "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// DefaultLogMaximumSize is the default LogMaximumSize.
+const DefaultLogMaximumSize = 1024
+
+// LogMaximumSize determines the maximum display size for data blobs (read,
+// write, etc.).
+var LogMaximumSize uint = DefaultLogMaximumSize
+
+// EventMaximumSize determines the maximum size for data blobs (read, write,
+// etc.) sent over the event channel. Default is 0 because most clients cannot
+// do anything useful with binary text dump of byte array arguments.
+var EventMaximumSize uint
+
+// ItimerTypes are the possible itimer types.
+var ItimerTypes = abi.ValueSet{
+ linux.ITIMER_REAL: "ITIMER_REAL",
+ linux.ITIMER_VIRTUAL: "ITIMER_VIRTUAL",
+ linux.ITIMER_PROF: "ITIMER_PROF",
+}
+
+func hexNum(num uint64) string {
+ return "0x" + strconv.FormatUint(num, 16)
+}
+
+func hexArg(arg arch.SyscallArgument) string {
+ return hexNum(arg.Uint64())
+}
+
+func iovecs(t *kernel.Task, addr usermem.Addr, iovcnt int, printContent bool, maxBytes uint64) string {
+ if iovcnt < 0 || iovcnt > linux.UIO_MAXIOV {
+ return fmt.Sprintf("%#x (error decoding iovecs: invalid iovcnt)", addr)
+ }
+ ars, err := t.CopyInIovecs(addr, iovcnt)
+ if err != nil {
+ return fmt.Sprintf("%#x (error decoding iovecs: %v)", addr, err)
+ }
+
+ var totalBytes uint64
+ var truncated bool
+ iovs := make([]string, iovcnt)
+ for i := 0; !ars.IsEmpty(); i, ars = i+1, ars.Tail() {
+ ar := ars.Head()
+ if ar.Length() == 0 || !printContent {
+ iovs[i] = fmt.Sprintf("{base=%#x, len=%d}", ar.Start, ar.Length())
+ continue
+ }
+
+ size := uint64(ar.Length())
+ if truncated || totalBytes+size > maxBytes {
+ truncated = true
+ size = maxBytes - totalBytes
+ } else {
+ totalBytes += uint64(ar.Length())
+ }
+
+ b := make([]byte, size)
+ amt, err := t.CopyIn(ar.Start, b)
+ if err != nil {
+ iovs[i] = fmt.Sprintf("{base=%#x, len=%d, %q..., error decoding string: %v}", ar.Start, ar.Length(), b[:amt], err)
+ continue
+ }
+
+ dot := ""
+ if truncated {
+ // Indicate truncation.
+ dot = "..."
+ }
+ iovs[i] = fmt.Sprintf("{base=%#x, len=%d, %q%s}", ar.Start, ar.Length(), b[:amt], dot)
+ }
+
+ return fmt.Sprintf("%#x %s", addr, strings.Join(iovs, ", "))
+}
+
+func dump(t *kernel.Task, addr usermem.Addr, size uint, maximumBlobSize uint) string {
+ origSize := size
+ if size > maximumBlobSize {
+ size = maximumBlobSize
+ }
+ if size == 0 {
+ return ""
+ }
+
+ b := make([]byte, size)
+ amt, err := t.CopyIn(addr, b)
+ if err != nil {
+ return fmt.Sprintf("%#x (error decoding string: %s)", addr, err)
+ }
+
+ dot := ""
+ if uint(amt) < origSize {
+ // ... if we truncated the dump.
+ dot = "..."
+ }
+
+ return fmt.Sprintf("%#x %q%s", addr, b[:amt], dot)
+}
+
+func path(t *kernel.Task, addr usermem.Addr) string {
+ path, err := t.CopyInString(addr, linux.PATH_MAX)
+ if err != nil {
+ return fmt.Sprintf("%#x (error decoding path: %s)", addr, err)
+ }
+ return fmt.Sprintf("%#x %s", addr, path)
+}
+
+func fd(t *kernel.Task, fd int32) string {
+ if kernel.VFS2Enabled {
+ return fdVFS2(t, fd)
+ }
+
+ root := t.FSContext().RootDirectory()
+ if root != nil {
+ defer root.DecRef()
+ }
+
+ if fd == linux.AT_FDCWD {
+ wd := t.FSContext().WorkingDirectory()
+ var name string
+ if wd != nil {
+ defer wd.DecRef()
+ name, _ = wd.FullName(root)
+ } else {
+ name = "(unknown cwd)"
+ }
+ return fmt.Sprintf("AT_FDCWD %s", name)
+ }
+
+ file := t.GetFile(fd)
+ if file == nil {
+ // Cast FD to uint64 to avoid printing negative hex.
+ return fmt.Sprintf("%#x (bad FD)", uint64(fd))
+ }
+ defer file.DecRef()
+
+ name, _ := file.Dirent.FullName(root)
+ return fmt.Sprintf("%#x %s", fd, name)
+}
+
+func fdVFS2(t *kernel.Task, fd int32) string {
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+
+ vfsObj := root.Mount().Filesystem().VirtualFilesystem()
+ if fd == linux.AT_FDCWD {
+ wd := t.FSContext().WorkingDirectoryVFS2()
+ defer wd.DecRef()
+
+ name, _ := vfsObj.PathnameWithDeleted(t, root, wd)
+ return fmt.Sprintf("AT_FDCWD %s", name)
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ // Cast FD to uint64 to avoid printing negative hex.
+ return fmt.Sprintf("%#x (bad FD)", uint64(fd))
+ }
+ defer file.DecRef()
+
+ name, _ := vfsObj.PathnameWithDeleted(t, root, file.VirtualDentry())
+ return fmt.Sprintf("%#x %s", fd, name)
+}
+
+func fdpair(t *kernel.Task, addr usermem.Addr) string {
+ var fds [2]int32
+ _, err := t.CopyIn(addr, &fds)
+ if err != nil {
+ return fmt.Sprintf("%#x (error decoding fds: %s)", addr, err)
+ }
+
+ return fmt.Sprintf("%#x [%d %d]", addr, fds[0], fds[1])
+}
+
+func uname(t *kernel.Task, addr usermem.Addr) string {
+ var u linux.UtsName
+ if _, err := t.CopyIn(addr, &u); err != nil {
+ return fmt.Sprintf("%#x (error decoding utsname: %s)", addr, err)
+ }
+
+ return fmt.Sprintf("%#x %s", addr, u)
+}
+
+func utimensTimespec(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ var tim linux.Timespec
+ if _, err := t.CopyIn(addr, &tim); err != nil {
+ return fmt.Sprintf("%#x (error decoding timespec: %s)", addr, err)
+ }
+
+ var ns string
+ switch tim.Nsec {
+ case linux.UTIME_NOW:
+ ns = "UTIME_NOW"
+ case linux.UTIME_OMIT:
+ ns = "UTIME_OMIT"
+ default:
+ ns = fmt.Sprintf("%v", tim.Nsec)
+ }
+ return fmt.Sprintf("%#x {sec=%v nsec=%s}", addr, tim.Sec, ns)
+}
+
+func timespec(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ var tim linux.Timespec
+ if _, err := t.CopyIn(addr, &tim); err != nil {
+ return fmt.Sprintf("%#x (error decoding timespec: %s)", addr, err)
+ }
+ return fmt.Sprintf("%#x {sec=%v nsec=%v}", addr, tim.Sec, tim.Nsec)
+}
+
+func timeval(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ var tim linux.Timeval
+ if _, err := t.CopyIn(addr, &tim); err != nil {
+ return fmt.Sprintf("%#x (error decoding timeval: %s)", addr, err)
+ }
+
+ return fmt.Sprintf("%#x {sec=%v usec=%v}", addr, tim.Sec, tim.Usec)
+}
+
+func utimbuf(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ var utim syscall.Utimbuf
+ if _, err := t.CopyIn(addr, &utim); err != nil {
+ return fmt.Sprintf("%#x (error decoding utimbuf: %s)", addr, err)
+ }
+
+ return fmt.Sprintf("%#x {actime=%v, modtime=%v}", addr, utim.Actime, utim.Modtime)
+}
+
+func stat(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ var stat linux.Stat
+ if _, err := t.CopyIn(addr, &stat); err != nil {
+ return fmt.Sprintf("%#x (error decoding stat: %s)", addr, err)
+ }
+ return fmt.Sprintf("%#x {dev=%d, ino=%d, mode=%s, nlink=%d, uid=%d, gid=%d, rdev=%d, size=%d, blksize=%d, blocks=%d, atime=%s, mtime=%s, ctime=%s}", addr, stat.Dev, stat.Ino, linux.FileMode(stat.Mode), stat.Nlink, stat.UID, stat.GID, stat.Rdev, stat.Size, stat.Blksize, stat.Blocks, time.Unix(stat.ATime.Sec, stat.ATime.Nsec), time.Unix(stat.MTime.Sec, stat.MTime.Nsec), time.Unix(stat.CTime.Sec, stat.CTime.Nsec))
+}
+
+func itimerval(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ interval := timeval(t, addr)
+ value := timeval(t, addr+usermem.Addr(binary.Size(linux.Timeval{})))
+ return fmt.Sprintf("%#x {interval=%s, value=%s}", addr, interval, value)
+}
+
+func itimerspec(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ interval := timespec(t, addr)
+ value := timespec(t, addr+usermem.Addr(binary.Size(linux.Timespec{})))
+ return fmt.Sprintf("%#x {interval=%s, value=%s}", addr, interval, value)
+}
+
+func stringVector(t *kernel.Task, addr usermem.Addr) string {
+ vec, err := t.CopyInVector(addr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize)
+ if err != nil {
+ return fmt.Sprintf("%#x {error copying vector: %v}", addr, err)
+ }
+ s := fmt.Sprintf("%#x [", addr)
+ for i, v := range vec {
+ if i != 0 {
+ s += ", "
+ }
+ s += fmt.Sprintf("%q", v)
+ }
+ s += "]"
+ return s
+}
+
+func rusage(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ var ru linux.Rusage
+ if _, err := t.CopyIn(addr, &ru); err != nil {
+ return fmt.Sprintf("%#x (error decoding rusage: %s)", addr, err)
+ }
+ return fmt.Sprintf("%#x %+v", addr, ru)
+}
+
+func capHeader(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ var hdr linux.CapUserHeader
+ if _, err := t.CopyIn(addr, &hdr); err != nil {
+ return fmt.Sprintf("%#x (error decoding header: %s)", addr, err)
+ }
+
+ var version string
+ switch hdr.Version {
+ case linux.LINUX_CAPABILITY_VERSION_1:
+ version = "1"
+ case linux.LINUX_CAPABILITY_VERSION_2:
+ version = "2"
+ case linux.LINUX_CAPABILITY_VERSION_3:
+ version = "3"
+ default:
+ version = strconv.FormatUint(uint64(hdr.Version), 16)
+ }
+
+ return fmt.Sprintf("%#x {Version: %s, Pid: %d}", addr, version, hdr.Pid)
+}
+
+func capData(t *kernel.Task, hdrAddr, dataAddr usermem.Addr) string {
+ if dataAddr == 0 {
+ return "null"
+ }
+
+ var hdr linux.CapUserHeader
+ if _, err := t.CopyIn(hdrAddr, &hdr); err != nil {
+ return fmt.Sprintf("%#x (error decoding header: %v)", dataAddr, err)
+ }
+
+ var p, i, e uint64
+
+ switch hdr.Version {
+ case linux.LINUX_CAPABILITY_VERSION_1:
+ var data linux.CapUserData
+ if _, err := t.CopyIn(dataAddr, &data); err != nil {
+ return fmt.Sprintf("%#x (error decoding data: %v)", dataAddr, err)
+ }
+ p = uint64(data.Permitted)
+ i = uint64(data.Inheritable)
+ e = uint64(data.Effective)
+ case linux.LINUX_CAPABILITY_VERSION_2, linux.LINUX_CAPABILITY_VERSION_3:
+ var data [2]linux.CapUserData
+ if _, err := t.CopyIn(dataAddr, &data); err != nil {
+ return fmt.Sprintf("%#x (error decoding data: %v)", dataAddr, err)
+ }
+ p = uint64(data[0].Permitted) | (uint64(data[1].Permitted) << 32)
+ i = uint64(data[0].Inheritable) | (uint64(data[1].Inheritable) << 32)
+ e = uint64(data[0].Effective) | (uint64(data[1].Effective) << 32)
+ default:
+ return fmt.Sprintf("%#x (unknown version %d)", dataAddr, hdr.Version)
+ }
+
+ return fmt.Sprintf("%#x {Permitted: %s, Inheritable: %s, Effective: %s}", dataAddr, CapabilityBitset.Parse(p), CapabilityBitset.Parse(i), CapabilityBitset.Parse(e))
+}
+
+// pre fills in the pre-execution arguments for a system call. If an argument
+// cannot be interpreted before the system call is executed, then a hex value
+// will be used. Note that a full output slice will always be provided, that is
+// len(return) == len(args).
+func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlobSize uint) []string {
+ var output []string
+
+ for arg := range args {
+ if arg >= len(i.format) {
+ break
+ }
+ switch i.format[arg] {
+ case FD:
+ output = append(output, fd(t, args[arg].Int()))
+ case WriteBuffer:
+ output = append(output, dump(t, args[arg].Pointer(), args[arg+1].SizeT(), maximumBlobSize))
+ case WriteIOVec:
+ output = append(output, iovecs(t, args[arg].Pointer(), int(args[arg+1].Int()), true /* content */, uint64(maximumBlobSize)))
+ case IOVec:
+ output = append(output, iovecs(t, args[arg].Pointer(), int(args[arg+1].Int()), false /* content */, uint64(maximumBlobSize)))
+ case SendMsgHdr:
+ output = append(output, msghdr(t, args[arg].Pointer(), true /* content */, uint64(maximumBlobSize)))
+ case RecvMsgHdr:
+ output = append(output, msghdr(t, args[arg].Pointer(), false /* content */, uint64(maximumBlobSize)))
+ case Path:
+ output = append(output, path(t, args[arg].Pointer()))
+ case ExecveStringVector:
+ output = append(output, stringVector(t, args[arg].Pointer()))
+ case SetSockOptVal:
+ output = append(output, sockOptVal(t, args[arg-2].Uint64() /* level */, args[arg-1].Uint64() /* optName */, args[arg].Pointer() /* optVal */, args[arg+1].Uint64() /* optLen */, maximumBlobSize))
+ case SockOptLevel:
+ output = append(output, sockOptLevels.Parse(args[arg].Uint64()))
+ case SockOptName:
+ output = append(output, sockOptNames[args[arg-1].Uint64() /* level */].Parse(args[arg].Uint64()))
+ case SockAddr:
+ output = append(output, sockAddr(t, args[arg].Pointer(), uint32(args[arg+1].Uint64())))
+ case SockLen:
+ output = append(output, sockLenPointer(t, args[arg].Pointer()))
+ case SockFamily:
+ output = append(output, SocketFamily.Parse(uint64(args[arg].Int())))
+ case SockType:
+ output = append(output, sockType(args[arg].Int()))
+ case SockProtocol:
+ output = append(output, sockProtocol(args[arg-2].Int(), args[arg].Int()))
+ case SockFlags:
+ output = append(output, sockFlags(args[arg].Int()))
+ case Timespec:
+ output = append(output, timespec(t, args[arg].Pointer()))
+ case UTimeTimespec:
+ output = append(output, utimensTimespec(t, args[arg].Pointer()))
+ case ItimerVal:
+ output = append(output, itimerval(t, args[arg].Pointer()))
+ case ItimerSpec:
+ output = append(output, itimerspec(t, args[arg].Pointer()))
+ case Timeval:
+ output = append(output, timeval(t, args[arg].Pointer()))
+ case Utimbuf:
+ output = append(output, utimbuf(t, args[arg].Pointer()))
+ case CloneFlags:
+ output = append(output, CloneFlagSet.Parse(uint64(args[arg].Uint())))
+ case OpenFlags:
+ output = append(output, open(uint64(args[arg].Uint())))
+ case Mode:
+ output = append(output, linux.FileMode(args[arg].ModeT()).String())
+ case FutexOp:
+ output = append(output, futex(uint64(args[arg].Uint())))
+ case PtraceRequest:
+ output = append(output, PtraceRequestSet.Parse(args[arg].Uint64()))
+ case ItimerType:
+ output = append(output, ItimerTypes.Parse(uint64(args[arg].Int())))
+ case Signal:
+ output = append(output, signalNames.ParseDecimal(args[arg].Uint64()))
+ case SignalMaskAction:
+ output = append(output, signalMaskActions.Parse(uint64(args[arg].Int())))
+ case SigSet:
+ output = append(output, sigSet(t, args[arg].Pointer()))
+ case SigAction:
+ output = append(output, sigAction(t, args[arg].Pointer()))
+ case CapHeader:
+ output = append(output, capHeader(t, args[arg].Pointer()))
+ case CapData:
+ output = append(output, capData(t, args[arg-1].Pointer(), args[arg].Pointer()))
+ case PollFDs:
+ output = append(output, pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), false))
+ case EpollCtlOp:
+ output = append(output, epollCtlOps.Parse(uint64(args[arg].Int())))
+ case EpollEvent:
+ output = append(output, epollEvent(t, args[arg].Pointer()))
+ case EpollEvents:
+ output = append(output, epollEvents(t, args[arg].Pointer(), 0 /* numEvents */, uint64(maximumBlobSize)))
+ case SelectFDSet:
+ output = append(output, fdSet(t, int(args[0].Int()), args[arg].Pointer()))
+ case Oct:
+ output = append(output, "0o"+strconv.FormatUint(args[arg].Uint64(), 8))
+ case Hex:
+ fallthrough
+ default:
+ output = append(output, hexArg(args[arg]))
+ }
+ }
+
+ return output
+}
+
+// post fills in the post-execution arguments for a system call. This modifies
+// the given output slice in place with arguments that may only be interpreted
+// after the system call has been executed.
+func (i *SyscallInfo) post(t *kernel.Task, args arch.SyscallArguments, rval uintptr, output []string, maximumBlobSize uint) {
+ for arg := range output {
+ if arg >= len(i.format) {
+ break
+ }
+ switch i.format[arg] {
+ case ReadBuffer:
+ output[arg] = dump(t, args[arg].Pointer(), uint(rval), maximumBlobSize)
+ case ReadIOVec:
+ printLength := uint64(rval)
+ if printLength > uint64(maximumBlobSize) {
+ printLength = uint64(maximumBlobSize)
+ }
+ output[arg] = iovecs(t, args[arg].Pointer(), int(args[arg+1].Int()), true /* content */, printLength)
+ case WriteIOVec, IOVec, WriteBuffer:
+ // We already have a big blast from write.
+ output[arg] = "..."
+ case SendMsgHdr:
+ output[arg] = msghdr(t, args[arg].Pointer(), false /* content */, uint64(maximumBlobSize))
+ case RecvMsgHdr:
+ output[arg] = msghdr(t, args[arg].Pointer(), true /* content */, uint64(maximumBlobSize))
+ case PostPath:
+ output[arg] = path(t, args[arg].Pointer())
+ case PipeFDs:
+ output[arg] = fdpair(t, args[arg].Pointer())
+ case Uname:
+ output[arg] = uname(t, args[arg].Pointer())
+ case Stat:
+ output[arg] = stat(t, args[arg].Pointer())
+ case PostSockAddr:
+ output[arg] = postSockAddr(t, args[arg].Pointer(), args[arg+1].Pointer())
+ case SockLen:
+ output[arg] = sockLenPointer(t, args[arg].Pointer())
+ case PostTimespec:
+ output[arg] = timespec(t, args[arg].Pointer())
+ case PostItimerVal:
+ output[arg] = itimerval(t, args[arg].Pointer())
+ case PostItimerSpec:
+ output[arg] = itimerspec(t, args[arg].Pointer())
+ case Timeval:
+ output[arg] = timeval(t, args[arg].Pointer())
+ case Rusage:
+ output[arg] = rusage(t, args[arg].Pointer())
+ case PostSigSet:
+ output[arg] = sigSet(t, args[arg].Pointer())
+ case PostSigAction:
+ output[arg] = sigAction(t, args[arg].Pointer())
+ case PostCapData:
+ output[arg] = capData(t, args[arg-1].Pointer(), args[arg].Pointer())
+ case PollFDs:
+ output[arg] = pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), true)
+ case EpollEvents:
+ output[arg] = epollEvents(t, args[arg].Pointer(), uint64(rval), uint64(maximumBlobSize))
+ case GetSockOptVal:
+ output[arg] = getSockOptVal(t, args[arg-2].Uint64() /* level */, args[arg-1].Uint64() /* optName */, args[arg].Pointer() /* optVal */, args[arg+1].Pointer() /* optLen */, maximumBlobSize, rval)
+ case SetSockOptVal:
+ // No need to print the value again. While it usually
+ // isn't, the string version of this arg can be long.
+ output[arg] = hexArg(args[arg])
+ }
+ }
+}
+
+// printEntry prints the given system call entry.
+func (i *SyscallInfo) printEnter(t *kernel.Task, args arch.SyscallArguments) []string {
+ output := i.pre(t, args, LogMaximumSize)
+
+ switch len(output) {
+ case 0:
+ t.Infof("%s E %s()", t.Name(), i.name)
+ case 1:
+ t.Infof("%s E %s(%s)", t.Name(), i.name,
+ output[0])
+ case 2:
+ t.Infof("%s E %s(%s, %s)", t.Name(), i.name,
+ output[0], output[1])
+ case 3:
+ t.Infof("%s E %s(%s, %s, %s)", t.Name(), i.name,
+ output[0], output[1], output[2])
+ case 4:
+ t.Infof("%s E %s(%s, %s, %s, %s)", t.Name(), i.name,
+ output[0], output[1], output[2], output[3])
+ case 5:
+ t.Infof("%s E %s(%s, %s, %s, %s, %s)", t.Name(), i.name,
+ output[0], output[1], output[2], output[3], output[4])
+ case 6:
+ t.Infof("%s E %s(%s, %s, %s, %s, %s, %s)", t.Name(), i.name,
+ output[0], output[1], output[2], output[3], output[4], output[5])
+ }
+
+ return output
+}
+
+// printExit prints the given system call exit.
+func (i *SyscallInfo) printExit(t *kernel.Task, elapsed time.Duration, output []string, args arch.SyscallArguments, retval uintptr, err error, errno int) {
+ var rval string
+ if err == nil {
+ // Fill in the output after successful execution.
+ i.post(t, args, retval, output, LogMaximumSize)
+ rval = fmt.Sprintf("%#x (%v)", retval, elapsed)
+ } else {
+ rval = fmt.Sprintf("%#x errno=%d (%s) (%v)", retval, errno, err, elapsed)
+ }
+
+ switch len(output) {
+ case 0:
+ t.Infof("%s X %s() = %s", t.Name(), i.name,
+ rval)
+ case 1:
+ t.Infof("%s X %s(%s) = %s", t.Name(), i.name,
+ output[0], rval)
+ case 2:
+ t.Infof("%s X %s(%s, %s) = %s", t.Name(), i.name,
+ output[0], output[1], rval)
+ case 3:
+ t.Infof("%s X %s(%s, %s, %s) = %s", t.Name(), i.name,
+ output[0], output[1], output[2], rval)
+ case 4:
+ t.Infof("%s X %s(%s, %s, %s, %s) = %s", t.Name(), i.name,
+ output[0], output[1], output[2], output[3], rval)
+ case 5:
+ t.Infof("%s X %s(%s, %s, %s, %s, %s) = %s", t.Name(), i.name,
+ output[0], output[1], output[2], output[3], output[4], rval)
+ case 6:
+ t.Infof("%s X %s(%s, %s, %s, %s, %s, %s) = %s", t.Name(), i.name,
+ output[0], output[1], output[2], output[3], output[4], output[5], rval)
+ }
+}
+
+// sendEnter sends the syscall enter to event log.
+func (i *SyscallInfo) sendEnter(t *kernel.Task, args arch.SyscallArguments) []string {
+ output := i.pre(t, args, EventMaximumSize)
+
+ event := pb.Strace{
+ Process: t.Name(),
+ Function: i.name,
+ Info: &pb.Strace_Enter{
+ Enter: &pb.StraceEnter{},
+ },
+ }
+ for _, arg := range output {
+ event.Args = append(event.Args, arg)
+ }
+ eventchannel.Emit(&event)
+
+ return output
+}
+
+// sendExit sends the syscall exit to event log.
+func (i *SyscallInfo) sendExit(t *kernel.Task, elapsed time.Duration, output []string, args arch.SyscallArguments, rval uintptr, err error, errno int) {
+ if err == nil {
+ // Fill in the output after successful execution.
+ i.post(t, args, rval, output, EventMaximumSize)
+ }
+
+ exit := &pb.StraceExit{
+ Return: fmt.Sprintf("%#x", rval),
+ ElapsedNs: elapsed.Nanoseconds(),
+ }
+ if err != nil {
+ exit.Error = err.Error()
+ exit.ErrNo = int64(errno)
+ }
+ event := pb.Strace{
+ Process: t.Name(),
+ Function: i.name,
+ Info: &pb.Strace_Exit{Exit: exit},
+ }
+ for _, arg := range output {
+ event.Args = append(event.Args, arg)
+ }
+ eventchannel.Emit(&event)
+}
+
+type syscallContext struct {
+ info SyscallInfo
+ args arch.SyscallArguments
+ start time.Time
+ logOutput []string
+ eventOutput []string
+ flags uint32
+}
+
+// SyscallEnter implements kernel.Stracer.SyscallEnter. It logs the syscall
+// entry trace.
+func (s SyscallMap) SyscallEnter(t *kernel.Task, sysno uintptr, args arch.SyscallArguments, flags uint32) interface{} {
+ info, ok := s[sysno]
+ if !ok {
+ info = SyscallInfo{
+ name: fmt.Sprintf("sys_%d", sysno),
+ format: defaultFormat,
+ }
+ }
+
+ var output, eventOutput []string
+ if bits.IsOn32(flags, kernel.StraceEnableLog) {
+ output = info.printEnter(t, args)
+ }
+ if bits.IsOn32(flags, kernel.StraceEnableEvent) {
+ eventOutput = info.sendEnter(t, args)
+ }
+
+ return &syscallContext{
+ info: info,
+ args: args,
+ start: time.Now(),
+ logOutput: output,
+ eventOutput: eventOutput,
+ flags: flags,
+ }
+}
+
+// SyscallExit implements kernel.Stracer.SyscallExit. It logs the syscall
+// exit trace.
+func (s SyscallMap) SyscallExit(context interface{}, t *kernel.Task, sysno, rval uintptr, err error) {
+ errno := kernel.ExtractErrno(err, int(sysno))
+ c := context.(*syscallContext)
+
+ elapsed := time.Since(c.start)
+ if bits.IsOn32(c.flags, kernel.StraceEnableLog) {
+ c.info.printExit(t, elapsed, c.logOutput, c.args, rval, err, errno)
+ }
+ if bits.IsOn32(c.flags, kernel.StraceEnableEvent) {
+ c.info.sendExit(t, elapsed, c.eventOutput, c.args, rval, err, errno)
+ }
+}
+
+// ConvertToSysnoMap converts the names to a map keyed on the syscall number
+// and value set to true.
+//
+// The map is in a convenient format to pass to SyscallFlagsTable.Enable().
+func (s SyscallMap) ConvertToSysnoMap(syscalls []string) (map[uintptr]bool, error) {
+ if syscalls == nil {
+ // Sentinel: no list.
+ return nil, nil
+ }
+
+ l := make(map[uintptr]bool)
+ for _, sc := range syscalls {
+ // Try to match this system call.
+ sysno, ok := s.ConvertToSysno(sc)
+ if !ok {
+ return nil, fmt.Errorf("syscall %q not found", sc)
+ }
+ l[sysno] = true
+ }
+
+ // Success.
+ return l, nil
+}
+
+// ConvertToSysno converts the name to system call number. Returns false
+// if syscall with same name is not found.
+func (s SyscallMap) ConvertToSysno(syscall string) (uintptr, bool) {
+ for sysno, info := range s {
+ if info.name != "" && info.name == syscall {
+ return sysno, true
+ }
+ }
+ return 0, false
+}
+
+// Name returns the syscall name.
+func (s SyscallMap) Name(sysno uintptr) string {
+ if info, ok := s[sysno]; ok {
+ return info.name
+ }
+ return fmt.Sprintf("sys_%d", sysno)
+}
+
+// Initialize prepares all syscall tables for use by this package.
+//
+// N.B. This is not in an init function because we can't be sure all syscall
+// tables are registered with the kernel when init runs.
+func Initialize() {
+ for _, table := range kernel.SyscallTables() {
+ // Is this known?
+ sys, ok := Lookup(table.OS, table.Arch)
+ if !ok {
+ continue
+ }
+
+ table.Stracer = sys
+ }
+}
+
+// SinkType defines where to send straces to.
+type SinkType uint32
+
+const (
+ // SinkTypeLog sends straces to text log
+ SinkTypeLog SinkType = 1 << iota
+
+ // SinkTypeEvent sends strace to event log
+ SinkTypeEvent
+)
+
+func convertToSyscallFlag(sinks SinkType) uint32 {
+ ret := uint32(0)
+ if bits.IsOn32(uint32(sinks), uint32(SinkTypeLog)) {
+ ret |= kernel.StraceEnableLog
+ }
+ if bits.IsOn32(uint32(sinks), uint32(SinkTypeEvent)) {
+ ret |= kernel.StraceEnableEvent
+ }
+ return ret
+}
+
+// Enable enables the syscalls in whitelist in all syscall tables.
+//
+// Preconditions: Initialize has been called.
+func Enable(whitelist []string, sinks SinkType) error {
+ flags := convertToSyscallFlag(sinks)
+ for _, table := range kernel.SyscallTables() {
+ // Is this known?
+ sys, ok := Lookup(table.OS, table.Arch)
+ if !ok {
+ continue
+ }
+
+ // Convert to a set of system calls numbers.
+ wl, err := sys.ConvertToSysnoMap(whitelist)
+ if err != nil {
+ return err
+ }
+
+ table.FeatureEnable.Enable(flags, wl, true)
+ }
+
+ // Done.
+ return nil
+}
+
+// Disable will disable Strace for all system calls and missing syscalls.
+//
+// Preconditions: Initialize has been called.
+func Disable(sinks SinkType) {
+ flags := convertToSyscallFlag(sinks)
+ for _, table := range kernel.SyscallTables() {
+ // Strace will be disabled for all syscalls including missing.
+ table.FeatureEnable.Enable(flags, nil, false)
+ }
+}
+
+// EnableAll enables all syscalls in all syscall tables.
+//
+// Preconditions: Initialize has been called.
+func EnableAll(sinks SinkType) {
+ flags := convertToSyscallFlag(sinks)
+ for _, table := range kernel.SyscallTables() {
+ // Is this known?
+ if _, ok := Lookup(table.OS, table.Arch); !ok {
+ continue
+ }
+
+ table.FeatureEnable.EnableAll(flags)
+ }
+}
+
+func init() {
+ t, ok := Lookup(abi.Host, arch.Host)
+ if ok {
+ // Provide the native table as the lookup for seccomp
+ // debugging. This is best-effort. This is provided this way to
+ // avoid dependencies from seccomp to this package.
+ seccomp.SyscallName = t.Name
+ }
+}
diff --git a/pkg/sentry/strace/strace.proto b/pkg/sentry/strace/strace.proto
new file mode 100644
index 000000000..906c52c51
--- /dev/null
+++ b/pkg/sentry/strace/strace.proto
@@ -0,0 +1,49 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package gvisor;
+
+message Strace {
+ // Process name that made the syscall.
+ string process = 1;
+
+ // Syscall function name.
+ string function = 2;
+
+ // List of syscall arguments formatted as strings.
+ repeated string args = 3;
+
+ oneof info {
+ StraceEnter enter = 4;
+ StraceExit exit = 5;
+ }
+}
+
+message StraceEnter {}
+
+message StraceExit {
+ // Return value formatted as string.
+ string return = 1;
+
+ // Formatted error string in case syscall failed.
+ string error = 2;
+
+ // Value of errno upon syscall exit.
+ int64 err_no = 3; // errno is a macro and gets expanded :-(
+
+ // Time elapsed between syscall enter and exit.
+ int64 elapsed_ns = 4;
+}
diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go
new file mode 100644
index 000000000..7e69b9279
--- /dev/null
+++ b/pkg/sentry/strace/syscalls.go
@@ -0,0 +1,292 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// FormatSpecifier values describe how an individual syscall argument should be
+// formatted.
+type FormatSpecifier int
+
+// Valid FormatSpecifiers.
+//
+// Unless otherwise specified, values are formatted before syscall execution
+// and not updated after syscall execution (the same value is output).
+const (
+ // Hex is just a hexadecimal number.
+ Hex FormatSpecifier = iota
+
+ // Oct is just an octal number.
+ Oct
+
+ // FD is a file descriptor.
+ FD
+
+ // ReadBuffer is a buffer for a read-style call. The syscall return
+ // value is used for the length.
+ //
+ // Formatted after syscall execution.
+ ReadBuffer
+
+ // WriteBuffer is a buffer for a write-style call. The following arg is
+ // used for the length.
+ //
+ // Contents omitted after syscall execution.
+ WriteBuffer
+
+ // ReadIOVec is a pointer to a struct iovec for a writev-style call.
+ // The following arg is used for the length. The return value is used
+ // for the total length.
+ //
+ // Complete contents only formatted after syscall execution.
+ ReadIOVec
+
+ // WriteIOVec is a pointer to a struct iovec for a writev-style call.
+ // The following arg is used for the length.
+ //
+ // Complete contents only formatted before syscall execution, omitted
+ // after.
+ WriteIOVec
+
+ // IOVec is a generic pointer to a struct iovec. Contents are not dumped.
+ IOVec
+
+ // SendMsgHdr is a pointer to a struct msghdr for a sendmsg-style call.
+ // Contents formatted only before syscall execution, omitted after.
+ SendMsgHdr
+
+ // RecvMsgHdr is a pointer to a struct msghdr for a recvmsg-style call.
+ // Contents formatted only after syscall execution.
+ RecvMsgHdr
+
+ // Path is a pointer to a char* path.
+ Path
+
+ // PostPath is a pointer to a char* path, formatted after syscall
+ // execution.
+ PostPath
+
+ // ExecveStringVector is a NULL-terminated array of strings. Enforces
+ // the maximum execve array length.
+ ExecveStringVector
+
+ // PipeFDs is an array of two FDs, formatted after syscall execution.
+ PipeFDs
+
+ // Uname is a pointer to a struct uname, formatted after syscall execution.
+ Uname
+
+ // Stat is a pointer to a struct stat, formatted after syscall execution.
+ Stat
+
+ // SockAddr is a pointer to a struct sockaddr. The following arg is
+ // used for length.
+ SockAddr
+
+ // PostSockAddr is a pointer to a struct sockaddr, formatted after
+ // syscall execution. The following arg is a pointer to the socklen_t
+ // length.
+ PostSockAddr
+
+ // SockLen is a pointer to a socklen_t, formatted before and after
+ // syscall execution.
+ SockLen
+
+ // SockFamily is a socket protocol family value.
+ SockFamily
+
+ // SockType is a socket type and flags value.
+ SockType
+
+ // SockProtocol is a socket protocol value. Argument n-2 is the socket
+ // protocol family.
+ SockProtocol
+
+ // SockFlags are socket flags.
+ SockFlags
+
+ // Timespec is a pointer to a struct timespec.
+ Timespec
+
+ // PostTimespec is a pointer to a struct timespec, formatted after
+ // syscall execution.
+ PostTimespec
+
+ // UTimeTimespec is a pointer to a struct timespec. Formatting includes
+ // UTIME_NOW and UTIME_OMIT.
+ UTimeTimespec
+
+ // ItimerVal is a pointer to a struct itimerval.
+ ItimerVal
+
+ // PostItimerVal is a pointer to a struct itimerval, formatted after
+ // syscall execution.
+ PostItimerVal
+
+ // ItimerSpec is a pointer to a struct itimerspec.
+ ItimerSpec
+
+ // PostItimerSpec is a pointer to a struct itimerspec, formatted after
+ // syscall execution.
+ PostItimerSpec
+
+ // Timeval is a pointer to a struct timeval, formatted before and after
+ // syscall execution.
+ Timeval
+
+ // Utimbuf is a pointer to a struct utimbuf.
+ Utimbuf
+
+ // Rusage is a struct rusage, formatted after syscall execution.
+ Rusage
+
+ // CloneFlags are clone(2) flags.
+ CloneFlags
+
+ // OpenFlags are open(2) flags.
+ OpenFlags
+
+ // Mode is a mode_t.
+ Mode
+
+ // FutexOp is the futex(2) operation.
+ FutexOp
+
+ // PtraceRequest is the ptrace(2) request.
+ PtraceRequest
+
+ // ItimerType is an itimer type (ITIMER_REAL, etc).
+ ItimerType
+
+ // Signal is a signal number.
+ Signal
+
+ // SignalMaskAction is a signal mask action passed to rt_sigprocmask(2).
+ SignalMaskAction
+
+ // SigSet is a signal set.
+ SigSet
+
+ // PostSigSet is a signal set, formatted after syscall execution.
+ PostSigSet
+
+ // SigAction is a struct sigaction.
+ SigAction
+
+ // PostSigAction is a struct sigaction, formatted after syscall execution.
+ PostSigAction
+
+ // CapHeader is a cap_user_header_t.
+ CapHeader
+
+ // CapData is the data argument to capget(2)/capset(2). The previous
+ // argument must be CapHeader.
+ CapData
+
+ // PostCapData is the data argument to capget(2)/capset(2), formatted
+ // after syscall execution. The previous argument must be CapHeader.
+ PostCapData
+
+ // PollFDs is an array of struct pollfd. The number of entries in the
+ // array is in the next argument.
+ PollFDs
+
+ // SelectFDSet is an fd_set argument in select(2)/pselect(2). The
+ // number of FDs represented must be the first argument.
+ SelectFDSet
+
+ // GetSockOptVal is the optval argument in getsockopt(2).
+ //
+ // Formatted after syscall execution.
+ GetSockOptVal
+
+ // SetSockOptVal is the optval argument in setsockopt(2).
+ //
+ // Contents omitted after syscall execution.
+ SetSockOptVal
+
+ // SockOptLevel is the level argument in getsockopt(2) and
+ // setsockopt(2).
+ SockOptLevel
+
+ // SockOptLevel is the optname argument in getsockopt(2) and
+ // setsockopt(2).
+ SockOptName
+
+ // EpollCtlOp is the op argument to epoll_ctl(2).
+ EpollCtlOp
+
+ // EpollEvent is the event argument in epoll_ctl(2).
+ EpollEvent
+
+ // EpollEvents is an array of struct epoll_event. It is the events
+ // argument in epoll_wait(2)/epoll_pwait(2).
+ EpollEvents
+)
+
+// defaultFormat is the syscall argument format to use if the actual format is
+// not known. It formats all six arguments as hex.
+var defaultFormat = []FormatSpecifier{Hex, Hex, Hex, Hex, Hex, Hex}
+
+// SyscallInfo captures the name and printing format of a syscall.
+type SyscallInfo struct {
+ // name is the name of the syscall.
+ name string
+
+ // format contains the format specifiers for each argument.
+ //
+ // Syscall calls can have up to six arguments. Arguments without a
+ // corresponding entry in format will not be printed.
+ format []FormatSpecifier
+}
+
+// makeSyscallInfo returns a SyscallInfo for a syscall.
+func makeSyscallInfo(name string, f ...FormatSpecifier) SyscallInfo {
+ return SyscallInfo{name: name, format: f}
+}
+
+// SyscallMap maps syscalls into names and printing formats.
+type SyscallMap map[uintptr]SyscallInfo
+
+var _ kernel.Stracer = (SyscallMap)(nil)
+
+// syscallTable contains the syscalls for a specific OS/Arch.
+type syscallTable struct {
+ // os is the operating system this table targets.
+ os abi.OS
+
+ // arch is the architecture this table targets.
+ arch arch.Arch
+
+ // syscalls contains the syscall mappings.
+ syscalls SyscallMap
+}
+
+var syscallTables []syscallTable
+
+// Lookup returns the SyscallMap for the OS/Arch combination. The returned map
+// must not be changed.
+func Lookup(os abi.OS, a arch.Arch) (SyscallMap, bool) {
+ for _, s := range syscallTables {
+ if s.os == os && s.arch == a {
+ return s.syscalls, true
+ }
+ }
+ return nil, false
+}
diff --git a/pkg/sentry/syscalls/BUILD b/pkg/sentry/syscalls/BUILD
new file mode 100644
index 000000000..b8d1bd415
--- /dev/null
+++ b/pkg/sentry/syscalls/BUILD
@@ -0,0 +1,21 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "syscalls",
+ srcs = [
+ "epoll.go",
+ "syscalls.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/epoll",
+ "//pkg/sentry/kernel/time",
+ "//pkg/syserror",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/syscalls/epoll.go b/pkg/sentry/syscalls/epoll.go
new file mode 100644
index 000000000..d9fb808c0
--- /dev/null
+++ b/pkg/sentry/syscalls/epoll.go
@@ -0,0 +1,173 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syscalls
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/epoll"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// CreateEpoll implements the epoll_create(2) linux syscall.
+func CreateEpoll(t *kernel.Task, closeOnExec bool) (int32, error) {
+ file := epoll.NewEventPoll(t)
+ defer file.DecRef()
+
+ fd, err := t.NewFDFrom(0, file, kernel.FDFlags{
+ CloseOnExec: closeOnExec,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ return fd, nil
+}
+
+// AddEpoll implements the epoll_ctl(2) linux syscall when op is EPOLL_CTL_ADD.
+func AddEpoll(t *kernel.Task, epfd int32, fd int32, flags epoll.EntryFlags, mask waiter.EventMask, userData [2]int32) error {
+ // Get epoll from the file descriptor.
+ epollfile := t.GetFile(epfd)
+ if epollfile == nil {
+ return syserror.EBADF
+ }
+ defer epollfile.DecRef()
+
+ // Get the target file id.
+ file := t.GetFile(fd)
+ if file == nil {
+ return syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the epollPoll operations.
+ e, ok := epollfile.FileOperations.(*epoll.EventPoll)
+ if !ok {
+ return syserror.EBADF
+ }
+
+ // Try to add the entry.
+ return e.AddEntry(epoll.FileIdentifier{file, fd}, flags, mask, userData)
+}
+
+// UpdateEpoll implements the epoll_ctl(2) linux syscall when op is EPOLL_CTL_MOD.
+func UpdateEpoll(t *kernel.Task, epfd int32, fd int32, flags epoll.EntryFlags, mask waiter.EventMask, userData [2]int32) error {
+ // Get epoll from the file descriptor.
+ epollfile := t.GetFile(epfd)
+ if epollfile == nil {
+ return syserror.EBADF
+ }
+ defer epollfile.DecRef()
+
+ // Get the target file id.
+ file := t.GetFile(fd)
+ if file == nil {
+ return syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the epollPoll operations.
+ e, ok := epollfile.FileOperations.(*epoll.EventPoll)
+ if !ok {
+ return syserror.EBADF
+ }
+
+ // Try to update the entry.
+ return e.UpdateEntry(epoll.FileIdentifier{file, fd}, flags, mask, userData)
+}
+
+// RemoveEpoll implements the epoll_ctl(2) linux syscall when op is EPOLL_CTL_DEL.
+func RemoveEpoll(t *kernel.Task, epfd int32, fd int32) error {
+ // Get epoll from the file descriptor.
+ epollfile := t.GetFile(epfd)
+ if epollfile == nil {
+ return syserror.EBADF
+ }
+ defer epollfile.DecRef()
+
+ // Get the target file id.
+ file := t.GetFile(fd)
+ if file == nil {
+ return syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the epollPoll operations.
+ e, ok := epollfile.FileOperations.(*epoll.EventPoll)
+ if !ok {
+ return syserror.EBADF
+ }
+
+ // Try to remove the entry.
+ return e.RemoveEntry(epoll.FileIdentifier{file, fd})
+}
+
+// WaitEpoll implements the epoll_wait(2) linux syscall.
+func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEvent, error) {
+ // Get epoll from the file descriptor.
+ epollfile := t.GetFile(fd)
+ if epollfile == nil {
+ return nil, syserror.EBADF
+ }
+ defer epollfile.DecRef()
+
+ // Extract the epollPoll operations.
+ e, ok := epollfile.FileOperations.(*epoll.EventPoll)
+ if !ok {
+ return nil, syserror.EBADF
+ }
+
+ // Try to read events and return right away if we got them or if the
+ // caller requested a non-blocking "wait".
+ r := e.ReadEvents(max)
+ if len(r) != 0 || timeout == 0 {
+ return r, nil
+ }
+
+ // We'll have to wait. Set up the timer if a timeout was specified and
+ // and register with the epoll object for readability events.
+ var haveDeadline bool
+ var deadline ktime.Time
+ if timeout > 0 {
+ timeoutDur := time.Duration(timeout) * time.Millisecond
+ deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur)
+ haveDeadline = true
+ }
+
+ w, ch := waiter.NewChannelEntry(nil)
+ e.EventRegister(&w, waiter.EventIn)
+ defer e.EventUnregister(&w)
+
+ // Try to read the events again until we succeed, timeout or get
+ // interrupted.
+ for {
+ r = e.ReadEvents(max)
+ if len(r) != 0 {
+ return r, nil
+ }
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return nil, nil
+ }
+
+ return nil, err
+ }
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
new file mode 100644
index 000000000..217fcfef2
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -0,0 +1,103 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "linux",
+ srcs = [
+ "error.go",
+ "flags.go",
+ "linux64.go",
+ "sigset.go",
+ "sys_aio.go",
+ "sys_capability.go",
+ "sys_clone_amd64.go",
+ "sys_clone_arm64.go",
+ "sys_epoll.go",
+ "sys_eventfd.go",
+ "sys_file.go",
+ "sys_futex.go",
+ "sys_getdents.go",
+ "sys_identity.go",
+ "sys_inotify.go",
+ "sys_lseek.go",
+ "sys_mempolicy.go",
+ "sys_mmap.go",
+ "sys_mount.go",
+ "sys_pipe.go",
+ "sys_poll.go",
+ "sys_prctl.go",
+ "sys_random.go",
+ "sys_read.go",
+ "sys_rlimit.go",
+ "sys_rseq.go",
+ "sys_rusage.go",
+ "sys_sched.go",
+ "sys_seccomp.go",
+ "sys_sem.go",
+ "sys_shm.go",
+ "sys_signal.go",
+ "sys_socket.go",
+ "sys_splice.go",
+ "sys_stat.go",
+ "sys_stat_amd64.go",
+ "sys_stat_arm64.go",
+ "sys_sync.go",
+ "sys_sysinfo.go",
+ "sys_syslog.go",
+ "sys_thread.go",
+ "sys_time.go",
+ "sys_timer.go",
+ "sys_timerfd.go",
+ "sys_tls_amd64.go",
+ "sys_tls_arm64.go",
+ "sys_utsname.go",
+ "sys_write.go",
+ "sys_xattr.go",
+ "timespec.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/abi",
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/bpf",
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/metric",
+ "//pkg/rand",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/anon",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fs/timerfd",
+ "//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/fsbridge",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/epoll",
+ "//pkg/sentry/kernel/eventfd",
+ "//pkg/sentry/kernel/fasync",
+ "//pkg/sentry/kernel/pipe",
+ "//pkg/sentry/kernel/sched",
+ "//pkg/sentry/kernel/shm",
+ "//pkg/sentry/kernel/signalfd",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/loader",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/mm",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/control",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/syscalls",
+ "//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go
new file mode 100644
index 000000000..64de56ac5
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/error.go
@@ -0,0 +1,157 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/metric"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+var (
+ partialResultMetric = metric.MustCreateNewUint64Metric("/syscalls/partial_result", true /* sync */, "Whether or not a partial result has occurred for this sandbox.")
+ partialResultOnce sync.Once
+)
+
+// HandleIOErrorVFS2 handles special error cases for partial results. For some
+// errors, we may consume the error and return only the partial read/write.
+//
+// op and f are used only for panics.
+func HandleIOErrorVFS2(t *kernel.Task, partialResult bool, err, intr error, op string, f *vfs.FileDescription) error {
+ known, err := handleIOErrorImpl(t, partialResult, err, intr, op)
+ if err != nil {
+ return err
+ }
+ if !known {
+ // An unknown error is encountered with a partial read/write.
+ fs := f.Mount().Filesystem().VirtualFilesystem()
+ root := vfs.RootFromContext(t)
+ name, _ := fs.PathnameWithDeleted(t, root, f.VirtualDentry())
+ log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q", partialResult, err, err, op, name)
+ partialResultOnce.Do(partialResultMetric.Increment)
+ }
+ return nil
+}
+
+// handleIOError handles special error cases for partial results. For some
+// errors, we may consume the error and return only the partial read/write.
+//
+// op and f are used only for panics.
+func handleIOError(t *kernel.Task, partialResult bool, err, intr error, op string, f *fs.File) error {
+ known, err := handleIOErrorImpl(t, partialResult, err, intr, op)
+ if err != nil {
+ return err
+ }
+ if !known {
+ // An unknown error is encountered with a partial read/write.
+ name, _ := f.Dirent.FullName(nil /* ignore chroot */)
+ log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q, %T", partialResult, err, err, op, name, f.FileOperations)
+ partialResultOnce.Do(partialResultMetric.Increment)
+ }
+ return nil
+}
+
+// handleIOError handles special error cases for partial results. For some
+// errors, we may consume the error and return only the partial read/write.
+//
+// Returns false if error is unknown.
+func handleIOErrorImpl(t *kernel.Task, partialResult bool, err, intr error, op string) (bool, error) {
+ switch err {
+ case nil:
+ // Typical successful syscall.
+ return true, nil
+ case io.EOF:
+ // EOF is always consumed. If this is a partial read/write
+ // (result != 0), the application will see that, otherwise
+ // they will see 0.
+ return true, nil
+ case syserror.ErrExceedsFileSizeLimit:
+ // Ignore partialResult because this error only applies to
+ // normal files, and for those files we cannot accumulate
+ // write results.
+ //
+ // Do not consume the error and return it as EFBIG.
+ // Simultaneously send a SIGXFSZ per setrlimit(2).
+ t.SendSignal(kernel.SignalInfoNoInfo(linux.SIGXFSZ, t, t))
+ return true, syserror.EFBIG
+ case syserror.ErrInterrupted:
+ // The syscall was interrupted. Return nil if it completed
+ // partially, otherwise return the error code that the syscall
+ // needs (to indicate to the kernel what it should do).
+ if partialResult {
+ return true, nil
+ }
+ return true, intr
+ }
+
+ if !partialResult {
+ // Typical syscall error.
+ return true, err
+ }
+
+ switch err {
+ case syserror.EINTR:
+ // Syscall interrupted, but completed a partial
+ // read/write. Like ErrWouldBlock, since we have a
+ // partial read/write, we consume the error and return
+ // the partial result.
+ return true, nil
+ case syserror.EFAULT:
+ // EFAULT is only shown the user if nothing was
+ // read/written. If we read something (this case), they see
+ // a partial read/write. They will then presumably try again
+ // with an incremented buffer, which will EFAULT with
+ // result == 0.
+ return true, nil
+ case syserror.EPIPE:
+ // Writes to a pipe or socket will return EPIPE if the other
+ // side is gone. The partial write is returned. EPIPE will be
+ // returned on the next call.
+ //
+ // TODO(gvisor.dev/issue/161): In some cases SIGPIPE should
+ // also be sent to the application.
+ return true, nil
+ case syserror.ENOSPC:
+ // Similar to EPIPE. Return what we wrote this time, and let
+ // ENOSPC be returned on the next call.
+ return true, nil
+ case syserror.ECONNRESET:
+ // For TCP sendfile connections, we may have a reset. But we
+ // should just return n as the result.
+ return true, nil
+ case syserror.ErrWouldBlock:
+ // Syscall would block, but completed a partial read/write.
+ // This case should only be returned by IssueIO for nonblocking
+ // files. Since we have a partial read/write, we consume
+ // ErrWouldBlock, returning the partial result.
+ return true, nil
+ }
+
+ switch err.(type) {
+ case kernel.SyscallRestartErrno:
+ // Identical to the EINTR case.
+ return true, nil
+ }
+
+ // Error is unknown and cannot be properly handled.
+ return false, nil
+}
diff --git a/pkg/sentry/syscalls/linux/flags.go b/pkg/sentry/syscalls/linux/flags.go
new file mode 100644
index 000000000..07961dad9
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/flags.go
@@ -0,0 +1,55 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+)
+
+// flagsToPermissions returns a Permissions object from Linux flags.
+// This includes truncate permission if O_TRUNC is set in the mask.
+func flagsToPermissions(mask uint) (p fs.PermMask) {
+ if mask&linux.O_TRUNC != 0 {
+ p.Write = true
+ }
+ switch mask & linux.O_ACCMODE {
+ case linux.O_WRONLY:
+ p.Write = true
+ case linux.O_RDWR:
+ p.Write = true
+ p.Read = true
+ case linux.O_RDONLY:
+ p.Read = true
+ }
+ return
+}
+
+// linuxToFlags converts Linux file flags to a FileFlags object.
+func linuxToFlags(mask uint) fs.FileFlags {
+ return fs.FileFlags{
+ Direct: mask&linux.O_DIRECT != 0,
+ DSync: mask&(linux.O_DSYNC|linux.O_SYNC) != 0,
+ Sync: mask&linux.O_SYNC != 0,
+ NonBlocking: mask&linux.O_NONBLOCK != 0,
+ Read: (mask & linux.O_ACCMODE) != linux.O_WRONLY,
+ Write: (mask & linux.O_ACCMODE) != linux.O_RDONLY,
+ Append: mask&linux.O_APPEND != 0,
+ Directory: mask&linux.O_DIRECTORY != 0,
+ Async: mask&linux.O_ASYNC != 0,
+ LargeFile: mask&linux.O_LARGEFILE != 0,
+ Truncate: mask&linux.O_TRUNC != 0,
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
new file mode 100644
index 000000000..ea4f9b1a7
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -0,0 +1,736 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package linux provides syscall tables for amd64 Linux.
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/syscalls"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ // LinuxSysname is the OS name advertised by gVisor.
+ LinuxSysname = "Linux"
+
+ // LinuxRelease is the Linux release version number advertised by gVisor.
+ LinuxRelease = "4.4.0"
+
+ // LinuxVersion is the version info advertised by gVisor.
+ LinuxVersion = "#1 SMP Sun Jan 10 15:06:54 PST 2016"
+)
+
+// AMD64 is a table of Linux amd64 syscall API with the corresponding syscall
+// numbers from Linux 4.4.
+var AMD64 = &kernel.SyscallTable{
+ OS: abi.Linux,
+ Arch: arch.AMD64,
+ Version: kernel.Version{
+ // Version 4.4 is chosen as a stable, longterm version of Linux, which
+ // guides the interface provided by this syscall table. The build
+ // version is that for a clean build with default kernel config, at 5
+ // minutes after v4.4 was tagged.
+ Sysname: LinuxSysname,
+ Release: LinuxRelease,
+ Version: LinuxVersion,
+ },
+ AuditNumber: linux.AUDIT_ARCH_X86_64,
+ Table: map[uintptr]kernel.Syscall{
+ 0: syscalls.Supported("read", Read),
+ 1: syscalls.Supported("write", Write),
+ 2: syscalls.PartiallySupported("open", Open, "Options O_DIRECT, O_NOATIME, O_PATH, O_TMPFILE, O_SYNC are not supported.", nil),
+ 3: syscalls.Supported("close", Close),
+ 4: syscalls.Supported("stat", Stat),
+ 5: syscalls.Supported("fstat", Fstat),
+ 6: syscalls.Supported("lstat", Lstat),
+ 7: syscalls.Supported("poll", Poll),
+ 8: syscalls.Supported("lseek", Lseek),
+ 9: syscalls.PartiallySupported("mmap", Mmap, "Generally supported with exceptions. Options MAP_FIXED_NOREPLACE, MAP_SHARED_VALIDATE, MAP_SYNC MAP_GROWSDOWN, MAP_HUGETLB are not supported.", nil),
+ 10: syscalls.Supported("mprotect", Mprotect),
+ 11: syscalls.Supported("munmap", Munmap),
+ 12: syscalls.Supported("brk", Brk),
+ 13: syscalls.Supported("rt_sigaction", RtSigaction),
+ 14: syscalls.Supported("rt_sigprocmask", RtSigprocmask),
+ 15: syscalls.Supported("rt_sigreturn", RtSigreturn),
+ 16: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil),
+ 17: syscalls.Supported("pread64", Pread64),
+ 18: syscalls.Supported("pwrite64", Pwrite64),
+ 19: syscalls.Supported("readv", Readv),
+ 20: syscalls.Supported("writev", Writev),
+ 21: syscalls.Supported("access", Access),
+ 22: syscalls.Supported("pipe", Pipe),
+ 23: syscalls.Supported("select", Select),
+ 24: syscalls.Supported("sched_yield", SchedYield),
+ 25: syscalls.Supported("mremap", Mremap),
+ 26: syscalls.PartiallySupported("msync", Msync, "Full data flush is not guaranteed at this time.", nil),
+ 27: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil),
+ 28: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil),
+ 29: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil),
+ 30: syscalls.PartiallySupported("shmat", Shmat, "Option SHM_RND is not supported.", nil),
+ 31: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil),
+ 32: syscalls.Supported("dup", Dup),
+ 33: syscalls.Supported("dup2", Dup2),
+ 34: syscalls.Supported("pause", Pause),
+ 35: syscalls.Supported("nanosleep", Nanosleep),
+ 36: syscalls.Supported("getitimer", Getitimer),
+ 37: syscalls.Supported("alarm", Alarm),
+ 38: syscalls.Supported("setitimer", Setitimer),
+ 39: syscalls.Supported("getpid", Getpid),
+ 40: syscalls.Supported("sendfile", Sendfile),
+ 41: syscalls.PartiallySupported("socket", Socket, "Limited support for AF_NETLINK, NETLINK_ROUTE sockets. Limited support for SOCK_RAW.", nil),
+ 42: syscalls.Supported("connect", Connect),
+ 43: syscalls.Supported("accept", Accept),
+ 44: syscalls.Supported("sendto", SendTo),
+ 45: syscalls.Supported("recvfrom", RecvFrom),
+ 46: syscalls.Supported("sendmsg", SendMsg),
+ 47: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil),
+ 48: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil),
+ 49: syscalls.PartiallySupported("bind", Bind, "Autobind for abstract Unix sockets is not supported.", nil),
+ 50: syscalls.Supported("listen", Listen),
+ 51: syscalls.Supported("getsockname", GetSockName),
+ 52: syscalls.Supported("getpeername", GetPeerName),
+ 53: syscalls.Supported("socketpair", SocketPair),
+ 54: syscalls.PartiallySupported("setsockopt", SetSockOpt, "Not all socket options are supported.", nil),
+ 55: syscalls.PartiallySupported("getsockopt", GetSockOpt, "Not all socket options are supported.", nil),
+ 56: syscalls.PartiallySupported("clone", Clone, "Mount namespace (CLONE_NEWNS) not supported. Options CLONE_PARENT, CLONE_SYSVSEM not supported.", nil),
+ 57: syscalls.Supported("fork", Fork),
+ 58: syscalls.Supported("vfork", Vfork),
+ 59: syscalls.Supported("execve", Execve),
+ 60: syscalls.Supported("exit", Exit),
+ 61: syscalls.Supported("wait4", Wait4),
+ 62: syscalls.Supported("kill", Kill),
+ 63: syscalls.Supported("uname", Uname),
+ 64: syscalls.Supported("semget", Semget),
+ 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
+ 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
+ 67: syscalls.Supported("shmdt", Shmdt),
+ 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 70: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 71: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil),
+ 73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil),
+ 74: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil),
+ 75: syscalls.PartiallySupported("fdatasync", Fdatasync, "Full data flush is not guaranteed at this time.", nil),
+ 76: syscalls.Supported("truncate", Truncate),
+ 77: syscalls.Supported("ftruncate", Ftruncate),
+ 78: syscalls.Supported("getdents", Getdents),
+ 79: syscalls.Supported("getcwd", Getcwd),
+ 80: syscalls.Supported("chdir", Chdir),
+ 81: syscalls.Supported("fchdir", Fchdir),
+ 82: syscalls.Supported("rename", Rename),
+ 83: syscalls.Supported("mkdir", Mkdir),
+ 84: syscalls.Supported("rmdir", Rmdir),
+ 85: syscalls.Supported("creat", Creat),
+ 86: syscalls.Supported("link", Link),
+ 87: syscalls.Supported("unlink", Unlink),
+ 88: syscalls.Supported("symlink", Symlink),
+ 89: syscalls.Supported("readlink", Readlink),
+ 90: syscalls.Supported("chmod", Chmod),
+ 91: syscalls.PartiallySupported("fchmod", Fchmod, "Options S_ISUID and S_ISGID not supported.", nil),
+ 92: syscalls.Supported("chown", Chown),
+ 93: syscalls.Supported("fchown", Fchown),
+ 94: syscalls.Supported("lchown", Lchown),
+ 95: syscalls.Supported("umask", Umask),
+ 96: syscalls.Supported("gettimeofday", Gettimeofday),
+ 97: syscalls.Supported("getrlimit", Getrlimit),
+ 98: syscalls.PartiallySupported("getrusage", Getrusage, "Fields ru_maxrss, ru_minflt, ru_majflt, ru_inblock, ru_oublock are not supported. Fields ru_utime and ru_stime have low precision.", nil),
+ 99: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil),
+ 100: syscalls.Supported("times", Times),
+ 101: syscalls.PartiallySupported("ptrace", Ptrace, "Options PTRACE_PEEKSIGINFO, PTRACE_SECCOMP_GET_FILTER not supported.", nil),
+ 102: syscalls.Supported("getuid", Getuid),
+ 103: syscalls.PartiallySupported("syslog", Syslog, "Outputs a dummy message for security reasons.", nil),
+ 104: syscalls.Supported("getgid", Getgid),
+ 105: syscalls.Supported("setuid", Setuid),
+ 106: syscalls.Supported("setgid", Setgid),
+ 107: syscalls.Supported("geteuid", Geteuid),
+ 108: syscalls.Supported("getegid", Getegid),
+ 109: syscalls.Supported("setpgid", Setpgid),
+ 110: syscalls.Supported("getppid", Getppid),
+ 111: syscalls.Supported("getpgrp", Getpgrp),
+ 112: syscalls.Supported("setsid", Setsid),
+ 113: syscalls.Supported("setreuid", Setreuid),
+ 114: syscalls.Supported("setregid", Setregid),
+ 115: syscalls.Supported("getgroups", Getgroups),
+ 116: syscalls.Supported("setgroups", Setgroups),
+ 117: syscalls.Supported("setresuid", Setresuid),
+ 118: syscalls.Supported("getresuid", Getresuid),
+ 119: syscalls.Supported("setresgid", Setresgid),
+ 120: syscalls.Supported("getresgid", Getresgid),
+ 121: syscalls.Supported("getpgid", Getpgid),
+ 122: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
+ 123: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
+ 124: syscalls.Supported("getsid", Getsid),
+ 125: syscalls.Supported("capget", Capget),
+ 126: syscalls.Supported("capset", Capset),
+ 127: syscalls.Supported("rt_sigpending", RtSigpending),
+ 128: syscalls.Supported("rt_sigtimedwait", RtSigtimedwait),
+ 129: syscalls.Supported("rt_sigqueueinfo", RtSigqueueinfo),
+ 130: syscalls.Supported("rt_sigsuspend", RtSigsuspend),
+ 131: syscalls.Supported("sigaltstack", Sigaltstack),
+ 132: syscalls.Supported("utime", Utime),
+ 133: syscalls.PartiallySupported("mknod", Mknod, "Device creation is not generally supported. Only regular file and FIFO creation are supported.", nil),
+ 134: syscalls.Error("uselib", syserror.ENOSYS, "Obsolete", nil),
+ 135: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil),
+ 136: syscalls.ErrorWithEvent("ustat", syserror.ENOSYS, "Needs filesystem support.", nil),
+ 137: syscalls.PartiallySupported("statfs", Statfs, "Depends on the backing file system implementation.", nil),
+ 138: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil),
+ 139: syscalls.ErrorWithEvent("sysfs", syserror.ENOSYS, "", []string{"gvisor.dev/issue/165"}),
+ 140: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil),
+ 141: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil),
+ 142: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil),
+ 143: syscalls.PartiallySupported("sched_getparam", SchedGetparam, "Stub implementation.", nil),
+ 144: syscalls.PartiallySupported("sched_setscheduler", SchedSetscheduler, "Stub implementation.", nil),
+ 145: syscalls.PartiallySupported("sched_getscheduler", SchedGetscheduler, "Stub implementation.", nil),
+ 146: syscalls.PartiallySupported("sched_get_priority_max", SchedGetPriorityMax, "Stub implementation.", nil),
+ 147: syscalls.PartiallySupported("sched_get_priority_min", SchedGetPriorityMin, "Stub implementation.", nil),
+ 148: syscalls.ErrorWithEvent("sched_rr_get_interval", syserror.EPERM, "", nil),
+ 149: syscalls.PartiallySupported("mlock", Mlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 150: syscalls.PartiallySupported("munlock", Munlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 151: syscalls.PartiallySupported("mlockall", Mlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 152: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 153: syscalls.CapError("vhangup", linux.CAP_SYS_TTY_CONFIG, "", nil),
+ 154: syscalls.Error("modify_ldt", syserror.EPERM, "", nil),
+ 155: syscalls.Error("pivot_root", syserror.EPERM, "", nil),
+ 156: syscalls.Error("sysctl", syserror.EPERM, "Deprecated. Use /proc/sys instead.", nil),
+ 157: syscalls.PartiallySupported("prctl", Prctl, "Not all options are supported.", nil),
+ 158: syscalls.PartiallySupported("arch_prctl", ArchPrctl, "Options ARCH_GET_GS, ARCH_SET_GS not supported.", nil),
+ 159: syscalls.CapError("adjtimex", linux.CAP_SYS_TIME, "", nil),
+ 160: syscalls.PartiallySupported("setrlimit", Setrlimit, "Not all rlimits are enforced.", nil),
+ 161: syscalls.Supported("chroot", Chroot),
+ 162: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil),
+ 163: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil),
+ 164: syscalls.CapError("settimeofday", linux.CAP_SYS_TIME, "", nil),
+ 165: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil),
+ 166: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil),
+ 167: syscalls.CapError("swapon", linux.CAP_SYS_ADMIN, "", nil),
+ 168: syscalls.CapError("swapoff", linux.CAP_SYS_ADMIN, "", nil),
+ 169: syscalls.CapError("reboot", linux.CAP_SYS_BOOT, "", nil),
+ 170: syscalls.Supported("sethostname", Sethostname),
+ 171: syscalls.Supported("setdomainname", Setdomainname),
+ 172: syscalls.CapError("iopl", linux.CAP_SYS_RAWIO, "", nil),
+ 173: syscalls.CapError("ioperm", linux.CAP_SYS_RAWIO, "", nil),
+ 174: syscalls.CapError("create_module", linux.CAP_SYS_MODULE, "", nil),
+ 175: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil),
+ 176: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil),
+ 177: syscalls.Error("get_kernel_syms", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil),
+ 178: syscalls.Error("query_module", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil),
+ 179: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations
+ 180: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil),
+ 181: syscalls.Error("getpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil),
+ 182: syscalls.Error("putpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil),
+ 183: syscalls.Error("afs_syscall", syserror.ENOSYS, "Not implemented in Linux.", nil),
+ 184: syscalls.Error("tuxcall", syserror.ENOSYS, "Not implemented in Linux.", nil),
+ 185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil),
+ 186: syscalls.Supported("gettid", Gettid),
+ 187: syscalls.Supported("readahead", Readahead),
+ 188: syscalls.PartiallySupported("setxattr", SetXattr, "Only supported for tmpfs.", nil),
+ 189: syscalls.PartiallySupported("lsetxattr", LSetXattr, "Only supported for tmpfs.", nil),
+ 190: syscalls.PartiallySupported("fsetxattr", FSetXattr, "Only supported for tmpfs.", nil),
+ 191: syscalls.PartiallySupported("getxattr", GetXattr, "Only supported for tmpfs.", nil),
+ 192: syscalls.PartiallySupported("lgetxattr", LGetXattr, "Only supported for tmpfs.", nil),
+ 193: syscalls.PartiallySupported("fgetxattr", FGetXattr, "Only supported for tmpfs.", nil),
+ 194: syscalls.PartiallySupported("listxattr", ListXattr, "Only supported for tmpfs", nil),
+ 195: syscalls.PartiallySupported("llistxattr", LListXattr, "Only supported for tmpfs", nil),
+ 196: syscalls.PartiallySupported("flistxattr", FListXattr, "Only supported for tmpfs", nil),
+ 197: syscalls.PartiallySupported("removexattr", RemoveXattr, "Only supported for tmpfs", nil),
+ 198: syscalls.PartiallySupported("lremovexattr", LRemoveXattr, "Only supported for tmpfs", nil),
+ 199: syscalls.PartiallySupported("fremovexattr", FRemoveXattr, "Only supported for tmpfs", nil),
+ 200: syscalls.Supported("tkill", Tkill),
+ 201: syscalls.Supported("time", Time),
+ 202: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil),
+ 203: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil),
+ 204: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil),
+ 205: syscalls.Error("set_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil),
+ 206: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 207: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 208: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 209: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 210: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 211: syscalls.Error("get_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil),
+ 212: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil),
+ 213: syscalls.Supported("epoll_create", EpollCreate),
+ 214: syscalls.ErrorWithEvent("epoll_ctl_old", syserror.ENOSYS, "Deprecated.", nil),
+ 215: syscalls.ErrorWithEvent("epoll_wait_old", syserror.ENOSYS, "Deprecated.", nil),
+ 216: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil),
+ 217: syscalls.Supported("getdents64", Getdents64),
+ 218: syscalls.Supported("set_tid_address", SetTidAddress),
+ 219: syscalls.Supported("restart_syscall", RestartSyscall),
+ 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}),
+ 221: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil),
+ 222: syscalls.Supported("timer_create", TimerCreate),
+ 223: syscalls.Supported("timer_settime", TimerSettime),
+ 224: syscalls.Supported("timer_gettime", TimerGettime),
+ 225: syscalls.Supported("timer_getoverrun", TimerGetoverrun),
+ 226: syscalls.Supported("timer_delete", TimerDelete),
+ 227: syscalls.Supported("clock_settime", ClockSettime),
+ 228: syscalls.Supported("clock_gettime", ClockGettime),
+ 229: syscalls.Supported("clock_getres", ClockGetres),
+ 230: syscalls.Supported("clock_nanosleep", ClockNanosleep),
+ 231: syscalls.Supported("exit_group", ExitGroup),
+ 232: syscalls.Supported("epoll_wait", EpollWait),
+ 233: syscalls.Supported("epoll_ctl", EpollCtl),
+ 234: syscalls.Supported("tgkill", Tgkill),
+ 235: syscalls.Supported("utimes", Utimes),
+ 236: syscalls.Error("vserver", syserror.ENOSYS, "Not implemented by Linux", nil),
+ 237: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}),
+ 238: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil),
+ 239: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil),
+ 240: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 241: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 242: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 243: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 244: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 245: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 246: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil),
+ 247: syscalls.Supported("waitid", Waitid),
+ 248: syscalls.Error("add_key", syserror.EACCES, "Not available to user.", nil),
+ 249: syscalls.Error("request_key", syserror.EACCES, "Not available to user.", nil),
+ 250: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil),
+ 251: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
+ 252: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
+ 253: syscalls.PartiallySupported("inotify_init", InotifyInit, "inotify events are only available inside the sandbox.", nil),
+ 254: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil),
+ 255: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil),
+ 256: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil),
+ 257: syscalls.Supported("openat", Openat),
+ 258: syscalls.Supported("mkdirat", Mkdirat),
+ 259: syscalls.Supported("mknodat", Mknodat),
+ 260: syscalls.Supported("fchownat", Fchownat),
+ 261: syscalls.Supported("futimesat", Futimesat),
+ 262: syscalls.Supported("fstatat", Fstatat),
+ 263: syscalls.Supported("unlinkat", Unlinkat),
+ 264: syscalls.Supported("renameat", Renameat),
+ 265: syscalls.Supported("linkat", Linkat),
+ 266: syscalls.Supported("symlinkat", Symlinkat),
+ 267: syscalls.Supported("readlinkat", Readlinkat),
+ 268: syscalls.Supported("fchmodat", Fchmodat),
+ 269: syscalls.Supported("faccessat", Faccessat),
+ 270: syscalls.Supported("pselect", Pselect),
+ 271: syscalls.Supported("ppoll", Ppoll),
+ 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil),
+ 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil),
+ 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil),
+ 275: syscalls.Supported("splice", Splice),
+ 276: syscalls.Supported("tee", Tee),
+ 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil),
+ 278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly)
+ 280: syscalls.Supported("utimensat", Utimensat),
+ 281: syscalls.Supported("epoll_pwait", EpollPwait),
+ 282: syscalls.PartiallySupported("signalfd", Signalfd, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}),
+ 283: syscalls.Supported("timerfd_create", TimerfdCreate),
+ 284: syscalls.Supported("eventfd", Eventfd),
+ 285: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil),
+ 286: syscalls.Supported("timerfd_settime", TimerfdSettime),
+ 287: syscalls.Supported("timerfd_gettime", TimerfdGettime),
+ 288: syscalls.Supported("accept4", Accept4),
+ 289: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}),
+ 290: syscalls.Supported("eventfd2", Eventfd2),
+ 291: syscalls.Supported("epoll_create1", EpollCreate1),
+ 292: syscalls.Supported("dup3", Dup3),
+ 293: syscalls.Supported("pipe2", Pipe2),
+ 294: syscalls.Supported("inotify_init1", InotifyInit1),
+ 295: syscalls.Supported("preadv", Preadv),
+ 296: syscalls.Supported("pwritev", Pwritev),
+ 297: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo),
+ 298: syscalls.ErrorWithEvent("perf_event_open", syserror.ENODEV, "No support for perf counters", nil),
+ 299: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil),
+ 300: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
+ 301: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
+ 302: syscalls.Supported("prlimit64", Prlimit64),
+ 303: syscalls.Error("name_to_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
+ 304: syscalls.Error("open_by_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
+ 305: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil),
+ 306: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil),
+ 307: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil),
+ 308: syscalls.ErrorWithEvent("setns", syserror.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995)
+ 309: syscalls.Supported("getcpu", Getcpu),
+ 310: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
+ 311: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
+ 312: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil),
+ 313: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil),
+ 314: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
+ 315: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
+ 316: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772)
+ 317: syscalls.Supported("seccomp", Seccomp),
+ 318: syscalls.Supported("getrandom", GetRandom),
+ 319: syscalls.Supported("memfd_create", MemfdCreate),
+ 320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil),
+ 321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
+ 322: syscalls.Supported("execveat", Execveat),
+ 323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
+ 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267)
+ 325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+
+ // Syscalls implemented after 325 are "backports" from versions
+ // of Linux after 4.4.
+ 326: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil),
+ 327: syscalls.Supported("preadv2", Preadv2),
+ 328: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil),
+ 329: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil),
+ 330: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil),
+ 331: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil),
+ 332: syscalls.Supported("statx", Statx),
+ 333: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil),
+ 334: syscalls.PartiallySupported("rseq", RSeq, "Not supported on all platforms.", nil),
+
+ // Linux skips ahead to syscall 424 to sync numbers between arches.
+ 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil),
+ 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil),
+ 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil),
+ 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil),
+ 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil),
+ 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil),
+ 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil),
+ 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil),
+ 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil),
+ 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil),
+ 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil),
+ 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil),
+ },
+ Emulate: map[usermem.Addr]uintptr{
+ 0xffffffffff600000: 96, // vsyscall gettimeofday(2)
+ 0xffffffffff600400: 201, // vsyscall time(2)
+ 0xffffffffff600800: 309, // vsyscall getcpu(2)
+ },
+ Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, syserror.ENOSYS
+ },
+}
+
+// ARM64 is a table of Linux arm64 syscall API with the corresponding syscall
+// numbers from Linux 4.4.
+var ARM64 = &kernel.SyscallTable{
+ OS: abi.Linux,
+ Arch: arch.ARM64,
+ Version: kernel.Version{
+ Sysname: LinuxSysname,
+ Release: LinuxRelease,
+ Version: LinuxVersion,
+ },
+ AuditNumber: linux.AUDIT_ARCH_AARCH64,
+ Table: map[uintptr]kernel.Syscall{
+ 0: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 1: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 2: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 3: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 4: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 5: syscalls.PartiallySupported("setxattr", SetXattr, "Only supported for tmpfs.", nil),
+ 6: syscalls.PartiallySupported("lsetxattr", LSetXattr, "Only supported for tmpfs.", nil),
+ 7: syscalls.PartiallySupported("fsetxattr", FSetXattr, "Only supported for tmpfs.", nil),
+ 8: syscalls.PartiallySupported("getxattr", GetXattr, "Only supported for tmpfs.", nil),
+ 9: syscalls.PartiallySupported("lgetxattr", LGetXattr, "Only supported for tmpfs.", nil),
+ 10: syscalls.PartiallySupported("fgetxattr", FGetXattr, "Only supported for tmpfs.", nil),
+ 11: syscalls.PartiallySupported("listxattr", ListXattr, "Only supported for tmpfs", nil),
+ 12: syscalls.PartiallySupported("llistxattr", LListXattr, "Only supported for tmpfs", nil),
+ 13: syscalls.PartiallySupported("flistxattr", FListXattr, "Only supported for tmpfs", nil),
+ 14: syscalls.PartiallySupported("removexattr", RemoveXattr, "Only supported for tmpfs", nil),
+ 15: syscalls.PartiallySupported("lremovexattr", LRemoveXattr, "Only supported for tmpfs", nil),
+ 16: syscalls.PartiallySupported("fremovexattr", FRemoveXattr, "Only supported for tmpfs", nil),
+ 17: syscalls.Supported("getcwd", Getcwd),
+ 18: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil),
+ 19: syscalls.Supported("eventfd2", Eventfd2),
+ 20: syscalls.Supported("epoll_create1", EpollCreate1),
+ 21: syscalls.Supported("epoll_ctl", EpollCtl),
+ 22: syscalls.Supported("epoll_pwait", EpollPwait),
+ 23: syscalls.Supported("dup", Dup),
+ 24: syscalls.Supported("dup3", Dup3),
+ 25: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil),
+ 26: syscalls.Supported("inotify_init1", InotifyInit1),
+ 27: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil),
+ 28: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil),
+ 29: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil),
+ 30: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
+ 31: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
+ 32: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil),
+ 33: syscalls.Supported("mknodat", Mknodat),
+ 34: syscalls.Supported("mkdirat", Mkdirat),
+ 35: syscalls.Supported("unlinkat", Unlinkat),
+ 36: syscalls.Supported("symlinkat", Symlinkat),
+ 37: syscalls.Supported("linkat", Linkat),
+ 38: syscalls.Supported("renameat", Renameat),
+ 39: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil),
+ 40: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil),
+ 41: syscalls.Error("pivot_root", syserror.EPERM, "", nil),
+ 42: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil),
+ 43: syscalls.PartiallySupported("statfs", Statfs, "Depends on the backing file system implementation.", nil),
+ 44: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil),
+ 45: syscalls.Supported("truncate", Truncate),
+ 46: syscalls.Supported("ftruncate", Ftruncate),
+ 47: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil),
+ 48: syscalls.Supported("faccessat", Faccessat),
+ 49: syscalls.Supported("chdir", Chdir),
+ 50: syscalls.Supported("fchdir", Fchdir),
+ 51: syscalls.Supported("chroot", Chroot),
+ 52: syscalls.PartiallySupported("fchmod", Fchmod, "Options S_ISUID and S_ISGID not supported.", nil),
+ 53: syscalls.Supported("fchmodat", Fchmodat),
+ 54: syscalls.Supported("fchownat", Fchownat),
+ 55: syscalls.Supported("fchown", Fchown),
+ 56: syscalls.Supported("openat", Openat),
+ 57: syscalls.Supported("close", Close),
+ 58: syscalls.CapError("vhangup", linux.CAP_SYS_TTY_CONFIG, "", nil),
+ 59: syscalls.Supported("pipe2", Pipe2),
+ 60: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations
+ 61: syscalls.Supported("getdents64", Getdents64),
+ 62: syscalls.Supported("lseek", Lseek),
+ 63: syscalls.Supported("read", Read),
+ 64: syscalls.Supported("write", Write),
+ 65: syscalls.Supported("readv", Readv),
+ 66: syscalls.Supported("writev", Writev),
+ 67: syscalls.Supported("pread64", Pread64),
+ 68: syscalls.Supported("pwrite64", Pwrite64),
+ 69: syscalls.Supported("preadv", Preadv),
+ 70: syscalls.Supported("pwritev", Pwritev),
+ 71: syscalls.Supported("sendfile", Sendfile),
+ 72: syscalls.Supported("pselect", Pselect),
+ 73: syscalls.Supported("ppoll", Ppoll),
+ 74: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}),
+ 75: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 76: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 77: syscalls.Supported("tee", Tee),
+ 78: syscalls.Supported("readlinkat", Readlinkat),
+ 79: syscalls.Supported("fstatat", Fstatat),
+ 80: syscalls.Supported("fstat", Fstat),
+ 81: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil),
+ 82: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil),
+ 83: syscalls.PartiallySupported("fdatasync", Fdatasync, "Full data flush is not guaranteed at this time.", nil),
+ 84: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil),
+ 85: syscalls.Supported("timerfd_create", TimerfdCreate),
+ 86: syscalls.Supported("timerfd_settime", TimerfdSettime),
+ 87: syscalls.Supported("timerfd_gettime", TimerfdGettime),
+ 88: syscalls.Supported("utimensat", Utimensat),
+ 89: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil),
+ 90: syscalls.Supported("capget", Capget),
+ 91: syscalls.Supported("capset", Capset),
+ 92: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil),
+ 93: syscalls.Supported("exit", Exit),
+ 94: syscalls.Supported("exit_group", ExitGroup),
+ 95: syscalls.Supported("waitid", Waitid),
+ 96: syscalls.Supported("set_tid_address", SetTidAddress),
+ 97: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil),
+ 98: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil),
+ 99: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil),
+ 100: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil),
+ 101: syscalls.Supported("nanosleep", Nanosleep),
+ 102: syscalls.Supported("getitimer", Getitimer),
+ 103: syscalls.Supported("setitimer", Setitimer),
+ 104: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil),
+ 105: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil),
+ 106: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil),
+ 107: syscalls.Supported("timer_create", TimerCreate),
+ 108: syscalls.Supported("timer_gettime", TimerGettime),
+ 109: syscalls.Supported("timer_getoverrun", TimerGetoverrun),
+ 110: syscalls.Supported("timer_settime", TimerSettime),
+ 111: syscalls.Supported("timer_delete", TimerDelete),
+ 112: syscalls.Supported("clock_settime", ClockSettime),
+ 113: syscalls.Supported("clock_gettime", ClockGettime),
+ 114: syscalls.Supported("clock_getres", ClockGetres),
+ 115: syscalls.Supported("clock_nanosleep", ClockNanosleep),
+ 116: syscalls.PartiallySupported("syslog", Syslog, "Outputs a dummy message for security reasons.", nil),
+ 117: syscalls.PartiallySupported("ptrace", Ptrace, "Options PTRACE_PEEKSIGINFO, PTRACE_SECCOMP_GET_FILTER not supported.", nil),
+ 118: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil),
+ 119: syscalls.PartiallySupported("sched_setscheduler", SchedSetscheduler, "Stub implementation.", nil),
+ 120: syscalls.PartiallySupported("sched_getscheduler", SchedGetscheduler, "Stub implementation.", nil),
+ 121: syscalls.PartiallySupported("sched_getparam", SchedGetparam, "Stub implementation.", nil),
+ 122: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil),
+ 123: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil),
+ 124: syscalls.Supported("sched_yield", SchedYield),
+ 125: syscalls.PartiallySupported("sched_get_priority_max", SchedGetPriorityMax, "Stub implementation.", nil),
+ 126: syscalls.PartiallySupported("sched_get_priority_min", SchedGetPriorityMin, "Stub implementation.", nil),
+ 127: syscalls.ErrorWithEvent("sched_rr_get_interval", syserror.EPERM, "", nil),
+ 128: syscalls.Supported("restart_syscall", RestartSyscall),
+ 129: syscalls.Supported("kill", Kill),
+ 130: syscalls.Supported("tkill", Tkill),
+ 131: syscalls.Supported("tgkill", Tgkill),
+ 132: syscalls.Supported("sigaltstack", Sigaltstack),
+ 133: syscalls.Supported("rt_sigsuspend", RtSigsuspend),
+ 134: syscalls.Supported("rt_sigaction", RtSigaction),
+ 135: syscalls.Supported("rt_sigprocmask", RtSigprocmask),
+ 136: syscalls.Supported("rt_sigpending", RtSigpending),
+ 137: syscalls.Supported("rt_sigtimedwait", RtSigtimedwait),
+ 138: syscalls.Supported("rt_sigqueueinfo", RtSigqueueinfo),
+ 139: syscalls.Supported("rt_sigreturn", RtSigreturn),
+ 140: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil),
+ 141: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil),
+ 142: syscalls.CapError("reboot", linux.CAP_SYS_BOOT, "", nil),
+ 143: syscalls.Supported("setregid", Setregid),
+ 144: syscalls.Supported("setgid", Setgid),
+ 145: syscalls.Supported("setreuid", Setreuid),
+ 146: syscalls.Supported("setuid", Setuid),
+ 147: syscalls.Supported("setresuid", Setresuid),
+ 148: syscalls.Supported("getresuid", Getresuid),
+ 149: syscalls.Supported("setresgid", Setresgid),
+ 150: syscalls.Supported("getresgid", Getresgid),
+ 151: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
+ 152: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
+ 153: syscalls.Supported("times", Times),
+ 154: syscalls.Supported("setpgid", Setpgid),
+ 155: syscalls.Supported("getpgid", Getpgid),
+ 156: syscalls.Supported("getsid", Getsid),
+ 157: syscalls.Supported("setsid", Setsid),
+ 158: syscalls.Supported("getgroups", Getgroups),
+ 159: syscalls.Supported("setgroups", Setgroups),
+ 160: syscalls.Supported("uname", Uname),
+ 161: syscalls.Supported("sethostname", Sethostname),
+ 162: syscalls.Supported("setdomainname", Setdomainname),
+ 163: syscalls.Supported("getrlimit", Getrlimit),
+ 164: syscalls.PartiallySupported("setrlimit", Setrlimit, "Not all rlimits are enforced.", nil),
+ 165: syscalls.PartiallySupported("getrusage", Getrusage, "Fields ru_maxrss, ru_minflt, ru_majflt, ru_inblock, ru_oublock are not supported. Fields ru_utime and ru_stime have low precision.", nil),
+ 166: syscalls.Supported("umask", Umask),
+ 167: syscalls.PartiallySupported("prctl", Prctl, "Not all options are supported.", nil),
+ 168: syscalls.Supported("getcpu", Getcpu),
+ 169: syscalls.Supported("gettimeofday", Gettimeofday),
+ 170: syscalls.CapError("settimeofday", linux.CAP_SYS_TIME, "", nil),
+ 171: syscalls.CapError("adjtimex", linux.CAP_SYS_TIME, "", nil),
+ 172: syscalls.Supported("getpid", Getpid),
+ 173: syscalls.Supported("getppid", Getppid),
+ 174: syscalls.Supported("getuid", Getuid),
+ 175: syscalls.Supported("geteuid", Geteuid),
+ 176: syscalls.Supported("getgid", Getgid),
+ 177: syscalls.Supported("getegid", Getegid),
+ 178: syscalls.Supported("gettid", Gettid),
+ 179: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil),
+ 180: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 181: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 182: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 183: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 184: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 185: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 186: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 187: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 190: syscalls.Supported("semget", Semget),
+ 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
+ 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}),
+ 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
+ 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil),
+ 195: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil),
+ 196: syscalls.PartiallySupported("shmat", Shmat, "Option SHM_RND is not supported.", nil),
+ 197: syscalls.Supported("shmdt", Shmdt),
+ 198: syscalls.PartiallySupported("socket", Socket, "Limited support for AF_NETLINK, NETLINK_ROUTE sockets. Limited support for SOCK_RAW.", nil),
+ 199: syscalls.Supported("socketpair", SocketPair),
+ 200: syscalls.PartiallySupported("bind", Bind, "Autobind for abstract Unix sockets is not supported.", nil),
+ 201: syscalls.Supported("listen", Listen),
+ 202: syscalls.Supported("accept", Accept),
+ 203: syscalls.Supported("connect", Connect),
+ 204: syscalls.Supported("getsockname", GetSockName),
+ 205: syscalls.Supported("getpeername", GetPeerName),
+ 206: syscalls.Supported("sendto", SendTo),
+ 207: syscalls.Supported("recvfrom", RecvFrom),
+ 208: syscalls.PartiallySupported("setsockopt", SetSockOpt, "Not all socket options are supported.", nil),
+ 209: syscalls.PartiallySupported("getsockopt", GetSockOpt, "Not all socket options are supported.", nil),
+ 210: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil),
+ 211: syscalls.Supported("sendmsg", SendMsg),
+ 212: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil),
+ 213: syscalls.Supported("readahead", Readahead),
+ 214: syscalls.Supported("brk", Brk),
+ 215: syscalls.Supported("munmap", Munmap),
+ 216: syscalls.Supported("mremap", Mremap),
+ 217: syscalls.Error("add_key", syserror.EACCES, "Not available to user.", nil),
+ 218: syscalls.Error("request_key", syserror.EACCES, "Not available to user.", nil),
+ 219: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil),
+ 220: syscalls.PartiallySupported("clone", Clone, "Mount namespace (CLONE_NEWNS) not supported. Options CLONE_PARENT, CLONE_SYSVSEM not supported.", nil),
+ 221: syscalls.Supported("execve", Execve),
+ 222: syscalls.PartiallySupported("mmap", Mmap, "Generally supported with exceptions. Options MAP_FIXED_NOREPLACE, MAP_SHARED_VALIDATE, MAP_SYNC MAP_GROWSDOWN, MAP_HUGETLB are not supported.", nil),
+ 223: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil),
+ 224: syscalls.CapError("swapon", linux.CAP_SYS_ADMIN, "", nil),
+ 225: syscalls.CapError("swapoff", linux.CAP_SYS_ADMIN, "", nil),
+ 226: syscalls.Supported("mprotect", Mprotect),
+ 227: syscalls.PartiallySupported("msync", Msync, "Full data flush is not guaranteed at this time.", nil),
+ 228: syscalls.PartiallySupported("mlock", Mlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 229: syscalls.PartiallySupported("munlock", Munlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 230: syscalls.PartiallySupported("mlockall", Mlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 231: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 232: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil),
+ 233: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil),
+ 234: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil),
+ 235: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}),
+ 236: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil),
+ 237: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil),
+ 238: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil),
+ 239: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly)
+ 240: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo),
+ 241: syscalls.ErrorWithEvent("perf_event_open", syserror.ENODEV, "No support for perf counters", nil),
+ 242: syscalls.Supported("accept4", Accept4),
+ 243: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil),
+ 260: syscalls.Supported("wait4", Wait4),
+ 261: syscalls.Supported("prlimit64", Prlimit64),
+ 262: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
+ 263: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
+ 264: syscalls.Error("name_to_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
+ 265: syscalls.Error("open_by_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
+ 266: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil),
+ 267: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil),
+ 268: syscalls.ErrorWithEvent("setns", syserror.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995)
+ 269: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil),
+ 270: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
+ 271: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
+ 272: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil),
+ 273: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil),
+ 274: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
+ 275: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
+ 276: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772)
+ 277: syscalls.Supported("seccomp", Seccomp),
+ 278: syscalls.Supported("getrandom", GetRandom),
+ 279: syscalls.Supported("memfd_create", MemfdCreate),
+ 280: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
+ 281: syscalls.Supported("execveat", Execveat),
+ 282: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
+ 283: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267)
+ 284: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+
+ // Syscalls after 284 are "backports" from versions of Linux after 4.4.
+ 285: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil),
+ 286: syscalls.Supported("preadv2", Preadv2),
+ 287: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil),
+ 288: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil),
+ 289: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil),
+ 290: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil),
+ 291: syscalls.Supported("statx", Statx),
+ 292: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil),
+ 293: syscalls.PartiallySupported("rseq", RSeq, "Not supported on all platforms.", nil),
+
+ // Linux skips ahead to syscall 424 to sync numbers between arches.
+ 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil),
+ 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil),
+ 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil),
+ 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil),
+ 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil),
+ 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil),
+ 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil),
+ 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil),
+ 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil),
+ 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil),
+ 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil),
+ 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil),
+ },
+ Emulate: map[usermem.Addr]uintptr{},
+ Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, syserror.ENOSYS
+ },
+}
+
+func init() {
+ kernel.RegisterSyscallTable(AMD64)
+ kernel.RegisterSyscallTable(ARM64)
+}
diff --git a/pkg/sentry/syscalls/linux/sigset.go b/pkg/sentry/syscalls/linux/sigset.go
new file mode 100644
index 000000000..434559b80
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sigset.go
@@ -0,0 +1,71 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// CopyInSigSet copies in a sigset_t, checks its size, and ensures that KILL and
+// STOP are clear.
+//
+// TODO(gvisor.dev/issue/1624): This is only exported because
+// syscalls/vfs2/signal.go depends on it. Once vfs1 is deleted and the vfs2
+// syscalls are moved into this package, then they can be unexported.
+func CopyInSigSet(t *kernel.Task, sigSetAddr usermem.Addr, size uint) (linux.SignalSet, error) {
+ if size != linux.SignalSetSize {
+ return 0, syserror.EINVAL
+ }
+ b := t.CopyScratchBuffer(8)
+ if _, err := t.CopyInBytes(sigSetAddr, b); err != nil {
+ return 0, err
+ }
+ mask := usermem.ByteOrder.Uint64(b[:])
+ return linux.SignalSet(mask) &^ kernel.UnblockableSignals, nil
+}
+
+// copyOutSigSet copies out a sigset_t.
+func copyOutSigSet(t *kernel.Task, sigSetAddr usermem.Addr, mask linux.SignalSet) error {
+ b := t.CopyScratchBuffer(8)
+ usermem.ByteOrder.PutUint64(b, uint64(mask))
+ _, err := t.CopyOutBytes(sigSetAddr, b)
+ return err
+}
+
+// copyInSigSetWithSize copies in a structure as below
+//
+// struct {
+// sigset_t* sigset_addr;
+// size_t sizeof_sigset;
+// };
+//
+// and returns sigset_addr and size.
+func copyInSigSetWithSize(t *kernel.Task, addr usermem.Addr) (usermem.Addr, uint, error) {
+ switch t.Arch().Width() {
+ case 8:
+ in := t.CopyScratchBuffer(16)
+ if _, err := t.CopyInBytes(addr, in); err != nil {
+ return 0, 0, err
+ }
+ maskAddr := usermem.Addr(usermem.ByteOrder.Uint64(in[0:]))
+ maskSize := uint(usermem.ByteOrder.Uint64(in[8:]))
+ return maskAddr, maskSize, nil
+ default:
+ return 0, 0, syserror.ENOSYS
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go
new file mode 100644
index 000000000..ba2557c52
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_aio.go
@@ -0,0 +1,382 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/eventfd"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// IoSetup implements linux syscall io_setup(2).
+func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nrEvents := args[0].Int()
+ idAddr := args[1].Pointer()
+
+ // Linux uses the native long as the aio ID.
+ //
+ // The context pointer _must_ be zero initially.
+ var idIn uint64
+ if _, err := t.CopyIn(idAddr, &idIn); err != nil {
+ return 0, nil, err
+ }
+ if idIn != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ id, err := t.MemoryManager().NewAIOContext(t, uint32(nrEvents))
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Copy out the new ID.
+ if _, err := t.CopyOut(idAddr, &id); err != nil {
+ t.MemoryManager().DestroyAIOContext(t, id)
+ return 0, nil, err
+ }
+
+ return 0, nil, nil
+}
+
+// IoDestroy implements linux syscall io_destroy(2).
+func IoDestroy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := args[0].Uint64()
+
+ ctx := t.MemoryManager().DestroyAIOContext(t, id)
+ if ctx == nil {
+ // Does not exist.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Drain completed requests amd wait for pending requests until there are no
+ // more.
+ for {
+ ctx.Drain()
+
+ ch := ctx.WaitChannel()
+ if ch == nil {
+ // No more requests, we're done.
+ return 0, nil, nil
+ }
+ // The task cannot be interrupted during the wait. Equivalent to
+ // TASK_UNINTERRUPTIBLE in Linux.
+ t.UninterruptibleSleepStart(true /* deactivate */)
+ <-ch
+ t.UninterruptibleSleepFinish(true /* activate */)
+ }
+}
+
+// IoGetevents implements linux syscall io_getevents(2).
+func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := args[0].Uint64()
+ minEvents := args[1].Int()
+ events := args[2].Int()
+ eventsAddr := args[3].Pointer()
+ timespecAddr := args[4].Pointer()
+
+ // Sanity check arguments.
+ if minEvents < 0 || minEvents > events {
+ return 0, nil, syserror.EINVAL
+ }
+
+ ctx, ok := t.MemoryManager().LookupAIOContext(t, id)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Setup the timeout.
+ var haveDeadline bool
+ var deadline ktime.Time
+ if timespecAddr != 0 {
+ d, err := copyTimespecIn(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ if !d.Valid() {
+ return 0, nil, syserror.EINVAL
+ }
+ deadline = t.Kernel().MonotonicClock().Now().Add(d.ToDuration())
+ haveDeadline = true
+ }
+
+ // Loop over all requests.
+ for count := int32(0); count < events; count++ {
+ // Get a request, per semantics.
+ var v interface{}
+ if count >= minEvents {
+ var ok bool
+ v, ok = ctx.PopRequest()
+ if !ok {
+ return uintptr(count), nil, nil
+ }
+ } else {
+ var err error
+ v, err = waitForRequest(ctx, t, haveDeadline, deadline)
+ if err != nil {
+ if count > 0 || err == syserror.ETIMEDOUT {
+ return uintptr(count), nil, nil
+ }
+ return 0, nil, syserror.ConvertIntr(err, syserror.EINTR)
+ }
+ }
+
+ ev := v.(*linux.IOEvent)
+
+ // Copy out the result.
+ if _, err := t.CopyOut(eventsAddr, ev); err != nil {
+ if count > 0 {
+ return uintptr(count), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
+ }
+
+ // Keep rolling.
+ eventsAddr += usermem.Addr(linux.IOEventSize)
+ }
+
+ // Everything finished.
+ return uintptr(events), nil, nil
+}
+
+func waitForRequest(ctx *mm.AIOContext, t *kernel.Task, haveDeadline bool, deadline ktime.Time) (interface{}, error) {
+ for {
+ if v, ok := ctx.PopRequest(); ok {
+ // Request was readily available. Just return it.
+ return v, nil
+ }
+
+ // Need to wait for request completion.
+ done := ctx.WaitChannel()
+ if done == nil {
+ // Context has been destroyed.
+ return nil, syserror.EINVAL
+ }
+ if err := t.BlockWithDeadline(done, haveDeadline, deadline); err != nil {
+ return nil, err
+ }
+ }
+}
+
+// memoryFor returns appropriate memory for the given callback.
+func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error) {
+ bytes := int(cb.Bytes)
+ if bytes < 0 {
+ // Linux also requires that this field fit in ssize_t.
+ return usermem.IOSequence{}, syserror.EINVAL
+ }
+
+ // Since this I/O will be asynchronous with respect to t's task goroutine,
+ // we have no guarantee that t's AddressSpace will be active during the
+ // I/O.
+ switch cb.OpCode {
+ case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PWRITE:
+ return t.SingleIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{
+ AddressSpaceActive: false,
+ })
+
+ case linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITEV:
+ return t.IovecsIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{
+ AddressSpaceActive: false,
+ })
+
+ case linux.IOCB_CMD_FSYNC, linux.IOCB_CMD_FDSYNC, linux.IOCB_CMD_NOOP:
+ return usermem.IOSequence{}, nil
+
+ default:
+ // Not a supported command.
+ return usermem.IOSequence{}, syserror.EINVAL
+ }
+}
+
+// IoCancel implements linux syscall io_cancel(2).
+//
+// It is not presently supported (ENOSYS indicates no support on this
+// architecture).
+func IoCancel(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, syserror.ENOSYS
+}
+
+// LINT.IfChange
+
+func getAIOCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *linux.IOCallback, ioseq usermem.IOSequence, actx *mm.AIOContext, eventFile *fs.File) kernel.AIOCallback {
+ return func(ctx context.Context) {
+ if actx.Dead() {
+ actx.CancelPendingRequest()
+ return
+ }
+ ev := &linux.IOEvent{
+ Data: cb.Data,
+ Obj: uint64(cbAddr),
+ }
+
+ var err error
+ switch cb.OpCode {
+ case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV:
+ ev.Result, err = file.Preadv(ctx, ioseq, cb.Offset)
+ case linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV:
+ ev.Result, err = file.Pwritev(ctx, ioseq, cb.Offset)
+ case linux.IOCB_CMD_FSYNC:
+ err = file.Fsync(ctx, 0, fs.FileMaxOffset, fs.SyncAll)
+ case linux.IOCB_CMD_FDSYNC:
+ err = file.Fsync(ctx, 0, fs.FileMaxOffset, fs.SyncData)
+ }
+
+ // Update the result.
+ if err != nil {
+ err = handleIOError(t, ev.Result != 0 /* partial */, err, nil /* never interrupted */, "aio", file)
+ ev.Result = -int64(kernel.ExtractErrno(err, 0))
+ }
+
+ file.DecRef()
+
+ // Queue the result for delivery.
+ actx.FinishRequest(ev)
+
+ // Notify the event file if one was specified. This needs to happen
+ // *after* queueing the result to avoid racing with the thread we may
+ // wake up.
+ if eventFile != nil {
+ eventFile.FileOperations.(*eventfd.EventOperations).Signal(1)
+ eventFile.DecRef()
+ }
+ }
+}
+
+// submitCallback processes a single callback.
+func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr usermem.Addr) error {
+ file := t.GetFile(cb.FD)
+ if file == nil {
+ // File not found.
+ return syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Was there an eventFD? Extract it.
+ var eventFile *fs.File
+ if cb.Flags&linux.IOCB_FLAG_RESFD != 0 {
+ eventFile = t.GetFile(cb.ResFD)
+ if eventFile == nil {
+ // Bad FD.
+ return syserror.EBADF
+ }
+ defer eventFile.DecRef()
+
+ // Check that it is an eventfd.
+ if _, ok := eventFile.FileOperations.(*eventfd.EventOperations); !ok {
+ // Not an event FD.
+ return syserror.EINVAL
+ }
+ }
+
+ ioseq, err := memoryFor(t, cb)
+ if err != nil {
+ return err
+ }
+
+ // Check offset for reads/writes.
+ switch cb.OpCode {
+ case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV:
+ if cb.Offset < 0 {
+ return syserror.EINVAL
+ }
+ }
+
+ // Prepare the request.
+ ctx, ok := t.MemoryManager().LookupAIOContext(t, id)
+ if !ok {
+ return syserror.EINVAL
+ }
+ if ready := ctx.Prepare(); !ready {
+ // Context is busy.
+ return syserror.EAGAIN
+ }
+
+ if eventFile != nil {
+ // The request is set. Make sure there's a ref on the file.
+ //
+ // This is necessary when the callback executes on completion,
+ // which is also what will release this reference.
+ eventFile.IncRef()
+ }
+
+ // Perform the request asynchronously.
+ file.IncRef()
+ t.QueueAIO(getAIOCallback(t, file, cbAddr, cb, ioseq, ctx, eventFile))
+
+ // All set.
+ return nil
+}
+
+// IoSubmit implements linux syscall io_submit(2).
+func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := args[0].Uint64()
+ nrEvents := args[1].Int()
+ addr := args[2].Pointer()
+
+ if nrEvents < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ for i := int32(0); i < nrEvents; i++ {
+ // Copy in the address.
+ cbAddrNative := t.Arch().Native(0)
+ if _, err := t.CopyIn(addr, cbAddrNative); err != nil {
+ if i > 0 {
+ // Some successful.
+ return uintptr(i), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
+ }
+
+ // Copy in this callback.
+ var cb linux.IOCallback
+ cbAddr := usermem.Addr(t.Arch().Value(cbAddrNative))
+ if _, err := t.CopyIn(cbAddr, &cb); err != nil {
+
+ if i > 0 {
+ // Some have been successful.
+ return uintptr(i), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
+ }
+
+ // Process this callback.
+ if err := submitCallback(t, id, &cb, cbAddr); err != nil {
+ if i > 0 {
+ // Partial success.
+ return uintptr(i), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
+ }
+
+ // Advance to the next one.
+ addr += usermem.Addr(t.Arch().Width())
+ }
+
+ return uintptr(nrEvents), nil, nil
+}
+
+// LINT.ThenChange(vfs2/aio.go)
diff --git a/pkg/sentry/syscalls/linux/sys_capability.go b/pkg/sentry/syscalls/linux/sys_capability.go
new file mode 100644
index 000000000..adf5ea5f2
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_capability.go
@@ -0,0 +1,149 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+func lookupCaps(t *kernel.Task, tid kernel.ThreadID) (permitted, inheritable, effective auth.CapabilitySet, err error) {
+ if tid < 0 {
+ err = syserror.EINVAL
+ return
+ }
+ if tid > 0 {
+ t = t.PIDNamespace().TaskWithID(tid)
+ }
+ if t == nil {
+ err = syserror.ESRCH
+ return
+ }
+ creds := t.Credentials()
+ permitted, inheritable, effective = creds.PermittedCaps, creds.InheritableCaps, creds.EffectiveCaps
+ return
+}
+
+// Capget implements Linux syscall capget.
+func Capget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ hdrAddr := args[0].Pointer()
+ dataAddr := args[1].Pointer()
+
+ var hdr linux.CapUserHeader
+ if _, err := t.CopyIn(hdrAddr, &hdr); err != nil {
+ return 0, nil, err
+ }
+ // hdr.Pid doesn't need to be valid if this capget() is a "version probe"
+ // (hdr.Version is unrecognized and dataAddr is null), so we can't do the
+ // lookup yet.
+ switch hdr.Version {
+ case linux.LINUX_CAPABILITY_VERSION_1:
+ if dataAddr == 0 {
+ return 0, nil, nil
+ }
+ p, i, e, err := lookupCaps(t, kernel.ThreadID(hdr.Pid))
+ if err != nil {
+ return 0, nil, err
+ }
+ data := linux.CapUserData{
+ Effective: uint32(e),
+ Permitted: uint32(p),
+ Inheritable: uint32(i),
+ }
+ _, err = t.CopyOut(dataAddr, &data)
+ return 0, nil, err
+
+ case linux.LINUX_CAPABILITY_VERSION_2, linux.LINUX_CAPABILITY_VERSION_3:
+ if dataAddr == 0 {
+ return 0, nil, nil
+ }
+ p, i, e, err := lookupCaps(t, kernel.ThreadID(hdr.Pid))
+ if err != nil {
+ return 0, nil, err
+ }
+ data := [2]linux.CapUserData{
+ {
+ Effective: uint32(e),
+ Permitted: uint32(p),
+ Inheritable: uint32(i),
+ },
+ {
+ Effective: uint32(e >> 32),
+ Permitted: uint32(p >> 32),
+ Inheritable: uint32(i >> 32),
+ },
+ }
+ _, err = t.CopyOut(dataAddr, &data)
+ return 0, nil, err
+
+ default:
+ hdr.Version = linux.HighestCapabilityVersion
+ if _, err := t.CopyOut(hdrAddr, &hdr); err != nil {
+ return 0, nil, err
+ }
+ if dataAddr != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ return 0, nil, nil
+ }
+}
+
+// Capset implements Linux syscall capset.
+func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ hdrAddr := args[0].Pointer()
+ dataAddr := args[1].Pointer()
+
+ var hdr linux.CapUserHeader
+ if _, err := t.CopyIn(hdrAddr, &hdr); err != nil {
+ return 0, nil, err
+ }
+ switch hdr.Version {
+ case linux.LINUX_CAPABILITY_VERSION_1:
+ if tid := kernel.ThreadID(hdr.Pid); tid != 0 && tid != t.ThreadID() {
+ return 0, nil, syserror.EPERM
+ }
+ var data linux.CapUserData
+ if _, err := t.CopyIn(dataAddr, &data); err != nil {
+ return 0, nil, err
+ }
+ p := auth.CapabilitySet(data.Permitted) & auth.AllCapabilities
+ i := auth.CapabilitySet(data.Inheritable) & auth.AllCapabilities
+ e := auth.CapabilitySet(data.Effective) & auth.AllCapabilities
+ return 0, nil, t.SetCapabilitySets(p, i, e)
+
+ case linux.LINUX_CAPABILITY_VERSION_2, linux.LINUX_CAPABILITY_VERSION_3:
+ if tid := kernel.ThreadID(hdr.Pid); tid != 0 && tid != t.ThreadID() {
+ return 0, nil, syserror.EPERM
+ }
+ var data [2]linux.CapUserData
+ if _, err := t.CopyIn(dataAddr, &data); err != nil {
+ return 0, nil, err
+ }
+ p := (auth.CapabilitySet(data[0].Permitted) | (auth.CapabilitySet(data[1].Permitted) << 32)) & auth.AllCapabilities
+ i := (auth.CapabilitySet(data[0].Inheritable) | (auth.CapabilitySet(data[1].Inheritable) << 32)) & auth.AllCapabilities
+ e := (auth.CapabilitySet(data[0].Effective) | (auth.CapabilitySet(data[1].Effective) << 32)) & auth.AllCapabilities
+ return 0, nil, t.SetCapabilitySets(p, i, e)
+
+ default:
+ hdr.Version = linux.HighestCapabilityVersion
+ if _, err := t.CopyOut(hdrAddr, &hdr); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/sys_clone_amd64.go b/pkg/sentry/syscalls/linux/sys_clone_amd64.go
new file mode 100644
index 000000000..dd43cf18d
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_clone_amd64.go
@@ -0,0 +1,35 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// Clone implements linux syscall clone(2).
+// sys_clone has so many flavors. We implement the default one in linux 3.11
+// x86_64:
+// sys_clone(clone_flags, newsp, parent_tidptr, child_tidptr, tls_val)
+func Clone(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := int(args[0].Int())
+ stack := args[1].Pointer()
+ parentTID := args[2].Pointer()
+ childTID := args[3].Pointer()
+ tls := args[4].Pointer()
+ return clone(t, flags, stack, parentTID, childTID, tls)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_clone_arm64.go b/pkg/sentry/syscalls/linux/sys_clone_arm64.go
new file mode 100644
index 000000000..cf68a8949
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_clone_arm64.go
@@ -0,0 +1,35 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// Clone implements linux syscall clone(2).
+// sys_clone has so many flavors, and we implement the default one in linux 3.11
+// arm64(kernel/fork.c with CONFIG_CLONE_BACKWARDS defined in the config file):
+// sys_clone(clone_flags, newsp, parent_tidptr, tls_val, child_tidptr)
+func Clone(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := int(args[0].Int())
+ stack := args[1].Pointer()
+ parentTID := args[2].Pointer()
+ tls := args[3].Pointer()
+ childTID := args[4].Pointer()
+ return clone(t, flags, stack, parentTID, childTID, tls)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go
new file mode 100644
index 000000000..7f460d30b
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_epoll.go
@@ -0,0 +1,147 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/epoll"
+ "gvisor.dev/gvisor/pkg/sentry/syscalls"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+// EpollCreate1 implements the epoll_create1(2) linux syscall.
+func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := args[0].Int()
+ if flags & ^linux.EPOLL_CLOEXEC != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ closeOnExec := flags&linux.EPOLL_CLOEXEC != 0
+ fd, err := syscalls.CreateEpoll(t, closeOnExec)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// EpollCreate implements the epoll_create(2) linux syscall.
+func EpollCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ size := args[0].Int()
+
+ if size <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ fd, err := syscalls.CreateEpoll(t, false)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// EpollCtl implements the epoll_ctl(2) linux syscall.
+func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ op := args[1].Int()
+ fd := args[2].Int()
+ eventAddr := args[3].Pointer()
+
+ // Capture the event state if needed.
+ flags := epoll.EntryFlags(0)
+ mask := waiter.EventMask(0)
+ var data [2]int32
+ if op != linux.EPOLL_CTL_DEL {
+ var e linux.EpollEvent
+ if _, err := e.CopyIn(t, eventAddr); err != nil {
+ return 0, nil, err
+ }
+
+ if e.Events&linux.EPOLLONESHOT != 0 {
+ flags |= epoll.OneShot
+ }
+
+ if e.Events&linux.EPOLLET != 0 {
+ flags |= epoll.EdgeTriggered
+ }
+
+ mask = waiter.EventMaskFromLinux(e.Events)
+ data = e.Data
+ }
+
+ // Perform the requested operations.
+ switch op {
+ case linux.EPOLL_CTL_ADD:
+ // See fs/eventpoll.c.
+ mask |= waiter.EventHUp | waiter.EventErr
+ return 0, nil, syscalls.AddEpoll(t, epfd, fd, flags, mask, data)
+ case linux.EPOLL_CTL_DEL:
+ return 0, nil, syscalls.RemoveEpoll(t, epfd, fd)
+ case linux.EPOLL_CTL_MOD:
+ // Same as EPOLL_CTL_ADD.
+ mask |= waiter.EventHUp | waiter.EventErr
+ return 0, nil, syscalls.UpdateEpoll(t, epfd, fd, flags, mask, data)
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+}
+
+// EpollWait implements the epoll_wait(2) linux syscall.
+func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ eventsAddr := args[1].Pointer()
+ maxEvents := int(args[2].Int())
+ timeout := int(args[3].Int())
+
+ r, err := syscalls.WaitEpoll(t, epfd, maxEvents, timeout)
+ if err != nil {
+ return 0, nil, syserror.ConvertIntr(err, syserror.EINTR)
+ }
+
+ if len(r) != 0 {
+ if _, err := linux.CopyEpollEventSliceOut(t, eventsAddr, r); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ return uintptr(len(r)), nil, nil
+}
+
+// EpollPwait implements the epoll_pwait(2) linux syscall.
+func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ maskAddr := args[4].Pointer()
+ maskSize := uint(args[5].Uint())
+
+ if maskAddr != 0 {
+ mask, err := CopyInSigSet(t, maskAddr, maskSize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ oldmask := t.SignalMask()
+ t.SetSignalMask(mask)
+ t.SetSavedSignalMask(oldmask)
+ }
+
+ return EpollWait(t, args)
+}
+
+// LINT.ThenChange(vfs2/epoll.go)
diff --git a/pkg/sentry/syscalls/linux/sys_eventfd.go b/pkg/sentry/syscalls/linux/sys_eventfd.go
new file mode 100644
index 000000000..ed3413ca6
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_eventfd.go
@@ -0,0 +1,56 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/eventfd"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Eventfd2 implements linux syscall eventfd2(2).
+func Eventfd2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ initVal := args[0].Int()
+ flags := uint(args[1].Uint())
+ allOps := uint(linux.EFD_SEMAPHORE | linux.EFD_NONBLOCK | linux.EFD_CLOEXEC)
+
+ if flags & ^allOps != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ event := eventfd.New(t, uint64(initVal), flags&linux.EFD_SEMAPHORE != 0)
+ event.SetFlags(fs.SettableFileFlags{
+ NonBlocking: flags&linux.EFD_NONBLOCK != 0,
+ })
+ defer event.DecRef()
+
+ fd, err := t.NewFDFrom(0, event, kernel.FDFlags{
+ CloseOnExec: flags&linux.EFD_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// Eventfd implements linux syscall eventfd(2).
+func Eventfd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ args[1].Value = 0
+ return Eventfd2(t, args)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
new file mode 100644
index 000000000..2797c6a72
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -0,0 +1,2238 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/fasync"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// fileOpAt performs an operation on the second last component in the path.
+func fileOpAt(t *kernel.Task, dirFD int32, path string, fn func(root *fs.Dirent, d *fs.Dirent, name string, remainingTraversals uint) error) error {
+ // Extract the last component.
+ dir, name := fs.SplitLast(path)
+ if dir == "/" {
+ // Common case: we are accessing a file in the root.
+ root := t.FSContext().RootDirectory()
+ err := fn(root, root, name, linux.MaxSymlinkTraversals)
+ root.DecRef()
+ return err
+ } else if dir == "." && dirFD == linux.AT_FDCWD {
+ // Common case: we are accessing a file relative to the current
+ // working directory; skip the look-up.
+ wd := t.FSContext().WorkingDirectory()
+ root := t.FSContext().RootDirectory()
+ err := fn(root, wd, name, linux.MaxSymlinkTraversals)
+ wd.DecRef()
+ root.DecRef()
+ return err
+ }
+
+ return fileOpOn(t, dirFD, dir, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, remainingTraversals uint) error {
+ return fn(root, d, name, remainingTraversals)
+ })
+}
+
+// fileOpOn performs an operation on the last entry of the path.
+func fileOpOn(t *kernel.Task, dirFD int32, path string, resolve bool, fn func(root *fs.Dirent, d *fs.Dirent, remainingTraversals uint) error) error {
+ var (
+ d *fs.Dirent // The file.
+ wd *fs.Dirent // The working directory (if required.)
+ rel *fs.Dirent // The relative directory for search (if required.)
+ f *fs.File // The file corresponding to dirFD (if required.)
+ err error
+ )
+
+ // Extract the working directory (maybe).
+ if len(path) > 0 && path[0] == '/' {
+ // Absolute path; rel can be nil.
+ } else if dirFD == linux.AT_FDCWD {
+ // Need to reference the working directory.
+ wd = t.FSContext().WorkingDirectory()
+ rel = wd
+ } else {
+ // Need to extract the given FD.
+ f = t.GetFile(dirFD)
+ if f == nil {
+ return syserror.EBADF
+ }
+ rel = f.Dirent
+ if !fs.IsDir(rel.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+ }
+
+ // Grab the root (always required.)
+ root := t.FSContext().RootDirectory()
+
+ // Lookup the node.
+ remainingTraversals := uint(linux.MaxSymlinkTraversals)
+ if resolve {
+ d, err = t.MountNamespace().FindInode(t, root, rel, path, &remainingTraversals)
+ } else {
+ d, err = t.MountNamespace().FindLink(t, root, rel, path, &remainingTraversals)
+ }
+ root.DecRef()
+ if wd != nil {
+ wd.DecRef()
+ }
+ if f != nil {
+ f.DecRef()
+ }
+ if err != nil {
+ return err
+ }
+
+ err = fn(root, d, remainingTraversals)
+ d.DecRef()
+ return err
+}
+
+// copyInPath copies a path in.
+func copyInPath(t *kernel.Task, addr usermem.Addr, allowEmpty bool) (path string, dirPath bool, err error) {
+ path, err = t.CopyInString(addr, linux.PATH_MAX)
+ if err != nil {
+ return "", false, err
+ }
+ if path == "" && !allowEmpty {
+ return "", false, syserror.ENOENT
+ }
+
+ // If the path ends with a /, then checks must be enforced in various
+ // ways in the different callers. We pass this back to the caller.
+ path, dirPath = fs.TrimTrailingSlashes(path)
+
+ return path, dirPath, nil
+}
+
+// LINT.IfChange
+
+func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uintptr, err error) {
+ path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, err
+ }
+
+ resolve := flags&linux.O_NOFOLLOW == 0
+ err = fileOpOn(t, dirFD, path, resolve, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ // First check a few things about the filesystem before trying to get the file
+ // reference.
+ //
+ // It's required that Check does not try to open files not that aren't backed by
+ // this dirent (e.g. pipes and sockets) because this would result in opening these
+ // files an extra time just to check permissions.
+ if err := d.Inode.CheckPermission(t, flagsToPermissions(flags)); err != nil {
+ return err
+ }
+
+ if fs.IsSymlink(d.Inode.StableAttr) && !resolve {
+ return syserror.ELOOP
+ }
+
+ fileFlags := linuxToFlags(flags)
+ // Linux always adds the O_LARGEFILE flag when running in 64-bit mode.
+ fileFlags.LargeFile = true
+ if fs.IsDir(d.Inode.StableAttr) {
+ // Don't allow directories to be opened writable.
+ if fileFlags.Write {
+ return syserror.EISDIR
+ }
+ } else {
+ // If O_DIRECTORY is set, but the file is not a directory, then fail.
+ if fileFlags.Directory {
+ return syserror.ENOTDIR
+ }
+ // If it's a directory, then make sure.
+ if dirPath {
+ return syserror.ENOTDIR
+ }
+ }
+
+ // Truncate is called when O_TRUNC is specified for any kind of
+ // existing Dirent. Behavior is delegated to the entry's Truncate
+ // implementation.
+ if flags&linux.O_TRUNC != 0 {
+ if err := d.Inode.Truncate(t, d, 0); err != nil {
+ return err
+ }
+ }
+
+ file, err := d.Inode.GetFile(t, d, fileFlags)
+ if err != nil {
+ return syserror.ConvertIntr(err, kernel.ERESTARTSYS)
+ }
+ defer file.DecRef()
+
+ // Success.
+ newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ if err != nil {
+ return err
+ }
+
+ // Set return result in frame.
+ fd = uintptr(newFD)
+
+ // Generate notification for opened file.
+ d.InotifyEvent(linux.IN_OPEN, 0)
+
+ return nil
+ })
+ return fd, err // Use result in frame.
+}
+
+func mknodAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error {
+ path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+ if dirPath {
+ return syserror.ENOENT
+ }
+
+ return fileOpAt(t, dirFD, path, func(root *fs.Dirent, d *fs.Dirent, name string, _ uint) error {
+ if !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Do we have the appropriate permissions on the parent?
+ if err := d.Inode.CheckPermission(t, fs.PermMask{Write: true, Execute: true}); err != nil {
+ return err
+ }
+
+ // Attempt a creation.
+ perms := fs.FilePermsFromMode(mode &^ linux.FileMode(t.FSContext().Umask()))
+
+ switch mode.FileType() {
+ case 0:
+ // "Zero file type is equivalent to type S_IFREG." - mknod(2)
+ fallthrough
+ case linux.ModeRegular:
+ // We are not going to return the file, so the actual
+ // flags used don't matter, but they cannot be empty or
+ // Create will complain.
+ flags := fs.FileFlags{Read: true, Write: true}
+ file, err := d.Create(t, root, name, flags, perms)
+ if err != nil {
+ return err
+ }
+ file.DecRef()
+ return nil
+
+ case linux.ModeNamedPipe:
+ return d.CreateFifo(t, root, name, perms)
+
+ case linux.ModeSocket:
+ // While it is possible create a unix domain socket file on linux
+ // using mknod(2), in practice this is pretty useless from an
+ // application. Linux internally uses mknod() to create the socket
+ // node during bind(2), but we implement bind(2) independently. If
+ // an application explicitly creates a socket node using mknod(),
+ // you can't seem to bind() or connect() to the resulting socket.
+ //
+ // Instead of emulating this seemingly useless behaviour, we'll
+ // indicate that the filesystem doesn't support the creation of
+ // sockets.
+ return syserror.EOPNOTSUPP
+
+ case linux.ModeCharacterDevice:
+ fallthrough
+ case linux.ModeBlockDevice:
+ // TODO(b/72101894): We don't support creating block or character
+ // devices at the moment.
+ //
+ // When we start supporting block and character devices, we'll
+ // need to check for CAP_MKNOD here.
+ return syserror.EPERM
+
+ default:
+ // "EINVAL - mode requested creation of something other than a
+ // regular file, device special file, FIFO or socket." - mknod(2)
+ return syserror.EINVAL
+ }
+ })
+}
+
+// Mknod implements the linux syscall mknod(2).
+func Mknod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ path := args[0].Pointer()
+ mode := linux.FileMode(args[1].ModeT())
+ // We don't need this argument until we support creation of device nodes.
+ _ = args[2].Uint() // dev
+
+ return 0, nil, mknodAt(t, linux.AT_FDCWD, path, mode)
+}
+
+// Mknodat implements the linux syscall mknodat(2).
+func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := args[0].Int()
+ path := args[1].Pointer()
+ mode := linux.FileMode(args[2].ModeT())
+ // We don't need this argument until we support creation of device nodes.
+ _ = args[3].Uint() // dev
+
+ return 0, nil, mknodAt(t, dirFD, path, mode)
+}
+
+func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode linux.FileMode) (fd uintptr, err error) {
+ path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, err
+ }
+ if dirPath {
+ return 0, syserror.ENOENT
+ }
+
+ fileFlags := linuxToFlags(flags)
+ // Linux always adds the O_LARGEFILE flag when running in 64-bit mode.
+ fileFlags.LargeFile = true
+
+ err = fileOpAt(t, dirFD, path, func(root *fs.Dirent, parent *fs.Dirent, name string, remainingTraversals uint) error {
+ // Resolve the name to see if it exists, and follow any
+ // symlinks along the way. We must do the symlink resolution
+ // manually because if the symlink target does not exist, we
+ // must create the target (and not the symlink itself).
+ var (
+ found *fs.Dirent
+ err error
+ )
+ for {
+ if !fs.IsDir(parent.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Start by looking up the dirent at 'name'.
+ found, err = t.MountNamespace().FindLink(t, root, parent, name, &remainingTraversals)
+ if err != nil {
+ break
+ }
+ defer found.DecRef()
+
+ // We found something (possibly a symlink). If the
+ // O_EXCL flag was passed, then we can immediately
+ // return EEXIST.
+ if flags&linux.O_EXCL != 0 {
+ return syserror.EEXIST
+ }
+
+ // If we have a non-symlink, then we can proceed.
+ if !fs.IsSymlink(found.Inode.StableAttr) {
+ break
+ }
+
+ // If O_NOFOLLOW was passed, then don't try to resolve
+ // anything.
+ if flags&linux.O_NOFOLLOW != 0 {
+ return syserror.ELOOP
+ }
+
+ // Try to resolve the symlink directly to a Dirent.
+ var resolved *fs.Dirent
+ resolved, err = found.Inode.Getlink(t)
+ if err == nil {
+ // No more resolution necessary.
+ defer resolved.DecRef()
+ break
+ }
+ if err != fs.ErrResolveViaReadlink {
+ return err
+ }
+
+ // Are we able to resolve further?
+ if remainingTraversals == 0 {
+ return syscall.ELOOP
+ }
+
+ // Resolve the symlink to a path via Readlink.
+ var path string
+ path, err = found.Inode.Readlink(t)
+ if err != nil {
+ break
+ }
+ remainingTraversals--
+
+ // Get the new parent from the target path.
+ var newParent *fs.Dirent
+ newParentPath, newName := fs.SplitLast(path)
+ newParent, err = t.MountNamespace().FindInode(t, root, parent, newParentPath, &remainingTraversals)
+ if err != nil {
+ break
+ }
+ defer newParent.DecRef()
+
+ // Repeat the process with the parent and name of the
+ // symlink target.
+ parent = newParent
+ name = newName
+ }
+
+ var newFile *fs.File
+ switch err {
+ case nil:
+ // Like sys_open, check for a few things about the
+ // filesystem before trying to get a reference to the
+ // fs.File. The same constraints on Check apply.
+ if err := found.Inode.CheckPermission(t, flagsToPermissions(flags)); err != nil {
+ return err
+ }
+
+ // Truncate is called when O_TRUNC is specified for any kind of
+ // existing Dirent. Behavior is delegated to the entry's Truncate
+ // implementation.
+ if flags&linux.O_TRUNC != 0 {
+ if err := found.Inode.Truncate(t, found, 0); err != nil {
+ return err
+ }
+ }
+
+ // Create a new fs.File.
+ newFile, err = found.Inode.GetFile(t, found, fileFlags)
+ if err != nil {
+ return syserror.ConvertIntr(err, kernel.ERESTARTSYS)
+ }
+ defer newFile.DecRef()
+ case syserror.ENOENT:
+ // File does not exist. Proceed with creation.
+
+ // Do we have write permissions on the parent?
+ if err := parent.Inode.CheckPermission(t, fs.PermMask{Write: true, Execute: true}); err != nil {
+ return err
+ }
+
+ // Attempt a creation.
+ perms := fs.FilePermsFromMode(mode &^ linux.FileMode(t.FSContext().Umask()))
+ newFile, err = parent.Create(t, root, name, fileFlags, perms)
+ if err != nil {
+ // No luck, bail.
+ return err
+ }
+ defer newFile.DecRef()
+ found = newFile.Dirent
+ default:
+ return err
+ }
+
+ // Success.
+ newFD, err := t.NewFDFrom(0, newFile, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ if err != nil {
+ return err
+ }
+
+ // Set result in frame.
+ fd = uintptr(newFD)
+
+ // Queue the open inotify event. The creation event is
+ // automatically queued when the dirent is found. The open
+ // events are implemented at the syscall layer so we need to
+ // manually queue one here.
+ found.InotifyEvent(linux.IN_OPEN, 0)
+
+ return nil
+ })
+ return fd, err // Use result in frame.
+}
+
+// Open implements linux syscall open(2).
+func Open(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := uint(args[1].Uint())
+ if flags&linux.O_CREAT != 0 {
+ mode := linux.FileMode(args[2].ModeT())
+ n, err := createAt(t, linux.AT_FDCWD, addr, flags, mode)
+ return n, nil, err
+ }
+ n, err := openAt(t, linux.AT_FDCWD, addr, flags)
+ return n, nil, err
+}
+
+// Openat implements linux syscall openat(2).
+func Openat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := args[0].Int()
+ addr := args[1].Pointer()
+ flags := uint(args[2].Uint())
+ if flags&linux.O_CREAT != 0 {
+ mode := linux.FileMode(args[3].ModeT())
+ n, err := createAt(t, dirFD, addr, flags, mode)
+ return n, nil, err
+ }
+ n, err := openAt(t, dirFD, addr, flags)
+ return n, nil, err
+}
+
+// Creat implements linux syscall creat(2).
+func Creat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := linux.FileMode(args[1].ModeT())
+ n, err := createAt(t, linux.AT_FDCWD, addr, linux.O_WRONLY|linux.O_TRUNC, mode)
+ return n, nil, err
+}
+
+// accessContext is a context that overrides the credentials used, but
+// otherwise carries the same values as the embedded context.
+//
+// accessContext should only be used for access(2).
+type accessContext struct {
+ context.Context
+ creds *auth.Credentials
+}
+
+// Value implements context.Context.
+func (ac accessContext) Value(key interface{}) interface{} {
+ switch key {
+ case auth.CtxCredentials:
+ return ac.creds
+ default:
+ return ac.Context.Value(key)
+ }
+}
+
+func accessAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode uint) error {
+ const rOK = 4
+ const wOK = 2
+ const xOK = 1
+
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+
+ // Sanity check the mode.
+ if mode&^(rOK|wOK|xOK) != 0 {
+ return syserror.EINVAL
+ }
+
+ return fileOpOn(t, dirFD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ // access(2) and faccessat(2) check permissions using real
+ // UID/GID, not effective UID/GID.
+ //
+ // "access() needs to use the real uid/gid, not the effective
+ // uid/gid. We do this by temporarily clearing all FS-related
+ // capabilities and switching the fsuid/fsgid around to the
+ // real ones." -fs/open.c:faccessat
+ creds := t.Credentials().Fork()
+ creds.EffectiveKUID = creds.RealKUID
+ creds.EffectiveKGID = creds.RealKGID
+ if creds.EffectiveKUID.In(creds.UserNamespace) == auth.RootUID {
+ creds.EffectiveCaps = creds.PermittedCaps
+ } else {
+ creds.EffectiveCaps = 0
+ }
+
+ ctx := &accessContext{
+ Context: t,
+ creds: creds,
+ }
+
+ return d.Inode.CheckPermission(ctx, fs.PermMask{
+ Read: mode&rOK != 0,
+ Write: mode&wOK != 0,
+ Execute: mode&xOK != 0,
+ })
+ })
+}
+
+// Access implements linux syscall access(2).
+func Access(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+
+ return 0, nil, accessAt(t, linux.AT_FDCWD, addr, mode)
+}
+
+// Faccessat implements linux syscall faccessat(2).
+//
+// Note that the faccessat() system call does not take a flags argument:
+// "The raw faccessat() system call takes only the first three arguments. The
+// AT_EACCESS and AT_SYMLINK_NOFOLLOW flags are actually implemented within
+// the glibc wrapper function for faccessat(). If either of these flags is
+// specified, then the wrapper function employs fstatat(2) to determine access
+// permissions." - faccessat(2)
+func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := args[0].Int()
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+
+ return 0, nil, accessAt(t, dirFD, addr, mode)
+}
+
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
+// Ioctl implements linux syscall ioctl(2).
+func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ request := int(args[1].Int())
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Shared flags between file and socket.
+ switch request {
+ case linux.FIONCLEX:
+ t.FDTable().SetFlags(fd, kernel.FDFlags{
+ CloseOnExec: false,
+ })
+ return 0, nil, nil
+ case linux.FIOCLEX:
+ t.FDTable().SetFlags(fd, kernel.FDFlags{
+ CloseOnExec: true,
+ })
+ return 0, nil, nil
+
+ case linux.FIONBIO:
+ var set int32
+ if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ return 0, nil, err
+ }
+ flags := file.Flags()
+ if set != 0 {
+ flags.NonBlocking = true
+ } else {
+ flags.NonBlocking = false
+ }
+ file.SetFlags(flags.Settable())
+ return 0, nil, nil
+
+ case linux.FIOASYNC:
+ var set int32
+ if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ return 0, nil, err
+ }
+ flags := file.Flags()
+ if set != 0 {
+ flags.Async = true
+ } else {
+ flags.Async = false
+ }
+ file.SetFlags(flags.Settable())
+ return 0, nil, nil
+
+ case linux.FIOSETOWN, linux.SIOCSPGRP:
+ var set int32
+ if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ return 0, nil, err
+ }
+ fSetOwn(t, file, set)
+ return 0, nil, nil
+
+ case linux.FIOGETOWN, linux.SIOCGPGRP:
+ who := fGetOwn(t, file)
+ _, err := t.CopyOut(args[2].Pointer(), &who)
+ return 0, nil, err
+
+ default:
+ ret, err := file.FileOperations.Ioctl(t, file, t.MemoryManager(), args)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return ret, nil, nil
+ }
+}
+
+// LINT.ThenChange(vfs2/ioctl.go)
+
+// LINT.IfChange
+
+// Getcwd implements the linux syscall getcwd(2).
+func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ size := args[1].SizeT()
+ cwd := t.FSContext().WorkingDirectory()
+ defer cwd.DecRef()
+ root := t.FSContext().RootDirectory()
+ defer root.DecRef()
+
+ // Get our fullname from the root and preprend unreachable if the root was
+ // unreachable from our current dirent this is the same behavior as on linux.
+ s, reachable := cwd.FullName(root)
+ if !reachable {
+ s = "(unreachable)" + s
+ }
+
+ // Note this is >= because we need a terminator.
+ if uint(len(s)) >= size {
+ return 0, nil, syserror.ERANGE
+ }
+
+ // Copy out the path name for the node.
+ bytes, err := t.CopyOutBytes(addr, []byte(s))
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Top it off with a terminator.
+ _, err = t.CopyOut(addr+usermem.Addr(bytes), []byte("\x00"))
+ return uintptr(bytes + 1), nil, err
+}
+
+// Chroot implements the linux syscall chroot(2).
+func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ if !t.HasCapability(linux.CAP_SYS_CHROOT) {
+ return 0, nil, syserror.EPERM
+ }
+
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ // Is it a directory?
+ if !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Does it have execute permissions?
+ if err := d.Inode.CheckPermission(t, fs.PermMask{Execute: true}); err != nil {
+ return err
+ }
+
+ t.FSContext().SetRootDirectory(d)
+ return nil
+ })
+}
+
+// Chdir implements the linux syscall chdir(2).
+func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ // Is it a directory?
+ if !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Does it have execute permissions?
+ if err := d.Inode.CheckPermission(t, fs.PermMask{Execute: true}); err != nil {
+ return err
+ }
+
+ t.FSContext().SetWorkingDirectory(d)
+ return nil
+ })
+}
+
+// Fchdir implements the linux syscall fchdir(2).
+func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Is it a directory?
+ if !fs.IsDir(file.Dirent.Inode.StableAttr) {
+ return 0, nil, syserror.ENOTDIR
+ }
+
+ // Does it have execute permissions?
+ if err := file.Dirent.Inode.CheckPermission(t, fs.PermMask{Execute: true}); err != nil {
+ return 0, nil, err
+ }
+
+ t.FSContext().SetWorkingDirectory(file.Dirent)
+ return 0, nil, nil
+}
+
+// LINT.ThenChange(vfs2/fscontext.go)
+
+// LINT.IfChange
+
+// Close implements linux syscall close(2).
+func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ // Note that Remove provides a reference on the file that we may use to
+ // flush. It is still active until we drop the final reference below
+ // (and other reference-holding operations complete).
+ file, _ := t.FDTable().Remove(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ err := file.Flush(t)
+ return 0, nil, handleIOError(t, false /* partial */, err, syserror.EINTR, "close", file)
+}
+
+// Dup implements linux syscall dup(2).
+func Dup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{})
+ if err != nil {
+ return 0, nil, syserror.EMFILE
+ }
+ return uintptr(newFD), nil, nil
+}
+
+// Dup2 implements linux syscall dup2(2).
+func Dup2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldfd := args[0].Int()
+ newfd := args[1].Int()
+
+ // If oldfd is a valid file descriptor, and newfd has the same value as oldfd,
+ // then dup2() does nothing, and returns newfd.
+ if oldfd == newfd {
+ oldFile := t.GetFile(oldfd)
+ if oldFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer oldFile.DecRef()
+
+ return uintptr(newfd), nil, nil
+ }
+
+ // Zero out flags arg to be used by Dup3.
+ args[2].Value = 0
+ return Dup3(t, args)
+}
+
+// Dup3 implements linux syscall dup3(2).
+func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldfd := args[0].Int()
+ newfd := args[1].Int()
+ flags := args[2].Uint()
+
+ if oldfd == newfd {
+ return 0, nil, syserror.EINVAL
+ }
+
+ oldFile := t.GetFile(oldfd)
+ if oldFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer oldFile.DecRef()
+
+ err := t.NewFDAt(newfd, oldFile, kernel.FDFlags{CloseOnExec: flags&linux.O_CLOEXEC != 0})
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(newfd), nil, nil
+}
+
+func fGetOwnEx(t *kernel.Task, file *fs.File) linux.FOwnerEx {
+ ma := file.Async(nil)
+ if ma == nil {
+ return linux.FOwnerEx{}
+ }
+ a := ma.(*fasync.FileAsync)
+ ot, otg, opg := a.Owner()
+ switch {
+ case ot != nil:
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_TID,
+ PID: int32(t.PIDNamespace().IDOfTask(ot)),
+ }
+ case otg != nil:
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_PID,
+ PID: int32(t.PIDNamespace().IDOfThreadGroup(otg)),
+ }
+ case opg != nil:
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_PGRP,
+ PID: int32(t.PIDNamespace().IDOfProcessGroup(opg)),
+ }
+ default:
+ return linux.FOwnerEx{}
+ }
+}
+
+func fGetOwn(t *kernel.Task, file *fs.File) int32 {
+ owner := fGetOwnEx(t, file)
+ if owner.Type == linux.F_OWNER_PGRP {
+ return -owner.PID
+ }
+ return owner.PID
+}
+
+// fSetOwn sets the file's owner with the semantics of F_SETOWN in Linux.
+//
+// If who is positive, it represents a PID. If negative, it represents a PGID.
+// If the PID or PGID is invalid, the owner is silently unset.
+func fSetOwn(t *kernel.Task, file *fs.File, who int32) error {
+ a := file.Async(fasync.New).(*fasync.FileAsync)
+ if who < 0 {
+ // Check for overflow before flipping the sign.
+ if who-1 > who {
+ return syserror.EINVAL
+ }
+ pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(-who))
+ a.SetOwnerProcessGroup(t, pg)
+ } else {
+ tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(who))
+ a.SetOwnerThreadGroup(t, tg)
+ }
+ return nil
+}
+
+// Fcntl implements linux syscall fcntl(2).
+func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ cmd := args[1].Int()
+
+ file, flags := t.FDTable().Get(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ switch cmd {
+ case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC:
+ from := args[2].Int()
+ fd, err := t.NewFDFrom(from, file, kernel.FDFlags{
+ CloseOnExec: cmd == linux.F_DUPFD_CLOEXEC,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+ case linux.F_GETFD:
+ return uintptr(flags.ToLinuxFDFlags()), nil, nil
+ case linux.F_SETFD:
+ flags := args[2].Uint()
+ err := t.FDTable().SetFlags(fd, kernel.FDFlags{
+ CloseOnExec: flags&linux.FD_CLOEXEC != 0,
+ })
+ return 0, nil, err
+ case linux.F_GETFL:
+ return uintptr(file.Flags().ToLinux()), nil, nil
+ case linux.F_SETFL:
+ flags := uint(args[2].Uint())
+ file.SetFlags(linuxToFlags(flags).Settable())
+ return 0, nil, nil
+ case linux.F_SETLK, linux.F_SETLKW:
+ // In Linux the file system can choose to provide lock operations for an inode.
+ // Normally pipe and socket types lack lock operations. We diverge and use a heavy
+ // hammer by only allowing locks on files and directories.
+ if !fs.IsFile(file.Dirent.Inode.StableAttr) && !fs.IsDir(file.Dirent.Inode.StableAttr) {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Copy in the lock request.
+ flockAddr := args[2].Pointer()
+ var flock linux.Flock
+ if _, err := t.CopyIn(flockAddr, &flock); err != nil {
+ return 0, nil, err
+ }
+
+ // Compute the lock whence.
+ var sw fs.SeekWhence
+ switch flock.Whence {
+ case 0:
+ sw = fs.SeekSet
+ case 1:
+ sw = fs.SeekCurrent
+ case 2:
+ sw = fs.SeekEnd
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Compute the lock offset.
+ var off int64
+ switch sw {
+ case fs.SeekSet:
+ off = 0
+ case fs.SeekCurrent:
+ // Note that Linux does not hold any mutexes while retrieving the file offset,
+ // see fs/locks.c:flock_to_posix_lock and fs/locks.c:fcntl_setlk.
+ off = file.Offset()
+ case fs.SeekEnd:
+ uattr, err := file.Dirent.Inode.UnstableAttr(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ off = uattr.Size
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Compute the lock range.
+ rng, err := lock.ComputeRange(flock.Start, flock.Len, off)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // These locks don't block; execute the non-blocking operation using the inode's lock
+ // context directly.
+ switch flock.Type {
+ case linux.F_RDLCK:
+ if !file.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+ if cmd == linux.F_SETLK {
+ // Non-blocking lock, provide a nil lock.Blocker.
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.ReadLock, rng, nil) {
+ return 0, nil, syserror.EAGAIN
+ }
+ } else {
+ // Blocking lock, pass in the task to satisfy the lock.Blocker interface.
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.ReadLock, rng, t) {
+ return 0, nil, syserror.EINTR
+ }
+ }
+ return 0, nil, nil
+ case linux.F_WRLCK:
+ if !file.Flags().Write {
+ return 0, nil, syserror.EBADF
+ }
+ if cmd == linux.F_SETLK {
+ // Non-blocking lock, provide a nil lock.Blocker.
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.WriteLock, rng, nil) {
+ return 0, nil, syserror.EAGAIN
+ }
+ } else {
+ // Blocking lock, pass in the task to satisfy the lock.Blocker interface.
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.WriteLock, rng, t) {
+ return 0, nil, syserror.EINTR
+ }
+ }
+ return 0, nil, nil
+ case linux.F_UNLCK:
+ file.Dirent.Inode.LockCtx.Posix.UnlockRegion(t.FDTable(), rng)
+ return 0, nil, nil
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+ case linux.F_GETOWN:
+ return uintptr(fGetOwn(t, file)), nil, nil
+ case linux.F_SETOWN:
+ return 0, nil, fSetOwn(t, file, args[2].Int())
+ case linux.F_GETOWN_EX:
+ addr := args[2].Pointer()
+ owner := fGetOwnEx(t, file)
+ _, err := t.CopyOut(addr, &owner)
+ return 0, nil, err
+ case linux.F_SETOWN_EX:
+ addr := args[2].Pointer()
+ var owner linux.FOwnerEx
+ n, err := t.CopyIn(addr, &owner)
+ if err != nil {
+ return 0, nil, err
+ }
+ a := file.Async(fasync.New).(*fasync.FileAsync)
+ switch owner.Type {
+ case linux.F_OWNER_TID:
+ task := t.PIDNamespace().TaskWithID(kernel.ThreadID(owner.PID))
+ if task == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ a.SetOwnerTask(t, task)
+ return uintptr(n), nil, nil
+ case linux.F_OWNER_PID:
+ tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(owner.PID))
+ if tg == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ a.SetOwnerThreadGroup(t, tg)
+ return uintptr(n), nil, nil
+ case linux.F_OWNER_PGRP:
+ pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(owner.PID))
+ if pg == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ a.SetOwnerProcessGroup(t, pg)
+ return uintptr(n), nil, nil
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+ case linux.F_GET_SEALS:
+ val, err := tmpfs.GetSeals(file.Dirent.Inode)
+ return uintptr(val), nil, err
+ case linux.F_ADD_SEALS:
+ if !file.Flags().Write {
+ return 0, nil, syserror.EPERM
+ }
+ err := tmpfs.AddSeals(file.Dirent.Inode, args[2].Uint())
+ return 0, nil, err
+ case linux.F_GETPIPE_SZ:
+ sz, ok := file.FileOperations.(fs.FifoSizer)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+ size, err := sz.FifoSize(t, file)
+ return uintptr(size), nil, err
+ case linux.F_SETPIPE_SZ:
+ sz, ok := file.FileOperations.(fs.FifoSizer)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+ n, err := sz.SetFifoSize(int64(args[2].Int()))
+ return uintptr(n), nil, err
+ default:
+ // Everything else is not yet supported.
+ return 0, nil, syserror.EINVAL
+ }
+}
+
+// Fadvise64 implements linux syscall fadvise64(2).
+// This implementation currently ignores the provided advice.
+func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ length := args[2].Int64()
+ advice := args[3].Int()
+
+ // Note: offset is allowed to be negative.
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // If the FD refers to a pipe or FIFO, return error.
+ if fs.IsPipe(file.Dirent.Inode.StableAttr) {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ switch advice {
+ case linux.POSIX_FADV_NORMAL:
+ case linux.POSIX_FADV_RANDOM:
+ case linux.POSIX_FADV_SEQUENTIAL:
+ case linux.POSIX_FADV_WILLNEED:
+ case linux.POSIX_FADV_DONTNEED:
+ case linux.POSIX_FADV_NOREUSE:
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Sure, whatever.
+ return 0, nil, nil
+}
+
+func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error {
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+
+ return fileOpAt(t, dirFD, path, func(root *fs.Dirent, d *fs.Dirent, name string, _ uint) error {
+ if !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Does this directory exist already?
+ remainingTraversals := uint(linux.MaxSymlinkTraversals)
+ f, err := t.MountNamespace().FindInode(t, root, d, name, &remainingTraversals)
+ switch err {
+ case nil:
+ // The directory existed.
+ defer f.DecRef()
+ return syserror.EEXIST
+ case syserror.EACCES:
+ // Permission denied while walking to the directory.
+ return err
+ default:
+ // Do we have write permissions on the parent?
+ if err := d.Inode.CheckPermission(t, fs.PermMask{Write: true, Execute: true}); err != nil {
+ return err
+ }
+
+ // Create the directory.
+ perms := fs.FilePermsFromMode(mode &^ linux.FileMode(t.FSContext().Umask()))
+ return d.CreateDirectory(t, root, name, perms)
+ }
+ })
+}
+
+// Mkdir implements linux syscall mkdir(2).
+func Mkdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := linux.FileMode(args[1].ModeT())
+
+ return 0, nil, mkdirAt(t, linux.AT_FDCWD, addr, mode)
+}
+
+// Mkdirat implements linux syscall mkdirat(2).
+func Mkdirat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := args[0].Int()
+ addr := args[1].Pointer()
+ mode := linux.FileMode(args[2].ModeT())
+
+ return 0, nil, mkdirAt(t, dirFD, addr, mode)
+}
+
+func rmdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+
+ // Special case: removing the root always returns EBUSY.
+ if path == "/" {
+ return syserror.EBUSY
+ }
+
+ return fileOpAt(t, dirFD, path, func(root *fs.Dirent, d *fs.Dirent, name string, _ uint) error {
+ if !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Linux returns different ernos when the path ends in single
+ // dot vs. double dots.
+ switch name {
+ case ".":
+ return syserror.EINVAL
+ case "..":
+ return syserror.ENOTEMPTY
+ }
+
+ if err := d.MayDelete(t, root, name); err != nil {
+ return err
+ }
+
+ return d.RemoveDirectory(t, root, name)
+ })
+}
+
+// Rmdir implements linux syscall rmdir(2).
+func Rmdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ return 0, nil, rmdirAt(t, linux.AT_FDCWD, addr)
+}
+
+func symlinkAt(t *kernel.Task, dirFD int32, newAddr usermem.Addr, oldAddr usermem.Addr) error {
+ newPath, dirPath, err := copyInPath(t, newAddr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+ if dirPath {
+ return syserror.ENOENT
+ }
+
+ // The oldPath is copied in verbatim. This is because the symlink
+ // will include all details, including trailing slashes.
+ oldPath, err := t.CopyInString(oldAddr, linux.PATH_MAX)
+ if err != nil {
+ return err
+ }
+ if oldPath == "" {
+ return syserror.ENOENT
+ }
+
+ return fileOpAt(t, dirFD, newPath, func(root *fs.Dirent, d *fs.Dirent, name string, _ uint) error {
+ if !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Make sure we have write permissions on the parent directory.
+ if err := d.Inode.CheckPermission(t, fs.PermMask{Write: true, Execute: true}); err != nil {
+ return err
+ }
+ return d.CreateLink(t, root, oldPath, name)
+ })
+}
+
+// Symlink implements linux syscall symlink(2).
+func Symlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldAddr := args[0].Pointer()
+ newAddr := args[1].Pointer()
+
+ return 0, nil, symlinkAt(t, linux.AT_FDCWD, newAddr, oldAddr)
+}
+
+// Symlinkat implements linux syscall symlinkat(2).
+func Symlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldAddr := args[0].Pointer()
+ dirFD := args[1].Int()
+ newAddr := args[2].Pointer()
+
+ return 0, nil, symlinkAt(t, dirFD, newAddr, oldAddr)
+}
+
+// mayLinkAt determines whether t can create a hard link to target.
+//
+// This corresponds to Linux's fs/namei.c:may_linkat.
+func mayLinkAt(t *kernel.Task, target *fs.Inode) error {
+ // Linux will impose the following restrictions on hard links only if
+ // sysctl_protected_hardlinks is enabled. The kernel disables this
+ // setting by default for backward compatibility (see commit
+ // 561ec64ae67e), but also recommends that distributions enable it (and
+ // Debian does:
+ // https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=889098).
+ //
+ // gVisor currently behaves as though sysctl_protected_hardlinks is
+ // always enabled, and thus imposes the following restrictions on hard
+ // links.
+
+ if target.CheckOwnership(t) {
+ // fs/namei.c:may_linkat: "Source inode owner (or CAP_FOWNER)
+ // can hardlink all they like."
+ return nil
+ }
+
+ // If we are not the owner, then the file must be regular and have
+ // Read+Write permissions.
+ if !fs.IsRegular(target.StableAttr) {
+ return syserror.EPERM
+ }
+ if target.CheckPermission(t, fs.PermMask{Read: true, Write: true}) != nil {
+ return syserror.EPERM
+ }
+
+ return nil
+}
+
+// linkAt creates a hard link to the target specified by oldDirFD and oldAddr,
+// specified by newDirFD and newAddr. If resolve is true, then the symlinks
+// will be followed when evaluating the target.
+func linkAt(t *kernel.Task, oldDirFD int32, oldAddr usermem.Addr, newDirFD int32, newAddr usermem.Addr, resolve, allowEmpty bool) error {
+ oldPath, _, err := copyInPath(t, oldAddr, allowEmpty)
+ if err != nil {
+ return err
+ }
+ newPath, dirPath, err := copyInPath(t, newAddr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+ if dirPath {
+ return syserror.ENOENT
+ }
+
+ if allowEmpty && oldPath == "" {
+ target := t.GetFile(oldDirFD)
+ if target == nil {
+ return syserror.EBADF
+ }
+ defer target.DecRef()
+ if err := mayLinkAt(t, target.Dirent.Inode); err != nil {
+ return err
+ }
+
+ // Resolve the target directory.
+ return fileOpAt(t, newDirFD, newPath, func(root *fs.Dirent, newParent *fs.Dirent, newName string, _ uint) error {
+ if !fs.IsDir(newParent.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Make sure we have write permissions on the parent directory.
+ if err := newParent.Inode.CheckPermission(t, fs.PermMask{Write: true, Execute: true}); err != nil {
+ return err
+ }
+ return newParent.CreateHardLink(t, root, target.Dirent, newName)
+ })
+ }
+
+ // Resolve oldDirFD and oldAddr to a dirent. The "resolve" argument
+ // only applies to this name.
+ return fileOpOn(t, oldDirFD, oldPath, resolve, func(root *fs.Dirent, target *fs.Dirent, _ uint) error {
+ if err := mayLinkAt(t, target.Inode); err != nil {
+ return err
+ }
+
+ // Next resolve newDirFD and newAddr to the parent dirent and name.
+ return fileOpAt(t, newDirFD, newPath, func(root *fs.Dirent, newParent *fs.Dirent, newName string, _ uint) error {
+ if !fs.IsDir(newParent.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Make sure we have write permissions on the parent directory.
+ if err := newParent.Inode.CheckPermission(t, fs.PermMask{Write: true, Execute: true}); err != nil {
+ return err
+ }
+ return newParent.CreateHardLink(t, root, target, newName)
+ })
+ })
+}
+
+// Link implements linux syscall link(2).
+func Link(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldAddr := args[0].Pointer()
+ newAddr := args[1].Pointer()
+
+ // man link(2):
+ // POSIX.1-2001 says that link() should dereference oldpath if it is a
+ // symbolic link. However, since kernel 2.0, Linux does not do so: if
+ // oldpath is a symbolic link, then newpath is created as a (hard) link
+ // to the same symbolic link file (i.e., newpath becomes a symbolic
+ // link to the same file that oldpath refers to).
+ resolve := false
+ return 0, nil, linkAt(t, linux.AT_FDCWD, oldAddr, linux.AT_FDCWD, newAddr, resolve, false /* allowEmpty */)
+}
+
+// Linkat implements linux syscall linkat(2).
+func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldDirFD := args[0].Int()
+ oldAddr := args[1].Pointer()
+ newDirFD := args[2].Int()
+ newAddr := args[3].Pointer()
+
+ // man linkat(2):
+ // By default, linkat(), does not dereference oldpath if it is a
+ // symbolic link (like link(2)). Since Linux 2.6.18, the flag
+ // AT_SYMLINK_FOLLOW can be specified in flags to cause oldpath to be
+ // dereferenced if it is a symbolic link.
+ flags := args[4].Int()
+
+ // Sanity check flags.
+ if flags&^(linux.AT_SYMLINK_FOLLOW|linux.AT_EMPTY_PATH) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ resolve := flags&linux.AT_SYMLINK_FOLLOW == linux.AT_SYMLINK_FOLLOW
+ allowEmpty := flags&linux.AT_EMPTY_PATH == linux.AT_EMPTY_PATH
+
+ if allowEmpty && !t.HasCapabilityIn(linux.CAP_DAC_READ_SEARCH, t.UserNamespace().Root()) {
+ return 0, nil, syserror.ENOENT
+ }
+
+ return 0, nil, linkAt(t, oldDirFD, oldAddr, newDirFD, newAddr, resolve, allowEmpty)
+}
+
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
+func readlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr, bufAddr usermem.Addr, size uint) (copied uintptr, err error) {
+ path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, err
+ }
+ if dirPath {
+ return 0, syserror.ENOENT
+ }
+
+ err = fileOpOn(t, dirFD, path, false /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ // Check for Read permission.
+ if err := d.Inode.CheckPermission(t, fs.PermMask{Read: true}); err != nil {
+ return err
+ }
+
+ s, err := d.Inode.Readlink(t)
+ if err == syserror.ENOLINK {
+ return syserror.EINVAL
+ }
+ if err != nil {
+ return err
+ }
+
+ buffer := []byte(s)
+ if uint(len(buffer)) > size {
+ buffer = buffer[:size]
+ }
+
+ n, err := t.CopyOutBytes(bufAddr, buffer)
+
+ // Update frame return value.
+ copied = uintptr(n)
+
+ return err
+ })
+ return copied, err // Return frame value.
+}
+
+// Readlink implements linux syscall readlink(2).
+func Readlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ bufAddr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ n, err := readlinkAt(t, linux.AT_FDCWD, addr, bufAddr, size)
+ return n, nil, err
+}
+
+// Readlinkat implements linux syscall readlinkat(2).
+func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := args[0].Int()
+ addr := args[1].Pointer()
+ bufAddr := args[2].Pointer()
+ size := args[3].SizeT()
+
+ n, err := readlinkAt(t, dirFD, addr, bufAddr, size)
+ return n, nil, err
+}
+
+// LINT.ThenChange(vfs2/stat.go)
+
+// LINT.IfChange
+
+func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
+ path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+
+ return fileOpAt(t, dirFD, path, func(root *fs.Dirent, d *fs.Dirent, name string, _ uint) error {
+ if !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ if err := d.MayDelete(t, root, name); err != nil {
+ return err
+ }
+
+ return d.Remove(t, root, name, dirPath)
+ })
+}
+
+// Unlink implements linux syscall unlink(2).
+func Unlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ return 0, nil, unlinkAt(t, linux.AT_FDCWD, addr)
+}
+
+// Unlinkat implements linux syscall unlinkat(2).
+func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := args[0].Int()
+ addr := args[1].Pointer()
+ flags := args[2].Uint()
+ if flags&linux.AT_REMOVEDIR != 0 {
+ return 0, nil, rmdirAt(t, dirFD, addr)
+ }
+ return 0, nil, unlinkAt(t, dirFD, addr)
+}
+
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
+// Truncate implements linux syscall truncate(2).
+func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].Int64()
+
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+ if dirPath {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if uint64(length) >= t.ThreadGroup().Limits().Get(limits.FileSize).Cur {
+ t.SendSignal(&arch.SignalInfo{
+ Signo: int32(linux.SIGXFSZ),
+ Code: arch.SignalInfoUser,
+ })
+ return 0, nil, syserror.EFBIG
+ }
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ if fs.IsDir(d.Inode.StableAttr) {
+ return syserror.EISDIR
+ }
+ // In contrast to open(O_TRUNC), truncate(2) is only valid for file
+ // types.
+ if !fs.IsFile(d.Inode.StableAttr) {
+ return syserror.EINVAL
+ }
+
+ // Reject truncation if the access permissions do not allow truncation.
+ // This is different from the behavior of sys_ftruncate, see below.
+ if err := d.Inode.CheckPermission(t, fs.PermMask{Write: true}); err != nil {
+ return err
+ }
+
+ if err := d.Inode.Truncate(t, d, length); err != nil {
+ return err
+ }
+
+ // File length modified, generate notification.
+ d.InotifyEvent(linux.IN_MODIFY, 0)
+
+ return nil
+ })
+}
+
+// Ftruncate implements linux syscall ftruncate(2).
+func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ length := args[1].Int64()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Reject truncation if the file flags do not permit this operation.
+ // This is different from truncate(2) above.
+ if !file.Flags().Write {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // In contrast to open(O_TRUNC), truncate(2) is only valid for file
+ // types. Note that this is different from truncate(2) above, where a
+ // directory returns EISDIR.
+ if !fs.IsFile(file.Dirent.Inode.StableAttr) {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if uint64(length) >= t.ThreadGroup().Limits().Get(limits.FileSize).Cur {
+ t.SendSignal(&arch.SignalInfo{
+ Signo: int32(linux.SIGXFSZ),
+ Code: arch.SignalInfoUser,
+ })
+ return 0, nil, syserror.EFBIG
+ }
+
+ if err := file.Dirent.Inode.Truncate(t, file.Dirent, length); err != nil {
+ return 0, nil, err
+ }
+
+ // File length modified, generate notification.
+ file.Dirent.InotifyEvent(linux.IN_MODIFY, 0)
+
+ return 0, nil, nil
+}
+
+// LINT.ThenChange(vfs2/setstat.go)
+
+// Umask implements linux syscall umask(2).
+func Umask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ mask := args[0].ModeT()
+ mask = t.FSContext().SwapUmask(mask & 0777)
+ return uintptr(mask), nil, nil
+}
+
+// LINT.IfChange
+
+// Change ownership of a file.
+//
+// uid and gid may be -1, in which case they will not be changed.
+func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error {
+ owner := fs.FileOwner{
+ UID: auth.NoID,
+ GID: auth.NoID,
+ }
+
+ uattr, err := d.Inode.UnstableAttr(t)
+ if err != nil {
+ return err
+ }
+ c := t.Credentials()
+ hasCap := d.Inode.CheckCapability(t, linux.CAP_CHOWN)
+ isOwner := uattr.Owner.UID == c.EffectiveKUID
+ if uid.Ok() {
+ kuid := c.UserNamespace.MapToKUID(uid)
+ // Valid UID must be supplied if UID is to be changed.
+ if !kuid.Ok() {
+ return syserror.EINVAL
+ }
+
+ // "Only a privileged process (CAP_CHOWN) may change the owner
+ // of a file." -chown(2)
+ //
+ // Linux also allows chown if you own the file and are
+ // explicitly not changing its UID.
+ isNoop := uattr.Owner.UID == kuid
+ if !(hasCap || (isOwner && isNoop)) {
+ return syserror.EPERM
+ }
+
+ owner.UID = kuid
+ }
+ if gid.Ok() {
+ kgid := c.UserNamespace.MapToKGID(gid)
+ // Valid GID must be supplied if GID is to be changed.
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+
+ // "The owner of a file may change the group of the file to any
+ // group of which that owner is a member. A privileged process
+ // (CAP_CHOWN) may change the group arbitrarily." -chown(2)
+ isNoop := uattr.Owner.GID == kgid
+ isMemberGroup := c.InGroup(kgid)
+ if !(hasCap || (isOwner && (isNoop || isMemberGroup))) {
+ return syserror.EPERM
+ }
+
+ owner.GID = kgid
+ }
+
+ // FIXME(b/62949101): This is racy; the inode's owner may have changed in
+ // the meantime. (Linux holds i_mutex while calling
+ // fs/attr.c:notify_change() => inode_operations::setattr =>
+ // inode_change_ok().)
+ if err := d.Inode.SetOwner(t, d, owner); err != nil {
+ return err
+ }
+
+ // When the owner or group are changed by an unprivileged user,
+ // chown(2) also clears the set-user-ID and set-group-ID bits, but
+ // we do not support them.
+ return nil
+}
+
+func chownAt(t *kernel.Task, fd int32, addr usermem.Addr, resolve, allowEmpty bool, uid auth.UID, gid auth.GID) error {
+ path, _, err := copyInPath(t, addr, allowEmpty)
+ if err != nil {
+ return err
+ }
+
+ if path == "" {
+ // Annoying. What's wrong with fchown?
+ file := t.GetFile(fd)
+ if file == nil {
+ return syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return chown(t, file.Dirent, uid, gid)
+ }
+
+ return fileOpOn(t, fd, path, resolve, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ return chown(t, d, uid, gid)
+ })
+}
+
+// Chown implements linux syscall chown(2).
+func Chown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ uid := auth.UID(args[1].Uint())
+ gid := auth.GID(args[2].Uint())
+
+ return 0, nil, chownAt(t, linux.AT_FDCWD, addr, true /* resolve */, false /* allowEmpty */, uid, gid)
+}
+
+// Lchown implements linux syscall lchown(2).
+func Lchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ uid := auth.UID(args[1].Uint())
+ gid := auth.GID(args[2].Uint())
+
+ return 0, nil, chownAt(t, linux.AT_FDCWD, addr, false /* resolve */, false /* allowEmpty */, uid, gid)
+}
+
+// Fchown implements linux syscall fchown(2).
+func Fchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ uid := auth.UID(args[1].Uint())
+ gid := auth.GID(args[2].Uint())
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, chown(t, file.Dirent, uid, gid)
+}
+
+// Fchownat implements Linux syscall fchownat(2).
+func Fchownat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := args[0].Int()
+ addr := args[1].Pointer()
+ uid := auth.UID(args[2].Uint())
+ gid := auth.GID(args[3].Uint())
+ flags := args[4].Int()
+
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, chownAt(t, dirFD, addr, flags&linux.AT_SYMLINK_NOFOLLOW == 0, flags&linux.AT_EMPTY_PATH != 0, uid, gid)
+}
+
+func chmod(t *kernel.Task, d *fs.Dirent, mode linux.FileMode) error {
+ // Must own file to change mode.
+ if !d.Inode.CheckOwnership(t) {
+ return syserror.EPERM
+ }
+
+ p := fs.FilePermsFromMode(mode)
+ if !d.Inode.SetPermissions(t, d, p) {
+ return syserror.EPERM
+ }
+
+ // File attribute changed, generate notification.
+ d.InotifyEvent(linux.IN_ATTRIB, 0)
+
+ return nil
+}
+
+func chmodAt(t *kernel.Task, fd int32, addr usermem.Addr, mode linux.FileMode) error {
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+
+ return fileOpOn(t, fd, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ return chmod(t, d, mode)
+ })
+}
+
+// Chmod implements linux syscall chmod(2).
+func Chmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := linux.FileMode(args[1].ModeT())
+
+ return 0, nil, chmodAt(t, linux.AT_FDCWD, addr, mode)
+}
+
+// Fchmod implements linux syscall fchmod(2).
+func Fchmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ mode := linux.FileMode(args[1].ModeT())
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, chmod(t, file.Dirent, mode)
+}
+
+// Fchmodat implements linux syscall fchmodat(2).
+func Fchmodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ mode := linux.FileMode(args[2].ModeT())
+
+ return 0, nil, chmodAt(t, fd, addr, mode)
+}
+
+// defaultSetToSystemTimeSpec returns a TimeSpec that will set ATime and MTime
+// to the system time.
+func defaultSetToSystemTimeSpec() fs.TimeSpec {
+ return fs.TimeSpec{
+ ATimeSetSystemTime: true,
+ MTimeSetSystemTime: true,
+ }
+}
+
+func utimes(t *kernel.Task, dirFD int32, addr usermem.Addr, ts fs.TimeSpec, resolve bool) error {
+ setTimestamp := func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ // Does the task own the file?
+ if !d.Inode.CheckOwnership(t) {
+ // Trying to set a specific time? Must be owner.
+ if (ts.ATimeOmit || !ts.ATimeSetSystemTime) && (ts.MTimeOmit || !ts.MTimeSetSystemTime) {
+ return syserror.EPERM
+ }
+
+ // Trying to set to current system time? Must have write access.
+ if err := d.Inode.CheckPermission(t, fs.PermMask{Write: true}); err != nil {
+ return err
+ }
+ }
+
+ if err := d.Inode.SetTimestamps(t, d, ts); err != nil {
+ return err
+ }
+
+ // File attribute changed, generate notification.
+ d.InotifyEvent(linux.IN_ATTRIB, 0)
+ return nil
+ }
+
+ // From utimes.c:
+ // "If filename is NULL and dfd refers to an open file, then operate on
+ // the file. Otherwise look up filename, possibly using dfd as a
+ // starting point."
+ if addr == 0 && dirFD != linux.AT_FDCWD {
+ if !resolve {
+ // Linux returns EINVAL in this case. See utimes.c.
+ return syserror.EINVAL
+ }
+ f := t.GetFile(dirFD)
+ if f == nil {
+ return syserror.EBADF
+ }
+ defer f.DecRef()
+
+ root := t.FSContext().RootDirectory()
+ defer root.DecRef()
+
+ return setTimestamp(root, f.Dirent, linux.MaxSymlinkTraversals)
+ }
+
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+
+ return fileOpOn(t, dirFD, path, resolve, setTimestamp)
+}
+
+// Utime implements linux syscall utime(2).
+func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ filenameAddr := args[0].Pointer()
+ timesAddr := args[1].Pointer()
+
+ // No timesAddr argument will be interpreted as current system time.
+ ts := defaultSetToSystemTimeSpec()
+ if timesAddr != 0 {
+ var times linux.Utime
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return 0, nil, err
+ }
+ ts = fs.TimeSpec{
+ ATime: ktime.FromSeconds(times.Actime),
+ MTime: ktime.FromSeconds(times.Modtime),
+ }
+ }
+ return 0, nil, utimes(t, linux.AT_FDCWD, filenameAddr, ts, true)
+}
+
+// Utimes implements linux syscall utimes(2).
+func Utimes(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ filenameAddr := args[0].Pointer()
+ timesAddr := args[1].Pointer()
+
+ // No timesAddr argument will be interpreted as current system time.
+ ts := defaultSetToSystemTimeSpec()
+ if timesAddr != 0 {
+ var times [2]linux.Timeval
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return 0, nil, err
+ }
+ ts = fs.TimeSpec{
+ ATime: ktime.FromTimeval(times[0]),
+ MTime: ktime.FromTimeval(times[1]),
+ }
+ }
+ return 0, nil, utimes(t, linux.AT_FDCWD, filenameAddr, ts, true)
+}
+
+// timespecIsValid checks that the timespec is valid for use in utimensat.
+func timespecIsValid(ts linux.Timespec) bool {
+ // Nsec must be UTIME_OMIT, UTIME_NOW, or less than 10^9.
+ return ts.Nsec == linux.UTIME_OMIT || ts.Nsec == linux.UTIME_NOW || ts.Nsec < 1e9
+}
+
+// Utimensat implements linux syscall utimensat(2).
+func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := args[0].Int()
+ pathnameAddr := args[1].Pointer()
+ timesAddr := args[2].Pointer()
+ flags := args[3].Int()
+
+ // No timesAddr argument will be interpreted as current system time.
+ ts := defaultSetToSystemTimeSpec()
+ if timesAddr != 0 {
+ var times [2]linux.Timespec
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return 0, nil, err
+ }
+ if !timespecIsValid(times[0]) || !timespecIsValid(times[1]) {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // If both are UTIME_OMIT, this is a noop.
+ if times[0].Nsec == linux.UTIME_OMIT && times[1].Nsec == linux.UTIME_OMIT {
+ return 0, nil, nil
+ }
+
+ ts = fs.TimeSpec{
+ ATime: ktime.FromTimespec(times[0]),
+ ATimeOmit: times[0].Nsec == linux.UTIME_OMIT,
+ ATimeSetSystemTime: times[0].Nsec == linux.UTIME_NOW,
+ MTime: ktime.FromTimespec(times[1]),
+ MTimeOmit: times[1].Nsec == linux.UTIME_OMIT,
+ MTimeSetSystemTime: times[0].Nsec == linux.UTIME_NOW,
+ }
+ }
+ return 0, nil, utimes(t, dirFD, pathnameAddr, ts, flags&linux.AT_SYMLINK_NOFOLLOW == 0)
+}
+
+// Futimesat implements linux syscall futimesat(2).
+func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := args[0].Int()
+ pathnameAddr := args[1].Pointer()
+ timesAddr := args[2].Pointer()
+
+ // No timesAddr argument will be interpreted as current system time.
+ ts := defaultSetToSystemTimeSpec()
+ if timesAddr != 0 {
+ var times [2]linux.Timeval
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return 0, nil, err
+ }
+ if times[0].Usec >= 1e6 || times[0].Usec < 0 ||
+ times[1].Usec >= 1e6 || times[1].Usec < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ ts = fs.TimeSpec{
+ ATime: ktime.FromTimeval(times[0]),
+ MTime: ktime.FromTimeval(times[1]),
+ }
+ }
+ return 0, nil, utimes(t, dirFD, pathnameAddr, ts, true)
+}
+
+// LINT.ThenChange(vfs2/setstat.go)
+
+// LINT.IfChange
+
+func renameAt(t *kernel.Task, oldDirFD int32, oldAddr usermem.Addr, newDirFD int32, newAddr usermem.Addr) error {
+ newPath, _, err := copyInPath(t, newAddr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+ oldPath, _, err := copyInPath(t, oldAddr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+
+ return fileOpAt(t, oldDirFD, oldPath, func(root *fs.Dirent, oldParent *fs.Dirent, oldName string, _ uint) error {
+ if !fs.IsDir(oldParent.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Rename rejects paths that end in ".", "..", or empty (i.e.
+ // the root) with EBUSY.
+ switch oldName {
+ case "", ".", "..":
+ return syserror.EBUSY
+ }
+
+ return fileOpAt(t, newDirFD, newPath, func(root *fs.Dirent, newParent *fs.Dirent, newName string, _ uint) error {
+ if !fs.IsDir(newParent.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Rename rejects paths that end in ".", "..", or empty
+ // (i.e. the root) with EBUSY.
+ switch newName {
+ case "", ".", "..":
+ return syserror.EBUSY
+ }
+
+ return fs.Rename(t, root, oldParent, oldName, newParent, newName)
+ })
+ })
+}
+
+// Rename implements linux syscall rename(2).
+func Rename(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldPathAddr := args[0].Pointer()
+ newPathAddr := args[1].Pointer()
+ return 0, nil, renameAt(t, linux.AT_FDCWD, oldPathAddr, linux.AT_FDCWD, newPathAddr)
+}
+
+// Renameat implements linux syscall renameat(2).
+func Renameat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldDirFD := args[0].Int()
+ oldPathAddr := args[1].Pointer()
+ newDirFD := args[2].Int()
+ newPathAddr := args[3].Pointer()
+ return 0, nil, renameAt(t, oldDirFD, oldPathAddr, newDirFD, newPathAddr)
+}
+
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// Fallocate implements linux system call fallocate(2).
+func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ mode := args[1].Int64()
+ offset := args[2].Int64()
+ length := args[3].Int64()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ if offset < 0 || length <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if mode != 0 {
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.ENOTSUP
+ }
+ if !file.Flags().Write {
+ return 0, nil, syserror.EBADF
+ }
+ if fs.IsPipe(file.Dirent.Inode.StableAttr) {
+ return 0, nil, syserror.ESPIPE
+ }
+ if fs.IsDir(file.Dirent.Inode.StableAttr) {
+ return 0, nil, syserror.EISDIR
+ }
+ if !fs.IsRegular(file.Dirent.Inode.StableAttr) {
+ return 0, nil, syserror.ENODEV
+ }
+ size := offset + length
+ if size < 0 {
+ return 0, nil, syserror.EFBIG
+ }
+ if uint64(size) >= t.ThreadGroup().Limits().Get(limits.FileSize).Cur {
+ t.SendSignal(&arch.SignalInfo{
+ Signo: int32(linux.SIGXFSZ),
+ Code: arch.SignalInfoUser,
+ })
+ return 0, nil, syserror.EFBIG
+ }
+
+ if err := file.Dirent.Inode.Allocate(t, file.Dirent, offset, length); err != nil {
+ return 0, nil, err
+ }
+
+ // File length modified, generate notification.
+ file.Dirent.InotifyEvent(linux.IN_MODIFY, 0)
+
+ return 0, nil, nil
+}
+
+// Flock implements linux syscall flock(2).
+func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ operation := args[1].Int()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ // flock(2): EBADF fd is not an open file descriptor.
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ nonblocking := operation&linux.LOCK_NB != 0
+ operation &^= linux.LOCK_NB
+
+ // A BSD style lock spans the entire file.
+ rng := lock.LockRange{
+ Start: 0,
+ End: lock.LockEOF,
+ }
+
+ switch operation {
+ case linux.LOCK_EX:
+ if nonblocking {
+ // Since we're nonblocking we pass a nil lock.Blocker implementation.
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.WriteLock, rng, nil) {
+ return 0, nil, syserror.EWOULDBLOCK
+ }
+ } else {
+ // Because we're blocking we will pass the task to satisfy the lock.Blocker interface.
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.WriteLock, rng, t) {
+ return 0, nil, syserror.EINTR
+ }
+ }
+ case linux.LOCK_SH:
+ if nonblocking {
+ // Since we're nonblocking we pass a nil lock.Blocker implementation.
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.ReadLock, rng, nil) {
+ return 0, nil, syserror.EWOULDBLOCK
+ }
+ } else {
+ // Because we're blocking we will pass the task to satisfy the lock.Blocker interface.
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.ReadLock, rng, t) {
+ return 0, nil, syserror.EINTR
+ }
+ }
+ case linux.LOCK_UN:
+ file.Dirent.Inode.LockCtx.BSD.UnlockRegion(file, rng)
+ default:
+ // flock(2): EINVAL operation is invalid.
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, nil
+}
+
+const (
+ memfdPrefix = "/memfd:"
+ memfdAllFlags = uint32(linux.MFD_CLOEXEC | linux.MFD_ALLOW_SEALING)
+ memfdMaxNameLen = linux.NAME_MAX - len(memfdPrefix) + 1
+)
+
+// MemfdCreate implements the linux syscall memfd_create(2).
+func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Uint()
+
+ if flags&^memfdAllFlags != 0 {
+ // Unknown bits in flags.
+ return 0, nil, syserror.EINVAL
+ }
+
+ allowSeals := flags&linux.MFD_ALLOW_SEALING != 0
+ cloExec := flags&linux.MFD_CLOEXEC != 0
+
+ name, err := t.CopyInString(addr, syscall.PathMax-len(memfdPrefix))
+ if err != nil {
+ return 0, nil, err
+ }
+ if len(name) > memfdMaxNameLen {
+ return 0, nil, syserror.EINVAL
+ }
+ name = memfdPrefix + name
+
+ inode := tmpfs.NewMemfdInode(t, allowSeals)
+ dirent := fs.NewDirent(t, inode, name)
+ // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd files are set up with
+ // FMODE_READ | FMODE_WRITE.
+ file, err := inode.GetFile(t, dirent, fs.FileFlags{Read: true, Write: true})
+ if err != nil {
+ return 0, nil, err
+ }
+
+ defer dirent.DecRef()
+ defer file.DecRef()
+
+ newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{
+ CloseOnExec: cloExec,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(newFD), nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go
new file mode 100644
index 000000000..b68261f72
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_futex.go
@@ -0,0 +1,288 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// futexWaitRestartBlock encapsulates the state required to restart futex(2)
+// via restart_syscall(2).
+//
+// +stateify savable
+type futexWaitRestartBlock struct {
+ duration time.Duration
+
+ // addr stored as uint64 since uintptr is not save-able.
+ addr uint64
+ private bool
+ val uint32
+ mask uint32
+}
+
+// Restart implements kernel.SyscallRestartBlock.Restart.
+func (f *futexWaitRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
+ return futexWaitDuration(t, f.duration, false, usermem.Addr(f.addr), f.private, f.val, f.mask)
+}
+
+// futexWaitAbsolute performs a FUTEX_WAIT_BITSET, blocking until the wait is
+// complete.
+//
+// The wait blocks forever if forever is true, otherwise it blocks until ts.
+//
+// If blocking is interrupted, the syscall is restarted with the original
+// arguments.
+func futexWaitAbsolute(t *kernel.Task, clockRealtime bool, ts linux.Timespec, forever bool, addr usermem.Addr, private bool, val, mask uint32) (uintptr, error) {
+ w := t.FutexWaiter()
+ err := t.Futex().WaitPrepare(w, t, addr, private, val, mask)
+ if err != nil {
+ return 0, err
+ }
+
+ if forever {
+ err = t.Block(w.C)
+ } else if clockRealtime {
+ notifier, tchan := ktime.NewChannelNotifier()
+ timer := ktime.NewTimer(t.Kernel().RealtimeClock(), notifier)
+ timer.Swap(ktime.Setting{
+ Enabled: true,
+ Next: ktime.FromTimespec(ts),
+ })
+ err = t.BlockWithTimer(w.C, tchan)
+ timer.Destroy()
+ } else {
+ err = t.BlockWithDeadline(w.C, true, ktime.FromTimespec(ts))
+ }
+
+ t.Futex().WaitComplete(w)
+ return 0, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
+}
+
+// futexWaitDuration performs a FUTEX_WAIT, blocking until the wait is
+// complete.
+//
+// The wait blocks forever if forever is true, otherwise is blocks for
+// duration.
+//
+// If blocking is interrupted, forever determines how to restart the
+// syscall. If forever is true, the syscall is restarted with the original
+// arguments. If forever is false, duration is a relative timeout and the
+// syscall is restarted with the remaining timeout.
+func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, addr usermem.Addr, private bool, val, mask uint32) (uintptr, error) {
+ w := t.FutexWaiter()
+ err := t.Futex().WaitPrepare(w, t, addr, private, val, mask)
+ if err != nil {
+ return 0, err
+ }
+
+ remaining, err := t.BlockWithTimeout(w.C, !forever, duration)
+ t.Futex().WaitComplete(w)
+ if err == nil {
+ return 0, nil
+ }
+
+ // The wait was unsuccessful for some reason other than interruption. Simply
+ // forward the error.
+ if err != syserror.ErrInterrupted {
+ return 0, err
+ }
+
+ // The wait was interrupted and we need to restart. Decide how.
+
+ // The wait duration was absolute, restart with the original arguments.
+ if forever {
+ return 0, kernel.ERESTARTSYS
+ }
+
+ // The wait duration was relative, restart with the remaining duration.
+ t.SetSyscallRestartBlock(&futexWaitRestartBlock{
+ duration: remaining,
+ addr: uint64(addr),
+ private: private,
+ val: val,
+ mask: mask,
+ })
+ return 0, kernel.ERESTART_RESTARTBLOCK
+}
+
+func futexLockPI(t *kernel.Task, ts linux.Timespec, forever bool, addr usermem.Addr, private bool) error {
+ w := t.FutexWaiter()
+ locked, err := t.Futex().LockPI(w, t, addr, uint32(t.ThreadID()), private, false)
+ if err != nil {
+ return err
+ }
+ if locked {
+ // Futex acquired, we're done!
+ return nil
+ }
+
+ if forever {
+ err = t.Block(w.C)
+ } else {
+ notifier, tchan := ktime.NewChannelNotifier()
+ timer := ktime.NewTimer(t.Kernel().RealtimeClock(), notifier)
+ timer.Swap(ktime.Setting{
+ Enabled: true,
+ Next: ktime.FromTimespec(ts),
+ })
+ err = t.BlockWithTimer(w.C, tchan)
+ timer.Destroy()
+ }
+
+ t.Futex().WaitComplete(w)
+ return syserror.ConvertIntr(err, kernel.ERESTARTSYS)
+}
+
+func tryLockPI(t *kernel.Task, addr usermem.Addr, private bool) error {
+ w := t.FutexWaiter()
+ locked, err := t.Futex().LockPI(w, t, addr, uint32(t.ThreadID()), private, true)
+ if err != nil {
+ return err
+ }
+ if !locked {
+ return syserror.EWOULDBLOCK
+ }
+ return nil
+}
+
+// Futex implements linux syscall futex(2).
+// It provides a method for a program to wait for a value at a given address to
+// change, and a method to wake up anyone waiting on a particular address.
+func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ futexOp := args[1].Int()
+ val := int(args[2].Int())
+ nreq := int(args[3].Int())
+ timeout := args[3].Pointer()
+ naddr := args[4].Pointer()
+ val3 := args[5].Int()
+
+ cmd := futexOp &^ (linux.FUTEX_PRIVATE_FLAG | linux.FUTEX_CLOCK_REALTIME)
+ private := (futexOp & linux.FUTEX_PRIVATE_FLAG) != 0
+ clockRealtime := (futexOp & linux.FUTEX_CLOCK_REALTIME) == linux.FUTEX_CLOCK_REALTIME
+ mask := uint32(val3)
+
+ switch cmd {
+ case linux.FUTEX_WAIT, linux.FUTEX_WAIT_BITSET:
+ // WAIT{_BITSET} wait forever if the timeout isn't passed.
+ forever := (timeout == 0)
+
+ var timespec linux.Timespec
+ if !forever {
+ var err error
+ timespec, err = copyTimespecIn(t, timeout)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+
+ switch cmd {
+ case linux.FUTEX_WAIT:
+ // WAIT uses a relative timeout.
+ mask = ^uint32(0)
+ var timeoutDur time.Duration
+ if !forever {
+ timeoutDur = time.Duration(timespec.ToNsecCapped()) * time.Nanosecond
+ }
+ n, err := futexWaitDuration(t, timeoutDur, forever, addr, private, uint32(val), mask)
+ return n, nil, err
+
+ case linux.FUTEX_WAIT_BITSET:
+ // WAIT_BITSET uses an absolute timeout which is either
+ // CLOCK_MONOTONIC or CLOCK_REALTIME.
+ if mask == 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ n, err := futexWaitAbsolute(t, clockRealtime, timespec, forever, addr, private, uint32(val), mask)
+ return n, nil, err
+ default:
+ panic("unreachable")
+ }
+
+ case linux.FUTEX_WAKE:
+ mask = ^uint32(0)
+ fallthrough
+
+ case linux.FUTEX_WAKE_BITSET:
+ if mask == 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if val <= 0 {
+ // The Linux kernel wakes one waiter even if val is
+ // non-positive.
+ val = 1
+ }
+ n, err := t.Futex().Wake(t, addr, private, mask, val)
+ return uintptr(n), nil, err
+
+ case linux.FUTEX_REQUEUE:
+ n, err := t.Futex().Requeue(t, addr, naddr, private, val, nreq)
+ return uintptr(n), nil, err
+
+ case linux.FUTEX_CMP_REQUEUE:
+ // 'val3' contains the value to be checked at 'addr' and
+ // 'val' is the number of waiters that should be woken up.
+ nval := uint32(val3)
+ n, err := t.Futex().RequeueCmp(t, addr, naddr, private, nval, val, nreq)
+ return uintptr(n), nil, err
+
+ case linux.FUTEX_WAKE_OP:
+ op := uint32(val3)
+ if val <= 0 {
+ // The Linux kernel wakes one waiter even if val is
+ // non-positive.
+ val = 1
+ }
+ n, err := t.Futex().WakeOp(t, addr, naddr, private, val, nreq, op)
+ return uintptr(n), nil, err
+
+ case linux.FUTEX_LOCK_PI:
+ forever := (timeout == 0)
+
+ var timespec linux.Timespec
+ if !forever {
+ var err error
+ timespec, err = copyTimespecIn(t, timeout)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+ err := futexLockPI(t, timespec, forever, addr, private)
+ return 0, nil, err
+
+ case linux.FUTEX_TRYLOCK_PI:
+ err := tryLockPI(t, addr, private)
+ return 0, nil, err
+
+ case linux.FUTEX_UNLOCK_PI:
+ err := t.Futex().UnlockPI(t, addr, uint32(t.ThreadID()), private)
+ return 0, nil, err
+
+ case linux.FUTEX_WAIT_REQUEUE_PI, linux.FUTEX_CMP_REQUEUE_PI:
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.ENOSYS
+
+ default:
+ // We don't even know about this command.
+ return 0, nil, syserror.ENOSYS
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go
new file mode 100644
index 000000000..b126fecc0
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_getdents.go
@@ -0,0 +1,250 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "bytes"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// LINT.IfChange
+
+// Getdents implements linux syscall getdents(2) for 64bit systems.
+func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := int(args[2].Uint())
+
+ minSize := int(smallestDirent(t.Arch()))
+ if size < minSize {
+ // size is smaller than smallest possible dirent.
+ return 0, nil, syserror.EINVAL
+ }
+
+ n, err := getdents(t, fd, addr, size, (*dirent).Serialize)
+ return n, nil, err
+}
+
+// Getdents64 implements linux syscall getdents64(2).
+func Getdents64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := int(args[2].Uint())
+
+ minSize := int(smallestDirent64(t.Arch()))
+ if size < minSize {
+ // size is smaller than smallest possible dirent.
+ return 0, nil, syserror.EINVAL
+ }
+
+ n, err := getdents(t, fd, addr, size, (*dirent).Serialize64)
+ return n, nil, err
+}
+
+// getdents implements the core of getdents(2)/getdents64(2).
+// f is the syscall implementation dirent serialization function.
+func getdents(t *kernel.Task, fd int32, addr usermem.Addr, size int, f func(*dirent, io.Writer) (int, error)) (uintptr, error) {
+ dir := t.GetFile(fd)
+ if dir == nil {
+ return 0, syserror.EBADF
+ }
+ defer dir.DecRef()
+
+ w := &usermem.IOReadWriter{
+ Ctx: t,
+ IO: t.MemoryManager(),
+ Addr: addr,
+ Opts: usermem.IOOpts{
+ AddressSpaceActive: true,
+ },
+ }
+
+ ds := newDirentSerializer(f, w, t.Arch(), size)
+ rerr := dir.Readdir(t, ds)
+
+ switch err := handleIOError(t, ds.Written() > 0, rerr, kernel.ERESTARTSYS, "getdents", dir); err {
+ case nil:
+ dir.Dirent.InotifyEvent(linux.IN_ACCESS, 0)
+ return uintptr(ds.Written()), nil
+ case io.EOF:
+ return 0, nil
+ default:
+ return 0, err
+ }
+}
+
+// oldDirentHdr is a fixed sized header matching the fixed size
+// fields found in the old linux dirent struct.
+type oldDirentHdr struct {
+ Ino uint64
+ Off uint64
+ Reclen uint16
+}
+
+// direntHdr is a fixed sized header matching the fixed size
+// fields found in the new linux dirent struct.
+type direntHdr struct {
+ OldHdr oldDirentHdr
+ Typ uint8
+}
+
+// dirent contains the data pointed to by a new linux dirent struct.
+type dirent struct {
+ Hdr direntHdr
+ Name []byte
+}
+
+// newDirent returns a dirent from an fs.InodeOperationsInfo.
+func newDirent(width uint, name string, attr fs.DentAttr, offset uint64) *dirent {
+ d := &dirent{
+ Hdr: direntHdr{
+ OldHdr: oldDirentHdr{
+ Ino: attr.InodeID,
+ Off: offset,
+ },
+ Typ: fs.ToDirentType(attr.Type),
+ },
+ Name: []byte(name),
+ }
+ d.Hdr.OldHdr.Reclen = d.padRec(int(width))
+ return d
+}
+
+// smallestDirent returns the size of the smallest possible dirent using
+// the old linux dirent format.
+func smallestDirent(a arch.Context) uint {
+ d := dirent{}
+ return uint(binary.Size(d.Hdr.OldHdr)) + a.Width() + 1
+}
+
+// smallestDirent64 returns the size of the smallest possible dirent using
+// the new linux dirent format.
+func smallestDirent64(a arch.Context) uint {
+ d := dirent{}
+ return uint(binary.Size(d.Hdr)) + a.Width()
+}
+
+// padRec pads the name field until the rec length is a multiple of the width,
+// which must be a power of 2. It returns the padded rec length.
+func (d *dirent) padRec(width int) uint16 {
+ a := int(binary.Size(d.Hdr)) + len(d.Name)
+ r := (a + width) &^ (width - 1)
+ padding := r - a
+ d.Name = append(d.Name, make([]byte, padding)...)
+ return uint16(r)
+}
+
+// Serialize64 serializes a Dirent struct to a byte slice, keeping the new
+// linux dirent format. Returns the number of bytes serialized or an error.
+func (d *dirent) Serialize64(w io.Writer) (int, error) {
+ n1, err := w.Write(binary.Marshal(nil, usermem.ByteOrder, d.Hdr))
+ if err != nil {
+ return 0, err
+ }
+ n2, err := w.Write(d.Name)
+ if err != nil {
+ return 0, err
+ }
+ return n1 + n2, nil
+}
+
+// Serialize serializes a Dirent struct to a byte slice, using the old linux
+// dirent format.
+// Returns the number of bytes serialized or an error.
+func (d *dirent) Serialize(w io.Writer) (int, error) {
+ n1, err := w.Write(binary.Marshal(nil, usermem.ByteOrder, d.Hdr.OldHdr))
+ if err != nil {
+ return 0, err
+ }
+ n2, err := w.Write(d.Name)
+ if err != nil {
+ return 0, err
+ }
+ n3, err := w.Write([]byte{d.Hdr.Typ})
+ if err != nil {
+ return 0, err
+ }
+ return n1 + n2 + n3, nil
+}
+
+// direntSerializer implements fs.InodeOperationsInfoSerializer, serializing dirents to an
+// io.Writer.
+type direntSerializer struct {
+ serialize func(*dirent, io.Writer) (int, error)
+ w io.Writer
+ // width is the arch native value width.
+ width uint
+ // offset is the current dirent offset.
+ offset uint64
+ // written is the total bytes serialized.
+ written int
+ // size is the size of the buffer to serialize into.
+ size int
+}
+
+func newDirentSerializer(f func(d *dirent, w io.Writer) (int, error), w io.Writer, ac arch.Context, size int) *direntSerializer {
+ return &direntSerializer{
+ serialize: f,
+ w: w,
+ width: ac.Width(),
+ size: size,
+ }
+}
+
+// CopyOut implements fs.InodeOperationsInfoSerializer.CopyOut.
+// It serializes and writes the fs.DentAttr to the direntSerializer io.Writer.
+func (ds *direntSerializer) CopyOut(name string, attr fs.DentAttr) error {
+ ds.offset++
+
+ d := newDirent(ds.width, name, attr, ds.offset)
+
+ // Serialize dirent into a temp buffer.
+ var b bytes.Buffer
+ n, err := ds.serialize(d, &b)
+ if err != nil {
+ ds.offset--
+ return err
+ }
+
+ // Check that we have enough room remaining to write the dirent.
+ if n > (ds.size - ds.written) {
+ ds.offset--
+ return io.EOF
+ }
+
+ // Write out the temp buffer.
+ if _, err := b.WriteTo(ds.w); err != nil {
+ ds.offset--
+ return err
+ }
+
+ ds.written += n
+ return nil
+}
+
+// Written returns the total number of bytes written.
+func (ds *direntSerializer) Written() int {
+ return ds.written
+}
+
+// LINT.ThenChange(vfs2/getdents.go)
diff --git a/pkg/sentry/syscalls/linux/sys_identity.go b/pkg/sentry/syscalls/linux/sys_identity.go
new file mode 100644
index 000000000..715ac45e6
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_identity.go
@@ -0,0 +1,180 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const (
+ // As NGROUPS_MAX in include/uapi/linux/limits.h.
+ maxNGroups = 65536
+)
+
+// Getuid implements the Linux syscall getuid.
+func Getuid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ c := t.Credentials()
+ ruid := c.RealKUID.In(c.UserNamespace).OrOverflow()
+ return uintptr(ruid), nil, nil
+}
+
+// Geteuid implements the Linux syscall geteuid.
+func Geteuid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ c := t.Credentials()
+ euid := c.EffectiveKUID.In(c.UserNamespace).OrOverflow()
+ return uintptr(euid), nil, nil
+}
+
+// Getresuid implements the Linux syscall getresuid.
+func Getresuid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ ruidAddr := args[0].Pointer()
+ euidAddr := args[1].Pointer()
+ suidAddr := args[2].Pointer()
+ c := t.Credentials()
+ ruid := c.RealKUID.In(c.UserNamespace).OrOverflow()
+ euid := c.EffectiveKUID.In(c.UserNamespace).OrOverflow()
+ suid := c.SavedKUID.In(c.UserNamespace).OrOverflow()
+ if _, err := t.CopyOut(ruidAddr, ruid); err != nil {
+ return 0, nil, err
+ }
+ if _, err := t.CopyOut(euidAddr, euid); err != nil {
+ return 0, nil, err
+ }
+ if _, err := t.CopyOut(suidAddr, suid); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, nil
+}
+
+// Getgid implements the Linux syscall getgid.
+func Getgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ c := t.Credentials()
+ rgid := c.RealKGID.In(c.UserNamespace).OrOverflow()
+ return uintptr(rgid), nil, nil
+}
+
+// Getegid implements the Linux syscall getegid.
+func Getegid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ c := t.Credentials()
+ egid := c.EffectiveKGID.In(c.UserNamespace).OrOverflow()
+ return uintptr(egid), nil, nil
+}
+
+// Getresgid implements the Linux syscall getresgid.
+func Getresgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ rgidAddr := args[0].Pointer()
+ egidAddr := args[1].Pointer()
+ sgidAddr := args[2].Pointer()
+ c := t.Credentials()
+ rgid := c.RealKGID.In(c.UserNamespace).OrOverflow()
+ egid := c.EffectiveKGID.In(c.UserNamespace).OrOverflow()
+ sgid := c.SavedKGID.In(c.UserNamespace).OrOverflow()
+ if _, err := t.CopyOut(rgidAddr, rgid); err != nil {
+ return 0, nil, err
+ }
+ if _, err := t.CopyOut(egidAddr, egid); err != nil {
+ return 0, nil, err
+ }
+ if _, err := t.CopyOut(sgidAddr, sgid); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, nil
+}
+
+// Setuid implements the Linux syscall setuid.
+func Setuid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ uid := auth.UID(args[0].Int())
+ return 0, nil, t.SetUID(uid)
+}
+
+// Setreuid implements the Linux syscall setreuid.
+func Setreuid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ ruid := auth.UID(args[0].Int())
+ euid := auth.UID(args[1].Int())
+ return 0, nil, t.SetREUID(ruid, euid)
+}
+
+// Setresuid implements the Linux syscall setreuid.
+func Setresuid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ ruid := auth.UID(args[0].Int())
+ euid := auth.UID(args[1].Int())
+ suid := auth.UID(args[2].Int())
+ return 0, nil, t.SetRESUID(ruid, euid, suid)
+}
+
+// Setgid implements the Linux syscall setgid.
+func Setgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ gid := auth.GID(args[0].Int())
+ return 0, nil, t.SetGID(gid)
+}
+
+// Setregid implements the Linux syscall setregid.
+func Setregid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ rgid := auth.GID(args[0].Int())
+ egid := auth.GID(args[1].Int())
+ return 0, nil, t.SetREGID(rgid, egid)
+}
+
+// Setresgid implements the Linux syscall setregid.
+func Setresgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ rgid := auth.GID(args[0].Int())
+ egid := auth.GID(args[1].Int())
+ sgid := auth.GID(args[2].Int())
+ return 0, nil, t.SetRESGID(rgid, egid, sgid)
+}
+
+// Getgroups implements the Linux syscall getgroups.
+func Getgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ size := int(args[0].Int())
+ if size < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ kgids := t.Credentials().ExtraKGIDs
+ // "If size is zero, list is not modified, but the total number of
+ // supplementary group IDs for the process is returned." - getgroups(2)
+ if size == 0 {
+ return uintptr(len(kgids)), nil, nil
+ }
+ if size < len(kgids) {
+ return 0, nil, syserror.EINVAL
+ }
+ gids := make([]auth.GID, len(kgids))
+ for i, kgid := range kgids {
+ gids[i] = kgid.In(t.UserNamespace()).OrOverflow()
+ }
+ if _, err := t.CopyOut(args[1].Pointer(), gids); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(len(gids)), nil, nil
+}
+
+// Setgroups implements the Linux syscall setgroups.
+func Setgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ size := args[0].Int()
+ if size < 0 || size > maxNGroups {
+ return 0, nil, syserror.EINVAL
+ }
+ if size == 0 {
+ return 0, nil, t.SetExtraGIDs(nil)
+ }
+ gids := make([]auth.GID, size)
+ if _, err := t.CopyIn(args[1].Pointer(), &gids); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, t.SetExtraGIDs(gids)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_inotify.go b/pkg/sentry/syscalls/linux/sys_inotify.go
new file mode 100644
index 000000000..b2c7b3444
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_inotify.go
@@ -0,0 +1,133 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/anon"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const allFlags = int(linux.IN_NONBLOCK | linux.IN_CLOEXEC)
+
+// InotifyInit1 implements the inotify_init1() syscalls.
+func InotifyInit1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := int(args[0].Int())
+
+ if flags&^allFlags != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ dirent := fs.NewDirent(t, anon.NewInode(t), "inotify")
+ fileFlags := fs.FileFlags{
+ Read: true,
+ Write: true,
+ NonBlocking: flags&linux.IN_NONBLOCK != 0,
+ }
+ n := fs.NewFile(t, dirent, fileFlags, fs.NewInotify(t))
+ defer n.DecRef()
+
+ fd, err := t.NewFDFrom(0, n, kernel.FDFlags{
+ CloseOnExec: flags&linux.IN_CLOEXEC != 0,
+ })
+
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// InotifyInit implements the inotify_init() syscalls.
+func InotifyInit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ args[0].Value = 0
+ return InotifyInit1(t, args)
+}
+
+// fdToInotify resolves an fd to an inotify object. If successful, the file will
+// have an extra ref and the caller is responsible for releasing the ref.
+func fdToInotify(t *kernel.Task, fd int32) (*fs.Inotify, *fs.File, error) {
+ file := t.GetFile(fd)
+ if file == nil {
+ // Invalid fd.
+ return nil, nil, syserror.EBADF
+ }
+
+ ino, ok := file.FileOperations.(*fs.Inotify)
+ if !ok {
+ // Not an inotify fd.
+ file.DecRef()
+ return nil, nil, syserror.EINVAL
+ }
+
+ return ino, file, nil
+}
+
+// InotifyAddWatch implements the inotify_add_watch() syscall.
+func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ mask := args[2].Uint()
+
+ // "IN_DONT_FOLLOW: Don't dereference pathname if it is a symbolic link."
+ // -- inotify(7)
+ resolve := mask&linux.IN_DONT_FOLLOW == 0
+
+ // "EINVAL: The given event mask contains no valid events."
+ // -- inotify_add_watch(2)
+ if validBits := mask & linux.ALL_INOTIFY_BITS; validBits == 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ ino, file, err := fdToInotify(t, fd)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ err = fileOpOn(t, linux.AT_FDCWD, path, resolve, func(root *fs.Dirent, dirent *fs.Dirent, _ uint) error {
+ // "IN_ONLYDIR: Only watch pathname if it is a directory." -- inotify(7)
+ if onlyDir := mask&linux.IN_ONLYDIR != 0; onlyDir && !fs.IsDir(dirent.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Copy out to the return frame.
+ fd = ino.AddWatch(dirent, mask)
+
+ return nil
+ })
+ return uintptr(fd), nil, err // Return from the existing value.
+}
+
+// InotifyRmWatch implements the inotify_rm_watch() syscall.
+func InotifyRmWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ wd := args[1].Int()
+
+ ino, file, err := fdToInotify(t, fd)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+ return 0, nil, ino.RmWatch(wd)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_lseek.go b/pkg/sentry/syscalls/linux/sys_lseek.go
new file mode 100644
index 000000000..3f7691eae
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_lseek.go
@@ -0,0 +1,58 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// LINT.IfChange
+
+// Lseek implements linux syscall lseek(2).
+func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ whence := args[2].Int()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ var sw fs.SeekWhence
+ switch whence {
+ case 0:
+ sw = fs.SeekSet
+ case 1:
+ sw = fs.SeekCurrent
+ case 2:
+ sw = fs.SeekEnd
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ offset, serr := file.Seek(t, sw, offset)
+ err := handleIOError(t, false /* partialResult */, serr, kernel.ERESTARTSYS, "lseek", file)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(offset), nil, err
+}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_mempolicy.go b/pkg/sentry/syscalls/linux/sys_mempolicy.go
new file mode 100644
index 000000000..9b4a5c3f1
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_mempolicy.go
@@ -0,0 +1,312 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// We unconditionally report a single NUMA node. This also means that our
+// "nodemask_t" is a single unsigned long (uint64).
+const (
+ maxNodes = 1
+ allowedNodemask = (1 << maxNodes) - 1
+)
+
+func copyInNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32) (uint64, error) {
+ // "nodemask points to a bit mask of node IDs that contains up to maxnode
+ // bits. The bit mask size is rounded to the next multiple of
+ // sizeof(unsigned long), but the kernel will use bits only up to maxnode.
+ // A NULL value of nodemask or a maxnode value of zero specifies the empty
+ // set of nodes. If the value of maxnode is zero, the nodemask argument is
+ // ignored." - set_mempolicy(2). Unfortunately, most of this is inaccurate
+ // because of what appears to be a bug: mm/mempolicy.c:get_nodes() uses
+ // maxnode-1, not maxnode, as the number of bits.
+ bits := maxnode - 1
+ if bits > usermem.PageSize*8 { // also handles overflow from maxnode == 0
+ return 0, syserror.EINVAL
+ }
+ if bits == 0 {
+ return 0, nil
+ }
+ // Copy in the whole nodemask.
+ numUint64 := (bits + 63) / 64
+ buf := t.CopyScratchBuffer(int(numUint64) * 8)
+ if _, err := t.CopyInBytes(addr, buf); err != nil {
+ return 0, err
+ }
+ val := usermem.ByteOrder.Uint64(buf)
+ // Check that only allowed bits in the first unsigned long in the nodemask
+ // are set.
+ if val&^allowedNodemask != 0 {
+ return 0, syserror.EINVAL
+ }
+ // Check that all remaining bits in the nodemask are 0.
+ for i := 8; i < len(buf); i++ {
+ if buf[i] != 0 {
+ return 0, syserror.EINVAL
+ }
+ }
+ return val, nil
+}
+
+func copyOutNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32, val uint64) error {
+ // mm/mempolicy.c:copy_nodes_to_user() also uses maxnode-1 as the number of
+ // bits.
+ bits := maxnode - 1
+ if bits > usermem.PageSize*8 { // also handles overflow from maxnode == 0
+ return syserror.EINVAL
+ }
+ if bits == 0 {
+ return nil
+ }
+ // Copy out the first unsigned long in the nodemask.
+ buf := t.CopyScratchBuffer(8)
+ usermem.ByteOrder.PutUint64(buf, val)
+ if _, err := t.CopyOutBytes(addr, buf); err != nil {
+ return err
+ }
+ // Zero out remaining unsigned longs in the nodemask.
+ if bits > 64 {
+ remAddr, ok := addr.AddLength(8)
+ if !ok {
+ return syserror.EFAULT
+ }
+ remUint64 := (bits - 1) / 64
+ if _, err := t.MemoryManager().ZeroOut(t, remAddr, int64(remUint64)*8, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// GetMempolicy implements the syscall get_mempolicy(2).
+func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ mode := args[0].Pointer()
+ nodemask := args[1].Pointer()
+ maxnode := args[2].Uint()
+ addr := args[3].Pointer()
+ flags := args[4].Uint()
+
+ if flags&^(linux.MPOL_F_NODE|linux.MPOL_F_ADDR|linux.MPOL_F_MEMS_ALLOWED) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ nodeFlag := flags&linux.MPOL_F_NODE != 0
+ addrFlag := flags&linux.MPOL_F_ADDR != 0
+ memsAllowed := flags&linux.MPOL_F_MEMS_ALLOWED != 0
+
+ // "EINVAL: The value specified by maxnode is less than the number of node
+ // IDs supported by the system." - get_mempolicy(2)
+ if nodemask != 0 && maxnode < maxNodes {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // "If flags specifies MPOL_F_MEMS_ALLOWED [...], the mode argument is
+ // ignored and the set of nodes (memories) that the thread is allowed to
+ // specify in subsequent calls to mbind(2) or set_mempolicy(2) (in the
+ // absence of any mode flags) is returned in nodemask."
+ if memsAllowed {
+ // "It is not permitted to combine MPOL_F_MEMS_ALLOWED with either
+ // MPOL_F_ADDR or MPOL_F_NODE."
+ if nodeFlag || addrFlag {
+ return 0, nil, syserror.EINVAL
+ }
+ if err := copyOutNodemask(t, nodemask, maxnode, allowedNodemask); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, nil
+ }
+
+ // "If flags specifies MPOL_F_ADDR, then information is returned about the
+ // policy governing the memory address given in addr. ... If the mode
+ // argument is not NULL, then get_mempolicy() will store the policy mode
+ // and any optional mode flags of the requested NUMA policy in the location
+ // pointed to by this argument. If nodemask is not NULL, then the nodemask
+ // associated with the policy will be stored in the location pointed to by
+ // this argument."
+ if addrFlag {
+ policy, nodemaskVal, err := t.MemoryManager().NumaPolicy(addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ if nodeFlag {
+ // "If flags specifies both MPOL_F_NODE and MPOL_F_ADDR,
+ // get_mempolicy() will return the node ID of the node on which the
+ // address addr is allocated into the location pointed to by mode.
+ // If no page has yet been allocated for the specified address,
+ // get_mempolicy() will allocate a page as if the thread had
+ // performed a read (load) access to that address, and return the
+ // ID of the node where that page was allocated."
+ buf := t.CopyScratchBuffer(1)
+ _, err := t.CopyInBytes(addr, buf)
+ if err != nil {
+ return 0, nil, err
+ }
+ policy = linux.MPOL_DEFAULT // maxNodes == 1
+ }
+ if mode != 0 {
+ if _, err := policy.CopyOut(t, mode); err != nil {
+ return 0, nil, err
+ }
+ }
+ if nodemask != 0 {
+ if err := copyOutNodemask(t, nodemask, maxnode, nodemaskVal); err != nil {
+ return 0, nil, err
+ }
+ }
+ return 0, nil, nil
+ }
+
+ // "EINVAL: ... flags specified MPOL_F_ADDR and addr is NULL, or flags did
+ // not specify MPOL_F_ADDR and addr is not NULL." This is partially
+ // inaccurate: if flags specifies MPOL_F_ADDR,
+ // mm/mempolicy.c:do_get_mempolicy() doesn't special-case NULL; it will
+ // just (usually) fail to find a VMA at address 0 and return EFAULT.
+ if addr != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // "If flags is specified as 0, then information about the calling thread's
+ // default policy (as set by set_mempolicy(2)) is returned, in the buffers
+ // pointed to by mode and nodemask. ... If flags specifies MPOL_F_NODE, but
+ // not MPOL_F_ADDR, and the thread's current policy is MPOL_INTERLEAVE,
+ // then get_mempolicy() will return in the location pointed to by a
+ // non-NULL mode argument, the node ID of the next node that will be used
+ // for interleaving of internal kernel pages allocated on behalf of the
+ // thread."
+ policy, nodemaskVal := t.NumaPolicy()
+ if nodeFlag {
+ if policy&^linux.MPOL_MODE_FLAGS != linux.MPOL_INTERLEAVE {
+ return 0, nil, syserror.EINVAL
+ }
+ policy = linux.MPOL_DEFAULT // maxNodes == 1
+ }
+ if mode != 0 {
+ if _, err := policy.CopyOut(t, mode); err != nil {
+ return 0, nil, err
+ }
+ }
+ if nodemask != 0 {
+ if err := copyOutNodemask(t, nodemask, maxnode, nodemaskVal); err != nil {
+ return 0, nil, err
+ }
+ }
+ return 0, nil, nil
+}
+
+// SetMempolicy implements the syscall set_mempolicy(2).
+func SetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ modeWithFlags := linux.NumaPolicy(args[0].Int())
+ nodemask := args[1].Pointer()
+ maxnode := args[2].Uint()
+
+ modeWithFlags, nodemaskVal, err := copyInMempolicyNodemask(t, modeWithFlags, nodemask, maxnode)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ t.SetNumaPolicy(modeWithFlags, nodemaskVal)
+ return 0, nil, nil
+}
+
+// Mbind implements the syscall mbind(2).
+func Mbind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].Uint64()
+ mode := linux.NumaPolicy(args[2].Int())
+ nodemask := args[3].Pointer()
+ maxnode := args[4].Uint()
+ flags := args[5].Uint()
+
+ if flags&^linux.MPOL_MF_VALID != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ // "If MPOL_MF_MOVE_ALL is passed in flags ... [the] calling thread must be
+ // privileged (CAP_SYS_NICE) to use this flag." - mbind(2)
+ if flags&linux.MPOL_MF_MOVE_ALL != 0 && !t.HasCapability(linux.CAP_SYS_NICE) {
+ return 0, nil, syserror.EPERM
+ }
+
+ mode, nodemaskVal, err := copyInMempolicyNodemask(t, mode, nodemask, maxnode)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Since we claim to have only a single node, all flags can be ignored
+ // (since all pages must already be on that single node).
+ err = t.MemoryManager().SetNumaPolicy(addr, length, mode, nodemaskVal)
+ return 0, nil, err
+}
+
+func copyInMempolicyNodemask(t *kernel.Task, modeWithFlags linux.NumaPolicy, nodemask usermem.Addr, maxnode uint32) (linux.NumaPolicy, uint64, error) {
+ flags := linux.NumaPolicy(modeWithFlags & linux.MPOL_MODE_FLAGS)
+ mode := linux.NumaPolicy(modeWithFlags &^ linux.MPOL_MODE_FLAGS)
+ if flags == linux.MPOL_MODE_FLAGS {
+ // Can't specify both mode flags simultaneously.
+ return 0, 0, syserror.EINVAL
+ }
+ if mode < 0 || mode >= linux.MPOL_MAX {
+ // Must specify a valid mode.
+ return 0, 0, syserror.EINVAL
+ }
+
+ var nodemaskVal uint64
+ if nodemask != 0 {
+ var err error
+ nodemaskVal, err = copyInNodemask(t, nodemask, maxnode)
+ if err != nil {
+ return 0, 0, err
+ }
+ }
+
+ switch mode {
+ case linux.MPOL_DEFAULT:
+ // "nodemask must be specified as NULL." - set_mempolicy(2). This is inaccurate;
+ // Linux allows a nodemask to be specified, as long as it is empty.
+ if nodemaskVal != 0 {
+ return 0, 0, syserror.EINVAL
+ }
+ case linux.MPOL_BIND, linux.MPOL_INTERLEAVE:
+ // These require a non-empty nodemask.
+ if nodemaskVal == 0 {
+ return 0, 0, syserror.EINVAL
+ }
+ case linux.MPOL_PREFERRED:
+ // This permits an empty nodemask, as long as no flags are set.
+ if nodemaskVal == 0 && flags != 0 {
+ return 0, 0, syserror.EINVAL
+ }
+ case linux.MPOL_LOCAL:
+ // This requires an empty nodemask and no flags set ...
+ if nodemaskVal != 0 || flags != 0 {
+ return 0, 0, syserror.EINVAL
+ }
+ // ... and is implemented as MPOL_PREFERRED.
+ mode = linux.MPOL_PREFERRED
+ default:
+ // Unknown mode, which we should have rejected above.
+ panic(fmt.Sprintf("unknown mode: %v", mode))
+ }
+
+ return mode | flags, nodemaskVal, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go
new file mode 100644
index 000000000..91694d374
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_mmap.go
@@ -0,0 +1,332 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "bytes"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Brk implements linux syscall brk(2).
+func Brk(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr, _ := t.MemoryManager().Brk(t, args[0].Pointer())
+ // "However, the actual Linux system call returns the new program break on
+ // success. On failure, the system call returns the current break." -
+ // brk(2)
+ return uintptr(addr), nil, nil
+}
+
+// LINT.IfChange
+
+// Mmap implements linux syscall mmap(2).
+func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ prot := args[2].Int()
+ flags := args[3].Int()
+ fd := args[4].Int()
+ fixed := flags&linux.MAP_FIXED != 0
+ private := flags&linux.MAP_PRIVATE != 0
+ shared := flags&linux.MAP_SHARED != 0
+ anon := flags&linux.MAP_ANONYMOUS != 0
+ map32bit := flags&linux.MAP_32BIT != 0
+
+ // Require exactly one of MAP_PRIVATE and MAP_SHARED.
+ if private == shared {
+ return 0, nil, syserror.EINVAL
+ }
+
+ opts := memmap.MMapOpts{
+ Length: args[1].Uint64(),
+ Offset: args[5].Uint64(),
+ Addr: args[0].Pointer(),
+ Fixed: fixed,
+ Unmap: fixed,
+ Map32Bit: map32bit,
+ Private: private,
+ Perms: usermem.AccessType{
+ Read: linux.PROT_READ&prot != 0,
+ Write: linux.PROT_WRITE&prot != 0,
+ Execute: linux.PROT_EXEC&prot != 0,
+ },
+ MaxPerms: usermem.AnyAccess,
+ GrowsDown: linux.MAP_GROWSDOWN&flags != 0,
+ Precommit: linux.MAP_POPULATE&flags != 0,
+ }
+ if linux.MAP_LOCKED&flags != 0 {
+ opts.MLockMode = memmap.MLockEager
+ }
+ defer func() {
+ if opts.MappingIdentity != nil {
+ opts.MappingIdentity.DecRef()
+ }
+ }()
+
+ if !anon {
+ // Convert the passed FD to a file reference.
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ flags := file.Flags()
+ // mmap unconditionally requires that the FD is readable.
+ if !flags.Read {
+ return 0, nil, syserror.EACCES
+ }
+ // MAP_SHARED requires that the FD be writable for PROT_WRITE.
+ if shared && !flags.Write {
+ opts.MaxPerms.Write = false
+ }
+
+ if err := file.ConfigureMMap(t, &opts); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ rv, err := t.MemoryManager().MMap(t, opts)
+ return uintptr(rv), nil, err
+}
+
+// LINT.ThenChange(vfs2/mmap.go)
+
+// Munmap implements linux syscall munmap(2).
+func Munmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, t.MemoryManager().MUnmap(t, args[0].Pointer(), args[1].Uint64())
+}
+
+// Mremap implements linux syscall mremap(2).
+func Mremap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldAddr := args[0].Pointer()
+ oldSize := args[1].Uint64()
+ newSize := args[2].Uint64()
+ flags := args[3].Uint64()
+ newAddr := args[4].Pointer()
+
+ if flags&^(linux.MREMAP_MAYMOVE|linux.MREMAP_FIXED) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ mayMove := flags&linux.MREMAP_MAYMOVE != 0
+ fixed := flags&linux.MREMAP_FIXED != 0
+ var moveMode mm.MRemapMoveMode
+ switch {
+ case !mayMove && !fixed:
+ moveMode = mm.MRemapNoMove
+ case mayMove && !fixed:
+ moveMode = mm.MRemapMayMove
+ case mayMove && fixed:
+ moveMode = mm.MRemapMustMove
+ case !mayMove && fixed:
+ // "If MREMAP_FIXED is specified, then MREMAP_MAYMOVE must also be
+ // specified." - mremap(2)
+ return 0, nil, syserror.EINVAL
+ }
+
+ rv, err := t.MemoryManager().MRemap(t, oldAddr, oldSize, newSize, mm.MRemapOpts{
+ Move: moveMode,
+ NewAddr: newAddr,
+ })
+ return uintptr(rv), nil, err
+}
+
+// Mprotect implements linux syscall mprotect(2).
+func Mprotect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ length := args[1].Uint64()
+ prot := args[2].Int()
+ err := t.MemoryManager().MProtect(args[0].Pointer(), length, usermem.AccessType{
+ Read: linux.PROT_READ&prot != 0,
+ Write: linux.PROT_WRITE&prot != 0,
+ Execute: linux.PROT_EXEC&prot != 0,
+ }, linux.PROT_GROWSDOWN&prot != 0)
+ return 0, nil, err
+}
+
+// Madvise implements linux syscall madvise(2).
+func Madvise(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := uint64(args[1].SizeT())
+ adv := args[2].Int()
+
+ // "The Linux implementation requires that the address addr be
+ // page-aligned, and allows length to be zero." - madvise(2)
+ if addr.RoundDown() != addr {
+ return 0, nil, syserror.EINVAL
+ }
+ if length == 0 {
+ return 0, nil, nil
+ }
+ // Not explicitly stated: length need not be page-aligned.
+ lenAddr, ok := usermem.Addr(length).RoundUp()
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+ length = uint64(lenAddr)
+
+ switch adv {
+ case linux.MADV_DONTNEED:
+ return 0, nil, t.MemoryManager().Decommit(addr, length)
+ case linux.MADV_DOFORK:
+ return 0, nil, t.MemoryManager().SetDontFork(addr, length, false)
+ case linux.MADV_DONTFORK:
+ return 0, nil, t.MemoryManager().SetDontFork(addr, length, true)
+ case linux.MADV_HUGEPAGE, linux.MADV_NOHUGEPAGE:
+ fallthrough
+ case linux.MADV_MERGEABLE, linux.MADV_UNMERGEABLE:
+ fallthrough
+ case linux.MADV_DONTDUMP, linux.MADV_DODUMP:
+ // TODO(b/72045799): Core dumping isn't implemented, so these are
+ // no-ops.
+ fallthrough
+ case linux.MADV_NORMAL, linux.MADV_RANDOM, linux.MADV_SEQUENTIAL, linux.MADV_WILLNEED:
+ // Do nothing, we totally ignore the suggestions above.
+ return 0, nil, nil
+ case linux.MADV_REMOVE:
+ // These "suggestions" have application-visible side effects, so we
+ // have to indicate that we don't support them.
+ return 0, nil, syserror.ENOSYS
+ case linux.MADV_HWPOISON:
+ // Only privileged processes are allowed to poison pages.
+ return 0, nil, syserror.EPERM
+ default:
+ // If adv is not a valid value tell the caller.
+ return 0, nil, syserror.EINVAL
+ }
+}
+
+// Mincore implements the syscall mincore(2).
+func Mincore(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].SizeT()
+ vec := args[2].Pointer()
+
+ if addr != addr.RoundDown() {
+ return 0, nil, syserror.EINVAL
+ }
+ // "The length argument need not be a multiple of the page size, but since
+ // residency information is returned for whole pages, length is effectively
+ // rounded up to the next multiple of the page size." - mincore(2)
+ la, ok := usermem.Addr(length).RoundUp()
+ if !ok {
+ return 0, nil, syserror.ENOMEM
+ }
+ ar, ok := addr.ToRange(uint64(la))
+ if !ok {
+ return 0, nil, syserror.ENOMEM
+ }
+
+ // Pretend that all mapped pages are "resident in core".
+ mapped := t.MemoryManager().VirtualMemorySizeRange(ar)
+ // "ENOMEM: addr to addr + length contained unmapped memory."
+ if mapped != uint64(la) {
+ return 0, nil, syserror.ENOMEM
+ }
+ resident := bytes.Repeat([]byte{1}, int(mapped/usermem.PageSize))
+ _, err := t.CopyOut(vec, resident)
+ return 0, nil, err
+}
+
+// Msync implements Linux syscall msync(2).
+func Msync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].SizeT()
+ flags := args[2].Int()
+
+ // "The flags argument should specify exactly one of MS_ASYNC and MS_SYNC,
+ // and may additionally include the MS_INVALIDATE bit. ... However, Linux
+ // permits a call to msync() that specifies neither of these flags, with
+ // semantics that are (currently) equivalent to specifying MS_ASYNC." -
+ // msync(2)
+ if flags&^(linux.MS_ASYNC|linux.MS_SYNC|linux.MS_INVALIDATE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ sync := flags&linux.MS_SYNC != 0
+ if sync && flags&linux.MS_ASYNC != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ err := t.MemoryManager().MSync(t, addr, uint64(length), mm.MSyncOpts{
+ Sync: sync,
+ Invalidate: flags&linux.MS_INVALIDATE != 0,
+ })
+ // MSync calls fsync, the same interrupt conversion rules apply, see
+ // mm/msync.c, fsync POSIX.1-2008.
+ return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
+}
+
+// Mlock implements linux syscall mlock(2).
+func Mlock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].SizeT()
+
+ return 0, nil, t.MemoryManager().MLock(t, addr, uint64(length), memmap.MLockEager)
+}
+
+// Mlock2 implements linux syscall mlock2(2).
+func Mlock2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].SizeT()
+ flags := args[2].Int()
+
+ if flags&^(linux.MLOCK_ONFAULT) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ mode := memmap.MLockEager
+ if flags&linux.MLOCK_ONFAULT != 0 {
+ mode = memmap.MLockLazy
+ }
+ return 0, nil, t.MemoryManager().MLock(t, addr, uint64(length), mode)
+}
+
+// Munlock implements linux syscall munlock(2).
+func Munlock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].SizeT()
+
+ return 0, nil, t.MemoryManager().MLock(t, addr, uint64(length), memmap.MLockNone)
+}
+
+// Mlockall implements linux syscall mlockall(2).
+func Mlockall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := args[0].Int()
+
+ if flags&^(linux.MCL_CURRENT|linux.MCL_FUTURE|linux.MCL_ONFAULT) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ mode := memmap.MLockEager
+ if flags&linux.MCL_ONFAULT != 0 {
+ mode = memmap.MLockLazy
+ }
+ return 0, nil, t.MemoryManager().MLockAll(t, mm.MLockAllOpts{
+ Current: flags&linux.MCL_CURRENT != 0,
+ Future: flags&linux.MCL_FUTURE != 0,
+ Mode: mode,
+ })
+}
+
+// Munlockall implements linux syscall munlockall(2).
+func Munlockall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, t.MemoryManager().MLockAll(t, mm.MLockAllOpts{
+ Current: true,
+ Future: true,
+ Mode: memmap.MLockNone,
+ })
+}
diff --git a/pkg/sentry/syscalls/linux/sys_mount.go b/pkg/sentry/syscalls/linux/sys_mount.go
new file mode 100644
index 000000000..eb5ff48f5
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_mount.go
@@ -0,0 +1,154 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Mount implements Linux syscall mount(2).
+func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ sourceAddr := args[0].Pointer()
+ targetAddr := args[1].Pointer()
+ typeAddr := args[2].Pointer()
+ flags := args[3].Uint64()
+ dataAddr := args[4].Pointer()
+
+ fsType, err := t.CopyInString(typeAddr, usermem.PageSize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ sourcePath, _, err := copyInPath(t, sourceAddr, true /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ targetPath, _, err := copyInPath(t, targetAddr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ data := ""
+ if dataAddr != 0 {
+ // In Linux, a full page is always copied in regardless of null
+ // character placement, and the address is passed to each file system.
+ // Most file systems always treat this data as a string, though, and so
+ // do all of the ones we implement.
+ data, err = t.CopyInString(dataAddr, usermem.PageSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+
+ // Ignore magic value that was required before Linux 2.4.
+ if flags&linux.MS_MGC_MSK == linux.MS_MGC_VAL {
+ flags = flags &^ linux.MS_MGC_MSK
+ }
+
+ // Must have CAP_SYS_ADMIN in the mount namespace's associated user
+ // namespace.
+ if !t.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.MountNamespace().UserNamespace()) {
+ return 0, nil, syserror.EPERM
+ }
+
+ const unsupportedOps = linux.MS_REMOUNT | linux.MS_BIND |
+ linux.MS_SHARED | linux.MS_PRIVATE | linux.MS_SLAVE |
+ linux.MS_UNBINDABLE | linux.MS_MOVE
+
+ // Silently allow MS_NOSUID, since we don't implement set-id bits
+ // anyway.
+ const unsupportedFlags = linux.MS_NODEV |
+ linux.MS_NODIRATIME | linux.MS_STRICTATIME
+
+ // Linux just allows passing any flags to mount(2) - it won't fail when
+ // unknown or unsupported flags are passed. Since we don't implement
+ // everything, we fail explicitly on flags that are unimplemented.
+ if flags&(unsupportedOps|unsupportedFlags) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ rsys, ok := fs.FindFilesystem(fsType)
+ if !ok {
+ return 0, nil, syserror.ENODEV
+ }
+ if !rsys.AllowUserMount() {
+ return 0, nil, syserror.EPERM
+ }
+
+ var superFlags fs.MountSourceFlags
+ if flags&linux.MS_NOATIME == linux.MS_NOATIME {
+ superFlags.NoAtime = true
+ }
+ if flags&linux.MS_RDONLY == linux.MS_RDONLY {
+ superFlags.ReadOnly = true
+ }
+ if flags&linux.MS_NOEXEC == linux.MS_NOEXEC {
+ superFlags.NoExec = true
+ }
+
+ rootInode, err := rsys.Mount(t, sourcePath, superFlags, data, nil)
+ if err != nil {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if err := fileOpOn(t, linux.AT_FDCWD, targetPath, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ // Mount will take a reference on rootInode if successful.
+ return t.MountNamespace().Mount(t, d, rootInode)
+ }); err != nil {
+ // Something went wrong. Drop our ref on rootInode before
+ // returning the error.
+ rootInode.DecRef()
+ return 0, nil, err
+ }
+
+ return 0, nil, nil
+}
+
+// Umount2 implements Linux syscall umount2(2).
+func Umount2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Int()
+
+ const unsupported = linux.MNT_FORCE | linux.MNT_EXPIRE
+ if flags&unsupported != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Must have CAP_SYS_ADMIN in the mount namespace's associated user
+ // namespace.
+ //
+ // Currently, this is always the init task's user namespace.
+ if !t.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.MountNamespace().UserNamespace()) {
+ return 0, nil, syserror.EPERM
+ }
+
+ resolve := flags&linux.UMOUNT_NOFOLLOW != linux.UMOUNT_NOFOLLOW
+ detachOnly := flags&linux.MNT_DETACH == linux.MNT_DETACH
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, resolve, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ return t.MountNamespace().Unmount(t, d, detachOnly)
+ })
+}
diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go
new file mode 100644
index 000000000..43c510930
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_pipe.go
@@ -0,0 +1,77 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// LINT.IfChange
+
+// pipe2 implements the actual system call with flags.
+func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) {
+ if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 {
+ return 0, syserror.EINVAL
+ }
+ r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize, usermem.PageSize)
+
+ r.SetFlags(linuxToFlags(flags).Settable())
+ defer r.DecRef()
+
+ w.SetFlags(linuxToFlags(flags).Settable())
+ defer w.DecRef()
+
+ fds, err := t.NewFDs(0, []*fs.File{r, w}, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ if _, err := t.CopyOut(addr, fds); err != nil {
+ for _, fd := range fds {
+ if file, _ := t.FDTable().Remove(fd); file != nil {
+ file.DecRef()
+ }
+ }
+ return 0, err
+ }
+ return 0, nil
+}
+
+// Pipe implements linux syscall pipe(2).
+func Pipe(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ n, err := pipe2(t, addr, 0)
+ return n, nil, err
+}
+
+// Pipe2 implements linux syscall pipe2(2).
+func Pipe2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := uint(args[1].Uint())
+
+ n, err := pipe2(t, addr, flags)
+ return n, nil, err
+}
+
+// LINT.ThenChange(vfs2/pipe.go)
diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go
new file mode 100644
index 000000000..f0198141c
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_poll.go
@@ -0,0 +1,545 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// fileCap is the maximum allowable files for poll & select.
+const fileCap = 1024 * 1024
+
+// Masks for "readable", "writable", and "exceptional" events as defined by
+// select(2).
+const (
+ // selectReadEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLIN_SET.
+ selectReadEvents = linux.POLLIN | linux.POLLHUP | linux.POLLERR
+
+ // selectWriteEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLOUT_SET.
+ selectWriteEvents = linux.POLLOUT | linux.POLLERR
+
+ // selectExceptEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLEX_SET.
+ selectExceptEvents = linux.POLLPRI
+)
+
+// pollState tracks the associated file descriptor and waiter of a PollFD.
+type pollState struct {
+ file *fs.File
+ waiter waiter.Entry
+}
+
+// initReadiness gets the current ready mask for the file represented by the FD
+// stored in pfd.FD. If a channel is passed in, the waiter entry in "state" is
+// used to register with the file for event notifications, and a reference to
+// the file is stored in "state".
+func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan struct{}) {
+ if pfd.FD < 0 {
+ pfd.REvents = 0
+ return
+ }
+
+ file := t.GetFile(pfd.FD)
+ if file == nil {
+ pfd.REvents = linux.POLLNVAL
+ return
+ }
+
+ if ch == nil {
+ defer file.DecRef()
+ } else {
+ state.file = file
+ state.waiter, _ = waiter.NewChannelEntry(ch)
+ file.EventRegister(&state.waiter, waiter.EventMaskFromLinux(uint32(pfd.Events)))
+ }
+
+ r := file.Readiness(waiter.EventMaskFromLinux(uint32(pfd.Events)))
+ pfd.REvents = int16(r.ToLinux()) & pfd.Events
+}
+
+// releaseState releases all the pollState in "state".
+func releaseState(state []pollState) {
+ for i := range state {
+ if state[i].file != nil {
+ state[i].file.EventUnregister(&state[i].waiter)
+ state[i].file.DecRef()
+ }
+ }
+}
+
+// pollBlock polls the PollFDs in "pfd" with a bounded time specified in "timeout"
+// when "timeout" is greater than zero.
+//
+// pollBlock returns the remaining timeout, which is always 0 on a timeout; and 0 or
+// positive if interrupted by a signal.
+func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.Duration, uintptr, error) {
+ var ch chan struct{}
+ if timeout != 0 {
+ ch = make(chan struct{}, 1)
+ }
+
+ // Register for event notification in the files involved if we may
+ // block (timeout not zero). Once we find a file that has a non-zero
+ // result, we stop registering for events but still go through all files
+ // to get their ready masks.
+ state := make([]pollState, len(pfd))
+ defer releaseState(state)
+ n := uintptr(0)
+ for i := range pfd {
+ initReadiness(t, &pfd[i], &state[i], ch)
+ if pfd[i].REvents != 0 {
+ n++
+ ch = nil
+ }
+ }
+
+ if timeout == 0 {
+ return timeout, n, nil
+ }
+
+ forever := timeout < 0
+
+ for n == 0 {
+ var err error
+ // Wait for a notification.
+ timeout, err = t.BlockWithTimeout(ch, !forever, timeout)
+ if err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = nil
+ }
+ return timeout, 0, err
+ }
+
+ // We got notified, count how many files are ready. If none,
+ // then this was a spurious notification, and we just go back
+ // to sleep with the remaining timeout.
+ for i := range state {
+ if state[i].file == nil {
+ continue
+ }
+
+ r := state[i].file.Readiness(waiter.EventMaskFromLinux(uint32(pfd[i].Events)))
+ rl := int16(r.ToLinux()) & pfd[i].Events
+ if rl != 0 {
+ pfd[i].REvents = rl
+ n++
+ }
+ }
+ }
+
+ return timeout, n, nil
+}
+
+// CopyInPollFDs copies an array of struct pollfd unless nfds exceeds the max.
+func CopyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD, error) {
+ if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) {
+ return nil, syserror.EINVAL
+ }
+
+ pfd := make([]linux.PollFD, nfds)
+ if nfds > 0 {
+ if _, err := t.CopyIn(addr, &pfd); err != nil {
+ return nil, err
+ }
+ }
+
+ return pfd, nil
+}
+
+func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) {
+ pfd, err := CopyInPollFDs(t, addr, nfds)
+ if err != nil {
+ return timeout, 0, err
+ }
+
+ // Compatibility warning: Linux adds POLLHUP and POLLERR just before
+ // polling, in fs/select.c:do_pollfd(). Since pfd is copied out after
+ // polling, changing event masks here is an application-visible difference.
+ // (Linux also doesn't copy out event masks at all, only revents.)
+ for i := range pfd {
+ pfd[i].Events |= linux.POLLHUP | linux.POLLERR
+ }
+ remainingTimeout, n, err := pollBlock(t, pfd, timeout)
+ err = syserror.ConvertIntr(err, syserror.EINTR)
+
+ // The poll entries are copied out regardless of whether
+ // any are set or not. This aligns with the Linux behavior.
+ if nfds > 0 && err == nil {
+ if _, err := t.CopyOut(addr, pfd); err != nil {
+ return remainingTimeout, 0, err
+ }
+ }
+
+ return remainingTimeout, n, err
+}
+
+// CopyInFDSet copies an fd set from select(2)/pselect(2).
+func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) {
+ set := make([]byte, nBytes)
+
+ if addr != 0 {
+ if _, err := t.CopyIn(addr, &set); err != nil {
+ return nil, err
+ }
+ // If we only use part of the last byte, mask out the extraneous bits.
+ //
+ // N.B. This only works on little-endian architectures.
+ if nBitsInLastPartialByte != 0 {
+ set[nBytes-1] &^= byte(0xff) << nBitsInLastPartialByte
+ }
+ }
+ return set, nil
+}
+
+func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) {
+ if nfds < 0 || nfds > fileCap {
+ return 0, syserror.EINVAL
+ }
+
+ // Calculate the size of the fd sets (one bit per fd).
+ nBytes := (nfds + 7) / 8
+ nBitsInLastPartialByte := nfds % 8
+
+ // Capture all the provided input vectors.
+ r, err := CopyInFDSet(t, readFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+ w, err := CopyInFDSet(t, writeFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+ e, err := CopyInFDSet(t, exceptFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+
+ // Count how many FDs are actually being requested so that we can build
+ // a PollFD array.
+ fdCount := 0
+ for i := 0; i < nBytes; i++ {
+ v := r[i] | w[i] | e[i]
+ for v != 0 {
+ v &= (v - 1)
+ fdCount++
+ }
+ }
+
+ // Build the PollFD array.
+ pfd := make([]linux.PollFD, 0, fdCount)
+ var fd int32
+ for i := 0; i < nBytes; i++ {
+ rV, wV, eV := r[i], w[i], e[i]
+ v := rV | wV | eV
+ m := byte(1)
+ for j := 0; j < 8; j++ {
+ if (v & m) != 0 {
+ // Make sure the fd is valid and decrement the reference
+ // immediately to ensure we don't leak. Note, another thread
+ // might be about to close fd. This is racy, but that's
+ // OK. Linux is racy in the same way.
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ file.DecRef()
+
+ var mask int16
+ if (rV & m) != 0 {
+ mask |= selectReadEvents
+ }
+
+ if (wV & m) != 0 {
+ mask |= selectWriteEvents
+ }
+
+ if (eV & m) != 0 {
+ mask |= selectExceptEvents
+ }
+
+ pfd = append(pfd, linux.PollFD{
+ FD: fd,
+ Events: mask,
+ })
+ }
+
+ fd++
+ m <<= 1
+ }
+ }
+
+ // Do the syscall, then count the number of bits set.
+ if _, _, err = pollBlock(t, pfd, timeout); err != nil {
+ return 0, syserror.ConvertIntr(err, syserror.EINTR)
+ }
+
+ // r, w, and e are currently event mask bitsets; unset bits corresponding
+ // to events that *didn't* occur.
+ bitSetCount := uintptr(0)
+ for idx := range pfd {
+ events := pfd[idx].REvents
+ i, j := pfd[idx].FD/8, uint(pfd[idx].FD%8)
+ m := byte(1) << j
+ if r[i]&m != 0 {
+ if (events & selectReadEvents) != 0 {
+ bitSetCount++
+ } else {
+ r[i] &^= m
+ }
+ }
+ if w[i]&m != 0 {
+ if (events & selectWriteEvents) != 0 {
+ bitSetCount++
+ } else {
+ w[i] &^= m
+ }
+ }
+ if e[i]&m != 0 {
+ if (events & selectExceptEvents) != 0 {
+ bitSetCount++
+ } else {
+ e[i] &^= m
+ }
+ }
+ }
+
+ // Copy updated vectors back.
+ if readFDs != 0 {
+ if _, err := t.CopyOut(readFDs, r); err != nil {
+ return 0, err
+ }
+ }
+
+ if writeFDs != 0 {
+ if _, err := t.CopyOut(writeFDs, w); err != nil {
+ return 0, err
+ }
+ }
+
+ if exceptFDs != 0 {
+ if _, err := t.CopyOut(exceptFDs, e); err != nil {
+ return 0, err
+ }
+ }
+
+ return bitSetCount, nil
+}
+
+// timeoutRemaining returns the amount of time remaining for the specified
+// timeout or 0 if it has elapsed.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func timeoutRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration) time.Duration {
+ now := t.Kernel().MonotonicClock().Now()
+ remaining := timeout - now.Sub(startNs)
+ if remaining < 0 {
+ remaining = 0
+ }
+ return remaining
+}
+
+// copyOutTimespecRemaining copies the time remaining in timeout to timespecAddr.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr usermem.Addr) error {
+ if timeout <= 0 {
+ return nil
+ }
+ remaining := timeoutRemaining(t, startNs, timeout)
+ tsRemaining := linux.NsecToTimespec(remaining.Nanoseconds())
+ return copyTimespecOut(t, timespecAddr, &tsRemaining)
+}
+
+// copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr usermem.Addr) error {
+ if timeout <= 0 {
+ return nil
+ }
+ remaining := timeoutRemaining(t, startNs, timeout)
+ tvRemaining := linux.NsecToTimeval(remaining.Nanoseconds())
+ return copyTimevalOut(t, timevalAddr, &tvRemaining)
+}
+
+// pollRestartBlock encapsulates the state required to restart poll(2) via
+// restart_syscall(2).
+//
+// +stateify savable
+type pollRestartBlock struct {
+ pfdAddr usermem.Addr
+ nfds uint
+ timeout time.Duration
+}
+
+// Restart implements kernel.SyscallRestartBlock.Restart.
+func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
+ return poll(t, p.pfdAddr, p.nfds, p.timeout)
+}
+
+func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration) (uintptr, error) {
+ remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout)
+ // On an interrupt poll(2) is restarted with the remaining timeout.
+ if err == syserror.EINTR {
+ t.SetSyscallRestartBlock(&pollRestartBlock{
+ pfdAddr: pfdAddr,
+ nfds: nfds,
+ timeout: remainingTimeout,
+ })
+ return 0, kernel.ERESTART_RESTARTBLOCK
+ }
+ return n, err
+}
+
+// Poll implements linux syscall poll(2).
+func Poll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pfdAddr := args[0].Pointer()
+ nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
+ timeout := time.Duration(args[2].Int()) * time.Millisecond
+ n, err := poll(t, pfdAddr, nfds, timeout)
+ return n, nil, err
+}
+
+// Ppoll implements linux syscall ppoll(2).
+func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pfdAddr := args[0].Pointer()
+ nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
+ timespecAddr := args[2].Pointer()
+ maskAddr := args[3].Pointer()
+ maskSize := uint(args[4].Uint())
+
+ timeout, err := copyTimespecInToDuration(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var startNs ktime.Time
+ if timeout > 0 {
+ startNs = t.Kernel().MonotonicClock().Now()
+ }
+
+ if maskAddr != 0 {
+ mask, err := CopyInSigSet(t, maskAddr, maskSize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ oldmask := t.SignalMask()
+ t.SetSignalMask(mask)
+ t.SetSavedSignalMask(oldmask)
+ }
+
+ _, n, err := doPoll(t, pfdAddr, nfds, timeout)
+ copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
+ // doPoll returns EINTR if interrupted, but ppoll is normally restartable
+ // if interrupted by something other than a signal handled by the
+ // application (i.e. returns ERESTARTNOHAND). However, if
+ // copyOutTimespecRemaining failed, then the restarted ppoll would use the
+ // wrong timeout, so the error should be left as EINTR.
+ //
+ // Note that this means that if err is nil but copyErr is not, copyErr is
+ // ignored. This is consistent with Linux.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// Select implements linux syscall select(2).
+func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nfds := int(args[0].Int()) // select(2) uses an int.
+ readFDs := args[1].Pointer()
+ writeFDs := args[2].Pointer()
+ exceptFDs := args[3].Pointer()
+ timevalAddr := args[4].Pointer()
+
+ // Use a negative Duration to indicate "no timeout".
+ timeout := time.Duration(-1)
+ if timevalAddr != 0 {
+ timeval, err := copyTimevalIn(t, timevalAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ if timeval.Sec < 0 || timeval.Usec < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ timeout = time.Duration(timeval.ToNsecCapped())
+ }
+ startNs := t.Kernel().MonotonicClock().Now()
+ n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
+ copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr)
+ // See comment in Ppoll.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// Pselect implements linux syscall pselect(2).
+func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nfds := int(args[0].Int()) // select(2) uses an int.
+ readFDs := args[1].Pointer()
+ writeFDs := args[2].Pointer()
+ exceptFDs := args[3].Pointer()
+ timespecAddr := args[4].Pointer()
+ maskWithSizeAddr := args[5].Pointer()
+
+ timeout, err := copyTimespecInToDuration(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var startNs ktime.Time
+ if timeout > 0 {
+ startNs = t.Kernel().MonotonicClock().Now()
+ }
+
+ if maskWithSizeAddr != 0 {
+ maskAddr, size, err := copyInSigSetWithSize(t, maskWithSizeAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if maskAddr != 0 {
+ mask, err := CopyInSigSet(t, maskAddr, size)
+ if err != nil {
+ return 0, nil, err
+ }
+ oldmask := t.SignalMask()
+ t.SetSignalMask(mask)
+ t.SetSavedSignalMask(oldmask)
+ }
+ }
+
+ n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
+ copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
+ // See comment in Ppoll.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go
new file mode 100644
index 000000000..f92bf8096
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_prctl.go
@@ -0,0 +1,228 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Prctl implements linux syscall prctl(2).
+// It has a list of subfunctions which operate on the process. The arguments are
+// all based on each subfunction.
+func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ option := args[0].Int()
+
+ switch option {
+ case linux.PR_SET_PDEATHSIG:
+ sig := linux.Signal(args[1].Int())
+ if sig != 0 && !sig.IsValid() {
+ return 0, nil, syserror.EINVAL
+ }
+ t.SetParentDeathSignal(sig)
+ return 0, nil, nil
+
+ case linux.PR_GET_PDEATHSIG:
+ _, err := t.CopyOut(args[1].Pointer(), int32(t.ParentDeathSignal()))
+ return 0, nil, err
+
+ case linux.PR_GET_DUMPABLE:
+ d := t.MemoryManager().Dumpability()
+ switch d {
+ case mm.NotDumpable:
+ return linux.SUID_DUMP_DISABLE, nil, nil
+ case mm.UserDumpable:
+ return linux.SUID_DUMP_USER, nil, nil
+ case mm.RootDumpable:
+ return linux.SUID_DUMP_ROOT, nil, nil
+ default:
+ panic(fmt.Sprintf("Unknown dumpability %v", d))
+ }
+
+ case linux.PR_SET_DUMPABLE:
+ var d mm.Dumpability
+ switch args[1].Int() {
+ case linux.SUID_DUMP_DISABLE:
+ d = mm.NotDumpable
+ case linux.SUID_DUMP_USER:
+ d = mm.UserDumpable
+ default:
+ // N.B. Userspace may not pass SUID_DUMP_ROOT.
+ return 0, nil, syserror.EINVAL
+ }
+ t.MemoryManager().SetDumpability(d)
+ return 0, nil, nil
+
+ case linux.PR_GET_KEEPCAPS:
+ if t.Credentials().KeepCaps {
+ return 1, nil, nil
+ }
+
+ return 0, nil, nil
+
+ case linux.PR_SET_KEEPCAPS:
+ val := args[1].Int()
+ // prctl(2): arg2 must be either 0 (permitted capabilities are cleared)
+ // or 1 (permitted capabilities are kept).
+ if val == 0 {
+ t.SetKeepCaps(false)
+ } else if val == 1 {
+ t.SetKeepCaps(true)
+ } else {
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, nil
+
+ case linux.PR_SET_NAME:
+ addr := args[1].Pointer()
+ name, err := t.CopyInString(addr, linux.TASK_COMM_LEN-1)
+ if err != nil && err != syserror.ENAMETOOLONG {
+ return 0, nil, err
+ }
+ t.SetName(name)
+
+ case linux.PR_GET_NAME:
+ addr := args[1].Pointer()
+ buf := t.CopyScratchBuffer(linux.TASK_COMM_LEN)
+ len := copy(buf, t.Name())
+ if len < linux.TASK_COMM_LEN {
+ buf[len] = 0
+ len++
+ }
+ _, err := t.CopyOut(addr, buf[:len])
+ if err != nil {
+ return 0, nil, err
+ }
+
+ case linux.PR_SET_MM:
+ if !t.HasCapability(linux.CAP_SYS_RESOURCE) {
+ return 0, nil, syserror.EPERM
+ }
+
+ switch args[1].Int() {
+ case linux.PR_SET_MM_EXE_FILE:
+ fd := args[2].Int()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // They trying to set exe to a non-file?
+ if !fs.IsFile(file.Dirent.Inode.StableAttr) {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Set the underlying executable.
+ t.MemoryManager().SetExecutable(fsbridge.NewFSFile(file))
+
+ case linux.PR_SET_MM_AUXV,
+ linux.PR_SET_MM_START_CODE,
+ linux.PR_SET_MM_END_CODE,
+ linux.PR_SET_MM_START_DATA,
+ linux.PR_SET_MM_END_DATA,
+ linux.PR_SET_MM_START_STACK,
+ linux.PR_SET_MM_START_BRK,
+ linux.PR_SET_MM_BRK,
+ linux.PR_SET_MM_ARG_START,
+ linux.PR_SET_MM_ARG_END,
+ linux.PR_SET_MM_ENV_START,
+ linux.PR_SET_MM_ENV_END:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ fallthrough
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ case linux.PR_SET_NO_NEW_PRIVS:
+ if args[1].Int() != 1 || args[2].Int() != 0 || args[3].Int() != 0 || args[4].Int() != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ // PR_SET_NO_NEW_PRIVS is assumed to always be set.
+ // See kernel.Task.updateCredsForExecLocked.
+ return 0, nil, nil
+
+ case linux.PR_GET_NO_NEW_PRIVS:
+ if args[1].Int() != 0 || args[2].Int() != 0 || args[3].Int() != 0 || args[4].Int() != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ return 1, nil, nil
+
+ case linux.PR_SET_SECCOMP:
+ if args[1].Int() != linux.SECCOMP_MODE_FILTER {
+ // Unsupported mode.
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, seccomp(t, linux.SECCOMP_SET_MODE_FILTER, 0, args[2].Pointer())
+
+ case linux.PR_GET_SECCOMP:
+ return uintptr(t.SeccompMode()), nil, nil
+
+ case linux.PR_CAPBSET_READ:
+ cp := linux.Capability(args[1].Uint64())
+ if !cp.Ok() {
+ return 0, nil, syserror.EINVAL
+ }
+ var rv uintptr
+ if auth.CapabilitySetOf(cp)&t.Credentials().BoundingCaps != 0 {
+ rv = 1
+ }
+ return rv, nil, nil
+
+ case linux.PR_CAPBSET_DROP:
+ cp := linux.Capability(args[1].Uint64())
+ if !cp.Ok() {
+ return 0, nil, syserror.EINVAL
+ }
+ return 0, nil, t.DropBoundingCapability(cp)
+
+ case linux.PR_GET_TIMING,
+ linux.PR_SET_TIMING,
+ linux.PR_GET_TSC,
+ linux.PR_SET_TSC,
+ linux.PR_TASK_PERF_EVENTS_DISABLE,
+ linux.PR_TASK_PERF_EVENTS_ENABLE,
+ linux.PR_GET_TIMERSLACK,
+ linux.PR_SET_TIMERSLACK,
+ linux.PR_MCE_KILL,
+ linux.PR_MCE_KILL_GET,
+ linux.PR_GET_TID_ADDRESS,
+ linux.PR_SET_CHILD_SUBREAPER,
+ linux.PR_GET_CHILD_SUBREAPER,
+ linux.PR_GET_THP_DISABLE,
+ linux.PR_SET_THP_DISABLE,
+ linux.PR_MPX_ENABLE_MANAGEMENT,
+ linux.PR_MPX_DISABLE_MANAGEMENT:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ fallthrough
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_random.go b/pkg/sentry/syscalls/linux/sys_random.go
new file mode 100644
index 000000000..c0aa0fd60
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_random.go
@@ -0,0 +1,92 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "io"
+ "math"
+
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ _GRND_NONBLOCK = 0x1
+ _GRND_RANDOM = 0x2
+)
+
+// GetRandom implements the linux syscall getrandom(2).
+//
+// In a multi-tenant/shared environment, the only valid implementation is to
+// fetch data from the urandom pool, otherwise starvation attacks become
+// possible. The urandom pool is also expected to have plenty of entropy, thus
+// the GRND_RANDOM flag is ignored. The GRND_NONBLOCK flag does not apply, as
+// the pool will already be initialized.
+func GetRandom(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].SizeT()
+ flags := args[2].Int()
+
+ // Flags are checked for validity but otherwise ignored. See above.
+ if flags & ^(_GRND_NONBLOCK|_GRND_RANDOM) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if length > math.MaxInt32 {
+ length = math.MaxInt32
+ }
+ ar, ok := addr.ToRange(uint64(length))
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+
+ // "If the urandom source has been initialized, reads of up to 256 bytes
+ // will always return as many bytes as requested and will not be
+ // interrupted by signals. No such guarantees apply for larger buffer
+ // sizes." - getrandom(2)
+ min := int(length)
+ if min > 256 {
+ min = 256
+ }
+ n, err := t.MemoryManager().CopyOutFrom(t, usermem.AddrRangeSeqOf(ar), safemem.FromIOReader{&randReader{-1, min}}, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if n >= int64(min) {
+ return uintptr(n), nil, nil
+ }
+ return 0, nil, err
+}
+
+// randReader is a io.Reader that handles partial reads from rand.Reader.
+type randReader struct {
+ done int
+ min int
+}
+
+// Read implements io.Reader.Read.
+func (r *randReader) Read(dst []byte) (int, error) {
+ if r.done >= r.min {
+ return rand.Reader.Read(dst)
+ }
+ min := r.min - r.done
+ if min > len(dst) {
+ min = len(dst)
+ }
+ return io.ReadAtLeast(rand.Reader, dst, min)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go
new file mode 100644
index 000000000..071b4bacc
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_read.go
@@ -0,0 +1,394 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+const (
+ // EventMaskRead contains events that can be triggered on reads.
+ EventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
+)
+
+// Read implements linux syscall read(2). Note that we try to get a buffer that
+// is exactly the size requested because some applications like qemu expect
+// they can do large reads all at once. Bug for bug. Same for other read
+// calls below.
+func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the file is readable.
+ if !file.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := readv(t, file, dst)
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "read", file)
+}
+
+// Readahead implements readahead(2).
+func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ size := args[2].SizeT()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the file is readable.
+ if !file.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Check that the size is valid.
+ if int(size) < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Return EINVAL; if the underlying file type does not support readahead,
+ // then Linux will return EINVAL to indicate as much. In the future, we
+ // may extend this function to actually support readahead hints.
+ return 0, nil, syserror.EINVAL
+}
+
+// Pread64 implements linux syscall pread64(2).
+func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+ offset := args[3].Int64()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Is reading at an offset supported?
+ if !file.Flags().Pread {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ // Check that the file is readable.
+ if !file.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := preadv(t, file, dst, offset)
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pread64", file)
+}
+
+// Readv implements linux syscall readv(2).
+func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the file is readable.
+ if !file.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Read the iovecs that specify the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := readv(t, file, dst)
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "readv", file)
+}
+
+// Preadv implements linux syscall preadv(2).
+func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Is reading at an offset supported?
+ if !file.Flags().Pread {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ // Check that the file is readable.
+ if !file.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Read the iovecs that specify the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := preadv(t, file, dst, offset)
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "preadv", file)
+}
+
+// Preadv2 implements linux syscall preadv2(2).
+func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // While the syscall is
+ // preadv2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
+ // the linux internal call
+ // (https://elixir.bootlin.com/linux/v4.18/source/fs/read_write.c#L1248)
+ // splits the offset argument into a high/low value for compatibility with
+ // 32-bit architectures. The flags argument is the 5th argument.
+
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+ flags := int(args[5].Int())
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < -1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Is reading at an offset supported?
+ if offset > -1 && !file.Flags().Pread {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ // Check that the file is readable.
+ if !file.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Check flags field.
+ // Note: gVisor does not implement the RWF_HIPRI feature, but the flag is
+ // accepted as a valid flag argument for preadv2.
+ if flags&^linux.RWF_VALID != 0 {
+ return 0, nil, syserror.EOPNOTSUPP
+ }
+
+ // Read the iovecs that specify the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // If preadv2 is called with an offset of -1, readv is called.
+ if offset == -1 {
+ n, err := readv(t, file, dst)
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "preadv2", file)
+ }
+
+ n, err := preadv(t, file, dst, offset)
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "preadv2", file)
+}
+
+func readv(t *kernel.Task, f *fs.File, dst usermem.IOSequence) (int64, error) {
+ n, err := f.Readv(t, dst)
+ if err != syserror.ErrWouldBlock || f.Flags().NonBlocking {
+ if n > 0 {
+ // Queue notification if we read anything.
+ f.Dirent.InotifyEvent(linux.IN_ACCESS, 0)
+ }
+ return n, err
+ }
+
+ // Sockets support read timeouts.
+ var haveDeadline bool
+ var deadline ktime.Time
+ if s, ok := f.FileOperations.(socket.Socket); ok {
+ dl := s.RecvTimeout()
+ if dl < 0 && err == syserror.ErrWouldBlock {
+ return n, err
+ }
+ if dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ }
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ f.EventRegister(&w, EventMaskRead)
+
+ total := n
+ for {
+ // Shorten dst to reflect bytes previously read.
+ dst = dst.DropFirst64(n)
+
+ // Issue the request and break out if it completes with anything
+ // other than "would block".
+ n, err = f.Readv(t, dst)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
+ break
+ }
+ }
+
+ f.EventUnregister(&w)
+
+ if total > 0 {
+ // Queue notification if we read anything.
+ f.Dirent.InotifyEvent(linux.IN_ACCESS, 0)
+ }
+
+ return total, err
+}
+
+func preadv(t *kernel.Task, f *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ n, err := f.Preadv(t, dst, offset)
+ if err != syserror.ErrWouldBlock || f.Flags().NonBlocking {
+ if n > 0 {
+ // Queue notification if we read anything.
+ f.Dirent.InotifyEvent(linux.IN_ACCESS, 0)
+ }
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ f.EventRegister(&w, EventMaskRead)
+
+ total := n
+ for {
+ // Shorten dst to reflect bytes previously read.
+ dst = dst.DropFirst64(n)
+
+ // Issue the request and break out if it completes with anything
+ // other than "would block".
+ n, err = f.Preadv(t, dst, offset+total)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ if err = t.Block(ch); err != nil {
+ break
+ }
+ }
+
+ f.EventUnregister(&w)
+
+ if total > 0 {
+ // Queue notification if we read anything.
+ f.Dirent.InotifyEvent(linux.IN_ACCESS, 0)
+ }
+
+ return total, err
+}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_rlimit.go b/pkg/sentry/syscalls/linux/sys_rlimit.go
new file mode 100644
index 000000000..d5d5b6959
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_rlimit.go
@@ -0,0 +1,224 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// rlimit describes an implementation of 'struct rlimit', which may vary from
+// system-to-system.
+type rlimit interface {
+ // toLimit converts an rlimit to a limits.Limit.
+ toLimit() *limits.Limit
+
+ // fromLimit converts a limits.Limit to an rlimit.
+ fromLimit(lim limits.Limit)
+
+ // copyIn copies an rlimit from the untrusted app to the kernel.
+ copyIn(t *kernel.Task, addr usermem.Addr) error
+
+ // copyOut copies an rlimit from the kernel to the untrusted app.
+ copyOut(t *kernel.Task, addr usermem.Addr) error
+}
+
+// newRlimit returns the appropriate rlimit type for 'struct rlimit' on this system.
+func newRlimit(t *kernel.Task) (rlimit, error) {
+ switch t.Arch().Width() {
+ case 8:
+ // On 64-bit system, struct rlimit and struct rlimit64 are identical.
+ return &rlimit64{}, nil
+ default:
+ return nil, syserror.ENOSYS
+ }
+}
+
+type rlimit64 struct {
+ Cur uint64
+ Max uint64
+}
+
+func (r *rlimit64) toLimit() *limits.Limit {
+ return &limits.Limit{
+ Cur: limits.FromLinux(r.Cur),
+ Max: limits.FromLinux(r.Max),
+ }
+}
+
+func (r *rlimit64) fromLimit(lim limits.Limit) {
+ *r = rlimit64{
+ Cur: limits.ToLinux(lim.Cur),
+ Max: limits.ToLinux(lim.Max),
+ }
+}
+
+func (r *rlimit64) copyIn(t *kernel.Task, addr usermem.Addr) error {
+ _, err := t.CopyIn(addr, r)
+ return err
+}
+
+func (r *rlimit64) copyOut(t *kernel.Task, addr usermem.Addr) error {
+ _, err := t.CopyOut(addr, *r)
+ return err
+}
+
+func makeRlimit64(lim limits.Limit) *rlimit64 {
+ return &rlimit64{Cur: lim.Cur, Max: lim.Max}
+}
+
+// setableLimits is the set of supported setable limits.
+var setableLimits = map[limits.LimitType]struct{}{
+ limits.NumberOfFiles: {},
+ limits.AS: {},
+ limits.CPU: {},
+ limits.Data: {},
+ limits.FileSize: {},
+ limits.MemoryLocked: {},
+ limits.Stack: {},
+ // These are not enforced, but we include them here to avoid returning
+ // EPERM, since some apps expect them to succeed.
+ limits.Core: {},
+ limits.ProcessCount: {},
+}
+
+func prlimit64(t *kernel.Task, resource limits.LimitType, newLim *limits.Limit) (limits.Limit, error) {
+ if newLim == nil {
+ return t.ThreadGroup().Limits().Get(resource), nil
+ }
+
+ if _, ok := setableLimits[resource]; !ok {
+ return limits.Limit{}, syserror.EPERM
+ }
+
+ // "A privileged process (under Linux: one with the CAP_SYS_RESOURCE
+ // capability in the initial user namespace) may make arbitrary changes
+ // to either limit value."
+ privileged := t.HasCapabilityIn(linux.CAP_SYS_RESOURCE, t.Kernel().RootUserNamespace())
+
+ oldLim, err := t.ThreadGroup().Limits().Set(resource, *newLim, privileged)
+ if err != nil {
+ return limits.Limit{}, err
+ }
+
+ if resource == limits.CPU {
+ t.NotifyRlimitCPUUpdated()
+ }
+ return oldLim, nil
+}
+
+// Getrlimit implements linux syscall getrlimit(2).
+func Getrlimit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ resource, ok := limits.FromLinuxResource[int(args[0].Int())]
+ if !ok {
+ // Return err; unknown limit.
+ return 0, nil, syserror.EINVAL
+ }
+ addr := args[1].Pointer()
+ rlim, err := newRlimit(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ lim, err := prlimit64(t, resource, nil)
+ if err != nil {
+ return 0, nil, err
+ }
+ rlim.fromLimit(lim)
+ return 0, nil, rlim.copyOut(t, addr)
+}
+
+// Setrlimit implements linux syscall setrlimit(2).
+func Setrlimit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ resource, ok := limits.FromLinuxResource[int(args[0].Int())]
+ if !ok {
+ // Return err; unknown limit.
+ return 0, nil, syserror.EINVAL
+ }
+ addr := args[1].Pointer()
+ rlim, err := newRlimit(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ if err := rlim.copyIn(t, addr); err != nil {
+ return 0, nil, syserror.EFAULT
+ }
+ _, err = prlimit64(t, resource, rlim.toLimit())
+ return 0, nil, err
+}
+
+// Prlimit64 implements linux syscall prlimit64(2).
+func Prlimit64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ tid := kernel.ThreadID(args[0].Int())
+ resource, ok := limits.FromLinuxResource[int(args[1].Int())]
+ if !ok {
+ // Return err; unknown limit.
+ return 0, nil, syserror.EINVAL
+ }
+ newRlimAddr := args[2].Pointer()
+ oldRlimAddr := args[3].Pointer()
+
+ var newLim *limits.Limit
+ if newRlimAddr != 0 {
+ var nrl rlimit64
+ if err := nrl.copyIn(t, newRlimAddr); err != nil {
+ return 0, nil, syserror.EFAULT
+ }
+ newLim = nrl.toLimit()
+ }
+
+ if tid < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ ot := t
+ if tid > 0 {
+ if ot = t.PIDNamespace().TaskWithID(tid); ot == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ }
+
+ // "To set or get the resources of a process other than itself, the caller
+ // must have the CAP_SYS_RESOURCE capability, or the real, effective, and
+ // saved set user IDs of the target process must match the real user ID of
+ // the caller and the real, effective, and saved set group IDs of the
+ // target process must match the real group ID of the caller."
+ if ot != t && !t.HasCapabilityIn(linux.CAP_SYS_RESOURCE, t.PIDNamespace().UserNamespace()) {
+ cred, tcred := t.Credentials(), ot.Credentials()
+ if cred.RealKUID != tcred.RealKUID ||
+ cred.RealKUID != tcred.EffectiveKUID ||
+ cred.RealKUID != tcred.SavedKUID ||
+ cred.RealKGID != tcred.RealKGID ||
+ cred.RealKGID != tcred.EffectiveKGID ||
+ cred.RealKGID != tcred.SavedKGID {
+ return 0, nil, syserror.EPERM
+ }
+ }
+
+ oldLim, err := prlimit64(ot, resource, newLim)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if oldRlimAddr != 0 {
+ if err := makeRlimit64(oldLim).copyOut(t, oldRlimAddr); err != nil {
+ return 0, nil, syserror.EFAULT
+ }
+ }
+
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_rseq.go b/pkg/sentry/syscalls/linux/sys_rseq.go
new file mode 100644
index 000000000..90db10ea6
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_rseq.go
@@ -0,0 +1,48 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// RSeq implements syscall rseq(2).
+func RSeq(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].Uint()
+ flags := args[2].Int()
+ signature := args[3].Uint()
+
+ if !t.RSeqAvailable() {
+ // Event for applications that want rseq on a configuration
+ // that doesn't support them.
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.ENOSYS
+ }
+
+ switch flags {
+ case 0:
+ // Register.
+ return 0, nil, t.SetRSeq(addr, length, signature)
+ case linux.RSEQ_FLAG_UNREGISTER:
+ return 0, nil, t.ClearRSeq(addr, length, signature)
+ default:
+ // Unknown flag.
+ return 0, nil, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/sys_rusage.go b/pkg/sentry/syscalls/linux/sys_rusage.go
new file mode 100644
index 000000000..1674c7445
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_rusage.go
@@ -0,0 +1,112 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+func getrusage(t *kernel.Task, which int32) linux.Rusage {
+ var cs usage.CPUStats
+
+ switch which {
+ case linux.RUSAGE_SELF:
+ cs = t.ThreadGroup().CPUStats()
+
+ case linux.RUSAGE_CHILDREN:
+ cs = t.ThreadGroup().JoinedChildCPUStats()
+
+ case linux.RUSAGE_THREAD:
+ cs = t.CPUStats()
+
+ case linux.RUSAGE_BOTH:
+ tg := t.ThreadGroup()
+ cs = tg.CPUStats()
+ cs.Accumulate(tg.JoinedChildCPUStats())
+ }
+
+ return linux.Rusage{
+ UTime: linux.NsecToTimeval(cs.UserTime.Nanoseconds()),
+ STime: linux.NsecToTimeval(cs.SysTime.Nanoseconds()),
+ NVCSw: int64(cs.VoluntarySwitches),
+ MaxRSS: int64(t.MaxRSS(which) / 1024),
+ }
+}
+
+// Getrusage implements linux syscall getrusage(2).
+// marked "y" are supported now
+// marked "*" are not used on Linux
+// marked "p" are pending for support
+//
+// y struct timeval ru_utime; /* user CPU time used */
+// y struct timeval ru_stime; /* system CPU time used */
+// p long ru_maxrss; /* maximum resident set size */
+// * long ru_ixrss; /* integral shared memory size */
+// * long ru_idrss; /* integral unshared data size */
+// * long ru_isrss; /* integral unshared stack size */
+// p long ru_minflt; /* page reclaims (soft page faults) */
+// p long ru_majflt; /* page faults (hard page faults) */
+// * long ru_nswap; /* swaps */
+// p long ru_inblock; /* block input operations */
+// p long ru_oublock; /* block output operations */
+// * long ru_msgsnd; /* IPC messages sent */
+// * long ru_msgrcv; /* IPC messages received */
+// * long ru_nsignals; /* signals received */
+// y long ru_nvcsw; /* voluntary context switches */
+// y long ru_nivcsw; /* involuntary context switches */
+func Getrusage(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ which := args[0].Int()
+ addr := args[1].Pointer()
+
+ if which != linux.RUSAGE_SELF && which != linux.RUSAGE_CHILDREN && which != linux.RUSAGE_THREAD {
+ return 0, nil, syserror.EINVAL
+ }
+
+ ru := getrusage(t, which)
+ _, err := t.CopyOut(addr, &ru)
+ return 0, nil, err
+}
+
+// Times implements linux syscall times(2).
+func Times(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ // Calculate the ticks first, and figure out if any additional work is
+ // necessary. Linux allows for a NULL addr, in which case only the
+ // return value is meaningful. We don't need to do anything else.
+ ticks := uintptr(ktime.NowFromContext(t).Nanoseconds() / linux.ClockTick.Nanoseconds())
+ if addr == 0 {
+ return ticks, nil, nil
+ }
+
+ cs1 := t.ThreadGroup().CPUStats()
+ cs2 := t.ThreadGroup().JoinedChildCPUStats()
+ r := linux.Tms{
+ UTime: linux.ClockTFromDuration(cs1.UserTime),
+ STime: linux.ClockTFromDuration(cs1.SysTime),
+ CUTime: linux.ClockTFromDuration(cs2.UserTime),
+ CSTime: linux.ClockTFromDuration(cs2.SysTime),
+ }
+ if _, err := t.CopyOut(addr, &r); err != nil {
+ return 0, nil, err
+ }
+
+ return ticks, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_sched.go b/pkg/sentry/syscalls/linux/sys_sched.go
new file mode 100644
index 000000000..99f6993f5
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_sched.go
@@ -0,0 +1,99 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const (
+ onlyScheduler = linux.SCHED_NORMAL
+ onlyPriority = 0
+)
+
+// SchedParam replicates struct sched_param in sched.h.
+type SchedParam struct {
+ schedPriority int64
+}
+
+// SchedGetparam implements linux syscall sched_getparam(2).
+func SchedGetparam(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pid := args[0].Int()
+ param := args[1].Pointer()
+ if param == 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if pid < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if pid != 0 && t.PIDNamespace().TaskWithID(kernel.ThreadID(pid)) == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ r := SchedParam{schedPriority: onlyPriority}
+ if _, err := t.CopyOut(param, r); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, nil
+}
+
+// SchedGetscheduler implements linux syscall sched_getscheduler(2).
+func SchedGetscheduler(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pid := args[0].Int()
+ if pid < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if pid != 0 && t.PIDNamespace().TaskWithID(kernel.ThreadID(pid)) == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ return onlyScheduler, nil, nil
+}
+
+// SchedSetscheduler implements linux syscall sched_setscheduler(2).
+func SchedSetscheduler(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pid := args[0].Int()
+ policy := args[1].Int()
+ param := args[2].Pointer()
+ if pid < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if policy != onlyScheduler {
+ return 0, nil, syserror.EINVAL
+ }
+ if pid != 0 && t.PIDNamespace().TaskWithID(kernel.ThreadID(pid)) == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ var r SchedParam
+ if _, err := t.CopyIn(param, &r); err != nil {
+ return 0, nil, syserror.EINVAL
+ }
+ if r.schedPriority != onlyPriority {
+ return 0, nil, syserror.EINVAL
+ }
+ return 0, nil, nil
+}
+
+// SchedGetPriorityMax implements linux syscall sched_get_priority_max(2).
+func SchedGetPriorityMax(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return onlyPriority, nil, nil
+}
+
+// SchedGetPriorityMin implements linux syscall sched_get_priority_min(2).
+func SchedGetPriorityMin(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return onlyPriority, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_seccomp.go b/pkg/sentry/syscalls/linux/sys_seccomp.go
new file mode 100644
index 000000000..5b7a66f4d
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_seccomp.go
@@ -0,0 +1,76 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// userSockFprog is equivalent to Linux's struct sock_fprog on amd64.
+type userSockFprog struct {
+ // Len is the length of the filter in BPF instructions.
+ Len uint16
+
+ _ [6]byte // padding for alignment
+
+ // Filter is a user pointer to the struct sock_filter array that makes up
+ // the filter program. Filter is a uint64 rather than a usermem.Addr
+ // because usermem.Addr is actually uintptr, which is not a fixed-size
+ // type, and encoding/binary.Read objects to this.
+ Filter uint64
+}
+
+// seccomp applies a seccomp policy to the current task.
+func seccomp(t *kernel.Task, mode, flags uint64, addr usermem.Addr) error {
+ // We only support SECCOMP_SET_MODE_FILTER at the moment.
+ if mode != linux.SECCOMP_SET_MODE_FILTER {
+ // Unsupported mode.
+ return syserror.EINVAL
+ }
+
+ tsync := flags&linux.SECCOMP_FILTER_FLAG_TSYNC != 0
+
+ // The only flag we support now is SECCOMP_FILTER_FLAG_TSYNC.
+ if flags&^linux.SECCOMP_FILTER_FLAG_TSYNC != 0 {
+ // Unsupported flag.
+ return syserror.EINVAL
+ }
+
+ var fprog userSockFprog
+ if _, err := t.CopyIn(addr, &fprog); err != nil {
+ return err
+ }
+ filter := make([]linux.BPFInstruction, int(fprog.Len))
+ if _, err := t.CopyIn(usermem.Addr(fprog.Filter), &filter); err != nil {
+ return err
+ }
+ compiledFilter, err := bpf.Compile(filter)
+ if err != nil {
+ t.Debugf("Invalid seccomp-bpf filter: %v", err)
+ return syserror.EINVAL
+ }
+
+ return t.AppendSyscallFilter(compiledFilter, tsync)
+}
+
+// Seccomp implements linux syscall seccomp(2).
+func Seccomp(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, seccomp(t, args[0].Uint64(), args[1].Uint64(), args[2].Pointer())
+}
diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go
new file mode 100644
index 000000000..5f54f2456
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_sem.go
@@ -0,0 +1,241 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "math"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const opsMax = 500 // SEMOPM
+
+// Semget handles: semget(key_t key, int nsems, int semflg)
+func Semget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ key := args[0].Int()
+ nsems := args[1].Int()
+ flag := args[2].Int()
+
+ private := key == linux.IPC_PRIVATE
+ create := flag&linux.IPC_CREAT == linux.IPC_CREAT
+ exclusive := flag&linux.IPC_EXCL == linux.IPC_EXCL
+ mode := linux.FileMode(flag & 0777)
+
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set, err := r.FindOrCreate(t, key, nsems, mode, private, create, exclusive)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(set.ID), nil, nil
+}
+
+// Semop handles: semop(int semid, struct sembuf *sops, size_t nsops)
+func Semop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := args[0].Int()
+ sembufAddr := args[1].Pointer()
+ nsops := args[2].SizeT()
+
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return 0, nil, syserror.EINVAL
+ }
+ if nsops <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if nsops > opsMax {
+ return 0, nil, syserror.E2BIG
+ }
+
+ ops := make([]linux.Sembuf, nsops)
+ if _, err := t.CopyIn(sembufAddr, ops); err != nil {
+ return 0, nil, err
+ }
+
+ creds := auth.CredentialsFromContext(t)
+ pid := t.Kernel().GlobalInit().PIDNamespace().IDOfThreadGroup(t.ThreadGroup())
+ for {
+ ch, num, err := set.ExecuteOps(t, ops, creds, int32(pid))
+ if ch == nil || err != nil {
+ // We're done (either on success or a failure).
+ return 0, nil, err
+ }
+ if err = t.Block(ch); err != nil {
+ set.AbortWait(num, ch)
+ return 0, nil, err
+ }
+ }
+}
+
+// Semctl handles: semctl(int semid, int semnum, int cmd, ...)
+func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := args[0].Int()
+ num := args[1].Int()
+ cmd := args[2].Int()
+
+ switch cmd {
+ case linux.SETVAL:
+ val := args[3].Int()
+ if val > math.MaxInt16 {
+ return 0, nil, syserror.ERANGE
+ }
+ return 0, nil, setVal(t, id, num, int16(val))
+
+ case linux.SETALL:
+ array := args[3].Pointer()
+ return 0, nil, setValAll(t, id, array)
+
+ case linux.GETVAL:
+ v, err := getVal(t, id, num)
+ return uintptr(v), nil, err
+
+ case linux.GETALL:
+ array := args[3].Pointer()
+ return 0, nil, getValAll(t, id, array)
+
+ case linux.IPC_RMID:
+ return 0, nil, remove(t, id)
+
+ case linux.IPC_SET:
+ arg := args[3].Pointer()
+ s := linux.SemidDS{}
+ if _, err := t.CopyIn(arg, &s); err != nil {
+ return 0, nil, err
+ }
+
+ perms := fs.FilePermsFromMode(linux.FileMode(s.SemPerm.Mode & 0777))
+ return 0, nil, ipcSet(t, id, auth.UID(s.SemPerm.UID), auth.GID(s.SemPerm.GID), perms)
+
+ case linux.GETPID:
+ v, err := getPID(t, id, num)
+ return uintptr(v), nil, err
+
+ case linux.IPC_INFO,
+ linux.SEM_INFO,
+ linux.IPC_STAT,
+ linux.SEM_STAT,
+ linux.SEM_STAT_ANY,
+ linux.GETNCNT,
+ linux.GETZCNT:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ fallthrough
+
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+}
+
+func remove(t *kernel.Task, id int32) error {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ creds := auth.CredentialsFromContext(t)
+ return r.RemoveID(id, creds)
+}
+
+func ipcSet(t *kernel.Task, id int32, uid auth.UID, gid auth.GID, perms fs.FilePermissions) error {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return syserror.EINVAL
+ }
+
+ creds := auth.CredentialsFromContext(t)
+ kuid := creds.UserNamespace.MapToKUID(uid)
+ if !kuid.Ok() {
+ return syserror.EINVAL
+ }
+ kgid := creds.UserNamespace.MapToKGID(gid)
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+ owner := fs.FileOwner{UID: kuid, GID: kgid}
+ return set.Change(t, creds, owner, perms)
+}
+
+func setVal(t *kernel.Task, id int32, num int32, val int16) error {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ pid := t.Kernel().GlobalInit().PIDNamespace().IDOfThreadGroup(t.ThreadGroup())
+ return set.SetVal(t, num, val, creds, int32(pid))
+}
+
+func setValAll(t *kernel.Task, id int32, array usermem.Addr) error {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return syserror.EINVAL
+ }
+ vals := make([]uint16, set.Size())
+ if _, err := t.CopyIn(array, vals); err != nil {
+ return err
+ }
+ creds := auth.CredentialsFromContext(t)
+ pid := t.Kernel().GlobalInit().PIDNamespace().IDOfThreadGroup(t.ThreadGroup())
+ return set.SetValAll(t, vals, creds, int32(pid))
+}
+
+func getVal(t *kernel.Task, id int32, num int32) (int16, error) {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return 0, syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ return set.GetVal(num, creds)
+}
+
+func getValAll(t *kernel.Task, id int32, array usermem.Addr) error {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ vals, err := set.GetValAll(creds)
+ if err != nil {
+ return err
+ }
+ _, err = t.CopyOut(array, vals)
+ return err
+}
+
+func getPID(t *kernel.Task, id int32, num int32) (int32, error) {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return 0, syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ gpid, err := set.GetPID(num, creds)
+ if err != nil {
+ return 0, err
+ }
+ // Convert pid from init namespace to the caller's namespace.
+ tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(gpid))
+ if tg == nil {
+ return 0, nil
+ }
+ return int32(tg.ID()), nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go
new file mode 100644
index 000000000..4a8bc24a2
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_shm.go
@@ -0,0 +1,161 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/shm"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Shmget implements shmget(2).
+func Shmget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ key := shm.Key(args[0].Int())
+ size := uint64(args[1].SizeT())
+ flag := args[2].Int()
+
+ private := key == linux.IPC_PRIVATE
+ create := flag&linux.IPC_CREAT == linux.IPC_CREAT
+ exclusive := flag&linux.IPC_EXCL == linux.IPC_EXCL
+ mode := linux.FileMode(flag & 0777)
+
+ pid := int32(t.ThreadGroup().ID())
+ r := t.IPCNamespace().ShmRegistry()
+ segment, err := r.FindOrCreate(t, pid, key, size, mode, private, create, exclusive)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer segment.DecRef()
+ return uintptr(segment.ID), nil, nil
+}
+
+// findSegment retrives a shm segment by the given id.
+//
+// findSegment returns a reference on Shm.
+func findSegment(t *kernel.Task, id shm.ID) (*shm.Shm, error) {
+ r := t.IPCNamespace().ShmRegistry()
+ segment := r.FindByID(id)
+ if segment == nil {
+ // No segment with provided id.
+ return nil, syserror.EINVAL
+ }
+ return segment, nil
+}
+
+// Shmat implements shmat(2).
+func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := shm.ID(args[0].Int())
+ addr := args[1].Pointer()
+ flag := args[2].Int()
+
+ segment, err := findSegment(t, id)
+ if err != nil {
+ return 0, nil, syserror.EINVAL
+ }
+ defer segment.DecRef()
+
+ opts, err := segment.ConfigureAttach(t, addr, shm.AttachOpts{
+ Execute: flag&linux.SHM_EXEC == linux.SHM_EXEC,
+ Readonly: flag&linux.SHM_RDONLY == linux.SHM_RDONLY,
+ Remap: flag&linux.SHM_REMAP == linux.SHM_REMAP,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ addr, err = t.MemoryManager().MMap(t, opts)
+ return uintptr(addr), nil, err
+}
+
+// Shmdt implements shmdt(2).
+func Shmdt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ err := t.MemoryManager().DetachShm(t, addr)
+ return 0, nil, err
+}
+
+// Shmctl implements shmctl(2).
+func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := shm.ID(args[0].Int())
+ cmd := args[1].Int()
+ buf := args[2].Pointer()
+
+ r := t.IPCNamespace().ShmRegistry()
+
+ switch cmd {
+ case linux.SHM_STAT:
+ // Technically, we should be treating id as "an index into the kernel's
+ // internal array that maintains information about all shared memory
+ // segments on the system". Since we don't track segments in an array,
+ // we'll just pretend the shmid is the index and do the same thing as
+ // IPC_STAT. Linux also uses the index as the shmid.
+ fallthrough
+ case linux.IPC_STAT:
+ segment, err := findSegment(t, id)
+ if err != nil {
+ return 0, nil, syserror.EINVAL
+ }
+ defer segment.DecRef()
+
+ stat, err := segment.IPCStat(t)
+ if err == nil {
+ _, err = t.CopyOut(buf, stat)
+ }
+ return 0, nil, err
+
+ case linux.IPC_INFO:
+ params := r.IPCInfo()
+ _, err := t.CopyOut(buf, params)
+ return 0, nil, err
+
+ case linux.SHM_INFO:
+ info := r.ShmInfo()
+ _, err := t.CopyOut(buf, info)
+ return 0, nil, err
+ }
+
+ // Remaining commands refer to a specific segment.
+ segment, err := findSegment(t, id)
+ if err != nil {
+ return 0, nil, syserror.EINVAL
+ }
+ defer segment.DecRef()
+
+ switch cmd {
+ case linux.IPC_SET:
+ var ds linux.ShmidDS
+ _, err = t.CopyIn(buf, &ds)
+ if err != nil {
+ return 0, nil, err
+ }
+ err = segment.Set(t, &ds)
+ return 0, nil, err
+
+ case linux.IPC_RMID:
+ segment.MarkDestroyed()
+ return 0, nil, nil
+
+ case linux.SHM_LOCK, linux.SHM_UNLOCK:
+ // We currently do not support memory locking anywhere.
+ // mlock(2)/munlock(2) are currently stubbed out as no-ops so do the
+ // same here.
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, nil
+
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go
new file mode 100644
index 000000000..d2b0012ae
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_signal.go
@@ -0,0 +1,590 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "math"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// "For a process to have permission to send a signal it must
+// - either be privileged (CAP_KILL), or
+// - the real or effective user ID of the sending process must be equal to the
+// real or saved set-user-ID of the target process.
+//
+// In the case of SIGCONT it suffices when the sending and receiving processes
+// belong to the same session." - kill(2)
+//
+// Equivalent to kernel/signal.c:check_kill_permission.
+func mayKill(t *kernel.Task, target *kernel.Task, sig linux.Signal) bool {
+ // kernel/signal.c:check_kill_permission also allows a signal if the
+ // sending and receiving tasks share a thread group, which is not
+ // mentioned in kill(2) since kill does not allow task-level
+ // granularity in signal sending.
+ if t.ThreadGroup() == target.ThreadGroup() {
+ return true
+ }
+
+ if t.HasCapabilityIn(linux.CAP_KILL, target.UserNamespace()) {
+ return true
+ }
+
+ creds := t.Credentials()
+ tcreds := target.Credentials()
+ if creds.EffectiveKUID == tcreds.SavedKUID ||
+ creds.EffectiveKUID == tcreds.RealKUID ||
+ creds.RealKUID == tcreds.SavedKUID ||
+ creds.RealKUID == tcreds.RealKUID {
+ return true
+ }
+
+ if sig == linux.SIGCONT && target.ThreadGroup().Session() == t.ThreadGroup().Session() {
+ return true
+ }
+ return false
+}
+
+// Kill implements linux syscall kill(2).
+func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pid := kernel.ThreadID(args[0].Int())
+ sig := linux.Signal(args[1].Int())
+
+ switch {
+ case pid > 0:
+ // "If pid is positive, then signal sig is sent to the process with the
+ // ID specified by pid." - kill(2)
+ // This loops to handle races with execve where target dies between
+ // TaskWithID and SendGroupSignal. Compare Linux's
+ // kernel/signal.c:kill_pid_info().
+ for {
+ target := t.PIDNamespace().TaskWithID(pid)
+ if target == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ if !mayKill(t, target, sig) {
+ return 0, nil, syserror.EPERM
+ }
+ info := &arch.SignalInfo{
+ Signo: int32(sig),
+ Code: arch.SignalInfoUser,
+ }
+ info.SetPid(int32(target.PIDNamespace().IDOfTask(t)))
+ info.SetUid(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow()))
+ if err := target.SendGroupSignal(info); err != syserror.ESRCH {
+ return 0, nil, err
+ }
+ }
+ case pid == -1:
+ // "If pid equals -1, then sig is sent to every process for which the
+ // calling process has permission to send signals, except for process 1
+ // (init), but see below. ... POSIX.1-2001 requires that kill(-1,sig)
+ // send sig to all processes that the calling process may send signals
+ // to, except possibly for some implementation-defined system
+ // processes. Linux allows a process to signal itself, but on Linux the
+ // call kill(-1,sig) does not signal the calling process."
+ var (
+ lastErr error
+ delivered int
+ )
+ for _, tg := range t.PIDNamespace().ThreadGroups() {
+ if tg == t.ThreadGroup() {
+ continue
+ }
+ if t.PIDNamespace().IDOfThreadGroup(tg) == kernel.InitTID {
+ continue
+ }
+
+ // If pid == -1, the returned error is the last non-EPERM error
+ // from any call to group_send_sig_info.
+ if !mayKill(t, tg.Leader(), sig) {
+ continue
+ }
+ // Here and below, whether or not kill returns an error may
+ // depend on the iteration order. We at least implement the
+ // semantics documented by the man page: "On success (at least
+ // one signal was sent), zero is returned."
+ info := &arch.SignalInfo{
+ Signo: int32(sig),
+ Code: arch.SignalInfoUser,
+ }
+ info.SetPid(int32(tg.PIDNamespace().IDOfTask(t)))
+ info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
+ err := tg.SendSignal(info)
+ if err == syserror.ESRCH {
+ // ESRCH is ignored because it means the task
+ // exited while we were iterating. This is a
+ // race which would not normally exist on
+ // Linux, so we suppress it.
+ continue
+ }
+ delivered++
+ if err != nil {
+ lastErr = err
+ }
+ }
+ if delivered > 0 {
+ return 0, nil, lastErr
+ }
+ return 0, nil, syserror.ESRCH
+ default:
+ // "If pid equals 0, then sig is sent to every process in the process
+ // group of the calling process."
+ //
+ // "If pid is less than -1, then sig is sent to every process
+ // in the process group whose ID is -pid."
+ pgid := kernel.ProcessGroupID(-pid)
+ if pgid == 0 {
+ pgid = t.PIDNamespace().IDOfProcessGroup(t.ThreadGroup().ProcessGroup())
+ }
+
+ // If pid != -1 (i.e. signalling a process group), the returned error
+ // is the last error from any call to group_send_sig_info.
+ lastErr := syserror.ESRCH
+ for _, tg := range t.PIDNamespace().ThreadGroups() {
+ if t.PIDNamespace().IDOfProcessGroup(tg.ProcessGroup()) == pgid {
+ if !mayKill(t, tg.Leader(), sig) {
+ lastErr = syserror.EPERM
+ continue
+ }
+
+ info := &arch.SignalInfo{
+ Signo: int32(sig),
+ Code: arch.SignalInfoUser,
+ }
+ info.SetPid(int32(tg.PIDNamespace().IDOfTask(t)))
+ info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
+ // See note above regarding ESRCH race above.
+ if err := tg.SendSignal(info); err != syserror.ESRCH {
+ lastErr = err
+ }
+ }
+ }
+
+ return 0, nil, lastErr
+ }
+}
+
+func tkillSigInfo(sender, receiver *kernel.Task, sig linux.Signal) *arch.SignalInfo {
+ info := &arch.SignalInfo{
+ Signo: int32(sig),
+ Code: arch.SignalInfoTkill,
+ }
+ info.SetPid(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup())))
+ info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
+ return info
+}
+
+// Tkill implements linux syscall tkill(2).
+func Tkill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ tid := kernel.ThreadID(args[0].Int())
+ sig := linux.Signal(args[1].Int())
+
+ // N.B. Inconsistent with man page, linux actually rejects calls with
+ // tid <=0 by EINVAL. This isn't the same for all signal calls.
+ if tid <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ target := t.PIDNamespace().TaskWithID(tid)
+ if target == nil {
+ return 0, nil, syserror.ESRCH
+ }
+
+ if !mayKill(t, target, sig) {
+ return 0, nil, syserror.EPERM
+ }
+ return 0, nil, target.SendSignal(tkillSigInfo(t, target, sig))
+}
+
+// Tgkill implements linux syscall tgkill(2).
+func Tgkill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ tgid := kernel.ThreadID(args[0].Int())
+ tid := kernel.ThreadID(args[1].Int())
+ sig := linux.Signal(args[2].Int())
+
+ // N.B. Inconsistent with man page, linux actually rejects calls with
+ // tgid/tid <=0 by EINVAL. This isn't the same for all signal calls.
+ if tgid <= 0 || tid <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ targetTG := t.PIDNamespace().ThreadGroupWithID(tgid)
+ target := t.PIDNamespace().TaskWithID(tid)
+ if targetTG == nil || target == nil || target.ThreadGroup() != targetTG {
+ return 0, nil, syserror.ESRCH
+ }
+
+ if !mayKill(t, target, sig) {
+ return 0, nil, syserror.EPERM
+ }
+ return 0, nil, target.SendSignal(tkillSigInfo(t, target, sig))
+}
+
+// RtSigaction implements linux syscall rt_sigaction(2).
+func RtSigaction(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ sig := linux.Signal(args[0].Int())
+ newactarg := args[1].Pointer()
+ oldactarg := args[2].Pointer()
+ sigsetsize := args[3].SizeT()
+
+ if sigsetsize != linux.SignalSetSize {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var newactptr *arch.SignalAct
+ if newactarg != 0 {
+ newact, err := t.CopyInSignalAct(newactarg)
+ if err != nil {
+ return 0, nil, err
+ }
+ newactptr = &newact
+ }
+ oldact, err := t.ThreadGroup().SetSignalAct(sig, newactptr)
+ if err != nil {
+ return 0, nil, err
+ }
+ if oldactarg != 0 {
+ if err := t.CopyOutSignalAct(oldactarg, &oldact); err != nil {
+ return 0, nil, err
+ }
+ }
+ return 0, nil, nil
+}
+
+// Sigreturn implements linux syscall sigreturn(2).
+func Sigreturn(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ ctrl, err := t.SignalReturn(false)
+ return 0, ctrl, err
+}
+
+// RtSigreturn implements linux syscall rt_sigreturn(2).
+func RtSigreturn(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ ctrl, err := t.SignalReturn(true)
+ return 0, ctrl, err
+}
+
+// RtSigprocmask implements linux syscall rt_sigprocmask(2).
+func RtSigprocmask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ how := args[0].Int()
+ setaddr := args[1].Pointer()
+ oldaddr := args[2].Pointer()
+ sigsetsize := args[3].SizeT()
+
+ if sigsetsize != linux.SignalSetSize {
+ return 0, nil, syserror.EINVAL
+ }
+ oldmask := t.SignalMask()
+ if setaddr != 0 {
+ mask, err := CopyInSigSet(t, setaddr, sigsetsize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ switch how {
+ case linux.SIG_BLOCK:
+ t.SetSignalMask(oldmask | mask)
+ case linux.SIG_UNBLOCK:
+ t.SetSignalMask(oldmask &^ mask)
+ case linux.SIG_SETMASK:
+ t.SetSignalMask(mask)
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+ }
+ if oldaddr != 0 {
+ return 0, nil, copyOutSigSet(t, oldaddr, oldmask)
+ }
+
+ return 0, nil, nil
+}
+
+// Sigaltstack implements linux syscall sigaltstack(2).
+func Sigaltstack(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ setaddr := args[0].Pointer()
+ oldaddr := args[1].Pointer()
+
+ alt := t.SignalStack()
+ if oldaddr != 0 {
+ if err := t.CopyOutSignalStack(oldaddr, &alt); err != nil {
+ return 0, nil, err
+ }
+ }
+ if setaddr != 0 {
+ alt, err := t.CopyInSignalStack(setaddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ // The signal stack cannot be changed if the task is currently
+ // on the stack. This is enforced at the lowest level because
+ // these semantics apply to changing the signal stack via a
+ // ucontext during a signal handler.
+ if !t.SetSignalStack(alt) {
+ return 0, nil, syserror.EPERM
+ }
+ }
+
+ return 0, nil, nil
+}
+
+// Pause implements linux syscall pause(2).
+func Pause(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, syserror.ConvertIntr(t.Block(nil), kernel.ERESTARTNOHAND)
+}
+
+// RtSigpending implements linux syscall rt_sigpending(2).
+func RtSigpending(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ pending := t.PendingSignals()
+ _, err := pending.CopyOut(t, addr)
+ return 0, nil, err
+}
+
+// RtSigtimedwait implements linux syscall rt_sigtimedwait(2).
+func RtSigtimedwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ sigset := args[0].Pointer()
+ siginfo := args[1].Pointer()
+ timespec := args[2].Pointer()
+ sigsetsize := args[3].SizeT()
+
+ mask, err := CopyInSigSet(t, sigset, sigsetsize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var timeout time.Duration
+ if timespec != 0 {
+ d, err := copyTimespecIn(t, timespec)
+ if err != nil {
+ return 0, nil, err
+ }
+ if !d.Valid() {
+ return 0, nil, syserror.EINVAL
+ }
+ timeout = time.Duration(d.ToNsecCapped())
+ } else {
+ timeout = time.Duration(math.MaxInt64)
+ }
+
+ si, err := t.Sigtimedwait(mask, timeout)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if siginfo != 0 {
+ si.FixSignalCodeForUser()
+ if _, err := si.CopyOut(t, siginfo); err != nil {
+ return 0, nil, err
+ }
+ }
+ return uintptr(si.Signo), nil, nil
+}
+
+// RtSigqueueinfo implements linux syscall rt_sigqueueinfo(2).
+func RtSigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pid := kernel.ThreadID(args[0].Int())
+ sig := linux.Signal(args[1].Int())
+ infoAddr := args[2].Pointer()
+
+ // Copy in the info.
+ //
+ // We must ensure that the Signo is set (Linux overrides this in the
+ // same way), and that the code is in the allowed set. This same logic
+ // appears below in RtSigtgqueueinfo and should be kept in sync.
+ var info arch.SignalInfo
+ if _, err := info.CopyIn(t, infoAddr); err != nil {
+ return 0, nil, err
+ }
+ info.Signo = int32(sig)
+
+ // This must loop to handle the race with execve described in Kill.
+ for {
+ // Deliver to the given task's thread group.
+ target := t.PIDNamespace().TaskWithID(pid)
+ if target == nil {
+ return 0, nil, syserror.ESRCH
+ }
+
+ // If the sender is not the receiver, it can't use si_codes used by the
+ // kernel or SI_TKILL.
+ if (info.Code >= 0 || info.Code == arch.SignalInfoTkill) && target != t {
+ return 0, nil, syserror.EPERM
+ }
+
+ if !mayKill(t, target, sig) {
+ return 0, nil, syserror.EPERM
+ }
+
+ if err := target.SendGroupSignal(&info); err != syserror.ESRCH {
+ return 0, nil, err
+ }
+ }
+}
+
+// RtTgsigqueueinfo implements linux syscall rt_tgsigqueueinfo(2).
+func RtTgsigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ tgid := kernel.ThreadID(args[0].Int())
+ tid := kernel.ThreadID(args[1].Int())
+ sig := linux.Signal(args[2].Int())
+ infoAddr := args[3].Pointer()
+
+ // N.B. Inconsistent with man page, linux actually rejects calls with
+ // tgid/tid <=0 by EINVAL. This isn't the same for all signal calls.
+ if tgid <= 0 || tid <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Copy in the info. See RtSigqueueinfo above.
+ var info arch.SignalInfo
+ if _, err := info.CopyIn(t, infoAddr); err != nil {
+ return 0, nil, err
+ }
+ info.Signo = int32(sig)
+
+ // Deliver to the given task.
+ targetTG := t.PIDNamespace().ThreadGroupWithID(tgid)
+ target := t.PIDNamespace().TaskWithID(tid)
+ if targetTG == nil || target == nil || target.ThreadGroup() != targetTG {
+ return 0, nil, syserror.ESRCH
+ }
+
+ // If the sender is not the receiver, it can't use si_codes used by the
+ // kernel or SI_TKILL.
+ if (info.Code >= 0 || info.Code == arch.SignalInfoTkill) && target != t {
+ return 0, nil, syserror.EPERM
+ }
+
+ if !mayKill(t, target, sig) {
+ return 0, nil, syserror.EPERM
+ }
+ return 0, nil, target.SendSignal(&info)
+}
+
+// RtSigsuspend implements linux syscall rt_sigsuspend(2).
+func RtSigsuspend(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ sigset := args[0].Pointer()
+
+ // Copy in the signal mask.
+ var mask linux.SignalSet
+ if _, err := mask.CopyIn(t, sigset); err != nil {
+ return 0, nil, err
+ }
+ mask &^= kernel.UnblockableSignals
+
+ // Swap the mask.
+ oldmask := t.SignalMask()
+ t.SetSignalMask(mask)
+ t.SetSavedSignalMask(oldmask)
+
+ // Perform the wait.
+ return 0, nil, syserror.ConvertIntr(t.Block(nil), kernel.ERESTARTNOHAND)
+}
+
+// RestartSyscall implements the linux syscall restart_syscall(2).
+func RestartSyscall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ if r := t.SyscallRestartBlock(); r != nil {
+ n, err := r.Restart(t)
+ return n, nil, err
+ }
+ // The restart block should never be nil here, but it's possible
+ // ERESTART_RESTARTBLOCK was set by ptrace without the current syscall
+ // setting up a restart block. If ptrace didn't manipulate the return value,
+ // finding a nil restart block is a bug. Linux ensures that the restart
+ // function is never null by (re)initializing it with one that translates
+ // the restart into EINTR. We'll emulate that behaviour.
+ t.Debugf("Restart block missing in restart_syscall(2). Did ptrace inject a return value of ERESTART_RESTARTBLOCK?")
+ return 0, nil, syserror.EINTR
+}
+
+// sharedSignalfd is shared between the two calls.
+func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) {
+ // Copy in the signal mask.
+ mask, err := CopyInSigSet(t, sigset, sigsetsize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Always check for valid flags, even if not creating.
+ if flags&^(linux.SFD_NONBLOCK|linux.SFD_CLOEXEC) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Is this a change to an existing signalfd?
+ //
+ // The spec indicates that this should adjust the mask.
+ if fd != -1 {
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Is this a signalfd?
+ if s, ok := file.FileOperations.(*signalfd.SignalOperations); ok {
+ s.SetMask(mask)
+ return 0, nil, nil
+ }
+
+ // Not a signalfd.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Create a new file.
+ file, err := signalfd.New(t, mask)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ // Set appropriate flags.
+ file.SetFlags(fs.SettableFileFlags{
+ NonBlocking: flags&linux.SFD_NONBLOCK != 0,
+ })
+
+ // Create a new descriptor.
+ fd, err = t.NewFDFrom(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.SFD_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Done.
+ return uintptr(fd), nil, nil
+}
+
+// Signalfd implements the linux syscall signalfd(2).
+func Signalfd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ sigset := args[1].Pointer()
+ sigsetsize := args[2].SizeT()
+ return sharedSignalfd(t, fd, sigset, sigsetsize, 0)
+}
+
+// Signalfd4 implements the linux syscall signalfd4(2).
+func Signalfd4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ sigset := args[1].Pointer()
+ sigsetsize := args[2].SizeT()
+ flags := args[3].Int()
+ return sharedSignalfd(t, fd, sigset, sigsetsize, flags)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
new file mode 100644
index 000000000..0760af77b
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -0,0 +1,1138 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// LINT.IfChange
+
+// minListenBacklog is the minimum reasonable backlog for listening sockets.
+const minListenBacklog = 8
+
+// maxListenBacklog is the maximum allowed backlog for listening sockets.
+const maxListenBacklog = 1024
+
+// maxAddrLen is the maximum socket address length we're willing to accept.
+const maxAddrLen = 200
+
+// maxOptLen is the maximum sockopt parameter length we're willing to accept.
+const maxOptLen = 1024 * 8
+
+// maxControlLen is the maximum length of the msghdr.msg_control buffer we're
+// willing to accept. Note that this limit is smaller than Linux, which allows
+// buffers upto INT_MAX.
+const maxControlLen = 10 * 1024 * 1024
+
+// nameLenOffset is the offset from the start of the MessageHeader64 struct to
+// the NameLen field.
+const nameLenOffset = 8
+
+// controlLenOffset is the offset form the start of the MessageHeader64 struct
+// to the ControlLen field.
+const controlLenOffset = 40
+
+// flagsOffset is the offset form the start of the MessageHeader64 struct
+// to the Flags field.
+const flagsOffset = 48
+
+const sizeOfInt32 = 4
+
+// messageHeader64Len is the length of a MessageHeader64 struct.
+var messageHeader64Len = uint64(binary.Size(MessageHeader64{}))
+
+// multipleMessageHeader64Len is the length of a multipeMessageHeader64 struct.
+var multipleMessageHeader64Len = uint64(binary.Size(multipleMessageHeader64{}))
+
+// baseRecvFlags are the flags that are accepted across recvmsg(2),
+// recvmmsg(2), and recvfrom(2).
+const baseRecvFlags = linux.MSG_OOB | linux.MSG_DONTROUTE | linux.MSG_DONTWAIT | linux.MSG_NOSIGNAL | linux.MSG_WAITALL | linux.MSG_TRUNC | linux.MSG_CTRUNC
+
+// MessageHeader64 is the 64-bit representation of the msghdr struct used in
+// the recvmsg and sendmsg syscalls.
+type MessageHeader64 struct {
+ // Name is the optional pointer to a network address buffer.
+ Name uint64
+
+ // NameLen is the length of the buffer pointed to by Name.
+ NameLen uint32
+ _ uint32
+
+ // Iov is a pointer to an array of io vectors that describe the memory
+ // locations involved in the io operation.
+ Iov uint64
+
+ // IovLen is the length of the array pointed to by Iov.
+ IovLen uint64
+
+ // Control is the optional pointer to ancillary control data.
+ Control uint64
+
+ // ControlLen is the length of the data pointed to by Control.
+ ControlLen uint64
+
+ // Flags on the sent/received message.
+ Flags int32
+ _ int32
+}
+
+// multipleMessageHeader64 is the 64-bit representation of the mmsghdr struct used in
+// the recvmmsg and sendmmsg syscalls.
+type multipleMessageHeader64 struct {
+ msgHdr MessageHeader64
+ msgLen uint32
+ _ int32
+}
+
+// CopyInMessageHeader64 copies a message header from user to kernel memory.
+func CopyInMessageHeader64(t *kernel.Task, addr usermem.Addr, msg *MessageHeader64) error {
+ b := t.CopyScratchBuffer(52)
+ if _, err := t.CopyInBytes(addr, b); err != nil {
+ return err
+ }
+
+ msg.Name = usermem.ByteOrder.Uint64(b[0:])
+ msg.NameLen = usermem.ByteOrder.Uint32(b[8:])
+ msg.Iov = usermem.ByteOrder.Uint64(b[16:])
+ msg.IovLen = usermem.ByteOrder.Uint64(b[24:])
+ msg.Control = usermem.ByteOrder.Uint64(b[32:])
+ msg.ControlLen = usermem.ByteOrder.Uint64(b[40:])
+ msg.Flags = int32(usermem.ByteOrder.Uint32(b[48:]))
+
+ return nil
+}
+
+// CaptureAddress allocates memory for and copies a socket address structure
+// from the untrusted address space range.
+func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) {
+ if addrlen > maxAddrLen {
+ return nil, syserror.EINVAL
+ }
+
+ addrBuf := make([]byte, addrlen)
+ if _, err := t.CopyInBytes(addr, addrBuf); err != nil {
+ return nil, err
+ }
+
+ return addrBuf, nil
+}
+
+// writeAddress writes a sockaddr structure and its length to an output buffer
+// in the unstrusted address space range. If the address is bigger than the
+// buffer, it is truncated.
+func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error {
+ // Get the buffer length.
+ var bufLen uint32
+ if _, err := t.CopyIn(addrLenPtr, &bufLen); err != nil {
+ return err
+ }
+
+ if int32(bufLen) < 0 {
+ return syserror.EINVAL
+ }
+
+ // Write the length unconditionally.
+ if _, err := t.CopyOut(addrLenPtr, addrLen); err != nil {
+ return err
+ }
+
+ if addr == nil {
+ return nil
+ }
+
+ if bufLen > addrLen {
+ bufLen = addrLen
+ }
+
+ // Copy as much of the address as will fit in the buffer.
+ encodedAddr := binary.Marshal(nil, usermem.ByteOrder, addr)
+ if bufLen > uint32(len(encodedAddr)) {
+ bufLen = uint32(len(encodedAddr))
+ }
+ _, err := t.CopyOutBytes(addrPtr, encodedAddr[:int(bufLen)])
+ return err
+}
+
+// Socket implements the linux syscall socket(2).
+func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ domain := int(args[0].Int())
+ stype := args[1].Int()
+ protocol := int(args[2].Int())
+
+ // Check and initialize the flags.
+ if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Create the new socket.
+ s, e := socket.New(t, domain, linux.SockType(stype&0xf), protocol)
+ if e != nil {
+ return 0, nil, e.ToError()
+ }
+ s.SetFlags(fs.SettableFileFlags{
+ NonBlocking: stype&linux.SOCK_NONBLOCK != 0,
+ })
+ defer s.DecRef()
+
+ fd, err := t.NewFDFrom(0, s, kernel.FDFlags{
+ CloseOnExec: stype&linux.SOCK_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// SocketPair implements the linux syscall socketpair(2).
+func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ domain := int(args[0].Int())
+ stype := args[1].Int()
+ protocol := int(args[2].Int())
+ socks := args[3].Pointer()
+
+ // Check and initialize the flags.
+ if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ fileFlags := fs.SettableFileFlags{
+ NonBlocking: stype&linux.SOCK_NONBLOCK != 0,
+ }
+
+ // Create the socket pair.
+ s1, s2, e := socket.Pair(t, domain, linux.SockType(stype&0xf), protocol)
+ if e != nil {
+ return 0, nil, e.ToError()
+ }
+ s1.SetFlags(fileFlags)
+ s2.SetFlags(fileFlags)
+ defer s1.DecRef()
+ defer s2.DecRef()
+
+ // Create the FDs for the sockets.
+ fds, err := t.NewFDs(0, []*fs.File{s1, s2}, kernel.FDFlags{
+ CloseOnExec: stype&linux.SOCK_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Copy the file descriptors out.
+ if _, err := t.CopyOut(socks, fds); err != nil {
+ for _, fd := range fds {
+ if file, _ := t.FDTable().Remove(fd); file != nil {
+ file.DecRef()
+ }
+ }
+ return 0, nil, err
+ }
+
+ return 0, nil, nil
+}
+
+// Connect implements the linux syscall connect(2).
+func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Uint()
+
+ // Get socket from the file descriptor.
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Capture address and call syscall implementation.
+ a, err := CaptureAddress(t, addr, addrlen)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ blocking := !file.Flags().NonBlocking
+ return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), kernel.ERESTARTSYS)
+}
+
+// accept is the implementation of the accept syscall. It is called by accept
+// and accept4 syscall handlers.
+func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, flags int) (uintptr, error) {
+ // Check that no unsupported flags are passed in.
+ if flags & ^(linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, syserror.ENOTSOCK
+ }
+
+ // Call the syscall implementation for this socket, then copy the
+ // output address if one is specified.
+ blocking := !file.Flags().NonBlocking
+
+ peerRequested := addrLen != 0
+ nfd, peer, peerLen, e := s.Accept(t, peerRequested, flags, blocking)
+ if e != nil {
+ return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
+ }
+ if peerRequested {
+ // NOTE(magi): Linux does not give you an error if it can't
+ // write the data back out so neither do we.
+ if err := writeAddress(t, peer, peerLen, addr, addrLen); err == syserror.EINVAL {
+ return 0, err
+ }
+ }
+ return uintptr(nfd), nil
+}
+
+// Accept4 implements the linux syscall accept4(2).
+func Accept4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+ flags := int(args[3].Int())
+
+ n, err := accept(t, fd, addr, addrlen, flags)
+ return n, nil, err
+}
+
+// Accept implements the linux syscall accept(2).
+func Accept(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+
+ n, err := accept(t, fd, addr, addrlen, 0)
+ return n, nil, err
+}
+
+// Bind implements the linux syscall bind(2).
+func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Uint()
+
+ // Get socket from the file descriptor.
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Capture address and call syscall implementation.
+ a, err := CaptureAddress(t, addr, addrlen)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, s.Bind(t, a).ToError()
+}
+
+// Listen implements the linux syscall listen(2).
+func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ backlog := args[1].Int()
+
+ // Get socket from the file descriptor.
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Per Linux, the backlog is silently capped to reasonable values.
+ if backlog <= 0 {
+ backlog = minListenBacklog
+ }
+ if backlog > maxListenBacklog {
+ backlog = maxListenBacklog
+ }
+
+ return 0, nil, s.Listen(t, int(backlog)).ToError()
+}
+
+// Shutdown implements the linux syscall shutdown(2).
+func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ how := args[1].Int()
+
+ // Get socket from the file descriptor.
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Validate how, then call syscall implementation.
+ switch how {
+ case linux.SHUT_RD, linux.SHUT_WR, linux.SHUT_RDWR:
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, s.Shutdown(t, int(how)).ToError()
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2).
+func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ level := args[1].Int()
+ name := args[2].Int()
+ optValAddr := args[3].Pointer()
+ optLenAddr := args[4].Pointer()
+
+ // Get socket from the file descriptor.
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Read the length. Reject negative values.
+ optLen := int32(0)
+ if _, err := t.CopyIn(optLenAddr, &optLen); err != nil {
+ return 0, nil, err
+ }
+ if optLen < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Call syscall implementation then copy both value and value len out.
+ v, e := getSockOpt(t, s, int(level), int(name), optValAddr, int(optLen))
+ if e != nil {
+ return 0, nil, e.ToError()
+ }
+
+ vLen := int32(binary.Size(v))
+ if _, err := t.CopyOut(optLenAddr, vLen); err != nil {
+ return 0, nil, err
+ }
+
+ if v != nil {
+ if _, err := t.CopyOut(optValAddr, v); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ return 0, nil, nil
+}
+
+// getSockOpt tries to handle common socket options, or dispatches to a specific
+// socket implementation.
+func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) {
+ if level == linux.SOL_SOCKET {
+ switch name {
+ case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
+ if len < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ }
+
+ switch name {
+ case linux.SO_TYPE:
+ _, skType, _ := s.Type()
+ return int32(skType), nil
+ case linux.SO_DOMAIN:
+ family, _, _ := s.Type()
+ return int32(family), nil
+ case linux.SO_PROTOCOL:
+ _, _, protocol := s.Type()
+ return int32(protocol), nil
+ }
+ }
+
+ return s.GetSockOpt(t, level, name, optValAddr, len)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2).
+//
+// Note that unlike Linux, enabling SO_PASSCRED does not autobind the socket.
+func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ level := args[1].Int()
+ name := args[2].Int()
+ optValAddr := args[3].Pointer()
+ optLen := args[4].Int()
+
+ // Get socket from the file descriptor.
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ if optLen < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if optLen > maxOptLen {
+ return 0, nil, syserror.EINVAL
+ }
+ buf := t.CopyScratchBuffer(int(optLen))
+ if _, err := t.CopyIn(optValAddr, &buf); err != nil {
+ return 0, nil, err
+ }
+
+ // Call syscall implementation.
+ if err := s.SetSockOpt(t, int(level), int(name), buf); err != nil {
+ return 0, nil, err.ToError()
+ }
+
+ return 0, nil, nil
+}
+
+// GetSockName implements the linux syscall getsockname(2).
+func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+
+ // Get socket from the file descriptor.
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Get the socket name and copy it to the caller.
+ v, vl, err := s.GetSockName(t)
+ if err != nil {
+ return 0, nil, err.ToError()
+ }
+
+ return 0, nil, writeAddress(t, v, vl, addr, addrlen)
+}
+
+// GetPeerName implements the linux syscall getpeername(2).
+func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+
+ // Get socket from the file descriptor.
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Get the socket peer name and copy it to the caller.
+ v, vl, err := s.GetPeerName(t)
+ if err != nil {
+ return 0, nil, err.ToError()
+ }
+
+ return 0, nil, writeAddress(t, v, vl, addr, addrlen)
+}
+
+// RecvMsg implements the linux syscall recvmsg(2).
+func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ flags := args[2].Int()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if file.Flags().NonBlocking {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.RecvTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ n, err := recvSingleMsg(t, s, msgPtr, flags, haveDeadline, deadline)
+ return n, nil, err
+}
+
+// RecvMMsg implements the linux syscall recvmmsg(2).
+func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ vlen := args[2].Uint()
+ flags := args[3].Int()
+ toPtr := args[4].Pointer()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(baseRecvFlags|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ if file.Flags().NonBlocking {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if toPtr != 0 {
+ ts, err := copyTimespecIn(t, toPtr)
+ if err != nil {
+ return 0, nil, err
+ }
+ if !ts.Valid() {
+ return 0, nil, syserror.EINVAL
+ }
+ deadline = t.Kernel().MonotonicClock().Now().Add(ts.ToDuration())
+ haveDeadline = true
+ }
+
+ if !haveDeadline {
+ if dl := s.RecvTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+ }
+
+ var count uint32
+ var err error
+ for i := uint64(0); i < uint64(vlen); i++ {
+ mp, ok := msgPtr.AddLength(i * multipleMessageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ var n uintptr
+ if n, err = recvSingleMsg(t, s, mp, flags, haveDeadline, deadline); err != nil {
+ break
+ }
+
+ // Copy the received length to the caller.
+ lp, ok := mp.AddLength(messageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ break
+ }
+ count++
+ }
+
+ if count == 0 {
+ return 0, nil, err
+ }
+ return uintptr(count), nil, nil
+}
+
+func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) {
+ // Capture the message header and io vectors.
+ var msg MessageHeader64
+ if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ return 0, err
+ }
+
+ if msg.IovLen > linux.UIO_MAXIOV {
+ return 0, syserror.EMSGSIZE
+ }
+ dst, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ // FIXME(b/63594852): Pretend we have an empty error queue.
+ if flags&linux.MSG_ERRQUEUE != 0 {
+ return 0, syserror.EAGAIN
+ }
+
+ // Fast path when no control message nor name buffers are provided.
+ if msg.ControlLen == 0 && msg.NameLen == 0 {
+ n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0)
+ if err != nil {
+ return 0, syserror.ConvertIntr(err.ToError(), kernel.ERESTARTSYS)
+ }
+ if !cms.Unix.Empty() {
+ mflags |= linux.MSG_CTRUNC
+ cms.Release()
+ }
+
+ if int(msg.Flags) != mflags {
+ // Copy out the flags to the caller.
+ if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ return 0, err
+ }
+ }
+
+ return uintptr(n), nil
+ }
+
+ if msg.ControlLen > maxControlLen {
+ return 0, syserror.ENOBUFS
+ }
+ n, mflags, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen)
+ if e != nil {
+ return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
+ }
+ defer cms.Release()
+
+ controlData := make([]byte, 0, msg.ControlLen)
+ controlData = control.PackControlMessages(t, cms, controlData)
+
+ if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() {
+ creds, _ := cms.Unix.Credentials.(control.SCMCredentials)
+ controlData, mflags = control.PackCredentials(t, creds, controlData, mflags)
+ }
+
+ if cms.Unix.Rights != nil {
+ controlData, mflags = control.PackRights(t, cms.Unix.Rights.(control.SCMRights), flags&linux.MSG_CMSG_CLOEXEC != 0, controlData, mflags)
+ }
+
+ // Copy the address to the caller.
+ if msg.NameLen != 0 {
+ if err := writeAddress(t, sender, senderLen, usermem.Addr(msg.Name), usermem.Addr(msgPtr+nameLenOffset)); err != nil {
+ return 0, err
+ }
+ }
+
+ // Copy the control data to the caller.
+ if _, err := t.CopyOut(msgPtr+controlLenOffset, uint64(len(controlData))); err != nil {
+ return 0, err
+ }
+ if len(controlData) > 0 {
+ if _, err := t.CopyOut(usermem.Addr(msg.Control), controlData); err != nil {
+ return 0, err
+ }
+ }
+
+ // Copy out the flags to the caller.
+ if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ return 0, err
+ }
+
+ return uintptr(n), nil
+}
+
+// recvFrom is the implementation of the recvfrom syscall. It is called by
+// recvfrom and recv syscall handlers.
+func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLenPtr usermem.Addr) (uintptr, error) {
+ if int(bufLen) < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CONFIRM) != 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, syserror.ENOTSOCK
+ }
+
+ if file.Flags().NonBlocking {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ dst, err := t.SingleIOSequence(bufPtr, int(bufLen), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.RecvTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0)
+ cm.Release()
+ if e != nil {
+ return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
+ }
+
+ // Copy the address to the caller.
+ if nameLenPtr != 0 {
+ if err := writeAddress(t, sender, senderLen, namePtr, nameLenPtr); err != nil {
+ return 0, err
+ }
+ }
+
+ return uintptr(n), nil
+}
+
+// RecvFrom implements the linux syscall recvfrom(2).
+func RecvFrom(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ bufPtr := args[1].Pointer()
+ bufLen := args[2].Uint64()
+ flags := args[3].Int()
+ namePtr := args[4].Pointer()
+ nameLenPtr := args[5].Pointer()
+
+ n, err := recvFrom(t, fd, bufPtr, bufLen, flags, namePtr, nameLenPtr)
+ return n, nil, err
+}
+
+// SendMsg implements the linux syscall sendmsg(2).
+func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ flags := args[2].Int()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if file.Flags().NonBlocking {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ n, err := sendSingleMsg(t, s, file, msgPtr, flags)
+ return n, nil, err
+}
+
+// SendMMsg implements the linux syscall sendmmsg(2).
+func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ vlen := args[2].Uint()
+ flags := args[3].Int()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if file.Flags().NonBlocking {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ var count uint32
+ var err error
+ for i := uint64(0); i < uint64(vlen); i++ {
+ mp, ok := msgPtr.AddLength(i * multipleMessageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ var n uintptr
+ if n, err = sendSingleMsg(t, s, file, mp, flags); err != nil {
+ break
+ }
+
+ // Copy the received length to the caller.
+ lp, ok := mp.AddLength(messageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ break
+ }
+ count++
+ }
+
+ if count == 0 {
+ return 0, nil, err
+ }
+ return uintptr(count), nil, nil
+}
+
+func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr usermem.Addr, flags int32) (uintptr, error) {
+ // Capture the message header.
+ var msg MessageHeader64
+ if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ return 0, err
+ }
+
+ var controlData []byte
+ if msg.ControlLen > 0 {
+ // Put an upper bound to prevent large allocations.
+ if msg.ControlLen > maxControlLen {
+ return 0, syserror.ENOBUFS
+ }
+ controlData = make([]byte, msg.ControlLen)
+ if _, err := t.CopyIn(usermem.Addr(msg.Control), &controlData); err != nil {
+ return 0, err
+ }
+ }
+
+ // Read the destination address if one is specified.
+ var to []byte
+ if msg.NameLen != 0 {
+ var err error
+ to, err = CaptureAddress(t, usermem.Addr(msg.Name), msg.NameLen)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ // Read data then call the sendmsg implementation.
+ if msg.IovLen > linux.UIO_MAXIOV {
+ return 0, syserror.EMSGSIZE
+ }
+ src, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ controlMessages, err := control.Parse(t, s, controlData)
+ if err != nil {
+ return 0, err
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.SendTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ // Call the syscall implementation.
+ n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages)
+ err = handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file)
+ if err != nil {
+ controlMessages.Release()
+ }
+ return uintptr(n), err
+}
+
+// sendTo is the implementation of the sendto syscall. It is called by sendto
+// and send syscall handlers.
+func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLen uint32) (uintptr, error) {
+ bl := int(bufLen)
+ if bl < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, syserror.ENOTSOCK
+ }
+
+ if file.Flags().NonBlocking {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ // Read the destination address if one is specified.
+ var to []byte
+ var err error
+ if namePtr != 0 {
+ to, err = CaptureAddress(t, namePtr, nameLen)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ src, err := t.SingleIOSequence(bufPtr, bl, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.SendTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ // Call the syscall implementation.
+ n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: control.New(t, s, nil)})
+ return uintptr(n), handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendto", file)
+}
+
+// SendTo implements the linux syscall sendto(2).
+func SendTo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ bufPtr := args[1].Pointer()
+ bufLen := args[2].Uint64()
+ flags := args[3].Int()
+ namePtr := args[4].Pointer()
+ nameLen := args[5].Uint()
+
+ n, err := sendTo(t, fd, bufPtr, bufLen, flags, namePtr, nameLen)
+ return n, nil, err
+}
+
+// LINT.ThenChange(./vfs2/socket.go)
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
new file mode 100644
index 000000000..77c78889d
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -0,0 +1,337 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// doSplice implements a blocking splice operation.
+func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonBlocking bool) (int64, error) {
+ if opts.Length < 0 || opts.SrcStart < 0 || opts.DstStart < 0 || (opts.SrcStart+opts.Length < 0) {
+ return 0, syserror.EINVAL
+ }
+
+ if opts.Length > int64(kernel.MAX_RW_COUNT) {
+ opts.Length = int64(kernel.MAX_RW_COUNT)
+ }
+
+ var (
+ total int64
+ n int64
+ err error
+ inCh chan struct{}
+ outCh chan struct{}
+ )
+ for opts.Length > 0 {
+ n, err = fs.Splice(t, outFile, inFile, opts)
+ opts.Length -= n
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ } else if err == syserror.ErrWouldBlock && nonBlocking {
+ break
+ }
+
+ // Note that the blocking behavior here is a bit different than the
+ // normal pattern. Because we need to have both data to read and data
+ // to write simultaneously, we actually explicitly block on both of
+ // these cases in turn before returning to the splice operation.
+ if inFile.Readiness(EventMaskRead) == 0 {
+ if inCh == nil {
+ inCh = make(chan struct{}, 1)
+ inW, _ := waiter.NewChannelEntry(inCh)
+ inFile.EventRegister(&inW, EventMaskRead)
+ defer inFile.EventUnregister(&inW)
+ continue // Need to refresh readiness.
+ }
+ if err = t.Block(inCh); err != nil {
+ break
+ }
+ }
+ if outFile.Readiness(EventMaskWrite) == 0 {
+ if outCh == nil {
+ outCh = make(chan struct{}, 1)
+ outW, _ := waiter.NewChannelEntry(outCh)
+ outFile.EventRegister(&outW, EventMaskWrite)
+ defer outFile.EventUnregister(&outW)
+ continue // Need to refresh readiness.
+ }
+ if err = t.Block(outCh); err != nil {
+ break
+ }
+ }
+ }
+
+ if total > 0 {
+ // On Linux, inotify behavior is not very consistent with splice(2). We try
+ // our best to emulate Linux for very basic calls to splice, where for some
+ // reason, events are generated for output files, but not input files.
+ outFile.Dirent.InotifyEvent(linux.IN_MODIFY, 0)
+ }
+ return total, err
+}
+
+// Sendfile implements linux system call sendfile(2).
+func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ outFD := args[0].Int()
+ inFD := args[1].Int()
+ offsetAddr := args[2].Pointer()
+ count := int64(args[3].SizeT())
+
+ // Get files.
+ inFile := t.GetFile(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+
+ if !inFile.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
+ outFile := t.GetFile(outFD)
+ if outFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer outFile.DecRef()
+
+ if !outFile.Flags().Write {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Verify that the outfile Append flag is not set.
+ if outFile.Flags().Append {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Verify that we have a regular infile. This is a requirement; the
+ // same check appears in Linux (fs/splice.c:splice_direct_to_actor).
+ if !fs.IsRegular(inFile.Dirent.Inode.StableAttr) {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var (
+ n int64
+ err error
+ )
+ if offsetAddr != 0 {
+ // Verify that when offset address is not null, infile must be
+ // seekable. The fs.Splice routine itself validates basic read.
+ if !inFile.Flags().Pread {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ // Copy in the offset.
+ var offset int64
+ if _, err := t.CopyIn(offsetAddr, &offset); err != nil {
+ return 0, nil, err
+ }
+
+ // Do the splice.
+ n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{
+ Length: count,
+ SrcOffset: true,
+ SrcStart: offset,
+ }, outFile.Flags().NonBlocking)
+
+ // Copy out the new offset.
+ if _, err := t.CopyOut(offsetAddr, n+offset); err != nil {
+ return 0, nil, err
+ }
+ } else {
+ // Send data using splice.
+ n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{
+ Length: count,
+ }, outFile.Flags().NonBlocking)
+ }
+
+ // Sendfile can't lose any data because inFD is always a regual file.
+ if n != 0 {
+ err = nil
+ }
+
+ // We can only pass a single file to handleIOError, so pick inFile
+ // arbitrarily. This is used only for debugging purposes.
+ return uintptr(n), nil, handleIOError(t, false, err, kernel.ERESTARTSYS, "sendfile", inFile)
+}
+
+// Splice implements splice(2).
+func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ inFD := args[0].Int()
+ inOffset := args[1].Pointer()
+ outFD := args[2].Int()
+ outOffset := args[3].Pointer()
+ count := int64(args[4].SizeT())
+ flags := args[5].Int()
+
+ // Check for invalid flags.
+ if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get files.
+ outFile := t.GetFile(outFD)
+ if outFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer outFile.DecRef()
+
+ inFile := t.GetFile(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+
+ // The operation is non-blocking if anything is non-blocking.
+ //
+ // N.B. This is a rather simplistic heuristic that avoids some
+ // poor edge case behavior since the exact semantics here are
+ // underspecified and vary between versions of Linux itself.
+ nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0)
+
+ // Construct our options.
+ //
+ // Note that exactly one of the underlying buffers must be a pipe. We
+ // don't actually have this constraint internally, but we enforce it
+ // for the semantics of the call.
+ opts := fs.SpliceOpts{
+ Length: count,
+ }
+ inFileAttr := inFile.Dirent.Inode.StableAttr
+ outFileAttr := outFile.Dirent.Inode.StableAttr
+ switch {
+ case fs.IsPipe(inFileAttr) && !fs.IsPipe(outFileAttr):
+ if inOffset != 0 {
+ return 0, nil, syserror.ESPIPE
+ }
+ if outOffset != 0 {
+ if !outFile.Flags().Pwrite {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var offset int64
+ if _, err := t.CopyIn(outOffset, &offset); err != nil {
+ return 0, nil, err
+ }
+
+ // Use the destination offset.
+ opts.DstOffset = true
+ opts.DstStart = offset
+ }
+ case !fs.IsPipe(inFileAttr) && fs.IsPipe(outFileAttr):
+ if outOffset != 0 {
+ return 0, nil, syserror.ESPIPE
+ }
+ if inOffset != 0 {
+ if !inFile.Flags().Pread {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var offset int64
+ if _, err := t.CopyIn(inOffset, &offset); err != nil {
+ return 0, nil, err
+ }
+
+ // Use the source offset.
+ opts.SrcOffset = true
+ opts.SrcStart = offset
+ }
+ case fs.IsPipe(inFileAttr) && fs.IsPipe(outFileAttr):
+ if inOffset != 0 || outOffset != 0 {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ // We may not refer to the same pipe; otherwise it's a continuous loop.
+ if inFileAttr.InodeID == outFileAttr.InodeID {
+ return 0, nil, syserror.EINVAL
+ }
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Splice data.
+ n, err := doSplice(t, outFile, inFile, opts, nonBlock)
+
+ // Special files can have additional requirements for granularity. For
+ // example, read from eventfd returns EINVAL if a size is less 8 bytes.
+ // Inotify is another example. read will return EINVAL is a buffer is
+ // too small to return the next event, but a size of an event isn't
+ // fixed, it is sizeof(struct inotify_event) + {NAME_LEN} + 1.
+ if n != 0 && err != nil && (fs.IsAnonymous(inFileAttr) || fs.IsAnonymous(outFileAttr)) {
+ err = nil
+ }
+
+ // See above; inFile is chosen arbitrarily here.
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "splice", inFile)
+}
+
+// Tee imlements tee(2).
+func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ inFD := args[0].Int()
+ outFD := args[1].Int()
+ count := int64(args[2].SizeT())
+ flags := args[3].Int()
+
+ // Check for invalid flags.
+ if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get files.
+ outFile := t.GetFile(outFD)
+ if outFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer outFile.DecRef()
+
+ inFile := t.GetFile(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+
+ // All files must be pipes.
+ if !fs.IsPipe(inFile.Dirent.Inode.StableAttr) || !fs.IsPipe(outFile.Dirent.Inode.StableAttr) {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // We may not refer to the same pipe; see above.
+ if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // The operation is non-blocking if anything is non-blocking.
+ nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0)
+
+ // Splice data.
+ n, err := doSplice(t, outFile, inFile, fs.SpliceOpts{
+ Length: count,
+ Dup: true,
+ }, nonBlock)
+
+ // Tee doesn't change a state of inFD, so it can't lose any data.
+ if n != 0 {
+ err = nil
+ }
+
+ // See above; inFile is chosen arbitrarily here.
+ return uintptr(n), nil, handleIOError(t, false, err, kernel.ERESTARTSYS, "tee", inFile)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go
new file mode 100644
index 000000000..46ebf27a2
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_stat.go
@@ -0,0 +1,290 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// LINT.IfChange
+
+// Stat implements linux syscall stat(2).
+func Stat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ statAddr := args[1].Pointer()
+
+ path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ return stat(t, d, dirPath, statAddr)
+ })
+}
+
+// Fstatat implements linux syscall newfstatat, i.e. fstatat(2).
+func Fstatat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ statAddr := args[2].Pointer()
+ flags := args[3].Int()
+
+ path, dirPath, err := copyInPath(t, addr, flags&linux.AT_EMPTY_PATH != 0)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if path == "" {
+ // Annoying. What's wrong with fstat?
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, fstat(t, file, statAddr)
+ }
+
+ // If the path ends in a slash (i.e. dirPath is true) or if AT_SYMLINK_NOFOLLOW is unset,
+ // then we must resolve the final component.
+ resolve := dirPath || flags&linux.AT_SYMLINK_NOFOLLOW == 0
+
+ return 0, nil, fileOpOn(t, fd, path, resolve, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ return stat(t, d, dirPath, statAddr)
+ })
+}
+
+// Lstat implements linux syscall lstat(2).
+func Lstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ statAddr := args[1].Pointer()
+
+ path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // If the path ends in a slash (i.e. dirPath is true), then we *do*
+ // want to resolve the final component.
+ resolve := dirPath
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, resolve, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ return stat(t, d, dirPath, statAddr)
+ })
+}
+
+// Fstat implements linux syscall fstat(2).
+func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ statAddr := args[1].Pointer()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, fstat(t, file, statAddr)
+}
+
+// stat implements stat from the given *fs.Dirent.
+func stat(t *kernel.Task, d *fs.Dirent, dirPath bool, statAddr usermem.Addr) error {
+ if dirPath && !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+ uattr, err := d.Inode.UnstableAttr(t)
+ if err != nil {
+ return err
+ }
+ s := statFromAttrs(t, d.Inode.StableAttr, uattr)
+ _, err = s.CopyOut(t, statAddr)
+ return err
+}
+
+// fstat implements fstat for the given *fs.File.
+func fstat(t *kernel.Task, f *fs.File, statAddr usermem.Addr) error {
+ uattr, err := f.UnstableAttr(t)
+ if err != nil {
+ return err
+ }
+ s := statFromAttrs(t, f.Dirent.Inode.StableAttr, uattr)
+ _, err = s.CopyOut(t, statAddr)
+ return err
+}
+
+// Statx implements linux syscall statx(2).
+func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ flags := args[2].Int()
+ mask := args[3].Uint()
+ statxAddr := args[4].Pointer()
+
+ if mask&linux.STATX__RESERVED != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if flags&^(linux.AT_SYMLINK_NOFOLLOW|linux.AT_EMPTY_PATH|linux.AT_STATX_SYNC_TYPE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if flags&linux.AT_STATX_SYNC_TYPE == linux.AT_STATX_SYNC_TYPE {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, dirPath, err := copyInPath(t, pathAddr, flags&linux.AT_EMPTY_PATH != 0)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if path == "" {
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+ uattr, err := file.UnstableAttr(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, statx(t, file.Dirent.Inode.StableAttr, uattr, statxAddr)
+ }
+
+ resolve := dirPath || flags&linux.AT_SYMLINK_NOFOLLOW == 0
+
+ return 0, nil, fileOpOn(t, fd, path, resolve, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ if dirPath && !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+ uattr, err := d.Inode.UnstableAttr(t)
+ if err != nil {
+ return err
+ }
+ return statx(t, d.Inode.StableAttr, uattr, statxAddr)
+ })
+}
+
+func statx(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr, statxAddr usermem.Addr) error {
+ // "[T]he kernel may return fields that weren't requested and may fail to
+ // return fields that were requested, depending on what the backing
+ // filesystem supports.
+ // [...]
+ // A filesystem may also fill in fields that the caller didn't ask for
+ // if it has values for them available and the information is available
+ // at no extra cost. If this happens, the corresponding bits will be
+ // set in stx_mask." -- statx(2)
+ //
+ // We fill in all the values we have (which currently does not include
+ // btime, see b/135608823), regardless of what the user asked for. The
+ // STATX_BASIC_STATS mask indicates that all fields are present except
+ // for btime.
+
+ devMajor, devMinor := linux.DecodeDeviceID(uint32(sattr.DeviceID))
+ s := linux.Statx{
+ // TODO(b/135608823): Support btime, and then change this to
+ // STATX_ALL to indicate presence of btime.
+ Mask: linux.STATX_BASIC_STATS,
+
+ // No attributes, and none supported.
+ Attributes: 0,
+ AttributesMask: 0,
+
+ Blksize: uint32(sattr.BlockSize),
+ Nlink: uint32(uattr.Links),
+ UID: uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()),
+ GID: uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()),
+ Mode: uint16(sattr.Type.LinuxType()) | uint16(uattr.Perms.LinuxMode()),
+ Ino: sattr.InodeID,
+ Size: uint64(uattr.Size),
+ Blocks: uint64(uattr.Usage) / 512,
+ Atime: uattr.AccessTime.StatxTimestamp(),
+ Ctime: uattr.StatusChangeTime.StatxTimestamp(),
+ Mtime: uattr.ModificationTime.StatxTimestamp(),
+ RdevMajor: uint32(sattr.DeviceFileMajor),
+ RdevMinor: sattr.DeviceFileMinor,
+ DevMajor: uint32(devMajor),
+ DevMinor: devMinor,
+ }
+ _, err := t.CopyOut(statxAddr, &s)
+ return err
+}
+
+// Statfs implements linux syscall statfs(2).
+func Statfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ statfsAddr := args[1].Pointer()
+
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ return statfsImpl(t, d, statfsAddr)
+ })
+}
+
+// Fstatfs implements linux syscall fstatfs(2).
+func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ statfsAddr := args[1].Pointer()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, statfsImpl(t, file.Dirent, statfsAddr)
+}
+
+// statfsImpl implements the linux syscall statfs and fstatfs based on a Dirent,
+// copying the statfs structure out to addr on success, otherwise an error is
+// returned.
+func statfsImpl(t *kernel.Task, d *fs.Dirent, addr usermem.Addr) error {
+ info, err := d.Inode.StatFS(t)
+ if err != nil {
+ return err
+ }
+ // Construct the statfs structure and copy it out.
+ statfs := linux.Statfs{
+ Type: info.Type,
+ // Treat block size and fragment size as the same, as
+ // most consumers of this structure will expect one
+ // or the other to be filled in.
+ BlockSize: d.Inode.StableAttr.BlockSize,
+ Blocks: info.TotalBlocks,
+ // We don't have the concept of reserved blocks, so
+ // report blocks free the same as available blocks.
+ // This is a normal thing for filesystems, to do, see
+ // udf, hugetlbfs, tmpfs, among others.
+ BlocksFree: info.FreeBlocks,
+ BlocksAvailable: info.FreeBlocks,
+ Files: info.TotalFiles,
+ FilesFree: info.FreeFiles,
+ // Same as Linux for simple_statfs, see fs/libfs.c.
+ NameLength: linux.NAME_MAX,
+ FragmentSize: d.Inode.StableAttr.BlockSize,
+ // Leave other fields 0 like simple_statfs does.
+ }
+ _, err = t.CopyOut(addr, &statfs)
+ return err
+}
+
+// LINT.ThenChange(vfs2/stat.go)
diff --git a/pkg/sentry/syscalls/linux/sys_stat_amd64.go b/pkg/sentry/syscalls/linux/sys_stat_amd64.go
new file mode 100644
index 000000000..0a04a6113
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_stat_amd64.go
@@ -0,0 +1,45 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// LINT.IfChange
+
+func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat {
+ return linux.Stat{
+ Dev: sattr.DeviceID,
+ Ino: sattr.InodeID,
+ Nlink: uattr.Links,
+ Mode: sattr.Type.LinuxType() | uint32(uattr.Perms.LinuxMode()),
+ UID: uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()),
+ GID: uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)),
+ Size: uattr.Size,
+ Blksize: sattr.BlockSize,
+ Blocks: uattr.Usage / 512,
+ ATime: uattr.AccessTime.Timespec(),
+ MTime: uattr.ModificationTime.Timespec(),
+ CTime: uattr.StatusChangeTime.Timespec(),
+ }
+}
+
+// LINT.ThenChange(vfs2/stat_amd64.go)
diff --git a/pkg/sentry/syscalls/linux/sys_stat_arm64.go b/pkg/sentry/syscalls/linux/sys_stat_arm64.go
new file mode 100644
index 000000000..5a3b1bfad
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_stat_arm64.go
@@ -0,0 +1,45 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// LINT.IfChange
+
+func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat {
+ return linux.Stat{
+ Dev: sattr.DeviceID,
+ Ino: sattr.InodeID,
+ Nlink: uint32(uattr.Links),
+ Mode: sattr.Type.LinuxType() | uint32(uattr.Perms.LinuxMode()),
+ UID: uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()),
+ GID: uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)),
+ Size: uattr.Size,
+ Blksize: int32(sattr.BlockSize),
+ Blocks: uattr.Usage / 512,
+ ATime: uattr.AccessTime.Timespec(),
+ MTime: uattr.ModificationTime.Timespec(),
+ CTime: uattr.StatusChangeTime.Timespec(),
+ }
+}
+
+// LINT.ThenChange(vfs2/stat_arm64.go)
diff --git a/pkg/sentry/syscalls/linux/sys_sync.go b/pkg/sentry/syscalls/linux/sys_sync.go
new file mode 100644
index 000000000..5ad465ae3
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_sync.go
@@ -0,0 +1,141 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// LINT.IfChange
+
+// Sync implements linux system call sync(2).
+func Sync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ t.MountNamespace().SyncAll(t)
+ // Sync is always successful.
+ return 0, nil, nil
+}
+
+// Syncfs implements linux system call syncfs(2).
+func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Use "sync-the-world" for now, it's guaranteed that fd is at least
+ // on the root filesystem.
+ return Sync(t, args)
+}
+
+// Fsync implements linux syscall fsync(2).
+func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ err := file.Fsync(t, 0, fs.FileMaxOffset, fs.SyncAll)
+ return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
+}
+
+// Fdatasync implements linux syscall fdatasync(2).
+//
+// At the moment, it just calls Fsync, which is a big hammer, but correct.
+func Fdatasync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ err := file.Fsync(t, 0, fs.FileMaxOffset, fs.SyncData)
+ return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
+}
+
+// SyncFileRange implements linux syscall sync_file_rage(2)
+func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ var err error
+
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ nbytes := args[2].Int64()
+ uflags := args[3].Uint()
+
+ if offset < 0 || offset+nbytes < offset {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if uflags&^(linux.SYNC_FILE_RANGE_WAIT_BEFORE|
+ linux.SYNC_FILE_RANGE_WRITE|
+ linux.SYNC_FILE_RANGE_WAIT_AFTER) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if nbytes == 0 {
+ nbytes = fs.FileMaxOffset
+ }
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // SYNC_FILE_RANGE_WAIT_BEFORE waits upon write-out of all pages in the
+ // specified range that have already been submitted to the device
+ // driver for write-out before performing any write.
+ if uflags&linux.SYNC_FILE_RANGE_WAIT_BEFORE != 0 &&
+ uflags&linux.SYNC_FILE_RANGE_WAIT_AFTER == 0 {
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.ENOSYS
+ }
+
+ // SYNC_FILE_RANGE_WRITE initiates write-out of all dirty pages in the
+ // specified range which are not presently submitted write-out.
+ //
+ // It looks impossible to implement this functionality without a
+ // massive rework of the vfs subsystem. file.Fsync() take a file lock
+ // for the entire operation, so even if it is running in a go routing,
+ // it blocks other file operations instead of flushing data in the
+ // background.
+ //
+ // It should be safe to skipped this flag while nobody uses
+ // SYNC_FILE_RANGE_WAIT_BEFORE.
+
+ // SYNC_FILE_RANGE_WAIT_AFTER waits upon write-out of all pages in the
+ // range after performing any write.
+ //
+ // In Linux, sync_file_range() doesn't writes out the file's
+ // meta-data, but fdatasync() does if a file size is changed.
+ if uflags&linux.SYNC_FILE_RANGE_WAIT_AFTER != 0 {
+ err = file.Fsync(t, offset, fs.FileMaxOffset, fs.SyncData)
+ }
+
+ return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
+}
+
+// LINT.ThenChange(vfs2/sync.go)
diff --git a/pkg/sentry/syscalls/linux/sys_sysinfo.go b/pkg/sentry/syscalls/linux/sys_sysinfo.go
new file mode 100644
index 000000000..297de052a
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_sysinfo.go
@@ -0,0 +1,48 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+)
+
+// Sysinfo implements the sysinfo syscall as described in man 2 sysinfo.
+func Sysinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ mf := t.Kernel().MemoryFile()
+ mf.UpdateUsage()
+ _, totalUsage := usage.MemoryAccounting.Copy()
+ totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage)
+ memFree := totalSize - totalUsage
+ if memFree > totalSize {
+ // Underflow.
+ memFree = 0
+ }
+
+ // Only a subset of the fields in sysinfo_t make sense to return.
+ si := linux.Sysinfo{
+ Procs: uint16(len(t.PIDNamespace().Tasks())),
+ Uptime: t.Kernel().MonotonicClock().Now().Seconds(),
+ TotalRAM: totalSize,
+ FreeRAM: memFree,
+ Unit: 1,
+ }
+ _, err := t.CopyOut(addr, si)
+ return 0, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/sys_syslog.go b/pkg/sentry/syscalls/linux/sys_syslog.go
new file mode 100644
index 000000000..40c8bb061
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_syslog.go
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const (
+ _SYSLOG_ACTION_READ_ALL = 3
+ _SYSLOG_ACTION_SIZE_BUFFER = 10
+)
+
+// logBufLen is the default syslog buffer size on Linux.
+const logBufLen = 1 << 17
+
+// Syslog implements part of Linux syscall syslog.
+//
+// Only the unpriviledged commands are implemented, allowing applications to
+// read a fun dmesg.
+func Syslog(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ command := args[0].Int()
+ buf := args[1].Pointer()
+ size := int(args[2].Int())
+
+ switch command {
+ case _SYSLOG_ACTION_READ_ALL:
+ if size < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if size > logBufLen {
+ size = logBufLen
+ }
+
+ log := t.Kernel().Syslog().Log()
+ if len(log) > size {
+ log = log[:size]
+ }
+
+ n, err := t.CopyOutBytes(buf, log)
+ return uintptr(n), nil, err
+ case _SYSLOG_ACTION_SIZE_BUFFER:
+ return logBufLen, nil, nil
+ default:
+ return 0, nil, syserror.ENOSYS
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
new file mode 100644
index 000000000..00915fdde
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -0,0 +1,769 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "path"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/sched"
+ "gvisor.dev/gvisor/pkg/sentry/loader"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ // ExecMaxTotalSize is the maximum length of all argv and envv entries.
+ //
+ // N.B. The behavior here is different than Linux. Linux provides a limit on
+ // individual arguments of 32 pages, and an aggregate limit of at least 32 pages
+ // but otherwise bounded by min(stack size / 4, 8 MB * 3 / 4). We don't implement
+ // any behavior based on the stack size, and instead provide a fixed hard-limit of
+ // 2 MB (which should work well given that 8 MB stack limits are common).
+ ExecMaxTotalSize = 2 * 1024 * 1024
+
+ // ExecMaxElemSize is the maximum length of a single argv or envv entry.
+ ExecMaxElemSize = 32 * usermem.PageSize
+
+ // exitSignalMask is the signal mask to be sent at exit. Same as CSIGNAL in linux.
+ exitSignalMask = 0xff
+)
+
+// Getppid implements linux syscall getppid(2).
+func Getppid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ parent := t.Parent()
+ if parent == nil {
+ return 0, nil, nil
+ }
+ return uintptr(t.PIDNamespace().IDOfThreadGroup(parent.ThreadGroup())), nil, nil
+}
+
+// Getpid implements linux syscall getpid(2).
+func Getpid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return uintptr(t.ThreadGroup().ID()), nil, nil
+}
+
+// Gettid implements linux syscall gettid(2).
+func Gettid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return uintptr(t.ThreadID()), nil, nil
+}
+
+// Execve implements linux syscall execve(2).
+func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ filenameAddr := args[0].Pointer()
+ argvAddr := args[1].Pointer()
+ envvAddr := args[2].Pointer()
+
+ return execveat(t, linux.AT_FDCWD, filenameAddr, argvAddr, envvAddr, 0)
+}
+
+// Execveat implements linux syscall execveat(2).
+func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := args[0].Int()
+ pathnameAddr := args[1].Pointer()
+ argvAddr := args[2].Pointer()
+ envvAddr := args[3].Pointer()
+ flags := args[4].Int()
+
+ return execveat(t, dirFD, pathnameAddr, argvAddr, envvAddr, flags)
+}
+
+func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) {
+ pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var argv, envv []string
+ if argvAddr != 0 {
+ var err error
+ argv, err = t.CopyInVector(argvAddr, ExecMaxElemSize, ExecMaxTotalSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+ if envvAddr != 0 {
+ var err error
+ envv, err = t.CopyInVector(envvAddr, ExecMaxElemSize, ExecMaxTotalSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ atEmptyPath := flags&linux.AT_EMPTY_PATH != 0
+ if !atEmptyPath && len(pathname) == 0 {
+ return 0, nil, syserror.ENOENT
+ }
+ resolveFinal := flags&linux.AT_SYMLINK_NOFOLLOW == 0
+
+ root := t.FSContext().RootDirectory()
+ defer root.DecRef()
+
+ var wd *fs.Dirent
+ var executable fsbridge.File
+ var closeOnExec bool
+ if dirFD == linux.AT_FDCWD || path.IsAbs(pathname) {
+ // Even if the pathname is absolute, we may still need the wd
+ // for interpreter scripts if the path of the interpreter is
+ // relative.
+ wd = t.FSContext().WorkingDirectory()
+ } else {
+ // Need to extract the given FD.
+ f, fdFlags := t.FDTable().Get(dirFD)
+ if f == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer f.DecRef()
+ closeOnExec = fdFlags.CloseOnExec
+
+ if atEmptyPath && len(pathname) == 0 {
+ // TODO(gvisor.dev/issue/160): Linux requires only execute permission,
+ // not read. However, our backing filesystems may prevent us from reading
+ // the file without read permission. Additionally, a task with a
+ // non-readable executable has additional constraints on access via
+ // ptrace and procfs.
+ if err := f.Dirent.Inode.CheckPermission(t, fs.PermMask{Read: true, Execute: true}); err != nil {
+ return 0, nil, err
+ }
+ executable = fsbridge.NewFSFile(f)
+ } else {
+ wd = f.Dirent
+ wd.IncRef()
+ if !fs.IsDir(wd.Inode.StableAttr) {
+ return 0, nil, syserror.ENOTDIR
+ }
+ }
+ }
+ if wd != nil {
+ defer wd.DecRef()
+ }
+
+ // Load the new TaskContext.
+ remainingTraversals := uint(linux.MaxSymlinkTraversals)
+ loadArgs := loader.LoadArgs{
+ Opener: fsbridge.NewFSLookup(t.MountNamespace(), root, wd),
+ RemainingTraversals: &remainingTraversals,
+ ResolveFinal: resolveFinal,
+ Filename: pathname,
+ File: executable,
+ CloseOnExec: closeOnExec,
+ Argv: argv,
+ Envv: envv,
+ Features: t.Arch().FeatureSet(),
+ }
+
+ tc, se := t.Kernel().LoadTaskImage(t, loadArgs)
+ if se != nil {
+ return 0, nil, se.ToError()
+ }
+
+ ctrl, err := t.Execve(tc)
+ return 0, ctrl, err
+}
+
+// Exit implements linux syscall exit(2).
+func Exit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ status := int(args[0].Int())
+ t.PrepareExit(kernel.ExitStatus{Code: status})
+ return 0, kernel.CtrlDoExit, nil
+}
+
+// ExitGroup implements linux syscall exit_group(2).
+func ExitGroup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ status := int(args[0].Int())
+ t.PrepareGroupExit(kernel.ExitStatus{Code: status})
+ return 0, kernel.CtrlDoExit, nil
+}
+
+// clone is used by Clone, Fork, and VFork.
+func clone(t *kernel.Task, flags int, stack usermem.Addr, parentTID usermem.Addr, childTID usermem.Addr, tls usermem.Addr) (uintptr, *kernel.SyscallControl, error) {
+ opts := kernel.CloneOptions{
+ SharingOptions: kernel.SharingOptions{
+ NewAddressSpace: flags&linux.CLONE_VM == 0,
+ NewSignalHandlers: flags&linux.CLONE_SIGHAND == 0,
+ NewThreadGroup: flags&linux.CLONE_THREAD == 0,
+ TerminationSignal: linux.Signal(flags & exitSignalMask),
+ NewPIDNamespace: flags&linux.CLONE_NEWPID == linux.CLONE_NEWPID,
+ NewUserNamespace: flags&linux.CLONE_NEWUSER == linux.CLONE_NEWUSER,
+ NewNetworkNamespace: flags&linux.CLONE_NEWNET == linux.CLONE_NEWNET,
+ NewFiles: flags&linux.CLONE_FILES == 0,
+ NewFSContext: flags&linux.CLONE_FS == 0,
+ NewUTSNamespace: flags&linux.CLONE_NEWUTS == linux.CLONE_NEWUTS,
+ NewIPCNamespace: flags&linux.CLONE_NEWIPC == linux.CLONE_NEWIPC,
+ },
+ Stack: stack,
+ SetTLS: flags&linux.CLONE_SETTLS == linux.CLONE_SETTLS,
+ TLS: tls,
+ ChildClearTID: flags&linux.CLONE_CHILD_CLEARTID == linux.CLONE_CHILD_CLEARTID,
+ ChildSetTID: flags&linux.CLONE_CHILD_SETTID == linux.CLONE_CHILD_SETTID,
+ ChildTID: childTID,
+ ParentSetTID: flags&linux.CLONE_PARENT_SETTID == linux.CLONE_PARENT_SETTID,
+ ParentTID: parentTID,
+ Vfork: flags&linux.CLONE_VFORK == linux.CLONE_VFORK,
+ Untraced: flags&linux.CLONE_UNTRACED == linux.CLONE_UNTRACED,
+ InheritTracer: flags&linux.CLONE_PTRACE == linux.CLONE_PTRACE,
+ }
+ ntid, ctrl, err := t.Clone(&opts)
+ return uintptr(ntid), ctrl, err
+}
+
+// Fork implements Linux syscall fork(2).
+func Fork(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // "A call to fork() is equivalent to a call to clone(2) specifying flags
+ // as just SIGCHLD." - fork(2)
+ return clone(t, int(linux.SIGCHLD), 0, 0, 0, 0)
+}
+
+// Vfork implements Linux syscall vfork(2).
+func Vfork(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // """
+ // A call to vfork() is equivalent to calling clone(2) with flags specified as:
+ //
+ // CLONE_VM | CLONE_VFORK | SIGCHLD
+ // """ - vfork(2)
+ return clone(t, linux.CLONE_VM|linux.CLONE_VFORK|int(linux.SIGCHLD), 0, 0, 0, 0)
+}
+
+// parseCommonWaitOptions applies the options common to wait4 and waitid to
+// wopts.
+func parseCommonWaitOptions(wopts *kernel.WaitOptions, options int) error {
+ switch options & (linux.WCLONE | linux.WALL) {
+ case 0:
+ wopts.NonCloneTasks = true
+ case linux.WCLONE:
+ wopts.CloneTasks = true
+ case linux.WALL:
+ wopts.NonCloneTasks = true
+ wopts.CloneTasks = true
+ default:
+ return syserror.EINVAL
+ }
+ if options&linux.WCONTINUED != 0 {
+ wopts.Events |= kernel.EventGroupContinue
+ }
+ if options&linux.WNOHANG == 0 {
+ wopts.BlockInterruptErr = kernel.ERESTARTSYS
+ }
+ if options&linux.WNOTHREAD == 0 {
+ wopts.SiblingChildren = true
+ }
+ return nil
+}
+
+// wait4 waits for the given child process to exit.
+func wait4(t *kernel.Task, pid int, statusAddr usermem.Addr, options int, rusageAddr usermem.Addr) (uintptr, error) {
+ if options&^(linux.WNOHANG|linux.WUNTRACED|linux.WCONTINUED|linux.WNOTHREAD|linux.WALL|linux.WCLONE) != 0 {
+ return 0, syserror.EINVAL
+ }
+ wopts := kernel.WaitOptions{
+ Events: kernel.EventExit | kernel.EventTraceeStop,
+ ConsumeEvent: true,
+ }
+ // There are four cases to consider:
+ //
+ // pid < -1 any child process whose process group ID is equal to the absolute value of pid
+ // pid == -1 any child process
+ // pid == 0 any child process whose process group ID is equal to that of the calling process
+ // pid > 0 the child whose process ID is equal to the value of pid
+ switch {
+ case pid < -1:
+ wopts.SpecificPGID = kernel.ProcessGroupID(-pid)
+ case pid == -1:
+ // Any process is the default.
+ case pid == 0:
+ wopts.SpecificPGID = t.PIDNamespace().IDOfProcessGroup(t.ThreadGroup().ProcessGroup())
+ default:
+ wopts.SpecificTID = kernel.ThreadID(pid)
+ }
+
+ if err := parseCommonWaitOptions(&wopts, options); err != nil {
+ return 0, err
+ }
+ if options&linux.WUNTRACED != 0 {
+ wopts.Events |= kernel.EventChildGroupStop
+ }
+
+ wr, err := t.Wait(&wopts)
+ if err != nil {
+ if err == kernel.ErrNoWaitableEvent {
+ return 0, nil
+ }
+ return 0, err
+ }
+ if statusAddr != 0 {
+ if _, err := t.CopyOut(statusAddr, wr.Status); err != nil {
+ return 0, err
+ }
+ }
+ if rusageAddr != 0 {
+ ru := getrusage(wr.Task, linux.RUSAGE_BOTH)
+ if _, err := t.CopyOut(rusageAddr, &ru); err != nil {
+ return 0, err
+ }
+ }
+ return uintptr(wr.TID), nil
+}
+
+// Wait4 implements linux syscall wait4(2).
+func Wait4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pid := int(args[0].Int())
+ statusAddr := args[1].Pointer()
+ options := int(args[2].Uint())
+ rusageAddr := args[3].Pointer()
+
+ n, err := wait4(t, pid, statusAddr, options, rusageAddr)
+ return n, nil, err
+}
+
+// WaitPid implements linux syscall waitpid(2).
+func WaitPid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pid := int(args[0].Int())
+ statusAddr := args[1].Pointer()
+ options := int(args[2].Uint())
+
+ n, err := wait4(t, pid, statusAddr, options, 0)
+ return n, nil, err
+}
+
+// Waitid implements linux syscall waitid(2).
+func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ idtype := args[0].Int()
+ id := args[1].Int()
+ infop := args[2].Pointer()
+ options := int(args[3].Uint())
+ rusageAddr := args[4].Pointer()
+
+ if options&^(linux.WNOHANG|linux.WEXITED|linux.WSTOPPED|linux.WCONTINUED|linux.WNOWAIT|linux.WNOTHREAD|linux.WALL|linux.WCLONE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if options&(linux.WEXITED|linux.WSTOPPED|linux.WCONTINUED) == 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ wopts := kernel.WaitOptions{
+ Events: kernel.EventTraceeStop,
+ ConsumeEvent: options&linux.WNOWAIT == 0,
+ }
+ switch idtype {
+ case linux.P_ALL:
+ case linux.P_PID:
+ wopts.SpecificTID = kernel.ThreadID(id)
+ case linux.P_PGID:
+ wopts.SpecificPGID = kernel.ProcessGroupID(id)
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ if err := parseCommonWaitOptions(&wopts, options); err != nil {
+ return 0, nil, err
+ }
+ if options&linux.WEXITED != 0 {
+ wopts.Events |= kernel.EventExit
+ }
+ if options&linux.WSTOPPED != 0 {
+ wopts.Events |= kernel.EventChildGroupStop
+ }
+
+ wr, err := t.Wait(&wopts)
+ if err != nil {
+ if err == kernel.ErrNoWaitableEvent {
+ err = nil
+ // "If WNOHANG was specified in options and there were no children
+ // in a waitable state, then waitid() returns 0 immediately and the
+ // state of the siginfo_t structure pointed to by infop is
+ // unspecified." - waitid(2). But Linux's waitid actually zeroes
+ // out the fields it would set for a successful waitid in this case
+ // as well.
+ if infop != 0 {
+ var si arch.SignalInfo
+ _, err = t.CopyOut(infop, &si)
+ }
+ }
+ return 0, nil, err
+ }
+ if rusageAddr != 0 {
+ ru := getrusage(wr.Task, linux.RUSAGE_BOTH)
+ if _, err := t.CopyOut(rusageAddr, &ru); err != nil {
+ return 0, nil, err
+ }
+ }
+ if infop == 0 {
+ return 0, nil, nil
+ }
+ si := arch.SignalInfo{
+ Signo: int32(linux.SIGCHLD),
+ }
+ si.SetPid(int32(wr.TID))
+ si.SetUid(int32(wr.UID))
+ // TODO(b/73541790): convert kernel.ExitStatus to functions and make
+ // WaitResult.Status a linux.WaitStatus.
+ s := syscall.WaitStatus(wr.Status)
+ switch {
+ case s.Exited():
+ si.Code = arch.CLD_EXITED
+ si.SetStatus(int32(s.ExitStatus()))
+ case s.Signaled():
+ si.Code = arch.CLD_KILLED
+ si.SetStatus(int32(s.Signal()))
+ case s.CoreDump():
+ si.Code = arch.CLD_DUMPED
+ si.SetStatus(int32(s.Signal()))
+ case s.Stopped():
+ if wr.Event == kernel.EventTraceeStop {
+ si.Code = arch.CLD_TRAPPED
+ si.SetStatus(int32(s.TrapCause()))
+ } else {
+ si.Code = arch.CLD_STOPPED
+ si.SetStatus(int32(s.StopSignal()))
+ }
+ case s.Continued():
+ si.Code = arch.CLD_CONTINUED
+ si.SetStatus(int32(linux.SIGCONT))
+ default:
+ t.Warningf("waitid got incomprehensible wait status %d", s)
+ }
+ _, err = t.CopyOut(infop, &si)
+ return 0, nil, err
+}
+
+// SetTidAddress implements linux syscall set_tid_address(2).
+func SetTidAddress(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ // Always succeed, return caller's tid.
+ t.SetClearTID(addr)
+ return uintptr(t.ThreadID()), nil, nil
+}
+
+// Unshare implements linux syscall unshare(2).
+func Unshare(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := args[0].Int()
+ opts := kernel.SharingOptions{
+ NewAddressSpace: flags&linux.CLONE_VM == linux.CLONE_VM,
+ NewSignalHandlers: flags&linux.CLONE_SIGHAND == linux.CLONE_SIGHAND,
+ NewThreadGroup: flags&linux.CLONE_THREAD == linux.CLONE_THREAD,
+ NewPIDNamespace: flags&linux.CLONE_NEWPID == linux.CLONE_NEWPID,
+ NewUserNamespace: flags&linux.CLONE_NEWUSER == linux.CLONE_NEWUSER,
+ NewNetworkNamespace: flags&linux.CLONE_NEWNET == linux.CLONE_NEWNET,
+ NewFiles: flags&linux.CLONE_FILES == linux.CLONE_FILES,
+ NewFSContext: flags&linux.CLONE_FS == linux.CLONE_FS,
+ NewUTSNamespace: flags&linux.CLONE_NEWUTS == linux.CLONE_NEWUTS,
+ NewIPCNamespace: flags&linux.CLONE_NEWIPC == linux.CLONE_NEWIPC,
+ }
+ // "CLONE_NEWPID automatically implies CLONE_THREAD as well." - unshare(2)
+ if opts.NewPIDNamespace {
+ opts.NewThreadGroup = true
+ }
+ // "... specifying CLONE_NEWUSER automatically implies CLONE_THREAD. Since
+ // Linux 3.9, CLONE_NEWUSER also automatically implies CLONE_FS."
+ if opts.NewUserNamespace {
+ opts.NewThreadGroup = true
+ opts.NewFSContext = true
+ }
+ return 0, nil, t.Unshare(&opts)
+}
+
+// SchedYield implements linux syscall sched_yield(2).
+func SchedYield(t *kernel.Task, _ arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ t.Yield()
+ return 0, nil, nil
+}
+
+// SchedSetaffinity implements linux syscall sched_setaffinity(2).
+func SchedSetaffinity(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ tid := args[0].Int()
+ size := args[1].SizeT()
+ maskAddr := args[2].Pointer()
+
+ var task *kernel.Task
+ if tid == 0 {
+ task = t
+ } else {
+ task = t.PIDNamespace().TaskWithID(kernel.ThreadID(tid))
+ if task == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ }
+
+ mask := sched.NewCPUSet(t.Kernel().ApplicationCores())
+ if size > mask.Size() {
+ size = mask.Size()
+ }
+ if _, err := t.CopyInBytes(maskAddr, mask[:size]); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, task.SetCPUMask(mask)
+}
+
+// SchedGetaffinity implements linux syscall sched_getaffinity(2).
+func SchedGetaffinity(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ tid := args[0].Int()
+ size := args[1].SizeT()
+ maskAddr := args[2].Pointer()
+
+ // This limitation is because linux stores the cpumask
+ // in an array of "unsigned long" so the buffer needs to
+ // be a multiple of the word size.
+ if size&(t.Arch().Width()-1) > 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var task *kernel.Task
+ if tid == 0 {
+ task = t
+ } else {
+ task = t.PIDNamespace().TaskWithID(kernel.ThreadID(tid))
+ if task == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ }
+
+ mask := task.CPUMask()
+ // The buffer needs to be big enough to hold a cpumask with
+ // all possible cpus.
+ if size < mask.Size() {
+ return 0, nil, syserror.EINVAL
+ }
+ _, err := t.CopyOutBytes(maskAddr, mask)
+
+ // NOTE: The syscall interface is slightly different than the glibc
+ // interface. The raw sched_getaffinity syscall returns the number of
+ // bytes used to represent a cpu mask.
+ return uintptr(mask.Size()), nil, err
+}
+
+// Getcpu implements linux syscall getcpu(2).
+func Getcpu(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ cpu := args[0].Pointer()
+ node := args[1].Pointer()
+ // third argument to this system call is nowadays unused.
+
+ if cpu != 0 {
+ buf := t.CopyScratchBuffer(4)
+ usermem.ByteOrder.PutUint32(buf, uint32(t.CPU()))
+ if _, err := t.CopyOutBytes(cpu, buf); err != nil {
+ return 0, nil, err
+ }
+ }
+ // We always return node 0.
+ if node != 0 {
+ if _, err := t.MemoryManager().ZeroOut(t, node, 4, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, nil, err
+ }
+ }
+ return 0, nil, nil
+}
+
+// Setpgid implements the linux syscall setpgid(2).
+func Setpgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // Note that throughout this function, pgid is interpreted with respect
+ // to t's namespace, not with respect to the selected ThreadGroup's
+ // namespace (which may be different).
+ pid := kernel.ThreadID(args[0].Int())
+ pgid := kernel.ProcessGroupID(args[1].Int())
+
+ // "If pid is zero, then the process ID of the calling process is used."
+ tg := t.ThreadGroup()
+ if pid != 0 {
+ ot := t.PIDNamespace().TaskWithID(pid)
+ if ot == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ tg = ot.ThreadGroup()
+ if tg.Leader() != ot {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Setpgid only operates on child threadgroups.
+ if tg != t.ThreadGroup() && (tg.Leader().Parent() == nil || tg.Leader().Parent().ThreadGroup() != t.ThreadGroup()) {
+ return 0, nil, syserror.ESRCH
+ }
+ }
+
+ // "If pgid is zero, then the PGID of the process specified by pid is made
+ // the same as its process ID."
+ defaultPGID := kernel.ProcessGroupID(t.PIDNamespace().IDOfThreadGroup(tg))
+ if pgid == 0 {
+ pgid = defaultPGID
+ } else if pgid < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // If the pgid is the same as the group, then create a new one. Otherwise,
+ // we attempt to join an existing process group.
+ if pgid == defaultPGID {
+ // For convenience, errors line up with Linux syscall API.
+ if err := tg.CreateProcessGroup(); err != nil {
+ // Is the process group already as expected? If so,
+ // just return success. This is the same behavior as
+ // Linux.
+ if t.PIDNamespace().IDOfProcessGroup(tg.ProcessGroup()) == defaultPGID {
+ return 0, nil, nil
+ }
+ return 0, nil, err
+ }
+ } else {
+ // Same as CreateProcessGroup, above.
+ if err := tg.JoinProcessGroup(t.PIDNamespace(), pgid, tg != t.ThreadGroup()); err != nil {
+ // See above.
+ if t.PIDNamespace().IDOfProcessGroup(tg.ProcessGroup()) == pgid {
+ return 0, nil, nil
+ }
+ return 0, nil, err
+ }
+ }
+
+ // Success.
+ return 0, nil, nil
+}
+
+// Getpgrp implements the linux syscall getpgrp(2).
+func Getpgrp(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return uintptr(t.PIDNamespace().IDOfProcessGroup(t.ThreadGroup().ProcessGroup())), nil, nil
+}
+
+// Getpgid implements the linux syscall getpgid(2).
+func Getpgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ tid := kernel.ThreadID(args[0].Int())
+ if tid == 0 {
+ return Getpgrp(t, args)
+ }
+
+ target := t.PIDNamespace().TaskWithID(tid)
+ if target == nil {
+ return 0, nil, syserror.ESRCH
+ }
+
+ return uintptr(t.PIDNamespace().IDOfProcessGroup(target.ThreadGroup().ProcessGroup())), nil, nil
+}
+
+// Setsid implements the linux syscall setsid(2).
+func Setsid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, t.ThreadGroup().CreateSession()
+}
+
+// Getsid implements the linux syscall getsid(2).
+func Getsid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ tid := kernel.ThreadID(args[0].Int())
+ if tid == 0 {
+ return uintptr(t.PIDNamespace().IDOfSession(t.ThreadGroup().Session())), nil, nil
+ }
+
+ target := t.PIDNamespace().TaskWithID(tid)
+ if target == nil {
+ return 0, nil, syserror.ESRCH
+ }
+
+ return uintptr(t.PIDNamespace().IDOfSession(target.ThreadGroup().Session())), nil, nil
+}
+
+// Getpriority pretends to implement the linux syscall getpriority(2).
+//
+// This is a stub; real priorities require a full scheduler.
+func Getpriority(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ which := args[0].Int()
+ who := kernel.ThreadID(args[1].Int())
+
+ switch which {
+ case linux.PRIO_PROCESS:
+ // Look for who, return ESRCH if not found.
+ var task *kernel.Task
+ if who == 0 {
+ task = t
+ } else {
+ task = t.PIDNamespace().TaskWithID(who)
+ }
+
+ if task == nil {
+ return 0, nil, syserror.ESRCH
+ }
+
+ // From kernel/sys.c:getpriority:
+ // "To avoid negative return values, 'getpriority()'
+ // will not return the normal nice-value, but a negated
+ // value that has been offset by 20"
+ return uintptr(20 - task.Niceness()), nil, nil
+ case linux.PRIO_USER:
+ fallthrough
+ case linux.PRIO_PGRP:
+ // PRIO_USER and PRIO_PGRP have no further implementation yet.
+ return 0, nil, nil
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+}
+
+// Setpriority pretends to implement the linux syscall setpriority(2).
+//
+// This is a stub; real priorities require a full scheduler.
+func Setpriority(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ which := args[0].Int()
+ who := kernel.ThreadID(args[1].Int())
+ niceval := int(args[2].Int())
+
+ // In the kernel's implementation, values outside the range
+ // of [-20, 19] are truncated to these minimum and maximum
+ // values.
+ if niceval < -20 /* min niceval */ {
+ niceval = -20
+ } else if niceval > 19 /* max niceval */ {
+ niceval = 19
+ }
+
+ switch which {
+ case linux.PRIO_PROCESS:
+ // Look for who, return ESRCH if not found.
+ var task *kernel.Task
+ if who == 0 {
+ task = t
+ } else {
+ task = t.PIDNamespace().TaskWithID(who)
+ }
+
+ if task == nil {
+ return 0, nil, syserror.ESRCH
+ }
+
+ task.SetNiceness(niceval)
+ case linux.PRIO_USER:
+ fallthrough
+ case linux.PRIO_PGRP:
+ // PRIO_USER and PRIO_PGRP have no further implementation yet.
+ return 0, nil, nil
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, nil
+}
+
+// Ptrace implements linux system call ptrace(2).
+func Ptrace(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ req := args[0].Int64()
+ pid := kernel.ThreadID(args[1].Int())
+ addr := args[2].Pointer()
+ data := args[3].Pointer()
+
+ return 0, nil, t.Ptrace(req, pid, addr, data)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go
new file mode 100644
index 000000000..2d2aa0819
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_time.go
@@ -0,0 +1,342 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// The most significant 29 bits hold either a pid or a file descriptor.
+func pidOfClockID(c int32) kernel.ThreadID {
+ return kernel.ThreadID(^(c >> 3))
+}
+
+// whichCPUClock returns one of CPUCLOCK_PERF, CPUCLOCK_VIRT, CPUCLOCK_SCHED or
+// CLOCK_FD.
+func whichCPUClock(c int32) int32 {
+ return c & linux.CPUCLOCK_CLOCK_MASK
+}
+
+// isCPUClockPerThread returns true if the CPUCLOCK_PERTHREAD bit is set in the
+// clock id.
+func isCPUClockPerThread(c int32) bool {
+ return c&linux.CPUCLOCK_PERTHREAD_MASK != 0
+}
+
+// isValidCPUClock returns checks that the cpu clock id is valid.
+func isValidCPUClock(c int32) bool {
+ // Bits 0, 1, and 2 cannot all be set.
+ if c&7 == 7 {
+ return false
+ }
+ if whichCPUClock(c) >= linux.CPUCLOCK_MAX {
+ return false
+ }
+ return true
+}
+
+// targetTask returns the kernel.Task for the given clock id.
+func targetTask(t *kernel.Task, c int32) *kernel.Task {
+ pid := pidOfClockID(c)
+ if pid == 0 {
+ return t
+ }
+ return t.PIDNamespace().TaskWithID(pid)
+}
+
+// ClockGetres implements linux syscall clock_getres(2).
+func ClockGetres(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ clockID := int32(args[0].Int())
+ addr := args[1].Pointer()
+ r := linux.Timespec{
+ Sec: 0,
+ Nsec: 1,
+ }
+
+ if _, err := getClock(t, clockID); err != nil {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if addr == 0 {
+ // Don't need to copy out.
+ return 0, nil, nil
+ }
+
+ return 0, nil, copyTimespecOut(t, addr, &r)
+}
+
+type cpuClocker interface {
+ UserCPUClock() ktime.Clock
+ CPUClock() ktime.Clock
+}
+
+func getClock(t *kernel.Task, clockID int32) (ktime.Clock, error) {
+ if clockID < 0 {
+ if !isValidCPUClock(clockID) {
+ return nil, syserror.EINVAL
+ }
+
+ targetTask := targetTask(t, clockID)
+ if targetTask == nil {
+ return nil, syserror.EINVAL
+ }
+
+ var target cpuClocker
+ if isCPUClockPerThread(clockID) {
+ target = targetTask
+ } else {
+ target = targetTask.ThreadGroup()
+ }
+
+ switch whichCPUClock(clockID) {
+ case linux.CPUCLOCK_VIRT:
+ return target.UserCPUClock(), nil
+ case linux.CPUCLOCK_PROF, linux.CPUCLOCK_SCHED:
+ // CPUCLOCK_SCHED is approximated by CPUCLOCK_PROF.
+ return target.CPUClock(), nil
+ default:
+ return nil, syserror.EINVAL
+ }
+ }
+
+ switch clockID {
+ case linux.CLOCK_REALTIME, linux.CLOCK_REALTIME_COARSE:
+ return t.Kernel().RealtimeClock(), nil
+ case linux.CLOCK_MONOTONIC, linux.CLOCK_MONOTONIC_COARSE,
+ linux.CLOCK_MONOTONIC_RAW, linux.CLOCK_BOOTTIME:
+ // CLOCK_MONOTONIC approximates CLOCK_MONOTONIC_RAW.
+ // CLOCK_BOOTTIME is internally mapped to CLOCK_MONOTONIC, as:
+ // - CLOCK_BOOTTIME should behave as CLOCK_MONOTONIC while also
+ // including suspend time.
+ // - gVisor has no concept of suspend/resume.
+ // - CLOCK_MONOTONIC already includes save/restore time, which is
+ // the closest to suspend time.
+ return t.Kernel().MonotonicClock(), nil
+ case linux.CLOCK_PROCESS_CPUTIME_ID:
+ return t.ThreadGroup().CPUClock(), nil
+ case linux.CLOCK_THREAD_CPUTIME_ID:
+ return t.CPUClock(), nil
+ default:
+ return nil, syserror.EINVAL
+ }
+}
+
+// ClockGettime implements linux syscall clock_gettime(2).
+func ClockGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ clockID := int32(args[0].Int())
+ addr := args[1].Pointer()
+
+ c, err := getClock(t, clockID)
+ if err != nil {
+ return 0, nil, err
+ }
+ ts := c.Now().Timespec()
+ return 0, nil, copyTimespecOut(t, addr, &ts)
+}
+
+// ClockSettime implements linux syscall clock_settime(2).
+func ClockSettime(*kernel.Task, arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, syserror.EPERM
+}
+
+// Time implements linux syscall time(2).
+func Time(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ r := t.Kernel().RealtimeClock().Now().TimeT()
+ if addr == usermem.Addr(0) {
+ return uintptr(r), nil, nil
+ }
+
+ if _, err := t.CopyOut(addr, r); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(r), nil, nil
+}
+
+// clockNanosleepRestartBlock encapsulates the state required to restart
+// clock_nanosleep(2) via restart_syscall(2).
+//
+// +stateify savable
+type clockNanosleepRestartBlock struct {
+ c ktime.Clock
+ duration time.Duration
+ rem usermem.Addr
+}
+
+// Restart implements kernel.SyscallRestartBlock.Restart.
+func (n *clockNanosleepRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
+ return 0, clockNanosleepFor(t, n.c, n.duration, n.rem)
+}
+
+// clockNanosleepUntil blocks until a specified time.
+//
+// If blocking is interrupted, the syscall is restarted with the original
+// arguments.
+func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, ts linux.Timespec) error {
+ notifier, tchan := ktime.NewChannelNotifier()
+ timer := ktime.NewTimer(c, notifier)
+
+ // Turn on the timer.
+ timer.Swap(ktime.Setting{
+ Period: 0,
+ Enabled: true,
+ Next: ktime.FromTimespec(ts),
+ })
+
+ err := t.BlockWithTimer(nil, tchan)
+
+ timer.Destroy()
+
+ // Did we just block until the timeout happened?
+ if err == syserror.ETIMEDOUT {
+ return nil
+ }
+
+ return syserror.ConvertIntr(err, kernel.ERESTARTNOHAND)
+}
+
+// clockNanosleepFor blocks for a specified duration.
+//
+// If blocking is interrupted, the syscall is restarted with the remaining
+// duration timeout.
+func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem usermem.Addr) error {
+ timer, start, tchan := ktime.After(c, dur)
+
+ err := t.BlockWithTimer(nil, tchan)
+
+ after := c.Now()
+
+ timer.Destroy()
+
+ switch err {
+ case syserror.ETIMEDOUT:
+ // Slept for entire timeout.
+ return nil
+ case syserror.ErrInterrupted:
+ // Interrupted.
+ remaining := dur - after.Sub(start)
+ if remaining < 0 {
+ remaining = time.Duration(0)
+ }
+
+ // Copy out remaining time.
+ if rem != 0 {
+ timeleft := linux.NsecToTimespec(remaining.Nanoseconds())
+ if err := copyTimespecOut(t, rem, &timeleft); err != nil {
+ return err
+ }
+ }
+
+ // Arrange for a restart with the remaining duration.
+ t.SetSyscallRestartBlock(&clockNanosleepRestartBlock{
+ c: c,
+ duration: remaining,
+ rem: rem,
+ })
+ return kernel.ERESTART_RESTARTBLOCK
+ default:
+ panic(fmt.Sprintf("Impossible BlockWithTimer error %v", err))
+ }
+}
+
+// Nanosleep implements linux syscall Nanosleep(2).
+func Nanosleep(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ rem := args[1].Pointer()
+
+ ts, err := copyTimespecIn(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if !ts.Valid() {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Just like linux, we cap the timeout with the max number that int64 can
+ // represent which is roughly 292 years.
+ dur := time.Duration(ts.ToNsecCapped()) * time.Nanosecond
+ return 0, nil, clockNanosleepFor(t, t.Kernel().MonotonicClock(), dur, rem)
+}
+
+// ClockNanosleep implements linux syscall clock_nanosleep(2).
+func ClockNanosleep(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ clockID := int32(args[0].Int())
+ flags := args[1].Int()
+ addr := args[2].Pointer()
+ rem := args[3].Pointer()
+
+ req, err := copyTimespecIn(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if !req.Valid() {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Only allow clock constants also allowed by Linux.
+ if clockID > 0 {
+ if clockID != linux.CLOCK_REALTIME &&
+ clockID != linux.CLOCK_MONOTONIC &&
+ clockID != linux.CLOCK_PROCESS_CPUTIME_ID {
+ return 0, nil, syserror.EINVAL
+ }
+ }
+
+ c, err := getClock(t, clockID)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if flags&linux.TIMER_ABSTIME != 0 {
+ return 0, nil, clockNanosleepUntil(t, c, req)
+ }
+
+ dur := time.Duration(req.ToNsecCapped()) * time.Nanosecond
+ return 0, nil, clockNanosleepFor(t, c, dur, rem)
+}
+
+// Gettimeofday implements linux syscall gettimeofday(2).
+func Gettimeofday(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ tv := args[0].Pointer()
+ tz := args[1].Pointer()
+
+ if tv != usermem.Addr(0) {
+ nowTv := t.Kernel().RealtimeClock().Now().Timeval()
+ if err := copyTimevalOut(t, tv, &nowTv); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ if tz != usermem.Addr(0) {
+ // Ask the time package for the timezone.
+ _, offset := time.Now().Zone()
+ // This int32 array mimics linux's struct timezone.
+ timezone := [2]int32{-int32(offset) / 60, 0}
+ _, err := t.CopyOut(tz, timezone)
+ return 0, nil, err
+ }
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_timer.go b/pkg/sentry/syscalls/linux/sys_timer.go
new file mode 100644
index 000000000..a4c400f87
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_timer.go
@@ -0,0 +1,203 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const nsecPerSec = int64(time.Second)
+
+// copyItimerValIn copies an ItimerVal from the untrusted app range to the
+// kernel. The ItimerVal may be either 32 or 64 bits.
+// A NULL address is allowed because because Linux allows
+// setitimer(which, NULL, &old_value) which disables the timer.
+// There is a KERN_WARN message saying this misfeature will be removed.
+// However, that hasn't happened as of 3.19, so we continue to support it.
+func copyItimerValIn(t *kernel.Task, addr usermem.Addr) (linux.ItimerVal, error) {
+ if addr == usermem.Addr(0) {
+ return linux.ItimerVal{}, nil
+ }
+
+ switch t.Arch().Width() {
+ case 8:
+ // Native size, just copy directly.
+ var itv linux.ItimerVal
+ if _, err := t.CopyIn(addr, &itv); err != nil {
+ return linux.ItimerVal{}, err
+ }
+
+ return itv, nil
+ default:
+ return linux.ItimerVal{}, syserror.ENOSYS
+ }
+}
+
+// copyItimerValOut copies an ItimerVal to the untrusted app range.
+// The ItimerVal may be either 32 or 64 bits.
+// A NULL address is allowed, in which case no copy takes place
+func copyItimerValOut(t *kernel.Task, addr usermem.Addr, itv *linux.ItimerVal) error {
+ if addr == usermem.Addr(0) {
+ return nil
+ }
+
+ switch t.Arch().Width() {
+ case 8:
+ // Native size, just copy directly.
+ _, err := t.CopyOut(addr, itv)
+ return err
+ default:
+ return syserror.ENOSYS
+ }
+}
+
+// Getitimer implements linux syscall getitimer(2).
+func Getitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ timerID := args[0].Int()
+ val := args[1].Pointer()
+
+ olditv, err := t.Getitimer(timerID)
+ if err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, copyItimerValOut(t, val, &olditv)
+}
+
+// Setitimer implements linux syscall setitimer(2).
+func Setitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ timerID := args[0].Int()
+ newVal := args[1].Pointer()
+ oldVal := args[2].Pointer()
+
+ newitv, err := copyItimerValIn(t, newVal)
+ if err != nil {
+ return 0, nil, err
+ }
+ olditv, err := t.Setitimer(timerID, newitv)
+ if err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, copyItimerValOut(t, oldVal, &olditv)
+}
+
+// Alarm implements linux syscall alarm(2).
+func Alarm(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ duration := time.Duration(args[0].Uint()) * time.Second
+
+ olditv, err := t.Setitimer(linux.ITIMER_REAL, linux.ItimerVal{
+ Value: linux.DurationToTimeval(duration),
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ olddur := olditv.Value.ToDuration()
+ secs := olddur.Round(time.Second).Nanoseconds() / nsecPerSec
+ if secs == 0 && olddur != 0 {
+ // We can't return 0 if an alarm was previously scheduled.
+ secs = 1
+ }
+ return uintptr(secs), nil, nil
+}
+
+// TimerCreate implements linux syscall timer_create(2).
+func TimerCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ clockID := args[0].Int()
+ sevp := args[1].Pointer()
+ timerIDp := args[2].Pointer()
+
+ c, err := getClock(t, clockID)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var sev *linux.Sigevent
+ if sevp != 0 {
+ sev = &linux.Sigevent{}
+ if _, err = t.CopyIn(sevp, sev); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ id, err := t.IntervalTimerCreate(c, sev)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if _, err := t.CopyOut(timerIDp, &id); err != nil {
+ t.IntervalTimerDelete(id)
+ return 0, nil, err
+ }
+
+ return 0, nil, nil
+}
+
+// TimerSettime implements linux syscall timer_settime(2).
+func TimerSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ timerID := linux.TimerID(args[0].Value)
+ flags := args[1].Int()
+ newValAddr := args[2].Pointer()
+ oldValAddr := args[3].Pointer()
+
+ var newVal linux.Itimerspec
+ if _, err := t.CopyIn(newValAddr, &newVal); err != nil {
+ return 0, nil, err
+ }
+ oldVal, err := t.IntervalTimerSettime(timerID, newVal, flags&linux.TIMER_ABSTIME != 0)
+ if err != nil {
+ return 0, nil, err
+ }
+ if oldValAddr != 0 {
+ if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil {
+ return 0, nil, err
+ }
+ }
+ return 0, nil, nil
+}
+
+// TimerGettime implements linux syscall timer_gettime(2).
+func TimerGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ timerID := linux.TimerID(args[0].Value)
+ curValAddr := args[1].Pointer()
+
+ curVal, err := t.IntervalTimerGettime(timerID)
+ if err != nil {
+ return 0, nil, err
+ }
+ _, err = t.CopyOut(curValAddr, &curVal)
+ return 0, nil, err
+}
+
+// TimerGetoverrun implements linux syscall timer_getoverrun(2).
+func TimerGetoverrun(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ timerID := linux.TimerID(args[0].Value)
+
+ o, err := t.IntervalTimerGetoverrun(timerID)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(o), nil, nil
+}
+
+// TimerDelete implements linux syscall timer_delete(2).
+func TimerDelete(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ timerID := linux.TimerID(args[0].Value)
+ return 0, nil, t.IntervalTimerDelete(timerID)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_timerfd.go b/pkg/sentry/syscalls/linux/sys_timerfd.go
new file mode 100644
index 000000000..cf49b43db
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_timerfd.go
@@ -0,0 +1,121 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/timerfd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// TimerfdCreate implements Linux syscall timerfd_create(2).
+func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ clockID := args[0].Int()
+ flags := args[1].Int()
+
+ if flags&^(linux.TFD_CLOEXEC|linux.TFD_NONBLOCK) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var c ktime.Clock
+ switch clockID {
+ case linux.CLOCK_REALTIME:
+ c = t.Kernel().RealtimeClock()
+ case linux.CLOCK_MONOTONIC, linux.CLOCK_BOOTTIME:
+ c = t.Kernel().MonotonicClock()
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+ f := timerfd.NewFile(t, c)
+ defer f.DecRef()
+ f.SetFlags(fs.SettableFileFlags{
+ NonBlocking: flags&linux.TFD_NONBLOCK != 0,
+ })
+
+ fd, err := t.NewFDFrom(0, f, kernel.FDFlags{
+ CloseOnExec: flags&linux.TFD_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// TimerfdSettime implements Linux syscall timerfd_settime(2).
+func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ flags := args[1].Int()
+ newValAddr := args[2].Pointer()
+ oldValAddr := args[3].Pointer()
+
+ if flags&^(linux.TFD_TIMER_ABSTIME) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ f := t.GetFile(fd)
+ if f == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer f.DecRef()
+
+ tf, ok := f.FileOperations.(*timerfd.TimerOperations)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var newVal linux.Itimerspec
+ if _, err := t.CopyIn(newValAddr, &newVal); err != nil {
+ return 0, nil, err
+ }
+ newS, err := ktime.SettingFromItimerspec(newVal, flags&linux.TFD_TIMER_ABSTIME != 0, tf.Clock())
+ if err != nil {
+ return 0, nil, err
+ }
+ tm, oldS := tf.SetTime(newS)
+ if oldValAddr != 0 {
+ oldVal := ktime.ItimerspecFromSetting(tm, oldS)
+ if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil {
+ return 0, nil, err
+ }
+ }
+ return 0, nil, nil
+}
+
+// TimerfdGettime implements Linux syscall timerfd_gettime(2).
+func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ curValAddr := args[1].Pointer()
+
+ f := t.GetFile(fd)
+ if f == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer f.DecRef()
+
+ tf, ok := f.FileOperations.(*timerfd.TimerOperations)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ tm, s := tf.GetTime()
+ curVal := ktime.ItimerspecFromSetting(tm, s)
+ _, err := t.CopyOut(curValAddr, &curVal)
+ return 0, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/sys_tls_amd64.go b/pkg/sentry/syscalls/linux/sys_tls_amd64.go
new file mode 100644
index 000000000..b3eb96a1c
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_tls_amd64.go
@@ -0,0 +1,52 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//+build amd64
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// ArchPrctl implements linux syscall arch_prctl(2).
+// It sets architecture-specific process or thread state for t.
+func ArchPrctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ switch args[0].Int() {
+ case linux.ARCH_GET_FS:
+ addr := args[1].Pointer()
+ fsbase := t.Arch().TLS()
+ _, err := t.CopyOut(addr, uint64(fsbase))
+ if err != nil {
+ return 0, nil, err
+ }
+
+ case linux.ARCH_SET_FS:
+ fsbase := args[1].Uint64()
+ if !t.Arch().SetTLS(uintptr(fsbase)) {
+ return 0, nil, syserror.EPERM
+ }
+
+ case linux.ARCH_GET_GS, linux.ARCH_SET_GS:
+ t.Kernel().EmitUnimplementedEvent(t)
+ fallthrough
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_tls_arm64.go b/pkg/sentry/syscalls/linux/sys_tls_arm64.go
new file mode 100644
index 000000000..fb08a356e
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_tls_arm64.go
@@ -0,0 +1,28 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//+build arm64
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// ArchPrctl is not defined for ARM64.
+func ArchPrctl(*kernel.Task, arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, syserror.ENOSYS
+}
diff --git a/pkg/sentry/syscalls/linux/sys_utsname.go b/pkg/sentry/syscalls/linux/sys_utsname.go
new file mode 100644
index 000000000..e9d702e8e
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_utsname.go
@@ -0,0 +1,95 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Uname implements linux syscall uname.
+func Uname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ version := t.SyscallTable().Version
+
+ uts := t.UTSNamespace()
+
+ // Fill in structure fields.
+ var u linux.UtsName
+ copy(u.Sysname[:], version.Sysname)
+ copy(u.Nodename[:], uts.HostName())
+ copy(u.Release[:], version.Release)
+ copy(u.Version[:], version.Version)
+ // build tag above.
+ switch t.SyscallTable().Arch {
+ case arch.AMD64:
+ copy(u.Machine[:], "x86_64")
+ case arch.ARM64:
+ copy(u.Machine[:], "aarch64")
+ default:
+ copy(u.Machine[:], "unknown")
+ }
+ copy(u.Domainname[:], uts.DomainName())
+
+ // Copy out the result.
+ va := args[0].Pointer()
+ _, err := t.CopyOut(va, u)
+ return 0, nil, err
+}
+
+// Setdomainname implements Linux syscall setdomainname.
+func Setdomainname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nameAddr := args[0].Pointer()
+ size := args[1].Int()
+
+ utsns := t.UTSNamespace()
+ if !t.HasCapabilityIn(linux.CAP_SYS_ADMIN, utsns.UserNamespace()) {
+ return 0, nil, syserror.EPERM
+ }
+ if size < 0 || size > linux.UTSLen {
+ return 0, nil, syserror.EINVAL
+ }
+
+ name, err := t.CopyInString(nameAddr, int(size))
+ if err != nil {
+ return 0, nil, err
+ }
+
+ utsns.SetDomainName(name)
+ return 0, nil, nil
+}
+
+// Sethostname implements Linux syscall sethostname.
+func Sethostname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nameAddr := args[0].Pointer()
+ size := args[1].Int()
+
+ utsns := t.UTSNamespace()
+ if !t.HasCapabilityIn(linux.CAP_SYS_ADMIN, utsns.UserNamespace()) {
+ return 0, nil, syserror.EPERM
+ }
+ if size < 0 || size > linux.UTSLen {
+ return 0, nil, syserror.EINVAL
+ }
+
+ name := make([]byte, size)
+ if _, err := t.CopyInBytes(nameAddr, name); err != nil {
+ return 0, nil, err
+ }
+
+ utsns.SetHostName(string(name))
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go
new file mode 100644
index 000000000..6ec0de96e
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_write.go
@@ -0,0 +1,364 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+const (
+ // EventMaskWrite contains events that can be triggered on writes.
+ //
+ // Note that EventHUp is not going to happen for pipes but may for
+ // implementations of poll on some sockets, see net/core/datagram.c.
+ EventMaskWrite = waiter.EventOut | waiter.EventHUp | waiter.EventErr
+)
+
+// Write implements linux syscall write(2).
+func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the file is writable.
+ if !file.Flags().Write {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := writev(t, file, src)
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "write", file)
+}
+
+// Pwrite64 implements linux syscall pwrite64(2).
+func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+ offset := args[3].Int64()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Is writing at an offset supported?
+ if !file.Flags().Pwrite {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ // Check that the file is writable.
+ if !file.Flags().Write {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pwritev(t, file, src, offset)
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pwrite64", file)
+}
+
+// Writev implements linux syscall writev(2).
+func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the file is writable.
+ if !file.Flags().Write {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Read the iovecs that specify the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := writev(t, file, src)
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "writev", file)
+}
+
+// Pwritev implements linux syscall pwritev(2).
+func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Is writing at an offset supported?
+ if !file.Flags().Pwrite {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ // Check that the file is writable.
+ if !file.Flags().Write {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Read the iovecs that specify the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pwritev(t, file, src, offset)
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pwritev", file)
+}
+
+// Pwritev2 implements linux syscall pwritev2(2).
+func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // While the syscall is
+ // pwritev2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
+ // the linux internal call
+ // (https://elixir.bootlin.com/linux/v4.18/source/fs/read_write.c#L1354)
+ // splits the offset argument into a high/low value for compatibility with
+ // 32-bit architectures. The flags argument is the 5th argument.
+
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+ flags := int(args[5].Int())
+
+ if int(args[4].Int())&0x4 == 1 {
+ return 0, nil, syserror.EACCES
+ }
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < -1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Is writing at an offset supported?
+ if offset > -1 && !file.Flags().Pwrite {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ // Note: gVisor does not implement the RWF_HIPRI feature, but the flag is
+ // accepted as a valid flag argument for pwritev2.
+ if flags&^linux.RWF_VALID != 0 {
+ return uintptr(flags), nil, syserror.EOPNOTSUPP
+ }
+
+ // Check that the file is writeable.
+ if !file.Flags().Write {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Read the iovecs that specify the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // If pwritev2 is called with an offset of -1, writev is called.
+ if offset == -1 {
+ n, err := writev(t, file, src)
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pwritev2", file)
+ }
+
+ n, err := pwritev(t, file, src, offset)
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pwritev2", file)
+}
+
+func writev(t *kernel.Task, f *fs.File, src usermem.IOSequence) (int64, error) {
+ n, err := f.Writev(t, src)
+ if err != syserror.ErrWouldBlock || f.Flags().NonBlocking {
+ if n > 0 {
+ // Queue notification if we wrote anything.
+ f.Dirent.InotifyEvent(linux.IN_MODIFY, 0)
+ }
+ return n, err
+ }
+
+ // Sockets support write timeouts.
+ var haveDeadline bool
+ var deadline ktime.Time
+ if s, ok := f.FileOperations.(socket.Socket); ok {
+ dl := s.SendTimeout()
+ if dl < 0 && err == syserror.ErrWouldBlock {
+ return n, err
+ }
+ if dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ }
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ f.EventRegister(&w, EventMaskWrite)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst64(n)
+
+ // Issue the request and break out if it completes with
+ // anything other than "would block".
+ n, err = f.Writev(t, src)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
+ break
+ }
+ }
+
+ f.EventUnregister(&w)
+
+ if total > 0 {
+ // Queue notification if we wrote anything.
+ f.Dirent.InotifyEvent(linux.IN_MODIFY, 0)
+ }
+
+ return total, err
+}
+
+func pwritev(t *kernel.Task, f *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ n, err := f.Pwritev(t, src, offset)
+ if err != syserror.ErrWouldBlock || f.Flags().NonBlocking {
+ if n > 0 {
+ // Queue notification if we wrote anything.
+ f.Dirent.InotifyEvent(linux.IN_MODIFY, 0)
+ }
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ f.EventRegister(&w, EventMaskWrite)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst64(n)
+
+ // Issue the request and break out if it completes with
+ // anything other than "would block".
+ n, err = f.Pwritev(t, src, offset+total)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ if err = t.Block(ch); err != nil {
+ break
+ }
+ }
+
+ f.EventUnregister(&w)
+
+ if total > 0 {
+ // Queue notification if we wrote anything.
+ f.Dirent.InotifyEvent(linux.IN_MODIFY, 0)
+ }
+
+ return total, err
+}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go
new file mode 100644
index 000000000..c24946160
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_xattr.go
@@ -0,0 +1,432 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// LINT.IfChange
+
+// GetXattr implements linux syscall getxattr(2).
+func GetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getXattrFromPath(t, args, true)
+}
+
+// LGetXattr implements linux syscall lgetxattr(2).
+func LGetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getXattrFromPath(t, args, false)
+}
+
+// FGetXattr implements linux syscall fgetxattr(2).
+func FGetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := uint64(args[3].SizeT())
+
+ // TODO(b/113957122): Return EBADF if the fd was opened with O_PATH.
+ f := t.GetFile(fd)
+ if f == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer f.DecRef()
+
+ n, err := getXattr(t, f.Dirent, nameAddr, valueAddr, size)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(n), nil, nil
+}
+
+func getXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlink bool) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := uint64(args[3].SizeT())
+
+ path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n := 0
+ err = fileOpOn(t, linux.AT_FDCWD, path, resolveSymlink, func(_ *fs.Dirent, d *fs.Dirent, _ uint) error {
+ if dirPath && !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ n, err = getXattr(t, d, nameAddr, valueAddr, size)
+ return err
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(n), nil, nil
+}
+
+// getXattr implements getxattr(2) from the given *fs.Dirent.
+func getXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr usermem.Addr, size uint64) (int, error) {
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, err
+ }
+
+ if err := checkXattrPermissions(t, d.Inode, fs.PermMask{Read: true}); err != nil {
+ return 0, err
+ }
+
+ // TODO(b/148380782): Support xattrs in namespaces other than "user".
+ if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ // If getxattr(2) is called with size 0, the size of the value will be
+ // returned successfully even if it is nonzero. In that case, we need to
+ // retrieve the entire attribute value so we can return the correct size.
+ requestedSize := size
+ if size == 0 || size > linux.XATTR_SIZE_MAX {
+ requestedSize = linux.XATTR_SIZE_MAX
+ }
+
+ value, err := d.Inode.GetXattr(t, name, requestedSize)
+ if err != nil {
+ return 0, err
+ }
+ n := len(value)
+ if uint64(n) > requestedSize {
+ return 0, syserror.ERANGE
+ }
+
+ // Don't copy out the attribute value if size is 0.
+ if size == 0 {
+ return n, nil
+ }
+
+ if _, err = t.CopyOutBytes(valueAddr, []byte(value)); err != nil {
+ return 0, err
+ }
+ return n, nil
+}
+
+// SetXattr implements linux syscall setxattr(2).
+func SetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return setXattrFromPath(t, args, true)
+}
+
+// LSetXattr implements linux syscall lsetxattr(2).
+func LSetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return setXattrFromPath(t, args, false)
+}
+
+// FSetXattr implements linux syscall fsetxattr(2).
+func FSetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := uint64(args[3].SizeT())
+ flags := args[4].Uint()
+
+ // TODO(b/113957122): Return EBADF if the fd was opened with O_PATH.
+ f := t.GetFile(fd)
+ if f == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer f.DecRef()
+
+ return 0, nil, setXattr(t, f.Dirent, nameAddr, valueAddr, uint64(size), flags)
+}
+
+func setXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlink bool) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := uint64(args[3].SizeT())
+ flags := args[4].Uint()
+
+ path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, resolveSymlink, func(_ *fs.Dirent, d *fs.Dirent, _ uint) error {
+ if dirPath && !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ return setXattr(t, d, nameAddr, valueAddr, uint64(size), flags)
+ })
+}
+
+// setXattr implements setxattr(2) from the given *fs.Dirent.
+func setXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr usermem.Addr, size uint64, flags uint32) error {
+ if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
+ return syserror.EINVAL
+ }
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+
+ if err := checkXattrPermissions(t, d.Inode, fs.PermMask{Write: true}); err != nil {
+ return err
+ }
+
+ if size > linux.XATTR_SIZE_MAX {
+ return syserror.E2BIG
+ }
+ buf := make([]byte, size)
+ if _, err := t.CopyInBytes(valueAddr, buf); err != nil {
+ return err
+ }
+ value := string(buf)
+
+ if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+
+ if err := d.Inode.SetXattr(t, d, name, value, flags); err != nil {
+ return err
+ }
+ d.InotifyEvent(linux.IN_ATTRIB, 0)
+ return nil
+}
+
+func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) {
+ name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1)
+ if err != nil {
+ if err == syserror.ENAMETOOLONG {
+ return "", syserror.ERANGE
+ }
+ return "", err
+ }
+ if len(name) == 0 {
+ return "", syserror.ERANGE
+ }
+ return name, nil
+}
+
+// Restrict xattrs to regular files and directories.
+//
+// TODO(b/148380782): In Linux, this restriction technically only applies to
+// xattrs in the "user.*" namespace. Make file type checks specific to the
+// namespace once we allow other xattr prefixes.
+func xattrFileTypeOk(i *fs.Inode) bool {
+ return fs.IsRegular(i.StableAttr) || fs.IsDir(i.StableAttr)
+}
+
+func checkXattrPermissions(t *kernel.Task, i *fs.Inode, perms fs.PermMask) error {
+ // Restrict xattrs to regular files and directories.
+ if !xattrFileTypeOk(i) {
+ if perms.Write {
+ return syserror.EPERM
+ }
+ return syserror.ENODATA
+ }
+
+ return i.CheckPermission(t, perms)
+}
+
+// ListXattr implements linux syscall listxattr(2).
+func ListXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listXattrFromPath(t, args, true)
+}
+
+// LListXattr implements linux syscall llistxattr(2).
+func LListXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listXattrFromPath(t, args, false)
+}
+
+// FListXattr implements linux syscall flistxattr(2).
+func FListXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ listAddr := args[1].Pointer()
+ size := uint64(args[2].SizeT())
+
+ // TODO(b/113957122): Return EBADF if the fd was opened with O_PATH.
+ f := t.GetFile(fd)
+ if f == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer f.DecRef()
+
+ n, err := listXattr(t, f.Dirent, listAddr, size)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(n), nil, nil
+}
+
+func listXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlink bool) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ listAddr := args[1].Pointer()
+ size := uint64(args[2].SizeT())
+
+ path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n := 0
+ err = fileOpOn(t, linux.AT_FDCWD, path, resolveSymlink, func(_ *fs.Dirent, d *fs.Dirent, _ uint) error {
+ if dirPath && !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ n, err = listXattr(t, d, listAddr, size)
+ return err
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(n), nil, nil
+}
+
+func listXattr(t *kernel.Task, d *fs.Dirent, addr usermem.Addr, size uint64) (int, error) {
+ if !xattrFileTypeOk(d.Inode) {
+ return 0, nil
+ }
+
+ // If listxattr(2) is called with size 0, the buffer size needed to contain
+ // the xattr list will be returned successfully even if it is nonzero. In
+ // that case, we need to retrieve the entire list so we can compute and
+ // return the correct size.
+ requestedSize := size
+ if size == 0 || size > linux.XATTR_SIZE_MAX {
+ requestedSize = linux.XATTR_SIZE_MAX
+ }
+ xattrs, err := d.Inode.ListXattr(t, requestedSize)
+ if err != nil {
+ return 0, err
+ }
+
+ // TODO(b/148380782): support namespaces other than "user".
+ for x := range xattrs {
+ if !strings.HasPrefix(x, linux.XATTR_USER_PREFIX) {
+ delete(xattrs, x)
+ }
+ }
+
+ listSize := xattrListSize(xattrs)
+ if listSize > linux.XATTR_SIZE_MAX {
+ return 0, syserror.E2BIG
+ }
+ if uint64(listSize) > requestedSize {
+ return 0, syserror.ERANGE
+ }
+
+ // Don't copy out the attributes if size is 0.
+ if size == 0 {
+ return listSize, nil
+ }
+
+ buf := make([]byte, 0, listSize)
+ for x := range xattrs {
+ buf = append(buf, []byte(x)...)
+ buf = append(buf, 0)
+ }
+ if _, err := t.CopyOutBytes(addr, buf); err != nil {
+ return 0, err
+ }
+
+ return len(buf), nil
+}
+
+func xattrListSize(xattrs map[string]struct{}) int {
+ size := 0
+ for x := range xattrs {
+ size += len(x) + 1
+ }
+ return size
+}
+
+// RemoveXattr implements linux syscall removexattr(2).
+func RemoveXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return removeXattrFromPath(t, args, true)
+}
+
+// LRemoveXattr implements linux syscall lremovexattr(2).
+func LRemoveXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return removeXattrFromPath(t, args, false)
+}
+
+// FRemoveXattr implements linux syscall fremovexattr(2).
+func FRemoveXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+
+ // TODO(b/113957122): Return EBADF if the fd was opened with O_PATH.
+ f := t.GetFile(fd)
+ if f == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer f.DecRef()
+
+ return 0, nil, removeXattr(t, f.Dirent, nameAddr)
+}
+
+func removeXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlink bool) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+
+ path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, resolveSymlink, func(_ *fs.Dirent, d *fs.Dirent, _ uint) error {
+ if dirPath && !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ return removeXattr(t, d, nameAddr)
+ })
+}
+
+// removeXattr implements removexattr(2) from the given *fs.Dirent.
+func removeXattr(t *kernel.Task, d *fs.Dirent, nameAddr usermem.Addr) error {
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+
+ if err := checkXattrPermissions(t, d.Inode, fs.PermMask{Write: true}); err != nil {
+ return err
+ }
+
+ if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+
+ if err := d.Inode.RemoveXattr(t, d, name); err != nil {
+ return err
+ }
+ d.InotifyEvent(linux.IN_ATTRIB, 0)
+ return nil
+}
+
+// LINT.ThenChange(vfs2/xattr.go)
diff --git a/pkg/sentry/syscalls/linux/timespec.go b/pkg/sentry/syscalls/linux/timespec.go
new file mode 100644
index 000000000..ddc3ee26e
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/timespec.go
@@ -0,0 +1,111 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// copyTimespecIn copies a Timespec from the untrusted app range to the kernel.
+func copyTimespecIn(t *kernel.Task, addr usermem.Addr) (linux.Timespec, error) {
+ switch t.Arch().Width() {
+ case 8:
+ ts := linux.Timespec{}
+ in := t.CopyScratchBuffer(16)
+ _, err := t.CopyInBytes(addr, in)
+ if err != nil {
+ return ts, err
+ }
+ ts.Sec = int64(usermem.ByteOrder.Uint64(in[0:]))
+ ts.Nsec = int64(usermem.ByteOrder.Uint64(in[8:]))
+ return ts, nil
+ default:
+ return linux.Timespec{}, syserror.ENOSYS
+ }
+}
+
+// copyTimespecOut copies a Timespec to the untrusted app range.
+func copyTimespecOut(t *kernel.Task, addr usermem.Addr, ts *linux.Timespec) error {
+ switch t.Arch().Width() {
+ case 8:
+ out := t.CopyScratchBuffer(16)
+ usermem.ByteOrder.PutUint64(out[0:], uint64(ts.Sec))
+ usermem.ByteOrder.PutUint64(out[8:], uint64(ts.Nsec))
+ _, err := t.CopyOutBytes(addr, out)
+ return err
+ default:
+ return syserror.ENOSYS
+ }
+}
+
+// copyTimevalIn copies a Timeval from the untrusted app range to the kernel.
+func copyTimevalIn(t *kernel.Task, addr usermem.Addr) (linux.Timeval, error) {
+ switch t.Arch().Width() {
+ case 8:
+ tv := linux.Timeval{}
+ in := t.CopyScratchBuffer(16)
+ _, err := t.CopyInBytes(addr, in)
+ if err != nil {
+ return tv, err
+ }
+ tv.Sec = int64(usermem.ByteOrder.Uint64(in[0:]))
+ tv.Usec = int64(usermem.ByteOrder.Uint64(in[8:]))
+ return tv, nil
+ default:
+ return linux.Timeval{}, syserror.ENOSYS
+ }
+}
+
+// copyTimevalOut copies a Timeval to the untrusted app range.
+func copyTimevalOut(t *kernel.Task, addr usermem.Addr, tv *linux.Timeval) error {
+ switch t.Arch().Width() {
+ case 8:
+ out := t.CopyScratchBuffer(16)
+ usermem.ByteOrder.PutUint64(out[0:], uint64(tv.Sec))
+ usermem.ByteOrder.PutUint64(out[8:], uint64(tv.Usec))
+ _, err := t.CopyOutBytes(addr, out)
+ return err
+ default:
+ return syserror.ENOSYS
+ }
+}
+
+// copyTimespecInToDuration copies a Timespec from the untrusted app range,
+// validates it and converts it to a Duration.
+//
+// If the Timespec is larger than what can be represented in a Duration, the
+// returned value is the maximum that Duration will allow.
+//
+// If timespecAddr is NULL, the returned value is negative.
+func copyTimespecInToDuration(t *kernel.Task, timespecAddr usermem.Addr) (time.Duration, error) {
+ // Use a negative Duration to indicate "no timeout".
+ timeout := time.Duration(-1)
+ if timespecAddr != 0 {
+ timespec, err := copyTimespecIn(t, timespecAddr)
+ if err != nil {
+ return 0, err
+ }
+ if !timespec.Valid() {
+ return 0, syserror.EINVAL
+ }
+ timeout = time.Duration(timespec.ToNsecCapped())
+ }
+ return timeout, nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
new file mode 100644
index 000000000..0c740335b
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -0,0 +1,76 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "vfs2",
+ srcs = [
+ "aio.go",
+ "epoll.go",
+ "eventfd.go",
+ "execve.go",
+ "fd.go",
+ "filesystem.go",
+ "fscontext.go",
+ "getdents.go",
+ "inotify.go",
+ "ioctl.go",
+ "lock.go",
+ "memfd.go",
+ "mmap.go",
+ "mount.go",
+ "path.go",
+ "pipe.go",
+ "poll.go",
+ "read_write.go",
+ "setstat.go",
+ "signal.go",
+ "socket.go",
+ "splice.go",
+ "stat.go",
+ "stat_amd64.go",
+ "stat_arm64.go",
+ "sync.go",
+ "timerfd.go",
+ "vfs2.go",
+ "xattr.go",
+ ],
+ marshal = True,
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/bits",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/gohacks",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsbridge",
+ "//pkg/sentry/fsimpl/eventfd",
+ "//pkg/sentry/fsimpl/pipefs",
+ "//pkg/sentry/fsimpl/signalfd",
+ "//pkg/sentry/fsimpl/timerfd",
+ "//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/fasync",
+ "//pkg/sentry/kernel/pipe",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/loader",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/mm",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/control",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/syscalls",
+ "//pkg/sentry/syscalls/linux",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go
new file mode 100644
index 000000000..e5cdefc50
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/aio.go
@@ -0,0 +1,216 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/eventfd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// IoSubmit implements linux syscall io_submit(2).
+func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := args[0].Uint64()
+ nrEvents := args[1].Int()
+ addr := args[2].Pointer()
+
+ if nrEvents < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ for i := int32(0); i < nrEvents; i++ {
+ // Copy in the address.
+ cbAddrNative := t.Arch().Native(0)
+ if _, err := t.CopyIn(addr, cbAddrNative); err != nil {
+ if i > 0 {
+ // Some successful.
+ return uintptr(i), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
+ }
+
+ // Copy in this callback.
+ var cb linux.IOCallback
+ cbAddr := usermem.Addr(t.Arch().Value(cbAddrNative))
+ if _, err := t.CopyIn(cbAddr, &cb); err != nil {
+ if i > 0 {
+ // Some have been successful.
+ return uintptr(i), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
+ }
+
+ // Process this callback.
+ if err := submitCallback(t, id, &cb, cbAddr); err != nil {
+ if i > 0 {
+ // Partial success.
+ return uintptr(i), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
+ }
+
+ // Advance to the next one.
+ addr += usermem.Addr(t.Arch().Width())
+ }
+
+ return uintptr(nrEvents), nil, nil
+}
+
+// submitCallback processes a single callback.
+func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr usermem.Addr) error {
+ if cb.Reserved2 != 0 {
+ return syserror.EINVAL
+ }
+
+ fd := t.GetFileVFS2(cb.FD)
+ if fd == nil {
+ return syserror.EBADF
+ }
+ defer fd.DecRef()
+
+ // Was there an eventFD? Extract it.
+ var eventFD *vfs.FileDescription
+ if cb.Flags&linux.IOCB_FLAG_RESFD != 0 {
+ eventFD = t.GetFileVFS2(cb.ResFD)
+ if eventFD == nil {
+ return syserror.EBADF
+ }
+ defer eventFD.DecRef()
+
+ // Check that it is an eventfd.
+ if _, ok := eventFD.Impl().(*eventfd.EventFileDescription); !ok {
+ return syserror.EINVAL
+ }
+ }
+
+ ioseq, err := memoryFor(t, cb)
+ if err != nil {
+ return err
+ }
+
+ // Check offset for reads/writes.
+ switch cb.OpCode {
+ case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV:
+ if cb.Offset < 0 {
+ return syserror.EINVAL
+ }
+ }
+
+ // Prepare the request.
+ aioCtx, ok := t.MemoryManager().LookupAIOContext(t, id)
+ if !ok {
+ return syserror.EINVAL
+ }
+ if ready := aioCtx.Prepare(); !ready {
+ // Context is busy.
+ return syserror.EAGAIN
+ }
+
+ if eventFD != nil {
+ // The request is set. Make sure there's a ref on the file.
+ //
+ // This is necessary when the callback executes on completion,
+ // which is also what will release this reference.
+ eventFD.IncRef()
+ }
+
+ // Perform the request asynchronously.
+ fd.IncRef()
+ t.QueueAIO(getAIOCallback(t, fd, eventFD, cbAddr, cb, ioseq, aioCtx))
+ return nil
+}
+
+func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr usermem.Addr, cb *linux.IOCallback, ioseq usermem.IOSequence, aioCtx *mm.AIOContext) kernel.AIOCallback {
+ return func(ctx context.Context) {
+ if aioCtx.Dead() {
+ aioCtx.CancelPendingRequest()
+ return
+ }
+ ev := &linux.IOEvent{
+ Data: cb.Data,
+ Obj: uint64(cbAddr),
+ }
+
+ var err error
+ switch cb.OpCode {
+ case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV:
+ ev.Result, err = fd.PRead(ctx, ioseq, cb.Offset, vfs.ReadOptions{})
+ case linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV:
+ ev.Result, err = fd.PWrite(ctx, ioseq, cb.Offset, vfs.WriteOptions{})
+ case linux.IOCB_CMD_FSYNC, linux.IOCB_CMD_FDSYNC:
+ err = fd.Sync(ctx)
+ }
+
+ // Update the result.
+ if err != nil {
+ err = slinux.HandleIOErrorVFS2(t, ev.Result != 0 /* partial */, err, nil /* never interrupted */, "aio", fd)
+ ev.Result = -int64(kernel.ExtractErrno(err, 0))
+ }
+
+ fd.DecRef()
+
+ // Queue the result for delivery.
+ aioCtx.FinishRequest(ev)
+
+ // Notify the event file if one was specified. This needs to happen
+ // *after* queueing the result to avoid racing with the thread we may
+ // wake up.
+ if eventFD != nil {
+ eventFD.Impl().(*eventfd.EventFileDescription).Signal(1)
+ eventFD.DecRef()
+ }
+ }
+}
+
+// memoryFor returns appropriate memory for the given callback.
+func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error) {
+ bytes := int(cb.Bytes)
+ if bytes < 0 {
+ // Linux also requires that this field fit in ssize_t.
+ return usermem.IOSequence{}, syserror.EINVAL
+ }
+
+ // Since this I/O will be asynchronous with respect to t's task goroutine,
+ // we have no guarantee that t's AddressSpace will be active during the
+ // I/O.
+ switch cb.OpCode {
+ case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PWRITE:
+ return t.SingleIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{
+ AddressSpaceActive: false,
+ })
+
+ case linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITEV:
+ return t.IovecsIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{
+ AddressSpaceActive: false,
+ })
+
+ case linux.IOCB_CMD_FSYNC, linux.IOCB_CMD_FDSYNC, linux.IOCB_CMD_NOOP:
+ return usermem.IOSequence{}, nil
+
+ default:
+ // Not a supported command.
+ return usermem.IOSequence{}, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go
new file mode 100644
index 000000000..34c90ae3e
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go
@@ -0,0 +1,228 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "math"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+var sizeofEpollEvent = (*linux.EpollEvent)(nil).SizeBytes()
+
+// EpollCreate1 implements Linux syscall epoll_create1(2).
+func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := args[0].Int()
+ if flags&^linux.EPOLL_CLOEXEC != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file, err := t.Kernel().VFS().NewEpollInstanceFD()
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.EPOLL_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// EpollCreate implements Linux syscall epoll_create(2).
+func EpollCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ size := args[0].Int()
+
+ // "Since Linux 2.6.8, the size argument is ignored, but must be greater
+ // than zero" - epoll_create(2)
+ if size <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file, err := t.Kernel().VFS().NewEpollInstanceFD()
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{})
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// EpollCtl implements Linux syscall epoll_ctl(2).
+func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ op := args[1].Int()
+ fd := args[2].Int()
+ eventAddr := args[3].Pointer()
+
+ epfile := t.GetFileVFS2(epfd)
+ if epfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer epfile.DecRef()
+ ep, ok := epfile.Impl().(*vfs.EpollInstance)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+ if epfile == file {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var event linux.EpollEvent
+ switch op {
+ case linux.EPOLL_CTL_ADD:
+ if _, err := event.CopyIn(t, eventAddr); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, ep.AddInterest(file, fd, event)
+ case linux.EPOLL_CTL_DEL:
+ return 0, nil, ep.DeleteInterest(file, fd)
+ case linux.EPOLL_CTL_MOD:
+ if _, err := event.CopyIn(t, eventAddr); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, ep.ModifyInterest(file, fd, event)
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+}
+
+// EpollWait implements Linux syscall epoll_wait(2).
+func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ eventsAddr := args[1].Pointer()
+ maxEvents := int(args[2].Int())
+ timeout := int(args[3].Int())
+
+ var _EP_MAX_EVENTS = math.MaxInt32 / sizeofEpollEvent // Linux: fs/eventpoll.c:EP_MAX_EVENTS
+ if maxEvents <= 0 || maxEvents > _EP_MAX_EVENTS {
+ return 0, nil, syserror.EINVAL
+ }
+
+ epfile := t.GetFileVFS2(epfd)
+ if epfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer epfile.DecRef()
+ ep, ok := epfile.Impl().(*vfs.EpollInstance)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Use a fixed-size buffer in a loop, instead of make([]linux.EpollEvent,
+ // maxEvents), so that the buffer can be allocated on the stack.
+ var (
+ events [16]linux.EpollEvent
+ total int
+ ch chan struct{}
+ haveDeadline bool
+ deadline ktime.Time
+ )
+ for {
+ batchEvents := len(events)
+ if batchEvents > maxEvents {
+ batchEvents = maxEvents
+ }
+ n := ep.ReadEvents(events[:batchEvents])
+ maxEvents -= n
+ if n != 0 {
+ // Copy what we read out.
+ copiedBytes, err := linux.CopyEpollEventSliceOut(t, eventsAddr, events[:n])
+ copiedEvents := copiedBytes / sizeofEpollEvent // rounded down
+ eventsAddr += usermem.Addr(copiedEvents * sizeofEpollEvent)
+ total += copiedEvents
+ if err != nil {
+ if total != 0 {
+ return uintptr(total), nil, nil
+ }
+ return 0, nil, err
+ }
+ // If we've filled the application's event buffer, we're done.
+ if maxEvents == 0 {
+ return uintptr(total), nil, nil
+ }
+ // Loop if we read a full batch, under the expectation that there
+ // may be more events to read.
+ if n == batchEvents {
+ continue
+ }
+ }
+ // We get here if n != batchEvents. If we read any number of events
+ // (just now, or in a previous iteration of this loop), or if timeout
+ // is 0 (such that epoll_wait should be non-blocking), return the
+ // events we've read so far to the application.
+ if total != 0 || timeout == 0 {
+ return uintptr(total), nil, nil
+ }
+ // In the first iteration of this loop, register with the epoll
+ // instance for readability events, but then immediately continue the
+ // loop since we need to retry ReadEvents() before blocking. In all
+ // subsequent iterations, block until events are available, the timeout
+ // expires, or an interrupt arrives.
+ if ch == nil {
+ var w waiter.Entry
+ w, ch = waiter.NewChannelEntry(nil)
+ epfile.EventRegister(&w, waiter.EventIn)
+ defer epfile.EventUnregister(&w)
+ } else {
+ // Set up the timer if a timeout was specified.
+ if timeout > 0 && !haveDeadline {
+ timeoutDur := time.Duration(timeout) * time.Millisecond
+ deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur)
+ haveDeadline = true
+ }
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = nil
+ }
+ // total must be 0 since otherwise we would have returned
+ // above.
+ return 0, nil, err
+ }
+ }
+ }
+}
+
+// EpollPwait implements Linux syscall epoll_pwait(2).
+func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ maskAddr := args[4].Pointer()
+ maskSize := uint(args[5].Uint())
+
+ if err := setTempSignalSet(t, maskAddr, maskSize); err != nil {
+ return 0, nil, err
+ }
+
+ return EpollWait(t, args)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/eventfd.go b/pkg/sentry/syscalls/linux/vfs2/eventfd.go
new file mode 100644
index 000000000..aff1a2070
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/eventfd.go
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/eventfd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Eventfd2 implements linux syscall eventfd2(2).
+func Eventfd2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ initVal := uint64(args[0].Uint())
+ flags := uint(args[1].Uint())
+ allOps := uint(linux.EFD_SEMAPHORE | linux.EFD_NONBLOCK | linux.EFD_CLOEXEC)
+
+ if flags & ^allOps != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ vfsObj := t.Kernel().VFS()
+ fileFlags := uint32(linux.O_RDWR)
+ if flags&linux.EFD_NONBLOCK != 0 {
+ fileFlags |= linux.O_NONBLOCK
+ }
+ semMode := flags&linux.EFD_SEMAPHORE != 0
+ eventfd, err := eventfd.New(vfsObj, initVal, semMode, fileFlags)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer eventfd.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, eventfd, kernel.FDFlags{
+ CloseOnExec: flags&linux.EFD_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// Eventfd implements linux syscall eventfd(2).
+func Eventfd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ args[1].Value = 0
+ return Eventfd2(t, args)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go
new file mode 100644
index 000000000..aef0078a8
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/execve.go
@@ -0,0 +1,137 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/loader"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Execve implements linux syscall execve(2).
+func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathnameAddr := args[0].Pointer()
+ argvAddr := args[1].Pointer()
+ envvAddr := args[2].Pointer()
+ return execveat(t, linux.AT_FDCWD, pathnameAddr, argvAddr, envvAddr, 0 /* flags */)
+}
+
+// Execveat implements linux syscall execveat(2).
+func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathnameAddr := args[1].Pointer()
+ argvAddr := args[2].Pointer()
+ envvAddr := args[3].Pointer()
+ flags := args[4].Int()
+ return execveat(t, dirfd, pathnameAddr, argvAddr, envvAddr, flags)
+}
+
+func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX)
+ if err != nil {
+ return 0, nil, err
+ }
+ var argv, envv []string
+ if argvAddr != 0 {
+ var err error
+ argv, err = t.CopyInVector(argvAddr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+ if envvAddr != 0 {
+ var err error
+ envv, err = t.CopyInVector(envvAddr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ var executable fsbridge.File
+ closeOnExec := false
+ if path := fspath.Parse(pathname); dirfd != linux.AT_FDCWD && !path.Absolute {
+ // We must open the executable ourselves since dirfd is used as the
+ // starting point while resolving path, but the task working directory
+ // is used as the starting point while resolving interpreters (Linux:
+ // fs/binfmt_script.c:load_script() => fs/exec.c:open_exec() =>
+ // do_open_execat(fd=AT_FDCWD)), and the loader package is currently
+ // incapable of handling this correctly.
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return 0, nil, syserror.ENOENT
+ }
+ dirfile, dirfileFlags := t.FDTable().GetVFS2(dirfd)
+ if dirfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ start := dirfile.VirtualDentry()
+ start.IncRef()
+ dirfile.DecRef()
+ closeOnExec = dirfileFlags.CloseOnExec
+ file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ FileExec: true,
+ })
+ start.DecRef()
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+ executable = fsbridge.NewVFSFile(file)
+ }
+
+ // Load the new TaskContext.
+ mntns := t.MountNamespaceVFS2() // FIXME(jamieliu): useless refcount change
+ defer mntns.DecRef()
+ wd := t.FSContext().WorkingDirectoryVFS2()
+ defer wd.DecRef()
+ remainingTraversals := uint(linux.MaxSymlinkTraversals)
+ loadArgs := loader.LoadArgs{
+ Opener: fsbridge.NewVFSLookup(mntns, root, wd),
+ RemainingTraversals: &remainingTraversals,
+ ResolveFinal: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ Filename: pathname,
+ File: executable,
+ CloseOnExec: closeOnExec,
+ Argv: argv,
+ Envv: envv,
+ Features: t.Arch().FeatureSet(),
+ }
+
+ tc, se := t.Kernel().LoadTaskImage(t, loadArgs)
+ if se != nil {
+ return 0, nil, se.ToError()
+ }
+
+ ctrl, err := t.Execve(tc)
+ return 0, ctrl, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
new file mode 100644
index 000000000..517394ba9
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -0,0 +1,355 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/fasync"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Close implements Linux syscall close(2).
+func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ // Note that Remove provides a reference on the file that we may use to
+ // flush. It is still active until we drop the final reference below
+ // (and other reference-holding operations complete).
+ _, file := t.FDTable().Remove(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ err := file.OnClose(t)
+ return 0, nil, slinux.HandleIOErrorVFS2(t, false /* partial */, err, syserror.EINTR, "close", file)
+}
+
+// Dup implements Linux syscall dup(2).
+func Dup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ newFD, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{})
+ if err != nil {
+ return 0, nil, syserror.EMFILE
+ }
+ return uintptr(newFD), nil, nil
+}
+
+// Dup2 implements Linux syscall dup2(2).
+func Dup2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldfd := args[0].Int()
+ newfd := args[1].Int()
+
+ if oldfd == newfd {
+ // As long as oldfd is valid, dup2() does nothing and returns newfd.
+ file := t.GetFileVFS2(oldfd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ file.DecRef()
+ return uintptr(newfd), nil, nil
+ }
+
+ return dup3(t, oldfd, newfd, 0)
+}
+
+// Dup3 implements Linux syscall dup3(2).
+func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldfd := args[0].Int()
+ newfd := args[1].Int()
+ flags := args[2].Uint()
+
+ if oldfd == newfd {
+ return 0, nil, syserror.EINVAL
+ }
+
+ return dup3(t, oldfd, newfd, flags)
+}
+
+func dup3(t *kernel.Task, oldfd, newfd int32, flags uint32) (uintptr, *kernel.SyscallControl, error) {
+ if flags&^linux.O_CLOEXEC != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(oldfd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ err := t.NewFDAtVFS2(newfd, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(newfd), nil, nil
+}
+
+// Fcntl implements linux syscall fcntl(2).
+func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ cmd := args[1].Int()
+
+ file, flags := t.FDTable().GetVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ switch cmd {
+ case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC:
+ minfd := args[2].Int()
+ fd, err := t.NewFDFromVFS2(minfd, file, kernel.FDFlags{
+ CloseOnExec: cmd == linux.F_DUPFD_CLOEXEC,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+ case linux.F_GETFD:
+ return uintptr(flags.ToLinuxFDFlags()), nil, nil
+ case linux.F_SETFD:
+ flags := args[2].Uint()
+ err := t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{
+ CloseOnExec: flags&linux.FD_CLOEXEC != 0,
+ })
+ return 0, nil, err
+ case linux.F_GETFL:
+ return uintptr(file.StatusFlags()), nil, nil
+ case linux.F_SETFL:
+ return 0, nil, file.SetStatusFlags(t, t.Credentials(), args[2].Uint())
+ case linux.F_SETPIPE_SZ:
+ pipefile, ok := file.Impl().(*pipe.VFSPipeFD)
+ if !ok {
+ return 0, nil, syserror.EBADF
+ }
+ n, err := pipefile.SetPipeSize(int64(args[2].Int()))
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+ case linux.F_GETOWN:
+ owner, hasOwner := getAsyncOwner(t, file)
+ if !hasOwner {
+ return 0, nil, nil
+ }
+ if owner.Type == linux.F_OWNER_PGRP {
+ return uintptr(-owner.PID), nil, nil
+ }
+ return uintptr(owner.PID), nil, nil
+ case linux.F_SETOWN:
+ who := args[2].Int()
+ ownerType := int32(linux.F_OWNER_PID)
+ if who < 0 {
+ // Check for overflow before flipping the sign.
+ if who-1 > who {
+ return 0, nil, syserror.EINVAL
+ }
+ ownerType = linux.F_OWNER_PGRP
+ who = -who
+ }
+ return 0, nil, setAsyncOwner(t, file, ownerType, who)
+ case linux.F_GETOWN_EX:
+ owner, hasOwner := getAsyncOwner(t, file)
+ if !hasOwner {
+ return 0, nil, nil
+ }
+ _, err := t.CopyOut(args[2].Pointer(), &owner)
+ return 0, nil, err
+ case linux.F_SETOWN_EX:
+ var owner linux.FOwnerEx
+ n, err := t.CopyIn(args[2].Pointer(), &owner)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, setAsyncOwner(t, file, owner.Type, owner.PID)
+ case linux.F_GETPIPE_SZ:
+ pipefile, ok := file.Impl().(*pipe.VFSPipeFD)
+ if !ok {
+ return 0, nil, syserror.EBADF
+ }
+ return uintptr(pipefile.PipeSize()), nil, nil
+ case linux.F_GET_SEALS:
+ val, err := tmpfs.GetSeals(file)
+ return uintptr(val), nil, err
+ case linux.F_ADD_SEALS:
+ if !file.IsWritable() {
+ return 0, nil, syserror.EPERM
+ }
+ err := tmpfs.AddSeals(file, args[2].Uint())
+ return 0, nil, err
+ case linux.F_SETLK, linux.F_SETLKW:
+ return 0, nil, posixLock(t, args, file, cmd)
+ default:
+ // TODO(gvisor.dev/issue/2920): Everything else is not yet supported.
+ return 0, nil, syserror.EINVAL
+ }
+}
+
+func getAsyncOwner(t *kernel.Task, fd *vfs.FileDescription) (ownerEx linux.FOwnerEx, hasOwner bool) {
+ a := fd.AsyncHandler()
+ if a == nil {
+ return linux.FOwnerEx{}, false
+ }
+
+ ot, otg, opg := a.(*fasync.FileAsync).Owner()
+ switch {
+ case ot != nil:
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_TID,
+ PID: int32(t.PIDNamespace().IDOfTask(ot)),
+ }, true
+ case otg != nil:
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_PID,
+ PID: int32(t.PIDNamespace().IDOfThreadGroup(otg)),
+ }, true
+ case opg != nil:
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_PGRP,
+ PID: int32(t.PIDNamespace().IDOfProcessGroup(opg)),
+ }, true
+ default:
+ return linux.FOwnerEx{}, true
+ }
+}
+
+func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32) error {
+ switch ownerType {
+ case linux.F_OWNER_TID, linux.F_OWNER_PID, linux.F_OWNER_PGRP:
+ // Acceptable type.
+ default:
+ return syserror.EINVAL
+ }
+
+ a := fd.SetAsyncHandler(fasync.NewVFS2).(*fasync.FileAsync)
+ if pid == 0 {
+ a.ClearOwner()
+ return nil
+ }
+
+ switch ownerType {
+ case linux.F_OWNER_TID:
+ task := t.PIDNamespace().TaskWithID(kernel.ThreadID(pid))
+ if task == nil {
+ return syserror.ESRCH
+ }
+ a.SetOwnerTask(t, task)
+ return nil
+ case linux.F_OWNER_PID:
+ tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(pid))
+ if tg == nil {
+ return syserror.ESRCH
+ }
+ a.SetOwnerThreadGroup(t, tg)
+ return nil
+ case linux.F_OWNER_PGRP:
+ pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(pid))
+ if pg == nil {
+ return syserror.ESRCH
+ }
+ a.SetOwnerProcessGroup(t, pg)
+ return nil
+ default:
+ return syserror.EINVAL
+ }
+}
+
+func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescription, cmd int32) error {
+ // Copy in the lock request.
+ flockAddr := args[2].Pointer()
+ var flock linux.Flock
+ if _, err := t.CopyIn(flockAddr, &flock); err != nil {
+ return err
+ }
+
+ var blocker lock.Blocker
+ if cmd == linux.F_SETLKW {
+ blocker = t
+ }
+
+ switch flock.Type {
+ case linux.F_RDLCK:
+ if !file.IsReadable() {
+ return syserror.EBADF
+ }
+ return file.LockPOSIX(t, t.FDTable(), lock.ReadLock, uint64(flock.Start), uint64(flock.Len), flock.Whence, blocker)
+
+ case linux.F_WRLCK:
+ if !file.IsWritable() {
+ return syserror.EBADF
+ }
+ return file.LockPOSIX(t, t.FDTable(), lock.WriteLock, uint64(flock.Start), uint64(flock.Len), flock.Whence, blocker)
+
+ case linux.F_UNLCK:
+ return file.UnlockPOSIX(t, t.FDTable(), uint64(flock.Start), uint64(flock.Len), flock.Whence)
+
+ default:
+ return syserror.EINVAL
+ }
+}
+
+// Fadvise64 implements fadvise64(2).
+// This implementation currently ignores the provided advice.
+func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ length := args[2].Int64()
+ advice := args[3].Int()
+
+ // Note: offset is allowed to be negative.
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // If the FD refers to a pipe or FIFO, return error.
+ if _, isPipe := file.Impl().(*pipe.VFSPipeFD); isPipe {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ switch advice {
+ case linux.POSIX_FADV_NORMAL:
+ case linux.POSIX_FADV_RANDOM:
+ case linux.POSIX_FADV_SEQUENTIAL:
+ case linux.POSIX_FADV_WILLNEED:
+ case linux.POSIX_FADV_DONTNEED:
+ case linux.POSIX_FADV_NOREUSE:
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Sure, whatever.
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
new file mode 100644
index 000000000..6b14c2bef
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
@@ -0,0 +1,384 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Link implements Linux syscall link(2).
+func Link(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldpathAddr := args[0].Pointer()
+ newpathAddr := args[1].Pointer()
+ return 0, nil, linkat(t, linux.AT_FDCWD, oldpathAddr, linux.AT_FDCWD, newpathAddr, 0 /* flags */)
+}
+
+// Linkat implements Linux syscall linkat(2).
+func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ flags := args[4].Int()
+ return 0, nil, linkat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags)
+}
+
+func linkat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_FOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+ if flags&linux.AT_EMPTY_PATH != 0 && !t.HasCapability(linux.CAP_DAC_READ_SEARCH) {
+ return syserror.ENOENT
+ }
+
+ oldpath, err := copyInPath(t, oldpathAddr)
+ if err != nil {
+ return err
+ }
+ oldtpop, err := getTaskPathOperation(t, olddirfd, oldpath, shouldAllowEmptyPath(flags&linux.AT_EMPTY_PATH != 0), shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_FOLLOW != 0))
+ if err != nil {
+ return err
+ }
+ defer oldtpop.Release()
+
+ newpath, err := copyInPath(t, newpathAddr)
+ if err != nil {
+ return err
+ }
+ newtpop, err := getTaskPathOperation(t, newdirfd, newpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer newtpop.Release()
+
+ return t.Kernel().VFS().LinkAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop)
+}
+
+// Mkdir implements Linux syscall mkdir(2).
+func Mkdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return 0, nil, mkdirat(t, linux.AT_FDCWD, addr, mode)
+}
+
+// Mkdirat implements Linux syscall mkdirat(2).
+func Mkdirat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+ return 0, nil, mkdirat(t, dirfd, addr, mode)
+}
+
+func mkdirat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint) error {
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().MkdirAt(t, t.Credentials(), &tpop.pop, &vfs.MkdirOptions{
+ Mode: linux.FileMode(mode & (0777 | linux.S_ISVTX) &^ t.FSContext().Umask()),
+ })
+}
+
+// Mknod implements Linux syscall mknod(2).
+func Mknod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ dev := args[2].Uint()
+ return 0, nil, mknodat(t, linux.AT_FDCWD, addr, linux.FileMode(mode), dev)
+}
+
+// Mknodat implements Linux syscall mknodat(2).
+func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+ dev := args[3].Uint()
+ return 0, nil, mknodat(t, dirfd, addr, linux.FileMode(mode), dev)
+}
+
+func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode linux.FileMode, dev uint32) error {
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+
+ // "Zero file type is equivalent to type S_IFREG." - mknod(2)
+ if mode.FileType() == 0 {
+ mode |= linux.ModeRegular
+ }
+ major, minor := linux.DecodeDeviceID(dev)
+ return t.Kernel().VFS().MknodAt(t, t.Credentials(), &tpop.pop, &vfs.MknodOptions{
+ Mode: mode &^ linux.FileMode(t.FSContext().Umask()),
+ DevMajor: uint32(major),
+ DevMinor: minor,
+ })
+}
+
+// Open implements Linux syscall open(2).
+func Open(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Uint()
+ mode := args[2].ModeT()
+ return openat(t, linux.AT_FDCWD, addr, flags, mode)
+}
+
+// Openat implements Linux syscall openat(2).
+func Openat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ flags := args[2].Uint()
+ mode := args[3].ModeT()
+ return openat(t, dirfd, addr, flags, mode)
+}
+
+// Creat implements Linux syscall creat(2).
+func Creat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return openat(t, linux.AT_FDCWD, addr, linux.O_WRONLY|linux.O_CREAT|linux.O_TRUNC, mode)
+}
+
+func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mode uint) (uintptr, *kernel.SyscallControl, error) {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, shouldFollowFinalSymlink(flags&linux.O_NOFOLLOW == 0))
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &tpop.pop, &vfs.OpenOptions{
+ Flags: flags | linux.O_LARGEFILE,
+ Mode: linux.FileMode(mode & (0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX) &^ t.FSContext().Umask()),
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ return uintptr(fd), nil, err
+}
+
+// Rename implements Linux syscall rename(2).
+func Rename(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldpathAddr := args[0].Pointer()
+ newpathAddr := args[1].Pointer()
+ return 0, nil, renameat(t, linux.AT_FDCWD, oldpathAddr, linux.AT_FDCWD, newpathAddr, 0 /* flags */)
+}
+
+// Renameat implements Linux syscall renameat(2).
+func Renameat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, 0 /* flags */)
+}
+
+// Renameat2 implements Linux syscall renameat2(2).
+func Renameat2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ flags := args[4].Uint()
+ return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags)
+}
+
+func renameat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags uint32) error {
+ oldpath, err := copyInPath(t, oldpathAddr)
+ if err != nil {
+ return err
+ }
+ // "If oldpath refers to a symbolic link, the link is renamed" - rename(2)
+ oldtpop, err := getTaskPathOperation(t, olddirfd, oldpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer oldtpop.Release()
+
+ newpath, err := copyInPath(t, newpathAddr)
+ if err != nil {
+ return err
+ }
+ newtpop, err := getTaskPathOperation(t, newdirfd, newpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer newtpop.Release()
+
+ return t.Kernel().VFS().RenameAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop, &vfs.RenameOptions{
+ Flags: flags,
+ })
+}
+
+// Fallocate implements linux system call fallocate(2).
+func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ mode := args[1].Uint64()
+ offset := args[2].Int64()
+ length := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ if !file.IsWritable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ if mode != 0 {
+ return 0, nil, syserror.ENOTSUP
+ }
+
+ if offset < 0 || length <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ size := offset + length
+
+ if size < 0 {
+ return 0, nil, syserror.EFBIG
+ }
+
+ limit := limits.FromContext(t).Get(limits.FileSize).Cur
+
+ if uint64(size) >= limit {
+ t.SendSignal(&arch.SignalInfo{
+ Signo: int32(linux.SIGXFSZ),
+ Code: arch.SignalInfoUser,
+ })
+ return 0, nil, syserror.EFBIG
+ }
+
+ return 0, nil, file.Impl().Allocate(t, mode, uint64(offset), uint64(length))
+
+ // File length modified, generate notification.
+ // TODO(gvisor.dev/issue/1479): Reenable when Inotify is ported.
+ // file.Dirent.InotifyEvent(linux.IN_MODIFY, 0)
+}
+
+// Rmdir implements Linux syscall rmdir(2).
+func Rmdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ return 0, nil, rmdirat(t, linux.AT_FDCWD, pathAddr)
+}
+
+func rmdirat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().RmdirAt(t, t.Credentials(), &tpop.pop)
+}
+
+// Unlink implements Linux syscall unlink(2).
+func Unlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ return 0, nil, unlinkat(t, linux.AT_FDCWD, pathAddr)
+}
+
+func unlinkat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().UnlinkAt(t, t.Credentials(), &tpop.pop)
+}
+
+// Unlinkat implements Linux syscall unlinkat(2).
+func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ flags := args[2].Int()
+
+ if flags&^linux.AT_REMOVEDIR != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if flags&linux.AT_REMOVEDIR != 0 {
+ return 0, nil, rmdirat(t, dirfd, pathAddr)
+ }
+ return 0, nil, unlinkat(t, dirfd, pathAddr)
+}
+
+// Symlink implements Linux syscall symlink(2).
+func Symlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ targetAddr := args[0].Pointer()
+ linkpathAddr := args[1].Pointer()
+ return 0, nil, symlinkat(t, targetAddr, linux.AT_FDCWD, linkpathAddr)
+}
+
+// Symlinkat implements Linux syscall symlinkat(2).
+func Symlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ targetAddr := args[0].Pointer()
+ newdirfd := args[1].Int()
+ linkpathAddr := args[2].Pointer()
+ return 0, nil, symlinkat(t, targetAddr, newdirfd, linkpathAddr)
+}
+
+func symlinkat(t *kernel.Task, targetAddr usermem.Addr, newdirfd int32, linkpathAddr usermem.Addr) error {
+ target, err := t.CopyInString(targetAddr, linux.PATH_MAX)
+ if err != nil {
+ return err
+ }
+ if len(target) == 0 {
+ return syserror.ENOENT
+ }
+ linkpath, err := copyInPath(t, linkpathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, newdirfd, linkpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().SymlinkAt(t, t.Credentials(), &tpop.pop, target)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/fscontext.go b/pkg/sentry/syscalls/linux/vfs2/fscontext.go
new file mode 100644
index 000000000..317409a18
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/fscontext.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Getcwd implements Linux syscall getcwd(2).
+func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ size := args[1].SizeT()
+
+ root := t.FSContext().RootDirectoryVFS2()
+ wd := t.FSContext().WorkingDirectoryVFS2()
+ s, err := t.Kernel().VFS().PathnameForGetcwd(t, root, wd)
+ root.DecRef()
+ wd.DecRef()
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Note this is >= because we need a terminator.
+ if uint(len(s)) >= size {
+ return 0, nil, syserror.ERANGE
+ }
+
+ // Construct a byte slice containing a NUL terminator.
+ buf := t.CopyScratchBuffer(len(s) + 1)
+ copy(buf, s)
+ buf[len(buf)-1] = 0
+
+ // Write the pathname slice.
+ n, err := t.CopyOutBytes(addr, buf)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Chdir implements Linux syscall chdir(2).
+func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetWorkingDirectoryVFS2(vd)
+ vd.DecRef()
+ return 0, nil, nil
+}
+
+// Fchdir implements Linux syscall fchdir(2).
+func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ tpop, err := getTaskPathOperation(t, fd, fspath.Path{}, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetWorkingDirectoryVFS2(vd)
+ vd.DecRef()
+ return 0, nil, nil
+}
+
+// Chroot implements Linux syscall chroot(2).
+func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ if !t.HasCapability(linux.CAP_SYS_CHROOT) {
+ return 0, nil, syserror.EPERM
+ }
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetRootDirectoryVFS2(vd)
+ vd.DecRef()
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go
new file mode 100644
index 000000000..c7c7bf7ce
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go
@@ -0,0 +1,161 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Getdents implements Linux syscall getdents(2).
+func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getdents(t, args, false /* isGetdents64 */)
+}
+
+// Getdents64 implements Linux syscall getdents64(2).
+func Getdents64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getdents(t, args, true /* isGetdents64 */)
+}
+
+func getdents(t *kernel.Task, args arch.SyscallArguments, isGetdents64 bool) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := int(args[2].Uint())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ cb := getGetdentsCallback(t, addr, size, isGetdents64)
+ err := file.IterDirents(t, cb)
+ n := size - cb.remaining
+ putGetdentsCallback(cb)
+ if n == 0 {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+type getdentsCallback struct {
+ t *kernel.Task
+ addr usermem.Addr
+ remaining int
+ isGetdents64 bool
+}
+
+var getdentsCallbackPool = sync.Pool{
+ New: func() interface{} {
+ return &getdentsCallback{}
+ },
+}
+
+func getGetdentsCallback(t *kernel.Task, addr usermem.Addr, size int, isGetdents64 bool) *getdentsCallback {
+ cb := getdentsCallbackPool.Get().(*getdentsCallback)
+ *cb = getdentsCallback{
+ t: t,
+ addr: addr,
+ remaining: size,
+ isGetdents64: isGetdents64,
+ }
+ return cb
+}
+
+func putGetdentsCallback(cb *getdentsCallback) {
+ cb.t = nil
+ getdentsCallbackPool.Put(cb)
+}
+
+// Handle implements vfs.IterDirentsCallback.Handle.
+func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
+ var buf []byte
+ if cb.isGetdents64 {
+ // struct linux_dirent64 {
+ // ino64_t d_ino; /* 64-bit inode number */
+ // off64_t d_off; /* 64-bit offset to next structure */
+ // unsigned short d_reclen; /* Size of this dirent */
+ // unsigned char d_type; /* File type */
+ // char d_name[]; /* Filename (null-terminated) */
+ // };
+ size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name)
+ size = (size + 7) &^ 7 // round up to multiple of 8
+ if size > cb.remaining {
+ return syserror.EINVAL
+ }
+ buf = cb.t.CopyScratchBuffer(size)
+ usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
+ usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
+ usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
+ buf[18] = dirent.Type
+ copy(buf[19:], dirent.Name)
+ // Zero out all remaining bytes in buf, including the NUL terminator
+ // after dirent.Name.
+ bufTail := buf[19+len(dirent.Name):]
+ for i := range bufTail {
+ bufTail[i] = 0
+ }
+ } else {
+ // struct linux_dirent {
+ // unsigned long d_ino; /* Inode number */
+ // unsigned long d_off; /* Offset to next linux_dirent */
+ // unsigned short d_reclen; /* Length of this linux_dirent */
+ // char d_name[]; /* Filename (null-terminated) */
+ // /* length is actually (d_reclen - 2 -
+ // offsetof(struct linux_dirent, d_name)) */
+ // /*
+ // char pad; // Zero padding byte
+ // char d_type; // File type (only since Linux
+ // // 2.6.4); offset is (d_reclen - 1)
+ // */
+ // };
+ if cb.t.Arch().Width() != 8 {
+ panic(fmt.Sprintf("unsupported sizeof(unsigned long): %d", cb.t.Arch().Width()))
+ }
+ size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name)
+ size = (size + 7) &^ 7 // round up to multiple of sizeof(long)
+ if size > cb.remaining {
+ return syserror.EINVAL
+ }
+ buf = cb.t.CopyScratchBuffer(size)
+ usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
+ usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
+ usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
+ copy(buf[18:], dirent.Name)
+ // Zero out all remaining bytes in buf, including the NUL terminator
+ // after dirent.Name and the zero padding byte between the name and
+ // dirent type.
+ bufTail := buf[18+len(dirent.Name) : size-1]
+ for i := range bufTail {
+ bufTail[i] = 0
+ }
+ buf[size-1] = dirent.Type
+ }
+ n, err := cb.t.CopyOutBytes(cb.addr, buf)
+ if err != nil {
+ // Don't report partially-written dirents by advancing cb.addr or
+ // cb.remaining.
+ return err
+ }
+ cb.addr += usermem.Addr(n)
+ cb.remaining -= n
+ return nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/inotify.go b/pkg/sentry/syscalls/linux/vfs2/inotify.go
new file mode 100644
index 000000000..5d98134a5
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/inotify.go
@@ -0,0 +1,137 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const allFlags = linux.IN_NONBLOCK | linux.IN_CLOEXEC
+
+// InotifyInit1 implements the inotify_init1() syscalls.
+func InotifyInit1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := args[0].Int()
+ if flags&^allFlags != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ ino, err := vfs.NewInotifyFD(t, t.Kernel().VFS(), uint32(flags))
+ if err != nil {
+ return 0, nil, err
+ }
+ defer ino.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, ino, kernel.FDFlags{
+ CloseOnExec: flags&linux.IN_CLOEXEC != 0,
+ })
+
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// InotifyInit implements the inotify_init() syscalls.
+func InotifyInit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ args[0].Value = 0
+ return InotifyInit1(t, args)
+}
+
+// fdToInotify resolves an fd to an inotify object. If successful, the file will
+// have an extra ref and the caller is responsible for releasing the ref.
+func fdToInotify(t *kernel.Task, fd int32) (*vfs.Inotify, *vfs.FileDescription, error) {
+ f := t.GetFileVFS2(fd)
+ if f == nil {
+ // Invalid fd.
+ return nil, nil, syserror.EBADF
+ }
+
+ ino, ok := f.Impl().(*vfs.Inotify)
+ if !ok {
+ // Not an inotify fd.
+ f.DecRef()
+ return nil, nil, syserror.EINVAL
+ }
+
+ return ino, f, nil
+}
+
+// InotifyAddWatch implements the inotify_add_watch() syscall.
+func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ mask := args[2].Uint()
+
+ // "EINVAL: The given event mask contains no valid events."
+ // -- inotify_add_watch(2)
+ if mask&linux.ALL_INOTIFY_BITS == 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // "IN_DONT_FOLLOW: Don't dereference pathname if it is a symbolic link."
+ // -- inotify(7)
+ follow := followFinalSymlink
+ if mask&linux.IN_DONT_FOLLOW == 0 {
+ follow = nofollowFinalSymlink
+ }
+
+ ino, f, err := fdToInotify(t, fd)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer f.DecRef()
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ if mask&linux.IN_ONLYDIR != 0 {
+ path.Dir = true
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, follow)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+ d, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{})
+ if err != nil {
+ return 0, nil, err
+ }
+ defer d.DecRef()
+
+ fd, err = ino.AddWatch(d.Dentry(), mask)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// InotifyRmWatch implements the inotify_rm_watch() syscall.
+func InotifyRmWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ wd := args[1].Int()
+
+ ino, f, err := fdToInotify(t, fd)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer f.DecRef()
+ return 0, nil, ino.RmWatch(wd)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
new file mode 100644
index 000000000..fd6ab94b2
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
@@ -0,0 +1,107 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Ioctl implements Linux syscall ioctl(2).
+func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Handle ioctls that apply to all FDs.
+ switch args[1].Int() {
+ case linux.FIONCLEX:
+ t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{
+ CloseOnExec: false,
+ })
+ return 0, nil, nil
+
+ case linux.FIOCLEX:
+ t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{
+ CloseOnExec: true,
+ })
+ return 0, nil, nil
+
+ case linux.FIONBIO:
+ var set int32
+ if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ return 0, nil, err
+ }
+ flags := file.StatusFlags()
+ if set != 0 {
+ flags |= linux.O_NONBLOCK
+ } else {
+ flags &^= linux.O_NONBLOCK
+ }
+ return 0, nil, file.SetStatusFlags(t, t.Credentials(), flags)
+
+ case linux.FIOASYNC:
+ var set int32
+ if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ return 0, nil, err
+ }
+ flags := file.StatusFlags()
+ if set != 0 {
+ flags |= linux.O_ASYNC
+ } else {
+ flags &^= linux.O_ASYNC
+ }
+ file.SetStatusFlags(t, t.Credentials(), flags)
+ return 0, nil, nil
+
+ case linux.FIOGETOWN, linux.SIOCGPGRP:
+ var who int32
+ owner, hasOwner := getAsyncOwner(t, file)
+ if hasOwner {
+ if owner.Type == linux.F_OWNER_PGRP {
+ who = -owner.PID
+ } else {
+ who = owner.PID
+ }
+ }
+ _, err := t.CopyOut(args[2].Pointer(), &who)
+ return 0, nil, err
+
+ case linux.FIOSETOWN, linux.SIOCSPGRP:
+ var who int32
+ if _, err := t.CopyIn(args[2].Pointer(), &who); err != nil {
+ return 0, nil, err
+ }
+ ownerType := int32(linux.F_OWNER_PID)
+ if who < 0 {
+ // Check for overflow before flipping the sign.
+ if who-1 > who {
+ return 0, nil, syserror.EINVAL
+ }
+ ownerType = linux.F_OWNER_PGRP
+ who = -who
+ }
+ return 0, nil, setAsyncOwner(t, file, ownerType, who)
+ }
+
+ ret, err := file.Ioctl(t, t.MemoryManager(), args)
+ return ret, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/lock.go b/pkg/sentry/syscalls/linux/vfs2/lock.go
new file mode 100644
index 000000000..bf19028c4
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/lock.go
@@ -0,0 +1,64 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Flock implements linux syscall flock(2).
+func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ operation := args[1].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ // flock(2): EBADF fd is not an open file descriptor.
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ nonblocking := operation&linux.LOCK_NB != 0
+ operation &^= linux.LOCK_NB
+
+ var blocker lock.Blocker
+ if !nonblocking {
+ blocker = t
+ }
+
+ switch operation {
+ case linux.LOCK_EX:
+ if err := file.LockBSD(t, lock.WriteLock, blocker); err != nil {
+ return 0, nil, err
+ }
+ case linux.LOCK_SH:
+ if err := file.LockBSD(t, lock.ReadLock, blocker); err != nil {
+ return 0, nil, err
+ }
+ case linux.LOCK_UN:
+ if err := file.UnlockBSD(t); err != nil {
+ return 0, nil, err
+ }
+ default:
+ // flock(2): EINVAL operation is invalid.
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/memfd.go b/pkg/sentry/syscalls/linux/vfs2/memfd.go
new file mode 100644
index 000000000..bbe248d17
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/memfd.go
@@ -0,0 +1,63 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const (
+ memfdPrefix = "memfd:"
+ memfdMaxNameLen = linux.NAME_MAX - len(memfdPrefix)
+ memfdAllFlags = uint32(linux.MFD_CLOEXEC | linux.MFD_ALLOW_SEALING)
+)
+
+// MemfdCreate implements the linux syscall memfd_create(2).
+func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Uint()
+
+ if flags&^memfdAllFlags != 0 {
+ // Unknown bits in flags.
+ return 0, nil, syserror.EINVAL
+ }
+
+ allowSeals := flags&linux.MFD_ALLOW_SEALING != 0
+ cloExec := flags&linux.MFD_CLOEXEC != 0
+
+ name, err := t.CopyInString(addr, memfdMaxNameLen)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ shmMount := t.Kernel().ShmMount()
+ file, err := tmpfs.NewMemfd(shmMount, t.Credentials(), allowSeals, memfdPrefix+name)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: cloExec,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go
new file mode 100644
index 000000000..60a43f0a0
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go
@@ -0,0 +1,92 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Mmap implements Linux syscall mmap(2).
+func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ prot := args[2].Int()
+ flags := args[3].Int()
+ fd := args[4].Int()
+ fixed := flags&linux.MAP_FIXED != 0
+ private := flags&linux.MAP_PRIVATE != 0
+ shared := flags&linux.MAP_SHARED != 0
+ anon := flags&linux.MAP_ANONYMOUS != 0
+ map32bit := flags&linux.MAP_32BIT != 0
+
+ // Require exactly one of MAP_PRIVATE and MAP_SHARED.
+ if private == shared {
+ return 0, nil, syserror.EINVAL
+ }
+
+ opts := memmap.MMapOpts{
+ Length: args[1].Uint64(),
+ Offset: args[5].Uint64(),
+ Addr: args[0].Pointer(),
+ Fixed: fixed,
+ Unmap: fixed,
+ Map32Bit: map32bit,
+ Private: private,
+ Perms: usermem.AccessType{
+ Read: linux.PROT_READ&prot != 0,
+ Write: linux.PROT_WRITE&prot != 0,
+ Execute: linux.PROT_EXEC&prot != 0,
+ },
+ MaxPerms: usermem.AnyAccess,
+ GrowsDown: linux.MAP_GROWSDOWN&flags != 0,
+ Precommit: linux.MAP_POPULATE&flags != 0,
+ }
+ if linux.MAP_LOCKED&flags != 0 {
+ opts.MLockMode = memmap.MLockEager
+ }
+ defer func() {
+ if opts.MappingIdentity != nil {
+ opts.MappingIdentity.DecRef()
+ }
+ }()
+
+ if !anon {
+ // Convert the passed FD to a file reference.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // mmap unconditionally requires that the FD is readable.
+ if !file.IsReadable() {
+ return 0, nil, syserror.EACCES
+ }
+ // MAP_SHARED requires that the FD be writable for PROT_WRITE.
+ if shared && !file.IsWritable() {
+ opts.MaxPerms.Write = false
+ }
+
+ if err := file.ConfigureMMap(t, &opts); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ rv, err := t.MemoryManager().MMap(t, opts)
+ return uintptr(rv), nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go
new file mode 100644
index 000000000..adeaa39cc
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/mount.go
@@ -0,0 +1,145 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Mount implements Linux syscall mount(2).
+func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ sourceAddr := args[0].Pointer()
+ targetAddr := args[1].Pointer()
+ typeAddr := args[2].Pointer()
+ flags := args[3].Uint64()
+ dataAddr := args[4].Pointer()
+
+ // For null-terminated strings related to mount(2), Linux copies in at most
+ // a page worth of data. See fs/namespace.c:copy_mount_string().
+ fsType, err := t.CopyInString(typeAddr, usermem.PageSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ source, err := t.CopyInString(sourceAddr, usermem.PageSize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ targetPath, err := copyInPath(t, targetAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ data := ""
+ if dataAddr != 0 {
+ // In Linux, a full page is always copied in regardless of null
+ // character placement, and the address is passed to each file system.
+ // Most file systems always treat this data as a string, though, and so
+ // do all of the ones we implement.
+ data, err = t.CopyInString(dataAddr, usermem.PageSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+
+ // Ignore magic value that was required before Linux 2.4.
+ if flags&linux.MS_MGC_MSK == linux.MS_MGC_VAL {
+ flags = flags &^ linux.MS_MGC_MSK
+ }
+
+ // Must have CAP_SYS_ADMIN in the current mount namespace's associated user
+ // namespace.
+ creds := t.Credentials()
+ if !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.MountNamespaceVFS2().Owner) {
+ return 0, nil, syserror.EPERM
+ }
+
+ const unsupportedOps = linux.MS_REMOUNT | linux.MS_BIND |
+ linux.MS_SHARED | linux.MS_PRIVATE | linux.MS_SLAVE |
+ linux.MS_UNBINDABLE | linux.MS_MOVE
+
+ // Silently allow MS_NOSUID, since we don't implement set-id bits
+ // anyway.
+ const unsupportedFlags = linux.MS_NODEV |
+ linux.MS_NODIRATIME | linux.MS_STRICTATIME
+
+ // Linux just allows passing any flags to mount(2) - it won't fail when
+ // unknown or unsupported flags are passed. Since we don't implement
+ // everything, we fail explicitly on flags that are unimplemented.
+ if flags&(unsupportedOps|unsupportedFlags) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var opts vfs.MountOptions
+ if flags&linux.MS_NOATIME == linux.MS_NOATIME {
+ opts.Flags.NoATime = true
+ }
+ if flags&linux.MS_NOEXEC == linux.MS_NOEXEC {
+ opts.Flags.NoExec = true
+ }
+ if flags&linux.MS_RDONLY == linux.MS_RDONLY {
+ opts.ReadOnly = true
+ }
+ opts.GetFilesystemOptions.Data = data
+
+ target, err := getTaskPathOperation(t, linux.AT_FDCWD, targetPath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer target.Release()
+
+ return 0, nil, t.Kernel().VFS().MountAt(t, creds, source, &target.pop, fsType, &opts)
+}
+
+// Umount2 implements Linux syscall umount2(2).
+func Umount2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Int()
+
+ // Must have CAP_SYS_ADMIN in the mount namespace's associated user
+ // namespace.
+ //
+ // Currently, this is always the init task's user namespace.
+ creds := t.Credentials()
+ if !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.MountNamespaceVFS2().Owner) {
+ return 0, nil, syserror.EPERM
+ }
+
+ const unsupported = linux.MNT_FORCE | linux.MNT_EXPIRE
+ if flags&unsupported != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ opts := vfs.UmountOptions{
+ Flags: uint32(flags),
+ }
+
+ return 0, nil, t.Kernel().VFS().UmountAt(t, creds, &tpop.pop, &opts)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/path.go b/pkg/sentry/syscalls/linux/vfs2/path.go
new file mode 100644
index 000000000..97da6c647
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/path.go
@@ -0,0 +1,94 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func copyInPath(t *kernel.Task, addr usermem.Addr) (fspath.Path, error) {
+ pathname, err := t.CopyInString(addr, linux.PATH_MAX)
+ if err != nil {
+ return fspath.Path{}, err
+ }
+ return fspath.Parse(pathname), nil
+}
+
+type taskPathOperation struct {
+ pop vfs.PathOperation
+ haveStartRef bool
+}
+
+func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink) (taskPathOperation, error) {
+ root := t.FSContext().RootDirectoryVFS2()
+ start := root
+ haveStartRef := false
+ if !path.Absolute {
+ if !path.HasComponents() && !bool(shouldAllowEmptyPath) {
+ root.DecRef()
+ return taskPathOperation{}, syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ haveStartRef = true
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ root.DecRef()
+ return taskPathOperation{}, syserror.EBADF
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ haveStartRef = true
+ dirfile.DecRef()
+ }
+ }
+ return taskPathOperation{
+ pop: vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: bool(shouldFollowFinalSymlink),
+ },
+ haveStartRef: haveStartRef,
+ }, nil
+}
+
+func (tpop *taskPathOperation) Release() {
+ tpop.pop.Root.DecRef()
+ if tpop.haveStartRef {
+ tpop.pop.Start.DecRef()
+ tpop.haveStartRef = false
+ }
+}
+
+type shouldAllowEmptyPath bool
+
+const (
+ disallowEmptyPath shouldAllowEmptyPath = false
+ allowEmptyPath shouldAllowEmptyPath = true
+)
+
+type shouldFollowFinalSymlink bool
+
+const (
+ nofollowFinalSymlink shouldFollowFinalSymlink = false
+ followFinalSymlink shouldFollowFinalSymlink = true
+)
diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go
new file mode 100644
index 000000000..4a01e4209
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go
@@ -0,0 +1,63 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/pipefs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Pipe implements Linux syscall pipe(2).
+func Pipe(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ return 0, nil, pipe2(t, addr, 0)
+}
+
+// Pipe2 implements Linux syscall pipe2(2).
+func Pipe2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Int()
+ return 0, nil, pipe2(t, addr, flags)
+}
+
+func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error {
+ if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 {
+ return syserror.EINVAL
+ }
+ r, w := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK))
+ defer r.DecRef()
+ defer w.DecRef()
+
+ fds, err := t.NewFDsVFS2(0, []*vfs.FileDescription{r, w}, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ if err != nil {
+ return err
+ }
+ if _, err := t.CopyOut(addr, fds); err != nil {
+ for _, fd := range fds {
+ if _, file := t.FDTable().Remove(fd); file != nil {
+ file.DecRef()
+ }
+ }
+ return err
+ }
+ return nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go
new file mode 100644
index 000000000..ff1b25d7b
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/poll.go
@@ -0,0 +1,586 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// fileCap is the maximum allowable files for poll & select. This has no
+// equivalent in Linux; it exists in gVisor since allocation failure in Go is
+// unrecoverable.
+const fileCap = 1024 * 1024
+
+// Masks for "readable", "writable", and "exceptional" events as defined by
+// select(2).
+const (
+ // selectReadEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLIN_SET.
+ selectReadEvents = linux.POLLIN | linux.POLLHUP | linux.POLLERR
+
+ // selectWriteEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLOUT_SET.
+ selectWriteEvents = linux.POLLOUT | linux.POLLERR
+
+ // selectExceptEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLEX_SET.
+ selectExceptEvents = linux.POLLPRI
+)
+
+// pollState tracks the associated file description and waiter of a PollFD.
+type pollState struct {
+ file *vfs.FileDescription
+ waiter waiter.Entry
+}
+
+// initReadiness gets the current ready mask for the file represented by the FD
+// stored in pfd.FD. If a channel is passed in, the waiter entry in "state" is
+// used to register with the file for event notifications, and a reference to
+// the file is stored in "state".
+func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan struct{}) {
+ if pfd.FD < 0 {
+ pfd.REvents = 0
+ return
+ }
+
+ file := t.GetFileVFS2(pfd.FD)
+ if file == nil {
+ pfd.REvents = linux.POLLNVAL
+ return
+ }
+
+ if ch == nil {
+ defer file.DecRef()
+ } else {
+ state.file = file
+ state.waiter, _ = waiter.NewChannelEntry(ch)
+ file.EventRegister(&state.waiter, waiter.EventMaskFromLinux(uint32(pfd.Events)))
+ }
+
+ r := file.Readiness(waiter.EventMaskFromLinux(uint32(pfd.Events)))
+ pfd.REvents = int16(r.ToLinux()) & pfd.Events
+}
+
+// releaseState releases all the pollState in "state".
+func releaseState(state []pollState) {
+ for i := range state {
+ if state[i].file != nil {
+ state[i].file.EventUnregister(&state[i].waiter)
+ state[i].file.DecRef()
+ }
+ }
+}
+
+// pollBlock polls the PollFDs in "pfd" with a bounded time specified in "timeout"
+// when "timeout" is greater than zero.
+//
+// pollBlock returns the remaining timeout, which is always 0 on a timeout; and 0 or
+// positive if interrupted by a signal.
+func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.Duration, uintptr, error) {
+ var ch chan struct{}
+ if timeout != 0 {
+ ch = make(chan struct{}, 1)
+ }
+
+ // Register for event notification in the files involved if we may
+ // block (timeout not zero). Once we find a file that has a non-zero
+ // result, we stop registering for events but still go through all files
+ // to get their ready masks.
+ state := make([]pollState, len(pfd))
+ defer releaseState(state)
+ n := uintptr(0)
+ for i := range pfd {
+ initReadiness(t, &pfd[i], &state[i], ch)
+ if pfd[i].REvents != 0 {
+ n++
+ ch = nil
+ }
+ }
+
+ if timeout == 0 {
+ return timeout, n, nil
+ }
+
+ haveTimeout := timeout >= 0
+
+ for n == 0 {
+ var err error
+ // Wait for a notification.
+ timeout, err = t.BlockWithTimeout(ch, haveTimeout, timeout)
+ if err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = nil
+ }
+ return timeout, 0, err
+ }
+
+ // We got notified, count how many files are ready. If none,
+ // then this was a spurious notification, and we just go back
+ // to sleep with the remaining timeout.
+ for i := range state {
+ if state[i].file == nil {
+ continue
+ }
+
+ r := state[i].file.Readiness(waiter.EventMaskFromLinux(uint32(pfd[i].Events)))
+ rl := int16(r.ToLinux()) & pfd[i].Events
+ if rl != 0 {
+ pfd[i].REvents = rl
+ n++
+ }
+ }
+ }
+
+ return timeout, n, nil
+}
+
+// copyInPollFDs copies an array of struct pollfd unless nfds exceeds the max.
+func copyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD, error) {
+ if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) {
+ return nil, syserror.EINVAL
+ }
+
+ pfd := make([]linux.PollFD, nfds)
+ if nfds > 0 {
+ if _, err := t.CopyIn(addr, &pfd); err != nil {
+ return nil, err
+ }
+ }
+
+ return pfd, nil
+}
+
+func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) {
+ pfd, err := copyInPollFDs(t, addr, nfds)
+ if err != nil {
+ return timeout, 0, err
+ }
+
+ // Compatibility warning: Linux adds POLLHUP and POLLERR just before
+ // polling, in fs/select.c:do_pollfd(). Since pfd is copied out after
+ // polling, changing event masks here is an application-visible difference.
+ // (Linux also doesn't copy out event masks at all, only revents.)
+ for i := range pfd {
+ pfd[i].Events |= linux.POLLHUP | linux.POLLERR
+ }
+ remainingTimeout, n, err := pollBlock(t, pfd, timeout)
+ err = syserror.ConvertIntr(err, syserror.EINTR)
+
+ // The poll entries are copied out regardless of whether
+ // any are set or not. This aligns with the Linux behavior.
+ if nfds > 0 && err == nil {
+ if _, err := t.CopyOut(addr, pfd); err != nil {
+ return remainingTimeout, 0, err
+ }
+ }
+
+ return remainingTimeout, n, err
+}
+
+// CopyInFDSet copies an fd set from select(2)/pselect(2).
+func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) {
+ set := make([]byte, nBytes)
+
+ if addr != 0 {
+ if _, err := t.CopyIn(addr, &set); err != nil {
+ return nil, err
+ }
+ // If we only use part of the last byte, mask out the extraneous bits.
+ //
+ // N.B. This only works on little-endian architectures.
+ if nBitsInLastPartialByte != 0 {
+ set[nBytes-1] &^= byte(0xff) << nBitsInLastPartialByte
+ }
+ }
+ return set, nil
+}
+
+func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) {
+ if nfds < 0 || nfds > fileCap {
+ return 0, syserror.EINVAL
+ }
+
+ // Calculate the size of the fd sets (one bit per fd).
+ nBytes := (nfds + 7) / 8
+ nBitsInLastPartialByte := nfds % 8
+
+ // Capture all the provided input vectors.
+ r, err := CopyInFDSet(t, readFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+ w, err := CopyInFDSet(t, writeFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+ e, err := CopyInFDSet(t, exceptFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+
+ // Count how many FDs are actually being requested so that we can build
+ // a PollFD array.
+ fdCount := 0
+ for i := 0; i < nBytes; i++ {
+ v := r[i] | w[i] | e[i]
+ for v != 0 {
+ v &= (v - 1)
+ fdCount++
+ }
+ }
+
+ // Build the PollFD array.
+ pfd := make([]linux.PollFD, 0, fdCount)
+ var fd int32
+ for i := 0; i < nBytes; i++ {
+ rV, wV, eV := r[i], w[i], e[i]
+ v := rV | wV | eV
+ m := byte(1)
+ for j := 0; j < 8; j++ {
+ if (v & m) != 0 {
+ // Make sure the fd is valid and decrement the reference
+ // immediately to ensure we don't leak. Note, another thread
+ // might be about to close fd. This is racy, but that's
+ // OK. Linux is racy in the same way.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ file.DecRef()
+
+ var mask int16
+ if (rV & m) != 0 {
+ mask |= selectReadEvents
+ }
+
+ if (wV & m) != 0 {
+ mask |= selectWriteEvents
+ }
+
+ if (eV & m) != 0 {
+ mask |= selectExceptEvents
+ }
+
+ pfd = append(pfd, linux.PollFD{
+ FD: fd,
+ Events: mask,
+ })
+ }
+
+ fd++
+ m <<= 1
+ }
+ }
+
+ // Do the syscall, then count the number of bits set.
+ if _, _, err = pollBlock(t, pfd, timeout); err != nil {
+ return 0, syserror.ConvertIntr(err, syserror.EINTR)
+ }
+
+ // r, w, and e are currently event mask bitsets; unset bits corresponding
+ // to events that *didn't* occur.
+ bitSetCount := uintptr(0)
+ for idx := range pfd {
+ events := pfd[idx].REvents
+ i, j := pfd[idx].FD/8, uint(pfd[idx].FD%8)
+ m := byte(1) << j
+ if r[i]&m != 0 {
+ if (events & selectReadEvents) != 0 {
+ bitSetCount++
+ } else {
+ r[i] &^= m
+ }
+ }
+ if w[i]&m != 0 {
+ if (events & selectWriteEvents) != 0 {
+ bitSetCount++
+ } else {
+ w[i] &^= m
+ }
+ }
+ if e[i]&m != 0 {
+ if (events & selectExceptEvents) != 0 {
+ bitSetCount++
+ } else {
+ e[i] &^= m
+ }
+ }
+ }
+
+ // Copy updated vectors back.
+ if readFDs != 0 {
+ if _, err := t.CopyOut(readFDs, r); err != nil {
+ return 0, err
+ }
+ }
+
+ if writeFDs != 0 {
+ if _, err := t.CopyOut(writeFDs, w); err != nil {
+ return 0, err
+ }
+ }
+
+ if exceptFDs != 0 {
+ if _, err := t.CopyOut(exceptFDs, e); err != nil {
+ return 0, err
+ }
+ }
+
+ return bitSetCount, nil
+}
+
+// timeoutRemaining returns the amount of time remaining for the specified
+// timeout or 0 if it has elapsed.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func timeoutRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration) time.Duration {
+ now := t.Kernel().MonotonicClock().Now()
+ remaining := timeout - now.Sub(startNs)
+ if remaining < 0 {
+ remaining = 0
+ }
+ return remaining
+}
+
+// copyOutTimespecRemaining copies the time remaining in timeout to timespecAddr.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr usermem.Addr) error {
+ if timeout <= 0 {
+ return nil
+ }
+ remaining := timeoutRemaining(t, startNs, timeout)
+ tsRemaining := linux.NsecToTimespec(remaining.Nanoseconds())
+ _, err := tsRemaining.CopyOut(t, timespecAddr)
+ return err
+}
+
+// copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr usermem.Addr) error {
+ if timeout <= 0 {
+ return nil
+ }
+ remaining := timeoutRemaining(t, startNs, timeout)
+ tvRemaining := linux.NsecToTimeval(remaining.Nanoseconds())
+ _, err := tvRemaining.CopyOut(t, timevalAddr)
+ return err
+}
+
+// pollRestartBlock encapsulates the state required to restart poll(2) via
+// restart_syscall(2).
+//
+// +stateify savable
+type pollRestartBlock struct {
+ pfdAddr usermem.Addr
+ nfds uint
+ timeout time.Duration
+}
+
+// Restart implements kernel.SyscallRestartBlock.Restart.
+func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
+ return poll(t, p.pfdAddr, p.nfds, p.timeout)
+}
+
+func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration) (uintptr, error) {
+ remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout)
+ // On an interrupt poll(2) is restarted with the remaining timeout.
+ if err == syserror.EINTR {
+ t.SetSyscallRestartBlock(&pollRestartBlock{
+ pfdAddr: pfdAddr,
+ nfds: nfds,
+ timeout: remainingTimeout,
+ })
+ return 0, kernel.ERESTART_RESTARTBLOCK
+ }
+ return n, err
+}
+
+// Poll implements linux syscall poll(2).
+func Poll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pfdAddr := args[0].Pointer()
+ nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
+ timeout := time.Duration(args[2].Int()) * time.Millisecond
+ n, err := poll(t, pfdAddr, nfds, timeout)
+ return n, nil, err
+}
+
+// Ppoll implements linux syscall ppoll(2).
+func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pfdAddr := args[0].Pointer()
+ nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
+ timespecAddr := args[2].Pointer()
+ maskAddr := args[3].Pointer()
+ maskSize := uint(args[4].Uint())
+
+ timeout, err := copyTimespecInToDuration(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var startNs ktime.Time
+ if timeout > 0 {
+ startNs = t.Kernel().MonotonicClock().Now()
+ }
+
+ if err := setTempSignalSet(t, maskAddr, maskSize); err != nil {
+ return 0, nil, err
+ }
+
+ _, n, err := doPoll(t, pfdAddr, nfds, timeout)
+ copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
+ // doPoll returns EINTR if interrupted, but ppoll is normally restartable
+ // if interrupted by something other than a signal handled by the
+ // application (i.e. returns ERESTARTNOHAND). However, if
+ // copyOutTimespecRemaining failed, then the restarted ppoll would use the
+ // wrong timeout, so the error should be left as EINTR.
+ //
+ // Note that this means that if err is nil but copyErr is not, copyErr is
+ // ignored. This is consistent with Linux.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// Select implements linux syscall select(2).
+func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nfds := int(args[0].Int()) // select(2) uses an int.
+ readFDs := args[1].Pointer()
+ writeFDs := args[2].Pointer()
+ exceptFDs := args[3].Pointer()
+ timevalAddr := args[4].Pointer()
+
+ // Use a negative Duration to indicate "no timeout".
+ timeout := time.Duration(-1)
+ if timevalAddr != 0 {
+ var timeval linux.Timeval
+ if _, err := timeval.CopyIn(t, timevalAddr); err != nil {
+ return 0, nil, err
+ }
+ if timeval.Sec < 0 || timeval.Usec < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ timeout = time.Duration(timeval.ToNsecCapped())
+ }
+ startNs := t.Kernel().MonotonicClock().Now()
+ n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
+ copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr)
+ // See comment in Ppoll.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// Pselect implements linux syscall pselect(2).
+func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nfds := int(args[0].Int()) // select(2) uses an int.
+ readFDs := args[1].Pointer()
+ writeFDs := args[2].Pointer()
+ exceptFDs := args[3].Pointer()
+ timespecAddr := args[4].Pointer()
+ maskWithSizeAddr := args[5].Pointer()
+
+ timeout, err := copyTimespecInToDuration(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var startNs ktime.Time
+ if timeout > 0 {
+ startNs = t.Kernel().MonotonicClock().Now()
+ }
+
+ if maskWithSizeAddr != 0 {
+ if t.Arch().Width() != 8 {
+ panic(fmt.Sprintf("unsupported sizeof(void*): %d", t.Arch().Width()))
+ }
+ var maskStruct sigSetWithSize
+ if _, err := maskStruct.CopyIn(t, maskWithSizeAddr); err != nil {
+ return 0, nil, err
+ }
+ if err := setTempSignalSet(t, usermem.Addr(maskStruct.sigsetAddr), uint(maskStruct.sizeofSigset)); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
+ copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
+ // See comment in Ppoll.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// +marshal
+type sigSetWithSize struct {
+ sigsetAddr uint64
+ sizeofSigset uint64
+}
+
+// copyTimespecInToDuration copies a Timespec from the untrusted app range,
+// validates it and converts it to a Duration.
+//
+// If the Timespec is larger than what can be represented in a Duration, the
+// returned value is the maximum that Duration will allow.
+//
+// If timespecAddr is NULL, the returned value is negative.
+func copyTimespecInToDuration(t *kernel.Task, timespecAddr usermem.Addr) (time.Duration, error) {
+ // Use a negative Duration to indicate "no timeout".
+ timeout := time.Duration(-1)
+ if timespecAddr != 0 {
+ var timespec linux.Timespec
+ if _, err := timespec.CopyIn(t, timespecAddr); err != nil {
+ return 0, err
+ }
+ if !timespec.Valid() {
+ return 0, syserror.EINVAL
+ }
+ timeout = time.Duration(timespec.ToNsecCapped())
+ }
+ return timeout, nil
+}
+
+func setTempSignalSet(t *kernel.Task, maskAddr usermem.Addr, maskSize uint) error {
+ if maskAddr == 0 {
+ return nil
+ }
+ if maskSize != linux.SignalSetSize {
+ return syserror.EINVAL
+ }
+ var mask linux.SignalSet
+ if _, err := mask.CopyIn(t, maskAddr); err != nil {
+ return err
+ }
+ mask &^= kernel.UnblockableSignals
+ oldmask := t.SignalMask()
+ t.SetSignalMask(mask)
+ t.SetSavedSignalMask(oldmask)
+ return nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go
new file mode 100644
index 000000000..cd25597a7
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go
@@ -0,0 +1,641 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ eventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
+ eventMaskWrite = waiter.EventOut | waiter.EventHUp | waiter.EventErr
+)
+
+// Read implements Linux syscall read(2).
+func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := read(t, file, dst, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "read", file)
+}
+
+// Readv implements Linux syscall readv(2).
+func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := read(t, file, dst, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "readv", file)
+}
+
+func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ n, err := file.Read(t, dst, opts)
+ if err != syserror.ErrWouldBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
+ return n, err
+ }
+
+ allowBlock, deadline, hasDeadline := blockPolicy(t, file)
+ if !allowBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskRead)
+
+ total := n
+ for {
+ // Shorten dst to reflect bytes previously read.
+ dst = dst.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err = file.Read(t, dst, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ if total > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
+ return total, err
+}
+
+// Pread64 implements Linux syscall pread64(2).
+func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pread(t, file, dst, offset, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pread64", file)
+}
+
+// Preadv implements Linux syscall preadv(2).
+func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pread(t, file, dst, offset, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv", file)
+}
+
+// Preadv2 implements Linux syscall preadv2(2).
+func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // While the glibc signature is
+ // preadv2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
+ // the actual syscall
+ // (https://elixir.bootlin.com/linux/v5.5/source/fs/read_write.c#L1142)
+ // splits the offset argument into a high/low value for compatibility with
+ // 32-bit architectures. The flags argument is the 6th argument (index 5).
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+ flags := args[5].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < -1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.ReadOptions{
+ Flags: uint32(flags),
+ }
+ var n int64
+ if offset == -1 {
+ n, err = read(t, file, dst, opts)
+ } else {
+ n, err = pread(t, file, dst, offset, opts)
+ }
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv2", file)
+}
+
+func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ n, err := file.PRead(t, dst, offset, opts)
+ if err != syserror.ErrWouldBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
+ return n, err
+ }
+
+ allowBlock, deadline, hasDeadline := blockPolicy(t, file)
+ if !allowBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskRead)
+
+ total := n
+ for {
+ // Shorten dst to reflect bytes previously read.
+ dst = dst.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err = file.PRead(t, dst, offset+total, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ if total > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
+ return total, err
+}
+
+// Write implements Linux syscall write(2).
+func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := write(t, file, src, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "write", file)
+}
+
+// Writev implements Linux syscall writev(2).
+func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := write(t, file, src, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "writev", file)
+}
+
+func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ n, err := file.Write(t, src, opts)
+ if err != syserror.ErrWouldBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ }
+ return n, err
+ }
+
+ allowBlock, deadline, hasDeadline := blockPolicy(t, file)
+ if !allowBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ }
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskWrite)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err = file.Write(t, src, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ if total > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ }
+ return total, err
+}
+
+// Pwrite64 implements Linux syscall pwrite64(2).
+func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pwrite(t, file, src, offset, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwrite64", file)
+}
+
+// Pwritev implements Linux syscall pwritev(2).
+func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pwrite(t, file, src, offset, vfs.WriteOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev", file)
+}
+
+// Pwritev2 implements Linux syscall pwritev2(2).
+func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // While the glibc signature is
+ // pwritev2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
+ // the actual syscall
+ // (https://elixir.bootlin.com/linux/v5.5/source/fs/read_write.c#L1162)
+ // splits the offset argument into a high/low value for compatibility with
+ // 32-bit architectures. The flags argument is the 6th argument (index 5).
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+ flags := args[5].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < -1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.WriteOptions{
+ Flags: uint32(flags),
+ }
+ var n int64
+ if offset == -1 {
+ n, err = write(t, file, src, opts)
+ } else {
+ n, err = pwrite(t, file, src, offset, opts)
+ }
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev2", file)
+}
+
+func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, err := file.PWrite(t, src, offset, opts)
+ if err != syserror.ErrWouldBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ }
+ return n, err
+ }
+
+ allowBlock, deadline, hasDeadline := blockPolicy(t, file)
+ if !allowBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskWrite)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err = file.PWrite(t, src, offset+total, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ if total > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
+ return total, err
+}
+
+func blockPolicy(t *kernel.Task, file *vfs.FileDescription) (allowBlock bool, deadline ktime.Time, hasDeadline bool) {
+ if file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return false, ktime.Time{}, false
+ }
+ // Sockets support read/write timeouts.
+ if s, ok := file.Impl().(socket.SocketVFS2); ok {
+ dl := s.RecvTimeout()
+ if dl < 0 {
+ return false, ktime.Time{}, false
+ }
+ if dl > 0 {
+ return true, t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond), true
+ }
+ }
+ return true, ktime.Time{}, false
+}
+
+// Lseek implements Linux syscall lseek(2).
+func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ whence := args[2].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ newoff, err := file.Seek(t, offset, whence)
+ return uintptr(newoff), nil, err
+}
+
+// Readahead implements readahead(2).
+func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the file is readable.
+ if !file.IsReadable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Check that the size is valid.
+ if int(size) < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Return EINVAL; if the underlying file type does not support readahead,
+ // then Linux will return EINVAL to indicate as much. In the future, we
+ // may extend this function to actually support readahead hints.
+ return 0, nil, syserror.EINVAL
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
new file mode 100644
index 000000000..09ecfed26
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -0,0 +1,428 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const chmodMask = 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX
+
+// Chmod implements Linux syscall chmod(2).
+func Chmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return 0, nil, fchmodat(t, linux.AT_FDCWD, pathAddr, mode)
+}
+
+// Fchmodat implements Linux syscall fchmodat(2).
+func Fchmodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ mode := args[2].ModeT()
+ return 0, nil, fchmodat(t, dirfd, pathAddr, mode)
+}
+
+func fchmodat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ return setstatat(t, dirfd, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_MODE,
+ Mode: uint16(mode & chmodMask),
+ },
+ })
+}
+
+// Fchmod implements Linux syscall fchmod(2).
+func Fchmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ mode := args[1].ModeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.SetStat(t, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_MODE,
+ Mode: uint16(mode & chmodMask),
+ },
+ })
+}
+
+// Chown implements Linux syscall chown(2).
+func Chown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ owner := args[1].Int()
+ group := args[2].Int()
+ return 0, nil, fchownat(t, linux.AT_FDCWD, pathAddr, owner, group, 0 /* flags */)
+}
+
+// Lchown implements Linux syscall lchown(2).
+func Lchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ owner := args[1].Int()
+ group := args[2].Int()
+ return 0, nil, fchownat(t, linux.AT_FDCWD, pathAddr, owner, group, linux.AT_SYMLINK_NOFOLLOW)
+}
+
+// Fchownat implements Linux syscall fchownat(2).
+func Fchownat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ owner := args[2].Int()
+ group := args[3].Int()
+ flags := args[4].Int()
+ return 0, nil, fchownat(t, dirfd, pathAddr, owner, group, flags)
+}
+
+func fchownat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, owner, group, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil {
+ return err
+ }
+
+ return setstatat(t, dirfd, path, shouldAllowEmptyPath(flags&linux.AT_EMPTY_PATH != 0), shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_NOFOLLOW == 0), &opts)
+}
+
+func populateSetStatOptionsForChown(t *kernel.Task, owner, group int32, opts *vfs.SetStatOptions) error {
+ userns := t.UserNamespace()
+ if owner != -1 {
+ kuid := userns.MapToKUID(auth.UID(owner))
+ if !kuid.Ok() {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask |= linux.STATX_UID
+ opts.Stat.UID = uint32(kuid)
+ }
+ if group != -1 {
+ kgid := userns.MapToKGID(auth.GID(group))
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask |= linux.STATX_GID
+ opts.Stat.GID = uint32(kgid)
+ }
+ return nil
+}
+
+// Fchown implements Linux syscall fchown(2).
+func Fchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ owner := args[1].Int()
+ group := args[2].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, file.SetStat(t, opts)
+}
+
+// Truncate implements Linux syscall truncate(2).
+func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].Int64()
+
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ err = setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: uint64(length),
+ },
+ })
+ return 0, nil, handleSetSizeError(t, err)
+}
+
+// Ftruncate implements Linux syscall ftruncate(2).
+func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ length := args[1].Int64()
+
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ err := file.SetStat(t, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: uint64(length),
+ },
+ })
+ return 0, nil, handleSetSizeError(t, err)
+}
+
+// Utime implements Linux syscall utime(2).
+func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ timesAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_ATIME | linux.STATX_MTIME,
+ },
+ }
+ if timesAddr == 0 {
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ } else {
+ var times linux.Utime
+ if _, err := times.CopyIn(t, timesAddr); err != nil {
+ return 0, nil, err
+ }
+ opts.Stat.Atime.Sec = times.Actime
+ opts.Stat.Mtime.Sec = times.Modtime
+ }
+
+ return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &opts)
+}
+
+// Utimes implements Linux syscall utimes(2).
+func Utimes(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ timesAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForUtimes(t, timesAddr, &opts); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &opts)
+}
+
+// Futimesat implements Linux syscall futimesat(2).
+func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ timesAddr := args[2].Pointer()
+
+ // "If filename is NULL and dfd refers to an open file, then operate on the
+ // file. Otherwise look up filename, possibly using dfd as a starting
+ // point." - fs/utimes.c
+ var path fspath.Path
+ shouldAllowEmptyPath := allowEmptyPath
+ if dirfd == linux.AT_FDCWD || pathAddr != 0 {
+ var err error
+ path, err = copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ shouldAllowEmptyPath = disallowEmptyPath
+ }
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForUtimes(t, timesAddr, &opts); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, setstatat(t, dirfd, path, shouldAllowEmptyPath, followFinalSymlink, &opts)
+}
+
+func populateSetStatOptionsForUtimes(t *kernel.Task, timesAddr usermem.Addr, opts *vfs.SetStatOptions) error {
+ if timesAddr == 0 {
+ opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ return nil
+ }
+ var times [2]linux.Timeval
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return err
+ }
+ if times[0].Usec < 0 || times[0].Usec > 999999 || times[1].Usec < 0 || times[1].Usec > 999999 {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME
+ opts.Stat.Atime = linux.StatxTimestamp{
+ Sec: times[0].Sec,
+ Nsec: uint32(times[0].Usec * 1000),
+ }
+ opts.Stat.Mtime = linux.StatxTimestamp{
+ Sec: times[1].Sec,
+ Nsec: uint32(times[1].Usec * 1000),
+ }
+ return nil
+}
+
+// Utimensat implements Linux syscall utimensat(2).
+func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ timesAddr := args[2].Pointer()
+ flags := args[3].Int()
+
+ // Linux requires that the UTIME_OMIT check occur before checking path or
+ // flags.
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForUtimens(t, timesAddr, &opts); err != nil {
+ return 0, nil, err
+ }
+ if opts.Stat.Mask == 0 {
+ return 0, nil, nil
+ }
+
+ if flags&^linux.AT_SYMLINK_NOFOLLOW != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // "If filename is NULL and dfd refers to an open file, then operate on the
+ // file. Otherwise look up filename, possibly using dfd as a starting
+ // point." - fs/utimes.c
+ var path fspath.Path
+ shouldAllowEmptyPath := allowEmptyPath
+ if dirfd == linux.AT_FDCWD || pathAddr != 0 {
+ var err error
+ path, err = copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ shouldAllowEmptyPath = disallowEmptyPath
+ }
+
+ return 0, nil, setstatat(t, dirfd, path, shouldAllowEmptyPath, shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_NOFOLLOW == 0), &opts)
+}
+
+func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, opts *vfs.SetStatOptions) error {
+ if timesAddr == 0 {
+ opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ return nil
+ }
+ var times [2]linux.Timespec
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return err
+ }
+ if times[0].Nsec != linux.UTIME_OMIT {
+ if times[0].Nsec != linux.UTIME_NOW && (times[0].Nsec < 0 || times[0].Nsec > 999999999) {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask |= linux.STATX_ATIME
+ opts.Stat.Atime = linux.StatxTimestamp{
+ Sec: times[0].Sec,
+ Nsec: uint32(times[0].Nsec),
+ }
+ }
+ if times[1].Nsec != linux.UTIME_OMIT {
+ if times[1].Nsec != linux.UTIME_NOW && (times[1].Nsec < 0 || times[1].Nsec > 999999999) {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask |= linux.STATX_MTIME
+ opts.Stat.Mtime = linux.StatxTimestamp{
+ Sec: times[1].Sec,
+ Nsec: uint32(times[1].Nsec),
+ }
+ }
+ return nil
+}
+
+func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink, opts *vfs.SetStatOptions) error {
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && !bool(shouldAllowEmptyPath) {
+ return syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.SetStat() instead of
+ // VirtualFilesystem.SetStatAt(), since the former may be able
+ // to use opened file state to expedite the SetStat.
+ err := dirfile.SetStat(t, *opts)
+ dirfile.DecRef()
+ return err
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef()
+ dirfile.DecRef()
+ }
+ }
+ return t.Kernel().VFS().SetStatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: bool(shouldFollowFinalSymlink),
+ }, opts)
+}
+
+func handleSetSizeError(t *kernel.Task, err error) error {
+ if err == syserror.ErrExceedsFileSizeLimit {
+ // Convert error to EFBIG and send a SIGXFSZ per setrlimit(2).
+ t.SendSignal(kernel.SignalInfoNoInfo(linux.SIGXFSZ, t, t))
+ return syserror.EFBIG
+ }
+ return err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/signal.go b/pkg/sentry/syscalls/linux/vfs2/signal.go
new file mode 100644
index 000000000..623992f6f
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/signal.go
@@ -0,0 +1,100 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/signalfd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// sharedSignalfd is shared between the two calls.
+func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) {
+ // Copy in the signal mask.
+ mask, err := slinux.CopyInSigSet(t, sigset, sigsetsize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Always check for valid flags, even if not creating.
+ if flags&^(linux.SFD_NONBLOCK|linux.SFD_CLOEXEC) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Is this a change to an existing signalfd?
+ //
+ // The spec indicates that this should adjust the mask.
+ if fd != -1 {
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Is this a signalfd?
+ if sfd, ok := file.Impl().(*signalfd.SignalFileDescription); ok {
+ sfd.SetMask(mask)
+ return 0, nil, nil
+ }
+
+ // Not a signalfd.
+ return 0, nil, syserror.EINVAL
+ }
+
+ fileFlags := uint32(linux.O_RDWR)
+ if flags&linux.SFD_NONBLOCK != 0 {
+ fileFlags |= linux.O_NONBLOCK
+ }
+
+ // Create a new file.
+ vfsObj := t.Kernel().VFS()
+ file, err := signalfd.New(vfsObj, t, mask, fileFlags)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ // Create a new descriptor.
+ fd, err = t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.SFD_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Done.
+ return uintptr(fd), nil, nil
+}
+
+// Signalfd implements the linux syscall signalfd(2).
+func Signalfd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ sigset := args[1].Pointer()
+ sigsetsize := args[2].SizeT()
+ return sharedSignalfd(t, fd, sigset, sigsetsize, 0)
+}
+
+// Signalfd4 implements the linux syscall signalfd4(2).
+func Signalfd4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ sigset := args[1].Pointer()
+ sigsetsize := args[2].SizeT()
+ flags := args[3].Int()
+ return sharedSignalfd(t, fd, sigset, sigsetsize, flags)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
new file mode 100644
index 000000000..10b668477
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -0,0 +1,1139 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// minListenBacklog is the minimum reasonable backlog for listening sockets.
+const minListenBacklog = 8
+
+// maxListenBacklog is the maximum allowed backlog for listening sockets.
+const maxListenBacklog = 1024
+
+// maxAddrLen is the maximum socket address length we're willing to accept.
+const maxAddrLen = 200
+
+// maxOptLen is the maximum sockopt parameter length we're willing to accept.
+const maxOptLen = 1024 * 8
+
+// maxControlLen is the maximum length of the msghdr.msg_control buffer we're
+// willing to accept. Note that this limit is smaller than Linux, which allows
+// buffers upto INT_MAX.
+const maxControlLen = 10 * 1024 * 1024
+
+// nameLenOffset is the offset from the start of the MessageHeader64 struct to
+// the NameLen field.
+const nameLenOffset = 8
+
+// controlLenOffset is the offset form the start of the MessageHeader64 struct
+// to the ControlLen field.
+const controlLenOffset = 40
+
+// flagsOffset is the offset form the start of the MessageHeader64 struct
+// to the Flags field.
+const flagsOffset = 48
+
+const sizeOfInt32 = 4
+
+// messageHeader64Len is the length of a MessageHeader64 struct.
+var messageHeader64Len = uint64(binary.Size(MessageHeader64{}))
+
+// multipleMessageHeader64Len is the length of a multipeMessageHeader64 struct.
+var multipleMessageHeader64Len = uint64(binary.Size(multipleMessageHeader64{}))
+
+// baseRecvFlags are the flags that are accepted across recvmsg(2),
+// recvmmsg(2), and recvfrom(2).
+const baseRecvFlags = linux.MSG_OOB | linux.MSG_DONTROUTE | linux.MSG_DONTWAIT | linux.MSG_NOSIGNAL | linux.MSG_WAITALL | linux.MSG_TRUNC | linux.MSG_CTRUNC
+
+// MessageHeader64 is the 64-bit representation of the msghdr struct used in
+// the recvmsg and sendmsg syscalls.
+type MessageHeader64 struct {
+ // Name is the optional pointer to a network address buffer.
+ Name uint64
+
+ // NameLen is the length of the buffer pointed to by Name.
+ NameLen uint32
+ _ uint32
+
+ // Iov is a pointer to an array of io vectors that describe the memory
+ // locations involved in the io operation.
+ Iov uint64
+
+ // IovLen is the length of the array pointed to by Iov.
+ IovLen uint64
+
+ // Control is the optional pointer to ancillary control data.
+ Control uint64
+
+ // ControlLen is the length of the data pointed to by Control.
+ ControlLen uint64
+
+ // Flags on the sent/received message.
+ Flags int32
+ _ int32
+}
+
+// multipleMessageHeader64 is the 64-bit representation of the mmsghdr struct used in
+// the recvmmsg and sendmmsg syscalls.
+type multipleMessageHeader64 struct {
+ msgHdr MessageHeader64
+ msgLen uint32
+ _ int32
+}
+
+// CopyInMessageHeader64 copies a message header from user to kernel memory.
+func CopyInMessageHeader64(t *kernel.Task, addr usermem.Addr, msg *MessageHeader64) error {
+ b := t.CopyScratchBuffer(52)
+ if _, err := t.CopyInBytes(addr, b); err != nil {
+ return err
+ }
+
+ msg.Name = usermem.ByteOrder.Uint64(b[0:])
+ msg.NameLen = usermem.ByteOrder.Uint32(b[8:])
+ msg.Iov = usermem.ByteOrder.Uint64(b[16:])
+ msg.IovLen = usermem.ByteOrder.Uint64(b[24:])
+ msg.Control = usermem.ByteOrder.Uint64(b[32:])
+ msg.ControlLen = usermem.ByteOrder.Uint64(b[40:])
+ msg.Flags = int32(usermem.ByteOrder.Uint32(b[48:]))
+
+ return nil
+}
+
+// CaptureAddress allocates memory for and copies a socket address structure
+// from the untrusted address space range.
+func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) {
+ if addrlen > maxAddrLen {
+ return nil, syserror.EINVAL
+ }
+
+ addrBuf := make([]byte, addrlen)
+ if _, err := t.CopyInBytes(addr, addrBuf); err != nil {
+ return nil, err
+ }
+
+ return addrBuf, nil
+}
+
+// writeAddress writes a sockaddr structure and its length to an output buffer
+// in the unstrusted address space range. If the address is bigger than the
+// buffer, it is truncated.
+func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error {
+ // Get the buffer length.
+ var bufLen uint32
+ if _, err := t.CopyIn(addrLenPtr, &bufLen); err != nil {
+ return err
+ }
+
+ if int32(bufLen) < 0 {
+ return syserror.EINVAL
+ }
+
+ // Write the length unconditionally.
+ if _, err := t.CopyOut(addrLenPtr, addrLen); err != nil {
+ return err
+ }
+
+ if addr == nil {
+ return nil
+ }
+
+ if bufLen > addrLen {
+ bufLen = addrLen
+ }
+
+ // Copy as much of the address as will fit in the buffer.
+ encodedAddr := binary.Marshal(nil, usermem.ByteOrder, addr)
+ if bufLen > uint32(len(encodedAddr)) {
+ bufLen = uint32(len(encodedAddr))
+ }
+ _, err := t.CopyOutBytes(addrPtr, encodedAddr[:int(bufLen)])
+ return err
+}
+
+// Socket implements the linux syscall socket(2).
+func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ domain := int(args[0].Int())
+ stype := args[1].Int()
+ protocol := int(args[2].Int())
+
+ // Check and initialize the flags.
+ if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Create the new socket.
+ s, e := socket.NewVFS2(t, domain, linux.SockType(stype&0xf), protocol)
+ if e != nil {
+ return 0, nil, e.ToError()
+ }
+ defer s.DecRef()
+
+ if err := s.SetStatusFlags(t, t.Credentials(), uint32(stype&linux.SOCK_NONBLOCK)); err != nil {
+ return 0, nil, err
+ }
+
+ fd, err := t.NewFDFromVFS2(0, s, kernel.FDFlags{
+ CloseOnExec: stype&linux.SOCK_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// SocketPair implements the linux syscall socketpair(2).
+func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ domain := int(args[0].Int())
+ stype := args[1].Int()
+ protocol := int(args[2].Int())
+ addr := args[3].Pointer()
+
+ // Check and initialize the flags.
+ if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Create the socket pair.
+ s1, s2, e := socket.PairVFS2(t, domain, linux.SockType(stype&0xf), protocol)
+ if e != nil {
+ return 0, nil, e.ToError()
+ }
+ // Adding to the FD table will cause an extra reference to be acquired.
+ defer s1.DecRef()
+ defer s2.DecRef()
+
+ nonblocking := uint32(stype & linux.SOCK_NONBLOCK)
+ if err := s1.SetStatusFlags(t, t.Credentials(), nonblocking); err != nil {
+ return 0, nil, err
+ }
+ if err := s2.SetStatusFlags(t, t.Credentials(), nonblocking); err != nil {
+ return 0, nil, err
+ }
+
+ // Create the FDs for the sockets.
+ flags := kernel.FDFlags{
+ CloseOnExec: stype&linux.SOCK_CLOEXEC != 0,
+ }
+ fds, err := t.NewFDsVFS2(0, []*vfs.FileDescription{s1, s2}, flags)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if _, err := t.CopyOut(addr, fds); err != nil {
+ for _, fd := range fds {
+ if _, file := t.FDTable().Remove(fd); file != nil {
+ file.DecRef()
+ }
+ }
+ return 0, nil, err
+ }
+
+ return 0, nil, nil
+}
+
+// Connect implements the linux syscall connect(2).
+func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Uint()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Capture address and call syscall implementation.
+ a, err := CaptureAddress(t, addr, addrlen)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ blocking := (file.StatusFlags() & linux.SOCK_NONBLOCK) == 0
+ return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), kernel.ERESTARTSYS)
+}
+
+// accept is the implementation of the accept syscall. It is called by accept
+// and accept4 syscall handlers.
+func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, flags int) (uintptr, error) {
+ // Check that no unsupported flags are passed in.
+ if flags & ^(linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, syserror.ENOTSOCK
+ }
+
+ // Call the syscall implementation for this socket, then copy the
+ // output address if one is specified.
+ blocking := (file.StatusFlags() & linux.SOCK_NONBLOCK) == 0
+
+ peerRequested := addrLen != 0
+ nfd, peer, peerLen, e := s.Accept(t, peerRequested, flags, blocking)
+ if e != nil {
+ return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
+ }
+ if peerRequested {
+ // NOTE(magi): Linux does not give you an error if it can't
+ // write the data back out so neither do we.
+ if err := writeAddress(t, peer, peerLen, addr, addrLen); err == syserror.EINVAL {
+ return 0, err
+ }
+ }
+ return uintptr(nfd), nil
+}
+
+// Accept4 implements the linux syscall accept4(2).
+func Accept4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+ flags := int(args[3].Int())
+
+ n, err := accept(t, fd, addr, addrlen, flags)
+ return n, nil, err
+}
+
+// Accept implements the linux syscall accept(2).
+func Accept(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+
+ n, err := accept(t, fd, addr, addrlen, 0)
+ return n, nil, err
+}
+
+// Bind implements the linux syscall bind(2).
+func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Uint()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Capture address and call syscall implementation.
+ a, err := CaptureAddress(t, addr, addrlen)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, s.Bind(t, a).ToError()
+}
+
+// Listen implements the linux syscall listen(2).
+func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ backlog := args[1].Int()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Per Linux, the backlog is silently capped to reasonable values.
+ if backlog <= 0 {
+ backlog = minListenBacklog
+ }
+ if backlog > maxListenBacklog {
+ backlog = maxListenBacklog
+ }
+
+ return 0, nil, s.Listen(t, int(backlog)).ToError()
+}
+
+// Shutdown implements the linux syscall shutdown(2).
+func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ how := args[1].Int()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Validate how, then call syscall implementation.
+ switch how {
+ case linux.SHUT_RD, linux.SHUT_WR, linux.SHUT_RDWR:
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, s.Shutdown(t, int(how)).ToError()
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2).
+func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ level := args[1].Int()
+ name := args[2].Int()
+ optValAddr := args[3].Pointer()
+ optLenAddr := args[4].Pointer()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Read the length. Reject negative values.
+ optLen := int32(0)
+ if _, err := t.CopyIn(optLenAddr, &optLen); err != nil {
+ return 0, nil, err
+ }
+ if optLen < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Call syscall implementation then copy both value and value len out.
+ v, e := getSockOpt(t, s, int(level), int(name), optValAddr, int(optLen))
+ if e != nil {
+ return 0, nil, e.ToError()
+ }
+
+ vLen := int32(binary.Size(v))
+ if _, err := t.CopyOut(optLenAddr, vLen); err != nil {
+ return 0, nil, err
+ }
+
+ if v != nil {
+ if _, err := t.CopyOut(optValAddr, v); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ return 0, nil, nil
+}
+
+// getSockOpt tries to handle common socket options, or dispatches to a specific
+// socket implementation.
+func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) {
+ if level == linux.SOL_SOCKET {
+ switch name {
+ case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
+ if len < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ }
+
+ switch name {
+ case linux.SO_TYPE:
+ _, skType, _ := s.Type()
+ return int32(skType), nil
+ case linux.SO_DOMAIN:
+ family, _, _ := s.Type()
+ return int32(family), nil
+ case linux.SO_PROTOCOL:
+ _, _, protocol := s.Type()
+ return int32(protocol), nil
+ }
+ }
+
+ return s.GetSockOpt(t, level, name, optValAddr, len)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2).
+//
+// Note that unlike Linux, enabling SO_PASSCRED does not autobind the socket.
+func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ level := args[1].Int()
+ name := args[2].Int()
+ optValAddr := args[3].Pointer()
+ optLen := args[4].Int()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ if optLen < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if optLen > maxOptLen {
+ return 0, nil, syserror.EINVAL
+ }
+ buf := t.CopyScratchBuffer(int(optLen))
+ if _, err := t.CopyIn(optValAddr, &buf); err != nil {
+ return 0, nil, err
+ }
+
+ // Call syscall implementation.
+ if err := s.SetSockOpt(t, int(level), int(name), buf); err != nil {
+ return 0, nil, err.ToError()
+ }
+
+ return 0, nil, nil
+}
+
+// GetSockName implements the linux syscall getsockname(2).
+func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Get the socket name and copy it to the caller.
+ v, vl, err := s.GetSockName(t)
+ if err != nil {
+ return 0, nil, err.ToError()
+ }
+
+ return 0, nil, writeAddress(t, v, vl, addr, addrlen)
+}
+
+// GetPeerName implements the linux syscall getpeername(2).
+func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Get the socket peer name and copy it to the caller.
+ v, vl, err := s.GetPeerName(t)
+ if err != nil {
+ return 0, nil, err.ToError()
+ }
+
+ return 0, nil, writeAddress(t, v, vl, addr, addrlen)
+}
+
+// RecvMsg implements the linux syscall recvmsg(2).
+func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ flags := args[2].Int()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.RecvTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ n, err := recvSingleMsg(t, s, msgPtr, flags, haveDeadline, deadline)
+ return n, nil, err
+}
+
+// RecvMMsg implements the linux syscall recvmmsg(2).
+func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ vlen := args[2].Uint()
+ flags := args[3].Int()
+ toPtr := args[4].Pointer()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(baseRecvFlags|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if toPtr != 0 {
+ var ts linux.Timespec
+ if _, err := ts.CopyIn(t, toPtr); err != nil {
+ return 0, nil, err
+ }
+ if !ts.Valid() {
+ return 0, nil, syserror.EINVAL
+ }
+ deadline = t.Kernel().MonotonicClock().Now().Add(ts.ToDuration())
+ haveDeadline = true
+ }
+
+ if !haveDeadline {
+ if dl := s.RecvTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+ }
+
+ var count uint32
+ var err error
+ for i := uint64(0); i < uint64(vlen); i++ {
+ mp, ok := msgPtr.AddLength(i * multipleMessageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ var n uintptr
+ if n, err = recvSingleMsg(t, s, mp, flags, haveDeadline, deadline); err != nil {
+ break
+ }
+
+ // Copy the received length to the caller.
+ lp, ok := mp.AddLength(messageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ break
+ }
+ count++
+ }
+
+ if count == 0 {
+ return 0, nil, err
+ }
+ return uintptr(count), nil, nil
+}
+
+func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) {
+ // Capture the message header and io vectors.
+ var msg MessageHeader64
+ if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ return 0, err
+ }
+
+ if msg.IovLen > linux.UIO_MAXIOV {
+ return 0, syserror.EMSGSIZE
+ }
+ dst, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ // FIXME(b/63594852): Pretend we have an empty error queue.
+ if flags&linux.MSG_ERRQUEUE != 0 {
+ return 0, syserror.EAGAIN
+ }
+
+ // Fast path when no control message nor name buffers are provided.
+ if msg.ControlLen == 0 && msg.NameLen == 0 {
+ n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0)
+ if err != nil {
+ return 0, syserror.ConvertIntr(err.ToError(), kernel.ERESTARTSYS)
+ }
+ if !cms.Unix.Empty() {
+ mflags |= linux.MSG_CTRUNC
+ cms.Release()
+ }
+
+ if int(msg.Flags) != mflags {
+ // Copy out the flags to the caller.
+ if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ return 0, err
+ }
+ }
+
+ return uintptr(n), nil
+ }
+
+ if msg.ControlLen > maxControlLen {
+ return 0, syserror.ENOBUFS
+ }
+ n, mflags, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen)
+ if e != nil {
+ return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
+ }
+ defer cms.Release()
+
+ controlData := make([]byte, 0, msg.ControlLen)
+ controlData = control.PackControlMessages(t, cms, controlData)
+
+ if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() {
+ creds, _ := cms.Unix.Credentials.(control.SCMCredentials)
+ controlData, mflags = control.PackCredentials(t, creds, controlData, mflags)
+ }
+
+ if cms.Unix.Rights != nil {
+ controlData, mflags = control.PackRightsVFS2(t, cms.Unix.Rights.(control.SCMRightsVFS2), flags&linux.MSG_CMSG_CLOEXEC != 0, controlData, mflags)
+ }
+
+ // Copy the address to the caller.
+ if msg.NameLen != 0 {
+ if err := writeAddress(t, sender, senderLen, usermem.Addr(msg.Name), usermem.Addr(msgPtr+nameLenOffset)); err != nil {
+ return 0, err
+ }
+ }
+
+ // Copy the control data to the caller.
+ if _, err := t.CopyOut(msgPtr+controlLenOffset, uint64(len(controlData))); err != nil {
+ return 0, err
+ }
+ if len(controlData) > 0 {
+ if _, err := t.CopyOut(usermem.Addr(msg.Control), controlData); err != nil {
+ return 0, err
+ }
+ }
+
+ // Copy out the flags to the caller.
+ if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ return 0, err
+ }
+
+ return uintptr(n), nil
+}
+
+// recvFrom is the implementation of the recvfrom syscall. It is called by
+// recvfrom and recv syscall handlers.
+func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLenPtr usermem.Addr) (uintptr, error) {
+ if int(bufLen) < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CONFIRM) != 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, syserror.ENOTSOCK
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ dst, err := t.SingleIOSequence(bufPtr, int(bufLen), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.RecvTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0)
+ cm.Release()
+ if e != nil {
+ return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
+ }
+
+ // Copy the address to the caller.
+ if nameLenPtr != 0 {
+ if err := writeAddress(t, sender, senderLen, namePtr, nameLenPtr); err != nil {
+ return 0, err
+ }
+ }
+
+ return uintptr(n), nil
+}
+
+// RecvFrom implements the linux syscall recvfrom(2).
+func RecvFrom(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ bufPtr := args[1].Pointer()
+ bufLen := args[2].Uint64()
+ flags := args[3].Int()
+ namePtr := args[4].Pointer()
+ nameLenPtr := args[5].Pointer()
+
+ n, err := recvFrom(t, fd, bufPtr, bufLen, flags, namePtr, nameLenPtr)
+ return n, nil, err
+}
+
+// SendMsg implements the linux syscall sendmsg(2).
+func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ flags := args[2].Int()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ n, err := sendSingleMsg(t, s, file, msgPtr, flags)
+ return n, nil, err
+}
+
+// SendMMsg implements the linux syscall sendmmsg(2).
+func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ vlen := args[2].Uint()
+ flags := args[3].Int()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ var count uint32
+ var err error
+ for i := uint64(0); i < uint64(vlen); i++ {
+ mp, ok := msgPtr.AddLength(i * multipleMessageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ var n uintptr
+ if n, err = sendSingleMsg(t, s, file, mp, flags); err != nil {
+ break
+ }
+
+ // Copy the received length to the caller.
+ lp, ok := mp.AddLength(messageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ break
+ }
+ count++
+ }
+
+ if count == 0 {
+ return 0, nil, err
+ }
+ return uintptr(count), nil, nil
+}
+
+func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescription, msgPtr usermem.Addr, flags int32) (uintptr, error) {
+ // Capture the message header.
+ var msg MessageHeader64
+ if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ return 0, err
+ }
+
+ var controlData []byte
+ if msg.ControlLen > 0 {
+ // Put an upper bound to prevent large allocations.
+ if msg.ControlLen > maxControlLen {
+ return 0, syserror.ENOBUFS
+ }
+ controlData = make([]byte, msg.ControlLen)
+ if _, err := t.CopyIn(usermem.Addr(msg.Control), &controlData); err != nil {
+ return 0, err
+ }
+ }
+
+ // Read the destination address if one is specified.
+ var to []byte
+ if msg.NameLen != 0 {
+ var err error
+ to, err = CaptureAddress(t, usermem.Addr(msg.Name), msg.NameLen)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ // Read data then call the sendmsg implementation.
+ if msg.IovLen > linux.UIO_MAXIOV {
+ return 0, syserror.EMSGSIZE
+ }
+ src, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ controlMessages, err := control.Parse(t, s, controlData)
+ if err != nil {
+ return 0, err
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.SendTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ // Call the syscall implementation.
+ n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages)
+ err = slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file)
+ if err != nil {
+ controlMessages.Release()
+ }
+ return uintptr(n), err
+}
+
+// sendTo is the implementation of the sendto syscall. It is called by sendto
+// and send syscall handlers.
+func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLen uint32) (uintptr, error) {
+ bl := int(bufLen)
+ if bl < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, syserror.ENOTSOCK
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ // Read the destination address if one is specified.
+ var to []byte
+ var err error
+ if namePtr != 0 {
+ to, err = CaptureAddress(t, namePtr, nameLen)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ src, err := t.SingleIOSequence(bufPtr, bl, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.SendTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ // Call the syscall implementation.
+ n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: control.New(t, s, nil)})
+ return uintptr(n), slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendto", file)
+}
+
+// SendTo implements the linux syscall sendto(2).
+func SendTo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ bufPtr := args[1].Pointer()
+ bufLen := args[2].Uint64()
+ flags := args[3].Int()
+ namePtr := args[4].Pointer()
+ nameLen := args[5].Uint()
+
+ n, err := sendTo(t, fd, bufPtr, bufLen, flags, namePtr, nameLen)
+ return n, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go
new file mode 100644
index 000000000..945a364a7
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/splice.go
@@ -0,0 +1,291 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// Splice implements Linux syscall splice(2).
+func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ inFD := args[0].Int()
+ inOffsetPtr := args[1].Pointer()
+ outFD := args[2].Int()
+ outOffsetPtr := args[3].Pointer()
+ count := int64(args[4].SizeT())
+ flags := args[5].Int()
+
+ if count == 0 {
+ return 0, nil, nil
+ }
+ if count > int64(kernel.MAX_RW_COUNT) {
+ count = int64(kernel.MAX_RW_COUNT)
+ }
+
+ // Check for invalid flags.
+ if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get file descriptions.
+ inFile := t.GetFileVFS2(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+ outFile := t.GetFileVFS2(outFD)
+ if outFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer outFile.DecRef()
+
+ // Check that both files support the required directionality.
+ if !inFile.IsReadable() || !outFile.IsWritable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ // The operation is non-blocking if anything is non-blocking.
+ //
+ // N.B. This is a rather simplistic heuristic that avoids some
+ // poor edge case behavior since the exact semantics here are
+ // underspecified and vary between versions of Linux itself.
+ nonBlock := ((inFile.StatusFlags()|outFile.StatusFlags())&linux.O_NONBLOCK != 0) || (flags&linux.SPLICE_F_NONBLOCK != 0)
+
+ // At least one file description must represent a pipe.
+ inPipeFD, inIsPipe := inFile.Impl().(*pipe.VFSPipeFD)
+ outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD)
+ if !inIsPipe && !outIsPipe {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Copy in offsets.
+ inOffset := int64(-1)
+ if inOffsetPtr != 0 {
+ if inIsPipe {
+ return 0, nil, syserror.ESPIPE
+ }
+ if inFile.Options().DenyPRead {
+ return 0, nil, syserror.EINVAL
+ }
+ if _, err := t.CopyIn(inOffsetPtr, &inOffset); err != nil {
+ return 0, nil, err
+ }
+ if inOffset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ }
+ outOffset := int64(-1)
+ if outOffsetPtr != 0 {
+ if outIsPipe {
+ return 0, nil, syserror.ESPIPE
+ }
+ if outFile.Options().DenyPWrite {
+ return 0, nil, syserror.EINVAL
+ }
+ if _, err := t.CopyIn(outOffsetPtr, &outOffset); err != nil {
+ return 0, nil, err
+ }
+ if outOffset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ }
+
+ // Move data.
+ var (
+ n int64
+ err error
+ inCh chan struct{}
+ outCh chan struct{}
+ )
+ for {
+ // If both input and output are pipes, delegate to the pipe
+ // implementation. Otherwise, exactly one end is a pipe, which we
+ // ensure is consistently ordered after the non-pipe FD's locks by
+ // passing the pipe FD as usermem.IO to the non-pipe end.
+ switch {
+ case inIsPipe && outIsPipe:
+ n, err = pipe.Splice(t, outPipeFD, inPipeFD, count)
+ case inIsPipe:
+ if outOffset != -1 {
+ n, err = outFile.PWrite(t, inPipeFD.IOSequence(count), outOffset, vfs.WriteOptions{})
+ outOffset += n
+ } else {
+ n, err = outFile.Write(t, inPipeFD.IOSequence(count), vfs.WriteOptions{})
+ }
+ case outIsPipe:
+ if inOffset != -1 {
+ n, err = inFile.PRead(t, outPipeFD.IOSequence(count), inOffset, vfs.ReadOptions{})
+ inOffset += n
+ } else {
+ n, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{})
+ }
+ }
+ if n != 0 || err != syserror.ErrWouldBlock || nonBlock {
+ break
+ }
+
+ // Note that the blocking behavior here is a bit different than the
+ // normal pattern. Because we need to have both data to read and data
+ // to write simultaneously, we actually explicitly block on both of
+ // these cases in turn before returning to the splice operation.
+ if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 {
+ if inCh == nil {
+ inCh = make(chan struct{}, 1)
+ inW, _ := waiter.NewChannelEntry(inCh)
+ inFile.EventRegister(&inW, eventMaskRead)
+ defer inFile.EventUnregister(&inW)
+ continue // Need to refresh readiness.
+ }
+ if err = t.Block(inCh); err != nil {
+ break
+ }
+ }
+ if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 {
+ if outCh == nil {
+ outCh = make(chan struct{}, 1)
+ outW, _ := waiter.NewChannelEntry(outCh)
+ outFile.EventRegister(&outW, eventMaskWrite)
+ defer outFile.EventUnregister(&outW)
+ continue // Need to refresh readiness.
+ }
+ if err = t.Block(outCh); err != nil {
+ break
+ }
+ }
+ }
+
+ // Copy updated offsets out.
+ if inOffsetPtr != 0 {
+ if _, err := t.CopyOut(inOffsetPtr, &inOffset); err != nil {
+ return 0, nil, err
+ }
+ }
+ if outOffsetPtr != 0 {
+ if _, err := t.CopyOut(outOffsetPtr, &outOffset); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ if n == 0 {
+ return 0, nil, err
+ }
+
+ // On Linux, inotify behavior is not very consistent with splice(2). We try
+ // our best to emulate Linux for very basic calls to splice, where for some
+ // reason, events are generated for output files, but not input files.
+ outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ return uintptr(n), nil, nil
+}
+
+// Tee implements Linux syscall tee(2).
+func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ inFD := args[0].Int()
+ outFD := args[1].Int()
+ count := int64(args[2].SizeT())
+ flags := args[3].Int()
+
+ if count == 0 {
+ return 0, nil, nil
+ }
+ if count > int64(kernel.MAX_RW_COUNT) {
+ count = int64(kernel.MAX_RW_COUNT)
+ }
+
+ // Check for invalid flags.
+ if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get file descriptions.
+ inFile := t.GetFileVFS2(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+ outFile := t.GetFileVFS2(outFD)
+ if outFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer outFile.DecRef()
+
+ // Check that both files support the required directionality.
+ if !inFile.IsReadable() || !outFile.IsWritable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ // The operation is non-blocking if anything is non-blocking.
+ //
+ // N.B. This is a rather simplistic heuristic that avoids some
+ // poor edge case behavior since the exact semantics here are
+ // underspecified and vary between versions of Linux itself.
+ nonBlock := ((inFile.StatusFlags()|outFile.StatusFlags())&linux.O_NONBLOCK != 0) || (flags&linux.SPLICE_F_NONBLOCK != 0)
+
+ // Both file descriptions must represent pipes.
+ inPipeFD, inIsPipe := inFile.Impl().(*pipe.VFSPipeFD)
+ outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD)
+ if !inIsPipe || !outIsPipe {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Copy data.
+ var (
+ inCh chan struct{}
+ outCh chan struct{}
+ )
+ for {
+ n, err := pipe.Tee(t, outPipeFD, inPipeFD, count)
+ if n != 0 {
+ return uintptr(n), nil, nil
+ }
+ if err != syserror.ErrWouldBlock || nonBlock {
+ return 0, nil, err
+ }
+
+ // Note that the blocking behavior here is a bit different than the
+ // normal pattern. Because we need to have both data to read and data
+ // to write simultaneously, we actually explicitly block on both of
+ // these cases in turn before returning to the tee operation.
+ if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 {
+ if inCh == nil {
+ inCh = make(chan struct{}, 1)
+ inW, _ := waiter.NewChannelEntry(inCh)
+ inFile.EventRegister(&inW, eventMaskRead)
+ defer inFile.EventUnregister(&inW)
+ continue // Need to refresh readiness.
+ }
+ if err := t.Block(inCh); err != nil {
+ return 0, nil, err
+ }
+ }
+ if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 {
+ if outCh == nil {
+ outCh = make(chan struct{}, 1)
+ outW, _ := waiter.NewChannelEntry(outCh)
+ outFile.EventRegister(&outW, eventMaskWrite)
+ defer outFile.EventUnregister(&outW)
+ continue // Need to refresh readiness.
+ }
+ if err := t.Block(outCh); err != nil {
+ return 0, nil, err
+ }
+ }
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go
new file mode 100644
index 000000000..bb1d5cac4
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/stat.go
@@ -0,0 +1,388 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Stat implements Linux syscall stat(2).
+func Stat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ statAddr := args[1].Pointer()
+ return 0, nil, fstatat(t, linux.AT_FDCWD, pathAddr, statAddr, 0 /* flags */)
+}
+
+// Lstat implements Linux syscall lstat(2).
+func Lstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ statAddr := args[1].Pointer()
+ return 0, nil, fstatat(t, linux.AT_FDCWD, pathAddr, statAddr, linux.AT_SYMLINK_NOFOLLOW)
+}
+
+// Newfstatat implements Linux syscall newfstatat, which backs fstatat(2).
+func Newfstatat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ statAddr := args[2].Pointer()
+ flags := args[3].Int()
+ return 0, nil, fstatat(t, dirfd, pathAddr, statAddr, flags)
+}
+
+func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+
+ opts := vfs.StatOptions{
+ Mask: linux.STATX_BASIC_STATS,
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.Stat() instead of
+ // VirtualFilesystem.StatAt() for fstatat(fd, ""), since the
+ // former may be able to use opened file state to expedite the
+ // Stat.
+ statx, err := dirfile.Stat(t, opts)
+ dirfile.DecRef()
+ if err != nil {
+ return err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ _, err = stat.CopyOut(t, statAddr)
+ return err
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef()
+ dirfile.DecRef()
+ }
+ }
+
+ statx, err := t.Kernel().VFS().StatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &opts)
+ if err != nil {
+ return err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ _, err = stat.CopyOut(t, statAddr)
+ return err
+}
+
+func timespecFromStatxTimestamp(sxts linux.StatxTimestamp) linux.Timespec {
+ return linux.Timespec{
+ Sec: sxts.Sec,
+ Nsec: int64(sxts.Nsec),
+ }
+}
+
+// Fstat implements Linux syscall fstat(2).
+func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ statAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ statx, err := file.Stat(t, vfs.StatOptions{
+ Mask: linux.STATX_BASIC_STATS,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ _, err = stat.CopyOut(t, statAddr)
+ return 0, nil, err
+}
+
+// Statx implements Linux syscall statx(2).
+func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ flags := args[2].Int()
+ mask := args[3].Uint()
+ statxAddr := args[4].Pointer()
+
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW|linux.AT_STATX_SYNC_TYPE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ // Make sure that only one sync type option is set.
+ syncType := uint32(flags & linux.AT_STATX_SYNC_TYPE)
+ if syncType != 0 && !bits.IsPowerOfTwo32(syncType) {
+ return 0, nil, syserror.EINVAL
+ }
+ if mask&linux.STATX__RESERVED != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ opts := vfs.StatOptions{
+ Mask: mask,
+ Sync: uint32(flags & linux.AT_STATX_SYNC_TYPE),
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return 0, nil, syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.Stat() instead of
+ // VirtualFilesystem.StatAt() for statx(fd, ""), since the
+ // former may be able to use opened file state to expedite the
+ // Stat.
+ statx, err := dirfile.Stat(t, opts)
+ dirfile.DecRef()
+ if err != nil {
+ return 0, nil, err
+ }
+ userifyStatx(t, &statx)
+ _, err = statx.CopyOut(t, statxAddr)
+ return 0, nil, err
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef()
+ dirfile.DecRef()
+ }
+ }
+
+ statx, err := t.Kernel().VFS().StatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &opts)
+ if err != nil {
+ return 0, nil, err
+ }
+ userifyStatx(t, &statx)
+ _, err = statx.CopyOut(t, statxAddr)
+ return 0, nil, err
+}
+
+func userifyStatx(t *kernel.Task, statx *linux.Statx) {
+ userns := t.UserNamespace()
+ statx.UID = uint32(auth.KUID(statx.UID).In(userns).OrOverflow())
+ statx.GID = uint32(auth.KGID(statx.GID).In(userns).OrOverflow())
+}
+
+// Readlink implements Linux syscall readlink(2).
+func Readlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ bufAddr := args[1].Pointer()
+ size := args[2].SizeT()
+ return readlinkat(t, linux.AT_FDCWD, pathAddr, bufAddr, size)
+}
+
+// Access implements Linux syscall access(2).
+func Access(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+
+ return 0, nil, accessAt(t, linux.AT_FDCWD, addr, mode)
+}
+
+// Faccessat implements Linux syscall faccessat(2).
+//
+// Note that the faccessat() system call does not take a flags argument:
+// "The raw faccessat() system call takes only the first three arguments. The
+// AT_EACCESS and AT_SYMLINK_NOFOLLOW flags are actually implemented within
+// the glibc wrapper function for faccessat(). If either of these flags is
+// specified, then the wrapper function employs fstatat(2) to determine access
+// permissions." - faccessat(2)
+func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+
+ return 0, nil, accessAt(t, dirfd, addr, mode)
+}
+
+func accessAt(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) error {
+ const rOK = 4
+ const wOK = 2
+ const xOK = 1
+
+ // Sanity check the mode.
+ if mode&^(rOK|wOK|xOK) != 0 {
+ return syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+
+ // access(2) and faccessat(2) check permissions using real
+ // UID/GID, not effective UID/GID.
+ //
+ // "access() needs to use the real uid/gid, not the effective
+ // uid/gid. We do this by temporarily clearing all FS-related
+ // capabilities and switching the fsuid/fsgid around to the
+ // real ones." -fs/open.c:faccessat
+ creds := t.Credentials().Fork()
+ creds.EffectiveKUID = creds.RealKUID
+ creds.EffectiveKGID = creds.RealKGID
+ if creds.EffectiveKUID.In(creds.UserNamespace) == auth.RootUID {
+ creds.EffectiveCaps = creds.PermittedCaps
+ } else {
+ creds.EffectiveCaps = 0
+ }
+
+ return t.Kernel().VFS().AccessAt(t, creds, vfs.AccessTypes(mode), &tpop.pop)
+}
+
+// Readlinkat implements Linux syscall mknodat(2).
+func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ bufAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ return readlinkat(t, dirfd, pathAddr, bufAddr, size)
+}
+
+func readlinkat(t *kernel.Task, dirfd int32, pathAddr, bufAddr usermem.Addr, size uint) (uintptr, *kernel.SyscallControl, error) {
+ if int(size) <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ // "Since Linux 2.6.39, pathname can be an empty string, in which case the
+ // call operates on the symbolic link referred to by dirfd ..." -
+ // readlinkat(2)
+ tpop, err := getTaskPathOperation(t, dirfd, path, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ target, err := t.Kernel().VFS().ReadlinkAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if len(target) > int(size) {
+ target = target[:size]
+ }
+ n, err := t.CopyOutBytes(bufAddr, gohacks.ImmutableBytesFromString(target))
+ if n == 0 {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Statfs implements Linux syscall statfs(2).
+func Statfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ bufAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+ _, err = statfs.CopyOut(t, bufAddr)
+ return 0, nil, err
+}
+
+// Fstatfs implements Linux syscall fstatfs(2).
+func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ bufAddr := args[1].Pointer()
+
+ tpop, err := getTaskPathOperation(t, fd, fspath.Path{}, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+ _, err = statfs.CopyOut(t, bufAddr)
+ return 0, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go b/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go
new file mode 100644
index 000000000..2da538fc6
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go
@@ -0,0 +1,46 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// This takes both input and output as pointer arguments to avoid copying large
+// structs.
+func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) {
+ // Linux just copies fields from struct kstat without regard to struct
+ // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too.
+ userns := t.UserNamespace()
+ *stat = linux.Stat{
+ Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)),
+ Ino: statx.Ino,
+ Nlink: uint64(statx.Nlink),
+ Mode: uint32(statx.Mode),
+ UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()),
+ GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)),
+ Size: int64(statx.Size),
+ Blksize: int64(statx.Blksize),
+ Blocks: int64(statx.Blocks),
+ ATime: timespecFromStatxTimestamp(statx.Atime),
+ MTime: timespecFromStatxTimestamp(statx.Mtime),
+ CTime: timespecFromStatxTimestamp(statx.Ctime),
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go b/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go
new file mode 100644
index 000000000..88b9c7627
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go
@@ -0,0 +1,46 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// This takes both input and output as pointer arguments to avoid copying large
+// structs.
+func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) {
+ // Linux just copies fields from struct kstat without regard to struct
+ // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too.
+ userns := t.UserNamespace()
+ *stat = linux.Stat{
+ Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)),
+ Ino: statx.Ino,
+ Nlink: uint32(statx.Nlink),
+ Mode: uint32(statx.Mode),
+ UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()),
+ GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)),
+ Size: int64(statx.Size),
+ Blksize: int32(statx.Blksize),
+ Blocks: int64(statx.Blocks),
+ ATime: timespecFromStatxTimestamp(statx.Atime),
+ MTime: timespecFromStatxTimestamp(statx.Mtime),
+ CTime: timespecFromStatxTimestamp(statx.Ctime),
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go
new file mode 100644
index 000000000..0d0ebf46a
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/sync.go
@@ -0,0 +1,115 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Sync implements Linux syscall sync(2).
+func Sync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, t.Kernel().VFS().SyncAllFilesystems(t)
+}
+
+// Syncfs implements Linux syscall syncfs(2).
+func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.SyncFS(t)
+}
+
+// Fsync implements Linux syscall fsync(2).
+func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.Sync(t)
+}
+
+// Fdatasync implements Linux syscall fdatasync(2).
+func Fdatasync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // TODO(gvisor.dev/issue/1897): Avoid writeback of unnecessary metadata.
+ return Fsync(t, args)
+}
+
+// SyncFileRange implements Linux syscall sync_file_range(2).
+func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ nbytes := args[2].Int64()
+ flags := args[3].Uint()
+
+ // Check for negative values and overflow.
+ if offset < 0 || offset+nbytes < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if flags&^(linux.SYNC_FILE_RANGE_WAIT_BEFORE|linux.SYNC_FILE_RANGE_WRITE|linux.SYNC_FILE_RANGE_WAIT_AFTER) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // TODO(gvisor.dev/issue/1897): Currently, the only file syncing we support
+ // is a full-file sync, i.e. fsync(2). As a result, there are severe
+ // limitations on how much we support sync_file_range:
+ // - In Linux, sync_file_range(2) doesn't write out the file's metadata, even
+ // if the file size is changed. We do.
+ // - We always sync the entire file instead of [offset, offset+nbytes).
+ // - We do not support the use of WAIT_BEFORE without WAIT_AFTER. For
+ // correctness, we would have to perform a write-out every time WAIT_BEFORE
+ // was used, but this would be much more expensive than expected if there
+ // were no write-out operations in progress.
+ // - Whenever WAIT_AFTER is used, we sync the file.
+ // - Ignore WRITE. If this flag is used with WAIT_AFTER, then the file will
+ // be synced anyway. If this flag is used without WAIT_AFTER, then it is
+ // safe (and less expensive) to do nothing, because the syscall will not
+ // wait for the write-out to complete--we only need to make sure that the
+ // next time WAIT_BEFORE or WAIT_AFTER are used, the write-out completes.
+ // - According to fs/sync.c, WAIT_BEFORE|WAIT_AFTER "will detect any I/O
+ // errors or ENOSPC conditions and will return those to the caller, after
+ // clearing the EIO and ENOSPC flags in the address_space." We don't do
+ // this.
+
+ if flags&linux.SYNC_FILE_RANGE_WAIT_BEFORE != 0 &&
+ flags&linux.SYNC_FILE_RANGE_WAIT_AFTER == 0 {
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.ENOSYS
+ }
+
+ if flags&linux.SYNC_FILE_RANGE_WAIT_AFTER != 0 {
+ if err := file.Sync(t); err != nil {
+ return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
+ }
+ }
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/timerfd.go b/pkg/sentry/syscalls/linux/vfs2/timerfd.go
new file mode 100644
index 000000000..5ac79bc09
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/timerfd.go
@@ -0,0 +1,127 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/timerfd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// TimerfdCreate implements Linux syscall timerfd_create(2).
+func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ clockID := args[0].Int()
+ flags := args[1].Int()
+
+ if flags&^(linux.TFD_CLOEXEC|linux.TFD_NONBLOCK) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Timerfds aren't writable per se (their implementation of Write just
+ // returns EINVAL), but they are "opened for writing", which is necessary
+ // to actually reach said implementation of Write.
+ fileFlags := uint32(linux.O_RDWR)
+ if flags&linux.TFD_NONBLOCK != 0 {
+ fileFlags |= linux.O_NONBLOCK
+ }
+
+ var clock ktime.Clock
+ switch clockID {
+ case linux.CLOCK_REALTIME:
+ clock = t.Kernel().RealtimeClock()
+ case linux.CLOCK_MONOTONIC, linux.CLOCK_BOOTTIME:
+ clock = t.Kernel().MonotonicClock()
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+ vfsObj := t.Kernel().VFS()
+ file, err := timerfd.New(vfsObj, clock, fileFlags)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.TFD_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// TimerfdSettime implements Linux syscall timerfd_settime(2).
+func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ flags := args[1].Int()
+ newValAddr := args[2].Pointer()
+ oldValAddr := args[3].Pointer()
+
+ if flags&^(linux.TFD_TIMER_ABSTIME) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ tfd, ok := file.Impl().(*timerfd.TimerFileDescription)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var newVal linux.Itimerspec
+ if _, err := t.CopyIn(newValAddr, &newVal); err != nil {
+ return 0, nil, err
+ }
+ newS, err := ktime.SettingFromItimerspec(newVal, flags&linux.TFD_TIMER_ABSTIME != 0, tfd.Clock())
+ if err != nil {
+ return 0, nil, err
+ }
+ tm, oldS := tfd.SetTime(newS)
+ if oldValAddr != 0 {
+ oldVal := ktime.ItimerspecFromSetting(tm, oldS)
+ if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil {
+ return 0, nil, err
+ }
+ }
+ return 0, nil, nil
+}
+
+// TimerfdGettime implements Linux syscall timerfd_gettime(2).
+func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ curValAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ tfd, ok := file.Impl().(*timerfd.TimerFileDescription)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ tm, s := tfd.GetTime()
+ curVal := ktime.ItimerspecFromSetting(tm, s)
+ _, err := t.CopyOut(curValAddr, &curVal)
+ return 0, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
new file mode 100644
index 000000000..8f497ecc7
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
@@ -0,0 +1,168 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package vfs2 provides syscall implementations that use VFS2.
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/syscalls"
+ "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+)
+
+// Override syscall table to add syscalls implementations from this package.
+func Override() {
+ // Override AMD64.
+ s := linux.AMD64
+ s.Table[0] = syscalls.Supported("read", Read)
+ s.Table[1] = syscalls.Supported("write", Write)
+ s.Table[2] = syscalls.Supported("open", Open)
+ s.Table[3] = syscalls.Supported("close", Close)
+ s.Table[4] = syscalls.Supported("stat", Stat)
+ s.Table[5] = syscalls.Supported("fstat", Fstat)
+ s.Table[6] = syscalls.Supported("lstat", Lstat)
+ s.Table[7] = syscalls.Supported("poll", Poll)
+ s.Table[8] = syscalls.Supported("lseek", Lseek)
+ s.Table[9] = syscalls.Supported("mmap", Mmap)
+ s.Table[16] = syscalls.Supported("ioctl", Ioctl)
+ s.Table[17] = syscalls.Supported("pread64", Pread64)
+ s.Table[18] = syscalls.Supported("pwrite64", Pwrite64)
+ s.Table[19] = syscalls.Supported("readv", Readv)
+ s.Table[20] = syscalls.Supported("writev", Writev)
+ s.Table[21] = syscalls.Supported("access", Access)
+ s.Table[22] = syscalls.Supported("pipe", Pipe)
+ s.Table[23] = syscalls.Supported("select", Select)
+ s.Table[32] = syscalls.Supported("dup", Dup)
+ s.Table[33] = syscalls.Supported("dup2", Dup2)
+ delete(s.Table, 40) // sendfile
+ s.Table[41] = syscalls.Supported("socket", Socket)
+ s.Table[42] = syscalls.Supported("connect", Connect)
+ s.Table[43] = syscalls.Supported("accept", Accept)
+ s.Table[44] = syscalls.Supported("sendto", SendTo)
+ s.Table[45] = syscalls.Supported("recvfrom", RecvFrom)
+ s.Table[46] = syscalls.Supported("sendmsg", SendMsg)
+ s.Table[47] = syscalls.Supported("recvmsg", RecvMsg)
+ s.Table[48] = syscalls.Supported("shutdown", Shutdown)
+ s.Table[49] = syscalls.Supported("bind", Bind)
+ s.Table[50] = syscalls.Supported("listen", Listen)
+ s.Table[51] = syscalls.Supported("getsockname", GetSockName)
+ s.Table[52] = syscalls.Supported("getpeername", GetPeerName)
+ s.Table[53] = syscalls.Supported("socketpair", SocketPair)
+ s.Table[54] = syscalls.Supported("setsockopt", SetSockOpt)
+ s.Table[55] = syscalls.Supported("getsockopt", GetSockOpt)
+ s.Table[59] = syscalls.Supported("execve", Execve)
+ s.Table[72] = syscalls.Supported("fcntl", Fcntl)
+ s.Table[73] = syscalls.Supported("fcntl", Flock)
+ s.Table[74] = syscalls.Supported("fsync", Fsync)
+ s.Table[75] = syscalls.Supported("fdatasync", Fdatasync)
+ s.Table[76] = syscalls.Supported("truncate", Truncate)
+ s.Table[77] = syscalls.Supported("ftruncate", Ftruncate)
+ s.Table[78] = syscalls.Supported("getdents", Getdents)
+ s.Table[79] = syscalls.Supported("getcwd", Getcwd)
+ s.Table[80] = syscalls.Supported("chdir", Chdir)
+ s.Table[81] = syscalls.Supported("fchdir", Fchdir)
+ s.Table[82] = syscalls.Supported("rename", Rename)
+ s.Table[83] = syscalls.Supported("mkdir", Mkdir)
+ s.Table[84] = syscalls.Supported("rmdir", Rmdir)
+ s.Table[85] = syscalls.Supported("creat", Creat)
+ s.Table[86] = syscalls.Supported("link", Link)
+ s.Table[87] = syscalls.Supported("unlink", Unlink)
+ s.Table[88] = syscalls.Supported("symlink", Symlink)
+ s.Table[89] = syscalls.Supported("readlink", Readlink)
+ s.Table[90] = syscalls.Supported("chmod", Chmod)
+ s.Table[91] = syscalls.Supported("fchmod", Fchmod)
+ s.Table[92] = syscalls.Supported("chown", Chown)
+ s.Table[93] = syscalls.Supported("fchown", Fchown)
+ s.Table[94] = syscalls.Supported("lchown", Lchown)
+ s.Table[132] = syscalls.Supported("utime", Utime)
+ s.Table[133] = syscalls.Supported("mknod", Mknod)
+ s.Table[137] = syscalls.Supported("statfs", Statfs)
+ s.Table[138] = syscalls.Supported("fstatfs", Fstatfs)
+ s.Table[161] = syscalls.Supported("chroot", Chroot)
+ s.Table[162] = syscalls.Supported("sync", Sync)
+ s.Table[165] = syscalls.Supported("mount", Mount)
+ s.Table[166] = syscalls.Supported("umount2", Umount2)
+ s.Table[187] = syscalls.Supported("readahead", Readahead)
+ s.Table[188] = syscalls.Supported("setxattr", Setxattr)
+ s.Table[189] = syscalls.Supported("lsetxattr", Lsetxattr)
+ s.Table[190] = syscalls.Supported("fsetxattr", Fsetxattr)
+ s.Table[191] = syscalls.Supported("getxattr", Getxattr)
+ s.Table[192] = syscalls.Supported("lgetxattr", Lgetxattr)
+ s.Table[193] = syscalls.Supported("fgetxattr", Fgetxattr)
+ s.Table[194] = syscalls.Supported("listxattr", Listxattr)
+ s.Table[195] = syscalls.Supported("llistxattr", Llistxattr)
+ s.Table[196] = syscalls.Supported("flistxattr", Flistxattr)
+ s.Table[197] = syscalls.Supported("removexattr", Removexattr)
+ s.Table[198] = syscalls.Supported("lremovexattr", Lremovexattr)
+ s.Table[199] = syscalls.Supported("fremovexattr", Fremovexattr)
+ s.Table[209] = syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"})
+ s.Table[213] = syscalls.Supported("epoll_create", EpollCreate)
+ s.Table[217] = syscalls.Supported("getdents64", Getdents64)
+ s.Table[221] = syscalls.PartiallySupported("fadvise64", Fadvise64, "The syscall is 'supported', but ignores all provided advice.", nil)
+ s.Table[232] = syscalls.Supported("epoll_wait", EpollWait)
+ s.Table[233] = syscalls.Supported("epoll_ctl", EpollCtl)
+ s.Table[235] = syscalls.Supported("utimes", Utimes)
+ s.Table[253] = syscalls.PartiallySupported("inotify_init", InotifyInit, "inotify events are only available inside the sandbox.", nil)
+ s.Table[254] = syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil)
+ s.Table[255] = syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil)
+ s.Table[257] = syscalls.Supported("openat", Openat)
+ s.Table[258] = syscalls.Supported("mkdirat", Mkdirat)
+ s.Table[259] = syscalls.Supported("mknodat", Mknodat)
+ s.Table[260] = syscalls.Supported("fchownat", Fchownat)
+ s.Table[261] = syscalls.Supported("futimesat", Futimesat)
+ s.Table[262] = syscalls.Supported("newfstatat", Newfstatat)
+ s.Table[263] = syscalls.Supported("unlinkat", Unlinkat)
+ s.Table[264] = syscalls.Supported("renameat", Renameat)
+ s.Table[265] = syscalls.Supported("linkat", Linkat)
+ s.Table[266] = syscalls.Supported("symlinkat", Symlinkat)
+ s.Table[267] = syscalls.Supported("readlinkat", Readlinkat)
+ s.Table[268] = syscalls.Supported("fchmodat", Fchmodat)
+ s.Table[269] = syscalls.Supported("faccessat", Faccessat)
+ s.Table[270] = syscalls.Supported("pselect", Pselect)
+ s.Table[271] = syscalls.Supported("ppoll", Ppoll)
+ s.Table[275] = syscalls.Supported("splice", Splice)
+ s.Table[276] = syscalls.Supported("tee", Tee)
+ s.Table[277] = syscalls.Supported("sync_file_range", SyncFileRange)
+ s.Table[280] = syscalls.Supported("utimensat", Utimensat)
+ s.Table[281] = syscalls.Supported("epoll_pwait", EpollPwait)
+ s.Table[282] = syscalls.Supported("signalfd", Signalfd)
+ s.Table[283] = syscalls.Supported("timerfd_create", TimerfdCreate)
+ s.Table[284] = syscalls.Supported("eventfd", Eventfd)
+ s.Table[285] = syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil)
+ s.Table[286] = syscalls.Supported("timerfd_settime", TimerfdSettime)
+ s.Table[287] = syscalls.Supported("timerfd_gettime", TimerfdGettime)
+ s.Table[288] = syscalls.Supported("accept4", Accept4)
+ s.Table[289] = syscalls.Supported("signalfd4", Signalfd4)
+ s.Table[290] = syscalls.Supported("eventfd2", Eventfd2)
+ s.Table[291] = syscalls.Supported("epoll_create1", EpollCreate1)
+ s.Table[292] = syscalls.Supported("dup3", Dup3)
+ s.Table[293] = syscalls.Supported("pipe2", Pipe2)
+ s.Table[294] = syscalls.PartiallySupported("inotify_init1", InotifyInit1, "inotify events are only available inside the sandbox.", nil)
+ s.Table[295] = syscalls.Supported("preadv", Preadv)
+ s.Table[296] = syscalls.Supported("pwritev", Pwritev)
+ s.Table[299] = syscalls.Supported("recvmmsg", RecvMMsg)
+ s.Table[306] = syscalls.Supported("syncfs", Syncfs)
+ s.Table[307] = syscalls.Supported("sendmmsg", SendMMsg)
+ s.Table[316] = syscalls.Supported("renameat2", Renameat2)
+ s.Table[319] = syscalls.Supported("memfd_create", MemfdCreate)
+ s.Table[322] = syscalls.Supported("execveat", Execveat)
+ s.Table[327] = syscalls.Supported("preadv2", Preadv2)
+ s.Table[328] = syscalls.Supported("pwritev2", Pwritev2)
+ s.Table[332] = syscalls.Supported("statx", Statx)
+ s.Init()
+
+ // Override ARM64.
+ s = linux.ARM64
+ s.Table[63] = syscalls.Supported("read", Read)
+ s.Init()
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go
new file mode 100644
index 000000000..af455d5c1
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go
@@ -0,0 +1,356 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "bytes"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Listxattr implements Linux syscall listxattr(2).
+func Listxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listxattr(t, args, followFinalSymlink)
+}
+
+// Llistxattr implements Linux syscall llistxattr(2).
+func Llistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listxattr(t, args, nofollowFinalSymlink)
+}
+
+func listxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ listAddr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop, uint64(size))
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrNameList(t, listAddr, size, names)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Flistxattr implements Linux syscall flistxattr(2).
+func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ listAddr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ names, err := file.Listxattr(t, uint64(size))
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrNameList(t, listAddr, size, names)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Getxattr implements Linux syscall getxattr(2).
+func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getxattr(t, args, followFinalSymlink)
+}
+
+// Lgetxattr implements Linux syscall lgetxattr(2).
+func Lgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getxattr(t, args, nofollowFinalSymlink)
+}
+
+func getxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.GetxattrOptions{
+ Name: name,
+ Size: uint64(size),
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrValue(t, valueAddr, size, value)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Fgetxattr implements Linux syscall fgetxattr(2).
+func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ value, err := file.Getxattr(t, &vfs.GetxattrOptions{Name: name, Size: uint64(size)})
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrValue(t, valueAddr, size, value)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Setxattr implements Linux syscall setxattr(2).
+func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, setxattr(t, args, followFinalSymlink)
+}
+
+// Lsetxattr implements Linux syscall lsetxattr(2).
+func Lsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, setxattr(t, args, nofollowFinalSymlink)
+}
+
+func setxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) error {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ flags := args[4].Int()
+
+ if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
+ return syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+ value, err := copyInXattrValue(t, valueAddr, size)
+ if err != nil {
+ return err
+ }
+
+ return t.Kernel().VFS().SetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.SetxattrOptions{
+ Name: name,
+ Value: value,
+ Flags: uint32(flags),
+ })
+}
+
+// Fsetxattr implements Linux syscall fsetxattr(2).
+func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ flags := args[4].Int()
+
+ if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ value, err := copyInXattrValue(t, valueAddr, size)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, file.Setxattr(t, &vfs.SetxattrOptions{
+ Name: name,
+ Value: value,
+ Flags: uint32(flags),
+ })
+}
+
+// Removexattr implements Linux syscall removexattr(2).
+func Removexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, removexattr(t, args, followFinalSymlink)
+}
+
+// Lremovexattr implements Linux syscall lremovexattr(2).
+func Lremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, removexattr(t, args, nofollowFinalSymlink)
+}
+
+func removexattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) error {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+
+ return t.Kernel().VFS().RemovexattrAt(t, t.Credentials(), &tpop.pop, name)
+}
+
+// Fremovexattr implements Linux syscall fremovexattr(2).
+func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, file.Removexattr(t, name)
+}
+
+func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) {
+ name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1)
+ if err != nil {
+ if err == syserror.ENAMETOOLONG {
+ return "", syserror.ERANGE
+ }
+ return "", err
+ }
+ if len(name) == 0 {
+ return "", syserror.ERANGE
+ }
+ return name, nil
+}
+
+func copyOutXattrNameList(t *kernel.Task, listAddr usermem.Addr, size uint, names []string) (int, error) {
+ if size > linux.XATTR_LIST_MAX {
+ size = linux.XATTR_LIST_MAX
+ }
+ var buf bytes.Buffer
+ for _, name := range names {
+ buf.WriteString(name)
+ buf.WriteByte(0)
+ }
+ if size == 0 {
+ // Return the size that would be required to accomodate the list.
+ return buf.Len(), nil
+ }
+ if buf.Len() > int(size) {
+ if size >= linux.XATTR_LIST_MAX {
+ return 0, syserror.E2BIG
+ }
+ return 0, syserror.ERANGE
+ }
+ return t.CopyOutBytes(listAddr, buf.Bytes())
+}
+
+func copyInXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint) (string, error) {
+ if size > linux.XATTR_SIZE_MAX {
+ return "", syserror.E2BIG
+ }
+ buf := make([]byte, size)
+ if _, err := t.CopyInBytes(valueAddr, buf); err != nil {
+ return "", err
+ }
+ return gohacks.StringFromImmutableBytes(buf), nil
+}
+
+func copyOutXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint, value string) (int, error) {
+ if size > linux.XATTR_SIZE_MAX {
+ size = linux.XATTR_SIZE_MAX
+ }
+ if size == 0 {
+ // Return the size that would be required to accomodate the value.
+ return len(value), nil
+ }
+ if len(value) > int(size) {
+ if size >= linux.XATTR_SIZE_MAX {
+ return 0, syserror.E2BIG
+ }
+ return 0, syserror.ERANGE
+ }
+ return t.CopyOutBytes(valueAddr, gohacks.ImmutableBytesFromString(value))
+}
diff --git a/pkg/sentry/syscalls/syscalls.go b/pkg/sentry/syscalls/syscalls.go
new file mode 100644
index 000000000..f88055676
--- /dev/null
+++ b/pkg/sentry/syscalls/syscalls.go
@@ -0,0 +1,111 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package syscalls is the interface from the application to the kernel.
+// Traditionally, syscalls is the interface that is used by applications to
+// request services from the kernel of a operating system. We provide a
+// user-mode kernel that needs to handle those requests coming from unmodified
+// applications. Therefore, we still use the term "syscalls" to denote this
+// interface.
+//
+// Note that the stubs in this package may merely provide the interface, not
+// the actual implementation. It just makes writing syscall stubs
+// straightforward.
+package syscalls
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Supported returns a syscall that is fully supported.
+func Supported(name string, fn kernel.SyscallFn) kernel.Syscall {
+ return kernel.Syscall{
+ Name: name,
+ Fn: fn,
+ SupportLevel: kernel.SupportFull,
+ Note: "Fully Supported.",
+ }
+}
+
+// PartiallySupported returns a syscall that has a partial implementation.
+func PartiallySupported(name string, fn kernel.SyscallFn, note string, urls []string) kernel.Syscall {
+ return kernel.Syscall{
+ Name: name,
+ Fn: fn,
+ SupportLevel: kernel.SupportPartial,
+ Note: note,
+ URLs: urls,
+ }
+}
+
+// Error returns a syscall handler that will always give the passed error.
+func Error(name string, err error, note string, urls []string) kernel.Syscall {
+ if note != "" {
+ note = note + "; "
+ }
+ return kernel.Syscall{
+ Name: name,
+ Fn: func(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, err
+ },
+ SupportLevel: kernel.SupportUnimplemented,
+ Note: fmt.Sprintf("%sReturns %q.", note, err.Error()),
+ URLs: urls,
+ }
+}
+
+// ErrorWithEvent gives a syscall function that sends an unimplemented
+// syscall event via the event channel and returns the passed error.
+func ErrorWithEvent(name string, err error, note string, urls []string) kernel.Syscall {
+ if note != "" {
+ note = note + "; "
+ }
+ return kernel.Syscall{
+ Name: name,
+ Fn: func(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, err
+ },
+ SupportLevel: kernel.SupportUnimplemented,
+ Note: fmt.Sprintf("%sReturns %q.", note, err.Error()),
+ URLs: urls,
+ }
+}
+
+// CapError gives a syscall function that checks for capability c. If the task
+// has the capability, it returns ENOSYS, otherwise EPERM. To unprivileged
+// tasks, it will seem like there is an implementation.
+func CapError(name string, c linux.Capability, note string, urls []string) kernel.Syscall {
+ if note != "" {
+ note = note + "; "
+ }
+ return kernel.Syscall{
+ Name: name,
+ Fn: func(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ if !t.HasCapability(c) {
+ return 0, nil, syserror.EPERM
+ }
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.ENOSYS
+ },
+ SupportLevel: kernel.SupportUnimplemented,
+ Note: fmt.Sprintf("%sReturns %q if the process does not have %s; %q otherwise.", note, syserror.EPERM, c.String(), syserror.ENOSYS),
+ URLs: urls,
+ }
+}
diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD
new file mode 100644
index 000000000..04f81a35b
--- /dev/null
+++ b/pkg/sentry/time/BUILD
@@ -0,0 +1,50 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "seqatomic_parameters",
+ out = "seqatomic_parameters_unsafe.go",
+ package = "time",
+ suffix = "Parameters",
+ template = "//pkg/sync:generic_seqatomic",
+ types = {
+ "Value": "Parameters",
+ },
+)
+
+go_library(
+ name = "time",
+ srcs = [
+ "arith_arm64.go",
+ "calibrated_clock.go",
+ "clock_id.go",
+ "clocks.go",
+ "muldiv_amd64.s",
+ "muldiv_arm64.s",
+ "parameters.go",
+ "sampler.go",
+ "sampler_unsafe.go",
+ "seqatomic_parameters_unsafe.go",
+ "tsc_amd64.s",
+ "tsc_arm64.s",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/log",
+ "//pkg/metric",
+ "//pkg/sync",
+ "//pkg/syserror",
+ ],
+)
+
+go_test(
+ name = "time_test",
+ srcs = [
+ "calibrated_clock_test.go",
+ "parameters_test.go",
+ "sampler_test.go",
+ ],
+ library = ":time",
+)
diff --git a/pkg/sentry/time/LICENSE b/pkg/sentry/time/LICENSE
new file mode 100644
index 000000000..6a66aea5e
--- /dev/null
+++ b/pkg/sentry/time/LICENSE
@@ -0,0 +1,27 @@
+Copyright (c) 2009 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/pkg/sentry/time/arith_arm64.go b/pkg/sentry/time/arith_arm64.go
new file mode 100644
index 000000000..b94740c2a
--- /dev/null
+++ b/pkg/sentry/time/arith_arm64.go
@@ -0,0 +1,70 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file provides a generic Go implementation of uint128 divided by uint64.
+
+// The code is derived from Go's generic math/big.divWW_g
+// (src/math/big/arith.go), but is only used on ARM64.
+
+package time
+
+import "math/bits"
+
+type word uint
+
+const (
+ _W = bits.UintSize // word size in bits
+ _W2 = _W / 2 // half word size in bits
+ _B2 = 1 << _W2 // half digit base
+ _M2 = _B2 - 1 // half digit mask
+)
+
+// nlz returns the number of leading zeros in x.
+// Wraps bits.LeadingZeros call for convenience.
+func nlz(x word) uint {
+ return uint(bits.LeadingZeros(uint(x)))
+}
+
+// q = (u1<<_W + u0 - r)/y
+// Adapted from Warren, Hacker's Delight, p. 152.
+func divWW(u1, u0, v word) (q, r word) {
+ if u1 >= v {
+ return 1<<_W - 1, 1<<_W - 1
+ }
+
+ s := nlz(v)
+ v <<= s
+
+ vn1 := v >> _W2
+ vn0 := v & _M2
+ un32 := u1<<s | u0>>(_W-s)
+ un10 := u0 << s
+ un1 := un10 >> _W2
+ un0 := un10 & _M2
+ q1 := un32 / vn1
+ rhat := un32 - q1*vn1
+
+ for q1 >= _B2 || q1*vn0 > _B2*rhat+un1 {
+ q1--
+ rhat += vn1
+
+ if rhat >= _B2 {
+ break
+ }
+ }
+
+ un21 := un32*_B2 + un1 - q1*v
+ q0 := un21 / vn1
+ rhat = un21 - q0*vn1
+
+ for q0 >= _B2 || q0*vn0 > _B2*rhat+un0 {
+ q0--
+ rhat += vn1
+ if rhat >= _B2 {
+ break
+ }
+ }
+
+ return q1*_B2 + q0, (un21*_B2 + un0 - q0*v) >> s
+}
diff --git a/pkg/sentry/time/calibrated_clock.go b/pkg/sentry/time/calibrated_clock.go
new file mode 100644
index 000000000..f9a93115d
--- /dev/null
+++ b/pkg/sentry/time/calibrated_clock.go
@@ -0,0 +1,269 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package time provides a calibrated clock synchronized to a system reference
+// clock.
+package time
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/metric"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// fallbackMetric tracks failed updates. It is not sync, as it is not critical
+// that all occurrences are captured and CalibratedClock may fallback many
+// times.
+var fallbackMetric = metric.MustCreateNewUint64Metric("/time/fallback", false /* sync */, "Incremented when a clock falls back to system calls due to a failed update")
+
+// CalibratedClock implements a clock that tracks a reference clock.
+//
+// Users should call Update at regular intervals of around approxUpdateInterval
+// to ensure that the clock does not drift significantly from the reference
+// clock.
+type CalibratedClock struct {
+ // mu protects the fields below.
+ // TODO(mpratt): consider a sequence counter for read locking.
+ mu sync.RWMutex
+
+ // ref sample the reference clock that this clock is calibrated
+ // against.
+ ref *sampler
+
+ // ready indicates that the fields below are ready for use calculating
+ // time.
+ ready bool
+
+ // params are the current timekeeping parameters.
+ params Parameters
+
+ // errorNS is the estimated clock error in nanoseconds.
+ errorNS ReferenceNS
+}
+
+// NewCalibratedClock creates a CalibratedClock that tracks the given ClockID.
+func NewCalibratedClock(c ClockID) *CalibratedClock {
+ return &CalibratedClock{
+ ref: newSampler(c),
+ }
+}
+
+// Debugf logs at debug level.
+func (c *CalibratedClock) Debugf(format string, v ...interface{}) {
+ if log.IsLogging(log.Debug) {
+ args := []interface{}{c.ref.clockID}
+ args = append(args, v...)
+ log.Debugf("CalibratedClock(%v): "+format, args...)
+ }
+}
+
+// Infof logs at debug level.
+func (c *CalibratedClock) Infof(format string, v ...interface{}) {
+ if log.IsLogging(log.Info) {
+ args := []interface{}{c.ref.clockID}
+ args = append(args, v...)
+ log.Infof("CalibratedClock(%v): "+format, args...)
+ }
+}
+
+// Warningf logs at debug level.
+func (c *CalibratedClock) Warningf(format string, v ...interface{}) {
+ if log.IsLogging(log.Warning) {
+ args := []interface{}{c.ref.clockID}
+ args = append(args, v...)
+ log.Warningf("CalibratedClock(%v): "+format, args...)
+ }
+}
+
+// reset forces the clock to restart the calibration process, logging the
+// passed message.
+func (c *CalibratedClock) reset(str string, v ...interface{}) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.resetLocked(str, v...)
+}
+
+// resetLocked is equivalent to reset with c.mu already held for writing.
+func (c *CalibratedClock) resetLocked(str string, v ...interface{}) {
+ c.Warningf(str+" Resetting clock; time may jump.", v...)
+ c.ready = false
+ c.ref.Reset()
+ fallbackMetric.Increment()
+}
+
+// updateParams updates the timekeeping parameters based on the passed
+// parameters.
+//
+// actual is the actual estimated timekeeping parameters. The stored parameters
+// may need to be adjusted slightly from these values to compensate for error.
+//
+// Preconditions: c.mu must be held for writing.
+func (c *CalibratedClock) updateParams(actual Parameters) {
+ if !c.ready {
+ // At initial calibration there is nothing to correct.
+ c.params = actual
+ c.ready = true
+
+ c.Infof("ready")
+
+ return
+ }
+
+ // Otherwise, adjust the params to correct for errors.
+ newParams, errorNS, err := errorAdjust(c.params, actual, actual.BaseCycles)
+ if err != nil {
+ // Something is very wrong. Reset and try again from the
+ // beginning.
+ c.resetLocked("Unable to update params: %v.", err)
+ return
+ }
+ logErrorAdjustment(c.ref.clockID, errorNS, c.params, newParams)
+
+ if errorNS.Magnitude() >= MaxClockError {
+ // We should never get such extreme error, something is very
+ // wrong. Reset everything and start again.
+ //
+ // N.B. logErrorAdjustment will have already logged the error
+ // at warning level.
+ //
+ // TODO(mpratt): We could allow Realtime clock jumps here.
+ c.resetLocked("Extreme clock error.")
+ return
+ }
+
+ c.params = newParams
+ c.errorNS = errorNS
+}
+
+// Update runs the update step of the clock, updating its synchronization with
+// the reference clock.
+//
+// Update returns timekeeping and true with the new timekeeping parameters if
+// the clock is calibrated. Update should be called regularly to prevent the
+// clock from getting significantly out of sync from the reference clock.
+//
+// The returned timekeeping parameters are invalidated on the next call to
+// Update.
+func (c *CalibratedClock) Update() (Parameters, bool) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if err := c.ref.Sample(); err != nil {
+ c.resetLocked("Unable to update calibrated clock: %v.", err)
+ return Parameters{}, false
+ }
+
+ oldest, newest, ok := c.ref.Range()
+ if !ok {
+ // Not ready yet.
+ return Parameters{}, false
+ }
+
+ minCount := uint64(newest.before - oldest.after)
+ maxCount := uint64(newest.after - oldest.before)
+ refInterval := uint64(newest.ref - oldest.ref)
+
+ // freq hz = count / (interval ns) * (nsPerS ns) / (1 s)
+ nsPerS := uint64(time.Second.Nanoseconds())
+
+ minHz, ok := muldiv64(minCount, nsPerS, refInterval)
+ if !ok {
+ c.resetLocked("Unable to update calibrated clock: (%v - %v) * %v / %v overflows.", newest.before, oldest.after, nsPerS, refInterval)
+ return Parameters{}, false
+ }
+
+ maxHz, ok := muldiv64(maxCount, nsPerS, refInterval)
+ if !ok {
+ c.resetLocked("Unable to update calibrated clock: (%v - %v) * %v / %v overflows.", newest.after, oldest.before, nsPerS, refInterval)
+ return Parameters{}, false
+ }
+
+ c.updateParams(Parameters{
+ Frequency: (minHz + maxHz) / 2,
+ BaseRef: newest.ref,
+ BaseCycles: newest.after,
+ })
+
+ return c.params, true
+}
+
+// GetTime returns the current time based on the clock calibration.
+func (c *CalibratedClock) GetTime() (int64, error) {
+ c.mu.RLock()
+
+ if !c.ready {
+ // Fallback to a syscall.
+ now, err := c.ref.Syscall()
+ c.mu.RUnlock()
+ return int64(now), err
+ }
+
+ now := c.ref.Cycles()
+ v, ok := c.params.ComputeTime(now)
+ if !ok {
+ // Something is seriously wrong with the clock. Try
+ // again with syscalls.
+ c.resetLocked("Time computation overflowed. params = %+v, now = %v.", c.params, now)
+ now, err := c.ref.Syscall()
+ c.mu.RUnlock()
+ return int64(now), err
+ }
+
+ c.mu.RUnlock()
+ return v, nil
+}
+
+// CalibratedClocks contains calibrated monotonic and realtime clocks.
+//
+// TODO(mpratt): We know that Linux runs the monotonic and realtime clocks at
+// the same rate, so rather than tracking both individually, we could do one
+// calibration for both clocks.
+type CalibratedClocks struct {
+ // monotonic is the clock tracking the system monotonic clock.
+ monotonic *CalibratedClock
+
+ // realtime is the realtime equivalent of monotonic.
+ realtime *CalibratedClock
+}
+
+// NewCalibratedClocks creates a CalibratedClocks.
+func NewCalibratedClocks() *CalibratedClocks {
+ return &CalibratedClocks{
+ monotonic: NewCalibratedClock(Monotonic),
+ realtime: NewCalibratedClock(Realtime),
+ }
+}
+
+// Update implements Clocks.Update.
+func (c *CalibratedClocks) Update() (Parameters, bool, Parameters, bool) {
+ monotonicParams, monotonicOk := c.monotonic.Update()
+ realtimeParams, realtimeOk := c.realtime.Update()
+
+ return monotonicParams, monotonicOk, realtimeParams, realtimeOk
+}
+
+// GetTime implements Clocks.GetTime.
+func (c *CalibratedClocks) GetTime(id ClockID) (int64, error) {
+ switch id {
+ case Monotonic:
+ return c.monotonic.GetTime()
+ case Realtime:
+ return c.realtime.GetTime()
+ default:
+ return 0, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/time/calibrated_clock_test.go b/pkg/sentry/time/calibrated_clock_test.go
new file mode 100644
index 000000000..d6622bfe2
--- /dev/null
+++ b/pkg/sentry/time/calibrated_clock_test.go
@@ -0,0 +1,186 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+import (
+ "testing"
+ "time"
+)
+
+// newTestCalibratedClock returns a CalibratedClock that collects samples from
+// the given sample list and cycle counts from the given cycle list.
+func newTestCalibratedClock(samples []sample, cycles []TSCValue) *CalibratedClock {
+ return &CalibratedClock{
+ ref: newTestSampler(samples, cycles),
+ }
+}
+
+func TestConstantFrequency(t *testing.T) {
+ // Perfectly constant frequency.
+ samples := []sample{
+ {before: 100000, after: 100000 + defaultOverheadCycles, ref: 100},
+ {before: 200000, after: 200000 + defaultOverheadCycles, ref: 200},
+ {before: 300000, after: 300000 + defaultOverheadCycles, ref: 300},
+ {before: 400000, after: 400000 + defaultOverheadCycles, ref: 400},
+ {before: 500000, after: 500000 + defaultOverheadCycles, ref: 500},
+ {before: 600000, after: 600000 + defaultOverheadCycles, ref: 600},
+ {before: 700000, after: 700000 + defaultOverheadCycles, ref: 700},
+ }
+
+ c := newTestCalibratedClock(samples, nil)
+
+ // Update from all samples.
+ for range samples {
+ c.Update()
+ }
+
+ c.mu.RLock()
+ if !c.ready {
+ c.mu.RUnlock()
+ t.Fatalf("clock not ready")
+ }
+ // A bit after the last sample.
+ now, ok := c.params.ComputeTime(750000)
+ c.mu.RUnlock()
+ if !ok {
+ t.Fatalf("ComputeTime ok got %v want true", ok)
+ }
+
+ t.Logf("now: %v", now)
+
+ // Time should be between the current sample and where we'd expect the
+ // next sample.
+ if now < 700 || now > 800 {
+ t.Errorf("now got %v want > 700 && < 800", now)
+ }
+}
+
+func TestErrorCorrection(t *testing.T) {
+ testCases := []struct {
+ name string
+ samples [5]sample
+ projectedTimeStart int64
+ projectedTimeEnd int64
+ }{
+ // Initial calibration should be ~1MHz for each of these, and
+ // the reference clock changes in samples[2].
+ {
+ name: "slow-down",
+ samples: [5]sample{
+ {before: 1000000, after: 1000001, ref: ReferenceNS(1 * ApproxUpdateInterval.Nanoseconds())},
+ {before: 2000000, after: 2000001, ref: ReferenceNS(2 * ApproxUpdateInterval.Nanoseconds())},
+ // Reference clock has slowed down, causing 100ms of error.
+ {before: 3010000, after: 3010001, ref: ReferenceNS(3 * ApproxUpdateInterval.Nanoseconds())},
+ {before: 4020000, after: 4020001, ref: ReferenceNS(4 * ApproxUpdateInterval.Nanoseconds())},
+ {before: 5030000, after: 5030001, ref: ReferenceNS(5 * ApproxUpdateInterval.Nanoseconds())},
+ },
+ projectedTimeStart: 3005 * time.Millisecond.Nanoseconds(),
+ projectedTimeEnd: 3015 * time.Millisecond.Nanoseconds(),
+ },
+ {
+ name: "speed-up",
+ samples: [5]sample{
+ {before: 1000000, after: 1000001, ref: ReferenceNS(1 * ApproxUpdateInterval.Nanoseconds())},
+ {before: 2000000, after: 2000001, ref: ReferenceNS(2 * ApproxUpdateInterval.Nanoseconds())},
+ // Reference clock has sped up, causing 100ms of error.
+ {before: 2990000, after: 2990001, ref: ReferenceNS(3 * ApproxUpdateInterval.Nanoseconds())},
+ {before: 3980000, after: 3980001, ref: ReferenceNS(4 * ApproxUpdateInterval.Nanoseconds())},
+ {before: 4970000, after: 4970001, ref: ReferenceNS(5 * ApproxUpdateInterval.Nanoseconds())},
+ },
+ projectedTimeStart: 2985 * time.Millisecond.Nanoseconds(),
+ projectedTimeEnd: 2995 * time.Millisecond.Nanoseconds(),
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := newTestCalibratedClock(tc.samples[:], nil)
+
+ // Initial calibration takes two updates.
+ _, ok := c.Update()
+ if ok {
+ t.Fatalf("Update ready too early")
+ }
+
+ params, ok := c.Update()
+ if !ok {
+ t.Fatalf("Update not ready")
+ }
+
+ // Initial calibration is ~1MHz.
+ hz := params.Frequency
+ if hz < 990000 || hz > 1010000 {
+ t.Fatalf("Frequency got %v want > 990kHz && < 1010kHz", hz)
+ }
+
+ // Project time at the next update. Given the 1MHz
+ // calibration, it is expected to be ~3.1s/2.9s, not
+ // the actual 3s.
+ //
+ // N.B. the next update time is the "after" time above.
+ projected, ok := params.ComputeTime(tc.samples[2].after)
+ if !ok {
+ t.Fatalf("ComputeTime ok got %v want true", ok)
+ }
+ if projected < tc.projectedTimeStart || projected > tc.projectedTimeEnd {
+ t.Fatalf("ComputeTime(%v) got %v want > %v && < %v", tc.samples[2].after, projected, tc.projectedTimeStart, tc.projectedTimeEnd)
+ }
+
+ // Update again to see the changed reference clock.
+ params, ok = c.Update()
+ if !ok {
+ t.Fatalf("Update not ready")
+ }
+
+ // We now know that TSC = tc.samples[2].after -> 3s,
+ // but with the previous params indicated that TSC
+ // tc.samples[2].after -> 3.5s/2.5s. We can't allow the
+ // clock to go backwards, and having the clock jump
+ // forwards is undesirable. There should be a smooth
+ // transition that corrects the clock error over time.
+ // Check that the clock is continuous at TSC =
+ // tc.samples[2].after.
+ newProjected, ok := params.ComputeTime(tc.samples[2].after)
+ if !ok {
+ t.Fatalf("ComputeTime ok got %v want true", ok)
+ }
+ if newProjected != projected {
+ t.Errorf("Discontinuous time; ComputeTime(%v) got %v want %v", tc.samples[2].after, newProjected, projected)
+ }
+
+ // As the reference clock stablizes, ensure that the clock error
+ // decreases.
+ initialErr := c.errorNS
+ t.Logf("initial error: %v ns", initialErr)
+
+ _, ok = c.Update()
+ if !ok {
+ t.Fatalf("Update not ready")
+ }
+ if c.errorNS.Magnitude() > initialErr.Magnitude() {
+ t.Errorf("errorNS increased, got %v want |%v| <= |%v|", c.errorNS, c.errorNS, initialErr)
+ }
+
+ _, ok = c.Update()
+ if !ok {
+ t.Fatalf("Update not ready")
+ }
+ if c.errorNS.Magnitude() > initialErr.Magnitude() {
+ t.Errorf("errorNS increased, got %v want |%v| <= |%v|", c.errorNS, c.errorNS, initialErr)
+ }
+
+ t.Logf("final error: %v ns", c.errorNS)
+ })
+ }
+}
diff --git a/pkg/sentry/time/clock_id.go b/pkg/sentry/time/clock_id.go
new file mode 100644
index 000000000..724f59dd9
--- /dev/null
+++ b/pkg/sentry/time/clock_id.go
@@ -0,0 +1,40 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+import (
+ "strconv"
+)
+
+// ClockID is a Linux clock identifier.
+type ClockID int32
+
+// These are the supported Linux clock identifiers.
+const (
+ Realtime ClockID = iota
+ Monotonic
+)
+
+// String implements fmt.Stringer.String.
+func (c ClockID) String() string {
+ switch c {
+ case Realtime:
+ return "Realtime"
+ case Monotonic:
+ return "Monotonic"
+ default:
+ return strconv.Itoa(int(c))
+ }
+}
diff --git a/pkg/sentry/time/clocks.go b/pkg/sentry/time/clocks.go
new file mode 100644
index 000000000..837e86094
--- /dev/null
+++ b/pkg/sentry/time/clocks.go
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+// Clocks represents a clock source that contains both a monotonic and realtime
+// clock.
+type Clocks interface {
+ // Update performs an update step, keeping the clocks in sync with the
+ // reference host clocks, and returning the new timekeeping parameters.
+ //
+ // Update should be called at approximately ApproxUpdateInterval.
+ Update() (monotonicParams Parameters, monotonicOk bool, realtimeParam Parameters, realtimeOk bool)
+
+ // GetTime returns the current time in nanoseconds for the given clock.
+ //
+ // Clocks implementations must support at least Monotonic and
+ // Realtime.
+ GetTime(c ClockID) (int64, error)
+}
diff --git a/pkg/sentry/time/muldiv_amd64.s b/pkg/sentry/time/muldiv_amd64.s
new file mode 100644
index 000000000..028c6684e
--- /dev/null
+++ b/pkg/sentry/time/muldiv_amd64.s
@@ -0,0 +1,44 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// Documentation is available in parameters.go.
+//
+// func muldiv64(value, multiplier, divisor uint64) (uint64, bool)
+TEXT ·muldiv64(SB),NOSPLIT,$0-33
+ MOVQ value+0(FP), AX
+ MOVQ multiplier+8(FP), BX
+ MOVQ divisor+16(FP), CX
+
+ // Multiply AX*BX and store result in DX:AX.
+ MULQ BX
+
+ // If divisor <= (value*multiplier) / 2^64, then the division will overflow.
+ //
+ // (value*multiplier) / 2^64 is DX:AX >> 64, or simply DX.
+ CMPQ CX, DX
+ JLE overflow
+
+ // Divide DX:AX by CX.
+ DIVQ CX
+
+ MOVQ AX, result+24(FP)
+ MOVB $1, ok+32(FP)
+ RET
+
+overflow:
+ MOVQ $0, result+24(FP)
+ MOVB $0, ok+32(FP)
+ RET
diff --git a/pkg/sentry/time/muldiv_arm64.s b/pkg/sentry/time/muldiv_arm64.s
new file mode 100644
index 000000000..8afc62d53
--- /dev/null
+++ b/pkg/sentry/time/muldiv_arm64.s
@@ -0,0 +1,47 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "funcdata.h"
+#include "textflag.h"
+
+// Documentation is available in parameters.go.
+//
+// func muldiv64(value, multiplier, divisor uint64) (uint64, bool)
+TEXT ·muldiv64(SB),NOSPLIT,$40-33
+ GO_ARGS
+ NO_LOCAL_POINTERS
+ MOVD value+0(FP), R0
+ MOVD multiplier+8(FP), R1
+ MOVD divisor+16(FP), R2
+
+ UMULH R0, R1, R3
+ MUL R0, R1, R4
+
+ CMP R2, R3
+ BHS overflow
+
+ MOVD R3, 8(RSP)
+ MOVD R4, 16(RSP)
+ MOVD R2, 24(RSP)
+ CALL ·divWW(SB)
+ MOVD 32(RSP), R0
+ MOVD R0, result+24(FP)
+ MOVD $1, R0
+ MOVB R0, ok+32(FP)
+ RET
+
+overflow:
+ MOVD ZR, result+24(FP)
+ MOVB ZR, ok+32(FP)
+ RET
diff --git a/pkg/sentry/time/parameters.go b/pkg/sentry/time/parameters.go
new file mode 100644
index 000000000..65868cb26
--- /dev/null
+++ b/pkg/sentry/time/parameters.go
@@ -0,0 +1,239 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+const (
+ // ApproxUpdateInterval is the approximate interval that parameters
+ // should be updated at.
+ //
+ // Error correction assumes that the next update will occur after this
+ // much time.
+ //
+ // If an update occurs before ApproxUpdateInterval passes, it has no
+ // adverse effect on error correction behavior.
+ //
+ // If an update occurs after ApproxUpdateInterval passes, the clock
+ // will overshoot its error correction target and begin accumulating
+ // error in the other direction.
+ //
+ // If updates occur after more than 2*ApproxUpdateInterval passes, the
+ // clock becomes unstable, accumulating more error than it had
+ // originally. Repeated updates after more than 2*ApproxUpdateInterval
+ // will cause unbounded increases in error.
+ //
+ // These statements assume that the host clock does not change. Actual
+ // error will depend upon host clock changes.
+ //
+ // TODO(b/68779214): make error correction more robust to delayed
+ // updates.
+ ApproxUpdateInterval = 1 * time.Second
+
+ // MaxClockError is the maximum amount of error that the clocks will
+ // try to correct.
+ //
+ // This limit:
+ //
+ // * Puts a limit on cases of otherwise unbounded increases in error.
+ //
+ // * Avoids unreasonably large frequency adjustments required to
+ // correct large errors over a single update interval.
+ MaxClockError = ReferenceNS(ApproxUpdateInterval) / 4
+)
+
+// Parameters are the timekeeping parameters needed to compute the current
+// time.
+type Parameters struct {
+ // BaseCycles was the TSC counter value when the time was BaseRef.
+ BaseCycles TSCValue
+
+ // BaseRef is the reference clock time in nanoseconds corresponding to
+ // BaseCycles.
+ BaseRef ReferenceNS
+
+ // Frequency is the frequency of the cycle clock in Hertz.
+ Frequency uint64
+}
+
+// muldiv64 multiplies two 64-bit numbers, then divides the result by another
+// 64-bit number.
+//
+// It requires that the result fit in 64 bits, but doesn't require that
+// intermediate values do; in particular, the result of the multiplication may
+// require 128 bits.
+//
+// It returns !ok if divisor is zero or the result does not fit in 64 bits.
+func muldiv64(value, multiplier, divisor uint64) (uint64, bool)
+
+// ComputeTime calculates the current time from a "now" TSC value.
+//
+// time = ref + (now - base) / f
+func (p Parameters) ComputeTime(nowCycles TSCValue) (int64, bool) {
+ diffCycles := nowCycles - p.BaseCycles
+ if diffCycles < 0 {
+ log.Warningf("now cycles %v < base cycles %v", nowCycles, p.BaseCycles)
+ diffCycles = 0
+ }
+
+ // Overflow "won't ever happen". If diffCycles is the max value
+ // (2^63 - 1), then to overflow,
+ //
+ // frequency <= ((2^63 - 1) * 10^9) / 2^64 = 500Mhz
+ //
+ // A TSC running at 2GHz takes 201 years to reach 2^63-1. 805 years at
+ // 500MHz.
+ diffNS, ok := muldiv64(uint64(diffCycles), uint64(time.Second.Nanoseconds()), p.Frequency)
+ return int64(uint64(p.BaseRef) + diffNS), ok
+}
+
+// errorAdjust returns a new Parameters struct "adjusted" that satisfies:
+//
+// 1. adjusted.ComputeTime(now) = prevParams.ComputeTime(now)
+// * i.e., the current time does not jump.
+//
+// 2. adjusted.ComputeTime(TSC at next update) = newParams.ComputeTime(TSC at next update)
+// * i.e., Any error between prevParams and newParams will be corrected over
+// the course of the next update period.
+//
+// errorAdjust also returns the current clock error.
+//
+// Preconditions:
+// * newParams.BaseCycles >= prevParams.BaseCycles; i.e., TSC must not go
+// backwards.
+// * newParams.BaseCycles <= now; i.e., the new parameters be computed at or
+// before now.
+func errorAdjust(prevParams Parameters, newParams Parameters, now TSCValue) (Parameters, ReferenceNS, error) {
+ if newParams.BaseCycles < prevParams.BaseCycles {
+ // Oh dear! Something is very wrong.
+ return Parameters{}, 0, fmt.Errorf("TSC went backwards in updated clock params: %v < %v", newParams.BaseCycles, prevParams.BaseCycles)
+ }
+ if newParams.BaseCycles > now {
+ return Parameters{}, 0, fmt.Errorf("parameters contain base cycles later than now: %v > %v", newParams.BaseCycles, now)
+ }
+
+ intervalNS := int64(ApproxUpdateInterval.Nanoseconds())
+ nsPerSec := uint64(time.Second.Nanoseconds())
+
+ // Current time as computed by prevParams.
+ oldNowNS, ok := prevParams.ComputeTime(now)
+ if !ok {
+ return Parameters{}, 0, fmt.Errorf("old now time computation overflowed. params = %+v, now = %v", prevParams, now)
+ }
+
+ // We expect the update ticker to run based on this clock (i.e., it has
+ // been using prevParams and will use the returned adjusted
+ // parameters). Hence it will decide to fire intervalNS from the
+ // current (oldNowNS) "now".
+ nextNS := oldNowNS + intervalNS
+
+ if nextNS <= int64(newParams.BaseRef) {
+ // The next update time already passed before the new
+ // parameters were created! We definitely can't correct the
+ // error by then.
+ return Parameters{}, 0, fmt.Errorf("unable to correct error in single period. oldNowNS = %v, nextNS = %v, p = %v", oldNowNS, nextNS, newParams)
+ }
+
+ // For what TSC value next will newParams.ComputeTime(next) = nextNS?
+ //
+ // Solve ComputeTime for next:
+ //
+ // next = newParams.Frequency * (nextNS - newParams.BaseRef) + newParams.BaseCycles
+ c, ok := muldiv64(newParams.Frequency, uint64(nextNS-int64(newParams.BaseRef)), nsPerSec)
+ if !ok {
+ return Parameters{}, 0, fmt.Errorf("%v * (%v - %v) / %v overflows", newParams.Frequency, nextNS, newParams.BaseRef, nsPerSec)
+ }
+
+ cycles := TSCValue(c)
+ next := cycles + newParams.BaseCycles
+
+ if next <= now {
+ // The next update time already passed now with the new
+ // parameters! We can't correct the error in a single period.
+ return Parameters{}, 0, fmt.Errorf("unable to correct error in single period. oldNowNS = %v, nextNS = %v, now = %v, next = %v", oldNowNS, nextNS, now, next)
+ }
+
+ // We want to solve for parameters that satisfy:
+ //
+ // adjusted.ComputeTime(now) = oldNowNS
+ //
+ // adjusted.ComputeTime(next) = nextNS
+ //
+ // i.e., the current time does not change, but by the time we reach
+ // next we reach the same time as newParams.
+
+ // We choose to keep BaseCycles fixed.
+ adjusted := Parameters{
+ BaseCycles: newParams.BaseCycles,
+ }
+
+ // We want a slope such that time goes from oldNowNS to nextNS when
+ // we reach next.
+ //
+ // In other words, cycles should increase by next - now in the next
+ // interval.
+
+ cycles = next - now
+ ns := intervalNS
+
+ // adjusted.Frequency = cycles / ns
+ adjusted.Frequency, ok = muldiv64(uint64(cycles), nsPerSec, uint64(ns))
+ if !ok {
+ return Parameters{}, 0, fmt.Errorf("(%v - %v) * %v / %v overflows", next, now, nsPerSec, ns)
+ }
+
+ // Now choose a base reference such that the current time remains the
+ // same. Note that this is just ComputeTime, solving for BaseRef:
+ //
+ // oldNowNS = BaseRef + (now - BaseCycles) / Frequency
+ // BaseRef = oldNowNS - (now - BaseCycles) / Frequency
+ diffNS, ok := muldiv64(uint64(now-adjusted.BaseCycles), nsPerSec, adjusted.Frequency)
+ if !ok {
+ return Parameters{}, 0, fmt.Errorf("(%v - %v) * %v / %v overflows", now, adjusted.BaseCycles, nsPerSec, adjusted.Frequency)
+ }
+
+ adjusted.BaseRef = ReferenceNS(oldNowNS - int64(diffNS))
+
+ // The error is the difference between the current time and what the
+ // new parameters say the current time should be.
+ newNowNS, ok := newParams.ComputeTime(now)
+ if !ok {
+ return Parameters{}, 0, fmt.Errorf("new now time computation overflowed. params = %+v, now = %v", newParams, now)
+ }
+
+ errorNS := ReferenceNS(oldNowNS - newNowNS)
+
+ return adjusted, errorNS, nil
+}
+
+// logErrorAdjustment logs the clock error and associated error correction
+// frequency adjustment.
+//
+// The log level is determined by the error severity.
+func logErrorAdjustment(clock ClockID, errorNS ReferenceNS, orig, adjusted Parameters) {
+ fn := log.Debugf
+ if int64(errorNS.Magnitude()) > time.Millisecond.Nanoseconds() {
+ fn = log.Warningf
+ } else if int64(errorNS.Magnitude()) > 10*time.Microsecond.Nanoseconds() {
+ fn = log.Infof
+ }
+
+ fn("Clock(%v): error: %v ns, adjusted frequency from %v Hz to %v Hz", clock, errorNS, orig.Frequency, adjusted.Frequency)
+}
diff --git a/pkg/sentry/time/parameters_test.go b/pkg/sentry/time/parameters_test.go
new file mode 100644
index 000000000..0ce1257f6
--- /dev/null
+++ b/pkg/sentry/time/parameters_test.go
@@ -0,0 +1,501 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+import (
+ "math"
+ "testing"
+ "time"
+)
+
+func TestParametersComputeTime(t *testing.T) {
+ testCases := []struct {
+ name string
+ params Parameters
+ now TSCValue
+ want int64
+ }{
+ {
+ // Now is the same as the base cycles.
+ name: "base-cycles",
+ params: Parameters{
+ BaseCycles: 10000,
+ BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()),
+ Frequency: 10000,
+ },
+ now: 10000,
+ want: 5000 * time.Millisecond.Nanoseconds(),
+ },
+ {
+ // Now is the behind the base cycles. Time is frozen.
+ name: "backwards",
+ params: Parameters{
+ BaseCycles: 10000,
+ BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()),
+ Frequency: 10000,
+ },
+ now: 9000,
+ want: 5000 * time.Millisecond.Nanoseconds(),
+ },
+ {
+ // Now is ahead of the base cycles.
+ name: "ahead",
+ params: Parameters{
+ BaseCycles: 10000,
+ BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()),
+ Frequency: 10000,
+ },
+ now: 15000,
+ want: 5500 * time.Millisecond.Nanoseconds(),
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ got, ok := tc.params.ComputeTime(tc.now)
+ if !ok {
+ t.Errorf("ComputeTime ok got %v want true", got)
+ }
+ if got != tc.want {
+ t.Errorf("ComputeTime got %+v want %+v", got, tc.want)
+ }
+ })
+ }
+}
+
+func TestParametersErrorAdjust(t *testing.T) {
+ testCases := []struct {
+ name string
+ oldParams Parameters
+ now TSCValue
+ newParams Parameters
+ want Parameters
+ errorNS ReferenceNS
+ wantErr bool
+ }{
+ {
+ // newParams are perfectly continuous with oldParams
+ // and don't need adjustment.
+ name: "continuous",
+ oldParams: Parameters{
+ BaseCycles: 0,
+ BaseRef: 0,
+ Frequency: 10000,
+ },
+ now: 50000,
+ newParams: Parameters{
+ BaseCycles: 50000,
+ BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()),
+ Frequency: 10000,
+ },
+ want: Parameters{
+ BaseCycles: 50000,
+ BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()),
+ Frequency: 10000,
+ },
+ },
+ {
+ // Same as "continuous", but with now ahead of
+ // newParams.BaseCycles. The result is the same as
+ // there is no error to correct.
+ name: "continuous-nowdiff",
+ oldParams: Parameters{
+ BaseCycles: 0,
+ BaseRef: 0,
+ Frequency: 10000,
+ },
+ now: 60000,
+ newParams: Parameters{
+ BaseCycles: 50000,
+ BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()),
+ Frequency: 10000,
+ },
+ want: Parameters{
+ BaseCycles: 50000,
+ BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()),
+ Frequency: 10000,
+ },
+ },
+ {
+ // errorAdjust bails out if the TSC goes backwards.
+ name: "tsc-backwards",
+ oldParams: Parameters{
+ BaseCycles: 10000,
+ BaseRef: ReferenceNS(1000 * time.Millisecond.Nanoseconds()),
+ Frequency: 10000,
+ },
+ now: 9000,
+ newParams: Parameters{
+ BaseCycles: 9000,
+ BaseRef: ReferenceNS(1100 * time.Millisecond.Nanoseconds()),
+ Frequency: 10000,
+ },
+ wantErr: true,
+ },
+ {
+ // errorAdjust bails out if new params are from after now.
+ name: "params-after-now",
+ oldParams: Parameters{
+ BaseCycles: 10000,
+ BaseRef: ReferenceNS(1000 * time.Millisecond.Nanoseconds()),
+ Frequency: 10000,
+ },
+ now: 11000,
+ newParams: Parameters{
+ BaseCycles: 12000,
+ BaseRef: ReferenceNS(1200 * time.Millisecond.Nanoseconds()),
+ Frequency: 10000,
+ },
+ wantErr: true,
+ },
+ {
+ // Host clock sped up.
+ name: "speed-up",
+ oldParams: Parameters{
+ BaseCycles: 0,
+ BaseRef: 0,
+ Frequency: 10000,
+ },
+ now: 45000,
+ // Host frequency changed to 9000 immediately after
+ // oldParams was returned.
+ newParams: Parameters{
+ BaseCycles: 45000,
+ // From oldParams, we think ref = 4.5s at cycles = 45000.
+ BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()),
+ Frequency: 9000,
+ },
+ want: Parameters{
+ BaseCycles: 45000,
+ BaseRef: ReferenceNS(4500 * time.Millisecond.Nanoseconds()),
+ // We must decrease the new frequency by 50% to
+ // correct 0.5s of error in 1s
+ // (ApproxUpdateInterval).
+ Frequency: 4500,
+ },
+ errorNS: ReferenceNS(-500 * time.Millisecond.Nanoseconds()),
+ },
+ {
+ // Host clock sped up, with now ahead of newParams.
+ name: "speed-up-nowdiff",
+ oldParams: Parameters{
+ BaseCycles: 0,
+ BaseRef: 0,
+ Frequency: 10000,
+ },
+ now: 50000,
+ // Host frequency changed to 9000 immediately after
+ // oldParams was returned.
+ newParams: Parameters{
+ BaseCycles: 45000,
+ BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()),
+ Frequency: 9000,
+ },
+ // nextRef = 6000ms
+ // nextCycles = 9000 * (6000ms - 5000ms) + 45000
+ // nextCycles = 9000 * (1s) + 45000
+ // nextCycles = 54000
+ // f = (54000 - 50000) / 1s = 4000
+ //
+ // ref = 5000ms - (50000 - 45000) / 4000
+ // ref = 3.75s
+ want: Parameters{
+ BaseCycles: 45000,
+ BaseRef: ReferenceNS(3750 * time.Millisecond.Nanoseconds()),
+ Frequency: 4000,
+ },
+ // oldNow = 50000 * 10000 = 5s
+ // newNow = (50000 - 45000) / 9000 + 5s = 5.555s
+ errorNS: ReferenceNS((5000*time.Millisecond - 5555555555).Nanoseconds()),
+ },
+ {
+ // Host clock sped up. The new parameters are so far
+ // ahead that the next update time already passed.
+ name: "speed-up-uncorrectable-baseref",
+ oldParams: Parameters{
+ BaseCycles: 0,
+ BaseRef: 0,
+ Frequency: 10000,
+ },
+ now: 50000,
+ // Host frequency changed to 5000 immediately after
+ // oldParams was returned.
+ newParams: Parameters{
+ BaseCycles: 45000,
+ BaseRef: ReferenceNS(9000 * time.Millisecond.Nanoseconds()),
+ Frequency: 5000,
+ },
+ // The next update should be at 10s, but newParams
+ // already passed 6s. Thus it is impossible to correct
+ // the clock by then.
+ wantErr: true,
+ },
+ {
+ // Host clock sped up. The new parameters are moving so
+ // fast that the next update should be before now.
+ name: "speed-up-uncorrectable-frequency",
+ oldParams: Parameters{
+ BaseCycles: 0,
+ BaseRef: 0,
+ Frequency: 10000,
+ },
+ now: 55000,
+ // Host frequency changed to 7500 immediately after
+ // oldParams was returned.
+ newParams: Parameters{
+ BaseCycles: 45000,
+ BaseRef: ReferenceNS(6000 * time.Millisecond.Nanoseconds()),
+ Frequency: 7500,
+ },
+ // The next update should be at 6.5s, but newParams are
+ // so far ahead and fast that they reach 6.5s at cycle
+ // 48750, which before now! Thus it is impossible to
+ // correct the clock by then.
+ wantErr: true,
+ },
+ {
+ // Host clock slowed down.
+ name: "slow-down",
+ oldParams: Parameters{
+ BaseCycles: 0,
+ BaseRef: 0,
+ Frequency: 10000,
+ },
+ now: 55000,
+ // Host frequency changed to 11000 immediately after
+ // oldParams was returned.
+ newParams: Parameters{
+ BaseCycles: 55000,
+ // From oldParams, we think ref = 5.5s at cycles = 55000.
+ BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()),
+ Frequency: 11000,
+ },
+ want: Parameters{
+ BaseCycles: 55000,
+ BaseRef: ReferenceNS(5500 * time.Millisecond.Nanoseconds()),
+ // We must increase the new frequency by 50% to
+ // correct 0.5s of error in 1s
+ // (ApproxUpdateInterval).
+ Frequency: 16500,
+ },
+ errorNS: ReferenceNS(500 * time.Millisecond.Nanoseconds()),
+ },
+ {
+ // Host clock slowed down, with now ahead of newParams.
+ name: "slow-down-nowdiff",
+ oldParams: Parameters{
+ BaseCycles: 0,
+ BaseRef: 0,
+ Frequency: 10000,
+ },
+ now: 60000,
+ // Host frequency changed to 11000 immediately after
+ // oldParams was returned.
+ newParams: Parameters{
+ BaseCycles: 55000,
+ BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()),
+ Frequency: 11000,
+ },
+ // nextRef = 7000ms
+ // nextCycles = 11000 * (7000ms - 5000ms) + 55000
+ // nextCycles = 11000 * (2000ms) + 55000
+ // nextCycles = 77000
+ // f = (77000 - 60000) / 1s = 17000
+ //
+ // ref = 6000ms - (60000 - 55000) / 17000
+ // ref = 5705882353ns
+ want: Parameters{
+ BaseCycles: 55000,
+ BaseRef: ReferenceNS(5705882353),
+ Frequency: 17000,
+ },
+ // oldNow = 60000 * 10000 = 6s
+ // newNow = (60000 - 55000) / 11000 + 5s = 5.4545s
+ errorNS: ReferenceNS((6*time.Second - 5454545454).Nanoseconds()),
+ },
+ {
+ // Host time went backwards.
+ name: "time-backwards",
+ oldParams: Parameters{
+ BaseCycles: 50000,
+ BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()),
+ Frequency: 10000,
+ },
+ now: 60000,
+ newParams: Parameters{
+ BaseCycles: 60000,
+ // From oldParams, we think ref = 6s at cycles = 60000.
+ BaseRef: ReferenceNS(4000 * time.Millisecond.Nanoseconds()),
+ Frequency: 10000,
+ },
+ want: Parameters{
+ BaseCycles: 60000,
+ BaseRef: ReferenceNS(6000 * time.Millisecond.Nanoseconds()),
+ // We must increase the frequency by 200% to
+ // correct 2s of error in 1s
+ // (ApproxUpdateInterval).
+ Frequency: 30000,
+ },
+ errorNS: ReferenceNS(2000 * time.Millisecond.Nanoseconds()),
+ },
+ {
+ // Host time went backwards, with now ahead of newParams.
+ name: "time-backwards-nowdiff",
+ oldParams: Parameters{
+ BaseCycles: 50000,
+ BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()),
+ Frequency: 10000,
+ },
+ now: 65000,
+ // nextRef = 7500ms
+ // nextCycles = 10000 * (7500ms - 4000ms) + 60000
+ // nextCycles = 10000 * (3500ms) + 60000
+ // nextCycles = 95000
+ // f = (95000 - 65000) / 1s = 30000
+ //
+ // ref = 6500ms - (65000 - 60000) / 30000
+ // ref = 6333333333ns
+ newParams: Parameters{
+ BaseCycles: 60000,
+ BaseRef: ReferenceNS(4000 * time.Millisecond.Nanoseconds()),
+ Frequency: 10000,
+ },
+ want: Parameters{
+ BaseCycles: 60000,
+ BaseRef: ReferenceNS(6333333334),
+ Frequency: 30000,
+ },
+ // oldNow = 65000 * 10000 = 6.5s
+ // newNow = (65000 - 60000) / 10000 + 4s = 4.5s
+ errorNS: ReferenceNS(2000 * time.Millisecond.Nanoseconds()),
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ got, errorNS, err := errorAdjust(tc.oldParams, tc.newParams, tc.now)
+ if err != nil && !tc.wantErr {
+ t.Errorf("err got %v want nil", err)
+ } else if err == nil && tc.wantErr {
+ t.Errorf("err got nil want non-nil")
+ }
+
+ if got != tc.want {
+ t.Errorf("Parameters got %+v want %+v", got, tc.want)
+ }
+ if errorNS != tc.errorNS {
+ t.Errorf("errorNS got %v want %v", errorNS, tc.errorNS)
+ }
+ })
+ }
+}
+
+func testMuldiv(t *testing.T, v uint64) {
+ for i := uint64(1); i <= 1000000; i++ {
+ mult := uint64(1000000000)
+ div := i * mult
+ res, ok := muldiv64(v, mult, div)
+ if !ok {
+ t.Errorf("Result of %v * %v / %v ok got false want true", v, mult, div)
+ }
+ if want := v / i; res != want {
+ t.Errorf("Bad result of %v * %v / %v: got %v, want %v", v, mult, div, res, want)
+ }
+ }
+}
+
+func TestMulDiv(t *testing.T) {
+ testMuldiv(t, math.MaxUint64)
+ for i := int64(-10); i <= 10; i++ {
+ testMuldiv(t, uint64(i))
+ }
+}
+
+func TestMulDivZero(t *testing.T) {
+ if r, ok := muldiv64(2, 4, 0); ok {
+ t.Errorf("muldiv64(2, 4, 0) got %d, ok want !ok", r)
+ }
+
+ if r, ok := muldiv64(0, 0, 0); ok {
+ t.Errorf("muldiv64(0, 0, 0) got %d, ok want !ok", r)
+ }
+}
+
+func TestMulDivOverflow(t *testing.T) {
+ testCases := []struct {
+ name string
+ val uint64
+ mult uint64
+ div uint64
+ ok bool
+ ret uint64
+ }{
+ {
+ name: "2^62",
+ val: 1 << 63,
+ mult: 4,
+ div: 8,
+ ok: true,
+ ret: 1 << 62,
+ },
+ {
+ name: "2^64-1",
+ val: 0xffffffffffffffff,
+ mult: 1,
+ div: 1,
+ ok: true,
+ ret: 0xffffffffffffffff,
+ },
+ {
+ name: "2^64",
+ val: 1 << 63,
+ mult: 4,
+ div: 2,
+ ok: false,
+ },
+ {
+ name: "2^125",
+ val: 1 << 63,
+ mult: 1 << 63,
+ div: 2,
+ ok: false,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ r, ok := muldiv64(tc.val, tc.mult, tc.div)
+ if ok != tc.ok {
+ t.Errorf("ok got %v want %v", ok, tc.ok)
+ }
+ if tc.ok && r != tc.ret {
+ t.Errorf("ret got %v want %v", r, tc.ret)
+ }
+ })
+ }
+}
+
+func BenchmarkMuldiv64(b *testing.B) {
+ var v uint64 = math.MaxUint64
+ for i := uint64(1); i <= 1000000; i++ {
+ mult := uint64(1000000000)
+ div := i * mult
+ res, ok := muldiv64(v, mult, div)
+ if !ok {
+ b.Errorf("Result of %v * %v / %v ok got false want true", v, mult, div)
+ }
+ if want := v / i; res != want {
+ b.Errorf("Bad result of %v * %v / %v: got %v, want %v", v, mult, div, res, want)
+ }
+ }
+}
diff --git a/pkg/sentry/time/sampler.go b/pkg/sentry/time/sampler.go
new file mode 100644
index 000000000..4ac9c4474
--- /dev/null
+++ b/pkg/sentry/time/sampler.go
@@ -0,0 +1,225 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+import (
+ "errors"
+
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+const (
+ // defaultOverheadTSC is the default estimated syscall overhead in TSC cycles.
+ // It is further refined as syscalls are made.
+ defaultOverheadCycles = 1 * 1000
+
+ // maxOverheadCycles is the maximum allowed syscall overhead in TSC cycles.
+ maxOverheadCycles = 100 * defaultOverheadCycles
+
+ // maxSampleLoops is the maximum number of times to try to get a clock sample
+ // under the expected overhead.
+ maxSampleLoops = 5
+
+ // maxSamples is the maximum number of samples to collect.
+ maxSamples = 10
+)
+
+// errOverheadTooHigh is returned from sampler.Sample if the syscall
+// overhead is too high.
+var errOverheadTooHigh = errors.New("time syscall overhead exceeds maximum")
+
+// TSCValue is a value from the TSC.
+type TSCValue int64
+
+// Rdtsc reads the TSC.
+//
+// Intel SDM, Vol 3, Ch 17.15:
+// "The RDTSC instruction reads the time-stamp counter and is guaranteed to
+// return a monotonically increasing unique value whenever executed, except for
+// a 64-bit counter wraparound. Intel guarantees that the time-stamp counter
+// will not wraparound within 10 years after being reset."
+//
+// We use int64, so we have 5 years before wrap-around.
+func Rdtsc() TSCValue
+
+// ReferenceNS are nanoseconds in the reference clock domain.
+// int64 gives us ~290 years before this overflows.
+type ReferenceNS int64
+
+// Magnitude returns the absolute value of r.
+func (r ReferenceNS) Magnitude() ReferenceNS {
+ if r < 0 {
+ return -r
+ }
+ return r
+}
+
+// cycleClock is a TSC-based cycle clock.
+type cycleClock interface {
+ // Cycles returns a count value from the TSC.
+ Cycles() TSCValue
+}
+
+// tscCycleClock is a cycleClock that uses the real TSC.
+type tscCycleClock struct{}
+
+// Cycles implements cycleClock.Cycles.
+func (tscCycleClock) Cycles() TSCValue {
+ return Rdtsc()
+}
+
+// sample contains a sample from the reference clock, with TSC values from
+// before and after the reference clock value was captured.
+type sample struct {
+ before TSCValue
+ after TSCValue
+ ref ReferenceNS
+}
+
+// Overhead returns the sample overhead in TSC cycles.
+func (s *sample) Overhead() TSCValue {
+ return s.after - s.before
+}
+
+// referenceClocks collects individual samples from a reference clock ID and
+// TSC.
+type referenceClocks interface {
+ cycleClock
+
+ // Sample returns a single sample from the reference clock ID.
+ Sample(c ClockID) (sample, error)
+}
+
+// sampler collects samples from a reference system clock, minimizing
+// the overhead in each sample.
+type sampler struct {
+ // clockID is the reference clock ID (e.g., CLOCK_MONOTONIC).
+ clockID ClockID
+
+ // clocks provides raw samples.
+ clocks referenceClocks
+
+ // overhead is the estimated sample overhead in TSC cycles.
+ overhead TSCValue
+
+ // samples is a ring buffer of the latest samples collected.
+ samples []sample
+}
+
+// newSampler creates a sampler for clockID.
+func newSampler(c ClockID) *sampler {
+ return &sampler{
+ clockID: c,
+ clocks: syscallTSCReferenceClocks{},
+ overhead: defaultOverheadCycles,
+ }
+}
+
+// Reset discards previously collected clock samples.
+func (s *sampler) Reset() {
+ s.overhead = defaultOverheadCycles
+ s.samples = []sample{}
+}
+
+// lowOverheadSample returns a reference clock sample with minimized syscall overhead.
+func (s *sampler) lowOverheadSample() (sample, error) {
+ for {
+ for i := 0; i < maxSampleLoops; i++ {
+ samp, err := s.clocks.Sample(s.clockID)
+ if err != nil {
+ return sample{}, err
+ }
+
+ if samp.before > samp.after {
+ log.Warningf("TSC went backwards: %v > %v", samp.before, samp.after)
+ continue
+ }
+
+ if samp.Overhead() <= s.overhead {
+ return samp, nil
+ }
+ }
+
+ // Couldn't get a sample with the current overhead. Increase it.
+ newOverhead := 2 * s.overhead
+ if newOverhead > maxOverheadCycles {
+ // We'll give it one more shot with the max overhead.
+
+ if s.overhead == maxOverheadCycles {
+ return sample{}, errOverheadTooHigh
+ }
+
+ newOverhead = maxOverheadCycles
+ }
+
+ s.overhead = newOverhead
+ log.Debugf("Time: Adjusting syscall overhead up to %v", s.overhead)
+ }
+}
+
+// Sample collects a reference clock sample.
+func (s *sampler) Sample() error {
+ sample, err := s.lowOverheadSample()
+ if err != nil {
+ return err
+ }
+
+ s.samples = append(s.samples, sample)
+ if len(s.samples) > maxSamples {
+ s.samples = s.samples[1:]
+ }
+
+ // If the 4 most recent samples all have an overhead less than half the
+ // expected overhead, adjust downwards.
+ if len(s.samples) < 4 {
+ return nil
+ }
+
+ for _, sample := range s.samples[len(s.samples)-4:] {
+ if sample.Overhead() > s.overhead/2 {
+ return nil
+ }
+ }
+
+ s.overhead -= s.overhead / 8
+ log.Debugf("Time: Adjusting syscall overhead down to %v", s.overhead)
+
+ return nil
+}
+
+// Syscall returns the current raw reference time without storing TSC
+// samples.
+func (s *sampler) Syscall() (ReferenceNS, error) {
+ sample, err := s.clocks.Sample(s.clockID)
+ if err != nil {
+ return 0, err
+ }
+
+ return sample.ref, nil
+}
+
+// Cycles returns a raw TSC value.
+func (s *sampler) Cycles() TSCValue {
+ return s.clocks.Cycles()
+}
+
+// Range returns the widest range of clock samples available.
+func (s *sampler) Range() (sample, sample, bool) {
+ if len(s.samples) < 2 {
+ return sample{}, sample{}, false
+ }
+
+ return s.samples[0], s.samples[len(s.samples)-1], true
+}
diff --git a/pkg/sentry/time/sampler_test.go b/pkg/sentry/time/sampler_test.go
new file mode 100644
index 000000000..3e70a1134
--- /dev/null
+++ b/pkg/sentry/time/sampler_test.go
@@ -0,0 +1,183 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+import (
+ "errors"
+ "testing"
+)
+
+// errNoSamples is returned when testReferenceClocks runs out of samples.
+var errNoSamples = errors.New("no samples available")
+
+// testReferenceClocks returns a preset list of samples and cycle counts.
+type testReferenceClocks struct {
+ samples []sample
+ cycles []TSCValue
+}
+
+// Sample implements referenceClocks.Sample, returning the next sample in the list.
+func (t *testReferenceClocks) Sample(_ ClockID) (sample, error) {
+ if len(t.samples) == 0 {
+ return sample{}, errNoSamples
+ }
+
+ s := t.samples[0]
+ if len(t.samples) == 1 {
+ t.samples = nil
+ } else {
+ t.samples = t.samples[1:]
+ }
+
+ return s, nil
+}
+
+// Cycles implements referenceClocks.Cycles, returning the next TSCValue in the list.
+func (t *testReferenceClocks) Cycles() TSCValue {
+ if len(t.cycles) == 0 {
+ return 0
+ }
+
+ c := t.cycles[0]
+ if len(t.cycles) == 1 {
+ t.cycles = nil
+ } else {
+ t.cycles = t.cycles[1:]
+ }
+
+ return c
+}
+
+// newTestSampler returns a sampler that collects samples from
+// the given sample list and cycle counts from the given cycle list.
+func newTestSampler(samples []sample, cycles []TSCValue) *sampler {
+ return &sampler{
+ clocks: &testReferenceClocks{
+ samples: samples,
+ cycles: cycles,
+ },
+ overhead: defaultOverheadCycles,
+ }
+}
+
+// generateSamples generates n samples with the given overhead.
+func generateSamples(n int, overhead TSCValue) []sample {
+ samples := []sample{{before: 1000000, after: 1000000 + overhead, ref: 100}}
+ for i := 0; i < n-1; i++ {
+ prev := samples[len(samples)-1]
+ samples = append(samples, sample{
+ before: prev.before + 1000000,
+ after: prev.after + 1000000,
+ ref: prev.ref + 100,
+ })
+ }
+ return samples
+}
+
+// TestSample ensures that samples can be collected.
+func TestSample(t *testing.T) {
+ testCases := []struct {
+ name string
+ samples []sample
+ err error
+ }{
+ {
+ name: "basic",
+ samples: []sample{
+ {before: 100000, after: 100000 + defaultOverheadCycles, ref: 100},
+ },
+ err: nil,
+ },
+ {
+ // Sample with backwards TSC ignored.
+ // referenceClock should retry and get errNoSamples.
+ name: "backwards-tsc-ignored",
+ samples: []sample{
+ {before: 100000, after: 90000, ref: 100},
+ },
+ err: errNoSamples,
+ },
+ {
+ // Sample far above overhead skipped.
+ // referenceClock should retry and get errNoSamples.
+ name: "reject-overhead",
+ samples: []sample{
+ {before: 100000, after: 100000 + 5*defaultOverheadCycles, ref: 100},
+ },
+ err: errNoSamples,
+ },
+ {
+ // Maximum overhead allowed is bounded.
+ name: "over-max-overhead",
+ // Generate a bunch of samples. The reference clock
+ // needs a while to ramp up its expected overhead.
+ samples: generateSamples(100, 2*maxOverheadCycles),
+ err: errOverheadTooHigh,
+ },
+ {
+ // Overhead at maximum overhead is allowed.
+ name: "max-overhead",
+ // Generate a bunch of samples. The reference clock
+ // needs a while to ramp up its expected overhead.
+ samples: generateSamples(100, maxOverheadCycles),
+ err: nil,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ s := newTestSampler(tc.samples, nil)
+ err := s.Sample()
+ if err != tc.err {
+ t.Errorf("Sample err got %v want %v", err, tc.err)
+ }
+ })
+ }
+}
+
+// TestOutliersIgnored tests that referenceClock ignores samples with very high
+// overhead.
+func TestOutliersIgnored(t *testing.T) {
+ s := newTestSampler([]sample{
+ {before: 100000, after: 100000 + defaultOverheadCycles, ref: 100},
+ {before: 200000, after: 200000 + defaultOverheadCycles, ref: 200},
+ {before: 300000, after: 300000 + defaultOverheadCycles, ref: 300},
+ {before: 400000, after: 400000 + defaultOverheadCycles, ref: 400},
+ {before: 500000, after: 500000 + 5*defaultOverheadCycles, ref: 500}, // Ignored
+ {before: 600000, after: 600000 + defaultOverheadCycles, ref: 600},
+ {before: 700000, after: 700000 + defaultOverheadCycles, ref: 700},
+ }, nil)
+
+ // Collect 5 samples.
+ for i := 0; i < 5; i++ {
+ err := s.Sample()
+ if err != nil {
+ t.Fatalf("Unexpected error while sampling: %v", err)
+ }
+ }
+
+ oldest, newest, ok := s.Range()
+ if !ok {
+ t.Fatalf("Range not ok")
+ }
+
+ if oldest.ref != 100 {
+ t.Errorf("oldest.ref got %v want %v", oldest.ref, 100)
+ }
+
+ // We skipped the high-overhead sample.
+ if newest.ref != 600 {
+ t.Errorf("newest.ref got %v want %v", newest.ref, 600)
+ }
+}
diff --git a/pkg/sentry/time/sampler_unsafe.go b/pkg/sentry/time/sampler_unsafe.go
new file mode 100644
index 000000000..e76180217
--- /dev/null
+++ b/pkg/sentry/time/sampler_unsafe.go
@@ -0,0 +1,56 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+// syscallTSCReferenceClocks is the standard referenceClocks, collecting
+// samples using CLOCK_GETTIME and RDTSC.
+type syscallTSCReferenceClocks struct {
+ tscCycleClock
+}
+
+// Sample implements sampler.Sample.
+func (syscallTSCReferenceClocks) Sample(c ClockID) (sample, error) {
+ var s sample
+
+ s.before = Rdtsc()
+
+ // Don't call clockGettime to avoid a call which may call morestack.
+ var ts syscall.Timespec
+ _, _, e := syscall.RawSyscall(syscall.SYS_CLOCK_GETTIME, uintptr(c), uintptr(unsafe.Pointer(&ts)), 0)
+ if e != 0 {
+ return sample{}, e
+ }
+
+ s.after = Rdtsc()
+ s.ref = ReferenceNS(ts.Nano())
+
+ return s, nil
+}
+
+// clockGettime calls SYS_CLOCK_GETTIME, returning time in nanoseconds.
+func clockGettime(c ClockID) (ReferenceNS, error) {
+ var ts syscall.Timespec
+ _, _, e := syscall.RawSyscall(syscall.SYS_CLOCK_GETTIME, uintptr(c), uintptr(unsafe.Pointer(&ts)), 0)
+ if e != 0 {
+ return 0, e
+ }
+
+ return ReferenceNS(ts.Nano()), nil
+}
diff --git a/pkg/sentry/time/tsc_amd64.s b/pkg/sentry/time/tsc_amd64.s
new file mode 100644
index 000000000..6a8eed664
--- /dev/null
+++ b/pkg/sentry/time/tsc_amd64.s
@@ -0,0 +1,27 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+TEXT ·Rdtsc(SB),NOSPLIT,$0-8
+ // N.B. We need LFENCE on Intel, AMD is more complicated.
+ // Modern AMD CPUs with modern kernels make LFENCE behave like it does
+ // on Intel with MSR_F10H_DECFG_LFENCE_SERIALIZE_BIT. MFENCE is
+ // otherwise needed on AMD.
+ LFENCE
+ RDTSC
+ SHLQ $32, DX
+ ADDQ DX, AX
+ MOVQ AX, ret+0(FP)
+ RET
diff --git a/pkg/sentry/time/tsc_arm64.s b/pkg/sentry/time/tsc_arm64.s
new file mode 100644
index 000000000..da9fa4112
--- /dev/null
+++ b/pkg/sentry/time/tsc_arm64.s
@@ -0,0 +1,22 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+TEXT ·Rdtsc(SB),NOSPLIT,$0-8
+ // Get the virtual counter.
+ ISB $15
+ WORD $0xd53be040 //MRS CNTVCT_EL0, R0
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/sentry/unimpl/BUILD b/pkg/sentry/unimpl/BUILD
new file mode 100644
index 000000000..5d4aa3a63
--- /dev/null
+++ b/pkg/sentry/unimpl/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library", "proto_library")
+
+package(licenses = ["notice"])
+
+proto_library(
+ name = "unimplemented_syscall",
+ srcs = ["unimplemented_syscall.proto"],
+ visibility = ["//visibility:public"],
+ deps = ["//pkg/sentry/arch:registers_proto"],
+)
+
+go_library(
+ name = "unimpl",
+ srcs = ["events.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/context",
+ "//pkg/log",
+ ],
+)
diff --git a/pkg/sentry/unimpl/events.go b/pkg/sentry/unimpl/events.go
new file mode 100644
index 000000000..73ed9372f
--- /dev/null
+++ b/pkg/sentry/unimpl/events.go
@@ -0,0 +1,45 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package unimpl contains interface to emit events about unimplemented
+// features.
+package unimpl
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// contextID is the events package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxEvents is a Context.Value key for a Events.
+ CtxEvents contextID = iota
+)
+
+// Events interface defines method to emit unsupported events.
+type Events interface {
+ EmitUnimplementedEvent(context.Context)
+}
+
+// EmitUnimplementedEvent emits unsupported syscall event to the context.
+func EmitUnimplementedEvent(ctx context.Context) {
+ e := ctx.Value(CtxEvents)
+ if e == nil {
+ log.Warningf("Context.Value(CtxEvents) not present, unimplemented syscall event not reported.")
+ return
+ }
+ e.(Events).EmitUnimplementedEvent(ctx)
+}
diff --git a/pkg/sentry/unimpl/unimplemented_syscall.proto b/pkg/sentry/unimpl/unimplemented_syscall.proto
new file mode 100644
index 000000000..0d7a94be7
--- /dev/null
+++ b/pkg/sentry/unimpl/unimplemented_syscall.proto
@@ -0,0 +1,27 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package gvisor;
+
+import "pkg/sentry/arch/registers.proto";
+
+message UnimplementedSyscall {
+ // Task ID.
+ int32 tid = 1;
+
+ // Registers at the time of the call.
+ Registers registers = 2;
+}
diff --git a/pkg/sentry/uniqueid/BUILD b/pkg/sentry/uniqueid/BUILD
new file mode 100644
index 000000000..7467e6398
--- /dev/null
+++ b/pkg/sentry/uniqueid/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "uniqueid",
+ srcs = ["context.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/context",
+ "//pkg/sentry/socket/unix/transport",
+ ],
+)
diff --git a/pkg/sentry/uniqueid/context.go b/pkg/sentry/uniqueid/context.go
new file mode 100644
index 000000000..1fb884a90
--- /dev/null
+++ b/pkg/sentry/uniqueid/context.go
@@ -0,0 +1,54 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package uniqueid defines context.Context keys for obtaining system-wide
+// unique identifiers.
+package uniqueid
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+)
+
+// contextID is the kernel package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxGlobalUniqueID is a Context.Value key for a system-wide
+ // unique identifier.
+ CtxGlobalUniqueID contextID = iota
+
+ // CtxGlobalUniqueIDProvider is a Context.Value key for a
+ // system-wide unique identifier generator.
+ CtxGlobalUniqueIDProvider
+
+ // CtxInotifyCookie is a Context.Value key for a unique inotify
+ // event cookie.
+ CtxInotifyCookie
+)
+
+// GlobalFromContext returns a system-wide unique identifier from ctx.
+func GlobalFromContext(ctx context.Context) uint64 {
+ return ctx.Value(CtxGlobalUniqueID).(uint64)
+}
+
+// GlobalProviderFromContext returns a system-wide unique identifier from ctx.
+func GlobalProviderFromContext(ctx context.Context) transport.UniqueIDProvider {
+ return ctx.Value(CtxGlobalUniqueIDProvider).(transport.UniqueIDProvider)
+}
+
+// InotifyCookie generates a unique inotify event cookie from ctx.
+func InotifyCookie(ctx context.Context) uint32 {
+ return ctx.Value(CtxInotifyCookie).(uint32)
+}
diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD
new file mode 100644
index 000000000..099315613
--- /dev/null
+++ b/pkg/sentry/usage/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "usage",
+ srcs = [
+ "cpu.go",
+ "io.go",
+ "memory.go",
+ "memory_unsafe.go",
+ "usage.go",
+ ],
+ visibility = [
+ "//:sandbox",
+ ],
+ deps = [
+ "//pkg/bits",
+ "//pkg/memutil",
+ "//pkg/sync",
+ ],
+)
diff --git a/pkg/sentry/usage/cpu.go b/pkg/sentry/usage/cpu.go
new file mode 100644
index 000000000..bfc282d69
--- /dev/null
+++ b/pkg/sentry/usage/cpu.go
@@ -0,0 +1,46 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usage
+
+import (
+ "time"
+)
+
+// CPUStats contains the subset of struct rusage fields that relate to CPU
+// scheduling.
+//
+// +stateify savable
+type CPUStats struct {
+ // UserTime is the amount of time spent executing application code.
+ UserTime time.Duration
+
+ // SysTime is the amount of time spent executing sentry code.
+ SysTime time.Duration
+
+ // VoluntarySwitches is the number of times control has been voluntarily
+ // ceded due to blocking, etc.
+ VoluntarySwitches uint64
+
+ // InvoluntarySwitches (struct rusage::ru_nivcsw) is unsupported, since
+ // "preemptive" scheduling is managed by the Go runtime, which doesn't
+ // provide this information.
+}
+
+// Accumulate adds s2 to s.
+func (s *CPUStats) Accumulate(s2 CPUStats) {
+ s.UserTime += s2.UserTime
+ s.SysTime += s2.SysTime
+ s.VoluntarySwitches += s2.VoluntarySwitches
+}
diff --git a/pkg/sentry/usage/io.go b/pkg/sentry/usage/io.go
new file mode 100644
index 000000000..dfcd3a49d
--- /dev/null
+++ b/pkg/sentry/usage/io.go
@@ -0,0 +1,90 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usage
+
+import (
+ "sync/atomic"
+)
+
+// IO contains I/O-related statistics.
+//
+// +stateify savable
+type IO struct {
+ // CharsRead is the number of bytes read by read syscalls.
+ CharsRead uint64
+
+ // CharsWritten is the number of bytes written by write syscalls.
+ CharsWritten uint64
+
+ // ReadSyscalls is the number of read syscalls.
+ ReadSyscalls uint64
+
+ // WriteSyscalls is the number of write syscalls.
+ WriteSyscalls uint64
+
+ // The following counter is only meaningful when Sentry has internal
+ // pagecache.
+
+ // BytesRead is the number of bytes actually read into pagecache.
+ BytesRead uint64
+
+ // BytesWritten is the number of bytes actually written from pagecache.
+ BytesWritten uint64
+
+ // BytesWriteCancelled is the number of bytes not written out due to
+ // truncation.
+ BytesWriteCancelled uint64
+}
+
+// AccountReadSyscall does the accounting for a read syscall.
+func (i *IO) AccountReadSyscall(bytes int64) {
+ atomic.AddUint64(&i.ReadSyscalls, 1)
+ if bytes > 0 {
+ atomic.AddUint64(&i.CharsRead, uint64(bytes))
+ }
+}
+
+// AccountWriteSyscall does the accounting for a write syscall.
+func (i *IO) AccountWriteSyscall(bytes int64) {
+ atomic.AddUint64(&i.WriteSyscalls, 1)
+ if bytes > 0 {
+ atomic.AddUint64(&i.CharsWritten, uint64(bytes))
+ }
+}
+
+// AccountReadIO does the accounting for a read IO into the file system.
+func (i *IO) AccountReadIO(bytes int64) {
+ if bytes > 0 {
+ atomic.AddUint64(&i.BytesRead, uint64(bytes))
+ }
+}
+
+// AccountWriteIO does the accounting for a write IO into the file system.
+func (i *IO) AccountWriteIO(bytes int64) {
+ if bytes > 0 {
+ atomic.AddUint64(&i.BytesWritten, uint64(bytes))
+ }
+}
+
+// Accumulate adds up io usages.
+func (i *IO) Accumulate(io *IO) {
+ atomic.AddUint64(&i.CharsRead, atomic.LoadUint64(&io.CharsRead))
+ atomic.AddUint64(&i.CharsWritten, atomic.LoadUint64(&io.CharsWritten))
+ atomic.AddUint64(&i.ReadSyscalls, atomic.LoadUint64(&io.ReadSyscalls))
+ atomic.AddUint64(&i.WriteSyscalls, atomic.LoadUint64(&io.WriteSyscalls))
+ atomic.AddUint64(&i.BytesRead, atomic.LoadUint64(&io.BytesRead))
+ atomic.AddUint64(&i.BytesWritten, atomic.LoadUint64(&io.BytesWritten))
+ atomic.AddUint64(&i.BytesWriteCancelled, atomic.LoadUint64(&io.BytesWriteCancelled))
+}
diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go
new file mode 100644
index 000000000..ab1d140d2
--- /dev/null
+++ b/pkg/sentry/usage/memory.go
@@ -0,0 +1,291 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usage
+
+import (
+ "fmt"
+ "os"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/memutil"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// MemoryKind represents a type of memory used by the application.
+//
+// For efficiency reasons, it is assumed that the Memory implementation is
+// responsible for specific stats (documented below), and those may be reported
+// in aggregate independently. See the platform.Memory interface as well as the
+// control.Usage.Collect method for more information.
+type MemoryKind int
+
+const (
+ // System represents miscellaneous system memory. This may include
+ // memory that is in the process of being reclaimed, system caches,
+ // page tables, swap, etc.
+ //
+ // This memory kind is backed by platform memory.
+ System MemoryKind = iota
+
+ // Anonymous represents anonymous application memory.
+ //
+ // This memory kind is backed by platform memory.
+ Anonymous
+
+ // PageCache represents memory allocated to back sandbox-visible files that
+ // do not have a local fd. The contents of these files are buffered in
+ // memory to support application mmaps.
+ //
+ // This memory kind is backed by platform memory.
+ PageCache
+
+ // Tmpfs represents memory used by the sandbox-visible tmpfs.
+ //
+ // This memory kind is backed by platform memory.
+ Tmpfs
+
+ // Ramdiskfs represents memory used by the ramdiskfs.
+ //
+ // This memory kind is backed by platform memory.
+ Ramdiskfs
+
+ // Mapped represents memory related to files which have a local fd on the
+ // host, and thus can be directly mapped. Typically these are files backed
+ // by gofers with donated-fd support. Note that this value may not track the
+ // exact amount of memory used by mapping on the host, because we don't have
+ // any visibility into the host kernel memory management. In particular,
+ // once we map some part of a host file, the host kernel is free to
+ // abitrarily populate/decommit the pages, which it may do for various
+ // reasons (ex. host memory reclaim, NUMA balancing).
+ //
+ // This memory kind is backed by the host pagecache, via host mmaps.
+ Mapped
+)
+
+// MemoryStats tracks application memory usage in bytes. All fields correspond to the
+// memory category with the same name. This object is thread-safe if accessed
+// through the provided methods. The public fields may be safely accessed
+// directly on a copy of the object obtained from Memory.Copy().
+type MemoryStats struct {
+ System uint64
+ Anonymous uint64
+ PageCache uint64
+ Tmpfs uint64
+ // Lazily updated based on the value in RTMapped.
+ Mapped uint64
+ Ramdiskfs uint64
+}
+
+// RTMemoryStats contains the memory usage values that need to be directly
+// exposed through a shared memory file for real-time access. These are
+// categories not backed by platform memory. For details about how this works,
+// see the memory accounting docs.
+//
+// N.B. Please keep the struct in sync with the API. Notably, changes to this
+// struct requires a version bump and addition of compatibility logic in the
+// control server. As a special-case, adding fields without re-ordering existing
+// ones do not require a version bump because the mapped page we use is
+// initially zeroed. Any added field will be ignored by an older API and will be
+// zero if read by a newer API.
+type RTMemoryStats struct {
+ RTMapped uint64
+}
+
+// MemoryLocked is Memory with access methods.
+type MemoryLocked struct {
+ mu sync.RWMutex
+ // MemoryStats records the memory stats.
+ MemoryStats
+ // RTMemoryStats records the memory stats that need to be exposed through
+ // shared page.
+ *RTMemoryStats
+ // File is the backing file storing the memory stats.
+ File *os.File
+}
+
+// Init initializes global 'MemoryAccounting'.
+func Init() error {
+ const name = "memory-usage"
+ fd, err := memutil.CreateMemFD(name, 0)
+ if err != nil {
+ return fmt.Errorf("error creating usage file: %v", err)
+ }
+ file := os.NewFile(uintptr(fd), name)
+ if err := file.Truncate(int64(RTMemoryStatsSize)); err != nil {
+ return fmt.Errorf("error truncating usage file: %v", err)
+ }
+ // Note: We rely on the returned page being initially zeroed. This will
+ // always be the case for a newly mapped page from /dev/shm. If we obtain
+ // the shared memory through some other means in the future, we may have to
+ // explicitly zero the page.
+ mmap, err := syscall.Mmap(int(file.Fd()), 0, int(RTMemoryStatsSize), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED)
+ if err != nil {
+ return fmt.Errorf("error mapping usage file: %v", err)
+ }
+
+ MemoryAccounting = &MemoryLocked{
+ File: file,
+ RTMemoryStats: RTMemoryStatsPointer(mmap),
+ }
+ return nil
+}
+
+// MemoryAccounting is the global memory stats.
+//
+// There is no need to save or restore the global memory accounting object,
+// because individual frame kinds are saved and charged only when they become
+// resident.
+var MemoryAccounting *MemoryLocked
+
+func (m *MemoryLocked) incLocked(val uint64, kind MemoryKind) {
+ switch kind {
+ case System:
+ atomic.AddUint64(&m.System, val)
+ case Anonymous:
+ atomic.AddUint64(&m.Anonymous, val)
+ case PageCache:
+ atomic.AddUint64(&m.PageCache, val)
+ case Mapped:
+ atomic.AddUint64(&m.RTMapped, val)
+ case Tmpfs:
+ atomic.AddUint64(&m.Tmpfs, val)
+ case Ramdiskfs:
+ atomic.AddUint64(&m.Ramdiskfs, val)
+ default:
+ panic(fmt.Sprintf("invalid memory kind: %v", kind))
+ }
+}
+
+// Inc adds an additional usage of 'val' bytes to memory category 'kind'.
+//
+// This method is thread-safe.
+func (m *MemoryLocked) Inc(val uint64, kind MemoryKind) {
+ m.mu.RLock()
+ m.incLocked(val, kind)
+ m.mu.RUnlock()
+}
+
+func (m *MemoryLocked) decLocked(val uint64, kind MemoryKind) {
+ switch kind {
+ case System:
+ atomic.AddUint64(&m.System, ^(val - 1))
+ case Anonymous:
+ atomic.AddUint64(&m.Anonymous, ^(val - 1))
+ case PageCache:
+ atomic.AddUint64(&m.PageCache, ^(val - 1))
+ case Mapped:
+ atomic.AddUint64(&m.RTMapped, ^(val - 1))
+ case Tmpfs:
+ atomic.AddUint64(&m.Tmpfs, ^(val - 1))
+ case Ramdiskfs:
+ atomic.AddUint64(&m.Ramdiskfs, ^(val - 1))
+ default:
+ panic(fmt.Sprintf("invalid memory kind: %v", kind))
+ }
+}
+
+// Dec remove a usage of 'val' bytes from memory category 'kind'.
+//
+// This method is thread-safe.
+func (m *MemoryLocked) Dec(val uint64, kind MemoryKind) {
+ m.mu.RLock()
+ m.decLocked(val, kind)
+ m.mu.RUnlock()
+}
+
+// Move moves a usage of 'val' bytes from 'from' to 'to'.
+//
+// This method is thread-safe.
+func (m *MemoryLocked) Move(val uint64, to MemoryKind, from MemoryKind) {
+ m.mu.RLock()
+ // Just call decLocked and incLocked directly. We held the RLock to
+ // protect against concurrent callers to Total().
+ m.decLocked(val, from)
+ m.incLocked(val, to)
+ m.mu.RUnlock()
+}
+
+// totalLocked returns a total usage.
+//
+// Precondition: must be called when locked.
+func (m *MemoryLocked) totalLocked() (total uint64) {
+ total += atomic.LoadUint64(&m.System)
+ total += atomic.LoadUint64(&m.Anonymous)
+ total += atomic.LoadUint64(&m.PageCache)
+ total += atomic.LoadUint64(&m.RTMapped)
+ total += atomic.LoadUint64(&m.Tmpfs)
+ total += atomic.LoadUint64(&m.Ramdiskfs)
+ return
+}
+
+// Total returns a total memory usage.
+//
+// This method is thread-safe.
+func (m *MemoryLocked) Total() uint64 {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return m.totalLocked()
+}
+
+// Copy returns a copy of the structure with a total.
+//
+// This method is thread-safe.
+func (m *MemoryLocked) Copy() (MemoryStats, uint64) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ ms := m.MemoryStats
+ ms.Mapped = m.RTMapped
+ return ms, m.totalLocked()
+}
+
+// These options control how much total memory the is reported to the application.
+// They may only be set before the application starts executing, and must not
+// be modified.
+var (
+ // MinimumTotalMemoryBytes is the minimum reported total system memory.
+ MinimumTotalMemoryBytes uint64 = 2 << 30 // 2 GB
+
+ // MaximumTotalMemoryBytes is the maximum reported total system memory.
+ // The 0 value indicates no maximum.
+ MaximumTotalMemoryBytes uint64
+)
+
+// TotalMemory returns the "total usable memory" available.
+//
+// This number doesn't really have a true value so it's based on the following
+// inputs and further bounded to be above the MinumumTotalMemoryBytes and below
+// MaximumTotalMemoryBytes.
+//
+// memSize should be the platform.Memory size reported by platform.Memory.TotalSize()
+// used is the total memory reported by MemoryLocked.Total()
+func TotalMemory(memSize, used uint64) uint64 {
+ if memSize < MinimumTotalMemoryBytes {
+ memSize = MinimumTotalMemoryBytes
+ }
+ if memSize < used {
+ memSize = used
+ // Bump totalSize to the next largest power of 2, if one exists, so
+ // that MemFree isn't 0.
+ if msb := bits.MostSignificantOne64(memSize); msb < 63 {
+ memSize = uint64(1) << (uint(msb) + 1)
+ }
+ }
+ if MaximumTotalMemoryBytes > 0 && memSize > MaximumTotalMemoryBytes {
+ memSize = MaximumTotalMemoryBytes
+ }
+ return memSize
+}
diff --git a/pkg/sentry/usage/memory_unsafe.go b/pkg/sentry/usage/memory_unsafe.go
new file mode 100644
index 000000000..9e0014ca0
--- /dev/null
+++ b/pkg/sentry/usage/memory_unsafe.go
@@ -0,0 +1,27 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usage
+
+import (
+ "unsafe"
+)
+
+// RTMemoryStatsSize is the size of the RTMemoryStats struct.
+var RTMemoryStatsSize = unsafe.Sizeof(RTMemoryStats{})
+
+// RTMemoryStatsPointer casts the address of the byte slice into a RTMemoryStats pointer.
+func RTMemoryStatsPointer(b []byte) *RTMemoryStats {
+ return (*RTMemoryStats)(unsafe.Pointer(&b[0]))
+}
diff --git a/pkg/sentry/usage/usage.go b/pkg/sentry/usage/usage.go
new file mode 100644
index 000000000..e3d33a965
--- /dev/null
+++ b/pkg/sentry/usage/usage.go
@@ -0,0 +1,16 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package usage provides representations of resource usage.
+package usage
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
new file mode 100644
index 000000000..642769e7c
--- /dev/null
+++ b/pkg/sentry/vfs/BUILD
@@ -0,0 +1,100 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+licenses(["notice"])
+
+go_template_instance(
+ name = "epoll_interest_list",
+ out = "epoll_interest_list.go",
+ package = "vfs",
+ prefix = "epollInterest",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*epollInterest",
+ "Linker": "*epollInterest",
+ },
+)
+
+go_template_instance(
+ name = "event_list",
+ out = "event_list.go",
+ package = "vfs",
+ prefix = "event",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*Event",
+ "Linker": "*Event",
+ },
+)
+
+go_library(
+ name = "vfs",
+ srcs = [
+ "anonfs.go",
+ "context.go",
+ "debug.go",
+ "dentry.go",
+ "device.go",
+ "epoll.go",
+ "epoll_interest_list.go",
+ "event_list.go",
+ "file_description.go",
+ "file_description_impl_util.go",
+ "filesystem.go",
+ "filesystem_impl_util.go",
+ "filesystem_type.go",
+ "inotify.go",
+ "lock.go",
+ "mount.go",
+ "mount_unsafe.go",
+ "options.go",
+ "pathname.go",
+ "permissions.go",
+ "resolving_path.go",
+ "vfs.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/context",
+ "//pkg/fd",
+ "//pkg/fdnotifier",
+ "//pkg/fspath",
+ "//pkg/gohacks",
+ "//pkg/log",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/uniqueid",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "vfs_test",
+ size = "small",
+ srcs = [
+ "file_description_impl_util_test.go",
+ "mount_test.go",
+ ],
+ library = ":vfs",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/contexttest",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/vfs/README.md b/pkg/sentry/vfs/README.md
new file mode 100644
index 000000000..4b9faf2ea
--- /dev/null
+++ b/pkg/sentry/vfs/README.md
@@ -0,0 +1,195 @@
+# The gVisor Virtual Filesystem
+
+THIS PACKAGE IS CURRENTLY EXPERIMENTAL AND NOT READY OR ENABLED FOR PRODUCTION
+USE. For the filesystem implementation currently used by gVisor, see the `fs`
+package.
+
+## Implementation Notes
+
+### Reference Counting
+
+Filesystem, Dentry, Mount, MountNamespace, and FileDescription are all
+reference-counted. Mount and MountNamespace are exclusively VFS-managed; when
+their reference count reaches zero, VFS releases their resources. Filesystem and
+FileDescription management is shared between VFS and filesystem implementations;
+when their reference count reaches zero, VFS notifies the implementation by
+calling `FilesystemImpl.Release()` or `FileDescriptionImpl.Release()`
+respectively and then releases VFS-owned resources. Dentries are exclusively
+managed by filesystem implementations; reference count changes are abstracted
+through DentryImpl, which should release resources when reference count reaches
+zero.
+
+Filesystem references are held by:
+
+- Mount: Each referenced Mount holds a reference on the mounted Filesystem.
+
+Dentry references are held by:
+
+- FileDescription: Each referenced FileDescription holds a reference on the
+ Dentry through which it was opened, via `FileDescription.vd.dentry`.
+
+- Mount: Each referenced Mount holds a reference on its mount point and on the
+ mounted filesystem root. The mount point is mutable (`mount(MS_MOVE)`).
+
+Mount references are held by:
+
+- FileDescription: Each referenced FileDescription holds a reference on the
+ Mount on which it was opened, via `FileDescription.vd.mount`.
+
+- Mount: Each referenced Mount holds a reference on its parent, which is the
+ mount containing its mount point.
+
+- VirtualFilesystem: A reference is held on each Mount that has been connected
+ to a mount point, but not yet umounted.
+
+MountNamespace and FileDescription references are held by users of VFS. The
+expectation is that each `kernel.Task` holds a reference on its corresponding
+MountNamespace, and each file descriptor holds a reference on its represented
+FileDescription.
+
+Notes:
+
+- Dentries do not hold a reference on their owning Filesystem. Instead, all
+ uses of a Dentry occur in the context of a Mount, which holds a reference on
+ the relevant Filesystem (see e.g. the VirtualDentry type). As a corollary,
+ when releasing references on both a Dentry and its corresponding Mount, the
+ Dentry's reference must be released first (because releasing the Mount's
+ reference may release the last reference on the Filesystem, whose state may
+ be required to release the Dentry reference).
+
+### The Inheritance Pattern
+
+Filesystem, Dentry, and FileDescription are all concepts featuring both state
+that must be shared between VFS and filesystem implementations, and operations
+that are implementation-defined. To facilitate this, each of these three
+concepts follows the same pattern, shown below for Dentry:
+
+```go
+// Dentry represents a node in a filesystem tree.
+type Dentry struct {
+ // VFS-required dentry state.
+ parent *Dentry
+ // ...
+
+ // impl is the DentryImpl associated with this Dentry. impl is immutable.
+ // This should be the last field in Dentry.
+ impl DentryImpl
+}
+
+// Init must be called before first use of d.
+func (d *Dentry) Init(impl DentryImpl) {
+ d.impl = impl
+}
+
+// Impl returns the DentryImpl associated with d.
+func (d *Dentry) Impl() DentryImpl {
+ return d.impl
+}
+
+// DentryImpl contains implementation-specific details of a Dentry.
+// Implementations of DentryImpl should contain their associated Dentry by
+// value as their first field.
+type DentryImpl interface {
+ // VFS-required implementation-defined dentry operations.
+ IncRef()
+ // ...
+}
+```
+
+This construction, which is essentially a type-safe analogue to Linux's
+`container_of` pattern, has the following properties:
+
+- VFS works almost exclusively with pointers to Dentry rather than DentryImpl
+ interface objects, such as in the type of `Dentry.parent`. This avoids
+ interface method calls (which are somewhat expensive to perform, and defeat
+ inlining and escape analysis), reduces the size of VFS types (since an
+ interface object is two pointers in size), and allows pointers to be loaded
+ and stored atomically using `sync/atomic`. Implementation-defined behavior
+ is accessed via `Dentry.impl` when required.
+
+- Filesystem implementations can access the implementation-defined state
+ associated with objects of VFS types by type-asserting or type-switching
+ (e.g. `Dentry.Impl().(*myDentry)`). Type assertions to a concrete type
+ require only an equality comparison of the interface object's type pointer
+ to a static constant, and are consequently very fast.
+
+- Filesystem implementations can access the VFS state associated with objects
+ of implementation-defined types directly.
+
+- VFS and implementation-defined state for a given type occupy the same
+ object, minimizing memory allocations and maximizing memory locality. `impl`
+ is the last field in `Dentry`, and `Dentry` is the first field in
+ `DentryImpl` implementations, for similar reasons: this tends to cause
+ fetching of the `Dentry.impl` interface object to also fetch `DentryImpl`
+ fields, either because they are in the same cache line or via next-line
+ prefetching.
+
+## Future Work
+
+- Most `mount(2)` features, and unmounting, are incomplete.
+
+- VFS1 filesystems are not directly compatible with VFS2. It may be possible
+ to implement shims that implement `vfs.FilesystemImpl` for
+ `fs.MountNamespace`, `vfs.DentryImpl` for `fs.Dirent`, and
+ `vfs.FileDescriptionImpl` for `fs.File`, which may be adequate for
+ filesystems that are not performance-critical (e.g. sysfs); however, it is
+ not clear that this will be less effort than simply porting the filesystems
+ in question. Practically speaking, the following filesystems will probably
+ need to be ported or made compatible through a shim to evaluate filesystem
+ performance on realistic workloads:
+
+ - devfs/procfs/sysfs, which will realistically be necessary to execute
+ most applications. (Note that procfs and sysfs do not support hard
+ links, so they do not require the complexity of separate inode objects.
+ Also note that Linux's /dev is actually a variant of tmpfs called
+ devtmpfs.)
+
+ - tmpfs. This should be relatively straightforward: copy/paste memfs,
+ store regular file contents in pgalloc-allocated memory instead of
+ `[]byte`, and add support for file timestamps. (In fact, it probably
+ makes more sense to convert memfs to tmpfs and not keep the former.)
+
+ - A remote filesystem, either lisafs (if it is ready by the time that
+ other benchmarking prerequisites are) or v9fs (aka 9P, aka gofers).
+
+ - epoll files.
+
+ Filesystems that will need to be ported before switching to VFS2, but can
+ probably be skipped for early testing:
+
+ - overlayfs, which is needed for (at least) synthetic mount points.
+
+ - Support for host ttys.
+
+ - timerfd files.
+
+ Filesystems that can be probably dropped:
+
+ - ashmem, which is far too incomplete to use.
+
+ - binder, which is similarly far too incomplete to use.
+
+- Save/restore. For instance, it is unclear if the current implementation of
+ the `state` package supports the inheritance pattern described above.
+
+- Many features that were previously implemented by VFS must now be
+ implemented by individual filesystems (though, in most cases, this should
+ consist of calls to hooks or libraries provided by `vfs` or other packages).
+ This includes, but is not necessarily limited to:
+
+ - Block and character device special files
+
+ - Inotify
+
+ - File locking
+
+ - `O_ASYNC`
+
+- Reference counts in the `vfs` package do not use the `refs` package since
+ `refs.AtomicRefCount` adds 64 bytes of overhead to each 8-byte reference
+ count, resulting in considerable cache bloat. 24 bytes of this overhead is
+ for weak reference support, which have poor performance and will not be used
+ by VFS2. The remaining 40 bytes is to store a descriptive string and stack
+ trace for reference leak checking; we can support reference leak checking
+ without incurring this space overhead by including the applicable
+ information directly in finalizers for applicable types.
diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go
new file mode 100644
index 000000000..641e3e502
--- /dev/null
+++ b/pkg/sentry/vfs/anonfs.go
@@ -0,0 +1,314 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// NewAnonVirtualDentry returns a VirtualDentry with the given synthetic name,
+// consistent with Linux's fs/anon_inodes.c:anon_inode_getfile(). References
+// are taken on the returned VirtualDentry.
+func (vfs *VirtualFilesystem) NewAnonVirtualDentry(name string) VirtualDentry {
+ d := anonDentry{
+ name: name,
+ }
+ d.vfsd.Init(&d)
+ vfs.anonMount.IncRef()
+ // anonDentry no-ops refcounting.
+ return VirtualDentry{
+ mount: vfs.anonMount,
+ dentry: &d.vfsd,
+ }
+}
+
+const (
+ anonfsBlockSize = usermem.PageSize // via fs/libfs.c:pseudo_fs_fill_super()
+
+ // Mode, UID, and GID for a generic anonfs file.
+ anonFileMode = 0600 // no type is correct
+ anonFileUID = auth.RootKUID
+ anonFileGID = auth.RootKGID
+)
+
+// anonFilesystemType implements FilesystemType.
+type anonFilesystemType struct{}
+
+// GetFilesystem implements FilesystemType.GetFilesystem.
+func (anonFilesystemType) GetFilesystem(context.Context, *VirtualFilesystem, *auth.Credentials, string, GetFilesystemOptions) (*Filesystem, *Dentry, error) {
+ panic("cannot instaniate an anon filesystem")
+}
+
+// Name implemenents FilesystemType.Name.
+func (anonFilesystemType) Name() string {
+ return "none"
+}
+
+// anonFilesystem is the implementation of FilesystemImpl that backs
+// VirtualDentries returned by VirtualFilesystem.NewAnonVirtualDentry().
+//
+// Since all Dentries in anonFilesystem are non-directories, all FilesystemImpl
+// methods that would require an anonDentry to be a directory return ENOTDIR.
+type anonFilesystem struct {
+ vfsfs Filesystem
+
+ devMinor uint32
+}
+
+type anonDentry struct {
+ vfsd Dentry
+
+ name string
+}
+
+// Release implements FilesystemImpl.Release.
+func (fs *anonFilesystem) Release() {
+}
+
+// Sync implements FilesystemImpl.Sync.
+func (fs *anonFilesystem) Sync(ctx context.Context) error {
+ return nil
+}
+
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *anonFilesystem) AccessAt(ctx context.Context, rp *ResolvingPath, creds *auth.Credentials, ats AccessTypes) error {
+ if !rp.Done() {
+ return syserror.ENOTDIR
+ }
+ return GenericCheckPermissions(creds, ats, anonFileMode, anonFileUID, anonFileGID)
+}
+
+// GetDentryAt implements FilesystemImpl.GetDentryAt.
+func (fs *anonFilesystem) GetDentryAt(ctx context.Context, rp *ResolvingPath, opts GetDentryOptions) (*Dentry, error) {
+ if !rp.Done() {
+ return nil, syserror.ENOTDIR
+ }
+ if opts.CheckSearchable {
+ return nil, syserror.ENOTDIR
+ }
+ // anonDentry no-ops refcounting.
+ return rp.Start(), nil
+}
+
+// GetParentDentryAt implements FilesystemImpl.GetParentDentryAt.
+func (fs *anonFilesystem) GetParentDentryAt(ctx context.Context, rp *ResolvingPath) (*Dentry, error) {
+ if !rp.Final() {
+ return nil, syserror.ENOTDIR
+ }
+ // anonDentry no-ops refcounting.
+ return rp.Start(), nil
+}
+
+// LinkAt implements FilesystemImpl.LinkAt.
+func (fs *anonFilesystem) LinkAt(ctx context.Context, rp *ResolvingPath, vd VirtualDentry) error {
+ if !rp.Final() {
+ return syserror.ENOTDIR
+ }
+ return syserror.EPERM
+}
+
+// MkdirAt implements FilesystemImpl.MkdirAt.
+func (fs *anonFilesystem) MkdirAt(ctx context.Context, rp *ResolvingPath, opts MkdirOptions) error {
+ if !rp.Final() {
+ return syserror.ENOTDIR
+ }
+ return syserror.EPERM
+}
+
+// MknodAt implements FilesystemImpl.MknodAt.
+func (fs *anonFilesystem) MknodAt(ctx context.Context, rp *ResolvingPath, opts MknodOptions) error {
+ if !rp.Final() {
+ return syserror.ENOTDIR
+ }
+ return syserror.EPERM
+}
+
+// OpenAt implements FilesystemImpl.OpenAt.
+func (fs *anonFilesystem) OpenAt(ctx context.Context, rp *ResolvingPath, opts OpenOptions) (*FileDescription, error) {
+ if !rp.Done() {
+ return nil, syserror.ENOTDIR
+ }
+ return nil, syserror.ENODEV
+}
+
+// ReadlinkAt implements FilesystemImpl.ReadlinkAt.
+func (fs *anonFilesystem) ReadlinkAt(ctx context.Context, rp *ResolvingPath) (string, error) {
+ if !rp.Done() {
+ return "", syserror.ENOTDIR
+ }
+ return "", syserror.EINVAL
+}
+
+// RenameAt implements FilesystemImpl.RenameAt.
+func (fs *anonFilesystem) RenameAt(ctx context.Context, rp *ResolvingPath, oldParentVD VirtualDentry, oldName string, opts RenameOptions) error {
+ if !rp.Final() {
+ return syserror.ENOTDIR
+ }
+ return syserror.EPERM
+}
+
+// RmdirAt implements FilesystemImpl.RmdirAt.
+func (fs *anonFilesystem) RmdirAt(ctx context.Context, rp *ResolvingPath) error {
+ if !rp.Final() {
+ return syserror.ENOTDIR
+ }
+ return syserror.EPERM
+}
+
+// SetStatAt implements FilesystemImpl.SetStatAt.
+func (fs *anonFilesystem) SetStatAt(ctx context.Context, rp *ResolvingPath, opts SetStatOptions) error {
+ if !rp.Done() {
+ return syserror.ENOTDIR
+ }
+ // Linux actually permits anon_inode_inode's metadata to be set, which is
+ // visible to all users of anon_inode_inode. We just silently ignore
+ // metadata changes.
+ return nil
+}
+
+// StatAt implements FilesystemImpl.StatAt.
+func (fs *anonFilesystem) StatAt(ctx context.Context, rp *ResolvingPath, opts StatOptions) (linux.Statx, error) {
+ if !rp.Done() {
+ return linux.Statx{}, syserror.ENOTDIR
+ }
+ // See fs/anon_inodes.c:anon_inode_init() => fs/libfs.c:alloc_anon_inode().
+ return linux.Statx{
+ Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS,
+ Blksize: anonfsBlockSize,
+ Nlink: 1,
+ UID: uint32(anonFileUID),
+ GID: uint32(anonFileGID),
+ Mode: anonFileMode,
+ Ino: 1,
+ Size: 0,
+ Blocks: 0,
+ DevMajor: linux.UNNAMED_MAJOR,
+ DevMinor: fs.devMinor,
+ }, nil
+}
+
+// StatFSAt implements FilesystemImpl.StatFSAt.
+func (fs *anonFilesystem) StatFSAt(ctx context.Context, rp *ResolvingPath) (linux.Statfs, error) {
+ if !rp.Done() {
+ return linux.Statfs{}, syserror.ENOTDIR
+ }
+ return linux.Statfs{
+ Type: linux.ANON_INODE_FS_MAGIC,
+ BlockSize: anonfsBlockSize,
+ }, nil
+}
+
+// SymlinkAt implements FilesystemImpl.SymlinkAt.
+func (fs *anonFilesystem) SymlinkAt(ctx context.Context, rp *ResolvingPath, target string) error {
+ if !rp.Final() {
+ return syserror.ENOTDIR
+ }
+ return syserror.EPERM
+}
+
+// UnlinkAt implements FilesystemImpl.UnlinkAt.
+func (fs *anonFilesystem) UnlinkAt(ctx context.Context, rp *ResolvingPath) error {
+ if !rp.Final() {
+ return syserror.ENOTDIR
+ }
+ return syserror.EPERM
+}
+
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+func (fs *anonFilesystem) BoundEndpointAt(ctx context.Context, rp *ResolvingPath, opts BoundEndpointOptions) (transport.BoundEndpoint, error) {
+ if !rp.Final() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := GenericCheckPermissions(rp.Credentials(), MayWrite, anonFileMode, anonFileUID, anonFileGID); err != nil {
+ return nil, err
+ }
+ return nil, syserror.ECONNREFUSED
+}
+
+// ListxattrAt implements FilesystemImpl.ListxattrAt.
+func (fs *anonFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) {
+ if !rp.Done() {
+ return nil, syserror.ENOTDIR
+ }
+ return nil, nil
+}
+
+// GetxattrAt implements FilesystemImpl.GetxattrAt.
+func (fs *anonFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error) {
+ if !rp.Done() {
+ return "", syserror.ENOTDIR
+ }
+ return "", syserror.ENOTSUP
+}
+
+// SetxattrAt implements FilesystemImpl.SetxattrAt.
+func (fs *anonFilesystem) SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error {
+ if !rp.Done() {
+ return syserror.ENOTDIR
+ }
+ return syserror.EPERM
+}
+
+// RemovexattrAt implements FilesystemImpl.RemovexattrAt.
+func (fs *anonFilesystem) RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error {
+ if !rp.Done() {
+ return syserror.ENOTDIR
+ }
+ return syserror.EPERM
+}
+
+// PrependPath implements FilesystemImpl.PrependPath.
+func (fs *anonFilesystem) PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error {
+ b.PrependComponent(fmt.Sprintf("anon_inode:%s", vd.dentry.impl.(*anonDentry).name))
+ return PrependPathSyntheticError{}
+}
+
+// IncRef implements DentryImpl.IncRef.
+func (d *anonDentry) IncRef() {
+ // no-op
+}
+
+// TryIncRef implements DentryImpl.TryIncRef.
+func (d *anonDentry) TryIncRef() bool {
+ return true
+}
+
+// DecRef implements DentryImpl.DecRef.
+func (d *anonDentry) DecRef() {
+ // no-op
+}
+
+// InotifyWithParent implements DentryImpl.InotifyWithParent.
+//
+// Although Linux technically supports inotify on pseudo filesystems (inotify
+// is implemented at the vfs layer), it is not particularly useful. It is left
+// unimplemented until someone actually needs it.
+func (d *anonDentry) InotifyWithParent(events, cookie uint32, et EventType) {}
+
+// Watches implements DentryImpl.Watches.
+func (d *anonDentry) Watches() *Watches {
+ return nil
+}
+
+// OnZeroWatches implements Dentry.OnZeroWatches.
+func (d *anonDentry) OnZeroWatches() {}
diff --git a/pkg/sentry/vfs/context.go b/pkg/sentry/vfs/context.go
new file mode 100644
index 000000000..c9e724fef
--- /dev/null
+++ b/pkg/sentry/vfs/context.go
@@ -0,0 +1,75 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+// contextID is this package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxMountNamespace is a Context.Value key for a MountNamespace.
+ CtxMountNamespace contextID = iota
+
+ // CtxRoot is a Context.Value key for a VFS root.
+ CtxRoot
+)
+
+// MountNamespaceFromContext returns the MountNamespace used by ctx. If ctx is
+// not associated with a MountNamespace, MountNamespaceFromContext returns nil.
+//
+// A reference is taken on the returned MountNamespace.
+func MountNamespaceFromContext(ctx context.Context) *MountNamespace {
+ if v := ctx.Value(CtxMountNamespace); v != nil {
+ return v.(*MountNamespace)
+ }
+ return nil
+}
+
+// RootFromContext returns the VFS root used by ctx. It takes a reference on
+// the returned VirtualDentry. If ctx does not have a specific VFS root,
+// RootFromContext returns a zero-value VirtualDentry.
+func RootFromContext(ctx context.Context) VirtualDentry {
+ if v := ctx.Value(CtxRoot); v != nil {
+ return v.(VirtualDentry)
+ }
+ return VirtualDentry{}
+}
+
+type rootContext struct {
+ context.Context
+ root VirtualDentry
+}
+
+// WithRoot returns a copy of ctx with the given root.
+func WithRoot(ctx context.Context, root VirtualDentry) context.Context {
+ return &rootContext{
+ Context: ctx,
+ root: root,
+ }
+}
+
+// Value implements Context.Value.
+func (rc rootContext) Value(key interface{}) interface{} {
+ switch key {
+ case CtxRoot:
+ rc.root.IncRef()
+ return rc.root
+ default:
+ return rc.Context.Value(key)
+ }
+}
diff --git a/pkg/sentry/vfs/debug.go b/pkg/sentry/vfs/debug.go
new file mode 100644
index 000000000..0ed20f249
--- /dev/null
+++ b/pkg/sentry/vfs/debug.go
@@ -0,0 +1,22 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+const (
+ // If checkInvariants is true, perform runtime checks for invariants
+ // expected by the vfs package. This is normally disabled since VFS is
+ // often a hot path.
+ checkInvariants = false
+)
diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go
new file mode 100644
index 000000000..cea3e6955
--- /dev/null
+++ b/pkg/sentry/vfs/dentry.go
@@ -0,0 +1,324 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Dentry represents a node in a Filesystem tree at which a file exists.
+//
+// Dentries are reference-counted. Unless otherwise specified, all Dentry
+// methods require that a reference is held.
+//
+// Dentry is loosely analogous to Linux's struct dentry, but:
+//
+// - VFS does not associate Dentries with inodes. gVisor interacts primarily
+// with filesystems that are accessed through filesystem APIs (as opposed to
+// raw block devices); many such APIs support only paths and file descriptors,
+// and not inodes. Furthermore, when parties outside the scope of VFS can
+// rename inodes on such filesystems, VFS generally cannot "follow" the rename,
+// both due to synchronization issues and because it may not even be able to
+// name the destination path; this implies that it would in fact be incorrect
+// for Dentries to be associated with inodes on such filesystems. Consequently,
+// operations that are inode operations in Linux are FilesystemImpl methods
+// and/or FileDescriptionImpl methods in gVisor's VFS. Filesystems that do
+// support inodes may store appropriate state in implementations of DentryImpl.
+//
+// - VFS does not require that Dentries are instantiated for all paths accessed
+// through VFS, only those that are tracked beyond the scope of a single
+// Filesystem operation. This includes file descriptions, mount points, mount
+// roots, process working directories, and chroots. This avoids instantiation
+// of Dentries for operations on mutable remote filesystems that can't actually
+// cache any state in the Dentry.
+//
+// - VFS does not track filesystem structure (i.e. relationships between
+// Dentries), since both the relevant state and synchronization are
+// filesystem-specific.
+//
+// - For the reasons above, VFS is not directly responsible for managing Dentry
+// lifetime. Dentry reference counts only indicate the extent to which VFS
+// requires Dentries to exist; Filesystems may elect to cache or discard
+// Dentries with zero references.
+//
+// +stateify savable
+type Dentry struct {
+ // mu synchronizes deletion/invalidation and mounting over this Dentry.
+ mu sync.Mutex `state:"nosave"`
+
+ // dead is true if the file represented by this Dentry has been deleted (by
+ // CommitDeleteDentry or CommitRenameReplaceDentry) or invalidated (by
+ // InvalidateDentry). dead is protected by mu.
+ dead bool
+
+ // mounts is the number of Mounts for which this Dentry is Mount.point.
+ // mounts is accessed using atomic memory operations.
+ mounts uint32
+
+ // impl is the DentryImpl associated with this Dentry. impl is immutable.
+ // This should be the last field in Dentry.
+ impl DentryImpl
+}
+
+// Init must be called before first use of d.
+func (d *Dentry) Init(impl DentryImpl) {
+ d.impl = impl
+}
+
+// Impl returns the DentryImpl associated with d.
+func (d *Dentry) Impl() DentryImpl {
+ return d.impl
+}
+
+// DentryImpl contains implementation details for a Dentry. Implementations of
+// DentryImpl should contain their associated Dentry by value as their first
+// field.
+type DentryImpl interface {
+ // IncRef increments the Dentry's reference count. A Dentry with a non-zero
+ // reference count must remain coherent with the state of the filesystem.
+ IncRef()
+
+ // TryIncRef increments the Dentry's reference count and returns true. If
+ // the Dentry's reference count is zero, TryIncRef may do nothing and
+ // return false. (It is also permitted to succeed if it can restore the
+ // guarantee that the Dentry is coherent with the state of the filesystem.)
+ //
+ // TryIncRef does not require that a reference is held on the Dentry.
+ TryIncRef() bool
+
+ // DecRef decrements the Dentry's reference count.
+ DecRef()
+
+ // InotifyWithParent notifies all watches on the targets represented by this
+ // dentry and its parent. The parent's watches are notified first, followed
+ // by this dentry's.
+ //
+ // InotifyWithParent automatically adds the IN_ISDIR flag for dentries
+ // representing directories.
+ //
+ // Note that the events may not actually propagate up to the user, depending
+ // on the event masks.
+ InotifyWithParent(events, cookie uint32, et EventType)
+
+ // Watches returns the set of inotify watches for the file corresponding to
+ // the Dentry. Dentries that are hard links to the same underlying file
+ // share the same watches.
+ //
+ // Watches may return nil if the dentry belongs to a FilesystemImpl that
+ // does not support inotify. If an implementation returns a non-nil watch
+ // set, it must always return a non-nil watch set. Likewise, if an
+ // implementation returns a nil watch set, it must always return a nil watch
+ // set.
+ //
+ // The caller does not need to hold a reference on the dentry.
+ Watches() *Watches
+
+ // OnZeroWatches is called whenever the number of watches on a dentry drops
+ // to zero. This is needed by some FilesystemImpls (e.g. gofer) to manage
+ // dentry lifetime.
+ //
+ // The caller does not need to hold a reference on the dentry. OnZeroWatches
+ // may acquire inotify locks, so to prevent deadlock, no inotify locks should
+ // be held by the caller.
+ OnZeroWatches()
+}
+
+// IncRef increments d's reference count.
+func (d *Dentry) IncRef() {
+ d.impl.IncRef()
+}
+
+// TryIncRef increments d's reference count and returns true. If d's reference
+// count is zero, TryIncRef may instead do nothing and return false.
+func (d *Dentry) TryIncRef() bool {
+ return d.impl.TryIncRef()
+}
+
+// DecRef decrements d's reference count.
+func (d *Dentry) DecRef() {
+ d.impl.DecRef()
+}
+
+// IsDead returns true if d has been deleted or invalidated by its owning
+// filesystem.
+func (d *Dentry) IsDead() bool {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.dead
+}
+
+func (d *Dentry) isMounted() bool {
+ return atomic.LoadUint32(&d.mounts) != 0
+}
+
+// InotifyWithParent notifies all watches on the targets represented by d and
+// its parent of events.
+func (d *Dentry) InotifyWithParent(events, cookie uint32, et EventType) {
+ d.impl.InotifyWithParent(events, cookie, et)
+}
+
+// Watches returns the set of inotify watches associated with d.
+//
+// Watches will return nil if d belongs to a FilesystemImpl that does not
+// support inotify.
+func (d *Dentry) Watches() *Watches {
+ return d.impl.Watches()
+}
+
+// OnZeroWatches performs cleanup tasks whenever the number of watches on a
+// dentry drops to zero.
+func (d *Dentry) OnZeroWatches() {
+ d.impl.OnZeroWatches()
+}
+
+// The following functions are exported so that filesystem implementations can
+// use them. The vfs package, and users of VFS, should not call these
+// functions.
+
+// PrepareDeleteDentry must be called before attempting to delete the file
+// represented by d. If PrepareDeleteDentry succeeds, the caller must call
+// AbortDeleteDentry or CommitDeleteDentry depending on the deletion's outcome.
+func (vfs *VirtualFilesystem) PrepareDeleteDentry(mntns *MountNamespace, d *Dentry) error {
+ vfs.mountMu.Lock()
+ if mntns.mountpoints[d] != 0 {
+ vfs.mountMu.Unlock()
+ return syserror.EBUSY
+ }
+ d.mu.Lock()
+ vfs.mountMu.Unlock()
+ // Return with d.mu locked to block attempts to mount over it; it will be
+ // unlocked by AbortDeleteDentry or CommitDeleteDentry.
+ return nil
+}
+
+// AbortDeleteDentry must be called after PrepareDeleteDentry if the deletion
+// fails.
+func (vfs *VirtualFilesystem) AbortDeleteDentry(d *Dentry) {
+ d.mu.Unlock()
+}
+
+// CommitDeleteDentry must be called after PrepareDeleteDentry if the deletion
+// succeeds.
+func (vfs *VirtualFilesystem) CommitDeleteDentry(d *Dentry) {
+ d.dead = true
+ d.mu.Unlock()
+ if d.isMounted() {
+ vfs.forgetDeadMountpoint(d)
+ }
+}
+
+// InvalidateDentry is called when d ceases to represent the file it formerly
+// did for reasons outside of VFS' control (e.g. d represents the local state
+// of a file on a remote filesystem on which the file has already been
+// deleted).
+func (vfs *VirtualFilesystem) InvalidateDentry(d *Dentry) {
+ d.mu.Lock()
+ d.dead = true
+ d.mu.Unlock()
+ if d.isMounted() {
+ vfs.forgetDeadMountpoint(d)
+ }
+}
+
+// PrepareRenameDentry must be called before attempting to rename the file
+// represented by from. If to is not nil, it represents the file that will be
+// replaced or exchanged by the rename. If PrepareRenameDentry succeeds, the
+// caller must call AbortRenameDentry, CommitRenameReplaceDentry, or
+// CommitRenameExchangeDentry depending on the rename's outcome.
+//
+// Preconditions: If to is not nil, it must be a child Dentry from the same
+// Filesystem. from != to.
+func (vfs *VirtualFilesystem) PrepareRenameDentry(mntns *MountNamespace, from, to *Dentry) error {
+ vfs.mountMu.Lock()
+ if mntns.mountpoints[from] != 0 {
+ vfs.mountMu.Unlock()
+ return syserror.EBUSY
+ }
+ if to != nil {
+ if mntns.mountpoints[to] != 0 {
+ vfs.mountMu.Unlock()
+ return syserror.EBUSY
+ }
+ to.mu.Lock()
+ }
+ from.mu.Lock()
+ vfs.mountMu.Unlock()
+ // Return with from.mu and to.mu locked, which will be unlocked by
+ // AbortRenameDentry, CommitRenameReplaceDentry, or
+ // CommitRenameExchangeDentry.
+ return nil
+}
+
+// AbortRenameDentry must be called after PrepareRenameDentry if the rename
+// fails.
+func (vfs *VirtualFilesystem) AbortRenameDentry(from, to *Dentry) {
+ from.mu.Unlock()
+ if to != nil {
+ to.mu.Unlock()
+ }
+}
+
+// CommitRenameReplaceDentry must be called after the file represented by from
+// is renamed without RENAME_EXCHANGE. If to is not nil, it represents the file
+// that was replaced by from.
+//
+// Preconditions: PrepareRenameDentry was previously called on from and to.
+func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(from, to *Dentry) {
+ from.mu.Unlock()
+ if to != nil {
+ to.dead = true
+ to.mu.Unlock()
+ if to.isMounted() {
+ vfs.forgetDeadMountpoint(to)
+ }
+ }
+}
+
+// CommitRenameExchangeDentry must be called after the files represented by
+// from and to are exchanged by rename(RENAME_EXCHANGE).
+//
+// Preconditions: PrepareRenameDentry was previously called on from and to.
+func (vfs *VirtualFilesystem) CommitRenameExchangeDentry(from, to *Dentry) {
+ from.mu.Unlock()
+ to.mu.Unlock()
+}
+
+// forgetDeadMountpoint is called when a mount point is deleted or invalidated
+// to umount all mounts using it in all other mount namespaces.
+//
+// forgetDeadMountpoint is analogous to Linux's
+// fs/namespace.c:__detach_mounts().
+func (vfs *VirtualFilesystem) forgetDeadMountpoint(d *Dentry) {
+ var (
+ vdsToDecRef []VirtualDentry
+ mountsToDecRef []*Mount
+ )
+ vfs.mountMu.Lock()
+ vfs.mounts.seq.BeginWrite()
+ for mnt := range vfs.mountpoints[d] {
+ vdsToDecRef, mountsToDecRef = vfs.umountRecursiveLocked(mnt, &umountRecursiveOptions{}, vdsToDecRef, mountsToDecRef)
+ }
+ vfs.mounts.seq.EndWrite()
+ vfs.mountMu.Unlock()
+ for _, vd := range vdsToDecRef {
+ vd.DecRef()
+ }
+ for _, mnt := range mountsToDecRef {
+ mnt.DecRef()
+ }
+}
diff --git a/pkg/sentry/vfs/device.go b/pkg/sentry/vfs/device.go
new file mode 100644
index 000000000..1e9dffc8f
--- /dev/null
+++ b/pkg/sentry/vfs/device.go
@@ -0,0 +1,132 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// DeviceKind indicates whether a device is a block or character device.
+type DeviceKind uint32
+
+const (
+ // BlockDevice indicates a block device.
+ BlockDevice DeviceKind = iota
+
+ // CharDevice indicates a character device.
+ CharDevice
+)
+
+// String implements fmt.Stringer.String.
+func (kind DeviceKind) String() string {
+ switch kind {
+ case BlockDevice:
+ return "block"
+ case CharDevice:
+ return "character"
+ default:
+ return fmt.Sprintf("invalid device kind %d", kind)
+ }
+}
+
+type devTuple struct {
+ kind DeviceKind
+ major uint32
+ minor uint32
+}
+
+// A Device backs device special files.
+type Device interface {
+ // Open returns a FileDescription representing this device.
+ Open(ctx context.Context, mnt *Mount, d *Dentry, opts OpenOptions) (*FileDescription, error)
+}
+
+// +stateify savable
+type registeredDevice struct {
+ dev Device
+ opts RegisterDeviceOptions
+}
+
+// RegisterDeviceOptions contains options to
+// VirtualFilesystem.RegisterDevice().
+//
+// +stateify savable
+type RegisterDeviceOptions struct {
+ // GroupName is the name shown for this device registration in
+ // /proc/devices. If GroupName is empty, this registration will not be
+ // shown in /proc/devices.
+ GroupName string
+}
+
+// RegisterDevice registers the given Device in vfs with the given major and
+// minor device numbers.
+func (vfs *VirtualFilesystem) RegisterDevice(kind DeviceKind, major, minor uint32, dev Device, opts *RegisterDeviceOptions) error {
+ tup := devTuple{kind, major, minor}
+ vfs.devicesMu.Lock()
+ defer vfs.devicesMu.Unlock()
+ if existing, ok := vfs.devices[tup]; ok {
+ return fmt.Errorf("%s device number (%d, %d) is already registered to device type %T", kind, major, minor, existing.dev)
+ }
+ vfs.devices[tup] = &registeredDevice{
+ dev: dev,
+ opts: *opts,
+ }
+ return nil
+}
+
+// OpenDeviceSpecialFile returns a FileDescription representing the given
+// device.
+func (vfs *VirtualFilesystem) OpenDeviceSpecialFile(ctx context.Context, mnt *Mount, d *Dentry, kind DeviceKind, major, minor uint32, opts *OpenOptions) (*FileDescription, error) {
+ tup := devTuple{kind, major, minor}
+ vfs.devicesMu.RLock()
+ defer vfs.devicesMu.RUnlock()
+ rd, ok := vfs.devices[tup]
+ if !ok {
+ return nil, syserror.ENXIO
+ }
+ return rd.dev.Open(ctx, mnt, d, *opts)
+}
+
+// GetAnonBlockDevMinor allocates and returns an unused minor device number for
+// an "anonymous" block device with major number UNNAMED_MAJOR.
+func (vfs *VirtualFilesystem) GetAnonBlockDevMinor() (uint32, error) {
+ vfs.anonBlockDevMinorMu.Lock()
+ defer vfs.anonBlockDevMinorMu.Unlock()
+ minor := vfs.anonBlockDevMinorNext
+ const maxDevMinor = (1 << 20) - 1
+ for minor < maxDevMinor {
+ if _, ok := vfs.anonBlockDevMinor[minor]; !ok {
+ vfs.anonBlockDevMinor[minor] = struct{}{}
+ vfs.anonBlockDevMinorNext = minor + 1
+ return minor, nil
+ }
+ minor++
+ }
+ return 0, syserror.EMFILE
+}
+
+// PutAnonBlockDevMinor deallocates a minor device number returned by a
+// previous call to GetAnonBlockDevMinor.
+func (vfs *VirtualFilesystem) PutAnonBlockDevMinor(minor uint32) {
+ vfs.anonBlockDevMinorMu.Lock()
+ defer vfs.anonBlockDevMinorMu.Unlock()
+ delete(vfs.anonBlockDevMinor, minor)
+ if minor < vfs.anonBlockDevMinorNext {
+ vfs.anonBlockDevMinorNext = minor
+ }
+}
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
new file mode 100644
index 000000000..599c3131c
--- /dev/null
+++ b/pkg/sentry/vfs/epoll.go
@@ -0,0 +1,383 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// epollCycleMu serializes attempts to register EpollInstances with other
+// EpollInstances in order to check for cycles.
+var epollCycleMu sync.Mutex
+
+// EpollInstance represents an epoll instance, as described by epoll(7).
+type EpollInstance struct {
+ vfsfd FileDescription
+ FileDescriptionDefaultImpl
+ DentryMetadataFileDescriptionImpl
+ NoLockFD
+
+ // q holds waiters on this EpollInstance.
+ q waiter.Queue
+
+ // interest is the set of file descriptors that are registered with the
+ // EpollInstance for monitoring. interest is protected by interestMu.
+ interestMu sync.Mutex
+ interest map[epollInterestKey]*epollInterest
+
+ // mu protects fields in registered epollInterests.
+ mu sync.Mutex
+
+ // ready is the set of file descriptors that may be "ready" for I/O. Note
+ // that this must be an ordered list, not a map: "If more than maxevents
+ // file descriptors are ready when epoll_wait() is called, then successive
+ // epoll_wait() calls will round robin through the set of ready file
+ // descriptors. This behavior helps avoid starvation scenarios, where a
+ // process fails to notice that additional file descriptors are ready
+ // because it focuses on a set of file descriptors that are already known
+ // to be ready." - epoll_wait(2)
+ ready epollInterestList
+}
+
+type epollInterestKey struct {
+ // file is the registered FileDescription. No reference is held on file;
+ // instead, when the last reference is dropped, FileDescription.DecRef()
+ // removes the FileDescription from all EpollInstances. file is immutable.
+ file *FileDescription
+
+ // num is the file descriptor number with which this entry was registered.
+ // num is immutable.
+ num int32
+}
+
+// epollInterest represents an EpollInstance's interest in a file descriptor.
+type epollInterest struct {
+ // epoll is the owning EpollInstance. epoll is immutable.
+ epoll *EpollInstance
+
+ // key is the file to which this epollInterest applies. key is immutable.
+ key epollInterestKey
+
+ // waiter is registered with key.file. entry is protected by epoll.mu.
+ waiter waiter.Entry
+
+ // mask is the event mask associated with this registration, including
+ // flags EPOLLET and EPOLLONESHOT. mask is protected by epoll.mu.
+ mask uint32
+
+ // ready is true if epollInterestEntry is linked into epoll.ready. ready
+ // and epollInterestEntry are protected by epoll.mu.
+ ready bool
+ epollInterestEntry
+
+ // userData is the struct epoll_event::data associated with this
+ // epollInterest. userData is protected by epoll.mu.
+ userData [2]int32
+}
+
+// NewEpollInstanceFD returns a FileDescription representing a new epoll
+// instance. A reference is taken on the returned FileDescription.
+func (vfs *VirtualFilesystem) NewEpollInstanceFD() (*FileDescription, error) {
+ vd := vfs.NewAnonVirtualDentry("[eventpoll]")
+ defer vd.DecRef()
+ ep := &EpollInstance{
+ interest: make(map[epollInterestKey]*epollInterest),
+ }
+ if err := ep.vfsfd.Init(ep, linux.O_RDWR, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &ep.vfsfd, nil
+}
+
+// Release implements FileDescriptionImpl.Release.
+func (ep *EpollInstance) Release() {
+ // Unregister all polled fds.
+ ep.interestMu.Lock()
+ defer ep.interestMu.Unlock()
+ for key, epi := range ep.interest {
+ file := key.file
+ file.epollMu.Lock()
+ delete(file.epolls, epi)
+ file.epollMu.Unlock()
+ file.EventUnregister(&epi.waiter)
+ }
+ ep.interest = nil
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (ep *EpollInstance) Readiness(mask waiter.EventMask) waiter.EventMask {
+ if mask&waiter.EventIn == 0 {
+ return 0
+ }
+ ep.mu.Lock()
+ for epi := ep.ready.Front(); epi != nil; epi = epi.Next() {
+ wmask := waiter.EventMaskFromLinux(epi.mask)
+ if epi.key.file.Readiness(wmask)&wmask != 0 {
+ ep.mu.Unlock()
+ return waiter.EventIn
+ }
+ }
+ ep.mu.Unlock()
+ return 0
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (ep *EpollInstance) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ ep.q.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (ep *EpollInstance) EventUnregister(e *waiter.Entry) {
+ ep.q.EventUnregister(e)
+}
+
+// Seek implements FileDescriptionImpl.Seek.
+func (ep *EpollInstance) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ // Linux: fs/eventpoll.c:eventpoll_fops.llseek == noop_llseek
+ return 0, nil
+}
+
+// AddInterest implements the semantics of EPOLL_CTL_ADD.
+//
+// Preconditions: A reference must be held on file.
+func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event linux.EpollEvent) error {
+ // Check for cyclic polling if necessary.
+ subep, _ := file.impl.(*EpollInstance)
+ if subep != nil {
+ epollCycleMu.Lock()
+ // epollCycleMu must be locked for the rest of AddInterest to ensure
+ // that cyclic polling is not introduced after the check.
+ defer epollCycleMu.Unlock()
+ if subep.mightPoll(ep) {
+ return syserror.ELOOP
+ }
+ }
+
+ ep.interestMu.Lock()
+ defer ep.interestMu.Unlock()
+
+ // Fail if the key is already registered.
+ key := epollInterestKey{
+ file: file,
+ num: num,
+ }
+ if _, ok := ep.interest[key]; ok {
+ return syserror.EEXIST
+ }
+
+ // Register interest in file.
+ mask := event.Events | linux.EPOLLERR | linux.EPOLLRDHUP
+ epi := &epollInterest{
+ epoll: ep,
+ key: key,
+ mask: mask,
+ userData: event.Data,
+ }
+ epi.waiter.Callback = epi
+ ep.interest[key] = epi
+ wmask := waiter.EventMaskFromLinux(mask)
+ file.EventRegister(&epi.waiter, wmask)
+
+ // Check if the file is already ready.
+ if file.Readiness(wmask)&wmask != 0 {
+ epi.Callback(nil)
+ }
+
+ // Add epi to file.epolls so that it is removed when the last
+ // FileDescription reference is dropped.
+ file.epollMu.Lock()
+ if file.epolls == nil {
+ file.epolls = make(map[*epollInterest]struct{})
+ }
+ file.epolls[epi] = struct{}{}
+ file.epollMu.Unlock()
+
+ return nil
+}
+
+func (ep *EpollInstance) mightPoll(ep2 *EpollInstance) bool {
+ return ep.mightPollRecursive(ep2, 4) // Linux: fs/eventpoll.c:EP_MAX_NESTS
+}
+
+func (ep *EpollInstance) mightPollRecursive(ep2 *EpollInstance, remainingRecursion int) bool {
+ ep.interestMu.Lock()
+ defer ep.interestMu.Unlock()
+ for key := range ep.interest {
+ nextep, ok := key.file.impl.(*EpollInstance)
+ if !ok {
+ continue
+ }
+ if nextep == ep2 {
+ return true
+ }
+ if remainingRecursion == 0 {
+ return true
+ }
+ if nextep.mightPollRecursive(ep2, remainingRecursion-1) {
+ return true
+ }
+ }
+ return false
+}
+
+// ModifyInterest implements the semantics of EPOLL_CTL_MOD.
+//
+// Preconditions: A reference must be held on file.
+func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, event linux.EpollEvent) error {
+ ep.interestMu.Lock()
+ defer ep.interestMu.Unlock()
+
+ // Fail if the key is not already registered.
+ epi, ok := ep.interest[epollInterestKey{
+ file: file,
+ num: num,
+ }]
+ if !ok {
+ return syserror.ENOENT
+ }
+
+ // Update epi for the next call to ep.ReadEvents().
+ mask := event.Events | linux.EPOLLERR | linux.EPOLLRDHUP
+ ep.mu.Lock()
+ epi.mask = mask
+ epi.userData = event.Data
+ ep.mu.Unlock()
+
+ // Re-register with the new mask.
+ file.EventUnregister(&epi.waiter)
+ wmask := waiter.EventMaskFromLinux(mask)
+ file.EventRegister(&epi.waiter, wmask)
+
+ // Check if the file is already ready with the new mask.
+ if file.Readiness(wmask)&wmask != 0 {
+ epi.Callback(nil)
+ }
+
+ return nil
+}
+
+// DeleteInterest implements the semantics of EPOLL_CTL_DEL.
+//
+// Preconditions: A reference must be held on file.
+func (ep *EpollInstance) DeleteInterest(file *FileDescription, num int32) error {
+ ep.interestMu.Lock()
+ defer ep.interestMu.Unlock()
+
+ // Fail if the key is not already registered.
+ epi, ok := ep.interest[epollInterestKey{
+ file: file,
+ num: num,
+ }]
+ if !ok {
+ return syserror.ENOENT
+ }
+
+ // Unregister from the file so that epi will no longer be readied.
+ file.EventUnregister(&epi.waiter)
+
+ // Forget about epi.
+ ep.removeLocked(epi)
+
+ file.epollMu.Lock()
+ delete(file.epolls, epi)
+ file.epollMu.Unlock()
+
+ return nil
+}
+
+// Callback implements waiter.EntryCallback.Callback.
+func (epi *epollInterest) Callback(*waiter.Entry) {
+ newReady := false
+ epi.epoll.mu.Lock()
+ if !epi.ready {
+ newReady = true
+ epi.ready = true
+ epi.epoll.ready.PushBack(epi)
+ }
+ epi.epoll.mu.Unlock()
+ if newReady {
+ epi.epoll.q.Notify(waiter.EventIn)
+ }
+}
+
+// Preconditions: ep.interestMu must be locked.
+func (ep *EpollInstance) removeLocked(epi *epollInterest) {
+ delete(ep.interest, epi.key)
+ ep.mu.Lock()
+ if epi.ready {
+ epi.ready = false
+ ep.ready.Remove(epi)
+ }
+ ep.mu.Unlock()
+}
+
+// ReadEvents reads up to len(events) ready events into events and returns the
+// number of events read.
+//
+// Preconditions: len(events) != 0.
+func (ep *EpollInstance) ReadEvents(events []linux.EpollEvent) int {
+ i := 0
+ // Hot path: avoid defer.
+ ep.mu.Lock()
+ var next *epollInterest
+ var requeue epollInterestList
+ for epi := ep.ready.Front(); epi != nil; epi = next {
+ next = epi.Next()
+ // Regardless of what else happens, epi is initially removed from the
+ // ready list.
+ ep.ready.Remove(epi)
+ wmask := waiter.EventMaskFromLinux(epi.mask)
+ ievents := epi.key.file.Readiness(wmask) & wmask
+ if ievents == 0 {
+ // Leave epi off the ready list.
+ epi.ready = false
+ continue
+ }
+ // Determine what we should do with epi.
+ switch {
+ case epi.mask&linux.EPOLLONESHOT != 0:
+ // Clear all events from the mask; they must be re-added by
+ // EPOLL_CTL_MOD.
+ epi.mask &= linux.EP_PRIVATE_BITS
+ fallthrough
+ case epi.mask&linux.EPOLLET != 0:
+ // Leave epi off the ready list.
+ epi.ready = false
+ default:
+ // Queue epi to be moved to the end of the ready list.
+ requeue.PushBack(epi)
+ }
+ // Report ievents.
+ events[i] = linux.EpollEvent{
+ Events: ievents.ToLinux(),
+ Data: epi.userData,
+ }
+ i++
+ if i == len(events) {
+ break
+ }
+ }
+ ep.ready.PushBackList(&requeue)
+ ep.mu.Unlock()
+ return i
+}
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
new file mode 100644
index 000000000..0c42574db
--- /dev/null
+++ b/pkg/sentry/vfs/file_description.go
@@ -0,0 +1,837 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// A FileDescription represents an open file description, which is the entity
+// referred to by a file descriptor (POSIX.1-2017 3.258 "Open File
+// Description").
+//
+// FileDescriptions are reference-counted. Unless otherwise specified, all
+// FileDescription methods require that a reference is held.
+//
+// FileDescription is analogous to Linux's struct file.
+type FileDescription struct {
+ // refs is the reference count. refs is accessed using atomic memory
+ // operations.
+ refs int64
+
+ // flagsMu protects statusFlags and asyncHandler below.
+ flagsMu sync.Mutex
+
+ // statusFlags contains status flags, "initialized by open(2) and possibly
+ // modified by fcntl()" - fcntl(2). statusFlags can be read using atomic
+ // memory operations when it does not need to be synchronized with an
+ // access to asyncHandler.
+ statusFlags uint32
+
+ // asyncHandler handles O_ASYNC signal generation. It is set with the
+ // F_SETOWN or F_SETOWN_EX fcntls. For asyncHandler to be used, O_ASYNC must
+ // also be set by fcntl(2).
+ asyncHandler FileAsync
+
+ // epolls is the set of epollInterests registered for this FileDescription.
+ // epolls is protected by epollMu.
+ epollMu sync.Mutex
+ epolls map[*epollInterest]struct{}
+
+ // vd is the filesystem location at which this FileDescription was opened.
+ // A reference is held on vd. vd is immutable.
+ vd VirtualDentry
+
+ // opts contains options passed to FileDescription.Init(). opts is
+ // immutable.
+ opts FileDescriptionOptions
+
+ // readable is MayReadFileWithOpenFlags(statusFlags). readable is
+ // immutable.
+ //
+ // readable is analogous to Linux's FMODE_READ.
+ readable bool
+
+ // writable is MayWriteFileWithOpenFlags(statusFlags). If writable is true,
+ // the FileDescription holds a write count on vd.mount. writable is
+ // immutable.
+ //
+ // writable is analogous to Linux's FMODE_WRITE.
+ writable bool
+
+ usedLockBSD uint32
+
+ // impl is the FileDescriptionImpl associated with this Filesystem. impl is
+ // immutable. This should be the last field in FileDescription.
+ impl FileDescriptionImpl
+}
+
+// FileDescriptionOptions contains options to FileDescription.Init().
+type FileDescriptionOptions struct {
+ // If AllowDirectIO is true, allow O_DIRECT to be set on the file.
+ AllowDirectIO bool
+
+ // If DenyPRead is true, calls to FileDescription.PRead() return ESPIPE.
+ DenyPRead bool
+
+ // If DenyPWrite is true, calls to FileDescription.PWrite() return
+ // ESPIPE.
+ DenyPWrite bool
+
+ // If UseDentryMetadata is true, calls to FileDescription methods that
+ // interact with file and filesystem metadata (Stat, SetStat, StatFS,
+ // Listxattr, Getxattr, Setxattr, Removexattr) are implemented by calling
+ // the corresponding FilesystemImpl methods instead of the corresponding
+ // FileDescriptionImpl methods.
+ //
+ // UseDentryMetadata is intended for file descriptions that are implemented
+ // outside of individual filesystems, such as pipes, sockets, and device
+ // special files. FileDescriptions for which UseDentryMetadata is true may
+ // embed DentryMetadataFileDescriptionImpl to obtain appropriate
+ // implementations of FileDescriptionImpl methods that should not be
+ // called.
+ UseDentryMetadata bool
+}
+
+// FileCreationFlags are the set of flags passed to FileDescription.Init() but
+// omitted from FileDescription.StatusFlags().
+const FileCreationFlags = linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC
+
+// Init must be called before first use of fd. If it succeeds, it takes
+// references on mnt and d. flags is the initial file description flags, which
+// is usually the full set of flags passed to open(2).
+func (fd *FileDescription) Init(impl FileDescriptionImpl, flags uint32, mnt *Mount, d *Dentry, opts *FileDescriptionOptions) error {
+ writable := MayWriteFileWithOpenFlags(flags)
+ if writable {
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ }
+
+ fd.refs = 1
+
+ // Remove "file creation flags" to mirror the behavior from file.f_flags in
+ // fs/open.c:do_dentry_open.
+ fd.statusFlags = flags &^ FileCreationFlags
+ fd.vd = VirtualDentry{
+ mount: mnt,
+ dentry: d,
+ }
+ mnt.IncRef()
+ d.IncRef()
+ fd.opts = *opts
+ fd.readable = MayReadFileWithOpenFlags(flags)
+ fd.writable = writable
+ fd.impl = impl
+ return nil
+}
+
+// IncRef increments fd's reference count.
+func (fd *FileDescription) IncRef() {
+ atomic.AddInt64(&fd.refs, 1)
+}
+
+// TryIncRef increments fd's reference count and returns true. If fd's
+// reference count is already zero, TryIncRef does nothing and returns false.
+//
+// TryIncRef does not require that a reference is held on fd.
+func (fd *FileDescription) TryIncRef() bool {
+ for {
+ refs := atomic.LoadInt64(&fd.refs)
+ if refs <= 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&fd.refs, refs, refs+1) {
+ return true
+ }
+ }
+}
+
+// DecRef decrements fd's reference count.
+func (fd *FileDescription) DecRef() {
+ if refs := atomic.AddInt64(&fd.refs, -1); refs == 0 {
+ // Unregister fd from all epoll instances.
+ fd.epollMu.Lock()
+ epolls := fd.epolls
+ fd.epolls = nil
+ fd.epollMu.Unlock()
+ for epi := range epolls {
+ ep := epi.epoll
+ ep.interestMu.Lock()
+ // Check that epi has not been concurrently unregistered by
+ // EpollInstance.DeleteInterest() or EpollInstance.Release().
+ if _, ok := ep.interest[epi.key]; ok {
+ fd.EventUnregister(&epi.waiter)
+ ep.removeLocked(epi)
+ }
+ ep.interestMu.Unlock()
+ }
+
+ // If BSD locks were used, release any lock that it may have acquired.
+ if atomic.LoadUint32(&fd.usedLockBSD) != 0 {
+ fd.impl.UnlockBSD(context.Background(), fd)
+ }
+
+ // Release implementation resources.
+ fd.impl.Release()
+ if fd.writable {
+ fd.vd.mount.EndWrite()
+ }
+ fd.vd.DecRef()
+ fd.flagsMu.Lock()
+ // TODO(gvisor.dev/issue/1663): We may need to unregister during save, as we do in VFS1.
+ if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
+ fd.asyncHandler.Unregister(fd)
+ }
+ fd.asyncHandler = nil
+ fd.flagsMu.Unlock()
+ } else if refs < 0 {
+ panic("FileDescription.DecRef() called without holding a reference")
+ }
+}
+
+// Refs returns the current number of references. The returned count
+// is inherently racy and is unsafe to use without external synchronization.
+func (fd *FileDescription) Refs() int64 {
+ return atomic.LoadInt64(&fd.refs)
+}
+
+// Mount returns the mount on which fd was opened. It does not take a reference
+// on the returned Mount.
+func (fd *FileDescription) Mount() *Mount {
+ return fd.vd.mount
+}
+
+// Dentry returns the dentry at which fd was opened. It does not take a
+// reference on the returned Dentry.
+func (fd *FileDescription) Dentry() *Dentry {
+ return fd.vd.dentry
+}
+
+// VirtualDentry returns the location at which fd was opened. It does not take
+// a reference on the returned VirtualDentry.
+func (fd *FileDescription) VirtualDentry() VirtualDentry {
+ return fd.vd
+}
+
+// Options returns the options passed to fd.Init().
+func (fd *FileDescription) Options() FileDescriptionOptions {
+ return fd.opts
+}
+
+// StatusFlags returns file description status flags, as for fcntl(F_GETFL).
+func (fd *FileDescription) StatusFlags() uint32 {
+ return atomic.LoadUint32(&fd.statusFlags)
+}
+
+// SetStatusFlags sets file description status flags, as for fcntl(F_SETFL).
+func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Credentials, flags uint32) error {
+ // Compare Linux's fs/fcntl.c:setfl().
+ oldFlags := fd.StatusFlags()
+ // Linux documents this check as "O_APPEND cannot be cleared if the file is
+ // marked as append-only and the file is open for write", which would make
+ // sense. However, the check as actually implemented seems to be "O_APPEND
+ // cannot be changed if the file is marked as append-only".
+ if (flags^oldFlags)&linux.O_APPEND != 0 {
+ stat, err := fd.Stat(ctx, StatOptions{
+ // There is no mask bit for stx_attributes.
+ Mask: 0,
+ // Linux just reads inode::i_flags directly.
+ Sync: linux.AT_STATX_DONT_SYNC,
+ })
+ if err != nil {
+ return err
+ }
+ if (stat.AttributesMask&linux.STATX_ATTR_APPEND != 0) && (stat.Attributes&linux.STATX_ATTR_APPEND != 0) {
+ return syserror.EPERM
+ }
+ }
+ if (flags&linux.O_NOATIME != 0) && (oldFlags&linux.O_NOATIME == 0) {
+ stat, err := fd.Stat(ctx, StatOptions{
+ Mask: linux.STATX_UID,
+ // Linux's inode_owner_or_capable() just reads inode::i_uid
+ // directly.
+ Sync: linux.AT_STATX_DONT_SYNC,
+ })
+ if err != nil {
+ return err
+ }
+ if stat.Mask&linux.STATX_UID == 0 {
+ return syserror.EPERM
+ }
+ if !CanActAsOwner(creds, auth.KUID(stat.UID)) {
+ return syserror.EPERM
+ }
+ }
+ if flags&linux.O_DIRECT != 0 && !fd.opts.AllowDirectIO {
+ return syserror.EINVAL
+ }
+ // TODO(jamieliu): FileDescriptionImpl.SetOAsync()?
+ const settableFlags = linux.O_APPEND | linux.O_ASYNC | linux.O_DIRECT | linux.O_NOATIME | linux.O_NONBLOCK
+ fd.flagsMu.Lock()
+ if fd.asyncHandler != nil {
+ // Use fd.statusFlags instead of oldFlags, which may have become outdated,
+ // to avoid double registering/unregistering.
+ if fd.statusFlags&linux.O_ASYNC == 0 && flags&linux.O_ASYNC != 0 {
+ fd.asyncHandler.Register(fd)
+ } else if fd.statusFlags&linux.O_ASYNC != 0 && flags&linux.O_ASYNC == 0 {
+ fd.asyncHandler.Unregister(fd)
+ }
+ }
+ fd.statusFlags = (oldFlags &^ settableFlags) | (flags & settableFlags)
+ fd.flagsMu.Unlock()
+ return nil
+}
+
+// IsReadable returns true if fd was opened for reading.
+func (fd *FileDescription) IsReadable() bool {
+ return fd.readable
+}
+
+// IsWritable returns true if fd was opened for writing.
+func (fd *FileDescription) IsWritable() bool {
+ return fd.writable
+}
+
+// Impl returns the FileDescriptionImpl associated with fd.
+func (fd *FileDescription) Impl() FileDescriptionImpl {
+ return fd.impl
+}
+
+// FileDescriptionImpl contains implementation details for an FileDescription.
+// Implementations of FileDescriptionImpl should contain their associated
+// FileDescription by value as their first field.
+//
+// For all functions that return linux.Statx, Statx.Uid and Statx.Gid will
+// be interpreted as IDs in the root UserNamespace (i.e. as auth.KUID and
+// auth.KGID respectively).
+//
+// All methods may return errors not specified.
+//
+// FileDescriptionImpl is analogous to Linux's struct file_operations.
+type FileDescriptionImpl interface {
+ // Release is called when the associated FileDescription reaches zero
+ // references.
+ Release()
+
+ // OnClose is called when a file descriptor representing the
+ // FileDescription is closed. Note that returning a non-nil error does not
+ // prevent the file descriptor from being closed.
+ OnClose(ctx context.Context) error
+
+ // Stat returns metadata for the file represented by the FileDescription.
+ Stat(ctx context.Context, opts StatOptions) (linux.Statx, error)
+
+ // SetStat updates metadata for the file represented by the
+ // FileDescription. Implementations are responsible for checking if the
+ // operation can be performed (see vfs.CheckSetStat() for common checks).
+ SetStat(ctx context.Context, opts SetStatOptions) error
+
+ // StatFS returns metadata for the filesystem containing the file
+ // represented by the FileDescription.
+ StatFS(ctx context.Context) (linux.Statfs, error)
+
+ // Allocate grows file represented by FileDescription to offset + length bytes.
+ // Only mode == 0 is supported currently.
+ Allocate(ctx context.Context, mode, offset, length uint64) error
+
+ // waiter.Waitable methods may be used to poll for I/O events.
+ waiter.Waitable
+
+ // PRead reads from the file into dst, starting at the given offset, and
+ // returns the number of bytes read. PRead is permitted to return partial
+ // reads with a nil error.
+ //
+ // Errors:
+ //
+ // - If opts.Flags specifies unsupported options, PRead returns EOPNOTSUPP.
+ //
+ // Preconditions: The FileDescription was opened for reading.
+ // FileDescriptionOptions.DenyPRead == false.
+ PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error)
+
+ // Read is similar to PRead, but does not specify an offset.
+ //
+ // For files with an implicit FileDescription offset (e.g. regular files),
+ // Read begins at the FileDescription offset, and advances the offset by
+ // the number of bytes read; note that POSIX 2.9.7 "Thread Interactions
+ // with Regular File Operations" requires that all operations that may
+ // mutate the FileDescription offset are serialized.
+ //
+ // Errors:
+ //
+ // - If opts.Flags specifies unsupported options, Read returns EOPNOTSUPP.
+ //
+ // Preconditions: The FileDescription was opened for reading.
+ Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error)
+
+ // PWrite writes src to the file, starting at the given offset, and returns
+ // the number of bytes written. PWrite is permitted to return partial
+ // writes with a nil error.
+ //
+ // As in Linux (but not POSIX), if O_APPEND is in effect for the
+ // FileDescription, PWrite should ignore the offset and append data to the
+ // end of the file.
+ //
+ // Errors:
+ //
+ // - If opts.Flags specifies unsupported options, PWrite returns
+ // EOPNOTSUPP.
+ //
+ // Preconditions: The FileDescription was opened for writing.
+ // FileDescriptionOptions.DenyPWrite == false.
+ PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error)
+
+ // Write is similar to PWrite, but does not specify an offset, which is
+ // implied as for Read.
+ //
+ // Write is a FileDescriptionImpl method, instead of a wrapper around
+ // PWrite that uses a FileDescription offset, to make it possible for
+ // remote filesystems to implement O_APPEND correctly (i.e. atomically with
+ // respect to writers outside the scope of VFS).
+ //
+ // Errors:
+ //
+ // - If opts.Flags specifies unsupported options, Write returns EOPNOTSUPP.
+ //
+ // Preconditions: The FileDescription was opened for writing.
+ Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error)
+
+ // IterDirents invokes cb on each entry in the directory represented by the
+ // FileDescription. If IterDirents has been called since the last call to
+ // Seek, it continues iteration from the end of the last call.
+ IterDirents(ctx context.Context, cb IterDirentsCallback) error
+
+ // Seek changes the FileDescription offset (assuming one exists) and
+ // returns its new value.
+ //
+ // For directories, if whence == SEEK_SET and offset == 0, the caller is
+ // rewinddir(), such that Seek "shall also cause the directory stream to
+ // refer to the current state of the corresponding directory" -
+ // POSIX.1-2017.
+ Seek(ctx context.Context, offset int64, whence int32) (int64, error)
+
+ // Sync requests that cached state associated with the file represented by
+ // the FileDescription is synchronized with persistent storage, and blocks
+ // until this is complete.
+ Sync(ctx context.Context) error
+
+ // ConfigureMMap mutates opts to implement mmap(2) for the file. Most
+ // implementations that support memory mapping can call
+ // GenericConfigureMMap with the appropriate memmap.Mappable.
+ ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error
+
+ // Ioctl implements the ioctl(2) syscall.
+ Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error)
+
+ // Listxattr returns all extended attribute names for the file.
+ Listxattr(ctx context.Context, size uint64) ([]string, error)
+
+ // Getxattr returns the value associated with the given extended attribute
+ // for the file.
+ Getxattr(ctx context.Context, opts GetxattrOptions) (string, error)
+
+ // Setxattr changes the value associated with the given extended attribute
+ // for the file.
+ Setxattr(ctx context.Context, opts SetxattrOptions) error
+
+ // Removexattr removes the given extended attribute from the file.
+ Removexattr(ctx context.Context, name string) error
+
+ // LockBSD tries to acquire a BSD-style advisory file lock.
+ LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error
+
+ // UnlockBSD releases a BSD-style advisory file lock.
+ UnlockBSD(ctx context.Context, uid lock.UniqueID) error
+
+ // LockPOSIX tries to acquire a POSIX-style advisory file lock.
+ LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, start, length uint64, whence int16, block lock.Blocker) error
+
+ // UnlockPOSIX releases a POSIX-style advisory file lock.
+ UnlockPOSIX(ctx context.Context, uid lock.UniqueID, start, length uint64, whence int16) error
+}
+
+// Dirent holds the information contained in struct linux_dirent64.
+type Dirent struct {
+ // Name is the filename.
+ Name string
+
+ // Type is the file type, a linux.DT_* constant.
+ Type uint8
+
+ // Ino is the inode number.
+ Ino uint64
+
+ // NextOff is the offset of the *next* Dirent in the directory; that is,
+ // FileDescription.Seek(NextOff, SEEK_SET) (as called by seekdir(3)) will
+ // cause the next call to FileDescription.IterDirents() to yield the next
+ // Dirent. (The offset of the first Dirent in a directory is always 0.)
+ NextOff int64
+}
+
+// IterDirentsCallback receives Dirents from FileDescriptionImpl.IterDirents.
+type IterDirentsCallback interface {
+ // Handle handles the given iterated Dirent. If Handle returns a non-nil
+ // error, FileDescriptionImpl.IterDirents must stop iteration and return
+ // the error; the next call to FileDescriptionImpl.IterDirents should
+ // restart with the same Dirent.
+ Handle(dirent Dirent) error
+}
+
+// IterDirentsCallbackFunc implements IterDirentsCallback for a function with
+// the semantics of IterDirentsCallback.Handle.
+type IterDirentsCallbackFunc func(dirent Dirent) error
+
+// Handle implements IterDirentsCallback.Handle.
+func (f IterDirentsCallbackFunc) Handle(dirent Dirent) error {
+ return f(dirent)
+}
+
+// OnClose is called when a file descriptor representing the FileDescription is
+// closed. Returning a non-nil error should not prevent the file descriptor
+// from being closed.
+func (fd *FileDescription) OnClose(ctx context.Context) error {
+ return fd.impl.OnClose(ctx)
+}
+
+// Stat returns metadata for the file represented by fd.
+func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ stat, err := fd.vd.mount.fs.impl.StatAt(ctx, rp, opts)
+ vfsObj.putResolvingPath(rp)
+ return stat, err
+ }
+ return fd.impl.Stat(ctx, opts)
+}
+
+// SetStat updates metadata for the file represented by fd.
+func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) error {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ err := fd.vd.mount.fs.impl.SetStatAt(ctx, rp, opts)
+ vfsObj.putResolvingPath(rp)
+ return err
+ }
+ return fd.impl.SetStat(ctx, opts)
+}
+
+// StatFS returns metadata for the filesystem containing the file represented
+// by fd.
+func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ statfs, err := fd.vd.mount.fs.impl.StatFSAt(ctx, rp)
+ vfsObj.putResolvingPath(rp)
+ return statfs, err
+ }
+ return fd.impl.StatFS(ctx)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+//
+// It returns fd's I/O readiness.
+func (fd *FileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fd.impl.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+//
+// It registers e for I/O readiness events in mask.
+func (fd *FileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fd.impl.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+//
+// It unregisters e for I/O readiness events.
+func (fd *FileDescription) EventUnregister(e *waiter.Entry) {
+ fd.impl.EventUnregister(e)
+}
+
+// PRead reads from the file represented by fd into dst, starting at the given
+// offset, and returns the number of bytes read. PRead is permitted to return
+// partial reads with a nil error.
+func (fd *FileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
+ if fd.opts.DenyPRead {
+ return 0, syserror.ESPIPE
+ }
+ if !fd.readable {
+ return 0, syserror.EBADF
+ }
+ return fd.impl.PRead(ctx, dst, offset, opts)
+}
+
+// Read is similar to PRead, but does not specify an offset.
+func (fd *FileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) {
+ if !fd.readable {
+ return 0, syserror.EBADF
+ }
+ return fd.impl.Read(ctx, dst, opts)
+}
+
+// PWrite writes src to the file represented by fd, starting at the given
+// offset, and returns the number of bytes written. PWrite is permitted to
+// return partial writes with a nil error.
+func (fd *FileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
+ if fd.opts.DenyPWrite {
+ return 0, syserror.ESPIPE
+ }
+ if !fd.writable {
+ return 0, syserror.EBADF
+ }
+ return fd.impl.PWrite(ctx, src, offset, opts)
+}
+
+// Write is similar to PWrite, but does not specify an offset.
+func (fd *FileDescription) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) {
+ if !fd.writable {
+ return 0, syserror.EBADF
+ }
+ return fd.impl.Write(ctx, src, opts)
+}
+
+// IterDirents invokes cb on each entry in the directory represented by fd. If
+// IterDirents has been called since the last call to Seek, it continues
+// iteration from the end of the last call.
+func (fd *FileDescription) IterDirents(ctx context.Context, cb IterDirentsCallback) error {
+ return fd.impl.IterDirents(ctx, cb)
+}
+
+// Seek changes fd's offset (assuming one exists) and returns its new value.
+func (fd *FileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return fd.impl.Seek(ctx, offset, whence)
+}
+
+// Sync has the semantics of fsync(2).
+func (fd *FileDescription) Sync(ctx context.Context) error {
+ return fd.impl.Sync(ctx)
+}
+
+// ConfigureMMap mutates opts to implement mmap(2) for the file represented by
+// fd.
+func (fd *FileDescription) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ return fd.impl.ConfigureMMap(ctx, opts)
+}
+
+// Ioctl implements the ioctl(2) syscall.
+func (fd *FileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return fd.impl.Ioctl(ctx, uio, args)
+}
+
+// Listxattr returns all extended attribute names for the file represented by
+// fd.
+//
+// If the size of the list (including a NUL terminating byte after every entry)
+// would exceed size, ERANGE may be returned. Note that implementations
+// are free to ignore size entirely and return without error). In all cases,
+// if size is 0, the list should be returned without error, regardless of size.
+func (fd *FileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ names, err := fd.vd.mount.fs.impl.ListxattrAt(ctx, rp, size)
+ vfsObj.putResolvingPath(rp)
+ return names, err
+ }
+ names, err := fd.impl.Listxattr(ctx, size)
+ if err == syserror.ENOTSUP {
+ // Linux doesn't actually return ENOTSUP in this case; instead,
+ // fs/xattr.c:vfs_listxattr() falls back to allowing the security
+ // subsystem to return security extended attributes, which by default
+ // don't exist.
+ return nil, nil
+ }
+ return names, err
+}
+
+// Getxattr returns the value associated with the given extended attribute for
+// the file represented by fd.
+//
+// If the size of the return value exceeds opts.Size, ERANGE may be returned
+// (note that implementations are free to ignore opts.Size entirely and return
+// without error). In all cases, if opts.Size is 0, the value should be
+// returned without error, regardless of size.
+func (fd *FileDescription) Getxattr(ctx context.Context, opts *GetxattrOptions) (string, error) {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ val, err := fd.vd.mount.fs.impl.GetxattrAt(ctx, rp, *opts)
+ vfsObj.putResolvingPath(rp)
+ return val, err
+ }
+ return fd.impl.Getxattr(ctx, *opts)
+}
+
+// Setxattr changes the value associated with the given extended attribute for
+// the file represented by fd.
+func (fd *FileDescription) Setxattr(ctx context.Context, opts *SetxattrOptions) error {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ err := fd.vd.mount.fs.impl.SetxattrAt(ctx, rp, *opts)
+ vfsObj.putResolvingPath(rp)
+ return err
+ }
+ return fd.impl.Setxattr(ctx, *opts)
+}
+
+// Removexattr removes the given extended attribute from the file represented
+// by fd.
+func (fd *FileDescription) Removexattr(ctx context.Context, name string) error {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ err := fd.vd.mount.fs.impl.RemovexattrAt(ctx, rp, name)
+ vfsObj.putResolvingPath(rp)
+ return err
+ }
+ return fd.impl.Removexattr(ctx, name)
+}
+
+// SyncFS instructs the filesystem containing fd to execute the semantics of
+// syncfs(2).
+func (fd *FileDescription) SyncFS(ctx context.Context) error {
+ return fd.vd.mount.fs.impl.Sync(ctx)
+}
+
+// MappedName implements memmap.MappingIdentity.MappedName.
+func (fd *FileDescription) MappedName(ctx context.Context) string {
+ vfsroot := RootFromContext(ctx)
+ s, _ := fd.vd.mount.vfs.PathnameWithDeleted(ctx, vfsroot, fd.vd)
+ if vfsroot.Ok() {
+ vfsroot.DecRef()
+ }
+ return s
+}
+
+// DeviceID implements memmap.MappingIdentity.DeviceID.
+func (fd *FileDescription) DeviceID() uint64 {
+ stat, err := fd.Stat(context.Background(), StatOptions{
+ // There is no STATX_DEV; we assume that Stat will return it if it's
+ // available regardless of mask.
+ Mask: 0,
+ // fs/proc/task_mmu.c:show_map_vma() just reads inode::i_sb->s_dev
+ // directly.
+ Sync: linux.AT_STATX_DONT_SYNC,
+ })
+ if err != nil {
+ return 0
+ }
+ return uint64(linux.MakeDeviceID(uint16(stat.DevMajor), stat.DevMinor))
+}
+
+// InodeID implements memmap.MappingIdentity.InodeID.
+func (fd *FileDescription) InodeID() uint64 {
+ stat, err := fd.Stat(context.Background(), StatOptions{
+ Mask: linux.STATX_INO,
+ // fs/proc/task_mmu.c:show_map_vma() just reads inode::i_ino directly.
+ Sync: linux.AT_STATX_DONT_SYNC,
+ })
+ if err != nil || stat.Mask&linux.STATX_INO == 0 {
+ return 0
+ }
+ return stat.Ino
+}
+
+// Msync implements memmap.MappingIdentity.Msync.
+func (fd *FileDescription) Msync(ctx context.Context, mr memmap.MappableRange) error {
+ return fd.Sync(ctx)
+}
+
+// LockBSD tries to acquire a BSD-style advisory file lock.
+func (fd *FileDescription) LockBSD(ctx context.Context, lockType lock.LockType, blocker lock.Blocker) error {
+ atomic.StoreUint32(&fd.usedLockBSD, 1)
+ return fd.impl.LockBSD(ctx, fd, lockType, blocker)
+}
+
+// UnlockBSD releases a BSD-style advisory file lock.
+func (fd *FileDescription) UnlockBSD(ctx context.Context) error {
+ return fd.impl.UnlockBSD(ctx, fd)
+}
+
+// LockPOSIX locks a POSIX-style file range lock.
+func (fd *FileDescription) LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, start, end uint64, whence int16, block lock.Blocker) error {
+ return fd.impl.LockPOSIX(ctx, uid, t, start, end, whence, block)
+}
+
+// UnlockPOSIX unlocks a POSIX-style file range lock.
+func (fd *FileDescription) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, start, end uint64, whence int16) error {
+ return fd.impl.UnlockPOSIX(ctx, uid, start, end, whence)
+}
+
+// A FileAsync sends signals to its owner when w is ready for IO. This is only
+// implemented by pkg/sentry/fasync:FileAsync, but we unfortunately need this
+// interface to avoid circular dependencies.
+type FileAsync interface {
+ Register(w waiter.Waitable)
+ Unregister(w waiter.Waitable)
+}
+
+// AsyncHandler returns the FileAsync for fd.
+func (fd *FileDescription) AsyncHandler() FileAsync {
+ fd.flagsMu.Lock()
+ defer fd.flagsMu.Unlock()
+ return fd.asyncHandler
+}
+
+// SetAsyncHandler sets fd.asyncHandler if it has not been set before and
+// returns it.
+func (fd *FileDescription) SetAsyncHandler(newHandler func() FileAsync) FileAsync {
+ fd.flagsMu.Lock()
+ defer fd.flagsMu.Unlock()
+ if fd.asyncHandler == nil {
+ fd.asyncHandler = newHandler()
+ if fd.statusFlags&linux.O_ASYNC != 0 {
+ fd.asyncHandler.Register(fd)
+ }
+ }
+ return fd.asyncHandler
+}
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
new file mode 100644
index 000000000..6b8b4ad49
--- /dev/null
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -0,0 +1,428 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "bytes"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// The following design pattern is strongly recommended for filesystem
+// implementations to adapt:
+// - Have a local fileDescription struct (containing FileDescription) which
+// embeds FileDescriptionDefaultImpl and overrides the default methods
+// which are common to all fd implementations for that filesystem like
+// StatusFlags, SetStatusFlags, Stat, SetStat, StatFS, etc.
+// - This should be embedded in all file description implementations as the
+// first field by value.
+// - Directory FDs would also embed DirectoryFileDescriptionDefaultImpl.
+
+// FileDescriptionDefaultImpl may be embedded by implementations of
+// FileDescriptionImpl to obtain implementations of many FileDescriptionImpl
+// methods with default behavior analogous to Linux's.
+type FileDescriptionDefaultImpl struct{}
+
+// OnClose implements FileDescriptionImpl.OnClose analogously to
+// file_operations::flush == NULL in Linux.
+func (FileDescriptionDefaultImpl) OnClose(ctx context.Context) error {
+ return nil
+}
+
+// StatFS implements FileDescriptionImpl.StatFS analogously to
+// super_operations::statfs == NULL in Linux.
+func (FileDescriptionDefaultImpl) StatFS(ctx context.Context) (linux.Statfs, error) {
+ return linux.Statfs{}, syserror.ENOSYS
+}
+
+// Allocate implements FileDescriptionImpl.Allocate analogously to
+// fallocate called on regular file, directory or FIFO in Linux.
+func (FileDescriptionDefaultImpl) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ return syserror.ENODEV
+}
+
+// Readiness implements waiter.Waitable.Readiness analogously to
+// file_operations::poll == NULL in Linux.
+func (FileDescriptionDefaultImpl) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // include/linux/poll.h:vfs_poll() => DEFAULT_POLLMASK
+ return waiter.EventIn | waiter.EventOut
+}
+
+// EventRegister implements waiter.Waitable.EventRegister analogously to
+// file_operations::poll == NULL in Linux.
+func (FileDescriptionDefaultImpl) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister analogously to
+// file_operations::poll == NULL in Linux.
+func (FileDescriptionDefaultImpl) EventUnregister(e *waiter.Entry) {
+}
+
+// PRead implements FileDescriptionImpl.PRead analogously to
+// file_operations::read == file_operations::read_iter == NULL in Linux.
+func (FileDescriptionDefaultImpl) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
+ return 0, syserror.EINVAL
+}
+
+// Read implements FileDescriptionImpl.Read analogously to
+// file_operations::read == file_operations::read_iter == NULL in Linux.
+func (FileDescriptionDefaultImpl) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) {
+ return 0, syserror.EINVAL
+}
+
+// PWrite implements FileDescriptionImpl.PWrite analogously to
+// file_operations::write == file_operations::write_iter == NULL in Linux.
+func (FileDescriptionDefaultImpl) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
+ return 0, syserror.EINVAL
+}
+
+// Write implements FileDescriptionImpl.Write analogously to
+// file_operations::write == file_operations::write_iter == NULL in Linux.
+func (FileDescriptionDefaultImpl) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) {
+ return 0, syserror.EINVAL
+}
+
+// IterDirents implements FileDescriptionImpl.IterDirents analogously to
+// file_operations::iterate == file_operations::iterate_shared == NULL in
+// Linux.
+func (FileDescriptionDefaultImpl) IterDirents(ctx context.Context, cb IterDirentsCallback) error {
+ return syserror.ENOTDIR
+}
+
+// Seek implements FileDescriptionImpl.Seek analogously to
+// file_operations::llseek == NULL in Linux.
+func (FileDescriptionDefaultImpl) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Sync implements FileDescriptionImpl.Sync analogously to
+// file_operations::fsync == NULL in Linux.
+func (FileDescriptionDefaultImpl) Sync(ctx context.Context) error {
+ return syserror.EINVAL
+}
+
+// ConfigureMMap implements FileDescriptionImpl.ConfigureMMap analogously to
+// file_operations::mmap == NULL in Linux.
+func (FileDescriptionDefaultImpl) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ return syserror.ENODEV
+}
+
+// Ioctl implements FileDescriptionImpl.Ioctl analogously to
+// file_operations::unlocked_ioctl == NULL in Linux.
+func (FileDescriptionDefaultImpl) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return 0, syserror.ENOTTY
+}
+
+// Listxattr implements FileDescriptionImpl.Listxattr analogously to
+// inode_operations::listxattr == NULL in Linux.
+func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context, size uint64) ([]string, error) {
+ // This isn't exactly accurate; see FileDescription.Listxattr.
+ return nil, syserror.ENOTSUP
+}
+
+// Getxattr implements FileDescriptionImpl.Getxattr analogously to
+// inode::i_opflags & IOP_XATTR == 0 in Linux.
+func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, opts GetxattrOptions) (string, error) {
+ return "", syserror.ENOTSUP
+}
+
+// Setxattr implements FileDescriptionImpl.Setxattr analogously to
+// inode::i_opflags & IOP_XATTR == 0 in Linux.
+func (FileDescriptionDefaultImpl) Setxattr(ctx context.Context, opts SetxattrOptions) error {
+ return syserror.ENOTSUP
+}
+
+// Removexattr implements FileDescriptionImpl.Removexattr analogously to
+// inode::i_opflags & IOP_XATTR == 0 in Linux.
+func (FileDescriptionDefaultImpl) Removexattr(ctx context.Context, name string) error {
+ return syserror.ENOTSUP
+}
+
+// DirectoryFileDescriptionDefaultImpl may be embedded by implementations of
+// FileDescriptionImpl that always represent directories to obtain
+// implementations of non-directory I/O methods that return EISDIR.
+type DirectoryFileDescriptionDefaultImpl struct{}
+
+// Allocate implements DirectoryFileDescriptionDefaultImpl.Allocate.
+func (DirectoryFileDescriptionDefaultImpl) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ return syserror.EISDIR
+}
+
+// PRead implements FileDescriptionImpl.PRead.
+func (DirectoryFileDescriptionDefaultImpl) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
+ return 0, syserror.EISDIR
+}
+
+// Read implements FileDescriptionImpl.Read.
+func (DirectoryFileDescriptionDefaultImpl) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) {
+ return 0, syserror.EISDIR
+}
+
+// PWrite implements FileDescriptionImpl.PWrite.
+func (DirectoryFileDescriptionDefaultImpl) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
+ return 0, syserror.EISDIR
+}
+
+// Write implements FileDescriptionImpl.Write.
+func (DirectoryFileDescriptionDefaultImpl) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) {
+ return 0, syserror.EISDIR
+}
+
+// DentryMetadataFileDescriptionImpl may be embedded by implementations of
+// FileDescriptionImpl for which FileDescriptionOptions.UseDentryMetadata is
+// true to obtain implementations of Stat and SetStat that panic.
+type DentryMetadataFileDescriptionImpl struct{}
+
+// Stat implements FileDescriptionImpl.Stat.
+func (DentryMetadataFileDescriptionImpl) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) {
+ panic("illegal call to DentryMetadataFileDescriptionImpl.Stat")
+}
+
+// SetStat implements FileDescriptionImpl.SetStat.
+func (DentryMetadataFileDescriptionImpl) SetStat(ctx context.Context, opts SetStatOptions) error {
+ panic("illegal call to DentryMetadataFileDescriptionImpl.SetStat")
+}
+
+// DynamicBytesSource represents a data source for a
+// DynamicBytesFileDescriptionImpl.
+type DynamicBytesSource interface {
+ // Generate writes the file's contents to buf.
+ Generate(ctx context.Context, buf *bytes.Buffer) error
+}
+
+// StaticData implements DynamicBytesSource over a static string.
+type StaticData struct {
+ Data string
+}
+
+// Generate implements DynamicBytesSource.
+func (s *StaticData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ buf.WriteString(s.Data)
+ return nil
+}
+
+// WritableDynamicBytesSource extends DynamicBytesSource to allow writes to the
+// underlying source.
+type WritableDynamicBytesSource interface {
+ DynamicBytesSource
+
+ // Write sends writes to the source.
+ Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error)
+}
+
+// DynamicBytesFileDescriptionImpl may be embedded by implementations of
+// FileDescriptionImpl that represent read-only regular files whose contents
+// are backed by a bytes.Buffer that is regenerated when necessary, consistent
+// with Linux's fs/seq_file.c:single_open().
+//
+// DynamicBytesFileDescriptionImpl.SetDataSource() must be called before first
+// use.
+type DynamicBytesFileDescriptionImpl struct {
+ data DynamicBytesSource // immutable
+ mu sync.Mutex // protects the following fields
+ buf bytes.Buffer
+ off int64
+ lastRead int64 // offset at which the last Read, PRead, or Seek ended
+}
+
+// SetDataSource must be called exactly once on fd before first use.
+func (fd *DynamicBytesFileDescriptionImpl) SetDataSource(data DynamicBytesSource) {
+ fd.data = data
+}
+
+// Preconditions: fd.mu must be locked.
+func (fd *DynamicBytesFileDescriptionImpl) preadLocked(ctx context.Context, dst usermem.IOSequence, offset int64, opts *ReadOptions) (int64, error) {
+ // Regenerate the buffer if it's empty, or before pread() at a new offset.
+ // Compare fs/seq_file.c:seq_read() => traverse().
+ switch {
+ case offset != fd.lastRead:
+ fd.buf.Reset()
+ fallthrough
+ case fd.buf.Len() == 0:
+ if err := fd.data.Generate(ctx, &fd.buf); err != nil {
+ fd.buf.Reset()
+ // fd.off is not updated in this case.
+ fd.lastRead = 0
+ return 0, err
+ }
+ }
+ bs := fd.buf.Bytes()
+ if offset >= int64(len(bs)) {
+ return 0, io.EOF
+ }
+ n, err := dst.CopyOut(ctx, bs[offset:])
+ fd.lastRead = offset + int64(n)
+ return int64(n), err
+}
+
+// PRead implements FileDescriptionImpl.PRead.
+func (fd *DynamicBytesFileDescriptionImpl) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
+ fd.mu.Lock()
+ n, err := fd.preadLocked(ctx, dst, offset, &opts)
+ fd.mu.Unlock()
+ return n, err
+}
+
+// Read implements FileDescriptionImpl.Read.
+func (fd *DynamicBytesFileDescriptionImpl) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) {
+ fd.mu.Lock()
+ n, err := fd.preadLocked(ctx, dst, fd.off, &opts)
+ fd.off += n
+ fd.mu.Unlock()
+ return n, err
+}
+
+// Seek implements FileDescriptionImpl.Seek.
+func (fd *DynamicBytesFileDescriptionImpl) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ switch whence {
+ case linux.SEEK_SET:
+ // Use offset as given.
+ case linux.SEEK_CUR:
+ offset += fd.off
+ default:
+ // fs/seq_file:seq_lseek() rejects SEEK_END etc.
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ if offset != fd.lastRead {
+ // Regenerate the file's contents immediately. Compare
+ // fs/seq_file.c:seq_lseek() => traverse().
+ fd.buf.Reset()
+ if err := fd.data.Generate(ctx, &fd.buf); err != nil {
+ fd.buf.Reset()
+ fd.off = 0
+ fd.lastRead = 0
+ return 0, err
+ }
+ fd.lastRead = offset
+ }
+ fd.off = offset
+ return offset, nil
+}
+
+// Preconditions: fd.mu must be locked.
+func (fd *DynamicBytesFileDescriptionImpl) pwriteLocked(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
+ if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+ limit, err := CheckLimit(ctx, offset, src.NumBytes())
+ if err != nil {
+ return 0, err
+ }
+ src = src.TakeFirst64(limit)
+
+ writable, ok := fd.data.(WritableDynamicBytesSource)
+ if !ok {
+ return 0, syserror.EIO
+ }
+ n, err := writable.Write(ctx, src, offset)
+ if err != nil {
+ return 0, err
+ }
+
+ // Invalidate cached data that might exist prior to this call.
+ fd.buf.Reset()
+ return n, nil
+}
+
+// PWrite implements FileDescriptionImpl.PWrite.
+func (fd *DynamicBytesFileDescriptionImpl) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
+ fd.mu.Lock()
+ n, err := fd.pwriteLocked(ctx, src, offset, opts)
+ fd.mu.Unlock()
+ return n, err
+}
+
+// Write implements FileDescriptionImpl.Write.
+func (fd *DynamicBytesFileDescriptionImpl) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) {
+ fd.mu.Lock()
+ n, err := fd.pwriteLocked(ctx, src, fd.off, opts)
+ fd.off += n
+ fd.mu.Unlock()
+ return n, err
+}
+
+// GenericConfigureMMap may be used by most implementations of
+// FileDescriptionImpl.ConfigureMMap.
+func GenericConfigureMMap(fd *FileDescription, m memmap.Mappable, opts *memmap.MMapOpts) error {
+ opts.Mappable = m
+ opts.MappingIdentity = fd
+ fd.IncRef()
+ return nil
+}
+
+// LockFD may be used by most implementations of FileDescriptionImpl.Lock*
+// functions. Caller must call Init().
+type LockFD struct {
+ locks *FileLocks
+}
+
+// Init initializes fd with FileLocks to use.
+func (fd *LockFD) Init(locks *FileLocks) {
+ fd.locks = locks
+}
+
+// Locks returns the locks associated with this file.
+func (fd *LockFD) Locks() *FileLocks {
+ return fd.locks
+}
+
+// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
+func (fd *LockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
+ return fd.locks.LockBSD(uid, t, block)
+}
+
+// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD.
+func (fd *LockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error {
+ fd.locks.UnlockBSD(uid)
+ return nil
+}
+
+// NoLockFD implements Lock*/Unlock* portion of FileDescriptionImpl interface
+// returning ENOLCK.
+type NoLockFD struct{}
+
+// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
+func (NoLockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
+ return syserror.ENOLCK
+}
+
+// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD.
+func (NoLockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error {
+ return syserror.ENOLCK
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (NoLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return syserror.ENOLCK
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (NoLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return syserror.ENOLCK
+}
diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go
new file mode 100644
index 000000000..3b7e1c273
--- /dev/null
+++ b/pkg/sentry/vfs/file_description_impl_util_test.go
@@ -0,0 +1,224 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "sync/atomic"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// fileDescription is the common fd struct which a filesystem implementation
+// embeds in all of its file description implementations as required.
+type fileDescription struct {
+ vfsfd FileDescription
+ FileDescriptionDefaultImpl
+ NoLockFD
+}
+
+// genCount contains the number of times its DynamicBytesSource.Generate()
+// implementation has been called.
+type genCount struct {
+ count uint64 // accessed using atomic memory ops
+}
+
+// Generate implements DynamicBytesSource.Generate.
+func (g *genCount) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%d", atomic.AddUint64(&g.count, 1))
+ return nil
+}
+
+type storeData struct {
+ data string
+}
+
+var _ WritableDynamicBytesSource = (*storeData)(nil)
+
+// Generate implements DynamicBytesSource.
+func (d *storeData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ buf.WriteString(d.data)
+ return nil
+}
+
+// Generate implements WritableDynamicBytesSource.
+func (d *storeData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ buf := make([]byte, src.NumBytes())
+ n, err := src.CopyIn(ctx, buf)
+ if err != nil {
+ return 0, err
+ }
+
+ d.data = string(buf[:n])
+ return 0, nil
+}
+
+// testFD is a read-only FileDescriptionImpl representing a regular file.
+type testFD struct {
+ fileDescription
+ DynamicBytesFileDescriptionImpl
+
+ data DynamicBytesSource
+}
+
+func newTestFD(vfsObj *VirtualFilesystem, statusFlags uint32, data DynamicBytesSource) *FileDescription {
+ vd := vfsObj.NewAnonVirtualDentry("genCountFD")
+ defer vd.DecRef()
+ var fd testFD
+ fd.vfsfd.Init(&fd, statusFlags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{})
+ fd.DynamicBytesFileDescriptionImpl.SetDataSource(data)
+ return &fd.vfsfd
+}
+
+// Release implements FileDescriptionImpl.Release.
+func (fd *testFD) Release() {
+}
+
+// SetStatusFlags implements FileDescriptionImpl.SetStatusFlags.
+// Stat implements FileDescriptionImpl.Stat.
+func (fd *testFD) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) {
+ // Note that Statx.Mask == 0 in the return value.
+ return linux.Statx{}, nil
+}
+
+// SetStat implements FileDescriptionImpl.SetStat.
+func (fd *testFD) SetStat(ctx context.Context, opts SetStatOptions) error {
+ return syserror.EPERM
+}
+
+func TestGenCountFD(t *testing.T) {
+ ctx := contexttest.Context(t)
+
+ vfsObj := &VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
+ fd := newTestFD(vfsObj, linux.O_RDWR, &genCount{})
+ defer fd.DecRef()
+
+ // The first read causes Generate to be called to fill the FD's buffer.
+ buf := make([]byte, 2)
+ ioseq := usermem.BytesIOSequence(buf)
+ n, err := fd.Read(ctx, ioseq, ReadOptions{})
+ if n != 1 || (err != nil && err != io.EOF) {
+ t.Fatalf("first Read: got (%d, %v), wanted (1, nil or EOF)", n, err)
+ }
+ if want := byte('1'); buf[0] != want {
+ t.Errorf("first Read: got byte %c, wanted %c", buf[0], want)
+ }
+
+ // A second read without seeking is still at EOF.
+ n, err = fd.Read(ctx, ioseq, ReadOptions{})
+ if n != 0 || err != io.EOF {
+ t.Fatalf("second Read: got (%d, %v), wanted (0, EOF)", n, err)
+ }
+
+ // Seeking to the beginning of the file causes it to be regenerated.
+ n, err = fd.Seek(ctx, 0, linux.SEEK_SET)
+ if n != 0 || err != nil {
+ t.Fatalf("Seek: got (%d, %v), wanted (0, nil)", n, err)
+ }
+ n, err = fd.Read(ctx, ioseq, ReadOptions{})
+ if n != 1 || (err != nil && err != io.EOF) {
+ t.Fatalf("Read after Seek: got (%d, %v), wanted (1, nil or EOF)", n, err)
+ }
+ if want := byte('2'); buf[0] != want {
+ t.Errorf("Read after Seek: got byte %c, wanted %c", buf[0], want)
+ }
+
+ // PRead at the beginning of the file also causes it to be regenerated.
+ n, err = fd.PRead(ctx, ioseq, 0, ReadOptions{})
+ if n != 1 || (err != nil && err != io.EOF) {
+ t.Fatalf("PRead: got (%d, %v), wanted (1, nil or EOF)", n, err)
+ }
+ if want := byte('3'); buf[0] != want {
+ t.Errorf("PRead: got byte %c, wanted %c", buf[0], want)
+ }
+
+ // Write and PWrite fails.
+ if _, err := fd.Write(ctx, ioseq, WriteOptions{}); err != syserror.EIO {
+ t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO)
+ }
+ if _, err := fd.PWrite(ctx, ioseq, 0, WriteOptions{}); err != syserror.EIO {
+ t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO)
+ }
+}
+
+func TestWritable(t *testing.T) {
+ ctx := contexttest.Context(t)
+
+ vfsObj := &VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
+ fd := newTestFD(vfsObj, linux.O_RDWR, &storeData{data: "init"})
+ defer fd.DecRef()
+
+ buf := make([]byte, 10)
+ ioseq := usermem.BytesIOSequence(buf)
+ if n, err := fd.Read(ctx, ioseq, ReadOptions{}); n != 4 && err != io.EOF {
+ t.Fatalf("Read: got (%v, %v), wanted (4, EOF)", n, err)
+ }
+ if want := "init"; want == string(buf) {
+ t.Fatalf("Read: got %v, wanted %v", string(buf), want)
+ }
+
+ // Test PWrite.
+ want := "write"
+ writeIOSeq := usermem.BytesIOSequence([]byte(want))
+ if n, err := fd.PWrite(ctx, writeIOSeq, 0, WriteOptions{}); int(n) != len(want) && err != nil {
+ t.Errorf("PWrite: got err (%v, %v), wanted (%v, nil)", n, err, len(want))
+ }
+ if n, err := fd.PRead(ctx, ioseq, 0, ReadOptions{}); int(n) != len(want) && err != io.EOF {
+ t.Fatalf("PRead: got (%v, %v), wanted (%v, EOF)", n, err, len(want))
+ }
+ if want == string(buf) {
+ t.Fatalf("PRead: got %v, wanted %v", string(buf), want)
+ }
+
+ // Test Seek to 0 followed by Write.
+ want = "write2"
+ writeIOSeq = usermem.BytesIOSequence([]byte(want))
+ if n, err := fd.Seek(ctx, 0, linux.SEEK_SET); n != 0 && err != nil {
+ t.Errorf("Seek: got err (%v, %v), wanted (0, nil)", n, err)
+ }
+ if n, err := fd.Write(ctx, writeIOSeq, WriteOptions{}); int(n) != len(want) && err != nil {
+ t.Errorf("Write: got err (%v, %v), wanted (%v, nil)", n, err, len(want))
+ }
+ if n, err := fd.PRead(ctx, ioseq, 0, ReadOptions{}); int(n) != len(want) && err != io.EOF {
+ t.Fatalf("PRead: got (%v, %v), wanted (%v, EOF)", n, err, len(want))
+ }
+ if want == string(buf) {
+ t.Fatalf("PRead: got %v, wanted %v", string(buf), want)
+ }
+
+ // Test failure if offset != 0.
+ if n, err := fd.Seek(ctx, 1, linux.SEEK_SET); n != 0 && err != nil {
+ t.Errorf("Seek: got err (%v, %v), wanted (0, nil)", n, err)
+ }
+ if n, err := fd.Write(ctx, writeIOSeq, WriteOptions{}); n != 0 && err != syserror.EINVAL {
+ t.Errorf("Write: got err (%v, %v), wanted (0, EINVAL)", n, err)
+ }
+ if n, err := fd.PWrite(ctx, writeIOSeq, 2, WriteOptions{}); n != 0 && err != syserror.EINVAL {
+ t.Errorf("PWrite: got err (%v, %v), wanted (0, EINVAL)", n, err)
+ }
+}
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
new file mode 100644
index 000000000..6bb9ca180
--- /dev/null
+++ b/pkg/sentry/vfs/filesystem.go
@@ -0,0 +1,556 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+)
+
+// A Filesystem is a tree of nodes represented by Dentries, which forms part of
+// a VirtualFilesystem.
+//
+// Filesystems are reference-counted. Unless otherwise specified, all
+// Filesystem methods require that a reference is held.
+//
+// Filesystem is analogous to Linux's struct super_block.
+//
+// +stateify savable
+type Filesystem struct {
+ // refs is the reference count. refs is accessed using atomic memory
+ // operations.
+ refs int64
+
+ // vfs is the VirtualFilesystem that uses this Filesystem. vfs is
+ // immutable.
+ vfs *VirtualFilesystem
+
+ // fsType is the FilesystemType of this Filesystem.
+ fsType FilesystemType
+
+ // impl is the FilesystemImpl associated with this Filesystem. impl is
+ // immutable. This should be the last field in Dentry.
+ impl FilesystemImpl
+}
+
+// Init must be called before first use of fs.
+func (fs *Filesystem) Init(vfsObj *VirtualFilesystem, fsType FilesystemType, impl FilesystemImpl) {
+ fs.refs = 1
+ fs.vfs = vfsObj
+ fs.fsType = fsType
+ fs.impl = impl
+ vfsObj.filesystemsMu.Lock()
+ vfsObj.filesystems[fs] = struct{}{}
+ vfsObj.filesystemsMu.Unlock()
+}
+
+// FilesystemType returns the FilesystemType for this Filesystem.
+func (fs *Filesystem) FilesystemType() FilesystemType {
+ return fs.fsType
+}
+
+// VirtualFilesystem returns the containing VirtualFilesystem.
+func (fs *Filesystem) VirtualFilesystem() *VirtualFilesystem {
+ return fs.vfs
+}
+
+// Impl returns the FilesystemImpl associated with fs.
+func (fs *Filesystem) Impl() FilesystemImpl {
+ return fs.impl
+}
+
+// IncRef increments fs' reference count.
+func (fs *Filesystem) IncRef() {
+ if atomic.AddInt64(&fs.refs, 1) <= 1 {
+ panic("Filesystem.IncRef() called without holding a reference")
+ }
+}
+
+// TryIncRef increments fs' reference count and returns true. If fs' reference
+// count is zero, TryIncRef does nothing and returns false.
+//
+// TryIncRef does not require that a reference is held on fs.
+func (fs *Filesystem) TryIncRef() bool {
+ for {
+ refs := atomic.LoadInt64(&fs.refs)
+ if refs <= 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&fs.refs, refs, refs+1) {
+ return true
+ }
+ }
+}
+
+// DecRef decrements fs' reference count.
+func (fs *Filesystem) DecRef() {
+ if refs := atomic.AddInt64(&fs.refs, -1); refs == 0 {
+ fs.vfs.filesystemsMu.Lock()
+ delete(fs.vfs.filesystems, fs)
+ fs.vfs.filesystemsMu.Unlock()
+ fs.impl.Release()
+ } else if refs < 0 {
+ panic("Filesystem.decRef() called without holding a reference")
+ }
+}
+
+// FilesystemImpl contains implementation details for a Filesystem.
+// Implementations of FilesystemImpl should contain their associated Filesystem
+// by value as their first field.
+//
+// All methods that take a ResolvingPath must resolve the path before
+// performing any other checks, including rejection of the operation if not
+// supported by the FilesystemImpl. This is because the final FilesystemImpl
+// (responsible for actually implementing the operation) isn't known until path
+// resolution is complete.
+//
+// Unless otherwise specified, FilesystemImpl methods are responsible for
+// performing permission checks. In many cases, vfs package functions in
+// permissions.go may be used to help perform these checks.
+//
+// When multiple specified error conditions apply to a given method call, the
+// implementation may return any applicable errno unless otherwise specified,
+// but returning the earliest error specified is preferable to maximize
+// compatibility with Linux.
+//
+// All methods may return errors not specified, notably including:
+//
+// - ENOENT if a required path component does not exist.
+//
+// - ENOTDIR if an intermediate path component is not a directory.
+//
+// - Errors from vfs-package functions (ResolvingPath.Resolve*(),
+// Mount.CheckBeginWrite(), permission-checking functions, etc.)
+//
+// For all methods that take or return linux.Statx, Statx.Uid and Statx.Gid
+// should be interpreted as IDs in the root UserNamespace (i.e. as auth.KUID
+// and auth.KGID respectively).
+//
+// FilesystemImpl combines elements of Linux's struct super_operations and
+// struct inode_operations, for reasons described in the documentation for
+// Dentry.
+type FilesystemImpl interface {
+ // Release is called when the associated Filesystem reaches zero
+ // references.
+ Release()
+
+ // Sync "causes all pending modifications to filesystem metadata and cached
+ // file data to be written to the underlying [filesystem]", as by syncfs(2).
+ Sync(ctx context.Context) error
+
+ // AccessAt checks whether a user with creds can access the file at rp.
+ AccessAt(ctx context.Context, rp *ResolvingPath, creds *auth.Credentials, ats AccessTypes) error
+
+ // GetDentryAt returns a Dentry representing the file at rp. A reference is
+ // taken on the returned Dentry.
+ //
+ // GetDentryAt does not correspond directly to a Linux syscall; it is used
+ // in the implementation of:
+ //
+ // - Syscalls that need to resolve two paths: link(), linkat().
+ //
+ // - Syscalls that need to refer to a filesystem position outside the
+ // context of a file description: chdir(), fchdir(), chroot(), mount(),
+ // umount().
+ GetDentryAt(ctx context.Context, rp *ResolvingPath, opts GetDentryOptions) (*Dentry, error)
+
+ // GetParentDentryAt returns a Dentry representing the directory at the
+ // second-to-last path component in rp. (Note that, despite the name, this
+ // is not necessarily the parent directory of the file at rp, since the
+ // last path component in rp may be "." or "..".) A reference is taken on
+ // the returned Dentry.
+ //
+ // GetParentDentryAt does not correspond directly to a Linux syscall; it is
+ // used in the implementation of the rename() family of syscalls, which
+ // must resolve the parent directories of two paths.
+ //
+ // Preconditions: !rp.Done().
+ //
+ // Postconditions: If GetParentDentryAt returns a nil error, then
+ // rp.Final(). If GetParentDentryAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
+ GetParentDentryAt(ctx context.Context, rp *ResolvingPath) (*Dentry, error)
+
+ // LinkAt creates a hard link at rp representing the same file as vd. It
+ // does not take ownership of references on vd.
+ //
+ // Errors:
+ //
+ // - If the last path component in rp is "." or "..", LinkAt returns
+ // EEXIST.
+ //
+ // - If a file already exists at rp, LinkAt returns EEXIST.
+ //
+ // - If rp.MustBeDir(), LinkAt returns ENOENT.
+ //
+ // - If the directory in which the link would be created has been removed
+ // by RmdirAt or RenameAt, LinkAt returns ENOENT.
+ //
+ // - If rp.Mount != vd.Mount(), LinkAt returns EXDEV.
+ //
+ // - If vd represents a directory, LinkAt returns EPERM.
+ //
+ // - If vd represents a file for which all existing links have been
+ // removed, or a file created by open(O_TMPFILE|O_EXCL), LinkAt returns
+ // ENOENT. Equivalently, if vd represents a file with a link count of 0 not
+ // created by open(O_TMPFILE) without O_EXCL, LinkAt returns ENOENT.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink().
+ //
+ // Postconditions: If LinkAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
+ LinkAt(ctx context.Context, rp *ResolvingPath, vd VirtualDentry) error
+
+ // MkdirAt creates a directory at rp.
+ //
+ // Errors:
+ //
+ // - If the last path component in rp is "." or "..", MkdirAt returns
+ // EEXIST.
+ //
+ // - If a file already exists at rp, MkdirAt returns EEXIST.
+ //
+ // - If the directory in which the new directory would be created has been
+ // removed by RmdirAt or RenameAt, MkdirAt returns ENOENT.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink().
+ //
+ // Postconditions: If MkdirAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
+ MkdirAt(ctx context.Context, rp *ResolvingPath, opts MkdirOptions) error
+
+ // MknodAt creates a regular file, device special file, or named pipe at
+ // rp.
+ //
+ // Errors:
+ //
+ // - If the last path component in rp is "." or "..", MknodAt returns
+ // EEXIST.
+ //
+ // - If a file already exists at rp, MknodAt returns EEXIST.
+ //
+ // - If rp.MustBeDir(), MknodAt returns ENOENT.
+ //
+ // - If the directory in which the file would be created has been removed
+ // by RmdirAt or RenameAt, MknodAt returns ENOENT.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink().
+ //
+ // Postconditions: If MknodAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
+ MknodAt(ctx context.Context, rp *ResolvingPath, opts MknodOptions) error
+
+ // OpenAt returns an FileDescription providing access to the file at rp. A
+ // reference is taken on the returned FileDescription.
+ //
+ // Errors:
+ //
+ // - If opts.Flags specifies O_TMPFILE and this feature is unsupported by
+ // the implementation, OpenAt returns EOPNOTSUPP. (All other unsupported
+ // features are silently ignored, consistently with Linux's open*(2).)
+ OpenAt(ctx context.Context, rp *ResolvingPath, opts OpenOptions) (*FileDescription, error)
+
+ // ReadlinkAt returns the target of the symbolic link at rp.
+ //
+ // Errors:
+ //
+ // - If the file at rp is not a symbolic link, ReadlinkAt returns EINVAL.
+ ReadlinkAt(ctx context.Context, rp *ResolvingPath) (string, error)
+
+ // RenameAt renames the file named oldName in directory oldParentVD to rp.
+ // It does not take ownership of references on oldParentVD.
+ //
+ // Errors [1]:
+ //
+ // - If opts.Flags specifies unsupported options, RenameAt returns EINVAL.
+ //
+ // - If the last path component in rp is "." or "..", and opts.Flags
+ // contains RENAME_NOREPLACE, RenameAt returns EEXIST.
+ //
+ // - If the last path component in rp is "." or "..", and opts.Flags does
+ // not contain RENAME_NOREPLACE, RenameAt returns EBUSY.
+ //
+ // - If rp.Mount != oldParentVD.Mount(), RenameAt returns EXDEV.
+ //
+ // - If the renamed file is not a directory, and opts.MustBeDir is true,
+ // RenameAt returns ENOTDIR.
+ //
+ // - If renaming would replace an existing file and opts.Flags contains
+ // RENAME_NOREPLACE, RenameAt returns EEXIST.
+ //
+ // - If there is no existing file at rp and opts.Flags contains
+ // RENAME_EXCHANGE, RenameAt returns ENOENT.
+ //
+ // - If there is an existing non-directory file at rp, and rp.MustBeDir()
+ // is true, RenameAt returns ENOTDIR.
+ //
+ // - If the renamed file is not a directory, opts.Flags does not contain
+ // RENAME_EXCHANGE, and rp.MustBeDir() is true, RenameAt returns ENOTDIR.
+ // (This check is not subsumed by the check for directory replacement below
+ // since it applies even if there is no file to replace.)
+ //
+ // - If the renamed file is a directory, and the new parent directory of
+ // the renamed file is either the renamed directory or a descendant
+ // subdirectory of the renamed directory, RenameAt returns EINVAL.
+ //
+ // - If renaming would exchange the renamed file with an ancestor directory
+ // of the renamed file, RenameAt returns EINVAL.
+ //
+ // - If renaming would replace an ancestor directory of the renamed file,
+ // RenameAt returns ENOTEMPTY. (This check would be subsumed by the
+ // non-empty directory check below; however, this check takes place before
+ // the self-rename check.)
+ //
+ // - If the renamed file would replace or exchange with itself (i.e. the
+ // source and destination paths resolve to the same file), RenameAt returns
+ // nil, skipping the checks described below.
+ //
+ // - If the source or destination directory is not writable by the provider
+ // of rp.Credentials(), RenameAt returns EACCES.
+ //
+ // - If the renamed file is a directory, and renaming would replace a
+ // non-directory file, RenameAt returns ENOTDIR.
+ //
+ // - If the renamed file is not a directory, and renaming would replace a
+ // directory, RenameAt returns EISDIR.
+ //
+ // - If the new parent directory of the renamed file has been removed by
+ // RmdirAt or a preceding call to RenameAt, RenameAt returns ENOENT.
+ //
+ // - If the renamed file is a directory, it is not writable by the
+ // provider of rp.Credentials(), and the source and destination parent
+ // directories are different, RenameAt returns EACCES. (This is nominally
+ // required to change the ".." entry in the renamed directory.)
+ //
+ // - If renaming would replace a non-empty directory, RenameAt returns
+ // ENOTEMPTY.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink(). oldParentVD.Dentry() was obtained from a
+ // previous call to
+ // oldParentVD.Mount().Filesystem().Impl().GetParentDentryAt(). oldName is
+ // not "." or "..".
+ //
+ // Postconditions: If RenameAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
+ //
+ // [1] "The worst of all namespace operations - renaming directory.
+ // "Perverted" doesn't even start to describe it. Somebody in UCB had a
+ // heck of a trip..." - fs/namei.c:vfs_rename()
+ RenameAt(ctx context.Context, rp *ResolvingPath, oldParentVD VirtualDentry, oldName string, opts RenameOptions) error
+
+ // RmdirAt removes the directory at rp.
+ //
+ // Errors:
+ //
+ // - If the last path component in rp is ".", RmdirAt returns EINVAL.
+ //
+ // - If the last path component in rp is "..", RmdirAt returns ENOTEMPTY.
+ //
+ // - If no file exists at rp, RmdirAt returns ENOENT.
+ //
+ // - If the file at rp exists but is not a directory, RmdirAt returns
+ // ENOTDIR.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink().
+ //
+ // Postconditions: If RmdirAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
+ RmdirAt(ctx context.Context, rp *ResolvingPath) error
+
+ // SetStatAt updates metadata for the file at the given path. Implementations
+ // are responsible for checking if the operation can be performed
+ // (see vfs.CheckSetStat() for common checks).
+ //
+ // Errors:
+ //
+ // - If opts specifies unsupported options, SetStatAt returns EINVAL.
+ SetStatAt(ctx context.Context, rp *ResolvingPath, opts SetStatOptions) error
+
+ // StatAt returns metadata for the file at rp.
+ StatAt(ctx context.Context, rp *ResolvingPath, opts StatOptions) (linux.Statx, error)
+
+ // StatFSAt returns metadata for the filesystem containing the file at rp.
+ // (This method takes a path because a FilesystemImpl may consist of any
+ // number of constituent filesystems.)
+ StatFSAt(ctx context.Context, rp *ResolvingPath) (linux.Statfs, error)
+
+ // SymlinkAt creates a symbolic link at rp referring to the given target.
+ //
+ // Errors:
+ //
+ // - If the last path component in rp is "." or "..", SymlinkAt returns
+ // EEXIST.
+ //
+ // - If a file already exists at rp, SymlinkAt returns EEXIST.
+ //
+ // - If rp.MustBeDir(), SymlinkAt returns ENOENT.
+ //
+ // - If the directory in which the symbolic link would be created has been
+ // removed by RmdirAt or RenameAt, SymlinkAt returns ENOENT.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink().
+ //
+ // Postconditions: If SymlinkAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
+ SymlinkAt(ctx context.Context, rp *ResolvingPath, target string) error
+
+ // UnlinkAt removes the file at rp.
+ //
+ // Errors:
+ //
+ // - If the last path component in rp is "." or "..", UnlinkAt returns
+ // EISDIR.
+ //
+ // - If no file exists at rp, UnlinkAt returns ENOENT.
+ //
+ // - If rp.MustBeDir(), and the file at rp exists and is not a directory,
+ // UnlinkAt returns ENOTDIR.
+ //
+ // - If the file at rp exists but is a directory, UnlinkAt returns EISDIR.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink().
+ //
+ // Postconditions: If UnlinkAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
+ UnlinkAt(ctx context.Context, rp *ResolvingPath) error
+
+ // ListxattrAt returns all extended attribute names for the file at rp.
+ //
+ // Errors:
+ //
+ // - If extended attributes are not supported by the filesystem,
+ // ListxattrAt returns ENOTSUP.
+ //
+ // - If the size of the list (including a NUL terminating byte after every
+ // entry) would exceed size, ERANGE may be returned. Note that
+ // implementations are free to ignore size entirely and return without
+ // error). In all cases, if size is 0, the list should be returned without
+ // error, regardless of size.
+ ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error)
+
+ // GetxattrAt returns the value associated with the given extended
+ // attribute for the file at rp.
+ //
+ // Errors:
+ //
+ // - If extended attributes are not supported by the filesystem, GetxattrAt
+ // returns ENOTSUP.
+ //
+ // - If an extended attribute named opts.Name does not exist, ENODATA is
+ // returned.
+ //
+ // - If the size of the return value exceeds opts.Size, ERANGE may be
+ // returned (note that implementations are free to ignore opts.Size entirely
+ // and return without error). In all cases, if opts.Size is 0, the value
+ // should be returned without error, regardless of size.
+ GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error)
+
+ // SetxattrAt changes the value associated with the given extended
+ // attribute for the file at rp.
+ //
+ // Errors:
+ //
+ // - If extended attributes are not supported by the filesystem, SetxattrAt
+ // returns ENOTSUP.
+ //
+ // - If XATTR_CREATE is set in opts.Flag and opts.Name already exists,
+ // EEXIST is returned. If XATTR_REPLACE is set and opts.Name does not exist,
+ // ENODATA is returned.
+ SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error
+
+ // RemovexattrAt removes the given extended attribute from the file at rp.
+ //
+ // Errors:
+ //
+ // - If extended attributes are not supported by the filesystem,
+ // RemovexattrAt returns ENOTSUP.
+ //
+ // - If name does not exist, ENODATA is returned.
+ RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error
+
+ // BoundEndpointAt returns the Unix socket endpoint bound at the path rp.
+ //
+ // Errors:
+ //
+ // - If the file does not have write permissions, then BoundEndpointAt
+ // returns EACCES.
+ //
+ // - If a non-socket file exists at rp, then BoundEndpointAt returns
+ // ECONNREFUSED.
+ BoundEndpointAt(ctx context.Context, rp *ResolvingPath, opts BoundEndpointOptions) (transport.BoundEndpoint, error)
+
+ // PrependPath prepends a path from vd to vd.Mount().Root() to b.
+ //
+ // If vfsroot.Ok(), it is the contextual VFS root; if it is encountered
+ // before vd.Mount().Root(), PrependPath should stop prepending path
+ // components and return a PrependPathAtVFSRootError.
+ //
+ // If traversal of vd.Dentry()'s ancestors encounters an independent
+ // ("root") Dentry that is not vd.Mount().Root() (i.e. vd.Dentry() is not a
+ // descendant of vd.Mount().Root()), PrependPath should stop prepending
+ // path components and return a PrependPathAtNonMountRootError.
+ //
+ // Filesystems for which Dentries do not have meaningful paths may prepend
+ // an arbitrary descriptive string to b and then return a
+ // PrependPathSyntheticError.
+ //
+ // Most implementations can acquire the appropriate locks to ensure that
+ // Dentry.Name() and Dentry.Parent() are fixed for vd.Dentry() and all of
+ // its ancestors, then call GenericPrependPath.
+ //
+ // Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl.
+ PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error
+}
+
+// PrependPathAtVFSRootError is returned by implementations of
+// FilesystemImpl.PrependPath() when they encounter the contextual VFS root.
+type PrependPathAtVFSRootError struct{}
+
+// Error implements error.Error.
+func (PrependPathAtVFSRootError) Error() string {
+ return "vfs.FilesystemImpl.PrependPath() reached VFS root"
+}
+
+// PrependPathAtNonMountRootError is returned by implementations of
+// FilesystemImpl.PrependPath() when they encounter an independent ancestor
+// Dentry that is not the Mount root.
+type PrependPathAtNonMountRootError struct{}
+
+// Error implements error.Error.
+func (PrependPathAtNonMountRootError) Error() string {
+ return "vfs.FilesystemImpl.PrependPath() reached root other than Mount root"
+}
+
+// PrependPathSyntheticError is returned by implementations of
+// FilesystemImpl.PrependPath() for which prepended names do not represent real
+// paths.
+type PrependPathSyntheticError struct{}
+
+// Error implements error.Error.
+func (PrependPathSyntheticError) Error() string {
+ return "vfs.FilesystemImpl.PrependPath() prepended synthetic name"
+}
diff --git a/pkg/sentry/vfs/filesystem_impl_util.go b/pkg/sentry/vfs/filesystem_impl_util.go
new file mode 100644
index 000000000..465e610e0
--- /dev/null
+++ b/pkg/sentry/vfs/filesystem_impl_util.go
@@ -0,0 +1,43 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "strings"
+)
+
+// GenericParseMountOptions parses a comma-separated list of options of the
+// form "key" or "key=value", where neither key nor value contain commas, and
+// returns it as a map. If str contains duplicate keys, then the last value
+// wins. For example:
+//
+// str = "key0=value0,key1,key2=value2,key0=value3" -> map{'key0':'value3','key1':'','key2':'value2'}
+//
+// GenericParseMountOptions is not appropriate if values may contain commas,
+// e.g. in the case of the mpol mount option for tmpfs(5).
+func GenericParseMountOptions(str string) map[string]string {
+ m := make(map[string]string)
+ for _, opt := range strings.Split(str, ",") {
+ if len(opt) > 0 {
+ res := strings.SplitN(opt, "=", 2)
+ if len(res) == 2 {
+ m[res[0]] = res[1]
+ } else {
+ m[opt] = ""
+ }
+ }
+ }
+ return m
+}
diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go
new file mode 100644
index 000000000..f2298f7f6
--- /dev/null
+++ b/pkg/sentry/vfs/filesystem_type.go
@@ -0,0 +1,117 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// A FilesystemType constructs filesystems.
+//
+// FilesystemType is analogous to Linux's struct file_system_type.
+type FilesystemType interface {
+ // GetFilesystem returns a Filesystem configured by the given options,
+ // along with its mount root. A reference is taken on the returned
+ // Filesystem and Dentry.
+ GetFilesystem(ctx context.Context, vfsObj *VirtualFilesystem, creds *auth.Credentials, source string, opts GetFilesystemOptions) (*Filesystem, *Dentry, error)
+
+ // Name returns the name of this FilesystemType.
+ Name() string
+}
+
+// GetFilesystemOptions contains options to FilesystemType.GetFilesystem.
+type GetFilesystemOptions struct {
+ // Data is the string passed as the 5th argument to mount(2), which is
+ // usually a comma-separated list of filesystem-specific mount options.
+ Data string
+
+ // InternalData holds opaque FilesystemType-specific data. There is
+ // intentionally no way for applications to specify InternalData; if it is
+ // not nil, the call to GetFilesystem originates from within the sentry.
+ InternalData interface{}
+}
+
+// +stateify savable
+type registeredFilesystemType struct {
+ fsType FilesystemType
+ opts RegisterFilesystemTypeOptions
+}
+
+// RegisterFilesystemTypeOptions contains options to
+// VirtualFilesystem.RegisterFilesystem().
+type RegisterFilesystemTypeOptions struct {
+ // If AllowUserMount is true, allow calls to VirtualFilesystem.MountAt()
+ // for which MountOptions.InternalMount == false to use this filesystem
+ // type.
+ AllowUserMount bool
+
+ // If AllowUserList is true, make this filesystem type visible in
+ // /proc/filesystems.
+ AllowUserList bool
+
+ // If RequiresDevice is true, indicate that mounting this filesystem
+ // requires a block device as the mount source in /proc/filesystems.
+ RequiresDevice bool
+}
+
+// RegisterFilesystemType registers the given FilesystemType in vfs with the
+// given name.
+func (vfs *VirtualFilesystem) RegisterFilesystemType(name string, fsType FilesystemType, opts *RegisterFilesystemTypeOptions) error {
+ vfs.fsTypesMu.Lock()
+ defer vfs.fsTypesMu.Unlock()
+ if existing, ok := vfs.fsTypes[name]; ok {
+ return fmt.Errorf("name %q is already registered to filesystem type %T", name, existing.fsType)
+ }
+ vfs.fsTypes[name] = &registeredFilesystemType{
+ fsType: fsType,
+ opts: *opts,
+ }
+ return nil
+}
+
+// MustRegisterFilesystemType is equivalent to RegisterFilesystemType but
+// panics on failure.
+func (vfs *VirtualFilesystem) MustRegisterFilesystemType(name string, fsType FilesystemType, opts *RegisterFilesystemTypeOptions) {
+ if err := vfs.RegisterFilesystemType(name, fsType, opts); err != nil {
+ panic(fmt.Sprintf("failed to register filesystem type %T: %v", fsType, err))
+ }
+}
+
+func (vfs *VirtualFilesystem) getFilesystemType(name string) *registeredFilesystemType {
+ vfs.fsTypesMu.RLock()
+ defer vfs.fsTypesMu.RUnlock()
+ return vfs.fsTypes[name]
+}
+
+// GenerateProcFilesystems emits the contents of /proc/filesystems for vfs to
+// buf.
+func (vfs *VirtualFilesystem) GenerateProcFilesystems(buf *bytes.Buffer) {
+ vfs.fsTypesMu.RLock()
+ defer vfs.fsTypesMu.RUnlock()
+ for name, rft := range vfs.fsTypes {
+ if !rft.opts.AllowUserList {
+ continue
+ }
+ var nodev string
+ if !rft.opts.RequiresDevice {
+ nodev = "nodev"
+ }
+ fmt.Fprintf(buf, "%s\t%s\n", nodev, name)
+ }
+}
diff --git a/pkg/sentry/vfs/g3doc/inotify.md b/pkg/sentry/vfs/g3doc/inotify.md
new file mode 100644
index 000000000..e7da49faa
--- /dev/null
+++ b/pkg/sentry/vfs/g3doc/inotify.md
@@ -0,0 +1,210 @@
+# Inotify
+
+Inotify is a mechanism for monitoring filesystem events in Linux--see
+inotify(7). An inotify instance can be used to monitor files and directories for
+modifications, creation/deletion, etc. The inotify API consists of system calls
+that create inotify instances (inotify_init/inotify_init1) and add/remove
+watches on files to an instance (inotify_add_watch/inotify_rm_watch). Events are
+generated from various places in the sentry, including the syscall layer, the
+vfs layer, the process fd table, and within each filesystem implementation. This
+document outlines the implementation details of inotify in VFS2.
+
+## Inotify Objects
+
+Inotify data structures are implemented in the vfs package.
+
+### vfs.Inotify
+
+Inotify instances are represented by vfs.Inotify objects, which implement
+vfs.FileDescriptionImpl. As in Linux, inotify fds are backed by a
+pseudo-filesystem (anonfs). Each inotify instance receives events from a set of
+vfs.Watch objects, which can be modified with inotify_add_watch(2) and
+inotify_rm_watch(2). An application can retrieve events by reading the inotify
+fd.
+
+### vfs.Watches
+
+The set of all watches held on a single file (i.e., the watch target) is stored
+in vfs.Watches. Each watch will belong to a different inotify instance (an
+instance can only have one watch on any watch target). The watches are stored in
+a map indexed by their vfs.Inotify owner’s id. Hard links and file descriptions
+to a single file will all share the same vfs.Watches. Activity on the target
+causes its vfs.Watches to generate notifications on its watches’ inotify
+instances.
+
+### vfs.Watch
+
+A single watch, owned by one inotify instance and applied to one watch target.
+Both the vfs.Inotify owner and vfs.Watches on the target will hold a vfs.Watch,
+which leads to some complicated locking behavior (see Lock Ordering). Whenever a
+watch is notified of an event on its target, it will queue events to its inotify
+instance for delivery to the user.
+
+### vfs.Event
+
+vfs.Event is a simple struct encapsulating all the fields for an inotify event.
+It is generated by vfs.Watches and forwarded to the watches' owners. It is
+serialized to the user during read(2) syscalls on the associated fs.Inotify's
+fd.
+
+## Lock Ordering
+
+There are three locks related to the inotify implementation:
+
+Inotify.mu: the inotify instance lock. Inotify.evMu: the inotify event queue
+lock. Watches.mu: the watch set lock, used to protect the collection of watches
+on a target.
+
+The correct lock ordering for inotify code is:
+
+Inotify.mu -> Watches.mu -> Inotify.evMu.
+
+Note that we use a distinct lock to protect the inotify event queue. If we
+simply used Inotify.mu, we could simultaneously have locks being acquired in the
+order of Inotify.mu -> Watches.mu and Watches.mu -> Inotify.mu, which would
+cause deadlocks. For instance, adding a watch to an inotify instance would
+require locking Inotify.mu, and then adding the same watch to the target would
+cause Watches.mu to be held. At the same time, generating an event on the target
+would require Watches.mu to be held before iterating through each watch, and
+then notifying the owner of each watch would cause Inotify.mu to be held.
+
+See the vfs package comment to understand how inotify locks fit into the overall
+ordering of filesystem locks.
+
+## Watch Targets in Different Filesystem Implementations
+
+In Linux, watches reside on inodes at the virtual filesystem layer. As a result,
+all hard links and file descriptions on a single file will all share the same
+watch set. In VFS2, there is no common inode structure across filesystem types
+(some may not even have inodes), so we have to plumb inotify support through
+each specific filesystem implementation. Some of the technical considerations
+are outlined below.
+
+### Tmpfs
+
+For filesystems with inodes, like tmpfs, the design is quite similar to that of
+Linux, where watches reside on the inode.
+
+### Pseudo-filesystems
+
+Technically, because inotify is implemented at the vfs layer in Linux,
+pseudo-filesystems on top of kernfs support inotify passively. However, watches
+can only track explicit filesystem operations like read/write, open/close,
+mknod, etc., so watches on a target like /proc/self/fd will not generate events
+every time a new fd is added or removed. As of this writing, we leave inotify
+unimplemented in kernfs and anonfs; it does not seem particularly useful.
+
+### Gofer Filesystem (fsimpl/gofer)
+
+The gofer filesystem has several traits that make it difficult to support
+inotify:
+
+* **There are no inodes.** A file is represented as a dentry that holds an
+ unopened p9 file (and possibly an open FID), through which the Sentry
+ interacts with the gofer.
+ * *Solution:* Because there is no inode structure stored in the sandbox,
+ inotify watches must be held on the dentry. This would be an issue in
+ the presence of hard links, where multiple dentries would need to share
+ the same set of watches, but in VFS2, we do not support the internal
+ creation of hard links on gofer fs. As a result, we make the assumption
+ that every dentry corresponds to a unique inode. However, the next point
+ raises an issue with this assumption:
+* **The Sentry cannot always be aware of hard links on the remote
+ filesystem.** There is no way for us to confirm whether two files on the
+ remote filesystem are actually links to the same inode. QIDs and inodes are
+ not always 1:1. The assumption that dentries and inodes are 1:1 is
+ inevitably broken if there are remote hard links that we cannot detect.
+ * *Solution:* this is an issue with gofer fs in general, not only inotify,
+ and we will have to live with it.
+* **Dentries can be cached, and then evicted.** Dentry lifetime does not
+ correspond to file lifetime. Because gofer fs is not entirely in-memory, the
+ absence of a dentry does not mean that the corresponding file does not
+ exist, nor does a dentry reaching zero references mean that the
+ corresponding file no longer exists. When a dentry reaches zero references,
+ it will be cached, in case the file at that path is needed again in the
+ future. However, the dentry may be evicted from the cache, which will cause
+ a new dentry to be created next time the same file path is used. The
+ existing watches will be lost.
+ * *Solution:* When a dentry reaches zero references, do not cache it if it
+ has any watches, so we can avoid eviction/destruction. Note that if the
+ dentry was deleted or invalidated (d.vfsd.IsDead()), we should still
+ destroy it along with its watches. Additionally, when a dentry’s last
+ watch is removed, we cache it if it also has zero references. This way,
+ the dentry can eventually be evicted from memory if it is no longer
+ needed.
+* **Dentries can be invalidated.** Another issue with dentry lifetime is that
+ the remote file at the file path represented may change from underneath the
+ dentry. In this case, the next time that the dentry is used, it will be
+ invalidated and a new dentry will replace it. In this case, it is not clear
+ what should be done with the watches on the old dentry.
+ * *Solution:* Silently destroy the watches when invalidation occurs. We
+ have no way of knowing exactly what happened, when it happens. Inotify
+ instances on NFS files in Linux probably behave in a similar fashion,
+ since inotify is implemented at the vfs layer and is not aware of the
+ complexities of remote file systems.
+ * An alternative would be to issue some kind of event upon invalidation,
+ e.g. a delete event, but this has several issues:
+ * We cannot discern whether the remote file was invalidated because it was
+ moved, deleted, etc. This information is crucial, because these cases
+ should result in different events. Furthermore, the watches should only
+ be destroyed if the file has been deleted.
+ * Moreover, the mechanism for detecting whether the underlying file has
+ changed is to check whether a new QID is given by the gofer. This may
+ result in false positives, e.g. suppose that the server closed and
+ re-opened the same file, which may result in a new QID.
+ * Finally, the time of the event may be completely different from the time
+ of the file modification, since a dentry is not immediately notified
+ when the underlying file has changed. It would be quite unexpected to
+ receive the notification when invalidation was triggered, i.e. the next
+ time the file was accessed within the sandbox, because then the
+ read/write/etc. operation on the file would not result in the expected
+ event.
+ * Another point in favor of the first solution: inotify in Linux can
+ already be lossy on local filesystems (one of the sacrifices made so
+ that filesystem performance isn’t killed), and it is lossy on NFS for
+ similar reasons to gofer fs. Therefore, it is better for inotify to be
+ silent than to emit incorrect notifications.
+* **There may be external users of the remote filesystem.** We can only track
+ operations performed on the file within the sandbox. This is sufficient
+ under InteropModeExclusive, but whenever there are external users, the set
+ of actions we are aware of is incomplete.
+ * *Solution:* We could either return an error or just issue a warning when
+ inotify is used without InteropModeExclusive. Although faulty, VFS1
+ allows it when the filesystem is shared, and Linux does the same for
+ remote filesystems (as mentioned above, inotify sits at the vfs level).
+
+## Dentry Interface
+
+For events that must be generated above the vfs layer, we provide the following
+DentryImpl methods to allow interactions with targets on any FilesystemImpl:
+
+* **InotifyWithParent()** generates events on the dentry’s watches as well as
+ its parent’s.
+* **Watches()** retrieves the watch set of the target represented by the
+ dentry. This is used to access and modify watches on a target.
+* **OnZeroWatches()** performs cleanup tasks after the last watch is removed
+ from a dentry. This is needed by gofer fs, which must allow a watched dentry
+ to be cached once it has no more watches. Most implementations can just do
+ nothing. Note that OnZeroWatches() must be called after all inotify locks
+ are released to preserve lock ordering, since it may acquire
+ FilesystemImpl-specific locks.
+
+## IN_EXCL_UNLINK
+
+There are several options that can be set for a watch, specified as part of the
+mask in inotify_add_watch(2). In particular, IN_EXCL_UNLINK requires some
+additional support in each filesystem.
+
+A watch with IN_EXCL_UNLINK will not generate events for its target if it
+corresponds to a path that was unlinked. For instance, if an fd is opened on
+“foo/bar” and “foo/bar” is subsequently unlinked, any reads/writes/etc. on the
+fd will be ignored by watches on “foo” or “foo/bar” with IN_EXCL_UNLINK. This
+requires each DentryImpl to keep track of whether it has been unlinked, in order
+to determine whether events should be sent to watches with IN_EXCL_UNLINK.
+
+## IN_ONESHOT
+
+One-shot watches expire after generating a single event. When an event occurs,
+all one-shot watches on the target that successfully generated an event are
+removed. Lock ordering can cause the management of one-shot watches to be quite
+expensive; see Watches.Notify() for more information.
diff --git a/pkg/sentry/vfs/genericfstree/BUILD b/pkg/sentry/vfs/genericfstree/BUILD
new file mode 100644
index 000000000..d8fd92677
--- /dev/null
+++ b/pkg/sentry/vfs/genericfstree/BUILD
@@ -0,0 +1,16 @@
+load("//tools/go_generics:defs.bzl", "go_template")
+
+package(
+ default_visibility = ["//:sandbox"],
+ licenses = ["notice"],
+)
+
+go_template(
+ name = "generic_fstree",
+ srcs = [
+ "genericfstree.go",
+ ],
+ types = [
+ "Dentry",
+ ],
+)
diff --git a/pkg/sentry/vfs/genericfstree/genericfstree.go b/pkg/sentry/vfs/genericfstree/genericfstree.go
new file mode 100644
index 000000000..8882fa84a
--- /dev/null
+++ b/pkg/sentry/vfs/genericfstree/genericfstree.go
@@ -0,0 +1,81 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package genericfstree provides tools for implementing vfs.FilesystemImpls
+// where a single statically-determined lock or set of locks is sufficient to
+// ensure that a Dentry's name and parent are contextually immutable.
+//
+// Clients using this package must use the go_template_instance rule in
+// tools/go_generics/defs.bzl to create an instantiation of this template
+// package, providing types to use in place of Dentry.
+package genericfstree
+
+import (
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// Dentry is a required type parameter that is a struct with the given fields.
+type Dentry struct {
+ // vfsd is the embedded vfs.Dentry corresponding to this vfs.DentryImpl.
+ vfsd vfs.Dentry
+
+ // parent is the parent of this Dentry in the filesystem's tree. If this
+ // Dentry is a filesystem root, parent is nil.
+ parent *Dentry
+
+ // name is the name of this Dentry in its parent. If this Dentry is a
+ // filesystem root, name is unspecified.
+ name string
+}
+
+// IsAncestorDentry returns true if d is an ancestor of d2; that is, d is
+// either d2's parent or an ancestor of d2's parent.
+func IsAncestorDentry(d, d2 *Dentry) bool {
+ for d2 != nil { // Stop at root, where d2.parent == nil.
+ if d2.parent == d {
+ return true
+ }
+ if d2.parent == d2 {
+ return false
+ }
+ d2 = d2.parent
+ }
+ return false
+}
+
+// ParentOrSelf returns d.parent. If d.parent is nil, ParentOrSelf returns d.
+func ParentOrSelf(d *Dentry) *Dentry {
+ if d.parent != nil {
+ return d.parent
+ }
+ return d
+}
+
+// PrependPath is a generic implementation of FilesystemImpl.PrependPath().
+func PrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath.Builder) error {
+ for {
+ if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() {
+ return vfs.PrependPathAtVFSRootError{}
+ }
+ if &d.vfsd == mnt.Root() {
+ return nil
+ }
+ if d.parent == nil {
+ return vfs.PrependPathAtNonMountRootError{}
+ }
+ b.PrependComponent(d.name)
+ d = d.parent
+ }
+}
diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go
new file mode 100644
index 000000000..c2e21ac5f
--- /dev/null
+++ b/pkg/sentry/vfs/inotify.go
@@ -0,0 +1,774 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "bytes"
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// inotifyEventBaseSize is the base size of linux's struct inotify_event. This
+// must be a power 2 for rounding below.
+const inotifyEventBaseSize = 16
+
+// EventType defines different kinds of inotfiy events.
+//
+// The way events are labelled appears somewhat arbitrary, but they must match
+// Linux so that IN_EXCL_UNLINK behaves as it does in Linux.
+type EventType uint8
+
+// PathEvent and InodeEvent correspond to FSNOTIFY_EVENT_PATH and
+// FSNOTIFY_EVENT_INODE in Linux.
+const (
+ PathEvent EventType = iota
+ InodeEvent EventType = iota
+)
+
+// Inotify represents an inotify instance created by inotify_init(2) or
+// inotify_init1(2). Inotify implements FileDescriptionImpl.
+//
+// +stateify savable
+type Inotify struct {
+ vfsfd FileDescription
+ FileDescriptionDefaultImpl
+ DentryMetadataFileDescriptionImpl
+ NoLockFD
+
+ // Unique identifier for this inotify instance. We don't just reuse the
+ // inotify fd because fds can be duped. These should not be exposed to the
+ // user, since we may aggressively reuse an id on S/R.
+ id uint64
+
+ // queue is used to notify interested parties when the inotify instance
+ // becomes readable or writable.
+ queue waiter.Queue `state:"nosave"`
+
+ // evMu *only* protects the events list. We need a separate lock while
+ // queuing events: using mu may violate lock ordering, since at that point
+ // the calling goroutine may already hold Watches.mu.
+ evMu sync.Mutex `state:"nosave"`
+
+ // A list of pending events for this inotify instance. Protected by evMu.
+ events eventList
+
+ // A scratch buffer, used to serialize inotify events. Allocate this
+ // ahead of time for the sake of performance. Protected by evMu.
+ scratch []byte
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // nextWatchMinusOne is used to allocate watch descriptors on this Inotify
+ // instance. Note that Linux starts numbering watch descriptors from 1.
+ nextWatchMinusOne int32
+
+ // Map from watch descriptors to watch objects.
+ watches map[int32]*Watch
+}
+
+var _ FileDescriptionImpl = (*Inotify)(nil)
+
+// NewInotifyFD constructs a new Inotify instance.
+func NewInotifyFD(ctx context.Context, vfsObj *VirtualFilesystem, flags uint32) (*FileDescription, error) {
+ // O_CLOEXEC affects file descriptors, so it must be handled outside of vfs.
+ flags &^= linux.O_CLOEXEC
+ if flags&^linux.O_NONBLOCK != 0 {
+ return nil, syserror.EINVAL
+ }
+
+ id := uniqueid.GlobalFromContext(ctx)
+ vd := vfsObj.NewAnonVirtualDentry(fmt.Sprintf("[inotifyfd:%d]", id))
+ defer vd.DecRef()
+ fd := &Inotify{
+ id: id,
+ scratch: make([]byte, inotifyEventBaseSize),
+ watches: make(map[int32]*Watch),
+ }
+ if err := fd.vfsfd.Init(fd, flags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{
+ UseDentryMetadata: true,
+ DenyPRead: true,
+ DenyPWrite: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// Release implements FileDescriptionImpl.Release. Release removes all
+// watches and frees all resources for an inotify instance.
+func (i *Inotify) Release() {
+ var ds []*Dentry
+
+ // We need to hold i.mu to avoid a race with concurrent calls to
+ // Inotify.handleDeletion from Watches. There's no risk of Watches
+ // accessing this Inotify after the destructor ends, because we remove all
+ // references to it below.
+ i.mu.Lock()
+ for _, w := range i.watches {
+ // Remove references to the watch from the watches set on the target. We
+ // don't need to worry about the references from i.watches, since this
+ // file description is about to be destroyed.
+ d := w.target
+ ws := d.Watches()
+ // Watchable dentries should never return a nil watch set.
+ if ws == nil {
+ panic("Cannot remove watch from an unwatchable dentry")
+ }
+ ws.Remove(i.id)
+ if ws.Size() == 0 {
+ ds = append(ds, d)
+ }
+ }
+ i.mu.Unlock()
+
+ for _, d := range ds {
+ d.OnZeroWatches()
+ }
+}
+
+// Allocate implements FileDescription.Allocate.
+func (i *Inotify) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ panic("Allocate should not be called on read-only inotify fds")
+}
+
+// EventRegister implements waiter.Waitable.
+func (i *Inotify) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ i.queue.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.
+func (i *Inotify) EventUnregister(e *waiter.Entry) {
+ i.queue.EventUnregister(e)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+//
+// Readiness indicates whether there are pending events for an inotify instance.
+func (i *Inotify) Readiness(mask waiter.EventMask) waiter.EventMask {
+ ready := waiter.EventMask(0)
+
+ i.evMu.Lock()
+ defer i.evMu.Unlock()
+
+ if !i.events.Empty() {
+ ready |= waiter.EventIn
+ }
+
+ return mask & ready
+}
+
+// PRead implements FileDescriptionImpl.
+func (*Inotify) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// PWrite implements FileDescriptionImpl.
+func (*Inotify) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Write implements FileDescriptionImpl.Write.
+func (*Inotify) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// Read implements FileDescriptionImpl.Read.
+func (i *Inotify) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) {
+ if dst.NumBytes() < inotifyEventBaseSize {
+ return 0, syserror.EINVAL
+ }
+
+ i.evMu.Lock()
+ defer i.evMu.Unlock()
+
+ if i.events.Empty() {
+ // Nothing to read yet, tell caller to block.
+ return 0, syserror.ErrWouldBlock
+ }
+
+ var writeLen int64
+ for it := i.events.Front(); it != nil; {
+ // Advance `it` before the element is removed from the list, or else
+ // it.Next() will always be nil.
+ event := it
+ it = it.Next()
+
+ // Does the buffer have enough remaining space to hold the event we're
+ // about to write out?
+ if dst.NumBytes() < int64(event.sizeOf()) {
+ if writeLen > 0 {
+ // Buffer wasn't big enough for all pending events, but we did
+ // write some events out.
+ return writeLen, nil
+ }
+ return 0, syserror.EINVAL
+ }
+
+ // Linux always dequeues an available event as long as there's enough
+ // buffer space to copy it out, even if the copy below fails. Emulate
+ // this behaviour.
+ i.events.Remove(event)
+
+ // Buffer has enough space, copy event to the read buffer.
+ n, err := event.CopyTo(ctx, i.scratch, dst)
+ if err != nil {
+ return 0, err
+ }
+
+ writeLen += n
+ dst = dst.DropFirst64(n)
+ }
+ return writeLen, nil
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (i *Inotify) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch args[1].Int() {
+ case linux.FIONREAD:
+ i.evMu.Lock()
+ defer i.evMu.Unlock()
+ var n uint32
+ for e := i.events.Front(); e != nil; e = e.Next() {
+ n += uint32(e.sizeOf())
+ }
+ var buf [4]byte
+ usermem.ByteOrder.PutUint32(buf[:], n)
+ _, err := uio.CopyOut(ctx, args[2].Pointer(), buf[:], usermem.IOOpts{})
+ return 0, err
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+func (i *Inotify) queueEvent(ev *Event) {
+ i.evMu.Lock()
+
+ // Check if we should coalesce the event we're about to queue with the last
+ // one currently in the queue. Events are coalesced if they are identical.
+ if last := i.events.Back(); last != nil {
+ if ev.equals(last) {
+ // "Coalesce" the two events by simply not queuing the new one. We
+ // don't need to raise a waiter.EventIn notification because no new
+ // data is available for reading.
+ i.evMu.Unlock()
+ return
+ }
+ }
+
+ i.events.PushBack(ev)
+
+ // Release mutex before notifying waiters because we don't control what they
+ // can do.
+ i.evMu.Unlock()
+
+ i.queue.Notify(waiter.EventIn)
+}
+
+// newWatchLocked creates and adds a new watch to target.
+//
+// Precondition: i.mu must be locked. ws must be the watch set for target d.
+func (i *Inotify) newWatchLocked(d *Dentry, ws *Watches, mask uint32) *Watch {
+ w := &Watch{
+ owner: i,
+ wd: i.nextWatchIDLocked(),
+ target: d,
+ mask: mask,
+ }
+
+ // Hold the watch in this inotify instance as well as the watch set on the
+ // target.
+ i.watches[w.wd] = w
+ ws.Add(w)
+ return w
+}
+
+// newWatchIDLocked allocates and returns a new watch descriptor.
+//
+// Precondition: i.mu must be locked.
+func (i *Inotify) nextWatchIDLocked() int32 {
+ i.nextWatchMinusOne++
+ return i.nextWatchMinusOne
+}
+
+// AddWatch constructs a new inotify watch and adds it to the target. It
+// returns the watch descriptor returned by inotify_add_watch(2).
+//
+// The caller must hold a reference on target.
+func (i *Inotify) AddWatch(target *Dentry, mask uint32) (int32, error) {
+ // Note: Locking this inotify instance protects the result returned by
+ // Lookup() below. With the lock held, we know for sure the lookup result
+ // won't become stale because it's impossible for *this* instance to
+ // add/remove watches on target.
+ i.mu.Lock()
+ defer i.mu.Unlock()
+
+ ws := target.Watches()
+ if ws == nil {
+ // While Linux supports inotify watches on all filesystem types, watches on
+ // filesystems like kernfs are not generally useful, so we do not.
+ return 0, syserror.EPERM
+ }
+ // Does the target already have a watch from this inotify instance?
+ if existing := ws.Lookup(i.id); existing != nil {
+ newmask := mask
+ if mask&linux.IN_MASK_ADD != 0 {
+ // "Add (OR) events to watch mask for this pathname if it already
+ // exists (instead of replacing mask)." -- inotify(7)
+ newmask |= atomic.LoadUint32(&existing.mask)
+ }
+ atomic.StoreUint32(&existing.mask, newmask)
+ return existing.wd, nil
+ }
+
+ // No existing watch, create a new watch.
+ w := i.newWatchLocked(target, ws, mask)
+ return w.wd, nil
+}
+
+// RmWatch looks up an inotify watch for the given 'wd' and configures the
+// target to stop sending events to this inotify instance.
+func (i *Inotify) RmWatch(wd int32) error {
+ i.mu.Lock()
+
+ // Find the watch we were asked to removed.
+ w, ok := i.watches[wd]
+ if !ok {
+ i.mu.Unlock()
+ return syserror.EINVAL
+ }
+
+ // Remove the watch from this instance.
+ delete(i.watches, wd)
+
+ // Remove the watch from the watch target.
+ ws := w.target.Watches()
+ // AddWatch ensures that w.target has a non-nil watch set.
+ if ws == nil {
+ panic("Watched dentry cannot have nil watch set")
+ }
+ ws.Remove(w.OwnerID())
+ remaining := ws.Size()
+ i.mu.Unlock()
+
+ if remaining == 0 {
+ w.target.OnZeroWatches()
+ }
+
+ // Generate the event for the removal.
+ i.queueEvent(newEvent(wd, "", linux.IN_IGNORED, 0))
+
+ return nil
+}
+
+// Watches is the collection of all inotify watches on a single file.
+//
+// +stateify savable
+type Watches struct {
+ // mu protects the fields below.
+ mu sync.RWMutex `state:"nosave"`
+
+ // ws is the map of active watches in this collection, keyed by the inotify
+ // instance id of the owner.
+ ws map[uint64]*Watch
+}
+
+// Size returns the number of watches held by w.
+func (w *Watches) Size() int {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ return len(w.ws)
+}
+
+// Lookup returns the watch owned by an inotify instance with the given id.
+// Returns nil if no such watch exists.
+//
+// Precondition: the inotify instance with the given id must be locked to
+// prevent the returned watch from being concurrently modified or replaced in
+// Inotify.watches.
+func (w *Watches) Lookup(id uint64) *Watch {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ return w.ws[id]
+}
+
+// Add adds watch into this set of watches.
+//
+// Precondition: the inotify instance with the given id must be locked.
+func (w *Watches) Add(watch *Watch) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ owner := watch.OwnerID()
+ // Sanity check, we should never have two watches for one owner on the
+ // same target.
+ if _, exists := w.ws[owner]; exists {
+ panic(fmt.Sprintf("Watch collision with ID %+v", owner))
+ }
+ if w.ws == nil {
+ w.ws = make(map[uint64]*Watch)
+ }
+ w.ws[owner] = watch
+}
+
+// Remove removes a watch with the given id from this set of watches and
+// releases it. The caller is responsible for generating any watch removal
+// event, as appropriate. The provided id must match an existing watch in this
+// collection.
+//
+// Precondition: the inotify instance with the given id must be locked.
+func (w *Watches) Remove(id uint64) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ if w.ws == nil {
+ // This watch set is being destroyed. The thread executing the
+ // destructor is already in the process of deleting all our watches. We
+ // got here with no references on the target because we raced with the
+ // destructor notifying all the watch owners of destruction. See the
+ // comment in Watches.HandleDeletion for why this race exists.
+ return
+ }
+
+ // It is possible for w.Remove() to be called for the same watch multiple
+ // times. See the treatment of one-shot watches in Watches.Notify().
+ if _, ok := w.ws[id]; ok {
+ delete(w.ws, id)
+ }
+}
+
+// Notify queues a new event with watches in this set. Watches with
+// IN_EXCL_UNLINK are skipped if the event is coming from a child that has been
+// unlinked.
+func (w *Watches) Notify(name string, events, cookie uint32, et EventType, unlinked bool) {
+ var hasExpired bool
+ w.mu.RLock()
+ for _, watch := range w.ws {
+ if unlinked && watch.ExcludeUnlinked() && et == PathEvent {
+ continue
+ }
+ if watch.Notify(name, events, cookie) {
+ hasExpired = true
+ }
+ }
+ w.mu.RUnlock()
+
+ if hasExpired {
+ w.cleanupExpiredWatches()
+ }
+}
+
+// This function is relatively expensive and should only be called where there
+// are expired watches.
+func (w *Watches) cleanupExpiredWatches() {
+ // Because of lock ordering, we cannot acquire Inotify.mu for each watch
+ // owner while holding w.mu. As a result, store expired watches locally
+ // before removing.
+ var toRemove []*Watch
+ w.mu.RLock()
+ for _, watch := range w.ws {
+ if atomic.LoadInt32(&watch.expired) == 1 {
+ toRemove = append(toRemove, watch)
+ }
+ }
+ w.mu.RUnlock()
+ for _, watch := range toRemove {
+ watch.owner.RmWatch(watch.wd)
+ }
+}
+
+// HandleDeletion is called when the watch target is destroyed. Clear the
+// watch set, detach watches from the inotify instances they belong to, and
+// generate the appropriate events.
+func (w *Watches) HandleDeletion() {
+ w.Notify("", linux.IN_DELETE_SELF, 0, InodeEvent, true /* unlinked */)
+
+ // As in Watches.Notify, we can't hold w.mu while acquiring Inotify.mu for
+ // the owner of each watch being deleted. Instead, atomically store the
+ // watches map in a local variable and set it to nil so we can iterate over
+ // it with the assurance that there will be no concurrent accesses.
+ var ws map[uint64]*Watch
+ w.mu.Lock()
+ ws = w.ws
+ w.ws = nil
+ w.mu.Unlock()
+
+ // Remove each watch from its owner's watch set, and generate a corresponding
+ // watch removal event.
+ for _, watch := range ws {
+ i := watch.owner
+ i.mu.Lock()
+ _, found := i.watches[watch.wd]
+ delete(i.watches, watch.wd)
+
+ // Release mutex before notifying waiters because we don't control what
+ // they can do.
+ i.mu.Unlock()
+
+ // If watch was not found, it was removed from the inotify instance before
+ // we could get to it, in which case we should not generate an event.
+ if found {
+ i.queueEvent(newEvent(watch.wd, "", linux.IN_IGNORED, 0))
+ }
+ }
+}
+
+// Watch represent a particular inotify watch created by inotify_add_watch.
+//
+// +stateify savable
+type Watch struct {
+ // Inotify instance which owns this watch.
+ //
+ // This field is immutable after creation.
+ owner *Inotify
+
+ // Descriptor for this watch. This is unique across an inotify instance.
+ //
+ // This field is immutable after creation.
+ wd int32
+
+ // target is a dentry representing the watch target. Its watch set contains this watch.
+ //
+ // This field is immutable after creation.
+ target *Dentry
+
+ // Events being monitored via this watch. Must be accessed with atomic
+ // memory operations.
+ mask uint32
+
+ // expired is set to 1 to indicate that this watch is a one-shot that has
+ // already sent a notification and therefore can be removed. Must be accessed
+ // with atomic memory operations.
+ expired int32
+}
+
+// OwnerID returns the id of the inotify instance that owns this watch.
+func (w *Watch) OwnerID() uint64 {
+ return w.owner.id
+}
+
+// ExcludeUnlinked indicates whether the watched object should continue to be
+// notified of events originating from a path that has been unlinked.
+//
+// For example, if "foo/bar" is opened and then unlinked, operations on the
+// open fd may be ignored by watches on "foo" and "foo/bar" with IN_EXCL_UNLINK.
+func (w *Watch) ExcludeUnlinked() bool {
+ return atomic.LoadUint32(&w.mask)&linux.IN_EXCL_UNLINK != 0
+}
+
+// Notify queues a new event on this watch. Returns true if this is a one-shot
+// watch that should be deleted, after this event was successfully queued.
+func (w *Watch) Notify(name string, events uint32, cookie uint32) bool {
+ if atomic.LoadInt32(&w.expired) == 1 {
+ // This is a one-shot watch that is already in the process of being
+ // removed. This may happen if a second event reaches the watch target
+ // before this watch has been removed.
+ return false
+ }
+
+ mask := atomic.LoadUint32(&w.mask)
+ if mask&events == 0 {
+ // We weren't watching for this event.
+ return false
+ }
+
+ // Event mask should include bits matched from the watch plus all control
+ // event bits.
+ unmaskableBits := ^uint32(0) &^ linux.IN_ALL_EVENTS
+ effectiveMask := unmaskableBits | mask
+ matchedEvents := effectiveMask & events
+ w.owner.queueEvent(newEvent(w.wd, name, matchedEvents, cookie))
+ if mask&linux.IN_ONESHOT != 0 {
+ atomic.StoreInt32(&w.expired, 1)
+ return true
+ }
+ return false
+}
+
+// Event represents a struct inotify_event from linux.
+//
+// +stateify savable
+type Event struct {
+ eventEntry
+
+ wd int32
+ mask uint32
+ cookie uint32
+
+ // len is computed based on the name field is set automatically by
+ // Event.setName. It should be 0 when no name is set; otherwise it is the
+ // length of the name slice.
+ len uint32
+
+ // The name field has special padding requirements and should only be set by
+ // calling Event.setName.
+ name []byte
+}
+
+func newEvent(wd int32, name string, events, cookie uint32) *Event {
+ e := &Event{
+ wd: wd,
+ mask: events,
+ cookie: cookie,
+ }
+ if name != "" {
+ e.setName(name)
+ }
+ return e
+}
+
+// paddedBytes converts a go string to a null-terminated c-string, padded with
+// null bytes to a total size of 'l'. 'l' must be large enough for all the bytes
+// in the 's' plus at least one null byte.
+func paddedBytes(s string, l uint32) []byte {
+ if l < uint32(len(s)+1) {
+ panic("Converting string to byte array results in truncation, this can lead to buffer-overflow due to the missing null-byte!")
+ }
+ b := make([]byte, l)
+ copy(b, s)
+
+ // b was zero-value initialized during make(), so the rest of the slice is
+ // already filled with null bytes.
+
+ return b
+}
+
+// setName sets the optional name for this event.
+func (e *Event) setName(name string) {
+ // We need to pad the name such that the entire event length ends up a
+ // multiple of inotifyEventBaseSize.
+ unpaddedLen := len(name) + 1
+ // Round up to nearest multiple of inotifyEventBaseSize.
+ e.len = uint32((unpaddedLen + inotifyEventBaseSize - 1) & ^(inotifyEventBaseSize - 1))
+ // Make sure we haven't overflowed and wrapped around when rounding.
+ if unpaddedLen > int(e.len) {
+ panic("Overflow when rounding inotify event size, the 'name' field was too big.")
+ }
+ e.name = paddedBytes(name, e.len)
+}
+
+func (e *Event) sizeOf() int {
+ s := inotifyEventBaseSize + int(e.len)
+ if s < inotifyEventBaseSize {
+ panic("Overflowed event size")
+ }
+ return s
+}
+
+// CopyTo serializes this event to dst. buf is used as a scratch buffer to
+// construct the output. We use a buffer allocated ahead of time for
+// performance. buf must be at least inotifyEventBaseSize bytes.
+func (e *Event) CopyTo(ctx context.Context, buf []byte, dst usermem.IOSequence) (int64, error) {
+ usermem.ByteOrder.PutUint32(buf[0:], uint32(e.wd))
+ usermem.ByteOrder.PutUint32(buf[4:], e.mask)
+ usermem.ByteOrder.PutUint32(buf[8:], e.cookie)
+ usermem.ByteOrder.PutUint32(buf[12:], e.len)
+
+ writeLen := 0
+
+ n, err := dst.CopyOut(ctx, buf)
+ if err != nil {
+ return 0, err
+ }
+ writeLen += n
+ dst = dst.DropFirst(n)
+
+ if e.len > 0 {
+ n, err = dst.CopyOut(ctx, e.name)
+ if err != nil {
+ return 0, err
+ }
+ writeLen += n
+ }
+
+ // Santiy check.
+ if writeLen != e.sizeOf() {
+ panic(fmt.Sprintf("Serialized unexpected amount of data for an event, expected %d, wrote %d.", e.sizeOf(), writeLen))
+ }
+
+ return int64(writeLen), nil
+}
+
+func (e *Event) equals(other *Event) bool {
+ return e.wd == other.wd &&
+ e.mask == other.mask &&
+ e.cookie == other.cookie &&
+ e.len == other.len &&
+ bytes.Equal(e.name, other.name)
+}
+
+// InotifyEventFromStatMask generates the appropriate events for an operation
+// that set the stats specified in mask.
+func InotifyEventFromStatMask(mask uint32) uint32 {
+ var ev uint32
+ if mask&(linux.STATX_UID|linux.STATX_GID|linux.STATX_MODE) != 0 {
+ ev |= linux.IN_ATTRIB
+ }
+ if mask&linux.STATX_SIZE != 0 {
+ ev |= linux.IN_MODIFY
+ }
+
+ if (mask & (linux.STATX_ATIME | linux.STATX_MTIME)) == (linux.STATX_ATIME | linux.STATX_MTIME) {
+ // Both times indicates a utime(s) call.
+ ev |= linux.IN_ATTRIB
+ } else if mask&linux.STATX_ATIME != 0 {
+ ev |= linux.IN_ACCESS
+ } else if mask&linux.STATX_MTIME != 0 {
+ mask |= linux.IN_MODIFY
+ }
+ return ev
+}
+
+// InotifyRemoveChild sends the appriopriate notifications to the watch sets of
+// the child being removed and its parent. Note that unlike most pairs of
+// parent/child notifications, the child is notified first in this case.
+func InotifyRemoveChild(self, parent *Watches, name string) {
+ if self != nil {
+ self.Notify("", linux.IN_ATTRIB, 0, InodeEvent, true /* unlinked */)
+ }
+ if parent != nil {
+ parent.Notify(name, linux.IN_DELETE, 0, InodeEvent, true /* unlinked */)
+ }
+}
+
+// InotifyRename sends the appriopriate notifications to the watch sets of the
+// file being renamed and its old/new parents.
+func InotifyRename(ctx context.Context, renamed, oldParent, newParent *Watches, oldName, newName string, isDir bool) {
+ var dirEv uint32
+ if isDir {
+ dirEv = linux.IN_ISDIR
+ }
+ cookie := uniqueid.InotifyCookie(ctx)
+ if oldParent != nil {
+ oldParent.Notify(oldName, dirEv|linux.IN_MOVED_FROM, cookie, InodeEvent, false /* unlinked */)
+ }
+ if newParent != nil {
+ newParent.Notify(newName, dirEv|linux.IN_MOVED_TO, cookie, InodeEvent, false /* unlinked */)
+ }
+ // Somewhat surprisingly, self move events do not have a cookie.
+ if renamed != nil {
+ renamed.Notify("", linux.IN_MOVE_SELF, 0, InodeEvent, false /* unlinked */)
+ }
+}
diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go
new file mode 100644
index 000000000..6c7583a81
--- /dev/null
+++ b/pkg/sentry/vfs/lock.go
@@ -0,0 +1,109 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package lock provides POSIX and BSD style file locking for VFS2 file
+// implementations.
+//
+// The actual implementations can be found in the lock package under
+// sentry/fs/lock.
+package vfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// FileLocks supports POSIX and BSD style locks, which correspond to fcntl(2)
+// and flock(2) respectively in Linux. It can be embedded into various file
+// implementations for VFS2 that support locking.
+//
+// Note that in Linux these two types of locks are _not_ cooperative, because
+// race and deadlock conditions make merging them prohibitive. We do the same
+// and keep them oblivious to each other.
+type FileLocks struct {
+ // bsd is a set of BSD-style advisory file wide locks, see flock(2).
+ bsd fslock.Locks
+
+ // posix is a set of POSIX-style regional advisory locks, see fcntl(2).
+ posix fslock.Locks
+}
+
+// LockBSD tries to acquire a BSD-style lock on the entire file.
+func (fl *FileLocks) LockBSD(uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
+ if fl.bsd.LockRegion(uid, t, fslock.LockRange{0, fslock.LockEOF}, block) {
+ return nil
+ }
+ return syserror.ErrWouldBlock
+}
+
+// UnlockBSD releases a BSD-style lock on the entire file.
+//
+// This operation is always successful, even if there did not exist a lock on
+// the requested region held by uid in the first place.
+func (fl *FileLocks) UnlockBSD(uid fslock.UniqueID) {
+ fl.bsd.UnlockRegion(uid, fslock.LockRange{0, fslock.LockEOF})
+}
+
+// LockPOSIX tries to acquire a POSIX-style lock on a file region.
+func (fl *FileLocks) LockPOSIX(ctx context.Context, fd *FileDescription, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ rng, err := computeRange(ctx, fd, start, length, whence)
+ if err != nil {
+ return err
+ }
+ if fl.posix.LockRegion(uid, t, rng, block) {
+ return nil
+ }
+ return syserror.ErrWouldBlock
+}
+
+// UnlockPOSIX releases a POSIX-style lock on a file region.
+//
+// This operation is always successful, even if there did not exist a lock on
+// the requested region held by uid in the first place.
+func (fl *FileLocks) UnlockPOSIX(ctx context.Context, fd *FileDescription, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ rng, err := computeRange(ctx, fd, start, length, whence)
+ if err != nil {
+ return err
+ }
+ fl.posix.UnlockRegion(uid, rng)
+ return nil
+}
+
+func computeRange(ctx context.Context, fd *FileDescription, start uint64, length uint64, whence int16) (fslock.LockRange, error) {
+ var off int64
+ switch whence {
+ case linux.SEEK_SET:
+ off = 0
+ case linux.SEEK_CUR:
+ // Note that Linux does not hold any mutexes while retrieving the file
+ // offset, see fs/locks.c:flock_to_posix_lock and fs/locks.c:fcntl_setlk.
+ curOff, err := fd.Seek(ctx, 0, linux.SEEK_CUR)
+ if err != nil {
+ return fslock.LockRange{}, err
+ }
+ off = curOff
+ case linux.SEEK_END:
+ stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_SIZE})
+ if err != nil {
+ return fslock.LockRange{}, err
+ }
+ off = int64(stat.Size)
+ default:
+ return fslock.LockRange{}, syserror.EINVAL
+ }
+
+ return fslock.ComputeRange(int64(start), int64(length), off)
+}
diff --git a/pkg/sentry/vfs/memxattr/BUILD b/pkg/sentry/vfs/memxattr/BUILD
new file mode 100644
index 000000000..d8c4d27b9
--- /dev/null
+++ b/pkg/sentry/vfs/memxattr/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "memxattr",
+ srcs = ["xattr.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ ],
+)
diff --git a/pkg/sentry/vfs/memxattr/xattr.go b/pkg/sentry/vfs/memxattr/xattr.go
new file mode 100644
index 000000000..cc1e7d764
--- /dev/null
+++ b/pkg/sentry/vfs/memxattr/xattr.go
@@ -0,0 +1,102 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package memxattr provides a default, in-memory extended attribute
+// implementation.
+package memxattr
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// SimpleExtendedAttributes implements extended attributes using a map of
+// names to values.
+//
+// +stateify savable
+type SimpleExtendedAttributes struct {
+ // mu protects the below fields.
+ mu sync.RWMutex `state:"nosave"`
+ xattrs map[string]string
+}
+
+// Getxattr returns the value at 'name'.
+func (x *SimpleExtendedAttributes) Getxattr(opts *vfs.GetxattrOptions) (string, error) {
+ x.mu.RLock()
+ value, ok := x.xattrs[opts.Name]
+ x.mu.RUnlock()
+ if !ok {
+ return "", syserror.ENODATA
+ }
+ // Check that the size of the buffer provided in getxattr(2) is large enough
+ // to contain the value.
+ if opts.Size != 0 && uint64(len(value)) > opts.Size {
+ return "", syserror.ERANGE
+ }
+ return value, nil
+}
+
+// Setxattr sets 'value' at 'name'.
+func (x *SimpleExtendedAttributes) Setxattr(opts *vfs.SetxattrOptions) error {
+ x.mu.Lock()
+ defer x.mu.Unlock()
+ if x.xattrs == nil {
+ if opts.Flags&linux.XATTR_REPLACE != 0 {
+ return syserror.ENODATA
+ }
+ x.xattrs = make(map[string]string)
+ }
+
+ _, ok := x.xattrs[opts.Name]
+ if ok && opts.Flags&linux.XATTR_CREATE != 0 {
+ return syserror.EEXIST
+ }
+ if !ok && opts.Flags&linux.XATTR_REPLACE != 0 {
+ return syserror.ENODATA
+ }
+
+ x.xattrs[opts.Name] = opts.Value
+ return nil
+}
+
+// Listxattr returns all names in xattrs.
+func (x *SimpleExtendedAttributes) Listxattr(size uint64) ([]string, error) {
+ // Keep track of the size of the buffer needed in listxattr(2) for the list.
+ listSize := 0
+ x.mu.RLock()
+ names := make([]string, 0, len(x.xattrs))
+ for n := range x.xattrs {
+ names = append(names, n)
+ // Add one byte per null terminator.
+ listSize += len(n) + 1
+ }
+ x.mu.RUnlock()
+ if size != 0 && uint64(listSize) > size {
+ return nil, syserror.ERANGE
+ }
+ return names, nil
+}
+
+// Removexattr removes the xattr at 'name'.
+func (x *SimpleExtendedAttributes) Removexattr(name string) error {
+ x.mu.Lock()
+ defer x.mu.Unlock()
+ if _, ok := x.xattrs[name]; !ok {
+ return syserror.ENODATA
+ }
+ delete(x.xattrs, name)
+ return nil
+}
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
new file mode 100644
index 000000000..32f901bd8
--- /dev/null
+++ b/pkg/sentry/vfs/mount.go
@@ -0,0 +1,903 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "bytes"
+ "fmt"
+ "math"
+ "sort"
+ "strings"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// A Mount is a replacement of a Dentry (Mount.key.point) from one Filesystem
+// (Mount.key.parent.fs) with a Dentry (Mount.root) from another Filesystem
+// (Mount.fs), which applies to path resolution in the context of a particular
+// Mount (Mount.key.parent).
+//
+// Mounts are reference-counted. Unless otherwise specified, all Mount methods
+// require that a reference is held.
+//
+// Mount and Filesystem are distinct types because it's possible for a single
+// Filesystem to be mounted at multiple locations and/or in multiple mount
+// namespaces.
+//
+// Mount is analogous to Linux's struct mount. (gVisor does not distinguish
+// between struct mount and struct vfsmount.)
+//
+// +stateify savable
+type Mount struct {
+ // vfs, fs, root are immutable. References are held on fs and root.
+ //
+ // Invariant: root belongs to fs.
+ vfs *VirtualFilesystem
+ fs *Filesystem
+ root *Dentry
+
+ // ID is the immutable mount ID.
+ ID uint64
+
+ // Flags contains settings as specified for mount(2), e.g. MS_NOEXEC, except
+ // for MS_RDONLY which is tracked in "writers". Immutable.
+ Flags MountFlags
+
+ // key is protected by VirtualFilesystem.mountMu and
+ // VirtualFilesystem.mounts.seq, and may be nil. References are held on
+ // key.parent and key.point if they are not nil.
+ //
+ // Invariant: key.parent != nil iff key.point != nil. key.point belongs to
+ // key.parent.fs.
+ key mountKey
+
+ // ns is the namespace in which this Mount was mounted. ns is protected by
+ // VirtualFilesystem.mountMu.
+ ns *MountNamespace
+
+ // The lower 63 bits of refs are a reference count. The MSB of refs is set
+ // if the Mount has been eagerly umounted, as by umount(2) without the
+ // MNT_DETACH flag. refs is accessed using atomic memory operations.
+ refs int64
+
+ // children is the set of all Mounts for which Mount.key.parent is this
+ // Mount. children is protected by VirtualFilesystem.mountMu.
+ children map[*Mount]struct{}
+
+ // umounted is true if VFS.umountRecursiveLocked() has been called on this
+ // Mount. VirtualFilesystem does not hold a reference on Mounts for which
+ // umounted is true. umounted is protected by VirtualFilesystem.mountMu.
+ umounted bool
+
+ // The lower 63 bits of writers is the number of calls to
+ // Mount.CheckBeginWrite() that have not yet been paired with a call to
+ // Mount.EndWrite(). The MSB of writers is set if MS_RDONLY is in effect.
+ // writers is accessed using atomic memory operations.
+ writers int64
+}
+
+func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *MountNamespace, opts *MountOptions) *Mount {
+ mnt := &Mount{
+ ID: atomic.AddUint64(&vfs.lastMountID, 1),
+ Flags: opts.Flags,
+ vfs: vfs,
+ fs: fs,
+ root: root,
+ ns: mntns,
+ refs: 1,
+ }
+ if opts.ReadOnly {
+ mnt.setReadOnlyLocked(true)
+ }
+ return mnt
+}
+
+// Options returns a copy of the MountOptions currently applicable to mnt.
+func (mnt *Mount) Options() MountOptions {
+ mnt.vfs.mountMu.Lock()
+ defer mnt.vfs.mountMu.Unlock()
+ return MountOptions{
+ Flags: mnt.Flags,
+ ReadOnly: mnt.readOnly(),
+ }
+}
+
+// A MountNamespace is a collection of Mounts.//
+// MountNamespaces are reference-counted. Unless otherwise specified, all
+// MountNamespace methods require that a reference is held.
+//
+// MountNamespace is analogous to Linux's struct mnt_namespace.
+//
+// +stateify savable
+type MountNamespace struct {
+ // Owner is the usernamespace that owns this mount namespace.
+ Owner *auth.UserNamespace
+
+ // root is the MountNamespace's root mount. root is immutable.
+ root *Mount
+
+ // refs is the reference count. refs is accessed using atomic memory
+ // operations.
+ refs int64
+
+ // mountpoints maps all Dentries which are mount points in this namespace
+ // to the number of Mounts for which they are mount points. mountpoints is
+ // protected by VirtualFilesystem.mountMu.
+ //
+ // mountpoints is used to determine if a Dentry can be moved or removed
+ // (which requires that the Dentry is not a mount point in the calling
+ // namespace).
+ //
+ // mountpoints is maintained even if there are no references held on the
+ // MountNamespace; this is required to ensure that
+ // VFS.PrepareDeleteDentry() and VFS.PrepareRemoveDentry() operate
+ // correctly on unreferenced MountNamespaces.
+ mountpoints map[*Dentry]uint32
+}
+
+// NewMountNamespace returns a new mount namespace with a root filesystem
+// configured by the given arguments. A reference is taken on the returned
+// MountNamespace.
+func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *GetFilesystemOptions) (*MountNamespace, error) {
+ rft := vfs.getFilesystemType(fsTypeName)
+ if rft == nil {
+ ctx.Warningf("Unknown filesystem type: %s", fsTypeName)
+ return nil, syserror.ENODEV
+ }
+ fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, *opts)
+ if err != nil {
+ return nil, err
+ }
+ mntns := &MountNamespace{
+ Owner: creds.UserNamespace,
+ refs: 1,
+ mountpoints: make(map[*Dentry]uint32),
+ }
+ mntns.root = newMount(vfs, fs, root, mntns, &MountOptions{})
+ return mntns, nil
+}
+
+// NewDisconnectedMount returns a Mount representing fs with the given root
+// (which may be nil). The new Mount is not associated with any MountNamespace
+// and is not connected to any other Mounts. References are taken on fs and
+// root.
+func (vfs *VirtualFilesystem) NewDisconnectedMount(fs *Filesystem, root *Dentry, opts *MountOptions) (*Mount, error) {
+ fs.IncRef()
+ if root != nil {
+ root.IncRef()
+ }
+ return newMount(vfs, fs, root, nil /* mntns */, opts), nil
+}
+
+// MountDisconnected creates a Filesystem configured by the given arguments,
+// then returns a Mount representing it. The new Mount is not associated with
+// any MountNamespace and is not connected to any other Mounts.
+func (vfs *VirtualFilesystem) MountDisconnected(ctx context.Context, creds *auth.Credentials, source string, fsTypeName string, opts *MountOptions) (*Mount, error) {
+ rft := vfs.getFilesystemType(fsTypeName)
+ if rft == nil {
+ return nil, syserror.ENODEV
+ }
+ if !opts.InternalMount && !rft.opts.AllowUserMount {
+ return nil, syserror.ENODEV
+ }
+ fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions)
+ if err != nil {
+ return nil, err
+ }
+ defer root.DecRef()
+ defer fs.DecRef()
+ return vfs.NewDisconnectedMount(fs, root, opts)
+}
+
+// ConnectMountAt connects mnt at the path represented by target.
+//
+// Preconditions: mnt must be disconnected.
+func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Credentials, mnt *Mount, target *PathOperation) error {
+ // We can't hold vfs.mountMu while calling FilesystemImpl methods due to
+ // lock ordering.
+ vd, err := vfs.GetDentryAt(ctx, creds, target, &GetDentryOptions{})
+ if err != nil {
+ return err
+ }
+ vfs.mountMu.Lock()
+ vd.dentry.mu.Lock()
+ for {
+ if vd.dentry.dead {
+ vd.dentry.mu.Unlock()
+ vfs.mountMu.Unlock()
+ vd.DecRef()
+ return syserror.ENOENT
+ }
+ // vd might have been mounted over between vfs.GetDentryAt() and
+ // vfs.mountMu.Lock().
+ if !vd.dentry.isMounted() {
+ break
+ }
+ nextmnt := vfs.mounts.Lookup(vd.mount, vd.dentry)
+ if nextmnt == nil {
+ break
+ }
+ // It's possible that nextmnt has been umounted but not disconnected,
+ // in which case vfs no longer holds a reference on it, and the last
+ // reference may be concurrently dropped even though we're holding
+ // vfs.mountMu.
+ if !nextmnt.tryIncMountedRef() {
+ break
+ }
+ // This can't fail since we're holding vfs.mountMu.
+ nextmnt.root.IncRef()
+ vd.dentry.mu.Unlock()
+ vd.DecRef()
+ vd = VirtualDentry{
+ mount: nextmnt,
+ dentry: nextmnt.root,
+ }
+ vd.dentry.mu.Lock()
+ }
+ // TODO(gvisor.dev/issue/1035): Linux requires that either both the mount
+ // point and the mount root are directories, or neither are, and returns
+ // ENOTDIR if this is not the case.
+ mntns := vd.mount.ns
+ vfs.mounts.seq.BeginWrite()
+ vfs.connectLocked(mnt, vd, mntns)
+ vfs.mounts.seq.EndWrite()
+ vd.dentry.mu.Unlock()
+ vfs.mountMu.Unlock()
+ return nil
+}
+
+// MountAt creates and mounts a Filesystem configured by the given arguments.
+func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error {
+ mnt, err := vfs.MountDisconnected(ctx, creds, source, fsTypeName, opts)
+ if err != nil {
+ return err
+ }
+ defer mnt.DecRef()
+ if err := vfs.ConnectMountAt(ctx, creds, mnt, target); err != nil {
+ return err
+ }
+ return nil
+}
+
+// UmountAt removes the Mount at the given path.
+func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *UmountOptions) error {
+ if opts.Flags&^(linux.MNT_FORCE|linux.MNT_DETACH) != 0 {
+ return syserror.EINVAL
+ }
+
+ // MNT_FORCE is currently unimplemented except for the permission check.
+ // Force unmounting specifically requires CAP_SYS_ADMIN in the root user
+ // namespace, and not in the owner user namespace for the target mount. See
+ // fs/namespace.c:SYSCALL_DEFINE2(umount, ...)
+ if opts.Flags&linux.MNT_FORCE != 0 && creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, creds.UserNamespace.Root()) {
+ return syserror.EPERM
+ }
+
+ vd, err := vfs.GetDentryAt(ctx, creds, pop, &GetDentryOptions{})
+ if err != nil {
+ return err
+ }
+ defer vd.DecRef()
+ if vd.dentry != vd.mount.root {
+ return syserror.EINVAL
+ }
+ vfs.mountMu.Lock()
+ if mntns := MountNamespaceFromContext(ctx); mntns != nil {
+ defer mntns.DecRef()
+ if mntns != vd.mount.ns {
+ vfs.mountMu.Unlock()
+ return syserror.EINVAL
+ }
+ }
+
+ // TODO(gvisor.dev/issue/1035): Linux special-cases umount of the caller's
+ // root, which we don't implement yet (we'll just fail it since the caller
+ // holds a reference on it).
+
+ vfs.mounts.seq.BeginWrite()
+ if opts.Flags&linux.MNT_DETACH == 0 {
+ if len(vd.mount.children) != 0 {
+ vfs.mounts.seq.EndWrite()
+ vfs.mountMu.Unlock()
+ return syserror.EBUSY
+ }
+ // We are holding a reference on vd.mount.
+ expectedRefs := int64(1)
+ if !vd.mount.umounted {
+ expectedRefs = 2
+ }
+ if atomic.LoadInt64(&vd.mount.refs)&^math.MinInt64 != expectedRefs { // mask out MSB
+ vfs.mounts.seq.EndWrite()
+ vfs.mountMu.Unlock()
+ return syserror.EBUSY
+ }
+ }
+ vdsToDecRef, mountsToDecRef := vfs.umountRecursiveLocked(vd.mount, &umountRecursiveOptions{
+ eager: opts.Flags&linux.MNT_DETACH == 0,
+ disconnectHierarchy: true,
+ }, nil, nil)
+ vfs.mounts.seq.EndWrite()
+ vfs.mountMu.Unlock()
+ for _, vd := range vdsToDecRef {
+ vd.DecRef()
+ }
+ for _, mnt := range mountsToDecRef {
+ mnt.DecRef()
+ }
+ return nil
+}
+
+type umountRecursiveOptions struct {
+ // If eager is true, ensure that future calls to Mount.tryIncMountedRef()
+ // on umounted mounts fail.
+ //
+ // eager is analogous to Linux's UMOUNT_SYNC.
+ eager bool
+
+ // If disconnectHierarchy is true, Mounts that are umounted hierarchically
+ // should be disconnected from their parents. (Mounts whose parents are not
+ // umounted, which in most cases means the Mount passed to the initial call
+ // to umountRecursiveLocked, are unconditionally disconnected for
+ // consistency with Linux.)
+ //
+ // disconnectHierarchy is analogous to Linux's !UMOUNT_CONNECTED.
+ disconnectHierarchy bool
+}
+
+// umountRecursiveLocked marks mnt and its descendants as umounted. It does not
+// release mount or dentry references; instead, it appends VirtualDentries and
+// Mounts on which references must be dropped to vdsToDecRef and mountsToDecRef
+// respectively, and returns updated slices. (This is necessary because
+// filesystem locks possibly taken by DentryImpl.DecRef() may precede
+// vfs.mountMu in the lock order, and Mount.DecRef() may lock vfs.mountMu.)
+//
+// umountRecursiveLocked is analogous to Linux's fs/namespace.c:umount_tree().
+//
+// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a
+// writer critical section.
+func (vfs *VirtualFilesystem) umountRecursiveLocked(mnt *Mount, opts *umountRecursiveOptions, vdsToDecRef []VirtualDentry, mountsToDecRef []*Mount) ([]VirtualDentry, []*Mount) {
+ if !mnt.umounted {
+ mnt.umounted = true
+ mountsToDecRef = append(mountsToDecRef, mnt)
+ if parent := mnt.parent(); parent != nil && (opts.disconnectHierarchy || !parent.umounted) {
+ vdsToDecRef = append(vdsToDecRef, vfs.disconnectLocked(mnt))
+ }
+ }
+ if opts.eager {
+ for {
+ refs := atomic.LoadInt64(&mnt.refs)
+ if refs < 0 {
+ break
+ }
+ if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs|math.MinInt64) {
+ break
+ }
+ }
+ }
+ for child := range mnt.children {
+ vdsToDecRef, mountsToDecRef = vfs.umountRecursiveLocked(child, opts, vdsToDecRef, mountsToDecRef)
+ }
+ return vdsToDecRef, mountsToDecRef
+}
+
+// connectLocked makes vd the mount parent/point for mnt. It consumes
+// references held by vd.
+//
+// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a
+// writer critical section. d.mu must be locked. mnt.parent() == nil, i.e. mnt
+// must not already be connected.
+func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns *MountNamespace) {
+ if checkInvariants {
+ if mnt.parent() != nil {
+ panic("VFS.connectLocked called on connected mount")
+ }
+ }
+ mnt.IncRef() // dropped by callers of umountRecursiveLocked
+ mnt.storeKey(vd)
+ if vd.mount.children == nil {
+ vd.mount.children = make(map[*Mount]struct{})
+ }
+ vd.mount.children[mnt] = struct{}{}
+ atomic.AddUint32(&vd.dentry.mounts, 1)
+ mnt.ns = mntns
+ mntns.mountpoints[vd.dentry]++
+ vfs.mounts.insertSeqed(mnt)
+ vfsmpmounts, ok := vfs.mountpoints[vd.dentry]
+ if !ok {
+ vfsmpmounts = make(map[*Mount]struct{})
+ vfs.mountpoints[vd.dentry] = vfsmpmounts
+ }
+ vfsmpmounts[mnt] = struct{}{}
+}
+
+// disconnectLocked makes vd have no mount parent/point and returns its old
+// mount parent/point with a reference held.
+//
+// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a
+// writer critical section. mnt.parent() != nil.
+func (vfs *VirtualFilesystem) disconnectLocked(mnt *Mount) VirtualDentry {
+ vd := mnt.loadKey()
+ if checkInvariants {
+ if vd.mount != nil {
+ panic("VFS.disconnectLocked called on disconnected mount")
+ }
+ }
+ mnt.storeKey(VirtualDentry{})
+ delete(vd.mount.children, mnt)
+ atomic.AddUint32(&vd.dentry.mounts, math.MaxUint32) // -1
+ mnt.ns.mountpoints[vd.dentry]--
+ if mnt.ns.mountpoints[vd.dentry] == 0 {
+ delete(mnt.ns.mountpoints, vd.dentry)
+ }
+ vfs.mounts.removeSeqed(mnt)
+ vfsmpmounts := vfs.mountpoints[vd.dentry]
+ delete(vfsmpmounts, mnt)
+ if len(vfsmpmounts) == 0 {
+ delete(vfs.mountpoints, vd.dentry)
+ }
+ return vd
+}
+
+// tryIncMountedRef increments mnt's reference count and returns true. If mnt's
+// reference count is already zero, or has been eagerly umounted,
+// tryIncMountedRef does nothing and returns false.
+//
+// tryIncMountedRef does not require that a reference is held on mnt.
+func (mnt *Mount) tryIncMountedRef() bool {
+ for {
+ refs := atomic.LoadInt64(&mnt.refs)
+ if refs <= 0 { // refs < 0 => MSB set => eagerly unmounted
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs+1) {
+ return true
+ }
+ }
+}
+
+// IncRef increments mnt's reference count.
+func (mnt *Mount) IncRef() {
+ // In general, negative values for mnt.refs are valid because the MSB is
+ // the eager-unmount bit.
+ atomic.AddInt64(&mnt.refs, 1)
+}
+
+// DecRef decrements mnt's reference count.
+func (mnt *Mount) DecRef() {
+ refs := atomic.AddInt64(&mnt.refs, -1)
+ if refs&^math.MinInt64 == 0 { // mask out MSB
+ var vd VirtualDentry
+ if mnt.parent() != nil {
+ mnt.vfs.mountMu.Lock()
+ mnt.vfs.mounts.seq.BeginWrite()
+ vd = mnt.vfs.disconnectLocked(mnt)
+ mnt.vfs.mounts.seq.EndWrite()
+ mnt.vfs.mountMu.Unlock()
+ }
+ mnt.root.DecRef()
+ mnt.fs.DecRef()
+ if vd.Ok() {
+ vd.DecRef()
+ }
+ }
+}
+
+// IncRef increments mntns' reference count.
+func (mntns *MountNamespace) IncRef() {
+ if atomic.AddInt64(&mntns.refs, 1) <= 1 {
+ panic("MountNamespace.IncRef() called without holding a reference")
+ }
+}
+
+// DecRef decrements mntns' reference count.
+func (mntns *MountNamespace) DecRef() {
+ vfs := mntns.root.fs.VirtualFilesystem()
+ if refs := atomic.AddInt64(&mntns.refs, -1); refs == 0 {
+ vfs.mountMu.Lock()
+ vfs.mounts.seq.BeginWrite()
+ vdsToDecRef, mountsToDecRef := vfs.umountRecursiveLocked(mntns.root, &umountRecursiveOptions{
+ disconnectHierarchy: true,
+ }, nil, nil)
+ vfs.mounts.seq.EndWrite()
+ vfs.mountMu.Unlock()
+ for _, vd := range vdsToDecRef {
+ vd.DecRef()
+ }
+ for _, mnt := range mountsToDecRef {
+ mnt.DecRef()
+ }
+ } else if refs < 0 {
+ panic("MountNamespace.DecRef() called without holding a reference")
+ }
+}
+
+// getMountAt returns the last Mount in the stack mounted at (mnt, d). It takes
+// a reference on the returned Mount. If (mnt, d) is not a mount point,
+// getMountAt returns nil.
+//
+// getMountAt is analogous to Linux's fs/namei.c:follow_mount().
+//
+// Preconditions: References are held on mnt and d.
+func (vfs *VirtualFilesystem) getMountAt(mnt *Mount, d *Dentry) *Mount {
+ // The first mount is special-cased:
+ //
+ // - The caller is assumed to have checked d.isMounted() already. (This
+ // isn't a precondition because it doesn't matter for correctness.)
+ //
+ // - We return nil, instead of mnt, if there is no mount at (mnt, d).
+ //
+ // - We don't drop the caller's references on mnt and d.
+retryFirst:
+ next := vfs.mounts.Lookup(mnt, d)
+ if next == nil {
+ return nil
+ }
+ if !next.tryIncMountedRef() {
+ // Raced with umount.
+ goto retryFirst
+ }
+ mnt = next
+ d = next.root
+ // We don't need to take Dentry refs anywhere in this function because
+ // Mounts hold references on Mount.root, which is immutable.
+ for d.isMounted() {
+ next := vfs.mounts.Lookup(mnt, d)
+ if next == nil {
+ break
+ }
+ if !next.tryIncMountedRef() {
+ // Raced with umount.
+ continue
+ }
+ mnt.DecRef()
+ mnt = next
+ d = next.root
+ }
+ return mnt
+}
+
+// getMountpointAt returns the mount point for the stack of Mounts including
+// mnt. It takes a reference on the returned VirtualDentry. If no such mount
+// point exists (i.e. mnt is a root mount), getMountpointAt returns (nil, nil).
+//
+// Preconditions: References are held on mnt and root. vfsroot is not (mnt,
+// mnt.root).
+func (vfs *VirtualFilesystem) getMountpointAt(mnt *Mount, vfsroot VirtualDentry) VirtualDentry {
+ // The first mount is special-cased:
+ //
+ // - The caller must have already checked mnt against vfsroot.
+ //
+ // - We return nil, instead of mnt, if there is no mount point for mnt.
+ //
+ // - We don't drop the caller's reference on mnt.
+retryFirst:
+ epoch := vfs.mounts.seq.BeginRead()
+ parent, point := mnt.parent(), mnt.point()
+ if !vfs.mounts.seq.ReadOk(epoch) {
+ goto retryFirst
+ }
+ if parent == nil {
+ return VirtualDentry{}
+ }
+ if !parent.tryIncMountedRef() {
+ // Raced with umount.
+ goto retryFirst
+ }
+ if !point.TryIncRef() {
+ // Since Mount holds a reference on Mount.key.point, this can only
+ // happen due to a racing change to Mount.key.
+ parent.DecRef()
+ goto retryFirst
+ }
+ if !vfs.mounts.seq.ReadOk(epoch) {
+ point.DecRef()
+ parent.DecRef()
+ goto retryFirst
+ }
+ mnt = parent
+ d := point
+ for {
+ if mnt == vfsroot.mount && d == vfsroot.dentry {
+ break
+ }
+ if d != mnt.root {
+ break
+ }
+ retryNotFirst:
+ epoch := vfs.mounts.seq.BeginRead()
+ parent, point := mnt.parent(), mnt.point()
+ if !vfs.mounts.seq.ReadOk(epoch) {
+ goto retryNotFirst
+ }
+ if parent == nil {
+ break
+ }
+ if !parent.tryIncMountedRef() {
+ // Raced with umount.
+ goto retryNotFirst
+ }
+ if !point.TryIncRef() {
+ // Since Mount holds a reference on Mount.key.point, this can
+ // only happen due to a racing change to Mount.key.
+ parent.DecRef()
+ goto retryNotFirst
+ }
+ if !vfs.mounts.seq.ReadOk(epoch) {
+ point.DecRef()
+ parent.DecRef()
+ goto retryNotFirst
+ }
+ d.DecRef()
+ mnt.DecRef()
+ mnt = parent
+ d = point
+ }
+ return VirtualDentry{mnt, d}
+}
+
+// CheckBeginWrite increments the counter of in-progress write operations on
+// mnt. If mnt is mounted MS_RDONLY, CheckBeginWrite does nothing and returns
+// EROFS.
+//
+// If CheckBeginWrite succeeds, EndWrite must be called when the write
+// operation is finished.
+func (mnt *Mount) CheckBeginWrite() error {
+ if atomic.AddInt64(&mnt.writers, 1) < 0 {
+ atomic.AddInt64(&mnt.writers, -1)
+ return syserror.EROFS
+ }
+ return nil
+}
+
+// EndWrite indicates that a write operation signaled by a previous successful
+// call to CheckBeginWrite has finished.
+func (mnt *Mount) EndWrite() {
+ atomic.AddInt64(&mnt.writers, -1)
+}
+
+// Preconditions: VirtualFilesystem.mountMu must be locked.
+func (mnt *Mount) setReadOnlyLocked(ro bool) error {
+ if oldRO := atomic.LoadInt64(&mnt.writers) < 0; oldRO == ro {
+ return nil
+ }
+ if ro {
+ if !atomic.CompareAndSwapInt64(&mnt.writers, 0, math.MinInt64) {
+ return syserror.EBUSY
+ }
+ return nil
+ }
+ // Unset MSB without dropping any temporary increments from failed calls to
+ // mnt.CheckBeginWrite().
+ atomic.AddInt64(&mnt.writers, math.MinInt64)
+ return nil
+}
+
+func (mnt *Mount) readOnly() bool {
+ return atomic.LoadInt64(&mnt.writers) < 0
+}
+
+// Filesystem returns the mounted Filesystem. It does not take a reference on
+// the returned Filesystem.
+func (mnt *Mount) Filesystem() *Filesystem {
+ return mnt.fs
+}
+
+// submountsLocked returns this Mount and all Mounts that are descendents of
+// it.
+//
+// Precondition: mnt.vfs.mountMu must be held.
+func (mnt *Mount) submountsLocked() []*Mount {
+ mounts := []*Mount{mnt}
+ for m := range mnt.children {
+ mounts = append(mounts, m.submountsLocked()...)
+ }
+ return mounts
+}
+
+// Root returns the mount's root. It does not take a reference on the returned
+// Dentry.
+func (mnt *Mount) Root() *Dentry {
+ return mnt.root
+}
+
+// Root returns mntns' root. A reference is taken on the returned
+// VirtualDentry.
+func (mntns *MountNamespace) Root() VirtualDentry {
+ vd := VirtualDentry{
+ mount: mntns.root,
+ dentry: mntns.root.root,
+ }
+ vd.IncRef()
+ return vd
+}
+
+// GenerateProcMounts emits the contents of /proc/[pid]/mounts for vfs to buf.
+//
+// Preconditions: taskRootDir.Ok().
+func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDir VirtualDentry, buf *bytes.Buffer) {
+ vfs.mountMu.Lock()
+ defer vfs.mountMu.Unlock()
+ rootMnt := taskRootDir.mount
+ mounts := rootMnt.submountsLocked()
+ sort.Slice(mounts, func(i, j int) bool { return mounts[i].ID < mounts[j].ID })
+ for _, mnt := range mounts {
+ // Get the path to this mount relative to task root.
+ mntRootVD := VirtualDentry{
+ mount: mnt,
+ dentry: mnt.root,
+ }
+ path, err := vfs.PathnameReachable(ctx, taskRootDir, mntRootVD)
+ if err != nil {
+ // For some reason we didn't get a path. Log a warning
+ // and run with empty path.
+ ctx.Warningf("Error getting pathname for mount root %+v: %v", mnt.root, err)
+ path = ""
+ }
+ if path == "" {
+ // Either an error occurred, or path is not reachable
+ // from root.
+ break
+ }
+
+ opts := "rw"
+ if mnt.readOnly() {
+ opts = "ro"
+ }
+ if mnt.Flags.NoATime {
+ opts = ",noatime"
+ }
+ if mnt.Flags.NoExec {
+ opts += ",noexec"
+ }
+
+ // Format:
+ // <special device or remote filesystem> <mount point> <filesystem type> <mount options> <needs dump> <fsck order>
+ //
+ // The "needs dump" and "fsck order" flags are always 0, which
+ // is allowed.
+ fmt.Fprintf(buf, "%s %s %s %s %d %d\n", "none", path, mnt.fs.FilesystemType().Name(), opts, 0, 0)
+ }
+}
+
+// GenerateProcMountInfo emits the contents of /proc/[pid]/mountinfo for vfs to
+// buf.
+//
+// Preconditions: taskRootDir.Ok().
+func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRootDir VirtualDentry, buf *bytes.Buffer) {
+ vfs.mountMu.Lock()
+ defer vfs.mountMu.Unlock()
+ rootMnt := taskRootDir.mount
+ mounts := rootMnt.submountsLocked()
+ sort.Slice(mounts, func(i, j int) bool { return mounts[i].ID < mounts[j].ID })
+ for _, mnt := range mounts {
+ // Get the path to this mount relative to task root.
+ mntRootVD := VirtualDentry{
+ mount: mnt,
+ dentry: mnt.root,
+ }
+ path, err := vfs.PathnameReachable(ctx, taskRootDir, mntRootVD)
+ if err != nil {
+ // For some reason we didn't get a path. Log a warning
+ // and run with empty path.
+ ctx.Warningf("Error getting pathname for mount root %+v: %v", mnt.root, err)
+ path = ""
+ }
+ if path == "" {
+ // Either an error occurred, or path is not reachable
+ // from root.
+ break
+ }
+ // Stat the mount root to get the major/minor device numbers.
+ pop := &PathOperation{
+ Root: mntRootVD,
+ Start: mntRootVD,
+ }
+ statx, err := vfs.StatAt(ctx, auth.NewAnonymousCredentials(), pop, &StatOptions{})
+ if err != nil {
+ // Well that's not good. Ignore this mount.
+ break
+ }
+
+ // Format:
+ // 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
+ // (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
+
+ // (1) Mount ID.
+ fmt.Fprintf(buf, "%d ", mnt.ID)
+
+ // (2) Parent ID (or this ID if there is no parent).
+ pID := mnt.ID
+ if p := mnt.parent(); p != nil {
+ pID = p.ID
+ }
+ fmt.Fprintf(buf, "%d ", pID)
+
+ // (3) Major:Minor device ID. We don't have a superblock, so we
+ // just use the root inode device number.
+ fmt.Fprintf(buf, "%d:%d ", statx.DevMajor, statx.DevMinor)
+
+ // (4) Root: the pathname of the directory in the filesystem
+ // which forms the root of this mount.
+ //
+ // NOTE(b/78135857): This will always be "/" until we implement
+ // bind mounts.
+ fmt.Fprintf(buf, "/ ")
+
+ // (5) Mount point (relative to process root).
+ fmt.Fprintf(buf, "%s ", manglePath(path))
+
+ // (6) Mount options.
+ opts := "rw"
+ if mnt.readOnly() {
+ opts = "ro"
+ }
+ if mnt.Flags.NoATime {
+ opts = ",noatime"
+ }
+ if mnt.Flags.NoExec {
+ opts += ",noexec"
+ }
+ fmt.Fprintf(buf, "%s ", opts)
+
+ // (7) Optional fields: zero or more fields of the form "tag[:value]".
+ // (8) Separator: the end of the optional fields is marked by a single hyphen.
+ fmt.Fprintf(buf, "- ")
+
+ // (9) Filesystem type.
+ fmt.Fprintf(buf, "%s ", mnt.fs.FilesystemType().Name())
+
+ // (10) Mount source: filesystem-specific information or "none".
+ fmt.Fprintf(buf, "none ")
+
+ // (11) Superblock options, and final newline.
+ fmt.Fprintf(buf, "%s\n", superBlockOpts(path, mnt))
+ }
+}
+
+// manglePath replaces ' ', '\t', '\n', and '\\' with their octal equivalents.
+// See Linux fs/seq_file.c:mangle_path.
+func manglePath(p string) string {
+ r := strings.NewReplacer(" ", "\\040", "\t", "\\011", "\n", "\\012", "\\", "\\134")
+ return r.Replace(p)
+}
+
+// superBlockOpts returns the super block options string for the the mount at
+// the given path.
+func superBlockOpts(mountPath string, mnt *Mount) string {
+ // gVisor doesn't (yet) have a concept of super block options, so we
+ // use the ro/rw bit from the mount flag.
+ opts := "rw"
+ if mnt.readOnly() {
+ opts = "ro"
+ }
+
+ // NOTE(b/147673608): If the mount is a cgroup, we also need to include
+ // the cgroup name in the options. For now we just read that from the
+ // path.
+ //
+ // TODO(gvisor.dev/issue/190): Once gVisor has full cgroup support, we
+ // should get this value from the cgroup itself, and not rely on the
+ // path.
+ if mnt.fs.FilesystemType().Name() == "cgroup" {
+ splitPath := strings.Split(mountPath, "/")
+ cgroupType := splitPath[len(splitPath)-1]
+ opts += "," + cgroupType
+ }
+ return opts
+}
diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go
new file mode 100644
index 000000000..3335e4057
--- /dev/null
+++ b/pkg/sentry/vfs/mount_test.go
@@ -0,0 +1,458 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "fmt"
+ "runtime"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func TestMountTableLookupEmpty(t *testing.T) {
+ var mt mountTable
+ mt.Init()
+
+ parent := &Mount{}
+ point := &Dentry{}
+ if m := mt.Lookup(parent, point); m != nil {
+ t.Errorf("empty mountTable lookup: got %p, wanted nil", m)
+ }
+}
+
+func TestMountTableInsertLookup(t *testing.T) {
+ var mt mountTable
+ mt.Init()
+
+ mount := &Mount{}
+ mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}})
+ mt.Insert(mount)
+
+ if m := mt.Lookup(mount.parent(), mount.point()); m != mount {
+ t.Errorf("mountTable positive lookup: got %p, wanted %p", m, mount)
+ }
+
+ otherParent := &Mount{}
+ if m := mt.Lookup(otherParent, mount.point()); m != nil {
+ t.Errorf("mountTable lookup with wrong mount parent: got %p, wanted nil", m)
+ }
+ otherPoint := &Dentry{}
+ if m := mt.Lookup(mount.parent(), otherPoint); m != nil {
+ t.Errorf("mountTable lookup with wrong mount point: got %p, wanted nil", m)
+ }
+}
+
+// TODO(gvisor.dev/issue/1035): concurrent lookup/insertion/removal.
+
+// must be powers of 2
+var benchNumMounts = []int{1 << 2, 1 << 5, 1 << 8}
+
+// For all of the following:
+//
+// - BenchmarkMountTableFoo tests usage pattern "Foo" for mountTable.
+//
+// - BenchmarkMountMapFoo tests usage pattern "Foo" for a
+// sync.RWMutex-protected map. (Mutator benchmarks do not use a RWMutex, since
+// mountTable also requires external synchronization between mutators.)
+//
+// - BenchmarkMountSyncMapFoo tests usage pattern "Foo" for a sync.Map.
+//
+// ParallelLookup is by far the most common and performance-sensitive operation
+// for this application. NegativeLookup is also important, but less so (only
+// relevant with multiple mount namespaces and significant differences in
+// mounts between them). Insertion and removal are benchmarked for
+// completeness.
+const enableComparativeBenchmarks = false
+
+func newBenchMount() *Mount {
+ mount := &Mount{}
+ mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}})
+ return mount
+}
+
+func BenchmarkMountTableParallelLookup(b *testing.B) {
+ for numG, maxG := 1, runtime.GOMAXPROCS(0); numG >= 0 && numG <= maxG; numG *= 2 {
+ for _, numMounts := range benchNumMounts {
+ desc := fmt.Sprintf("%dx%d", numG, numMounts)
+ b.Run(desc, func(b *testing.B) {
+ var mt mountTable
+ mt.Init()
+ keys := make([]VirtualDentry, 0, numMounts)
+ for i := 0; i < numMounts; i++ {
+ mount := newBenchMount()
+ mt.Insert(mount)
+ keys = append(keys, mount.loadKey())
+ }
+
+ var ready sync.WaitGroup
+ begin := make(chan struct{})
+ var end sync.WaitGroup
+ for g := 0; g < numG; g++ {
+ ready.Add(1)
+ end.Add(1)
+ go func() {
+ defer end.Done()
+ ready.Done()
+ <-begin
+ for i := 0; i < b.N; i++ {
+ k := keys[i&(numMounts-1)]
+ m := mt.Lookup(k.mount, k.dentry)
+ if m == nil {
+ b.Fatalf("lookup failed")
+ }
+ if parent := m.parent(); parent != k.mount {
+ b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ }
+ if point := m.point(); point != k.dentry {
+ b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry)
+ }
+ }
+ }()
+ }
+
+ ready.Wait()
+ b.ResetTimer()
+ close(begin)
+ end.Wait()
+ })
+ }
+ }
+}
+
+func BenchmarkMountMapParallelLookup(b *testing.B) {
+ if !enableComparativeBenchmarks {
+ b.Skipf("comparative benchmarks are disabled")
+ }
+
+ for numG, maxG := 1, runtime.GOMAXPROCS(0); numG >= 0 && numG <= maxG; numG *= 2 {
+ for _, numMounts := range benchNumMounts {
+ desc := fmt.Sprintf("%dx%d", numG, numMounts)
+ b.Run(desc, func(b *testing.B) {
+ var mu sync.RWMutex
+ ms := make(map[VirtualDentry]*Mount)
+ keys := make([]VirtualDentry, 0, numMounts)
+ for i := 0; i < numMounts; i++ {
+ mount := newBenchMount()
+ key := mount.loadKey()
+ ms[key] = mount
+ keys = append(keys, key)
+ }
+
+ var ready sync.WaitGroup
+ begin := make(chan struct{})
+ var end sync.WaitGroup
+ for g := 0; g < numG; g++ {
+ ready.Add(1)
+ end.Add(1)
+ go func() {
+ defer end.Done()
+ ready.Done()
+ <-begin
+ for i := 0; i < b.N; i++ {
+ k := keys[i&(numMounts-1)]
+ mu.RLock()
+ m := ms[k]
+ mu.RUnlock()
+ if m == nil {
+ b.Fatalf("lookup failed")
+ }
+ if parent := m.parent(); parent != k.mount {
+ b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ }
+ if point := m.point(); point != k.dentry {
+ b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry)
+ }
+ }
+ }()
+ }
+
+ ready.Wait()
+ b.ResetTimer()
+ close(begin)
+ end.Wait()
+ })
+ }
+ }
+}
+
+func BenchmarkMountSyncMapParallelLookup(b *testing.B) {
+ if !enableComparativeBenchmarks {
+ b.Skipf("comparative benchmarks are disabled")
+ }
+
+ for numG, maxG := 1, runtime.GOMAXPROCS(0); numG >= 0 && numG <= maxG; numG *= 2 {
+ for _, numMounts := range benchNumMounts {
+ desc := fmt.Sprintf("%dx%d", numG, numMounts)
+ b.Run(desc, func(b *testing.B) {
+ var ms sync.Map
+ keys := make([]VirtualDentry, 0, numMounts)
+ for i := 0; i < numMounts; i++ {
+ mount := newBenchMount()
+ key := mount.loadKey()
+ ms.Store(key, mount)
+ keys = append(keys, key)
+ }
+
+ var ready sync.WaitGroup
+ begin := make(chan struct{})
+ var end sync.WaitGroup
+ for g := 0; g < numG; g++ {
+ ready.Add(1)
+ end.Add(1)
+ go func() {
+ defer end.Done()
+ ready.Done()
+ <-begin
+ for i := 0; i < b.N; i++ {
+ k := keys[i&(numMounts-1)]
+ mi, ok := ms.Load(k)
+ if !ok {
+ b.Fatalf("lookup failed")
+ }
+ m := mi.(*Mount)
+ if parent := m.parent(); parent != k.mount {
+ b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ }
+ if point := m.point(); point != k.dentry {
+ b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry)
+ }
+ }
+ }()
+ }
+
+ ready.Wait()
+ b.ResetTimer()
+ close(begin)
+ end.Wait()
+ })
+ }
+ }
+}
+
+func BenchmarkMountTableNegativeLookup(b *testing.B) {
+ for _, numMounts := range benchNumMounts {
+ desc := fmt.Sprintf("%d", numMounts)
+ b.Run(desc, func(b *testing.B) {
+ var mt mountTable
+ mt.Init()
+ for i := 0; i < numMounts; i++ {
+ mt.Insert(newBenchMount())
+ }
+ negkeys := make([]VirtualDentry, 0, numMounts)
+ for i := 0; i < numMounts; i++ {
+ negkeys = append(negkeys, VirtualDentry{
+ mount: &Mount{},
+ dentry: &Dentry{},
+ })
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ k := negkeys[i&(numMounts-1)]
+ m := mt.Lookup(k.mount, k.dentry)
+ if m != nil {
+ b.Fatalf("lookup got %p, wanted nil", m)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkMountMapNegativeLookup(b *testing.B) {
+ if !enableComparativeBenchmarks {
+ b.Skipf("comparative benchmarks are disabled")
+ }
+
+ for _, numMounts := range benchNumMounts {
+ desc := fmt.Sprintf("%d", numMounts)
+ b.Run(desc, func(b *testing.B) {
+ var mu sync.RWMutex
+ ms := make(map[VirtualDentry]*Mount)
+ for i := 0; i < numMounts; i++ {
+ mount := newBenchMount()
+ ms[mount.loadKey()] = mount
+ }
+ negkeys := make([]VirtualDentry, 0, numMounts)
+ for i := 0; i < numMounts; i++ {
+ negkeys = append(negkeys, VirtualDentry{
+ mount: &Mount{},
+ dentry: &Dentry{},
+ })
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ k := negkeys[i&(numMounts-1)]
+ mu.RLock()
+ m := ms[k]
+ mu.RUnlock()
+ if m != nil {
+ b.Fatalf("lookup got %p, wanted nil", m)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkMountSyncMapNegativeLookup(b *testing.B) {
+ if !enableComparativeBenchmarks {
+ b.Skipf("comparative benchmarks are disabled")
+ }
+
+ for _, numMounts := range benchNumMounts {
+ desc := fmt.Sprintf("%d", numMounts)
+ b.Run(desc, func(b *testing.B) {
+ var ms sync.Map
+ for i := 0; i < numMounts; i++ {
+ mount := newBenchMount()
+ ms.Store(mount.loadKey(), mount)
+ }
+ negkeys := make([]VirtualDentry, 0, numMounts)
+ for i := 0; i < numMounts; i++ {
+ negkeys = append(negkeys, VirtualDentry{
+ mount: &Mount{},
+ dentry: &Dentry{},
+ })
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ k := negkeys[i&(numMounts-1)]
+ m, _ := ms.Load(k)
+ if m != nil {
+ b.Fatalf("lookup got %p, wanted nil", m)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkMountTableInsert(b *testing.B) {
+ // Preallocate Mounts so that allocation time isn't included in the
+ // benchmark.
+ mounts := make([]*Mount, 0, b.N)
+ for i := 0; i < b.N; i++ {
+ mounts = append(mounts, newBenchMount())
+ }
+
+ var mt mountTable
+ mt.Init()
+ b.ResetTimer()
+ for i := range mounts {
+ mt.Insert(mounts[i])
+ }
+}
+
+func BenchmarkMountMapInsert(b *testing.B) {
+ if !enableComparativeBenchmarks {
+ b.Skipf("comparative benchmarks are disabled")
+ }
+
+ // Preallocate Mounts so that allocation time isn't included in the
+ // benchmark.
+ mounts := make([]*Mount, 0, b.N)
+ for i := 0; i < b.N; i++ {
+ mounts = append(mounts, newBenchMount())
+ }
+
+ ms := make(map[VirtualDentry]*Mount)
+ b.ResetTimer()
+ for i := range mounts {
+ mount := mounts[i]
+ ms[mount.loadKey()] = mount
+ }
+}
+
+func BenchmarkMountSyncMapInsert(b *testing.B) {
+ if !enableComparativeBenchmarks {
+ b.Skipf("comparative benchmarks are disabled")
+ }
+
+ // Preallocate Mounts so that allocation time isn't included in the
+ // benchmark.
+ mounts := make([]*Mount, 0, b.N)
+ for i := 0; i < b.N; i++ {
+ mounts = append(mounts, newBenchMount())
+ }
+
+ var ms sync.Map
+ b.ResetTimer()
+ for i := range mounts {
+ mount := mounts[i]
+ ms.Store(mount.loadKey(), mount)
+ }
+}
+
+func BenchmarkMountTableRemove(b *testing.B) {
+ mounts := make([]*Mount, 0, b.N)
+ for i := 0; i < b.N; i++ {
+ mounts = append(mounts, newBenchMount())
+ }
+ var mt mountTable
+ mt.Init()
+ for i := range mounts {
+ mt.Insert(mounts[i])
+ }
+
+ b.ResetTimer()
+ for i := range mounts {
+ mt.Remove(mounts[i])
+ }
+}
+
+func BenchmarkMountMapRemove(b *testing.B) {
+ if !enableComparativeBenchmarks {
+ b.Skipf("comparative benchmarks are disabled")
+ }
+
+ mounts := make([]*Mount, 0, b.N)
+ for i := 0; i < b.N; i++ {
+ mounts = append(mounts, newBenchMount())
+ }
+ ms := make(map[VirtualDentry]*Mount)
+ for i := range mounts {
+ mount := mounts[i]
+ ms[mount.loadKey()] = mount
+ }
+
+ b.ResetTimer()
+ for i := range mounts {
+ mount := mounts[i]
+ delete(ms, mount.loadKey())
+ }
+}
+
+func BenchmarkMountSyncMapRemove(b *testing.B) {
+ if !enableComparativeBenchmarks {
+ b.Skipf("comparative benchmarks are disabled")
+ }
+
+ mounts := make([]*Mount, 0, b.N)
+ for i := 0; i < b.N; i++ {
+ mounts = append(mounts, newBenchMount())
+ }
+ var ms sync.Map
+ for i := range mounts {
+ mount := mounts[i]
+ ms.Store(mount.loadKey(), mount)
+ }
+
+ b.ResetTimer()
+ for i := range mounts {
+ mount := mounts[i]
+ ms.Delete(mount.loadKey())
+ }
+}
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
new file mode 100644
index 000000000..70f850ca4
--- /dev/null
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -0,0 +1,364 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build go1.12
+// +build !go1.16
+
+// Check go:linkname function signatures when updating Go version.
+
+package vfs
+
+import (
+ "fmt"
+ "math/bits"
+ "reflect"
+ "sync/atomic"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// mountKey represents the location at which a Mount is mounted. It is
+// structurally identical to VirtualDentry, but stores its fields as
+// unsafe.Pointer since mutators synchronize with VFS path traversal using
+// seqcounts.
+type mountKey struct {
+ parent unsafe.Pointer // *Mount
+ point unsafe.Pointer // *Dentry
+}
+
+func (mnt *Mount) parent() *Mount {
+ return (*Mount)(atomic.LoadPointer(&mnt.key.parent))
+}
+
+func (mnt *Mount) point() *Dentry {
+ return (*Dentry)(atomic.LoadPointer(&mnt.key.point))
+}
+
+func (mnt *Mount) loadKey() VirtualDentry {
+ return VirtualDentry{
+ mount: mnt.parent(),
+ dentry: mnt.point(),
+ }
+}
+
+// Invariant: mnt.key.parent == nil. vd.Ok().
+func (mnt *Mount) storeKey(vd VirtualDentry) {
+ atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount))
+ atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry))
+}
+
+// mountTable maps (mount parent, mount point) pairs to mounts. It supports
+// efficient concurrent lookup, even in the presence of concurrent mutators
+// (provided mutation is sufficiently uncommon).
+//
+// mountTable.Init() must be called on new mountTables before use.
+//
+// +stateify savable
+type mountTable struct {
+ // mountTable is implemented as a seqcount-protected hash table that
+ // resolves collisions with linear probing, featuring Robin Hood insertion
+ // and backward shift deletion. These minimize probe length variance,
+ // significantly improving the performance of linear probing at high load
+ // factors. (mountTable doesn't use bucketing, which is the other major
+ // technique commonly used in high-performance hash tables; the efficiency
+ // of bucketing is largely due to SIMD lookup, and Go lacks both SIMD
+ // intrinsics and inline assembly, limiting the performance of this
+ // approach.)
+
+ seq sync.SeqCount `state:"nosave"`
+ seed uint32 // for hashing keys
+
+ // size holds both length (number of elements) and capacity (number of
+ // slots): capacity is stored as its base-2 log (referred to as order) in
+ // the least significant bits of size, and length is stored in the
+ // remaining bits. Go defines bit shifts >= width of shifted unsigned
+ // operand as shifting to 0, which differs from x86's SHL, so the Go
+ // compiler inserts a bounds check for each bit shift unless we mask order
+ // anyway (cf. runtime.bucketShift()), and length isn't used by lookup;
+ // thus this bit packing gets us more bits for the length (vs. storing
+ // length and cap in separate uint32s) for ~free.
+ size uint64
+
+ slots unsafe.Pointer `state:"nosave"` // []mountSlot; never nil after Init
+}
+
+type mountSlot struct {
+ // We don't store keys in slots; instead, we just check Mount.parent and
+ // Mount.point directly. Any practical use of lookup will need to touch
+ // Mounts anyway, and comparing hashes means that false positives are
+ // extremely rare, so this isn't an extra cache line touch overall.
+ value unsafe.Pointer // *Mount
+ hash uintptr
+}
+
+const (
+ mtSizeOrderBits = 6 // log2 of pointer size in bits
+ mtSizeOrderMask = (1 << mtSizeOrderBits) - 1
+ mtSizeOrderOne = 1
+ mtSizeLenLSB = mtSizeOrderBits
+ mtSizeLenOne = 1 << mtSizeLenLSB
+ mtSizeLenNegOne = ^uint64(mtSizeOrderMask) // uint64(-1) << mtSizeLenLSB
+
+ mountSlotBytes = unsafe.Sizeof(mountSlot{})
+ mountKeyBytes = unsafe.Sizeof(mountKey{})
+
+ // Tuning parameters.
+ //
+ // Essentially every mountTable will contain at least /proc, /sys, and
+ // /dev/shm, so there is ~no reason for mtInitCap to be < 4.
+ mtInitOrder = 2
+ mtInitCap = 1 << mtInitOrder
+ mtMaxLoadNum = 13
+ mtMaxLoadDen = 16
+)
+
+func init() {
+ // We can't just define mtSizeOrderBits as follows because Go doesn't have
+ // constexpr.
+ if ptrBits := uint(unsafe.Sizeof(uintptr(0)) * 8); mtSizeOrderBits != bits.TrailingZeros(ptrBits) {
+ panic(fmt.Sprintf("mtSizeOrderBits (%d) must be %d = log2 of pointer size in bits (%d)", mtSizeOrderBits, bits.TrailingZeros(ptrBits), ptrBits))
+ }
+ if bits.OnesCount(uint(mountSlotBytes)) != 1 {
+ panic(fmt.Sprintf("sizeof(mountSlotBytes) (%d) must be a power of 2 to use bit masking for wraparound", mountSlotBytes))
+ }
+ if mtInitCap <= 1 {
+ panic(fmt.Sprintf("mtInitCap (%d) must be at least 2 since mountTable methods assume that there will always be at least one empty slot", mtInitCap))
+ }
+ if mtMaxLoadNum >= mtMaxLoadDen {
+ panic(fmt.Sprintf("invalid mountTable maximum load factor (%d/%d)", mtMaxLoadNum, mtMaxLoadDen))
+ }
+}
+
+// Init must be called exactly once on each mountTable before use.
+func (mt *mountTable) Init() {
+ mt.seed = rand32()
+ mt.size = mtInitOrder
+ mt.slots = newMountTableSlots(mtInitCap)
+}
+
+func newMountTableSlots(cap uintptr) unsafe.Pointer {
+ slice := make([]mountSlot, cap, cap)
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
+ return unsafe.Pointer(hdr.Data)
+}
+
+// Lookup returns the Mount with the given parent, mounted at the given point.
+// If no such Mount exists, Lookup returns nil.
+//
+// Lookup may be called even if there are concurrent mutators of mt.
+func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount {
+ key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)}
+ hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes)
+
+loop:
+ for {
+ epoch := mt.seq.BeginRead()
+ size := atomic.LoadUint64(&mt.size)
+ slots := atomic.LoadPointer(&mt.slots)
+ if !mt.seq.ReadOk(epoch) {
+ continue
+ }
+ tcap := uintptr(1) << (size & mtSizeOrderMask)
+ mask := tcap - 1
+ off := (hash & mask) * mountSlotBytes
+ offmask := mask * mountSlotBytes
+ for {
+ // This avoids bounds checking.
+ slot := (*mountSlot)(unsafe.Pointer(uintptr(slots) + off))
+ slotValue := atomic.LoadPointer(&slot.value)
+ slotHash := atomic.LoadUintptr(&slot.hash)
+ if !mt.seq.ReadOk(epoch) {
+ // The element we're looking for might have been moved into a
+ // slot we've previously checked, so restart entirely.
+ continue loop
+ }
+ if slotValue == nil {
+ return nil
+ }
+ if slotHash == hash {
+ mount := (*Mount)(slotValue)
+ var mountKey mountKey
+ mountKey.parent = atomic.LoadPointer(&mount.key.parent)
+ mountKey.point = atomic.LoadPointer(&mount.key.point)
+ if !mt.seq.ReadOk(epoch) {
+ continue loop
+ }
+ if key == mountKey {
+ return mount
+ }
+ }
+ off = (off + mountSlotBytes) & offmask
+ }
+ }
+}
+
+// Insert inserts the given mount into mt.
+//
+// Preconditions: mt must not already contain a Mount with the same mount point
+// and parent.
+func (mt *mountTable) Insert(mount *Mount) {
+ mt.seq.BeginWrite()
+ mt.insertSeqed(mount)
+ mt.seq.EndWrite()
+}
+
+// insertSeqed inserts the given mount into mt.
+//
+// Preconditions: mt.seq must be in a writer critical section. mt must not
+// already contain a Mount with the same mount point and parent.
+func (mt *mountTable) insertSeqed(mount *Mount) {
+ hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes)
+
+ // We're under the maximum load factor if:
+ //
+ // (len+1) / cap <= mtMaxLoadNum / mtMaxLoadDen
+ // (len+1) * mtMaxLoadDen <= mtMaxLoadNum * cap
+ tlen := mt.size >> mtSizeLenLSB
+ order := mt.size & mtSizeOrderMask
+ tcap := uintptr(1) << order
+ if ((tlen + 1) * mtMaxLoadDen) <= (uint64(mtMaxLoadNum) << order) {
+ // Atomically insert the new element into the table.
+ atomic.AddUint64(&mt.size, mtSizeLenOne)
+ mtInsertLocked(mt.slots, tcap, unsafe.Pointer(mount), hash)
+ return
+ }
+
+ // Otherwise, we have to expand. Double the number of slots in the new
+ // table.
+ newOrder := order + 1
+ if newOrder > mtSizeOrderMask {
+ panic("mount table size overflow")
+ }
+ newCap := uintptr(1) << newOrder
+ newSlots := newMountTableSlots(newCap)
+ // Copy existing elements to the new table.
+ oldCur := mt.slots
+ // Go does not permit pointers to the end of allocated objects, so we
+ // must use a pointer to the last element of the old table. The
+ // following expression is equivalent to
+ // `slots+(cap-1)*mountSlotBytes` but has a critical path length of 2
+ // arithmetic instructions instead of 3.
+ oldLast := unsafe.Pointer((uintptr(mt.slots) - mountSlotBytes) + (tcap * mountSlotBytes))
+ for {
+ oldSlot := (*mountSlot)(oldCur)
+ if oldSlot.value != nil {
+ mtInsertLocked(newSlots, newCap, oldSlot.value, oldSlot.hash)
+ }
+ if oldCur == oldLast {
+ break
+ }
+ oldCur = unsafe.Pointer(uintptr(oldCur) + mountSlotBytes)
+ }
+ // Insert the new element into the new table.
+ mtInsertLocked(newSlots, newCap, unsafe.Pointer(mount), hash)
+ // Switch to the new table.
+ atomic.AddUint64(&mt.size, mtSizeLenOne|mtSizeOrderOne)
+ atomic.StorePointer(&mt.slots, newSlots)
+}
+
+// Preconditions: There are no concurrent mutators of the table (slots, cap).
+// If the table is visible to readers, then mt.seq must be in a writer critical
+// section. cap must be a power of 2.
+func mtInsertLocked(slots unsafe.Pointer, cap uintptr, value unsafe.Pointer, hash uintptr) {
+ mask := cap - 1
+ off := (hash & mask) * mountSlotBytes
+ offmask := mask * mountSlotBytes
+ disp := uintptr(0)
+ for {
+ slot := (*mountSlot)(unsafe.Pointer(uintptr(slots) + off))
+ slotValue := slot.value
+ if slotValue == nil {
+ atomic.StorePointer(&slot.value, value)
+ atomic.StoreUintptr(&slot.hash, hash)
+ return
+ }
+ // If we've been displaced farther from our first-probed slot than the
+ // element stored in this one, swap elements and switch to inserting
+ // the replaced one. (This is Robin Hood insertion.)
+ slotHash := slot.hash
+ slotDisp := ((off / mountSlotBytes) - slotHash) & mask
+ if disp > slotDisp {
+ atomic.StorePointer(&slot.value, value)
+ atomic.StoreUintptr(&slot.hash, hash)
+ value = slotValue
+ hash = slotHash
+ disp = slotDisp
+ }
+ off = (off + mountSlotBytes) & offmask
+ disp++
+ }
+}
+
+// Remove removes the given mount from mt.
+//
+// Preconditions: mt must contain mount.
+func (mt *mountTable) Remove(mount *Mount) {
+ mt.seq.BeginWrite()
+ mt.removeSeqed(mount)
+ mt.seq.EndWrite()
+}
+
+// removeSeqed removes the given mount from mt.
+//
+// Preconditions: mt.seq must be in a writer critical section. mt must contain
+// mount.
+func (mt *mountTable) removeSeqed(mount *Mount) {
+ hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes)
+ tcap := uintptr(1) << (mt.size & mtSizeOrderMask)
+ mask := tcap - 1
+ slots := mt.slots
+ off := (hash & mask) * mountSlotBytes
+ offmask := mask * mountSlotBytes
+ for {
+ slot := (*mountSlot)(unsafe.Pointer(uintptr(slots) + off))
+ slotValue := slot.value
+ if slotValue == unsafe.Pointer(mount) {
+ // Found the element to remove. Move all subsequent elements
+ // backward until we either find an empty slot, or an element that
+ // is already in its first-probed slot. (This is backward shift
+ // deletion.)
+ for {
+ nextOff := (off + mountSlotBytes) & offmask
+ nextSlot := (*mountSlot)(unsafe.Pointer(uintptr(slots) + nextOff))
+ nextSlotValue := nextSlot.value
+ if nextSlotValue == nil {
+ break
+ }
+ nextSlotHash := nextSlot.hash
+ if (nextOff / mountSlotBytes) == (nextSlotHash & mask) {
+ break
+ }
+ atomic.StorePointer(&slot.value, nextSlotValue)
+ atomic.StoreUintptr(&slot.hash, nextSlotHash)
+ off = nextOff
+ slot = nextSlot
+ }
+ atomic.StorePointer(&slot.value, nil)
+ atomic.AddUint64(&mt.size, mtSizeLenNegOne)
+ return
+ }
+ if checkInvariants && slotValue == nil {
+ panic(fmt.Sprintf("mountTable.Remove() called on missing Mount %v", mount))
+ }
+ off = (off + mountSlotBytes) & offmask
+ }
+}
+
+//go:linkname memhash runtime.memhash
+func memhash(p unsafe.Pointer, seed, s uintptr) uintptr
+
+//go:linkname rand32 runtime.fastrand
+func rand32() uint32
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
new file mode 100644
index 000000000..f223aeda8
--- /dev/null
+++ b/pkg/sentry/vfs/options.go
@@ -0,0 +1,235 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+)
+
+// GetDentryOptions contains options to VirtualFilesystem.GetDentryAt() and
+// FilesystemImpl.GetDentryAt().
+type GetDentryOptions struct {
+ // If CheckSearchable is true, FilesystemImpl.GetDentryAt() must check that
+ // the returned Dentry is a directory for which creds has search
+ // permission.
+ CheckSearchable bool
+}
+
+// MkdirOptions contains options to VirtualFilesystem.MkdirAt() and
+// FilesystemImpl.MkdirAt().
+type MkdirOptions struct {
+ // Mode is the file mode bits for the created directory.
+ Mode linux.FileMode
+
+ // If ForSyntheticMountpoint is true, FilesystemImpl.MkdirAt() may create
+ // the given directory in memory only (as opposed to persistent storage).
+ // The created directory should be able to support the creation of
+ // subdirectories with ForSyntheticMountpoint == true. It does not need to
+ // support the creation of subdirectories with ForSyntheticMountpoint ==
+ // false, or files of other types.
+ //
+ // FilesystemImpls are permitted to ignore the ForSyntheticMountpoint
+ // option.
+ //
+ // The ForSyntheticMountpoint option exists because, unlike mount(2), the
+ // OCI Runtime Specification permits the specification of mount points that
+ // do not exist, under the expectation that container runtimes will create
+ // them. (More accurately, the OCI Runtime Specification completely fails
+ // to document this feature, but it's implemented by runc.)
+ // ForSyntheticMountpoint allows such mount points to be created even when
+ // the underlying persistent filesystem is immutable.
+ ForSyntheticMountpoint bool
+}
+
+// MknodOptions contains options to VirtualFilesystem.MknodAt() and
+// FilesystemImpl.MknodAt().
+type MknodOptions struct {
+ // Mode is the file type and mode bits for the created file.
+ Mode linux.FileMode
+
+ // If Mode specifies a character or block device special file, DevMajor and
+ // DevMinor are the major and minor device numbers for the created device.
+ DevMajor uint32
+ DevMinor uint32
+
+ // Endpoint is the endpoint to bind to the created file, if a socket file is
+ // being created for bind(2) on a Unix domain socket.
+ Endpoint transport.BoundEndpoint
+}
+
+// MountFlags contains flags as specified for mount(2), e.g. MS_NOEXEC.
+// MS_RDONLY is not part of MountFlags because it's tracked in Mount.writers.
+type MountFlags struct {
+ // NoExec is equivalent to MS_NOEXEC.
+ NoExec bool
+
+ // NoATime is equivalent to MS_NOATIME and indicates that the
+ // filesystem should not update access time in-place.
+ NoATime bool
+}
+
+// MountOptions contains options to VirtualFilesystem.MountAt().
+type MountOptions struct {
+ // Flags contains flags as specified for mount(2), e.g. MS_NOEXEC.
+ Flags MountFlags
+
+ // ReadOnly is equivalent to MS_RDONLY.
+ ReadOnly bool
+
+ // GetFilesystemOptions contains options to FilesystemType.GetFilesystem().
+ GetFilesystemOptions GetFilesystemOptions
+
+ // If InternalMount is true, allow the use of filesystem types for which
+ // RegisterFilesystemTypeOptions.AllowUserMount == false.
+ InternalMount bool
+}
+
+// OpenOptions contains options to VirtualFilesystem.OpenAt() and
+// FilesystemImpl.OpenAt().
+type OpenOptions struct {
+ // Flags contains access mode and flags as specified for open(2).
+ //
+ // FilesystemImpls are responsible for implementing the following flags:
+ // O_RDONLY, O_WRONLY, O_RDWR, O_APPEND, O_CREAT, O_DIRECT, O_DSYNC,
+ // O_EXCL, O_NOATIME, O_NOCTTY, O_NONBLOCK, O_PATH, O_SYNC, O_TMPFILE, and
+ // O_TRUNC. VFS is responsible for handling O_DIRECTORY, O_LARGEFILE, and
+ // O_NOFOLLOW. VFS users are responsible for handling O_CLOEXEC, since file
+ // descriptors are mostly outside the scope of VFS.
+ Flags uint32
+
+ // If FilesystemImpl.OpenAt() creates a file, Mode is the file mode for the
+ // created file.
+ Mode linux.FileMode
+
+ // FileExec is set when the file is being opened to be executed.
+ // VirtualFilesystem.OpenAt() checks that the caller has execute permissions
+ // on the file, that the file is a regular file, and that the mount doesn't
+ // have MS_NOEXEC set.
+ FileExec bool
+}
+
+// ReadOptions contains options to FileDescription.PRead(),
+// FileDescriptionImpl.PRead(), FileDescription.Read(), and
+// FileDescriptionImpl.Read().
+type ReadOptions struct {
+ // Flags contains flags as specified for preadv2(2).
+ Flags uint32
+}
+
+// RenameOptions contains options to VirtualFilesystem.RenameAt() and
+// FilesystemImpl.RenameAt().
+type RenameOptions struct {
+ // Flags contains flags as specified for renameat2(2).
+ Flags uint32
+
+ // If MustBeDir is true, the renamed file must be a directory.
+ MustBeDir bool
+}
+
+// SetStatOptions contains options to VirtualFilesystem.SetStatAt(),
+// FilesystemImpl.SetStatAt(), FileDescription.SetStat(), and
+// FileDescriptionImpl.SetStat().
+type SetStatOptions struct {
+ // Stat is the metadata that should be set. Only fields indicated by
+ // Stat.Mask should be set.
+ //
+ // If Stat specifies that a timestamp should be set,
+ // FilesystemImpl.SetStatAt() and FileDescriptionImpl.SetStat() must
+ // special-case StatxTimestamp.Nsec == UTIME_NOW as described by
+ // utimensat(2); however, they do not need to check for StatxTimestamp.Nsec
+ // == UTIME_OMIT (VFS users must unset the corresponding bit in Stat.Mask
+ // instead).
+ Stat linux.Statx
+}
+
+// BoundEndpointOptions contains options to VirtualFilesystem.BoundEndpointAt()
+// and FilesystemImpl.BoundEndpointAt().
+type BoundEndpointOptions struct {
+ // Addr is the path of the file whose socket endpoint is being retrieved.
+ // It is generally irrelevant: most endpoints are stored at a dentry that
+ // was created through a bind syscall, so the path can be stored on creation.
+ // However, if the endpoint was created in FilesystemImpl.BoundEndpointAt(),
+ // then we may not know what the original bind address was.
+ //
+ // For example, if connect(2) is called with address "foo" which corresponds
+ // a remote named socket in goferfs, we need to generate an endpoint wrapping
+ // that file. In this case, we can use Addr to set the endpoint address to
+ // "foo". Note that Addr is only a best-effort attempt--we still do not know
+ // the exact address that was used on the remote fs to bind the socket (it
+ // may have been "foo", "./foo", etc.).
+ Addr string
+}
+
+// GetxattrOptions contains options to VirtualFilesystem.GetxattrAt(),
+// FilesystemImpl.GetxattrAt(), FileDescription.Getxattr(), and
+// FileDescriptionImpl.Getxattr().
+type GetxattrOptions struct {
+ // Name is the name of the extended attribute to retrieve.
+ Name string
+
+ // Size is the maximum value size that the caller will tolerate. If the value
+ // is larger than size, getxattr methods may return ERANGE, but they are also
+ // free to ignore the hint entirely (i.e. the value returned may be larger
+ // than size). All size checking is done independently at the syscall layer.
+ Size uint64
+}
+
+// SetxattrOptions contains options to VirtualFilesystem.SetxattrAt(),
+// FilesystemImpl.SetxattrAt(), FileDescription.Setxattr(), and
+// FileDescriptionImpl.Setxattr().
+type SetxattrOptions struct {
+ // Name is the name of the extended attribute being mutated.
+ Name string
+
+ // Value is the extended attribute's new value.
+ Value string
+
+ // Flags contains flags as specified for setxattr/lsetxattr/fsetxattr(2).
+ Flags uint32
+}
+
+// StatOptions contains options to VirtualFilesystem.StatAt(),
+// FilesystemImpl.StatAt(), FileDescription.Stat(), and
+// FileDescriptionImpl.Stat().
+type StatOptions struct {
+ // Mask is the set of fields in the returned Statx that the FilesystemImpl
+ // or FileDescriptionImpl should provide. Bits are as in linux.Statx.Mask.
+ //
+ // The FilesystemImpl or FileDescriptionImpl may return fields not
+ // requested in Mask, and may fail to return fields requested in Mask that
+ // are not supported by the underlying filesystem implementation, without
+ // returning an error.
+ Mask uint32
+
+ // Sync specifies the synchronization required, and is one of
+ // linux.AT_STATX_SYNC_AS_STAT (which is 0, and therefore the default),
+ // linux.AT_STATX_SYNC_FORCE_SYNC, or linux.AT_STATX_SYNC_DONT_SYNC.
+ Sync uint32
+}
+
+// UmountOptions contains options to VirtualFilesystem.UmountAt().
+type UmountOptions struct {
+ // Flags contains flags as specified for umount2(2).
+ Flags uint32
+}
+
+// WriteOptions contains options to FileDescription.PWrite(),
+// FileDescriptionImpl.PWrite(), FileDescription.Write(), and
+// FileDescriptionImpl.Write().
+type WriteOptions struct {
+ // Flags contains flags as specified for pwritev2(2).
+ Flags uint32
+}
diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go
new file mode 100644
index 000000000..cd78d66bc
--- /dev/null
+++ b/pkg/sentry/vfs/pathname.go
@@ -0,0 +1,195 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+var fspathBuilderPool = sync.Pool{
+ New: func() interface{} {
+ return &fspath.Builder{}
+ },
+}
+
+func getFSPathBuilder() *fspath.Builder {
+ return fspathBuilderPool.Get().(*fspath.Builder)
+}
+
+func putFSPathBuilder(b *fspath.Builder) {
+ // No methods can be called on b after b.String(), so reset it to its zero
+ // value (as returned by fspathBuilderPool.New) instead.
+ *b = fspath.Builder{}
+ fspathBuilderPool.Put(b)
+}
+
+// PathnameWithDeleted returns an absolute pathname to vd, consistent with
+// Linux's d_path(). In particular, if vd.Dentry() has been disowned,
+// PathnameWithDeleted appends " (deleted)" to the returned pathname.
+func (vfs *VirtualFilesystem) PathnameWithDeleted(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) {
+ b := getFSPathBuilder()
+ defer putFSPathBuilder(b)
+ haveRef := false
+ defer func() {
+ if haveRef {
+ vd.DecRef()
+ }
+ }()
+
+ origD := vd.dentry
+loop:
+ for {
+ err := vd.mount.fs.impl.PrependPath(ctx, vfsroot, vd, b)
+ switch err.(type) {
+ case nil:
+ if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry {
+ // genericfstree.PrependPath() will have returned
+ // PrependPathAtVFSRootError in this case since it checks
+ // against vfsroot before mnt.root, but other implementations
+ // of FilesystemImpl.PrependPath() may return nil instead.
+ break loop
+ }
+ nextVD := vfs.getMountpointAt(vd.mount, vfsroot)
+ if !nextVD.Ok() {
+ break loop
+ }
+ if haveRef {
+ vd.DecRef()
+ }
+ vd = nextVD
+ haveRef = true
+ // continue loop
+ case PrependPathSyntheticError:
+ // Skip prepending "/" and appending " (deleted)".
+ return b.String(), nil
+ case PrependPathAtVFSRootError, PrependPathAtNonMountRootError:
+ break loop
+ default:
+ return "", err
+ }
+ }
+ b.PrependByte('/')
+ if origD.IsDead() {
+ b.AppendString(" (deleted)")
+ }
+ return b.String(), nil
+}
+
+// PathnameReachable returns an absolute pathname to vd, consistent with
+// Linux's __d_path() (as used by seq_path_root()). If vfsroot.Ok() and vd is
+// not reachable from vfsroot, such that seq_path_root() would return SEQ_SKIP
+// (causing the entire containing entry to be skipped), PathnameReachable
+// returns ("", nil).
+func (vfs *VirtualFilesystem) PathnameReachable(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) {
+ b := getFSPathBuilder()
+ defer putFSPathBuilder(b)
+ haveRef := false
+ defer func() {
+ if haveRef {
+ vd.DecRef()
+ }
+ }()
+loop:
+ for {
+ err := vd.mount.fs.impl.PrependPath(ctx, vfsroot, vd, b)
+ switch err.(type) {
+ case nil:
+ if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry {
+ break loop
+ }
+ nextVD := vfs.getMountpointAt(vd.mount, vfsroot)
+ if !nextVD.Ok() {
+ return "", nil
+ }
+ if haveRef {
+ vd.DecRef()
+ }
+ vd = nextVD
+ haveRef = true
+ case PrependPathAtVFSRootError:
+ break loop
+ case PrependPathAtNonMountRootError, PrependPathSyntheticError:
+ return "", nil
+ default:
+ return "", err
+ }
+ }
+ b.PrependByte('/')
+ return b.String(), nil
+}
+
+// PathnameForGetcwd returns an absolute pathname to vd, consistent with
+// Linux's sys_getcwd().
+func (vfs *VirtualFilesystem) PathnameForGetcwd(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) {
+ if vd.dentry.IsDead() {
+ return "", syserror.ENOENT
+ }
+
+ b := getFSPathBuilder()
+ defer putFSPathBuilder(b)
+ haveRef := false
+ defer func() {
+ if haveRef {
+ vd.DecRef()
+ }
+ }()
+ unreachable := false
+loop:
+ for {
+ err := vd.mount.fs.impl.PrependPath(ctx, vfsroot, vd, b)
+ switch err.(type) {
+ case nil:
+ if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry {
+ break loop
+ }
+ nextVD := vfs.getMountpointAt(vd.mount, vfsroot)
+ if !nextVD.Ok() {
+ unreachable = true
+ break loop
+ }
+ if haveRef {
+ vd.DecRef()
+ }
+ vd = nextVD
+ haveRef = true
+ case PrependPathAtVFSRootError:
+ break loop
+ case PrependPathAtNonMountRootError, PrependPathSyntheticError:
+ unreachable = true
+ break loop
+ default:
+ return "", err
+ }
+ }
+ b.PrependByte('/')
+ if unreachable {
+ b.PrependString("(unreachable)")
+ }
+ return b.String(), nil
+}
+
+// As of this writing, we do not have equivalents to:
+//
+// - d_absolute_path(), which returns EINVAL if (effectively) any call to
+// FilesystemImpl.PrependPath() would return PrependPathAtNonMountRootError.
+//
+// - dentry_path(), which does not walk up mounts (and only returns the path
+// relative to Filesystem root), but also appends "//deleted" for disowned
+// Dentries.
+//
+// These should be added as necessary.
diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go
new file mode 100644
index 000000000..9cb050597
--- /dev/null
+++ b/pkg/sentry/vfs/permissions.go
@@ -0,0 +1,280 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "math"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// AccessTypes is a bitmask of Unix file permissions.
+type AccessTypes uint16
+
+// Bits in AccessTypes.
+const (
+ MayExec AccessTypes = 1
+ MayWrite AccessTypes = 2
+ MayRead AccessTypes = 4
+)
+
+// OnlyRead returns true if access _only_ allows read.
+func (a AccessTypes) OnlyRead() bool {
+ return a == MayRead
+}
+
+// MayRead returns true if access allows read.
+func (a AccessTypes) MayRead() bool {
+ return a&MayRead != 0
+}
+
+// MayWrite returns true if access allows write.
+func (a AccessTypes) MayWrite() bool {
+ return a&MayWrite != 0
+}
+
+// MayExec returns true if access allows exec.
+func (a AccessTypes) MayExec() bool {
+ return a&MayExec != 0
+}
+
+// GenericCheckPermissions checks that creds has the given access rights on a
+// file with the given permissions, UID, and GID, subject to the rules of
+// fs/namei.c:generic_permission().
+func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
+ // Check permission bits.
+ perms := uint16(mode.Permissions())
+ if creds.EffectiveKUID == kuid {
+ perms >>= 6
+ } else if creds.InGroup(kgid) {
+ perms >>= 3
+ }
+ if uint16(ats)&perms == uint16(ats) {
+ // All permission bits match, access granted.
+ return nil
+ }
+
+ // Caller capabilities require that the file's KUID and KGID are mapped in
+ // the caller's user namespace; compare
+ // kernel/capability.c:privileged_wrt_inode_uidgid().
+ if !kuid.In(creds.UserNamespace).Ok() || !kgid.In(creds.UserNamespace).Ok() {
+ return syserror.EACCES
+ }
+ // CAP_DAC_READ_SEARCH allows the caller to read and search arbitrary
+ // directories, and read arbitrary non-directory files.
+ if (mode.IsDir() && !ats.MayWrite()) || ats.OnlyRead() {
+ if creds.HasCapability(linux.CAP_DAC_READ_SEARCH) {
+ return nil
+ }
+ }
+ // CAP_DAC_OVERRIDE allows arbitrary access to directories, read/write
+ // access to non-directory files, and execute access to non-directory files
+ // for which at least one execute bit is set.
+ if mode.IsDir() || !ats.MayExec() || (mode.Permissions()&0111 != 0) {
+ if creds.HasCapability(linux.CAP_DAC_OVERRIDE) {
+ return nil
+ }
+ }
+ return syserror.EACCES
+}
+
+// MayLink determines whether creating a hard link to a file with the given
+// mode, kuid, and kgid is permitted.
+//
+// This corresponds to Linux's fs/namei.c:may_linkat.
+func MayLink(creds *auth.Credentials, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
+ // Source inode owner can hardlink all they like; otherwise, it must be a
+ // safe source.
+ if CanActAsOwner(creds, kuid) {
+ return nil
+ }
+
+ // Only regular files can be hard linked.
+ if mode.FileType() != linux.S_IFREG {
+ return syserror.EPERM
+ }
+
+ // Setuid files should not get pinned to the filesystem.
+ if mode&linux.S_ISUID != 0 {
+ return syserror.EPERM
+ }
+
+ // Executable setgid files should not get pinned to the filesystem, but we
+ // don't support S_IXGRP anyway.
+
+ // Hardlinking to unreadable or unwritable sources is dangerous.
+ if err := GenericCheckPermissions(creds, MayRead|MayWrite, mode, kuid, kgid); err != nil {
+ return syserror.EPERM
+ }
+ return nil
+}
+
+// AccessTypesForOpenFlags returns the access types required to open a file
+// with the given OpenOptions.Flags. Note that this is NOT the same thing as
+// the set of accesses permitted for the opened file:
+//
+// - O_TRUNC causes MayWrite to be set in the returned AccessTypes (since it
+// mutates the file), but does not permit writing to the open file description
+// thereafter.
+//
+// - "Linux reserves the special, nonstandard access mode 3 (binary 11) in
+// flags to mean: check for read and write permission on the file and return a
+// file descriptor that can't be used for reading or writing." - open(2). Thus
+// AccessTypesForOpenFlags returns MayRead|MayWrite in this case.
+//
+// Use May{Read,Write}FileWithOpenFlags() for these checks instead.
+func AccessTypesForOpenFlags(opts *OpenOptions) AccessTypes {
+ ats := AccessTypes(0)
+ if opts.FileExec {
+ ats |= MayExec
+ }
+
+ switch opts.Flags & linux.O_ACCMODE {
+ case linux.O_RDONLY:
+ if opts.Flags&linux.O_TRUNC != 0 {
+ return ats | MayRead | MayWrite
+ }
+ return ats | MayRead
+ case linux.O_WRONLY:
+ return ats | MayWrite
+ default:
+ return ats | MayRead | MayWrite
+ }
+}
+
+// MayReadFileWithOpenFlags returns true if a file with the given open flags
+// should be readable.
+func MayReadFileWithOpenFlags(flags uint32) bool {
+ switch flags & linux.O_ACCMODE {
+ case linux.O_RDONLY, linux.O_RDWR:
+ return true
+ default:
+ return false
+ }
+}
+
+// MayWriteFileWithOpenFlags returns true if a file with the given open flags
+// should be writable.
+func MayWriteFileWithOpenFlags(flags uint32) bool {
+ switch flags & linux.O_ACCMODE {
+ case linux.O_WRONLY, linux.O_RDWR:
+ return true
+ default:
+ return false
+ }
+}
+
+// CheckSetStat checks that creds has permission to change the metadata of a
+// file with the given permissions, UID, and GID as specified by stat, subject
+// to the rules of Linux's fs/attr.c:setattr_prepare().
+func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ limit, err := CheckLimit(ctx, 0, int64(stat.Size))
+ if err != nil {
+ return err
+ }
+ if limit < int64(stat.Size) {
+ return syserror.ErrExceedsFileSizeLimit
+ }
+ }
+ if stat.Mask&linux.STATX_MODE != 0 {
+ if !CanActAsOwner(creds, kuid) {
+ return syserror.EPERM
+ }
+ // TODO(b/30815691): "If the calling process is not privileged (Linux:
+ // does not have the CAP_FSETID capability), and the group of the file
+ // does not match the effective group ID of the process or one of its
+ // supplementary group IDs, the S_ISGID bit will be turned off, but
+ // this will not cause an error to be returned." - chmod(2)
+ }
+ if stat.Mask&linux.STATX_UID != 0 {
+ if !((creds.EffectiveKUID == kuid && auth.KUID(stat.UID) == kuid) ||
+ HasCapabilityOnFile(creds, linux.CAP_CHOWN, kuid, kgid)) {
+ return syserror.EPERM
+ }
+ }
+ if stat.Mask&linux.STATX_GID != 0 {
+ if !((creds.EffectiveKUID == kuid && creds.InGroup(auth.KGID(stat.GID))) ||
+ HasCapabilityOnFile(creds, linux.CAP_CHOWN, kuid, kgid)) {
+ return syserror.EPERM
+ }
+ }
+ if stat.Mask&(linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME) != 0 {
+ if !CanActAsOwner(creds, kuid) {
+ if (stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW) ||
+ (stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW) ||
+ (stat.Mask&linux.STATX_CTIME != 0 && stat.Ctime.Nsec != linux.UTIME_NOW) {
+ return syserror.EPERM
+ }
+ if err := GenericCheckPermissions(creds, MayWrite, mode, kuid, kgid); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// CheckDeleteSticky checks whether the sticky bit is set on a directory with
+// the given file mode, and if so, checks whether creds has permission to
+// remove a file owned by childKUID from a directory with the given mode.
+// CheckDeleteSticky is consistent with fs/linux.h:check_sticky().
+func CheckDeleteSticky(creds *auth.Credentials, parentMode linux.FileMode, childKUID auth.KUID) error {
+ if parentMode&linux.ModeSticky == 0 {
+ return nil
+ }
+ if CanActAsOwner(creds, childKUID) {
+ return nil
+ }
+ return syserror.EPERM
+}
+
+// CanActAsOwner returns true if creds can act as the owner of a file with the
+// given owning UID, consistent with Linux's
+// fs/inode.c:inode_owner_or_capable().
+func CanActAsOwner(creds *auth.Credentials, kuid auth.KUID) bool {
+ if creds.EffectiveKUID == kuid {
+ return true
+ }
+ return creds.HasCapability(linux.CAP_FOWNER) && creds.UserNamespace.MapFromKUID(kuid).Ok()
+}
+
+// HasCapabilityOnFile returns true if creds has the given capability with
+// respect to a file with the given owning UID and GID, consistent with Linux's
+// kernel/capability.c:capable_wrt_inode_uidgid().
+func HasCapabilityOnFile(creds *auth.Credentials, cp linux.Capability, kuid auth.KUID, kgid auth.KGID) bool {
+ return creds.HasCapability(cp) && creds.UserNamespace.MapFromKUID(kuid).Ok() && creds.UserNamespace.MapFromKGID(kgid).Ok()
+}
+
+// CheckLimit enforces file size rlimits. It returns error if the write
+// operation must not proceed. Otherwise it returns the max length allowed to
+// without violating the limit.
+func CheckLimit(ctx context.Context, offset, size int64) (int64, error) {
+ fileSizeLimit := limits.FromContext(ctx).Get(limits.FileSize).Cur
+ if fileSizeLimit > math.MaxInt64 {
+ return size, nil
+ }
+ if offset >= int64(fileSizeLimit) {
+ return 0, syserror.ErrExceedsFileSizeLimit
+ }
+ remaining := int64(fileSizeLimit) - offset
+ if remaining < size {
+ return remaining, nil
+ }
+ return size, nil
+}
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
new file mode 100644
index 000000000..9d047ff88
--- /dev/null
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -0,0 +1,466 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// ResolvingPath represents the state of an in-progress path resolution, shared
+// between VFS and FilesystemImpl methods that take a path.
+//
+// From the perspective of FilesystemImpl methods, a ResolvingPath represents a
+// starting Dentry on the associated Filesystem (on which a reference is
+// already held), a stream of path components relative to that Dentry, and
+// elements of the invoking Context that are commonly required by
+// FilesystemImpl methods.
+//
+// ResolvingPath is loosely analogous to Linux's struct nameidata.
+type ResolvingPath struct {
+ vfs *VirtualFilesystem
+ root VirtualDentry // refs borrowed from PathOperation
+ mount *Mount
+ start *Dentry
+ pit fspath.Iterator
+
+ flags uint16
+ mustBeDir bool // final file must be a directory?
+ mustBeDirOrig bool
+ symlinks uint8 // number of symlinks traversed
+ symlinksOrig uint8
+ curPart uint8 // index into parts
+ numOrigParts uint8
+
+ creds *auth.Credentials
+
+ // Data associated with resolve*Errors, stored in ResolvingPath so that
+ // those errors don't need to allocate.
+ nextMount *Mount // ref held if not nil
+ nextStart *Dentry // ref held if not nil
+ absSymlinkTarget fspath.Path
+
+ // ResolvingPath must track up to two relative paths: the "current"
+ // relative path, which is updated whenever a relative symlink is
+ // encountered, and the "original" relative path, which is updated from the
+ // current relative path by handleError() when resolution must change
+ // filesystems (due to reaching a mount boundary or absolute symlink) and
+ // overwrites the current relative path when Restart() is called.
+ parts [1 + linux.MaxSymlinkTraversals]fspath.Iterator
+ origParts [1 + linux.MaxSymlinkTraversals]fspath.Iterator
+}
+
+const (
+ rpflagsHaveMountRef = 1 << iota // do we hold a reference on mount?
+ rpflagsHaveStartRef // do we hold a reference on start?
+ rpflagsFollowFinalSymlink // same as PathOperation.FollowFinalSymlink
+)
+
+func init() {
+ if maxParts := len(ResolvingPath{}.parts); maxParts > 255 {
+ panic(fmt.Sprintf("uint8 is insufficient to accommodate len(ResolvingPath.parts) (%d)", maxParts))
+ }
+}
+
+// Error types that communicate state from the FilesystemImpl-caller,
+// VFS-callee side of path resolution (i.e. errors returned by
+// ResolvingPath.Resolve*()) to the VFS-caller, FilesystemImpl-callee side
+// (i.e. VFS methods => ResolvingPath.handleError()). These are empty structs
+// rather than error values because Go doesn't support non-primitive constants,
+// so error "constants" are really mutable vars, necessitating somewhat
+// expensive interface object comparisons.
+
+type resolveMountRootOrJumpError struct{}
+
+// Error implements error.Error.
+func (resolveMountRootOrJumpError) Error() string {
+ return "resolving mount root or jump"
+}
+
+type resolveMountPointError struct{}
+
+// Error implements error.Error.
+func (resolveMountPointError) Error() string {
+ return "resolving mount point"
+}
+
+type resolveAbsSymlinkError struct{}
+
+// Error implements error.Error.
+func (resolveAbsSymlinkError) Error() string {
+ return "resolving absolute symlink"
+}
+
+var resolvingPathPool = sync.Pool{
+ New: func() interface{} {
+ return &ResolvingPath{}
+ },
+}
+
+func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *PathOperation) *ResolvingPath {
+ rp := resolvingPathPool.Get().(*ResolvingPath)
+ rp.vfs = vfs
+ rp.root = pop.Root
+ rp.mount = pop.Start.mount
+ rp.start = pop.Start.dentry
+ rp.pit = pop.Path.Begin
+ rp.flags = 0
+ if pop.FollowFinalSymlink {
+ rp.flags |= rpflagsFollowFinalSymlink
+ }
+ rp.mustBeDir = pop.Path.Dir
+ rp.mustBeDirOrig = pop.Path.Dir
+ rp.symlinks = 0
+ rp.curPart = 0
+ rp.numOrigParts = 1
+ rp.creds = creds
+ rp.parts[0] = pop.Path.Begin
+ rp.origParts[0] = pop.Path.Begin
+ return rp
+}
+
+func (vfs *VirtualFilesystem) putResolvingPath(rp *ResolvingPath) {
+ rp.root = VirtualDentry{}
+ rp.decRefStartAndMount()
+ rp.mount = nil
+ rp.start = nil
+ rp.releaseErrorState()
+ resolvingPathPool.Put(rp)
+}
+
+func (rp *ResolvingPath) decRefStartAndMount() {
+ if rp.flags&rpflagsHaveStartRef != 0 {
+ rp.start.DecRef()
+ }
+ if rp.flags&rpflagsHaveMountRef != 0 {
+ rp.mount.DecRef()
+ }
+}
+
+func (rp *ResolvingPath) releaseErrorState() {
+ if rp.nextStart != nil {
+ rp.nextStart.DecRef()
+ rp.nextStart = nil
+ }
+ if rp.nextMount != nil {
+ rp.nextMount.DecRef()
+ rp.nextMount = nil
+ }
+}
+
+// VirtualFilesystem returns the containing VirtualFilesystem.
+func (rp *ResolvingPath) VirtualFilesystem() *VirtualFilesystem {
+ return rp.vfs
+}
+
+// Credentials returns the credentials of rp's provider.
+func (rp *ResolvingPath) Credentials() *auth.Credentials {
+ return rp.creds
+}
+
+// Mount returns the Mount on which path resolution is currently occurring. It
+// does not take a reference on the returned Mount.
+func (rp *ResolvingPath) Mount() *Mount {
+ return rp.mount
+}
+
+// Start returns the starting Dentry represented by rp. It does not take a
+// reference on the returned Dentry.
+func (rp *ResolvingPath) Start() *Dentry {
+ return rp.start
+}
+
+// Done returns true if there are no remaining path components in the stream
+// represented by rp.
+func (rp *ResolvingPath) Done() bool {
+ // We don't need to check for rp.curPart == 0 because rp.Advance() won't
+ // set rp.pit to a terminal iterator otherwise.
+ return !rp.pit.Ok()
+}
+
+// Final returns true if there is exactly one remaining path component in the
+// stream represented by rp.
+//
+// Preconditions: !rp.Done().
+func (rp *ResolvingPath) Final() bool {
+ return rp.curPart == 0 && !rp.pit.NextOk()
+}
+
+// Component returns the current path component in the stream represented by
+// rp.
+//
+// Preconditions: !rp.Done().
+func (rp *ResolvingPath) Component() string {
+ if checkInvariants {
+ if !rp.pit.Ok() {
+ panic("ResolvingPath.Component() called at end of relative path")
+ }
+ }
+ return rp.pit.String()
+}
+
+// Advance advances the stream of path components represented by rp.
+//
+// Preconditions: !rp.Done().
+func (rp *ResolvingPath) Advance() {
+ if checkInvariants {
+ if !rp.pit.Ok() {
+ panic("ResolvingPath.Advance() called at end of relative path")
+ }
+ }
+ next := rp.pit.Next()
+ if next.Ok() || rp.curPart == 0 { // have next component, or at end of path
+ rp.pit = next
+ } else { // at end of path segment, continue with next one
+ rp.curPart--
+ rp.pit = rp.parts[rp.curPart]
+ }
+}
+
+// Restart resets the stream of path components represented by rp to its state
+// on entry to the current FilesystemImpl method.
+func (rp *ResolvingPath) Restart() {
+ rp.pit = rp.origParts[rp.numOrigParts-1]
+ rp.mustBeDir = rp.mustBeDirOrig
+ rp.symlinks = rp.symlinksOrig
+ rp.curPart = rp.numOrigParts - 1
+ copy(rp.parts[:], rp.origParts[:rp.numOrigParts])
+ rp.releaseErrorState()
+}
+
+func (rp *ResolvingPath) relpathCommit() {
+ rp.mustBeDirOrig = rp.mustBeDir
+ rp.symlinksOrig = rp.symlinks
+ rp.numOrigParts = rp.curPart + 1
+ copy(rp.origParts[:rp.curPart], rp.parts[:])
+ rp.origParts[rp.curPart] = rp.pit
+}
+
+// CheckRoot is called before resolving the parent of the Dentry d. If the
+// Dentry is contextually a VFS root, such that path resolution should treat
+// d's parent as itself, CheckRoot returns (true, nil). If the Dentry is the
+// root of a non-root mount, such that path resolution should switch to another
+// Mount, CheckRoot returns (unspecified, non-nil error). Otherwise, path
+// resolution should resolve d's parent normally, and CheckRoot returns (false,
+// nil).
+func (rp *ResolvingPath) CheckRoot(d *Dentry) (bool, error) {
+ if d == rp.root.dentry && rp.mount == rp.root.mount {
+ // At contextual VFS root (due to e.g. chroot(2)).
+ return true, nil
+ } else if d == rp.mount.root {
+ // At mount root ...
+ vd := rp.vfs.getMountpointAt(rp.mount, rp.root)
+ if vd.Ok() {
+ // ... of non-root mount.
+ rp.nextMount = vd.mount
+ rp.nextStart = vd.dentry
+ return false, resolveMountRootOrJumpError{}
+ }
+ // ... of root mount.
+ return true, nil
+ }
+ return false, nil
+}
+
+// CheckMount is called after resolving the parent or child of another Dentry
+// to d. If d is a mount point, such that path resolution should switch to
+// another Mount, CheckMount returns a non-nil error. Otherwise, CheckMount
+// returns nil.
+func (rp *ResolvingPath) CheckMount(d *Dentry) error {
+ if !d.isMounted() {
+ return nil
+ }
+ if mnt := rp.vfs.getMountAt(rp.mount, d); mnt != nil {
+ rp.nextMount = mnt
+ return resolveMountPointError{}
+ }
+ return nil
+}
+
+// ShouldFollowSymlink returns true if, supposing that the current path
+// component in pcs represents a symbolic link, the symbolic link should be
+// followed.
+//
+// If path is terminated with '/', the '/' is considered the last element and
+// any symlink before that is followed:
+// - For most non-creating walks, the last path component is handled by
+// fs/namei.c:lookup_last(), which sets LOOKUP_FOLLOW if the first byte
+// after the path component is non-NULL (which is only possible if it's '/')
+// and the path component is of type LAST_NORM.
+//
+// - For open/openat/openat2 without O_CREAT, the last path component is
+// handled by fs/namei.c:do_last(), which does the same, though without the
+// LAST_NORM check.
+//
+// Preconditions: !rp.Done().
+func (rp *ResolvingPath) ShouldFollowSymlink() bool {
+ // Non-final symlinks are always followed. Paths terminated with '/' are also
+ // always followed.
+ return rp.flags&rpflagsFollowFinalSymlink != 0 || !rp.Final() || rp.MustBeDir()
+}
+
+// HandleSymlink is called when the current path component is a symbolic link
+// to the given target. If the calling Filesystem method should continue path
+// traversal, HandleSymlink updates the path component stream to reflect the
+// symlink target and returns nil. Otherwise it returns a non-nil error.
+//
+// Preconditions: !rp.Done().
+//
+// Postconditions: If HandleSymlink returns a nil error, then !rp.Done().
+func (rp *ResolvingPath) HandleSymlink(target string) error {
+ if rp.symlinks >= linux.MaxSymlinkTraversals {
+ return syserror.ELOOP
+ }
+ if len(target) == 0 {
+ return syserror.ENOENT
+ }
+ rp.symlinks++
+ targetPath := fspath.Parse(target)
+ if targetPath.Absolute {
+ rp.absSymlinkTarget = targetPath
+ return resolveAbsSymlinkError{}
+ }
+ // Consume the path component that represented the symlink.
+ rp.Advance()
+ // Prepend the symlink target to the relative path.
+ if checkInvariants {
+ if !targetPath.HasComponents() {
+ panic(fmt.Sprintf("non-empty pathname %q parsed to relative path with no components", target))
+ }
+ }
+ rp.relpathPrepend(targetPath)
+ return nil
+}
+
+// Preconditions: path.HasComponents().
+func (rp *ResolvingPath) relpathPrepend(path fspath.Path) {
+ if rp.pit.Ok() {
+ rp.parts[rp.curPart] = rp.pit
+ rp.pit = path.Begin
+ rp.curPart++
+ } else {
+ // The symlink was the final path component, so now the symlink target
+ // is the whole path.
+ rp.pit = path.Begin
+ // Symlink targets can set rp.mustBeDir (if they end in a trailing /),
+ // but can't unset it.
+ if path.Dir {
+ rp.mustBeDir = true
+ }
+ }
+}
+
+// HandleJump is called when the current path component is a "magic" link to
+// the given VirtualDentry, like /proc/[pid]/fd/[fd]. If the calling Filesystem
+// method should continue path traversal, HandleMagicSymlink updates the path
+// component stream to reflect the magic link target and returns nil. Otherwise
+// it returns a non-nil error.
+//
+// Preconditions: !rp.Done().
+func (rp *ResolvingPath) HandleJump(target VirtualDentry) error {
+ if rp.symlinks >= linux.MaxSymlinkTraversals {
+ return syserror.ELOOP
+ }
+ rp.symlinks++
+ // Consume the path component that represented the magic link.
+ rp.Advance()
+ // Unconditionally return a resolveMountRootOrJumpError, even if the Mount
+ // isn't changing, to force restarting at the new Dentry.
+ target.IncRef()
+ rp.nextMount = target.mount
+ rp.nextStart = target.dentry
+ return resolveMountRootOrJumpError{}
+}
+
+func (rp *ResolvingPath) handleError(err error) bool {
+ switch err.(type) {
+ case resolveMountRootOrJumpError:
+ // Switch to the new Mount. We hold references on the Mount and Dentry.
+ rp.decRefStartAndMount()
+ rp.mount = rp.nextMount
+ rp.start = rp.nextStart
+ rp.flags |= rpflagsHaveMountRef | rpflagsHaveStartRef
+ rp.nextMount = nil
+ rp.nextStart = nil
+ // Commit the previous FileystemImpl's progress through the relative
+ // path. (Don't consume the path component that caused us to traverse
+ // through the mount root - i.e. the ".." - because we still need to
+ // resolve the mount point's parent in the new FilesystemImpl.)
+ rp.relpathCommit()
+ // Restart path resolution on the new Mount. Don't bother calling
+ // rp.releaseErrorState() since we already set nextMount and nextStart
+ // to nil above.
+ return true
+
+ case resolveMountPointError:
+ // Switch to the new Mount. We hold a reference on the Mount, but
+ // borrow the reference on the mount root from the Mount.
+ rp.decRefStartAndMount()
+ rp.mount = rp.nextMount
+ rp.start = rp.nextMount.root
+ rp.flags = rp.flags&^rpflagsHaveStartRef | rpflagsHaveMountRef
+ rp.nextMount = nil
+ // Consume the path component that represented the mount point.
+ rp.Advance()
+ // Commit the previous FilesystemImpl's progress through the relative
+ // path.
+ rp.relpathCommit()
+ // Restart path resolution on the new Mount.
+ rp.releaseErrorState()
+ return true
+
+ case resolveAbsSymlinkError:
+ // Switch to the new Mount. References are borrowed from rp.root.
+ rp.decRefStartAndMount()
+ rp.mount = rp.root.mount
+ rp.start = rp.root.dentry
+ rp.flags &^= rpflagsHaveMountRef | rpflagsHaveStartRef
+ // Consume the path component that represented the symlink.
+ rp.Advance()
+ // Prepend the symlink target to the relative path.
+ rp.relpathPrepend(rp.absSymlinkTarget)
+ // Commit the previous FilesystemImpl's progress through the relative
+ // path, including the symlink target we just prepended.
+ rp.relpathCommit()
+ // Restart path resolution on the new Mount.
+ rp.releaseErrorState()
+ return true
+
+ default:
+ // Not an error we can handle.
+ return false
+ }
+}
+
+// canHandleError returns true if err is an error returned by rp.Resolve*()
+// that rp.handleError() may attempt to handle.
+func (rp *ResolvingPath) canHandleError(err error) bool {
+ switch err.(type) {
+ case resolveMountRootOrJumpError, resolveMountPointError, resolveAbsSymlinkError:
+ return true
+ default:
+ return false
+ }
+}
+
+// MustBeDir returns true if the file traversed by rp must be a directory.
+func (rp *ResolvingPath) MustBeDir() bool {
+ return rp.mustBeDir
+}
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
new file mode 100644
index 000000000..522e27475
--- /dev/null
+++ b/pkg/sentry/vfs/vfs.go
@@ -0,0 +1,849 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package vfs implements a virtual filesystem layer.
+//
+// Lock order:
+//
+// EpollInstance.interestMu
+// FileDescription.epollMu
+// FilesystemImpl/FileDescriptionImpl locks
+// VirtualFilesystem.mountMu
+// Dentry.mu
+// Locks acquired by FilesystemImpls between Prepare{Delete,Rename}Dentry and Commit{Delete,Rename*}Dentry
+// VirtualFilesystem.filesystemsMu
+// EpollInstance.mu
+// Inotify.mu
+// Watches.mu
+// Inotify.evMu
+// VirtualFilesystem.fsTypesMu
+//
+// Locking Dentry.mu in multiple Dentries requires holding
+// VirtualFilesystem.mountMu. Locking EpollInstance.interestMu in multiple
+// EpollInstances requires holding epollCycleMu.
+package vfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// A VirtualFilesystem (VFS for short) combines Filesystems in trees of Mounts.
+//
+// There is no analogue to the VirtualFilesystem type in Linux, as the
+// equivalent state in Linux is global.
+//
+// +stateify savable
+type VirtualFilesystem struct {
+ // mountMu serializes mount mutations.
+ //
+ // mountMu is analogous to Linux's namespace_sem.
+ mountMu sync.Mutex `state:"nosave"`
+
+ // mounts maps (mount parent, mount point) pairs to mounts. (Since mounts
+ // are uniquely namespaced, including mount parent in the key correctly
+ // handles both bind mounts and mount namespaces; Linux does the same.)
+ // Synchronization between mutators and readers is provided by mounts.seq;
+ // synchronization between mutators is provided by mountMu.
+ //
+ // mounts is used to follow mount points during path traversal. We use a
+ // single table rather than per-Dentry tables to reduce size (and therefore
+ // cache footprint) for the vast majority of Dentries that are not mount
+ // points.
+ //
+ // mounts is analogous to Linux's mount_hashtable.
+ mounts mountTable
+
+ // mountpoints maps mount points to mounts at those points in all
+ // namespaces. mountpoints is protected by mountMu.
+ //
+ // mountpoints is used to find mounts that must be umounted due to
+ // removal of a mount point Dentry from another mount namespace. ("A file
+ // or directory that is a mount point in one namespace that is not a mount
+ // point in another namespace, may be renamed, unlinked, or removed
+ // (rmdir(2)) in the mount namespace in which it is not a mount point
+ // (subject to the usual permission checks)." - mount_namespaces(7))
+ //
+ // mountpoints is analogous to Linux's mountpoint_hashtable.
+ mountpoints map[*Dentry]map[*Mount]struct{}
+
+ // lastMountID is the last allocated mount ID. lastMountID is accessed
+ // using atomic memory operations.
+ lastMountID uint64
+
+ // anonMount is a Mount, not included in mounts or mountpoints,
+ // representing an anonFilesystem. anonMount is used to back
+ // VirtualDentries returned by VirtualFilesystem.NewAnonVirtualDentry().
+ // anonMount is immutable.
+ //
+ // anonMount is analogous to Linux's anon_inode_mnt.
+ anonMount *Mount
+
+ // devices contains all registered Devices. devices is protected by
+ // devicesMu.
+ devicesMu sync.RWMutex `state:"nosave"`
+ devices map[devTuple]*registeredDevice
+
+ // anonBlockDevMinor contains all allocated anonymous block device minor
+ // numbers. anonBlockDevMinorNext is a lower bound for the smallest
+ // unallocated anonymous block device number. anonBlockDevMinorNext and
+ // anonBlockDevMinor are protected by anonBlockDevMinorMu.
+ anonBlockDevMinorMu sync.Mutex `state:"nosave"`
+ anonBlockDevMinorNext uint32
+ anonBlockDevMinor map[uint32]struct{}
+
+ // fsTypes contains all registered FilesystemTypes. fsTypes is protected by
+ // fsTypesMu.
+ fsTypesMu sync.RWMutex `state:"nosave"`
+ fsTypes map[string]*registeredFilesystemType
+
+ // filesystems contains all Filesystems. filesystems is protected by
+ // filesystemsMu.
+ filesystemsMu sync.Mutex `state:"nosave"`
+ filesystems map[*Filesystem]struct{}
+}
+
+// Init initializes a new VirtualFilesystem with no mounts or FilesystemTypes.
+func (vfs *VirtualFilesystem) Init() error {
+ if vfs.mountpoints != nil {
+ panic("VFS already initialized")
+ }
+ vfs.mountpoints = make(map[*Dentry]map[*Mount]struct{})
+ vfs.devices = make(map[devTuple]*registeredDevice)
+ vfs.anonBlockDevMinorNext = 1
+ vfs.anonBlockDevMinor = make(map[uint32]struct{})
+ vfs.fsTypes = make(map[string]*registeredFilesystemType)
+ vfs.filesystems = make(map[*Filesystem]struct{})
+ vfs.mounts.Init()
+
+ // Construct vfs.anonMount.
+ anonfsDevMinor, err := vfs.GetAnonBlockDevMinor()
+ if err != nil {
+ // This shouldn't be possible since anonBlockDevMinorNext was
+ // initialized to 1 above (no device numbers have been allocated yet).
+ panic(fmt.Sprintf("VirtualFilesystem.Init: device number allocation for anonfs failed: %v", err))
+ }
+ anonfs := anonFilesystem{
+ devMinor: anonfsDevMinor,
+ }
+ anonfs.vfsfs.Init(vfs, &anonFilesystemType{}, &anonfs)
+ defer anonfs.vfsfs.DecRef()
+ anonMount, err := vfs.NewDisconnectedMount(&anonfs.vfsfs, nil, &MountOptions{})
+ if err != nil {
+ // We should not be passing any MountOptions that would cause
+ // construction of this mount to fail.
+ panic(fmt.Sprintf("VirtualFilesystem.Init: anonfs mount failed: %v", err))
+ }
+ vfs.anonMount = anonMount
+
+ return nil
+}
+
+// PathOperation specifies the path operated on by a VFS method.
+//
+// PathOperation is passed to VFS methods by pointer to reduce memory copying:
+// it's somewhat large and should never escape. (Options structs are passed by
+// pointer to VFS and FileDescription methods for the same reason.)
+type PathOperation struct {
+ // Root is the VFS root. References on Root are borrowed from the provider
+ // of the PathOperation.
+ //
+ // Invariants: Root.Ok().
+ Root VirtualDentry
+
+ // Start is the starting point for the path traversal. References on Start
+ // are borrowed from the provider of the PathOperation (i.e. the caller of
+ // the VFS method to which the PathOperation was passed).
+ //
+ // Invariants: Start.Ok(). If Path.Absolute, then Start == Root.
+ Start VirtualDentry
+
+ // Path is the pathname traversed by this operation.
+ Path fspath.Path
+
+ // If FollowFinalSymlink is true, and the Dentry traversed by the final
+ // path component represents a symbolic link, the symbolic link should be
+ // followed.
+ FollowFinalSymlink bool
+}
+
+// AccessAt checks whether a user with creds has access to the file at
+// the given path.
+func (vfs *VirtualFilesystem) AccessAt(ctx context.Context, creds *auth.Credentials, ats AccessTypes, pop *PathOperation) error {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.AccessAt(ctx, rp, creds, ats)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// GetDentryAt returns a VirtualDentry representing the given path, at which a
+// file must exist. A reference is taken on the returned VirtualDentry.
+func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ d, err := rp.mount.fs.impl.GetDentryAt(ctx, rp, *opts)
+ if err == nil {
+ vd := VirtualDentry{
+ mount: rp.mount,
+ dentry: d,
+ }
+ rp.mount.IncRef()
+ vfs.putResolvingPath(rp)
+ return vd, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return VirtualDentry{}, err
+ }
+ }
+}
+
+// Preconditions: pop.Path.Begin.Ok().
+func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (VirtualDentry, string, error) {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ parent, err := rp.mount.fs.impl.GetParentDentryAt(ctx, rp)
+ if err == nil {
+ parentVD := VirtualDentry{
+ mount: rp.mount,
+ dentry: parent,
+ }
+ rp.mount.IncRef()
+ name := rp.Component()
+ vfs.putResolvingPath(rp)
+ return parentVD, name, nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.GetParentDentryAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return VirtualDentry{}, "", err
+ }
+ }
+}
+
+// LinkAt creates a hard link at newpop representing the existing file at
+// oldpop.
+func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation) error {
+ oldVD, err := vfs.GetDentryAt(ctx, creds, oldpop, &GetDentryOptions{})
+ if err != nil {
+ return err
+ }
+
+ if !newpop.Path.Begin.Ok() {
+ oldVD.DecRef()
+ if newpop.Path.Absolute {
+ return syserror.EEXIST
+ }
+ return syserror.ENOENT
+ }
+ if newpop.FollowFinalSymlink {
+ oldVD.DecRef()
+ ctx.Warningf("VirtualFilesystem.LinkAt: file creation paths can't follow final symlink")
+ return syserror.EINVAL
+ }
+
+ rp := vfs.getResolvingPath(creds, newpop)
+ for {
+ err := rp.mount.fs.impl.LinkAt(ctx, rp, oldVD)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ oldVD.DecRef()
+ return nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.LinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ oldVD.DecRef()
+ return err
+ }
+ }
+}
+
+// MkdirAt creates a directory at the given path.
+func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error {
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return syserror.EEXIST
+ }
+ return syserror.ENOENT
+ }
+ if pop.FollowFinalSymlink {
+ ctx.Warningf("VirtualFilesystem.MkdirAt: file creation paths can't follow final symlink")
+ return syserror.EINVAL
+ }
+ // "Under Linux, apart from the permission bits, the S_ISVTX mode bit is
+ // also honored." - mkdir(2)
+ opts.Mode &= 0777 | linux.S_ISVTX
+
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.MkdirAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// MknodAt creates a file of the given mode at the given path. It returns an
+// error from the syserror package.
+func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error {
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return syserror.EEXIST
+ }
+ return syserror.ENOENT
+ }
+ if pop.FollowFinalSymlink {
+ ctx.Warningf("VirtualFilesystem.MknodAt: file creation paths can't follow final symlink")
+ return syserror.EINVAL
+ }
+
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.MknodAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.MknodAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// OpenAt returns a FileDescription providing access to the file at the given
+// path. A reference is taken on the returned FileDescription.
+func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) {
+ // Remove:
+ //
+ // - O_CLOEXEC, which affects file descriptors and therefore must be
+ // handled outside of VFS.
+ //
+ // - Unknown flags.
+ opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_LARGEFILE | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE
+ // Linux's __O_SYNC (which we call linux.O_SYNC) implies O_DSYNC.
+ if opts.Flags&linux.O_SYNC != 0 {
+ opts.Flags |= linux.O_DSYNC
+ }
+ // Linux's __O_TMPFILE (which we call linux.O_TMPFILE) must be specified
+ // with O_DIRECTORY and a writable access mode (to ensure that it fails on
+ // filesystem implementations that do not support it).
+ if opts.Flags&linux.O_TMPFILE != 0 {
+ if opts.Flags&linux.O_DIRECTORY == 0 {
+ return nil, syserror.EINVAL
+ }
+ if opts.Flags&linux.O_CREAT != 0 {
+ return nil, syserror.EINVAL
+ }
+ if opts.Flags&linux.O_ACCMODE == linux.O_RDONLY {
+ return nil, syserror.EINVAL
+ }
+ }
+ // O_PATH causes most other flags to be ignored.
+ if opts.Flags&linux.O_PATH != 0 {
+ opts.Flags &= linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_PATH
+ }
+ // "On Linux, the following bits are also honored in mode: [S_ISUID,
+ // S_ISGID, S_ISVTX]" - open(2)
+ opts.Mode &= 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX
+
+ if opts.Flags&linux.O_NOFOLLOW != 0 {
+ pop.FollowFinalSymlink = false
+ }
+ rp := vfs.getResolvingPath(creds, pop)
+ if opts.Flags&linux.O_DIRECTORY != 0 {
+ rp.mustBeDir = true
+ rp.mustBeDirOrig = true
+ }
+ for {
+ fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+
+ if opts.FileExec {
+ if fd.Mount().Flags.NoExec {
+ fd.DecRef()
+ return nil, syserror.EACCES
+ }
+
+ // Only a regular file can be executed.
+ stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_TYPE})
+ if err != nil {
+ fd.DecRef()
+ return nil, err
+ }
+ if stat.Mask&linux.STATX_TYPE == 0 || stat.Mode&linux.S_IFMT != linux.S_IFREG {
+ fd.DecRef()
+ return nil, syserror.EACCES
+ }
+ }
+
+ fd.Dentry().InotifyWithParent(linux.IN_OPEN, 0, PathEvent)
+ return fd, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return nil, err
+ }
+ }
+}
+
+// ReadlinkAt returns the target of the symbolic link at the given path.
+func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (string, error) {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ target, err := rp.mount.fs.impl.ReadlinkAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return target, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return "", err
+ }
+ }
+}
+
+// RenameAt renames the file at oldpop to newpop.
+func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation, opts *RenameOptions) error {
+ if !oldpop.Path.Begin.Ok() {
+ if oldpop.Path.Absolute {
+ return syserror.EBUSY
+ }
+ return syserror.ENOENT
+ }
+ if oldpop.FollowFinalSymlink {
+ ctx.Warningf("VirtualFilesystem.RenameAt: source path can't follow final symlink")
+ return syserror.EINVAL
+ }
+
+ oldParentVD, oldName, err := vfs.getParentDirAndName(ctx, creds, oldpop)
+ if err != nil {
+ return err
+ }
+ if oldName == "." || oldName == ".." {
+ oldParentVD.DecRef()
+ return syserror.EBUSY
+ }
+
+ if !newpop.Path.Begin.Ok() {
+ oldParentVD.DecRef()
+ if newpop.Path.Absolute {
+ return syserror.EBUSY
+ }
+ return syserror.ENOENT
+ }
+ if newpop.FollowFinalSymlink {
+ oldParentVD.DecRef()
+ ctx.Warningf("VirtualFilesystem.RenameAt: destination path can't follow final symlink")
+ return syserror.EINVAL
+ }
+
+ rp := vfs.getResolvingPath(creds, newpop)
+ renameOpts := *opts
+ if oldpop.Path.Dir {
+ renameOpts.MustBeDir = true
+ }
+ for {
+ err := rp.mount.fs.impl.RenameAt(ctx, rp, oldParentVD, oldName, renameOpts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ oldParentVD.DecRef()
+ return nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.RenameAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ oldParentVD.DecRef()
+ return err
+ }
+ }
+}
+
+// RmdirAt removes the directory at the given path.
+func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error {
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return syserror.EBUSY
+ }
+ return syserror.ENOENT
+ }
+ if pop.FollowFinalSymlink {
+ ctx.Warningf("VirtualFilesystem.RmdirAt: file deletion paths can't follow final symlink")
+ return syserror.EINVAL
+ }
+
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.RmdirAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.RmdirAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// SetStatAt changes metadata for the file at the given path.
+func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetStatOptions) error {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.SetStatAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// StatAt returns metadata for the file at the given path.
+func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *StatOptions) (linux.Statx, error) {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return stat, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return linux.Statx{}, err
+ }
+ }
+}
+
+// StatFSAt returns metadata for the filesystem containing the file at the
+// given path.
+func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (linux.Statfs, error) {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ statfs, err := rp.mount.fs.impl.StatFSAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return statfs, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return linux.Statfs{}, err
+ }
+ }
+}
+
+// SymlinkAt creates a symbolic link at the given path with the given target.
+func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, target string) error {
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return syserror.EEXIST
+ }
+ return syserror.ENOENT
+ }
+ if pop.FollowFinalSymlink {
+ ctx.Warningf("VirtualFilesystem.SymlinkAt: file creation paths can't follow final symlink")
+ return syserror.EINVAL
+ }
+
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.SymlinkAt(ctx, rp, target)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.SymlinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// UnlinkAt deletes the non-directory file at the given path.
+func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error {
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return syserror.EBUSY
+ }
+ return syserror.ENOENT
+ }
+ if pop.FollowFinalSymlink {
+ ctx.Warningf("VirtualFilesystem.UnlinkAt: file deletion paths can't follow final symlink")
+ return syserror.EINVAL
+ }
+
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.UnlinkAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.UnlinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// BoundEndpointAt gets the bound endpoint at the given path, if one exists.
+func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *BoundEndpointOptions) (transport.BoundEndpoint, error) {
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return nil, syserror.ECONNREFUSED
+ }
+ return nil, syserror.ENOENT
+ }
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return bep, nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.BoundEndpointAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return nil, err
+ }
+ }
+}
+
+// ListxattrAt returns all extended attribute names for the file at the given
+// path.
+func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, size uint64) ([]string, error) {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp, size)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return names, nil
+ }
+ if err == syserror.ENOTSUP {
+ // Linux doesn't actually return ENOTSUP in this case; instead,
+ // fs/xattr.c:vfs_listxattr() falls back to allowing the security
+ // subsystem to return security extended attributes, which by
+ // default don't exist.
+ vfs.putResolvingPath(rp)
+ return nil, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return nil, err
+ }
+ }
+}
+
+// GetxattrAt returns the value associated with the given extended attribute
+// for the file at the given path.
+func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetxattrOptions) (string, error) {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return val, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return "", err
+ }
+ }
+}
+
+// SetxattrAt changes the value associated with the given extended attribute
+// for the file at the given path.
+func (vfs *VirtualFilesystem) SetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetxattrOptions) error {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.SetxattrAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// RemovexattrAt removes the given extended attribute from the file at rp.
+func (vfs *VirtualFilesystem) RemovexattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) error {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.RemovexattrAt(ctx, rp, name)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// SyncAllFilesystems has the semantics of Linux's sync(2).
+func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error {
+ fss := make(map[*Filesystem]struct{})
+ vfs.filesystemsMu.Lock()
+ for fs := range vfs.filesystems {
+ if !fs.TryIncRef() {
+ continue
+ }
+ fss[fs] = struct{}{}
+ }
+ vfs.filesystemsMu.Unlock()
+ var retErr error
+ for fs := range fss {
+ if err := fs.impl.Sync(ctx); err != nil && retErr == nil {
+ retErr = err
+ }
+ fs.DecRef()
+ }
+ return retErr
+}
+
+// A VirtualDentry represents a node in a VFS tree, by combining a Dentry
+// (which represents a node in a Filesystem's tree) and a Mount (which
+// represents the Filesystem's position in a VFS mount tree).
+//
+// VirtualDentry's semantics are similar to that of a Go interface object
+// representing a pointer: it is a copyable value type that represents
+// references to another entity. The zero value of VirtualDentry is an "empty
+// VirtualDentry", directly analogous to a nil interface object.
+// VirtualDentry.Ok() checks that a VirtualDentry is not zero-valued; unless
+// otherwise specified, all other VirtualDentry methods require
+// VirtualDentry.Ok() == true.
+//
+// Mounts and Dentries are reference-counted, requiring that users call
+// VirtualDentry.{Inc,Dec}Ref() as appropriate. We often colloquially refer to
+// references on the Mount and Dentry referred to by a VirtualDentry as
+// references on the VirtualDentry itself. Unless otherwise specified, all
+// VirtualDentry methods require that a reference is held on the VirtualDentry.
+//
+// VirtualDentry is analogous to Linux's struct path.
+//
+// +stateify savable
+type VirtualDentry struct {
+ mount *Mount
+ dentry *Dentry
+}
+
+// MakeVirtualDentry creates a VirtualDentry.
+func MakeVirtualDentry(mount *Mount, dentry *Dentry) VirtualDentry {
+ return VirtualDentry{
+ mount: mount,
+ dentry: dentry,
+ }
+}
+
+// Ok returns true if vd is not empty. It does not require that a reference is
+// held.
+func (vd VirtualDentry) Ok() bool {
+ return vd.mount != nil
+}
+
+// IncRef increments the reference counts on the Mount and Dentry represented
+// by vd.
+func (vd VirtualDentry) IncRef() {
+ vd.mount.IncRef()
+ vd.dentry.IncRef()
+}
+
+// DecRef decrements the reference counts on the Mount and Dentry represented
+// by vd.
+func (vd VirtualDentry) DecRef() {
+ vd.dentry.DecRef()
+ vd.mount.DecRef()
+}
+
+// Mount returns the Mount associated with vd. It does not take a reference on
+// the returned Mount.
+func (vd VirtualDentry) Mount() *Mount {
+ return vd.mount
+}
+
+// Dentry returns the Dentry associated with vd. It does not take a reference
+// on the returned Dentry.
+func (vd VirtualDentry) Dentry() *Dentry {
+ return vd.dentry
+}
diff --git a/pkg/sentry/watchdog/BUILD b/pkg/sentry/watchdog/BUILD
new file mode 100644
index 000000000..1c5a1c9b6
--- /dev/null
+++ b/pkg/sentry/watchdog/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "watchdog",
+ srcs = ["watchdog.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/log",
+ "//pkg/metric",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sync",
+ ],
+)
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
new file mode 100644
index 000000000..748273366
--- /dev/null
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -0,0 +1,374 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package watchdog is responsible for monitoring the sentry for tasks that may
+// potentially be stuck or looping inderterminally causing hard to debug hungs in
+// the untrusted app.
+//
+// It works by periodically querying all tasks to check whether they are in user
+// mode (RunUser), kernel mode (RunSys), or blocked in the kernel (OffCPU). Tasks
+// that have been running in kernel mode for a long time in the same syscall
+// without blocking are considered stuck and are reported.
+//
+// When a stuck task is detected, the watchdog can take one of the following actions:
+// 1. LogWarning: Logs a warning message followed by a stack dump of all goroutines.
+// If a tasks continues to be stuck, the message will repeat every minute, unless
+// a new stuck task is detected
+// 2. Panic: same as above, followed by panic()
+//
+package watchdog
+
+import (
+ "bytes"
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/metric"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Opts configures the watchdog.
+type Opts struct {
+ // TaskTimeout is the amount of time to allow a task to execute the
+ // same syscall without blocking before it's declared stuck.
+ TaskTimeout time.Duration
+
+ // TaskTimeoutAction indicates what action to take when a stuck tasks
+ // is detected.
+ TaskTimeoutAction Action
+
+ // StartupTimeout is the amount of time to allow between watchdog
+ // creation and calling watchdog.Start.
+ StartupTimeout time.Duration
+
+ // StartupTimeoutAction indicates what action to take when
+ // watchdog.Start is not called within the timeout.
+ StartupTimeoutAction Action
+}
+
+// DefaultOpts is a default set of options for the watchdog.
+var DefaultOpts = Opts{
+ // Task timeout.
+ TaskTimeout: 3 * time.Minute,
+ TaskTimeoutAction: LogWarning,
+
+ // Startup timeout.
+ StartupTimeout: 30 * time.Second,
+ StartupTimeoutAction: LogWarning,
+}
+
+// descheduleThreshold is the amount of time scheduling needs to be off before the entire wait period
+// is discounted from task's last update time. It's set high enough that small scheduling delays won't
+// trigger it.
+const descheduleThreshold = 1 * time.Second
+
+var (
+ stuckStartup = metric.MustCreateNewUint64Metric("/watchdog/stuck_startup_detected", true /* sync */, "Incremented once on startup watchdog timeout")
+ stuckTasks = metric.MustCreateNewUint64Metric("/watchdog/stuck_tasks_detected", true /* sync */, "Cumulative count of stuck tasks detected")
+)
+
+// Amount of time to wait before dumping the stack to the log again when the same task(s) remains stuck.
+var stackDumpSameTaskPeriod = time.Minute
+
+// Action defines what action to take when a stuck task is detected.
+type Action int
+
+const (
+ // LogWarning logs warning message followed by stack trace.
+ LogWarning Action = iota
+
+ // Panic will do the same logging as LogWarning and panic().
+ Panic
+)
+
+// String returns Action's string representation.
+func (a Action) String() string {
+ switch a {
+ case LogWarning:
+ return "LogWarning"
+ case Panic:
+ return "Panic"
+ default:
+ panic(fmt.Sprintf("Invalid action: %d", a))
+ }
+}
+
+// Watchdog is the main watchdog class. It controls a goroutine that periodically
+// analyses all tasks and reports if any of them appear to be stuck.
+type Watchdog struct {
+ // Configuration options are embedded.
+ Opts
+
+ // period indicates how often to check all tasks. It's calculated based on
+ // opts.TaskTimeout.
+ period time.Duration
+
+ // k is where the tasks come from.
+ k *kernel.Kernel
+
+ // stop is used to notify to watchdog should stop.
+ stop chan struct{}
+
+ // done is used to notify when the watchdog has stopped.
+ done chan struct{}
+
+ // offenders map contains all tasks that are currently stuck.
+ offenders map[*kernel.Task]*offender
+
+ // lastStackDump tracks the last time a stack dump was generated to prevent
+ // spamming the log.
+ lastStackDump time.Time
+
+ // lastRun is set to the last time the watchdog executed a monitoring loop.
+ lastRun ktime.Time
+
+ // mu protects the fields below.
+ mu sync.Mutex
+
+ // running is true if the watchdog is running.
+ running bool
+
+ // startCalled is true if Start has ever been called. It remains true
+ // even if Stop is called.
+ startCalled bool
+}
+
+type offender struct {
+ lastUpdateTime ktime.Time
+}
+
+// New creates a new watchdog.
+func New(k *kernel.Kernel, opts Opts) *Watchdog {
+ // 4 is arbitrary, just don't want to prolong 'TaskTimeout' too much.
+ period := opts.TaskTimeout / 4
+ w := &Watchdog{
+ Opts: opts,
+ k: k,
+ period: period,
+ offenders: make(map[*kernel.Task]*offender),
+ stop: make(chan struct{}),
+ done: make(chan struct{}),
+ }
+
+ // Handle StartupTimeout if it exists.
+ if w.StartupTimeout > 0 {
+ log.Infof("Watchdog waiting %v for startup", w.StartupTimeout)
+ go w.waitForStart() // S/R-SAFE: watchdog is stopped buring save and restarted after restore.
+ }
+
+ return w
+}
+
+// Start starts the watchdog.
+func (w *Watchdog) Start() {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ w.startCalled = true
+
+ if w.running {
+ return
+ }
+
+ if w.TaskTimeout == 0 {
+ log.Infof("Watchdog task timeout disabled")
+ return
+ }
+ w.lastRun = w.k.MonotonicClock().Now()
+
+ log.Infof("Starting watchdog, period: %v, timeout: %v, action: %v", w.period, w.TaskTimeout, w.TaskTimeoutAction)
+ go w.loop() // S/R-SAFE: watchdog is stopped during save and restarted after restore.
+ w.running = true
+}
+
+// Stop requests the watchdog to stop and wait for it.
+func (w *Watchdog) Stop() {
+ if w.TaskTimeout == 0 {
+ return
+ }
+
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ if !w.running {
+ return
+ }
+ log.Infof("Stopping watchdog")
+ w.stop <- struct{}{}
+ <-w.done
+ w.running = false
+ log.Infof("Watchdog stopped")
+}
+
+// waitForStart waits for Start to be called and takes action if it does not
+// happen within the startup timeout.
+func (w *Watchdog) waitForStart() {
+ <-time.After(w.StartupTimeout)
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ if w.startCalled {
+ // We are fine.
+ return
+ }
+
+ stuckStartup.Increment()
+
+ var buf bytes.Buffer
+ buf.WriteString(fmt.Sprintf("Watchdog.Start() not called within %s", w.StartupTimeout))
+ w.doAction(w.StartupTimeoutAction, false, &buf)
+}
+
+// loop is the main watchdog routine. It only returns when 'Stop()' is called.
+func (w *Watchdog) loop() {
+ // Loop until someone stops it.
+ for {
+ select {
+ case <-w.stop:
+ w.done <- struct{}{}
+ return
+ case <-time.After(w.period):
+ w.runTurn()
+ }
+ }
+}
+
+// runTurn runs a single pass over all tasks and reports anything it finds.
+func (w *Watchdog) runTurn() {
+ // Someone needs to watch the watchdog. The call below can get stuck if there
+ // is a deadlock affecting root's PID namespace mutex. Run it in a goroutine
+ // and report if it takes too long to return.
+ var tasks []*kernel.Task
+ done := make(chan struct{})
+ go func() { // S/R-SAFE: watchdog is stopped and restarted during S/R.
+ tasks = w.k.TaskSet().Root.Tasks()
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(w.TaskTimeout):
+ // Report if the watchdog is not making progress.
+ // No one is watching the watchdog watcher though.
+ w.reportStuckWatchdog()
+ <-done
+ }
+
+ newOffenders := make(map[*kernel.Task]*offender)
+ newTaskFound := false
+ now := ktime.FromNanoseconds(int64(w.k.CPUClockNow() * uint64(linux.ClockTick)))
+
+ // The process may be running with low CPU limit making tasks appear stuck because
+ // are starved of CPU cycles. An estimate is that Tasks could have been starved
+ // since the last time the watchdog run. If the watchdog detects that scheduling
+ // is off, it will discount the entire duration since last run from 'lastUpdateTime'.
+ discount := time.Duration(0)
+ if now.Sub(w.lastRun.Add(w.period)) > descheduleThreshold {
+ discount = now.Sub(w.lastRun)
+ }
+ w.lastRun = now
+
+ log.Infof("Watchdog starting loop, tasks: %d, discount: %v", len(tasks), discount)
+ for _, t := range tasks {
+ tsched := t.TaskGoroutineSchedInfo()
+
+ // An offender is a task running inside the kernel for longer than the specified timeout.
+ if tsched.State == kernel.TaskGoroutineRunningSys {
+ lastUpdateTime := ktime.FromNanoseconds(int64(tsched.Timestamp * uint64(linux.ClockTick)))
+ elapsed := now.Sub(lastUpdateTime) - discount
+ if elapsed > w.TaskTimeout {
+ tc, ok := w.offenders[t]
+ if !ok {
+ // New stuck task detected.
+ //
+ // Note that tasks blocked doing IO may be considered stuck in kernel,
+ // unless they are surrounded b
+ // Task.UninterruptibleSleepStart/Finish.
+ tc = &offender{lastUpdateTime: lastUpdateTime}
+ stuckTasks.Increment()
+ newTaskFound = true
+ }
+ newOffenders[t] = tc
+ }
+ }
+ }
+ if len(newOffenders) > 0 {
+ w.report(newOffenders, newTaskFound, now)
+ }
+
+ // Remember which tasks have been reported.
+ w.offenders = newOffenders
+}
+
+// report takes appropriate action when a stuck task is detected.
+func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound bool, now ktime.Time) {
+ var buf bytes.Buffer
+ buf.WriteString(fmt.Sprintf("Sentry detected %d stuck task(s):\n", len(offenders)))
+ for t, o := range offenders {
+ tid := w.k.TaskSet().Root.IDOfTask(t)
+ buf.WriteString(fmt.Sprintf("\tTask tid: %v (%#x), entered RunSys state %v ago.\n", tid, uint64(tid), now.Sub(o.lastUpdateTime)))
+ }
+
+ buf.WriteString("Search for '(*Task).run(0x..., 0x<tid>)' in the stack dump to find the offending goroutine")
+
+ // Force stack dump only if a new task is detected.
+ w.doAction(w.TaskTimeoutAction, newTaskFound, &buf)
+}
+
+func (w *Watchdog) reportStuckWatchdog() {
+ var buf bytes.Buffer
+ buf.WriteString("Watchdog goroutine is stuck")
+ w.doAction(w.TaskTimeoutAction, false, &buf)
+}
+
+// doAction will take the given action. If the action is LogWarning, the stack
+// is not always dumped to the log to prevent log flooding. "forceStack"
+// guarantees that the stack will be dumped regardless.
+func (w *Watchdog) doAction(action Action, forceStack bool, msg *bytes.Buffer) {
+ switch action {
+ case LogWarning:
+ // Dump stack only if forced or sometime has passed since the last time a
+ // stack dump was generated.
+ if !forceStack && time.Since(w.lastStackDump) < stackDumpSameTaskPeriod {
+ msg.WriteString("\n...[stack dump skipped]...")
+ log.Warningf(msg.String())
+ return
+ }
+ log.TracebackAll(msg.String())
+ w.lastStackDump = time.Now()
+
+ case Panic:
+ // Panic will skip over running tasks, which is likely the culprit here. So manually
+ // dump all stacks before panic'ing.
+ log.TracebackAll(msg.String())
+
+ // Attempt to flush metrics, timeout and move on in case metrics are stuck as well.
+ metricsEmitted := make(chan struct{}, 1)
+ go func() { // S/R-SAFE: watchdog is stopped during save and restarted after restore.
+ // Flush metrics before killing process.
+ metric.EmitMetricUpdate()
+ metricsEmitted <- struct{}{}
+ }()
+ select {
+ case <-metricsEmitted:
+ case <-time.After(1 * time.Second):
+ }
+ panic(fmt.Sprintf("%s\nStack for running G's are skipped while panicking.", msg.String()))
+
+ default:
+ panic(fmt.Sprintf("Unknown watchdog action %v", action))
+
+ }
+}
diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD
new file mode 100644
index 000000000..e131455f7
--- /dev/null
+++ b/pkg/sleep/BUILD
@@ -0,0 +1,24 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "sleep",
+ srcs = [
+ "commit_amd64.s",
+ "commit_arm64.s",
+ "commit_asm.go",
+ "commit_noasm.go",
+ "sleep_unsafe.go",
+ ],
+ visibility = ["//:sandbox"],
+)
+
+go_test(
+ name = "sleep_test",
+ size = "medium",
+ srcs = [
+ "sleep_test.go",
+ ],
+ library = ":sleep",
+)
diff --git a/pkg/sleep/commit_amd64.s b/pkg/sleep/commit_amd64.s
new file mode 100644
index 000000000..bc4ac2c3c
--- /dev/null
+++ b/pkg/sleep/commit_amd64.s
@@ -0,0 +1,35 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+#define preparingG 1
+
+// See commit_noasm.go for a description of commitSleep.
+//
+// func commitSleep(g uintptr, waitingG *uintptr) bool
+TEXT ·commitSleep(SB),NOSPLIT,$0-24
+ MOVQ waitingG+8(FP), CX
+ MOVQ g+0(FP), DX
+
+ // Store the G in waitingG if it's still preparingG. If it's anything
+ // else it means a waker has aborted the sleep.
+ MOVQ $preparingG, AX
+ LOCK
+ CMPXCHGQ DX, 0(CX)
+
+ SETEQ AX
+ MOVB AX, ret+16(FP)
+
+ RET
diff --git a/pkg/sleep/commit_arm64.s b/pkg/sleep/commit_arm64.s
new file mode 100644
index 000000000..d0ef15b20
--- /dev/null
+++ b/pkg/sleep/commit_arm64.s
@@ -0,0 +1,38 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+#define preparingG 1
+
+// See commit_noasm.go for a description of commitSleep.
+//
+// func commitSleep(g uintptr, waitingG *uintptr) bool
+TEXT ·commitSleep(SB),NOSPLIT,$0-24
+ MOVD waitingG+8(FP), R0
+ MOVD $preparingG, R1
+ MOVD G+0(FP), R2
+
+ // Store the G in waitingG if it's still preparingG. If it's anything
+ // else it means a waker has aborted the sleep.
+again:
+ LDAXR (R0), R3
+ CMP R1, R3
+ BNE ok
+ STLXR R2, (R0), R3
+ CBNZ R3, again
+ok:
+ CSET EQ, R0
+ MOVB R0, ret+16(FP)
+ RET
diff --git a/pkg/sleep/commit_asm.go b/pkg/sleep/commit_asm.go
new file mode 100644
index 000000000..75728a97d
--- /dev/null
+++ b/pkg/sleep/commit_asm.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64 arm64
+
+package sleep
+
+// See commit_noasm.go for a description of commitSleep.
+func commitSleep(g uintptr, waitingG *uintptr) bool
diff --git a/pkg/sleep/commit_noasm.go b/pkg/sleep/commit_noasm.go
new file mode 100644
index 000000000..f59061f37
--- /dev/null
+++ b/pkg/sleep/commit_noasm.go
@@ -0,0 +1,33 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !race
+// +build !amd64,!arm64
+
+package sleep
+
+import "sync/atomic"
+
+// commitSleep signals to wakers that the given g is now sleeping. Wakers can
+// then fetch it and wake it.
+//
+// The commit may fail if wakers have been asserted after our last check, in
+// which case they will have set s.waitingG to zero.
+//
+// It is written in assembly because it is called from g0, so it doesn't have
+// a race context.
+func commitSleep(g uintptr, waitingG *uintptr) bool {
+ // Try to store the G so that wakers know who to wake.
+ return atomic.CompareAndSwapUintptr(waitingG, preparingG, g)
+}
diff --git a/pkg/sleep/empty.s b/pkg/sleep/empty.s
new file mode 100644
index 000000000..fb37360ac
--- /dev/null
+++ b/pkg/sleep/empty.s
@@ -0,0 +1,15 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Empty assembly file so empty func definitions work.
diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go
new file mode 100644
index 000000000..af47e2ba1
--- /dev/null
+++ b/pkg/sleep/sleep_test.go
@@ -0,0 +1,573 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sleep
+
+import (
+ "math/rand"
+ "runtime"
+ "testing"
+ "time"
+)
+
+// ZeroWakerNotAsserted tests that a zero-value waker is in non-asserted state.
+func ZeroWakerNotAsserted(t *testing.T) {
+ var w Waker
+ if w.IsAsserted() {
+ t.Fatalf("Zero waker is asserted")
+ }
+
+ if w.Clear() {
+ t.Fatalf("Zero waker is asserted")
+ }
+}
+
+// AssertedWakerAfterAssert tests that a waker properly reports its state as
+// asserted once its Assert() method is called.
+func AssertedWakerAfterAssert(t *testing.T) {
+ var w Waker
+ w.Assert()
+ if !w.IsAsserted() {
+ t.Fatalf("Asserted waker is not reported as such")
+ }
+
+ if !w.Clear() {
+ t.Fatalf("Asserted waker is not reported as such")
+ }
+}
+
+// AssertedWakerAfterTwoAsserts tests that a waker properly reports its state as
+// asserted once its Assert() method is called twice.
+func AssertedWakerAfterTwoAsserts(t *testing.T) {
+ var w Waker
+ w.Assert()
+ w.Assert()
+ if !w.IsAsserted() {
+ t.Fatalf("Asserted waker is not reported as such")
+ }
+
+ if !w.Clear() {
+ t.Fatalf("Asserted waker is not reported as such")
+ }
+}
+
+// NotAssertedWakerWithSleeper tests that a waker properly reports its state as
+// not asserted after a sleeper is associated with it.
+func NotAssertedWakerWithSleeper(t *testing.T) {
+ var w Waker
+ var s Sleeper
+ s.AddWaker(&w, 0)
+ if w.IsAsserted() {
+ t.Fatalf("Non-asserted waker is reported as asserted")
+ }
+
+ if w.Clear() {
+ t.Fatalf("Non-asserted waker is reported as asserted")
+ }
+}
+
+// NotAssertedWakerAfterWake tests that a waker properly reports its state as
+// not asserted after a previous assert is consumed by a sleeper. That is, tests
+// the "edge-triggered" behavior.
+func NotAssertedWakerAfterWake(t *testing.T) {
+ var w Waker
+ var s Sleeper
+ s.AddWaker(&w, 0)
+ w.Assert()
+ s.Fetch(true)
+ if w.IsAsserted() {
+ t.Fatalf("Consumed waker is reported as asserted")
+ }
+
+ if w.Clear() {
+ t.Fatalf("Consumed waker is reported as asserted")
+ }
+}
+
+// AssertedWakerBeforeAdd tests that a waker causes a sleeper to not sleep if
+// it's already asserted before being added.
+func AssertedWakerBeforeAdd(t *testing.T) {
+ var w Waker
+ var s Sleeper
+ w.Assert()
+ s.AddWaker(&w, 0)
+
+ if _, ok := s.Fetch(false); !ok {
+ t.Fatalf("Fetch failed even though asserted waker was added")
+ }
+}
+
+// ClearedWaker tests that a waker properly reports its state as not asserted
+// after it is cleared.
+func ClearedWaker(t *testing.T) {
+ var w Waker
+ w.Assert()
+ w.Clear()
+ if w.IsAsserted() {
+ t.Fatalf("Cleared waker is reported as asserted")
+ }
+
+ if w.Clear() {
+ t.Fatalf("Cleared waker is reported as asserted")
+ }
+}
+
+// ClearedWakerWithSleeper tests that a waker properly reports its state as
+// not asserted when it is cleared while it has a sleeper associated with it.
+func ClearedWakerWithSleeper(t *testing.T) {
+ var w Waker
+ var s Sleeper
+ s.AddWaker(&w, 0)
+ w.Clear()
+ if w.IsAsserted() {
+ t.Fatalf("Cleared waker is reported as asserted")
+ }
+
+ if w.Clear() {
+ t.Fatalf("Cleared waker is reported as asserted")
+ }
+}
+
+// ClearedWakerAssertedWithSleeper tests that a waker properly reports its state
+// as not asserted when it is cleared while it has a sleeper associated with it
+// and has been asserted.
+func ClearedWakerAssertedWithSleeper(t *testing.T) {
+ var w Waker
+ var s Sleeper
+ s.AddWaker(&w, 0)
+ w.Assert()
+ w.Clear()
+ if w.IsAsserted() {
+ t.Fatalf("Cleared waker is reported as asserted")
+ }
+
+ if w.Clear() {
+ t.Fatalf("Cleared waker is reported as asserted")
+ }
+}
+
+// TestBlock tests that a sleeper actually blocks waiting for the waker to
+// assert its state.
+func TestBlock(t *testing.T) {
+ var w Waker
+ var s Sleeper
+
+ s.AddWaker(&w, 0)
+
+ // Assert waker after one second.
+ before := time.Now()
+ go func() {
+ time.Sleep(1 * time.Second)
+ w.Assert()
+ }()
+
+ // Fetch the result and make sure it took at least 500ms.
+ if _, ok := s.Fetch(true); !ok {
+ t.Fatalf("Fetch failed unexpectedly")
+ }
+ if d := time.Now().Sub(before); d < 500*time.Millisecond {
+ t.Fatalf("Duration was too short: %v", d)
+ }
+
+ // Check that already-asserted waker completes inline.
+ w.Assert()
+ if _, ok := s.Fetch(true); !ok {
+ t.Fatalf("Fetch failed unexpectedly")
+ }
+
+ // Check that fetch sleeps if waker had been asserted but was reset
+ // before Fetch is called.
+ w.Assert()
+ w.Clear()
+ before = time.Now()
+ go func() {
+ time.Sleep(1 * time.Second)
+ w.Assert()
+ }()
+ if _, ok := s.Fetch(true); !ok {
+ t.Fatalf("Fetch failed unexpectedly")
+ }
+ if d := time.Now().Sub(before); d < 500*time.Millisecond {
+ t.Fatalf("Duration was too short: %v", d)
+ }
+}
+
+// TestNonBlock checks that a sleeper won't block if waker isn't asserted.
+func TestNonBlock(t *testing.T) {
+ var w Waker
+ var s Sleeper
+
+ // Don't block when there's no waker.
+ if _, ok := s.Fetch(false); ok {
+ t.Fatalf("Fetch succeeded when there is no waker")
+ }
+
+ // Don't block when waker isn't asserted.
+ s.AddWaker(&w, 0)
+ if _, ok := s.Fetch(false); ok {
+ t.Fatalf("Fetch succeeded when waker was not asserted")
+ }
+
+ // Don't block when waker was asserted, but isn't anymore.
+ w.Assert()
+ w.Clear()
+ if _, ok := s.Fetch(false); ok {
+ t.Fatalf("Fetch succeeded when waker was not asserted anymore")
+ }
+
+ // Don't block when waker was consumed by previous Fetch().
+ w.Assert()
+ if _, ok := s.Fetch(false); !ok {
+ t.Fatalf("Fetch failed even though waker was asserted")
+ }
+
+ if _, ok := s.Fetch(false); ok {
+ t.Fatalf("Fetch succeeded when waker had been consumed")
+ }
+}
+
+// TestMultiple checks that a sleeper can wait for and receives notifications
+// from multiple wakers.
+func TestMultiple(t *testing.T) {
+ s := Sleeper{}
+ w1 := Waker{}
+ w2 := Waker{}
+
+ s.AddWaker(&w1, 0)
+ s.AddWaker(&w2, 1)
+
+ w1.Assert()
+ w2.Assert()
+
+ v, ok := s.Fetch(false)
+ if !ok {
+ t.Fatalf("Fetch failed when there are asserted wakers")
+ }
+
+ if v != 0 && v != 1 {
+ t.Fatalf("Unexpected waker id: %v", v)
+ }
+
+ want := 1 - v
+ v, ok = s.Fetch(false)
+ if !ok {
+ t.Fatalf("Fetch failed when there is an asserted waker")
+ }
+
+ if v != want {
+ t.Fatalf("Unexpected waker id, got %v, want %v", v, want)
+ }
+}
+
+// TestDoneFunction tests if calling Done() on a sleeper works properly.
+func TestDoneFunction(t *testing.T) {
+ // Trivial case of no waker.
+ s := Sleeper{}
+ s.Done()
+
+ // Cases when the sleeper has n wakers, but none are asserted.
+ for n := 1; n < 20; n++ {
+ s := Sleeper{}
+ w := make([]Waker, n)
+ for j := 0; j < n; j++ {
+ s.AddWaker(&w[j], j)
+ }
+ s.Done()
+ }
+
+ // Cases when the sleeper has n wakers, and only the i-th one is
+ // asserted.
+ for n := 1; n < 20; n++ {
+ for i := 0; i < n; i++ {
+ s := Sleeper{}
+ w := make([]Waker, n)
+ for j := 0; j < n; j++ {
+ s.AddWaker(&w[j], j)
+ }
+ w[i].Assert()
+ s.Done()
+ }
+ }
+
+ // Cases when the sleeper has n wakers, and the i-th one is asserted
+ // and cleared.
+ for n := 1; n < 20; n++ {
+ for i := 0; i < n; i++ {
+ s := Sleeper{}
+ w := make([]Waker, n)
+ for j := 0; j < n; j++ {
+ s.AddWaker(&w[j], j)
+ }
+ w[i].Assert()
+ w[i].Clear()
+ s.Done()
+ }
+ }
+
+ // Cases when the sleeper has n wakers, with a random number of them
+ // asserted.
+ for n := 1; n < 20; n++ {
+ for iters := 0; iters < 1000; iters++ {
+ s := Sleeper{}
+ w := make([]Waker, n)
+ for j := 0; j < n; j++ {
+ s.AddWaker(&w[j], j)
+ }
+
+ // Pick the number of asserted elements, then assert
+ // random wakers.
+ asserted := rand.Int() % (n + 1)
+ for j := 0; j < asserted; j++ {
+ w[rand.Int()%n].Assert()
+ }
+ s.Done()
+ }
+ }
+}
+
+// TestRace tests that multiple wakers can continuously send wake requests to
+// the sleeper.
+func TestRace(t *testing.T) {
+ const wakers = 100
+ const wakeRequests = 10000
+
+ counts := make([]int, wakers)
+ w := make([]Waker, wakers)
+ s := Sleeper{}
+
+ // Associate each waker and start goroutines that will assert them.
+ for i := range w {
+ s.AddWaker(&w[i], i)
+ go func(w *Waker) {
+ n := 0
+ for n < wakeRequests {
+ if !w.IsAsserted() {
+ w.Assert()
+ n++
+ } else {
+ runtime.Gosched()
+ }
+ }
+ }(&w[i])
+ }
+
+ // Wait for all wake up notifications from all wakers.
+ for i := 0; i < wakers*wakeRequests; i++ {
+ v, _ := s.Fetch(true)
+ counts[v]++
+ }
+
+ // Check that we got the right number for each.
+ for i, v := range counts {
+ if v != wakeRequests {
+ t.Errorf("Waker %v only got %v wakes", i, v)
+ }
+ }
+}
+
+// TestRaceInOrder tests that multiple wakers can continuously send wake requests to
+// the sleeper and that the wakers are retrieved in the order asserted.
+func TestRaceInOrder(t *testing.T) {
+ const wakers = 100
+ const wakeRequests = 10000
+
+ w := make([]Waker, wakers)
+ s := Sleeper{}
+
+ // Associate each waker and start goroutines that will assert them.
+ for i := range w {
+ s.AddWaker(&w[i], i)
+ }
+ go func() {
+ n := 0
+ for n < wakeRequests {
+ wk := w[n%len(w)]
+ wk.Assert()
+ n++
+ }
+ }()
+
+ // Wait for all wake up notifications from all wakers.
+ for i := 0; i < wakeRequests; i++ {
+ v, _ := s.Fetch(true)
+ if got, want := v, i%wakers; got != want {
+ t.Fatalf("got %d want %d", got, want)
+ }
+ }
+}
+
+// BenchmarkSleeperMultiSelect measures how long it takes to fetch a wake up
+// from 4 wakers when at least one is already asserted.
+func BenchmarkSleeperMultiSelect(b *testing.B) {
+ const count = 4
+ s := Sleeper{}
+ w := make([]Waker, count)
+ for i := range w {
+ s.AddWaker(&w[i], i)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w[count-1].Assert()
+ s.Fetch(true)
+ }
+}
+
+// BenchmarkGoMultiSelect measures how long it takes to fetch a zero-length
+// struct from one of 4 channels when at least one is ready.
+func BenchmarkGoMultiSelect(b *testing.B) {
+ const count = 4
+ ch := make([]chan struct{}, count)
+ for i := range ch {
+ ch[i] = make(chan struct{}, 1)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ ch[count-1] <- struct{}{}
+ select {
+ case <-ch[0]:
+ case <-ch[1]:
+ case <-ch[2]:
+ case <-ch[3]:
+ }
+ }
+}
+
+// BenchmarkSleeperSingleSelect measures how long it takes to fetch a wake up
+// from one waker that is already asserted.
+func BenchmarkSleeperSingleSelect(b *testing.B) {
+ s := Sleeper{}
+ w := Waker{}
+ s.AddWaker(&w, 0)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Assert()
+ s.Fetch(true)
+ }
+}
+
+// BenchmarkGoSingleSelect measures how long it takes to fetch a zero-length
+// struct from a channel that already has it buffered.
+func BenchmarkGoSingleSelect(b *testing.B) {
+ ch := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ ch <- struct{}{}
+ <-ch
+ }
+}
+
+// BenchmarkSleeperAssertNonWaiting measures how long it takes to assert a
+// channel that is already asserted.
+func BenchmarkSleeperAssertNonWaiting(b *testing.B) {
+ w := Waker{}
+ w.Assert()
+ for i := 0; i < b.N; i++ {
+ w.Assert()
+ }
+
+}
+
+// BenchmarkGoAssertNonWaiting measures how long it takes to write to a channel
+// that has already something written to it.
+func BenchmarkGoAssertNonWaiting(b *testing.B) {
+ ch := make(chan struct{}, 1)
+ ch <- struct{}{}
+ for i := 0; i < b.N; i++ {
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+ }
+}
+
+// BenchmarkSleeperWaitOnSingleSelect measures how long it takes to wait on one
+// waker channel while another goroutine wakes up the sleeper. This assumes that
+// a new goroutine doesn't run immediately (i.e., the creator of a new goroutine
+// is allowed to go to sleep before the new goroutine has a chance to run).
+func BenchmarkSleeperWaitOnSingleSelect(b *testing.B) {
+ s := Sleeper{}
+ w := Waker{}
+ s.AddWaker(&w, 0)
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w.Assert()
+ }()
+ s.Fetch(true)
+ }
+
+}
+
+// BenchmarkGoWaitOnSingleSelect measures how long it takes to wait on one
+// channel while another goroutine wakes up the sleeper. This assumes that a new
+// goroutine doesn't run immediately (i.e., the creator of a new goroutine is
+// allowed to go to sleep before the new goroutine has a chance to run).
+func BenchmarkGoWaitOnSingleSelect(b *testing.B) {
+ ch := make(chan struct{}, 1)
+ for i := 0; i < b.N; i++ {
+ go func() {
+ ch <- struct{}{}
+ }()
+ <-ch
+ }
+}
+
+// BenchmarkSleeperWaitOnMultiSelect measures how long it takes to wait on 4
+// wakers while another goroutine wakes up the sleeper. This assumes that a new
+// goroutine doesn't run immediately (i.e., the creator of a new goroutine is
+// allowed to go to sleep before the new goroutine has a chance to run).
+func BenchmarkSleeperWaitOnMultiSelect(b *testing.B) {
+ const count = 4
+ s := Sleeper{}
+ w := make([]Waker, count)
+ for i := range w {
+ s.AddWaker(&w[i], i)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w[count-1].Assert()
+ }()
+ s.Fetch(true)
+ }
+}
+
+// BenchmarkGoWaitOnMultiSelect measures how long it takes to wait on 4 channels
+// while another goroutine wakes up the sleeper. This assumes that a new
+// goroutine doesn't run immediately (i.e., the creator of a new goroutine is
+// allowed to go to sleep before the new goroutine has a chance to run).
+func BenchmarkGoWaitOnMultiSelect(b *testing.B) {
+ const count = 4
+ ch := make([]chan struct{}, count)
+ for i := range ch {
+ ch[i] = make(chan struct{}, 1)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ ch[count-1] <- struct{}{}
+ }()
+ select {
+ case <-ch[0]:
+ case <-ch[1]:
+ case <-ch[2]:
+ case <-ch[3]:
+ }
+ }
+}
diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go
new file mode 100644
index 000000000..f68c12620
--- /dev/null
+++ b/pkg/sleep/sleep_unsafe.go
@@ -0,0 +1,400 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build go1.11
+// +build !go1.16
+
+// Check go:linkname function signatures when updating Go version.
+
+// Package sleep allows goroutines to efficiently sleep on multiple sources of
+// notifications (wakers). It offers O(1) complexity, which is different from
+// multi-channel selects which have O(n) complexity (where n is the number of
+// channels) and a considerable constant factor.
+//
+// It is similar to edge-triggered epoll waits, where the user registers each
+// object of interest once, and then can repeatedly wait on all of them.
+//
+// A Waker object is used to wake a sleeping goroutine (G) up, or prevent it
+// from going to sleep next. A Sleeper object is used to receive notifications
+// from wakers, and if no notifications are available, to optionally sleep until
+// one becomes available.
+//
+// A Waker can be associated with at most one Sleeper, but a Sleeper can be
+// associated with multiple Wakers. A Sleeper has a list of asserted (ready)
+// wakers; when Fetch() is called repeatedly, elements from this list are
+// returned until the list becomes empty in which case the goroutine goes to
+// sleep. When Assert() is called on a Waker, it adds itself to the Sleeper's
+// asserted list and wakes the G up from its sleep if needed.
+//
+// Sleeper objects are expected to be used as follows, with just one goroutine
+// executing this code:
+//
+// // One time set-up.
+// s := sleep.Sleeper{}
+// s.AddWaker(&w1, constant1)
+// s.AddWaker(&w2, constant2)
+//
+// // Called repeatedly.
+// for {
+// switch id, _ := s.Fetch(true); id {
+// case constant1:
+// // Do work triggered by w1 being asserted.
+// case constant2:
+// // Do work triggered by w2 being asserted.
+// }
+// }
+//
+// And Waker objects are expected to call w.Assert() when they want the sleeper
+// to wake up and perform work.
+//
+// The notifications are edge-triggered, which means that if a Waker calls
+// Assert() several times before the sleeper has the chance to wake up, it will
+// only be notified once and should perform all pending work (alternatively, it
+// can also call Assert() on the waker, to ensure that it will wake up again).
+//
+// The "unsafeness" here is in the casts to/from unsafe.Pointer, which is safe
+// when only one type is used for each unsafe.Pointer (which is the case here),
+// we should just make sure that this remains the case in the future. The usage
+// of unsafe package could be confined to sharedWaker and sharedSleeper types
+// that would hold pointers in atomic.Pointers, but the go compiler currently
+// can't optimize these as well (it won't inline their method calls), which
+// reduces performance.
+package sleep
+
+import (
+ "sync/atomic"
+ "unsafe"
+)
+
+const (
+ // preparingG is stored in sleepers to indicate that they're preparing
+ // to sleep.
+ preparingG = 1
+)
+
+var (
+ // assertedSleeper is a sentinel sleeper. A pointer to it is stored in
+ // wakers that are asserted.
+ assertedSleeper Sleeper
+)
+
+//go:linkname gopark runtime.gopark
+func gopark(unlockf func(uintptr, *uintptr) bool, wg *uintptr, reason uint8, traceEv byte, traceskip int)
+
+//go:linkname goready runtime.goready
+func goready(g uintptr, traceskip int)
+
+// Sleeper allows a goroutine to sleep and receive wake up notifications from
+// Wakers in an efficient way.
+//
+// This is similar to edge-triggered epoll in that wakers are added to the
+// sleeper once and the sleeper can then repeatedly sleep in O(1) time while
+// waiting on all wakers.
+//
+// None of the methods in a Sleeper can be called concurrently. Wakers that have
+// been added to a sleeper A can only be added to another sleeper after A.Done()
+// returns. These restrictions allow this to be implemented lock-free.
+//
+// This struct is thread-compatible.
+type Sleeper struct {
+ // sharedList is a "stack" of asserted wakers. They atomically add
+ // themselves to the front of this list as they become asserted.
+ sharedList unsafe.Pointer
+
+ // localList is a list of asserted wakers that is only accessible to the
+ // waiter, and thus doesn't have to be accessed atomically. When
+ // fetching more wakers, the waiter will first go through this list, and
+ // only when it's empty will it atomically fetch wakers from
+ // sharedList.
+ localList *Waker
+
+ // allWakers is a list with all wakers that have been added to this
+ // sleeper. It is used during cleanup to remove associations.
+ allWakers *Waker
+
+ // waitingG holds the G that is sleeping, if any. It is used by wakers
+ // to determine which G, if any, they should wake.
+ waitingG uintptr
+}
+
+// AddWaker associates the given waker to the sleeper. id is the value to be
+// returned when the sleeper is woken by the given waker.
+func (s *Sleeper) AddWaker(w *Waker, id int) {
+ // Add the waker to the list of all wakers.
+ w.allWakersNext = s.allWakers
+ s.allWakers = w
+ w.id = id
+
+ // Try to associate the waker with the sleeper. If it's already
+ // asserted, we simply enqueue it in the "ready" list.
+ for {
+ p := (*Sleeper)(atomic.LoadPointer(&w.s))
+ if p == &assertedSleeper {
+ s.enqueueAssertedWaker(w)
+ return
+ }
+
+ if atomic.CompareAndSwapPointer(&w.s, usleeper(p), usleeper(s)) {
+ return
+ }
+ }
+}
+
+// nextWaker returns the next waker in the notification list, blocking if
+// needed.
+func (s *Sleeper) nextWaker(block bool) *Waker {
+ // Attempt to replenish the local list if it's currently empty.
+ if s.localList == nil {
+ for atomic.LoadPointer(&s.sharedList) == nil {
+ // Fail request if caller requested that we
+ // don't block.
+ if !block {
+ return nil
+ }
+
+ // Indicate to wakers that we're about to sleep,
+ // this allows them to abort the wait by setting
+ // waitingG back to zero (which we'll notice
+ // before committing the sleep).
+ atomic.StoreUintptr(&s.waitingG, preparingG)
+
+ // Check if something was queued while we were
+ // preparing to sleep. We need this interleaving
+ // to avoid missing wake ups.
+ if atomic.LoadPointer(&s.sharedList) != nil {
+ atomic.StoreUintptr(&s.waitingG, 0)
+ break
+ }
+
+ // Try to commit the sleep and report it to the
+ // tracer as a select.
+ //
+ // gopark puts the caller to sleep and calls
+ // commitSleep to decide whether to immediately
+ // wake the caller up or to leave it sleeping.
+ const traceEvGoBlockSelect = 24
+ // See:runtime2.go in the go runtime package for
+ // the values to pass as the waitReason here.
+ const waitReasonSelect = 9
+ gopark(commitSleep, &s.waitingG, waitReasonSelect, traceEvGoBlockSelect, 0)
+ }
+
+ // Pull the shared list out and reverse it in the local
+ // list. Given that wakers push themselves in reverse
+ // order, we fix things here.
+ v := (*Waker)(atomic.SwapPointer(&s.sharedList, nil))
+ for v != nil {
+ cur := v
+ v = v.next
+
+ cur.next = s.localList
+ s.localList = cur
+ }
+ }
+
+ // Remove the waker in the front of the list.
+ w := s.localList
+ s.localList = w.next
+
+ return w
+}
+
+// Fetch fetches the next wake-up notification. If a notification is immediately
+// available, it is returned right away. Otherwise, the behavior depends on the
+// value of 'block': if true, the current goroutine blocks until a notification
+// arrives, then returns it; if false, returns 'ok' as false.
+//
+// When 'ok' is true, the value of 'id' corresponds to the id associated with
+// the waker; when 'ok' is false, 'id' is undefined.
+//
+// N.B. This method is *not* thread-safe. Only one goroutine at a time is
+// allowed to call this method.
+func (s *Sleeper) Fetch(block bool) (id int, ok bool) {
+ for {
+ w := s.nextWaker(block)
+ if w == nil {
+ return -1, false
+ }
+
+ // Reassociate the waker with the sleeper. If the waker was
+ // still asserted we can return it, otherwise try the next one.
+ old := (*Sleeper)(atomic.SwapPointer(&w.s, usleeper(s)))
+ if old == &assertedSleeper {
+ return w.id, true
+ }
+ }
+}
+
+// Done is used to indicate that the caller won't use this Sleeper anymore. It
+// removes the association with all wakers so that they can be safely reused
+// by another sleeper after Done() returns.
+func (s *Sleeper) Done() {
+ // Remove all associations that we can, and build a list of the ones
+ // we could not. An association can be removed right away from waker w
+ // if w.s has a pointer to the sleeper, that is, the waker is not
+ // asserted yet. By atomically switching w.s to nil, we guarantee that
+ // subsequent calls to Assert() on the waker will not result in it being
+ // queued to this sleeper.
+ var pending *Waker
+ w := s.allWakers
+ for w != nil {
+ next := w.allWakersNext
+ for {
+ t := atomic.LoadPointer(&w.s)
+ if t != usleeper(s) {
+ w.allWakersNext = pending
+ pending = w
+ break
+ }
+
+ if atomic.CompareAndSwapPointer(&w.s, t, nil) {
+ break
+ }
+ }
+ w = next
+ }
+
+ // The associations that we could not remove are either asserted, or in
+ // the process of being asserted, or have been asserted and cleared
+ // before being pulled from the sleeper lists. We must wait for them all
+ // to make it to the sleeper lists, so that we know that the wakers
+ // won't do any more work towards waking this sleeper up.
+ for pending != nil {
+ pulled := s.nextWaker(true)
+
+ // Remove the waker we just pulled from the list of associated
+ // wakers.
+ prev := &pending
+ for w := *prev; w != nil; w = *prev {
+ if pulled == w {
+ *prev = w.allWakersNext
+ break
+ }
+ prev = &w.allWakersNext
+ }
+ }
+ s.allWakers = nil
+}
+
+// enqueueAssertedWaker enqueues an asserted waker to the "ready" circular list
+// of wakers that want to notify the sleeper.
+func (s *Sleeper) enqueueAssertedWaker(w *Waker) {
+ // Add the new waker to the front of the list.
+ for {
+ v := (*Waker)(atomic.LoadPointer(&s.sharedList))
+ w.next = v
+ if atomic.CompareAndSwapPointer(&s.sharedList, uwaker(v), uwaker(w)) {
+ break
+ }
+ }
+
+ // Nothing to do if there isn't a G waiting.
+ if atomic.LoadUintptr(&s.waitingG) == 0 {
+ return
+ }
+
+ // Signal to the sleeper that a waker has been asserted.
+ switch g := atomic.SwapUintptr(&s.waitingG, 0); g {
+ case 0, preparingG:
+ default:
+ // We managed to get a G. Wake it up.
+ goready(g, 0)
+ }
+}
+
+// Waker represents a source of wake-up notifications to be sent to sleepers. A
+// waker can be associated with at most one sleeper at a time, and at any given
+// time is either in asserted or non-asserted state.
+//
+// Once asserted, the waker remains so until it is manually cleared or a sleeper
+// consumes its assertion (i.e., a sleeper wakes up or is prevented from going
+// to sleep due to the waker).
+//
+// This struct is thread-safe, that is, its methods can be called concurrently
+// by multiple goroutines.
+type Waker struct {
+ // s is the sleeper that this waker can wake up. Only one sleeper at a
+ // time is allowed. This field can have three classes of values:
+ // nil -- the waker is not asserted: it either is not associated with
+ // a sleeper, or is queued to a sleeper due to being previously
+ // asserted. This is the zero value.
+ // &assertedSleeper -- the waker is asserted.
+ // otherwise -- the waker is not asserted, and is associated with the
+ // given sleeper. Once it transitions to asserted state, the
+ // associated sleeper will be woken.
+ s unsafe.Pointer
+
+ // next is used to form a linked list of asserted wakers in a sleeper.
+ next *Waker
+
+ // allWakersNext is used to form a linked list of all wakers associated
+ // to a given sleeper.
+ allWakersNext *Waker
+
+ // id is the value to be returned to sleepers when they wake up due to
+ // this waker being asserted.
+ id int
+}
+
+// Assert moves the waker to an asserted state, if it isn't asserted yet. When
+// asserted, the waker will cause its matching sleeper to wake up.
+func (w *Waker) Assert() {
+ // Nothing to do if the waker is already asserted. This check allows us
+ // to complete this case (already asserted) without any interlocked
+ // operations on x86.
+ if atomic.LoadPointer(&w.s) == usleeper(&assertedSleeper) {
+ return
+ }
+
+ // Mark the waker as asserted, and wake up a sleeper if there is one.
+ switch s := (*Sleeper)(atomic.SwapPointer(&w.s, usleeper(&assertedSleeper))); s {
+ case nil:
+ case &assertedSleeper:
+ default:
+ s.enqueueAssertedWaker(w)
+ }
+}
+
+// Clear moves the waker to then non-asserted state and returns whether it was
+// asserted before being cleared.
+//
+// N.B. The waker isn't removed from the "ready" list of a sleeper (if it
+// happens to be in one), but the sleeper will notice that it is not asserted
+// anymore and won't return it to the caller.
+func (w *Waker) Clear() bool {
+ // Nothing to do if the waker is not asserted. This check allows us to
+ // complete this case (already not asserted) without any interlocked
+ // operations on x86.
+ if atomic.LoadPointer(&w.s) != usleeper(&assertedSleeper) {
+ return false
+ }
+
+ // Try to store nil in the sleeper, which indicates that the waker is
+ // not asserted.
+ return atomic.CompareAndSwapPointer(&w.s, usleeper(&assertedSleeper), nil)
+}
+
+// IsAsserted returns whether the waker is currently asserted (i.e., if it's
+// currently in a state that would cause its matching sleeper to wake up).
+func (w *Waker) IsAsserted() bool {
+ return (*Sleeper)(atomic.LoadPointer(&w.s)) == &assertedSleeper
+}
+
+func usleeper(s *Sleeper) unsafe.Pointer {
+ return unsafe.Pointer(s)
+}
+
+func uwaker(w *Waker) unsafe.Pointer {
+ return unsafe.Pointer(w)
+}
diff --git a/pkg/state/BUILD b/pkg/state/BUILD
new file mode 100644
index 000000000..089b3bbef
--- /dev/null
+++ b/pkg/state/BUILD
@@ -0,0 +1,100 @@
+load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "pending_list",
+ out = "pending_list.go",
+ package = "state",
+ prefix = "pending",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*objectEncodeState",
+ "ElementMapper": "pendingMapper",
+ "Linker": "*pendingEntry",
+ },
+)
+
+go_template_instance(
+ name = "deferred_list",
+ out = "deferred_list.go",
+ package = "state",
+ prefix = "deferred",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*objectEncodeState",
+ "ElementMapper": "deferredMapper",
+ "Linker": "*deferredEntry",
+ },
+)
+
+go_template_instance(
+ name = "complete_list",
+ out = "complete_list.go",
+ package = "state",
+ prefix = "complete",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*objectDecodeState",
+ "Linker": "*objectDecodeState",
+ },
+)
+
+go_template_instance(
+ name = "addr_range",
+ out = "addr_range.go",
+ package = "state",
+ prefix = "addr",
+ template = "//pkg/segment:generic_range",
+ types = {
+ "T": "uintptr",
+ },
+)
+
+go_template_instance(
+ name = "addr_set",
+ out = "addr_set.go",
+ consts = {
+ "minDegree": "10",
+ },
+ imports = {
+ "reflect": "reflect",
+ },
+ package = "state",
+ prefix = "addr",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "uintptr",
+ "Range": "addrRange",
+ "Value": "*objectEncodeState",
+ "Functions": "addrSetFunctions",
+ },
+)
+
+go_library(
+ name = "state",
+ srcs = [
+ "addr_range.go",
+ "addr_set.go",
+ "complete_list.go",
+ "decode.go",
+ "decode_unsafe.go",
+ "deferred_list.go",
+ "encode.go",
+ "encode_unsafe.go",
+ "pending_list.go",
+ "state.go",
+ "state_norace.go",
+ "state_race.go",
+ "stats.go",
+ "types.go",
+ ],
+ marshal = False,
+ stateify = False,
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/log",
+ "//pkg/state/wire",
+ ],
+)
diff --git a/pkg/state/README.md b/pkg/state/README.md
new file mode 100644
index 000000000..1aa401193
--- /dev/null
+++ b/pkg/state/README.md
@@ -0,0 +1,158 @@
+# State Encoding and Decoding
+
+The state package implements the encoding and decoding of data structures for
+`go_stateify`. This package is designed for use cases other than the standard
+encoding packages, e.g. `gob` and `json`. Principally:
+
+* This package operates on complex object graphs and accurately serializes and
+ restores all relationships. That is, you can have things like: intrusive
+ pointers, cycles, and pointer chains of arbitrary depths. These are not
+ handled appropriately by existing encoders. This is not an implementation
+ flaw: the formats themselves are not capable of representing these graphs,
+ as they can only generate directed trees.
+
+* This package allows installing order-dependent load callbacks and then
+ resolves that graph at load time, with cycle detection. Similarly, there is
+ no analogous feature possible in the standard encoders.
+
+* This package handles the resolution of interfaces, based on a registered
+ type name. For interface objects type information is saved in the serialized
+ format. This is generally true for `gob` as well, but it works differently.
+
+Here's an overview of how encoding and decoding works.
+
+## Encoding
+
+Encoding produces a `statefile`, which contains a list of chunks of the form
+`(header, payload)`. The payload can either be some raw data, or a series of
+encoded wire objects representing some object graph. All encoded objects are
+defined in the `wire` subpackage.
+
+Encoding of an object graph begins with `encodeState.Save`.
+
+### 1. Memory Map & Encoding
+
+To discover relationships between potentially interdependent data structures
+(for example, a struct may contain pointers to members of other data
+structures), the encoder first walks the object graph and constructs a memory
+map of the objects in the input graph. As this walk progresses, objects are
+queued in the `pending` list and items are placed on the `deferred` list as they
+are discovered. No single object will be encoded multiple times, but the
+discovered relationships between objects may change as more parts of the overall
+object graph are discovered.
+
+The encoder starts at the root object and recursively visits all reachable
+objects, recording the address ranges containing the underlying data for each
+object. This is stored as a segment set (`addrSet`), mapping address ranges to
+the of the object occupying the range; see `encodeState.values`. Note that there
+is special handling for zero-sized types and map objects during this process.
+
+Additionally, the encoder assigns each object a unique identifier which is used
+to indicate relationships between objects in the statefile; see `objectID` in
+`encode.go`.
+
+### 2. Type Serialization
+
+The enoder will subsequently serialize all information about discovered types,
+including field names. These are used during decoding to reconcile these types
+with other internally registered types.
+
+### 3. Object Serialization
+
+With a full address map, and all objects correctly encoded, all object encodings
+are serialized. The assigned `objectID`s aren't explicitly encoded in the
+statefile. The order of object messages in the stream determine their IDs.
+
+### Example
+
+Given the following data structure definitions:
+
+```go
+type system struct {
+ o *outer
+ i *inner
+}
+
+type outer struct {
+ a int64
+ cn *container
+}
+
+type container struct {
+ n uint64
+ elem *inner
+}
+
+type inner struct {
+ c container
+ x, y uint64
+}
+```
+
+Initialized like this:
+
+```go
+o := outer{
+ a: 10,
+ cn: nil,
+}
+i := inner{
+ x: 20,
+ y: 30,
+ c: container{},
+}
+s := system{
+ o: &o,
+ i: &i,
+}
+
+o.cn = &i.c
+o.cn.elem = &i
+
+```
+
+Encoding will produce an object stream like this:
+
+```
+g0r1 = struct{
+ i: g0r3,
+ o: g0r2,
+}
+g0r2 = struct{
+ a: 10,
+ cn: g0r3.c,
+}
+g0r3 = struct{
+ c: struct{
+ elem: g0r3,
+ n: 0u,
+ },
+ x: 20u,
+ y: 30u,
+}
+```
+
+Note how `g0r3.c` is correctly encoded as the underlying `container` object for
+`inner.c`, and how the pointer from `outer.cn` points to it, despite `system.i`
+being discovered after the pointer to it in `system.o.cn`. Also note that
+decoding isn't strictly reliant on the order of encoded object stream, as long
+as the relationship between objects are correctly encoded.
+
+## Decoding
+
+Decoding reads the statefile and reconstructs the object graph. Decoding begins
+in `decodeState.Load`. Decoding is performed in a single pass over the object
+stream in the statefile, and a subsequent pass over all deserialized objects is
+done to fire off all loading callbacks in the correctly defined order. Note that
+introducing cycles is possible here, but these are detected and an error will be
+returned.
+
+Decoding is relatively straight forward. For most primitive values, the decoder
+constructs an appropriate object and fills it with the values encoded in the
+statefile. Pointers need special handling, as they must point to a value
+allocated elsewhere. When values are constructed, the decoder indexes them by
+their `objectID`s in `decodeState.objectsByID`. The target of pointers are
+resolved by searching for the target in this index by their `objectID`; see
+`decodeState.register`. For pointers to values inside another value (fields in a
+pointer, elements of an array), the decoder uses the accessor path to walk to
+the appropriate location; see `walkChild`.
diff --git a/pkg/state/decode.go b/pkg/state/decode.go
new file mode 100644
index 000000000..c9971cdf6
--- /dev/null
+++ b/pkg/state/decode.go
@@ -0,0 +1,725 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "math"
+ "reflect"
+
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+// internalCallback is a interface called on object completion.
+//
+// There are two implementations: objectDecodeState & userCallback.
+type internalCallback interface {
+ // source returns the dependent object. May be nil.
+ source() *objectDecodeState
+
+ // callbackRun executes the callback.
+ callbackRun()
+}
+
+// userCallback is an implementation of internalCallback.
+type userCallback func()
+
+// source implements internalCallback.source.
+func (userCallback) source() *objectDecodeState {
+ return nil
+}
+
+// callbackRun implements internalCallback.callbackRun.
+func (uc userCallback) callbackRun() {
+ uc()
+}
+
+// objectDecodeState represents an object that may be in the process of being
+// decoded. Specifically, it represents either a decoded object, or an an
+// interest in a future object that will be decoded. When that interest is
+// registered (via register), the storage for the object will be created, but
+// it will not be decoded until the object is encountered in the stream.
+type objectDecodeState struct {
+ // id is the id for this object.
+ id objectID
+
+ // typ is the id for this typeID. This may be zero if this is not a
+ // type-registered structure.
+ typ typeID
+
+ // obj is the object. This may or may not be valid yet, depending on
+ // whether complete returns true. However, regardless of whether the
+ // object is valid, obj contains a final storage location for the
+ // object. This is immutable.
+ //
+ // Note that this must be addressable (obj.Addr() must not panic).
+ //
+ // The obj passed to the decode methods below will equal this obj only
+ // in the case of decoding the top-level object. However, the passed
+ // obj may represent individual fields, elements of a slice, etc. that
+ // are effectively embedded within the reflect.Value below but with
+ // distinct types.
+ obj reflect.Value
+
+ // blockedBy is the number of dependencies this object has.
+ blockedBy int
+
+ // callbacksInline is inline storage for callbacks.
+ callbacksInline [2]internalCallback
+
+ // callbacks is a set of callbacks to execute on load.
+ callbacks []internalCallback
+
+ completeEntry
+}
+
+// addCallback adds a callback to the objectDecodeState.
+func (ods *objectDecodeState) addCallback(ic internalCallback) {
+ if ods.callbacks == nil {
+ ods.callbacks = ods.callbacksInline[:0]
+ }
+ ods.callbacks = append(ods.callbacks, ic)
+}
+
+// findCycleFor returns when the given object is found in the blocking set.
+func (ods *objectDecodeState) findCycleFor(target *objectDecodeState) []*objectDecodeState {
+ for _, ic := range ods.callbacks {
+ other := ic.source()
+ if other != nil && other == target {
+ return []*objectDecodeState{target}
+ } else if childList := other.findCycleFor(target); childList != nil {
+ return append(childList, other)
+ }
+ }
+
+ // This should not occur.
+ Failf("no deadlock found?")
+ panic("unreachable")
+}
+
+// findCycle finds a dependency cycle.
+func (ods *objectDecodeState) findCycle() []*objectDecodeState {
+ return append(ods.findCycleFor(ods), ods)
+}
+
+// source implements internalCallback.source.
+func (ods *objectDecodeState) source() *objectDecodeState {
+ return ods
+}
+
+// callbackRun implements internalCallback.callbackRun.
+func (ods *objectDecodeState) callbackRun() {
+ ods.blockedBy--
+}
+
+// decodeState is a graph of objects in the process of being decoded.
+//
+// The decode process involves loading the breadth-first graph generated by
+// encode. This graph is read in it's entirety, ensuring that all object
+// storage is complete.
+//
+// As the graph is being serialized, a set of completion callbacks are
+// executed. These completion callbacks should form a set of acyclic subgraphs
+// over the original one. After decoding is complete, the objects are scanned
+// to ensure that all callbacks are executed, otherwise the callback graph was
+// not acyclic.
+type decodeState struct {
+ // ctx is the decode context.
+ ctx context.Context
+
+ // r is the input stream.
+ r wire.Reader
+
+ // types is the type database.
+ types typeDecodeDatabase
+
+ // objectByID is the set of objects in progress.
+ objectsByID []*objectDecodeState
+
+ // deferred are objects that have been read, by no interest has been
+ // registered yet. These will be decoded once interest in registered.
+ deferred map[objectID]wire.Object
+
+ // pending is the set of objects that are not yet complete.
+ pending completeList
+
+ // stats tracks time data.
+ stats Stats
+}
+
+// lookup looks up an object in decodeState or returns nil if no such object
+// has been previously registered.
+func (ds *decodeState) lookup(id objectID) *objectDecodeState {
+ if len(ds.objectsByID) < int(id) {
+ return nil
+ }
+ return ds.objectsByID[id-1]
+}
+
+// checkComplete checks for completion.
+func (ds *decodeState) checkComplete(ods *objectDecodeState) bool {
+ // Still blocked?
+ if ods.blockedBy > 0 {
+ return false
+ }
+
+ // Track stats if relevant.
+ if ods.callbacks != nil && ods.typ != 0 {
+ ds.stats.start(ods.typ)
+ defer ds.stats.done()
+ }
+
+ // Fire all callbacks.
+ for _, ic := range ods.callbacks {
+ ic.callbackRun()
+ }
+
+ // Mark completed.
+ cbs := ods.callbacks
+ ods.callbacks = nil
+ ds.pending.Remove(ods)
+
+ // Recursively check others.
+ for _, ic := range cbs {
+ if other := ic.source(); other != nil && other.blockedBy == 0 {
+ ds.checkComplete(other)
+ }
+ }
+
+ return true // All set.
+}
+
+// wait registers a dependency on an object.
+//
+// As a special case, we always allow _useable_ references back to the first
+// decoding object because it may have fields that are already decoded. We also
+// allow trivial self reference, since they can be handled internally.
+func (ds *decodeState) wait(waiter *objectDecodeState, id objectID, callback func()) {
+ switch id {
+ case waiter.id:
+ // Trivial self reference.
+ fallthrough
+ case 1:
+ // Root object; see above.
+ if callback != nil {
+ callback()
+ }
+ return
+ }
+
+ // Mark as blocked.
+ waiter.blockedBy++
+
+ // No nil can be returned here.
+ other := ds.lookup(id)
+ if callback != nil {
+ // Add the additional user callback.
+ other.addCallback(userCallback(callback))
+ }
+
+ // Mark waiter as unblocked.
+ other.addCallback(waiter)
+}
+
+// waitObject notes a blocking relationship.
+func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, callback func()) {
+ if rv, ok := encoded.(*wire.Ref); ok && rv.Root != 0 {
+ // Refs can encode pointers and maps.
+ ds.wait(ods, objectID(rv.Root), callback)
+ } else if sv, ok := encoded.(*wire.Slice); ok && sv.Ref.Root != 0 {
+ // See decodeObject; we need to wait for the array (if non-nil).
+ ds.wait(ods, objectID(sv.Ref.Root), callback)
+ } else if iv, ok := encoded.(*wire.Interface); ok {
+ // It's an interface (wait recurisvely).
+ ds.waitObject(ods, iv.Value, callback)
+ } else if callback != nil {
+ // Nothing to wait for: execute the callback immediately.
+ callback()
+ }
+}
+
+// walkChild returns a child object from obj, given an accessor path. This is
+// the decode-side equivalent to traverse in encode.go.
+//
+// For the purposes of this function, a child object is either a field within a
+// struct or an array element, with one such indirection per element in
+// path. The returned value may be an unexported field, so it may not be
+// directly assignable. See unsafePointerTo.
+func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value {
+ // See wire.Ref.Dots. The path here is specified in reverse order.
+ for i := len(path) - 1; i >= 0; i-- {
+ switch pc := path[i].(type) {
+ case *wire.FieldName: // Must be a pointer.
+ if obj.Kind() != reflect.Struct {
+ Failf("next component in child path is a field name, but the current object is not a struct. Path: %v, current obj: %#v", path, obj)
+ }
+ obj = obj.FieldByName(string(*pc))
+ case wire.Index: // Embedded.
+ if obj.Kind() != reflect.Array {
+ Failf("next component in child path is an array index, but the current object is not an array. Path: %v, current obj: %#v", path, obj)
+ }
+ obj = obj.Index(int(pc))
+ default:
+ panic("unreachable: switch should be exhaustive")
+ }
+ }
+ return obj
+}
+
+// register registers a decode with a type.
+//
+// This type is only used to instantiate a new object if it has not been
+// registered previously. This depends on the type provided if none is
+// available in the object itself.
+func (ds *decodeState) register(r *wire.Ref, typ reflect.Type) reflect.Value {
+ // Grow the objectsByID slice.
+ id := objectID(r.Root)
+ if len(ds.objectsByID) < int(id) {
+ ds.objectsByID = append(ds.objectsByID, make([]*objectDecodeState, int(id)-len(ds.objectsByID))...)
+ }
+
+ // Does this object already exist?
+ ods := ds.objectsByID[id-1]
+ if ods != nil {
+ return walkChild(r.Dots, ods.obj)
+ }
+
+ // Create the object.
+ if len(r.Dots) != 0 {
+ typ = ds.findType(r.Type)
+ }
+ v := reflect.New(typ)
+ ods = &objectDecodeState{
+ id: id,
+ obj: v.Elem(),
+ }
+ ds.objectsByID[id-1] = ods
+ ds.pending.PushBack(ods)
+
+ // Process any deferred objects & callbacks.
+ if encoded, ok := ds.deferred[id]; ok {
+ delete(ds.deferred, id)
+ ds.decodeObject(ods, ods.obj, encoded)
+ }
+
+ return walkChild(r.Dots, ods.obj)
+}
+
+// objectDecoder is for decoding structs.
+type objectDecoder struct {
+ // ds is decodeState.
+ ds *decodeState
+
+ // ods is current object being decoded.
+ ods *objectDecodeState
+
+ // reconciledTypeEntry is the reconciled type information.
+ rte *reconciledTypeEntry
+
+ // encoded is the encoded object state.
+ encoded *wire.Struct
+}
+
+// load is helper for the public methods on Source.
+func (od *objectDecoder) load(slot int, objPtr reflect.Value, wait bool, fn func()) {
+ // Note that we have reconciled the type and may remap the fields here
+ // to match what's expected by the decoder. The "slot" parameter here
+ // is in terms of the local type, where the fields in the encoded
+ // object are in terms of the wire object's type, which might be in a
+ // different order (but will have the same fields).
+ v := *od.encoded.Field(od.rte.FieldOrder[slot])
+ od.ds.decodeObject(od.ods, objPtr.Elem(), v)
+ if wait {
+ // Mark this individual object a blocker.
+ od.ds.waitObject(od.ods, v, fn)
+ }
+}
+
+// aterLoad implements Source.AfterLoad.
+func (od *objectDecoder) afterLoad(fn func()) {
+ // Queue the local callback; this will execute when all of the above
+ // data dependencies have been cleared.
+ od.ods.addCallback(userCallback(fn))
+}
+
+// decodeStruct decodes a struct value.
+func (ds *decodeState) decodeStruct(ods *objectDecodeState, obj reflect.Value, encoded *wire.Struct) {
+ if encoded.TypeID == 0 {
+ // Allow anonymous empty structs, but only if the encoded
+ // object also has no fields.
+ if encoded.Fields() == 0 && obj.NumField() == 0 {
+ return
+ }
+
+ // Propagate an error.
+ Failf("empty struct on wire %#v has field mismatch with type %q", encoded, obj.Type().Name())
+ }
+
+ // Lookup the object type.
+ rte := ds.types.Lookup(typeID(encoded.TypeID), obj.Type())
+ ods.typ = typeID(encoded.TypeID)
+
+ // Invoke the loader.
+ od := objectDecoder{
+ ds: ds,
+ ods: ods,
+ rte: rte,
+ encoded: encoded,
+ }
+ ds.stats.start(ods.typ)
+ defer ds.stats.done()
+ if sl, ok := obj.Addr().Interface().(SaverLoader); ok {
+ // Note: may be a registered empty struct which does not
+ // implement the saver/loader interfaces.
+ sl.StateLoad(Source{internal: od})
+ }
+}
+
+// decodeMap decodes a map value.
+func (ds *decodeState) decodeMap(ods *objectDecodeState, obj reflect.Value, encoded *wire.Map) {
+ if obj.IsNil() {
+ // See pointerTo.
+ obj.Set(reflect.MakeMap(obj.Type()))
+ }
+ for i := 0; i < len(encoded.Keys); i++ {
+ // Decode the objects.
+ kv := reflect.New(obj.Type().Key()).Elem()
+ vv := reflect.New(obj.Type().Elem()).Elem()
+ ds.decodeObject(ods, kv, encoded.Keys[i])
+ ds.decodeObject(ods, vv, encoded.Values[i])
+ ds.waitObject(ods, encoded.Keys[i], nil)
+ ds.waitObject(ods, encoded.Values[i], nil)
+
+ // Set in the map.
+ obj.SetMapIndex(kv, vv)
+ }
+}
+
+// decodeArray decodes an array value.
+func (ds *decodeState) decodeArray(ods *objectDecodeState, obj reflect.Value, encoded *wire.Array) {
+ if len(encoded.Contents) != obj.Len() {
+ Failf("mismatching array length expect=%d, actual=%d", obj.Len(), len(encoded.Contents))
+ }
+ // Decode the contents into the array.
+ for i := 0; i < len(encoded.Contents); i++ {
+ ds.decodeObject(ods, obj.Index(i), encoded.Contents[i])
+ ds.waitObject(ods, encoded.Contents[i], nil)
+ }
+}
+
+// findType finds the type for the given wire.TypeSpecs.
+func (ds *decodeState) findType(t wire.TypeSpec) reflect.Type {
+ switch x := t.(type) {
+ case wire.TypeID:
+ typ := ds.types.LookupType(typeID(x))
+ rte := ds.types.Lookup(typeID(x), typ)
+ return rte.LocalType
+ case *wire.TypeSpecPointer:
+ return reflect.PtrTo(ds.findType(x.Type))
+ case *wire.TypeSpecArray:
+ return reflect.ArrayOf(int(x.Count), ds.findType(x.Type))
+ case *wire.TypeSpecSlice:
+ return reflect.SliceOf(ds.findType(x.Type))
+ case *wire.TypeSpecMap:
+ return reflect.MapOf(ds.findType(x.Key), ds.findType(x.Value))
+ default:
+ // Should not happen.
+ Failf("unknown type %#v", t)
+ }
+ panic("unreachable")
+}
+
+// decodeInterface decodes an interface value.
+func (ds *decodeState) decodeInterface(ods *objectDecodeState, obj reflect.Value, encoded *wire.Interface) {
+ if _, ok := encoded.Type.(wire.TypeSpecNil); ok {
+ // Special case; the nil object. Just decode directly, which
+ // will read nil from the wire (if encoded correctly).
+ ds.decodeObject(ods, obj, encoded.Value)
+ return
+ }
+
+ // We now need to resolve the actual type.
+ typ := ds.findType(encoded.Type)
+
+ // We need to imbue type information here, then we can proceed to
+ // decode normally. In order to avoid issues with setting value-types,
+ // we create a new non-interface version of this object. We will then
+ // set the interface object to be equal to whatever we decode.
+ origObj := obj
+ obj = reflect.New(typ).Elem()
+ defer origObj.Set(obj)
+
+ // With the object now having sufficient type information to actually
+ // have Set called on it, we can proceed to decode the value.
+ ds.decodeObject(ods, obj, encoded.Value)
+}
+
+// isFloatEq determines if x and y represent the same value.
+func isFloatEq(x float64, y float64) bool {
+ switch {
+ case math.IsNaN(x):
+ return math.IsNaN(y)
+ case math.IsInf(x, 1):
+ return math.IsInf(y, 1)
+ case math.IsInf(x, -1):
+ return math.IsInf(y, -1)
+ default:
+ return x == y
+ }
+}
+
+// isComplexEq determines if x and y represent the same value.
+func isComplexEq(x complex128, y complex128) bool {
+ return isFloatEq(real(x), real(y)) && isFloatEq(imag(x), imag(y))
+}
+
+// decodeObject decodes a object value.
+func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, encoded wire.Object) {
+ switch x := encoded.(type) {
+ case wire.Nil: // Fast path: first.
+ // We leave obj alone here. That's because if obj represents an
+ // interface, it may have been imbued with type information in
+ // decodeInterface, and we don't want to destroy that.
+ case *wire.Ref:
+ // Nil pointers may be encoded in a "forceValue" context. For
+ // those we just leave it alone as the value will already be
+ // correct (nil).
+ if id := objectID(x.Root); id == 0 {
+ return
+ }
+
+ // Note that if this is a map type, we go through a level of
+ // indirection to allow for map aliasing.
+ if obj.Kind() == reflect.Map {
+ v := ds.register(x, obj.Type())
+ if v.IsNil() {
+ // Note that we don't want to clobber the map
+ // if has already been decoded by decodeMap. We
+ // just make it so that we have a consistent
+ // reference when that eventually does happen.
+ v.Set(reflect.MakeMap(v.Type()))
+ }
+ obj.Set(v)
+ return
+ }
+
+ // Normal assignment: authoritative only if no dots.
+ v := ds.register(x, obj.Type().Elem())
+ if v.IsValid() {
+ obj.Set(unsafePointerTo(v))
+ }
+ case wire.Bool:
+ obj.SetBool(bool(x))
+ case wire.Int:
+ obj.SetInt(int64(x))
+ if obj.Int() != int64(x) {
+ Failf("signed integer truncated from %v to %v", int64(x), obj.Int())
+ }
+ case wire.Uint:
+ obj.SetUint(uint64(x))
+ if obj.Uint() != uint64(x) {
+ Failf("unsigned integer truncated from %v to %v", uint64(x), obj.Uint())
+ }
+ case wire.Float32:
+ obj.SetFloat(float64(x))
+ case wire.Float64:
+ obj.SetFloat(float64(x))
+ if !isFloatEq(obj.Float(), float64(x)) {
+ Failf("floating point number truncated from %v to %v", float64(x), obj.Float())
+ }
+ case *wire.Complex64:
+ obj.SetComplex(complex128(*x))
+ case *wire.Complex128:
+ obj.SetComplex(complex128(*x))
+ if !isComplexEq(obj.Complex(), complex128(*x)) {
+ Failf("complex number truncated from %v to %v", complex128(*x), obj.Complex())
+ }
+ case *wire.String:
+ obj.SetString(string(*x))
+ case *wire.Slice:
+ // See *wire.Ref above; same applies.
+ if id := objectID(x.Ref.Root); id == 0 {
+ return
+ }
+ // Note that it's fine to slice the array here and assume that
+ // contents will still be filled in later on.
+ typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type.
+ v := ds.register(&x.Ref, typ)
+ obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity)))
+ case *wire.Array:
+ ds.decodeArray(ods, obj, x)
+ case *wire.Struct:
+ ds.decodeStruct(ods, obj, x)
+ case *wire.Map:
+ ds.decodeMap(ods, obj, x)
+ case *wire.Interface:
+ ds.decodeInterface(ods, obj, x)
+ default:
+ // Shoud not happen, not propagated as an error.
+ Failf("unknown object %#v for %q", encoded, obj.Type().Name())
+ }
+}
+
+// Load deserializes the object graph rooted at obj.
+//
+// This function may panic and should be run in safely().
+func (ds *decodeState) Load(obj reflect.Value) {
+ ds.stats.init()
+ defer ds.stats.fini(func(id typeID) string {
+ return ds.types.LookupName(id)
+ })
+
+ // Create the root object.
+ ds.objectsByID = append(ds.objectsByID, &objectDecodeState{
+ id: 1,
+ obj: obj,
+ })
+
+ // Read the number of objects.
+ lastID, object, err := ReadHeader(ds.r)
+ if err != nil {
+ Failf("header error: %w", err)
+ }
+ if !object {
+ Failf("object missing")
+ }
+
+ // Decode all objects.
+ var (
+ encoded wire.Object
+ ods *objectDecodeState
+ id = objectID(1)
+ tid = typeID(1)
+ )
+ if err := safely(func() {
+ // Decode all objects in the stream.
+ //
+ // Note that the structure of this decoding loop should match
+ // the raw decoding loop in printer.go.
+ for id <= objectID(lastID) {
+ // Unmarshal the object.
+ encoded = wire.Load(ds.r)
+
+ // Is this a type object? Handle inline.
+ if wt, ok := encoded.(*wire.Type); ok {
+ ds.types.Register(wt)
+ tid++
+ encoded = nil
+ continue
+ }
+
+ // Actually resolve the object.
+ ods = ds.lookup(id)
+ if ods != nil {
+ // Decode the object.
+ ds.decodeObject(ods, ods.obj, encoded)
+ } else {
+ // If an object hasn't had interest registered
+ // previously or isn't yet valid, we deferred
+ // decoding until interest is registered.
+ ds.deferred[id] = encoded
+ }
+
+ // For error handling.
+ ods = nil
+ encoded = nil
+ id++
+ }
+ }); err != nil {
+ // Include as much information as we can, taking into account
+ // the possible state transitions above.
+ if ods != nil {
+ Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err)
+ } else if encoded != nil {
+ Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err)
+ } else {
+ Failf("general decoding error: %w", err)
+ }
+ }
+
+ // Check if we have any deferred objects.
+ for id, encoded := range ds.deferred {
+ // Shoud never happen, the graph was bogus.
+ Failf("still have deferred objects: one is ID %d, %#v", id, encoded)
+ }
+
+ // Scan and fire all callbacks. We iterate over the list of incomplete
+ // objects until all have been finished. We stop iterating if no
+ // objects become complete (there is a dependency cycle).
+ //
+ // Note that we iterate backwards here, because there will be a strong
+ // tendendcy for blocking relationships to go from earlier objects to
+ // later (deeper) objects in the graph. This will reduce the number of
+ // iterations required to finish all objects.
+ if err := safely(func() {
+ for ds.pending.Back() != nil {
+ thisCycle := false
+ for ods = ds.pending.Back(); ods != nil; {
+ if ds.checkComplete(ods) {
+ thisCycle = true
+ break
+ }
+ ods = ods.Prev()
+ }
+ if !thisCycle {
+ break
+ }
+ }
+ }); err != nil {
+ Failf("error executing callbacks for %#v: %w", ods.obj.Interface(), err)
+ }
+
+ // Check if we have any remaining dependency cycles. If there are any
+ // objects left in the pending list, then it must be due to a cycle.
+ if ods := ds.pending.Front(); ods != nil {
+ // This must be the result of a dependency cycle.
+ cycle := ods.findCycle()
+ var buf bytes.Buffer
+ buf.WriteString("dependency cycle: {")
+ for i, cycleOS := range cycle {
+ if i > 0 {
+ buf.WriteString(" => ")
+ }
+ fmt.Fprintf(&buf, "%q", cycleOS.obj.Type())
+ }
+ buf.WriteString("}")
+ Failf("incomplete graph: %s", string(buf.Bytes()))
+ }
+}
+
+// ReadHeader reads an object header.
+//
+// Each object written to the statefile is prefixed with a header. See
+// WriteHeader for more information; these functions are exported to allow
+// non-state writes to the file to play nice with debugging tools.
+func ReadHeader(r wire.Reader) (length uint64, object bool, err error) {
+ // Read the header.
+ err = safely(func() {
+ length = wire.LoadUint(r)
+ })
+ if err != nil {
+ // On the header, pass raw I/O errors.
+ if sErr, ok := err.(*ErrState); ok {
+ return 0, false, sErr.Unwrap()
+ }
+ }
+
+ // Decode whether the object is valid.
+ object = length&objectFlag != 0
+ length &^= objectFlag
+ return
+}
diff --git a/pkg/state/decode_unsafe.go b/pkg/state/decode_unsafe.go
new file mode 100644
index 000000000..d048f61a1
--- /dev/null
+++ b/pkg/state/decode_unsafe.go
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// unsafePointerTo is logically equivalent to reflect.Value.Addr, but works on
+// values representing unexported fields. This bypasses visibility, but not
+// type safety.
+func unsafePointerTo(obj reflect.Value) reflect.Value {
+ return reflect.NewAt(obj.Type(), unsafe.Pointer(obj.UnsafeAddr()))
+}
diff --git a/pkg/state/encode.go b/pkg/state/encode.go
new file mode 100644
index 000000000..92fcad4e9
--- /dev/null
+++ b/pkg/state/encode.go
@@ -0,0 +1,841 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "context"
+ "reflect"
+
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+// objectEncodeState the type and identity of an object occupying a memory
+// address range. This is the value type for addrSet, and the intrusive entry
+// for the pending and deferred lists.
+type objectEncodeState struct {
+ // id is the assigned ID for this object.
+ id objectID
+
+ // obj is the object value. Note that this may be replaced if we
+ // encounter an object that contains this object. When this happens (in
+ // resolve), we will update existing references approprately, below,
+ // and defer a re-encoding of the object.
+ obj reflect.Value
+
+ // encoded is the encoded value of this object. Note that this may not
+ // be up to date if this object is still in the deferred list.
+ encoded wire.Object
+
+ // how indicates whether this object should be encoded as a value. This
+ // is used only for deferred encoding.
+ how encodeStrategy
+
+ // refs are the list of reference objects used by other objects
+ // referring to this object. When the object is updated, these
+ // references may be updated directly and automatically.
+ refs []*wire.Ref
+
+ pendingEntry
+ deferredEntry
+}
+
+// encodeState is state used for encoding.
+//
+// The encoding process constructs a representation of the in-memory graph of
+// objects before a single object is serialized. This is done to ensure that
+// all references can be fully disambiguated. See resolve for more details.
+type encodeState struct {
+ // ctx is the encode context.
+ ctx context.Context
+
+ // w is the output stream.
+ w wire.Writer
+
+ // types is the type database.
+ types typeEncodeDatabase
+
+ // lastID is the last allocated object ID.
+ lastID objectID
+
+ // values tracks the address ranges occupied by objects, along with the
+ // types of these objects. This is used to locate pointer targets,
+ // including pointers to fields within another type.
+ //
+ // Multiple objects may overlap in memory iff the larger object fully
+ // contains the smaller one, and the type of the smaller object matches
+ // a field or array element's type at the appropriate offset. An
+ // arbitrary number of objects may be nested in this manner.
+ //
+ // Note that this does not track zero-sized objects, those are tracked
+ // by zeroValues below.
+ values addrSet
+
+ // zeroValues tracks zero-sized objects.
+ zeroValues map[reflect.Type]*objectEncodeState
+
+ // deferred is the list of objects to be encoded.
+ deferred deferredList
+
+ // pendingTypes is the list of types to be serialized. Serialization
+ // will occur when all objects have been encoded, but before pending is
+ // serialized.
+ pendingTypes []wire.Type
+
+ // pending is the list of objects to be serialized. Serialization does
+ // not actually occur until the full object graph is computed.
+ pending pendingList
+
+ // stats tracks time data.
+ stats Stats
+}
+
+// isSameSizeParent returns true if child is a field value or element within
+// parent. Only a struct or array can have a child value.
+//
+// isSameSizeParent deals with objects like this:
+//
+// struct child {
+// // fields..
+// }
+//
+// struct parent {
+// c child
+// }
+//
+// var p parent
+// record(&p.c)
+//
+// Here, &p and &p.c occupy the exact same address range.
+//
+// Or like this:
+//
+// struct child {
+// // fields
+// }
+//
+// var arr [1]parent
+// record(&arr[0])
+//
+// Similarly, &arr[0] and &arr[0].c have the exact same address range.
+//
+// Precondition: parent and child must occupy the same memory.
+func isSameSizeParent(parent reflect.Value, childType reflect.Type) bool {
+ switch parent.Kind() {
+ case reflect.Struct:
+ for i := 0; i < parent.NumField(); i++ {
+ field := parent.Field(i)
+ if field.Type() == childType {
+ return true
+ }
+ // Recurse through any intermediate types.
+ if isSameSizeParent(field, childType) {
+ return true
+ }
+ // Does it make sense to keep going if the first field
+ // doesn't match? Yes, because there might be an
+ // arbitrary number of zero-sized fields before we get
+ // a match, and childType itself can be zero-sized.
+ }
+ return false
+ case reflect.Array:
+ // The only case where an array with more than one elements can
+ // return true is if childType is zero-sized. In such cases,
+ // it's ambiguous which element contains the match since a
+ // zero-sized child object fully fits in any of the zero-sized
+ // elements in an array... However since all elements are of
+ // the same type, we only need to check one element.
+ //
+ // For non-zero-sized childTypes, parent.Len() must be 1, but a
+ // combination of the precondition and an implicit comparison
+ // between the array element size and childType ensures this.
+ return parent.Len() > 0 && isSameSizeParent(parent.Index(0), childType)
+ default:
+ return false
+ }
+}
+
+// nextID returns the next valid ID.
+func (es *encodeState) nextID() objectID {
+ es.lastID++
+ return objectID(es.lastID)
+}
+
+// dummyAddr points to the dummy zero-sized address.
+var dummyAddr = reflect.ValueOf(new(struct{})).Pointer()
+
+// resolve records the address range occupied by an object.
+func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
+ addr := obj.Pointer()
+
+ // Is this a map pointer? Just record the single address. It is not
+ // possible to take any pointers into the map internals.
+ if obj.Kind() == reflect.Map {
+ if addr == 0 {
+ // Just leave the nil reference alone. This is fine, we
+ // may need to encode as a reference in this way. We
+ // return nil for our objectEncodeState so that anyone
+ // depending on this value knows there's nothing there.
+ return
+ }
+ if seg, _ := es.values.Find(addr); seg.Ok() {
+ // Ensure the map types match.
+ existing := seg.Value()
+ if existing.obj.Type() != obj.Type() {
+ Failf("overlapping map objects at 0x%x: [new object] %#v [existing object type] %s", addr, obj, existing.obj)
+ }
+
+ // No sense recording refs, maps may not be replaced by
+ // covering objects, they are maximal.
+ ref.Root = wire.Uint(existing.id)
+ return
+ }
+
+ // Record the map.
+ oes := &objectEncodeState{
+ id: es.nextID(),
+ obj: obj,
+ how: encodeMapAsValue,
+ }
+ es.values.Add(addrRange{addr, addr + 1}, oes)
+ es.pending.PushBack(oes)
+ es.deferred.PushBack(oes)
+
+ // See above: no ref recording.
+ ref.Root = wire.Uint(oes.id)
+ return
+ }
+
+ // If not a map, then the object must be a pointer.
+ if obj.Kind() != reflect.Ptr {
+ Failf("attempt to record non-map and non-pointer object %#v", obj)
+ }
+
+ obj = obj.Elem() // Value from here.
+
+ // Is this a zero-sized type?
+ typ := obj.Type()
+ size := typ.Size()
+ if size == 0 {
+ if addr == dummyAddr {
+ // Zero-sized objects point to a dummy byte within the
+ // runtime. There's no sense recording this in the
+ // address map. We add this to the dedicated
+ // zeroValues.
+ //
+ // Note that zero-sized objects must be *true*
+ // zero-sized objects. They cannot be part of some
+ // larger object. In that case, they are assigned a
+ // 1-byte address at the end of the object.
+ oes, ok := es.zeroValues[typ]
+ if !ok {
+ oes = &objectEncodeState{
+ id: es.nextID(),
+ obj: obj,
+ }
+ es.zeroValues[typ] = oes
+ es.pending.PushBack(oes)
+ es.deferred.PushBack(oes)
+ }
+
+ // There's also no sense tracking back references. We
+ // know that this is a true zero-sized object, and not
+ // part of a larger container, so it will not change.
+ ref.Root = wire.Uint(oes.id)
+ return
+ }
+ size = 1 // See above.
+ }
+
+ // Calculate the container.
+ end := addr + size
+ r := addrRange{addr, end}
+ if seg, _ := es.values.Find(addr); seg.Ok() {
+ existing := seg.Value()
+ switch {
+ case seg.Start() == addr && seg.End() == end && obj.Type() == existing.obj.Type():
+ // The object is a perfect match. Happy path. Avoid the
+ // traversal and just return directly. We don't need to
+ // encode the type information or any dots here.
+ ref.Root = wire.Uint(existing.id)
+ existing.refs = append(existing.refs, ref)
+ return
+
+ case (seg.Start() < addr && seg.End() >= end) || (seg.Start() <= addr && seg.End() > end):
+ // The previously registered object is larger than
+ // this, no need to update. But we expect some
+ // traversal below.
+
+ case seg.Start() == addr && seg.End() == end:
+ if !isSameSizeParent(obj, existing.obj.Type()) {
+ break // Needs traversal.
+ }
+ fallthrough // Needs update.
+
+ case (seg.Start() > addr && seg.End() <= end) || (seg.Start() >= addr && seg.End() < end):
+ // Update the object and redo the encoding.
+ old := existing.obj
+ existing.obj = obj
+ es.deferred.Remove(existing)
+ es.deferred.PushBack(existing)
+
+ // The previously registered object is superseded by
+ // this new object. We are guaranteed to not have any
+ // mergeable neighbours in this segment set.
+ if !raceEnabled {
+ seg.SetRangeUnchecked(r)
+ } else {
+ // Add extra paranoid. This will be statically
+ // removed at compile time unless a race build.
+ es.values.Remove(seg)
+ es.values.Add(r, existing)
+ seg = es.values.LowerBoundSegment(addr)
+ }
+
+ // Compute the traversal required & update references.
+ dots := traverse(obj.Type(), old.Type(), addr, seg.Start())
+ wt := es.findType(obj.Type())
+ for _, ref := range existing.refs {
+ ref.Dots = append(ref.Dots, dots...)
+ ref.Type = wt
+ }
+ default:
+ // There is a non-sensical overlap.
+ Failf("overlapping objects: [new object] %#v [existing object] %#v", obj, existing.obj)
+ }
+
+ // Compute the new reference, record and return it.
+ ref.Root = wire.Uint(existing.id)
+ ref.Dots = traverse(existing.obj.Type(), obj.Type(), seg.Start(), addr)
+ ref.Type = es.findType(obj.Type())
+ existing.refs = append(existing.refs, ref)
+ return
+ }
+
+ // The only remaining case is a pointer value that doesn't overlap with
+ // any registered addresses. Create a new entry for it, and start
+ // tracking the first reference we just created.
+ oes := &objectEncodeState{
+ id: es.nextID(),
+ obj: obj,
+ }
+ if !raceEnabled {
+ es.values.AddWithoutMerging(r, oes)
+ } else {
+ // Merges should never happen. This is just enabled extra
+ // sanity checks because the Merge function below will panic.
+ es.values.Add(r, oes)
+ }
+ es.pending.PushBack(oes)
+ es.deferred.PushBack(oes)
+ ref.Root = wire.Uint(oes.id)
+ oes.refs = append(oes.refs, ref)
+}
+
+// traverse searches for a target object within a root object, where the target
+// object is a struct field or array element within root, with potentially
+// multiple intervening types. traverse returns the set of field or element
+// traversals required to reach the target.
+//
+// Note that for efficiency, traverse returns the dots in the reverse order.
+// That is, the first traversal required will be the last element of the list.
+//
+// Precondition: The target object must lie completely within the range defined
+// by [rootAddr, rootAddr + sizeof(rootType)].
+func traverse(rootType, targetType reflect.Type, rootAddr, targetAddr uintptr) []wire.Dot {
+ // Recursion base case: the types actually match.
+ if targetType == rootType && targetAddr == rootAddr {
+ return nil
+ }
+
+ switch rootType.Kind() {
+ case reflect.Struct:
+ offset := targetAddr - rootAddr
+ for i := rootType.NumField(); i > 0; i-- {
+ field := rootType.Field(i - 1)
+ // The first field from the end with an offset that is
+ // smaller than or equal to our address offset is where
+ // the target is located. Traverse from there.
+ if field.Offset <= offset {
+ dots := traverse(field.Type, targetType, rootAddr+field.Offset, targetAddr)
+ fieldName := wire.FieldName(field.Name)
+ return append(dots, &fieldName)
+ }
+ }
+ // Should never happen; the target should be reachable.
+ Failf("no field in root type %v contains target type %v", rootType, targetType)
+
+ case reflect.Array:
+ // Since arrays have homogenous types, all elements have the
+ // same size and we can compute where the target lives. This
+ // does not matter for the purpose of typing, but matters for
+ // the purpose of computing the address of the given index.
+ elemSize := int(rootType.Elem().Size())
+ n := int(targetAddr-rootAddr) / elemSize // Relies on integer division rounding down.
+ if rootType.Len() < n {
+ Failf("traversal target of type %v @%x is beyond the end of the array type %v @%x with %v elements",
+ targetType, targetAddr, rootType, rootAddr, rootType.Len())
+ }
+ dots := traverse(rootType.Elem(), targetType, rootAddr+uintptr(n*elemSize), targetAddr)
+ return append(dots, wire.Index(n))
+
+ default:
+ // For any other type, there's no possibility of aliasing so if
+ // the types didn't match earlier then we have an addresss
+ // collision which shouldn't be possible at this point.
+ Failf("traverse failed for root type %v and target type %v", rootType, targetType)
+ }
+ panic("unreachable")
+}
+
+// encodeMap encodes a map.
+func (es *encodeState) encodeMap(obj reflect.Value, dest *wire.Object) {
+ if obj.IsNil() {
+ // Because there is a difference between a nil map and an empty
+ // map, we need to not decode in the case of a truly nil map.
+ *dest = wire.Nil{}
+ return
+ }
+ l := obj.Len()
+ m := &wire.Map{
+ Keys: make([]wire.Object, l),
+ Values: make([]wire.Object, l),
+ }
+ *dest = m
+ for i, k := range obj.MapKeys() {
+ v := obj.MapIndex(k)
+ // Map keys must be encoded using the full value because the
+ // type will be omitted after the first key.
+ es.encodeObject(k, encodeAsValue, &m.Keys[i])
+ es.encodeObject(v, encodeAsValue, &m.Values[i])
+ }
+}
+
+// objectEncoder is for encoding structs.
+type objectEncoder struct {
+ // es is encodeState.
+ es *encodeState
+
+ // encoded is the encoded struct.
+ encoded *wire.Struct
+}
+
+// save is called by the public methods on Sink.
+func (oe *objectEncoder) save(slot int, obj reflect.Value) {
+ fieldValue := oe.encoded.Field(slot)
+ oe.es.encodeObject(obj, encodeDefault, fieldValue)
+}
+
+// encodeStruct encodes a composite object.
+func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) {
+ // Ensure that the obj is addressable. There are two cases when it is
+ // not. First, is when this is dispatched via SaveValue. Second, when
+ // this is a map key as a struct. Either way, we need to make a copy to
+ // obtain an addressable value.
+ if !obj.CanAddr() {
+ localObj := reflect.New(obj.Type())
+ localObj.Elem().Set(obj)
+ obj = localObj.Elem()
+ }
+
+ // Prepare the value.
+ s := &wire.Struct{}
+ *dest = s
+
+ // Look the type up in the database.
+ te, ok := es.types.Lookup(obj.Type())
+ if te == nil {
+ if obj.NumField() == 0 {
+ // Allow unregistered anonymous, empty structs. This
+ // will just return success without ever invoking the
+ // passed function. This uses the immutable EmptyStruct
+ // variable to prevent an allocation in this case.
+ //
+ // Note that this mechanism does *not* work for
+ // interfaces in general. So you can't dispatch
+ // non-registered empty structs via interfaces because
+ // then they can't be restored.
+ s.Alloc(0)
+ return
+ }
+ // We need a SaverLoader for struct types.
+ Failf("struct %T does not implement SaverLoader", obj.Interface())
+ }
+ if !ok {
+ // Queue the type to be serialized.
+ es.pendingTypes = append(es.pendingTypes, te.Type)
+ }
+
+ // Invoke the provided saver.
+ s.TypeID = wire.TypeID(te.ID)
+ s.Alloc(len(te.Fields))
+ oe := objectEncoder{
+ es: es,
+ encoded: s,
+ }
+ es.stats.start(te.ID)
+ defer es.stats.done()
+ if sl, ok := obj.Addr().Interface().(SaverLoader); ok {
+ // Note: may be a registered empty struct which does not
+ // implement the saver/loader interfaces.
+ sl.StateSave(Sink{internal: oe})
+ }
+}
+
+// encodeArray encodes an array.
+func (es *encodeState) encodeArray(obj reflect.Value, dest *wire.Object) {
+ l := obj.Len()
+ a := &wire.Array{
+ Contents: make([]wire.Object, l),
+ }
+ *dest = a
+ for i := 0; i < l; i++ {
+ // We need to encode the full value because arrays are encoded
+ // using the type information from only the first element.
+ es.encodeObject(obj.Index(i), encodeAsValue, &a.Contents[i])
+ }
+}
+
+// findType recursively finds type information.
+func (es *encodeState) findType(typ reflect.Type) wire.TypeSpec {
+ // First: check if this is a proper type. It's possible for pointers,
+ // slices, arrays, maps, etc to all have some different type.
+ te, ok := es.types.Lookup(typ)
+ if te != nil {
+ if !ok {
+ // See encodeStruct.
+ es.pendingTypes = append(es.pendingTypes, te.Type)
+ }
+ return wire.TypeID(te.ID)
+ }
+
+ switch typ.Kind() {
+ case reflect.Ptr:
+ return &wire.TypeSpecPointer{
+ Type: es.findType(typ.Elem()),
+ }
+ case reflect.Slice:
+ return &wire.TypeSpecSlice{
+ Type: es.findType(typ.Elem()),
+ }
+ case reflect.Array:
+ return &wire.TypeSpecArray{
+ Count: wire.Uint(typ.Len()),
+ Type: es.findType(typ.Elem()),
+ }
+ case reflect.Map:
+ return &wire.TypeSpecMap{
+ Key: es.findType(typ.Key()),
+ Value: es.findType(typ.Elem()),
+ }
+ default:
+ // After potentially chasing many pointers, the
+ // ultimate type of the object is not known.
+ Failf("type %q is not known", typ)
+ }
+ panic("unreachable")
+}
+
+// encodeInterface encodes an interface.
+func (es *encodeState) encodeInterface(obj reflect.Value, dest *wire.Object) {
+ // Dereference the object.
+ obj = obj.Elem()
+ if !obj.IsValid() {
+ // Special case: the nil object.
+ *dest = &wire.Interface{
+ Type: wire.TypeSpecNil{},
+ Value: wire.Nil{},
+ }
+ return
+ }
+
+ // Encode underlying object.
+ i := &wire.Interface{
+ Type: es.findType(obj.Type()),
+ }
+ *dest = i
+ es.encodeObject(obj, encodeAsValue, &i.Value)
+}
+
+// isPrimitive returns true if this is a primitive object, or a composite
+// object composed entirely of primitives.
+func isPrimitiveZero(typ reflect.Type) bool {
+ switch typ.Kind() {
+ case reflect.Ptr:
+ // Pointers are always treated as primitive types because we
+ // won't encode directly from here. Returning true here won't
+ // prevent the object from being encoded correctly.
+ return true
+ case reflect.Bool:
+ return true
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return true
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ return true
+ case reflect.Float32, reflect.Float64:
+ return true
+ case reflect.Complex64, reflect.Complex128:
+ return true
+ case reflect.String:
+ return true
+ case reflect.Slice:
+ // The slice itself a primitive, but not necessarily the array
+ // that points to. This is similar to a pointer.
+ return true
+ case reflect.Array:
+ // We cannot treat an array as a primitive, because it may be
+ // composed of structures or other things with side-effects.
+ return isPrimitiveZero(typ.Elem())
+ case reflect.Interface:
+ // Since we now that this type is the zero type, the interface
+ // value must be zero. Therefore this is primitive.
+ return true
+ case reflect.Struct:
+ return false
+ case reflect.Map:
+ // The isPrimitiveZero function is called only on zero-types to
+ // see if it's safe to serialize. Since a zero map has no
+ // elements, it is safe to treat as a primitive.
+ return true
+ default:
+ Failf("unknown type %q", typ.Name())
+ }
+ panic("unreachable")
+}
+
+// encodeStrategy is the strategy used for encodeObject.
+type encodeStrategy int
+
+const (
+ // encodeDefault means types are encoded normally as references.
+ encodeDefault encodeStrategy = iota
+
+ // encodeAsValue means that types will never take short-circuited and
+ // will always be encoded as a normal value.
+ encodeAsValue
+
+ // encodeMapAsValue means that even maps will be fully encoded.
+ encodeMapAsValue
+)
+
+// encodeObject encodes an object.
+func (es *encodeState) encodeObject(obj reflect.Value, how encodeStrategy, dest *wire.Object) {
+ if how == encodeDefault && isPrimitiveZero(obj.Type()) && obj.IsZero() {
+ *dest = wire.Nil{}
+ return
+ }
+ switch obj.Kind() {
+ case reflect.Ptr: // Fast path: first.
+ r := new(wire.Ref)
+ *dest = r
+ if obj.IsNil() {
+ // May be in an array or elsewhere such that a value is
+ // required. So we encode as a reference to the zero
+ // object, which does not exist. Note that this has to
+ // be handled correctly in the decode path as well.
+ return
+ }
+ es.resolve(obj, r)
+ case reflect.Bool:
+ *dest = wire.Bool(obj.Bool())
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ *dest = wire.Int(obj.Int())
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ *dest = wire.Uint(obj.Uint())
+ case reflect.Float32:
+ *dest = wire.Float32(obj.Float())
+ case reflect.Float64:
+ *dest = wire.Float64(obj.Float())
+ case reflect.Complex64:
+ c := wire.Complex64(obj.Complex())
+ *dest = &c // Needs alloc.
+ case reflect.Complex128:
+ c := wire.Complex128(obj.Complex())
+ *dest = &c // Needs alloc.
+ case reflect.String:
+ s := wire.String(obj.String())
+ *dest = &s // Needs alloc.
+ case reflect.Array:
+ es.encodeArray(obj, dest)
+ case reflect.Slice:
+ s := &wire.Slice{
+ Capacity: wire.Uint(obj.Cap()),
+ Length: wire.Uint(obj.Len()),
+ }
+ *dest = s
+ // Note that we do need to provide a wire.Slice type here as
+ // how is not encodeDefault. If this were the case, then it
+ // would have been caught by the IsZero check above and we
+ // would have just used wire.Nil{}.
+ if obj.IsNil() {
+ return
+ }
+ // Slices need pointer resolution.
+ es.resolve(arrayFromSlice(obj), &s.Ref)
+ case reflect.Interface:
+ es.encodeInterface(obj, dest)
+ case reflect.Struct:
+ es.encodeStruct(obj, dest)
+ case reflect.Map:
+ if how == encodeMapAsValue {
+ es.encodeMap(obj, dest)
+ return
+ }
+ r := new(wire.Ref)
+ *dest = r
+ es.resolve(obj, r)
+ default:
+ Failf("unknown object %#v", obj.Interface())
+ panic("unreachable")
+ }
+}
+
+// Save serializes the object graph rooted at obj.
+func (es *encodeState) Save(obj reflect.Value) {
+ es.stats.init()
+ defer es.stats.fini(func(id typeID) string {
+ return es.pendingTypes[id-1].Name
+ })
+
+ // Resolve the first object, which should queue a pile of additional
+ // objects on the pending list. All queued objects should be fully
+ // resolved, and we should be able to serialize after this call.
+ var root wire.Ref
+ es.resolve(obj.Addr(), &root)
+
+ // Encode the graph.
+ var oes *objectEncodeState
+ if err := safely(func() {
+ for oes = es.deferred.Front(); oes != nil; oes = es.deferred.Front() {
+ // Remove and encode the object. Note that as a result
+ // of this encoding, the object may be enqueued on the
+ // deferred list yet again. That's expected, and why it
+ // is removed first.
+ es.deferred.Remove(oes)
+ es.encodeObject(oes.obj, oes.how, &oes.encoded)
+ }
+ }); err != nil {
+ // Include the object in the error message.
+ Failf("encoding error at object %#v: %w", oes.obj.Interface(), err)
+ }
+
+ // Check that items are pending.
+ if es.pending.Front() == nil {
+ Failf("pending is empty?")
+ }
+
+ // Write the header with the number of objects. Note that there is no
+ // way that es.lastID could conflict with objectID, which would
+ // indicate that an impossibly large encoding.
+ if err := WriteHeader(es.w, uint64(es.lastID), true); err != nil {
+ Failf("error writing header: %w", err)
+ }
+
+ // Serialize all pending types and pending objects. Note that we don't
+ // bother removing from this list as we walk it because that just
+ // wastes time. It will not change after this point.
+ var id objectID
+ if err := safely(func() {
+ for _, wt := range es.pendingTypes {
+ // Encode the type.
+ wire.Save(es.w, &wt)
+ }
+ for oes = es.pending.Front(); oes != nil; oes = oes.pendingEntry.Next() {
+ id++ // First object is 1.
+ if oes.id != id {
+ Failf("expected id %d, got %d", id, oes.id)
+ }
+
+ // Marshall the object.
+ wire.Save(es.w, oes.encoded)
+ }
+ }); err != nil {
+ // Include the object and the error.
+ Failf("error serializing object %#v: %w", oes.encoded, err)
+ }
+
+ // Check what we wrote.
+ if id != es.lastID {
+ Failf("expected %d objects, wrote %d", es.lastID, id)
+ }
+}
+
+// objectFlag indicates that the length is a # of objects, rather than a raw
+// byte length. When this is set on a length header in the stream, it may be
+// decoded appropriately.
+const objectFlag uint64 = 1 << 63
+
+// WriteHeader writes a header.
+//
+// Each object written to the statefile should be prefixed with a header. In
+// order to generate statefiles that play nicely with debugging tools, raw
+// writes should be prefixed with a header with object set to false and the
+// appropriate length. This will allow tools to skip these regions.
+func WriteHeader(w wire.Writer, length uint64, object bool) error {
+ // Sanity check the length.
+ if length&objectFlag != 0 {
+ Failf("impossibly huge length: %d", length)
+ }
+ if object {
+ length |= objectFlag
+ }
+
+ // Write a header.
+ return safely(func() {
+ wire.SaveUint(w, length)
+ })
+}
+
+// pendingMapper is for the pending list.
+type pendingMapper struct{}
+
+func (pendingMapper) linkerFor(oes *objectEncodeState) *pendingEntry { return &oes.pendingEntry }
+
+// deferredMapper is for the deferred list.
+type deferredMapper struct{}
+
+func (deferredMapper) linkerFor(oes *objectEncodeState) *deferredEntry { return &oes.deferredEntry }
+
+// addrSetFunctions is used by addrSet.
+type addrSetFunctions struct{}
+
+func (addrSetFunctions) MinKey() uintptr {
+ return 0
+}
+
+func (addrSetFunctions) MaxKey() uintptr {
+ return ^uintptr(0)
+}
+
+func (addrSetFunctions) ClearValue(val **objectEncodeState) {
+ *val = nil
+}
+
+func (addrSetFunctions) Merge(r1 addrRange, val1 *objectEncodeState, r2 addrRange, val2 *objectEncodeState) (*objectEncodeState, bool) {
+ if val1.obj == val2.obj {
+ // This, should never happen. It would indicate that the same
+ // object exists in two non-contiguous address ranges. Note
+ // that this assertion can only be triggered if the race
+ // detector is enabled.
+ Failf("unexpected merge in addrSet @ %v and %v: %#v and %#v", r1, r2, val1.obj, val2.obj)
+ }
+ // Reject the merge.
+ return val1, false
+}
+
+func (addrSetFunctions) Split(r addrRange, val *objectEncodeState, _ uintptr) (*objectEncodeState, *objectEncodeState) {
+ // A split should never happen: we don't remove ranges.
+ Failf("unexpected split in addrSet @ %v: %#v", r, val.obj)
+ panic("unreachable")
+}
diff --git a/pkg/state/encode_unsafe.go b/pkg/state/encode_unsafe.go
new file mode 100644
index 000000000..e0dad83b4
--- /dev/null
+++ b/pkg/state/encode_unsafe.go
@@ -0,0 +1,33 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// arrayFromSlice constructs a new pointer to the slice data.
+//
+// It would be similar to the following:
+//
+// x := make([]Foo, l, c)
+// a := ([l]Foo*)(unsafe.Pointer(x[0]))
+//
+func arrayFromSlice(obj reflect.Value) reflect.Value {
+ return reflect.NewAt(
+ reflect.ArrayOf(obj.Cap(), obj.Type().Elem()),
+ unsafe.Pointer(obj.Pointer()))
+}
diff --git a/pkg/state/pretty/BUILD b/pkg/state/pretty/BUILD
new file mode 100644
index 000000000..d053802f7
--- /dev/null
+++ b/pkg/state/pretty/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "pretty",
+ srcs = ["pretty.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/state",
+ "//pkg/state/wire",
+ ],
+)
diff --git a/pkg/state/pretty/pretty.go b/pkg/state/pretty/pretty.go
new file mode 100644
index 000000000..cf37aaa49
--- /dev/null
+++ b/pkg/state/pretty/pretty.go
@@ -0,0 +1,273 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pretty is a pretty-printer for state streams.
+package pretty
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "reflect"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+func formatRef(x *wire.Ref, graph uint64, html bool) string {
+ baseRef := fmt.Sprintf("g%dr%d", graph, x.Root)
+ fullRef := baseRef
+ if len(x.Dots) > 0 {
+ // See wire.Ref; Type valid if Dots non-zero.
+ typ, _ := formatType(x.Type, graph, html)
+ var buf strings.Builder
+ buf.WriteString("(*")
+ buf.WriteString(typ)
+ buf.WriteString(")(")
+ buf.WriteString(baseRef)
+ for _, component := range x.Dots {
+ switch v := component.(type) {
+ case *wire.FieldName:
+ buf.WriteString(".")
+ buf.WriteString(string(*v))
+ case wire.Index:
+ buf.WriteString(fmt.Sprintf("[%d]", v))
+ default:
+ panic(fmt.Sprintf("unreachable: switch should be exhaustive, unhandled case %v", reflect.TypeOf(component)))
+ }
+ }
+ buf.WriteString(")")
+ fullRef = buf.String()
+ }
+ if html {
+ return fmt.Sprintf("<a href=\"#%s\">%s</a>", baseRef, fullRef)
+ }
+ return fullRef
+}
+
+func formatType(t wire.TypeSpec, graph uint64, html bool) (string, bool) {
+ switch x := t.(type) {
+ case wire.TypeID:
+ base := fmt.Sprintf("g%dt%d", graph, x)
+ if html {
+ return fmt.Sprintf("<a href=\"#%s\">%s</a>", base, base), true
+ }
+ return fmt.Sprintf("%s", base), true
+ case wire.TypeSpecNil:
+ return "", false // Only nil type.
+ case *wire.TypeSpecPointer:
+ element, _ := formatType(x.Type, graph, html)
+ return fmt.Sprintf("(*%s)", element), true
+ case *wire.TypeSpecArray:
+ element, _ := formatType(x.Type, graph, html)
+ return fmt.Sprintf("[%d](%s)", x.Count, element), true
+ case *wire.TypeSpecSlice:
+ element, _ := formatType(x.Type, graph, html)
+ return fmt.Sprintf("([]%s)", element), true
+ case *wire.TypeSpecMap:
+ key, _ := formatType(x.Key, graph, html)
+ value, _ := formatType(x.Value, graph, html)
+ return fmt.Sprintf("(map[%s]%s)", key, value), true
+ default:
+ panic(fmt.Sprintf("unreachable: unknown type %T", t))
+ }
+}
+
+// format formats a single object, for pretty-printing. It also returns whether
+// the value is a non-zero value.
+func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bool) {
+ switch x := encoded.(type) {
+ case wire.Nil:
+ return "nil", false
+ case *wire.String:
+ return fmt.Sprintf("%q", *x), *x != ""
+ case *wire.Complex64:
+ return fmt.Sprintf("%f+%fi", real(*x), imag(*x)), *x != 0.0
+ case *wire.Complex128:
+ return fmt.Sprintf("%f+%fi", real(*x), imag(*x)), *x != 0.0
+ case *wire.Ref:
+ return formatRef(x, graph, html), x.Root != 0
+ case *wire.Type:
+ tabs := "\n" + strings.Repeat("\t", depth)
+ items := make([]string, 0, len(x.Fields)+2)
+ items = append(items, fmt.Sprintf("type %s {", x.Name))
+ for i := 0; i < len(x.Fields); i++ {
+ items = append(items, fmt.Sprintf("\t%d: %s,", i, x.Fields[i]))
+ }
+ items = append(items, "}")
+ return strings.Join(items, tabs), true // No zero value.
+ case *wire.Slice:
+ return fmt.Sprintf("%s{len:%d,cap:%d}", formatRef(&x.Ref, graph, html), x.Length, x.Capacity), x.Capacity != 0
+ case *wire.Array:
+ if len(x.Contents) == 0 {
+ return "[]", false
+ }
+ items := make([]string, 0, len(x.Contents)+2)
+ zeros := make([]string, 0) // used to eliminate zero entries.
+ items = append(items, "[")
+ tabs := "\n" + strings.Repeat("\t", depth)
+ for i := 0; i < len(x.Contents); i++ {
+ item, ok := format(graph, depth+1, x.Contents[i], html)
+ if !ok {
+ zeros = append(zeros, fmt.Sprintf("\t%s,", item))
+ continue
+ }
+ if len(zeros) > 0 {
+ items = append(items, zeros...)
+ zeros = nil
+ }
+ items = append(items, fmt.Sprintf("\t%s,", item))
+ }
+ if len(zeros) > 0 {
+ items = append(items, fmt.Sprintf("\t... (%d zeros),", len(zeros)))
+ }
+ items = append(items, "]")
+ return strings.Join(items, tabs), len(zeros) < len(x.Contents)
+ case *wire.Struct:
+ typ, _ := formatType(x.TypeID, graph, html)
+ if x.Fields() == 0 {
+ return fmt.Sprintf("struct[%s]{}", typ), false
+ }
+ items := make([]string, 0, 2)
+ items = append(items, fmt.Sprintf("struct[%s]{", typ))
+ tabs := "\n" + strings.Repeat("\t", depth)
+ allZero := true
+ for i := 0; i < x.Fields(); i++ {
+ element, ok := format(graph, depth+1, *x.Field(i), html)
+ allZero = allZero && !ok
+ items = append(items, fmt.Sprintf("\t%d: %s,", i, element))
+ i++
+ }
+ items = append(items, "}")
+ return strings.Join(items, tabs), !allZero
+ case *wire.Map:
+ if len(x.Keys) == 0 {
+ return "map{}", false
+ }
+ items := make([]string, 0, len(x.Keys)+2)
+ items = append(items, "map{")
+ tabs := "\n" + strings.Repeat("\t", depth)
+ for i := 0; i < len(x.Keys); i++ {
+ key, _ := format(graph, depth+1, x.Keys[i], html)
+ value, _ := format(graph, depth+1, x.Values[i], html)
+ items = append(items, fmt.Sprintf("\t%s: %s,", key, value))
+ }
+ items = append(items, "}")
+ return strings.Join(items, tabs), true
+ case *wire.Interface:
+ typ, typOk := formatType(x.Type, graph, html)
+ element, elementOk := format(graph, depth+1, x.Value, html)
+ return fmt.Sprintf("interface[%s]{%s}", typ, element), typOk || elementOk
+ default:
+ // Must be a primitive; use reflection.
+ return fmt.Sprintf("%v", encoded), true
+ }
+}
+
+// printStream is the basic print implementation.
+func printStream(w io.Writer, r wire.Reader, html bool) (err error) {
+ // current graph ID.
+ var graph uint64
+
+ if html {
+ fmt.Fprintf(w, "<pre>")
+ defer fmt.Fprintf(w, "</pre>")
+ }
+
+ defer func() {
+ if r := recover(); r != nil {
+ if rErr, ok := r.(error); ok {
+ err = rErr // Override return.
+ return
+ }
+ panic(r) // Propagate.
+ }
+ }()
+
+ for {
+ // Find the first object to begin generation.
+ length, object, err := state.ReadHeader(r)
+ if err == io.EOF {
+ // Nothing else to do.
+ break
+ } else if err != nil {
+ return err
+ }
+ if !object {
+ graph++ // Increment the graph.
+ if length > 0 {
+ fmt.Fprintf(w, "(%d bytes non-object data)\n", length)
+ io.Copy(ioutil.Discard, &io.LimitedReader{
+ R: r,
+ N: int64(length),
+ })
+ }
+ continue
+ }
+
+ // Read & unmarshal the object.
+ //
+ // Note that this loop must match the general structure of the
+ // loop in decode.go. But we don't register type information,
+ // etc. and just print the raw structures.
+ var (
+ oid uint64 = 1
+ tid uint64 = 1
+ )
+ for oid <= length {
+ // Unmarshal the object.
+ encoded := wire.Load(r)
+
+ // Is this a type?
+ if _, ok := encoded.(*wire.Type); ok {
+ str, _ := format(graph, 0, encoded, html)
+ tag := fmt.Sprintf("g%dt%d", graph, tid)
+ if html {
+ // See below.
+ tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">&#9875;</a>", tag, tag, tag)
+ }
+ if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil {
+ return err
+ }
+ tid++
+ continue
+ }
+
+ // Format the node.
+ str, _ := format(graph, 0, encoded, html)
+ tag := fmt.Sprintf("g%dr%d", graph, oid)
+ if html {
+ // Create a little tag with an anchor next to it for linking.
+ tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">&#9875;</a>", tag, tag, tag)
+ }
+ if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil {
+ return err
+ }
+ oid++
+ }
+ }
+
+ return nil
+}
+
+// PrintText reads the stream from r and prints text to w.
+func PrintText(w io.Writer, r wire.Reader) error {
+ return printStream(w, r, false /* html */)
+}
+
+// PrintHTML reads the stream from r and prints html to w.
+func PrintHTML(w io.Writer, r wire.Reader) error {
+ return printStream(w, r, true /* html */)
+}
diff --git a/pkg/state/state.go b/pkg/state/state.go
new file mode 100644
index 000000000..acb629969
--- /dev/null
+++ b/pkg/state/state.go
@@ -0,0 +1,321 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package state provides functionality related to saving and loading object
+// graphs. For most types, it provides a set of default saving / loading logic
+// that will be invoked automatically if custom logic is not defined.
+//
+// Kind Support
+// ---- -------
+// Bool default
+// Int default
+// Int8 default
+// Int16 default
+// Int32 default
+// Int64 default
+// Uint default
+// Uint8 default
+// Uint16 default
+// Uint32 default
+// Uint64 default
+// Float32 default
+// Float64 default
+// Complex64 default
+// Complex128 default
+// Array default
+// Chan custom
+// Func custom
+// Interface default
+// Map default
+// Ptr default
+// Slice default
+// String default
+// Struct custom (*) Unless zero-sized.
+// UnsafePointer custom
+//
+// See README.md for an overview of how encoding and decoding works.
+package state
+
+import (
+ "context"
+ "fmt"
+ "reflect"
+ "runtime"
+
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+// objectID is a unique identifier assigned to each object to be serialized.
+// Each instance of an object is considered separately, i.e. if there are two
+// objects of the same type in the object graph being serialized, they'll be
+// assigned unique objectIDs.
+type objectID uint32
+
+// typeID is the identifier for a type. Types are serialized and tracked
+// alongside objects in order to avoid the overhead of encoding field names in
+// all objects.
+type typeID uint32
+
+// ErrState is returned when an error is encountered during encode/decode.
+type ErrState struct {
+ // err is the underlying error.
+ err error
+
+ // trace is the stack trace.
+ trace string
+}
+
+// Error returns a sensible description of the state error.
+func (e *ErrState) Error() string {
+ return fmt.Sprintf("%v:\n%s", e.err, e.trace)
+}
+
+// Unwrap implements standard unwrapping.
+func (e *ErrState) Unwrap() error {
+ return e.err
+}
+
+// Save saves the given object state.
+func Save(ctx context.Context, w wire.Writer, rootPtr interface{}) (Stats, error) {
+ // Create the encoding state.
+ es := encodeState{
+ ctx: ctx,
+ w: w,
+ types: makeTypeEncodeDatabase(),
+ zeroValues: make(map[reflect.Type]*objectEncodeState),
+ }
+
+ // Perform the encoding.
+ err := safely(func() {
+ es.Save(reflect.ValueOf(rootPtr).Elem())
+ })
+ return es.stats, err
+}
+
+// Load loads a checkpoint.
+func Load(ctx context.Context, r wire.Reader, rootPtr interface{}) (Stats, error) {
+ // Create the decoding state.
+ ds := decodeState{
+ ctx: ctx,
+ r: r,
+ types: makeTypeDecodeDatabase(),
+ deferred: make(map[objectID]wire.Object),
+ }
+
+ // Attempt our decode.
+ err := safely(func() {
+ ds.Load(reflect.ValueOf(rootPtr).Elem())
+ })
+ return ds.stats, err
+}
+
+// Sink is used for Type.StateSave.
+type Sink struct {
+ internal objectEncoder
+}
+
+// Save adds the given object to the map.
+//
+// You should pass always pointers to the object you are saving. For example:
+//
+// type X struct {
+// A int
+// B *int
+// }
+//
+// func (x *X) StateTypeInfo(m Sink) state.TypeInfo {
+// return state.TypeInfo{
+// Name: "pkg.X",
+// Fields: []string{
+// "A",
+// "B",
+// },
+// }
+// }
+//
+// func (x *X) StateSave(m Sink) {
+// m.Save(0, &x.A) // Field is A.
+// m.Save(1, &x.B) // Field is B.
+// }
+//
+// func (x *X) StateLoad(m Source) {
+// m.Load(0, &x.A) // Field is A.
+// m.Load(1, &x.B) // Field is B.
+// }
+func (s Sink) Save(slot int, objPtr interface{}) {
+ s.internal.save(slot, reflect.ValueOf(objPtr).Elem())
+}
+
+// SaveValue adds the given object value to the map.
+//
+// This should be used for values where pointers are not available, or casts
+// are required during Save/Load.
+//
+// For example, if we want to cast external package type P.Foo to int64:
+//
+// func (x *X) StateSave(m Sink) {
+// m.SaveValue(0, "A", int64(x.A))
+// }
+//
+// func (x *X) StateLoad(m Source) {
+// m.LoadValue(0, new(int64), func(x interface{}) {
+// x.A = P.Foo(x.(int64))
+// })
+// }
+func (s Sink) SaveValue(slot int, obj interface{}) {
+ s.internal.save(slot, reflect.ValueOf(obj))
+}
+
+// Context returns the context object provided at save time.
+func (s Sink) Context() context.Context {
+ return s.internal.es.ctx
+}
+
+// Type is an interface that must be implemented by Struct objects. This allows
+// these objects to be serialized while minimizing runtime reflection required.
+//
+// All these methods can be automatically generated by the go_statify tool.
+type Type interface {
+ // StateTypeName returns the type's name.
+ //
+ // This is used for matching type information during encoding and
+ // decoding, as well as dynamic interface dispatch. This should be
+ // globally unique.
+ StateTypeName() string
+
+ // StateFields returns information about the type.
+ //
+ // Fields is the set of fields for the object. Calls to Sink.Save and
+ // Source.Load must be made in-order with respect to these fields.
+ //
+ // This will be called at most once per serialization.
+ StateFields() []string
+}
+
+// SaverLoader must be implemented by struct types.
+type SaverLoader interface {
+ // StateSave saves the state of the object to the given Map.
+ StateSave(Sink)
+
+ // StateLoad loads the state of the object.
+ StateLoad(Source)
+}
+
+// Source is used for Type.StateLoad.
+type Source struct {
+ internal objectDecoder
+}
+
+// Load loads the given object passed as a pointer..
+//
+// See Sink.Save for an example.
+func (s Source) Load(slot int, objPtr interface{}) {
+ s.internal.load(slot, reflect.ValueOf(objPtr), false, nil)
+}
+
+// LoadWait loads the given objects from the map, and marks it as requiring all
+// AfterLoad executions to complete prior to running this object's AfterLoad.
+//
+// See Sink.Save for an example.
+func (s Source) LoadWait(slot int, objPtr interface{}) {
+ s.internal.load(slot, reflect.ValueOf(objPtr), true, nil)
+}
+
+// LoadValue loads the given object value from the map.
+//
+// See Sink.SaveValue for an example.
+func (s Source) LoadValue(slot int, objPtr interface{}, fn func(interface{})) {
+ o := reflect.ValueOf(objPtr)
+ s.internal.load(slot, o, true, func() { fn(o.Elem().Interface()) })
+}
+
+// AfterLoad schedules a function execution when all objects have been
+// allocated and their automated loading and customized load logic have been
+// executed. fn will not be executed until all of current object's
+// dependencies' AfterLoad() logic, if exist, have been executed.
+func (s Source) AfterLoad(fn func()) {
+ s.internal.afterLoad(fn)
+}
+
+// Context returns the context object provided at load time.
+func (s Source) Context() context.Context {
+ return s.internal.ds.ctx
+}
+
+// IsZeroValue checks if the given value is the zero value.
+//
+// This function is used by the stateify tool.
+func IsZeroValue(val interface{}) bool {
+ return val == nil || reflect.ValueOf(val).Elem().IsZero()
+}
+
+// Failf is a wrapper around panic that should be used to generate errors that
+// can be caught during saving and loading.
+func Failf(fmtStr string, v ...interface{}) {
+ panic(fmt.Errorf(fmtStr, v...))
+}
+
+// safely executes the given function, catching a panic and unpacking as an
+// error.
+//
+// The error flow through the state package uses panic and recover. There are
+// two important reasons for this:
+//
+// 1) Many of the reflection methods will already panic with invalid data or
+// violated assumptions. We would want to recover anyways here.
+//
+// 2) It allows us to eliminate boilerplate within Save() and Load() functions.
+// In nearly all cases, when the low-level serialization functions fail, you
+// will want the checkpoint to fail anyways. Plumbing errors through every
+// method doesn't add a lot of value. If there are specific error conditions
+// that you'd like to handle, you should add appropriate functionality to
+// objects themselves prior to calling Save() and Load().
+func safely(fn func()) (err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ if es, ok := r.(*ErrState); ok {
+ err = es // Propagate.
+ return
+ }
+
+ // Build a new state error.
+ es := new(ErrState)
+ if e, ok := r.(error); ok {
+ es.err = e
+ } else {
+ es.err = fmt.Errorf("%v", r)
+ }
+
+ // Make a stack. We don't know how big it will be ahead
+ // of time, but want to make sure we get the whole
+ // thing. So we just do a stupid brute force approach.
+ var stack []byte
+ for sz := 1024; ; sz *= 2 {
+ stack = make([]byte, sz)
+ n := runtime.Stack(stack, false)
+ if n < sz {
+ es.trace = string(stack[:n])
+ break
+ }
+ }
+
+ // Set the error.
+ err = es
+ }
+ }()
+
+ // Execute the function.
+ fn()
+ return nil
+}
diff --git a/pkg/state/state_norace.go b/pkg/state/state_norace.go
new file mode 100644
index 000000000..4281aed6d
--- /dev/null
+++ b/pkg/state/state_norace.go
@@ -0,0 +1,19 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !race
+
+package state
+
+var raceEnabled = false
diff --git a/pkg/state/state_race.go b/pkg/state/state_race.go
new file mode 100644
index 000000000..8232981ce
--- /dev/null
+++ b/pkg/state/state_race.go
@@ -0,0 +1,19 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build race
+
+package state
+
+var raceEnabled = true
diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD
new file mode 100644
index 000000000..d6c89c7e9
--- /dev/null
+++ b/pkg/state/statefile/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "statefile",
+ srcs = ["statefile.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/binary",
+ "//pkg/compressio",
+ "//pkg/state/wire",
+ ],
+)
+
+go_test(
+ name = "statefile_test",
+ size = "small",
+ srcs = ["statefile_test.go"],
+ library = ":statefile",
+ deps = ["//pkg/compressio"],
+)
diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go
new file mode 100644
index 000000000..bdfb800fb
--- /dev/null
+++ b/pkg/state/statefile/statefile.go
@@ -0,0 +1,239 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package statefile defines the state file data stream.
+//
+// This package currently does not include any details regarding the state
+// encoding itself, only details regarding state metadata and data layout.
+//
+// The file format is defined as follows.
+//
+// /------------------------------------------------------\
+// | header (8-bytes) |
+// +------------------------------------------------------+
+// | metadata length (8-bytes) |
+// +------------------------------------------------------+
+// | metadata |
+// +------------------------------------------------------+
+// | data |
+// \------------------------------------------------------/
+//
+// First, it includes a 8-byte magic header which is the following
+// sequence of bytes [0x67, 0x56, 0x69, 0x73, 0x6f, 0x72, 0x53, 0x46]
+//
+// This header is followed by an 8-byte length N (big endian), and an
+// ASCII-encoded JSON map that is exactly N bytes long.
+//
+// This map includes only strings for keys and strings for values. Keys in the
+// map that begin with "_" are for internal use only. They may be read, but may
+// not be provided by the user. In the future, this metadata may contain some
+// information relating to the state encoding itself.
+//
+// After the map, the remainder of the file is the state data.
+package statefile
+
+import (
+ "bytes"
+ "compress/flate"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/json"
+ "fmt"
+ "hash"
+ "io"
+ "strings"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/compressio"
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+// keySize is the AES-256 key length.
+const keySize = 32
+
+// compressionChunkSize is the chunk size for compression.
+const compressionChunkSize = 1024 * 1024
+
+// maxMetadataSize is the size limit of metadata section.
+const maxMetadataSize = 16 * 1024 * 1024
+
+// magicHeader is the byte sequence beginning each file.
+var magicHeader = []byte("\x67\x56\x69\x73\x6f\x72\x53\x46")
+
+// ErrBadMagic is returned if the header does not match.
+var ErrBadMagic = fmt.Errorf("bad magic header")
+
+// ErrMetadataMissing is returned if the state file is missing mandatory metadata.
+var ErrMetadataMissing = fmt.Errorf("missing metadata")
+
+// ErrInvalidMetadataLength is returned if the metadata length is too large.
+var ErrInvalidMetadataLength = fmt.Errorf("metadata length invalid, maximum size is %d", maxMetadataSize)
+
+// ErrMetadataInvalid is returned if passed metadata is invalid.
+var ErrMetadataInvalid = fmt.Errorf("metadata invalid, can't start with _")
+
+// WriteCloser is an io.Closer and wire.Writer.
+type WriteCloser interface {
+ wire.Writer
+ io.Closer
+}
+
+// NewWriter returns a state data writer for a statefile.
+//
+// Note that the returned WriteCloser must be closed.
+func NewWriter(w io.Writer, key []byte, metadata map[string]string) (WriteCloser, error) {
+ if metadata == nil {
+ metadata = make(map[string]string)
+ }
+ for k := range metadata {
+ if strings.HasPrefix(k, "_") {
+ return nil, ErrMetadataInvalid
+ }
+ }
+
+ // Create our HMAC function.
+ h := hmac.New(sha256.New, key)
+ mw := io.MultiWriter(w, h)
+
+ // First, write the header.
+ if _, err := mw.Write(magicHeader); err != nil {
+ return nil, err
+ }
+
+ // Generate a timestamp, for convenience only.
+ metadata["_timestamp"] = time.Now().UTC().String()
+ defer delete(metadata, "_timestamp")
+
+ // Write the metadata.
+ b, err := json.Marshal(metadata)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(b) > maxMetadataSize {
+ return nil, ErrInvalidMetadataLength
+ }
+
+ // Metadata length.
+ if err := binary.WriteUint64(mw, binary.BigEndian, uint64(len(b))); err != nil {
+ return nil, err
+ }
+ // Metadata bytes; io.MultiWriter will return a short write error if
+ // any of the writers returns < n.
+ if _, err := mw.Write(b); err != nil {
+ return nil, err
+ }
+ // Write the current hash.
+ cur := h.Sum(nil)
+ for done := 0; done < len(cur); {
+ n, err := mw.Write(cur[done:])
+ done += n
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // Wrap in compression. We always use "best speed" mode here. When using
+ // "best compression" mode, there is usually only a little gain in file
+ // size reduction, which translate to even smaller gain in restore
+ // latency reduction, while inccuring much more CPU usage at save time.
+ return compressio.NewWriter(w, key, compressionChunkSize, flate.BestSpeed)
+}
+
+// MetadataUnsafe reads out the metadata from a state file without verifying any
+// HMAC. This function shouldn't be called for untrusted input files.
+func MetadataUnsafe(r io.Reader) (map[string]string, error) {
+ return metadata(r, nil)
+}
+
+// metadata validates the magic header and reads out the metadata from a state
+// data stream.
+func metadata(r io.Reader, h hash.Hash) (map[string]string, error) {
+ if h != nil {
+ r = io.TeeReader(r, h)
+ }
+
+ // Read and validate magic header.
+ b := make([]byte, len(magicHeader))
+ if _, err := r.Read(b); err != nil {
+ return nil, err
+ }
+ if !bytes.Equal(b, magicHeader) {
+ return nil, ErrBadMagic
+ }
+
+ // Read and validate metadata.
+ b, err := func() (b []byte, err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ b = nil
+ err = fmt.Errorf("%v", r)
+ }
+ }()
+
+ metadataLen, err := binary.ReadUint64(r, binary.BigEndian)
+ if err != nil {
+ return nil, err
+ }
+ if metadataLen > maxMetadataSize {
+ return nil, ErrInvalidMetadataLength
+ }
+ b = make([]byte, int(metadataLen))
+ if _, err := io.ReadFull(r, b); err != nil {
+ return nil, err
+ }
+ return b, nil
+ }()
+ if err != nil {
+ return nil, err
+ }
+
+ if h != nil {
+ // Check the hash prior to decoding.
+ cur := h.Sum(nil)
+ buf := make([]byte, len(cur))
+ if _, err := io.ReadFull(r, buf); err != nil {
+ return nil, err
+ }
+ if !hmac.Equal(cur, buf) {
+ return nil, compressio.ErrHashMismatch
+ }
+ }
+
+ // Decode the metadata.
+ metadata := make(map[string]string)
+ if err := json.Unmarshal(b, &metadata); err != nil {
+ return nil, err
+ }
+
+ return metadata, nil
+}
+
+// NewReader returns a reader for a statefile.
+func NewReader(r io.Reader, key []byte) (wire.Reader, map[string]string, error) {
+ // Read the metadata with the hash.
+ h := hmac.New(sha256.New, key)
+ metadata, err := metadata(r, h)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Wrap in compression.
+ cr, err := compressio.NewReader(r, key)
+ if err != nil {
+ return nil, nil, err
+ }
+ return cr, metadata, nil
+}
diff --git a/pkg/state/statefile/statefile_test.go b/pkg/state/statefile/statefile_test.go
new file mode 100644
index 000000000..0b470fdec
--- /dev/null
+++ b/pkg/state/statefile/statefile_test.go
@@ -0,0 +1,290 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package statefile
+
+import (
+ "bytes"
+ crand "crypto/rand"
+ "encoding/base64"
+ "io"
+ "math/rand"
+ "runtime"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/compressio"
+)
+
+func randomKey() ([]byte, error) {
+ r := make([]byte, base64.RawStdEncoding.DecodedLen(keySize))
+ if _, err := io.ReadFull(crand.Reader, r); err != nil {
+ return nil, err
+ }
+ key := make([]byte, keySize)
+ base64.RawStdEncoding.Encode(key, r)
+ return key, nil
+}
+
+type testCase struct {
+ name string
+ data []byte
+ metadata map[string]string
+}
+
+func TestStatefile(t *testing.T) {
+ rand.Seed(time.Now().Unix())
+
+ cases := []testCase{
+ // Various data sizes.
+ {"nil", nil, nil},
+ {"empty", []byte(""), nil},
+ {"some", []byte("_"), nil},
+ {"one", []byte("0"), nil},
+ {"two", []byte("01"), nil},
+ {"three", []byte("012"), nil},
+ {"four", []byte("0123"), nil},
+ {"five", []byte("01234"), nil},
+ {"six", []byte("012356"), nil},
+ {"seven", []byte("0123567"), nil},
+ {"eight", []byte("01235678"), nil},
+
+ // Make sure we have one longer than the hash length.
+ {"longer than hash", []byte("012356asdjflkasjlk3jlk23j4lkjaso0d789f0aujw3lkjlkxsdf78asdful2kj3ljka78"), nil},
+
+ // Make sure we have one longer than the chunk size.
+ {"chunks", make([]byte, 3*compressionChunkSize), nil},
+ {"large", make([]byte, 30*compressionChunkSize), nil},
+
+ // Different metadata.
+ {"one metadata", []byte("data"), map[string]string{"foo": "bar"}},
+ {"two metadata", []byte("data"), map[string]string{"foo": "bar", "one": "two"}},
+ }
+
+ for _, c := range cases {
+ // Generate a key.
+ integrityKey, err := randomKey()
+ if err != nil {
+ t.Errorf("can't generate key: got %v, excepted nil", err)
+ continue
+ }
+
+ t.Run(c.name, func(t *testing.T) {
+ for _, key := range [][]byte{nil, integrityKey} {
+ t.Run("key="+string(key), func(t *testing.T) {
+ // Encoding happens via a buffer.
+ var bufEncoded bytes.Buffer
+ var bufDecoded bytes.Buffer
+
+ // Do all the writing.
+ w, err := NewWriter(&bufEncoded, key, c.metadata)
+ if err != nil {
+ t.Fatalf("error creating writer: got %v, expected nil", err)
+ }
+ if _, err := io.Copy(w, bytes.NewBuffer(c.data)); err != nil {
+ t.Fatalf("error during write: got %v, expected nil", err)
+ }
+
+ // Finish the sum.
+ if err := w.Close(); err != nil {
+ t.Fatalf("error during close: got %v, expected nil", err)
+ }
+
+ t.Logf("original data: %d bytes, encoded: %d bytes.",
+ len(c.data), len(bufEncoded.Bytes()))
+
+ // Do all the reading.
+ r, metadata, err := NewReader(bytes.NewReader(bufEncoded.Bytes()), key)
+ if err != nil {
+ t.Fatalf("error creating reader: got %v, expected nil", err)
+ }
+ if _, err := io.Copy(&bufDecoded, r); err != nil {
+ t.Fatalf("error during read: got %v, expected nil", err)
+ }
+
+ // Check that the data matches.
+ if !bytes.Equal(c.data, bufDecoded.Bytes()) {
+ t.Fatalf("data didn't match (%d vs %d bytes)", len(bufDecoded.Bytes()), len(c.data))
+ }
+
+ // Check that the metadata matches.
+ for k, v := range c.metadata {
+ nv, ok := metadata[k]
+ if !ok {
+ t.Fatalf("missing metadata: %s", k)
+ }
+ if v != nv {
+ t.Fatalf("mismatched metdata for %s: got %s, expected %s", k, nv, v)
+ }
+ }
+
+ // Change the data and verify that it fails.
+ if key != nil {
+ b := append([]byte(nil), bufEncoded.Bytes()...)
+ b[rand.Intn(len(b))]++
+ bufDecoded.Reset()
+ r, _, err = NewReader(bytes.NewReader(b), key)
+ if err == nil {
+ _, err = io.Copy(&bufDecoded, r)
+ }
+ if err == nil {
+ t.Error("got no error: expected error on data corruption")
+ }
+ }
+
+ // Change the key and verify that it fails.
+ newKey := integrityKey
+ if len(key) > 0 {
+ newKey = append([]byte{}, key...)
+ newKey[rand.Intn(len(newKey))]++
+ }
+ bufDecoded.Reset()
+ r, _, err = NewReader(bytes.NewReader(bufEncoded.Bytes()), newKey)
+ if err == nil {
+ _, err = io.Copy(&bufDecoded, r)
+ }
+ if err != compressio.ErrHashMismatch {
+ t.Errorf("got error: %v, expected ErrHashMismatch on key mismatch", err)
+ }
+ })
+ }
+ })
+ }
+}
+
+const benchmarkDataSize = 100 * 1024 * 1024
+
+func benchmark(b *testing.B, size int, write bool, compressible bool) {
+ b.StopTimer()
+ b.SetBytes(benchmarkDataSize)
+
+ // Generate source data.
+ var source []byte
+ if compressible {
+ // For compressible data, we use essentially all zeros.
+ source = make([]byte, benchmarkDataSize)
+ } else {
+ // For non-compressible data, we use random base64 data (to
+ // make it marginally compressible, a ratio of 75%).
+ var sourceBuf bytes.Buffer
+ bufW := base64.NewEncoder(base64.RawStdEncoding, &sourceBuf)
+ bufR := rand.New(rand.NewSource(0))
+ if _, err := io.CopyN(bufW, bufR, benchmarkDataSize); err != nil {
+ b.Fatalf("unable to seed random data: %v", err)
+ }
+ source = sourceBuf.Bytes()
+ }
+
+ // Generate a random key for integrity check.
+ key, err := randomKey()
+ if err != nil {
+ b.Fatalf("error generating key: %v", err)
+ }
+
+ // Define our benchmark functions. Prior to running the readState
+ // function here, you must execute the writeState function at least
+ // once (done below).
+ var stateBuf bytes.Buffer
+ writeState := func() {
+ stateBuf.Reset()
+ w, err := NewWriter(&stateBuf, key, nil)
+ if err != nil {
+ b.Fatalf("error creating writer: %v", err)
+ }
+ for done := 0; done < len(source); {
+ chunk := size // limit size.
+ if done+chunk > len(source) {
+ chunk = len(source) - done
+ }
+ n, err := w.Write(source[done : done+chunk])
+ done += n
+ if n == 0 && err != nil {
+ b.Fatalf("error during write: %v", err)
+ }
+ }
+ if err := w.Close(); err != nil {
+ b.Fatalf("error closing writer: %v", err)
+ }
+ }
+ readState := func() {
+ tmpBuf := bytes.NewBuffer(stateBuf.Bytes())
+ r, _, err := NewReader(tmpBuf, key)
+ if err != nil {
+ b.Fatalf("error creating reader: %v", err)
+ }
+ for done := 0; done < len(source); {
+ chunk := size // limit size.
+ if done+chunk > len(source) {
+ chunk = len(source) - done
+ }
+ n, err := r.Read(source[done : done+chunk])
+ done += n
+ if n == 0 && err != nil {
+ b.Fatalf("error during read: %v", err)
+ }
+ }
+ }
+ // Generate the state once without timing to ensure that buffers have
+ // been appropriately allocated.
+ writeState()
+ if write {
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ writeState()
+ }
+ b.StopTimer()
+ } else {
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ readState()
+ }
+ b.StopTimer()
+ }
+}
+
+func BenchmarkWrite4KCompressible(b *testing.B) {
+ benchmark(b, 4096, true, true)
+}
+
+func BenchmarkWrite4KNoncompressible(b *testing.B) {
+ benchmark(b, 4096, true, false)
+}
+
+func BenchmarkWrite1MCompressible(b *testing.B) {
+ benchmark(b, 1024*1024, true, true)
+}
+
+func BenchmarkWrite1MNoncompressible(b *testing.B) {
+ benchmark(b, 1024*1024, true, false)
+}
+
+func BenchmarkRead4KCompressible(b *testing.B) {
+ benchmark(b, 4096, false, true)
+}
+
+func BenchmarkRead4KNoncompressible(b *testing.B) {
+ benchmark(b, 4096, false, false)
+}
+
+func BenchmarkRead1MCompressible(b *testing.B) {
+ benchmark(b, 1024*1024, false, true)
+}
+
+func BenchmarkRead1MNoncompressible(b *testing.B) {
+ benchmark(b, 1024*1024, false, false)
+}
+
+func init() {
+ runtime.GOMAXPROCS(runtime.NumCPU())
+}
diff --git a/pkg/state/stats.go b/pkg/state/stats.go
new file mode 100644
index 000000000..eaec664a1
--- /dev/null
+++ b/pkg/state/stats.go
@@ -0,0 +1,145 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+ "time"
+)
+
+type statEntry struct {
+ count uint
+ total time.Duration
+}
+
+// Stats tracks encode / decode timing.
+//
+// This currently provides a meaningful String function and no other way to
+// extract stats about individual types.
+//
+// All exported receivers accept nil.
+type Stats struct {
+ // byType contains a breakdown of time spent by type.
+ //
+ // This is indexed *directly* by typeID, including zero.
+ byType []statEntry
+
+ // stack contains objects in progress.
+ stack []typeID
+
+ // names contains type names.
+ //
+ // This is also indexed *directly* by typeID, including zero, which we
+ // hard-code as "state.default". This is only resolved by calling fini
+ // on the stats object.
+ names []string
+
+ // last is the last start time.
+ last time.Time
+}
+
+// init initializes statistics.
+func (s *Stats) init() {
+ s.last = time.Now()
+ s.stack = append(s.stack, 0)
+}
+
+// fini finalizes statistics.
+func (s *Stats) fini(resolve func(id typeID) string) {
+ s.done()
+
+ // Resolve all type names.
+ s.names = make([]string, len(s.byType))
+ s.names[0] = "state.default" // See above.
+ for id := typeID(1); int(id) < len(s.names); id++ {
+ s.names[id] = resolve(id)
+ }
+}
+
+// sample adds the samples to the given object.
+func (s *Stats) sample(id typeID) {
+ now := time.Now()
+ if len(s.byType) <= int(id) {
+ // Allocate all the missing entries in one fell swoop.
+ s.byType = append(s.byType, make([]statEntry, 1+int(id)-len(s.byType))...)
+ }
+ s.byType[id].total += now.Sub(s.last)
+ s.last = now
+}
+
+// start starts a sample.
+func (s *Stats) start(id typeID) {
+ last := s.stack[len(s.stack)-1]
+ s.sample(last)
+ s.stack = append(s.stack, id)
+}
+
+// done finishes the current sample.
+func (s *Stats) done() {
+ last := s.stack[len(s.stack)-1]
+ s.sample(last)
+ s.byType[last].count++
+ s.stack = s.stack[:len(s.stack)-1]
+}
+
+type sliceEntry struct {
+ name string
+ entry *statEntry
+}
+
+// String returns a table representation of the stats.
+func (s *Stats) String() string {
+ // Build a list of stat entries.
+ ss := make([]sliceEntry, 0, len(s.byType))
+ for id := 0; id < len(s.names); id++ {
+ ss = append(ss, sliceEntry{
+ name: s.names[id],
+ entry: &s.byType[id],
+ })
+ }
+
+ // Sort by total time (descending).
+ sort.Slice(ss, func(i, j int) bool {
+ return ss[i].entry.total > ss[j].entry.total
+ })
+
+ // Print the stat results.
+ var (
+ buf bytes.Buffer
+ count uint
+ total time.Duration
+ )
+ buf.WriteString("\n")
+ buf.WriteString(fmt.Sprintf("% 16s | % 8s | % 16s | %s\n", "total", "count", "per", "type"))
+ buf.WriteString("-----------------+----------+------------------+----------------\n")
+ for _, se := range ss {
+ if se.entry.count == 0 {
+ // Since we store all types linearly, we are not
+ // guaranteed that any entry actually has time.
+ continue
+ }
+ count += se.entry.count
+ total += se.entry.total
+ per := se.entry.total / time.Duration(se.entry.count)
+ buf.WriteString(fmt.Sprintf("% 16s | %8d | % 16s | %s\n",
+ se.entry.total, se.entry.count, per, se.name))
+ }
+ buf.WriteString("-----------------+----------+------------------+----------------\n")
+ buf.WriteString(fmt.Sprintf("% 16s | % 8d | % 16s | [all]",
+ total, count, total/time.Duration(count)))
+ return string(buf.Bytes())
+}
diff --git a/pkg/state/tests/BUILD b/pkg/state/tests/BUILD
new file mode 100644
index 000000000..9297cafbe
--- /dev/null
+++ b/pkg/state/tests/BUILD
@@ -0,0 +1,43 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "tests",
+ srcs = [
+ "array.go",
+ "bench.go",
+ "integer.go",
+ "load.go",
+ "map.go",
+ "register.go",
+ "struct.go",
+ "tests.go",
+ ],
+ deps = [
+ "//pkg/state",
+ "//pkg/state/pretty",
+ ],
+)
+
+go_test(
+ name = "tests_test",
+ size = "small",
+ srcs = [
+ "array_test.go",
+ "bench_test.go",
+ "bool_test.go",
+ "float_test.go",
+ "integer_test.go",
+ "load_test.go",
+ "map_test.go",
+ "register_test.go",
+ "string_test.go",
+ "struct_test.go",
+ ],
+ library = ":tests",
+ deps = [
+ "//pkg/state",
+ "//pkg/state/wire",
+ ],
+)
diff --git a/pkg/state/tests/array.go b/pkg/state/tests/array.go
new file mode 100644
index 000000000..0972a80e7
--- /dev/null
+++ b/pkg/state/tests/array.go
@@ -0,0 +1,35 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type arrayContainer struct {
+ v [1]interface{}
+}
+
+// +stateify savable
+type arrayPtrContainer struct {
+ v *[1]interface{}
+}
+
+// +stateify savable
+type sliceContainer struct {
+ v []interface{}
+}
+
+// +stateify savable
+type slicePtrContainer struct {
+ v *[]interface{}
+}
diff --git a/pkg/state/tests/array_test.go b/pkg/state/tests/array_test.go
new file mode 100644
index 000000000..a347b2947
--- /dev/null
+++ b/pkg/state/tests/array_test.go
@@ -0,0 +1,134 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "reflect"
+ "testing"
+)
+
+var allArrayPrimitives = []interface{}{
+ [1]bool{},
+ [1]bool{true},
+ [2]bool{false, true},
+ [1]int{},
+ [1]int{1},
+ [2]int{0, 1},
+ [1]int8{},
+ [1]int8{1},
+ [2]int8{0, 1},
+ [1]int16{},
+ [1]int16{1},
+ [2]int16{0, 1},
+ [1]int32{},
+ [1]int32{1},
+ [2]int32{0, 1},
+ [1]int64{},
+ [1]int64{1},
+ [2]int64{0, 1},
+ [1]uint{},
+ [1]uint{1},
+ [2]uint{0, 1},
+ [1]uintptr{},
+ [1]uintptr{1},
+ [2]uintptr{0, 1},
+ [1]uint8{},
+ [1]uint8{1},
+ [2]uint8{0, 1},
+ [1]uint16{},
+ [1]uint16{1},
+ [2]uint16{0, 1},
+ [1]uint32{},
+ [1]uint32{1},
+ [2]uint32{0, 1},
+ [1]uint64{},
+ [1]uint64{1},
+ [2]uint64{0, 1},
+ [1]string{},
+ [1]string{""},
+ [1]string{nonEmptyString},
+ [2]string{"", nonEmptyString},
+}
+
+func TestArrayPrimitives(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(allArrayPrimitives))
+ runTestCases(t, false, "pointers", pointersTo(flatten(allArrayPrimitives)))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(allArrayPrimitives)))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allArrayPrimitives))))
+}
+
+func TestSlices(t *testing.T) {
+ var allSlices = flatten(
+ filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o)).Elem()
+ v.Set(reflect.ValueOf(o))
+ return v.Slice(0, v.Len()).Interface(), true
+ }),
+ filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o)).Elem()
+ v.Set(reflect.ValueOf(o))
+ if v.Len() == 0 {
+ // Return the pure "nil" value for the slice.
+ return reflect.New(v.Slice(0, 0).Type()).Elem().Interface(), true
+ }
+ return v.Slice(1, v.Len()).Interface(), true
+ }),
+ filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o)).Elem()
+ v.Set(reflect.ValueOf(o))
+ if v.Len() == 0 {
+ // Return the zero-valued slice.
+ return reflect.MakeSlice(v.Slice(0, 0).Type(), 0, 0).Interface(), true
+ }
+ return v.Slice(0, v.Len()-1).Interface(), true
+ }),
+ )
+ runTestCases(t, false, "plain", allSlices)
+ runTestCases(t, false, "pointers", pointersTo(allSlices))
+ runTestCases(t, false, "interfaces", interfacesTo(allSlices))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(allSlices)))
+}
+
+func TestArrayContainers(t *testing.T) {
+ var (
+ emptyArray [1]interface{}
+ fullArray [1]interface{}
+ )
+ fullArray[0] = &emptyArray
+ runTestCases(t, false, "", []interface{}{
+ arrayContainer{v: emptyArray},
+ arrayContainer{v: fullArray},
+ arrayPtrContainer{v: nil},
+ arrayPtrContainer{v: &emptyArray},
+ arrayPtrContainer{v: &fullArray},
+ })
+}
+
+func TestSliceContainers(t *testing.T) {
+ var (
+ nilSlice []interface{}
+ emptySlice = make([]interface{}, 0)
+ fullSlice = []interface{}{nil}
+ )
+ runTestCases(t, false, "", []interface{}{
+ sliceContainer{v: nilSlice},
+ sliceContainer{v: emptySlice},
+ sliceContainer{v: fullSlice},
+ slicePtrContainer{v: nil},
+ slicePtrContainer{v: &nilSlice},
+ slicePtrContainer{v: &emptySlice},
+ slicePtrContainer{v: &fullSlice},
+ })
+}
diff --git a/pkg/state/tests/bench.go b/pkg/state/tests/bench.go
new file mode 100644
index 000000000..40869cdfb
--- /dev/null
+++ b/pkg/state/tests/bench.go
@@ -0,0 +1,24 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type benchStruct struct {
+ B *benchStruct // Must be exported for gob.
+}
+
+func (b *benchStruct) afterLoad() {
+ // Do nothing, just force scheduling.
+}
diff --git a/pkg/state/tests/bench_test.go b/pkg/state/tests/bench_test.go
new file mode 100644
index 000000000..7e102c907
--- /dev/null
+++ b/pkg/state/tests/bench_test.go
@@ -0,0 +1,153 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "bytes"
+ "context"
+ "encoding/gob"
+ "fmt"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+// buildPtrObject builds a benchmark object.
+func buildPtrObject(n int) interface{} {
+ b := new(benchStruct)
+ for i := 0; i < n; i++ {
+ b = &benchStruct{B: b}
+ }
+ return b
+}
+
+// buildMapObject builds a benchmark object.
+func buildMapObject(n int) interface{} {
+ b := new(benchStruct)
+ m := make(map[int]*benchStruct)
+ for i := 0; i < n; i++ {
+ m[i] = b
+ }
+ return &m
+}
+
+// buildSliceObject builds a benchmark object.
+func buildSliceObject(n int) interface{} {
+ b := new(benchStruct)
+ s := make([]*benchStruct, 0, n)
+ for i := 0; i < n; i++ {
+ s = append(s, b)
+ }
+ return &s
+}
+
+var allObjects = map[string]struct {
+ New func(int) interface{}
+}{
+ "ptr": {
+ New: buildPtrObject,
+ },
+ "map": {
+ New: buildMapObject,
+ },
+ "slice": {
+ New: buildSliceObject,
+ },
+}
+
+func buildObjects(n int, fn func(int) interface{}) (iters int, v interface{}) {
+ // maxSize is the maximum size of an individual object below. For an N
+ // larger than this, we start to return multiple objects.
+ const maxSize = 1024
+ if n <= maxSize {
+ return 1, fn(n)
+ }
+ iters = (n + maxSize - 1) / maxSize
+ return iters, fn(maxSize)
+}
+
+// gobSave is a version of save using gob (no stats available).
+func gobSave(_ context.Context, w wire.Writer, v interface{}) (_ state.Stats, err error) {
+ enc := gob.NewEncoder(w)
+ err = enc.Encode(v)
+ return
+}
+
+// gobLoad is a version of load using gob (no stats available).
+func gobLoad(_ context.Context, r wire.Reader, v interface{}) (_ state.Stats, err error) {
+ dec := gob.NewDecoder(r)
+ err = dec.Decode(v)
+ return
+}
+
+var allAlgos = map[string]struct {
+ Save func(context.Context, wire.Writer, interface{}) (state.Stats, error)
+ Load func(context.Context, wire.Reader, interface{}) (state.Stats, error)
+ MaxPtr int
+}{
+ "state": {
+ Save: state.Save,
+ Load: state.Load,
+ },
+ "gob": {
+ Save: gobSave,
+ Load: gobLoad,
+ },
+}
+
+func BenchmarkEncoding(b *testing.B) {
+ for objName, objInfo := range allObjects {
+ for algoName, algoInfo := range allAlgos {
+ b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) {
+ b.StopTimer()
+ n, v := buildObjects(b.N, objInfo.New)
+ b.ReportAllocs()
+ b.StartTimer()
+ for i := 0; i < n; i++ {
+ if _, err := algoInfo.Save(context.Background(), discard{}, v); err != nil {
+ b.Errorf("save failed: %v", err)
+ }
+ }
+ b.StopTimer()
+ })
+ }
+ }
+}
+
+func BenchmarkDecoding(b *testing.B) {
+ for objName, objInfo := range allObjects {
+ for algoName, algoInfo := range allAlgos {
+ b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) {
+ b.StopTimer()
+ n, v := buildObjects(b.N, objInfo.New)
+ buf := new(bytes.Buffer)
+ if _, err := algoInfo.Save(context.Background(), buf, v); err != nil {
+ b.Errorf("save failed: %v", err)
+ }
+ b.ReportAllocs()
+ b.StartTimer()
+ var r bytes.Reader
+ for i := 0; i < n; i++ {
+ r.Reset(buf.Bytes())
+ if _, err := algoInfo.Load(context.Background(), &r, v); err != nil {
+ b.Errorf("load failed: %v", err)
+ }
+ }
+ b.StopTimer()
+ })
+ }
+ }
+}
diff --git a/pkg/state/tests/bool_test.go b/pkg/state/tests/bool_test.go
new file mode 100644
index 000000000..e17cfacf9
--- /dev/null
+++ b/pkg/state/tests/bool_test.go
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+)
+
+var allBools = []bool{
+ true,
+ false,
+}
+
+func TestBool(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(allBools))
+ runTestCases(t, false, "pointers", pointersTo(flatten(allBools)))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(allBools)))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allBools))))
+}
diff --git a/pkg/state/tests/float_test.go b/pkg/state/tests/float_test.go
new file mode 100644
index 000000000..3e89edd9c
--- /dev/null
+++ b/pkg/state/tests/float_test.go
@@ -0,0 +1,118 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "math"
+ "testing"
+)
+
+var safeFloat32s = []float32{
+ float32(0.0),
+ float32(1.0),
+ float32(-1.0),
+ float32(math.Inf(1)),
+ float32(math.Inf(-1)),
+}
+
+var allFloat32s = append(safeFloat32s, float32(math.NaN()))
+
+var safeFloat64s = []float64{
+ float64(0.0),
+ float64(1.0),
+ float64(-1.0),
+ math.Inf(1),
+ math.Inf(-1),
+}
+
+var allFloat64s = append(safeFloat64s, math.NaN())
+
+func TestFloat(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(
+ allFloat32s,
+ allFloat64s,
+ ))
+ // See checkEqual for why NaNs are missing.
+ runTestCases(t, false, "pointers", pointersTo(flatten(
+ safeFloat32s,
+ safeFloat64s,
+ )))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(
+ safeFloat32s,
+ safeFloat64s,
+ )))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(
+ safeFloat32s,
+ safeFloat64s,
+ ))))
+}
+
+const onlyDouble float64 = 1.0000000000000002
+
+func TestFloatTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingFloat32{save: onlyDouble},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingFloat32{save: 1.0},
+ })
+}
+
+var safeComplex64s = combine(safeFloat32s, safeFloat32s, func(i, j interface{}) interface{} {
+ return complex(i.(float32), j.(float32))
+})
+
+var allComplex64s = combine(allFloat32s, allFloat32s, func(i, j interface{}) interface{} {
+ return complex(i.(float32), j.(float32))
+})
+
+var safeComplex128s = combine(safeFloat64s, safeFloat64s, func(i, j interface{}) interface{} {
+ return complex(i.(float64), j.(float64))
+})
+
+var allComplex128s = combine(allFloat64s, allFloat64s, func(i, j interface{}) interface{} {
+ return complex(i.(float64), j.(float64))
+})
+
+func TestComplex(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(
+ allComplex64s,
+ allComplex128s,
+ ))
+ // See TestFloat; same issue.
+ runTestCases(t, false, "pointers", pointersTo(flatten(
+ safeComplex64s,
+ safeComplex128s,
+ )))
+ runTestCases(t, false, "interfacse", interfacesTo(flatten(
+ safeComplex64s,
+ safeComplex128s,
+ )))
+ runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(flatten(
+ safeComplex64s,
+ safeComplex128s,
+ ))))
+}
+
+func TestComplexTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingComplex64{save: complex(onlyDouble, onlyDouble)},
+ truncatingComplex64{save: complex(1.0, onlyDouble)},
+ truncatingComplex64{save: complex(onlyDouble, 1.0)},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingComplex64{save: complex(1.0, 1.0)},
+ })
+}
diff --git a/pkg/state/tests/integer.go b/pkg/state/tests/integer.go
new file mode 100644
index 000000000..ca403eed1
--- /dev/null
+++ b/pkg/state/tests/integer.go
@@ -0,0 +1,163 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+// +stateify type
+type truncatingUint8 struct {
+ save uint64
+ load uint8 `state:"nosave"`
+}
+
+func (t *truncatingUint8) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingUint8) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = uint64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingUint8)(nil)
+
+// +stateify type
+type truncatingUint16 struct {
+ save uint64
+ load uint16 `state:"nosave"`
+}
+
+func (t *truncatingUint16) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingUint16) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = uint64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingUint16)(nil)
+
+// +stateify type
+type truncatingUint32 struct {
+ save uint64
+ load uint32 `state:"nosave"`
+}
+
+func (t *truncatingUint32) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingUint32) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = uint64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingUint32)(nil)
+
+// +stateify type
+type truncatingInt8 struct {
+ save int64
+ load int8 `state:"nosave"`
+}
+
+func (t *truncatingInt8) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingInt8) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = int64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingInt8)(nil)
+
+// +stateify type
+type truncatingInt16 struct {
+ save int64
+ load int16 `state:"nosave"`
+}
+
+func (t *truncatingInt16) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingInt16) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = int64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingInt16)(nil)
+
+// +stateify type
+type truncatingInt32 struct {
+ save int64
+ load int32 `state:"nosave"`
+}
+
+func (t *truncatingInt32) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingInt32) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = int64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingInt32)(nil)
+
+// +stateify type
+type truncatingFloat32 struct {
+ save float64
+ load float32 `state:"nosave"`
+}
+
+func (t *truncatingFloat32) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingFloat32) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = float64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingFloat32)(nil)
+
+// +stateify type
+type truncatingComplex64 struct {
+ save complex128
+ load complex64 `state:"nosave"`
+}
+
+func (t *truncatingComplex64) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingComplex64) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = complex128(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingComplex64)(nil)
diff --git a/pkg/state/tests/integer_test.go b/pkg/state/tests/integer_test.go
new file mode 100644
index 000000000..d3931c952
--- /dev/null
+++ b/pkg/state/tests/integer_test.go
@@ -0,0 +1,94 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "math"
+ "testing"
+)
+
+var (
+ allIntTs = []int{-1, 0, 1}
+ allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8}
+ allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16}
+ allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32}
+ allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64}
+ allUintTs = []uint{0, 1}
+ allUintptrs = []uintptr{0, 1, ^uintptr(0)}
+ allUint8s = []uint8{0, 1, math.MaxUint8}
+ allUint16s = []uint16{0, 1, math.MaxUint16}
+ allUint32s = []uint32{0, 1, math.MaxUint32}
+ allUint64s = []uint64{0, 1, math.MaxUint64}
+)
+
+var allInts = flatten(
+ allIntTs,
+ allInt8s,
+ allInt16s,
+ allInt32s,
+ allInt64s,
+)
+
+var allUints = flatten(
+ allUintTs,
+ allUintptrs,
+ allUint8s,
+ allUint16s,
+ allUint32s,
+ allUint64s,
+)
+
+func TestInt(t *testing.T) {
+ runTestCases(t, false, "plain", allInts)
+ runTestCases(t, false, "pointers", pointersTo(allInts))
+ runTestCases(t, false, "interfaces", interfacesTo(allInts))
+ runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allInts)))
+}
+
+func TestIntTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingInt8{save: math.MinInt8 - 1},
+ truncatingInt16{save: math.MinInt16 - 1},
+ truncatingInt32{save: math.MinInt32 - 1},
+ truncatingInt8{save: math.MaxInt8 + 1},
+ truncatingInt16{save: math.MaxInt16 + 1},
+ truncatingInt32{save: math.MaxInt32 + 1},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingInt8{save: 1},
+ truncatingInt16{save: 1},
+ truncatingInt32{save: 1},
+ })
+}
+
+func TestUint(t *testing.T) {
+ runTestCases(t, false, "plain", allUints)
+ runTestCases(t, false, "pointers", pointersTo(allUints))
+ runTestCases(t, false, "interfaces", interfacesTo(allUints))
+ runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allUints)))
+}
+
+func TestUintTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingUint8{save: math.MaxUint8 + 1},
+ truncatingUint16{save: math.MaxUint16 + 1},
+ truncatingUint32{save: math.MaxUint32 + 1},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingUint8{save: 1},
+ truncatingUint16{save: 1},
+ truncatingUint32{save: 1},
+ })
+}
diff --git a/pkg/state/tests/load.go b/pkg/state/tests/load.go
new file mode 100644
index 000000000..a8350c0f3
--- /dev/null
+++ b/pkg/state/tests/load.go
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type genericContainer struct {
+ v interface{}
+}
+
+// +stateify savable
+type afterLoadStruct struct {
+ v int `state:"nosave"`
+}
+
+func (a *afterLoadStruct) afterLoad() {
+ a.v++
+}
+
+// +stateify savable
+type valueLoadStruct struct {
+ v int `state:".(int64)"`
+}
+
+func (v *valueLoadStruct) saveV() int64 {
+ return int64(v.v) // Save as int64.
+}
+
+func (v *valueLoadStruct) loadV(value int64) {
+ v.v = int(value) // Load as int.
+}
+
+// +stateify savable
+type cycleStruct struct {
+ c *cycleStruct
+}
+
+// +stateify savable
+type badCycleStruct struct {
+ b *badCycleStruct `state:"wait"`
+}
+
+func (b *badCycleStruct) afterLoad() {
+ if b.b != b {
+ // This is not executable, since AfterLoad requires that the
+ // object and all dependencies are complete. This should cause
+ // a deadlock error during load.
+ panic("badCycleStruct.afterLoad called")
+ }
+}
diff --git a/pkg/state/tests/load_test.go b/pkg/state/tests/load_test.go
new file mode 100644
index 000000000..1e9794296
--- /dev/null
+++ b/pkg/state/tests/load_test.go
@@ -0,0 +1,70 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+)
+
+func TestLoadHooks(t *testing.T) {
+ runTestCases(t, false, "load-hooks", []interface{}{
+ &afterLoadStruct{v: 1},
+ &valueLoadStruct{v: 1},
+ &genericContainer{v: &afterLoadStruct{v: 1}},
+ &genericContainer{v: &valueLoadStruct{v: 1}},
+ &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}},
+ &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}},
+ &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}},
+ &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}},
+ })
+}
+
+func TestCycles(t *testing.T) {
+ // cs is a single object cycle.
+ cs := cycleStruct{nil}
+ cs.c = &cs
+
+ // cs1 and cs2 are in a two object cycle.
+ cs1 := cycleStruct{nil}
+ cs2 := cycleStruct{nil}
+ cs1.c = &cs2
+ cs2.c = &cs1
+
+ runTestCases(t, false, "cycles", []interface{}{
+ cs,
+ cs1,
+ })
+}
+
+func TestDeadlock(t *testing.T) {
+ // bs is a single object cycle. This does not cause deadlock because an
+ // object cannot wait for itself.
+ bs := badCycleStruct{nil}
+ bs.b = &bs
+
+ runTestCases(t, false, "self", []interface{}{
+ &bs,
+ })
+
+ // bs2 and bs2 are in a deadlocking cycle.
+ bs1 := badCycleStruct{nil}
+ bs2 := badCycleStruct{nil}
+ bs1.b = &bs2
+ bs2.b = &bs1
+
+ runTestCases(t, true, "deadlock", []interface{}{
+ &bs1,
+ })
+}
diff --git a/pkg/state/tests/map.go b/pkg/state/tests/map.go
new file mode 100644
index 000000000..db4e548f1
--- /dev/null
+++ b/pkg/state/tests/map.go
@@ -0,0 +1,28 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type mapContainer struct {
+ v map[int]interface{}
+}
+
+// +stateify savable
+type mapPtrContainer struct {
+ v *map[int]interface{}
+}
+
+// +stateify savable
+type registeredMapStruct struct{}
diff --git a/pkg/state/tests/map_test.go b/pkg/state/tests/map_test.go
new file mode 100644
index 000000000..92bf0fc01
--- /dev/null
+++ b/pkg/state/tests/map_test.go
@@ -0,0 +1,90 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "reflect"
+ "testing"
+)
+
+var allMapPrimitives = []interface{}{
+ bool(true),
+ int(1),
+ int8(1),
+ int16(1),
+ int32(1),
+ int64(1),
+ uint(1),
+ uintptr(1),
+ uint8(1),
+ uint16(1),
+ uint32(1),
+ uint64(1),
+ string(""),
+ registeredMapStruct{},
+}
+
+var allMapKeys = flatten(allMapPrimitives, pointersTo(allMapPrimitives))
+
+var allMapValues = flatten(allMapPrimitives, pointersTo(allMapPrimitives), interfacesTo(allMapPrimitives))
+
+var emptyMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} {
+ m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2)))
+ return m.Interface()
+})
+
+var fullMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} {
+ m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2)))
+ m.SetMapIndex(reflect.Zero(reflect.TypeOf(v1)), reflect.Zero(reflect.TypeOf(v2)))
+ return m.Interface()
+})
+
+func TestMapAliasing(t *testing.T) {
+ v := make(map[int]int)
+ ptrToV := &v
+ aliases := []map[int]int{v, v}
+ runTestCases(t, false, "", []interface{}{ptrToV, aliases})
+}
+
+func TestMapsEmpty(t *testing.T) {
+ runTestCases(t, false, "plain", emptyMaps)
+ runTestCases(t, false, "pointers", pointersTo(emptyMaps))
+ runTestCases(t, false, "interfaces", interfacesTo(emptyMaps))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(emptyMaps)))
+}
+
+func TestMapsFull(t *testing.T) {
+ runTestCases(t, false, "plain", fullMaps)
+ runTestCases(t, false, "pointers", pointersTo(fullMaps))
+ runTestCases(t, false, "interfaces", interfacesTo(fullMaps))
+ runTestCases(t, false, "interfacesToPointer", interfacesTo(pointersTo(fullMaps)))
+}
+
+func TestMapContainers(t *testing.T) {
+ var (
+ nilMap map[int]interface{}
+ emptyMap = make(map[int]interface{})
+ fullMap = map[int]interface{}{0: nil}
+ )
+ runTestCases(t, false, "", []interface{}{
+ mapContainer{v: nilMap},
+ mapContainer{v: emptyMap},
+ mapContainer{v: fullMap},
+ mapPtrContainer{v: nil},
+ mapPtrContainer{v: &nilMap},
+ mapPtrContainer{v: &emptyMap},
+ mapPtrContainer{v: &fullMap},
+ })
+}
diff --git a/pkg/state/tests/register.go b/pkg/state/tests/register.go
new file mode 100644
index 000000000..074d86315
--- /dev/null
+++ b/pkg/state/tests/register.go
@@ -0,0 +1,21 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type alreadyRegisteredStruct struct{}
+
+// +stateify savable
+type alreadyRegisteredOther int
diff --git a/pkg/state/tests/register_test.go b/pkg/state/tests/register_test.go
new file mode 100644
index 000000000..c829753cc
--- /dev/null
+++ b/pkg/state/tests/register_test.go
@@ -0,0 +1,167 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+// faker calls itself whatever is in the name field.
+type faker struct {
+ Name string
+ Fields []string
+}
+
+func (f *faker) StateTypeName() string {
+ return f.Name
+}
+
+func (f *faker) StateFields() []string {
+ return f.Fields
+}
+
+// fakerWithSaverLoader has all it needs.
+type fakerWithSaverLoader struct {
+ faker
+}
+
+func (f *fakerWithSaverLoader) StateSave(m state.Sink) {}
+
+func (f *fakerWithSaverLoader) StateLoad(m state.Source) {}
+
+// fakerOther calls itself .. uh, itself?
+type fakerOther string
+
+func (f *fakerOther) StateTypeName() string {
+ return string(*f)
+}
+
+func (f *fakerOther) StateFields() []string {
+ return nil
+}
+
+func newFakerOther(name string) *fakerOther {
+ f := fakerOther(name)
+ return &f
+}
+
+// fakerOtherBadFields returns non-nil fields.
+type fakerOtherBadFields string
+
+func (f *fakerOtherBadFields) StateTypeName() string {
+ return string(*f)
+}
+
+func (f *fakerOtherBadFields) StateFields() []string {
+ return []string{string(*f)}
+}
+
+func newFakerOtherBadFields(name string) *fakerOtherBadFields {
+ f := fakerOtherBadFields(name)
+ return &f
+}
+
+// fakerOtherSaverLoader implements SaverLoader methods.
+type fakerOtherSaverLoader string
+
+func (f *fakerOtherSaverLoader) StateTypeName() string {
+ return string(*f)
+}
+
+func (f *fakerOtherSaverLoader) StateFields() []string {
+ return nil
+}
+
+func (f *fakerOtherSaverLoader) StateSave(m state.Sink) {}
+
+func (f *fakerOtherSaverLoader) StateLoad(m state.Source) {}
+
+func newFakerOtherSaverLoader(name string) *fakerOtherSaverLoader {
+ f := fakerOtherSaverLoader(name)
+ return &f
+}
+
+func TestRegisterPrimitives(t *testing.T) {
+ for _, typeName := range []string{
+ "int",
+ "int8",
+ "int16",
+ "int32",
+ "int64",
+ "uint",
+ "uintptr",
+ "uint8",
+ "uint16",
+ "uint32",
+ "uint64",
+ "float32",
+ "float64",
+ "complex64",
+ "complex128",
+ "string",
+ } {
+ t.Run("struct/"+typeName, func(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Registering type %q did not panic", typeName)
+ }
+ }()
+ state.Register(&faker{
+ Name: typeName,
+ })
+ })
+ t.Run("other/"+typeName, func(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Registering type %q did not panic", typeName)
+ }
+ }()
+ state.Register(newFakerOther(typeName))
+ })
+ }
+}
+
+func TestRegisterBad(t *testing.T) {
+ const (
+ goodName = "foo"
+ firstField = "a"
+ secondField = "b"
+ )
+ for name, object := range map[string]state.Type{
+ "non-struct-with-fields": newFakerOtherBadFields(goodName),
+ "non-struct-with-saverloader": newFakerOtherSaverLoader(goodName),
+ "struct-without-saverloader": &faker{Name: goodName},
+ "non-struct-duplicate-with-struct": newFakerOther((new(alreadyRegisteredStruct)).StateTypeName()),
+ "non-struct-duplicate-with-non-struct": newFakerOther((new(alreadyRegisteredOther)).StateTypeName()),
+ "struct-duplicate-with-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredStruct)).StateTypeName()}},
+ "struct-duplicate-with-non-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredOther)).StateTypeName()}},
+ "struct-with-empty-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{""}}},
+ "struct-with-empty-field-and-non-empty": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, ""}}},
+ "struct-with-duplicate-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, firstField}}},
+ "struct-with-duplicate-field-and-non-dup": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, secondField, firstField}}},
+ } {
+ t.Run(name, func(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Registering object %#v did not panic", object)
+ }
+ }()
+ state.Register(object)
+ })
+
+ }
+}
diff --git a/pkg/state/tests/string_test.go b/pkg/state/tests/string_test.go
new file mode 100644
index 000000000..44f5a562c
--- /dev/null
+++ b/pkg/state/tests/string_test.go
@@ -0,0 +1,34 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+)
+
+const nonEmptyString = "hello world"
+
+var allStrings = []string{
+ "",
+ nonEmptyString,
+ "\\0",
+}
+
+func TestString(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(allStrings))
+ runTestCases(t, false, "pointers", pointersTo(flatten(allStrings)))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(allStrings)))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allStrings))))
+}
diff --git a/pkg/state/tests/struct.go b/pkg/state/tests/struct.go
new file mode 100644
index 000000000..bd2c2b399
--- /dev/null
+++ b/pkg/state/tests/struct.go
@@ -0,0 +1,65 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+type unregisteredEmptyStruct struct{}
+
+// typeOnlyEmptyStruct just implements the state.Type interface.
+type typeOnlyEmptyStruct struct{}
+
+func (*typeOnlyEmptyStruct) StateTypeName() string { return "registeredEmptyStruct" }
+
+func (*typeOnlyEmptyStruct) StateFields() []string { return nil }
+
+// +stateify savable
+type savableEmptyStruct struct{}
+
+// +stateify savable
+type emptyStructPointer struct {
+ nothing *struct{}
+}
+
+// +stateify savable
+type outerSame struct {
+ inner inner
+}
+
+// +stateify savable
+type outerFieldFirst struct {
+ inner inner
+ v int64
+}
+
+// +stateify savable
+type outerFieldSecond struct {
+ v int64
+ inner inner
+}
+
+// +stateify savable
+type outerArray struct {
+ inner [2]inner
+}
+
+// +stateify savable
+type inner struct {
+ v int64
+}
+
+// +stateify savable
+type system struct {
+ v1 interface{}
+ v2 interface{}
+}
diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go
new file mode 100644
index 000000000..de9d17aa7
--- /dev/null
+++ b/pkg/state/tests/struct_test.go
@@ -0,0 +1,89 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func TestEmptyStruct(t *testing.T) {
+ runTestCases(t, false, "plain", []interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ savableEmptyStruct{},
+ })
+ runTestCases(t, false, "pointers", pointersTo([]interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ savableEmptyStruct{},
+ }))
+ runTestCases(t, false, "interfaces-pass", interfacesTo([]interface{}{
+ // Only registered types can be dispatched via interfaces. All
+ // other types should fail, even if it is the empty struct.
+ savableEmptyStruct{},
+ }))
+ runTestCases(t, true, "interfaces-fail", interfacesTo([]interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ }))
+ runTestCases(t, false, "interfacesToPointers-pass", interfacesTo(pointersTo([]interface{}{
+ savableEmptyStruct{},
+ })))
+ runTestCases(t, true, "interfacesToPointers-fail", interfacesTo(pointersTo([]interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ })))
+
+ // Ensuring empty struct aliasing works.
+ es := emptyStructPointer{new(struct{})}
+ runTestCases(t, false, "empty-struct-pointers", []interface{}{
+ emptyStructPointer{},
+ es,
+ []emptyStructPointer{es, es}, // Same pointer.
+ })
+}
+
+func TestRegisterTypeOnlyStruct(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Register did not panic")
+ }
+ }()
+ state.Register((*typeOnlyEmptyStruct)(nil))
+}
+
+func TestEmbeddedPointers(t *testing.T) {
+ var (
+ ofs outerSame
+ of1 outerFieldFirst
+ of2 outerFieldSecond
+ oa outerArray
+ )
+
+ runTestCases(t, false, "embedded-pointers", []interface{}{
+ system{&ofs, &ofs.inner},
+ system{&ofs.inner, &ofs},
+ system{&of1, &of1.inner},
+ system{&of1.inner, &of1},
+ system{&of2, &of2.inner},
+ system{&of2.inner, &of2},
+ system{&oa, &oa.inner[0]},
+ system{&oa, &oa.inner[1]},
+ system{&oa.inner[0], &oa},
+ system{&oa.inner[1], &oa},
+ })
+}
diff --git a/pkg/state/tests/tests.go b/pkg/state/tests/tests.go
new file mode 100644
index 000000000..435a0e9db
--- /dev/null
+++ b/pkg/state/tests/tests.go
@@ -0,0 +1,215 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tests tests the state packages.
+package tests
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "math"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/pretty"
+)
+
+// discard is an implementation of wire.Writer.
+type discard struct{}
+
+// Write implements wire.Writer.Write.
+func (discard) Write(p []byte) (int, error) { return len(p), nil }
+
+// WriteByte implements wire.Writer.WriteByte.
+func (discard) WriteByte(byte) error { return nil }
+
+// checkEqual checks if two objects are equal.
+//
+// N.B. This only handles one level of dereferences for NaN. Otherwise we
+// would need to fork the entire implementation of reflect.DeepEqual.
+func checkEqual(root, loadedValue interface{}) bool {
+ if reflect.DeepEqual(root, loadedValue) {
+ return true
+ }
+
+ // NaN is not equal to itself. We handle the case of raw floating point
+ // primitives here, but don't handle this case nested.
+ rf32, ok1 := root.(float32)
+ lf32, ok2 := loadedValue.(float32)
+ if ok1 && ok2 && math.IsNaN(float64(rf32)) && math.IsNaN(float64(lf32)) {
+ return true
+ }
+ rf64, ok1 := root.(float64)
+ lf64, ok2 := loadedValue.(float64)
+ if ok1 && ok2 && math.IsNaN(rf64) && math.IsNaN(lf64) {
+ return true
+ }
+
+ // Same real for complex numbers.
+ rc64, ok1 := root.(complex64)
+ lc64, ok2 := root.(complex64)
+ if ok1 && ok2 {
+ return checkEqual(real(rc64), real(lc64)) && checkEqual(imag(rc64), imag(lc64))
+ }
+ rc128, ok1 := root.(complex128)
+ lc128, ok2 := root.(complex128)
+ if ok1 && ok2 {
+ return checkEqual(real(rc128), real(lc128)) && checkEqual(imag(rc128), imag(lc128))
+ }
+
+ return false
+}
+
+// runTestCases runs a test for each object in objects.
+func runTestCases(t *testing.T, shouldFail bool, prefix string, objects []interface{}) {
+ t.Helper()
+ for i, root := range objects {
+ t.Run(fmt.Sprintf("%s%d", prefix, i), func(t *testing.T) {
+ t.Logf("Original object:\n%#v", root)
+
+ // Save the passed object.
+ saveBuffer := &bytes.Buffer{}
+ saveObjectPtr := reflect.New(reflect.TypeOf(root))
+ saveObjectPtr.Elem().Set(reflect.ValueOf(root))
+ saveStats, err := state.Save(context.Background(), saveBuffer, saveObjectPtr.Interface())
+ if err != nil {
+ if shouldFail {
+ return
+ }
+ t.Fatalf("Save failed unexpectedly: %v", err)
+ }
+
+ // Dump the serialized proto to aid with debugging.
+ var ppBuf bytes.Buffer
+ t.Logf("Raw state:\n%v", saveBuffer.Bytes())
+ if err := pretty.PrintText(&ppBuf, bytes.NewReader(saveBuffer.Bytes())); err != nil {
+ // We don't count this as a test failure if we
+ // have shouldFail set, but we will count as a
+ // failure if we were not expecting to fail.
+ if !shouldFail {
+ t.Errorf("PrettyPrint(html=false) failed unexpected: %v", err)
+ }
+ }
+ if err := pretty.PrintHTML(discard{}, bytes.NewReader(saveBuffer.Bytes())); err != nil {
+ // See above.
+ if !shouldFail {
+ t.Errorf("PrettyPrint(html=true) failed unexpected: %v", err)
+ }
+ }
+ t.Logf("Encoded state:\n%s", ppBuf.String())
+ t.Logf("Save stats:\n%s", saveStats.String())
+
+ // Load a new copy of the object.
+ loadObjectPtr := reflect.New(reflect.TypeOf(root))
+ loadStats, err := state.Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface())
+ if err != nil {
+ if shouldFail {
+ return
+ }
+ t.Fatalf("Load failed unexpectedly: %v", err)
+ }
+
+ // Compare the values.
+ loadedValue := loadObjectPtr.Elem().Interface()
+ if !checkEqual(root, loadedValue) {
+ if shouldFail {
+ return
+ }
+ t.Fatalf("Objects differ:\n\toriginal: %#v\n\tloaded: %#v\n", root, loadedValue)
+ }
+
+ // Everything went okay. Is that good?
+ if shouldFail {
+ t.Fatalf("This test was expected to fail, but didn't.")
+ }
+ t.Logf("Load stats:\n%s", loadStats.String())
+
+ // Truncate half the bytes in the byte stream,
+ // and ensure that we can't restore. Then
+ // truncate only the final byte and ensure that
+ // we can't restore.
+ l := saveBuffer.Len()
+ halfReader := bytes.NewReader(saveBuffer.Bytes()[:l/2])
+ if _, err := state.Load(context.Background(), halfReader, loadObjectPtr.Interface()); err == nil {
+ t.Errorf("Load with half bytes succeeded unexpectedly.")
+ }
+ missingByteReader := bytes.NewReader(saveBuffer.Bytes()[:l-1])
+ if _, err := state.Load(context.Background(), missingByteReader, loadObjectPtr.Interface()); err == nil {
+ t.Errorf("Load with missing byte succeeded unexpectedly.")
+ }
+ })
+ }
+}
+
+// convert converts the slice to an []interface{}.
+func convert(v interface{}) (r []interface{}) {
+ s := reflect.ValueOf(v) // Must be slice.
+ for i := 0; i < s.Len(); i++ {
+ r = append(r, s.Index(i).Interface())
+ }
+ return r
+}
+
+// flatten flattens multiple slices.
+func flatten(vs ...interface{}) (r []interface{}) {
+ for _, v := range vs {
+ r = append(r, convert(v)...)
+ }
+ return r
+}
+
+// filter maps from one slice to another.
+func filter(vs interface{}, fn func(interface{}) (interface{}, bool)) (r []interface{}) {
+ s := reflect.ValueOf(vs)
+ for i := 0; i < s.Len(); i++ {
+ v, ok := fn(s.Index(i).Interface())
+ if ok {
+ r = append(r, v)
+ }
+ }
+ return r
+}
+
+// combine combines objects in two slices as specified.
+func combine(v1, v2 interface{}, fn func(_, _ interface{}) interface{}) (r []interface{}) {
+ s1 := reflect.ValueOf(v1)
+ s2 := reflect.ValueOf(v2)
+ for i := 0; i < s1.Len(); i++ {
+ for j := 0; j < s2.Len(); j++ {
+ // Combine using the given function.
+ r = append(r, fn(s1.Index(i).Interface(), s2.Index(j).Interface()))
+ }
+ }
+ return r
+}
+
+// pointersTo is a filter function that returns pointers.
+func pointersTo(vs interface{}) []interface{} {
+ return filter(vs, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o))
+ v.Elem().Set(reflect.ValueOf(o))
+ return v.Interface(), true
+ })
+}
+
+// interfacesTo is a filter function that returns interface objects.
+func interfacesTo(vs interface{}) []interface{} {
+ return filter(vs, func(o interface{}) (interface{}, bool) {
+ var v [1]interface{}
+ v[0] = o
+ return v, true
+ })
+}
diff --git a/pkg/state/types.go b/pkg/state/types.go
new file mode 100644
index 000000000..215ef80f8
--- /dev/null
+++ b/pkg/state/types.go
@@ -0,0 +1,361 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "reflect"
+ "sort"
+
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+// assertValidType asserts that the type is valid.
+func assertValidType(name string, fields []string) {
+ if name == "" {
+ Failf("type has empty name")
+ }
+ fieldsCopy := make([]string, len(fields))
+ for i := 0; i < len(fields); i++ {
+ if fields[i] == "" {
+ Failf("field has empty name for type %q", name)
+ }
+ fieldsCopy[i] = fields[i]
+ }
+ sort.Slice(fieldsCopy, func(i, j int) bool {
+ return fieldsCopy[i] < fieldsCopy[j]
+ })
+ for i := range fieldsCopy {
+ if i > 0 && fieldsCopy[i-1] == fieldsCopy[i] {
+ Failf("duplicate field %q for type %s", fieldsCopy[i], name)
+ }
+ }
+}
+
+// typeEntry is an entry in the typeDatabase.
+type typeEntry struct {
+ ID typeID
+ wire.Type
+}
+
+// reconciledTypeEntry is a reconciled entry in the typeDatabase.
+type reconciledTypeEntry struct {
+ wire.Type
+ LocalType reflect.Type
+ FieldOrder []int
+}
+
+// typeEncodeDatabase is an internal TypeInfo database for encoding.
+type typeEncodeDatabase struct {
+ // byType maps by type to the typeEntry.
+ byType map[reflect.Type]*typeEntry
+
+ // lastID is the last used ID.
+ lastID typeID
+}
+
+// makeTypeEncodeDatabase makes a typeDatabase.
+func makeTypeEncodeDatabase() typeEncodeDatabase {
+ return typeEncodeDatabase{
+ byType: make(map[reflect.Type]*typeEntry),
+ }
+}
+
+// typeDecodeDatabase is an internal TypeInfo database for decoding.
+type typeDecodeDatabase struct {
+ // byID maps by ID to type.
+ byID []*reconciledTypeEntry
+
+ // pending are entries that are pending validation by Lookup. These
+ // will be reconciled with actual objects. Note that these will also be
+ // used to lookup types by name, since they may not be reconciled and
+ // there's little value to deleting from this map.
+ pending []*wire.Type
+}
+
+// makeTypeDecodeDatabase makes a typeDatabase.
+func makeTypeDecodeDatabase() typeDecodeDatabase {
+ return typeDecodeDatabase{}
+}
+
+// lookupNameFields extracts the name and fields from an object.
+func lookupNameFields(typ reflect.Type) (string, []string, bool) {
+ v := reflect.Zero(reflect.PtrTo(typ)).Interface()
+ t, ok := v.(Type)
+ if !ok {
+ // Is this a primitive?
+ if typ.Kind() == reflect.Interface {
+ return interfaceType, nil, true
+ }
+ name := typ.Name()
+ if _, ok := primitiveTypeDatabase[name]; !ok {
+ // This is not a known type, and not a primitive. The
+ // encoder may proceed for anonymous empty structs, or
+ // it may deference the type pointer and try again.
+ return "", nil, false
+ }
+ return name, nil, true
+ }
+ // Extract the name from the object.
+ name := t.StateTypeName()
+ fields := t.StateFields()
+ assertValidType(name, fields)
+ return name, fields, true
+}
+
+// Lookup looks up or registers the given object.
+//
+// The bool indicates whether this is an existing entry: false means the entry
+// did not exist, and true means the entry did exist. If this bool is false and
+// the returned typeEntry are nil, then the obj did not implement the Type
+// interface.
+func (tdb *typeEncodeDatabase) Lookup(typ reflect.Type) (*typeEntry, bool) {
+ te, ok := tdb.byType[typ]
+ if !ok {
+ // Lookup the type information.
+ name, fields, ok := lookupNameFields(typ)
+ if !ok {
+ // Empty structs may still be encoded, so let the
+ // caller decide what to do from here.
+ return nil, false
+ }
+
+ // Register the new type.
+ tdb.lastID++
+ te = &typeEntry{
+ ID: tdb.lastID,
+ Type: wire.Type{
+ Name: name,
+ Fields: fields,
+ },
+ }
+
+ // All done.
+ tdb.byType[typ] = te
+ return te, false
+ }
+ return te, true
+}
+
+// Register adds a typeID entry.
+func (tbd *typeDecodeDatabase) Register(typ *wire.Type) {
+ assertValidType(typ.Name, typ.Fields)
+ tbd.pending = append(tbd.pending, typ)
+}
+
+// LookupName looks up the type name by ID.
+func (tbd *typeDecodeDatabase) LookupName(id typeID) string {
+ if len(tbd.pending) < int(id) {
+ // This is likely an encoder error?
+ Failf("type ID %d not available", id)
+ }
+ return tbd.pending[id-1].Name
+}
+
+// LookupType looks up the type by ID.
+func (tbd *typeDecodeDatabase) LookupType(id typeID) reflect.Type {
+ name := tbd.LookupName(id)
+ typ, ok := globalTypeDatabase[name]
+ if !ok {
+ // If not available, see if it's primitive.
+ typ, ok = primitiveTypeDatabase[name]
+ if !ok && name == interfaceType {
+ // Matches the built-in interface type.
+ var i interface{}
+ return reflect.TypeOf(&i).Elem()
+ }
+ if !ok {
+ // The type is perhaps not registered?
+ Failf("type name %q is not available", name)
+ }
+ return typ // Primitive type.
+ }
+ return typ // Registered type.
+}
+
+// singleFieldOrder defines the field order for a single field.
+var singleFieldOrder = []int{0}
+
+// Lookup looks up or registers the given object.
+//
+// First, the typeID is searched to see if this has already been appropriately
+// reconciled. If no, then a reconcilation will take place that may result in a
+// field ordering. If a nil reconciledTypeEntry is returned from this method,
+// then the object does not support the Type interface.
+//
+// This method never returns nil.
+func (tbd *typeDecodeDatabase) Lookup(id typeID, typ reflect.Type) *reconciledTypeEntry {
+ if len(tbd.byID) > int(id) && tbd.byID[id-1] != nil {
+ // Already reconciled.
+ return tbd.byID[id-1]
+ }
+ // The ID has not been reconciled yet. That's fine. We need to make
+ // sure it aligns with the current provided object.
+ if len(tbd.pending) < int(id) {
+ // This id was never registered. Probably an encoder error?
+ Failf("typeDatabase does not contain id %d", id)
+ }
+ // Extract the pending info.
+ pending := tbd.pending[id-1]
+ // Grow the byID list.
+ if len(tbd.byID) < int(id) {
+ tbd.byID = append(tbd.byID, make([]*reconciledTypeEntry, int(id)-len(tbd.byID))...)
+ }
+ // Reconcile the type.
+ name, fields, ok := lookupNameFields(typ)
+ if !ok {
+ // Empty structs are decoded only when the type is nil. Since
+ // this isn't the case, we fail here.
+ Failf("unsupported type %q during decode; can't reconcile", pending.Name)
+ }
+ if name != pending.Name {
+ // Are these the same type? Print a helpful message as this may
+ // actually happen in practice if types change.
+ Failf("typeDatabase contains conflicting definitions for id %d: %s->%v (current) and %s->%v (existing)",
+ id, name, fields, pending.Name, pending.Fields)
+ }
+ rte := &reconciledTypeEntry{
+ Type: wire.Type{
+ Name: name,
+ Fields: fields,
+ },
+ LocalType: typ,
+ }
+ // If there are zero or one fields, then we skip allocating the field
+ // slice. There is special handling for decoding in this case. If the
+ // field name does not match, it will be caught in the general purpose
+ // code below.
+ if len(fields) != len(pending.Fields) {
+ Failf("type %q contains different fields: %v (decode) and %v (encode)",
+ name, fields, pending.Fields)
+ }
+ if len(fields) == 0 {
+ tbd.byID[id-1] = rte // Save.
+ return rte
+ }
+ if len(fields) == 1 && fields[0] == pending.Fields[0] {
+ tbd.byID[id-1] = rte // Save.
+ rte.FieldOrder = singleFieldOrder
+ return rte
+ }
+ // For each field in the current object's information, match it to a
+ // field in the destination object. We know from the assertion above
+ // and the insertion on insertion to pending that neither field
+ // contains any duplicates.
+ fieldOrder := make([]int, len(fields))
+ for i, name := range fields {
+ fieldOrder[i] = -1 // Sentinel.
+ // Is it an exact match?
+ if pending.Fields[i] == name {
+ fieldOrder[i] = i
+ continue
+ }
+ // Find the matching field.
+ for j, otherName := range pending.Fields {
+ if name == otherName {
+ fieldOrder[i] = j
+ break
+ }
+ }
+ if fieldOrder[i] == -1 {
+ // The type name matches but we are lacking some common fields.
+ Failf("type %q has mismatched fields: %v (decode) and %v (encode)",
+ name, fields, pending.Fields)
+ }
+ }
+ // The type has been reeconciled.
+ rte.FieldOrder = fieldOrder
+ tbd.byID[id-1] = rte
+ return rte
+}
+
+// interfaceType defines all interfaces.
+const interfaceType = "interface"
+
+// primitiveTypeDatabase is a set of fixed types.
+var primitiveTypeDatabase = func() map[string]reflect.Type {
+ r := make(map[string]reflect.Type)
+ for _, t := range []reflect.Type{
+ reflect.TypeOf(false),
+ reflect.TypeOf(int(0)),
+ reflect.TypeOf(int8(0)),
+ reflect.TypeOf(int16(0)),
+ reflect.TypeOf(int32(0)),
+ reflect.TypeOf(int64(0)),
+ reflect.TypeOf(uint(0)),
+ reflect.TypeOf(uintptr(0)),
+ reflect.TypeOf(uint8(0)),
+ reflect.TypeOf(uint16(0)),
+ reflect.TypeOf(uint32(0)),
+ reflect.TypeOf(uint64(0)),
+ reflect.TypeOf(""),
+ reflect.TypeOf(float32(0.0)),
+ reflect.TypeOf(float64(0.0)),
+ reflect.TypeOf(complex64(0.0)),
+ reflect.TypeOf(complex128(0.0)),
+ } {
+ r[t.Name()] = t
+ }
+ return r
+}()
+
+// globalTypeDatabase is used for dispatching interfaces on decode.
+var globalTypeDatabase = map[string]reflect.Type{}
+
+// Register registers a type.
+//
+// This must be called on init and only done once.
+func Register(t Type) {
+ name := t.StateTypeName()
+ fields := t.StateFields()
+ assertValidType(name, fields)
+ // Register must always be called on pointers.
+ typ := reflect.TypeOf(t)
+ if typ.Kind() != reflect.Ptr {
+ Failf("Register must be called on pointers")
+ }
+ typ = typ.Elem()
+ if typ.Kind() == reflect.Struct {
+ // All registered structs must implement SaverLoader. We allow
+ // the registration is non-struct types with just the Type
+ // interface, but we need to call StateSave/StateLoad methods
+ // on aggregate types.
+ if _, ok := t.(SaverLoader); !ok {
+ Failf("struct %T does not implement SaverLoader", t)
+ }
+ } else {
+ // Non-structs must not have any fields. We don't support
+ // calling StateSave/StateLoad methods on any non-struct types.
+ // If custom behavior is required, these types should be
+ // wrapped in a structure of some kind.
+ if len(fields) != 0 {
+ Failf("non-struct %T has non-zero fields %v", t, fields)
+ }
+ // We don't allow non-structs to implement StateSave/StateLoad
+ // methods, because they won't be called and it's confusing.
+ if _, ok := t.(SaverLoader); ok {
+ Failf("non-struct %T implements SaverLoader", t)
+ }
+ }
+ if _, ok := primitiveTypeDatabase[name]; ok {
+ Failf("conflicting primitiveTypeDatabase entry for %T: used by primitive", t)
+ }
+ if _, ok := globalTypeDatabase[name]; ok {
+ Failf("conflicting globalTypeDatabase entries for %T: name conflict", t)
+ }
+ if name == interfaceType {
+ Failf("conflicting name for %T: matches interfaceType", t)
+ }
+ globalTypeDatabase[name] = typ
+}
diff --git a/pkg/state/wire/BUILD b/pkg/state/wire/BUILD
new file mode 100644
index 000000000..311b93dcb
--- /dev/null
+++ b/pkg/state/wire/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "wire",
+ srcs = ["wire.go"],
+ marshal = False,
+ stateify = False,
+ visibility = ["//:sandbox"],
+ deps = ["//pkg/gohacks"],
+)
diff --git a/pkg/state/wire/wire.go b/pkg/state/wire/wire.go
new file mode 100644
index 000000000..93dee6740
--- /dev/null
+++ b/pkg/state/wire/wire.go
@@ -0,0 +1,970 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package wire contains a few basic types that can be composed to serialize
+// graph information for the state package. This package defines the wire
+// protocol.
+//
+// Note that these types are careful about how they implement the relevant
+// interfaces (either value receiver or pointer receiver), so that native-sized
+// types, such as integers and simple pointers, can fit inside the interface
+// object.
+//
+// This package also uses panic as control flow, so called should be careful to
+// wrap calls in appropriate handlers.
+//
+// Testing for this package is driven by the state test package.
+package wire
+
+import (
+ "fmt"
+ "io"
+ "math"
+
+ "gvisor.dev/gvisor/pkg/gohacks"
+)
+
+// Reader is the required reader interface.
+type Reader interface {
+ io.Reader
+ ReadByte() (byte, error)
+}
+
+// Writer is the required writer interface.
+type Writer interface {
+ io.Writer
+ WriteByte(byte) error
+}
+
+// readFull is a utility. The equivalent is not needed for Write, but the API
+// contract dictates that it must always complete all bytes given or return an
+// error.
+func readFull(r io.Reader, p []byte) {
+ for done := 0; done < len(p); {
+ n, err := r.Read(p[done:])
+ done += n
+ if n == 0 && err != nil {
+ panic(err)
+ }
+ }
+}
+
+// Object is a generic object.
+type Object interface {
+ // save saves the given object.
+ //
+ // Panic is used for error control flow.
+ save(Writer)
+
+ // load loads a new object of the given type.
+ //
+ // Panic is used for error control flow.
+ load(Reader) Object
+}
+
+// Bool is a boolean.
+type Bool bool
+
+// loadBool loads an object of type Bool.
+func loadBool(r Reader) Bool {
+ b := loadUint(r)
+ return Bool(b == 1)
+}
+
+// save implements Object.save.
+func (b Bool) save(w Writer) {
+ var v Uint
+ if b {
+ v = 1
+ } else {
+ v = 0
+ }
+ v.save(w)
+}
+
+// load implements Object.load.
+func (Bool) load(r Reader) Object { return loadBool(r) }
+
+// Int is a signed integer.
+//
+// This uses varint encoding.
+type Int int64
+
+// loadInt loads an object of type Int.
+func loadInt(r Reader) Int {
+ u := loadUint(r)
+ x := Int(u >> 1)
+ if u&1 != 0 {
+ x = ^x
+ }
+ return x
+}
+
+// save implements Object.save.
+func (i Int) save(w Writer) {
+ u := Uint(i) << 1
+ if i < 0 {
+ u = ^u
+ }
+ u.save(w)
+}
+
+// load implements Object.load.
+func (Int) load(r Reader) Object { return loadInt(r) }
+
+// Uint is an unsigned integer.
+type Uint uint64
+
+// loadUint loads an object of type Uint.
+func loadUint(r Reader) Uint {
+ var (
+ u Uint
+ s uint
+ )
+ for i := 0; i <= 9; i++ {
+ b, err := r.ReadByte()
+ if err != nil {
+ panic(err)
+ }
+ if b < 0x80 {
+ if i == 9 && b > 1 {
+ panic("overflow")
+ }
+ u |= Uint(b) << s
+ return u
+ }
+ u |= Uint(b&0x7f) << s
+ s += 7
+ }
+ panic("unreachable")
+}
+
+// save implements Object.save.
+func (u Uint) save(w Writer) {
+ for u >= 0x80 {
+ if err := w.WriteByte(byte(u) | 0x80); err != nil {
+ panic(err)
+ }
+ u >>= 7
+ }
+ if err := w.WriteByte(byte(u)); err != nil {
+ panic(err)
+ }
+}
+
+// load implements Object.load.
+func (Uint) load(r Reader) Object { return loadUint(r) }
+
+// Float32 is a 32-bit floating point number.
+type Float32 float32
+
+// loadFloat32 loads an object of type Float32.
+func loadFloat32(r Reader) Float32 {
+ n := loadUint(r)
+ return Float32(math.Float32frombits(uint32(n)))
+}
+
+// save implements Object.save.
+func (f Float32) save(w Writer) {
+ n := Uint(math.Float32bits(float32(f)))
+ n.save(w)
+}
+
+// load implements Object.load.
+func (Float32) load(r Reader) Object { return loadFloat32(r) }
+
+// Float64 is a 64-bit floating point number.
+type Float64 float64
+
+// loadFloat64 loads an object of type Float64.
+func loadFloat64(r Reader) Float64 {
+ n := loadUint(r)
+ return Float64(math.Float64frombits(uint64(n)))
+}
+
+// save implements Object.save.
+func (f Float64) save(w Writer) {
+ n := Uint(math.Float64bits(float64(f)))
+ n.save(w)
+}
+
+// load implements Object.load.
+func (Float64) load(r Reader) Object { return loadFloat64(r) }
+
+// Complex64 is a 64-bit complex number.
+type Complex64 complex128
+
+// loadComplex64 loads an object of type Complex64.
+func loadComplex64(r Reader) Complex64 {
+ re := loadFloat32(r)
+ im := loadFloat32(r)
+ return Complex64(complex(float32(re), float32(im)))
+}
+
+// save implements Object.save.
+func (c *Complex64) save(w Writer) {
+ re := Float32(real(*c))
+ im := Float32(imag(*c))
+ re.save(w)
+ im.save(w)
+}
+
+// load implements Object.load.
+func (*Complex64) load(r Reader) Object {
+ c := loadComplex64(r)
+ return &c
+}
+
+// Complex128 is a 128-bit complex number.
+type Complex128 complex128
+
+// loadComplex128 loads an object of type Complex128.
+func loadComplex128(r Reader) Complex128 {
+ re := loadFloat64(r)
+ im := loadFloat64(r)
+ return Complex128(complex(float64(re), float64(im)))
+}
+
+// save implements Object.save.
+func (c *Complex128) save(w Writer) {
+ re := Float64(real(*c))
+ im := Float64(imag(*c))
+ re.save(w)
+ im.save(w)
+}
+
+// load implements Object.load.
+func (*Complex128) load(r Reader) Object {
+ c := loadComplex128(r)
+ return &c
+}
+
+// String is a string.
+type String string
+
+// loadString loads an object of type String.
+func loadString(r Reader) String {
+ l := loadUint(r)
+ p := make([]byte, l)
+ readFull(r, p)
+ return String(gohacks.StringFromImmutableBytes(p))
+}
+
+// save implements Object.save.
+func (s *String) save(w Writer) {
+ l := Uint(len(*s))
+ l.save(w)
+ p := gohacks.ImmutableBytesFromString(string(*s))
+ _, err := w.Write(p) // Must write all bytes.
+ if err != nil {
+ panic(err)
+ }
+}
+
+// load implements Object.load.
+func (*String) load(r Reader) Object {
+ s := loadString(r)
+ return &s
+}
+
+// Dot is a kind of reference: one of Index and FieldName.
+type Dot interface {
+ isDot()
+}
+
+// Index is a reference resolution.
+type Index uint32
+
+func (Index) isDot() {}
+
+// FieldName is a reference resolution.
+type FieldName string
+
+func (*FieldName) isDot() {}
+
+// Ref is a reference to an object.
+type Ref struct {
+ // Root is the root object.
+ Root Uint
+
+ // Dots is the set of traversals required from the Root object above.
+ // Note that this will be stored in reverse order for efficiency.
+ Dots []Dot
+
+ // Type is the base type for the root object. This is non-nil iff Dots
+ // is non-zero length (that is, this is a complex reference). This is
+ // not *strictly* necessary, but can be used to simplify decoding.
+ Type TypeSpec
+}
+
+// loadRef loads an object of type Ref (abstract).
+func loadRef(r Reader) Ref {
+ ref := Ref{
+ Root: loadUint(r),
+ }
+ l := loadUint(r)
+ ref.Dots = make([]Dot, l)
+ for i := 0; i < int(l); i++ {
+ // Disambiguate between an Index (non-negative) and a field
+ // name (negative). This does some space and avoids a dedicate
+ // loadDot function. See Ref.save for the other side.
+ d := loadInt(r)
+ if d >= 0 {
+ ref.Dots[i] = Index(d)
+ continue
+ }
+ p := make([]byte, -d)
+ readFull(r, p)
+ fieldName := FieldName(gohacks.StringFromImmutableBytes(p))
+ ref.Dots[i] = &fieldName
+ }
+ if l != 0 {
+ // Only if dots is non-zero.
+ ref.Type = loadTypeSpec(r)
+ }
+ return ref
+}
+
+// save implements Object.save.
+func (r *Ref) save(w Writer) {
+ r.Root.save(w)
+ l := Uint(len(r.Dots))
+ l.save(w)
+ for _, d := range r.Dots {
+ // See LoadRef. We use non-negative numbers to encode Index
+ // objects and negative numbers to encode field lengths.
+ switch x := d.(type) {
+ case Index:
+ i := Int(x)
+ i.save(w)
+ case *FieldName:
+ d := Int(-len(*x))
+ d.save(w)
+ p := gohacks.ImmutableBytesFromString(string(*x))
+ if _, err := w.Write(p); err != nil {
+ panic(err)
+ }
+ default:
+ panic("unknown dot implementation")
+ }
+ }
+ if l != 0 {
+ // See above.
+ saveTypeSpec(w, r.Type)
+ }
+}
+
+// load implements Object.load.
+func (*Ref) load(r Reader) Object {
+ ref := loadRef(r)
+ return &ref
+}
+
+// Nil is a primitive zero value of any type.
+type Nil struct{}
+
+// loadNil loads an object of type Nil.
+func loadNil(r Reader) Nil {
+ return Nil{}
+}
+
+// save implements Object.save.
+func (Nil) save(w Writer) {}
+
+// load implements Object.load.
+func (Nil) load(r Reader) Object { return loadNil(r) }
+
+// Slice is a slice value.
+type Slice struct {
+ Length Uint
+ Capacity Uint
+ Ref Ref
+}
+
+// loadSlice loads an object of type Slice.
+func loadSlice(r Reader) Slice {
+ return Slice{
+ Length: loadUint(r),
+ Capacity: loadUint(r),
+ Ref: loadRef(r),
+ }
+}
+
+// save implements Object.save.
+func (s *Slice) save(w Writer) {
+ s.Length.save(w)
+ s.Capacity.save(w)
+ s.Ref.save(w)
+}
+
+// load implements Object.load.
+func (*Slice) load(r Reader) Object {
+ s := loadSlice(r)
+ return &s
+}
+
+// Array is an array value.
+type Array struct {
+ Contents []Object
+}
+
+// loadArray loads an object of type Array.
+func loadArray(r Reader) Array {
+ l := loadUint(r)
+ if l == 0 {
+ // Note that there isn't a single object available to encode
+ // the type of, so we need this additional branch.
+ return Array{}
+ }
+ // All the objects here have the same type, so use dynamic dispatch
+ // only once. All other objects will automatically take the same type
+ // as the first object.
+ contents := make([]Object, l)
+ v := Load(r)
+ contents[0] = v
+ for i := 1; i < int(l); i++ {
+ contents[i] = v.load(r)
+ }
+ return Array{
+ Contents: contents,
+ }
+}
+
+// save implements Object.save.
+func (a *Array) save(w Writer) {
+ l := Uint(len(a.Contents))
+ l.save(w)
+ if l == 0 {
+ // See LoadArray.
+ return
+ }
+ // See above.
+ Save(w, a.Contents[0])
+ for i := 1; i < int(l); i++ {
+ a.Contents[i].save(w)
+ }
+}
+
+// load implements Object.load.
+func (*Array) load(r Reader) Object {
+ a := loadArray(r)
+ return &a
+}
+
+// Map is a map value.
+type Map struct {
+ Keys []Object
+ Values []Object
+}
+
+// loadMap loads an object of type Map.
+func loadMap(r Reader) Map {
+ l := loadUint(r)
+ if l == 0 {
+ // See LoadArray.
+ return Map{}
+ }
+ // See type dispatch notes in Array.
+ keys := make([]Object, l)
+ values := make([]Object, l)
+ k := Load(r)
+ v := Load(r)
+ keys[0] = k
+ values[0] = v
+ for i := 1; i < int(l); i++ {
+ keys[i] = k.load(r)
+ values[i] = v.load(r)
+ }
+ return Map{
+ Keys: keys,
+ Values: values,
+ }
+}
+
+// save implements Object.save.
+func (m *Map) save(w Writer) {
+ l := Uint(len(m.Keys))
+ if int(l) != len(m.Values) {
+ panic(fmt.Sprintf("mismatched keys (%d) Aand values (%d)", len(m.Keys), len(m.Values)))
+ }
+ l.save(w)
+ if l == 0 {
+ // See LoadArray.
+ return
+ }
+ // See above.
+ Save(w, m.Keys[0])
+ Save(w, m.Values[0])
+ for i := 1; i < int(l); i++ {
+ m.Keys[i].save(w)
+ m.Values[i].save(w)
+ }
+}
+
+// load implements Object.load.
+func (*Map) load(r Reader) Object {
+ m := loadMap(r)
+ return &m
+}
+
+// TypeSpec is a type dereference.
+type TypeSpec interface {
+ isTypeSpec()
+}
+
+// TypeID is a concrete type ID.
+type TypeID Uint
+
+func (TypeID) isTypeSpec() {}
+
+// TypeSpecPointer is a pointer type.
+type TypeSpecPointer struct {
+ Type TypeSpec
+}
+
+func (*TypeSpecPointer) isTypeSpec() {}
+
+// TypeSpecArray is an array type.
+type TypeSpecArray struct {
+ Count Uint
+ Type TypeSpec
+}
+
+func (*TypeSpecArray) isTypeSpec() {}
+
+// TypeSpecSlice is a slice type.
+type TypeSpecSlice struct {
+ Type TypeSpec
+}
+
+func (*TypeSpecSlice) isTypeSpec() {}
+
+// TypeSpecMap is a map type.
+type TypeSpecMap struct {
+ Key TypeSpec
+ Value TypeSpec
+}
+
+func (*TypeSpecMap) isTypeSpec() {}
+
+// TypeSpecNil is an empty type.
+type TypeSpecNil struct{}
+
+func (TypeSpecNil) isTypeSpec() {}
+
+// TypeSpec types.
+//
+// These use a distinct encoding on the wire, as they are used only in the
+// interface object. They are decoded through the dedicated loadTypeSpec and
+// saveTypeSpec functions.
+const (
+ typeSpecTypeID Uint = iota
+ typeSpecPointer
+ typeSpecArray
+ typeSpecSlice
+ typeSpecMap
+ typeSpecNil
+)
+
+// loadTypeSpec loads TypeSpec values.
+func loadTypeSpec(r Reader) TypeSpec {
+ switch hdr := loadUint(r); hdr {
+ case typeSpecTypeID:
+ return TypeID(loadUint(r))
+ case typeSpecPointer:
+ return &TypeSpecPointer{
+ Type: loadTypeSpec(r),
+ }
+ case typeSpecArray:
+ return &TypeSpecArray{
+ Count: loadUint(r),
+ Type: loadTypeSpec(r),
+ }
+ case typeSpecSlice:
+ return &TypeSpecSlice{
+ Type: loadTypeSpec(r),
+ }
+ case typeSpecMap:
+ return &TypeSpecMap{
+ Key: loadTypeSpec(r),
+ Value: loadTypeSpec(r),
+ }
+ case typeSpecNil:
+ return TypeSpecNil{}
+ default:
+ // This is not a valid stream?
+ panic(fmt.Errorf("unknown header: %d", hdr))
+ }
+}
+
+// saveTypeSpec saves TypeSpec values.
+func saveTypeSpec(w Writer, t TypeSpec) {
+ switch x := t.(type) {
+ case TypeID:
+ typeSpecTypeID.save(w)
+ Uint(x).save(w)
+ case *TypeSpecPointer:
+ typeSpecPointer.save(w)
+ saveTypeSpec(w, x.Type)
+ case *TypeSpecArray:
+ typeSpecArray.save(w)
+ x.Count.save(w)
+ saveTypeSpec(w, x.Type)
+ case *TypeSpecSlice:
+ typeSpecSlice.save(w)
+ saveTypeSpec(w, x.Type)
+ case *TypeSpecMap:
+ typeSpecMap.save(w)
+ saveTypeSpec(w, x.Key)
+ saveTypeSpec(w, x.Value)
+ case TypeSpecNil:
+ typeSpecNil.save(w)
+ default:
+ // This should not happen?
+ panic(fmt.Errorf("unknown type %T", t))
+ }
+}
+
+// Interface is an interface value.
+type Interface struct {
+ Type TypeSpec
+ Value Object
+}
+
+// loadInterface loads an object of type Interface.
+func loadInterface(r Reader) Interface {
+ return Interface{
+ Type: loadTypeSpec(r),
+ Value: Load(r),
+ }
+}
+
+// save implements Object.save.
+func (i *Interface) save(w Writer) {
+ saveTypeSpec(w, i.Type)
+ Save(w, i.Value)
+}
+
+// load implements Object.load.
+func (*Interface) load(r Reader) Object {
+ i := loadInterface(r)
+ return &i
+}
+
+// Type is type information.
+type Type struct {
+ Name string
+ Fields []string
+}
+
+// loadType loads an object of type Type.
+func loadType(r Reader) Type {
+ name := string(loadString(r))
+ l := loadUint(r)
+ fields := make([]string, l)
+ for i := 0; i < int(l); i++ {
+ fields[i] = string(loadString(r))
+ }
+ return Type{
+ Name: name,
+ Fields: fields,
+ }
+}
+
+// save implements Object.save.
+func (t *Type) save(w Writer) {
+ s := String(t.Name)
+ s.save(w)
+ l := Uint(len(t.Fields))
+ l.save(w)
+ for i := 0; i < int(l); i++ {
+ s := String(t.Fields[i])
+ s.save(w)
+ }
+}
+
+// load implements Object.load.
+func (*Type) load(r Reader) Object {
+ t := loadType(r)
+ return &t
+}
+
+// multipleObjects is a special type for serializing multiple objects.
+type multipleObjects []Object
+
+// loadMultipleObjects loads a series of objects.
+func loadMultipleObjects(r Reader) multipleObjects {
+ l := loadUint(r)
+ m := make(multipleObjects, l)
+ for i := 0; i < int(l); i++ {
+ m[i] = Load(r)
+ }
+ return m
+}
+
+// save implements Object.save.
+func (m *multipleObjects) save(w Writer) {
+ l := Uint(len(*m))
+ l.save(w)
+ for i := 0; i < int(l); i++ {
+ Save(w, (*m)[i])
+ }
+}
+
+// load implements Object.load.
+func (*multipleObjects) load(r Reader) Object {
+ m := loadMultipleObjects(r)
+ return &m
+}
+
+// noObjects represents no objects.
+type noObjects struct{}
+
+// loadNoObjects loads a sentinel.
+func loadNoObjects(r Reader) noObjects { return noObjects{} }
+
+// save implements Object.save.
+func (noObjects) save(w Writer) {}
+
+// load implements Object.load.
+func (noObjects) load(r Reader) Object { return loadNoObjects(r) }
+
+// Struct is a basic composite value.
+type Struct struct {
+ TypeID TypeID
+ fields Object // Optionally noObjects or *multipleObjects.
+}
+
+// Field returns a pointer to the given field slot.
+//
+// This must be called after Alloc.
+func (s *Struct) Field(i int) *Object {
+ if fields, ok := s.fields.(*multipleObjects); ok {
+ return &((*fields)[i])
+ }
+ if _, ok := s.fields.(noObjects); ok {
+ // Alloc may be optionally called; can't call twice.
+ panic("Field called inappropriately, wrong Alloc?")
+ }
+ return &s.fields
+}
+
+// Alloc allocates the given number of fields.
+//
+// This must be called before Add and Save.
+//
+// Precondition: slots must be positive.
+func (s *Struct) Alloc(slots int) {
+ switch {
+ case slots == 0:
+ s.fields = noObjects{}
+ case slots == 1:
+ // Leave it alone.
+ case slots > 1:
+ fields := make(multipleObjects, slots)
+ s.fields = &fields
+ default:
+ // Violates precondition.
+ panic(fmt.Sprintf("Alloc called with negative slots %d?", slots))
+ }
+}
+
+// Fields returns the number of fields.
+func (s *Struct) Fields() int {
+ switch x := s.fields.(type) {
+ case *multipleObjects:
+ return len(*x)
+ case noObjects:
+ return 0
+ default:
+ return 1
+ }
+}
+
+// loadStruct loads an object of type Struct.
+func loadStruct(r Reader) Struct {
+ return Struct{
+ TypeID: TypeID(loadUint(r)),
+ fields: Load(r),
+ }
+}
+
+// save implements Object.save.
+//
+// Precondition: Alloc must have been called, and the fields all filled in
+// appropriately. See Alloc and Add for more details.
+func (s *Struct) save(w Writer) {
+ Uint(s.TypeID).save(w)
+ Save(w, s.fields)
+}
+
+// load implements Object.load.
+func (*Struct) load(r Reader) Object {
+ s := loadStruct(r)
+ return &s
+}
+
+// Object types.
+//
+// N.B. Be careful about changing the order or introducing new elements in the
+// middle here. This is part of the wire format and shouldn't change.
+const (
+ typeBool Uint = iota
+ typeInt
+ typeUint
+ typeFloat32
+ typeFloat64
+ typeNil
+ typeRef
+ typeString
+ typeSlice
+ typeArray
+ typeMap
+ typeStruct
+ typeNoObjects
+ typeMultipleObjects
+ typeInterface
+ typeComplex64
+ typeComplex128
+ typeType
+)
+
+// Save saves the given object.
+//
+// +checkescape all
+//
+// N.B. This function will panic on error.
+func Save(w Writer, obj Object) {
+ switch x := obj.(type) {
+ case Bool:
+ typeBool.save(w)
+ x.save(w)
+ case Int:
+ typeInt.save(w)
+ x.save(w)
+ case Uint:
+ typeUint.save(w)
+ x.save(w)
+ case Float32:
+ typeFloat32.save(w)
+ x.save(w)
+ case Float64:
+ typeFloat64.save(w)
+ x.save(w)
+ case Nil:
+ typeNil.save(w)
+ x.save(w)
+ case *Ref:
+ typeRef.save(w)
+ x.save(w)
+ case *String:
+ typeString.save(w)
+ x.save(w)
+ case *Slice:
+ typeSlice.save(w)
+ x.save(w)
+ case *Array:
+ typeArray.save(w)
+ x.save(w)
+ case *Map:
+ typeMap.save(w)
+ x.save(w)
+ case *Struct:
+ typeStruct.save(w)
+ x.save(w)
+ case noObjects:
+ typeNoObjects.save(w)
+ x.save(w)
+ case *multipleObjects:
+ typeMultipleObjects.save(w)
+ x.save(w)
+ case *Interface:
+ typeInterface.save(w)
+ x.save(w)
+ case *Type:
+ typeType.save(w)
+ x.save(w)
+ case *Complex64:
+ typeComplex64.save(w)
+ x.save(w)
+ case *Complex128:
+ typeComplex128.save(w)
+ x.save(w)
+ default:
+ panic(fmt.Errorf("unknown type: %#v", obj))
+ }
+}
+
+// Load loads a new object.
+//
+// +checkescape all
+//
+// N.B. This function will panic on error.
+func Load(r Reader) Object {
+ switch hdr := loadUint(r); hdr {
+ case typeBool:
+ return loadBool(r)
+ case typeInt:
+ return loadInt(r)
+ case typeUint:
+ return loadUint(r)
+ case typeFloat32:
+ return loadFloat32(r)
+ case typeFloat64:
+ return loadFloat64(r)
+ case typeNil:
+ return loadNil(r)
+ case typeRef:
+ return ((*Ref)(nil)).load(r) // Escapes.
+ case typeString:
+ return ((*String)(nil)).load(r) // Escapes.
+ case typeSlice:
+ return ((*Slice)(nil)).load(r) // Escapes.
+ case typeArray:
+ return ((*Array)(nil)).load(r) // Escapes.
+ case typeMap:
+ return ((*Map)(nil)).load(r) // Escapes.
+ case typeStruct:
+ return ((*Struct)(nil)).load(r) // Escapes.
+ case typeNoObjects: // Special for struct.
+ return loadNoObjects(r)
+ case typeMultipleObjects: // Special for struct.
+ return ((*multipleObjects)(nil)).load(r) // Escapes.
+ case typeInterface:
+ return ((*Interface)(nil)).load(r) // Escapes.
+ case typeComplex64:
+ return ((*Complex64)(nil)).load(r) // Escapes.
+ case typeComplex128:
+ return ((*Complex128)(nil)).load(r) // Escapes.
+ case typeType:
+ return ((*Type)(nil)).load(r) // Escapes.
+ default:
+ // This is not a valid stream?
+ panic(fmt.Errorf("unknown header: %d", hdr))
+ }
+}
+
+// LoadUint loads a single unsigned integer.
+//
+// N.B. This function will panic on error.
+func LoadUint(r Reader) uint64 {
+ return uint64(loadUint(r))
+}
+
+// SaveUint saves a single unsigned integer.
+//
+// N.B. This function will panic on error.
+func SaveUint(w Writer, v uint64) {
+ Uint(v).save(w)
+}
diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD
new file mode 100644
index 000000000..d0d77e19c
--- /dev/null
+++ b/pkg/sync/BUILD
@@ -0,0 +1,55 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template")
+
+package(
+ default_visibility = ["//:sandbox"],
+ licenses = ["notice"],
+)
+
+exports_files(["LICENSE"])
+
+go_template(
+ name = "generic_atomicptr",
+ srcs = ["atomicptr_unsafe.go"],
+ types = [
+ "Value",
+ ],
+)
+
+go_template(
+ name = "generic_seqatomic",
+ srcs = ["seqatomic_unsafe.go"],
+ types = [
+ "Value",
+ ],
+ deps = [
+ ":sync",
+ ],
+)
+
+go_library(
+ name = "sync",
+ srcs = [
+ "aliases.go",
+ "memmove_unsafe.go",
+ "mutex_unsafe.go",
+ "norace_unsafe.go",
+ "race_unsafe.go",
+ "rwmutex_unsafe.go",
+ "seqcount.go",
+ "sync.go",
+ ],
+ marshal = False,
+ stateify = False,
+)
+
+go_test(
+ name = "sync_test",
+ size = "small",
+ srcs = [
+ "mutex_test.go",
+ "rwmutex_test.go",
+ "seqcount_test.go",
+ ],
+ library = ":sync",
+)
diff --git a/pkg/sync/LICENSE b/pkg/sync/LICENSE
new file mode 100644
index 000000000..6a66aea5e
--- /dev/null
+++ b/pkg/sync/LICENSE
@@ -0,0 +1,27 @@
+Copyright (c) 2009 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/pkg/sync/README.md b/pkg/sync/README.md
new file mode 100644
index 000000000..2183c4e20
--- /dev/null
+++ b/pkg/sync/README.md
@@ -0,0 +1,5 @@
+# Syncutil
+
+This package provides additional synchronization primitives not provided by the
+Go stdlib 'sync' package. It is partially derived from the upstream 'sync'
+package from go1.10.
diff --git a/pkg/sync/aliases.go b/pkg/sync/aliases.go
new file mode 100644
index 000000000..0d4316254
--- /dev/null
+++ b/pkg/sync/aliases.go
@@ -0,0 +1,36 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sync
+
+import (
+ "sync"
+)
+
+// Aliases of standard library types.
+type (
+ // Cond is an alias of sync.Cond.
+ Cond = sync.Cond
+
+ // Locker is an alias of sync.Locker.
+ Locker = sync.Locker
+
+ // Once is an alias of sync.Once.
+ Once = sync.Once
+
+ // Pool is an alias of sync.Pool.
+ Pool = sync.Pool
+
+ // WaitGroup is an alias of sync.WaitGroup.
+ WaitGroup = sync.WaitGroup
+
+ // Map is an alias of sync.Map.
+ Map = sync.Map
+)
+
+// NewCond is a wrapper around sync.NewCond.
+func NewCond(l Locker) *Cond {
+ return sync.NewCond(l)
+}
diff --git a/pkg/sync/atomicptr_unsafe.go b/pkg/sync/atomicptr_unsafe.go
new file mode 100644
index 000000000..525c4beed
--- /dev/null
+++ b/pkg/sync/atomicptr_unsafe.go
@@ -0,0 +1,47 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package template doesn't exist. This file must be instantiated using the
+// go_template_instance rule in tools/go_generics/defs.bzl.
+package template
+
+import (
+ "sync/atomic"
+ "unsafe"
+)
+
+// Value is a required type parameter.
+type Value struct{}
+
+// An AtomicPtr is a pointer to a value of type Value that can be atomically
+// loaded and stored. The zero value of an AtomicPtr represents nil.
+//
+// Note that copying AtomicPtr by value performs a non-atomic read of the
+// stored pointer, which is unsafe if Store() can be called concurrently; in
+// this case, do `dst.Store(src.Load())` instead.
+//
+// +stateify savable
+type AtomicPtr struct {
+ ptr unsafe.Pointer `state:".(*Value)"`
+}
+
+func (p *AtomicPtr) savePtr() *Value {
+ return p.Load()
+}
+
+func (p *AtomicPtr) loadPtr(v *Value) {
+ p.Store(v)
+}
+
+// Load returns the value set by the most recent Store. It returns nil if there
+// has been no previous call to Store.
+func (p *AtomicPtr) Load() *Value {
+ return (*Value)(atomic.LoadPointer(&p.ptr))
+}
+
+// Store sets the value returned by Load to x.
+func (p *AtomicPtr) Store(x *Value) {
+ atomic.StorePointer(&p.ptr, (unsafe.Pointer)(x))
+}
diff --git a/pkg/sync/atomicptrtest/BUILD b/pkg/sync/atomicptrtest/BUILD
new file mode 100644
index 000000000..e97553254
--- /dev/null
+++ b/pkg/sync/atomicptrtest/BUILD
@@ -0,0 +1,27 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "atomicptr_int",
+ out = "atomicptr_int_unsafe.go",
+ package = "atomicptr",
+ suffix = "Int",
+ template = "//pkg/sync:generic_atomicptr",
+ types = {
+ "Value": "int",
+ },
+)
+
+go_library(
+ name = "atomicptr",
+ srcs = ["atomicptr_int_unsafe.go"],
+)
+
+go_test(
+ name = "atomicptr_test",
+ size = "small",
+ srcs = ["atomicptr_test.go"],
+ library = ":atomicptr",
+)
diff --git a/pkg/sync/atomicptrtest/atomicptr_test.go b/pkg/sync/atomicptrtest/atomicptr_test.go
new file mode 100644
index 000000000..8fdc5112e
--- /dev/null
+++ b/pkg/sync/atomicptrtest/atomicptr_test.go
@@ -0,0 +1,31 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package atomicptr
+
+import (
+ "testing"
+)
+
+func newInt(val int) *int {
+ return &val
+}
+
+func TestAtomicPtr(t *testing.T) {
+ var p AtomicPtrInt
+ if got := p.Load(); got != nil {
+ t.Errorf("initial value is %p (%v), wanted nil", got, got)
+ }
+ want := newInt(42)
+ p.Store(want)
+ if got := p.Load(); got != want {
+ t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want)
+ }
+ want = newInt(100)
+ p.Store(want)
+ if got := p.Load(); got != want {
+ t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want)
+ }
+}
diff --git a/pkg/sync/memmove_unsafe.go b/pkg/sync/memmove_unsafe.go
new file mode 100644
index 000000000..1d7780695
--- /dev/null
+++ b/pkg/sync/memmove_unsafe.go
@@ -0,0 +1,28 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.12
+// +build !go1.16
+
+// Check go:linkname function signatures when updating Go version.
+
+package sync
+
+import (
+ "unsafe"
+)
+
+//go:linkname memmove runtime.memmove
+//go:noescape
+func memmove(to, from unsafe.Pointer, n uintptr)
+
+// Memmove is exported for SeqAtomicLoad/SeqAtomicTryLoad<T>, which can't
+// define it because go_generics can't update the go:linkname annotation.
+// Furthermore, go:linkname silently doesn't work if the local name is exported
+// (this is of course undocumented), which is why this indirection is
+// necessary.
+func Memmove(to, from unsafe.Pointer, n uintptr) {
+ memmove(to, from, n)
+}
diff --git a/pkg/sync/mutex_test.go b/pkg/sync/mutex_test.go
new file mode 100644
index 000000000..0838248b4
--- /dev/null
+++ b/pkg/sync/mutex_test.go
@@ -0,0 +1,71 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sync
+
+import (
+ "sync"
+ "testing"
+ "unsafe"
+)
+
+// TestStructSize verifies that syncMutex's size hasn't drifted from the
+// standard library's version.
+//
+// The correctness of this package relies on these remaining in sync.
+func TestStructSize(t *testing.T) {
+ const (
+ got = unsafe.Sizeof(syncMutex{})
+ want = unsafe.Sizeof(sync.Mutex{})
+ )
+ if got != want {
+ t.Errorf("got sizeof(syncMutex) = %d, want = sizeof(sync.Mutex) = %d", got, want)
+ }
+}
+
+// TestFieldValues verifies that the semantics of syncMutex.state from the
+// standard library's implementation.
+//
+// The correctness of this package relies on these remaining in sync.
+func TestFieldValues(t *testing.T) {
+ var m Mutex
+ m.Lock()
+ if got := *m.state(); got != mutexLocked {
+ t.Errorf("got locked sync.Mutex.state = %d, want = %d", got, mutexLocked)
+ }
+ m.Unlock()
+ if got := *m.state(); got != mutexUnlocked {
+ t.Errorf("got unlocked sync.Mutex.state = %d, want = %d", got, mutexUnlocked)
+ }
+}
+
+func TestDoubleTryLock(t *testing.T) {
+ var m Mutex
+ if !m.TryLock() {
+ t.Fatal("failed to aquire lock")
+ }
+ if m.TryLock() {
+ t.Fatal("unexpectedly succeeded in aquiring locked mutex")
+ }
+}
+
+func TestTryLockAfterLock(t *testing.T) {
+ var m Mutex
+ m.Lock()
+ if m.TryLock() {
+ t.Fatal("unexpectedly succeeded in aquiring locked mutex")
+ }
+}
+
+func TestTryLockUnlock(t *testing.T) {
+ var m Mutex
+ if !m.TryLock() {
+ t.Fatal("failed to aquire lock")
+ }
+ m.Unlock()
+ if !m.TryLock() {
+ t.Fatal("failed to aquire lock after unlock")
+ }
+}
diff --git a/pkg/sync/mutex_unsafe.go b/pkg/sync/mutex_unsafe.go
new file mode 100644
index 000000000..dc034d561
--- /dev/null
+++ b/pkg/sync/mutex_unsafe.go
@@ -0,0 +1,49 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.13
+// +build !go1.16
+
+// When updating the build constraint (above), check that syncMutex matches the
+// standard library sync.Mutex definition.
+
+package sync
+
+import (
+ "sync"
+ "sync/atomic"
+ "unsafe"
+)
+
+// Mutex is a try lock.
+type Mutex struct {
+ sync.Mutex
+}
+
+type syncMutex struct {
+ state int32
+ sema uint32
+}
+
+func (m *Mutex) state() *int32 {
+ return &(*syncMutex)(unsafe.Pointer(&m.Mutex)).state
+}
+
+const (
+ mutexUnlocked = 0
+ mutexLocked = 1
+)
+
+// TryLock tries to aquire the mutex. It returns true if it succeeds and false
+// otherwise. TryLock does not block.
+func (m *Mutex) TryLock() bool {
+ if atomic.CompareAndSwapInt32(m.state(), mutexUnlocked, mutexLocked) {
+ if RaceEnabled {
+ RaceAcquire(unsafe.Pointer(&m.Mutex))
+ }
+ return true
+ }
+ return false
+}
diff --git a/pkg/sync/norace_unsafe.go b/pkg/sync/norace_unsafe.go
new file mode 100644
index 000000000..006055dd6
--- /dev/null
+++ b/pkg/sync/norace_unsafe.go
@@ -0,0 +1,35 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !race
+
+package sync
+
+import (
+ "unsafe"
+)
+
+// RaceEnabled is true if the Go data race detector is enabled.
+const RaceEnabled = false
+
+// RaceDisable has the same semantics as runtime.RaceDisable.
+func RaceDisable() {
+}
+
+// RaceEnable has the same semantics as runtime.RaceEnable.
+func RaceEnable() {
+}
+
+// RaceAcquire has the same semantics as runtime.RaceAcquire.
+func RaceAcquire(addr unsafe.Pointer) {
+}
+
+// RaceRelease has the same semantics as runtime.RaceRelease.
+func RaceRelease(addr unsafe.Pointer) {
+}
+
+// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge.
+func RaceReleaseMerge(addr unsafe.Pointer) {
+}
diff --git a/pkg/sync/race_unsafe.go b/pkg/sync/race_unsafe.go
new file mode 100644
index 000000000..31d8fa9a6
--- /dev/null
+++ b/pkg/sync/race_unsafe.go
@@ -0,0 +1,41 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build race
+
+package sync
+
+import (
+ "runtime"
+ "unsafe"
+)
+
+// RaceEnabled is true if the Go data race detector is enabled.
+const RaceEnabled = true
+
+// RaceDisable has the same semantics as runtime.RaceDisable.
+func RaceDisable() {
+ runtime.RaceDisable()
+}
+
+// RaceEnable has the same semantics as runtime.RaceEnable.
+func RaceEnable() {
+ runtime.RaceEnable()
+}
+
+// RaceAcquire has the same semantics as runtime.RaceAcquire.
+func RaceAcquire(addr unsafe.Pointer) {
+ runtime.RaceAcquire(addr)
+}
+
+// RaceRelease has the same semantics as runtime.RaceRelease.
+func RaceRelease(addr unsafe.Pointer) {
+ runtime.RaceRelease(addr)
+}
+
+// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge.
+func RaceReleaseMerge(addr unsafe.Pointer) {
+ runtime.RaceReleaseMerge(addr)
+}
diff --git a/pkg/sync/rwmutex_test.go b/pkg/sync/rwmutex_test.go
new file mode 100644
index 000000000..ce667e825
--- /dev/null
+++ b/pkg/sync/rwmutex_test.go
@@ -0,0 +1,205 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Copyright 2019 The gVisor Authors.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// GOMAXPROCS=10 go test
+
+// Copy/pasted from the standard library's sync/rwmutex_test.go, except for the
+// addition of downgradingWriter and the renaming of num_iterations to
+// numIterations to shut up Golint.
+
+package sync
+
+import (
+ "fmt"
+ "runtime"
+ "sync/atomic"
+ "testing"
+)
+
+func parallelReader(m *RWMutex, clocked, cunlock, cdone chan bool) {
+ m.RLock()
+ clocked <- true
+ <-cunlock
+ m.RUnlock()
+ cdone <- true
+}
+
+func doTestParallelReaders(numReaders, gomaxprocs int) {
+ runtime.GOMAXPROCS(gomaxprocs)
+ var m RWMutex
+ clocked := make(chan bool)
+ cunlock := make(chan bool)
+ cdone := make(chan bool)
+ for i := 0; i < numReaders; i++ {
+ go parallelReader(&m, clocked, cunlock, cdone)
+ }
+ // Wait for all parallel RLock()s to succeed.
+ for i := 0; i < numReaders; i++ {
+ <-clocked
+ }
+ for i := 0; i < numReaders; i++ {
+ cunlock <- true
+ }
+ // Wait for the goroutines to finish.
+ for i := 0; i < numReaders; i++ {
+ <-cdone
+ }
+}
+
+func TestParallelReaders(t *testing.T) {
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1))
+ doTestParallelReaders(1, 4)
+ doTestParallelReaders(3, 4)
+ doTestParallelReaders(4, 2)
+}
+
+func reader(rwm *RWMutex, numIterations int, activity *int32, cdone chan bool) {
+ for i := 0; i < numIterations; i++ {
+ rwm.RLock()
+ n := atomic.AddInt32(activity, 1)
+ if n < 1 || n >= 10000 {
+ panic(fmt.Sprintf("wlock(%d)\n", n))
+ }
+ for i := 0; i < 100; i++ {
+ }
+ atomic.AddInt32(activity, -1)
+ rwm.RUnlock()
+ }
+ cdone <- true
+}
+
+func writer(rwm *RWMutex, numIterations int, activity *int32, cdone chan bool) {
+ for i := 0; i < numIterations; i++ {
+ rwm.Lock()
+ n := atomic.AddInt32(activity, 10000)
+ if n != 10000 {
+ panic(fmt.Sprintf("wlock(%d)\n", n))
+ }
+ for i := 0; i < 100; i++ {
+ }
+ atomic.AddInt32(activity, -10000)
+ rwm.Unlock()
+ }
+ cdone <- true
+}
+
+func downgradingWriter(rwm *RWMutex, numIterations int, activity *int32, cdone chan bool) {
+ for i := 0; i < numIterations; i++ {
+ rwm.Lock()
+ n := atomic.AddInt32(activity, 10000)
+ if n != 10000 {
+ panic(fmt.Sprintf("wlock(%d)\n", n))
+ }
+ for i := 0; i < 100; i++ {
+ }
+ atomic.AddInt32(activity, -10000)
+ rwm.DowngradeLock()
+ n = atomic.AddInt32(activity, 1)
+ if n < 1 || n >= 10000 {
+ panic(fmt.Sprintf("wlock(%d)\n", n))
+ }
+ for i := 0; i < 100; i++ {
+ }
+ n = atomic.AddInt32(activity, -1)
+ rwm.RUnlock()
+ }
+ cdone <- true
+}
+
+func HammerDowngradableRWMutex(gomaxprocs, numReaders, numIterations int) {
+ runtime.GOMAXPROCS(gomaxprocs)
+ // Number of active readers + 10000 * number of active writers.
+ var activity int32
+ var rwm RWMutex
+ cdone := make(chan bool)
+ go writer(&rwm, numIterations, &activity, cdone)
+ go downgradingWriter(&rwm, numIterations, &activity, cdone)
+ var i int
+ for i = 0; i < numReaders/2; i++ {
+ go reader(&rwm, numIterations, &activity, cdone)
+ }
+ go writer(&rwm, numIterations, &activity, cdone)
+ go downgradingWriter(&rwm, numIterations, &activity, cdone)
+ for ; i < numReaders; i++ {
+ go reader(&rwm, numIterations, &activity, cdone)
+ }
+ // Wait for the 4 writers and all readers to finish.
+ for i := 0; i < 4+numReaders; i++ {
+ <-cdone
+ }
+}
+
+func TestDowngradableRWMutex(t *testing.T) {
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1))
+ n := 1000
+ if testing.Short() {
+ n = 5
+ }
+ HammerDowngradableRWMutex(1, 1, n)
+ HammerDowngradableRWMutex(1, 3, n)
+ HammerDowngradableRWMutex(1, 10, n)
+ HammerDowngradableRWMutex(4, 1, n)
+ HammerDowngradableRWMutex(4, 3, n)
+ HammerDowngradableRWMutex(4, 10, n)
+ HammerDowngradableRWMutex(10, 1, n)
+ HammerDowngradableRWMutex(10, 3, n)
+ HammerDowngradableRWMutex(10, 10, n)
+ HammerDowngradableRWMutex(10, 5, n)
+}
+
+func TestRWDoubleTryLock(t *testing.T) {
+ var rwm RWMutex
+ if !rwm.TryLock() {
+ t.Fatal("failed to aquire lock")
+ }
+ if rwm.TryLock() {
+ t.Fatal("unexpectedly succeeded in aquiring locked mutex")
+ }
+}
+
+func TestRWTryLockAfterLock(t *testing.T) {
+ var rwm RWMutex
+ rwm.Lock()
+ if rwm.TryLock() {
+ t.Fatal("unexpectedly succeeded in aquiring locked mutex")
+ }
+}
+
+func TestRWTryLockUnlock(t *testing.T) {
+ var rwm RWMutex
+ if !rwm.TryLock() {
+ t.Fatal("failed to aquire lock")
+ }
+ rwm.Unlock()
+ if !rwm.TryLock() {
+ t.Fatal("failed to aquire lock after unlock")
+ }
+}
+
+func TestTryRLockAfterLock(t *testing.T) {
+ var rwm RWMutex
+ rwm.Lock()
+ if rwm.TryRLock() {
+ t.Fatal("unexpectedly succeeded in aquiring locked mutex")
+ }
+}
+
+func TestTryLockAfterRLock(t *testing.T) {
+ var rwm RWMutex
+ rwm.RLock()
+ if rwm.TryLock() {
+ t.Fatal("unexpectedly succeeded in aquiring locked mutex")
+ }
+}
+
+func TestDoubleTryRLock(t *testing.T) {
+ var rwm RWMutex
+ if !rwm.TryRLock() {
+ t.Fatal("failed to aquire lock")
+ }
+ if !rwm.TryRLock() {
+ t.Fatal("failed to read aquire read locked lock")
+ }
+}
diff --git a/pkg/sync/rwmutex_unsafe.go b/pkg/sync/rwmutex_unsafe.go
new file mode 100644
index 000000000..995c0346e
--- /dev/null
+++ b/pkg/sync/rwmutex_unsafe.go
@@ -0,0 +1,198 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Copyright 2019 The gVisor Authors.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.13
+// +build !go1.16
+
+// Check go:linkname function signatures when updating Go version.
+
+// This is mostly copied from the standard library's sync/rwmutex.go.
+//
+// Happens-before relationships indicated to the race detector:
+// - Unlock -> Lock (via writerSem)
+// - Unlock -> RLock (via readerSem)
+// - RUnlock -> Lock (via writerSem)
+// - DowngradeLock -> RLock (via readerSem)
+
+package sync
+
+import (
+ "sync/atomic"
+ "unsafe"
+)
+
+//go:linkname runtimeSemacquire sync.runtime_Semacquire
+func runtimeSemacquire(s *uint32)
+
+//go:linkname runtimeSemrelease sync.runtime_Semrelease
+func runtimeSemrelease(s *uint32, handoff bool, skipframes int)
+
+// RWMutex is identical to sync.RWMutex, but adds the DowngradeLock,
+// TryLock and TryRLock methods.
+type RWMutex struct {
+ w Mutex // held if there are pending writers
+ writerSem uint32 // semaphore for writers to wait for completing readers
+ readerSem uint32 // semaphore for readers to wait for completing writers
+ readerCount int32 // number of pending readers
+ readerWait int32 // number of departing readers
+}
+
+const rwmutexMaxReaders = 1 << 30
+
+// TryRLock locks rw for reading. It returns true if it succeeds and false
+// otherwise. It does not block.
+func (rw *RWMutex) TryRLock() bool {
+ if RaceEnabled {
+ RaceDisable()
+ }
+ for {
+ rc := atomic.LoadInt32(&rw.readerCount)
+ if rc < 0 {
+ if RaceEnabled {
+ RaceEnable()
+ }
+ return false
+ }
+ if !atomic.CompareAndSwapInt32(&rw.readerCount, rc, rc+1) {
+ continue
+ }
+ if RaceEnabled {
+ RaceEnable()
+ RaceAcquire(unsafe.Pointer(&rw.readerSem))
+ }
+ return true
+ }
+}
+
+// RLock locks rw for reading.
+func (rw *RWMutex) RLock() {
+ if RaceEnabled {
+ RaceDisable()
+ }
+ if atomic.AddInt32(&rw.readerCount, 1) < 0 {
+ // A writer is pending, wait for it.
+ runtimeSemacquire(&rw.readerSem)
+ }
+ if RaceEnabled {
+ RaceEnable()
+ RaceAcquire(unsafe.Pointer(&rw.readerSem))
+ }
+}
+
+// RUnlock undoes a single RLock call.
+func (rw *RWMutex) RUnlock() {
+ if RaceEnabled {
+ RaceReleaseMerge(unsafe.Pointer(&rw.writerSem))
+ RaceDisable()
+ }
+ if r := atomic.AddInt32(&rw.readerCount, -1); r < 0 {
+ if r+1 == 0 || r+1 == -rwmutexMaxReaders {
+ panic("RUnlock of unlocked RWMutex")
+ }
+ // A writer is pending.
+ if atomic.AddInt32(&rw.readerWait, -1) == 0 {
+ // The last reader unblocks the writer.
+ runtimeSemrelease(&rw.writerSem, false, 0)
+ }
+ }
+ if RaceEnabled {
+ RaceEnable()
+ }
+}
+
+// TryLock locks rw for writing. It returns true if it succeeds and false
+// otherwise. It does not block.
+func (rw *RWMutex) TryLock() bool {
+ if RaceEnabled {
+ RaceDisable()
+ }
+ // First, resolve competition with other writers.
+ if !rw.w.TryLock() {
+ if RaceEnabled {
+ RaceEnable()
+ }
+ return false
+ }
+ // Only proceed if there are no readers.
+ if !atomic.CompareAndSwapInt32(&rw.readerCount, 0, -rwmutexMaxReaders) {
+ rw.w.Unlock()
+ if RaceEnabled {
+ RaceEnable()
+ }
+ return false
+ }
+ if RaceEnabled {
+ RaceEnable()
+ RaceAcquire(unsafe.Pointer(&rw.writerSem))
+ }
+ return true
+}
+
+// Lock locks rw for writing.
+func (rw *RWMutex) Lock() {
+ if RaceEnabled {
+ RaceDisable()
+ }
+ // First, resolve competition with other writers.
+ rw.w.Lock()
+ // Announce to readers there is a pending writer.
+ r := atomic.AddInt32(&rw.readerCount, -rwmutexMaxReaders) + rwmutexMaxReaders
+ // Wait for active readers.
+ if r != 0 && atomic.AddInt32(&rw.readerWait, r) != 0 {
+ runtimeSemacquire(&rw.writerSem)
+ }
+ if RaceEnabled {
+ RaceEnable()
+ RaceAcquire(unsafe.Pointer(&rw.writerSem))
+ }
+}
+
+// Unlock unlocks rw for writing.
+func (rw *RWMutex) Unlock() {
+ if RaceEnabled {
+ RaceRelease(unsafe.Pointer(&rw.writerSem))
+ RaceRelease(unsafe.Pointer(&rw.readerSem))
+ RaceDisable()
+ }
+ // Announce to readers there is no active writer.
+ r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders)
+ if r >= rwmutexMaxReaders {
+ panic("Unlock of unlocked RWMutex")
+ }
+ // Unblock blocked readers, if any.
+ for i := 0; i < int(r); i++ {
+ runtimeSemrelease(&rw.readerSem, false, 0)
+ }
+ // Allow other writers to proceed.
+ rw.w.Unlock()
+ if RaceEnabled {
+ RaceEnable()
+ }
+}
+
+// DowngradeLock atomically unlocks rw for writing and locks it for reading.
+func (rw *RWMutex) DowngradeLock() {
+ if RaceEnabled {
+ RaceRelease(unsafe.Pointer(&rw.readerSem))
+ RaceDisable()
+ }
+ // Announce to readers there is no active writer and one additional reader.
+ r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders+1)
+ if r >= rwmutexMaxReaders+1 {
+ panic("DowngradeLock of unlocked RWMutex")
+ }
+ // Unblock blocked readers, if any. Note that this loop starts as 1 since r
+ // includes this goroutine.
+ for i := 1; i < int(r); i++ {
+ runtimeSemrelease(&rw.readerSem, false, 0)
+ }
+ // Allow other writers to proceed to rw.w.Lock(). Note that they will still
+ // block on rw.writerSem since at least this reader exists, such that
+ // DowngradeLock() is atomic with the previous write lock.
+ rw.w.Unlock()
+ if RaceEnabled {
+ RaceEnable()
+ }
+}
diff --git a/pkg/sync/seqatomic_unsafe.go b/pkg/sync/seqatomic_unsafe.go
new file mode 100644
index 000000000..eda6fb131
--- /dev/null
+++ b/pkg/sync/seqatomic_unsafe.go
@@ -0,0 +1,72 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package template doesn't exist. This file must be instantiated using the
+// go_template_instance rule in tools/go_generics/defs.bzl.
+package template
+
+import (
+ "fmt"
+ "reflect"
+ "strings"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Value is a required type parameter.
+//
+// Value must not contain any pointers, including interface objects, function
+// objects, slices, maps, channels, unsafe.Pointer, and arrays or structs
+// containing any of the above. An init() function will panic if this property
+// does not hold.
+type Value struct{}
+
+// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race
+// with any writer critical sections in sc.
+func SeqAtomicLoad(sc *sync.SeqCount, ptr *Value) Value {
+ // This function doesn't use SeqAtomicTryLoad because doing so is
+ // measurably, significantly (~20%) slower; Go is awful at inlining.
+ var val Value
+ for {
+ epoch := sc.BeginRead()
+ if sync.RaceEnabled {
+ // runtime.RaceDisable() doesn't actually stop the race detector,
+ // so it can't help us here. Instead, call runtime.memmove
+ // directly, which is not instrumented by the race detector.
+ sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
+ } else {
+ // This is ~40% faster for short reads than going through memmove.
+ val = *ptr
+ }
+ if sc.ReadOk(epoch) {
+ break
+ }
+ }
+ return val
+}
+
+// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section
+// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read
+// would race with a writer critical section, SeqAtomicTryLoad returns
+// (unspecified, false).
+func SeqAtomicTryLoad(sc *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) (Value, bool) {
+ var val Value
+ if sync.RaceEnabled {
+ sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
+ } else {
+ val = *ptr
+ }
+ return val, sc.ReadOk(epoch)
+}
+
+func init() {
+ var val Value
+ typ := reflect.TypeOf(val)
+ name := typ.Name()
+ if ptrs := sync.PointersInType(typ, name); len(ptrs) != 0 {
+ panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n")))
+ }
+}
diff --git a/pkg/sync/seqatomictest/BUILD b/pkg/sync/seqatomictest/BUILD
new file mode 100644
index 000000000..5c38c783e
--- /dev/null
+++ b/pkg/sync/seqatomictest/BUILD
@@ -0,0 +1,31 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "seqatomic_int",
+ out = "seqatomic_int_unsafe.go",
+ package = "seqatomic",
+ suffix = "Int",
+ template = "//pkg/sync:generic_seqatomic",
+ types = {
+ "Value": "int",
+ },
+)
+
+go_library(
+ name = "seqatomic",
+ srcs = ["seqatomic_int_unsafe.go"],
+ deps = [
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "seqatomic_test",
+ size = "small",
+ srcs = ["seqatomic_test.go"],
+ library = ":seqatomic",
+ deps = ["//pkg/sync"],
+)
diff --git a/pkg/sync/seqatomictest/seqatomic_test.go b/pkg/sync/seqatomictest/seqatomic_test.go
new file mode 100644
index 000000000..2c4568b07
--- /dev/null
+++ b/pkg/sync/seqatomictest/seqatomic_test.go
@@ -0,0 +1,132 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seqatomic
+
+import (
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func TestSeqAtomicLoadUncontended(t *testing.T) {
+ var seq sync.SeqCount
+ const want = 1
+ data := want
+ if got := SeqAtomicLoadInt(&seq, &data); got != want {
+ t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want)
+ }
+}
+
+func TestSeqAtomicLoadAfterWrite(t *testing.T) {
+ var seq sync.SeqCount
+ var data int
+ const want = 1
+ seq.BeginWrite()
+ data = want
+ seq.EndWrite()
+ if got := SeqAtomicLoadInt(&seq, &data); got != want {
+ t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want)
+ }
+}
+
+func TestSeqAtomicLoadDuringWrite(t *testing.T) {
+ var seq sync.SeqCount
+ var data int
+ const want = 1
+ seq.BeginWrite()
+ go func() {
+ time.Sleep(time.Second)
+ data = want
+ seq.EndWrite()
+ }()
+ if got := SeqAtomicLoadInt(&seq, &data); got != want {
+ t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want)
+ }
+}
+
+func TestSeqAtomicTryLoadUncontended(t *testing.T) {
+ var seq sync.SeqCount
+ const want = 1
+ data := want
+ epoch := seq.BeginRead()
+ if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want {
+ t.Errorf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want)
+ }
+}
+
+func TestSeqAtomicTryLoadDuringWrite(t *testing.T) {
+ var seq sync.SeqCount
+ var data int
+ epoch := seq.BeginRead()
+ seq.BeginWrite()
+ if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok {
+ t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got)
+ }
+ seq.EndWrite()
+}
+
+func TestSeqAtomicTryLoadAfterWrite(t *testing.T) {
+ var seq sync.SeqCount
+ var data int
+ epoch := seq.BeginRead()
+ seq.BeginWrite()
+ seq.EndWrite()
+ if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok {
+ t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got)
+ }
+}
+
+func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) {
+ var seq sync.SeqCount
+ const want = 42
+ data := want
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ if got := SeqAtomicLoadInt(&seq, &data); got != want {
+ b.Fatalf("SeqAtomicLoadInt: got %v, wanted %v", got, want)
+ }
+ }
+ })
+}
+
+func BenchmarkSeqAtomicTryLoadIntUncontended(b *testing.B) {
+ var seq sync.SeqCount
+ const want = 42
+ data := want
+ b.RunParallel(func(pb *testing.PB) {
+ epoch := seq.BeginRead()
+ for pb.Next() {
+ if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want {
+ b.Fatalf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want)
+ }
+ }
+ })
+}
+
+// For comparison:
+func BenchmarkAtomicValueLoadIntUncontended(b *testing.B) {
+ var a atomic.Value
+ const want = 42
+ a.Store(int(want))
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ if got := a.Load().(int); got != want {
+ b.Fatalf("atomic.Value.Load: got %v, wanted %v", got, want)
+ }
+ }
+ })
+}
diff --git a/pkg/sync/seqcount.go b/pkg/sync/seqcount.go
new file mode 100644
index 000000000..a1e895352
--- /dev/null
+++ b/pkg/sync/seqcount.go
@@ -0,0 +1,149 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sync
+
+import (
+ "fmt"
+ "reflect"
+ "runtime"
+ "sync/atomic"
+)
+
+// SeqCount is a synchronization primitive for optimistic reader/writer
+// synchronization in cases where readers can work with stale data and
+// therefore do not need to block writers.
+//
+// Compared to sync/atomic.Value:
+//
+// - Mutation of SeqCount-protected data does not require memory allocation,
+// whereas atomic.Value generally does. This is a significant advantage when
+// writes are common.
+//
+// - Atomic reads of SeqCount-protected data require copying. This is a
+// disadvantage when atomic reads are common.
+//
+// - SeqCount may be more flexible: correct use of SeqCount.ReadOk allows other
+// operations to be made atomic with reads of SeqCount-protected data.
+//
+// - SeqCount may be less flexible: as of this writing, SeqCount-protected data
+// cannot include pointers.
+//
+// - SeqCount is more cumbersome to use; atomic reads of SeqCount-protected
+// data require instantiating function templates using go_generics (see
+// seqatomic.go).
+type SeqCount struct {
+ // epoch is incremented by BeginWrite and EndWrite, such that epoch is odd
+ // if a writer critical section is active, and a read from data protected
+ // by this SeqCount is atomic iff epoch is the same even value before and
+ // after the read.
+ epoch uint32
+}
+
+// SeqCountEpoch tracks writer critical sections in a SeqCount.
+type SeqCountEpoch struct {
+ val uint32
+}
+
+// We assume that:
+//
+// - All functions in sync/atomic that perform a memory read are at least a
+// read fence: memory reads before calls to such functions cannot be reordered
+// after the call, and memory reads after calls to such functions cannot be
+// reordered before the call, even if those reads do not use sync/atomic.
+//
+// - All functions in sync/atomic that perform a memory write are at least a
+// write fence: memory writes before calls to such functions cannot be
+// reordered after the call, and memory writes after calls to such functions
+// cannot be reordered before the call, even if those writes do not use
+// sync/atomic.
+//
+// As of this writing, the Go memory model completely fails to describe
+// sync/atomic, but these properties are implied by
+// https://groups.google.com/forum/#!topic/golang-nuts/7EnEhM3U7B8.
+
+// BeginRead indicates the beginning of a reader critical section. Reader
+// critical sections DO NOT BLOCK writer critical sections, so operations in a
+// reader critical section MAY RACE with writer critical sections. Races are
+// detected by ReadOk at the end of the reader critical section. Thus, the
+// low-level structure of readers is generally:
+//
+// for {
+// epoch := seq.BeginRead()
+// // do something idempotent with seq-protected data
+// if seq.ReadOk(epoch) {
+// break
+// }
+// }
+//
+// However, since reader critical sections may race with writer critical
+// sections, the Go race detector will (accurately) flag data races in readers
+// using this pattern. Most users of SeqCount will need to use the
+// SeqAtomicLoad function template in seqatomic.go.
+func (s *SeqCount) BeginRead() SeqCountEpoch {
+ epoch := atomic.LoadUint32(&s.epoch)
+ for epoch&1 != 0 {
+ runtime.Gosched()
+ epoch = atomic.LoadUint32(&s.epoch)
+ }
+ return SeqCountEpoch{epoch}
+}
+
+// ReadOk returns true if the reader critical section initiated by a previous
+// call to BeginRead() that returned epoch did not race with any writer critical
+// sections.
+//
+// ReadOk may be called any number of times during a reader critical section.
+// Reader critical sections do not need to be explicitly terminated; the last
+// call to ReadOk is implicitly the end of the reader critical section.
+func (s *SeqCount) ReadOk(epoch SeqCountEpoch) bool {
+ return atomic.LoadUint32(&s.epoch) == epoch.val
+}
+
+// BeginWrite indicates the beginning of a writer critical section.
+//
+// SeqCount does not support concurrent writer critical sections; clients with
+// concurrent writers must synchronize them using e.g. sync.Mutex.
+func (s *SeqCount) BeginWrite() {
+ if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 == 0 {
+ panic("SeqCount.BeginWrite during writer critical section")
+ }
+}
+
+// EndWrite ends the effect of a preceding BeginWrite.
+func (s *SeqCount) EndWrite() {
+ if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 != 0 {
+ panic("SeqCount.EndWrite outside writer critical section")
+ }
+}
+
+// PointersInType returns a list of pointers reachable from values named
+// valName of the given type.
+//
+// PointersInType is not exhaustive, but it is guaranteed that if typ contains
+// at least one pointer, then PointersInTypeOf returns a non-empty list.
+func PointersInType(typ reflect.Type, valName string) []string {
+ switch kind := typ.Kind(); kind {
+ case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
+ return nil
+
+ case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.String, reflect.UnsafePointer:
+ return []string{valName}
+
+ case reflect.Array:
+ return PointersInType(typ.Elem(), valName+"[]")
+
+ case reflect.Struct:
+ var ptrs []string
+ for i, n := 0, typ.NumField(); i < n; i++ {
+ field := typ.Field(i)
+ ptrs = append(ptrs, PointersInType(field.Type, fmt.Sprintf("%s.%s", valName, field.Name))...)
+ }
+ return ptrs
+
+ default:
+ return []string{fmt.Sprintf("%s (of type %s with unknown kind %s)", valName, typ, kind)}
+ }
+}
diff --git a/pkg/sync/seqcount_test.go b/pkg/sync/seqcount_test.go
new file mode 100644
index 000000000..6eb7b4b59
--- /dev/null
+++ b/pkg/sync/seqcount_test.go
@@ -0,0 +1,153 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sync
+
+import (
+ "reflect"
+ "testing"
+ "time"
+)
+
+func TestSeqCountWriteUncontended(t *testing.T) {
+ var seq SeqCount
+ seq.BeginWrite()
+ seq.EndWrite()
+}
+
+func TestSeqCountReadUncontended(t *testing.T) {
+ var seq SeqCount
+ epoch := seq.BeginRead()
+ if !seq.ReadOk(epoch) {
+ t.Errorf("ReadOk: got false, wanted true")
+ }
+}
+
+func TestSeqCountBeginReadAfterWrite(t *testing.T) {
+ var seq SeqCount
+ var data int32
+ const want = 1
+ seq.BeginWrite()
+ data = want
+ seq.EndWrite()
+ epoch := seq.BeginRead()
+ if data != want {
+ t.Errorf("Reader: got %v, wanted %v", data, want)
+ }
+ if !seq.ReadOk(epoch) {
+ t.Errorf("ReadOk: got false, wanted true")
+ }
+}
+
+func TestSeqCountBeginReadDuringWrite(t *testing.T) {
+ var seq SeqCount
+ var data int
+ const want = 1
+ seq.BeginWrite()
+ go func() {
+ time.Sleep(time.Second)
+ data = want
+ seq.EndWrite()
+ }()
+ epoch := seq.BeginRead()
+ if data != want {
+ t.Errorf("Reader: got %v, wanted %v", data, want)
+ }
+ if !seq.ReadOk(epoch) {
+ t.Errorf("ReadOk: got false, wanted true")
+ }
+}
+
+func TestSeqCountReadOkAfterWrite(t *testing.T) {
+ var seq SeqCount
+ epoch := seq.BeginRead()
+ seq.BeginWrite()
+ seq.EndWrite()
+ if seq.ReadOk(epoch) {
+ t.Errorf("ReadOk: got true, wanted false")
+ }
+}
+
+func TestSeqCountReadOkDuringWrite(t *testing.T) {
+ var seq SeqCount
+ epoch := seq.BeginRead()
+ seq.BeginWrite()
+ if seq.ReadOk(epoch) {
+ t.Errorf("ReadOk: got true, wanted false")
+ }
+ seq.EndWrite()
+}
+
+func BenchmarkSeqCountWriteUncontended(b *testing.B) {
+ var seq SeqCount
+ for i := 0; i < b.N; i++ {
+ seq.BeginWrite()
+ seq.EndWrite()
+ }
+}
+
+func BenchmarkSeqCountReadUncontended(b *testing.B) {
+ var seq SeqCount
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ epoch := seq.BeginRead()
+ if !seq.ReadOk(epoch) {
+ b.Fatalf("ReadOk: got false, wanted true")
+ }
+ }
+ })
+}
+
+func TestPointersInType(t *testing.T) {
+ for _, test := range []struct {
+ name string // used for both test and value name
+ val interface{}
+ ptrs []string
+ }{
+ {
+ name: "EmptyStruct",
+ val: struct{}{},
+ },
+ {
+ name: "Int",
+ val: int(0),
+ },
+ {
+ name: "MixedStruct",
+ val: struct {
+ b bool
+ I int
+ ExportedPtr *struct{}
+ unexportedPtr *struct{}
+ arr [2]int
+ ptrArr [2]*int
+ nestedStruct struct {
+ nestedNonptr int
+ nestedPtr *int
+ }
+ structArr [1]struct {
+ nonptr int
+ ptr *int
+ }
+ }{},
+ ptrs: []string{
+ "MixedStruct.ExportedPtr",
+ "MixedStruct.unexportedPtr",
+ "MixedStruct.ptrArr[]",
+ "MixedStruct.nestedStruct.nestedPtr",
+ "MixedStruct.structArr[].ptr",
+ },
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ typ := reflect.TypeOf(test.val)
+ ptrs := PointersInType(typ, test.name)
+ t.Logf("Found pointers: %v", ptrs)
+ if (len(ptrs) != 0 || len(test.ptrs) != 0) && !reflect.DeepEqual(ptrs, test.ptrs) {
+ t.Errorf("Got %v, wanted %v", ptrs, test.ptrs)
+ }
+ })
+ }
+}
diff --git a/pkg/sync/sync.go b/pkg/sync/sync.go
new file mode 100644
index 000000000..b16cf5333
--- /dev/null
+++ b/pkg/sync/sync.go
@@ -0,0 +1,7 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package sync provides synchronization primitives.
+package sync
diff --git a/pkg/syncevent/BUILD b/pkg/syncevent/BUILD
new file mode 100644
index 000000000..0500a22cf
--- /dev/null
+++ b/pkg/syncevent/BUILD
@@ -0,0 +1,39 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+licenses(["notice"])
+
+go_library(
+ name = "syncevent",
+ srcs = [
+ "broadcaster.go",
+ "receiver.go",
+ "source.go",
+ "syncevent.go",
+ "waiter_amd64.s",
+ "waiter_arm64.s",
+ "waiter_asm_unsafe.go",
+ "waiter_noasm_unsafe.go",
+ "waiter_unsafe.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/atomicbitops",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "syncevent_test",
+ size = "small",
+ srcs = [
+ "broadcaster_test.go",
+ "syncevent_example_test.go",
+ "waiter_test.go",
+ ],
+ library = ":syncevent",
+ deps = [
+ "//pkg/sleep",
+ "//pkg/sync",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/syncevent/broadcaster.go b/pkg/syncevent/broadcaster.go
new file mode 100644
index 000000000..4bff59e7d
--- /dev/null
+++ b/pkg/syncevent/broadcaster.go
@@ -0,0 +1,218 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Broadcaster is an implementation of Source that supports any number of
+// subscribed Receivers.
+//
+// The zero value of Broadcaster is valid and has no subscribed Receivers.
+// Broadcaster is not copyable by value.
+//
+// All Broadcaster methods may be called concurrently from multiple goroutines.
+type Broadcaster struct {
+ // Broadcaster is implemented as a hash table where keys are assigned by
+ // the Broadcaster and returned as SubscriptionIDs, making it safe to use
+ // the identity function for hashing. The hash table resolves collisions
+ // using linear probing and features Robin Hood insertion and backward
+ // shift deletion in order to support a relatively high load factor
+ // efficiently, which matters since the cost of Broadcast is linear in the
+ // size of the table.
+
+ // mu protects the following fields.
+ mu sync.Mutex
+
+ // Invariants: len(table) is 0 or a power of 2.
+ table []broadcasterSlot
+
+ // load is the number of entries in table with receiver != nil.
+ load int
+
+ lastID SubscriptionID
+}
+
+type broadcasterSlot struct {
+ // Invariants: If receiver == nil, then filter == NoEvents and id == 0.
+ // Otherwise, id != 0.
+ receiver *Receiver
+ filter Set
+ id SubscriptionID
+}
+
+const (
+ broadcasterMinNonZeroTableSize = 2 // must be a power of 2 > 1
+
+ broadcasterMaxLoadNum = 13
+ broadcasterMaxLoadDen = 16
+)
+
+// SubscribeEvents implements Source.SubscribeEvents.
+func (b *Broadcaster) SubscribeEvents(r *Receiver, filter Set) SubscriptionID {
+ b.mu.Lock()
+
+ // Assign an ID for this subscription.
+ b.lastID++
+ id := b.lastID
+
+ // Expand the table if over the maximum load factor:
+ //
+ // load / len(b.table) > broadcasterMaxLoadNum / broadcasterMaxLoadDen
+ // load * broadcasterMaxLoadDen > broadcasterMaxLoadNum * len(b.table)
+ b.load++
+ if (b.load * broadcasterMaxLoadDen) > (broadcasterMaxLoadNum * len(b.table)) {
+ // Double the number of slots in the new table.
+ newlen := broadcasterMinNonZeroTableSize
+ if len(b.table) != 0 {
+ newlen = 2 * len(b.table)
+ }
+ if newlen <= cap(b.table) {
+ // Reuse excess capacity in the current table, moving entries not
+ // already in their first-probed positions to better ones.
+ newtable := b.table[:newlen]
+ newmask := uint64(newlen - 1)
+ for i := range b.table {
+ if b.table[i].receiver != nil && uint64(b.table[i].id)&newmask != uint64(i) {
+ entry := b.table[i]
+ b.table[i] = broadcasterSlot{}
+ broadcasterTableInsert(newtable, entry.id, entry.receiver, entry.filter)
+ }
+ }
+ b.table = newtable
+ } else {
+ newtable := make([]broadcasterSlot, newlen)
+ // Copy existing entries to the new table.
+ for i := range b.table {
+ if b.table[i].receiver != nil {
+ broadcasterTableInsert(newtable, b.table[i].id, b.table[i].receiver, b.table[i].filter)
+ }
+ }
+ // Switch to the new table.
+ b.table = newtable
+ }
+ }
+
+ broadcasterTableInsert(b.table, id, r, filter)
+ b.mu.Unlock()
+ return id
+}
+
+// Preconditions: table must not be full. len(table) is a power of 2.
+func broadcasterTableInsert(table []broadcasterSlot, id SubscriptionID, r *Receiver, filter Set) {
+ entry := broadcasterSlot{
+ receiver: r,
+ filter: filter,
+ id: id,
+ }
+ mask := uint64(len(table) - 1)
+ i := uint64(id) & mask
+ disp := uint64(0)
+ for {
+ if table[i].receiver == nil {
+ table[i] = entry
+ return
+ }
+ // If we've been displaced farther from our first-probed slot than the
+ // element stored in this one, swap elements and switch to inserting
+ // the replaced one. (This is Robin Hood insertion.)
+ slotDisp := (i - uint64(table[i].id)) & mask
+ if disp > slotDisp {
+ table[i], entry = entry, table[i]
+ disp = slotDisp
+ }
+ i = (i + 1) & mask
+ disp++
+ }
+}
+
+// UnsubscribeEvents implements Source.UnsubscribeEvents.
+func (b *Broadcaster) UnsubscribeEvents(id SubscriptionID) {
+ b.mu.Lock()
+
+ mask := uint64(len(b.table) - 1)
+ i := uint64(id) & mask
+ for {
+ if b.table[i].id == id {
+ // Found the element to remove. Move all subsequent elements
+ // backward until we either find an empty slot, or an element that
+ // is already in its first-probed slot. (This is backward shift
+ // deletion.)
+ for {
+ next := (i + 1) & mask
+ if b.table[next].receiver == nil {
+ break
+ }
+ if uint64(b.table[next].id)&mask == next {
+ break
+ }
+ b.table[i] = b.table[next]
+ i = next
+ }
+ b.table[i] = broadcasterSlot{}
+ break
+ }
+ i = (i + 1) & mask
+ }
+
+ // If a table 1/4 of the current size would still be at or under the
+ // maximum load factor (i.e. the current table size is at least two
+ // expansions bigger than necessary), halve the size of the table to reduce
+ // the cost of Broadcast. Since we are concerned with iteration time and
+ // not memory usage, reuse the existing slice to reduce future allocations
+ // from table re-expansion.
+ b.load--
+ if len(b.table) > broadcasterMinNonZeroTableSize && (b.load*(4*broadcasterMaxLoadDen)) <= (broadcasterMaxLoadNum*len(b.table)) {
+ newlen := len(b.table) / 2
+ newtable := b.table[:newlen]
+ for i := newlen; i < len(b.table); i++ {
+ if b.table[i].receiver != nil {
+ broadcasterTableInsert(newtable, b.table[i].id, b.table[i].receiver, b.table[i].filter)
+ b.table[i] = broadcasterSlot{}
+ }
+ }
+ b.table = newtable
+ }
+
+ b.mu.Unlock()
+}
+
+// Broadcast notifies all Receivers subscribed to the Broadcaster of the subset
+// of events to which they subscribed. The order in which Receivers are
+// notified is unspecified.
+func (b *Broadcaster) Broadcast(events Set) {
+ b.mu.Lock()
+ for i := range b.table {
+ if intersection := events & b.table[i].filter; intersection != 0 {
+ // We don't need to check if broadcasterSlot.receiver is nil, since
+ // if it is then broadcasterSlot.filter is 0.
+ b.table[i].receiver.Notify(intersection)
+ }
+ }
+ b.mu.Unlock()
+}
+
+// FilteredEvents returns the set of events for which Broadcast will notify at
+// least one Receiver, i.e. the union of filters for all subscribed Receivers.
+func (b *Broadcaster) FilteredEvents() Set {
+ var es Set
+ b.mu.Lock()
+ for i := range b.table {
+ es |= b.table[i].filter
+ }
+ b.mu.Unlock()
+ return es
+}
diff --git a/pkg/syncevent/broadcaster_test.go b/pkg/syncevent/broadcaster_test.go
new file mode 100644
index 000000000..e88779e23
--- /dev/null
+++ b/pkg/syncevent/broadcaster_test.go
@@ -0,0 +1,376 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "fmt"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestBroadcasterFilter(t *testing.T) {
+ const numReceivers = 2 * MaxEvents
+
+ var br Broadcaster
+ ws := make([]Waiter, numReceivers)
+ for i := range ws {
+ ws[i].Init()
+ br.SubscribeEvents(ws[i].Receiver(), 1<<(i%MaxEvents))
+ }
+ for ev := 0; ev < MaxEvents; ev++ {
+ br.Broadcast(1 << ev)
+ for i := range ws {
+ want := NoEvents
+ if i%MaxEvents == ev {
+ want = 1 << ev
+ }
+ if got := ws[i].Receiver().PendingAndAckAll(); got != want {
+ t.Errorf("after Broadcast of event %d: waiter %d has pending event set %#x, wanted %#x", ev, i, got, want)
+ }
+ }
+ }
+}
+
+// TestBroadcasterManySubscriptions tests that subscriptions are not lost by
+// table expansion/compaction.
+func TestBroadcasterManySubscriptions(t *testing.T) {
+ const numReceivers = 5000 // arbitrary
+
+ var br Broadcaster
+ ws := make([]Waiter, numReceivers)
+ for i := range ws {
+ ws[i].Init()
+ }
+
+ ids := make([]SubscriptionID, numReceivers)
+ for i := 0; i < numReceivers; i++ {
+ // Subscribe receiver i.
+ ids[i] = br.SubscribeEvents(ws[i].Receiver(), 1)
+ // Check that receivers [0, i] are subscribed.
+ br.Broadcast(1)
+ for j := 0; j <= i; j++ {
+ if ws[j].Pending() != 1 {
+ t.Errorf("receiver %d did not receive an event after subscription of receiver %d", j, i)
+ }
+ ws[j].Ack(1)
+ }
+ }
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numReceivers)
+ for i := 0; i < numReceivers; i++ {
+ // Unsubscribe receiver unsub[i].
+ br.UnsubscribeEvents(ids[unsub[i]])
+ // Check that receivers [unsub[0], unsub[i]] are not subscribed, and that
+ // receivers (unsub[i], unsub[numReceivers]) are still subscribed.
+ br.Broadcast(1)
+ for j := 0; j <= i; j++ {
+ if ws[unsub[j]].Pending() != 0 {
+ t.Errorf("unsub iteration %d: receiver %d received an event after unsubscription of receiver %d", i, unsub[j], unsub[i])
+ }
+ }
+ for j := i + 1; j < numReceivers; j++ {
+ if ws[unsub[j]].Pending() != 1 {
+ t.Errorf("unsub iteration %d: receiver %d did not receive an event after unsubscription of receiver %d", i, unsub[j], unsub[i])
+ }
+ ws[unsub[j]].Ack(1)
+ }
+ }
+}
+
+var (
+ receiverCountsNonZero = []int{1, 4, 16, 64}
+ receiverCountsIncludingZero = append([]int{0}, receiverCountsNonZero...)
+)
+
+// BenchmarkBroadcasterX, BenchmarkMapX, and BenchmarkQueueX benchmark usage
+// pattern X (described in terms of Broadcaster) with Broadcaster, a
+// Mutex-protected map[*Receiver]Set, and waiter.Queue respectively.
+
+// BenchmarkXxxSubscribeUnsubscribe measures the cost of a Subscribe/Unsubscribe
+// cycle.
+
+func BenchmarkBroadcasterSubscribeUnsubscribe(b *testing.B) {
+ var br Broadcaster
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ id := br.SubscribeEvents(w.Receiver(), 1)
+ br.UnsubscribeEvents(id)
+ }
+}
+
+func BenchmarkMapSubscribeUnsubscribe(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ m[w.Receiver()] = Set(1)
+ mu.Unlock()
+ mu.Lock()
+ delete(m, w.Receiver())
+ mu.Unlock()
+ }
+}
+
+func BenchmarkQueueSubscribeUnsubscribe(b *testing.B) {
+ var q waiter.Queue
+ e, _ := waiter.NewChannelEntry(nil)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ q.EventRegister(&e, 1)
+ q.EventUnregister(&e)
+ }
+}
+
+// BenchmarkXxxSubscribeUnsubscribeBatch is similar to
+// BenchmarkXxxSubscribeUnsubscribe, but subscribes and unsubscribes a large
+// number of Receivers at a time in order to measure the amortized overhead of
+// table expansion/compaction. (Since waiter.Queue is implemented using a
+// linked list, BenchmarkQueueSubscribeUnsubscribe and
+// BenchmarkQueueSubscribeUnsubscribeBatch should produce nearly the same
+// result.)
+
+const numBatchReceivers = 1000
+
+func BenchmarkBroadcasterSubscribeUnsubscribeBatch(b *testing.B) {
+ var br Broadcaster
+ ws := make([]Waiter, numBatchReceivers)
+ for i := range ws {
+ ws[i].Init()
+ }
+ ids := make([]SubscriptionID, numBatchReceivers)
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numBatchReceivers)
+
+ b.ResetTimer()
+ for i := 0; i < b.N/numBatchReceivers; i++ {
+ for j := 0; j < numBatchReceivers; j++ {
+ ids[j] = br.SubscribeEvents(ws[j].Receiver(), 1)
+ }
+ for j := 0; j < numBatchReceivers; j++ {
+ br.UnsubscribeEvents(ids[unsub[j]])
+ }
+ }
+}
+
+func BenchmarkMapSubscribeUnsubscribeBatch(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ ws := make([]Waiter, numBatchReceivers)
+ for i := range ws {
+ ws[i].Init()
+ }
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numBatchReceivers)
+
+ b.ResetTimer()
+ for i := 0; i < b.N/numBatchReceivers; i++ {
+ for j := 0; j < numBatchReceivers; j++ {
+ mu.Lock()
+ m[ws[j].Receiver()] = Set(1)
+ mu.Unlock()
+ }
+ for j := 0; j < numBatchReceivers; j++ {
+ mu.Lock()
+ delete(m, ws[unsub[j]].Receiver())
+ mu.Unlock()
+ }
+ }
+}
+
+func BenchmarkQueueSubscribeUnsubscribeBatch(b *testing.B) {
+ var q waiter.Queue
+ es := make([]waiter.Entry, numBatchReceivers)
+ for i := range es {
+ es[i], _ = waiter.NewChannelEntry(nil)
+ }
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numBatchReceivers)
+
+ b.ResetTimer()
+ for i := 0; i < b.N/numBatchReceivers; i++ {
+ for j := 0; j < numBatchReceivers; j++ {
+ q.EventRegister(&es[j], 1)
+ }
+ for j := 0; j < numBatchReceivers; j++ {
+ q.EventUnregister(&es[unsub[j]])
+ }
+ }
+}
+
+// BenchmarkXxxBroadcastRedundant measures how long it takes to Broadcast
+// already-pending events to multiple Receivers.
+
+func BenchmarkBroadcasterBroadcastRedundant(b *testing.B) {
+ for _, n := range receiverCountsIncludingZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var br Broadcaster
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ br.SubscribeEvents(ws[i].Receiver(), 1)
+ }
+ br.Broadcast(1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ br.Broadcast(1)
+ }
+ })
+ }
+}
+
+func BenchmarkMapBroadcastRedundant(b *testing.B) {
+ for _, n := range receiverCountsIncludingZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ m[ws[i].Receiver()] = Set(1)
+ }
+ mu.Lock()
+ for r := range m {
+ r.Notify(1)
+ }
+ mu.Unlock()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ for r := range m {
+ r.Notify(1)
+ }
+ mu.Unlock()
+ }
+ })
+ }
+}
+
+func BenchmarkQueueBroadcastRedundant(b *testing.B) {
+ for _, n := range receiverCountsIncludingZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var q waiter.Queue
+ for i := 0; i < n; i++ {
+ e, _ := waiter.NewChannelEntry(nil)
+ q.EventRegister(&e, 1)
+ }
+ q.Notify(1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ q.Notify(1)
+ }
+ })
+ }
+}
+
+// BenchmarkXxxBroadcastAck measures how long it takes to Broadcast events to
+// multiple Receivers, check that all Receivers have received the event, and
+// clear the event from all Receivers.
+
+func BenchmarkBroadcasterBroadcastAck(b *testing.B) {
+ for _, n := range receiverCountsNonZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var br Broadcaster
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ br.SubscribeEvents(ws[i].Receiver(), 1)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ br.Broadcast(1)
+ for j := range ws {
+ if got, want := ws[j].Pending(), Set(1); got != want {
+ b.Fatalf("Receiver.Pending(): got %#x, wanted %#x", got, want)
+ }
+ ws[j].Ack(1)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkMapBroadcastAck(b *testing.B) {
+ for _, n := range receiverCountsNonZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ m[ws[i].Receiver()] = Set(1)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ for r := range m {
+ r.Notify(1)
+ }
+ mu.Unlock()
+ for j := range ws {
+ if got, want := ws[j].Pending(), Set(1); got != want {
+ b.Fatalf("Receiver.Pending(): got %#x, wanted %#x", got, want)
+ }
+ ws[j].Ack(1)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkQueueBroadcastAck(b *testing.B) {
+ for _, n := range receiverCountsNonZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var q waiter.Queue
+ chs := make([]chan struct{}, n)
+ for i := range chs {
+ e, ch := waiter.NewChannelEntry(nil)
+ q.EventRegister(&e, 1)
+ chs[i] = ch
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ q.Notify(1)
+ for _, ch := range chs {
+ select {
+ case <-ch:
+ default:
+ b.Fatalf("channel did not receive event")
+ }
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/syncevent/receiver.go b/pkg/syncevent/receiver.go
new file mode 100644
index 000000000..5c86e5400
--- /dev/null
+++ b/pkg/syncevent/receiver.go
@@ -0,0 +1,103 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/atomicbitops"
+)
+
+// Receiver is an event sink that holds pending events and invokes a callback
+// whenever new events become pending. Receiver's methods may be called
+// concurrently from multiple goroutines.
+//
+// Receiver.Init() must be called before first use.
+type Receiver struct {
+ // pending is the set of pending events. pending is accessed using atomic
+ // memory operations.
+ pending uint64
+
+ // cb is notified when new events become pending. cb is immutable after
+ // Init().
+ cb ReceiverCallback
+}
+
+// ReceiverCallback receives callbacks from a Receiver.
+type ReceiverCallback interface {
+ // NotifyPending is called when the corresponding Receiver has new pending
+ // events.
+ //
+ // NotifyPending is called synchronously from Receiver.Notify(), so
+ // implementations must not take locks that may be held by callers of
+ // Receiver.Notify(). NotifyPending may be called concurrently from
+ // multiple goroutines.
+ NotifyPending()
+}
+
+// Init must be called before first use of r.
+func (r *Receiver) Init(cb ReceiverCallback) {
+ r.cb = cb
+}
+
+// Pending returns the set of pending events.
+func (r *Receiver) Pending() Set {
+ return Set(atomic.LoadUint64(&r.pending))
+}
+
+// Notify sets the given events as pending.
+func (r *Receiver) Notify(es Set) {
+ p := Set(atomic.LoadUint64(&r.pending))
+ // Optimization: Skip the atomic CAS on r.pending if all events are
+ // already pending.
+ if p&es == es {
+ return
+ }
+ // When this is uncontended (the common case), CAS is faster than
+ // atomic-OR because the former is inlined and the latter (which we
+ // implement in assembly ourselves) is not.
+ if !atomic.CompareAndSwapUint64(&r.pending, uint64(p), uint64(p|es)) {
+ // If the CAS fails, fall back to atomic-OR.
+ atomicbitops.OrUint64(&r.pending, uint64(es))
+ }
+ r.cb.NotifyPending()
+}
+
+// Ack unsets the given events as pending.
+func (r *Receiver) Ack(es Set) {
+ p := Set(atomic.LoadUint64(&r.pending))
+ // Optimization: Skip the atomic CAS on r.pending if all events are
+ // already not pending.
+ if p&es == 0 {
+ return
+ }
+ // When this is uncontended (the common case), CAS is faster than
+ // atomic-AND because the former is inlined and the latter (which we
+ // implement in assembly ourselves) is not.
+ if !atomic.CompareAndSwapUint64(&r.pending, uint64(p), uint64(p&^es)) {
+ // If the CAS fails, fall back to atomic-AND.
+ atomicbitops.AndUint64(&r.pending, ^uint64(es))
+ }
+}
+
+// PendingAndAckAll unsets all events as pending and returns the set of
+// previously-pending events.
+//
+// PendingAndAckAll should only be used in preference to a call to Pending
+// followed by a conditional call to Ack when the caller expects events to be
+// pending (e.g. after a call to ReceiverCallback.NotifyPending()).
+func (r *Receiver) PendingAndAckAll() Set {
+ return Set(atomic.SwapUint64(&r.pending, 0))
+}
diff --git a/pkg/syncevent/source.go b/pkg/syncevent/source.go
new file mode 100644
index 000000000..ddffb171a
--- /dev/null
+++ b/pkg/syncevent/source.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+// Source represents an event source.
+type Source interface {
+ // SubscribeEvents causes the Source to notify the given Receiver of the
+ // given subset of events.
+ //
+ // Preconditions: r != nil. The ReceiverCallback for r must not take locks
+ // that are ordered prior to the Source; for example, it cannot call any
+ // Source methods.
+ SubscribeEvents(r *Receiver, filter Set) SubscriptionID
+
+ // UnsubscribeEvents causes the Source to stop notifying the Receiver
+ // subscribed by a previous call to SubscribeEvents that returned the given
+ // SubscriptionID.
+ //
+ // Preconditions: UnsubscribeEvents may be called at most once for any
+ // given SubscriptionID.
+ UnsubscribeEvents(id SubscriptionID)
+}
+
+// SubscriptionID identifies a call to Source.SubscribeEvents.
+type SubscriptionID uint64
+
+// UnsubscribeAndAck is a convenience function that unsubscribes r from the
+// given events from src and also clears them from r.
+func UnsubscribeAndAck(src Source, r *Receiver, filter Set, id SubscriptionID) {
+ src.UnsubscribeEvents(id)
+ r.Ack(filter)
+}
+
+// NoopSource implements Source by never sending events to subscribed
+// Receivers.
+type NoopSource struct{}
+
+// SubscribeEvents implements Source.SubscribeEvents.
+func (NoopSource) SubscribeEvents(*Receiver, Set) SubscriptionID {
+ return 0
+}
+
+// UnsubscribeEvents implements Source.UnsubscribeEvents.
+func (NoopSource) UnsubscribeEvents(SubscriptionID) {
+}
+
+// See Broadcaster for a non-noop implementations of Source.
diff --git a/pkg/syncevent/syncevent.go b/pkg/syncevent/syncevent.go
new file mode 100644
index 000000000..9fb6a06de
--- /dev/null
+++ b/pkg/syncevent/syncevent.go
@@ -0,0 +1,32 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package syncevent provides efficient primitives for goroutine
+// synchronization based on event bitmasks.
+package syncevent
+
+// Set is a bitmask where each bit represents a distinct user-defined event.
+// The event package does not treat any bits in Set specially.
+type Set uint64
+
+const (
+ // NoEvents is a Set containing no events.
+ NoEvents = Set(0)
+
+ // AllEvents is a Set containing all possible events.
+ AllEvents = ^Set(0)
+
+ // MaxEvents is the number of distinct events that can be represented by a Set.
+ MaxEvents = 64
+)
diff --git a/pkg/syncevent/syncevent_example_test.go b/pkg/syncevent/syncevent_example_test.go
new file mode 100644
index 000000000..bfb18e2ea
--- /dev/null
+++ b/pkg/syncevent/syncevent_example_test.go
@@ -0,0 +1,108 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "fmt"
+ "sync/atomic"
+ "time"
+)
+
+func Example_ioReadinessInterrputible() {
+ const (
+ evReady = Set(1 << iota)
+ evInterrupt
+ )
+ errNotReady := fmt.Errorf("not ready for I/O")
+
+ // State of some I/O object.
+ var (
+ br Broadcaster
+ ready uint32
+ )
+ doIO := func() error {
+ if atomic.LoadUint32(&ready) == 0 {
+ return errNotReady
+ }
+ return nil
+ }
+ go func() {
+ // The I/O object eventually becomes ready for I/O.
+ time.Sleep(100 * time.Millisecond)
+ // When it does, it first ensures that future calls to isReady() return
+ // true, then broadcasts the readiness event to Receivers.
+ atomic.StoreUint32(&ready, 1)
+ br.Broadcast(evReady)
+ }()
+
+ // Each user of the I/O object owns a Waiter.
+ var w Waiter
+ w.Init()
+ // The Waiter may be asynchronously interruptible, e.g. for signal
+ // handling in the sentry.
+ go func() {
+ time.Sleep(200 * time.Millisecond)
+ w.Receiver().Notify(evInterrupt)
+ }()
+
+ // To use the I/O object:
+ //
+ // Optionally, if the I/O object is likely to be ready, attempt I/O first.
+ err := doIO()
+ if err == nil {
+ // Success, we're done.
+ return /* nil */
+ }
+ if err != errNotReady {
+ // Failure, I/O failed for some reason other than readiness.
+ return /* err */
+ }
+ // Subscribe for readiness events from the I/O object.
+ id := br.SubscribeEvents(w.Receiver(), evReady)
+ // When we are finished blocking, unsubscribe from readiness events and
+ // remove readiness events from the pending event set.
+ defer UnsubscribeAndAck(&br, w.Receiver(), evReady, id)
+ for {
+ // Attempt I/O again. This must be done after the call to SubscribeEvents,
+ // since the I/O object might have become ready between the previous call
+ // to doIO and the call to SubscribeEvents.
+ err = doIO()
+ if err == nil {
+ return /* nil */
+ }
+ if err != errNotReady {
+ return /* err */
+ }
+ // Block until either the I/O object indicates it is ready, or we are
+ // interrupted.
+ events := w.Wait()
+ if events&evInterrupt != 0 {
+ // In the specific case of sentry signal handling, signal delivery
+ // is handled by another system, so we aren't responsible for
+ // acknowledging evInterrupt.
+ return /* errInterrupted */
+ }
+ // Note that, in a concurrent context, the I/O object might become
+ // ready and then not ready again. To handle this:
+ //
+ // - evReady must be acknowledged before calling doIO() again (rather
+ // than after), so that if the I/O object becomes ready *again* after
+ // the call to doIO(), the readiness event is not lost.
+ //
+ // - We must loop instead of just calling doIO() once after receiving
+ // evReady.
+ w.Ack(evReady)
+ }
+}
diff --git a/pkg/syncevent/waiter_amd64.s b/pkg/syncevent/waiter_amd64.s
new file mode 100644
index 000000000..985b56ae5
--- /dev/null
+++ b/pkg/syncevent/waiter_amd64.s
@@ -0,0 +1,32 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// See waiter_noasm_unsafe.go for a description of waiterUnlock.
+//
+// func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool
+TEXT ·waiterUnlock(SB),NOSPLIT,$0-24
+ MOVQ g+0(FP), DI
+ MOVQ wg+8(FP), SI
+
+ MOVQ $·preparingG(SB), AX
+ LOCK
+ CMPXCHGQ DI, 0(SI)
+
+ SETEQ AX
+ MOVB AX, ret+16(FP)
+
+ RET
+
diff --git a/pkg/syncevent/waiter_arm64.s b/pkg/syncevent/waiter_arm64.s
new file mode 100644
index 000000000..20d7ac23b
--- /dev/null
+++ b/pkg/syncevent/waiter_arm64.s
@@ -0,0 +1,34 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// See waiter_noasm_unsafe.go for a description of waiterUnlock.
+//
+// func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool
+TEXT ·waiterUnlock(SB),NOSPLIT,$0-24
+ MOVD wg+8(FP), R0
+ MOVD $·preparingG(SB), R1
+ MOVD g+0(FP), R2
+again:
+ LDAXR (R0), R3
+ CMP R1, R3
+ BNE ok
+ STLXR R2, (R0), R3
+ CBNZ R3, again
+ok:
+ CSET EQ, R0
+ MOVB R0, ret+16(FP)
+ RET
+
diff --git a/pkg/syncevent/waiter_asm_unsafe.go b/pkg/syncevent/waiter_asm_unsafe.go
new file mode 100644
index 000000000..0995e9053
--- /dev/null
+++ b/pkg/syncevent/waiter_asm_unsafe.go
@@ -0,0 +1,24 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64 arm64
+
+package syncevent
+
+import (
+ "unsafe"
+)
+
+// See waiter_noasm_unsafe.go for a description of waiterUnlock.
+func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool
diff --git a/pkg/syncevent/waiter_noasm_unsafe.go b/pkg/syncevent/waiter_noasm_unsafe.go
new file mode 100644
index 000000000..1c4b0e39a
--- /dev/null
+++ b/pkg/syncevent/waiter_noasm_unsafe.go
@@ -0,0 +1,39 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// waiterUnlock is called from g0, so when the race detector is enabled,
+// waiterUnlock must be implemented in assembly since no race context is
+// available.
+//
+// +build !race
+// +build !amd64,!arm64
+
+package syncevent
+
+import (
+ "sync/atomic"
+ "unsafe"
+)
+
+// waiterUnlock is the "unlock function" passed to runtime.gopark by
+// Waiter.Wait*. wg is &Waiter.g, and g is a pointer to the calling runtime.g.
+// waiterUnlock returns true if Waiter.Wait should sleep and false if sleeping
+// should be aborted.
+//
+//go:nosplit
+func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool {
+ // The only way this CAS can fail is if a call to Waiter.NotifyPending()
+ // has replaced *wg with nil, in which case we should not sleep.
+ return atomic.CompareAndSwapPointer(wg, (unsafe.Pointer)(&preparingG), g)
+}
diff --git a/pkg/syncevent/waiter_test.go b/pkg/syncevent/waiter_test.go
new file mode 100644
index 000000000..3c8cbcdd8
--- /dev/null
+++ b/pkg/syncevent/waiter_test.go
@@ -0,0 +1,414 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func TestWaiterAlreadyPending(t *testing.T) {
+ var w Waiter
+ w.Init()
+ want := Set(1)
+ w.Notify(want)
+ if got := w.Wait(); got != want {
+ t.Errorf("Waiter.Wait: got %#x, wanted %#x", got, want)
+ }
+}
+
+func TestWaiterAsyncNotify(t *testing.T) {
+ var w Waiter
+ w.Init()
+ want := Set(1)
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ w.Notify(want)
+ }()
+ if got := w.Wait(); got != want {
+ t.Errorf("Waiter.Wait: got %#x, wanted %#x", got, want)
+ }
+}
+
+func TestWaiterWaitFor(t *testing.T) {
+ var w Waiter
+ w.Init()
+ evWaited := Set(1)
+ evOther := Set(2)
+ w.Notify(evOther)
+ notifiedEvent := uint32(0)
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ atomic.StoreUint32(&notifiedEvent, 1)
+ w.Notify(evWaited)
+ }()
+ if got, want := w.WaitFor(evWaited), evWaited|evOther; got != want {
+ t.Errorf("Waiter.WaitFor: got %#x, wanted %#x", got, want)
+ }
+ if atomic.LoadUint32(&notifiedEvent) == 0 {
+ t.Errorf("Waiter.WaitFor returned before goroutine notified waited-for event")
+ }
+}
+
+func TestWaiterWaitAndAckAll(t *testing.T) {
+ var w Waiter
+ w.Init()
+ w.Notify(AllEvents)
+ if got := w.WaitAndAckAll(); got != AllEvents {
+ t.Errorf("Waiter.WaitAndAckAll: got %#x, wanted %#x", got, AllEvents)
+ }
+ if got := w.Pending(); got != NoEvents {
+ t.Errorf("Waiter.WaitAndAckAll did not ack all events: got %#x, wanted 0", got)
+ }
+}
+
+// BenchmarkWaiterX, BenchmarkSleeperX, and BenchmarkChannelX benchmark usage
+// pattern X (described in terms of Waiter) with Waiter, sleep.Sleeper, and
+// buffered chan struct{} respectively. When the maximum number of event
+// sources is relevant, we use 3 event sources because this is representative
+// of the kernel.Task.block() use case: an interrupt source, a timeout source,
+// and the actual event source being waited on.
+
+// Event set used by most benchmarks.
+const evBench Set = 1
+
+// BenchmarkXxxNotifyRedundant measures how long it takes to notify a Waiter of
+// an event that is already pending.
+
+func BenchmarkWaiterNotifyRedundant(b *testing.B) {
+ var w Waiter
+ w.Init()
+ w.Notify(evBench)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Notify(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyRedundant(b *testing.B) {
+ var s sleep.Sleeper
+ var w sleep.Waker
+ s.AddWaker(&w, 0)
+ w.Assert()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Assert()
+ }
+}
+
+func BenchmarkChannelNotifyRedundant(b *testing.B) {
+ ch := make(chan struct{}, 1)
+ ch <- struct{}{}
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+ }
+}
+
+// BenchmarkXxxNotifyWaitAck measures how long it takes to notify a Waiter an
+// event, return that event using a blocking check, and then unset the event as
+// pending.
+
+func BenchmarkWaiterNotifyWaitAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Notify(evBench)
+ w.Wait()
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyWaitAck(b *testing.B) {
+ var s sleep.Sleeper
+ var w sleep.Waker
+ s.AddWaker(&w, 0)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Assert()
+ s.Fetch(true)
+ }
+}
+
+func BenchmarkChannelNotifyWaitAck(b *testing.B) {
+ ch := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ // notify
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+
+ // wait + ack
+ <-ch
+ }
+}
+
+// BenchmarkSleeperMultiNotifyWaitAck is equivalent to
+// BenchmarkSleeperNotifyWaitAck, but also includes allocation of a
+// temporary sleep.Waker. This is necessary when multiple goroutines may wait
+// for the same event, since each sleep.Waker can wake only a single
+// sleep.Sleeper.
+//
+// The syncevent package does not require a distinct object for each
+// waiter-waker relationship, so BenchmarkWaiterNotifyWaitAck and
+// BenchmarkWaiterMultiNotifyWaitAck would be identical. The analogous state
+// for channels, runtime.sudog, is inescapably runtime-allocated, so
+// BenchmarkChannelNotifyWaitAck and BenchmarkChannelMultiNotifyWaitAck would
+// also be identical.
+
+func BenchmarkSleeperMultiNotifyWaitAck(b *testing.B) {
+ var s sleep.Sleeper
+ // The sleep package doesn't provide sync.Pool allocation of Wakers;
+ // we do for a fairer comparison.
+ wakerPool := sync.Pool{
+ New: func() interface{} {
+ return &sleep.Waker{}
+ },
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w := wakerPool.Get().(*sleep.Waker)
+ s.AddWaker(w, 0)
+ w.Assert()
+ s.Fetch(true)
+ s.Done()
+ wakerPool.Put(w)
+ }
+}
+
+// BenchmarkXxxTempNotifyWaitAck is equivalent to NotifyWaitAck, but also
+// includes allocation of a temporary Waiter. This models the case where a
+// goroutine not already associated with a Waiter needs one in order to block.
+//
+// The analogous state for channels is built into runtime.g, so
+// BenchmarkChannelNotifyWaitAck and BenchmarkChannelTempNotifyWaitAck would be
+// identical.
+
+func BenchmarkWaiterTempNotifyWaitAck(b *testing.B) {
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w := GetWaiter()
+ w.Notify(evBench)
+ w.Wait()
+ w.Ack(evBench)
+ PutWaiter(w)
+ }
+}
+
+func BenchmarkSleeperTempNotifyWaitAck(b *testing.B) {
+ // The sleep package doesn't provide sync.Pool allocation of Sleepers;
+ // we do for a fairer comparison.
+ sleeperPool := sync.Pool{
+ New: func() interface{} {
+ return &sleep.Sleeper{}
+ },
+ }
+ var w sleep.Waker
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ s := sleeperPool.Get().(*sleep.Sleeper)
+ s.AddWaker(&w, 0)
+ w.Assert()
+ s.Fetch(true)
+ s.Done()
+ sleeperPool.Put(s)
+ }
+}
+
+// BenchmarkXxxNotifyWaitMultiAck is equivalent to NotifyWaitAck, but allows
+// for multiple event sources.
+
+func BenchmarkWaiterNotifyWaitMultiAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Notify(evBench)
+ if e := w.Wait(); e != evBench {
+ b.Fatalf("Wait: got %#x, wanted %#x", e, evBench)
+ }
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyWaitMultiAck(b *testing.B) {
+ var s sleep.Sleeper
+ var ws [3]sleep.Waker
+ for i := range ws {
+ s.AddWaker(&ws[i], i)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ ws[0].Assert()
+ if id, _ := s.Fetch(true); id != 0 {
+ b.Fatalf("Fetch: got %d, wanted 0", id)
+ }
+ }
+}
+
+func BenchmarkChannelNotifyWaitMultiAck(b *testing.B) {
+ ch0 := make(chan struct{}, 1)
+ ch1 := make(chan struct{}, 1)
+ ch2 := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ // notify
+ select {
+ case ch0 <- struct{}{}:
+ default:
+ }
+
+ // wait + clear
+ select {
+ case <-ch0:
+ // ok
+ case <-ch1:
+ b.Fatalf("received from ch1")
+ case <-ch2:
+ b.Fatalf("received from ch2")
+ }
+ }
+}
+
+// BenchmarkXxxNotifyAsyncWaitAck measures how long it takes to wait for an
+// event while another goroutine signals the event. This assumes that a new
+// goroutine doesn't run immediately (i.e. the creator of a new goroutine is
+// allowed to go to sleep before the new goroutine has a chance to run).
+
+func BenchmarkWaiterNotifyAsyncWaitAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w.Notify(1)
+ }()
+ w.Wait()
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyAsyncWaitAck(b *testing.B) {
+ var s sleep.Sleeper
+ var w sleep.Waker
+ s.AddWaker(&w, 0)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w.Assert()
+ }()
+ s.Fetch(true)
+ }
+}
+
+func BenchmarkChannelNotifyAsyncWaitAck(b *testing.B) {
+ ch := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+ }()
+ <-ch
+ }
+}
+
+// BenchmarkXxxNotifyAsyncWaitMultiAck is equivalent to NotifyAsyncWaitAck, but
+// allows for multiple event sources.
+
+func BenchmarkWaiterNotifyAsyncWaitMultiAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w.Notify(evBench)
+ }()
+ if e := w.Wait(); e != evBench {
+ b.Fatalf("Wait: got %#x, wanted %#x", e, evBench)
+ }
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyAsyncWaitMultiAck(b *testing.B) {
+ var s sleep.Sleeper
+ var ws [3]sleep.Waker
+ for i := range ws {
+ s.AddWaker(&ws[i], i)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ ws[0].Assert()
+ }()
+ if id, _ := s.Fetch(true); id != 0 {
+ b.Fatalf("Fetch: got %d, expected 0", id)
+ }
+ }
+}
+
+func BenchmarkChannelNotifyAsyncWaitMultiAck(b *testing.B) {
+ ch0 := make(chan struct{}, 1)
+ ch1 := make(chan struct{}, 1)
+ ch2 := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ select {
+ case ch0 <- struct{}{}:
+ default:
+ }
+ }()
+
+ select {
+ case <-ch0:
+ // ok
+ case <-ch1:
+ b.Fatalf("received from ch1")
+ case <-ch2:
+ b.Fatalf("received from ch2")
+ }
+ }
+}
diff --git a/pkg/syncevent/waiter_unsafe.go b/pkg/syncevent/waiter_unsafe.go
new file mode 100644
index 000000000..ad271e1a0
--- /dev/null
+++ b/pkg/syncevent/waiter_unsafe.go
@@ -0,0 +1,206 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build go1.11
+// +build !go1.16
+
+// Check go:linkname function signatures when updating Go version.
+
+package syncevent
+
+import (
+ "sync/atomic"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+//go:linkname gopark runtime.gopark
+func gopark(unlockf func(unsafe.Pointer, *unsafe.Pointer) bool, wg *unsafe.Pointer, reason uint8, traceEv byte, traceskip int)
+
+//go:linkname goready runtime.goready
+func goready(g unsafe.Pointer, traceskip int)
+
+const (
+ waitReasonSelect = 9 // Go: src/runtime/runtime2.go
+ traceEvGoBlockSelect = 24 // Go: src/runtime/trace.go
+)
+
+// Waiter allows a goroutine to block on pending events received by a Receiver.
+//
+// Waiter.Init() must be called before first use.
+type Waiter struct {
+ r Receiver
+
+ // g is one of:
+ //
+ // - nil: No goroutine is blocking in Wait.
+ //
+ // - &preparingG: A goroutine is in Wait preparing to sleep, but hasn't yet
+ // completed waiterUnlock(). Thus the wait can only be interrupted by
+ // replacing the value of g with nil (the G may not be in state Gwaiting
+ // yet, so we can't call goready.)
+ //
+ // - Otherwise: g is a pointer to the runtime.g in state Gwaiting for the
+ // goroutine blocked in Wait, which can only be woken by calling goready.
+ g unsafe.Pointer `state:"zerovalue"`
+}
+
+// Sentinel object for Waiter.g.
+var preparingG struct{}
+
+// Init must be called before first use of w.
+func (w *Waiter) Init() {
+ w.r.Init(w)
+}
+
+// Receiver returns the Receiver that receives events that unblock calls to
+// w.Wait().
+func (w *Waiter) Receiver() *Receiver {
+ return &w.r
+}
+
+// Pending returns the set of pending events.
+func (w *Waiter) Pending() Set {
+ return w.r.Pending()
+}
+
+// Wait blocks until at least one event is pending, then returns the set of
+// pending events. It does not affect the set of pending events; callers must
+// call w.Ack() to do so, or use w.WaitAndAck() instead.
+//
+// Precondition: Only one goroutine may call any Wait* method at a time.
+func (w *Waiter) Wait() Set {
+ return w.WaitFor(AllEvents)
+}
+
+// WaitFor blocks until at least one event in es is pending, then returns the
+// set of pending events (including those not in es). It does not affect the
+// set of pending events; callers must call w.Ack() to do so.
+//
+// Precondition: Only one goroutine may call any Wait* method at a time.
+func (w *Waiter) WaitFor(es Set) Set {
+ for {
+ // Optimization: Skip the atomic store to w.g if an event is already
+ // pending.
+ if p := w.r.Pending(); p&es != NoEvents {
+ return p
+ }
+
+ // Indicate that we're preparing to go to sleep.
+ atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG))
+
+ // If an event is pending, abort the sleep.
+ if p := w.r.Pending(); p&es != NoEvents {
+ atomic.StorePointer(&w.g, nil)
+ return p
+ }
+
+ // If w.g is still preparingG (i.e. w.NotifyPending() has not been
+ // called or has not reached atomic.SwapPointer()), go to sleep until
+ // w.NotifyPending() => goready().
+ gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0)
+ }
+}
+
+// Ack marks the given events as not pending.
+func (w *Waiter) Ack(es Set) {
+ w.r.Ack(es)
+}
+
+// WaitAndAckAll blocks until at least one event is pending, then marks all
+// events as not pending and returns the set of previously-pending events.
+//
+// Precondition: Only one goroutine may call any Wait* method at a time.
+func (w *Waiter) WaitAndAckAll() Set {
+ // Optimization: Skip the atomic store to w.g if an event is already
+ // pending. Call Pending() first since, in the common case that events are
+ // not yet pending, this skips an atomic swap on w.r.pending.
+ if w.r.Pending() != NoEvents {
+ if p := w.r.PendingAndAckAll(); p != NoEvents {
+ return p
+ }
+ }
+
+ for {
+ // Indicate that we're preparing to go to sleep.
+ atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG))
+
+ // If an event is pending, abort the sleep.
+ if w.r.Pending() != NoEvents {
+ if p := w.r.PendingAndAckAll(); p != NoEvents {
+ atomic.StorePointer(&w.g, nil)
+ return p
+ }
+ }
+
+ // If w.g is still preparingG (i.e. w.NotifyPending() has not been
+ // called or has not reached atomic.SwapPointer()), go to sleep until
+ // w.NotifyPending() => goready().
+ gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0)
+
+ // Check for pending events. We call PendingAndAckAll() directly now since
+ // we only expect to be woken after events become pending.
+ if p := w.r.PendingAndAckAll(); p != NoEvents {
+ return p
+ }
+ }
+}
+
+// Notify marks the given events as pending, possibly unblocking concurrent
+// calls to w.Wait() or w.WaitFor().
+func (w *Waiter) Notify(es Set) {
+ w.r.Notify(es)
+}
+
+// NotifyPending implements ReceiverCallback.NotifyPending. Users of Waiter
+// should not call NotifyPending.
+func (w *Waiter) NotifyPending() {
+ // Optimization: Skip the atomic swap on w.g if there is no sleeping
+ // goroutine. NotifyPending is called after w.r.Pending() is updated, so
+ // concurrent and future calls to w.Wait() will observe pending events and
+ // abort sleeping.
+ if atomic.LoadPointer(&w.g) == nil {
+ return
+ }
+ // Wake a sleeping G, or prevent a G that is preparing to sleep from doing
+ // so. Swap is needed here to ensure that only one call to NotifyPending
+ // calls goready.
+ if g := atomic.SwapPointer(&w.g, nil); g != nil && g != (unsafe.Pointer)(&preparingG) {
+ goready(g, 0)
+ }
+}
+
+var waiterPool = sync.Pool{
+ New: func() interface{} {
+ w := &Waiter{}
+ w.Init()
+ return w
+ },
+}
+
+// GetWaiter returns an unused Waiter. PutWaiter should be called to release
+// the Waiter once it is no longer needed.
+//
+// Where possible, users should prefer to associate each goroutine that calls
+// Waiter.Wait() with a distinct pre-allocated Waiter to avoid allocation of
+// Waiters in hot paths.
+func GetWaiter() *Waiter {
+ return waiterPool.Get().(*Waiter)
+}
+
+// PutWaiter releases an unused Waiter previously returned by GetWaiter.
+func PutWaiter(w *Waiter) {
+ waiterPool.Put(w)
+}
diff --git a/pkg/syserr/BUILD b/pkg/syserr/BUILD
new file mode 100644
index 000000000..7d760344a
--- /dev/null
+++ b/pkg/syserr/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "syserr",
+ srcs = [
+ "host_linux.go",
+ "netstack.go",
+ "syserr.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ ],
+)
diff --git a/pkg/syserr/host_linux.go b/pkg/syserr/host_linux.go
new file mode 100644
index 000000000..fc6ef60a1
--- /dev/null
+++ b/pkg/syserr/host_linux.go
@@ -0,0 +1,46 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package syserr
+
+import (
+ "fmt"
+ "syscall"
+)
+
+const maxErrno = 134
+
+type linuxHostTranslation struct {
+ err *Error
+ ok bool
+}
+
+var linuxHostTranslations [maxErrno]linuxHostTranslation
+
+// FromHost translates a syscall.Errno to a corresponding Error value.
+func FromHost(err syscall.Errno) *Error {
+ if err < 0 || int(err) >= len(linuxHostTranslations) || !linuxHostTranslations[err].ok {
+ panic(fmt.Sprintf("unknown host errno %q (%d)", err.Error(), err))
+ }
+ return linuxHostTranslations[err].err
+}
+
+func addLinuxHostTranslation(host syscall.Errno, trans *Error) {
+ if linuxHostTranslations[host].ok {
+ panic(fmt.Sprintf("duplicate translation for host errno %q (%d)", host.Error(), host))
+ }
+ linuxHostTranslations[host] = linuxHostTranslation{err: trans, ok: true}
+}
diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go
new file mode 100644
index 000000000..8ff922c69
--- /dev/null
+++ b/pkg/syserr/netstack.go
@@ -0,0 +1,103 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syserr
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// Mapping for tcpip.Error types.
+var (
+ ErrUnknownProtocol = New(tcpip.ErrUnknownProtocol.String(), linux.EINVAL)
+ ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.EINVAL)
+ ErrUnknownDevice = New(tcpip.ErrUnknownDevice.String(), linux.ENODEV)
+ ErrUnknownProtocolOption = New(tcpip.ErrUnknownProtocolOption.String(), linux.ENOPROTOOPT)
+ ErrDuplicateNICID = New(tcpip.ErrDuplicateNICID.String(), linux.EEXIST)
+ ErrDuplicateAddress = New(tcpip.ErrDuplicateAddress.String(), linux.EEXIST)
+ ErrBadLinkEndpoint = New(tcpip.ErrBadLinkEndpoint.String(), linux.EINVAL)
+ ErrAlreadyBound = New(tcpip.ErrAlreadyBound.String(), linux.EINVAL)
+ ErrInvalidEndpointState = New(tcpip.ErrInvalidEndpointState.String(), linux.EINVAL)
+ ErrAlreadyConnecting = New(tcpip.ErrAlreadyConnecting.String(), linux.EALREADY)
+ ErrNoPortAvailable = New(tcpip.ErrNoPortAvailable.String(), linux.EAGAIN)
+ ErrPortInUse = New(tcpip.ErrPortInUse.String(), linux.EADDRINUSE)
+ ErrBadLocalAddress = New(tcpip.ErrBadLocalAddress.String(), linux.EADDRNOTAVAIL)
+ ErrClosedForSend = New(tcpip.ErrClosedForSend.String(), linux.EPIPE)
+ ErrClosedForReceive = New(tcpip.ErrClosedForReceive.String(), nil)
+ ErrTimeout = New(tcpip.ErrTimeout.String(), linux.ETIMEDOUT)
+ ErrAborted = New(tcpip.ErrAborted.String(), linux.EPIPE)
+ ErrConnectStarted = New(tcpip.ErrConnectStarted.String(), linux.EINPROGRESS)
+ ErrDestinationRequired = New(tcpip.ErrDestinationRequired.String(), linux.EDESTADDRREQ)
+ ErrNotSupported = New(tcpip.ErrNotSupported.String(), linux.EOPNOTSUPP)
+ ErrQueueSizeNotSupported = New(tcpip.ErrQueueSizeNotSupported.String(), linux.ENOTTY)
+ ErrNoSuchFile = New(tcpip.ErrNoSuchFile.String(), linux.ENOENT)
+ ErrInvalidOptionValue = New(tcpip.ErrInvalidOptionValue.String(), linux.EINVAL)
+ ErrBroadcastDisabled = New(tcpip.ErrBroadcastDisabled.String(), linux.EACCES)
+ ErrNotPermittedNet = New(tcpip.ErrNotPermitted.String(), linux.EPERM)
+)
+
+var netstackErrorTranslations = map[*tcpip.Error]*Error{
+ tcpip.ErrUnknownProtocol: ErrUnknownProtocol,
+ tcpip.ErrUnknownNICID: ErrUnknownNICID,
+ tcpip.ErrUnknownDevice: ErrUnknownDevice,
+ tcpip.ErrUnknownProtocolOption: ErrUnknownProtocolOption,
+ tcpip.ErrDuplicateNICID: ErrDuplicateNICID,
+ tcpip.ErrDuplicateAddress: ErrDuplicateAddress,
+ tcpip.ErrNoRoute: ErrNoRoute,
+ tcpip.ErrBadLinkEndpoint: ErrBadLinkEndpoint,
+ tcpip.ErrAlreadyBound: ErrAlreadyBound,
+ tcpip.ErrInvalidEndpointState: ErrInvalidEndpointState,
+ tcpip.ErrAlreadyConnecting: ErrAlreadyConnecting,
+ tcpip.ErrAlreadyConnected: ErrAlreadyConnected,
+ tcpip.ErrNoPortAvailable: ErrNoPortAvailable,
+ tcpip.ErrPortInUse: ErrPortInUse,
+ tcpip.ErrBadLocalAddress: ErrBadLocalAddress,
+ tcpip.ErrClosedForSend: ErrClosedForSend,
+ tcpip.ErrClosedForReceive: ErrClosedForReceive,
+ tcpip.ErrWouldBlock: ErrWouldBlock,
+ tcpip.ErrConnectionRefused: ErrConnectionRefused,
+ tcpip.ErrTimeout: ErrTimeout,
+ tcpip.ErrAborted: ErrAborted,
+ tcpip.ErrConnectStarted: ErrConnectStarted,
+ tcpip.ErrDestinationRequired: ErrDestinationRequired,
+ tcpip.ErrNotSupported: ErrNotSupported,
+ tcpip.ErrQueueSizeNotSupported: ErrQueueSizeNotSupported,
+ tcpip.ErrNotConnected: ErrNotConnected,
+ tcpip.ErrConnectionReset: ErrConnectionReset,
+ tcpip.ErrConnectionAborted: ErrConnectionAborted,
+ tcpip.ErrNoSuchFile: ErrNoSuchFile,
+ tcpip.ErrInvalidOptionValue: ErrInvalidOptionValue,
+ tcpip.ErrNoLinkAddress: ErrHostDown,
+ tcpip.ErrBadAddress: ErrBadAddress,
+ tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable,
+ tcpip.ErrMessageTooLong: ErrMessageTooLong,
+ tcpip.ErrNoBufferSpace: ErrNoBufferSpace,
+ tcpip.ErrBroadcastDisabled: ErrBroadcastDisabled,
+ tcpip.ErrNotPermitted: ErrNotPermittedNet,
+ tcpip.ErrAddressFamilyNotSupported: ErrAddressFamilyNotSupported,
+}
+
+// TranslateNetstackError converts an error from the tcpip package to a sentry
+// internal error.
+func TranslateNetstackError(err *tcpip.Error) *Error {
+ if err == nil {
+ return nil
+ }
+ se, ok := netstackErrorTranslations[err]
+ if !ok {
+ panic("Unknown error: " + err.String())
+ }
+ return se
+}
diff --git a/pkg/syserr/syserr.go b/pkg/syserr/syserr.go
new file mode 100644
index 000000000..ac4b799c3
--- /dev/null
+++ b/pkg/syserr/syserr.go
@@ -0,0 +1,293 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package syserr contains sandbox-internal errors. These errors are distinct
+// from both the errors returned by host system calls and the errors returned
+// to sandboxed applications.
+package syserr
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Error represents an internal error.
+type Error struct {
+ // message is the human readable form of this Error.
+ message string
+
+ // noTranslation indicates that this Error cannot be translated to a
+ // linux.Errno.
+ noTranslation bool
+
+ // errno is the linux.Errno this Error should be translated to. nil means
+ // that this Error should be translated to a nil linux.Errno.
+ errno *linux.Errno
+}
+
+// New creates a new Error and adds a translation for it.
+//
+// New must only be called at init.
+func New(message string, linuxTranslation *linux.Errno) *Error {
+ err := &Error{message: message, errno: linuxTranslation}
+
+ if linuxTranslation == nil {
+ return err
+ }
+
+ // TODO(b/34162363): Remove this.
+ errno := linuxTranslation.Number()
+ if errno <= 0 || errno >= len(linuxBackwardsTranslations) {
+ panic(fmt.Sprint("invalid errno: ", errno))
+ }
+
+ e := error(syscall.Errno(errno))
+ // syserror.ErrWouldBlock gets translated to syserror.EWOULDBLOCK and
+ // enables proper blocking semantics. This should temporary address the
+ // class of blocking bugs that keep popping up with the current state of
+ // the error space.
+ if e == syserror.EWOULDBLOCK {
+ e = syserror.ErrWouldBlock
+ }
+ linuxBackwardsTranslations[errno] = linuxBackwardsTranslation{err: e, ok: true}
+
+ return err
+}
+
+// NewDynamic creates a new error with a dynamic error message and an errno
+// translation.
+//
+// NewDynamic should only be used sparingly and not be used for static error
+// messages. Errors with static error messages should be declared with New as
+// global variables.
+func NewDynamic(message string, linuxTranslation *linux.Errno) *Error {
+ return &Error{message: message, errno: linuxTranslation}
+}
+
+// NewWithoutTranslation creates a new Error. If translation is attempted on
+// the error, translation will fail.
+//
+// NewWithoutTranslation may be called at any time, but static errors should
+// be declared as global variables and dynamic errors should be used sparingly.
+func NewWithoutTranslation(message string) *Error {
+ return &Error{message: message, noTranslation: true}
+}
+
+func newWithHost(message string, linuxTranslation *linux.Errno, hostErrno syscall.Errno) *Error {
+ e := New(message, linuxTranslation)
+ addLinuxHostTranslation(hostErrno, e)
+ return e
+}
+
+// String implements fmt.Stringer.String.
+func (e *Error) String() string {
+ if e == nil {
+ return "<nil>"
+ }
+ return e.message
+}
+
+type linuxBackwardsTranslation struct {
+ err error
+ ok bool
+}
+
+// TODO(b/34162363): Remove this.
+var linuxBackwardsTranslations [maxErrno]linuxBackwardsTranslation
+
+// ToError translates an Error to a corresponding error value.
+//
+// TODO(b/34162363): Remove this.
+func (e *Error) ToError() error {
+ if e == nil {
+ return nil
+ }
+ if e.noTranslation {
+ panic(fmt.Sprintf("error %q does not support translation", e.message))
+ }
+ if e.errno == nil {
+ return nil
+ }
+ errno := e.errno.Number()
+ if errno <= 0 || errno >= len(linuxBackwardsTranslations) || !linuxBackwardsTranslations[errno].ok {
+ panic(fmt.Sprintf("unknown error %q (%d)", e.message, errno))
+ }
+ return linuxBackwardsTranslations[errno].err
+}
+
+// ToLinux converts the Error to a Linux ABI error that can be returned to the
+// application.
+func (e *Error) ToLinux() *linux.Errno {
+ if e.noTranslation {
+ panic(fmt.Sprintf("No Linux ABI translation available for %q", e.message))
+ }
+ return e.errno
+}
+
+// TODO(b/34162363): Remove or replace most of these errors.
+//
+// Some of the errors should be replaced with package specific errors and
+// others should be removed entirely.
+var (
+ ErrNotPermitted = newWithHost("operation not permitted", linux.EPERM, syscall.EPERM)
+ ErrNoFileOrDir = newWithHost("no such file or directory", linux.ENOENT, syscall.ENOENT)
+ ErrNoProcess = newWithHost("no such process", linux.ESRCH, syscall.ESRCH)
+ ErrInterrupted = newWithHost("interrupted system call", linux.EINTR, syscall.EINTR)
+ ErrIO = newWithHost("I/O error", linux.EIO, syscall.EIO)
+ ErrDeviceOrAddress = newWithHost("no such device or address", linux.ENXIO, syscall.ENXIO)
+ ErrTooManyArgs = newWithHost("argument list too long", linux.E2BIG, syscall.E2BIG)
+ ErrEcec = newWithHost("exec format error", linux.ENOEXEC, syscall.ENOEXEC)
+ ErrBadFD = newWithHost("bad file number", linux.EBADF, syscall.EBADF)
+ ErrNoChild = newWithHost("no child processes", linux.ECHILD, syscall.ECHILD)
+ ErrTryAgain = newWithHost("try again", linux.EAGAIN, syscall.EAGAIN)
+ ErrNoMemory = newWithHost("out of memory", linux.ENOMEM, syscall.ENOMEM)
+ ErrPermissionDenied = newWithHost("permission denied", linux.EACCES, syscall.EACCES)
+ ErrBadAddress = newWithHost("bad address", linux.EFAULT, syscall.EFAULT)
+ ErrNotBlockDevice = newWithHost("block device required", linux.ENOTBLK, syscall.ENOTBLK)
+ ErrBusy = newWithHost("device or resource busy", linux.EBUSY, syscall.EBUSY)
+ ErrExists = newWithHost("file exists", linux.EEXIST, syscall.EEXIST)
+ ErrCrossDeviceLink = newWithHost("cross-device link", linux.EXDEV, syscall.EXDEV)
+ ErrNoDevice = newWithHost("no such device", linux.ENODEV, syscall.ENODEV)
+ ErrNotDir = newWithHost("not a directory", linux.ENOTDIR, syscall.ENOTDIR)
+ ErrIsDir = newWithHost("is a directory", linux.EISDIR, syscall.EISDIR)
+ ErrInvalidArgument = newWithHost("invalid argument", linux.EINVAL, syscall.EINVAL)
+ ErrFileTableOverflow = newWithHost("file table overflow", linux.ENFILE, syscall.ENFILE)
+ ErrTooManyOpenFiles = newWithHost("too many open files", linux.EMFILE, syscall.EMFILE)
+ ErrNotTTY = newWithHost("not a typewriter", linux.ENOTTY, syscall.ENOTTY)
+ ErrTestFileBusy = newWithHost("text file busy", linux.ETXTBSY, syscall.ETXTBSY)
+ ErrFileTooBig = newWithHost("file too large", linux.EFBIG, syscall.EFBIG)
+ ErrNoSpace = newWithHost("no space left on device", linux.ENOSPC, syscall.ENOSPC)
+ ErrIllegalSeek = newWithHost("illegal seek", linux.ESPIPE, syscall.ESPIPE)
+ ErrReadOnlyFS = newWithHost("read-only file system", linux.EROFS, syscall.EROFS)
+ ErrTooManyLinks = newWithHost("too many links", linux.EMLINK, syscall.EMLINK)
+ ErrBrokenPipe = newWithHost("broken pipe", linux.EPIPE, syscall.EPIPE)
+ ErrDomain = newWithHost("math argument out of domain of func", linux.EDOM, syscall.EDOM)
+ ErrRange = newWithHost("math result not representable", linux.ERANGE, syscall.ERANGE)
+ ErrDeadlock = newWithHost("resource deadlock would occur", linux.EDEADLOCK, syscall.EDEADLOCK)
+ ErrNameTooLong = newWithHost("file name too long", linux.ENAMETOOLONG, syscall.ENAMETOOLONG)
+ ErrNoLocksAvailable = newWithHost("no record locks available", linux.ENOLCK, syscall.ENOLCK)
+ ErrInvalidSyscall = newWithHost("invalid system call number", linux.ENOSYS, syscall.ENOSYS)
+ ErrDirNotEmpty = newWithHost("directory not empty", linux.ENOTEMPTY, syscall.ENOTEMPTY)
+ ErrLinkLoop = newWithHost("too many symbolic links encountered", linux.ELOOP, syscall.ELOOP)
+ ErrNoMessage = newWithHost("no message of desired type", linux.ENOMSG, syscall.ENOMSG)
+ ErrIdentifierRemoved = newWithHost("identifier removed", linux.EIDRM, syscall.EIDRM)
+ ErrChannelOutOfRange = newWithHost("channel number out of range", linux.ECHRNG, syscall.ECHRNG)
+ ErrLevelTwoNotSynced = newWithHost("level 2 not synchronized", linux.EL2NSYNC, syscall.EL2NSYNC)
+ ErrLevelThreeHalted = newWithHost("level 3 halted", linux.EL3HLT, syscall.EL3HLT)
+ ErrLevelThreeReset = newWithHost("level 3 reset", linux.EL3RST, syscall.EL3RST)
+ ErrLinkNumberOutOfRange = newWithHost("link number out of range", linux.ELNRNG, syscall.ELNRNG)
+ ErrProtocolDriverNotAttached = newWithHost("protocol driver not attached", linux.EUNATCH, syscall.EUNATCH)
+ ErrNoCSIAvailable = newWithHost("no CSI structure available", linux.ENOCSI, syscall.ENOCSI)
+ ErrLevelTwoHalted = newWithHost("level 2 halted", linux.EL2HLT, syscall.EL2HLT)
+ ErrInvalidExchange = newWithHost("invalid exchange", linux.EBADE, syscall.EBADE)
+ ErrInvalidRequestDescriptor = newWithHost("invalid request descriptor", linux.EBADR, syscall.EBADR)
+ ErrExchangeFull = newWithHost("exchange full", linux.EXFULL, syscall.EXFULL)
+ ErrNoAnode = newWithHost("no anode", linux.ENOANO, syscall.ENOANO)
+ ErrInvalidRequestCode = newWithHost("invalid request code", linux.EBADRQC, syscall.EBADRQC)
+ ErrInvalidSlot = newWithHost("invalid slot", linux.EBADSLT, syscall.EBADSLT)
+ ErrBadFontFile = newWithHost("bad font file format", linux.EBFONT, syscall.EBFONT)
+ ErrNotStream = newWithHost("device not a stream", linux.ENOSTR, syscall.ENOSTR)
+ ErrNoDataAvailable = newWithHost("no data available", linux.ENODATA, syscall.ENODATA)
+ ErrTimerExpired = newWithHost("timer expired", linux.ETIME, syscall.ETIME)
+ ErrStreamsResourceDepleted = newWithHost("out of streams resources", linux.ENOSR, syscall.ENOSR)
+ ErrMachineNotOnNetwork = newWithHost("machine is not on the network", linux.ENONET, syscall.ENONET)
+ ErrPackageNotInstalled = newWithHost("package not installed", linux.ENOPKG, syscall.ENOPKG)
+ ErrIsRemote = newWithHost("object is remote", linux.EREMOTE, syscall.EREMOTE)
+ ErrNoLink = newWithHost("link has been severed", linux.ENOLINK, syscall.ENOLINK)
+ ErrAdvertise = newWithHost("advertise error", linux.EADV, syscall.EADV)
+ ErrSRMount = newWithHost("srmount error", linux.ESRMNT, syscall.ESRMNT)
+ ErrSendCommunication = newWithHost("communication error on send", linux.ECOMM, syscall.ECOMM)
+ ErrProtocol = newWithHost("protocol error", linux.EPROTO, syscall.EPROTO)
+ ErrMultihopAttempted = newWithHost("multihop attempted", linux.EMULTIHOP, syscall.EMULTIHOP)
+ ErrRFS = newWithHost("RFS specific error", linux.EDOTDOT, syscall.EDOTDOT)
+ ErrInvalidDataMessage = newWithHost("not a data message", linux.EBADMSG, syscall.EBADMSG)
+ ErrOverflow = newWithHost("value too large for defined data type", linux.EOVERFLOW, syscall.EOVERFLOW)
+ ErrNetworkNameNotUnique = newWithHost("name not unique on network", linux.ENOTUNIQ, syscall.ENOTUNIQ)
+ ErrFDInBadState = newWithHost("file descriptor in bad state", linux.EBADFD, syscall.EBADFD)
+ ErrRemoteAddressChanged = newWithHost("remote address changed", linux.EREMCHG, syscall.EREMCHG)
+ ErrSharedLibraryInaccessible = newWithHost("can not access a needed shared library", linux.ELIBACC, syscall.ELIBACC)
+ ErrCorruptedSharedLibrary = newWithHost("accessing a corrupted shared library", linux.ELIBBAD, syscall.ELIBBAD)
+ ErrLibSectionCorrupted = newWithHost(".lib section in a.out corrupted", linux.ELIBSCN, syscall.ELIBSCN)
+ ErrTooManySharedLibraries = newWithHost("attempting to link in too many shared libraries", linux.ELIBMAX, syscall.ELIBMAX)
+ ErrSharedLibraryExeced = newWithHost("cannot exec a shared library directly", linux.ELIBEXEC, syscall.ELIBEXEC)
+ ErrIllegalByteSequence = newWithHost("illegal byte sequence", linux.EILSEQ, syscall.EILSEQ)
+ ErrShouldRestart = newWithHost("interrupted system call should be restarted", linux.ERESTART, syscall.ERESTART)
+ ErrStreamPipe = newWithHost("streams pipe error", linux.ESTRPIPE, syscall.ESTRPIPE)
+ ErrTooManyUsers = newWithHost("too many users", linux.EUSERS, syscall.EUSERS)
+ ErrNotASocket = newWithHost("socket operation on non-socket", linux.ENOTSOCK, syscall.ENOTSOCK)
+ ErrDestinationAddressRequired = newWithHost("destination address required", linux.EDESTADDRREQ, syscall.EDESTADDRREQ)
+ ErrMessageTooLong = newWithHost("message too long", linux.EMSGSIZE, syscall.EMSGSIZE)
+ ErrWrongProtocolForSocket = newWithHost("protocol wrong type for socket", linux.EPROTOTYPE, syscall.EPROTOTYPE)
+ ErrProtocolNotAvailable = newWithHost("protocol not available", linux.ENOPROTOOPT, syscall.ENOPROTOOPT)
+ ErrProtocolNotSupported = newWithHost("protocol not supported", linux.EPROTONOSUPPORT, syscall.EPROTONOSUPPORT)
+ ErrSocketNotSupported = newWithHost("socket type not supported", linux.ESOCKTNOSUPPORT, syscall.ESOCKTNOSUPPORT)
+ ErrEndpointOperation = newWithHost("operation not supported on transport endpoint", linux.EOPNOTSUPP, syscall.EOPNOTSUPP)
+ ErrProtocolFamilyNotSupported = newWithHost("protocol family not supported", linux.EPFNOSUPPORT, syscall.EPFNOSUPPORT)
+ ErrAddressFamilyNotSupported = newWithHost("address family not supported by protocol", linux.EAFNOSUPPORT, syscall.EAFNOSUPPORT)
+ ErrAddressInUse = newWithHost("address already in use", linux.EADDRINUSE, syscall.EADDRINUSE)
+ ErrAddressNotAvailable = newWithHost("cannot assign requested address", linux.EADDRNOTAVAIL, syscall.EADDRNOTAVAIL)
+ ErrNetworkDown = newWithHost("network is down", linux.ENETDOWN, syscall.ENETDOWN)
+ ErrNetworkUnreachable = newWithHost("network is unreachable", linux.ENETUNREACH, syscall.ENETUNREACH)
+ ErrNetworkReset = newWithHost("network dropped connection because of reset", linux.ENETRESET, syscall.ENETRESET)
+ ErrConnectionAborted = newWithHost("software caused connection abort", linux.ECONNABORTED, syscall.ECONNABORTED)
+ ErrConnectionReset = newWithHost("connection reset by peer", linux.ECONNRESET, syscall.ECONNRESET)
+ ErrNoBufferSpace = newWithHost("no buffer space available", linux.ENOBUFS, syscall.ENOBUFS)
+ ErrAlreadyConnected = newWithHost("transport endpoint is already connected", linux.EISCONN, syscall.EISCONN)
+ ErrNotConnected = newWithHost("transport endpoint is not connected", linux.ENOTCONN, syscall.ENOTCONN)
+ ErrShutdown = newWithHost("cannot send after transport endpoint shutdown", linux.ESHUTDOWN, syscall.ESHUTDOWN)
+ ErrTooManyRefs = newWithHost("too many references: cannot splice", linux.ETOOMANYREFS, syscall.ETOOMANYREFS)
+ ErrTimedOut = newWithHost("connection timed out", linux.ETIMEDOUT, syscall.ETIMEDOUT)
+ ErrConnectionRefused = newWithHost("connection refused", linux.ECONNREFUSED, syscall.ECONNREFUSED)
+ ErrHostDown = newWithHost("host is down", linux.EHOSTDOWN, syscall.EHOSTDOWN)
+ ErrNoRoute = newWithHost("no route to host", linux.EHOSTUNREACH, syscall.EHOSTUNREACH)
+ ErrAlreadyInProgress = newWithHost("operation already in progress", linux.EALREADY, syscall.EALREADY)
+ ErrInProgress = newWithHost("operation now in progress", linux.EINPROGRESS, syscall.EINPROGRESS)
+ ErrStaleFileHandle = newWithHost("stale file handle", linux.ESTALE, syscall.ESTALE)
+ ErrStructureNeedsCleaning = newWithHost("structure needs cleaning", linux.EUCLEAN, syscall.EUCLEAN)
+ ErrIsNamedFile = newWithHost("is a named type file", linux.ENOTNAM, syscall.ENOTNAM)
+ ErrRemoteIO = newWithHost("remote I/O error", linux.EREMOTEIO, syscall.EREMOTEIO)
+ ErrQuotaExceeded = newWithHost("quota exceeded", linux.EDQUOT, syscall.EDQUOT)
+ ErrNoMedium = newWithHost("no medium found", linux.ENOMEDIUM, syscall.ENOMEDIUM)
+ ErrWrongMediumType = newWithHost("wrong medium type", linux.EMEDIUMTYPE, syscall.EMEDIUMTYPE)
+ ErrCanceled = newWithHost("operation canceled", linux.ECANCELED, syscall.ECANCELED)
+ ErrNoKey = newWithHost("required key not available", linux.ENOKEY, syscall.ENOKEY)
+ ErrKeyExpired = newWithHost("key has expired", linux.EKEYEXPIRED, syscall.EKEYEXPIRED)
+ ErrKeyRevoked = newWithHost("key has been revoked", linux.EKEYREVOKED, syscall.EKEYREVOKED)
+ ErrKeyRejected = newWithHost("key was rejected by service", linux.EKEYREJECTED, syscall.EKEYREJECTED)
+ ErrOwnerDied = newWithHost("owner died", linux.EOWNERDEAD, syscall.EOWNERDEAD)
+ ErrNotRecoverable = newWithHost("state not recoverable", linux.ENOTRECOVERABLE, syscall.ENOTRECOVERABLE)
+
+ // ErrWouldBlock translates to EWOULDBLOCK which is the same as EAGAIN
+ // on Linux.
+ ErrWouldBlock = New("operation would block", linux.EWOULDBLOCK)
+)
+
+// FromError converts a generic error to an *Error.
+//
+// TODO(b/34162363): Remove this function.
+func FromError(err error) *Error {
+ if err == nil {
+ return nil
+ }
+ if errno, ok := err.(syscall.Errno); ok {
+ return FromHost(errno)
+ }
+ if errno, ok := syserror.TranslateError(err); ok {
+ return FromHost(errno)
+ }
+ panic("unknown error: " + err.Error())
+}
diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD
new file mode 100644
index 000000000..b13c15d9b
--- /dev/null
+++ b/pkg/syserror/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "syserror",
+ srcs = ["syserror.go"],
+ visibility = ["//visibility:public"],
+)
+
+go_test(
+ name = "syserror_test",
+ srcs = ["syserror_test.go"],
+ deps = [
+ ":syserror",
+ ],
+)
diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go
new file mode 100644
index 000000000..c73072c42
--- /dev/null
+++ b/pkg/syserror/syserror.go
@@ -0,0 +1,159 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package syserror contains syscall error codes exported as error interface
+// instead of Errno. This allows for fast comparison and returns when the
+// comparand or return value is of type error because there is no need to
+// convert from Errno to an interface, i.e., runtime.convT2I isn't called.
+package syserror
+
+import (
+ "errors"
+ "syscall"
+)
+
+// The following variables have the same meaning as their syscall equivalent.
+var (
+ E2BIG = error(syscall.E2BIG)
+ EACCES = error(syscall.EACCES)
+ EADDRINUSE = error(syscall.EADDRINUSE)
+ EAGAIN = error(syscall.EAGAIN)
+ EBADF = error(syscall.EBADF)
+ EBADFD = error(syscall.EBADFD)
+ EBUSY = error(syscall.EBUSY)
+ ECHILD = error(syscall.ECHILD)
+ ECONNREFUSED = error(syscall.ECONNREFUSED)
+ ECONNRESET = error(syscall.ECONNRESET)
+ EDEADLK = error(syscall.EDEADLK)
+ EEXIST = error(syscall.EEXIST)
+ EFAULT = error(syscall.EFAULT)
+ EFBIG = error(syscall.EFBIG)
+ EIDRM = error(syscall.EIDRM)
+ EINTR = error(syscall.EINTR)
+ EINVAL = error(syscall.EINVAL)
+ EIO = error(syscall.EIO)
+ EISDIR = error(syscall.EISDIR)
+ ELIBBAD = error(syscall.ELIBBAD)
+ ELOOP = error(syscall.ELOOP)
+ EMFILE = error(syscall.EMFILE)
+ EMLINK = error(syscall.EMLINK)
+ EMSGSIZE = error(syscall.EMSGSIZE)
+ ENAMETOOLONG = error(syscall.ENAMETOOLONG)
+ ENOATTR = ENODATA
+ ENOBUFS = error(syscall.ENOBUFS)
+ ENODATA = error(syscall.ENODATA)
+ ENODEV = error(syscall.ENODEV)
+ ENOENT = error(syscall.ENOENT)
+ ENOEXEC = error(syscall.ENOEXEC)
+ ENOLCK = error(syscall.ENOLCK)
+ ENOLINK = error(syscall.ENOLINK)
+ ENOMEM = error(syscall.ENOMEM)
+ ENOSPC = error(syscall.ENOSPC)
+ ENOSYS = error(syscall.ENOSYS)
+ ENOTDIR = error(syscall.ENOTDIR)
+ ENOTEMPTY = error(syscall.ENOTEMPTY)
+ ENOTSOCK = error(syscall.ENOTSOCK)
+ ENOTSUP = error(syscall.ENOTSUP)
+ ENOTTY = error(syscall.ENOTTY)
+ ENXIO = error(syscall.ENXIO)
+ EOPNOTSUPP = error(syscall.EOPNOTSUPP)
+ EOVERFLOW = error(syscall.EOVERFLOW)
+ EPERM = error(syscall.EPERM)
+ EPIPE = error(syscall.EPIPE)
+ ERANGE = error(syscall.ERANGE)
+ EREMOTE = error(syscall.EREMOTE)
+ EROFS = error(syscall.EROFS)
+ ESPIPE = error(syscall.ESPIPE)
+ ESRCH = error(syscall.ESRCH)
+ ETIMEDOUT = error(syscall.ETIMEDOUT)
+ EUSERS = error(syscall.EUSERS)
+ EWOULDBLOCK = error(syscall.EWOULDBLOCK)
+ EXDEV = error(syscall.EXDEV)
+)
+
+var (
+ // ErrWouldBlock is an internal error used to indicate that an operation
+ // cannot be satisfied immediately, and should be retried at a later
+ // time, possibly when the caller has received a notification that the
+ // operation may be able to complete. It is used by implementations of
+ // the kio.File interface.
+ ErrWouldBlock = errors.New("request would block")
+
+ // ErrInterrupted is returned if a request is interrupted before it can
+ // complete.
+ ErrInterrupted = errors.New("request was interrupted")
+
+ // ErrExceedsFileSizeLimit is returned if a request would exceed the
+ // file's size limit.
+ ErrExceedsFileSizeLimit = errors.New("exceeds file size limit")
+)
+
+// errorMap is the map used to convert generic errors into errnos.
+var errorMap = map[error]syscall.Errno{}
+
+// errorUnwrappers is an array of unwrap functions to extract typed errors.
+var errorUnwrappers = []func(error) (syscall.Errno, bool){}
+
+// AddErrorTranslation allows modules to populate the error map by adding their
+// own translations during initialization. Returns if the error translation is
+// accepted or not. A pre-existing translation will not be overwritten by the
+// new translation.
+func AddErrorTranslation(from error, to syscall.Errno) bool {
+ if _, ok := errorMap[from]; ok {
+ return false
+ }
+
+ errorMap[from] = to
+ return true
+}
+
+// AddErrorUnwrapper registers an unwrap method that can extract a concrete error
+// from a typed, but not initialized, error.
+func AddErrorUnwrapper(unwrap func(e error) (syscall.Errno, bool)) {
+ errorUnwrappers = append(errorUnwrappers, unwrap)
+}
+
+// TranslateError translates errors to errnos, it will return false if
+// the error was not registered.
+func TranslateError(from error) (syscall.Errno, bool) {
+ err, ok := errorMap[from]
+ if ok {
+ return err, ok
+ }
+ // Try to unwrap the error if we couldn't match an error
+ // exactly. This might mean that a package has its own
+ // error type.
+ for _, unwrap := range errorUnwrappers {
+ err, ok := unwrap(from)
+ if ok {
+ return err, ok
+ }
+ }
+ return 0, false
+}
+
+// ConvertIntr converts the provided error code (err) to another one (intr) if
+// the first error corresponds to an interrupted operation.
+func ConvertIntr(err, intr error) error {
+ if err == ErrInterrupted {
+ return intr
+ }
+ return err
+}
+
+func init() {
+ AddErrorTranslation(ErrWouldBlock, syscall.EWOULDBLOCK)
+ AddErrorTranslation(ErrInterrupted, syscall.EINTR)
+ AddErrorTranslation(ErrExceedsFileSizeLimit, syscall.EFBIG)
+}
diff --git a/pkg/syserror/syserror_test.go b/pkg/syserror/syserror_test.go
new file mode 100644
index 000000000..29719752e
--- /dev/null
+++ b/pkg/syserror/syserror_test.go
@@ -0,0 +1,136 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syserror_test
+
+import (
+ "errors"
+ "syscall"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+var globalError error
+
+func returnErrnoAsError() error {
+ return syscall.EINVAL
+}
+
+func returnError() error {
+ return syserror.EINVAL
+}
+
+func BenchmarkReturnErrnoAsError(b *testing.B) {
+ for i := b.N; i > 0; i-- {
+ returnErrnoAsError()
+ }
+}
+
+func BenchmarkReturnError(b *testing.B) {
+ for i := b.N; i > 0; i-- {
+ returnError()
+ }
+}
+
+func BenchmarkCompareErrno(b *testing.B) {
+ j := 0
+ for i := b.N; i > 0; i-- {
+ if globalError == syscall.EINVAL {
+ j++
+ }
+ }
+}
+
+func BenchmarkCompareError(b *testing.B) {
+ j := 0
+ for i := b.N; i > 0; i-- {
+ if globalError == syserror.EINVAL {
+ j++
+ }
+ }
+}
+
+func BenchmarkSwitchErrno(b *testing.B) {
+ j := 0
+ for i := b.N; i > 0; i-- {
+ switch globalError {
+ case syscall.EINVAL:
+ j += 1
+ case syscall.EINTR:
+ j += 2
+ case syscall.EAGAIN:
+ j += 3
+ }
+ }
+}
+
+func BenchmarkSwitchError(b *testing.B) {
+ j := 0
+ for i := b.N; i > 0; i-- {
+ switch globalError {
+ case syserror.EINVAL:
+ j += 1
+ case syserror.EINTR:
+ j += 2
+ case syserror.EAGAIN:
+ j += 3
+ }
+ }
+}
+
+type translationTestTable struct {
+ fn string
+ errIn error
+ syscallErrorIn syscall.Errno
+ expectedBool bool
+ expectedTranslation syscall.Errno
+}
+
+func TestErrorTranslation(t *testing.T) {
+ myError := errors.New("My test error")
+ myError2 := errors.New("Another test error")
+ testTable := []translationTestTable{
+ {"TranslateError", myError, 0, false, 0},
+ {"TranslateError", myError2, 0, false, 0},
+ {"AddErrorTranslation", myError, syscall.EAGAIN, true, 0},
+ {"AddErrorTranslation", myError, syscall.EAGAIN, false, 0},
+ {"AddErrorTranslation", myError, syscall.EPERM, false, 0},
+ {"TranslateError", myError, 0, true, syscall.EAGAIN},
+ {"TranslateError", myError2, 0, false, 0},
+ {"AddErrorTranslation", myError2, syscall.EPERM, true, 0},
+ {"AddErrorTranslation", myError2, syscall.EPERM, false, 0},
+ {"AddErrorTranslation", myError2, syscall.EAGAIN, false, 0},
+ {"TranslateError", myError, 0, true, syscall.EAGAIN},
+ {"TranslateError", myError2, 0, true, syscall.EPERM},
+ }
+ for _, tt := range testTable {
+ switch tt.fn {
+ case "TranslateError":
+ err, ok := syserror.TranslateError(tt.errIn)
+ if ok != tt.expectedBool {
+ t.Fatalf("%v(%v) => %v expected %v", tt.fn, tt.errIn, ok, tt.expectedBool)
+ } else if err != tt.expectedTranslation {
+ t.Fatalf("%v(%v) (error) => %v expected %v", tt.fn, tt.errIn, err, tt.expectedTranslation)
+ }
+ case "AddErrorTranslation":
+ ok := syserror.AddErrorTranslation(tt.errIn, tt.syscallErrorIn)
+ if ok != tt.expectedBool {
+ t.Fatalf("%v(%v) => %v expected %v", tt.fn, tt.errIn, ok, tt.expectedBool)
+ }
+ default:
+ t.Fatalf("Unknown function %v", tt.fn)
+ }
+ }
+}
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
new file mode 100644
index 000000000..454e07662
--- /dev/null
+++ b/pkg/tcpip/BUILD
@@ -0,0 +1,32 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "tcpip",
+ srcs = [
+ "tcpip.go",
+ "time_unsafe.go",
+ "timer.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sync",
+ "//pkg/tcpip/buffer",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "tcpip_test",
+ size = "small",
+ srcs = ["tcpip_test.go"],
+ library = ":tcpip",
+)
+
+go_test(
+ name = "tcpip_x_test",
+ size = "small",
+ srcs = ["timer_test.go"],
+ deps = [":tcpip"],
+)
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
new file mode 100644
index 000000000..a984f1712
--- /dev/null
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -0,0 +1,37 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "gonet",
+ srcs = ["gonet.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "gonet_test",
+ size = "small",
+ srcs = ["gonet_test.go"],
+ library = ":gonet",
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ "@org_golang_x_net//nettest:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
new file mode 100644
index 000000000..d82ed5205
--- /dev/null
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -0,0 +1,738 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gonet provides a Go net package compatible wrapper for a tcpip stack.
+package gonet
+
+import (
+ "context"
+ "errors"
+ "io"
+ "net"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+var (
+ errCanceled = errors.New("operation canceled")
+ errWouldBlock = errors.New("operation would block")
+)
+
+// timeoutError is how the net package reports timeouts.
+type timeoutError struct{}
+
+func (e *timeoutError) Error() string { return "i/o timeout" }
+func (e *timeoutError) Timeout() bool { return true }
+func (e *timeoutError) Temporary() bool { return true }
+
+// A TCPListener is a wrapper around a TCP tcpip.Endpoint that implements
+// net.Listener.
+type TCPListener struct {
+ stack *stack.Stack
+ ep tcpip.Endpoint
+ wq *waiter.Queue
+ cancel chan struct{}
+}
+
+// NewTCPListener creates a new TCPListener from a listening tcpip.Endpoint.
+func NewTCPListener(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *TCPListener {
+ return &TCPListener{
+ stack: s,
+ ep: ep,
+ wq: wq,
+ cancel: make(chan struct{}),
+ }
+}
+
+// ListenTCP creates a new TCPListener.
+func ListenTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPListener, error) {
+ // Create a TCP endpoint, bind it, then start listening.
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
+ if err != nil {
+ return nil, errors.New(err.String())
+ }
+
+ if err := ep.Bind(addr); err != nil {
+ ep.Close()
+ return nil, &net.OpError{
+ Op: "bind",
+ Net: "tcp",
+ Addr: fullToTCPAddr(addr),
+ Err: errors.New(err.String()),
+ }
+ }
+
+ if err := ep.Listen(10); err != nil {
+ ep.Close()
+ return nil, &net.OpError{
+ Op: "listen",
+ Net: "tcp",
+ Addr: fullToTCPAddr(addr),
+ Err: errors.New(err.String()),
+ }
+ }
+
+ return NewTCPListener(s, &wq, ep), nil
+}
+
+// Close implements net.Listener.Close.
+func (l *TCPListener) Close() error {
+ l.ep.Close()
+ return nil
+}
+
+// Shutdown stops the HTTP server.
+func (l *TCPListener) Shutdown() {
+ l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+ close(l.cancel) // broadcast cancellation
+}
+
+// Addr implements net.Listener.Addr.
+func (l *TCPListener) Addr() net.Addr {
+ a, err := l.ep.GetLocalAddress()
+ if err != nil {
+ return nil
+ }
+ return fullToTCPAddr(a)
+}
+
+type deadlineTimer struct {
+ // mu protects the fields below.
+ mu sync.Mutex
+
+ readTimer *time.Timer
+ readCancelCh chan struct{}
+ writeTimer *time.Timer
+ writeCancelCh chan struct{}
+}
+
+func (d *deadlineTimer) init() {
+ d.readCancelCh = make(chan struct{})
+ d.writeCancelCh = make(chan struct{})
+}
+
+func (d *deadlineTimer) readCancel() <-chan struct{} {
+ d.mu.Lock()
+ c := d.readCancelCh
+ d.mu.Unlock()
+ return c
+}
+func (d *deadlineTimer) writeCancel() <-chan struct{} {
+ d.mu.Lock()
+ c := d.writeCancelCh
+ d.mu.Unlock()
+ return c
+}
+
+// setDeadline contains the shared logic for setting a deadline.
+//
+// cancelCh and timer must be pointers to deadlineTimer.readCancelCh and
+// deadlineTimer.readTimer or deadlineTimer.writeCancelCh and
+// deadlineTimer.writeTimer.
+//
+// setDeadline must only be called while holding d.mu.
+func (d *deadlineTimer) setDeadline(cancelCh *chan struct{}, timer **time.Timer, t time.Time) {
+ if *timer != nil && !(*timer).Stop() {
+ *cancelCh = make(chan struct{})
+ }
+
+ // Create a new channel if we already closed it due to setting an already
+ // expired time. We won't race with the timer because we already handled
+ // that above.
+ select {
+ case <-*cancelCh:
+ *cancelCh = make(chan struct{})
+ default:
+ }
+
+ // "A zero value for t means I/O operations will not time out."
+ // - net.Conn.SetDeadline
+ if t.IsZero() {
+ return
+ }
+
+ timeout := t.Sub(time.Now())
+ if timeout <= 0 {
+ close(*cancelCh)
+ return
+ }
+
+ // Timer.Stop returns whether or not the AfterFunc has started, but
+ // does not indicate whether or not it has completed. Make a copy of
+ // the cancel channel to prevent this code from racing with the next
+ // call of setDeadline replacing *cancelCh.
+ ch := *cancelCh
+ *timer = time.AfterFunc(timeout, func() {
+ close(ch)
+ })
+}
+
+// SetReadDeadline implements net.Conn.SetReadDeadline and
+// net.PacketConn.SetReadDeadline.
+func (d *deadlineTimer) SetReadDeadline(t time.Time) error {
+ d.mu.Lock()
+ d.setDeadline(&d.readCancelCh, &d.readTimer, t)
+ d.mu.Unlock()
+ return nil
+}
+
+// SetWriteDeadline implements net.Conn.SetWriteDeadline and
+// net.PacketConn.SetWriteDeadline.
+func (d *deadlineTimer) SetWriteDeadline(t time.Time) error {
+ d.mu.Lock()
+ d.setDeadline(&d.writeCancelCh, &d.writeTimer, t)
+ d.mu.Unlock()
+ return nil
+}
+
+// SetDeadline implements net.Conn.SetDeadline and net.PacketConn.SetDeadline.
+func (d *deadlineTimer) SetDeadline(t time.Time) error {
+ d.mu.Lock()
+ d.setDeadline(&d.readCancelCh, &d.readTimer, t)
+ d.setDeadline(&d.writeCancelCh, &d.writeTimer, t)
+ d.mu.Unlock()
+ return nil
+}
+
+// A TCPConn is a wrapper around a TCP tcpip.Endpoint that implements the net.Conn
+// interface.
+type TCPConn struct {
+ deadlineTimer
+
+ wq *waiter.Queue
+ ep tcpip.Endpoint
+
+ // readMu serializes reads and implicitly protects read.
+ //
+ // Lock ordering:
+ // If both readMu and deadlineTimer.mu are to be used in a single
+ // request, readMu must be acquired before deadlineTimer.mu.
+ readMu sync.Mutex
+
+ // read contains bytes that have been read from the endpoint,
+ // but haven't yet been returned.
+ read buffer.View
+}
+
+// NewTCPConn creates a new TCPConn.
+func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn {
+ c := &TCPConn{
+ wq: wq,
+ ep: ep,
+ }
+ c.deadlineTimer.init()
+ return c
+}
+
+// Accept implements net.Conn.Accept.
+func (l *TCPListener) Accept() (net.Conn, error) {
+ n, wq, err := l.ep.Accept()
+
+ if err == tcpip.ErrWouldBlock {
+ // Create wait queue entry that notifies a channel.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ l.wq.EventRegister(&waitEntry, waiter.EventIn)
+ defer l.wq.EventUnregister(&waitEntry)
+
+ for {
+ n, wq, err = l.ep.Accept()
+
+ if err != tcpip.ErrWouldBlock {
+ break
+ }
+
+ select {
+ case <-l.cancel:
+ return nil, errCanceled
+ case <-notifyCh:
+ }
+ }
+ }
+
+ if err != nil {
+ return nil, &net.OpError{
+ Op: "accept",
+ Net: "tcp",
+ Addr: l.Addr(),
+ Err: errors.New(err.String()),
+ }
+ }
+
+ return NewTCPConn(wq, n), nil
+}
+
+type opErrorer interface {
+ newOpError(op string, err error) *net.OpError
+}
+
+// commonRead implements the common logic between net.Conn.Read and
+// net.PacketConn.ReadFrom.
+func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer, dontWait bool) ([]byte, error) {
+ select {
+ case <-deadline:
+ return nil, errorer.newOpError("read", &timeoutError{})
+ default:
+ }
+
+ read, _, err := ep.Read(addr)
+
+ if err == tcpip.ErrWouldBlock {
+ if dontWait {
+ return nil, errWouldBlock
+ }
+ // Create wait queue entry that notifies a channel.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&waitEntry, waiter.EventIn)
+ defer wq.EventUnregister(&waitEntry)
+ for {
+ read, _, err = ep.Read(addr)
+ if err != tcpip.ErrWouldBlock {
+ break
+ }
+ select {
+ case <-deadline:
+ return nil, errorer.newOpError("read", &timeoutError{})
+ case <-notifyCh:
+ }
+ }
+ }
+
+ if err == tcpip.ErrClosedForReceive {
+ return nil, io.EOF
+ }
+
+ if err != nil {
+ return nil, errorer.newOpError("read", errors.New(err.String()))
+ }
+
+ return read, nil
+}
+
+// Read implements net.Conn.Read.
+func (c *TCPConn) Read(b []byte) (int, error) {
+ c.readMu.Lock()
+ defer c.readMu.Unlock()
+
+ deadline := c.readCancel()
+
+ numRead := 0
+ defer func() {
+ if numRead != 0 {
+ c.ep.ModerateRecvBuf(numRead)
+ }
+ }()
+ for numRead != len(b) {
+ if len(c.read) == 0 {
+ var err error
+ c.read, err = commonRead(c.ep, c.wq, deadline, nil, c, numRead != 0)
+ if err != nil {
+ if numRead != 0 {
+ return numRead, nil
+ }
+ return numRead, err
+ }
+ }
+ n := copy(b[numRead:], c.read)
+ c.read.TrimFront(n)
+ numRead += n
+ if len(c.read) == 0 {
+ c.read = nil
+ }
+ }
+ return numRead, nil
+}
+
+// Write implements net.Conn.Write.
+func (c *TCPConn) Write(b []byte) (int, error) {
+ deadline := c.writeCancel()
+
+ // Check if deadlineTimer has already expired.
+ select {
+ case <-deadline:
+ return 0, c.newOpError("write", &timeoutError{})
+ default:
+ }
+
+ v := buffer.NewViewFromBytes(b)
+
+ // We must handle two soft failure conditions simultaneously:
+ // 1. Write may write nothing and return tcpip.ErrWouldBlock.
+ // If this happens, we need to register for notifications if we have
+ // not already and wait to try again.
+ // 2. Write may write fewer than the full number of bytes and return
+ // without error. In this case we need to try writing the remaining
+ // bytes again. I do not need to register for notifications.
+ //
+ // What is more, these two soft failure conditions can be interspersed.
+ // There is no guarantee that all of the condition #1s will occur before
+ // all of the condition #2s or visa-versa.
+ var (
+ err *tcpip.Error
+ nbytes int
+ reg bool
+ notifyCh chan struct{}
+ )
+ for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) {
+ if err == tcpip.ErrWouldBlock {
+ if !reg {
+ // Only register once.
+ reg = true
+
+ // Create wait queue entry that notifies a channel.
+ var waitEntry waiter.Entry
+ waitEntry, notifyCh = waiter.NewChannelEntry(nil)
+ c.wq.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.wq.EventUnregister(&waitEntry)
+ } else {
+ // Don't wait immediately after registration in case more data
+ // became available between when we last checked and when we setup
+ // the notification.
+ select {
+ case <-deadline:
+ return nbytes, c.newOpError("write", &timeoutError{})
+ case <-notifyCh:
+ }
+ }
+ }
+
+ var n int64
+ var resCh <-chan struct{}
+ n, resCh, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
+ nbytes += int(n)
+ v.TrimFront(int(n))
+
+ if resCh != nil {
+ select {
+ case <-deadline:
+ return nbytes, c.newOpError("write", &timeoutError{})
+ case <-resCh:
+ }
+
+ n, _, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
+ nbytes += int(n)
+ v.TrimFront(int(n))
+ }
+ }
+
+ if err == nil {
+ return nbytes, nil
+ }
+
+ return nbytes, c.newOpError("write", errors.New(err.String()))
+}
+
+// Close implements net.Conn.Close.
+func (c *TCPConn) Close() error {
+ c.ep.Close()
+ return nil
+}
+
+// CloseRead shuts down the reading side of the TCP connection. Most callers
+// should just use Close.
+//
+// A TCP Half-Close is performed the same as CloseRead for *net.TCPConn.
+func (c *TCPConn) CloseRead() error {
+ if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil {
+ return c.newOpError("close", errors.New(terr.String()))
+ }
+ return nil
+}
+
+// CloseWrite shuts down the writing side of the TCP connection. Most callers
+// should just use Close.
+//
+// A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn.
+func (c *TCPConn) CloseWrite() error {
+ if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil {
+ return c.newOpError("close", errors.New(terr.String()))
+ }
+ return nil
+}
+
+// LocalAddr implements net.Conn.LocalAddr.
+func (c *TCPConn) LocalAddr() net.Addr {
+ a, err := c.ep.GetLocalAddress()
+ if err != nil {
+ return nil
+ }
+ return fullToTCPAddr(a)
+}
+
+// RemoteAddr implements net.Conn.RemoteAddr.
+func (c *TCPConn) RemoteAddr() net.Addr {
+ a, err := c.ep.GetRemoteAddress()
+ if err != nil {
+ return nil
+ }
+ return fullToTCPAddr(a)
+}
+
+func (c *TCPConn) newOpError(op string, err error) *net.OpError {
+ return &net.OpError{
+ Op: op,
+ Net: "tcp",
+ Source: c.LocalAddr(),
+ Addr: c.RemoteAddr(),
+ Err: err,
+ }
+}
+
+func fullToTCPAddr(addr tcpip.FullAddress) *net.TCPAddr {
+ return &net.TCPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)}
+}
+
+func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr {
+ return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)}
+}
+
+// DialTCP creates a new TCPConn connected to the specified address.
+func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
+ return DialContextTCP(context.Background(), s, addr, network)
+}
+
+// DialContextTCP creates a new TCPConn connected to the specified address
+// with the option of adding cancellation and timeouts.
+func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
+ // Create TCP endpoint, then connect.
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
+ if err != nil {
+ return nil, errors.New(err.String())
+ }
+
+ // Create wait queue entry that notifies a channel.
+ //
+ // We do this unconditionally as Connect will always return an error.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&waitEntry, waiter.EventOut)
+ defer wq.EventUnregister(&waitEntry)
+
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ }
+
+ err = ep.Connect(addr)
+ if err == tcpip.ErrConnectStarted {
+ select {
+ case <-ctx.Done():
+ ep.Close()
+ return nil, ctx.Err()
+ case <-notifyCh:
+ }
+
+ err = ep.GetSockOpt(tcpip.ErrorOption{})
+ }
+ if err != nil {
+ ep.Close()
+ return nil, &net.OpError{
+ Op: "connect",
+ Net: "tcp",
+ Addr: fullToTCPAddr(addr),
+ Err: errors.New(err.String()),
+ }
+ }
+
+ return NewTCPConn(&wq, ep), nil
+}
+
+// A UDPConn is a wrapper around a UDP tcpip.Endpoint that implements
+// net.Conn and net.PacketConn.
+type UDPConn struct {
+ deadlineTimer
+
+ stack *stack.Stack
+ ep tcpip.Endpoint
+ wq *waiter.Queue
+}
+
+// NewUDPConn creates a new UDPConn.
+func NewUDPConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *UDPConn {
+ c := &UDPConn{
+ stack: s,
+ ep: ep,
+ wq: wq,
+ }
+ c.deadlineTimer.init()
+ return c
+}
+
+// DialUDP creates a new UDPConn.
+//
+// If laddr is nil, a local address is automatically chosen.
+//
+// If raddr is nil, the UDPConn is left unconnected.
+func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*UDPConn, error) {
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq)
+ if err != nil {
+ return nil, errors.New(err.String())
+ }
+
+ if laddr != nil {
+ if err := ep.Bind(*laddr); err != nil {
+ ep.Close()
+ return nil, &net.OpError{
+ Op: "bind",
+ Net: "udp",
+ Addr: fullToUDPAddr(*laddr),
+ Err: errors.New(err.String()),
+ }
+ }
+ }
+
+ c := NewUDPConn(s, &wq, ep)
+
+ if raddr != nil {
+ if err := c.ep.Connect(*raddr); err != nil {
+ c.ep.Close()
+ return nil, &net.OpError{
+ Op: "connect",
+ Net: "udp",
+ Addr: fullToUDPAddr(*raddr),
+ Err: errors.New(err.String()),
+ }
+ }
+ }
+
+ return c, nil
+}
+
+func (c *UDPConn) newOpError(op string, err error) *net.OpError {
+ return c.newRemoteOpError(op, nil, err)
+}
+
+func (c *UDPConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError {
+ return &net.OpError{
+ Op: op,
+ Net: "udp",
+ Source: c.LocalAddr(),
+ Addr: remote,
+ Err: err,
+ }
+}
+
+// RemoteAddr implements net.Conn.RemoteAddr.
+func (c *UDPConn) RemoteAddr() net.Addr {
+ a, err := c.ep.GetRemoteAddress()
+ if err != nil {
+ return nil
+ }
+ return fullToUDPAddr(a)
+}
+
+// Read implements net.Conn.Read
+func (c *UDPConn) Read(b []byte) (int, error) {
+ bytesRead, _, err := c.ReadFrom(b)
+ return bytesRead, err
+}
+
+// ReadFrom implements net.PacketConn.ReadFrom.
+func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
+ deadline := c.readCancel()
+
+ var addr tcpip.FullAddress
+ read, err := commonRead(c.ep, c.wq, deadline, &addr, c, false)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return copy(b, read), fullToUDPAddr(addr), nil
+}
+
+func (c *UDPConn) Write(b []byte) (int, error) {
+ return c.WriteTo(b, nil)
+}
+
+// WriteTo implements net.PacketConn.WriteTo.
+func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
+ deadline := c.writeCancel()
+
+ // Check if deadline has already expired.
+ select {
+ case <-deadline:
+ return 0, c.newRemoteOpError("write", addr, &timeoutError{})
+ default:
+ }
+
+ // If we're being called by Write, there is no addr
+ wopts := tcpip.WriteOptions{}
+ if addr != nil {
+ ua := addr.(*net.UDPAddr)
+ wopts.To = &tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)}
+ }
+
+ v := buffer.NewView(len(b))
+ copy(v, b)
+
+ n, resCh, err := c.ep.Write(tcpip.SlicePayload(v), wopts)
+ if resCh != nil {
+ select {
+ case <-deadline:
+ return int(n), c.newRemoteOpError("write", addr, &timeoutError{})
+ case <-resCh:
+ }
+
+ n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
+ }
+
+ if err == tcpip.ErrWouldBlock {
+ // Create wait queue entry that notifies a channel.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ c.wq.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.wq.EventUnregister(&waitEntry)
+ for {
+ select {
+ case <-deadline:
+ return int(n), c.newRemoteOpError("write", addr, &timeoutError{})
+ case <-notifyCh:
+ }
+
+ n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
+ if err != tcpip.ErrWouldBlock {
+ break
+ }
+ }
+ }
+
+ if err == nil {
+ return int(n), nil
+ }
+
+ return int(n), c.newRemoteOpError("write", addr, errors.New(err.String()))
+}
+
+// Close implements net.PacketConn.Close.
+func (c *UDPConn) Close() error {
+ c.ep.Close()
+ return nil
+}
+
+// LocalAddr implements net.PacketConn.LocalAddr.
+func (c *UDPConn) LocalAddr() net.Addr {
+ a, err := c.ep.GetLocalAddress()
+ if err != nil {
+ return nil
+ }
+ return fullToUDPAddr(a)
+}
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
new file mode 100644
index 000000000..3c552988a
--- /dev/null
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -0,0 +1,716 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gonet
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net"
+ "reflect"
+ "strings"
+ "testing"
+ "time"
+
+ "golang.org/x/net/nettest"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ NICID = 1
+)
+
+func TestTimeouts(t *testing.T) {
+ nc := NewTCPConn(nil, nil)
+ dlfs := []struct {
+ name string
+ f func(time.Time) error
+ }{
+ {"SetDeadline", nc.SetDeadline},
+ {"SetReadDeadline", nc.SetReadDeadline},
+ {"SetWriteDeadline", nc.SetWriteDeadline},
+ }
+
+ for _, dlf := range dlfs {
+ if err := dlf.f(time.Time{}); err != nil {
+ t.Errorf("got %s(time.Time{}) = %v, want = %v", dlf.name, err, nil)
+ }
+ }
+}
+
+func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
+ // Create the stack and add a NIC.
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()},
+ })
+
+ if err := s.CreateNIC(NICID, loopback.New()); err != nil {
+ return nil, err
+ }
+
+ // Add default route.
+ s.SetRouteTable([]tcpip.Route{
+ // IPv4
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: NICID,
+ },
+
+ // IPv6
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: NICID,
+ },
+ })
+
+ return s, nil
+}
+
+type testConnection struct {
+ wq *waiter.Queue
+ e *waiter.Entry
+ ch chan struct{}
+ ep tcpip.Endpoint
+}
+
+func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) {
+ wq := &waiter.Queue{}
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+
+ entry, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&entry, waiter.EventOut)
+
+ err = ep.Connect(addr)
+ if err == tcpip.ErrConnectStarted {
+ <-ch
+ err = ep.GetSockOpt(tcpip.ErrorOption{})
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ wq.EventUnregister(&entry)
+ wq.EventRegister(&entry, waiter.EventIn)
+
+ return &testConnection{wq, &entry, ch, ep}, nil
+}
+
+func (c *testConnection) close() {
+ c.wq.EventUnregister(c.e)
+ c.ep.Close()
+}
+
+// TestCloseReader tests that Conn.Close() causes Conn.Read() to unblock.
+func TestCloseReader(t *testing.T) {
+ s, err := newLoopbackStack()
+ if err != nil {
+ t.Fatalf("newLoopbackStack() = %v", err)
+ }
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
+
+ addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
+
+ s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+
+ l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
+ if e != nil {
+ t.Fatalf("NewListener() = %v", e)
+ }
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ c, err := l.Accept()
+ if err != nil {
+ t.Fatalf("l.Accept() = %v", err)
+ }
+
+ // Give c.Read() a chance to block before closing the connection.
+ time.AfterFunc(time.Millisecond*50, func() {
+ c.Close()
+ })
+
+ buf := make([]byte, 256)
+ n, err := c.Read(buf)
+ if n != 0 || err != io.EOF {
+ t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, err)
+ }
+ }()
+ sender, err := connect(s, addr)
+ if err != nil {
+ t.Fatalf("connect() = %v", err)
+ }
+
+ select {
+ case <-done:
+ case <-time.After(5 * time.Second):
+ t.Errorf("c.Read() didn't unblock")
+ }
+ sender.close()
+}
+
+// TestCloseReaderWithForwarder tests that TCPConn.Close wakes TCPConn.Read when
+// using tcp.Forwarder.
+func TestCloseReaderWithForwarder(t *testing.T) {
+ s, err := newLoopbackStack()
+ if err != nil {
+ t.Fatalf("newLoopbackStack() = %v", err)
+ }
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
+
+ addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+
+ done := make(chan struct{})
+
+ fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
+ defer close(done)
+
+ var wq waiter.Queue
+ ep, err := r.CreateEndpoint(&wq)
+ if err != nil {
+ t.Fatalf("r.CreateEndpoint() = %v", err)
+ }
+ defer ep.Close()
+ r.Complete(false)
+
+ c := NewTCPConn(&wq, ep)
+
+ // Give c.Read() a chance to block before closing the connection.
+ time.AfterFunc(time.Millisecond*50, func() {
+ c.Close()
+ })
+
+ buf := make([]byte, 256)
+ n, e := c.Read(buf)
+ if n != 0 || e != io.EOF {
+ t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, e)
+ }
+ })
+ s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
+
+ sender, err := connect(s, addr)
+ if err != nil {
+ t.Fatalf("connect() = %v", err)
+ }
+
+ select {
+ case <-done:
+ case <-time.After(5 * time.Second):
+ t.Errorf("c.Read() didn't unblock")
+ }
+ sender.close()
+}
+
+func TestCloseRead(t *testing.T) {
+ s, terr := newLoopbackStack()
+ if terr != nil {
+ t.Fatalf("newLoopbackStack() = %v", terr)
+ }
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
+
+ addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+
+ fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
+ var wq waiter.Queue
+ _, err := r.CreateEndpoint(&wq)
+ if err != nil {
+ t.Fatalf("r.CreateEndpoint() = %v", err)
+ }
+ // Endpoint will be closed in deferred s.Close (above).
+ })
+
+ s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
+
+ tc, terr := connect(s, addr)
+ if terr != nil {
+ t.Fatalf("connect() = %v", terr)
+ }
+ c := NewTCPConn(tc.wq, tc.ep)
+
+ if err := c.CloseRead(); err != nil {
+ t.Errorf("c.CloseRead() = %v", err)
+ }
+
+ buf := make([]byte, 256)
+ if n, err := c.Read(buf); err != io.EOF {
+ t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, err)
+ }
+
+ if n, err := c.Write([]byte("abc123")); n != 6 || err != nil {
+ t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, err)
+ }
+}
+
+func TestCloseWrite(t *testing.T) {
+ s, terr := newLoopbackStack()
+ if terr != nil {
+ t.Fatalf("newLoopbackStack() = %v", terr)
+ }
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
+
+ addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+
+ fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
+ var wq waiter.Queue
+ ep, err := r.CreateEndpoint(&wq)
+ if err != nil {
+ t.Fatalf("r.CreateEndpoint() = %v", err)
+ }
+ defer ep.Close()
+ r.Complete(false)
+
+ c := NewTCPConn(&wq, ep)
+
+ n, e := c.Read(make([]byte, 256))
+ if n != 0 || e != io.EOF {
+ t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, e)
+ }
+
+ if n, e = c.Write([]byte("abc123")); n != 6 || e != nil {
+ t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e)
+ }
+ })
+
+ s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
+
+ tc, terr := connect(s, addr)
+ if terr != nil {
+ t.Fatalf("connect() = %v", terr)
+ }
+ c := NewTCPConn(tc.wq, tc.ep)
+
+ if err := c.CloseWrite(); err != nil {
+ t.Errorf("c.CloseWrite() = %v", err)
+ }
+
+ buf := make([]byte, 256)
+ n, err := c.Read(buf)
+ if err != nil || string(buf[:n]) != "abc123" {
+ t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, err)
+ }
+
+ n, err = c.Write([]byte("abc123"))
+ got, ok := err.(*net.OpError)
+ want := "endpoint is closed for send"
+ if n != 0 || !ok || got.Op != "write" || got.Err == nil || !strings.HasSuffix(got.Err.Error(), want) {
+ t.Errorf("c.Write() = (%d, %v), want (0, OpError(Op: write, Err: %s))", n, err, want)
+ }
+}
+
+func TestUDPForwarder(t *testing.T) {
+ s, terr := newLoopbackStack()
+ if terr != nil {
+ t.Fatalf("newLoopbackStack() = %v", terr)
+ }
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
+
+ ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
+ addr1 := tcpip.FullAddress{NICID, ip1, 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
+ ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
+ addr2 := tcpip.FullAddress{NICID, ip2, 11311}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
+
+ done := make(chan struct{})
+ fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
+ defer close(done)
+
+ var wq waiter.Queue
+ ep, err := r.CreateEndpoint(&wq)
+ if err != nil {
+ t.Fatalf("r.CreateEndpoint() = %v", err)
+ }
+ defer ep.Close()
+
+ c := NewTCPConn(&wq, ep)
+
+ buf := make([]byte, 256)
+ n, e := c.Read(buf)
+ if e != nil {
+ t.Errorf("c.Read() = %v", e)
+ }
+
+ if _, e := c.Write(buf[:n]); e != nil {
+ t.Errorf("c.Write() = %v", e)
+ }
+ })
+ s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket)
+
+ c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatal("DialUDP(bind port 5):", err)
+ }
+
+ sent := "abc123"
+ sendAddr := fullToUDPAddr(addr1)
+ if n, err := c2.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) {
+ t.Errorf("c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil)
+ }
+
+ buf := make([]byte, 256)
+ n, recvAddr, err := c2.ReadFrom(buf)
+ if err != nil || recvAddr.String() != sendAddr.String() {
+ t.Errorf("c1.ReadFrom() = %d, %v, %v, want = %d, %v, %v", n, recvAddr, err, len(sent), sendAddr, nil)
+ }
+}
+
+// TestDeadlineChange tests that changing the deadline affects currently blocked reads.
+func TestDeadlineChange(t *testing.T) {
+ s, err := newLoopbackStack()
+ if err != nil {
+ t.Fatalf("newLoopbackStack() = %v", err)
+ }
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
+
+ addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
+
+ s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+
+ l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
+ if e != nil {
+ t.Fatalf("NewListener() = %v", e)
+ }
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ c, err := l.Accept()
+ if err != nil {
+ t.Fatalf("l.Accept() = %v", err)
+ }
+
+ c.SetDeadline(time.Now().Add(time.Minute))
+ // Give c.Read() a chance to block before closing the connection.
+ time.AfterFunc(time.Millisecond*50, func() {
+ c.SetDeadline(time.Now().Add(time.Millisecond * 10))
+ })
+
+ buf := make([]byte, 256)
+ n, err := c.Read(buf)
+ got, ok := err.(*net.OpError)
+ want := "i/o timeout"
+ if n != 0 || !ok || got.Err == nil || got.Err.Error() != want {
+ t.Errorf("c.Read() = (%d, %v), want (0, OpError(%s))", n, err, want)
+ }
+ }()
+ sender, err := connect(s, addr)
+ if err != nil {
+ t.Fatalf("connect() = %v", err)
+ }
+
+ select {
+ case <-done:
+ case <-time.After(time.Millisecond * 500):
+ t.Errorf("c.Read() didn't unblock")
+ }
+ sender.close()
+}
+
+func TestPacketConnTransfer(t *testing.T) {
+ s, e := newLoopbackStack()
+ if e != nil {
+ t.Fatalf("newLoopbackStack() = %v", e)
+ }
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
+
+ ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
+ addr1 := tcpip.FullAddress{NICID, ip1, 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
+ ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
+ addr2 := tcpip.FullAddress{NICID, ip2, 11311}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
+
+ c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatal("DialUDP(bind port 4):", err)
+ }
+ c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatal("DialUDP(bind port 5):", err)
+ }
+
+ c1.SetDeadline(time.Now().Add(time.Second))
+ c2.SetDeadline(time.Now().Add(time.Second))
+
+ sent := "abc123"
+ sendAddr := fullToUDPAddr(addr2)
+ if n, err := c1.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) {
+ t.Errorf("got c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil)
+ }
+ recv := make([]byte, len(sent))
+ n, recvAddr, err := c2.ReadFrom(recv)
+ if err != nil || n != len(recv) {
+ t.Errorf("got c2.ReadFrom() = %d, %v, want = %d, %v", n, err, len(recv), nil)
+ }
+
+ if recv := string(recv); recv != sent {
+ t.Errorf("got recv = %q, want = %q", recv, sent)
+ }
+
+ if want := fullToUDPAddr(addr1); !reflect.DeepEqual(recvAddr, want) {
+ t.Errorf("got recvAddr = %v, want = %v", recvAddr, want)
+ }
+
+ if err := c1.Close(); err != nil {
+ t.Error("c1.Close():", err)
+ }
+ if err := c2.Close(); err != nil {
+ t.Error("c2.Close():", err)
+ }
+}
+
+func TestConnectedPacketConnTransfer(t *testing.T) {
+ s, e := newLoopbackStack()
+ if e != nil {
+ t.Fatalf("newLoopbackStack() = %v", e)
+ }
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
+
+ ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
+ addr := tcpip.FullAddress{NICID, ip, 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+
+ c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatal("DialUDP(bind port 4):", err)
+ }
+ c2, err := DialUDP(s, nil, &addr, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatal("DialUDP(bind port 5):", err)
+ }
+
+ c1.SetDeadline(time.Now().Add(time.Second))
+ c2.SetDeadline(time.Now().Add(time.Second))
+
+ sent := "abc123"
+ if n, err := c2.Write([]byte(sent)); err != nil || n != len(sent) {
+ t.Errorf("got c2.Write(%q) = %d, %v, want = %d, %v", sent, n, err, len(sent), nil)
+ }
+ recv := make([]byte, len(sent))
+ n, err := c1.Read(recv)
+ if err != nil || n != len(recv) {
+ t.Errorf("got c1.Read() = %d, %v, want = %d, %v", n, err, len(recv), nil)
+ }
+
+ if recv := string(recv); recv != sent {
+ t.Errorf("got recv = %q, want = %q", recv, sent)
+ }
+
+ if err := c1.Close(); err != nil {
+ t.Error("c1.Close():", err)
+ }
+ if err := c2.Close(); err != nil {
+ t.Error("c2.Close():", err)
+ }
+}
+
+func makePipe() (c1, c2 net.Conn, stop func(), err error) {
+ s, e := newLoopbackStack()
+ if e != nil {
+ return nil, nil, nil, fmt.Errorf("newLoopbackStack() = %v", e)
+ }
+
+ ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
+ addr := tcpip.FullAddress{NICID, ip, 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+
+ l, err := ListenTCP(s, addr, ipv4.ProtocolNumber)
+ if err != nil {
+ return nil, nil, nil, fmt.Errorf("NewListener: %v", err)
+ }
+
+ c1, err = DialTCP(s, addr, ipv4.ProtocolNumber)
+ if err != nil {
+ l.Close()
+ return nil, nil, nil, fmt.Errorf("DialTCP: %v", err)
+ }
+
+ c2, err = l.Accept()
+ if err != nil {
+ l.Close()
+ c1.Close()
+ return nil, nil, nil, fmt.Errorf("l.Accept: %v", err)
+ }
+
+ stop = func() {
+ c1.Close()
+ c2.Close()
+ s.Close()
+ s.Wait()
+ }
+
+ if err := l.Close(); err != nil {
+ stop()
+ return nil, nil, nil, fmt.Errorf("l.Close(): %v", err)
+ }
+
+ return c1, c2, stop, nil
+}
+
+func TestTCPConnTransfer(t *testing.T) {
+ c1, c2, _, err := makePipe()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func() {
+ if err := c1.Close(); err != nil {
+ t.Error("c1.Close():", err)
+ }
+ if err := c2.Close(); err != nil {
+ t.Error("c2.Close():", err)
+ }
+ }()
+
+ c1.SetDeadline(time.Now().Add(time.Second))
+ c2.SetDeadline(time.Now().Add(time.Second))
+
+ const sent = "abc123"
+
+ tests := []struct {
+ name string
+ c1 net.Conn
+ c2 net.Conn
+ }{
+ {"connected to accepted", c1, c2},
+ {"accepted to connected", c2, c1},
+ }
+
+ for _, test := range tests {
+ if n, err := test.c1.Write([]byte(sent)); err != nil || n != len(sent) {
+ t.Errorf("%s: got test.c1.Write(%q) = %d, %v, want = %d, %v", test.name, sent, n, err, len(sent), nil)
+ continue
+ }
+
+ recv := make([]byte, len(sent))
+ n, err := test.c2.Read(recv)
+ if err != nil || n != len(recv) {
+ t.Errorf("%s: got test.c2.Read() = %d, %v, want = %d, %v", test.name, n, err, len(recv), nil)
+ continue
+ }
+
+ if recv := string(recv); recv != sent {
+ t.Errorf("%s: got recv = %q, want = %q", test.name, recv, sent)
+ }
+ }
+}
+
+func TestTCPDialError(t *testing.T) {
+ s, e := newLoopbackStack()
+ if e != nil {
+ t.Fatalf("newLoopbackStack() = %v", e)
+ }
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
+
+ ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
+ addr := tcpip.FullAddress{NICID, ip, 11211}
+
+ _, err := DialTCP(s, addr, ipv4.ProtocolNumber)
+ got, ok := err.(*net.OpError)
+ want := tcpip.ErrNoRoute
+ if !ok || got.Err.Error() != want.String() {
+ t.Errorf("Got DialTCP() = %v, want = %v", err, tcpip.ErrNoRoute)
+ }
+}
+
+func TestDialContextTCPCanceled(t *testing.T) {
+ s, err := newLoopbackStack()
+ if err != nil {
+ t.Fatalf("newLoopbackStack() = %v", err)
+ }
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
+
+ addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+
+ ctx := context.Background()
+ ctx, cancel := context.WithCancel(ctx)
+ cancel()
+
+ if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.Canceled {
+ t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.Canceled)
+ }
+}
+
+func TestDialContextTCPTimeout(t *testing.T) {
+ s, err := newLoopbackStack()
+ if err != nil {
+ t.Fatalf("newLoopbackStack() = %v", err)
+ }
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
+
+ addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+
+ fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
+ time.Sleep(time.Second)
+ r.Complete(true)
+ })
+ s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
+
+ ctx := context.Background()
+ ctx, cancel := context.WithDeadline(ctx, time.Now().Add(100*time.Millisecond))
+ defer cancel()
+
+ if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.DeadlineExceeded {
+ t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.DeadlineExceeded)
+ }
+}
+
+func TestNetTest(t *testing.T) {
+ nettest.TestConn(t, makePipe)
+}
diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD
new file mode 100644
index 000000000..563bc78ea
--- /dev/null
+++ b/pkg/tcpip/buffer/BUILD
@@ -0,0 +1,19 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "buffer",
+ srcs = [
+ "prependable.go",
+ "view.go",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+go_test(
+ name = "buffer_test",
+ size = "small",
+ srcs = ["view_test.go"],
+ library = ":buffer",
+)
diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go
new file mode 100644
index 000000000..ba21f4eca
--- /dev/null
+++ b/pkg/tcpip/buffer/prependable.go
@@ -0,0 +1,85 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+// Prependable is a buffer that grows backwards, that is, more data can be
+// prepended to it. It is useful when building networking packets, where each
+// protocol adds its own headers to the front of the higher-level protocol
+// header and payload; for example, TCP would prepend its header to the payload,
+// then IP would prepend its own, then ethernet.
+type Prependable struct {
+ // Buf is the buffer backing the prependable buffer.
+ buf View
+
+ // usedIdx is the index where the used part of the buffer begins.
+ usedIdx int
+}
+
+// NewPrependable allocates a new prependable buffer with the given size.
+func NewPrependable(size int) Prependable {
+ return Prependable{buf: NewView(size), usedIdx: size}
+}
+
+// NewPrependableFromView creates an entirely-used Prependable from a View.
+//
+// NewPrependableFromView takes ownership of v. Note that since the entire
+// prependable is used, further attempts to call Prepend will note that size >
+// p.usedIdx and return nil.
+func NewPrependableFromView(v View) Prependable {
+ return Prependable{buf: v, usedIdx: 0}
+}
+
+// NewEmptyPrependableFromView creates a new prependable buffer from a View.
+func NewEmptyPrependableFromView(v View) Prependable {
+ return Prependable{buf: v, usedIdx: len(v)}
+}
+
+// View returns a View of the backing buffer that contains all prepended
+// data so far.
+func (p Prependable) View() View {
+ return p.buf[p.usedIdx:]
+}
+
+// UsedLength returns the number of bytes used so far.
+func (p Prependable) UsedLength() int {
+ return len(p.buf) - p.usedIdx
+}
+
+// AvailableLength returns the number of bytes used so far.
+func (p Prependable) AvailableLength() int {
+ return p.usedIdx
+}
+
+// TrimBack removes size bytes from the end.
+func (p *Prependable) TrimBack(size int) {
+ p.buf = p.buf[:len(p.buf)-size]
+}
+
+// Prepend reserves the requested space in front of the buffer, returning a
+// slice that represents the reserved space.
+func (p *Prependable) Prepend(size int) []byte {
+ if size > p.usedIdx {
+ return nil
+ }
+
+ p.usedIdx -= size
+ return p.View()[:size:size]
+}
+
+// DeepCopy copies p and the bytes backing it.
+func (p Prependable) DeepCopy() Prependable {
+ p.buf = append(View(nil), p.buf...)
+ return p
+}
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
new file mode 100644
index 000000000..9a3c5d6c3
--- /dev/null
+++ b/pkg/tcpip/buffer/view.go
@@ -0,0 +1,256 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package buffer provides the implementation of a buffer view.
+package buffer
+
+import (
+ "bytes"
+ "io"
+)
+
+// View is a slice of a buffer, with convenience methods.
+type View []byte
+
+// NewView allocates a new buffer and returns an initialized view that covers
+// the whole buffer.
+func NewView(size int) View {
+ return make(View, size)
+}
+
+// NewViewFromBytes allocates a new buffer and copies in the given bytes.
+func NewViewFromBytes(b []byte) View {
+ return append(View(nil), b...)
+}
+
+// TrimFront removes the first "count" bytes from the visible section of the
+// buffer.
+func (v *View) TrimFront(count int) {
+ *v = (*v)[count:]
+}
+
+// CapLength irreversibly reduces the length of the visible section of the
+// buffer to the value specified.
+func (v *View) CapLength(length int) {
+ // We also set the slice cap because if we don't, one would be able to
+ // expand the view back to include the region just excluded. We want to
+ // prevent that to avoid potential data leak if we have uninitialized
+ // data in excluded region.
+ *v = (*v)[:length:length]
+}
+
+// Reader returns a bytes.Reader for v.
+func (v *View) Reader() bytes.Reader {
+ var r bytes.Reader
+ r.Reset(*v)
+ return r
+}
+
+// ToVectorisedView returns a VectorisedView containing the receiver.
+func (v View) ToVectorisedView() VectorisedView {
+ if len(v) == 0 {
+ return VectorisedView{}
+ }
+ return NewVectorisedView(len(v), []View{v})
+}
+
+// VectorisedView is a vectorised version of View using non contiguous memory.
+// It supports all the convenience methods supported by View.
+//
+// +stateify savable
+type VectorisedView struct {
+ views []View
+ size int
+}
+
+// NewVectorisedView creates a new vectorised view from an already-allocated slice
+// of View and sets its size.
+func NewVectorisedView(size int, views []View) VectorisedView {
+ return VectorisedView{views: views, size: size}
+}
+
+// TrimFront removes the first "count" bytes of the vectorised view. It panics
+// if count > vv.Size().
+func (vv *VectorisedView) TrimFront(count int) {
+ for count > 0 && len(vv.views) > 0 {
+ if count < len(vv.views[0]) {
+ vv.size -= count
+ vv.views[0].TrimFront(count)
+ return
+ }
+ count -= len(vv.views[0])
+ vv.removeFirst()
+ }
+}
+
+// Read implements io.Reader.
+func (vv *VectorisedView) Read(v View) (copied int, err error) {
+ count := len(v)
+ for count > 0 && len(vv.views) > 0 {
+ if count < len(vv.views[0]) {
+ vv.size -= count
+ copy(v[copied:], vv.views[0][:count])
+ vv.views[0].TrimFront(count)
+ copied += count
+ return copied, nil
+ }
+ count -= len(vv.views[0])
+ copy(v[copied:], vv.views[0])
+ copied += len(vv.views[0])
+ vv.removeFirst()
+ }
+ if copied == 0 {
+ return 0, io.EOF
+ }
+ return copied, nil
+}
+
+// ReadToVV reads up to n bytes from vv to dstVV and removes them from vv. It
+// returns the number of bytes copied.
+func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int) {
+ for count > 0 && len(vv.views) > 0 {
+ if count < len(vv.views[0]) {
+ vv.size -= count
+ dstVV.AppendView(vv.views[0][:count])
+ vv.views[0].TrimFront(count)
+ copied += count
+ return
+ }
+ count -= len(vv.views[0])
+ dstVV.AppendView(vv.views[0])
+ copied += len(vv.views[0])
+ vv.removeFirst()
+ }
+ return copied
+}
+
+// CapLength irreversibly reduces the length of the vectorised view.
+func (vv *VectorisedView) CapLength(length int) {
+ if length < 0 {
+ length = 0
+ }
+ if vv.size < length {
+ return
+ }
+ vv.size = length
+ for i := range vv.views {
+ v := &vv.views[i]
+ if len(*v) >= length {
+ if length == 0 {
+ vv.views = vv.views[:i]
+ } else {
+ v.CapLength(length)
+ vv.views = vv.views[:i+1]
+ }
+ return
+ }
+ length -= len(*v)
+ }
+}
+
+// Clone returns a clone of this VectorisedView.
+// If the buffer argument is large enough to contain all the Views of this VectorisedView,
+// the method will avoid allocations and use the buffer to store the Views of the clone.
+func (vv *VectorisedView) Clone(buffer []View) VectorisedView {
+ return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size}
+}
+
+// PullUp returns the first "count" bytes of the vectorised view. If those
+// bytes aren't already contiguous inside the vectorised view, PullUp will
+// reallocate as needed to make them contiguous. PullUp fails and returns false
+// when count > vv.Size().
+func (vv *VectorisedView) PullUp(count int) (View, bool) {
+ if len(vv.views) == 0 {
+ return nil, count == 0
+ }
+ if count <= len(vv.views[0]) {
+ return vv.views[0][:count], true
+ }
+ if count > vv.size {
+ return nil, false
+ }
+
+ newFirst := NewView(count)
+ i := 0
+ for offset := 0; offset < count; i++ {
+ copy(newFirst[offset:], vv.views[i])
+ if count-offset < len(vv.views[i]) {
+ vv.views[i].TrimFront(count - offset)
+ break
+ }
+ offset += len(vv.views[i])
+ vv.views[i] = nil
+ }
+ // We're guaranteed that i > 0, since count is too large for the first
+ // view.
+ vv.views[i-1] = newFirst
+ vv.views = vv.views[i-1:]
+ return newFirst, true
+}
+
+// Size returns the size in bytes of the entire content stored in the vectorised view.
+func (vv *VectorisedView) Size() int {
+ return vv.size
+}
+
+// ToView returns a single view containing the content of the vectorised view.
+//
+// If the vectorised view contains a single view, that view will be returned
+// directly.
+func (vv *VectorisedView) ToView() View {
+ if len(vv.views) == 1 {
+ return vv.views[0]
+ }
+ u := make([]byte, 0, vv.size)
+ for _, v := range vv.views {
+ u = append(u, v...)
+ }
+ return u
+}
+
+// Views returns the slice containing the all views.
+func (vv *VectorisedView) Views() []View {
+ return vv.views
+}
+
+// Append appends the views in a vectorised view to this vectorised view.
+func (vv *VectorisedView) Append(vv2 VectorisedView) {
+ vv.views = append(vv.views, vv2.views...)
+ vv.size += vv2.size
+}
+
+// AppendView appends the given view into this vectorised view.
+func (vv *VectorisedView) AppendView(v View) {
+ if len(v) == 0 {
+ return
+ }
+ vv.views = append(vv.views, v)
+ vv.size += len(v)
+}
+
+// Readers returns a bytes.Reader for each of vv's views.
+func (vv *VectorisedView) Readers() []bytes.Reader {
+ readers := make([]bytes.Reader, 0, len(vv.views))
+ for _, v := range vv.views {
+ readers = append(readers, v.Reader())
+ }
+ return readers
+}
+
+// removeFirst panics when len(vv.views) < 1.
+func (vv *VectorisedView) removeFirst() {
+ vv.size -= len(vv.views[0])
+ vv.views[0] = nil
+ vv.views = vv.views[1:]
+}
diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go
new file mode 100644
index 000000000..726e54de9
--- /dev/null
+++ b/pkg/tcpip/buffer/view_test.go
@@ -0,0 +1,521 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package buffer_test contains tests for the VectorisedView type.
+package buffer
+
+import (
+ "bytes"
+ "reflect"
+ "testing"
+)
+
+// copy returns a deep-copy of the vectorised view.
+func (vv VectorisedView) copy() VectorisedView {
+ uu := VectorisedView{
+ views: make([]View, 0, len(vv.views)),
+ size: vv.size,
+ }
+ for _, v := range vv.views {
+ uu.views = append(uu.views, append(View(nil), v...))
+ }
+ return uu
+}
+
+// vv is an helper to build VectorisedView from different strings.
+func vv(size int, pieces ...string) VectorisedView {
+ views := make([]View, len(pieces))
+ for i, p := range pieces {
+ views[i] = []byte(p)
+ }
+
+ return NewVectorisedView(size, views)
+}
+
+var capLengthTestCases = []struct {
+ comment string
+ in VectorisedView
+ length int
+ want VectorisedView
+}{
+ {
+ comment: "Simple case",
+ in: vv(2, "12"),
+ length: 1,
+ want: vv(1, "1"),
+ },
+ {
+ comment: "Case spanning across two Views",
+ in: vv(4, "123", "4"),
+ length: 2,
+ want: vv(2, "12"),
+ },
+ {
+ comment: "Corner case with negative length",
+ in: vv(1, "1"),
+ length: -1,
+ want: vv(0),
+ },
+ {
+ comment: "Corner case with length = 0",
+ in: vv(3, "12", "3"),
+ length: 0,
+ want: vv(0),
+ },
+ {
+ comment: "Corner case with length = size",
+ in: vv(1, "1"),
+ length: 1,
+ want: vv(1, "1"),
+ },
+ {
+ comment: "Corner case with length > size",
+ in: vv(1, "1"),
+ length: 2,
+ want: vv(1, "1"),
+ },
+}
+
+func TestCapLength(t *testing.T) {
+ for _, c := range capLengthTestCases {
+ orig := c.in.copy()
+ c.in.CapLength(c.length)
+ if !reflect.DeepEqual(c.in, c.want) {
+ t.Errorf("Test \"%s\" failed when calling CapLength(%d) on %v. Got %v. Want %v",
+ c.comment, c.length, orig, c.in, c.want)
+ }
+ }
+}
+
+var trimFrontTestCases = []struct {
+ comment string
+ in VectorisedView
+ count int
+ want VectorisedView
+}{
+ {
+ comment: "Simple case",
+ in: vv(2, "12"),
+ count: 1,
+ want: vv(1, "2"),
+ },
+ {
+ comment: "Case where we trim an entire View",
+ in: vv(2, "1", "2"),
+ count: 1,
+ want: vv(1, "2"),
+ },
+ {
+ comment: "Case spanning across two Views",
+ in: vv(3, "1", "23"),
+ count: 2,
+ want: vv(1, "3"),
+ },
+ {
+ comment: "Corner case with negative count",
+ in: vv(1, "1"),
+ count: -1,
+ want: vv(1, "1"),
+ },
+ {
+ comment: " Corner case with count = 0",
+ in: vv(1, "1"),
+ count: 0,
+ want: vv(1, "1"),
+ },
+ {
+ comment: "Corner case with count = size",
+ in: vv(1, "1"),
+ count: 1,
+ want: vv(0),
+ },
+ {
+ comment: "Corner case with count > size",
+ in: vv(1, "1"),
+ count: 2,
+ want: vv(0),
+ },
+}
+
+func TestTrimFront(t *testing.T) {
+ for _, c := range trimFrontTestCases {
+ orig := c.in.copy()
+ c.in.TrimFront(c.count)
+ if !reflect.DeepEqual(c.in, c.want) {
+ t.Errorf("Test \"%s\" failed when calling TrimFront(%d) on %v. Got %v. Want %v",
+ c.comment, c.count, orig, c.in, c.want)
+ }
+ }
+}
+
+var toViewCases = []struct {
+ comment string
+ in VectorisedView
+ want View
+}{
+ {
+ comment: "Simple case",
+ in: vv(2, "12"),
+ want: []byte("12"),
+ },
+ {
+ comment: "Case with multiple views",
+ in: vv(2, "1", "2"),
+ want: []byte("12"),
+ },
+ {
+ comment: "Empty case",
+ in: vv(0),
+ want: []byte(""),
+ },
+}
+
+func TestToView(t *testing.T) {
+ for _, c := range toViewCases {
+ got := c.in.ToView()
+ if !reflect.DeepEqual(got, c.want) {
+ t.Errorf("Test \"%s\" failed when calling ToView() on %v. Got %v. Want %v",
+ c.comment, c.in, got, c.want)
+ }
+ }
+}
+
+var toCloneCases = []struct {
+ comment string
+ inView VectorisedView
+ inBuffer []View
+}{
+ {
+ comment: "Simple case",
+ inView: vv(1, "1"),
+ inBuffer: make([]View, 1),
+ },
+ {
+ comment: "Case with multiple views",
+ inView: vv(2, "1", "2"),
+ inBuffer: make([]View, 2),
+ },
+ {
+ comment: "Case with buffer too small",
+ inView: vv(2, "1", "2"),
+ inBuffer: make([]View, 1),
+ },
+ {
+ comment: "Case with buffer larger than needed",
+ inView: vv(1, "1"),
+ inBuffer: make([]View, 2),
+ },
+ {
+ comment: "Case with nil buffer",
+ inView: vv(1, "1"),
+ inBuffer: nil,
+ },
+}
+
+func TestToClone(t *testing.T) {
+ for _, c := range toCloneCases {
+ t.Run(c.comment, func(t *testing.T) {
+ got := c.inView.Clone(c.inBuffer)
+ if !reflect.DeepEqual(got, c.inView) {
+ t.Fatalf("got (%+v).Clone(%+v) = %+v, want = %+v",
+ c.inView, c.inBuffer, got, c.inView)
+ }
+ })
+ }
+}
+
+func TestVVReadToVV(t *testing.T) {
+ testCases := []struct {
+ comment string
+ vv VectorisedView
+ bytesToRead int
+ wantBytes string
+ leftVV VectorisedView
+ }{
+ {
+ comment: "large VV, short read",
+ vv: vv(30, "012345678901234567890123456789"),
+ bytesToRead: 10,
+ wantBytes: "0123456789",
+ leftVV: vv(20, "01234567890123456789"),
+ },
+ {
+ comment: "largeVV, multiple views, short read",
+ vv: vv(13, "123", "345", "567", "8910"),
+ bytesToRead: 6,
+ wantBytes: "123345",
+ leftVV: vv(7, "567", "8910"),
+ },
+ {
+ comment: "smallVV (multiple views), large read",
+ vv: vv(3, "1", "2", "3"),
+ bytesToRead: 10,
+ wantBytes: "123",
+ leftVV: vv(0, ""),
+ },
+ {
+ comment: "smallVV (single view), large read",
+ vv: vv(1, "1"),
+ bytesToRead: 10,
+ wantBytes: "1",
+ leftVV: vv(0, ""),
+ },
+ {
+ comment: "emptyVV, large read",
+ vv: vv(0, ""),
+ bytesToRead: 10,
+ wantBytes: "",
+ leftVV: vv(0, ""),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.comment, func(t *testing.T) {
+ var readTo VectorisedView
+ inSize := tc.vv.Size()
+ copied := tc.vv.ReadToVV(&readTo, tc.bytesToRead)
+ if got, want := copied, len(tc.wantBytes); got != want {
+ t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc: %+v", got, want, tc)
+ }
+ if got, want := string(readTo.ToView()), tc.wantBytes; got != want {
+ t.Errorf("unexpected content in readTo got: %s, want: %s", got, want)
+ }
+ if got, want := tc.vv.Size(), inSize-copied; got != want {
+ t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv)
+ }
+ if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want {
+ t.Errorf("unexpected data left in vv after read got: %+v, want: %+v", got, want)
+ }
+ })
+ }
+}
+
+func TestVVRead(t *testing.T) {
+ testCases := []struct {
+ comment string
+ vv VectorisedView
+ bytesToRead int
+ readBytes string
+ leftBytes string
+ wantError bool
+ }{
+ {
+ comment: "large VV, short read",
+ vv: vv(30, "012345678901234567890123456789"),
+ bytesToRead: 10,
+ readBytes: "0123456789",
+ leftBytes: "01234567890123456789",
+ },
+ {
+ comment: "largeVV, multiple buffers, short read",
+ vv: vv(13, "123", "345", "567", "8910"),
+ bytesToRead: 6,
+ readBytes: "123345",
+ leftBytes: "5678910",
+ },
+ {
+ comment: "smallVV, large read",
+ vv: vv(3, "1", "2", "3"),
+ bytesToRead: 10,
+ readBytes: "123",
+ leftBytes: "",
+ },
+ {
+ comment: "smallVV, large read",
+ vv: vv(1, "1"),
+ bytesToRead: 10,
+ readBytes: "1",
+ leftBytes: "",
+ },
+ {
+ comment: "emptyVV, large read",
+ vv: vv(0, ""),
+ bytesToRead: 10,
+ readBytes: "",
+ wantError: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.comment, func(t *testing.T) {
+ readTo := NewView(tc.bytesToRead)
+ inSize := tc.vv.Size()
+ copied, err := tc.vv.Read(readTo)
+ if !tc.wantError && err != nil {
+ t.Fatalf("unexpected error in tc.vv.Read(..) = %s", err)
+ }
+ readTo = readTo[:copied]
+ if got, want := copied, len(tc.readBytes); got != want {
+ t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc.vv: %+v", got, want, tc.vv)
+ }
+ if got, want := string(readTo), tc.readBytes; got != want {
+ t.Errorf("unexpected data in readTo got: %s, want: %s", got, want)
+ }
+ if got, want := tc.vv.Size(), inSize-copied; got != want {
+ t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv)
+ }
+ if got, want := string(tc.vv.ToView()), tc.leftBytes; got != want {
+ t.Errorf("vv has incorrect data after Read got: %s, want: %s", got, want)
+ }
+ })
+ }
+}
+
+var pullUpTestCases = []struct {
+ comment string
+ in VectorisedView
+ count int
+ want []byte
+ result VectorisedView
+ ok bool
+}{
+ {
+ comment: "simple case",
+ in: vv(2, "12"),
+ count: 1,
+ want: []byte("1"),
+ result: vv(2, "12"),
+ ok: true,
+ },
+ {
+ comment: "entire View",
+ in: vv(2, "1", "2"),
+ count: 1,
+ want: []byte("1"),
+ result: vv(2, "1", "2"),
+ ok: true,
+ },
+ {
+ comment: "spanning across two Views",
+ in: vv(3, "1", "23"),
+ count: 2,
+ want: []byte("12"),
+ result: vv(3, "12", "3"),
+ ok: true,
+ },
+ {
+ comment: "spanning across all Views",
+ in: vv(5, "1", "23", "45"),
+ count: 5,
+ want: []byte("12345"),
+ result: vv(5, "12345"),
+ ok: true,
+ },
+ {
+ comment: "count = 0",
+ in: vv(1, "1"),
+ count: 0,
+ want: []byte{},
+ result: vv(1, "1"),
+ ok: true,
+ },
+ {
+ comment: "count = size",
+ in: vv(1, "1"),
+ count: 1,
+ want: []byte("1"),
+ result: vv(1, "1"),
+ ok: true,
+ },
+ {
+ comment: "count too large",
+ in: vv(3, "1", "23"),
+ count: 4,
+ want: nil,
+ result: vv(3, "1", "23"),
+ ok: false,
+ },
+ {
+ comment: "empty vv",
+ in: vv(0, ""),
+ count: 1,
+ want: nil,
+ result: vv(0, ""),
+ ok: false,
+ },
+ {
+ comment: "empty vv, count = 0",
+ in: vv(0, ""),
+ count: 0,
+ want: nil,
+ result: vv(0, ""),
+ ok: true,
+ },
+ {
+ comment: "empty views",
+ in: vv(3, "", "1", "", "23"),
+ count: 2,
+ want: []byte("12"),
+ result: vv(3, "12", "3"),
+ ok: true,
+ },
+}
+
+func TestPullUp(t *testing.T) {
+ for _, c := range pullUpTestCases {
+ got, ok := c.in.PullUp(c.count)
+
+ // Is the return value right?
+ if ok != c.ok {
+ t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t",
+ c.comment, c.count, c.in, ok, c.ok)
+ }
+ if bytes.Compare(got, View(c.want)) != 0 {
+ t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v",
+ c.comment, c.count, c.in, got, c.want)
+ }
+
+ // Is the underlying structure right?
+ if !reflect.DeepEqual(c.in, c.result) {
+ t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v",
+ c.comment, c.count, c.in, c.result)
+ }
+ }
+}
+
+func TestToVectorisedView(t *testing.T) {
+ testCases := []struct {
+ in View
+ want VectorisedView
+ }{
+ {nil, VectorisedView{}},
+ {View{}, VectorisedView{}},
+ {View{'a'}, VectorisedView{size: 1, views: []View{{'a'}}}},
+ }
+ for _, tc := range testCases {
+ if got, want := tc.in.ToVectorisedView(), tc.want; !reflect.DeepEqual(got, want) {
+ t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want)
+ }
+ }
+}
+
+func TestAppendView(t *testing.T) {
+ testCases := []struct {
+ vv VectorisedView
+ in View
+ want VectorisedView
+ }{
+ {VectorisedView{}, nil, VectorisedView{}},
+ {VectorisedView{}, View{}, VectorisedView{}},
+ {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, nil, VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}},
+ {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, View{}, VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}},
+ {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, View{'e'}, VectorisedView{[]View{{'a', 'b', 'c', 'd'}, {'e'}}, 5}},
+ }
+ for _, tc := range testCases {
+ tc.vv.AppendView(tc.in)
+ if got, want := tc.vv, tc.want; !reflect.DeepEqual(got, want) {
+ t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD
new file mode 100644
index 000000000..ed434807f
--- /dev/null
+++ b/pkg/tcpip/checker/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "checker",
+ testonly = 1,
+ srcs = ["checker.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ ],
+)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
new file mode 100644
index 000000000..ee264b726
--- /dev/null
+++ b/pkg/tcpip/checker/checker.go
@@ -0,0 +1,976 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package checker provides helper functions to check networking packets for
+// validity.
+package checker
+
+import (
+ "encoding/binary"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+// NetworkChecker is a function to check a property of a network packet.
+type NetworkChecker func(*testing.T, []header.Network)
+
+// TransportChecker is a function to check a property of a transport packet.
+type TransportChecker func(*testing.T, header.Transport)
+
+// ControlMessagesChecker is a function to check a property of ancillary data.
+type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages)
+
+// IPv4 checks the validity and properties of the given IPv4 packet. It is
+// expected to be used in conjunction with other network checkers for specific
+// properties. For example, to check the source and destination address, one
+// would call:
+//
+// checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y))
+func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) {
+ t.Helper()
+
+ ipv4 := header.IPv4(b)
+
+ if !ipv4.IsValid(len(b)) {
+ t.Error("Not a valid IPv4 packet")
+ }
+
+ xsum := ipv4.CalculateChecksum()
+ if xsum != 0 && xsum != 0xffff {
+ t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum())
+ }
+
+ for _, f := range checkers {
+ f(t, []header.Network{ipv4})
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+}
+
+// IPv6 checks the validity and properties of the given IPv6 packet. The usage
+// is similar to IPv4.
+func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) {
+ t.Helper()
+
+ ipv6 := header.IPv6(b)
+ if !ipv6.IsValid(len(b)) {
+ t.Error("Not a valid IPv6 packet")
+ }
+
+ for _, f := range checkers {
+ f(t, []header.Network{ipv6})
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+}
+
+// SrcAddr creates a checker that checks the source address.
+func SrcAddr(addr tcpip.Address) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ if a := h[0].SourceAddress(); a != addr {
+ t.Errorf("Bad source address, got %v, want %v", a, addr)
+ }
+ }
+}
+
+// DstAddr creates a checker that checks the destination address.
+func DstAddr(addr tcpip.Address) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ if a := h[0].DestinationAddress(); a != addr {
+ t.Errorf("Bad destination address, got %v, want %v", a, addr)
+ }
+ }
+}
+
+// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6).
+func TTL(ttl uint8) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ var v uint8
+ switch ip := h[0].(type) {
+ case header.IPv4:
+ v = ip.TTL()
+ case header.IPv6:
+ v = ip.HopLimit()
+ }
+ if v != ttl {
+ t.Fatalf("Bad TTL, got %v, want %v", v, ttl)
+ }
+ }
+}
+
+// PayloadLen creates a checker that checks the payload length.
+func PayloadLen(plen int) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ if l := len(h[0].Payload()); l != plen {
+ t.Errorf("Bad payload length, got %v, want %v", l, plen)
+ }
+ }
+}
+
+// FragmentOffset creates a checker that checks the FragmentOffset field.
+func FragmentOffset(offset uint16) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ // We only do this of IPv4 for now.
+ switch ip := h[0].(type) {
+ case header.IPv4:
+ if v := ip.FragmentOffset(); v != offset {
+ t.Errorf("Bad fragment offset, got %v, want %v", v, offset)
+ }
+ }
+ }
+}
+
+// FragmentFlags creates a checker that checks the fragment flags field.
+func FragmentFlags(flags uint8) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ // We only do this of IPv4 for now.
+ switch ip := h[0].(type) {
+ case header.IPv4:
+ if v := ip.Flags(); v != flags {
+ t.Errorf("Bad fragment offset, got %v, want %v", v, flags)
+ }
+ }
+ }
+}
+
+// ReceiveTClass creates a checker that checks the TCLASS field in
+// ControlMessages.
+func ReceiveTClass(want uint32) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasTClass {
+ t.Fatalf("got cm.HasTClass = %t, want cm.TClass = %d", cm.HasTClass, want)
+ }
+ if got := cm.TClass; got != want {
+ t.Fatalf("got cm.TClass = %d, want %d", got, want)
+ }
+ }
+}
+
+// ReceiveTOS creates a checker that checks the TOS field in ControlMessages.
+func ReceiveTOS(want uint8) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasTOS {
+ t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want)
+ }
+ if got := cm.TOS; got != want {
+ t.Fatalf("got cm.TOS = %d, want %d", got, want)
+ }
+ }
+}
+
+// TOS creates a checker that checks the TOS field.
+func TOS(tos uint8, label uint32) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ if v, l := h[0].TOS(); v != tos || l != label {
+ t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label)
+ }
+ }
+}
+
+// Raw creates a checker that checks the bytes of payload.
+// The checker always checks the payload of the last network header.
+// For instance, in case of IPv6 fragments, the payload that will be checked
+// is the one containing the actual data that the packet is carrying, without
+// the bytes added by the IPv6 fragmentation.
+func Raw(want []byte) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) {
+ t.Errorf("Wrong payload, got %v, want %v", got, want)
+ }
+ }
+}
+
+// IPv6Fragment creates a checker that validates an IPv6 fragment.
+func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
+ t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ }
+
+ ipv6Frag := header.IPv6Fragment(h[0].Payload())
+ if !ipv6Frag.IsValid() {
+ t.Error("Not a valid IPv6 fragment")
+ }
+
+ for _, f := range checkers {
+ f(t, []header.Network{h[0], ipv6Frag})
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// TCP creates a checker that checks that the transport protocol is TCP and
+// potentially additional transport header fields.
+func TCP(checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ first := h[0]
+ last := h[len(h)-1]
+
+ if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
+ t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber)
+ }
+
+ // Verify the checksum.
+ tcp := header.TCP(last.Payload())
+ l := uint16(len(tcp))
+
+ xsum := header.Checksum([]byte(first.SourceAddress()), 0)
+ xsum = header.Checksum([]byte(first.DestinationAddress()), xsum)
+ xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum)
+ xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum)
+ xsum = header.Checksum(tcp, xsum)
+
+ if xsum != 0 && xsum != 0xffff {
+ t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum())
+ }
+
+ // Run the transport checkers.
+ for _, f := range checkers {
+ f(t, tcp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// UDP creates a checker that checks that the transport protocol is UDP and
+// potentially additional transport header fields.
+func UDP(checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ last := h[len(h)-1]
+
+ if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
+ t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ }
+
+ udp := header.UDP(last.Payload())
+ for _, f := range checkers {
+ f(t, udp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// SrcPort creates a checker that checks the source port.
+func SrcPort(port uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ if p := h.SourcePort(); p != port {
+ t.Errorf("Bad source port, got %v, want %v", p, port)
+ }
+ }
+}
+
+// DstPort creates a checker that checks the destination port.
+func DstPort(port uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ if p := h.DestinationPort(); p != port {
+ t.Errorf("Bad destination port, got %v, want %v", p, port)
+ }
+ }
+}
+
+// NoChecksum creates a checker that checks if the checksum is zero.
+func NoChecksum(noChecksum bool) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ udp, ok := h.(header.UDP)
+ if !ok {
+ return
+ }
+
+ if b := udp.Checksum() == 0; b != noChecksum {
+ t.Errorf("bad checksum state, got %t, want %t", b, noChecksum)
+ }
+ }
+}
+
+// SeqNum creates a checker that checks the sequence number.
+func SeqNum(seq uint32) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+
+ if s := tcp.SequenceNumber(); s != seq {
+ t.Errorf("Bad sequence number, got %v, want %v", s, seq)
+ }
+ }
+}
+
+// AckNum creates a checker that checks the ack number.
+func AckNum(seq uint32) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+
+ if s := tcp.AckNumber(); s != seq {
+ t.Errorf("Bad ack number, got %v, want %v", s, seq)
+ }
+ }
+}
+
+// Window creates a checker that checks the tcp window.
+func Window(window uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+
+ if w := tcp.WindowSize(); w != window {
+ t.Errorf("Bad window, got 0x%x, want 0x%x", w, window)
+ }
+ }
+}
+
+// TCPFlags creates a checker that checks the tcp flags.
+func TCPFlags(flags uint8) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+
+ if f := tcp.Flags(); f != flags {
+ t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags)
+ }
+ }
+}
+
+// TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the
+// given mask, match the supplied flags.
+func TCPFlagsMatch(flags, mask uint8) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+
+ if f := tcp.Flags(); (f & mask) != (flags & mask) {
+ t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask)
+ }
+ }
+}
+
+// TCPSynOptions creates a checker that checks the presence of TCP options in
+// SYN segments.
+//
+// If wndscale is negative, the window scale option must not be present.
+func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+ opts := tcp.Options()
+ limit := len(opts)
+ foundMSS := false
+ foundWS := false
+ foundTS := false
+ foundSACKPermitted := false
+ tsVal := uint32(0)
+ tsEcr := uint32(0)
+ for i := 0; i < limit; {
+ switch opts[i] {
+ case header.TCPOptionEOL:
+ i = limit
+ case header.TCPOptionNOP:
+ i++
+ case header.TCPOptionMSS:
+ v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
+ if wantOpts.MSS != v {
+ t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS)
+ }
+ foundMSS = true
+ i += 4
+ case header.TCPOptionWS:
+ if wantOpts.WS < 0 {
+ t.Error("WS present when it shouldn't be")
+ }
+ v := int(opts[i+2])
+ if v != wantOpts.WS {
+ t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS)
+ }
+ foundWS = true
+ i += 3
+ case header.TCPOptionTS:
+ if i+9 >= limit {
+ t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i)
+ }
+ if opts[i+1] != 10 {
+ t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit)
+ }
+ tsVal = binary.BigEndian.Uint32(opts[i+2:])
+ tsEcr = uint32(0)
+ if tcp.Flags()&header.TCPFlagAck != 0 {
+ // If the syn is an SYN-ACK then read
+ // the tsEcr value as well.
+ tsEcr = binary.BigEndian.Uint32(opts[i+6:])
+ }
+ foundTS = true
+ i += 10
+ case header.TCPOptionSACKPermitted:
+ if i+1 >= limit {
+ t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i)
+ }
+ if opts[i+1] != 2 {
+ t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit)
+ }
+ foundSACKPermitted = true
+ i += 2
+
+ default:
+ i += int(opts[i+1])
+ }
+ }
+
+ if !foundMSS {
+ t.Errorf("MSS option not found. Options: %x", opts)
+ }
+
+ if !foundWS && wantOpts.WS >= 0 {
+ t.Errorf("WS option not found. Options: %x", opts)
+ }
+ if wantOpts.TS && !foundTS {
+ t.Errorf("TS option not found. Options: %x", opts)
+ }
+ if foundTS && tsVal == 0 {
+ t.Error("TS option specified but the timestamp value is zero")
+ }
+ if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
+ t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr)
+ }
+ if wantOpts.SACKPermitted && !foundSACKPermitted {
+ t.Errorf("SACKPermitted option not found. Options: %x", opts)
+ }
+ }
+}
+
+// TCPTimestampChecker creates a checker that validates that a TCP segment has a
+// TCP Timestamp option if wantTS is true, it also compares the wantTSVal and
+// wantTSEcr values with those in the TCP segment (if present).
+//
+// If wantTSVal or wantTSEcr is zero then the corresponding comparison is
+// skipped.
+func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+ opts := []byte(tcp.Options())
+ limit := len(opts)
+ foundTS := false
+ tsVal := uint32(0)
+ tsEcr := uint32(0)
+ for i := 0; i < limit; {
+ switch opts[i] {
+ case header.TCPOptionEOL:
+ i = limit
+ case header.TCPOptionNOP:
+ i++
+ case header.TCPOptionTS:
+ if i+9 >= limit {
+ t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
+ }
+ if opts[i+1] != 10 {
+ t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1])
+ }
+ tsVal = binary.BigEndian.Uint32(opts[i+2:])
+ tsEcr = binary.BigEndian.Uint32(opts[i+6:])
+ foundTS = true
+ i += 10
+ default:
+ // We don't recognize this option, just skip over it.
+ if i+2 > limit {
+ return
+ }
+ l := int(opts[i+1])
+ if i < 2 || i+l > limit {
+ return
+ }
+ i += l
+ }
+ }
+
+ if wantTS != foundTS {
+ t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS)
+ }
+ if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
+ t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal)
+ }
+ if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
+ t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr)
+ }
+ }
+}
+
+// TCPNoSACKBlockChecker creates a checker that verifies that the segment does not
+// contain any SACK blocks in the TCP options.
+func TCPNoSACKBlockChecker() TransportChecker {
+ return TCPSACKBlockChecker(nil)
+}
+
+// TCPSACKBlockChecker creates a checker that verifies that the segment does
+// contain the specified SACK blocks in the TCP options.
+func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+ var gotSACKBlocks []header.SACKBlock
+
+ opts := []byte(tcp.Options())
+ limit := len(opts)
+ for i := 0; i < limit; {
+ switch opts[i] {
+ case header.TCPOptionEOL:
+ i = limit
+ case header.TCPOptionNOP:
+ i++
+ case header.TCPOptionSACK:
+ if i+2 > limit {
+ // Malformed SACK block.
+ t.Errorf("malformed SACK option in options: %v", opts)
+ }
+ sackOptionLen := int(opts[i+1])
+ if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
+ // Malformed SACK block.
+ t.Errorf("malformed SACK option length in options: %v", opts)
+ }
+ numBlocks := sackOptionLen / 8
+ for j := 0; j < numBlocks; j++ {
+ start := binary.BigEndian.Uint32(opts[i+2+j*8:])
+ end := binary.BigEndian.Uint32(opts[i+2+j*8+4:])
+ gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{
+ Start: seqnum.Value(start),
+ End: seqnum.Value(end),
+ })
+ }
+ i += sackOptionLen
+ default:
+ // We don't recognize this option, just skip over it.
+ if i+2 > limit {
+ break
+ }
+ l := int(opts[i+1])
+ if l < 2 || i+l > limit {
+ break
+ }
+ i += l
+ }
+ }
+
+ if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) {
+ t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks)
+ }
+ }
+}
+
+// Payload creates a checker that checks the payload.
+func Payload(want []byte) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ if got := h.Payload(); !reflect.DeepEqual(got, want) {
+ t.Errorf("Wrong payload, got %v, want %v", got, want)
+ }
+ }
+}
+
+// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and
+// potentially additional ICMPv4 header fields.
+func ICMPv4(checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ last := h[len(h)-1]
+
+ if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber {
+ t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber)
+ }
+
+ icmp := header.ICMPv4(last.Payload())
+ for _, f := range checkers {
+ f(t, icmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// ICMPv4Type creates a checker that checks the ICMPv4 Type field.
+func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ }
+ if got := icmpv4.Type(); got != want {
+ t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Code creates a checker that checks the ICMPv4 Code field.
+func ICMPv4Code(want byte) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ }
+ if got := icmpv4.Code(); got != want {
+ t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ }
+ }
+}
+
+// ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and
+// potentially additional ICMPv6 header fields.
+//
+// ICMPv6 will validate the checksum field before calling checkers.
+func ICMPv6(checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ last := h[len(h)-1]
+
+ if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber {
+ t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber)
+ }
+
+ icmp := header.ICMPv6(last.Payload())
+ if got, want := icmp.Checksum(), header.ICMPv6Checksum(icmp, last.SourceAddress(), last.DestinationAddress(), buffer.VectorisedView{}); got != want {
+ t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want)
+ }
+
+ for _, f := range checkers {
+ f(t, icmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// ICMPv6Type creates a checker that checks the ICMPv6 Type field.
+func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv6, ok := h.(header.ICMPv6)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ }
+ if got := icmpv6.Type(); got != want {
+ t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ }
+ }
+}
+
+// ICMPv6Code creates a checker that checks the ICMPv6 Code field.
+func ICMPv6Code(want byte) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv6, ok := h.(header.ICMPv6)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ }
+ if got := icmpv6.Code(); got != want {
+ t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ }
+ }
+}
+
+// NDP creates a checker that checks that the packet contains a valid NDP
+// message for type of ty, with potentially additional checks specified by
+// checkers.
+//
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDP message as far as the size of the message (minSize) is concerned. The
+// values within the message are up to checkers to validate.
+func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ // Check normal ICMPv6 first.
+ ICMPv6(
+ ICMPv6Type(msgType),
+ ICMPv6Code(0))(t, h)
+
+ last := h[len(h)-1]
+
+ icmp := header.ICMPv6(last.Payload())
+ if got := len(icmp.NDPPayload()); got < minSize {
+ t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
+ }
+
+ for _, f := range checkers {
+ f(t, icmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// NDPNS creates a checker that checks that the packet contains a valid NDP
+// Neighbor Solicitation message (as per the raw wire format), with potentially
+// additional checks specified by checkers.
+//
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDPNS message as far as the size of the message is concerned. The values
+// within the message are up to checkers to validate.
+func NDPNS(checkers ...TransportChecker) NetworkChecker {
+ return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...)
+}
+
+// NDPNSTargetAddress creates a checker that checks the Target Address field of
+// a header.NDPNeighborSolicit.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNS message as far as the size is concerned.
+func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+
+ if got := ns.TargetAddress(); got != want {
+ t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want)
+ }
+ }
+}
+
+// NDPNA creates a checker that checks that the packet contains a valid NDP
+// Neighbor Advertisement message (as per the raw wire format), with potentially
+// additional checks specified by checkers.
+//
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDPNA message as far as the size of the message is concerned. The values
+// within the message are up to checkers to validate.
+func NDPNA(checkers ...TransportChecker) NetworkChecker {
+ return NDP(header.ICMPv6NeighborAdvert, header.NDPNAMinimumSize, checkers...)
+}
+
+// NDPNATargetAddress creates a checker that checks the Target Address field of
+// a header.NDPNeighborAdvert.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNA message as far as the size is concerned.
+func NDPNATargetAddress(want tcpip.Address) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+
+ if got := na.TargetAddress(); got != want {
+ t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want)
+ }
+ }
+}
+
+// NDPNASolicitedFlag creates a checker that checks the Solicited field of
+// a header.NDPNeighborAdvert.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNA message as far as the size is concerned.
+func NDPNASolicitedFlag(want bool) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+
+ if got := na.SolicitedFlag(); got != want {
+ t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want)
+ }
+ }
+}
+
+// ndpOptions checks that optsBuf only contains opts.
+func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) {
+ t.Helper()
+
+ it, err := optsBuf.Iter(true)
+ if err != nil {
+ t.Errorf("optsBuf.Iter(true): %s", err)
+ return
+ }
+
+ i := 0
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ // This should never happen as Iter(true) above did not return an error.
+ t.Fatalf("unexpected error when iterating over NDP options: %s", err)
+ }
+ if done {
+ break
+ }
+
+ if i >= len(opts) {
+ t.Errorf("got unexpected option: %s", opt)
+ continue
+ }
+
+ switch wantOpt := opts[i].(type) {
+ case header.NDPSourceLinkLayerAddressOption:
+ gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption)
+ if !ok {
+ t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
+ } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
+ t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
+ }
+ case header.NDPTargetLinkLayerAddressOption:
+ gotOpt, ok := opt.(header.NDPTargetLinkLayerAddressOption)
+ if !ok {
+ t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
+ } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
+ t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
+ }
+ default:
+ t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt)
+ }
+
+ i++
+ }
+
+ if missing := opts[i:]; len(missing) > 0 {
+ t.Errorf("missing options: %s", missing)
+ }
+}
+
+// NDPNAOptions creates a checker that checks that the packet contains the
+// provided NDP options within an NDP Neighbor Solicitation message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNA message as far as the size is concerned.
+func NDPNAOptions(opts []header.NDPOption) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ ndpOptions(t, na.Options(), opts)
+ }
+}
+
+// NDPNSOptions creates a checker that checks that the packet contains the
+// provided NDP options within an NDP Neighbor Solicitation message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNS message as far as the size is concerned.
+func NDPNSOptions(opts []header.NDPOption) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ndpOptions(t, ns.Options(), opts)
+ }
+}
+
+// NDPRS creates a checker that checks that the packet contains a valid NDP
+// Router Solicitation message (as per the raw wire format).
+//
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDPRS as far as the size of the message is concerned. The values within the
+// message are up to checkers to validate.
+func NDPRS(checkers ...TransportChecker) NetworkChecker {
+ return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...)
+}
+
+// NDPRSOptions creates a checker that checks that the packet contains the
+// provided NDP options within an NDP Router Solicitation message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPRS message as far as the size is concerned.
+func NDPRSOptions(opts []header.NDPOption) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ rs := header.NDPRouterSolicit(icmp.NDPPayload())
+ ndpOptions(t, rs.Options(), opts)
+ }
+}
diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD
new file mode 100644
index 000000000..ff2719291
--- /dev/null
+++ b/pkg/tcpip/hash/jenkins/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "jenkins",
+ srcs = ["jenkins.go"],
+ visibility = ["//visibility:public"],
+)
+
+go_test(
+ name = "jenkins_test",
+ size = "small",
+ srcs = [
+ "jenkins_test.go",
+ ],
+ library = ":jenkins",
+)
diff --git a/pkg/tcpip/hash/jenkins/jenkins.go b/pkg/tcpip/hash/jenkins/jenkins.go
new file mode 100644
index 000000000..52c22230e
--- /dev/null
+++ b/pkg/tcpip/hash/jenkins/jenkins.go
@@ -0,0 +1,80 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package jenkins implements Jenkins's one_at_a_time, non-cryptographic hash
+// functions created by by Bob Jenkins.
+//
+// See https://en.wikipedia.org/wiki/Jenkins_hash_function#cite_note-dobbsx-1
+//
+package jenkins
+
+import (
+ "hash"
+)
+
+// Sum32 represents Jenkins's one_at_a_time hash.
+//
+// Use the Sum32 type directly (as opposed to New32 below)
+// to avoid allocations.
+type Sum32 uint32
+
+// New32 returns a new 32-bit Jenkins's one_at_a_time hash.Hash.
+//
+// Its Sum method will lay the value out in big-endian byte order.
+func New32() hash.Hash32 {
+ var s Sum32
+ return &s
+}
+
+// Reset resets the hash to its initial state.
+func (s *Sum32) Reset() { *s = 0 }
+
+// Sum32 returns the hash value
+func (s *Sum32) Sum32() uint32 {
+ hash := *s
+
+ hash += (hash << 3)
+ hash ^= hash >> 11
+ hash += hash << 15
+
+ return uint32(hash)
+}
+
+// Write adds more data to the running hash.
+//
+// It never returns an error.
+func (s *Sum32) Write(data []byte) (int, error) {
+ hash := *s
+ for _, b := range data {
+ hash += Sum32(b)
+ hash += hash << 10
+ hash ^= hash >> 6
+ }
+ *s = hash
+ return len(data), nil
+}
+
+// Size returns the number of bytes Sum will return.
+func (s *Sum32) Size() int { return 4 }
+
+// BlockSize returns the hash's underlying block size.
+func (s *Sum32) BlockSize() int { return 1 }
+
+// Sum appends the current hash to in and returns the resulting slice.
+//
+// It does not change the underlying hash state.
+func (s *Sum32) Sum(in []byte) []byte {
+ v := s.Sum32()
+ return append(in, byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
+}
diff --git a/pkg/tcpip/hash/jenkins/jenkins_test.go b/pkg/tcpip/hash/jenkins/jenkins_test.go
new file mode 100644
index 000000000..4c78b5808
--- /dev/null
+++ b/pkg/tcpip/hash/jenkins/jenkins_test.go
@@ -0,0 +1,176 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package jenkins
+
+import (
+ "bytes"
+ "encoding/binary"
+ "hash"
+ "hash/fnv"
+ "math"
+ "testing"
+)
+
+func TestGolden32(t *testing.T) {
+ var golden32 = []struct {
+ out []byte
+ in string
+ }{
+ {[]byte{0x00, 0x00, 0x00, 0x00}, ""},
+ {[]byte{0xca, 0x2e, 0x94, 0x42}, "a"},
+ {[]byte{0x45, 0xe6, 0x1e, 0x58}, "ab"},
+ {[]byte{0xed, 0x13, 0x1f, 0x5b}, "abc"},
+ }
+
+ hash := New32()
+
+ for _, g := range golden32 {
+ hash.Reset()
+ done, error := hash.Write([]byte(g.in))
+ if error != nil {
+ t.Fatalf("write error: %s", error)
+ }
+ if done != len(g.in) {
+ t.Fatalf("wrote only %d out of %d bytes", done, len(g.in))
+ }
+ if actual := hash.Sum(nil); !bytes.Equal(g.out, actual) {
+ t.Errorf("hash(%q) = 0x%x want 0x%x", g.in, actual, g.out)
+ }
+ }
+}
+
+func TestIntegrity32(t *testing.T) {
+ data := []byte{'1', '2', 3, 4, 5}
+
+ h := New32()
+ h.Write(data)
+ sum := h.Sum(nil)
+
+ if size := h.Size(); size != len(sum) {
+ t.Fatalf("Size()=%d but len(Sum())=%d", size, len(sum))
+ }
+
+ if a := h.Sum(nil); !bytes.Equal(sum, a) {
+ t.Fatalf("first Sum()=0x%x, second Sum()=0x%x", sum, a)
+ }
+
+ h.Reset()
+ h.Write(data)
+ if a := h.Sum(nil); !bytes.Equal(sum, a) {
+ t.Fatalf("Sum()=0x%x, but after Reset() Sum()=0x%x", sum, a)
+ }
+
+ h.Reset()
+ h.Write(data[:2])
+ h.Write(data[2:])
+ if a := h.Sum(nil); !bytes.Equal(sum, a) {
+ t.Fatalf("Sum()=0x%x, but with partial writes, Sum()=0x%x", sum, a)
+ }
+
+ sum32 := h.(hash.Hash32).Sum32()
+ if sum32 != binary.BigEndian.Uint32(sum) {
+ t.Fatalf("Sum()=0x%x, but Sum32()=0x%x", sum, sum32)
+ }
+}
+
+func BenchmarkJenkins32KB(b *testing.B) {
+ h := New32()
+
+ b.SetBytes(1024)
+ data := make([]byte, 1024)
+ for i := range data {
+ data[i] = byte(i)
+ }
+ in := make([]byte, 0, h.Size())
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ h.Reset()
+ h.Write(data)
+ h.Sum(in)
+ }
+}
+
+func BenchmarkFnv32(b *testing.B) {
+ arr := make([]int64, 1000)
+ for i := 0; i < b.N; i++ {
+ var payload [8]byte
+ binary.BigEndian.PutUint32(payload[:4], uint32(i))
+ binary.BigEndian.PutUint32(payload[4:], uint32(i))
+
+ h := fnv.New32()
+ h.Write(payload[:])
+ idx := int(h.Sum32()) % len(arr)
+ arr[idx]++
+ }
+ b.StopTimer()
+ c := 0
+ if b.N > 1000000 {
+ for i := 0; i < len(arr)-1; i++ {
+ if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) {
+ if c == 0 {
+ b.Logf("i %d val[i] %d val[i+1] %d b.N %b\n", i, arr[i], arr[i+1], b.N)
+ }
+ c++
+ }
+ }
+ if c > 0 {
+ b.Logf("Unbalanced buckets: %d", c)
+ }
+ }
+}
+
+func BenchmarkSum32(b *testing.B) {
+ arr := make([]int64, 1000)
+ for i := 0; i < b.N; i++ {
+ var payload [8]byte
+ binary.BigEndian.PutUint32(payload[:4], uint32(i))
+ binary.BigEndian.PutUint32(payload[4:], uint32(i))
+ h := Sum32(0)
+ h.Write(payload[:])
+ idx := int(h.Sum32()) % len(arr)
+ arr[idx]++
+ }
+ b.StopTimer()
+ if b.N > 1000000 {
+ for i := 0; i < len(arr)-1; i++ {
+ if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) {
+ b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N)
+ break
+ }
+ }
+ }
+}
+
+func BenchmarkNew32(b *testing.B) {
+ arr := make([]int64, 1000)
+ for i := 0; i < b.N; i++ {
+ var payload [8]byte
+ binary.BigEndian.PutUint32(payload[:4], uint32(i))
+ binary.BigEndian.PutUint32(payload[4:], uint32(i))
+ h := New32()
+ h.Write(payload[:])
+ idx := int(h.Sum32()) % len(arr)
+ arr[idx]++
+ }
+ b.StopTimer()
+ if b.N > 1000000 {
+ for i := 0; i < len(arr)-1; i++ {
+ if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) {
+ b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N)
+ break
+ }
+ }
+ }
+}
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
new file mode 100644
index 000000000..0cde694dc
--- /dev/null
+++ b/pkg/tcpip/header/BUILD
@@ -0,0 +1,69 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "header",
+ srcs = [
+ "arp.go",
+ "checksum.go",
+ "eth.go",
+ "gue.go",
+ "icmpv4.go",
+ "icmpv6.go",
+ "interfaces.go",
+ "ipv4.go",
+ "ipv6.go",
+ "ipv6_extension_headers.go",
+ "ipv6_fragment.go",
+ "ndp_neighbor_advert.go",
+ "ndp_neighbor_solicit.go",
+ "ndp_options.go",
+ "ndp_router_advert.go",
+ "ndp_router_solicit.go",
+ "ndpoptionidentifier_string.go",
+ "tcp.go",
+ "udp.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/seqnum",
+ "@com_github_google_btree//:go_default_library",
+ ],
+)
+
+go_test(
+ name = "header_x_test",
+ size = "small",
+ srcs = [
+ "checksum_test.go",
+ "ipv6_test.go",
+ "ipversion_test.go",
+ "tcp_test.go",
+ ],
+ deps = [
+ ":header",
+ "//pkg/rand",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ ],
+)
+
+go_test(
+ name = "header_test",
+ size = "small",
+ srcs = [
+ "eth_test.go",
+ "ipv6_extension_headers_test.go",
+ "ndp_test.go",
+ ],
+ library = ":header",
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/header/arp.go b/pkg/tcpip/header/arp.go
new file mode 100644
index 000000000..718a4720a
--- /dev/null
+++ b/pkg/tcpip/header/arp.go
@@ -0,0 +1,100 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import "gvisor.dev/gvisor/pkg/tcpip"
+
+const (
+ // ARPProtocolNumber is the ARP network protocol number.
+ ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806
+
+ // ARPSize is the size of an IPv4-over-Ethernet ARP packet.
+ ARPSize = 2 + 2 + 1 + 1 + 2 + 2*6 + 2*4
+)
+
+// ARPOp is an ARP opcode.
+type ARPOp uint16
+
+// Typical ARP opcodes defined in RFC 826.
+const (
+ ARPRequest ARPOp = 1
+ ARPReply ARPOp = 2
+)
+
+// ARP is an ARP packet stored in a byte array as described in RFC 826.
+type ARP []byte
+
+func (a ARP) hardwareAddressSpace() uint16 { return uint16(a[0])<<8 | uint16(a[1]) }
+func (a ARP) protocolAddressSpace() uint16 { return uint16(a[2])<<8 | uint16(a[3]) }
+func (a ARP) hardwareAddressSize() int { return int(a[4]) }
+func (a ARP) protocolAddressSize() int { return int(a[5]) }
+
+// Op is the ARP opcode.
+func (a ARP) Op() ARPOp { return ARPOp(a[6])<<8 | ARPOp(a[7]) }
+
+// SetOp sets the ARP opcode.
+func (a ARP) SetOp(op ARPOp) {
+ a[6] = uint8(op >> 8)
+ a[7] = uint8(op)
+}
+
+// SetIPv4OverEthernet configures the ARP packet for IPv4-over-Ethernet.
+func (a ARP) SetIPv4OverEthernet() {
+ a[0], a[1] = 0, 1 // htypeEthernet
+ a[2], a[3] = 0x08, 0x00 // IPv4ProtocolNumber
+ a[4] = 6 // macSize
+ a[5] = uint8(IPv4AddressSize)
+}
+
+// HardwareAddressSender is the link address of the sender.
+// It is a view on to the ARP packet so it can be used to set the value.
+func (a ARP) HardwareAddressSender() []byte {
+ const s = 8
+ return a[s : s+6]
+}
+
+// ProtocolAddressSender is the protocol address of the sender.
+// It is a view on to the ARP packet so it can be used to set the value.
+func (a ARP) ProtocolAddressSender() []byte {
+ const s = 8 + 6
+ return a[s : s+4]
+}
+
+// HardwareAddressTarget is the link address of the target.
+// It is a view on to the ARP packet so it can be used to set the value.
+func (a ARP) HardwareAddressTarget() []byte {
+ const s = 8 + 6 + 4
+ return a[s : s+6]
+}
+
+// ProtocolAddressTarget is the protocol address of the target.
+// It is a view on to the ARP packet so it can be used to set the value.
+func (a ARP) ProtocolAddressTarget() []byte {
+ const s = 8 + 6 + 4 + 6
+ return a[s : s+4]
+}
+
+// IsValid reports whether this is an ARP packet for IPv4 over Ethernet.
+func (a ARP) IsValid() bool {
+ if len(a) < ARPSize {
+ return false
+ }
+ const htypeEthernet = 1
+ const macSize = 6
+ return a.hardwareAddressSpace() == htypeEthernet &&
+ a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) &&
+ a.hardwareAddressSize() == macSize &&
+ a.protocolAddressSize() == IPv4AddressSize
+}
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
new file mode 100644
index 000000000..14a4b2b44
--- /dev/null
+++ b/pkg/tcpip/header/checksum.go
@@ -0,0 +1,249 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package header provides the implementation of the encoding and decoding of
+// network protocol headers.
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) {
+ v := initial
+
+ if odd {
+ v += uint32(buf[0])
+ buf = buf[1:]
+ }
+
+ l := len(buf)
+ odd = l&1 != 0
+ if odd {
+ l--
+ v += uint32(buf[l]) << 8
+ }
+
+ for i := 0; i < l; i += 2 {
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ }
+
+ return ChecksumCombine(uint16(v), uint16(v>>16)), odd
+}
+
+func unrolledCalculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) {
+ v := initial
+
+ if odd {
+ v += uint32(buf[0])
+ buf = buf[1:]
+ }
+
+ l := len(buf)
+ odd = l&1 != 0
+ if odd {
+ l--
+ v += uint32(buf[l]) << 8
+ }
+ for (l - 64) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ buf = buf[64:]
+ l = l - 64
+ }
+ if (l - 32) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ buf = buf[32:]
+ l = l - 32
+ }
+ if (l - 16) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ buf = buf[16:]
+ l = l - 16
+ }
+ if (l - 8) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ buf = buf[8:]
+ l = l - 8
+ }
+ if (l - 4) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ buf = buf[4:]
+ l = l - 4
+ }
+
+ // At this point since l was even before we started unrolling
+ // there can be only two bytes left to add.
+ if l != 0 {
+ v += (uint32(buf[0]) << 8) + uint32(buf[1])
+ }
+
+ return ChecksumCombine(uint16(v), uint16(v>>16)), odd
+}
+
+// ChecksumOld calculates the checksum (as defined in RFC 1071) of the bytes in
+// the given byte array. This function uses a non-optimized implementation. Its
+// only retained for reference and to use as a benchmark/test. Most code should
+// use the header.Checksum function.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func ChecksumOld(buf []byte, initial uint16) uint16 {
+ s, _ := calculateChecksum(buf, false, uint32(initial))
+ return s
+}
+
+// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
+// given byte array. This function uses an optimized unrolled version of the
+// checksum algorithm.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func Checksum(buf []byte, initial uint16) uint16 {
+ s, _ := unrolledCalculateChecksum(buf, false, uint32(initial))
+ return s
+}
+
+// ChecksumVV calculates the checksum (as defined in RFC 1071) of the bytes in
+// the given VectorizedView.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 {
+ return ChecksumVVWithOffset(vv, initial, 0, vv.Size())
+}
+
+// ChecksumVVWithOffset calculates the checksum (as defined in RFC 1071) of the
+// bytes in the given VectorizedView.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, size int) uint16 {
+ odd := false
+ sum := initial
+ for _, v := range vv.Views() {
+ if len(v) == 0 {
+ continue
+ }
+
+ if off >= len(v) {
+ off -= len(v)
+ continue
+ }
+ v = v[off:]
+
+ l := len(v)
+ if l > size {
+ l = size
+ }
+ v = v[:l]
+
+ sum, odd = unrolledCalculateChecksum(v, odd, uint32(sum))
+
+ size -= len(v)
+ if size == 0 {
+ break
+ }
+ off = 0
+ }
+ return sum
+}
+
+// ChecksumCombine combines the two uint16 to form their checksum. This is done
+// by adding them and the carry.
+//
+// Note that checksum a must have been computed on an even number of bytes.
+func ChecksumCombine(a, b uint16) uint16 {
+ v := uint32(a) + uint32(b)
+ return uint16(v + v>>16)
+}
+
+// PseudoHeaderChecksum calculates the pseudo-header checksum for the given
+// destination protocol and network address. Pseudo-headers are needed by
+// transport layers when calculating their own checksum.
+func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.Address, dstAddr tcpip.Address, totalLen uint16) uint16 {
+ xsum := Checksum([]byte(srcAddr), 0)
+ xsum = Checksum([]byte(dstAddr), xsum)
+
+ // Add the length portion of the checksum to the pseudo-checksum.
+ tmp := make([]byte, 2)
+ binary.BigEndian.PutUint16(tmp, totalLen)
+ xsum = Checksum(tmp, xsum)
+
+ return Checksum([]byte{0, uint8(protocol)}, xsum)
+}
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
new file mode 100644
index 000000000..309403482
--- /dev/null
+++ b/pkg/tcpip/header/checksum_test.go
@@ -0,0 +1,171 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package header provides the implementation of the encoding and decoding of
+// network protocol headers.
+package header_test
+
+import (
+ "fmt"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+func TestChecksumVVWithOffset(t *testing.T) {
+ testCases := []struct {
+ name string
+ vv buffer.VectorisedView
+ off, size int
+ initial uint16
+ want uint16
+ }{
+ {
+ name: "empty",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
+ }),
+ off: 0,
+ size: 0,
+ want: 0,
+ },
+ {
+ name: "OneView",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
+ }),
+ off: 0,
+ size: 5,
+ want: 1294,
+ },
+ {
+ name: "TwoViews",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
+ }),
+ off: 0,
+ size: 11,
+ want: 33819,
+ },
+ {
+ name: "TwoViewsWithOffset",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
+ }),
+ off: 1,
+ size: 11,
+ want: 33819,
+ },
+ {
+ name: "ThreeViewsWithOffset",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
+ }),
+ off: 7,
+ size: 11,
+ want: 33819,
+ },
+ {
+ name: "ThreeViewsWithInitial",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{77, 11, 33, 0, 55, 44}),
+ buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123, 99}),
+ }),
+ initial: 77,
+ off: 7,
+ size: 11,
+ want: 33896,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ if got, want := header.ChecksumVVWithOffset(tc.vv, tc.initial, tc.off, tc.size), tc.want; got != want {
+ t.Errorf("header.ChecksumVVWithOffset(%v) = %v, want: %v", tc, got, tc.want)
+ }
+ v := tc.vv.ToView()
+ v.TrimFront(tc.off)
+ v.CapLength(tc.size)
+ if got, want := header.Checksum(v, tc.initial), tc.want; got != want {
+ t.Errorf("header.Checksum(%v) = %v, want: %v", tc, got, tc.want)
+ }
+ })
+ }
+}
+
+func TestChecksum(t *testing.T) {
+ var bufSizes = []int{0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 257, 1023, 1024}
+ type testCase struct {
+ buf []byte
+ initial uint16
+ csumOrig uint16
+ csumNew uint16
+ }
+ testCases := make([]testCase, 100000)
+ // Ensure same buffer generation for test consistency.
+ rnd := rand.New(rand.NewSource(42))
+ for i := range testCases {
+ testCases[i].buf = make([]byte, bufSizes[i%len(bufSizes)])
+ testCases[i].initial = uint16(rnd.Intn(65536))
+ rnd.Read(testCases[i].buf)
+ }
+
+ for i := range testCases {
+ testCases[i].csumOrig = header.ChecksumOld(testCases[i].buf, testCases[i].initial)
+ testCases[i].csumNew = header.Checksum(testCases[i].buf, testCases[i].initial)
+ if got, want := testCases[i].csumNew, testCases[i].csumOrig; got != want {
+ t.Fatalf("new checksum for (buf = %x, initial = %d) does not match old got: %d, want: %d", testCases[i].buf, testCases[i].initial, got, want)
+ }
+ }
+}
+
+func BenchmarkChecksum(b *testing.B) {
+ var bufSizes = []int{64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 16384, 32767, 32768, 65535, 65536}
+
+ checkSumImpls := []struct {
+ fn func([]byte, uint16) uint16
+ name string
+ }{
+ {header.ChecksumOld, fmt.Sprintf("checksum_old")},
+ {header.Checksum, fmt.Sprintf("checksum")},
+ }
+
+ for _, csumImpl := range checkSumImpls {
+ // Ensure same buffer generation for test consistency.
+ rnd := rand.New(rand.NewSource(42))
+ for _, bufSz := range bufSizes {
+ b.Run(fmt.Sprintf("%s_%d", csumImpl.name, bufSz), func(b *testing.B) {
+ tc := struct {
+ buf []byte
+ initial uint16
+ csum uint16
+ }{
+ buf: make([]byte, bufSz),
+ initial: uint16(rnd.Intn(65536)),
+ }
+ rnd.Read(tc.buf)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ tc.csum = csumImpl.fn(tc.buf, tc.initial)
+ }
+ })
+ }
+ }
+}
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
new file mode 100644
index 000000000..b1e92d2d7
--- /dev/null
+++ b/pkg/tcpip/header/eth.go
@@ -0,0 +1,177 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ dstMAC = 0
+ srcMAC = 6
+ ethType = 12
+)
+
+// EthernetFields contains the fields of an ethernet frame header. It is used to
+// describe the fields of a frame that needs to be encoded.
+type EthernetFields struct {
+ // SrcAddr is the "MAC source" field of an ethernet frame header.
+ SrcAddr tcpip.LinkAddress
+
+ // DstAddr is the "MAC destination" field of an ethernet frame header.
+ DstAddr tcpip.LinkAddress
+
+ // Type is the "ethertype" field of an ethernet frame header.
+ Type tcpip.NetworkProtocolNumber
+}
+
+// Ethernet represents an ethernet frame header stored in a byte array.
+type Ethernet []byte
+
+const (
+ // EthernetMinimumSize is the minimum size of a valid ethernet frame.
+ EthernetMinimumSize = 14
+
+ // EthernetAddressSize is the size, in bytes, of an ethernet address.
+ EthernetAddressSize = 6
+
+ // unspecifiedEthernetAddress is the unspecified ethernet address
+ // (all bits set to 0).
+ unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
+
+ // unicastMulticastFlagMask is the mask of the least significant bit in
+ // the first octet (in network byte order) of an ethernet address that
+ // determines whether the ethernet address is a unicast or multicast. If
+ // the masked bit is a 1, then the address is a multicast, unicast
+ // otherwise.
+ //
+ // See the IEEE Std 802-2001 document for more details. Specifically,
+ // section 9.2.1 of http://ieee802.org/secmail/pdfocSP2xXA6d.pdf:
+ // "A 48-bit universal address consists of two parts. The first 24 bits
+ // correspond to the OUI as assigned by the IEEE, expect that the
+ // assignee may set the LSB of the first octet to 1 for group addresses
+ // or set it to 0 for individual addresses."
+ unicastMulticastFlagMask = 1
+
+ // unicastMulticastFlagByteIdx is the byte that holds the
+ // unicast/multicast flag. See unicastMulticastFlagMask.
+ unicastMulticastFlagByteIdx = 0
+)
+
+const (
+ // EthernetProtocolAll is a catch-all for all protocols carried inside
+ // an ethernet frame. It is mainly used to create packet sockets that
+ // capture all traffic.
+ EthernetProtocolAll tcpip.NetworkProtocolNumber = 0x0003
+
+ // EthernetProtocolPUP is the PARC Universial Packet protocol ethertype.
+ EthernetProtocolPUP tcpip.NetworkProtocolNumber = 0x0200
+)
+
+// Ethertypes holds the protocol numbers describing the payload of an ethernet
+// frame. These types aren't necessarily supported by netstack, but can be used
+// to catch all traffic of a type via packet endpoints.
+var Ethertypes = []tcpip.NetworkProtocolNumber{
+ EthernetProtocolAll,
+ EthernetProtocolPUP,
+}
+
+// SourceAddress returns the "MAC source" field of the ethernet frame header.
+func (b Ethernet) SourceAddress() tcpip.LinkAddress {
+ return tcpip.LinkAddress(b[srcMAC:][:EthernetAddressSize])
+}
+
+// DestinationAddress returns the "MAC destination" field of the ethernet frame
+// header.
+func (b Ethernet) DestinationAddress() tcpip.LinkAddress {
+ return tcpip.LinkAddress(b[dstMAC:][:EthernetAddressSize])
+}
+
+// Type returns the "ethertype" field of the ethernet frame header.
+func (b Ethernet) Type() tcpip.NetworkProtocolNumber {
+ return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(b[ethType:]))
+}
+
+// Encode encodes all the fields of the ethernet frame header.
+func (b Ethernet) Encode(e *EthernetFields) {
+ binary.BigEndian.PutUint16(b[ethType:], uint16(e.Type))
+ copy(b[srcMAC:][:EthernetAddressSize], e.SrcAddr)
+ copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr)
+}
+
+// IsValidUnicastEthernetAddress returns true if addr is a valid unicast
+// ethernet address.
+func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool {
+ // Must be of the right length.
+ if len(addr) != EthernetAddressSize {
+ return false
+ }
+
+ // Must not be unspecified.
+ if addr == unspecifiedEthernetAddress {
+ return false
+ }
+
+ // Must not be a multicast.
+ if addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 {
+ return false
+ }
+
+ // addr is a valid unicast ethernet address.
+ return true
+}
+
+// EthernetAddressFromMulticastIPv4Address returns a multicast Ethernet address
+// for a multicast IPv4 address.
+//
+// addr MUST be a multicast IPv4 address.
+func EthernetAddressFromMulticastIPv4Address(addr tcpip.Address) tcpip.LinkAddress {
+ var linkAddrBytes [EthernetAddressSize]byte
+ // RFC 1112 Host Extensions for IP Multicasting
+ //
+ // 6.4. Extensions to an Ethernet Local Network Module:
+ //
+ // An IP host group address is mapped to an Ethernet multicast
+ // address by placing the low-order 23-bits of the IP address
+ // into the low-order 23 bits of the Ethernet multicast address
+ // 01-00-5E-00-00-00 (hex).
+ linkAddrBytes[0] = 0x1
+ linkAddrBytes[2] = 0x5e
+ linkAddrBytes[3] = addr[1] & 0x7F
+ copy(linkAddrBytes[4:], addr[IPv4AddressSize-2:])
+ return tcpip.LinkAddress(linkAddrBytes[:])
+}
+
+// EthernetAddressFromMulticastIPv6Address returns a multicast Ethernet address
+// for a multicast IPv6 address.
+//
+// addr MUST be a multicast IPv6 address.
+func EthernetAddressFromMulticastIPv6Address(addr tcpip.Address) tcpip.LinkAddress {
+ // RFC 2464 Transmission of IPv6 Packets over Ethernet Networks
+ //
+ // 7. Address Mapping -- Multicast
+ //
+ // An IPv6 packet with a multicast destination address DST,
+ // consisting of the sixteen octets DST[1] through DST[16], is
+ // transmitted to the Ethernet multicast address whose first
+ // two octets are the value 3333 hexadecimal and whose last
+ // four octets are the last four octets of DST.
+ linkAddrBytes := []byte(addr[IPv6AddressSize-EthernetAddressSize:])
+ linkAddrBytes[0] = 0x33
+ linkAddrBytes[1] = 0x33
+ return tcpip.LinkAddress(linkAddrBytes[:])
+}
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
new file mode 100644
index 000000000..14413f2ce
--- /dev/null
+++ b/pkg/tcpip/header/eth_test.go
@@ -0,0 +1,102 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+func TestIsValidUnicastEthernetAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.LinkAddress
+ expected bool
+ }{
+ {
+ "Nil",
+ tcpip.LinkAddress([]byte(nil)),
+ false,
+ },
+ {
+ "Empty",
+ tcpip.LinkAddress(""),
+ false,
+ },
+ {
+ "InvalidLength",
+ tcpip.LinkAddress("\x01\x02\x03"),
+ false,
+ },
+ {
+ "Unspecified",
+ unspecifiedEthernetAddress,
+ false,
+ },
+ {
+ "Multicast",
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ false,
+ },
+ {
+ "Valid",
+ tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"),
+ true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := IsValidUnicastEthernetAddress(test.addr); got != test.expected {
+ t.Fatalf("got IsValidUnicastEthernetAddress = %t, want = %t", got, test.expected)
+ }
+ })
+ }
+}
+
+func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "IPv4 Multicast without 24th bit set",
+ addr: "\xe0\x7e\xdc\xba",
+ expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba",
+ },
+ {
+ name: "IPv4 Multicast with 24th bit set",
+ addr: "\xe0\xfe\xdc\xba",
+ expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := EthernetAddressFromMulticastIPv4Address(test.addr); got != test.expectedLinkAddr {
+ t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", test.addr, got, test.expectedLinkAddr)
+ }
+ })
+ }
+}
+
+func TestEthernetAddressFromMulticastIPv6Address(t *testing.T) {
+ addr := tcpip.Address("\xff\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x1a")
+ if got, want := EthernetAddressFromMulticastIPv6Address(addr), tcpip.LinkAddress("\x33\x33\x0d\x0e\x0f\x1a"); got != want {
+ t.Fatalf("got EthernetAddressFromMulticastIPv6Address(%s) = %s, want = %s", addr, got, want)
+ }
+}
diff --git a/pkg/tcpip/header/gue.go b/pkg/tcpip/header/gue.go
new file mode 100644
index 000000000..10d358c0e
--- /dev/null
+++ b/pkg/tcpip/header/gue.go
@@ -0,0 +1,73 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+const (
+ typeHLen = 0
+ encapProto = 1
+)
+
+// GUEFields contains the fields of a GUE packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type GUEFields struct {
+ // Type is the "type" field of the GUE header.
+ Type uint8
+
+ // Control is the "control" field of the GUE header.
+ Control bool
+
+ // HeaderLength is the "header length" field of the GUE header. It must
+ // be at least 4 octets, and a multiple of 4 as well.
+ HeaderLength uint8
+
+ // Protocol is the "protocol" field of the GUE header. This is one of
+ // the IPPROTO_* values.
+ Protocol uint8
+}
+
+// GUE represents a Generic UDP Encapsulation header stored in a byte array, the
+// fields are described in https://tools.ietf.org/html/draft-ietf-nvo3-gue-01.
+type GUE []byte
+
+const (
+ // GUEMinimumSize is the minimum size of a valid GUE packet.
+ GUEMinimumSize = 4
+)
+
+// TypeAndControl returns the GUE packet type (top 3 bits of the first byte,
+// which includes the control bit).
+func (b GUE) TypeAndControl() uint8 {
+ return b[typeHLen] >> 5
+}
+
+// HeaderLength returns the total length of the GUE header.
+func (b GUE) HeaderLength() uint8 {
+ return 4 + 4*(b[typeHLen]&0x1f)
+}
+
+// Protocol returns the protocol field of the GUE header.
+func (b GUE) Protocol() uint8 {
+ return b[encapProto]
+}
+
+// Encode encodes all the fields of the GUE header.
+func (b GUE) Encode(i *GUEFields) {
+ ctl := uint8(0)
+ if i.Control {
+ ctl = 1 << 5
+ }
+ b[typeHLen] = ctl | i.Type<<6 | (i.HeaderLength-4)/4
+ b[encapProto] = i.Protocol
+}
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
new file mode 100644
index 000000000..7908c5744
--- /dev/null
+++ b/pkg/tcpip/header/icmpv4.go
@@ -0,0 +1,170 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+// ICMPv4 represents an ICMPv4 header stored in a byte array.
+type ICMPv4 []byte
+
+const (
+ // ICMPv4PayloadOffset defines the start of ICMP payload.
+ ICMPv4PayloadOffset = 8
+
+ // ICMPv4MinimumSize is the minimum size of a valid ICMP packet.
+ ICMPv4MinimumSize = 8
+
+ // ICMPv4ProtocolNumber is the ICMP transport protocol number.
+ ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1
+
+ // icmpv4ChecksumOffset is the offset of the checksum field
+ // in an ICMPv4 message.
+ icmpv4ChecksumOffset = 2
+
+ // icmpv4MTUOffset is the offset of the MTU field
+ // in a ICMPv4FragmentationNeeded message.
+ icmpv4MTUOffset = 6
+
+ // icmpv4IdentOffset is the offset of the ident field
+ // in a ICMPv4EchoRequest/Reply message.
+ icmpv4IdentOffset = 4
+
+ // icmpv4SequenceOffset is the offset of the sequence field
+ // in a ICMPv4EchoRequest/Reply message.
+ icmpv4SequenceOffset = 6
+)
+
+// ICMPv4Type is the ICMP type field described in RFC 792.
+type ICMPv4Type byte
+
+// Typical values of ICMPv4Type defined in RFC 792.
+const (
+ ICMPv4EchoReply ICMPv4Type = 0
+ ICMPv4DstUnreachable ICMPv4Type = 3
+ ICMPv4SrcQuench ICMPv4Type = 4
+ ICMPv4Redirect ICMPv4Type = 5
+ ICMPv4Echo ICMPv4Type = 8
+ ICMPv4TimeExceeded ICMPv4Type = 11
+ ICMPv4ParamProblem ICMPv4Type = 12
+ ICMPv4Timestamp ICMPv4Type = 13
+ ICMPv4TimestampReply ICMPv4Type = 14
+ ICMPv4InfoRequest ICMPv4Type = 15
+ ICMPv4InfoReply ICMPv4Type = 16
+)
+
+// Values for ICMP code as defined in RFC 792.
+const (
+ ICMPv4TTLExceeded = 0
+ ICMPv4PortUnreachable = 3
+ ICMPv4FragmentationNeeded = 4
+)
+
+// Type is the ICMP type field.
+func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) }
+
+// SetType sets the ICMP type field.
+func (b ICMPv4) SetType(t ICMPv4Type) { b[0] = byte(t) }
+
+// Code is the ICMP code field. Its meaning depends on the value of Type.
+func (b ICMPv4) Code() byte { return b[1] }
+
+// SetCode sets the ICMP code field.
+func (b ICMPv4) SetCode(c byte) { b[1] = c }
+
+// Checksum is the ICMP checksum field.
+func (b ICMPv4) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:])
+}
+
+// SetChecksum sets the ICMP checksum field.
+func (b ICMPv4) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[icmpv4ChecksumOffset:], checksum)
+}
+
+// SourcePort implements Transport.SourcePort.
+func (ICMPv4) SourcePort() uint16 {
+ return 0
+}
+
+// DestinationPort implements Transport.DestinationPort.
+func (ICMPv4) DestinationPort() uint16 {
+ return 0
+}
+
+// SetSourcePort implements Transport.SetSourcePort.
+func (ICMPv4) SetSourcePort(uint16) {
+}
+
+// SetDestinationPort implements Transport.SetDestinationPort.
+func (ICMPv4) SetDestinationPort(uint16) {
+}
+
+// Payload implements Transport.Payload.
+func (b ICMPv4) Payload() []byte {
+ return b[ICMPv4PayloadOffset:]
+}
+
+// MTU retrieves the MTU field from an ICMPv4 message.
+func (b ICMPv4) MTU() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv4MTUOffset:])
+}
+
+// SetMTU sets the MTU field from an ICMPv4 message.
+func (b ICMPv4) SetMTU(mtu uint16) {
+ binary.BigEndian.PutUint16(b[icmpv4MTUOffset:], mtu)
+}
+
+// Ident retrieves the Ident field from an ICMPv4 message.
+func (b ICMPv4) Ident() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv4IdentOffset:])
+}
+
+// SetIdent sets the Ident field from an ICMPv4 message.
+func (b ICMPv4) SetIdent(ident uint16) {
+ binary.BigEndian.PutUint16(b[icmpv4IdentOffset:], ident)
+}
+
+// Sequence retrieves the Sequence field from an ICMPv4 message.
+func (b ICMPv4) Sequence() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv4SequenceOffset:])
+}
+
+// SetSequence sets the Sequence field from an ICMPv4 message.
+func (b ICMPv4) SetSequence(sequence uint16) {
+ binary.BigEndian.PutUint16(b[icmpv4SequenceOffset:], sequence)
+}
+
+// ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header,
+// and payload.
+func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 {
+ // Calculate the IPv6 pseudo-header upper-layer checksum.
+ xsum := uint16(0)
+ for _, v := range vv.Views() {
+ xsum = Checksum(v, xsum)
+ }
+
+ // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
+ h2, h3 := h[2], h[3]
+ h[2], h[3] = 0, 0
+ xsum = ^Checksum(h, xsum)
+ h[2], h[3] = h2, h3
+
+ return xsum
+}
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
new file mode 100644
index 000000000..c7ee2de57
--- /dev/null
+++ b/pkg/tcpip/header/icmpv6.go
@@ -0,0 +1,221 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+// ICMPv6 represents an ICMPv6 header stored in a byte array.
+type ICMPv6 []byte
+
+const (
+ // ICMPv6HeaderSize is the size of the ICMPv6 header. That is, the
+ // sum of the size of the ICMPv6 Type, Code and Checksum fields, as
+ // per RFC 4443 section 2.1. After the ICMPv6 header, the ICMPv6
+ // message body begins.
+ ICMPv6HeaderSize = 4
+
+ // ICMPv6MinimumSize is the minimum size of a valid ICMP packet.
+ ICMPv6MinimumSize = 8
+
+ // ICMPv6PayloadOffset is the offset of the payload in an
+ // ICMP packet.
+ ICMPv6PayloadOffset = 8
+
+ // ICMPv6ProtocolNumber is the ICMP transport protocol number.
+ ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58
+
+ // ICMPv6NeighborSolicitMinimumSize is the minimum size of a
+ // neighbor solicitation packet.
+ ICMPv6NeighborSolicitMinimumSize = ICMPv6HeaderSize + NDPNSMinimumSize
+
+ // ICMPv6NeighborAdvertMinimumSize is the minimum size of a
+ // neighbor advertisement packet.
+ ICMPv6NeighborAdvertMinimumSize = ICMPv6HeaderSize + NDPNAMinimumSize
+
+ // ICMPv6NeighborAdvertSize is size of a neighbor advertisement
+ // including the NDP Target Link Layer option for an Ethernet
+ // address.
+ ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + NDPLinkLayerAddressSize
+
+ // ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet.
+ ICMPv6EchoMinimumSize = 8
+
+ // ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP
+ // destination unreachable packet.
+ ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize
+
+ // ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP
+ // packet-too-big packet.
+ ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize
+
+ // icmpv6ChecksumOffset is the offset of the checksum field
+ // in an ICMPv6 message.
+ icmpv6ChecksumOffset = 2
+
+ // icmpv6MTUOffset is the offset of the MTU field in an ICMPv6
+ // PacketTooBig message.
+ icmpv6MTUOffset = 4
+
+ // icmpv6IdentOffset is the offset of the ident field
+ // in a ICMPv6 Echo Request/Reply message.
+ icmpv6IdentOffset = 4
+
+ // icmpv6SequenceOffset is the offset of the sequence field
+ // in a ICMPv6 Echo Request/Reply message.
+ icmpv6SequenceOffset = 6
+
+ // NDPHopLimit is the expected IP hop limit value of 255 for received
+ // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1,
+ // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently
+ // drop the NDP packet. All outgoing NDP packets must use this value for
+ // its IP hop limit field.
+ NDPHopLimit = 255
+)
+
+// ICMPv6Type is the ICMP type field described in RFC 4443 and friends.
+type ICMPv6Type byte
+
+// Typical values of ICMPv6Type defined in RFC 4443.
+const (
+ ICMPv6DstUnreachable ICMPv6Type = 1
+ ICMPv6PacketTooBig ICMPv6Type = 2
+ ICMPv6TimeExceeded ICMPv6Type = 3
+ ICMPv6ParamProblem ICMPv6Type = 4
+ ICMPv6EchoRequest ICMPv6Type = 128
+ ICMPv6EchoReply ICMPv6Type = 129
+
+ // Neighbor Discovery Protocol (NDP) messages, see RFC 4861.
+
+ ICMPv6RouterSolicit ICMPv6Type = 133
+ ICMPv6RouterAdvert ICMPv6Type = 134
+ ICMPv6NeighborSolicit ICMPv6Type = 135
+ ICMPv6NeighborAdvert ICMPv6Type = 136
+ ICMPv6RedirectMsg ICMPv6Type = 137
+)
+
+// Values for ICMP code as defined in RFC 4443.
+const (
+ ICMPv6PortUnreachable = 4
+)
+
+// Type is the ICMP type field.
+func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) }
+
+// SetType sets the ICMP type field.
+func (b ICMPv6) SetType(t ICMPv6Type) { b[0] = byte(t) }
+
+// Code is the ICMP code field. Its meaning depends on the value of Type.
+func (b ICMPv6) Code() byte { return b[1] }
+
+// SetCode sets the ICMP code field.
+func (b ICMPv6) SetCode(c byte) { b[1] = c }
+
+// Checksum is the ICMP checksum field.
+func (b ICMPv6) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv6ChecksumOffset:])
+}
+
+// SetChecksum sets the ICMP checksum field.
+func (b ICMPv6) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[icmpv6ChecksumOffset:], checksum)
+}
+
+// SourcePort implements Transport.SourcePort.
+func (ICMPv6) SourcePort() uint16 {
+ return 0
+}
+
+// DestinationPort implements Transport.DestinationPort.
+func (ICMPv6) DestinationPort() uint16 {
+ return 0
+}
+
+// SetSourcePort implements Transport.SetSourcePort.
+func (ICMPv6) SetSourcePort(uint16) {
+}
+
+// SetDestinationPort implements Transport.SetDestinationPort.
+func (ICMPv6) SetDestinationPort(uint16) {
+}
+
+// MTU retrieves the MTU field from an ICMPv6 message.
+func (b ICMPv6) MTU() uint32 {
+ return binary.BigEndian.Uint32(b[icmpv6MTUOffset:])
+}
+
+// SetMTU sets the MTU field from an ICMPv6 message.
+func (b ICMPv6) SetMTU(mtu uint32) {
+ binary.BigEndian.PutUint32(b[icmpv6MTUOffset:], mtu)
+}
+
+// Ident retrieves the Ident field from an ICMPv6 message.
+func (b ICMPv6) Ident() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv6IdentOffset:])
+}
+
+// SetIdent sets the Ident field from an ICMPv6 message.
+func (b ICMPv6) SetIdent(ident uint16) {
+ binary.BigEndian.PutUint16(b[icmpv6IdentOffset:], ident)
+}
+
+// Sequence retrieves the Sequence field from an ICMPv6 message.
+func (b ICMPv6) Sequence() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv6SequenceOffset:])
+}
+
+// SetSequence sets the Sequence field from an ICMPv6 message.
+func (b ICMPv6) SetSequence(sequence uint16) {
+ binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence)
+}
+
+// NDPPayload returns the NDP payload buffer. That is, it returns the ICMPv6
+// packet's message body as defined by RFC 4443 section 2.1; the portion of the
+// ICMPv6 buffer after the first ICMPv6HeaderSize bytes.
+func (b ICMPv6) NDPPayload() []byte {
+ return b[ICMPv6HeaderSize:]
+}
+
+// Payload implements Transport.Payload.
+func (b ICMPv6) Payload() []byte {
+ return b[ICMPv6PayloadOffset:]
+}
+
+// ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header,
+// IPv6 src/dst addresses and the payload.
+func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 {
+ // Calculate the IPv6 pseudo-header upper-layer checksum.
+ xsum := Checksum([]byte(src), 0)
+ xsum = Checksum([]byte(dst), xsum)
+ var upperLayerLength [4]byte
+ binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size()))
+ xsum = Checksum(upperLayerLength[:], xsum)
+ xsum = Checksum([]byte{0, 0, 0, uint8(ICMPv6ProtocolNumber)}, xsum)
+ for _, v := range vv.Views() {
+ xsum = Checksum(v, xsum)
+ }
+
+ // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
+ h2, h3 := h[2], h[3]
+ h[2], h[3] = 0, 0
+ xsum = ^Checksum(h, xsum)
+ h[2], h[3] = h2, h3
+
+ return xsum
+}
diff --git a/pkg/tcpip/header/interfaces.go b/pkg/tcpip/header/interfaces.go
new file mode 100644
index 000000000..861cbbb70
--- /dev/null
+++ b/pkg/tcpip/header/interfaces.go
@@ -0,0 +1,92 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // MaxIPPacketSize is the maximum supported IP packet size, excluding
+ // jumbograms. The maximum IPv4 packet size is 64k-1 (total size must fit
+ // in 16 bits). For IPv6, the payload max size (excluding jumbograms) is
+ // 64k-1 (also needs to fit in 16 bits). So we use 64k - 1 + 2 * m, where
+ // m is the minimum IPv6 header size; we leave room for some potential
+ // IP options.
+ MaxIPPacketSize = 0xffff + 2*IPv6MinimumSize
+)
+
+// Transport offers generic methods to query and/or update the fields of the
+// header of a transport protocol buffer.
+type Transport interface {
+ // SourcePort returns the value of the "source port" field.
+ SourcePort() uint16
+
+ // Destination returns the value of the "destination port" field.
+ DestinationPort() uint16
+
+ // Checksum returns the value of the "checksum" field.
+ Checksum() uint16
+
+ // SetSourcePort sets the value of the "source port" field.
+ SetSourcePort(uint16)
+
+ // SetDestinationPort sets the value of the "destination port" field.
+ SetDestinationPort(uint16)
+
+ // SetChecksum sets the value of the "checksum" field.
+ SetChecksum(uint16)
+
+ // Payload returns the data carried in the transport buffer.
+ Payload() []byte
+}
+
+// Network offers generic methods to query and/or update the fields of the
+// header of a network protocol buffer.
+type Network interface {
+ // SourceAddress returns the value of the "source address" field.
+ SourceAddress() tcpip.Address
+
+ // DestinationAddress returns the value of the "destination address"
+ // field.
+ DestinationAddress() tcpip.Address
+
+ // Checksum returns the value of the "checksum" field.
+ Checksum() uint16
+
+ // SetSourceAddress sets the value of the "source address" field.
+ SetSourceAddress(tcpip.Address)
+
+ // SetDestinationAddress sets the value of the "destination address"
+ // field.
+ SetDestinationAddress(tcpip.Address)
+
+ // SetChecksum sets the value of the "checksum" field.
+ SetChecksum(uint16)
+
+ // TransportProtocol returns the number of the transport protocol
+ // stored in the payload.
+ TransportProtocol() tcpip.TransportProtocolNumber
+
+ // Payload returns a byte slice containing the payload of the network
+ // packet.
+ Payload() []byte
+
+ // TOS returns the values of the "type of service" and "flow label" fields.
+ TOS() (uint8, uint32)
+
+ // SetTOS sets the values of the "type of service" and "flow label" fields.
+ SetTOS(t uint8, l uint32)
+}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
new file mode 100644
index 000000000..62ac932bb
--- /dev/null
+++ b/pkg/tcpip/header/ipv4.go
@@ -0,0 +1,312 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ versIHL = 0
+ tos = 1
+ // IPv4TotalLenOffset is the offset of the total length field in the
+ // IPv4 header.
+ IPv4TotalLenOffset = 2
+ id = 4
+ flagsFO = 6
+ ttl = 8
+ protocol = 9
+ checksum = 10
+ srcAddr = 12
+ dstAddr = 16
+)
+
+// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type IPv4Fields struct {
+ // IHL is the "internet header length" field of an IPv4 packet. The value
+ // is in bytes.
+ IHL uint8
+
+ // TOS is the "type of service" field of an IPv4 packet.
+ TOS uint8
+
+ // TotalLength is the "total length" field of an IPv4 packet.
+ TotalLength uint16
+
+ // ID is the "identification" field of an IPv4 packet.
+ ID uint16
+
+ // Flags is the "flags" field of an IPv4 packet.
+ Flags uint8
+
+ // FragmentOffset is the "fragment offset" field of an IPv4 packet.
+ FragmentOffset uint16
+
+ // TTL is the "time to live" field of an IPv4 packet.
+ TTL uint8
+
+ // Protocol is the "protocol" field of an IPv4 packet.
+ Protocol uint8
+
+ // Checksum is the "checksum" field of an IPv4 packet.
+ Checksum uint16
+
+ // SrcAddr is the "source ip address" of an IPv4 packet.
+ SrcAddr tcpip.Address
+
+ // DstAddr is the "destination ip address" of an IPv4 packet.
+ DstAddr tcpip.Address
+}
+
+// IPv4 represents an ipv4 header stored in a byte array.
+// Most of the methods of IPv4 access to the underlying slice without
+// checking the boundaries and could panic because of 'index out of range'.
+// Always call IsValid() to validate an instance of IPv4 before using other methods.
+type IPv4 []byte
+
+const (
+ // IPv4MinimumSize is the minimum size of a valid IPv4 packet.
+ IPv4MinimumSize = 20
+
+ // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given
+ // that there are only 4 bits to represents the header length in 32-bit
+ // units, the header cannot exceed 15*4 = 60 bytes.
+ IPv4MaximumHeaderSize = 60
+
+ // MinIPFragmentPayloadSize is the minimum number of payload bytes that
+ // the first fragment must carry when an IPv4 packet is fragmented.
+ MinIPFragmentPayloadSize = 8
+
+ // IPv4AddressSize is the size, in bytes, of an IPv4 address.
+ IPv4AddressSize = 4
+
+ // IPv4ProtocolNumber is IPv4's network protocol number.
+ IPv4ProtocolNumber tcpip.NetworkProtocolNumber = 0x0800
+
+ // IPv4Version is the version of the ipv4 protocol.
+ IPv4Version = 4
+
+ // IPv4Broadcast is the broadcast address of the IPv4 procotol.
+ IPv4Broadcast tcpip.Address = "\xff\xff\xff\xff"
+
+ // IPv4Any is the non-routable IPv4 "any" meta address.
+ IPv4Any tcpip.Address = "\x00\x00\x00\x00"
+
+ // IPv4MinimumProcessableDatagramSize is the minimum size of an IP
+ // packet that every IPv4 capable host must be able to
+ // process/reassemble.
+ IPv4MinimumProcessableDatagramSize = 576
+)
+
+// Flags that may be set in an IPv4 packet.
+const (
+ IPv4FlagMoreFragments = 1 << iota
+ IPv4FlagDontFragment
+)
+
+// IPv4EmptySubnet is the empty IPv4 subnet.
+var IPv4EmptySubnet = func() tcpip.Subnet {
+ subnet, err := tcpip.NewSubnet(IPv4Any, tcpip.AddressMask(IPv4Any))
+ if err != nil {
+ panic(err)
+ }
+ return subnet
+}()
+
+// IPVersion returns the version of IP used in the given packet. It returns -1
+// if the packet is not large enough to contain the version field.
+func IPVersion(b []byte) int {
+ // Length must be at least offset+length of version field.
+ if len(b) < versIHL+1 {
+ return -1
+ }
+ return int(b[versIHL] >> 4)
+}
+
+// HeaderLength returns the value of the "header length" field of the ipv4
+// header. The length returned is in bytes.
+func (b IPv4) HeaderLength() uint8 {
+ return (b[versIHL] & 0xf) * 4
+}
+
+// ID returns the value of the identifier field of the ipv4 header.
+func (b IPv4) ID() uint16 {
+ return binary.BigEndian.Uint16(b[id:])
+}
+
+// Protocol returns the value of the protocol field of the ipv4 header.
+func (b IPv4) Protocol() uint8 {
+ return b[protocol]
+}
+
+// Flags returns the "flags" field of the ipv4 header.
+func (b IPv4) Flags() uint8 {
+ return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13)
+}
+
+// More returns whether the more fragments flag is set.
+func (b IPv4) More() bool {
+ return b.Flags()&IPv4FlagMoreFragments != 0
+}
+
+// TTL returns the "TTL" field of the ipv4 header.
+func (b IPv4) TTL() uint8 {
+ return b[ttl]
+}
+
+// FragmentOffset returns the "fragment offset" field of the ipv4 header.
+func (b IPv4) FragmentOffset() uint16 {
+ return binary.BigEndian.Uint16(b[flagsFO:]) << 3
+}
+
+// TotalLength returns the "total length" field of the ipv4 header.
+func (b IPv4) TotalLength() uint16 {
+ return binary.BigEndian.Uint16(b[IPv4TotalLenOffset:])
+}
+
+// Checksum returns the checksum field of the ipv4 header.
+func (b IPv4) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[checksum:])
+}
+
+// SourceAddress returns the "source address" field of the ipv4 header.
+func (b IPv4) SourceAddress() tcpip.Address {
+ return tcpip.Address(b[srcAddr : srcAddr+IPv4AddressSize])
+}
+
+// DestinationAddress returns the "destination address" field of the ipv4
+// header.
+func (b IPv4) DestinationAddress() tcpip.Address {
+ return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
+}
+
+// TransportProtocol implements Network.TransportProtocol.
+func (b IPv4) TransportProtocol() tcpip.TransportProtocolNumber {
+ return tcpip.TransportProtocolNumber(b.Protocol())
+}
+
+// Payload implements Network.Payload.
+func (b IPv4) Payload() []byte {
+ return b[b.HeaderLength():][:b.PayloadLength()]
+}
+
+// PayloadLength returns the length of the payload portion of the ipv4 packet.
+func (b IPv4) PayloadLength() uint16 {
+ return b.TotalLength() - uint16(b.HeaderLength())
+}
+
+// TOS returns the "type of service" field of the ipv4 header.
+func (b IPv4) TOS() (uint8, uint32) {
+ return b[tos], 0
+}
+
+// SetTOS sets the "type of service" field of the ipv4 header.
+func (b IPv4) SetTOS(v uint8, _ uint32) {
+ b[tos] = v
+}
+
+// SetTotalLength sets the "total length" field of the ipv4 header.
+func (b IPv4) SetTotalLength(totalLength uint16) {
+ binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength)
+}
+
+// SetChecksum sets the checksum field of the ipv4 header.
+func (b IPv4) SetChecksum(v uint16) {
+ binary.BigEndian.PutUint16(b[checksum:], v)
+}
+
+// SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the
+// ipv4 header.
+func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) {
+ v := (uint16(flags) << 13) | (offset >> 3)
+ binary.BigEndian.PutUint16(b[flagsFO:], v)
+}
+
+// SetID sets the identification field.
+func (b IPv4) SetID(v uint16) {
+ binary.BigEndian.PutUint16(b[id:], v)
+}
+
+// SetSourceAddress sets the "source address" field of the ipv4 header.
+func (b IPv4) SetSourceAddress(addr tcpip.Address) {
+ copy(b[srcAddr:srcAddr+IPv4AddressSize], addr)
+}
+
+// SetDestinationAddress sets the "destination address" field of the ipv4
+// header.
+func (b IPv4) SetDestinationAddress(addr tcpip.Address) {
+ copy(b[dstAddr:dstAddr+IPv4AddressSize], addr)
+}
+
+// CalculateChecksum calculates the checksum of the ipv4 header.
+func (b IPv4) CalculateChecksum() uint16 {
+ return Checksum(b[:b.HeaderLength()], 0)
+}
+
+// Encode encodes all the fields of the ipv4 header.
+func (b IPv4) Encode(i *IPv4Fields) {
+ b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf)
+ b[tos] = i.TOS
+ b.SetTotalLength(i.TotalLength)
+ binary.BigEndian.PutUint16(b[id:], i.ID)
+ b.SetFlagsFragmentOffset(i.Flags, i.FragmentOffset)
+ b[ttl] = i.TTL
+ b[protocol] = i.Protocol
+ b.SetChecksum(i.Checksum)
+ copy(b[srcAddr:srcAddr+IPv4AddressSize], i.SrcAddr)
+ copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr)
+}
+
+// EncodePartial updates the total length and checksum fields of ipv4 header,
+// taking in the partial checksum, which is the checksum of the header without
+// the total length and checksum fields. It is useful in cases when similar
+// packets are produced.
+func (b IPv4) EncodePartial(partialChecksum, totalLength uint16) {
+ b.SetTotalLength(totalLength)
+ checksum := Checksum(b[IPv4TotalLenOffset:IPv4TotalLenOffset+2], partialChecksum)
+ b.SetChecksum(^checksum)
+}
+
+// IsValid performs basic validation on the packet.
+func (b IPv4) IsValid(pktSize int) bool {
+ if len(b) < IPv4MinimumSize {
+ return false
+ }
+
+ hlen := int(b.HeaderLength())
+ tlen := int(b.TotalLength())
+ if hlen < IPv4MinimumSize || hlen > tlen || tlen > pktSize {
+ return false
+ }
+
+ if IPVersion(b) != IPv4Version {
+ return false
+ }
+
+ return true
+}
+
+// IsV4MulticastAddress determines if the provided address is an IPv4 multicast
+// address (range 224.0.0.0 to 239.255.255.255). The four most significant bits
+// will be 1110 = 0xe0.
+func IsV4MulticastAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv4AddressSize {
+ return false
+ }
+ return (addr[0] & 0xf0) == 0xe0
+}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
new file mode 100644
index 000000000..4f367fe4c
--- /dev/null
+++ b/pkg/tcpip/header/ipv6.go
@@ -0,0 +1,499 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "crypto/sha256"
+ "encoding/binary"
+ "fmt"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ versTCFL = 0
+ // IPv6PayloadLenOffset is the offset of the PayloadLength field in
+ // IPv6 header.
+ IPv6PayloadLenOffset = 4
+ // IPv6NextHeaderOffset is the offset of the NextHeader field in
+ // IPv6 header.
+ IPv6NextHeaderOffset = 6
+ hopLimit = 7
+ v6SrcAddr = 8
+ v6DstAddr = v6SrcAddr + IPv6AddressSize
+)
+
+// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type IPv6Fields struct {
+ // TrafficClass is the "traffic class" field of an IPv6 packet.
+ TrafficClass uint8
+
+ // FlowLabel is the "flow label" field of an IPv6 packet.
+ FlowLabel uint32
+
+ // PayloadLength is the "payload length" field of an IPv6 packet.
+ PayloadLength uint16
+
+ // NextHeader is the "next header" field of an IPv6 packet.
+ NextHeader uint8
+
+ // HopLimit is the "hop limit" field of an IPv6 packet.
+ HopLimit uint8
+
+ // SrcAddr is the "source ip address" of an IPv6 packet.
+ SrcAddr tcpip.Address
+
+ // DstAddr is the "destination ip address" of an IPv6 packet.
+ DstAddr tcpip.Address
+}
+
+// IPv6 represents an ipv6 header stored in a byte array.
+// Most of the methods of IPv6 access to the underlying slice without
+// checking the boundaries and could panic because of 'index out of range'.
+// Always call IsValid() to validate an instance of IPv6 before using other methods.
+type IPv6 []byte
+
+const (
+ // IPv6MinimumSize is the minimum size of a valid IPv6 packet.
+ IPv6MinimumSize = 40
+
+ // IPv6AddressSize is the size, in bytes, of an IPv6 address.
+ IPv6AddressSize = 16
+
+ // IPv6ProtocolNumber is IPv6's network protocol number.
+ IPv6ProtocolNumber tcpip.NetworkProtocolNumber = 0x86dd
+
+ // IPv6Version is the version of the ipv6 protocol.
+ IPv6Version = 6
+
+ // IPv6AllNodesMulticastAddress is a link-local multicast group that
+ // all IPv6 nodes MUST join, as per RFC 4291, section 2.8. Packets
+ // destined to this address will reach all nodes on a link.
+ //
+ // The address is ff02::1.
+ IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+
+ // IPv6AllRoutersMulticastAddress is a link-local multicast group that
+ // all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
+ // destined to this address will reach all routers on a link.
+ //
+ // The address is ff02::2.
+ IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
+ // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460,
+ // section 5.
+ IPv6MinimumMTU = 1280
+
+ // IPv6Any is the non-routable IPv6 "any" meta address. It is also
+ // known as the unspecified address.
+ IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+
+ // IIDSize is the size of an interface identifier (IID), in bytes, as
+ // defined by RFC 4291 section 2.5.1.
+ IIDSize = 8
+
+ // IIDOffsetInIPv6Address is the offset, in bytes, from the start
+ // of an IPv6 address to the beginning of the interface identifier
+ // (IID) for auto-generated addresses. That is, all bytes before
+ // the IIDOffsetInIPv6Address-th byte are the prefix bytes, and all
+ // bytes including and after the IIDOffsetInIPv6Address-th byte are
+ // for the IID.
+ IIDOffsetInIPv6Address = 8
+
+ // OpaqueIIDSecretKeyMinBytes is the recommended minimum number of bytes
+ // for the secret key used to generate an opaque interface identifier as
+ // outlined by RFC 7217.
+ OpaqueIIDSecretKeyMinBytes = 16
+
+ // ipv6MulticastAddressScopeByteIdx is the byte where the scope (scop) field
+ // is located within a multicast IPv6 address, as per RFC 4291 section 2.7.
+ ipv6MulticastAddressScopeByteIdx = 1
+
+ // ipv6MulticastAddressScopeMask is the mask for the scope (scop) field,
+ // within the byte holding the field, as per RFC 4291 section 2.7.
+ ipv6MulticastAddressScopeMask = 0xF
+
+ // ipv6LinkLocalMulticastScope is the value of the scope (scop) field within
+ // a multicast IPv6 address that indicates the address has link-local scope,
+ // as per RFC 4291 section 2.7.
+ ipv6LinkLocalMulticastScope = 2
+)
+
+// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the
+// catch-all or wildcard subnet. That is, all IPv6 addresses are considered to
+// be contained within this subnet.
+var IPv6EmptySubnet = func() tcpip.Subnet {
+ subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any))
+ if err != nil {
+ panic(err)
+ }
+ return subnet
+}()
+
+// IPv6LinkLocalPrefix is the prefix for IPv6 link-local addresses, as defined
+// by RFC 4291 section 2.5.6.
+//
+// The prefix is fe80::/64
+var IPv6LinkLocalPrefix = tcpip.AddressWithPrefix{
+ Address: "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: 64,
+}
+
+// PayloadLength returns the value of the "payload length" field of the ipv6
+// header.
+func (b IPv6) PayloadLength() uint16 {
+ return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:])
+}
+
+// HopLimit returns the value of the "hop limit" field of the ipv6 header.
+func (b IPv6) HopLimit() uint8 {
+ return b[hopLimit]
+}
+
+// NextHeader returns the value of the "next header" field of the ipv6 header.
+func (b IPv6) NextHeader() uint8 {
+ return b[IPv6NextHeaderOffset]
+}
+
+// TransportProtocol implements Network.TransportProtocol.
+func (b IPv6) TransportProtocol() tcpip.TransportProtocolNumber {
+ return tcpip.TransportProtocolNumber(b.NextHeader())
+}
+
+// Payload implements Network.Payload.
+func (b IPv6) Payload() []byte {
+ return b[IPv6MinimumSize:][:b.PayloadLength()]
+}
+
+// SourceAddress returns the "source address" field of the ipv6 header.
+func (b IPv6) SourceAddress() tcpip.Address {
+ return tcpip.Address(b[v6SrcAddr:][:IPv6AddressSize])
+}
+
+// DestinationAddress returns the "destination address" field of the ipv6
+// header.
+func (b IPv6) DestinationAddress() tcpip.Address {
+ return tcpip.Address(b[v6DstAddr:][:IPv6AddressSize])
+}
+
+// Checksum implements Network.Checksum. Given that IPv6 doesn't have a
+// checksum, it just returns 0.
+func (IPv6) Checksum() uint16 {
+ return 0
+}
+
+// TOS returns the "traffic class" and "flow label" fields of the ipv6 header.
+func (b IPv6) TOS() (uint8, uint32) {
+ v := binary.BigEndian.Uint32(b[versTCFL:])
+ return uint8(v >> 20), v & 0xfffff
+}
+
+// SetTOS sets the "traffic class" and "flow label" fields of the ipv6 header.
+func (b IPv6) SetTOS(t uint8, l uint32) {
+ vtf := (6 << 28) | (uint32(t) << 20) | (l & 0xfffff)
+ binary.BigEndian.PutUint32(b[versTCFL:], vtf)
+}
+
+// SetPayloadLength sets the "payload length" field of the ipv6 header.
+func (b IPv6) SetPayloadLength(payloadLength uint16) {
+ binary.BigEndian.PutUint16(b[IPv6PayloadLenOffset:], payloadLength)
+}
+
+// SetSourceAddress sets the "source address" field of the ipv6 header.
+func (b IPv6) SetSourceAddress(addr tcpip.Address) {
+ copy(b[v6SrcAddr:][:IPv6AddressSize], addr)
+}
+
+// SetDestinationAddress sets the "destination address" field of the ipv6
+// header.
+func (b IPv6) SetDestinationAddress(addr tcpip.Address) {
+ copy(b[v6DstAddr:][:IPv6AddressSize], addr)
+}
+
+// SetNextHeader sets the value of the "next header" field of the ipv6 header.
+func (b IPv6) SetNextHeader(v uint8) {
+ b[IPv6NextHeaderOffset] = v
+}
+
+// SetChecksum implements Network.SetChecksum. Given that IPv6 doesn't have a
+// checksum, it is empty.
+func (IPv6) SetChecksum(uint16) {
+}
+
+// Encode encodes all the fields of the ipv6 header.
+func (b IPv6) Encode(i *IPv6Fields) {
+ b.SetTOS(i.TrafficClass, i.FlowLabel)
+ b.SetPayloadLength(i.PayloadLength)
+ b[IPv6NextHeaderOffset] = i.NextHeader
+ b[hopLimit] = i.HopLimit
+ b.SetSourceAddress(i.SrcAddr)
+ b.SetDestinationAddress(i.DstAddr)
+}
+
+// IsValid performs basic validation on the packet.
+func (b IPv6) IsValid(pktSize int) bool {
+ if len(b) < IPv6MinimumSize {
+ return false
+ }
+
+ dlen := int(b.PayloadLength())
+ if dlen > pktSize-IPv6MinimumSize {
+ return false
+ }
+
+ if IPVersion(b) != IPv6Version {
+ return false
+ }
+
+ return true
+}
+
+// IsV4MappedAddress determines if the provided address is an IPv4 mapped
+// address by checking if its prefix is 0:0:0:0:0:ffff::/96.
+func IsV4MappedAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+
+ return strings.HasPrefix(string(addr), "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff")
+}
+
+// IsV6MulticastAddress determines if the provided address is an IPv6
+// multicast address (anything starting with FF).
+func IsV6MulticastAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+ return addr[0] == 0xff
+}
+
+// IsV6UnicastAddress determines if the provided address is a valid IPv6
+// unicast (and specified) address. That is, IsV6UnicastAddress returns
+// true if addr contains IPv6AddressSize bytes, is not the unspecified
+// address and is not a multicast address.
+func IsV6UnicastAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+
+ // Must not be unspecified
+ if addr == IPv6Any {
+ return false
+ }
+
+ // Return if not a multicast.
+ return addr[0] != 0xff
+}
+
+// SolicitedNodeAddr computes the solicited-node multicast address. This is
+// used for NDP. Described in RFC 4291. The argument must be a full-length IPv6
+// address.
+func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address {
+ const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff"
+ return solicitedNodeMulticastPrefix + addr[len(addr)-3:]
+}
+
+// EthernetAdddressToModifiedEUI64IntoBuf populates buf with a modified EUI-64
+// from a 48-bit Ethernet/MAC address, as per RFC 4291 section 2.5.1.
+//
+// buf MUST be at least 8 bytes.
+func EthernetAdddressToModifiedEUI64IntoBuf(linkAddr tcpip.LinkAddress, buf []byte) {
+ buf[0] = linkAddr[0] ^ 2
+ buf[1] = linkAddr[1]
+ buf[2] = linkAddr[2]
+ buf[3] = 0xFF
+ buf[4] = 0xFE
+ buf[5] = linkAddr[3]
+ buf[6] = linkAddr[4]
+ buf[7] = linkAddr[5]
+}
+
+// EthernetAddressToModifiedEUI64 computes a modified EUI-64 from a 48-bit
+// Ethernet/MAC address, as per RFC 4291 section 2.5.1.
+func EthernetAddressToModifiedEUI64(linkAddr tcpip.LinkAddress) [IIDSize]byte {
+ var buf [IIDSize]byte
+ EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:])
+ return buf
+}
+
+// LinkLocalAddr computes the default IPv6 link-local address from a link-layer
+// (MAC) address.
+func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address {
+ // Convert a 48-bit MAC to a modified EUI-64 and then prepend the
+ // link-local header, FE80::.
+ //
+ // The conversion is very nearly:
+ // aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff
+ // Note the capital A. The conversion aa->Aa involves a bit flip.
+ lladdrb := [IPv6AddressSize]byte{
+ 0: 0xFE,
+ 1: 0x80,
+ }
+ EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, lladdrb[IIDOffsetInIPv6Address:])
+ return tcpip.Address(lladdrb[:])
+}
+
+// IsV6LinkLocalAddress determines if the provided address is an IPv6
+// link-local address (fe80::/10).
+func IsV6LinkLocalAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+ return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80
+}
+
+// IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6
+// link-local multicast address.
+func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool {
+ return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope
+}
+
+// IsV6UniqueLocalAddress determines if the provided address is an IPv6
+// unique-local address (within the prefix FC00::/7).
+func IsV6UniqueLocalAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+ // According to RFC 4193 section 3.1, a unique local address has the prefix
+ // FC00::/7.
+ return (addr[0] & 0xfe) == 0xfc
+}
+
+// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier
+// (IID) to buf as outlined by RFC 7217 and returns the extended buffer.
+//
+// The opaque IID is generated from the cryptographic hash of the concatenation
+// of the prefix, NIC's name, DAD counter (DAD retry counter) and the secret
+// key. The secret key SHOULD be at least OpaqueIIDSecretKeyMinBytes bytes and
+// MUST be generated to a pseudo-random number. See RFC 4086 for randomness
+// requirements for security.
+//
+// If buf has enough capacity for the IID (IIDSize bytes), a new underlying
+// array for the buffer will not be allocated.
+func AppendOpaqueInterfaceIdentifier(buf []byte, prefix tcpip.Subnet, nicName string, dadCounter uint8, secretKey []byte) []byte {
+ // As per RFC 7217 section 5, the opaque identifier can be generated as a
+ // cryptographic hash of the concatenation of each of the function parameters.
+ // Note, we omit the optional Network_ID field.
+ h := sha256.New()
+ // h.Write never returns an error.
+ h.Write([]byte(prefix.ID()[:IIDOffsetInIPv6Address]))
+ h.Write([]byte(nicName))
+ h.Write([]byte{dadCounter})
+ h.Write(secretKey)
+
+ var sumBuf [sha256.Size]byte
+ sum := h.Sum(sumBuf[:0])
+
+ return append(buf, sum[:IIDSize]...)
+}
+
+// LinkLocalAddrWithOpaqueIID computes the default IPv6 link-local address with
+// an opaque IID.
+func LinkLocalAddrWithOpaqueIID(nicName string, dadCounter uint8, secretKey []byte) tcpip.Address {
+ lladdrb := [IPv6AddressSize]byte{
+ 0: 0xFE,
+ 1: 0x80,
+ }
+
+ return tcpip.Address(AppendOpaqueInterfaceIdentifier(lladdrb[:IIDOffsetInIPv6Address], IPv6LinkLocalPrefix.Subnet(), nicName, dadCounter, secretKey))
+}
+
+// IPv6AddressScope is the scope of an IPv6 address.
+type IPv6AddressScope int
+
+const (
+ // LinkLocalScope indicates a link-local address.
+ LinkLocalScope IPv6AddressScope = iota
+
+ // UniqueLocalScope indicates a unique-local address.
+ UniqueLocalScope
+
+ // GlobalScope indicates a global address.
+ GlobalScope
+)
+
+// ScopeForIPv6Address returns the scope for an IPv6 address.
+func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) {
+ if len(addr) != IPv6AddressSize {
+ return GlobalScope, tcpip.ErrBadAddress
+ }
+
+ switch {
+ case IsV6LinkLocalMulticastAddress(addr):
+ return LinkLocalScope, nil
+
+ case IsV6LinkLocalAddress(addr):
+ return LinkLocalScope, nil
+
+ case IsV6UniqueLocalAddress(addr):
+ return UniqueLocalScope, nil
+
+ default:
+ return GlobalScope, nil
+ }
+}
+
+// InitialTempIID generates the initial temporary IID history value to generate
+// temporary SLAAC addresses with.
+//
+// Panics if initialTempIIDHistory is not at least IIDSize bytes.
+func InitialTempIID(initialTempIIDHistory []byte, seed []byte, nicID tcpip.NICID) {
+ h := sha256.New()
+ // h.Write never returns an error.
+ h.Write(seed)
+ var nicIDBuf [4]byte
+ binary.BigEndian.PutUint32(nicIDBuf[:], uint32(nicID))
+ h.Write(nicIDBuf[:])
+
+ var sumBuf [sha256.Size]byte
+ sum := h.Sum(sumBuf[:0])
+
+ if n := copy(initialTempIIDHistory, sum[sha256.Size-IIDSize:]); n != IIDSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IIDSize))
+ }
+}
+
+// GenerateTempIPv6SLAACAddr generates a temporary SLAAC IPv6 address for an
+// associated stable/permanent SLAAC address.
+//
+// GenerateTempIPv6SLAACAddr will update the temporary IID history value to be
+// used when generating a new temporary IID.
+//
+// Panics if tempIIDHistory is not at least IIDSize bytes.
+func GenerateTempIPv6SLAACAddr(tempIIDHistory []byte, stableAddr tcpip.Address) tcpip.AddressWithPrefix {
+ addrBytes := []byte(stableAddr)
+ h := sha256.New()
+ h.Write(tempIIDHistory)
+ h.Write(addrBytes[IIDOffsetInIPv6Address:])
+ var sumBuf [sha256.Size]byte
+ sum := h.Sum(sumBuf[:0])
+
+ // The rightmost 64 bits of sum are saved for the next iteration.
+ if n := copy(tempIIDHistory, sum[sha256.Size-IIDSize:]); n != IIDSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IIDSize))
+ }
+
+ // The leftmost 64 bits of sum is used as the IID.
+ if n := copy(addrBytes[IIDOffsetInIPv6Address:], sum); n != IIDSize {
+ panic(fmt.Sprintf("copied %d IID bytes, expected %d bytes", n, IIDSize))
+ }
+
+ return tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: IIDOffsetInIPv6Address * 8,
+ }
+}
diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go
new file mode 100644
index 000000000..3499d8399
--- /dev/null
+++ b/pkg/tcpip/header/ipv6_extension_headers.go
@@ -0,0 +1,551 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+// IPv6ExtensionHeaderIdentifier is an IPv6 extension header identifier.
+type IPv6ExtensionHeaderIdentifier uint8
+
+const (
+ // IPv6HopByHopOptionsExtHdrIdentifier is the header identifier of a Hop by
+ // Hop Options extension header, as per RFC 8200 section 4.3.
+ IPv6HopByHopOptionsExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 0
+
+ // IPv6RoutingExtHdrIdentifier is the header identifier of a Routing extension
+ // header, as per RFC 8200 section 4.4.
+ IPv6RoutingExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 43
+
+ // IPv6FragmentExtHdrIdentifier is the header identifier of a Fragment
+ // extension header, as per RFC 8200 section 4.5.
+ IPv6FragmentExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 44
+
+ // IPv6DestinationOptionsExtHdrIdentifier is the header identifier of a
+ // Destination Options extension header, as per RFC 8200 section 4.6.
+ IPv6DestinationOptionsExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 60
+
+ // IPv6NoNextHeaderIdentifier is the header identifier used to signify the end
+ // of an IPv6 payload, as per RFC 8200 section 4.7.
+ IPv6NoNextHeaderIdentifier IPv6ExtensionHeaderIdentifier = 59
+)
+
+const (
+ // ipv6UnknownExtHdrOptionActionMask is the mask of the action to take when
+ // a node encounters an unrecognized option.
+ ipv6UnknownExtHdrOptionActionMask = 192
+
+ // ipv6UnknownExtHdrOptionActionShift is the least significant bits to discard
+ // from the action value for an unrecognized option identifier.
+ ipv6UnknownExtHdrOptionActionShift = 6
+
+ // ipv6RoutingExtHdrSegmentsLeftIdx is the index to the Segments Left field
+ // within an IPv6RoutingExtHdr.
+ ipv6RoutingExtHdrSegmentsLeftIdx = 1
+
+ // IPv6FragmentExtHdrLength is the length of an IPv6 extension header, in
+ // bytes.
+ IPv6FragmentExtHdrLength = 8
+
+ // ipv6FragmentExtHdrFragmentOffsetOffset is the offset to the start of the
+ // Fragment Offset field within an IPv6FragmentExtHdr.
+ ipv6FragmentExtHdrFragmentOffsetOffset = 0
+
+ // ipv6FragmentExtHdrFragmentOffsetShift is the least significant bits to
+ // discard from the Fragment Offset.
+ ipv6FragmentExtHdrFragmentOffsetShift = 3
+
+ // ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an
+ // IPv6FragmentExtHdr.
+ ipv6FragmentExtHdrFlagsIdx = 1
+
+ // ipv6FragmentExtHdrMFlagMask is the mask of the More (M) flag within the
+ // flags field of an IPv6FragmentExtHdr.
+ ipv6FragmentExtHdrMFlagMask = 1
+
+ // ipv6FragmentExtHdrIdentificationOffset is the offset to the Identification
+ // field within an IPv6FragmentExtHdr.
+ ipv6FragmentExtHdrIdentificationOffset = 2
+
+ // ipv6ExtHdrLenBytesPerUnit is the unit size of an extension header's length
+ // field. That is, given a Length field of 2, the extension header expects
+ // 16 bytes following the first 8 bytes (see ipv6ExtHdrLenBytesExcluded for
+ // details about the first 8 bytes' exclusion from the Length field).
+ ipv6ExtHdrLenBytesPerUnit = 8
+
+ // ipv6ExtHdrLenBytesExcluded is the number of bytes excluded from an
+ // extension header's Length field following the Length field.
+ //
+ // The Length field excludes the first 8 bytes, but the Next Header and Length
+ // field take up the first 2 of the 8 bytes so we expect (at minimum) 6 bytes
+ // after the Length field.
+ //
+ // This ensures that every extension header is at least 8 bytes.
+ ipv6ExtHdrLenBytesExcluded = 6
+
+ // IPv6FragmentExtHdrFragmentOffsetBytesPerUnit is the unit size of a Fragment
+ // extension header's Fragment Offset field. That is, given a Fragment Offset
+ // of 2, the extension header is indiciating that the fragment's payload
+ // starts at the 16th byte in the reassembled packet.
+ IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8
+)
+
+// IPv6PayloadHeader is implemented by the various headers that can be found
+// in an IPv6 payload.
+//
+// These headers include IPv6 extension headers or upper layer data.
+type IPv6PayloadHeader interface {
+ isIPv6PayloadHeader()
+}
+
+// IPv6RawPayloadHeader the remainder of an IPv6 payload after an iterator
+// encounters a Next Header field it does not recognize as an IPv6 extension
+// header.
+type IPv6RawPayloadHeader struct {
+ Identifier IPv6ExtensionHeaderIdentifier
+ Buf buffer.VectorisedView
+}
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6RawPayloadHeader) isIPv6PayloadHeader() {}
+
+// ipv6OptionsExtHdr is an IPv6 extension header that holds options.
+type ipv6OptionsExtHdr []byte
+
+// Iter returns an iterator over the IPv6 extension header options held in b.
+func (b ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator {
+ it := IPv6OptionsExtHdrOptionsIterator{}
+ it.reader.Reset(b)
+ return it
+}
+
+// IPv6OptionsExtHdrOptionsIterator is an iterator over IPv6 extension header
+// options.
+//
+// Note, between when an IPv6OptionsExtHdrOptionsIterator is obtained and last
+// used, no changes to the underlying buffer may happen. Doing so may cause
+// undefined and unexpected behaviour. It is fine to obtain an
+// IPv6OptionsExtHdrOptionsIterator, iterate over the first few options then
+// modify the backing payload so long as the IPv6OptionsExtHdrOptionsIterator
+// obtained before modification is no longer used.
+type IPv6OptionsExtHdrOptionsIterator struct {
+ reader bytes.Reader
+}
+
+// IPv6OptionUnknownAction is the action that must be taken if the processing
+// IPv6 node does not recognize the option, as outlined in RFC 8200 section 4.2.
+type IPv6OptionUnknownAction int
+
+const (
+ // IPv6OptionUnknownActionSkip indicates that the unrecognized option must
+ // be skipped and the node should continue processing the header.
+ IPv6OptionUnknownActionSkip IPv6OptionUnknownAction = 0
+
+ // IPv6OptionUnknownActionDiscard indicates that the packet must be silently
+ // discarded.
+ IPv6OptionUnknownActionDiscard IPv6OptionUnknownAction = 1
+
+ // IPv6OptionUnknownActionDiscardSendICMP indicates that the packet must be
+ // discarded and the node must send an ICMP Parameter Problem, Code 2, message
+ // to the packet's source, regardless of whether or not the packet's
+ // Destination was a multicast address.
+ IPv6OptionUnknownActionDiscardSendICMP IPv6OptionUnknownAction = 2
+
+ // IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest indicates that the
+ // packet must be discarded and the node must send an ICMP Parameter Problem,
+ // Code 2, message to the packet's source only if the packet's Destination was
+ // not a multicast address.
+ IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest IPv6OptionUnknownAction = 3
+)
+
+// IPv6ExtHdrOption is implemented by the various IPv6 extension header options.
+type IPv6ExtHdrOption interface {
+ // UnknownAction returns the action to take in response to an unrecognized
+ // option.
+ UnknownAction() IPv6OptionUnknownAction
+
+ // isIPv6ExtHdrOption is used to "lock" this interface so it is not
+ // implemented by other packages.
+ isIPv6ExtHdrOption()
+}
+
+// IPv6ExtHdrOptionIndentifier is an IPv6 extension header option identifier.
+type IPv6ExtHdrOptionIndentifier uint8
+
+const (
+ // ipv6Pad1ExtHdrOptionIdentifier is the identifier for a padding option that
+ // provides 1 byte padding, as outlined in RFC 8200 section 4.2.
+ ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 0
+
+ // ipv6PadBExtHdrOptionIdentifier is the identifier for a padding option that
+ // provides variable length byte padding, as outlined in RFC 8200 section 4.2.
+ ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 1
+)
+
+// IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension
+// header option that is unknown by the parsing utilities.
+type IPv6UnknownExtHdrOption struct {
+ Identifier IPv6ExtHdrOptionIndentifier
+ Data []byte
+}
+
+// UnknownAction implements IPv6OptionUnknownAction.UnknownAction.
+func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction {
+ return IPv6OptionUnknownAction((o.Identifier & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift)
+}
+
+// isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption.
+func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {}
+
+// Next returns the next option in the options data.
+//
+// If the next item is not a known extension header option,
+// IPv6UnknownExtHdrOption will be returned with the option identifier and data.
+//
+// The return is of the format (option, done, error). done will be true when
+// Next is unable to return anything because the iterator has reached the end of
+// the options data, or an error occured.
+func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) {
+ for {
+ temp, err := i.reader.ReadByte()
+ if err != nil {
+ // If we can't read the first byte of a new option, then we know the
+ // options buffer has been exhausted and we are done iterating.
+ return nil, true, nil
+ }
+ id := IPv6ExtHdrOptionIndentifier(temp)
+
+ // If the option identifier indicates the option is a Pad1 option, then we
+ // know the option does not have Length and Data fields. End processing of
+ // the Pad1 option and continue processing the buffer as a new option.
+ if id == ipv6Pad1ExtHdrOptionIdentifier {
+ continue
+ }
+
+ length, err := i.reader.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ // ReadByte should only ever return nil or io.EOF.
+ panic(fmt.Sprintf("unexpected error when reading the option's Length field for option with id = %d: %s", id, err))
+ }
+
+ // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected once
+ // we start parsing an option; we expect the reader to contain enough
+ // bytes for the whole option.
+ return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF)
+ }
+
+ // Special-case the variable length padding option to avoid a copy.
+ if id == ipv6PadNExtHdrOptionIdentifier {
+ // Do we have enough bytes in the reader for the PadN option?
+ if n := i.reader.Len(); n < int(length) {
+ // Reset the reader to effectively consume the remaining buffer.
+ i.reader.Reset(nil)
+
+ // We return the same error as if we failed to read a non-padding option
+ // so consumers of this iterator don't need to differentiate between
+ // padding and non-padding options.
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF)
+ }
+
+ if _, err := i.reader.Seek(int64(length), io.SeekCurrent); err != nil {
+ panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err))
+ }
+
+ // End processing of the PadN option and continue processing the buffer as
+ // a new option.
+ continue
+ }
+
+ bytes := make([]byte, length)
+ if n, err := io.ReadFull(&i.reader, bytes); err != nil {
+ // io.ReadFull may return io.EOF if i.reader has been exhausted. We use
+ // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the
+ // Length field found in the option.
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err)
+ }
+
+ return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil
+ }
+}
+
+// IPv6HopByHopOptionsExtHdr is a buffer holding the Hop By Hop Options
+// extension header.
+type IPv6HopByHopOptionsExtHdr struct {
+ ipv6OptionsExtHdr
+}
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6HopByHopOptionsExtHdr) isIPv6PayloadHeader() {}
+
+// IPv6DestinationOptionsExtHdr is a buffer holding the Destination Options
+// extension header.
+type IPv6DestinationOptionsExtHdr struct {
+ ipv6OptionsExtHdr
+}
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6DestinationOptionsExtHdr) isIPv6PayloadHeader() {}
+
+// IPv6RoutingExtHdr is a buffer holding the Routing extension header specific
+// data as outlined in RFC 8200 section 4.4.
+type IPv6RoutingExtHdr []byte
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6RoutingExtHdr) isIPv6PayloadHeader() {}
+
+// SegmentsLeft returns the Segments Left field.
+func (b IPv6RoutingExtHdr) SegmentsLeft() uint8 {
+ return b[ipv6RoutingExtHdrSegmentsLeftIdx]
+}
+
+// IPv6FragmentExtHdr is a buffer holding the Fragment extension header specific
+// data as outlined in RFC 8200 section 4.5.
+//
+// Note, the buffer does not include the Next Header and Reserved fields.
+type IPv6FragmentExtHdr [6]byte
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6FragmentExtHdr) isIPv6PayloadHeader() {}
+
+// FragmentOffset returns the Fragment Offset field.
+//
+// This value indicates where the buffer following the Fragment extension header
+// starts in the target (reassembled) packet.
+func (b IPv6FragmentExtHdr) FragmentOffset() uint16 {
+ return binary.BigEndian.Uint16(b[ipv6FragmentExtHdrFragmentOffsetOffset:]) >> ipv6FragmentExtHdrFragmentOffsetShift
+}
+
+// More returns the More (M) flag.
+//
+// This indicates whether any fragments are expected to succeed b.
+func (b IPv6FragmentExtHdr) More() bool {
+ return b[ipv6FragmentExtHdrFlagsIdx]&ipv6FragmentExtHdrMFlagMask != 0
+}
+
+// ID returns the Identification field.
+//
+// This value is used to uniquely identify the packet, between a
+// souce and destination.
+func (b IPv6FragmentExtHdr) ID() uint32 {
+ return binary.BigEndian.Uint32(b[ipv6FragmentExtHdrIdentificationOffset:])
+}
+
+// IsAtomic returns whether the fragment header indicates an atomic fragment. An
+// atomic fragment is a fragment that contains all the data required to
+// reassemble a full packet.
+func (b IPv6FragmentExtHdr) IsAtomic() bool {
+ return !b.More() && b.FragmentOffset() == 0
+}
+
+// IPv6PayloadIterator is an iterator over the contents of an IPv6 payload.
+//
+// The IPv6 payload may contain IPv6 extension headers before any upper layer
+// data.
+//
+// Note, between when an IPv6PayloadIterator is obtained and last used, no
+// changes to the payload may happen. Doing so may cause undefined and
+// unexpected behaviour. It is fine to obtain an IPv6PayloadIterator, iterate
+// over the first few headers then modify the backing payload so long as the
+// IPv6PayloadIterator obtained before modification is no longer used.
+type IPv6PayloadIterator struct {
+ // The identifier of the next header to parse.
+ nextHdrIdentifier IPv6ExtensionHeaderIdentifier
+
+ // reader is an io.Reader over payload.
+ reader bufio.Reader
+ payload buffer.VectorisedView
+
+ // Indicates to the iterator that it should return the remaining payload as a
+ // raw payload on the next call to Next.
+ forceRaw bool
+}
+
+// MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing
+// extension headers, or a raw payload if the payload cannot be parsed.
+func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, payload buffer.VectorisedView) IPv6PayloadIterator {
+ readers := payload.Readers()
+ readerPs := make([]io.Reader, 0, len(readers))
+ for i := range readers {
+ readerPs = append(readerPs, &readers[i])
+ }
+
+ return IPv6PayloadIterator{
+ nextHdrIdentifier: nextHdrIdentifier,
+ payload: payload.Clone(nil),
+ // We need a buffer of size 1 for calls to bufio.Reader.ReadByte.
+ reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1),
+ }
+}
+
+// AsRawHeader returns the remaining payload of i as a raw header and
+// optionally consumes the iterator.
+//
+// If consume is true, calls to Next after calling AsRawHeader on i will
+// indicate that the iterator is done.
+func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader {
+ identifier := i.nextHdrIdentifier
+
+ var buf buffer.VectorisedView
+ if consume {
+ // Since we consume the iterator, we return the payload as is.
+ buf = i.payload
+
+ // Mark i as done.
+ *i = IPv6PayloadIterator{
+ nextHdrIdentifier: IPv6NoNextHeaderIdentifier,
+ }
+ } else {
+ buf = i.payload.Clone(nil)
+ }
+
+ return IPv6RawPayloadHeader{Identifier: identifier, Buf: buf}
+}
+
+// Next returns the next item in the payload.
+//
+// If the next item is not a known IPv6 extension header, IPv6RawPayloadHeader
+// will be returned with the remaining bytes and next header identifier.
+//
+// The return is of the format (header, done, error). done will be true when
+// Next is unable to return anything because the iterator has reached the end of
+// the payload, or an error occured.
+func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
+ // We could be forced to return i as a raw header when the previous header was
+ // a fragment extension header as the data following the fragment extension
+ // header may not be complete.
+ if i.forceRaw {
+ return i.AsRawHeader(true /* consume */), false, nil
+ }
+
+ // Is the header we are parsing a known extension header?
+ switch i.nextHdrIdentifier {
+ case IPv6HopByHopOptionsExtHdrIdentifier:
+ nextHdrIdentifier, bytes, err := i.nextHeaderData(false /* fragmentHdr */, nil)
+ if err != nil {
+ return nil, true, err
+ }
+
+ i.nextHdrIdentifier = nextHdrIdentifier
+ return IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: bytes}, false, nil
+ case IPv6RoutingExtHdrIdentifier:
+ nextHdrIdentifier, bytes, err := i.nextHeaderData(false /* fragmentHdr */, nil)
+ if err != nil {
+ return nil, true, err
+ }
+
+ i.nextHdrIdentifier = nextHdrIdentifier
+ return IPv6RoutingExtHdr(bytes), false, nil
+ case IPv6FragmentExtHdrIdentifier:
+ var data [6]byte
+ // We ignore the returned bytes becauase we know the fragment extension
+ // header specific data will fit in data.
+ nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:])
+ if err != nil {
+ return nil, true, err
+ }
+
+ fragmentExtHdr := IPv6FragmentExtHdr(data)
+
+ // If the packet is not the first fragment, do not attempt to parse anything
+ // after the fragment extension header as the payload following the fragment
+ // extension header should not contain any headers; the first fragment must
+ // hold all the headers up to and including any upper layer headers, as per
+ // RFC 8200 section 4.5.
+ if fragmentExtHdr.FragmentOffset() != 0 {
+ i.forceRaw = true
+ }
+
+ i.nextHdrIdentifier = nextHdrIdentifier
+ return fragmentExtHdr, false, nil
+ case IPv6DestinationOptionsExtHdrIdentifier:
+ nextHdrIdentifier, bytes, err := i.nextHeaderData(false /* fragmentHdr */, nil)
+ if err != nil {
+ return nil, true, err
+ }
+
+ i.nextHdrIdentifier = nextHdrIdentifier
+ return IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: bytes}, false, nil
+ case IPv6NoNextHeaderIdentifier:
+ // This indicates the end of the IPv6 payload.
+ return nil, true, nil
+
+ default:
+ // The header we are parsing is not a known extension header. Return the
+ // raw payload.
+ return i.AsRawHeader(true /* consume */), false, nil
+ }
+}
+
+// nextHeaderData returns the extension header's Next Header field and raw data.
+//
+// fragmentHdr indicates that the extension header being parsed is the Fragment
+// extension header so the Length field should be ignored as it is Reserved
+// for the Fragment extension header.
+//
+// If bytes is not nil, extension header specific data will be read into bytes
+// if it has enough capacity. If bytes is provided but does not have enough
+// capacity for the data, nextHeaderData will panic.
+func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IPv6ExtensionHeaderIdentifier, []byte, error) {
+ // We ignore the number of bytes read because we know we will only ever read
+ // at max 1 bytes since rune has a length of 1. If we read 0 bytes, the Read
+ // would return io.EOF to indicate that io.Reader has reached the end of the
+ // payload.
+ nextHdrIdentifier, err := i.reader.ReadByte()
+ i.payload.TrimFront(1)
+ if err != nil {
+ return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
+ }
+
+ var length uint8
+ length, err = i.reader.ReadByte()
+ i.payload.TrimFront(1)
+ if err != nil {
+ if fragmentHdr {
+ return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
+ }
+
+ return 0, nil, fmt.Errorf("error when reading the Reserved field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
+ }
+ if fragmentHdr {
+ length = 0
+ }
+
+ bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded
+ if bytes == nil {
+ bytes = make([]byte, bytesLen)
+ } else if n := len(bytes); n < bytesLen {
+ panic(fmt.Sprintf("bytes only has space for %d bytes but need space for %d bytes (length = %d) for extension header with id = %d", n, bytesLen, length, i.nextHdrIdentifier))
+ }
+
+ n, err := io.ReadFull(&i.reader, bytes)
+ i.payload.TrimFront(n)
+ if err != nil {
+ return 0, nil, fmt.Errorf("read %d out of %d extension header data bytes (length = %d) for header with id = %d: %w", n, bytesLen, length, i.nextHdrIdentifier, err)
+ }
+
+ return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), bytes, nil
+}
diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go
new file mode 100644
index 000000000..ab20c5f37
--- /dev/null
+++ b/pkg/tcpip/header/ipv6_extension_headers_test.go
@@ -0,0 +1,992 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+// Equal returns true of a and b are equivalent.
+//
+// Note, Equal will return true if a and b hold the same Identifier value and
+// contain the same bytes in Buf, even if the bytes are split across views
+// differently.
+//
+// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported
+// fields.
+func (a IPv6RawPayloadHeader) Equal(b IPv6RawPayloadHeader) bool {
+ return a.Identifier == b.Identifier && bytes.Equal(a.Buf.ToView(), b.Buf.ToView())
+}
+
+// Equal returns true of a and b are equivalent.
+//
+// Note, Equal will return true if a and b hold equivalent ipv6OptionsExtHdrs.
+//
+// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported
+// fields.
+func (a IPv6HopByHopOptionsExtHdr) Equal(b IPv6HopByHopOptionsExtHdr) bool {
+ return bytes.Equal(a.ipv6OptionsExtHdr, b.ipv6OptionsExtHdr)
+}
+
+// Equal returns true of a and b are equivalent.
+//
+// Note, Equal will return true if a and b hold equivalent ipv6OptionsExtHdrs.
+//
+// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported
+// fields.
+func (a IPv6DestinationOptionsExtHdr) Equal(b IPv6DestinationOptionsExtHdr) bool {
+ return bytes.Equal(a.ipv6OptionsExtHdr, b.ipv6OptionsExtHdr)
+}
+
+func TestIPv6UnknownExtHdrOption(t *testing.T) {
+ tests := []struct {
+ name string
+ identifier IPv6ExtHdrOptionIndentifier
+ expectedUnknownAction IPv6OptionUnknownAction
+ }{
+ {
+ name: "Skip with zero LSBs",
+ identifier: 0,
+ expectedUnknownAction: IPv6OptionUnknownActionSkip,
+ },
+ {
+ name: "Discard with zero LSBs",
+ identifier: 64,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscard,
+ },
+ {
+ name: "Discard and ICMP with zero LSBs",
+ identifier: 128,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMP,
+ },
+ {
+ name: "Discard and ICMP for non multicast destination with zero LSBs",
+ identifier: 192,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest,
+ },
+ {
+ name: "Skip with non-zero LSBs",
+ identifier: 63,
+ expectedUnknownAction: IPv6OptionUnknownActionSkip,
+ },
+ {
+ name: "Discard with non-zero LSBs",
+ identifier: 127,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscard,
+ },
+ {
+ name: "Discard and ICMP with non-zero LSBs",
+ identifier: 191,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMP,
+ },
+ {
+ name: "Discard and ICMP for non multicast destination with non-zero LSBs",
+ identifier: 255,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opt := &IPv6UnknownExtHdrOption{Identifier: test.identifier, Data: []byte{1, 2, 3, 4}}
+ if a := opt.UnknownAction(); a != test.expectedUnknownAction {
+ t.Fatalf("got UnknownAction() = %d, want = %d", a, test.expectedUnknownAction)
+ }
+ })
+ }
+
+}
+
+func TestIPv6OptionsExtHdrIterErr(t *testing.T) {
+ tests := []struct {
+ name string
+ bytes []byte
+ err error
+ }{
+ {
+ name: "Single unknown with zero length",
+ bytes: []byte{255, 0},
+ },
+ {
+ name: "Single unknown with non-zero length",
+ bytes: []byte{255, 3, 1, 2, 3},
+ },
+ {
+ name: "Two options",
+ bytes: []byte{
+ 255, 0,
+ 254, 1, 1,
+ },
+ },
+ {
+ name: "Three options",
+ bytes: []byte{
+ 255, 0,
+ 254, 1, 1,
+ 253, 4, 2, 3, 4, 5,
+ },
+ },
+ {
+ name: "Single unknown only identifier",
+ bytes: []byte{255},
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Single unknown too small with length = 1",
+ bytes: []byte{255, 1},
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Single unknown too small with length = 2",
+ bytes: []byte{255, 2, 1},
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid first with second unknown only identifier",
+ bytes: []byte{
+ 255, 0,
+ 254,
+ },
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid first with second unknown missing data",
+ bytes: []byte{
+ 255, 0,
+ 254, 1,
+ },
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid first with second unknown too small",
+ bytes: []byte{
+ 255, 0,
+ 254, 2, 1,
+ },
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "One Pad1",
+ bytes: []byte{0},
+ },
+ {
+ name: "Multiple Pad1",
+ bytes: []byte{0, 0, 0},
+ },
+ {
+ name: "Multiple PadN",
+ bytes: []byte{
+ // Pad3
+ 1, 1, 1,
+
+ // Pad5
+ 1, 3, 1, 2, 3,
+ },
+ },
+ {
+ name: "Pad5 too small middle of data buffer",
+ bytes: []byte{1, 3, 1, 2},
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Pad5 no data",
+ bytes: []byte{1, 3},
+ err: io.ErrUnexpectedEOF,
+ },
+ }
+
+ check := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expectedErr error) {
+ for i := 0; ; i++ {
+ _, done, err := it.Next()
+ if err != nil {
+ // If we encountered a non-nil error while iterating, make sure it is
+ // is the same error as expectedErr.
+ if !errors.Is(err, expectedErr) {
+ t.Fatalf("got %d-th Next() = %v, want = %v", i, err, expectedErr)
+ }
+
+ return
+ }
+ if done {
+ // If we are done (without an error), make sure that we did not expect
+ // an error.
+ if expectedErr != nil {
+ t.Fatalf("expected error when iterating; want = %s", expectedErr)
+ }
+
+ return
+ }
+ }
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ t.Run("Hop By Hop", func(t *testing.T) {
+ extHdr := IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
+ check(t, extHdr.Iter(), test.err)
+ })
+
+ t.Run("Destination", func(t *testing.T) {
+ extHdr := IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
+ check(t, extHdr.Iter(), test.err)
+ })
+ })
+ }
+}
+
+func TestIPv6OptionsExtHdrIter(t *testing.T) {
+ tests := []struct {
+ name string
+ bytes []byte
+ expected []IPv6ExtHdrOption
+ }{
+ {
+ name: "Single unknown with zero length",
+ bytes: []byte{255, 0},
+ expected: []IPv6ExtHdrOption{
+ &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{}},
+ },
+ },
+ {
+ name: "Single unknown with non-zero length",
+ bytes: []byte{255, 3, 1, 2, 3},
+ expected: []IPv6ExtHdrOption{
+ &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{1, 2, 3}},
+ },
+ },
+ {
+ name: "Single Pad1",
+ bytes: []byte{0},
+ },
+ {
+ name: "Two Pad1",
+ bytes: []byte{0, 0},
+ },
+ {
+ name: "Single Pad3",
+ bytes: []byte{1, 1, 1},
+ },
+ {
+ name: "Single Pad5",
+ bytes: []byte{1, 3, 1, 2, 3},
+ },
+ {
+ name: "Multiple Pad",
+ bytes: []byte{
+ // Pad1
+ 0,
+
+ // Pad2
+ 1, 0,
+
+ // Pad3
+ 1, 1, 1,
+
+ // Pad4
+ 1, 2, 1, 2,
+
+ // Pad5
+ 1, 3, 1, 2, 3,
+ },
+ },
+ {
+ name: "Multiple options",
+ bytes: []byte{
+ // Pad1
+ 0,
+
+ // Unknown
+ 255, 0,
+
+ // Pad2
+ 1, 0,
+
+ // Unknown
+ 254, 1, 1,
+
+ // Pad3
+ 1, 1, 1,
+
+ // Unknown
+ 253, 4, 2, 3, 4, 5,
+
+ // Pad4
+ 1, 2, 1, 2,
+ },
+ expected: []IPv6ExtHdrOption{
+ &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{}},
+ &IPv6UnknownExtHdrOption{Identifier: 254, Data: []byte{1}},
+ &IPv6UnknownExtHdrOption{Identifier: 253, Data: []byte{2, 3, 4, 5}},
+ },
+ },
+ }
+
+ checkIter := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expected []IPv6ExtHdrOption) {
+ for i, e := range expected {
+ opt, done, err := it.Next()
+ if err != nil {
+ t.Errorf("(i=%d) Next(): %s", i, err)
+ }
+ if done {
+ t.Errorf("(i=%d) unexpectedly done iterating", i)
+ }
+ if diff := cmp.Diff(e, opt); diff != "" {
+ t.Errorf("(i=%d) got option mismatch (-want +got):\n%s", i, diff)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ opt, done, err := it.Next()
+ if err != nil {
+ t.Errorf("(last) Next(): %s", err)
+ }
+ if !done {
+ t.Errorf("(last) iterator unexpectedly not done")
+ }
+ if opt != nil {
+ t.Errorf("(last) got Next() = %T, want = nil", opt)
+ }
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ t.Run("Hop By Hop", func(t *testing.T) {
+ extHdr := IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
+ checkIter(t, extHdr.Iter(), test.expected)
+ })
+
+ t.Run("Destination", func(t *testing.T) {
+ extHdr := IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
+ checkIter(t, extHdr.Iter(), test.expected)
+ })
+ })
+ }
+}
+
+func TestIPv6RoutingExtHdr(t *testing.T) {
+ tests := []struct {
+ name string
+ bytes []byte
+ segmentsLeft uint8
+ }{
+ {
+ name: "Zeroes",
+ bytes: []byte{0, 0, 0, 0, 0, 0},
+ segmentsLeft: 0,
+ },
+ {
+ name: "Ones",
+ bytes: []byte{1, 1, 1, 1, 1, 1},
+ segmentsLeft: 1,
+ },
+ {
+ name: "Mixed",
+ bytes: []byte{1, 2, 3, 4, 5, 6},
+ segmentsLeft: 2,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ extHdr := IPv6RoutingExtHdr(test.bytes)
+ if got := extHdr.SegmentsLeft(); got != test.segmentsLeft {
+ t.Errorf("got SegmentsLeft() = %d, want = %d", got, test.segmentsLeft)
+ }
+ })
+ }
+}
+
+func TestIPv6FragmentExtHdr(t *testing.T) {
+ tests := []struct {
+ name string
+ bytes [6]byte
+ fragmentOffset uint16
+ more bool
+ id uint32
+ }{
+ {
+ name: "Zeroes",
+ bytes: [6]byte{0, 0, 0, 0, 0, 0},
+ fragmentOffset: 0,
+ more: false,
+ id: 0,
+ },
+ {
+ name: "Ones",
+ bytes: [6]byte{0, 9, 0, 0, 0, 1},
+ fragmentOffset: 1,
+ more: true,
+ id: 1,
+ },
+ {
+ name: "Mixed",
+ bytes: [6]byte{68, 9, 128, 4, 2, 1},
+ fragmentOffset: 2177,
+ more: true,
+ id: 2147746305,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ extHdr := IPv6FragmentExtHdr(test.bytes)
+ if got := extHdr.FragmentOffset(); got != test.fragmentOffset {
+ t.Errorf("got FragmentOffset() = %d, want = %d", got, test.fragmentOffset)
+ }
+ if got := extHdr.More(); got != test.more {
+ t.Errorf("got More() = %t, want = %t", got, test.more)
+ }
+ if got := extHdr.ID(); got != test.id {
+ t.Errorf("got ID() = %d, want = %d", got, test.id)
+ }
+ })
+ }
+}
+
+func makeVectorisedViewFromByteBuffers(bs ...[]byte) buffer.VectorisedView {
+ size := 0
+ var vs []buffer.View
+
+ for _, b := range bs {
+ vs = append(vs, buffer.View(b))
+ size += len(b)
+ }
+
+ return buffer.NewVectorisedView(size, vs)
+}
+
+func TestIPv6ExtHdrIterErr(t *testing.T) {
+ tests := []struct {
+ name string
+ firstNextHdr IPv6ExtensionHeaderIdentifier
+ payload buffer.VectorisedView
+ err error
+ }{
+ {
+ name: "Upper layer only without data",
+ firstNextHdr: 255,
+ },
+ {
+ name: "Upper layer only with data",
+ firstNextHdr: 255,
+ payload: makeVectorisedViewFromByteBuffers([]byte{1, 2, 3, 4}),
+ },
+ {
+ name: "No next header",
+ firstNextHdr: IPv6NoNextHeaderIdentifier,
+ },
+ {
+ name: "No next header with data",
+ firstNextHdr: IPv6NoNextHeaderIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{1, 2, 3, 4}),
+ },
+ {
+ name: "Valid single hop by hop",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3, 4}),
+ },
+ {
+ name: "Hop by hop too small",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid single fragment",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 68, 9, 128, 4, 2, 1}),
+ },
+ {
+ name: "Fragment too small",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 68, 9, 128, 4, 2}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid single destination",
+ firstNextHdr: IPv6DestinationOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3, 4}),
+ },
+ {
+ name: "Destination too small",
+ firstNextHdr: IPv6DestinationOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid single routing",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2, 3, 4, 5, 6}),
+ },
+ {
+ name: "Valid single routing across views",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2}, []byte{3, 4, 5, 6}),
+ },
+ {
+ name: "Routing too small with zero length field",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2, 3, 4, 5}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid routing with non-zero length field",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8}),
+ },
+ {
+ name: "Valid routing with non-zero length field across views",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6}, []byte{1, 2, 3, 4, 5, 6, 7, 8}),
+ },
+ {
+ name: "Routing too small with non-zero length field",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Routing too small with non-zero length field across views",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6}, []byte{1, 2, 3, 4, 5, 6, 7}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Mixed",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop Options extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // (Atomic) Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
+
+ // Routing extension header.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Destination Options extension header.
+ 255, 0, 255, 4, 1, 2, 3, 4,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ },
+ {
+ name: "Mixed without upper layer data",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop Options extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // (Atomic) Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
+
+ // Routing extension header.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Destination Options extension header.
+ 255, 0, 255, 4, 1, 2, 3, 4,
+ }),
+ },
+ {
+ name: "Mixed without upper layer data but last ext hdr too small",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop Options extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // (Atomic) Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
+
+ // Routing extension header.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Destination Options extension header.
+ 255, 0, 255, 4, 1, 2, 3,
+ }),
+ err: io.ErrUnexpectedEOF,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ it := MakeIPv6PayloadIterator(test.firstNextHdr, test.payload)
+
+ for i := 0; ; i++ {
+ _, done, err := it.Next()
+ if err != nil {
+ // If we encountered a non-nil error while iterating, make sure it is
+ // is the same error as test.err.
+ if !errors.Is(err, test.err) {
+ t.Fatalf("got %d-th Next() = %v, want = %v", i, err, test.err)
+ }
+
+ return
+ }
+ if done {
+ // If we are done (without an error), make sure that we did not expect
+ // an error.
+ if test.err != nil {
+ t.Fatalf("expected error when iterating; want = %s", test.err)
+ }
+
+ return
+ }
+ }
+ })
+ }
+}
+
+func TestIPv6ExtHdrIter(t *testing.T) {
+ routingExtHdrWithUpperLayerData := buffer.View([]byte{255, 0, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4})
+ upperLayerData := buffer.View([]byte{1, 2, 3, 4})
+ tests := []struct {
+ name string
+ firstNextHdr IPv6ExtensionHeaderIdentifier
+ payload buffer.VectorisedView
+ expected []IPv6PayloadHeader
+ }{
+ // With a non-atomic fragment that is not the first fragment, the payload
+ // after the fragment will not be parsed because the payload is expected to
+ // only hold upper layer data.
+ {
+ name: "hopbyhop - fragment (not first) - routing - upper",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // Fragment extension header.
+ //
+ // More = 1, Fragment Offset = 2117, ID = 2147746305
+ uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1,
+
+ // Routing extension header.
+ //
+ // Even though we have a routing ext header here, it should be
+ // be interpretted as raw bytes as only the first fragment is expected
+ // to hold headers.
+ 255, 0, 1, 2, 3, 4, 5, 6,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ IPv6FragmentExtHdr([6]byte{68, 9, 128, 4, 2, 1}),
+ IPv6RawPayloadHeader{
+ Identifier: IPv6RoutingExtHdrIdentifier,
+ Buf: routingExtHdrWithUpperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+ {
+ name: "hopbyhop - fragment (first) - routing - upper",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // Fragment extension header.
+ //
+ // More = 1, Fragment Offset = 0, ID = 2147746305
+ uint8(IPv6RoutingExtHdrIdentifier), 0, 0, 1, 128, 4, 2, 1,
+
+ // Routing extension header.
+ 255, 0, 1, 2, 3, 4, 5, 6,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ IPv6FragmentExtHdr([6]byte{0, 1, 128, 4, 2, 1}),
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: upperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+ {
+ name: "fragment - routing - upper (across views)",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Fragment extension header.
+ uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1,
+
+ // Routing extension header.
+ 255, 0, 1, 2}, []byte{3, 4, 5, 6,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6FragmentExtHdr([6]byte{68, 9, 128, 4, 2, 1}),
+ IPv6RawPayloadHeader{
+ Identifier: IPv6RoutingExtHdrIdentifier,
+ Buf: routingExtHdrWithUpperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+
+ // If we have an atomic fragment, the payload following the fragment
+ // extension header should be parsed normally.
+ {
+ name: "atomic fragment - routing - destination - upper",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
+
+ // Routing extension header.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Destination Options extension header.
+ 255, 0, 1, 4, 1, 2, 3, 4,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: upperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+ {
+ name: "atomic fragment - routing - upper (across views)",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6}, []byte{128, 4, 2, 1,
+
+ // Routing extension header.
+ 255, 0, 1, 2}, []byte{3, 4, 5, 6,
+
+ // Upper layer data.
+ 1, 2}, []byte{3, 4}),
+ expected: []IPv6PayloadHeader{
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]),
+ },
+ },
+ },
+ {
+ name: "atomic fragment - destination - no next header",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Fragment extension header.
+ //
+ // Res (Reserved) bits are 1 which should not affect anything.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 0, 6, 128, 4, 2, 1,
+
+ // Destination Options extension header.
+ uint8(IPv6NoNextHeaderIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // Random data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ },
+ },
+ {
+ name: "routing - atomic fragment - no next header",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Routing extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6NoNextHeaderIdentifier), 0, 0, 6, 128, 4, 2, 1,
+
+ // Random data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ },
+ },
+ {
+ name: "routing - atomic fragment - no next header (across views)",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Routing extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6NoNextHeaderIdentifier), 255, 0, 6}, []byte{128, 4, 2, 1,
+
+ // Random data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ },
+ },
+ {
+ name: "hopbyhop - routing - fragment - no next header",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop Options extension header.
+ uint8(IPv6RoutingExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Fragment extension header.
+ //
+ // Fragment Offset = 32; Res = 6.
+ uint8(IPv6NoNextHeaderIdentifier), 0, 1, 6, 128, 4, 2, 1,
+
+ // Random data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6FragmentExtHdr([6]byte{1, 6, 128, 4, 2, 1}),
+ IPv6RawPayloadHeader{
+ Identifier: IPv6NoNextHeaderIdentifier,
+ Buf: upperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+
+ // Test the raw payload for common transport layer protocol numbers.
+ {
+ name: "TCP raw payload",
+ firstNextHdr: IPv6ExtensionHeaderIdentifier(TCPProtocolNumber),
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: IPv6ExtensionHeaderIdentifier(TCPProtocolNumber),
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "UDP raw payload",
+ firstNextHdr: IPv6ExtensionHeaderIdentifier(UDPProtocolNumber),
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: IPv6ExtensionHeaderIdentifier(UDPProtocolNumber),
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "ICMPv4 raw payload",
+ firstNextHdr: IPv6ExtensionHeaderIdentifier(ICMPv4ProtocolNumber),
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: IPv6ExtensionHeaderIdentifier(ICMPv4ProtocolNumber),
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "ICMPv6 raw payload",
+ firstNextHdr: IPv6ExtensionHeaderIdentifier(ICMPv6ProtocolNumber),
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: IPv6ExtensionHeaderIdentifier(ICMPv6ProtocolNumber),
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "Unknwon next header raw payload",
+ firstNextHdr: 255,
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "Unknwon next header raw payload (across views)",
+ firstNextHdr: 255,
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]),
+ }},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ it := MakeIPv6PayloadIterator(test.firstNextHdr, test.payload)
+
+ for i, e := range test.expected {
+ extHdr, done, err := it.Next()
+ if err != nil {
+ t.Errorf("(i=%d) Next(): %s", i, err)
+ }
+ if done {
+ t.Errorf("(i=%d) unexpectedly done iterating", i)
+ }
+ if diff := cmp.Diff(e, extHdr); diff != "" {
+ t.Errorf("(i=%d) got ext hdr mismatch (-want +got):\n%s", i, diff)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ extHdr, done, err := it.Next()
+ if err != nil {
+ t.Errorf("(last) Next(): %s", err)
+ }
+ if !done {
+ t.Errorf("(last) iterator unexpectedly not done")
+ }
+ if extHdr != nil {
+ t.Errorf("(last) got Next() = %T, want = nil", extHdr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ipv6_fragment.go b/pkg/tcpip/header/ipv6_fragment.go
new file mode 100644
index 000000000..018555a26
--- /dev/null
+++ b/pkg/tcpip/header/ipv6_fragment.go
@@ -0,0 +1,146 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ nextHdrFrag = 0
+ fragOff = 2
+ more = 3
+ idV6 = 4
+)
+
+// IPv6FragmentFields contains the fields of an IPv6 fragment. It is used to describe the
+// fields of a packet that needs to be encoded.
+type IPv6FragmentFields struct {
+ // NextHeader is the "next header" field of an IPv6 fragment.
+ NextHeader uint8
+
+ // FragmentOffset is the "fragment offset" field of an IPv6 fragment.
+ FragmentOffset uint16
+
+ // M is the "more" field of an IPv6 fragment.
+ M bool
+
+ // Identification is the "identification" field of an IPv6 fragment.
+ Identification uint32
+}
+
+// IPv6Fragment represents an ipv6 fragment header stored in a byte array.
+// Most of the methods of IPv6Fragment access to the underlying slice without
+// checking the boundaries and could panic because of 'index out of range'.
+// Always call IsValid() to validate an instance of IPv6Fragment before using other methods.
+type IPv6Fragment []byte
+
+const (
+ // IPv6FragmentHeader header is the number used to specify that the next
+ // header is a fragment header, per RFC 2460.
+ IPv6FragmentHeader = 44
+
+ // IPv6FragmentHeaderSize is the size of the fragment header.
+ IPv6FragmentHeaderSize = 8
+)
+
+// Encode encodes all the fields of the ipv6 fragment.
+func (b IPv6Fragment) Encode(i *IPv6FragmentFields) {
+ b[nextHdrFrag] = i.NextHeader
+ binary.BigEndian.PutUint16(b[fragOff:], i.FragmentOffset<<3)
+ if i.M {
+ b[more] |= 1
+ }
+ binary.BigEndian.PutUint32(b[idV6:], i.Identification)
+}
+
+// IsValid performs basic validation on the fragment header.
+func (b IPv6Fragment) IsValid() bool {
+ return len(b) >= IPv6FragmentHeaderSize
+}
+
+// NextHeader returns the value of the "next header" field of the ipv6 fragment.
+func (b IPv6Fragment) NextHeader() uint8 {
+ return b[nextHdrFrag]
+}
+
+// FragmentOffset returns the "fragment offset" field of the ipv6 fragment.
+func (b IPv6Fragment) FragmentOffset() uint16 {
+ return binary.BigEndian.Uint16(b[fragOff:]) >> 3
+}
+
+// More returns the "more" field of the ipv6 fragment.
+func (b IPv6Fragment) More() bool {
+ return b[more]&1 > 0
+}
+
+// Payload implements Network.Payload.
+func (b IPv6Fragment) Payload() []byte {
+ return b[IPv6FragmentHeaderSize:]
+}
+
+// ID returns the value of the identifier field of the ipv6 fragment.
+func (b IPv6Fragment) ID() uint32 {
+ return binary.BigEndian.Uint32(b[idV6:])
+}
+
+// TransportProtocol implements Network.TransportProtocol.
+func (b IPv6Fragment) TransportProtocol() tcpip.TransportProtocolNumber {
+ return tcpip.TransportProtocolNumber(b.NextHeader())
+}
+
+// The functions below have been added only to satisfy the Network interface.
+
+// Checksum is not supported by IPv6Fragment.
+func (b IPv6Fragment) Checksum() uint16 {
+ panic("not supported")
+}
+
+// SourceAddress is not supported by IPv6Fragment.
+func (b IPv6Fragment) SourceAddress() tcpip.Address {
+ panic("not supported")
+}
+
+// DestinationAddress is not supported by IPv6Fragment.
+func (b IPv6Fragment) DestinationAddress() tcpip.Address {
+ panic("not supported")
+}
+
+// SetSourceAddress is not supported by IPv6Fragment.
+func (b IPv6Fragment) SetSourceAddress(tcpip.Address) {
+ panic("not supported")
+}
+
+// SetDestinationAddress is not supported by IPv6Fragment.
+func (b IPv6Fragment) SetDestinationAddress(tcpip.Address) {
+ panic("not supported")
+}
+
+// SetChecksum is not supported by IPv6Fragment.
+func (b IPv6Fragment) SetChecksum(uint16) {
+ panic("not supported")
+}
+
+// TOS is not supported by IPv6Fragment.
+func (b IPv6Fragment) TOS() (uint8, uint32) {
+ panic("not supported")
+}
+
+// SetTOS is not supported by IPv6Fragment.
+func (b IPv6Fragment) SetTOS(t uint8, l uint32) {
+ panic("not supported")
+}
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
new file mode 100644
index 000000000..426a873b1
--- /dev/null
+++ b/pkg/tcpip/header/ipv6_test.go
@@ -0,0 +1,417 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header_test
+
+import (
+ "bytes"
+ "crypto/sha256"
+ "fmt"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+const (
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+)
+
+func TestEthernetAdddressToModifiedEUI64(t *testing.T) {
+ expectedIID := [header.IIDSize]byte{0, 2, 3, 255, 254, 4, 5, 6}
+
+ if diff := cmp.Diff(expectedIID, header.EthernetAddressToModifiedEUI64(linkAddr)); diff != "" {
+ t.Errorf("EthernetAddressToModifiedEUI64(%s) mismatch (-want +got):\n%s", linkAddr, diff)
+ }
+
+ var buf [header.IIDSize]byte
+ header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:])
+ if diff := cmp.Diff(expectedIID, buf); diff != "" {
+ t.Errorf("EthernetAddressToModifiedEUI64IntoBuf(%s, _) mismatch (-want +got):\n%s", linkAddr, diff)
+ }
+}
+
+func TestLinkLocalAddr(t *testing.T) {
+ if got, want := header.LinkLocalAddr(linkAddr), tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x02\x03\xff\xfe\x04\x05\x06"); got != want {
+ t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want)
+ }
+}
+
+func TestAppendOpaqueInterfaceIdentifier(t *testing.T) {
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte
+ if n, err := rand.Read(secretKeyBuf[:]); err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want {
+ t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n)
+ }
+
+ tests := []struct {
+ name string
+ prefix tcpip.Subnet
+ nicName string
+ dadCounter uint8
+ secretKey []byte
+ }{
+ {
+ name: "SecretKey of minimum size",
+ prefix: header.IPv6LinkLocalPrefix.Subnet(),
+ nicName: "eth0",
+ dadCounter: 0,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes],
+ },
+ {
+ name: "SecretKey of less than minimum size",
+ prefix: func() tcpip.Subnet {
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: "\x01\x02\x03\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+ }
+ return addrWithPrefix.Subnet()
+ }(),
+ nicName: "eth10",
+ dadCounter: 1,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2],
+ },
+ {
+ name: "SecretKey of more than minimum size",
+ prefix: func() tcpip.Subnet {
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: "\x01\x02\x03\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+ }
+ return addrWithPrefix.Subnet()
+ }(),
+ nicName: "eth11",
+ dadCounter: 2,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2],
+ },
+ {
+ name: "Nil SecretKey and empty nicName",
+ prefix: func() tcpip.Subnet {
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: "\x01\x02\x03\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+ }
+ return addrWithPrefix.Subnet()
+ }(),
+ nicName: "",
+ dadCounter: 3,
+ secretKey: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ h := sha256.New()
+ h.Write([]byte(test.prefix.ID()[:header.IIDOffsetInIPv6Address]))
+ h.Write([]byte(test.nicName))
+ h.Write([]byte{test.dadCounter})
+ if k := test.secretKey; k != nil {
+ h.Write(k)
+ }
+ var hashSum [sha256.Size]byte
+ h.Sum(hashSum[:0])
+ want := hashSum[:header.IIDSize]
+
+ // Passing a nil buffer should result in a new buffer returned with the
+ // IID.
+ if got := header.AppendOpaqueInterfaceIdentifier(nil, test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) {
+ t.Errorf("got AppendOpaqueInterfaceIdentifier(nil, %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want)
+ }
+
+ // Passing a buffer with sufficient capacity for the IID should populate
+ // the buffer provided.
+ var iidBuf [header.IIDSize]byte
+ if got := header.AppendOpaqueInterfaceIdentifier(iidBuf[:0], test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) {
+ t.Errorf("got AppendOpaqueInterfaceIdentifier(iidBuf[:0], %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want)
+ }
+ if got := iidBuf[:]; !bytes.Equal(got, want) {
+ t.Errorf("got iidBuf = %x, want = %x", got, want)
+ }
+ })
+ }
+}
+
+func TestLinkLocalAddrWithOpaqueIID(t *testing.T) {
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte
+ if n, err := rand.Read(secretKeyBuf[:]); err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want {
+ t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n)
+ }
+
+ prefix := header.IPv6LinkLocalPrefix.Subnet()
+
+ tests := []struct {
+ name string
+ prefix tcpip.Subnet
+ nicName string
+ dadCounter uint8
+ secretKey []byte
+ }{
+ {
+ name: "SecretKey of minimum size",
+ nicName: "eth0",
+ dadCounter: 0,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes],
+ },
+ {
+ name: "SecretKey of less than minimum size",
+ nicName: "eth10",
+ dadCounter: 1,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2],
+ },
+ {
+ name: "SecretKey of more than minimum size",
+ nicName: "eth11",
+ dadCounter: 2,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2],
+ },
+ {
+ name: "Nil SecretKey and empty nicName",
+ nicName: "",
+ dadCounter: 3,
+ secretKey: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ addrBytes := [header.IPv6AddressSize]byte{
+ 0: 0xFE,
+ 1: 0x80,
+ }
+
+ want := tcpip.Address(header.AppendOpaqueInterfaceIdentifier(
+ addrBytes[:header.IIDOffsetInIPv6Address],
+ prefix,
+ test.nicName,
+ test.dadCounter,
+ test.secretKey,
+ ))
+
+ if got := header.LinkLocalAddrWithOpaqueIID(test.nicName, test.dadCounter, test.secretKey); got != want {
+ t.Errorf("got LinkLocalAddrWithOpaqueIID(%s, %d, %x) = %s, want = %s", test.nicName, test.dadCounter, test.secretKey, got, want)
+ }
+ })
+ }
+}
+
+func TestIsV6UniqueLocalAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expected bool
+ }{
+ {
+ name: "Valid Unique 1",
+ addr: uniqueLocalAddr1,
+ expected: true,
+ },
+ {
+ name: "Valid Unique 2",
+ addr: uniqueLocalAddr1,
+ expected: true,
+ },
+ {
+ name: "Link Local",
+ addr: linkLocalAddr,
+ expected: false,
+ },
+ {
+ name: "Global",
+ addr: globalAddr,
+ expected: false,
+ },
+ {
+ name: "IPv4",
+ addr: "\x01\x02\x03\x04",
+ expected: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := header.IsV6UniqueLocalAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV6UniqueLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ }
+ })
+ }
+}
+
+func TestIsV6LinkLocalMulticastAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expected bool
+ }{
+ {
+ name: "Valid Link Local Multicast",
+ addr: linkLocalMulticastAddr,
+ expected: true,
+ },
+ {
+ name: "Valid Link Local Multicast with flags",
+ addr: "\xff\xf2\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ expected: true,
+ },
+ {
+ name: "Link Local Unicast",
+ addr: linkLocalAddr,
+ expected: false,
+ },
+ {
+ name: "IPv4 Multicast",
+ addr: "\xe0\x00\x00\x01",
+ expected: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := header.IsV6LinkLocalMulticastAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV6LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ }
+ })
+ }
+}
+
+func TestIsV6LinkLocalAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expected bool
+ }{
+ {
+ name: "Valid Link Local Unicast",
+ addr: linkLocalAddr,
+ expected: true,
+ },
+ {
+ name: "Link Local Multicast",
+ addr: linkLocalMulticastAddr,
+ expected: false,
+ },
+ {
+ name: "Unique Local",
+ addr: uniqueLocalAddr1,
+ expected: false,
+ },
+ {
+ name: "Global",
+ addr: globalAddr,
+ expected: false,
+ },
+ {
+ name: "IPv4 Link Local",
+ addr: "\xa9\xfe\x00\x01",
+ expected: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := header.IsV6LinkLocalAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV6LinkLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ }
+ })
+ }
+}
+
+func TestScopeForIPv6Address(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ scope header.IPv6AddressScope
+ err *tcpip.Error
+ }{
+ {
+ name: "Unique Local",
+ addr: uniqueLocalAddr1,
+ scope: header.UniqueLocalScope,
+ err: nil,
+ },
+ {
+ name: "Link Local Unicast",
+ addr: linkLocalAddr,
+ scope: header.LinkLocalScope,
+ err: nil,
+ },
+ {
+ name: "Link Local Multicast",
+ addr: linkLocalMulticastAddr,
+ scope: header.LinkLocalScope,
+ err: nil,
+ },
+ {
+ name: "Global",
+ addr: globalAddr,
+ scope: header.GlobalScope,
+ err: nil,
+ },
+ {
+ name: "IPv4",
+ addr: "\x01\x02\x03\x04",
+ scope: header.GlobalScope,
+ err: tcpip.ErrBadAddress,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ got, err := header.ScopeForIPv6Address(test.addr)
+ if err != test.err {
+ t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (_, %v), want = (_, %v)", test.addr, err, test.err)
+ }
+ if got != test.scope {
+ t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (%d, _), want = (%d, _)", test.addr, got, test.scope)
+ }
+ })
+ }
+}
+
+func TestSolicitedNodeAddr(t *testing.T) {
+ tests := []struct {
+ addr tcpip.Address
+ want tcpip.Address
+ }{
+ {
+ addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\xa0",
+ want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0",
+ },
+ {
+ addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x0e\x0f\xa0",
+ want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0",
+ },
+ {
+ addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x01\x02\x03",
+ want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x01\x02\x03",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) {
+ if got := header.SolicitedNodeAddr(test.addr); got != test.want {
+ t.Fatalf("got header.SolicitedNodeAddr(%s) = %s, want = %s", test.addr, got, test.want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ipversion_test.go b/pkg/tcpip/header/ipversion_test.go
new file mode 100644
index 000000000..b5540bf66
--- /dev/null
+++ b/pkg/tcpip/header/ipversion_test.go
@@ -0,0 +1,67 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+func TestIPv4(t *testing.T) {
+ b := header.IPv4(make([]byte, header.IPv4MinimumSize))
+ b.Encode(&header.IPv4Fields{})
+
+ const want = header.IPv4Version
+ if v := header.IPVersion(b); v != want {
+ t.Fatalf("Bad version, want %v, got %v", want, v)
+ }
+}
+
+func TestIPv6(t *testing.T) {
+ b := header.IPv6(make([]byte, header.IPv6MinimumSize))
+ b.Encode(&header.IPv6Fields{})
+
+ const want = header.IPv6Version
+ if v := header.IPVersion(b); v != want {
+ t.Fatalf("Bad version, want %v, got %v", want, v)
+ }
+}
+
+func TestOtherVersion(t *testing.T) {
+ const want = header.IPv4Version + header.IPv6Version
+ b := make([]byte, 1)
+ b[0] = want << 4
+
+ if v := header.IPVersion(b); v != want {
+ t.Fatalf("Bad version, want %v, got %v", want, v)
+ }
+}
+
+func TestTooShort(t *testing.T) {
+ b := make([]byte, 1)
+ b[0] = (header.IPv4Version + header.IPv6Version) << 4
+
+ // Get the version of a zero-length slice.
+ const want = -1
+ if v := header.IPVersion(b[:0]); v != want {
+ t.Fatalf("Bad version, want %v, got %v", want, v)
+ }
+
+ // Get the version of a nil slice.
+ if v := header.IPVersion(nil); v != want {
+ t.Fatalf("Bad version, want %v, got %v", want, v)
+ }
+}
diff --git a/pkg/tcpip/header/ndp_neighbor_advert.go b/pkg/tcpip/header/ndp_neighbor_advert.go
new file mode 100644
index 000000000..505c92668
--- /dev/null
+++ b/pkg/tcpip/header/ndp_neighbor_advert.go
@@ -0,0 +1,110 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import "gvisor.dev/gvisor/pkg/tcpip"
+
+// NDPNeighborAdvert is an NDP Neighbor Advertisement message. It will
+// only contain the body of an ICMPv6 packet.
+//
+// See RFC 4861 section 4.4 for more details.
+type NDPNeighborAdvert []byte
+
+const (
+ // NDPNAMinimumSize is the minimum size of a valid NDP Neighbor
+ // Advertisement message (body of an ICMPv6 packet).
+ NDPNAMinimumSize = 20
+
+ // ndpNATargetAddressOffset is the start of the Target Address
+ // field within an NDPNeighborAdvert.
+ ndpNATargetAddressOffset = 4
+
+ // ndpNAOptionsOffset is the start of the NDP options in an
+ // NDPNeighborAdvert.
+ ndpNAOptionsOffset = ndpNATargetAddressOffset + IPv6AddressSize
+
+ // ndpNAFlagsOffset is the offset of the flags within an
+ // NDPNeighborAdvert
+ ndpNAFlagsOffset = 0
+
+ // ndpNARouterFlagMask is the mask of the Router Flag field in
+ // the flags byte within in an NDPNeighborAdvert.
+ ndpNARouterFlagMask = (1 << 7)
+
+ // ndpNASolicitedFlagMask is the mask of the Solicited Flag field in
+ // the flags byte within in an NDPNeighborAdvert.
+ ndpNASolicitedFlagMask = (1 << 6)
+
+ // ndpNAOverrideFlagMask is the mask of the Override Flag field in
+ // the flags byte within in an NDPNeighborAdvert.
+ ndpNAOverrideFlagMask = (1 << 5)
+)
+
+// TargetAddress returns the value within the Target Address field.
+func (b NDPNeighborAdvert) TargetAddress() tcpip.Address {
+ return tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize])
+}
+
+// SetTargetAddress sets the value within the Target Address field.
+func (b NDPNeighborAdvert) SetTargetAddress(addr tcpip.Address) {
+ copy(b[ndpNATargetAddressOffset:][:IPv6AddressSize], addr)
+}
+
+// RouterFlag returns the value of the Router Flag field.
+func (b NDPNeighborAdvert) RouterFlag() bool {
+ return b[ndpNAFlagsOffset]&ndpNARouterFlagMask != 0
+}
+
+// SetRouterFlag sets the value in the Router Flag field.
+func (b NDPNeighborAdvert) SetRouterFlag(f bool) {
+ if f {
+ b[ndpNAFlagsOffset] |= ndpNARouterFlagMask
+ } else {
+ b[ndpNAFlagsOffset] &^= ndpNARouterFlagMask
+ }
+}
+
+// SolicitedFlag returns the value of the Solicited Flag field.
+func (b NDPNeighborAdvert) SolicitedFlag() bool {
+ return b[ndpNAFlagsOffset]&ndpNASolicitedFlagMask != 0
+}
+
+// SetSolicitedFlag sets the value in the Solicited Flag field.
+func (b NDPNeighborAdvert) SetSolicitedFlag(f bool) {
+ if f {
+ b[ndpNAFlagsOffset] |= ndpNASolicitedFlagMask
+ } else {
+ b[ndpNAFlagsOffset] &^= ndpNASolicitedFlagMask
+ }
+}
+
+// OverrideFlag returns the value of the Override Flag field.
+func (b NDPNeighborAdvert) OverrideFlag() bool {
+ return b[ndpNAFlagsOffset]&ndpNAOverrideFlagMask != 0
+}
+
+// SetOverrideFlag sets the value in the Override Flag field.
+func (b NDPNeighborAdvert) SetOverrideFlag(f bool) {
+ if f {
+ b[ndpNAFlagsOffset] |= ndpNAOverrideFlagMask
+ } else {
+ b[ndpNAFlagsOffset] &^= ndpNAOverrideFlagMask
+ }
+}
+
+// Options returns an NDPOptions of the the options body.
+func (b NDPNeighborAdvert) Options() NDPOptions {
+ return NDPOptions(b[ndpNAOptionsOffset:])
+}
diff --git a/pkg/tcpip/header/ndp_neighbor_solicit.go b/pkg/tcpip/header/ndp_neighbor_solicit.go
new file mode 100644
index 000000000..3a1b8e139
--- /dev/null
+++ b/pkg/tcpip/header/ndp_neighbor_solicit.go
@@ -0,0 +1,52 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import "gvisor.dev/gvisor/pkg/tcpip"
+
+// NDPNeighborSolicit is an NDP Neighbor Solicitation message. It will only
+// contain the body of an ICMPv6 packet.
+//
+// See RFC 4861 section 4.3 for more details.
+type NDPNeighborSolicit []byte
+
+const (
+ // NDPNSMinimumSize is the minimum size of a valid NDP Neighbor
+ // Solicitation message (body of an ICMPv6 packet).
+ NDPNSMinimumSize = 20
+
+ // ndpNSTargetAddessOffset is the start of the Target Address
+ // field within an NDPNeighborSolicit.
+ ndpNSTargetAddessOffset = 4
+
+ // ndpNSOptionsOffset is the start of the NDP options in an
+ // NDPNeighborSolicit.
+ ndpNSOptionsOffset = ndpNSTargetAddessOffset + IPv6AddressSize
+)
+
+// TargetAddress returns the value within the Target Address field.
+func (b NDPNeighborSolicit) TargetAddress() tcpip.Address {
+ return tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize])
+}
+
+// SetTargetAddress sets the value within the Target Address field.
+func (b NDPNeighborSolicit) SetTargetAddress(addr tcpip.Address) {
+ copy(b[ndpNSTargetAddessOffset:][:IPv6AddressSize], addr)
+}
+
+// Options returns an NDPOptions of the the options body.
+func (b NDPNeighborSolicit) Options() NDPOptions {
+ return NDPOptions(b[ndpNSOptionsOffset:])
+}
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
new file mode 100644
index 000000000..5d3975c56
--- /dev/null
+++ b/pkg/tcpip/header/ndp_options.go
@@ -0,0 +1,899 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "math"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// NDPOptionIdentifier is an NDP option type identifier.
+type NDPOptionIdentifier uint8
+
+const (
+ // NDPSourceLinkLayerAddressOptionType is the type of the Source Link Layer
+ // Address option, as per RFC 4861 section 4.6.1.
+ NDPSourceLinkLayerAddressOptionType NDPOptionIdentifier = 1
+
+ // NDPTargetLinkLayerAddressOptionType is the type of the Target Link Layer
+ // Address option, as per RFC 4861 section 4.6.1.
+ NDPTargetLinkLayerAddressOptionType NDPOptionIdentifier = 2
+
+ // NDPPrefixInformationType is the type of the Prefix Information
+ // option, as per RFC 4861 section 4.6.2.
+ NDPPrefixInformationType NDPOptionIdentifier = 3
+
+ // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS
+ // Server option, as per RFC 8106 section 5.1.
+ NDPRecursiveDNSServerOptionType NDPOptionIdentifier = 25
+
+ // NDPDNSSearchListOptionType is the type of the DNS Search List option,
+ // as per RFC 8106 section 5.2.
+ NDPDNSSearchListOptionType = 31
+)
+
+const (
+ // NDPLinkLayerAddressSize is the size of a Source or Target Link Layer
+ // Address option for an Ethernet address.
+ NDPLinkLayerAddressSize = 8
+
+ // ndpPrefixInformationLength is the expected length, in bytes, of the
+ // body of an NDP Prefix Information option, as per RFC 4861 section
+ // 4.6.2 which specifies that the Length field is 4. Given this, the
+ // expected length, in bytes, is 30 becuase 4 * lengthByteUnits (8) - 2
+ // (Type & Length) = 30.
+ ndpPrefixInformationLength = 30
+
+ // ndpPrefixInformationPrefixLengthOffset is the offset of the Prefix
+ // Length field within an NDPPrefixInformation.
+ ndpPrefixInformationPrefixLengthOffset = 0
+
+ // ndpPrefixInformationFlagsOffset is the offset of the flags byte
+ // within an NDPPrefixInformation.
+ ndpPrefixInformationFlagsOffset = 1
+
+ // ndpPrefixInformationOnLinkFlagMask is the mask of the On-Link Flag
+ // field in the flags byte within an NDPPrefixInformation.
+ ndpPrefixInformationOnLinkFlagMask = (1 << 7)
+
+ // ndpPrefixInformationAutoAddrConfFlagMask is the mask of the
+ // Autonomous Address-Configuration flag field in the flags byte within
+ // an NDPPrefixInformation.
+ ndpPrefixInformationAutoAddrConfFlagMask = (1 << 6)
+
+ // ndpPrefixInformationReserved1FlagsMask is the mask of the Reserved1
+ // field in the flags byte within an NDPPrefixInformation.
+ ndpPrefixInformationReserved1FlagsMask = 63
+
+ // ndpPrefixInformationValidLifetimeOffset is the start of the 4-byte
+ // Valid Lifetime field within an NDPPrefixInformation.
+ ndpPrefixInformationValidLifetimeOffset = 2
+
+ // ndpPrefixInformationPreferredLifetimeOffset is the start of the
+ // 4-byte Preferred Lifetime field within an NDPPrefixInformation.
+ ndpPrefixInformationPreferredLifetimeOffset = 6
+
+ // ndpPrefixInformationReserved2Offset is the start of the 4-byte
+ // Reserved2 field within an NDPPrefixInformation.
+ ndpPrefixInformationReserved2Offset = 10
+
+ // ndpPrefixInformationReserved2Length is the length of the Reserved2
+ // field.
+ //
+ // It is 4 bytes.
+ ndpPrefixInformationReserved2Length = 4
+
+ // ndpPrefixInformationPrefixOffset is the start of the Prefix field
+ // within an NDPPrefixInformation.
+ ndpPrefixInformationPrefixOffset = 14
+
+ // ndpRecursiveDNSServerLifetimeOffset is the start of the 4-byte
+ // Lifetime field within an NDPRecursiveDNSServer.
+ ndpRecursiveDNSServerLifetimeOffset = 2
+
+ // ndpRecursiveDNSServerAddressesOffset is the start of the addresses
+ // for IPv6 Recursive DNS Servers within an NDPRecursiveDNSServer.
+ ndpRecursiveDNSServerAddressesOffset = 6
+
+ // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS Server
+ // option's body size when it contains at least one IPv6 address, as per
+ // RFC 8106 section 5.3.1.
+ minNDPRecursiveDNSServerBodySize = 22
+
+ // ndpDNSSearchListLifetimeOffset is the start of the 4-byte
+ // Lifetime field within an NDPDNSSearchList.
+ ndpDNSSearchListLifetimeOffset = 2
+
+ // ndpDNSSearchListDomainNamesOffset is the start of the DNS search list
+ // domain names within an NDPDNSSearchList.
+ ndpDNSSearchListDomainNamesOffset = 6
+
+ // minNDPDNSSearchListBodySize is the minimum NDP DNS Search List option's
+ // body size when it contains at least one domain name, as per RFC 8106
+ // section 5.3.1.
+ minNDPDNSSearchListBodySize = 14
+
+ // maxDomainNameLabelLength is the maximum length of a domain name
+ // label, as per RFC 1035 section 3.1.
+ maxDomainNameLabelLength = 63
+
+ // maxDomainNameLength is the maximum length of a domain name, including
+ // label AND label length octet, as per RFC 1035 section 3.1.
+ maxDomainNameLength = 255
+
+ // lengthByteUnits is the multiplier factor for the Length field of an
+ // NDP option. That is, the length field for NDP options is in units of
+ // 8 octets, as per RFC 4861 section 4.6.
+ lengthByteUnits = 8
+)
+
+var (
+ // NDPInfiniteLifetime is a value that represents infinity for the
+ // 4-byte lifetime fields found in various NDP options. Its value is
+ // (2^32 - 1)s = 4294967295s.
+ //
+ // This is a variable instead of a constant so that tests can change
+ // this value to a smaller value. It should only be modified by tests.
+ NDPInfiniteLifetime = time.Second * math.MaxUint32
+)
+
+// NDPOptionIterator is an iterator of NDPOption.
+//
+// Note, between when an NDPOptionIterator is obtained and last used, no changes
+// to the NDPOptions may happen. Doing so may cause undefined and unexpected
+// behaviour. It is fine to obtain an NDPOptionIterator, iterate over the first
+// few NDPOption then modify the backing NDPOptions so long as the
+// NDPOptionIterator obtained before modification is no longer used.
+type NDPOptionIterator struct {
+ opts *bytes.Buffer
+}
+
+// Potential errors when iterating over an NDPOptions.
+var (
+ ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body")
+ ErrNDPOptMalformedHeader = errors.New("NDP option has a malformed header")
+)
+
+// Next returns the next element in the backing NDPOptions, or true if we are
+// done, or false if an error occured.
+//
+// The return can be read as option, done, error. Note, option should only be
+// used if done is false and error is nil.
+func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
+ for {
+ // Do we still have elements to look at?
+ if i.opts.Len() == 0 {
+ return nil, true, nil
+ }
+
+ // Get the Type field.
+ temp, err := i.opts.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ // ReadByte should only ever return nil or io.EOF.
+ panic(fmt.Sprintf("unexpected error when reading the option's Type field: %s", err))
+ }
+
+ // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected once
+ // we start parsing an option; we expect the buffer to contain enough
+ // bytes for the whole option.
+ return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Type field: %w", io.ErrUnexpectedEOF)
+ }
+ kind := NDPOptionIdentifier(temp)
+
+ // Get the Length field.
+ length, err := i.opts.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ panic(fmt.Sprintf("unexpected error when reading the option's Length field for %s: %s", kind, err))
+ }
+
+ return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Length field for %s: %w", kind, io.ErrUnexpectedEOF)
+ }
+
+ // This would indicate an erroneous NDP option as the Length field should
+ // never be 0.
+ if length == 0 {
+ return nil, true, fmt.Errorf("zero valued Length field for %s: %w", kind, ErrNDPOptMalformedHeader)
+ }
+
+ // Get the body.
+ numBytes := int(length) * lengthByteUnits
+ numBodyBytes := numBytes - 2
+ body := i.opts.Next(numBodyBytes)
+ if len(body) < numBodyBytes {
+ return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Body for %s: %w", kind, io.ErrUnexpectedEOF)
+ }
+
+ switch kind {
+ case NDPSourceLinkLayerAddressOptionType:
+ return NDPSourceLinkLayerAddressOption(body), false, nil
+
+ case NDPTargetLinkLayerAddressOptionType:
+ return NDPTargetLinkLayerAddressOption(body), false, nil
+
+ case NDPPrefixInformationType:
+ // Make sure the length of a Prefix Information option
+ // body is ndpPrefixInformationLength, as per RFC 4861
+ // section 4.6.2.
+ if numBodyBytes != ndpPrefixInformationLength {
+ return nil, true, fmt.Errorf("got %d bytes for NDP Prefix Information option's body, expected %d bytes: %w", numBodyBytes, ndpPrefixInformationLength, ErrNDPOptMalformedBody)
+ }
+
+ return NDPPrefixInformation(body), false, nil
+
+ case NDPRecursiveDNSServerOptionType:
+ opt := NDPRecursiveDNSServer(body)
+ if err := opt.checkAddresses(); err != nil {
+ return nil, true, err
+ }
+
+ return opt, false, nil
+
+ case NDPDNSSearchListOptionType:
+ opt := NDPDNSSearchList(body)
+ if err := opt.checkDomainNames(); err != nil {
+ return nil, true, err
+ }
+
+ return opt, false, nil
+
+ default:
+ // We do not yet recognize the option, just skip for
+ // now. This is okay because RFC 4861 allows us to
+ // skip/ignore any unrecognized options. However,
+ // we MUST recognized all the options in RFC 4861.
+ //
+ // TODO(b/141487990): Handle all NDP options as defined
+ // by RFC 4861.
+ }
+ }
+}
+
+// NDPOptions is a buffer of NDP options as defined by RFC 4861 section 4.6.
+type NDPOptions []byte
+
+// Iter returns an iterator of NDPOption.
+//
+// If check is true, Iter will do an integrity check on the options by iterating
+// over it and returning an error if detected.
+//
+// See NDPOptionIterator for more information.
+func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) {
+ it := NDPOptionIterator{
+ opts: bytes.NewBuffer(b),
+ }
+
+ if check {
+ it2 := NDPOptionIterator{
+ opts: bytes.NewBuffer(b),
+ }
+
+ for {
+ if _, done, err := it2.Next(); err != nil || done {
+ return it, err
+ }
+ }
+ }
+
+ return it, nil
+}
+
+// Serialize serializes the provided list of NDP options into o.
+//
+// Note, b must be of sufficient size to hold all the options in s. See
+// NDPOptionsSerializer.Length for details on the getting the total size
+// of a serialized NDPOptionsSerializer.
+//
+// Serialize may panic if b is not of sufficient size to hold all the options
+// in s.
+func (b NDPOptions) Serialize(s NDPOptionsSerializer) int {
+ done := 0
+
+ for _, o := range s {
+ l := paddedLength(o)
+
+ if l == 0 {
+ continue
+ }
+
+ b[0] = byte(o.Type())
+
+ // We know this safe because paddedLength would have returned
+ // 0 if o had an invalid length (> 255 * lengthByteUnits).
+ b[1] = uint8(l / lengthByteUnits)
+
+ // Serialize NDP option body.
+ used := o.serializeInto(b[2:])
+
+ // Zero out remaining (padding) bytes, if any exists.
+ for i := used + 2; i < l; i++ {
+ b[i] = 0
+ }
+
+ b = b[l:]
+ done += l
+ }
+
+ return done
+}
+
+// NDPOption is the set of functions to be implemented by all NDP option types.
+type NDPOption interface {
+ fmt.Stringer
+
+ // Type returns the type of the receiver.
+ Type() NDPOptionIdentifier
+
+ // Length returns the length of the body of the receiver, in bytes.
+ Length() int
+
+ // serializeInto serializes the receiver into the provided byte
+ // buffer.
+ //
+ // Note, the caller MUST provide a byte buffer with size of at least
+ // Length. Implementers of this function may assume that the byte buffer
+ // is of sufficient size. serializeInto MAY panic if the provided byte
+ // buffer is not of sufficient size.
+ //
+ // serializeInto will return the number of bytes that was used to
+ // serialize the receiver. Implementers must only use the number of
+ // bytes required to serialize the receiver. Callers MAY provide a
+ // larger buffer than required to serialize into.
+ serializeInto([]byte) int
+}
+
+// paddedLength returns the length of o, in bytes, with any padding bytes, if
+// required.
+func paddedLength(o NDPOption) int {
+ l := o.Length()
+
+ if l == 0 {
+ return 0
+ }
+
+ // Length excludes the 2 Type and Length bytes.
+ l += 2
+
+ // Add extra bytes if needed to make sure the option is
+ // lengthByteUnits-byte aligned. We do this by adding lengthByteUnits-1
+ // to l and then stripping off the last few LSBits from l. This will
+ // make sure that l is rounded up to the nearest unit of
+ // lengthByteUnits. This works since lengthByteUnits is a power of 2
+ // (= 8).
+ mask := lengthByteUnits - 1
+ l += mask
+ l &^= mask
+
+ if l/lengthByteUnits > 255 {
+ // Should never happen because an option can only have a max
+ // value of 255 for its Length field, so just return 0 so this
+ // option does not get serialized.
+ //
+ // Returning 0 here will make sure that this option does not get
+ // serialized when NDPOptions.Serialize is called with the
+ // NDPOptionsSerializer that holds this option, effectively
+ // skipping this option during serialization. Also note that
+ // a value of zero for the Length field in an NDP option is
+ // invalid so this is another sign to the caller that this NDP
+ // option is malformed, as per RFC 4861 section 4.6.
+ return 0
+ }
+
+ return l
+}
+
+// NDPOptionsSerializer is a serializer for NDP options.
+type NDPOptionsSerializer []NDPOption
+
+// Length returns the total number of bytes required to serialize.
+func (b NDPOptionsSerializer) Length() int {
+ l := 0
+
+ for _, o := range b {
+ l += paddedLength(o)
+ }
+
+ return l
+}
+
+// NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option
+// as defined by RFC 4861 section 4.6.1.
+//
+// It is the first X bytes following the NDP option's Type and Length field
+// where X is the value in Length multiplied by lengthByteUnits - 2 bytes.
+type NDPSourceLinkLayerAddressOption tcpip.LinkAddress
+
+// Type implements NDPOption.Type.
+func (o NDPSourceLinkLayerAddressOption) Type() NDPOptionIdentifier {
+ return NDPSourceLinkLayerAddressOptionType
+}
+
+// Length implements NDPOption.Length.
+func (o NDPSourceLinkLayerAddressOption) Length() int {
+ return len(o)
+}
+
+// serializeInto implements NDPOption.serializeInto.
+func (o NDPSourceLinkLayerAddressOption) serializeInto(b []byte) int {
+ return copy(b, o)
+}
+
+// String implements fmt.Stringer.String.
+func (o NDPSourceLinkLayerAddressOption) String() string {
+ return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o))
+}
+
+// EthernetAddress will return an ethernet (MAC) address if the
+// NDPSourceLinkLayerAddressOption's body has at minimum EthernetAddressSize
+// bytes. If the body has more than EthernetAddressSize bytes, only the first
+// EthernetAddressSize bytes are returned as that is all that is needed for an
+// Ethernet address.
+func (o NDPSourceLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress {
+ if len(o) >= EthernetAddressSize {
+ return tcpip.LinkAddress(o[:EthernetAddressSize])
+ }
+
+ return tcpip.LinkAddress([]byte(nil))
+}
+
+// NDPTargetLinkLayerAddressOption is the NDP Target Link Layer Option
+// as defined by RFC 4861 section 4.6.1.
+//
+// It is the first X bytes following the NDP option's Type and Length field
+// where X is the value in Length multiplied by lengthByteUnits - 2 bytes.
+type NDPTargetLinkLayerAddressOption tcpip.LinkAddress
+
+// Type implements NDPOption.Type.
+func (o NDPTargetLinkLayerAddressOption) Type() NDPOptionIdentifier {
+ return NDPTargetLinkLayerAddressOptionType
+}
+
+// Length implements NDPOption.Length.
+func (o NDPTargetLinkLayerAddressOption) Length() int {
+ return len(o)
+}
+
+// serializeInto implements NDPOption.serializeInto.
+func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int {
+ return copy(b, o)
+}
+
+// String implements fmt.Stringer.String.
+func (o NDPTargetLinkLayerAddressOption) String() string {
+ return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o))
+}
+
+// EthernetAddress will return an ethernet (MAC) address if the
+// NDPTargetLinkLayerAddressOption's body has at minimum EthernetAddressSize
+// bytes. If the body has more than EthernetAddressSize bytes, only the first
+// EthernetAddressSize bytes are returned as that is all that is needed for an
+// Ethernet address.
+func (o NDPTargetLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress {
+ if len(o) >= EthernetAddressSize {
+ return tcpip.LinkAddress(o[:EthernetAddressSize])
+ }
+
+ return tcpip.LinkAddress([]byte(nil))
+}
+
+// NDPPrefixInformation is the NDP Prefix Information option as defined by
+// RFC 4861 section 4.6.2.
+//
+// The length, in bytes, of a valid NDP Prefix Information option body MUST be
+// ndpPrefixInformationLength bytes.
+type NDPPrefixInformation []byte
+
+// Type implements NDPOption.Type.
+func (o NDPPrefixInformation) Type() NDPOptionIdentifier {
+ return NDPPrefixInformationType
+}
+
+// Length implements NDPOption.Length.
+func (o NDPPrefixInformation) Length() int {
+ return ndpPrefixInformationLength
+}
+
+// serializeInto implements NDPOption.serializeInto.
+func (o NDPPrefixInformation) serializeInto(b []byte) int {
+ used := copy(b, o)
+
+ // Zero out the Reserved1 field.
+ b[ndpPrefixInformationFlagsOffset] &^= ndpPrefixInformationReserved1FlagsMask
+
+ // Zero out the Reserved2 field.
+ reserved2 := b[ndpPrefixInformationReserved2Offset:][:ndpPrefixInformationReserved2Length]
+ for i := range reserved2 {
+ reserved2[i] = 0
+ }
+
+ return used
+}
+
+// String implements fmt.Stringer.String.
+func (o NDPPrefixInformation) String() string {
+ return fmt.Sprintf("%T(O=%t, A=%t, PL=%s, VL=%s, Prefix=%s)",
+ o,
+ o.OnLinkFlag(),
+ o.AutonomousAddressConfigurationFlag(),
+ o.PreferredLifetime(),
+ o.ValidLifetime(),
+ o.Subnet())
+}
+
+// PrefixLength returns the value in the number of leading bits in the Prefix
+// that are valid.
+//
+// Valid values are in the range [0, 128], but o may not always contain valid
+// values. It is up to the caller to valdiate the Prefix Information option.
+func (o NDPPrefixInformation) PrefixLength() uint8 {
+ return o[ndpPrefixInformationPrefixLengthOffset]
+}
+
+// OnLinkFlag returns true of the prefix is considered on-link. On-link means
+// that a forwarding node is not needed to send packets to other nodes on the
+// same prefix.
+//
+// Note, when this function returns false, no statement is made about the
+// on-link property of a prefix. That is, if OnLinkFlag returns false, the
+// caller MUST NOT conclude that the prefix is off-link and MUST NOT update any
+// previously stored state for this prefix about its on-link status.
+func (o NDPPrefixInformation) OnLinkFlag() bool {
+ return o[ndpPrefixInformationFlagsOffset]&ndpPrefixInformationOnLinkFlagMask != 0
+}
+
+// AutonomousAddressConfigurationFlag returns true if the prefix can be used for
+// Stateless Address Auto-Configuration (as specified in RFC 4862).
+func (o NDPPrefixInformation) AutonomousAddressConfigurationFlag() bool {
+ return o[ndpPrefixInformationFlagsOffset]&ndpPrefixInformationAutoAddrConfFlagMask != 0
+}
+
+// ValidLifetime returns the length of time that the prefix is valid for the
+// purpose of on-link determination. This value is relative to the send time of
+// the packet that the Prefix Information option was present in.
+//
+// Note, a value of 0 implies the prefix should not be considered as on-link,
+// and a value of infinity/forever is represented by
+// NDPInfiniteLifetime.
+func (o NDPPrefixInformation) ValidLifetime() time.Duration {
+ // The field is the time in seconds, as per RFC 4861 section 4.6.2.
+ return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationValidLifetimeOffset:]))
+}
+
+// PreferredLifetime returns the length of time that an address generated from
+// the prefix via Stateless Address Auto-Configuration remains preferred. This
+// value is relative to the send time of the packet that the Prefix Information
+// option was present in.
+//
+// Note, a value of 0 implies that addresses generated from the prefix should
+// no longer remain preferred, and a value of infinity is represented by
+// NDPInfiniteLifetime.
+//
+// Also note that the value of this field MUST NOT exceed the Valid Lifetime
+// field to avoid preferring addresses that are no longer valid, for the
+// purpose of Stateless Address Auto-Configuration.
+func (o NDPPrefixInformation) PreferredLifetime() time.Duration {
+ // The field is the time in seconds, as per RFC 4861 section 4.6.2.
+ return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationPreferredLifetimeOffset:]))
+}
+
+// Prefix returns an IPv6 address or a prefix of an IPv6 address. The Prefix
+// Length field (see NDPPrefixInformation.PrefixLength) contains the number
+// of valid leading bits in the prefix.
+//
+// Hosts SHOULD ignore an NDP Prefix Information option where the Prefix field
+// holds the link-local prefix (fe80::).
+func (o NDPPrefixInformation) Prefix() tcpip.Address {
+ return tcpip.Address(o[ndpPrefixInformationPrefixOffset:][:IPv6AddressSize])
+}
+
+// Subnet returns the Prefix field and Prefix Length field represented in a
+// tcpip.Subnet.
+func (o NDPPrefixInformation) Subnet() tcpip.Subnet {
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: o.Prefix(),
+ PrefixLen: int(o.PrefixLength()),
+ }
+ return addrWithPrefix.Subnet()
+}
+
+// NDPRecursiveDNSServer is the NDP Recursive DNS Server option, as defined by
+// RFC 8106 section 5.1.
+//
+// To make sure that the option meets its minimum length and does not end in the
+// middle of a DNS server's IPv6 address, the length of a valid
+// NDPRecursiveDNSServer must meet the following constraint:
+// (Length - ndpRecursiveDNSServerAddressesOffset) % IPv6AddressSize == 0
+type NDPRecursiveDNSServer []byte
+
+// Type returns the type of an NDP Recursive DNS Server option.
+//
+// Type implements NDPOption.Type.
+func (NDPRecursiveDNSServer) Type() NDPOptionIdentifier {
+ return NDPRecursiveDNSServerOptionType
+}
+
+// Length implements NDPOption.Length.
+func (o NDPRecursiveDNSServer) Length() int {
+ return len(o)
+}
+
+// serializeInto implements NDPOption.serializeInto.
+func (o NDPRecursiveDNSServer) serializeInto(b []byte) int {
+ used := copy(b, o)
+
+ // Zero out the reserved bytes that are before the Lifetime field.
+ for i := 0; i < ndpRecursiveDNSServerLifetimeOffset; i++ {
+ b[i] = 0
+ }
+
+ return used
+}
+
+// String implements fmt.Stringer.String.
+func (o NDPRecursiveDNSServer) String() string {
+ lt := o.Lifetime()
+ addrs, err := o.Addresses()
+ if err != nil {
+ return fmt.Sprintf("%T([] valid for %s; err = %s)", o, lt, err)
+ }
+ return fmt.Sprintf("%T(%s valid for %s)", o, addrs, lt)
+}
+
+// Lifetime returns the length of time that the DNS server addresses
+// in this option may be used for name resolution.
+//
+// Note, a value of 0 implies the addresses should no longer be used,
+// and a value of infinity/forever is represented by NDPInfiniteLifetime.
+//
+// Lifetime may panic if o does not have enough bytes to hold the Lifetime
+// field.
+func (o NDPRecursiveDNSServer) Lifetime() time.Duration {
+ // The field is the time in seconds, as per RFC 8106 section 5.1.
+ return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpRecursiveDNSServerLifetimeOffset:]))
+}
+
+// Addresses returns the recursive DNS server IPv6 addresses that may be
+// used for name resolution.
+//
+// Note, the addresses MAY be link-local addresses.
+func (o NDPRecursiveDNSServer) Addresses() ([]tcpip.Address, error) {
+ var addrs []tcpip.Address
+ return addrs, o.iterAddresses(func(addr tcpip.Address) { addrs = append(addrs, addr) })
+}
+
+// checkAddresses iterates over the addresses in an NDP Recursive DNS Server
+// option and returns any error it encounters.
+func (o NDPRecursiveDNSServer) checkAddresses() error {
+ return o.iterAddresses(nil)
+}
+
+// iterAddresses iterates over the addresses in an NDP Recursive DNS Server
+// option and calls a function with each valid unicast IPv6 address.
+//
+// Note, the addresses MAY be link-local addresses.
+func (o NDPRecursiveDNSServer) iterAddresses(fn func(tcpip.Address)) error {
+ if l := len(o); l < minNDPRecursiveDNSServerBodySize {
+ return fmt.Errorf("got %d bytes for NDP Recursive DNS Server option's body, expected at least %d bytes: %w", l, minNDPRecursiveDNSServerBodySize, io.ErrUnexpectedEOF)
+ }
+
+ o = o[ndpRecursiveDNSServerAddressesOffset:]
+ l := len(o)
+ if l%IPv6AddressSize != 0 {
+ return fmt.Errorf("NDP Recursive DNS Server option's body ends in the middle of an IPv6 address (addresses body size = %d bytes): %w", l, ErrNDPOptMalformedBody)
+ }
+
+ for i := 0; len(o) != 0; i++ {
+ addr := tcpip.Address(o[:IPv6AddressSize])
+ if !IsV6UnicastAddress(addr) {
+ return fmt.Errorf("%d-th address (%s) in NDP Recursive DNS Server option is not a valid unicast IPv6 address: %w", i, addr, ErrNDPOptMalformedBody)
+ }
+
+ if fn != nil {
+ fn(addr)
+ }
+
+ o = o[IPv6AddressSize:]
+ }
+
+ return nil
+}
+
+// NDPDNSSearchList is the NDP DNS Search List option, as defined by
+// RFC 8106 section 5.2.
+type NDPDNSSearchList []byte
+
+// Type implements NDPOption.Type.
+func (o NDPDNSSearchList) Type() NDPOptionIdentifier {
+ return NDPDNSSearchListOptionType
+}
+
+// Length implements NDPOption.Length.
+func (o NDPDNSSearchList) Length() int {
+ return len(o)
+}
+
+// serializeInto implements NDPOption.serializeInto.
+func (o NDPDNSSearchList) serializeInto(b []byte) int {
+ used := copy(b, o)
+
+ // Zero out the reserved bytes that are before the Lifetime field.
+ for i := 0; i < ndpDNSSearchListLifetimeOffset; i++ {
+ b[i] = 0
+ }
+
+ return used
+}
+
+// String implements fmt.Stringer.String.
+func (o NDPDNSSearchList) String() string {
+ lt := o.Lifetime()
+ domainNames, err := o.DomainNames()
+ if err != nil {
+ return fmt.Sprintf("%T([] valid for %s; err = %s)", o, lt, err)
+ }
+ return fmt.Sprintf("%T(%s valid for %s)", o, domainNames, lt)
+}
+
+// Lifetime returns the length of time that the DNS search list of domain names
+// in this option may be used for name resolution.
+//
+// Note, a value of 0 implies the domain names should no longer be used,
+// and a value of infinity/forever is represented by NDPInfiniteLifetime.
+func (o NDPDNSSearchList) Lifetime() time.Duration {
+ // The field is the time in seconds, as per RFC 8106 section 5.1.
+ return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpDNSSearchListLifetimeOffset:]))
+}
+
+// DomainNames returns a DNS search list of domain names.
+//
+// DomainNames will parse the backing buffer as outlined by RFC 1035 section
+// 3.1 and return a list of strings, with all domain names in lower case.
+func (o NDPDNSSearchList) DomainNames() ([]string, error) {
+ var domainNames []string
+ return domainNames, o.iterDomainNames(func(domainName string) { domainNames = append(domainNames, domainName) })
+}
+
+// checkDomainNames iterates over the domain names in an NDP DNS Search List
+// option and returns any error it encounters.
+func (o NDPDNSSearchList) checkDomainNames() error {
+ return o.iterDomainNames(nil)
+}
+
+// iterDomainNames iterates over the domain names in an NDP DNS Search List
+// option and calls a function with each valid domain name.
+func (o NDPDNSSearchList) iterDomainNames(fn func(string)) error {
+ if l := len(o); l < minNDPDNSSearchListBodySize {
+ return fmt.Errorf("got %d bytes for NDP DNS Search List option's body, expected at least %d bytes: %w", l, minNDPDNSSearchListBodySize, io.ErrUnexpectedEOF)
+ }
+
+ var searchList bytes.Reader
+ searchList.Reset(o[ndpDNSSearchListDomainNamesOffset:])
+
+ var scratch [maxDomainNameLength]byte
+ domainName := bytes.NewBuffer(scratch[:])
+
+ // Parse the domain names, as per RFC 1035 section 3.1.
+ for searchList.Len() != 0 {
+ domainName.Reset()
+
+ // Parse a label within a domain name, as per RFC 1035 section 3.1.
+ for {
+ // The first byte is the label length.
+ labelLenByte, err := searchList.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ // ReadByte should only ever return nil or io.EOF.
+ panic(fmt.Sprintf("unexpected error when reading a label's length: %s", err))
+ }
+
+ // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected
+ // once we start parsing a domain name; we expect the buffer to contain
+ // enough bytes for the whole domain name.
+ return fmt.Errorf("unexpected exhausted buffer while parsing a new label for a domain from NDP Search List option: %w", io.ErrUnexpectedEOF)
+ }
+ labelLen := int(labelLenByte)
+
+ // A zero-length label implies the end of a domain name.
+ if labelLen == 0 {
+ // If the domain name is empty or we have no callback function, do
+ // nothing further with the current domain name.
+ if domainName.Len() == 0 || fn == nil {
+ break
+ }
+
+ // Ignore the trailing period in the parsed domain name.
+ domainName.Truncate(domainName.Len() - 1)
+ fn(domainName.String())
+ break
+ }
+
+ // The label's length must not exceed the maximum length for a label.
+ if labelLen > maxDomainNameLabelLength {
+ return fmt.Errorf("label length of %d bytes is greater than the max label length of %d bytes for an NDP Search List option: %w", labelLen, maxDomainNameLabelLength, ErrNDPOptMalformedBody)
+ }
+
+ // The label (and trailing period) must not make the domain name too long.
+ if labelLen+1 > domainName.Cap()-domainName.Len() {
+ return fmt.Errorf("label would make an NDP Search List option's domain name longer than the max domain name length of %d bytes: %w", maxDomainNameLength, ErrNDPOptMalformedBody)
+ }
+
+ // Copy the label and add a trailing period.
+ for i := 0; i < labelLen; i++ {
+ b, err := searchList.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ panic(fmt.Sprintf("unexpected error when reading domain name's label: %s", err))
+ }
+
+ return fmt.Errorf("read %d out of %d bytes for a domain name's label from NDP Search List option: %w", i, labelLen, io.ErrUnexpectedEOF)
+ }
+
+ // As per RFC 1035 section 2.3.1:
+ // 1) the label must only contain ASCII include letters, digits and
+ // hyphens
+ // 2) the first character in a label must be a letter
+ // 3) the last letter in a label must be a letter or digit
+
+ if !isLetter(b) {
+ if i == 0 {
+ return fmt.Errorf("first character of a domain name's label in an NDP Search List option must be a letter, got character code = %d: %w", b, ErrNDPOptMalformedBody)
+ }
+
+ if b == '-' {
+ if i == labelLen-1 {
+ return fmt.Errorf("last character of a domain name's label in an NDP Search List option must not be a hyphen (-): %w", ErrNDPOptMalformedBody)
+ }
+ } else if !isDigit(b) {
+ return fmt.Errorf("domain name's label in an NDP Search List option may only contain letters, digits and hyphens, got character code = %d: %w", b, ErrNDPOptMalformedBody)
+ }
+ }
+
+ // If b is an upper case character, make it lower case.
+ if isUpperLetter(b) {
+ b = b - 'A' + 'a'
+ }
+
+ if err := domainName.WriteByte(b); err != nil {
+ panic(fmt.Sprintf("unexpected error writing label to domain name buffer: %s", err))
+ }
+ }
+ if err := domainName.WriteByte('.'); err != nil {
+ panic(fmt.Sprintf("unexpected error writing trailing period to domain name buffer: %s", err))
+ }
+ }
+ }
+
+ return nil
+}
+
+func isLetter(b byte) bool {
+ return b >= 'a' && b <= 'z' || isUpperLetter(b)
+}
+
+func isUpperLetter(b byte) bool {
+ return b >= 'A' && b <= 'Z'
+}
+
+func isDigit(b byte) bool {
+ return b >= '0' && b <= '9'
+}
diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go
new file mode 100644
index 000000000..bf7610863
--- /dev/null
+++ b/pkg/tcpip/header/ndp_router_advert.go
@@ -0,0 +1,112 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+ "time"
+)
+
+// NDPRouterAdvert is an NDP Router Advertisement message. It will only contain
+// the body of an ICMPv6 packet.
+//
+// See RFC 4861 section 4.2 for more details.
+type NDPRouterAdvert []byte
+
+const (
+ // NDPRAMinimumSize is the minimum size of a valid NDP Router
+ // Advertisement message (body of an ICMPv6 packet).
+ NDPRAMinimumSize = 12
+
+ // ndpRACurrHopLimitOffset is the byte of the Curr Hop Limit field
+ // within an NDPRouterAdvert.
+ ndpRACurrHopLimitOffset = 0
+
+ // ndpRAFlagsOffset is the byte with the NDP RA bit-fields/flags
+ // within an NDPRouterAdvert.
+ ndpRAFlagsOffset = 1
+
+ // ndpRAManagedAddrConfFlagMask is the mask of the Managed Address
+ // Configuration flag within the bit-field/flags byte of an
+ // NDPRouterAdvert.
+ ndpRAManagedAddrConfFlagMask = (1 << 7)
+
+ // ndpRAOtherConfFlagMask is the mask of the Other Configuration flag
+ // within the bit-field/flags byte of an NDPRouterAdvert.
+ ndpRAOtherConfFlagMask = (1 << 6)
+
+ // ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime
+ // field within an NDPRouterAdvert.
+ ndpRARouterLifetimeOffset = 2
+
+ // ndpRAReachableTimeOffset is the start of the 4-byte Reachable Time
+ // field within an NDPRouterAdvert.
+ ndpRAReachableTimeOffset = 4
+
+ // ndpRARetransTimerOffset is the start of the 4-byte Retrans Timer
+ // field within an NDPRouterAdvert.
+ ndpRARetransTimerOffset = 8
+
+ // ndpRAOptionsOffset is the start of the NDP options in an
+ // NDPRouterAdvert.
+ ndpRAOptionsOffset = 12
+)
+
+// CurrHopLimit returns the value of the Curr Hop Limit field.
+func (b NDPRouterAdvert) CurrHopLimit() uint8 {
+ return b[ndpRACurrHopLimitOffset]
+}
+
+// ManagedAddrConfFlag returns the value of the Managed Address Configuration
+// flag.
+func (b NDPRouterAdvert) ManagedAddrConfFlag() bool {
+ return b[ndpRAFlagsOffset]&ndpRAManagedAddrConfFlagMask != 0
+}
+
+// OtherConfFlag returns the value of the Other Configuration flag.
+func (b NDPRouterAdvert) OtherConfFlag() bool {
+ return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0
+}
+
+// RouterLifetime returns the lifetime associated with the default router. A
+// value of 0 means the source of the Router Advertisement is not a default
+// router and SHOULD NOT appear on the default router list. Note, a value of 0
+// only means that the router should not be used as a default router, it does
+// not apply to other information contained in the Router Advertisement.
+func (b NDPRouterAdvert) RouterLifetime() time.Duration {
+ // The field is the time in seconds, as per RFC 4861 section 4.2.
+ return time.Second * time.Duration(binary.BigEndian.Uint16(b[ndpRARouterLifetimeOffset:]))
+}
+
+// ReachableTime returns the time that a node assumes a neighbor is reachable
+// after having received a reachability confirmation. A value of 0 means
+// that it is unspecified by the source of the Router Advertisement message.
+func (b NDPRouterAdvert) ReachableTime() time.Duration {
+ // The field is the time in milliseconds, as per RFC 4861 section 4.2.
+ return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRAReachableTimeOffset:]))
+}
+
+// RetransTimer returns the time between retransmitted Neighbor Solicitation
+// messages. A value of 0 means that it is unspecified by the source of the
+// Router Advertisement message.
+func (b NDPRouterAdvert) RetransTimer() time.Duration {
+ // The field is the time in milliseconds, as per RFC 4861 section 4.2.
+ return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRARetransTimerOffset:]))
+}
+
+// Options returns an NDPOptions of the the options body.
+func (b NDPRouterAdvert) Options() NDPOptions {
+ return NDPOptions(b[ndpRAOptionsOffset:])
+}
diff --git a/pkg/tcpip/header/ndp_router_solicit.go b/pkg/tcpip/header/ndp_router_solicit.go
new file mode 100644
index 000000000..9e67ba95d
--- /dev/null
+++ b/pkg/tcpip/header/ndp_router_solicit.go
@@ -0,0 +1,36 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+// NDPRouterSolicit is an NDP Router Solicitation message. It will only contain
+// the body of an ICMPv6 packet.
+//
+// See RFC 4861 section 4.1 for more details.
+type NDPRouterSolicit []byte
+
+const (
+ // NDPRSMinimumSize is the minimum size of a valid NDP Router
+ // Solicitation message (body of an ICMPv6 packet).
+ NDPRSMinimumSize = 4
+
+ // ndpRSOptionsOffset is the start of the NDP options in an
+ // NDPRouterSolicit.
+ ndpRSOptionsOffset = 4
+)
+
+// Options returns an NDPOptions of the the options body.
+func (b NDPRouterSolicit) Options() NDPOptions {
+ return NDPOptions(b[ndpRSOptionsOffset:])
+}
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
new file mode 100644
index 000000000..dc4591253
--- /dev/null
+++ b/pkg/tcpip/header/ndp_test.go
@@ -0,0 +1,1521 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "regexp"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// TestNDPNeighborSolicit tests the functions of NDPNeighborSolicit.
+func TestNDPNeighborSolicit(t *testing.T) {
+ b := []byte{
+ 0, 0, 0, 0,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ }
+
+ // Test getting the Target Address.
+ ns := NDPNeighborSolicit(b)
+ addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")
+ if got := ns.TargetAddress(); got != addr {
+ t.Errorf("got ns.TargetAddress = %s, want %s", got, addr)
+ }
+
+ // Test updating the Target Address.
+ addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11")
+ ns.SetTargetAddress(addr2)
+ if got := ns.TargetAddress(); got != addr2 {
+ t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2)
+ }
+ // Make sure the address got updated in the backing buffer.
+ if got := tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]); got != addr2 {
+ t.Errorf("got targetaddress buffer = %s, want %s", got, addr2)
+ }
+}
+
+// TestNDPNeighborAdvert tests the functions of NDPNeighborAdvert.
+func TestNDPNeighborAdvert(t *testing.T) {
+ b := []byte{
+ 160, 0, 0, 0,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ }
+
+ // Test getting the Target Address.
+ na := NDPNeighborAdvert(b)
+ addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")
+ if got := na.TargetAddress(); got != addr {
+ t.Errorf("got TargetAddress = %s, want %s", got, addr)
+ }
+
+ // Test getting the Router Flag.
+ if got := na.RouterFlag(); !got {
+ t.Errorf("got RouterFlag = false, want = true")
+ }
+
+ // Test getting the Solicited Flag.
+ if got := na.SolicitedFlag(); got {
+ t.Errorf("got SolicitedFlag = true, want = false")
+ }
+
+ // Test getting the Override Flag.
+ if got := na.OverrideFlag(); !got {
+ t.Errorf("got OverrideFlag = false, want = true")
+ }
+
+ // Test updating the Target Address.
+ addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11")
+ na.SetTargetAddress(addr2)
+ if got := na.TargetAddress(); got != addr2 {
+ t.Errorf("got TargetAddress = %s, want %s", got, addr2)
+ }
+ // Make sure the address got updated in the backing buffer.
+ if got := tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize]); got != addr2 {
+ t.Errorf("got targetaddress buffer = %s, want %s", got, addr2)
+ }
+
+ // Test updating the Router Flag.
+ na.SetRouterFlag(false)
+ if got := na.RouterFlag(); got {
+ t.Errorf("got RouterFlag = true, want = false")
+ }
+
+ // Test updating the Solicited Flag.
+ na.SetSolicitedFlag(true)
+ if got := na.SolicitedFlag(); !got {
+ t.Errorf("got SolicitedFlag = false, want = true")
+ }
+
+ // Test updating the Override Flag.
+ na.SetOverrideFlag(false)
+ if got := na.OverrideFlag(); got {
+ t.Errorf("got OverrideFlag = true, want = false")
+ }
+
+ // Make sure flags got updated in the backing buffer.
+ if got := b[ndpNAFlagsOffset]; got != 64 {
+ t.Errorf("got flags byte = %d, want = 64", got)
+ }
+}
+
+func TestNDPRouterAdvert(t *testing.T) {
+ b := []byte{
+ 64, 128, 1, 2,
+ 3, 4, 5, 6,
+ 7, 8, 9, 10,
+ }
+
+ ra := NDPRouterAdvert(b)
+
+ if got := ra.CurrHopLimit(); got != 64 {
+ t.Errorf("got ra.CurrHopLimit = %d, want = 64", got)
+ }
+
+ if got := ra.ManagedAddrConfFlag(); !got {
+ t.Errorf("got ManagedAddrConfFlag = false, want = true")
+ }
+
+ if got := ra.OtherConfFlag(); got {
+ t.Errorf("got OtherConfFlag = true, want = false")
+ }
+
+ if got, want := ra.RouterLifetime(), time.Second*258; got != want {
+ t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want)
+ }
+
+ if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want {
+ t.Errorf("got ra.ReachableTime = %d, want = %d", got, want)
+ }
+
+ if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want {
+ t.Errorf("got ra.RetransTimer = %d, want = %d", got, want)
+ }
+}
+
+// TestNDPSourceLinkLayerAddressOptionEthernetAddress tests getting the
+// Ethernet address from an NDPSourceLinkLayerAddressOption.
+func TestNDPSourceLinkLayerAddressOptionEthernetAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expected tcpip.LinkAddress
+ }{
+ {
+ "ValidMAC",
+ []byte{1, 2, 3, 4, 5, 6},
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ },
+ {
+ "SLLBodyTooShort",
+ []byte{1, 2, 3, 4, 5},
+ tcpip.LinkAddress([]byte(nil)),
+ },
+ {
+ "SLLBodyLargerThanNeeded",
+ []byte{1, 2, 3, 4, 5, 6, 7, 8},
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ sll := NDPSourceLinkLayerAddressOption(test.buf)
+ if got := sll.EthernetAddress(); got != test.expected {
+ t.Errorf("got sll.EthernetAddress = %s, want = %s", got, test.expected)
+ }
+ })
+ }
+}
+
+// TestNDPSourceLinkLayerAddressOptionSerialize tests serializing a
+// NDPSourceLinkLayerAddressOption.
+func TestNDPSourceLinkLayerAddressOptionSerialize(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expectedBuf []byte
+ addr tcpip.LinkAddress
+ }{
+ {
+ "Ethernet",
+ make([]byte, 8),
+ []byte{1, 1, 1, 2, 3, 4, 5, 6},
+ "\x01\x02\x03\x04\x05\x06",
+ },
+ {
+ "Padding",
+ []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
+ []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0},
+ "\x01\x02\x03\x04\x05\x06\x07\x08",
+ },
+ {
+ "Empty",
+ nil,
+ nil,
+ "",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := NDPOptions(test.buf)
+ serializer := NDPOptionsSerializer{
+ NDPSourceLinkLayerAddressOption(test.addr),
+ }
+ if got, want := int(serializer.Length()), len(test.expectedBuf); got != want {
+ t.Fatalf("got Length = %d, want = %d", got, want)
+ }
+ opts.Serialize(serializer)
+ if !bytes.Equal(test.buf, test.expectedBuf) {
+ t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf)
+ }
+
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ if len(test.expectedBuf) > 0 {
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType {
+ t.Fatalf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType)
+ }
+ sll := next.(NDPSourceLinkLayerAddressOption)
+ if got, want := []byte(sll), test.expectedBuf[2:]; !bytes.Equal(got, want) {
+ t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
+ }
+
+ if got, want := sll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want {
+ t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want)
+ }
+ }
+
+ // Iterator should not return anything else.
+ next, done, err := it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+ })
+ }
+}
+
+// TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the
+// Ethernet address from an NDPTargetLinkLayerAddressOption.
+func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expected tcpip.LinkAddress
+ }{
+ {
+ "ValidMAC",
+ []byte{1, 2, 3, 4, 5, 6},
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ },
+ {
+ "TLLBodyTooShort",
+ []byte{1, 2, 3, 4, 5},
+ tcpip.LinkAddress([]byte(nil)),
+ },
+ {
+ "TLLBodyLargerThanNeeded",
+ []byte{1, 2, 3, 4, 5, 6, 7, 8},
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ tll := NDPTargetLinkLayerAddressOption(test.buf)
+ if got := tll.EthernetAddress(); got != test.expected {
+ t.Errorf("got tll.EthernetAddress = %s, want = %s", got, test.expected)
+ }
+ })
+ }
+}
+
+// TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a
+// NDPTargetLinkLayerAddressOption.
+func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expectedBuf []byte
+ addr tcpip.LinkAddress
+ }{
+ {
+ "Ethernet",
+ make([]byte, 8),
+ []byte{2, 1, 1, 2, 3, 4, 5, 6},
+ "\x01\x02\x03\x04\x05\x06",
+ },
+ {
+ "Padding",
+ []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
+ []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0},
+ "\x01\x02\x03\x04\x05\x06\x07\x08",
+ },
+ {
+ "Empty",
+ nil,
+ nil,
+ "",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := NDPOptions(test.buf)
+ serializer := NDPOptionsSerializer{
+ NDPTargetLinkLayerAddressOption(test.addr),
+ }
+ if got, want := int(serializer.Length()), len(test.expectedBuf); got != want {
+ t.Fatalf("got Length = %d, want = %d", got, want)
+ }
+ opts.Serialize(serializer)
+ if !bytes.Equal(test.buf, test.expectedBuf) {
+ t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf)
+ }
+
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ if len(test.expectedBuf) > 0 {
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType {
+ t.Fatalf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType)
+ }
+ tll := next.(NDPTargetLinkLayerAddressOption)
+ if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) {
+ t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
+ }
+
+ if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want {
+ t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want)
+ }
+ }
+
+ // Iterator should not return anything else.
+ next, done, err := it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+ })
+ }
+}
+
+// TestNDPPrefixInformationOption tests the field getters and serialization of a
+// NDPPrefixInformation.
+func TestNDPPrefixInformationOption(t *testing.T) {
+ b := []byte{
+ 43, 127,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 5, 5, 5, 5,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23, 24,
+ }
+
+ targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
+ opts := NDPOptions(targetBuf)
+ serializer := NDPOptionsSerializer{
+ NDPPrefixInformation(b),
+ }
+ opts.Serialize(serializer)
+ expectedBuf := []byte{
+ 3, 4, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23, 24,
+ }
+ if !bytes.Equal(targetBuf, expectedBuf) {
+ t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expectedBuf)
+ }
+
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPPrefixInformationType {
+ t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType)
+ }
+
+ pi := next.(NDPPrefixInformation)
+
+ if got := pi.Type(); got != 3 {
+ t.Errorf("got Type = %d, want = 3", got)
+ }
+
+ if got := pi.Length(); got != 30 {
+ t.Errorf("got Length = %d, want = 30", got)
+ }
+
+ if got := pi.PrefixLength(); got != 43 {
+ t.Errorf("got PrefixLength = %d, want = 43", got)
+ }
+
+ if pi.OnLinkFlag() {
+ t.Error("got OnLinkFlag = true, want = false")
+ }
+
+ if !pi.AutonomousAddressConfigurationFlag() {
+ t.Error("got AutonomousAddressConfigurationFlag = false, want = true")
+ }
+
+ if got, want := pi.ValidLifetime(), 16909060*time.Second; got != want {
+ t.Errorf("got ValidLifetime = %d, want = %d", got, want)
+ }
+
+ if got, want := pi.PreferredLifetime(), 84281096*time.Second; got != want {
+ t.Errorf("got PreferredLifetime = %d, want = %d", got, want)
+ }
+
+ if got, want := pi.Prefix(), tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18"); got != want {
+ t.Errorf("got Prefix = %s, want = %s", got, want)
+ }
+
+ // Iterator should not return anything else.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+}
+
+func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) {
+ b := []byte{
+ 9, 8,
+ 1, 2, 4, 8,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ }
+ targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
+ expected := []byte{
+ 25, 3, 0, 0,
+ 1, 2, 4, 8,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ }
+ opts := NDPOptions(targetBuf)
+ serializer := NDPOptionsSerializer{
+ NDPRecursiveDNSServer(b),
+ }
+ if got, want := opts.Serialize(serializer), len(expected); got != want {
+ t.Errorf("got Serialize = %d, want = %d", got, want)
+ }
+ if !bytes.Equal(targetBuf, expected) {
+ t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected)
+ }
+
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPRecursiveDNSServerOptionType {
+ t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType)
+ }
+
+ opt, ok := next.(NDPRecursiveDNSServer)
+ if !ok {
+ t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next)
+ }
+ if got := opt.Type(); got != 25 {
+ t.Errorf("got Type = %d, want = 31", got)
+ }
+ if got := opt.Length(); got != 22 {
+ t.Errorf("got Length = %d, want = 22", got)
+ }
+ if got, want := opt.Lifetime(), 16909320*time.Second; got != want {
+ t.Errorf("got Lifetime = %s, want = %s", got, want)
+ }
+ want := []tcpip.Address{
+ "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
+ }
+ addrs, err := opt.Addresses()
+ if err != nil {
+ t.Errorf("opt.Addresses() = %s", err)
+ }
+ if diff := cmp.Diff(addrs, want); diff != "" {
+ t.Errorf("mismatched addresses (-want +got):\n%s", diff)
+ }
+
+ // Iterator should not return anything else.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+}
+
+func TestNDPRecursiveDNSServerOption(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ lifetime time.Duration
+ addrs []tcpip.Address
+ }{
+ {
+ "Valid1Addr",
+ []byte{
+ 25, 3, 0, 0,
+ 0, 0, 0, 0,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ },
+ 0,
+ []tcpip.Address{
+ "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
+ },
+ },
+ {
+ "Valid2Addr",
+ []byte{
+ 25, 5, 0, 0,
+ 0, 0, 0, 0,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16,
+ },
+ 0,
+ []tcpip.Address{
+ "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
+ "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10",
+ },
+ },
+ {
+ "Valid3Addr",
+ []byte{
+ 25, 7, 0, 0,
+ 0, 0, 0, 0,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16,
+ 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17,
+ },
+ 0,
+ []tcpip.Address{
+ "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
+ "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10",
+ "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x11",
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := NDPOptions(test.buf)
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ // Iterator should get our option.
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPRecursiveDNSServerOptionType {
+ t.Fatalf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType)
+ }
+
+ opt, ok := next.(NDPRecursiveDNSServer)
+ if !ok {
+ t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next)
+ }
+ if got := opt.Lifetime(); got != test.lifetime {
+ t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime)
+ }
+ addrs, err := opt.Addresses()
+ if err != nil {
+ t.Errorf("opt.Addresses() = %s", err)
+ }
+ if diff := cmp.Diff(addrs, test.addrs); diff != "" {
+ t.Errorf("mismatched addresses (-want +got):\n%s", diff)
+ }
+
+ // Iterator should not return anything else.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+ })
+ }
+}
+
+// TestNDPDNSSearchListOption tests the getters of NDPDNSSearchList.
+func TestNDPDNSSearchListOption(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ lifetime time.Duration
+ domainNames []string
+ err error
+ }{
+ {
+ name: "Valid1Label",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, 'a', 'b', 'c',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: []string{
+ "abc",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid2Label",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 5,
+ 3, 'a', 'b', 'c',
+ 4, 'a', 'b', 'c', 'd',
+ 0,
+ 0, 0, 0, 0, 0, 0,
+ },
+ lifetime: 5 * time.Second,
+ domainNames: []string{
+ "abc.abcd",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid3Label",
+ buf: []byte{
+ 0, 0,
+ 1, 0, 0, 0,
+ 3, 'a', 'b', 'c',
+ 4, 'a', 'b', 'c', 'd',
+ 1, 'e',
+ 0,
+ 0, 0, 0, 0,
+ },
+ lifetime: 16777216 * time.Second,
+ domainNames: []string{
+ "abc.abcd.e",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid2Domains",
+ buf: []byte{
+ 0, 0,
+ 1, 2, 3, 4,
+ 3, 'a', 'b', 'c',
+ 0,
+ 2, 'd', 'e',
+ 3, 'x', 'y', 'z',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: 16909060 * time.Second,
+ domainNames: []string{
+ "abc",
+ "de.xyz",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid3DomainsMixedCase",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 3, 'a', 'B', 'c',
+ 0,
+ 2, 'd', 'E',
+ 3, 'X', 'y', 'z',
+ 0,
+ 1, 'J',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: []string{
+ "abc",
+ "de.xyz",
+ "j",
+ },
+ err: nil,
+ },
+ {
+ name: "ValidDomainAfterNULL",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 3, 'a', 'B', 'c',
+ 0, 0, 0, 0,
+ 2, 'd', 'E',
+ 3, 'X', 'y', 'z',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: []string{
+ "abc",
+ "de.xyz",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid0Domains",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 0,
+ 0, 0, 0, 0, 0, 0, 0,
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: nil,
+ },
+ {
+ name: "NoTrailingNull",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 7, 'a', 'b', 'c', 'd', 'e', 'f', 'g',
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "IncorrectLength",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 8, 'a', 'b', 'c', 'd', 'e', 'f', 'g',
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "IncorrectLengthWithNULL",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 7, 'a', 'b', 'c', 'd', 'e', 'f',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "LabelOfLength63",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: []string{
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk",
+ },
+ err: nil,
+ },
+ {
+ name: "LabelOfLength64",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 64, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "DomainNameOfLength255",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: []string{
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghij",
+ },
+ err: nil,
+ },
+ {
+ name: "DomainNameOfLength256",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "StartingDigitForLabel",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, '9', 'b', 'c',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "StartingHyphenForLabel",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, '-', 'b', 'c',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "EndingHyphenForLabel",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, 'a', 'b', '-',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "EndingDigitForLabel",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, 'a', 'b', '9',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: []string{
+ "ab9",
+ },
+ err: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opt := NDPDNSSearchList(test.buf)
+
+ if got := opt.Lifetime(); got != test.lifetime {
+ t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime)
+ }
+ domainNames, err := opt.DomainNames()
+ if !errors.Is(err, test.err) {
+ t.Errorf("opt.DomainNames() = %s", err)
+ }
+ if diff := cmp.Diff(domainNames, test.domainNames); diff != "" {
+ t.Errorf("mismatched domain names (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestNDPSearchListOptionDomainNameLabelInvalidSymbols(t *testing.T) {
+ for r := rune(0); r <= 255; r++ {
+ t.Run(fmt.Sprintf("RuneVal=%d", r), func(t *testing.T) {
+ buf := []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 3, 'a', 0 /* will be replaced */, 'c',
+ 0,
+ 0, 0, 0,
+ }
+ buf[8] = uint8(r)
+ opt := NDPDNSSearchList(buf)
+
+ // As per RFC 1035 section 2.3.1, the label must only include ASCII
+ // letters, digits and hyphens (a-z, A-Z, 0-9, -).
+ var expectedErr error
+ re := regexp.MustCompile(`[a-zA-Z0-9-]`)
+ if !re.Match([]byte{byte(r)}) {
+ expectedErr = ErrNDPOptMalformedBody
+ }
+
+ if domainNames, err := opt.DomainNames(); !errors.Is(err, expectedErr) {
+ t.Errorf("got opt.DomainNames() = (%s, %v), want = (_, %v)", domainNames, err, ErrNDPOptMalformedBody)
+ }
+ })
+ }
+}
+
+func TestNDPDNSSearchListOptionSerialize(t *testing.T) {
+ b := []byte{
+ 9, 8,
+ 1, 0, 0, 0,
+ 3, 'a', 'b', 'c',
+ 4, 'a', 'b', 'c', 'd',
+ 1, 'e',
+ 0,
+ }
+ targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
+ expected := []byte{
+ 31, 3, 0, 0,
+ 1, 0, 0, 0,
+ 3, 'a', 'b', 'c',
+ 4, 'a', 'b', 'c', 'd',
+ 1, 'e',
+ 0,
+ 0, 0, 0, 0,
+ }
+ opts := NDPOptions(targetBuf)
+ serializer := NDPOptionsSerializer{
+ NDPDNSSearchList(b),
+ }
+ if got, want := opts.Serialize(serializer), len(expected); got != want {
+ t.Errorf("got Serialize = %d, want = %d", got, want)
+ }
+ if !bytes.Equal(targetBuf, expected) {
+ t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected)
+ }
+
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPDNSSearchListOptionType {
+ t.Errorf("got Type = %d, want = %d", got, NDPDNSSearchListOptionType)
+ }
+
+ opt, ok := next.(NDPDNSSearchList)
+ if !ok {
+ t.Fatalf("next (type = %T) cannot be casted to an NDPDNSSearchList", next)
+ }
+ if got := opt.Type(); got != 31 {
+ t.Errorf("got Type = %d, want = 31", got)
+ }
+ if got := opt.Length(); got != 22 {
+ t.Errorf("got Length = %d, want = 22", got)
+ }
+ if got, want := opt.Lifetime(), 16777216*time.Second; got != want {
+ t.Errorf("got Lifetime = %s, want = %s", got, want)
+ }
+ domainNames, err := opt.DomainNames()
+ if err != nil {
+ t.Errorf("opt.DomainNames() = %s", err)
+ }
+ if diff := cmp.Diff(domainNames, []string{"abc.abcd.e"}); diff != "" {
+ t.Errorf("domain names mismatch (-want +got):\n%s", diff)
+ }
+
+ // Iterator should not return anything else.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+}
+
+// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions
+// the iterator was returned for is malformed.
+func TestNDPOptionsIterCheck(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expectedErr error
+ }{
+ {
+ name: "ZeroLengthField",
+ buf: []byte{0, 0, 0, 0, 0, 0, 0, 0},
+ expectedErr: ErrNDPOptMalformedHeader,
+ },
+ {
+ name: "ValidSourceLinkLayerAddressOption",
+ buf: []byte{1, 1, 1, 2, 3, 4, 5, 6},
+ expectedErr: nil,
+ },
+ {
+ name: "TooSmallSourceLinkLayerAddressOption",
+ buf: []byte{1, 1, 1, 2, 3, 4, 5},
+ expectedErr: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "ValidTargetLinkLayerAddressOption",
+ buf: []byte{2, 1, 1, 2, 3, 4, 5, 6},
+ expectedErr: nil,
+ },
+ {
+ name: "TooSmallTargetLinkLayerAddressOption",
+ buf: []byte{2, 1, 1, 2, 3, 4, 5},
+ expectedErr: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "ValidPrefixInformation",
+ buf: []byte{
+ 3, 4, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23, 24,
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "TooSmallPrefixInformation",
+ buf: []byte{
+ 3, 4, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23,
+ },
+ expectedErr: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "InvalidPrefixInformationLength",
+ buf: []byte{
+ 3, 3, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ },
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformation",
+ buf: []byte{
+ // Source Link-Layer Address.
+ 1, 1, 1, 2, 3, 4, 5, 6,
+
+ // Target Link-Layer Address.
+ 2, 1, 7, 8, 9, 10, 11, 12,
+
+ // Prefix information.
+ 3, 4, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23, 24,
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformationWithUnrecognized",
+ buf: []byte{
+ // Source Link-Layer Address.
+ 1, 1, 1, 2, 3, 4, 5, 6,
+
+ // Target Link-Layer Address.
+ 2, 1, 7, 8, 9, 10, 11, 12,
+
+ // 255 is an unrecognized type. If 255 ends up
+ // being the type for some recognized type,
+ // update 255 to some other unrecognized value.
+ 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8,
+
+ // Prefix information.
+ 3, 4, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23, 24,
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "InvalidRecursiveDNSServerCutsOffAddress",
+ buf: []byte{
+ 25, 4, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+ 0, 1, 2, 3, 4, 5, 6, 7,
+ },
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "InvalidRecursiveDNSServerInvalidLengthField",
+ buf: []byte{
+ 25, 2, 0, 0,
+ 0, 0, 0, 0,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8,
+ },
+ expectedErr: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "RecursiveDNSServerTooSmall",
+ buf: []byte{
+ 25, 1, 0, 0,
+ 0, 0, 0,
+ },
+ expectedErr: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "RecursiveDNSServerMulticast",
+ buf: []byte{
+ 25, 3, 0, 0,
+ 0, 0, 0, 0,
+ 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+ },
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "RecursiveDNSServerUnspecified",
+ buf: []byte{
+ 25, 3, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ },
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "DNSSearchListLargeCompliantRFC1035",
+ buf: []byte{
+ 31, 33, 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j',
+ 0,
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "DNSSearchListNonCompliantRFC1035",
+ buf: []byte{
+ 31, 33, 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ },
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "DNSSearchListValidSmall",
+ buf: []byte{
+ 31, 2, 0, 0,
+ 0, 0, 0, 0,
+ 6, 'a', 'b', 'c', 'd', 'e', 'f',
+ 0,
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "DNSSearchListTooSmall",
+ buf: []byte{
+ 31, 1, 0, 0,
+ 0, 0, 0,
+ },
+ expectedErr: io.ErrUnexpectedEOF,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := NDPOptions(test.buf)
+
+ if _, err := opts.Iter(true); !errors.Is(err, test.expectedErr) {
+ t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expectedErr)
+ }
+
+ // test.buf may be malformed but we chose not to check
+ // the iterator so it must return true.
+ if _, err := opts.Iter(false); err != nil {
+ t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err)
+ }
+ })
+ }
+}
+
+// TestNDPOptionsIter tests that we can iterator over a valid NDPOptions. Note,
+// this test does not actually check any of the option's getters, it simply
+// checks the option Type and Body. We have other tests that tests the option
+// field gettings given an option body and don't need to duplicate those tests
+// here.
+func TestNDPOptionsIter(t *testing.T) {
+ buf := []byte{
+ // Source Link-Layer Address.
+ 1, 1, 1, 2, 3, 4, 5, 6,
+
+ // Target Link-Layer Address.
+ 2, 1, 7, 8, 9, 10, 11, 12,
+
+ // 255 is an unrecognized type. If 255 ends up being the type
+ // for some recognized type, update 255 to some other
+ // unrecognized value. Note, this option should be skipped when
+ // iterating.
+ 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8,
+
+ // Prefix information.
+ 3, 4, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23, 24,
+ }
+
+ opts := NDPOptions(buf)
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ // Test the first (Source Link-Layer) option.
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got, want := []byte(next.(NDPSourceLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) {
+ t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
+ }
+ if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType {
+ t.Errorf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType)
+ }
+
+ // Test the next (Target Link-Layer) option.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[10:][:6]; !bytes.Equal(got, want) {
+ t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
+ }
+ if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType {
+ t.Errorf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType)
+ }
+
+ // Test the next (Prefix Information) option.
+ // Note, the unrecognized option should be skipped.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got, want := next.(NDPPrefixInformation), buf[34:][:30]; !bytes.Equal(got, want) {
+ t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
+ }
+ if got := next.Type(); got != NDPPrefixInformationType {
+ t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType)
+ }
+
+ // Iterator should not return anything else.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+}
diff --git a/pkg/tcpip/header/ndpoptionidentifier_string.go b/pkg/tcpip/header/ndpoptionidentifier_string.go
new file mode 100644
index 000000000..6fe9a336b
--- /dev/null
+++ b/pkg/tcpip/header/ndpoptionidentifier_string.go
@@ -0,0 +1,50 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type NDPOptionIdentifier ."; DO NOT EDIT.
+
+package header
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[NDPSourceLinkLayerAddressOptionType-1]
+ _ = x[NDPTargetLinkLayerAddressOptionType-2]
+ _ = x[NDPPrefixInformationType-3]
+ _ = x[NDPRecursiveDNSServerOptionType-25]
+}
+
+const (
+ _NDPOptionIdentifier_name_0 = "NDPSourceLinkLayerAddressOptionTypeNDPTargetLinkLayerAddressOptionTypeNDPPrefixInformationType"
+ _NDPOptionIdentifier_name_1 = "NDPRecursiveDNSServerOptionType"
+)
+
+var (
+ _NDPOptionIdentifier_index_0 = [...]uint8{0, 35, 70, 94}
+)
+
+func (i NDPOptionIdentifier) String() string {
+ switch {
+ case 1 <= i && i <= 3:
+ i -= 1
+ return _NDPOptionIdentifier_name_0[_NDPOptionIdentifier_index_0[i]:_NDPOptionIdentifier_index_0[i+1]]
+ case i == 25:
+ return _NDPOptionIdentifier_name_1
+ default:
+ return "NDPOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+}
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
new file mode 100644
index 000000000..4c6f808e5
--- /dev/null
+++ b/pkg/tcpip/header/tcp.go
@@ -0,0 +1,621 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "github.com/google/btree"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+// These constants are the offsets of the respective fields in the TCP header.
+const (
+ TCPSrcPortOffset = 0
+ TCPDstPortOffset = 2
+ TCPSeqNumOffset = 4
+ TCPAckNumOffset = 8
+ TCPDataOffset = 12
+ TCPFlagsOffset = 13
+ TCPWinSizeOffset = 14
+ TCPChecksumOffset = 16
+ TCPUrgentPtrOffset = 18
+)
+
+const (
+ // MaxWndScale is maximum allowed window scaling, as described in
+ // RFC 1323, section 2.3, page 11.
+ MaxWndScale = 14
+
+ // TCPMaxSACKBlocks is the maximum number of SACK blocks that can
+ // be encoded in a TCP option field.
+ TCPMaxSACKBlocks = 4
+)
+
+// Flags that may be set in a TCP segment.
+const (
+ TCPFlagFin = 1 << iota
+ TCPFlagSyn
+ TCPFlagRst
+ TCPFlagPsh
+ TCPFlagAck
+ TCPFlagUrg
+)
+
+// Options that may be present in a TCP segment.
+const (
+ TCPOptionEOL = 0
+ TCPOptionNOP = 1
+ TCPOptionMSS = 2
+ TCPOptionWS = 3
+ TCPOptionTS = 8
+ TCPOptionSACKPermitted = 4
+ TCPOptionSACK = 5
+)
+
+// Option Lengths.
+const (
+ TCPOptionMSSLength = 4
+ TCPOptionTSLength = 10
+ TCPOptionWSLength = 3
+ TCPOptionSackPermittedLength = 2
+)
+
+// TCPFields contains the fields of a TCP packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type TCPFields struct {
+ // SrcPort is the "source port" field of a TCP packet.
+ SrcPort uint16
+
+ // DstPort is the "destination port" field of a TCP packet.
+ DstPort uint16
+
+ // SeqNum is the "sequence number" field of a TCP packet.
+ SeqNum uint32
+
+ // AckNum is the "acknowledgement number" field of a TCP packet.
+ AckNum uint32
+
+ // DataOffset is the "data offset" field of a TCP packet. It is the length of
+ // the TCP header in bytes.
+ DataOffset uint8
+
+ // Flags is the "flags" field of a TCP packet.
+ Flags uint8
+
+ // WindowSize is the "window size" field of a TCP packet.
+ WindowSize uint16
+
+ // Checksum is the "checksum" field of a TCP packet.
+ Checksum uint16
+
+ // UrgentPointer is the "urgent pointer" field of a TCP packet.
+ UrgentPointer uint16
+}
+
+// TCPSynOptions is used to return the parsed TCP Options in a syn
+// segment.
+type TCPSynOptions struct {
+ // MSS is the maximum segment size provided by the peer in the SYN.
+ MSS uint16
+
+ // WS is the window scale option provided by the peer in the SYN.
+ //
+ // Set to -1 if no window scale option was provided.
+ WS int
+
+ // TS is true if the timestamp option was provided in the syn/syn-ack.
+ TS bool
+
+ // TSVal is the value of the TSVal field in the timestamp option.
+ TSVal uint32
+
+ // TSEcr is the value of the TSEcr field in the timestamp option.
+ TSEcr uint32
+
+ // SACKPermitted is true if the SACK option was provided in the SYN/SYN-ACK.
+ SACKPermitted bool
+}
+
+// SACKBlock represents a single contiguous SACK block.
+//
+// +stateify savable
+type SACKBlock struct {
+ // Start indicates the lowest sequence number in the block.
+ Start seqnum.Value
+
+ // End indicates the sequence number immediately following the last
+ // sequence number of this block.
+ End seqnum.Value
+}
+
+// Less returns true if r.Start < b.Start.
+func (r SACKBlock) Less(b btree.Item) bool {
+ return r.Start.LessThan(b.(SACKBlock).Start)
+}
+
+// Contains returns true if b is completely contained in r.
+func (r SACKBlock) Contains(b SACKBlock) bool {
+ return r.Start.LessThanEq(b.Start) && b.End.LessThanEq(r.End)
+}
+
+// TCPOptions are used to parse and cache the TCP segment options for a non
+// syn/syn-ack segment.
+//
+// +stateify savable
+type TCPOptions struct {
+ // TS is true if the TimeStamp option is enabled.
+ TS bool
+
+ // TSVal is the value in the TSVal field of the segment.
+ TSVal uint32
+
+ // TSEcr is the value in the TSEcr field of the segment.
+ TSEcr uint32
+
+ // SACKBlocks are the SACK blocks specified in the segment.
+ SACKBlocks []SACKBlock
+}
+
+// TCP represents a TCP header stored in a byte array.
+type TCP []byte
+
+const (
+ // TCPMinimumSize is the minimum size of a valid TCP packet.
+ TCPMinimumSize = 20
+
+ // TCPOptionsMaximumSize is the maximum size of TCP options.
+ TCPOptionsMaximumSize = 40
+
+ // TCPHeaderMaximumSize is the maximum header size of a TCP packet.
+ TCPHeaderMaximumSize = TCPMinimumSize + TCPOptionsMaximumSize
+
+ // TCPProtocolNumber is TCP's transport protocol number.
+ TCPProtocolNumber tcpip.TransportProtocolNumber = 6
+
+ // TCPMinimumMSS is the minimum acceptable value for MSS. This is the
+ // same as the value TCP_MIN_MSS defined net/tcp.h.
+ TCPMinimumMSS = IPv4MaximumHeaderSize + TCPHeaderMaximumSize + MinIPFragmentPayloadSize - IPv4MinimumSize - TCPMinimumSize
+
+ // TCPMaximumMSS is the maximum acceptable value for MSS.
+ TCPMaximumMSS = 0xffff
+
+ // TCPDefaultMSS is the MSS value that should be used if an MSS option
+ // is not received from the peer. It's also the value returned by
+ // TCP_MAXSEG option for a socket in an unconnected state.
+ //
+ // Per RFC 1122, page 85: "If an MSS option is not received at
+ // connection setup, TCP MUST assume a default send MSS of 536."
+ TCPDefaultMSS = 536
+)
+
+// SourcePort returns the "source port" field of the tcp header.
+func (b TCP) SourcePort() uint16 {
+ return binary.BigEndian.Uint16(b[TCPSrcPortOffset:])
+}
+
+// DestinationPort returns the "destination port" field of the tcp header.
+func (b TCP) DestinationPort() uint16 {
+ return binary.BigEndian.Uint16(b[TCPDstPortOffset:])
+}
+
+// SequenceNumber returns the "sequence number" field of the tcp header.
+func (b TCP) SequenceNumber() uint32 {
+ return binary.BigEndian.Uint32(b[TCPSeqNumOffset:])
+}
+
+// AckNumber returns the "ack number" field of the tcp header.
+func (b TCP) AckNumber() uint32 {
+ return binary.BigEndian.Uint32(b[TCPAckNumOffset:])
+}
+
+// DataOffset returns the "data offset" field of the tcp header. The return
+// value is the length of the TCP header in bytes.
+func (b TCP) DataOffset() uint8 {
+ return (b[TCPDataOffset] >> 4) * 4
+}
+
+// Payload returns the data in the tcp packet.
+func (b TCP) Payload() []byte {
+ return b[b.DataOffset():]
+}
+
+// Flags returns the flags field of the tcp header.
+func (b TCP) Flags() uint8 {
+ return b[TCPFlagsOffset]
+}
+
+// WindowSize returns the "window size" field of the tcp header.
+func (b TCP) WindowSize() uint16 {
+ return binary.BigEndian.Uint16(b[TCPWinSizeOffset:])
+}
+
+// Checksum returns the "checksum" field of the tcp header.
+func (b TCP) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[TCPChecksumOffset:])
+}
+
+// UrgentPointer returns the "urgent pointer" field of the tcp header.
+func (b TCP) UrgentPointer() uint16 {
+ return binary.BigEndian.Uint16(b[TCPUrgentPtrOffset:])
+}
+
+// SetSourcePort sets the "source port" field of the tcp header.
+func (b TCP) SetSourcePort(port uint16) {
+ binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], port)
+}
+
+// SetDestinationPort sets the "destination port" field of the tcp header.
+func (b TCP) SetDestinationPort(port uint16) {
+ binary.BigEndian.PutUint16(b[TCPDstPortOffset:], port)
+}
+
+// SetChecksum sets the checksum field of the tcp header.
+func (b TCP) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[TCPChecksumOffset:], checksum)
+}
+
+// SetDataOffset sets the data offset field of the tcp header. headerLen should
+// be the length of the TCP header in bytes.
+func (b TCP) SetDataOffset(headerLen uint8) {
+ b[TCPDataOffset] = (headerLen / 4) << 4
+}
+
+// SetSequenceNumber sets the sequence number field of the tcp header.
+func (b TCP) SetSequenceNumber(seqNum uint32) {
+ binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seqNum)
+}
+
+// SetAckNumber sets the ack number field of the tcp header.
+func (b TCP) SetAckNumber(ackNum uint32) {
+ binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ackNum)
+}
+
+// SetFlags sets the flags field of the tcp header.
+func (b TCP) SetFlags(flags uint8) {
+ b[TCPFlagsOffset] = flags
+}
+
+// SetWindowSize sets the window size field of the tcp header.
+func (b TCP) SetWindowSize(rcvwnd uint16) {
+ binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd)
+}
+
+// SetUrgentPoiner sets the window size field of the tcp header.
+func (b TCP) SetUrgentPoiner(urgentPointer uint16) {
+ binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], urgentPointer)
+}
+
+// CalculateChecksum calculates the checksum of the tcp segment.
+// partialChecksum is the checksum of the network-layer pseudo-header
+// and the checksum of the segment data.
+func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 {
+ // Calculate the rest of the checksum.
+ return Checksum(b[:b.DataOffset()], partialChecksum)
+}
+
+// Options returns a slice that holds the unparsed TCP options in the segment.
+func (b TCP) Options() []byte {
+ return b[TCPMinimumSize:b.DataOffset()]
+}
+
+// ParsedOptions returns a TCPOptions structure which parses and caches the TCP
+// option values in the TCP segment. NOTE: Invoking this function repeatedly is
+// expensive as it reparses the options on each invocation.
+func (b TCP) ParsedOptions() TCPOptions {
+ return ParseTCPOptions(b.Options())
+}
+
+func (b TCP) encodeSubset(seq, ack uint32, flags uint8, rcvwnd uint16) {
+ binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seq)
+ binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ack)
+ b[TCPFlagsOffset] = flags
+ binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd)
+}
+
+// Encode encodes all the fields of the tcp header.
+func (b TCP) Encode(t *TCPFields) {
+ b.encodeSubset(t.SeqNum, t.AckNum, t.Flags, t.WindowSize)
+ binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], t.SrcPort)
+ binary.BigEndian.PutUint16(b[TCPDstPortOffset:], t.DstPort)
+ b[TCPDataOffset] = (t.DataOffset / 4) << 4
+ binary.BigEndian.PutUint16(b[TCPChecksumOffset:], t.Checksum)
+ binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], t.UrgentPointer)
+}
+
+// EncodePartial updates a subset of the fields of the tcp header. It is useful
+// in cases when similar segments are produced.
+func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags byte, rcvwnd uint16) {
+ // Add the total length and "flags" field contributions to the checksum.
+ // We don't use the flags field directly from the header because it's a
+ // one-byte field with an odd offset, so it would be accounted for
+ // incorrectly by the Checksum routine.
+ tmp := make([]byte, 4)
+ binary.BigEndian.PutUint16(tmp, length)
+ binary.BigEndian.PutUint16(tmp[2:], uint16(flags))
+ checksum := Checksum(tmp, partialChecksum)
+
+ // Encode the passed-in fields.
+ b.encodeSubset(seqnum, acknum, flags, rcvwnd)
+
+ // Add the contributions of the passed-in fields to the checksum.
+ checksum = Checksum(b[TCPSeqNumOffset:TCPSeqNumOffset+8], checksum)
+ checksum = Checksum(b[TCPWinSizeOffset:TCPWinSizeOffset+2], checksum)
+
+ // Encode the checksum.
+ b.SetChecksum(^checksum)
+}
+
+// ParseSynOptions parses the options received in a SYN segment and returns the
+// relevant ones. opts should point to the option part of the TCP Header.
+func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions {
+ limit := len(opts)
+
+ synOpts := TCPSynOptions{
+ // Per RFC 1122, page 85: "If an MSS option is not received at
+ // connection setup, TCP MUST assume a default send MSS of 536."
+ MSS: TCPDefaultMSS,
+ // If no window scale option is specified, WS in options is
+ // returned as -1; this is because the absence of the option
+ // indicates that the we cannot use window scaling on the
+ // receive end either.
+ WS: -1,
+ }
+
+ for i := 0; i < limit; {
+ switch opts[i] {
+ case TCPOptionEOL:
+ i = limit
+ case TCPOptionNOP:
+ i++
+ case TCPOptionMSS:
+ if i+4 > limit || opts[i+1] != 4 {
+ return synOpts
+ }
+ mss := uint16(opts[i+2])<<8 | uint16(opts[i+3])
+ if mss == 0 {
+ return synOpts
+ }
+ synOpts.MSS = mss
+ i += 4
+
+ case TCPOptionWS:
+ if i+3 > limit || opts[i+1] != 3 {
+ return synOpts
+ }
+ ws := int(opts[i+2])
+ if ws > MaxWndScale {
+ ws = MaxWndScale
+ }
+ synOpts.WS = ws
+ i += 3
+
+ case TCPOptionTS:
+ if i+10 > limit || opts[i+1] != 10 {
+ return synOpts
+ }
+ synOpts.TSVal = binary.BigEndian.Uint32(opts[i+2:])
+ if isAck {
+ // If the segment is a SYN-ACK then store the Timestamp Echo Reply
+ // in the segment.
+ synOpts.TSEcr = binary.BigEndian.Uint32(opts[i+6:])
+ }
+ synOpts.TS = true
+ i += 10
+ case TCPOptionSACKPermitted:
+ if i+2 > limit || opts[i+1] != 2 {
+ return synOpts
+ }
+ synOpts.SACKPermitted = true
+ i += 2
+
+ default:
+ // We don't recognize this option, just skip over it.
+ if i+2 > limit {
+ return synOpts
+ }
+ l := int(opts[i+1])
+ // If the length is incorrect or if l+i overflows the
+ // total options length then return false.
+ if l < 2 || i+l > limit {
+ return synOpts
+ }
+ i += l
+ }
+ }
+
+ return synOpts
+}
+
+// ParseTCPOptions extracts and stores all known options in the provided byte
+// slice in a TCPOptions structure.
+func ParseTCPOptions(b []byte) TCPOptions {
+ opts := TCPOptions{}
+ limit := len(b)
+ for i := 0; i < limit; {
+ switch b[i] {
+ case TCPOptionEOL:
+ i = limit
+ case TCPOptionNOP:
+ i++
+ case TCPOptionTS:
+ if i+10 > limit || (b[i+1] != 10) {
+ return opts
+ }
+ opts.TS = true
+ opts.TSVal = binary.BigEndian.Uint32(b[i+2:])
+ opts.TSEcr = binary.BigEndian.Uint32(b[i+6:])
+ i += 10
+ case TCPOptionSACK:
+ if i+2 > limit {
+ // Malformed SACK block, just return and stop parsing.
+ return opts
+ }
+ sackOptionLen := int(b[i+1])
+ if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
+ // Malformed SACK block, just return and stop parsing.
+ return opts
+ }
+ numBlocks := (sackOptionLen - 2) / 8
+ opts.SACKBlocks = []SACKBlock{}
+ for j := 0; j < numBlocks; j++ {
+ start := binary.BigEndian.Uint32(b[i+2+j*8:])
+ end := binary.BigEndian.Uint32(b[i+2+j*8+4:])
+ opts.SACKBlocks = append(opts.SACKBlocks, SACKBlock{
+ Start: seqnum.Value(start),
+ End: seqnum.Value(end),
+ })
+ }
+ i += sackOptionLen
+ default:
+ // We don't recognize this option, just skip over it.
+ if i+2 > limit {
+ return opts
+ }
+ l := int(b[i+1])
+ // If the length is incorrect or if l+i overflows the
+ // total options length then return false.
+ if l < 2 || i+l > limit {
+ return opts
+ }
+ i += l
+ }
+ }
+ return opts
+}
+
+// EncodeMSSOption encodes the MSS TCP option with the provided MSS values in
+// the supplied buffer. If the provided buffer is not large enough then it just
+// returns without encoding anything. It returns the number of bytes written to
+// the provided buffer.
+func EncodeMSSOption(mss uint32, b []byte) int {
+ if len(b) < TCPOptionMSSLength {
+ return 0
+ }
+ b[0], b[1], b[2], b[3] = TCPOptionMSS, TCPOptionMSSLength, byte(mss>>8), byte(mss)
+ return TCPOptionMSSLength
+}
+
+// EncodeWSOption encodes the WS TCP option with the WS value in the
+// provided buffer. If the provided buffer is not large enough then it just
+// returns without encoding anything. It returns the number of bytes written to
+// the provided buffer.
+func EncodeWSOption(ws int, b []byte) int {
+ if len(b) < TCPOptionWSLength {
+ return 0
+ }
+ b[0], b[1], b[2] = TCPOptionWS, TCPOptionWSLength, uint8(ws)
+ return int(b[1])
+}
+
+// EncodeTSOption encodes the provided tsVal and tsEcr values as a TCP timestamp
+// option into the provided buffer. If the buffer is smaller than expected it
+// just returns without encoding anything. It returns the number of bytes
+// written to the provided buffer.
+func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int {
+ if len(b) < TCPOptionTSLength {
+ return 0
+ }
+ b[0], b[1] = TCPOptionTS, TCPOptionTSLength
+ binary.BigEndian.PutUint32(b[2:], tsVal)
+ binary.BigEndian.PutUint32(b[6:], tsEcr)
+ return int(b[1])
+}
+
+// EncodeSACKPermittedOption encodes a SACKPermitted option into the provided
+// buffer. If the buffer is smaller than required it just returns without
+// encoding anything. It returns the number of bytes written to the provided
+// buffer.
+func EncodeSACKPermittedOption(b []byte) int {
+ if len(b) < TCPOptionSackPermittedLength {
+ return 0
+ }
+
+ b[0], b[1] = TCPOptionSACKPermitted, TCPOptionSackPermittedLength
+ return int(b[1])
+}
+
+// EncodeSACKBlocks encodes the provided SACK blocks as a TCP SACK option block
+// in the provided slice. It tries to fit in as many blocks as possible based on
+// number of bytes available in the provided buffer. It returns the number of
+// bytes written to the provided buffer.
+func EncodeSACKBlocks(sackBlocks []SACKBlock, b []byte) int {
+ if len(sackBlocks) == 0 {
+ return 0
+ }
+ l := len(sackBlocks)
+ if l > TCPMaxSACKBlocks {
+ l = TCPMaxSACKBlocks
+ }
+ if ll := (len(b) - 2) / 8; ll < l {
+ l = ll
+ }
+ if l == 0 {
+ // There is not enough space in the provided buffer to add
+ // any SACK blocks.
+ return 0
+ }
+ b[0] = TCPOptionSACK
+ b[1] = byte(l*8 + 2)
+ for i := 0; i < l; i++ {
+ binary.BigEndian.PutUint32(b[i*8+2:], uint32(sackBlocks[i].Start))
+ binary.BigEndian.PutUint32(b[i*8+6:], uint32(sackBlocks[i].End))
+ }
+ return int(b[1])
+}
+
+// EncodeNOP adds an explicit NOP to the option list.
+func EncodeNOP(b []byte) int {
+ if len(b) == 0 {
+ return 0
+ }
+ b[0] = TCPOptionNOP
+ return 1
+}
+
+// AddTCPOptionPadding adds the required number of TCPOptionNOP to quad align
+// the option buffer. It adds padding bytes after the offset specified and
+// returns the number of padding bytes added. The passed in options slice
+// must have space for the padding bytes.
+func AddTCPOptionPadding(options []byte, offset int) int {
+ paddingToAdd := -offset & 3
+ // Now add any padding bytes that might be required to quad align the
+ // options.
+ for i := offset; i < offset+paddingToAdd; i++ {
+ options[i] = TCPOptionNOP
+ }
+ return paddingToAdd
+}
+
+// Acceptable checks if a segment that starts at segSeq and has length segLen is
+// "acceptable" for arriving in a receive window that starts at rcvNxt and ends
+// before rcvAcc, according to the table on page 26 and 69 of RFC 793.
+func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.Value) bool {
+ if rcvNxt == rcvAcc {
+ return segLen == 0 && segSeq == rcvNxt
+ }
+ if segLen == 0 {
+ // rcvWnd is incremented by 1 because that is Linux's behavior despite the
+ // RFC.
+ return segSeq.InRange(rcvNxt, rcvAcc.Add(1))
+ }
+ // Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming
+ // the payload, so we'll accept any payload that overlaps the receieve window.
+ // segSeq < rcvAcc is more correct according to RFC, however, Linux does it
+ // differently, it uses segSeq <= rcvAcc, we'd want to keep the same behavior
+ // as Linux.
+ return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThanEq(rcvAcc)
+}
diff --git a/pkg/tcpip/header/tcp_test.go b/pkg/tcpip/header/tcp_test.go
new file mode 100644
index 000000000..72563837b
--- /dev/null
+++ b/pkg/tcpip/header/tcp_test.go
@@ -0,0 +1,148 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header_test
+
+import (
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+func TestEncodeSACKBlocks(t *testing.T) {
+ testCases := []struct {
+ sackBlocks []header.SACKBlock
+ want []header.SACKBlock
+ bufSize int
+ }{
+ {
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}},
+ 40,
+ },
+ {
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}},
+ 30,
+ },
+ {
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
+ []header.SACKBlock{{10, 20}, {22, 30}},
+ 20,
+ },
+ {
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
+ []header.SACKBlock{{10, 20}},
+ 10,
+ },
+ {
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
+ nil,
+ 8,
+ },
+ {
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}},
+ 60,
+ },
+ }
+ for _, tc := range testCases {
+ b := make([]byte, tc.bufSize)
+ t.Logf("testing: %v", tc)
+ header.EncodeSACKBlocks(tc.sackBlocks, b)
+ opts := header.ParseTCPOptions(b)
+ if got, want := opts.SACKBlocks, tc.want; !reflect.DeepEqual(got, want) {
+ t.Errorf("header.EncodeSACKBlocks(%v, %v), encoded blocks got: %v, want: %v", tc.sackBlocks, b, got, want)
+ }
+ }
+}
+
+func TestTCPParseOptions(t *testing.T) {
+ type tsOption struct {
+ tsVal uint32
+ tsEcr uint32
+ }
+
+ generateOptions := func(tsOpt *tsOption, sackBlocks []header.SACKBlock) []byte {
+ l := 0
+ if tsOpt != nil {
+ l += 10
+ }
+ if len(sackBlocks) != 0 {
+ l += len(sackBlocks)*8 + 2
+ }
+ b := make([]byte, l)
+ offset := 0
+ if tsOpt != nil {
+ offset = header.EncodeTSOption(tsOpt.tsVal, tsOpt.tsEcr, b)
+ }
+ header.EncodeSACKBlocks(sackBlocks, b[offset:])
+ return b
+ }
+
+ testCases := []struct {
+ b []byte
+ want header.TCPOptions
+ }{
+ // Trivial cases.
+ {nil, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionNOP, header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionEOL}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionNOP, header.TCPOptionEOL, header.TCPOptionTS, 10, 1, 1}, header.TCPOptions{false, 0, 0, nil}},
+
+ // Test timestamp parsing.
+ {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}},
+ {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}},
+
+ // Test malformed timestamp option.
+ {[]byte{header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}},
+
+ // Test SACKBlock parsing.
+ {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}}}},
+ {[]byte{header.TCPOptionSACK, 18, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}, {11, 12}}}},
+
+ // Test malformed SACK option.
+ {[]byte{header.TCPOptionSACK, 0}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK, 8, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK, 17, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK, 10}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0}, header.TCPOptions{false, 0, 0, nil}},
+
+ // Test Timestamp + SACK block parsing.
+ {generateOptions(&tsOption{1, 1}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 1, []header.SACKBlock{{1, 10}, {11, 12}}}},
+ {generateOptions(&tsOption{1, 2}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 2, []header.SACKBlock{{1, 10}, {11, 12}}}},
+ {generateOptions(&tsOption{1, 3}, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}, {15, 16}}), header.TCPOptions{true, 1, 3, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}}}},
+
+ // Test valid timestamp + malformed SACK block parsing.
+ {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK}, header.TCPOptions{true, 1, 1, nil}},
+ {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10}, header.TCPOptions{true, 1, 1, nil}},
+ {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10, 0, 0, 0}, header.TCPOptions{true, 1, 1, nil}},
+ {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}},
+ {[]byte{header.TCPOptionSACK, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK, 10, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{134873088, 65536}}}},
+ {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{8, 167772160}}}},
+ {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}},
+ }
+ for _, tc := range testCases {
+ if got, want := header.ParseTCPOptions(tc.b), tc.want; !reflect.DeepEqual(got, want) {
+ t.Errorf("ParseTCPOptions(%v) = %v, want: %v", tc.b, got, tc.want)
+ }
+ }
+}
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
new file mode 100644
index 000000000..9339d637f
--- /dev/null
+++ b/pkg/tcpip/header/udp.go
@@ -0,0 +1,120 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ udpSrcPort = 0
+ udpDstPort = 2
+ udpLength = 4
+ udpChecksum = 6
+)
+
+const (
+ // UDPMaximumPacketSize is the largest possible UDP packet.
+ UDPMaximumPacketSize = 0xffff
+)
+
+// UDPFields contains the fields of a UDP packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type UDPFields struct {
+ // SrcPort is the "source port" field of a UDP packet.
+ SrcPort uint16
+
+ // DstPort is the "destination port" field of a UDP packet.
+ DstPort uint16
+
+ // Length is the "length" field of a UDP packet.
+ Length uint16
+
+ // Checksum is the "checksum" field of a UDP packet.
+ Checksum uint16
+}
+
+// UDP represents a UDP header stored in a byte array.
+type UDP []byte
+
+const (
+ // UDPMinimumSize is the minimum size of a valid UDP packet.
+ UDPMinimumSize = 8
+
+ // UDPProtocolNumber is UDP's transport protocol number.
+ UDPProtocolNumber tcpip.TransportProtocolNumber = 17
+)
+
+// SourcePort returns the "source port" field of the udp header.
+func (b UDP) SourcePort() uint16 {
+ return binary.BigEndian.Uint16(b[udpSrcPort:])
+}
+
+// DestinationPort returns the "destination port" field of the udp header.
+func (b UDP) DestinationPort() uint16 {
+ return binary.BigEndian.Uint16(b[udpDstPort:])
+}
+
+// Length returns the "length" field of the udp header.
+func (b UDP) Length() uint16 {
+ return binary.BigEndian.Uint16(b[udpLength:])
+}
+
+// Payload returns the data contained in the UDP datagram.
+func (b UDP) Payload() []byte {
+ return b[UDPMinimumSize:]
+}
+
+// Checksum returns the "checksum" field of the udp header.
+func (b UDP) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[udpChecksum:])
+}
+
+// SetSourcePort sets the "source port" field of the udp header.
+func (b UDP) SetSourcePort(port uint16) {
+ binary.BigEndian.PutUint16(b[udpSrcPort:], port)
+}
+
+// SetDestinationPort sets the "destination port" field of the udp header.
+func (b UDP) SetDestinationPort(port uint16) {
+ binary.BigEndian.PutUint16(b[udpDstPort:], port)
+}
+
+// SetChecksum sets the "checksum" field of the udp header.
+func (b UDP) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
+}
+
+// SetLength sets the "length" field of the udp header.
+func (b UDP) SetLength(length uint16) {
+ binary.BigEndian.PutUint16(b[udpLength:], length)
+}
+
+// CalculateChecksum calculates the checksum of the udp packet, given the
+// checksum of the network-layer pseudo-header and the checksum of the payload.
+func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
+ // Calculate the rest of the checksum.
+ return Checksum(b[:UDPMinimumSize], partialChecksum)
+}
+
+// Encode encodes all the fields of the udp header.
+func (b UDP) Encode(u *UDPFields) {
+ binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort)
+ binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort)
+ binary.BigEndian.PutUint16(b[udpLength:], u.Length)
+ binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum)
+}
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
new file mode 100644
index 000000000..b8b93e78e
--- /dev/null
+++ b/pkg/tcpip/link/channel/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "channel",
+ srcs = ["channel.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
new file mode 100644
index 000000000..20b183da0
--- /dev/null
+++ b/pkg/tcpip/link/channel/channel.go
@@ -0,0 +1,298 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package channel provides the implemention of channel-based data-link layer
+// endpoints. Such endpoints allow injection of inbound packets and store
+// outbound packets in a channel.
+package channel
+
+import (
+ "context"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// PacketInfo holds all the information about an outbound packet.
+type PacketInfo struct {
+ Pkt *stack.PacketBuffer
+ Proto tcpip.NetworkProtocolNumber
+ GSO *stack.GSO
+ Route stack.Route
+}
+
+// Notification is the interface for receiving notification from the packet
+// queue.
+type Notification interface {
+ // WriteNotify will be called when a write happens to the queue.
+ WriteNotify()
+}
+
+// NotificationHandle is an opaque handle to the registered notification target.
+// It can be used to unregister the notification when no longer interested.
+//
+// +stateify savable
+type NotificationHandle struct {
+ n Notification
+}
+
+type queue struct {
+ // c is the outbound packet channel.
+ c chan PacketInfo
+ // mu protects fields below.
+ mu sync.RWMutex
+ notify []*NotificationHandle
+}
+
+func (q *queue) Close() {
+ close(q.c)
+}
+
+func (q *queue) Read() (PacketInfo, bool) {
+ select {
+ case p := <-q.c:
+ return p, true
+ default:
+ return PacketInfo{}, false
+ }
+}
+
+func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) {
+ select {
+ case pkt := <-q.c:
+ return pkt, true
+ case <-ctx.Done():
+ return PacketInfo{}, false
+ }
+}
+
+func (q *queue) Write(p PacketInfo) bool {
+ wrote := false
+ select {
+ case q.c <- p:
+ wrote = true
+ default:
+ }
+ q.mu.Lock()
+ notify := q.notify
+ q.mu.Unlock()
+
+ if wrote {
+ // Send notification outside of lock.
+ for _, h := range notify {
+ h.n.WriteNotify()
+ }
+ }
+ return wrote
+}
+
+func (q *queue) Num() int {
+ return len(q.c)
+}
+
+func (q *queue) AddNotify(notify Notification) *NotificationHandle {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ h := &NotificationHandle{n: notify}
+ q.notify = append(q.notify, h)
+ return h
+}
+
+func (q *queue) RemoveNotify(handle *NotificationHandle) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ // Make a copy, since we reads the array outside of lock when notifying.
+ notify := make([]*NotificationHandle, 0, len(q.notify))
+ for _, h := range q.notify {
+ if h != handle {
+ notify = append(notify, h)
+ }
+ }
+ q.notify = notify
+}
+
+// Endpoint is link layer endpoint that stores outbound packets in a channel
+// and allows injection of inbound packets.
+type Endpoint struct {
+ dispatcher stack.NetworkDispatcher
+ mtu uint32
+ linkAddr tcpip.LinkAddress
+ LinkEPCapabilities stack.LinkEndpointCapabilities
+
+ // Outbound packet queue.
+ q *queue
+}
+
+// New creates a new channel endpoint.
+func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
+ return &Endpoint{
+ q: &queue{
+ c: make(chan PacketInfo, size),
+ },
+ mtu: mtu,
+ linkAddr: linkAddr,
+ }
+}
+
+// Close closes e. Further packet injections will panic. Reads continue to
+// succeed until all packets are read.
+func (e *Endpoint) Close() {
+ e.q.Close()
+}
+
+// Read does non-blocking read one packet from the outbound packet queue.
+func (e *Endpoint) Read() (PacketInfo, bool) {
+ return e.q.Read()
+}
+
+// ReadContext does blocking read for one packet from the outbound packet queue.
+// It can be cancelled by ctx, and in this case, it returns false.
+func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) {
+ return e.q.ReadContext(ctx)
+}
+
+// Drain removes all outbound packets from the channel and counts them.
+func (e *Endpoint) Drain() int {
+ c := 0
+ for {
+ if _, ok := e.Read(); !ok {
+ return c
+ }
+ c++
+ }
+}
+
+// NumQueued returns the number of packet queued for outbound.
+func (e *Endpoint) NumQueued() int {
+ return e.q.Num()
+}
+
+// InjectInbound injects an inbound packet.
+func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.InjectLinkAddr(protocol, "", pkt)
+}
+
+// InjectLinkAddr injects an inbound packet with a remote link address.
+func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *stack.PacketBuffer) {
+ e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt)
+}
+
+// Attach saves the stack network-layer dispatcher for use later when packets
+// are injected.
+func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *Endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
+// during construction.
+func (e *Endpoint) MTU() uint32 {
+ return e.mtu
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.LinkEPCapabilities
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (*Endpoint) GSOMaxSize() uint32 {
+ return 1 << 15
+}
+
+// MaxHeaderLength returns the maximum size of the link layer header. Given it
+// doesn't have a header, it just returns 0.
+func (*Endpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.linkAddr
+}
+
+// WritePacket stores outbound packets into the channel.
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ // Clone r then release its resource so we only get the relevant fields from
+ // stack.Route without holding a reference to a NIC's endpoint.
+ route := r.Clone()
+ route.Release()
+ p := PacketInfo{
+ Pkt: pkt,
+ Proto: protocol,
+ GSO: gso,
+ Route: route,
+ }
+
+ e.q.Write(p)
+
+ return nil
+}
+
+// WritePackets stores outbound packets into the channel.
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ // Clone r then release its resource so we only get the relevant fields from
+ // stack.Route without holding a reference to a NIC's endpoint.
+ route := r.Clone()
+ route.Release()
+ n := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ p := PacketInfo{
+ Pkt: pkt,
+ Proto: protocol,
+ GSO: gso,
+ Route: route,
+ }
+
+ if !e.q.Write(p) {
+ break
+ }
+ n++
+ }
+
+ return n, nil
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ p := PacketInfo{
+ Pkt: &stack.PacketBuffer{Data: vv},
+ Proto: 0,
+ GSO: nil,
+ }
+
+ e.q.Write(p)
+
+ return nil
+}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*Endpoint) Wait() {}
+
+// AddNotify adds a notification target for receiving event about outgoing
+// packets.
+func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle {
+ return e.q.AddNotify(notify)
+}
+
+// RemoveNotify removes handle from the list of notification targets.
+func (e *Endpoint) RemoveNotify(handle *NotificationHandle) {
+ e.q.RemoveNotify(handle)
+}
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
new file mode 100644
index 000000000..aa6db9aea
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -0,0 +1,40 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "fdbased",
+ srcs = [
+ "endpoint.go",
+ "endpoint_unsafe.go",
+ "mmap.go",
+ "mmap_stub.go",
+ "mmap_unsafe.go",
+ "packet_dispatchers.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/binary",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/rawfile",
+ "//pkg/tcpip/stack",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "fdbased_test",
+ size = "small",
+ srcs = ["endpoint_test.go"],
+ library = ":fdbased",
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/rawfile",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
new file mode 100644
index 000000000..f34082e1a
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -0,0 +1,657 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+// Package fdbased provides the implemention of data-link layer endpoints
+// backed by boundary-preserving file descriptors (e.g., TUN devices,
+// seqpacket/datagram sockets).
+//
+// FD based endpoints can be used in the networking stack by calling New() to
+// create a new endpoint, and then passing it as an argument to
+// Stack.CreateNIC().
+//
+// FD based endpoints can use more than one file descriptor to read incoming
+// packets. If there are more than one FDs specified and the underlying FD is an
+// AF_PACKET then the endpoint will enable FANOUT mode on the socket so that the
+// host kernel will consistently hash the packets to the sockets. This ensures
+// that packets for the same TCP streams are not reordered.
+//
+// Similarly if more than one FD's are specified where the underlying FD is not
+// AF_PACKET then it's the caller's responsibility to ensure that all inbound
+// packets on the descriptors are consistently 5 tuple hashed to one of the
+// descriptors to prevent TCP reordering.
+//
+// Since netstack today does not compute 5 tuple hashes for outgoing packets we
+// only use the first FD to write outbound packets. Once 5 tuple hashes for
+// all outbound packets are available we will make use of all underlying FD's to
+// write outbound packets.
+package fdbased
+
+import (
+ "fmt"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// linkDispatcher reads packets from the link FD and dispatches them to the
+// NetworkDispatcher.
+type linkDispatcher interface {
+ dispatch() (bool, *tcpip.Error)
+}
+
+// PacketDispatchMode are the various supported methods of receiving and
+// dispatching packets from the underlying FD.
+type PacketDispatchMode int
+
+const (
+ // Readv is the default dispatch mode and is the least performant of the
+ // dispatch options but the one that is supported by all underlying FD
+ // types.
+ Readv PacketDispatchMode = iota
+ // RecvMMsg enables use of recvmmsg() syscall instead of readv() to
+ // read inbound packets. This reduces # of syscalls needed to process
+ // packets.
+ //
+ // NOTE: recvmmsg() is only supported for sockets, so if the underlying
+ // FD is not a socket then the code will still fall back to the readv()
+ // path.
+ RecvMMsg
+ // PacketMMap enables use of PACKET_RX_RING to receive packets from the
+ // NIC. PacketMMap requires that the underlying FD be an AF_PACKET. The
+ // primary use-case for this is runsc which uses an AF_PACKET FD to
+ // receive packets from the veth device.
+ PacketMMap
+)
+
+func (p PacketDispatchMode) String() string {
+ switch p {
+ case Readv:
+ return "Readv"
+ case RecvMMsg:
+ return "RecvMMsg"
+ case PacketMMap:
+ return "PacketMMap"
+ default:
+ return fmt.Sprintf("unknown packet dispatch mode '%d'", p)
+ }
+}
+
+type endpoint struct {
+ // fds is the set of file descriptors each identifying one inbound/outbound
+ // channel. The endpoint will dispatch from all inbound channels as well as
+ // hash outbound packets to specific channels based on the packet hash.
+ fds []int
+
+ // mtu (maximum transmission unit) is the maximum size of a packet.
+ mtu uint32
+
+ // hdrSize specifies the link-layer header size. If set to 0, no header
+ // is added/removed; otherwise an ethernet header is used.
+ hdrSize int
+
+ // addr is the address of the endpoint.
+ addr tcpip.LinkAddress
+
+ // caps holds the endpoint capabilities.
+ caps stack.LinkEndpointCapabilities
+
+ // closed is a function to be called when the FD's peer (if any) closes
+ // its end of the communication pipe.
+ closed func(*tcpip.Error)
+
+ inboundDispatchers []linkDispatcher
+ dispatcher stack.NetworkDispatcher
+
+ // packetDispatchMode controls the packet dispatcher used by this
+ // endpoint.
+ packetDispatchMode PacketDispatchMode
+
+ // gsoMaxSize is the maximum GSO packet size. It is zero if GSO is
+ // disabled.
+ gsoMaxSize uint32
+
+ // wg keeps track of running goroutines.
+ wg sync.WaitGroup
+}
+
+// Options specify the details about the fd-based endpoint to be created.
+type Options struct {
+ // FDs is a set of FDs used to read/write packets.
+ FDs []int
+
+ // MTU is the mtu to use for this endpoint.
+ MTU uint32
+
+ // EthernetHeader if true, indicates that the endpoint should read/write
+ // ethernet frames instead of IP packets.
+ EthernetHeader bool
+
+ // ClosedFunc is a function to be called when an endpoint's peer (if
+ // any) closes its end of the communication pipe.
+ ClosedFunc func(*tcpip.Error)
+
+ // Address is the link address for this endpoint. Only used if
+ // EthernetHeader is true.
+ Address tcpip.LinkAddress
+
+ // SaveRestore if true, indicates that this NIC capability set should
+ // include CapabilitySaveRestore
+ SaveRestore bool
+
+ // DisconnectOk if true, indicates that this NIC capability set should
+ // include CapabilityDisconnectOk.
+ DisconnectOk bool
+
+ // GSOMaxSize is the maximum GSO packet size. It is zero if GSO is
+ // disabled.
+ GSOMaxSize uint32
+
+ // SoftwareGSOEnabled indicates whether software GSO is enabled or not.
+ SoftwareGSOEnabled bool
+
+ // PacketDispatchMode specifies the type of inbound dispatcher to be
+ // used for this endpoint.
+ PacketDispatchMode PacketDispatchMode
+
+ // TXChecksumOffload if true, indicates that this endpoints capability
+ // set should include CapabilityTXChecksumOffload.
+ TXChecksumOffload bool
+
+ // RXChecksumOffload if true, indicates that this endpoints capability
+ // set should include CapabilityRXChecksumOffload.
+ RXChecksumOffload bool
+}
+
+// fanoutID is used for AF_PACKET based endpoints to enable PACKET_FANOUT
+// support in the host kernel. This allows us to use multiple FD's to receive
+// from the same underlying NIC. The fanoutID needs to be the same for a given
+// set of FD's that point to the same NIC. Trying to set the PACKET_FANOUT
+// option for an FD with a fanoutID already in use by another FD for a different
+// NIC will return an EINVAL.
+var fanoutID = 1
+
+// New creates a new fd-based endpoint.
+//
+// Makes fd non-blocking, but does not take ownership of fd, which must remain
+// open for the lifetime of the returned endpoint (until after the endpoint has
+// stopped being using and Wait returns).
+func New(opts *Options) (stack.LinkEndpoint, error) {
+ caps := stack.LinkEndpointCapabilities(0)
+ if opts.RXChecksumOffload {
+ caps |= stack.CapabilityRXChecksumOffload
+ }
+
+ if opts.TXChecksumOffload {
+ caps |= stack.CapabilityTXChecksumOffload
+ }
+
+ hdrSize := 0
+ if opts.EthernetHeader {
+ hdrSize = header.EthernetMinimumSize
+ caps |= stack.CapabilityResolutionRequired
+ }
+
+ if opts.SaveRestore {
+ caps |= stack.CapabilitySaveRestore
+ }
+
+ if opts.DisconnectOk {
+ caps |= stack.CapabilityDisconnectOk
+ }
+
+ if len(opts.FDs) == 0 {
+ return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified")
+ }
+
+ e := &endpoint{
+ fds: opts.FDs,
+ mtu: opts.MTU,
+ caps: caps,
+ closed: opts.ClosedFunc,
+ addr: opts.Address,
+ hdrSize: hdrSize,
+ packetDispatchMode: opts.PacketDispatchMode,
+ }
+
+ // Create per channel dispatchers.
+ for i := 0; i < len(e.fds); i++ {
+ fd := e.fds[i]
+ if err := syscall.SetNonblock(fd, true); err != nil {
+ return nil, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err)
+ }
+
+ isSocket, err := isSocketFD(fd)
+ if err != nil {
+ return nil, err
+ }
+ if isSocket {
+ if opts.GSOMaxSize != 0 {
+ if opts.SoftwareGSOEnabled {
+ e.caps |= stack.CapabilitySoftwareGSO
+ } else {
+ e.caps |= stack.CapabilityHardwareGSO
+ }
+ e.gsoMaxSize = opts.GSOMaxSize
+ }
+ }
+ inboundDispatcher, err := createInboundDispatcher(e, fd, isSocket)
+ if err != nil {
+ return nil, fmt.Errorf("createInboundDispatcher(...) = %v", err)
+ }
+ e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher)
+ }
+
+ // Increment fanoutID to ensure that we don't re-use the same fanoutID for
+ // the next endpoint.
+ fanoutID++
+
+ return e, nil
+}
+
+func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher, error) {
+ // By default use the readv() dispatcher as it works with all kinds of
+ // FDs (tap/tun/unix domain sockets and af_packet).
+ inboundDispatcher, err := newReadVDispatcher(fd, e)
+ if err != nil {
+ return nil, fmt.Errorf("newReadVDispatcher(%d, %+v) = %v", fd, e, err)
+ }
+
+ if isSocket {
+ sa, err := unix.Getsockname(fd)
+ if err != nil {
+ return nil, fmt.Errorf("unix.Getsockname(%d) = %v", fd, err)
+ }
+ switch sa.(type) {
+ case *unix.SockaddrLinklayer:
+ // enable PACKET_FANOUT mode is the underlying socket is
+ // of type AF_PACKET.
+ const fanoutType = 0x8000 // PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_DEFRAG
+ fanoutArg := fanoutID | fanoutType<<16
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil {
+ return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err)
+ }
+ }
+
+ switch e.packetDispatchMode {
+ case PacketMMap:
+ inboundDispatcher, err = newPacketMMapDispatcher(fd, e)
+ if err != nil {
+ return nil, fmt.Errorf("newPacketMMapDispatcher(%d, %+v) = %v", fd, e, err)
+ }
+ case RecvMMsg:
+ // If the provided FD is a socket then we optimize
+ // packet reads by using recvmmsg() instead of read() to
+ // read packets in a batch.
+ inboundDispatcher, err = newRecvMMsgDispatcher(fd, e)
+ if err != nil {
+ return nil, fmt.Errorf("newRecvMMsgDispatcher(%d, %+v) = %v", fd, e, err)
+ }
+ }
+ }
+ return inboundDispatcher, nil
+}
+
+func isSocketFD(fd int) (bool, error) {
+ var stat syscall.Stat_t
+ if err := syscall.Fstat(fd, &stat); err != nil {
+ return false, fmt.Errorf("syscall.Fstat(%v,...) failed: %v", fd, err)
+ }
+ return (stat.Mode & syscall.S_IFSOCK) == syscall.S_IFSOCK, nil
+}
+
+// Attach launches the goroutine that reads packets from the file descriptor and
+// dispatches them via the provided dispatcher.
+func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+ // Link endpoints are not savable. When transportation endpoints are
+ // saved, they stop sending outgoing packets and all incoming packets
+ // are rejected.
+ for i := range e.inboundDispatchers {
+ e.wg.Add(1)
+ go func(i int) { // S/R-SAFE: See above.
+ e.dispatchLoop(e.inboundDispatchers[i])
+ e.wg.Done()
+ }(i)
+ }
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
+// during construction.
+func (e *endpoint) MTU() uint32 {
+ return e.mtu
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.caps
+}
+
+// MaxHeaderLength returns the maximum size of the link-layer header.
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return uint16(e.hdrSize)
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (e *endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.addr
+}
+
+// Wait implements stack.LinkEndpoint.Wait. It waits for the endpoint to stop
+// reading from its FD.
+func (e *endpoint) Wait() {
+ e.wg.Wait()
+}
+
+// virtioNetHdr is declared in linux/virtio_net.h.
+type virtioNetHdr struct {
+ flags uint8
+ gsoType uint8
+ hdrLen uint16
+ gsoSize uint16
+ csumStart uint16
+ csumOffset uint16
+}
+
+// These constants are declared in linux/virtio_net.h.
+const (
+ _VIRTIO_NET_HDR_F_NEEDS_CSUM = 1
+
+ _VIRTIO_NET_HDR_GSO_TCPV4 = 1
+ _VIRTIO_NET_HDR_GSO_TCPV6 = 4
+)
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if e.hdrSize > 0 {
+ // Add ethernet header if needed.
+ eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
+ pkt.LinkHeader = buffer.View(eth)
+ ethHdr := &header.EthernetFields{
+ DstAddr: r.RemoteLinkAddress,
+ Type: protocol,
+ }
+
+ // Preserve the src address if it's set in the route.
+ if r.LocalLinkAddress != "" {
+ ethHdr.SrcAddr = r.LocalLinkAddress
+ } else {
+ ethHdr.SrcAddr = e.addr
+ }
+ eth.Encode(ethHdr)
+ }
+
+ fd := e.fds[pkt.Hash%uint32(len(e.fds))]
+ if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ vnetHdr := virtioNetHdr{}
+ if gso != nil {
+ vnetHdr.hdrLen = uint16(pkt.Header.UsedLength())
+ if gso.NeedsCsum {
+ vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
+ vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen
+ vnetHdr.csumOffset = gso.CsumOffset
+ }
+ if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS {
+ switch gso.Type {
+ case stack.GSOTCPv4:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
+ case stack.GSOTCPv6:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
+ default:
+ panic(fmt.Sprintf("Unknown gso type: %v", gso.Type))
+ }
+ vnetHdr.gsoSize = gso.MSS
+ }
+ }
+
+ vnetHdrBuf := binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr)
+ return rawfile.NonBlockingWrite3(fd, vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView())
+ }
+
+ if pkt.Data.Size() == 0 {
+ return rawfile.NonBlockingWrite(fd, pkt.Header.View())
+ }
+ if pkt.Header.UsedLength() == 0 {
+ return rawfile.NonBlockingWrite(fd, pkt.Data.ToView())
+ }
+
+ return rawfile.NonBlockingWrite3(fd, pkt.Header.View(), pkt.Data.ToView(), nil)
+}
+
+func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tcpip.Error) {
+ // Send a batch of packets through batchFD.
+ mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch))
+ for _, pkt := range batch {
+ var ethHdrBuf []byte
+ iovLen := 0
+ if e.hdrSize > 0 {
+ // Add ethernet header if needed.
+ ethHdrBuf = make([]byte, header.EthernetMinimumSize)
+ eth := header.Ethernet(ethHdrBuf)
+ ethHdr := &header.EthernetFields{
+ DstAddr: pkt.EgressRoute.RemoteLinkAddress,
+ Type: pkt.NetworkProtocolNumber,
+ }
+
+ // Preserve the src address if it's set in the route.
+ if pkt.EgressRoute.LocalLinkAddress != "" {
+ ethHdr.SrcAddr = pkt.EgressRoute.LocalLinkAddress
+ } else {
+ ethHdr.SrcAddr = e.addr
+ }
+ eth.Encode(ethHdr)
+ iovLen++
+ }
+
+ vnetHdr := virtioNetHdr{}
+ var vnetHdrBuf []byte
+ if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ if pkt.GSOOptions != nil {
+ vnetHdr.hdrLen = uint16(pkt.Header.UsedLength())
+ if pkt.GSOOptions.NeedsCsum {
+ vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
+ vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen
+ vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset
+ }
+ if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data.Size()) > pkt.GSOOptions.MSS {
+ switch pkt.GSOOptions.Type {
+ case stack.GSOTCPv4:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
+ case stack.GSOTCPv6:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
+ default:
+ panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type))
+ }
+ vnetHdr.gsoSize = pkt.GSOOptions.MSS
+ }
+ }
+ vnetHdrBuf = binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr)
+ iovLen++
+ }
+
+ iovecs := make([]syscall.Iovec, iovLen+1+len(pkt.Data.Views()))
+ var mmsgHdr rawfile.MMsgHdr
+ mmsgHdr.Msg.Iov = &iovecs[0]
+ iovecIdx := 0
+ if vnetHdrBuf != nil {
+ v := &iovecs[iovecIdx]
+ v.Base = &vnetHdrBuf[0]
+ v.Len = uint64(len(vnetHdrBuf))
+ iovecIdx++
+ }
+ if ethHdrBuf != nil {
+ v := &iovecs[iovecIdx]
+ v.Base = &ethHdrBuf[0]
+ v.Len = uint64(len(ethHdrBuf))
+ iovecIdx++
+ }
+ pktSize := uint64(0)
+ // Encode L3 Header
+ v := &iovecs[iovecIdx]
+ hdr := &pkt.Header
+ hdrView := hdr.View()
+ v.Base = &hdrView[0]
+ v.Len = uint64(len(hdrView))
+ pktSize += v.Len
+ iovecIdx++
+
+ // Now encode the Transport Payload.
+ pktViews := pkt.Data.Views()
+ for i := range pktViews {
+ vec := &iovecs[iovecIdx]
+ iovecIdx++
+ vec.Base = &pktViews[i][0]
+ vec.Len = uint64(len(pktViews[i]))
+ pktSize += vec.Len
+ }
+ mmsgHdr.Msg.Iovlen = uint64(iovecIdx)
+ mmsgHdrs = append(mmsgHdrs, mmsgHdr)
+ }
+
+ packets := 0
+ for len(mmsgHdrs) > 0 {
+ sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs)
+ if err != nil {
+ return packets, err
+ }
+ packets += sent
+ mmsgHdrs = mmsgHdrs[sent:]
+ }
+
+ return packets, nil
+}
+
+// WritePackets writes outbound packets to the underlying file descriptors. If
+// one is not currently writable, the packet is dropped.
+//
+// Being a batch API, each packet in pkts should have the following
+// fields populated:
+// - pkt.EgressRoute
+// - pkt.GSOOptions
+// - pkt.NetworkProtocolNumber
+func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ // Preallocate to avoid repeated reallocation as we append to batch.
+ // batchSz is 47 because when SWGSO is in use then a single 65KB TCP
+ // segment can get split into 46 segments of 1420 bytes and a single 216
+ // byte segment.
+ const batchSz = 47
+ batch := make([]*stack.PacketBuffer, 0, batchSz)
+ batchFD := -1
+ sentPackets := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if len(batch) == 0 {
+ batchFD = e.fds[pkt.Hash%uint32(len(e.fds))]
+ }
+ pktFD := e.fds[pkt.Hash%uint32(len(e.fds))]
+ if sendNow := pktFD != batchFD; !sendNow {
+ batch = append(batch, pkt)
+ continue
+ }
+ n, err := e.sendBatch(batchFD, batch)
+ sentPackets += n
+ if err != nil {
+ return sentPackets, err
+ }
+ batch = batch[:0]
+ batch = append(batch, pkt)
+ batchFD = pktFD
+ }
+
+ if len(batch) != 0 {
+ n, err := e.sendBatch(batchFD, batch)
+ sentPackets += n
+ if err != nil {
+ return sentPackets, err
+ }
+ }
+ return sentPackets, nil
+}
+
+// viewsEqual tests whether v1 and v2 refer to the same backing bytes.
+func viewsEqual(vs1, vs2 []buffer.View) bool {
+ return len(vs1) == len(vs2) && (len(vs1) == 0 || &vs1[0] == &vs2[0])
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ return rawfile.NonBlockingWrite(e.fds[0], vv.ToView())
+}
+
+// InjectOutobund implements stack.InjectableEndpoint.InjectOutbound.
+func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
+ return rawfile.NonBlockingWrite(e.fds[0], packet)
+}
+
+// dispatchLoop reads packets from the file descriptor in a loop and dispatches
+// them to the network stack.
+func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) *tcpip.Error {
+ for {
+ cont, err := inboundDispatcher.dispatch()
+ if err != nil || !cont {
+ if e.closed != nil {
+ e.closed(err)
+ }
+ return err
+ }
+ }
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (e *endpoint) GSOMaxSize() uint32 {
+ return e.gsoMaxSize
+}
+
+// InjectableEndpoint is an injectable fd-based endpoint. The endpoint writes
+// to the FD, but does not read from it. All reads come from injected packets.
+type InjectableEndpoint struct {
+ endpoint
+
+ dispatcher stack.NetworkDispatcher
+}
+
+// Attach saves the stack network-layer dispatcher for use later when packets
+// are injected.
+func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// InjectInbound injects an inbound packet.
+func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt)
+}
+
+// NewInjectable creates a new fd-based InjectableEndpoint.
+func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) *InjectableEndpoint {
+ syscall.SetNonblock(fd, true)
+
+ return &InjectableEndpoint{endpoint: endpoint{
+ fds: []int{fd},
+ mtu: mtu,
+ caps: capabilities,
+ }}
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
new file mode 100644
index 000000000..eaee7e5d7
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -0,0 +1,502 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package fdbased
+
+import (
+ "bytes"
+ "fmt"
+ "math/rand"
+ "reflect"
+ "syscall"
+ "testing"
+ "time"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ mtu = 1500
+ laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66")
+ raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc")
+ proto = 10
+ csumOffset = 48
+ gsoMSS = 500
+)
+
+type packetInfo struct {
+ raddr tcpip.LinkAddress
+ proto tcpip.NetworkProtocolNumber
+ contents *stack.PacketBuffer
+}
+
+type context struct {
+ t *testing.T
+ readFDs []int
+ writeFDs []int
+ ep stack.LinkEndpoint
+ ch chan packetInfo
+ done chan struct{}
+}
+
+func newContext(t *testing.T, opt *Options) *context {
+ firstFDPair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
+ if err != nil {
+ t.Fatalf("Socketpair failed: %v", err)
+ }
+ secondFDPair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
+ if err != nil {
+ t.Fatalf("Socketpair failed: %v", err)
+ }
+
+ done := make(chan struct{}, 2)
+ opt.ClosedFunc = func(*tcpip.Error) {
+ done <- struct{}{}
+ }
+
+ opt.FDs = []int{firstFDPair[1], secondFDPair[1]}
+ ep, err := New(opt)
+ if err != nil {
+ t.Fatalf("Failed to create FD endpoint: %v", err)
+ }
+
+ c := &context{
+ t: t,
+ readFDs: []int{firstFDPair[0], secondFDPair[0]},
+ writeFDs: opt.FDs,
+ ep: ep,
+ ch: make(chan packetInfo, 100),
+ done: done,
+ }
+
+ ep.Attach(c)
+
+ return c
+}
+
+func (c *context) cleanup() {
+ for _, fd := range c.readFDs {
+ syscall.Close(fd)
+ }
+ <-c.done
+ <-c.done
+ for _, fd := range c.writeFDs {
+ syscall.Close(fd)
+ }
+}
+
+func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ c.ch <- packetInfo{remote, protocol, pkt}
+}
+
+func TestNoEthernetProperties(t *testing.T) {
+ c := newContext(t, &Options{MTU: mtu})
+ defer c.cleanup()
+
+ if want, v := uint16(0), c.ep.MaxHeaderLength(); want != v {
+ t.Fatalf("MaxHeaderLength() = %v, want %v", v, want)
+ }
+
+ if want, v := uint32(mtu), c.ep.MTU(); want != v {
+ t.Fatalf("MTU() = %v, want %v", v, want)
+ }
+}
+
+func TestEthernetProperties(t *testing.T) {
+ c := newContext(t, &Options{EthernetHeader: true, MTU: mtu})
+ defer c.cleanup()
+
+ if want, v := uint16(header.EthernetMinimumSize), c.ep.MaxHeaderLength(); want != v {
+ t.Fatalf("MaxHeaderLength() = %v, want %v", v, want)
+ }
+
+ if want, v := uint32(mtu), c.ep.MTU(); want != v {
+ t.Fatalf("MTU() = %v, want %v", v, want)
+ }
+}
+
+func TestAddress(t *testing.T) {
+ addrs := []tcpip.LinkAddress{"", "abc", "def"}
+ for _, a := range addrs {
+ t.Run(fmt.Sprintf("Address: %q", a), func(t *testing.T) {
+ c := newContext(t, &Options{Address: a, MTU: mtu})
+ defer c.cleanup()
+
+ if want, v := a, c.ep.LinkAddress(); want != v {
+ t.Fatalf("LinkAddress() = %v, want %v", v, want)
+ }
+ })
+ }
+}
+
+func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash uint32) {
+ c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize})
+ defer c.cleanup()
+
+ r := &stack.Route{
+ RemoteLinkAddress: raddr,
+ }
+
+ // Build header.
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100)
+ b := hdr.Prepend(100)
+ for i := range b {
+ b[i] = uint8(rand.Intn(256))
+ }
+
+ // Build payload and write.
+ payload := make(buffer.View, plen)
+ for i := range payload {
+ payload[i] = uint8(rand.Intn(256))
+ }
+ want := append(hdr.View(), payload...)
+ var gso *stack.GSO
+ if gsoMaxSize != 0 {
+ gso = &stack.GSO{
+ Type: stack.GSOTCPv6,
+ NeedsCsum: true,
+ CsumOffset: csumOffset,
+ MSS: gsoMSS,
+ MaxSize: gsoMaxSize,
+ L3HdrLen: header.IPv4MaximumHeaderSize,
+ }
+ }
+ if err := c.ep.WritePacket(r, gso, proto, &stack.PacketBuffer{
+ Header: hdr,
+ Data: payload.ToVectorisedView(),
+ Hash: hash,
+ }); err != nil {
+ t.Fatalf("WritePacket failed: %v", err)
+ }
+
+ // Read from the corresponding FD, then compare with what we wrote.
+ b = make([]byte, mtu)
+ fd := c.readFDs[hash%uint32(len(c.readFDs))]
+ n, err := syscall.Read(fd, b)
+ if err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+ b = b[:n]
+ if gsoMaxSize != 0 {
+ vnetHdr := *(*virtioNetHdr)(unsafe.Pointer(&b[0]))
+ if vnetHdr.flags&_VIRTIO_NET_HDR_F_NEEDS_CSUM == 0 {
+ t.Fatalf("virtioNetHdr.flags %v doesn't contain %v", vnetHdr.flags, _VIRTIO_NET_HDR_F_NEEDS_CSUM)
+ }
+ csumStart := header.EthernetMinimumSize + gso.L3HdrLen
+ if vnetHdr.csumStart != csumStart {
+ t.Fatalf("vnetHdr.csumStart = %v, want %v", vnetHdr.csumStart, csumStart)
+ }
+ if vnetHdr.csumOffset != csumOffset {
+ t.Fatalf("vnetHdr.csumOffset = %v, want %v", vnetHdr.csumOffset, csumOffset)
+ }
+ gsoType := uint8(0)
+ if int(gso.MSS) < plen {
+ gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
+ }
+ if vnetHdr.gsoType != gsoType {
+ t.Fatalf("vnetHdr.gsoType = %v, want %v", vnetHdr.gsoType, gsoType)
+ }
+ b = b[virtioNetHdrSize:]
+ }
+ if eth {
+ h := header.Ethernet(b)
+ b = b[header.EthernetMinimumSize:]
+
+ if a := h.SourceAddress(); a != laddr {
+ t.Fatalf("SourceAddress() = %v, want %v", a, laddr)
+ }
+
+ if a := h.DestinationAddress(); a != raddr {
+ t.Fatalf("DestinationAddress() = %v, want %v", a, raddr)
+ }
+
+ if et := h.Type(); et != proto {
+ t.Fatalf("Type() = %v, want %v", et, proto)
+ }
+ }
+ if len(b) != len(want) {
+ t.Fatalf("Read returned %v bytes, want %v", len(b), len(want))
+ }
+ if !bytes.Equal(b, want) {
+ t.Fatalf("Read returned %x, want %x", b, want)
+ }
+}
+
+func TestWritePacket(t *testing.T) {
+ lengths := []int{0, 100, 1000}
+ eths := []bool{true, false}
+ gsos := []uint32{0, 32768}
+
+ for _, eth := range eths {
+ for _, plen := range lengths {
+ for _, gso := range gsos {
+ t.Run(
+ fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v", eth, plen, gso),
+ func(t *testing.T) {
+ testWritePacket(t, plen, eth, gso, 0)
+ },
+ )
+ }
+ }
+ }
+}
+
+func TestHashedWritePacket(t *testing.T) {
+ lengths := []int{0, 100, 1000}
+ eths := []bool{true, false}
+ gsos := []uint32{0, 32768}
+ hashes := []uint32{0, 1}
+ for _, eth := range eths {
+ for _, plen := range lengths {
+ for _, gso := range gsos {
+ for _, hash := range hashes {
+ t.Run(
+ fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v,Hash=%d", eth, plen, gso, hash),
+ func(t *testing.T) {
+ testWritePacket(t, plen, eth, gso, hash)
+ },
+ )
+ }
+ }
+ }
+ }
+}
+
+func TestPreserveSrcAddress(t *testing.T) {
+ baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99")
+
+ c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: true})
+ defer c.cleanup()
+
+ // Set LocalLinkAddress in route to the value of the bridged address.
+ r := &stack.Route{
+ RemoteLinkAddress: raddr,
+ LocalLinkAddress: baddr,
+ }
+
+ // WritePacket panics given a prependable with anything less than
+ // the minimum size of the ethernet header.
+ hdr := buffer.NewPrependable(header.EthernetMinimumSize)
+ if err := c.ep.WritePacket(r, nil /* gso */, proto, &stack.PacketBuffer{
+ Header: hdr,
+ Data: buffer.VectorisedView{},
+ }); err != nil {
+ t.Fatalf("WritePacket failed: %v", err)
+ }
+
+ // Read from the FD, then compare with what we wrote.
+ b := make([]byte, mtu)
+ n, err := syscall.Read(c.readFDs[0], b)
+ if err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+ b = b[:n]
+ h := header.Ethernet(b)
+
+ if a := h.SourceAddress(); a != baddr {
+ t.Fatalf("SourceAddress() = %v, want %v", a, baddr)
+ }
+}
+
+func TestDeliverPacket(t *testing.T) {
+ lengths := []int{100, 1000}
+ eths := []bool{true, false}
+
+ for _, eth := range eths {
+ for _, plen := range lengths {
+ t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) {
+ c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth})
+ defer c.cleanup()
+
+ // Build packet.
+ b := make([]byte, plen)
+ all := b
+ for i := range b {
+ b[i] = uint8(rand.Intn(256))
+ }
+
+ var hdr header.Ethernet
+ if !eth {
+ // So that it looks like an IPv4 packet.
+ b[0] = 0x40
+ } else {
+ hdr = make(header.Ethernet, header.EthernetMinimumSize)
+ hdr.Encode(&header.EthernetFields{
+ SrcAddr: raddr,
+ DstAddr: laddr,
+ Type: proto,
+ })
+ all = append(hdr, b...)
+ }
+
+ // Write packet via the file descriptor.
+ if _, err := syscall.Write(c.readFDs[0], all); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ // Receive packet through the endpoint.
+ select {
+ case pi := <-c.ch:
+ want := packetInfo{
+ raddr: raddr,
+ proto: proto,
+ contents: &stack.PacketBuffer{
+ Data: buffer.View(b).ToVectorisedView(),
+ LinkHeader: buffer.View(hdr),
+ },
+ }
+ if !eth {
+ want.proto = header.IPv4ProtocolNumber
+ want.raddr = ""
+ }
+ // want.contents.Data will be a single
+ // view, so make pi do the same for the
+ // DeepEqual check.
+ pi.contents.Data = pi.contents.Data.ToView().ToVectorisedView()
+ if !reflect.DeepEqual(want, pi) {
+ t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want)
+ }
+ case <-time.After(10 * time.Second):
+ t.Fatalf("Timed out waiting for packet")
+ }
+ })
+ }
+ }
+}
+
+func TestBufConfigMaxLength(t *testing.T) {
+ got := 0
+ for _, i := range BufConfig {
+ got += i
+ }
+ want := header.MaxIPPacketSize // maximum TCP packet size
+ if got < want {
+ t.Errorf("total buffer size is invalid: got %d, want >= %d", got, want)
+ }
+}
+
+func TestBufConfigFirst(t *testing.T) {
+ // The stack assumes that the TCP/IP header is enterily contained in the first view.
+ // Therefore, the first view needs to be large enough to contain the maximum TCP/IP
+ // header, which is 120 bytes (60 bytes for IP + 60 bytes for TCP).
+ want := 120
+ got := BufConfig[0]
+ if got < want {
+ t.Errorf("first view has an invalid size: got %d, want >= %d", got, want)
+ }
+}
+
+var capLengthTestCases = []struct {
+ comment string
+ config []int
+ n int
+ wantUsed int
+ wantLengths []int
+}{
+ {
+ comment: "Single slice",
+ config: []int{2},
+ n: 1,
+ wantUsed: 1,
+ wantLengths: []int{1},
+ },
+ {
+ comment: "Multiple slices",
+ config: []int{1, 2},
+ n: 2,
+ wantUsed: 2,
+ wantLengths: []int{1, 1},
+ },
+ {
+ comment: "Entire buffer",
+ config: []int{1, 2},
+ n: 3,
+ wantUsed: 2,
+ wantLengths: []int{1, 2},
+ },
+ {
+ comment: "Entire buffer but not on the last slice",
+ config: []int{1, 2, 3},
+ n: 3,
+ wantUsed: 2,
+ wantLengths: []int{1, 2, 3},
+ },
+}
+
+func TestReadVDispatcherCapLength(t *testing.T) {
+ for _, c := range capLengthTestCases {
+ // fd does not matter for this test.
+ d := readVDispatcher{fd: -1, e: &endpoint{}}
+ d.views = make([]buffer.View, len(c.config))
+ d.iovecs = make([]syscall.Iovec, len(c.config))
+ d.allocateViews(c.config)
+
+ used := d.capViews(c.n, c.config)
+ if used != c.wantUsed {
+ t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed)
+ }
+ lengths := make([]int, len(d.views))
+ for i, v := range d.views {
+ lengths[i] = len(v)
+ }
+ if !reflect.DeepEqual(lengths, c.wantLengths) {
+ t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths)
+ }
+ }
+}
+
+func TestRecvMMsgDispatcherCapLength(t *testing.T) {
+ for _, c := range capLengthTestCases {
+ d := recvMMsgDispatcher{
+ fd: -1, // fd does not matter for this test.
+ e: &endpoint{},
+ views: make([][]buffer.View, 1),
+ iovecs: make([][]syscall.Iovec, 1),
+ msgHdrs: make([]rawfile.MMsgHdr, 1),
+ }
+
+ for i, _ := range d.views {
+ d.views[i] = make([]buffer.View, len(c.config))
+ }
+ for i := range d.iovecs {
+ d.iovecs[i] = make([]syscall.Iovec, len(c.config))
+ }
+ for k, msgHdr := range d.msgHdrs {
+ msgHdr.Msg.Iov = &d.iovecs[k][0]
+ msgHdr.Msg.Iovlen = uint64(len(c.config))
+ }
+
+ d.allocateViews(c.config)
+
+ used := d.capViews(0, c.n, c.config)
+ if used != c.wantUsed {
+ t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed)
+ }
+ lengths := make([]int, len(d.views[0]))
+ for i, v := range d.views[0] {
+ lengths[i] = len(v)
+ }
+ if !reflect.DeepEqual(lengths, c.wantLengths) {
+ t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths)
+ }
+
+ }
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint_unsafe.go b/pkg/tcpip/link/fdbased/endpoint_unsafe.go
new file mode 100644
index 000000000..df14eaad1
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/endpoint_unsafe.go
@@ -0,0 +1,23 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package fdbased
+
+import (
+ "unsafe"
+)
+
+const virtioNetHdrSize = int(unsafe.Sizeof(virtioNetHdr{}))
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
new file mode 100644
index 000000000..2dfd29aa9
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -0,0 +1,199 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux,amd64 linux,arm64
+
+package fdbased
+
+import (
+ "encoding/binary"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ tPacketAlignment = uintptr(16)
+ tpStatusKernel = 0
+ tpStatusUser = 1
+ tpStatusCopy = 2
+ tpStatusLosing = 4
+)
+
+// We overallocate the frame size to accommodate space for the
+// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding.
+//
+// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB
+//
+// NOTE:
+// Frames need to be aligned at 16 byte boundaries.
+// BlockSize needs to be page aligned.
+//
+// For details see PACKET_MMAP setting constraints in
+// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
+const (
+ tpFrameSize = 65536 + 128
+ tpBlockSize = tpFrameSize * 32
+ tpBlockNR = 1
+ tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize
+)
+
+// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct
+// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>.
+func tPacketAlign(v uintptr) uintptr {
+ return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1))
+}
+
+// tPacketReq is the tpacket_req structure as described in
+// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
+type tPacketReq struct {
+ tpBlockSize uint32
+ tpBlockNR uint32
+ tpFrameSize uint32
+ tpFrameNR uint32
+}
+
+// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h>
+type tPacketHdr []byte
+
+const (
+ tpStatusOffset = 0
+ tpLenOffset = 8
+ tpSnapLenOffset = 12
+ tpMacOffset = 16
+ tpNetOffset = 18
+ tpSecOffset = 20
+ tpUSecOffset = 24
+)
+
+func (t tPacketHdr) tpLen() uint32 {
+ return binary.LittleEndian.Uint32(t[tpLenOffset:])
+}
+
+func (t tPacketHdr) tpSnapLen() uint32 {
+ return binary.LittleEndian.Uint32(t[tpSnapLenOffset:])
+}
+
+func (t tPacketHdr) tpMac() uint16 {
+ return binary.LittleEndian.Uint16(t[tpMacOffset:])
+}
+
+func (t tPacketHdr) tpNet() uint16 {
+ return binary.LittleEndian.Uint16(t[tpNetOffset:])
+}
+
+func (t tPacketHdr) tpSec() uint32 {
+ return binary.LittleEndian.Uint32(t[tpSecOffset:])
+}
+
+func (t tPacketHdr) tpUSec() uint32 {
+ return binary.LittleEndian.Uint32(t[tpUSecOffset:])
+}
+
+func (t tPacketHdr) Payload() []byte {
+ return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()]
+}
+
+// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets.
+// See: mmap_amd64_unsafe.go for implementation details.
+type packetMMapDispatcher struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // e is the endpoint this dispatcher is attached to.
+ e *endpoint
+
+ // ringBuffer is only used when PacketMMap dispatcher is used and points
+ // to the start of the mmapped PACKET_RX_RING buffer.
+ ringBuffer []byte
+
+ // ringOffset is the current offset into the ring buffer where the next
+ // inbound packet will be placed by the kernel.
+ ringOffset int
+}
+
+func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) {
+ hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
+ for hdr.tpStatus()&tpStatusUser == 0 {
+ event := rawfile.PollEvent{
+ FD: int32(d.fd),
+ Events: unix.POLLIN | unix.POLLERR,
+ }
+ if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 {
+ if errno == syscall.EINTR {
+ continue
+ }
+ return nil, rawfile.TranslateErrno(errno)
+ }
+ if hdr.tpStatus()&tpStatusCopy != 0 {
+ // This frame is truncated so skip it after flipping the
+ // buffer to the kernel.
+ hdr.setTPStatus(tpStatusKernel)
+ d.ringOffset = (d.ringOffset + 1) % tpFrameNR
+ hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:])
+ continue
+ }
+ }
+
+ // Copy out the packet from the mmapped frame to a locally owned buffer.
+ pkt := make([]byte, hdr.tpSnapLen())
+ copy(pkt, hdr.Payload())
+ // Release packet to kernel.
+ hdr.setTPStatus(tpStatusKernel)
+ d.ringOffset = (d.ringOffset + 1) % tpFrameNR
+ return pkt, nil
+}
+
+// dispatch reads packets from an mmaped ring buffer and dispatches them to the
+// network stack.
+func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
+ pkt, err := d.readMMappedPacket()
+ if err != nil {
+ return false, err
+ }
+ var (
+ p tcpip.NetworkProtocolNumber
+ remote, local tcpip.LinkAddress
+ eth header.Ethernet
+ )
+ if d.e.hdrSize > 0 {
+ eth = header.Ethernet(pkt)
+ p = eth.Type()
+ remote = eth.SourceAddress()
+ local = eth.DestinationAddress()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ switch header.IPVersion(pkt) {
+ case header.IPv4Version:
+ p = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ p = header.IPv6ProtocolNumber
+ default:
+ return true, nil
+ }
+ }
+
+ pkt = pkt[d.e.hdrSize:]
+ d.e.dispatcher.DeliverNetworkPacket(remote, local, p, &stack.PacketBuffer{
+ Data: buffer.View(pkt).ToVectorisedView(),
+ LinkHeader: buffer.View(eth),
+ })
+ return true, nil
+}
diff --git a/pkg/tcpip/link/fdbased/mmap_stub.go b/pkg/tcpip/link/fdbased/mmap_stub.go
new file mode 100644
index 000000000..67be52d67
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/mmap_stub.go
@@ -0,0 +1,23 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !linux !amd64,!arm64
+
+package fdbased
+
+// Stubbed out version for non-linux/non-amd64/non-arm64 platforms.
+
+func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ return nil, nil
+}
diff --git a/pkg/tcpip/link/fdbased/mmap_unsafe.go b/pkg/tcpip/link/fdbased/mmap_unsafe.go
new file mode 100644
index 000000000..3894185ae
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/mmap_unsafe.go
@@ -0,0 +1,84 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux,amd64 linux,arm64
+
+package fdbased
+
+import (
+ "fmt"
+ "sync/atomic"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+// tPacketHdrlen is the TPACKET_HDRLEN variable defined in <linux/if_packet.h>.
+var tPacketHdrlen = tPacketAlign(unsafe.Sizeof(tPacketHdr{}) + unsafe.Sizeof(syscall.RawSockaddrLinklayer{}))
+
+// tpStatus returns the frame status field.
+// The status is concurrently updated by the kernel as a result we must
+// use atomic operations to prevent races.
+func (t tPacketHdr) tpStatus() uint32 {
+ hdr := unsafe.Pointer(&t[0])
+ statusPtr := unsafe.Pointer(uintptr(hdr) + uintptr(tpStatusOffset))
+ return atomic.LoadUint32((*uint32)(statusPtr))
+}
+
+// setTPStatus set's the frame status to the provided status.
+// The status is concurrently updated by the kernel as a result we must
+// use atomic operations to prevent races.
+func (t tPacketHdr) setTPStatus(status uint32) {
+ hdr := unsafe.Pointer(&t[0])
+ statusPtr := unsafe.Pointer(uintptr(hdr) + uintptr(tpStatusOffset))
+ atomic.StoreUint32((*uint32)(statusPtr), status)
+}
+
+func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ d := &packetMMapDispatcher{
+ fd: fd,
+ e: e,
+ }
+ pageSize := unix.Getpagesize()
+ if tpBlockSize%pageSize != 0 {
+ return nil, fmt.Errorf("tpBlockSize: %d is not page aligned, pagesize: %d", tpBlockSize, pageSize)
+ }
+ tReq := tPacketReq{
+ tpBlockSize: uint32(tpBlockSize),
+ tpBlockNR: uint32(tpBlockNR),
+ tpFrameSize: uint32(tpFrameSize),
+ tpFrameNR: uint32(tpFrameNR),
+ }
+ // Setup PACKET_RX_RING.
+ if err := setsockopt(d.fd, syscall.SOL_PACKET, syscall.PACKET_RX_RING, unsafe.Pointer(&tReq), unsafe.Sizeof(tReq)); err != nil {
+ return nil, fmt.Errorf("failed to enable PACKET_RX_RING: %v", err)
+ }
+ // Let's mmap the blocks.
+ sz := tpBlockSize * tpBlockNR
+ buf, err := syscall.Mmap(d.fd, 0, sz, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED)
+ if err != nil {
+ return nil, fmt.Errorf("syscall.Mmap(...,0, %v, ...) failed = %v", sz, err)
+ }
+ d.ringBuffer = buf
+ return d, nil
+}
+
+func setsockopt(fd, level, name int, val unsafe.Pointer, vallen uintptr) error {
+ if _, _, errno := syscall.Syscall6(syscall.SYS_SETSOCKOPT, uintptr(fd), uintptr(level), uintptr(name), uintptr(val), vallen, 0); errno != 0 {
+ return error(errno)
+ }
+
+ return nil
+}
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
new file mode 100644
index 000000000..f04738cfb
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -0,0 +1,317 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package fdbased
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// BufConfig defines the shape of the vectorised view used to read packets from the NIC.
+var BufConfig = []int{128, 256, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768}
+
+// readVDispatcher uses readv() system call to read inbound packets and
+// dispatches them.
+type readVDispatcher struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // e is the endpoint this dispatcher is attached to.
+ e *endpoint
+
+ // views are the actual buffers that hold the packet contents.
+ views []buffer.View
+
+ // iovecs are initialized with base pointers/len of the corresponding
+ // entries in the views defined above, except when GSO is enabled then
+ // the first iovec points to a buffer for the vnet header which is
+ // stripped before the views are passed up the stack for further
+ // processing.
+ iovecs []syscall.Iovec
+}
+
+func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ d := &readVDispatcher{fd: fd, e: e}
+ d.views = make([]buffer.View, len(BufConfig))
+ iovLen := len(BufConfig)
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ iovLen++
+ }
+ d.iovecs = make([]syscall.Iovec, iovLen)
+ return d, nil
+}
+
+func (d *readVDispatcher) allocateViews(bufConfig []int) {
+ var vnetHdr [virtioNetHdrSize]byte
+ vnetHdrOff := 0
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ // The kernel adds virtioNetHdr before each packet, but
+ // we don't use it, so so we allocate a buffer for it,
+ // add it in iovecs but don't add it in a view.
+ d.iovecs[0] = syscall.Iovec{
+ Base: &vnetHdr[0],
+ Len: uint64(virtioNetHdrSize),
+ }
+ vnetHdrOff++
+ }
+ for i := 0; i < len(bufConfig); i++ {
+ if d.views[i] != nil {
+ break
+ }
+ b := buffer.NewView(bufConfig[i])
+ d.views[i] = b
+ d.iovecs[i+vnetHdrOff] = syscall.Iovec{
+ Base: &b[0],
+ Len: uint64(len(b)),
+ }
+ }
+}
+
+func (d *readVDispatcher) capViews(n int, buffers []int) int {
+ c := 0
+ for i, s := range buffers {
+ c += s
+ if c >= n {
+ d.views[i].CapLength(s - (c - n))
+ return i + 1
+ }
+ }
+ return len(buffers)
+}
+
+// dispatch reads one packet from the file descriptor and dispatches it.
+func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
+ d.allocateViews(BufConfig)
+
+ n, err := rawfile.BlockingReadv(d.fd, d.iovecs)
+ if err != nil {
+ return false, err
+ }
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ // Skip virtioNetHdr which is added before each packet, it
+ // isn't used and it isn't in a view.
+ n -= virtioNetHdrSize
+ }
+ if n <= d.e.hdrSize {
+ return false, nil
+ }
+
+ var (
+ p tcpip.NetworkProtocolNumber
+ remote, local tcpip.LinkAddress
+ eth header.Ethernet
+ )
+ if d.e.hdrSize > 0 {
+ eth = header.Ethernet(d.views[0][:header.EthernetMinimumSize])
+ p = eth.Type()
+ remote = eth.SourceAddress()
+ local = eth.DestinationAddress()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ switch header.IPVersion(d.views[0]) {
+ case header.IPv4Version:
+ p = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ p = header.IPv6ProtocolNumber
+ default:
+ return true, nil
+ }
+ }
+
+ used := d.capViews(n, BufConfig)
+ pkt := &stack.PacketBuffer{
+ Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)),
+ LinkHeader: buffer.View(eth),
+ }
+ pkt.Data.TrimFront(d.e.hdrSize)
+
+ d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt)
+
+ // Prepare e.views for another packet: release used views.
+ for i := 0; i < used; i++ {
+ d.views[i] = nil
+ }
+
+ return true, nil
+}
+
+// recvMMsgDispatcher uses the recvmmsg system call to read inbound packets and
+// dispatches them.
+type recvMMsgDispatcher struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // e is the endpoint this dispatcher is attached to.
+ e *endpoint
+
+ // views is an array of array of buffers that contain packet contents.
+ views [][]buffer.View
+
+ // iovecs is an array of array of iovec records where each iovec base
+ // pointer and length are initialzed to the corresponding view above,
+ // except when GSO is enabled then the first iovec in each array of
+ // iovecs points to a buffer for the vnet header which is stripped
+ // before the views are passed up the stack for further processing.
+ iovecs [][]syscall.Iovec
+
+ // msgHdrs is an array of MMsgHdr objects where each MMsghdr is used to
+ // reference an array of iovecs in the iovecs field defined above. This
+ // array is passed as the parameter to recvmmsg call to retrieve
+ // potentially more than 1 packet per syscall.
+ msgHdrs []rawfile.MMsgHdr
+}
+
+const (
+ // MaxMsgsPerRecv is the maximum number of packets we want to retrieve
+ // in a single RecvMMsg call.
+ MaxMsgsPerRecv = 8
+)
+
+func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ d := &recvMMsgDispatcher{
+ fd: fd,
+ e: e,
+ }
+ d.views = make([][]buffer.View, MaxMsgsPerRecv)
+ for i := range d.views {
+ d.views[i] = make([]buffer.View, len(BufConfig))
+ }
+ d.iovecs = make([][]syscall.Iovec, MaxMsgsPerRecv)
+ iovLen := len(BufConfig)
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ // virtioNetHdr is prepended before each packet.
+ iovLen++
+ }
+ for i := range d.iovecs {
+ d.iovecs[i] = make([]syscall.Iovec, iovLen)
+ }
+ d.msgHdrs = make([]rawfile.MMsgHdr, MaxMsgsPerRecv)
+ for i := range d.msgHdrs {
+ d.msgHdrs[i].Msg.Iov = &d.iovecs[i][0]
+ d.msgHdrs[i].Msg.Iovlen = uint64(iovLen)
+ }
+ return d, nil
+}
+
+func (d *recvMMsgDispatcher) capViews(k, n int, buffers []int) int {
+ c := 0
+ for i, s := range buffers {
+ c += s
+ if c >= n {
+ d.views[k][i].CapLength(s - (c - n))
+ return i + 1
+ }
+ }
+ return len(buffers)
+}
+
+func (d *recvMMsgDispatcher) allocateViews(bufConfig []int) {
+ for k := 0; k < len(d.views); k++ {
+ var vnetHdr [virtioNetHdrSize]byte
+ vnetHdrOff := 0
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ // The kernel adds virtioNetHdr before each packet, but
+ // we don't use it, so so we allocate a buffer for it,
+ // add it in iovecs but don't add it in a view.
+ d.iovecs[k][0] = syscall.Iovec{
+ Base: &vnetHdr[0],
+ Len: uint64(virtioNetHdrSize),
+ }
+ vnetHdrOff++
+ }
+ for i := 0; i < len(bufConfig); i++ {
+ if d.views[k][i] != nil {
+ break
+ }
+ b := buffer.NewView(bufConfig[i])
+ d.views[k][i] = b
+ d.iovecs[k][i+vnetHdrOff] = syscall.Iovec{
+ Base: &b[0],
+ Len: uint64(len(b)),
+ }
+ }
+ }
+}
+
+// recvMMsgDispatch reads more than one packet at a time from the file
+// descriptor and dispatches it.
+func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
+ d.allocateViews(BufConfig)
+
+ nMsgs, err := rawfile.BlockingRecvMMsg(d.fd, d.msgHdrs)
+ if err != nil {
+ return false, err
+ }
+ // Process each of received packets.
+ for k := 0; k < nMsgs; k++ {
+ n := int(d.msgHdrs[k].Len)
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ n -= virtioNetHdrSize
+ }
+ if n <= d.e.hdrSize {
+ return false, nil
+ }
+
+ var (
+ p tcpip.NetworkProtocolNumber
+ remote, local tcpip.LinkAddress
+ eth header.Ethernet
+ )
+ if d.e.hdrSize > 0 {
+ eth = header.Ethernet(d.views[k][0])
+ p = eth.Type()
+ remote = eth.SourceAddress()
+ local = eth.DestinationAddress()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ switch header.IPVersion(d.views[k][0]) {
+ case header.IPv4Version:
+ p = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ p = header.IPv6ProtocolNumber
+ default:
+ return true, nil
+ }
+ }
+
+ used := d.capViews(k, int(n), BufConfig)
+ pkt := &stack.PacketBuffer{
+ Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)),
+ LinkHeader: buffer.View(eth),
+ }
+ pkt.Data.TrimFront(d.e.hdrSize)
+ d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt)
+
+ // Prepare e.views for another packet: release used views.
+ for i := 0; i < used; i++ {
+ d.views[k][i] = nil
+ }
+ }
+
+ for k := 0; k < nMsgs; k++ {
+ d.msgHdrs[k].Len = 0
+ }
+
+ return true, nil
+}
diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD
new file mode 100644
index 000000000..6bf3805b7
--- /dev/null
+++ b/pkg/tcpip/link/loopback/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "loopback",
+ srcs = ["loopback.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
new file mode 100644
index 000000000..568c6874f
--- /dev/null
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -0,0 +1,115 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package loopback provides the implemention of loopback data-link layer
+// endpoints. Such endpoints just turn outbound packets into inbound ones.
+//
+// Loopback endpoints can be used in the networking stack by calling New() to
+// create a new endpoint, and then passing it as an argument to
+// Stack.CreateNIC().
+package loopback
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type endpoint struct {
+ dispatcher stack.NetworkDispatcher
+}
+
+// New creates a new loopback endpoint. This link-layer endpoint just turns
+// outbound packets into inbound packets.
+func New() stack.LinkEndpoint {
+ return &endpoint{}
+}
+
+// Attach implements stack.LinkEndpoint.Attach. It just saves the stack network-
+// layer dispatcher for later use when packets need to be dispatched.
+func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns a constant that matches the
+// linux loopback interface.
+func (*endpoint) MTU() uint32 {
+ return 65536
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities. Loopback advertises
+// itself as supporting checksum offload, but in reality it's just omitted.
+func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return stack.CapabilityRXChecksumOffload | stack.CapabilityTXChecksumOffload | stack.CapabilitySaveRestore | stack.CapabilityLoopback
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. Given that the
+// loopback interface doesn't have a header, it just returns 0.
+func (*endpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (*endpoint) LinkAddress() tcpip.LinkAddress {
+ return ""
+}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*endpoint) Wait() {}
+
+// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
+// packets to the network-layer dispatcher.
+func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
+ views[0] = pkt.Header.View()
+ views = append(views, pkt.Data.Views()...)
+
+ // Because we're immediately turning around and writing the packet back
+ // to the rx path, we intentionally don't preserve the remote and local
+ // link addresses from the stack.Route we're passed.
+ e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, &stack.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
+ })
+
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ // There should be an ethernet header at the beginning of vv.
+ hdr, ok := vv.PullUp(header.EthernetMinimumSize)
+ if !ok {
+ // Reject the packet if it's shorter than an ethernet header.
+ return tcpip.ErrBadAddress
+ }
+ linkHeader := header.Ethernet(hdr)
+ vv.TrimFront(len(linkHeader))
+ e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), &stack.PacketBuffer{
+ Data: vv,
+ LinkHeader: buffer.View(linkHeader),
+ })
+
+ return nil
+}
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
new file mode 100644
index 000000000..82b441b79
--- /dev/null
+++ b/pkg/tcpip/link/muxed/BUILD
@@ -0,0 +1,28 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "muxed",
+ srcs = ["injectable.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "muxed_test",
+ size = "small",
+ srcs = ["injectable_test.go"],
+ library = ":muxed",
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/link/fdbased",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
new file mode 100644
index 000000000..c69d6b7e9
--- /dev/null
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -0,0 +1,137 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package muxed provides a muxed link endpoints.
+package muxed
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// InjectableEndpoint is an injectable multi endpoint. The endpoint has
+// trivial routing rules that determine which InjectableEndpoint a given packet
+// will be written to. Note that HandleLocal works differently for this
+// endpoint (see WritePacket).
+type InjectableEndpoint struct {
+ routes map[tcpip.Address]stack.InjectableLinkEndpoint
+ dispatcher stack.NetworkDispatcher
+}
+
+// MTU implements stack.LinkEndpoint.
+func (m *InjectableEndpoint) MTU() uint32 {
+ minMTU := ^uint32(0)
+ for _, endpoint := range m.routes {
+ if endpointMTU := endpoint.MTU(); endpointMTU < minMTU {
+ minMTU = endpointMTU
+ }
+ }
+ return minMTU
+}
+
+// Capabilities implements stack.LinkEndpoint.
+func (m *InjectableEndpoint) Capabilities() stack.LinkEndpointCapabilities {
+ minCapabilities := stack.LinkEndpointCapabilities(^uint(0))
+ for _, endpoint := range m.routes {
+ minCapabilities &= endpoint.Capabilities()
+ }
+ return minCapabilities
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.
+func (m *InjectableEndpoint) MaxHeaderLength() uint16 {
+ minHeaderLen := ^uint16(0)
+ for _, endpoint := range m.routes {
+ if headerLen := endpoint.MaxHeaderLength(); headerLen < minHeaderLen {
+ minHeaderLen = headerLen
+ }
+ }
+ return minHeaderLen
+}
+
+// LinkAddress implements stack.LinkEndpoint.
+func (m *InjectableEndpoint) LinkAddress() tcpip.LinkAddress {
+ return ""
+}
+
+// Attach implements stack.LinkEndpoint.
+func (m *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ for _, endpoint := range m.routes {
+ endpoint.Attach(dispatcher)
+ }
+ m.dispatcher = dispatcher
+}
+
+// IsAttached implements stack.LinkEndpoint.
+func (m *InjectableEndpoint) IsAttached() bool {
+ return m.dispatcher != nil
+}
+
+// InjectInbound implements stack.InjectableLinkEndpoint.
+func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ m.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt)
+}
+
+// WritePackets writes outbound packets to the appropriate
+// LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if
+// r.RemoteAddress has a route registered in this endpoint.
+func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ endpoint, ok := m.routes[r.RemoteAddress]
+ if !ok {
+ return 0, tcpip.ErrNoRoute
+ }
+ return endpoint.WritePackets(r, gso, pkts, protocol)
+}
+
+// WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint
+// based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a
+// route registered in this endpoint.
+func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if endpoint, ok := m.routes[r.RemoteAddress]; ok {
+ return endpoint.WritePacket(r, gso, protocol, pkt)
+ }
+ return tcpip.ErrNoRoute
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (m *InjectableEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
+ // WriteRawPacket doesn't get a route or network address, so there's
+ // nowhere to write this.
+ return tcpip.ErrNoRoute
+}
+
+// InjectOutbound writes outbound packets to the appropriate
+// LinkInjectableEndpoint based on the dest address.
+func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
+ endpoint, ok := m.routes[dest]
+ if !ok {
+ return tcpip.ErrNoRoute
+ }
+ return endpoint.InjectOutbound(dest, packet)
+}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (m *InjectableEndpoint) Wait() {
+ for _, ep := range m.routes {
+ ep.Wait()
+ }
+}
+
+// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
+func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
+ return &InjectableEndpoint{
+ routes: routes,
+ }
+}
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
new file mode 100644
index 000000000..0744f66d6
--- /dev/null
+++ b/pkg/tcpip/link/muxed/injectable_test.go
@@ -0,0 +1,98 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package muxed
+
+import (
+ "bytes"
+ "net"
+ "os"
+ "syscall"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+func TestInjectableEndpointRawDispatch(t *testing.T) {
+ endpoint, sock, dstIP := makeTestInjectableEndpoint(t)
+
+ endpoint.InjectOutbound(dstIP, []byte{0xFA})
+
+ buf := make([]byte, ipv4.MaxTotalSize)
+ bytesRead, err := sock.Read(buf)
+ if err != nil {
+ t.Fatalf("Unable to read from socketpair: %v", err)
+ }
+ if got, want := buf[:bytesRead], []byte{0xFA}; !bytes.Equal(got, want) {
+ t.Fatalf("Read %v from the socketpair, wanted %v", got, want)
+ }
+}
+
+func TestInjectableEndpointDispatch(t *testing.T) {
+ endpoint, sock, dstIP := makeTestInjectableEndpoint(t)
+
+ hdr := buffer.NewPrependable(1)
+ hdr.Prepend(1)[0] = 0xFA
+ packetRoute := stack.Route{RemoteAddress: dstIP}
+
+ endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Header: hdr,
+ Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(),
+ })
+
+ buf := make([]byte, 6500)
+ bytesRead, err := sock.Read(buf)
+ if err != nil {
+ t.Fatalf("Unable to read from socketpair: %v", err)
+ }
+ if got, want := buf[:bytesRead], []byte{0xFA, 0xFB}; !bytes.Equal(got, want) {
+ t.Fatalf("Read %v from the socketpair, wanted %v", got, want)
+ }
+}
+
+func TestInjectableEndpointDispatchHdrOnly(t *testing.T) {
+ endpoint, sock, dstIP := makeTestInjectableEndpoint(t)
+ hdr := buffer.NewPrependable(1)
+ hdr.Prepend(1)[0] = 0xFA
+ packetRoute := stack.Route{RemoteAddress: dstIP}
+ endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Header: hdr,
+ Data: buffer.NewView(0).ToVectorisedView(),
+ })
+ buf := make([]byte, 6500)
+ bytesRead, err := sock.Read(buf)
+ if err != nil {
+ t.Fatalf("Unable to read from socketpair: %v", err)
+ }
+ if got, want := buf[:bytesRead], []byte{0xFA}; !bytes.Equal(got, want) {
+ t.Fatalf("Read %v from the socketpair, wanted %v", got, want)
+ }
+}
+
+func makeTestInjectableEndpoint(t *testing.T) (*InjectableEndpoint, *os.File, tcpip.Address) {
+ dstIP := tcpip.Address(net.ParseIP("1.2.3.4").To4())
+ pair, err := syscall.Socketpair(syscall.AF_UNIX,
+ syscall.SOCK_SEQPACKET|syscall.SOCK_CLOEXEC|syscall.SOCK_NONBLOCK, 0)
+ if err != nil {
+ t.Fatal("Failed to create socket pair:", err)
+ }
+ underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone)
+ routes := map[tcpip.Address]stack.InjectableLinkEndpoint{dstIP: underlyingEndpoint}
+ endpoint := NewInjectableEndpoint(routes)
+ return endpoint, os.NewFile(uintptr(pair[0]), "test route end"), dstIP
+}
diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD
new file mode 100644
index 000000000..bdd5276ad
--- /dev/null
+++ b/pkg/tcpip/link/nested/BUILD
@@ -0,0 +1,31 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "nested",
+ srcs = [
+ "nested.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "nested_test",
+ size = "small",
+ srcs = [
+ "nested_test.go",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/nested",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
new file mode 100644
index 000000000..2998f9c4f
--- /dev/null
+++ b/pkg/tcpip/link/nested/nested.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package nested provides helpers to implement the pattern of nested
+// stack.LinkEndpoints.
+package nested
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// Endpoint is a wrapper around stack.LinkEndpoint and stack.NetworkDispatcher
+// that can be used to implement nesting safely by providing lifecycle
+// concurrency guards.
+//
+// See the tests in this package for example usage.
+type Endpoint struct {
+ child stack.LinkEndpoint
+ embedder stack.NetworkDispatcher
+
+ // mu protects dispatcher.
+ mu sync.RWMutex
+ dispatcher stack.NetworkDispatcher
+}
+
+var _ stack.GSOEndpoint = (*Endpoint)(nil)
+var _ stack.LinkEndpoint = (*Endpoint)(nil)
+var _ stack.NetworkDispatcher = (*Endpoint)(nil)
+
+// Init initializes a nested.Endpoint that uses embedder as the dispatcher for
+// child on Attach.
+//
+// See the tests in this package for example usage.
+func (e *Endpoint) Init(child stack.LinkEndpoint, embedder stack.NetworkDispatcher) {
+ e.child = child
+ e.embedder = embedder
+}
+
+// DeliverNetworkPacket implements stack.NetworkDispatcher.
+func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.mu.RLock()
+ d := e.dispatcher
+ e.mu.RUnlock()
+ if d != nil {
+ d.DeliverNetworkPacket(remote, local, protocol, pkt)
+ }
+}
+
+// Attach implements stack.LinkEndpoint.
+func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.mu.Lock()
+ e.dispatcher = dispatcher
+ e.mu.Unlock()
+ // If we're attaching to a valid dispatcher, pass embedder as the dispatcher
+ // to our child, otherwise detach the child by giving it a nil dispatcher.
+ var pass stack.NetworkDispatcher
+ if dispatcher != nil {
+ pass = e.embedder
+ }
+ e.child.Attach(pass)
+}
+
+// IsAttached implements stack.LinkEndpoint.
+func (e *Endpoint) IsAttached() bool {
+ e.mu.RLock()
+ isAttached := e.dispatcher != nil
+ e.mu.RUnlock()
+ return isAttached
+}
+
+// MTU implements stack.LinkEndpoint.
+func (e *Endpoint) MTU() uint32 {
+ return e.child.MTU()
+}
+
+// Capabilities implements stack.LinkEndpoint.
+func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.child.Capabilities()
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.
+func (e *Endpoint) MaxHeaderLength() uint16 {
+ return e.child.MaxHeaderLength()
+}
+
+// LinkAddress implements stack.LinkEndpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.child.LinkAddress()
+}
+
+// WritePacket implements stack.LinkEndpoint.
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ return e.child.WritePacket(r, gso, protocol, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ return e.child.WritePackets(r, gso, pkts, protocol)
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ return e.child.WriteRawPacket(vv)
+}
+
+// Wait implements stack.LinkEndpoint.
+func (e *Endpoint) Wait() {
+ e.child.Wait()
+}
+
+// GSOMaxSize implements stack.GSOEndpoint.
+func (e *Endpoint) GSOMaxSize() uint32 {
+ if e, ok := e.child.(stack.GSOEndpoint); ok {
+ return e.GSOMaxSize()
+ }
+ return 0
+}
diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go
new file mode 100644
index 000000000..c1a219f02
--- /dev/null
+++ b/pkg/tcpip/link/nested/nested_test.go
@@ -0,0 +1,105 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package nested_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type parentEndpoint struct {
+ nested.Endpoint
+}
+
+var _ stack.LinkEndpoint = (*parentEndpoint)(nil)
+var _ stack.NetworkDispatcher = (*parentEndpoint)(nil)
+
+type childEndpoint struct {
+ stack.LinkEndpoint
+ dispatcher stack.NetworkDispatcher
+}
+
+var _ stack.LinkEndpoint = (*childEndpoint)(nil)
+
+func (c *childEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ c.dispatcher = dispatcher
+}
+
+func (c *childEndpoint) IsAttached() bool {
+ return c.dispatcher != nil
+}
+
+type counterDispatcher struct {
+ count int
+}
+
+var _ stack.NetworkDispatcher = (*counterDispatcher)(nil)
+
+func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
+ d.count++
+}
+
+func TestNestedLinkEndpoint(t *testing.T) {
+ const emptyAddress = tcpip.LinkAddress("")
+
+ var (
+ childEP childEndpoint
+ nestedEP parentEndpoint
+ disp counterDispatcher
+ )
+ nestedEP.Endpoint.Init(&childEP, &nestedEP)
+
+ if childEP.IsAttached() {
+ t.Error("On init, childEP.IsAttached() = true, want = false")
+ }
+ if nestedEP.IsAttached() {
+ t.Error("On init, nestedEP.IsAttached() = true, want = false")
+ }
+
+ nestedEP.Attach(&disp)
+ if disp.count != 0 {
+ t.Fatalf("After attach, got disp.count = %d, want = 0", disp.count)
+ }
+ if !childEP.IsAttached() {
+ t.Error("After attach, childEP.IsAttached() = false, want = true")
+ }
+ if !nestedEP.IsAttached() {
+ t.Error("After attach, nestedEP.IsAttached() = false, want = true")
+ }
+
+ nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{})
+ if disp.count != 1 {
+ t.Errorf("After first packet with dispatcher attached, got disp.count = %d, want = 1", disp.count)
+ }
+
+ nestedEP.Attach(nil)
+ if childEP.IsAttached() {
+ t.Error("After detach, childEP.IsAttached() = true, want = false")
+ }
+ if nestedEP.IsAttached() {
+ t.Error("After detach, nestedEP.IsAttached() = true, want = false")
+ }
+
+ disp.count = 0
+ nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{})
+ if disp.count != 0 {
+ t.Errorf("After second packet with dispatcher detached, got disp.count = %d, want = 0", disp.count)
+ }
+
+}
diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD
new file mode 100644
index 000000000..054c213bc
--- /dev/null
+++ b/pkg/tcpip/link/qdisc/fifo/BUILD
@@ -0,0 +1,19 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "fifo",
+ srcs = [
+ "endpoint.go",
+ "packet_buffer_queue.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sleep",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
new file mode 100644
index 000000000..b5dfb7850
--- /dev/null
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -0,0 +1,209 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fifo provides the implementation of data-link layer endpoints that
+// wrap another endpoint and queues all outbound packets and asynchronously
+// dispatches them to the lower endpoint.
+package fifo
+
+import (
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// endpoint represents a LinkEndpoint which implements a FIFO queue for all
+// outgoing packets. endpoint can have 1 or more underlying queueDispatchers.
+// All outgoing packets are consistenly hashed to a single underlying queue
+// using the PacketBuffer.Hash if set, otherwise all packets are queued to the
+// first queue to avoid reordering in case of missing hash.
+type endpoint struct {
+ dispatcher stack.NetworkDispatcher
+ lower stack.LinkEndpoint
+ wg sync.WaitGroup
+ dispatchers []*queueDispatcher
+}
+
+// queueDispatcher is responsible for dispatching all outbound packets in its
+// queue. It will also smartly batch packets when possible and write them
+// through the lower LinkEndpoint.
+type queueDispatcher struct {
+ lower stack.LinkEndpoint
+ q *packetBufferQueue
+ newPacketWaker sleep.Waker
+ closeWaker sleep.Waker
+}
+
+// New creates a new fifo link endpoint with the n queues with maximum
+// capacity of queueLen.
+func New(lower stack.LinkEndpoint, n int, queueLen int) stack.LinkEndpoint {
+ e := &endpoint{
+ lower: lower,
+ }
+ // Create the required dispatchers
+ for i := 0; i < n; i++ {
+ qd := &queueDispatcher{
+ q: &packetBufferQueue{limit: queueLen},
+ lower: lower,
+ }
+ e.dispatchers = append(e.dispatchers, qd)
+ e.wg.Add(1)
+ go func() {
+ defer e.wg.Done()
+ qd.dispatchLoop()
+ }()
+ }
+ return e
+}
+
+func (q *queueDispatcher) dispatchLoop() {
+ const newPacketWakerID = 1
+ const closeWakerID = 2
+ s := sleep.Sleeper{}
+ s.AddWaker(&q.newPacketWaker, newPacketWakerID)
+ s.AddWaker(&q.closeWaker, closeWakerID)
+ defer s.Done()
+
+ const batchSize = 32
+ var batch stack.PacketBufferList
+ for {
+ id, ok := s.Fetch(true)
+ if ok && id == closeWakerID {
+ return
+ }
+ for pkt := q.q.dequeue(); pkt != nil; pkt = q.q.dequeue() {
+ batch.PushBack(pkt)
+ if batch.Len() < batchSize && !q.q.empty() {
+ continue
+ }
+ // We pass a protocol of zero here because each packet carries its
+ // NetworkProtocol.
+ q.lower.WritePackets(nil /* route */, nil /* gso */, batch, 0 /* protocol */)
+ for pkt := batch.Front(); pkt != nil; pkt = pkt.Next() {
+ pkt.EgressRoute.Release()
+ batch.Remove(pkt)
+ }
+ batch.Reset()
+ }
+ }
+}
+
+// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket.
+func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt)
+}
+
+// Attach implements stack.LinkEndpoint.Attach.
+func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+ e.lower.Attach(e)
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements stack.LinkEndpoint.MTU.
+func (e *endpoint) MTU() uint32 {
+ return e.lower.MTU()
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.lower.Capabilities()
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength.
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.lower.MaxHeaderLength()
+}
+
+// LinkAddress implements stack.LinkEndpoint.LinkAddress.
+func (e *endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.lower.LinkAddress()
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (e *endpoint) GSOMaxSize() uint32 {
+ if gso, ok := e.lower.(stack.GSOEndpoint); ok {
+ return gso.GSOMaxSize()
+ }
+ return 0
+}
+
+// WritePacket implements stack.LinkEndpoint.WritePacket.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ // WritePacket caller's do not set the following fields in PacketBuffer
+ // so we populate them here.
+ newRoute := r.Clone()
+ pkt.EgressRoute = &newRoute
+ pkt.GSOOptions = gso
+ pkt.NetworkProtocolNumber = protocol
+ d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
+ if !d.q.enqueue(pkt) {
+ return tcpip.ErrNoBufferSpace
+ }
+ d.newPacketWaker.Assert()
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+//
+// Being a batch API, each packet in pkts should have the following fields
+// populated:
+// - pkt.EgressRoute
+// - pkt.GSOOptions
+// - pkt.NetworkProtocolNumber
+func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ enqueued := 0
+ for pkt := pkts.Front(); pkt != nil; {
+ d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
+ nxt := pkt.Next()
+ // Since qdisc can hold onto a packet for long we should Clone
+ // the route here to ensure it doesn't get released while the
+ // packet is still in our queue.
+ newRoute := pkt.EgressRoute.Clone()
+ pkt.EgressRoute = &newRoute
+ if !d.q.enqueue(pkt) {
+ if enqueued > 0 {
+ d.newPacketWaker.Assert()
+ }
+ return enqueued, tcpip.ErrNoBufferSpace
+ }
+ pkt = nxt
+ enqueued++
+ d.newPacketWaker.Assert()
+ }
+ return enqueued, nil
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ return e.lower.WriteRawPacket(vv)
+}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (e *endpoint) Wait() {
+ e.lower.Wait()
+
+ // The linkEP is gone. Teardown the outbound dispatcher goroutines.
+ for i := range e.dispatchers {
+ e.dispatchers[i].closeWaker.Assert()
+ }
+
+ e.wg.Wait()
+}
diff --git a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go
new file mode 100644
index 000000000..eb5abb906
--- /dev/null
+++ b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go
@@ -0,0 +1,84 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fifo
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// packetBufferQueue is a bounded, thread-safe queue of PacketBuffers.
+//
+type packetBufferQueue struct {
+ mu sync.Mutex
+ list stack.PacketBufferList
+ limit int
+ used int
+}
+
+// emptyLocked determines if the queue is empty.
+// Preconditions: q.mu must be held.
+func (q *packetBufferQueue) emptyLocked() bool {
+ return q.used == 0
+}
+
+// empty determines if the queue is empty.
+func (q *packetBufferQueue) empty() bool {
+ q.mu.Lock()
+ r := q.emptyLocked()
+ q.mu.Unlock()
+
+ return r
+}
+
+// setLimit updates the limit. No PacketBuffers are immediately dropped in case
+// the queue becomes full due to the new limit.
+func (q *packetBufferQueue) setLimit(limit int) {
+ q.mu.Lock()
+ q.limit = limit
+ q.mu.Unlock()
+}
+
+// enqueue adds the given packet to the queue.
+//
+// Returns true when the PacketBuffer is successfully added to the queue, in
+// which case ownership of the reference is transferred to the queue. And
+// returns false if the queue is full, in which case ownership is retained by
+// the caller.
+func (q *packetBufferQueue) enqueue(s *stack.PacketBuffer) bool {
+ q.mu.Lock()
+ r := q.used < q.limit
+ if r {
+ q.list.PushBack(s)
+ q.used++
+ }
+ q.mu.Unlock()
+
+ return r
+}
+
+// dequeue removes and returns the next PacketBuffer from queue, if one exists.
+// Ownership is transferred to the caller.
+func (q *packetBufferQueue) dequeue() *stack.PacketBuffer {
+ q.mu.Lock()
+ s := q.list.Front()
+ if s != nil {
+ q.list.Remove(s)
+ q.used--
+ }
+ q.mu.Unlock()
+
+ return s
+}
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
new file mode 100644
index 000000000..14b527bc2
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "rawfile",
+ srcs = [
+ "blockingpoll_amd64.s",
+ "blockingpoll_arm64.s",
+ "blockingpoll_noyield_unsafe.go",
+ "blockingpoll_yield_unsafe.go",
+ "errors.go",
+ "rawfile_unsafe.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
new file mode 100644
index 000000000..298bad55d
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
@@ -0,0 +1,41 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// BlockingPoll makes the ppoll() syscall while calling the version of
+// entersyscall that relinquishes the P so that other Gs can run. This is meant
+// to be called in cases when the syscall is expected to block.
+//
+// func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (n int, err syscall.Errno)
+TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
+ CALL ·callEntersyscallblock(SB)
+ MOVQ fds+0(FP), DI
+ MOVQ nfds+8(FP), SI
+ MOVQ timeout+16(FP), DX
+ MOVQ $0x0, R10 // sigmask parameter which isn't used here
+ MOVQ $0x10f, AX // SYS_PPOLL
+ SYSCALL
+ CMPQ AX, $0xfffffffffffff001
+ JLS ok
+ MOVQ $-1, n+24(FP)
+ NEGQ AX
+ MOVQ AX, err+32(FP)
+ CALL ·callExitsyscall(SB)
+ RET
+ok:
+ MOVQ AX, n+24(FP)
+ MOVQ $0, err+32(FP)
+ CALL ·callExitsyscall(SB)
+ RET
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
new file mode 100644
index 000000000..b62888b93
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// BlockingPoll makes the ppoll() syscall while calling the version of
+// entersyscall that relinquishes the P so that other Gs can run. This is meant
+// to be called in cases when the syscall is expected to block.
+//
+// func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (n int, err syscall.Errno)
+TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
+ BL ·callEntersyscallblock(SB)
+ MOVD fds+0(FP), R0
+ MOVD nfds+8(FP), R1
+ MOVD timeout+16(FP), R2
+ MOVD $0x0, R3 // sigmask parameter which isn't used here
+ MOVD $0x49, R8 // SYS_PPOLL
+ SVC
+ CMP $0xfffffffffffff001, R0
+ BLS ok
+ MOVD $-1, R1
+ MOVD R1, n+24(FP)
+ NEG R0, R0
+ MOVD R0, err+32(FP)
+ BL ·callExitsyscall(SB)
+ RET
+ok:
+ MOVD R0, n+24(FP)
+ MOVD $0, err+32(FP)
+ BL ·callExitsyscall(SB)
+ RET
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
new file mode 100644
index 000000000..621ab8d29
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux,!amd64,!arm64
+
+package rawfile
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+// BlockingPoll is just a stub function that forwards to the ppoll() system call
+// on non-amd64 and non-arm64 platforms.
+func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) {
+ n, _, e := syscall.Syscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(fds)),
+ uintptr(nfds), uintptr(unsafe.Pointer(timeout)), 0, 0, 0)
+
+ return int(n), e
+}
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
new file mode 100644
index 000000000..99313ee25
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
@@ -0,0 +1,66 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux,amd64 linux,arm64
+// +build go1.12
+// +build !go1.16
+
+// Check go:linkname function signatures when updating Go version.
+
+package rawfile
+
+import (
+ "syscall"
+ _ "unsafe" // for go:linkname
+)
+
+// BlockingPoll on amd64/arm64 makes the ppoll() syscall while calling the
+// version of entersyscall that relinquishes the P so that other Gs can
+// run. This is meant to be called in cases when the syscall is expected to
+// block. On non amd64/arm64 platforms it just forwards to the ppoll() system
+// call.
+//
+//go:noescape
+func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno)
+
+// Use go:linkname to call into the runtime. As of Go 1.12 this has to
+// be done from Go code so that we make an ABIInternal call to an
+// ABIInternal function; see https://golang.org/issue/27539.
+
+// We need to call both entersyscallblock and exitsyscall this way so
+// that the runtime's check on the stack pointer lines up.
+
+// Note that calling an unexported function in the runtime package is
+// unsafe and this hack is likely to break in future Go releases.
+
+//go:linkname entersyscallblock runtime.entersyscallblock
+func entersyscallblock()
+
+//go:linkname exitsyscall runtime.exitsyscall
+func exitsyscall()
+
+// These forwarding functions must be nosplit because 1) we must
+// disallow preemption between entersyscallblock and exitsyscall, and
+// 2) we have an untyped assembly frame on the stack which can not be
+// grown or moved.
+
+//go:nosplit
+func callEntersyscallblock() {
+ entersyscallblock()
+}
+
+//go:nosplit
+func callExitsyscall() {
+ exitsyscall()
+}
diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go
new file mode 100644
index 000000000..a0a873c84
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/errors.go
@@ -0,0 +1,70 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package rawfile
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const maxErrno = 134
+
+var translations [maxErrno]*tcpip.Error
+
+// TranslateErrno translate an errno from the syscall package into a
+// *tcpip.Error.
+//
+// Valid, but unrecognized errnos will be translated to
+// tcpip.ErrInvalidEndpointState (EINVAL). Panics on invalid errnos.
+func TranslateErrno(e syscall.Errno) *tcpip.Error {
+ if err := translations[e]; err != nil {
+ return err
+ }
+ return tcpip.ErrInvalidEndpointState
+}
+
+func addTranslation(host syscall.Errno, trans *tcpip.Error) {
+ if translations[host] != nil {
+ panic(fmt.Sprintf("duplicate translation for host errno %q (%d)", host.Error(), host))
+ }
+ translations[host] = trans
+}
+
+func init() {
+ addTranslation(syscall.EEXIST, tcpip.ErrDuplicateAddress)
+ addTranslation(syscall.ENETUNREACH, tcpip.ErrNoRoute)
+ addTranslation(syscall.EINVAL, tcpip.ErrInvalidEndpointState)
+ addTranslation(syscall.EALREADY, tcpip.ErrAlreadyConnecting)
+ addTranslation(syscall.EISCONN, tcpip.ErrAlreadyConnected)
+ addTranslation(syscall.EADDRINUSE, tcpip.ErrPortInUse)
+ addTranslation(syscall.EADDRNOTAVAIL, tcpip.ErrBadLocalAddress)
+ addTranslation(syscall.EPIPE, tcpip.ErrClosedForSend)
+ addTranslation(syscall.EWOULDBLOCK, tcpip.ErrWouldBlock)
+ addTranslation(syscall.ECONNREFUSED, tcpip.ErrConnectionRefused)
+ addTranslation(syscall.ETIMEDOUT, tcpip.ErrTimeout)
+ addTranslation(syscall.EINPROGRESS, tcpip.ErrConnectStarted)
+ addTranslation(syscall.EDESTADDRREQ, tcpip.ErrDestinationRequired)
+ addTranslation(syscall.ENOTSUP, tcpip.ErrNotSupported)
+ addTranslation(syscall.ENOTTY, tcpip.ErrQueueSizeNotSupported)
+ addTranslation(syscall.ENOTCONN, tcpip.ErrNotConnected)
+ addTranslation(syscall.ECONNRESET, tcpip.ErrConnectionReset)
+ addTranslation(syscall.ECONNABORTED, tcpip.ErrConnectionAborted)
+ addTranslation(syscall.EMSGSIZE, tcpip.ErrMessageTooLong)
+ addTranslation(syscall.ENOBUFS, tcpip.ErrNoBufferSpace)
+}
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
new file mode 100644
index 000000000..69de6eb3e
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -0,0 +1,192 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+// Package rawfile contains utilities for using the netstack with raw host
+// files on Linux hosts.
+package rawfile
+
+import (
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// GetMTU determines the MTU of a network interface device.
+func GetMTU(name string) (uint32, error) {
+ fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0)
+ if err != nil {
+ return 0, err
+ }
+
+ defer syscall.Close(fd)
+
+ var ifreq struct {
+ name [16]byte
+ mtu int32
+ _ [20]byte
+ }
+
+ copy(ifreq.name[:], name)
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.SIOCGIFMTU, uintptr(unsafe.Pointer(&ifreq)))
+ if errno != 0 {
+ return 0, errno
+ }
+
+ return uint32(ifreq.mtu), nil
+}
+
+// NonBlockingWrite writes the given buffer to a file descriptor. It fails if
+// partial data is written.
+func NonBlockingWrite(fd int, buf []byte) *tcpip.Error {
+ var ptr unsafe.Pointer
+ if len(buf) > 0 {
+ ptr = unsafe.Pointer(&buf[0])
+ }
+
+ _, _, e := syscall.RawSyscall(syscall.SYS_WRITE, uintptr(fd), uintptr(ptr), uintptr(len(buf)))
+ if e != 0 {
+ return TranslateErrno(e)
+ }
+
+ return nil
+}
+
+// NonBlockingWrite3 writes up to three byte slices to a file descriptor in a
+// single syscall. It fails if partial data is written.
+func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error {
+ // If there is no second and third buffer, issue a regular write.
+ if len(b2) == 0 && len(b3) == 0 {
+ return NonBlockingWrite(fd, b1)
+ }
+
+ // Build the iovec that represents them and issue a writev syscall.
+ iovec := [3]syscall.Iovec{
+ {
+ Base: &b1[0],
+ Len: uint64(len(b1)),
+ },
+ {
+ Base: &b2[0],
+ Len: uint64(len(b2)),
+ },
+ }
+ iovecLen := uintptr(2)
+
+ if len(b3) > 0 {
+ iovecLen++
+ iovec[2].Base = &b3[0]
+ iovec[2].Len = uint64(len(b3))
+ }
+
+ _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen)
+ if e != 0 {
+ return TranslateErrno(e)
+ }
+
+ return nil
+}
+
+// NonBlockingSendMMsg sends multiple messages on a socket.
+func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) {
+ n, _, e := syscall.RawSyscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0)
+ if e != 0 {
+ return 0, TranslateErrno(e)
+ }
+
+ return int(n), nil
+}
+
+// PollEvent represents the pollfd structure passed to a poll() system call.
+type PollEvent struct {
+ FD int32
+ Events int16
+ Revents int16
+}
+
+// BlockingRead reads from a file descriptor that is set up as non-blocking. If
+// no data is available, it will block in a poll() syscall until the file
+// descriptor becomes readable.
+func BlockingRead(fd int, b []byte) (int, *tcpip.Error) {
+ for {
+ n, _, e := syscall.RawSyscall(syscall.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)))
+ if e == 0 {
+ return int(n), nil
+ }
+
+ event := PollEvent{
+ FD: int32(fd),
+ Events: 1, // POLLIN
+ }
+
+ _, e = BlockingPoll(&event, 1, nil)
+ if e != 0 && e != syscall.EINTR {
+ return 0, TranslateErrno(e)
+ }
+ }
+}
+
+// BlockingReadv reads from a file descriptor that is set up as non-blocking and
+// stores the data in a list of iovecs buffers. If no data is available, it will
+// block in a poll() syscall until the file descriptor becomes readable.
+func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) {
+ for {
+ n, _, e := syscall.RawSyscall(syscall.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs)))
+ if e == 0 {
+ return int(n), nil
+ }
+
+ event := PollEvent{
+ FD: int32(fd),
+ Events: 1, // POLLIN
+ }
+
+ _, e = BlockingPoll(&event, 1, nil)
+ if e != 0 && e != syscall.EINTR {
+ return 0, TranslateErrno(e)
+ }
+ }
+}
+
+// MMsgHdr represents the mmsg_hdr structure required by recvmmsg() on linux.
+type MMsgHdr struct {
+ Msg syscall.Msghdr
+ Len uint32
+ _ [4]byte
+}
+
+// BlockingRecvMMsg reads from a file descriptor that is set up as non-blocking
+// and stores the received messages in a slice of MMsgHdr structures. If no data
+// is available, it will block in a poll() syscall until the file descriptor
+// becomes readable.
+func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) {
+ for {
+ n, _, e := syscall.RawSyscall6(syscall.SYS_RECVMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0)
+ if e == 0 {
+ return int(n), nil
+ }
+
+ event := PollEvent{
+ FD: int32(fd),
+ Events: 1, // POLLIN
+ }
+
+ if _, e := BlockingPoll(&event, 1, nil); e != 0 && e != syscall.EINTR {
+ return 0, TranslateErrno(e)
+ }
+ }
+}
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
new file mode 100644
index 000000000..13243ebbb
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -0,0 +1,41 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "sharedmem",
+ srcs = [
+ "rx.go",
+ "sharedmem.go",
+ "sharedmem_unsafe.go",
+ "tx.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/log",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/rawfile",
+ "//pkg/tcpip/link/sharedmem/queue",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "sharedmem_test",
+ srcs = [
+ "sharedmem_test.go",
+ ],
+ library = ":sharedmem",
+ deps = [
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/sharedmem/pipe",
+ "//pkg/tcpip/link/sharedmem/queue",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD
new file mode 100644
index 000000000..87020ec08
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/pipe/BUILD
@@ -0,0 +1,23 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "pipe",
+ srcs = [
+ "pipe.go",
+ "pipe_unsafe.go",
+ "rx.go",
+ "tx.go",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+go_test(
+ name = "pipe_test",
+ srcs = [
+ "pipe_test.go",
+ ],
+ library = ":pipe",
+ deps = ["//pkg/sync"],
+)
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe.go b/pkg/tcpip/link/sharedmem/pipe/pipe.go
new file mode 100644
index 000000000..74c9f0311
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/pipe/pipe.go
@@ -0,0 +1,78 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pipe implements a shared memory ring buffer on which a single reader
+// and a single writer can operate (read/write) concurrently. The ring buffer
+// allows for data of different sizes to be written, and preserves the boundary
+// of the written data.
+//
+// Example usage is as follows:
+//
+// wb := t.Push(20)
+// // Write data to wb.
+// t.Flush()
+//
+// rb := r.Pull()
+// // Do something with data in rb.
+// t.Flush()
+package pipe
+
+import (
+ "math"
+)
+
+const (
+ jump uint64 = math.MaxUint32 + 1
+ offsetMask uint64 = math.MaxUint32
+ revolutionMask uint64 = ^offsetMask
+
+ sizeOfSlotHeader = 8 // sizeof(uint64)
+ slotFree uint64 = 1 << 63
+ slotSizeMask uint64 = math.MaxUint32
+)
+
+// payloadToSlotSize calculates the total size of a slot based on its payload
+// size. The total size is the header size, plus the payload size, plus padding
+// if necessary to make the total size a multiple of sizeOfSlotHeader.
+func payloadToSlotSize(payloadSize uint64) uint64 {
+ s := sizeOfSlotHeader + payloadSize
+ return (s + sizeOfSlotHeader - 1) &^ (sizeOfSlotHeader - 1)
+}
+
+// slotToPayloadSize calculates the payload size of a slot based on the total
+// size of the slot. This is only meant to be used when creating slots that
+// don't carry information (e.g., free slots or wrap slots).
+func slotToPayloadSize(offset uint64) uint64 {
+ return offset - sizeOfSlotHeader
+}
+
+// pipe is a basic data structure used by both (transmit & receive) ends of a
+// pipe. Indices into this pipe are split into two fields: offset, which counts
+// the number of bytes from the beginning of the buffer, and revolution, which
+// counts the number of times the index has wrapped around.
+type pipe struct {
+ buffer []byte
+}
+
+// init initializes the pipe buffer such that its size is a multiple of the size
+// of the slot header.
+func (p *pipe) init(b []byte) {
+ p.buffer = b[:len(b)&^(sizeOfSlotHeader-1)]
+}
+
+// data returns a section of the buffer starting at the given index (which may
+// include revolution information) and with the given size.
+func (p *pipe) data(idx uint64, size uint64) []byte {
+ return p.buffer[(idx&offsetMask)+sizeOfSlotHeader:][:size]
+}
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
new file mode 100644
index 000000000..dc239a0d0
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
@@ -0,0 +1,518 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "math/rand"
+ "reflect"
+ "runtime"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func TestSimpleReadWrite(t *testing.T) {
+ // Check that a simple write can be properly read from the rx side.
+ tr := rand.New(rand.NewSource(99))
+ rr := rand.New(rand.NewSource(99))
+
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ wb := tx.Push(10)
+ if wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ for i := range wb {
+ wb[i] = byte(tr.Intn(256))
+ }
+ tx.Flush()
+
+ var rx Rx
+ rx.Init(b)
+ rb := rx.Pull()
+ if len(rb) != 10 {
+ t.Fatalf("Bad buffer size returned: got %v, want %v", len(rb), 10)
+ }
+
+ for i := range rb {
+ if v := byte(rr.Intn(256)); v != rb[i] {
+ t.Fatalf("Bad read buffer at index %v: got %v, want %v", i, rb[i], v)
+ }
+ }
+ rx.Flush()
+}
+
+func TestEmptyRead(t *testing.T) {
+ // Check that pulling from an empty pipe fails.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on empty pipe")
+ }
+}
+
+func TestTooLargeWrite(t *testing.T) {
+ // Check that writes that are too large are properly rejected.
+ b := make([]byte, 96)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(96); wb != nil {
+ t.Fatalf("Write of 96 bytes succeeded on 96-byte pipe")
+ }
+
+ if wb := tx.Push(88); wb != nil {
+ t.Fatalf("Write of 88 bytes succeeded on 96-byte pipe")
+ }
+
+ if wb := tx.Push(80); wb == nil {
+ t.Fatalf("Write of 80 bytes failed on 96-byte pipe")
+ }
+}
+
+func TestFullWrite(t *testing.T) {
+ // Check that writes fail when the pipe is full.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(80); wb == nil {
+ t.Fatalf("Write of 80 bytes failed on 96-byte pipe")
+ }
+
+ if wb := tx.Push(1); wb != nil {
+ t.Fatalf("Write succeeded on full pipe")
+ }
+}
+
+func TestFullAndFlushedWrite(t *testing.T) {
+ // Check that writes fail when the pipe is full and has already been
+ // flushed.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(80); wb == nil {
+ t.Fatalf("Write of 80 bytes failed on 96-byte pipe")
+ }
+
+ tx.Flush()
+
+ if wb := tx.Push(1); wb != nil {
+ t.Fatalf("Write succeeded on full pipe")
+ }
+}
+
+func TestTxFlushTwice(t *testing.T) {
+ // Checks that a second consecutive tx flush is a no-op.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ // Make copy of original tx queue, flush it, then check that it didn't
+ // change.
+ orig := tx
+ tx.Flush()
+
+ if !reflect.DeepEqual(orig, tx) {
+ t.Fatalf("Flush mutated tx pipe: got %v, want %v", tx, orig)
+ }
+}
+
+func TestRxFlushTwice(t *testing.T) {
+ // Checks that a second consecutive rx flush is a no-op.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+ rx.Flush()
+
+ // Make copy of original rx queue, flush it, then check that it didn't
+ // change.
+ orig := rx
+ rx.Flush()
+
+ if !reflect.DeepEqual(orig, rx) {
+ t.Fatalf("Flush mutated rx pipe: got %v, want %v", rx, orig)
+ }
+}
+
+func TestWrapInMiddleOfTransaction(t *testing.T) {
+ // Check that writes are not flushed when we need to wrap the buffer
+ // around.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+ rx.Flush()
+
+ // At this point the ring buffer is empty, but the write is at offset
+ // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment).
+ if wb := tx.Push(10); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on non-full pipe")
+ }
+
+ // We haven't flushed yet, so pull must return nil.
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on non-flushed pipe")
+ }
+
+ tx.Flush()
+
+ // The two buffers must be available now.
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+}
+
+func TestWriteAbort(t *testing.T) {
+ // Check that a read fails on a pipe that has had data pushed to it but
+ // has aborted the push.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(10); wb == nil {
+ t.Fatalf("Write failed on empty pipe")
+ }
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on empty pipe")
+ }
+
+ tx.Abort()
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on empty pipe")
+ }
+}
+
+func TestWrappedWriteAbort(t *testing.T) {
+ // Check that writes are properly aborted even if the writes wrap
+ // around.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+ rx.Flush()
+
+ // At this point the ring buffer is empty, but the write is at offset
+ // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment).
+ if wb := tx.Push(10); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on non-full pipe")
+ }
+
+ // We haven't flushed yet, so pull must return nil.
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on non-flushed pipe")
+ }
+
+ tx.Abort()
+
+ // The pushes were aborted, so no data should be readable.
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on non-flushed pipe")
+ }
+
+ // Try the same transactions again, but flush this time.
+ if wb := tx.Push(10); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on non-full pipe")
+ }
+
+ tx.Flush()
+
+ // The two buffers must be available now.
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+}
+
+func TestEmptyReadOnNonFlushedWrite(t *testing.T) {
+ // Check that a read fails on a pipe that has had data pushed to it
+ // but not yet flushed.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(10); wb == nil {
+ t.Fatalf("Write failed on empty pipe")
+ }
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on empty pipe")
+ }
+
+ tx.Flush()
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull on failed on non-empty pipe")
+ }
+}
+
+func TestPullAfterPullingEntirePipe(t *testing.T) {
+ // Check that Pull fails when the pipe is full, but all of it has
+ // already been pulled but not yet flushed.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+ rx.Flush()
+
+ // At this point the ring buffer is empty, but the write is at offset
+ // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 3
+ // buffers that will fill the pipe.
+ if wb := tx.Push(10); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+
+ if wb := tx.Push(20); wb == nil {
+ t.Fatalf("Push failed on non-full pipe")
+ }
+
+ if wb := tx.Push(24); wb == nil {
+ t.Fatalf("Push failed on non-full pipe")
+ }
+
+ tx.Flush()
+
+ // The three buffers must be available now.
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+
+ // Fourth pull must fail.
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on empty pipe")
+ }
+}
+
+func TestNoRoomToWrapOnPush(t *testing.T) {
+ // Check that Push fails when it tries to allocate room to add a wrap
+ // message.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+ rx.Flush()
+
+ // At this point the ring buffer is empty, but the write is at offset
+ // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 20,
+ // which won't fit (64+20+8+padding = 96, which wouldn't leave room for
+ // the padding), so it wraps around.
+ if wb := tx.Push(20); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+
+ tx.Flush()
+
+ // Buffer offset is at 28. Try to write 70, which would require a wrap
+ // slot which cannot be created now.
+ if wb := tx.Push(70); wb != nil {
+ t.Fatalf("Push succeeded on pipe with no room for wrap message")
+ }
+}
+
+func TestRxImplicitFlushOfWrapMessage(t *testing.T) {
+ // Check if the first read is that of a wrapping message, that it gets
+ // immediately flushed.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ // This will cause a wrapping message to written.
+ if wb := tx.Push(60); wb != nil {
+ t.Fatalf("Push succeeded when there is no room in pipe")
+ }
+
+ var rx Rx
+ rx.Init(b)
+
+ // Read the first message.
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+ rx.Flush()
+
+ // This should fail because of the wrapping message is taking up space.
+ if wb := tx.Push(60); wb != nil {
+ t.Fatalf("Push succeeded when there is no room in pipe")
+ }
+
+ // Try to read the next one. This should consume the wrapping message.
+ rx.Pull()
+
+ // This must now succeed.
+ if wb := tx.Push(60); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+}
+
+func TestConcurrentReaderWriter(t *testing.T) {
+ // Push a million buffers of random sizes and random contents. Check
+ // that buffers read match what was written.
+ tr := rand.New(rand.NewSource(99))
+ rr := rand.New(rand.NewSource(99))
+
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ var rx Rx
+ rx.Init(b)
+
+ const count = 1000000
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ runtime.Gosched()
+ for i := 0; i < count; i++ {
+ n := 1 + tr.Intn(80)
+ wb := tx.Push(uint64(n))
+ for wb == nil {
+ wb = tx.Push(uint64(n))
+ }
+
+ for j := range wb {
+ wb[j] = byte(tr.Intn(256))
+ }
+
+ tx.Flush()
+ }
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ runtime.Gosched()
+ for i := 0; i < count; i++ {
+ n := 1 + rr.Intn(80)
+ rb := rx.Pull()
+ for rb == nil {
+ rb = rx.Pull()
+ }
+
+ if n != len(rb) {
+ t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n)
+ }
+
+ for j := range rb {
+ if v := byte(rr.Intn(256)); v != rb[j] {
+ t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v)
+ }
+ }
+
+ rx.Flush()
+ }
+ }()
+
+ wg.Wait()
+}
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go
new file mode 100644
index 000000000..62d17029e
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go
@@ -0,0 +1,35 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "sync/atomic"
+ "unsafe"
+)
+
+func (p *pipe) write(idx uint64, v uint64) {
+ ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0]))
+ *ptr = v
+}
+
+func (p *pipe) writeAtomic(idx uint64, v uint64) {
+ ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0]))
+ atomic.StoreUint64(ptr, v)
+}
+
+func (p *pipe) readAtomic(idx uint64) uint64 {
+ ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0]))
+ return atomic.LoadUint64(ptr)
+}
diff --git a/pkg/tcpip/link/sharedmem/pipe/rx.go b/pkg/tcpip/link/sharedmem/pipe/rx.go
new file mode 100644
index 000000000..f22e533ac
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/pipe/rx.go
@@ -0,0 +1,93 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+// Rx is the receive side of the shared memory ring buffer.
+type Rx struct {
+ p pipe
+
+ tail uint64
+ head uint64
+}
+
+// Init initializes the receive end of the pipe. In the initial state, the next
+// slot to be inspected is the very first one.
+func (r *Rx) Init(b []byte) {
+ r.p.init(b)
+ r.tail = 0xfffffffe * jump
+ r.head = r.tail
+}
+
+// Pull reads the next buffer from the pipe, returning nil if there isn't one
+// currently available.
+//
+// The returned slice is available until Flush() is next called. After that, it
+// must not be touched.
+func (r *Rx) Pull() []byte {
+ if r.head == r.tail+jump {
+ // We've already pulled the whole pipe.
+ return nil
+ }
+
+ header := r.p.readAtomic(r.head)
+ if header&slotFree != 0 {
+ // The next slot is free, we can't pull it yet.
+ return nil
+ }
+
+ payloadSize := header & slotSizeMask
+ newHead := r.head + payloadToSlotSize(payloadSize)
+ headWrap := (r.head & revolutionMask) | uint64(len(r.p.buffer))
+
+ // Check if this is a wrapping slot. If that's the case, it carries no
+ // data, so we just skip it and try again from the first slot.
+ if int64(newHead-headWrap) >= 0 {
+ if int64(newHead-headWrap) > int64(jump) || newHead&offsetMask != 0 {
+ return nil
+ }
+
+ if r.tail == r.head {
+ // If this is the first pull since the last Flush()
+ // call, we flush the state so that the sender can use
+ // this space if it needs to.
+ r.p.writeAtomic(r.head, slotFree|slotToPayloadSize(newHead-r.head))
+ r.tail = newHead
+ }
+
+ r.head = newHead
+ return r.Pull()
+ }
+
+ // Grab the buffer before updating r.head.
+ b := r.p.data(r.head, payloadSize)
+ r.head = newHead
+ return b
+}
+
+// Flush tells the transmitter that all buffers pulled since the last Flush()
+// have been used, so the transmitter is free to used their slots for further
+// transmission.
+func (r *Rx) Flush() {
+ if r.head == r.tail {
+ return
+ }
+ r.p.writeAtomic(r.tail, slotFree|slotToPayloadSize(r.head-r.tail))
+ r.tail = r.head
+}
+
+// Bytes returns the byte slice on which the pipe operates.
+func (r *Rx) Bytes() []byte {
+ return r.p.buffer
+}
diff --git a/pkg/tcpip/link/sharedmem/pipe/tx.go b/pkg/tcpip/link/sharedmem/pipe/tx.go
new file mode 100644
index 000000000..9841eb231
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/pipe/tx.go
@@ -0,0 +1,161 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+// Tx is the transmit side of the shared memory ring buffer.
+type Tx struct {
+ p pipe
+ maxPayloadSize uint64
+
+ head uint64
+ tail uint64
+ next uint64
+
+ tailHeader uint64
+}
+
+// Init initializes the transmit end of the pipe. In the initial state, the next
+// slot to be written is the very first one, and the transmitter has the whole
+// ring buffer available to it.
+func (t *Tx) Init(b []byte) {
+ t.p.init(b)
+ // maxPayloadSize excludes the header of the payload, and the header
+ // of the wrapping message.
+ t.maxPayloadSize = uint64(len(t.p.buffer)) - 2*sizeOfSlotHeader
+ t.tail = 0xfffffffe * jump
+ t.next = t.tail
+ t.head = t.tail + jump
+ t.p.write(t.tail, slotFree)
+}
+
+// Capacity determines how many records of the given size can be written to the
+// pipe before it fills up.
+func (t *Tx) Capacity(recordSize uint64) uint64 {
+ available := uint64(len(t.p.buffer)) - sizeOfSlotHeader
+ entryLen := payloadToSlotSize(recordSize)
+ return available / entryLen
+}
+
+// Push reserves "payloadSize" bytes for transmission in the pipe. The caller
+// populates the returned slice with the data to be transferred and enventually
+// calls Flush() to make the data visible to the reader, or Abort() to make the
+// pipe forget all Push() calls since the last Flush().
+//
+// The returned slice is available until Flush() or Abort() is next called.
+// After that, it must not be touched.
+func (t *Tx) Push(payloadSize uint64) []byte {
+ // Fail request if we know we will never have enough room.
+ if payloadSize > t.maxPayloadSize {
+ return nil
+ }
+
+ totalLen := payloadToSlotSize(payloadSize)
+ newNext := t.next + totalLen
+ nextWrap := (t.next & revolutionMask) | uint64(len(t.p.buffer))
+ if int64(newNext-nextWrap) >= 0 {
+ // The new buffer would overflow the pipe, so we push a wrapping
+ // slot, then try to add the actual slot to the front of the
+ // pipe.
+ newNext = (newNext & revolutionMask) + jump
+ wrappingPayloadSize := slotToPayloadSize(newNext - t.next)
+ if !t.reclaim(newNext) {
+ return nil
+ }
+
+ oldNext := t.next
+ t.next = newNext
+ if oldNext != t.tail {
+ t.p.write(oldNext, wrappingPayloadSize)
+ } else {
+ t.tailHeader = wrappingPayloadSize
+ t.Flush()
+ }
+
+ newNext += totalLen
+ }
+
+ // Check that we have enough room for the buffer.
+ if !t.reclaim(newNext) {
+ return nil
+ }
+
+ if t.next != t.tail {
+ t.p.write(t.next, payloadSize)
+ } else {
+ t.tailHeader = payloadSize
+ }
+
+ // Grab the buffer before updating t.next.
+ b := t.p.data(t.next, payloadSize)
+ t.next = newNext
+
+ return b
+}
+
+// reclaim attempts to advance the head until at least newNext. If the head is
+// already at or beyond newNext, nothing happens and true is returned; otherwise
+// it tries to reclaim slots that have already been consumed by the receive end
+// of the pipe (they will be marked as free) and returns a boolean indicating
+// whether it was successful in reclaiming enough slots.
+func (t *Tx) reclaim(newNext uint64) bool {
+ for int64(newNext-t.head) > 0 {
+ // Can't reclaim if slot is not free.
+ header := t.p.readAtomic(t.head)
+ if header&slotFree == 0 {
+ return false
+ }
+
+ payloadSize := header & slotSizeMask
+ newHead := t.head + payloadToSlotSize(payloadSize)
+
+ // Check newHead is within bounds and valid.
+ if int64(newHead-t.tail) > int64(jump) || newHead&offsetMask >= uint64(len(t.p.buffer)) {
+ return false
+ }
+
+ t.head = newHead
+ }
+
+ return true
+}
+
+// Abort causes all Push() calls since the last Flush() to be forgotten and
+// therefore they will not be made visible to the receiver.
+func (t *Tx) Abort() {
+ t.next = t.tail
+}
+
+// Flush causes all buffers pushed since the last Flush() [or Abort(), whichever
+// is the most recent] to be made visible to the receiver.
+func (t *Tx) Flush() {
+ if t.next == t.tail {
+ // Nothing to do if there are no pushed buffers.
+ return
+ }
+
+ if t.next != t.head {
+ // The receiver will spin in t.next, so we must make sure that
+ // the slotFree bit is set.
+ t.p.write(t.next, slotFree)
+ }
+
+ t.p.writeAtomic(t.tail, t.tailHeader)
+ t.tail = t.next
+}
+
+// Bytes returns the byte slice on which the pipe operates.
+func (t *Tx) Bytes() []byte {
+ return t.p.buffer
+}
diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD
new file mode 100644
index 000000000..3ba06af73
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/queue/BUILD
@@ -0,0 +1,27 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "queue",
+ srcs = [
+ "rx.go",
+ "tx.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/log",
+ "//pkg/tcpip/link/sharedmem/pipe",
+ ],
+)
+
+go_test(
+ name = "queue_test",
+ srcs = [
+ "queue_test.go",
+ ],
+ library = ":queue",
+ deps = [
+ "//pkg/tcpip/link/sharedmem/pipe",
+ ],
+)
diff --git a/pkg/tcpip/link/sharedmem/queue/queue_test.go b/pkg/tcpip/link/sharedmem/queue/queue_test.go
new file mode 100644
index 000000000..9a0aad5d7
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/queue/queue_test.go
@@ -0,0 +1,517 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package queue
+
+import (
+ "encoding/binary"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
+)
+
+func TestBasicTxQueue(t *testing.T) {
+ // Tests that a basic transmit on a queue works, and that completion
+ // gets properly reported as well.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Tx
+ q.Init(pb1, pb2)
+
+ // Enqueue two buffers.
+ b := []TxBuffer{
+ {nil, 100, 60},
+ {nil, 200, 40},
+ }
+
+ b[0].Next = &b[1]
+
+ const usedID = 1002
+ const usedTotalSize = 100
+ if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) {
+ t.Fatalf("Enqueue failed on empty queue")
+ }
+
+ // Check the contents of the pipe.
+ d := rxp.Pull()
+ if d == nil {
+ t.Fatalf("Tx pipe is empty after Enqueue")
+ }
+
+ want := []byte{
+ 234, 3, 0, 0, 0, 0, 0, 0, // id
+ 100, 0, 0, 0, // total size
+ 0, 0, 0, 0, // reserved
+ 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
+ 60, 0, 0, 0, // size 1
+ 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
+ 40, 0, 0, 0, // size 2
+ }
+
+ if !reflect.DeepEqual(want, d) {
+ t.Fatalf("Bad posted packet: got %v, want %v", d, want)
+ }
+
+ rxp.Flush()
+
+ // Check that there are no completions yet.
+ if _, ok := q.CompletedPacket(); ok {
+ t.Fatalf("Packet reported as completed too soon")
+ }
+
+ // Post a completion.
+ d = txp.Push(8)
+ if d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+ binary.LittleEndian.PutUint64(d, usedID)
+ txp.Flush()
+
+ // Check that completion is properly reported.
+ id, ok := q.CompletedPacket()
+ if !ok {
+ t.Fatalf("Completion not reported")
+ }
+
+ if id != usedID {
+ t.Fatalf("Bad completion id: got %v, want %v", id, usedID)
+ }
+}
+
+func TestBasicRxQueue(t *testing.T) {
+ // Tests that a basic receive on a queue works.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Rx
+ q.Init(pb1, pb2, nil)
+
+ // Post two buffers.
+ b := []RxBuffer{
+ {100, 60, 1077, 0},
+ {200, 40, 2123, 0},
+ }
+
+ if !q.PostBuffers(b) {
+ t.Fatalf("PostBuffers failed on empty queue")
+ }
+
+ // Check the contents of the pipe.
+ want := [][]byte{
+ {
+ 100, 0, 0, 0, 0, 0, 0, 0, // Offset1
+ 60, 0, 0, 0, // Size1
+ 0, 0, 0, 0, // Remaining in group 1
+ 0, 0, 0, 0, 0, 0, 0, 0, // User data 1
+ 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
+ },
+ {
+ 200, 0, 0, 0, 0, 0, 0, 0, // Offset2
+ 40, 0, 0, 0, // Size2
+ 0, 0, 0, 0, // Remaining in group 2
+ 0, 0, 0, 0, 0, 0, 0, 0, // User data 2
+ 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
+ },
+ }
+
+ for i := range b {
+ d := rxp.Pull()
+ if d == nil {
+ t.Fatalf("Tx pipe is empty after PostBuffers")
+ }
+
+ if !reflect.DeepEqual(want[i], d) {
+ t.Fatalf("Bad posted packet: got %v, want %v", d, want[i])
+ }
+
+ rxp.Flush()
+ }
+
+ // Check that there are no completions.
+ if _, n := q.Dequeue(nil); n != 0 {
+ t.Fatalf("Packet reported as received too soon")
+ }
+
+ // Post a completion.
+ d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
+ if d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+
+ copy(d, []byte{
+ 100, 0, 0, 0, // packet size
+ 0, 0, 0, 0, // reserved
+
+ 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
+ 60, 0, 0, 0, // size 1
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
+ 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
+
+ 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
+ 40, 0, 0, 0, // size 2
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
+ 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
+ })
+
+ txp.Flush()
+
+ // Check that completion is properly reported.
+ bufs, n := q.Dequeue(nil)
+ if n != 100 {
+ t.Fatalf("Bad packet size: got %v, want %v", n, 100)
+ }
+
+ if !reflect.DeepEqual(bufs, b) {
+ t.Fatalf("Bad returned buffers: got %v, want %v", bufs, b)
+ }
+}
+
+func TestBadTxCompletion(t *testing.T) {
+ // Check that tx completions with bad sizes are properly ignored.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Tx
+ q.Init(pb1, pb2)
+
+ // Post a completion that is too short, and check that it is ignored.
+ if d := txp.Push(7); d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+ txp.Flush()
+
+ if _, ok := q.CompletedPacket(); ok {
+ t.Fatalf("Bad completion not ignored")
+ }
+
+ // Post a completion that is too long, and check that it is ignored.
+ if d := txp.Push(10); d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+ txp.Flush()
+
+ if _, ok := q.CompletedPacket(); ok {
+ t.Fatalf("Bad completion not ignored")
+ }
+}
+
+func TestBadRxCompletion(t *testing.T) {
+ // Check that bad rx completions are properly ignored.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Rx
+ q.Init(pb1, pb2, nil)
+
+ // Post a completion that is too short, and check that it is ignored.
+ if d := txp.Push(7); d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+ txp.Flush()
+
+ if b, _ := q.Dequeue(nil); b != nil {
+ t.Fatalf("Bad completion not ignored")
+ }
+
+ // Post a completion whose buffer sizes add up to less than the total
+ // size.
+ d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
+ if d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+
+ copy(d, []byte{
+ 100, 0, 0, 0, // packet size
+ 0, 0, 0, 0, // reserved
+
+ 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
+ 10, 0, 0, 0, // size 1
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
+ 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
+
+ 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
+ 10, 0, 0, 0, // size 2
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
+ 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
+ })
+
+ txp.Flush()
+ if b, _ := q.Dequeue(nil); b != nil {
+ t.Fatalf("Bad completion not ignored")
+ }
+
+ // Post a completion whose buffer sizes will cause a 32-bit overflow,
+ // but adds up to the right number.
+ d = txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
+ if d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+
+ copy(d, []byte{
+ 100, 0, 0, 0, // packet size
+ 0, 0, 0, 0, // reserved
+
+ 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
+ 255, 255, 255, 255, // size 1
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
+ 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
+
+ 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
+ 101, 0, 0, 0, // size 2
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
+ 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
+ })
+
+ txp.Flush()
+ if b, _ := q.Dequeue(nil); b != nil {
+ t.Fatalf("Bad completion not ignored")
+ }
+}
+
+func TestFillTxPipe(t *testing.T) {
+ // Check that transmitting a new buffer when the buffer pipe is full
+ // fails gracefully.
+ pb1 := make([]byte, 104)
+ pb2 := make([]byte, 104)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Tx
+ q.Init(pb1, pb2)
+
+ // Transmit twice, which should fill the tx pipe.
+ b := []TxBuffer{
+ {nil, 100, 60},
+ {nil, 200, 40},
+ }
+
+ b[0].Next = &b[1]
+
+ const usedID = 1002
+ const usedTotalSize = 100
+ for i := uint64(0); i < 2; i++ {
+ if !q.Enqueue(usedID+i, usedTotalSize, 2, &b[0]) {
+ t.Fatalf("Failed to transmit buffer")
+ }
+ }
+
+ // Transmit another packet now that the tx pipe is full.
+ if q.Enqueue(usedID+2, usedTotalSize, 2, &b[0]) {
+ t.Fatalf("Enqueue succeeded when tx pipe is full")
+ }
+}
+
+func TestFillRxPipe(t *testing.T) {
+ // Check that posting a new buffer when the buffer pipe is full fails
+ // gracefully.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Rx
+ q.Init(pb1, pb2, nil)
+
+ // Post a buffer twice, it should fill the tx pipe.
+ b := []RxBuffer{
+ {100, 60, 1077, 0},
+ }
+
+ for i := 0; i < 2; i++ {
+ if !q.PostBuffers(b) {
+ t.Fatalf("PostBuffers failed on non-full queue")
+ }
+ }
+
+ // Post another buffer now that the tx pipe is full.
+ if q.PostBuffers(b) {
+ t.Fatalf("PostBuffers succeeded on full queue")
+ }
+}
+
+func TestLotsOfTransmissions(t *testing.T) {
+ // Make sure pipes are being properly flushed when transmitting packets.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Tx
+ q.Init(pb1, pb2)
+
+ // Prepare packet with two buffers.
+ b := []TxBuffer{
+ {nil, 100, 60},
+ {nil, 200, 40},
+ }
+
+ b[0].Next = &b[1]
+
+ const usedID = 1002
+ const usedTotalSize = 100
+
+ // Post 100000 packets and completions.
+ for i := 100000; i > 0; i-- {
+ if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) {
+ t.Fatalf("Enqueue failed on non-full queue")
+ }
+
+ if d := rxp.Pull(); d == nil {
+ t.Fatalf("Tx pipe is empty after Enqueue")
+ }
+ rxp.Flush()
+
+ d := txp.Push(8)
+ if d == nil {
+ t.Fatalf("Unable to write to rx pipe")
+ }
+ binary.LittleEndian.PutUint64(d, usedID)
+ txp.Flush()
+ if _, ok := q.CompletedPacket(); !ok {
+ t.Fatalf("Completion not returned")
+ }
+ }
+}
+
+func TestLotsOfReceptions(t *testing.T) {
+ // Make sure pipes are being properly flushed when receiving packets.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Rx
+ q.Init(pb1, pb2, nil)
+
+ // Prepare for posting two buffers.
+ b := []RxBuffer{
+ {100, 60, 1077, 0},
+ {200, 40, 2123, 0},
+ }
+
+ // Post 100000 buffers and completions.
+ for i := 100000; i > 0; i-- {
+ if !q.PostBuffers(b) {
+ t.Fatalf("PostBuffers failed on non-full queue")
+ }
+
+ if d := rxp.Pull(); d == nil {
+ t.Fatalf("Tx pipe is empty after PostBuffers")
+ }
+ rxp.Flush()
+
+ if d := rxp.Pull(); d == nil {
+ t.Fatalf("Tx pipe is empty after PostBuffers")
+ }
+ rxp.Flush()
+
+ d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
+ if d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+
+ copy(d, []byte{
+ 100, 0, 0, 0, // packet size
+ 0, 0, 0, 0, // reserved
+
+ 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
+ 60, 0, 0, 0, // size 1
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
+ 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
+
+ 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
+ 40, 0, 0, 0, // size 2
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
+ 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
+ })
+
+ txp.Flush()
+
+ if _, n := q.Dequeue(nil); n == 0 {
+ t.Fatalf("Dequeue failed when there is a completion")
+ }
+ }
+}
+
+func TestRxEnableNotification(t *testing.T) {
+ // Check that enabling nofifications results in properly updated state.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var state uint32
+ var q Rx
+ q.Init(pb1, pb2, &state)
+
+ q.EnableNotification()
+ if state != eventFDEnabled {
+ t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDEnabled)
+ }
+}
+
+func TestRxDisableNotification(t *testing.T) {
+ // Check that disabling nofifications results in properly updated state.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var state uint32
+ var q Rx
+ q.Init(pb1, pb2, &state)
+
+ q.DisableNotification()
+ if state != eventFDDisabled {
+ t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDDisabled)
+ }
+}
diff --git a/pkg/tcpip/link/sharedmem/queue/rx.go b/pkg/tcpip/link/sharedmem/queue/rx.go
new file mode 100644
index 000000000..696e6c9e5
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/queue/rx.go
@@ -0,0 +1,221 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package queue provides the implementation of transmit and receive queues
+// based on shared memory ring buffers.
+package queue
+
+import (
+ "encoding/binary"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
+)
+
+const (
+ // Offsets within a posted buffer.
+ postedOffset = 0
+ postedSize = 8
+ postedRemainingInGroup = 12
+ postedUserData = 16
+ postedID = 24
+
+ sizeOfPostedBuffer = 32
+
+ // Offsets within a received packet header.
+ consumedPacketSize = 0
+ consumedPacketReserved = 4
+
+ sizeOfConsumedPacketHeader = 8
+
+ // Offsets within a consumed buffer.
+ consumedOffset = 0
+ consumedSize = 8
+ consumedUserData = 12
+ consumedID = 20
+
+ sizeOfConsumedBuffer = 28
+
+ // The following are the allowed states of the shared data area.
+ eventFDUninitialized = 0
+ eventFDDisabled = 1
+ eventFDEnabled = 2
+)
+
+// RxBuffer is the descriptor of a receive buffer.
+type RxBuffer struct {
+ Offset uint64
+ Size uint32
+ ID uint64
+ UserData uint64
+}
+
+// Rx is a receive queue. It is implemented with one tx and one rx pipe: the tx
+// pipe is used to "post" buffers, while the rx pipe is used to receive packets
+// whose contents have been written to previously posted buffers.
+//
+// This struct is thread-compatible.
+type Rx struct {
+ tx pipe.Tx
+ rx pipe.Rx
+ sharedEventFDState *uint32
+}
+
+// Init initializes the receive queue with the given pipes, and shared state
+// pointer -- the latter is used to enable/disable eventfd notifications.
+func (r *Rx) Init(tx, rx []byte, sharedEventFDState *uint32) {
+ r.sharedEventFDState = sharedEventFDState
+ r.tx.Init(tx)
+ r.rx.Init(rx)
+}
+
+// EnableNotification updates the shared state such that the peer will notify
+// the eventfd when there are packets to be dequeued.
+func (r *Rx) EnableNotification() {
+ atomic.StoreUint32(r.sharedEventFDState, eventFDEnabled)
+}
+
+// DisableNotification updates the shared state such that the peer will not
+// notify the eventfd.
+func (r *Rx) DisableNotification() {
+ atomic.StoreUint32(r.sharedEventFDState, eventFDDisabled)
+}
+
+// PostedBuffersLimit returns the maximum number of buffers that can be posted
+// before the tx queue fills up.
+func (r *Rx) PostedBuffersLimit() uint64 {
+ return r.tx.Capacity(sizeOfPostedBuffer)
+}
+
+// PostBuffers makes the given buffers available for receiving data from the
+// peer. Once they are posted, the peer is free to write to them and will
+// eventually post them back for consumption.
+func (r *Rx) PostBuffers(buffers []RxBuffer) bool {
+ for i := range buffers {
+ b := r.tx.Push(sizeOfPostedBuffer)
+ if b == nil {
+ r.tx.Abort()
+ return false
+ }
+
+ pb := &buffers[i]
+ binary.LittleEndian.PutUint64(b[postedOffset:], pb.Offset)
+ binary.LittleEndian.PutUint32(b[postedSize:], pb.Size)
+ binary.LittleEndian.PutUint32(b[postedRemainingInGroup:], 0)
+ binary.LittleEndian.PutUint64(b[postedUserData:], pb.UserData)
+ binary.LittleEndian.PutUint64(b[postedID:], pb.ID)
+ }
+
+ r.tx.Flush()
+
+ return true
+}
+
+// Dequeue receives buffers that have been previously posted by PostBuffers()
+// and that have been filled by the peer and posted back.
+//
+// This is similar to append() in that new buffers are appended to "bufs", with
+// reallocation only if "bufs" doesn't have enough capacity.
+func (r *Rx) Dequeue(bufs []RxBuffer) ([]RxBuffer, uint32) {
+ for {
+ outBufs := bufs
+
+ // Pull the next descriptor from the rx pipe.
+ b := r.rx.Pull()
+ if b == nil {
+ return bufs, 0
+ }
+
+ if len(b) < sizeOfConsumedPacketHeader {
+ log.Warningf("Ignoring packet header: size (%v) is less than header size (%v)", len(b), sizeOfConsumedPacketHeader)
+ r.rx.Flush()
+ continue
+ }
+
+ totalDataSize := binary.LittleEndian.Uint32(b[consumedPacketSize:])
+
+ // Calculate the number of buffer descriptors and copy them
+ // over to the output.
+ count := (len(b) - sizeOfConsumedPacketHeader) / sizeOfConsumedBuffer
+ offset := sizeOfConsumedPacketHeader
+ buffersSize := uint32(0)
+ for i := count; i > 0; i-- {
+ s := binary.LittleEndian.Uint32(b[offset+consumedSize:])
+ buffersSize += s
+ if buffersSize < s {
+ // The buffer size overflows an unsigned 32-bit
+ // integer, so break out and force it to be
+ // ignored.
+ totalDataSize = 1
+ buffersSize = 0
+ break
+ }
+
+ outBufs = append(outBufs, RxBuffer{
+ Offset: binary.LittleEndian.Uint64(b[offset+consumedOffset:]),
+ Size: s,
+ ID: binary.LittleEndian.Uint64(b[offset+consumedID:]),
+ })
+
+ offset += sizeOfConsumedBuffer
+ }
+
+ r.rx.Flush()
+
+ if buffersSize < totalDataSize {
+ // The descriptor is corrupted, ignore it.
+ log.Warningf("Ignoring packet: actual data size (%v) less than expected size (%v)", buffersSize, totalDataSize)
+ continue
+ }
+
+ return outBufs, totalDataSize
+ }
+}
+
+// Bytes returns the byte slices on which the queue operates.
+func (r *Rx) Bytes() (tx, rx []byte) {
+ return r.tx.Bytes(), r.rx.Bytes()
+}
+
+// DecodeRxBufferHeader decodes the header of a buffer posted on an rx queue.
+func DecodeRxBufferHeader(b []byte) RxBuffer {
+ return RxBuffer{
+ Offset: binary.LittleEndian.Uint64(b[postedOffset:]),
+ Size: binary.LittleEndian.Uint32(b[postedSize:]),
+ ID: binary.LittleEndian.Uint64(b[postedID:]),
+ UserData: binary.LittleEndian.Uint64(b[postedUserData:]),
+ }
+}
+
+// RxCompletionSize returns the number of bytes needed to encode an rx
+// completion containing "count" buffers.
+func RxCompletionSize(count int) uint64 {
+ return sizeOfConsumedPacketHeader + uint64(count)*sizeOfConsumedBuffer
+}
+
+// EncodeRxCompletion encodes an rx completion header.
+func EncodeRxCompletion(b []byte, size, reserved uint32) {
+ binary.LittleEndian.PutUint32(b[consumedPacketSize:], size)
+ binary.LittleEndian.PutUint32(b[consumedPacketReserved:], reserved)
+}
+
+// EncodeRxCompletionBuffer encodes the i-th rx completion buffer header.
+func EncodeRxCompletionBuffer(b []byte, i int, rxb RxBuffer) {
+ b = b[RxCompletionSize(i):]
+ binary.LittleEndian.PutUint64(b[consumedOffset:], rxb.Offset)
+ binary.LittleEndian.PutUint32(b[consumedSize:], rxb.Size)
+ binary.LittleEndian.PutUint64(b[consumedUserData:], rxb.UserData)
+ binary.LittleEndian.PutUint64(b[consumedID:], rxb.ID)
+}
diff --git a/pkg/tcpip/link/sharedmem/queue/tx.go b/pkg/tcpip/link/sharedmem/queue/tx.go
new file mode 100644
index 000000000..beffe807b
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/queue/tx.go
@@ -0,0 +1,151 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package queue
+
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
+)
+
+const (
+ // Offsets within a packet header.
+ packetID = 0
+ packetSize = 8
+ packetReserved = 12
+
+ sizeOfPacketHeader = 16
+
+ // Offsets with a buffer descriptor
+ bufferOffset = 0
+ bufferSize = 8
+
+ sizeOfBufferDescriptor = 12
+)
+
+// TxBuffer is the descriptor of a transmit buffer.
+type TxBuffer struct {
+ Next *TxBuffer
+ Offset uint64
+ Size uint32
+}
+
+// Tx is a transmit queue. It is implemented with one tx and one rx pipe: the
+// tx pipe is used to request the transmission of packets, while the rx pipe
+// is used to receive which transmissions have completed.
+//
+// This struct is thread-compatible.
+type Tx struct {
+ tx pipe.Tx
+ rx pipe.Rx
+}
+
+// Init initializes the transmit queue with the given pipes.
+func (t *Tx) Init(tx, rx []byte) {
+ t.tx.Init(tx)
+ t.rx.Init(rx)
+}
+
+// Enqueue queues the given linked list of buffers for transmission as one
+// packet. While it is queued, the caller must not modify them.
+func (t *Tx) Enqueue(id uint64, totalDataLen, bufferCount uint32, buffer *TxBuffer) bool {
+ // Reserve room in the tx pipe.
+ totalLen := sizeOfPacketHeader + uint64(bufferCount)*sizeOfBufferDescriptor
+
+ b := t.tx.Push(totalLen)
+ if b == nil {
+ return false
+ }
+
+ // Initialize the packet and buffer descriptors.
+ binary.LittleEndian.PutUint64(b[packetID:], id)
+ binary.LittleEndian.PutUint32(b[packetSize:], totalDataLen)
+ binary.LittleEndian.PutUint32(b[packetReserved:], 0)
+
+ offset := sizeOfPacketHeader
+ for i := bufferCount; i != 0; i-- {
+ binary.LittleEndian.PutUint64(b[offset+bufferOffset:], buffer.Offset)
+ binary.LittleEndian.PutUint32(b[offset+bufferSize:], buffer.Size)
+ offset += sizeOfBufferDescriptor
+ buffer = buffer.Next
+ }
+
+ t.tx.Flush()
+
+ return true
+}
+
+// CompletedPacket returns the id of the last completed transmission. The
+// returned id, if any, refers to a value passed on a previous call to
+// Enqueue().
+func (t *Tx) CompletedPacket() (id uint64, ok bool) {
+ for {
+ b := t.rx.Pull()
+ if b == nil {
+ return 0, false
+ }
+
+ if len(b) != 8 {
+ t.rx.Flush()
+ log.Warningf("Ignoring completed packet: size (%v) is less than expected (%v)", len(b), 8)
+ continue
+ }
+
+ v := binary.LittleEndian.Uint64(b)
+
+ t.rx.Flush()
+
+ return v, true
+ }
+}
+
+// Bytes returns the byte slices on which the queue operates.
+func (t *Tx) Bytes() (tx, rx []byte) {
+ return t.tx.Bytes(), t.rx.Bytes()
+}
+
+// TxPacketInfo holds information about a packet sent on a tx queue.
+type TxPacketInfo struct {
+ ID uint64
+ Size uint32
+ Reserved uint32
+ BufferCount int
+}
+
+// DecodeTxPacketHeader decodes the header of a packet sent over a tx queue.
+func DecodeTxPacketHeader(b []byte) TxPacketInfo {
+ return TxPacketInfo{
+ ID: binary.LittleEndian.Uint64(b[packetID:]),
+ Size: binary.LittleEndian.Uint32(b[packetSize:]),
+ Reserved: binary.LittleEndian.Uint32(b[packetReserved:]),
+ BufferCount: (len(b) - sizeOfPacketHeader) / sizeOfBufferDescriptor,
+ }
+}
+
+// DecodeTxBufferHeader decodes the header of the i-th buffer of a packet sent
+// over a tx queue.
+func DecodeTxBufferHeader(b []byte, i int) TxBuffer {
+ b = b[sizeOfPacketHeader+i*sizeOfBufferDescriptor:]
+ return TxBuffer{
+ Offset: binary.LittleEndian.Uint64(b[bufferOffset:]),
+ Size: binary.LittleEndian.Uint32(b[bufferSize:]),
+ }
+}
+
+// EncodeTxCompletion encodes a tx completion header.
+func EncodeTxCompletion(b []byte, id uint64) {
+ binary.LittleEndian.PutUint64(b, id)
+}
diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go
new file mode 100644
index 000000000..eec11e4cb
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/rx.go
@@ -0,0 +1,159 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package sharedmem
+
+import (
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
+)
+
+// rx holds all state associated with an rx queue.
+type rx struct {
+ data []byte
+ sharedData []byte
+ q queue.Rx
+ eventFD int
+}
+
+// init initializes all state needed by the rx queue based on the information
+// provided.
+//
+// The caller always retains ownership of all file descriptors passed in. The
+// queue implementation will duplicate any that it may need in the future.
+func (r *rx) init(mtu uint32, c *QueueConfig) error {
+ // Map in all buffers.
+ txPipe, err := getBuffer(c.TxPipeFD)
+ if err != nil {
+ return err
+ }
+
+ rxPipe, err := getBuffer(c.RxPipeFD)
+ if err != nil {
+ syscall.Munmap(txPipe)
+ return err
+ }
+
+ data, err := getBuffer(c.DataFD)
+ if err != nil {
+ syscall.Munmap(txPipe)
+ syscall.Munmap(rxPipe)
+ return err
+ }
+
+ sharedData, err := getBuffer(c.SharedDataFD)
+ if err != nil {
+ syscall.Munmap(txPipe)
+ syscall.Munmap(rxPipe)
+ syscall.Munmap(data)
+ return err
+ }
+
+ // Duplicate the eventFD so that caller can close it but we can still
+ // use it.
+ efd, err := syscall.Dup(c.EventFD)
+ if err != nil {
+ syscall.Munmap(txPipe)
+ syscall.Munmap(rxPipe)
+ syscall.Munmap(data)
+ syscall.Munmap(sharedData)
+ return err
+ }
+
+ // Set the eventfd as non-blocking.
+ if err := syscall.SetNonblock(efd, true); err != nil {
+ syscall.Munmap(txPipe)
+ syscall.Munmap(rxPipe)
+ syscall.Munmap(data)
+ syscall.Munmap(sharedData)
+ syscall.Close(efd)
+ return err
+ }
+
+ // Initialize state based on buffers.
+ r.q.Init(txPipe, rxPipe, sharedDataPointer(sharedData))
+ r.data = data
+ r.eventFD = efd
+ r.sharedData = sharedData
+
+ return nil
+}
+
+// cleanup releases all resources allocated during init(). It must only be
+// called if init() has previously succeeded.
+func (r *rx) cleanup() {
+ a, b := r.q.Bytes()
+ syscall.Munmap(a)
+ syscall.Munmap(b)
+
+ syscall.Munmap(r.data)
+ syscall.Munmap(r.sharedData)
+ syscall.Close(r.eventFD)
+}
+
+// postAndReceive posts the provided buffers (if any), and then tries to read
+// from the receive queue.
+//
+// Capacity permitting, it reuses the posted buffer slice to store the buffers
+// that were read as well.
+//
+// This function will block if there aren't any available packets.
+func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue.RxBuffer, uint32) {
+ // Post the buffers first. If we cannot post, sleep until we can. We
+ // never post more than will fit concurrently, so it's safe to wait
+ // until enough room is available.
+ if len(b) != 0 && !r.q.PostBuffers(b) {
+ r.q.EnableNotification()
+ for !r.q.PostBuffers(b) {
+ var tmp [8]byte
+ rawfile.BlockingRead(r.eventFD, tmp[:])
+ if atomic.LoadUint32(stopRequested) != 0 {
+ r.q.DisableNotification()
+ return nil, 0
+ }
+ }
+ r.q.DisableNotification()
+ }
+
+ // Read the next set of descriptors.
+ b, n := r.q.Dequeue(b[:0])
+ if len(b) != 0 {
+ return b, n
+ }
+
+ // Data isn't immediately available. Enable eventfd notifications.
+ r.q.EnableNotification()
+ for {
+ b, n = r.q.Dequeue(b)
+ if len(b) != 0 {
+ break
+ }
+
+ // Wait for notification.
+ var tmp [8]byte
+ rawfile.BlockingRead(r.eventFD, tmp[:])
+ if atomic.LoadUint32(stopRequested) != 0 {
+ r.q.DisableNotification()
+ return nil, 0
+ }
+ }
+ r.q.DisableNotification()
+
+ return b, n
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
new file mode 100644
index 000000000..0374a2441
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -0,0 +1,289 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+// Package sharedmem provides the implemention of data-link layer endpoints
+// backed by shared memory.
+//
+// Shared memory endpoints can be used in the networking stack by calling New()
+// to create a new endpoint, and then passing it as an argument to
+// Stack.CreateNIC().
+package sharedmem
+
+import (
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// QueueConfig holds all the file descriptors needed to describe a tx or rx
+// queue over shared memory. It is used when creating new shared memory
+// endpoints to describe tx and rx queues.
+type QueueConfig struct {
+ // DataFD is a file descriptor for the file that contains the data to
+ // be transmitted via this queue. Descriptors contain offsets within
+ // this file.
+ DataFD int
+
+ // EventFD is a file descriptor for the event that is signaled when
+ // data is becomes available in this queue.
+ EventFD int
+
+ // TxPipeFD is a file descriptor for the tx pipe associated with the
+ // queue.
+ TxPipeFD int
+
+ // RxPipeFD is a file descriptor for the rx pipe associated with the
+ // queue.
+ RxPipeFD int
+
+ // SharedDataFD is a file descriptor for the file that contains shared
+ // state between the two ends of the queue. This data specifies, for
+ // example, whether EventFD signaling is enabled or disabled.
+ SharedDataFD int
+}
+
+type endpoint struct {
+ // mtu (maximum transmission unit) is the maximum size of a packet.
+ mtu uint32
+
+ // bufferSize is the size of each individual buffer.
+ bufferSize uint32
+
+ // addr is the local address of this endpoint.
+ addr tcpip.LinkAddress
+
+ // rx is the receive queue.
+ rx rx
+
+ // stopRequested is to be accessed atomically only, and determines if
+ // the worker goroutines should stop.
+ stopRequested uint32
+
+ // Wait group used to indicate that all workers have stopped.
+ completed sync.WaitGroup
+
+ // mu protects the following fields.
+ mu sync.Mutex
+
+ // tx is the transmit queue.
+ tx tx
+
+ // workerStarted specifies whether the worker goroutine was started.
+ workerStarted bool
+}
+
+// New creates a new shared-memory-based endpoint. Buffers will be broken up
+// into buffers of "bufferSize" bytes.
+func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) {
+ e := &endpoint{
+ mtu: mtu,
+ bufferSize: bufferSize,
+ addr: addr,
+ }
+
+ if err := e.tx.init(bufferSize, &tx); err != nil {
+ return nil, err
+ }
+
+ if err := e.rx.init(bufferSize, &rx); err != nil {
+ e.tx.cleanup()
+ return nil, err
+ }
+
+ return e, nil
+}
+
+// Close frees all resources associated with the endpoint.
+func (e *endpoint) Close() {
+ // Tell dispatch goroutine to stop, then write to the eventfd so that
+ // it wakes up in case it's sleeping.
+ atomic.StoreUint32(&e.stopRequested, 1)
+ syscall.Write(e.rx.eventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+
+ // Cleanup the queues inline if the worker hasn't started yet; we also
+ // know it won't start from now on because stopRequested is set to 1.
+ e.mu.Lock()
+ workerPresent := e.workerStarted
+ e.mu.Unlock()
+
+ if !workerPresent {
+ e.tx.cleanup()
+ e.rx.cleanup()
+ }
+}
+
+// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have
+// stopped after a Close() call.
+func (e *endpoint) Wait() {
+ e.completed.Wait()
+}
+
+// Attach implements stack.LinkEndpoint.Attach. It launches the goroutine that
+// reads packets from the rx queue.
+func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.mu.Lock()
+ if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 {
+ e.workerStarted = true
+ e.completed.Add(1)
+ // Link endpoints are not savable. When transportation endpoints
+ // are saved, they stop sending outgoing packets and all
+ // incoming packets are rejected.
+ go e.dispatchLoop(dispatcher) // S/R-SAFE: see above.
+ }
+ e.mu.Unlock()
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *endpoint) IsAttached() bool {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.workerStarted
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
+// during construction.
+func (e *endpoint) MTU() uint32 {
+ return e.mtu - header.EthernetMinimumSize
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return 0
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the
+// ethernet frame header size.
+func (*endpoint) MaxHeaderLength() uint16 {
+ return header.EthernetMinimumSize
+}
+
+// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local
+// link address.
+func (e *endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.addr
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ // Add the ethernet header here.
+ eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
+ pkt.LinkHeader = buffer.View(eth)
+ ethHdr := &header.EthernetFields{
+ DstAddr: r.RemoteLinkAddress,
+ Type: protocol,
+ }
+ if r.LocalLinkAddress != "" {
+ ethHdr.SrcAddr = r.LocalLinkAddress
+ } else {
+ ethHdr.SrcAddr = e.addr
+ }
+ eth.Encode(ethHdr)
+
+ v := pkt.Data.ToView()
+ // Transmit the packet.
+ e.mu.Lock()
+ ok := e.tx.transmit(pkt.Header.View(), v)
+ e.mu.Unlock()
+
+ if !ok {
+ return tcpip.ErrWouldBlock
+ }
+
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ v := vv.ToView()
+ // Transmit the packet.
+ e.mu.Lock()
+ ok := e.tx.transmit(v, buffer.View{})
+ e.mu.Unlock()
+
+ if !ok {
+ return tcpip.ErrWouldBlock
+ }
+
+ return nil
+}
+
+// dispatchLoop reads packets from the rx queue in a loop and dispatches them
+// to the network stack.
+func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
+ // Post initial set of buffers.
+ limit := e.rx.q.PostedBuffersLimit()
+ if l := uint64(len(e.rx.data)) / uint64(e.bufferSize); limit > l {
+ limit = l
+ }
+ for i := uint64(0); i < limit; i++ {
+ b := queue.RxBuffer{
+ Offset: i * uint64(e.bufferSize),
+ Size: e.bufferSize,
+ ID: i,
+ }
+ if !e.rx.q.PostBuffers([]queue.RxBuffer{b}) {
+ log.Warningf("Unable to post %v-th buffer", i)
+ }
+ }
+
+ // Read in a loop until a stop is requested.
+ var rxb []queue.RxBuffer
+ for atomic.LoadUint32(&e.stopRequested) == 0 {
+ var n uint32
+ rxb, n = e.rx.postAndReceive(rxb, &e.stopRequested)
+
+ // Copy data from the shared area to its own buffer, then
+ // prepare to repost the buffer.
+ b := make([]byte, n)
+ offset := uint32(0)
+ for i := range rxb {
+ copy(b[offset:], e.rx.data[rxb[i].Offset:][:rxb[i].Size])
+ offset += rxb[i].Size
+
+ rxb[i].Size = e.bufferSize
+ }
+
+ if n < header.EthernetMinimumSize {
+ continue
+ }
+
+ // Send packet up the stack.
+ eth := header.Ethernet(b[:header.EthernetMinimumSize])
+ d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), &stack.PacketBuffer{
+ Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(),
+ LinkHeader: buffer.View(eth),
+ })
+ }
+
+ // Clean state.
+ e.tx.cleanup()
+ e.rx.cleanup()
+
+ e.completed.Done()
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
new file mode 100644
index 000000000..28a2e88ba
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -0,0 +1,812 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package sharedmem
+
+import (
+ "bytes"
+ "io/ioutil"
+ "math/rand"
+ "os"
+ "strings"
+ "syscall"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ localLinkAddr = "\xde\xad\xbe\xef\x56\x78"
+ remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34"
+
+ queueDataSize = 1024 * 1024
+ queuePipeSize = 4096
+)
+
+type queueBuffers struct {
+ data []byte
+ rx pipe.Tx
+ tx pipe.Rx
+}
+
+func initQueue(t *testing.T, q *queueBuffers, c *QueueConfig) {
+ // Prepare tx pipe.
+ b, err := getBuffer(c.TxPipeFD)
+ if err != nil {
+ t.Fatalf("getBuffer failed: %v", err)
+ }
+ q.tx.Init(b)
+
+ // Prepare rx pipe.
+ b, err = getBuffer(c.RxPipeFD)
+ if err != nil {
+ t.Fatalf("getBuffer failed: %v", err)
+ }
+ q.rx.Init(b)
+
+ // Get data slice.
+ q.data, err = getBuffer(c.DataFD)
+ if err != nil {
+ t.Fatalf("getBuffer failed: %v", err)
+ }
+}
+
+func (q *queueBuffers) cleanup() {
+ syscall.Munmap(q.tx.Bytes())
+ syscall.Munmap(q.rx.Bytes())
+ syscall.Munmap(q.data)
+}
+
+type packetInfo struct {
+ addr tcpip.LinkAddress
+ proto tcpip.NetworkProtocolNumber
+ vv buffer.VectorisedView
+ linkHeader buffer.View
+}
+
+type testContext struct {
+ t *testing.T
+ ep *endpoint
+ txCfg QueueConfig
+ rxCfg QueueConfig
+ txq queueBuffers
+ rxq queueBuffers
+
+ packetCh chan struct{}
+ mu sync.Mutex
+ packets []packetInfo
+}
+
+func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress) *testContext {
+ var err error
+ c := &testContext{
+ t: t,
+ packetCh: make(chan struct{}, 1000000),
+ }
+ c.txCfg = createQueueFDs(t, queueSizes{
+ dataSize: queueDataSize,
+ txPipeSize: queuePipeSize,
+ rxPipeSize: queuePipeSize,
+ sharedDataSize: 4096,
+ })
+
+ c.rxCfg = createQueueFDs(t, queueSizes{
+ dataSize: queueDataSize,
+ txPipeSize: queuePipeSize,
+ rxPipeSize: queuePipeSize,
+ sharedDataSize: 4096,
+ })
+
+ initQueue(t, &c.txq, &c.txCfg)
+ initQueue(t, &c.rxq, &c.rxCfg)
+
+ ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg)
+ if err != nil {
+ t.Fatalf("New failed: %v", err)
+ }
+
+ c.ep = ep.(*endpoint)
+ c.ep.Attach(c)
+
+ return c
+}
+
+func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ c.mu.Lock()
+ c.packets = append(c.packets, packetInfo{
+ addr: remoteLinkAddr,
+ proto: proto,
+ vv: pkt.Data.Clone(nil),
+ })
+ c.mu.Unlock()
+
+ c.packetCh <- struct{}{}
+}
+
+func (c *testContext) cleanup() {
+ c.ep.Close()
+ closeFDs(&c.txCfg)
+ closeFDs(&c.rxCfg)
+ c.txq.cleanup()
+ c.rxq.cleanup()
+}
+
+func (c *testContext) waitForPackets(n int, to <-chan time.Time, errorStr string) {
+ for i := 0; i < n; i++ {
+ select {
+ case <-c.packetCh:
+ case <-to:
+ c.t.Fatalf(errorStr)
+ }
+ }
+}
+
+func (c *testContext) pushRxCompletion(size uint32, bs []queue.RxBuffer) {
+ b := c.rxq.rx.Push(queue.RxCompletionSize(len(bs)))
+ queue.EncodeRxCompletion(b, size, 0)
+ for i := range bs {
+ queue.EncodeRxCompletionBuffer(b, i, queue.RxBuffer{
+ Offset: bs[i].Offset,
+ Size: bs[i].Size,
+ ID: bs[i].ID,
+ })
+ }
+}
+
+func randomFill(b []byte) {
+ for i := range b {
+ b[i] = byte(rand.Intn(256))
+ }
+}
+
+func shuffle(b []int) {
+ for i := len(b) - 1; i >= 0; i-- {
+ j := rand.Intn(i + 1)
+ b[i], b[j] = b[j], b[i]
+ }
+}
+
+func createFile(t *testing.T, size int64, initQueue bool) int {
+ tmpDir := os.Getenv("TEST_TMPDIR")
+ if tmpDir == "" {
+ tmpDir = os.Getenv("TMPDIR")
+ }
+ f, err := ioutil.TempFile(tmpDir, "sharedmem_test")
+ if err != nil {
+ t.Fatalf("TempFile failed: %v", err)
+ }
+ defer f.Close()
+ syscall.Unlink(f.Name())
+
+ if initQueue {
+ // Write the "slot-free" flag in the initial queue.
+ _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0)
+ if err != nil {
+ t.Fatalf("WriteAt failed: %v", err)
+ }
+ }
+
+ fd, err := syscall.Dup(int(f.Fd()))
+ if err != nil {
+ t.Fatalf("Dup failed: %v", err)
+ }
+
+ if err := syscall.Ftruncate(fd, size); err != nil {
+ syscall.Close(fd)
+ t.Fatalf("Ftruncate failed: %v", err)
+ }
+
+ return fd
+}
+
+func closeFDs(c *QueueConfig) {
+ syscall.Close(c.DataFD)
+ syscall.Close(c.EventFD)
+ syscall.Close(c.TxPipeFD)
+ syscall.Close(c.RxPipeFD)
+ syscall.Close(c.SharedDataFD)
+}
+
+type queueSizes struct {
+ dataSize int64
+ txPipeSize int64
+ rxPipeSize int64
+ sharedDataSize int64
+}
+
+func createQueueFDs(t *testing.T, s queueSizes) QueueConfig {
+ fd, _, err := syscall.RawSyscall(syscall.SYS_EVENTFD2, 0, 0, 0)
+ if err != 0 {
+ t.Fatalf("eventfd failed: %v", error(err))
+ }
+
+ return QueueConfig{
+ EventFD: int(fd),
+ DataFD: createFile(t, s.dataSize, false),
+ TxPipeFD: createFile(t, s.txPipeSize, true),
+ RxPipeFD: createFile(t, s.rxPipeSize, true),
+ SharedDataFD: createFile(t, s.sharedDataSize, false),
+ }
+}
+
+// TestSimpleSend sends 1000 packets with random header and payload sizes,
+// then checks that the right payload is received on the shared memory queues.
+func TestSimpleSend(t *testing.T) {
+ c := newTestContext(t, 20000, 1500, localLinkAddr)
+ defer c.cleanup()
+
+ // Prepare route.
+ r := stack.Route{
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+
+ for iters := 1000; iters > 0; iters-- {
+ func() {
+ // Prepare and send packet.
+ n := rand.Intn(10000)
+ hdr := buffer.NewPrependable(n + int(c.ep.MaxHeaderLength()))
+ hdrBuf := hdr.Prepend(n)
+ randomFill(hdrBuf)
+
+ n = rand.Intn(10000)
+ buf := buffer.NewView(n)
+ randomFill(buf)
+
+ proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
+ if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ }); err != nil {
+ t.Fatalf("WritePacket failed: %v", err)
+ }
+
+ // Receive packet.
+ desc := c.txq.tx.Pull()
+ pi := queue.DecodeTxPacketHeader(desc)
+ if pi.Reserved != 0 {
+ t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved)
+ }
+ contents := make([]byte, 0, pi.Size)
+ for i := 0; i < pi.BufferCount; i++ {
+ bi := queue.DecodeTxBufferHeader(desc, i)
+ contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...)
+ }
+ c.txq.tx.Flush()
+
+ defer func() {
+ // Tell the endpoint about the completion of the write.
+ b := c.txq.rx.Push(8)
+ queue.EncodeTxCompletion(b, pi.ID)
+ c.txq.rx.Flush()
+ }()
+
+ // Check the ethernet header.
+ ethTemplate := make(header.Ethernet, header.EthernetMinimumSize)
+ ethTemplate.Encode(&header.EthernetFields{
+ SrcAddr: localLinkAddr,
+ DstAddr: remoteLinkAddr,
+ Type: proto,
+ })
+ if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) {
+ t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate)
+ }
+
+ // Compare contents skipping the ethernet header added by the
+ // endpoint.
+ merged := append(hdrBuf, buf...)
+ if uint32(len(contents)) < pi.Size {
+ t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size)
+ }
+ contents = contents[:pi.Size][header.EthernetMinimumSize:]
+
+ if !bytes.Equal(contents, merged) {
+ t.Fatalf("Buffers are different: got %x (%v bytes), want %x (%v bytes)", contents, len(contents), merged, len(merged))
+ }
+ }()
+ }
+}
+
+// TestPreserveSrcAddressInSend calls WritePacket once with LocalLinkAddress
+// set in Route (using much of the same code as TestSimpleSend), then checks
+// that the encoded ethernet header received includes the correct SrcAddr.
+func TestPreserveSrcAddressInSend(t *testing.T) {
+ c := newTestContext(t, 20000, 1500, localLinkAddr)
+ defer c.cleanup()
+
+ newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6))
+ // Set both remote and local link address in route.
+ r := stack.Route{
+ RemoteLinkAddress: remoteLinkAddr,
+ LocalLinkAddress: newLocalLinkAddress,
+ }
+
+ // WritePacket panics given a prependable with anything less than
+ // the minimum size of the ethernet header.
+ hdr := buffer.NewPrependable(header.EthernetMinimumSize)
+
+ proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
+ if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{
+ Header: hdr,
+ }); err != nil {
+ t.Fatalf("WritePacket failed: %v", err)
+ }
+
+ // Receive packet.
+ desc := c.txq.tx.Pull()
+ pi := queue.DecodeTxPacketHeader(desc)
+ if pi.Reserved != 0 {
+ t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved)
+ }
+ contents := make([]byte, 0, pi.Size)
+ for i := 0; i < pi.BufferCount; i++ {
+ bi := queue.DecodeTxBufferHeader(desc, i)
+ contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...)
+ }
+ c.txq.tx.Flush()
+
+ defer func() {
+ // Tell the endpoint about the completion of the write.
+ b := c.txq.rx.Push(8)
+ queue.EncodeTxCompletion(b, pi.ID)
+ c.txq.rx.Flush()
+ }()
+
+ // Check that the ethernet header contains the expected SrcAddr.
+ ethTemplate := make(header.Ethernet, header.EthernetMinimumSize)
+ ethTemplate.Encode(&header.EthernetFields{
+ SrcAddr: newLocalLinkAddress,
+ DstAddr: remoteLinkAddr,
+ Type: proto,
+ })
+ if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) {
+ t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate)
+ }
+}
+
+// TestFillTxQueue sends packets until the queue is full.
+func TestFillTxQueue(t *testing.T) {
+ c := newTestContext(t, 20000, 1500, localLinkAddr)
+ defer c.cleanup()
+
+ // Prepare to send a packet.
+ r := stack.Route{
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+
+ buf := buffer.NewView(100)
+
+ // Each packet is uses no more than 40 bytes, so write that many packets
+ // until the tx queue if full.
+ ids := make(map[uint64]struct{})
+ for i := queuePipeSize / 40; i > 0; i-- {
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ }); err != nil {
+ t.Fatalf("WritePacket failed unexpectedly: %v", err)
+ }
+
+ // Check that they have different IDs.
+ desc := c.txq.tx.Pull()
+ pi := queue.DecodeTxPacketHeader(desc)
+ if _, ok := ids[pi.ID]; ok {
+ t.Fatalf("ID (%v) reused", pi.ID)
+ }
+ ids[pi.ID] = struct{}{}
+ }
+
+ // Next attempt to write must fail.
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ }); err != want {
+ t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
+ }
+}
+
+// TestFillTxQueueAfterBadCompletion sends a bad completion, then sends packets
+// until the queue is full.
+func TestFillTxQueueAfterBadCompletion(t *testing.T) {
+ c := newTestContext(t, 20000, 1500, localLinkAddr)
+ defer c.cleanup()
+
+ // Send a bad completion.
+ queue.EncodeTxCompletion(c.txq.rx.Push(8), 1)
+ c.txq.rx.Flush()
+
+ // Prepare to send a packet.
+ r := stack.Route{
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+
+ buf := buffer.NewView(100)
+
+ // Send two packets so that the id slice has at least two slots.
+ for i := 2; i > 0; i-- {
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ }); err != nil {
+ t.Fatalf("WritePacket failed unexpectedly: %v", err)
+ }
+ }
+
+ // Complete the two writes twice.
+ for i := 2; i > 0; i-- {
+ pi := queue.DecodeTxPacketHeader(c.txq.tx.Pull())
+
+ queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID)
+ queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID)
+ c.txq.rx.Flush()
+ }
+ c.txq.tx.Flush()
+
+ // Each packet is uses no more than 40 bytes, so write that many packets
+ // until the tx queue if full.
+ ids := make(map[uint64]struct{})
+ for i := queuePipeSize / 40; i > 0; i-- {
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ }); err != nil {
+ t.Fatalf("WritePacket failed unexpectedly: %v", err)
+ }
+
+ // Check that they have different IDs.
+ desc := c.txq.tx.Pull()
+ pi := queue.DecodeTxPacketHeader(desc)
+ if _, ok := ids[pi.ID]; ok {
+ t.Fatalf("ID (%v) reused", pi.ID)
+ }
+ ids[pi.ID] = struct{}{}
+ }
+
+ // Next attempt to write must fail.
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ }); err != want {
+ t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
+ }
+}
+
+// TestFillTxMemory sends packets until the we run out of shared memory.
+func TestFillTxMemory(t *testing.T) {
+ const bufferSize = 1500
+ c := newTestContext(t, 20000, bufferSize, localLinkAddr)
+ defer c.cleanup()
+
+ // Prepare to send a packet.
+ r := stack.Route{
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+
+ buf := buffer.NewView(100)
+
+ // Each packet is uses up one buffer, so write as many as possible until
+ // we fill the memory.
+ ids := make(map[uint64]struct{})
+ for i := queueDataSize / bufferSize; i > 0; i-- {
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ }); err != nil {
+ t.Fatalf("WritePacket failed unexpectedly: %v", err)
+ }
+
+ // Check that they have different IDs.
+ desc := c.txq.tx.Pull()
+ pi := queue.DecodeTxPacketHeader(desc)
+ if _, ok := ids[pi.ID]; ok {
+ t.Fatalf("ID (%v) reused", pi.ID)
+ }
+ ids[pi.ID] = struct{}{}
+ c.txq.tx.Flush()
+ }
+
+ // Next attempt to write must fail.
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ })
+ if want := tcpip.ErrWouldBlock; err != want {
+ t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
+ }
+}
+
+// TestFillTxMemoryWithMultiBuffer sends packets until the we run out of
+// shared memory for a 2-buffer packet, but still with room for a 1-buffer
+// packet.
+func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
+ const bufferSize = 1500
+ c := newTestContext(t, 20000, bufferSize, localLinkAddr)
+ defer c.cleanup()
+
+ // Prepare to send a packet.
+ r := stack.Route{
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+
+ buf := buffer.NewView(100)
+
+ // Each packet is uses up one buffer, so write as many as possible
+ // until there is only one buffer left.
+ for i := queueDataSize/bufferSize - 1; i > 0; i-- {
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ }); err != nil {
+ t.Fatalf("WritePacket failed unexpectedly: %v", err)
+ }
+
+ // Pull the posted buffer.
+ c.txq.tx.Pull()
+ c.txq.tx.Flush()
+ }
+
+ // Attempt to write a two-buffer packet. It must fail.
+ {
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ uu := buffer.NewView(bufferSize).ToVectorisedView()
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
+ Header: hdr,
+ Data: uu,
+ }); err != want {
+ t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
+ }
+ }
+
+ // Attempt to write the one-buffer packet again. It must succeed.
+ {
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ }); err != nil {
+ t.Fatalf("WritePacket failed unexpectedly: %v", err)
+ }
+ }
+}
+
+func pollPull(t *testing.T, p *pipe.Rx, to <-chan time.Time, errStr string) []byte {
+ t.Helper()
+
+ for {
+ b := p.Pull()
+ if b != nil {
+ return b
+ }
+
+ select {
+ case <-time.After(10 * time.Millisecond):
+ case <-to:
+ t.Fatal(errStr)
+ }
+ }
+}
+
+// TestSimpleReceive completes 1000 different receives with random payload and
+// random number of buffers. It checks that the contents match the expected
+// values.
+func TestSimpleReceive(t *testing.T) {
+ const bufferSize = 1500
+ c := newTestContext(t, 20000, bufferSize, localLinkAddr)
+ defer c.cleanup()
+
+ // Check that buffers have been posted.
+ limit := c.ep.rx.q.PostedBuffersLimit()
+ for i := uint64(0); i < limit; i++ {
+ timeout := time.After(2 * time.Second)
+ bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers to be posted"))
+
+ if want := i * bufferSize; want != bi.Offset {
+ t.Fatalf("Bad posted offset: got %v, want %v", bi.Offset, want)
+ }
+
+ if want := i; want != bi.ID {
+ t.Fatalf("Bad posted ID: got %v, want %v", bi.ID, want)
+ }
+
+ if bufferSize != bi.Size {
+ t.Fatalf("Bad posted bufferSize: got %v, want %v", bi.Size, bufferSize)
+ }
+ }
+ c.rxq.tx.Flush()
+
+ // Create a slice with the indices 0..limit-1.
+ idx := make([]int, limit)
+ for i := range idx {
+ idx[i] = i
+ }
+
+ // Complete random packets 1000 times.
+ for iters := 1000; iters > 0; iters-- {
+ timeout := time.After(2 * time.Second)
+ // Prepare a random packet.
+ shuffle(idx)
+ n := 1 + rand.Intn(10)
+ bufs := make([]queue.RxBuffer, n)
+ contents := make([]byte, bufferSize*n-rand.Intn(500))
+ randomFill(contents)
+ for i := range bufs {
+ j := idx[i]
+ bufs[i].Size = bufferSize
+ bufs[i].Offset = uint64(bufferSize * j)
+ bufs[i].ID = uint64(j)
+
+ copy(c.rxq.data[bufs[i].Offset:][:bufferSize], contents[i*bufferSize:])
+ }
+
+ // Push completion.
+ c.pushRxCompletion(uint32(len(contents)), bufs)
+ c.rxq.rx.Flush()
+ syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+
+ // Wait for packet to be received, then check it.
+ c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet")
+ c.mu.Lock()
+ rcvd := []byte(c.packets[0].vv.ToView())
+ c.packets = c.packets[:0]
+ c.mu.Unlock()
+
+ if contents := contents[header.EthernetMinimumSize:]; !bytes.Equal(contents, rcvd) {
+ t.Fatalf("Unexpected buffer contents: got %x, want %x", rcvd, contents)
+ }
+
+ // Check that buffers have been reposted.
+ for i := range bufs {
+ bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffers to be reposted"))
+ if bi != bufs[i] {
+ t.Fatalf("Unexpected buffer reposted: got %x, want %x", bi, bufs[i])
+ }
+ }
+ c.rxq.tx.Flush()
+ }
+}
+
+// TestRxBuffersReposted tests that rx buffers get reposted after they have been
+// completed.
+func TestRxBuffersReposted(t *testing.T) {
+ const bufferSize = 1500
+ c := newTestContext(t, 20000, bufferSize, localLinkAddr)
+ defer c.cleanup()
+
+ // Receive all posted buffers.
+ limit := c.ep.rx.q.PostedBuffersLimit()
+ buffers := make([]queue.RxBuffer, 0, limit)
+ for i := limit; i > 0; i-- {
+ timeout := time.After(2 * time.Second)
+ buffers = append(buffers, queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers")))
+ }
+ c.rxq.tx.Flush()
+
+ // Check that all buffers are reposted when individually completed.
+ for i := range buffers {
+ timeout := time.After(2 * time.Second)
+ // Complete the buffer.
+ c.pushRxCompletion(buffers[i].Size, buffers[i:][:1])
+ c.rxq.rx.Flush()
+ syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+
+ // Wait for it to be reposted.
+ bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted"))
+ if bi != buffers[i] {
+ t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[i])
+ }
+ }
+ c.rxq.tx.Flush()
+
+ // Check that all buffers are reposted when completed in pairs.
+ for i := 0; i < len(buffers)/2; i++ {
+ timeout := time.After(2 * time.Second)
+ // Complete with two buffers.
+ c.pushRxCompletion(2*bufferSize, buffers[2*i:][:2])
+ c.rxq.rx.Flush()
+ syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+
+ // Wait for them to be reposted.
+ for j := 0; j < 2; j++ {
+ bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted"))
+ if bi != buffers[2*i+j] {
+ t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[2*i+j])
+ }
+ }
+ }
+ c.rxq.tx.Flush()
+}
+
+// TestReceivePostingIsFull checks that the endpoint will properly handle the
+// case when a received buffer cannot be immediately reposted because it hasn't
+// been pulled from the tx pipe yet.
+func TestReceivePostingIsFull(t *testing.T) {
+ const bufferSize = 1500
+ c := newTestContext(t, 20000, bufferSize, localLinkAddr)
+ defer c.cleanup()
+
+ // Complete first posted buffer before flushing it from the tx pipe.
+ first := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for first buffer to be posted"))
+ c.pushRxCompletion(first.Size, []queue.RxBuffer{first})
+ c.rxq.rx.Flush()
+ syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+
+ // Check that packet is received.
+ c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet")
+
+ // Complete another buffer.
+ second := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for second buffer to be posted"))
+ c.pushRxCompletion(second.Size, []queue.RxBuffer{second})
+ c.rxq.rx.Flush()
+ syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+
+ // Check that no packet is received yet, as the worker is blocked trying
+ // to repost.
+ select {
+ case <-time.After(500 * time.Millisecond):
+ case <-c.packetCh:
+ t.Fatalf("Unexpected packet received")
+ }
+
+ // Flush tx queue, which will allow the first buffer to be reposted,
+ // and the second completion to be pulled.
+ c.rxq.tx.Flush()
+ syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+
+ // Check that second packet completes.
+ c.waitForPackets(1, time.After(time.Second), "Timeout waiting for second completed packet")
+}
+
+// TestCloseWhileWaitingToPost closes the endpoint while it is waiting to
+// repost a buffer. Make sure it backs out.
+func TestCloseWhileWaitingToPost(t *testing.T) {
+ const bufferSize = 1500
+ c := newTestContext(t, 20000, bufferSize, localLinkAddr)
+ cleaned := false
+ defer func() {
+ if !cleaned {
+ c.cleanup()
+ }
+ }()
+
+ // Complete first posted buffer before flushing it from the tx pipe.
+ bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for initial buffer to be posted"))
+ c.pushRxCompletion(bi.Size, []queue.RxBuffer{bi})
+ c.rxq.rx.Flush()
+ syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+
+ // Wait for packet to be indicated.
+ c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet")
+
+ // Cleanup and wait for worker to complete.
+ c.cleanup()
+ cleaned = true
+ c.ep.Wait()
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go
new file mode 100644
index 000000000..f7e816a41
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go
@@ -0,0 +1,25 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sharedmem
+
+import (
+ "unsafe"
+)
+
+// sharedDataPointer converts the shared data slice into a pointer so that it
+// can be used in atomic operations.
+func sharedDataPointer(sharedData []byte) *uint32 {
+ return (*uint32)(unsafe.Pointer(&sharedData[0:4][0]))
+}
diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go
new file mode 100644
index 000000000..6b8d7859d
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/tx.go
@@ -0,0 +1,272 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sharedmem
+
+import (
+ "math"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
+)
+
+const (
+ nilID = math.MaxUint64
+)
+
+// tx holds all state associated with a tx queue.
+type tx struct {
+ data []byte
+ q queue.Tx
+ ids idManager
+ bufs bufferManager
+}
+
+// init initializes all state needed by the tx queue based on the information
+// provided.
+//
+// The caller always retains ownership of all file descriptors passed in. The
+// queue implementation will duplicate any that it may need in the future.
+func (t *tx) init(mtu uint32, c *QueueConfig) error {
+ // Map in all buffers.
+ txPipe, err := getBuffer(c.TxPipeFD)
+ if err != nil {
+ return err
+ }
+
+ rxPipe, err := getBuffer(c.RxPipeFD)
+ if err != nil {
+ syscall.Munmap(txPipe)
+ return err
+ }
+
+ data, err := getBuffer(c.DataFD)
+ if err != nil {
+ syscall.Munmap(txPipe)
+ syscall.Munmap(rxPipe)
+ return err
+ }
+
+ // Initialize state based on buffers.
+ t.q.Init(txPipe, rxPipe)
+ t.ids.init()
+ t.bufs.init(0, len(data), int(mtu))
+ t.data = data
+
+ return nil
+}
+
+// cleanup releases all resources allocated during init(). It must only be
+// called if init() has previously succeeded.
+func (t *tx) cleanup() {
+ a, b := t.q.Bytes()
+ syscall.Munmap(a)
+ syscall.Munmap(b)
+ syscall.Munmap(t.data)
+}
+
+// transmit sends a packet made up of up to two buffers. Returns a boolean that
+// specifies whether the packet was successfully transmitted.
+func (t *tx) transmit(a, b []byte) bool {
+ // Pull completions from the tx queue and add their buffers back to the
+ // pool so that we can reuse them.
+ for {
+ id, ok := t.q.CompletedPacket()
+ if !ok {
+ break
+ }
+
+ if buf := t.ids.remove(id); buf != nil {
+ t.bufs.free(buf)
+ }
+ }
+
+ bSize := t.bufs.entrySize
+ total := uint32(len(a) + len(b))
+ bufCount := (total + bSize - 1) / bSize
+
+ // Allocate enough buffers to hold all the data.
+ var buf *queue.TxBuffer
+ for i := bufCount; i != 0; i-- {
+ b := t.bufs.alloc()
+ if b == nil {
+ // Failed to get all buffers. Return to the pool
+ // whatever we had managed to get.
+ if buf != nil {
+ t.bufs.free(buf)
+ }
+ return false
+ }
+ b.Next = buf
+ buf = b
+ }
+
+ // Copy data into allocated buffers.
+ nBuf := buf
+ var dBuf []byte
+ for _, data := range [][]byte{a, b} {
+ for len(data) > 0 {
+ if len(dBuf) == 0 {
+ dBuf = t.data[nBuf.Offset:][:nBuf.Size]
+ nBuf = nBuf.Next
+ }
+ n := copy(dBuf, data)
+ data = data[n:]
+ dBuf = dBuf[n:]
+ }
+ }
+
+ // Get an id for this packet and send it out.
+ id := t.ids.add(buf)
+ if !t.q.Enqueue(id, total, bufCount, buf) {
+ t.ids.remove(id)
+ t.bufs.free(buf)
+ return false
+ }
+
+ return true
+}
+
+// getBuffer returns a memory region mapped to the full contents of the given
+// file descriptor.
+func getBuffer(fd int) ([]byte, error) {
+ var s syscall.Stat_t
+ if err := syscall.Fstat(fd, &s); err != nil {
+ return nil, err
+ }
+
+ // Check that size doesn't overflow an int.
+ if s.Size > int64(^uint(0)>>1) {
+ return nil, syscall.EDOM
+ }
+
+ return syscall.Mmap(fd, 0, int(s.Size), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED|syscall.MAP_FILE)
+}
+
+// idDescriptor is used by idManager to either point to a tx buffer (in case
+// the ID is assigned) or to the next free element (if the id is not assigned).
+type idDescriptor struct {
+ buf *queue.TxBuffer
+ nextFree uint64
+}
+
+// idManager is a manager of tx buffer identifiers. It assigns unique IDs to
+// tx buffers that are added to it; the IDs can only be reused after they have
+// been removed.
+//
+// The ID assignments are stored so that the tx buffers can be retrieved from
+// the IDs previously assigned to them.
+type idManager struct {
+ // ids is a slice containing all tx buffers. The ID is the index into
+ // this slice.
+ ids []idDescriptor
+
+ // freeList a list of free IDs.
+ freeList uint64
+}
+
+// init initializes the id manager.
+func (m *idManager) init() {
+ m.freeList = nilID
+}
+
+// add assigns an ID to the given tx buffer.
+func (m *idManager) add(b *queue.TxBuffer) uint64 {
+ if i := m.freeList; i != nilID {
+ // There is an id available in the free list, just use it.
+ m.ids[i].buf = b
+ m.freeList = m.ids[i].nextFree
+ return i
+ }
+
+ // We need to expand the id descriptor.
+ m.ids = append(m.ids, idDescriptor{buf: b})
+ return uint64(len(m.ids) - 1)
+}
+
+// remove retrieves the tx buffer associated with the given ID, and removes the
+// ID from the assigned table so that it can be reused in the future.
+func (m *idManager) remove(i uint64) *queue.TxBuffer {
+ if i >= uint64(len(m.ids)) {
+ return nil
+ }
+
+ desc := &m.ids[i]
+ b := desc.buf
+ if b == nil {
+ // The provided id is not currently assigned.
+ return nil
+ }
+
+ desc.buf = nil
+ desc.nextFree = m.freeList
+ m.freeList = i
+
+ return b
+}
+
+// bufferManager manages a buffer region broken up into smaller, equally sized
+// buffers. Smaller buffers can be allocated and freed.
+type bufferManager struct {
+ freeList *queue.TxBuffer
+ curOffset uint64
+ limit uint64
+ entrySize uint32
+}
+
+// init initializes the buffer manager.
+func (b *bufferManager) init(initialOffset, size, entrySize int) {
+ b.freeList = nil
+ b.curOffset = uint64(initialOffset)
+ b.limit = uint64(initialOffset + size/entrySize*entrySize)
+ b.entrySize = uint32(entrySize)
+}
+
+// alloc allocates a buffer from the manager, if one is available.
+func (b *bufferManager) alloc() *queue.TxBuffer {
+ if b.freeList != nil {
+ // There is a descriptor ready for reuse in the free list.
+ d := b.freeList
+ b.freeList = d.Next
+ d.Next = nil
+ return d
+ }
+
+ if b.curOffset < b.limit {
+ // There is room available in the never-used range, so create
+ // a new descriptor for it.
+ d := &queue.TxBuffer{
+ Offset: b.curOffset,
+ Size: b.entrySize,
+ }
+ b.curOffset += uint64(b.entrySize)
+ return d
+ }
+
+ return nil
+}
+
+// free returns all buffers in the list to the buffer manager so that they can
+// be reused.
+func (b *bufferManager) free(d *queue.TxBuffer) {
+ // Find the last buffer in the list.
+ last := d
+ for last.Next != nil {
+ last = last.Next
+ }
+
+ // Push list onto free list.
+ last.Next = b.freeList
+ b.freeList = d
+}
diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD
new file mode 100644
index 000000000..7cbc305e7
--- /dev/null
+++ b/pkg/tcpip/link/sniffer/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "sniffer",
+ srcs = [
+ "pcap.go",
+ "sniffer.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/log",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/nested",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/sniffer/pcap.go b/pkg/tcpip/link/sniffer/pcap.go
new file mode 100644
index 000000000..c16c19647
--- /dev/null
+++ b/pkg/tcpip/link/sniffer/pcap.go
@@ -0,0 +1,66 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sniffer
+
+import "time"
+
+type pcapHeader struct {
+ // MagicNumber is the file magic number.
+ MagicNumber uint32
+
+ // VersionMajor is the major version number.
+ VersionMajor uint16
+
+ // VersionMinor is the minor version number.
+ VersionMinor uint16
+
+ // Thiszone is the GMT to local correction.
+ Thiszone int32
+
+ // Sigfigs is the accuracy of timestamps.
+ Sigfigs uint32
+
+ // Snaplen is the max length of captured packets, in octets.
+ Snaplen uint32
+
+ // Network is the data link type.
+ Network uint32
+}
+
+const pcapPacketHeaderLen = 16
+
+type pcapPacketHeader struct {
+ // Seconds is the timestamp seconds.
+ Seconds uint32
+
+ // Microseconds is the timestamp microseconds.
+ Microseconds uint32
+
+ // IncludedLength is the number of octets of packet saved in file.
+ IncludedLength uint32
+
+ // OriginalLength is the actual length of packet.
+ OriginalLength uint32
+}
+
+func newPCAPPacketHeader(incLen, orgLen uint32) pcapPacketHeader {
+ now := time.Now()
+ return pcapPacketHeader{
+ Seconds: uint32(now.Unix()),
+ Microseconds: uint32(now.Nanosecond() / 1000),
+ IncludedLength: incLen,
+ OriginalLength: orgLen,
+ }
+}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
new file mode 100644
index 000000000..d9cd4e83a
--- /dev/null
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -0,0 +1,394 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package sniffer provides the implementation of data-link layer endpoints that
+// wrap another endpoint and logs inbound and outbound packets.
+//
+// Sniffer endpoints can be used in the networking stack by calling New(eID) to
+// create a new endpoint, where eID is the ID of the endpoint being wrapped,
+// and then passing it as an argument to Stack.CreateNIC().
+package sniffer
+
+import (
+ "encoding/binary"
+ "fmt"
+ "io"
+ "sync/atomic"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// LogPackets is a flag used to enable or disable packet logging via the log
+// package. Valid values are 0 or 1.
+//
+// LogPackets must be accessed atomically.
+var LogPackets uint32 = 1
+
+// LogPacketsToPCAP is a flag used to enable or disable logging packets to a
+// pcap writer. Valid values are 0 or 1. A writer must have been specified when the
+// sniffer was created for this flag to have effect.
+//
+// LogPacketsToPCAP must be accessed atomically.
+var LogPacketsToPCAP uint32 = 1
+
+type endpoint struct {
+ nested.Endpoint
+ writer io.Writer
+ maxPCAPLen uint32
+}
+
+var _ stack.GSOEndpoint = (*endpoint)(nil)
+var _ stack.LinkEndpoint = (*endpoint)(nil)
+var _ stack.NetworkDispatcher = (*endpoint)(nil)
+
+// New creates a new sniffer link-layer endpoint. It wraps around another
+// endpoint and logs packets and they traverse the endpoint.
+func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
+ sniffer := &endpoint{}
+ sniffer.Endpoint.Init(lower, sniffer)
+ return sniffer
+}
+
+func zoneOffset() (int32, error) {
+ loc, err := time.LoadLocation("Local")
+ if err != nil {
+ return 0, err
+ }
+ date := time.Date(0, 0, 0, 0, 0, 0, 0, loc)
+ _, offset := date.Zone()
+ return int32(offset), nil
+}
+
+func writePCAPHeader(w io.Writer, maxLen uint32) error {
+ offset, err := zoneOffset()
+ if err != nil {
+ return err
+ }
+ return binary.Write(w, binary.BigEndian, pcapHeader{
+ // From https://wiki.wireshark.org/Development/LibpcapFileFormat
+ MagicNumber: 0xa1b2c3d4,
+
+ VersionMajor: 2,
+ VersionMinor: 4,
+ Thiszone: offset,
+ Sigfigs: 0,
+ Snaplen: maxLen,
+ Network: 101, // LINKTYPE_RAW
+ })
+}
+
+// NewWithWriter creates a new sniffer link-layer endpoint. It wraps around
+// another endpoint and logs packets as they traverse the endpoint.
+//
+// Packets are logged to writer in the pcap format. A sniffer created with this
+// function will not emit packets using the standard log package.
+//
+// snapLen is the maximum amount of a packet to be saved. Packets with a length
+// less than or equal to snapLen will be saved in their entirety. Longer
+// packets will be truncated to snapLen.
+func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) (stack.LinkEndpoint, error) {
+ if err := writePCAPHeader(writer, snapLen); err != nil {
+ return nil, err
+ }
+ sniffer := &endpoint{
+ writer: writer,
+ maxPCAPLen: snapLen,
+ }
+ sniffer.Endpoint.Init(lower, sniffer)
+ return sniffer, nil
+}
+
+// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is
+// called by the link-layer endpoint being wrapped when a packet arrives, and
+// logs the packet before forwarding to the actual dispatcher.
+func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.dumpPacket("recv", nil, protocol, pkt)
+ e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt)
+}
+
+func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ writer := e.writer
+ if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
+ logPacket(prefix, protocol, pkt, gso)
+ }
+ if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 {
+ totalLength := pkt.Header.UsedLength() + pkt.Data.Size()
+ length := totalLength
+ if max := int(e.maxPCAPLen); length > max {
+ length = max
+ }
+ if err := binary.Write(writer, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(totalLength))); err != nil {
+ panic(err)
+ }
+ write := func(b []byte) {
+ if len(b) > length {
+ b = b[:length]
+ }
+ for len(b) != 0 {
+ n, err := writer.Write(b)
+ if err != nil {
+ panic(err)
+ }
+ b = b[n:]
+ length -= n
+ }
+ }
+ write(pkt.Header.View())
+ for _, view := range pkt.Data.Views() {
+ if length == 0 {
+ break
+ }
+ write(view)
+ }
+ }
+}
+
+// WritePacket implements the stack.LinkEndpoint interface. It is called by
+// higher-level protocols to write packets; it just logs the packet and
+// forwards the request to the lower endpoint.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.dumpPacket("send", gso, protocol, pkt)
+ return e.Endpoint.WritePacket(r, gso, protocol, pkt)
+}
+
+// WritePackets implements the stack.LinkEndpoint interface. It is called by
+// higher-level protocols to write packets; it just logs the packet and
+// forwards the request to the lower endpoint.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.dumpPacket("send", gso, protocol, pkt)
+ }
+ return e.Endpoint.WritePackets(r, gso, pkts, protocol)
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ e.dumpPacket("send", nil, 0, &stack.PacketBuffer{
+ Data: vv,
+ })
+ return e.Endpoint.WriteRawPacket(vv)
+}
+
+func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) {
+ // Figure out the network layer info.
+ var transProto uint8
+ src := tcpip.Address("unknown")
+ dst := tcpip.Address("unknown")
+ id := 0
+ size := uint16(0)
+ var fragmentOffset uint16
+ var moreFragments bool
+
+ // Create a clone of pkt, including any headers if present. Avoid allocating
+ // backing memory for the clone.
+ views := [8]buffer.View{}
+ vv := buffer.NewVectorisedView(0, views[:0])
+ vv.AppendView(pkt.Header.View())
+ vv.Append(pkt.Data)
+
+ switch protocol {
+ case header.IPv4ProtocolNumber:
+ hdr, ok := vv.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return
+ }
+ ipv4 := header.IPv4(hdr)
+ fragmentOffset = ipv4.FragmentOffset()
+ moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments
+ src = ipv4.SourceAddress()
+ dst = ipv4.DestinationAddress()
+ transProto = ipv4.Protocol()
+ size = ipv4.TotalLength() - uint16(ipv4.HeaderLength())
+ vv.TrimFront(int(ipv4.HeaderLength()))
+ id = int(ipv4.ID())
+
+ case header.IPv6ProtocolNumber:
+ hdr, ok := vv.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return
+ }
+ ipv6 := header.IPv6(hdr)
+ src = ipv6.SourceAddress()
+ dst = ipv6.DestinationAddress()
+ transProto = ipv6.NextHeader()
+ size = ipv6.PayloadLength()
+ vv.TrimFront(header.IPv6MinimumSize)
+
+ case header.ARPProtocolNumber:
+ hdr, ok := vv.PullUp(header.ARPSize)
+ if !ok {
+ return
+ }
+ vv.TrimFront(header.ARPSize)
+ arp := header.ARP(hdr)
+ log.Infof(
+ "%s arp %s (%s) -> %s (%s) valid:%t",
+ prefix,
+ tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()),
+ tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()),
+ arp.IsValid(),
+ )
+ return
+ default:
+ log.Infof("%s unknown network protocol", prefix)
+ return
+ }
+
+ // Figure out the transport layer info.
+ transName := "unknown"
+ srcPort := uint16(0)
+ dstPort := uint16(0)
+ details := ""
+ switch tcpip.TransportProtocolNumber(transProto) {
+ case header.ICMPv4ProtocolNumber:
+ transName = "icmp"
+ hdr, ok := vv.PullUp(header.ICMPv4MinimumSize)
+ if !ok {
+ break
+ }
+ icmp := header.ICMPv4(hdr)
+ icmpType := "unknown"
+ if fragmentOffset == 0 {
+ switch icmp.Type() {
+ case header.ICMPv4EchoReply:
+ icmpType = "echo reply"
+ case header.ICMPv4DstUnreachable:
+ icmpType = "destination unreachable"
+ case header.ICMPv4SrcQuench:
+ icmpType = "source quench"
+ case header.ICMPv4Redirect:
+ icmpType = "redirect"
+ case header.ICMPv4Echo:
+ icmpType = "echo"
+ case header.ICMPv4TimeExceeded:
+ icmpType = "time exceeded"
+ case header.ICMPv4ParamProblem:
+ icmpType = "param problem"
+ case header.ICMPv4Timestamp:
+ icmpType = "timestamp"
+ case header.ICMPv4TimestampReply:
+ icmpType = "timestamp reply"
+ case header.ICMPv4InfoRequest:
+ icmpType = "info request"
+ case header.ICMPv4InfoReply:
+ icmpType = "info reply"
+ }
+ }
+ log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ return
+
+ case header.ICMPv6ProtocolNumber:
+ transName = "icmp"
+ hdr, ok := vv.PullUp(header.ICMPv6MinimumSize)
+ if !ok {
+ break
+ }
+ icmp := header.ICMPv6(hdr)
+ icmpType := "unknown"
+ switch icmp.Type() {
+ case header.ICMPv6DstUnreachable:
+ icmpType = "destination unreachable"
+ case header.ICMPv6PacketTooBig:
+ icmpType = "packet too big"
+ case header.ICMPv6TimeExceeded:
+ icmpType = "time exceeded"
+ case header.ICMPv6ParamProblem:
+ icmpType = "param problem"
+ case header.ICMPv6EchoRequest:
+ icmpType = "echo request"
+ case header.ICMPv6EchoReply:
+ icmpType = "echo reply"
+ case header.ICMPv6RouterSolicit:
+ icmpType = "router solicit"
+ case header.ICMPv6RouterAdvert:
+ icmpType = "router advert"
+ case header.ICMPv6NeighborSolicit:
+ icmpType = "neighbor solicit"
+ case header.ICMPv6NeighborAdvert:
+ icmpType = "neighbor advert"
+ case header.ICMPv6RedirectMsg:
+ icmpType = "redirect message"
+ }
+ log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ return
+
+ case header.UDPProtocolNumber:
+ transName = "udp"
+ hdr, ok := vv.PullUp(header.UDPMinimumSize)
+ if !ok {
+ break
+ }
+ udp := header.UDP(hdr)
+ if fragmentOffset == 0 {
+ srcPort = udp.SourcePort()
+ dstPort = udp.DestinationPort()
+ details = fmt.Sprintf("xsum: 0x%x", udp.Checksum())
+ size -= header.UDPMinimumSize
+ }
+
+ case header.TCPProtocolNumber:
+ transName = "tcp"
+ hdr, ok := vv.PullUp(header.TCPMinimumSize)
+ if !ok {
+ break
+ }
+ tcp := header.TCP(hdr)
+ if fragmentOffset == 0 {
+ offset := int(tcp.DataOffset())
+ if offset < header.TCPMinimumSize {
+ details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset)
+ break
+ }
+ if offset > vv.Size() && !moreFragments {
+ details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, vv.Size())
+ break
+ }
+
+ srcPort = tcp.SourcePort()
+ dstPort = tcp.DestinationPort()
+ size -= uint16(offset)
+
+ // Initialize the TCP flags.
+ flags := tcp.Flags()
+ flagsStr := []byte("FSRPAU")
+ for i := range flagsStr {
+ if flags&(1<<uint(i)) == 0 {
+ flagsStr[i] = ' '
+ }
+ }
+ details = fmt.Sprintf("flags:0x%02x (%s) seqnum: %d ack: %d win: %d xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum())
+ if flags&header.TCPFlagSyn != 0 {
+ details += fmt.Sprintf(" options: %+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0))
+ } else {
+ details += fmt.Sprintf(" options: %+v", tcp.ParsedOptions())
+ }
+ }
+
+ default:
+ log.Infof("%s %s -> %s unknown transport protocol: %d", prefix, src, dst, transProto)
+ return
+ }
+
+ if gso != nil {
+ details += fmt.Sprintf(" gso: %+v", gso)
+ }
+
+ log.Infof("%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details)
+}
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
new file mode 100644
index 000000000..e0db6cf54
--- /dev/null
+++ b/pkg/tcpip/link/tun/BUILD
@@ -0,0 +1,25 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "tun",
+ srcs = [
+ "device.go",
+ "protocol.go",
+ "tun_unsafe.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/refs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
new file mode 100644
index 000000000..6bc9033d0
--- /dev/null
+++ b/pkg/tcpip/link/tun/device.go
@@ -0,0 +1,358 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tun
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ // drivers/net/tun.c:tun_net_init()
+ defaultDevMtu = 1500
+
+ // Queue length for outbound packet, arriving at fd side for read. Overflow
+ // causes packet drops. gVisor implementation-specific.
+ defaultDevOutQueueLen = 1024
+)
+
+var zeroMAC [6]byte
+
+// Device is an opened /dev/net/tun device.
+//
+// +stateify savable
+type Device struct {
+ waiter.Queue
+
+ mu sync.RWMutex `state:"nosave"`
+ endpoint *tunEndpoint
+ notifyHandle *channel.NotificationHandle
+ flags uint16
+}
+
+// beforeSave is invoked by stateify.
+func (d *Device) beforeSave() {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ // TODO(b/110961832): Restore the device to stack. At this moment, the stack
+ // is not savable.
+ if d.endpoint != nil {
+ panic("/dev/net/tun does not support save/restore when a device is associated with it.")
+ }
+}
+
+// Release implements fs.FileOperations.Release.
+func (d *Device) Release() {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // Decrease refcount if there is an endpoint associated with this file.
+ if d.endpoint != nil {
+ d.endpoint.RemoveNotify(d.notifyHandle)
+ d.endpoint.DecRef()
+ d.endpoint = nil
+ }
+}
+
+// SetIff services TUNSETIFF ioctl(2) request.
+func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ if d.endpoint != nil {
+ return syserror.EINVAL
+ }
+
+ // Input validations.
+ isTun := flags&linux.IFF_TUN != 0
+ isTap := flags&linux.IFF_TAP != 0
+ supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI)
+ if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 {
+ return syserror.EINVAL
+ }
+
+ prefix := "tun"
+ if isTap {
+ prefix = "tap"
+ }
+
+ linkCaps := stack.CapabilityNone
+ if isTap {
+ linkCaps |= stack.CapabilityResolutionRequired
+ }
+
+ endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps)
+ if err != nil {
+ return syserror.EINVAL
+ }
+
+ d.endpoint = endpoint
+ d.notifyHandle = d.endpoint.AddNotify(d)
+ d.flags = flags
+ return nil
+}
+
+func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) {
+ for {
+ // 1. Try to attach to an existing NIC.
+ if name != "" {
+ if nic, found := s.GetNICByName(name); found {
+ endpoint, ok := nic.LinkEndpoint().(*tunEndpoint)
+ if !ok {
+ // Not a NIC created by tun device.
+ return nil, syserror.EOPNOTSUPP
+ }
+ if !endpoint.TryIncRef() {
+ // Race detected: NIC got deleted in between.
+ continue
+ }
+ return endpoint, nil
+ }
+ }
+
+ // 2. Creating a new NIC.
+ id := tcpip.NICID(s.UniqueID())
+ endpoint := &tunEndpoint{
+ Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""),
+ stack: s,
+ nicID: id,
+ name: name,
+ }
+ endpoint.Endpoint.LinkEPCapabilities = linkCaps
+ if endpoint.name == "" {
+ endpoint.name = fmt.Sprintf("%s%d", prefix, id)
+ }
+ err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{
+ Name: endpoint.name,
+ })
+ switch err {
+ case nil:
+ return endpoint, nil
+ case tcpip.ErrDuplicateNICID:
+ // Race detected: A NIC has been created in between.
+ continue
+ default:
+ return nil, syserror.EINVAL
+ }
+ }
+}
+
+// Write inject one inbound packet to the network interface.
+func (d *Device) Write(data []byte) (int64, error) {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint == nil {
+ return 0, syserror.EBADFD
+ }
+ if !endpoint.IsAttached() {
+ return 0, syserror.EIO
+ }
+
+ dataLen := int64(len(data))
+
+ // Packet information.
+ var pktInfoHdr PacketInfoHeader
+ if !d.hasFlags(linux.IFF_NO_PI) {
+ if len(data) < PacketInfoHeaderSize {
+ // Ignore bad packet.
+ return dataLen, nil
+ }
+ pktInfoHdr = PacketInfoHeader(data[:PacketInfoHeaderSize])
+ data = data[PacketInfoHeaderSize:]
+ }
+
+ // Ethernet header (TAP only).
+ var ethHdr header.Ethernet
+ if d.hasFlags(linux.IFF_TAP) {
+ if len(data) < header.EthernetMinimumSize {
+ // Ignore bad packet.
+ return dataLen, nil
+ }
+ ethHdr = header.Ethernet(data[:header.EthernetMinimumSize])
+ data = data[header.EthernetMinimumSize:]
+ }
+
+ // Try to determine network protocol number, default zero.
+ var protocol tcpip.NetworkProtocolNumber
+ switch {
+ case pktInfoHdr != nil:
+ protocol = pktInfoHdr.Protocol()
+ case ethHdr != nil:
+ protocol = ethHdr.Type()
+ }
+
+ // Try to determine remote link address, default zero.
+ var remote tcpip.LinkAddress
+ switch {
+ case ethHdr != nil:
+ remote = ethHdr.SourceAddress()
+ default:
+ remote = tcpip.LinkAddress(zeroMAC[:])
+ }
+
+ pkt := &stack.PacketBuffer{
+ Data: buffer.View(data).ToVectorisedView(),
+ }
+ if ethHdr != nil {
+ pkt.LinkHeader = buffer.View(ethHdr)
+ }
+ endpoint.InjectLinkAddr(protocol, remote, pkt)
+ return dataLen, nil
+}
+
+// Read reads one outgoing packet from the network interface.
+func (d *Device) Read() ([]byte, error) {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint == nil {
+ return nil, syserror.EBADFD
+ }
+
+ for {
+ info, ok := endpoint.Read()
+ if !ok {
+ return nil, syserror.ErrWouldBlock
+ }
+
+ v, ok := d.encodePkt(&info)
+ if !ok {
+ // Ignore unsupported packet.
+ continue
+ }
+ return v, nil
+ }
+}
+
+// encodePkt encodes packet for fd side.
+func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
+ var vv buffer.VectorisedView
+
+ // Packet information.
+ if !d.hasFlags(linux.IFF_NO_PI) {
+ hdr := make(PacketInfoHeader, PacketInfoHeaderSize)
+ hdr.Encode(&PacketInfoFields{
+ Protocol: info.Proto,
+ })
+ vv.AppendView(buffer.View(hdr))
+ }
+
+ // If the packet does not already have link layer header, and the route
+ // does not exist, we can't compute it. This is possibly a raw packet, tun
+ // device doesn't support this at the moment.
+ if info.Pkt.LinkHeader == nil && info.Route.RemoteLinkAddress == "" {
+ return nil, false
+ }
+
+ // Ethernet header (TAP only).
+ if d.hasFlags(linux.IFF_TAP) {
+ // Add ethernet header if not provided.
+ if info.Pkt.LinkHeader == nil {
+ hdr := &header.EthernetFields{
+ SrcAddr: info.Route.LocalLinkAddress,
+ DstAddr: info.Route.RemoteLinkAddress,
+ Type: info.Proto,
+ }
+ if hdr.SrcAddr == "" {
+ hdr.SrcAddr = d.endpoint.LinkAddress()
+ }
+
+ eth := make(header.Ethernet, header.EthernetMinimumSize)
+ eth.Encode(hdr)
+ vv.AppendView(buffer.View(eth))
+ } else {
+ vv.AppendView(info.Pkt.LinkHeader)
+ }
+ }
+
+ // Append upper headers.
+ vv.AppendView(buffer.View(info.Pkt.Header.View()[len(info.Pkt.LinkHeader):]))
+ // Append data payload.
+ vv.Append(info.Pkt.Data)
+
+ return vv.ToView(), true
+}
+
+// Name returns the name of the attached network interface. Empty string if
+// unattached.
+func (d *Device) Name() string {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+ if d.endpoint != nil {
+ return d.endpoint.name
+ }
+ return ""
+}
+
+// Flags returns the flags set for d. Zero value if unset.
+func (d *Device) Flags() uint16 {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+ return d.flags
+}
+
+func (d *Device) hasFlags(flags uint16) bool {
+ return d.flags&flags == flags
+}
+
+// Readiness implements watier.Waitable.Readiness.
+func (d *Device) Readiness(mask waiter.EventMask) waiter.EventMask {
+ if mask&waiter.EventIn != 0 {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint != nil && endpoint.NumQueued() == 0 {
+ mask &= ^waiter.EventIn
+ }
+ }
+ return mask & (waiter.EventIn | waiter.EventOut)
+}
+
+// WriteNotify implements channel.Notification.WriteNotify.
+func (d *Device) WriteNotify() {
+ d.Notify(waiter.EventIn)
+}
+
+// tunEndpoint is the link endpoint for the NIC created by the tun device.
+//
+// It is ref-counted as multiple opening files can attach to the same NIC.
+// The last owner is responsible for deleting the NIC.
+type tunEndpoint struct {
+ *channel.Endpoint
+
+ refs.AtomicRefCount
+
+ stack *stack.Stack
+ nicID tcpip.NICID
+ name string
+}
+
+// DecRef decrements refcount of e, removes NIC if refcount goes to 0.
+func (e *tunEndpoint) DecRef() {
+ e.DecRefWithDestructor(func() {
+ e.stack.RemoveNIC(e.nicID)
+ })
+}
diff --git a/pkg/tcpip/link/tun/protocol.go b/pkg/tcpip/link/tun/protocol.go
new file mode 100644
index 000000000..89d9d91a9
--- /dev/null
+++ b/pkg/tcpip/link/tun/protocol.go
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tun
+
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // PacketInfoHeaderSize is the size of the packet information header.
+ PacketInfoHeaderSize = 4
+
+ offsetFlags = 0
+ offsetProtocol = 2
+)
+
+// PacketInfoFields contains fields sent through the wire if IFF_NO_PI flag is
+// not set.
+type PacketInfoFields struct {
+ Flags uint16
+ Protocol tcpip.NetworkProtocolNumber
+}
+
+// PacketInfoHeader is the wire representation of the packet information sent if
+// IFF_NO_PI flag is not set.
+type PacketInfoHeader []byte
+
+// Encode encodes f into h.
+func (h PacketInfoHeader) Encode(f *PacketInfoFields) {
+ binary.BigEndian.PutUint16(h[offsetFlags:][:2], f.Flags)
+ binary.BigEndian.PutUint16(h[offsetProtocol:][:2], uint16(f.Protocol))
+}
+
+// Flags returns the flag field in h.
+func (h PacketInfoHeader) Flags() uint16 {
+ return binary.BigEndian.Uint16(h[offsetFlags:])
+}
+
+// Protocol returns the protocol field in h.
+func (h PacketInfoHeader) Protocol() tcpip.NetworkProtocolNumber {
+ return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(h[offsetProtocol:]))
+}
diff --git a/pkg/tcpip/link/tun/tun_unsafe.go b/pkg/tcpip/link/tun/tun_unsafe.go
new file mode 100644
index 000000000..09ca9b527
--- /dev/null
+++ b/pkg/tcpip/link/tun/tun_unsafe.go
@@ -0,0 +1,63 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+// Package tun contains methods to open TAP and TUN devices.
+package tun
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+// Open opens the specified TUN device, sets it to non-blocking mode, and
+// returns its file descriptor.
+func Open(name string) (int, error) {
+ return open(name, syscall.IFF_TUN|syscall.IFF_NO_PI)
+}
+
+// OpenTAP opens the specified TAP device, sets it to non-blocking mode, and
+// returns its file descriptor.
+func OpenTAP(name string) (int, error) {
+ return open(name, syscall.IFF_TAP|syscall.IFF_NO_PI)
+}
+
+func open(name string, flags uint16) (int, error) {
+ fd, err := syscall.Open("/dev/net/tun", syscall.O_RDWR, 0)
+ if err != nil {
+ return -1, err
+ }
+
+ var ifr struct {
+ name [16]byte
+ flags uint16
+ _ [22]byte
+ }
+
+ copy(ifr.name[:], name)
+ ifr.flags = flags
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.TUNSETIFF, uintptr(unsafe.Pointer(&ifr)))
+ if errno != 0 {
+ syscall.Close(fd)
+ return -1, errno
+ }
+
+ if err = syscall.SetNonblock(fd, true); err != nil {
+ syscall.Close(fd)
+ return -1, err
+ }
+
+ return fd, nil
+}
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
new file mode 100644
index 000000000..0956d2c65
--- /dev/null
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -0,0 +1,30 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "waitable",
+ srcs = [
+ "waitable.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/gate",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "waitable_test",
+ srcs = [
+ "waitable_test.go",
+ ],
+ library = ":waitable",
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
new file mode 100644
index 000000000..949b3f2b2
--- /dev/null
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -0,0 +1,149 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package waitable provides the implementation of data-link layer endpoints
+// that wrap other endpoints, and can wait for inflight calls to WritePacket or
+// DeliverNetworkPacket to finish (and new ones to be prevented).
+//
+// Waitable endpoints can be used in the networking stack by calling New(eID) to
+// create a new endpoint, where eID is the ID of the endpoint being wrapped,
+// and then passing it as an argument to Stack.CreateNIC().
+package waitable
+
+import (
+ "gvisor.dev/gvisor/pkg/gate"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// Endpoint is a waitable link-layer endpoint.
+type Endpoint struct {
+ dispatchGate gate.Gate
+ dispatcher stack.NetworkDispatcher
+
+ writeGate gate.Gate
+ lower stack.LinkEndpoint
+}
+
+// New creates a new waitable link-layer endpoint. It wraps around another
+// endpoint and allows the caller to block new write/dispatch calls and wait for
+// the inflight ones to finish before returning.
+func New(lower stack.LinkEndpoint) *Endpoint {
+ return &Endpoint{
+ lower: lower,
+ }
+}
+
+// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket.
+// It is called by the link-layer endpoint being wrapped when a packet arrives,
+// and only forwards to the actual dispatcher if Wait or WaitDispatch haven't
+// been called.
+func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if !e.dispatchGate.Enter() {
+ return
+ }
+
+ e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt)
+ e.dispatchGate.Leave()
+}
+
+// Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and
+// registers with the lower endpoint as its dispatcher so that "e" is called
+// for inbound packets.
+func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+ e.lower.Attach(e)
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *Endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the
+// lower endpoint.
+func (e *Endpoint) MTU() uint32 {
+ return e.lower.MTU()
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the
+// request to the lower endpoint.
+func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.lower.Capabilities()
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It just
+// forwards the request to the lower endpoint.
+func (e *Endpoint) MaxHeaderLength() uint16 {
+ return e.lower.MaxHeaderLength()
+}
+
+// LinkAddress implements stack.LinkEndpoint.LinkAddress. It just forwards the
+// request to the lower endpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.lower.LinkAddress()
+}
+
+// WritePacket implements stack.LinkEndpoint.WritePacket. It is called by
+// higher-level protocols to write packets. It only forwards packets to the
+// lower endpoint if Wait or WaitWrite haven't been called.
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if !e.writeGate.Enter() {
+ return nil
+ }
+
+ err := e.lower.WritePacket(r, gso, protocol, pkt)
+ e.writeGate.Leave()
+ return err
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by
+// higher-level protocols to write packets. It only forwards packets to the
+// lower endpoint if Wait or WaitWrite haven't been called.
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ if !e.writeGate.Enter() {
+ return pkts.Len(), nil
+ }
+
+ n, err := e.lower.WritePackets(r, gso, pkts, protocol)
+ e.writeGate.Leave()
+ return n, err
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ if !e.writeGate.Enter() {
+ return nil
+ }
+
+ err := e.lower.WriteRawPacket(vv)
+ e.writeGate.Leave()
+ return err
+}
+
+// WaitWrite prevents new calls to WritePacket from reaching the lower endpoint,
+// and waits for inflight ones to finish before returning.
+func (e *Endpoint) WaitWrite() {
+ e.writeGate.Close()
+}
+
+// WaitDispatch prevents new calls to DeliverNetworkPacket from reaching the
+// actual dispatcher, and waits for inflight ones to finish before returning.
+func (e *Endpoint) WaitDispatch() {
+ e.dispatchGate.Close()
+}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (e *Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
new file mode 100644
index 000000000..63bf40562
--- /dev/null
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -0,0 +1,173 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package waitable
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type countedEndpoint struct {
+ dispatchCount int
+ writeCount int
+ attachCount int
+
+ mtu uint32
+ capabilities stack.LinkEndpointCapabilities
+ hdrLen uint16
+ linkAddr tcpip.LinkAddress
+
+ dispatcher stack.NetworkDispatcher
+}
+
+func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.dispatchCount++
+}
+
+func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.attachCount++
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *countedEndpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+func (e *countedEndpoint) MTU() uint32 {
+ return e.mtu
+}
+
+func (e *countedEndpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.capabilities
+}
+
+func (e *countedEndpoint) MaxHeaderLength() uint16 {
+ return e.hdrLen
+}
+
+func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress {
+ return e.linkAddr
+}
+
+func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.writeCount++
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ e.writeCount += pkts.Len()
+ return pkts.Len(), nil
+}
+
+func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
+ e.writeCount++
+ return nil
+}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*countedEndpoint) Wait() {}
+
+func TestWaitWrite(t *testing.T) {
+ ep := &countedEndpoint{}
+ wep := New(ep)
+
+ // Write and check that it goes through.
+ wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{})
+ if want := 1; ep.writeCount != want {
+ t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
+ }
+
+ // Wait on dispatches, then try to write. It must go through.
+ wep.WaitDispatch()
+ wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{})
+ if want := 2; ep.writeCount != want {
+ t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
+ }
+
+ // Wait on writes, then try to write. It must not go through.
+ wep.WaitWrite()
+ wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{})
+ if want := 2; ep.writeCount != want {
+ t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
+ }
+}
+
+func TestWaitDispatch(t *testing.T) {
+ ep := &countedEndpoint{}
+ wep := New(ep)
+
+ // Check that attach happens.
+ wep.Attach(ep)
+ if want := 1; ep.attachCount != want {
+ t.Fatalf("Unexpected attachCount: got=%v, want=%v", ep.attachCount, want)
+ }
+
+ // Dispatch and check that it goes through.
+ ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{})
+ if want := 1; ep.dispatchCount != want {
+ t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
+ }
+
+ // Wait on writes, then try to dispatch. It must go through.
+ wep.WaitWrite()
+ ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{})
+ if want := 2; ep.dispatchCount != want {
+ t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
+ }
+
+ // Wait on dispatches, then try to dispatch. It must not go through.
+ wep.WaitDispatch()
+ ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{})
+ if want := 2; ep.dispatchCount != want {
+ t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
+ }
+}
+
+func TestOtherMethods(t *testing.T) {
+ const (
+ mtu = 0xdead
+ capabilities = 0xbeef
+ hdrLen = 0x1234
+ linkAddr = "test address"
+ )
+ ep := &countedEndpoint{
+ mtu: mtu,
+ capabilities: capabilities,
+ hdrLen: hdrLen,
+ linkAddr: linkAddr,
+ }
+ wep := New(ep)
+
+ if v := wep.MTU(); v != mtu {
+ t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu)
+ }
+
+ if v := wep.Capabilities(); v != capabilities {
+ t.Fatalf("Unexpected capabilities: got=%v, want=%v", v, capabilities)
+ }
+
+ if v := wep.MaxHeaderLength(); v != hdrLen {
+ t.Fatalf("Unexpected MaxHeaderLength: got=%v, want=%v", v, hdrLen)
+ }
+
+ if v := wep.LinkAddress(); v != linkAddr {
+ t.Fatalf("Unexpected LinkAddress: got=%q, want=%q", v, linkAddr)
+ }
+}
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
new file mode 100644
index 000000000..6a4839fb8
--- /dev/null
+++ b/pkg/tcpip/network/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_test")
+
+package(licenses = ["notice"])
+
+go_test(
+ name = "ip_test",
+ size = "small",
+ srcs = [
+ "ip_test.go",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ ],
+)
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
new file mode 100644
index 000000000..eddf7b725
--- /dev/null
+++ b/pkg/tcpip/network/arp/BUILD
@@ -0,0 +1,32 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "arp",
+ srcs = ["arp.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "arp_test",
+ size = "small",
+ srcs = ["arp_test.go"],
+ deps = [
+ ":arp",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
+ ],
+)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
new file mode 100644
index 000000000..7f27a840d
--- /dev/null
+++ b/pkg/tcpip/network/arp/arp.go
@@ -0,0 +1,224 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package arp implements the ARP network protocol. It is used to resolve
+// IPv4 addresses into link-local MAC addresses, and advertises IPv4
+// addresses of its stack with the local network.
+//
+// To use it in the networking stack, pass arp.NewProtocol() as one of the
+// network protocols when calling stack.New. Then add an "arp" address to every
+// NIC on the stack that should respond to ARP requests. That is:
+//
+// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil {
+// // handle err
+// }
+package arp
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // ProtocolNumber is the ARP protocol number.
+ ProtocolNumber = header.ARPProtocolNumber
+
+ // ProtocolAddress is the address expected by the ARP endpoint.
+ ProtocolAddress = tcpip.Address("arp")
+)
+
+// endpoint implements stack.NetworkEndpoint.
+type endpoint struct {
+ protocol *protocol
+ nicID tcpip.NICID
+ linkEP stack.LinkEndpoint
+ linkAddrCache stack.LinkAddressCache
+}
+
+// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint.
+func (e *endpoint) DefaultTTL() uint8 {
+ return 0
+}
+
+func (e *endpoint) MTU() uint32 {
+ lmtu := e.linkEP.MTU()
+ return lmtu - uint32(e.MaxHeaderLength())
+}
+
+func (e *endpoint) NICID() tcpip.NICID {
+ return e.nicID
+}
+
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+func (e *endpoint) ID() *stack.NetworkEndpointID {
+ return &stack.NetworkEndpointID{ProtocolAddress}
+}
+
+func (e *endpoint) PrefixLen() int {
+ return 0
+}
+
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.ARPSize
+}
+
+func (e *endpoint) Close() {}
+
+func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
+func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return e.protocol.Number()
+}
+
+// WritePackets implements stack.NetworkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) {
+ return 0, tcpip.ErrNotSupported
+}
+
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ h := header.ARP(pkt.NetworkHeader)
+ if !h.IsValid() {
+ return
+ }
+
+ switch h.Op() {
+ case header.ARPRequest:
+ localAddr := tcpip.Address(h.ProtocolAddressTarget())
+ if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ return // we have no useful answer, ignore the request
+ }
+ hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize)
+ packet := header.ARP(hdr.Prepend(header.ARPSize))
+ packet.SetIPv4OverEthernet()
+ packet.SetOp(header.ARPReply)
+ copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:])
+ copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget())
+ copy(packet.HardwareAddressTarget(), h.HardwareAddressSender())
+ copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
+ e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{
+ Header: hdr,
+ })
+ fallthrough // also fill the cache from requests
+ case header.ARPReply:
+ addr := tcpip.Address(h.ProtocolAddressSender())
+ linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+ e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr)
+ }
+}
+
+// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver.
+type protocol struct {
+}
+
+func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
+func (p *protocol) MinimumPacketSize() int { return header.ARPSize }
+func (p *protocol) DefaultPrefixLen() int { return 0 }
+
+func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.ARP(v)
+ return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
+}
+
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
+ if addrWithPrefix.Address != ProtocolAddress {
+ return nil, tcpip.ErrBadLocalAddress
+ }
+ return &endpoint{
+ protocol: p,
+ nicID: nicID,
+ linkEP: sender,
+ linkAddrCache: linkAddrCache,
+ }, nil
+}
+
+// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
+func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return header.IPv4ProtocolNumber
+}
+
+// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+ r := &stack.Route{
+ RemoteLinkAddress: broadcastMAC,
+ }
+
+ hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize)
+ h := header.ARP(hdr.Prepend(header.ARPSize))
+ h.SetIPv4OverEthernet()
+ h.SetOp(header.ARPRequest)
+ copy(h.HardwareAddressSender(), linkEP.LinkAddress())
+ copy(h.ProtocolAddressSender(), localAddr)
+ copy(h.ProtocolAddressTarget(), addr)
+
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{
+ Header: hdr,
+ })
+}
+
+// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
+func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == header.IPv4Broadcast {
+ return broadcastMAC, true
+ }
+ if header.IsV4MulticastAddress(addr) {
+ return header.EthernetAddressFromMulticastIPv4Address(addr), true
+ }
+ return tcpip.LinkAddress([]byte(nil)), false
+}
+
+// SetOption implements stack.NetworkProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements stack.NetworkProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
+// Parse implements stack.NetworkProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
+ hdr, ok := pkt.Data.PullUp(header.ARPSize)
+ if !ok {
+ return 0, false, false
+ }
+ pkt.NetworkHeader = hdr
+ pkt.Data.TrimFront(header.ARPSize)
+ return 0, false, true
+}
+
+var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
+
+// NewProtocol returns an ARP network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
+}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
new file mode 100644
index 000000000..66e67429c
--- /dev/null
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -0,0 +1,146 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arp_test
+
+import (
+ "context"
+ "strconv"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+)
+
+const (
+ stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
+ stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
+ stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
+ stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
+)
+
+type testContext struct {
+ t *testing.T
+ linkEP *channel.Endpoint
+ s *stack.Stack
+}
+
+func newTestContext(t *testing.T) *testContext {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()},
+ })
+
+ const defaultMTU = 65536
+ ep := channel.New(256, defaultMTU, stackLinkAddr)
+ wep := stack.LinkEndpoint(ep)
+
+ if testing.Verbose() {
+ wep = sniffer.New(ep)
+ }
+ if err := s.CreateNIC(1, wep); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil {
+ t.Fatalf("AddAddress for ipv4 failed: %v", err)
+ }
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil {
+ t.Fatalf("AddAddress for ipv4 failed: %v", err)
+ }
+ if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("AddAddress for arp failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv4EmptySubnet,
+ NIC: 1,
+ }})
+
+ return &testContext{
+ t: t,
+ s: s,
+ linkEP: ep,
+ }
+}
+
+func (c *testContext) cleanup() {
+ c.linkEP.Close()
+}
+
+func TestDirectRequest(t *testing.T) {
+ c := newTestContext(t)
+ defer c.cleanup()
+
+ const senderMAC = "\x01\x02\x03\x04\x05\x06"
+ const senderIPv4 = "\x0a\x00\x00\x02"
+
+ v := make(buffer.View, header.ARPSize)
+ h := header.ARP(v)
+ h.SetIPv4OverEthernet()
+ h.SetOp(header.ARPRequest)
+ copy(h.HardwareAddressSender(), senderMAC)
+ copy(h.ProtocolAddressSender(), senderIPv4)
+
+ inject := func(addr tcpip.Address) {
+ copy(h.ProtocolAddressTarget(), addr)
+ c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{
+ Data: v.ToVectorisedView(),
+ })
+ }
+
+ for i, address := range []tcpip.Address{stackAddr1, stackAddr2} {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ inject(address)
+ pi, _ := c.linkEP.ReadContext(context.Background())
+ if pi.Proto != arp.ProtocolNumber {
+ t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto)
+ }
+ rep := header.ARP(pi.Pkt.Header.View())
+ if !rep.IsValid() {
+ t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength())
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
+ t.Errorf("got HardwareAddressSender = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want {
+ t.Errorf("got ProtocolAddressSender = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want {
+ t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want {
+ t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, want)
+ }
+ })
+ }
+
+ inject(stackAddrBad)
+ // Sleep tests are gross, but this will only potentially flake
+ // if there's a bug. If there is no bug this will reliably
+ // succeed.
+ ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer cancel()
+ if pkt, ok := c.linkEP.ReadContext(ctx); ok {
+ t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
new file mode 100644
index 000000000..d1c728ccf
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -0,0 +1,45 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "reassembler_list",
+ out = "reassembler_list.go",
+ package = "fragmentation",
+ prefix = "reassembler",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*reassembler",
+ "Linker": "*reassembler",
+ },
+)
+
+go_library(
+ name = "fragmentation",
+ srcs = [
+ "frag_heap.go",
+ "fragmentation.go",
+ "reassembler.go",
+ "reassembler_list.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/log",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ ],
+)
+
+go_test(
+ name = "fragmentation_test",
+ size = "small",
+ srcs = [
+ "frag_heap_test.go",
+ "fragmentation_test.go",
+ "reassembler_test.go",
+ ],
+ library = ":fragmentation",
+ deps = ["//pkg/tcpip/buffer"],
+)
diff --git a/pkg/tcpip/network/fragmentation/frag_heap.go b/pkg/tcpip/network/fragmentation/frag_heap.go
new file mode 100644
index 000000000..0b570d25a
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/frag_heap.go
@@ -0,0 +1,77 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fragmentation
+
+import (
+ "container/heap"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+type fragment struct {
+ offset uint16
+ vv buffer.VectorisedView
+}
+
+type fragHeap []fragment
+
+func (h *fragHeap) Len() int {
+ return len(*h)
+}
+
+func (h *fragHeap) Less(i, j int) bool {
+ return (*h)[i].offset < (*h)[j].offset
+}
+
+func (h *fragHeap) Swap(i, j int) {
+ (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
+}
+
+func (h *fragHeap) Push(x interface{}) {
+ *h = append(*h, x.(fragment))
+}
+
+func (h *fragHeap) Pop() interface{} {
+ old := *h
+ n := len(old)
+ x := old[n-1]
+ *h = old[:n-1]
+ return x
+}
+
+// reassamble empties the heap and returns a VectorisedView
+// containing a reassambled version of the fragments inside the heap.
+func (h *fragHeap) reassemble() (buffer.VectorisedView, error) {
+ curr := heap.Pop(h).(fragment)
+ views := curr.vv.Views()
+ size := curr.vv.Size()
+
+ if curr.offset != 0 {
+ return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset)
+ }
+
+ for h.Len() > 0 {
+ curr := heap.Pop(h).(fragment)
+ if int(curr.offset) < size {
+ curr.vv.TrimFront(size - int(curr.offset))
+ } else if int(curr.offset) > size {
+ return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset)
+ }
+ size += curr.vv.Size()
+ views = append(views, curr.vv.Views()...)
+ }
+ return buffer.NewVectorisedView(size, views), nil
+}
diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go
new file mode 100644
index 000000000..9ececcb9f
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/frag_heap_test.go
@@ -0,0 +1,126 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fragmentation
+
+import (
+ "container/heap"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+var reassambleTestCases = []struct {
+ comment string
+ in []fragment
+ want buffer.VectorisedView
+}{
+ {
+ comment: "Non-overlapping in-order",
+ in: []fragment{
+ {offset: 0, vv: vv(1, "0")},
+ {offset: 1, vv: vv(1, "1")},
+ },
+ want: vv(2, "0", "1"),
+ },
+ {
+ comment: "Non-overlapping out-of-order",
+ in: []fragment{
+ {offset: 1, vv: vv(1, "1")},
+ {offset: 0, vv: vv(1, "0")},
+ },
+ want: vv(2, "0", "1"),
+ },
+ {
+ comment: "Duplicated packets",
+ in: []fragment{
+ {offset: 0, vv: vv(1, "0")},
+ {offset: 0, vv: vv(1, "0")},
+ },
+ want: vv(1, "0"),
+ },
+ {
+ comment: "Overlapping in-order",
+ in: []fragment{
+ {offset: 0, vv: vv(2, "01")},
+ {offset: 1, vv: vv(2, "12")},
+ },
+ want: vv(3, "01", "2"),
+ },
+ {
+ comment: "Overlapping out-of-order",
+ in: []fragment{
+ {offset: 1, vv: vv(2, "12")},
+ {offset: 0, vv: vv(2, "01")},
+ },
+ want: vv(3, "01", "2"),
+ },
+ {
+ comment: "Overlapping subset in-order",
+ in: []fragment{
+ {offset: 0, vv: vv(3, "012")},
+ {offset: 1, vv: vv(1, "1")},
+ },
+ want: vv(3, "012"),
+ },
+ {
+ comment: "Overlapping subset out-of-order",
+ in: []fragment{
+ {offset: 1, vv: vv(1, "1")},
+ {offset: 0, vv: vv(3, "012")},
+ },
+ want: vv(3, "012"),
+ },
+}
+
+func TestReassamble(t *testing.T) {
+ for _, c := range reassambleTestCases {
+ t.Run(c.comment, func(t *testing.T) {
+ h := make(fragHeap, 0, 8)
+ heap.Init(&h)
+ for _, f := range c.in {
+ heap.Push(&h, f)
+ }
+ got, err := h.reassemble()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(got, c.want) {
+ t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want)
+ }
+ })
+ }
+}
+
+func TestReassambleFailsForNonZeroOffset(t *testing.T) {
+ h := make(fragHeap, 0, 8)
+ heap.Init(&h)
+ heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")})
+ _, err := h.reassemble()
+ if err == nil {
+ t.Errorf("reassemble() did not fail when the first packet had offset != 0")
+ }
+}
+
+func TestReassambleFailsForHoles(t *testing.T) {
+ h := make(fragHeap, 0, 8)
+ heap.Init(&h)
+ heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")})
+ heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")})
+ _, err := h.reassemble()
+ if err == nil {
+ t.Errorf("reassemble() did not fail when there was a hole in the packet")
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
new file mode 100644
index 000000000..2982450f8
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -0,0 +1,144 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fragmentation contains the implementation of IP fragmentation.
+// It is based on RFC 791 and RFC 815.
+package fragmentation
+
+import (
+ "fmt"
+ "log"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+// DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time.
+const DefaultReassembleTimeout = 30 * time.Second
+
+// HighFragThreshold is the threshold at which we start trimming old
+// fragmented packets. Linux uses a default value of 4 MB. See
+// net.ipv4.ipfrag_high_thresh for more information.
+const HighFragThreshold = 4 << 20 // 4MB
+
+// LowFragThreshold is the threshold we reach to when we start dropping
+// older fragmented packets. It's important that we keep enough room for newer
+// packets to be re-assembled. Hence, this needs to be lower than
+// HighFragThreshold enough. Linux uses a default value of 3 MB. See
+// net.ipv4.ipfrag_low_thresh for more information.
+const LowFragThreshold = 3 << 20 // 3MB
+
+// Fragmentation is the main structure that other modules
+// of the stack should use to implement IP Fragmentation.
+type Fragmentation struct {
+ mu sync.Mutex
+ highLimit int
+ lowLimit int
+ reassemblers map[uint32]*reassembler
+ rList reassemblerList
+ size int
+ timeout time.Duration
+}
+
+// NewFragmentation creates a new Fragmentation.
+//
+// highMemoryLimit specifies the limit on the memory consumed
+// by the fragments stored by Fragmentation (overhead of internal data-structures
+// is not accounted). Fragments are dropped when the limit is reached.
+//
+// lowMemoryLimit specifies the limit on which we will reach by dropping
+// fragments after reaching highMemoryLimit.
+//
+// reassemblingTimeout specifies the maximum time allowed to reassemble a packet.
+// Fragments are lazily evicted only when a new a packet with an
+// already existing fragmentation-id arrives after the timeout.
+func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation {
+ if lowMemoryLimit >= highMemoryLimit {
+ lowMemoryLimit = highMemoryLimit
+ }
+
+ if lowMemoryLimit < 0 {
+ lowMemoryLimit = 0
+ }
+
+ return &Fragmentation{
+ reassemblers: make(map[uint32]*reassembler),
+ highLimit: highMemoryLimit,
+ lowLimit: lowMemoryLimit,
+ timeout: reassemblingTimeout,
+ }
+}
+
+// Process processes an incoming fragment belonging to an ID and returns a
+// complete packet when all the packets belonging to that ID have been received.
+func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) {
+ f.mu.Lock()
+ r, ok := f.reassemblers[id]
+ if ok && r.tooOld(f.timeout) {
+ // This is very likely to be an id-collision or someone performing a slow-rate attack.
+ f.release(r)
+ ok = false
+ }
+ if !ok {
+ r = newReassembler(id)
+ f.reassemblers[id] = r
+ f.rList.PushFront(r)
+ }
+ f.mu.Unlock()
+
+ res, done, consumed, err := r.process(first, last, more, vv)
+ if err != nil {
+ // We probably got an invalid sequence of fragments. Just
+ // discard the reassembler and move on.
+ f.mu.Lock()
+ f.release(r)
+ f.mu.Unlock()
+ return buffer.VectorisedView{}, false, fmt.Errorf("fragmentation processing error: %v", err)
+ }
+ f.mu.Lock()
+ f.size += consumed
+ if done {
+ f.release(r)
+ }
+ // Evict reassemblers if we are consuming more memory than highLimit until
+ // we reach lowLimit.
+ if f.size > f.highLimit {
+ for f.size > f.lowLimit {
+ tail := f.rList.Back()
+ if tail == nil {
+ break
+ }
+ f.release(tail)
+ }
+ }
+ f.mu.Unlock()
+ return res, done, nil
+}
+
+func (f *Fragmentation) release(r *reassembler) {
+ // Before releasing a fragment we need to check if r is already marked as done.
+ // Otherwise, we would delete it twice.
+ if r.checkDoneOrMark() {
+ return
+ }
+
+ delete(f.reassemblers, r.id)
+ f.rList.Remove(r)
+ f.size -= r.size
+ if f.size < 0 {
+ log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size)
+ f.size = 0
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
new file mode 100644
index 000000000..72c0f53be
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -0,0 +1,165 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fragmentation
+
+import (
+ "reflect"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+// vv is a helper to build VectorisedView from different strings.
+func vv(size int, pieces ...string) buffer.VectorisedView {
+ views := make([]buffer.View, len(pieces))
+ for i, p := range pieces {
+ views[i] = []byte(p)
+ }
+
+ return buffer.NewVectorisedView(size, views)
+}
+
+type processInput struct {
+ id uint32
+ first uint16
+ last uint16
+ more bool
+ vv buffer.VectorisedView
+}
+
+type processOutput struct {
+ vv buffer.VectorisedView
+ done bool
+}
+
+var processTestCases = []struct {
+ comment string
+ in []processInput
+ out []processOutput
+}{
+ {
+ comment: "One ID",
+ in: []processInput{
+ {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
+ {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ },
+ out: []processOutput{
+ {vv: buffer.VectorisedView{}, done: false},
+ {vv: vv(4, "01", "23"), done: true},
+ },
+ },
+ {
+ comment: "Two IDs",
+ in: []processInput{
+ {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
+ {id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")},
+ {id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")},
+ {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ },
+ out: []processOutput{
+ {vv: buffer.VectorisedView{}, done: false},
+ {vv: buffer.VectorisedView{}, done: false},
+ {vv: vv(4, "ab", "cd"), done: true},
+ {vv: vv(4, "01", "23"), done: true},
+ },
+ },
+}
+
+func TestFragmentationProcess(t *testing.T) {
+ for _, c := range processTestCases {
+ t.Run(c.comment, func(t *testing.T) {
+ f := NewFragmentation(1024, 512, DefaultReassembleTimeout)
+ for i, in := range c.in {
+ vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv)
+ if err != nil {
+ t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err)
+ }
+ if !reflect.DeepEqual(vv, c.out[i].vv) {
+ t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv)
+ }
+ if done != c.out[i].done {
+ t.Errorf("got Process(%d) = %+v, want = %+v", i, done, c.out[i].done)
+ }
+ if c.out[i].done {
+ if _, ok := f.reassemblers[in.id]; ok {
+ t.Errorf("Process(%d) did not remove buffer from reassemblers", i)
+ }
+ for n := f.rList.Front(); n != nil; n = n.Next() {
+ if n.id == in.id {
+ t.Errorf("Process(%d) did not remove buffer from rList", i)
+ }
+ }
+ }
+ }
+ })
+ }
+}
+
+func TestReassemblingTimeout(t *testing.T) {
+ timeout := time.Millisecond
+ f := NewFragmentation(1024, 512, timeout)
+ // Send first fragment with id = 0, first = 0, last = 0, and more = true.
+ f.Process(0, 0, 0, true, vv(1, "0"))
+ // Sleep more than the timeout.
+ time.Sleep(2 * timeout)
+ // Send another fragment that completes a packet.
+ // However, no packet should be reassembled because the fragment arrived after the timeout.
+ _, done, err := f.Process(0, 1, 1, false, vv(1, "1"))
+ if err != nil {
+ t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err)
+ }
+ if done {
+ t.Errorf("Fragmentation does not respect the reassembling timeout.")
+ }
+}
+
+func TestMemoryLimits(t *testing.T) {
+ f := NewFragmentation(3, 1, DefaultReassembleTimeout)
+ // Send first fragment with id = 0.
+ f.Process(0, 0, 0, true, vv(1, "0"))
+ // Send first fragment with id = 1.
+ f.Process(1, 0, 0, true, vv(1, "1"))
+ // Send first fragment with id = 2.
+ f.Process(2, 0, 0, true, vv(1, "2"))
+
+ // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
+ // evicted.
+ f.Process(3, 0, 0, true, vv(1, "3"))
+
+ if _, ok := f.reassemblers[0]; ok {
+ t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
+ }
+ if _, ok := f.reassemblers[1]; ok {
+ t.Errorf("Memory limits are not respected: id=1 has not been evicted.")
+ }
+ if _, ok := f.reassemblers[3]; !ok {
+ t.Errorf("Implementation of memory limits is wrong: id=3 is not present.")
+ }
+}
+
+func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
+ f := NewFragmentation(1, 0, DefaultReassembleTimeout)
+ // Send first fragment with id = 0.
+ f.Process(0, 0, 0, true, vv(1, "0"))
+ // Send the same packet again.
+ f.Process(0, 0, 0, true, vv(1, "0"))
+
+ got := f.size
+ want := 1
+ if got != want {
+ t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
new file mode 100644
index 000000000..0a83d81f2
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -0,0 +1,118 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fragmentation
+
+import (
+ "container/heap"
+ "fmt"
+ "math"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+type hole struct {
+ first uint16
+ last uint16
+ deleted bool
+}
+
+type reassembler struct {
+ reassemblerEntry
+ id uint32
+ size int
+ mu sync.Mutex
+ holes []hole
+ deleted int
+ heap fragHeap
+ done bool
+ creationTime time.Time
+}
+
+func newReassembler(id uint32) *reassembler {
+ r := &reassembler{
+ id: id,
+ holes: make([]hole, 0, 16),
+ deleted: 0,
+ heap: make(fragHeap, 0, 8),
+ creationTime: time.Now(),
+ }
+ r.holes = append(r.holes, hole{
+ first: 0,
+ last: math.MaxUint16,
+ deleted: false})
+ return r
+}
+
+// updateHoles updates the list of holes for an incoming fragment and
+// returns true iff the fragment filled at least part of an existing hole.
+func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
+ used := false
+ for i := range r.holes {
+ if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first {
+ continue
+ }
+ used = true
+ r.deleted++
+ r.holes[i].deleted = true
+ if first > r.holes[i].first {
+ r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false})
+ }
+ if last < r.holes[i].last && more {
+ r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false})
+ }
+ }
+ return used
+}
+
+func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ consumed := 0
+ if r.done {
+ // A concurrent goroutine might have already reassembled
+ // the packet and emptied the heap while this goroutine
+ // was waiting on the mutex. We don't have to do anything in this case.
+ return buffer.VectorisedView{}, false, consumed, nil
+ }
+ if r.updateHoles(first, last, more) {
+ // We store the incoming packet only if it filled some holes.
+ heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)})
+ consumed = vv.Size()
+ r.size += consumed
+ }
+ // Check if all the holes have been deleted and we are ready to reassamble.
+ if r.deleted < len(r.holes) {
+ return buffer.VectorisedView{}, false, consumed, nil
+ }
+ res, err := r.heap.reassemble()
+ if err != nil {
+ return buffer.VectorisedView{}, false, consumed, fmt.Errorf("fragment reassembly failed: %v", err)
+ }
+ return res, true, consumed, nil
+}
+
+func (r *reassembler) tooOld(timeout time.Duration) bool {
+ return time.Now().Sub(r.creationTime) > timeout
+}
+
+func (r *reassembler) checkDoneOrMark() bool {
+ r.mu.Lock()
+ prev := r.done
+ r.done = true
+ r.mu.Unlock()
+ return prev
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
new file mode 100644
index 000000000..7eee0710d
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -0,0 +1,105 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fragmentation
+
+import (
+ "math"
+ "reflect"
+ "testing"
+)
+
+type updateHolesInput struct {
+ first uint16
+ last uint16
+ more bool
+}
+
+var holesTestCases = []struct {
+ comment string
+ in []updateHolesInput
+ want []hole
+}{
+ {
+ comment: "No fragments. Expected holes: {[0 -> inf]}.",
+ in: []updateHolesInput{},
+ want: []hole{{first: 0, last: math.MaxUint16, deleted: false}},
+ },
+ {
+ comment: "One fragment at beginning. Expected holes: {[2, inf]}.",
+ in: []updateHolesInput{{first: 0, last: 1, more: true}},
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 2, last: math.MaxUint16, deleted: false},
+ },
+ },
+ {
+ comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.",
+ in: []updateHolesInput{{first: 1, last: 2, more: true}},
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 0, last: 0, deleted: false},
+ {first: 3, last: math.MaxUint16, deleted: false},
+ },
+ },
+ {
+ comment: "One fragment at the end. Expected holes: {[0, 0]}.",
+ in: []updateHolesInput{{first: 1, last: 2, more: false}},
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 0, last: 0, deleted: false},
+ },
+ },
+ {
+ comment: "One fragment completing a packet. Expected holes: {}.",
+ in: []updateHolesInput{{first: 0, last: 1, more: false}},
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ },
+ },
+ {
+ comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.",
+ in: []updateHolesInput{
+ {first: 0, last: 1, more: true},
+ {first: 2, last: 3, more: false},
+ },
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 2, last: math.MaxUint16, deleted: true},
+ },
+ },
+ {
+ comment: "Two overlapping fragments completing a packet. Expected holes: {}.",
+ in: []updateHolesInput{
+ {first: 0, last: 2, more: true},
+ {first: 2, last: 3, more: false},
+ },
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 3, last: math.MaxUint16, deleted: true},
+ },
+ },
+}
+
+func TestUpdateHoles(t *testing.T) {
+ for _, c := range holesTestCases {
+ r := newReassembler(0)
+ for _, i := range c.in {
+ r.updateHoles(i.first, i.last, i.more)
+ }
+ if !reflect.DeepEqual(r.holes, c.want) {
+ t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want)
+ }
+ }
+}
diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD
new file mode 100644
index 000000000..872165866
--- /dev/null
+++ b/pkg/tcpip/network/hash/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "hash",
+ srcs = ["hash.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/rand",
+ "//pkg/tcpip/header",
+ ],
+)
diff --git a/pkg/tcpip/network/hash/hash.go b/pkg/tcpip/network/hash/hash.go
new file mode 100644
index 000000000..8f65713c5
--- /dev/null
+++ b/pkg/tcpip/network/hash/hash.go
@@ -0,0 +1,93 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package hash contains utility functions for hashing.
+package hash
+
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+var hashIV = RandN32(1)[0]
+
+// RandN32 generates a slice of n cryptographic random 32-bit numbers.
+func RandN32(n int) []uint32 {
+ b := make([]byte, 4*n)
+ if _, err := rand.Read(b); err != nil {
+ panic("unable to get random numbers: " + err.Error())
+ }
+ r := make([]uint32, n)
+ for i := range r {
+ r[i] = binary.LittleEndian.Uint32(b[4*i : (4*i + 4)])
+ }
+ return r
+}
+
+// Hash3Words calculates the Jenkins hash of 3 32-bit words. This is adapted
+// from linux.
+func Hash3Words(a, b, c, initval uint32) uint32 {
+ const iv = 0xdeadbeef + (3 << 2)
+ initval += iv
+
+ a += initval
+ b += initval
+ c += initval
+
+ c ^= b
+ c -= rol32(b, 14)
+ a ^= c
+ a -= rol32(c, 11)
+ b ^= a
+ b -= rol32(a, 25)
+ c ^= b
+ c -= rol32(b, 16)
+ a ^= c
+ a -= rol32(c, 4)
+ b ^= a
+ b -= rol32(a, 14)
+ c ^= b
+ c -= rol32(b, 24)
+
+ return c
+}
+
+// IPv4FragmentHash computes the hash of the IPv4 fragment as suggested in RFC 791.
+func IPv4FragmentHash(h header.IPv4) uint32 {
+ x := uint32(h.ID())<<16 | uint32(h.Protocol())
+ t := h.SourceAddress()
+ y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ t = h.DestinationAddress()
+ z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ return Hash3Words(x, y, z, hashIV)
+}
+
+// IPv6FragmentHash computes the hash of the ipv6 fragment.
+// Unlike IPv4, the protocol is not used to compute the hash.
+// RFC 2640 (sec 4.5) is not very sharp on this aspect.
+// As a reference, also Linux ignores the protocol to compute
+// the hash (inet6_hash_frag).
+func IPv6FragmentHash(h header.IPv6, id uint32) uint32 {
+ t := h.SourceAddress()
+ y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ t = h.DestinationAddress()
+ z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ return Hash3Words(id, y, z, hashIV)
+}
+
+func rol32(v, shift uint32) uint32 {
+ return (v << shift) | (v >> ((-shift) & 31))
+}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
new file mode 100644
index 000000000..7c8fb3e0a
--- /dev/null
+++ b/pkg/tcpip/network/ip_test.go
@@ -0,0 +1,673 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+)
+
+const (
+ localIpv4Addr = "\x0a\x00\x00\x01"
+ localIpv4PrefixLen = 24
+ remoteIpv4Addr = "\x0a\x00\x00\x02"
+ ipv4SubnetAddr = "\x0a\x00\x00\x00"
+ ipv4SubnetMask = "\xff\xff\xff\x00"
+ ipv4Gateway = "\x0a\x00\x00\x03"
+ localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ localIpv6PrefixLen = 120
+ remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+ ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
+ ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
+)
+
+// testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
+// The former is used to pretend that it's a link endpoint so that we can
+// inspect packets written by the network endpoints. The latter is used to
+// pretend that it's the network stack so that it can inspect incoming packets
+// that have been handled by the network endpoints.
+//
+// Packets are checked by comparing their fields/values against the expected
+// values stored in the test object itself.
+type testObject struct {
+ t *testing.T
+ protocol tcpip.TransportProtocolNumber
+ contents []byte
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ v4 bool
+ typ stack.ControlType
+ extra uint32
+
+ dataCalls int
+ controlCalls int
+}
+
+// checkValues verifies that the transport protocol, data contents, src & dst
+// addresses of a packet match what's expected. If any field doesn't match, the
+// test fails.
+func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, srcAddr, dstAddr tcpip.Address) {
+ v := vv.ToView()
+ if protocol != t.protocol {
+ t.t.Errorf("protocol = %v, want %v", protocol, t.protocol)
+ }
+
+ if srcAddr != t.srcAddr {
+ t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr)
+ }
+
+ if dstAddr != t.dstAddr {
+ t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr)
+ }
+
+ if len(v) != len(t.contents) {
+ t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents))
+ }
+
+ for i := range t.contents {
+ if t.contents[i] != v[i] {
+ t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i])
+ }
+ }
+}
+
+// DeliverTransportPacket is called by network endpoints after parsing incoming
+// packets. This is used by the test object to verify that the results of the
+// parsing are expected.
+func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) {
+ t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress)
+ t.dataCalls++
+}
+
+// DeliverTransportControlPacket is called by network endpoints after parsing
+// incoming control (ICMP) packets. This is used by the test object to verify
+// that the results of the parsing are expected.
+func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
+ t.checkValues(trans, pkt.Data, remote, local)
+ if typ != t.typ {
+ t.t.Errorf("typ = %v, want %v", typ, t.typ)
+ }
+ if extra != t.extra {
+ t.t.Errorf("extra = %v, want %v", extra, t.extra)
+ }
+ t.controlCalls++
+}
+
+// Attach is only implemented to satisfy the LinkEndpoint interface.
+func (*testObject) Attach(stack.NetworkDispatcher) {}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (*testObject) IsAttached() bool {
+ return true
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It just returns a constant that
+// matches the linux loopback MTU.
+func (*testObject) MTU() uint32 {
+ return 65536
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (*testObject) Capabilities() stack.LinkEndpointCapabilities {
+ return 0
+}
+
+// MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface.
+func (*testObject) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (*testObject) LinkAddress() tcpip.LinkAddress {
+ return ""
+}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*testObject) Wait() {}
+
+// WritePacket is called by network endpoints after producing a packet and
+// writing it to the link endpoint. This is used by the test object to verify
+// that the produced packet is as expected.
+func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ var prot tcpip.TransportProtocolNumber
+ var srcAddr tcpip.Address
+ var dstAddr tcpip.Address
+
+ if t.v4 {
+ h := header.IPv4(pkt.Header.View())
+ prot = tcpip.TransportProtocolNumber(h.Protocol())
+ srcAddr = h.SourceAddress()
+ dstAddr = h.DestinationAddress()
+
+ } else {
+ h := header.IPv6(pkt.Header.View())
+ prot = tcpip.TransportProtocolNumber(h.NextHeader())
+ srcAddr = h.SourceAddress()
+ dstAddr = h.DestinationAddress()
+ }
+ t.checkValues(prot, pkt.Data, srcAddr, dstAddr)
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
+ s.CreateNIC(1, loopback.New())
+ s.AddAddress(1, ipv4.ProtocolNumber, local)
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv4EmptySubnet,
+ Gateway: ipv4Gateway,
+ NIC: 1,
+ }})
+
+ return s.FindRoute(1, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
+}
+
+func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
+ s.CreateNIC(1, loopback.New())
+ s.AddAddress(1, ipv6.ProtocolNumber, local)
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: ipv6Gateway,
+ NIC: 1,
+ }})
+
+ return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
+}
+
+func buildDummyStack() *stack.Stack {
+ return stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
+}
+
+func TestIPv4Send(t *testing.T) {
+ o := testObject{t: t, v4: true}
+ proto := ipv4.NewProtocol()
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o, buildDummyStack())
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Allocate and initialize the payload view.
+ payload := buffer.NewView(100)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = uint8(i)
+ }
+
+ // Allocate the header buffer.
+ hdr := buffer.NewPrependable(int(ep.MaxHeaderLength()))
+
+ // Issue the write.
+ o.protocol = 123
+ o.srcAddr = localIpv4Addr
+ o.dstAddr = remoteIpv4Addr
+ o.contents = payload
+
+ r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ if err != nil {
+ t.Fatalf("could not find route: %v", err)
+ }
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: 123,
+ TTL: 123,
+ TOS: stack.DefaultTOS,
+ }, &stack.PacketBuffer{
+ Header: hdr,
+ Data: payload.ToVectorisedView(),
+ }); err != nil {
+ t.Fatalf("WritePacket failed: %v", err)
+ }
+}
+
+func TestIPv4Receive(t *testing.T) {
+ o := testObject{t: t, v4: true}
+ proto := ipv4.NewProtocol()
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ totalLen := header.IPv4MinimumSize + 30
+ view := buffer.NewView(totalLen)
+ ip := header.IPv4(view)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ TTL: 20,
+ Protocol: 10,
+ SrcAddr: remoteIpv4Addr,
+ DstAddr: localIpv4Addr,
+ })
+
+ // Make payload be non-zero.
+ for i := header.IPv4MinimumSize; i < totalLen; i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
+ o.protocol = 10
+ o.srcAddr = remoteIpv4Addr
+ o.dstAddr = localIpv4Addr
+ o.contents = view[header.IPv4MinimumSize:totalLen]
+
+ r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ if err != nil {
+ t.Fatalf("could not find route: %v", err)
+ }
+ pkt := stack.PacketBuffer{Data: view.ToVectorisedView()}
+ proto.Parse(&pkt)
+ ep.HandlePacket(&r, &pkt)
+ if o.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ }
+}
+
+func TestIPv4ReceiveControl(t *testing.T) {
+ const mtu = 0xbeef - header.IPv4MinimumSize
+ cases := []struct {
+ name string
+ expectedCount int
+ fragmentOffset uint16
+ code uint8
+ expectedTyp stack.ControlType
+ expectedExtra uint32
+ trunc int
+ }{
+ {"FragmentationNeeded", 1, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 0},
+ {"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10},
+ {"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8},
+ {"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8},
+ {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4MinimumSize + header.IPv4MinimumSize + 8},
+ {"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8},
+ }
+ r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb")
+ if err != nil {
+ t.Fatal(err)
+ }
+ for _, c := range cases {
+ t.Run(c.name, func(t *testing.T) {
+ o := testObject{t: t}
+ proto := ipv4.NewProtocol()
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize
+ view := buffer.NewView(dataOffset + 8)
+
+ // Create the outer IPv4 header.
+ ip := header.IPv4(view)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(len(view) - c.trunc),
+ TTL: 20,
+ Protocol: uint8(header.ICMPv4ProtocolNumber),
+ SrcAddr: "\x0a\x00\x00\xbb",
+ DstAddr: localIpv4Addr,
+ })
+
+ // Create the ICMP header.
+ icmp := header.ICMPv4(view[header.IPv4MinimumSize:])
+ icmp.SetType(header.ICMPv4DstUnreachable)
+ icmp.SetCode(c.code)
+ icmp.SetIdent(0xdead)
+ icmp.SetSequence(0xbeef)
+
+ // Create the inner IPv4 header.
+ ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:])
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: 100,
+ TTL: 20,
+ Protocol: 10,
+ FragmentOffset: c.fragmentOffset,
+ SrcAddr: localIpv4Addr,
+ DstAddr: remoteIpv4Addr,
+ })
+
+ // Make payload be non-zero.
+ for i := dataOffset; i < len(view); i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to IPv4 endpoint, dispatcher will validate that
+ // it's ok.
+ o.protocol = 10
+ o.srcAddr = remoteIpv4Addr
+ o.dstAddr = localIpv4Addr
+ o.contents = view[dataOffset:]
+ o.typ = c.expectedTyp
+ o.extra = c.expectedExtra
+
+ ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize))
+ if want := c.expectedCount; o.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ }
+ })
+ }
+}
+
+func TestIPv4FragmentationReceive(t *testing.T) {
+ o := testObject{t: t, v4: true}
+ proto := ipv4.NewProtocol()
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ totalLen := header.IPv4MinimumSize + 24
+
+ frag1 := buffer.NewView(totalLen)
+ ip1 := header.IPv4(frag1)
+ ip1.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ TTL: 20,
+ Protocol: 10,
+ FragmentOffset: 0,
+ Flags: header.IPv4FlagMoreFragments,
+ SrcAddr: remoteIpv4Addr,
+ DstAddr: localIpv4Addr,
+ })
+ // Make payload be non-zero.
+ for i := header.IPv4MinimumSize; i < totalLen; i++ {
+ frag1[i] = uint8(i)
+ }
+
+ frag2 := buffer.NewView(totalLen)
+ ip2 := header.IPv4(frag2)
+ ip2.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ TTL: 20,
+ Protocol: 10,
+ FragmentOffset: 24,
+ SrcAddr: remoteIpv4Addr,
+ DstAddr: localIpv4Addr,
+ })
+ // Make payload be non-zero.
+ for i := header.IPv4MinimumSize; i < totalLen; i++ {
+ frag2[i] = uint8(i)
+ }
+
+ // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
+ o.protocol = 10
+ o.srcAddr = remoteIpv4Addr
+ o.dstAddr = localIpv4Addr
+ o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
+
+ r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ if err != nil {
+ t.Fatalf("could not find route: %v", err)
+ }
+
+ // Send first segment.
+ pkt := stack.PacketBuffer{Data: frag1.ToVectorisedView()}
+ proto.Parse(&pkt)
+ ep.HandlePacket(&r, &pkt)
+ if o.dataCalls != 0 {
+ t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
+ }
+
+ // Send second segment.
+ pkt = stack.PacketBuffer{Data: frag2.ToVectorisedView()}
+ proto.Parse(&pkt)
+ ep.HandlePacket(&r, &pkt)
+ if o.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ }
+}
+
+func TestIPv6Send(t *testing.T) {
+ o := testObject{t: t}
+ proto := ipv6.NewProtocol()
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o, buildDummyStack())
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Allocate and initialize the payload view.
+ payload := buffer.NewView(100)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = uint8(i)
+ }
+
+ // Allocate the header buffer.
+ hdr := buffer.NewPrependable(int(ep.MaxHeaderLength()))
+
+ // Issue the write.
+ o.protocol = 123
+ o.srcAddr = localIpv6Addr
+ o.dstAddr = remoteIpv6Addr
+ o.contents = payload
+
+ r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
+ if err != nil {
+ t.Fatalf("could not find route: %v", err)
+ }
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: 123,
+ TTL: 123,
+ TOS: stack.DefaultTOS,
+ }, &stack.PacketBuffer{
+ Header: hdr,
+ Data: payload.ToVectorisedView(),
+ }); err != nil {
+ t.Fatalf("WritePacket failed: %v", err)
+ }
+}
+
+func TestIPv6Receive(t *testing.T) {
+ o := testObject{t: t}
+ proto := ipv6.NewProtocol()
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack())
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ totalLen := header.IPv6MinimumSize + 30
+ view := buffer.NewView(totalLen)
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(totalLen - header.IPv6MinimumSize),
+ NextHeader: 10,
+ HopLimit: 20,
+ SrcAddr: remoteIpv6Addr,
+ DstAddr: localIpv6Addr,
+ })
+
+ // Make payload be non-zero.
+ for i := header.IPv6MinimumSize; i < totalLen; i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
+ o.protocol = 10
+ o.srcAddr = remoteIpv6Addr
+ o.dstAddr = localIpv6Addr
+ o.contents = view[header.IPv6MinimumSize:totalLen]
+
+ r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
+ if err != nil {
+ t.Fatalf("could not find route: %v", err)
+ }
+
+ pkt := stack.PacketBuffer{Data: view.ToVectorisedView()}
+ proto.Parse(&pkt)
+ ep.HandlePacket(&r, &pkt)
+ if o.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ }
+}
+
+func TestIPv6ReceiveControl(t *testing.T) {
+ newUint16 := func(v uint16) *uint16 { return &v }
+
+ const mtu = 0xffff
+ const outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa"
+ cases := []struct {
+ name string
+ expectedCount int
+ fragmentOffset *uint16
+ typ header.ICMPv6Type
+ code uint8
+ expectedTyp stack.ControlType
+ expectedExtra uint32
+ trunc int
+ }{
+ {"PacketTooBig", 1, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 0},
+ {"Truncated (10 bytes missing)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 10},
+ {"Truncated (missing IPv6 header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.IPv6MinimumSize + 8},
+ {"Truncated PacketTooBig (missing 'extra info')", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 4 + header.IPv6MinimumSize + 8},
+ {"Truncated (missing ICMP header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + 8},
+ {"Port unreachable", 1, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Truncated DstUnreachable (missing 'extra info')", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 4 + header.IPv6MinimumSize + 8},
+ {"Fragmented, zero offset", 1, newUint16(0), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8},
+ }
+ r, err := buildIPv6Route(
+ localIpv6Addr,
+ "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
+ )
+ if err != nil {
+ t.Fatal(err)
+ }
+ for _, c := range cases {
+ t.Run(c.name, func(t *testing.T) {
+ o := testObject{t: t}
+ proto := ipv6.NewProtocol()
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack())
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ defer ep.Close()
+
+ dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize
+ if c.fragmentOffset != nil {
+ dataOffset += header.IPv6FragmentHeaderSize
+ }
+ view := buffer.NewView(dataOffset + 8)
+
+ // Create the outer IPv6 header.
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 20,
+ SrcAddr: outerSrcAddr,
+ DstAddr: localIpv6Addr,
+ })
+
+ // Create the ICMP header.
+ icmp := header.ICMPv6(view[header.IPv6MinimumSize:])
+ icmp.SetType(c.typ)
+ icmp.SetCode(c.code)
+ icmp.SetIdent(0xdead)
+ icmp.SetSequence(0xbeef)
+
+ // Create the inner IPv6 header.
+ ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: 100,
+ NextHeader: 10,
+ HopLimit: 20,
+ SrcAddr: localIpv6Addr,
+ DstAddr: remoteIpv6Addr,
+ })
+
+ // Build the fragmentation header if needed.
+ if c.fragmentOffset != nil {
+ ip.SetNextHeader(header.IPv6FragmentHeader)
+ frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:])
+ frag.Encode(&header.IPv6FragmentFields{
+ NextHeader: 10,
+ FragmentOffset: *c.fragmentOffset,
+ M: true,
+ Identification: 0x12345678,
+ })
+ }
+
+ // Make payload be non-zero.
+ for i := dataOffset; i < len(view); i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to IPv6 endpoint, dispatcher will validate that
+ // it's ok.
+ o.protocol = 10
+ o.srcAddr = remoteIpv6Addr
+ o.dstAddr = localIpv6Addr
+ o.contents = view[dataOffset:]
+ o.typ = c.expectedTyp
+ o.extra = c.expectedExtra
+
+ // Set ICMPv6 checksum.
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{}))
+
+ ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize))
+ if want := c.expectedCount; o.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ }
+ })
+ }
+}
+
+// truncatedPacket returns a PacketBuffer based on a truncated view. If view,
+// after truncation, is large enough to hold a network header, it makes part of
+// view the packet's NetworkHeader and the rest its Data. Otherwise all of view
+// becomes Data.
+func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer {
+ v := view[:len(view)-trunc]
+ if len(v) < netHdrLen {
+ return &stack.PacketBuffer{Data: v.ToVectorisedView()}
+ }
+ return &stack.PacketBuffer{
+ NetworkHeader: v[:netHdrLen],
+ Data: v[netHdrLen:].ToVectorisedView(),
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
new file mode 100644
index 000000000..78420d6e6
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -0,0 +1,39 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "ipv4",
+ srcs = [
+ "icmp.go",
+ "ipv4.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/network/fragmentation",
+ "//pkg/tcpip/network/hash",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "ipv4_test",
+ size = "small",
+ srcs = ["ipv4_test.go"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
new file mode 100644
index 000000000..1b67aa066
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -0,0 +1,167 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// handleControl handles the case when an ICMP packet contains the headers of
+// the original packet that caused the ICMP one to be sent. This information is
+// used to find out which transport endpoint must be notified about the ICMP
+// packet.
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
+ h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return
+ }
+ hdr := header.IPv4(h)
+
+ // We don't use IsValid() here because ICMP only requires that the IP
+ // header plus 8 bytes of the transport header be included. So it's
+ // likely that it is truncated, which would cause IsValid to return
+ // false.
+ //
+ // Drop packet if it doesn't have the basic IPv4 header or if the
+ // original source address doesn't match the endpoint's address.
+ if hdr.SourceAddress() != e.id.LocalAddress {
+ return
+ }
+
+ hlen := int(hdr.HeaderLength())
+ if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 {
+ // We won't be able to handle this if it doesn't contain the
+ // full IPv4 header, or if it's a fragment not at offset 0
+ // (because it won't have the transport header).
+ return
+ }
+
+ // Skip the ip header, then deliver control message.
+ pkt.Data.TrimFront(hlen)
+ p := hdr.TransportProtocol()
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+}
+
+func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
+ stats := r.Stats()
+ received := stats.ICMP.V4PacketsReceived
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their
+ // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
+ v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
+ h := header.ICMPv4(v)
+
+ // TODO(b/112892170): Meaningfully handle all ICMP types.
+ switch h.Type() {
+ case header.ICMPv4Echo:
+ received.Echo.Increment()
+
+ // Only send a reply if the checksum is valid.
+ wantChecksum := h.Checksum()
+ // Reset the checksum field to 0 to can calculate the proper
+ // checksum. We'll have to reset this before we hand the packet
+ // off.
+ h.SetChecksum(0)
+ gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
+ if gotChecksum != wantChecksum {
+ // It's possible that a raw socket expects to receive this.
+ h.SetChecksum(wantChecksum)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
+ received.Invalid.Increment()
+ return
+ }
+
+ // It's possible that a raw socket expects to receive this.
+ h.SetChecksum(wantChecksum)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, &stack.PacketBuffer{
+ Data: pkt.Data.Clone(nil),
+ NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...),
+ })
+
+ vv := pkt.Data.Clone(nil)
+ vv.TrimFront(header.ICMPv4MinimumSize)
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ copy(pkt, h)
+ pkt.SetType(header.ICMPv4EchoReply)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0)))
+ sent := stats.ICMP.V4PacketsSent
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv4ProtocolNumber,
+ TTL: r.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ }, &stack.PacketBuffer{
+ Header: hdr,
+ Data: vv,
+ TransportHeader: buffer.View(pkt),
+ }); err != nil {
+ sent.Dropped.Increment()
+ return
+ }
+ sent.EchoReply.Increment()
+
+ case header.ICMPv4EchoReply:
+ received.EchoReply.Increment()
+
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
+
+ case header.ICMPv4DstUnreachable:
+ received.DstUnreachable.Increment()
+
+ pkt.Data.TrimFront(header.ICMPv4MinimumSize)
+ switch h.Code() {
+ case header.ICMPv4PortUnreachable:
+ e.handleControl(stack.ControlPortUnreachable, 0, pkt)
+
+ case header.ICMPv4FragmentationNeeded:
+ mtu := uint32(h.MTU())
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
+ }
+
+ case header.ICMPv4SrcQuench:
+ received.SrcQuench.Increment()
+
+ case header.ICMPv4Redirect:
+ received.Redirect.Increment()
+
+ case header.ICMPv4TimeExceeded:
+ received.TimeExceeded.Increment()
+
+ case header.ICMPv4ParamProblem:
+ received.ParamProblem.Increment()
+
+ case header.ICMPv4Timestamp:
+ received.Timestamp.Increment()
+
+ case header.ICMPv4TimestampReply:
+ received.TimestampReply.Increment()
+
+ case header.ICMPv4InfoRequest:
+ received.InfoRequest.Increment()
+
+ case header.ICMPv4InfoReply:
+ received.InfoReply.Increment()
+
+ default:
+ received.Invalid.Increment()
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
new file mode 100644
index 000000000..b1776e5ee
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -0,0 +1,594 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ipv4 contains the implementation of the ipv4 network protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing ipv4.NewProtocol() as one of the network
+// protocols when calling stack.New(). Then endpoints can be created by passing
+// ipv4.ProtocolNumber as the network protocol number when calling
+// Stack.NewEndpoint().
+package ipv4
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
+ "gvisor.dev/gvisor/pkg/tcpip/network/hash"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // ProtocolNumber is the ipv4 protocol number.
+ ProtocolNumber = header.IPv4ProtocolNumber
+
+ // MaxTotalSize is maximum size that can be encoded in the 16-bit
+ // TotalLength field of the ipv4 header.
+ MaxTotalSize = 0xffff
+
+ // DefaultTTL is the default time-to-live value for this endpoint.
+ DefaultTTL = 64
+
+ // buckets is the number of identifier buckets.
+ buckets = 2048
+)
+
+type endpoint struct {
+ nicID tcpip.NICID
+ id stack.NetworkEndpointID
+ prefixLen int
+ linkEP stack.LinkEndpoint
+ dispatcher stack.TransportDispatcher
+ fragmentation *fragmentation.Fragmentation
+ protocol *protocol
+ stack *stack.Stack
+}
+
+// NewEndpoint creates a new ipv4 endpoint.
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
+ e := &endpoint{
+ nicID: nicID,
+ id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
+ prefixLen: addrWithPrefix.PrefixLen,
+ linkEP: linkEP,
+ dispatcher: dispatcher,
+ fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ protocol: p,
+ stack: st,
+ }
+
+ return e, nil
+}
+
+// DefaultTTL is the default time-to-live value for this endpoint.
+func (e *endpoint) DefaultTTL() uint8 {
+ return e.protocol.DefaultTTL()
+}
+
+// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
+// the network layer max header length.
+func (e *endpoint) MTU() uint32 {
+ return calculateMTU(e.linkEP.MTU())
+}
+
+// Capabilities implements stack.NetworkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+// NICID returns the ID of the NIC this endpoint belongs to.
+func (e *endpoint) NICID() tcpip.NICID {
+ return e.nicID
+}
+
+// ID returns the ipv4 endpoint ID.
+func (e *endpoint) ID() *stack.NetworkEndpointID {
+ return &e.id
+}
+
+// PrefixLen returns the ipv4 endpoint subnet prefix length in bits.
+func (e *endpoint) PrefixLen() int {
+ return e.prefixLen
+}
+
+// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
+// underlying protocols).
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (e *endpoint) GSOMaxSize() uint32 {
+ if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
+ return gso.GSOMaxSize()
+ }
+ return 0
+}
+
+// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
+func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return e.protocol.Number()
+}
+
+// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to
+// write. It assumes that the IP header is entirely in pkt.Header but does not
+// assume that only the IP header is in pkt.Header. It assumes that the input
+// packet's stated length matches the length of the header+payload. mtu
+// includes the IP header and options. This does not support the DontFragment
+// IP flag.
+func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error {
+ // This packet is too big, it needs to be fragmented.
+ ip := header.IPv4(pkt.Header.View())
+ flags := ip.Flags()
+
+ // Update mtu to take into account the header, which will exist in all
+ // fragments anyway.
+ innerMTU := mtu - int(ip.HeaderLength())
+
+ // Round the MTU down to align to 8 bytes. Then calculate the number of
+ // fragments. Calculate fragment sizes as in RFC791.
+ innerMTU &^= 7
+ n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU
+
+ outerMTU := innerMTU + int(ip.HeaderLength())
+ offset := ip.FragmentOffset()
+ originalAvailableLength := pkt.Header.AvailableLength()
+ for i := 0; i < n; i++ {
+ // Where possible, the first fragment that is sent has the same
+ // pkt.Header.UsedLength() as the input packet. The link-layer
+ // endpoint may depend on this for looking at, eg, L4 headers.
+ h := ip
+ if i > 0 {
+ pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength)
+ h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength())))
+ copy(h, ip[:ip.HeaderLength()])
+ }
+ if i != n-1 {
+ h.SetTotalLength(uint16(outerMTU))
+ h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset)
+ } else {
+ h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size()))
+ h.SetFlagsFragmentOffset(flags, offset)
+ }
+ h.SetChecksum(0)
+ h.SetChecksum(^h.CalculateChecksum())
+ offset += uint16(innerMTU)
+ if i > 0 {
+ newPayload := pkt.Data.Clone(nil)
+ newPayload.CapLength(innerMTU)
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{
+ Header: pkt.Header,
+ Data: newPayload,
+ NetworkHeader: buffer.View(h),
+ }); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ pkt.Data.TrimFront(newPayload.Size())
+ continue
+ }
+ // Special handling for the first fragment because it comes
+ // from the header.
+ if outerMTU >= pkt.Header.UsedLength() {
+ // This fragment can fit all of pkt.Header and possibly
+ // some of pkt.Data, too.
+ newPayload := pkt.Data.Clone(nil)
+ newPayloadLength := outerMTU - pkt.Header.UsedLength()
+ newPayload.CapLength(newPayloadLength)
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{
+ Header: pkt.Header,
+ Data: newPayload,
+ NetworkHeader: buffer.View(h),
+ }); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ pkt.Data.TrimFront(newPayloadLength)
+ } else {
+ // The fragment is too small to fit all of pkt.Header.
+ startOfHdr := pkt.Header
+ startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU)
+ emptyVV := buffer.NewVectorisedView(0, []buffer.View{})
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{
+ Header: startOfHdr,
+ Data: emptyVV,
+ NetworkHeader: buffer.View(h),
+ }); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ // Add the unused bytes of pkt.Header into the pkt.Data
+ // that remains to be sent.
+ restOfHdr := pkt.Header.View()[outerMTU:]
+ tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)})
+ tmp.Append(pkt.Data)
+ pkt.Data = tmp
+ }
+ }
+ return nil
+}
+
+func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 {
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ length := uint16(hdr.UsedLength() + payloadSize)
+ // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic
+ // datagrams. Since the DF bit is never being set here, all datagrams
+ // are non-atomic and need an ID.
+ id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: length,
+ ID: uint16(id),
+ TTL: params.TTL,
+ TOS: params.TOS,
+ Protocol: uint8(params.Protocol),
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+ return ip
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
+ ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
+ pkt.NetworkHeader = buffer.View(ip)
+
+ nicName := e.stack.FindNICNameFromID(e.NICID())
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ ipt := e.stack.IPTables()
+ if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ // iptables is telling us to drop the packet.
+ return nil
+ }
+
+ // If the packet is manipulated as per NAT Ouput rules, handle packet
+ // based on destination address and do not send the packet to link layer.
+ // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than
+ // only NATted packets, but removing this check short circuits broadcasts
+ // before they are sent out to other hosts.
+ if pkt.NatDone {
+ netHeader := header.IPv4(pkt.NetworkHeader)
+ ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
+ if err == nil {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ ep.HandlePacket(&route, pkt)
+ return nil
+ }
+ }
+
+ if r.Loop&stack.PacketLoop != 0 {
+ loopedR := r.MakeLoopedRoute()
+ e.HandlePacket(&loopedR, pkt)
+ loopedR.Release()
+ }
+ if r.Loop&stack.PacketOut == 0 {
+ return nil
+ }
+ if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
+ return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt)
+ }
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ return nil
+}
+
+// WritePackets implements stack.NetworkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+ if r.Loop&stack.PacketLoop != 0 {
+ panic("multiple packets in local loop")
+ }
+ if r.Loop&stack.PacketOut == 0 {
+ return pkts.Len(), nil
+ }
+
+ for pkt := pkts.Front(); pkt != nil; {
+ ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
+ pkt.NetworkHeader = buffer.View(ip)
+ pkt = pkt.Next()
+ }
+
+ nicName := e.stack.FindNICNameFromID(e.NICID())
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ ipt := e.stack.IPTables()
+ dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
+ if len(dropped) == 0 && len(natPkts) == 0 {
+ // Fast path: If no packets are to be dropped then we can just invoke the
+ // faster WritePackets API directly.
+ n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+ }
+
+ // Slow Path as we are dropping some packets in the batch degrade to
+ // emitting one packet at a time.
+ n := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if _, ok := dropped[pkt]; ok {
+ continue
+ }
+ if _, ok := natPkts[pkt]; ok {
+ netHeader := header.IPv4(pkt.NetworkHeader)
+ if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ src := netHeader.SourceAddress()
+ dst := netHeader.DestinationAddress()
+ route := r.ReverseRoute(src, dst)
+ ep.HandlePacket(&route, pkt)
+ n++
+ continue
+ }
+ }
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+ }
+ n++
+ }
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, nil
+}
+
+// WriteHeaderIncludedPacket writes a packet already containing a network
+// header through the given route.
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
+ // The packet already has an IP header, but there are a few required
+ // checks.
+ h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return tcpip.ErrInvalidOptionValue
+ }
+ ip := header.IPv4(h)
+ if !ip.IsValid(pkt.Data.Size()) {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ // Always set the total length.
+ ip.SetTotalLength(uint16(pkt.Data.Size()))
+
+ // Set the source address when zero.
+ if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) {
+ ip.SetSourceAddress(r.LocalAddress)
+ }
+
+ // Set the destination. If the packet already included a destination,
+ // it will be part of the route.
+ ip.SetDestinationAddress(r.RemoteAddress)
+
+ // Set the packet ID when zero.
+ if ip.ID() == 0 {
+ // RFC 6864 section 4.3 mandates uniqueness of ID values for
+ // non-atomic datagrams, so assign an ID to all such datagrams
+ // according to the definition given in RFC 6864 section 4.
+ if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 {
+ ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
+ }
+ }
+
+ // Always set the checksum.
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ if r.Loop&stack.PacketLoop != 0 {
+ e.HandlePacket(r, pkt.Clone())
+ }
+ if r.Loop&stack.PacketOut == 0 {
+ return nil
+ }
+
+ r.Stats().IP.PacketsSent.Increment()
+
+ ip = ip[:ip.HeaderLength()]
+ pkt.Header = buffer.NewPrependableFromView(buffer.View(ip))
+ pkt.Data.TrimFront(int(ip.HeaderLength()))
+ return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+}
+
+// HandlePacket is called by the link layer when new ipv4 packets arrive for
+// this endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ h := header.IPv4(pkt.NetworkHeader)
+ if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+
+ // iptables filtering. All packets that reach here are intended for
+ // this machine and will not be forwarded.
+ ipt := e.stack.IPTables()
+ if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ // iptables is telling us to drop the packet.
+ return
+ }
+
+ if h.More() || h.FragmentOffset() != 0 {
+ if pkt.Data.Size()+len(pkt.TransportHeader) == 0 {
+ // Drop the packet as it's marked as a fragment but has
+ // no payload.
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+ // The packet is a fragment, let's try to reassemble it.
+ last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1
+ // Drop the packet if the fragmentOffset is incorrect. i.e the
+ // combination of fragmentOffset and pkt.Data.size() causes a
+ // wrap around resulting in last being less than the offset.
+ if last < h.FragmentOffset() {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+ var ready bool
+ var err error
+ pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, h.More(), pkt.Data)
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+ if !ready {
+ return
+ }
+ }
+ p := h.TransportProtocol()
+ if p == header.ICMPv4ProtocolNumber {
+ pkt.NetworkHeader.CapLength(int(h.HeaderLength()))
+ e.handleICMP(r, pkt)
+ return
+ }
+ r.Stats().IP.PacketsDelivered.Increment()
+ e.dispatcher.DeliverTransportPacket(r, p, pkt)
+}
+
+// Close cleans up resources associated with the endpoint.
+func (e *endpoint) Close() {}
+
+type protocol struct {
+ ids []uint32
+ hashIV uint32
+
+ // defaultTTL is the current default TTL for the protocol. Only the
+ // uint8 portion of it is meaningful and it must be accessed
+ // atomically.
+ defaultTTL uint32
+}
+
+// Number returns the ipv4 protocol number.
+func (p *protocol) Number() tcpip.NetworkProtocolNumber {
+ return ProtocolNumber
+}
+
+// MinimumPacketSize returns the minimum valid ipv4 packet size.
+func (p *protocol) MinimumPacketSize() int {
+ return header.IPv4MinimumSize
+}
+
+// DefaultPrefixLen returns the IPv4 default prefix length.
+func (p *protocol) DefaultPrefixLen() int {
+ return header.IPv4AddressSize * 8
+}
+
+// ParseAddresses implements NetworkProtocol.ParseAddresses.
+func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.IPv4(v)
+ return h.SourceAddress(), h.DestinationAddress()
+}
+
+// SetOption implements NetworkProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case tcpip.DefaultTTLOption:
+ p.SetDefaultTTL(uint8(v))
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Option implements NetworkProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *tcpip.DefaultTTLOption:
+ *v = tcpip.DefaultTTLOption(p.DefaultTTL())
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// SetDefaultTTL sets the default TTL for endpoints created with this protocol.
+func (p *protocol) SetDefaultTTL(ttl uint8) {
+ atomic.StoreUint32(&p.defaultTTL, uint32(ttl))
+}
+
+// DefaultTTL returns the default TTL for endpoints created with this protocol.
+func (p *protocol) DefaultTTL() uint8 {
+ return uint8(atomic.LoadUint32(&p.defaultTTL))
+}
+
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
+ hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return 0, false, false
+ }
+ ipHdr := header.IPv4(hdr)
+
+ // If there are options, pull those into hdr as well.
+ if headerLen := int(ipHdr.HeaderLength()); headerLen > header.IPv4MinimumSize && headerLen <= pkt.Data.Size() {
+ hdr, ok = pkt.Data.PullUp(headerLen)
+ if !ok {
+ panic(fmt.Sprintf("There are only %d bytes in pkt.Data, but there should be at least %d", pkt.Data.Size(), headerLen))
+ }
+ ipHdr = header.IPv4(hdr)
+ }
+
+ // If this is a fragment, don't bother parsing the transport header.
+ parseTransportHeader := true
+ if ipHdr.More() || ipHdr.FragmentOffset() != 0 {
+ parseTransportHeader = false
+ }
+
+ pkt.NetworkHeader = hdr
+ pkt.Data.TrimFront(len(hdr))
+ pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr))
+ return ipHdr.TransportProtocol(), parseTransportHeader, true
+}
+
+// calculateMTU calculates the network-layer payload MTU based on the link-layer
+// payload mtu.
+func calculateMTU(mtu uint32) uint32 {
+ if mtu > MaxTotalSize {
+ mtu = MaxTotalSize
+ }
+ return mtu - header.IPv4MinimumSize
+}
+
+// hashRoute calculates a hash value for the given route. It uses the source &
+// destination address, the transport protocol number, and a random initial
+// value (generated once on initialization) to generate the hash.
+func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
+ t := r.LocalAddress
+ a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ t = r.RemoteAddress
+ b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ return hash.Hash3Words(a, b, uint32(protocol), hashIV)
+}
+
+// NewProtocol returns an IPv4 network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ ids := make([]uint32, buckets)
+
+ // Randomly initialize hashIV and the ids.
+ r := hash.RandN32(1 + buckets)
+ for i := range ids {
+ ids[i] = r[i]
+ }
+ hashIV := r[buckets]
+
+ return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL}
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
new file mode 100644
index 000000000..11e579c4b
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -0,0 +1,745 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4_test
+
+import (
+ "bytes"
+ "encoding/hex"
+ "math/rand"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestExcludeBroadcast(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+
+ const defaultMTU = 65536
+ ep := stack.LinkEndpoint(channel.New(256, defaultMTU, ""))
+ if testing.Verbose() {
+ ep = sniffer.New(ep)
+ }
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv4EmptySubnet,
+ NIC: 1,
+ }})
+
+ randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53}
+
+ var wq waiter.Queue
+ t.Run("WithoutPrimaryAddress", func(t *testing.T) {
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ep.Close()
+
+ // Cannot connect using a broadcast address as the source.
+ if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute {
+ t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute)
+ }
+
+ // However, we can bind to a broadcast address to listen.
+ if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}); err != nil {
+ t.Errorf("Bind failed: %v", err)
+ }
+ })
+
+ t.Run("WithPrimaryAddress", func(t *testing.T) {
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ep.Close()
+
+ // Add a valid primary endpoint address, now we can connect.
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+ if err := ep.Connect(randomAddr); err != nil {
+ t.Errorf("Connect failed: %v", err)
+ }
+ })
+}
+
+// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much
+// data should already be in the header before WritePacket. extraLength
+// indicates how much extra space should be in the header. The payload is made
+// from many Views of the sizes listed in viewSizes.
+func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) {
+ hdr := buffer.NewPrependable(hdrLength + extraLength)
+ hdr.Prepend(hdrLength)
+ rand.Read(hdr.View())
+
+ var views []buffer.View
+ totalLength := 0
+ for _, s := range viewSizes {
+ newView := buffer.NewView(s)
+ rand.Read(newView)
+ views = append(views, newView)
+ totalLength += s
+ }
+ payload := buffer.NewVectorisedView(totalLength, views)
+ return hdr, payload
+}
+
+// comparePayloads compared the contents of all the packets against the contents
+// of the source packet.
+func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) {
+ t.Helper()
+ // Make a complete array of the sourcePacketInfo packet.
+ source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize])
+ source = append(source, sourcePacketInfo.Header.View()...)
+ source = append(source, sourcePacketInfo.Data.ToView()...)
+
+ // Make a copy of the IP header, which will be modified in some fields to make
+ // an expected header.
+ sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...))
+ sourceCopy.SetChecksum(0)
+ sourceCopy.SetFlagsFragmentOffset(0, 0)
+ sourceCopy.SetTotalLength(0)
+ var offset uint16
+ // Build up an array of the bytes sent.
+ var reassembledPayload []byte
+ for i, packet := range packets {
+ // Confirm that the packet is valid.
+ allBytes := packet.Header.View().ToVectorisedView()
+ allBytes.Append(packet.Data)
+ ip := header.IPv4(allBytes.ToView())
+ if !ip.IsValid(len(ip)) {
+ t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip))
+ }
+ if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want {
+ t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want)
+ }
+ if got, want := len(ip), int(mtu); got > want {
+ t.Errorf("fragment is too large, got %d want %d", got, want)
+ }
+ if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want {
+ t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want)
+ }
+ if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want {
+ t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want)
+ }
+ if i < len(packets)-1 {
+ sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset)
+ } else {
+ sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset)
+ }
+ reassembledPayload = append(reassembledPayload, ip.Payload()...)
+ offset += ip.TotalLength() - uint16(ip.HeaderLength())
+ // Clear out the checksum and length from the ip because we can't compare
+ // it.
+ sourceCopy.SetTotalLength(uint16(len(ip)))
+ sourceCopy.SetChecksum(0)
+ sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum())
+ if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) {
+ t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()]))
+ }
+ }
+ expected := source[source.HeaderLength():]
+ if !bytes.Equal(reassembledPayload, expected) {
+ t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected))
+ }
+}
+
+type errorChannel struct {
+ *channel.Endpoint
+ Ch chan *stack.PacketBuffer
+ packetCollectorErrors []*tcpip.Error
+}
+
+// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket
+// will return successive errors from packetCollectorErrors until the list is
+// empty and then return nil each time.
+func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
+ return &errorChannel{
+ Endpoint: channel.New(size, mtu, linkAddr),
+ Ch: make(chan *stack.PacketBuffer, size),
+ packetCollectorErrors: packetCollectorErrors,
+ }
+}
+
+// Drain removes all outbound packets from the channel and counts them.
+func (e *errorChannel) Drain() int {
+ c := 0
+ for {
+ select {
+ case <-e.Ch:
+ c++
+ default:
+ return c
+ }
+ }
+}
+
+// WritePacket stores outbound packets into the channel.
+func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ select {
+ case e.Ch <- pkt:
+ default:
+ }
+
+ nextError := (*tcpip.Error)(nil)
+ if len(e.packetCollectorErrors) > 0 {
+ nextError = e.packetCollectorErrors[0]
+ e.packetCollectorErrors = e.packetCollectorErrors[1:]
+ }
+ return nextError
+}
+
+type context struct {
+ stack.Route
+ linkEP *errorChannel
+}
+
+func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
+ // Make the packet and write it.
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ })
+ ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
+ s.CreateNIC(1, ep)
+ const (
+ src = "\x10\x00\x00\x01"
+ dst = "\x10\x00\x00\x02"
+ )
+ s.AddAddress(1, ipv4.ProtocolNumber, src)
+ {
+ subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }})
+ }
+ r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute got %v, want %v", err, nil)
+ }
+ return context{
+ Route: r,
+ linkEP: ep,
+ }
+}
+
+func TestFragmentation(t *testing.T) {
+ var manyPayloadViewsSizes [1000]int
+ for i := range manyPayloadViewsSizes {
+ manyPayloadViewsSizes[i] = 7
+ }
+ fragTests := []struct {
+ description string
+ mtu uint32
+ gso *stack.GSO
+ hdrLength int
+ extraLength int
+ payloadViewsSizes []int
+ expectedFrags int
+ }{
+ {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
+ {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
+ {"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2},
+ {"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2},
+ {"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25},
+ {"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25},
+ {"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2},
+ {"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2},
+ {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6},
+ }
+
+ for _, ft := range fragTests {
+ t.Run(ft.description, func(t *testing.T) {
+ hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
+ source := &stack.PacketBuffer{
+ Header: hdr,
+ // Save the source payload because WritePacket will modify it.
+ Data: payload.Clone(nil),
+ }
+ c := buildContext(t, nil, ft.mtu)
+ err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: 42,
+ TOS: stack.DefaultTOS,
+ }, &stack.PacketBuffer{
+ Header: hdr,
+ Data: payload,
+ })
+ if err != nil {
+ t.Errorf("err got %v, want %v", err, nil)
+ }
+
+ var results []*stack.PacketBuffer
+ L:
+ for {
+ select {
+ case pi := <-c.linkEP.Ch:
+ results = append(results, pi)
+ default:
+ break L
+ }
+ }
+
+ if got, want := len(results), ft.expectedFrags; got != want {
+ t.Errorf("len(result) got %d, want %d", got, want)
+ }
+ if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want {
+ t.Errorf("no errors yet len(result) got %d, want %d", got, want)
+ }
+ compareFragments(t, results, source, ft.mtu)
+ })
+ }
+}
+
+// TestFragmentationErrors checks that errors are returned from write packet
+// correctly.
+func TestFragmentationErrors(t *testing.T) {
+ fragTests := []struct {
+ description string
+ mtu uint32
+ hdrLength int
+ payloadViewsSizes []int
+ packetCollectorErrors []*tcpip.Error
+ }{
+ {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
+ {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
+ {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}},
+ {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
+ }
+
+ for _, ft := range fragTests {
+ t.Run(ft.description, func(t *testing.T) {
+ hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
+ c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
+ err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: 42,
+ TOS: stack.DefaultTOS,
+ }, &stack.PacketBuffer{
+ Header: hdr,
+ Data: payload,
+ })
+ for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
+ if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
+ t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)
+ }
+ }
+ // We only need to check that last error because all the ones before are
+ // nil.
+ if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want {
+ t.Errorf("err got %v, want %v", got, want)
+ }
+ if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
+ t.Errorf("after linkEP error len(result) got %d, want %d", got, want)
+ }
+ })
+ }
+}
+
+func TestInvalidFragments(t *testing.T) {
+ // These packets have both IHL and TotalLength set to 0.
+ testCases := []struct {
+ name string
+ packets [][]byte
+ wantMalformedIPPackets uint64
+ wantMalformedFragments uint64
+ }{
+ {
+ "ihl_totallen_zero_valid_frag_offset",
+ [][]byte{
+ {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 1,
+ 0,
+ },
+ {
+ "ihl_totallen_zero_invalid_frag_offset",
+ [][]byte{
+ {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 1,
+ 0,
+ },
+ {
+ // Total Length of 37(20 bytes IP header + 17 bytes of
+ // payload)
+ // Frag Offset of 0x1ffe = 8190*8 = 65520
+ // Leading to the fragment end to be past 65535.
+ "ihl_totallen_valid_invalid_frag_offset_1",
+ [][]byte{
+ {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 1,
+ 1,
+ },
+ // The following 3 tests were found by running a fuzzer and were
+ // triggering a panic in the IPv4 reassembler code.
+ {
+ "ihl_less_than_ipv4_minimum_size_1",
+ [][]byte{
+ {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 2,
+ 0,
+ },
+ {
+ "ihl_less_than_ipv4_minimum_size_2",
+ [][]byte{
+ {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 2,
+ 0,
+ },
+ {
+ "ihl_less_than_ipv4_minimum_size_3",
+ [][]byte{
+ {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 2,
+ 0,
+ },
+ {
+ "fragment_with_short_total_len_extra_payload",
+ [][]byte{
+ {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 1,
+ 1,
+ },
+ {
+ "multiple_fragments_with_more_fragments_set_to_false",
+ [][]byte{
+ {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ },
+ 1,
+ 1,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ const nicID tcpip.NICID = 42
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{
+ ipv4.NewProtocol(),
+ },
+ })
+
+ var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30})
+ var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31})
+ ep := channel.New(10, 1500, linkAddr)
+ s.CreateNIC(nicID, sniffer.New(ep))
+
+ for _, pkt := range tc.packets {
+ ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, &stack.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}),
+ })
+ }
+
+ if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want {
+ t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want)
+ }
+ if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want {
+ t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want)
+ }
+ })
+ }
+}
+
+// TestReceiveFragments feeds fragments in through the incoming packet path to
+// test reassembly
+func TestReceiveFragments(t *testing.T) {
+ const addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1
+ const addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2
+ const nicID = 1
+
+ // Build and return a UDP header containing payload.
+ udpGen := func(payloadLen int, multiplier uint8) buffer.View {
+ payload := buffer.NewView(payloadLen)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = uint8(i) * multiplier
+ }
+
+ udpLength := header.UDPMinimumSize + len(payload)
+
+ hdr := buffer.NewPrependable(udpLength)
+ u := header.UDP(hdr.Prepend(udpLength))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: uint16(udpLength),
+ })
+ copy(u.Payload(), payload)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+ sum = header.Checksum(payload, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ return hdr.View()
+ }
+
+ // UDP header plus a payload of 0..256
+ ipv4Payload1 := udpGen(256, 1)
+ udpPayload1 := ipv4Payload1[header.UDPMinimumSize:]
+ // UDP header plus a payload of 0..256 in increments of 2.
+ ipv4Payload2 := udpGen(128, 2)
+ udpPayload2 := ipv4Payload2[header.UDPMinimumSize:]
+
+ type fragmentData struct {
+ id uint16
+ flags uint8
+ fragmentOffset uint16
+ payload buffer.View
+ }
+
+ tests := []struct {
+ name string
+ fragments []fragmentData
+ expectedPayloads [][]byte
+ }{
+ {
+ name: "No fragmentation",
+ fragments: []fragmentData{
+ {
+ id: 1,
+ flags: 0,
+ fragmentOffset: 0,
+ payload: ipv4Payload1,
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "More fragments without payload",
+ fragments: []fragmentData{
+ {
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1,
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Non-zero fragment offset without payload",
+ fragments: []fragmentData{
+ {
+ id: 1,
+ flags: 0,
+ fragmentOffset: 8,
+ payload: ipv4Payload1,
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments",
+ fragments: []fragmentData{
+ {
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1[:64],
+ },
+ {
+ id: 1,
+ flags: 0,
+ fragmentOffset: 64,
+ payload: ipv4Payload1[64:],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Second fragment has MoreFlags set",
+ fragments: []fragmentData{
+ {
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1[:64],
+ },
+ {
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 64,
+ payload: ipv4Payload1[64:],
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with different IDs",
+ fragments: []fragmentData{
+ {
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1[:64],
+ },
+ {
+ id: 2,
+ flags: 0,
+ fragmentOffset: 64,
+ payload: ipv4Payload1[64:],
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two interleaved fragmented packets",
+ fragments: []fragmentData{
+ {
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1[:64],
+ },
+ {
+ id: 2,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload2[:64],
+ },
+ {
+ id: 1,
+ flags: 0,
+ fragmentOffset: 64,
+ payload: ipv4Payload1[64:],
+ },
+ {
+ id: 2,
+ flags: 0,
+ fragmentOffset: 64,
+ payload: ipv4Payload2[64:],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1, udpPayload2},
+ },
+ {
+ name: "Fragment without followup",
+ fragments: []fragmentData{
+ {
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1[:64],
+ },
+ },
+ expectedPayloads: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ // Setup a stack and endpoint.
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00"))
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ }
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err)
+ }
+ defer ep.Close()
+
+ bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s", bindAddr, err)
+ }
+
+ // Prepare and send the fragments.
+ for _, frag := range test.fragments {
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize)
+
+ // Serialize IPv4 fixed header.
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: header.IPv4MinimumSize + uint16(len(frag.payload)),
+ ID: frag.id,
+ Flags: frag.flags,
+ FragmentOffset: frag.fragmentOffset,
+ TTL: 64,
+ Protocol: uint8(header.UDPProtocolNumber),
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ })
+
+ vv := hdr.View().ToVectorisedView()
+ vv.AppendView(frag.payload)
+
+ e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{
+ Data: vv,
+ })
+ }
+
+ if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want {
+ t.Errorf("got UDP Rx Packets = %d, want = %d", got, want)
+ }
+
+ for i, expectedPayload := range test.expectedPayloads {
+ gotPayload, _, err := ep.Read(nil)
+ if err != nil {
+ t.Fatalf("(i=%d) Read(nil): %s", i, err)
+ }
+ if diff := cmp.Diff(buffer.View(expectedPayload), gotPayload); diff != "" {
+ t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff)
+ }
+ }
+
+ if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
new file mode 100644
index 000000000..3f71fc520
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -0,0 +1,44 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "ipv6",
+ srcs = [
+ "icmp.go",
+ "ipv6.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/network/fragmentation",
+ "//pkg/tcpip/network/hash",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "ipv6_test",
+ size = "small",
+ srcs = [
+ "icmp_test.go",
+ "ipv6_test.go",
+ "ndp_test.go",
+ ],
+ library = ":ipv6",
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
new file mode 100644
index 000000000..2ff7eedf4
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -0,0 +1,549 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// handleControl handles the case when an ICMP packet contains the headers of
+// the original packet that caused the ICMP one to be sent. This information is
+// used to find out which transport endpoint must be notified about the ICMP
+// packet.
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
+ h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return
+ }
+ hdr := header.IPv6(h)
+
+ // We don't use IsValid() here because ICMP only requires that up to
+ // 1280 bytes of the original packet be included. So it's likely that it
+ // is truncated, which would cause IsValid to return false.
+ //
+ // Drop packet if it doesn't have the basic IPv6 header or if the
+ // original source address doesn't match the endpoint's address.
+ if hdr.SourceAddress() != e.id.LocalAddress {
+ return
+ }
+
+ // Skip the IP header, then handle the fragmentation header if there
+ // is one.
+ pkt.Data.TrimFront(header.IPv6MinimumSize)
+ p := hdr.TransportProtocol()
+ if p == header.IPv6FragmentHeader {
+ f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize)
+ if !ok {
+ return
+ }
+ fragHdr := header.IPv6Fragment(f)
+ if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 {
+ // We can't handle fragments that aren't at offset 0
+ // because they don't have the transport headers.
+ return
+ }
+
+ // Skip fragmentation header and find out the actual protocol
+ // number.
+ pkt.Data.TrimFront(header.IPv6FragmentHeaderSize)
+ p = fragHdr.TransportProtocol()
+ }
+
+ // Deliver the control packet to the transport endpoint.
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+}
+
+func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) {
+ stats := r.Stats().ICMP
+ sent := stats.V6PacketsSent
+ received := stats.V6PacketsReceived
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their
+ // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
+ v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
+ h := header.ICMPv6(v)
+ iph := header.IPv6(pkt.NetworkHeader)
+
+ // Validate ICMPv6 checksum before processing the packet.
+ //
+ // This copy is used as extra payload during the checksum calculation.
+ payload := pkt.Data.Clone(nil)
+ payload.TrimFront(len(h))
+ if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want {
+ received.Invalid.Increment()
+ return
+ }
+
+ isNDPValid := func() bool {
+ // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and
+ // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field
+ // in the IPv6 header is not set to 255, or the ICMPv6 Code field is not
+ // set to 0.
+ //
+ // As per RFC 6980 section 5, nodes MUST silently drop NDP messages if the
+ // packet includes a fragmentation header.
+ return !hasFragmentHeader && iph.HopLimit() == header.NDPHopLimit && h.Code() == 0
+ }
+
+ // TODO(b/112892170): Meaningfully handle all ICMP types.
+ switch h.Type() {
+ case header.ICMPv6PacketTooBig:
+ received.PacketTooBig.Increment()
+ hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
+ pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
+ mtu := header.ICMPv6(hdr).MTU()
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
+
+ case header.ICMPv6DstUnreachable:
+ received.DstUnreachable.Increment()
+ hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
+ pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
+ switch header.ICMPv6(hdr).Code() {
+ case header.ICMPv6PortUnreachable:
+ e.handleControl(stack.ControlPortUnreachable, 0, pkt)
+ }
+
+ case header.ICMPv6NeighborSolicit:
+ received.NeighborSolicit.Increment()
+ if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
+
+ // The remainder of payload must be only the neighbor solicitation, so
+ // payload.ToView() always returns the solicitation. Per RFC 6980 section 5,
+ // NDP messages cannot be fragmented. Also note that in the common case NDP
+ // datagrams are very small and ToView() will not incur allocations.
+ ns := header.NDPNeighborSolicit(payload.ToView())
+ it, err := ns.Options().Iter(true)
+ if err != nil {
+ // If we have a malformed NDP NS option, drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
+ targetAddr := ns.TargetAddress()
+ s := r.Stack()
+ if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil {
+ // We will only get an error if the NIC is unrecognized, which should not
+ // happen. For now, drop this packet.
+ //
+ // TODO(b/141002840): Handle this better?
+ return
+ } else if isTentative {
+ // If the target address is tentative and the source of the packet is a
+ // unicast (specified) address, then the source of the packet is
+ // attempting to perform address resolution on the target. In this case,
+ // the solicitation is silently ignored, as per RFC 4862 section 5.4.3.
+ //
+ // If the target address is tentative and the source of the packet is the
+ // unspecified address (::), then we know another node is also performing
+ // DAD for the same address (since the target address is tentative for us,
+ // we know we are also performing DAD on it). In this case we let the
+ // stack know so it can handle such a scenario and do nothing further with
+ // the NS.
+ if r.RemoteAddress == header.IPv6Any {
+ s.DupTentativeAddrDetected(e.nicID, targetAddr)
+ }
+
+ // Do not handle neighbor solicitations targeted to an address that is
+ // tentative on the NIC any further.
+ return
+ }
+
+ // At this point we know that the target address is not tentative on the NIC
+ // so the packet is processed as defined in RFC 4861, as per RFC 4862
+ // section 5.4.3.
+
+ // Is the NS targetting us?
+ if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 {
+ return
+ }
+
+ // If the NS message contains the Source Link-Layer Address option, update
+ // the link address cache with the value of the option.
+ //
+ // TODO(b/148429853): Properly process the NS message and do Neighbor
+ // Unreachability Detection.
+ var sourceLinkAddr tcpip.LinkAddress
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ // This should never happen as Iter(true) above did not return an error.
+ panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err))
+ }
+ if done {
+ break
+ }
+
+ switch opt := opt.(type) {
+ case header.NDPSourceLinkLayerAddressOption:
+ // No RFCs define what to do when an NS message has multiple Source
+ // Link-Layer Address options. Since no interface can have multiple
+ // link-layer addresses, we consider such messages invalid.
+ if len(sourceLinkAddr) != 0 {
+ received.Invalid.Increment()
+ return
+ }
+
+ sourceLinkAddr = opt.EthernetAddress()
+ }
+ }
+
+ unspecifiedSource := r.RemoteAddress == header.IPv6Any
+
+ // As per RFC 4861 section 4.3, the Source Link-Layer Address Option MUST
+ // NOT be included when the source IP address is the unspecified address.
+ // Otherwise, on link layers that have addresses this option MUST be
+ // included in multicast solicitations and SHOULD be included in unicast
+ // solicitations.
+ if len(sourceLinkAddr) == 0 {
+ if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource {
+ received.Invalid.Increment()
+ return
+ }
+ } else if unspecifiedSource {
+ received.Invalid.Increment()
+ return
+ } else {
+ e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr)
+ }
+
+ // ICMPv6 Neighbor Solicit messages are always sent to
+ // specially crafted IPv6 multicast addresses. As a result, the
+ // route we end up with here has as its LocalAddress such a
+ // multicast address. It would be nonsense to claim that our
+ // source address is a multicast address, so we manually set
+ // the source address to the target address requested in the
+ // solicit message. Since that requires mutating the route, we
+ // must first clone it.
+ r := r.Clone()
+ defer r.Release()
+ r.LocalAddress = targetAddr
+
+ // As per RFC 4861 section 7.2.4, if the the source of the solicitation is
+ // the unspecified address, the node MUST set the Solicited flag to zero and
+ // multicast the advertisement to the all-nodes address.
+ solicited := true
+ if unspecifiedSource {
+ solicited = false
+ r.RemoteAddress = header.IPv6AllNodesMulticastAddress
+ }
+
+ // If the NS has a source link-layer option, use the link address it
+ // specifies as the remote link address for the response instead of the
+ // source link address of the packet.
+ //
+ // TODO(#2401): As per RFC 4861 section 7.2.4 we should consult our link
+ // address cache for the right destination link address instead of manually
+ // patching the route with the remote link address if one is specified in a
+ // Source Link-Layer Address option.
+ if len(sourceLinkAddr) != 0 {
+ r.RemoteLinkAddress = sourceLinkAddr
+ }
+
+ optsSerializer := header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress),
+ }
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()))
+ packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ packet.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(packet.NDPPayload())
+ na.SetSolicitedFlag(solicited)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(targetAddr)
+ opts := na.Options()
+ opts.Serialize(optsSerializer)
+ packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ // RFC 4861 Neighbor Discovery for IP version 6 (IPv6)
+ //
+ // 7.1.2. Validation of Neighbor Advertisements
+ //
+ // The IP Hop Limit field has a value of 255, i.e., the packet
+ // could not possibly have been forwarded by a router.
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
+ Header: hdr,
+ }); err != nil {
+ sent.Dropped.Increment()
+ return
+ }
+ sent.NeighborAdvert.Increment()
+
+ case header.ICMPv6NeighborAdvert:
+ received.NeighborAdvert.Increment()
+ if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
+
+ // The remainder of payload must be only the neighbor advertisement, so
+ // payload.ToView() always returns the advertisement. Per RFC 6980 section
+ // 5, NDP messages cannot be fragmented. Also note that in the common case
+ // NDP datagrams are very small and ToView() will not incur allocations.
+ na := header.NDPNeighborAdvert(payload.ToView())
+ it, err := na.Options().Iter(true)
+ if err != nil {
+ // If we have a malformed NDP NA option, drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
+ targetAddr := na.TargetAddress()
+ stack := r.Stack()
+
+ if isTentative, err := stack.IsAddrTentative(e.nicID, targetAddr); err != nil {
+ // We will only get an error if the NIC is unrecognized, which should not
+ // happen. For now short-circuit this packet.
+ //
+ // TODO(b/141002840): Handle this better?
+ return
+ } else if isTentative {
+ // We just got an NA from a node that owns an address we are performing
+ // DAD on, implying the address is not unique. In this case we let the
+ // stack know so it can handle such a scenario and do nothing furthur with
+ // the NDP NA.
+ stack.DupTentativeAddrDetected(e.nicID, targetAddr)
+ return
+ }
+
+ // At this point we know that the target address is not tentative on the
+ // NIC. However, the target address may still be assigned to the NIC but not
+ // tentative (it could be permanent). Such a scenario is beyond the scope of
+ // RFC 4862. As such, we simply ignore such a scenario for now and proceed
+ // as normal.
+ //
+ // TODO(b/143147598): Handle the scenario described above. Also inform the
+ // netstack integration that a duplicate address was detected outside of
+ // DAD.
+
+ // If the NA message has the target link layer option, update the link
+ // address cache with the link address for the target of the message.
+ //
+ // TODO(b/148429853): Properly process the NA message and do Neighbor
+ // Unreachability Detection.
+ var targetLinkAddr tcpip.LinkAddress
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ // This should never happen as Iter(true) above did not return an error.
+ panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err))
+ }
+ if done {
+ break
+ }
+
+ switch opt := opt.(type) {
+ case header.NDPTargetLinkLayerAddressOption:
+ // No RFCs define what to do when an NA message has multiple Target
+ // Link-Layer Address options. Since no interface can have multiple
+ // link-layer addresses, we consider such messages invalid.
+ if len(targetLinkAddr) != 0 {
+ received.Invalid.Increment()
+ return
+ }
+
+ targetLinkAddr = opt.EthernetAddress()
+ }
+ }
+
+ if len(targetLinkAddr) != 0 {
+ e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr)
+ }
+
+ case header.ICMPv6EchoRequest:
+ received.EchoRequest.Increment()
+ icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
+ pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize)
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize)
+ packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+ copy(packet, icmpHdr)
+ packet.SetType(header.ICMPv6EchoReply)
+ packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{
+ Header: hdr,
+ Data: pkt.Data,
+ }); err != nil {
+ sent.Dropped.Increment()
+ return
+ }
+ sent.EchoReply.Increment()
+
+ case header.ICMPv6EchoReply:
+ received.EchoReply.Increment()
+ if pkt.Data.Size() < header.ICMPv6EchoMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt)
+
+ case header.ICMPv6TimeExceeded:
+ received.TimeExceeded.Increment()
+
+ case header.ICMPv6ParamProblem:
+ received.ParamProblem.Increment()
+
+ case header.ICMPv6RouterSolicit:
+ received.RouterSolicit.Increment()
+ if !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
+
+ case header.ICMPv6RouterAdvert:
+ received.RouterAdvert.Increment()
+
+ // Is the NDP payload of sufficient size to hold a Router
+ // Advertisement?
+ if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
+
+ routerAddr := iph.SourceAddress()
+
+ //
+ // Validate the RA as per RFC 4861 section 6.1.2.
+ //
+
+ // Is the IP Source Address a link-local address?
+ if !header.IsV6LinkLocalAddress(routerAddr) {
+ // ...No, silently drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
+ // The remainder of payload must be only the router advertisement, so
+ // payload.ToView() always returns the advertisement. Per RFC 6980 section
+ // 5, NDP messages cannot be fragmented. Also note that in the common case
+ // NDP datagrams are very small and ToView() will not incur allocations.
+ ra := header.NDPRouterAdvert(payload.ToView())
+ opts := ra.Options()
+
+ // Are options valid as per the wire format?
+ if _, err := opts.Iter(true); err != nil {
+ // ...No, silently drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
+ //
+ // At this point, we have a valid Router Advertisement, as far
+ // as RFC 4861 section 6.1.2 is concerned.
+ //
+
+ // Tell the NIC to handle the RA.
+ stack := r.Stack()
+ rxNICID := r.NICID()
+ stack.HandleNDPRA(rxNICID, routerAddr, ra)
+
+ case header.ICMPv6RedirectMsg:
+ received.RedirectMsg.Increment()
+ if !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
+
+ default:
+ received.Invalid.Increment()
+ }
+}
+
+const (
+ ndpSolicitedFlag = 1 << 6
+ ndpOverrideFlag = 1 << 5
+
+ ndpOptSrcLinkAddr = 1
+ ndpOptDstLinkAddr = 2
+
+ icmpV6FlagOffset = 4
+ icmpV6OptOffset = 24
+ icmpV6LengthOffset = 25
+)
+
+var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
+
+var _ stack.LinkAddressResolver = (*protocol)(nil)
+
+// LinkAddressProtocol implements stack.LinkAddressResolver.
+func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+// LinkAddressRequest implements stack.LinkAddressResolver.
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+ snaddr := header.SolicitedNodeAddr(addr)
+
+ // TODO(b/148672031): Use stack.FindRoute instead of manually creating the
+ // route here. Note, we would need the nicID to do this properly so the right
+ // NIC (associated to linkEP) is used to send the NDP NS message.
+ r := &stack.Route{
+ LocalAddress: localAddr,
+ RemoteAddress: snaddr,
+ RemoteLinkAddress: header.EthernetAddressFromMulticastIPv6Address(snaddr),
+ }
+ hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ copy(pkt[icmpV6OptOffset-len(addr):], addr)
+ pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr
+ pkt[icmpV6LengthOffset] = 1
+ copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress())
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ length := uint16(hdr.UsedLength())
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: length,
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+
+ // TODO(stijlist): count this in ICMP stats.
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{
+ Header: hdr,
+ })
+}
+
+// ResolveStaticAddress implements stack.LinkAddressResolver.
+func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if header.IsV6MulticastAddress(addr) {
+ return header.EthernetAddressFromMulticastIPv6Address(addr), true
+ }
+ return tcpip.LinkAddress([]byte(nil)), false
+}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
new file mode 100644
index 000000000..52a01b44e
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -0,0 +1,953 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "context"
+ "reflect"
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
+)
+
+var (
+ lladdr0 = header.LinkLocalAddr(linkAddr0)
+ lladdr1 = header.LinkLocalAddr(linkAddr1)
+)
+
+type stubLinkEndpoint struct {
+ stack.LinkEndpoint
+}
+
+func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return 0
+}
+
+func (*stubLinkEndpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress {
+ return ""
+}
+
+func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
+ return nil
+}
+
+func (*stubLinkEndpoint) Attach(stack.NetworkDispatcher) {}
+
+type stubDispatcher struct {
+ stack.TransportDispatcher
+}
+
+func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) {
+}
+
+type stubLinkAddressCache struct {
+ stack.LinkAddressCache
+}
+
+func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.NICID {
+ return 0
+}
+
+func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) {
+}
+
+func TestICMPCounts(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ })
+ {
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+ if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
+
+ netProto := s.NetworkProtocolInstance(ProtocolNumber)
+ if netProto == nil {
+ t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
+ }
+ ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
+ if err != nil {
+ t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
+ }
+
+ r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ }
+ defer r.Release()
+
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
+ types := []struct {
+ typ header.ICMPv6Type
+ size int
+ extraData []byte
+ }{
+ {
+ typ: header.ICMPv6DstUnreachable,
+ size: header.ICMPv6DstUnreachableMinimumSize,
+ },
+ {
+ typ: header.ICMPv6PacketTooBig,
+ size: header.ICMPv6PacketTooBigMinimumSize,
+ },
+ {
+ typ: header.ICMPv6TimeExceeded,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6ParamProblem,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoRequest,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoReply,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ },
+ {
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ },
+ {
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ },
+ {
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ },
+ }
+
+ handleIPv6Payload := func(icmp header.ICMPv6) {
+ ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(icmp)),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ ep.HandlePacket(&r, &stack.PacketBuffer{
+ NetworkHeader: buffer.View(ip),
+ Data: buffer.View(icmp).ToVectorisedView(),
+ })
+ }
+
+ for _, typ := range types {
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ handleIPv6Payload(icmp)
+ }
+
+ // Construct an empty ICMP packet so that
+ // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
+ handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
+
+ icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
+ visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
+ if got, want := s.Value(), uint64(1); got != want {
+ t.Errorf("got %s = %d, want = %d", name, got, want)
+ }
+ })
+ if t.Failed() {
+ t.Logf("stats:\n%+v", s.Stats())
+ }
+}
+
+func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) {
+ t := v.Type()
+ for i := 0; i < v.NumField(); i++ {
+ v := v.Field(i)
+ if s, ok := v.Interface().(*tcpip.StatCounter); ok {
+ f(t.Field(i).Name, s)
+ } else {
+ visitStats(v, f)
+ }
+ }
+}
+
+type testContext struct {
+ s0 *stack.Stack
+ s1 *stack.Stack
+
+ linkEP0 *channel.Endpoint
+ linkEP1 *channel.Endpoint
+}
+
+type endpointWithResolutionCapability struct {
+ stack.LinkEndpoint
+}
+
+func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapabilities {
+ return e.LinkEndpoint.Capabilities() | stack.CapabilityResolutionRequired
+}
+
+func newTestContext(t *testing.T) *testContext {
+ c := &testContext{
+ s0: stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ }),
+ s1: stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ }),
+ }
+
+ const defaultMTU = 65536
+ c.linkEP0 = channel.New(256, defaultMTU, linkAddr0)
+
+ wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0})
+ if testing.Verbose() {
+ wrappedEP0 = sniffer.New(wrappedEP0)
+ }
+ if err := c.s0.CreateNIC(1, wrappedEP0); err != nil {
+ t.Fatalf("CreateNIC s0: %v", err)
+ }
+ if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress lladdr0: %v", err)
+ }
+
+ c.linkEP1 = channel.New(256, defaultMTU, linkAddr1)
+ wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
+ if err := c.s1.CreateNIC(1, wrappedEP1); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+ if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil {
+ t.Fatalf("AddAddress lladdr1: %v", err)
+ }
+
+ subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.s0.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet0,
+ NIC: 1,
+ }},
+ )
+ subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.s1.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet1,
+ NIC: 1,
+ }},
+ )
+
+ return c
+}
+
+func (c *testContext) cleanup() {
+ c.linkEP0.Close()
+ c.linkEP1.Close()
+}
+
+type routeArgs struct {
+ src, dst *channel.Endpoint
+ typ header.ICMPv6Type
+ remoteLinkAddr tcpip.LinkAddress
+}
+
+func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) {
+ t.Helper()
+
+ pi, _ := args.src.ReadContext(context.Background())
+
+ {
+ views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()}
+ size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size()
+ vv := buffer.NewVectorisedView(size, views)
+ args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), &stack.PacketBuffer{
+ Data: vv,
+ })
+ }
+
+ if pi.Proto != ProtocolNumber {
+ t.Errorf("unexpected protocol number %d", pi.Proto)
+ return
+ }
+
+ if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress {
+ t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr)
+ }
+
+ ipv6 := header.IPv6(pi.Pkt.Header.View())
+ transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader())
+ if transProto != header.ICMPv6ProtocolNumber {
+ t.Errorf("unexpected transport protocol number %d", transProto)
+ return
+ }
+ icmpv6 := header.ICMPv6(ipv6.Payload())
+ if got, want := icmpv6.Type(), args.typ; got != want {
+ t.Errorf("got ICMPv6 type = %d, want = %d", got, want)
+ return
+ }
+ if fn != nil {
+ fn(t, icmpv6)
+ }
+}
+
+func TestLinkResolution(t *testing.T) {
+ c := newTestContext(t)
+ defer c.cleanup()
+
+ r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ }
+ defer r.Release()
+
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+ pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ payload := tcpip.SlicePayload(hdr.View())
+
+ // We can't send our payload directly over the route because that
+ // doesn't provoke NDP discovery.
+ var wq waiter.Queue
+ ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
+ }
+
+ for {
+ _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}})
+ if resCh != nil {
+ if err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err)
+ }
+ for _, args := range []routeArgs{
+ {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))},
+ {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert},
+ } {
+ routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) {
+ if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want {
+ t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want)
+ }
+ })
+ }
+ <-resCh
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Write(_) = _, _, %s", err)
+ }
+ break
+ }
+
+ for _, args := range []routeArgs{
+ {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest},
+ {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply},
+ } {
+ routeICMPv6Packet(t, args, nil)
+ }
+}
+
+func TestICMPChecksumValidationSimple(t *testing.T) {
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ extraData []byte
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ }{
+ {
+ name: "DstUnreachable",
+ typ: header.ICMPv6DstUnreachable,
+ size: header.ICMPv6DstUnreachableMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.DstUnreachable
+ },
+ },
+ {
+ name: "PacketTooBig",
+ typ: header.ICMPv6PacketTooBig,
+ size: header.ICMPv6PacketTooBigMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.PacketTooBig
+ },
+ },
+ {
+ name: "TimeExceeded",
+ typ: header.ICMPv6TimeExceeded,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.TimeExceeded
+ },
+ },
+ {
+ name: "ParamProblem",
+ typ: header.ICMPv6ParamProblem,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.ParamProblem
+ },
+ },
+ {
+ name: "EchoRequest",
+ typ: header.ICMPv6EchoRequest,
+ size: header.ICMPv6EchoMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoRequest
+ },
+ },
+ {
+ name: "EchoReply",
+ typ: header.ICMPv6EchoReply,
+ size: header.ICMPv6EchoMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoReply
+ },
+ },
+ {
+ name: "RouterSolicit",
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterSolicit
+ },
+ },
+ {
+ name: "RouterAdvert",
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterAdvert
+ },
+ },
+ {
+ name: "NeighborSolicit",
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborSolicit
+ },
+ },
+ {
+ name: "NeighborAdvert",
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborAdvert
+ },
+ },
+ {
+ name: "RedirectMsg",
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RedirectMsg
+ },
+ },
+ }
+
+ for _, typ := range types {
+ t.Run(typ.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr0)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
+
+ handleIPv6Payload := func(checksum bool) {
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ if checksum {
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView()))
+ }
+ ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(icmp)),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
+ })
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
+
+ // Initial stat counts should be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Without setting checksum, the incoming packet should
+ // be invalid.
+ handleIPv6Payload(false)
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ // Rx count of type typ.typ should not have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // When checksum is set, it should be received.
+ handleIPv6Payload(true)
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ })
+ }
+}
+
+func TestICMPChecksumValidationWithPayload(t *testing.T) {
+ const simpleBodySize = 64
+ simpleBody := func(view buffer.View) {
+ for i := 0; i < simpleBodySize; i++ {
+ view[i] = uint8(i)
+ }
+ }
+
+ const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize
+ errorICMPBody := func(view buffer.View) {
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: simpleBodySize,
+ NextHeader: 10,
+ HopLimit: 20,
+ SrcAddr: lladdr0,
+ DstAddr: lladdr1,
+ })
+ simpleBody(view[header.IPv6MinimumSize:])
+ }
+
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ payloadSize int
+ payload func(buffer.View)
+ }{
+ {
+ "DstUnreachable",
+ header.ICMPv6DstUnreachable,
+ header.ICMPv6DstUnreachableMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.DstUnreachable
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "PacketTooBig",
+ header.ICMPv6PacketTooBig,
+ header.ICMPv6PacketTooBigMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.PacketTooBig
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "TimeExceeded",
+ header.ICMPv6TimeExceeded,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.TimeExceeded
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "ParamProblem",
+ header.ICMPv6ParamProblem,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.ParamProblem
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "EchoRequest",
+ header.ICMPv6EchoRequest,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoRequest
+ },
+ simpleBodySize,
+ simpleBody,
+ },
+ {
+ "EchoReply",
+ header.ICMPv6EchoReply,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoReply
+ },
+ simpleBodySize,
+ simpleBody,
+ },
+ }
+
+ for _, typ := range types {
+ t.Run(typ.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr0)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
+
+ handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
+ icmpSize := size + payloadSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
+ pkt := header.ICMPv6(hdr.Prepend(icmpSize))
+ pkt.SetType(typ)
+ payloadFn(pkt.Payload())
+
+ if checksum {
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ }
+
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(icmpSize),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
+
+ // Initial stat counts should be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Without setting checksum, the incoming packet should
+ // be invalid.
+ handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false)
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ // Rx count of type typ.typ should not have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // When checksum is set, it should be received.
+ handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ })
+ }
+}
+
+func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
+ const simpleBodySize = 64
+ simpleBody := func(view buffer.View) {
+ for i := 0; i < simpleBodySize; i++ {
+ view[i] = uint8(i)
+ }
+ }
+
+ const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize
+ errorICMPBody := func(view buffer.View) {
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: simpleBodySize,
+ NextHeader: 10,
+ HopLimit: 20,
+ SrcAddr: lladdr0,
+ DstAddr: lladdr1,
+ })
+ simpleBody(view[header.IPv6MinimumSize:])
+ }
+
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ payloadSize int
+ payload func(buffer.View)
+ }{
+ {
+ "DstUnreachable",
+ header.ICMPv6DstUnreachable,
+ header.ICMPv6DstUnreachableMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.DstUnreachable
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "PacketTooBig",
+ header.ICMPv6PacketTooBig,
+ header.ICMPv6PacketTooBigMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.PacketTooBig
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "TimeExceeded",
+ header.ICMPv6TimeExceeded,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.TimeExceeded
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "ParamProblem",
+ header.ICMPv6ParamProblem,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.ParamProblem
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "EchoRequest",
+ header.ICMPv6EchoRequest,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoRequest
+ },
+ simpleBodySize,
+ simpleBody,
+ },
+ {
+ "EchoReply",
+ header.ICMPv6EchoReply,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoReply
+ },
+ simpleBodySize,
+ simpleBody,
+ },
+ }
+
+ for _, typ := range types {
+ t.Run(typ.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr0)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
+
+ handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + size)
+ pkt := header.ICMPv6(hdr.Prepend(size))
+ pkt.SetType(typ)
+
+ payload := buffer.NewView(payloadSize)
+ payloadFn(payload)
+
+ if checksum {
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView()))
+ }
+
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(size + payloadSize),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
+ })
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
+
+ // Initial stat counts should be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Without setting checksum, the incoming packet should
+ // be invalid.
+ handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false)
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ // Rx count of type typ.typ should not have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // When checksum is set, it should be received.
+ handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
new file mode 100644
index 000000000..95fbcf2d1
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -0,0 +1,599 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ipv6 contains the implementation of the ipv6 network protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing ipv6.NewProtocol() as one of the network
+// protocols when calling stack.New(). Then endpoints can be created by passing
+// ipv6.ProtocolNumber as the network protocol number when calling
+// Stack.NewEndpoint().
+package ipv6
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
+ "gvisor.dev/gvisor/pkg/tcpip/network/hash"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // ProtocolNumber is the ipv6 protocol number.
+ ProtocolNumber = header.IPv6ProtocolNumber
+
+ // maxTotalSize is maximum size that can be encoded in the 16-bit
+ // PayloadLength field of the ipv6 header.
+ maxPayloadSize = 0xffff
+
+ // DefaultTTL is the default hop limit for IPv6 Packets egressed by
+ // Netstack.
+ DefaultTTL = 64
+)
+
+type endpoint struct {
+ nicID tcpip.NICID
+ id stack.NetworkEndpointID
+ prefixLen int
+ linkEP stack.LinkEndpoint
+ linkAddrCache stack.LinkAddressCache
+ dispatcher stack.TransportDispatcher
+ fragmentation *fragmentation.Fragmentation
+ protocol *protocol
+}
+
+// DefaultTTL is the default hop limit for this endpoint.
+func (e *endpoint) DefaultTTL() uint8 {
+ return e.protocol.DefaultTTL()
+}
+
+// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
+// the network layer max header length.
+func (e *endpoint) MTU() uint32 {
+ return calculateMTU(e.linkEP.MTU())
+}
+
+// NICID returns the ID of the NIC this endpoint belongs to.
+func (e *endpoint) NICID() tcpip.NICID {
+ return e.nicID
+}
+
+// ID returns the ipv6 endpoint ID.
+func (e *endpoint) ID() *stack.NetworkEndpointID {
+ return &e.id
+}
+
+// PrefixLen returns the ipv6 endpoint subnet prefix length in bits.
+func (e *endpoint) PrefixLen() int {
+ return e.prefixLen
+}
+
+// Capabilities implements stack.NetworkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
+// underlying protocols).
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (e *endpoint) GSOMaxSize() uint32 {
+ if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
+ return gso.GSOMaxSize()
+ }
+ return 0
+}
+
+func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 {
+ length := uint16(hdr.UsedLength() + payloadSize)
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: length,
+ NextHeader: uint8(params.Protocol),
+ HopLimit: params.TTL,
+ TrafficClass: params.TOS,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ return ip
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
+ ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
+ pkt.NetworkHeader = buffer.View(ip)
+
+ if r.Loop&stack.PacketLoop != 0 {
+ // The inbound path expects the network header to still be in
+ // the PacketBuffer's Data field.
+ views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
+ views[0] = pkt.Header.View()
+ views = append(views, pkt.Data.Views()...)
+ loopedR := r.MakeLoopedRoute()
+
+ e.HandlePacket(&loopedR, &stack.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
+ })
+
+ loopedR.Release()
+ }
+ if r.Loop&stack.PacketOut == 0 {
+ return nil
+ }
+
+ r.Stats().IP.PacketsSent.Increment()
+ return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+ if r.Loop&stack.PacketLoop != 0 {
+ panic("not implemented")
+ }
+ if r.Loop&stack.PacketOut == 0 {
+ return pkts.Len(), nil
+ }
+
+ for pb := pkts.Front(); pb != nil; pb = pb.Next() {
+ ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params)
+ pb.NetworkHeader = buffer.View(ip)
+ }
+
+ n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+}
+
+// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
+// supported by IPv6.
+func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
+ // TODO(b/146666412): Support IPv6 header-included packets.
+ return tcpip.ErrNotSupported
+}
+
+// HandlePacket is called by the link layer when new ipv6 packets arrive for
+// this endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ h := header.IPv6(pkt.NetworkHeader)
+ if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+
+ // vv consists of:
+ // - Any IPv6 header bytes after the first 40 (i.e. extensions).
+ // - The transport header, if present.
+ // - Any other payload data.
+ vv := pkt.NetworkHeader[header.IPv6MinimumSize:].ToVectorisedView()
+ vv.AppendView(pkt.TransportHeader)
+ vv.Append(pkt.Data)
+ it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv)
+ hasFragmentHeader := false
+
+ for firstHeader := true; ; firstHeader = false {
+ extHdr, done, err := it.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ switch extHdr := extHdr.(type) {
+ case header.IPv6HopByHopOptionsExtHdr:
+ // As per RFC 8200 section 4.1, the Hop By Hop extension header is
+ // restricted to appear immediately after an IPv6 fixed header.
+ //
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1
+ // (unrecognized next header) error in response to an extension header's
+ // Next Header field with the Hop By Hop extension header identifier.
+ if !firstHeader {
+ return
+ }
+
+ optsIt := extHdr.Iter()
+
+ for {
+ opt, done, err := optsIt.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ // We currently do not support any IPv6 Hop By Hop extension header
+ // options.
+ switch opt.UnknownAction() {
+ case header.IPv6OptionUnknownActionSkip:
+ case header.IPv6OptionUnknownActionDiscard:
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ default:
+ panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt))
+ }
+ }
+
+ case header.IPv6RoutingExtHdr:
+ // As per RFC 8200 section 4.4, if a node encounters a routing header with
+ // an unrecognized routing type value, with a non-zero Segments Left
+ // value, the node must discard the packet and send an ICMP Parameter
+ // Problem, Code 0. If the Segments Left is 0, the node must ignore the
+ // Routing extension header and process the next header in the packet.
+ //
+ // Note, the stack does not yet handle any type of routing extension
+ // header, so we just make sure Segments Left is zero before processing
+ // the next extension header.
+ //
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for
+ // unrecognized routing types with a non-zero Segments Left value.
+ if extHdr.SegmentsLeft() != 0 {
+ return
+ }
+
+ case header.IPv6FragmentExtHdr:
+ hasFragmentHeader = true
+
+ if extHdr.IsAtomic() {
+ // This fragment extension header indicates that this packet is an
+ // atomic fragment. An atomic fragment is a fragment that contains
+ // all the data required to reassemble a full packet. As per RFC 6946,
+ // atomic fragments must not interfere with "normal" fragmented traffic
+ // so we skip processing the fragment instead of feeding it through the
+ // reassembly process below.
+ continue
+ }
+
+ // Don't consume the iterator if we have the first fragment because we
+ // will use it to validate that the first fragment holds the upper layer
+ // header.
+ rawPayload := it.AsRawHeader(extHdr.FragmentOffset() != 0 /* consume */)
+
+ if extHdr.FragmentOffset() == 0 {
+ // Check that the iterator ends with a raw payload as the first fragment
+ // should include all headers up to and including any upper layer
+ // headers, as per RFC 8200 section 4.5; only upper layer data
+ // (non-headers) should follow the fragment extension header.
+ var lastHdr header.IPv6PayloadHeader
+
+ for {
+ it, done, err := it.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ lastHdr = it
+ }
+
+ // If the last header is a raw header, then the last portion of the IPv6
+ // payload is not a known IPv6 extension header. Note, this does not
+ // mean that the last portion is an upper layer header or not an
+ // extension header because:
+ // 1) we do not yet support all extension headers
+ // 2) we do not validate the upper layer header before reassembling.
+ //
+ // This check makes sure that a known IPv6 extension header is not
+ // present after the Fragment extension header in a non-initial
+ // fragment.
+ //
+ // TODO(#2196): Support IPv6 Authentication and Encapsulated
+ // Security Payload extension headers.
+ // TODO(#2333): Validate that the upper layer header is valid.
+ switch lastHdr.(type) {
+ case header.IPv6RawPayloadHeader:
+ default:
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+ }
+
+ fragmentPayloadLen := rawPayload.Buf.Size()
+ if fragmentPayloadLen == 0 {
+ // Drop the packet as it's marked as a fragment but has no payload.
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+
+ // The packet is a fragment, let's try to reassemble it.
+ start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
+ last := start + uint16(fragmentPayloadLen) - 1
+
+ // Drop the packet if the fragmentOffset is incorrect. i.e the
+ // combination of fragmentOffset and pkt.Data.size() causes a
+ // wrap around resulting in last being less than the offset.
+ if last < start {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+
+ var ready bool
+ // Note that pkt doesn't have its transport header set after reassembly,
+ // and won't until DeliverNetworkPacket sets it.
+ pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, extHdr.More(), rawPayload.Buf)
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+
+ if ready {
+ // We create a new iterator with the reassembled packet because we could
+ // have more extension headers in the reassembled payload, as per RFC
+ // 8200 section 4.5.
+ it = header.MakeIPv6PayloadIterator(rawPayload.Identifier, pkt.Data)
+ }
+
+ case header.IPv6DestinationOptionsExtHdr:
+ optsIt := extHdr.Iter()
+
+ for {
+ opt, done, err := optsIt.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ // We currently do not support any IPv6 Destination extension header
+ // options.
+ switch opt.UnknownAction() {
+ case header.IPv6OptionUnknownActionSkip:
+ case header.IPv6OptionUnknownActionDiscard:
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ default:
+ panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt))
+ }
+ }
+
+ case header.IPv6RawPayloadHeader:
+ // If the last header in the payload isn't a known IPv6 extension header,
+ // handle it as if it is transport layer data.
+
+ // For unfragmented packets, extHdr still contains the transport header.
+ // Get rid of it.
+ //
+ // For reassembled fragments, pkt.TransportHeader is unset, so this is a
+ // no-op and pkt.Data begins with the transport header.
+ extHdr.Buf.TrimFront(len(pkt.TransportHeader))
+ pkt.Data = extHdr.Buf
+
+ if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
+ e.handleICMP(r, pkt, hasFragmentHeader)
+ } else {
+ r.Stats().IP.PacketsDelivered.Increment()
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
+ // in response to unrecognized next header values.
+ e.dispatcher.DeliverTransportPacket(r, p, pkt)
+ }
+
+ default:
+ // If we receive a packet for an extension header we do not yet handle,
+ // drop the packet for now.
+ //
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
+ // in response to unrecognized next header values.
+ r.Stats().UnknownProtocolRcvdPackets.Increment()
+ return
+ }
+ }
+}
+
+// Close cleans up resources associated with the endpoint.
+func (*endpoint) Close() {}
+
+// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
+func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return e.protocol.Number()
+}
+
+type protocol struct {
+ // defaultTTL is the current default TTL for the protocol. Only the
+ // uint8 portion of it is meaningful and it must be accessed
+ // atomically.
+ defaultTTL uint32
+}
+
+// Number returns the ipv6 protocol number.
+func (p *protocol) Number() tcpip.NetworkProtocolNumber {
+ return ProtocolNumber
+}
+
+// MinimumPacketSize returns the minimum valid ipv6 packet size.
+func (p *protocol) MinimumPacketSize() int {
+ return header.IPv6MinimumSize
+}
+
+// DefaultPrefixLen returns the IPv6 default prefix length.
+func (p *protocol) DefaultPrefixLen() int {
+ return header.IPv6AddressSize * 8
+}
+
+// ParseAddresses implements NetworkProtocol.ParseAddresses.
+func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.IPv6(v)
+ return h.SourceAddress(), h.DestinationAddress()
+}
+
+// NewEndpoint creates a new ipv6 endpoint.
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
+ return &endpoint{
+ nicID: nicID,
+ id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
+ prefixLen: addrWithPrefix.PrefixLen,
+ linkEP: linkEP,
+ linkAddrCache: linkAddrCache,
+ dispatcher: dispatcher,
+ fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ protocol: p,
+ }, nil
+}
+
+// SetOption implements NetworkProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case tcpip.DefaultTTLOption:
+ p.SetDefaultTTL(uint8(v))
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Option implements NetworkProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *tcpip.DefaultTTLOption:
+ *v = tcpip.DefaultTTLOption(p.DefaultTTL())
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// SetDefaultTTL sets the default TTL for endpoints created with this protocol.
+func (p *protocol) SetDefaultTTL(ttl uint8) {
+ atomic.StoreUint32(&p.defaultTTL, uint32(ttl))
+}
+
+// DefaultTTL returns the default TTL for endpoints created with this protocol.
+func (p *protocol) DefaultTTL() uint8 {
+ return uint8(atomic.LoadUint32(&p.defaultTTL))
+}
+
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
+ hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return 0, false, false
+ }
+ ipHdr := header.IPv6(hdr)
+
+ // dataClone consists of:
+ // - Any IPv6 header bytes after the first 40 (i.e. extensions).
+ // - The transport header, if present.
+ // - Any other payload data.
+ views := [8]buffer.View{}
+ dataClone := pkt.Data.Clone(views[:])
+ dataClone.TrimFront(header.IPv6MinimumSize)
+ it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone)
+
+ // Iterate over the IPv6 extensions to find their length.
+ //
+ // Parsing occurs again in HandlePacket because we don't track the
+ // extensions in PacketBuffer. Unfortunately, that means HandlePacket
+ // has to do the parsing work again.
+ var nextHdr tcpip.TransportProtocolNumber
+ foundNext := true
+ extensionsSize := 0
+traverseExtensions:
+ for extHdr, done, err := it.Next(); ; extHdr, done, err = it.Next() {
+ if err != nil {
+ break
+ }
+ // If we exhaust the extension list, the entire packet is the IPv6 header
+ // and (possibly) extensions.
+ if done {
+ extensionsSize = dataClone.Size()
+ foundNext = false
+ break
+ }
+
+ switch extHdr := extHdr.(type) {
+ case header.IPv6FragmentExtHdr:
+ // If this is an atomic fragment, we don't have to treat it specially.
+ if !extHdr.More() && extHdr.FragmentOffset() == 0 {
+ continue
+ }
+ // This is a non-atomic fragment and has to be re-assembled before we can
+ // examine the payload for a transport header.
+ foundNext = false
+
+ case header.IPv6RawPayloadHeader:
+ // We've found the payload after any extensions.
+ extensionsSize = dataClone.Size() - extHdr.Buf.Size()
+ nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier)
+ break traverseExtensions
+
+ default:
+ // Any other extension is a no-op, keep looping until we find the payload.
+ }
+ }
+
+ // Put the IPv6 header with extensions in pkt.NetworkHeader.
+ hdr, ok = pkt.Data.PullUp(header.IPv6MinimumSize + extensionsSize)
+ if !ok {
+ panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size()))
+ }
+ ipHdr = header.IPv6(hdr)
+
+ pkt.NetworkHeader = hdr
+ pkt.Data.TrimFront(len(hdr))
+ pkt.Data.CapLength(int(ipHdr.PayloadLength()))
+
+ return nextHdr, foundNext, true
+}
+
+// calculateMTU calculates the network-layer payload MTU based on the link-layer
+// payload mtu.
+func calculateMTU(mtu uint32) uint32 {
+ mtu -= header.IPv6MinimumSize
+ if mtu <= maxPayloadSize {
+ return mtu
+ }
+ return maxPayloadSize
+}
+
+// NewProtocol returns an IPv6 network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{defaultTTL: DefaultTTL}
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
new file mode 100644
index 000000000..213ff64f2
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -0,0 +1,1265 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ // The least significant 3 bytes are the same as addr2 so both addr2 and
+ // addr3 will have the same solicited-node address.
+ addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02"
+ addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03"
+
+ // Tests use the extension header identifier values as uint8 instead of
+ // header.IPv6ExtensionHeaderIdentifier.
+ hopByHopExtHdrID = uint8(header.IPv6HopByHopOptionsExtHdrIdentifier)
+ routingExtHdrID = uint8(header.IPv6RoutingExtHdrIdentifier)
+ fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier)
+ destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier)
+ noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier)
+)
+
+// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the
+// expected Neighbor Advertisement received count after receiving the packet.
+func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
+ t.Helper()
+
+ // Receive ICMP packet.
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+
+ if got := stats.NeighborAdvert.Value(); got != want {
+ t.Fatalf("got NeighborAdvert = %d, want = %d", got, want)
+ }
+}
+
+// testReceiveUDP tests receiving a UDP packet from src to dst. want is the
+// expected UDP received count after receiving the packet.
+func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
+ t.Helper()
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil {
+ t.Fatalf("ep.Bind(...) failed: %v", err)
+ }
+
+ // Receive UDP Packet.
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
+ u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: header.UDPMinimumSize,
+ })
+
+ // UDP pseudo-header checksum.
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize)
+
+ // UDP checksum
+ sum = header.Checksum(header.UDP([]byte{}), sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ stat := s.Stats().UDP.PacketsReceived
+
+ if got := stat.Value(); got != want {
+ t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want)
+ }
+}
+
+// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and
+// UDP packets destined to the IPv6 link-local all-nodes multicast address.
+func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ protocolFactory stack.TransportProtocol
+ rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
+ }{
+ {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
+ {"UDP", udp.NewProtocol(), testReceiveUDP},
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ })
+ e := channel.New(10, 1280, linkAddr1)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ // Should receive a packet destined to the all-nodes
+ // multicast address.
+ test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1)
+ })
+ }
+}
+
+// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP
+// packets destined to the IPv6 solicited-node address of an assigned IPv6
+// address.
+func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ protocolFactory stack.TransportProtocol
+ rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
+ }{
+ {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
+ {"UDP", udp.NewProtocol(), testReceiveUDP},
+ }
+
+ snmc := header.SolicitedNodeAddr(addr2)
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ })
+ e := channel.New(1, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ // Should not receive a packet destined to the solicited node address of
+ // addr2/addr3 yet as we haven't added those addresses.
+ test.rxf(t, s, e, addr1, snmc, 0)
+
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ }
+
+ // Should receive a packet destined to the solicited node address of
+ // addr2/addr3 now that we have added added addr2.
+ test.rxf(t, s, e, addr1, snmc, 1)
+
+ if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err)
+ }
+
+ // Should still receive a packet destined to the solicited node address of
+ // addr2/addr3 now that we have added addr3.
+ test.rxf(t, s, e, addr1, snmc, 2)
+
+ if err := s.RemoveAddress(nicID, addr2); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr2, err)
+ }
+
+ // Should still receive a packet destined to the solicited node address of
+ // addr2/addr3 now that we have removed addr2.
+ test.rxf(t, s, e, addr1, snmc, 3)
+
+ // Make sure addr3's endpoint does not get removed from the NIC by
+ // incrementing its reference count with a route.
+ r, err := s.FindRoute(nicID, addr3, addr4, ProtocolNumber, false)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr3, addr4, ProtocolNumber, err)
+ }
+ defer r.Release()
+
+ if err := s.RemoveAddress(nicID, addr3); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr3, err)
+ }
+
+ // Should not receive a packet destined to the solicited node address of
+ // addr2/addr3 yet as both of them got removed, even though a route using
+ // addr3 exists.
+ test.rxf(t, s, e, addr1, snmc, 3)
+ })
+ }
+}
+
+// TestAddIpv6Address tests adding IPv6 addresses.
+func TestAddIpv6Address(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ }{
+ // This test is in response to b/140943433.
+ {
+ "Nil",
+ tcpip.Address([]byte(nil)),
+ },
+ {
+ "ValidUnicast",
+ addr1,
+ },
+ {
+ "ValidLinkLocalUnicast",
+ lladdr0,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil {
+ t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err)
+ }
+
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if addr.Address != test.addr {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr)
+ }
+ })
+ }
+}
+
+func TestReceiveIPv6ExtHdrs(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ extHdr func(nextHdr uint8) ([]byte, uint8)
+ shouldAccept bool
+ }{
+ {
+ name: "None",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr },
+ shouldAccept: true,
+ },
+ {
+ name: "hopbyhop with unknown option skippable action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 62, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "hopbyhop with unknown option discard action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard unknown.
+ 127, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action unless multicast dest",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "routing with zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID },
+ shouldAccept: true,
+ },
+ {
+ name: "routing with non-zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "atomic fragment with zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID },
+ shouldAccept: true,
+ },
+ {
+ name: "atomic fragment with non-zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ shouldAccept: true,
+ },
+ {
+ name: "fragment",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "No next header",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "destination with unknown option skippable action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 62, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "destination with unknown option discard action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard unknown.
+ 127, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "destination with unknown option discard and send icmp action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "destination with unknown option discard and send icmp action unless multicast dest",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "routing - atomic fragment",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ nextHdr, 0, 0, 0, 1, 2, 3, 4,
+ }, routingExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "atomic fragment - routing",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Fragment extension header.
+ routingExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Routing extension header.
+ nextHdr, 0, 1, 0, 2, 3, 4, 5,
+ }, fragmentExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "hop by hop (with skippable unknown) - routing",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with skippable unknown option.
+ routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ nextHdr, 0, 1, 0, 2, 3, 4, 5,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "routing - hop by hop (with skippable unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Routing extension header.
+ hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Hop By Hop extension header with skippable unknown option.
+ nextHdr, 0, 62, 4, 1, 2, 3, 4,
+ }, routingExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "No next header",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with skippable unknown option.
+ routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Destination extension header with skippable unknown option.
+ nextHdr, 0, 63, 4, 1, 2, 3, 4,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "hopbyhop (with discard unknown) - routing - atomic fragment - destination (with skippable unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with discard action for unknown option.
+ routingExtHdrID, 0, 65, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Destination extension header with skippable unknown option.
+ nextHdr, 0, 63, 4, 1, 2, 3, 4,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with skippable unknown option.
+ routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Destination extension header with discard action for unknown
+ // option.
+ nextHdr, 0, 65, 4, 1, 2, 3, 4,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ }
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err)
+ }
+ defer ep.Close()
+
+ bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s", bindAddr, err)
+ }
+
+ udpPayload := []byte{1, 2, 3, 4, 5, 6, 7, 8}
+ udpLength := header.UDPMinimumSize + len(udpPayload)
+ extHdrBytes, ipv6NextHdr := test.extHdr(uint8(header.UDPProtocolNumber))
+ extHdrLen := len(extHdrBytes)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + extHdrLen + udpLength)
+
+ // Serialize UDP message.
+ u := header.UDP(hdr.Prepend(udpLength))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: uint16(udpLength),
+ })
+ copy(u.Payload(), udpPayload)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+ sum = header.Checksum(udpPayload, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ // Copy extension header bytes between the UDP message and the IPv6
+ // fixed header.
+ copy(hdr.Prepend(extHdrLen), extHdrBytes)
+
+ // Serialize IPv6 fixed header.
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: ipv6NextHdr,
+ HopLimit: 255,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ })
+
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ stats := s.Stats().UDP.PacketsReceived
+
+ if !test.shouldAccept {
+ if got := stats.Value(); got != 0 {
+ t.Errorf("got UDP Rx Packets = %d, want = 0", got)
+ }
+
+ return
+ }
+
+ // Expect a UDP packet.
+ if got := stats.Value(); got != 1 {
+ t.Errorf("got UDP Rx Packets = %d, want = 1", got)
+ }
+ gotPayload, _, err := ep.Read(nil)
+ if err != nil {
+ t.Fatalf("Read(nil): %s", err)
+ }
+ if diff := cmp.Diff(buffer.View(udpPayload), gotPayload); diff != "" {
+ t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have any more UDP packets.
+ if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ }
+ })
+ }
+}
+
+// fragmentData holds the IPv6 payload for a fragmented IPv6 packet.
+type fragmentData struct {
+ nextHdr uint8
+ data buffer.VectorisedView
+}
+
+func TestReceiveIPv6Fragments(t *testing.T) {
+ const nicID = 1
+ const udpPayload1Length = 256
+ const udpPayload2Length = 128
+ const fragmentExtHdrLen = 8
+ // Note, not all routing extension headers will be 8 bytes but this test
+ // uses 8 byte routing extension headers for most sub tests.
+ const routingExtHdrLen = 8
+
+ udpGen := func(payload []byte, multiplier uint8) buffer.View {
+ payloadLen := len(payload)
+ for i := 0; i < payloadLen; i++ {
+ payload[i] = uint8(i) * multiplier
+ }
+
+ udpLength := header.UDPMinimumSize + payloadLen
+
+ hdr := buffer.NewPrependable(udpLength)
+ u := header.UDP(hdr.Prepend(udpLength))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: uint16(udpLength),
+ })
+ copy(u.Payload(), payload)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+ sum = header.Checksum(payload, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ return hdr.View()
+ }
+
+ var udpPayload1Buf [udpPayload1Length]byte
+ udpPayload1 := udpPayload1Buf[:]
+ ipv6Payload1 := udpGen(udpPayload1, 1)
+
+ var udpPayload2Buf [udpPayload2Length]byte
+ udpPayload2 := udpPayload2Buf[:]
+ ipv6Payload2 := udpGen(udpPayload2, 2)
+
+ tests := []struct {
+ name string
+ expectedPayload []byte
+ fragments []fragmentData
+ expectedPayloads [][]byte
+ }{
+ {
+ name: "No fragmentation",
+ fragments: []fragmentData{
+ {
+ nextHdr: uint8(header.UDPProtocolNumber),
+ data: ipv6Payload1.ToVectorisedView(),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Atomic fragment",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1),
+ []buffer.View{
+ // Fragment extension header.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}),
+
+ ipv6Payload1,
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Two fragments",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Two fragments with different IDs",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 2
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with per-fragment routing header with zero segments left",
+ fragments: []fragmentData{
+ {
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 0.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 0.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Two fragments with per-fragment routing header with non-zero segments left",
+ fragments: []fragmentData{
+ {
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 1.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 1.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 9, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with routing header with zero segments left",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header.
+ //
+ // Segments left = 0.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 9, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Two fragments with routing header with non-zero segments left",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header.
+ //
+ // Segments left = 1.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 9, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with routing header with zero segments left across fragments",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is fragmentExtHdrLen+8 because the
+ // first 8 bytes of the 16 byte routing extension header is in
+ // this fragment.
+ fragmentExtHdrLen+8,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header (part 1)
+ //
+ // Segments left = 0.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}),
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is
+ // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of
+ // the 16 byte routing extension header is in this fagment.
+ fragmentExtHdrLen+8+len(ipv6Payload1),
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 1, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}),
+
+ // Routing extension header (part 2)
+ buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
+
+ ipv6Payload1,
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with routing header with non-zero segments left across fragments",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is fragmentExtHdrLen+8 because the
+ // first 8 bytes of the 16 byte routing extension header is in
+ // this fragment.
+ fragmentExtHdrLen+8,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header (part 1)
+ //
+ // Segments left = 1.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}),
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is
+ // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of
+ // the 16 byte routing extension header is in this fagment.
+ fragmentExtHdrLen+8+len(ipv6Payload1),
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 1, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}),
+
+ // Routing extension header (part 2)
+ buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
+
+ ipv6Payload1,
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ // As per RFC 6946, IPv6 atomic fragments MUST NOT interfere with "normal"
+ // fragmented traffic.
+ {
+ name: "Two fragments with atomic",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ // This fragment has the same ID as the other fragments but is an atomic
+ // fragment. It should not interfere with the other fragments.
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload2),
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}),
+
+ ipv6Payload2,
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload2, udpPayload1},
+ },
+ {
+ name: "Two interleaved fragmented packets",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+32,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 2
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}),
+
+ ipv6Payload2[:32],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload2)-32,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 4, More = false, ID = 2
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}),
+
+ ipv6Payload2[32:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1, udpPayload2},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ }
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err)
+ }
+ defer ep.Close()
+
+ bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s", bindAddr, err)
+ }
+
+ for _, f := range test.fragments {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize)
+
+ // Serialize IPv6 fixed header.
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(f.data.Size()),
+ NextHeader: f.nextHdr,
+ HopLimit: 255,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ })
+
+ vv := hdr.View().ToVectorisedView()
+ vv.Append(f.data)
+
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ Data: vv,
+ })
+ }
+
+ if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want {
+ t.Errorf("got UDP Rx Packets = %d, want = %d", got, want)
+ }
+
+ for i, p := range test.expectedPayloads {
+ gotPayload, _, err := ep.Read(nil)
+ if err != nil {
+ t.Fatalf("(i=%d) Read(nil): %s", i, err)
+ }
+ if diff := cmp.Diff(buffer.View(p), gotPayload); diff != "" {
+ t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff)
+ }
+ }
+
+ if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
new file mode 100644
index 000000000..64239ce9a
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -0,0 +1,907 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+)
+
+// setupStackAndEndpoint creates a stack with a single NIC with a link-local
+// address llladdr and an IPv6 endpoint to a remote with link-local address
+// rlladdr
+func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) {
+ t.Helper()
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ })
+
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+ if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
+
+ netProto := s.NetworkProtocolInstance(ProtocolNumber)
+ if netProto == nil {
+ t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
+ }
+
+ ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
+ if err != nil {
+ t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
+ }
+
+ return s, ep
+}
+
+// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a
+// valid NDP NS message with the Source Link Layer Address option results in a
+// new entry in the link address cache for the sender of the message.
+func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ optsBuf []byte
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Valid",
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7},
+ expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
+ },
+ {
+ name: "Too Small",
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6},
+ },
+ {
+ name: "Invalid Length",
+ optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ opts := ns.Options()
+ copy(opts, test.optsBuf)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
+ if linkAddr != test.expectedLinkAddr {
+ t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
+ }
+
+ if test.expectedLinkAddr != "" {
+ if err != nil {
+ t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
+ }
+ if c != nil {
+ t.Errorf("got unexpected channel")
+ }
+
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
+ }
+ if c == nil {
+ t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+ }
+
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
+ }
+ })
+ }
+}
+
+func TestNeighorSolicitationResponse(t *testing.T) {
+ const nicID = 1
+ nicAddr := lladdr0
+ remoteAddr := lladdr1
+ nicAddrSNMC := header.SolicitedNodeAddr(nicAddr)
+ nicLinkAddr := linkAddr0
+ remoteLinkAddr0 := linkAddr1
+ remoteLinkAddr1 := linkAddr2
+
+ tests := []struct {
+ name string
+ nsOpts header.NDPOptionsSerializer
+ nsSrcLinkAddr tcpip.LinkAddress
+ nsSrc tcpip.Address
+ nsDst tcpip.Address
+ nsInvalid bool
+ naDstLinkAddr tcpip.LinkAddress
+ naSolicited bool
+ naSrc tcpip.Address
+ naDst tcpip.Address
+ }{
+ {
+ name: "Unspecified source to multicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddrSNMC,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: false,
+ naSrc: nicAddr,
+ naDst: header.IPv6AllNodesMulticastAddress,
+ },
+ {
+ name: "Unspecified source with source ll option to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddrSNMC,
+ nsInvalid: true,
+ },
+ {
+ name: "Unspecified source to unicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: false,
+ naSrc: nicAddr,
+ naDst: header.IPv6AllNodesMulticastAddress,
+ },
+ {
+ name: "Unspecified source with source ll option to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddr,
+ nsInvalid: true,
+ },
+
+ {
+ name: "Specified source with 1 source ll to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 1 source ll different from route to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr1,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source to multicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: true,
+ },
+ {
+ name: "Specified source with 2 source ll to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: true,
+ },
+
+ {
+ name: "Specified source to unicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 1 source ll to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 1 source ll different from route to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr1,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 2 source ll to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ e := channel.New(1, 1280, nicLinkAddr)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
+ }
+
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(nicAddr)
+ opts := ns.Options()
+ opts.Serialize(test.nsOpts)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: test.nsSrc,
+ DstAddr: test.nsDst,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, &stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ if test.nsInvalid {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+
+ if p, got := e.Read(); got {
+ t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt)
+ }
+
+ // If we expected the NS to be invalid, we have nothing else to check.
+ return
+ }
+
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ p, got := e.Read()
+ if !got {
+ t.Fatal("expected an NDP NA response")
+ }
+
+ if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
+ }
+
+ checker.IPv6(t, p.Pkt.Header.View(),
+ checker.SrcAddr(test.naSrc),
+ checker.DstAddr(test.naDst),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNA(
+ checker.NDPNASolicitedFlag(test.naSolicited),
+ checker.NDPNATargetAddress(nicAddr),
+ checker.NDPNAOptions([]header.NDPOption{
+ header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]),
+ }),
+ ))
+ })
+ }
+}
+
+// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a
+// valid NDP NA message with the Target Link Layer Address option results in a
+// new entry in the link address cache for the target of the message.
+func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ optsBuf []byte
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Valid",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7},
+ expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
+ },
+ {
+ name: "Too Small",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6},
+ },
+ {
+ name: "Invalid Length",
+ optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7},
+ },
+ {
+ name: "Multiple",
+ optsBuf: []byte{
+ 2, 1, 2, 3, 4, 5, 6, 7,
+ 2, 1, 2, 3, 4, 5, 6, 8,
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr1)
+ opts := ns.Options()
+ copy(opts, test.optsBuf)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
+ if linkAddr != test.expectedLinkAddr {
+ t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
+ }
+
+ if test.expectedLinkAddr != "" {
+ if err != nil {
+ t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
+ }
+ if c != nil {
+ t.Errorf("got unexpected channel")
+ }
+
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
+ }
+ if c == nil {
+ t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+ }
+
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
+ }
+ })
+ }
+}
+
+func TestNDPValidation(t *testing.T) {
+ setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
+ t.Helper()
+
+ // Create a stack with the assigned link-local address lladdr0
+ // and an endpoint to lladdr1.
+ s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1)
+
+ r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ }
+
+ return s, ep, r
+ }
+
+ handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
+ nextHdr := uint8(header.ICMPv6ProtocolNumber)
+ var extensions buffer.View
+ if atomicFragment {
+ extensions = buffer.NewView(header.IPv6FragmentExtHdrLength)
+ extensions[0] = nextHdr
+ nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
+ }
+
+ ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize + len(extensions)))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(payload) + len(extensions)),
+ NextHeader: nextHdr,
+ HopLimit: hopLimit,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
+ t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
+ }
+ ep.HandlePacket(r, &stack.PacketBuffer{
+ NetworkHeader: buffer.View(ip),
+ Data: payload.ToVectorisedView(),
+ })
+ }
+
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ extraData []byte
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ }{
+ {
+ name: "RouterSolicit",
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterSolicit
+ },
+ },
+ {
+ name: "RouterAdvert",
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterAdvert
+ },
+ },
+ {
+ name: "NeighborSolicit",
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborSolicit
+ },
+ },
+ {
+ name: "NeighborAdvert",
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborAdvert
+ },
+ },
+ {
+ name: "RedirectMsg",
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RedirectMsg
+ },
+ },
+ }
+
+ subTests := []struct {
+ name string
+ atomicFragment bool
+ hopLimit uint8
+ code uint8
+ valid bool
+ }{
+ {
+ name: "Valid",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: true,
+ },
+ {
+ name: "Fragmented",
+ atomicFragment: true,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid hop limit",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit - 1,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid ICMPv6 code",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 1,
+ valid: false,
+ },
+ }
+
+ for _, typ := range types {
+ t.Run(typ.name, func(t *testing.T) {
+ for _, test := range subTests {
+ t.Run(test.name, func(t *testing.T) {
+ s, ep, r := setup(t)
+ defer r.Release()
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
+
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ icmp.SetCode(test.code)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+
+ // Rx count of the NDP message should initially be 0.
+ if got := typStat.Value(); got != 0 {
+ t.Errorf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r)
+
+ // Rx count of the NDP packet should have increased.
+ if got := typStat.Value(); got != 1 {
+ t.Errorf("got %s = %d, want = 1", typ.name, got)
+ }
+
+ want := uint64(0)
+ if !test.valid {
+ // Invalid count should have increased.
+ want = 1
+ }
+ if got := invalid.Value(); got != want {
+ t.Errorf("got invalid = %d, want = %d", got, want)
+ }
+ })
+ }
+ })
+ }
+}
+
+// TestRouterAdvertValidation tests that when the NIC is configured to handle
+// NDP Router Advertisement packets, it validates the Router Advertisement
+// properly before handling them.
+func TestRouterAdvertValidation(t *testing.T) {
+ tests := []struct {
+ name string
+ src tcpip.Address
+ hopLimit uint8
+ code uint8
+ ndpPayload []byte
+ expectedSuccess bool
+ }{
+ {
+ "OK",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ true,
+ },
+ {
+ "NonLinkLocalSourceAddr",
+ addr1,
+ 255,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "HopLimitNot255",
+ lladdr0,
+ 254,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "NonZeroCode",
+ lladdr0,
+ 255,
+ 1,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "NDPPayloadTooSmall",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "OKWithOptions",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ // RA payload
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+
+ // Option #1 (TargetLinkLayerAddress)
+ 2, 1, 0, 0, 0, 0, 0, 0,
+
+ // Option #2 (unrecognized)
+ 255, 1, 0, 0, 0, 0, 0, 0,
+
+ // Option #3 (PrefixInformation)
+ 3, 4, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ },
+ true,
+ },
+ {
+ "OptionWithZeroLength",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ // RA payload
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+
+ // Option #1 (TargetLinkLayerAddress)
+ // Invalid as it has 0 length.
+ 2, 0, 0, 0, 0, 0, 0, 0,
+
+ // Option #2 (unrecognized)
+ 255, 1, 0, 0, 0, 0, 0, 0,
+
+ // Option #3 (PrefixInformation)
+ 3, 4, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ },
+ false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
+ pkt := header.ICMPv6(hdr.Prepend(icmpSize))
+ pkt.SetType(header.ICMPv6RouterAdvert)
+ pkt.SetCode(test.code)
+ copy(pkt.NDPPayload(), test.ndpPayload)
+ payloadLength := hdr.UsedLength()
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: test.hopLimit,
+ SrcAddr: test.src,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ rxRA := stats.RouterAdvert
+
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := rxRA.Value(); got != 0 {
+ t.Fatalf("got rxRA = %d, want = 0", got)
+ }
+
+ e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ if got := rxRA.Value(); got != 1 {
+ t.Fatalf("got rxRA = %d, want = 1", got)
+ }
+
+ if test.expectedSuccess {
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
new file mode 100644
index 000000000..2bad05a2e
--- /dev/null
+++ b/pkg/tcpip/ports/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "ports",
+ srcs = ["ports.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sync",
+ "//pkg/tcpip",
+ ],
+)
+
+go_test(
+ name = "ports_test",
+ srcs = ["ports_test.go"],
+ library = ":ports",
+ deps = [
+ "//pkg/tcpip",
+ ],
+)
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
new file mode 100644
index 000000000..f6d592eb5
--- /dev/null
+++ b/pkg/tcpip/ports/ports.go
@@ -0,0 +1,554 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ports provides PortManager that manages allocating, reserving and releasing ports.
+package ports
+
+import (
+ "math"
+ "math/rand"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // FirstEphemeral is the first ephemeral port.
+ FirstEphemeral = 16000
+
+ // numEphemeralPorts it the mnumber of available ephemeral ports to
+ // Netstack.
+ numEphemeralPorts = math.MaxUint16 - FirstEphemeral + 1
+
+ anyIPAddress tcpip.Address = ""
+)
+
+type portDescriptor struct {
+ network tcpip.NetworkProtocolNumber
+ transport tcpip.TransportProtocolNumber
+ port uint16
+}
+
+// Flags represents the type of port reservation.
+//
+// +stateify savable
+type Flags struct {
+ // MostRecent represents UDP SO_REUSEADDR.
+ MostRecent bool
+
+ // LoadBalanced indicates SO_REUSEPORT.
+ //
+ // LoadBalanced takes precidence over MostRecent.
+ LoadBalanced bool
+
+ // TupleOnly represents TCP SO_REUSEADDR.
+ TupleOnly bool
+}
+
+// Bits converts the Flags to their bitset form.
+func (f Flags) Bits() BitFlags {
+ var rf BitFlags
+ if f.MostRecent {
+ rf |= MostRecentFlag
+ }
+ if f.LoadBalanced {
+ rf |= LoadBalancedFlag
+ }
+ if f.TupleOnly {
+ rf |= TupleOnlyFlag
+ }
+ return rf
+}
+
+// Effective returns the effective behavior of a flag config.
+func (f Flags) Effective() Flags {
+ e := f
+ if e.LoadBalanced && e.MostRecent {
+ e.MostRecent = false
+ }
+ return e
+}
+
+// PortManager manages allocating, reserving and releasing ports.
+type PortManager struct {
+ mu sync.RWMutex
+ allocatedPorts map[portDescriptor]bindAddresses
+
+ // hint is used to pick ports ephemeral ports in a stable order for
+ // a given port offset.
+ //
+ // hint must be accessed using the portHint/incPortHint helpers.
+ // TODO(gvisor.dev/issue/940): S/R this field.
+ hint uint32
+}
+
+// BitFlags is a bitset representation of Flags.
+type BitFlags uint32
+
+const (
+ // MostRecentFlag represents Flags.MostRecent.
+ MostRecentFlag BitFlags = 1 << iota
+
+ // LoadBalancedFlag represents Flags.LoadBalanced.
+ LoadBalancedFlag
+
+ // TupleOnlyFlag represents Flags.TupleOnly.
+ TupleOnlyFlag
+
+ // nextFlag is the value that the next added flag will have.
+ //
+ // It is used to calculate FlagMask below. It is also the number of
+ // valid flag states.
+ nextFlag
+
+ // FlagMask is a bit mask for BitFlags.
+ FlagMask = nextFlag - 1
+
+ // MultiBindFlagMask contains the flags that allow binding the same
+ // tuple multiple times.
+ MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag
+)
+
+// ToFlags converts the bitset into a Flags struct.
+func (f BitFlags) ToFlags() Flags {
+ return Flags{
+ MostRecent: f&MostRecentFlag != 0,
+ LoadBalanced: f&LoadBalancedFlag != 0,
+ TupleOnly: f&TupleOnlyFlag != 0,
+ }
+}
+
+// FlagCounter counts how many references each flag combination has.
+type FlagCounter struct {
+ // refs stores the count for each possible flag combination, (0 though
+ // FlagMask).
+ refs [nextFlag]int
+}
+
+// AddRef increases the reference count for a specific flag combination.
+func (c *FlagCounter) AddRef(flags BitFlags) {
+ c.refs[flags]++
+}
+
+// DropRef decreases the reference count for a specific flag combination.
+func (c *FlagCounter) DropRef(flags BitFlags) {
+ c.refs[flags]--
+}
+
+// TotalRefs calculates the total number of references for all flag
+// combinations.
+func (c FlagCounter) TotalRefs() int {
+ var total int
+ for _, r := range c.refs {
+ total += r
+ }
+ return total
+}
+
+// FlagRefs returns the number of references with all specified flags.
+func (c FlagCounter) FlagRefs(flags BitFlags) int {
+ var total int
+ for i, r := range c.refs {
+ if BitFlags(i)&flags == flags {
+ total += r
+ }
+ }
+ return total
+}
+
+// AllRefsHave returns if all references have all specified flags.
+func (c FlagCounter) AllRefsHave(flags BitFlags) bool {
+ for i, r := range c.refs {
+ if BitFlags(i)&flags != flags && r > 0 {
+ return false
+ }
+ }
+ return true
+}
+
+// IntersectionRefs returns the set of flags shared by all references.
+func (c FlagCounter) IntersectionRefs() BitFlags {
+ intersection := FlagMask
+ for i, r := range c.refs {
+ if r > 0 {
+ intersection &= BitFlags(i)
+ }
+ }
+ return intersection
+}
+
+type destination struct {
+ addr tcpip.Address
+ port uint16
+}
+
+func makeDestination(a tcpip.FullAddress) destination {
+ return destination{
+ a.Addr,
+ a.Port,
+ }
+}
+
+// portNode is never empty. When it has no elements, it is removed from the
+// map that references it.
+type portNode map[destination]FlagCounter
+
+// intersectionRefs calculates the intersection of flag bit values which affect
+// the specified destination.
+//
+// If no destinations are present, all flag values are returned as there are no
+// entries to limit possible flag values of a new entry.
+//
+// In addition to the intersection, the number of intersecting refs is
+// returned.
+func (p portNode) intersectionRefs(dst destination) (BitFlags, int) {
+ intersection := FlagMask
+ var count int
+
+ for d, f := range p {
+ if d == dst {
+ intersection &= f.IntersectionRefs()
+ count++
+ continue
+ }
+ // Wildcard destinations affect all destinations for TupleOnly.
+ if d.addr == anyIPAddress || dst.addr == anyIPAddress {
+ // Only bitwise and the TupleOnlyFlag.
+ intersection &= ((^TupleOnlyFlag) | f.IntersectionRefs())
+ count++
+ }
+ }
+
+ return intersection, count
+}
+
+// deviceNode is never empty. When it has no elements, it is removed from the
+// map that references it.
+type deviceNode map[tcpip.NICID]portNode
+
+// isAvailable checks whether binding is possible by device. If not binding to a
+// device, check against all FlagCounters. If binding to a specific device, check
+// against the unspecified device and the provided device.
+//
+// If either of the port reuse flags is enabled on any of the nodes, all nodes
+// sharing a port must share at least one reuse flag. This matches Linux's
+// behavior.
+func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
+ flagBits := flags.Bits()
+ if bindToDevice == 0 {
+ intersection := FlagMask
+ for _, p := range d {
+ i, c := p.intersectionRefs(dst)
+ if c == 0 {
+ continue
+ }
+ intersection &= i
+ if intersection&flagBits == 0 {
+ // Can't bind because the (addr,port) was
+ // previously bound without reuse.
+ return false
+ }
+ }
+ return true
+ }
+
+ intersection := FlagMask
+
+ if p, ok := d[0]; ok {
+ var c int
+ intersection, c = p.intersectionRefs(dst)
+ if c > 0 && intersection&flagBits == 0 {
+ return false
+ }
+ }
+
+ if p, ok := d[bindToDevice]; ok {
+ i, c := p.intersectionRefs(dst)
+ intersection &= i
+ if c > 0 && intersection&flagBits == 0 {
+ return false
+ }
+ }
+
+ return true
+}
+
+// bindAddresses is a set of IP addresses.
+type bindAddresses map[tcpip.Address]deviceNode
+
+// isAvailable checks whether an IP address is available to bind to. If the
+// address is the "any" address, check all other addresses. Otherwise, just
+// check against the "any" address and the provided address.
+func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
+ if addr == anyIPAddress {
+ // If binding to the "any" address then check that there are no conflicts
+ // with all addresses.
+ for _, d := range b {
+ if !d.isAvailable(flags, bindToDevice, dst) {
+ return false
+ }
+ }
+ return true
+ }
+
+ // Check that there is no conflict with the "any" address.
+ if d, ok := b[anyIPAddress]; ok {
+ if !d.isAvailable(flags, bindToDevice, dst) {
+ return false
+ }
+ }
+
+ // Check that this is no conflict with the provided address.
+ if d, ok := b[addr]; ok {
+ if !d.isAvailable(flags, bindToDevice, dst) {
+ return false
+ }
+ }
+
+ return true
+}
+
+// NewPortManager creates new PortManager.
+func NewPortManager() *PortManager {
+ return &PortManager{allocatedPorts: make(map[portDescriptor]bindAddresses)}
+}
+
+// PickEphemeralPort randomly chooses a starting point and iterates over all
+// possible ephemeral ports, allowing the caller to decide whether a given port
+// is suitable for its needs, and stopping when a port is found or an error
+// occurs.
+func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+ offset := uint32(rand.Int31n(numEphemeralPorts))
+ return s.pickEphemeralPort(offset, numEphemeralPorts, testPort)
+}
+
+// portHint atomically reads and returns the s.hint value.
+func (s *PortManager) portHint() uint32 {
+ return atomic.LoadUint32(&s.hint)
+}
+
+// incPortHint atomically increments s.hint by 1.
+func (s *PortManager) incPortHint() {
+ atomic.AddUint32(&s.hint, 1)
+}
+
+// PickEphemeralPortStable starts at the specified offset + s.portHint and
+// iterates over all ephemeral ports, allowing the caller to decide whether a
+// given port is suitable for its needs and stopping when a port is found or an
+// error occurs.
+func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+ p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort)
+ if err == nil {
+ s.incPortHint()
+ }
+ return p, err
+
+}
+
+// pickEphemeralPort starts at the offset specified from the FirstEphemeral port
+// and iterates over the number of ports specified by count and allows the
+// caller to decide whether a given port is suitable for its needs, and stopping
+// when a port is found or an error occurs.
+func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+ for i := uint32(0); i < count; i++ {
+ port = uint16(FirstEphemeral + (offset+i)%count)
+ ok, err := testPort(port)
+ if err != nil {
+ return 0, err
+ }
+
+ if ok {
+ return port, nil
+ }
+ }
+
+ return 0, tcpip.ErrNoPortAvailable
+}
+
+// IsPortAvailable tests if the given port is available on all given protocols.
+func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, makeDestination(dest))
+}
+
+func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
+ for _, network := range networks {
+ desc := portDescriptor{network, transport, port}
+ if addrs, ok := s.allocatedPorts[desc]; ok {
+ if !addrs.isAvailable(addr, flags, bindToDevice, dst) {
+ return false
+ }
+ }
+ }
+ return true
+}
+
+// ReservePort marks a port/IP combination as reserved so that it cannot be
+// reserved by another endpoint. If port is zero, ReservePort will search for
+// an unreserved ephemeral port and reserve it, returning its value in the
+// "port" return value.
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) (reservedPort uint16, err *tcpip.Error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ dst := makeDestination(dest)
+
+ // If a port is specified, just try to reserve it for all network
+ // protocols.
+ if port != 0 {
+ if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) {
+ return 0, tcpip.ErrPortInUse
+ }
+ return port, nil
+ }
+
+ // A port wasn't specified, so try to find one.
+ return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst), nil
+ })
+}
+
+// reserveSpecificPort tries to reserve the given port on all given protocols.
+func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
+ if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, dst) {
+ return false
+ }
+
+ flagBits := flags.Bits()
+
+ // Reserve port on all network protocols.
+ for _, network := range networks {
+ desc := portDescriptor{network, transport, port}
+ m, ok := s.allocatedPorts[desc]
+ if !ok {
+ m = make(bindAddresses)
+ s.allocatedPorts[desc] = m
+ }
+ d, ok := m[addr]
+ if !ok {
+ d = make(deviceNode)
+ m[addr] = d
+ }
+ p := d[bindToDevice]
+ if p == nil {
+ p = make(portNode)
+ }
+ n := p[dst]
+ n.AddRef(flagBits)
+ p[dst] = n
+ d[bindToDevice] = p
+ }
+
+ return true
+}
+
+// ReserveTuple adds a port reservation for the tuple on all given protocol.
+func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool {
+ flagBits := flags.Bits()
+ dst := makeDestination(dest)
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // It is easier to undo the entire reservation, so if we find that the
+ // tuple can't be fully added, finish and undo the whole thing.
+ undo := false
+
+ // Reserve port on all network protocols.
+ for _, network := range networks {
+ desc := portDescriptor{network, transport, port}
+ m, ok := s.allocatedPorts[desc]
+ if !ok {
+ m = make(bindAddresses)
+ s.allocatedPorts[desc] = m
+ }
+ d, ok := m[addr]
+ if !ok {
+ d = make(deviceNode)
+ m[addr] = d
+ }
+ p := d[bindToDevice]
+ if p == nil {
+ p = make(portNode)
+ }
+
+ n := p[dst]
+ if n.TotalRefs() != 0 && n.IntersectionRefs()&flagBits == 0 {
+ // Tuple already exists.
+ undo = true
+ }
+ n.AddRef(flagBits)
+ p[dst] = n
+ d[bindToDevice] = p
+ }
+
+ if undo {
+ // releasePortLocked decrements the counts (rather than setting
+ // them to zero), so it will undo the incorrect incrementing
+ // above.
+ s.releasePortLocked(networks, transport, addr, port, flagBits, bindToDevice, dst)
+ return false
+ }
+
+ return true
+}
+
+// ReleasePort releases the reservation on a port/IP combination so that it can
+// be reserved by other endpoints.
+func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, makeDestination(dest))
+}
+
+func (s *PortManager) releasePortLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags BitFlags, bindToDevice tcpip.NICID, dst destination) {
+ for _, network := range networks {
+ desc := portDescriptor{network, transport, port}
+ if m, ok := s.allocatedPorts[desc]; ok {
+ d, ok := m[addr]
+ if !ok {
+ continue
+ }
+ p, ok := d[bindToDevice]
+ if !ok {
+ continue
+ }
+ n, ok := p[dst]
+ if !ok {
+ continue
+ }
+ n.DropRef(flags)
+ if n.TotalRefs() > 0 {
+ p[dst] = n
+ continue
+ }
+ delete(p, dst)
+ if len(p) > 0 {
+ continue
+ }
+ delete(d, bindToDevice)
+ if len(d) > 0 {
+ continue
+ }
+ delete(m, addr)
+ if len(m) > 0 {
+ continue
+ }
+ delete(s.allocatedPorts, desc)
+ }
+ }
+}
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
new file mode 100644
index 000000000..58db5868c
--- /dev/null
+++ b/pkg/tcpip/ports/ports_test.go
@@ -0,0 +1,450 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ports
+
+import (
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ fakeTransNumber tcpip.TransportProtocolNumber = 1
+ fakeNetworkNumber tcpip.NetworkProtocolNumber = 2
+
+ fakeIPAddress = tcpip.Address("\x08\x08\x08\x08")
+ fakeIPAddress1 = tcpip.Address("\x08\x08\x08\x09")
+)
+
+type portReserveTestAction struct {
+ port uint16
+ ip tcpip.Address
+ want *tcpip.Error
+ flags Flags
+ release bool
+ device tcpip.NICID
+ dest tcpip.FullAddress
+}
+
+func TestPortReservation(t *testing.T) {
+ for _, test := range []struct {
+ tname string
+ actions []portReserveTestAction
+ }{
+ {
+ tname: "bind to ip",
+ actions: []portReserveTestAction{
+ {port: 80, ip: fakeIPAddress, want: nil},
+ {port: 80, ip: fakeIPAddress1, want: nil},
+ /* N.B. Order of tests matters! */
+ {port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse},
+ {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}},
+ },
+ },
+ {
+ tname: "bind to inaddr any",
+ actions: []portReserveTestAction{
+ {port: 22, ip: anyIPAddress, want: nil},
+ {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ /* release fakeIPAddress, but anyIPAddress is still inuse */
+ {port: 22, ip: fakeIPAddress, release: true},
+ {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}},
+ /* Release port 22 from any IP address, then try to reserve fake IP address on 22 */
+ {port: 22, ip: anyIPAddress, want: nil, release: true},
+ {port: 22, ip: fakeIPAddress, want: nil},
+ },
+ }, {
+ tname: "bind to zero port",
+ actions: []portReserveTestAction{
+ {port: 00, ip: fakeIPAddress, want: nil},
+ {port: 00, ip: fakeIPAddress, want: nil},
+ {port: 00, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ },
+ }, {
+ tname: "bind to ip with reuseport",
+ actions: []portReserveTestAction{
+ {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+
+ {port: 25, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 25, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
+
+ {port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ },
+ }, {
+ tname: "bind to inaddr any with reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+
+ {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
+
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil},
+
+ {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true},
+ {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
+
+ {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true},
+ {port: 24, ip: anyIPAddress, flags: Flags{}, want: nil},
+ },
+ }, {
+ tname: "bind twice with device fails",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 3, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind to device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 1, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 2, want: nil},
+ },
+ }, {
+ tname: "bind to device and then without device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind without device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, want: nil},
+ {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
+ },
+ }, {
+ tname: "binding with reuseport and device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "mixing reuseport and not reuseport by binding to device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 999, want: nil},
+ },
+ }, {
+ tname: "can't bind to 0 after mixing reuseport and not reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind and release",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
+
+ // Release the bind to device 0 and try again.
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil},
+ },
+ }, {
+ tname: "bind twice with reuseport once",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "release an unreserved device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil},
+ // The below don't exist.
+ {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil, release: true},
+ {port: 9999, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true},
+ // Release all.
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil, release: true},
+ },
+ }, {
+ tname: "bind with reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil},
+ },
+ }, {
+ tname: "bind twice with reuseaddr once",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport, and then reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport, and then reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport twice, and then reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport twice, and then reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ },
+ }, {
+ tname: "bind with reuseaddr, and then reuseaddr and reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuseport, and then reuseaddr and reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind tuple with reuseaddr, and then wildcard with reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil},
+ },
+ }, {
+ tname: "bind tuple with reuseaddr, and then wildcard",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
+ {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind wildcard with reuseaddr, and then tuple with reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
+ },
+ }, {
+ tname: "bind tuple with reuseaddr, and then wildcard",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind two tuples with reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil},
+ },
+ }, {
+ tname: "bind two tuples",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
+ {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil},
+ },
+ }, {
+ tname: "bind wildcard, and then tuple with reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind wildcard twice with reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil},
+ {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil},
+ },
+ },
+ } {
+ t.Run(test.tname, func(t *testing.T) {
+ pm := NewPortManager()
+ net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber}
+
+ for _, test := range test.actions {
+ if test.release {
+ pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest)
+ continue
+ }
+ gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest)
+ if err != test.want {
+ t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d, %v) = %v, want %v", test.ip, test.port, test.flags, test.device, test.dest, err, test.want)
+ }
+ if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
+ t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
+ }
+ }
+ })
+ }
+}
+
+func TestPickEphemeralPort(t *testing.T) {
+ customErr := &tcpip.Error{}
+ for _, test := range []struct {
+ name string
+ f func(port uint16) (bool, *tcpip.Error)
+ wantErr *tcpip.Error
+ wantPort uint16
+ }{
+ {
+ name: "no-port-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ return false, nil
+ },
+ wantErr: tcpip.ErrNoPortAvailable,
+ },
+ {
+ name: "port-tester-error",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ return false, customErr
+ },
+ wantErr: customErr,
+ },
+ {
+ name: "only-port-16042-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ if port == FirstEphemeral+42 {
+ return true, nil
+ }
+ return false, nil
+ },
+ wantPort: FirstEphemeral + 42,
+ },
+ {
+ name: "only-port-under-16000-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ if port < FirstEphemeral {
+ return true, nil
+ }
+ return false, nil
+ },
+ wantErr: tcpip.ErrNoPortAvailable,
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ pm := NewPortManager()
+ if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr {
+ t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
+ }
+ })
+ }
+}
+
+func TestPickEphemeralPortStable(t *testing.T) {
+ customErr := &tcpip.Error{}
+ for _, test := range []struct {
+ name string
+ f func(port uint16) (bool, *tcpip.Error)
+ wantErr *tcpip.Error
+ wantPort uint16
+ }{
+ {
+ name: "no-port-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ return false, nil
+ },
+ wantErr: tcpip.ErrNoPortAvailable,
+ },
+ {
+ name: "port-tester-error",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ return false, customErr
+ },
+ wantErr: customErr,
+ },
+ {
+ name: "only-port-16042-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ if port == FirstEphemeral+42 {
+ return true, nil
+ }
+ return false, nil
+ },
+ wantPort: FirstEphemeral + 42,
+ },
+ {
+ name: "only-port-under-16000-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ if port < FirstEphemeral {
+ return true, nil
+ }
+ return false, nil
+ },
+ wantErr: tcpip.ErrNoPortAvailable,
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ pm := NewPortManager()
+ portOffset := uint32(rand.Int31n(int32(numEphemeralPorts)))
+ if port, err := pm.PickEphemeralPortStable(portOffset, test.f); port != test.wantPort || err != test.wantErr {
+ t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD
new file mode 100644
index 000000000..cf0a5fefe
--- /dev/null
+++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "tun_tcp_connect",
+ srcs = ["main.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/fdbased",
+ "//pkg/tcpip/link/rawfile",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/link/tun",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
new file mode 100644
index 000000000..0ab089208
--- /dev/null
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -0,0 +1,225 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+// This sample creates a stack with TCP and IPv4 protocols on top of a TUN
+// device, and connects to a peer. Similar to "nc <address> <port>". While the
+// sample is running, attempts to connect to its IPv4 address will result in
+// a RST segment.
+//
+// As an example of how to run it, a TUN device can be created and enabled on
+// a linux host as follows (this only needs to be done once per boot):
+//
+// [sudo] ip tuntap add user <username> mode tun <device-name>
+// [sudo] ip link set <device-name> up
+// [sudo] ip addr add <ipv4-address>/<mask-length> dev <device-name>
+//
+// A concrete example:
+//
+// $ sudo ip tuntap add user wedsonaf mode tun tun0
+// $ sudo ip link set tun0 up
+// $ sudo ip addr add 192.168.1.1/24 dev tun0
+//
+// Then one can run tun_tcp_connect as such:
+//
+// $ ./tun/tun_tcp_connect tun0 192.168.1.2 0 192.168.1.1 1234
+//
+// This will attempt to connect to the linux host's stack. One can run nc in
+// listen mode to accept a connect from tun_tcp_connect and exchange data.
+package main
+
+import (
+ "bufio"
+ "fmt"
+ "log"
+ "math/rand"
+ "net"
+ "os"
+ "strconv"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// writer reads from standard input and writes to the endpoint until standard
+// input is closed. It signals that it's done by closing the provided channel.
+func writer(ch chan struct{}, ep tcpip.Endpoint) {
+ defer func() {
+ ep.Shutdown(tcpip.ShutdownWrite)
+ close(ch)
+ }()
+
+ r := bufio.NewReader(os.Stdin)
+ for {
+ v := buffer.NewView(1024)
+ n, err := r.Read(v)
+ if err != nil {
+ return
+ }
+
+ v.CapLength(n)
+ for len(v) > 0 {
+ n, _, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
+ if err != nil {
+ fmt.Println("Write failed:", err)
+ return
+ }
+
+ v.TrimFront(int(n))
+ }
+ }
+}
+
+func main() {
+ if len(os.Args) != 6 {
+ log.Fatal("Usage: ", os.Args[0], " <tun-device> <local-ipv4-address> <local-port> <remote-ipv4-address> <remote-port>")
+ }
+
+ tunName := os.Args[1]
+ addrName := os.Args[2]
+ portName := os.Args[3]
+ remoteAddrName := os.Args[4]
+ remotePortName := os.Args[5]
+
+ rand.Seed(time.Now().UnixNano())
+
+ addr := tcpip.Address(net.ParseIP(addrName).To4())
+ remote := tcpip.FullAddress{
+ NIC: 1,
+ Addr: tcpip.Address(net.ParseIP(remoteAddrName).To4()),
+ }
+
+ var localPort uint16
+ if v, err := strconv.Atoi(portName); err != nil {
+ log.Fatalf("Unable to convert port %v: %v", portName, err)
+ } else {
+ localPort = uint16(v)
+ }
+
+ if v, err := strconv.Atoi(remotePortName); err != nil {
+ log.Fatalf("Unable to convert port %v: %v", remotePortName, err)
+ } else {
+ remote.Port = uint16(v)
+ }
+
+ // Create the stack with ipv4 and tcp protocols, then add a tun-based
+ // NIC and ipv4 address.
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
+
+ mtu, err := rawfile.GetMTU(tunName)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fd, err := tun.Open(tunName)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ linkEP, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu})
+ if err != nil {
+ log.Fatal(err)
+ }
+ if err := s.CreateNIC(1, sniffer.New(linkEP)); err != nil {
+ log.Fatal(err)
+ }
+
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, addr); err != nil {
+ log.Fatal(err)
+ }
+
+ // Add default route.
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: 1,
+ },
+ })
+
+ // Create TCP endpoint.
+ var wq waiter.Queue
+ ep, e := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if e != nil {
+ log.Fatal(e)
+ }
+
+ // Bind if a port is specified.
+ if localPort != 0 {
+ if err := ep.Bind(tcpip.FullAddress{0, "", localPort}); err != nil {
+ log.Fatal("Bind failed: ", err)
+ }
+ }
+
+ // Issue connect request and wait for it to complete.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&waitEntry, waiter.EventOut)
+ terr := ep.Connect(remote)
+ if terr == tcpip.ErrConnectStarted {
+ fmt.Println("Connect is pending...")
+ <-notifyCh
+ terr = ep.GetSockOpt(tcpip.ErrorOption{})
+ }
+ wq.EventUnregister(&waitEntry)
+
+ if terr != nil {
+ log.Fatal("Unable to connect: ", terr)
+ }
+
+ fmt.Println("Connected")
+
+ // Start the writer in its own goroutine.
+ writerCompletedCh := make(chan struct{})
+ go writer(writerCompletedCh, ep) // S/R-SAFE: sample code.
+
+ // Read data and write to standard output until the peer closes the
+ // connection from its side.
+ wq.EventRegister(&waitEntry, waiter.EventIn)
+ for {
+ v, _, err := ep.Read(nil)
+ if err != nil {
+ if err == tcpip.ErrClosedForReceive {
+ break
+ }
+
+ if err == tcpip.ErrWouldBlock {
+ <-notifyCh
+ continue
+ }
+
+ log.Fatal("Read() failed:", err)
+ }
+
+ os.Stdout.Write(v)
+ }
+ wq.EventUnregister(&waitEntry)
+
+ // The reader has completed. Now wait for the writer as well.
+ <-writerCompletedCh
+
+ ep.Close()
+}
diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD
new file mode 100644
index 000000000..43264b76d
--- /dev/null
+++ b/pkg/tcpip/sample/tun_tcp_echo/BUILD
@@ -0,0 +1,21 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "tun_tcp_echo",
+ srcs = ["main.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/link/fdbased",
+ "//pkg/tcpip/link/rawfile",
+ "//pkg/tcpip/link/tun",
+ "//pkg/tcpip/network/arp",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
new file mode 100644
index 000000000..9e37cab18
--- /dev/null
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -0,0 +1,203 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+// This sample creates a stack with TCP and IPv4 protocols on top of a TUN
+// device, and listens on a port. Data received by the server in the accepted
+// connections is echoed back to the clients.
+package main
+
+import (
+ "flag"
+ "log"
+ "math/rand"
+ "net"
+ "os"
+ "strconv"
+ "strings"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+var tap = flag.Bool("tap", false, "use tap istead of tun")
+var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device")
+
+func echo(wq *waiter.Queue, ep tcpip.Endpoint) {
+ defer ep.Close()
+
+ // Create wait queue entry that notifies a channel.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+
+ wq.EventRegister(&waitEntry, waiter.EventIn)
+ defer wq.EventUnregister(&waitEntry)
+
+ for {
+ v, _, err := ep.Read(nil)
+ if err != nil {
+ if err == tcpip.ErrWouldBlock {
+ <-notifyCh
+ continue
+ }
+
+ return
+ }
+
+ ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
+ }
+}
+
+func main() {
+ flag.Parse()
+ if len(flag.Args()) != 3 {
+ log.Fatal("Usage: ", os.Args[0], " <tun-device> <local-address> <local-port>")
+ }
+
+ tunName := flag.Arg(0)
+ addrName := flag.Arg(1)
+ portName := flag.Arg(2)
+
+ rand.Seed(time.Now().UnixNano())
+
+ // Parse the mac address.
+ maddr, err := net.ParseMAC(*mac)
+ if err != nil {
+ log.Fatalf("Bad MAC address: %v", *mac)
+ }
+
+ // Parse the IP address. Support both ipv4 and ipv6.
+ parsedAddr := net.ParseIP(addrName)
+ if parsedAddr == nil {
+ log.Fatalf("Bad IP address: %v", addrName)
+ }
+
+ var addr tcpip.Address
+ var proto tcpip.NetworkProtocolNumber
+ if parsedAddr.To4() != nil {
+ addr = tcpip.Address(parsedAddr.To4())
+ proto = ipv4.ProtocolNumber
+ } else if parsedAddr.To16() != nil {
+ addr = tcpip.Address(parsedAddr.To16())
+ proto = ipv6.ProtocolNumber
+ } else {
+ log.Fatalf("Unknown IP type: %v", addrName)
+ }
+
+ localPort, err := strconv.Atoi(portName)
+ if err != nil {
+ log.Fatalf("Unable to convert port %v: %v", portName, err)
+ }
+
+ // Create the stack with ip and tcp protocols, then add a tun-based
+ // NIC and address.
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
+
+ mtu, err := rawfile.GetMTU(tunName)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ var fd int
+ if *tap {
+ fd, err = tun.OpenTAP(tunName)
+ } else {
+ fd, err = tun.Open(tunName)
+ }
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ linkEP, err := fdbased.New(&fdbased.Options{
+ FDs: []int{fd},
+ MTU: mtu,
+ EthernetHeader: *tap,
+ Address: tcpip.LinkAddress(maddr),
+ })
+ if err != nil {
+ log.Fatal(err)
+ }
+ if err := s.CreateNIC(1, linkEP); err != nil {
+ log.Fatal(err)
+ }
+
+ if err := s.AddAddress(1, proto, addr); err != nil {
+ log.Fatal(err)
+ }
+
+ if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ log.Fatal(err)
+ }
+
+ subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr))))
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Add default route.
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: subnet,
+ NIC: 1,
+ },
+ })
+
+ // Create TCP endpoint, bind it, then start listening.
+ var wq waiter.Queue
+ ep, e := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq)
+ if e != nil {
+ log.Fatal(e)
+ }
+
+ defer ep.Close()
+
+ if err := ep.Bind(tcpip.FullAddress{0, "", uint16(localPort)}); err != nil {
+ log.Fatal("Bind failed: ", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ log.Fatal("Listen failed: ", err)
+ }
+
+ // Wait for connections to appear.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&waitEntry, waiter.EventIn)
+ defer wq.EventUnregister(&waitEntry)
+
+ for {
+ n, wq, err := ep.Accept()
+ if err != nil {
+ if err == tcpip.ErrWouldBlock {
+ <-notifyCh
+ continue
+ }
+
+ log.Fatal("Accept() failed:", err)
+ }
+
+ go echo(wq, n) // S/R-SAFE: sample code.
+ }
+}
diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD
new file mode 100644
index 000000000..45f503845
--- /dev/null
+++ b/pkg/tcpip/seqnum/BUILD
@@ -0,0 +1,9 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "seqnum",
+ srcs = ["seqnum.go"],
+ visibility = ["//visibility:public"],
+)
diff --git a/pkg/tcpip/seqnum/seqnum.go b/pkg/tcpip/seqnum/seqnum.go
new file mode 100644
index 000000000..d3bea7de4
--- /dev/null
+++ b/pkg/tcpip/seqnum/seqnum.go
@@ -0,0 +1,62 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package seqnum defines the types and methods for TCP sequence numbers such
+// that they fit in 32-bit words and work properly when overflows occur.
+package seqnum
+
+// Value represents the value of a sequence number.
+type Value uint32
+
+// Size represents the size (length) of a sequence number window.
+type Size uint32
+
+// LessThan checks if v is before w, i.e., v < w.
+func (v Value) LessThan(w Value) bool {
+ return int32(v-w) < 0
+}
+
+// LessThanEq returns true if v==w or v is before i.e., v < w.
+func (v Value) LessThanEq(w Value) bool {
+ if v == w {
+ return true
+ }
+ return v.LessThan(w)
+}
+
+// InRange checks if v is in the range [a,b), i.e., a <= v < b.
+func (v Value) InRange(a, b Value) bool {
+ return v-a < b-a
+}
+
+// InWindow checks if v is in the window that starts at 'first' and spans 'size'
+// sequence numbers.
+func (v Value) InWindow(first Value, size Size) bool {
+ return v.InRange(first, first.Add(size))
+}
+
+// Add calculates the sequence number following the [v, v+s) window.
+func (v Value) Add(s Size) Value {
+ return v + Value(s)
+}
+
+// Size calculates the size of the window defined by [v, w).
+func (v Value) Size(w Value) Size {
+ return Size(w - v)
+}
+
+// UpdateForward updates v such that it becomes v + s.
+func (v *Value) UpdateForward(s Size) {
+ *v += Value(s)
+}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
new file mode 100644
index 000000000..e65c731c2
--- /dev/null
+++ b/pkg/tcpip/stack/BUILD
@@ -0,0 +1,118 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "linkaddrentry_list",
+ out = "linkaddrentry_list.go",
+ package = "stack",
+ prefix = "linkAddrEntry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*linkAddrEntry",
+ "Linker": "*linkAddrEntry",
+ },
+)
+
+go_template_instance(
+ name = "packet_buffer_list",
+ out = "packet_buffer_list.go",
+ package = "stack",
+ prefix = "PacketBuffer",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*PacketBuffer",
+ "Linker": "*PacketBuffer",
+ },
+)
+
+go_library(
+ name = "stack",
+ srcs = [
+ "conntrack.go",
+ "dhcpv6configurationfromndpra_string.go",
+ "forwarder.go",
+ "icmp_rate_limit.go",
+ "iptables.go",
+ "iptables_targets.go",
+ "iptables_types.go",
+ "linkaddrcache.go",
+ "linkaddrentry_list.go",
+ "ndp.go",
+ "nic.go",
+ "packet_buffer.go",
+ "packet_buffer_list.go",
+ "rand.go",
+ "registration.go",
+ "route.go",
+ "stack.go",
+ "stack_global_state.go",
+ "stack_options.go",
+ "transport_demuxer.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/ilist",
+ "//pkg/log",
+ "//pkg/rand",
+ "//pkg/sleep",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/hash/jenkins",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/ports",
+ "//pkg/tcpip/seqnum",
+ "//pkg/tcpip/transport/tcpconntrack",
+ "//pkg/waiter",
+ "@org_golang_x_time//rate:go_default_library",
+ ],
+)
+
+go_test(
+ name = "stack_x_test",
+ size = "medium",
+ srcs = [
+ "ndp_test.go",
+ "stack_test.go",
+ "transport_demuxer_test.go",
+ "transport_test.go",
+ ],
+ shard_count = 20,
+ deps = [
+ ":stack",
+ "//pkg/rand",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/ports",
+ "//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ ],
+)
+
+go_test(
+ name = "stack_test",
+ size = "small",
+ srcs = [
+ "forwarder_test.go",
+ "linkaddrcache_test.go",
+ "nic_test.go",
+ ],
+ library = ":stack",
+ deps = [
+ "//pkg/sleep",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ ],
+)
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
new file mode 100644
index 000000000..af9c325ca
--- /dev/null
+++ b/pkg/tcpip/stack/conntrack.go
@@ -0,0 +1,331 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
+)
+
+// Connection tracking is used to track and manipulate packets for NAT rules.
+// The connection is created for a packet if it does not exist. Every
+// connection contains two tuples (original and reply). The tuples are
+// manipulated if there is a matching NAT rule. The packet is modified by
+// looking at the tuples in the Prerouting and Output hooks.
+//
+// Currently, only TCP tracking is supported.
+
+// Direction of the tuple.
+type direction int
+
+const (
+ dirOriginal direction = iota
+ dirReply
+)
+
+// Manipulation type for the connection.
+type manipType int
+
+const (
+ manipDstPrerouting manipType = iota
+ manipDstOutput
+)
+
+// tuple holds a connection's identifying and manipulating data in one
+// direction. It is immutable.
+type tuple struct {
+ tupleID
+
+ // conn is the connection tracking entry this tuple belongs to.
+ conn *conn
+
+ // direction is the direction of the tuple.
+ direction direction
+}
+
+// tupleID uniquely identifies a connection in one direction. It currently
+// contains enough information to distinguish between any TCP or UDP
+// connection, and will need to be extended to support other protocols.
+type tupleID struct {
+ srcAddr tcpip.Address
+ srcPort uint16
+ dstAddr tcpip.Address
+ dstPort uint16
+ transProto tcpip.TransportProtocolNumber
+ netProto tcpip.NetworkProtocolNumber
+}
+
+// reply creates the reply tupleID.
+func (ti tupleID) reply() tupleID {
+ return tupleID{
+ srcAddr: ti.dstAddr,
+ srcPort: ti.dstPort,
+ dstAddr: ti.srcAddr,
+ dstPort: ti.srcPort,
+ transProto: ti.transProto,
+ netProto: ti.netProto,
+ }
+}
+
+// conn is a tracked connection.
+type conn struct {
+ // original is the tuple in original direction. It is immutable.
+ original tuple
+
+ // reply is the tuple in reply direction. It is immutable.
+ reply tuple
+
+ // manip indicates if the packet should be manipulated. It is immutable.
+ manip manipType
+
+ // tcbHook indicates if the packet is inbound or outbound to
+ // update the state of tcb. It is immutable.
+ tcbHook Hook
+
+ // mu protects tcb.
+ mu sync.Mutex
+
+ // tcb is TCB control block. It is used to keep track of states
+ // of tcp connection and is protected by mu.
+ tcb tcpconntrack.TCB
+}
+
+// ConnTrack tracks all connections created for NAT rules. Most users are
+// expected to only call handlePacket and createConnFor.
+type ConnTrack struct {
+ // mu protects conns.
+ mu sync.RWMutex
+
+ // conns maintains a map of tuples needed for connection tracking for
+ // iptables NAT rules. It is protected by mu.
+ conns map[tupleID]tuple
+}
+
+// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
+// TCP header.
+func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
+ // TODO(gvisor.dev/issue/170): Need to support for other
+ // protocols as well.
+ netHeader := header.IPv4(pkt.NetworkHeader)
+ if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ return tupleID{}, tcpip.ErrUnknownProtocol
+ }
+ tcpHeader := header.TCP(pkt.TransportHeader)
+ if tcpHeader == nil {
+ return tupleID{}, tcpip.ErrUnknownProtocol
+ }
+
+ return tupleID{
+ srcAddr: netHeader.SourceAddress(),
+ srcPort: tcpHeader.SourcePort(),
+ dstAddr: netHeader.DestinationAddress(),
+ dstPort: tcpHeader.DestinationPort(),
+ transProto: netHeader.TransportProtocol(),
+ netProto: header.IPv4ProtocolNumber,
+ }, nil
+}
+
+// newConn creates new connection.
+func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
+ conn := conn{
+ manip: manip,
+ tcbHook: hook,
+ }
+ conn.original = tuple{conn: &conn, tupleID: orig}
+ conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
+ return &conn
+}
+
+// connFor gets the conn for pkt if it exists, or returns nil
+// if it does not. It returns an error when pkt does not contain a valid TCP
+// header.
+// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support
+// other transport protocols.
+func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return nil, dirOriginal
+ }
+
+ ct.mu.Lock()
+ defer ct.mu.Unlock()
+
+ tuple, ok := ct.conns[tid]
+ if !ok {
+ return nil, dirOriginal
+ }
+ return tuple.conn, tuple.direction
+}
+
+// createConnFor creates a new conn for pkt.
+func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return nil
+ }
+ if hook != Prerouting && hook != Output {
+ return nil
+ }
+
+ // Create a new connection and change the port as per the iptables
+ // rule. This tuple will be used to manipulate the packet in
+ // handlePacket.
+ replyTID := tid.reply()
+ replyTID.srcAddr = rt.MinIP
+ replyTID.srcPort = rt.MinPort
+ var manip manipType
+ switch hook {
+ case Prerouting:
+ manip = manipDstPrerouting
+ case Output:
+ manip = manipDstOutput
+ }
+ conn := newConn(tid, replyTID, manip, hook)
+
+ // Add the changed tuple to the map.
+ // TODO(gvisor.dev/issue/170): Need to support collisions using linked
+ // list.
+ ct.mu.Lock()
+ defer ct.mu.Unlock()
+ ct.conns[tid] = conn.original
+ ct.conns[replyTID] = conn.reply
+
+ return conn
+}
+
+// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
+// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.
+func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
+ netHeader := header.IPv4(pkt.NetworkHeader)
+ tcpHeader := header.TCP(pkt.TransportHeader)
+
+ // For prerouting redirection, packets going in the original direction
+ // have their destinations modified and replies have their sources
+ // modified.
+ switch dir {
+ case dirOriginal:
+ port := conn.reply.srcPort
+ tcpHeader.SetDestinationPort(port)
+ netHeader.SetDestinationAddress(conn.reply.srcAddr)
+ case dirReply:
+ port := conn.original.dstPort
+ tcpHeader.SetSourcePort(port)
+ netHeader.SetSourceAddress(conn.original.dstAddr)
+ }
+
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+}
+
+// handlePacketOutput manipulates ports for packets in Output hook.
+func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) {
+ netHeader := header.IPv4(pkt.NetworkHeader)
+ tcpHeader := header.TCP(pkt.TransportHeader)
+
+ // For output redirection, packets going in the original direction
+ // have their destinations modified and replies have their sources
+ // modified. For prerouting redirection, we only reach this point
+ // when replying, so packet sources are modified.
+ if conn.manip == manipDstOutput && dir == dirOriginal {
+ port := conn.reply.srcPort
+ tcpHeader.SetDestinationPort(port)
+ netHeader.SetDestinationAddress(conn.reply.srcAddr)
+ } else {
+ port := conn.original.dstPort
+ tcpHeader.SetSourcePort(port)
+ netHeader.SetSourceAddress(conn.original.dstAddr)
+ }
+
+ // Calculate the TCP checksum and set it.
+ tcpHeader.SetChecksum(0)
+ hdr := &pkt.Header
+ length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength())
+ xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length)
+ if gso != nil && gso.NeedsCsum {
+ tcpHeader.SetChecksum(xsum)
+ } else if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
+ xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, int(tcpHeader.DataOffset()), pkt.Data.Size())
+ tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
+ }
+
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+}
+
+// handlePacket will manipulate the port and address of the packet if the
+// connection exists.
+func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) {
+ if pkt.NatDone {
+ return
+ }
+
+ if hook != Prerouting && hook != Output {
+ return
+ }
+
+ conn, dir := ct.connFor(pkt)
+ if conn == nil {
+ // Connection not found for the packet or the packet is invalid.
+ return
+ }
+
+ switch hook {
+ case Prerouting:
+ handlePacketPrerouting(pkt, conn, dir)
+ case Output:
+ handlePacketOutput(pkt, conn, gso, r, dir)
+ }
+ pkt.NatDone = true
+
+ // Update the state of tcb.
+ // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle
+ // other tcp states.
+ conn.mu.Lock()
+ defer conn.mu.Unlock()
+ var st tcpconntrack.Result
+ tcpHeader := header.TCP(pkt.TransportHeader)
+ if conn.tcb.IsEmpty() {
+ conn.tcb.Init(tcpHeader)
+ conn.tcbHook = hook
+ } else {
+ switch hook {
+ case conn.tcbHook:
+ st = conn.tcb.UpdateStateOutbound(tcpHeader)
+ default:
+ st = conn.tcb.UpdateStateInbound(tcpHeader)
+ }
+ }
+
+ // Delete conn if tcp connection is closed.
+ if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset {
+ ct.deleteConn(conn)
+ }
+}
+
+// deleteConn deletes the connection.
+func (ct *ConnTrack) deleteConn(conn *conn) {
+ if conn == nil {
+ return
+ }
+
+ ct.mu.Lock()
+ defer ct.mu.Unlock()
+
+ delete(ct.conns, conn.original.tupleID)
+ delete(ct.conns, conn.reply.tupleID)
+}
diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go
new file mode 100644
index 000000000..d199ded6a
--- /dev/null
+++ b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go
@@ -0,0 +1,40 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT.
+
+package stack
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[DHCPv6NoConfiguration-1]
+ _ = x[DHCPv6ManagedAddress-2]
+ _ = x[DHCPv6OtherConfigurations-3]
+}
+
+const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAddressDHCPv6OtherConfigurations"
+
+var _DHCPv6ConfigurationFromNDPRA_index = [...]uint8{0, 21, 41, 66}
+
+func (i DHCPv6ConfigurationFromNDPRA) String() string {
+ i -= 1
+ if i < 0 || i >= DHCPv6ConfigurationFromNDPRA(len(_DHCPv6ConfigurationFromNDPRA_index)-1) {
+ return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i+1), 10) + ")"
+ }
+ return _DHCPv6ConfigurationFromNDPRA_name[_DHCPv6ConfigurationFromNDPRA_index[i]:_DHCPv6ConfigurationFromNDPRA_index[i+1]]
+}
diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go
new file mode 100644
index 000000000..3eff141e6
--- /dev/null
+++ b/pkg/tcpip/stack/forwarder.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // maxPendingResolutions is the maximum number of pending link-address
+ // resolutions.
+ maxPendingResolutions = 64
+ maxPendingPacketsPerResolution = 256
+)
+
+type pendingPacket struct {
+ nic *NIC
+ route *Route
+ proto tcpip.NetworkProtocolNumber
+ pkt *PacketBuffer
+}
+
+type forwardQueue struct {
+ sync.Mutex
+
+ // The packets to send once the resolver completes.
+ packets map[<-chan struct{}][]*pendingPacket
+
+ // FIFO of channels used to cancel the oldest goroutine waiting for
+ // link-address resolution.
+ cancelChans []chan struct{}
+}
+
+func newForwardQueue() *forwardQueue {
+ return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)}
+}
+
+func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ shouldWait := false
+
+ f.Lock()
+ packets, ok := f.packets[ch]
+ if !ok {
+ shouldWait = true
+ }
+ for len(packets) == maxPendingPacketsPerResolution {
+ p := packets[0]
+ packets = packets[1:]
+ p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ p.route.Release()
+ }
+ if l := len(packets); l >= maxPendingPacketsPerResolution {
+ panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution))
+ }
+ f.packets[ch] = append(packets, &pendingPacket{
+ nic: n,
+ route: r,
+ proto: protocol,
+ pkt: pkt,
+ })
+ f.Unlock()
+
+ if !shouldWait {
+ return
+ }
+
+ // Wait for the link-address resolution to complete.
+ // Start a goroutine with a forwarding-cancel channel so that we can
+ // limit the maximum number of goroutines running concurrently.
+ cancel := f.newCancelChannel()
+ go func() {
+ cancelled := false
+ select {
+ case <-ch:
+ case <-cancel:
+ cancelled = true
+ }
+
+ f.Lock()
+ packets := f.packets[ch]
+ delete(f.packets, ch)
+ f.Unlock()
+
+ for _, p := range packets {
+ if cancelled {
+ p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ } else if _, err := p.route.Resolve(nil); err != nil {
+ p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ } else {
+ p.nic.forwardPacket(p.route, p.proto, p.pkt)
+ }
+ p.route.Release()
+ }
+ }()
+}
+
+// newCancelChannel creates a channel that can cancel a pending forwarding
+// activity. The oldest channel is closed if the number of open channels would
+// exceed maxPendingResolutions.
+func (f *forwardQueue) newCancelChannel() chan struct{} {
+ f.Lock()
+ defer f.Unlock()
+
+ if len(f.cancelChans) == maxPendingResolutions {
+ ch := f.cancelChans[0]
+ f.cancelChans = f.cancelChans[1:]
+ close(ch)
+ }
+ if l := len(f.cancelChans); l >= maxPendingResolutions {
+ panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions))
+ }
+
+ ch := make(chan struct{})
+ f.cancelChans = append(f.cancelChans, ch)
+ return ch
+}
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
new file mode 100644
index 000000000..a6546cef0
--- /dev/null
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -0,0 +1,650 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "encoding/binary"
+ "math"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+const (
+ fwdTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
+ fwdTestNetHeaderLen = 12
+ fwdTestNetDefaultPrefixLen = 8
+
+ // fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests,
+ // except where another value is explicitly used. It is chosen to match
+ // the MTU of loopback interfaces on linux systems.
+ fwdTestNetDefaultMTU = 65536
+
+ dstAddrOffset = 0
+ srcAddrOffset = 1
+ protocolNumberOffset = 2
+)
+
+// fwdTestNetworkEndpoint is a network-layer protocol endpoint.
+// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only
+// use the first three: destination address, source address, and transport
+// protocol. They're all one byte fields to simplify parsing.
+type fwdTestNetworkEndpoint struct {
+ nicID tcpip.NICID
+ id NetworkEndpointID
+ prefixLen int
+ proto *fwdTestNetworkProtocol
+ dispatcher TransportDispatcher
+ ep LinkEndpoint
+}
+
+func (f *fwdTestNetworkEndpoint) MTU() uint32 {
+ return f.ep.MTU() - uint32(f.MaxHeaderLength())
+}
+
+func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID {
+ return f.nicID
+}
+
+func (f *fwdTestNetworkEndpoint) PrefixLen() int {
+ return f.prefixLen
+}
+
+func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
+ return 123
+}
+
+func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID {
+ return &f.id
+}
+
+func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
+ // Dispatch the packet to the transport protocol.
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt)
+}
+
+func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
+ return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen
+}
+
+func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
+ return 0
+}
+
+func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities {
+ return f.ep.Capabilities()
+}
+
+func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return f.proto.Number()
+}
+
+func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
+ // Add the protocol's header to the packet and send it to the link
+ // endpoint.
+ b := pkt.Header.Prepend(fwdTestNetHeaderLen)
+ b[dstAddrOffset] = r.RemoteAddress[0]
+ b[srcAddrOffset] = f.id.LocalAddress[0]
+ b[protocolNumberOffset] = byte(params.Protocol)
+
+ return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt)
+}
+
+// WritePackets implements LinkEndpoint.WritePackets.
+func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+func (*fwdTestNetworkEndpoint) Close() {}
+
+// fwdTestNetworkProtocol is a network-layer protocol that implements Address
+// resolution.
+type fwdTestNetworkProtocol struct {
+ addrCache *linkAddrCache
+ addrResolveDelay time.Duration
+ onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address)
+ onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
+}
+
+func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
+ return fwdTestNetNumber
+}
+
+func (f *fwdTestNetworkProtocol) MinimumPacketSize() int {
+ return fwdTestNetHeaderLen
+}
+
+func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int {
+ return fwdTestNetDefaultPrefixLen
+}
+
+func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
+}
+
+func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
+ netHeader, ok := pkt.Data.PullUp(fwdTestNetHeaderLen)
+ if !ok {
+ return 0, false, false
+ }
+ pkt.NetworkHeader = netHeader
+ pkt.Data.TrimFront(fwdTestNetHeaderLen)
+ return tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), true, true
+}
+
+func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) {
+ return &fwdTestNetworkEndpoint{
+ nicID: nicID,
+ id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
+ prefixLen: addrWithPrefix.PrefixLen,
+ proto: f,
+ dispatcher: dispatcher,
+ ep: ep,
+ }, nil
+}
+
+func (f *fwdTestNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func (f *fwdTestNetworkProtocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func (f *fwdTestNetworkProtocol) Close() {}
+
+func (f *fwdTestNetworkProtocol) Wait() {}
+
+func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error {
+ if f.addrCache != nil && f.onLinkAddressResolved != nil {
+ time.AfterFunc(f.addrResolveDelay, func() {
+ f.onLinkAddressResolved(f.addrCache, addr)
+ })
+ }
+ return nil
+}
+
+func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if f.onResolveStaticAddress != nil {
+ return f.onResolveStaticAddress(addr)
+ }
+ return "", false
+}
+
+func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return fwdTestNetNumber
+}
+
+// fwdTestPacketInfo holds all the information about an outbound packet.
+type fwdTestPacketInfo struct {
+ RemoteLinkAddress tcpip.LinkAddress
+ LocalLinkAddress tcpip.LinkAddress
+ Pkt *PacketBuffer
+}
+
+type fwdTestLinkEndpoint struct {
+ dispatcher NetworkDispatcher
+ mtu uint32
+ linkAddr tcpip.LinkAddress
+
+ // C is where outbound packets are queued.
+ C chan fwdTestPacketInfo
+}
+
+// InjectInbound injects an inbound packet.
+func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ e.InjectLinkAddr(protocol, "", pkt)
+}
+
+// InjectLinkAddr injects an inbound packet with a remote link address.
+func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) {
+ e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt)
+}
+
+// Attach saves the stack network-layer dispatcher for use later when packets
+// are injected.
+func (e *fwdTestLinkEndpoint) Attach(dispatcher NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *fwdTestLinkEndpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
+// during construction.
+func (e *fwdTestLinkEndpoint) MTU() uint32 {
+ return e.mtu
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities {
+ caps := LinkEndpointCapabilities(0)
+ return caps | CapabilityResolutionRequired
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 {
+ return 1 << 15
+}
+
+// MaxHeaderLength returns the maximum size of the link layer header. Given it
+// doesn't have a header, it just returns 0.
+func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
+ return e.linkAddr
+}
+
+func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+ p := fwdTestPacketInfo{
+ RemoteLinkAddress: r.RemoteLinkAddress,
+ LocalLinkAddress: r.LocalLinkAddress,
+ Pkt: pkt,
+ }
+
+ select {
+ case e.C <- p:
+ default:
+ }
+
+ return nil
+}
+
+// WritePackets stores outbound packets into the channel.
+func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ n := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.WritePacket(r, gso, protocol, pkt)
+ n++
+ }
+
+ return n, nil
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ p := fwdTestPacketInfo{
+ Pkt: &PacketBuffer{Data: vv},
+ }
+
+ select {
+ case e.C <- p:
+ default:
+ }
+
+ return nil
+}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*fwdTestLinkEndpoint) Wait() {}
+
+func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
+ // Create a stack with the network protocol and two NICs.
+ s := New(Options{
+ NetworkProtocols: []NetworkProtocol{proto},
+ })
+
+ proto.addrCache = s.linkAddrCache
+
+ // Enable forwarding.
+ s.SetForwarding(true)
+
+ // NIC 1 has the link address "a", and added the network address 1.
+ ep1 = &fwdTestLinkEndpoint{
+ C: make(chan fwdTestPacketInfo, 300),
+ mtu: fwdTestNetDefaultMTU,
+ linkAddr: "a",
+ }
+ if err := s.CreateNIC(1, ep1); err != nil {
+ t.Fatal("CreateNIC #1 failed:", err)
+ }
+ if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil {
+ t.Fatal("AddAddress #1 failed:", err)
+ }
+
+ // NIC 2 has the link address "b", and added the network address 2.
+ ep2 = &fwdTestLinkEndpoint{
+ C: make(chan fwdTestPacketInfo, 300),
+ mtu: fwdTestNetDefaultMTU,
+ linkAddr: "b",
+ }
+ if err := s.CreateNIC(2, ep2); err != nil {
+ t.Fatal("CreateNIC #2 failed:", err)
+ }
+ if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil {
+ t.Fatal("AddAddress #2 failed:", err)
+ }
+
+ // Route all packets to NIC 2.
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}})
+ }
+
+ return ep1, ep2
+}
+
+func TestForwardingWithStaticResolver(t *testing.T) {
+ // Create a network protocol with a static resolver.
+ proto := &fwdTestNetworkProtocol{
+ onResolveStaticAddress:
+ // The network address 3 is resolved to the link address "c".
+ func(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == "\x03" {
+ return "c", true
+ }
+ return "", false
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ default:
+ t.Fatal("packet not forwarded")
+ }
+
+ // Test that the static address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+}
+
+func TestForwardingWithFakeResolver(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ // Any address will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+}
+
+func TestForwardingWithNoResolver(t *testing.T) {
+ // Create a network protocol without a resolver.
+ proto := &fwdTestNetworkProtocol{}
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ select {
+ case <-ep2.C:
+ t.Fatal("Packet should not be forwarded")
+ case <-time.After(time.Second):
+ }
+}
+
+func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ // Only packets to address 3 will be resolved to the
+ // link address "c".
+ if addr == "\x03" {
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ }
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // Inject an inbound packet to address 4 on NIC 1. This packet should
+ // not be forwarded.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 4
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf = buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ if p.Pkt.NetworkHeader[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+}
+
+func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // Inject two inbound packets to address 3 on NIC 1.
+ for i := 0; i < 2; i++ {
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ }
+
+ for i := 0; i < 2; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ if p.Pkt.NetworkHeader[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+}
+
+func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
+ // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ // Set the packet sequence number.
+ binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ }
+
+ for i := 0; i < maxPendingPacketsPerResolution; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ if b := p.Pkt.Header.View(); b[dstAddrOffset] != 3 {
+ t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset])
+ }
+ seqNumBuf, ok := p.Pkt.Data.PullUp(2) // The sequence number is a uint16 (2 bytes).
+ if !ok {
+ t.Fatalf("p.Pkt.Data is too short to hold a sequence number: %d", p.Pkt.Data.Size())
+ }
+
+ // The first 5 packets should not be forwarded so the sequence number should
+ // start with 5.
+ want := uint16(i + 5)
+ if n := binary.BigEndian.Uint16(seqNumBuf); n != want {
+ t.Fatalf("got the packet #%d, want = #%d", n, want)
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+}
+
+func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ for i := 0; i < maxPendingResolutions+5; i++ {
+ // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1.
+ // Each packet has a different destination address (3 to
+ // maxPendingResolutions + 7).
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = byte(3 + i)
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ }
+
+ for i := 0; i < maxPendingResolutions; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ // The first 5 packets (address 3 to 7) should not be forwarded
+ // because their address resolutions are interrupted.
+ if p.Pkt.NetworkHeader[dstAddrOffset] < 8 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", p.Pkt.NetworkHeader[dstAddrOffset])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go
new file mode 100644
index 000000000..3a20839da
--- /dev/null
+++ b/pkg/tcpip/stack/icmp_rate_limit.go
@@ -0,0 +1,41 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "golang.org/x/time/rate"
+)
+
+const (
+ // icmpLimit is the default maximum number of ICMP messages permitted by this
+ // rate limiter.
+ icmpLimit = 1000
+
+ // icmpBurst is the default number of ICMP messages that can be sent in a single
+ // burst.
+ icmpBurst = 50
+)
+
+// ICMPRateLimiter is a global rate limiter that controls the generation of
+// ICMP messages generated by the stack.
+type ICMPRateLimiter struct {
+ *rate.Limiter
+}
+
+// NewICMPRateLimiter returns a global rate limiter for controlling the rate
+// at which ICMP messages are generated by the stack.
+func NewICMPRateLimiter() *ICMPRateLimiter {
+ return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)}
+}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
new file mode 100644
index 000000000..974d77c36
--- /dev/null
+++ b/pkg/tcpip/stack/iptables.go
@@ -0,0 +1,367 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// Table names.
+const (
+ TablenameNat = "nat"
+ TablenameMangle = "mangle"
+ TablenameFilter = "filter"
+)
+
+// Chain names as defined by net/ipv4/netfilter/ip_tables.c.
+const (
+ ChainNamePrerouting = "PREROUTING"
+ ChainNameInput = "INPUT"
+ ChainNameForward = "FORWARD"
+ ChainNameOutput = "OUTPUT"
+ ChainNamePostrouting = "POSTROUTING"
+)
+
+// HookUnset indicates that there is no hook set for an entrypoint or
+// underflow.
+const HookUnset = -1
+
+// DefaultTables returns a default set of tables. Each chain is set to accept
+// all packets.
+func DefaultTables() *IPTables {
+ // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for
+ // iotas.
+ return &IPTables{
+ tables: map[string]Table{
+ TablenameNat: Table{
+ Rules: []Rule{
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: map[Hook]int{
+ Prerouting: 0,
+ Input: 1,
+ Output: 2,
+ Postrouting: 3,
+ },
+ Underflows: map[Hook]int{
+ Prerouting: 0,
+ Input: 1,
+ Output: 2,
+ Postrouting: 3,
+ },
+ UserChains: map[string]int{},
+ },
+ TablenameMangle: Table{
+ Rules: []Rule{
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: map[Hook]int{
+ Prerouting: 0,
+ Output: 1,
+ },
+ Underflows: map[Hook]int{
+ Prerouting: 0,
+ Output: 1,
+ },
+ UserChains: map[string]int{},
+ },
+ TablenameFilter: Table{
+ Rules: []Rule{
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: map[Hook]int{
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ },
+ Underflows: map[Hook]int{
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ },
+ UserChains: map[string]int{},
+ },
+ },
+ priorities: map[Hook][]string{
+ Input: []string{TablenameNat, TablenameFilter},
+ Prerouting: []string{TablenameMangle, TablenameNat},
+ Output: []string{TablenameMangle, TablenameNat, TablenameFilter},
+ },
+ connections: ConnTrack{
+ conns: make(map[tupleID]tuple),
+ },
+ }
+}
+
+// EmptyFilterTable returns a Table with no rules and the filter table chains
+// mapped to HookUnset.
+func EmptyFilterTable() Table {
+ return Table{
+ Rules: []Rule{},
+ BuiltinChains: map[Hook]int{
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: HookUnset,
+ },
+ Underflows: map[Hook]int{
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: HookUnset,
+ },
+ UserChains: map[string]int{},
+ }
+}
+
+// EmptyNatTable returns a Table with no rules and the filter table chains
+// mapped to HookUnset.
+func EmptyNatTable() Table {
+ return Table{
+ Rules: []Rule{},
+ BuiltinChains: map[Hook]int{
+ Prerouting: HookUnset,
+ Input: HookUnset,
+ Output: HookUnset,
+ Postrouting: HookUnset,
+ },
+ Underflows: map[Hook]int{
+ Prerouting: HookUnset,
+ Input: HookUnset,
+ Output: HookUnset,
+ Postrouting: HookUnset,
+ },
+ UserChains: map[string]int{},
+ }
+}
+
+// GetTable returns table by name.
+func (it *IPTables) GetTable(name string) (Table, bool) {
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ t, ok := it.tables[name]
+ return t, ok
+}
+
+// ReplaceTable replaces or inserts table by name.
+func (it *IPTables) ReplaceTable(name string, table Table) {
+ it.mu.Lock()
+ defer it.mu.Unlock()
+ it.modified = true
+ it.tables[name] = table
+}
+
+// GetPriorities returns slice of priorities associated with hook.
+func (it *IPTables) GetPriorities(hook Hook) []string {
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ return it.priorities[hook]
+}
+
+// A chainVerdict is what a table decides should be done with a packet.
+type chainVerdict int
+
+const (
+ // chainAccept indicates the packet should continue through netstack.
+ chainAccept chainVerdict = iota
+
+ // chainAccept indicates the packet should be dropped.
+ chainDrop
+
+ // chainReturn indicates the packet should return to the calling chain
+ // or the underflow rule of a builtin chain.
+ chainReturn
+)
+
+// Check runs pkt through the rules for hook. It returns true when the packet
+// should continue traversing the network stack and false when it should be
+// dropped.
+//
+// Precondition: pkt.NetworkHeader is set.
+func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool {
+ // Many users never configure iptables. Spare them the cost of rule
+ // traversal if rules have never been set.
+ it.mu.RLock()
+ if !it.modified {
+ it.mu.RUnlock()
+ return true
+ }
+ it.mu.RUnlock()
+
+ // Packets are manipulated only if connection and matching
+ // NAT rule exists.
+ it.connections.handlePacket(pkt, hook, gso, r)
+
+ // Go through each table containing the hook.
+ for _, tablename := range it.GetPriorities(hook) {
+ table, _ := it.GetTable(tablename)
+ ruleIdx := table.BuiltinChains[hook]
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
+ // If the table returns Accept, move on to the next table.
+ case chainAccept:
+ continue
+ // The Drop verdict is final.
+ case chainDrop:
+ return false
+ case chainReturn:
+ // Any Return from a built-in chain means we have to
+ // call the underflow.
+ underflow := table.Rules[table.Underflows[hook]]
+ switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, address); v {
+ case RuleAccept:
+ continue
+ case RuleDrop:
+ return false
+ case RuleJump, RuleReturn:
+ panic("Underflows should only return RuleAccept or RuleDrop.")
+ default:
+ panic(fmt.Sprintf("Unknown verdict: %d", v))
+ }
+
+ default:
+ panic(fmt.Sprintf("Unknown verdict %v.", verdict))
+ }
+ }
+
+ // Every table returned Accept.
+ return true
+}
+
+// CheckPackets runs pkts through the rules for hook and returns a map of packets that
+// should not go forward.
+//
+// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+//
+// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a
+// precondition.
+//
+// NOTE: unlike the Check API the returned map contains packets that should be
+// dropped.
+func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, nicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if !pkt.NatDone {
+ if ok := it.Check(hook, pkt, gso, r, "", nicName); !ok {
+ if drop == nil {
+ drop = make(map[*PacketBuffer]struct{})
+ }
+ drop[pkt] = struct{}{}
+ }
+ if pkt.NatDone {
+ if natPkts == nil {
+ natPkts = make(map[*PacketBuffer]struct{})
+ }
+ natPkts[pkt] = struct{}{}
+ }
+ }
+ }
+ return drop, natPkts
+}
+
+// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a
+// precondition.
+func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict {
+ // Start from ruleIdx and walk the list of rules until a rule gives us
+ // a verdict.
+ for ruleIdx < len(table.Rules) {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
+ case RuleAccept:
+ return chainAccept
+
+ case RuleDrop:
+ return chainDrop
+
+ case RuleReturn:
+ return chainReturn
+
+ case RuleJump:
+ // "Jumping" to the next rule just means we're
+ // continuing on down the list.
+ if jumpTo == ruleIdx+1 {
+ ruleIdx++
+ continue
+ }
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address, nicName); verdict {
+ case chainAccept:
+ return chainAccept
+ case chainDrop:
+ return chainDrop
+ case chainReturn:
+ ruleIdx++
+ continue
+ default:
+ panic(fmt.Sprintf("Unknown verdict: %d", verdict))
+ }
+
+ default:
+ panic(fmt.Sprintf("Unknown verdict: %d", verdict))
+ }
+
+ }
+
+ // We got through the entire table without a decision. Default to DROP
+ // for safety.
+ return chainDrop
+}
+
+// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a
+// precondition.
+func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) {
+ rule := table.Rules[ruleIdx]
+
+ // If pkt.NetworkHeader hasn't been set yet, it will be contained in
+ // pkt.Data.
+ if pkt.NetworkHeader == nil {
+ var ok bool
+ pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ // Precondition has been violated.
+ panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize))
+ }
+ }
+
+ // Check whether the packet matches the IP header filter.
+ if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) {
+ // Continue on to the next rule.
+ return RuleJump, ruleIdx + 1
+ }
+
+ // Go through each rule matcher. If they all match, run
+ // the rule target.
+ for _, matcher := range rule.Matchers {
+ matches, hotdrop := matcher.Match(hook, pkt, "")
+ if hotdrop {
+ return RuleDrop, 0
+ }
+ if !matches {
+ // Continue on to the next rule.
+ return RuleJump, ruleIdx + 1
+ }
+ }
+
+ // All the matchers matched, so run the target.
+ return rule.Target.Action(pkt, &it.connections, hook, gso, r, address)
+}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
new file mode 100644
index 000000000..d43f60c67
--- /dev/null
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -0,0 +1,164 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// AcceptTarget accepts packets.
+type AcceptTarget struct{}
+
+// Action implements Target.Action.
+func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+ return RuleAccept, 0
+}
+
+// DropTarget drops packets.
+type DropTarget struct{}
+
+// Action implements Target.Action.
+func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+ return RuleDrop, 0
+}
+
+// ErrorTarget logs an error and drops the packet. It represents a target that
+// should be unreachable.
+type ErrorTarget struct{}
+
+// Action implements Target.Action.
+func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+ log.Debugf("ErrorTarget triggered.")
+ return RuleDrop, 0
+}
+
+// UserChainTarget marks a rule as the beginning of a user chain.
+type UserChainTarget struct {
+ Name string
+}
+
+// Action implements Target.Action.
+func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+ panic("UserChainTarget should never be called.")
+}
+
+// ReturnTarget returns from the current chain. If the chain is a built-in, the
+// hook's underflow should be called.
+type ReturnTarget struct{}
+
+// Action implements Target.Action.
+func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+ return RuleReturn, 0
+}
+
+// RedirectTarget redirects the packet by modifying the destination port/IP.
+// Min and Max values for IP and Ports in the struct indicate the range of
+// values which can be used to redirect.
+type RedirectTarget struct {
+ // TODO(gvisor.dev/issue/170): Other flags need to be added after
+ // we support them.
+ // RangeProtoSpecified flag indicates single port is specified to
+ // redirect.
+ RangeProtoSpecified bool
+
+ // MinIP indicates address used to redirect.
+ MinIP tcpip.Address
+
+ // MaxIP indicates address used to redirect.
+ MaxIP tcpip.Address
+
+ // MinPort indicates port used to redirect.
+ MinPort uint16
+
+ // MaxPort indicates port used to redirect.
+ MaxPort uint16
+}
+
+// Action implements Target.Action.
+// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
+// implementation only works for PREROUTING and calls pkt.Clone(), neither
+// of which should be the case.
+func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+ // Packet is already manipulated.
+ if pkt.NatDone {
+ return RuleAccept, 0
+ }
+
+ // Drop the packet if network and transport header are not set.
+ if pkt.NetworkHeader == nil || pkt.TransportHeader == nil {
+ return RuleDrop, 0
+ }
+
+ // Change the address to localhost (127.0.0.1) in Output and
+ // to primary address of the incoming interface in Prerouting.
+ switch hook {
+ case Output:
+ rt.MinIP = tcpip.Address([]byte{127, 0, 0, 1})
+ rt.MaxIP = tcpip.Address([]byte{127, 0, 0, 1})
+ case Prerouting:
+ rt.MinIP = address
+ rt.MaxIP = address
+ default:
+ panic("redirect target is supported only on output and prerouting hooks")
+ }
+
+ // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if
+ // we need to change dest address (for OUTPUT chain) or ports.
+ netHeader := header.IPv4(pkt.NetworkHeader)
+ switch protocol := netHeader.TransportProtocol(); protocol {
+ case header.UDPProtocolNumber:
+ udpHeader := header.UDP(pkt.TransportHeader)
+ udpHeader.SetDestinationPort(rt.MinPort)
+
+ // Calculate UDP checksum and set it.
+ if hook == Output {
+ udpHeader.SetChecksum(0)
+ hdr := &pkt.Header
+ length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength())
+
+ // Only calculate the checksum if offloading isn't supported.
+ if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
+ xsum := r.PseudoHeaderChecksum(protocol, length)
+ for _, v := range pkt.Data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ udpHeader.SetChecksum(0)
+ udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
+ }
+ }
+ // Change destination address.
+ netHeader.SetDestinationAddress(rt.MinIP)
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ pkt.NatDone = true
+ case header.TCPProtocolNumber:
+ if ct == nil {
+ return RuleAccept, 0
+ }
+
+ // Set up conection for matching NAT rule. Only the first
+ // packet of the connection comes here. Other packets will be
+ // manipulated in connection tracking.
+ if conn := ct.createConnFor(pkt, hook, rt); conn != nil {
+ ct.handlePacket(pkt, hook, gso, r)
+ }
+ default:
+ return RuleDrop, 0
+ }
+
+ return RuleAccept, 0
+}
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
new file mode 100644
index 000000000..c528ec381
--- /dev/null
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -0,0 +1,253 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "strings"
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// A Hook specifies one of the hooks built into the network stack.
+//
+// Userspace app Userspace app
+// ^ |
+// | v
+// [Input] [Output]
+// ^ |
+// | v
+// | routing
+// | |
+// | v
+// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]----->
+type Hook uint
+
+// These values correspond to values in include/uapi/linux/netfilter.h.
+const (
+ // Prerouting happens before a packet is routed to applications or to
+ // be forwarded.
+ Prerouting Hook = iota
+
+ // Input happens before a packet reaches an application.
+ Input
+
+ // Forward happens once it's decided that a packet should be forwarded
+ // to another host.
+ Forward
+
+ // Output happens after a packet is written by an application to be
+ // sent out.
+ Output
+
+ // Postrouting happens just before a packet goes out on the wire.
+ Postrouting
+
+ // The total number of hooks.
+ NumHooks
+)
+
+// A RuleVerdict is what a rule decides should be done with a packet.
+type RuleVerdict int
+
+const (
+ // RuleAccept indicates the packet should continue through netstack.
+ RuleAccept RuleVerdict = iota
+
+ // RuleDrop indicates the packet should be dropped.
+ RuleDrop
+
+ // RuleJump indicates the packet should jump to another chain.
+ RuleJump
+
+ // RuleReturn indicates the packet should return to the previous chain.
+ RuleReturn
+)
+
+// IPTables holds all the tables for a netstack.
+type IPTables struct {
+ // mu protects tables, priorities, and modified.
+ mu sync.RWMutex
+
+ // tables maps table names to tables. User tables have arbitrary names.
+ // mu needs to be locked for accessing.
+ tables map[string]Table
+
+ // priorities maps each hook to a list of table names. The order of the
+ // list is the order in which each table should be visited for that
+ // hook. mu needs to be locked for accessing.
+ priorities map[Hook][]string
+
+ // modified is whether tables have been modified at least once. It is
+ // used to elide the iptables performance overhead for workloads that
+ // don't utilize iptables.
+ modified bool
+
+ connections ConnTrack
+}
+
+// A Table defines a set of chains and hooks into the network stack. It is
+// really just a list of rules.
+type Table struct {
+ // Rules holds the rules that make up the table.
+ Rules []Rule
+
+ // BuiltinChains maps builtin chains to their entrypoint rule in Rules.
+ BuiltinChains map[Hook]int
+
+ // Underflows maps builtin chains to their underflow rule in Rules
+ // (i.e. the rule to execute if the chain returns without a verdict).
+ Underflows map[Hook]int
+
+ // UserChains holds user-defined chains for the keyed by name. Users
+ // can give their chains arbitrary names.
+ UserChains map[string]int
+}
+
+// ValidHooks returns a bitmap of the builtin hooks for the given table.
+func (table *Table) ValidHooks() uint32 {
+ hooks := uint32(0)
+ for hook := range table.BuiltinChains {
+ hooks |= 1 << hook
+ }
+ return hooks
+}
+
+// A Rule is a packet processing rule. It consists of two pieces. First it
+// contains zero or more matchers, each of which is a specification of which
+// packets this rule applies to. If there are no matchers in the rule, it
+// applies to any packet.
+type Rule struct {
+ // Filter holds basic IP filtering fields common to every rule.
+ Filter IPHeaderFilter
+
+ // Matchers is the list of matchers for this rule.
+ Matchers []Matcher
+
+ // Target is the action to invoke if all the matchers match the packet.
+ Target Target
+}
+
+// IPHeaderFilter holds basic IP filtering data common to every rule.
+type IPHeaderFilter struct {
+ // Protocol matches the transport protocol.
+ Protocol tcpip.TransportProtocolNumber
+
+ // Dst matches the destination IP address.
+ Dst tcpip.Address
+
+ // DstMask masks bits of the destination IP address when comparing with
+ // Dst.
+ DstMask tcpip.Address
+
+ // DstInvert inverts the meaning of the destination IP check, i.e. when
+ // true the filter will match packets that fail the destination
+ // comparison.
+ DstInvert bool
+
+ // Src matches the source IP address.
+ Src tcpip.Address
+
+ // SrcMask masks bits of the source IP address when comparing with Src.
+ SrcMask tcpip.Address
+
+ // SrcInvert inverts the meaning of the source IP check, i.e. when true the
+ // filter will match packets that fail the source comparison.
+ SrcInvert bool
+
+ // OutputInterface matches the name of the outgoing interface for the
+ // packet.
+ OutputInterface string
+
+ // OutputInterfaceMask masks the characters of the interface name when
+ // comparing with OutputInterface.
+ OutputInterfaceMask string
+
+ // OutputInterfaceInvert inverts the meaning of outgoing interface check,
+ // i.e. when true the filter will match packets that fail the outgoing
+ // interface comparison.
+ OutputInterfaceInvert bool
+}
+
+// match returns whether hdr matches the filter.
+func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool {
+ // TODO(gvisor.dev/issue/170): Support other fields of the filter.
+ // Check the transport protocol.
+ if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() {
+ return false
+ }
+
+ // Check the source and destination IPs.
+ if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) {
+ return false
+ }
+
+ // Check the output interface.
+ // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING
+ // hooks after supported.
+ if hook == Output {
+ n := len(fl.OutputInterface)
+ if n == 0 {
+ return true
+ }
+
+ // If the interface name ends with '+', any interface which begins
+ // with the name should be matched.
+ ifName := fl.OutputInterface
+ matches := true
+ if strings.HasSuffix(ifName, "+") {
+ matches = strings.HasPrefix(nicName, ifName[:n-1])
+ } else {
+ matches = nicName == ifName
+ }
+ return fl.OutputInterfaceInvert != matches
+ }
+
+ return true
+}
+
+// filterAddress returns whether addr matches the filter.
+func filterAddress(addr, mask, filterAddr tcpip.Address, invert bool) bool {
+ matches := true
+ for i := range filterAddr {
+ if addr[i]&mask[i] != filterAddr[i] {
+ matches = false
+ break
+ }
+ }
+ return matches != invert
+}
+
+// A Matcher is the interface for matching packets.
+type Matcher interface {
+ // Name returns the name of the Matcher.
+ Name() string
+
+ // Match returns whether the packet matches and whether the packet
+ // should be "hotdropped", i.e. dropped immediately. This is usually
+ // used for suspicious packets.
+ //
+ // Precondition: packet.NetworkHeader is set.
+ Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
+}
+
+// A Target is the interface for taking an action for a packet.
+type Target interface {
+ // Action takes an action on the packet and returns a verdict on how
+ // traversal should (or should not) continue. If the return value is
+ // Jump, it also returns the index of the rule to jump to.
+ Action(packet *PacketBuffer, connections *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int)
+}
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
new file mode 100644
index 000000000..403557fd7
--- /dev/null
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -0,0 +1,295 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const linkAddrCacheSize = 512 // max cache entries
+
+// linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses.
+//
+// The entries are stored in a ring buffer, oldest entry replaced first.
+//
+// This struct is safe for concurrent use.
+type linkAddrCache struct {
+ // ageLimit is how long a cache entry is valid for.
+ ageLimit time.Duration
+
+ // resolutionTimeout is the amount of time to wait for a link request to
+ // resolve an address.
+ resolutionTimeout time.Duration
+
+ // resolutionAttempts is the number of times an address is attempted to be
+ // resolved before failing.
+ resolutionAttempts int
+
+ cache struct {
+ sync.Mutex
+ table map[tcpip.FullAddress]*linkAddrEntry
+ lru linkAddrEntryList
+ }
+}
+
+// entryState controls the state of a single entry in the cache.
+type entryState int
+
+const (
+ // incomplete means that there is an outstanding request to resolve the
+ // address. This is the initial state.
+ incomplete entryState = iota
+ // ready means that the address has been resolved and can be used.
+ ready
+ // failed means that address resolution timed out and the address
+ // could not be resolved.
+ failed
+)
+
+// String implements Stringer.
+func (s entryState) String() string {
+ switch s {
+ case incomplete:
+ return "incomplete"
+ case ready:
+ return "ready"
+ case failed:
+ return "failed"
+ default:
+ return fmt.Sprintf("unknown(%d)", s)
+ }
+}
+
+// A linkAddrEntry is an entry in the linkAddrCache.
+// This struct is thread-compatible.
+type linkAddrEntry struct {
+ linkAddrEntryEntry
+
+ addr tcpip.FullAddress
+ linkAddr tcpip.LinkAddress
+ expiration time.Time
+ s entryState
+
+ // wakers is a set of waiters for address resolution result. Anytime
+ // state transitions out of incomplete these waiters are notified.
+ wakers map[*sleep.Waker]struct{}
+
+ // done is used to allow callers to wait on address resolution. It is nil iff
+ // s is incomplete and resolution is not yet in progress.
+ done chan struct{}
+}
+
+// changeState sets the entry's state to ns, notifying any waiters.
+//
+// The entry's expiration is bumped up to the greater of itself and the passed
+// expiration; the zero value indicates immediate expiration, and is set
+// unconditionally - this is an implementation detail that allows for entries
+// to be reused.
+func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) {
+ // Notify whoever is waiting on address resolution when transitioning
+ // out of incomplete.
+ if e.s == incomplete && ns != incomplete {
+ for w := range e.wakers {
+ w.Assert()
+ }
+ e.wakers = nil
+ if ch := e.done; ch != nil {
+ close(ch)
+ }
+ e.done = nil
+ }
+
+ if expiration.IsZero() || expiration.After(e.expiration) {
+ e.expiration = expiration
+ }
+ e.s = ns
+}
+
+func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
+ delete(e.wakers, w)
+}
+
+// add adds a k -> v mapping to the cache.
+func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
+ // Calculate expiration time before acquiring the lock, since expiration is
+ // relative to the time when information was learned, rather than when it
+ // happened to be inserted into the cache.
+ expiration := time.Now().Add(c.ageLimit)
+
+ c.cache.Lock()
+ entry := c.getOrCreateEntryLocked(k)
+ entry.linkAddr = v
+
+ entry.changeState(ready, expiration)
+ c.cache.Unlock()
+}
+
+// getOrCreateEntryLocked retrieves a cache entry associated with k. The
+// returned entry is always refreshed in the cache (it is reachable via the
+// map, and its place is bumped in LRU).
+//
+// If a matching entry exists in the cache, it is returned. If no matching
+// entry exists and the cache is full, an existing entry is evicted via LRU,
+// reset to state incomplete, and returned. If no matching entry exists and the
+// cache is not full, a new entry with state incomplete is allocated and
+// returned.
+func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry {
+ if entry, ok := c.cache.table[k]; ok {
+ c.cache.lru.Remove(entry)
+ c.cache.lru.PushFront(entry)
+ return entry
+ }
+ var entry *linkAddrEntry
+ if len(c.cache.table) == linkAddrCacheSize {
+ entry = c.cache.lru.Back()
+
+ delete(c.cache.table, entry.addr)
+ c.cache.lru.Remove(entry)
+
+ // Wake waiters and mark the soon-to-be-reused entry as expired. Note
+ // that the state passed doesn't matter when the zero time is passed.
+ entry.changeState(failed, time.Time{})
+ } else {
+ entry = new(linkAddrEntry)
+ }
+
+ *entry = linkAddrEntry{
+ addr: k,
+ s: incomplete,
+ }
+ c.cache.table[k] = entry
+ c.cache.lru.PushFront(entry)
+ return entry
+}
+
+// get reports any known link address for k.
+func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+ if linkRes != nil {
+ if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok {
+ return addr, nil, nil
+ }
+ }
+
+ c.cache.Lock()
+ defer c.cache.Unlock()
+ entry := c.getOrCreateEntryLocked(k)
+ switch s := entry.s; s {
+ case ready, failed:
+ if !time.Now().After(entry.expiration) {
+ // Not expired.
+ switch s {
+ case ready:
+ return entry.linkAddr, nil, nil
+ case failed:
+ return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ }
+ }
+
+ entry.changeState(incomplete, time.Time{})
+ fallthrough
+ case incomplete:
+ if waker != nil {
+ if entry.wakers == nil {
+ entry.wakers = make(map[*sleep.Waker]struct{})
+ }
+ entry.wakers[waker] = struct{}{}
+ }
+
+ if entry.done == nil {
+ // Address resolution needs to be initiated.
+ if linkRes == nil {
+ return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
+ }
+
+ entry.done = make(chan struct{})
+ go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+ }
+
+ return entry.linkAddr, entry.done, tcpip.ErrWouldBlock
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ }
+}
+
+// removeWaker removes a waker previously added through get().
+func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
+ c.cache.Lock()
+ defer c.cache.Unlock()
+
+ if entry, ok := c.cache.table[k]; ok {
+ entry.removeWaker(waker)
+ }
+}
+
+func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, done <-chan struct{}) {
+ for i := 0; ; i++ {
+ // Send link request, then wait for the timeout limit and check
+ // whether the request succeeded.
+ linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
+
+ select {
+ case now := <-time.After(c.resolutionTimeout):
+ if stop := c.checkLinkRequest(now, k, i); stop {
+ return
+ }
+ case <-done:
+ return
+ }
+ }
+}
+
+// checkLinkRequest checks whether previous attempt to resolve address has succeeded
+// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request
+// can stop, false if another request should be sent.
+func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool {
+ c.cache.Lock()
+ defer c.cache.Unlock()
+ entry, ok := c.cache.table[k]
+ if !ok {
+ // Entry was evicted from the cache.
+ return true
+ }
+ switch s := entry.s; s {
+ case ready, failed:
+ // Entry was made ready by resolver or failed. Either way we're done.
+ case incomplete:
+ if attempt+1 < c.resolutionAttempts {
+ // No response yet, need to send another ARP request.
+ return false
+ }
+ // Max number of retries reached, mark entry as failed.
+ entry.changeState(failed, now.Add(c.ageLimit))
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ }
+ return true
+}
+
+func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
+ c := &linkAddrCache{
+ ageLimit: ageLimit,
+ resolutionTimeout: resolutionTimeout,
+ resolutionAttempts: resolutionAttempts,
+ }
+ c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize)
+ return c
+}
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
new file mode 100644
index 000000000..1baa498d0
--- /dev/null
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -0,0 +1,277 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+type testaddr struct {
+ addr tcpip.FullAddress
+ linkAddr tcpip.LinkAddress
+}
+
+var testAddrs = func() []testaddr {
+ var addrs []testaddr
+ for i := 0; i < 4*linkAddrCacheSize; i++ {
+ addr := fmt.Sprintf("Addr%06d", i)
+ addrs = append(addrs, testaddr{
+ addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
+ linkAddr: tcpip.LinkAddress("Link" + addr),
+ })
+ }
+ return addrs
+}()
+
+type testLinkAddressResolver struct {
+ cache *linkAddrCache
+ delay time.Duration
+ onLinkAddressRequest func()
+}
+
+func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
+ time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
+ if f := r.onLinkAddressRequest; f != nil {
+ f()
+ }
+ return nil
+}
+
+func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) {
+ for _, ta := range testAddrs {
+ if ta.addr.Addr == addr {
+ r.cache.add(ta.addr, ta.linkAddr)
+ break
+ }
+ }
+}
+
+func (*testLinkAddressResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == "broadcast" {
+ return "mac_broadcast", true
+ }
+ return "", false
+}
+
+func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return 1
+}
+
+func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) {
+ w := sleep.Waker{}
+ s := sleep.Sleeper{}
+ s.AddWaker(&w, 123)
+ defer s.Done()
+
+ for {
+ if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
+ return got, err
+ }
+ s.Fetch(true)
+ }
+}
+
+func TestCacheOverflow(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ for i := len(testAddrs) - 1; i >= 0; i-- {
+ e := testAddrs[i]
+ c.add(e.addr, e.linkAddr)
+ got, _, err := c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
+ }
+ }
+ // Expect to find at least half of the most recent entries.
+ for i := 0; i < linkAddrCacheSize/2; i++ {
+ e := testAddrs[i]
+ got, _, err := c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
+ }
+ }
+ // The earliest entries should no longer be in the cache.
+ for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- {
+ e := testAddrs[i]
+ if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
+ }
+ }
+}
+
+func TestCacheConcurrent(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+
+ var wg sync.WaitGroup
+ for r := 0; r < 16; r++ {
+ wg.Add(1)
+ go func() {
+ for _, e := range testAddrs {
+ c.add(e.addr, e.linkAddr)
+ c.get(e.addr, nil, "", nil, nil) // make work for gotsan
+ }
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+
+ // All goroutines add in the same order and add more values than
+ // can fit in the cache, so our eviction strategy requires that
+ // the last entry be present and the first be missing.
+ e := testAddrs[len(testAddrs)-1]
+ got, _, err := c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ }
+
+ e = testAddrs[0]
+ if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ }
+}
+
+func TestCacheAgeLimit(t *testing.T) {
+ c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
+ e := testAddrs[0]
+ c.add(e.addr, e.linkAddr)
+ time.Sleep(50 * time.Millisecond)
+ if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ }
+}
+
+func TestCacheReplace(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ e := testAddrs[0]
+ l2 := e.linkAddr + "2"
+ c.add(e.addr, e.linkAddr)
+ got, _, err := c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ }
+
+ c.add(e.addr, l2)
+ got, _, err = c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ }
+ if got != l2 {
+ t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2)
+ }
+}
+
+func TestCacheResolution(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1)
+ linkRes := &testLinkAddressResolver{cache: c}
+ for i, ta := range testAddrs {
+ got, err := getBlocking(c, ta.addr, linkRes)
+ if err != nil {
+ t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err)
+ }
+ if got != ta.linkAddr {
+ t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr)
+ }
+ }
+
+ // Check that after resolved, address stays in the cache and never returns WouldBlock.
+ for i := 0; i < 10; i++ {
+ e := testAddrs[len(testAddrs)-1]
+ got, _, err := c.get(e.addr, linkRes, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ }
+ }
+}
+
+func TestCacheResolutionFailed(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5)
+ linkRes := &testLinkAddressResolver{cache: c}
+
+ var requestCount uint32
+ linkRes.onLinkAddressRequest = func() {
+ atomic.AddUint32(&requestCount, 1)
+ }
+
+ // First, sanity check that resolution is working...
+ e := testAddrs[0]
+ got, err := getBlocking(c, e.addr, linkRes)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ }
+
+ before := atomic.LoadUint32(&requestCount)
+
+ e.addr.Addr += "2"
+ if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ }
+
+ if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want {
+ t.Errorf("got link address request count = %d, want = %d", got, want)
+ }
+}
+
+func TestCacheResolutionTimeout(t *testing.T) {
+ resolverDelay := 500 * time.Millisecond
+ expiration := resolverDelay / 10
+ c := newLinkAddrCache(expiration, 1*time.Millisecond, 3)
+ linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
+
+ e := testAddrs[0]
+ if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ }
+}
+
+// TestStaticResolution checks that static link addresses are resolved immediately and don't
+// send resolution requests.
+func TestStaticResolution(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, time.Millisecond, 1)
+ linkRes := &testLinkAddressResolver{cache: c, delay: time.Minute}
+
+ addr := tcpip.Address("broadcast")
+ want := tcpip.LinkAddress("mac_broadcast")
+ got, _, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(addr), string(got), err)
+ }
+ if got != want {
+ t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want))
+ }
+}
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
new file mode 100644
index 000000000..e28c23d66
--- /dev/null
+++ b/pkg/tcpip/stack/ndp.go
@@ -0,0 +1,1981 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "log"
+ "math/rand"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+const (
+ // defaultDupAddrDetectTransmits is the default number of NDP Neighbor
+ // Solicitation messages to send when doing Duplicate Address Detection
+ // for a tentative address.
+ //
+ // Default = 1 (from RFC 4862 section 5.1)
+ defaultDupAddrDetectTransmits = 1
+
+ // defaultRetransmitTimer is the default amount of time to wait between
+ // sending NDP Neighbor solicitation messages.
+ //
+ // Default = 1s (from RFC 4861 section 10).
+ defaultRetransmitTimer = time.Second
+
+ // defaultMaxRtrSolicitations is the default number of Router
+ // Solicitation messages to send when a NIC becomes enabled.
+ //
+ // Default = 3 (from RFC 4861 section 10).
+ defaultMaxRtrSolicitations = 3
+
+ // defaultRtrSolicitationInterval is the default amount of time between
+ // sending Router Solicitation messages.
+ //
+ // Default = 4s (from 4861 section 10).
+ defaultRtrSolicitationInterval = 4 * time.Second
+
+ // defaultMaxRtrSolicitationDelay is the default maximum amount of time
+ // to wait before sending the first Router Solicitation message.
+ //
+ // Default = 1s (from 4861 section 10).
+ defaultMaxRtrSolicitationDelay = time.Second
+
+ // defaultHandleRAs is the default configuration for whether or not to
+ // handle incoming Router Advertisements as a host.
+ defaultHandleRAs = true
+
+ // defaultDiscoverDefaultRouters is the default configuration for
+ // whether or not to discover default routers from incoming Router
+ // Advertisements, as a host.
+ defaultDiscoverDefaultRouters = true
+
+ // defaultDiscoverOnLinkPrefixes is the default configuration for
+ // whether or not to discover on-link prefixes from incoming Router
+ // Advertisements' Prefix Information option, as a host.
+ defaultDiscoverOnLinkPrefixes = true
+
+ // defaultAutoGenGlobalAddresses is the default configuration for
+ // whether or not to generate global IPv6 addresses in response to
+ // receiving a new Prefix Information option with its Autonomous
+ // Address AutoConfiguration flag set, as a host.
+ //
+ // Default = true.
+ defaultAutoGenGlobalAddresses = true
+
+ // minimumRetransmitTimer is the minimum amount of time to wait between
+ // sending NDP Neighbor solicitation messages. Note, RFC 4861 does
+ // not impose a minimum Retransmit Timer, but we do here to make sure
+ // the messages are not sent all at once. We also come to this value
+ // because in the RetransmitTimer field of a Router Advertisement, a
+ // value of 0 means unspecified, so the smallest valid value is 1.
+ // Note, the unit of the RetransmitTimer field in the Router
+ // Advertisement is milliseconds.
+ minimumRetransmitTimer = time.Millisecond
+
+ // minimumRtrSolicitationInterval is the minimum amount of time to wait
+ // between sending Router Solicitation messages. This limit is imposed
+ // to make sure that Router Solicitation messages are not sent all at
+ // once, defeating the purpose of sending the initial few messages.
+ minimumRtrSolicitationInterval = 500 * time.Millisecond
+
+ // minimumMaxRtrSolicitationDelay is the minimum amount of time to wait
+ // before sending the first Router Solicitation message. It is 0 because
+ // we cannot have a negative delay.
+ minimumMaxRtrSolicitationDelay = 0
+
+ // MaxDiscoveredDefaultRouters is the maximum number of discovered
+ // default routers. The stack should stop discovering new routers after
+ // discovering MaxDiscoveredDefaultRouters routers.
+ //
+ // This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and
+ // SHOULD be more.
+ MaxDiscoveredDefaultRouters = 10
+
+ // MaxDiscoveredOnLinkPrefixes is the maximum number of discovered
+ // on-link prefixes. The stack should stop discovering new on-link
+ // prefixes after discovering MaxDiscoveredOnLinkPrefixes on-link
+ // prefixes.
+ MaxDiscoveredOnLinkPrefixes = 10
+
+ // validPrefixLenForAutoGen is the expected prefix length that an
+ // address can be generated for. Must be 64 bits as the interface
+ // identifier (IID) is 64 bits and an IPv6 address is 128 bits, so
+ // 128 - 64 = 64.
+ validPrefixLenForAutoGen = 64
+
+ // defaultAutoGenTempGlobalAddresses is the default configuration for whether
+ // or not to generate temporary SLAAC addresses.
+ defaultAutoGenTempGlobalAddresses = true
+
+ // defaultMaxTempAddrValidLifetime is the default maximum valid lifetime
+ // for temporary SLAAC addresses generated as part of RFC 4941.
+ //
+ // Default = 7 days (from RFC 4941 section 5).
+ defaultMaxTempAddrValidLifetime = 7 * 24 * time.Hour
+
+ // defaultMaxTempAddrPreferredLifetime is the default preferred lifetime
+ // for temporary SLAAC addresses generated as part of RFC 4941.
+ //
+ // Default = 1 day (from RFC 4941 section 5).
+ defaultMaxTempAddrPreferredLifetime = 24 * time.Hour
+
+ // defaultRegenAdvanceDuration is the default duration before the deprecation
+ // of a temporary address when a new address will be generated.
+ //
+ // Default = 5s (from RFC 4941 section 5).
+ defaultRegenAdvanceDuration = 5 * time.Second
+
+ // minRegenAdvanceDuration is the minimum duration before the deprecation
+ // of a temporary address when a new address will be generated.
+ minRegenAdvanceDuration = time.Duration(0)
+
+ // maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt
+ // SLAAC address regenerations in response to a NIC-local conflict.
+ maxSLAACAddrLocalRegenAttempts = 10
+)
+
+var (
+ // MinPrefixInformationValidLifetimeForUpdate is the minimum Valid
+ // Lifetime to update the valid lifetime of a generated address by
+ // SLAAC.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // Min = 2hrs.
+ MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour
+
+ // MaxDesyncFactor is the upper bound for the preferred lifetime's desync
+ // factor for temporary SLAAC addresses.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // Must be greater than 0.
+ //
+ // Max = 10m (from RFC 4941 section 5).
+ MaxDesyncFactor = 10 * time.Minute
+
+ // MinMaxTempAddrPreferredLifetime is the minimum value allowed for the
+ // maximum preferred lifetime for temporary SLAAC addresses.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // This value guarantees that a temporary address will be preferred for at
+ // least 1hr if the SLAAC prefix is valid for at least that time.
+ MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour
+
+ // MinMaxTempAddrValidLifetime is the minimum value allowed for the
+ // maximum valid lifetime for temporary SLAAC addresses.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // This value guarantees that a temporary address will be valid for at least
+ // 2hrs if the SLAAC prefix is valid for at least that time.
+ MinMaxTempAddrValidLifetime = 2 * time.Hour
+)
+
+// DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an
+// NDP Router Advertisement informed the Stack about.
+type DHCPv6ConfigurationFromNDPRA int
+
+const (
+ _ DHCPv6ConfigurationFromNDPRA = iota
+
+ // DHCPv6NoConfiguration indicates that no configurations are available via
+ // DHCPv6.
+ DHCPv6NoConfiguration
+
+ // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6.
+ //
+ // DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6
+ // will return all available configuration information.
+ DHCPv6ManagedAddress
+
+ // DHCPv6OtherConfigurations indicates that other configuration information is
+ // available via DHCPv6.
+ //
+ // Other configurations are configurations other than addresses. Examples of
+ // other configurations are recursive DNS server list, DNS search lists and
+ // default gateway.
+ DHCPv6OtherConfigurations
+)
+
+// NDPDispatcher is the interface integrators of netstack must implement to
+// receive and handle NDP related events.
+type NDPDispatcher interface {
+ // OnDuplicateAddressDetectionStatus will be called when the DAD process
+ // for an address (addr) on a NIC (with ID nicID) completes. resolved
+ // will be set to true if DAD completed successfully (no duplicate addr
+ // detected); false otherwise (addr was detected to be a duplicate on
+ // the link the NIC is a part of, or it was stopped for some other
+ // reason, such as the address being removed). If an error occured
+ // during DAD, err will be set and resolved must be ignored.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
+
+ // OnDefaultRouterDiscovered will be called when a new default router is
+ // discovered. Implementations must return true if the newly discovered
+ // router should be remembered.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool
+
+ // OnDefaultRouterInvalidated will be called when a discovered default
+ // router that was remembered is invalidated.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address)
+
+ // OnOnLinkPrefixDiscovered will be called when a new on-link prefix is
+ // discovered. Implementations must return true if the newly discovered
+ // on-link prefix should be remembered.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool
+
+ // OnOnLinkPrefixInvalidated will be called when a discovered on-link
+ // prefix that was remembered is invalidated.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet)
+
+ // OnAutoGenAddress will be called when a new prefix with its
+ // autonomous address-configuration flag set has been received and SLAAC
+ // has been performed. Implementations may prevent the stack from
+ // assigning the address to the NIC by returning false.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool
+
+ // OnAutoGenAddressDeprecated will be called when an auto-generated
+ // address (as part of SLAAC) has been deprecated, but is still
+ // considered valid. Note, if an address is invalidated at the same
+ // time it is deprecated, the deprecation event MAY be omitted.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix)
+
+ // OnAutoGenAddressInvalidated will be called when an auto-generated
+ // address (as part of SLAAC) has been invalidated.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix)
+
+ // OnRecursiveDNSServerOption will be called when an NDP option with
+ // recursive DNS servers has been received. Note, addrs may contain
+ // link-local addresses.
+ //
+ // It is up to the caller to use the DNS Servers only for their valid
+ // lifetime. OnRecursiveDNSServerOption may be called for new or
+ // already known DNS servers. If called with known DNS servers, their
+ // valid lifetimes must be refreshed to lifetime (it may be increased,
+ // decreased, or completely invalidated when lifetime = 0).
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration)
+
+ // OnDNSSearchListOption will be called when an NDP option with a DNS
+ // search list has been received.
+ //
+ // It is up to the caller to use the domain names in the search list
+ // for only their valid lifetime. OnDNSSearchListOption may be called
+ // with new or already known domain names. If called with known domain
+ // names, their valid lifetimes must be refreshed to lifetime (it may
+ // be increased, decreased or completely invalidated when lifetime = 0.
+ OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration)
+
+ // OnDHCPv6Configuration will be called with an updated configuration that is
+ // available via DHCPv6 for a specified NIC.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA)
+}
+
+// NDPConfigurations is the NDP configurations for the netstack.
+type NDPConfigurations struct {
+ // The number of Neighbor Solicitation messages to send when doing
+ // Duplicate Address Detection for a tentative address.
+ //
+ // Note, a value of zero effectively disables DAD.
+ DupAddrDetectTransmits uint8
+
+ // The amount of time to wait between sending Neighbor solicitation
+ // messages.
+ //
+ // Must be greater than or equal to 1ms.
+ RetransmitTimer time.Duration
+
+ // The number of Router Solicitation messages to send when the NIC
+ // becomes enabled.
+ MaxRtrSolicitations uint8
+
+ // The amount of time between transmitting Router Solicitation messages.
+ //
+ // Must be greater than or equal to 0.5s.
+ RtrSolicitationInterval time.Duration
+
+ // The maximum amount of time before transmitting the first Router
+ // Solicitation message.
+ //
+ // Must be greater than or equal to 0s.
+ MaxRtrSolicitationDelay time.Duration
+
+ // HandleRAs determines whether or not Router Advertisements will be
+ // processed.
+ HandleRAs bool
+
+ // DiscoverDefaultRouters determines whether or not default routers will
+ // be discovered from Router Advertisements. This configuration is
+ // ignored if HandleRAs is false.
+ DiscoverDefaultRouters bool
+
+ // DiscoverOnLinkPrefixes determines whether or not on-link prefixes
+ // will be discovered from Router Advertisements' Prefix Information
+ // option. This configuration is ignored if HandleRAs is false.
+ DiscoverOnLinkPrefixes bool
+
+ // AutoGenGlobalAddresses determines whether or not global IPv6
+ // addresses will be generated for a NIC in response to receiving a new
+ // Prefix Information option with its Autonomous Address
+ // AutoConfiguration flag set, as a host, as per RFC 4862 (SLAAC).
+ //
+ // Note, if an address was already generated for some unique prefix, as
+ // part of SLAAC, this option does not affect whether or not the
+ // lifetime(s) of the generated address changes; this option only
+ // affects the generation of new addresses as part of SLAAC.
+ AutoGenGlobalAddresses bool
+
+ // AutoGenAddressConflictRetries determines how many times to attempt to retry
+ // generation of a permanent auto-generated address in response to DAD
+ // conflicts.
+ //
+ // If the method used to generate the address does not support creating
+ // alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's
+ // MAC address), then no attempt will be made to resolve the conflict.
+ AutoGenAddressConflictRetries uint8
+
+ // AutoGenTempGlobalAddresses determines whether or not temporary SLAAC
+ // addresses will be generated for a NIC as part of SLAAC privacy extensions,
+ // RFC 4941.
+ //
+ // Ignored if AutoGenGlobalAddresses is false.
+ AutoGenTempGlobalAddresses bool
+
+ // MaxTempAddrValidLifetime is the maximum valid lifetime for temporary
+ // SLAAC addresses.
+ MaxTempAddrValidLifetime time.Duration
+
+ // MaxTempAddrPreferredLifetime is the maximum preferred lifetime for
+ // temporary SLAAC addresses.
+ MaxTempAddrPreferredLifetime time.Duration
+
+ // RegenAdvanceDuration is the duration before the deprecation of a temporary
+ // address when a new address will be generated.
+ RegenAdvanceDuration time.Duration
+}
+
+// DefaultNDPConfigurations returns an NDPConfigurations populated with
+// default values.
+func DefaultNDPConfigurations() NDPConfigurations {
+ return NDPConfigurations{
+ DupAddrDetectTransmits: defaultDupAddrDetectTransmits,
+ RetransmitTimer: defaultRetransmitTimer,
+ MaxRtrSolicitations: defaultMaxRtrSolicitations,
+ RtrSolicitationInterval: defaultRtrSolicitationInterval,
+ MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay,
+ HandleRAs: defaultHandleRAs,
+ DiscoverDefaultRouters: defaultDiscoverDefaultRouters,
+ DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes,
+ AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses,
+ AutoGenTempGlobalAddresses: defaultAutoGenTempGlobalAddresses,
+ MaxTempAddrValidLifetime: defaultMaxTempAddrValidLifetime,
+ MaxTempAddrPreferredLifetime: defaultMaxTempAddrPreferredLifetime,
+ RegenAdvanceDuration: defaultRegenAdvanceDuration,
+ }
+}
+
+// validate modifies an NDPConfigurations with valid values. If invalid values
+// are present in c, the corresponding default values will be used instead.
+func (c *NDPConfigurations) validate() {
+ if c.RetransmitTimer < minimumRetransmitTimer {
+ c.RetransmitTimer = defaultRetransmitTimer
+ }
+
+ if c.RtrSolicitationInterval < minimumRtrSolicitationInterval {
+ c.RtrSolicitationInterval = defaultRtrSolicitationInterval
+ }
+
+ if c.MaxRtrSolicitationDelay < minimumMaxRtrSolicitationDelay {
+ c.MaxRtrSolicitationDelay = defaultMaxRtrSolicitationDelay
+ }
+
+ if c.MaxTempAddrValidLifetime < MinMaxTempAddrValidLifetime {
+ c.MaxTempAddrValidLifetime = MinMaxTempAddrValidLifetime
+ }
+
+ if c.MaxTempAddrPreferredLifetime < MinMaxTempAddrPreferredLifetime || c.MaxTempAddrPreferredLifetime > c.MaxTempAddrValidLifetime {
+ c.MaxTempAddrPreferredLifetime = MinMaxTempAddrPreferredLifetime
+ }
+
+ if c.RegenAdvanceDuration < minRegenAdvanceDuration {
+ c.RegenAdvanceDuration = minRegenAdvanceDuration
+ }
+}
+
+// ndpState is the per-interface NDP state.
+type ndpState struct {
+ // The NIC this ndpState is for.
+ nic *NIC
+
+ // configs is the per-interface NDP configurations.
+ configs NDPConfigurations
+
+ // The DAD state to send the next NS message, or resolve the address.
+ dad map[tcpip.Address]dadState
+
+ // The default routers discovered through Router Advertisements.
+ defaultRouters map[tcpip.Address]defaultRouterState
+
+ rtrSolicit struct {
+ // The timer used to send the next router solicitation message.
+ timer *time.Timer
+
+ // Used to let the Router Solicitation timer know that it has been stopped.
+ //
+ // Must only be read from or written to while protected by the lock of
+ // the NIC this ndpState is associated with. MUST be set when the timer is
+ // set.
+ done *bool
+ }
+
+ // The on-link prefixes discovered through Router Advertisements' Prefix
+ // Information option.
+ onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState
+
+ // The SLAAC prefixes discovered through Router Advertisements' Prefix
+ // Information option.
+ slaacPrefixes map[tcpip.Subnet]slaacPrefixState
+
+ // The last learned DHCPv6 configuration from an NDP RA.
+ dhcpv6Configuration DHCPv6ConfigurationFromNDPRA
+
+ // temporaryIIDHistory is the history value used to generate a new temporary
+ // IID.
+ temporaryIIDHistory [header.IIDSize]byte
+
+ // temporaryAddressDesyncFactor is the preferred lifetime's desync factor for
+ // temporary SLAAC addresses.
+ temporaryAddressDesyncFactor time.Duration
+}
+
+// dadState holds the Duplicate Address Detection timer and channel to signal
+// to the DAD goroutine that DAD should stop.
+type dadState struct {
+ // The DAD timer to send the next NS message, or resolve the address.
+ timer *time.Timer
+
+ // Used to let the DAD timer know that it has been stopped.
+ //
+ // Must only be read from or written to while protected by the lock of
+ // the NIC this dadState is associated with.
+ done *bool
+}
+
+// defaultRouterState holds data associated with a default router discovered by
+// a Router Advertisement (RA).
+type defaultRouterState struct {
+ // Timer to invalidate the default router.
+ //
+ // Must not be nil.
+ invalidationTimer *tcpip.CancellableTimer
+}
+
+// onLinkPrefixState holds data associated with an on-link prefix discovered by
+// a Router Advertisement's Prefix Information option (PI) when the NDP
+// configurations was configured to do so.
+type onLinkPrefixState struct {
+ // Timer to invalidate the on-link prefix.
+ //
+ // Must not be nil.
+ invalidationTimer *tcpip.CancellableTimer
+}
+
+// tempSLAACAddrState holds state associated with a temporary SLAAC address.
+type tempSLAACAddrState struct {
+ // Timer to deprecate the temporary SLAAC address.
+ //
+ // Must not be nil.
+ deprecationTimer *tcpip.CancellableTimer
+
+ // Timer to invalidate the temporary SLAAC address.
+ //
+ // Must not be nil.
+ invalidationTimer *tcpip.CancellableTimer
+
+ // Timer to regenerate the temporary SLAAC address.
+ //
+ // Must not be nil.
+ regenTimer *tcpip.CancellableTimer
+
+ createdAt time.Time
+
+ // The address's endpoint.
+ //
+ // Must not be nil.
+ ref *referencedNetworkEndpoint
+
+ // Has a new temporary SLAAC address already been regenerated?
+ regenerated bool
+}
+
+// slaacPrefixState holds state associated with a SLAAC prefix.
+type slaacPrefixState struct {
+ // Timer to deprecate the prefix.
+ //
+ // Must not be nil.
+ deprecationTimer *tcpip.CancellableTimer
+
+ // Timer to invalidate the prefix.
+ //
+ // Must not be nil.
+ invalidationTimer *tcpip.CancellableTimer
+
+ // Nonzero only when the address is not valid forever.
+ validUntil time.Time
+
+ // Nonzero only when the address is not preferred forever.
+ preferredUntil time.Time
+
+ // State associated with the stable address generated for the prefix.
+ stableAddr struct {
+ // The address's endpoint.
+ //
+ // May only be nil when the address is being (re-)generated. Otherwise,
+ // must not be nil as all SLAAC prefixes must have a stable address.
+ ref *referencedNetworkEndpoint
+
+ // The number of times an address has been generated locally where the NIC
+ // already had the generated address.
+ localGenerationFailures uint8
+ }
+
+ // The temporary (short-lived) addresses generated for the SLAAC prefix.
+ tempAddrs map[tcpip.Address]tempSLAACAddrState
+
+ // The next two fields are used by both stable and temporary addresses
+ // generated for a SLAAC prefix. This is safe as only 1 address will be
+ // in the generation and DAD process at any time. That is, no two addresses
+ // will be generated at the same time for a given SLAAC prefix.
+
+ // The number of times an address has been generated and added to the NIC.
+ //
+ // Addresses may be regenerated in reseponse to a DAD conflicts.
+ generationAttempts uint8
+
+ // The maximum number of times to attempt regeneration of a SLAAC address
+ // in response to DAD conflicts.
+ maxGenerationAttempts uint8
+}
+
+// startDuplicateAddressDetection performs Duplicate Address Detection.
+//
+// This function must only be called by IPv6 addresses that are currently
+// tentative.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
+ // addr must be a valid unicast IPv6 address.
+ if !header.IsV6UnicastAddress(addr) {
+ return tcpip.ErrAddressFamilyNotSupported
+ }
+
+ if ref.getKind() != permanentTentative {
+ // The endpoint should be marked as tentative since we are starting DAD.
+ panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()))
+ }
+
+ // Should not attempt to perform DAD on an address that is currently in the
+ // DAD process.
+ if _, ok := ndp.dad[addr]; ok {
+ // Should never happen because we should only ever call this function for
+ // newly created addresses. If we attemped to "add" an address that already
+ // existed, we would get an error since we attempted to add a duplicate
+ // address, or its reference count would have been increased without doing
+ // the work that would have been done for an address that was brand new.
+ // See NIC.addAddressLocked.
+ panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID()))
+ }
+
+ remaining := ndp.configs.DupAddrDetectTransmits
+ if remaining == 0 {
+ ref.setKind(permanent)
+
+ // Consider DAD to have resolved even if no DAD messages were actually
+ // transmitted.
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, true, nil)
+ }
+
+ return nil
+ }
+
+ var done bool
+ var timer *time.Timer
+ // We initially start a timer to fire immediately because some of the DAD work
+ // cannot be done while holding the NIC's lock. This is effectively the same
+ // as starting a goroutine but we use a timer that fires immediately so we can
+ // reset it for the next DAD iteration.
+ timer = time.AfterFunc(0, func() {
+ ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
+
+ if done {
+ // If we reach this point, it means that the DAD timer fired after
+ // another goroutine already obtained the NIC lock and stopped DAD
+ // before this function obtained the NIC lock. Simply return here and do
+ // nothing further.
+ return
+ }
+
+ if ref.getKind() != permanentTentative {
+ // The endpoint should still be marked as tentative since we are still
+ // performing DAD on it.
+ panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID()))
+ }
+
+ dadDone := remaining == 0
+
+ var err *tcpip.Error
+ if !dadDone {
+ // Use the unspecified address as the source address when performing DAD.
+ ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint)
+
+ // Do not hold the lock when sending packets which may be a long running
+ // task or may block link address resolution. We know this is safe
+ // because immediately after obtaining the lock again, we check if DAD
+ // has been stopped before doing any work with the NIC. Note, DAD would be
+ // stopped if the NIC was disabled or removed, or if the address was
+ // removed.
+ ndp.nic.mu.Unlock()
+ err = ndp.sendDADPacket(addr, ref)
+ ndp.nic.mu.Lock()
+ }
+
+ if done {
+ // If we reach this point, it means that DAD was stopped after we released
+ // the NIC's read lock and before we obtained the write lock.
+ return
+ }
+
+ if dadDone {
+ // DAD has resolved.
+ ref.setKind(permanent)
+ } else if err == nil {
+ // DAD is not done and we had no errors when sending the last NDP NS,
+ // schedule the next DAD timer.
+ remaining--
+ timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer)
+ return
+ }
+
+ // At this point we know that either DAD is done or we hit an error sending
+ // the last NDP NS. Either way, clean up addr's DAD state and let the
+ // integrator know DAD has completed.
+ delete(ndp.dad, addr)
+
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, dadDone, err)
+ }
+
+ // If DAD resolved for a stable SLAAC address, attempt generation of a
+ // temporary SLAAC address.
+ if dadDone && ref.configType == slaac {
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ ndp.regenerateTempSLAACAddr(ref.addrWithPrefix().Subnet(), true /* resetGenAttempts */)
+ }
+ })
+
+ ndp.dad[addr] = dadState{
+ timer: timer,
+ done: &done,
+ }
+
+ return nil
+}
+
+// sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns
+// addr.
+//
+// addr must be a tentative IPv6 address on ndp's NIC.
+//
+// The NIC ndp belongs to MUST NOT be locked.
+func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
+ snmc := header.SolicitedNodeAddr(addr)
+
+ r := makeRoute(header.IPv6ProtocolNumber, ref.ep.ID().LocalAddress, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ defer r.Release()
+
+ // Route should resolve immediately since snmc is a multicast address so a
+ // remote link address can be calculated without a resolution process.
+ if c, err := r.Resolve(nil); err != nil {
+ // Do not consider the NIC being unknown or disabled as a fatal error.
+ // Since this method is required to be called when the NIC is not locked,
+ // the NIC could have been disabled or removed by another goroutine.
+ if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState {
+ return err
+ }
+
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err))
+ } else if c != nil {
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID()))
+ }
+
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(addr)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ sent := r.Stats().ICMP.V6PacketsSent
+ if err := r.WritePacket(nil,
+ NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ TOS: DefaultTOS,
+ }, &PacketBuffer{Header: hdr},
+ ); err != nil {
+ sent.Dropped.Increment()
+ return err
+ }
+ sent.NeighborSolicit.Increment()
+
+ return nil
+}
+
+// stopDuplicateAddressDetection ends a running Duplicate Address Detection
+// process. Note, this may leave the DAD process for a tentative address in
+// such a state forever, unless some other external event resolves the DAD
+// process (receiving an NA from the true owner of addr, or an NS for addr
+// (implying another node is attempting to use addr)). It is up to the caller
+// of this function to handle such a scenario. Normally, addr will be removed
+// from n right after this function returns or the address successfully
+// resolved.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
+ dad, ok := ndp.dad[addr]
+ if !ok {
+ // Not currently performing DAD on addr, just return.
+ return
+ }
+
+ if dad.timer != nil {
+ dad.timer.Stop()
+ dad.timer = nil
+
+ *dad.done = true
+ dad.done = nil
+ }
+
+ delete(ndp.dad, addr)
+
+ // Let the integrator know DAD did not resolve.
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil)
+ }
+}
+
+// handleRA handles a Router Advertisement message that arrived on the NIC
+// this ndp is for. Does nothing if the NIC is configured to not handle RAs.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
+ // Is the NIC configured to handle RAs at all?
+ //
+ // Currently, the stack does not determine router interface status on a
+ // per-interface basis; it is a stack-wide configuration, so we check
+ // stack's forwarding flag to determine if the NIC is a routing
+ // interface.
+ if !ndp.configs.HandleRAs || ndp.nic.stack.forwarding {
+ return
+ }
+
+ // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we
+ // only inform the dispatcher on configuration changes. We do nothing else
+ // with the information.
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ var configuration DHCPv6ConfigurationFromNDPRA
+ switch {
+ case ra.ManagedAddrConfFlag():
+ configuration = DHCPv6ManagedAddress
+
+ case ra.OtherConfFlag():
+ configuration = DHCPv6OtherConfigurations
+
+ default:
+ configuration = DHCPv6NoConfiguration
+ }
+
+ if ndp.dhcpv6Configuration != configuration {
+ ndp.dhcpv6Configuration = configuration
+ ndpDisp.OnDHCPv6Configuration(ndp.nic.ID(), configuration)
+ }
+ }
+
+ // Is the NIC configured to discover default routers?
+ if ndp.configs.DiscoverDefaultRouters {
+ rtr, ok := ndp.defaultRouters[ip]
+ rl := ra.RouterLifetime()
+ switch {
+ case !ok && rl != 0:
+ // This is a new default router we are discovering.
+ //
+ // Only remember it if we currently know about less than
+ // MaxDiscoveredDefaultRouters routers.
+ if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters {
+ ndp.rememberDefaultRouter(ip, rl)
+ }
+
+ case ok && rl != 0:
+ // This is an already discovered default router. Update
+ // the invalidation timer.
+ rtr.invalidationTimer.StopLocked()
+ rtr.invalidationTimer.Reset(rl)
+ ndp.defaultRouters[ip] = rtr
+
+ case ok && rl == 0:
+ // We know about the router but it is no longer to be
+ // used as a default router so invalidate it.
+ ndp.invalidateDefaultRouter(ip)
+ }
+ }
+
+ // TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter
+ // Discovery.
+
+ // We know the options is valid as far as wire format is concerned since
+ // we got the Router Advertisement, as documented by this fn. Given this
+ // we do not check the iterator for errors on calls to Next.
+ it, _ := ra.Options().Iter(false)
+ for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() {
+ switch opt := opt.(type) {
+ case header.NDPRecursiveDNSServer:
+ if ndp.nic.stack.ndpDisp == nil {
+ continue
+ }
+
+ addrs, _ := opt.Addresses()
+ ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), addrs, opt.Lifetime())
+
+ case header.NDPDNSSearchList:
+ if ndp.nic.stack.ndpDisp == nil {
+ continue
+ }
+
+ domainNames, _ := opt.DomainNames()
+ ndp.nic.stack.ndpDisp.OnDNSSearchListOption(ndp.nic.ID(), domainNames, opt.Lifetime())
+
+ case header.NDPPrefixInformation:
+ prefix := opt.Subnet()
+
+ // Is the prefix a link-local?
+ if header.IsV6LinkLocalAddress(prefix.ID()) {
+ // ...Yes, skip as per RFC 4861 section 6.3.4,
+ // and RFC 4862 section 5.5.3.b (for SLAAC).
+ continue
+ }
+
+ // Is the Prefix Length 0?
+ if prefix.Prefix() == 0 {
+ // ...Yes, skip as this is an invalid prefix
+ // as all IPv6 addresses cannot be on-link.
+ continue
+ }
+
+ if opt.OnLinkFlag() {
+ ndp.handleOnLinkPrefixInformation(opt)
+ }
+
+ if opt.AutonomousAddressConfigurationFlag() {
+ ndp.handleAutonomousPrefixInformation(opt)
+ }
+ }
+
+ // TODO(b/141556115): Do (MTU) Parameter Discovery.
+ }
+}
+
+// invalidateDefaultRouter invalidates a discovered default router.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
+ rtr, ok := ndp.defaultRouters[ip]
+
+ // Is the router still discovered?
+ if !ok {
+ // ...Nope, do nothing further.
+ return
+ }
+
+ rtr.invalidationTimer.StopLocked()
+ delete(ndp.defaultRouters, ip)
+
+ // Let the integrator know a discovered default router is invalidated.
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip)
+ }
+}
+
+// rememberDefaultRouter remembers a newly discovered default router with IPv6
+// link-local address ip with lifetime rl.
+//
+// The router identified by ip MUST NOT already be known by the NIC.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
+ ndpDisp := ndp.nic.stack.ndpDisp
+ if ndpDisp == nil {
+ return
+ }
+
+ // Inform the integrator when we discovered a default router.
+ if !ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) {
+ // Informed by the integrator to not remember the router, do
+ // nothing further.
+ return
+ }
+
+ state := defaultRouterState{
+ invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ ndp.invalidateDefaultRouter(ip)
+ }),
+ }
+
+ state.invalidationTimer.Reset(rl)
+
+ ndp.defaultRouters[ip] = state
+}
+
+// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6
+// address with prefix prefix with lifetime l.
+//
+// The prefix identified by prefix MUST NOT already be known.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) {
+ ndpDisp := ndp.nic.stack.ndpDisp
+ if ndpDisp == nil {
+ return
+ }
+
+ // Inform the integrator when we discovered an on-link prefix.
+ if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) {
+ // Informed by the integrator to not remember the prefix, do
+ // nothing further.
+ return
+ }
+
+ state := onLinkPrefixState{
+ invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ ndp.invalidateOnLinkPrefix(prefix)
+ }),
+ }
+
+ if l < header.NDPInfiniteLifetime {
+ state.invalidationTimer.Reset(l)
+ }
+
+ ndp.onLinkPrefixes[prefix] = state
+}
+
+// invalidateOnLinkPrefix invalidates a discovered on-link prefix.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
+ s, ok := ndp.onLinkPrefixes[prefix]
+
+ // Is the on-link prefix still discovered?
+ if !ok {
+ // ...Nope, do nothing further.
+ return
+ }
+
+ s.invalidationTimer.StopLocked()
+ delete(ndp.onLinkPrefixes, prefix)
+
+ // Let the integrator know a discovered on-link prefix is invalidated.
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix)
+ }
+}
+
+// handleOnLinkPrefixInformation handles a Prefix Information option with
+// its on-link flag set, as per RFC 4861 section 6.3.4.
+//
+// handleOnLinkPrefixInformation assumes that the prefix this pi is for is
+// not the link-local prefix and the on-link flag is set.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) {
+ prefix := pi.Subnet()
+ prefixState, ok := ndp.onLinkPrefixes[prefix]
+ vl := pi.ValidLifetime()
+
+ if !ok && vl == 0 {
+ // Don't know about this prefix but it has a zero valid
+ // lifetime, so just ignore.
+ return
+ }
+
+ if !ok && vl != 0 {
+ // This is a new on-link prefix we are discovering
+ //
+ // Only remember it if we currently know about less than
+ // MaxDiscoveredOnLinkPrefixes on-link prefixes.
+ if ndp.configs.DiscoverOnLinkPrefixes && len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes {
+ ndp.rememberOnLinkPrefix(prefix, vl)
+ }
+ return
+ }
+
+ if ok && vl == 0 {
+ // We know about the on-link prefix, but it is
+ // no longer to be considered on-link, so
+ // invalidate it.
+ ndp.invalidateOnLinkPrefix(prefix)
+ return
+ }
+
+ // This is an already discovered on-link prefix with a
+ // new non-zero valid lifetime.
+ //
+ // Update the invalidation timer.
+
+ prefixState.invalidationTimer.StopLocked()
+
+ if vl < header.NDPInfiniteLifetime {
+ // Prefix is valid for a finite lifetime, reset the timer to expire after
+ // the new valid lifetime.
+ prefixState.invalidationTimer.Reset(vl)
+ }
+
+ ndp.onLinkPrefixes[prefix] = prefixState
+}
+
+// handleAutonomousPrefixInformation handles a Prefix Information option with
+// its autonomous flag set, as per RFC 4862 section 5.5.3.
+//
+// handleAutonomousPrefixInformation assumes that the prefix this pi is for is
+// not the link-local prefix and the autonomous flag is set.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) {
+ vl := pi.ValidLifetime()
+ pl := pi.PreferredLifetime()
+
+ // If the preferred lifetime is greater than the valid lifetime,
+ // silently ignore the Prefix Information option, as per RFC 4862
+ // section 5.5.3.c.
+ if pl > vl {
+ return
+ }
+
+ prefix := pi.Subnet()
+
+ // Check if we already maintain SLAAC state for prefix.
+ if state, ok := ndp.slaacPrefixes[prefix]; ok {
+ // As per RFC 4862 section 5.5.3.e, refresh prefix's SLAAC lifetimes.
+ ndp.refreshSLAACPrefixLifetimes(prefix, &state, pl, vl)
+ ndp.slaacPrefixes[prefix] = state
+ return
+ }
+
+ // prefix is a new SLAAC prefix. Do the work as outlined by RFC 4862 section
+ // 5.5.3.d if ndp is configured to auto-generate new addresses via SLAAC.
+ if !ndp.configs.AutoGenGlobalAddresses {
+ return
+ }
+
+ ndp.doSLAAC(prefix, pl, vl)
+}
+
+// doSLAAC generates a new SLAAC address with the provided lifetimes
+// for prefix.
+//
+// pl is the new preferred lifetime. vl is the new valid lifetime.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
+ // If we do not already have an address for this prefix and the valid
+ // lifetime is 0, no need to do anything further, as per RFC 4862
+ // section 5.5.3.d.
+ if vl == 0 {
+ return
+ }
+
+ // Make sure the prefix is valid (as far as its length is concerned) to
+ // generate a valid IPv6 address from an interface identifier (IID), as
+ // per RFC 4862 sectiion 5.5.3.d.
+ if prefix.Prefix() != validPrefixLenForAutoGen {
+ return
+ }
+
+ state := slaacPrefixState{
+ deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
+ }
+
+ ndp.deprecateSLAACAddress(state.stableAddr.ref)
+ }),
+ invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix))
+ }
+
+ ndp.invalidateSLAACPrefix(prefix, state)
+ }),
+ tempAddrs: make(map[tcpip.Address]tempSLAACAddrState),
+ maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1,
+ }
+
+ now := time.Now()
+
+ // The time an address is preferred until is needed to properly generate the
+ // address.
+ if pl < header.NDPInfiniteLifetime {
+ state.preferredUntil = now.Add(pl)
+ }
+
+ if !ndp.generateSLAACAddr(prefix, &state) {
+ // We were unable to generate an address for the prefix, we do not nothing
+ // further as there is no reason to maintain state or timers for a prefix we
+ // do not have an address for.
+ return
+ }
+
+ // Setup the initial timers to deprecate and invalidate prefix.
+
+ if pl < header.NDPInfiniteLifetime && pl != 0 {
+ state.deprecationTimer.Reset(pl)
+ }
+
+ if vl < header.NDPInfiniteLifetime {
+ state.invalidationTimer.Reset(vl)
+ state.validUntil = now.Add(vl)
+ }
+
+ // If the address is assigned (DAD resolved), generate a temporary address.
+ if state.stableAddr.ref.getKind() == permanent {
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ ndp.generateTempSLAACAddr(prefix, &state, true /* resetGenAttempts */)
+ }
+
+ ndp.slaacPrefixes[prefix] = state
+}
+
+// addSLAACAddr adds a SLAAC address to the NIC.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType networkEndpointConfigType, deprecated bool) *referencedNetworkEndpoint {
+ // Inform the integrator that we have a new SLAAC address.
+ ndpDisp := ndp.nic.stack.ndpDisp
+ if ndpDisp == nil {
+ return nil
+ }
+
+ if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addr) {
+ // Informed by the integrator not to add the address.
+ return nil
+ }
+
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr,
+ }
+
+ ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, configType, deprecated)
+ if err != nil {
+ panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", protocolAddr, err))
+ }
+
+ return ref
+}
+
+// generateSLAACAddr generates a SLAAC address for prefix.
+//
+// Returns true if an address was successfully generated.
+//
+// Panics if the prefix is not a SLAAC prefix or it already has an address.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool {
+ if r := state.stableAddr.ref; r != nil {
+ panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix()))
+ }
+
+ // If we have already reached the maximum address generation attempts for the
+ // prefix, do not generate another address.
+ if state.generationAttempts == state.maxGenerationAttempts {
+ return false
+ }
+
+ var generatedAddr tcpip.AddressWithPrefix
+ addrBytes := []byte(prefix.ID())
+
+ for i := 0; ; i++ {
+ // If we were unable to generate an address after the maximum SLAAC address
+ // local regeneration attempts, do nothing further.
+ if i == maxSLAACAddrLocalRegenAttempts {
+ return false
+ }
+
+ dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures
+ if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
+ addrBytes = header.AppendOpaqueInterfaceIdentifier(
+ addrBytes[:header.IIDOffsetInIPv6Address],
+ prefix,
+ oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name),
+ dadCounter,
+ oIID.SecretKey,
+ )
+ } else if dadCounter == 0 {
+ // Modified-EUI64 based IIDs have no way to resolve DAD conflicts, so if
+ // the DAD counter is non-zero, we cannot use this method.
+ //
+ // Only attempt to generate an interface-specific IID if we have a valid
+ // link address.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ linkAddr := ndp.nic.linkEP.LinkAddress()
+ if !header.IsValidUnicastEthernetAddress(linkAddr) {
+ return false
+ }
+
+ // Generate an address within prefix from the modified EUI-64 of ndp's
+ // NIC's Ethernet MAC address.
+ header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
+ } else {
+ // We have no way to regenerate an address in response to an address
+ // conflict when addresses are not generated with opaque IIDs.
+ return false
+ }
+
+ generatedAddr = tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: validPrefixLenForAutoGen,
+ }
+
+ if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) {
+ break
+ }
+
+ state.stableAddr.localGenerationFailures++
+ }
+
+ if ref := ndp.addSLAACAddr(generatedAddr, slaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); ref != nil {
+ state.stableAddr.ref = ref
+ state.generationAttempts++
+ return true
+ }
+
+ return false
+}
+
+// regenerateSLAACAddr regenerates an address for a SLAAC prefix.
+//
+// If generating a new address for the prefix fails, the prefix will be
+// invalidated.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate address for %s", prefix))
+ }
+
+ if ndp.generateSLAACAddr(prefix, &state) {
+ ndp.slaacPrefixes[prefix] = state
+ return
+ }
+
+ // We were unable to generate a permanent address for the SLAAC prefix so
+ // invalidate the prefix as there is no reason to maintain state for a
+ // SLAAC prefix we do not have an address for.
+ ndp.invalidateSLAACPrefix(prefix, state)
+}
+
+// generateTempSLAACAddr generates a new temporary SLAAC address.
+//
+// If resetGenAttempts is true, the prefix's generation counter will be reset.
+//
+// Returns true if a new address was generated.
+func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *slaacPrefixState, resetGenAttempts bool) bool {
+ // Are we configured to auto-generate new temporary global addresses for the
+ // prefix?
+ if !ndp.configs.AutoGenTempGlobalAddresses || prefix == header.IPv6LinkLocalPrefix.Subnet() {
+ return false
+ }
+
+ if resetGenAttempts {
+ prefixState.generationAttempts = 0
+ prefixState.maxGenerationAttempts = ndp.configs.AutoGenAddressConflictRetries + 1
+ }
+
+ // If we have already reached the maximum address generation attempts for the
+ // prefix, do not generate another address.
+ if prefixState.generationAttempts == prefixState.maxGenerationAttempts {
+ return false
+ }
+
+ stableAddr := prefixState.stableAddr.ref.ep.ID().LocalAddress
+ now := time.Now()
+
+ // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary
+ // address is the lower of the valid lifetime of the stable address or the
+ // maximum temporary address valid lifetime.
+ vl := ndp.configs.MaxTempAddrValidLifetime
+ if prefixState.validUntil != (time.Time{}) {
+ if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL {
+ vl = prefixVL
+ }
+ }
+
+ if vl <= 0 {
+ // Cannot create an address without a valid lifetime.
+ return false
+ }
+
+ // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary
+ // address is the lower of the preferred lifetime of the stable address or the
+ // maximum temporary address preferred lifetime - the temporary address desync
+ // factor.
+ pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor
+ if prefixState.preferredUntil != (time.Time{}) {
+ if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL {
+ // Respect the preferred lifetime of the prefix, as per RFC 4941 section
+ // 3.3 step 4.
+ pl = prefixPL
+ }
+ }
+
+ // As per RFC 4941 section 3.3 step 5, a temporary address is created only if
+ // the calculated preferred lifetime is greater than the advance regeneration
+ // duration. In particular, we MUST NOT create a temporary address with a zero
+ // Preferred Lifetime.
+ if pl <= ndp.configs.RegenAdvanceDuration {
+ return false
+ }
+
+ // Attempt to generate a new address that is not already assigned to the NIC.
+ var generatedAddr tcpip.AddressWithPrefix
+ for i := 0; ; i++ {
+ // If we were unable to generate an address after the maximum SLAAC address
+ // local regeneration attempts, do nothing further.
+ if i == maxSLAACAddrLocalRegenAttempts {
+ return false
+ }
+
+ generatedAddr = header.GenerateTempIPv6SLAACAddr(ndp.temporaryIIDHistory[:], stableAddr)
+ if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) {
+ break
+ }
+ }
+
+ // As per RFC RFC 4941 section 3.3 step 5, we MUST NOT create a temporary
+ // address with a zero preferred lifetime. The checks above ensure this
+ // so we know the address is not deprecated.
+ ref := ndp.addSLAACAddr(generatedAddr, slaacTemp, false /* deprecated */)
+ if ref == nil {
+ return false
+ }
+
+ state := tempSLAACAddrState{
+ deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr))
+ }
+
+ tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to deprecate temporary address %s", generatedAddr))
+ }
+
+ ndp.deprecateSLAACAddress(tempAddrState.ref)
+ }),
+ invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr))
+ }
+
+ tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to invalidate temporary address %s", generatedAddr))
+ }
+
+ ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState)
+ }),
+ regenTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr))
+ }
+
+ tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to regenerate temporary address after %s", generatedAddr))
+ }
+
+ // If an address has already been regenerated for this address, don't
+ // regenerate another address.
+ if tempAddrState.regenerated {
+ return
+ }
+
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ tempAddrState.regenerated = ndp.generateTempSLAACAddr(prefix, &prefixState, true /* resetGenAttempts */)
+ prefixState.tempAddrs[generatedAddr.Address] = tempAddrState
+ ndp.slaacPrefixes[prefix] = prefixState
+ }),
+ createdAt: now,
+ ref: ref,
+ }
+
+ state.deprecationTimer.Reset(pl)
+ state.invalidationTimer.Reset(vl)
+ state.regenTimer.Reset(pl - ndp.configs.RegenAdvanceDuration)
+
+ prefixState.generationAttempts++
+ prefixState.tempAddrs[generatedAddr.Address] = state
+
+ return true
+}
+
+// regenerateTempSLAACAddr regenerates a temporary address for a SLAAC prefix.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttempts bool) {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate temporary address for %s", prefix))
+ }
+
+ ndp.generateTempSLAACAddr(prefix, &state, resetGenAttempts)
+ ndp.slaacPrefixes[prefix] = state
+}
+
+// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix.
+//
+// pl is the new preferred lifetime. vl is the new valid lifetime.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixState *slaacPrefixState, pl, vl time.Duration) {
+ // If the preferred lifetime is zero, then the prefix should be deprecated.
+ deprecated := pl == 0
+ if deprecated {
+ ndp.deprecateSLAACAddress(prefixState.stableAddr.ref)
+ } else {
+ prefixState.stableAddr.ref.deprecated = false
+ }
+
+ // If prefix was preferred for some finite lifetime before, stop the
+ // deprecation timer so it can be reset.
+ prefixState.deprecationTimer.StopLocked()
+
+ now := time.Now()
+
+ // Reset the deprecation timer if prefix has a finite preferred lifetime.
+ if pl < header.NDPInfiniteLifetime {
+ if !deprecated {
+ prefixState.deprecationTimer.Reset(pl)
+ }
+ prefixState.preferredUntil = now.Add(pl)
+ } else {
+ prefixState.preferredUntil = time.Time{}
+ }
+
+ // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix:
+ //
+ // 1) If the received Valid Lifetime is greater than 2 hours or greater than
+ // RemainingLifetime, set the valid lifetime of the prefix to the
+ // advertised Valid Lifetime.
+ //
+ // 2) If RemainingLifetime is less than or equal to 2 hours, ignore the
+ // advertised Valid Lifetime.
+ //
+ // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours.
+
+ if vl >= header.NDPInfiniteLifetime {
+ // Handle the infinite valid lifetime separately as we do not keep a timer
+ // in this case.
+ prefixState.invalidationTimer.StopLocked()
+ prefixState.validUntil = time.Time{}
+ } else {
+ var effectiveVl time.Duration
+ var rl time.Duration
+
+ // If the prefix was originally set to be valid forever, assume the
+ // remaining time to be the maximum possible value.
+ if prefixState.validUntil == (time.Time{}) {
+ rl = header.NDPInfiniteLifetime
+ } else {
+ rl = time.Until(prefixState.validUntil)
+ }
+
+ if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl {
+ effectiveVl = vl
+ } else if rl > MinPrefixInformationValidLifetimeForUpdate {
+ effectiveVl = MinPrefixInformationValidLifetimeForUpdate
+ }
+
+ if effectiveVl != 0 {
+ prefixState.invalidationTimer.StopLocked()
+ prefixState.invalidationTimer.Reset(effectiveVl)
+ prefixState.validUntil = now.Add(effectiveVl)
+ }
+ }
+
+ // If DAD is not yet complete on the stable address, there is no need to do
+ // work with temporary addresses.
+ if prefixState.stableAddr.ref.getKind() != permanent {
+ return
+ }
+
+ // Note, we do not need to update the entries in the temporary address map
+ // after updating the timers because the timers are held as pointers.
+ var regenForAddr tcpip.Address
+ allAddressesRegenerated := true
+ for tempAddr, tempAddrState := range prefixState.tempAddrs {
+ // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary
+ // address is the lower of the valid lifetime of the stable address or the
+ // maximum temporary address valid lifetime. Note, the valid lifetime of a
+ // temporary address is relative to the address's creation time.
+ validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime)
+ if prefixState.validUntil != (time.Time{}) && validUntil.Sub(prefixState.validUntil) > 0 {
+ validUntil = prefixState.validUntil
+ }
+
+ // If the address is no longer valid, invalidate it immediately. Otherwise,
+ // reset the invalidation timer.
+ newValidLifetime := validUntil.Sub(now)
+ if newValidLifetime <= 0 {
+ ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState)
+ continue
+ }
+ tempAddrState.invalidationTimer.StopLocked()
+ tempAddrState.invalidationTimer.Reset(newValidLifetime)
+
+ // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary
+ // address is the lower of the preferred lifetime of the stable address or
+ // the maximum temporary address preferred lifetime - the temporary address
+ // desync factor. Note, the preferred lifetime of a temporary address is
+ // relative to the address's creation time.
+ preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor)
+ if prefixState.preferredUntil != (time.Time{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 {
+ preferredUntil = prefixState.preferredUntil
+ }
+
+ // If the address is no longer preferred, deprecate it immediately.
+ // Otherwise, reset the deprecation timer.
+ newPreferredLifetime := preferredUntil.Sub(now)
+ tempAddrState.deprecationTimer.StopLocked()
+ if newPreferredLifetime <= 0 {
+ ndp.deprecateSLAACAddress(tempAddrState.ref)
+ } else {
+ tempAddrState.ref.deprecated = false
+ tempAddrState.deprecationTimer.Reset(newPreferredLifetime)
+ }
+
+ tempAddrState.regenTimer.StopLocked()
+ if tempAddrState.regenerated {
+ } else {
+ allAddressesRegenerated = false
+
+ if newPreferredLifetime <= ndp.configs.RegenAdvanceDuration {
+ // The new preferred lifetime is less than the advance regeneration
+ // duration so regenerate an address for this temporary address
+ // immediately after we finish iterating over the temporary addresses.
+ regenForAddr = tempAddr
+ } else {
+ tempAddrState.regenTimer.Reset(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
+ }
+ }
+ }
+
+ // Generate a new temporary address if all of the existing temporary addresses
+ // have been regenerated, or we need to immediately regenerate an address
+ // due to an update in preferred lifetime.
+ //
+ // If each temporay address has already been regenerated, no new temporary
+ // address will be generated. To ensure continuation of temporary SLAAC
+ // addresses, we manually try to regenerate an address here.
+ if len(regenForAddr) != 0 || allAddressesRegenerated {
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ if state, ok := prefixState.tempAddrs[regenForAddr]; ndp.generateTempSLAACAddr(prefix, prefixState, true /* resetGenAttempts */) && ok {
+ state.regenerated = true
+ prefixState.tempAddrs[regenForAddr] = state
+ }
+ }
+}
+
+// deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP
+// dispatcher that ref has been deprecated.
+//
+// deprecateSLAACAddress does nothing if ref is already deprecated.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) {
+ if ref.deprecated {
+ return
+ }
+
+ ref.deprecated = true
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), ref.addrWithPrefix())
+ }
+}
+
+// invalidateSLAACPrefix invalidates a SLAAC prefix.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) {
+ if r := state.stableAddr.ref; r != nil {
+ // Since we are already invalidating the prefix, do not invalidate the
+ // prefix when removing the address.
+ if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACInvalidation */); err != nil {
+ panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", r.addrWithPrefix(), err))
+ }
+ }
+
+ ndp.cleanupSLAACPrefixResources(prefix, state)
+}
+
+// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC address's
+// resources.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) {
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr)
+ }
+
+ prefix := addr.Subnet()
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok || state.stableAddr.ref == nil || addr.Address != state.stableAddr.ref.ep.ID().LocalAddress {
+ return
+ }
+
+ if !invalidatePrefix {
+ // If the prefix is not being invalidated, disassociate the address from the
+ // prefix and do nothing further.
+ state.stableAddr.ref = nil
+ ndp.slaacPrefixes[prefix] = state
+ return
+ }
+
+ ndp.cleanupSLAACPrefixResources(prefix, state)
+}
+
+// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry.
+//
+// Panics if the SLAAC prefix is not known.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) {
+ // Invalidate all temporary addresses.
+ for tempAddr, tempAddrState := range state.tempAddrs {
+ ndp.invalidateTempSLAACAddr(state.tempAddrs, tempAddr, tempAddrState)
+ }
+
+ state.stableAddr.ref = nil
+ state.deprecationTimer.StopLocked()
+ state.invalidationTimer.StopLocked()
+ delete(ndp.slaacPrefixes, prefix)
+}
+
+// invalidateTempSLAACAddr invalidates a temporary SLAAC address.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
+ // Since we are already invalidating the address, do not invalidate the
+ // address when removing the address.
+ if err := ndp.nic.removePermanentIPv6EndpointLocked(tempAddrState.ref, false /* allowSLAACInvalidation */); err != nil {
+ panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.ref.addrWithPrefix(), err))
+ }
+
+ ndp.cleanupTempSLAACAddrResources(tempAddrs, tempAddr, tempAddrState)
+}
+
+// cleanupTempSLAACAddrResourcesAndNotify cleans up an invalidated temporary
+// SLAAC address's resources from ndp.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) {
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr)
+ }
+
+ if !invalidateAddr {
+ return
+ }
+
+ prefix := addr.Subnet()
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry to clean up temp addr %s resources", addr))
+ }
+
+ tempAddrState, ok := state.tempAddrs[addr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to clean up temp addr %s resources", addr))
+ }
+
+ ndp.cleanupTempSLAACAddrResources(state.tempAddrs, addr.Address, tempAddrState)
+}
+
+// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's
+// timers and entry.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
+ tempAddrState.deprecationTimer.StopLocked()
+ tempAddrState.invalidationTimer.StopLocked()
+ tempAddrState.regenTimer.StopLocked()
+ delete(tempAddrs, tempAddr)
+}
+
+// cleanupState cleans up ndp's state.
+//
+// If hostOnly is true, then only host-specific state will be cleaned up.
+//
+// cleanupState MUST be called with hostOnly set to true when ndp's NIC is
+// transitioning from a host to a router. This function will invalidate all
+// discovered on-link prefixes, discovered routers, and auto-generated
+// addresses.
+//
+// If hostOnly is true, then the link-local auto-generated address will not be
+// invalidated as routers are also expected to generate a link-local address.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupState(hostOnly bool) {
+ linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet()
+ linkLocalPrefixes := 0
+ for prefix, state := range ndp.slaacPrefixes {
+ // RFC 4862 section 5 states that routers are also expected to generate a
+ // link-local address so we do not invalidate them if we are cleaning up
+ // host-only state.
+ if hostOnly && prefix == linkLocalSubnet {
+ linkLocalPrefixes++
+ continue
+ }
+
+ ndp.invalidateSLAACPrefix(prefix, state)
+ }
+
+ if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes {
+ panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes))
+ }
+
+ for prefix := range ndp.onLinkPrefixes {
+ ndp.invalidateOnLinkPrefix(prefix)
+ }
+
+ if got := len(ndp.onLinkPrefixes); got != 0 {
+ panic(fmt.Sprintf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got))
+ }
+
+ for router := range ndp.defaultRouters {
+ ndp.invalidateDefaultRouter(router)
+ }
+
+ if got := len(ndp.defaultRouters); got != 0 {
+ panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got))
+ }
+
+ ndp.dhcpv6Configuration = 0
+}
+
+// startSolicitingRouters starts soliciting routers, as per RFC 4861 section
+// 6.3.7. If routers are already being solicited, this function does nothing.
+//
+// The NIC ndp belongs to MUST be locked.
+func (ndp *ndpState) startSolicitingRouters() {
+ if ndp.rtrSolicit.timer != nil {
+ // We are already soliciting routers.
+ return
+ }
+
+ remaining := ndp.configs.MaxRtrSolicitations
+ if remaining == 0 {
+ return
+ }
+
+ // Calculate the random delay before sending our first RS, as per RFC
+ // 4861 section 6.3.7.
+ var delay time.Duration
+ if ndp.configs.MaxRtrSolicitationDelay > 0 {
+ delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
+ }
+
+ var done bool
+ ndp.rtrSolicit.done = &done
+ ndp.rtrSolicit.timer = time.AfterFunc(delay, func() {
+ ndp.nic.mu.Lock()
+ if done {
+ // If we reach this point, it means that the RS timer fired after another
+ // goroutine already obtained the NIC lock and stopped solicitations.
+ // Simply return here and do nothing further.
+ ndp.nic.mu.Unlock()
+ return
+ }
+
+ // As per RFC 4861 section 4.1, the source of the RS is an address assigned
+ // to the sending interface, or the unspecified address if no address is
+ // assigned to the sending interface.
+ ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress)
+ if ref == nil {
+ ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint)
+ }
+ ndp.nic.mu.Unlock()
+
+ localAddr := ref.ep.ID().LocalAddress
+ r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ defer r.Release()
+
+ // Route should resolve immediately since
+ // header.IPv6AllRoutersMulticastAddress is a multicast address so a
+ // remote link address can be calculated without a resolution process.
+ if c, err := r.Resolve(nil); err != nil {
+ // Do not consider the NIC being unknown or disabled as a fatal error.
+ // Since this method is required to be called when the NIC is not locked,
+ // the NIC could have been disabled or removed by another goroutine.
+ if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState {
+ return
+ }
+
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err))
+ } else if c != nil {
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID()))
+ }
+
+ // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
+ // link-layer address option if the source address of the NDP RS is
+ // specified. This option MUST NOT be included if the source address is
+ // unspecified.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ var optsSerializer header.NDPOptionsSerializer
+ if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) {
+ optsSerializer = header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress),
+ }
+ }
+ payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length())
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize)
+ pkt := header.ICMPv6(hdr.Prepend(payloadSize))
+ pkt.SetType(header.ICMPv6RouterSolicit)
+ rs := header.NDPRouterSolicit(pkt.NDPPayload())
+ rs.Options().Serialize(optsSerializer)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ sent := r.Stats().ICMP.V6PacketsSent
+ if err := r.WritePacket(nil,
+ NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ TOS: DefaultTOS,
+ }, &PacketBuffer{Header: hdr},
+ ); err != nil {
+ sent.Dropped.Increment()
+ log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err)
+ // Don't send any more messages if we had an error.
+ remaining = 0
+ } else {
+ sent.RouterSolicit.Increment()
+ remaining--
+ }
+
+ ndp.nic.mu.Lock()
+ if done || remaining == 0 {
+ ndp.rtrSolicit.timer = nil
+ ndp.rtrSolicit.done = nil
+ } else if ndp.rtrSolicit.timer != nil {
+ // Note, we need to explicitly check to make sure that
+ // the timer field is not nil because if it was nil but
+ // we still reached this point, then we know the NIC
+ // was requested to stop soliciting routers so we don't
+ // need to send the next Router Solicitation message.
+ ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval)
+ }
+ ndp.nic.mu.Unlock()
+ })
+
+}
+
+// stopSolicitingRouters stops soliciting routers. If routers are not currently
+// being solicited, this function does nothing.
+//
+// The NIC ndp belongs to MUST be locked.
+func (ndp *ndpState) stopSolicitingRouters() {
+ if ndp.rtrSolicit.timer == nil {
+ // Nothing to do.
+ return
+ }
+
+ *ndp.rtrSolicit.done = true
+ ndp.rtrSolicit.timer.Stop()
+ ndp.rtrSolicit.timer = nil
+ ndp.rtrSolicit.done = nil
+}
+
+// initializeTempAddrState initializes state related to temporary SLAAC
+// addresses.
+func (ndp *ndpState) initializeTempAddrState() {
+ header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.nic.stack.tempIIDSeed, ndp.nic.ID())
+
+ if MaxDesyncFactor != 0 {
+ ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
+ }
+}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
new file mode 100644
index 000000000..6f86abc98
--- /dev/null
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -0,0 +1,5363 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "context"
+ "encoding/binary"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
+ linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
+ linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09")
+
+ // Extra time to use when waiting for an async event to occur.
+ defaultAsyncPositiveEventTimeout = 10 * time.Second
+
+ // Extra time to use when waiting for an async event to not occur.
+ //
+ // Since a negative check is used to make sure an event did not happen, it is
+ // okay to use a smaller timeout compared to the positive case since execution
+ // stall in regards to the monotonic clock will not affect the expected
+ // outcome.
+ defaultAsyncNegativeEventTimeout = time.Second
+)
+
+var (
+ llAddr1 = header.LinkLocalAddr(linkAddr1)
+ llAddr2 = header.LinkLocalAddr(linkAddr2)
+ llAddr3 = header.LinkLocalAddr(linkAddr3)
+ llAddr4 = header.LinkLocalAddr(linkAddr4)
+ dstAddr = tcpip.FullAddress{
+ Addr: "\x0a\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ Port: 25,
+ }
+)
+
+func addrForSubnet(subnet tcpip.Subnet, linkAddr tcpip.LinkAddress) tcpip.AddressWithPrefix {
+ if !header.IsValidUnicastEthernetAddress(linkAddr) {
+ return tcpip.AddressWithPrefix{}
+ }
+
+ addrBytes := []byte(subnet.ID())
+ header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
+ return tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: 64,
+ }
+}
+
+// prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent
+// tcpip.Subnet, and an address where the lower half of the address is composed
+// of the EUI-64 of linkAddr if it is a valid unicast ethernet address.
+func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWithPrefix, tcpip.Subnet, tcpip.AddressWithPrefix) {
+ prefixBytes := []byte{1, 2, 3, 4, 5, 6, 7, 8 + offset, 0, 0, 0, 0, 0, 0, 0, 0}
+ prefix := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(prefixBytes),
+ PrefixLen: 64,
+ }
+
+ subnet := prefix.Subnet()
+
+ return prefix, subnet, addrForSubnet(subnet, linkAddr)
+}
+
+// ndpDADEvent is a set of parameters that was passed to
+// ndpDispatcher.OnDuplicateAddressDetectionStatus.
+type ndpDADEvent struct {
+ nicID tcpip.NICID
+ addr tcpip.Address
+ resolved bool
+ err *tcpip.Error
+}
+
+type ndpRouterEvent struct {
+ nicID tcpip.NICID
+ addr tcpip.Address
+ // true if router was discovered, false if invalidated.
+ discovered bool
+}
+
+type ndpPrefixEvent struct {
+ nicID tcpip.NICID
+ prefix tcpip.Subnet
+ // true if prefix was discovered, false if invalidated.
+ discovered bool
+}
+
+type ndpAutoGenAddrEventType int
+
+const (
+ newAddr ndpAutoGenAddrEventType = iota
+ deprecatedAddr
+ invalidatedAddr
+)
+
+type ndpAutoGenAddrEvent struct {
+ nicID tcpip.NICID
+ addr tcpip.AddressWithPrefix
+ eventType ndpAutoGenAddrEventType
+}
+
+type ndpRDNSS struct {
+ addrs []tcpip.Address
+ lifetime time.Duration
+}
+
+type ndpRDNSSEvent struct {
+ nicID tcpip.NICID
+ rdnss ndpRDNSS
+}
+
+type ndpDNSSLEvent struct {
+ nicID tcpip.NICID
+ domainNames []string
+ lifetime time.Duration
+}
+
+type ndpDHCPv6Event struct {
+ nicID tcpip.NICID
+ configuration stack.DHCPv6ConfigurationFromNDPRA
+}
+
+var _ stack.NDPDispatcher = (*ndpDispatcher)(nil)
+
+// ndpDispatcher implements NDPDispatcher so tests can know when various NDP
+// related events happen for test purposes.
+type ndpDispatcher struct {
+ dadC chan ndpDADEvent
+ routerC chan ndpRouterEvent
+ rememberRouter bool
+ prefixC chan ndpPrefixEvent
+ rememberPrefix bool
+ autoGenAddrC chan ndpAutoGenAddrEvent
+ rdnssC chan ndpRDNSSEvent
+ dnsslC chan ndpDNSSLEvent
+ routeTable []tcpip.Route
+ dhcpv6ConfigurationC chan ndpDHCPv6Event
+}
+
+// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus.
+func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) {
+ if n.dadC != nil {
+ n.dadC <- ndpDADEvent{
+ nicID,
+ addr,
+ resolved,
+ err,
+ }
+ }
+}
+
+// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered.
+func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool {
+ if c := n.routerC; c != nil {
+ c <- ndpRouterEvent{
+ nicID,
+ addr,
+ true,
+ }
+ }
+
+ return n.rememberRouter
+}
+
+// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated.
+func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) {
+ if c := n.routerC; c != nil {
+ c <- ndpRouterEvent{
+ nicID,
+ addr,
+ false,
+ }
+ }
+}
+
+// Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered.
+func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool {
+ if c := n.prefixC; c != nil {
+ c <- ndpPrefixEvent{
+ nicID,
+ prefix,
+ true,
+ }
+ }
+
+ return n.rememberPrefix
+}
+
+// Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated.
+func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) {
+ if c := n.prefixC; c != nil {
+ c <- ndpPrefixEvent{
+ nicID,
+ prefix,
+ false,
+ }
+ }
+}
+
+func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool {
+ if c := n.autoGenAddrC; c != nil {
+ c <- ndpAutoGenAddrEvent{
+ nicID,
+ addr,
+ newAddr,
+ }
+ }
+ return true
+}
+
+func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
+ if c := n.autoGenAddrC; c != nil {
+ c <- ndpAutoGenAddrEvent{
+ nicID,
+ addr,
+ deprecatedAddr,
+ }
+ }
+}
+
+func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
+ if c := n.autoGenAddrC; c != nil {
+ c <- ndpAutoGenAddrEvent{
+ nicID,
+ addr,
+ invalidatedAddr,
+ }
+ }
+}
+
+// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption.
+func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) {
+ if c := n.rdnssC; c != nil {
+ c <- ndpRDNSSEvent{
+ nicID,
+ ndpRDNSS{
+ addrs,
+ lifetime,
+ },
+ }
+ }
+}
+
+// Implements stack.NDPDispatcher.OnDNSSearchListOption.
+func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) {
+ if n.dnsslC != nil {
+ n.dnsslC <- ndpDNSSLEvent{
+ nicID,
+ domainNames,
+ lifetime,
+ }
+ }
+}
+
+// Implements stack.NDPDispatcher.OnDHCPv6Configuration.
+func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) {
+ if c := n.dhcpv6ConfigurationC; c != nil {
+ c <- ndpDHCPv6Event{
+ nicID,
+ configuration,
+ }
+ }
+}
+
+// channelLinkWithHeaderLength is a channel.Endpoint with a configurable
+// header length.
+type channelLinkWithHeaderLength struct {
+ *channel.Endpoint
+ headerLength uint16
+}
+
+func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 {
+ return l.headerLength
+}
+
+// Check e to make sure that the event is for addr on nic with ID 1, and the
+// resolved flag set to resolved with the specified err.
+func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string {
+ return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e))
+}
+
+// TestDADDisabled tests that an address successfully resolves immediately
+// when DAD is not enabled (the default for an empty stack.Options).
+func TestDADDisabled(t *testing.T) {
+ const nicID = 1
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ }
+
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ }
+
+ // Should get the address immediately since we should not have performed
+ // DAD on it.
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(%d, %d) err = %s", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if addr.Address != addr1 {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
+ }
+
+ // We should not have sent any NDP NS messages.
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 {
+ t.Fatalf("got NeighborSolicit = %d, want = 0", got)
+ }
+}
+
+// TestDADResolve tests that an address successfully resolves after performing
+// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
+// Included in the subtests is a test to make sure that an invalid
+// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s.
+// This tests also validates the NDP NS packet that is transmitted.
+func TestDADResolve(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ linkHeaderLen uint16
+ dupAddrDetectTransmits uint8
+ retransTimer time.Duration
+ expectedRetransmitTimer time.Duration
+ }{
+ {
+ name: "1:1s:1s",
+ dupAddrDetectTransmits: 1,
+ retransTimer: time.Second,
+ expectedRetransmitTimer: time.Second,
+ },
+ {
+ name: "2:1s:1s",
+ linkHeaderLen: 1,
+ dupAddrDetectTransmits: 2,
+ retransTimer: time.Second,
+ expectedRetransmitTimer: time.Second,
+ },
+ {
+ name: "1:2s:2s",
+ linkHeaderLen: 2,
+ dupAddrDetectTransmits: 1,
+ retransTimer: 2 * time.Second,
+ expectedRetransmitTimer: 2 * time.Second,
+ },
+ // 0s is an invalid RetransmitTimer timer and will be fixed to
+ // the default RetransmitTimer value of 1s.
+ {
+ name: "1:0s:1s",
+ linkHeaderLen: 3,
+ dupAddrDetectTransmits: 1,
+ retransTimer: 0,
+ expectedRetransmitTimer: time.Second,
+ },
+ }
+
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ }
+ opts.NDPConfigs.RetransmitTimer = test.retransTimer
+ opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits
+
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ s := stack.New(opts)
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ // We add a default route so the call to FindRoute below will succeed
+ // once we have an assigned address.
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: addr3,
+ NIC: nicID,
+ }})
+
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ }
+
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
+ if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+
+ // Make sure the address does not resolve before the resolution time has
+ // passed.
+ time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout)
+ if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+ // Should not get a route even if we specify the local address as the
+ // tentative address.
+ {
+ r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false)
+ if err != tcpip.ErrNoRoute {
+ t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
+ }
+ r.Release()
+ }
+ {
+ r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
+ if err != tcpip.ErrNoRoute {
+ t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
+ }
+ r.Release()
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ }
+ if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ } else if addr.Address != addr1 {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
+ }
+ // Should get a route using the address now that it is resolved.
+ {
+ r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false)
+ if err != nil {
+ t.Errorf("got FindRoute(%d, '', %s, %d, false): %s", nicID, addr2, header.IPv6ProtocolNumber, err)
+ } else if r.LocalAddress != addr1 {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
+ }
+ r.Release()
+ }
+ {
+ r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
+ if err != nil {
+ t.Errorf("got FindRoute(%d, %s, %s, %d, false): %s", nicID, addr1, addr2, header.IPv6ProtocolNumber, err)
+ } else if r.LocalAddress != addr1 {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
+ }
+ r.Release()
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Should not have sent any more NS messages.
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
+ t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits)
+ }
+
+ // Validate the sent Neighbor Solicitation messages.
+ for i := uint8(0); i < test.dupAddrDetectTransmits; i++ {
+ p, _ := e.ReadContext(context.Background())
+
+ // Make sure its an IPv6 packet.
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+
+ // Make sure the right remote link address is used.
+ snmc := header.SolicitedNodeAddr(addr1)
+ if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
+
+ // Check NDP NS packet.
+ //
+ // As per RFC 4861 section 4.3, a possible option is the Source Link
+ // Layer option, but this option MUST NOT be included when the source
+ // address of the packet is the unspecified address.
+ checker.IPv6(t, p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(snmc),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(addr1),
+ checker.NDPNSOptions(nil),
+ ))
+
+ if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want)
+ }
+ }
+ })
+ }
+}
+
+// TestDADFail tests to make sure that the DAD process fails if another node is
+// detected to be performing DAD on the same address (receive an NS message from
+// a node doing DAD for the same address), or if another node is detected to own
+// the address already (receive an NA message for the tentative address).
+func TestDADFail(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ makeBuf func(tgt tcpip.Address) buffer.Prependable
+ getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ }{
+ {
+ "RxSolicit",
+ func(tgt tcpip.Address) buffer.Prependable {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(tgt)
+ snmc := header.SolicitedNodeAddr(tgt)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: 255,
+ SrcAddr: header.IPv6Any,
+ DstAddr: snmc,
+ })
+
+ return hdr
+
+ },
+ func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return s.NeighborSolicit
+ },
+ },
+ {
+ "RxAdvert",
+ func(tgt tcpip.Address) buffer.Prependable {
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
+ pkt := header.ICMPv6(hdr.Prepend(naSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(tgt)
+ na.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: 255,
+ SrcAddr: tgt,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
+
+ return hdr
+
+ },
+ func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return s.NeighborAdvert
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ }
+ ndpConfigs := stack.DefaultNDPConfigurations()
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ }
+ opts.NDPConfigs.RetransmitTimer = time.Second * 2
+
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ }
+
+ // Address should not be considered bound to the NIC yet
+ // (DAD ongoing).
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+
+ // Receive a packet to simulate multiple nodes owning or
+ // attempting to own the same address.
+ hdr := test.makeBuf(addr1)
+ e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ stat := test.getStat(s.Stats().ICMP.V6PacketsReceived)
+ if got := stat.Value(); got != 1 {
+ t.Fatalf("got stat = %d, want = 1", got)
+ }
+
+ // Wait for DAD to fail and make sure the address did
+ // not get resolved.
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // If we don't get a failure event after the
+ // expected resolution time + extra 1s buffer,
+ // something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ }
+ addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+
+ // Attempting to add the address again should not fail if the address's
+ // state was cleaned up when DAD failed.
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ }
+ })
+ }
+}
+
+func TestDADStop(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ stopFn func(t *testing.T, s *stack.Stack)
+ skipFinalAddrCheck bool
+ }{
+ // Tests to make sure that DAD stops when an address is removed.
+ {
+ name: "Remove address",
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ if err := s.RemoveAddress(nicID, addr1); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s): %s", nicID, addr1, err)
+ }
+ },
+ },
+
+ // Tests to make sure that DAD stops when the NIC is disabled.
+ {
+ name: "Disable NIC",
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ },
+ },
+
+ // Tests to make sure that DAD stops when the NIC is removed.
+ {
+ name: "Remove NIC",
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ if err := s.RemoveNIC(nicID); err != nil {
+ t.Fatalf("RemoveNIC(%d): %s", nicID, err)
+ }
+ },
+ // The NIC is removed so we can't check its addresses after calling
+ // stopFn.
+ skipFinalAddrCheck: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ }
+ ndpConfigs := stack.NDPConfigurations{
+ RetransmitTimer: time.Second,
+ DupAddrDetectTransmits: 2,
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ }
+
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ }
+
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+
+ test.stopFn(t, s)
+
+ // Wait for DAD to fail (since the address was removed during DAD).
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // If we don't get a failure event after the expected resolution
+ // time + extra 1s buffer, something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ if !test.skipFinalAddrCheck {
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+ }
+
+ // Should not have sent more than 1 NS message.
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
+ t.Errorf("got NeighborSolicit = %d, want <= 1", got)
+ }
+ })
+ }
+}
+
+// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if
+// we attempt to update NDP configurations using an invalid NICID.
+func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ })
+
+ // No NIC with ID 1 yet.
+ if got := s.SetNDPConfigurations(1, stack.NDPConfigurations{}); got != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.SetNDPConfigurations = %v, want = %s", got, tcpip.ErrUnknownNICID)
+ }
+}
+
+// TestSetNDPConfigurations tests that we can update and use per-interface NDP
+// configurations without affecting the default NDP configurations or other
+// interfaces' configurations.
+func TestSetNDPConfigurations(t *testing.T) {
+ const nicID1 = 1
+ const nicID2 = 2
+ const nicID3 = 3
+
+ tests := []struct {
+ name string
+ dupAddrDetectTransmits uint8
+ retransmitTimer time.Duration
+ expectedRetransmitTimer time.Duration
+ }{
+ {
+ "OK",
+ 1,
+ time.Second,
+ time.Second,
+ },
+ {
+ "Invalid Retransmit Timer",
+ 1,
+ 0,
+ time.Second,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ })
+
+ expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) {
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatalf("expected DAD event for %s", addr)
+ }
+ }
+
+ // This NIC(1)'s NDP configurations will be updated to
+ // be different from the default.
+ if err := s.CreateNIC(nicID1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
+ }
+
+ // Created before updating NIC(1)'s NDP configurations
+ // but updating NIC(1)'s NDP configurations should not
+ // affect other existing NICs.
+ if err := s.CreateNIC(nicID2, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
+ }
+
+ // Update the NDP configurations on NIC(1) to use DAD.
+ configs := stack.NDPConfigurations{
+ DupAddrDetectTransmits: test.dupAddrDetectTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ }
+ if err := s.SetNDPConfigurations(nicID1, configs); err != nil {
+ t.Fatalf("got SetNDPConfigurations(%d, _) = %s", nicID1, err)
+ }
+
+ // Created after updating NIC(1)'s NDP configurations
+ // but the stack's default NDP configurations should not
+ // have been updated.
+ if err := s.CreateNIC(nicID3, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID3, err)
+ }
+
+ // Add addresses for each NIC.
+ if err := s.AddAddress(nicID1, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addr1, err)
+ }
+ if err := s.AddAddress(nicID2, header.IPv6ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addr2, err)
+ }
+ expectDADEvent(nicID2, addr2)
+ if err := s.AddAddress(nicID3, header.IPv6ProtocolNumber, addr3); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addr3, err)
+ }
+ expectDADEvent(nicID3, addr3)
+
+ // Address should not be considered bound to NIC(1) yet
+ // (DAD ongoing).
+ addr, err := s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want)
+ }
+
+ // Should get the address on NIC(2) and NIC(3)
+ // immediately since we should not have performed DAD on
+ // it as the stack was configured to not do DAD by
+ // default and we only updated the NDP configurations on
+ // NIC(1).
+ addr, err = s.GetMainNICAddress(nicID2, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID2, header.IPv6ProtocolNumber, err)
+ }
+ if addr.Address != addr2 {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID2, header.IPv6ProtocolNumber, addr, addr2)
+ }
+ addr, err = s.GetMainNICAddress(nicID3, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID3, header.IPv6ProtocolNumber, err)
+ }
+ if addr.Address != addr3 {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID3, header.IPv6ProtocolNumber, addr, addr3)
+ }
+
+ // Sleep until right (500ms before) before resolution to
+ // make sure the address didn't resolve on NIC(1) yet.
+ const delta = 500 * time.Millisecond
+ time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta)
+ addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want)
+ }
+
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(2 * delta):
+ // We should get a resolution event after 500ms
+ // (delta) since we wait for 500ms less than the
+ // expected resolution time above to make sure
+ // that the address did not yet resolve. Waiting
+ // for 1s (2x delta) without a resolution event
+ // means something is wrong.
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID1, addr1, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ }
+ addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
+ }
+ if addr.Address != addr1 {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID1, header.IPv6ProtocolNumber, addr, addr1)
+ }
+ })
+ }
+}
+
+// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options
+// and DHCPv6 configurations specified.
+func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
+ icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length())
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
+ pkt := header.ICMPv6(hdr.Prepend(icmpSize))
+ pkt.SetType(header.ICMPv6RouterAdvert)
+ pkt.SetCode(0)
+ raPayload := pkt.NDPPayload()
+ ra := header.NDPRouterAdvert(raPayload)
+ // Populate the Router Lifetime.
+ binary.BigEndian.PutUint16(raPayload[2:], rl)
+ // Populate the Managed Address flag field.
+ if managedAddress {
+ // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing)
+ // of the RA payload.
+ raPayload[1] |= (1 << 7)
+ }
+ // Populate the Other Configurations flag field.
+ if otherConfigurations {
+ // The Other Configurations flag field is the 6th bit of byte #1
+ // (0-indexing) of the RA payload.
+ raPayload[1] |= (1 << 6)
+ }
+ opts := ra.Options()
+ opts.Serialize(optSer)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ iph.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: ip,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
+
+ return &stack.PacketBuffer{Data: hdr.View().ToVectorisedView()}
+}
+
+// raBufWithOpts returns a valid NDP Router Advertisement with options.
+//
+// Note, raBufWithOpts does not populate any of the RA fields other than the
+// Router Lifetime.
+func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
+ return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer)
+}
+
+// raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related
+// fields set.
+//
+// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the
+// DHCPv6 related ones.
+func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer {
+ return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{})
+}
+
+// raBuf returns a valid NDP Router Advertisement.
+//
+// Note, raBuf does not populate any of the RA fields other than the
+// Router Lifetime.
+func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer {
+ return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{})
+}
+
+// raBufWithPI returns a valid NDP Router Advertisement with a single Prefix
+// Information option.
+//
+// Note, raBufWithPI does not populate any of the RA fields other than the
+// Router Lifetime.
+func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) *stack.PacketBuffer {
+ flags := uint8(0)
+ if onLink {
+ // The OnLink flag is the 7th bit in the flags byte.
+ flags |= 1 << 7
+ }
+ if auto {
+ // The Address Auto-Configuration flag is the 6th bit in the
+ // flags byte.
+ flags |= 1 << 6
+ }
+
+ // A valid header.NDPPrefixInformation must be 30 bytes.
+ buf := [30]byte{}
+ // The first byte in a header.NDPPrefixInformation is the Prefix Length
+ // field.
+ buf[0] = uint8(prefix.PrefixLen)
+ // The 2nd byte within a header.NDPPrefixInformation is the Flags field.
+ buf[1] = flags
+ // The Valid Lifetime field starts after the 2nd byte within a
+ // header.NDPPrefixInformation.
+ binary.BigEndian.PutUint32(buf[2:], vl)
+ // The Preferred Lifetime field starts after the 6th byte within a
+ // header.NDPPrefixInformation.
+ binary.BigEndian.PutUint32(buf[6:], pl)
+ // The Prefix Address field starts after the 14th byte within a
+ // header.NDPPrefixInformation.
+ copy(buf[14:], prefix.Address)
+ return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{
+ header.NDPPrefixInformation(buf[:]),
+ })
+}
+
+// TestNoRouterDiscovery tests that router discovery will not be performed if
+// configured not to.
+func TestNoRouterDiscovery(t *testing.T) {
+ // Being configured to discover routers means handle and
+ // discover are set to true and forwarding is set to false.
+ // This tests all possible combinations of the configurations,
+ // except for the configuration where handle = true, discover =
+ // true and forwarding = false (the required configuration to do
+ // router discovery) - that will done in other tests.
+ for i := 0; i < 7; i++ {
+ handle := i&1 != 0
+ discover := i&2 != 0
+ forwarding := i&4 == 0
+
+ t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: handle,
+ DiscoverDefaultRouters: discover,
+ },
+ NDPDisp: &ndpDisp,
+ })
+ s.SetForwarding(forwarding)
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Rx an RA with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("unexpectedly discovered a router when configured not to")
+ default:
+ }
+ })
+ }
+}
+
+// Check e to make sure that the event is for addr on nic with ID 1, and the
+// discovered flag set to discovered.
+func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string {
+ return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e))
+}
+
+// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not
+// remember a discovered router when the dispatcher asks it not to.
+func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA for a router we should not remember.
+ const lifetimeSeconds = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds))
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, llAddr2, true); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected router discovery event")
+ }
+
+ // Wait for the invalidation time plus some buffer to make sure we do
+ // not actually receive any invalidation events as we should not have
+ // remembered the router in the first place.
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("should not have received any router events")
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
+ }
+}
+
+func TestRouterDiscovery(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ rememberRouter: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ expectRouterEvent := func(addr tcpip.Address, discovered bool) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, addr, discovered); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected router discovery event")
+ }
+ }
+
+ expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, addr, false); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for router discovery event")
+ }
+ }
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Rx an RA from lladdr2 with zero lifetime. It should not be
+ // remembered.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("unexpectedly discovered a router with 0 lifetime")
+ default:
+ }
+
+ // Rx an RA from lladdr2 with a huge lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
+ expectRouterEvent(llAddr2, true)
+
+ // Rx an RA from another router (lladdr3) with non-zero lifetime.
+ const l3LifetimeSeconds = 6
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds))
+ expectRouterEvent(llAddr3, true)
+
+ // Rx an RA from lladdr2 with lesser lifetime.
+ const l2LifetimeSeconds = 2
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds))
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("Should not receive a router event when updating lifetimes for known routers")
+ default:
+ }
+
+ // Wait for lladdr2's router invalidation timer to fire. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+
+ // Rx an RA from lladdr2 with huge lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
+ expectRouterEvent(llAddr2, true)
+
+ // Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
+ expectRouterEvent(llAddr2, false)
+
+ // Wait for lladdr3's router invalidation timer to fire. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+}
+
+// TestRouterDiscoveryMaxRouters tests that only
+// stack.MaxDiscoveredDefaultRouters discovered routers are remembered.
+func TestRouterDiscoveryMaxRouters(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ rememberRouter: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA from 2 more than the max number of discovered routers.
+ for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ {
+ linkAddr := []byte{2, 2, 3, 4, 5, 0}
+ linkAddr[5] = byte(i)
+ llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr))
+
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5))
+
+ if i <= stack.MaxDiscoveredDefaultRouters {
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, llAddr, true); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected router discovery event")
+ }
+
+ } else {
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("should not have discovered a new router after we already discovered the max number of routers")
+ default:
+ }
+ }
+ }
+}
+
+// TestNoPrefixDiscovery tests that prefix discovery will not be performed if
+// configured not to.
+func TestNoPrefixDiscovery(t *testing.T) {
+ prefix := tcpip.AddressWithPrefix{
+ Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"),
+ PrefixLen: 64,
+ }
+
+ // Being configured to discover prefixes means handle and
+ // discover are set to true and forwarding is set to false.
+ // This tests all possible combinations of the configurations,
+ // except for the configuration where handle = true, discover =
+ // true and forwarding = false (the required configuration to do
+ // prefix discovery) - that will done in other tests.
+ for i := 0; i < 7; i++ {
+ handle := i&1 != 0
+ discover := i&2 != 0
+ forwarding := i&4 == 0
+
+ t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ prefixC: make(chan ndpPrefixEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: handle,
+ DiscoverOnLinkPrefixes: discover,
+ },
+ NDPDisp: &ndpDisp,
+ })
+ s.SetForwarding(forwarding)
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Rx an RA with prefix with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0))
+
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly discovered a prefix when configured not to")
+ default:
+ }
+ })
+ }
+}
+
+// Check e to make sure that the event is for prefix on nic with ID 1, and the
+// discovered flag set to discovered.
+func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string {
+ return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e))
+}
+
+// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not
+// remember a discovered on-link prefix when the dispatcher asks it not to.
+func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
+ prefix, subnet, _ := prefixSubnetAddr(0, "")
+
+ ndpDisp := ndpDispatcher{
+ prefixC: make(chan ndpPrefixEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: false,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA with prefix that we should not remember.
+ const lifetimeSeconds = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0))
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, subnet, true); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected prefix discovery event")
+ }
+
+ // Wait for the invalidation time plus some buffer to make sure we do
+ // not actually receive any invalidation events as we should not have
+ // remembered the prefix in the first place.
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("should not have received any prefix events")
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
+ }
+}
+
+func TestPrefixDiscovery(t *testing.T) {
+ prefix1, subnet1, _ := prefixSubnetAddr(0, "")
+ prefix2, subnet2, _ := prefixSubnetAddr(1, "")
+ prefix3, subnet3, _ := prefixSubnetAddr(2, "")
+
+ ndpDisp := ndpDispatcher{
+ prefixC: make(chan ndpPrefixEvent, 1),
+ rememberPrefix: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, prefix, discovered); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected prefix discovery event")
+ }
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly discovered a prefix with 0 lifetime")
+ default:
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0))
+ expectPrefixEvent(subnet1, true)
+
+ // Receive an RA with prefix2 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0))
+ expectPrefixEvent(subnet2, true)
+
+ // Receive an RA with prefix3 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0))
+ expectPrefixEvent(subnet3, true)
+
+ // Receive an RA with prefix1 in a PI with lifetime = 0.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
+ expectPrefixEvent(subnet1, false)
+
+ // Receive an RA with prefix2 in a PI with lesser lifetime.
+ lifetime := uint32(2)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0))
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly received prefix event when updating lifetime")
+ default:
+ }
+
+ // Wait for prefix2's most recent invalidation timer plus some buffer to
+ // expire.
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for prefix discovery event")
+ }
+
+ // Receive RA to invalidate prefix3.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0))
+ expectPrefixEvent(subnet3, false)
+}
+
+func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
+ // Update the infinite lifetime value to a smaller value so we can test
+ // that when we receive a PI with such a lifetime value, we do not
+ // invalidate the prefix.
+ const testInfiniteLifetimeSeconds = 2
+ const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second
+ saved := header.NDPInfiniteLifetime
+ header.NDPInfiniteLifetime = testInfiniteLifetime
+ defer func() {
+ header.NDPInfiniteLifetime = saved
+ }()
+
+ prefix := tcpip.AddressWithPrefix{
+ Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"),
+ PrefixLen: 64,
+ }
+ subnet := prefix.Subnet()
+
+ ndpDisp := ndpDispatcher{
+ prefixC: make(chan ndpPrefixEvent, 1),
+ rememberPrefix: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, prefix, discovered); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected prefix discovery event")
+ }
+ }
+
+ // Receive an RA with prefix in an NDP Prefix Information option (PI)
+ // with infinite valid lifetime which should not get invalidated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0))
+ expectPrefixEvent(subnet, true)
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
+ case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout):
+ }
+
+ // Receive an RA with finite lifetime.
+ // The prefix should get invalidated after 1s.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0))
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, subnet, false); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(testInfiniteLifetime):
+ t.Fatal("timed out waiting for prefix discovery event")
+ }
+
+ // Receive an RA with finite lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0))
+ expectPrefixEvent(subnet, true)
+
+ // Receive an RA with prefix with an infinite lifetime.
+ // The prefix should not be invalidated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0))
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
+ case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout):
+ }
+
+ // Receive an RA with a prefix with a lifetime value greater than the
+ // set infinite lifetime value.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0))
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
+ case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultAsyncNegativeEventTimeout):
+ }
+
+ // Receive an RA with 0 lifetime.
+ // The prefix should get invalidated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0))
+ expectPrefixEvent(subnet, false)
+}
+
+// TestPrefixDiscoveryMaxRouters tests that only
+// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered.
+func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3),
+ rememberPrefix: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: false,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2)
+ prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{}
+
+ // Receive an RA with 2 more than the max number of discovered on-link
+ // prefixes.
+ for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ {
+ prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0}
+ prefixAddr[7] = byte(i)
+ prefix := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(prefixAddr[:]),
+ PrefixLen: 64,
+ }
+ prefixes[i] = prefix.Subnet()
+ buf := [30]byte{}
+ buf[0] = uint8(prefix.PrefixLen)
+ buf[1] = 128
+ binary.BigEndian.PutUint32(buf[2:], 10)
+ copy(buf[14:], prefix.Address)
+
+ optSer[i] = header.NDPPrefixInformation(buf[:])
+ }
+
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer))
+ for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ {
+ if i < stack.MaxDiscoveredOnLinkPrefixes {
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected prefix discovery event")
+ }
+ } else {
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("should not have discovered a new prefix after we already discovered the max number of prefixes")
+ default:
+ }
+ }
+ }
+}
+
+// Checks to see if list contains an IPv6 address, item.
+func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool {
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: item,
+ }
+
+ for _, i := range list {
+ if i == protocolAddress {
+ return true
+ }
+ }
+
+ return false
+}
+
+// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to.
+func TestNoAutoGenAddr(t *testing.T) {
+ prefix, _, _ := prefixSubnetAddr(0, "")
+
+ // Being configured to auto-generate addresses means handle and
+ // autogen are set to true and forwarding is set to false.
+ // This tests all possible combinations of the configurations,
+ // except for the configuration where handle = true, autogen =
+ // true and forwarding = false (the required configuration to do
+ // SLAAC) - that will done in other tests.
+ for i := 0; i < 7; i++ {
+ handle := i&1 != 0
+ autogen := i&2 != 0
+ forwarding := i&4 == 0
+
+ t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: handle,
+ AutoGenGlobalAddresses: autogen,
+ },
+ NDPDisp: &ndpDisp,
+ })
+ s.SetForwarding(forwarding)
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Rx an RA with prefix with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0))
+
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address when configured not to")
+ default:
+ }
+ })
+ }
+}
+
+// Check e to make sure that the event is for addr on nic with ID 1, and the
+// event type is set to eventType.
+func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string {
+ return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e))
+}
+
+// TestAutoGenAddr tests that an address is properly generated and invalidated
+// when configured to do so.
+func TestAutoGenAddr(t *testing.T) {
+ const newMinVL = 2
+ newMinVLDuration := newMinVL * time.Second
+ saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address with 0 lifetime")
+ default:
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+
+ // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
+ // with preferred lifetime > valid lifetime
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime")
+ default:
+ }
+
+ // Receive an RA with prefix2 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ t.Fatalf("Should have %s in the list of addresses", addr2)
+ }
+
+ // Refresh valid lifetime for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix")
+ default:
+ }
+
+ // Wait for addr of prefix1 to be invalidated.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ t.Fatalf("Should have %s in the list of addresses", addr2)
+ }
+}
+
+func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string {
+ ret := ""
+ for _, c := range containList {
+ if !containsV6Addr(addrs, c) {
+ ret += fmt.Sprintf("should have %s in the list of addresses\n", c)
+ }
+ }
+ for _, c := range notContainList {
+ if containsV6Addr(addrs, c) {
+ ret += fmt.Sprintf("should not have %s in the list of addresses\n", c)
+ }
+ }
+ return ret
+}
+
+// TestAutoGenTempAddr tests that temporary SLAAC addresses are generated when
+// configured to do so as part of IPv6 Privacy Extensions.
+func TestAutoGenTempAddr(t *testing.T) {
+ const (
+ nicID = 1
+ newMinVL = 5
+ newMinVLDuration = newMinVL * time.Second
+ )
+
+ savedMinPrefixInformationValidLifetimeForUpdate := stack.MinPrefixInformationValidLifetimeForUpdate
+ savedMaxDesync := stack.MaxDesyncFactor
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate
+ stack.MaxDesyncFactor = savedMaxDesync
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+ stack.MaxDesyncFactor = time.Nanosecond
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+
+ tests := []struct {
+ name string
+ dupAddrTransmits uint8
+ retransmitTimer time.Duration
+ }{
+ {
+ name: "DAD disabled",
+ },
+ {
+ name: "DAD enabled",
+ dupAddrTransmits: 1,
+ retransmitTimer: time.Second,
+ },
+ }
+
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for i, test := range tests {
+ i := i
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ seed := []byte{uint8(i)}
+ var tempIIDHistory [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistory[:], seed, nicID)
+ newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix {
+ return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr)
+ }
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 2),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: test.dupAddrTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ TempIIDSeed: seed,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
+
+ expectDADEventAsync := func(addr tcpip.Address) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e)
+ default:
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ expectDADEventAsync(addr1.Address)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly got an auto gen addr event = %+v", e)
+ default:
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero valid & preferred lifetimes.
+ tempAddr1 := newTempAddr(addr1.Address)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ expectAutoGenAddrEvent(tempAddr1, newAddr)
+ expectDADEventAsync(tempAddr1.Address)
+ if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
+ // with preferred lifetime > valid lifetime
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e)
+ default:
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Receive an RA with prefix2 in a PI w/ non-zero valid and preferred
+ // lifetimes.
+ tempAddr2 := newTempAddr(addr2.Address)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ expectDADEventAsync(addr2.Address)
+ expectAutoGenAddrEventAsync(tempAddr2, newAddr)
+ expectDADEventAsync(tempAddr2.Address)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Deprecate prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Refresh lifetimes for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Reduce valid lifetime and deprecate addresses of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Wait for addrs of prefix1 to be invalidated. They should be
+ // invalidated at the same time.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ var nextAddr tcpip.AddressWithPrefix
+ if e.addr == addr1 {
+ if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ nextAddr = tempAddr1
+ } else {
+ if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ nextAddr = addr1
+ }
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Receive an RA with prefix2 in a PI w/ 0 lifetimes.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0))
+ expectAutoGenAddrEvent(addr2, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr2, deprecatedAddr)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Errorf("got unexpected auto gen addr event = %+v", e)
+ default:
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+ })
+ }
+ })
+}
+
+// TestNoAutoGenTempAddrForLinkLocal test that temporary SLAAC addresses are not
+// generated for auto generated link-local addresses.
+func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
+ const nicID = 1
+
+ savedMaxDesyncFactor := stack.MaxDesyncFactor
+ defer func() {
+ stack.MaxDesyncFactor = savedMaxDesyncFactor
+ }()
+ stack.MaxDesyncFactor = time.Nanosecond
+
+ tests := []struct {
+ name string
+ dupAddrTransmits uint8
+ retransmitTimer time.Duration
+ }{
+ {
+ name: "DAD disabled",
+ },
+ {
+ name: "DAD enabled",
+ dupAddrTransmits: 1,
+ retransmitTimer: time.Second,
+ },
+ }
+
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ AutoGenIPv6LinkLocal: true,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ // The stable link-local address should auto-generate and resolve DAD.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+
+ // No new addresses should be generated.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Errorf("got unxpected auto gen addr event = %+v", e)
+ case <-time.After(defaultAsyncNegativeEventTimeout):
+ }
+ })
+ }
+ })
+}
+
+// TestNoAutoGenTempAddrWithoutStableAddr tests that a temporary SLAAC address
+// will not be generated until after DAD completes, even if a new Router
+// Advertisement is received to refresh lifetimes.
+func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
+ const (
+ nicID = 1
+ dadTransmits = 1
+ retransmitTimer = 2 * time.Second
+ )
+
+ savedMaxDesyncFactor := stack.MaxDesyncFactor
+ defer func() {
+ stack.MaxDesyncFactor = savedMaxDesyncFactor
+ }()
+ stack.MaxDesyncFactor = 0
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+ var tempIIDHistory [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistory[:], nil, nicID)
+ tempAddr := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ // Receive an RA to trigger SLAAC for prefix.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+
+ // DAD on the stable address for prefix has not yet completed. Receiving a new
+ // RA that would refresh lifetimes should not generate a temporary SLAAC
+ // address for the prefix.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpected auto gen addr event = %+v", e)
+ default:
+ }
+
+ // Wait for DAD to complete for the stable address then expect the temporary
+ // address to be generated.
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+}
+
+// TestAutoGenTempAddrRegen tests that temporary SLAAC addresses are
+// regenerated.
+func TestAutoGenTempAddrRegen(t *testing.T) {
+ const (
+ nicID = 1
+ regenAfter = 2 * time.Second
+ newMinVL = 10
+ newMinVLDuration = newMinVL * time.Second
+ )
+
+ savedMaxDesyncFactor := stack.MaxDesyncFactor
+ savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime
+ savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime
+ defer func() {
+ stack.MaxDesyncFactor = savedMaxDesyncFactor
+ stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
+ stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
+ }()
+ stack.MaxDesyncFactor = 0
+ stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration
+ stack.MinMaxTempAddrValidLifetime = newMinVLDuration
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+ var tempIIDHistory [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistory[:], nil, nicID)
+ tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ ndpConfigs := stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ RegenAdvanceDuration: newMinVLDuration - regenAfter,
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero valid & preferred lifetimes.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr, newAddr)
+ expectAutoGenAddrEvent(tempAddr1, newAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Wait for regeneration
+ expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Wait for regeneration
+ expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Stop generating temporary addresses
+ ndpConfigs.AutoGenTempGlobalAddresses = false
+ if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
+ t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ }
+
+ // Wait for all the temporary addresses to get invalidated.
+ tempAddrs := []tcpip.AddressWithPrefix{tempAddr1, tempAddr2, tempAddr3}
+ invalidateAfter := newMinVLDuration - 2*regenAfter
+ for _, addr := range tempAddrs {
+ // Wait for a deprecation then invalidation event, or just an invalidation
+ // event. We need to cover both cases but cannot deterministically hit both
+ // cases because the deprecation and invalidation timers could fire in any
+ // order.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff == "" {
+ // If we get a deprecation event first, we should get an invalidation
+ // event almost immediately after.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" {
+ // If we get an invalidation event first, we shouldn't get a deprecation
+ // event after.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly got an auto-generated event = %+v", e)
+ case <-time.After(defaultAsyncNegativeEventTimeout):
+ }
+ } else {
+ t.Fatalf("got unexpected auto-generated event = %+v", e)
+ }
+ case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+
+ invalidateAfter = regenAfter
+ }
+ if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+}
+
+// TestAutoGenTempAddrRegenTimerUpdates tests that a temporary address's
+// regeneration timer gets updated when refreshing the address's lifetimes.
+func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
+ const (
+ nicID = 1
+ regenAfter = 2 * time.Second
+ newMinVL = 10
+ newMinVLDuration = newMinVL * time.Second
+ )
+
+ savedMaxDesyncFactor := stack.MaxDesyncFactor
+ savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime
+ savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime
+ defer func() {
+ stack.MaxDesyncFactor = savedMaxDesyncFactor
+ stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
+ stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
+ }()
+ stack.MaxDesyncFactor = 0
+ stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration
+ stack.MinMaxTempAddrValidLifetime = newMinVLDuration
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+ var tempIIDHistory [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistory[:], nil, nicID)
+ tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ ndpConfigs := stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ RegenAdvanceDuration: newMinVLDuration - regenAfter,
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero valid & preferred lifetimes.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr, newAddr)
+ expectAutoGenAddrEvent(tempAddr1, newAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Deprecate the prefix.
+ //
+ // A new temporary address should be generated after the regeneration
+ // time has passed since the prefix is deprecated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpected auto gen addr event = %+v", e)
+ case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout):
+ }
+
+ // Prefer the prefix again.
+ //
+ // A new temporary address should immediately be generated since the
+ // regeneration time has already passed since the last address was generated
+ // - this regeneration does not depend on a timer.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ expectAutoGenAddrEvent(tempAddr2, newAddr)
+
+ // Increase the maximum lifetimes for temporary addresses to large values
+ // then refresh the lifetimes of the prefix.
+ //
+ // A new address should not be generated after the regeneration time that was
+ // expected for the previous check. This is because the preferred lifetime for
+ // the temporary addresses has increased, so it will take more time to
+ // regenerate a new temporary address. Note, new addresses are only
+ // regenerated after the preferred lifetime - the regenerate advance duration
+ // as paased.
+ ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second
+ ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second
+ if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
+ t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ }
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpected auto gen addr event = %+v", e)
+ case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout):
+ }
+
+ // Set the maximum lifetimes for temporary addresses such that on the next
+ // RA, the regeneration timer gets reset.
+ //
+ // The maximum lifetime is the sum of the minimum lifetimes for temporary
+ // addresses + the time that has already passed since the last address was
+ // generated so that the regeneration timer is needed to generate the next
+ // address.
+ newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout
+ ndpConfigs.MaxTempAddrValidLifetime = newLifetimes
+ ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes
+ if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
+ t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ }
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
+}
+
+// TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response
+// to a mix of DAD conflicts and NIC-local conflicts.
+func TestMixedSLAACAddrConflictRegen(t *testing.T) {
+ const (
+ nicID = 1
+ nicName = "nic"
+ lifetimeSeconds = 9999
+ // From stack.maxSLAACAddrLocalRegenAttempts
+ maxSLAACAddrLocalRegenAttempts = 10
+ // We use 2 more addreses than the maximum local regeneration attempts
+ // because we want to also trigger regeneration in response to a DAD
+ // conflicts for this test.
+ maxAddrs = maxSLAACAddrLocalRegenAttempts + 2
+ dupAddrTransmits = 1
+ retransmitTimer = time.Second
+ )
+
+ var tempIIDHistoryWithModifiedEUI64 [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistoryWithModifiedEUI64[:], nil, nicID)
+
+ var tempIIDHistoryWithOpaqueIID [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistoryWithOpaqueIID[:], nil, nicID)
+
+ prefix, subnet, stableAddrWithModifiedEUI64 := prefixSubnetAddr(0, linkAddr1)
+ var stableAddrsWithOpaqueIID [maxAddrs]tcpip.AddressWithPrefix
+ var tempAddrsWithOpaqueIID [maxAddrs]tcpip.AddressWithPrefix
+ var tempAddrsWithModifiedEUI64 [maxAddrs]tcpip.AddressWithPrefix
+ addrBytes := []byte(subnet.ID())
+ for i := 0; i < maxAddrs; i++ {
+ stableAddrsWithOpaqueIID[i] = tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, uint8(i), nil)),
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+ }
+ // When generating temporary addresses, the resolved stable address for the
+ // SLAAC prefix will be the first address stable address generated for the
+ // prefix as we will not simulate address conflicts for the stable addresses
+ // in tests involving temporary addresses. Address conflicts for stable
+ // addresses will be done in their own tests.
+ tempAddrsWithOpaqueIID[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistoryWithOpaqueIID[:], stableAddrsWithOpaqueIID[0].Address)
+ tempAddrsWithModifiedEUI64[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistoryWithModifiedEUI64[:], stableAddrWithModifiedEUI64.Address)
+ }
+
+ tests := []struct {
+ name string
+ addrs []tcpip.AddressWithPrefix
+ tempAddrs bool
+ initialExpect tcpip.AddressWithPrefix
+ nicNameFromID func(tcpip.NICID, string) string
+ }{
+ {
+ name: "Stable addresses with opaque IIDs",
+ addrs: stableAddrsWithOpaqueIID[:],
+ nicNameFromID: func(tcpip.NICID, string) string {
+ return nicName
+ },
+ },
+ {
+ name: "Temporary addresses with opaque IIDs",
+ addrs: tempAddrsWithOpaqueIID[:],
+ tempAddrs: true,
+ initialExpect: stableAddrsWithOpaqueIID[0],
+ nicNameFromID: func(tcpip.NICID, string) string {
+ return nicName
+ },
+ },
+ {
+ name: "Temporary addresses with modified EUI64",
+ addrs: tempAddrsWithModifiedEUI64[:],
+ tempAddrs: true,
+ initialExpect: stableAddrWithModifiedEUI64,
+ },
+ }
+
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ ndpConfigs := stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: test.tempAddrs,
+ AutoGenAddressConflictRetries: 1,
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: test.nicNameFromID,
+ },
+ })
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: llAddr2,
+ NIC: nicID,
+ }})
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ for j := 0; j < len(test.addrs)-1; j++ {
+ // The NIC will not attempt to generate an address in response to a
+ // NIC-local conflict after some maximum number of attempts. We skip
+ // creating a conflict for the address that would be generated as part
+ // of the last attempt so we can simulate a DAD conflict for this
+ // address and restart the NIC-local generation process.
+ if j == maxSLAACAddrLocalRegenAttempts-1 {
+ continue
+ }
+
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, test.addrs[j].Address); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, test.addrs[j].Address, err)
+ }
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectAutoGenAddrAsyncEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
+
+ expectDADEventAsync := func(addr tcpip.Address) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+ }
+
+ // Enable DAD.
+ ndpDisp.dadC = make(chan ndpDADEvent, 2)
+ ndpConfigs.DupAddrDetectTransmits = dupAddrTransmits
+ ndpConfigs.RetransmitTimer = retransmitTimer
+ if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
+ t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ }
+
+ // Do SLAAC for prefix.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
+ if test.initialExpect != (tcpip.AddressWithPrefix{}) {
+ expectAutoGenAddrEvent(test.initialExpect, newAddr)
+ expectDADEventAsync(test.initialExpect.Address)
+ }
+
+ // The last local generation attempt should succeed, but we introduce a
+ // DAD failure to restart the local generation process.
+ addr := test.addrs[maxSLAACAddrLocalRegenAttempts-1]
+ expectAutoGenAddrAsyncEvent(addr, newAddr)
+ if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
+ t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
+ }
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+ expectAutoGenAddrEvent(addr, invalidatedAddr)
+
+ // The last address generated should resolve DAD.
+ addr = test.addrs[len(test.addrs)-1]
+ expectAutoGenAddrAsyncEvent(addr, newAddr)
+ expectDADEventAsync(addr.Address)
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpected auto gen addr event = %+v", e)
+ default:
+ }
+ })
+ }
+}
+
+// stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher,
+// channel.Endpoint and stack.Stack.
+//
+// stack.Stack will have a default route through the router (llAddr3) installed
+// and a static link-address (linkAddr3) added to the link address cache for the
+// router.
+func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) {
+ t.Helper()
+ ndpDisp := &ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: ndpDisp,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: llAddr3,
+ NIC: nicID,
+ }})
+ s.AddLinkAddress(nicID, llAddr3, linkAddr3)
+ return ndpDisp, e, s
+}
+
+// addrForNewConnectionTo returns the local address used when creating a new
+// connection to addr.
+func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address {
+ t.Helper()
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ }
+ defer ep.Close()
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
+ }
+ if err := ep.Connect(addr); err != nil {
+ t.Fatalf("ep.Connect(%+v): %s", addr, err)
+ }
+ got, err := ep.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("ep.GetLocalAddress(): %s", err)
+ }
+ return got.Addr
+}
+
+// addrForNewConnection returns the local address used when creating a new
+// connection.
+func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address {
+ t.Helper()
+
+ return addrForNewConnectionTo(t, s, dstAddr)
+}
+
+// addrForNewConnectionWithAddr returns the local address used when creating a
+// new connection with a specific local address.
+func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address {
+ t.Helper()
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ }
+ defer ep.Close()
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
+ }
+ if err := ep.Bind(addr); err != nil {
+ t.Fatalf("ep.Bind(%+v): %s", addr, err)
+ }
+ if err := ep.Connect(dstAddr); err != nil {
+ t.Fatalf("ep.Connect(%+v): %s", dstAddr, err)
+ }
+ got, err := ep.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("ep.GetLocalAddress(): %s", err)
+ }
+ return got.Addr
+}
+
+// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when
+// receiving a PI with 0 preferred lifetime.
+func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
+ const nicID = 1
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
+
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
+
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
+
+ // Receive PI for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ expectPrimaryAddr(addr1)
+
+ // Deprecate addr for prefix1 immedaitely.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ // addr should still be the primary endpoint as there are no other addresses.
+ expectPrimaryAddr(addr1)
+
+ // Refresh lifetimes of addr generated from prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
+
+ // Receive PI for prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
+
+ // Deprecate addr for prefix2 immedaitely.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, deprecatedAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr1 should be the primary endpoint now since addr2 is deprecated but
+ // addr1 is not.
+ expectPrimaryAddr(addr1)
+ // addr2 is deprecated but if explicitly requested, it should be used.
+ fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID}
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
+ }
+
+ // Another PI w/ 0 preferred lifetime should not result in a deprecation
+ // event.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
+ }
+
+ // Refresh lifetimes of addr generated from prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr2)
+}
+
+// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated
+// when its preferred lifetime expires.
+func TestAutoGenAddrTimerDeprecation(t *testing.T) {
+ const nicID = 1
+ const newMinVL = 2
+ newMinVLDuration := newMinVL * time.Second
+ saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
+
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
+
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
+
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
+
+ // Receive PI for prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
+
+ // Receive a PI for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr1)
+
+ // Refresh lifetime for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
+
+ // Wait for addr of prefix1 to be deprecated.
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr2 should be the primary endpoint now since addr1 is deprecated but
+ // addr2 is not.
+ expectPrimaryAddr(addr2)
+ // addr1 is deprecated but if explicitly requested, it should be used.
+ fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
+ }
+
+ // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
+ // sure we do not get a deprecation event again.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr2)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
+ }
+
+ // Refresh lifetimes for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ // addr1 is the primary endpoint again since it is non-deprecated now.
+ expectPrimaryAddr(addr1)
+
+ // Wait for addr of prefix1 to be deprecated.
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr2 should be the primary endpoint now since it is not deprecated.
+ expectPrimaryAddr(addr2)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
+ }
+
+ // Wait for addr of prefix1 to be invalidated.
+ expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout)
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
+
+ // Refresh both lifetimes for addr of prefix2 to the same value.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+
+ // Wait for a deprecation then invalidation events, or just an invalidation
+ // event. We need to cover both cases but cannot deterministically hit both
+ // cases because the deprecation and invalidation handlers could be handled in
+ // either deprecation then invalidation, or invalidation then deprecation
+ // (which should be cancelled by the invalidation handler).
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" {
+ // If we get a deprecation event first, we should get an invalidation
+ // event almost immediately after.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
+ // If we get an invalidation event first, we should not get a deprecation
+ // event after.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ case <-time.After(defaultAsyncNegativeEventTimeout):
+ }
+ } else {
+ t.Fatalf("got unexpected auto-generated event")
+ }
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should not have %s in the list of addresses", addr2)
+ }
+ // Should not have any primary endpoints.
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); got != want {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want)
+ }
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ }
+ defer ep.Close()
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
+ }
+
+ if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
+ t.Errorf("got ep.Connect(%+v) = %v, want = %s", dstAddr, err, tcpip.ErrNoRoute)
+ }
+}
+
+// Tests transitioning a SLAAC address's valid lifetime between finite and
+// infinite values.
+func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
+ const infiniteVLSeconds = 2
+ const minVLSeconds = 1
+ savedIL := header.NDPInfiniteLifetime
+ savedMinVL := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = savedMinVL
+ header.NDPInfiniteLifetime = savedIL
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second
+ header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+
+ tests := []struct {
+ name string
+ infiniteVL uint32
+ }{
+ {
+ name: "EqualToInfiniteVL",
+ infiniteVL: infiniteVLSeconds,
+ },
+ // Our implementation supports changing header.NDPInfiniteLifetime for tests
+ // such that a packet can be received where the lifetime field has a value
+ // greater than header.NDPInfiniteLifetime. Because of this, we test to make
+ // sure that receiving a value greater than header.NDPInfiniteLifetime is
+ // handled the same as when receiving a value equal to
+ // header.NDPInfiniteLifetime.
+ {
+ name: "MoreThanInfiniteVL",
+ infiniteVL: infiniteVLSeconds + 1,
+ },
+ }
+
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA with finite prefix.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+
+ // Receive an new RA with prefix with infinite VL.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0))
+
+ // Receive a new RA with prefix with finite VL.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+
+ case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timeout waiting for addr auto gen event")
+ }
+ })
+ }
+ })
+}
+
+// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an
+// auto-generated address only gets updated when required to, as specified in
+// RFC 4862 section 5.5.3.e.
+func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
+ const infiniteVL = 4294967295
+ const newMinVL = 4
+ saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+
+ tests := []struct {
+ name string
+ ovl uint32
+ nvl uint32
+ evl uint32
+ }{
+ // Should update the VL to the minimum VL for updating if the
+ // new VL is less than newMinVL but was originally greater than
+ // it.
+ {
+ "LargeVLToVLLessThanMinVLForUpdate",
+ 9999,
+ 1,
+ newMinVL,
+ },
+ {
+ "LargeVLTo0",
+ 9999,
+ 0,
+ newMinVL,
+ },
+ {
+ "InfiniteVLToVLLessThanMinVLForUpdate",
+ infiniteVL,
+ 1,
+ newMinVL,
+ },
+ {
+ "InfiniteVLTo0",
+ infiniteVL,
+ 0,
+ newMinVL,
+ },
+
+ // Should not update VL if original VL was less than newMinVL
+ // and the new VL is also less than newMinVL.
+ {
+ "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate",
+ newMinVL - 1,
+ newMinVL - 3,
+ newMinVL - 1,
+ },
+
+ // Should take the new VL if the new VL is greater than the
+ // remaining time or is greater than newMinVL.
+ {
+ "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate",
+ newMinVL + 5,
+ newMinVL + 3,
+ newMinVL + 3,
+ },
+ {
+ "SmallVLToGreaterVLButStillLessThanMinVLForUpdate",
+ newMinVL - 3,
+ newMinVL - 1,
+ newMinVL - 1,
+ },
+ {
+ "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate",
+ newMinVL - 3,
+ newMinVL + 1,
+ newMinVL + 1,
+ },
+ }
+
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
+ }
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA with prefix with initial VL,
+ // test.ovl.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+
+ // Receive an new RA with prefix with new VL,
+ // test.nvl.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0))
+
+ //
+ // Validate that the VL for the address got set
+ // to test.evl.
+ //
+
+ // The address should not be invalidated until the effective valid
+ // lifetime has passed.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly received an auto gen addr event")
+ case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout):
+ }
+
+ // Wait for the invalidation event.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timeout waiting for addr auto gen event")
+ }
+ })
+ }
+ })
+}
+
+// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed
+// by the user, its resources will be cleaned up and an invalidation event will
+// be sent to the integrator.
+func TestAutoGenAddrRemoval(t *testing.T) {
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ // Receive a PI to auto-generate an address.
+ const lifetimeSeconds = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0))
+ expectAutoGenAddrEvent(addr, newAddr)
+
+ // Removing the address should result in an invalidation event
+ // immediately.
+ if err := s.RemoveAddress(1, addr.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err)
+ }
+ expectAutoGenAddrEvent(addr, invalidatedAddr)
+
+ // Wait for the original valid lifetime to make sure the original timer
+ // got stopped/cleaned up.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly received an auto gen addr event")
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
+ }
+}
+
+// TestAutoGenAddrAfterRemoval tests adding a SLAAC address that was previously
+// assigned to the NIC but is in the permanentExpired state.
+func TestAutoGenAddrAfterRemoval(t *testing.T) {
+ const nicID = 1
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
+
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
+
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
+
+ // Receive a PI to auto-generate addr1 with a large valid and preferred
+ // lifetime.
+ const largeLifetimeSeconds = 999
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ expectPrimaryAddr(addr1)
+
+ // Add addr2 as a static address.
+ protoAddr2 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr2,
+ }
+ if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
+ }
+ // addr2 should be more preferred now since it is at the front of the primary
+ // list.
+ expectPrimaryAddr(addr2)
+
+ // Get a route using addr2 to increment its reference count then remove it
+ // to leave it in the permanentExpired state.
+ r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err)
+ }
+ defer r.Release()
+ if err := s.RemoveAddress(nicID, addr2.Address); err != nil {
+ t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err)
+ }
+ // addr1 should be preferred again since addr2 is in the expired state.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to auto-generate addr2 as valid and preferred.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ // addr2 should be more preferred now that it is closer to the front of the
+ // primary list and not deprecated.
+ expectPrimaryAddr(addr2)
+
+ // Removing the address should result in an invalidation event immediately.
+ // It should still be in the permanentExpired state because r is still held.
+ //
+ // We remove addr2 here to make sure addr2 was marked as a SLAAC address
+ // (it was previously marked as a static address).
+ if err := s.RemoveAddress(1, addr2.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ }
+ expectAutoGenAddrEvent(addr2, invalidatedAddr)
+ // addr1 should be more preferred since addr2 is in the expired state.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to auto-generate addr2 as valid and deprecated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ // addr1 should still be more preferred since addr2 is deprecated, even though
+ // it is closer to the front of the primary list.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to refresh addr2's preferred lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto gen addr event")
+ default:
+ }
+ // addr2 should be more preferred now that it is not deprecated.
+ expectPrimaryAddr(addr2)
+
+ if err := s.RemoveAddress(1, addr2.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ }
+ expectAutoGenAddrEvent(addr2, invalidatedAddr)
+ expectPrimaryAddr(addr1)
+}
+
+// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that
+// is already assigned to the NIC, the static address remains.
+func TestAutoGenAddrStaticConflict(t *testing.T) {
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Add the address as a static address before SLAAC tries to add it.
+ if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err)
+ }
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+
+ // Receive a PI where the generated address will be the same as the one
+ // that we already have assigned statically.
+ const lifetimeSeconds = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically")
+ default:
+ }
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+
+ // Should not get an invalidation event after the PI's invalidation
+ // time.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly received an auto gen addr event")
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
+ }
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+}
+
+// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use
+// opaque interface identifiers when configured to do so.
+func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
+ const nicID = 1
+ const nicName = "nic1"
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
+ secretKey := secretKeyBuf[:]
+ n, err := rand.Read(secretKey)
+ if err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ }
+ if n != header.OpaqueIIDSecretKeyMinBytes {
+ t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
+ }
+
+ prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1)
+ prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1)
+ // addr1 and addr2 are the addresses that are expected to be generated when
+ // stack.Stack is configured to generate opaque interface identifiers as
+ // defined by RFC 7217.
+ addrBytes := []byte(subnet1.ID())
+ addr1 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet1, nicName, 0, secretKey)),
+ PrefixLen: 64,
+ }
+ addrBytes = []byte(subnet2.ID())
+ addr2 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet2, nicName, 0, secretKey)),
+ PrefixLen: 64,
+ }
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })
+ opts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v, _) = %s", nicID, opts, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ // Receive an RA with prefix1 in a PI.
+ const validLifetimeSecondPrefix1 = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+
+ // Receive an RA with prefix2 in a PI with a large valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+
+ // Wait for addr of prefix1 to be invalidated.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+}
+
+func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
+ const nicID = 1
+ const nicName = "nic"
+ const dadTransmits = 1
+ const retransmitTimer = time.Second
+ const maxMaxRetries = 3
+ const lifetimeSeconds = 10
+
+ // Needed for the temporary address sub test.
+ savedMaxDesync := stack.MaxDesyncFactor
+ defer func() {
+ stack.MaxDesyncFactor = savedMaxDesync
+ }()
+ stack.MaxDesyncFactor = time.Nanosecond
+
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
+ secretKey := secretKeyBuf[:]
+ n, err := rand.Read(secretKey)
+ if err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ }
+ if n != header.OpaqueIIDSecretKeyMinBytes {
+ t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
+ }
+
+ prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
+
+ addrForSubnet := func(subnet tcpip.Subnet, dadCounter uint8) tcpip.AddressWithPrefix {
+ addrBytes := []byte(subnet.ID())
+ return tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, dadCounter, secretKey)),
+ PrefixLen: 64,
+ }
+ }
+
+ expectAutoGenAddrEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectAutoGenAddrEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
+
+ expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+ }
+
+ expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+ }
+
+ stableAddrForTempAddrTest := addrForSubnet(subnet, 0)
+
+ addrTypes := []struct {
+ name string
+ ndpConfigs stack.NDPConfigurations
+ autoGenLinkLocal bool
+ prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix
+ addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix
+ }{
+ {
+ name: "Global address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix {
+ // Receive an RA with prefix1 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
+ return nil
+
+ },
+ addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix {
+ return addrForSubnet(subnet, dadCounter)
+ },
+ },
+ {
+ name: "LinkLocal address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ },
+ autoGenLinkLocal: true,
+ prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix {
+ return nil
+ },
+ addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix {
+ return addrForSubnet(header.IPv6LinkLocalPrefix.Subnet(), dadCounter)
+ },
+ },
+ {
+ name: "Temporary address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ prepareFn: func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix {
+ header.InitialTempIID(tempIIDHistory, nil, nicID)
+
+ // Generate a stable SLAAC address so temporary addresses will be
+ // generated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr)
+ expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, true)
+
+ // The stable address will be assigned throughout the test.
+ return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest}
+ },
+ addrGenFn: func(_ uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix {
+ return header.GenerateTempIPv6SLAACAddr(tempIIDHistory, stableAddrForTempAddrTest.Address)
+ },
+ },
+ }
+
+ for _, addrType := range addrTypes {
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the parallel
+ // tests complete and limit the number of parallel tests running at the same
+ // time to reduce flakes.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run(addrType.name, func(t *testing.T) {
+ for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ {
+ for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ {
+ maxRetries := maxRetries
+ numFailures := numFailures
+ addrType := addrType
+
+ t.Run(fmt.Sprintf("%d max retries and %d failures", maxRetries, numFailures), func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ ndpConfigs := addrType.ndpConfigs
+ ndpConfigs.AutoGenAddressConflictRetries = maxRetries
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })
+ opts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
+ }
+
+ var tempIIDHistory [header.IIDSize]byte
+ stableAddrs := addrType.prepareFn(t, &ndpDisp, e, tempIIDHistory[:])
+
+ // Simulate DAD conflicts so the address is regenerated.
+ for i := uint8(0); i < numFailures; i++ {
+ addr := addrType.addrGenFn(i, tempIIDHistory[:])
+ expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr)
+
+ // Should not have any new addresses assigned to the NIC.
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Simulate a DAD conflict.
+ if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
+ t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
+ }
+ expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr)
+ expectDADEvent(t, &ndpDisp, addr.Address, false)
+
+ // Attempting to add the address manually should not fail if the
+ // address's state was cleaned up when DAD failed.
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err)
+ }
+ if err := s.RemoveAddress(nicID, addr.Address); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err)
+ }
+ expectDADEvent(t, &ndpDisp, addr.Address, false)
+ }
+
+ // Should not have any new addresses assigned to the NIC.
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // If we had less failures than generation attempts, we should have
+ // an address after DAD resolves.
+ if maxRetries+1 > numFailures {
+ addr := addrType.addrGenFn(numFailures, tempIIDHistory[:])
+ expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr)
+ expectDADEventAsync(t, &ndpDisp, addr.Address, true)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+ }
+
+ // Should not attempt address generation again.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
+ case <-time.After(defaultAsyncNegativeEventTimeout):
+ }
+ })
+ }
+ }
+ })
+ }
+}
+
+// TestAutoGenAddrWithEUI64IIDNoDADRetries tests that a regeneration attempt is
+// not made for SLAAC addresses generated with an IID based on the NIC's link
+// address.
+func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
+ const nicID = 1
+ const dadTransmits = 1
+ const retransmitTimer = time.Second
+ const maxRetries = 3
+ const lifetimeSeconds = 10
+
+ prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
+
+ addrTypes := []struct {
+ name string
+ ndpConfigs stack.NDPConfigurations
+ autoGenLinkLocal bool
+ subnet tcpip.Subnet
+ triggerSLAACFn func(e *channel.Endpoint)
+ }{
+ {
+ name: "Global address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenAddressConflictRetries: maxRetries,
+ },
+ subnet: subnet,
+ triggerSLAACFn: func(e *channel.Endpoint) {
+ // Receive an RA with prefix1 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
+
+ },
+ },
+ {
+ name: "LinkLocal address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ AutoGenAddressConflictRetries: maxRetries,
+ },
+ autoGenLinkLocal: true,
+ subnet: header.IPv6LinkLocalPrefix.Subnet(),
+ triggerSLAACFn: func(e *channel.Endpoint) {},
+ },
+ }
+
+ for _, addrType := range addrTypes {
+ addrType := addrType
+
+ t.Run(addrType.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: addrType.ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ addrType.triggerSLAACFn(e)
+
+ addrBytes := []byte(addrType.subnet.ID())
+ header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr1, addrBytes[header.IIDOffsetInIPv6Address:])
+ addr := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: 64,
+ }
+ expectAutoGenAddrEvent(addr, newAddr)
+
+ // Simulate a DAD conflict.
+ if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
+ t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
+ }
+ expectAutoGenAddrEvent(addr, invalidatedAddr)
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+
+ // Should not attempt address regeneration.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
+ case <-time.After(defaultAsyncNegativeEventTimeout):
+ }
+ })
+ }
+}
+
+// TestAutoGenAddrContinuesLifetimesAfterRetry tests that retrying address
+// generation in response to DAD conflicts does not refresh the lifetimes.
+func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
+ const nicID = 1
+ const nicName = "nic"
+ const dadTransmits = 1
+ const retransmitTimer = 2 * time.Second
+ const failureTimer = time.Second
+ const maxRetries = 1
+ const lifetimeSeconds = 5
+
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
+ secretKey := secretKeyBuf[:]
+ n, err := rand.Read(secretKey)
+ if err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ }
+ if n != header.OpaqueIIDSecretKeyMinBytes {
+ t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
+ }
+
+ prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenAddressConflictRetries: maxRetries,
+ },
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })
+ opts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ // Receive an RA with prefix in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
+
+ addrBytes := []byte(subnet.ID())
+ addr := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 0, secretKey)),
+ PrefixLen: 64,
+ }
+ expectAutoGenAddrEvent(addr, newAddr)
+
+ // Simulate a DAD conflict after some time has passed.
+ time.Sleep(failureTimer)
+ if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
+ t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
+ }
+ expectAutoGenAddrEvent(addr, invalidatedAddr)
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+
+ // Let the next address resolve.
+ addr.Address = tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 1, secretKey))
+ expectAutoGenAddrEvent(addr, newAddr)
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+
+ // Address should be deprecated/invalidated after the lifetime expires.
+ //
+ // Note, the remaining lifetime is calculated from when the PI was first
+ // processed. Since we wait for some time before simulating a DAD conflict
+ // and more time for the new address to resolve, the new address is only
+ // expected to be valid for the remaining time. The DAD conflict should
+ // not have reset the lifetimes.
+ //
+ // We expect either just the invalidation event or the deprecation event
+ // followed by the invalidation event.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if e.eventType == deprecatedAddr {
+ if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation")
+ }
+ } else {
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ }
+ case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for auto gen addr event")
+ }
+}
+
+// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event
+// to the integrator when an RA is received with the NDP Recursive DNS Server
+// option with at least one valid address.
+func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
+ tests := []struct {
+ name string
+ opt header.NDPRecursiveDNSServer
+ expected *ndpRDNSS
+ }{
+ {
+ "Unspecified",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ }),
+ nil,
+ },
+ {
+ "Multicast",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+ }),
+ nil,
+ },
+ {
+ "OptionTooSmall",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ 1, 2, 3, 4, 5, 6, 7, 8,
+ }),
+ nil,
+ },
+ {
+ "0Addresses",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ }),
+ nil,
+ },
+ {
+ "Valid1Address",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1,
+ }),
+ &ndpRDNSS{
+ []tcpip.Address{
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01",
+ },
+ 2 * time.Second,
+ },
+ },
+ {
+ "Valid2Addresses",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2,
+ }),
+ &ndpRDNSS{
+ []tcpip.Address{
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01",
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02",
+ },
+ time.Second,
+ },
+ },
+ {
+ "Valid3Addresses",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 3,
+ }),
+ &ndpRDNSS{
+ []tcpip.Address{
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01",
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02",
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x03",
+ },
+ 0,
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ // We do not expect more than a single RDNSS
+ // event at any time for this test.
+ rdnssC: make(chan ndpRDNSSEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, header.NDPOptionsSerializer{test.opt}))
+
+ if test.expected != nil {
+ select {
+ case e := <-ndpDisp.rdnssC:
+ if e.nicID != 1 {
+ t.Errorf("got rdnss nicID = %d, want = 1", e.nicID)
+ }
+ if diff := cmp.Diff(e.rdnss.addrs, test.expected.addrs); diff != "" {
+ t.Errorf("rdnss addrs mismatch (-want +got):\n%s", diff)
+ }
+ if e.rdnss.lifetime != test.expected.lifetime {
+ t.Errorf("got rdnss lifetime = %s, want = %s", e.rdnss.lifetime, test.expected.lifetime)
+ }
+ default:
+ t.Fatal("expected an RDNSS option event")
+ }
+ }
+
+ // Should have no more RDNSS options.
+ select {
+ case e := <-ndpDisp.rdnssC:
+ t.Fatalf("unexpectedly got a new RDNSS option event: %+v", e)
+ default:
+ }
+ })
+ }
+}
+
+// TestNDPDNSSearchListDispatch tests that the integrator is informed when an
+// NDP DNS Search List option is received with at least one domain name in the
+// search list.
+func TestNDPDNSSearchListDispatch(t *testing.T) {
+ const nicID = 1
+
+ ndpDisp := ndpDispatcher{
+ dnsslC: make(chan ndpDNSSLEvent, 3),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ optSer := header.NDPOptionsSerializer{
+ header.NDPDNSSearchList([]byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 2, 'h', 'i',
+ 0,
+ }),
+ header.NDPDNSSearchList([]byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 1, 'i',
+ 0,
+ 2, 'a', 'm',
+ 2, 'm', 'e',
+ 0,
+ }),
+ header.NDPDNSSearchList([]byte{
+ 0, 0,
+ 0, 0, 1, 0,
+ 3, 'x', 'y', 'z',
+ 0,
+ 5, 'h', 'e', 'l', 'l', 'o',
+ 5, 'w', 'o', 'r', 'l', 'd',
+ 0,
+ 4, 't', 'h', 'i', 's',
+ 2, 'i', 's',
+ 1, 'a',
+ 4, 't', 'e', 's', 't',
+ 0,
+ }),
+ }
+ expected := []struct {
+ domainNames []string
+ lifetime time.Duration
+ }{
+ {
+ domainNames: []string{
+ "hi",
+ },
+ lifetime: 0,
+ },
+ {
+ domainNames: []string{
+ "i",
+ "am.me",
+ },
+ lifetime: time.Second,
+ },
+ {
+ domainNames: []string{
+ "xyz",
+ "hello.world",
+ "this.is.a.test",
+ },
+ lifetime: 256 * time.Second,
+ },
+ }
+
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer))
+
+ for i, expected := range expected {
+ select {
+ case dnssl := <-ndpDisp.dnsslC:
+ if dnssl.nicID != nicID {
+ t.Errorf("got %d-th dnssl nicID = %d, want = %d", i, dnssl.nicID, nicID)
+ }
+ if diff := cmp.Diff(dnssl.domainNames, expected.domainNames); diff != "" {
+ t.Errorf("%d-th dnssl domain names mismatch (-want +got):\n%s", i, diff)
+ }
+ if dnssl.lifetime != expected.lifetime {
+ t.Errorf("got %d-th dnssl lifetime = %s, want = %s", i, dnssl.lifetime, expected.lifetime)
+ }
+ default:
+ t.Fatal("expected a DNSSL event")
+ }
+ }
+
+ // Should have no more DNSSL options.
+ select {
+ case <-ndpDisp.dnsslC:
+ t.Fatal("unexpectedly got a DNSSL event")
+ default:
+ }
+}
+
+// TestCleanupNDPState tests that all discovered routers and prefixes, and
+// auto-generated addresses are invalidated when a NIC becomes a router.
+func TestCleanupNDPState(t *testing.T) {
+ const (
+ lifetimeSeconds = 5
+ maxRouterAndPrefixEvents = 4
+ nicID1 = 1
+ nicID2 = 2
+ )
+
+ prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, subnet2, e1Addr2 := prefixSubnetAddr(1, linkAddr1)
+ e2Addr1 := addrForSubnet(subnet1, linkAddr2)
+ e2Addr2 := addrForSubnet(subnet2, linkAddr2)
+ llAddrWithPrefix1 := tcpip.AddressWithPrefix{
+ Address: llAddr1,
+ PrefixLen: 64,
+ }
+ llAddrWithPrefix2 := tcpip.AddressWithPrefix{
+ Address: llAddr2,
+ PrefixLen: 64,
+ }
+
+ tests := []struct {
+ name string
+ cleanupFn func(t *testing.T, s *stack.Stack)
+ keepAutoGenLinkLocal bool
+ maxAutoGenAddrEvents int
+ skipFinalAddrCheck bool
+ }{
+ // A NIC should still keep its auto-generated link-local address when
+ // becoming a router.
+ {
+ name: "Enable forwarding",
+ cleanupFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ s.SetForwarding(true)
+ },
+ keepAutoGenLinkLocal: true,
+ maxAutoGenAddrEvents: 4,
+ },
+
+ // A NIC should cleanup all NDP state when it is disabled.
+ {
+ name: "Disable NIC",
+ cleanupFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+
+ if err := s.DisableNIC(nicID1); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
+ }
+ if err := s.DisableNIC(nicID2); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
+ }
+ },
+ keepAutoGenLinkLocal: false,
+ maxAutoGenAddrEvents: 6,
+ },
+
+ // A NIC should cleanup all NDP state when it is removed.
+ {
+ name: "Remove NIC",
+ cleanupFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+
+ if err := s.RemoveNIC(nicID1); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID1, err)
+ }
+ if err := s.RemoveNIC(nicID2); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID2, err)
+ }
+ },
+ keepAutoGenLinkLocal: false,
+ maxAutoGenAddrEvents: 6,
+ // The NICs are removed so we can't check their addresses after calling
+ // stopFn.
+ skipFinalAddrCheck: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents),
+ rememberRouter: true,
+ prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents),
+ rememberPrefix: true,
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: true,
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ DiscoverOnLinkPrefixes: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ expectRouterEvent := func() (bool, ndpRouterEvent) {
+ select {
+ case e := <-ndpDisp.routerC:
+ return true, e
+ default:
+ }
+
+ return false, ndpRouterEvent{}
+ }
+
+ expectPrefixEvent := func() (bool, ndpPrefixEvent) {
+ select {
+ case e := <-ndpDisp.prefixC:
+ return true, e
+ default:
+ }
+
+ return false, ndpPrefixEvent{}
+ }
+
+ expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) {
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ return true, e
+ default:
+ }
+
+ return false, ndpAutoGenAddrEvent{}
+ }
+
+ e1 := channel.New(0, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
+ }
+ // We have other tests that make sure we receive the *correct* events
+ // on normal discovery of routers/prefixes, and auto-generated
+ // addresses. Here we just make sure we get an event and let other tests
+ // handle the correctness check.
+ expectAutoGenAddrEvent()
+
+ e2 := channel.New(0, 1280, linkAddr2)
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
+ }
+ expectAutoGenAddrEvent()
+
+ // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and
+ // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from
+ // llAddr4) to discover multiple routers and prefixes, and auto-gen
+ // multiple addresses.
+
+ e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1)
+ }
+
+ e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1)
+ }
+
+ e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2)
+ }
+
+ e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2)
+ }
+
+ // We should have the auto-generated addresses added.
+ nicinfo := s.NICInfo()
+ nic1Addrs := nicinfo[nicID1].ProtocolAddresses
+ nic2Addrs := nicinfo[nicID2].ProtocolAddresses
+ if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic1Addrs, e1Addr1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic1Addrs, e1Addr2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, e2Addr1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, e2Addr2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
+ }
+
+ // We can't proceed any further if we already failed the test (missing
+ // some discovery/auto-generated address events or addresses).
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ test.cleanupFn(t, s)
+
+ // Collect invalidation events after having NDP state cleaned up.
+ gotRouterEvents := make(map[ndpRouterEvent]int)
+ for i := 0; i < maxRouterAndPrefixEvents; i++ {
+ ok, e := expectRouterEvent()
+ if !ok {
+ t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
+ break
+ }
+ gotRouterEvents[e]++
+ }
+ gotPrefixEvents := make(map[ndpPrefixEvent]int)
+ for i := 0; i < maxRouterAndPrefixEvents; i++ {
+ ok, e := expectPrefixEvent()
+ if !ok {
+ t.Errorf("expected %d prefix events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
+ break
+ }
+ gotPrefixEvents[e]++
+ }
+ gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int)
+ for i := 0; i < test.maxAutoGenAddrEvents; i++ {
+ ok, e := expectAutoGenAddrEvent()
+ if !ok {
+ t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", test.maxAutoGenAddrEvents, i)
+ break
+ }
+ gotAutoGenAddrEvents[e]++
+ }
+
+ // No need to proceed any further if we already failed the test (missing
+ // some invalidation events).
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ expectedRouterEvents := map[ndpRouterEvent]int{
+ {nicID: nicID1, addr: llAddr3, discovered: false}: 1,
+ {nicID: nicID1, addr: llAddr4, discovered: false}: 1,
+ {nicID: nicID2, addr: llAddr3, discovered: false}: 1,
+ {nicID: nicID2, addr: llAddr4, discovered: false}: 1,
+ }
+ if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" {
+ t.Errorf("router events mismatch (-want +got):\n%s", diff)
+ }
+ expectedPrefixEvents := map[ndpPrefixEvent]int{
+ {nicID: nicID1, prefix: subnet1, discovered: false}: 1,
+ {nicID: nicID1, prefix: subnet2, discovered: false}: 1,
+ {nicID: nicID2, prefix: subnet1, discovered: false}: 1,
+ {nicID: nicID2, prefix: subnet2, discovered: false}: 1,
+ }
+ if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" {
+ t.Errorf("prefix events mismatch (-want +got):\n%s", diff)
+ }
+ expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{
+ {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1,
+ {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1,
+ {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1,
+ {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1,
+ }
+
+ if !test.keepAutoGenLinkLocal {
+ expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID1, addr: llAddrWithPrefix1, eventType: invalidatedAddr}] = 1
+ expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID2, addr: llAddrWithPrefix2, eventType: invalidatedAddr}] = 1
+ }
+
+ if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" {
+ t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff)
+ }
+
+ if !test.skipFinalAddrCheck {
+ // Make sure the auto-generated addresses got removed.
+ nicinfo = s.NICInfo()
+ nic1Addrs = nicinfo[nicID1].ProtocolAddresses
+ nic2Addrs = nicinfo[nicID2].ProtocolAddresses
+ if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ }
+ }
+ if containsV6Addr(nic1Addrs, e1Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic1Addrs, e1Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ }
+ }
+ if containsV6Addr(nic2Addrs, e2Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
+ }
+ if containsV6Addr(nic2Addrs, e2Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
+ }
+ }
+
+ // Should not get any more events (invalidation timers should have been
+ // cancelled when the NDP state was cleaned up).
+ time.Sleep(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout)
+ select {
+ case <-ndpDisp.routerC:
+ t.Error("unexpected router event")
+ default:
+ }
+ select {
+ case <-ndpDisp.prefixC:
+ t.Error("unexpected prefix event")
+ default:
+ }
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Error("unexpected auto-generated address event")
+ default:
+ }
+ })
+ }
+}
+
+// TestDHCPv6ConfigurationFromNDPDA tests that the NDPDispatcher is properly
+// informed when new information about what configurations are available via
+// DHCPv6 is learned.
+func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
+ const nicID = 1
+
+ ndpDisp := ndpDispatcher{
+ dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1),
+ rememberRouter: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ expectDHCPv6Event := func(configuration stack.DHCPv6ConfigurationFromNDPRA) {
+ t.Helper()
+ select {
+ case e := <-ndpDisp.dhcpv6ConfigurationC:
+ if diff := cmp.Diff(ndpDHCPv6Event{nicID: nicID, configuration: configuration}, e, cmp.AllowUnexported(e)); diff != "" {
+ t.Errorf("dhcpv6 event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DHCPv6 configuration event")
+ }
+ }
+
+ expectNoDHCPv6Event := func() {
+ t.Helper()
+ select {
+ case <-ndpDisp.dhcpv6ConfigurationC:
+ t.Fatal("unexpected DHCPv6 configuration event")
+ default:
+ }
+ }
+
+ // Even if the first RA reports no DHCPv6 configurations are available, the
+ // dispatcher should get an event.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
+ expectDHCPv6Event(stack.DHCPv6NoConfiguration)
+ // Receiving the same update again should not result in an event to the
+ // dispatcher.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Other
+ // Configurations.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Managed Address.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
+ expectDHCPv6Event(stack.DHCPv6ManagedAddress)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to none.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
+ expectDHCPv6Event(stack.DHCPv6NoConfiguration)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Managed Address.
+ //
+ // Note, when the M flag is set, the O flag is redundant.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
+ expectDHCPv6Event(stack.DHCPv6ManagedAddress)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
+ expectNoDHCPv6Event()
+ // Even though the DHCPv6 flags are different, the effective configuration is
+ // the same so we should not receive a new event.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
+ expectNoDHCPv6Event()
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Other
+ // Configurations.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectNoDHCPv6Event()
+
+ // Cycling the NIC should cause the last DHCPv6 configuration to be cleared.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+
+ // Receive an RA that updates the DHCPv6 configuration to Other
+ // Configurations.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectNoDHCPv6Event()
+}
+
+// TestRouterSolicitation tests the initial Router Solicitations that are sent
+// when a NIC newly becomes enabled.
+func TestRouterSolicitation(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ linkHeaderLen uint16
+ linkAddr tcpip.LinkAddress
+ nicAddr tcpip.Address
+ expectedSrcAddr tcpip.Address
+ expectedNDPOpts []header.NDPOption
+ maxRtrSolicit uint8
+ rtrSolicitInt time.Duration
+ effectiveRtrSolicitInt time.Duration
+ maxRtrSolicitDelay time.Duration
+ effectiveMaxRtrSolicitDelay time.Duration
+ }{
+ {
+ name: "Single RS with 2s delay and interval",
+ expectedSrcAddr: header.IPv6Any,
+ maxRtrSolicit: 1,
+ rtrSolicitInt: 2 * time.Second,
+ effectiveRtrSolicitInt: 2 * time.Second,
+ maxRtrSolicitDelay: 2 * time.Second,
+ effectiveMaxRtrSolicitDelay: 2 * time.Second,
+ },
+ {
+ name: "Single RS with 4s delay and interval",
+ expectedSrcAddr: header.IPv6Any,
+ maxRtrSolicit: 1,
+ rtrSolicitInt: 4 * time.Second,
+ effectiveRtrSolicitInt: 4 * time.Second,
+ maxRtrSolicitDelay: 4 * time.Second,
+ effectiveMaxRtrSolicitDelay: 4 * time.Second,
+ },
+ {
+ name: "Two RS with delay",
+ linkHeaderLen: 1,
+ nicAddr: llAddr1,
+ expectedSrcAddr: llAddr1,
+ maxRtrSolicit: 2,
+ rtrSolicitInt: 2 * time.Second,
+ effectiveRtrSolicitInt: 2 * time.Second,
+ maxRtrSolicitDelay: 500 * time.Millisecond,
+ effectiveMaxRtrSolicitDelay: 500 * time.Millisecond,
+ },
+ {
+ name: "Single RS without delay",
+ linkHeaderLen: 2,
+ linkAddr: linkAddr1,
+ nicAddr: llAddr1,
+ expectedSrcAddr: llAddr1,
+ expectedNDPOpts: []header.NDPOption{
+ header.NDPSourceLinkLayerAddressOption(linkAddr1),
+ },
+ maxRtrSolicit: 1,
+ rtrSolicitInt: 2 * time.Second,
+ effectiveRtrSolicitInt: 2 * time.Second,
+ maxRtrSolicitDelay: 0,
+ effectiveMaxRtrSolicitDelay: 0,
+ },
+ {
+ name: "Two RS without delay and invalid zero interval",
+ linkHeaderLen: 3,
+ linkAddr: linkAddr1,
+ expectedSrcAddr: header.IPv6Any,
+ maxRtrSolicit: 2,
+ rtrSolicitInt: 0,
+ effectiveRtrSolicitInt: 4 * time.Second,
+ maxRtrSolicitDelay: 0,
+ effectiveMaxRtrSolicitDelay: 0,
+ },
+ {
+ name: "Three RS without delay",
+ linkAddr: linkAddr1,
+ expectedSrcAddr: header.IPv6Any,
+ maxRtrSolicit: 3,
+ rtrSolicitInt: 500 * time.Millisecond,
+ effectiveRtrSolicitInt: 500 * time.Millisecond,
+ maxRtrSolicitDelay: 0,
+ effectiveMaxRtrSolicitDelay: 0,
+ },
+ {
+ name: "Two RS with invalid negative delay",
+ linkAddr: linkAddr1,
+ expectedSrcAddr: header.IPv6Any,
+ maxRtrSolicit: 2,
+ rtrSolicitInt: time.Second,
+ effectiveRtrSolicitInt: time.Second,
+ maxRtrSolicitDelay: -3 * time.Second,
+ effectiveMaxRtrSolicitDelay: time.Second,
+ },
+ }
+
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ p, ok := e.ReadContext(ctx)
+ if !ok {
+ t.Fatal("timed out waiting for packet")
+ return
+ }
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+
+ // Make sure the right remote link address is used.
+ if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
+
+ checker.IPv6(t,
+ p.Pkt.Header.View(),
+ checker.SrcAddr(test.expectedSrcAddr),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
+ )
+
+ if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want)
+ }
+ }
+ waitForNothing := func(timeout time.Duration) {
+ t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got a packet")
+ }
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ MaxRtrSolicitations: test.maxRtrSolicit,
+ RtrSolicitationInterval: test.rtrSolicitInt,
+ MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
+ },
+ })
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ if addr := test.nicAddr; addr != "" {
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
+ }
+ }
+
+ // Make sure each RS is sent at the right time.
+ remaining := test.maxRtrSolicit
+ if remaining > 0 {
+ waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout)
+ remaining--
+ }
+
+ for ; remaining > 0; remaining-- {
+ if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
+ waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout)
+ waitForPkt(defaultAsyncPositiveEventTimeout)
+ } else {
+ waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout)
+ }
+ }
+
+ // Make sure no more RS.
+ if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
+ waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout)
+ } else {
+ waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout)
+ }
+
+ // Make sure the counter got properly
+ // incremented.
+ if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
+ t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
+ }
+ })
+ }
+ })
+}
+
+func TestStopStartSolicitingRouters(t *testing.T) {
+ const nicID = 1
+ const delay = 0
+ const interval = 500 * time.Millisecond
+ const maxRtrSolicitations = 3
+
+ tests := []struct {
+ name string
+ startFn func(t *testing.T, s *stack.Stack)
+ // first is used to tell stopFn that it is being called for the first time
+ // after router solicitations were last enabled.
+ stopFn func(t *testing.T, s *stack.Stack, first bool)
+ }{
+ // Tests that when forwarding is enabled or disabled, router solicitations
+ // are stopped or started, respectively.
+ {
+ name: "Enable and disable forwarding",
+ startFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ s.SetForwarding(false)
+ },
+ stopFn: func(t *testing.T, s *stack.Stack, _ bool) {
+ t.Helper()
+ s.SetForwarding(true)
+ },
+ },
+
+ // Tests that when a NIC is enabled or disabled, router solicitations
+ // are started or stopped, respectively.
+ {
+ name: "Enable and disable NIC",
+ startFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ },
+ stopFn: func(t *testing.T, s *stack.Stack, _ bool) {
+ t.Helper()
+
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ },
+ },
+
+ // Tests that when a NIC is removed, router solicitations are stopped. We
+ // cannot start router solications on a removed NIC.
+ {
+ name: "Remove NIC",
+ stopFn: func(t *testing.T, s *stack.Stack, first bool) {
+ t.Helper()
+
+ // Only try to remove the NIC the first time stopFn is called since it's
+ // impossible to remove an already removed NIC.
+ if !first {
+ return
+ }
+
+ if err := s.RemoveNIC(nicID); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID, err)
+ }
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ p, ok := e.ReadContext(ctx)
+ if !ok {
+ t.Fatal("timed out waiting for packet")
+ }
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+ checker.IPv6(t, p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS())
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ MaxRtrSolicitations: maxRtrSolicitations,
+ RtrSolicitationInterval: interval,
+ MaxRtrSolicitationDelay: delay,
+ },
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ // Stop soliciting routers.
+ test.stopFn(t, s, true /* first */)
+ ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ // A single RS may have been sent before solicitations were stopped.
+ ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout)
+ defer cancel()
+ if _, ok = e.ReadContext(ctx); ok {
+ t.Fatal("should not have sent more than one RS message")
+ }
+ }
+
+ // Stopping router solicitations after it has already been stopped should
+ // do nothing.
+ test.stopFn(t, s, false /* first */)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
+ }
+
+ // If test.startFn is nil, there is no way to restart router solications.
+ if test.startFn == nil {
+ return
+ }
+
+ // Start soliciting routers.
+ test.startFn(t, s)
+ waitForPkt(delay + defaultAsyncPositiveEventTimeout)
+ waitForPkt(interval + defaultAsyncPositiveEventTimeout)
+ waitForPkt(interval + defaultAsyncPositiveEventTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
+ }
+
+ // Starting router solicitations after it has already completed should do
+ // nothing.
+ test.startFn(t, s)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got a packet after finishing router solicitations")
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
new file mode 100644
index 000000000..7b80534e6
--- /dev/null
+++ b/pkg/tcpip/stack/nic.go
@@ -0,0 +1,1743 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "reflect"
+ "sort"
+ "strings"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+var ipv4BroadcastAddr = tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: header.IPv4Broadcast,
+ PrefixLen: 8 * header.IPv4AddressSize,
+ },
+}
+
+// NIC represents a "network interface card" to which the networking stack is
+// attached.
+type NIC struct {
+ stack *Stack
+ id tcpip.NICID
+ name string
+ linkEP LinkEndpoint
+ context NICContext
+
+ stats NICStats
+
+ mu struct {
+ sync.RWMutex
+ enabled bool
+ spoofing bool
+ promiscuous bool
+ primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
+ endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
+ addressRanges []tcpip.Subnet
+ mcastJoins map[NetworkEndpointID]uint32
+ // packetEPs is protected by mu, but the contained PacketEndpoint
+ // values are not.
+ packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
+ ndp ndpState
+ }
+}
+
+// NICStats includes transmitted and received stats.
+type NICStats struct {
+ Tx DirectionStats
+ Rx DirectionStats
+
+ DisabledRx DirectionStats
+}
+
+func makeNICStats() NICStats {
+ var s NICStats
+ tcpip.InitStatCounters(reflect.ValueOf(&s).Elem())
+ return s
+}
+
+// DirectionStats includes packet and byte counts.
+type DirectionStats struct {
+ Packets *tcpip.StatCounter
+ Bytes *tcpip.StatCounter
+}
+
+// PrimaryEndpointBehavior is an enumeration of an endpoint's primacy behavior.
+type PrimaryEndpointBehavior int
+
+const (
+ // CanBePrimaryEndpoint indicates the endpoint can be used as a primary
+ // endpoint for new connections with no local address. This is the
+ // default when calling NIC.AddAddress.
+ CanBePrimaryEndpoint PrimaryEndpointBehavior = iota
+
+ // FirstPrimaryEndpoint indicates the endpoint should be the first
+ // primary endpoint considered. If there are multiple endpoints with
+ // this behavior, the most recently-added one will be first.
+ FirstPrimaryEndpoint
+
+ // NeverPrimaryEndpoint indicates the endpoint should never be a
+ // primary endpoint.
+ NeverPrimaryEndpoint
+)
+
+// newNIC returns a new NIC using the default NDP configurations from stack.
+func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC {
+ // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
+ // example, make sure that the link address it provides is a valid
+ // unicast ethernet address.
+
+ // TODO(b/143357959): RFC 8200 section 5 requires that IPv6 endpoints
+ // observe an MTU of at least 1280 bytes. Ensure that this requirement
+ // of IPv6 is supported on this endpoint's LinkEndpoint.
+
+ nic := &NIC{
+ stack: stack,
+ id: id,
+ name: name,
+ linkEP: ep,
+ context: ctx,
+ stats: makeNICStats(),
+ }
+ nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint)
+ nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint)
+ nic.mu.mcastJoins = make(map[NetworkEndpointID]uint32)
+ nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint)
+ nic.mu.ndp = ndpState{
+ nic: nic,
+ configs: stack.ndpConfigs,
+ dad: make(map[tcpip.Address]dadState),
+ defaultRouters: make(map[tcpip.Address]defaultRouterState),
+ onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
+ slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
+ }
+ nic.mu.ndp.initializeTempAddrState()
+
+ // Register supported packet endpoint protocols.
+ for _, netProto := range header.Ethertypes {
+ nic.mu.packetEPs[netProto] = []PacketEndpoint{}
+ }
+ for _, netProto := range stack.networkProtocols {
+ nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{}
+ }
+
+ nic.linkEP.Attach(nic)
+
+ return nic
+}
+
+// enabled returns true if n is enabled.
+func (n *NIC) enabled() bool {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ n.mu.RUnlock()
+ return enabled
+}
+
+// disable disables n.
+//
+// It undoes the work done by enable.
+func (n *NIC) disable() *tcpip.Error {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ n.mu.RUnlock()
+ if !enabled {
+ return nil
+ }
+
+ n.mu.Lock()
+ err := n.disableLocked()
+ n.mu.Unlock()
+ return err
+}
+
+// disableLocked disables n.
+//
+// It undoes the work done by enable.
+//
+// n MUST be locked.
+func (n *NIC) disableLocked() *tcpip.Error {
+ if !n.mu.enabled {
+ return nil
+ }
+
+ // TODO(b/147015577): Should Routes that are currently bound to n be
+ // invalidated? Currently, Routes will continue to work when a NIC is enabled
+ // again, and applications may not know that the underlying NIC was ever
+ // disabled.
+
+ if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok {
+ n.mu.ndp.stopSolicitingRouters()
+ n.mu.ndp.cleanupState(false /* hostOnly */)
+
+ // Stop DAD for all the unicast IPv6 endpoints that are in the
+ // permanentTentative state.
+ for _, r := range n.mu.endpoints {
+ if addr := r.ep.ID().LocalAddress; r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) {
+ n.mu.ndp.stopDuplicateAddressDetection(addr)
+ }
+ }
+
+ // The NIC may have already left the multicast group.
+ if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+ }
+
+ if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
+ // The address may have already been removed.
+ if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+ }
+
+ n.mu.enabled = false
+ return nil
+}
+
+// enable enables n.
+//
+// If the stack has IPv6 enabled, enable will join the IPv6 All-Nodes Multicast
+// address (ff02::1), start DAD for permanent addresses, and start soliciting
+// routers if the stack is not operating as a router. If the stack is also
+// configured to auto-generate a link-local address, one will be generated.
+func (n *NIC) enable() *tcpip.Error {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ n.mu.RUnlock()
+ if enabled {
+ return nil
+ }
+
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if n.mu.enabled {
+ return nil
+ }
+
+ n.mu.enabled = true
+
+ // Create an endpoint to receive broadcast packets on this interface.
+ if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
+ if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
+ return err
+ }
+ }
+
+ // Join the IPv6 All-Nodes Multicast group if the stack is configured to
+ // use IPv6. This is required to ensure that this node properly receives
+ // and responds to the various NDP messages that are destined to the
+ // all-nodes multicast address. An example is the Neighbor Advertisement
+ // when we perform Duplicate Address Detection, or Router Advertisement
+ // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861
+ // section 4.2 for more information.
+ //
+ // Also auto-generate an IPv6 link-local address based on the NIC's
+ // link address if it is configured to do so. Note, each interface is
+ // required to have IPv6 link-local unicast address, as per RFC 4291
+ // section 2.1.
+ _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]
+ if !ok {
+ return nil
+ }
+
+ // Join the All-Nodes multicast group before starting DAD as responses to DAD
+ // (NDP NS) messages may be sent to the All-Nodes multicast group if the
+ // source address of the NDP NS is the unspecified address, as per RFC 4861
+ // section 7.2.4.
+ if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
+ return err
+ }
+
+ // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
+ // state.
+ //
+ // Addresses may have aleady completed DAD but in the time since the NIC was
+ // last enabled, other devices may have acquired the same addresses.
+ for _, r := range n.mu.endpoints {
+ addr := r.ep.ID().LocalAddress
+ if k := r.getKind(); (k != permanent && k != permanentTentative) || !header.IsV6UnicastAddress(addr) {
+ continue
+ }
+
+ r.setKind(permanentTentative)
+ if err := n.mu.ndp.startDuplicateAddressDetection(addr, r); err != nil {
+ return err
+ }
+ }
+
+ // Do not auto-generate an IPv6 link-local address for loopback devices.
+ if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() {
+ // The valid and preferred lifetime is infinite for the auto-generated
+ // link-local address.
+ n.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime)
+ }
+
+ // If we are operating as a router, then do not solicit routers since we
+ // won't process the RAs anyways.
+ //
+ // Routers do not process Router Advertisements (RA) the same way a host
+ // does. That is, routers do not learn from RAs (e.g. on-link prefixes
+ // and default routers). Therefore, soliciting RAs from other routers on
+ // a link is unnecessary for routers.
+ if !n.stack.forwarding {
+ n.mu.ndp.startSolicitingRouters()
+ }
+
+ return nil
+}
+
+// remove detaches NIC from the link endpoint, and marks existing referenced
+// network endpoints expired. This guarantees no packets between this NIC and
+// the network stack.
+func (n *NIC) remove() *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ n.disableLocked()
+
+ // TODO(b/151378115): come up with a better way to pick an error than the
+ // first one.
+ var err *tcpip.Error
+
+ // Forcefully leave multicast groups.
+ for nid := range n.mu.mcastJoins {
+ if tempErr := n.leaveGroupLocked(nid.LocalAddress, true /* force */); tempErr != nil && err == nil {
+ err = tempErr
+ }
+ }
+
+ // Remove permanent and permanentTentative addresses, so no packet goes out.
+ for nid, ref := range n.mu.endpoints {
+ switch ref.getKind() {
+ case permanentTentative, permanent:
+ if tempErr := n.removePermanentAddressLocked(nid.LocalAddress); tempErr != nil && err == nil {
+ err = tempErr
+ }
+ }
+ }
+
+ // Detach from link endpoint, so no packet comes in.
+ n.linkEP.Attach(nil)
+
+ return err
+}
+
+// becomeIPv6Router transitions n into an IPv6 router.
+//
+// When transitioning into an IPv6 router, host-only state (NDP discovered
+// routers, discovered on-link prefixes, and auto-generated addresses) will
+// be cleaned up/invalidated and NDP router solicitations will be stopped.
+func (n *NIC) becomeIPv6Router() {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ n.mu.ndp.cleanupState(true /* hostOnly */)
+ n.mu.ndp.stopSolicitingRouters()
+}
+
+// becomeIPv6Host transitions n into an IPv6 host.
+//
+// When transitioning into an IPv6 host, NDP router solicitations will be
+// started.
+func (n *NIC) becomeIPv6Host() {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ n.mu.ndp.startSolicitingRouters()
+}
+
+// setPromiscuousMode enables or disables promiscuous mode.
+func (n *NIC) setPromiscuousMode(enable bool) {
+ n.mu.Lock()
+ n.mu.promiscuous = enable
+ n.mu.Unlock()
+}
+
+func (n *NIC) isPromiscuousMode() bool {
+ n.mu.RLock()
+ rv := n.mu.promiscuous
+ n.mu.RUnlock()
+ return rv
+}
+
+func (n *NIC) isLoopback() bool {
+ return n.linkEP.Capabilities()&CapabilityLoopback != 0
+}
+
+// setSpoofing enables or disables address spoofing.
+func (n *NIC) setSpoofing(enable bool) {
+ n.mu.Lock()
+ n.mu.spoofing = enable
+ n.mu.Unlock()
+}
+
+// primaryEndpoint will return the first non-deprecated endpoint if such an
+// endpoint exists for the given protocol and remoteAddr. If no non-deprecated
+// endpoint exists, the first deprecated endpoint will be returned.
+//
+// If an IPv6 primary endpoint is requested, Source Address Selection (as
+// defined by RFC 6724 section 5) will be performed.
+func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint {
+ if protocol == header.IPv6ProtocolNumber && remoteAddr != "" {
+ return n.primaryIPv6Endpoint(remoteAddr)
+ }
+
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ var deprecatedEndpoint *referencedNetworkEndpoint
+ for _, r := range n.mu.primary[protocol] {
+ if !r.isValidForOutgoingRLocked() {
+ continue
+ }
+
+ if !r.deprecated {
+ if r.tryIncRef() {
+ // r is not deprecated, so return it immediately.
+ //
+ // If we kept track of a deprecated endpoint, decrement its reference
+ // count since it was incremented when we decided to keep track of it.
+ if deprecatedEndpoint != nil {
+ deprecatedEndpoint.decRefLocked()
+ deprecatedEndpoint = nil
+ }
+
+ return r
+ }
+ } else if deprecatedEndpoint == nil && r.tryIncRef() {
+ // We prefer an endpoint that is not deprecated, but we keep track of r in
+ // case n doesn't have any non-deprecated endpoints.
+ //
+ // If we end up finding a more preferred endpoint, r's reference count
+ // will be decremented when such an endpoint is found.
+ deprecatedEndpoint = r
+ }
+ }
+
+ // n doesn't have any valid non-deprecated endpoints, so return
+ // deprecatedEndpoint (which may be nil if n doesn't have any valid deprecated
+ // endpoints either).
+ return deprecatedEndpoint
+}
+
+// ipv6AddrCandidate is an IPv6 candidate for Source Address Selection (RFC
+// 6724 section 5).
+type ipv6AddrCandidate struct {
+ ref *referencedNetworkEndpoint
+ scope header.IPv6AddressScope
+}
+
+// primaryIPv6Endpoint returns an IPv6 endpoint following Source Address
+// Selection (RFC 6724 section 5).
+//
+// Note, only rules 1-3 and 7 are followed.
+//
+// remoteAddr must be a valid IPv6 address.
+func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
+ n.mu.RLock()
+ ref := n.primaryIPv6EndpointRLocked(remoteAddr)
+ n.mu.RUnlock()
+ return ref
+}
+
+// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address
+// Selection (RFC 6724 section 5).
+//
+// Note, only rules 1-3 and 7 are followed.
+//
+// remoteAddr must be a valid IPv6 address.
+//
+// n.mu MUST be read locked.
+func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
+ primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber]
+
+ if len(primaryAddrs) == 0 {
+ return nil
+ }
+
+ // Create a candidate set of available addresses we can potentially use as a
+ // source address.
+ cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs))
+ for _, r := range primaryAddrs {
+ // If r is not valid for outgoing connections, it is not a valid endpoint.
+ if !r.isValidForOutgoingRLocked() {
+ continue
+ }
+
+ addr := r.ep.ID().LocalAddress
+ scope, err := header.ScopeForIPv6Address(addr)
+ if err != nil {
+ // Should never happen as we got r from the primary IPv6 endpoint list and
+ // ScopeForIPv6Address only returns an error if addr is not an IPv6
+ // address.
+ panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err))
+ }
+
+ cs = append(cs, ipv6AddrCandidate{
+ ref: r,
+ scope: scope,
+ })
+ }
+
+ remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
+ if err != nil {
+ // primaryIPv6Endpoint should never be called with an invalid IPv6 address.
+ panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err))
+ }
+
+ // Sort the addresses as per RFC 6724 section 5 rules 1-3.
+ //
+ // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5.
+ sort.Slice(cs, func(i, j int) bool {
+ sa := cs[i]
+ sb := cs[j]
+
+ // Prefer same address as per RFC 6724 section 5 rule 1.
+ if sa.ref.ep.ID().LocalAddress == remoteAddr {
+ return true
+ }
+ if sb.ref.ep.ID().LocalAddress == remoteAddr {
+ return false
+ }
+
+ // Prefer appropriate scope as per RFC 6724 section 5 rule 2.
+ if sa.scope < sb.scope {
+ return sa.scope >= remoteScope
+ } else if sb.scope < sa.scope {
+ return sb.scope < remoteScope
+ }
+
+ // Avoid deprecated addresses as per RFC 6724 section 5 rule 3.
+ if saDep, sbDep := sa.ref.deprecated, sb.ref.deprecated; saDep != sbDep {
+ // If sa is not deprecated, it is preferred over sb.
+ return sbDep
+ }
+
+ // Prefer temporary addresses as per RFC 6724 section 5 rule 7.
+ if saTemp, sbTemp := sa.ref.configType == slaacTemp, sb.ref.configType == slaacTemp; saTemp != sbTemp {
+ return saTemp
+ }
+
+ // sa and sb are equal, return the endpoint that is closest to the front of
+ // the primary endpoint list.
+ return i < j
+ })
+
+ // Return the most preferred address that can have its reference count
+ // incremented.
+ for _, c := range cs {
+ if r := c.ref; r.tryIncRef() {
+ return r
+ }
+ }
+
+ return nil
+}
+
+// hasPermanentAddrLocked returns true if n has a permanent (including currently
+// tentative) address, addr.
+func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool {
+ ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
+
+ if !ok {
+ return false
+ }
+
+ kind := ref.getKind()
+
+ return kind == permanent || kind == permanentTentative
+}
+
+type getRefBehaviour int
+
+const (
+ // spoofing indicates that the NIC's spoofing flag should be observed when
+ // getting a NIC's referenced network endpoint.
+ spoofing getRefBehaviour = iota
+
+ // promiscuous indicates that the NIC's promiscuous flag should be observed
+ // when getting a NIC's referenced network endpoint.
+ promiscuous
+)
+
+func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
+ return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous)
+}
+
+// findEndpoint finds the endpoint, if any, with the given address.
+func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
+ return n.getRefOrCreateTemp(protocol, address, peb, spoofing)
+}
+
+// getRefEpOrCreateTemp returns the referenced network endpoint for the given
+// protocol and address.
+//
+// If none exists a temporary one may be created if we are in promiscuous mode
+// or spoofing. Promiscuous mode will only be checked if promiscuous is true.
+// Similarly, spoofing will only be checked if spoofing is true.
+func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint {
+ n.mu.RLock()
+
+ var spoofingOrPromiscuous bool
+ switch tempRef {
+ case spoofing:
+ spoofingOrPromiscuous = n.mu.spoofing
+ case promiscuous:
+ spoofingOrPromiscuous = n.mu.promiscuous
+ }
+
+ if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
+ // An endpoint with this id exists, check if it can be used and return it.
+ if !ref.isAssignedRLocked(spoofingOrPromiscuous) {
+ n.mu.RUnlock()
+ return nil
+ }
+
+ if ref.tryIncRef() {
+ n.mu.RUnlock()
+ return ref
+ }
+ }
+
+ // A usable reference was not found, create a temporary one if requested by
+ // the caller or if the address is found in the NIC's subnets.
+ createTempEP := spoofingOrPromiscuous
+ if !createTempEP {
+ for _, sn := range n.mu.addressRanges {
+ // Skip the subnet address.
+ if address == sn.ID() {
+ continue
+ }
+ // For now just skip the broadcast address, until we support it.
+ // FIXME(b/137608825): Add support for sending/receiving directed
+ // (subnet) broadcast.
+ if address == sn.Broadcast() {
+ continue
+ }
+ if sn.Contains(address) {
+ createTempEP = true
+ break
+ }
+ }
+ }
+
+ n.mu.RUnlock()
+
+ if !createTempEP {
+ return nil
+ }
+
+ // Try again with the lock in exclusive mode. If we still can't get the
+ // endpoint, create a new "temporary" endpoint. It will only exist while
+ // there's a route through it.
+ n.mu.Lock()
+ ref := n.getRefOrCreateTempLocked(protocol, address, peb)
+ n.mu.Unlock()
+ return ref
+}
+
+/// getRefOrCreateTempLocked returns an existing endpoint for address or creates
+/// and returns a temporary endpoint.
+func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
+ if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
+ // No need to check the type as we are ok with expired endpoints at this
+ // point.
+ if ref.tryIncRef() {
+ return ref
+ }
+ // tryIncRef failing means the endpoint is scheduled to be removed once the
+ // lock is released. Remove it here so we can create a new (temporary) one.
+ // The removal logic waiting for the lock handles this case.
+ n.removeEndpointLocked(ref)
+ }
+
+ // Add a new temporary endpoint.
+ netProto, ok := n.stack.networkProtocols[protocol]
+ if !ok {
+ return nil
+ }
+ ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: netProto.DefaultPrefixLen(),
+ },
+ }, peb, temporary, static, false)
+ return ref
+}
+
+// addAddressLocked adds a new protocolAddress to n.
+//
+// If n already has the address in a non-permanent state, and the kind given is
+// permanent, that address will be promoted in place and its properties set to
+// the properties provided. Otherwise, it returns tcpip.ErrDuplicateAddress.
+func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) {
+ // TODO(b/141022673): Validate IP addresses before adding them.
+
+ // Sanity check.
+ id := NetworkEndpointID{LocalAddress: protocolAddress.AddressWithPrefix.Address}
+ if ref, ok := n.mu.endpoints[id]; ok {
+ // Endpoint already exists.
+ if kind != permanent {
+ return nil, tcpip.ErrDuplicateAddress
+ }
+ switch ref.getKind() {
+ case permanentTentative, permanent:
+ // The NIC already have a permanent endpoint with that address.
+ return nil, tcpip.ErrDuplicateAddress
+ case permanentExpired, temporary:
+ // Promote the endpoint to become permanent and respect the new peb,
+ // configType and deprecated status.
+ if ref.tryIncRef() {
+ // TODO(b/147748385): Perform Duplicate Address Detection when promoting
+ // an IPv6 endpoint to permanent.
+ ref.setKind(permanent)
+ ref.deprecated = deprecated
+ ref.configType = configType
+
+ refs := n.mu.primary[ref.protocol]
+ for i, r := range refs {
+ if r == ref {
+ switch peb {
+ case CanBePrimaryEndpoint:
+ return ref, nil
+ case FirstPrimaryEndpoint:
+ if i == 0 {
+ return ref, nil
+ }
+ n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ case NeverPrimaryEndpoint:
+ n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ return ref, nil
+ }
+ }
+ }
+
+ n.insertPrimaryEndpointLocked(ref, peb)
+
+ return ref, nil
+ }
+ // tryIncRef failing means the endpoint is scheduled to be removed once
+ // the lock is released. Remove it here so we can create a new
+ // (permanent) one. The removal logic waiting for the lock handles this
+ // case.
+ n.removeEndpointLocked(ref)
+ }
+ }
+
+ netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol]
+ if !ok {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ // Create the new network endpoint.
+ ep, err := netProto.NewEndpoint(n.id, protocolAddress.AddressWithPrefix, n.stack, n, n.linkEP, n.stack)
+ if err != nil {
+ return nil, err
+ }
+
+ isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address)
+
+ // If the address is an IPv6 address and it is a permanent address,
+ // mark it as tentative so it goes through the DAD process if the NIC is
+ // enabled. If the NIC is not enabled, DAD will be started when the NIC is
+ // enabled.
+ if isIPv6Unicast && kind == permanent {
+ kind = permanentTentative
+ }
+
+ ref := &referencedNetworkEndpoint{
+ refs: 1,
+ ep: ep,
+ nic: n,
+ protocol: protocolAddress.Protocol,
+ kind: kind,
+ configType: configType,
+ deprecated: deprecated,
+ }
+
+ // Set up cache if link address resolution exists for this protocol.
+ if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
+ if _, ok := n.stack.linkAddrResolvers[protocolAddress.Protocol]; ok {
+ ref.linkCache = n.stack
+ }
+ }
+
+ // If we are adding an IPv6 unicast address, join the solicited-node
+ // multicast address.
+ if isIPv6Unicast {
+ snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address)
+ if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil {
+ return nil, err
+ }
+ }
+
+ n.mu.endpoints[id] = ref
+
+ n.insertPrimaryEndpointLocked(ref, peb)
+
+ // If we are adding a tentative IPv6 address, start DAD if the NIC is enabled.
+ if isIPv6Unicast && kind == permanentTentative && n.mu.enabled {
+ if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil {
+ return nil, err
+ }
+ }
+
+ return ref, nil
+}
+
+// AddAddress adds a new address to n, so that it starts accepting packets
+// targeted at the given address (and network protocol).
+func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
+ // Add the endpoint.
+ n.mu.Lock()
+ _, err := n.addAddressLocked(protocolAddress, peb, permanent, static, false /* deprecated */)
+ n.mu.Unlock()
+
+ return err
+}
+
+// AllAddresses returns all addresses (primary and non-primary) associated with
+// this NIC.
+func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ addrs := make([]tcpip.ProtocolAddress, 0, len(n.mu.endpoints))
+ for nid, ref := range n.mu.endpoints {
+ // Don't include tentative, expired or temporary endpoints to
+ // avoid confusion and prevent the caller from using those.
+ switch ref.getKind() {
+ case permanentExpired, temporary:
+ continue
+ }
+
+ addrs = append(addrs, tcpip.ProtocolAddress{
+ Protocol: ref.protocol,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: nid.LocalAddress,
+ PrefixLen: ref.ep.PrefixLen(),
+ },
+ })
+ }
+ return addrs
+}
+
+// PrimaryAddresses returns the primary addresses associated with this NIC.
+func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ var addrs []tcpip.ProtocolAddress
+ for proto, list := range n.mu.primary {
+ for _, ref := range list {
+ // Don't include tentative, expired or tempory endpoints
+ // to avoid confusion and prevent the caller from using
+ // those.
+ switch ref.getKind() {
+ case permanentTentative, permanentExpired, temporary:
+ continue
+ }
+
+ addrs = append(addrs, tcpip.ProtocolAddress{
+ Protocol: proto,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: ref.ep.ID().LocalAddress,
+ PrefixLen: ref.ep.PrefixLen(),
+ },
+ })
+ }
+ }
+ return addrs
+}
+
+// primaryAddress returns the primary address associated with this NIC.
+//
+// primaryAddress will return the first non-deprecated address if such an
+// address exists. If no non-deprecated address exists, the first deprecated
+// address will be returned.
+func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ list, ok := n.mu.primary[proto]
+ if !ok {
+ return tcpip.AddressWithPrefix{}
+ }
+
+ var deprecatedEndpoint *referencedNetworkEndpoint
+ for _, ref := range list {
+ // Don't include tentative, expired or tempory endpoints to avoid confusion
+ // and prevent the caller from using those.
+ switch ref.getKind() {
+ case permanentTentative, permanentExpired, temporary:
+ continue
+ }
+
+ if !ref.deprecated {
+ return tcpip.AddressWithPrefix{
+ Address: ref.ep.ID().LocalAddress,
+ PrefixLen: ref.ep.PrefixLen(),
+ }
+ }
+
+ if deprecatedEndpoint == nil {
+ deprecatedEndpoint = ref
+ }
+ }
+
+ if deprecatedEndpoint != nil {
+ return tcpip.AddressWithPrefix{
+ Address: deprecatedEndpoint.ep.ID().LocalAddress,
+ PrefixLen: deprecatedEndpoint.ep.PrefixLen(),
+ }
+ }
+
+ return tcpip.AddressWithPrefix{}
+}
+
+// AddAddressRange adds a range of addresses to n, so that it starts accepting
+// packets targeted at the given addresses and network protocol. The range is
+// given by a subnet address, and all addresses contained in the subnet are
+// used except for the subnet address itself and the subnet's broadcast
+// address.
+func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) {
+ n.mu.Lock()
+ n.mu.addressRanges = append(n.mu.addressRanges, subnet)
+ n.mu.Unlock()
+}
+
+// RemoveAddressRange removes the given address range from n.
+func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) {
+ n.mu.Lock()
+
+ // Use the same underlying array.
+ tmp := n.mu.addressRanges[:0]
+ for _, sub := range n.mu.addressRanges {
+ if sub != subnet {
+ tmp = append(tmp, sub)
+ }
+ }
+ n.mu.addressRanges = tmp
+
+ n.mu.Unlock()
+}
+
+// AddressRanges returns the Subnets associated with this NIC.
+func (n *NIC) AddressRanges() []tcpip.Subnet {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+ sns := make([]tcpip.Subnet, 0, len(n.mu.addressRanges)+len(n.mu.endpoints))
+ for nid := range n.mu.endpoints {
+ sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress))))
+ if err != nil {
+ // This should never happen as the mask has been carefully crafted to
+ // match the address.
+ panic("Invalid endpoint subnet: " + err.Error())
+ }
+ sns = append(sns, sn)
+ }
+ return append(sns, n.mu.addressRanges...)
+}
+
+// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required
+// by peb.
+//
+// n MUST be locked.
+func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) {
+ switch peb {
+ case CanBePrimaryEndpoint:
+ n.mu.primary[r.protocol] = append(n.mu.primary[r.protocol], r)
+ case FirstPrimaryEndpoint:
+ n.mu.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.mu.primary[r.protocol]...)
+ }
+}
+
+func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
+ id := *r.ep.ID()
+
+ // Nothing to do if the reference has already been replaced with a different
+ // one. This happens in the case where 1) this endpoint's ref count hit zero
+ // and was waiting (on the lock) to be removed and 2) the same address was
+ // re-added in the meantime by removing this endpoint from the list and
+ // adding a new one.
+ if n.mu.endpoints[id] != r {
+ return
+ }
+
+ if r.getKind() == permanent {
+ panic("Reference count dropped to zero before being removed")
+ }
+
+ delete(n.mu.endpoints, id)
+ refs := n.mu.primary[r.protocol]
+ for i, ref := range refs {
+ if ref == r {
+ n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ refs[len(refs)-1] = nil
+ break
+ }
+ }
+
+ r.ep.Close()
+}
+
+func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
+ n.mu.Lock()
+ n.removeEndpointLocked(r)
+ n.mu.Unlock()
+}
+
+func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
+ r, ok := n.mu.endpoints[NetworkEndpointID{addr}]
+ if !ok {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ kind := r.getKind()
+ if kind != permanent && kind != permanentTentative {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ switch r.protocol {
+ case header.IPv6ProtocolNumber:
+ return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAACInvalidation */)
+ default:
+ r.expireLocked()
+ return nil
+ }
+}
+
+func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACInvalidation bool) *tcpip.Error {
+ addr := r.addrWithPrefix()
+
+ isIPv6Unicast := header.IsV6UnicastAddress(addr.Address)
+
+ if isIPv6Unicast {
+ n.mu.ndp.stopDuplicateAddressDetection(addr.Address)
+
+ // If we are removing an address generated via SLAAC, cleanup
+ // its SLAAC resources and notify the integrator.
+ switch r.configType {
+ case slaac:
+ n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
+ case slaacTemp:
+ n.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
+ }
+ }
+
+ r.expireLocked()
+
+ // At this point the endpoint is deleted.
+
+ // If we are removing an IPv6 unicast address, leave the solicited-node
+ // multicast address.
+ //
+ // We ignore the tcpip.ErrBadLocalAddress error because the solicited-node
+ // multicast group may be left by user action.
+ if isIPv6Unicast {
+ snmc := header.SolicitedNodeAddr(addr.Address)
+ if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// RemoveAddress removes an address from n.
+func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+ return n.removePermanentAddressLocked(addr)
+}
+
+// joinGroup adds a new endpoint for the given multicast address, if none
+// exists yet. Otherwise it just increments its count.
+func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ return n.joinGroupLocked(protocol, addr)
+}
+
+// joinGroupLocked adds a new endpoint for the given multicast address, if none
+// exists yet. Otherwise it just increments its count. n MUST be locked before
+// joinGroupLocked is called.
+func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ // TODO(b/143102137): When implementing MLD, make sure MLD packets are
+ // not sent unless a valid link-local address is available for use on n
+ // as an MLD packet's source address must be a link-local address as
+ // outlined in RFC 3810 section 5.
+
+ id := NetworkEndpointID{addr}
+ joins := n.mu.mcastJoins[id]
+ if joins == 0 {
+ netProto, ok := n.stack.networkProtocols[protocol]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ if _, err := n.addAddressLocked(tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: netProto.DefaultPrefixLen(),
+ },
+ }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
+ return err
+ }
+ }
+ n.mu.mcastJoins[id] = joins + 1
+ return nil
+}
+
+// leaveGroup decrements the count for the given multicast address, and when it
+// reaches zero removes the endpoint for this address.
+func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ return n.leaveGroupLocked(addr, false /* force */)
+}
+
+// leaveGroupLocked decrements the count for the given multicast address, and
+// when it reaches zero removes the endpoint for this address. n MUST be locked
+// before leaveGroupLocked is called.
+//
+// If force is true, then the count for the multicast addres is ignored and the
+// endpoint will be removed immediately.
+func (n *NIC) leaveGroupLocked(addr tcpip.Address, force bool) *tcpip.Error {
+ id := NetworkEndpointID{addr}
+ joins, ok := n.mu.mcastJoins[id]
+ if !ok {
+ // There are no joins with this address on this NIC.
+ return tcpip.ErrBadLocalAddress
+ }
+
+ joins--
+ if force || joins == 0 {
+ // There are no outstanding joins or we are forced to leave, clean up.
+ delete(n.mu.mcastJoins, id)
+ return n.removePermanentAddressLocked(addr)
+ }
+
+ n.mu.mcastJoins[id] = joins
+ return nil
+}
+
+// isInGroup returns true if n has joined the multicast group addr.
+func (n *NIC) isInGroup(addr tcpip.Address) bool {
+ n.mu.RLock()
+ joins := n.mu.mcastJoins[NetworkEndpointID{addr}]
+ n.mu.RUnlock()
+
+ return joins != 0
+}
+
+func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) {
+ r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
+ r.RemoteLinkAddress = remotelinkAddr
+
+ ref.ep.HandlePacket(&r, pkt)
+ ref.decRef()
+}
+
+// DeliverNetworkPacket finds the appropriate network protocol endpoint and
+// hands the packet over for further processing. This function is called when
+// the NIC receives a packet from the link endpoint.
+// Note that the ownership of the slice backing vv is retained by the caller.
+// This rule applies only to the slice itself, not to the items of the slice;
+// the ownership of the items is not retained by the caller.
+func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ // If the NIC is not yet enabled, don't receive any packets.
+ if !enabled {
+ n.mu.RUnlock()
+
+ n.stats.DisabledRx.Packets.Increment()
+ n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
+ return
+ }
+
+ n.stats.Rx.Packets.Increment()
+ n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
+
+ netProto, ok := n.stack.networkProtocols[protocol]
+ if !ok {
+ n.mu.RUnlock()
+ n.stack.stats.UnknownProtocolRcvdPackets.Increment()
+ return
+ }
+
+ // If no local link layer address is provided, assume it was sent
+ // directly to this NIC.
+ if local == "" {
+ local = n.linkEP.LinkAddress()
+ }
+
+ // Are any packet sockets listening for this network protocol?
+ packetEPs := n.mu.packetEPs[protocol]
+ // Check whether there are packet sockets listening for every protocol.
+ // If we received a packet with protocol EthernetProtocolAll, then the
+ // previous for loop will have handled it.
+ if protocol != header.EthernetProtocolAll {
+ packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
+ }
+ n.mu.RUnlock()
+ for _, ep := range packetEPs {
+ ep.HandlePacket(n.id, local, protocol, pkt.Clone())
+ }
+
+ if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
+ n.stack.stats.IP.PacketsReceived.Increment()
+ }
+
+ // Parse headers.
+ transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
+ if !ok {
+ // The packet is too small to contain a network header.
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ return
+ }
+ if hasTransportHdr {
+ // Parse the transport header if present.
+ if state, ok := n.stack.transportProtocols[transProtoNum]; ok {
+ state.proto.Parse(pkt)
+ }
+ }
+
+ src, dst := netProto.ParseAddresses(pkt.NetworkHeader)
+
+ if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil {
+ // The source address is one of our own, so we never should have gotten a
+ // packet like this unless handleLocal is false. Loopback also calls this
+ // function even though the packets didn't come from the physical interface
+ // so don't drop those.
+ n.stack.stats.IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+
+ // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet.
+ // Loopback traffic skips the prerouting chain.
+ if protocol == header.IPv4ProtocolNumber && !n.isLoopback() {
+ // iptables filtering.
+ ipt := n.stack.IPTables()
+ address := n.primaryAddress(protocol)
+ if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok {
+ // iptables is telling us to drop the packet.
+ return
+ }
+ }
+
+ if ref := n.getRef(protocol, dst); ref != nil {
+ handlePacket(protocol, dst, src, n.linkEP.LinkAddress(), remote, ref, pkt)
+ return
+ }
+
+ // This NIC doesn't care about the packet. Find a NIC that cares about the
+ // packet and forward it to the NIC.
+ //
+ // TODO: Should we be forwarding the packet even if promiscuous?
+ if n.stack.Forwarding() {
+ r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */)
+ if err != nil {
+ n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
+ return
+ }
+
+ // Found a NIC.
+ n := r.ref.nic
+ n.mu.RLock()
+ ref, ok := n.mu.endpoints[NetworkEndpointID{dst}]
+ ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef()
+ n.mu.RUnlock()
+ if ok {
+ r.LocalLinkAddress = n.linkEP.LinkAddress()
+ r.RemoteLinkAddress = remote
+ r.RemoteAddress = src
+ // TODO(b/123449044): Update the source NIC as well.
+ ref.ep.HandlePacket(&r, pkt)
+ ref.decRef()
+ r.Release()
+ return
+ }
+
+ // n doesn't have a destination endpoint.
+ // Send the packet out of n.
+ // TODO(b/128629022): move this logic to route.WritePacket.
+ if ch, err := r.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt)
+ // forwarder will release route.
+ return
+ }
+ n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
+ r.Release()
+ return
+ }
+
+ // The link-address resolution finished immediately.
+ n.forwardPacket(&r, protocol, pkt)
+ r.Release()
+ return
+ }
+
+ // If a packet socket handled the packet, don't treat it as invalid.
+ if len(packetEPs) == 0 {
+ n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
+ }
+}
+
+func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+ // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this
+ // by having lower layers explicity write each header instead of just
+ // pkt.Header.
+
+ // pkt may have set its NetworkHeader and TransportHeader. If we're
+ // forwarding, we'll have to copy them into pkt.Header.
+ pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()) + len(pkt.NetworkHeader) + len(pkt.TransportHeader))
+ if n := copy(pkt.Header.Prepend(len(pkt.TransportHeader)), pkt.TransportHeader); n != len(pkt.TransportHeader) {
+ panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.TransportHeader)))
+ }
+ if n := copy(pkt.Header.Prepend(len(pkt.NetworkHeader)), pkt.NetworkHeader); n != len(pkt.NetworkHeader) {
+ panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.NetworkHeader)))
+ }
+
+ // WritePacket takes ownership of pkt, calculate numBytes first.
+ numBytes := pkt.Header.UsedLength() + pkt.Data.Size()
+
+ if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return
+ }
+
+ n.stats.Tx.Packets.Increment()
+ n.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+}
+
+// DeliverTransportPacket delivers the packets to the appropriate transport
+// protocol endpoint.
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) {
+ state, ok := n.stack.transportProtocols[protocol]
+ if !ok {
+ n.stack.stats.UnknownProtocolRcvdPackets.Increment()
+ return
+ }
+
+ transProto := state.proto
+
+ // Raw socket packets are delivered based solely on the transport
+ // protocol number. We do not inspect the payload to ensure it's
+ // validly formed.
+ n.stack.demux.deliverRawPacket(r, protocol, pkt)
+
+ // TransportHeader is nil only when pkt is an ICMP packet or was reassembled
+ // from fragments.
+ if pkt.TransportHeader == nil {
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
+ // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
+ if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
+ // ICMP packets may be longer, but until icmp.Parse is implemented, here
+ // we parse it using the minimum size.
+ transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize())
+ if !ok {
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ return
+ }
+ pkt.TransportHeader = transHeader
+ pkt.Data.TrimFront(len(pkt.TransportHeader))
+ } else {
+ // This is either a bad packet or was re-assembled from fragments.
+ transProto.Parse(pkt)
+ }
+ }
+
+ if len(pkt.TransportHeader) < transProto.MinimumPacketSize() {
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ return
+ }
+
+ srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader)
+ if err != nil {
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ return
+ }
+
+ id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
+ if n.stack.demux.deliverPacket(r, protocol, pkt, id) {
+ return
+ }
+
+ // Try to deliver to per-stack default handler.
+ if state.defaultHandler != nil {
+ if state.defaultHandler(r, id, pkt) {
+ return
+ }
+ }
+
+ // We could not find an appropriate destination for this packet, so
+ // deliver it to the global handler.
+ if !transProto.HandleUnknownDestinationPacket(r, id, pkt) {
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ }
+}
+
+// DeliverTransportControlPacket delivers control packets to the appropriate
+// transport protocol endpoint.
+func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) {
+ state, ok := n.stack.transportProtocols[trans]
+ if !ok {
+ return
+ }
+
+ transProto := state.proto
+
+ // ICMPv4 only guarantees that 8 bytes of the transport protocol will
+ // be present in the payload. We know that the ports are within the
+ // first 8 bytes for all known transport protocols.
+ transHeader, ok := pkt.Data.PullUp(8)
+ if !ok {
+ return
+ }
+
+ srcPort, dstPort, err := transProto.ParsePorts(transHeader)
+ if err != nil {
+ return
+ }
+
+ id := TransportEndpointID{srcPort, local, dstPort, remote}
+ if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, pkt, id) {
+ return
+ }
+}
+
+// ID returns the identifier of n.
+func (n *NIC) ID() tcpip.NICID {
+ return n.id
+}
+
+// Name returns the name of n.
+func (n *NIC) Name() string {
+ return n.name
+}
+
+// Stack returns the instance of the Stack that owns this NIC.
+func (n *NIC) Stack() *Stack {
+ return n.stack
+}
+
+// LinkEndpoint returns the link endpoint of n.
+func (n *NIC) LinkEndpoint() LinkEndpoint {
+ return n.linkEP
+}
+
+// isAddrTentative returns true if addr is tentative on n.
+//
+// Note that if addr is not associated with n, then this function will return
+// false. It will only return true if the address is associated with the NIC
+// AND it is tentative.
+func (n *NIC) isAddrTentative(addr tcpip.Address) bool {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
+ if !ok {
+ return false
+ }
+
+ return ref.getKind() == permanentTentative
+}
+
+// dupTentativeAddrDetected attempts to inform n that a tentative addr is a
+// duplicate on a link.
+//
+// dupTentativeAddrDetected will remove the tentative address if it exists. If
+// the address was generated via SLAAC, an attempt will be made to generate a
+// new address.
+func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
+ if !ok {
+ return tcpip.ErrBadAddress
+ }
+
+ if ref.getKind() != permanentTentative {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // If the address is a SLAAC address, do not invalidate its SLAAC prefix as a
+ // new address will be generated for it.
+ if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACInvalidation */); err != nil {
+ return err
+ }
+
+ prefix := ref.addrWithPrefix().Subnet()
+
+ switch ref.configType {
+ case slaac:
+ n.mu.ndp.regenerateSLAACAddr(prefix)
+ case slaacTemp:
+ // Do not reset the generation attempts counter for the prefix as the
+ // temporary address is being regenerated in response to a DAD conflict.
+ n.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */)
+ }
+
+ return nil
+}
+
+// setNDPConfigs sets the NDP configurations for n.
+//
+// Note, if c contains invalid NDP configuration values, it will be fixed to
+// use default values for the erroneous values.
+func (n *NIC) setNDPConfigs(c NDPConfigurations) {
+ c.validate()
+
+ n.mu.Lock()
+ n.mu.ndp.configs = c
+ n.mu.Unlock()
+}
+
+// handleNDPRA handles an NDP Router Advertisement message that arrived on n.
+func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ n.mu.ndp.handleRA(ip, ra)
+}
+
+type networkEndpointKind int32
+
+const (
+ // A permanentTentative endpoint is a permanent address that is not yet
+ // considered to be fully bound to an interface in the traditional
+ // sense. That is, the address is associated with a NIC, but packets
+ // destined to the address MUST NOT be accepted and MUST be silently
+ // dropped, and the address MUST NOT be used as a source address for
+ // outgoing packets. For IPv6, addresses will be of this kind until
+ // NDP's Duplicate Address Detection has resolved, or be deleted if
+ // the process results in detecting a duplicate address.
+ permanentTentative networkEndpointKind = iota
+
+ // A permanent endpoint is created by adding a permanent address (vs. a
+ // temporary one) to the NIC. Its reference count is biased by 1 to avoid
+ // removal when no route holds a reference to it. It is removed by explicitly
+ // removing the permanent address from the NIC.
+ permanent
+
+ // An expired permanent endpoint is a permanent endpoint that had its address
+ // removed from the NIC, and it is waiting to be removed once no more routes
+ // hold a reference to it. This is achieved by decreasing its reference count
+ // by 1. If its address is re-added before the endpoint is removed, its type
+ // changes back to permanent and its reference count increases by 1 again.
+ permanentExpired
+
+ // A temporary endpoint is created for spoofing outgoing packets, or when in
+ // promiscuous mode and accepting incoming packets that don't match any
+ // permanent endpoint. Its reference count is not biased by 1 and the
+ // endpoint is removed immediately when no more route holds a reference to
+ // it. A temporary endpoint can be promoted to permanent if its address
+ // is added permanently.
+ temporary
+)
+
+func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ eps, ok := n.mu.packetEPs[netProto]
+ if !ok {
+ return tcpip.ErrNotSupported
+ }
+ n.mu.packetEPs[netProto] = append(eps, ep)
+
+ return nil
+}
+
+func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ eps, ok := n.mu.packetEPs[netProto]
+ if !ok {
+ return
+ }
+
+ for i, epOther := range eps {
+ if epOther == ep {
+ n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...)
+ return
+ }
+ }
+}
+
+type networkEndpointConfigType int32
+
+const (
+ // A statically configured endpoint is an address that was added by
+ // some user-specified action (adding an explicit address, joining a
+ // multicast group).
+ static networkEndpointConfigType = iota
+
+ // A SLAAC configured endpoint is an IPv6 endpoint that was added by
+ // SLAAC as per RFC 4862 section 5.5.3.
+ slaac
+
+ // A temporary SLAAC configured endpoint is an IPv6 endpoint that was added by
+ // SLAAC as per RFC 4941. Temporary SLAAC addresses are short-lived and are
+ // not expected to be valid (or preferred) forever; hence the term temporary.
+ slaacTemp
+)
+
+type referencedNetworkEndpoint struct {
+ ep NetworkEndpoint
+ nic *NIC
+ protocol tcpip.NetworkProtocolNumber
+
+ // linkCache is set if link address resolution is enabled for this
+ // protocol. Set to nil otherwise.
+ linkCache LinkAddressCache
+
+ // refs is counting references held for this endpoint. When refs hits zero it
+ // triggers the automatic removal of the endpoint from the NIC.
+ refs int32
+
+ // networkEndpointKind must only be accessed using {get,set}Kind().
+ kind networkEndpointKind
+
+ // configType is the method that was used to configure this endpoint.
+ // This must never change except during endpoint creation and promotion to
+ // permanent.
+ configType networkEndpointConfigType
+
+ // deprecated indicates whether or not the endpoint should be considered
+ // deprecated. That is, when deprecated is true, other endpoints that are not
+ // deprecated should be preferred.
+ deprecated bool
+}
+
+func (r *referencedNetworkEndpoint) addrWithPrefix() tcpip.AddressWithPrefix {
+ return tcpip.AddressWithPrefix{
+ Address: r.ep.ID().LocalAddress,
+ PrefixLen: r.ep.PrefixLen(),
+ }
+}
+
+func (r *referencedNetworkEndpoint) getKind() networkEndpointKind {
+ return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind)))
+}
+
+func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) {
+ atomic.StoreInt32((*int32)(&r.kind), int32(kind))
+}
+
+// isValidForOutgoing returns true if the endpoint can be used to send out a
+// packet. It requires the endpoint to not be marked expired (i.e., its address)
+// has been removed) unless the NIC is in spoofing mode, or temporary.
+func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
+ r.nic.mu.RLock()
+ defer r.nic.mu.RUnlock()
+
+ return r.isValidForOutgoingRLocked()
+}
+
+// isValidForOutgoingRLocked is the same as isValidForOutgoing but requires
+// r.nic.mu to be read locked.
+func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool {
+ if !r.nic.mu.enabled {
+ return false
+ }
+
+ return r.isAssignedRLocked(r.nic.mu.spoofing)
+}
+
+// isAssignedRLocked returns true if r is considered to be assigned to the NIC.
+//
+// r.nic.mu must be read locked.
+func (r *referencedNetworkEndpoint) isAssignedRLocked(spoofingOrPromiscuous bool) bool {
+ switch r.getKind() {
+ case permanentTentative:
+ return false
+ case permanentExpired:
+ return spoofingOrPromiscuous
+ default:
+ return true
+ }
+}
+
+// expireLocked decrements the reference count and marks the permanent endpoint
+// as expired.
+func (r *referencedNetworkEndpoint) expireLocked() {
+ r.setKind(permanentExpired)
+ r.decRefLocked()
+}
+
+// decRef decrements the ref count and cleans up the endpoint once it reaches
+// zero.
+func (r *referencedNetworkEndpoint) decRef() {
+ if atomic.AddInt32(&r.refs, -1) == 0 {
+ r.nic.removeEndpoint(r)
+ }
+}
+
+// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is
+// locked.
+func (r *referencedNetworkEndpoint) decRefLocked() {
+ if atomic.AddInt32(&r.refs, -1) == 0 {
+ r.nic.removeEndpointLocked(r)
+ }
+}
+
+// incRef increments the ref count. It must only be called when the caller is
+// known to be holding a reference to the endpoint, otherwise tryIncRef should
+// be used.
+func (r *referencedNetworkEndpoint) incRef() {
+ atomic.AddInt32(&r.refs, 1)
+}
+
+// tryIncRef attempts to increment the ref count from n to n+1, but only if n is
+// not zero. That is, it will increment the count if the endpoint is still
+// alive, and do nothing if it has already been clean up.
+func (r *referencedNetworkEndpoint) tryIncRef() bool {
+ for {
+ v := atomic.LoadInt32(&r.refs)
+ if v == 0 {
+ return false
+ }
+
+ if atomic.CompareAndSwapInt32(&r.refs, v, v+1) {
+ return true
+ }
+ }
+}
+
+// stack returns the Stack instance that owns the underlying endpoint.
+func (r *referencedNetworkEndpoint) stack() *Stack {
+ return r.nic.stack
+}
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
new file mode 100644
index 000000000..31f865260
--- /dev/null
+++ b/pkg/tcpip/stack/nic_test.go
@@ -0,0 +1,318 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "math"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+var _ LinkEndpoint = (*testLinkEndpoint)(nil)
+
+// A LinkEndpoint that throws away outgoing packets.
+//
+// We use this instead of the channel endpoint as the channel package depends on
+// the stack package which this test lives in, causing a cyclic dependency.
+type testLinkEndpoint struct {
+ dispatcher NetworkDispatcher
+}
+
+// Attach implements LinkEndpoint.Attach.
+func (e *testLinkEndpoint) Attach(dispatcher NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements LinkEndpoint.IsAttached.
+func (e *testLinkEndpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements LinkEndpoint.MTU.
+func (*testLinkEndpoint) MTU() uint32 {
+ return math.MaxUint16
+}
+
+// Capabilities implements LinkEndpoint.Capabilities.
+func (*testLinkEndpoint) Capabilities() LinkEndpointCapabilities {
+ return CapabilityResolutionRequired
+}
+
+// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
+func (*testLinkEndpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (*testLinkEndpoint) LinkAddress() tcpip.LinkAddress {
+ return ""
+}
+
+// Wait implements LinkEndpoint.Wait.
+func (*testLinkEndpoint) Wait() {}
+
+// WritePacket implements LinkEndpoint.WritePacket.
+func (e *testLinkEndpoint) WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error {
+ return nil
+}
+
+// WritePackets implements LinkEndpoint.WritePackets.
+func (e *testLinkEndpoint) WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ // Our tests don't use this so we don't support it.
+ return 0, tcpip.ErrNotSupported
+}
+
+// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
+func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
+ // Our tests don't use this so we don't support it.
+ return tcpip.ErrNotSupported
+}
+
+var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
+
+// An IPv6 NetworkEndpoint that throws away outgoing packets.
+//
+// We use this instead of ipv6.endpoint because the ipv6 package depends on
+// the stack package which this test lives in, causing a cyclic dependency.
+type testIPv6Endpoint struct {
+ nicID tcpip.NICID
+ id NetworkEndpointID
+ prefixLen int
+ linkEP LinkEndpoint
+ protocol *testIPv6Protocol
+}
+
+// DefaultTTL implements NetworkEndpoint.DefaultTTL.
+func (*testIPv6Endpoint) DefaultTTL() uint8 {
+ return 0
+}
+
+// MTU implements NetworkEndpoint.MTU.
+func (e *testIPv6Endpoint) MTU() uint32 {
+ return e.linkEP.MTU() - header.IPv6MinimumSize
+}
+
+// Capabilities implements NetworkEndpoint.Capabilities.
+func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength.
+func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
+}
+
+// WritePacket implements NetworkEndpoint.WritePacket.
+func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) *tcpip.Error {
+ return nil
+}
+
+// WritePackets implements NetworkEndpoint.WritePackets.
+func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, *tcpip.Error) {
+ // Our tests don't use this so we don't support it.
+ return 0, tcpip.ErrNotSupported
+}
+
+// WriteHeaderIncludedPacket implements
+// NetworkEndpoint.WriteHeaderIncludedPacket.
+func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip.Error {
+ // Our tests don't use this so we don't support it.
+ return tcpip.ErrNotSupported
+}
+
+// ID implements NetworkEndpoint.ID.
+func (e *testIPv6Endpoint) ID() *NetworkEndpointID {
+ return &e.id
+}
+
+// PrefixLen implements NetworkEndpoint.PrefixLen.
+func (e *testIPv6Endpoint) PrefixLen() int {
+ return e.prefixLen
+}
+
+// NICID implements NetworkEndpoint.NICID.
+func (e *testIPv6Endpoint) NICID() tcpip.NICID {
+ return e.nicID
+}
+
+// HandlePacket implements NetworkEndpoint.HandlePacket.
+func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) {
+}
+
+// Close implements NetworkEndpoint.Close.
+func (*testIPv6Endpoint) Close() {}
+
+// NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber.
+func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+var _ NetworkProtocol = (*testIPv6Protocol)(nil)
+
+// An IPv6 NetworkProtocol that supports the bare minimum to make a stack
+// believe it supports IPv6.
+//
+// We use this instead of ipv6.protocol because the ipv6 package depends on
+// the stack package which this test lives in, causing a cyclic dependency.
+type testIPv6Protocol struct{}
+
+// Number implements NetworkProtocol.Number.
+func (*testIPv6Protocol) Number() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+// MinimumPacketSize implements NetworkProtocol.MinimumPacketSize.
+func (*testIPv6Protocol) MinimumPacketSize() int {
+ return header.IPv6MinimumSize
+}
+
+// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen.
+func (*testIPv6Protocol) DefaultPrefixLen() int {
+ return header.IPv6AddressSize * 8
+}
+
+// ParseAddresses implements NetworkProtocol.ParseAddresses.
+func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.IPv6(v)
+ return h.SourceAddress(), h.DestinationAddress()
+}
+
+// NewEndpoint implements NetworkProtocol.NewEndpoint.
+func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) {
+ return &testIPv6Endpoint{
+ nicID: nicID,
+ id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
+ prefixLen: addrWithPrefix.PrefixLen,
+ linkEP: linkEP,
+ protocol: p,
+ }, nil
+}
+
+// SetOption implements NetworkProtocol.SetOption.
+func (*testIPv6Protocol) SetOption(interface{}) *tcpip.Error {
+ return nil
+}
+
+// Option implements NetworkProtocol.Option.
+func (*testIPv6Protocol) Option(interface{}) *tcpip.Error {
+ return nil
+}
+
+// Close implements NetworkProtocol.Close.
+func (*testIPv6Protocol) Close() {}
+
+// Wait implements NetworkProtocol.Wait.
+func (*testIPv6Protocol) Wait() {}
+
+// Parse implements NetworkProtocol.Parse.
+func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
+ return 0, false, false
+}
+
+var _ LinkAddressResolver = (*testIPv6Protocol)(nil)
+
+// LinkAddressProtocol implements LinkAddressResolver.
+func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+// LinkAddressRequest implements LinkAddressResolver.
+func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
+ return nil
+}
+
+// ResolveStaticAddress implements LinkAddressResolver.
+func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if header.IsV6MulticastAddress(addr) {
+ return header.EthernetAddressFromMulticastIPv6Address(addr), true
+ }
+ return "", false
+}
+
+// Test the race condition where a NIC is removed and an RS timer fires at the
+// same time.
+func TestRemoveNICWhileHandlingRSTimer(t *testing.T) {
+ const (
+ nicID = 1
+
+ maxRtrSolicitations = 5
+ )
+
+ e := testLinkEndpoint{}
+ s := New(Options{
+ NetworkProtocols: []NetworkProtocol{&testIPv6Protocol{}},
+ NDPConfigs: NDPConfigurations{
+ MaxRtrSolicitations: maxRtrSolicitations,
+ RtrSolicitationInterval: minimumRtrSolicitationInterval,
+ },
+ })
+
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ s.mu.Lock()
+ // Wait for the router solicitation timer to fire and block trying to obtain
+ // the stack lock when doing link address resolution.
+ time.Sleep(minimumRtrSolicitationInterval * 2)
+ if err := s.removeNICLocked(nicID); err != nil {
+ t.Fatalf("s.removeNICLocked(%d) = %s", nicID, err)
+ }
+ s.mu.Unlock()
+}
+
+func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
+ // When the NIC is disabled, the only field that matters is the stats field.
+ // This test is limited to stats counter checks.
+ nic := NIC{
+ stats: makeNICStats(),
+ }
+
+ if got := nic.stats.DisabledRx.Packets.Value(); got != 0 {
+ t.Errorf("got DisabledRx.Packets = %d, want = 0", got)
+ }
+ if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 {
+ t.Errorf("got DisabledRx.Bytes = %d, want = 0", got)
+ }
+ if got := nic.stats.Rx.Packets.Value(); got != 0 {
+ t.Errorf("got Rx.Packets = %d, want = 0", got)
+ }
+ if got := nic.stats.Rx.Bytes.Value(); got != 0 {
+ t.Errorf("got Rx.Bytes = %d, want = 0", got)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ nic.DeliverNetworkPacket("", "", 0, &PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()})
+
+ if got := nic.stats.DisabledRx.Packets.Value(); got != 1 {
+ t.Errorf("got DisabledRx.Packets = %d, want = 1", got)
+ }
+ if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 {
+ t.Errorf("got DisabledRx.Bytes = %d, want = 4", got)
+ }
+ if got := nic.stats.Rx.Packets.Value(); got != 0 {
+ t.Errorf("got Rx.Packets = %d, want = 0", got)
+ }
+ if got := nic.stats.Rx.Bytes.Value(); got != 0 {
+ t.Errorf("got Rx.Bytes = %d, want = 0", got)
+ }
+}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
new file mode 100644
index 000000000..1b5da6017
--- /dev/null
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -0,0 +1,115 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at //
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+// A PacketBuffer contains all the data of a network packet.
+//
+// As a PacketBuffer traverses up the stack, it may be necessary to pass it to
+// multiple endpoints. Clone() should be called in such cases so that
+// modifications to the Data field do not affect other copies.
+type PacketBuffer struct {
+ _ noCopy
+
+ // PacketBufferEntry is used to build an intrusive list of
+ // PacketBuffers.
+ PacketBufferEntry
+
+ // Data holds the payload of the packet. For inbound packets, it also
+ // holds the headers, which are consumed as the packet moves up the
+ // stack. Headers are guaranteed not to be split across views.
+ //
+ // The bytes backing Data are immutable, but Data itself may be trimmed
+ // or otherwise modified.
+ Data buffer.VectorisedView
+
+ // Header holds the headers of outbound packets. As a packet is passed
+ // down the stack, each layer adds to Header. Note that forwarded
+ // packets don't populate Headers on their way out -- their headers and
+ // payload are never parsed out and remain in Data.
+ //
+ // TODO(gvisor.dev/issue/170): Forwarded packets don't currently
+ // populate Header, but should. This will be doable once early parsing
+ // (https://github.com/google/gvisor/pull/1995) is supported.
+ Header buffer.Prependable
+
+ // These fields are used by both inbound and outbound packets. They
+ // typically overlap with the Data and Header fields.
+ //
+ // The bytes backing these views are immutable. Each field may be nil
+ // if either it has not been set yet or no such header exists (e.g.
+ // packets sent via loopback may not have a link header).
+ //
+ // These fields may be Views into other slices (either Data or Header).
+ // SR dosen't support this, so deep copies are necessary in some cases.
+ LinkHeader buffer.View
+ NetworkHeader buffer.View
+ TransportHeader buffer.View
+
+ // Hash is the transport layer hash of this packet. A value of zero
+ // indicates no valid hash has been set.
+ Hash uint32
+
+ // Owner is implemented by task to get the uid and gid.
+ // Only set for locally generated packets.
+ Owner tcpip.PacketOwner
+
+ // The following fields are only set by the qdisc layer when the packet
+ // is added to a queue.
+ EgressRoute *Route
+ GSOOptions *GSO
+ NetworkProtocolNumber tcpip.NetworkProtocolNumber
+
+ // NatDone indicates if the packet has been manipulated as per NAT
+ // iptables rule.
+ NatDone bool
+}
+
+// Clone makes a copy of pk. It clones the Data field, which creates a new
+// VectorisedView but does not deep copy the underlying bytes.
+//
+// Clone also does not deep copy any of its other fields.
+//
+// FIXME(b/153685824): Data gets copied but not other header references.
+func (pk *PacketBuffer) Clone() *PacketBuffer {
+ return &PacketBuffer{
+ PacketBufferEntry: pk.PacketBufferEntry,
+ Data: pk.Data.Clone(nil),
+ Header: pk.Header,
+ LinkHeader: pk.LinkHeader,
+ NetworkHeader: pk.NetworkHeader,
+ TransportHeader: pk.TransportHeader,
+ Hash: pk.Hash,
+ Owner: pk.Owner,
+ EgressRoute: pk.EgressRoute,
+ GSOOptions: pk.GSOOptions,
+ NetworkProtocolNumber: pk.NetworkProtocolNumber,
+ NatDone: pk.NatDone,
+ }
+}
+
+// noCopy may be embedded into structs which must not be copied
+// after the first use.
+//
+// See https://golang.org/issues/8005#issuecomment-190753527
+// for details.
+type noCopy struct{}
+
+// Lock is a no-op used by -copylocks checker from `go vet`.
+func (*noCopy) Lock() {}
+func (*noCopy) Unlock() {}
diff --git a/pkg/tcpip/stack/rand.go b/pkg/tcpip/stack/rand.go
new file mode 100644
index 000000000..421fb5c15
--- /dev/null
+++ b/pkg/tcpip/stack/rand.go
@@ -0,0 +1,40 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ mathrand "math/rand"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// lockedRandomSource provides a threadsafe rand.Source.
+type lockedRandomSource struct {
+ mu sync.Mutex
+ src mathrand.Source
+}
+
+func (r *lockedRandomSource) Int63() (n int64) {
+ r.mu.Lock()
+ n = r.src.Int63()
+ r.mu.Unlock()
+ return n
+}
+
+func (r *lockedRandomSource) Seed(seed int64) {
+ r.mu.Lock()
+ r.src.Seed(seed)
+ r.mu.Unlock()
+}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
new file mode 100644
index 000000000..5cbc946b6
--- /dev/null
+++ b/pkg/tcpip/stack/registration.go
@@ -0,0 +1,560 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// NetworkEndpointID is the identifier of a network layer protocol endpoint.
+// Currently the local address is sufficient because all supported protocols
+// (i.e., IPv4 and IPv6) have different sizes for their addresses.
+type NetworkEndpointID struct {
+ LocalAddress tcpip.Address
+}
+
+// TransportEndpointID is the identifier of a transport layer protocol endpoint.
+//
+// +stateify savable
+type TransportEndpointID struct {
+ // LocalPort is the local port associated with the endpoint.
+ LocalPort uint16
+
+ // LocalAddress is the local [network layer] address associated with
+ // the endpoint.
+ LocalAddress tcpip.Address
+
+ // RemotePort is the remote port associated with the endpoint.
+ RemotePort uint16
+
+ // RemoteAddress it the remote [network layer] address associated with
+ // the endpoint.
+ RemoteAddress tcpip.Address
+}
+
+// ControlType is the type of network control message.
+type ControlType int
+
+// The following are the allowed values for ControlType values.
+const (
+ ControlPacketTooBig ControlType = iota
+ ControlPortUnreachable
+ ControlUnknown
+)
+
+// TransportEndpoint is the interface that needs to be implemented by transport
+// protocol (e.g., tcp, udp) endpoints that can handle packets.
+type TransportEndpoint interface {
+ // UniqueID returns an unique ID for this transport endpoint.
+ UniqueID() uint64
+
+ // HandlePacket is called by the stack when new packets arrive to
+ // this transport endpoint. It sets pkt.TransportHeader.
+ //
+ // HandlePacket takes ownership of pkt.
+ HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer)
+
+ // HandleControlPacket is called by the stack when new control (e.g.
+ // ICMP) packets arrive to this transport endpoint.
+ // HandleControlPacket takes ownership of pkt.
+ HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer)
+
+ // Abort initiates an expedited endpoint teardown. It puts the endpoint
+ // in a closed state and frees all resources associated with it. This
+ // cleanup may happen asynchronously. Wait can be used to block on this
+ // asynchronous cleanup.
+ Abort()
+
+ // Wait waits for any worker goroutines owned by the endpoint to stop.
+ //
+ // An endpoint can be requested to stop its worker goroutines by calling
+ // its Close method.
+ //
+ // Wait will not block if the endpoint hasn't started any goroutines
+ // yet, even if it might later.
+ Wait()
+}
+
+// RawTransportEndpoint is the interface that needs to be implemented by raw
+// transport protocol endpoints. RawTransportEndpoints receive the entire
+// packet - including the network and transport headers - as delivered to
+// netstack.
+type RawTransportEndpoint interface {
+ // HandlePacket is called by the stack when new packets arrive to
+ // this transport endpoint. The packet contains all data from the link
+ // layer up.
+ //
+ // HandlePacket takes ownership of pkt.
+ HandlePacket(r *Route, pkt *PacketBuffer)
+}
+
+// PacketEndpoint is the interface that needs to be implemented by packet
+// transport protocol endpoints. These endpoints receive link layer headers in
+// addition to whatever they contain (usually network and transport layer
+// headers and a payload).
+type PacketEndpoint interface {
+ // HandlePacket is called by the stack when new packets arrive that
+ // match the endpoint.
+ //
+ // Implementers should treat packet as immutable and should copy it
+ // before before modification.
+ //
+ // linkHeader may have a length of 0, in which case the PacketEndpoint
+ // should construct its own ethernet header for applications.
+ //
+ // HandlePacket takes ownership of pkt.
+ HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
+}
+
+// TransportProtocol is the interface that needs to be implemented by transport
+// protocols (e.g., tcp, udp) that want to be part of the networking stack.
+type TransportProtocol interface {
+ // Number returns the transport protocol number.
+ Number() tcpip.TransportProtocolNumber
+
+ // NewEndpoint creates a new endpoint of the transport protocol.
+ NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+
+ // NewRawEndpoint creates a new raw endpoint of the transport protocol.
+ NewRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+
+ // MinimumPacketSize returns the minimum valid packet size of this
+ // transport protocol. The stack automatically drops any packets smaller
+ // than this targeted at this protocol.
+ MinimumPacketSize() int
+
+ // ParsePorts returns the source and destination ports stored in a
+ // packet of this protocol.
+ ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
+
+ // HandleUnknownDestinationPacket handles packets targeted at this
+ // protocol but that don't match any existing endpoint. For example,
+ // it is targeted at a port that have no listeners.
+ //
+ // The return value indicates whether the packet was well-formed (for
+ // stats purposes only).
+ //
+ // HandleUnknownDestinationPacket takes ownership of pkt.
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
+
+ // SetOption allows enabling/disabling protocol specific features.
+ // SetOption returns an error if the option is not supported or the
+ // provided option value is invalid.
+ SetOption(option interface{}) *tcpip.Error
+
+ // Option allows retrieving protocol specific option values.
+ // Option returns an error if the option is not supported or the
+ // provided option value is invalid.
+ Option(option interface{}) *tcpip.Error
+
+ // Close requests that any worker goroutines owned by the protocol
+ // stop.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the protocol to stop.
+ Wait()
+
+ // Parse sets pkt.TransportHeader and trims pkt.Data appropriately. It does
+ // neither and returns false if pkt.Data is too small, i.e. pkt.Data.Size() <
+ // MinimumPacketSize()
+ Parse(pkt *PacketBuffer) (ok bool)
+}
+
+// TransportDispatcher contains the methods used by the network stack to deliver
+// packets to the appropriate transport endpoint after it has been handled by
+// the network layer.
+type TransportDispatcher interface {
+ // DeliverTransportPacket delivers packets to the appropriate
+ // transport protocol endpoint.
+ //
+ // pkt.NetworkHeader must be set before calling DeliverTransportPacket.
+ //
+ // DeliverTransportPacket takes ownership of pkt.
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer)
+
+ // DeliverTransportControlPacket delivers control packets to the
+ // appropriate transport protocol endpoint.
+ //
+ // pkt.NetworkHeader must be set before calling
+ // DeliverTransportControlPacket.
+ //
+ // DeliverTransportControlPacket takes ownership of pkt.
+ DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer)
+}
+
+// PacketLooping specifies where an outbound packet should be sent.
+type PacketLooping byte
+
+const (
+ // PacketOut indicates that the packet should be passed to the link
+ // endpoint.
+ PacketOut PacketLooping = 1 << iota
+
+ // PacketLoop indicates that the packet should be handled locally.
+ PacketLoop
+)
+
+// NetworkHeaderParams are the header parameters given as input by the
+// transport endpoint to the network.
+type NetworkHeaderParams struct {
+ // Protocol refers to the transport protocol number.
+ Protocol tcpip.TransportProtocolNumber
+
+ // TTL refers to Time To Live field of the IP-header.
+ TTL uint8
+
+ // TOS refers to TypeOfService or TrafficClass field of the IP-header.
+ TOS uint8
+}
+
+// NetworkEndpoint is the interface that needs to be implemented by endpoints
+// of network layer protocols (e.g., ipv4, ipv6).
+type NetworkEndpoint interface {
+ // DefaultTTL is the default time-to-live value (or hop limit, in ipv6)
+ // for this endpoint.
+ DefaultTTL() uint8
+
+ // MTU is the maximum transmission unit for this endpoint. This is
+ // generally calculated as the MTU of the underlying data link endpoint
+ // minus the network endpoint max header length.
+ MTU() uint32
+
+ // Capabilities returns the set of capabilities supported by the
+ // underlying link-layer endpoint.
+ Capabilities() LinkEndpointCapabilities
+
+ // MaxHeaderLength returns the maximum size the network (and lower
+ // level layers combined) headers can have. Higher levels use this
+ // information to reserve space in the front of the packets they're
+ // building.
+ MaxHeaderLength() uint16
+
+ // WritePacket writes a packet to the given destination address and
+ // protocol. It takes ownership of pkt. pkt.TransportHeader must have already
+ // been set.
+ WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error
+
+ // WritePackets writes packets to the given destination address and
+ // protocol. pkts must not be zero length. It takes ownership of pkts and
+ // underlying packets.
+ WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error)
+
+ // WriteHeaderIncludedPacket writes a packet that includes a network
+ // header to the given destination address. It takes ownership of pkt.
+ WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error
+
+ // ID returns the network protocol endpoint ID.
+ ID() *NetworkEndpointID
+
+ // PrefixLen returns the network endpoint's subnet prefix length in bits.
+ PrefixLen() int
+
+ // NICID returns the id of the NIC this endpoint belongs to.
+ NICID() tcpip.NICID
+
+ // HandlePacket is called by the link layer when new packets arrive to
+ // this network endpoint. It sets pkt.NetworkHeader.
+ //
+ // HandlePacket takes ownership of pkt.
+ HandlePacket(r *Route, pkt *PacketBuffer)
+
+ // Close is called when the endpoint is reomved from a stack.
+ Close()
+
+ // NetworkProtocolNumber returns the tcpip.NetworkProtocolNumber for
+ // this endpoint.
+ NetworkProtocolNumber() tcpip.NetworkProtocolNumber
+}
+
+// NetworkProtocol is the interface that needs to be implemented by network
+// protocols (e.g., ipv4, ipv6) that want to be part of the networking stack.
+type NetworkProtocol interface {
+ // Number returns the network protocol number.
+ Number() tcpip.NetworkProtocolNumber
+
+ // MinimumPacketSize returns the minimum valid packet size of this
+ // network protocol. The stack automatically drops any packets smaller
+ // than this targeted at this protocol.
+ MinimumPacketSize() int
+
+ // DefaultPrefixLen returns the protocol's default prefix length.
+ DefaultPrefixLen() int
+
+ // ParseAddresses returns the source and destination addresses stored in a
+ // packet of this protocol.
+ ParseAddresses(v buffer.View) (src, dst tcpip.Address)
+
+ // NewEndpoint creates a new endpoint of this protocol.
+ NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) (NetworkEndpoint, *tcpip.Error)
+
+ // SetOption allows enabling/disabling protocol specific features.
+ // SetOption returns an error if the option is not supported or the
+ // provided option value is invalid.
+ SetOption(option interface{}) *tcpip.Error
+
+ // Option allows retrieving protocol specific option values.
+ // Option returns an error if the option is not supported or the
+ // provided option value is invalid.
+ Option(option interface{}) *tcpip.Error
+
+ // Close requests that any worker goroutines owned by the protocol
+ // stop.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the protocol to stop.
+ Wait()
+
+ // Parse sets pkt.NetworkHeader and trims pkt.Data appropriately. It
+ // returns:
+ // - The encapsulated protocol, if present.
+ // - Whether there is an encapsulated transport protocol payload (e.g. ARP
+ // does not encapsulate anything).
+ // - Whether pkt.Data was large enough to parse and set pkt.NetworkHeader.
+ Parse(pkt *PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool)
+}
+
+// NetworkDispatcher contains the methods used by the network stack to deliver
+// packets to the appropriate network endpoint after it has been handled by
+// the data link layer.
+type NetworkDispatcher interface {
+ // DeliverNetworkPacket finds the appropriate network protocol endpoint
+ // and hands the packet over for further processing.
+ //
+ // pkt.LinkHeader may or may not be set before calling
+ // DeliverNetworkPacket. Some packets do not have link headers (e.g.
+ // packets sent via loopback), and won't have the field set.
+ //
+ // DeliverNetworkPacket takes ownership of pkt.
+ DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
+}
+
+// LinkEndpointCapabilities is the type associated with the capabilities
+// supported by a link-layer endpoint. It is a set of bitfields.
+type LinkEndpointCapabilities uint
+
+// The following are the supported link endpoint capabilities.
+const (
+ CapabilityNone LinkEndpointCapabilities = 0
+ // CapabilityTXChecksumOffload indicates that the link endpoint supports
+ // checksum computation for outgoing packets and the stack can skip
+ // computing checksums when sending packets.
+ CapabilityTXChecksumOffload LinkEndpointCapabilities = 1 << iota
+ // CapabilityRXChecksumOffload indicates that the link endpoint supports
+ // checksum verification on received packets and that it's safe for the
+ // stack to skip checksum verification.
+ CapabilityRXChecksumOffload
+ CapabilityResolutionRequired
+ CapabilitySaveRestore
+ CapabilityDisconnectOk
+ CapabilityLoopback
+ CapabilityHardwareGSO
+
+ // CapabilitySoftwareGSO indicates the link endpoint supports of sending
+ // multiple packets using a single call (LinkEndpoint.WritePackets).
+ CapabilitySoftwareGSO
+)
+
+// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
+// ethernet, loopback, raw) and used by network layer protocols to send packets
+// out through the implementer's data link endpoint. When a link header exists,
+// it sets each PacketBuffer's LinkHeader field before passing it up the
+// stack.
+type LinkEndpoint interface {
+ // MTU is the maximum transmission unit for this endpoint. This is
+ // usually dictated by the backing physical network; when such a
+ // physical network doesn't exist, the limit is generally 64k, which
+ // includes the maximum size of an IP packet.
+ MTU() uint32
+
+ // Capabilities returns the set of capabilities supported by the
+ // endpoint.
+ Capabilities() LinkEndpointCapabilities
+
+ // MaxHeaderLength returns the maximum size the data link (and
+ // lower level layers combined) headers can have. Higher levels use this
+ // information to reserve space in the front of the packets they're
+ // building.
+ MaxHeaderLength() uint16
+
+ // LinkAddress returns the link address (typically a MAC) of the
+ // link endpoint.
+ LinkAddress() tcpip.LinkAddress
+
+ // WritePacket writes a packet with the given protocol through the
+ // given route. It takes ownership of pkt. pkt.NetworkHeader and
+ // pkt.TransportHeader must have already been set.
+ //
+ // To participate in transparent bridging, a LinkEndpoint implementation
+ // should call eth.Encode with header.EthernetFields.SrcAddr set to
+ // r.LocalLinkAddress if it is provided.
+ WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error
+
+ // WritePackets writes packets with the given protocol through the
+ // given route. pkts must not be zero length. It takes ownership of pkts and
+ // underlying packets.
+ //
+ // Right now, WritePackets is used only when the software segmentation
+ // offload is enabled. If it will be used for something else, it may
+ // require to change syscall filters.
+ WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
+
+ // WriteRawPacket writes a packet directly to the link. The packet
+ // should already have an ethernet header. It takes ownership of vv.
+ WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error
+
+ // Attach attaches the data link layer endpoint to the network-layer
+ // dispatcher of the stack.
+ //
+ // Attach will be called with a nil dispatcher if the receiver's associated
+ // NIC is being removed.
+ Attach(dispatcher NetworkDispatcher)
+
+ // IsAttached returns whether a NetworkDispatcher is attached to the
+ // endpoint.
+ IsAttached() bool
+
+ // Wait waits for any worker goroutines owned by the endpoint to stop.
+ //
+ // For now, requesting that an endpoint's worker goroutine(s) stop is
+ // implementation specific.
+ //
+ // Wait will not block if the endpoint hasn't started any goroutines
+ // yet, even if it might later.
+ Wait()
+}
+
+// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
+// delivered via the Inject method.
+type InjectableLinkEndpoint interface {
+ LinkEndpoint
+
+ // InjectInbound injects an inbound packet.
+ InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
+
+ // InjectOutbound writes a fully formed outbound packet directly to the
+ // link.
+ //
+ // dest is used by endpoints with multiple raw destinations.
+ InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error
+}
+
+// A LinkAddressResolver is an extension to a NetworkProtocol that
+// can resolve link addresses.
+type LinkAddressResolver interface {
+ // LinkAddressRequest sends a request for the LinkAddress of addr.
+ // The request is sent on linkEP with localAddr as the source.
+ //
+ // A valid response will cause the discovery protocol's network
+ // endpoint to call AddLinkAddress.
+ LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error
+
+ // ResolveStaticAddress attempts to resolve address without sending
+ // requests. It either resolves the name immediately or returns the
+ // empty LinkAddress.
+ //
+ // It can be used to resolve broadcast addresses for example.
+ ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool)
+
+ // LinkAddressProtocol returns the network protocol of the
+ // addresses this this resolver can resolve.
+ LinkAddressProtocol() tcpip.NetworkProtocolNumber
+}
+
+// A LinkAddressCache caches link addresses.
+type LinkAddressCache interface {
+ // CheckLocalAddress determines if the given local address exists, and if it
+ // does not exist.
+ CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID
+
+ // AddLinkAddress adds a link address to the cache.
+ AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
+
+ // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC).
+ // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver
+ // registered with the network protocol, the cache attempts to resolve the address
+ // and returns ErrWouldBlock. Waker is notified when address resolution is
+ // complete (success or not).
+ //
+ // If address resolution is required, ErrNoLinkAddress and a notification channel is
+ // returned for the top level caller to block. Channel is closed once address resolution
+ // is complete (success or not).
+ GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
+
+ // RemoveWaker removes a waker that has been added in GetLinkAddress().
+ RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
+}
+
+// RawFactory produces endpoints for writing various types of raw packets.
+type RawFactory interface {
+ // NewUnassociatedEndpoint produces endpoints for writing packets not
+ // associated with a particular transport protocol. Such endpoints can
+ // be used to write arbitrary packets that include the network header.
+ NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+
+ // NewPacketEndpoint produces endpoints for reading and writing packets
+ // that include network and (when cooked is false) link layer headers.
+ NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+}
+
+// GSOType is the type of GSO segments.
+//
+// +stateify savable
+type GSOType int
+
+// Types of gso segments.
+const (
+ GSONone GSOType = iota
+
+ // Hardware GSO types:
+ GSOTCPv4
+ GSOTCPv6
+
+ // GSOSW is used for software GSO segments which have to be sent by
+ // endpoint.WritePackets.
+ GSOSW
+)
+
+// GSO contains generic segmentation offload properties.
+//
+// +stateify savable
+type GSO struct {
+ // Type is one of GSONone, GSOTCPv4, etc.
+ Type GSOType
+ // NeedsCsum is set if the checksum offload is enabled.
+ NeedsCsum bool
+ // CsumOffset is offset after that to place checksum.
+ CsumOffset uint16
+
+ // Mss is maximum segment size.
+ MSS uint16
+ // L3Len is L3 (IP) header length.
+ L3HdrLen uint16
+
+ // MaxSize is maximum GSO packet size.
+ MaxSize uint32
+}
+
+// GSOEndpoint provides access to GSO properties.
+type GSOEndpoint interface {
+ // GSOMaxSize returns the maximum GSO packet size.
+ GSOMaxSize() uint32
+}
+
+// SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment.
+// This isn't a hard limit, because it is never set into packet headers.
+const SoftwareGSOMaxSize = (1 << 16)
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
new file mode 100644
index 000000000..d65f8049e
--- /dev/null
+++ b/pkg/tcpip/stack/route.go
@@ -0,0 +1,289 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// Route represents a route through the networking stack to a given destination.
+type Route struct {
+ // RemoteAddress is the final destination of the route.
+ RemoteAddress tcpip.Address
+
+ // RemoteLinkAddress is the link-layer (MAC) address of the
+ // final destination of the route.
+ RemoteLinkAddress tcpip.LinkAddress
+
+ // LocalAddress is the local address where the route starts.
+ LocalAddress tcpip.Address
+
+ // LocalLinkAddress is the link-layer (MAC) address of the
+ // where the route starts.
+ LocalLinkAddress tcpip.LinkAddress
+
+ // NextHop is the next node in the path to the destination.
+ NextHop tcpip.Address
+
+ // NetProto is the network-layer protocol.
+ NetProto tcpip.NetworkProtocolNumber
+
+ // ref a reference to the network endpoint through which the route
+ // starts.
+ ref *referencedNetworkEndpoint
+
+ // Loop controls where WritePacket should send packets.
+ Loop PacketLooping
+}
+
+// makeRoute initializes a new route. It takes ownership of the provided
+// reference to a network endpoint.
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, handleLocal, multicastLoop bool) Route {
+ loop := PacketOut
+ if handleLocal && localAddr != "" && remoteAddr == localAddr {
+ loop = PacketLoop
+ } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) {
+ loop |= PacketLoop
+ } else if remoteAddr == header.IPv4Broadcast {
+ loop |= PacketLoop
+ }
+
+ return Route{
+ NetProto: netProto,
+ LocalAddress: localAddr,
+ LocalLinkAddress: localLinkAddr,
+ RemoteAddress: remoteAddr,
+ ref: ref,
+ Loop: loop,
+ }
+}
+
+// NICID returns the id of the NIC from which this route originates.
+func (r *Route) NICID() tcpip.NICID {
+ return r.ref.ep.NICID()
+}
+
+// MaxHeaderLength forwards the call to the network endpoint's implementation.
+func (r *Route) MaxHeaderLength() uint16 {
+ return r.ref.ep.MaxHeaderLength()
+}
+
+// Stats returns a mutable copy of current stats.
+func (r *Route) Stats() tcpip.Stats {
+ return r.ref.nic.stack.Stats()
+}
+
+// PseudoHeaderChecksum forwards the call to the network endpoint's
+// implementation.
+func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, totalLen uint16) uint16 {
+ return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress, totalLen)
+}
+
+// Capabilities returns the link-layer capabilities of the route.
+func (r *Route) Capabilities() LinkEndpointCapabilities {
+ return r.ref.ep.Capabilities()
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (r *Route) GSOMaxSize() uint32 {
+ if gso, ok := r.ref.ep.(GSOEndpoint); ok {
+ return gso.GSOMaxSize()
+ }
+ return 0
+}
+
+// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
+// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
+// notified when address resolution is complete (success or not).
+//
+// If address resolution is required, ErrNoLinkAddress and a notification channel is
+// returned for the top level caller to block. Channel is closed once address resolution
+// is complete (success or not).
+//
+// The NIC r uses must not be locked.
+func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
+ if !r.IsResolutionRequired() {
+ // Nothing to do if there is no cache (which does the resolution on cache miss) or
+ // link address is already known.
+ return nil, nil
+ }
+
+ nextAddr := r.NextHop
+ if nextAddr == "" {
+ // Local link address is already known.
+ if r.RemoteAddress == r.LocalAddress {
+ r.RemoteLinkAddress = r.LocalLinkAddress
+ return nil, nil
+ }
+ nextAddr = r.RemoteAddress
+ }
+ linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
+ if err != nil {
+ return ch, err
+ }
+ r.RemoteLinkAddress = linkAddr
+ return nil, nil
+}
+
+// RemoveWaker removes a waker that has been added in Resolve().
+func (r *Route) RemoveWaker(waker *sleep.Waker) {
+ nextAddr := r.NextHop
+ if nextAddr == "" {
+ nextAddr = r.RemoteAddress
+ }
+ r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker)
+}
+
+// IsResolutionRequired returns true if Resolve() must be called to resolve
+// the link address before the this route can be written to.
+//
+// The NIC r uses must not be locked.
+func (r *Route) IsResolutionRequired() bool {
+ return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == ""
+}
+
+// WritePacket writes the packet through the given route.
+func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
+ if !r.ref.isValidForOutgoing() {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // WritePacket takes ownership of pkt, calculate numBytes first.
+ numBytes := pkt.Header.UsedLength() + pkt.Data.Size()
+
+ err := r.ref.ep.WritePacket(r, gso, params, pkt)
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ } else {
+ r.ref.nic.stats.Tx.Packets.Increment()
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+ }
+ return err
+}
+
+// WritePackets writes a list of n packets through the given route and returns
+// the number of packets written.
+func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
+ if !r.ref.isValidForOutgoing() {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ // WritePackets takes ownership of pkt, calculate length first.
+ numPkts := pkts.Len()
+
+ n, err := r.ref.ep.WritePackets(r, gso, pkts, params)
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n))
+ }
+ r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n))
+
+ writtenBytes := 0
+ for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() {
+ writtenBytes += pb.Header.UsedLength()
+ writtenBytes += pb.Data.Size()
+ }
+
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
+ return n, err
+}
+
+// WriteHeaderIncludedPacket writes a packet already containing a network
+// header through the given route.
+func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
+ if !r.ref.isValidForOutgoing() {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first.
+ numBytes := pkt.Data.Size()
+
+ if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
+ }
+ r.ref.nic.stats.Tx.Packets.Increment()
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+ return nil
+}
+
+// DefaultTTL returns the default TTL of the underlying network endpoint.
+func (r *Route) DefaultTTL() uint8 {
+ return r.ref.ep.DefaultTTL()
+}
+
+// MTU returns the MTU of the underlying network endpoint.
+func (r *Route) MTU() uint32 {
+ return r.ref.ep.MTU()
+}
+
+// NetworkProtocolNumber returns the NetworkProtocolNumber of the underlying
+// network endpoint.
+func (r *Route) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return r.ref.ep.NetworkProtocolNumber()
+}
+
+// Release frees all resources associated with the route.
+func (r *Route) Release() {
+ if r.ref != nil {
+ r.ref.decRef()
+ r.ref = nil
+ }
+}
+
+// Clone Clone a route such that the original one can be released and the new
+// one will remain valid.
+func (r *Route) Clone() Route {
+ if r.ref != nil {
+ r.ref.incRef()
+ }
+ return *r
+}
+
+// MakeLoopedRoute duplicates the given route with special handling for routes
+// used for sending multicast or broadcast packets. In those cases the
+// multicast/broadcast address is the remote address when sending out, but for
+// incoming (looped) packets it becomes the local address. Similarly, the local
+// interface address that was the local address going out becomes the remote
+// address coming in. This is different to unicast routes where local and
+// remote addresses remain the same as they identify location (local vs remote)
+// not direction (source vs destination).
+func (r *Route) MakeLoopedRoute() Route {
+ l := r.Clone()
+ if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) {
+ l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress
+ l.RemoteLinkAddress = l.LocalLinkAddress
+ }
+ return l
+}
+
+// Stack returns the instance of the Stack that owns this route.
+func (r *Route) Stack() *Stack {
+ return r.ref.stack()
+}
+
+// ReverseRoute returns new route with given source and destination address.
+func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route {
+ return Route{
+ NetProto: r.NetProto,
+ LocalAddress: dst,
+ LocalLinkAddress: r.RemoteLinkAddress,
+ RemoteAddress: src,
+ RemoteLinkAddress: r.LocalLinkAddress,
+ ref: r.ref,
+ Loop: r.Loop,
+ }
+}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
new file mode 100644
index 000000000..cdcfb8321
--- /dev/null
+++ b/pkg/tcpip/stack/stack.go
@@ -0,0 +1,1938 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package stack provides the glue between networking protocols and the
+// consumers of the networking stack.
+//
+// For consumers, the only function of interest is New(), everything else is
+// provided by the tcpip/public package.
+package stack
+
+import (
+ "bytes"
+ "encoding/binary"
+ mathrand "math/rand"
+ "sync/atomic"
+ "time"
+
+ "golang.org/x/time/rate"
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ // ageLimit is set to the same cache stale time used in Linux.
+ ageLimit = 1 * time.Minute
+ // resolutionTimeout is set to the same ARP timeout used in Linux.
+ resolutionTimeout = 1 * time.Second
+ // resolutionAttempts is set to the same ARP retries used in Linux.
+ resolutionAttempts = 3
+
+ // DefaultTOS is the default type of service value for network endpoints.
+ DefaultTOS = 0
+)
+
+type transportProtocolState struct {
+ proto TransportProtocol
+ defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
+}
+
+// TCPProbeFunc is the expected function type for a TCP probe function to be
+// passed to stack.AddTCPProbe.
+type TCPProbeFunc func(s TCPEndpointState)
+
+// TCPCubicState is used to hold a copy of the internal cubic state when the
+// TCPProbeFunc is invoked.
+type TCPCubicState struct {
+ WLastMax float64
+ WMax float64
+ T time.Time
+ TimeSinceLastCongestion time.Duration
+ C float64
+ K float64
+ Beta float64
+ WC float64
+ WEst float64
+}
+
+// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
+type TCPEndpointID struct {
+ // LocalPort is the local port associated with the endpoint.
+ LocalPort uint16
+
+ // LocalAddress is the local [network layer] address associated with
+ // the endpoint.
+ LocalAddress tcpip.Address
+
+ // RemotePort is the remote port associated with the endpoint.
+ RemotePort uint16
+
+ // RemoteAddress it the remote [network layer] address associated with
+ // the endpoint.
+ RemoteAddress tcpip.Address
+}
+
+// TCPFastRecoveryState holds a copy of the internal fast recovery state of a
+// TCP endpoint.
+type TCPFastRecoveryState struct {
+ // Active if true indicates the endpoint is in fast recovery.
+ Active bool
+
+ // First is the first unacknowledged sequence number being recovered.
+ First seqnum.Value
+
+ // Last is the 'recover' sequence number that indicates the point at
+ // which we should exit recovery barring any timeouts etc.
+ Last seqnum.Value
+
+ // MaxCwnd is the maximum value we are permitted to grow the congestion
+ // window during recovery. This is set at the time we enter recovery.
+ MaxCwnd int
+
+ // HighRxt is the highest sequence number which has been retransmitted
+ // during the current loss recovery phase.
+ // See: RFC 6675 Section 2 for details.
+ HighRxt seqnum.Value
+
+ // RescueRxt is the highest sequence number which has been
+ // optimistically retransmitted to prevent stalling of the ACK clock
+ // when there is loss at the end of the window and no new data is
+ // available for transmission.
+ // See: RFC 6675 Section 2 for details.
+ RescueRxt seqnum.Value
+}
+
+// TCPReceiverState holds a copy of the internal state of the receiver for
+// a given TCP endpoint.
+type TCPReceiverState struct {
+ // RcvNxt is the TCP variable RCV.NXT.
+ RcvNxt seqnum.Value
+
+ // RcvAcc is the TCP variable RCV.ACC.
+ RcvAcc seqnum.Value
+
+ // RcvWndScale is the window scaling to use for inbound segments.
+ RcvWndScale uint8
+
+ // PendingBufUsed is the number of bytes pending in the receive
+ // queue.
+ PendingBufUsed seqnum.Size
+
+ // PendingBufSize is the size of the socket receive buffer.
+ PendingBufSize seqnum.Size
+}
+
+// TCPSenderState holds a copy of the internal state of the sender for
+// a given TCP Endpoint.
+type TCPSenderState struct {
+ // LastSendTime is the time at which we sent the last segment.
+ LastSendTime time.Time
+
+ // DupAckCount is the number of Duplicate ACK's received.
+ DupAckCount int
+
+ // SndCwnd is the size of the sending congestion window in packets.
+ SndCwnd int
+
+ // Ssthresh is the slow start threshold in packets.
+ Ssthresh int
+
+ // SndCAAckCount is the number of packets consumed in congestion
+ // avoidance mode.
+ SndCAAckCount int
+
+ // Outstanding is the number of packets in flight.
+ Outstanding int
+
+ // SndWnd is the send window size in bytes.
+ SndWnd seqnum.Size
+
+ // SndUna is the next unacknowledged sequence number.
+ SndUna seqnum.Value
+
+ // SndNxt is the sequence number of the next segment to be sent.
+ SndNxt seqnum.Value
+
+ // RTTMeasureSeqNum is the sequence number being used for the latest RTT
+ // measurement.
+ RTTMeasureSeqNum seqnum.Value
+
+ // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent.
+ RTTMeasureTime time.Time
+
+ // Closed indicates that the caller has closed the endpoint for sending.
+ Closed bool
+
+ // SRTT is the smoothed round-trip time as defined in section 2 of
+ // RFC 6298.
+ SRTT time.Duration
+
+ // RTO is the retransmit timeout as defined in section of 2 of RFC 6298.
+ RTO time.Duration
+
+ // RTTVar is the round-trip time variation as defined in section 2 of
+ // RFC 6298.
+ RTTVar time.Duration
+
+ // SRTTInited if true indicates take a valid RTT measurement has been
+ // completed.
+ SRTTInited bool
+
+ // MaxPayloadSize is the maximum size of the payload of a given segment.
+ // It is initialized on demand.
+ MaxPayloadSize int
+
+ // SndWndScale is the number of bits to shift left when reading the send
+ // window size from a segment.
+ SndWndScale uint8
+
+ // MaxSentAck is the highest acknowledgement number sent till now.
+ MaxSentAck seqnum.Value
+
+ // FastRecovery holds the fast recovery state for the endpoint.
+ FastRecovery TCPFastRecoveryState
+
+ // Cubic holds the state related to CUBIC congestion control.
+ Cubic TCPCubicState
+}
+
+// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
+type TCPSACKInfo struct {
+ // Blocks is the list of SACK Blocks that identify the out of order segments
+ // held by a given TCP endpoint.
+ Blocks []header.SACKBlock
+
+ // ReceivedBlocks are the SACK blocks received by this endpoint
+ // from the peer endpoint.
+ ReceivedBlocks []header.SACKBlock
+
+ // MaxSACKED is the highest sequence number that has been SACKED
+ // by the peer.
+ MaxSACKED seqnum.Value
+}
+
+// RcvBufAutoTuneParams holds state related to TCP receive buffer auto-tuning.
+type RcvBufAutoTuneParams struct {
+ // MeasureTime is the time at which the current measurement
+ // was started.
+ MeasureTime time.Time
+
+ // CopiedBytes is the number of bytes copied to userspace since
+ // this measure began.
+ CopiedBytes int
+
+ // PrevCopiedBytes is the number of bytes copied to userspace in
+ // the previous RTT period.
+ PrevCopiedBytes int
+
+ // RcvBufSize is the auto tuned receive buffer size.
+ RcvBufSize int
+
+ // RTT is the smoothed RTT as measured by observing the time between
+ // when a byte is first acknowledged and the receipt of data that is at
+ // least one window beyond the sequence number that was acknowledged.
+ RTT time.Duration
+
+ // RTTVar is the "round-trip time variation" as defined in section 2
+ // of RFC6298.
+ RTTVar time.Duration
+
+ // RTTMeasureSeqNumber is the highest acceptable sequence number at the
+ // time this RTT measurement period began.
+ RTTMeasureSeqNumber seqnum.Value
+
+ // RTTMeasureTime is the absolute time at which the current RTT
+ // measurement period began.
+ RTTMeasureTime time.Time
+
+ // Disabled is true if an explicit receive buffer is set for the
+ // endpoint.
+ Disabled bool
+}
+
+// TCPEndpointState is a copy of the internal state of a TCP endpoint.
+type TCPEndpointState struct {
+ // ID is a copy of the TransportEndpointID for the endpoint.
+ ID TCPEndpointID
+
+ // SegTime denotes the absolute time when this segment was received.
+ SegTime time.Time
+
+ // RcvBufSize is the size of the receive socket buffer for the endpoint.
+ RcvBufSize int
+
+ // RcvBufUsed is the amount of bytes actually held in the receive socket
+ // buffer for the endpoint.
+ RcvBufUsed int
+
+ // RcvBufAutoTuneParams is used to hold state variables to compute
+ // the auto tuned receive buffer size.
+ RcvAutoParams RcvBufAutoTuneParams
+
+ // RcvClosed if true, indicates the endpoint has been closed for reading.
+ RcvClosed bool
+
+ // SendTSOk is used to indicate when the TS Option has been negotiated.
+ // When sendTSOk is true every non-RST segment should carry a TS as per
+ // RFC7323#section-1.1.
+ SendTSOk bool
+
+ // RecentTS is the timestamp that should be sent in the TSEcr field of
+ // the timestamp for future segments sent by the endpoint. This field is
+ // updated if required when a new segment is received by this endpoint.
+ RecentTS uint32
+
+ // TSOffset is a randomized offset added to the value of the TSVal field
+ // in the timestamp option.
+ TSOffset uint32
+
+ // SACKPermitted is set to true if the peer sends the TCPSACKPermitted
+ // option in the SYN/SYN-ACK.
+ SACKPermitted bool
+
+ // SACK holds TCP SACK related information for this endpoint.
+ SACK TCPSACKInfo
+
+ // SndBufSize is the size of the socket send buffer.
+ SndBufSize int
+
+ // SndBufUsed is the number of bytes held in the socket send buffer.
+ SndBufUsed int
+
+ // SndClosed indicates that the endpoint has been closed for sends.
+ SndClosed bool
+
+ // SndBufInQueue is the number of bytes in the send queue.
+ SndBufInQueue seqnum.Size
+
+ // PacketTooBigCount is used to notify the main protocol routine how
+ // many times a "packet too big" control packet is received.
+ PacketTooBigCount int
+
+ // SndMTU is the smallest MTU seen in the control packets received.
+ SndMTU int
+
+ // Receiver holds variables related to the TCP receiver for the endpoint.
+ Receiver TCPReceiverState
+
+ // Sender holds state related to the TCP Sender for the endpoint.
+ Sender TCPSenderState
+}
+
+// ResumableEndpoint is an endpoint that needs to be resumed after restore.
+type ResumableEndpoint interface {
+ // Resume resumes an endpoint after restore. This can be used to restart
+ // background workers such as protocol goroutines. This must be called after
+ // all indirect dependencies of the endpoint has been restored, which
+ // generally implies at the end of the restore process.
+ Resume(*Stack)
+}
+
+// uniqueIDGenerator is a default unique ID generator.
+type uniqueIDGenerator uint64
+
+func (u *uniqueIDGenerator) UniqueID() uint64 {
+ return atomic.AddUint64((*uint64)(u), 1)
+}
+
+// NICNameFromID is a function that returns a stable name for the specified NIC,
+// even if different NIC IDs are used to refer to the same NIC in different
+// program runs. It is used when generating opaque interface identifiers (IIDs).
+// If the NIC was created with a name, it will be passed to NICNameFromID.
+//
+// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are
+// generated for the same prefix on differnt NICs.
+type NICNameFromID func(tcpip.NICID, string) string
+
+// OpaqueInterfaceIdentifierOptions holds the options related to the generation
+// of opaque interface indentifiers (IIDs) as defined by RFC 7217.
+type OpaqueInterfaceIdentifierOptions struct {
+ // NICNameFromID is a function that returns a stable name for a specified NIC,
+ // even if the NIC ID changes over time.
+ //
+ // Must be specified to generate the opaque IID.
+ NICNameFromID NICNameFromID
+
+ // SecretKey is a pseudo-random number used as the secret key when generating
+ // opaque IIDs as defined by RFC 7217. The key SHOULD be at least
+ // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness
+ // requirements for security as outlined by RFC 4086. SecretKey MUST NOT
+ // change between program runs, unless explicitly changed.
+ //
+ // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey
+ // MUST NOT be modified after Stack is created.
+ //
+ // May be nil, but a nil value is highly discouraged to maintain
+ // some level of randomness between nodes.
+ SecretKey []byte
+}
+
+// Stack is a networking stack, with all supported protocols, NICs, and route
+// table.
+type Stack struct {
+ transportProtocols map[tcpip.TransportProtocolNumber]*transportProtocolState
+ networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
+ linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
+
+ // rawFactory creates raw endpoints. If nil, raw endpoints are
+ // disabled. It is set during Stack creation and is immutable.
+ rawFactory RawFactory
+
+ demux *transportDemuxer
+
+ stats tcpip.Stats
+
+ linkAddrCache *linkAddrCache
+
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*NIC
+ forwarding bool
+ cleanupEndpoints map[TransportEndpoint]struct{}
+
+ // route is the route table passed in by the user via SetRouteTable(),
+ // it is used by FindRoute() to build a route for a specific
+ // destination.
+ routeTable []tcpip.Route
+
+ *ports.PortManager
+
+ // If not nil, then any new endpoints will have this probe function
+ // invoked everytime they receive a TCP segment.
+ tcpProbeFunc TCPProbeFunc
+
+ // clock is used to generate user-visible times.
+ clock tcpip.Clock
+
+ // handleLocal allows non-loopback interfaces to loop packets.
+ handleLocal bool
+
+ // tables are the iptables packet filtering and manipulation rules.
+ tables *IPTables
+
+ // resumableEndpoints is a list of endpoints that need to be resumed if the
+ // stack is being restored.
+ resumableEndpoints []ResumableEndpoint
+
+ // icmpRateLimiter is a global rate limiter for all ICMP messages generated
+ // by the stack.
+ icmpRateLimiter *ICMPRateLimiter
+
+ // seed is a one-time random value initialized at stack startup
+ // and is used to seed the TCP port picking on active connections
+ //
+ // TODO(gvisor.dev/issue/940): S/R this field.
+ seed uint32
+
+ // ndpConfigs is the default NDP configurations used by interfaces.
+ ndpConfigs NDPConfigurations
+
+ // autoGenIPv6LinkLocal determines whether or not the stack will attempt
+ // to auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
+ autoGenIPv6LinkLocal bool
+
+ // ndpDisp is the NDP event dispatcher that is used to send the netstack
+ // integrator NDP related events.
+ ndpDisp NDPDispatcher
+
+ // uniqueIDGenerator is a generator of unique identifiers.
+ uniqueIDGenerator UniqueID
+
+ // opaqueIIDOpts hold the options for generating opaque interface identifiers
+ // (IIDs) as outlined by RFC 7217.
+ opaqueIIDOpts OpaqueInterfaceIdentifierOptions
+
+ // tempIIDSeed is used to seed the initial temporary interface identifier
+ // history value used to generate IIDs for temporary SLAAC addresses.
+ tempIIDSeed []byte
+
+ // forwarder holds the packets that wait for their link-address resolutions
+ // to complete, and forwards them when each resolution is done.
+ forwarder *forwardQueue
+
+ // randomGenerator is an injectable pseudo random generator that can be
+ // used when a random number is required.
+ randomGenerator *mathrand.Rand
+
+ // sendBufferSize holds the min/default/max send buffer sizes for
+ // endpoints other than TCP.
+ sendBufferSize SendBufferSizeOption
+
+ // receiveBufferSize holds the min/default/max receive buffer sizes for
+ // endpoints other than TCP.
+ receiveBufferSize ReceiveBufferSizeOption
+}
+
+// UniqueID is an abstract generator of unique identifiers.
+type UniqueID interface {
+ UniqueID() uint64
+}
+
+// Options contains optional Stack configuration.
+type Options struct {
+ // NetworkProtocols lists the network protocols to enable.
+ NetworkProtocols []NetworkProtocol
+
+ // TransportProtocols lists the transport protocols to enable.
+ TransportProtocols []TransportProtocol
+
+ // Clock is an optional clock source used for timestampping packets.
+ //
+ // If no Clock is specified, the clock source will be time.Now.
+ Clock tcpip.Clock
+
+ // Stats are optional statistic counters.
+ Stats tcpip.Stats
+
+ // HandleLocal indicates whether packets destined to their source
+ // should be handled by the stack internally (true) or outside the
+ // stack (false).
+ HandleLocal bool
+
+ // UniqueID is an optional generator of unique identifiers.
+ UniqueID UniqueID
+
+ // NDPConfigs is the default NDP configurations used by interfaces.
+ //
+ // By default, NDPConfigs will have a zero value for its
+ // DupAddrDetectTransmits field, implying that DAD will not be performed
+ // before assigning an address to a NIC.
+ NDPConfigs NDPConfigurations
+
+ // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to
+ // auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs.
+ //
+ // Note, setting this to true does not mean that a link-local address
+ // will be assigned right away, or at all. If Duplicate Address Detection
+ // is enabled, an address will only be assigned if it successfully resolves.
+ // If it fails, no further attempt will be made to auto-generate an IPv6
+ // link-local address.
+ //
+ // The generated link-local address will follow RFC 4291 Appendix A
+ // guidelines.
+ AutoGenIPv6LinkLocal bool
+
+ // NDPDisp is the NDP event dispatcher that an integrator can provide to
+ // receive NDP related events.
+ NDPDisp NDPDispatcher
+
+ // RawFactory produces raw endpoints. Raw endpoints are enabled only if
+ // this is non-nil.
+ RawFactory RawFactory
+
+ // OpaqueIIDOpts hold the options for generating opaque interface
+ // identifiers (IIDs) as outlined by RFC 7217.
+ OpaqueIIDOpts OpaqueInterfaceIdentifierOptions
+
+ // RandSource is an optional source to use to generate random
+ // numbers. If omitted it defaults to a Source seeded by the data
+ // returned by rand.Read().
+ //
+ // RandSource must be thread-safe.
+ RandSource mathrand.Source
+
+ // TempIIDSeed is used to seed the initial temporary interface identifier
+ // history value used to generate IIDs for temporary SLAAC addresses.
+ //
+ // Temporary SLAAC adresses are short-lived addresses which are unpredictable
+ // and random from the perspective of other nodes on the network. It is
+ // recommended that the seed be a random byte buffer of at least
+ // header.IIDSize bytes to make sure that temporary SLAAC addresses are
+ // sufficiently random. It should follow minimum randomness requirements for
+ // security as outlined by RFC 4086.
+ //
+ // Note: using a nil value, the same seed across netstack program runs, or a
+ // seed that is too small would reduce randomness and increase predictability,
+ // defeating the purpose of temporary SLAAC addresses.
+ TempIIDSeed []byte
+}
+
+// TransportEndpointInfo holds useful information about a transport endpoint
+// which can be queried by monitoring tools.
+//
+// +stateify savable
+type TransportEndpointInfo struct {
+ // The following fields are initialized at creation time and are
+ // immutable.
+
+ NetProto tcpip.NetworkProtocolNumber
+ TransProto tcpip.TransportProtocolNumber
+
+ // The following fields are protected by endpoint mu.
+
+ ID TransportEndpointID
+ // BindNICID and bindAddr are set via calls to Bind(). They are used to
+ // reject attempts to send data or connect via a different NIC or
+ // address
+ BindNICID tcpip.NICID
+ BindAddr tcpip.Address
+ // RegisterNICID is the default NICID registered as a side-effect of
+ // connect or datagram write.
+ RegisterNICID tcpip.NICID
+}
+
+// AddrNetProtoLocked unwraps the specified address if it is a V4-mapped V6
+// address and returns the network protocol number to be used to communicate
+// with the specified address. It returns an error if the passed address is
+// incompatible with the receiver.
+//
+// Preconditon: the parent endpoint mu must be held while calling this method.
+func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := e.NetProto
+ switch len(addr.Addr) {
+ case header.IPv4AddressSize:
+ netProto = header.IPv4ProtocolNumber
+ case header.IPv6AddressSize:
+ if header.IsV4MappedAddress(addr.Addr) {
+ netProto = header.IPv4ProtocolNumber
+ addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
+ if addr.Addr == header.IPv4Any {
+ addr.Addr = ""
+ }
+ }
+ }
+
+ switch len(e.ID.LocalAddress) {
+ case header.IPv4AddressSize:
+ if len(addr.Addr) == header.IPv6AddressSize {
+ return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState
+ }
+ case header.IPv6AddressSize:
+ if len(addr.Addr) == header.IPv4AddressSize {
+ return tcpip.FullAddress{}, 0, tcpip.ErrNetworkUnreachable
+ }
+ }
+
+ switch {
+ case netProto == e.NetProto:
+ case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber:
+ if v6only {
+ return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute
+ }
+ default:
+ return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState
+ }
+
+ return addr, netProto, nil
+}
+
+// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
+// marker interface.
+func (*TransportEndpointInfo) IsEndpointInfo() {}
+
+// New allocates a new networking stack with only the requested networking and
+// transport protocols configured with default options.
+//
+// Note, NDPConfigurations will be fixed before being used by the Stack. That
+// is, if an invalid value was provided, it will be reset to the default value.
+//
+// Protocol options can be changed by calling the
+// SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the
+// stack. Please refer to individual protocol implementations as to what options
+// are supported.
+func New(opts Options) *Stack {
+ clock := opts.Clock
+ if clock == nil {
+ clock = &tcpip.StdClock{}
+ }
+
+ if opts.UniqueID == nil {
+ opts.UniqueID = new(uniqueIDGenerator)
+ }
+
+ randSrc := opts.RandSource
+ if randSrc == nil {
+ // Source provided by mathrand.NewSource is not thread-safe so
+ // we wrap it in a simple thread-safe version.
+ randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())}
+ }
+
+ // Make sure opts.NDPConfigs contains valid values only.
+ opts.NDPConfigs.validate()
+
+ s := &Stack{
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
+ nics: make(map[tcpip.NICID]*NIC),
+ cleanupEndpoints: make(map[TransportEndpoint]struct{}),
+ linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ tables: DefaultTables(),
+ icmpRateLimiter: NewICMPRateLimiter(),
+ seed: generateRandUint32(),
+ ndpConfigs: opts.NDPConfigs,
+ autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
+ uniqueIDGenerator: opts.UniqueID,
+ ndpDisp: opts.NDPDisp,
+ opaqueIIDOpts: opts.OpaqueIIDOpts,
+ tempIIDSeed: opts.TempIIDSeed,
+ forwarder: newForwardQueue(),
+ randomGenerator: mathrand.New(randSrc),
+ sendBufferSize: SendBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultBufferSize,
+ Max: DefaultMaxBufferSize,
+ },
+ receiveBufferSize: ReceiveBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultBufferSize,
+ Max: DefaultMaxBufferSize,
+ },
+ }
+
+ // Add specified network protocols.
+ for _, netProto := range opts.NetworkProtocols {
+ s.networkProtocols[netProto.Number()] = netProto
+ if r, ok := netProto.(LinkAddressResolver); ok {
+ s.linkAddrResolvers[r.LinkAddressProtocol()] = r
+ }
+ }
+
+ // Add specified transport protocols.
+ for _, transProto := range opts.TransportProtocols {
+ s.transportProtocols[transProto.Number()] = &transportProtocolState{
+ proto: transProto,
+ }
+ }
+
+ // Add the factory for raw endpoints, if present.
+ s.rawFactory = opts.RawFactory
+
+ // Create the global transport demuxer.
+ s.demux = newTransportDemuxer(s)
+
+ return s
+}
+
+// UniqueID returns a unique identifier.
+func (s *Stack) UniqueID() uint64 {
+ return s.uniqueIDGenerator.UniqueID()
+}
+
+// SetNetworkProtocolOption allows configuring individual protocol level
+// options. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation or the provided value
+// is incorrect.
+func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
+ netProto, ok := s.networkProtocols[network]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return netProto.SetOption(option)
+}
+
+// NetworkProtocolOption allows retrieving individual protocol level option
+// values. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation.
+// e.g.
+// var v ipv4.MyOption
+// err := s.NetworkProtocolOption(tcpip.IPv4ProtocolNumber, &v)
+// if err != nil {
+// ...
+// }
+func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
+ netProto, ok := s.networkProtocols[network]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return netProto.Option(option)
+}
+
+// SetTransportProtocolOption allows configuring individual protocol level
+// options. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation or the provided value
+// is incorrect.
+func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
+ transProtoState, ok := s.transportProtocols[transport]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return transProtoState.proto.SetOption(option)
+}
+
+// TransportProtocolOption allows retrieving individual protocol level option
+// values. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation.
+// var v tcp.SACKEnabled
+// if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil {
+// ...
+// }
+func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
+ transProtoState, ok := s.transportProtocols[transport]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return transProtoState.proto.Option(option)
+}
+
+// SetTransportProtocolHandler sets the per-stack default handler for the given
+// protocol.
+//
+// It must be called only during initialization of the stack. Changing it as the
+// stack is operating is not supported.
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) {
+ state := s.transportProtocols[p]
+ if state != nil {
+ state.defaultHandler = h
+ }
+}
+
+// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
+func (s *Stack) NowNanoseconds() int64 {
+ return s.clock.NowNanoseconds()
+}
+
+// Stats returns a mutable copy of the current stats.
+//
+// This is not generally exported via the public interface, but is available
+// internally.
+func (s *Stack) Stats() tcpip.Stats {
+ return s.stats
+}
+
+// SetForwarding enables or disables the packet forwarding between NICs.
+//
+// When forwarding becomes enabled, any host-only state on all NICs will be
+// cleaned up and if IPv6 is enabled, NDP Router Solicitations will be started.
+// When forwarding becomes disabled and if IPv6 is enabled, NDP Router
+// Solicitations will be stopped.
+func (s *Stack) SetForwarding(enable bool) {
+ // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // If forwarding status didn't change, do nothing further.
+ if s.forwarding == enable {
+ return
+ }
+
+ s.forwarding = enable
+
+ // If this stack does not support IPv6, do nothing further.
+ if _, ok := s.networkProtocols[header.IPv6ProtocolNumber]; !ok {
+ return
+ }
+
+ if enable {
+ for _, nic := range s.nics {
+ nic.becomeIPv6Router()
+ }
+ } else {
+ for _, nic := range s.nics {
+ nic.becomeIPv6Host()
+ }
+ }
+}
+
+// Forwarding returns if the packet forwarding between NICs is enabled.
+func (s *Stack) Forwarding() bool {
+ // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return s.forwarding
+}
+
+// SetRouteTable assigns the route table to be used by this stack. It
+// specifies which NIC to use for given destination address ranges.
+//
+// This method takes ownership of the table.
+func (s *Stack) SetRouteTable(table []tcpip.Route) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.routeTable = table
+}
+
+// GetRouteTable returns the route table which is currently in use.
+func (s *Stack) GetRouteTable() []tcpip.Route {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return append([]tcpip.Route(nil), s.routeTable...)
+}
+
+// AddRoute appends a route to the route table.
+func (s *Stack) AddRoute(route tcpip.Route) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.routeTable = append(s.routeTable, route)
+}
+
+// NewEndpoint creates a new transport layer endpoint of the given protocol.
+func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ t, ok := s.transportProtocols[transport]
+ if !ok {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ return t.proto.NewEndpoint(s, network, waiterQueue)
+}
+
+// NewRawEndpoint creates a new raw transport layer endpoint of the given
+// protocol. Raw endpoints receive all traffic for a given protocol regardless
+// of address.
+func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
+ if s.rawFactory == nil {
+ return nil, tcpip.ErrNotPermitted
+ }
+
+ if !associated {
+ return s.rawFactory.NewUnassociatedEndpoint(s, network, transport, waiterQueue)
+ }
+
+ t, ok := s.transportProtocols[transport]
+ if !ok {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ return t.proto.NewRawEndpoint(s, network, waiterQueue)
+}
+
+// NewPacketEndpoint creates a new packet endpoint listening for the given
+// netProto.
+func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if s.rawFactory == nil {
+ return nil, tcpip.ErrNotPermitted
+ }
+
+ return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue)
+}
+
+// NICContext is an opaque pointer used to store client-supplied NIC metadata.
+type NICContext interface{}
+
+// NICOptions specifies the configuration of a NIC as it is being created.
+// The zero value creates an enabled, unnamed NIC.
+type NICOptions struct {
+ // Name specifies the name of the NIC.
+ Name string
+
+ // Disabled specifies whether to avoid calling Attach on the passed
+ // LinkEndpoint.
+ Disabled bool
+
+ // Context specifies user-defined data that will be returned in stack.NICInfo
+ // for the NIC. Clients of this library can use it to add metadata that
+ // should be tracked alongside a NIC, to avoid having to keep a
+ // map[tcpip.NICID]metadata mirroring stack.Stack's nic map.
+ Context NICContext
+}
+
+// CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and
+// NICOptions. See the documentation on type NICOptions for details on how
+// NICs can be configured.
+//
+// LinkEndpoint.Attach will be called to bind ep with a NetworkDispatcher.
+func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // Make sure id is unique.
+ if _, ok := s.nics[id]; ok {
+ return tcpip.ErrDuplicateNICID
+ }
+
+ // Make sure name is unique, unless unnamed.
+ if opts.Name != "" {
+ for _, n := range s.nics {
+ if n.Name() == opts.Name {
+ return tcpip.ErrDuplicateNICID
+ }
+ }
+ }
+
+ n := newNIC(s, id, opts.Name, ep, opts.Context)
+ s.nics[id] = n
+ if !opts.Disabled {
+ return n.enable()
+ }
+
+ return nil
+}
+
+// CreateNIC creates a NIC with the provided id and LinkEndpoint and calls
+// LinkEndpoint.Attach to bind ep with a NetworkDispatcher.
+func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
+ return s.CreateNICWithOptions(id, ep, NICOptions{})
+}
+
+// GetNICByName gets the NIC specified by name.
+func (s *Stack) GetNICByName(name string) (*NIC, bool) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ for _, nic := range s.nics {
+ if nic.Name() == name {
+ return nic, true
+ }
+ }
+ return nil, false
+}
+
+// EnableNIC enables the given NIC so that the link-layer endpoint can start
+// delivering packets to it.
+func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.enable()
+}
+
+// DisableNIC disables the given NIC.
+func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.disable()
+}
+
+// CheckNIC checks if a NIC is usable.
+func (s *Stack) CheckNIC(id tcpip.NICID) bool {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return false
+ }
+
+ return nic.enabled()
+}
+
+// RemoveNIC removes NIC and all related routes from the network stack.
+func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ return s.removeNICLocked(id)
+}
+
+// removeNICLocked removes NIC and all related routes from the network stack.
+//
+// s.mu must be locked.
+func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error {
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+ delete(s.nics, id)
+
+ // Remove routes in-place. n tracks the number of routes written.
+ n := 0
+ for i, r := range s.routeTable {
+ s.routeTable[i] = tcpip.Route{}
+ if r.NIC != id {
+ // Keep this route.
+ s.routeTable[n] = r
+ n++
+ }
+ }
+
+ s.routeTable = s.routeTable[:n]
+
+ return nic.remove()
+}
+
+// NICAddressRanges returns a map of NICIDs to their associated subnets.
+func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nics := map[tcpip.NICID][]tcpip.Subnet{}
+
+ for id, nic := range s.nics {
+ nics[id] = append(nics[id], nic.AddressRanges()...)
+ }
+ return nics
+}
+
+// NICInfo captures the name and addresses assigned to a NIC.
+type NICInfo struct {
+ Name string
+ LinkAddress tcpip.LinkAddress
+ ProtocolAddresses []tcpip.ProtocolAddress
+
+ // Flags indicate the state of the NIC.
+ Flags NICStateFlags
+
+ // MTU is the maximum transmission unit.
+ MTU uint32
+
+ Stats NICStats
+
+ // Context is user-supplied data optionally supplied in CreateNICWithOptions.
+ // See type NICOptions for more details.
+ Context NICContext
+}
+
+// HasNIC returns true if the NICID is defined in the stack.
+func (s *Stack) HasNIC(id tcpip.NICID) bool {
+ s.mu.RLock()
+ _, ok := s.nics[id]
+ s.mu.RUnlock()
+ return ok
+}
+
+// NICInfo returns a map of NICIDs to their associated information.
+func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nics := make(map[tcpip.NICID]NICInfo)
+ for id, nic := range s.nics {
+ flags := NICStateFlags{
+ Up: true, // Netstack interfaces are always up.
+ Running: nic.enabled(),
+ Promiscuous: nic.isPromiscuousMode(),
+ Loopback: nic.isLoopback(),
+ }
+ nics[id] = NICInfo{
+ Name: nic.name,
+ LinkAddress: nic.linkEP.LinkAddress(),
+ ProtocolAddresses: nic.PrimaryAddresses(),
+ Flags: flags,
+ MTU: nic.linkEP.MTU(),
+ Stats: nic.stats,
+ Context: nic.context,
+ }
+ }
+ return nics
+}
+
+// NICStateFlags holds information about the state of an NIC.
+type NICStateFlags struct {
+ // Up indicates whether the interface is running.
+ Up bool
+
+ // Running indicates whether resources are allocated.
+ Running bool
+
+ // Promiscuous indicates whether the interface is in promiscuous mode.
+ Promiscuous bool
+
+ // Loopback indicates whether the interface is a loopback.
+ Loopback bool
+}
+
+// AddAddress adds a new network-layer address to the specified NIC.
+func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
+}
+
+// AddProtocolAddress adds a new network-layer protocol address to the
+// specified NIC.
+func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error {
+ return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint)
+}
+
+// AddAddressWithOptions is the same as AddAddress, but allows you to specify
+// whether the new endpoint can be primary or not.
+func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error {
+ netProto, ok := s.networkProtocols[protocol]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: netProto.DefaultPrefixLen(),
+ },
+ }, peb)
+}
+
+// AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows
+// you to specify whether the new endpoint can be primary or not.
+func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[id]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.AddAddress(protocolAddress, peb)
+}
+
+// AddAddressRange adds a range of addresses to the specified NIC. The range is
+// given by a subnet address, and all addresses contained in the subnet are
+// used except for the subnet address itself and the subnet's broadcast
+// address.
+func (s *Stack) AddAddressRange(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[id]; ok {
+ nic.AddAddressRange(protocol, subnet)
+ return nil
+ }
+
+ return tcpip.ErrUnknownNICID
+}
+
+// RemoveAddressRange removes the range of addresses from the specified NIC.
+func (s *Stack) RemoveAddressRange(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[id]; ok {
+ nic.RemoveAddressRange(subnet)
+ return nil
+ }
+
+ return tcpip.ErrUnknownNICID
+}
+
+// RemoveAddress removes an existing network-layer address from the specified
+// NIC.
+func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[id]; ok {
+ return nic.RemoveAddress(addr)
+ }
+
+ return tcpip.ErrUnknownNICID
+}
+
+// AllAddresses returns a map of NICIDs to their protocol addresses (primary
+// and non-primary).
+func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nics := make(map[tcpip.NICID][]tcpip.ProtocolAddress)
+ for id, nic := range s.nics {
+ nics[id] = nic.AllAddresses()
+ }
+ return nics
+}
+
+// GetMainNICAddress returns the first non-deprecated primary address and prefix
+// for the given NIC and protocol. If no non-deprecated primary address exists,
+// a deprecated primary address and prefix will be returned. Returns an error if
+// the NIC doesn't exist and an empty value if the NIC doesn't have a primary
+// address for the given protocol.
+func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID
+ }
+
+ return nic.primaryAddress(protocol), nil
+}
+
+func (s *Stack) getRefEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
+ if len(localAddr) == 0 {
+ return nic.primaryEndpoint(netProto, remoteAddr)
+ }
+ return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint)
+}
+
+// FindRoute creates a route to the given destination address, leaving through
+// the given nic and local address (if provided).
+func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ isBroadcast := remoteAddr == header.IPv4Broadcast
+ isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
+ needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
+ if id != 0 && !needRoute {
+ if nic, ok := s.nics[id]; ok && nic.enabled() {
+ if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
+ return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil
+ }
+ }
+ } else {
+ for _, route := range s.routeTable {
+ if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) {
+ continue
+ }
+ if nic, ok := s.nics[route.NIC]; ok && nic.enabled() {
+ if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
+ if len(remoteAddr) == 0 {
+ // If no remote address was provided, then the route
+ // provided will refer to the link local address.
+ remoteAddr = ref.ep.ID().LocalAddress
+ }
+
+ r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback())
+ if needRoute {
+ r.NextHop = route.Gateway
+ }
+ return r, nil
+ }
+ }
+ }
+ }
+
+ if !needRoute {
+ return Route{}, tcpip.ErrNetworkUnreachable
+ }
+
+ return Route{}, tcpip.ErrNoRoute
+}
+
+// CheckNetworkProtocol checks if a given network protocol is enabled in the
+// stack.
+func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool {
+ _, ok := s.networkProtocols[protocol]
+ return ok
+}
+
+// CheckLocalAddress determines if the given local address exists, and if it
+// does, returns the id of the NIC it's bound to. Returns 0 if the address
+// does not exist.
+func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ // If a NIC is specified, we try to find the address there only.
+ if nicID != 0 {
+ nic := s.nics[nicID]
+ if nic == nil {
+ return 0
+ }
+
+ ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
+ if ref == nil {
+ return 0
+ }
+
+ ref.decRef()
+
+ return nic.id
+ }
+
+ // Go through all the NICs.
+ for _, nic := range s.nics {
+ ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
+ if ref != nil {
+ ref.decRef()
+ return nic.id
+ }
+ }
+
+ return 0
+}
+
+// SetPromiscuousMode enables or disables promiscuous mode in the given NIC.
+func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.setPromiscuousMode(enable)
+
+ return nil
+}
+
+// SetSpoofing enables or disables address spoofing in the given NIC, allowing
+// endpoints to bind to any address in the NIC.
+func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.setSpoofing(enable)
+
+ return nil
+}
+
+// AddLinkAddress adds a link address to the stack link cache.
+func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
+ fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
+ s.linkAddrCache.add(fullAddr, linkAddr)
+ // TODO: provide a way for a transport endpoint to receive a signal
+ // that AddLinkAddress for a particular address has been called.
+}
+
+// GetLinkAddress implements LinkAddressCache.GetLinkAddress.
+func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+ s.mu.RLock()
+ nic := s.nics[nicID]
+ if nic == nil {
+ s.mu.RUnlock()
+ return "", nil, tcpip.ErrUnknownNICID
+ }
+ s.mu.RUnlock()
+
+ fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
+ linkRes := s.linkAddrResolvers[protocol]
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker)
+}
+
+// RemoveWaker implements LinkAddressCache.RemoveWaker.
+func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic := s.nics[nicID]; nic == nil {
+ fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
+ s.linkAddrCache.removeWaker(fullAddr, waker)
+ }
+}
+
+// RegisterTransportEndpoint registers the given endpoint with the stack
+// transport dispatcher. Received packets that match the provided id will be
+// delivered to the given endpoint; specifying a nic is optional, but
+// nic-specific IDs have precedence over global ones.
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
+}
+
+// CheckRegisterTransportEndpoint checks if an endpoint can be registered with
+// the stack transport dispatcher.
+func (s *Stack) CheckRegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ return s.demux.checkEndpoint(netProtos, protocol, id, flags, bindToDevice)
+}
+
+// UnregisterTransportEndpoint removes the endpoint with the given id from the
+// stack transport dispatcher.
+func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
+}
+
+// StartTransportEndpointCleanup removes the endpoint with the given id from
+// the stack transport dispatcher. It also transitions it to the cleanup stage.
+func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.cleanupEndpoints[ep] = struct{}{}
+
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
+}
+
+// CompleteTransportEndpointCleanup removes the endpoint from the cleanup
+// stage.
+func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) {
+ s.mu.Lock()
+ delete(s.cleanupEndpoints, ep)
+ s.mu.Unlock()
+}
+
+// FindTransportEndpoint finds an endpoint that most closely matches the provided
+// id. If no endpoint is found it returns nil.
+func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
+ return s.demux.findTransportEndpoint(netProto, transProto, id, r)
+}
+
+// RegisterRawTransportEndpoint registers the given endpoint with the stack
+// transport dispatcher. Received packets that match the provided transport
+// protocol will be delivered to the given endpoint.
+func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
+ return s.demux.registerRawEndpoint(netProto, transProto, ep)
+}
+
+// UnregisterRawTransportEndpoint removes the endpoint for the transport
+// protocol from the stack transport dispatcher.
+func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
+ s.demux.unregisterRawEndpoint(netProto, transProto, ep)
+}
+
+// RegisterRestoredEndpoint records e as an endpoint that has been restored on
+// this stack.
+func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) {
+ s.mu.Lock()
+ s.resumableEndpoints = append(s.resumableEndpoints, e)
+ s.mu.Unlock()
+}
+
+// RegisteredEndpoints returns all endpoints which are currently registered.
+func (s *Stack) RegisteredEndpoints() []TransportEndpoint {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ var es []TransportEndpoint
+ for _, e := range s.demux.protocol {
+ es = append(es, e.transportEndpoints()...)
+ }
+ return es
+}
+
+// CleanupEndpoints returns endpoints currently in the cleanup state.
+func (s *Stack) CleanupEndpoints() []TransportEndpoint {
+ s.mu.Lock()
+ es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints))
+ for e := range s.cleanupEndpoints {
+ es = append(es, e)
+ }
+ s.mu.Unlock()
+ return es
+}
+
+// RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful
+// for restoring a stack after a save.
+func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) {
+ s.mu.Lock()
+ for _, e := range es {
+ s.cleanupEndpoints[e] = struct{}{}
+ }
+ s.mu.Unlock()
+}
+
+// Close closes all currently registered transport endpoints.
+//
+// Endpoints created or modified during this call may not get closed.
+func (s *Stack) Close() {
+ for _, e := range s.RegisteredEndpoints() {
+ e.Abort()
+ }
+ for _, p := range s.transportProtocols {
+ p.proto.Close()
+ }
+ for _, p := range s.networkProtocols {
+ p.Close()
+ }
+}
+
+// Wait waits for all transport and link endpoints to halt their worker
+// goroutines.
+//
+// Endpoints created or modified during this call may not get waited on.
+//
+// Note that link endpoints must be stopped via an implementation specific
+// mechanism.
+func (s *Stack) Wait() {
+ for _, e := range s.RegisteredEndpoints() {
+ e.Wait()
+ }
+ for _, e := range s.CleanupEndpoints() {
+ e.Wait()
+ }
+ for _, p := range s.transportProtocols {
+ p.proto.Wait()
+ }
+ for _, p := range s.networkProtocols {
+ p.Wait()
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ for _, n := range s.nics {
+ n.linkEP.Wait()
+ }
+}
+
+// Resume restarts the stack after a restore. This must be called after the
+// entire system has been restored.
+func (s *Stack) Resume() {
+ // ResumableEndpoint.Resume() may call other methods on s, so we can't hold
+ // s.mu while resuming the endpoints.
+ s.mu.Lock()
+ eps := s.resumableEndpoints
+ s.resumableEndpoints = nil
+ s.mu.Unlock()
+ for _, e := range eps {
+ e.Resume(s)
+ }
+}
+
+// RegisterPacketEndpoint registers ep with the stack, causing it to receive
+// all traffic of the specified netProto on the given NIC. If nicID is 0, it
+// receives traffic from every NIC.
+func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // If no NIC is specified, capture on all devices.
+ if nicID == 0 {
+ // Register with each NIC.
+ for _, nic := range s.nics {
+ if err := nic.registerPacketEndpoint(netProto, ep); err != nil {
+ s.unregisterPacketEndpointLocked(0, netProto, ep)
+ return err
+ }
+ }
+ return nil
+ }
+
+ // Capture on a specific device.
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+ if err := nic.registerPacketEndpoint(netProto, ep); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// UnregisterPacketEndpoint unregisters ep for packets of the specified
+// netProto from the specified NIC. If nicID is 0, ep is unregistered from all
+// NICs.
+func (s *Stack) UnregisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.unregisterPacketEndpointLocked(nicID, netProto, ep)
+}
+
+func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
+ // If no NIC is specified, unregister on all devices.
+ if nicID == 0 {
+ // Unregister with each NIC.
+ for _, nic := range s.nics {
+ nic.unregisterPacketEndpoint(netProto, ep)
+ }
+ return
+ }
+
+ // Unregister in a single device.
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return
+ }
+ nic.unregisterPacketEndpoint(netProto, ep)
+}
+
+// WritePacket writes data directly to the specified NIC. It adds an ethernet
+// header based on the arguments.
+func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
+ s.mu.Lock()
+ nic, ok := s.nics[nicID]
+ s.mu.Unlock()
+ if !ok {
+ return tcpip.ErrUnknownDevice
+ }
+
+ // Add our own fake ethernet header.
+ ethFields := header.EthernetFields{
+ SrcAddr: nic.linkEP.LinkAddress(),
+ DstAddr: dst,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ vv := buffer.View(fakeHeader).ToVectorisedView()
+ vv.Append(payload)
+
+ if err := nic.linkEP.WriteRawPacket(vv); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// WriteRawPacket writes data directly to the specified NIC without adding any
+// headers.
+func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error {
+ s.mu.Lock()
+ nic, ok := s.nics[nicID]
+ s.mu.Unlock()
+ if !ok {
+ return tcpip.ErrUnknownDevice
+ }
+
+ if err := nic.linkEP.WriteRawPacket(payload); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// NetworkProtocolInstance returns the protocol instance in the stack for the
+// specified network protocol. This method is public for protocol implementers
+// and tests to use.
+func (s *Stack) NetworkProtocolInstance(num tcpip.NetworkProtocolNumber) NetworkProtocol {
+ if p, ok := s.networkProtocols[num]; ok {
+ return p
+ }
+ return nil
+}
+
+// TransportProtocolInstance returns the protocol instance in the stack for the
+// specified transport protocol. This method is public for protocol implementers
+// and tests to use.
+func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) TransportProtocol {
+ if pState, ok := s.transportProtocols[num]; ok {
+ return pState.proto
+ }
+ return nil
+}
+
+// AddTCPProbe installs a probe function that will be invoked on every segment
+// received by a given TCP endpoint. The probe function is passed a copy of the
+// TCP endpoint state before and after processing of the segment.
+//
+// NOTE: TCPProbe is added only to endpoints created after this call. Endpoints
+// created prior to this call will not call the probe function.
+//
+// Further, installing two different probes back to back can result in some
+// endpoints calling the first one and some the second one. There is no
+// guarantee provided on which probe will be invoked. Ideally this should only
+// be called once per stack.
+func (s *Stack) AddTCPProbe(probe TCPProbeFunc) {
+ s.mu.Lock()
+ s.tcpProbeFunc = probe
+ s.mu.Unlock()
+}
+
+// GetTCPProbe returns the TCPProbeFunc if installed with AddTCPProbe, nil
+// otherwise.
+func (s *Stack) GetTCPProbe() TCPProbeFunc {
+ s.mu.Lock()
+ p := s.tcpProbeFunc
+ s.mu.Unlock()
+ return p
+}
+
+// RemoveTCPProbe removes an installed TCP probe.
+//
+// NOTE: This only ensures that endpoints created after this call do not
+// have a probe attached. Endpoints already created will continue to invoke
+// TCP probe.
+func (s *Stack) RemoveTCPProbe() {
+ s.mu.Lock()
+ s.tcpProbeFunc = nil
+ s.mu.Unlock()
+}
+
+// JoinGroup joins the given multicast group on the given NIC.
+func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
+ // TODO: notify network of subscription via igmp protocol.
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[nicID]; ok {
+ return nic.joinGroup(protocol, multicastAddr)
+ }
+ return tcpip.ErrUnknownNICID
+}
+
+// LeaveGroup leaves the given multicast group on the given NIC.
+func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[nicID]; ok {
+ return nic.leaveGroup(multicastAddr)
+ }
+ return tcpip.ErrUnknownNICID
+}
+
+// IsInGroup returns true if the NIC with ID nicID has joined the multicast
+// group multicastAddr.
+func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[nicID]; ok {
+ return nic.isInGroup(multicastAddr), nil
+ }
+ return false, tcpip.ErrUnknownNICID
+}
+
+// IPTables returns the stack's iptables.
+func (s *Stack) IPTables() *IPTables {
+ return s.tables
+}
+
+// ICMPLimit returns the maximum number of ICMP messages that can be sent
+// in one second.
+func (s *Stack) ICMPLimit() rate.Limit {
+ return s.icmpRateLimiter.Limit()
+}
+
+// SetICMPLimit sets the maximum number of ICMP messages that be sent
+// in one second.
+func (s *Stack) SetICMPLimit(newLimit rate.Limit) {
+ s.icmpRateLimiter.SetLimit(newLimit)
+}
+
+// ICMPBurst returns the maximum number of ICMP messages that can be sent
+// in a single burst.
+func (s *Stack) ICMPBurst() int {
+ return s.icmpRateLimiter.Burst()
+}
+
+// SetICMPBurst sets the maximum number of ICMP messages that can be sent
+// in a single burst.
+func (s *Stack) SetICMPBurst(burst int) {
+ s.icmpRateLimiter.SetBurst(burst)
+}
+
+// AllowICMPMessage returns true if we the rate limiter allows at least one
+// ICMP message to be sent at this instant.
+func (s *Stack) AllowICMPMessage() bool {
+ return s.icmpRateLimiter.Allow()
+}
+
+// IsAddrTentative returns true if addr is tentative on the NIC with ID id.
+//
+// Note that if addr is not associated with a NIC with id ID, then this
+// function will return false. It will only return true if the address is
+// associated with the NIC AND it is tentative.
+func (s *Stack) IsAddrTentative(id tcpip.NICID, addr tcpip.Address) (bool, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return false, tcpip.ErrUnknownNICID
+ }
+
+ return nic.isAddrTentative(addr), nil
+}
+
+// DupTentativeAddrDetected attempts to inform the NIC with ID id that a
+// tentative addr on it is a duplicate on a link.
+func (s *Stack) DupTentativeAddrDetected(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.dupTentativeAddrDetected(addr)
+}
+
+// SetNDPConfigurations sets the per-interface NDP configurations on the NIC
+// with ID id to c.
+//
+// Note, if c contains invalid NDP configuration values, it will be fixed to
+// use default values for the erroneous values.
+func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.setNDPConfigs(c)
+
+ return nil
+}
+
+// HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement
+// message that it needs to handle.
+func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.handleNDPRA(ip, ra)
+
+ return nil
+}
+
+// Seed returns a 32 bit value that can be used as a seed value for port
+// picking, ISN generation etc.
+//
+// NOTE: The seed is generated once during stack initialization only.
+func (s *Stack) Seed() uint32 {
+ return s.seed
+}
+
+// Rand returns a reference to a pseudo random generator that can be used
+// to generate random numbers as required.
+func (s *Stack) Rand() *mathrand.Rand {
+ return s.randomGenerator
+}
+
+func generateRandUint32() uint32 {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ return binary.LittleEndian.Uint32(b)
+}
+
+func generateRandInt64() int64 {
+ b := make([]byte, 8)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ buf := bytes.NewReader(b)
+ var v int64
+ if err := binary.Read(buf, binary.LittleEndian, &v); err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// FindNetworkEndpoint returns the network endpoint for the given address.
+func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ for _, nic := range s.nics {
+ id := NetworkEndpointID{address}
+
+ if ref, ok := nic.mu.endpoints[id]; ok {
+ nic.mu.RLock()
+ defer nic.mu.RUnlock()
+
+ // An endpoint with this id exists, check if it can be
+ // used and return it.
+ return ref.ep, nil
+ }
+ }
+ return nil, tcpip.ErrBadAddress
+}
+
+// FindNICNameFromID returns the name of the nic for the given NICID.
+func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return ""
+ }
+
+ return nic.Name()
+}
diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go
new file mode 100644
index 000000000..dfec4258a
--- /dev/null
+++ b/pkg/tcpip/stack/stack_global_state.go
@@ -0,0 +1,19 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+// StackFromEnv is the global stack created in restore run.
+// FIXME(b/36201077)
+var StackFromEnv *Stack
diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go
new file mode 100644
index 000000000..0b093e6c5
--- /dev/null
+++ b/pkg/tcpip/stack/stack_options.go
@@ -0,0 +1,106 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import "gvisor.dev/gvisor/pkg/tcpip"
+
+const (
+ // MinBufferSize is the smallest size of a receive or send buffer.
+ MinBufferSize = 4 << 10 // 4 KiB
+
+ // DefaultBufferSize is the default size of the send/recv buffer for a
+ // transport endpoint.
+ DefaultBufferSize = 212 << 10 // 212 KiB
+
+ // DefaultMaxBufferSize is the default maximum permitted size of a
+ // send/receive buffer.
+ DefaultMaxBufferSize = 4 << 20 // 4 MiB
+)
+
+// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to
+// get/set the default, min and max send buffer sizes.
+type SendBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to
+// get/set the default, min and max receive buffer sizes.
+type ReceiveBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+// SetOption allows setting stack wide options.
+func (s *Stack) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case SendBufferSizeOption:
+ // Make sure we don't allow lowering the buffer below minimum
+ // required for stack to work.
+ if v.Min < MinBufferSize {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ if v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ s.mu.Lock()
+ s.sendBufferSize = v
+ s.mu.Unlock()
+ return nil
+
+ case ReceiveBufferSizeOption:
+ // Make sure we don't allow lowering the buffer below minimum
+ // required for stack to work.
+ if v.Min < MinBufferSize {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ if v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ s.mu.Lock()
+ s.receiveBufferSize = v
+ s.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Option allows retrieving stack wide options.
+func (s *Stack) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *SendBufferSizeOption:
+ s.mu.RLock()
+ *v = s.sendBufferSize
+ s.mu.RUnlock()
+ return nil
+
+ case *ReceiveBufferSizeOption:
+ s.mu.RLock()
+ *v = s.receiveBufferSize
+ s.mu.RUnlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
new file mode 100644
index 000000000..7657a4101
--- /dev/null
+++ b/pkg/tcpip/stack/stack_test.go
@@ -0,0 +1,3420 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package stack_test contains tests for the stack. It is in its own package so
+// that the tests can also validate that all definitions needed to implement
+// transport and network protocols are properly exported by the stack package.
+package stack_test
+
+import (
+ "bytes"
+ "fmt"
+ "math"
+ "sort"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+)
+
+const (
+ fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
+ fakeNetHeaderLen = 12
+ fakeDefaultPrefixLen = 8
+
+ // fakeControlProtocol is used for control packets that represent
+ // destination port unreachable.
+ fakeControlProtocol tcpip.TransportProtocolNumber = 2
+
+ // defaultMTU is the MTU, in bytes, used throughout the tests, except
+ // where another value is explicitly used. It is chosen to match the MTU
+ // of loopback interfaces on linux systems.
+ defaultMTU = 65536
+
+ dstAddrOffset = 0
+ srcAddrOffset = 1
+ protocolNumberOffset = 2
+)
+
+// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and
+// received packets; the counts of all endpoints are aggregated in the protocol
+// descriptor.
+//
+// Headers of this protocol are fakeNetHeaderLen bytes, but we currently only
+// use the first three: destination address, source address, and transport
+// protocol. They're all one byte fields to simplify parsing.
+type fakeNetworkEndpoint struct {
+ nicID tcpip.NICID
+ id stack.NetworkEndpointID
+ prefixLen int
+ proto *fakeNetworkProtocol
+ dispatcher stack.TransportDispatcher
+ ep stack.LinkEndpoint
+}
+
+func (f *fakeNetworkEndpoint) MTU() uint32 {
+ return f.ep.MTU() - uint32(f.MaxHeaderLength())
+}
+
+func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
+ return f.nicID
+}
+
+func (f *fakeNetworkEndpoint) PrefixLen() int {
+ return f.prefixLen
+}
+
+func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
+ return 123
+}
+
+func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
+ return &f.id
+}
+
+func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ // Increment the received packet count in the protocol descriptor.
+ f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
+
+ // Handle control packets.
+ if pkt.NetworkHeader[protocolNumberOffset] == uint8(fakeControlProtocol) {
+ nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
+ if !ok {
+ return
+ }
+ pkt.Data.TrimFront(fakeNetHeaderLen)
+ f.dispatcher.DeliverTransportControlPacket(
+ tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
+ tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
+ fakeNetNumber,
+ tcpip.TransportProtocolNumber(nb[protocolNumberOffset]),
+ stack.ControlPortUnreachable, 0, pkt)
+ return
+ }
+
+ // Dispatch the packet to the transport protocol.
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt)
+}
+
+func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
+ return f.ep.MaxHeaderLength() + fakeNetHeaderLen
+}
+
+func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
+ return 0
+}
+
+func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return f.ep.Capabilities()
+}
+
+func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return f.proto.Number()
+}
+
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
+ // Increment the sent packet count in the protocol descriptor.
+ f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
+
+ // Add the protocol's header to the packet and send it to the link
+ // endpoint.
+ pkt.NetworkHeader = pkt.Header.Prepend(fakeNetHeaderLen)
+ pkt.NetworkHeader[dstAddrOffset] = r.RemoteAddress[0]
+ pkt.NetworkHeader[srcAddrOffset] = f.id.LocalAddress[0]
+ pkt.NetworkHeader[protocolNumberOffset] = byte(params.Protocol)
+
+ if r.Loop&stack.PacketLoop != 0 {
+ f.HandlePacket(r, pkt)
+ }
+ if r.Loop&stack.PacketOut == 0 {
+ return nil
+ }
+
+ return f.ep.WritePacket(r, gso, fakeNetNumber, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+func (*fakeNetworkEndpoint) Close() {}
+
+type fakeNetGoodOption bool
+
+type fakeNetBadOption bool
+
+type fakeNetInvalidValueOption int
+
+type fakeNetOptions struct {
+ good bool
+}
+
+// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the
+// number of packets sent and received via endpoints of this protocol. The index
+// where packets are added is given by the packet's destination address MOD 10.
+type fakeNetworkProtocol struct {
+ packetCount [10]int
+ sendPacketCount [10]int
+ opts fakeNetOptions
+}
+
+func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
+ return fakeNetNumber
+}
+
+func (f *fakeNetworkProtocol) MinimumPacketSize() int {
+ return fakeNetHeaderLen
+}
+
+func (f *fakeNetworkProtocol) DefaultPrefixLen() int {
+ return fakeDefaultPrefixLen
+}
+
+func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
+ return f.packetCount[int(intfAddr)%len(f.packetCount)]
+}
+
+func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
+}
+
+func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
+ return &fakeNetworkEndpoint{
+ nicID: nicID,
+ id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
+ prefixLen: addrWithPrefix.PrefixLen,
+ proto: f,
+ dispatcher: dispatcher,
+ ep: ep,
+ }, nil
+}
+
+func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case fakeNetGoodOption:
+ f.opts.good = bool(v)
+ return nil
+ case fakeNetInvalidValueOption:
+ return tcpip.ErrInvalidOptionValue
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *fakeNetGoodOption:
+ *v = fakeNetGoodOption(f.opts.good)
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Close implements TransportProtocol.Close.
+func (*fakeNetworkProtocol) Close() {}
+
+// Wait implements TransportProtocol.Wait.
+func (*fakeNetworkProtocol) Wait() {}
+
+// Parse implements TransportProtocol.Parse.
+func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
+ hdr, ok := pkt.Data.PullUp(fakeNetHeaderLen)
+ if !ok {
+ return 0, false, false
+ }
+ pkt.NetworkHeader = hdr
+ pkt.Data.TrimFront(fakeNetHeaderLen)
+ return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
+}
+
+func fakeNetFactory() stack.NetworkProtocol {
+ return &fakeNetworkProtocol{}
+}
+
+// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify
+// that LinkEndpoint.Attach was called.
+type linkEPWithMockedAttach struct {
+ stack.LinkEndpoint
+ attached bool
+}
+
+// Attach implements stack.LinkEndpoint.Attach.
+func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) {
+ l.LinkEndpoint.Attach(d)
+ l.attached = d != nil
+}
+
+func (l *linkEPWithMockedAttach) isAttached() bool {
+ return l.attached
+}
+
+func TestNetworkReceive(t *testing.T) {
+ // Create a stack with the fake network protocol, one nic, and two
+ // addresses attached to it: 1 & 2.
+ ep := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+
+ buf := buffer.NewView(30)
+
+ // Make sure packet with wrong address is not delivered.
+ buf[dstAddrOffset] = 3
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ if fakeNet.packetCount[1] != 0 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
+ }
+ if fakeNet.packetCount[2] != 0 {
+ t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0)
+ }
+
+ // Make sure packet is delivered to first endpoint.
+ buf[dstAddrOffset] = 1
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+ if fakeNet.packetCount[2] != 0 {
+ t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0)
+ }
+
+ // Make sure packet is delivered to second endpoint.
+ buf[dstAddrOffset] = 2
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+ if fakeNet.packetCount[2] != 1 {
+ t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
+ }
+
+ // Make sure packet is not delivered if protocol number is wrong.
+ ep.InjectInbound(fakeNetNumber-1, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+ if fakeNet.packetCount[2] != 1 {
+ t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
+ }
+
+ // Make sure packet that is too small is dropped.
+ buf.CapLength(2)
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+ if fakeNet.packetCount[2] != 1 {
+ t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
+ }
+}
+
+func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error {
+ r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+ return send(r, payload)
+}
+
+func send(r stack.Route, payload buffer.View) *tcpip.Error {
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
+ Header: hdr,
+ Data: payload.ToVectorisedView(),
+ })
+}
+
+func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) {
+ t.Helper()
+ ep.Drain()
+ if err := sendTo(s, addr, payload); err != nil {
+ t.Error("sendTo failed:", err)
+ }
+ if got, want := ep.Drain(), 1; got != want {
+ t.Errorf("sendTo packet count: got = %d, want %d", got, want)
+ }
+}
+
+func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) {
+ t.Helper()
+ ep.Drain()
+ if err := send(r, payload); err != nil {
+ t.Error("send failed:", err)
+ }
+ if got, want := ep.Drain(), 1; got != want {
+ t.Errorf("send packet count: got = %d, want %d", got, want)
+ }
+}
+
+func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+ t.Helper()
+ if gotErr := send(r, payload); gotErr != wantErr {
+ t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
+ }
+}
+
+func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+ t.Helper()
+ if gotErr := sendTo(s, addr, payload); gotErr != wantErr {
+ t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr)
+ }
+}
+
+func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) {
+ t.Helper()
+ // testRecvInternal injects one packet, and we expect to receive it.
+ want := fakeNet.PacketCount(localAddrByte) + 1
+ testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want)
+}
+
+func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) {
+ t.Helper()
+ // testRecvInternal injects one packet, and we do NOT expect to receive it.
+ want := fakeNet.PacketCount(localAddrByte)
+ testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want)
+}
+
+func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) {
+ t.Helper()
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ if got := fakeNet.PacketCount(localAddrByte); got != want {
+ t.Errorf("receive packet count: got = %d, want %d", got, want)
+ }
+}
+
+func TestNetworkSend(t *testing.T) {
+ // Create a stack with the fake network protocol, one nic, and one
+ // address: 1. The route table sends all packets through the only
+ // existing nic.
+ ep := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatal("NewNIC failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+
+ // Make sure that the link-layer endpoint received the outbound packet.
+ testSendTo(t, s, "\x03", ep, nil)
+}
+
+func TestNetworkSendMultiRoute(t *testing.T) {
+ // Create a stack with the fake network protocol, two nics, and two
+ // addresses per nic, the first nic has odd address, the second one has
+ // even addresses.
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+
+ if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ {Destination: subnet0, Gateway: "\x00", NIC: 2},
+ })
+ }
+
+ // Send a packet to an odd destination.
+ testSendTo(t, s, "\x05", ep1, nil)
+
+ // Send a packet to an even destination.
+ testSendTo(t, s, "\x06", ep2, nil)
+}
+
+func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) {
+ r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+
+ defer r.Release()
+
+ if r.LocalAddress != expectedSrcAddr {
+ t.Fatalf("Bad source address: expected %v, got %v", expectedSrcAddr, r.LocalAddress)
+ }
+
+ if r.RemoteAddress != dstAddr {
+ t.Fatalf("Bad destination address: expected %v, got %v", dstAddr, r.RemoteAddress)
+ }
+}
+
+func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) {
+ _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != tcpip.ErrNoRoute {
+ t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, tcpip.ErrNoRoute)
+ }
+}
+
+// TestAttachToLinkEndpointImmediately tests that a LinkEndpoint is attached to
+// a NetworkDispatcher when the NIC is created.
+func TestAttachToLinkEndpointImmediately(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ nicOpts stack.NICOptions
+ }{
+ {
+ name: "Create enabled NIC",
+ nicOpts: stack.NICOptions{Disabled: false},
+ },
+ {
+ name: "Create disabled NIC",
+ nicOpts: stack.NICOptions{Disabled: true},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ e := linkEPWithMockedAttach{
+ LinkEndpoint: loopback.New(),
+ }
+
+ if err := s.CreateNICWithOptions(nicID, &e, test.nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err)
+ }
+ if !e.isAttached() {
+ t.Fatal("link endpoint not attached to a network dispatcher")
+ }
+ })
+ }
+}
+
+func TestDisableUnknownNIC(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID)
+ }
+}
+
+func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ e := loopback.New()
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ checkNIC := func(enabled bool) {
+ t.Helper()
+
+ allNICInfo := s.NICInfo()
+ nicInfo, ok := allNICInfo[nicID]
+ if !ok {
+ t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
+ } else if nicInfo.Flags.Running != enabled {
+ t.Errorf("got nicInfo.Flags.Running = %t, want = %t", nicInfo.Flags.Running, enabled)
+ }
+
+ if got := s.CheckNIC(nicID); got != enabled {
+ t.Errorf("got s.CheckNIC(%d) = %t, want = %t", nicID, got, enabled)
+ }
+ }
+
+ // NIC should initially report itself as disabled.
+ checkNIC(false)
+
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ checkNIC(true)
+
+ // If the NIC is not reporting a correct enabled status, we cannot trust the
+ // next check so end the test here.
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ checkNIC(false)
+}
+
+func TestRemoveUnknownNIC(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID)
+ }
+}
+
+func TestRemoveNIC(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ e := linkEPWithMockedAttach{
+ LinkEndpoint: loopback.New(),
+ }
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ // NIC should be present in NICInfo and attached to a NetworkDispatcher.
+ allNICInfo := s.NICInfo()
+ if _, ok := allNICInfo[nicID]; !ok {
+ t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
+ }
+ if !e.isAttached() {
+ t.Fatal("link endpoint not attached to a network dispatcher")
+ }
+
+ // Removing a NIC should remove it from NICInfo and e should be detached from
+ // the NetworkDispatcher.
+ if err := s.RemoveNIC(nicID); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID, err)
+ }
+ if nicInfo, ok := s.NICInfo()[nicID]; ok {
+ t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo)
+ }
+ if e.isAttached() {
+ t.Error("link endpoint for removed NIC still attached to a network dispatcher")
+ }
+}
+
+func TestRouteWithDownNIC(t *testing.T) {
+ tests := []struct {
+ name string
+ downFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error
+ upFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error
+ }{
+ {
+ name: "Disabled NIC",
+ downFn: (*stack.Stack).DisableNIC,
+ upFn: (*stack.Stack).EnableNIC,
+ },
+
+ // Once a NIC is removed, it cannot be brought up.
+ {
+ name: "Removed NIC",
+ downFn: (*stack.Stack).RemoveNIC,
+ },
+ }
+
+ const unspecifiedNIC = 0
+ const nicID1 = 1
+ const nicID2 = 2
+ const addr1 = tcpip.Address("\x01")
+ const addr2 = tcpip.Address("\x02")
+ const nic1Dst = tcpip.Address("\x05")
+ const nic2Dst = tcpip.Address("\x06")
+
+ setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep1 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ }
+
+ ep2 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+
+ if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ }
+
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
+ {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
+ })
+ }
+
+ return s, ep1, ep2
+ }
+
+ // Tests that routes through a down NIC are not used when looking up a route
+ // for a destination.
+ t.Run("Find", func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, _, _ := setup(t)
+
+ // Test routes to odd address.
+ testRoute(t, s, unspecifiedNIC, "", "\x05", addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1)
+ testRoute(t, s, nicID1, addr1, "\x05", addr1)
+
+ // Test routes to even address.
+ testRoute(t, s, unspecifiedNIC, "", "\x06", addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2)
+ testRoute(t, s, nicID2, addr2, "\x06", addr2)
+
+ // Bringing NIC1 down should result in no routes to odd addresses. Routes to
+ // even addresses should continue to be available as NIC2 is still up.
+ if err := test.downFn(s, nicID1); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID1, err)
+ }
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2)
+ testRoute(t, s, nicID2, addr2, nic2Dst, addr2)
+
+ // Bringing NIC2 down should result in no routes to even addresses. No
+ // route should be available to any address as routes to odd addresses
+ // were made unavailable by bringing NIC1 down above.
+ if err := test.downFn(s, nicID2); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID2, err)
+ }
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+
+ if upFn := test.upFn; upFn != nil {
+ // Bringing NIC1 up should make routes to odd addresses available
+ // again. Routes to even addresses should continue to be unavailable
+ // as NIC2 is still down.
+ if err := upFn(s, nicID1); err != nil {
+ t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
+ }
+ testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1)
+ testRoute(t, s, nicID1, addr1, nic1Dst, addr1)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+ }
+ })
+ }
+ })
+
+ // Tests that writing a packet using a Route through a down NIC fails.
+ t.Run("WritePacket", func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, ep1, ep2 := setup(t)
+
+ r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err)
+ }
+ defer r1.Release()
+
+ r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err)
+ }
+ defer r2.Release()
+
+ // If we failed to get routes r1 or r2, we cannot proceed with the test.
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ buf := buffer.View([]byte{1})
+ testSend(t, r1, ep1, buf)
+ testSend(t, r2, ep2, buf)
+
+ // Writes with Routes that use NIC1 after being brought down should fail.
+ if err := test.downFn(s, nicID1); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID1, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testSend(t, r2, ep2, buf)
+
+ // Writes with Routes that use NIC2 after being brought down should fail.
+ if err := test.downFn(s, nicID2); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID2, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+
+ if upFn := test.upFn; upFn != nil {
+ // Writes with Routes that use NIC1 after being brought up should
+ // succeed.
+ //
+ // TODO(b/147015577): Should we instead completely invalidate all
+ // Routes that were bound to a NIC that was brought down at some
+ // point?
+ if err := upFn(s, nicID1); err != nil {
+ t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
+ }
+ testSend(t, r1, ep1, buf)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+ }
+ })
+ }
+ })
+}
+
+func TestRoutes(t *testing.T) {
+ // Create a stack with the fake network protocol, two nics, and two
+ // addresses per nic, the first nic has odd address, the second one has
+ // even addresses.
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+
+ if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ {Destination: subnet0, Gateway: "\x00", NIC: 2},
+ })
+ }
+
+ // Test routes to odd address.
+ testRoute(t, s, 0, "", "\x05", "\x01")
+ testRoute(t, s, 0, "\x01", "\x05", "\x01")
+ testRoute(t, s, 1, "\x01", "\x05", "\x01")
+ testRoute(t, s, 0, "\x03", "\x05", "\x03")
+ testRoute(t, s, 1, "\x03", "\x05", "\x03")
+
+ // Test routes to even address.
+ testRoute(t, s, 0, "", "\x06", "\x02")
+ testRoute(t, s, 0, "\x02", "\x06", "\x02")
+ testRoute(t, s, 2, "\x02", "\x06", "\x02")
+ testRoute(t, s, 0, "\x04", "\x06", "\x04")
+ testRoute(t, s, 2, "\x04", "\x06", "\x04")
+
+ // Try to send to odd numbered address from even numbered ones, then
+ // vice-versa.
+ testNoRoute(t, s, 0, "\x02", "\x05")
+ testNoRoute(t, s, 2, "\x02", "\x05")
+ testNoRoute(t, s, 0, "\x04", "\x05")
+ testNoRoute(t, s, 2, "\x04", "\x05")
+
+ testNoRoute(t, s, 0, "\x01", "\x06")
+ testNoRoute(t, s, 1, "\x01", "\x06")
+ testNoRoute(t, s, 0, "\x03", "\x06")
+ testNoRoute(t, s, 1, "\x03", "\x06")
+}
+
+func TestAddressRemoval(t *testing.T) {
+ const localAddrByte byte = 0x01
+ localAddr := tcpip.Address([]byte{localAddrByte})
+ remoteAddr := tcpip.Address("\x02")
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+
+ buf := buffer.NewView(30)
+
+ // Send and receive packets, and verify they are received.
+ buf[dstAddrOffset] = localAddrByte
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+
+ // Remove the address, then check that send/receive doesn't work anymore.
+ if err := s.RemoveAddress(1, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+
+ // Check that removing the same address fails.
+ if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
+ t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
+ }
+}
+
+func TestAddressRemovalWithRouteHeld(t *testing.T) {
+ const localAddrByte byte = 0x01
+ localAddr := tcpip.Address([]byte{localAddrByte})
+ remoteAddr := tcpip.Address("\x02")
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+ buf := buffer.NewView(30)
+
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+
+ // Send and receive packets, and verify they are received.
+ buf[dstAddrOffset] = localAddrByte
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSend(t, r, ep, nil)
+ testSendTo(t, s, remoteAddr, ep, nil)
+
+ // Remove the address, then check that send/receive doesn't work anymore.
+ if err := s.RemoveAddress(1, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+
+ // Check that removing the same address fails.
+ if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
+ t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
+ }
+}
+
+func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.Address) {
+ t.Helper()
+ info, ok := s.NICInfo()[nicID]
+ if !ok {
+ t.Fatalf("NICInfo() failed to find nicID=%d", nicID)
+ }
+ if len(addr) == 0 {
+ // No address given, verify that there is no address assigned to the NIC.
+ for _, a := range info.ProtocolAddresses {
+ if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) {
+ t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{}))
+ }
+ }
+ return
+ }
+ // Address given, verify the address is assigned to the NIC and no other
+ // address is.
+ found := false
+ for _, a := range info.ProtocolAddresses {
+ if a.Protocol == fakeNetNumber {
+ if a.AddressWithPrefix.Address == addr {
+ found = true
+ } else {
+ t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr)
+ }
+ }
+ }
+ if !found {
+ t.Errorf("verify address: couldn't find %s on the NIC", addr)
+ }
+}
+
+func TestEndpointExpiration(t *testing.T) {
+ const (
+ localAddrByte byte = 0x01
+ remoteAddr tcpip.Address = "\x03"
+ noAddr tcpip.Address = ""
+ nicID tcpip.NICID = 1
+ )
+ localAddr := tcpip.Address([]byte{localAddrByte})
+
+ for _, promiscuous := range []bool{true, false} {
+ for _, spoofing := range []bool{true, false} {
+ t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = localAddrByte
+
+ if promiscuous {
+ if err := s.SetPromiscuousMode(nicID, true); err != nil {
+ t.Fatal("SetPromiscuousMode failed:", err)
+ }
+ }
+
+ if spoofing {
+ if err := s.SetSpoofing(nicID, true); err != nil {
+ t.Fatal("SetSpoofing failed:", err)
+ }
+ }
+
+ // 1. No Address yet, send should only work for spoofing, receive for
+ // promiscuous mode.
+ //-----------------------
+ verifyAddress(t, s, nicID, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, ep, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ }
+
+ // 2. Add Address, everything should work.
+ //-----------------------
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicID, localAddr)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+
+ // 3. Remove the address, send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicID, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicID, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, ep, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ }
+
+ // 4. Add Address back, everything should work again.
+ //-----------------------
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicID, localAddr)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+
+ // 5. Take a reference to the endpoint by getting a route. Verify that
+ // we can still send/receive, including sending using the route.
+ //-----------------------
+ r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+ testSend(t, r, ep, nil)
+
+ // 6. Remove the address. Send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicID, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicID, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ }
+ if spoofing {
+ testSend(t, r, ep, nil)
+ testSendTo(t, s, remoteAddr, ep, nil)
+ } else {
+ testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ }
+
+ // 7. Add Address back, everything should work again.
+ //-----------------------
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicID, localAddr)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+ testSend(t, r, ep, nil)
+
+ // 8. Remove the route, sendTo/recv should still work.
+ //-----------------------
+ r.Release()
+ verifyAddress(t, s, nicID, localAddr)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+
+ // 9. Remove the address. Send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicID, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicID, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, ep, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ }
+ })
+ }
+ }
+}
+
+func TestPromiscuousMode(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+
+ buf := buffer.NewView(30)
+
+ // Write a packet, and check that it doesn't get delivered as we don't
+ // have a matching endpoint.
+ const localAddrByte byte = 0x01
+ buf[dstAddrOffset] = localAddrByte
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+
+ // Set promiscuous mode, then check that packet is delivered.
+ if err := s.SetPromiscuousMode(1, true); err != nil {
+ t.Fatal("SetPromiscuousMode failed:", err)
+ }
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+
+ // Check that we can't get a route as there is no local address.
+ _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
+ if err != tcpip.ErrNoRoute {
+ t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, tcpip.ErrNoRoute)
+ }
+
+ // Set promiscuous mode to false, then check that packet can't be
+ // delivered anymore.
+ if err := s.SetPromiscuousMode(1, false); err != nil {
+ t.Fatal("SetPromiscuousMode failed:", err)
+ }
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+}
+
+func TestSpoofingWithAddress(t *testing.T) {
+ localAddr := tcpip.Address("\x01")
+ nonExistentLocalAddr := tcpip.Address("\x02")
+ dstAddr := tcpip.Address("\x03")
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ // With address spoofing disabled, FindRoute does not permit an address
+ // that was not added to the NIC to be used as the source.
+ r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err == nil {
+ t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
+ }
+
+ // With address spoofing enabled, FindRoute permits any address to be used
+ // as the source.
+ if err := s.SetSpoofing(1, true); err != nil {
+ t.Fatal("SetSpoofing failed:", err)
+ }
+ r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ if r.LocalAddress != nonExistentLocalAddr {
+ t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr)
+ }
+ // Sending a packet works.
+ testSendTo(t, s, dstAddr, ep, nil)
+ testSend(t, r, ep, nil)
+
+ // FindRoute should also work with a local address that exists on the NIC.
+ r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ if r.LocalAddress != localAddr {
+ t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr)
+ }
+ // Sending a packet using the route works.
+ testSend(t, r, ep, nil)
+}
+
+func TestSpoofingNoAddress(t *testing.T) {
+ nonExistentLocalAddr := tcpip.Address("\x01")
+ dstAddr := tcpip.Address("\x02")
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ // With address spoofing disabled, FindRoute does not permit an address
+ // that was not added to the NIC to be used as the source.
+ r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err == nil {
+ t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
+ }
+ // Sending a packet fails.
+ testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute)
+
+ // With address spoofing enabled, FindRoute permits any address to be used
+ // as the source.
+ if err := s.SetSpoofing(1, true); err != nil {
+ t.Fatal("SetSpoofing failed:", err)
+ }
+ r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ if r.LocalAddress != nonExistentLocalAddr {
+ t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr)
+ }
+ // Sending a packet works.
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, ep, nil)
+}
+
+func verifyRoute(gotRoute, wantRoute stack.Route) error {
+ if gotRoute.LocalAddress != wantRoute.LocalAddress {
+ return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress)
+ }
+ if gotRoute.RemoteAddress != wantRoute.RemoteAddress {
+ return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress)
+ }
+ if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress {
+ return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress)
+ }
+ if gotRoute.NextHop != wantRoute.NextHop {
+ return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop)
+ }
+ return nil
+}
+
+func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+ s.SetRouteTable([]tcpip.Route{})
+
+ // If there is no endpoint, it won't work.
+ if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
+ t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
+ }
+
+ protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}}
+ if err := s.AddProtocolAddress(1, protoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err)
+ }
+ r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
+ }
+ if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
+ }
+
+ // If the NIC doesn't exist, it won't work.
+ if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
+ t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
+ }
+}
+
+func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
+ defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0}
+ // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1.
+ nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24}
+ nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1.
+ nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24}
+ nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01")
+
+ // Create a new stack with two NICs.
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC failed: %s", err)
+ }
+ if err := s.CreateNIC(2, ep); err != nil {
+ t.Fatalf("CreateNIC failed: %s", err)
+ }
+ nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr}
+ if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err)
+ }
+
+ nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr}
+ if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil {
+ t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err)
+ }
+
+ // Set the initial route table.
+ rt := []tcpip.Route{
+ {Destination: nic1Addr.Subnet(), NIC: 1},
+ {Destination: nic2Addr.Subnet(), NIC: 2},
+ {Destination: defaultAddr.Subnet(), Gateway: nic2Gateway, NIC: 2},
+ {Destination: defaultAddr.Subnet(), Gateway: nic1Gateway, NIC: 1},
+ }
+ s.SetRouteTable(rt)
+
+ // When an interface is given, the route for a broadcast goes through it.
+ r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
+ }
+ if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
+ }
+
+ // When an interface is not given, it consults the route table.
+ // 1. Case: Using the default route.
+ r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
+ }
+ if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
+ }
+
+ // 2. Case: Having an explicit route for broadcast will select that one.
+ rt = append(
+ []tcpip.Route{
+ {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1},
+ },
+ rt...,
+ )
+ s.SetRouteTable(rt)
+ r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
+ }
+ if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
+ }
+}
+
+func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ routeNeeded bool
+ address tcpip.Address
+ }{
+ // IPv4 multicast address range: 224.0.0.0 - 239.255.255.255
+ // <=> 0xe0.0x00.0x00.0x00 - 0xef.0xff.0xff.0xff
+ {"IPv4 Multicast 1", false, "\xe0\x00\x00\x00"},
+ {"IPv4 Multicast 2", false, "\xef\xff\xff\xff"},
+ {"IPv4 Unicast 1", true, "\xdf\xff\xff\xff"},
+ {"IPv4 Unicast 2", true, "\xf0\x00\x00\x00"},
+ {"IPv4 Unicast 3", true, "\x00\x00\x00\x00"},
+
+ // IPv6 multicast address is 0xff[8] + flags[4] + scope[4] + groupId[112]
+ {"IPv6 Multicast 1", false, "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
+ {"IPv6 Multicast 2", false, "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
+ {"IPv6 Multicast 3", false, "\xff\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"},
+
+ // IPv6 link-local address starts with fe80::/10.
+ {"IPv6 Unicast Link-Local 1", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
+ {"IPv6 Unicast Link-Local 2", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"},
+ {"IPv6 Unicast Link-Local 3", false, "\xfe\x80\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff"},
+ {"IPv6 Unicast Link-Local 4", false, "\xfe\xbf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
+ {"IPv6 Unicast Link-Local 5", false, "\xfe\xbf\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"},
+
+ // IPv6 addresses that are neither multicast nor link-local.
+ {"IPv6 Unicast Not Link-Local 1", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
+ {"IPv6 Unicast Not Link-Local 2", true, "\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"},
+ {"IPv6 Unicast Not Link-local 3", true, "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
+ {"IPv6 Unicast Not Link-Local 4", true, "\xfe\xc0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
+ {"IPv6 Unicast Not Link-Local 5", true, "\xfe\xdf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
+ {"IPv6 Unicast Not Link-Local 6", true, "\xfd\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
+ {"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{})
+
+ var anyAddr tcpip.Address
+ if len(tc.address) == header.IPv4AddressSize {
+ anyAddr = header.IPv4Any
+ } else {
+ anyAddr = header.IPv6Any
+ }
+
+ want := tcpip.ErrNetworkUnreachable
+ if tc.routeNeeded {
+ want = tcpip.ErrNoRoute
+ }
+
+ // If there is no endpoint, it won't work.
+ if _, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want {
+ t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil {
+ t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err)
+ }
+
+ if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded {
+ // Route table is empty but we need a route, this should cause an error.
+ if err != tcpip.ErrNoRoute {
+ t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, tcpip.ErrNoRoute)
+ }
+ } else {
+ if err != nil {
+ t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", anyAddr, tc.address, fakeNetNumber, err)
+ }
+ if r.LocalAddress != anyAddr {
+ t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress, anyAddr)
+ }
+ if r.RemoteAddress != tc.address {
+ t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress, tc.address)
+ }
+ }
+ // If the NIC doesn't exist, it won't work.
+ if _, err := s.FindRoute(2, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want {
+ t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", anyAddr, tc.address, fakeNetNumber, err, want)
+ }
+ })
+ }
+}
+
+// Add a range of addresses, then check that a packet is delivered.
+func TestAddressRangeAcceptsMatchingPacket(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+
+ buf := buffer.NewView(30)
+
+ const localAddrByte byte = 0x01
+ buf[dstAddrOffset] = localAddrByte
+ subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
+ if err != nil {
+ t.Fatal("NewSubnet failed:", err)
+ }
+ if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil {
+ t.Fatal("AddAddressRange failed:", err)
+ }
+
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+}
+
+func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) {
+ t.Helper()
+
+ // Loop over all addresses and check them.
+ numOfAddresses := 1 << uint(8-subnet.Prefix())
+ if numOfAddresses < 1 || numOfAddresses > 255 {
+ t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet)
+ }
+
+ addrBytes := []byte(subnet.ID())
+ for i := 0; i < numOfAddresses; i++ {
+ addr := tcpip.Address(addrBytes)
+ wantNicID := nicID
+ // The subnet and broadcast addresses are skipped.
+ if !rangeExists || addr == subnet.ID() || addr == subnet.Broadcast() {
+ wantNicID = 0
+ }
+ if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, addr); gotNicID != wantNicID {
+ t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, addr, gotNicID, wantNicID)
+ }
+ addrBytes[0]++
+ }
+
+ // Trying the next address should always fail since it is outside the range.
+ if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 {
+ t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0)
+ }
+}
+
+// Set a range of addresses, then remove it again, and check at each step that
+// CheckLocalAddress returns the correct NIC for each address or zero if not
+// existent.
+func TestCheckLocalAddressForSubnet(t *testing.T) {
+ const nicID tcpip.NICID = 1
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID}})
+ }
+
+ subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0"))
+ if err != nil {
+ t.Fatal("NewSubnet failed:", err)
+ }
+
+ testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */)
+
+ if err := s.AddAddressRange(nicID, fakeNetNumber, subnet); err != nil {
+ t.Fatal("AddAddressRange failed:", err)
+ }
+
+ testNicForAddressRange(t, nicID, s, subnet, true /* rangeExists */)
+
+ if err := s.RemoveAddressRange(nicID, subnet); err != nil {
+ t.Fatal("RemoveAddressRange failed:", err)
+ }
+
+ testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */)
+}
+
+// Set a range of addresses, then send a packet to a destination outside the
+// range and then check it doesn't get delivered.
+func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+
+ buf := buffer.NewView(30)
+
+ const localAddrByte byte = 0x01
+ buf[dstAddrOffset] = localAddrByte
+ subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
+ if err != nil {
+ t.Fatal("NewSubnet failed:", err)
+ }
+ if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil {
+ t.Fatal("AddAddressRange failed:", err)
+ }
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+}
+
+func TestNetworkOptions(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{},
+ })
+
+ // Try an unsupported network protocol.
+ if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol {
+ t.Fatalf("SetNetworkProtocolOption(fakeNet2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err)
+ }
+
+ testCases := []struct {
+ option interface{}
+ wantErr *tcpip.Error
+ verifier func(t *testing.T, p stack.NetworkProtocol)
+ }{
+ {fakeNetGoodOption(true), nil, func(t *testing.T, p stack.NetworkProtocol) {
+ t.Helper()
+ fakeNet := p.(*fakeNetworkProtocol)
+ if fakeNet.opts.good != true {
+ t.Fatalf("fakeNet.opts.good = false, want = true")
+ }
+ var v fakeNetGoodOption
+ if err := s.NetworkProtocolOption(fakeNetNumber, &v); err != nil {
+ t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) = %v, want = nil, where v is option %T", v, err)
+ }
+ if v != true {
+ t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) returned v = %v, want = true", v)
+ }
+ }},
+ {fakeNetBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
+ {fakeNetInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
+ }
+ for _, tc := range testCases {
+ if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); got != tc.wantErr {
+ t.Errorf("s.SetNetworkProtocolOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.wantErr)
+ }
+ if tc.verifier != nil {
+ tc.verifier(t, s.NetworkProtocolInstance(fakeNetNumber))
+ }
+ }
+}
+
+func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.Subnet) bool {
+ ranges, ok := s.NICAddressRanges()[id]
+ if !ok {
+ return false
+ }
+ for _, r := range ranges {
+ if r == addrRange {
+ return true
+ }
+ }
+ return false
+}
+
+func TestAddresRangeAddRemove(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ addr := tcpip.Address("\x01\x01\x01\x01")
+ mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr)))
+ addrRange, err := tcpip.NewSubnet(addr, mask)
+ if err != nil {
+ t.Fatal("NewSubnet failed:", err)
+ }
+
+ if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want {
+ t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
+ }
+
+ if err := s.AddAddressRange(1, fakeNetNumber, addrRange); err != nil {
+ t.Fatal("AddAddressRange failed:", err)
+ }
+
+ if got, want := stackContainsAddressRange(s, 1, addrRange), true; got != want {
+ t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
+ }
+
+ if err := s.RemoveAddressRange(1, addrRange); err != nil {
+ t.Fatal("RemoveAddressRange failed:", err)
+ }
+
+ if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want {
+ t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
+ }
+}
+
+func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
+ for _, addrLen := range []int{4, 16} {
+ t.Run(fmt.Sprintf("addrLen=%d", addrLen), func(t *testing.T) {
+ for canBe := 0; canBe < 3; canBe++ {
+ t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) {
+ for never := 0; never < 3; never++ {
+ t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+ // Insert <canBe> primary and <never> never-primary addresses.
+ // Each one will add a network endpoint to the NIC.
+ primaryAddrAdded := make(map[tcpip.AddressWithPrefix]struct{})
+ for i := 0; i < canBe+never; i++ {
+ var behavior stack.PrimaryEndpointBehavior
+ if i < canBe {
+ behavior = stack.CanBePrimaryEndpoint
+ } else {
+ behavior = stack.NeverPrimaryEndpoint
+ }
+ // Add an address and in case of a primary one include a
+ // prefixLen.
+ address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen))
+ if behavior == stack.CanBePrimaryEndpoint {
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: addrLen * 8,
+ },
+ }
+ if err := s.AddProtocolAddressWithOptions(1, protocolAddress, behavior); err != nil {
+ t.Fatal("AddProtocolAddressWithOptions failed:", err)
+ }
+ // Remember the address/prefix.
+ primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{}
+ } else {
+ if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil {
+ t.Fatal("AddAddressWithOptions failed:", err)
+ }
+ }
+ }
+ // Check that GetMainNICAddress returns an address if at least
+ // one primary address was added. In that case make sure the
+ // address/prefixLen matches what we added.
+ gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("GetMainNICAddress failed:", err)
+ }
+ if len(primaryAddrAdded) == 0 {
+ // No primary addresses present.
+ if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
+ t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr)
+ }
+ } else {
+ // At least one primary address was added, verify the returned
+ // address is in the list of primary addresses we added.
+ if _, ok := primaryAddrAdded[gotAddr]; !ok {
+ t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded)
+ }
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
+
+func TestGetMainNICAddressAddRemove(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ for _, tc := range []struct {
+ name string
+ address tcpip.Address
+ prefixLen int
+ }{
+ {"IPv4", "\x01\x01\x01\x01", 24},
+ {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", 116},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tc.address,
+ PrefixLen: tc.prefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddress); err != nil {
+ t.Fatal("AddProtocolAddress failed:", err)
+ }
+
+ // Check that we get the right initial address and prefix length.
+ gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("GetMainNICAddress failed:", err)
+ }
+ if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr {
+ t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
+ }
+
+ if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+
+ // Check that we get no address after removal.
+ gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("GetMainNICAddress failed:", err)
+ }
+ if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
+ t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
+ }
+ })
+ }
+}
+
+// Simple network address generator. Good for 255 addresses.
+type addressGenerator struct{ cnt byte }
+
+func (g *addressGenerator) next(addrLen int) tcpip.Address {
+ g.cnt++
+ return tcpip.Address(bytes.Repeat([]byte{g.cnt}, addrLen))
+}
+
+func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) {
+ t.Helper()
+
+ if len(gotAddresses) != len(expectedAddresses) {
+ t.Fatalf("got len(addresses) = %d, want = %d", len(gotAddresses), len(expectedAddresses))
+ }
+
+ sort.Slice(gotAddresses, func(i, j int) bool {
+ return gotAddresses[i].AddressWithPrefix.Address < gotAddresses[j].AddressWithPrefix.Address
+ })
+ sort.Slice(expectedAddresses, func(i, j int) bool {
+ return expectedAddresses[i].AddressWithPrefix.Address < expectedAddresses[j].AddressWithPrefix.Address
+ })
+
+ for i, gotAddr := range gotAddresses {
+ expectedAddr := expectedAddresses[i]
+ if gotAddr != expectedAddr {
+ t.Errorf("got address = %+v, wanted = %+v", gotAddr, expectedAddr)
+ }
+ }
+}
+
+func TestAddAddress(t *testing.T) {
+ const nicID = 1
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ var addrGen addressGenerator
+ expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2)
+ for _, addrLen := range []int{4, 16} {
+ address := addrGen.next(addrLen)
+ if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil {
+ t.Fatalf("AddAddress(address=%s) failed: %s", address, err)
+ }
+ expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen},
+ })
+ }
+
+ gotAddresses := s.AllAddresses()[nicID]
+ verifyAddresses(t, expectedAddresses, gotAddresses)
+}
+
+func TestAddProtocolAddress(t *testing.T) {
+ const nicID = 1
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ var addrGen addressGenerator
+ addrLenRange := []int{4, 16}
+ prefixLenRange := []int{8, 13, 20, 32}
+ expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange))
+ for _, addrLen := range addrLenRange {
+ for _, prefixLen := range prefixLenRange {
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addrGen.next(addrLen),
+ PrefixLen: prefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
+ t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err)
+ }
+ expectedAddresses = append(expectedAddresses, protocolAddress)
+ }
+ }
+
+ gotAddresses := s.AllAddresses()[nicID]
+ verifyAddresses(t, expectedAddresses, gotAddresses)
+}
+
+func TestAddAddressWithOptions(t *testing.T) {
+ const nicID = 1
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ addrLenRange := []int{4, 16}
+ behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint}
+ expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange))
+ var addrGen addressGenerator
+ for _, addrLen := range addrLenRange {
+ for _, behavior := range behaviorRange {
+ address := addrGen.next(addrLen)
+ if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil {
+ t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err)
+ }
+ expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen},
+ })
+ }
+ }
+
+ gotAddresses := s.AllAddresses()[nicID]
+ verifyAddresses(t, expectedAddresses, gotAddresses)
+}
+
+func TestAddProtocolAddressWithOptions(t *testing.T) {
+ const nicID = 1
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ addrLenRange := []int{4, 16}
+ prefixLenRange := []int{8, 13, 20, 32}
+ behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint}
+ expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange))
+ var addrGen addressGenerator
+ for _, addrLen := range addrLenRange {
+ for _, prefixLen := range prefixLenRange {
+ for _, behavior := range behaviorRange {
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addrGen.next(addrLen),
+ PrefixLen: prefixLen,
+ },
+ }
+ if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil {
+ t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err)
+ }
+ expectedAddresses = append(expectedAddresses, protocolAddress)
+ }
+ }
+ }
+
+ gotAddresses := s.AllAddresses()[nicID]
+ verifyAddresses(t, expectedAddresses, gotAddresses)
+}
+
+func TestCreateNICWithOptions(t *testing.T) {
+ type callArgsAndExpect struct {
+ nicID tcpip.NICID
+ opts stack.NICOptions
+ err *tcpip.Error
+ }
+
+ tests := []struct {
+ desc string
+ calls []callArgsAndExpect
+ }{
+ {
+ desc: "DuplicateNICID",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{Name: "eth1"},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{Name: "eth2"},
+ err: tcpip.ErrDuplicateNICID,
+ },
+ },
+ },
+ {
+ desc: "DuplicateName",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{Name: "lo"},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(2),
+ opts: stack.NICOptions{Name: "lo"},
+ err: tcpip.ErrDuplicateNICID,
+ },
+ },
+ },
+ {
+ desc: "Unnamed",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(2),
+ opts: stack.NICOptions{},
+ err: nil,
+ },
+ },
+ },
+ {
+ desc: "UnnamedDuplicateNICID",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{},
+ err: tcpip.ErrDuplicateNICID,
+ },
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.desc, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"))
+ for _, call := range test.calls {
+ if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want {
+ t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want)
+ }
+ }
+ })
+ }
+}
+
+func TestNICStats(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
+ t.Fatal("CreateNIC failed: ", err)
+ }
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ // Route all packets for address \x01 to NIC 1.
+ {
+ subnet, err := tcpip.NewSubnet("\x01", "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ // Send a packet to address 1.
+ buf := buffer.NewView(30)
+ ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
+ t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.NICInfo()[1].Stats.Rx.Bytes.Value(), uint64(len(buf)); got != want {
+ t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want)
+ }
+
+ payload := buffer.NewView(10)
+ // Write a packet out via the address for NIC 1
+ if err := sendTo(s, "\x01", payload); err != nil {
+ t.Fatal("sendTo failed: ", err)
+ }
+ want := uint64(ep1.Drain())
+ if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want {
+ t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want)
+ }
+
+ if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want {
+ t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
+ }
+}
+
+func TestNICForwarding(t *testing.T) {
+ const nicID1 = 1
+ const nicID2 = 2
+ const dstAddr = tcpip.Address("\x03")
+
+ tests := []struct {
+ name string
+ headerLen uint16
+ }{
+ {
+ name: "Zero header length",
+ },
+ {
+ name: "Non-zero header length",
+ headerLen: 16,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ s.SetForwarding(true)
+
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err)
+ }
+
+ ep2 := channelLinkWithHeaderLength{
+ Endpoint: channel.New(10, defaultMTU, ""),
+ headerLength: test.headerLen,
+ }
+ if err := s.CreateNIC(nicID2, &ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+ if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil {
+ t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err)
+ }
+
+ // Route all packets to dstAddr to NIC 2.
+ {
+ subnet, err := tcpip.NewSubnet(dstAddr, "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}})
+ }
+
+ // Send a packet to dstAddr.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = dstAddr[0]
+ ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ pkt, ok := ep2.Read()
+ if !ok {
+ t.Fatal("packet not forwarded")
+ }
+
+ // Test that the link's MaxHeaderLength is honoured.
+ if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want {
+ t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want)
+ }
+
+ // Test that forwarding increments Tx stats correctly.
+ if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want {
+ t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
+ t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
+ }
+ })
+ }
+}
+
+// TestNICContextPreservation tests that you can read out via stack.NICInfo the
+// Context data you pass via NICContext.Context in stack.CreateNICWithOptions.
+func TestNICContextPreservation(t *testing.T) {
+ var ctx *int
+ tests := []struct {
+ name string
+ opts stack.NICOptions
+ want stack.NICContext
+ }{
+ {
+ "context_set",
+ stack.NICOptions{Context: ctx},
+ ctx,
+ },
+ {
+ "context_not_set",
+ stack.NICOptions{},
+ nil,
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ id := tcpip.NICID(1)
+ ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"))
+ if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil {
+ t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err)
+ }
+ nicinfos := s.NICInfo()
+ nicinfo, ok := nicinfos[id]
+ if !ok {
+ t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos)
+ }
+ if got, want := nicinfo.Context == test.want, true; got != want {
+ t.Fatalf("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want)
+ }
+ })
+ }
+}
+
+// TestNICAutoGenLinkLocalAddr tests the auto-generation of IPv6 link-local
+// addresses.
+func TestNICAutoGenLinkLocalAddr(t *testing.T) {
+ const nicID = 1
+
+ var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte
+ n, err := rand.Read(secretKey[:])
+ if err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ }
+ if n != header.OpaqueIIDSecretKeyMinBytes {
+ t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n)
+ }
+
+ nicNameFunc := func(_ tcpip.NICID, name string) string {
+ return name
+ }
+
+ tests := []struct {
+ name string
+ nicName string
+ autoGen bool
+ linkAddr tcpip.LinkAddress
+ iidOpts stack.OpaqueInterfaceIdentifierOptions
+ shouldGen bool
+ expectedAddr tcpip.Address
+ }{
+ {
+ name: "Disabled",
+ nicName: "nic1",
+ autoGen: false,
+ linkAddr: linkAddr1,
+ shouldGen: false,
+ },
+ {
+ name: "Disabled without OIID options",
+ nicName: "nic1",
+ autoGen: false,
+ linkAddr: linkAddr1,
+ iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: nicNameFunc,
+ SecretKey: secretKey[:],
+ },
+ shouldGen: false,
+ },
+
+ // Tests for EUI64 based addresses.
+ {
+ name: "EUI64 Enabled",
+ autoGen: true,
+ linkAddr: linkAddr1,
+ shouldGen: true,
+ expectedAddr: header.LinkLocalAddr(linkAddr1),
+ },
+ {
+ name: "EUI64 Empty MAC",
+ autoGen: true,
+ shouldGen: false,
+ },
+ {
+ name: "EUI64 Invalid MAC",
+ autoGen: true,
+ linkAddr: "\x01\x02\x03",
+ shouldGen: false,
+ },
+ {
+ name: "EUI64 Multicast MAC",
+ autoGen: true,
+ linkAddr: "\x01\x02\x03\x04\x05\x06",
+ shouldGen: false,
+ },
+ {
+ name: "EUI64 Unspecified MAC",
+ autoGen: true,
+ linkAddr: "\x00\x00\x00\x00\x00\x00",
+ shouldGen: false,
+ },
+
+ // Tests for Opaque IID based addresses.
+ {
+ name: "OIID Enabled",
+ nicName: "nic1",
+ autoGen: true,
+ linkAddr: linkAddr1,
+ iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: nicNameFunc,
+ SecretKey: secretKey[:],
+ },
+ shouldGen: true,
+ expectedAddr: header.LinkLocalAddrWithOpaqueIID("nic1", 0, secretKey[:]),
+ },
+ // These are all cases where we would not have generated a
+ // link-local address if opaque IIDs were disabled.
+ {
+ name: "OIID Empty MAC and empty nicName",
+ autoGen: true,
+ iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: nicNameFunc,
+ SecretKey: secretKey[:1],
+ },
+ shouldGen: true,
+ expectedAddr: header.LinkLocalAddrWithOpaqueIID("", 0, secretKey[:1]),
+ },
+ {
+ name: "OIID Invalid MAC",
+ nicName: "test",
+ autoGen: true,
+ linkAddr: "\x01\x02\x03",
+ iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: nicNameFunc,
+ SecretKey: secretKey[:2],
+ },
+ shouldGen: true,
+ expectedAddr: header.LinkLocalAddrWithOpaqueIID("test", 0, secretKey[:2]),
+ },
+ {
+ name: "OIID Multicast MAC",
+ nicName: "test2",
+ autoGen: true,
+ linkAddr: "\x01\x02\x03\x04\x05\x06",
+ iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: nicNameFunc,
+ SecretKey: secretKey[:3],
+ },
+ shouldGen: true,
+ expectedAddr: header.LinkLocalAddrWithOpaqueIID("test2", 0, secretKey[:3]),
+ },
+ {
+ name: "OIID Unspecified MAC and nil SecretKey",
+ nicName: "test3",
+ autoGen: true,
+ linkAddr: "\x00\x00\x00\x00\x00\x00",
+ iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: nicNameFunc,
+ },
+ shouldGen: true,
+ expectedAddr: header.LinkLocalAddrWithOpaqueIID("test3", 0, nil),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: test.autoGen,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: test.iidOpts,
+ }
+
+ e := channel.New(0, 1280, test.linkAddr)
+ s := stack.New(opts)
+ nicOpts := stack.NICOptions{Name: test.nicName, Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
+ }
+
+ // A new disabled NIC should not have any address, even if auto generation
+ // was enabled.
+ allStackAddrs := s.AllAddresses()
+ allNICAddrs, ok := allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 0 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ }
+
+ // Enabling the NIC should attempt auto-generation of a link-local
+ // address.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+
+ var expectedMainAddr tcpip.AddressWithPrefix
+ if test.shouldGen {
+ expectedMainAddr = tcpip.AddressWithPrefix{
+ Address: test.expectedAddr,
+ PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen,
+ }
+
+ // Should have auto-generated an address and resolved immediately (DAD
+ // is disabled).
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, expectedMainAddr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ } else {
+ // Should not have auto-generated an address.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address")
+ default:
+ }
+ }
+
+ gotMainAddr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if gotMainAddr != expectedMainAddr {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", gotMainAddr, expectedMainAddr)
+ }
+ })
+ }
+}
+
+// TestNoLinkLocalAutoGenForLoopbackNIC tests that IPv6 link-local addresses are
+// not auto-generated for loopback NICs.
+func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
+ const nicID = 1
+ const nicName = "nicName"
+
+ tests := []struct {
+ name string
+ opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions
+ }{
+ {
+ name: "IID From MAC",
+ opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{},
+ },
+ {
+ name: "Opaque IID",
+ opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: true,
+ OpaqueIIDOpts: test.opaqueIIDOpts,
+ }
+
+ e := loopback.New()
+ s := stack.New(opts)
+ nicOpts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want)
+ }
+ })
+ }
+}
+
+// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6
+// link-local addresses will only be assigned after the DAD process resolves.
+func TestNICAutoGenAddrDoesDAD(t *testing.T) {
+ const nicID = 1
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ ndpConfigs := stack.DefaultNDPConfigurations()
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: ndpConfigs,
+ AutoGenIPv6LinkLocal: true,
+ NDPDisp: &ndpDisp,
+ }
+
+ e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ // Address should not be considered bound to the
+ // NIC yet (DAD ongoing).
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+
+ linkLocalAddr := header.LinkLocalAddr(linkAddr1)
+
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // We should get a resolution event after 1s (default time to
+ // resolve as per default NDP configurations). Waiting for that
+ // resolution time + an extra 1s without a resolution event
+ // means something is wrong.
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, linkLocalAddr, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ }
+ addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+}
+
+// TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected
+// when an address's kind gets "promoted" to permanent from permanentExpired.
+func TestNewPEBOnPromotionToPermanent(t *testing.T) {
+ pebs := []stack.PrimaryEndpointBehavior{
+ stack.NeverPrimaryEndpoint,
+ stack.CanBePrimaryEndpoint,
+ stack.FirstPrimaryEndpoint,
+ }
+
+ for _, pi := range pebs {
+ for _, ps := range pebs {
+ t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ // Add a permanent address with initial
+ // PrimaryEndpointBehavior (peb), pi. If pi is
+ // NeverPrimaryEndpoint, the address should not
+ // be returned by a call to GetMainNICAddress;
+ // else, it should.
+ if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil {
+ t.Fatal("AddAddressWithOptions failed:", err)
+ }
+ addr, err := s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("s.GetMainNICAddress failed:", err)
+ }
+ if pi == stack.NeverPrimaryEndpoint {
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want)
+
+ }
+ } else if addr.Address != "\x01" {
+ t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatalf("NewSubnet failed: %v", err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ // Take a route through the address so its ref
+ // count gets incremented and does not actually
+ // get deleted when RemoveAddress is called
+ // below. This is because we want to test that a
+ // new peb is respected when an address gets
+ // "promoted" to permanent from a
+ // permanentExpired kind.
+ r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false)
+ if err != nil {
+ t.Fatalf("FindRoute failed: %v", err)
+ }
+ defer r.Release()
+ if err := s.RemoveAddress(1, "\x01"); err != nil {
+ t.Fatalf("RemoveAddress failed: %v", err)
+ }
+
+ //
+ // At this point, the address should still be
+ // known by the NIC, but have its
+ // kind = permanentExpired.
+ //
+
+ // Add some other address with peb set to
+ // FirstPrimaryEndpoint.
+ if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("AddAddressWithOptions failed: %v", err)
+
+ }
+
+ // Add back the address we removed earlier and
+ // make sure the new peb was respected.
+ // (The address should just be promoted now).
+ if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil {
+ t.Fatalf("AddAddressWithOptions failed: %v", err)
+ }
+ var primaryAddrs []tcpip.Address
+ for _, pa := range s.NICInfo()[1].ProtocolAddresses {
+ primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address)
+ }
+ var expectedList []tcpip.Address
+ switch ps {
+ case stack.FirstPrimaryEndpoint:
+ expectedList = []tcpip.Address{
+ "\x01",
+ "\x03",
+ }
+ case stack.CanBePrimaryEndpoint:
+ expectedList = []tcpip.Address{
+ "\x03",
+ "\x01",
+ }
+ case stack.NeverPrimaryEndpoint:
+ expectedList = []tcpip.Address{
+ "\x03",
+ }
+ }
+ if !cmp.Equal(primaryAddrs, expectedList) {
+ t.Fatalf("got NIC's primary addresses = %v, want = %v", primaryAddrs, expectedList)
+ }
+
+ // Once we remove the other address, if the new
+ // peb, ps, was NeverPrimaryEndpoint, no address
+ // should be returned by a call to
+ // GetMainNICAddress; else, our original address
+ // should be returned.
+ if err := s.RemoveAddress(1, "\x03"); err != nil {
+ t.Fatalf("RemoveAddress failed: %v", err)
+ }
+ addr, err = s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatalf("s.GetMainNICAddress failed: %v", err)
+ }
+ if ps == stack.NeverPrimaryEndpoint {
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want)
+
+ }
+ } else {
+ if addr.Address != "\x01" {
+ t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address)
+ }
+ }
+ })
+ }
+ }
+}
+
+func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
+ const (
+ linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ nicID = 1
+ lifetimeSeconds = 9999
+ )
+
+ prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, stableGlobalAddr2 := prefixSubnetAddr(1, linkAddr1)
+
+ var tempIIDHistory [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistory[:], nil, nicID)
+ tempGlobalAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableGlobalAddr1.Address).Address
+ tempGlobalAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableGlobalAddr2.Address).Address
+
+ // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test.
+ tests := []struct {
+ name string
+ slaacPrefixForTempAddrBeforeNICAddrAdd tcpip.AddressWithPrefix
+ nicAddrs []tcpip.Address
+ slaacPrefixForTempAddrAfterNICAddrAdd tcpip.AddressWithPrefix
+ connectAddr tcpip.Address
+ expectedLocalAddr tcpip.Address
+ }{
+ // Test Rule 1 of RFC 6724 section 5.
+ {
+ name: "Same Global most preferred (last address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: globalAddr1,
+ expectedLocalAddr: globalAddr1,
+ },
+ {
+ name: "Same Global most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
+ connectAddr: globalAddr1,
+ expectedLocalAddr: globalAddr1,
+ },
+ {
+ name: "Same Link Local most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
+ connectAddr: linkLocalAddr1,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Same Link Local most preferred (first address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: linkLocalAddr1,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Same Unique Local most preferred (last address)",
+ nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
+ connectAddr: uniqueLocalAddr1,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+ {
+ name: "Same Unique Local most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
+ connectAddr: uniqueLocalAddr1,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+
+ // Test Rule 2 of RFC 6724 section 5.
+ {
+ name: "Global most preferred (last address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: globalAddr2,
+ expectedLocalAddr: globalAddr1,
+ },
+ {
+ name: "Global most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
+ connectAddr: globalAddr2,
+ expectedLocalAddr: globalAddr1,
+ },
+ {
+ name: "Link Local most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
+ connectAddr: linkLocalAddr2,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Link Local most preferred (first address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: linkLocalAddr2,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Link Local most preferred for link local multicast (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
+ connectAddr: linkLocalMulticastAddr,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Link Local most preferred for link local multicast (first address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: linkLocalMulticastAddr,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Unique Local most preferred (last address)",
+ nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
+ connectAddr: uniqueLocalAddr2,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+ {
+ name: "Unique Local most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
+ connectAddr: uniqueLocalAddr2,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+
+ // Test Rule 7 of RFC 6724 section 5.
+ {
+ name: "Temp Global most preferred (last address)",
+ slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: globalAddr2,
+ expectedLocalAddr: tempGlobalAddr1,
+ },
+ {
+ name: "Temp Global most preferred (first address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ slaacPrefixForTempAddrAfterNICAddrAdd: prefix1,
+ connectAddr: globalAddr2,
+ expectedLocalAddr: tempGlobalAddr1,
+ },
+
+ // Test returning the endpoint that is closest to the front when
+ // candidate addresses are "equal" from the perspective of RFC 6724
+ // section 5.
+ {
+ name: "Unique Local for Global",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2},
+ connectAddr: globalAddr2,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+ {
+ name: "Link Local for Global",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
+ connectAddr: globalAddr2,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Link Local for Unique Local",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
+ connectAddr: uniqueLocalAddr2,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Temp Global for Global",
+ slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1,
+ slaacPrefixForTempAddrAfterNICAddrAdd: prefix2,
+ connectAddr: globalAddr1,
+ expectedLocalAddr: tempGlobalAddr2,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDispatcher{},
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: llAddr3,
+ NIC: nicID,
+ }})
+ s.AddLinkAddress(nicID, llAddr3, linkAddr3)
+
+ if test.slaacPrefixForTempAddrBeforeNICAddrAdd != (tcpip.AddressWithPrefix{}) {
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrBeforeNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds))
+ }
+
+ for _, a := range test.nicAddrs {
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil {
+ t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err)
+ }
+ }
+
+ if test.slaacPrefixForTempAddrAfterNICAddrAdd != (tcpip.AddressWithPrefix{}) {
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrAfterNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds))
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ if got := addrForNewConnectionTo(t, s, tcpip.FullAddress{Addr: test.connectAddr, NIC: nicID, Port: 1234}); got != test.expectedLocalAddr {
+ t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr)
+ }
+ })
+ }
+}
+
+func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) {
+ const nicID = 1
+
+ e := loopback.New()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ })
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ allStackAddrs := s.AllAddresses()
+ allNICAddrs, ok := allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 0 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ }
+
+ // Enabling the NIC should add the IPv4 broadcast address.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ allStackAddrs = s.AllAddresses()
+ allNICAddrs, ok = allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 1 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 1", l)
+ }
+ want := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: header.IPv4Broadcast,
+ PrefixLen: 32,
+ },
+ }
+ if allNICAddrs[0] != want {
+ t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want)
+ }
+
+ // Disabling the NIC should remove the IPv4 broadcast address.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ allStackAddrs = s.AllAddresses()
+ allNICAddrs, ok = allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 0 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ }
+}
+
+// TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval tests that removing an IPv6
+// address after leaving its solicited node multicast address does not result in
+// an error.
+func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ })
+ e := channel.New(10, 1280, linkAddr1)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err)
+ }
+
+ // The NIC should have joined addr1's solicited node multicast address.
+ snmc := header.SolicitedNodeAddr(addr1)
+ in, err := s.IsInGroup(nicID, snmc)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err)
+ }
+ if !in {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, snmc)
+ }
+
+ if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, snmc); err != nil {
+ t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, snmc, err)
+ }
+ in, err = s.IsInGroup(nicID, snmc)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err)
+ }
+ if in {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, snmc)
+ }
+
+ if err := s.RemoveAddress(nicID, addr1); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err)
+ }
+}
+
+func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) {
+ const nicID = 1
+
+ e := loopback.New()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ })
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ // Should not be in the IPv6 all-nodes multicast group yet because the NIC has
+ // not been enabled yet.
+ isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
+ }
+ if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
+ }
+
+ // The all-nodes multicast group should be joined when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
+ }
+ if !isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress)
+ }
+
+ // The all-nodes multicast group should be left when the NIC is disabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
+ }
+ if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
+ }
+}
+
+// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC
+// was disabled have DAD performed on them when the NIC is enabled.
+func TestDoDADWhenNICEnabled(t *testing.T) {
+ const dadTransmits = 1
+ const retransmitTimer = time.Second
+ const nicID = 1
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ },
+ NDPDisp: &ndpDisp,
+ }
+
+ e := channel.New(dadTransmits, 1280, linkAddr1)
+ s := stack.New(opts)
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ addr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: llAddr1,
+ PrefixLen: 128,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, addr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
+ }
+
+ // Address should be in the list of all addresses.
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+
+ // Address should be tentative so it should not be a main address.
+ got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); got != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want)
+ }
+
+ // Enabling the NIC should start DAD for the address.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
+ got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); got != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want)
+ }
+
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ }
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+ got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if got != addr.AddressWithPrefix {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
+ }
+
+ // Enabling the NIC again should be a no-op.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+ got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if got != addr.AddressWithPrefix {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
+ }
+}
+
+func TestStackReceiveBufferSizeOption(t *testing.T) {
+ const sMin = stack.MinBufferSize
+ testCases := []struct {
+ name string
+ rs stack.ReceiveBufferSizeOption
+ err *tcpip.Error
+ }{
+ // Invalid configurations.
+ {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+ {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+
+ // Valid Configurations
+ {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
+ {"all_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
+ {"min_default_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
+ {"default_max_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ defer s.Close()
+ if err := s.SetOption(tc.rs); err != tc.err {
+ t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err)
+ }
+ var rs stack.ReceiveBufferSizeOption
+ if tc.err == nil {
+ if err := s.Option(&rs); err != nil {
+ t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err)
+ }
+ if got, want := rs, tc.rs; got != want {
+ t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want)
+ }
+ }
+ })
+ }
+}
+
+func TestStackSendBufferSizeOption(t *testing.T) {
+ const sMin = stack.MinBufferSize
+ testCases := []struct {
+ name string
+ ss stack.SendBufferSizeOption
+ err *tcpip.Error
+ }{
+ // Invalid configurations.
+ {"min_below_zero", stack.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"min_zero", stack.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"default_below_min", stack.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+ {"default_above_max", stack.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"max_below_min", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+
+ // Valid Configurations
+ {"in_ascending_order", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
+ {"all_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
+ {"min_default_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
+ {"default_max_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ defer s.Close()
+ if err := s.SetOption(tc.ss); err != tc.err {
+ t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err)
+ }
+ var ss stack.SendBufferSizeOption
+ if tc.err == nil {
+ if err := s.Option(&ss); err != nil {
+ t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err)
+ }
+ if got, want := ss, tc.ss; got != want {
+ t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
new file mode 100644
index 000000000..b902c6ca9
--- /dev/null
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -0,0 +1,686 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "math/rand"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
+)
+
+type protocolIDs struct {
+ network tcpip.NetworkProtocolNumber
+ transport tcpip.TransportProtocolNumber
+}
+
+// transportEndpoints manages all endpoints of a given protocol. It has its own
+// mutex so as to reduce interference between protocols.
+type transportEndpoints struct {
+ // mu protects all fields of the transportEndpoints.
+ mu sync.RWMutex
+ endpoints map[TransportEndpointID]*endpointsByNIC
+ // rawEndpoints contains endpoints for raw sockets, which receive all
+ // traffic of a given protocol regardless of port.
+ rawEndpoints []RawTransportEndpoint
+}
+
+// unregisterEndpoint unregisters the endpoint with the given id such that it
+// won't receive any more packets.
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ epsByNIC, ok := eps.endpoints[id]
+ if !ok {
+ return
+ }
+ if !epsByNIC.unregisterEndpoint(bindToDevice, ep, flags) {
+ return
+ }
+ delete(eps.endpoints, id)
+}
+
+func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
+ eps.mu.RLock()
+ defer eps.mu.RUnlock()
+ es := make([]TransportEndpoint, 0, len(eps.endpoints))
+ for _, e := range eps.endpoints {
+ es = append(es, e.transportEndpoints()...)
+ }
+ return es
+}
+
+// iterEndpointsLocked yields all endpointsByNIC in eps that match id, in
+// descending order of match quality. If a call to yield returns false,
+// iterEndpointsLocked stops iteration and returns immediately.
+//
+// Preconditions: eps.mu must be locked.
+func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) {
+ // Try to find a match with the id as provided.
+ if ep, ok := eps.endpoints[id]; ok {
+ if !yield(ep) {
+ return
+ }
+ }
+
+ // Try to find a match with the id minus the local address.
+ nid := id
+
+ nid.LocalAddress = ""
+ if ep, ok := eps.endpoints[nid]; ok {
+ if !yield(ep) {
+ return
+ }
+ }
+
+ // Try to find a match with the id minus the remote part.
+ nid.LocalAddress = id.LocalAddress
+ nid.RemoteAddress = ""
+ nid.RemotePort = 0
+ if ep, ok := eps.endpoints[nid]; ok {
+ if !yield(ep) {
+ return
+ }
+ }
+
+ // Try to find a match with only the local port.
+ nid.LocalAddress = ""
+ if ep, ok := eps.endpoints[nid]; ok {
+ if !yield(ep) {
+ return
+ }
+ }
+}
+
+// findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in
+// descending order of match quality.
+//
+// Preconditions: eps.mu must be locked.
+func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC {
+ var matchedEPs []*endpointsByNIC
+ eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
+ matchedEPs = append(matchedEPs, ep)
+ return true
+ })
+ return matchedEPs
+}
+
+// findEndpointLocked returns the endpoint that most closely matches the given id.
+//
+// Preconditions: eps.mu must be locked.
+func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC {
+ var matchedEP *endpointsByNIC
+ eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
+ matchedEP = ep
+ return false
+ })
+ return matchedEP
+}
+
+type endpointsByNIC struct {
+ mu sync.RWMutex
+ endpoints map[tcpip.NICID]*multiPortEndpoint
+ // seed is a random secret for a jenkins hash.
+ seed uint32
+}
+
+func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
+ epsByNIC.mu.RLock()
+ defer epsByNIC.mu.RUnlock()
+ var eps []TransportEndpoint
+ for _, ep := range epsByNIC.endpoints {
+ eps = append(eps, ep.transportEndpoints()...)
+ }
+ return eps
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
+ epsByNIC.mu.RLock()
+
+ mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()]
+ if !ok {
+ if mpep, ok = epsByNIC.endpoints[0]; !ok {
+ epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
+ return
+ }
+ }
+
+ // If this is a broadcast or multicast datagram, deliver the datagram to all
+ // endpoints bound to the right device.
+ if isMulticastOrBroadcast(id.LocalAddress) {
+ mpep.handlePacketAll(r, id, pkt)
+ epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
+ return
+ }
+ // multiPortEndpoints are guaranteed to have at least one element.
+ transEP := selectEndpoint(id, mpep, epsByNIC.seed)
+ if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
+ queuedProtocol.QueuePacket(r, transEP, id, pkt)
+ epsByNIC.mu.RUnlock()
+ return
+ }
+
+ transEP.HandlePacket(r, id, pkt)
+ epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) {
+ epsByNIC.mu.RLock()
+ defer epsByNIC.mu.RUnlock()
+
+ mpep, ok := epsByNIC.endpoints[n.ID()]
+ if !ok {
+ mpep, ok = epsByNIC.endpoints[0]
+ }
+ if !ok {
+ return
+ }
+
+ // TODO(eyalsoha): Why don't we look at id to see if this packet needs to
+ // broadcast like we are doing with handlePacket above?
+
+ // multiPortEndpoints are guaranteed to have at least one element.
+ selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(id, typ, extra, pkt)
+}
+
+// registerEndpoint returns true if it succeeds. It fails and returns
+// false if ep already has an element with the same key.
+func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ epsByNIC.mu.Lock()
+ defer epsByNIC.mu.Unlock()
+
+ multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
+ if !ok {
+ multiPortEp = &multiPortEndpoint{
+ demux: d,
+ netProto: netProto,
+ transProto: transProto,
+ }
+ epsByNIC.endpoints[bindToDevice] = multiPortEp
+ }
+
+ return multiPortEp.singleRegisterEndpoint(t, flags)
+}
+
+func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ epsByNIC.mu.RLock()
+ defer epsByNIC.mu.RUnlock()
+
+ multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
+ if !ok {
+ return nil
+ }
+
+ return multiPortEp.singleCheckEndpoint(flags)
+}
+
+// unregisterEndpoint returns true if endpointsByNIC has to be unregistered.
+func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool {
+ epsByNIC.mu.Lock()
+ defer epsByNIC.mu.Unlock()
+ multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
+ if !ok {
+ return false
+ }
+ if multiPortEp.unregisterEndpoint(t, flags) {
+ delete(epsByNIC.endpoints, bindToDevice)
+ }
+ return len(epsByNIC.endpoints) == 0
+}
+
+// transportDemuxer demultiplexes packets targeted at a transport endpoint
+// (i.e., after they've been parsed by the network layer). It does two levels
+// of demultiplexing: first based on the network and transport protocols, then
+// based on endpoints IDs. It should only be instantiated via
+// newTransportDemuxer.
+type transportDemuxer struct {
+ // protocol is immutable.
+ protocol map[protocolIDs]*transportEndpoints
+ queuedProtocols map[protocolIDs]queuedTransportProtocol
+}
+
+// queuedTransportProtocol if supported by a protocol implementation will cause
+// the dispatcher to delivery packets to the QueuePacket method instead of
+// calling HandlePacket directly on the endpoint.
+type queuedTransportProtocol interface {
+ QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer)
+}
+
+func newTransportDemuxer(stack *Stack) *transportDemuxer {
+ d := &transportDemuxer{
+ protocol: make(map[protocolIDs]*transportEndpoints),
+ queuedProtocols: make(map[protocolIDs]queuedTransportProtocol),
+ }
+
+ // Add each network and transport pair to the demuxer.
+ for netProto := range stack.networkProtocols {
+ for proto := range stack.transportProtocols {
+ protoIDs := protocolIDs{netProto, proto}
+ d.protocol[protoIDs] = &transportEndpoints{
+ endpoints: make(map[TransportEndpointID]*endpointsByNIC),
+ }
+ qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol)
+ if isQueued {
+ d.queuedProtocols[protoIDs] = qTransProto
+ }
+ }
+ }
+
+ return d
+}
+
+// registerEndpoint registers the given endpoint with the dispatcher such that
+// packets that match the endpoint ID are delivered to it.
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ for i, n := range netProtos {
+ if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil {
+ d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice)
+ return err
+ }
+ }
+
+ return nil
+}
+
+// checkEndpoint checks if an endpoint can be registered with the dispatcher.
+func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ for _, n := range netProtos {
+ if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// multiPortEndpoint is a container for TransportEndpoints which are bound to
+// the same pair of address and port. endpointsArr always has at least one
+// element.
+//
+// FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save
+// this to ensure that the underlying endpoints get saved/restored, but not not
+// use the restored copy.
+//
+// +stateify savable
+type multiPortEndpoint struct {
+ mu sync.RWMutex `state:"nosave"`
+ demux *transportDemuxer
+ netProto tcpip.NetworkProtocolNumber
+ transProto tcpip.TransportProtocolNumber
+
+ // endpoints stores the transport endpoints in the order in which they
+ // were bound. This is required for UDP SO_REUSEADDR.
+ endpoints []TransportEndpoint
+ flags ports.FlagCounter
+}
+
+func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
+ ep.mu.RLock()
+ eps := append([]TransportEndpoint(nil), ep.endpoints...)
+ ep.mu.RUnlock()
+ return eps
+}
+
+// reciprocalScale scales a value into range [0, n).
+//
+// This is similar to val % n, but faster.
+// See http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
+func reciprocalScale(val, n uint32) uint32 {
+ return uint32((uint64(val) * uint64(n)) >> 32)
+}
+
+// selectEndpoint calculates a hash of destination and source addresses and
+// ports then uses it to select a socket. In this case, all packets from one
+// address will be sent to same endpoint.
+func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
+ if len(mpep.endpoints) == 1 {
+ return mpep.endpoints[0]
+ }
+
+ if mpep.flags.IntersectionRefs().ToFlags().Effective().MostRecent {
+ return mpep.endpoints[len(mpep.endpoints)-1]
+ }
+
+ payload := []byte{
+ byte(id.LocalPort),
+ byte(id.LocalPort >> 8),
+ byte(id.RemotePort),
+ byte(id.RemotePort >> 8),
+ }
+
+ h := jenkins.Sum32(seed)
+ h.Write(payload)
+ h.Write([]byte(id.LocalAddress))
+ h.Write([]byte(id.RemoteAddress))
+ hash := h.Sum32()
+
+ idx := reciprocalScale(hash, uint32(len(mpep.endpoints)))
+ return mpep.endpoints[idx]
+}
+
+func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
+ ep.mu.RLock()
+ queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
+ // HandlePacket takes ownership of pkt, so each endpoint needs
+ // its own copy except for the final one.
+ for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] {
+ if mustQueue {
+ queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone())
+ } else {
+ endpoint.HandlePacket(r, id, pkt.Clone())
+ }
+ }
+ if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue {
+ queuedProtocol.QueuePacket(r, endpoint, id, pkt)
+ } else {
+ endpoint.HandlePacket(r, id, pkt)
+ }
+ ep.mu.RUnlock() // Don't use defer for performance reasons.
+}
+
+// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
+// list. The list might be empty already.
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) *tcpip.Error {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ bits := flags.Bits() & ports.MultiBindFlagMask
+
+ if len(ep.endpoints) != 0 {
+ // If it was previously bound, we need to check if we can bind again.
+ if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ ep.endpoints = append(ep.endpoints, t)
+ ep.flags.AddRef(bits)
+
+ return nil
+}
+
+func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ bits := flags.Bits() & ports.MultiBindFlagMask
+
+ if len(ep.endpoints) != 0 {
+ // If it was previously bound, we need to check if we can bind again.
+ if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ return nil
+}
+
+// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
+func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ for i, endpoint := range ep.endpoints {
+ if endpoint == t {
+ copy(ep.endpoints[i:], ep.endpoints[i+1:])
+ ep.endpoints[len(ep.endpoints)-1] = nil
+ ep.endpoints = ep.endpoints[:len(ep.endpoints)-1]
+
+ ep.flags.DropRef(flags.Bits() & ports.MultiBindFlagMask)
+ break
+ }
+ }
+ return len(ep.endpoints) == 0
+}
+
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ if id.RemotePort != 0 {
+ // SO_REUSEPORT only applies to bound/listening endpoints.
+ flags.LoadBalanced = false
+ }
+
+ eps, ok := d.protocol[protocolIDs{netProto, protocol}]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+
+ epsByNIC, ok := eps.endpoints[id]
+ if !ok {
+ epsByNIC = &endpointsByNIC{
+ endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
+ seed: rand.Uint32(),
+ }
+ eps.endpoints[id] = epsByNIC
+ }
+
+ return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice)
+}
+
+func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ if id.RemotePort != 0 {
+ // SO_REUSEPORT only applies to bound/listening endpoints.
+ flags.LoadBalanced = false
+ }
+
+ eps, ok := d.protocol[protocolIDs{netProto, protocol}]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+
+ eps.mu.RLock()
+ defer eps.mu.RUnlock()
+
+ epsByNIC, ok := eps.endpoints[id]
+ if !ok {
+ return nil
+ }
+
+ return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice)
+}
+
+// unregisterEndpoint unregisters the endpoint with the given id such that it
+// won't receive any more packets.
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+ if id.RemotePort != 0 {
+ // SO_REUSEPORT only applies to bound/listening endpoints.
+ flags.LoadBalanced = false
+ }
+
+ for _, n := range netProtos {
+ if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
+ eps.unregisterEndpoint(id, ep, flags, bindToDevice)
+ }
+ }
+}
+
+// deliverPacket attempts to find one or more matching transport endpoints, and
+// then, if matches are found, delivers the packet to them. Returns true if
+// the packet no longer needs to be handled.
+func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool {
+ eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+ if !ok {
+ return false
+ }
+
+ // If the packet is a UDP broadcast or multicast, then find all matching
+ // transport endpoints.
+ if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
+ eps.mu.RLock()
+ destEPs := eps.findAllEndpointsLocked(id)
+ eps.mu.RUnlock()
+ // Fail if we didn't find at least one matching transport endpoint.
+ if len(destEPs) == 0 {
+ r.Stats().UDP.UnknownPortErrors.Increment()
+ return false
+ }
+ // handlePacket takes ownership of pkt, so each endpoint needs its own
+ // copy except for the final one.
+ for _, ep := range destEPs[:len(destEPs)-1] {
+ ep.handlePacket(r, id, pkt.Clone())
+ }
+ destEPs[len(destEPs)-1].handlePacket(r, id, pkt)
+ return true
+ }
+
+ // If the packet is a TCP packet with a non-unicast source or destination
+ // address, then do nothing further and instruct the caller to do the same.
+ if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) {
+ // TCP can only be used to communicate between a single source and a
+ // single destination; the addresses must be unicast.
+ r.Stats().TCP.InvalidSegmentsReceived.Increment()
+ return true
+ }
+
+ eps.mu.RLock()
+ ep := eps.findEndpointLocked(id)
+ eps.mu.RUnlock()
+ if ep == nil {
+ if protocol == header.UDPProtocolNumber {
+ r.Stats().UDP.UnknownPortErrors.Increment()
+ }
+ return false
+ }
+ ep.handlePacket(r, id, pkt)
+ return true
+}
+
+// deliverRawPacket attempts to deliver the given packet and returns whether it
+// was delivered successfully.
+func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool {
+ eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+ if !ok {
+ return false
+ }
+
+ // As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via
+ // raw endpoint first. If there are multiple raw endpoints, they all
+ // receive the packet.
+ foundRaw := false
+ eps.mu.RLock()
+ for _, rawEP := range eps.rawEndpoints {
+ // Each endpoint gets its own copy of the packet for the sake
+ // of save/restore.
+ rawEP.HandlePacket(r, pkt)
+ foundRaw = true
+ }
+ eps.mu.RUnlock()
+
+ return foundRaw
+}
+
+// deliverControlPacket attempts to deliver the given control packet. Returns
+// true if it found an endpoint, false otherwise.
+func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer, id TransportEndpointID) bool {
+ eps, ok := d.protocol[protocolIDs{net, trans}]
+ if !ok {
+ return false
+ }
+
+ eps.mu.RLock()
+ ep := eps.findEndpointLocked(id)
+ eps.mu.RUnlock()
+ if ep == nil {
+ return false
+ }
+
+ ep.handleControlPacket(n, id, typ, extra, pkt)
+ return true
+}
+
+// findTransportEndpoint find a single endpoint that most closely matches the provided id.
+func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
+ eps, ok := d.protocol[protocolIDs{netProto, transProto}]
+ if !ok {
+ return nil
+ }
+
+ eps.mu.RLock()
+ epsByNIC := eps.findEndpointLocked(id)
+ if epsByNIC == nil {
+ eps.mu.RUnlock()
+ return nil
+ }
+
+ epsByNIC.mu.RLock()
+ eps.mu.RUnlock()
+
+ mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()]
+ if !ok {
+ if mpep, ok = epsByNIC.endpoints[0]; !ok {
+ epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
+ return nil
+ }
+ }
+
+ ep := selectEndpoint(id, mpep, epsByNIC.seed)
+ epsByNIC.mu.RUnlock()
+ return ep
+}
+
+// registerRawEndpoint registers the given endpoint with the dispatcher such
+// that packets of the appropriate protocol are delivered to it. A single
+// packet can be sent to one or more raw endpoints along with a non-raw
+// endpoint.
+func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
+ eps, ok := d.protocol[protocolIDs{netProto, transProto}]
+ if !ok {
+ return tcpip.ErrNotSupported
+ }
+
+ eps.mu.Lock()
+ eps.rawEndpoints = append(eps.rawEndpoints, ep)
+ eps.mu.Unlock()
+
+ return nil
+}
+
+// unregisterRawEndpoint unregisters the raw endpoint for the given transport
+// protocol such that it won't receive any more packets.
+func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
+ eps, ok := d.protocol[protocolIDs{netProto, transProto}]
+ if !ok {
+ panic(fmt.Errorf("tried to unregister endpoint with unsupported network and transport protocol pair: %d, %d", netProto, transProto))
+ }
+
+ eps.mu.Lock()
+ for i, rawEP := range eps.rawEndpoints {
+ if rawEP == ep {
+ lastIdx := len(eps.rawEndpoints) - 1
+ eps.rawEndpoints[i] = eps.rawEndpoints[lastIdx]
+ eps.rawEndpoints[lastIdx] = nil
+ eps.rawEndpoints = eps.rawEndpoints[:lastIdx]
+ break
+ }
+ }
+ eps.mu.Unlock()
+}
+
+func isMulticastOrBroadcast(addr tcpip.Address) bool {
+ return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
+}
+
+func isUnicast(addr tcpip.Address) bool {
+ return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr)
+}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
new file mode 100644
index 000000000..73dada928
--- /dev/null
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -0,0 +1,390 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "math"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ testSrcAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testDstAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
+ testSrcAddrV4 = "\x0a\x00\x00\x01"
+ testDstAddrV4 = "\x0a\x00\x00\x02"
+
+ testDstPort = 1234
+ testSrcPort = 4096
+)
+
+type testContext struct {
+ linkEps map[tcpip.NICID]*channel.Endpoint
+ s *stack.Stack
+ wq waiter.Queue
+}
+
+// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs.
+func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ linkEps := make(map[tcpip.NICID]*channel.Endpoint)
+ for _, linkEpID := range linkEpIDs {
+ channelEp := channel.New(256, mtu, "")
+ if err := s.CreateNIC(linkEpID, channelEp); err != nil {
+ t.Fatalf("CreateNIC failed: %s", err)
+ }
+ linkEps[linkEpID] = channelEp
+
+ if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil {
+ t.Fatalf("AddAddress IPv4 failed: %s", err)
+ }
+
+ if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil {
+ t.Fatalf("AddAddress IPv6 failed: %s", err)
+ }
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: header.IPv4EmptySubnet, NIC: 1},
+ {Destination: header.IPv6EmptySubnet, NIC: 1},
+ })
+
+ return &testContext{
+ s: s,
+ linkEps: linkEps,
+ }
+}
+
+type headers struct {
+ srcPort, dstPort uint16
+}
+
+func newPayload() []byte {
+ b := make([]byte, 30+rand.Intn(100))
+ for i := range b {
+ b[i] = byte(rand.Intn(256))
+ }
+ return b
+}
+
+func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
+ buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
+ payloadStart := len(buf) - len(payload)
+ copy(buf[payloadStart:], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: 0x80,
+ TotalLength: uint16(len(buf)),
+ TTL: 65,
+ Protocol: uint8(udp.ProtocolNumber),
+ SrcAddr: testSrcAddrV4,
+ DstAddr: testDstAddrV4,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ // Initialize the UDP header.
+ u := header.UDP(buf[header.IPv4MinimumSize:])
+ u.Encode(&header.UDPFields{
+ SrcPort: h.srcPort,
+ DstPort: h.dstPort,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
+ })
+
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u)))
+
+ // Calculate the UDP checksum and set it.
+ xsum = header.Checksum(payload, xsum)
+ u.SetChecksum(^u.CalculateChecksum(xsum))
+
+ // Inject packet.
+ c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ NetworkHeader: buffer.View(ip),
+ TransportHeader: buffer.View(u),
+ })
+}
+
+func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
+ // Allocate a buffer for data and headers.
+ buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
+ copy(buf[len(buf)-len(payload):], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 65,
+ SrcAddr: testSrcAddrV6,
+ DstAddr: testDstAddrV6,
+ })
+
+ // Initialize the UDP header.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.Encode(&header.UDPFields{
+ SrcPort: h.srcPort,
+ DstPort: h.dstPort,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
+ })
+
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u)))
+
+ // Calculate the UDP checksum and set it.
+ xsum = header.Checksum(payload, xsum)
+ u.SetChecksum(^u.CalculateChecksum(xsum))
+
+ // Inject packet.
+ c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ NetworkHeader: buffer.View(ip),
+ TransportHeader: buffer.View(u),
+ })
+}
+
+func TestTransportDemuxerRegister(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ want *tcpip.Error
+ }{
+ {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol},
+ {"success", ipv4.ProtocolNumber, nil},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ tEP, ok := ep.(stack.TransportEndpoint)
+ if !ok {
+ t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
+ }
+ if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want {
+ t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want)
+ }
+ })
+ }
+}
+
+// TestBindToDeviceDistribution injects varied packets on input devices and checks that
+// the distribution of packets received matches expectations.
+func TestBindToDeviceDistribution(t *testing.T) {
+ type endpointSockopts struct {
+ reuse bool
+ bindToDevice tcpip.NICID
+ }
+ for _, test := range []struct {
+ name string
+ // endpoints will received the inject packets.
+ endpoints []endpointSockopts
+ // wantDistributions is the want ratio of packets received on each
+ // endpoint for each NIC on which packets are injected.
+ wantDistributions map[tcpip.NICID][]float64
+ }{
+ {
+ "BindPortReuse",
+ // 5 endpoints that all have reuse set.
+ []endpointSockopts{
+ {reuse: true, bindToDevice: 0},
+ {reuse: true, bindToDevice: 0},
+ {reuse: true, bindToDevice: 0},
+ {reuse: true, bindToDevice: 0},
+ {reuse: true, bindToDevice: 0},
+ },
+ map[tcpip.NICID][]float64{
+ // Injected packets on dev0 get distributed evenly.
+ 1: {0.2, 0.2, 0.2, 0.2, 0.2},
+ },
+ },
+ {
+ "BindToDevice",
+ // 3 endpoints with various bindings.
+ []endpointSockopts{
+ {reuse: false, bindToDevice: 1},
+ {reuse: false, bindToDevice: 2},
+ {reuse: false, bindToDevice: 3},
+ },
+ map[tcpip.NICID][]float64{
+ // Injected packets on dev0 go only to the endpoint bound to dev0.
+ 1: {1, 0, 0},
+ // Injected packets on dev1 go only to the endpoint bound to dev1.
+ 2: {0, 1, 0},
+ // Injected packets on dev2 go only to the endpoint bound to dev2.
+ 3: {0, 0, 1},
+ },
+ },
+ {
+ "ReuseAndBindToDevice",
+ // 6 endpoints with various bindings.
+ []endpointSockopts{
+ {reuse: true, bindToDevice: 1},
+ {reuse: true, bindToDevice: 1},
+ {reuse: true, bindToDevice: 2},
+ {reuse: true, bindToDevice: 2},
+ {reuse: true, bindToDevice: 2},
+ {reuse: true, bindToDevice: 0},
+ },
+ map[tcpip.NICID][]float64{
+ // Injected packets on dev0 get distributed among endpoints bound to
+ // dev0.
+ 1: {0.5, 0.5, 0, 0, 0, 0},
+ // Injected packets on dev1 get distributed among endpoints bound to
+ // dev1 or unbound.
+ 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
+ // Injected packets on dev999 go only to the unbound.
+ 1000: {0, 0, 0, 0, 0, 1},
+ },
+ },
+ } {
+ for protoName, netProtoNum := range map[string]tcpip.NetworkProtocolNumber{
+ "IPv4": ipv4.ProtocolNumber,
+ "IPv6": ipv6.ProtocolNumber,
+ } {
+ for device, wantDistribution := range test.wantDistributions {
+ t.Run(test.name+protoName+string(device), func(t *testing.T) {
+ var devices []tcpip.NICID
+ for d := range test.wantDistributions {
+ devices = append(devices, d)
+ }
+ c := newDualTestContextMultiNIC(t, defaultMTU, devices)
+
+ eps := make(map[tcpip.Endpoint]int)
+
+ pollChannel := make(chan tcpip.Endpoint)
+ for i, endpoint := range test.endpoints {
+ // Try to receive the data.
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+
+ var err *tcpip.Error
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ eps[ep] = i
+
+ go func(ep tcpip.Endpoint) {
+ for range ch {
+ pollChannel <- ep
+ }
+ }(ep)
+
+ defer ep.Close()
+ if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil {
+ t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err)
+ }
+ bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
+ if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
+ t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", bindToDeviceOption, i, err)
+ }
+
+ var dstAddr tcpip.Address
+ switch netProtoNum {
+ case ipv4.ProtocolNumber:
+ dstAddr = testDstAddrV4
+ case ipv6.ProtocolNumber:
+ dstAddr = testDstAddrV6
+ default:
+ t.Fatalf("unexpected protocol number: %d", netProtoNum)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil {
+ t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err)
+ }
+ }
+
+ npackets := 100000
+ nports := 10000
+ if got, want := len(test.endpoints), len(wantDistribution); got != want {
+ t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
+ }
+ ports := make(map[uint16]tcpip.Endpoint)
+ stats := make(map[tcpip.Endpoint]int)
+ for i := 0; i < npackets; i++ {
+ // Send a packet.
+ port := uint16(i % nports)
+ payload := newPayload()
+ hdrs := &headers{
+ srcPort: testSrcPort + port,
+ dstPort: testDstPort,
+ }
+ switch netProtoNum {
+ case ipv4.ProtocolNumber:
+ c.sendV4Packet(payload, hdrs, device)
+ case ipv6.ProtocolNumber:
+ c.sendV6Packet(payload, hdrs, device)
+ default:
+ t.Fatalf("unexpected protocol number: %d", netProtoNum)
+ }
+
+ ep := <-pollChannel
+ if _, _, err := ep.Read(nil); err != nil {
+ t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err)
+ }
+ stats[ep]++
+ if i < nports {
+ ports[uint16(i)] = ep
+ } else {
+ // Check that all packets from one client are handled by the same
+ // socket.
+ if want, got := ports[port], ep; want != got {
+ t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got])
+ }
+ }
+ }
+
+ // Check that a packet distribution is as expected.
+ for ep, i := range eps {
+ wantRatio := wantDistribution[i]
+ wantRecv := wantRatio * float64(npackets)
+ actualRecv := stats[ep]
+ actualRatio := float64(stats[ep]) / float64(npackets)
+ // The deviation is less than 10%.
+ if math.Abs(actualRatio-wantRatio) > 0.05 {
+ t.Errorf("want about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantRatio*100, wantRecv, npackets, i, actualRatio*100, actualRecv, npackets)
+ }
+ }
+ })
+ }
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
new file mode 100644
index 000000000..7e8b84867
--- /dev/null
+++ b/pkg/tcpip/stack/transport_test.go
@@ -0,0 +1,664 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ fakeTransNumber tcpip.TransportProtocolNumber = 1
+ fakeTransHeaderLen = 3
+)
+
+// fakeTransportEndpoint is a transport-layer protocol endpoint. It counts
+// received packets; the counts of all endpoints are aggregated in the protocol
+// descriptor.
+//
+// Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't
+// use it.
+type fakeTransportEndpoint struct {
+ stack.TransportEndpointInfo
+ stack *stack.Stack
+ proto *fakeTransportProtocol
+ peerAddr tcpip.Address
+ route stack.Route
+ uniqueID uint64
+
+ // acceptQueue is non-nil iff bound.
+ acceptQueue []fakeTransportEndpoint
+}
+
+func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo {
+ return &f.TransportEndpointInfo
+}
+
+func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats {
+ return nil
+}
+
+func (f *fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
+
+func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
+ return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+}
+
+func (f *fakeTransportEndpoint) Abort() {
+ f.Close()
+}
+
+func (f *fakeTransportEndpoint) Close() {
+ f.route.Release()
+}
+
+func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return mask
+}
+
+func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ return buffer.View{}, tcpip.ControlMessages{}, nil
+}
+
+func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ if len(f.route.RemoteAddress) == 0 {
+ return 0, nil, tcpip.ErrNoRoute
+ }
+
+ hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()) + fakeTransHeaderLen)
+ hdr.Prepend(fakeTransHeaderLen)
+ v, err := p.FullPayload()
+ if err != nil {
+ return 0, nil, err
+ }
+ if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
+ Header: hdr,
+ Data: buffer.View(v).ToVectorisedView(),
+ }); err != nil {
+ return 0, nil, err
+ }
+
+ return int64(len(v)), nil, nil
+}
+
+func (f *fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// SetSockOpt sets a socket option. Currently not supported.
+func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// SetSockOptBool sets a socket option. Currently not supported.
+func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// SetSockOptInt sets a socket option. Currently not supported.
+func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+ return -1, tcpip.ErrUnknownProtocolOption
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+ }
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*fakeTransportEndpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ f.peerAddr = addr.Addr
+
+ // Find the route.
+ r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ return tcpip.ErrNoRoute
+ }
+ defer r.Release()
+
+ // Try to register so that we can start receiving packets.
+ f.ID.RemoteAddress = addr.Addr
+ err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
+ if err != nil {
+ return err
+ }
+
+ f.route = r.Clone()
+
+ return nil
+}
+
+func (f *fakeTransportEndpoint) UniqueID() uint64 {
+ return f.uniqueID
+}
+
+func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
+ return nil
+}
+
+func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) *tcpip.Error {
+ return nil
+}
+
+func (*fakeTransportEndpoint) Reset() {
+}
+
+func (*fakeTransportEndpoint) Listen(int) *tcpip.Error {
+ return nil
+}
+
+func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ if len(f.acceptQueue) == 0 {
+ return nil, nil, nil
+ }
+ a := f.acceptQueue[0]
+ f.acceptQueue = f.acceptQueue[1:]
+ return &a, nil, nil
+}
+
+func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
+ if err := f.stack.RegisterTransportEndpoint(
+ a.NIC,
+ []tcpip.NetworkProtocolNumber{fakeNetNumber},
+ fakeTransNumber,
+ stack.TransportEndpointID{LocalAddress: a.Addr},
+ f,
+ ports.Flags{},
+ 0, /* bindtoDevice */
+ ); err != nil {
+ return err
+ }
+ f.acceptQueue = []fakeTransportEndpoint{}
+ return nil
+}
+
+func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, nil
+}
+
+func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, nil
+}
+
+func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) {
+ // Increment the number of received packets.
+ f.proto.packetCount++
+ if f.acceptQueue != nil {
+ f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
+ stack: f.stack,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ ID: f.ID,
+ NetProto: f.NetProto,
+ },
+ proto: f.proto,
+ peerAddr: r.RemoteAddress,
+ route: r.Clone(),
+ })
+ }
+}
+
+func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) {
+ // Increment the number of received control packets.
+ f.proto.controlCount++
+}
+
+func (f *fakeTransportEndpoint) State() uint32 {
+ return 0
+}
+
+func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {}
+
+func (f *fakeTransportEndpoint) IPTables() (stack.IPTables, error) {
+ return stack.IPTables{}, nil
+}
+
+func (f *fakeTransportEndpoint) Resume(*stack.Stack) {}
+
+func (f *fakeTransportEndpoint) Wait() {}
+
+type fakeTransportGoodOption bool
+
+type fakeTransportBadOption bool
+
+type fakeTransportInvalidValueOption int
+
+type fakeTransportProtocolOptions struct {
+ good bool
+}
+
+// fakeTransportProtocol is a transport-layer protocol descriptor. It
+// aggregates the number of packets received via endpoints of this protocol.
+type fakeTransportProtocol struct {
+ packetCount int
+ controlCount int
+ opts fakeTransportProtocolOptions
+}
+
+func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
+ return fakeTransNumber
+}
+
+func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
+}
+
+func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return nil, tcpip.ErrUnknownProtocol
+}
+
+func (*fakeTransportProtocol) MinimumPacketSize() int {
+ return fakeTransHeaderLen
+}
+
+func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcpip.Error) {
+ return 0, 0, nil
+}
+
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool {
+ return true
+}
+
+func (f *fakeTransportProtocol) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case fakeTransportGoodOption:
+ f.opts.good = bool(v)
+ return nil
+ case fakeTransportInvalidValueOption:
+ return tcpip.ErrInvalidOptionValue
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *fakeTransportGoodOption:
+ *v = fakeTransportGoodOption(f.opts.good)
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Abort implements TransportProtocol.Abort.
+func (*fakeTransportProtocol) Abort() {}
+
+// Close implements tcpip.Endpoint.Close.
+func (*fakeTransportProtocol) Close() {}
+
+// Wait implements TransportProtocol.Wait.
+func (*fakeTransportProtocol) Wait() {}
+
+// Parse implements TransportProtocol.Parse.
+func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool {
+ hdr, ok := pkt.Data.PullUp(fakeTransHeaderLen)
+ if !ok {
+ return false
+ }
+ pkt.TransportHeader = hdr
+ pkt.Data.TrimFront(fakeTransHeaderLen)
+ return true
+}
+
+func fakeTransFactory() stack.TransportProtocol {
+ return &fakeTransportProtocol{}
+}
+
+func TestTransportReceive(t *testing.T) {
+ linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
+ if err := s.CreateNIC(1, linkEP); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ // Create endpoint and connect to remote address.
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
+ t.Fatalf("Connect failed: %v", err)
+ }
+
+ fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol)
+
+ // Create buffer that will hold the packet.
+ buf := buffer.NewView(30)
+
+ // Make sure packet with wrong protocol is not delivered.
+ buf[0] = 1
+ buf[2] = 0
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ if fakeTrans.packetCount != 0 {
+ t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
+ }
+
+ // Make sure packet from the wrong source is not delivered.
+ buf[0] = 1
+ buf[1] = 3
+ buf[2] = byte(fakeTransNumber)
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ if fakeTrans.packetCount != 0 {
+ t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
+ }
+
+ // Make sure packet is delivered.
+ buf[0] = 1
+ buf[1] = 2
+ buf[2] = byte(fakeTransNumber)
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ if fakeTrans.packetCount != 1 {
+ t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1)
+ }
+}
+
+func TestTransportControlReceive(t *testing.T) {
+ linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
+ if err := s.CreateNIC(1, linkEP); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ // Create endpoint and connect to remote address.
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
+ t.Fatalf("Connect failed: %v", err)
+ }
+
+ fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol)
+
+ // Create buffer that will hold the control packet.
+ buf := buffer.NewView(2*fakeNetHeaderLen + 30)
+
+ // Outer packet contains the control protocol number.
+ buf[0] = 1
+ buf[1] = 0xfe
+ buf[2] = uint8(fakeControlProtocol)
+
+ // Make sure packet with wrong protocol is not delivered.
+ buf[fakeNetHeaderLen+0] = 0
+ buf[fakeNetHeaderLen+1] = 1
+ buf[fakeNetHeaderLen+2] = 0
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ if fakeTrans.controlCount != 0 {
+ t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
+ }
+
+ // Make sure packet from the wrong source is not delivered.
+ buf[fakeNetHeaderLen+0] = 3
+ buf[fakeNetHeaderLen+1] = 1
+ buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ if fakeTrans.controlCount != 0 {
+ t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
+ }
+
+ // Make sure packet is delivered.
+ buf[fakeNetHeaderLen+0] = 2
+ buf[fakeNetHeaderLen+1] = 1
+ buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ if fakeTrans.controlCount != 1 {
+ t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1)
+ }
+}
+
+func TestTransportSend(t *testing.T) {
+ linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
+ if err := s.CreateNIC(1, linkEP); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ // Create endpoint and bind it.
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
+ t.Fatalf("Connect failed: %v", err)
+ }
+
+ // Create buffer that will hold the payload.
+ view := buffer.NewView(30)
+ _, _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
+ if err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+
+ if fakeNet.sendPacketCount[2] != 1 {
+ t.Errorf("sendPacketCount = %d, want %d", fakeNet.sendPacketCount[2], 1)
+ }
+}
+
+func TestTransportOptions(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
+
+ // Try an unsupported transport protocol.
+ if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol {
+ t.Fatalf("SetTransportProtocolOption(fakeTrans2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err)
+ }
+
+ testCases := []struct {
+ option interface{}
+ wantErr *tcpip.Error
+ verifier func(t *testing.T, p stack.TransportProtocol)
+ }{
+ {fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) {
+ t.Helper()
+ fakeTrans := p.(*fakeTransportProtocol)
+ if fakeTrans.opts.good != true {
+ t.Fatalf("fakeTrans.opts.good = false, want = true")
+ }
+ var v fakeTransportGoodOption
+ if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil {
+ t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err)
+ }
+ if v != true {
+ t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v)
+ }
+
+ }},
+ {fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
+ {fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
+ }
+ for _, tc := range testCases {
+ if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr {
+ t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr)
+ }
+ if tc.verifier != nil {
+ tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber))
+ }
+ }
+}
+
+func TestTransportForwarding(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
+ s.SetForwarding(true)
+
+ // TODO(b/123449044): Change this to a channel NIC.
+ ep1 := loopback.New()
+ if err := s.CreateNIC(1, ep1); err != nil {
+ t.Fatalf("CreateNIC #1 failed: %v", err)
+ }
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress #1 failed: %v", err)
+ }
+
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
+ t.Fatalf("CreateNIC #2 failed: %v", err)
+ }
+ if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
+ t.Fatalf("AddAddress #2 failed: %v", err)
+ }
+
+ // Route all packets to address 3 to NIC 2 and all packets to address
+ // 1 to NIC 1.
+ {
+ subnet0, err := tcpip.NewSubnet("\x03", "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet0, Gateway: "\x00", NIC: 2},
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ })
+ }
+
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if err := ep.Bind(tcpip.FullAddress{Addr: "\x01", NIC: 1}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Send a packet to address 1 from address 3.
+ req := buffer.NewView(30)
+ req[0] = 1
+ req[1] = 3
+ req[2] = byte(fakeTransNumber)
+ ep2.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ Data: req.ToVectorisedView(),
+ })
+
+ aep, _, err := ep.Accept()
+ if err != nil || aep == nil {
+ t.Fatalf("Accept failed: %v, %v", aep, err)
+ }
+
+ resp := buffer.NewView(30)
+ if _, _, err := aep.Write(tcpip.SlicePayload(resp), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ p, ok := ep2.Read()
+ if !ok {
+ t.Fatal("Response packet not forwarded")
+ }
+
+ if dst := p.Pkt.NetworkHeader[0]; dst != 3 {
+ t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
+ }
+ if src := p.Pkt.NetworkHeader[1]; src != 1 {
+ t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
+ }
+}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
new file mode 100644
index 000000000..25534a10d
--- /dev/null
+++ b/pkg/tcpip/tcpip.go
@@ -0,0 +1,1616 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tcpip provides the interfaces and related types that users of the
+// tcpip stack will use in order to create endpoints used to send and receive
+// data over the network stack.
+//
+// The starting point is the creation and configuration of a stack. A stack can
+// be created by calling the New() function of the tcpip/stack/stack package;
+// configuring a stack involves creating NICs (via calls to Stack.CreateNIC()),
+// adding network addresses (via calls to Stack.AddAddress()), and
+// setting a route table (via a call to Stack.SetRouteTable()).
+//
+// Once a stack is configured, endpoints can be created by calling
+// Stack.NewEndpoint(). Such endpoints can be used to send/receive data, connect
+// to peers, listen for connections, accept connections, etc., depending on the
+// transport protocol selected.
+package tcpip
+
+import (
+ "errors"
+ "fmt"
+ "math/bits"
+ "reflect"
+ "strconv"
+ "strings"
+ "sync/atomic"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// Error represents an error in the netstack error space. Using a special type
+// ensures that errors outside of this space are not accidentally introduced.
+//
+// Note: to support save / restore, it is important that all tcpip errors have
+// distinct error messages.
+type Error struct {
+ msg string
+
+ ignoreStats bool
+}
+
+// String implements fmt.Stringer.String.
+func (e *Error) String() string {
+ if e == nil {
+ return "<nil>"
+ }
+ return e.msg
+}
+
+// IgnoreStats indicates whether this error type should be included in failure
+// counts in tcpip.Stats structs.
+func (e *Error) IgnoreStats() bool {
+ return e.ignoreStats
+}
+
+// Errors that can be returned by the network stack.
+var (
+ ErrUnknownProtocol = &Error{msg: "unknown protocol"}
+ ErrUnknownNICID = &Error{msg: "unknown nic id"}
+ ErrUnknownDevice = &Error{msg: "unknown device"}
+ ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"}
+ ErrDuplicateNICID = &Error{msg: "duplicate nic id"}
+ ErrDuplicateAddress = &Error{msg: "duplicate address"}
+ ErrNoRoute = &Error{msg: "no route"}
+ ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"}
+ ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true}
+ ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"}
+ ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true}
+ ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true}
+ ErrNoPortAvailable = &Error{msg: "no ports are available"}
+ ErrPortInUse = &Error{msg: "port is in use"}
+ ErrBadLocalAddress = &Error{msg: "bad local address"}
+ ErrClosedForSend = &Error{msg: "endpoint is closed for send"}
+ ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"}
+ ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true}
+ ErrConnectionRefused = &Error{msg: "connection was refused"}
+ ErrTimeout = &Error{msg: "operation timed out"}
+ ErrAborted = &Error{msg: "operation aborted"}
+ ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true}
+ ErrDestinationRequired = &Error{msg: "destination address is required"}
+ ErrNotSupported = &Error{msg: "operation not supported"}
+ ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"}
+ ErrNotConnected = &Error{msg: "endpoint not connected"}
+ ErrConnectionReset = &Error{msg: "connection reset by peer"}
+ ErrConnectionAborted = &Error{msg: "connection aborted"}
+ ErrNoSuchFile = &Error{msg: "no such file"}
+ ErrInvalidOptionValue = &Error{msg: "invalid option value specified"}
+ ErrNoLinkAddress = &Error{msg: "no remote link address"}
+ ErrBadAddress = &Error{msg: "bad address"}
+ ErrNetworkUnreachable = &Error{msg: "network is unreachable"}
+ ErrMessageTooLong = &Error{msg: "message too long"}
+ ErrNoBufferSpace = &Error{msg: "no buffer space available"}
+ ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"}
+ ErrNotPermitted = &Error{msg: "operation not permitted"}
+ ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"}
+)
+
+var messageToError map[string]*Error
+
+var populate sync.Once
+
+// StringToError converts an error message to the error.
+func StringToError(s string) *Error {
+ populate.Do(func() {
+ var errors = []*Error{
+ ErrUnknownProtocol,
+ ErrUnknownNICID,
+ ErrUnknownDevice,
+ ErrUnknownProtocolOption,
+ ErrDuplicateNICID,
+ ErrDuplicateAddress,
+ ErrNoRoute,
+ ErrBadLinkEndpoint,
+ ErrAlreadyBound,
+ ErrInvalidEndpointState,
+ ErrAlreadyConnecting,
+ ErrAlreadyConnected,
+ ErrNoPortAvailable,
+ ErrPortInUse,
+ ErrBadLocalAddress,
+ ErrClosedForSend,
+ ErrClosedForReceive,
+ ErrWouldBlock,
+ ErrConnectionRefused,
+ ErrTimeout,
+ ErrAborted,
+ ErrConnectStarted,
+ ErrDestinationRequired,
+ ErrNotSupported,
+ ErrQueueSizeNotSupported,
+ ErrNotConnected,
+ ErrConnectionReset,
+ ErrConnectionAborted,
+ ErrNoSuchFile,
+ ErrInvalidOptionValue,
+ ErrNoLinkAddress,
+ ErrBadAddress,
+ ErrNetworkUnreachable,
+ ErrMessageTooLong,
+ ErrNoBufferSpace,
+ ErrBroadcastDisabled,
+ ErrNotPermitted,
+ ErrAddressFamilyNotSupported,
+ }
+
+ messageToError = make(map[string]*Error)
+ for _, e := range errors {
+ if messageToError[e.String()] != nil {
+ panic("tcpip errors with duplicated message: " + e.String())
+ }
+ messageToError[e.String()] = e
+ }
+ })
+
+ e, ok := messageToError[s]
+ if !ok {
+ panic("unknown error message: " + s)
+ }
+
+ return e
+}
+
+// Errors related to Subnet
+var (
+ errSubnetLengthMismatch = errors.New("subnet length of address and mask differ")
+ errSubnetAddressMasked = errors.New("subnet address has bits set outside the mask")
+)
+
+// ErrSaveRejection indicates a failed save due to unsupported networking state.
+// This type of errors is only used for save logic.
+type ErrSaveRejection struct {
+ Err error
+}
+
+// Error returns a sensible description of the save rejection error.
+func (e ErrSaveRejection) Error() string {
+ return "save rejected due to unsupported networking state: " + e.Err.Error()
+}
+
+// A Clock provides the current time.
+//
+// Times returned by a Clock should always be used for application-visible
+// time. Only monotonic times should be used for netstack internal timekeeping.
+type Clock interface {
+ // NowNanoseconds returns the current real time as a number of
+ // nanoseconds since the Unix epoch.
+ NowNanoseconds() int64
+
+ // NowMonotonic returns a monotonic time value.
+ NowMonotonic() int64
+}
+
+// Address is a byte slice cast as a string that represents the address of a
+// network node. Or, in the case of unix endpoints, it may represent a path.
+type Address string
+
+// AddressMask is a bitmask for an address.
+type AddressMask string
+
+// String implements Stringer.
+func (m AddressMask) String() string {
+ return Address(m).String()
+}
+
+// Prefix returns the number of bits before the first host bit.
+func (m AddressMask) Prefix() int {
+ p := 0
+ for _, b := range []byte(m) {
+ p += bits.LeadingZeros8(^b)
+ }
+ return p
+}
+
+// Subnet is a subnet defined by its address and mask.
+type Subnet struct {
+ address Address
+ mask AddressMask
+}
+
+// NewSubnet creates a new Subnet, checking that the address and mask are the same length.
+func NewSubnet(a Address, m AddressMask) (Subnet, error) {
+ if len(a) != len(m) {
+ return Subnet{}, errSubnetLengthMismatch
+ }
+ for i := 0; i < len(a); i++ {
+ if a[i]&^m[i] != 0 {
+ return Subnet{}, errSubnetAddressMasked
+ }
+ }
+ return Subnet{a, m}, nil
+}
+
+// String implements Stringer.
+func (s Subnet) String() string {
+ return fmt.Sprintf("%s/%d", s.ID(), s.Prefix())
+}
+
+// Contains returns true iff the address is of the same length and matches the
+// subnet address and mask.
+func (s *Subnet) Contains(a Address) bool {
+ if len(a) != len(s.address) {
+ return false
+ }
+ for i := 0; i < len(a); i++ {
+ if a[i]&s.mask[i] != s.address[i] {
+ return false
+ }
+ }
+ return true
+}
+
+// ID returns the subnet ID.
+func (s *Subnet) ID() Address {
+ return s.address
+}
+
+// Bits returns the number of ones (network bits) and zeros (host bits) in the
+// subnet mask.
+func (s *Subnet) Bits() (ones int, zeros int) {
+ ones = s.mask.Prefix()
+ return ones, len(s.mask)*8 - ones
+}
+
+// Prefix returns the number of bits before the first host bit.
+func (s *Subnet) Prefix() int {
+ return s.mask.Prefix()
+}
+
+// Mask returns the subnet mask.
+func (s *Subnet) Mask() AddressMask {
+ return s.mask
+}
+
+// Broadcast returns the subnet's broadcast address.
+func (s *Subnet) Broadcast() Address {
+ addr := []byte(s.address)
+ for i := range addr {
+ addr[i] |= ^s.mask[i]
+ }
+ return Address(addr)
+}
+
+// Equal returns true if s equals o.
+//
+// Needed to use cmp.Equal on Subnet as its fields are unexported.
+func (s Subnet) Equal(o Subnet) bool {
+ return s == o
+}
+
+// NICID is a number that uniquely identifies a NIC.
+type NICID int32
+
+// ShutdownFlags represents flags that can be passed to the Shutdown() method
+// of the Endpoint interface.
+type ShutdownFlags int
+
+// Values of the flags that can be passed to the Shutdown() method. They can
+// be OR'ed together.
+const (
+ ShutdownRead ShutdownFlags = 1 << iota
+ ShutdownWrite
+)
+
+// FullAddress represents a full transport node address, as required by the
+// Connect() and Bind() methods.
+//
+// +stateify savable
+type FullAddress struct {
+ // NIC is the ID of the NIC this address refers to.
+ //
+ // This may not be used by all endpoint types.
+ NIC NICID
+
+ // Addr is the network or link layer address.
+ Addr Address
+
+ // Port is the transport port.
+ //
+ // This may not be used by all endpoint types.
+ Port uint16
+}
+
+// Payloader is an interface that provides data.
+//
+// This interface allows the endpoint to request the amount of data it needs
+// based on internal buffers without exposing them.
+type Payloader interface {
+ // FullPayload returns all available bytes.
+ FullPayload() ([]byte, *Error)
+
+ // Payload returns a slice containing at most size bytes.
+ Payload(size int) ([]byte, *Error)
+}
+
+// SlicePayload implements Payloader for slices.
+//
+// This is typically used for tests.
+type SlicePayload []byte
+
+// FullPayload implements Payloader.FullPayload.
+func (s SlicePayload) FullPayload() ([]byte, *Error) {
+ return s, nil
+}
+
+// Payload implements Payloader.Payload.
+func (s SlicePayload) Payload(size int) ([]byte, *Error) {
+ if size > len(s) {
+ size = len(s)
+ }
+ return s[:size], nil
+}
+
+// A ControlMessages contains socket control messages for IP sockets.
+//
+// +stateify savable
+type ControlMessages struct {
+ // HasTimestamp indicates whether Timestamp is valid/set.
+ HasTimestamp bool
+
+ // Timestamp is the time (in ns) that the last packet used to create
+ // the read data was received.
+ Timestamp int64
+
+ // HasInq indicates whether Inq is valid/set.
+ HasInq bool
+
+ // Inq is the number of bytes ready to be received.
+ Inq int32
+
+ // HasTOS indicates whether Tos is valid/set.
+ HasTOS bool
+
+ // TOS is the IPv4 type of service of the associated packet.
+ TOS uint8
+
+ // HasTClass indicates whether TClass is valid/set.
+ HasTClass bool
+
+ // TClass is the IPv6 traffic class of the associated packet.
+ TClass uint32
+
+ // HasIPPacketInfo indicates whether PacketInfo is set.
+ HasIPPacketInfo bool
+
+ // PacketInfo holds interface and address data on an incoming packet.
+ PacketInfo IPPacketInfo
+}
+
+// PacketOwner is used to get UID and GID of the packet.
+type PacketOwner interface {
+ // UID returns UID of the packet.
+ UID() uint32
+
+ // GID returns GID of the packet.
+ GID() uint32
+}
+
+// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
+// that exposes functionality like read, write, connect, etc. to users of the
+// networking stack.
+type Endpoint interface {
+ // Close puts the endpoint in a closed state and frees all resources
+ // associated with it. Close initiates the teardown process, the
+ // Endpoint may not be fully closed when Close returns.
+ Close()
+
+ // Abort initiates an expedited endpoint teardown. As compared to
+ // Close, Abort prioritizes closing the Endpoint quickly over cleanly.
+ // Abort is best effort; implementing Abort with Close is acceptable.
+ Abort()
+
+ // Read reads data from the endpoint and optionally returns the sender.
+ //
+ // This method does not block if there is no data pending. It will also
+ // either return an error or data, never both.
+ Read(*FullAddress) (buffer.View, ControlMessages, *Error)
+
+ // Write writes data to the endpoint's peer. This method does not block if
+ // the data cannot be written.
+ //
+ // Unlike io.Writer.Write, Endpoint.Write transfers ownership of any bytes
+ // successfully written to the Endpoint. That is, if a call to
+ // Write(SlicePayload{data}) returns (n, err), it may retain data[:n], and
+ // the caller should not use data[:n] after Write returns.
+ //
+ // Note that unlike io.Writer.Write, it is not an error for Write to
+ // perform a partial write (if n > 0, no error may be returned). Only
+ // stream (TCP) Endpoints may return partial writes, and even then only
+ // in the case where writing additional data would block. Other Endpoints
+ // will either write the entire message or return an error.
+ //
+ // For UDP and Ping sockets if address resolution is required,
+ // ErrNoLinkAddress and a notification channel is returned for the caller to
+ // block. Channel is closed once address resolution is complete (success or
+ // not). The channel is only non-nil in this case.
+ Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error)
+
+ // Peek reads data without consuming it from the endpoint.
+ //
+ // This method does not block if there is no data pending.
+ Peek([][]byte) (int64, ControlMessages, *Error)
+
+ // Connect connects the endpoint to its peer. Specifying a NIC is
+ // optional.
+ //
+ // There are three classes of return values:
+ // nil -- the attempt to connect succeeded.
+ // ErrConnectStarted/ErrAlreadyConnecting -- the connect attempt started
+ // but hasn't completed yet. In this case, the caller must call Connect
+ // or GetSockOpt(ErrorOption) when the endpoint becomes writable to
+ // get the actual result. The first call to Connect after the socket has
+ // connected returns nil. Calling connect again results in ErrAlreadyConnected.
+ // Anything else -- the attempt to connect failed.
+ //
+ // If address.Addr is empty, this means that Enpoint has to be
+ // disconnected if this is supported, otherwise
+ // ErrAddressFamilyNotSupported must be returned.
+ Connect(address FullAddress) *Error
+
+ // Disconnect disconnects the endpoint from its peer.
+ Disconnect() *Error
+
+ // Shutdown closes the read and/or write end of the endpoint connection
+ // to its peer.
+ Shutdown(flags ShutdownFlags) *Error
+
+ // Listen puts the endpoint in "listen" mode, which allows it to accept
+ // new connections.
+ Listen(backlog int) *Error
+
+ // Accept returns a new endpoint if a peer has established a connection
+ // to an endpoint previously set to listen mode. This method does not
+ // block if no new connections are available.
+ //
+ // The returned Queue is the wait queue for the newly created endpoint.
+ Accept() (Endpoint, *waiter.Queue, *Error)
+
+ // Bind binds the endpoint to a specific local address and port.
+ // Specifying a NIC is optional.
+ Bind(address FullAddress) *Error
+
+ // GetLocalAddress returns the address to which the endpoint is bound.
+ GetLocalAddress() (FullAddress, *Error)
+
+ // GetRemoteAddress returns the address to which the endpoint is
+ // connected.
+ GetRemoteAddress() (FullAddress, *Error)
+
+ // Readiness returns the current readiness of the endpoint. For example,
+ // if waiter.EventIn is set, the endpoint is immediately readable.
+ Readiness(mask waiter.EventMask) waiter.EventMask
+
+ // SetSockOpt sets a socket option. opt should be one of the *Option types.
+ SetSockOpt(opt interface{}) *Error
+
+ // SetSockOptBool sets a socket option, for simple cases where a value
+ // has the bool type.
+ SetSockOptBool(opt SockOptBool, v bool) *Error
+
+ // SetSockOptInt sets a socket option, for simple cases where a value
+ // has the int type.
+ SetSockOptInt(opt SockOptInt, v int) *Error
+
+ // GetSockOpt gets a socket option. opt should be a pointer to one of the
+ // *Option types.
+ GetSockOpt(opt interface{}) *Error
+
+ // GetSockOptBool gets a socket option for simple cases where a return
+ // value has the bool type.
+ GetSockOptBool(SockOptBool) (bool, *Error)
+
+ // GetSockOptInt gets a socket option for simple cases where a return
+ // value has the int type.
+ GetSockOptInt(SockOptInt) (int, *Error)
+
+ // State returns a socket's lifecycle state. The returned value is
+ // protocol-specific and is primarily used for diagnostics.
+ State() uint32
+
+ // ModerateRecvBuf should be called everytime data is copied to the user
+ // space. This allows for dynamic tuning of recv buffer space for a
+ // given socket.
+ //
+ // NOTE: This method is a no-op for sockets other than TCP.
+ ModerateRecvBuf(copied int)
+
+ // Info returns a copy to the transport endpoint info.
+ Info() EndpointInfo
+
+ // Stats returns a reference to the endpoint stats.
+ Stats() EndpointStats
+
+ // SetOwner sets the task owner to the endpoint owner.
+ SetOwner(owner PacketOwner)
+}
+
+// EndpointInfo is the interface implemented by each endpoint info struct.
+type EndpointInfo interface {
+ // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
+ // marker interface.
+ IsEndpointInfo()
+}
+
+// EndpointStats is the interface implemented by each endpoint stats struct.
+type EndpointStats interface {
+ // IsEndpointStats is an empty method to implement the tcpip.EndpointStats
+ // marker interface.
+ IsEndpointStats()
+}
+
+// WriteOptions contains options for Endpoint.Write.
+type WriteOptions struct {
+ // If To is not nil, write to the given address instead of the endpoint's
+ // peer.
+ To *FullAddress
+
+ // More has the same semantics as Linux's MSG_MORE.
+ More bool
+
+ // EndOfRecord has the same semantics as Linux's MSG_EOR.
+ EndOfRecord bool
+
+ // Atomic means that all data fetched from Payloader must be written to the
+ // endpoint. If Atomic is false, then data fetched from the Payloader may be
+ // discarded if available endpoint buffer space is unsufficient.
+ Atomic bool
+}
+
+// SockOptBool represents socket options which values have the bool type.
+type SockOptBool int
+
+const (
+ // BroadcastOption is used by SetSockOptBool/GetSockOptBool to specify
+ // whether datagram sockets are allowed to send packets to a broadcast
+ // address.
+ BroadcastOption SockOptBool = iota
+
+ // CorkOption is used by SetSockOptBool/GetSockOptBool to specify if
+ // data should be held until segments are full by the TCP transport
+ // protocol.
+ CorkOption
+
+ // DelayOption is used by SetSockOptBool/GetSockOptBool to specify if
+ // data should be sent out immediately by the transport protocol. For
+ // TCP, it determines if the Nagle algorithm is on or off.
+ DelayOption
+
+ // KeepaliveEnabledOption is used by SetSockOptBool/GetSockOptBool to
+ // specify whether TCP keepalive is enabled for this socket.
+ KeepaliveEnabledOption
+
+ // MulticastLoopOption is used by SetSockOptBool/GetSockOptBool to
+ // specify whether multicast packets sent over a non-loopback interface
+ // will be looped back.
+ MulticastLoopOption
+
+ // NoChecksumOption is used by SetSockOptBool/GetSockOptBool to specify
+ // whether UDP checksum is disabled for this socket.
+ NoChecksumOption
+
+ // PasscredOption is used by SetSockOptBool/GetSockOptBool to specify
+ // whether SCM_CREDENTIALS socket control messages are enabled.
+ //
+ // Only supported on Unix sockets.
+ PasscredOption
+
+ // QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool.
+ QuickAckOption
+
+ // ReceiveTClassOption is used by SetSockOptBool/GetSockOptBool to
+ // specify if the IPV6_TCLASS ancillary message is passed with incoming
+ // packets.
+ ReceiveTClassOption
+
+ // ReceiveTOSOption is used by SetSockOptBool/GetSockOptBool to specify
+ // if the TOS ancillary message is passed with incoming packets.
+ ReceiveTOSOption
+
+ // ReceiveIPPacketInfoOption is used by SetSockOptBool/GetSockOptBool to
+ // specify if more inforamtion is provided with incoming packets such as
+ // interface index and address.
+ ReceiveIPPacketInfoOption
+
+ // ReuseAddressOption is used by SetSockOptBool/GetSockOptBool to
+ // specify whether Bind() should allow reuse of local address.
+ ReuseAddressOption
+
+ // ReusePortOption is used by SetSockOptBool/GetSockOptBool to permit
+ // multiple sockets to be bound to an identical socket address.
+ ReusePortOption
+
+ // V6OnlyOption is used by SetSockOptBool/GetSockOptBool to specify
+ // whether an IPv6 socket is to be restricted to sending and receiving
+ // IPv6 packets only.
+ V6OnlyOption
+
+ // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw
+ // endpoint that all packets being written have an IP header and the
+ // endpoint should not attach an IP header.
+ IPHdrIncludedOption
+)
+
+// SockOptInt represents socket options which values have the int type.
+type SockOptInt int
+
+const (
+ // KeepaliveCountOption is used by SetSockOptInt/GetSockOptInt to
+ // specify the number of un-ACKed TCP keepalives that will be sent
+ // before the connection is closed.
+ KeepaliveCountOption SockOptInt = iota
+
+ // IPv4TOSOption is used by SetSockOptInt/GetSockOptInt to specify TOS
+ // for all subsequent outgoing IPv4 packets from the endpoint.
+ IPv4TOSOption
+
+ // IPv6TrafficClassOption is used by SetSockOptInt/GetSockOptInt to
+ // specify TOS for all subsequent outgoing IPv6 packets from the
+ // endpoint.
+ IPv6TrafficClassOption
+
+ // MaxSegOption is used by SetSockOptInt/GetSockOptInt to set/get the
+ // current Maximum Segment Size(MSS) value as specified using the
+ // TCP_MAXSEG option.
+ MaxSegOption
+
+ // MTUDiscoverOption is used to set/get the path MTU discovery setting.
+ //
+ // NOTE: Setting this option to any other value than PMTUDiscoveryDont
+ // is not supported and will fail as such, and getting this option will
+ // always return PMTUDiscoveryDont.
+ MTUDiscoverOption
+
+ // MulticastTTLOption is used by SetSockOptInt/GetSockOptInt to control
+ // the default TTL value for multicast messages. The default is 1.
+ MulticastTTLOption
+
+ // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the
+ // number of unread bytes in the input buffer should be returned.
+ ReceiveQueueSizeOption
+
+ // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
+ // specify the send buffer size option.
+ SendBufferSizeOption
+
+ // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
+ // specify the receive buffer size option.
+ ReceiveBufferSizeOption
+
+ // SendQueueSizeOption is used in GetSockOptInt to specify that the
+ // number of unread bytes in the output buffer should be returned.
+ SendQueueSizeOption
+
+ // TTLOption is used by SetSockOptInt/GetSockOptInt to control the
+ // default TTL/hop limit value for unicast messages. The default is
+ // protocol specific.
+ //
+ // A zero value indicates the default.
+ TTLOption
+
+ // TCPSynCountOption is used by SetSockOptInt/GetSockOptInt to specify
+ // the number of SYN retransmits that TCP should send before aborting
+ // the attempt to connect. It cannot exceed 255.
+ //
+ // NOTE: This option is currently only stubbed out and is no-op.
+ TCPSynCountOption
+
+ // TCPWindowClampOption is used by SetSockOptInt/GetSockOptInt to bound
+ // the size of the advertised window to this value.
+ //
+ // NOTE: This option is currently only stubed out and is a no-op
+ TCPWindowClampOption
+)
+
+const (
+ // PMTUDiscoveryWant is a setting of the MTUDiscoverOption to use
+ // per-route settings.
+ PMTUDiscoveryWant int = iota
+
+ // PMTUDiscoveryDont is a setting of the MTUDiscoverOption to disable
+ // path MTU discovery.
+ PMTUDiscoveryDont
+
+ // PMTUDiscoveryDo is a setting of the MTUDiscoverOption to always do
+ // path MTU discovery.
+ PMTUDiscoveryDo
+
+ // PMTUDiscoveryProbe is a setting of the MTUDiscoverOption to set DF
+ // but ignore path MTU.
+ PMTUDiscoveryProbe
+)
+
+// ErrorOption is used in GetSockOpt to specify that the last error reported by
+// the endpoint should be cleared and returned.
+type ErrorOption struct{}
+
+// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
+// should bind only on a specific NIC.
+type BindToDeviceOption NICID
+
+// TCPInfoOption is used by GetSockOpt to expose TCP statistics.
+//
+// TODO(b/64800844): Add and populate stat fields.
+type TCPInfoOption struct {
+ RTT time.Duration
+ RTTVar time.Duration
+}
+
+// KeepaliveIdleOption is used by SetSockOpt/GetSockOpt to specify the time a
+// connection must remain idle before the first TCP keepalive packet is sent.
+// Once this time is reached, KeepaliveIntervalOption is used instead.
+type KeepaliveIdleOption time.Duration
+
+// KeepaliveIntervalOption is used by SetSockOpt/GetSockOpt to specify the
+// interval between sending TCP keepalive packets.
+type KeepaliveIntervalOption time.Duration
+
+// TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user
+// specified timeout for a given TCP connection.
+// See: RFC5482 for details.
+type TCPUserTimeoutOption time.Duration
+
+// CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get
+// the current congestion control algorithm.
+type CongestionControlOption string
+
+// AvailableCongestionControlOption is used to query the supported congestion
+// control algorithms.
+type AvailableCongestionControlOption string
+
+// buffer moderation.
+type ModerateReceiveBufferOption bool
+
+// TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the
+// maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state
+// before being marked closed.
+type TCPLingerTimeoutOption time.Duration
+
+// TCPTimeWaitTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the
+// maximum duration for which a socket lingers in the TIME_WAIT state
+// before being marked closed.
+type TCPTimeWaitTimeoutOption time.Duration
+
+// TCPDeferAcceptOption is used by SetSockOpt/GetSockOpt to allow a
+// accept to return a completed connection only when there is data to be
+// read. This usually means the listening socket will drop the final ACK
+// for a handshake till the specified timeout until a segment with data arrives.
+type TCPDeferAcceptOption time.Duration
+
+// TCPMinRTOOption is use by SetSockOpt/GetSockOpt to allow overriding
+// default MinRTO used by the Stack.
+type TCPMinRTOOption time.Duration
+
+// TCPMaxRTOOption is use by SetSockOpt/GetSockOpt to allow overriding
+// default MaxRTO used by the Stack.
+type TCPMaxRTOOption time.Duration
+
+// TCPMaxRetriesOption is used by SetSockOpt/GetSockOpt to set/get the
+// maximum number of retransmits after which we time out the connection.
+type TCPMaxRetriesOption uint64
+
+// TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify
+// the number of endpoints that can be in SYN-RCVD state before the stack
+// switches to using SYN cookies.
+type TCPSynRcvdCountThresholdOption uint64
+
+// TCPSynRetriesOption is used by SetSockOpt/GetSockOpt to specify stack-wide
+// default for number of times SYN is retransmitted before aborting a connect.
+type TCPSynRetriesOption uint8
+
+// MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a
+// default interface for multicast.
+type MulticastInterfaceOption struct {
+ NIC NICID
+ InterfaceAddr Address
+}
+
+// MembershipOption is used by SetSockOpt/GetSockOpt as an argument to
+// AddMembershipOption and RemoveMembershipOption.
+type MembershipOption struct {
+ NIC NICID
+ InterfaceAddr Address
+ MulticastAddr Address
+}
+
+// AddMembershipOption is used by SetSockOpt/GetSockOpt to join a multicast
+// group identified by the given multicast address, on the interface matching
+// the given interface address.
+type AddMembershipOption MembershipOption
+
+// RemoveMembershipOption is used by SetSockOpt/GetSockOpt to leave a multicast
+// group identified by the given multicast address, on the interface matching
+// the given interface address.
+type RemoveMembershipOption MembershipOption
+
+// OutOfBandInlineOption is used by SetSockOpt/GetSockOpt to specify whether
+// TCP out-of-band data is delivered along with the normal in-band data.
+type OutOfBandInlineOption int
+
+// DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify
+// a default TTL.
+type DefaultTTLOption uint8
+
+//
+// IPPacketInfo is the message structure for IP_PKTINFO.
+//
+// +stateify savable
+type IPPacketInfo struct {
+ // NIC is the ID of the NIC to be used.
+ NIC NICID
+
+ // LocalAddr is the local address.
+ LocalAddr Address
+
+ // DestinationAddr is the destination address.
+ DestinationAddr Address
+}
+
+// Route is a row in the routing table. It specifies through which NIC (and
+// gateway) sets of packets should be routed. A row is considered viable if the
+// masked target address matches the destination address in the row.
+type Route struct {
+ // Destination must contain the target address for this row to be viable.
+ Destination Subnet
+
+ // Gateway is the gateway to be used if this row is viable.
+ Gateway Address
+
+ // NIC is the id of the nic to be used if this row is viable.
+ NIC NICID
+}
+
+// String implements the fmt.Stringer interface.
+func (r Route) String() string {
+ var out strings.Builder
+ fmt.Fprintf(&out, "%s", r.Destination)
+ if len(r.Gateway) > 0 {
+ fmt.Fprintf(&out, " via %s", r.Gateway)
+ }
+ fmt.Fprintf(&out, " nic %d", r.NIC)
+ return out.String()
+}
+
+// TransportProtocolNumber is the number of a transport protocol.
+type TransportProtocolNumber uint32
+
+// NetworkProtocolNumber is the number of a network protocol.
+type NetworkProtocolNumber uint32
+
+// A StatCounter keeps track of a statistic.
+type StatCounter struct {
+ count uint64
+}
+
+// Increment adds one to the counter.
+func (s *StatCounter) Increment() {
+ s.IncrementBy(1)
+}
+
+// Decrement minuses one to the counter.
+func (s *StatCounter) Decrement() {
+ s.IncrementBy(^uint64(0))
+}
+
+// Value returns the current value of the counter.
+func (s *StatCounter) Value() uint64 {
+ return atomic.LoadUint64(&s.count)
+}
+
+// IncrementBy increments the counter by v.
+func (s *StatCounter) IncrementBy(v uint64) {
+ atomic.AddUint64(&s.count, v)
+}
+
+func (s *StatCounter) String() string {
+ return strconv.FormatUint(s.Value(), 10)
+}
+
+// ICMPv4PacketStats enumerates counts for all ICMPv4 packet types.
+type ICMPv4PacketStats struct {
+ // Echo is the total number of ICMPv4 echo packets counted.
+ Echo *StatCounter
+
+ // EchoReply is the total number of ICMPv4 echo reply packets counted.
+ EchoReply *StatCounter
+
+ // DstUnreachable is the total number of ICMPv4 destination unreachable
+ // packets counted.
+ DstUnreachable *StatCounter
+
+ // SrcQuench is the total number of ICMPv4 source quench packets
+ // counted.
+ SrcQuench *StatCounter
+
+ // Redirect is the total number of ICMPv4 redirect packets counted.
+ Redirect *StatCounter
+
+ // TimeExceeded is the total number of ICMPv4 time exceeded packets
+ // counted.
+ TimeExceeded *StatCounter
+
+ // ParamProblem is the total number of ICMPv4 parameter problem packets
+ // counted.
+ ParamProblem *StatCounter
+
+ // Timestamp is the total number of ICMPv4 timestamp packets counted.
+ Timestamp *StatCounter
+
+ // TimestampReply is the total number of ICMPv4 timestamp reply packets
+ // counted.
+ TimestampReply *StatCounter
+
+ // InfoRequest is the total number of ICMPv4 information request
+ // packets counted.
+ InfoRequest *StatCounter
+
+ // InfoReply is the total number of ICMPv4 information reply packets
+ // counted.
+ InfoReply *StatCounter
+}
+
+// ICMPv6PacketStats enumerates counts for all ICMPv6 packet types.
+type ICMPv6PacketStats struct {
+ // EchoRequest is the total number of ICMPv6 echo request packets
+ // counted.
+ EchoRequest *StatCounter
+
+ // EchoReply is the total number of ICMPv6 echo reply packets counted.
+ EchoReply *StatCounter
+
+ // DstUnreachable is the total number of ICMPv6 destination unreachable
+ // packets counted.
+ DstUnreachable *StatCounter
+
+ // PacketTooBig is the total number of ICMPv6 packet too big packets
+ // counted.
+ PacketTooBig *StatCounter
+
+ // TimeExceeded is the total number of ICMPv6 time exceeded packets
+ // counted.
+ TimeExceeded *StatCounter
+
+ // ParamProblem is the total number of ICMPv6 parameter problem packets
+ // counted.
+ ParamProblem *StatCounter
+
+ // RouterSolicit is the total number of ICMPv6 router solicit packets
+ // counted.
+ RouterSolicit *StatCounter
+
+ // RouterAdvert is the total number of ICMPv6 router advert packets
+ // counted.
+ RouterAdvert *StatCounter
+
+ // NeighborSolicit is the total number of ICMPv6 neighbor solicit
+ // packets counted.
+ NeighborSolicit *StatCounter
+
+ // NeighborAdvert is the total number of ICMPv6 neighbor advert packets
+ // counted.
+ NeighborAdvert *StatCounter
+
+ // RedirectMsg is the total number of ICMPv6 redirect message packets
+ // counted.
+ RedirectMsg *StatCounter
+}
+
+// ICMPv4SentPacketStats collects outbound ICMPv4-specific stats.
+type ICMPv4SentPacketStats struct {
+ ICMPv4PacketStats
+
+ // Dropped is the total number of ICMPv4 packets dropped due to link
+ // layer errors.
+ Dropped *StatCounter
+
+ // RateLimited is the total number of ICMPv6 packets dropped due to
+ // rate limit being exceeded.
+ RateLimited *StatCounter
+}
+
+// ICMPv4ReceivedPacketStats collects inbound ICMPv4-specific stats.
+type ICMPv4ReceivedPacketStats struct {
+ ICMPv4PacketStats
+
+ // Invalid is the total number of ICMPv4 packets received that the
+ // transport layer could not parse.
+ Invalid *StatCounter
+}
+
+// ICMPv6SentPacketStats collects outbound ICMPv6-specific stats.
+type ICMPv6SentPacketStats struct {
+ ICMPv6PacketStats
+
+ // Dropped is the total number of ICMPv6 packets dropped due to link
+ // layer errors.
+ Dropped *StatCounter
+
+ // RateLimited is the total number of ICMPv6 packets dropped due to
+ // rate limit being exceeded.
+ RateLimited *StatCounter
+}
+
+// ICMPv6ReceivedPacketStats collects inbound ICMPv6-specific stats.
+type ICMPv6ReceivedPacketStats struct {
+ ICMPv6PacketStats
+
+ // Invalid is the total number of ICMPv6 packets received that the
+ // transport layer could not parse.
+ Invalid *StatCounter
+}
+
+// ICMPStats collects ICMP-specific stats (both v4 and v6).
+type ICMPStats struct {
+ // ICMPv4SentPacketStats contains counts of sent packets by ICMPv4 packet type
+ // and a single count of packets which failed to write to the link
+ // layer.
+ V4PacketsSent ICMPv4SentPacketStats
+
+ // ICMPv4ReceivedPacketStats contains counts of received packets by ICMPv4
+ // packet type and a single count of invalid packets received.
+ V4PacketsReceived ICMPv4ReceivedPacketStats
+
+ // ICMPv6SentPacketStats contains counts of sent packets by ICMPv6 packet type
+ // and a single count of packets which failed to write to the link
+ // layer.
+ V6PacketsSent ICMPv6SentPacketStats
+
+ // ICMPv6ReceivedPacketStats contains counts of received packets by ICMPv6
+ // packet type and a single count of invalid packets received.
+ V6PacketsReceived ICMPv6ReceivedPacketStats
+}
+
+// IPStats collects IP-specific stats (both v4 and v6).
+type IPStats struct {
+ // PacketsReceived is the total number of IP packets received from the
+ // link layer in nic.DeliverNetworkPacket.
+ PacketsReceived *StatCounter
+
+ // InvalidDestinationAddressesReceived is the total number of IP packets
+ // received with an unknown or invalid destination address.
+ InvalidDestinationAddressesReceived *StatCounter
+
+ // InvalidSourceAddressesReceived is the total number of IP packets received
+ // with a source address that should never have been received on the wire.
+ InvalidSourceAddressesReceived *StatCounter
+
+ // PacketsDelivered is the total number of incoming IP packets that
+ // are successfully delivered to the transport layer via HandlePacket.
+ PacketsDelivered *StatCounter
+
+ // PacketsSent is the total number of IP packets sent via WritePacket.
+ PacketsSent *StatCounter
+
+ // OutgoingPacketErrors is the total number of IP packets which failed
+ // to write to a link-layer endpoint.
+ OutgoingPacketErrors *StatCounter
+
+ // MalformedPacketsReceived is the total number of IP Packets that were
+ // dropped due to the IP packet header failing validation checks.
+ MalformedPacketsReceived *StatCounter
+
+ // MalformedFragmentsReceived is the total number of IP Fragments that were
+ // dropped due to the fragment failing validation checks.
+ MalformedFragmentsReceived *StatCounter
+}
+
+// TCPStats collects TCP-specific stats.
+type TCPStats struct {
+ // ActiveConnectionOpenings is the number of connections opened
+ // successfully via Connect.
+ ActiveConnectionOpenings *StatCounter
+
+ // PassiveConnectionOpenings is the number of connections opened
+ // successfully via Listen.
+ PassiveConnectionOpenings *StatCounter
+
+ // CurrentEstablished is the number of TCP connections for which the
+ // current state is ESTABLISHED.
+ CurrentEstablished *StatCounter
+
+ // CurrentConnected is the number of TCP connections that
+ // are in connected state.
+ CurrentConnected *StatCounter
+
+ // EstablishedResets is the number of times TCP connections have made
+ // a direct transition to the CLOSED state from either the
+ // ESTABLISHED state or the CLOSE-WAIT state.
+ EstablishedResets *StatCounter
+
+ // EstablishedClosed is the number of times established TCP connections
+ // made a transition to CLOSED state.
+ EstablishedClosed *StatCounter
+
+ // EstablishedTimedout is the number of times an established connection
+ // was reset because of keep-alive time out.
+ EstablishedTimedout *StatCounter
+
+ // ListenOverflowSynDrop is the number of times the listen queue overflowed
+ // and a SYN was dropped.
+ ListenOverflowSynDrop *StatCounter
+
+ // ListenOverflowAckDrop is the number of times the final ACK
+ // in the handshake was dropped due to overflow.
+ ListenOverflowAckDrop *StatCounter
+
+ // ListenOverflowCookieSent is the number of times a SYN cookie was sent.
+ ListenOverflowSynCookieSent *StatCounter
+
+ // ListenOverflowSynCookieRcvd is the number of times a valid SYN
+ // cookie was received.
+ ListenOverflowSynCookieRcvd *StatCounter
+
+ // ListenOverflowInvalidSynCookieRcvd is the number of times an invalid SYN cookie
+ // was received.
+ ListenOverflowInvalidSynCookieRcvd *StatCounter
+
+ // FailedConnectionAttempts is the number of calls to Connect or Listen
+ // (active and passive openings, respectively) that end in an error.
+ FailedConnectionAttempts *StatCounter
+
+ // ValidSegmentsReceived is the number of TCP segments received that
+ // the transport layer successfully parsed.
+ ValidSegmentsReceived *StatCounter
+
+ // InvalidSegmentsReceived is the number of TCP segments received that
+ // the transport layer could not parse.
+ InvalidSegmentsReceived *StatCounter
+
+ // SegmentsSent is the number of TCP segments sent.
+ SegmentsSent *StatCounter
+
+ // SegmentSendErrors is the number of TCP segments failed to be sent.
+ SegmentSendErrors *StatCounter
+
+ // ResetsSent is the number of TCP resets sent.
+ ResetsSent *StatCounter
+
+ // ResetsReceived is the number of TCP resets received.
+ ResetsReceived *StatCounter
+
+ // Retransmits is the number of TCP segments retransmitted.
+ Retransmits *StatCounter
+
+ // FastRecovery is the number of times Fast Recovery was used to
+ // recover from packet loss.
+ FastRecovery *StatCounter
+
+ // SACKRecovery is the number of times SACK Recovery was used to
+ // recover from packet loss.
+ SACKRecovery *StatCounter
+
+ // SlowStartRetransmits is the number of segments retransmitted in slow
+ // start.
+ SlowStartRetransmits *StatCounter
+
+ // FastRetransmit is the number of segments retransmitted in fast
+ // recovery.
+ FastRetransmit *StatCounter
+
+ // Timeouts is the number of times the RTO expired.
+ Timeouts *StatCounter
+
+ // ChecksumErrors is the number of segments dropped due to bad checksums.
+ ChecksumErrors *StatCounter
+}
+
+// UDPStats collects UDP-specific stats.
+type UDPStats struct {
+ // PacketsReceived is the number of UDP datagrams received via
+ // HandlePacket.
+ PacketsReceived *StatCounter
+
+ // UnknownPortErrors is the number of incoming UDP datagrams dropped
+ // because they did not have a known destination port.
+ UnknownPortErrors *StatCounter
+
+ // ReceiveBufferErrors is the number of incoming UDP datagrams dropped
+ // due to the receiving buffer being in an invalid state.
+ ReceiveBufferErrors *StatCounter
+
+ // MalformedPacketsReceived is the number of incoming UDP datagrams
+ // dropped due to the UDP header being in a malformed state.
+ MalformedPacketsReceived *StatCounter
+
+ // PacketsSent is the number of UDP datagrams sent via sendUDP.
+ PacketsSent *StatCounter
+
+ // PacketSendErrors is the number of datagrams failed to be sent.
+ PacketSendErrors *StatCounter
+
+ // ChecksumErrors is the number of datagrams dropped due to bad checksums.
+ ChecksumErrors *StatCounter
+}
+
+// Stats holds statistics about the networking stack.
+//
+// All fields are optional.
+type Stats struct {
+ // UnknownProtocolRcvdPackets is the number of packets received by the
+ // stack that were for an unknown or unsupported protocol.
+ UnknownProtocolRcvdPackets *StatCounter
+
+ // MalformedRcvdPackets is the number of packets received by the stack
+ // that were deemed malformed.
+ MalformedRcvdPackets *StatCounter
+
+ // DroppedPackets is the number of packets dropped due to full queues.
+ DroppedPackets *StatCounter
+
+ // ICMP breaks out ICMP-specific stats (both v4 and v6).
+ ICMP ICMPStats
+
+ // IP breaks out IP-specific stats (both v4 and v6).
+ IP IPStats
+
+ // TCP breaks out TCP-specific stats.
+ TCP TCPStats
+
+ // UDP breaks out UDP-specific stats.
+ UDP UDPStats
+}
+
+// ReceiveErrors collects packet receive errors within transport endpoint.
+type ReceiveErrors struct {
+ // ReceiveBufferOverflow is the number of received packets dropped
+ // due to the receive buffer being full.
+ ReceiveBufferOverflow StatCounter
+
+ // MalformedPacketsReceived is the number of incoming packets
+ // dropped due to the packet header being in a malformed state.
+ MalformedPacketsReceived StatCounter
+
+ // ClosedReceiver is the number of received packets dropped because
+ // of receiving endpoint state being closed.
+ ClosedReceiver StatCounter
+
+ // ChecksumErrors is the number of packets dropped due to bad checksums.
+ ChecksumErrors StatCounter
+}
+
+// SendErrors collects packet send errors within the transport layer for
+// an endpoint.
+type SendErrors struct {
+ // SendToNetworkFailed is the number of packets failed to be written to
+ // the network endpoint.
+ SendToNetworkFailed StatCounter
+
+ // NoRoute is the number of times we failed to resolve IP route.
+ NoRoute StatCounter
+
+ // NoLinkAddr is the number of times we failed to resolve ARP.
+ NoLinkAddr StatCounter
+}
+
+// ReadErrors collects segment read errors from an endpoint read call.
+type ReadErrors struct {
+ // ReadClosed is the number of received packet drops because the endpoint
+ // was shutdown for read.
+ ReadClosed StatCounter
+
+ // InvalidEndpointState is the number of times we found the endpoint state
+ // to be unexpected.
+ InvalidEndpointState StatCounter
+
+ // NotConnected is the number of times we tried to read but found that the
+ // endpoint was not connected.
+ NotConnected StatCounter
+}
+
+// WriteErrors collects packet write errors from an endpoint write call.
+type WriteErrors struct {
+ // WriteClosed is the number of packet drops because the endpoint
+ // was shutdown for write.
+ WriteClosed StatCounter
+
+ // InvalidEndpointState is the number of times we found the endpoint state
+ // to be unexpected.
+ InvalidEndpointState StatCounter
+
+ // InvalidArgs is the number of times invalid input arguments were
+ // provided for endpoint Write call.
+ InvalidArgs StatCounter
+}
+
+// TransportEndpointStats collects statistics about the endpoint.
+type TransportEndpointStats struct {
+ // PacketsReceived is the number of successful packet receives.
+ PacketsReceived StatCounter
+
+ // PacketsSent is the number of successful packet sends.
+ PacketsSent StatCounter
+
+ // ReceiveErrors collects packet receive errors within transport layer.
+ ReceiveErrors ReceiveErrors
+
+ // ReadErrors collects packet read errors from an endpoint read call.
+ ReadErrors ReadErrors
+
+ // SendErrors collects packet send errors within the transport layer.
+ SendErrors SendErrors
+
+ // WriteErrors collects packet write errors from an endpoint write call.
+ WriteErrors WriteErrors
+}
+
+// IsEndpointStats is an empty method to implement the tcpip.EndpointStats
+// marker interface.
+func (*TransportEndpointStats) IsEndpointStats() {}
+
+// InitStatCounters initializes v's fields with nil StatCounter fields to new
+// StatCounters.
+func InitStatCounters(v reflect.Value) {
+ for i := 0; i < v.NumField(); i++ {
+ v := v.Field(i)
+ if s, ok := v.Addr().Interface().(**StatCounter); ok {
+ if *s == nil {
+ *s = new(StatCounter)
+ }
+ } else {
+ InitStatCounters(v)
+ }
+ }
+}
+
+// FillIn returns a copy of s with nil fields initialized to new StatCounters.
+func (s Stats) FillIn() Stats {
+ InitStatCounters(reflect.ValueOf(&s).Elem())
+ return s
+}
+
+// Clone returns a copy of the TransportEndpointStats by atomically reading
+// each field.
+func (src *TransportEndpointStats) Clone() TransportEndpointStats {
+ var dst TransportEndpointStats
+ clone(reflect.ValueOf(&dst).Elem(), reflect.ValueOf(src).Elem())
+ return dst
+}
+
+func clone(dst reflect.Value, src reflect.Value) {
+ for i := 0; i < dst.NumField(); i++ {
+ d := dst.Field(i)
+ s := src.Field(i)
+ if c, ok := s.Addr().Interface().(*StatCounter); ok {
+ d.Addr().Interface().(*StatCounter).IncrementBy(c.Value())
+ } else {
+ clone(d, s)
+ }
+ }
+}
+
+// String implements the fmt.Stringer interface.
+func (a Address) String() string {
+ switch len(a) {
+ case 4:
+ return fmt.Sprintf("%d.%d.%d.%d", int(a[0]), int(a[1]), int(a[2]), int(a[3]))
+ case 16:
+ // Find the longest subsequence of hexadecimal zeros.
+ start, end := -1, -1
+ for i := 0; i < len(a); i += 2 {
+ j := i
+ for j < len(a) && a[j] == 0 && a[j+1] == 0 {
+ j += 2
+ }
+ if j > i+2 && j-i > end-start {
+ start, end = i, j
+ }
+ }
+
+ var b strings.Builder
+ for i := 0; i < len(a); i += 2 {
+ if i == start {
+ b.WriteString("::")
+ i = end
+ if end >= len(a) {
+ break
+ }
+ } else if i > 0 {
+ b.WriteByte(':')
+ }
+ v := uint16(a[i+0])<<8 | uint16(a[i+1])
+ if v == 0 {
+ b.WriteByte('0')
+ } else {
+ const digits = "0123456789abcdef"
+ for i := uint(3); i < 4; i-- {
+ if v := v >> (i * 4); v != 0 {
+ b.WriteByte(digits[v&0xf])
+ }
+ }
+ }
+ }
+ return b.String()
+ default:
+ return fmt.Sprintf("%x", []byte(a))
+ }
+}
+
+// To4 converts the IPv4 address to a 4-byte representation.
+// If the address is not an IPv4 address, To4 returns "".
+func (a Address) To4() Address {
+ const (
+ ipv4len = 4
+ ipv6len = 16
+ )
+ if len(a) == ipv4len {
+ return a
+ }
+ if len(a) == ipv6len &&
+ isZeros(a[0:10]) &&
+ a[10] == 0xff &&
+ a[11] == 0xff {
+ return a[12:16]
+ }
+ return ""
+}
+
+// isZeros reports whether a is all zeros.
+func isZeros(a Address) bool {
+ for i := 0; i < len(a); i++ {
+ if a[i] != 0 {
+ return false
+ }
+ }
+ return true
+}
+
+// LinkAddress is a byte slice cast as a string that represents a link address.
+// It is typically a 6-byte MAC address.
+type LinkAddress string
+
+// String implements the fmt.Stringer interface.
+func (a LinkAddress) String() string {
+ switch len(a) {
+ case 6:
+ return fmt.Sprintf("%02x:%02x:%02x:%02x:%02x:%02x", a[0], a[1], a[2], a[3], a[4], a[5])
+ default:
+ return fmt.Sprintf("%x", []byte(a))
+ }
+}
+
+// ParseMACAddress parses an IEEE 802 address.
+//
+// It must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff.
+func ParseMACAddress(s string) (LinkAddress, error) {
+ parts := strings.FieldsFunc(s, func(c rune) bool {
+ return c == ':' || c == '-'
+ })
+ if len(parts) != 6 {
+ return "", fmt.Errorf("inconsistent parts: %s", s)
+ }
+ addr := make([]byte, 0, len(parts))
+ for _, part := range parts {
+ u, err := strconv.ParseUint(part, 16, 8)
+ if err != nil {
+ return "", fmt.Errorf("invalid hex digits: %s", s)
+ }
+ addr = append(addr, byte(u))
+ }
+ return LinkAddress(addr), nil
+}
+
+// AddressWithPrefix is an address with its subnet prefix length.
+type AddressWithPrefix struct {
+ // Address is a network address.
+ Address Address
+
+ // PrefixLen is the subnet prefix length.
+ PrefixLen int
+}
+
+// String implements the fmt.Stringer interface.
+func (a AddressWithPrefix) String() string {
+ return fmt.Sprintf("%s/%d", a.Address, a.PrefixLen)
+}
+
+// Subnet converts the address and prefix into a Subnet value and returns it.
+func (a AddressWithPrefix) Subnet() Subnet {
+ addrLen := len(a.Address)
+ if a.PrefixLen <= 0 {
+ return Subnet{
+ address: Address(strings.Repeat("\x00", addrLen)),
+ mask: AddressMask(strings.Repeat("\x00", addrLen)),
+ }
+ }
+ if a.PrefixLen >= addrLen*8 {
+ return Subnet{
+ address: a.Address,
+ mask: AddressMask(strings.Repeat("\xff", addrLen)),
+ }
+ }
+
+ sa := make([]byte, addrLen)
+ sm := make([]byte, addrLen)
+ n := uint(a.PrefixLen)
+ for i := 0; i < addrLen; i++ {
+ if n >= 8 {
+ sa[i] = a.Address[i]
+ sm[i] = 0xff
+ n -= 8
+ continue
+ }
+ sm[i] = ^byte(0xff >> n)
+ sa[i] = a.Address[i] & sm[i]
+ n = 0
+ }
+
+ // For extra caution, call NewSubnet rather than directly creating the Subnet
+ // value. If that fails it indicates a serious bug in this code, so panic is
+ // in order.
+ s, err := NewSubnet(Address(sa), AddressMask(sm))
+ if err != nil {
+ panic("invalid subnet: " + err.Error())
+ }
+ return s
+}
+
+// ProtocolAddress is an address and the network protocol it is associated
+// with.
+type ProtocolAddress struct {
+ // Protocol is the protocol of the address.
+ Protocol NetworkProtocolNumber
+
+ // AddressWithPrefix is a network address with its subnet prefix length.
+ AddressWithPrefix AddressWithPrefix
+}
+
+var (
+ // danglingEndpointsMu protects access to danglingEndpoints.
+ danglingEndpointsMu sync.Mutex
+
+ // danglingEndpoints tracks all dangling endpoints no longer owned by the app.
+ danglingEndpoints = make(map[Endpoint]struct{})
+)
+
+// GetDanglingEndpoints returns all dangling endpoints.
+func GetDanglingEndpoints() []Endpoint {
+ danglingEndpointsMu.Lock()
+ es := make([]Endpoint, 0, len(danglingEndpoints))
+ for e := range danglingEndpoints {
+ es = append(es, e)
+ }
+ danglingEndpointsMu.Unlock()
+ return es
+}
+
+// AddDanglingEndpoint adds a dangling endpoint.
+func AddDanglingEndpoint(e Endpoint) {
+ danglingEndpointsMu.Lock()
+ danglingEndpoints[e] = struct{}{}
+ danglingEndpointsMu.Unlock()
+}
+
+// DeleteDanglingEndpoint removes a dangling endpoint.
+func DeleteDanglingEndpoint(e Endpoint) {
+ danglingEndpointsMu.Lock()
+ delete(danglingEndpoints, e)
+ danglingEndpointsMu.Unlock()
+}
+
+// AsyncLoading is the global barrier for asynchronous endpoint loading
+// activities.
+var AsyncLoading sync.WaitGroup
diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go
new file mode 100644
index 000000000..1c8e2bc34
--- /dev/null
+++ b/pkg/tcpip/tcpip_test.go
@@ -0,0 +1,228 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import (
+ "fmt"
+ "net"
+ "strings"
+ "testing"
+)
+
+func TestSubnetContains(t *testing.T) {
+ tests := []struct {
+ s Address
+ m AddressMask
+ a Address
+ want bool
+ }{
+ {"\xa0", "\xf0", "\x90", false},
+ {"\xa0", "\xf0", "\xa0", true},
+ {"\xa0", "\xf0", "\xa5", true},
+ {"\xa0", "\xf0", "\xaf", true},
+ {"\xa0", "\xf0", "\xb0", false},
+ {"\xa0", "\xf0", "", false},
+ {"\xa0", "\xf0", "\xa0\x00", false},
+ {"\xc2\x80", "\xff\xf0", "\xc2\x80", true},
+ {"\xc2\x80", "\xff\xf0", "\xc2\x00", false},
+ {"\xc2\x00", "\xff\xf0", "\xc2\x00", true},
+ {"\xc2\x00", "\xff\xf0", "\xc2\x80", false},
+ }
+ for _, tt := range tests {
+ s, err := NewSubnet(tt.s, tt.m)
+ if err != nil {
+ t.Errorf("NewSubnet(%v, %v) = %v", tt.s, tt.m, err)
+ continue
+ }
+ if got := s.Contains(tt.a); got != tt.want {
+ t.Errorf("Subnet(%v).Contains(%v) = %v, want %v", s, tt.a, got, tt.want)
+ }
+ }
+}
+
+func TestSubnetBits(t *testing.T) {
+ tests := []struct {
+ a AddressMask
+ want1 int
+ want0 int
+ }{
+ {"\x00", 0, 8},
+ {"\x00\x00", 0, 16},
+ {"\x36", 0, 8},
+ {"\x5c", 0, 8},
+ {"\x5c\x5c", 0, 16},
+ {"\x5c\x36", 0, 16},
+ {"\x36\x5c", 0, 16},
+ {"\x36\x36", 0, 16},
+ {"\xff", 8, 0},
+ {"\xff\xff", 16, 0},
+ }
+ for _, tt := range tests {
+ s := &Subnet{mask: tt.a}
+ got1, got0 := s.Bits()
+ if got1 != tt.want1 || got0 != tt.want0 {
+ t.Errorf("Subnet{mask: %x}.Bits() = %d, %d, want %d, %d", tt.a, got1, got0, tt.want1, tt.want0)
+ }
+ }
+}
+
+func TestSubnetPrefix(t *testing.T) {
+ tests := []struct {
+ a AddressMask
+ want int
+ }{
+ {"\x00", 0},
+ {"\x00\x00", 0},
+ {"\x36", 0},
+ {"\x86", 1},
+ {"\xc5", 2},
+ {"\xff\x00", 8},
+ {"\xff\x36", 8},
+ {"\xff\x8c", 9},
+ {"\xff\xc8", 10},
+ {"\xff", 8},
+ {"\xff\xff", 16},
+ }
+ for _, tt := range tests {
+ s := &Subnet{mask: tt.a}
+ got := s.Prefix()
+ if got != tt.want {
+ t.Errorf("Subnet{mask: %x}.Bits() = %d want %d", tt.a, got, tt.want)
+ }
+ }
+}
+
+func TestSubnetCreation(t *testing.T) {
+ tests := []struct {
+ a Address
+ m AddressMask
+ want error
+ }{
+ {"\xa0", "\xf0", nil},
+ {"\xa0\xa0", "\xf0", errSubnetLengthMismatch},
+ {"\xaa", "\xf0", errSubnetAddressMasked},
+ {"", "", nil},
+ }
+ for _, tt := range tests {
+ if _, err := NewSubnet(tt.a, tt.m); err != tt.want {
+ t.Errorf("NewSubnet(%v, %v) = %v, want %v", tt.a, tt.m, err, tt.want)
+ }
+ }
+}
+
+func TestAddressString(t *testing.T) {
+ for _, want := range []string{
+ // Taken from stdlib.
+ "2001:db8::123:12:1",
+ "2001:db8::1",
+ "2001:db8:0:1:0:1:0:1",
+ "2001:db8:1:0:1:0:1:0",
+ "2001::1:0:0:1",
+ "2001:db8:0:0:1::",
+ "2001:db8::1:0:0:1",
+ "2001:db8::a:b:c:d",
+
+ // Leading zeros.
+ "::1",
+ // Trailing zeros.
+ "8::",
+ // No zeros.
+ "1:1:1:1:1:1:1:1",
+ // Longer sequence is after other zeros, but not at the end.
+ "1:0:0:1::1",
+ // Longer sequence is at the beginning, shorter sequence is at
+ // the end.
+ "::1:1:1:0:0",
+ // Longer sequence is not at the beginning, shorter sequence is
+ // at the end.
+ "1::1:1:0:0",
+ // Longer sequence is at the beginning, shorter sequence is not
+ // at the end.
+ "::1:1:0:0:1",
+ // Neither sequence is at an end, longer is after shorter.
+ "1:0:0:1::1",
+ // Shorter sequence is at the beginning, longer sequence is not
+ // at the end.
+ "0:0:1:1::1",
+ // Shorter sequence is at the beginning, longer sequence is at
+ // the end.
+ "0:0:1:1:1::",
+ // Short sequences at both ends, longer one in the middle.
+ "0:1:1::1:1:0",
+ // Short sequences at both ends, longer one in the middle.
+ "0:1::1:0:0",
+ // Short sequences at both ends, longer one in the middle.
+ "0:0:1::1:0",
+ // Longer sequence surrounded by shorter sequences, but none at
+ // the end.
+ "1:0:1::1:0:1",
+ } {
+ addr := Address(net.ParseIP(want))
+ if got := addr.String(); got != want {
+ t.Errorf("Address(%x).String() = '%s', want = '%s'", addr, got, want)
+ }
+ }
+}
+
+func TestStatsString(t *testing.T) {
+ got := fmt.Sprintf("%+v", Stats{}.FillIn())
+
+ matchers := []string{
+ // Print root-level stats correctly.
+ "UnknownProtocolRcvdPackets:0",
+ // Print protocol-specific stats correctly.
+ "TCP:{ActiveConnectionOpenings:0",
+ }
+
+ for _, m := range matchers {
+ if !strings.Contains(got, m) {
+ t.Errorf("string.Contains(got, %q) = false", m)
+ }
+ }
+ if t.Failed() {
+ t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got)
+ }
+}
+
+func TestAddressWithPrefixSubnet(t *testing.T) {
+ tests := []struct {
+ addr Address
+ prefixLen int
+ subnetAddr Address
+ subnetMask AddressMask
+ }{
+ {"\xaa\x55\x33\x42", -1, "\x00\x00\x00\x00", "\x00\x00\x00\x00"},
+ {"\xaa\x55\x33\x42", 0, "\x00\x00\x00\x00", "\x00\x00\x00\x00"},
+ {"\xaa\x55\x33\x42", 1, "\x80\x00\x00\x00", "\x80\x00\x00\x00"},
+ {"\xaa\x55\x33\x42", 7, "\xaa\x00\x00\x00", "\xfe\x00\x00\x00"},
+ {"\xaa\x55\x33\x42", 8, "\xaa\x00\x00\x00", "\xff\x00\x00\x00"},
+ {"\xaa\x55\x33\x42", 24, "\xaa\x55\x33\x00", "\xff\xff\xff\x00"},
+ {"\xaa\x55\x33\x42", 31, "\xaa\x55\x33\x42", "\xff\xff\xff\xfe"},
+ {"\xaa\x55\x33\x42", 32, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"},
+ {"\xaa\x55\x33\x42", 33, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"},
+ }
+ for _, tt := range tests {
+ ap := AddressWithPrefix{Address: tt.addr, PrefixLen: tt.prefixLen}
+ gotSubnet := ap.Subnet()
+ wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask)
+ if err != nil {
+ t.Errorf("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err)
+ continue
+ }
+ if gotSubnet != wantSubnet {
+ t.Errorf("got subnet = %q, want = %q", gotSubnet, wantSubnet)
+ }
+ }
+}
diff --git a/pkg/tcpip/time.s b/pkg/tcpip/time.s
new file mode 100644
index 000000000..fb37360ac
--- /dev/null
+++ b/pkg/tcpip/time.s
@@ -0,0 +1,15 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Empty assembly file so empty func definitions work.
diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go
new file mode 100644
index 000000000..7f172f978
--- /dev/null
+++ b/pkg/tcpip/time_unsafe.go
@@ -0,0 +1,47 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build go1.9
+// +build !go1.16
+
+// Check go:linkname function signatures when updating Go version.
+
+package tcpip
+
+import (
+ _ "time" // Used with go:linkname.
+ _ "unsafe" // Required for go:linkname.
+)
+
+// StdClock implements Clock with the time package.
+//
+// +stateify savable
+type StdClock struct{}
+
+var _ Clock = (*StdClock)(nil)
+
+//go:linkname now time.now
+func now() (sec int64, nsec int32, mono int64)
+
+// NowNanoseconds implements Clock.NowNanoseconds.
+func (*StdClock) NowNanoseconds() int64 {
+ sec, nsec, _ := now()
+ return sec*1e9 + int64(nsec)
+}
+
+// NowMonotonic implements Clock.NowMonotonic.
+func (*StdClock) NowMonotonic() int64 {
+ _, _, mono := now()
+ return mono
+}
diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go
new file mode 100644
index 000000000..59f3b391f
--- /dev/null
+++ b/pkg/tcpip/timer.go
@@ -0,0 +1,184 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import (
+ "sync"
+ "time"
+)
+
+// cancellableTimerInstance is a specific instance of CancellableTimer.
+//
+// Different instances are created each time CancellableTimer is Reset so each
+// timer has its own earlyReturn signal. This is to address a bug when a
+// CancellableTimer is stopped and reset in quick succession resulting in a
+// timer instance's earlyReturn signal being affected or seen by another timer
+// instance.
+//
+// Consider the following sceneario where timer instances share a common
+// earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a
+// lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second
+// (B), third (C), and fourth (D) instance of the timer firing, respectively):
+// T1: Obtain L
+// T1: Create a new CancellableTimer w/ lock L (create instance A)
+// T2: instance A fires, blocked trying to obtain L.
+// T1: Attempt to stop instance A (set earlyReturn = true)
+// T1: Reset timer (create instance B)
+// T3: instance B fires, blocked trying to obtain L.
+// T1: Attempt to stop instance B (set earlyReturn = true)
+// T1: Reset timer (create instance C)
+// T4: instance C fires, blocked trying to obtain L.
+// T1: Attempt to stop instance C (set earlyReturn = true)
+// T1: Reset timer (create instance D)
+// T5: instance D fires, blocked trying to obtain L.
+// T1: Release L
+//
+// Now that T1 has released L, any of the 4 timer instances can take L and check
+// earlyReturn. If the timers simply check earlyReturn and then do nothing
+// further, then instance D will never early return even though it was not
+// requested to stop. If the timers reset earlyReturn before early returning,
+// then all but one of the timers will do work when only one was expected to.
+// If CancellableTimer resets earlyReturn when resetting, then all the timers
+// will fire (again, when only one was expected to).
+//
+// To address the above concerns the simplest solution was to give each timer
+// its own earlyReturn signal.
+type cancellableTimerInstance struct {
+ timer *time.Timer
+
+ // Used to inform the timer to early return when it gets stopped while the
+ // lock the timer tries to obtain when fired is held (T1 is a goroutine that
+ // tries to cancel the timer and T2 is the goroutine that handles the timer
+ // firing):
+ // T1: Obtain the lock, then call StopLocked()
+ // T2: timer fires, and gets blocked on obtaining the lock
+ // T1: Releases lock
+ // T2: Obtains lock does unintended work
+ //
+ // To resolve this, T1 will check to see if the timer already fired, and
+ // inform the timer using earlyReturn to return early so that once T2 obtains
+ // the lock, it will see that it is set to true and do nothing further.
+ earlyReturn *bool
+}
+
+// stop stops the timer instance t from firing if it hasn't fired already. If it
+// has fired and is blocked at obtaining the lock, earlyReturn will be set to
+// true so that it will early return when it obtains the lock.
+func (t *cancellableTimerInstance) stop() {
+ if t.timer != nil {
+ t.timer.Stop()
+ *t.earlyReturn = true
+ }
+}
+
+// CancellableTimer is a timer that does some work and can be safely cancelled
+// when it fires at the same time some "related work" is being done.
+//
+// The term "related work" is defined as some work that needs to be done while
+// holding some lock that the timer must also hold while doing some work.
+//
+// Note, it is not safe to copy a CancellableTimer as its timer instance creates
+// a closure over the address of the CancellableTimer.
+type CancellableTimer struct {
+ // The active instance of a cancellable timer.
+ instance cancellableTimerInstance
+
+ // locker is the lock taken by the timer immediately after it fires and must
+ // be held when attempting to stop the timer.
+ //
+ // Must never change after being assigned.
+ locker sync.Locker
+
+ // fn is the function that will be called when a timer fires and has not been
+ // signaled to early return.
+ //
+ // fn MUST NOT attempt to lock locker.
+ //
+ // Must never change after being assigned.
+ fn func()
+}
+
+// StopLocked prevents the Timer from firing if it has not fired already.
+//
+// If the timer is blocked on obtaining the t.locker lock when StopLocked is
+// called, it will early return instead of calling t.fn.
+//
+// Note, t will be modified.
+//
+// t.locker MUST be locked.
+func (t *CancellableTimer) StopLocked() {
+ t.instance.stop()
+
+ // Nothing to do with the stopped instance anymore.
+ t.instance = cancellableTimerInstance{}
+}
+
+// Reset changes the timer to expire after duration d.
+//
+// Note, t will be modified.
+//
+// Reset should only be called on stopped or expired timers. To be safe, callers
+// should always call StopLocked before calling Reset.
+func (t *CancellableTimer) Reset(d time.Duration) {
+ // Create a new instance.
+ earlyReturn := false
+
+ // Capture the locker so that updating the timer does not cause a data race
+ // when a timer fires and tries to obtain the lock (read the timer's locker).
+ locker := t.locker
+ t.instance = cancellableTimerInstance{
+ timer: time.AfterFunc(d, func() {
+ locker.Lock()
+ defer locker.Unlock()
+
+ if earlyReturn {
+ // If we reach this point, it means that the timer fired while another
+ // goroutine called StopLocked while it had the lock. Simply return
+ // here and do nothing further.
+ earlyReturn = false
+ return
+ }
+
+ t.fn()
+ }),
+ earlyReturn: &earlyReturn,
+ }
+}
+
+// Lock is a no-op used by the copylocks checker from go vet.
+//
+// See CancellableTimer for details about why it shouldn't be copied.
+//
+// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more
+// details about the copylocks checker.
+func (*CancellableTimer) Lock() {}
+
+// Unlock is a no-op used by the copylocks checker from go vet.
+//
+// See CancellableTimer for details about why it shouldn't be copied.
+//
+// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more
+// details about the copylocks checker.
+func (*CancellableTimer) Unlock() {}
+
+// NewCancellableTimer returns an unscheduled CancellableTimer with the given
+// locker and fn.
+//
+// fn MUST NOT attempt to lock locker.
+//
+// Callers must call Reset to schedule the timer to fire.
+func NewCancellableTimer(locker sync.Locker, fn func()) *CancellableTimer {
+ return &CancellableTimer{locker: locker, fn: fn}
+}
diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go
new file mode 100644
index 000000000..b4940e397
--- /dev/null
+++ b/pkg/tcpip/timer_test.go
@@ -0,0 +1,261 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip_test
+
+import (
+ "sync"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ shortDuration = 1 * time.Nanosecond
+ middleDuration = 100 * time.Millisecond
+ longDuration = 1 * time.Second
+)
+
+func TestCancellableTimerReassignment(t *testing.T) {
+ var timer tcpip.CancellableTimer
+ var wg sync.WaitGroup
+ var lock sync.Mutex
+
+ for i := 0; i < 2; i++ {
+ wg.Add(1)
+
+ go func() {
+ lock.Lock()
+ // Assigning a new timer value updates the timer's locker and function.
+ // This test makes sure there is no data race when reassigning a timer
+ // that has an active timer (even if it has been stopped as a stopped
+ // timer may be blocked on a lock before it can check if it has been
+ // stopped while another goroutine holds the same lock).
+ timer = *tcpip.NewCancellableTimer(&lock, func() {
+ wg.Done()
+ })
+ timer.Reset(shortDuration)
+ lock.Unlock()
+ }()
+ }
+ wg.Wait()
+}
+
+func TestCancellableTimerFire(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ timer := tcpip.NewCancellableTimer(&lock, func() {
+ ch <- struct{}{}
+ })
+ timer.Reset(shortDuration)
+
+ // Wait for timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestCancellableTimerResetFromLongDuration(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(middleDuration)
+
+ lock.Lock()
+ timer.StopLocked()
+ lock.Unlock()
+
+ timer.Reset(shortDuration)
+
+ // Wait for timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestCancellableTimerResetFromShortDuration(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ lock.Lock()
+ timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ timer.StopLocked()
+ lock.Unlock()
+
+ // Wait for timer to fire if it wasn't correctly stopped.
+ select {
+ case <-ch:
+ t.Fatal("timer fired after being stopped")
+ case <-time.After(middleDuration):
+ }
+
+ timer.Reset(shortDuration)
+
+ // Wait for timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestCancellableTimerImmediatelyStop(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ for i := 0; i < 1000; i++ {
+ lock.Lock()
+ timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ timer.StopLocked()
+ lock.Unlock()
+ }
+
+ // Wait for timer to fire if it wasn't correctly stopped.
+ select {
+ case <-ch:
+ t.Fatal("timer fired after being stopped")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ lock.Lock()
+ timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ timer.StopLocked()
+ lock.Unlock()
+
+ for i := 0; i < 10; i++ {
+ timer.Reset(middleDuration)
+
+ lock.Lock()
+ // Sleep until the timer fires and gets blocked trying to take the lock.
+ time.Sleep(middleDuration * 2)
+ timer.StopLocked()
+ lock.Unlock()
+ }
+
+ // Wait for double the duration so timers that weren't correctly stopped can
+ // fire.
+ select {
+ case <-ch:
+ t.Fatal("timer fired after being stopped")
+ case <-time.After(middleDuration * 2):
+ }
+}
+
+func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ lock.Lock()
+ timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ for i := 0; i < 10; i++ {
+ // Sleep until the timer fires and gets blocked trying to take the lock.
+ time.Sleep(middleDuration)
+ timer.StopLocked()
+ timer.Reset(shortDuration)
+ }
+ lock.Unlock()
+
+ // Wait for double the duration for the last timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestManyCancellableTimerResetUnderLock(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ lock.Lock()
+ timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ for i := 0; i < 10; i++ {
+ timer.StopLocked()
+ timer.Reset(shortDuration)
+ }
+ lock.Unlock()
+
+ // Wait for double the duration for the last timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
new file mode 100644
index 000000000..7e5c79776
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -0,0 +1,40 @@
+load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "icmp_packet_list",
+ out = "icmp_packet_list.go",
+ package = "icmp",
+ prefix = "icmpPacket",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*icmpPacket",
+ "Linker": "*icmpPacket",
+ },
+)
+
+go_library(
+ name = "icmp",
+ srcs = [
+ "endpoint.go",
+ "endpoint_state.go",
+ "icmp_packet_list.go",
+ "protocol.go",
+ ],
+ imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sleep",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/ports",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/raw",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
new file mode 100644
index 000000000..62d1acad4
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -0,0 +1,831 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package icmp
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type icmpPacket struct {
+ icmpPacketEntry
+ senderAddress tcpip.FullAddress
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ timestamp int64
+}
+
+type endpointState int
+
+const (
+ stateInitial endpointState = iota
+ stateBound
+ stateConnected
+ stateClosed
+)
+
+// endpoint represents an ICMP endpoint. This struct serves as the interface
+// between users of the endpoint and the protocol implementation; it is legal to
+// have concurrent goroutines make calls into the endpoint, they are properly
+// synchronized.
+//
+// +stateify savable
+type endpoint struct {
+ stack.TransportEndpointInfo
+
+ // The following fields are initialized at creation time and are
+ // immutable.
+ stack *stack.Stack `state:"manual"`
+ waiterQueue *waiter.Queue
+ uniqueID uint64
+
+ // The following fields are used to manage the receive queue, and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvReady bool
+ rcvList icmpPacketList
+ rcvBufSizeMax int `state:".(int)"`
+ rcvBufSize int
+ rcvClosed bool
+
+ // The following fields are protected by the mu mutex.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ // shutdownFlags represent the current shutdown state of the endpoint.
+ shutdownFlags tcpip.ShutdownFlags
+ state endpointState
+ route stack.Route `state:"manual"`
+ ttl uint8
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+
+ // owner is used to get uid and gid of the packet.
+ owner tcpip.PacketOwner
+}
+
+func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return &endpoint{
+ stack: s,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: transProto,
+ },
+ waiterQueue: waiterQueue,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSize: 32 * 1024,
+ state: stateInitial,
+ uniqueID: s.UniqueID(),
+ }, nil
+}
+
+// UniqueID implements stack.TransportEndpoint.UniqueID.
+func (e *endpoint) UniqueID() uint64 {
+ return e.uniqueID
+}
+
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
+// Close puts the endpoint in a closed state and frees all resources
+// associated with it.
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
+ switch e.state {
+ case stateBound, stateConnected:
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */)
+ }
+
+ // Close the receive list and drain it.
+ e.rcvMu.Lock()
+ e.rcvClosed = true
+ e.rcvBufSize = 0
+ for !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ }
+ e.rcvMu.Unlock()
+
+ e.route.Release()
+
+ // Update the state.
+ e.state = stateClosed
+
+ e.mu.Unlock()
+
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
+func (e *endpoint) ModerateRecvBuf(copied int) {}
+
+func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.owner = owner
+}
+
+// Read reads data from the endpoint. This method does not block if
+// there is no data pending.
+func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ e.rcvMu.Lock()
+
+ if e.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if e.rcvClosed {
+ e.stats.ReadErrors.ReadClosed.Increment()
+ err = tcpip.ErrClosedForReceive
+ }
+ e.rcvMu.Unlock()
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
+
+ e.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = p.senderAddress
+ }
+
+ return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
+}
+
+// prepareForWrite prepares the endpoint for sending data. In particular, it
+// binds it if it's still in the initial state. To do so, it must first
+// reacquire the mutex in exclusive mode.
+//
+// Returns true for retry if preparation should be retried.
+func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
+ switch e.state {
+ case stateInitial:
+ case stateConnected:
+ return false, nil
+
+ case stateBound:
+ if to == nil {
+ return false, tcpip.ErrDestinationRequired
+ }
+ return false, nil
+ default:
+ return false, tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // The state changed when we released the shared locked and re-acquired
+ // it in exclusive mode. Try again.
+ if e.state != stateInitial {
+ return true, nil
+ }
+
+ // The state is still 'initial', so try to bind the endpoint.
+ if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
+ return false, err
+ }
+
+ return true, nil
+}
+
+// Write writes data to the endpoint's peer. This method does not block
+// if the data cannot be written.
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ n, ch, err := e.write(p, opts)
+ switch err {
+ case nil:
+ e.stats.PacketsSent.Increment()
+ case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ e.stats.WriteErrors.InvalidArgs.Increment()
+ case tcpip.ErrClosedForSend:
+ e.stats.WriteErrors.WriteClosed.Increment()
+ case tcpip.ErrInvalidEndpointState:
+ e.stats.WriteErrors.InvalidEndpointState.Increment()
+ case tcpip.ErrNoLinkAddress:
+ e.stats.SendErrors.NoLinkAddr.Increment()
+ case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ // Errors indicating any problem with IP routing of the packet.
+ e.stats.SendErrors.NoRoute.Increment()
+ default:
+ // For all other errors when writing to the network layer.
+ e.stats.SendErrors.SendToNetworkFailed.Increment()
+ }
+ return n, ch, err
+}
+
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
+ if opts.More {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
+ to := opts.To
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // If we've shutdown with SHUT_WR we are in an invalid state for sending.
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
+ return 0, nil, tcpip.ErrClosedForSend
+ }
+
+ // Prepare for write.
+ for {
+ retry, err := e.prepareForWrite(to)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if !retry {
+ break
+ }
+ }
+
+ var route *stack.Route
+ if to == nil {
+ route = &e.route
+
+ if route.IsResolutionRequired() {
+ // Promote lock to exclusive if using a shared route,
+ // given that it may need to change in Route.Resolve()
+ // call below.
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Recheck state after lock was re-acquired.
+ if e.state != stateConnected {
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+ }
+ } else {
+ // Reject destination address if it goes through a different
+ // NIC than the endpoint was bound to.
+ nicID := to.NIC
+ if e.BindNICID != 0 {
+ if nicID != 0 && nicID != e.BindNICID {
+ return 0, nil, tcpip.ErrNoRoute
+ }
+
+ nicID = e.BindNICID
+ }
+
+ dst, netProto, err := e.checkV4MappedLocked(*to)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Find the endpoint.
+ r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer r.Release()
+
+ route = &r
+ }
+
+ if route.IsResolutionRequired() {
+ if ch, err := route.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ return 0, ch, tcpip.ErrNoLinkAddress
+ }
+ return 0, nil, err
+ }
+ }
+
+ v, err := p.FullPayload()
+ if err != nil {
+ return 0, nil, err
+ }
+
+ switch e.NetProto {
+ case header.IPv4ProtocolNumber:
+ err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner)
+
+ case header.IPv6ProtocolNumber:
+ err = send6(route, e.ID.LocalPort, v, e.ttl)
+ }
+
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return int64(len(v)), nil, nil
+}
+
+// Peek only returns data from a single datagram, so do nothing here.
+func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// SetSockOpt sets a socket option.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return nil
+}
+
+// SetSockOptBool sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ return nil
+}
+
+// SetSockOptInt sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ switch opt {
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ e.ttl = uint8(v)
+ e.mu.Unlock()
+
+ }
+ return nil
+}
+
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ switch opt {
+ case tcpip.KeepaliveEnabledOption:
+ return false, nil
+
+ default:
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+ switch opt {
+ case tcpip.ReceiveQueueSizeOption:
+ v := 0
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ v = p.data.Size()
+ }
+ e.rcvMu.Unlock()
+ return v, nil
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSize
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
+ return v, nil
+
+ case tcpip.TTLOption:
+ e.rcvMu.Lock()
+ v := int(e.ttl)
+ e.rcvMu.Unlock()
+ return v, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error {
+ if len(data) < header.ICMPv4MinimumSize {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength()))
+
+ icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ copy(icmpv4, data)
+ // Set the ident to the user-specified port. Sequence number should
+ // already be set by the user.
+ icmpv4.SetIdent(ident)
+ data = data[header.ICMPv4MinimumSize:]
+
+ // Linux performs these basic checks.
+ if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ icmpv4.SetChecksum(0)
+ icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
+
+ if ttl == 0 {
+ ttl = r.DefaultTTL()
+ }
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
+ Header: hdr,
+ Data: data.ToVectorisedView(),
+ TransportHeader: buffer.View(icmpv4),
+ Owner: owner,
+ })
+}
+
+func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error {
+ if len(data) < header.ICMPv6EchoMinimumSize {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength()))
+
+ icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ copy(icmpv6, data)
+ // Set the ident. Sequence number is provided by the user.
+ icmpv6.SetIdent(ident)
+ data = data[header.ICMPv6MinimumSize:]
+
+ if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ dataVV := data.ToVectorisedView()
+ icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV))
+
+ if ttl == 0 {
+ ttl = r.DefaultTTL()
+ }
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
+ Header: hdr,
+ Data: dataVV,
+ TransportHeader: buffer.View(icmpv6),
+ })
+}
+
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */)
+ if err != nil {
+ return tcpip.FullAddress{}, 0, err
+ }
+ return unwrapped, netProto, nil
+}
+
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Connect connects the endpoint to its peer. Specifying a NIC is optional.
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ nicID := addr.NIC
+ localPort := uint16(0)
+ switch e.state {
+ case stateInitial:
+ case stateBound, stateConnected:
+ localPort = e.ID.LocalPort
+ if e.BindNICID == 0 {
+ break
+ }
+
+ if nicID != 0 && nicID != e.BindNICID {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ nicID = e.BindNICID
+ default:
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ addr, netProto, err := e.checkV4MappedLocked(addr)
+ if err != nil {
+ return err
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ id := stack.TransportEndpointID{
+ LocalAddress: r.LocalAddress,
+ LocalPort: localPort,
+ RemoteAddress: r.RemoteAddress,
+ }
+
+ // Even if we're connected, this endpoint can still be used to send
+ // packets on a different network protocol, so we register both even if
+ // v6only is set to false and this is an ipv6 endpoint.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+
+ id, err = e.registerWithStack(nicID, netProtos, id)
+ if err != nil {
+ return err
+ }
+
+ e.ID = id
+ e.route = r.Clone()
+ e.RegisterNICID = nicID
+
+ e.state = stateConnected
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// ConnectEndpoint is not supported.
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection
+// to its peer.
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.shutdownFlags |= flags
+
+ if e.state != stateConnected {
+ return tcpip.ErrNotConnected
+ }
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.rcvMu.Lock()
+ wasClosed := e.rcvClosed
+ e.rcvClosed = true
+ e.rcvMu.Unlock()
+
+ if !wasClosed {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+ }
+
+ return nil
+}
+
+// Listen is not supported by UDP, it just fails.
+func (*endpoint) Listen(int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept is not supported by UDP, it just fails.
+func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+ if id.LocalPort != 0 {
+ // The endpoint already has a local port, just attempt to
+ // register it.
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */)
+ return id, err
+ }
+
+ // We need to find a port for the endpoint.
+ _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ id.LocalPort = p
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */)
+ switch err {
+ case nil:
+ return true, nil
+ case tcpip.ErrPortInUse:
+ return false, nil
+ default:
+ return false, err
+ }
+ })
+
+ return id, err
+}
+
+func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore.
+ if e.state != stateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ addr, netProto, err := e.checkV4MappedLocked(addr)
+ if err != nil {
+ return err
+ }
+
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+
+ if len(addr.Addr) != 0 {
+ // A local address was specified, verify that it's valid.
+ if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+ }
+
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: addr.Addr,
+ }
+ id, err = e.registerWithStack(addr.NIC, netProtos, id)
+ if err != nil {
+ return err
+ }
+
+ e.ID = id
+ e.RegisterNICID = addr.NIC
+
+ // Mark endpoint as bound.
+ e.state = stateBound
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// Bind binds the endpoint to a specific local address and port.
+// Specifying a NIC is optional.
+func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ err := e.bindLocked(addr)
+ if err != nil {
+ return err
+ }
+
+ e.BindNICID = addr.NIC
+ e.BindAddr = addr.Addr
+
+ return nil
+}
+
+// GetLocalAddress returns the address to which the endpoint is bound.
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ return tcpip.FullAddress{
+ NIC: e.RegisterNICID,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
+ }, nil
+}
+
+// GetRemoteAddress returns the address to which the endpoint is connected.
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.state != stateConnected {
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ }
+
+ return tcpip.FullAddress{
+ NIC: e.RegisterNICID,
+ Addr: e.ID.RemoteAddress,
+ Port: e.ID.RemotePort,
+ }, nil
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine if the endpoint is readable if requested.
+ if (mask & waiter.EventIn) != 0 {
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() || e.rcvClosed {
+ result |= waiter.EventIn
+ }
+ e.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+ // Only accept echo replies.
+ switch e.NetProto {
+ case header.IPv4ProtocolNumber:
+ h := header.ICMPv4(pkt.TransportHeader)
+ if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
+ return
+ }
+ case header.IPv6ProtocolNumber:
+ h := header.ICMPv6(pkt.TransportHeader)
+ if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply {
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
+ return
+ }
+ }
+
+ e.rcvMu.Lock()
+
+ // Drop the packet if our buffer is currently full.
+ if !e.rcvReady || e.rcvClosed {
+ e.rcvMu.Unlock()
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ClosedReceiver.Increment()
+ return
+ }
+
+ if e.rcvBufSize >= e.rcvBufSizeMax {
+ e.rcvMu.Unlock()
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
+ return
+ }
+
+ wasEmpty := e.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ packet := &icmpPacket{
+ senderAddress: tcpip.FullAddress{
+ NIC: r.NICID(),
+ Addr: id.RemoteAddress,
+ },
+ }
+
+ // ICMP socket's data includes ICMP header.
+ packet.data = pkt.TransportHeader.ToVectorisedView()
+ packet.data.Append(pkt.Data)
+
+ e.rcvList.PushBack(packet)
+ e.rcvBufSize += packet.data.Size()
+
+ packet.timestamp = e.stack.NowNanoseconds()
+
+ e.rcvMu.Unlock()
+ e.stats.PacketsReceived.Increment()
+ // Notify any waiters that there's data to be read now.
+ if wasEmpty {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
+}
+
+// State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't
+// expose internal socket state.
+func (e *endpoint) State() uint32 {
+ return 0
+}
+
+// Info returns a copy of the endpoint info.
+func (e *endpoint) Info() tcpip.EndpointInfo {
+ e.mu.RLock()
+ // Make a copy of the endpoint info.
+ ret := e.TransportEndpointInfo
+ e.mu.RUnlock()
+ return &ret
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (e *endpoint) Stats() tcpip.EndpointStats {
+ return &e.stats
+}
+
+// Wait implements stack.TransportEndpoint.Wait.
+func (*endpoint) Wait() {}
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
new file mode 100644
index 000000000..9d263c0ec
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -0,0 +1,95 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package icmp
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves icmpPacket.data field.
+func (p *icmpPacket) saveData() buffer.VectorisedView {
+ // We cannot save p.data directly as p.data.views may alias to p.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return p.data.Clone(nil)
+}
+
+// loadData loads icmpPacket.data field.
+func (p *icmpPacket) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing p.views for data.views.
+ p.data = data
+}
+
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after savercvBufSizeMax(), which would have
+ // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ e.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) saveRcvBufSizeMax() int {
+ max := e.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ e.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ e.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) loadRcvBufSizeMax(max int) {
+ e.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (e *endpoint) afterLoad() {
+ stack.StackFromEnv.RegisterRestoredEndpoint(e)
+}
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *endpoint) Resume(s *stack.Stack) {
+ e.stack = s
+
+ if e.state != stateBound && e.state != stateConnected {
+ return
+ }
+
+ var err *tcpip.Error
+ if e.state == stateConnected {
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */)
+ if err != nil {
+ panic(err)
+ }
+
+ e.ID.LocalAddress = e.route.LocalAddress
+ } else if len(e.ID.LocalAddress) != 0 { // stateBound
+ if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID)
+ if err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
new file mode 100644
index 000000000..74ef6541e
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -0,0 +1,145 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package icmp contains the implementation of the ICMP and IPv6-ICMP transport
+// protocols for use in ping. To use it in the networking stack, this package
+// must be added to the project, and activated on the stack by passing
+// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport
+// protocols when calling stack.New(). Then endpoints can be created by passing
+// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number
+// when calling Stack.NewEndpoint().
+package icmp
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/raw"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ // ProtocolNumber4 is the ICMP protocol number.
+ ProtocolNumber4 = header.ICMPv4ProtocolNumber
+
+ // ProtocolNumber6 is the IPv6-ICMP protocol number.
+ ProtocolNumber6 = header.ICMPv6ProtocolNumber
+)
+
+// protocol implements stack.TransportProtocol.
+type protocol struct {
+ number tcpip.TransportProtocolNumber
+}
+
+// Number returns the ICMP protocol number.
+func (p *protocol) Number() tcpip.TransportProtocolNumber {
+ return p.number
+}
+
+func (p *protocol) netProto() tcpip.NetworkProtocolNumber {
+ switch p.number {
+ case ProtocolNumber4:
+ return header.IPv4ProtocolNumber
+ case ProtocolNumber6:
+ return header.IPv6ProtocolNumber
+ }
+ panic(fmt.Sprint("unknown protocol number: ", p.number))
+}
+
+// NewEndpoint creates a new icmp endpoint. It implements
+// stack.TransportProtocol.NewEndpoint.
+func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if netProto != p.netProto() {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+ return newEndpoint(stack, netProto, p.number, waiterQueue)
+}
+
+// NewRawEndpoint creates a new raw icmp endpoint. It implements
+// stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if netProto != p.netProto() {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+ return raw.NewEndpoint(stack, netProto, p.number, waiterQueue)
+}
+
+// MinimumPacketSize returns the minimum valid icmp packet size.
+func (p *protocol) MinimumPacketSize() int {
+ switch p.number {
+ case ProtocolNumber4:
+ return header.ICMPv4MinimumSize
+ case ProtocolNumber6:
+ return header.ICMPv6MinimumSize
+ }
+ panic(fmt.Sprint("unknown protocol number: ", p.number))
+}
+
+// ParsePorts in case of ICMP sets src to 0, dst to ICMP ID, and err to nil.
+func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+ switch p.number {
+ case ProtocolNumber4:
+ hdr := header.ICMPv4(v)
+ return 0, hdr.Ident(), nil
+ case ProtocolNumber6:
+ hdr := header.ICMPv6(v)
+ return 0, hdr.Ident(), nil
+ }
+ panic(fmt.Sprint("unknown protocol number: ", p.number))
+}
+
+// HandleUnknownDestinationPacket handles packets targeted at this protocol but
+// that don't match any existing endpoint.
+func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool {
+ return true
+}
+
+// SetOption implements stack.TransportProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements stack.TransportProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
+ // TODO(gvisor.dev/issue/170): Implement parsing of ICMP.
+ //
+ // Right now, the Parse() method is tied to enabled protocols passed into
+ // stack.New. This works for UDP and TCP, but we handle ICMP traffic even
+ // when netstack users don't pass ICMP as a supported protocol.
+ return false
+}
+
+// NewProtocol4 returns an ICMPv4 transport protocol.
+func NewProtocol4() stack.TransportProtocol {
+ return &protocol{ProtocolNumber4}
+}
+
+// NewProtocol6 returns an ICMPv6 transport protocol.
+func NewProtocol6() stack.TransportProtocol {
+ return &protocol{ProtocolNumber6}
+}
diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD
new file mode 100644
index 000000000..b989b1209
--- /dev/null
+++ b/pkg/tcpip/transport/packet/BUILD
@@ -0,0 +1,37 @@
+load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "packet_list",
+ out = "packet_list.go",
+ package = "packet",
+ prefix = "packet",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*packet",
+ "Linker": "*packet",
+ },
+)
+
+go_library(
+ name = "packet",
+ srcs = [
+ "endpoint.go",
+ "endpoint_state.go",
+ "packet_list.go",
+ ],
+ imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/log",
+ "//pkg/sleep",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
new file mode 100644
index 000000000..a8f8454dd
--- /dev/null
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -0,0 +1,469 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package packet provides the implementation of packet sockets (see
+// packet(7)). Packet sockets allow applications to:
+//
+// * manually write and inspect link, network, and transport headers
+// * receive all traffic of a given network protocol, or all protocols
+//
+// Packet sockets are similar to raw sockets, but provide even more power to
+// users, letting them effectively talk directly to the network device.
+//
+// Packet sockets skip the input and output iptables chains.
+package packet
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type packet struct {
+ packetEntry
+ // data holds the actual packet data, including any headers and
+ // payload.
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ // timestampNS is the unix time at which the packet was received.
+ timestampNS int64
+ // senderAddr is the network address of the sender.
+ senderAddr tcpip.FullAddress
+}
+
+// endpoint is the packet socket implementation of tcpip.Endpoint. It is legal
+// to have goroutines make concurrent calls into the endpoint.
+//
+// Lock order:
+// endpoint.mu
+// endpoint.rcvMu
+//
+// +stateify savable
+type endpoint struct {
+ stack.TransportEndpointInfo
+ // The following fields are initialized at creation time and are
+ // immutable.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ waiterQueue *waiter.Queue
+ cooked bool
+
+ // The following fields are used to manage the receive queue and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvList packetList
+ rcvBufSizeMax int `state:".(int)"`
+ rcvBufSize int
+ rcvClosed bool
+
+ // The following fields are protected by mu.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ closed bool
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+ bound bool
+}
+
+// NewEndpoint returns a new packet endpoint.
+func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ ep := &endpoint{
+ stack: s,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ },
+ cooked: cooked,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSize: 32 * 1024,
+ }
+
+ // Override with stack defaults.
+ var ss stack.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ ep.sndBufSizeMax = ss.Default
+ }
+
+ var rs stack.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ ep.rcvBufSizeMax = rs.Default
+ }
+
+ if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil {
+ return nil, err
+ }
+ return ep, nil
+}
+
+// Abort implements stack.TransportEndpoint.Abort.
+func (ep *endpoint) Abort() {
+ ep.Close()
+}
+
+// Close implements tcpip.Endpoint.Close.
+func (ep *endpoint) Close() {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if ep.closed {
+ return
+ }
+
+ ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+
+ ep.rcvMu.Lock()
+ defer ep.rcvMu.Unlock()
+
+ // Clear the receive list.
+ ep.rcvClosed = true
+ ep.rcvBufSize = 0
+ for !ep.rcvList.Empty() {
+ ep.rcvList.Remove(ep.rcvList.Front())
+ }
+
+ ep.closed = true
+ ep.bound = false
+ ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
+func (ep *endpoint) ModerateRecvBuf(copied int) {}
+
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ ep.rcvMu.Lock()
+
+ // If there's no data to read, return that read would block or that the
+ // endpoint is closed.
+ if ep.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if ep.rcvClosed {
+ ep.stats.ReadErrors.ReadClosed.Increment()
+ err = tcpip.ErrClosedForReceive
+ }
+ ep.rcvMu.Unlock()
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ packet := ep.rcvList.Front()
+ ep.rcvList.Remove(packet)
+ ep.rcvBufSize -= packet.data.Size()
+
+ ep.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = packet.senderAddr
+ }
+
+ return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
+}
+
+func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // TODO(b/129292371): Implement.
+ return 0, nil, tcpip.ErrInvalidOptionValue
+}
+
+// Peek implements tcpip.Endpoint.Peek.
+func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
+// disconnected, and this function always returns tpcip.ErrNotSupported.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be
+// connected, and this function always returnes tcpip.ErrNotSupported.
+func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used
+// with Shutdown, and this function always returns tcpip.ErrNotSupported.
+func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with
+// Listen, and this function always returns tcpip.ErrNotSupported.
+func (ep *endpoint) Listen(backlog int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with
+// Accept, and this function always returns tcpip.ErrNotSupported.
+func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+// Bind implements tcpip.Endpoint.Bind.
+func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ // TODO(gvisor.dev/issue/173): Add Bind support.
+
+ // "By default, all packets of the specified protocol type are passed
+ // to a packet socket. To get packets only from a specific interface
+ // use bind(2) specifying an address in a struct sockaddr_ll to bind
+ // the packet socket to an interface. Fields used for binding are
+ // sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex."
+ // - packet(7).
+
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if ep.bound {
+ return tcpip.ErrAlreadyBound
+ }
+
+ // Unregister endpoint with all the nics.
+ ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+
+ // Bind endpoint to receive packets from specific interface.
+ if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
+ return err
+ }
+
+ ep.bound = true
+
+ return nil
+}
+
+// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
+func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, tcpip.ErrNotSupported
+}
+
+// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
+func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ // Even a connected socket doesn't return a remote address.
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+}
+
+// Readiness implements tcpip.Endpoint.Readiness.
+func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine whether the endpoint is readable.
+ if (mask & waiter.EventIn) != 0 {
+ ep.rcvMu.Lock()
+ if !ep.rcvList.Empty() || ep.rcvClosed {
+ result |= waiter.EventIn
+ }
+ ep.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be
+// used with SetSockOpt, and this function always returns
+// tcpip.ErrNotSupported.
+func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
+func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss stack.SendBufferSizeOption
+ if err := ep.stack.Option(&ss); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+ if v < ss.Min {
+ v = ss.Min
+ }
+ ep.mu.Lock()
+ ep.sndBufSizeMax = v
+ ep.mu.Unlock()
+ return nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs stack.ReceiveBufferSizeOption
+ if err := ep.stack.Option(&rs); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err))
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+ if v < rs.Min {
+ v = rs.Min
+ }
+ ep.rcvMu.Lock()
+ ep.rcvBufSizeMax = v
+ ep.rcvMu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrNotSupported
+}
+
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+ switch opt {
+ case tcpip.ReceiveQueueSizeOption:
+ v := 0
+ ep.rcvMu.Lock()
+ if !ep.rcvList.Empty() {
+ p := ep.rcvList.Front()
+ v = p.data.Size()
+ }
+ ep.rcvMu.Unlock()
+ return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ ep.mu.Lock()
+ v := ep.sndBufSizeMax
+ ep.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ ep.rcvMu.Lock()
+ v := ep.rcvBufSizeMax
+ ep.rcvMu.Unlock()
+ return v, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// HandlePacket implements stack.PacketEndpoint.HandlePacket.
+func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ ep.rcvMu.Lock()
+
+ // Drop the packet if our buffer is currently full.
+ if ep.rcvClosed {
+ ep.rcvMu.Unlock()
+ ep.stack.Stats().DroppedPackets.Increment()
+ ep.stats.ReceiveErrors.ClosedReceiver.Increment()
+ return
+ }
+
+ if ep.rcvBufSize >= ep.rcvBufSizeMax {
+ ep.rcvMu.Unlock()
+ ep.stack.Stats().DroppedPackets.Increment()
+ ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
+ return
+ }
+
+ wasEmpty := ep.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ var packet packet
+ // TODO(b/129292371): Return network protocol.
+ if len(pkt.LinkHeader) > 0 {
+ // Get info directly from the ethernet header.
+ hdr := header.Ethernet(pkt.LinkHeader)
+ packet.senderAddr = tcpip.FullAddress{
+ NIC: nicID,
+ Addr: tcpip.Address(hdr.SourceAddress()),
+ }
+ } else {
+ // Guess the would-be ethernet header.
+ packet.senderAddr = tcpip.FullAddress{
+ NIC: nicID,
+ Addr: tcpip.Address(localAddr),
+ }
+ }
+
+ if ep.cooked {
+ // Cooked packets can simply be queued.
+ packet.data = pkt.Data
+ } else {
+ // Raw packets need their ethernet headers prepended before
+ // queueing.
+ var linkHeader buffer.View
+ if len(pkt.LinkHeader) == 0 {
+ // We weren't provided with an actual ethernet header,
+ // so fake one.
+ ethFields := header.EthernetFields{
+ SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
+ DstAddr: localAddr,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ linkHeader = buffer.View(fakeHeader)
+ } else {
+ linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
+ }
+ combinedVV := linkHeader.ToVectorisedView()
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
+ }
+ packet.timestampNS = ep.stack.NowNanoseconds()
+
+ ep.rcvList.PushBack(&packet)
+ ep.rcvBufSize += packet.data.Size()
+
+ ep.rcvMu.Unlock()
+ ep.stats.PacketsReceived.Increment()
+ // Notify waiters that there's data to be read.
+ if wasEmpty {
+ ep.waiterQueue.Notify(waiter.EventIn)
+ }
+}
+
+// State implements socket.Socket.State.
+func (ep *endpoint) State() uint32 {
+ return 0
+}
+
+// Info returns a copy of the endpoint info.
+func (ep *endpoint) Info() tcpip.EndpointInfo {
+ ep.mu.RLock()
+ // Make a copy of the endpoint info.
+ ret := ep.TransportEndpointInfo
+ ep.mu.RUnlock()
+ return &ret
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (ep *endpoint) Stats() tcpip.EndpointStats {
+ return &ep.stats
+}
+
+func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
new file mode 100644
index 000000000..9b88f17e4
--- /dev/null
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -0,0 +1,72 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package packet
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves packet.data field.
+func (p *packet) saveData() buffer.VectorisedView {
+ // We cannot save p.data directly as p.data.views may alias to p.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return p.data.Clone(nil)
+}
+
+// loadData loads packet.data field.
+func (p *packet) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing p.views for data.views.
+ p.data = data
+}
+
+// beforeSave is invoked by stateify.
+func (ep *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after saveRcvBufSizeMax(), which would have
+ // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ ep.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) saveRcvBufSizeMax() int {
+ max := ep.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ ep.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ ep.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) loadRcvBufSizeMax(max int) {
+ ep.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (ep *endpoint) afterLoad() {
+ // StackFromEnv is a stack used specifically for save/restore.
+ ep.stack = stack.StackFromEnv
+
+ // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC.
+ if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil {
+ panic(*err)
+ }
+}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
new file mode 100644
index 000000000..2eab09088
--- /dev/null
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -0,0 +1,39 @@
+load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "raw_packet_list",
+ out = "raw_packet_list.go",
+ package = "raw",
+ prefix = "rawPacket",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*rawPacket",
+ "Linker": "*rawPacket",
+ },
+)
+
+go_library(
+ name = "raw",
+ srcs = [
+ "endpoint.go",
+ "endpoint_state.go",
+ "protocol.go",
+ "raw_packet_list.go",
+ ],
+ imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/log",
+ "//pkg/sleep",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/packet",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
new file mode 100644
index 000000000..5b6e7d102
--- /dev/null
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -0,0 +1,729 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package raw provides the implementation of raw sockets (see raw(7)). Raw
+// sockets allow applications to:
+//
+// * manually write and inspect transport layer headers and payloads
+// * receive all traffic of a given transport protocol (e.g. ICMP or UDP)
+// * optionally write and inspect network layer headers of packets
+//
+// Raw sockets don't have any notion of ports, and incoming packets are
+// demultiplexed solely by protocol number. Thus, a raw UDP endpoint will
+// receive every UDP packet received by netstack. bind(2) and connect(2) can be
+// used to filter incoming packets by source and destination.
+package raw
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type rawPacket struct {
+ rawPacketEntry
+ // data holds the actual packet data, including any headers and
+ // payload.
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ // timestampNS is the unix time at which the packet was received.
+ timestampNS int64
+ // senderAddr is the network address of the sender.
+ senderAddr tcpip.FullAddress
+}
+
+// endpoint is the raw socket implementation of tcpip.Endpoint. It is legal to
+// have goroutines make concurrent calls into the endpoint.
+//
+// Lock order:
+// endpoint.mu
+// endpoint.rcvMu
+//
+// +stateify savable
+type endpoint struct {
+ stack.TransportEndpointInfo
+ // The following fields are initialized at creation time and are
+ // immutable.
+ stack *stack.Stack `state:"manual"`
+ waiterQueue *waiter.Queue
+ associated bool
+ hdrIncluded bool
+
+ // The following fields are used to manage the receive queue and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvList rawPacketList
+ rcvBufSize int
+ rcvBufSizeMax int `state:".(int)"`
+ rcvClosed bool
+
+ // The following fields are protected by mu.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ closed bool
+ connected bool
+ bound bool
+ // route is the route to a remote network endpoint. It is set via
+ // Connect(), and is valid only when conneted is true.
+ route stack.Route `state:"manual"`
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+
+ // owner is used to get uid and gid of the packet.
+ owner tcpip.PacketOwner
+}
+
+// NewEndpoint returns a raw endpoint for the given protocols.
+func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */)
+}
+
+func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
+ if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ e := &endpoint{
+ stack: s,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: transProto,
+ },
+ waiterQueue: waiterQueue,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSizeMax: 32 * 1024,
+ associated: associated,
+ hdrIncluded: !associated,
+ }
+
+ // Override with stack defaults.
+ var ss stack.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ e.sndBufSizeMax = ss.Default
+ }
+
+ var rs stack.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ e.rcvBufSizeMax = rs.Default
+ }
+
+ // Unassociated endpoints are write-only and users call Write() with IP
+ // headers included. Because they're write-only, We don't need to
+ // register with the stack.
+ if !associated {
+ e.rcvBufSizeMax = 0
+ e.waiterQueue = nil
+ return e, nil
+ }
+
+ if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil {
+ return nil, err
+ }
+
+ return e, nil
+}
+
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
+// Close implements tcpip.Endpoint.Close.
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if e.closed || !e.associated {
+ return
+ }
+
+ e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
+
+ e.rcvMu.Lock()
+ defer e.rcvMu.Unlock()
+
+ // Clear the receive list.
+ e.rcvClosed = true
+ e.rcvBufSize = 0
+ for !e.rcvList.Empty() {
+ e.rcvList.Remove(e.rcvList.Front())
+ }
+
+ if e.connected {
+ e.route.Release()
+ e.connected = false
+ }
+
+ e.closed = true
+
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
+func (e *endpoint) ModerateRecvBuf(copied int) {}
+
+func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.owner = owner
+}
+
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ e.rcvMu.Lock()
+
+ // If there's no data to read, return that read would block or that the
+ // endpoint is closed.
+ if e.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if e.rcvClosed {
+ e.stats.ReadErrors.ReadClosed.Increment()
+ err = tcpip.ErrClosedForReceive
+ }
+ e.rcvMu.Unlock()
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ pkt := e.rcvList.Front()
+ e.rcvList.Remove(pkt)
+ e.rcvBufSize -= pkt.data.Size()
+
+ e.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = pkt.senderAddr
+ }
+
+ return pkt.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: pkt.timestampNS}, nil
+}
+
+// Write implements tcpip.Endpoint.Write.
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // We can create, but not write to, unassociated IPv6 endpoints.
+ if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
+ n, ch, err := e.write(p, opts)
+ switch err {
+ case nil:
+ e.stats.PacketsSent.Increment()
+ case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ e.stats.WriteErrors.InvalidArgs.Increment()
+ case tcpip.ErrClosedForSend:
+ e.stats.WriteErrors.WriteClosed.Increment()
+ case tcpip.ErrInvalidEndpointState:
+ e.stats.WriteErrors.InvalidEndpointState.Increment()
+ case tcpip.ErrNoLinkAddress:
+ e.stats.SendErrors.NoLinkAddr.Increment()
+ case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ // Errors indicating any problem with IP routing of the packet.
+ e.stats.SendErrors.NoRoute.Increment()
+ default:
+ // For all other errors when writing to the network layer.
+ e.stats.SendErrors.SendToNetworkFailed.Increment()
+ }
+ return n, ch, err
+}
+
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
+ if opts.More {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
+ e.mu.RLock()
+
+ if e.closed {
+ e.mu.RUnlock()
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+
+ payloadBytes, err := p.FullPayload()
+ if err != nil {
+ e.mu.RUnlock()
+ return 0, nil, err
+ }
+
+ // If this is an unassociated socket and callee provided a nonzero
+ // destination address, route using that address.
+ if e.hdrIncluded {
+ ip := header.IPv4(payloadBytes)
+ if !ip.IsValid(len(payloadBytes)) {
+ e.mu.RUnlock()
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+ dstAddr := ip.DestinationAddress()
+ // Update dstAddr with the address in the IP header, unless
+ // opts.To is set (e.g. if sendto specifies a specific
+ // address).
+ if dstAddr != tcpip.Address([]byte{0, 0, 0, 0}) && opts.To == nil {
+ opts.To = &tcpip.FullAddress{
+ NIC: 0, // NIC is unset.
+ Addr: dstAddr, // The address from the payload.
+ Port: 0, // There are no ports here.
+ }
+ }
+ }
+
+ // Did the user caller provide a destination? If not, use the connected
+ // destination.
+ if opts.To == nil {
+ // If the user doesn't specify a destination, they should have
+ // connected to another address.
+ if !e.connected {
+ e.mu.RUnlock()
+ return 0, nil, tcpip.ErrDestinationRequired
+ }
+
+ if e.route.IsResolutionRequired() {
+ savedRoute := &e.route
+ // Promote lock to exclusive if using a shared route,
+ // given that it may need to change in finishWrite.
+ e.mu.RUnlock()
+ e.mu.Lock()
+
+ // Make sure that the route didn't change during the
+ // time we didn't hold the lock.
+ if !e.connected || savedRoute != &e.route {
+ e.mu.Unlock()
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+
+ n, ch, err := e.finishWrite(payloadBytes, savedRoute)
+ e.mu.Unlock()
+ return n, ch, err
+ }
+
+ n, ch, err := e.finishWrite(payloadBytes, &e.route)
+ e.mu.RUnlock()
+ return n, ch, err
+ }
+
+ // The caller provided a destination. Reject destination address if it
+ // goes through a different NIC than the endpoint was bound to.
+ nic := opts.To.NIC
+ if e.bound && nic != 0 && nic != e.BindNICID {
+ e.mu.RUnlock()
+ return 0, nil, tcpip.ErrNoRoute
+ }
+
+ // Find the route to the destination. If BindAddress is 0,
+ // FindRoute will choose an appropriate source address.
+ route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
+ if err != nil {
+ e.mu.RUnlock()
+ return 0, nil, err
+ }
+
+ n, ch, err := e.finishWrite(payloadBytes, &route)
+ route.Release()
+ e.mu.RUnlock()
+ return n, ch, err
+}
+
+// finishWrite writes the payload to a route. It resolves the route if
+// necessary. It's really just a helper to make defer unnecessary in Write.
+func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) {
+ // We may need to resolve the route (match a link layer address to the
+ // network address). If that requires blocking (e.g. to use ARP),
+ // return a channel on which the caller can wait.
+ if route.IsResolutionRequired() {
+ if ch, err := route.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ return 0, ch, tcpip.ErrNoLinkAddress
+ }
+ return 0, nil, err
+ }
+ }
+
+ if e.hdrIncluded {
+ if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ }); err != nil {
+ return 0, nil, err
+ }
+ } else {
+ hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
+ if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{
+ Header: hdr,
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ Owner: e.owner,
+ }); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ return int64(len(payloadBytes)), nil, nil
+}
+
+// Peek implements tcpip.Endpoint.Peek.
+func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Connect implements tcpip.Endpoint.Connect.
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if e.closed {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ nic := addr.NIC
+ if e.bound {
+ if e.BindNICID == 0 {
+ // If we're bound, but not to a specific NIC, the NIC
+ // in addr will be used. Nothing to do here.
+ } else if addr.NIC == 0 {
+ // If we're bound to a specific NIC, but addr doesn't
+ // specify a NIC, use the bound NIC.
+ nic = e.BindNICID
+ } else if addr.NIC != e.BindNICID {
+ // We're bound and addr specifies a NIC. They must be
+ // the same.
+ return tcpip.ErrInvalidEndpointState
+ }
+ }
+
+ // Find a route to the destination.
+ route, err := e.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, e.NetProto, false)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+
+ if e.associated {
+ // Re-register the endpoint with the appropriate NIC.
+ if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil {
+ return err
+ }
+ e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
+ e.RegisterNICID = nic
+ }
+
+ // Save the route we've connected via.
+ e.route = route.Clone()
+ e.connected = true
+
+ return nil
+}
+
+// Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets.
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if !e.connected {
+ return tcpip.ErrNotConnected
+ }
+ return nil
+}
+
+// Listen implements tcpip.Endpoint.Listen.
+func (e *endpoint) Listen(backlog int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept implements tcpip.Endpoint.Accept.
+func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+// Bind implements tcpip.Endpoint.Bind.
+func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // If a local address was specified, verify that it's valid.
+ if e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ if e.associated {
+ // Re-register the endpoint with the appropriate NIC.
+ if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil {
+ return err
+ }
+ e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
+ e.RegisterNICID = addr.NIC
+ e.BindNICID = addr.NIC
+ }
+
+ e.BindAddr = addr.Addr
+ e.bound = true
+
+ return nil
+}
+
+// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, tcpip.ErrNotSupported
+}
+
+// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ // Even a connected socket doesn't return a remote address.
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+}
+
+// Readiness implements tcpip.Endpoint.Readiness.
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine whether the endpoint is readable.
+ if (mask & waiter.EventIn) != 0 {
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() || e.rcvClosed {
+ result |= waiter.EventIn
+ }
+ e.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.IPHdrIncludedOption:
+ e.mu.Lock()
+ e.hdrIncluded = v
+ e.mu.Unlock()
+ return nil
+ }
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss stack.SendBufferSizeOption
+ if err := e.stack.Option(&ss); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+ if v < ss.Min {
+ v = ss.Min
+ }
+ e.mu.Lock()
+ e.sndBufSizeMax = v
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs stack.ReceiveBufferSizeOption
+ if err := e.stack.Option(&rs); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err))
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+ if v < rs.Min {
+ v = rs.Min
+ }
+ e.rcvMu.Lock()
+ e.rcvBufSizeMax = v
+ e.rcvMu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ switch opt {
+ case tcpip.KeepaliveEnabledOption:
+ return false, nil
+
+ case tcpip.IPHdrIncludedOption:
+ e.mu.Lock()
+ v := e.hdrIncluded
+ e.mu.Unlock()
+ return v, nil
+
+ default:
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+ switch opt {
+ case tcpip.ReceiveQueueSizeOption:
+ v := 0
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ v = p.data.Size()
+ }
+ e.rcvMu.Unlock()
+ return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSizeMax
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
+ return v, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
+func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
+ e.rcvMu.Lock()
+
+ // Drop the packet if our buffer is currently full or if this is an unassociated
+ // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only
+ // See: https://man7.org/linux/man-pages/man7/raw.7.html
+ //
+ // An IPPROTO_RAW socket is send only. If you really want to receive
+ // all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
+ // Note that packet sockets don't reassemble IP fragments, unlike raw
+ // sockets.
+ if e.rcvClosed || !e.associated {
+ e.rcvMu.Unlock()
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ClosedReceiver.Increment()
+ return
+ }
+
+ if e.rcvBufSize >= e.rcvBufSizeMax {
+ e.rcvMu.Unlock()
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
+ return
+ }
+
+ if e.bound {
+ // If bound to a NIC, only accept data for that NIC.
+ if e.BindNICID != 0 && e.BindNICID != route.NICID() {
+ e.rcvMu.Unlock()
+ return
+ }
+ // If bound to an address, only accept data for that address.
+ if e.BindAddr != "" && e.BindAddr != route.RemoteAddress {
+ e.rcvMu.Unlock()
+ return
+ }
+ }
+
+ // If connected, only accept packets from the remote address we
+ // connected to.
+ if e.connected && e.route.RemoteAddress != route.RemoteAddress {
+ e.rcvMu.Unlock()
+ return
+ }
+
+ wasEmpty := e.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ packet := &rawPacket{
+ senderAddr: tcpip.FullAddress{
+ NIC: route.NICID(),
+ Addr: route.RemoteAddress,
+ },
+ }
+
+ // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
+ // We copy headers' underlying bytes because pkt.*Header may point to
+ // the middle of a slice, and another struct may point to the "outer"
+ // slice. Save/restore doesn't support overlapping slices and will fail.
+ var combinedVV buffer.VectorisedView
+ if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber {
+ headers := make(buffer.View, 0, len(pkt.NetworkHeader)+len(pkt.TransportHeader))
+ headers = append(headers, pkt.NetworkHeader...)
+ headers = append(headers, pkt.TransportHeader...)
+ combinedVV = headers.ToVectorisedView()
+ } else {
+ combinedVV = append(buffer.View(nil), pkt.TransportHeader...).ToVectorisedView()
+ }
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
+ packet.timestampNS = e.stack.NowNanoseconds()
+
+ e.rcvList.PushBack(packet)
+ e.rcvBufSize += packet.data.Size()
+ e.rcvMu.Unlock()
+ e.stats.PacketsReceived.Increment()
+ // Notify waiters that there's data to be read.
+ if wasEmpty {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+}
+
+// State implements socket.Socket.State.
+func (e *endpoint) State() uint32 {
+ return 0
+}
+
+// Info returns a copy of the endpoint info.
+func (e *endpoint) Info() tcpip.EndpointInfo {
+ e.mu.RLock()
+ // Make a copy of the endpoint info.
+ ret := e.TransportEndpointInfo
+ e.mu.RUnlock()
+ return &ret
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (e *endpoint) Stats() tcpip.EndpointStats {
+ return &e.stats
+}
+
+// Wait implements stack.TransportEndpoint.Wait.
+func (*endpoint) Wait() {}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
new file mode 100644
index 000000000..33bfb56cd
--- /dev/null
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -0,0 +1,94 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package raw
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves rawPacket.data field.
+func (p *rawPacket) saveData() buffer.VectorisedView {
+ // We cannot save p.data directly as p.data.views may alias to p.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return p.data.Clone(nil)
+}
+
+// loadData loads rawPacket.data field.
+func (p *rawPacket) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing p.views for data.views.
+ p.data = data
+}
+
+// beforeSave is invoked by stateify.
+func (ep *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after saveRcvBufSizeMax(), which would have
+ // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ ep.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) saveRcvBufSizeMax() int {
+ max := ep.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ ep.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ ep.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) loadRcvBufSizeMax(max int) {
+ ep.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (ep *endpoint) afterLoad() {
+ stack.StackFromEnv.RegisterRestoredEndpoint(ep)
+}
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (ep *endpoint) Resume(s *stack.Stack) {
+ ep.stack = s
+
+ // If the endpoint is connected, re-connect.
+ if ep.connected {
+ var err *tcpip.Error
+ ep.route, err = ep.stack.FindRoute(ep.RegisterNICID, ep.BindAddr, ep.route.RemoteAddress, ep.NetProto, false)
+ if err != nil {
+ panic(err)
+ }
+ }
+
+ // If the endpoint is bound, re-bind.
+ if ep.bound {
+ if ep.stack.CheckLocalAddress(ep.RegisterNICID, ep.NetProto, ep.BindAddr) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ if ep.associated {
+ if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil {
+ panic(err)
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go
new file mode 100644
index 000000000..f30aa2a4a
--- /dev/null
+++ b/pkg/tcpip/transport/raw/protocol.go
@@ -0,0 +1,35 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package raw
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/packet"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// EndpointFactory implements stack.RawFactory.
+type EndpointFactory struct{}
+
+// NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint.
+func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */)
+}
+
+// NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint.
+func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return packet.NewEndpoint(stack, cooked, netProto, waiterQueue)
+}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
new file mode 100644
index 000000000..18ff89ffc
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -0,0 +1,126 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "tcp_segment_list",
+ out = "tcp_segment_list.go",
+ package = "tcp",
+ prefix = "segment",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*segment",
+ "Linker": "*segment",
+ },
+)
+
+go_template_instance(
+ name = "tcp_endpoint_list",
+ out = "tcp_endpoint_list.go",
+ package = "tcp",
+ prefix = "endpoint",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*endpoint",
+ "Linker": "*endpoint",
+ },
+)
+
+go_library(
+ name = "tcp",
+ srcs = [
+ "accept.go",
+ "connect.go",
+ "connect_unsafe.go",
+ "cubic.go",
+ "cubic_state.go",
+ "dispatcher.go",
+ "endpoint.go",
+ "endpoint_state.go",
+ "forwarder.go",
+ "protocol.go",
+ "rcv.go",
+ "rcv_state.go",
+ "reno.go",
+ "sack.go",
+ "sack_scoreboard.go",
+ "segment.go",
+ "segment_heap.go",
+ "segment_queue.go",
+ "segment_state.go",
+ "snd.go",
+ "snd_state.go",
+ "tcp_endpoint_list.go",
+ "tcp_segment_list.go",
+ "timer.go",
+ ],
+ imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/log",
+ "//pkg/rand",
+ "//pkg/sleep",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/hash/jenkins",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/ports",
+ "//pkg/tcpip/seqnum",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/raw",
+ "//pkg/waiter",
+ "@com_github_google_btree//:go_default_library",
+ ],
+)
+
+go_test(
+ name = "tcp_x_test",
+ size = "medium",
+ srcs = [
+ "dual_stack_test.go",
+ "sack_scoreboard_test.go",
+ "tcp_noracedetector_test.go",
+ "tcp_sack_test.go",
+ "tcp_test.go",
+ "tcp_timestamp_test.go",
+ ],
+ shard_count = 10,
+ deps = [
+ ":tcp",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/ports",
+ "//pkg/tcpip/seqnum",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp/testing/context",
+ "//pkg/test/testutil",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "rcv_test",
+ size = "small",
+ srcs = ["rcv_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ ],
+)
+
+go_test(
+ name = "tcp_test",
+ size = "small",
+ srcs = ["timer_test.go"],
+ library = ":tcp",
+ deps = ["//pkg/sleep"],
+)
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
new file mode 100644
index 000000000..6e00e5526
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -0,0 +1,752 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "crypto/sha1"
+ "encoding/binary"
+ "fmt"
+ "hash"
+ "io"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ // tsLen is the length, in bits, of the timestamp in the SYN cookie.
+ tsLen = 8
+
+ // tsMask is a mask for timestamp values (i.e., tsLen bits).
+ tsMask = (1 << tsLen) - 1
+
+ // tsOffset is the offset, in bits, of the timestamp in the SYN cookie.
+ tsOffset = 24
+
+ // hashMask is the mask for hash values (i.e., tsOffset bits).
+ hashMask = (1 << tsOffset) - 1
+
+ // maxTSDiff is the maximum allowed difference between a received cookie
+ // timestamp and the current timestamp. If the difference is greater
+ // than maxTSDiff, the cookie is expired.
+ maxTSDiff = 2
+
+ // SynRcvdCountThreshold is the default global maximum number of
+ // connections that are allowed to be in SYN-RCVD state before TCP
+ // starts using SYN cookies to accept connections.
+ SynRcvdCountThreshold uint64 = 1000
+)
+
+var (
+ // mssTable is a slice containing the possible MSS values that we
+ // encode in the SYN cookie with two bits.
+ mssTable = []uint16{536, 1300, 1440, 1460}
+)
+
+func encodeMSS(mss uint16) uint32 {
+ for i := len(mssTable) - 1; i > 0; i-- {
+ if mss >= mssTable[i] {
+ return uint32(i)
+ }
+ }
+ return 0
+}
+
+// listenContext is used by a listening endpoint to store state used while
+// listening for connections. This struct is allocated by the listen goroutine
+// and must not be accessed or have its methods called concurrently as they
+// may mutate the stored objects.
+type listenContext struct {
+ stack *stack.Stack
+
+ // synRcvdCount is a reference to the stack level synRcvdCount.
+ synRcvdCount *synRcvdCounter
+
+ // rcvWnd is the receive window that is sent by this listening context
+ // in the initial SYN-ACK.
+ rcvWnd seqnum.Size
+
+ // nonce are random bytes that are initialized once when the context
+ // is created and used to seed the hash function when generating
+ // the SYN cookie.
+ nonce [2][sha1.BlockSize]byte
+
+ // listenEP is a reference to the listening endpoint associated with
+ // this context. Can be nil if the context is created by the forwarder.
+ listenEP *endpoint
+
+ // hasherMu protects hasher.
+ hasherMu sync.Mutex
+ // hasher is the hash function used to generate a SYN cookie.
+ hasher hash.Hash
+
+ // v6Only is true if listenEP is a dual stack socket and has the
+ // IPV6_V6ONLY option set.
+ v6Only bool
+
+ // netProto indicates the network protocol(IPv4/v6) for the listening
+ // endpoint.
+ netProto tcpip.NetworkProtocolNumber
+
+ // pendingMu protects pendingEndpoints. This should only be accessed
+ // by the listening endpoint's worker goroutine.
+ //
+ // Lock Ordering: listenEP.workerMu -> pendingMu
+ pendingMu sync.Mutex
+ // pending is used to wait for all pendingEndpoints to finish when
+ // a socket is closed.
+ pending sync.WaitGroup
+ // pendingEndpoints is a map of all endpoints for which a handshake is
+ // in progress.
+ pendingEndpoints map[stack.TransportEndpointID]*endpoint
+}
+
+// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
+func timeStamp() uint32 {
+ return uint32(time.Now().Unix()>>6) & tsMask
+}
+
+// newListenContext creates a new listen context.
+func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
+ l := &listenContext{
+ stack: stk,
+ rcvWnd: rcvWnd,
+ hasher: sha1.New(),
+ v6Only: v6Only,
+ netProto: netProto,
+ listenEP: listenEP,
+ pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
+ }
+ p, ok := stk.TransportProtocolInstance(ProtocolNumber).(*protocol)
+ if !ok {
+ panic(fmt.Sprintf("unable to get TCP protocol instance from stack: %+v", stk))
+ }
+ l.synRcvdCount = p.SynRcvdCounter()
+
+ rand.Read(l.nonce[0][:])
+ rand.Read(l.nonce[1][:])
+
+ return l
+}
+
+// cookieHash calculates the cookieHash for the given id, timestamp and nonce
+// index. The hash is used to create and validate cookies.
+func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonceIndex int) uint32 {
+
+ // Initialize block with fixed-size data: local ports and v.
+ var payload [8]byte
+ binary.BigEndian.PutUint16(payload[0:], id.LocalPort)
+ binary.BigEndian.PutUint16(payload[2:], id.RemotePort)
+ binary.BigEndian.PutUint32(payload[4:], ts)
+
+ // Feed everything to the hasher.
+ l.hasherMu.Lock()
+ l.hasher.Reset()
+ l.hasher.Write(payload[:])
+ l.hasher.Write(l.nonce[nonceIndex][:])
+ io.WriteString(l.hasher, string(id.LocalAddress))
+ io.WriteString(l.hasher, string(id.RemoteAddress))
+
+ // Finalize the calculation of the hash and return the first 4 bytes.
+ h := make([]byte, 0, sha1.Size)
+ h = l.hasher.Sum(h)
+ l.hasherMu.Unlock()
+
+ return binary.BigEndian.Uint32(h[:])
+}
+
+// createCookie creates a SYN cookie for the given id and incoming sequence
+// number.
+func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Value, data uint32) seqnum.Value {
+ ts := timeStamp()
+ v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset)
+ v += (l.cookieHash(id, ts, 1) + data) & hashMask
+ return seqnum.Value(v)
+}
+
+// isCookieValid checks if the supplied cookie is valid for the given id and
+// sequence number. If it is, it also returns the data originally encoded in the
+// cookie when createCookie was called.
+func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnum.Value, seq seqnum.Value) (uint32, bool) {
+ ts := timeStamp()
+ v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq)
+ cookieTS := v >> tsOffset
+ if ((ts - cookieTS) & tsMask) > maxTSDiff {
+ return 0, false
+ }
+
+ return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true
+}
+
+// createConnectingEndpoint creates a new endpoint in a connecting state, with
+// the connection parameters given by the arguments.
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint {
+ // Create a new endpoint.
+ netProto := l.netProto
+ if netProto == 0 {
+ netProto = s.route.NetProto
+ }
+ n := newEndpoint(l.stack, netProto, queue)
+ n.v6only = l.v6Only
+ n.ID = s.id
+ n.boundNICID = s.route.NICID()
+ n.route = s.route.Clone()
+ n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
+ n.rcvBufSize = int(l.rcvWnd)
+ n.amss = mssForRoute(&n.route)
+ n.setEndpointState(StateConnecting)
+
+ n.maybeEnableTimestamp(rcvdSynOpts)
+ n.maybeEnableSACKPermitted(rcvdSynOpts)
+
+ n.initGSO()
+
+ // Bootstrap the auto tuning algorithm. Starting at zero will result in
+ // a large step function on the first window adjustment causing the
+ // window to grow to a really large value.
+ n.rcvAutoParams.prevCopied = n.initialReceiveWindow()
+
+ return n
+}
+
+// createEndpointAndPerformHandshake creates a new endpoint in connected state
+// and then performs the TCP 3-way handshake.
+//
+// The new endpoint is returned with e.mu held.
+func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) {
+ // Create new endpoint.
+ irs := s.sequenceNumber
+ isn := generateSecureISN(s.id, l.stack.Seed())
+ ep := l.createConnectingEndpoint(s, isn, irs, opts, queue)
+
+ // Lock the endpoint before registering to ensure that no out of
+ // band changes are possible due to incoming packets etc till
+ // the endpoint is done initializing.
+ ep.mu.Lock()
+ ep.owner = owner
+
+ // listenEP is nil when listenContext is used by tcp.Forwarder.
+ deferAccept := time.Duration(0)
+ if l.listenEP != nil {
+ l.listenEP.mu.Lock()
+ if l.listenEP.EndpointState() != StateListen {
+
+ l.listenEP.mu.Unlock()
+ // Ensure we release any registrations done by the newly
+ // created endpoint.
+ ep.mu.Unlock()
+ ep.Close()
+
+ return nil, tcpip.ErrConnectionAborted
+ }
+ l.addPendingEndpoint(ep)
+
+ // Propagate any inheritable options from the listening endpoint
+ // to the newly created endpoint.
+ l.listenEP.propagateInheritableOptionsLocked(ep)
+
+ if !ep.reserveTupleLocked() {
+ ep.mu.Unlock()
+ ep.Close()
+
+ if l.listenEP != nil {
+ l.removePendingEndpoint(ep)
+ l.listenEP.mu.Unlock()
+ }
+
+ return nil, tcpip.ErrConnectionAborted
+ }
+
+ deferAccept = l.listenEP.deferAccept
+ l.listenEP.mu.Unlock()
+ }
+
+ // Register new endpoint so that packets are routed to it.
+ if err := ep.stack.RegisterTransportEndpoint(ep.boundNICID, ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil {
+ ep.mu.Unlock()
+ ep.Close()
+
+ if l.listenEP != nil {
+ l.removePendingEndpoint(ep)
+ }
+
+ ep.drainClosingSegmentQueue()
+
+ return nil, err
+ }
+
+ ep.isRegistered = true
+
+ // Perform the 3-way handshake.
+ h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
+ if err := h.execute(); err != nil {
+ ep.mu.Unlock()
+ ep.Close()
+ ep.notifyAborted()
+
+ if l.listenEP != nil {
+ l.removePendingEndpoint(ep)
+ }
+
+ ep.drainClosingSegmentQueue()
+
+ return nil, err
+ }
+ ep.isConnectNotified = true
+
+ // Update the receive window scaling. We can't do it before the
+ // handshake because it's possible that the peer doesn't support window
+ // scaling.
+ ep.rcv.rcvWndScale = h.effectiveRcvWndScale()
+
+ return ep, nil
+}
+
+func (l *listenContext) addPendingEndpoint(n *endpoint) {
+ l.pendingMu.Lock()
+ l.pendingEndpoints[n.ID] = n
+ l.pending.Add(1)
+ l.pendingMu.Unlock()
+}
+
+func (l *listenContext) removePendingEndpoint(n *endpoint) {
+ l.pendingMu.Lock()
+ delete(l.pendingEndpoints, n.ID)
+ l.pending.Done()
+ l.pendingMu.Unlock()
+}
+
+func (l *listenContext) closeAllPendingEndpoints() {
+ l.pendingMu.Lock()
+ for _, n := range l.pendingEndpoints {
+ n.notifyProtocolGoroutine(notifyClose)
+ }
+ l.pendingMu.Unlock()
+ l.pending.Wait()
+}
+
+// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
+// endpoint has transitioned out of the listen state (acceptedChan is nil),
+// the new endpoint is closed instead.
+func (e *endpoint) deliverAccepted(n *endpoint) {
+ e.mu.Lock()
+ e.pendingAccepted.Add(1)
+ e.mu.Unlock()
+ defer e.pendingAccepted.Done()
+
+ e.acceptMu.Lock()
+ for {
+ if e.acceptedChan == nil {
+ e.acceptMu.Unlock()
+ n.notifyProtocolGoroutine(notifyReset)
+ return
+ }
+ select {
+ case e.acceptedChan <- n:
+ e.acceptMu.Unlock()
+ e.waiterQueue.Notify(waiter.EventIn)
+ return
+ default:
+ e.acceptCond.Wait()
+ }
+ }
+}
+
+// propagateInheritableOptionsLocked propagates any options set on the listening
+// endpoint to the newly created endpoint.
+//
+// Precondition: e.mu and n.mu must be held.
+func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
+ n.userTimeout = e.userTimeout
+ n.portFlags = e.portFlags
+ n.boundBindToDevice = e.boundBindToDevice
+ n.boundPortFlags = e.boundPortFlags
+}
+
+// reserveTupleLocked reserves an accepted endpoint's tuple.
+//
+// Preconditions:
+// * propagateInheritableOptionsLocked has been called.
+// * e.mu is held.
+func (e *endpoint) reserveTupleLocked() bool {
+ dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort}
+ if !e.stack.ReserveTuple(
+ e.effectiveNetProtos,
+ ProtocolNumber,
+ e.ID.LocalAddress,
+ e.ID.LocalPort,
+ e.boundPortFlags,
+ e.boundBindToDevice,
+ dest,
+ ) {
+ return false
+ }
+
+ e.isPortReserved = true
+ e.boundDest = dest
+ return true
+}
+
+// notifyAborted wakes up any waiters on registered, but not accepted
+// endpoints.
+//
+// This is strictly not required normally as a socket that was never accepted
+// can't really have any registered waiters except when stack.Wait() is called
+// which waits for all registered endpoints to stop and expects an EventHUp.
+func (e *endpoint) notifyAborted() {
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// handleSynSegment is called in its own goroutine once the listening endpoint
+// receives a SYN segment. It is responsible for completing the handshake and
+// queueing the new endpoint for acceptance.
+//
+// A limited number of these goroutines are allowed before TCP starts using SYN
+// cookies to accept connections.
+func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
+ defer ctx.synRcvdCount.dec()
+ defer func() {
+ e.mu.Lock()
+ e.decSynRcvdCount()
+ e.mu.Unlock()
+ }()
+ defer s.decRef()
+
+ n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner)
+ if err != nil {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ return
+ }
+ ctx.removePendingEndpoint(n)
+ n.startAcceptedLoop()
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+
+ e.deliverAccepted(n)
+}
+
+func (e *endpoint) incSynRcvdCount() bool {
+ e.acceptMu.Lock()
+ canInc := e.synRcvdCount < cap(e.acceptedChan)
+ e.acceptMu.Unlock()
+ if canInc {
+ e.synRcvdCount++
+ }
+ return canInc
+}
+
+func (e *endpoint) decSynRcvdCount() {
+ e.synRcvdCount--
+}
+
+func (e *endpoint) acceptQueueIsFull() bool {
+ e.acceptMu.Lock()
+ full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan)
+ e.acceptMu.Unlock()
+ return full
+}
+
+// handleListenSegment is called when a listening endpoint receives a segment
+// and needs to handle it.
+func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
+ e.rcvListMu.Lock()
+ rcvClosed := e.rcvClosed
+ e.rcvListMu.Unlock()
+ if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) {
+ // If the endpoint is shutdown, reply with reset.
+ //
+ // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST
+ // must be sent in response to a SYN-ACK while in the listen
+ // state to prevent completing a handshake from an old SYN.
+ replyWithReset(s, e.sendTOS, e.ttl)
+ return
+ }
+
+ // TODO(b/143300739): Use the userMSS of the listening socket
+ // for accepted sockets.
+
+ switch {
+ case s.flags == header.TCPFlagSyn:
+ opts := parseSynSegmentOptions(s)
+ if ctx.synRcvdCount.inc() {
+ // Only handle the syn if the following conditions hold
+ // - accept queue is not full.
+ // - number of connections in synRcvd state is less than the
+ // backlog.
+ if !e.acceptQueueIsFull() && e.incSynRcvdCount() {
+ s.incRef()
+ go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier.
+ return
+ }
+ ctx.synRcvdCount.dec()
+ e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
+ e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ } else {
+ // If cookies are in use but the endpoint accept queue
+ // is full then drop the syn.
+ if e.acceptQueueIsFull() {
+ e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
+ e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ }
+ cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
+
+ // Send SYN without window scaling because we currently
+ // dont't encode this information in the cookie.
+ //
+ // Enable Timestamp option if the original syn did have
+ // the timestamp option specified.
+ synOpts := header.TCPSynOptions{
+ WS: -1,
+ TS: opts.TS,
+ TSVal: tcpTimeStamp(timeStampOffset()),
+ TSEcr: opts.TSVal,
+ MSS: mssForRoute(&s.route),
+ }
+ e.sendSynTCP(&s.route, tcpFields{
+ id: s.id,
+ ttl: e.ttl,
+ tos: e.sendTOS,
+ flags: header.TCPFlagSyn | header.TCPFlagAck,
+ seq: cookie,
+ ack: s.sequenceNumber + 1,
+ rcvWnd: ctx.rcvWnd,
+ }, synOpts)
+ e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
+ }
+
+ case (s.flags & header.TCPFlagAck) != 0:
+ if e.acceptQueueIsFull() {
+ // Silently drop the ack as the application can't accept
+ // the connection at this point. The ack will be
+ // retransmitted by the sender anyway and we can
+ // complete the connection at the time of retransmit if
+ // the backlog has space.
+ e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
+ e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ }
+
+ if !ctx.synRcvdCount.synCookiesInUse() {
+ // When not using SYN cookies, as per RFC 793, section 3.9, page 64:
+ // Any acknowledgment is bad if it arrives on a connection still in
+ // the LISTEN state. An acceptable reset segment should be formed
+ // for any arriving ACK-bearing segment. The RST should be
+ // formatted as follows:
+ //
+ // <SEQ=SEG.ACK><CTL=RST>
+ //
+ // Send a reset as this is an ACK for which there is no
+ // half open connections and we are not using cookies
+ // yet.
+ //
+ // The only time we should reach here when a connection
+ // was opened and closed really quickly and a delayed
+ // ACK was received from the sender.
+ replyWithReset(s, e.sendTOS, e.ttl)
+ return
+ }
+
+ iss := s.ackNumber - 1
+ irs := s.sequenceNumber - 1
+
+ // Since SYN cookies are in use this is potentially an ACK to a
+ // SYN-ACK we sent but don't have a half open connection state
+ // as cookies are being used to protect against a potential SYN
+ // flood. In such cases validate the cookie and if valid create
+ // a fully connected endpoint and deliver to the accept queue.
+ //
+ // If not, silently drop the ACK to avoid leaking information
+ // when under a potential syn flood attack.
+ //
+ // Validate the cookie.
+ data, ok := ctx.isCookieValid(s.id, iss, irs)
+ if !ok || int(data) >= len(mssTable) {
+ e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ }
+ e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment()
+ // Create newly accepted endpoint and deliver it.
+ rcvdSynOptions := &header.TCPSynOptions{
+ MSS: mssTable[data],
+ // Disable Window scaling as original SYN is
+ // lost.
+ WS: -1,
+ }
+
+ // When syn cookies are in use we enable timestamp only
+ // if the ack specifies the timestamp option assuming
+ // that the other end did in fact negotiate the
+ // timestamp option in the original SYN.
+ if s.parsedOptions.TS {
+ rcvdSynOptions.TS = true
+ rcvdSynOptions.TSVal = s.parsedOptions.TSVal
+ rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
+ }
+
+ n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{})
+
+ n.mu.Lock()
+
+ // Propagate any inheritable options from the listening endpoint
+ // to the newly created endpoint.
+ e.propagateInheritableOptionsLocked(n)
+
+ if !n.reserveTupleLocked() {
+ n.mu.Unlock()
+ n.Close()
+
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ return
+ }
+
+ // Register new endpoint so that packets are routed to it.
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil {
+ n.mu.Unlock()
+ n.Close()
+
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ return
+ }
+
+ n.isRegistered = true
+
+ // clear the tsOffset for the newly created
+ // endpoint as the Timestamp was already
+ // randomly offset when the original SYN-ACK was
+ // sent above.
+ n.tsOffset = 0
+
+ // Switch state to connected.
+ n.isConnectNotified = true
+ n.transitionToStateEstablishedLocked(&handshake{
+ ep: n,
+ iss: iss,
+ ackNum: irs + 1,
+ rcvWnd: seqnum.Size(n.initialReceiveWindow()),
+ sndWnd: s.window,
+ rcvWndScale: e.rcvWndScaleForHandshake(),
+ sndWndScale: rcvdSynOptions.WS,
+ mss: rcvdSynOptions.MSS,
+ })
+
+ // Do the delivery in a separate goroutine so
+ // that we don't block the listen loop in case
+ // the application is slow to accept or stops
+ // accepting.
+ //
+ // NOTE: This won't result in an unbounded
+ // number of goroutines as we do check before
+ // entering here that there was at least some
+ // space available in the backlog.
+
+ // Start the protocol goroutine.
+ n.startAcceptedLoop()
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+ go e.deliverAccepted(n)
+ }
+}
+
+// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in
+// its own goroutine and is responsible for handling connection requests.
+func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
+ e.mu.Lock()
+ v6Only := e.v6only
+ ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto)
+
+ defer func() {
+ // Mark endpoint as closed. This will prevent goroutines running
+ // handleSynSegment() from attempting to queue new connections
+ // to the endpoint.
+ e.setEndpointState(StateClose)
+
+ // close any endpoints in SYN-RCVD state.
+ ctx.closeAllPendingEndpoints()
+
+ // Do cleanup if needed.
+ e.completeWorkerLocked()
+
+ if e.drainDone != nil {
+ close(e.drainDone)
+ }
+ e.mu.Unlock()
+
+ e.drainClosingSegmentQueue()
+
+ // Notify waiters that the endpoint is shutdown.
+ e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
+ }()
+
+ s := sleep.Sleeper{}
+ s.AddWaker(&e.notificationWaker, wakerForNotification)
+ s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
+ for {
+ e.mu.Unlock()
+ index, _ := s.Fetch(true)
+ e.mu.Lock()
+ switch index {
+ case wakerForNotification:
+ n := e.fetchNotifications()
+ if n&notifyClose != 0 {
+ return nil
+ }
+ if n&notifyDrain != 0 {
+ for !e.segmentQueue.empty() {
+ s := e.segmentQueue.dequeue()
+ e.handleListenSegment(ctx, s)
+ s.decRef()
+ }
+ close(e.drainDone)
+ e.mu.Unlock()
+ <-e.undrain
+ e.mu.Lock()
+ }
+
+ case wakerForNewSegment:
+ // Process at most maxSegmentsPerWake segments.
+ mayRequeue := true
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ mayRequeue = false
+ break
+ }
+
+ e.handleListenSegment(ctx, s)
+ s.decRef()
+ }
+
+ // If the queue is not empty, make sure we'll wake up
+ // in the next iteration.
+ if mayRequeue && !e.segmentQueue.empty() {
+ e.newSegmentWaker.Assert()
+ }
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
new file mode 100644
index 000000000..81b740115
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -0,0 +1,1713 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "encoding/binary"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// maxSegmentsPerWake is the maximum number of segments to process in the main
+// protocol goroutine per wake-up. Yielding [after this number of segments are
+// processed] allows other events to be processed as well (e.g., timeouts,
+// resets, etc.).
+const maxSegmentsPerWake = 100
+
+type handshakeState int
+
+// The following are the possible states of the TCP connection during a 3-way
+// handshake. A depiction of the states and transitions can be found in RFC 793,
+// page 23.
+const (
+ handshakeSynSent handshakeState = iota
+ handshakeSynRcvd
+ handshakeCompleted
+)
+
+// The following are used to set up sleepers.
+const (
+ wakerForNotification = iota
+ wakerForNewSegment
+ wakerForResend
+ wakerForResolution
+)
+
+const (
+ // Maximum space available for options.
+ maxOptionSize = 40
+)
+
+// handshake holds the state used during a TCP 3-way handshake.
+//
+// NOTE: handshake.ep.mu is held during handshake processing. It is released if
+// we are going to block and reacquired when we start processing an event.
+type handshake struct {
+ ep *endpoint
+ state handshakeState
+ active bool
+ flags uint8
+ ackNum seqnum.Value
+
+ // iss is the initial send sequence number, as defined in RFC 793.
+ iss seqnum.Value
+
+ // rcvWnd is the receive window, as defined in RFC 793.
+ rcvWnd seqnum.Size
+
+ // sndWnd is the send window, as defined in RFC 793.
+ sndWnd seqnum.Size
+
+ // mss is the maximum segment size received from the peer.
+ mss uint16
+
+ // sndWndScale is the send window scale, as defined in RFC 1323. A
+ // negative value means no scaling is supported by the peer.
+ sndWndScale int
+
+ // rcvWndScale is the receive window scale, as defined in RFC 1323.
+ rcvWndScale int
+
+ // startTime is the time at which the first SYN/SYN-ACK was sent.
+ startTime time.Time
+
+ // deferAccept if non-zero will drop the final ACK for a passive
+ // handshake till an ACK segment with data is received or the timeout is
+ // hit.
+ deferAccept time.Duration
+
+ // acked is true if the the final ACK for a 3-way handshake has
+ // been received. This is required to stop retransmitting the
+ // original SYN-ACK when deferAccept is enabled.
+ acked bool
+}
+
+func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
+ h := handshake{
+ ep: ep,
+ active: true,
+ rcvWnd: rcvWnd,
+ rcvWndScale: ep.rcvWndScaleForHandshake(),
+ }
+ h.resetState()
+ return h
+}
+
+func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake {
+ h := newHandshake(ep, rcvWnd)
+ h.resetToSynRcvd(isn, irs, opts, deferAccept)
+ return h
+}
+
+// FindWndScale determines the window scale to use for the given maximum window
+// size.
+func FindWndScale(wnd seqnum.Size) int {
+ if wnd < 0x10000 {
+ return 0
+ }
+
+ max := seqnum.Size(0xffff)
+ s := 0
+ for wnd > max && s < header.MaxWndScale {
+ s++
+ max <<= 1
+ }
+
+ return s
+}
+
+// resetState resets the state of the handshake object such that it becomes
+// ready for a new 3-way handshake.
+func (h *handshake) resetState() {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+
+ h.state = handshakeSynSent
+ h.flags = header.TCPFlagSyn
+ h.ackNum = 0
+ h.mss = 0
+ h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed())
+}
+
+// generateSecureISN generates a secure Initial Sequence number based on the
+// recommendation here https://tools.ietf.org/html/rfc6528#page-3.
+func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value {
+ isnHasher := jenkins.Sum32(seed)
+ isnHasher.Write([]byte(id.LocalAddress))
+ isnHasher.Write([]byte(id.RemoteAddress))
+ portBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(portBuf, id.LocalPort)
+ isnHasher.Write(portBuf)
+ binary.LittleEndian.PutUint16(portBuf, id.RemotePort)
+ isnHasher.Write(portBuf)
+ // The time period here is 64ns. This is similar to what linux uses
+ // generate a sequence number that overlaps less than one
+ // time per MSL (2 minutes).
+ //
+ // A 64ns clock ticks 10^9/64 = 15625000) times in a second.
+ // To wrap the whole 32 bit space would require
+ // 2^32/1562500 ~ 274 seconds.
+ //
+ // Which sort of guarantees that we won't reuse the ISN for a new
+ // connection for the same tuple for at least 274s.
+ isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6)
+ return seqnum.Value(isn)
+}
+
+// effectiveRcvWndScale returns the effective receive window scale to be used.
+// If the peer doesn't support window scaling, the effective rcv wnd scale is
+// zero; otherwise it's the value calculated based on the initial rcv wnd.
+func (h *handshake) effectiveRcvWndScale() uint8 {
+ if h.sndWndScale < 0 {
+ return 0
+ }
+ return uint8(h.rcvWndScale)
+}
+
+// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD
+// state.
+func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) {
+ h.active = false
+ h.state = handshakeSynRcvd
+ h.flags = header.TCPFlagSyn | header.TCPFlagAck
+ h.iss = iss
+ h.ackNum = irs + 1
+ h.mss = opts.MSS
+ h.sndWndScale = opts.WS
+ h.deferAccept = deferAccept
+ h.ep.setEndpointState(StateSynRecv)
+}
+
+// checkAck checks if the ACK number, if present, of a segment received during
+// a TCP 3-way handshake is valid. If it's not, a RST segment is sent back in
+// response.
+func (h *handshake) checkAck(s *segment) bool {
+ if s.flagIsSet(header.TCPFlagAck) && s.ackNumber != h.iss+1 {
+ // RFC 793, page 36, states that a reset must be generated when
+ // the connection is in any non-synchronized state and an
+ // incoming segment acknowledges something not yet sent. The
+ // connection remains in the same state.
+ ack := s.sequenceNumber.Add(s.logicalLen())
+ h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, s.ackNumber, ack, 0)
+ return false
+ }
+
+ return true
+}
+
+// synSentState handles a segment received when the TCP 3-way handshake is in
+// the SYN-SENT state.
+func (h *handshake) synSentState(s *segment) *tcpip.Error {
+ // RFC 793, page 37, states that in the SYN-SENT state, a reset is
+ // acceptable if the ack field acknowledges the SYN.
+ if s.flagIsSet(header.TCPFlagRst) {
+ if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 {
+ // RFC 793, page 67, states that "If the RST bit is set [and] If the ACK
+ // was acceptable then signal the user "error: connection reset", drop
+ // the segment, enter CLOSED state, delete TCB, and return."
+ h.ep.workerCleanup = true
+ // Although the RFC above calls out ECONNRESET, Linux actually returns
+ // ECONNREFUSED here so we do as well.
+ return tcpip.ErrConnectionRefused
+ }
+ return nil
+ }
+
+ if !h.checkAck(s) {
+ return nil
+ }
+
+ // We are in the SYN-SENT state. We only care about segments that have
+ // the SYN flag.
+ if !s.flagIsSet(header.TCPFlagSyn) {
+ return nil
+ }
+
+ // Parse the SYN options.
+ rcvSynOpts := parseSynSegmentOptions(s)
+
+ // Remember if the Timestamp option was negotiated.
+ h.ep.maybeEnableTimestamp(&rcvSynOpts)
+
+ // Remember if the SACKPermitted option was negotiated.
+ h.ep.maybeEnableSACKPermitted(&rcvSynOpts)
+
+ // Remember the sequence we'll ack from now on.
+ h.ackNum = s.sequenceNumber + 1
+ h.flags |= header.TCPFlagAck
+ h.mss = rcvSynOpts.MSS
+ h.sndWndScale = rcvSynOpts.WS
+
+ // If this is a SYN ACK response, we only need to acknowledge the SYN
+ // and the handshake is completed.
+ if s.flagIsSet(header.TCPFlagAck) {
+ h.state = handshakeCompleted
+
+ h.ep.transitionToStateEstablishedLocked(h)
+
+ h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale())
+ return nil
+ }
+
+ // A SYN segment was received, but no ACK in it. We acknowledge the SYN
+ // but resend our own SYN and wait for it to be acknowledged in the
+ // SYN-RCVD state.
+ h.state = handshakeSynRcvd
+ ttl := h.ep.ttl
+ amss := h.ep.amss
+ h.ep.setEndpointState(StateSynRecv)
+ synOpts := header.TCPSynOptions{
+ WS: int(h.effectiveRcvWndScale()),
+ TS: rcvSynOpts.TS,
+ TSVal: h.ep.timestamp(),
+ TSEcr: h.ep.recentTimestamp(),
+
+ // We only send SACKPermitted if the other side indicated it
+ // permits SACK. This is not explicitly defined in the RFC but
+ // this is the behaviour implemented by Linux.
+ SACKPermitted: rcvSynOpts.SACKPermitted,
+ MSS: amss,
+ }
+ if ttl == 0 {
+ ttl = s.route.DefaultTTL()
+ }
+ h.ep.sendSynTCP(&s.route, tcpFields{
+ id: h.ep.ID,
+ ttl: ttl,
+ tos: h.ep.sendTOS,
+ flags: h.flags,
+ seq: h.iss,
+ ack: h.ackNum,
+ rcvWnd: h.rcvWnd,
+ }, synOpts)
+ return nil
+}
+
+// synRcvdState handles a segment received when the TCP 3-way handshake is in
+// the SYN-RCVD state.
+func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
+ if s.flagIsSet(header.TCPFlagRst) {
+ // RFC 793, page 37, states that in the SYN-RCVD state, a reset
+ // is acceptable if the sequence number is in the window.
+ if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) {
+ return tcpip.ErrConnectionRefused
+ }
+ return nil
+ }
+
+ if !h.checkAck(s) {
+ return nil
+ }
+
+ // RFC 793, Section 3.9, page 69, states that in the SYN-RCVD state, a
+ // sequence number outside of the window causes an ACK with the proper seq
+ // number and "After sending the acknowledgment, drop the unacceptable
+ // segment and return."
+ if !s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) {
+ h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd)
+ return nil
+ }
+
+ if s.flagIsSet(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 {
+ // We received two SYN segments with different sequence
+ // numbers, so we reset this and restart the whole
+ // process, except that we don't reset the timer.
+ ack := s.sequenceNumber.Add(s.logicalLen())
+ seq := seqnum.Value(0)
+ if s.flagIsSet(header.TCPFlagAck) {
+ seq = s.ackNumber
+ }
+ h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0)
+
+ if !h.active {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ h.resetState()
+ synOpts := header.TCPSynOptions{
+ WS: h.rcvWndScale,
+ TS: h.ep.sendTSOk,
+ TSVal: h.ep.timestamp(),
+ TSEcr: h.ep.recentTimestamp(),
+ SACKPermitted: h.ep.sackPermitted,
+ MSS: h.ep.amss,
+ }
+ h.ep.sendSynTCP(&s.route, tcpFields{
+ id: h.ep.ID,
+ ttl: h.ep.ttl,
+ tos: h.ep.sendTOS,
+ flags: h.flags,
+ seq: h.iss,
+ ack: h.ackNum,
+ rcvWnd: h.rcvWnd,
+ }, synOpts)
+ return nil
+ }
+
+ // We have previously received (and acknowledged) the peer's SYN. If the
+ // peer acknowledges our SYN, the handshake is completed.
+ if s.flagIsSet(header.TCPFlagAck) {
+ // If deferAccept is not zero and this is a bare ACK and the
+ // timeout is not hit then drop the ACK.
+ if h.deferAccept != 0 && s.data.Size() == 0 && time.Since(h.startTime) < h.deferAccept {
+ h.acked = true
+ h.ep.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+
+ // If the timestamp option is negotiated and the segment does
+ // not carry a timestamp option then the segment must be dropped
+ // as per https://tools.ietf.org/html/rfc7323#section-3.2.
+ if h.ep.sendTSOk && !s.parsedOptions.TS {
+ h.ep.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+
+ // Update timestamp if required. See RFC7323, section-4.3.
+ if h.ep.sendTSOk && s.parsedOptions.TS {
+ h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
+ }
+ h.state = handshakeCompleted
+
+ h.ep.transitionToStateEstablishedLocked(h)
+
+ // If the segment has data then requeue it for the receiver
+ // to process it again once main loop is started.
+ if s.data.Size() > 0 {
+ s.incRef()
+ h.ep.enqueueSegment(s)
+ }
+ return nil
+ }
+
+ return nil
+}
+
+func (h *handshake) handleSegment(s *segment) *tcpip.Error {
+ h.sndWnd = s.window
+ if !s.flagIsSet(header.TCPFlagSyn) && h.sndWndScale > 0 {
+ h.sndWnd <<= uint8(h.sndWndScale)
+ }
+
+ switch h.state {
+ case handshakeSynRcvd:
+ return h.synRcvdState(s)
+ case handshakeSynSent:
+ return h.synSentState(s)
+ }
+ return nil
+}
+
+// processSegments goes through the segment queue and processes up to
+// maxSegmentsPerWake (if they're available).
+func (h *handshake) processSegments() *tcpip.Error {
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := h.ep.segmentQueue.dequeue()
+ if s == nil {
+ return nil
+ }
+
+ err := h.handleSegment(s)
+ s.decRef()
+ if err != nil {
+ return err
+ }
+
+ // We stop processing packets once the handshake is completed,
+ // otherwise we may process packets meant to be processed by
+ // the main protocol goroutine.
+ if h.state == handshakeCompleted {
+ break
+ }
+ }
+
+ // If the queue is not empty, make sure we'll wake up in the next
+ // iteration.
+ if !h.ep.segmentQueue.empty() {
+ h.ep.newSegmentWaker.Assert()
+ }
+
+ return nil
+}
+
+func (h *handshake) resolveRoute() *tcpip.Error {
+ // Set up the wakers.
+ s := sleep.Sleeper{}
+ resolutionWaker := &sleep.Waker{}
+ s.AddWaker(resolutionWaker, wakerForResolution)
+ s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
+ defer s.Done()
+
+ // Initial action is to resolve route.
+ index := wakerForResolution
+ for {
+ switch index {
+ case wakerForResolution:
+ if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
+ if err == tcpip.ErrNoLinkAddress {
+ h.ep.stats.SendErrors.NoLinkAddr.Increment()
+ } else if err != nil {
+ h.ep.stats.SendErrors.NoRoute.Increment()
+ }
+ // Either success (err == nil) or failure.
+ return err
+ }
+ // Resolution not completed. Keep trying...
+
+ case wakerForNotification:
+ n := h.ep.fetchNotifications()
+ if n&notifyClose != 0 {
+ h.ep.route.RemoveWaker(resolutionWaker)
+ return tcpip.ErrAborted
+ }
+ if n&notifyDrain != 0 {
+ close(h.ep.drainDone)
+ h.ep.mu.Unlock()
+ <-h.ep.undrain
+ h.ep.mu.Lock()
+ }
+ }
+
+ // Wait for notification.
+ index, _ = s.Fetch(true)
+ }
+}
+
+// execute executes the TCP 3-way handshake.
+func (h *handshake) execute() *tcpip.Error {
+ if h.ep.route.IsResolutionRequired() {
+ if err := h.resolveRoute(); err != nil {
+ return err
+ }
+ }
+
+ h.startTime = time.Now()
+ // Initialize the resend timer.
+ resendWaker := sleep.Waker{}
+ timeOut := time.Duration(time.Second)
+ rt := time.AfterFunc(timeOut, resendWaker.Assert)
+ defer rt.Stop()
+
+ // Set up the wakers.
+ s := sleep.Sleeper{}
+ s.AddWaker(&resendWaker, wakerForResend)
+ s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
+ s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
+ defer s.Done()
+
+ var sackEnabled SACKEnabled
+ if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil {
+ // If stack returned an error when checking for SACKEnabled
+ // status then just default to switching off SACK negotiation.
+ sackEnabled = false
+ }
+
+ // Send the initial SYN segment and loop until the handshake is
+ // completed.
+ h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
+
+ synOpts := header.TCPSynOptions{
+ WS: h.rcvWndScale,
+ TS: true,
+ TSVal: h.ep.timestamp(),
+ TSEcr: h.ep.recentTimestamp(),
+ SACKPermitted: bool(sackEnabled),
+ MSS: h.ep.amss,
+ }
+
+ // Execute is also called in a listen context so we want to make sure we
+ // only send the TS/SACK option when we received the TS/SACK in the
+ // initial SYN.
+ if h.state == handshakeSynRcvd {
+ synOpts.TS = h.ep.sendTSOk
+ synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled)
+ if h.sndWndScale < 0 {
+ // Disable window scaling if the peer did not send us
+ // the window scaling option.
+ synOpts.WS = -1
+ }
+ }
+
+ h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ id: h.ep.ID,
+ ttl: h.ep.ttl,
+ tos: h.ep.sendTOS,
+ flags: h.flags,
+ seq: h.iss,
+ ack: h.ackNum,
+ rcvWnd: h.rcvWnd,
+ }, synOpts)
+
+ for h.state != handshakeCompleted {
+ h.ep.mu.Unlock()
+ index, _ := s.Fetch(true)
+ h.ep.mu.Lock()
+ switch index {
+
+ case wakerForResend:
+ timeOut *= 2
+ if timeOut > MaxRTO {
+ return tcpip.ErrTimeout
+ }
+ rt.Reset(timeOut)
+ // Resend the SYN/SYN-ACK only if the following conditions hold.
+ // - It's an active handshake (deferAccept does not apply)
+ // - It's a passive handshake and we have not yet got the final-ACK.
+ // - It's a passive handshake and we got an ACK but deferAccept is
+ // enabled and we are now past the deferAccept duration.
+ // The last is required to provide a way for the peer to complete
+ // the connection with another ACK or data (as ACKs are never
+ // retransmitted on their own).
+ if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept {
+ h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ id: h.ep.ID,
+ ttl: h.ep.ttl,
+ tos: h.ep.sendTOS,
+ flags: h.flags,
+ seq: h.iss,
+ ack: h.ackNum,
+ rcvWnd: h.rcvWnd,
+ }, synOpts)
+ }
+
+ case wakerForNotification:
+ n := h.ep.fetchNotifications()
+ if (n&notifyClose)|(n&notifyAbort) != 0 {
+ return tcpip.ErrAborted
+ }
+ if n&notifyDrain != 0 {
+ for !h.ep.segmentQueue.empty() {
+ s := h.ep.segmentQueue.dequeue()
+ err := h.handleSegment(s)
+ s.decRef()
+ if err != nil {
+ return err
+ }
+ if h.state == handshakeCompleted {
+ return nil
+ }
+ }
+ close(h.ep.drainDone)
+ h.ep.mu.Unlock()
+ <-h.ep.undrain
+ h.ep.mu.Lock()
+ }
+
+ case wakerForNewSegment:
+ if err := h.processSegments(); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+func parseSynSegmentOptions(s *segment) header.TCPSynOptions {
+ synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck))
+ if synOpts.TS {
+ s.parsedOptions.TSVal = synOpts.TSVal
+ s.parsedOptions.TSEcr = synOpts.TSEcr
+ }
+ return synOpts
+}
+
+var optionPool = sync.Pool{
+ New: func() interface{} {
+ return &[maxOptionSize]byte{}
+ },
+}
+
+func getOptions() []byte {
+ return (*optionPool.Get().(*[maxOptionSize]byte))[:]
+}
+
+func putOptions(options []byte) {
+ // Reslice to full capacity.
+ optionPool.Put(optionsToArray(options))
+}
+
+func makeSynOptions(opts header.TCPSynOptions) []byte {
+ // Emulate linux option order. This is as follows:
+ //
+ // if md5: NOP NOP MD5SIG 18 md5sig(16)
+ // if mss: MSS 4 mss(2)
+ // if ts and sack_advertise:
+ // SACK 2 TIMESTAMP 2 timestamp(8)
+ // elif ts: NOP NOP TIMESTAMP 10 timestamp(8)
+ // elif sack: NOP NOP SACK 2
+ // if wscale: NOP WINDOW 3 ws(1)
+ // if sack_blocks: NOP NOP SACK ((2 + (#blocks * 8))
+ // [for each block] start_seq(4) end_seq(4)
+ // if fastopen_cookie:
+ // if exp: EXP (4 + len(cookie)) FASTOPEN_MAGIC(2)
+ // else: FASTOPEN (2 + len(cookie))
+ // cookie(variable) [padding to four bytes]
+ //
+ options := getOptions()
+
+ // Always encode the mss.
+ offset := header.EncodeMSSOption(uint32(opts.MSS), options)
+
+ // Special ordering is required here. If both TS and SACK are enabled,
+ // then the SACK option precedes TS, with no padding. If they are
+ // enabled individually, then we see padding before the option.
+ if opts.TS && opts.SACKPermitted {
+ offset += header.EncodeSACKPermittedOption(options[offset:])
+ offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:])
+ } else if opts.TS {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:])
+ } else if opts.SACKPermitted {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeSACKPermittedOption(options[offset:])
+ }
+
+ // Initialize the WS option.
+ if opts.WS >= 0 {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeWSOption(opts.WS, options[offset:])
+ }
+
+ // Padding to the end; note that this never apply unless we add a
+ // fastopen option, we always expect the offset to remain the same.
+ if delta := header.AddTCPOptionPadding(options, offset); delta != 0 {
+ panic("unexpected option encoding")
+ }
+
+ return options[:offset]
+}
+
+// tcpFields is a struct to carry different parameters required by the
+// send*TCP variant functions below.
+type tcpFields struct {
+ id stack.TransportEndpointID
+ ttl uint8
+ tos uint8
+ flags byte
+ seq seqnum.Value
+ ack seqnum.Value
+ rcvWnd seqnum.Size
+ opts []byte
+ txHash uint32
+}
+
+func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) *tcpip.Error {
+ tf.opts = makeSynOptions(opts)
+ // We ignore SYN send errors and let the callers re-attempt send.
+ if err := e.sendTCP(r, tf, buffer.VectorisedView{}, nil); err != nil {
+ e.stats.SendErrors.SynSendToNetworkFailed.Increment()
+ }
+ putOptions(tf.opts)
+ return nil
+}
+
+func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error {
+ tf.txHash = e.txHash
+ if err := sendTCP(r, tf, data, gso, e.owner); err != nil {
+ e.stats.SendErrors.SegmentSendToNetworkFailed.Increment()
+ return err
+ }
+ e.stats.SegmentsSent.Increment()
+ return nil
+}
+
+func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) {
+ optLen := len(tf.opts)
+ hdr := &pkt.Header
+ packetSize := pkt.Data.Size()
+ // Initialize the header.
+ tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
+ pkt.TransportHeader = buffer.View(tcp)
+ tcp.Encode(&header.TCPFields{
+ SrcPort: tf.id.LocalPort,
+ DstPort: tf.id.RemotePort,
+ SeqNum: uint32(tf.seq),
+ AckNum: uint32(tf.ack),
+ DataOffset: uint8(header.TCPMinimumSize + optLen),
+ Flags: tf.flags,
+ WindowSize: uint16(tf.rcvWnd),
+ })
+ copy(tcp[header.TCPMinimumSize:], tf.opts)
+
+ length := uint16(hdr.UsedLength() + packetSize)
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
+ // Only calculate the checksum if offloading isn't supported.
+ if gso != nil && gso.NeedsCsum {
+ // This is called CHECKSUM_PARTIAL in the Linux kernel. We
+ // calculate a checksum of the pseudo-header and save it in the
+ // TCP header, then the kernel calculate a checksum of the
+ // header and data and get the right sum of the TCP packet.
+ tcp.SetChecksum(xsum)
+ } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
+ xsum = header.ChecksumVV(pkt.Data, xsum)
+ tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
+ }
+}
+
+func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error {
+ // We need to shallow clone the VectorisedView here as ReadToView will
+ // split the VectorisedView and Trim underlying views as it splits. Not
+ // doing the clone here will cause the underlying views of data itself
+ // to be altered.
+ data = data.Clone(nil)
+
+ optLen := len(tf.opts)
+ if tf.rcvWnd > 0xffff {
+ tf.rcvWnd = 0xffff
+ }
+
+ mss := int(gso.MSS)
+ n := (data.Size() + mss - 1) / mss
+
+ size := data.Size()
+ hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen
+ var pkts stack.PacketBufferList
+ for i := 0; i < n; i++ {
+ packetSize := mss
+ if packetSize > size {
+ packetSize = size
+ }
+ size -= packetSize
+ var pkt stack.PacketBuffer
+ pkt.Header = buffer.NewPrependable(hdrSize)
+ pkt.Hash = tf.txHash
+ pkt.Owner = owner
+ pkt.EgressRoute = r
+ pkt.GSOOptions = gso
+ pkt.NetworkProtocolNumber = r.NetworkProtocolNumber()
+ data.ReadToVV(&pkt.Data, packetSize)
+ buildTCPHdr(r, tf, &pkt, gso)
+ tf.seq = tf.seq.Add(seqnum.Size(packetSize))
+ pkts.PushBack(&pkt)
+ }
+
+ if tf.ttl == 0 {
+ tf.ttl = r.DefaultTTL()
+ }
+ sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos})
+ if err != nil {
+ r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent))
+ }
+ r.Stats().TCP.SegmentsSent.IncrementBy(uint64(sent))
+ return err
+}
+
+// sendTCP sends a TCP segment with the provided options via the provided
+// network endpoint and under the provided identity.
+func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error {
+ optLen := len(tf.opts)
+ if tf.rcvWnd > 0xffff {
+ tf.rcvWnd = 0xffff
+ }
+
+ if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() {
+ return sendTCPBatch(r, tf, data, gso, owner)
+ }
+
+ pkt := &stack.PacketBuffer{
+ Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen),
+ Data: data,
+ Hash: tf.txHash,
+ Owner: owner,
+ }
+ buildTCPHdr(r, tf, pkt, gso)
+
+ if tf.ttl == 0 {
+ tf.ttl = r.DefaultTTL()
+ }
+ if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil {
+ r.Stats().TCP.SegmentSendErrors.Increment()
+ return err
+ }
+ r.Stats().TCP.SegmentsSent.Increment()
+ if (tf.flags & header.TCPFlagRst) != 0 {
+ r.Stats().TCP.ResetsSent.Increment()
+ }
+ return nil
+}
+
+// makeOptions makes an options slice.
+func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
+ options := getOptions()
+ offset := 0
+
+ // N.B. the ordering here matches the ordering used by Linux internally
+ // and described in the raw makeOptions function. We don't include
+ // unnecessary cases here (post connection.)
+ if e.sendTSOk {
+ // Embed the timestamp if timestamp has been enabled.
+ //
+ // We only use the lower 32 bits of the unix time in
+ // milliseconds. This is similar to what Linux does where it
+ // uses the lower 32 bits of the jiffies value in the tsVal
+ // field of the timestamp option.
+ //
+ // Further, RFC7323 section-5.4 recommends millisecond
+ // resolution as the lowest recommended resolution for the
+ // timestamp clock.
+ //
+ // Ref: https://tools.ietf.org/html/rfc7323#section-5.4.
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:])
+ }
+ if e.sackPermitted && len(sackBlocks) > 0 {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeSACKBlocks(sackBlocks, options[offset:])
+ }
+
+ // We expect the above to produce an aligned offset.
+ if delta := header.AddTCPOptionPadding(options, offset); delta != 0 {
+ panic("unexpected option encoding")
+ }
+
+ return options[:offset]
+}
+
+// sendRaw sends a TCP segment to the endpoint's peer.
+func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
+ var sackBlocks []header.SACKBlock
+ if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
+ sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
+ }
+ options := e.makeOptions(sackBlocks)
+ err := e.sendTCP(&e.route, tcpFields{
+ id: e.ID,
+ ttl: e.ttl,
+ tos: e.sendTOS,
+ flags: flags,
+ seq: seq,
+ ack: ack,
+ rcvWnd: rcvWnd,
+ opts: options,
+ }, data, e.gso)
+ putOptions(options)
+ return err
+}
+
+func (e *endpoint) handleWrite() *tcpip.Error {
+ // Move packets from send queue to send list. The queue is accessible
+ // from other goroutines and protected by the send mutex, while the send
+ // list is only accessible from the handler goroutine, so it needs no
+ // mutexes.
+ e.sndBufMu.Lock()
+
+ first := e.sndQueue.Front()
+ if first != nil {
+ e.snd.writeList.PushBackList(&e.sndQueue)
+ e.sndBufInQueue = 0
+ }
+
+ e.sndBufMu.Unlock()
+
+ // Initialize the next segment to write if it's currently nil.
+ if e.snd.writeNext == nil {
+ e.snd.writeNext = first
+ }
+
+ // Push out any new packets.
+ e.snd.sendData()
+
+ return nil
+}
+
+func (e *endpoint) handleClose() *tcpip.Error {
+ if !e.EndpointState().connected() {
+ return nil
+ }
+ // Drain the send queue.
+ e.handleWrite()
+
+ // Mark send side as closed.
+ e.snd.closed = true
+
+ return nil
+}
+
+// resetConnectionLocked puts the endpoint in an error state with the given
+// error code and sends a RST if and only if the error is not ErrConnectionReset
+// indicating that the connection is being reset due to receiving a RST. This
+// method must only be called from the protocol goroutine.
+func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
+ // Only send a reset if the connection is being aborted for a reason
+ // other than receiving a reset.
+ e.setEndpointState(StateError)
+ e.HardError = err
+ if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout {
+ // The exact sequence number to be used for the RST is the same as the
+ // one used by Linux. We need to handle the case of window being shrunk
+ // which can cause sndNxt to be outside the acceptable window on the
+ // receiver.
+ //
+ // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more
+ // information.
+ sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd)
+ resetSeqNum := sndWndEnd
+ if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1<<e.snd.sndWndScale) {
+ resetSeqNum = e.snd.sndNxt
+ }
+ e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.rcvNxt, 0)
+ }
+}
+
+// completeWorkerLocked is called by the worker goroutine when it's about to
+// exit.
+func (e *endpoint) completeWorkerLocked() {
+ // Worker is terminating(either due to moving to
+ // CLOSED or ERROR state, ensure we release all
+ // registrations port reservations even if the socket
+ // itself is not yet closed by the application.
+ e.workerRunning = false
+ if e.workerCleanup {
+ e.cleanupLocked()
+ }
+}
+
+// transitionToStateEstablisedLocked transitions a given endpoint
+// to an established state using the handshake parameters provided.
+// It also initializes sender/receiver.
+func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
+ // Transfer handshake state to TCP connection. We disable
+ // receive window scaling if the peer doesn't support it
+ // (indicated by a negative send window scale).
+ e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+
+ rcvBufSize := seqnum.Size(e.receiveBufferSize())
+ e.rcvListMu.Lock()
+ e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
+ // Bootstrap the auto tuning algorithm. Starting at zero will
+ // result in a really large receive window after the first auto
+ // tuning adjustment.
+ e.rcvAutoParams.prevCopied = int(h.rcvWnd)
+ e.rcvListMu.Unlock()
+
+ e.setEndpointState(StateEstablished)
+}
+
+// transitionToStateCloseLocked ensures that the endpoint is
+// cleaned up from the transport demuxer, "before" moving to
+// StateClose. This will ensure that no packet will be
+// delivered to this endpoint from the demuxer when the endpoint
+// is transitioned to StateClose.
+func (e *endpoint) transitionToStateCloseLocked() {
+ if e.EndpointState() == StateClose {
+ return
+ }
+ // Mark the endpoint as fully closed for reads/writes.
+ e.cleanupLocked()
+ e.setEndpointState(StateClose)
+ e.stack.Stats().TCP.CurrentConnected.Decrement()
+ e.stack.Stats().TCP.EstablishedClosed.Increment()
+}
+
+// tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed
+// segment to any other endpoint other than the current one. This is called
+// only when the endpoint is in StateClose and we want to deliver the segment
+// to any other listening endpoint. We reply with RST if we cannot find one.
+func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
+ ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route)
+ if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
+ // Dual-stack socket, try IPv4.
+ ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route)
+ }
+ if ep == nil {
+ replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
+ s.decRef()
+ return
+ }
+
+ if e == ep {
+ panic("current endpoint not removed from demuxer, enqueing segments to itself")
+ }
+
+ if ep := ep.(*endpoint); ep.enqueueSegment(s) {
+ ep.newSegmentWaker.Assert()
+ }
+}
+
+// Drain segment queue from the endpoint and try to re-match the segment to a
+// different endpoint. This is used when the current endpoint is transitioned to
+// StateClose and has been unregistered from the transport demuxer.
+func (e *endpoint) drainClosingSegmentQueue() {
+ for {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ break
+ }
+
+ e.tryDeliverSegmentFromClosedEndpoint(s)
+ }
+}
+
+func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
+ if e.rcv.acceptable(s.sequenceNumber, 0) {
+ // RFC 793, page 37 states that "in all states
+ // except SYN-SENT, all reset (RST) segments are
+ // validated by checking their SEQ-fields." So
+ // we only process it if it's acceptable.
+ switch e.EndpointState() {
+ // In case of a RST in CLOSE-WAIT linux moves
+ // the socket to closed state with an error set
+ // to indicate EPIPE.
+ //
+ // Technically this seems to be at odds w/ RFC.
+ // As per https://tools.ietf.org/html/rfc793#section-2.7
+ // page 69 the behavior for a segment arriving
+ // w/ RST bit set in CLOSE-WAIT is inlined below.
+ //
+ // ESTABLISHED
+ // FIN-WAIT-1
+ // FIN-WAIT-2
+ // CLOSE-WAIT
+
+ // If the RST bit is set then, any outstanding RECEIVEs and
+ // SEND should receive "reset" responses. All segment queues
+ // should be flushed. Users should also receive an unsolicited
+ // general "connection reset" signal. Enter the CLOSED state,
+ // delete the TCB, and return.
+ case StateCloseWait:
+ e.transitionToStateCloseLocked()
+ e.HardError = tcpip.ErrAborted
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ return false, nil
+ default:
+ // RFC 793, page 37 states that "in all states
+ // except SYN-SENT, all reset (RST) segments are
+ // validated by checking their SEQ-fields." So
+ // we only process it if it's acceptable.
+
+ // Notify protocol goroutine. This is required when
+ // handleSegment is invoked from the processor goroutine
+ // rather than the worker goroutine.
+ e.notifyProtocolGoroutine(notifyResetByPeer)
+ return false, tcpip.ErrConnectionReset
+ }
+ }
+ return true, nil
+}
+
+// handleSegments processes all inbound segments.
+func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
+ checkRequeue := true
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ if e.EndpointState().closed() {
+ return nil
+ }
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ checkRequeue = false
+ break
+ }
+
+ cont, err := e.handleSegment(s)
+ if err != nil {
+ s.decRef()
+ return err
+ }
+ if !cont {
+ s.decRef()
+ return nil
+ }
+ }
+
+ // When fastPath is true we don't want to wake up the worker
+ // goroutine. If the endpoint has more segments to process the
+ // dispatcher will call handleSegments again anyway.
+ if !fastPath && checkRequeue && !e.segmentQueue.empty() {
+ e.newSegmentWaker.Assert()
+ }
+
+ // Send an ACK for all processed packets if needed.
+ if e.rcv.rcvNxt != e.snd.maxSentAck {
+ e.snd.sendAck()
+ }
+
+ e.resetKeepaliveTimer(true /* receivedData */)
+
+ return nil
+}
+
+// handleSegment handles a given segment and notifies the worker goroutine if
+// if the connection should be terminated.
+func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
+ // Invoke the tcp probe if installed.
+ if e.probe != nil {
+ e.probe(e.completeState())
+ }
+
+ if s.flagIsSet(header.TCPFlagRst) {
+ if ok, err := e.handleReset(s); !ok {
+ return false, err
+ }
+ } else if s.flagIsSet(header.TCPFlagSyn) {
+ // See: https://tools.ietf.org/html/rfc5961#section-4.1
+ // 1) If the SYN bit is set, irrespective of the sequence number, TCP
+ // MUST send an ACK (also referred to as challenge ACK) to the remote
+ // peer:
+ //
+ // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK>
+ //
+ // After sending the acknowledgment, TCP MUST drop the unacceptable
+ // segment and stop processing further.
+ //
+ // By sending an ACK, the remote peer is challenged to confirm the loss
+ // of the previous connection and the request to start a new connection.
+ // A legitimate peer, after restart, would not have a TCB in the
+ // synchronized state. Thus, when the ACK arrives, the peer should send
+ // a RST segment back with the sequence number derived from the ACK
+ // field that caused the RST.
+
+ // This RST will confirm that the remote peer has indeed closed the
+ // previous connection. Upon receipt of a valid RST, the local TCP
+ // endpoint MUST terminate its connection. The local TCP endpoint
+ // should then rely on SYN retransmission from the remote end to
+ // re-establish the connection.
+
+ e.snd.sendAck()
+ } else if s.flagIsSet(header.TCPFlagAck) {
+ // Patch the window size in the segment according to the
+ // send window scale.
+ s.window <<= e.snd.sndWndScale
+
+ // RFC 793, page 41 states that "once in the ESTABLISHED
+ // state all segments must carry current acknowledgment
+ // information."
+ drop, err := e.rcv.handleRcvdSegment(s)
+ if err != nil {
+ return false, err
+ }
+ if drop {
+ return true, nil
+ }
+
+ // Now check if the received segment has caused us to transition
+ // to a CLOSED state, if yes then terminate processing and do
+ // not invoke the sender.
+ state := e.state
+ if state == StateClose {
+ // When we get into StateClose while processing from the queue,
+ // return immediately and let the protocolMainloop handle it.
+ //
+ // We can reach StateClose only while processing a previous segment
+ // or a notification from the protocolMainLoop (caller goroutine).
+ // This means that with this return, the segment dequeue below can
+ // never occur on a closed endpoint.
+ s.decRef()
+ return false, nil
+ }
+
+ e.snd.handleRcvdSegment(s)
+ }
+
+ return true, nil
+}
+
+// keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP
+// keepalive packets periodically when the connection is idle. If we don't hear
+// from the other side after a number of tries, we terminate the connection.
+func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
+ userTimeout := e.userTimeout
+
+ e.keepalive.Lock()
+ if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() {
+ e.keepalive.Unlock()
+ return nil
+ }
+
+ // If a userTimeout is set then abort the connection if it is
+ // exceeded.
+ if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 {
+ e.keepalive.Unlock()
+ e.stack.Stats().TCP.EstablishedTimedout.Increment()
+ return tcpip.ErrTimeout
+ }
+
+ if e.keepalive.unacked >= e.keepalive.count {
+ e.keepalive.Unlock()
+ e.stack.Stats().TCP.EstablishedTimedout.Increment()
+ return tcpip.ErrTimeout
+ }
+
+ // RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with
+ // seg.seq = snd.nxt-1.
+ e.keepalive.unacked++
+ e.keepalive.Unlock()
+ e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.sndNxt-1)
+ e.resetKeepaliveTimer(false)
+ return nil
+}
+
+// resetKeepaliveTimer restarts or stops the keepalive timer, depending on
+// whether it is enabled for this endpoint.
+func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
+ e.keepalive.Lock()
+ if receivedData {
+ e.keepalive.unacked = 0
+ }
+ // Start the keepalive timer IFF it's enabled and there is no pending
+ // data to send.
+ if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt {
+ e.keepalive.timer.disable()
+ e.keepalive.Unlock()
+ return
+ }
+ if e.keepalive.unacked > 0 {
+ e.keepalive.timer.enable(e.keepalive.interval)
+ } else {
+ e.keepalive.timer.enable(e.keepalive.idle)
+ }
+ e.keepalive.Unlock()
+}
+
+// disableKeepaliveTimer stops the keepalive timer.
+func (e *endpoint) disableKeepaliveTimer() {
+ e.keepalive.Lock()
+ e.keepalive.timer.disable()
+ e.keepalive.Unlock()
+}
+
+// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
+// goroutine and is responsible for sending segments and handling received
+// segments.
+func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error {
+ e.mu.Lock()
+ var closeTimer *time.Timer
+ var closeWaker sleep.Waker
+
+ epilogue := func() {
+ // e.mu is expected to be hold upon entering this section.
+
+ if e.snd != nil {
+ e.snd.resendTimer.cleanup()
+ }
+
+ if closeTimer != nil {
+ closeTimer.Stop()
+ }
+
+ e.completeWorkerLocked()
+
+ if e.drainDone != nil {
+ close(e.drainDone)
+ }
+
+ e.mu.Unlock()
+
+ e.drainClosingSegmentQueue()
+
+ // When the protocol loop exits we should wake up our waiters.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ }
+
+ if handshake {
+ // This is an active connection, so we must initiate the 3-way
+ // handshake, and then inform potential waiters about its
+ // completion.
+ initialRcvWnd := e.initialReceiveWindow()
+ h := newHandshake(e, seqnum.Size(initialRcvWnd))
+ h.ep.setEndpointState(StateSynSent)
+
+ if err := h.execute(); err != nil {
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ e.setEndpointState(StateError)
+ e.HardError = err
+
+ e.workerCleanup = true
+ // Lock released below.
+ epilogue()
+ return err
+ }
+ }
+
+ e.keepalive.timer.init(&e.keepalive.waker)
+ defer e.keepalive.timer.cleanup()
+
+ drained := e.drainDone != nil
+ if drained {
+ close(e.drainDone)
+ <-e.undrain
+ }
+
+ // Set up the functions that will be called when the main protocol loop
+ // wakes up.
+ funcs := []struct {
+ w *sleep.Waker
+ f func() *tcpip.Error
+ }{
+ {
+ w: &e.sndWaker,
+ f: e.handleWrite,
+ },
+ {
+ w: &e.sndCloseWaker,
+ f: e.handleClose,
+ },
+ {
+ w: &closeWaker,
+ f: func() *tcpip.Error {
+ // This means the socket is being closed due
+ // to the TCP-FIN-WAIT2 timeout was hit. Just
+ // mark the socket as closed.
+ e.transitionToStateCloseLocked()
+ e.workerCleanup = true
+ return nil
+ },
+ },
+ {
+ w: &e.snd.resendWaker,
+ f: func() *tcpip.Error {
+ if !e.snd.retransmitTimerExpired() {
+ e.stack.Stats().TCP.EstablishedTimedout.Increment()
+ return tcpip.ErrTimeout
+ }
+ return nil
+ },
+ },
+ {
+ w: &e.newSegmentWaker,
+ f: func() *tcpip.Error {
+ return e.handleSegments(false /* fastPath */)
+ },
+ },
+ {
+ w: &e.keepalive.waker,
+ f: e.keepaliveTimerExpired,
+ },
+ {
+ w: &e.notificationWaker,
+ f: func() *tcpip.Error {
+ n := e.fetchNotifications()
+ if n&notifyNonZeroReceiveWindow != 0 {
+ e.rcv.nonZeroWindow()
+ }
+
+ if n&notifyReceiveWindowChanged != 0 {
+ e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize())
+ }
+
+ if n&notifyMTUChanged != 0 {
+ e.sndBufMu.Lock()
+ count := e.packetTooBigCount
+ e.packetTooBigCount = 0
+ mtu := e.sndMTU
+ e.sndBufMu.Unlock()
+
+ e.snd.updateMaxPayloadSize(mtu, count)
+ }
+
+ if n&notifyReset != 0 || n&notifyAbort != 0 {
+ return tcpip.ErrConnectionAborted
+ }
+
+ if n&notifyResetByPeer != 0 {
+ return tcpip.ErrConnectionReset
+ }
+
+ if n&notifyClose != 0 && closeTimer == nil {
+ if e.EndpointState() == StateFinWait2 && e.closed {
+ // The socket has been closed and we are in FIN_WAIT2
+ // so start the FIN_WAIT2 timer.
+ closeTimer = time.AfterFunc(e.tcpLingerTimeout, closeWaker.Assert)
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ }
+ }
+
+ if n&notifyKeepaliveChanged != 0 {
+ // The timer could fire in background
+ // when the endpoint is drained. That's
+ // OK. See above.
+ e.resetKeepaliveTimer(true)
+ }
+
+ if n&notifyDrain != 0 {
+ for !e.segmentQueue.empty() {
+ if err := e.handleSegments(false /* fastPath */); err != nil {
+ return err
+ }
+ }
+ if !e.EndpointState().closed() {
+ // Only block the worker if the endpoint
+ // is not in closed state or error state.
+ close(e.drainDone)
+ e.mu.Unlock()
+ <-e.undrain
+ e.mu.Lock()
+ }
+ }
+
+ if n&notifyTickleWorker != 0 {
+ // Just a tickle notification. No need to do
+ // anything.
+ return nil
+ }
+
+ return nil
+ },
+ },
+ }
+
+ // Initialize the sleeper based on the wakers in funcs.
+ s := sleep.Sleeper{}
+ for i := range funcs {
+ s.AddWaker(funcs[i].w, i)
+ }
+
+ // Notify the caller that the waker initialization is complete and the
+ // endpoint is ready.
+ if wakerInitDone != nil {
+ close(wakerInitDone)
+ }
+
+ // Tell waiters that the endpoint is connected and writable.
+ e.waiterQueue.Notify(waiter.EventOut)
+
+ // The following assertions and notifications are needed for restored
+ // endpoints. Fresh newly created endpoints have empty states and should
+ // not invoke any.
+ if !e.segmentQueue.empty() {
+ e.newSegmentWaker.Assert()
+ }
+
+ e.rcvListMu.Lock()
+ if !e.rcvList.Empty() {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+ e.rcvListMu.Unlock()
+
+ if e.workerCleanup {
+ e.notifyProtocolGoroutine(notifyClose)
+ }
+
+ // Main loop. Handle segments until both send and receive ends of the
+ // connection have completed.
+ cleanupOnError := func(err *tcpip.Error) {
+ e.stack.Stats().TCP.CurrentConnected.Decrement()
+ e.workerCleanup = true
+ if err != nil {
+ e.resetConnectionLocked(err)
+ }
+ // Lock released below.
+ epilogue()
+ }
+
+loop:
+ for {
+ switch e.EndpointState() {
+ case StateTimeWait, StateClose, StateError:
+ break loop
+ }
+
+ e.mu.Unlock()
+ v, _ := s.Fetch(true)
+ e.mu.Lock()
+
+ // We need to double check here because the notification may be
+ // stale by the time we got around to processing it.
+ switch e.EndpointState() {
+ case StateError:
+ // If the endpoint has already transitioned to an ERROR
+ // state just pass nil here as any reset that may need
+ // to be sent etc should already have been done and we
+ // just want to terminate the loop and cleanup the
+ // endpoint.
+ cleanupOnError(nil)
+ return nil
+ case StateTimeWait:
+ fallthrough
+ case StateClose:
+ break loop
+ default:
+ if err := funcs[v].f(); err != nil {
+ cleanupOnError(err)
+ return nil
+ }
+ }
+ }
+
+ var reuseTW func()
+ if e.EndpointState() == StateTimeWait {
+ // Disable close timer as we now entering real TIME_WAIT.
+ if closeTimer != nil {
+ closeTimer.Stop()
+ }
+ // Mark the current sleeper done so as to free all associated
+ // wakers.
+ s.Done()
+ // Wake up any waiters before we enter TIME_WAIT.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ e.workerCleanup = true
+ reuseTW = e.doTimeWait()
+ }
+
+ // Handle any StateError transition from StateTimeWait.
+ if e.EndpointState() == StateError {
+ cleanupOnError(nil)
+ return nil
+ }
+
+ e.transitionToStateCloseLocked()
+
+ // Lock released below.
+ epilogue()
+
+ // A new SYN was received during TIME_WAIT and we need to abort
+ // the timewait and redirect the segment to the listener queue
+ if reuseTW != nil {
+ reuseTW()
+ }
+
+ return nil
+}
+
+// handleTimeWaitSegments processes segments received during TIME_WAIT
+// state.
+func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()) {
+ checkRequeue := true
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ checkRequeue = false
+ break
+ }
+ extTW, newSyn := e.rcv.handleTimeWaitSegment(s)
+ if newSyn {
+ info := e.EndpointInfo.TransportEndpointInfo
+ newID := info.ID
+ newID.RemoteAddress = ""
+ newID.RemotePort = 0
+ netProtos := []tcpip.NetworkProtocolNumber{info.NetProto}
+ // If the local address is an IPv4 address then also
+ // look for IPv6 dual stack endpoints that might be
+ // listening on the local address.
+ if newID.LocalAddress.To4() != "" {
+ netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber}
+ }
+ for _, netProto := range netProtos {
+ if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil {
+ tcpEP := listenEP.(*endpoint)
+ if EndpointState(tcpEP.State()) == StateListen {
+ reuseTW = func() {
+ if !tcpEP.enqueueSegment(s) {
+ s.decRef()
+ return
+ }
+ tcpEP.newSegmentWaker.Assert()
+ }
+ // We explicitly do not decRef
+ // the segment as it's still
+ // valid and being reflected to
+ // a listening endpoint.
+ return false, reuseTW
+ }
+ }
+ }
+ }
+ if extTW {
+ extendTimeWait = true
+ }
+ s.decRef()
+ }
+ if checkRequeue && !e.segmentQueue.empty() {
+ e.newSegmentWaker.Assert()
+ }
+ return extendTimeWait, nil
+}
+
+// doTimeWait is responsible for handling the TCP behaviour once a socket
+// enters the TIME_WAIT state. Optionally it can return a closure that
+// should be executed after releasing the endpoint registrations. This is
+// done in cases where a new SYN is received during TIME_WAIT that carries
+// a sequence number larger than one see on the connection.
+func (e *endpoint) doTimeWait() (twReuse func()) {
+ // Trigger a 2 * MSL time wait state. During this period
+ // we will drop all incoming segments.
+ // NOTE: On Linux this is not configurable and is fixed at 60 seconds.
+ timeWaitDuration := DefaultTCPTimeWaitTimeout
+
+ // Get the stack wide configuration.
+ var tcpTW tcpip.TCPTimeWaitTimeoutOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &tcpTW); err == nil {
+ timeWaitDuration = time.Duration(tcpTW)
+ }
+
+ const newSegment = 1
+ const notification = 2
+ const timeWaitDone = 3
+
+ s := sleep.Sleeper{}
+ defer s.Done()
+ s.AddWaker(&e.newSegmentWaker, newSegment)
+ s.AddWaker(&e.notificationWaker, notification)
+
+ var timeWaitWaker sleep.Waker
+ s.AddWaker(&timeWaitWaker, timeWaitDone)
+ timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert)
+ defer timeWaitTimer.Stop()
+
+ for {
+ e.mu.Unlock()
+ v, _ := s.Fetch(true)
+ e.mu.Lock()
+ switch v {
+ case newSegment:
+ extendTimeWait, reuseTW := e.handleTimeWaitSegments()
+ if reuseTW != nil {
+ return reuseTW
+ }
+ if extendTimeWait {
+ timeWaitTimer.Reset(timeWaitDuration)
+ }
+ case notification:
+ n := e.fetchNotifications()
+ if n&notifyClose != 0 || n&notifyAbort != 0 {
+ return nil
+ }
+ if n&notifyDrain != 0 {
+ for !e.segmentQueue.empty() {
+ // Ignore extending TIME_WAIT during a
+ // save. For sockets in TIME_WAIT we just
+ // terminate the TIME_WAIT early.
+ e.handleTimeWaitSegments()
+ }
+ close(e.drainDone)
+ e.mu.Unlock()
+ <-e.undrain
+ e.mu.Lock()
+ return nil
+ }
+ case timeWaitDone:
+ return nil
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/connect_unsafe.go b/pkg/tcpip/transport/tcp/connect_unsafe.go
new file mode 100644
index 000000000..cfc304616
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/connect_unsafe.go
@@ -0,0 +1,30 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// optionsToArray converts a slice of capacity >-= maxOptionSize to an array.
+//
+// optionsToArray panics if the capacity of options is smaller than
+// maxOptionSize.
+func optionsToArray(options []byte) *[maxOptionSize]byte {
+ // Reslice to full capacity.
+ options = options[0:maxOptionSize]
+ return (*[maxOptionSize]byte)(unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&options)).Data))
+}
diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go
new file mode 100644
index 000000000..7b1f5e763
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/cubic.go
@@ -0,0 +1,234 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "math"
+ "time"
+)
+
+// cubicState stores the variables related to TCP CUBIC congestion
+// control algorithm state.
+//
+// See: https://tools.ietf.org/html/rfc8312.
+// +stateify savable
+type cubicState struct {
+ // wLastMax is the previous wMax value.
+ wLastMax float64
+
+ // wMax is the value of the congestion window at the
+ // time of last congestion event.
+ wMax float64
+
+ // t denotes the time when the current congestion avoidance
+ // was entered.
+ t time.Time `state:".(unixTime)"`
+
+ // numCongestionEvents tracks the number of congestion events since last
+ // RTO.
+ numCongestionEvents int
+
+ // c is the cubic constant as specified in RFC8312. It's fixed at 0.4 as
+ // per RFC.
+ c float64
+
+ // k is the time period that the above function takes to increase the
+ // current window size to W_max if there are no further congestion
+ // events and is calculated using the following equation:
+ //
+ // K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2)
+ k float64
+
+ // beta is the CUBIC multiplication decrease factor. that is, when a
+ // congestion event is detected, CUBIC reduces its cwnd to
+ // W_cubic(0)=W_max*beta_cubic.
+ beta float64
+
+ // wC is window computed by CUBIC at time t. It's calculated using the
+ // formula:
+ //
+ // W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1)
+ wC float64
+
+ // wEst is the window computed by CUBIC at time t+RTT i.e
+ // W_cubic(t+RTT).
+ wEst float64
+
+ s *sender
+}
+
+// newCubicCC returns a partially initialized cubic state with the constants
+// beta and c set and t set to current time.
+func newCubicCC(s *sender) *cubicState {
+ return &cubicState{
+ t: time.Now(),
+ beta: 0.7,
+ c: 0.4,
+ s: s,
+ }
+}
+
+// enterCongestionAvoidance is used to initialize cubic in cases where we exit
+// SlowStart without a real congestion event taking place. This can happen when
+// a connection goes back to slow start due to a retransmit and we exceed the
+// previously lowered ssThresh without experiencing packet loss.
+//
+// Refer: https://tools.ietf.org/html/rfc8312#section-4.8
+func (c *cubicState) enterCongestionAvoidance() {
+ // See: https://tools.ietf.org/html/rfc8312#section-4.7 &
+ // https://tools.ietf.org/html/rfc8312#section-4.8
+ if c.numCongestionEvents == 0 {
+ c.k = 0
+ c.t = time.Now()
+ c.wLastMax = c.wMax
+ c.wMax = float64(c.s.sndCwnd)
+ }
+}
+
+// updateSlowStart will update the congestion window as per the slow-start
+// algorithm used by NewReno. If after adjusting the congestion window we cross
+// the ssThresh then it will return the number of packets that must be consumed
+// in congestion avoidance mode.
+func (c *cubicState) updateSlowStart(packetsAcked int) int {
+ // Don't let the congestion window cross into the congestion
+ // avoidance range.
+ newcwnd := c.s.sndCwnd + packetsAcked
+ enterCA := false
+ if newcwnd >= c.s.sndSsthresh {
+ newcwnd = c.s.sndSsthresh
+ c.s.sndCAAckCount = 0
+ enterCA = true
+ }
+
+ packetsAcked -= newcwnd - c.s.sndCwnd
+ c.s.sndCwnd = newcwnd
+ if enterCA {
+ c.enterCongestionAvoidance()
+ }
+ return packetsAcked
+}
+
+// Update updates cubic's internal state variables. It must be called on every
+// ACK received.
+// Refer: https://tools.ietf.org/html/rfc8312#section-4
+func (c *cubicState) Update(packetsAcked int) {
+ if c.s.sndCwnd < c.s.sndSsthresh {
+ packetsAcked = c.updateSlowStart(packetsAcked)
+ if packetsAcked == 0 {
+ return
+ }
+ } else {
+ c.s.rtt.Lock()
+ srtt := c.s.rtt.srtt
+ c.s.rtt.Unlock()
+ c.s.sndCwnd = c.getCwnd(packetsAcked, c.s.sndCwnd, srtt)
+ }
+}
+
+// cubicCwnd computes the CUBIC congestion window after t seconds from last
+// congestion event.
+func (c *cubicState) cubicCwnd(t float64) float64 {
+ return c.c*math.Pow(t, 3.0) + c.wMax
+}
+
+// getCwnd returns the current congestion window as computed by CUBIC.
+// Refer: https://tools.ietf.org/html/rfc8312#section-4
+func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int {
+ elapsed := time.Since(c.t).Seconds()
+
+ // Compute the window as per Cubic after 'elapsed' time
+ // since last congestion event.
+ c.wC = c.cubicCwnd(elapsed - c.k)
+
+ // Compute the TCP friendly estimate of the congestion window.
+ c.wEst = c.wMax*c.beta + (3.0*((1.0-c.beta)/(1.0+c.beta)))*(elapsed/srtt.Seconds())
+
+ // Make sure in the TCP friendly region CUBIC performs at least
+ // as well as Reno.
+ if c.wC < c.wEst && float64(sndCwnd) < c.wEst {
+ // TCP Friendly region of cubic.
+ return int(c.wEst)
+ }
+
+ // In Concave/Convex region of CUBIC, calculate what CUBIC window
+ // will be after 1 RTT and use that to grow congestion window
+ // for every ack.
+ tEst := (time.Since(c.t) + srtt).Seconds()
+ wtRtt := c.cubicCwnd(tEst - c.k)
+ // As per 4.3 for each received ACK cwnd must be incremented
+ // by (w_cubic(t+RTT) - cwnd/cwnd.
+ cwnd := float64(sndCwnd)
+ for i := 0; i < packetsAcked; i++ {
+ // Concave/Convex regions of cubic have the same formulas.
+ // See: https://tools.ietf.org/html/rfc8312#section-4.3
+ cwnd += (wtRtt - cwnd) / cwnd
+ }
+ return int(cwnd)
+}
+
+// HandleNDupAcks implements congestionControl.HandleNDupAcks.
+func (c *cubicState) HandleNDupAcks() {
+ // See: https://tools.ietf.org/html/rfc8312#section-4.5
+ c.numCongestionEvents++
+ c.t = time.Now()
+ c.wLastMax = c.wMax
+ c.wMax = float64(c.s.sndCwnd)
+
+ c.fastConvergence()
+ c.reduceSlowStartThreshold()
+}
+
+// HandleRTOExpired implements congestionContrl.HandleRTOExpired.
+func (c *cubicState) HandleRTOExpired() {
+ // See: https://tools.ietf.org/html/rfc8312#section-4.6
+ c.t = time.Now()
+ c.numCongestionEvents = 0
+ c.wLastMax = c.wMax
+ c.wMax = float64(c.s.sndCwnd)
+
+ c.fastConvergence()
+
+ // We lost a packet, so reduce ssthresh.
+ c.reduceSlowStartThreshold()
+
+ // Reduce the congestion window to 1, i.e., enter slow-start. Per
+ // RFC 5681, page 7, we must use 1 regardless of the value of the
+ // initial congestion window.
+ c.s.sndCwnd = 1
+}
+
+// fastConvergence implements the logic for Fast Convergence algorithm as
+// described in https://tools.ietf.org/html/rfc8312#section-4.6.
+func (c *cubicState) fastConvergence() {
+ if c.wMax < c.wLastMax {
+ c.wLastMax = c.wMax
+ c.wMax = c.wMax * (1.0 + c.beta) / 2.0
+ } else {
+ c.wLastMax = c.wMax
+ }
+ // Recompute k as wMax may have changed.
+ c.k = math.Cbrt(c.wMax * (1 - c.beta) / c.c)
+}
+
+// PostRecovery implemements congestionControl.PostRecovery.
+func (c *cubicState) PostRecovery() {
+ c.t = time.Now()
+}
+
+// reduceSlowStartThreshold returns new SsThresh as described in
+// https://tools.ietf.org/html/rfc8312#section-4.7.
+func (c *cubicState) reduceSlowStartThreshold() {
+ c.s.sndSsthresh = int(math.Max(float64(c.s.sndCwnd)*c.beta, 2.0))
+}
diff --git a/pkg/tcpip/transport/tcp/cubic_state.go b/pkg/tcpip/transport/tcp/cubic_state.go
new file mode 100644
index 000000000..d0f58cfaf
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/cubic_state.go
@@ -0,0 +1,29 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+)
+
+// saveT is invoked by stateify.
+func (c *cubicState) saveT() unixTime {
+ return unixTime{c.t.Unix(), c.t.UnixNano()}
+}
+
+// loadT is invoked by stateify.
+func (c *cubicState) loadT(unix unixTime) {
+ c.t = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
new file mode 100644
index 000000000..98aecab9e
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -0,0 +1,234 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// epQueue is a queue of endpoints.
+type epQueue struct {
+ mu sync.Mutex
+ list endpointList
+}
+
+// enqueue adds e to the queue if the endpoint is not already on the queue.
+func (q *epQueue) enqueue(e *endpoint) {
+ q.mu.Lock()
+ if e.pendingProcessing {
+ q.mu.Unlock()
+ return
+ }
+ q.list.PushBack(e)
+ e.pendingProcessing = true
+ q.mu.Unlock()
+}
+
+// dequeue removes and returns the first element from the queue if available,
+// returns nil otherwise.
+func (q *epQueue) dequeue() *endpoint {
+ q.mu.Lock()
+ if e := q.list.Front(); e != nil {
+ q.list.Remove(e)
+ e.pendingProcessing = false
+ q.mu.Unlock()
+ return e
+ }
+ q.mu.Unlock()
+ return nil
+}
+
+// empty returns true if the queue is empty, false otherwise.
+func (q *epQueue) empty() bool {
+ q.mu.Lock()
+ v := q.list.Empty()
+ q.mu.Unlock()
+ return v
+}
+
+// processor is responsible for processing packets queued to a tcp endpoint.
+type processor struct {
+ epQ epQueue
+ sleeper sleep.Sleeper
+ newEndpointWaker sleep.Waker
+ closeWaker sleep.Waker
+}
+
+func (p *processor) close() {
+ p.closeWaker.Assert()
+}
+
+func (p *processor) queueEndpoint(ep *endpoint) {
+ // Queue an endpoint for processing by the processor goroutine.
+ p.epQ.enqueue(ep)
+ p.newEndpointWaker.Assert()
+}
+
+const (
+ newEndpointWaker = 1
+ closeWaker = 2
+)
+
+func (p *processor) start(wg *sync.WaitGroup) {
+ defer wg.Done()
+ defer p.sleeper.Done()
+
+ for {
+ if id, _ := p.sleeper.Fetch(true); id == closeWaker {
+ break
+ }
+ for {
+ ep := p.epQ.dequeue()
+ if ep == nil {
+ break
+ }
+ if ep.segmentQueue.empty() {
+ continue
+ }
+
+ // If socket has transitioned out of connected state then just let the
+ // worker handle the packet.
+ //
+ // NOTE: We read this outside of e.mu lock which means that by the time
+ // we get to handleSegments the endpoint may not be in ESTABLISHED. But
+ // this should be fine as all normal shutdown states are handled by
+ // handleSegments and if the endpoint moves to a CLOSED/ERROR state
+ // then handleSegments is a noop.
+ if ep.EndpointState() == StateEstablished && ep.mu.TryLock() {
+ // If the endpoint is in a connected state then we do direct delivery
+ // to ensure low latency and avoid scheduler interactions.
+ switch err := ep.handleSegments(true /* fastPath */); {
+ case err != nil:
+ // Send any active resets if required.
+ ep.resetConnectionLocked(err)
+ fallthrough
+ case ep.EndpointState() == StateClose:
+ ep.notifyProtocolGoroutine(notifyTickleWorker)
+ case !ep.segmentQueue.empty():
+ p.epQ.enqueue(ep)
+ }
+ ep.mu.Unlock()
+ } else {
+ ep.newSegmentWaker.Assert()
+ }
+ }
+ }
+}
+
+// dispatcher manages a pool of TCP endpoint processors which are responsible
+// for the processing of inbound segments. This fixed pool of processor
+// goroutines do full tcp processing. The processor is selected based on the
+// hash of the endpoint id to ensure that delivery for the same endpoint happens
+// in-order.
+type dispatcher struct {
+ processors []processor
+ seed uint32
+ wg sync.WaitGroup
+}
+
+func (d *dispatcher) init(nProcessors int) {
+ d.close()
+ d.wait()
+ d.processors = make([]processor, nProcessors)
+ d.seed = generateRandUint32()
+ for i := range d.processors {
+ p := &d.processors[i]
+ p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker)
+ p.sleeper.AddWaker(&p.closeWaker, closeWaker)
+ d.wg.Add(1)
+ // NB: sleeper-waker registration must happen synchronously to avoid races
+ // with `close`. It's possible to pull all this logic into `start`, but
+ // that results in a heap-allocated function literal.
+ go p.start(&d.wg)
+ }
+}
+
+func (d *dispatcher) close() {
+ for i := range d.processors {
+ d.processors[i].close()
+ }
+}
+
+func (d *dispatcher) wait() {
+ d.wg.Wait()
+}
+
+func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+ ep := stackEP.(*endpoint)
+ s := newSegment(r, id, pkt)
+ if !s.parse() {
+ ep.stack.Stats().MalformedRcvdPackets.Increment()
+ ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
+ ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
+ s.decRef()
+ return
+ }
+
+ if !s.csumValid {
+ ep.stack.Stats().MalformedRcvdPackets.Increment()
+ ep.stack.Stats().TCP.ChecksumErrors.Increment()
+ ep.stats.ReceiveErrors.ChecksumErrors.Increment()
+ s.decRef()
+ return
+ }
+
+ ep.stack.Stats().TCP.ValidSegmentsReceived.Increment()
+ ep.stats.SegmentsReceived.Increment()
+ if (s.flags & header.TCPFlagRst) != 0 {
+ ep.stack.Stats().TCP.ResetsReceived.Increment()
+ }
+
+ if !ep.enqueueSegment(s) {
+ s.decRef()
+ return
+ }
+
+ // For sockets not in established state let the worker goroutine
+ // handle the packets.
+ if ep.EndpointState() != StateEstablished {
+ ep.newSegmentWaker.Assert()
+ return
+ }
+
+ d.selectProcessor(id).queueEndpoint(ep)
+}
+
+func generateRandUint32() uint32 {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ return binary.LittleEndian.Uint32(b)
+}
+
+func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor {
+ var payload [4]byte
+ binary.LittleEndian.PutUint16(payload[0:], id.LocalPort)
+ binary.LittleEndian.PutUint16(payload[2:], id.RemotePort)
+
+ h := jenkins.Sum32(d.seed)
+ h.Write(payload[:])
+ h.Write([]byte(id.LocalAddress))
+ h.Write([]byte(id.RemoteAddress))
+
+ return &d.processors[h.Sum32()%uint32(len(d.processors))]
+}
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
new file mode 100644
index 000000000..804e95aea
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -0,0 +1,651 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestV4MappedConnectOnV6Only(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ // Start connection attempt, it must fail.
+ err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
+ if err != tcpip.ErrNoRoute {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+}
+
+func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) {
+ // Start connection attempt.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventOut)
+ defer c.WQ.EventUnregister(&we)
+
+ err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
+ if err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+ synCheckers := append(checkers, checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ))
+ checker.IPv4(t, b, synCheckers...)
+
+ tcp := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcp.SequenceNumber())
+
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcp.DestinationPort(),
+ DstPort: tcp.SourcePort(),
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Receive ACK packet.
+ ackCheckers := append(checkers, checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(iss)+1),
+ ))
+ checker.IPv4(t, c.GetPacket(), ackCheckers...)
+
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ err = c.EP.GetSockOpt(tcpip.ErrorOption{})
+ if err != nil {
+ t.Fatalf("Unexpected error when connecting: %v", err)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for connection")
+ }
+}
+
+func TestV4MappedConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Test the connection request.
+ testV4Connect(t, c)
+}
+
+func TestV4ConnectWhenBoundToWildcard(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test the connection request.
+ testV4Connect(t, c)
+}
+
+func TestV4ConnectWhenBoundToV4MappedWildcard(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind to v4 mapped wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test the connection request.
+ testV4Connect(t, c)
+}
+
+func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind to v4 mapped address.
+ if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test the connection request.
+ testV4Connect(t, c)
+}
+
+func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) {
+ // Start connection attempt to IPv6 address.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventOut)
+ defer c.WQ.EventUnregister(&we)
+
+ err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort})
+ if err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetV6Packet()
+ synCheckers := append(checkers, checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ))
+ checker.IPv6(t, b, synCheckers...)
+
+ tcp := header.TCP(header.IPv6(b).Payload())
+ c.IRS = seqnum.Value(tcp.SequenceNumber())
+
+ iss := seqnum.Value(789)
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: tcp.DestinationPort(),
+ DstPort: tcp.SourcePort(),
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Receive ACK packet.
+ ackCheckers := append(checkers, checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(iss)+1),
+ ))
+ checker.IPv6(t, c.GetV6Packet(), ackCheckers...)
+
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ err = c.EP.GetSockOpt(tcpip.ErrorOption{})
+ if err != nil {
+ t.Fatalf("Unexpected error when connecting: %v", err)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for connection")
+ }
+}
+
+func TestV6Connect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Test the connection request.
+ testV6Connect(t, c)
+}
+
+func TestV6ConnectV6Only(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ // Test the connection request.
+ testV6Connect(t, c)
+}
+
+func TestV6ConnectWhenBoundToWildcard(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test the connection request.
+ testV6Connect(t, c)
+}
+
+func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind to local address.
+ if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV6Addr}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test the connection request.
+ testV6Connect(t, c)
+}
+
+func TestV4RefuseOnV6Only(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Start listening.
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ // Send a SYN request.
+ irs := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+
+ // Receive the RST reply.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.AckNum(uint32(irs)+1),
+ ),
+ )
+}
+
+func TestV6RefuseOnBoundToV4Mapped(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind and listen.
+ if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ // Send a SYN request.
+ irs := seqnum.Value(789)
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+
+ // Receive the RST reply.
+ checker.IPv6(t, c.GetV6Packet(),
+ checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.AckNum(uint32(irs)+1),
+ ),
+ )
+}
+
+func testV4Accept(t *testing.T, c *context.Context) {
+ c.SetGSOEnabled(true)
+ defer c.SetGSOEnabled(false)
+
+ // Start listening.
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ // Send a SYN request.
+ irs := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcp := header.TCP(header.IPv4(b).Payload())
+ iss := seqnum.Value(tcp.SequenceNumber())
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1),
+ ),
+ )
+
+ // Send ACK.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ nep, _, err := c.EP.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ nep, _, err = c.EP.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %v", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Make sure we get the same error when calling the original ep and the
+ // new one. This validates that v4-mapped endpoints are still able to
+ // query the V6Only flag, whereas pure v4 endpoints are not.
+ _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption)
+ if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected {
+ t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected)
+ }
+
+ // Check the peer address.
+ addr, err := nep.GetRemoteAddress()
+ if err != nil {
+ t.Fatalf("GetRemoteAddress failed failed: %v", err)
+ }
+
+ if addr.Addr != context.TestAddr {
+ t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr)
+ }
+
+ data := "Don't panic"
+ nep.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
+ b = c.GetPacket()
+ tcp = header.TCP(header.IPv4(b).Payload())
+ if string(tcp.Payload()) != data {
+ t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
+ }
+}
+
+func TestV4AcceptOnV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4Accept(t, c)
+}
+
+func TestV4AcceptOnBoundToV4MappedWildcard(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind to v4 mapped wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4Accept(t, c)
+}
+
+func TestV4AcceptOnBoundToV4Mapped(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind and listen.
+ if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4Accept(t, c)
+}
+
+func TestV6AcceptOnV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind and listen.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ // Send a SYN request.
+ irs := seqnum.Value(789)
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetV6Packet()
+ tcp := header.TCP(header.IPv6(b).Payload())
+ iss := seqnum.Value(tcp.SequenceNumber())
+ checker.IPv6(t, b,
+ checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1),
+ ),
+ )
+
+ // Send ACK.
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ nep, _, err := c.EP.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ nep, _, err = c.EP.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %v", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Make sure we can still query the v6 only status of the new endpoint,
+ // that is, that it is in fact a v6 socket.
+ if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil {
+ t.Fatalf("GetSockOpt failed failed: %v", err)
+ }
+
+ // Check the peer address.
+ addr, err := nep.GetRemoteAddress()
+ if err != nil {
+ t.Fatalf("GetRemoteAddress failed failed: %v", err)
+ }
+
+ if addr.Addr != context.TestV6Addr {
+ t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr)
+ }
+}
+
+func TestV4AcceptOnV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4Accept(t, c)
+}
+
+func testV4ListenClose(t *testing.T, c *context.Context) {
+ // Set the SynRcvd threshold to zero to force a syn cookie based accept
+ // to happen.
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption failed: %s", err)
+ }
+
+ const n = uint16(32)
+
+ // Start listening.
+ if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ irs := seqnum.Value(789)
+ for i := uint16(0); i < n; i++ {
+ // Send a SYN request.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort + i,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+ }
+
+ // Each of these ACK's will cause a syn-cookie based connection to be
+ // accepted and delivered to the listening endpoint.
+ for i := uint16(0); i < n; i++ {
+ b := c.GetPacket()
+ tcp := header.TCP(header.IPv4(b).Payload())
+ iss := seqnum.Value(tcp.SequenceNumber())
+ // Send ACK.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcp.DestinationPort(),
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+ }
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+ nep, _, err := c.EP.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ nep, _, err = c.EP.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %v", err)
+ }
+
+ case <-time.After(10 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+ nep.Close()
+ c.EP.Close()
+}
+
+func TestV4ListenCloseOnV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4ListenClose(t, c)
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
new file mode 100644
index 000000000..caac6ef57
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -0,0 +1,2888 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "encoding/binary"
+ "fmt"
+ "math"
+ "runtime"
+ "strings"
+ "sync/atomic"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// EndpointState represents the state of a TCP endpoint.
+type EndpointState uint32
+
+// Endpoint states. Note that are represented in a netstack-specific manner and
+// may not be meaningful externally. Specifically, they need to be translated to
+// Linux's representation for these states if presented to userspace.
+const (
+ // Endpoint states internal to netstack. These map to the TCP state CLOSED.
+ StateInitial EndpointState = iota
+ StateBound
+ StateConnecting // Connect() called, but the initial SYN hasn't been sent.
+ StateError
+
+ // TCP protocol states.
+ StateEstablished
+ StateSynSent
+ StateSynRecv
+ StateFinWait1
+ StateFinWait2
+ StateTimeWait
+ StateClose
+ StateCloseWait
+ StateLastAck
+ StateListen
+ StateClosing
+)
+
+// connected returns true when s is one of the states representing an
+// endpoint connected to a peer.
+func (s EndpointState) connected() bool {
+ switch s {
+ case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
+ return true
+ default:
+ return false
+ }
+}
+
+// connecting returns true when s is one of the states representing a
+// connection in progress, but not yet fully established.
+func (s EndpointState) connecting() bool {
+ switch s {
+ case StateConnecting, StateSynSent, StateSynRecv:
+ return true
+ default:
+ return false
+ }
+}
+
+// handshake returns true when s is one of the states representing an endpoint
+// in the middle of a TCP handshake.
+func (s EndpointState) handshake() bool {
+ switch s {
+ case StateSynSent, StateSynRecv:
+ return true
+ default:
+ return false
+ }
+}
+
+// closed returns true when s is one of the states an endpoint transitions to
+// when closed or when it encounters an error. This is distinct from a newly
+// initialized endpoint that was never connected.
+func (s EndpointState) closed() bool {
+ switch s {
+ case StateClose, StateError:
+ return true
+ default:
+ return false
+ }
+}
+
+// String implements fmt.Stringer.String.
+func (s EndpointState) String() string {
+ switch s {
+ case StateInitial:
+ return "INITIAL"
+ case StateBound:
+ return "BOUND"
+ case StateConnecting:
+ return "CONNECTING"
+ case StateError:
+ return "ERROR"
+ case StateEstablished:
+ return "ESTABLISHED"
+ case StateSynSent:
+ return "SYN-SENT"
+ case StateSynRecv:
+ return "SYN-RCVD"
+ case StateFinWait1:
+ return "FIN-WAIT1"
+ case StateFinWait2:
+ return "FIN-WAIT2"
+ case StateTimeWait:
+ return "TIME-WAIT"
+ case StateClose:
+ return "CLOSED"
+ case StateCloseWait:
+ return "CLOSE-WAIT"
+ case StateLastAck:
+ return "LAST-ACK"
+ case StateListen:
+ return "LISTEN"
+ case StateClosing:
+ return "CLOSING"
+ default:
+ panic("unreachable")
+ }
+}
+
+// Reasons for notifying the protocol goroutine.
+const (
+ notifyNonZeroReceiveWindow = 1 << iota
+ notifyReceiveWindowChanged
+ notifyClose
+ notifyMTUChanged
+ notifyDrain
+ notifyReset
+ notifyResetByPeer
+ // notifyAbort is a request for an expedited teardown.
+ notifyAbort
+ notifyKeepaliveChanged
+ notifyMSSChanged
+ // notifyTickleWorker is used to tickle the protocol main loop during a
+ // restore after we update the endpoint state to the correct one. This
+ // ensures the loop terminates if the final state of the endpoint is
+ // say TIME_WAIT.
+ notifyTickleWorker
+ notifyError
+)
+
+// SACKInfo holds TCP SACK related information for a given endpoint.
+//
+// +stateify savable
+type SACKInfo struct {
+ // Blocks is the maximum number of SACK blocks we track
+ // per endpoint.
+ Blocks [MaxSACKBlocks]header.SACKBlock
+
+ // NumBlocks is the number of valid SACK blocks stored in the
+ // blocks array above.
+ NumBlocks int
+}
+
+// rcvBufAutoTuneParams are used to hold state variables to compute
+// the auto tuned recv buffer size.
+//
+// +stateify savable
+type rcvBufAutoTuneParams struct {
+ // measureTime is the time at which the current measurement
+ // was started.
+ measureTime time.Time `state:".(unixTime)"`
+
+ // copied is the number of bytes copied out of the receive
+ // buffers since this measure began.
+ copied int
+
+ // prevCopied is the number of bytes copied out of the receive
+ // buffers in the previous RTT period.
+ prevCopied int
+
+ // rtt is the non-smoothed minimum RTT as measured by observing the time
+ // between when a byte is first acknowledged and the receipt of data
+ // that is at least one window beyond the sequence number that was
+ // acknowledged.
+ rtt time.Duration
+
+ // rttMeasureSeqNumber is the highest acceptable sequence number at the
+ // time this RTT measurement period began.
+ rttMeasureSeqNumber seqnum.Value
+
+ // rttMeasureTime is the absolute time at which the current rtt
+ // measurement period began.
+ rttMeasureTime time.Time `state:".(unixTime)"`
+
+ // disabled is true if an explicit receive buffer is set for the
+ // endpoint.
+ disabled bool
+}
+
+// ReceiveErrors collect segment receive errors within transport layer.
+type ReceiveErrors struct {
+ tcpip.ReceiveErrors
+
+ // SegmentQueueDropped is the number of segments dropped due to
+ // a full segment queue.
+ SegmentQueueDropped tcpip.StatCounter
+
+ // ChecksumErrors is the number of segments dropped due to bad checksums.
+ ChecksumErrors tcpip.StatCounter
+
+ // ListenOverflowSynDrop is the number of times the listen queue overflowed
+ // and a SYN was dropped.
+ ListenOverflowSynDrop tcpip.StatCounter
+
+ // ListenOverflowAckDrop is the number of times the final ACK
+ // in the handshake was dropped due to overflow.
+ ListenOverflowAckDrop tcpip.StatCounter
+
+ // ZeroRcvWindowState is the number of times we advertised
+ // a zero receive window when rcvList is full.
+ ZeroRcvWindowState tcpip.StatCounter
+}
+
+// SendErrors collect segment send errors within the transport layer.
+type SendErrors struct {
+ tcpip.SendErrors
+
+ // SegmentSendToNetworkFailed is the number of TCP segments failed to be sent
+ // to the network endpoint.
+ SegmentSendToNetworkFailed tcpip.StatCounter
+
+ // SynSendToNetworkFailed is the number of TCP SYNs failed to be sent
+ // to the network endpoint.
+ SynSendToNetworkFailed tcpip.StatCounter
+
+ // Retransmits is the number of TCP segments retransmitted.
+ Retransmits tcpip.StatCounter
+
+ // FastRetransmit is the number of segments retransmitted in fast
+ // recovery.
+ FastRetransmit tcpip.StatCounter
+
+ // Timeouts is the number of times the RTO expired.
+ Timeouts tcpip.StatCounter
+}
+
+// Stats holds statistics about the endpoint.
+type Stats struct {
+ // SegmentsReceived is the number of TCP segments received that
+ // the transport layer successfully parsed.
+ SegmentsReceived tcpip.StatCounter
+
+ // SegmentsSent is the number of TCP segments sent.
+ SegmentsSent tcpip.StatCounter
+
+ // FailedConnectionAttempts is the number of times we saw Connect and
+ // Accept errors.
+ FailedConnectionAttempts tcpip.StatCounter
+
+ // ReceiveErrors collects segment receive errors within the
+ // transport layer.
+ ReceiveErrors ReceiveErrors
+
+ // ReadErrors collects segment read errors from an endpoint read call.
+ ReadErrors tcpip.ReadErrors
+
+ // SendErrors collects segment send errors within the transport layer.
+ SendErrors SendErrors
+
+ // WriteErrors collects segment write errors from an endpoint write call.
+ WriteErrors tcpip.WriteErrors
+}
+
+// IsEndpointStats is an empty method to implement the tcpip.EndpointStats
+// marker interface.
+func (*Stats) IsEndpointStats() {}
+
+// EndpointInfo holds useful information about a transport endpoint which
+// can be queried by monitoring tools.
+//
+// +stateify savable
+type EndpointInfo struct {
+ stack.TransportEndpointInfo
+
+ // HardError is meaningful only when state is stateError. It stores the
+ // error to be returned when read/write syscalls are called and the
+ // endpoint is in this state. HardError is protected by endpoint mu.
+ HardError *tcpip.Error `state:".(string)"`
+}
+
+// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
+// marker interface.
+func (*EndpointInfo) IsEndpointInfo() {}
+
+// endpoint represents a TCP endpoint. This struct serves as the interface
+// between users of the endpoint and the protocol implementation; it is legal to
+// have concurrent goroutines make calls into the endpoint, they are properly
+// synchronized. The protocol implementation, however, runs in a single
+// goroutine.
+//
+// Each endpoint has a few mutexes:
+//
+// e.mu -> Primary mutex for an endpoint must be held for all operations except
+// in e.Readiness where acquiring it will result in a deadlock in epoll
+// implementation.
+//
+// The following three mutexes can be acquired independent of e.mu but if
+// acquired with e.mu then e.mu must be acquired first.
+//
+// e.acceptMu -> protects acceptedChan.
+// e.rcvListMu -> Protects the rcvList and associated fields.
+// e.sndBufMu -> Protects the sndQueue and associated fields.
+// e.lastErrorMu -> Protects the lastError field.
+//
+// LOCKING/UNLOCKING of the endpoint. The locking of an endpoint is different
+// based on the context in which the lock is acquired. In the syscall context
+// e.LockUser/e.UnlockUser should be used and when doing background processing
+// e.mu.Lock/e.mu.Unlock should be used. The distinction is described below
+// in brief.
+//
+// The reason for this locking behaviour is to avoid wakeups to handle packets.
+// In cases where the endpoint is already locked the background processor can
+// queue the packet up and go its merry way and the lock owner will eventually
+// process the backlog when releasing the lock. Similarly when acquiring the
+// lock from say a syscall goroutine we can implement a bit of spinning if we
+// know that the lock is not held by another syscall goroutine. Background
+// processors should never hold the lock for long and we can avoid an expensive
+// sleep/wakeup by spinning for a shortwhile.
+//
+// For more details please see the detailed documentation on
+// e.LockUser/e.UnlockUser methods.
+//
+// +stateify savable
+type endpoint struct {
+ EndpointInfo
+
+ // endpointEntry is used to queue endpoints for processing to the
+ // a given tcp processor goroutine.
+ //
+ // Precondition: epQueue.mu must be held to read/write this field..
+ endpointEntry `state:"nosave"`
+
+ // pendingProcessing is true if this endpoint is queued for processing
+ // to a TCP processor.
+ //
+ // Precondition: epQueue.mu must be held to read/write this field..
+ pendingProcessing bool `state:"nosave"`
+
+ // The following fields are initialized at creation time and do not
+ // change throughout the lifetime of the endpoint.
+ stack *stack.Stack `state:"manual"`
+ waiterQueue *waiter.Queue `state:"wait"`
+ uniqueID uint64
+
+ // lastError represents the last error that the endpoint reported;
+ // access to it is protected by the following mutex.
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError *tcpip.Error `state:".(string)"`
+
+ // The following fields are used to manage the receive queue. The
+ // protocol goroutine adds ready-for-delivery segments to rcvList,
+ // which are returned by Read() calls to users.
+ //
+ // Once the peer has closed its send side, rcvClosed is set to true
+ // to indicate to users that no more data is coming.
+ //
+ // rcvListMu can be taken after the endpoint mu below.
+ rcvListMu sync.Mutex `state:"nosave"`
+ rcvList segmentList `state:"wait"`
+ rcvClosed bool
+ rcvBufSize int
+ rcvBufUsed int
+ rcvAutoParams rcvBufAutoTuneParams
+
+ // mu protects all endpoint fields unless documented otherwise. mu must
+ // be acquired before interacting with the endpoint fields.
+ mu sync.Mutex `state:"nosave"`
+ ownedByUser uint32
+
+ // state must be read/set using the EndpointState()/setEndpointState()
+ // methods.
+ state EndpointState `state:".(EndpointState)"`
+
+ // origEndpointState is only used during a restore phase to save the
+ // endpoint state at restore time as the socket is moved to it's correct
+ // state.
+ origEndpointState EndpointState `state:"nosave"`
+
+ isPortReserved bool `state:"manual"`
+ isRegistered bool `state:"manual"`
+ boundNICID tcpip.NICID
+ route stack.Route `state:"manual"`
+ ttl uint8
+ v6only bool
+ isConnectNotified bool
+ // TCP should never broadcast but Linux nevertheless supports enabling/
+ // disabling SO_BROADCAST, albeit as a NOOP.
+ broadcast bool
+
+ // portFlags stores the current values of port related flags.
+ portFlags ports.Flags
+
+ // Values used to reserve a port or register a transport endpoint
+ // (which ever happens first).
+ boundBindToDevice tcpip.NICID
+ boundPortFlags ports.Flags
+ boundDest tcpip.FullAddress
+
+ // effectiveNetProtos contains the network protocols actually in use. In
+ // most cases it will only contain "netProto", but in cases like IPv6
+ // endpoints with v6only set to false, this could include multiple
+ // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
+ // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
+ // address).
+ effectiveNetProtos []tcpip.NetworkProtocolNumber
+
+ // workerRunning specifies if a worker goroutine is running.
+ workerRunning bool
+
+ // workerCleanup specifies if the worker goroutine must perform cleanup
+ // before exiting. This can only be set to true when workerRunning is
+ // also true, and they're both protected by the mutex.
+ workerCleanup bool
+
+ // sendTSOk is used to indicate when the TS Option has been negotiated.
+ // When sendTSOk is true every non-RST segment should carry a TS as per
+ // RFC7323#section-1.1
+ sendTSOk bool
+
+ // recentTS is the timestamp that should be sent in the TSEcr field of
+ // the timestamp for future segments sent by the endpoint. This field is
+ // updated if required when a new segment is received by this endpoint.
+ //
+ // recentTS must be read/written atomically.
+ recentTS uint32
+
+ // tsOffset is a randomized offset added to the value of the
+ // TSVal field in the timestamp option.
+ tsOffset uint32
+
+ // shutdownFlags represent the current shutdown state of the endpoint.
+ shutdownFlags tcpip.ShutdownFlags
+
+ // sackPermitted is set to true if the peer sends the TCPSACKPermitted
+ // option in the SYN/SYN-ACK.
+ sackPermitted bool
+
+ // sack holds TCP SACK related information for this endpoint.
+ sack SACKInfo
+
+ // bindToDevice is set to the NIC on which to bind or disabled if 0.
+ bindToDevice tcpip.NICID
+
+ // delay enables Nagle's algorithm.
+ //
+ // delay is a boolean (0 is false) and must be accessed atomically.
+ delay uint32
+
+ // cork holds back segments until full.
+ //
+ // cork is a boolean (0 is false) and must be accessed atomically.
+ cork uint32
+
+ // scoreboard holds TCP SACK Scoreboard information for this endpoint.
+ scoreboard *SACKScoreboard
+
+ // The options below aren't implemented, but we remember the user
+ // settings because applications expect to be able to set/query these
+ // options.
+
+ // slowAck holds the negated state of quick ack. It is stubbed out and
+ // does nothing.
+ //
+ // slowAck is a boolean (0 is false) and must be accessed atomically.
+ slowAck uint32
+
+ // segmentQueue is used to hand received segments to the protocol
+ // goroutine. Segments are queued as long as the queue is not full,
+ // and dropped when it is.
+ segmentQueue segmentQueue `state:"wait"`
+
+ // synRcvdCount is the number of connections for this endpoint that are
+ // in SYN-RCVD state.
+ synRcvdCount int
+
+ // userMSS if non-zero is the MSS value explicitly set by the user
+ // for this endpoint using the TCP_MAXSEG setsockopt.
+ userMSS uint16
+
+ // maxSynRetries is the maximum number of SYN retransmits that TCP should
+ // send before aborting the attempt to connect. It cannot exceed 255.
+ //
+ // NOTE: This is currently a no-op and does not change the SYN
+ // retransmissions.
+ maxSynRetries uint8
+
+ // windowClamp is used to bound the size of the advertised window to
+ // this value.
+ windowClamp uint32
+
+ // The following fields are used to manage the send buffer. When
+ // segments are ready to be sent, they are added to sndQueue and the
+ // protocol goroutine is signaled via sndWaker.
+ //
+ // When the send side is closed, the protocol goroutine is notified via
+ // sndCloseWaker, and sndClosed is set to true.
+ sndBufMu sync.Mutex `state:"nosave"`
+ sndBufSize int
+ sndBufUsed int
+ sndClosed bool
+ sndBufInQueue seqnum.Size
+ sndQueue segmentList `state:"wait"`
+ sndWaker sleep.Waker `state:"manual"`
+ sndCloseWaker sleep.Waker `state:"manual"`
+
+ // cc stores the name of the Congestion Control algorithm to use for
+ // this endpoint.
+ cc tcpip.CongestionControlOption
+
+ // The following are used when a "packet too big" control packet is
+ // received. They are protected by sndBufMu. They are used to
+ // communicate to the main protocol goroutine how many such control
+ // messages have been received since the last notification was processed
+ // and what was the smallest MTU seen.
+ packetTooBigCount int
+ sndMTU int
+
+ // newSegmentWaker is used to indicate to the protocol goroutine that
+ // it needs to wake up and handle new segments queued to it.
+ newSegmentWaker sleep.Waker `state:"manual"`
+
+ // notificationWaker is used to indicate to the protocol goroutine that
+ // it needs to wake up and check for notifications.
+ notificationWaker sleep.Waker `state:"manual"`
+
+ // notifyFlags is a bitmask of flags used to indicate to the protocol
+ // goroutine what it was notified; this is only accessed atomically.
+ notifyFlags uint32 `state:"nosave"`
+
+ // keepalive manages TCP keepalive state. When the connection is idle
+ // (no data sent or received) for keepaliveIdle, we start sending
+ // keepalives every keepalive.interval. If we send keepalive.count
+ // without hearing a response, the connection is closed.
+ keepalive keepalive
+
+ // userTimeout if non-zero specifies a user specified timeout for
+ // a connection w/ pending data to send. A connection that has pending
+ // unacked data will be forcibily aborted if the timeout is reached
+ // without any data being acked.
+ userTimeout time.Duration
+
+ // deferAccept if non-zero specifies a user specified time during
+ // which the final ACK of a handshake will be dropped provided the
+ // ACK is a bare ACK and carries no data. If the timeout is crossed then
+ // the bare ACK is accepted and the connection is delivered to the
+ // listener.
+ deferAccept time.Duration
+
+ // pendingAccepted is a synchronization primitive used to track number
+ // of connections that are queued up to be delivered to the accepted
+ // channel. We use this to ensure that all goroutines blocked on writing
+ // to the acceptedChan below terminate before we close acceptedChan.
+ pendingAccepted sync.WaitGroup `state:"nosave"`
+
+ // acceptMu protects acceptedChan.
+ acceptMu sync.Mutex `state:"nosave"`
+
+ // acceptCond is a condition variable that can be used to block on when
+ // acceptedChan is full and an endpoint is ready to be delivered.
+ //
+ // This condition variable is required because just blocking on sending
+ // to acceptedChan does not work in cases where endpoint.Listen is
+ // called twice with different backlog values. In such cases the channel
+ // is closed and a new one created. Any pending goroutines blocking on
+ // the write to the channel will panic.
+ //
+ // We use this condition variable to block/unblock goroutines which
+ // tried to deliver an endpoint but couldn't because accept backlog was
+ // full ( See: endpoint.deliverAccepted ).
+ acceptCond *sync.Cond `state:"nosave"`
+
+ // acceptedChan is used by a listening endpoint protocol goroutine to
+ // send newly accepted connections to the endpoint so that they can be
+ // read by Accept() calls.
+ acceptedChan chan *endpoint `state:".([]*endpoint)"`
+
+ // The following are only used from the protocol goroutine, and
+ // therefore don't need locks to protect them.
+ rcv *receiver `state:"wait"`
+ snd *sender `state:"wait"`
+
+ // The goroutine drain completion notification channel.
+ drainDone chan struct{} `state:"nosave"`
+
+ // The goroutine undrain notification channel. This is currently used as
+ // a way to block the worker goroutines. Today nothing closes/writes
+ // this channel and this causes any goroutines waiting on this to just
+ // block. This is used during save/restore to prevent worker goroutines
+ // from mutating state as it's being saved.
+ undrain chan struct{} `state:"nosave"`
+
+ // probe if not nil is invoked on every received segment. It is passed
+ // a copy of the current state of the endpoint.
+ probe stack.TCPProbeFunc `state:"nosave"`
+
+ // The following are only used to assist the restore run to re-connect.
+ connectingAddress tcpip.Address
+
+ // amss is the advertised MSS to the peer by this endpoint.
+ amss uint16
+
+ // sendTOS represents IPv4 TOS or IPv6 TrafficClass,
+ // applied while sending packets. Defaults to 0 as on Linux.
+ sendTOS uint8
+
+ gso *stack.GSO
+
+ // TODO(b/142022063): Add ability to save and restore per endpoint stats.
+ stats Stats `state:"nosave"`
+
+ // tcpLingerTimeout is the maximum amount of a time a socket
+ // a socket stays in TIME_WAIT state before being marked
+ // closed.
+ tcpLingerTimeout time.Duration
+
+ // closed indicates that the user has called closed on the
+ // endpoint and at this point the endpoint is only around
+ // to complete the TCP shutdown.
+ closed bool
+
+ // txHash is the transport layer hash to be set on outbound packets
+ // emitted by this endpoint.
+ txHash uint32
+
+ // owner is used to get uid and gid of the packet.
+ owner tcpip.PacketOwner
+}
+
+// UniqueID implements stack.TransportEndpoint.UniqueID.
+func (e *endpoint) UniqueID() uint64 {
+ return e.uniqueID
+}
+
+// calculateAdvertisedMSS calculates the MSS to advertise.
+//
+// If userMSS is non-zero and is not greater than the maximum possible MSS for
+// r, it will be used; otherwise, the maximum possible MSS will be used.
+func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 {
+ // The maximum possible MSS is dependent on the route.
+ maxMSS := mssForRoute(&r)
+
+ if userMSS != 0 && userMSS < maxMSS {
+ return userMSS
+ }
+
+ return maxMSS
+}
+
+// LockUser tries to lock e.mu and if it fails it will check if the lock is held
+// by another syscall goroutine. If yes, then it will goto sleep waiting for the
+// lock to be released, if not then it will spin till it acquires the lock or
+// another syscall goroutine acquires it in which case it will goto sleep as
+// described above.
+//
+// The assumption behind spinning here being that background packet processing
+// should not be holding the lock for long and spinning reduces latency as we
+// avoid an expensive sleep/wakeup of of the syscall goroutine).
+func (e *endpoint) LockUser() {
+ for {
+ // Try first if the sock is locked then check if it's owned
+ // by another user goroutine if not then we spin, otherwise
+ // we just goto sleep on the Lock() and wait.
+ if !e.mu.TryLock() {
+ // If socket is owned by the user then just goto sleep
+ // as the lock could be held for a reasonably long time.
+ if atomic.LoadUint32(&e.ownedByUser) == 1 {
+ e.mu.Lock()
+ atomic.StoreUint32(&e.ownedByUser, 1)
+ return
+ }
+ // Spin but yield the processor since the lower half
+ // should yield the lock soon.
+ runtime.Gosched()
+ continue
+ }
+ atomic.StoreUint32(&e.ownedByUser, 1)
+ return
+ }
+}
+
+// UnlockUser will check if there are any segments already queued for processing
+// and process any such segments before unlocking e.mu. This is required because
+// we when packets arrive and endpoint lock is already held then such packets
+// are queued up to be processed. If the lock is held by the endpoint goroutine
+// then it will process these packets but if the lock is instead held by the
+// syscall goroutine then we can have the syscall goroutine process the backlog
+// before unlocking.
+//
+// This avoids an unnecessary wakeup of the endpoint protocol goroutine for the
+// endpoint. It's also required eventually when we get rid of the endpoint
+// protocol goroutine altogether.
+//
+// Precondition: e.LockUser() must have been called before calling e.UnlockUser()
+func (e *endpoint) UnlockUser() {
+ // Lock segment queue before checking so that we avoid a race where
+ // segments can be queued between the time we check if queue is empty
+ // and actually unlock the endpoint mutex.
+ for {
+ e.segmentQueue.mu.Lock()
+ if e.segmentQueue.emptyLocked() {
+ if atomic.SwapUint32(&e.ownedByUser, 0) != 1 {
+ panic("e.UnlockUser() called without calling e.LockUser()")
+ }
+ e.mu.Unlock()
+ e.segmentQueue.mu.Unlock()
+ return
+ }
+ e.segmentQueue.mu.Unlock()
+
+ switch e.EndpointState() {
+ case StateEstablished:
+ if err := e.handleSegments(true /* fastPath */); err != nil {
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ }
+ default:
+ // Since we are waking the endpoint goroutine here just unlock
+ // and let it process the queued segments.
+ e.newSegmentWaker.Assert()
+ if atomic.SwapUint32(&e.ownedByUser, 0) != 1 {
+ panic("e.UnlockUser() called without calling e.LockUser()")
+ }
+ e.mu.Unlock()
+ return
+ }
+ }
+}
+
+// StopWork halts packet processing. Only to be used in tests.
+func (e *endpoint) StopWork() {
+ e.mu.Lock()
+}
+
+// ResumeWork resumes packet processing. Only to be used in tests.
+func (e *endpoint) ResumeWork() {
+ e.mu.Unlock()
+}
+
+// setEndpointState updates the state of the endpoint to state atomically. This
+// method is unexported as the only place we should update the state is in this
+// package but we allow the state to be read freely without holding e.mu.
+//
+// Precondition: e.mu must be held to call this method.
+func (e *endpoint) setEndpointState(state EndpointState) {
+ oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+ switch state {
+ case StateEstablished:
+ e.stack.Stats().TCP.CurrentEstablished.Increment()
+ e.stack.Stats().TCP.CurrentConnected.Increment()
+ case StateError:
+ fallthrough
+ case StateClose:
+ if oldstate == StateCloseWait || oldstate == StateEstablished {
+ e.stack.Stats().TCP.EstablishedResets.Increment()
+ }
+ fallthrough
+ default:
+ if oldstate == StateEstablished {
+ e.stack.Stats().TCP.CurrentEstablished.Decrement()
+ }
+ }
+ atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+}
+
+// EndpointState returns the current state of the endpoint.
+func (e *endpoint) EndpointState() EndpointState {
+ return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+}
+
+// setRecentTimestamp atomically sets the recentTS field to the
+// provided value.
+func (e *endpoint) setRecentTimestamp(recentTS uint32) {
+ atomic.StoreUint32(&e.recentTS, recentTS)
+}
+
+// recentTimestamp atomically reads and returns the value of the recentTS field.
+func (e *endpoint) recentTimestamp() uint32 {
+ return atomic.LoadUint32(&e.recentTS)
+}
+
+// keepalive is a synchronization wrapper used to appease stateify. See the
+// comment in endpoint, where it is used.
+//
+// +stateify savable
+type keepalive struct {
+ sync.Mutex `state:"nosave"`
+ enabled bool
+ idle time.Duration
+ interval time.Duration
+ count int
+ unacked int
+ timer timer `state:"nosave"`
+ waker sleep.Waker `state:"nosave"`
+}
+
+func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+ e := &endpoint{
+ stack: s,
+ EndpointInfo: EndpointInfo{
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: header.TCPProtocolNumber,
+ },
+ },
+ waiterQueue: waiterQueue,
+ state: StateInitial,
+ rcvBufSize: DefaultReceiveBufferSize,
+ sndBufSize: DefaultSendBufferSize,
+ sndMTU: int(math.MaxInt32),
+ keepalive: keepalive{
+ // Linux defaults.
+ idle: 2 * time.Hour,
+ interval: 75 * time.Second,
+ count: 9,
+ },
+ uniqueID: s.UniqueID(),
+ txHash: s.Rand().Uint32(),
+ windowClamp: DefaultReceiveBufferSize,
+ maxSynRetries: DefaultSynRetries,
+ }
+
+ var ss SendBufferSizeOption
+ if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ e.sndBufSize = ss.Default
+ }
+
+ var rs ReceiveBufferSizeOption
+ if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ e.rcvBufSize = rs.Default
+ }
+
+ var cs tcpip.CongestionControlOption
+ if err := s.TransportProtocolOption(ProtocolNumber, &cs); err == nil {
+ e.cc = cs
+ }
+
+ var mrb tcpip.ModerateReceiveBufferOption
+ if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil {
+ e.rcvAutoParams.disabled = !bool(mrb)
+ }
+
+ var de DelayEnabled
+ if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de {
+ e.SetSockOptBool(tcpip.DelayOption, true)
+ }
+
+ var tcpLT tcpip.TCPLingerTimeoutOption
+ if err := s.TransportProtocolOption(ProtocolNumber, &tcpLT); err == nil {
+ e.tcpLingerTimeout = time.Duration(tcpLT)
+ }
+
+ var synRetries tcpip.TCPSynRetriesOption
+ if err := s.TransportProtocolOption(ProtocolNumber, &synRetries); err == nil {
+ e.maxSynRetries = uint8(synRetries)
+ }
+
+ if p := s.GetTCPProbe(); p != nil {
+ e.probe = p
+ }
+
+ e.segmentQueue.setLimit(MaxUnprocessedSegments)
+ e.tsOffset = timeStampOffset()
+ e.acceptCond = sync.NewCond(&e.acceptMu)
+
+ return e
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ result := waiter.EventMask(0)
+
+ switch e.EndpointState() {
+ case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv:
+ // Ready for nothing.
+
+ case StateClose, StateError:
+ // Ready for anything.
+ result = mask
+
+ case StateListen:
+ // Check if there's anything in the accepted channel.
+ if (mask & waiter.EventIn) != 0 {
+ e.acceptMu.Lock()
+ if len(e.acceptedChan) > 0 {
+ result |= waiter.EventIn
+ }
+ e.acceptMu.Unlock()
+ }
+ }
+ if e.EndpointState().connected() {
+ // Determine if the endpoint is writable if requested.
+ if (mask & waiter.EventOut) != 0 {
+ e.sndBufMu.Lock()
+ if e.sndClosed || e.sndBufUsed < e.sndBufSize {
+ result |= waiter.EventOut
+ }
+ e.sndBufMu.Unlock()
+ }
+
+ // Determine if the endpoint is readable if requested.
+ if (mask & waiter.EventIn) != 0 {
+ e.rcvListMu.Lock()
+ if e.rcvBufUsed > 0 || e.rcvClosed {
+ result |= waiter.EventIn
+ }
+ e.rcvListMu.Unlock()
+ }
+ }
+
+ return result
+}
+
+func (e *endpoint) fetchNotifications() uint32 {
+ return atomic.SwapUint32(&e.notifyFlags, 0)
+}
+
+func (e *endpoint) notifyProtocolGoroutine(n uint32) {
+ for {
+ v := atomic.LoadUint32(&e.notifyFlags)
+ if v&n == n {
+ // The flags are already set.
+ return
+ }
+
+ if atomic.CompareAndSwapUint32(&e.notifyFlags, v, v|n) {
+ if v == 0 {
+ // We are causing a transition from no flags to
+ // at least one flag set, so we must cause the
+ // protocol goroutine to wake up.
+ e.notificationWaker.Assert()
+ }
+ return
+ }
+ }
+}
+
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ // The abort notification is not processed synchronously, so no
+ // synchronization is needed.
+ //
+ // If the endpoint becomes connected after this check, we still close
+ // the endpoint. This worst case results in a slower abort.
+ //
+ // If the endpoint disconnected after the check, nothing needs to be
+ // done, so sending a notification which will potentially be ignored is
+ // fine.
+ //
+ // If the endpoint connecting finishes after the check, the endpoint
+ // is either in a connected state (where we would notifyAbort anyway),
+ // SYN-RECV (where we would also notifyAbort anyway), or in an error
+ // state where nothing is required and the notification can be safely
+ // ignored.
+ //
+ // Endpoints where a Close during connecting or SYN-RECV state would be
+ // problematic are set to state connecting before being registered (and
+ // thus possible to be Aborted). They are never available in initial
+ // state.
+ //
+ // Endpoints transitioning from initial to connecting state may be
+ // safely either closed or sent notifyAbort.
+ if s := e.EndpointState(); s == StateConnecting || s == StateSynRecv || s.connected() {
+ e.notifyProtocolGoroutine(notifyAbort)
+ return
+ }
+ e.Close()
+}
+
+// Close puts the endpoint in a closed state and frees all resources associated
+// with it. It must be called only once and with no other concurrent calls to
+// the endpoint.
+func (e *endpoint) Close() {
+ e.LockUser()
+ defer e.UnlockUser()
+ if e.closed {
+ return
+ }
+
+ // Issue a shutdown so that the peer knows we won't send any more data
+ // if we're connected, or stop accepting if we're listening.
+ e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+ e.closeNoShutdownLocked()
+}
+
+// closeNoShutdown closes the endpoint without doing a full shutdown. This is
+// used when a connection needs to be aborted with a RST and we want to skip
+// a full 4 way TCP shutdown.
+func (e *endpoint) closeNoShutdownLocked() {
+ // For listening sockets, we always release ports inline so that they
+ // are immediately available for reuse after Close() is called. If also
+ // registered, we unregister as well otherwise the next user would fail
+ // in Listen() when trying to register.
+ if e.EndpointState() == StateListen && e.isPortReserved {
+ if e.isRegistered {
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ e.isRegistered = false
+ }
+
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest)
+ e.isPortReserved = false
+ e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
+ e.boundDest = tcpip.FullAddress{}
+ }
+
+ // Mark endpoint as closed.
+ e.closed = true
+
+ switch e.EndpointState() {
+ case StateClose, StateError:
+ return
+ }
+
+ // Either perform the local cleanup or kick the worker to make sure it
+ // knows it needs to cleanup.
+ if e.workerRunning {
+ e.workerCleanup = true
+ tcpip.AddDanglingEndpoint(e)
+ // Worker will remove the dangling endpoint when the endpoint
+ // goroutine terminates.
+ e.notifyProtocolGoroutine(notifyClose)
+ } else {
+ e.transitionToStateCloseLocked()
+ }
+}
+
+// closePendingAcceptableConnections closes all connections that have completed
+// handshake but not yet been delivered to the application.
+func (e *endpoint) closePendingAcceptableConnectionsLocked() {
+ e.acceptMu.Lock()
+ if e.acceptedChan == nil {
+ e.acceptMu.Unlock()
+ return
+ }
+ close(e.acceptedChan)
+ ch := e.acceptedChan
+ e.acceptedChan = nil
+ e.acceptCond.Broadcast()
+ e.acceptMu.Unlock()
+
+ // Reset all connections that are waiting to be accepted.
+ for n := range ch {
+ n.notifyProtocolGoroutine(notifyReset)
+ }
+ // Wait for reset of all endpoints that are still waiting to be delivered to
+ // the now closed acceptedChan.
+ e.pendingAccepted.Wait()
+}
+
+// cleanupLocked frees all resources associated with the endpoint. It is called
+// after Close() is called and the worker goroutine (if any) is done with its
+// work.
+func (e *endpoint) cleanupLocked() {
+ // Close all endpoints that might have been accepted by TCP but not by
+ // the client.
+ e.closePendingAcceptableConnectionsLocked()
+
+ e.workerCleanup = false
+
+ if e.isRegistered {
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ e.isRegistered = false
+ }
+
+ if e.isPortReserved {
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest)
+ e.isPortReserved = false
+ }
+ e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
+ e.boundDest = tcpip.FullAddress{}
+
+ e.route.Release()
+ e.stack.CompleteTransportEndpointCleanup(e)
+ tcpip.DeleteDanglingEndpoint(e)
+}
+
+// initialReceiveWindow returns the initial receive window to advertise in the
+// SYN/SYN-ACK.
+func (e *endpoint) initialReceiveWindow() int {
+ rcvWnd := e.receiveBufferAvailable()
+ if rcvWnd > math.MaxUint16 {
+ rcvWnd = math.MaxUint16
+ }
+
+ // Use the user supplied MSS, if available.
+ routeWnd := InitialCwnd * int(calculateAdvertisedMSS(e.userMSS, e.route)) * 2
+ if rcvWnd > routeWnd {
+ rcvWnd = routeWnd
+ }
+ rcvWndScale := e.rcvWndScaleForHandshake()
+
+ // Round-down the rcvWnd to a multiple of wndScale. This ensures that the
+ // window offered in SYN won't be reduced due to the loss of precision if
+ // window scaling is enabled after the handshake.
+ rcvWnd = (rcvWnd >> uint8(rcvWndScale)) << uint8(rcvWndScale)
+
+ // Ensure we can always accept at least 1 byte if the scale specified
+ // was too high for the provided rcvWnd.
+ if rcvWnd == 0 {
+ rcvWnd = 1
+ }
+
+ return rcvWnd
+}
+
+// ModerateRecvBuf adjusts the receive buffer and the advertised window
+// based on the number of bytes copied to userspace.
+func (e *endpoint) ModerateRecvBuf(copied int) {
+ e.LockUser()
+ defer e.UnlockUser()
+
+ e.rcvListMu.Lock()
+ if e.rcvAutoParams.disabled {
+ e.rcvListMu.Unlock()
+ return
+ }
+ now := time.Now()
+ if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt {
+ e.rcvAutoParams.copied += copied
+ e.rcvListMu.Unlock()
+ return
+ }
+ prevRTTCopied := e.rcvAutoParams.copied + copied
+ prevCopied := e.rcvAutoParams.prevCopied
+ rcvWnd := 0
+ if prevRTTCopied > prevCopied {
+ // The minimal receive window based on what was copied by the app
+ // in the immediate preceding RTT and some extra buffer for 16
+ // segments to account for variations.
+ // We multiply by 2 to account for packet losses.
+ rcvWnd = prevRTTCopied*2 + 16*int(e.amss)
+
+ // Scale for slow start based on bytes copied in this RTT vs previous.
+ grow := (rcvWnd * (prevRTTCopied - prevCopied)) / prevCopied
+
+ // Multiply growth factor by 2 again to account for sender being
+ // in slow-start where the sender grows it's congestion window
+ // by 100% per RTT.
+ rcvWnd += grow * 2
+
+ // Make sure auto tuned buffer size can always receive upto 2x
+ // the initial window of 10 segments.
+ if minRcvWnd := int(e.amss) * InitialCwnd * 2; rcvWnd < minRcvWnd {
+ rcvWnd = minRcvWnd
+ }
+
+ // Cap the auto tuned buffer size by the maximum permissible
+ // receive buffer size.
+ if max := e.maxReceiveBufferSize(); rcvWnd > max {
+ rcvWnd = max
+ }
+
+ // We do not adjust downwards as that can cause the receiver to
+ // reject valid data that might already be in flight as the
+ // acceptable window will shrink.
+ if rcvWnd > e.rcvBufSize {
+ availBefore := e.receiveBufferAvailableLocked()
+ e.rcvBufSize = rcvWnd
+ availAfter := e.receiveBufferAvailableLocked()
+ mask := uint32(notifyReceiveWindowChanged)
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
+ mask |= notifyNonZeroReceiveWindow
+ }
+ e.notifyProtocolGoroutine(mask)
+ }
+
+ // We only update prevCopied when we grow the buffer because in cases
+ // where prevCopied > prevRTTCopied the existing buffer is already big
+ // enough to handle the current rate and we don't need to do any
+ // adjustments.
+ e.rcvAutoParams.prevCopied = prevRTTCopied
+ }
+ e.rcvAutoParams.measureTime = now
+ e.rcvAutoParams.copied = 0
+ e.rcvListMu.Unlock()
+}
+
+func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.owner = owner
+}
+
+// Read reads data from the endpoint.
+func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ e.LockUser()
+ defer e.UnlockUser()
+
+ // When in SYN-SENT state, let the caller block on the receive.
+ // An application can initiate a non-blocking connect and then block
+ // on a receive. It can expect to read any data after the handshake
+ // is complete. RFC793, section 3.9, p58.
+ if e.EndpointState() == StateSynSent {
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
+ }
+
+ // The endpoint can be read if it's connected, or if it's already closed
+ // but has some pending unread data. Also note that a RST being received
+ // would cause the state to become StateError so we should allow the
+ // reads to proceed before returning a ECONNRESET.
+ e.rcvListMu.Lock()
+ bufUsed := e.rcvBufUsed
+ if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
+ e.rcvListMu.Unlock()
+ he := e.HardError
+ if s == StateError {
+ return buffer.View{}, tcpip.ControlMessages{}, he
+ }
+ e.stats.ReadErrors.NotConnected.Increment()
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected
+ }
+
+ v, err := e.readLocked()
+ e.rcvListMu.Unlock()
+
+ if err == tcpip.ErrClosedForReceive {
+ e.stats.ReadErrors.ReadClosed.Increment()
+ }
+ return v, tcpip.ControlMessages{}, err
+}
+
+func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
+ if e.rcvBufUsed == 0 {
+ if e.rcvClosed || !e.EndpointState().connected() {
+ return buffer.View{}, tcpip.ErrClosedForReceive
+ }
+ return buffer.View{}, tcpip.ErrWouldBlock
+ }
+
+ s := e.rcvList.Front()
+ views := s.data.Views()
+ v := views[s.viewToDeliver]
+ s.viewToDeliver++
+
+ if s.viewToDeliver >= len(views) {
+ e.rcvList.Remove(s)
+ s.decRef()
+ }
+
+ e.rcvBufUsed -= len(v)
+
+ // If the window was small before this read and if the read freed up
+ // enough buffer space, to either fit an aMSS or half a receive buffer
+ // (whichever smaller), then notify the protocol goroutine to send a
+ // window update.
+ if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above {
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ }
+
+ return v, nil
+}
+
+// isEndpointWritableLocked checks if a given endpoint is writable
+// and also returns the number of bytes that can be written at this
+// moment. If the endpoint is not writable then it returns an error
+// indicating the reason why it's not writable.
+// Caller must hold e.mu and e.sndBufMu
+func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
+ // The endpoint cannot be written to if it's not connected.
+ if !e.EndpointState().connected() {
+ switch e.EndpointState() {
+ case StateError:
+ return 0, e.HardError
+ default:
+ return 0, tcpip.ErrClosedForSend
+ }
+ }
+
+ // Check if the connection has already been closed for sends.
+ if e.sndClosed {
+ return 0, tcpip.ErrClosedForSend
+ }
+
+ avail := e.sndBufSize - e.sndBufUsed
+ if avail <= 0 {
+ return 0, tcpip.ErrWouldBlock
+ }
+ return avail, nil
+}
+
+// Write writes data to the endpoint's peer.
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // Linux completely ignores any address passed to sendto(2) for TCP sockets
+ // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
+ // and opts.EndOfRecord are also ignored.
+
+ e.LockUser()
+ e.sndBufMu.Lock()
+
+ avail, err := e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.UnlockUser()
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return 0, nil, err
+ }
+
+ // We can release locks while copying data.
+ //
+ // This is not possible if atomic is set, because we can't allow the
+ // available buffer space to be consumed by some other caller while we
+ // are copying data in.
+ if !opts.Atomic {
+ e.sndBufMu.Unlock()
+ e.UnlockUser()
+ }
+
+ // Fetch data.
+ v, perr := p.Payload(avail)
+ if perr != nil || len(v) == 0 {
+ // Note that perr may be nil if len(v) == 0.
+ if opts.Atomic {
+ e.sndBufMu.Unlock()
+ e.UnlockUser()
+ }
+ return 0, nil, perr
+ }
+
+ queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) {
+ // Add data to the send queue.
+ s := newSegmentFromView(&e.route, e.ID, v)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
+ e.sndQueue.PushBack(s)
+ e.sndBufMu.Unlock()
+
+ // Do the work inline.
+ e.handleWrite()
+ e.UnlockUser()
+ return int64(len(v)), nil, nil
+ }
+
+ if opts.Atomic {
+ // Locks released in queueAndSend()
+ return queueAndSend()
+ }
+
+ // Since we released locks in between it's possible that the
+ // endpoint transitioned to a CLOSED/ERROR states so make
+ // sure endpoint is still writable before trying to write.
+ e.LockUser()
+ e.sndBufMu.Lock()
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.UnlockUser()
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return 0, nil, err
+ }
+
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
+
+ // Locks released in queueAndSend()
+ return queueAndSend()
+}
+
+// Peek reads data without consuming it from the endpoint.
+//
+// This method does not block if there is no data pending.
+func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+ e.LockUser()
+ defer e.UnlockUser()
+
+ // The endpoint can be read if it's connected, or if it's already closed
+ // but has some pending unread data.
+ if s := e.EndpointState(); !s.connected() && s != StateClose {
+ if s == StateError {
+ return 0, tcpip.ControlMessages{}, e.HardError
+ }
+ e.stats.ReadErrors.InvalidEndpointState.Increment()
+ return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
+ }
+
+ e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
+
+ if e.rcvBufUsed == 0 {
+ if e.rcvClosed || !e.EndpointState().connected() {
+ e.stats.ReadErrors.ReadClosed.Increment()
+ return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
+ }
+ return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
+ }
+
+ // Make a copy of vec so we can modify the slide headers.
+ vec = append([][]byte(nil), vec...)
+
+ var num int64
+ for s := e.rcvList.Front(); s != nil; s = s.Next() {
+ views := s.data.Views()
+
+ for i := s.viewToDeliver; i < len(views); i++ {
+ v := views[i]
+
+ for len(v) > 0 {
+ if len(vec) == 0 {
+ return num, tcpip.ControlMessages{}, nil
+ }
+ if len(vec[0]) == 0 {
+ vec = vec[1:]
+ continue
+ }
+
+ n := copy(vec[0], v)
+ v = v[n:]
+ vec[0] = vec[0][n:]
+ num += int64(n)
+ }
+ }
+ }
+
+ return num, tcpip.ControlMessages{}, nil
+}
+
+// windowCrossedACKThresholdLocked checks if the receive window to be announced
+// now would be under aMSS or under half receive buffer, whichever smaller. This
+// is useful as a receive side silly window syndrome prevention mechanism. If
+// window grows to reasonable value, we should send ACK to the sender to inform
+// the rx space is now large. We also want ensure a series of small read()'s
+// won't trigger a flood of spurious tiny ACK's.
+//
+// For large receive buffers, the threshold is aMSS - once reader reads more
+// than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of
+// receive buffer size. This is chosen arbitrairly.
+// crossed will be true if the window size crossed the ACK threshold.
+// above will be true if the new window is >= ACK threshold and false
+// otherwise.
+//
+// Precondition: e.mu and e.rcvListMu must be held.
+func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
+ newAvail := e.receiveBufferAvailableLocked()
+ oldAvail := newAvail - deltaBefore
+ if oldAvail < 0 {
+ oldAvail = 0
+ }
+
+ threshold := int(e.amss)
+ if threshold > e.rcvBufSize/2 {
+ threshold = e.rcvBufSize / 2
+ }
+
+ switch {
+ case oldAvail < threshold && newAvail >= threshold:
+ return true, true
+ case oldAvail >= threshold && newAvail < threshold:
+ return true, false
+ }
+ return false, false
+}
+
+// SetSockOptBool sets a socket option.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+
+ case tcpip.BroadcastOption:
+ e.LockUser()
+ e.broadcast = v
+ e.UnlockUser()
+
+ case tcpip.CorkOption:
+ e.LockUser()
+ if !v {
+ atomic.StoreUint32(&e.cork, 0)
+
+ // Handle the corked data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.cork, 1)
+ }
+ e.UnlockUser()
+
+ case tcpip.DelayOption:
+ if v {
+ atomic.StoreUint32(&e.delay, 1)
+ } else {
+ atomic.StoreUint32(&e.delay, 0)
+
+ // Handle delayed data.
+ e.sndWaker.Assert()
+ }
+
+ case tcpip.KeepaliveEnabledOption:
+ e.keepalive.Lock()
+ e.keepalive.enabled = v
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+
+ case tcpip.QuickAckOption:
+ o := uint32(1)
+ if v {
+ o = 0
+ }
+ atomic.StoreUint32(&e.slowAck, o)
+
+ case tcpip.ReuseAddressOption:
+ e.LockUser()
+ e.portFlags.TupleOnly = v
+ e.UnlockUser()
+
+ case tcpip.ReusePortOption:
+ e.LockUser()
+ e.portFlags.LoadBalanced = v
+ e.UnlockUser()
+
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // We only allow this to be set when we're in the initial state.
+ if e.EndpointState() != StateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.LockUser()
+ e.v6only = v
+ e.UnlockUser()
+ }
+
+ return nil
+}
+
+// SetSockOptInt sets a socket option.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ // Lower 2 bits represents ECN bits. RFC 3168, section 23.1
+ const inetECNMask = 3
+
+ switch opt {
+ case tcpip.KeepaliveCountOption:
+ e.keepalive.Lock()
+ e.keepalive.count = v
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+
+ case tcpip.IPv4TOSOption:
+ e.LockUser()
+ // TODO(gvisor.dev/issue/995): ECN is not currently supported,
+ // ignore the bits for now.
+ e.sendTOS = uint8(v) & ^uint8(inetECNMask)
+ e.UnlockUser()
+
+ case tcpip.IPv6TrafficClassOption:
+ e.LockUser()
+ // TODO(gvisor.dev/issue/995): ECN is not currently supported,
+ // ignore the bits for now.
+ e.sendTOS = uint8(v) & ^uint8(inetECNMask)
+ e.UnlockUser()
+
+ case tcpip.MaxSegOption:
+ userMSS := v
+ if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
+ return tcpip.ErrInvalidOptionValue
+ }
+ e.LockUser()
+ e.userMSS = uint16(userMSS)
+ e.UnlockUser()
+ e.notifyProtocolGoroutine(notifyMSSChanged)
+
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if attempting to set this option to
+ // anything other than path MTU discovery disabled.
+ if v != tcpip.PMTUDiscoveryDont {
+ return tcpip.ErrNotSupported
+ }
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs ReceiveBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ if v < rs.Min {
+ v = rs.Min
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+ }
+
+ mask := uint32(notifyReceiveWindowChanged)
+
+ e.LockUser()
+ e.rcvListMu.Lock()
+
+ // Make sure the receive buffer size allows us to send a
+ // non-zero window size.
+ scale := uint8(0)
+ if e.rcv != nil {
+ scale = e.rcv.rcvWndScale
+ }
+ if v>>scale == 0 {
+ v = 1 << scale
+ }
+
+ // Make sure 2*size doesn't overflow.
+ if v > math.MaxInt32/2 {
+ v = math.MaxInt32 / 2
+ }
+
+ availBefore := e.receiveBufferAvailableLocked()
+ e.rcvBufSize = v
+ availAfter := e.receiveBufferAvailableLocked()
+
+ e.rcvAutoParams.disabled = true
+
+ // Immediately send an ACK to uncork the sender silly window
+ // syndrome prevetion, when our available space grows above aMSS
+ // or half receive buffer, whichever smaller.
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
+ mask |= notifyNonZeroReceiveWindow
+ }
+
+ e.rcvListMu.Unlock()
+ e.UnlockUser()
+ e.notifyProtocolGoroutine(mask)
+
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss SendBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ if v < ss.Min {
+ v = ss.Min
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+ }
+
+ e.sndBufMu.Lock()
+ e.sndBufSize = v
+ e.sndBufMu.Unlock()
+
+ case tcpip.TTLOption:
+ e.LockUser()
+ e.ttl = uint8(v)
+ e.UnlockUser()
+
+ case tcpip.TCPSynCountOption:
+ if v < 1 || v > 255 {
+ return tcpip.ErrInvalidOptionValue
+ }
+ e.LockUser()
+ e.maxSynRetries = uint8(v)
+ e.UnlockUser()
+
+ case tcpip.TCPWindowClampOption:
+ if v == 0 {
+ e.LockUser()
+ switch e.EndpointState() {
+ case StateClose, StateInitial:
+ e.windowClamp = 0
+ e.UnlockUser()
+ return nil
+ default:
+ e.UnlockUser()
+ return tcpip.ErrInvalidOptionValue
+ }
+ }
+ var rs ReceiveBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ if v < rs.Min/2 {
+ v = rs.Min / 2
+ }
+ }
+ e.LockUser()
+ e.windowClamp = uint32(v)
+ e.UnlockUser()
+ }
+ return nil
+}
+
+// SetSockOpt sets a socket option.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
+ case tcpip.BindToDeviceOption:
+ id := tcpip.NICID(v)
+ if id != 0 && !e.stack.HasNIC(id) {
+ return tcpip.ErrUnknownDevice
+ }
+ e.LockUser()
+ e.bindToDevice = id
+ e.UnlockUser()
+
+ case tcpip.KeepaliveIdleOption:
+ e.keepalive.Lock()
+ e.keepalive.idle = time.Duration(v)
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+
+ case tcpip.KeepaliveIntervalOption:
+ e.keepalive.Lock()
+ e.keepalive.interval = time.Duration(v)
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+
+ case tcpip.OutOfBandInlineOption:
+ // We don't currently support disabling this option.
+
+ case tcpip.TCPUserTimeoutOption:
+ e.LockUser()
+ e.userTimeout = time.Duration(v)
+ e.UnlockUser()
+
+ case tcpip.CongestionControlOption:
+ // Query the available cc algorithms in the stack and
+ // validate that the specified algorithm is actually
+ // supported in the stack.
+ var avail tcpip.AvailableCongestionControlOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil {
+ return err
+ }
+ availCC := strings.Split(string(avail), " ")
+ for _, cc := range availCC {
+ if v == tcpip.CongestionControlOption(cc) {
+ e.LockUser()
+ state := e.EndpointState()
+ e.cc = v
+ switch state {
+ case StateEstablished:
+ if e.EndpointState() == state {
+ e.snd.cc = e.snd.initCongestionControl(e.cc)
+ }
+ }
+ e.UnlockUser()
+ return nil
+ }
+ }
+
+ // Linux returns ENOENT when an invalid congestion
+ // control algorithm is specified.
+ return tcpip.ErrNoSuchFile
+
+ case tcpip.TCPLingerTimeoutOption:
+ e.LockUser()
+ if v < 0 {
+ // Same as effectively disabling TCPLinger timeout.
+ v = 0
+ }
+ var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption
+ if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil {
+ // We were unable to retrieve a stack config, just use
+ // the DefaultTCPLingerTimeout.
+ if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) {
+ stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout)
+ }
+ }
+ // Cap it to the stack wide TCPLinger timeout.
+ if v > stkTCPLingerTimeout {
+ v = stkTCPLingerTimeout
+ }
+ e.tcpLingerTimeout = time.Duration(v)
+ e.UnlockUser()
+
+ case tcpip.TCPDeferAcceptOption:
+ e.LockUser()
+ if time.Duration(v) > MaxRTO {
+ v = tcpip.TCPDeferAcceptOption(MaxRTO)
+ }
+ e.deferAccept = time.Duration(v)
+ e.UnlockUser()
+
+ default:
+ return nil
+ }
+ return nil
+}
+
+// readyReceiveSize returns the number of bytes ready to be received.
+func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
+ e.LockUser()
+ defer e.UnlockUser()
+
+ // The endpoint cannot be in listen state.
+ if e.EndpointState() == StateListen {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
+
+ return e.rcvBufUsed, nil
+}
+
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ switch opt {
+ case tcpip.BroadcastOption:
+ e.LockUser()
+ v := e.broadcast
+ e.UnlockUser()
+ return v, nil
+
+ case tcpip.CorkOption:
+ return atomic.LoadUint32(&e.cork) != 0, nil
+
+ case tcpip.DelayOption:
+ return atomic.LoadUint32(&e.delay) != 0, nil
+
+ case tcpip.KeepaliveEnabledOption:
+ e.keepalive.Lock()
+ v := e.keepalive.enabled
+ e.keepalive.Unlock()
+
+ return v, nil
+
+ case tcpip.QuickAckOption:
+ v := atomic.LoadUint32(&e.slowAck) == 0
+ return v, nil
+
+ case tcpip.ReuseAddressOption:
+ e.LockUser()
+ v := e.portFlags.TupleOnly
+ e.UnlockUser()
+
+ return v, nil
+
+ case tcpip.ReusePortOption:
+ e.LockUser()
+ v := e.portFlags.LoadBalanced
+ e.UnlockUser()
+
+ return v, nil
+
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+
+ e.LockUser()
+ v := e.v6only
+ e.UnlockUser()
+
+ return v, nil
+
+ case tcpip.MulticastLoopOption:
+ return true, nil
+
+ default:
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+ switch opt {
+ case tcpip.KeepaliveCountOption:
+ e.keepalive.Lock()
+ v := e.keepalive.count
+ e.keepalive.Unlock()
+ return v, nil
+
+ case tcpip.IPv4TOSOption:
+ e.LockUser()
+ v := int(e.sendTOS)
+ e.UnlockUser()
+ return v, nil
+
+ case tcpip.IPv6TrafficClassOption:
+ e.LockUser()
+ v := int(e.sendTOS)
+ e.UnlockUser()
+ return v, nil
+
+ case tcpip.MaxSegOption:
+ // This is just stubbed out. Linux never returns the user_mss
+ // value as it either returns the defaultMSS or returns the
+ // actual current MSS. Netstack just returns the defaultMSS
+ // always for now.
+ v := header.TCPDefaultMSS
+ return v, nil
+
+ case tcpip.MTUDiscoverOption:
+ // Always return the path MTU discovery disabled setting since
+ // it's the only one supported.
+ return tcpip.PMTUDiscoveryDont, nil
+
+ case tcpip.ReceiveQueueSizeOption:
+ return e.readyReceiveSize()
+
+ case tcpip.SendBufferSizeOption:
+ e.sndBufMu.Lock()
+ v := e.sndBufSize
+ e.sndBufMu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvListMu.Lock()
+ v := e.rcvBufSize
+ e.rcvListMu.Unlock()
+ return v, nil
+
+ case tcpip.TTLOption:
+ e.LockUser()
+ v := int(e.ttl)
+ e.UnlockUser()
+ return v, nil
+
+ case tcpip.TCPSynCountOption:
+ e.LockUser()
+ v := int(e.maxSynRetries)
+ e.UnlockUser()
+ return v, nil
+
+ case tcpip.TCPWindowClampOption:
+ e.LockUser()
+ v := int(e.windowClamp)
+ e.UnlockUser()
+ return v, nil
+
+ case tcpip.MulticastTTLOption:
+ return 1, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ e.lastErrorMu.Lock()
+ err := e.lastError
+ e.lastError = nil
+ e.lastErrorMu.Unlock()
+ return err
+
+ case *tcpip.BindToDeviceOption:
+ e.LockUser()
+ *o = tcpip.BindToDeviceOption(e.bindToDevice)
+ e.UnlockUser()
+
+ case *tcpip.TCPInfoOption:
+ *o = tcpip.TCPInfoOption{}
+ e.LockUser()
+ snd := e.snd
+ e.UnlockUser()
+ if snd != nil {
+ snd.rtt.Lock()
+ o.RTT = snd.rtt.srtt
+ o.RTTVar = snd.rtt.rttvar
+ snd.rtt.Unlock()
+ }
+
+ case *tcpip.KeepaliveIdleOption:
+ e.keepalive.Lock()
+ *o = tcpip.KeepaliveIdleOption(e.keepalive.idle)
+ e.keepalive.Unlock()
+
+ case *tcpip.KeepaliveIntervalOption:
+ e.keepalive.Lock()
+ *o = tcpip.KeepaliveIntervalOption(e.keepalive.interval)
+ e.keepalive.Unlock()
+
+ case *tcpip.TCPUserTimeoutOption:
+ e.LockUser()
+ *o = tcpip.TCPUserTimeoutOption(e.userTimeout)
+ e.UnlockUser()
+
+ case *tcpip.OutOfBandInlineOption:
+ // We don't currently support disabling this option.
+ *o = 1
+
+ case *tcpip.CongestionControlOption:
+ e.LockUser()
+ *o = e.cc
+ e.UnlockUser()
+
+ case *tcpip.TCPLingerTimeoutOption:
+ e.LockUser()
+ *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout)
+ e.UnlockUser()
+
+ case *tcpip.TCPDeferAcceptOption:
+ e.LockUser()
+ *o = tcpip.TCPDeferAcceptOption(e.deferAccept)
+ e.UnlockUser()
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+ return nil
+}
+
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
+ if err != nil {
+ return tcpip.FullAddress{}, 0, err
+ }
+ return unwrapped, netProto, nil
+}
+
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Connect connects the endpoint to its peer.
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ err := e.connect(addr, true, true)
+ if err != nil && !err.IgnoreStats() {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ }
+ return err
+}
+
+// connect connects the endpoint to its peer. In the normal non-S/R case, the
+// new connection is expected to run the main goroutine and perform handshake.
+// In restore of previously connected endpoints, both ends will be passively
+// created (so no new handshaking is done); for stack-accepted connections not
+// yet accepted by the app, they are restored without running the main goroutine
+// here.
+func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error {
+ e.LockUser()
+ defer e.UnlockUser()
+
+ connectingAddr := addr.Addr
+
+ addr, netProto, err := e.checkV4MappedLocked(addr)
+ if err != nil {
+ return err
+ }
+
+ if e.EndpointState().connected() {
+ // The endpoint is already connected. If caller hasn't been
+ // notified yet, return success.
+ if !e.isConnectNotified {
+ e.isConnectNotified = true
+ return nil
+ }
+ // Otherwise return that it's already connected.
+ return tcpip.ErrAlreadyConnected
+ }
+
+ nicID := addr.NIC
+ switch e.EndpointState() {
+ case StateBound:
+ // If we're already bound to a NIC but the caller is requesting
+ // that we use a different one now, we cannot proceed.
+ if e.boundNICID == 0 {
+ break
+ }
+
+ if nicID != 0 && nicID != e.boundNICID {
+ return tcpip.ErrNoRoute
+ }
+
+ nicID = e.boundNICID
+
+ case StateInitial:
+ // Nothing to do. We'll eventually fill-in the gaps in the ID (if any)
+ // when we find a route.
+
+ case StateConnecting, StateSynSent, StateSynRecv:
+ // A connection request has already been issued but hasn't completed
+ // yet.
+ return tcpip.ErrAlreadyConnecting
+
+ case StateError:
+ return e.HardError
+
+ default:
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ e.ID.LocalAddress = r.LocalAddress
+ e.ID.RemoteAddress = r.RemoteAddress
+ e.ID.RemotePort = addr.Port
+
+ if e.ID.LocalPort != 0 {
+ // The endpoint is bound to a port, attempt to register it.
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ if err != nil {
+ return err
+ }
+ } else {
+ // The endpoint doesn't have a local port yet, so try to get
+ // one. Make sure that it isn't one that will result in the same
+ // address/port for both local and remote (otherwise this
+ // endpoint would be trying to connect to itself).
+ sameAddr := e.ID.LocalAddress == e.ID.RemoteAddress
+
+ // Calculate a port offset based on the destination IP/port and
+ // src IP to ensure that for a given tuple (srcIP, destIP,
+ // destPort) the offset used as a starting point is the same to
+ // ensure that we can cycle through the port space effectively.
+ h := jenkins.Sum32(e.stack.Seed())
+ h.Write([]byte(e.ID.LocalAddress))
+ h.Write([]byte(e.ID.RemoteAddress))
+ portBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort)
+ h.Write(portBuf)
+ portOffset := h.Sum32()
+
+ if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
+ if sameAddr && p == e.ID.RemotePort {
+ return false, nil
+ }
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil {
+ return false, nil
+ }
+
+ id := e.ID
+ id.LocalPort = p
+ if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr)
+ if err == tcpip.ErrPortInUse {
+ return false, nil
+ }
+ return false, err
+ }
+
+ // Port picking successful. Save the details of
+ // the selected port.
+ e.ID = id
+ e.isPortReserved = true
+ e.boundBindToDevice = e.bindToDevice
+ e.boundPortFlags = e.portFlags
+ e.boundDest = addr
+ return true, nil
+ }); err != nil {
+ return err
+ }
+ }
+
+ e.isRegistered = true
+ e.setEndpointState(StateConnecting)
+ e.route = r.Clone()
+ e.boundNICID = nicID
+ e.effectiveNetProtos = netProtos
+ e.connectingAddress = connectingAddr
+
+ e.initGSO()
+
+ // Connect in the restore phase does not perform handshake. Restore its
+ // connection setting here.
+ if !handshake {
+ e.segmentQueue.mu.Lock()
+ for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} {
+ for s := l.Front(); s != nil; s = s.Next() {
+ s.id = e.ID
+ s.route = r.Clone()
+ e.sndWaker.Assert()
+ }
+ }
+ e.segmentQueue.mu.Unlock()
+ e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
+ e.setEndpointState(StateEstablished)
+ }
+
+ if run {
+ e.workerRunning = true
+ e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
+ go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
+ }
+
+ return tcpip.ErrConnectStarted
+}
+
+// ConnectEndpoint is not supported.
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection to its
+// peer.
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.LockUser()
+ defer e.UnlockUser()
+ return e.shutdownLocked(flags)
+}
+
+func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.shutdownFlags |= flags
+ switch {
+ case e.EndpointState().connected():
+ // Close for read.
+ if e.shutdownFlags&tcpip.ShutdownRead != 0 {
+ // Mark read side as closed.
+ e.rcvListMu.Lock()
+ e.rcvClosed = true
+ rcvBufUsed := e.rcvBufUsed
+ e.rcvListMu.Unlock()
+
+ // If we're fully closed and we have unread data we need to abort
+ // the connection with a RST.
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 && rcvBufUsed > 0 {
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ // Wake up worker to terminate loop.
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ return nil
+ }
+ }
+
+ // Close for write.
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
+ e.sndBufMu.Lock()
+ if e.sndClosed {
+ // Already closed.
+ e.sndBufMu.Unlock()
+ if e.EndpointState() == StateTimeWait {
+ return tcpip.ErrNotConnected
+ }
+ return nil
+ }
+
+ // Queue fin segment.
+ s := newSegmentFromView(&e.route, e.ID, nil)
+ e.sndQueue.PushBack(s)
+ e.sndBufInQueue++
+ // Mark endpoint as closed.
+ e.sndClosed = true
+ e.sndBufMu.Unlock()
+ e.handleClose()
+ }
+
+ return nil
+ case e.EndpointState() == StateListen:
+ if e.shutdownFlags&tcpip.ShutdownRead != 0 {
+ // Reset all connections from the accept queue and keep the
+ // worker running so that it can continue handling incoming
+ // segments by replying with RST.
+ //
+ // By not removing this endpoint from the demuxer mapping, we
+ // ensure that any other bind to the same port fails, as on Linux.
+ e.rcvListMu.Lock()
+ e.rcvClosed = true
+ e.rcvListMu.Unlock()
+ e.closePendingAcceptableConnectionsLocked()
+ // Notify waiters that the endpoint is shutdown.
+ e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
+ }
+ return nil
+ default:
+ return tcpip.ErrNotConnected
+ }
+}
+
+// Listen puts the endpoint in "listen" mode, which allows it to accept
+// new connections.
+func (e *endpoint) Listen(backlog int) *tcpip.Error {
+ err := e.listen(backlog)
+ if err != nil && !err.IgnoreStats() {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ }
+ return err
+}
+
+func (e *endpoint) listen(backlog int) *tcpip.Error {
+ e.LockUser()
+ defer e.UnlockUser()
+
+ if e.EndpointState() == StateListen && !e.closed {
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+ if e.acceptedChan == nil {
+ // listen is called after shutdown.
+ e.acceptedChan = make(chan *endpoint, backlog)
+ e.shutdownFlags = 0
+ e.rcvListMu.Lock()
+ e.rcvClosed = false
+ e.rcvListMu.Unlock()
+ } else {
+ // Adjust the size of the channel iff we can fix
+ // existing pending connections into the new one.
+ if len(e.acceptedChan) > backlog {
+ return tcpip.ErrInvalidEndpointState
+ }
+ if cap(e.acceptedChan) == backlog {
+ return nil
+ }
+ origChan := e.acceptedChan
+ e.acceptedChan = make(chan *endpoint, backlog)
+ close(origChan)
+ for ep := range origChan {
+ e.acceptedChan <- ep
+ }
+ }
+
+ // Notify any blocked goroutines that they can attempt to
+ // deliver endpoints again.
+ e.acceptCond.Broadcast()
+
+ return nil
+ }
+
+ if e.EndpointState() == StateInitial {
+ // The listen is called on an unbound socket, the socket is
+ // automatically bound to a random free port with the local
+ // address set to INADDR_ANY.
+ if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
+ return err
+ }
+ }
+
+ // Endpoint must be bound before it can transition to listen mode.
+ if e.EndpointState() != StateBound {
+ e.stats.ReadErrors.InvalidEndpointState.Increment()
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Register the endpoint.
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil {
+ return err
+ }
+
+ e.isRegistered = true
+ e.setEndpointState(StateListen)
+
+ // The channel may be non-nil when we're restoring the endpoint, and it
+ // may be pre-populated with some previously accepted (but not Accepted)
+ // endpoints.
+ e.acceptMu.Lock()
+ if e.acceptedChan == nil {
+ e.acceptedChan = make(chan *endpoint, backlog)
+ }
+ e.acceptMu.Unlock()
+
+ e.workerRunning = true
+ go e.protocolListenLoop( // S/R-SAFE: drained on save.
+ seqnum.Size(e.receiveBufferAvailable()))
+ return nil
+}
+
+// startAcceptedLoop sets up required state and starts a goroutine with the
+// main loop for accepted connections.
+func (e *endpoint) startAcceptedLoop() {
+ e.workerRunning = true
+ e.mu.Unlock()
+ wakerInitDone := make(chan struct{})
+ go e.protocolMainLoop(false, wakerInitDone) // S/R-SAFE: drained on save.
+ <-wakerInitDone
+}
+
+// Accept returns a new endpoint if a peer has established a connection
+// to an endpoint previously set to listen mode.
+func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ e.LockUser()
+ defer e.UnlockUser()
+
+ e.rcvListMu.Lock()
+ rcvClosed := e.rcvClosed
+ e.rcvListMu.Unlock()
+ // Endpoint must be in listen state before it can accept connections.
+ if rcvClosed || e.EndpointState() != StateListen {
+ return nil, nil, tcpip.ErrInvalidEndpointState
+ }
+
+ // Get the new accepted endpoint.
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+ var n *endpoint
+ select {
+ case n = <-e.acceptedChan:
+ e.acceptCond.Signal()
+ default:
+ return nil, nil, tcpip.ErrWouldBlock
+ }
+ return n, n.waiterQueue, nil
+}
+
+// Bind binds the endpoint to a specific local port and optionally address.
+func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
+ e.LockUser()
+ defer e.UnlockUser()
+
+ return e.bindLocked(addr)
+}
+
+func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore. This is because once the endpoint goes into a connected or
+ // listen state, it is already bound.
+ if e.EndpointState() != StateInitial {
+ return tcpip.ErrAlreadyBound
+ }
+
+ e.BindAddr = addr.Addr
+ addr, netProto, err := e.checkV4MappedLocked(addr)
+ if err != nil {
+ return err
+ }
+
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv6ProtocolNumber,
+ header.IPv4ProtocolNumber,
+ }
+ }
+
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{})
+ if err != nil {
+ return err
+ }
+
+ e.boundBindToDevice = e.bindToDevice
+ e.boundPortFlags = e.portFlags
+ e.isPortReserved = true
+ e.effectiveNetProtos = netProtos
+ e.ID.LocalPort = port
+
+ // Any failures beyond this point must remove the port registration.
+ defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) {
+ if err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice, tcpip.FullAddress{})
+ e.isPortReserved = false
+ e.effectiveNetProtos = nil
+ e.ID.LocalPort = 0
+ e.ID.LocalAddress = ""
+ e.boundNICID = 0
+ e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
+ }
+ }(e.boundPortFlags, e.boundBindToDevice)
+
+ // If an address is specified, we must ensure that it's one of our
+ // local addresses.
+ if len(addr.Addr) != 0 {
+ nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
+ if nic == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ e.boundNICID = nic
+ e.ID.LocalAddress = addr.Addr
+ }
+
+ if err := e.stack.CheckRegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e.boundPortFlags, e.boundBindToDevice); err != nil {
+ return err
+ }
+
+ // Mark endpoint as bound.
+ e.setEndpointState(StateBound)
+
+ return nil
+}
+
+// GetLocalAddress returns the address to which the endpoint is bound.
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.LockUser()
+ defer e.UnlockUser()
+
+ return tcpip.FullAddress{
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
+ NIC: e.boundNICID,
+ }, nil
+}
+
+// GetRemoteAddress returns the address to which the endpoint is connected.
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.LockUser()
+ defer e.UnlockUser()
+
+ if !e.EndpointState().connected() {
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ }
+
+ return tcpip.FullAddress{
+ Addr: e.ID.RemoteAddress,
+ Port: e.ID.RemotePort,
+ NIC: e.boundNICID,
+ }, nil
+}
+
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+ // TCP HandlePacket is not required anymore as inbound packets first
+ // land at the Dispatcher which then can either delivery using the
+ // worker go routine or directly do the invoke the tcp processing inline
+ // based on the state of the endpoint.
+}
+
+func (e *endpoint) enqueueSegment(s *segment) bool {
+ // Send packet to worker goroutine.
+ if !e.segmentQueue.enqueue(s) {
+ // The queue is full, so we drop the segment.
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.SegmentQueueDropped.Increment()
+ return false
+ }
+ return true
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
+ switch typ {
+ case stack.ControlPacketTooBig:
+ e.sndBufMu.Lock()
+ e.packetTooBigCount++
+ if v := int(extra); v < e.sndMTU {
+ e.sndMTU = v
+ }
+ e.sndBufMu.Unlock()
+
+ e.notifyProtocolGoroutine(notifyMTUChanged)
+ }
+}
+
+// updateSndBufferUsage is called by the protocol goroutine when room opens up
+// in the send buffer. The number of newly available bytes is v.
+func (e *endpoint) updateSndBufferUsage(v int) {
+ e.sndBufMu.Lock()
+ notify := e.sndBufUsed >= e.sndBufSize>>1
+ e.sndBufUsed -= v
+ // We only notify when there is half the sndBufSize available after
+ // a full buffer event occurs. This ensures that we don't wake up
+ // writers to queue just 1-2 segments and go back to sleep.
+ notify = notify && e.sndBufUsed < e.sndBufSize>>1
+ e.sndBufMu.Unlock()
+
+ if notify {
+ e.waiterQueue.Notify(waiter.EventOut)
+ }
+}
+
+// readyToRead is called by the protocol goroutine when a new segment is ready
+// to be read, or when the connection is closed for receiving (in which case
+// s will be nil).
+func (e *endpoint) readyToRead(s *segment) {
+ e.rcvListMu.Lock()
+ if s != nil {
+ s.incRef()
+ e.rcvBufUsed += s.data.Size()
+ // Increase counter if the receive window falls down below MSS
+ // or half receive buffer size, whichever smaller.
+ if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above {
+ e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
+ }
+ e.rcvList.PushBack(s)
+ } else {
+ e.rcvClosed = true
+ }
+ e.rcvListMu.Unlock()
+ e.waiterQueue.Notify(waiter.EventIn)
+}
+
+// receiveBufferAvailableLocked calculates how many bytes are still available
+// in the receive buffer.
+// rcvListMu must be held when this function is called.
+func (e *endpoint) receiveBufferAvailableLocked() int {
+ // We may use more bytes than the buffer size when the receive buffer
+ // shrinks.
+ if e.rcvBufUsed >= e.rcvBufSize {
+ return 0
+ }
+
+ return e.rcvBufSize - e.rcvBufUsed
+}
+
+// receiveBufferAvailable calculates how many bytes are still available in the
+// receive buffer.
+func (e *endpoint) receiveBufferAvailable() int {
+ e.rcvListMu.Lock()
+ available := e.receiveBufferAvailableLocked()
+ e.rcvListMu.Unlock()
+ return available
+}
+
+func (e *endpoint) receiveBufferSize() int {
+ e.rcvListMu.Lock()
+ size := e.rcvBufSize
+ e.rcvListMu.Unlock()
+
+ return size
+}
+
+func (e *endpoint) maxReceiveBufferSize() int {
+ var rs ReceiveBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil {
+ // As a fallback return the hardcoded max buffer size.
+ return MaxBufferSize
+ }
+ return rs.Max
+}
+
+// rcvWndScaleForHandshake computes the receive window scale to offer to the
+// peer when window scaling is enabled (true by default). If auto-tuning is
+// disabled then the window scaling factor is based on the size of the
+// receiveBuffer otherwise we use the max permissible receive buffer size to
+// compute the scale.
+func (e *endpoint) rcvWndScaleForHandshake() int {
+ bufSizeForScale := e.receiveBufferSize()
+
+ e.rcvListMu.Lock()
+ autoTuningDisabled := e.rcvAutoParams.disabled
+ e.rcvListMu.Unlock()
+ if autoTuningDisabled {
+ return FindWndScale(seqnum.Size(bufSizeForScale))
+ }
+
+ return FindWndScale(seqnum.Size(e.maxReceiveBufferSize()))
+}
+
+// updateRecentTimestamp updates the recent timestamp using the algorithm
+// described in https://tools.ietf.org/html/rfc7323#section-4.3
+func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) {
+ if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) {
+ e.setRecentTimestamp(tsVal)
+ }
+}
+
+// maybeEnableTimestamp marks the timestamp option enabled for this endpoint if
+// the SYN options indicate that timestamp option was negotiated. It also
+// initializes the recentTS with the value provided in synOpts.TSval.
+func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
+ if synOpts.TS {
+ e.sendTSOk = true
+ e.setRecentTimestamp(synOpts.TSVal)
+ }
+}
+
+// timestamp returns the timestamp value to be used in the TSVal field of the
+// timestamp option for outgoing TCP segments for a given endpoint.
+func (e *endpoint) timestamp() uint32 {
+ return tcpTimeStamp(e.tsOffset)
+}
+
+// tcpTimeStamp returns a timestamp offset by the provided offset. This is
+// not inlined above as it's used when SYN cookies are in use and endpoint
+// is not created at the time when the SYN cookie is sent.
+func tcpTimeStamp(offset uint32) uint32 {
+ now := time.Now()
+ return uint32(now.Unix()*1000+int64(now.Nanosecond()/1e6)) + offset
+}
+
+// timeStampOffset returns a randomized timestamp offset to be used when sending
+// timestamp values in a timestamp option for a TCP segment.
+func timeStampOffset() uint32 {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ // Initialize a random tsOffset that will be added to the recentTS
+ // everytime the timestamp is sent when the Timestamp option is enabled.
+ //
+ // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on
+ // why this is required.
+ //
+ // NOTE: This is not completely to spec as normally this should be
+ // initialized in a manner analogous to how sequence numbers are
+ // randomized per connection basis. But for now this is sufficient.
+ return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+}
+
+// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint
+// if the SYN options indicate that the SACK option was negotiated and the TCP
+// stack is configured to enable TCP SACK option.
+func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
+ var v SACKEnabled
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil {
+ // Stack doesn't support SACK. So just return.
+ return
+ }
+ if bool(v) && synOpts.SACKPermitted {
+ e.sackPermitted = true
+ }
+}
+
+// maxOptionSize return the maximum size of TCP options.
+func (e *endpoint) maxOptionSize() (size int) {
+ var maxSackBlocks [header.TCPMaxSACKBlocks]header.SACKBlock
+ options := e.makeOptions(maxSackBlocks[:])
+ size = len(options)
+ putOptions(options)
+
+ return size
+}
+
+// completeState makes a full copy of the endpoint and returns it. This is used
+// before invoking the probe. The state returned may not be fully consistent if
+// there are intervening syscalls when the state is being copied.
+func (e *endpoint) completeState() stack.TCPEndpointState {
+ var s stack.TCPEndpointState
+ s.SegTime = time.Now()
+
+ // Copy EndpointID.
+ s.ID = stack.TCPEndpointID(e.ID)
+
+ // Copy endpoint rcv state.
+ e.rcvListMu.Lock()
+ s.RcvBufSize = e.rcvBufSize
+ s.RcvBufUsed = e.rcvBufUsed
+ s.RcvClosed = e.rcvClosed
+ s.RcvAutoParams.MeasureTime = e.rcvAutoParams.measureTime
+ s.RcvAutoParams.CopiedBytes = e.rcvAutoParams.copied
+ s.RcvAutoParams.PrevCopiedBytes = e.rcvAutoParams.prevCopied
+ s.RcvAutoParams.RTT = e.rcvAutoParams.rtt
+ s.RcvAutoParams.RTTMeasureSeqNumber = e.rcvAutoParams.rttMeasureSeqNumber
+ s.RcvAutoParams.RTTMeasureTime = e.rcvAutoParams.rttMeasureTime
+ s.RcvAutoParams.Disabled = e.rcvAutoParams.disabled
+ e.rcvListMu.Unlock()
+
+ // Endpoint TCP Option state.
+ s.SendTSOk = e.sendTSOk
+ s.RecentTS = e.recentTimestamp()
+ s.TSOffset = e.tsOffset
+ s.SACKPermitted = e.sackPermitted
+ s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks)
+ copy(s.SACK.Blocks, e.sack.Blocks[:e.sack.NumBlocks])
+ s.SACK.ReceivedBlocks, s.SACK.MaxSACKED = e.scoreboard.Copy()
+
+ // Copy endpoint send state.
+ e.sndBufMu.Lock()
+ s.SndBufSize = e.sndBufSize
+ s.SndBufUsed = e.sndBufUsed
+ s.SndClosed = e.sndClosed
+ s.SndBufInQueue = e.sndBufInQueue
+ s.PacketTooBigCount = e.packetTooBigCount
+ s.SndMTU = e.sndMTU
+ e.sndBufMu.Unlock()
+
+ // Copy receiver state.
+ s.Receiver = stack.TCPReceiverState{
+ RcvNxt: e.rcv.rcvNxt,
+ RcvAcc: e.rcv.rcvAcc,
+ RcvWndScale: e.rcv.rcvWndScale,
+ PendingBufUsed: e.rcv.pendingBufUsed,
+ PendingBufSize: e.rcv.pendingBufSize,
+ }
+
+ // Copy sender state.
+ s.Sender = stack.TCPSenderState{
+ LastSendTime: e.snd.lastSendTime,
+ DupAckCount: e.snd.dupAckCount,
+ FastRecovery: stack.TCPFastRecoveryState{
+ Active: e.snd.fr.active,
+ First: e.snd.fr.first,
+ Last: e.snd.fr.last,
+ MaxCwnd: e.snd.fr.maxCwnd,
+ HighRxt: e.snd.fr.highRxt,
+ RescueRxt: e.snd.fr.rescueRxt,
+ },
+ SndCwnd: e.snd.sndCwnd,
+ Ssthresh: e.snd.sndSsthresh,
+ SndCAAckCount: e.snd.sndCAAckCount,
+ Outstanding: e.snd.outstanding,
+ SndWnd: e.snd.sndWnd,
+ SndUna: e.snd.sndUna,
+ SndNxt: e.snd.sndNxt,
+ RTTMeasureSeqNum: e.snd.rttMeasureSeqNum,
+ RTTMeasureTime: e.snd.rttMeasureTime,
+ Closed: e.snd.closed,
+ RTO: e.snd.rto,
+ MaxPayloadSize: e.snd.maxPayloadSize,
+ SndWndScale: e.snd.sndWndScale,
+ MaxSentAck: e.snd.maxSentAck,
+ }
+ e.snd.rtt.Lock()
+ s.Sender.SRTT = e.snd.rtt.srtt
+ s.Sender.SRTTInited = e.snd.rtt.srttInited
+ e.snd.rtt.Unlock()
+
+ if cubic, ok := e.snd.cc.(*cubicState); ok {
+ s.Sender.Cubic = stack.TCPCubicState{
+ WMax: cubic.wMax,
+ WLastMax: cubic.wLastMax,
+ T: cubic.t,
+ TimeSinceLastCongestion: time.Since(cubic.t),
+ C: cubic.c,
+ K: cubic.k,
+ Beta: cubic.beta,
+ WC: cubic.wC,
+ WEst: cubic.wEst,
+ }
+ }
+ return s
+}
+
+func (e *endpoint) initHardwareGSO() {
+ gso := &stack.GSO{}
+ switch e.route.NetProto {
+ case header.IPv4ProtocolNumber:
+ gso.Type = stack.GSOTCPv4
+ gso.L3HdrLen = header.IPv4MinimumSize
+ case header.IPv6ProtocolNumber:
+ gso.Type = stack.GSOTCPv6
+ gso.L3HdrLen = header.IPv6MinimumSize
+ default:
+ panic(fmt.Sprintf("Unknown netProto: %v", e.NetProto))
+ }
+ gso.NeedsCsum = true
+ gso.CsumOffset = header.TCPChecksumOffset
+ gso.MaxSize = e.route.GSOMaxSize()
+ e.gso = gso
+}
+
+func (e *endpoint) initGSO() {
+ if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ e.initHardwareGSO()
+ } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 {
+ e.gso = &stack.GSO{
+ MaxSize: e.route.GSOMaxSize(),
+ Type: stack.GSOSW,
+ NeedsCsum: false,
+ }
+ }
+}
+
+// State implements tcpip.Endpoint.State. It exports the endpoint's protocol
+// state for diagnostics.
+func (e *endpoint) State() uint32 {
+ return uint32(e.EndpointState())
+}
+
+// Info returns a copy of the endpoint info.
+func (e *endpoint) Info() tcpip.EndpointInfo {
+ e.LockUser()
+ // Make a copy of the endpoint info.
+ ret := e.EndpointInfo
+ e.UnlockUser()
+ return &ret
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (e *endpoint) Stats() tcpip.EndpointStats {
+ return &e.stats
+}
+
+// Wait implements stack.TransportEndpoint.Wait.
+func (e *endpoint) Wait() {
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp)
+ defer e.waiterQueue.EventUnregister(&waitEntry)
+ for {
+ e.LockUser()
+ running := e.workerRunning
+ e.UnlockUser()
+ if !running {
+ break
+ }
+ <-notifyCh
+ }
+}
+
+func mssForRoute(r *stack.Route) uint16 {
+ // TODO(b/143359391): Respect TCP Min and Max size.
+ return uint16(r.MTU() - header.TCPMinimumSize)
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
new file mode 100644
index 000000000..abf1ac5c9
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -0,0 +1,348 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "fmt"
+ "sync/atomic"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+func (e *endpoint) drainSegmentLocked() {
+ // Drain only up to once.
+ if e.drainDone != nil {
+ return
+ }
+
+ e.drainDone = make(chan struct{})
+ e.undrain = make(chan struct{})
+ e.mu.Unlock()
+
+ e.notifyProtocolGoroutine(notifyDrain)
+ <-e.drainDone
+
+ e.mu.Lock()
+}
+
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ // Stop incoming packets.
+ e.segmentQueue.setLimit(0)
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ epState := e.EndpointState()
+ switch {
+ case epState == StateInitial || epState == StateBound:
+ case epState.connected() || epState.handshake():
+ if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
+ if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
+ panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)})
+ }
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.mu.Unlock()
+ e.Close()
+ e.mu.Lock()
+ }
+ if !e.workerRunning {
+ // The endpoint must be in acceptedChan or has been just
+ // disconnected and closed.
+ break
+ }
+ fallthrough
+ case epState == StateListen || epState == StateConnecting:
+ e.drainSegmentLocked()
+ // Refresh epState, since drainSegmentLocked may have changed it.
+ epState = e.EndpointState()
+ if !epState.closed() {
+ if !e.workerRunning {
+ panic("endpoint has no worker running in listen, connecting, or connected state")
+ }
+ }
+ case epState.closed():
+ for e.workerRunning {
+ e.mu.Unlock()
+ time.Sleep(100 * time.Millisecond)
+ e.mu.Lock()
+ }
+ if e.workerRunning {
+ panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.ID))
+ }
+ default:
+ panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState()))
+ }
+
+ if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() {
+ panic("endpoint still has waiters upon save")
+ }
+}
+
+// saveAcceptedChan is invoked by stateify.
+func (e *endpoint) saveAcceptedChan() []*endpoint {
+ if e.acceptedChan == nil {
+ return nil
+ }
+ acceptedEndpoints := make([]*endpoint, len(e.acceptedChan), cap(e.acceptedChan))
+ for i := 0; i < len(acceptedEndpoints); i++ {
+ select {
+ case ep := <-e.acceptedChan:
+ acceptedEndpoints[i] = ep
+ default:
+ panic("endpoint acceptedChan buffer got consumed by background context")
+ }
+ }
+ for i := 0; i < len(acceptedEndpoints); i++ {
+ select {
+ case e.acceptedChan <- acceptedEndpoints[i]:
+ default:
+ panic("endpoint acceptedChan buffer got populated by background context")
+ }
+ }
+ return acceptedEndpoints
+}
+
+// loadAcceptedChan is invoked by stateify.
+func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) {
+ if cap(acceptedEndpoints) > 0 {
+ e.acceptedChan = make(chan *endpoint, cap(acceptedEndpoints))
+ for _, ep := range acceptedEndpoints {
+ e.acceptedChan <- ep
+ }
+ }
+}
+
+// saveState is invoked by stateify.
+func (e *endpoint) saveState() EndpointState {
+ return e.EndpointState()
+}
+
+// Endpoint loading must be done in the following ordering by their state, to
+// avoid dangling connecting w/o listening peer, and to avoid conflicts in port
+// reservation.
+var connectedLoading sync.WaitGroup
+var listenLoading sync.WaitGroup
+var connectingLoading sync.WaitGroup
+
+// Bound endpoint loading happens last.
+
+// loadState is invoked by stateify.
+func (e *endpoint) loadState(epState EndpointState) {
+ // This is to ensure that the loading wait groups include all applicable
+ // endpoints before any asynchronous calls to the Wait() methods.
+ // For restore purposes we treat TimeWait like a connected endpoint.
+ if epState.connected() || epState == StateTimeWait {
+ connectedLoading.Add(1)
+ }
+ switch {
+ case epState == StateListen:
+ listenLoading.Add(1)
+ case epState.connecting():
+ connectingLoading.Add(1)
+ }
+ // Directly update the state here rather than using e.setEndpointState
+ // as the endpoint is still being loaded and the stack reference is not
+ // yet initialized.
+ atomic.StoreUint32((*uint32)(&e.state), uint32(epState))
+}
+
+// afterLoad is invoked by stateify.
+func (e *endpoint) afterLoad() {
+ e.origEndpointState = e.state
+ // Restore the endpoint to InitialState as it will be moved to
+ // its origEndpointState during Resume.
+ e.state = StateInitial
+ // Condition variables and mutexs are not S/R'ed so reinitialize
+ // acceptCond with e.acceptMu.
+ e.acceptCond = sync.NewCond(&e.acceptMu)
+ stack.StackFromEnv.RegisterRestoredEndpoint(e)
+}
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *endpoint) Resume(s *stack.Stack) {
+ e.stack = s
+ e.segmentQueue.setLimit(MaxUnprocessedSegments)
+ epState := e.origEndpointState
+ switch epState {
+ case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
+ var ss SendBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
+ panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
+ }
+ }
+
+ var rs ReceiveBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max {
+ panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max))
+ }
+ }
+ }
+
+ bind := func() {
+ addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort})
+ if err != nil {
+ panic("unable to parse BindAddr: " + err.String())
+ }
+ if ok := e.stack.ReserveTuple(e.effectiveNetProtos, ProtocolNumber, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest); !ok {
+ panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest))
+ }
+ e.isPortReserved = true
+
+ // Mark endpoint as bound.
+ e.setEndpointState(StateBound)
+ }
+
+ switch {
+ case epState.connected():
+ bind()
+ if len(e.connectingAddress) == 0 {
+ e.connectingAddress = e.ID.RemoteAddress
+ // This endpoint is accepted by netstack but not yet by
+ // the app. If the endpoint is IPv6 but the remote
+ // address is IPv4, we need to connect as IPv6 so that
+ // dual-stack mode can be properly activated.
+ if e.NetProto == header.IPv6ProtocolNumber && len(e.ID.RemoteAddress) != header.IPv6AddressSize {
+ e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.ID.RemoteAddress
+ }
+ }
+ // Reset the scoreboard to reinitialize the sack information as
+ // we do not restore SACK information.
+ e.scoreboard.Reset()
+ if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
+ panic("endpoint connecting failed: " + err.String())
+ }
+ e.mu.Lock()
+ e.state = e.origEndpointState
+ closed := e.closed
+ e.mu.Unlock()
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ if epState == StateFinWait2 && closed {
+ // If the endpoint has been closed then make sure we notify so
+ // that the FIN_WAIT2 timer is started after a restore.
+ e.notifyProtocolGoroutine(notifyClose)
+ }
+ connectedLoading.Done()
+ case epState == StateListen:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ bind()
+ backlog := cap(e.acceptedChan)
+ if err := e.Listen(backlog); err != nil {
+ panic("endpoint listening failed: " + err.String())
+ }
+ e.LockUser()
+ if e.shutdownFlags != 0 {
+ e.shutdownLocked(e.shutdownFlags)
+ }
+ e.UnlockUser()
+ listenLoading.Done()
+ tcpip.AsyncLoading.Done()
+ }()
+ case epState.connecting():
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ listenLoading.Wait()
+ bind()
+ if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}); err != tcpip.ErrConnectStarted {
+ panic("endpoint connecting failed: " + err.String())
+ }
+ connectingLoading.Done()
+ tcpip.AsyncLoading.Done()
+ }()
+ case epState == StateBound:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ listenLoading.Wait()
+ connectingLoading.Wait()
+ bind()
+ tcpip.AsyncLoading.Done()
+ }()
+ case epState == StateClose:
+ e.isPortReserved = false
+ e.state = StateClose
+ e.stack.CompleteTransportEndpointCleanup(e)
+ tcpip.DeleteDanglingEndpoint(e)
+ case epState == StateError:
+ e.state = StateError
+ e.stack.CompleteTransportEndpointCleanup(e)
+ tcpip.DeleteDanglingEndpoint(e)
+ }
+}
+
+// saveLastError is invoked by stateify.
+func (e *endpoint) saveLastError() string {
+ if e.lastError == nil {
+ return ""
+ }
+
+ return e.lastError.String()
+}
+
+// loadLastError is invoked by stateify.
+func (e *endpoint) loadLastError(s string) {
+ if s == "" {
+ return
+ }
+
+ e.lastError = tcpip.StringToError(s)
+}
+
+// saveHardError is invoked by stateify.
+func (e *EndpointInfo) saveHardError() string {
+ if e.HardError == nil {
+ return ""
+ }
+
+ return e.HardError.String()
+}
+
+// loadHardError is invoked by stateify.
+func (e *EndpointInfo) loadHardError(s string) {
+ if s == "" {
+ return
+ }
+
+ e.HardError = tcpip.StringToError(s)
+}
+
+// saveMeasureTime is invoked by stateify.
+func (r *rcvBufAutoTuneParams) saveMeasureTime() unixTime {
+ return unixTime{r.measureTime.Unix(), r.measureTime.UnixNano()}
+}
+
+// loadMeasureTime is invoked by stateify.
+func (r *rcvBufAutoTuneParams) loadMeasureTime(unix unixTime) {
+ r.measureTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveRttMeasureTime is invoked by stateify.
+func (r *rcvBufAutoTuneParams) saveRttMeasureTime() unixTime {
+ return unixTime{r.rttMeasureTime.Unix(), r.rttMeasureTime.UnixNano()}
+}
+
+// loadRttMeasureTime is invoked by stateify.
+func (r *rcvBufAutoTuneParams) loadRttMeasureTime(unix unixTime) {
+ r.rttMeasureTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
new file mode 100644
index 000000000..070b634b4
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -0,0 +1,169 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// Forwarder is a connection request forwarder, which allows clients to decide
+// what to do with a connection request, for example: ignore it, send a RST, or
+// attempt to complete the 3-way handshake.
+//
+// The canonical way of using it is to pass the Forwarder.HandlePacket function
+// to stack.SetTransportProtocolHandler.
+type Forwarder struct {
+ maxInFlight int
+ handler func(*ForwarderRequest)
+
+ mu sync.Mutex
+ inFlight map[stack.TransportEndpointID]struct{}
+ listen *listenContext
+}
+
+// NewForwarder allocates and initializes a new forwarder with the given
+// maximum number of in-flight connection attempts. Once the maximum is reached
+// new incoming connection requests will be ignored.
+//
+// If rcvWnd is set to zero, the default buffer size is used instead.
+func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*ForwarderRequest)) *Forwarder {
+ if rcvWnd == 0 {
+ rcvWnd = DefaultReceiveBufferSize
+ }
+ return &Forwarder{
+ maxInFlight: maxInFlight,
+ handler: handler,
+ inFlight: make(map[stack.TransportEndpointID]struct{}),
+ listen: newListenContext(s, nil /* listenEP */, seqnum.Size(rcvWnd), true, 0),
+ }
+}
+
+// HandlePacket handles a packet if it is of interest to the forwarder (i.e., if
+// it's a SYN packet), returning true if it's the case. Otherwise the packet
+// is not handled and false is returned.
+//
+// This function is expected to be passed as an argument to the
+// stack.SetTransportProtocolHandler function.
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+ s := newSegment(r, id, pkt)
+ defer s.decRef()
+
+ // We only care about well-formed SYN packets.
+ if !s.parse() || !s.csumValid || s.flags != header.TCPFlagSyn {
+ return false
+ }
+
+ opts := parseSynSegmentOptions(s)
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // We have an inflight request for this id, ignore this one for now.
+ if _, ok := f.inFlight[id]; ok {
+ return true
+ }
+
+ // Ignore the segment if we're beyond the limit.
+ if len(f.inFlight) >= f.maxInFlight {
+ return true
+ }
+
+ // Launch a new goroutine to handle the request.
+ f.inFlight[id] = struct{}{}
+ s.incRef()
+ go f.handler(&ForwarderRequest{ // S/R-SAFE: not used by Sentry.
+ forwarder: f,
+ segment: s,
+ synOptions: opts,
+ })
+
+ return true
+}
+
+// ForwarderRequest represents a connection request received by the forwarder
+// and passed to the client. Clients must eventually call Complete() on it, and
+// may optionally create an endpoint to represent it via CreateEndpoint.
+type ForwarderRequest struct {
+ mu sync.Mutex
+ forwarder *Forwarder
+ segment *segment
+ synOptions header.TCPSynOptions
+}
+
+// ID returns the 4-tuple (src address, src port, dst address, dst port) that
+// represents the connection request.
+func (r *ForwarderRequest) ID() stack.TransportEndpointID {
+ return r.segment.id
+}
+
+// Complete completes the request, and optionally sends a RST segment back to the
+// sender.
+func (r *ForwarderRequest) Complete(sendReset bool) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.segment == nil {
+ panic("Completing already completed forwarder request")
+ }
+
+ // Remove request from the forwarder.
+ r.forwarder.mu.Lock()
+ delete(r.forwarder.inFlight, r.segment.id)
+ r.forwarder.mu.Unlock()
+
+ // If the caller requested, send a reset.
+ if sendReset {
+ replyWithReset(r.segment, stack.DefaultTOS, r.segment.route.DefaultTTL())
+ }
+
+ // Release all resources.
+ r.segment.decRef()
+ r.segment = nil
+ r.forwarder = nil
+}
+
+// CreateEndpoint creates a TCP endpoint for the connection request, performing
+// the 3-way handshake in the process.
+func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.segment == nil {
+ return nil, tcpip.ErrInvalidEndpointState
+ }
+
+ f := r.forwarder
+ ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{
+ MSS: r.synOptions.MSS,
+ WS: r.synOptions.WS,
+ TS: r.synOptions.TS,
+ TSVal: r.synOptions.TSVal,
+ TSEcr: r.synOptions.TSEcr,
+ SACKPermitted: r.synOptions.SACKPermitted,
+ }, queue, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ // Start the protocol goroutine.
+ ep.startAcceptedLoop()
+
+ return ep, nil
+}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
new file mode 100644
index 000000000..b34e47bbd
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -0,0 +1,541 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tcp contains the implementation of the TCP transport protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing tcp.NewProtocol() as one of the
+// transport protocols when calling stack.New(). Then endpoints can be created
+// by passing tcp.ProtocolNumber as the transport protocol number when calling
+// Stack.NewEndpoint().
+package tcp
+
+import (
+ "fmt"
+ "runtime"
+ "strings"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/raw"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ // ProtocolNumber is the tcp protocol number.
+ ProtocolNumber = header.TCPProtocolNumber
+
+ // MinBufferSize is the smallest size of a receive or send buffer.
+ MinBufferSize = 4 << 10 // 4096 bytes.
+
+ // DefaultSendBufferSize is the default size of the send buffer for
+ // an endpoint.
+ DefaultSendBufferSize = 1 << 20 // 1MB
+
+ // DefaultReceiveBufferSize is the default size of the receive buffer
+ // for an endpoint.
+ DefaultReceiveBufferSize = 1 << 20 // 1MB
+
+ // MaxBufferSize is the largest size a receive/send buffer can grow to.
+ MaxBufferSize = 4 << 20 // 4MB
+
+ // MaxUnprocessedSegments is the maximum number of unprocessed segments
+ // that can be queued for a given endpoint.
+ MaxUnprocessedSegments = 300
+
+ // DefaultTCPLingerTimeout is the amount of time that sockets linger in
+ // FIN_WAIT_2 state before being marked closed.
+ DefaultTCPLingerTimeout = 60 * time.Second
+
+ // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger
+ // in TIME_WAIT state before being marked closed.
+ DefaultTCPTimeWaitTimeout = 60 * time.Second
+
+ // DefaultSynRetries is the default value for the number of SYN retransmits
+ // before a connect is aborted.
+ DefaultSynRetries = 6
+)
+
+const (
+ ccReno = "reno"
+ ccCubic = "cubic"
+)
+
+// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to
+// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018.
+type SACKEnabled bool
+
+// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to
+// enable/disable Nagle's algorithm in TCP.
+type DelayEnabled bool
+
+// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption
+// to get/set the default, min and max TCP send buffer sizes.
+type SendBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+// ReceiveBufferSizeOption is used by
+// stack.(Stack*).TransportProtocolOption to get/set the default, min and max
+// TCP receive buffer sizes.
+type ReceiveBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The
+// value is protected by a mutex so that we can increment only when it's
+// guaranteed not to go above a threshold.
+type synRcvdCounter struct {
+ sync.Mutex
+ value uint64
+ pending sync.WaitGroup
+ threshold uint64
+}
+
+// inc tries to increment the global number of endpoints in SYN-RCVD state. It
+// succeeds if the increment doesn't make the count go beyond the threshold, and
+// fails otherwise.
+func (s *synRcvdCounter) inc() bool {
+ s.Lock()
+ defer s.Unlock()
+ if s.value >= s.threshold {
+ return false
+ }
+
+ s.pending.Add(1)
+ s.value++
+
+ return true
+}
+
+// dec atomically decrements the global number of endpoints in SYN-RCVD
+// state. It must only be called if a previous call to inc succeeded.
+func (s *synRcvdCounter) dec() {
+ s.Lock()
+ defer s.Unlock()
+ s.value--
+ s.pending.Done()
+}
+
+// synCookiesInUse returns true if the synRcvdCount is greater than
+// SynRcvdCountThreshold.
+func (s *synRcvdCounter) synCookiesInUse() bool {
+ s.Lock()
+ defer s.Unlock()
+ return s.value >= s.threshold
+}
+
+// SetThreshold sets synRcvdCounter.Threshold to ths new threshold.
+func (s *synRcvdCounter) SetThreshold(threshold uint64) {
+ s.Lock()
+ defer s.Unlock()
+ s.threshold = threshold
+}
+
+// Threshold returns the current value of synRcvdCounter.Threhsold.
+func (s *synRcvdCounter) Threshold() uint64 {
+ s.Lock()
+ defer s.Unlock()
+ return s.threshold
+}
+
+type protocol struct {
+ mu sync.RWMutex
+ sackEnabled bool
+ delayEnabled bool
+ sendBufferSize SendBufferSizeOption
+ recvBufferSize ReceiveBufferSizeOption
+ congestionControl string
+ availableCongestionControl []string
+ moderateReceiveBuffer bool
+ tcpLingerTimeout time.Duration
+ tcpTimeWaitTimeout time.Duration
+ minRTO time.Duration
+ maxRTO time.Duration
+ maxRetries uint32
+ synRcvdCount synRcvdCounter
+ synRetries uint8
+ dispatcher dispatcher
+}
+
+// Number returns the tcp protocol number.
+func (*protocol) Number() tcpip.TransportProtocolNumber {
+ return ProtocolNumber
+}
+
+// NewEndpoint creates a new tcp endpoint.
+func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(stack, netProto, waiterQueue), nil
+}
+
+// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
+// unsupported. It implements stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(stack, netProto, header.TCPProtocolNumber, waiterQueue)
+}
+
+// MinimumPacketSize returns the minimum valid tcp packet size.
+func (*protocol) MinimumPacketSize() int {
+ return header.TCPMinimumSize
+}
+
+// ParsePorts returns the source and destination ports stored in the given tcp
+// packet.
+func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+ h := header.TCP(v)
+ return h.SourcePort(), h.DestinationPort(), nil
+}
+
+// QueuePacket queues packets targeted at an endpoint after hashing the packet
+// to a specific processing queue. Each queue is serviced by its own processor
+// goroutine which is responsible for dequeuing and doing full TCP dispatch of
+// the packet.
+func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+ p.dispatcher.queuePacket(r, ep, id, pkt)
+}
+
+// HandleUnknownDestinationPacket handles packets targeted at this protocol but
+// that don't match any existing endpoint.
+//
+// RFC 793, page 36, states that "If the connection does not exist (CLOSED) then
+// a reset is sent in response to any incoming segment except another reset. In
+// particular, SYNs addressed to a non-existent connection are rejected by this
+// means."
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+ s := newSegment(r, id, pkt)
+ defer s.decRef()
+
+ if !s.parse() || !s.csumValid {
+ return false
+ }
+
+ // There's nothing to do if this is already a reset packet.
+ if s.flagIsSet(header.TCPFlagRst) {
+ return true
+ }
+
+ replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
+ return true
+}
+
+// replyWithReset replies to the given segment with a reset segment.
+func replyWithReset(s *segment, tos, ttl uint8) {
+ // Get the seqnum from the packet if the ack flag is set.
+ seq := seqnum.Value(0)
+ ack := seqnum.Value(0)
+ flags := byte(header.TCPFlagRst)
+ // As per RFC 793 page 35 (Reset Generation)
+ // 1. If the connection does not exist (CLOSED) then a reset is sent
+ // in response to any incoming segment except another reset. In
+ // particular, SYNs addressed to a non-existent connection are rejected
+ // by this means.
+
+ // If the incoming segment has an ACK field, the reset takes its
+ // sequence number from the ACK field of the segment, otherwise the
+ // reset has sequence number zero and the ACK field is set to the sum
+ // of the sequence number and segment length of the incoming segment.
+ // The connection remains in the CLOSED state.
+ if s.flagIsSet(header.TCPFlagAck) {
+ seq = s.ackNumber
+ } else {
+ flags |= header.TCPFlagAck
+ ack = s.sequenceNumber.Add(s.logicalLen())
+ }
+ sendTCP(&s.route, tcpFields{
+ id: s.id,
+ ttl: ttl,
+ tos: tos,
+ flags: flags,
+ seq: seq,
+ ack: ack,
+ rcvWnd: 0,
+ }, buffer.VectorisedView{}, nil /* gso */, nil /* PacketOwner */)
+}
+
+// SetOption implements stack.TransportProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case SACKEnabled:
+ p.mu.Lock()
+ p.sackEnabled = bool(v)
+ p.mu.Unlock()
+ return nil
+
+ case DelayEnabled:
+ p.mu.Lock()
+ p.delayEnabled = bool(v)
+ p.mu.Unlock()
+ return nil
+
+ case SendBufferSizeOption:
+ if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.sendBufferSize = v
+ p.mu.Unlock()
+ return nil
+
+ case ReceiveBufferSizeOption:
+ if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.recvBufferSize = v
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.CongestionControlOption:
+ for _, c := range p.availableCongestionControl {
+ if string(v) == c {
+ p.mu.Lock()
+ p.congestionControl = string(v)
+ p.mu.Unlock()
+ return nil
+ }
+ }
+ // linux returns ENOENT when an invalid congestion control
+ // is specified.
+ return tcpip.ErrNoSuchFile
+
+ case tcpip.ModerateReceiveBufferOption:
+ p.mu.Lock()
+ p.moderateReceiveBuffer = bool(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPLingerTimeoutOption:
+ if v < 0 {
+ v = 0
+ }
+ p.mu.Lock()
+ p.tcpLingerTimeout = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPTimeWaitTimeoutOption:
+ if v < 0 {
+ v = 0
+ }
+ p.mu.Lock()
+ p.tcpTimeWaitTimeout = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPMinRTOOption:
+ if v < 0 {
+ v = tcpip.TCPMinRTOOption(MinRTO)
+ }
+ p.mu.Lock()
+ p.minRTO = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPMaxRTOOption:
+ if v < 0 {
+ v = tcpip.TCPMaxRTOOption(MaxRTO)
+ }
+ p.mu.Lock()
+ p.maxRTO = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPMaxRetriesOption:
+ p.mu.Lock()
+ p.maxRetries = uint32(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPSynRcvdCountThresholdOption:
+ p.mu.Lock()
+ p.synRcvdCount.SetThreshold(uint64(v))
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPSynRetriesOption:
+ if v < 1 || v > 255 {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.synRetries = uint8(v)
+ p.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Option implements stack.TransportProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *SACKEnabled:
+ p.mu.RLock()
+ *v = SACKEnabled(p.sackEnabled)
+ p.mu.RUnlock()
+ return nil
+
+ case *DelayEnabled:
+ p.mu.RLock()
+ *v = DelayEnabled(p.delayEnabled)
+ p.mu.RUnlock()
+ return nil
+
+ case *SendBufferSizeOption:
+ p.mu.RLock()
+ *v = p.sendBufferSize
+ p.mu.RUnlock()
+ return nil
+
+ case *ReceiveBufferSizeOption:
+ p.mu.RLock()
+ *v = p.recvBufferSize
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.CongestionControlOption:
+ p.mu.RLock()
+ *v = tcpip.CongestionControlOption(p.congestionControl)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.AvailableCongestionControlOption:
+ p.mu.RLock()
+ *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.ModerateReceiveBufferOption:
+ p.mu.RLock()
+ *v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPLingerTimeoutOption:
+ p.mu.RLock()
+ *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPTimeWaitTimeoutOption:
+ p.mu.RLock()
+ *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPMinRTOOption:
+ p.mu.RLock()
+ *v = tcpip.TCPMinRTOOption(p.minRTO)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPMaxRTOOption:
+ p.mu.RLock()
+ *v = tcpip.TCPMaxRTOOption(p.maxRTO)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPMaxRetriesOption:
+ p.mu.RLock()
+ *v = tcpip.TCPMaxRetriesOption(p.maxRetries)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPSynRcvdCountThresholdOption:
+ p.mu.RLock()
+ *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold())
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPSynRetriesOption:
+ p.mu.RLock()
+ *v = tcpip.TCPSynRetriesOption(p.synRetries)
+ p.mu.RUnlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Close implements stack.TransportProtocol.Close.
+func (p *protocol) Close() {
+ p.dispatcher.close()
+}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (p *protocol) Wait() {
+ p.dispatcher.wait()
+}
+
+// SynRcvdCounter returns a reference to the synRcvdCount for this protocol
+// instance.
+func (p *protocol) SynRcvdCounter() *synRcvdCounter {
+ return &p.synRcvdCount
+}
+
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
+ hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize)
+ if !ok {
+ return false
+ }
+
+ // If the header has options, pull those up as well.
+ if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() {
+ hdr, ok = pkt.Data.PullUp(offset)
+ if !ok {
+ panic(fmt.Sprintf("There should be at least %d bytes in pkt.Data.", offset))
+ }
+ }
+
+ pkt.TransportHeader = hdr
+ pkt.Data.TrimFront(len(hdr))
+ return true
+}
+
+// NewProtocol returns a TCP transport protocol.
+func NewProtocol() stack.TransportProtocol {
+ p := protocol{
+ sendBufferSize: SendBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultSendBufferSize,
+ Max: MaxBufferSize,
+ },
+ recvBufferSize: ReceiveBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultReceiveBufferSize,
+ Max: MaxBufferSize,
+ },
+ congestionControl: ccReno,
+ availableCongestionControl: []string{ccReno, ccCubic},
+ tcpLingerTimeout: DefaultTCPLingerTimeout,
+ tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold},
+ synRetries: DefaultSynRetries,
+ minRTO: MinRTO,
+ maxRTO: MaxRTO,
+ maxRetries: MaxRetries,
+ }
+ p.dispatcher.init(runtime.GOMAXPROCS(0))
+ return &p
+}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
new file mode 100644
index 000000000..dd89a292a
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -0,0 +1,475 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "container/heap"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+// receiver holds the state necessary to receive TCP segments and turn them
+// into a stream of bytes.
+//
+// +stateify savable
+type receiver struct {
+ ep *endpoint
+
+ rcvNxt seqnum.Value
+
+ // rcvAcc is one beyond the last acceptable sequence number. That is,
+ // the "largest" sequence value that the receiver has announced to the
+ // its peer that it's willing to accept. This may be different than
+ // rcvNxt + rcvWnd if the receive window is reduced; in that case we
+ // have to reduce the window as we receive more data instead of
+ // shrinking it.
+ rcvAcc seqnum.Value
+
+ // rcvWnd is the non-scaled receive window last advertised to the peer.
+ rcvWnd seqnum.Size
+
+ rcvWndScale uint8
+
+ closed bool
+
+ pendingRcvdSegments segmentHeap
+ pendingBufUsed seqnum.Size
+ pendingBufSize seqnum.Size
+
+ // Time when the last ack was received.
+ lastRcvdAckTime time.Time `state:".(unixTime)"`
+}
+
+func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver {
+ return &receiver{
+ ep: ep,
+ rcvNxt: irs + 1,
+ rcvAcc: irs.Add(rcvWnd + 1),
+ rcvWnd: rcvWnd,
+ rcvWndScale: rcvWndScale,
+ pendingBufSize: pendingBufSize,
+ lastRcvdAckTime: time.Now(),
+ }
+}
+
+// acceptable checks if the segment sequence number range is acceptable
+// according to the table on page 26 of RFC 793.
+func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
+ // r.rcvWnd could be much larger than the window size we advertised in our
+ // outgoing packets, we should use what we have advertised for acceptability
+ // test.
+ scaledWindowSize := r.rcvWnd >> r.rcvWndScale
+ if scaledWindowSize > 0xffff {
+ // This is what we actually put in the Window field.
+ scaledWindowSize = 0xffff
+ }
+ advertisedWindowSize := scaledWindowSize << r.rcvWndScale
+ return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize))
+}
+
+// getSendParams returns the parameters needed by the sender when building
+// segments to send.
+func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
+ // Calculate the window size based on the available buffer space.
+ receiveBufferAvailable := r.ep.receiveBufferAvailable()
+ acc := r.rcvNxt.Add(seqnum.Size(receiveBufferAvailable))
+ if r.rcvAcc.LessThan(acc) {
+ r.rcvAcc = acc
+ }
+ // Stash away the non-scaled receive window as we use it for measuring
+ // receiver's estimated RTT.
+ r.rcvWnd = r.rcvNxt.Size(r.rcvAcc)
+ return r.rcvNxt, r.rcvWnd >> r.rcvWndScale
+}
+
+// nonZeroWindow is called when the receive window grows from zero to nonzero;
+// in such cases we may need to send an ack to indicate to our peer that it can
+// resume sending data.
+func (r *receiver) nonZeroWindow() {
+ // Immediately send an ack.
+ r.ep.snd.sendAck()
+}
+
+// consumeSegment attempts to consume a segment that was received by r. The
+// segment may have just been received or may have been received earlier but
+// wasn't ready to be consumed then.
+//
+// Returns true if the segment was consumed, false if it cannot be consumed
+// yet because of a missing segment.
+func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum.Size) bool {
+ if segLen > 0 {
+ // If the segment doesn't include the seqnum we're expecting to
+ // consume now, we're missing a segment. We cannot proceed until
+ // we receive that segment though.
+ if !r.rcvNxt.InWindow(segSeq, segLen) {
+ return false
+ }
+
+ // Trim segment to eliminate already acknowledged data.
+ if segSeq.LessThan(r.rcvNxt) {
+ diff := segSeq.Size(r.rcvNxt)
+ segLen -= diff
+ segSeq.UpdateForward(diff)
+ s.sequenceNumber.UpdateForward(diff)
+ s.data.TrimFront(int(diff))
+ }
+
+ // Move segment to ready-to-deliver list. Wakeup any waiters.
+ r.ep.readyToRead(s)
+
+ } else if segSeq != r.rcvNxt {
+ return false
+ }
+
+ // Update the segment that we're expecting to consume.
+ r.rcvNxt = segSeq.Add(segLen)
+
+ // In cases of a misbehaving sender which could send more than the
+ // advertised window, we could end up in a situation where we get a
+ // segment that exceeds the window advertised. Instead of partially
+ // accepting the segment and discarding bytes beyond the advertised
+ // window, we accept the whole segment and make sure r.rcvAcc is moved
+ // forward to match r.rcvNxt to indicate that the window is now closed.
+ //
+ // In absence of this check the r.acceptable() check fails and accepts
+ // segments that should be dropped because rcvWnd is calculated as
+ // the size of the interval (rcvNxt, rcvAcc] which becomes extremely
+ // large if rcvAcc is ever less than rcvNxt.
+ if r.rcvAcc.LessThan(r.rcvNxt) {
+ r.rcvAcc = r.rcvNxt
+ }
+
+ // Trim SACK Blocks to remove any SACK information that covers
+ // sequence numbers that have been consumed.
+ TrimSACKBlockList(&r.ep.sack, r.rcvNxt)
+
+ // Handle FIN or FIN-ACK.
+ if s.flagIsSet(header.TCPFlagFin) {
+ r.rcvNxt++
+
+ // Send ACK immediately.
+ r.ep.snd.sendAck()
+
+ // Tell any readers that no more data will come.
+ r.closed = true
+ r.ep.readyToRead(nil)
+
+ // We just received a FIN, our next state depends on whether we sent a
+ // FIN already or not.
+ switch r.ep.EndpointState() {
+ case StateEstablished:
+ r.ep.setEndpointState(StateCloseWait)
+ case StateFinWait1:
+ if s.flagIsSet(header.TCPFlagAck) {
+ // FIN-ACK, transition to TIME-WAIT.
+ r.ep.setEndpointState(StateTimeWait)
+ } else {
+ // Simultaneous close, expecting a final ACK.
+ r.ep.setEndpointState(StateClosing)
+ }
+ case StateFinWait2:
+ r.ep.setEndpointState(StateTimeWait)
+ }
+
+ // Flush out any pending segments, except the very first one if
+ // it happens to be the one we're handling now because the
+ // caller is using it.
+ first := 0
+ if len(r.pendingRcvdSegments) != 0 && r.pendingRcvdSegments[0] == s {
+ first = 1
+ }
+
+ for i := first; i < len(r.pendingRcvdSegments); i++ {
+ r.pendingRcvdSegments[i].decRef()
+ // Note that slice truncation does not allow garbage collection of
+ // truncated items, thus truncated items must be set to nil to avoid
+ // memory leaks.
+ r.pendingRcvdSegments[i] = nil
+ }
+ r.pendingRcvdSegments = r.pendingRcvdSegments[:first]
+
+ return true
+ }
+
+ // Handle ACK (not FIN-ACK, which we handled above) during one of the
+ // shutdown states.
+ if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
+ switch r.ep.EndpointState() {
+ case StateFinWait1:
+ r.ep.setEndpointState(StateFinWait2)
+ // Notify protocol goroutine that we have received an
+ // ACK to our FIN so that it can start the FIN_WAIT2
+ // timer to abort connection if the other side does
+ // not close within 2MSL.
+ r.ep.notifyProtocolGoroutine(notifyClose)
+ case StateClosing:
+ r.ep.setEndpointState(StateTimeWait)
+ case StateLastAck:
+ r.ep.transitionToStateCloseLocked()
+ }
+ }
+
+ return true
+}
+
+// updateRTT updates the receiver RTT measurement based on the sequence number
+// of the received segment.
+func (r *receiver) updateRTT() {
+ // From: https://public.lanl.gov/radiant/pubs/drs/sc2001-poster.pdf
+ //
+ // A system that is only transmitting acknowledgements can still
+ // estimate the round-trip time by observing the time between when a byte
+ // is first acknowledged and the receipt of data that is at least one
+ // window beyond the sequence number that was acknowledged.
+ r.ep.rcvListMu.Lock()
+ if r.ep.rcvAutoParams.rttMeasureTime.IsZero() {
+ // New measurement.
+ r.ep.rcvAutoParams.rttMeasureTime = time.Now()
+ r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd)
+ r.ep.rcvListMu.Unlock()
+ return
+ }
+ if r.rcvNxt.LessThan(r.ep.rcvAutoParams.rttMeasureSeqNumber) {
+ r.ep.rcvListMu.Unlock()
+ return
+ }
+ rtt := time.Since(r.ep.rcvAutoParams.rttMeasureTime)
+ // We only store the minimum observed RTT here as this is only used in
+ // absence of a SRTT available from either timestamps or a sender
+ // measurement of RTT.
+ if r.ep.rcvAutoParams.rtt == 0 || rtt < r.ep.rcvAutoParams.rtt {
+ r.ep.rcvAutoParams.rtt = rtt
+ }
+ r.ep.rcvAutoParams.rttMeasureTime = time.Now()
+ r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd)
+ r.ep.rcvListMu.Unlock()
+}
+
+func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) {
+ r.ep.rcvListMu.Lock()
+ rcvClosed := r.ep.rcvClosed || r.closed
+ r.ep.rcvListMu.Unlock()
+
+ // If we are in one of the shutdown states then we need to do
+ // additional checks before we try and process the segment.
+ switch state {
+ case StateCloseWait:
+ // If the ACK acks something not yet sent then we send an ACK.
+ if r.ep.snd.sndNxt.LessThan(s.ackNumber) {
+ r.ep.snd.sendAck()
+ return true, nil
+ }
+ fallthrough
+ case StateClosing, StateLastAck:
+ if !s.sequenceNumber.LessThanEq(r.rcvNxt) {
+ // Just drop the segment as we have
+ // already received a FIN and this
+ // segment is after the sequence number
+ // for the FIN.
+ return true, nil
+ }
+ fallthrough
+ case StateFinWait1:
+ fallthrough
+ case StateFinWait2:
+ // If we are closed for reads (either due to an
+ // incoming FIN or the user calling shutdown(..,
+ // SHUT_RD) then any data past the rcvNxt should
+ // trigger a RST.
+ endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
+ if state != StateCloseWait && rcvClosed && r.rcvNxt.LessThan(endDataSeq) {
+ return true, tcpip.ErrConnectionAborted
+ }
+ if state == StateFinWait1 {
+ break
+ }
+
+ // If it's a retransmission of an old data segment
+ // or a pure ACK then allow it.
+ if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) ||
+ s.logicalLen() == 0 {
+ break
+ }
+
+ // In FIN-WAIT2 if the socket is fully
+ // closed(not owned by application on our end
+ // then the only acceptable segment is a
+ // FIN. Since FIN can technically also carry
+ // data we verify that the segment carrying a
+ // FIN ends at exactly e.rcvNxt+1.
+ //
+ // From RFC793 page 25.
+ //
+ // For sequence number purposes, the SYN is
+ // considered to occur before the first actual
+ // data octet of the segment in which it occurs,
+ // while the FIN is considered to occur after
+ // the last actual data octet in a segment in
+ // which it occurs.
+ if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) {
+ return true, tcpip.ErrConnectionAborted
+ }
+ }
+
+ // We don't care about receive processing anymore if the receive side
+ // is closed.
+ //
+ // NOTE: We still want to permit a FIN as it's possible only our
+ // end has closed and the peer is yet to send a FIN. Hence we
+ // compare only the payload.
+ segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
+ if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) {
+ return true, nil
+ }
+ return false, nil
+}
+
+// handleRcvdSegment handles TCP segments directed at the connection managed by
+// r as they arrive. It is called by the protocol main loop.
+func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
+ state := r.ep.EndpointState()
+ closed := r.ep.closed
+
+ segLen := seqnum.Size(s.data.Size())
+ segSeq := s.sequenceNumber
+
+ // If the sequence number range is outside the acceptable range, just
+ // send an ACK and stop further processing of the segment.
+ // This is according to RFC 793, page 68.
+ if !r.acceptable(segSeq, segLen) {
+ r.ep.snd.sendAck()
+ return true, nil
+ }
+
+ if state != StateEstablished {
+ drop, err := r.handleRcvdSegmentClosing(s, state, closed)
+ if drop || err != nil {
+ return drop, err
+ }
+ }
+
+ // Store the time of the last ack.
+ r.lastRcvdAckTime = time.Now()
+
+ // Defer segment processing if it can't be consumed now.
+ if !r.consumeSegment(s, segSeq, segLen) {
+ if segLen > 0 || s.flagIsSet(header.TCPFlagFin) {
+ // We only store the segment if it's within our buffer
+ // size limit.
+ if r.pendingBufUsed < r.pendingBufSize {
+ r.pendingBufUsed += s.logicalLen()
+ s.incRef()
+ heap.Push(&r.pendingRcvdSegments, s)
+ UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
+ }
+
+ // Immediately send an ack so that the peer knows it may
+ // have to retransmit.
+ r.ep.snd.sendAck()
+ }
+ return false, nil
+ }
+
+ // Since we consumed a segment update the receiver's RTT estimate
+ // if required.
+ if segLen > 0 {
+ r.updateRTT()
+ }
+
+ // By consuming the current segment, we may have filled a gap in the
+ // sequence number domain that allows pending segments to be consumed
+ // now. So try to do it.
+ for !r.closed && r.pendingRcvdSegments.Len() > 0 {
+ s := r.pendingRcvdSegments[0]
+ segLen := seqnum.Size(s.data.Size())
+ segSeq := s.sequenceNumber
+
+ // Skip segment altogether if it has already been acknowledged.
+ if !segSeq.Add(segLen-1).LessThan(r.rcvNxt) &&
+ !r.consumeSegment(s, segSeq, segLen) {
+ break
+ }
+
+ heap.Pop(&r.pendingRcvdSegments)
+ r.pendingBufUsed -= s.logicalLen()
+ s.decRef()
+ }
+ return false, nil
+}
+
+// handleTimeWaitSegment handles inbound segments received when the endpoint
+// has entered the TIME_WAIT state.
+func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn bool) {
+ segSeq := s.sequenceNumber
+ segLen := seqnum.Size(s.data.Size())
+
+ // Just silently drop any RST packets in TIME_WAIT. We do not support
+ // TIME_WAIT assasination as a result we confirm w/ fix 1 as described
+ // in https://tools.ietf.org/html/rfc1337#section-3.
+ if s.flagIsSet(header.TCPFlagRst) {
+ return false, false
+ }
+
+ // If it's a SYN and the sequence number is higher than any seen before
+ // for this connection then try and redirect it to a listening endpoint
+ // if available.
+ //
+ // RFC 1122:
+ // "When a connection is [...] on TIME-WAIT state [...]
+ // [a TCP] MAY accept a new SYN from the remote TCP to
+ // reopen the connection directly, if it:
+
+ // (1) assigns its initial sequence number for the new
+ // connection to be larger than the largest sequence
+ // number it used on the previous connection incarnation,
+ // and
+
+ // (2) returns to TIME-WAIT state if the SYN turns out
+ // to be an old duplicate".
+ if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) {
+
+ return false, true
+ }
+
+ // Drop the segment if it does not contain an ACK.
+ if !s.flagIsSet(header.TCPFlagAck) {
+ return false, false
+ }
+
+ // Update Timestamp if required. See RFC7323, section-4.3.
+ if r.ep.sendTSOk && s.parsedOptions.TS {
+ r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq)
+ }
+
+ if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) {
+ // If it's a FIN-ACK then resetTimeWait and send an ACK, as it
+ // indicates our final ACK could have been lost.
+ r.ep.snd.sendAck()
+ return true, false
+ }
+
+ // If the sequence number range is outside the acceptable range or
+ // carries data then just send an ACK. This is according to RFC 793,
+ // page 37.
+ //
+ // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt.
+ if segSeq != r.rcvNxt || segLen != 0 {
+ r.ep.snd.sendAck()
+ }
+ return false, false
+}
diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/tcpip/transport/tcp/rcv_state.go
new file mode 100644
index 000000000..2bf21a2e7
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rcv_state.go
@@ -0,0 +1,29 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+)
+
+// saveLastRcvdAckTime is invoked by stateify.
+func (r *receiver) saveLastRcvdAckTime() unixTime {
+ return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()}
+}
+
+// loadLastRcvdAckTime is invoked by stateify.
+func (r *receiver) loadLastRcvdAckTime(unix unixTime) {
+ r.lastRcvdAckTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go
new file mode 100644
index 000000000..8a026ec46
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rcv_test.go
@@ -0,0 +1,74 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package rcv_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+func TestAcceptable(t *testing.T) {
+ for _, tt := range []struct {
+ segSeq seqnum.Value
+ segLen seqnum.Size
+ rcvNxt, rcvAcc seqnum.Value
+ want bool
+ }{
+ // The segment is smaller than the window.
+ {105, 2, 100, 104, false},
+ {105, 2, 101, 105, true},
+ {105, 2, 102, 106, true},
+ {105, 2, 103, 107, true},
+ {105, 2, 104, 108, true},
+ {105, 2, 105, 109, true},
+ {105, 2, 106, 110, true},
+ {105, 2, 107, 111, false},
+
+ // The segment is larger than the window.
+ {105, 4, 103, 105, true},
+ {105, 4, 104, 106, true},
+ {105, 4, 105, 107, true},
+ {105, 4, 106, 108, true},
+ {105, 4, 107, 109, true},
+ {105, 4, 108, 110, true},
+ {105, 4, 109, 111, false},
+ {105, 4, 110, 112, false},
+
+ // The segment has no width.
+ {105, 0, 100, 102, false},
+ {105, 0, 101, 103, false},
+ {105, 0, 102, 104, false},
+ {105, 0, 103, 105, true},
+ {105, 0, 104, 106, true},
+ {105, 0, 105, 107, true},
+ {105, 0, 106, 108, false},
+ {105, 0, 107, 109, false},
+
+ // The receive window has no width.
+ {105, 2, 103, 103, false},
+ {105, 2, 104, 104, false},
+ {105, 2, 105, 105, false},
+ {105, 2, 106, 106, false},
+ {105, 2, 107, 107, false},
+ {105, 2, 108, 108, false},
+ {105, 2, 109, 109, false},
+ } {
+ if got := header.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want {
+ t.Errorf("header.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want)
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/reno.go b/pkg/tcpip/transport/tcp/reno.go
new file mode 100644
index 000000000..f83ebc717
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/reno.go
@@ -0,0 +1,103 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+// renoState stores the variables related to TCP New Reno congestion
+// control algorithm.
+//
+// +stateify savable
+type renoState struct {
+ s *sender
+}
+
+// newRenoCC initializes the state for the NewReno congestion control algorithm.
+func newRenoCC(s *sender) *renoState {
+ return &renoState{s: s}
+}
+
+// updateSlowStart will update the congestion window as per the slow-start
+// algorithm used by NewReno. If after adjusting the congestion window
+// we cross the SSthreshold then it will return the number of packets that
+// must be consumed in congestion avoidance mode.
+func (r *renoState) updateSlowStart(packetsAcked int) int {
+ // Don't let the congestion window cross into the congestion
+ // avoidance range.
+ newcwnd := r.s.sndCwnd + packetsAcked
+ if newcwnd >= r.s.sndSsthresh {
+ newcwnd = r.s.sndSsthresh
+ r.s.sndCAAckCount = 0
+ }
+
+ packetsAcked -= newcwnd - r.s.sndCwnd
+ r.s.sndCwnd = newcwnd
+ return packetsAcked
+}
+
+// updateCongestionAvoidance will update congestion window in congestion
+// avoidance mode as described in RFC5681 section 3.1
+func (r *renoState) updateCongestionAvoidance(packetsAcked int) {
+ // Consume the packets in congestion avoidance mode.
+ r.s.sndCAAckCount += packetsAcked
+ if r.s.sndCAAckCount >= r.s.sndCwnd {
+ r.s.sndCwnd += r.s.sndCAAckCount / r.s.sndCwnd
+ r.s.sndCAAckCount = r.s.sndCAAckCount % r.s.sndCwnd
+ }
+}
+
+// reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681,
+// page 6, eq. 4. It is called when we detect congestion in the network.
+func (r *renoState) reduceSlowStartThreshold() {
+ r.s.sndSsthresh = r.s.outstanding / 2
+ if r.s.sndSsthresh < 2 {
+ r.s.sndSsthresh = 2
+ }
+
+}
+
+// Update updates the congestion state based on the number of packets that
+// were acknowledged.
+// Update implements congestionControl.Update.
+func (r *renoState) Update(packetsAcked int) {
+ if r.s.sndCwnd < r.s.sndSsthresh {
+ packetsAcked = r.updateSlowStart(packetsAcked)
+ if packetsAcked == 0 {
+ return
+ }
+ }
+ r.updateCongestionAvoidance(packetsAcked)
+}
+
+// HandleNDupAcks implements congestionControl.HandleNDupAcks.
+func (r *renoState) HandleNDupAcks() {
+ // A retransmit was triggered due to nDupAckThreshold
+ // being hit. Reduce our slow start threshold.
+ r.reduceSlowStartThreshold()
+}
+
+// HandleRTOExpired implements congestionControl.HandleRTOExpired.
+func (r *renoState) HandleRTOExpired() {
+ // We lost a packet, so reduce ssthresh.
+ r.reduceSlowStartThreshold()
+
+ // Reduce the congestion window to 1, i.e., enter slow-start. Per
+ // RFC 5681, page 7, we must use 1 regardless of the value of the
+ // initial congestion window.
+ r.s.sndCwnd = 1
+}
+
+// PostRecovery implements congestionControl.PostRecovery.
+func (r *renoState) PostRecovery() {
+ // noop.
+}
diff --git a/pkg/tcpip/transport/tcp/sack.go b/pkg/tcpip/transport/tcp/sack.go
new file mode 100644
index 000000000..7be86d68e
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/sack.go
@@ -0,0 +1,105 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+const (
+ // MaxSACKBlocks is the maximum number of SACK blocks stored
+ // at receiver side.
+ MaxSACKBlocks = 6
+)
+
+// UpdateSACKBlocks updates the list of SACK blocks to include the segment
+// specified by segStart->segEnd. If the segment happens to be an out of order
+// delivery then the first block in the sack.blocks always includes the
+// segment identified by segStart->segEnd.
+func UpdateSACKBlocks(sack *SACKInfo, segStart seqnum.Value, segEnd seqnum.Value, rcvNxt seqnum.Value) {
+ newSB := header.SACKBlock{Start: segStart, End: segEnd}
+
+ // Ignore any invalid SACK blocks or blocks that are before rcvNxt as
+ // those bytes have already been acked.
+ if newSB.End.LessThanEq(newSB.Start) || newSB.End.LessThan(rcvNxt) {
+ return
+ }
+
+ if sack.NumBlocks == 0 {
+ sack.Blocks[0] = newSB
+ sack.NumBlocks = 1
+ return
+ }
+ var n = 0
+ for i := 0; i < sack.NumBlocks; i++ {
+ start, end := sack.Blocks[i].Start, sack.Blocks[i].End
+ if end.LessThanEq(rcvNxt) {
+ // Discard any sack blocks that are before rcvNxt as
+ // those have already been acked.
+ continue
+ }
+ if newSB.Start.LessThanEq(end) && start.LessThanEq(newSB.End) {
+ // Merge this SACK block into newSB and discard this SACK
+ // block.
+ if start.LessThan(newSB.Start) {
+ newSB.Start = start
+ }
+ if newSB.End.LessThan(end) {
+ newSB.End = end
+ }
+ } else {
+ // Save this block.
+ sack.Blocks[n] = sack.Blocks[i]
+ n++
+ }
+ }
+ if rcvNxt.LessThan(newSB.Start) {
+ // If this was an out of order segment then make sure that the
+ // first SACK block is the one that includes the segment.
+ //
+ // See the first bullet point in
+ // https://tools.ietf.org/html/rfc2018#section-4
+ if n == MaxSACKBlocks {
+ // If the number of SACK blocks is equal to
+ // MaxSACKBlocks then discard the last SACK block.
+ n--
+ }
+ for i := n - 1; i >= 0; i-- {
+ sack.Blocks[i+1] = sack.Blocks[i]
+ }
+ sack.Blocks[0] = newSB
+ n++
+ }
+ sack.NumBlocks = n
+}
+
+// TrimSACKBlockList updates the sack block list by removing/modifying any block
+// where start is < rcvNxt.
+func TrimSACKBlockList(sack *SACKInfo, rcvNxt seqnum.Value) {
+ n := 0
+ for i := 0; i < sack.NumBlocks; i++ {
+ if sack.Blocks[i].End.LessThanEq(rcvNxt) {
+ continue
+ }
+ if sack.Blocks[i].Start.LessThan(rcvNxt) {
+ // Shrink this SACK block.
+ sack.Blocks[i].Start = rcvNxt
+ }
+ sack.Blocks[n] = sack.Blocks[i]
+ n++
+ }
+ sack.NumBlocks = n
+}
diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard.go b/pkg/tcpip/transport/tcp/sack_scoreboard.go
new file mode 100644
index 000000000..7ef2df377
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/sack_scoreboard.go
@@ -0,0 +1,306 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "fmt"
+ "strings"
+
+ "github.com/google/btree"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+const (
+ // maxSACKBlocks is the maximum number of distinct SACKBlocks the
+ // scoreboard will track. Once there are 100 distinct blocks, new
+ // insertions will fail.
+ maxSACKBlocks = 100
+
+ // defaultBtreeDegree is set to 2 as btree.New(2) results in a 2-3-4
+ // tree.
+ defaultBtreeDegree = 2
+)
+
+// SACKScoreboard stores a set of disjoint SACK ranges.
+//
+// +stateify savable
+type SACKScoreboard struct {
+ // smss is defined in RFC5681 as following:
+ //
+ // The SMSS is the size of the largest segment that the sender can
+ // transmit. This value can be based on the maximum transmission unit
+ // of the network, the path MTU discovery [RFC1191, RFC4821] algorithm,
+ // RMSS (see next item), or other factors. The size does not include
+ // the TCP/IP headers and options.
+ smss uint16
+ maxSACKED seqnum.Value
+ sacked seqnum.Size `state:"nosave"`
+ ranges *btree.BTree `state:"nosave"`
+}
+
+// NewSACKScoreboard returns a new SACK Scoreboard.
+func NewSACKScoreboard(smss uint16, iss seqnum.Value) *SACKScoreboard {
+ return &SACKScoreboard{
+ smss: smss,
+ ranges: btree.New(defaultBtreeDegree),
+ maxSACKED: iss,
+ }
+}
+
+// Reset erases all known range information from the SACK scoreboard.
+func (s *SACKScoreboard) Reset() {
+ s.ranges = btree.New(defaultBtreeDegree)
+ s.sacked = 0
+}
+
+// Insert inserts/merges the provided SACKBlock into the scoreboard.
+func (s *SACKScoreboard) Insert(r header.SACKBlock) {
+ if s.ranges.Len() >= maxSACKBlocks {
+ return
+ }
+
+ // Check if we can merge the new range with a range before or after it.
+ var toDelete []btree.Item
+ if s.maxSACKED.LessThan(r.End - 1) {
+ s.maxSACKED = r.End - 1
+ }
+ s.ranges.AscendGreaterOrEqual(r, func(i btree.Item) bool {
+ if i == r {
+ return true
+ }
+ sacked := i.(header.SACKBlock)
+ // There is a hole between these two SACK blocks, so we can't
+ // merge anymore.
+ if r.End.LessThan(sacked.Start) {
+ return false
+ }
+ // There is some overlap at this point, merge the blocks and
+ // delete the other one.
+ //
+ // ----sS--------sE
+ // r.S---------------rE
+ // -------sE
+ if sacked.End.LessThan(r.End) {
+ // sacked is contained in the newly inserted range.
+ // Delete this block.
+ toDelete = append(toDelete, i)
+ return true
+ }
+ // sacked covers a range past end of the newly inserted
+ // block.
+ r.End = sacked.End
+ toDelete = append(toDelete, i)
+ return true
+ })
+
+ s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool {
+ if i == r {
+ return true
+ }
+ sacked := i.(header.SACKBlock)
+ // sA------sE
+ // rA----rE
+ if sacked.End.LessThan(r.Start) {
+ return false
+ }
+ // The previous range extends into the current block. Merge it
+ // into the newly inserted range and delete the other one.
+ //
+ // <-rA---rE----<---rE--->
+ // sA--------------sE
+ r.Start = sacked.Start
+ // Extend r to cover sacked if sacked extends past r.
+ if r.End.LessThan(sacked.End) {
+ r.End = sacked.End
+ }
+ toDelete = append(toDelete, i)
+ return true
+ })
+ for _, i := range toDelete {
+ if sb := s.ranges.Delete(i); sb != nil {
+ sb := i.(header.SACKBlock)
+ s.sacked -= sb.Start.Size(sb.End)
+ }
+ }
+
+ replaced := s.ranges.ReplaceOrInsert(r)
+ if replaced == nil {
+ s.sacked += r.Start.Size(r.End)
+ }
+}
+
+// IsSACKED returns true if the a given range of sequence numbers denoted by r
+// are already covered by SACK information in the scoreboard.
+func (s *SACKScoreboard) IsSACKED(r header.SACKBlock) bool {
+ if s.Empty() {
+ return false
+ }
+
+ found := false
+ s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool {
+ sacked := i.(header.SACKBlock)
+ if sacked.End.LessThan(r.Start) {
+ return false
+ }
+ if sacked.Contains(r) {
+ found = true
+ return false
+ }
+ return true
+ })
+ return found
+}
+
+// Dump prints the state of the scoreboard structure.
+func (s *SACKScoreboard) String() string {
+ var str strings.Builder
+ str.WriteString("SACKScoreboard: {")
+ s.ranges.Ascend(func(i btree.Item) bool {
+ str.WriteString(fmt.Sprintf("%v,", i))
+ return true
+ })
+ str.WriteString("}\n")
+ return str.String()
+}
+
+// Delete removes all SACK information prior to seq.
+func (s *SACKScoreboard) Delete(seq seqnum.Value) {
+ if s.Empty() {
+ return
+ }
+ toDelete := []btree.Item{}
+ toInsert := []btree.Item{}
+ r := header.SACKBlock{seq, seq.Add(1)}
+ s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool {
+ if i == r {
+ return true
+ }
+ sb := i.(header.SACKBlock)
+ toDelete = append(toDelete, i)
+ if sb.End.LessThanEq(seq) {
+ s.sacked -= sb.Start.Size(sb.End)
+ } else {
+ newSB := header.SACKBlock{seq, sb.End}
+ toInsert = append(toInsert, newSB)
+ s.sacked -= sb.Start.Size(seq)
+ }
+ return true
+ })
+ for _, sb := range toDelete {
+ s.ranges.Delete(sb)
+ }
+ for _, sb := range toInsert {
+ s.ranges.ReplaceOrInsert(sb)
+ }
+}
+
+// Copy provides a copy of the SACK scoreboard.
+func (s *SACKScoreboard) Copy() (sackBlocks []header.SACKBlock, maxSACKED seqnum.Value) {
+ s.ranges.Ascend(func(i btree.Item) bool {
+ sackBlocks = append(sackBlocks, i.(header.SACKBlock))
+ return true
+ })
+ return sackBlocks, s.maxSACKED
+}
+
+// IsRangeLost implements the IsLost(SeqNum) operation defined in RFC 6675
+// section 4 but operates on a range of sequence numbers and returns true if
+// there are at least nDupAckThreshold SACK blocks greater than the range being
+// checked or if at least (nDupAckThreshold-1)*s.smss bytes have been SACKED
+// with sequence numbers greater than the block being checked.
+func (s *SACKScoreboard) IsRangeLost(r header.SACKBlock) bool {
+ if s.Empty() {
+ return false
+ }
+ nDupSACK := 0
+ nDupSACKBytes := seqnum.Size(0)
+ isLost := false
+
+ // We need to check if the immediate lower (if any) sacked
+ // range contains or partially overlaps with r.
+ searchMore := true
+ s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool {
+ sacked := i.(header.SACKBlock)
+ if sacked.Contains(r) {
+ searchMore = false
+ return false
+ }
+ if sacked.End.LessThanEq(r.Start) {
+ // all sequence numbers covered by sacked are below
+ // r so we continue searching.
+ return false
+ }
+ // There is a partial overlap. In this case we r.Start is
+ // between sacked.Start & sacked.End and r.End extends beyond
+ // sacked.End.
+ // Move r.Start to sacked.End and continuing searching blocks
+ // above r.Start.
+ r.Start = sacked.End
+ return false
+ })
+
+ if !searchMore {
+ return isLost
+ }
+
+ s.ranges.AscendGreaterOrEqual(r, func(i btree.Item) bool {
+ sacked := i.(header.SACKBlock)
+ if sacked.Contains(r) {
+ return false
+ }
+ nDupSACKBytes += sacked.Start.Size(sacked.End)
+ nDupSACK++
+ if nDupSACK >= nDupAckThreshold || nDupSACKBytes >= seqnum.Size((nDupAckThreshold-1)*s.smss) {
+ isLost = true
+ return false
+ }
+ return true
+ })
+ return isLost
+}
+
+// IsLost implements the IsLost(SeqNum) operation defined in RFC3517 section
+// 4.
+//
+// This routine returns whether the given sequence number is considered to be
+// lost. The routine returns true when either nDupAckThreshold discontiguous
+// SACKed sequences have arrived above 'SeqNum' or (nDupAckThreshold * SMSS)
+// bytes with sequence numbers greater than 'SeqNum' have been SACKed.
+// Otherwise, the routine returns false.
+func (s *SACKScoreboard) IsLost(seq seqnum.Value) bool {
+ return s.IsRangeLost(header.SACKBlock{seq, seq.Add(1)})
+}
+
+// Empty returns true if the SACK scoreboard has no entries, false otherwise.
+func (s *SACKScoreboard) Empty() bool {
+ return s.ranges.Len() == 0
+}
+
+// Sacked returns the current number of bytes held in the SACK scoreboard.
+func (s *SACKScoreboard) Sacked() seqnum.Size {
+ return s.sacked
+}
+
+// MaxSACKED returns the highest sequence number ever inserted in the SACK
+// scoreboard.
+func (s *SACKScoreboard) MaxSACKED() seqnum.Value {
+ return s.maxSACKED
+}
+
+// SMSS returns the sender's MSS as held by the SACK scoreboard.
+func (s *SACKScoreboard) SMSS() uint16 {
+ return s.smss
+}
diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go b/pkg/tcpip/transport/tcp/sack_scoreboard_test.go
new file mode 100644
index 000000000..b4e5ba0df
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/sack_scoreboard_test.go
@@ -0,0 +1,249 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+)
+
+const smss = 1500
+
+func initScoreboard(blocks []header.SACKBlock, iss seqnum.Value) *tcp.SACKScoreboard {
+ s := tcp.NewSACKScoreboard(smss, iss)
+ for _, blk := range blocks {
+ s.Insert(blk)
+ }
+ return s
+}
+
+func TestSACKScoreboardIsSACKED(t *testing.T) {
+ type blockTest struct {
+ block header.SACKBlock
+ sacked bool
+ }
+ testCases := []struct {
+ comment string
+ scoreboardBlocks []header.SACKBlock
+ blockTests []blockTest
+ iss seqnum.Value
+ }{
+ {
+ "Test holes and unsacked SACK blocks in SACKed ranges and insertion of overlapping SACK blocks",
+ []header.SACKBlock{{10, 20}, {10, 30}, {30, 40}, {41, 50}, {5, 10}, {1, 50}, {111, 120}, {101, 110}, {52, 120}},
+ []blockTest{
+ {header.SACKBlock{15, 21}, true},
+ {header.SACKBlock{200, 201}, false},
+ {header.SACKBlock{50, 51}, false},
+ {header.SACKBlock{53, 120}, true},
+ },
+ 0,
+ },
+ {
+ "Test disjoint SACKBlocks",
+ []header.SACKBlock{{2288624809, 2288810057}, {2288811477, 2288838565}},
+ []blockTest{
+ {header.SACKBlock{2288624809, 2288810057}, true},
+ {header.SACKBlock{2288811477, 2288838565}, true},
+ {header.SACKBlock{2288810057, 2288811477}, false},
+ },
+ 2288624809,
+ },
+ {
+ "Test sequence number wrap around",
+ []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}},
+ []blockTest{
+ {header.SACKBlock{4294254144, 4294254145}, true},
+ {header.SACKBlock{4294254143, 4294254144}, false},
+ {header.SACKBlock{4294254144, 1}, true},
+ {header.SACKBlock{225652, 5350509}, false},
+ {header.SACKBlock{5340409, 5350509}, true},
+ {header.SACKBlock{5350509, 5350609}, false},
+ },
+ 4294254144,
+ },
+ {
+ "Test disjoint SACKBlocks out of order",
+ []header.SACKBlock{{827450276, 827454536}, {827426028, 827428868}},
+ []blockTest{
+ {header.SACKBlock{827426028, 827428867}, true},
+ {header.SACKBlock{827450168, 827450275}, false},
+ },
+ 827426000,
+ },
+ }
+ for _, tc := range testCases {
+ sb := initScoreboard(tc.scoreboardBlocks, tc.iss)
+ for _, blkTest := range tc.blockTests {
+ if want, got := blkTest.sacked, sb.IsSACKED(blkTest.block); got != want {
+ t.Errorf("%s: s.IsSACKED(%v) = %v, want %v", tc.comment, blkTest.block, got, want)
+ }
+ }
+ }
+}
+
+func TestSACKScoreboardIsRangeLost(t *testing.T) {
+ s := tcp.NewSACKScoreboard(10, 0)
+ s.Insert(header.SACKBlock{1, 25})
+ s.Insert(header.SACKBlock{25, 50})
+ s.Insert(header.SACKBlock{51, 100})
+ s.Insert(header.SACKBlock{111, 120})
+ s.Insert(header.SACKBlock{101, 110})
+ s.Insert(header.SACKBlock{121, 141})
+ s.Insert(header.SACKBlock{145, 146})
+ s.Insert(header.SACKBlock{147, 148})
+ s.Insert(header.SACKBlock{149, 150})
+ s.Insert(header.SACKBlock{153, 154})
+ s.Insert(header.SACKBlock{155, 156})
+ testCases := []struct {
+ block header.SACKBlock
+ lost bool
+ }{
+ // Block not covered by SACK block and has more than
+ // nDupAckThreshold discontiguous SACK blocks after it as well
+ // as (nDupAckThreshold -1) * 10 (smss) bytes that have been
+ // SACKED above the sequence number covered by this block.
+ {block: header.SACKBlock{0, 1}, lost: true},
+
+ // These blocks have all been SACKed and should not be
+ // considered lost.
+ {block: header.SACKBlock{1, 2}, lost: false},
+ {block: header.SACKBlock{25, 26}, lost: false},
+ {block: header.SACKBlock{1, 45}, lost: false},
+
+ // Same as the first case above.
+ {block: header.SACKBlock{50, 51}, lost: true},
+
+ // This block has been SACKed and should not be considered lost.
+ {block: header.SACKBlock{119, 120}, lost: false},
+
+ // This one should return true because there are >
+ // (nDupAckThreshold - 1) * 10 (smss) bytes that have been
+ // sacked above this sequence number.
+ {block: header.SACKBlock{120, 121}, lost: true},
+
+ // This block has been SACKed and should not be considered lost.
+ {block: header.SACKBlock{125, 126}, lost: false},
+
+ // This block has not been SACKed and there are nDupAckThreshold
+ // number of SACKed blocks after it.
+ {block: header.SACKBlock{141, 145}, lost: true},
+
+ // This block has not been SACKed and there are less than
+ // nDupAckThreshold SACKed sequences after it.
+ {block: header.SACKBlock{151, 152}, lost: false},
+ }
+ for _, tc := range testCases {
+ if want, got := tc.lost, s.IsRangeLost(tc.block); got != want {
+ t.Errorf("s.IsRangeLost(%v) = %v, want %v", tc.block, got, want)
+ }
+ }
+}
+
+func TestSACKScoreboardIsLost(t *testing.T) {
+ s := tcp.NewSACKScoreboard(10, 0)
+ s.Insert(header.SACKBlock{1, 25})
+ s.Insert(header.SACKBlock{25, 50})
+ s.Insert(header.SACKBlock{51, 100})
+ s.Insert(header.SACKBlock{111, 120})
+ s.Insert(header.SACKBlock{101, 110})
+ s.Insert(header.SACKBlock{121, 141})
+ s.Insert(header.SACKBlock{121, 141})
+ s.Insert(header.SACKBlock{145, 146})
+ s.Insert(header.SACKBlock{147, 148})
+ s.Insert(header.SACKBlock{149, 150})
+ s.Insert(header.SACKBlock{153, 154})
+ s.Insert(header.SACKBlock{155, 156})
+ testCases := []struct {
+ seq seqnum.Value
+ lost bool
+ }{
+ // Sequence number not covered by SACK block and has more than
+ // nDupAckThreshold discontiguous SACK blocks after it as well
+ // as (nDupAckThreshold -1) * 10 (smss) bytes that have been
+ // SACKED above the sequence number.
+ {seq: 0, lost: true},
+
+ // These sequence numbers have all been SACKed and should not be
+ // considered lost.
+ {seq: 1, lost: false},
+ {seq: 25, lost: false},
+ {seq: 45, lost: false},
+
+ // Same as first case above.
+ {seq: 50, lost: true},
+
+ // This block has been SACKed and should not be considered lost.
+ {seq: 119, lost: false},
+
+ // This one should return true because there are >
+ // (nDupAckThreshold - 1) * 10 (smss) bytes that have been
+ // sacked above this sequence number.
+ {seq: 120, lost: true},
+
+ // This sequence number has been SACKed and should not be
+ // considered lost.
+ {seq: 125, lost: false},
+
+ // This sequence number has not been SACKed and there are
+ // nDupAckThreshold number of SACKed blocks after it.
+ {seq: 141, lost: true},
+
+ // This sequence number has not been SACKed and there are less
+ // than nDupAckThreshold SACKed sequences after it.
+ {seq: 151, lost: false},
+ }
+ for _, tc := range testCases {
+ if want, got := tc.lost, s.IsLost(tc.seq); got != want {
+ t.Errorf("s.IsLost(%v) = %v, want %v", tc.seq, got, want)
+ }
+ }
+}
+
+func TestSACKScoreboardDelete(t *testing.T) {
+ blocks := []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}}
+ s := initScoreboard(blocks, 4294254143)
+ s.Delete(5340408)
+ if s.Empty() {
+ t.Fatalf("s.Empty() = true, want false")
+ }
+ if got, want := s.Sacked(), blocks[1].Start.Size(blocks[1].End); got != want {
+ t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want)
+ }
+ s.Delete(5340410)
+ if s.Empty() {
+ t.Fatal("s.Empty() = true, want false")
+ }
+ newSB := header.SACKBlock{5340410, 5350509}
+ if !s.IsSACKED(newSB) {
+ t.Fatalf("s.IsSACKED(%v) = false, want true, scoreboard: %v", newSB, s)
+ }
+ s.Delete(5350509)
+ lastOctet := header.SACKBlock{5350508, 5350509}
+ if s.IsSACKED(lastOctet) {
+ t.Fatalf("s.IsSACKED(%v) = false, want true", lastOctet)
+ }
+
+ s.Delete(5350510)
+ if !s.Empty() {
+ t.Fatal("s.Empty() = false, want true")
+ }
+ if got, want := s.Sacked(), seqnum.Size(0); got != want {
+ t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want)
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
new file mode 100644
index 000000000..0280892a8
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -0,0 +1,194 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "sync/atomic"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// segment represents a TCP segment. It holds the payload and parsed TCP segment
+// information, and can be added to intrusive lists.
+// segment is mostly immutable, the only field allowed to change is viewToDeliver.
+//
+// +stateify savable
+type segment struct {
+ segmentEntry
+ refCnt int32
+ id stack.TransportEndpointID `state:"manual"`
+ route stack.Route `state:"manual"`
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ hdr header.TCP
+ // views is used as buffer for data when its length is large
+ // enough to store a VectorisedView.
+ views [8]buffer.View `state:"nosave"`
+ // viewToDeliver keeps track of the next View that should be
+ // delivered by the Read endpoint.
+ viewToDeliver int
+ sequenceNumber seqnum.Value
+ ackNumber seqnum.Value
+ flags uint8
+ window seqnum.Size
+ // csum is only populated for received segments.
+ csum uint16
+ // csumValid is true if the csum in the received segment is valid.
+ csumValid bool
+
+ // parsedOptions stores the parsed values from the options in the segment.
+ parsedOptions header.TCPOptions
+ options []byte `state:".([]byte)"`
+ hasNewSACKInfo bool
+ rcvdTime time.Time `state:".(unixTime)"`
+ // xmitTime is the last transmit time of this segment.
+ xmitTime time.Time `state:".(unixTime)"`
+ xmitCount uint32
+}
+
+func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
+ s := &segment{
+ refCnt: 1,
+ id: id,
+ route: r.Clone(),
+ }
+ s.data = pkt.Data.Clone(s.views[:])
+ s.hdr = header.TCP(pkt.TransportHeader)
+ s.rcvdTime = time.Now()
+ return s
+}
+
+func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment {
+ s := &segment{
+ refCnt: 1,
+ id: id,
+ route: r.Clone(),
+ }
+ s.rcvdTime = time.Now()
+ if len(v) != 0 {
+ s.views[0] = v
+ s.data = buffer.NewVectorisedView(len(v), s.views[:1])
+ }
+ return s
+}
+
+func (s *segment) clone() *segment {
+ t := &segment{
+ refCnt: 1,
+ id: s.id,
+ sequenceNumber: s.sequenceNumber,
+ ackNumber: s.ackNumber,
+ flags: s.flags,
+ window: s.window,
+ route: s.route.Clone(),
+ viewToDeliver: s.viewToDeliver,
+ rcvdTime: s.rcvdTime,
+ xmitTime: s.xmitTime,
+ xmitCount: s.xmitCount,
+ }
+ t.data = s.data.Clone(t.views[:])
+ return t
+}
+
+// flagIsSet checks if at least one flag in flags is set in s.flags.
+func (s *segment) flagIsSet(flags uint8) bool {
+ return s.flags&flags != 0
+}
+
+// flagsAreSet checks if all flags in flags are set in s.flags.
+func (s *segment) flagsAreSet(flags uint8) bool {
+ return s.flags&flags == flags
+}
+
+func (s *segment) decRef() {
+ if atomic.AddInt32(&s.refCnt, -1) == 0 {
+ s.route.Release()
+ }
+}
+
+func (s *segment) incRef() {
+ atomic.AddInt32(&s.refCnt, 1)
+}
+
+// logicalLen is the segment length in the sequence number space. It's defined
+// as the data length plus one for each of the SYN and FIN bits set.
+func (s *segment) logicalLen() seqnum.Size {
+ l := seqnum.Size(s.data.Size())
+ if s.flagIsSet(header.TCPFlagSyn) {
+ l++
+ }
+ if s.flagIsSet(header.TCPFlagFin) {
+ l++
+ }
+ return l
+}
+
+// parse populates the sequence & ack numbers, flags, and window fields of the
+// segment from the TCP header stored in the data. It then updates the view to
+// skip the header.
+//
+// Returns boolean indicating if the parsing was successful.
+//
+// If checksum verification is not offloaded then parse also verifies the
+// TCP checksum and stores the checksum and result of checksum verification in
+// the csum and csumValid fields of the segment.
+func (s *segment) parse() bool {
+ // h is the header followed by the payload. We check that the offset to
+ // the data respects the following constraints:
+ // 1. That it's at least the minimum header size; if we don't do this
+ // then part of the header would be delivered to user.
+ // 2. That the header fits within the buffer; if we don't do this, we
+ // would panic when we tried to access data beyond the buffer.
+ //
+ // N.B. The segment has already been validated as having at least the
+ // minimum TCP size before reaching here, so it's safe to read the
+ // fields.
+ offset := int(s.hdr.DataOffset())
+ if offset < header.TCPMinimumSize || offset > len(s.hdr) {
+ return false
+ }
+
+ s.options = []byte(s.hdr[header.TCPMinimumSize:])
+ s.parsedOptions = header.ParseTCPOptions(s.options)
+
+ // Query the link capabilities to decide if checksum validation is
+ // required.
+ verifyChecksum := true
+ if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 {
+ s.csumValid = true
+ verifyChecksum = false
+ }
+ if verifyChecksum {
+ s.csum = s.hdr.Checksum()
+ xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr)))
+ xsum = s.hdr.CalculateChecksum(xsum)
+ xsum = header.ChecksumVV(s.data, xsum)
+ s.csumValid = xsum == 0xffff
+ }
+
+ s.sequenceNumber = seqnum.Value(s.hdr.SequenceNumber())
+ s.ackNumber = seqnum.Value(s.hdr.AckNumber())
+ s.flags = s.hdr.Flags()
+ s.window = seqnum.Size(s.hdr.WindowSize())
+ return true
+}
+
+// sackBlock returns a header.SACKBlock that represents this segment.
+func (s *segment) sackBlock() header.SACKBlock {
+ return header.SACKBlock{s.sequenceNumber, s.sequenceNumber.Add(s.logicalLen())}
+}
diff --git a/pkg/tcpip/transport/tcp/segment_heap.go b/pkg/tcpip/transport/tcp/segment_heap.go
new file mode 100644
index 000000000..8d3ddce4b
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_heap.go
@@ -0,0 +1,51 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import "container/heap"
+
+type segmentHeap []*segment
+
+var _ heap.Interface = (*segmentHeap)(nil)
+
+// Len returns the length of h.
+func (h *segmentHeap) Len() int {
+ return len(*h)
+}
+
+// Less determines whether the i-th element of h is less than the j-th element.
+func (h *segmentHeap) Less(i, j int) bool {
+ return (*h)[i].sequenceNumber.LessThan((*h)[j].sequenceNumber)
+}
+
+// Swap swaps the i-th and j-th elements of h.
+func (h *segmentHeap) Swap(i, j int) {
+ (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
+}
+
+// Push adds x as the last element of h.
+func (h *segmentHeap) Push(x interface{}) {
+ *h = append(*h, x.(*segment))
+}
+
+// Pop removes the last element of h and returns it.
+func (h *segmentHeap) Pop() interface{} {
+ old := *h
+ n := len(old)
+ x := old[n-1]
+ old[n-1] = nil
+ *h = old[:n-1]
+ return x
+}
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
new file mode 100644
index 000000000..48a257137
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -0,0 +1,85 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// segmentQueue is a bounded, thread-safe queue of TCP segments.
+//
+// +stateify savable
+type segmentQueue struct {
+ mu sync.Mutex `state:"nosave"`
+ list segmentList `state:"wait"`
+ limit int
+ used int
+}
+
+// emptyLocked determines if the queue is empty.
+// Preconditions: q.mu must be held.
+func (q *segmentQueue) emptyLocked() bool {
+ return q.used == 0
+}
+
+// empty determines if the queue is empty.
+func (q *segmentQueue) empty() bool {
+ q.mu.Lock()
+ r := q.emptyLocked()
+ q.mu.Unlock()
+
+ return r
+}
+
+// setLimit updates the limit. No segments are immediately dropped in case the
+// queue becomes full due to the new limit.
+func (q *segmentQueue) setLimit(limit int) {
+ q.mu.Lock()
+ q.limit = limit
+ q.mu.Unlock()
+}
+
+// enqueue adds the given segment to the queue.
+//
+// Returns true when the segment is successfully added to the queue, in which
+// case ownership of the reference is transferred to the queue. And returns
+// false if the queue is full, in which case ownership is retained by the
+// caller.
+func (q *segmentQueue) enqueue(s *segment) bool {
+ q.mu.Lock()
+ r := q.used < q.limit
+ if r {
+ q.list.PushBack(s)
+ q.used++
+ }
+ q.mu.Unlock()
+
+ return r
+}
+
+// dequeue removes and returns the next segment from queue, if one exists.
+// Ownership is transferred to the caller, who is responsible for decrementing
+// the ref count when done.
+func (q *segmentQueue) dequeue() *segment {
+ q.mu.Lock()
+ s := q.list.Front()
+ if s != nil {
+ q.list.Remove(s)
+ q.used--
+ }
+ q.mu.Unlock()
+
+ return s
+}
diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go
new file mode 100644
index 000000000..7dc2741a6
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_state.go
@@ -0,0 +1,82 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+// saveData is invoked by stateify.
+func (s *segment) saveData() buffer.VectorisedView {
+ // We cannot save s.data directly as s.data.views may alias to s.views,
+ // which is not allowed by state framework (in-struct pointer).
+ v := make([]buffer.View, len(s.data.Views()))
+ // For views already delivered, we cannot save them directly as they may
+ // have already been sliced and saved elsewhere (e.g., readViews).
+ for i := 0; i < s.viewToDeliver; i++ {
+ v[i] = append([]byte(nil), s.data.Views()[i]...)
+ }
+ for i := s.viewToDeliver; i < len(v); i++ {
+ v[i] = s.data.Views()[i]
+ }
+ return buffer.NewVectorisedView(s.data.Size(), v)
+}
+
+// loadData is invoked by stateify.
+func (s *segment) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the s.data = data.Clone(s.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing s.views for data.views.
+ s.data = data
+}
+
+// saveOptions is invoked by stateify.
+func (s *segment) saveOptions() []byte {
+ // We cannot save s.options directly as it may point to s.data's trimmed
+ // tail, which is not allowed by state framework (in-struct pointer).
+ b := make([]byte, 0, cap(s.options))
+ return append(b, s.options...)
+}
+
+// loadOptions is invoked by stateify.
+func (s *segment) loadOptions(options []byte) {
+ // NOTE: We cannot point s.options back into s.data's trimmed tail. But
+ // it is OK as they do not need to aliased. Plus, options is already
+ // allocated so there is no cost here.
+ s.options = options
+}
+
+// saveRcvdTime is invoked by stateify.
+func (s *segment) saveRcvdTime() unixTime {
+ return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()}
+}
+
+// loadRcvdTime is invoked by stateify.
+func (s *segment) loadRcvdTime(unix unixTime) {
+ s.rcvdTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveXmitTime is invoked by stateify.
+func (s *segment) saveXmitTime() unixTime {
+ return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()}
+}
+
+// loadXmitTime is invoked by stateify.
+func (s *segment) loadXmitTime(unix unixTime) {
+ s.rcvdTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
new file mode 100644
index 000000000..5862c32f2
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -0,0 +1,1487 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "fmt"
+ "math"
+ "sync/atomic"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+const (
+ // MinRTO is the minimum allowed value for the retransmit timeout.
+ MinRTO = 200 * time.Millisecond
+
+ // MaxRTO is the maximum allowed value for the retransmit timeout.
+ MaxRTO = 120 * time.Second
+
+ // InitialCwnd is the initial congestion window.
+ InitialCwnd = 10
+
+ // nDupAckThreshold is the number of duplicate ACK's required
+ // before fast-retransmit is entered.
+ nDupAckThreshold = 3
+
+ // MaxRetries is the maximum number of probe retries sender does
+ // before timing out the connection.
+ // Linux default TCP_RETR2, net.ipv4.tcp_retries2.
+ MaxRetries = 15
+)
+
+// ccState indicates the current congestion control state for this sender.
+type ccState int
+
+const (
+ // Open indicates that the sender is receiving acks in order and
+ // no loss or dupACK's etc have been detected.
+ Open ccState = iota
+ // RTORecovery indicates that an RTO has occurred and the sender
+ // has entered an RTO based recovery phase.
+ RTORecovery
+ // FastRecovery indicates that the sender has entered FastRecovery
+ // based on receiving nDupAck's. This state is entered only when
+ // SACK is not in use.
+ FastRecovery
+ // SACKRecovery indicates that the sender has entered SACK based
+ // recovery.
+ SACKRecovery
+ // Disorder indicates the sender either received some SACK blocks
+ // or dupACK's.
+ Disorder
+)
+
+// congestionControl is an interface that must be implemented by any supported
+// congestion control algorithm.
+type congestionControl interface {
+ // HandleNDupAcks is invoked when sender.dupAckCount >= nDupAckThreshold
+ // just before entering fast retransmit.
+ HandleNDupAcks()
+
+ // HandleRTOExpired is invoked when the retransmit timer expires.
+ HandleRTOExpired()
+
+ // Update is invoked when processing inbound acks. It's passed the
+ // number of packet's that were acked by the most recent cumulative
+ // acknowledgement.
+ Update(packetsAcked int)
+
+ // PostRecovery is invoked when the sender is exiting a fast retransmit/
+ // recovery phase. This provides congestion control algorithms a way
+ // to adjust their state when exiting recovery.
+ PostRecovery()
+}
+
+// sender holds the state necessary to send TCP segments.
+//
+// +stateify savable
+type sender struct {
+ ep *endpoint
+
+ // lastSendTime is the timestamp when the last packet was sent.
+ lastSendTime time.Time `state:".(unixTime)"`
+
+ // dupAckCount is the number of duplicated acks received. It is used for
+ // fast retransmit.
+ dupAckCount int
+
+ // fr holds state related to fast recovery.
+ fr fastRecovery
+
+ // sndCwnd is the congestion window, in packets.
+ sndCwnd int
+
+ // sndSsthresh is the threshold between slow start and congestion
+ // avoidance.
+ sndSsthresh int
+
+ // sndCAAckCount is the number of packets acknowledged during congestion
+ // avoidance. When enough packets have been ack'd (typically cwnd
+ // packets), the congestion window is incremented by one.
+ sndCAAckCount int
+
+ // outstanding is the number of outstanding packets, that is, packets
+ // that have been sent but not yet acknowledged.
+ outstanding int
+
+ // sndWnd is the send window size.
+ sndWnd seqnum.Size
+
+ // sndUna is the next unacknowledged sequence number.
+ sndUna seqnum.Value
+
+ // sndNxt is the sequence number of the next segment to be sent.
+ sndNxt seqnum.Value
+
+ // rttMeasureSeqNum is the sequence number being used for the latest RTT
+ // measurement.
+ rttMeasureSeqNum seqnum.Value
+
+ // rttMeasureTime is the time when the rttMeasureSeqNum was sent.
+ rttMeasureTime time.Time `state:".(unixTime)"`
+
+ // firstRetransmittedSegXmitTime is the original transmit time of
+ // the first segment that was retransmitted due to RTO expiration.
+ firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"`
+
+ // zeroWindowProbing is set if the sender is currently probing
+ // for zero receive window.
+ zeroWindowProbing bool `state:"nosave"`
+
+ // unackZeroWindowProbes is the number of unacknowledged zero
+ // window probes.
+ unackZeroWindowProbes uint32 `state:"nosave"`
+
+ closed bool
+ writeNext *segment
+ writeList segmentList
+ resendTimer timer `state:"nosave"`
+ resendWaker sleep.Waker `state:"nosave"`
+
+ // rtt.srtt, rtt.rttvar, and rto are the "smoothed round-trip time",
+ // "round-trip time variation" and "retransmit timeout", as defined in
+ // section 2 of RFC 6298.
+ rtt rtt
+ rto time.Duration
+
+ // minRTO is the minimum permitted value for sender.rto.
+ minRTO time.Duration
+
+ // maxRTO is the maximum permitted value for sender.rto.
+ maxRTO time.Duration
+
+ // maxRetries is the maximum permitted retransmissions.
+ maxRetries uint32
+
+ // maxPayloadSize is the maximum size of the payload of a given segment.
+ // It is initialized on demand.
+ maxPayloadSize int
+
+ // gso is set if generic segmentation offload is enabled.
+ gso bool
+
+ // sndWndScale is the number of bits to shift left when reading the send
+ // window size from a segment.
+ sndWndScale uint8
+
+ // maxSentAck is the maxium acknowledgement actually sent.
+ maxSentAck seqnum.Value
+
+ // state is the current state of congestion control for this endpoint.
+ state ccState
+
+ // cc is the congestion control algorithm in use for this sender.
+ cc congestionControl
+}
+
+// rtt is a synchronization wrapper used to appease stateify. See the comment
+// in sender, where it is used.
+//
+// +stateify savable
+type rtt struct {
+ sync.Mutex `state:"nosave"`
+
+ srtt time.Duration
+ rttvar time.Duration
+ srttInited bool
+}
+
+// fastRecovery holds information related to fast recovery from a packet loss.
+//
+// +stateify savable
+type fastRecovery struct {
+ // active whether the endpoint is in fast recovery. The following fields
+ // are only meaningful when active is true.
+ active bool
+
+ // first and last represent the inclusive sequence number range being
+ // recovered.
+ first seqnum.Value
+ last seqnum.Value
+
+ // maxCwnd is the maximum value the congestion window may be inflated to
+ // due to duplicate acks. This exists to avoid attacks where the
+ // receiver intentionally sends duplicate acks to artificially inflate
+ // the sender's cwnd.
+ maxCwnd int
+
+ // highRxt is the highest sequence number which has been retransmitted
+ // during the current loss recovery phase.
+ // See: RFC 6675 Section 2 for details.
+ highRxt seqnum.Value
+
+ // rescueRxt is the highest sequence number which has been
+ // optimistically retransmitted to prevent stalling of the ACK clock
+ // when there is loss at the end of the window and no new data is
+ // available for transmission.
+ // See: RFC 6675 Section 2 for details.
+ rescueRxt seqnum.Value
+}
+
+func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender {
+ // The sender MUST reduce the TCP data length to account for any IP or
+ // TCP options that it is including in the packets that it sends.
+ // See: https://tools.ietf.org/html/rfc6691#section-2
+ maxPayloadSize := int(mss) - ep.maxOptionSize()
+
+ s := &sender{
+ ep: ep,
+ sndWnd: sndWnd,
+ sndUna: iss + 1,
+ sndNxt: iss + 1,
+ rto: 1 * time.Second,
+ rttMeasureSeqNum: iss + 1,
+ lastSendTime: time.Now(),
+ maxPayloadSize: maxPayloadSize,
+ maxSentAck: irs + 1,
+ fr: fastRecovery{
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1.
+ last: iss,
+ highRxt: iss,
+ rescueRxt: iss,
+ },
+ gso: ep.gso != nil,
+ }
+
+ if s.gso {
+ s.ep.gso.MSS = uint16(maxPayloadSize)
+ }
+
+ s.cc = s.initCongestionControl(ep.cc)
+
+ // A negative sndWndScale means that no scaling is in use, otherwise we
+ // store the scaling value.
+ if sndWndScale > 0 {
+ s.sndWndScale = uint8(sndWndScale)
+ }
+
+ s.resendTimer.init(&s.resendWaker)
+
+ s.updateMaxPayloadSize(int(ep.route.MTU()), 0)
+
+ // Initialize SACK Scoreboard after updating max payload size as we use
+ // the maxPayloadSize as the smss when determining if a segment is lost
+ // etc.
+ s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss)
+
+ // Get Stack wide config.
+ var minRTO tcpip.TCPMinRTOOption
+ if err := ep.stack.TransportProtocolOption(ProtocolNumber, &minRTO); err != nil {
+ panic(fmt.Sprintf("unable to get minRTO from stack: %s", err))
+ }
+ s.minRTO = time.Duration(minRTO)
+
+ var maxRTO tcpip.TCPMaxRTOOption
+ if err := ep.stack.TransportProtocolOption(ProtocolNumber, &maxRTO); err != nil {
+ panic(fmt.Sprintf("unable to get maxRTO from stack: %s", err))
+ }
+ s.maxRTO = time.Duration(maxRTO)
+
+ var maxRetries tcpip.TCPMaxRetriesOption
+ if err := ep.stack.TransportProtocolOption(ProtocolNumber, &maxRetries); err != nil {
+ panic(fmt.Sprintf("unable to get maxRetries from stack: %s", err))
+ }
+ s.maxRetries = uint32(maxRetries)
+
+ return s
+}
+
+// initCongestionControl initializes the specified congestion control module and
+// returns a handle to it. It also initializes the sndCwnd and sndSsThresh to
+// their initial values.
+func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionControlOption) congestionControl {
+ s.sndCwnd = InitialCwnd
+ s.sndSsthresh = math.MaxInt64
+
+ switch congestionControlName {
+ case ccCubic:
+ return newCubicCC(s)
+ case ccReno:
+ fallthrough
+ default:
+ return newRenoCC(s)
+ }
+}
+
+// updateMaxPayloadSize updates the maximum payload size based on the given
+// MTU. If this is in response to "packet too big" control packets (indicated
+// by the count argument), it also reduces the number of outstanding packets and
+// attempts to retransmit the first packet above the MTU size.
+func (s *sender) updateMaxPayloadSize(mtu, count int) {
+ m := mtu - header.TCPMinimumSize
+
+ m -= s.ep.maxOptionSize()
+
+ // We don't adjust up for now.
+ if m >= s.maxPayloadSize {
+ return
+ }
+
+ // Make sure we can transmit at least one byte.
+ if m <= 0 {
+ m = 1
+ }
+
+ s.maxPayloadSize = m
+ if s.gso {
+ s.ep.gso.MSS = uint16(m)
+ }
+
+ if count == 0 {
+ // updateMaxPayloadSize is also called when the sender is created.
+ // and there is no data to send in such cases. Return immediately.
+ return
+ }
+
+ // Update the scoreboard's smss to reflect the new lowered
+ // maxPayloadSize.
+ s.ep.scoreboard.smss = uint16(m)
+
+ s.outstanding -= count
+ if s.outstanding < 0 {
+ s.outstanding = 0
+ }
+
+ // Rewind writeNext to the first segment exceeding the MTU. Do nothing
+ // if it is already before such a packet.
+ for seg := s.writeList.Front(); seg != nil; seg = seg.Next() {
+ if seg == s.writeNext {
+ // We got to writeNext before we could find a segment
+ // exceeding the MTU.
+ break
+ }
+
+ if seg.data.Size() > m {
+ // We found a segment exceeding the MTU. Rewind
+ // writeNext and try to retransmit it.
+ s.writeNext = seg
+ break
+ }
+ }
+
+ // Since we likely reduced the number of outstanding packets, we may be
+ // ready to send some more.
+ s.sendData()
+}
+
+// sendAck sends an ACK segment.
+func (s *sender) sendAck() {
+ s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.sndNxt)
+}
+
+// updateRTO updates the retransmit timeout when a new roud-trip time is
+// available. This is done in accordance with section 2 of RFC 6298.
+func (s *sender) updateRTO(rtt time.Duration) {
+ s.rtt.Lock()
+ if !s.rtt.srttInited {
+ s.rtt.rttvar = rtt / 2
+ s.rtt.srtt = rtt
+ s.rtt.srttInited = true
+ } else {
+ diff := s.rtt.srtt - rtt
+ if diff < 0 {
+ diff = -diff
+ }
+ // Use RFC6298 standard algorithm to update rttvar and srtt when
+ // no timestamps are available.
+ if !s.ep.sendTSOk {
+ s.rtt.rttvar = (3*s.rtt.rttvar + diff) / 4
+ s.rtt.srtt = (7*s.rtt.srtt + rtt) / 8
+ } else {
+ // When we are taking RTT measurements of every ACK then
+ // we need to use a modified method as specified in
+ // https://tools.ietf.org/html/rfc7323#appendix-G
+ if s.outstanding == 0 {
+ s.rtt.Unlock()
+ return
+ }
+ // Netstack measures congestion window/inflight all in
+ // terms of packets and not bytes. This is similar to
+ // how linux also does cwnd and inflight. In practice
+ // this approximation works as expected.
+ expectedSamples := math.Ceil(float64(s.outstanding) / 2)
+
+ // alpha & beta values are the original values as recommended in
+ // https://tools.ietf.org/html/rfc6298#section-2.3.
+ const alpha = 0.125
+ const beta = 0.25
+
+ alphaPrime := alpha / expectedSamples
+ betaPrime := beta / expectedSamples
+ rttVar := (1-betaPrime)*s.rtt.rttvar.Seconds() + betaPrime*diff.Seconds()
+ srtt := (1-alphaPrime)*s.rtt.srtt.Seconds() + alphaPrime*rtt.Seconds()
+ s.rtt.rttvar = time.Duration(rttVar * float64(time.Second))
+ s.rtt.srtt = time.Duration(srtt * float64(time.Second))
+ }
+ }
+
+ s.rto = s.rtt.srtt + 4*s.rtt.rttvar
+ s.rtt.Unlock()
+ if s.rto < s.minRTO {
+ s.rto = s.minRTO
+ }
+}
+
+// resendSegment resends the first unacknowledged segment.
+func (s *sender) resendSegment() {
+ // Don't use any segments we already sent to measure RTT as they may
+ // have been affected by packets being lost.
+ s.rttMeasureSeqNum = s.sndNxt
+
+ // Resend the segment.
+ if seg := s.writeList.Front(); seg != nil {
+ if seg.data.Size() > s.maxPayloadSize {
+ s.splitSeg(seg, s.maxPayloadSize)
+ }
+
+ // See: RFC 6675 section 5 Step 4.3
+ //
+ // To prevent retransmission, set both the HighRXT and RescueRXT
+ // to the highest sequence number in the retransmitted segment.
+ s.fr.highRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
+ s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
+ s.sendSegment(seg)
+ s.ep.stack.Stats().TCP.FastRetransmit.Increment()
+ s.ep.stats.SendErrors.FastRetransmit.Increment()
+
+ // Run SetPipe() as per RFC 6675 section 5 Step 4.4
+ s.SetPipe()
+ }
+}
+
+// retransmitTimerExpired is called when the retransmit timer expires, and
+// unacknowledged segments are assumed lost, and thus need to be resent.
+// Returns true if the connection is still usable, or false if the connection
+// is deemed lost.
+func (s *sender) retransmitTimerExpired() bool {
+ // Check if the timer actually expired or if it's a spurious wake due
+ // to a previously orphaned runtime timer.
+ if !s.resendTimer.checkExpiration() {
+ return true
+ }
+
+ // TODO(b/147297758): Band-aid fix, retransmitTimer can fire in some edge cases
+ // when writeList is empty. Remove this once we have a proper fix for this
+ // issue.
+ if s.writeList.Front() == nil {
+ return true
+ }
+
+ s.ep.stack.Stats().TCP.Timeouts.Increment()
+ s.ep.stats.SendErrors.Timeouts.Increment()
+
+ // Give up if we've waited more than a minute since the last resend or
+ // if a user time out is set and we have exceeded the user specified
+ // timeout since the first retransmission.
+ uto := s.ep.userTimeout
+
+ if s.firstRetransmittedSegXmitTime.IsZero() {
+ // We store the original xmitTime of the segment that we are
+ // about to retransmit as the retransmission time. This is
+ // required as by the time the retransmitTimer has expired the
+ // segment has already been sent and unacked for the RTO at the
+ // time the segment was sent.
+ s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime
+ }
+
+ elapsed := time.Since(s.firstRetransmittedSegXmitTime)
+ remaining := s.maxRTO
+ if uto != 0 {
+ // Cap to the user specified timeout if one is specified.
+ remaining = uto - elapsed
+ }
+
+ // Always honor the user-timeout irrespective of whether the zero
+ // window probes were acknowledged.
+ // net/ipv4/tcp_timer.c::tcp_probe_timer()
+ if remaining <= 0 || s.unackZeroWindowProbes >= s.maxRetries {
+ return false
+ }
+
+ // Set new timeout. The timer will be restarted by the call to sendData
+ // below.
+ s.rto *= 2
+ // Cap the RTO as per RFC 1122 4.2.3.1, RFC 6298 5.5
+ if s.rto > s.maxRTO {
+ s.rto = s.maxRTO
+ }
+
+ // Cap RTO to remaining time.
+ if s.rto > remaining {
+ s.rto = remaining
+ }
+
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4.
+ //
+ // Retransmit timeouts:
+ // After a retransmit timeout, record the highest sequence number
+ // transmitted in the variable recover, and exit the fast recovery
+ // procedure if applicable.
+ s.fr.last = s.sndNxt - 1
+
+ if s.fr.active {
+ // We were attempting fast recovery but were not successful.
+ // Leave the state. We don't need to update ssthresh because it
+ // has already been updated when entered fast-recovery.
+ s.leaveFastRecovery()
+ }
+
+ s.state = RTORecovery
+ s.cc.HandleRTOExpired()
+
+ // Mark the next segment to be sent as the first unacknowledged one and
+ // start sending again. Set the number of outstanding packets to 0 so
+ // that we'll be able to retransmit.
+ //
+ // We'll keep on transmitting (or retransmitting) as we get acks for
+ // the data we transmit.
+ s.outstanding = 0
+
+ // Expunge all SACK information as per https://tools.ietf.org/html/rfc6675#section-5.1
+ //
+ // In order to avoid memory deadlocks, the TCP receiver is allowed to
+ // discard data that has already been selectively acknowledged. As a
+ // result, [RFC2018] suggests that a TCP sender SHOULD expunge the SACK
+ // information gathered from a receiver upon a retransmission timeout
+ // (RTO) "since the timeout might indicate that the data receiver has
+ // reneged." Additionally, a TCP sender MUST "ignore prior SACK
+ // information in determining which data to retransmit."
+ //
+ // NOTE: We take the stricter interpretation and just expunge all
+ // information as we lack more rigorous checks to validate if the SACK
+ // information is usable after an RTO.
+ s.ep.scoreboard.Reset()
+ s.writeNext = s.writeList.Front()
+
+ // RFC 1122 4.2.2.17: Start sending zero window probes when we still see a
+ // zero receive window after retransmission interval and we have data to
+ // send.
+ if s.zeroWindowProbing {
+ s.sendZeroWindowProbe()
+ // RFC 1122 4.2.2.17: A TCP MAY keep its offered receive window closed
+ // indefinitely. As long as the receiving TCP continues to send
+ // acknowledgments in response to the probe segments, the sending TCP
+ // MUST allow the connection to stay open.
+ return true
+ }
+
+ seg := s.writeNext
+ // RFC 1122 4.2.3.5: Close the connection when the number of
+ // retransmissions for this segment is beyond a limit.
+ if seg != nil && seg.xmitCount > s.maxRetries {
+ return false
+ }
+
+ s.sendData()
+
+ return true
+}
+
+// pCount returns the number of packets in the segment. Due to GSO, a segment
+// can be composed of multiple packets.
+func (s *sender) pCount(seg *segment) int {
+ size := seg.data.Size()
+ if size == 0 {
+ return 1
+ }
+
+ return (size-1)/s.maxPayloadSize + 1
+}
+
+// splitSeg splits a given segment at the size specified and inserts the
+// remainder as a new segment after the current one in the write list.
+func (s *sender) splitSeg(seg *segment, size int) {
+ if seg.data.Size() <= size {
+ return
+ }
+ // Split this segment up.
+ nSeg := seg.clone()
+ nSeg.data.TrimFront(size)
+ nSeg.sequenceNumber.UpdateForward(seqnum.Size(size))
+ s.writeList.InsertAfter(seg, nSeg)
+
+ // The segment being split does not carry PUSH flag because it is
+ // followed by the newly split segment.
+ // RFC1122 section 4.2.2.2: MUST set the PSH bit in the last buffered
+ // segment (i.e., when there is no more queued data to be sent).
+ // Linux removes PSH flag only when the segment is being split over MSS
+ // and retains it when we are splitting the segment over lack of sender
+ // window space.
+ // ref: net/ipv4/tcp_output.c::tcp_write_xmit(), tcp_mss_split_point()
+ // ref: net/ipv4/tcp_output.c::tcp_write_wakeup(), tcp_snd_wnd_test()
+ if seg.data.Size() > s.maxPayloadSize {
+ seg.flags ^= header.TCPFlagPsh
+ }
+
+ seg.data.CapLength(size)
+}
+
+// NextSeg implements the RFC6675 NextSeg() operation.
+//
+// NextSeg starts scanning the writeList starting from nextSegHint and returns
+// the hint to be passed on the next call to NextSeg. This is required to avoid
+// iterating the write list repeatedly when NextSeg is invoked in a loop during
+// recovery. The returned hint will be nil if there are no more segments that
+// can match rules defined by NextSeg operation in RFC6675.
+//
+// rescueRtx will be true only if nextSeg is a rescue retransmission as
+// described by Step 4) of the NextSeg algorithm.
+func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRtx bool) {
+ var s3 *segment
+ var s4 *segment
+ // Step 1.
+ for seg := nextSegHint; seg != nil; seg = seg.Next() {
+ // Stop iteration if we hit a segment that has never been
+ // transmitted (i.e. either it has no assigned sequence number
+ // or if it does have one, it's >= the next sequence number
+ // to be sent [i.e. >= s.sndNxt]).
+ if !s.isAssignedSequenceNumber(seg) || s.sndNxt.LessThanEq(seg.sequenceNumber) {
+ hint = nil
+ break
+ }
+ segSeq := seg.sequenceNumber
+ if smss := s.ep.scoreboard.SMSS(); seg.data.Size() > int(smss) {
+ s.splitSeg(seg, int(smss))
+ }
+
+ // See RFC 6675 Section 4
+ //
+ // 1. If there exists a smallest unSACKED sequence number
+ // 'S2' that meets the following 3 criteria for determinig
+ // loss, the sequence range of one segment of up to SMSS
+ // octects starting with S2 MUST be returned.
+ if !s.ep.scoreboard.IsSACKED(header.SACKBlock{segSeq, segSeq.Add(1)}) {
+ // NextSeg():
+ //
+ // (1.a) S2 is greater than HighRxt
+ // (1.b) S2 is less than highest octect covered by
+ // any received SACK.
+ if s.fr.highRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) {
+ // NextSeg():
+ // (1.c) IsLost(S2) returns true.
+ if s.ep.scoreboard.IsLost(segSeq) {
+ return seg, seg.Next(), false
+ }
+
+ // NextSeg():
+ //
+ // (3): If the conditions for rules (1) and (2)
+ // fail, but there exists an unSACKed sequence
+ // number S3 that meets the criteria for
+ // detecting loss given in steps 1.a and 1.b
+ // above (specifically excluding (1.c)) then one
+ // segment of upto SMSS octets starting with S3
+ // SHOULD be returned.
+ if s3 == nil {
+ s3 = seg
+ hint = seg.Next()
+ }
+ }
+ // NextSeg():
+ //
+ // (4) If the conditions for (1), (2) and (3) fail,
+ // but there exists outstanding unSACKED data, we
+ // provide the opportunity for a single "rescue"
+ // retransmission per entry into loss recovery. If
+ // HighACK is greater than RescueRxt (or RescueRxt
+ // is undefined), then one segment of upto SMSS
+ // octects that MUST include the highest outstanding
+ // unSACKed sequence number SHOULD be returned, and
+ // RescueRxt set to RecoveryPoint. HighRxt MUST NOT
+ // be updated.
+ if s.fr.rescueRxt.LessThan(s.sndUna - 1) {
+ if s4 != nil {
+ if s4.sequenceNumber.LessThan(segSeq) {
+ s4 = seg
+ }
+ } else {
+ s4 = seg
+ }
+ }
+ }
+ }
+
+ // If we got here then no segment matched step (1).
+ // Step (2): "If no sequence number 'S2' per rule (1)
+ // exists but there exists available unsent data and the
+ // receiver's advertised window allows, the sequence
+ // range of one segment of up to SMSS octets of
+ // previously unsent data starting with sequence number
+ // HighData+1 MUST be returned."
+ for seg := s.writeNext; seg != nil; seg = seg.Next() {
+ if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) {
+ continue
+ }
+ // We do not split the segment here to <= smss as it has
+ // potentially not been assigned a sequence number yet.
+ return seg, nil, false
+ }
+
+ if s3 != nil {
+ return s3, hint, false
+ }
+
+ return s4, nil, true
+}
+
+// maybeSendSegment tries to send the specified segment and either coalesces
+// other segments into this one or splits the specified segment based on the
+// lower of the specified limit value or the receivers window size specified by
+// end.
+func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (sent bool) {
+ // We abuse the flags field to determine if we have already
+ // assigned a sequence number to this segment.
+ if !s.isAssignedSequenceNumber(seg) {
+ // Merge segments if allowed.
+ if seg.data.Size() != 0 {
+ available := int(s.sndNxt.Size(end))
+ if available > limit {
+ available = limit
+ }
+
+ // nextTooBig indicates that the next segment was too
+ // large to entirely fit in the current segment. It
+ // would be possible to split the next segment and merge
+ // the portion that fits, but unexpectedly splitting
+ // segments can have user visible side-effects which can
+ // break applications. For example, RFC 7766 section 8
+ // says that the length and data of a DNS response
+ // should be sent in the same TCP segment to avoid
+ // triggering bugs in poorly written DNS
+ // implementations.
+ var nextTooBig bool
+ for seg.Next() != nil && seg.Next().data.Size() != 0 {
+ if seg.data.Size()+seg.Next().data.Size() > available {
+ nextTooBig = true
+ break
+ }
+ seg.data.Append(seg.Next().data)
+
+ // Consume the segment that we just merged in.
+ s.writeList.Remove(seg.Next())
+ }
+ if !nextTooBig && seg.data.Size() < available {
+ // Segment is not full.
+ if s.outstanding > 0 && atomic.LoadUint32(&s.ep.delay) != 0 {
+ // Nagle's algorithm. From Wikipedia:
+ // Nagle's algorithm works by
+ // combining a number of small
+ // outgoing messages and sending them
+ // all at once. Specifically, as long
+ // as there is a sent packet for which
+ // the sender has received no
+ // acknowledgment, the sender should
+ // keep buffering its output until it
+ // has a full packet's worth of
+ // output, thus allowing output to be
+ // sent all at once.
+ return false
+ }
+ // With TCP_CORK, hold back until minimum of the available
+ // send space and MSS.
+ // TODO(gvisor.dev/issue/2833): Drain the held segments after a
+ // timeout.
+ if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 {
+ return false
+ }
+ }
+ }
+
+ // Assign flags. We don't do it above so that we can merge
+ // additional data if Nagle holds the segment.
+ seg.sequenceNumber = s.sndNxt
+ seg.flags = header.TCPFlagAck | header.TCPFlagPsh
+ }
+
+ var segEnd seqnum.Value
+ if seg.data.Size() == 0 {
+ if s.writeList.Back() != seg {
+ panic("FIN segments must be the final segment in the write list.")
+ }
+ seg.flags = header.TCPFlagAck | header.TCPFlagFin
+ segEnd = seg.sequenceNumber.Add(1)
+ // Update the state to reflect that we have now
+ // queued a FIN.
+ switch s.ep.EndpointState() {
+ case StateCloseWait:
+ s.ep.setEndpointState(StateLastAck)
+ default:
+ s.ep.setEndpointState(StateFinWait1)
+ }
+ } else {
+ // We're sending a non-FIN segment.
+ if seg.flags&header.TCPFlagFin != 0 {
+ panic("Netstack queues FIN segments without data.")
+ }
+
+ if !seg.sequenceNumber.LessThan(end) {
+ return false
+ }
+
+ available := int(seg.sequenceNumber.Size(end))
+ if available == 0 {
+ return false
+ }
+
+ // If the whole segment or at least 1MSS sized segment cannot
+ // be accomodated in the receiver advertized window, skip
+ // splitting and sending of the segment. ref:
+ // net/ipv4/tcp_output.c::tcp_snd_wnd_test()
+ //
+ // Linux checks this for all segment transmits not triggered by
+ // a probe timer. On this condition, it defers the segment split
+ // and transmit to a short probe timer.
+ //
+ // ref: include/net/tcp.h::tcp_check_probe_timer()
+ // ref: net/ipv4/tcp_output.c::tcp_write_wakeup()
+ //
+ // Instead of defining a new transmit timer, we attempt to split
+ // the segment right here if there are no pending segments. If
+ // there are pending segments, segment transmits are deferred to
+ // the retransmit timer handler.
+ if s.sndUna != s.sndNxt {
+ switch {
+ case available >= seg.data.Size():
+ // OK to send, the whole segments fits in the
+ // receiver's advertised window.
+ case available >= s.maxPayloadSize:
+ // OK to send, at least 1 MSS sized segment fits
+ // in the receiver's advertised window.
+ default:
+ return false
+ }
+ }
+
+ // The segment size limit is computed as a function of sender
+ // congestion window and MSS. When sender congestion window is >
+ // 1, this limit can be larger than MSS. Ensure that the
+ // currently available send space is not greater than minimum of
+ // this limit and MSS.
+ if available > limit {
+ available = limit
+ }
+
+ // If GSO is not in use then cap available to
+ // maxPayloadSize. When GSO is in use the gVisor GSO logic or
+ // the host GSO logic will cap the segment to the correct size.
+ if s.ep.gso == nil && available > s.maxPayloadSize {
+ available = s.maxPayloadSize
+ }
+
+ if seg.data.Size() > available {
+ s.splitSeg(seg, available)
+ }
+
+ segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
+ }
+
+ s.sendSegment(seg)
+
+ // Update sndNxt if we actually sent new data (as opposed to
+ // retransmitting some previously sent data).
+ if s.sndNxt.LessThan(segEnd) {
+ s.sndNxt = segEnd
+ }
+
+ return true
+}
+
+// handleSACKRecovery implements the loss recovery phase as described in RFC6675
+// section 5, step C.
+func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) {
+ s.SetPipe()
+
+ if smss := int(s.ep.scoreboard.SMSS()); limit > smss {
+ // Cap segment size limit to s.smss as SACK recovery requires
+ // that all retransmissions or new segments send during recovery
+ // be of <= SMSS.
+ limit = smss
+ }
+
+ nextSegHint := s.writeList.Front()
+ for s.outstanding < s.sndCwnd {
+ var nextSeg *segment
+ var rescueRtx bool
+ nextSeg, nextSegHint, rescueRtx = s.NextSeg(nextSegHint)
+ if nextSeg == nil {
+ return dataSent
+ }
+ if !s.isAssignedSequenceNumber(nextSeg) || s.sndNxt.LessThanEq(nextSeg.sequenceNumber) {
+ // New data being sent.
+
+ // Step C.3 described below is handled by
+ // maybeSendSegment which increments sndNxt when
+ // a segment is transmitted.
+ //
+ // Step C.3 "If any of the data octets sent in
+ // (C.1) are above HighData, HighData must be
+ // updated to reflect the transmission of
+ // previously unsent data."
+ //
+ // We pass s.smss as the limit as the Step 2) requires that
+ // new data sent should be of size s.smss or less.
+ if sent := s.maybeSendSegment(nextSeg, limit, end); !sent {
+ return dataSent
+ }
+ dataSent = true
+ s.outstanding++
+ s.writeNext = nextSeg.Next()
+ continue
+ }
+
+ // Now handle the retransmission case where we matched either step 1,3 or 4
+ // of the NextSeg algorithm.
+ // RFC 6675, Step C.4.
+ //
+ // "The estimate of the amount of data outstanding in the network
+ // must be updated by incrementing pipe by the number of octets
+ // transmitted in (C.1)."
+ s.outstanding++
+ dataSent = true
+ s.sendSegment(nextSeg)
+
+ segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen())
+ if rescueRtx {
+ // We do the last part of rule (4) of NextSeg here to update
+ // RescueRxt as until this point we don't know if we are going
+ // to use the rescue transmission.
+ s.fr.rescueRxt = s.fr.last
+ } else {
+ // RFC 6675, Step C.2
+ //
+ // "If any of the data octets sent in (C.1) are below
+ // HighData, HighRxt MUST be set to the highest sequence
+ // number of the retransmitted segment unless NextSeg ()
+ // rule (4) was invoked for this retransmission."
+ s.fr.highRxt = segEnd - 1
+ }
+ }
+ return dataSent
+}
+
+func (s *sender) sendZeroWindowProbe() {
+ ack, win := s.ep.rcv.getSendParams()
+ s.unackZeroWindowProbes++
+ // Send a zero window probe with sequence number pointing to
+ // the last acknowledged byte.
+ s.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, s.sndUna-1, ack, win)
+ // Rearm the timer to continue probing.
+ s.resendTimer.enable(s.rto)
+}
+
+func (s *sender) enableZeroWindowProbing() {
+ s.zeroWindowProbing = true
+ // We piggyback the probing on the retransmit timer with the
+ // current retranmission interval, as we may start probing while
+ // segment retransmissions.
+ if s.firstRetransmittedSegXmitTime.IsZero() {
+ s.firstRetransmittedSegXmitTime = time.Now()
+ }
+ s.resendTimer.enable(s.rto)
+}
+
+func (s *sender) disableZeroWindowProbing() {
+ s.zeroWindowProbing = false
+ s.unackZeroWindowProbes = 0
+ s.firstRetransmittedSegXmitTime = time.Time{}
+ s.resendTimer.disable()
+}
+
+// sendData sends new data segments. It is called when data becomes available or
+// when the send window opens up.
+func (s *sender) sendData() {
+ limit := s.maxPayloadSize
+ if s.gso {
+ limit = int(s.ep.gso.MaxSize - header.TCPHeaderMaximumSize)
+ }
+ end := s.sndUna.Add(s.sndWnd)
+
+ // Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10.
+ // "A TCP SHOULD set cwnd to no more than RW before beginning
+ // transmission if the TCP has not sent data in the interval exceeding
+ // the retrasmission timeout."
+ if !s.fr.active && s.state != RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto {
+ if s.sndCwnd > InitialCwnd {
+ s.sndCwnd = InitialCwnd
+ }
+ }
+
+ var dataSent bool
+
+ // RFC 6675 recovery algorithm step C 1-5.
+ if s.fr.active && s.ep.sackPermitted {
+ dataSent = s.handleSACKRecovery(s.maxPayloadSize, end)
+ } else {
+ for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() {
+ cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize
+ if cwndLimit < limit {
+ limit = cwndLimit
+ }
+ if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ // Move writeNext along so that we don't try and scan data that
+ // has already been SACKED.
+ s.writeNext = seg.Next()
+ continue
+ }
+ if sent := s.maybeSendSegment(seg, limit, end); !sent {
+ break
+ }
+ dataSent = true
+ s.outstanding += s.pCount(seg)
+ s.writeNext = seg.Next()
+ }
+ }
+
+ if dataSent {
+ // We sent data, so we should stop the keepalive timer to ensure
+ // that no keepalives are sent while there is pending data.
+ s.ep.disableKeepaliveTimer()
+ }
+
+ // If the sender has advertized zero receive window and we have
+ // data to be sent out, start zero window probing to query the
+ // the remote for it's receive window size.
+ if s.writeNext != nil && s.sndWnd == 0 {
+ s.enableZeroWindowProbing()
+ }
+
+ // Enable the timer if we have pending data and it's not enabled yet.
+ if !s.resendTimer.enabled() && s.sndUna != s.sndNxt {
+ s.resendTimer.enable(s.rto)
+ }
+ // If we have no more pending data, start the keepalive timer.
+ if s.sndUna == s.sndNxt {
+ s.ep.resetKeepaliveTimer(false)
+ }
+}
+
+func (s *sender) enterFastRecovery() {
+ s.fr.active = true
+ // Save state to reflect we're now in fast recovery.
+ //
+ // See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3.
+ // We inflate the cwnd by 3 to account for the 3 packets which triggered
+ // the 3 duplicate ACKs and are now not in flight.
+ s.sndCwnd = s.sndSsthresh + 3
+ s.fr.first = s.sndUna
+ s.fr.last = s.sndNxt - 1
+ s.fr.maxCwnd = s.sndCwnd + s.outstanding
+ s.fr.highRxt = s.sndUna
+ s.fr.rescueRxt = s.sndUna
+ if s.ep.sackPermitted {
+ s.state = SACKRecovery
+ s.ep.stack.Stats().TCP.SACKRecovery.Increment()
+ return
+ }
+ s.state = FastRecovery
+ s.ep.stack.Stats().TCP.FastRecovery.Increment()
+}
+
+func (s *sender) leaveFastRecovery() {
+ s.fr.active = false
+ s.fr.maxCwnd = 0
+ s.dupAckCount = 0
+
+ // Deflate cwnd. It had been artificially inflated when new dups arrived.
+ s.sndCwnd = s.sndSsthresh
+
+ s.cc.PostRecovery()
+}
+
+func (s *sender) handleFastRecovery(seg *segment) (rtx bool) {
+ ack := seg.ackNumber
+ // We are in fast recovery mode. Ignore the ack if it's out of
+ // range.
+ if !ack.InRange(s.sndUna, s.sndNxt+1) {
+ return false
+ }
+
+ // Leave fast recovery if it acknowledges all the data covered by
+ // this fast recovery session.
+ if s.fr.last.LessThan(ack) {
+ s.leaveFastRecovery()
+ return false
+ }
+
+ if s.ep.sackPermitted {
+ // When SACK is enabled we let retransmission be governed by
+ // the SACK logic.
+ return false
+ }
+
+ // Don't count this as a duplicate if it is carrying data or
+ // updating the window.
+ if seg.logicalLen() != 0 || s.sndWnd != seg.window {
+ return false
+ }
+
+ // Inflate the congestion window if we're getting duplicate acks
+ // for the packet we retransmitted.
+ if ack == s.fr.first {
+ // We received a dup, inflate the congestion window by 1 packet
+ // if we're not at the max yet. Only inflate the window if
+ // regular FastRecovery is in use, RFC6675 does not require
+ // inflating cwnd on duplicate ACKs.
+ if s.sndCwnd < s.fr.maxCwnd {
+ s.sndCwnd++
+ }
+ return false
+ }
+
+ // A partial ack was received. Retransmit this packet and
+ // remember it so that we don't retransmit it again. We don't
+ // inflate the window because we're putting the same packet back
+ // onto the wire.
+ //
+ // N.B. The retransmit timer will be reset by the caller.
+ s.fr.first = ack
+ s.dupAckCount = 0
+ return true
+}
+
+// isAssignedSequenceNumber relies on the fact that we only set flags once a
+// sequencenumber is assigned and that is only done right before we send the
+// segment. As a result any segment that has a non-zero flag has a valid
+// sequence number assigned to it.
+func (s *sender) isAssignedSequenceNumber(seg *segment) bool {
+ return seg.flags != 0
+}
+
+// SetPipe implements the SetPipe() function described in RFC6675. Netstack
+// maintains the congestion window in number of packets and not bytes, so
+// SetPipe() here measures number of outstanding packets rather than actual
+// outstanding bytes in the network.
+func (s *sender) SetPipe() {
+ // If SACK isn't permitted or it is permitted but recovery is not active
+ // then ignore pipe calculations.
+ if !s.ep.sackPermitted || !s.fr.active {
+ return
+ }
+ pipe := 0
+ smss := seqnum.Size(s.ep.scoreboard.SMSS())
+ for s1 := s.writeList.Front(); s1 != nil && s1.data.Size() != 0 && s.isAssignedSequenceNumber(s1); s1 = s1.Next() {
+ // With GSO each segment can be much larger than SMSS. So check the segment
+ // in SMSS sized ranges.
+ segEnd := s1.sequenceNumber.Add(seqnum.Size(s1.data.Size()))
+ for startSeq := s1.sequenceNumber; startSeq.LessThan(segEnd); startSeq = startSeq.Add(smss) {
+ endSeq := startSeq.Add(smss)
+ if segEnd.LessThan(endSeq) {
+ endSeq = segEnd
+ }
+ sb := header.SACKBlock{startSeq, endSeq}
+ // SetPipe():
+ //
+ // After initializing pipe to zero, the following steps are
+ // taken for each octet 'S1' in the sequence space between
+ // HighACK and HighData that has not been SACKed:
+ if !s1.sequenceNumber.LessThan(s.sndNxt) {
+ break
+ }
+ if s.ep.scoreboard.IsSACKED(sb) {
+ continue
+ }
+
+ // SetPipe():
+ //
+ // (a) If IsLost(S1) returns false, Pipe is incremened by 1.
+ //
+ // NOTE: here we mark the whole segment as lost. We do not try
+ // and test every byte in our write buffer as we maintain our
+ // pipe in terms of oustanding packets and not bytes.
+ if !s.ep.scoreboard.IsRangeLost(sb) {
+ pipe++
+ }
+ // SetPipe():
+ // (b) If S1 <= HighRxt, Pipe is incremented by 1.
+ if s1.sequenceNumber.LessThanEq(s.fr.highRxt) {
+ pipe++
+ }
+ }
+ }
+ s.outstanding = pipe
+}
+
+// checkDuplicateAck is called when an ack is received. It manages the state
+// related to duplicate acks and determines if a retransmit is needed according
+// to the rules in RFC 6582 (NewReno).
+func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
+ ack := seg.ackNumber
+ if s.fr.active {
+ return s.handleFastRecovery(seg)
+ }
+
+ // We're not in fast recovery yet. A segment is considered a duplicate
+ // only if it doesn't carry any data and doesn't update the send window,
+ // because if it does, it wasn't sent in response to an out-of-order
+ // segment. If SACK is enabled then we have an additional check to see
+ // if the segment carries new SACK information. If it does then it is
+ // considered a duplicate ACK as per RFC6675.
+ if ack != s.sndUna || seg.logicalLen() != 0 || s.sndWnd != seg.window || ack == s.sndNxt {
+ if !s.ep.sackPermitted || !seg.hasNewSACKInfo {
+ s.dupAckCount = 0
+ return false
+ }
+ }
+
+ s.dupAckCount++
+
+ // Do not enter fast recovery until we reach nDupAckThreshold or the
+ // first unacknowledged byte is considered lost as per SACK scoreboard.
+ if s.dupAckCount < nDupAckThreshold || (s.ep.sackPermitted && !s.ep.scoreboard.IsLost(s.sndUna)) {
+ // RFC 6675 Step 3.
+ s.fr.highRxt = s.sndUna - 1
+ // Do run SetPipe() to calculate the outstanding segments.
+ s.SetPipe()
+ s.state = Disorder
+ return false
+ }
+
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 2
+ //
+ // We only do the check here, the incrementing of last to the highest
+ // sequence number transmitted till now is done when enterFastRecovery
+ // is invoked.
+ if !s.fr.last.LessThan(seg.ackNumber) {
+ s.dupAckCount = 0
+ return false
+ }
+ s.cc.HandleNDupAcks()
+ s.enterFastRecovery()
+ s.dupAckCount = 0
+ return true
+}
+
+// handleRcvdSegment is called when a segment is received; it is responsible for
+// updating the send-related state.
+func (s *sender) handleRcvdSegment(seg *segment) {
+ // Check if we can extract an RTT measurement from this ack.
+ if !seg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(seg.ackNumber) {
+ s.updateRTO(time.Now().Sub(s.rttMeasureTime))
+ s.rttMeasureSeqNum = s.sndNxt
+ }
+
+ // Update Timestamp if required. See RFC7323, section-4.3.
+ if s.ep.sendTSOk && seg.parsedOptions.TS {
+ s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber)
+ }
+
+ // Insert SACKBlock information into our scoreboard.
+ if s.ep.sackPermitted {
+ for _, sb := range seg.parsedOptions.SACKBlocks {
+ // Only insert the SACK block if the following holds
+ // true:
+ // * SACK block acks data after the ack number in the
+ // current segment.
+ // * SACK block represents a sequence
+ // between sndUna and sndNxt (i.e. data that is
+ // currently unacked and in-flight).
+ // * SACK block that has not been SACKed already.
+ //
+ // NOTE: This check specifically excludes DSACK blocks
+ // which have start/end before sndUna and are used to
+ // indicate spurious retransmissions.
+ if seg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) {
+ s.ep.scoreboard.Insert(sb)
+ seg.hasNewSACKInfo = true
+ }
+ }
+ s.SetPipe()
+ }
+
+ // Count the duplicates and do the fast retransmit if needed.
+ rtx := s.checkDuplicateAck(seg)
+
+ // Stash away the current window size.
+ s.sndWnd = seg.window
+
+ ack := seg.ackNumber
+
+ // Disable zero window probing if remote advertizes a non-zero receive
+ // window. This can be with an ACK to the zero window probe (where the
+ // acknumber refers to the already acknowledged byte) OR to any previously
+ // unacknowledged segment.
+ if s.zeroWindowProbing && seg.window > 0 &&
+ (ack == s.sndUna || (ack-1).InRange(s.sndUna, s.sndNxt)) {
+ s.disableZeroWindowProbing()
+ }
+
+ // On receiving the ACK for the zero window probe, account for it and
+ // skip trying to send any segment as we are still probing for
+ // receive window to become non-zero.
+ if s.zeroWindowProbing && s.unackZeroWindowProbes > 0 && ack == s.sndUna {
+ s.unackZeroWindowProbes--
+ return
+ }
+
+ // Ignore ack if it doesn't acknowledge any new data.
+ if (ack - 1).InRange(s.sndUna, s.sndNxt) {
+ s.dupAckCount = 0
+
+ // See : https://tools.ietf.org/html/rfc1323#section-3.3.
+ // Specifically we should only update the RTO using TSEcr if the
+ // following condition holds:
+ //
+ // A TSecr value received in a segment is used to update the
+ // averaged RTT measurement only if the segment acknowledges
+ // some new data, i.e., only if it advances the left edge of
+ // the send window.
+ if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 {
+ // TSVal/Ecr values sent by Netstack are at a millisecond
+ // granularity.
+ elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond
+ s.updateRTO(elapsed)
+ }
+
+ // When an ack is received we must rearm the timer.
+ // RFC 6298 5.3
+ s.resendTimer.enable(s.rto)
+
+ // Remove all acknowledged data from the write list.
+ acked := s.sndUna.Size(ack)
+ s.sndUna = ack
+
+ ackLeft := acked
+ originalOutstanding := s.outstanding
+ for ackLeft > 0 {
+ // We use logicalLen here because we can have FIN
+ // segments (which are always at the end of list) that
+ // have no data, but do consume a sequence number.
+ seg := s.writeList.Front()
+ datalen := seg.logicalLen()
+
+ if datalen > ackLeft {
+ prevCount := s.pCount(seg)
+ seg.data.TrimFront(int(ackLeft))
+ seg.sequenceNumber.UpdateForward(ackLeft)
+ s.outstanding -= prevCount - s.pCount(seg)
+ break
+ }
+
+ if s.writeNext == seg {
+ s.writeNext = seg.Next()
+ }
+
+ s.writeList.Remove(seg)
+
+ // if SACK is enabled then Only reduce outstanding if
+ // the segment was not previously SACKED as these have
+ // already been accounted for in SetPipe().
+ if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ s.outstanding -= s.pCount(seg)
+ }
+ seg.decRef()
+ ackLeft -= datalen
+ }
+
+ // Update the send buffer usage and notify potential waiters.
+ s.ep.updateSndBufferUsage(int(acked))
+
+ // Clear SACK information for all acked data.
+ s.ep.scoreboard.Delete(s.sndUna)
+
+ // If we are not in fast recovery then update the congestion
+ // window based on the number of acknowledged packets.
+ if !s.fr.active {
+ s.cc.Update(originalOutstanding - s.outstanding)
+ if s.fr.last.LessThan(s.sndUna) {
+ s.state = Open
+ }
+ }
+
+ // It is possible for s.outstanding to drop below zero if we get
+ // a retransmit timeout, reset outstanding to zero but later
+ // get an ack that cover previously sent data.
+ if s.outstanding < 0 {
+ s.outstanding = 0
+ }
+
+ s.SetPipe()
+
+ // If all outstanding data was acknowledged the disable the timer.
+ // RFC 6298 Rule 5.3
+ if s.sndUna == s.sndNxt {
+ s.outstanding = 0
+ // Reset firstRetransmittedSegXmitTime to the zero value.
+ s.firstRetransmittedSegXmitTime = time.Time{}
+ s.resendTimer.disable()
+ }
+ }
+ // Now that we've popped all acknowledged data from the retransmit
+ // queue, retransmit if needed.
+ if rtx {
+ s.resendSegment()
+ }
+
+ // Send more data now that some of the pending data has been ack'd, or
+ // that the window opened up, or the congestion window was inflated due
+ // to a duplicate ack during fast recovery. This will also re-enable
+ // the retransmit timer if needed.
+ if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || seg.hasNewSACKInfo {
+ s.sendData()
+ }
+}
+
+// sendSegment sends the specified segment.
+func (s *sender) sendSegment(seg *segment) *tcpip.Error {
+ if seg.xmitCount > 0 {
+ s.ep.stack.Stats().TCP.Retransmits.Increment()
+ s.ep.stats.SendErrors.Retransmits.Increment()
+ if s.sndCwnd < s.sndSsthresh {
+ s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment()
+ }
+ }
+ seg.xmitTime = time.Now()
+ seg.xmitCount++
+ err := s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber)
+
+ // Every time a packet containing data is sent (including a
+ // retransmission), if SACK is enabled and we are retransmitting data
+ // then use the conservative timer described in RFC6675 Section 6.0,
+ // otherwise follow the standard time described in RFC6298 Section 5.1.
+ if err != nil && seg.data.Size() != 0 {
+ if s.fr.active && seg.xmitCount > 1 && s.ep.sackPermitted {
+ s.resendTimer.enable(s.rto)
+ } else {
+ if !s.resendTimer.enabled() {
+ s.resendTimer.enable(s.rto)
+ }
+ }
+ }
+
+ return err
+}
+
+// sendSegmentFromView sends a new segment containing the given payload, flags
+// and sequence number.
+func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error {
+ s.lastSendTime = time.Now()
+ if seq == s.rttMeasureSeqNum {
+ s.rttMeasureTime = s.lastSendTime
+ }
+
+ rcvNxt, rcvWnd := s.ep.rcv.getSendParams()
+
+ // Remember the max sent ack.
+ s.maxSentAck = rcvNxt
+
+ return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd)
+}
diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go
new file mode 100644
index 000000000..8b20c3455
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/snd_state.go
@@ -0,0 +1,60 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+)
+
+// +stateify savable
+type unixTime struct {
+ second int64
+ nano int64
+}
+
+// saveLastSendTime is invoked by stateify.
+func (s *sender) saveLastSendTime() unixTime {
+ return unixTime{s.lastSendTime.Unix(), s.lastSendTime.UnixNano()}
+}
+
+// loadLastSendTime is invoked by stateify.
+func (s *sender) loadLastSendTime(unix unixTime) {
+ s.lastSendTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveRttMeasureTime is invoked by stateify.
+func (s *sender) saveRttMeasureTime() unixTime {
+ return unixTime{s.rttMeasureTime.Unix(), s.rttMeasureTime.UnixNano()}
+}
+
+// loadRttMeasureTime is invoked by stateify.
+func (s *sender) loadRttMeasureTime(unix unixTime) {
+ s.rttMeasureTime = time.Unix(unix.second, unix.nano)
+}
+
+// afterLoad is invoked by stateify.
+func (s *sender) afterLoad() {
+ s.resendTimer.init(&s.resendWaker)
+}
+
+// saveFirstRetransmittedSegXmitTime is invoked by stateify.
+func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime {
+ return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()}
+}
+
+// loadFirstRetransmittedSegXmitTime is invoked by stateify.
+func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) {
+ s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
new file mode 100644
index 000000000..b9993ce1a
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -0,0 +1,550 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// These tests are flaky when run under the go race detector due to some
+// iterations taking long enough that the retransmit timer can kick in causing
+// the congestion window measurements to fail due to extra packets etc.
+//
+// +build !race
+
+package tcp_test
+
+import (
+ "fmt"
+ "math"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+func TestFastRecovery(t *testing.T) {
+ maxPayload := 32
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ const iterations = 3
+ data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write all the data in one shot. Packets will only be written at the
+ // MTU size though.
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Do slow start for a few iterations.
+ expected := tcp.InitialCwnd
+ bytesRead := 0
+ for i := 0; i < iterations; i++ {
+ expected = tcp.InitialCwnd << uint(i)
+ if i > 0 {
+ // Acknowledge all the data received so far if not on
+ // first iteration.
+ c.SendAck(790, bytesRead)
+ }
+
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+ }
+
+ // Send 3 duplicate acks. This should force an immediate retransmit of
+ // the pending packet and put the sender into fast recovery.
+ rtxOffset := bytesRead - maxPayload*expected
+ for i := 0; i < 3; i++ {
+ c.SendAck(790, rtxOffset)
+ }
+
+ // Receive the retransmitted packet.
+ c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
+
+ // Wait before checking metrics.
+ metricPollFn := func() error {
+ if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
+ }
+ if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want)
+ }
+
+ if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.FastRecovery.Value = %d, want = %d", got, want)
+ }
+ return nil
+ }
+
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
+ }
+
+ // Now send 7 mode duplicate acks. Each of these should cause a window
+ // inflation by 1 and cause the sender to send an extra packet.
+ for i := 0; i < 7; i++ {
+ c.SendAck(790, rtxOffset)
+ }
+
+ recover := bytesRead
+
+ // Ensure no new packets arrive.
+ c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.",
+ 50*time.Millisecond)
+
+ // Acknowledge half of the pending data.
+ rtxOffset = bytesRead - expected*maxPayload/2
+ c.SendAck(790, rtxOffset)
+
+ // Receive the retransmit due to partial ack.
+ c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
+
+ // Wait before checking metrics.
+ metricPollFn = func() error {
+ if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want {
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
+ }
+ if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want {
+ return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want)
+ }
+ return nil
+ }
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
+ }
+
+ // Receive the 10 extra packets that should have been released due to
+ // the congestion window inflation in recovery.
+ for i := 0; i < 10; i++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // A partial ACK during recovery should reduce congestion window by the
+ // number acked. Since we had "expected" packets outstanding before sending
+ // partial ack and we acked expected/2 , the cwnd and outstanding should
+ // be expected/2 + 10 (7 dupAcks + 3 for the original 3 dupacks that triggered
+ // fast recovery). Which means the sender should not send any more packets
+ // till we ack this one.
+ c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.",
+ 50*time.Millisecond)
+
+ // Acknowledge all pending data to recover point.
+ c.SendAck(790, recover)
+
+ // At this point, the cwnd should reset to expected/2 and there are 10
+ // packets outstanding.
+ //
+ // NOTE: Technically netstack is incorrect in that we adjust the cwnd on
+ // the same segment that takes us out of recovery. But because of that
+ // the actual cwnd at exit of recovery will be expected/2 + 1 as we
+ // acked a cwnd worth of packets which will increase the cwnd further by
+ // 1 in congestion avoidance.
+ //
+ // Now in the first iteration since there are 10 packets outstanding.
+ // We would expect to get expected/2 +1 - 10 packets. But subsequent
+ // iterations will send us expected/2 + 1 + 1 (per iteration).
+ expected = expected/2 + 1 - 10
+ for i := 0; i < iterations; i++ {
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond)
+
+ // Acknowledge all the data received so far.
+ c.SendAck(790, bytesRead)
+
+ // In cogestion avoidance, the packets trains increase by 1 in
+ // each iteration.
+ if i == 0 {
+ // After the first iteration we expect to get the full
+ // congestion window worth of packets in every
+ // iteration.
+ expected += 10
+ }
+ expected++
+ }
+}
+
+func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
+ maxPayload := 32
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ const iterations = 3
+ data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write all the data in one shot. Packets will only be written at the
+ // MTU size though.
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ expected := tcp.InitialCwnd
+ bytesRead := 0
+ for i := 0; i < iterations; i++ {
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+
+ // Acknowledge all the data received so far.
+ c.SendAck(790, bytesRead)
+
+ // Double the number of expected packets for the next iteration.
+ expected *= 2
+ }
+}
+
+func TestCongestionAvoidance(t *testing.T) {
+ maxPayload := 32
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ const iterations = 3
+ data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write all the data in one shot. Packets will only be written at the
+ // MTU size though.
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Do slow start for a few iterations.
+ expected := tcp.InitialCwnd
+ bytesRead := 0
+ for i := 0; i < iterations; i++ {
+ expected = tcp.InitialCwnd << uint(i)
+ if i > 0 {
+ // Acknowledge all the data received so far if not on
+ // first iteration.
+ c.SendAck(790, bytesRead)
+ }
+
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond)
+ }
+
+ // Don't acknowledge the first packet of the last packet train. Let's
+ // wait for them to time out, which will trigger a restart of slow
+ // start, and initialization of ssthresh to cwnd/2.
+ rtxOffset := bytesRead - maxPayload*expected
+ c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
+
+ // Acknowledge all the data received so far.
+ c.SendAck(790, bytesRead)
+
+ // This part is tricky: when the timeout happened, we had "expected"
+ // packets pending, cwnd reset to 1, and ssthresh set to expected/2.
+ // By acknowledging "expected" packets, the slow-start part will
+ // increase cwnd to expected/2 (which "consumes" expected/2-1 of the
+ // acknowledgements), then the congestion avoidance part will consume
+ // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack
+ // remains in the "ack count" (which will cause cwnd to be incremented
+ // once it reaches cwnd acks).
+ //
+ // So we're straight into congestion avoidance with cwnd set to
+ // expected/2 + 1.
+ //
+ // Check that packets trains of cwnd packets are sent, and that cwnd is
+ // incremented by 1 after we acknowledge each packet.
+ expected = expected/2 + 1
+ for i := 0; i < iterations; i++ {
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond)
+
+ // Acknowledge all the data received so far.
+ c.SendAck(790, bytesRead)
+
+ // In cogestion avoidance, the packets trains increase by 1 in
+ // each iteration.
+ expected++
+ }
+}
+
+// cubicCwnd returns an estimate of a cubic window given the
+// originalCwnd, wMax, last congestion event time and sRTT.
+func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Duration) int {
+ cwnd := float64(origCwnd)
+ // We wait 50ms between each iteration so sRTT as computed by cubic
+ // should be close to 50ms.
+ elapsed := (time.Since(congEventTime) + sRTT).Seconds()
+ k := math.Cbrt(float64(wMax) * 0.3 / 0.7)
+ wtRTT := 0.4*math.Pow(elapsed-k, 3) + float64(wMax)
+ cwnd += (wtRTT - cwnd) / cwnd
+ return int(cwnd)
+}
+
+func TestCubicCongestionAvoidance(t *testing.T) {
+ maxPayload := 32
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ enableCUBIC(t, c)
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ const iterations = 3
+ data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write all the data in one shot. Packets will only be written at the
+ // MTU size though.
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Do slow start for a few iterations.
+ expected := tcp.InitialCwnd
+ bytesRead := 0
+ for i := 0; i < iterations; i++ {
+ expected = tcp.InitialCwnd << uint(i)
+ if i > 0 {
+ // Acknowledge all the data received so far if not on
+ // first iteration.
+ c.SendAck(790, bytesRead)
+ }
+
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond)
+ }
+
+ // Don't acknowledge the first packet of the last packet train. Let's
+ // wait for them to time out, which will trigger a restart of slow
+ // start, and initialization of ssthresh to cwnd * 0.7.
+ rtxOffset := bytesRead - maxPayload*expected
+ c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
+
+ // Acknowledge all pending data.
+ c.SendAck(790, bytesRead)
+
+ // Store away the time we sent the ACK and assuming a 200ms RTO
+ // we estimate that the sender will have an RTO 200ms from now
+ // and go back into slow start.
+ packetDropTime := time.Now().Add(200 * time.Millisecond)
+
+ // This part is tricky: when the timeout happened, we had "expected"
+ // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7.
+ // By acknowledging "expected" packets, the slow-start part will
+ // increase cwnd to expected/2 essentially putting the connection
+ // straight into congestion avoidance.
+ wMax := expected
+ // Lower expected as per cubic spec after a congestion event.
+ expected = int(float64(expected) * 0.7)
+ cwnd := expected
+ for i := 0; i < iterations; i++ {
+ // Cubic grows window independent of ACKs. Cubic Window growth
+ // is a function of time elapsed since last congestion event.
+ // As a result the congestion window does not grow
+ // deterministically in response to ACKs.
+ //
+ // We need to roughly estimate what the cwnd of the sender is
+ // based on when we sent the dupacks.
+ cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond)
+
+ packetsExpected := cwnd
+ for j := 0; j < packetsExpected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+ t.Logf("expected packets received, next trying to receive any extra packets that may come")
+
+ // If our estimate was correct there should be no more pending packets.
+ // We attempt to read a packet a few times with a short sleep in between
+ // to ensure that we don't see the sender send any unexpected packets.
+ unexpectedPackets := 0
+ for {
+ gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload)
+ if !gotPacket {
+ break
+ }
+ bytesRead += maxPayload
+ unexpectedPackets++
+ time.Sleep(1 * time.Millisecond)
+ }
+ if unexpectedPackets != 0 {
+ t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i)
+ }
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond)
+
+ // Acknowledge all the data received so far.
+ c.SendAck(790, bytesRead)
+ }
+}
+
+func TestRetransmit(t *testing.T) {
+ maxPayload := 32
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ const iterations = 3
+ data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write all the data in two shots. Packets will only be written at the
+ // MTU size though.
+ half := data[:len(data)/2]
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ half = data[len(data)/2:]
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Do slow start for a few iterations.
+ expected := tcp.InitialCwnd
+ bytesRead := 0
+ for i := 0; i < iterations; i++ {
+ expected = tcp.InitialCwnd << uint(i)
+ if i > 0 {
+ // Acknowledge all the data received so far if not on
+ // first iteration.
+ c.SendAck(790, bytesRead)
+ }
+
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+ }
+
+ // Wait for a timeout and retransmit.
+ rtxOffset := bytesRead - maxPayload*expected
+ c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
+
+ metricPollFn := func() error {
+ if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.Timeouts.Value = %d, want = %d", got, want)
+ }
+
+ if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want)
+ }
+
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want {
+ return fmt.Errorf("got EP SendErrors.Timeouts.Value = %d, want = %d", got, want)
+ }
+
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want {
+ return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %d, want = %d", got, want)
+ }
+
+ if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %d, want = %d", got, want)
+ }
+
+ return nil
+ }
+
+ // Poll when checking metrics.
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
+ }
+
+ // Acknowledge half of the pending data.
+ rtxOffset = bytesRead - expected*maxPayload/2
+ c.SendAck(790, rtxOffset)
+
+ // Receive the remaining data, making sure that acknowledged data is not
+ // retransmitted.
+ for offset := rtxOffset; offset < len(data); offset += maxPayload {
+ c.ReceiveAndCheckPacket(data, offset, maxPayload)
+ c.SendAck(790, offset+maxPayload)
+ }
+
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
new file mode 100644
index 000000000..99521f0c1
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -0,0 +1,589 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_test
+
+import (
+ "fmt"
+ "log"
+ "reflect"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+// createConnectedWithSACKPermittedOption creates and connects c.ep with the
+// SACKPermitted option enabled if the stack in the context has the SACK support
+// enabled.
+func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint {
+ return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()})
+}
+
+// createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS
+// option enabled if the stack in the context has SACK and TS enabled.
+func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint {
+ return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true})
+}
+
+func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
+ t.Helper()
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil {
+ t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%t) = %s", enable, err)
+ }
+}
+
+// TestSackPermittedConnect establishes a connection with the SACK option
+// enabled.
+func TestSackPermittedConnect(t *testing.T) {
+ for _, sackEnabled := range []bool{false, true} {
+ t.Run(fmt.Sprintf("stack.sackEnabled: %v", sackEnabled), func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ setStackSACKPermitted(t, c, sackEnabled)
+ rep := createConnectedWithSACKPermittedOption(c)
+ data := []byte{1, 2, 3}
+
+ rep.SendPacket(data, nil)
+ savedSeqNum := rep.NextSeqNum
+ rep.VerifyACKNoSACK()
+
+ // Make an out of order packet and send it.
+ rep.NextSeqNum += 3
+ sackBlocks := []header.SACKBlock{
+ {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))},
+ }
+ rep.SendPacket(data, nil)
+
+ // Restore the saved sequence number so that the
+ // VerifyXXX calls use the right sequence number for
+ // checking ACK numbers.
+ rep.NextSeqNum = savedSeqNum
+ if sackEnabled {
+ rep.VerifyACKHasSACK(sackBlocks)
+ } else {
+ rep.VerifyACKNoSACK()
+ }
+
+ // Send the missing segment.
+ rep.SendPacket(data, nil)
+ // The ACK should contain the cumulative ACK for all 9
+ // bytes sent and no SACK blocks.
+ rep.NextSeqNum += 3
+ // Check that no SACK block is returned in the ACK.
+ rep.VerifyACKNoSACK()
+ })
+ }
+}
+
+// TestSackDisabledConnect establishes a connection with the SACK option
+// disabled and verifies that no SACKs are sent for out of order segments.
+func TestSackDisabledConnect(t *testing.T) {
+ for _, sackEnabled := range []bool{false, true} {
+ t.Run(fmt.Sprintf("sackEnabled: %v", sackEnabled), func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ setStackSACKPermitted(t, c, sackEnabled)
+
+ rep := c.CreateConnectedWithOptions(header.TCPSynOptions{})
+
+ data := []byte{1, 2, 3}
+
+ rep.SendPacket(data, nil)
+ savedSeqNum := rep.NextSeqNum
+ rep.VerifyACKNoSACK()
+
+ // Make an out of order packet and send it.
+ rep.NextSeqNum += 3
+ rep.SendPacket(data, nil)
+
+ // The ACK should contain the older sequence number and
+ // no SACK blocks.
+ rep.NextSeqNum = savedSeqNum
+ rep.VerifyACKNoSACK()
+
+ // Send the missing segment.
+ rep.SendPacket(data, nil)
+ // The ACK should contain the cumulative ACK for all 9
+ // bytes sent and no SACK blocks.
+ rep.NextSeqNum += 3
+ // Check that no SACK block is returned in the ACK.
+ rep.VerifyACKNoSACK()
+ })
+ }
+}
+
+// TestSackPermittedAccept accepts and establishes a connection with the
+// SACKPermitted option enabled if the connection request specifies the
+// SACKPermitted option. In case of SYN cookies SACK should be disabled as we
+// don't encode the SACK information in the cookie.
+func TestSackPermittedAccept(t *testing.T) {
+ type testCase struct {
+ cookieEnabled bool
+ sackPermitted bool
+ wndScale int
+ wndSize uint16
+ }
+
+ testCases := []testCase{
+ // When cookie is used window scaling is disabled.
+ {true, false, -1, 0xffff}, // When cookie is used window scaling is disabled.
+ {false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
+ }
+
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
+ for _, sackEnabled := range []bool{false, true} {
+ t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ if tc.cookieEnabled {
+ // Set the SynRcvd threshold to
+ // zero to force a syn cookie
+ // based accept to happen.
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
+ }
+ setStackSACKPermitted(t, c, sackEnabled)
+
+ rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
+ // Now verify no SACK blocks are
+ // received when sack is disabled.
+ data := []byte{1, 2, 3}
+ rep.SendPacket(data, nil)
+ rep.VerifyACKNoSACK()
+
+ savedSeqNum := rep.NextSeqNum
+
+ // Make an out of order packet and send
+ // it.
+ rep.NextSeqNum += 3
+ sackBlocks := []header.SACKBlock{
+ {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))},
+ }
+ rep.SendPacket(data, nil)
+
+ // The ACK should contain the older
+ // sequence number.
+ rep.NextSeqNum = savedSeqNum
+ if sackEnabled && tc.sackPermitted {
+ rep.VerifyACKHasSACK(sackBlocks)
+ } else {
+ rep.VerifyACKNoSACK()
+ }
+
+ // Send the missing segment.
+ rep.SendPacket(data, nil)
+ // The ACK should contain the cumulative
+ // ACK for all 9 bytes sent and no SACK
+ // blocks.
+ rep.NextSeqNum += 3
+ // Check that no SACK block is returned
+ // in the ACK.
+ rep.VerifyACKNoSACK()
+ })
+ }
+ })
+ }
+}
+
+// TestSackDisabledAccept accepts and establishes a connection with
+// the SACKPermitted option disabled and verifies that no SACKs are
+// sent for out of order packets.
+func TestSackDisabledAccept(t *testing.T) {
+ type testCase struct {
+ cookieEnabled bool
+ wndScale int
+ wndSize uint16
+ }
+
+ testCases := []testCase{
+ // When cookie is used window scaling is disabled.
+ {true, -1, 0xffff}, // When cookie is used window scaling is disabled.
+ {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
+ }
+
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
+ for _, sackEnabled := range []bool{false, true} {
+ t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ if tc.cookieEnabled {
+ // Set the SynRcvd threshold to
+ // zero to force a syn cookie
+ // based accept to happen.
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
+ }
+
+ setStackSACKPermitted(t, c, sackEnabled)
+
+ rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
+
+ // Now verify no SACK blocks are
+ // received when sack is disabled.
+ data := []byte{1, 2, 3}
+ rep.SendPacket(data, nil)
+ rep.VerifyACKNoSACK()
+ savedSeqNum := rep.NextSeqNum
+
+ // Make an out of order packet and send
+ // it.
+ rep.NextSeqNum += 3
+ rep.SendPacket(data, nil)
+
+ // The ACK should contain the older
+ // sequence number and no SACK blocks.
+ rep.NextSeqNum = savedSeqNum
+ rep.VerifyACKNoSACK()
+
+ // Send the missing segment.
+ rep.SendPacket(data, nil)
+ // The ACK should contain the cumulative
+ // ACK for all 9 bytes sent and no SACK
+ // blocks.
+ rep.NextSeqNum += 3
+ // Check that no SACK block is returned
+ // in the ACK.
+ rep.VerifyACKNoSACK()
+ })
+ }
+ })
+ }
+}
+
+func TestUpdateSACKBlocks(t *testing.T) {
+ testCases := []struct {
+ segStart seqnum.Value
+ segEnd seqnum.Value
+ rcvNxt seqnum.Value
+ sackBlocks []header.SACKBlock
+ updated []header.SACKBlock
+ }{
+ // Trivial cases where current SACK block list is empty and we
+ // have an out of order delivery.
+ {10, 11, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 11}}},
+ {10, 12, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 12}}},
+ {10, 20, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 20}}},
+
+ // Cases where current SACK block list is not empty and we have
+ // an out of order delivery. Tests that the updated SACK block
+ // list has the first block as the one that contains the new
+ // SACK block representing the segment that was just delivered.
+ {10, 11, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 11}, {12, 20}}},
+ {24, 30, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{24, 30}, {12, 20}}},
+ {24, 30, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}}},
+
+ // Ensure that we only retain header.MaxSACKBlocks and drop the
+ // oldest one if adding a new block exceeds
+ // header.MaxSACKBlocks.
+ {24, 30, 9,
+ []header.SACKBlock{{12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}, {72, 80}},
+ []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}},
+
+ // Cases where segment extends an existing SACK block.
+ {10, 12, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 20}}},
+ {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}},
+ {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}},
+ {15, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 22}}},
+ {15, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 25}}},
+ {11, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{11, 25}}},
+ {10, 12, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 20}, {32, 40}}},
+ {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}},
+ {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}},
+ {15, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 22}, {32, 40}}},
+ {15, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 25}, {32, 40}}},
+ {11, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{11, 25}, {32, 40}}},
+
+ // Cases where segment contains rcvNxt.
+ {10, 20, 15, []header.SACKBlock{{20, 30}, {40, 50}}, []header.SACKBlock{{40, 50}}},
+ }
+
+ for _, tc := range testCases {
+ var sack tcp.SACKInfo
+ copy(sack.Blocks[:], tc.sackBlocks)
+ sack.NumBlocks = len(tc.sackBlocks)
+ tcp.UpdateSACKBlocks(&sack, tc.segStart, tc.segEnd, tc.rcvNxt)
+ if got, want := sack.Blocks[:sack.NumBlocks], tc.updated; !reflect.DeepEqual(got, want) {
+ t.Errorf("UpdateSACKBlocks(%v, %v, %v, %v), got: %v, want: %v", tc.sackBlocks, tc.segStart, tc.segEnd, tc.rcvNxt, got, want)
+ }
+
+ }
+}
+
+func TestTrimSackBlockList(t *testing.T) {
+ testCases := []struct {
+ rcvNxt seqnum.Value
+ sackBlocks []header.SACKBlock
+ trimmed []header.SACKBlock
+ }{
+ // Simple cases where we trim whole entries.
+ {2, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}},
+ {21, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{22, 30}, {32, 40}}},
+ {31, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{32, 40}}},
+ {40, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}},
+ // Cases where we need to update a block.
+ {12, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{12, 20}, {22, 30}, {32, 40}}},
+ {23, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{23, 30}, {32, 40}}},
+ {33, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{33, 40}}},
+ {41, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}},
+ }
+ for _, tc := range testCases {
+ var sack tcp.SACKInfo
+ copy(sack.Blocks[:], tc.sackBlocks)
+ sack.NumBlocks = len(tc.sackBlocks)
+ tcp.TrimSACKBlockList(&sack, tc.rcvNxt)
+ if got, want := sack.Blocks[:sack.NumBlocks], tc.trimmed; !reflect.DeepEqual(got, want) {
+ t.Errorf("TrimSackBlockList(%v, %v), got: %v, want: %v", tc.sackBlocks, tc.rcvNxt, got, want)
+ }
+ }
+}
+
+func TestSACKRecovery(t *testing.T) {
+ const maxPayload = 10
+ // See: tcp.makeOptions for why tsOptionSize is set to 12 here.
+ const tsOptionSize = 12
+ // Enabling SACK means the payload size is reduced to account
+ // for the extra space required for the TCP options.
+ //
+ // We increase the MTU by 40 bytes to account for SACK and Timestamp
+ // options.
+ const maxTCPOptionSize = 40
+
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
+ defer c.Cleanup()
+
+ c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) {
+ // We use log.Printf instead of t.Logf here because this probe
+ // can fire even when the test function has finished. This is
+ // because closing the endpoint in cleanup() does not mean the
+ // actual worker loop terminates immediately as it still has to
+ // do a full TCP shutdown. But this test can finish running
+ // before the shutdown is done. Using t.Logf in such a case
+ // causes the test to panic due to logging after test finished.
+ log.Printf("state: %+v\n", s)
+ })
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+
+ const iterations = 3
+ data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write all the data in one shot. Packets will only be written at the
+ // MTU size though.
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Do slow start for a few iterations.
+ expected := tcp.InitialCwnd
+ bytesRead := 0
+ for i := 0; i < iterations; i++ {
+ expected = tcp.InitialCwnd << uint(i)
+ if i > 0 {
+ // Acknowledge all the data received so far if not on
+ // first iteration.
+ c.SendAck(790, bytesRead)
+ }
+
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+ }
+
+ // Send 3 duplicate acks. This should force an immediate retransmit of
+ // the pending packet and put the sender into fast recovery.
+ rtxOffset := bytesRead - maxPayload*expected
+ start := c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1)
+ end := start.Add(10)
+ for i := 0; i < 3; i++ {
+ c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}})
+ end = end.Add(10)
+ }
+
+ // Receive the retransmitted packet.
+ c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize)
+
+ metricPollFn := func() error {
+ tcpStats := c.Stack().Stats().TCP
+ stats := []struct {
+ stat *tcpip.StatCounter
+ name string
+ want uint64
+ }{
+ {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
+ {tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
+ {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
+ {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0},
+ }
+ for _, s := range stats {
+ if got, want := s.stat.Value(), s.want; got != want {
+ return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
+ }
+ }
+ return nil
+ }
+
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
+ }
+
+ // Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause
+ // window inflation and sending of packets is completely handled by the
+ // SACK Recovery algorithm. We should see no packets being released, as
+ // the cwnd at this point after entering recovery should be half of the
+ // outstanding number of packets in flight.
+ for i := 0; i < 7; i++ {
+ c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}})
+ end = end.Add(10)
+ }
+
+ recover := bytesRead
+
+ // Ensure no new packets arrive.
+ c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.",
+ 50*time.Millisecond)
+
+ // Acknowledge half of the pending data. This along with the 10 sacked
+ // segments above should reduce the outstanding below the current
+ // congestion window allowing the sender to transmit data.
+ rtxOffset = bytesRead - expected*maxPayload/2
+
+ // Now send a partial ACK w/ a SACK block that indicates that the next 3
+ // segments are lost and we have received 6 segments after the lost
+ // segments. This should cause the sender to immediately transmit all 3
+ // segments in response to this ACK unlike in FastRecovery where only 1
+ // segment is retransmitted per ACK.
+ start = c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1)
+ end = start.Add(60)
+ c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}})
+
+ // At this point, we acked expected/2 packets and we SACKED 6 packets and
+ // 3 segments were considered lost due to the SACK block we sent.
+ //
+ // So total packets outstanding can be calculated as follows after 7
+ // iterations of slow start -> 10/20/40/80/160/320/640. So expected
+ // should be 640 at start, then we went to recover at which point the
+ // cwnd should be set to 320 + 3 (for the 3 dupAcks which have left the
+ // network).
+ // Outstanding at this point after acking half the window
+ // (320 packets) will be:
+ // outstanding = 640-320-6(due to SACK block)-3 = 311
+ //
+ // The last 3 is due to the fact that the first 3 packets after
+ // rtxOffset will be considered lost due to the SACK blocks sent.
+ // Receive the retransmit due to partial ack.
+
+ c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize)
+ // Receive the 2 extra packets that should have been retransmitted as
+ // those should be considered lost and immediately retransmitted based
+ // on the SACK information in the previous ACK sent above.
+ for i := 0; i < 2; i++ {
+ c.ReceiveAndCheckPacketWithOptions(data, rtxOffset+maxPayload*(i+1), maxPayload, tsOptionSize)
+ }
+
+ // Now we should get 9 more new unsent packets as the cwnd is 323 and
+ // outstanding is 311.
+ for i := 0; i < 9; i++ {
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+ bytesRead += maxPayload
+ }
+
+ metricPollFn = func() error {
+ // In SACK recovery only the first segment is fast retransmitted when
+ // entering recovery.
+ if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
+ }
+
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want {
+ return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %d, want = %d", got, want)
+ }
+
+ if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want {
+ return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want)
+ }
+
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want {
+ return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %d, want = %d", got, want)
+ }
+ return nil
+ }
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
+ }
+
+ c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond)
+
+ // Acknowledge all pending data to recover point.
+ c.SendAck(790, recover)
+
+ // At this point, the cwnd should reset to expected/2 and there are 9
+ // packets outstanding.
+ //
+ // Now in the first iteration since there are 9 packets outstanding.
+ // We would expect to get expected/2 - 9 packets. But subsequent
+ // iterations will send us expected/2 + 1 (per iteration).
+ expected = expected/2 - 9
+ for i := 0; i < iterations; i++ {
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+ bytesRead += maxPayload
+ }
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd and iteration: %d.", expected, i), 50*time.Millisecond)
+
+ // Acknowledge all the data received so far.
+ c.SendAck(790, bytesRead)
+
+ // In cogestion avoidance, the packets trains increase by 1 in
+ // each iteration.
+ if i == 0 {
+ // After the first iteration we expect to get the full
+ // congestion window worth of packets in every
+ // iteration.
+ expected += 9
+ }
+ expected++
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
new file mode 100644
index 000000000..e67ec42b1
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -0,0 +1,7258 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_test
+
+import (
+ "bytes"
+ "fmt"
+ "math"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ // defaultMTU is the MTU, in bytes, used throughout the tests, except
+ // where another value is explicitly used. It is chosen to match the MTU
+ // of loopback interfaces on linux systems.
+ defaultMTU = 65535
+
+ // defaultIPv4MSS is the MSS sent by the network stack in SYN/SYN-ACK for an
+ // IPv4 endpoint when the MTU is set to defaultMTU in the test.
+ defaultIPv4MSS = defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
+)
+
+func TestGiveUpConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ var wq waiter.Queue
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ // Register for notification, then start connection attempt.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&waitEntry, waiter.EventOut)
+ defer wq.EventUnregister(&waitEntry)
+
+ if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
+ }
+
+ // Close the connection, wait for completion.
+ ep.Close()
+
+ // Wait for ep to become writable.
+ <-notifyCh
+ if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted {
+ t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %s, want = %s", err, tcpip.ErrAborted)
+ }
+
+ // Call Connect again to retreive the handshake failure status
+ // and stats updates.
+ if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted {
+ t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrAborted)
+ }
+
+ if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 {
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = 1", got)
+ }
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+}
+
+func TestConnectIncrementActiveConnection(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ stats := c.Stack().Stats()
+ want := stats.TCP.ActiveConnectionOpenings.Value() + 1
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
+ t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want)
+ }
+}
+
+func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ stats := c.Stack().Stats()
+ want := stats.TCP.FailedConnectionAttempts.Value()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
+ }
+ if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
+ t.Errorf("got EP stats.FailedConnectionAttempts = %d, want = %d", got, want)
+ }
+}
+
+func TestActiveFailedConnectionAttemptIncrement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ stats := c.Stack().Stats()
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ c.EP = ep
+ want := stats.TCP.FailedConnectionAttempts.Value() + 1
+
+ if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute {
+ t.Errorf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrNoRoute)
+ }
+
+ if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
+ }
+ if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
+ t.Errorf("got EP stats FailedConnectionAttempts = %d, want = %d", got, want)
+ }
+}
+
+func TestTCPSegmentsSentIncrement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ stats := c.Stack().Stats()
+ // SYN and ACK
+ want := stats.TCP.SegmentsSent.Value() + 2
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ if got := stats.TCP.SegmentsSent.Value(); got != want {
+ t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want)
+ }
+ if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want {
+ t.Errorf("got EP stats SegmentsSent.Value() = %d, want = %d", got, want)
+ }
+}
+
+func TestTCPResetsSentIncrement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ stats := c.Stack().Stats()
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ want := stats.TCP.SegmentsSent.Value() + 1
+
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ // If the AckNum is not the increment of the last sequence number, a RST
+ // segment is sent back in response.
+ AckNum: c.IRS + 2,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ c.GetPacket()
+
+ metricPollFn := func() error {
+ if got := stats.TCP.ResetsSent.Value(); got != want {
+ return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %d, want = %d", got, want)
+ }
+ return nil
+ }
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
+ }
+}
+
+// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates
+// a RST if an ACK is received on the listening socket for which there is no
+// active handshake in progress and we are not using SYN cookies.
+func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Lower stackwide TIME_WAIT timeout so that the reservations
+ // are released instantly on Close.
+ tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil {
+ t.Fatalf("e.stack.SetTransportProtocolOption(%d, %#v) = %s", tcp.ProtocolNumber, tcpTW, err)
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ c.GetPacket()
+
+ // Since an active close was done we need to wait for a little more than
+ // tcpLingerTimeout for the port reservations to be released and the
+ // socket to move to a CLOSED state.
+ time.Sleep(20 * time.Millisecond)
+
+ // Now resend the same ACK, this ACK should generate a RST as there
+ // should be no endpoint in SYN-RCVD state and we are not using
+ // syn-cookies yet. The reason we send the same ACK is we need a valid
+ // cookie(IRS) generated by the netstack without which the ACK will be
+ // rejected.
+ c.SendPacket(nil, ackHeaders)
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst)))
+}
+
+func TestTCPResetsReceivedIncrement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ stats := c.Stack().Stats()
+ want := stats.TCP.ResetsReceived.Value() + 1
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ SeqNum: iss.Add(1),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ Flags: header.TCPFlagRst,
+ })
+
+ if got := stats.TCP.ResetsReceived.Value(); got != want {
+ t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want)
+ }
+}
+
+func TestTCPResetsDoNotGenerateResets(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ stats := c.Stack().Stats()
+ want := stats.TCP.ResetsReceived.Value() + 1
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ SeqNum: iss.Add(1),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ Flags: header.TCPFlagRst,
+ })
+
+ if got := stats.TCP.ResetsReceived.Value(); got != want {
+ t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want)
+ }
+ c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond)
+}
+
+func TestActiveHandshake(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+}
+
+func TestNonBlockingClose(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ ep := c.EP
+ c.EP = nil
+
+ // Close the endpoint and measure how long it takes.
+ t0 := time.Now()
+ ep.Close()
+ if diff := time.Now().Sub(t0); diff > 3*time.Second {
+ t.Fatalf("Took too long to close: %s", diff)
+ }
+}
+
+func TestConnectResetAfterClose(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set TCPLinger to 3 seconds so that sockets are marked closed
+ // after 3 second in FIN_WAIT2 state.
+ tcpLingerTimeout := 3 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%s) failed: %s", tcpLingerTimeout, err)
+ }
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ ep := c.EP
+ c.EP = nil
+
+ // Close the endpoint, make sure we get a FIN segment, then acknowledge
+ // to complete closure of sender, but don't send our own FIN.
+ ep.Close()
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Wait for the ep to give up waiting for a FIN.
+ time.Sleep(tcpLingerTimeout + 1*time.Second)
+
+ // Now send an ACK and it should trigger a RST as the endpoint should
+ // not exist anymore.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ for {
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin {
+ // This is a retransmit of the FIN, ignore it.
+ continue
+ }
+
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ // RST is always generated with sndNxt which if the FIN
+ // has been sent will be 1 higher than the sequence number
+ // of the FIN itself.
+ checker.SeqNum(uint32(c.IRS)+2),
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst),
+ ),
+ )
+ break
+ }
+}
+
+// TestCurrentConnectedIncrement tests increment of the current
+// established and connected counters.
+func TestCurrentConnectedIncrement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
+ // after 1 second in TIME_WAIT state.
+ tcpTimeWaitTimeout := 1 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err)
+ }
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ ep := c.EP
+ c.EP = nil
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 1", got)
+ }
+ gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value()
+ if gotConnected != 1 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 1", gotConnected)
+ }
+
+ ep.Close()
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, gotConnected)
+ }
+
+ // Ack and send FIN as well.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Check that the stack acks the FIN.
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+2),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Wait for a little more than the TIME-WAIT duration for the socket to
+ // transition to CLOSED state.
+ time.Sleep(1200 * time.Millisecond)
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
+}
+
+// TestClosingWithEnqueuedSegments tests handling of still enqueued segments
+// when the endpoint transitions to StateClose. The in-flight segments would be
+// re-enqueued to a any listening endpoint.
+func TestClosingWithEnqueuedSegments(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ ep := c.EP
+ c.EP = nil
+
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want {
+ t.Errorf("unexpected endpoint state: want %d, got %d", want, got)
+ }
+
+ // Send a FIN for ESTABLISHED --> CLOSED-WAIT
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagFin | header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Get the ACK for the FIN we sent.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Give the stack a few ms to transition the endpoint out of ESTABLISHED
+ // state.
+ time.Sleep(10 * time.Millisecond)
+
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want {
+ t.Errorf("unexpected endpoint state: want %d, got %d", want, got)
+ }
+
+ // Close the application endpoint for CLOSE_WAIT --> LAST_ACK
+ ep.Close()
+
+ // Get the FIN
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want {
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+
+ // Pause the endpoint`s protocolMainLoop.
+ ep.(interface{ StopWork() }).StopWork()
+
+ // Enqueue last ACK followed by an ACK matching the endpoint
+ //
+ // Send Last ACK for LAST_ACK --> CLOSED
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 791,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Send a packet with ACK set, this would generate RST when
+ // not using SYN cookies as in this test.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 792,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Unpause endpoint`s protocolMainLoop.
+ ep.(interface{ ResumeWork() }).ResumeWork()
+
+ // Wait for the protocolMainLoop to resume and update state.
+ time.Sleep(10 * time.Millisecond)
+
+ // Expect the endpoint to be closed.
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+
+ if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = 1", got)
+ }
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+
+ // Check if the endpoint was moved to CLOSED and netstack a reset in
+ // response to the ACK packet that we sent after last-ACK.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+2),
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst),
+ ),
+ )
+}
+
+func TestSimpleReceive(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
+ }
+
+ data := []byte{1, 2, 3}
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ // Receive data.
+ v, _, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %s", err)
+ }
+
+ if !bytes.Equal(data, v) {
+ t.Fatalf("got data = %v, want = %v", v, data)
+ }
+
+ // Check that ACK is received.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when
+// creating a new active IPv4 TCP socket. It should be present in the sent TCP
+// SYN segment.
+func TestUserSuppliedMSSOnConnectV4(t *testing.T) {
+ const mtu = 5000
+ const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
+ tests := []struct {
+ name string
+ setMSS int
+ expMSS uint16
+ }{
+ {
+ "EqualToMaxMSS",
+ maxMSS,
+ maxMSS,
+ },
+ {
+ "LessThanMTU",
+ maxMSS - 1,
+ maxMSS - 1,
+ },
+ {
+ "GreaterThanMTU",
+ maxMSS + 1,
+ maxMSS,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ // Set the MSS socket option.
+ if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, test.setMSS); err != nil {
+ t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err)
+ }
+
+ // Get expected window size.
+ rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
+ t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
+ }
+ ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
+
+ // Start connection attempt to IPv4 address.
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("unexpected return value from Connect: %s", err)
+ }
+
+ // Receive SYN packet with our user supplied MSS.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws})))
+ })
+ }
+}
+
+// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when
+// creating a new active IPv6 TCP socket. It should be present in the sent TCP
+// SYN segment.
+func TestUserSuppliedMSSOnConnectV6(t *testing.T) {
+ const mtu = 5000
+ const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize
+ tests := []struct {
+ name string
+ setMSS uint16
+ expMSS uint16
+ }{
+ {
+ "EqualToMaxMSS",
+ maxMSS,
+ maxMSS,
+ },
+ {
+ "LessThanMTU",
+ maxMSS - 1,
+ maxMSS - 1,
+ },
+ {
+ "GreaterThanMTU",
+ maxMSS + 1,
+ maxMSS,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ // Set the MSS socket option.
+ if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
+ t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err)
+ }
+
+ // Get expected window size.
+ rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
+ t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
+ }
+ ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
+
+ // Start connection attempt to IPv6 address.
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("unexpected return value from Connect: %s", err)
+ }
+
+ // Receive SYN packet with our user supplied MSS.
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws})))
+ })
+ }
+}
+
+func TestSendRstOnListenerRxSynAckV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ checker.SeqNum(200)))
+}
+
+func TestSendRstOnListenerRxSynAckV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ checker.SeqNum(200)))
+}
+
+// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete,
+// peers can send data and expect a response within a reasonable ammount of time
+// without calling Accept on the listening endpoint first.
+//
+// This test uses IPv4.
+func TestTCPAckBeforeAcceptV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ // Send data before accepting the connection.
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
+
+// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete,
+// peers can send data and expect a response within a reasonable ammount of time
+// without calling Accept on the listening endpoint first.
+//
+// This test uses IPv6.
+func TestTCPAckBeforeAcceptV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ // Send data before accepting the connection.
+ c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
+
+func TestSendRstOnListenerRxAckV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1 /* epRcvBuf */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagFin | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ checker.SeqNum(200)))
+}
+
+func TestSendRstOnListenerRxAckV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true /* v6Only */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagFin | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ checker.SeqNum(200)))
+}
+
+// TestListenShutdown tests for the listening endpoint replying with RST
+// on read shutdown.
+func TestListenShutdown(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1 /* epRcvBuf */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(1 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
+ t.Fatal("Shutdown failed:", err)
+ }
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ // Expect the listening endpoint to reset the connection.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ ))
+}
+
+// TestListenCloseWhileConnect tests for the listening endpoint to
+// drain the accept-queue when closed. This should reset all of the
+// pending connections that are waiting to be accepted.
+func TestListenCloseWhileConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1 /* epRcvBuf */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(1 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventIn)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+ // Wait for the new endpoint created because of handshake to be delivered
+ // to the listening endpoint's accept queue.
+ <-notifyCh
+
+ // Close the listening endpoint.
+ c.EP.Close()
+
+ // Expect the listening endpoint to reset the connection.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ ))
+}
+
+func TestTOSV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ c.EP = ep
+
+ const tos = 0xC0
+ if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil {
+ t.Errorf("SetSockOptInt(IPv4TOSOption, %d) failed: %s", tos, err)
+ }
+
+ v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption)
+ if err != nil {
+ t.Errorf("GetSockoptInt(IPv4TOSOption) failed: %s", err)
+ }
+
+ if v != tos {
+ t.Errorf("got GetSockOptInt(IPv4TOSOption) = %d, want = %d", v, tos)
+ }
+
+ testV4Connect(t, c, checker.TOS(tos, 0))
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that data is received.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790), // Acknum is initial sequence number + 1
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ checker.TOS(tos, 0),
+ )
+
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
+ t.Errorf("got data = %x, want = %x", p, data)
+ }
+}
+
+func TestTrafficClassV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ const tos = 0xC0
+ if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tos); err != nil {
+ t.Errorf("SetSockOpInt(IPv6TrafficClassOption, %d) failed: %s", tos, err)
+ }
+
+ v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption)
+ if err != nil {
+ t.Fatalf("GetSockoptInt(IPv6TrafficClassOption) failed: %s", err)
+ }
+
+ if v != tos {
+ t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = %d, want = %d", v, tos)
+ }
+
+ // Test the connection request.
+ testV6Connect(t, c, checker.TOS(tos, 0))
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that data is received.
+ b := c.GetV6Packet()
+ checker.IPv6(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ checker.TOS(tos, 0),
+ )
+
+ if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
+ t.Errorf("got data = %x, want = %x", p, data)
+ }
+}
+
+func TestConnectBindToDevice(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ device tcpip.NICID
+ want tcp.EndpointState
+ }{
+ {"RightDevice", 1, tcp.StateEstablished},
+ {"WrongDevice", 2, tcp.StateSynSent},
+ {"AnyDevice", 0, tcp.StateEstablished},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+ bindToDevice := tcpip.BindToDeviceOption(test.device)
+ c.EP.SetSockOpt(bindToDevice)
+ // Start connection attempt.
+ waitEntry, _ := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("unexpected return value from Connect: %s", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: nil,
+ })
+
+ c.GetPacket()
+ if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want {
+ t.Fatalf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+ })
+ }
+}
+
+func TestRstOnSynSent(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create an endpoint, don't handshake because we want to interfere with the
+ // handshake process.
+ c.Create(-1)
+
+ // Start connection attempt.
+ waitEntry, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
+ if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted {
+ t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+
+ // Ensure that we've reached SynSent state
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ // Send a packet with a proper ACK and a RST flag to cause the socket
+ // to Error and close out
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagRst | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: nil,
+ })
+
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(3 * time.Second):
+ t.Fatal("timed out waiting for packet to arrive")
+ }
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused)
+ }
+
+ // Due to the RST the endpoint should be in an error state.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+}
+
+func TestOutOfOrderReceive(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Send second half of data first, with seqnum 3 ahead of expected.
+ data := []byte{1, 2, 3, 4, 5, 6}
+ c.SendPacket(data[3:], &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 793,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Check that we get an ACK specifying which seqnum is expected.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Wait 200ms and check that no data has been received.
+ time.Sleep(200 * time.Millisecond)
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Send the first 3 bytes now.
+ c.SendPacket(data[:3], &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Receive data.
+ read := make([]byte, 0, 6)
+ for len(read) < len(data) {
+ v, _, err := c.EP.Read(nil)
+ if err != nil {
+ if err == tcpip.ErrWouldBlock {
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(5 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+ continue
+ }
+ t.Fatalf("Read failed: %s", err)
+ }
+
+ read = append(read, v...)
+ }
+
+ // Check that we received the data in proper order.
+ if !bytes.Equal(data, read) {
+ t.Fatalf("got data = %v, want = %v", read, data)
+ }
+
+ // Check that the whole data is acknowledged.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestOutOfOrderFlood(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create a new connection with initial window size of 10.
+ c.CreateConnected(789, 30000, 10)
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Send 100 packets before the actual one that is expected.
+ data := []byte{1, 2, 3, 4, 5, 6}
+ for i := 0; i < 100; i++ {
+ c.SendPacket(data[3:], &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 796,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ // Send packet with seqnum 793. It must be discarded because the
+ // out-of-order buffer was filled by the previous packets.
+ c.SendPacket(data[3:], &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 793,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Now send the expected packet, seqnum 790.
+ c.SendPacket(data[:3], &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Check that only packet 790 is acknowledged.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(793),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestRstOnCloseWithUnreadData(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
+ }
+
+ data := []byte{1, 2, 3}
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(3 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ // Check that ACK is received, this happens regardless of the read.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Now that we know we have unread data, let's just close the connection
+ // and verify that netstack sends an RST rather than a FIN.
+ c.EP.Close()
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ // We shouldn't consume a sequence number on RST.
+ checker.SeqNum(uint32(c.IRS)+1),
+ ))
+ // The RST puts the endpoint into an error state.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+
+ // This final ACK should be ignored because an ACK on a reset doesn't mean
+ // anything.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + len(data)),
+ AckNum: c.IRS.Add(seqnum.Size(2)),
+ RcvWnd: 30000,
+ })
+}
+
+func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
+ }
+
+ data := []byte{1, 2, 3}
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(3 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ // Check that ACK is received, this happens regardless of the read.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Cause a FIN to be generated.
+ c.EP.Shutdown(tcpip.ShutdownWrite)
+
+ // Make sure we get the FIN but DON't ACK IT.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ checker.SeqNum(uint32(c.IRS)+1),
+ ))
+
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+
+ // Cause a RST to be generated by closing the read end now since we have
+ // unread data.
+ c.EP.Shutdown(tcpip.ShutdownRead)
+
+ // Make sure we get the RST
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ // RST is always generated with sndNxt which if the FIN
+ // has been sent will be 1 higher than the sequence
+ // number of the FIN itself.
+ checker.SeqNum(uint32(c.IRS)+2),
+ ))
+ // The RST puts the endpoint into an error state.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+
+ // The ACK to the FIN should now be rejected since the connection has been
+ // closed by a RST.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + len(data)),
+ AckNum: c.IRS.Add(seqnum.Size(2)),
+ RcvWnd: 30000,
+ })
+}
+
+func TestShutdownRead(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
+ }
+
+ if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
+ }
+ var want uint64 = 1
+ if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want {
+ t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want)
+ }
+}
+
+func TestFullWindowReceive(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, 10)
+
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ _, _, err := c.EP.Read(nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("Read failed: %s", err)
+ }
+
+ // Fill up the window.
+ data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(5 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ // Check that data is acknowledged, and window goes to zero.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.Window(0),
+ ),
+ )
+
+ // Receive data and check it.
+ v, _, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %s", err)
+ }
+
+ if !bytes.Equal(data, v) {
+ t.Fatalf("got data = %v, want = %v", v, data)
+ }
+
+ var want uint64 = 1
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want {
+ t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %d want %d", got, want)
+ }
+
+ // Check that we get an ACK for the newly non-zero window.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.Window(10),
+ ),
+ )
+}
+
+func TestNoWindowShrinking(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Start off with a window size of 10, then shrink it to 5.
+ c.CreateConnected(789, 30000, 10)
+
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil {
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err)
+ }
+
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Send 3 bytes, check that the peer acknowledges them.
+ data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
+ c.SendPacket(data[:3], &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(5 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ // Check that data is acknowledged, and that window doesn't go to zero
+ // just yet because it was previously set to 10. It must go to 7 now.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(793),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.Window(7),
+ ),
+ )
+
+ // Send 7 more bytes, check that the window fills up.
+ c.SendPacket(data[3:], &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 793,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ select {
+ case <-ch:
+ case <-time.After(5 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.Window(0),
+ ),
+ )
+
+ // Receive data and check it.
+ read := make([]byte, 0, 10)
+ for len(read) < len(data) {
+ v, _, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %s", err)
+ }
+
+ read = append(read, v...)
+ }
+
+ if !bytes.Equal(data, read) {
+ t.Fatalf("got data = %v, want = %v", read, data)
+ }
+
+ // Check that we get an ACK for the newly non-zero window, which is the
+ // new size.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.Window(5),
+ ),
+ )
+}
+
+func TestSimpleSend(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that data is received.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
+ t.Fatalf("got data = %v, want = %v", p, data)
+ }
+
+ // Acknowledge the data.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1 + seqnum.Size(len(data))),
+ RcvWnd: 30000,
+ })
+}
+
+func TestZeroWindowSend(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789 /* iss */, 0 /* rcvWnd */, -1 /* epRcvBuf */)
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
+ if err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check if we got a zero-window probe.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ // Open up the window. Data should be received now.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Check that data is received.
+ b = c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
+ t.Fatalf("got data = %v, want = %v", p, data)
+ }
+
+ // Acknowledge the data.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1 + seqnum.Size(len(data))),
+ RcvWnd: 30000,
+ })
+}
+
+func TestScaledWindowConnect(t *testing.T) {
+ // This test ensures that window scaling is used when the peer
+ // does advertise it and connection is established with Connect().
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set the window size greater than the maximum non-scaled window.
+ c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{
+ header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
+ })
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that data is received, and that advertised window is 0xbfff,
+ // that is, that it is scaled.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.Window(0xbfff),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+}
+
+func TestNonScaledWindowConnect(t *testing.T) {
+ // This test ensures that window scaling is not used when the peer
+ // doesn't advertise it and connection is established with Connect().
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set the window size greater than the maximum non-scaled window.
+ c.CreateConnected(789, 30000, 65535*3)
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that data is received, and that advertised window is 0xffff,
+ // that is, that it's not scaled.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.Window(0xffff),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+}
+
+func TestScaledWindowAccept(t *testing.T) {
+ // This test ensures that window scaling is used when the peer
+ // does advertise it and connection is established with Accept().
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create EP and start listening.
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer ep.Close()
+
+ // Set the window size greater than the maximum non-scaled window.
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err)
+ }
+
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Do 3-way handshake.
+ c.PassiveConnectWithOptions(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS})
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that data is received, and that advertised window is 0xbfff,
+ // that is, that it is scaled.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.Window(0xbfff),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+}
+
+func TestNonScaledWindowAccept(t *testing.T) {
+ // This test ensures that window scaling is not used when the peer
+ // doesn't advertise it and connection is established with Accept().
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create EP and start listening.
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer ep.Close()
+
+ // Set the window size greater than the maximum non-scaled window.
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err)
+ }
+
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN
+ // should not carry the window scaling option.
+ c.PassiveConnect(100, -1, header.TCPSynOptions{MSS: defaultIPv4MSS})
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that data is received, and that advertised window is 0xffff,
+ // that is, that it's not scaled.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.Window(0xffff),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+}
+
+func TestZeroScaledWindowReceive(t *testing.T) {
+ // This test ensures that the endpoint sends a non-zero window size
+ // advertisement when the scaled window transitions from 0 to non-zero,
+ // but the actual window (not scaled) hasn't gotten to zero.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set the window size such that a window scale of 4 will be used.
+ const wnd = 65535 * 10
+ const ws = uint32(4)
+ c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{
+ header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
+ })
+
+ // Write chunks of 50000 bytes.
+ remain := wnd
+ sent := 0
+ data := make([]byte, 50000)
+ for remain > len(data) {
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + sent),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ sent += len(data)
+ remain -= len(data)
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(uint16(remain>>ws)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ // Make the window non-zero, but the scaled window zero.
+ if remain >= 16 {
+ data = data[:remain-15]
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + sent),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ sent += len(data)
+ remain -= len(data)
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(0),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ // Read at least 1MSS of data. An ack should be sent in response to that.
+ sz := 0
+ for sz < defaultMTU {
+ v, _, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %s", err)
+ }
+ sz += len(v)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(uint16(sz>>ws)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestSegmentMerging(t *testing.T) {
+ tests := []struct {
+ name string
+ stop func(tcpip.Endpoint)
+ resume func(tcpip.Endpoint)
+ }{
+ {
+ "stop work",
+ func(ep tcpip.Endpoint) {
+ ep.(interface{ StopWork() }).StopWork()
+ },
+ func(ep tcpip.Endpoint) {
+ ep.(interface{ ResumeWork() }).ResumeWork()
+ },
+ },
+ {
+ "cork",
+ func(ep tcpip.Endpoint) {
+ ep.SetSockOptBool(tcpip.CorkOption, true)
+ },
+ func(ep tcpip.Endpoint) {
+ ep.SetSockOptBool(tcpip.CorkOption, false)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ // Send tcp.InitialCwnd number of segments to fill up
+ // InitialWindow but don't ACK. That should prevent
+ // anymore packets from going out.
+ for i := 0; i < tcp.InitialCwnd; i++ {
+ view := buffer.NewViewFromBytes([]byte{0})
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write #%d failed: %s", i+1, err)
+ }
+ }
+
+ // Now send the segments that should get merged as the congestion
+ // window is full and we won't be able to send any more packets.
+ var allData []byte
+ for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
+ allData = append(allData, data...)
+ view := buffer.NewViewFromBytes(data)
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write #%d failed: %s", i+1, err)
+ }
+ }
+
+ // Check that we get tcp.InitialCwnd packets.
+ for i := 0; i < tcp.InitialCwnd; i++ {
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(header.TCPMinimumSize+1),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+uint32(i)+1),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ }
+
+ // Acknowledge the data.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload.
+ RcvWnd: 30000,
+ })
+
+ // Check that data is received.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(allData)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+11),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) {
+ t.Fatalf("got data = %v, want = %v", got, allData)
+ }
+
+ // Acknowledge the data.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))),
+ RcvWnd: 30000,
+ })
+ })
+ }
+}
+
+func TestDelay(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ c.EP.SetSockOptBool(tcpip.DelayOption, true)
+
+ var allData []byte
+ for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
+ allData = append(allData, data...)
+ view := buffer.NewViewFromBytes(data)
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write #%d failed: %s", i+1, err)
+ }
+ }
+
+ seq := c.IRS.Add(1)
+ for _, want := range [][]byte{allData[:1], allData[1:]} {
+ // Check that data is received.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(want)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(seq)),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, want) {
+ t.Fatalf("got data = %v, want = %v", got, want)
+ }
+
+ seq = seq.Add(seqnum.Size(len(want)))
+ // Acknowledge the data.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seq,
+ RcvWnd: 30000,
+ })
+ }
+}
+
+func TestUndelay(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ c.EP.SetSockOptBool(tcpip.DelayOption, true)
+
+ allData := [][]byte{{0}, {1, 2, 3}}
+ for i, data := range allData {
+ view := buffer.NewViewFromBytes(data)
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write #%d failed: %s", i+1, err)
+ }
+ }
+
+ seq := c.IRS.Add(1)
+
+ // Check that data is received.
+ first := c.GetPacket()
+ checker.IPv4(t, first,
+ checker.PayloadLen(len(allData[0])+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(seq)),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ if got, want := first[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[0]; !bytes.Equal(got, want) {
+ t.Fatalf("got first packet's data = %v, want = %v", got, want)
+ }
+
+ seq = seq.Add(seqnum.Size(len(allData[0])))
+
+ // Check that we don't get the second packet yet.
+ c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond)
+
+ c.EP.SetSockOptBool(tcpip.DelayOption, false)
+
+ // Check that data is received.
+ second := c.GetPacket()
+ checker.IPv4(t, second,
+ checker.PayloadLen(len(allData[1])+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(seq)),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ if got, want := second[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[1]; !bytes.Equal(got, want) {
+ t.Fatalf("got second packet's data = %v, want = %v", got, want)
+ }
+
+ seq = seq.Add(seqnum.Size(len(allData[1])))
+
+ // Acknowledge the data.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seq,
+ RcvWnd: 30000,
+ })
+}
+
+func TestMSSNotDelayed(t *testing.T) {
+ tests := []struct {
+ name string
+ fn func(tcpip.Endpoint)
+ }{
+ {"no-op", func(tcpip.Endpoint) {}},
+ {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.DelayOption, true) }},
+ {"cork", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.CorkOption, true) }},
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const maxPayload = 100
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
+ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
+ })
+
+ test.fn(c.EP)
+
+ allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)}
+ for i, data := range allData {
+ view := buffer.NewViewFromBytes(data)
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write #%d failed: %s", i+1, err)
+ }
+ }
+
+ seq := c.IRS.Add(1)
+
+ for i, data := range allData {
+ // Check that data is received.
+ packet := c.GetPacket()
+ checker.IPv4(t, packet,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(seq)),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ if got, want := packet[header.IPv4MinimumSize+header.TCPMinimumSize:], data; !bytes.Equal(got, want) {
+ t.Fatalf("got packet #%d's data = %v, want = %v", i+1, got, want)
+ }
+
+ seq = seq.Add(seqnum.Size(len(data)))
+ }
+
+ // Acknowledge the data.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seq,
+ RcvWnd: 30000,
+ })
+ })
+ }
+}
+
+func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
+ payloadMultiplier := 10
+ dataLen := payloadMultiplier * maxPayload
+ data := make([]byte, dataLen)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that data is received in chunks.
+ bytesReceived := 0
+ numPackets := 0
+ for bytesReceived != dataLen {
+ b := c.GetPacket()
+ numPackets++
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ payloadLen := len(tcpHdr.Payload())
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1+uint32(bytesReceived)),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ pdata := data[bytesReceived : bytesReceived+payloadLen]
+ if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) {
+ t.Fatalf("got data = %v, want = %v", p, pdata)
+ }
+ bytesReceived += payloadLen
+ var options []byte
+ if c.TimeStampEnabled {
+ // If timestamp option is enabled, echo back the timestamp and increment
+ // the TSEcr value included in the packet and send that back as the TSVal.
+ parsedOpts := tcpHdr.ParsedOptions()
+ tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
+ header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:])
+ options = tsOpt[:]
+ }
+ // Acknowledge the data.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)),
+ RcvWnd: 30000,
+ TCPOpts: options,
+ })
+ }
+ if numPackets == 1 {
+ t.Fatalf("expected write to be broken up into multiple packets, but got 1 packet")
+ }
+}
+
+func TestSendGreaterThanMTU(t *testing.T) {
+ const maxPayload = 100
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ testBrokenUpWrite(t, c, maxPayload)
+}
+
+func TestSetTTL(t *testing.T) {
+ for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
+ t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
+ c := context.New(t, 65535)
+ defer c.Cleanup()
+
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil {
+ t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
+ }
+
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("unexpected return value from Connect: %s", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+
+ checker.IPv4(t, b, checker.TTL(wantTTL))
+ })
+ }
+}
+
+func TestActiveSendMSSLessThanMTU(t *testing.T) {
+ const maxPayload = 100
+ c := context.New(t, 65535)
+ defer c.Cleanup()
+
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
+ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
+ })
+ testBrokenUpWrite(t, c, maxPayload)
+}
+
+func TestPassiveSendMSSLessThanMTU(t *testing.T) {
+ const maxPayload = 100
+ const mtu = 1200
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ // Create EP and start listening.
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer ep.Close()
+
+ // Set the buffer size to a deterministic size so that we can check the
+ // window scaling option.
+ const rcvBufferSize = 0x20000
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err)
+ }
+
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Do 3-way handshake.
+ c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Check that data gets properly segmented.
+ testBrokenUpWrite(t, c, maxPayload)
+}
+
+func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
+ const maxPayload = 536
+ const mtu = 2000
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ // Set the SynRcvd threshold to zero to force a syn cookie based accept
+ // to happen.
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
+
+ // Create EP and start listening.
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer ep.Close()
+
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Do 3-way handshake.
+ c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Check that data gets properly segmented.
+ testBrokenUpWrite(t, c, maxPayload)
+}
+
+func TestForwarderSendMSSLessThanMTU(t *testing.T) {
+ const maxPayload = 100
+ const mtu = 1200
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ s := c.Stack()
+ ch := make(chan *tcpip.Error, 1)
+ f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) {
+ var err *tcpip.Error
+ c.EP, err = r.CreateEndpoint(&c.WQ)
+ ch <- err
+ })
+ s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket)
+
+ // Do 3-way handshake.
+ c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
+
+ // Wait for connection to be available.
+ select {
+ case err := <-ch:
+ if err != nil {
+ t.Fatalf("Error creating endpoint: %s", err)
+ }
+ case <-time.After(2 * time.Second):
+ t.Fatalf("Timed out waiting for connection")
+ }
+
+ // Check that data gets properly segmented.
+ testBrokenUpWrite(t, c, maxPayload)
+}
+
+func TestSynOptionsOnActiveConnect(t *testing.T) {
+ const mtu = 1400
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ // Set the buffer size to a deterministic size so that we can check the
+ // window scaling option.
+ const rcvBufferSize = 0x20000
+ const wndScale = 2
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err)
+ }
+
+ // Start connection attempt.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventOut)
+ defer c.WQ.EventUnregister(&we)
+
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+ mss := uint16(mtu - header.IPv4MinimumSize - header.TCPMinimumSize)
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}),
+ ),
+ )
+
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ // Wait for retransmit.
+ time.Sleep(1 * time.Second)
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.SrcPort(tcpHdr.SourcePort()),
+ checker.SeqNum(tcpHdr.SequenceNumber()),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}),
+ ),
+ )
+
+ // Send SYN-ACK.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Receive ACK packet.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(iss)+1),
+ ),
+ )
+
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
+ t.Fatalf("GetSockOpt failed: %s", err)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for connection")
+ }
+}
+
+func TestCloseListener(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create listener.
+ var wq waiter.Queue
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ if err := ep.Bind(tcpip.FullAddress{}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Close the listener and measure how long it takes.
+ t0 := time.Now()
+ ep.Close()
+ if diff := time.Now().Sub(t0); diff > 3*time.Second {
+ t.Fatalf("Took too long to close: %s", diff)
+ }
+}
+
+func TestReceiveOnResetConnection(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ // Send RST segment.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagRst,
+ SeqNum: 790,
+ RcvWnd: 30000,
+ })
+
+ // Try to read.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+loop:
+ for {
+ switch _, _, err := c.EP.Read(nil); err {
+ case tcpip.ErrWouldBlock:
+ select {
+ case <-ch:
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for reset to arrive")
+ }
+ case tcpip.ErrConnectionReset:
+ break loop
+ default:
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
+ }
+ }
+ // Expect the state to be StateError and subsequent Reads to fail with HardError.
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
+ }
+ if tcp.EndpointState(c.EP.State()) != tcp.StateError {
+ t.Fatalf("got EP state is not StateError")
+ }
+
+ if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
+ t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
+}
+
+func TestSendOnResetConnection(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ // Send RST segment.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagRst,
+ SeqNum: 790,
+ RcvWnd: 30000,
+ })
+
+ // Wait for the RST to be received.
+ time.Sleep(1 * time.Second)
+
+ // Try to write.
+ view := buffer.NewView(10)
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset {
+ t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset)
+ }
+}
+
+// TestMaxRetransmitsTimeout tests if the connection is timed out after
+// a segment has been retransmitted MaxRetries times.
+func TestMaxRetransmitsTimeout(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ const numRetries = 2
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRetriesOption(numRetries)); err != nil {
+ t.Fatalf("could not set protocol option MaxRetries.\n")
+ }
+
+ c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
+ if err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Expect first transmit and MaxRetries retransmits.
+ for i := 0; i < numRetries+1; i++ {
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh),
+ ),
+ )
+ }
+ // Wait for the connection to timeout after MaxRetries retransmits.
+ initRTO := 1 * time.Second
+ select {
+ case <-notifyCh:
+ case <-time.After((2 << numRetries) * initRTO):
+ t.Fatalf("connection still alive after maximum retransmits.\n")
+ }
+
+ // Send an ACK and expect a RST as the connection would have been closed.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ ),
+ )
+
+ if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
+}
+
+// TestMaxRTO tests if the retransmit interval caps to MaxRTO.
+func TestMaxRTO(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ rto := 1 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRTOOption(rto)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPMaxRTO(%d) failed: %s", rto, err)
+ }
+
+ c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+
+ _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
+ if err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ const numRetransmits = 2
+ for i := 0; i < numRetransmits; i++ {
+ start := time.Now()
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() {
+ t.Errorf("Retransmit interval not capped to MaxRTO.\n")
+ }
+ }
+}
+
+// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is
+// unique on retransmits.
+func TestRetransmitIPv4IDUniqueness(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ size int
+ }{
+ {"1Byte", 1},
+ {"512Bytes", 512},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+
+ // Disabling PMTU discovery causes all packets sent from this socket to
+ // have DF=0. This needs to be done because the IPv4 ID uniqueness
+ // applies only to non-atomic IPv4 datagrams as defined in RFC 6864
+ // Section 4, and datagrams with DF=0 are non-atomic.
+ if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil {
+ t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err)
+ }
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.FragmentFlags(0),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): struct{}{}}
+ // Expect two retransmitted packets, and that all packets received have
+ // unique IPv4 ID values.
+ for i := 0; i <= 2; i++ {
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.FragmentFlags(0),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ id := header.IPv4(pkt).ID()
+ if _, exists := idSet[id]; exists {
+ t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id)
+ }
+ idSet[id] = struct{}{}
+ }
+ })
+ }
+}
+
+func TestFinImmediately(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ // Shutdown immediately, check that we get a FIN.
+ if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+
+ // Ack and send FIN as well.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Check that the stack acks the FIN.
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+2),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestFinRetransmit(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ // Shutdown immediately, check that we get a FIN.
+ if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+
+ // Don't acknowledge yet. We should get a retransmit of the FIN.
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+
+ // Ack and send FIN as well.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Check that the stack acks the FIN.
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+2),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestFinWithNoPendingData(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ // Write something out, and have it acknowledged.
+ view := buffer.NewView(10)
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ next := uint32(c.IRS) + 1
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ next += uint32(len(view))
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ // Shutdown, check that we get a FIN.
+ if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+ next++
+
+ // Ack and send FIN as well.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ // Check that the stack acks the FIN.
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestFinWithPendingDataCwndFull(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ // Write enough segments to fill the congestion window before ACK'ing
+ // any of them.
+ view := buffer.NewView(10)
+ for i := tcp.InitialCwnd; i > 0; i-- {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ }
+
+ next := uint32(c.IRS) + 1
+ for i := tcp.InitialCwnd; i > 0; i-- {
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ next += uint32(len(view))
+ }
+
+ // Shutdown the connection, check that the FIN segment isn't sent
+ // because the congestion window doesn't allow it. Wait until a
+ // retransmit is received.
+ if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ // Send the ACK that will allow the FIN to be sent as well.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+ next++
+
+ // Send a FIN that acknowledges everything. Get an ACK back.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestFinWithPendingData(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ // Write something out, and acknowledge it to get cwnd to 2.
+ view := buffer.NewView(10)
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ next := uint32(c.IRS) + 1
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ next += uint32(len(view))
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ // Write new data, but don't acknowledge it.
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ next += uint32(len(view))
+
+ // Shutdown the connection, check that we do get a FIN.
+ if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+ next++
+
+ // Send a FIN that acknowledges everything. Get an ACK back.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestFinWithPartialAck(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ // Write something out, and acknowledge it to get cwnd to 2. Also send
+ // FIN from the test side.
+ view := buffer.NewView(10)
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ next := uint32(c.IRS) + 1
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ next += uint32(len(view))
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ // Check that we get an ACK for the fin.
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(791),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ // Write new data, but don't acknowledge it.
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(791),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ next += uint32(len(view))
+
+ // Shutdown the connection, check that we do get a FIN.
+ if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+ next++
+
+ // Send an ACK for the data, but not for the FIN yet.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 791,
+ AckNum: seqnum.Value(next - 1),
+ RcvWnd: 30000,
+ })
+
+ // Check that we don't get a retransmit of the FIN.
+ c.CheckNoPacketTimeout("FIN retransmitted when data was ack'd", 100*time.Millisecond)
+
+ // Ack the FIN.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 791,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+}
+
+func TestUpdateListenBacklog(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create listener.
+ var wq waiter.Queue
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ if err := ep.Bind(tcpip.FullAddress{}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Update the backlog with another Listen() on the same endpoint.
+ if err := ep.Listen(20); err != nil {
+ t.Fatalf("Listen failed to update backlog: %s", err)
+ }
+
+ ep.Close()
+}
+
+func scaledSendWindow(t *testing.T, scale uint8) {
+ // This test ensures that the endpoint is using the right scaling by
+ // sending a buffer that is larger than the window size, and ensuring
+ // that the endpoint doesn't send more than allowed.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
+ c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{
+ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
+ header.TCPOptionWS, 3, scale, header.TCPOptionNOP,
+ })
+
+ // Open up the window with a scaled value.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 1,
+ })
+
+ // Send some data. Check that it's capped by the window size.
+ view := buffer.NewView(65535)
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that only data that fits in the scaled window is sent.
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen((1<<scale)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ // Reset the connection to free resources.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagRst,
+ SeqNum: 790,
+ })
+}
+
+func TestScaledSendWindow(t *testing.T) {
+ for scale := uint8(0); scale <= 14; scale++ {
+ scaledSendWindow(t, scale)
+ }
+}
+
+func TestReceivedValidSegmentCountIncrement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ stats := c.Stack().Stats()
+ want := stats.TCP.ValidSegmentsReceived.Value() + 1
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ if got := stats.TCP.ValidSegmentsReceived.Value(); got != want {
+ t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %d, want = %d", got, want)
+ }
+ if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want {
+ t.Errorf("got EP stats Stats.SegmentsReceived = %d, want = %d", got, want)
+ }
+ // Ensure there were no errors during handshake. If these stats have
+ // incremented, then the connection should not have been established.
+ if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 {
+ t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0)
+ }
+ if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 {
+ t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %d, want = %d", got, 0)
+ }
+}
+
+func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ stats := c.Stack().Stats()
+ want := stats.TCP.InvalidSegmentsReceived.Value() + 1
+ vv := c.BuildSegment(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ tcpbuf := vv.ToView()[header.IPv4MinimumSize:]
+ tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4
+
+ c.SendSegment(vv)
+
+ if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want {
+ t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %d, want = %d", got, want)
+ }
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
+ }
+}
+
+func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ stats := c.Stack().Stats()
+ want := stats.TCP.ChecksumErrors.Value() + 1
+ vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ tcpbuf := vv.ToView()[header.IPv4MinimumSize:]
+ // Overwrite a byte in the payload which should cause checksum
+ // verification to fail.
+ tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4
+
+ c.SendSegment(vv)
+
+ if got := stats.TCP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want)
+ }
+}
+
+func TestReceivedSegmentQueuing(t *testing.T) {
+ // This test sends 200 segments containing a few bytes each to an
+ // endpoint and checks that they're all received and acknowledged by
+ // the endpoint, that is, that none of the segments are dropped by
+ // internal queues.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ // Send 200 segments.
+ data := []byte{1, 2, 3}
+ for i := 0; i < 200; i++ {
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + i*len(data)),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ }
+
+ // Receive ACKs for all segments.
+ last := seqnum.Value(790 + 200*len(data))
+ for {
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ ack := seqnum.Value(tcpHdr.AckNumber())
+ if ack == last {
+ break
+ }
+
+ if last.LessThan(ack) {
+ t.Fatalf("Acknowledge (%v) beyond the expected (%v)", ack, last)
+ }
+ }
+}
+
+func TestReadAfterClosedState(t *testing.T) {
+ // This test ensures that calling Read() or Peek() after the endpoint
+ // has transitioned to closedState still works if there is pending
+ // data. To transition to stateClosed without calling Close(), we must
+ // shutdown the send path and the peer must send its own FIN.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
+ // after 1 second in TIME_WAIT state.
+ tcpTimeWaitTimeout := 1 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err)
+ }
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Shutdown immediately for write, check that we get a FIN.
+ if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+
+ // Send some data and acknowledge the FIN.
+ data := []byte{1, 2, 3}
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Check that ACK is received.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+2),
+ checker.AckNum(uint32(791+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Give the stack the chance to transition to closed state from
+ // TIME_WAIT.
+ time.Sleep(tcpTimeWaitTimeout * 2)
+
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want {
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ // Check that peek works.
+ peekBuf := make([]byte, 10)
+ n, _, err := c.EP.Peek([][]byte{peekBuf})
+ if err != nil {
+ t.Fatalf("Peek failed: %s", err)
+ }
+
+ peekBuf = peekBuf[:n]
+ if !bytes.Equal(data, peekBuf) {
+ t.Fatalf("got data = %v, want = %v", peekBuf, data)
+ }
+
+ // Receive data.
+ v, _, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %s", err)
+ }
+
+ if !bytes.Equal(data, v) {
+ t.Fatalf("got data = %v, want = %v", v, data)
+ }
+
+ // Now that we drained the queue, check that functions fail with the
+ // right error code.
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
+ }
+
+ if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
+ t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive)
+ }
+}
+
+func TestReusePort(t *testing.T) {
+ // This test ensures that ports are immediately available for reuse
+ // after Close on the endpoints using them returns.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // First case, just an endpoint that was bound.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
+ }
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ c.EP.Close()
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
+ }
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ c.EP.Close()
+
+ // Second case, an endpoint that was bound and is connecting..
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
+ }
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
+ }
+ c.EP.Close()
+
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
+ }
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ c.EP.Close()
+
+ // Third case, an endpoint that was bound and is listening.
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
+ }
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+ c.EP.Close()
+
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
+ }
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+}
+
+func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
+ t.Helper()
+
+ s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
+ t.Fatalf("GetSockOpt failed: %s", err)
+ }
+
+ if int(s) != v {
+ t.Fatalf("got receive buffer size = %d, want = %d", s, v)
+ }
+}
+
+func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
+ t.Helper()
+
+ s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ if err != nil {
+ t.Fatalf("GetSockOpt failed: %s", err)
+ }
+
+ if int(s) != v {
+ t.Fatalf("got send buffer size = %d, want = %d", s, v)
+ }
+}
+
+func TestDefaultBufferSizes(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
+
+ // Check the default values.
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ defer func() {
+ if ep != nil {
+ ep.Close()
+ }
+ }()
+
+ checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize)
+ checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
+
+ // Change the default send buffer size.
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{
+ Min: 1,
+ Default: tcp.DefaultSendBufferSize * 2,
+ Max: tcp.DefaultSendBufferSize * 20}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ }
+
+ ep.Close()
+ ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+
+ checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
+ checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
+
+ // Change the default receive buffer size.
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{
+ Min: 1,
+ Default: tcp.DefaultReceiveBufferSize * 3,
+ Max: tcp.DefaultReceiveBufferSize * 30}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+
+ ep.Close()
+ ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+
+ checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
+ checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*3)
+}
+
+func TestMinMaxBufferSizes(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
+
+ // Check the default values.
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ defer ep.Close()
+
+ // Change the min/max values for send/receive
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ }
+
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ }
+
+ // Set values below the min.
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil {
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err)
+ }
+
+ checkRecvBufferSize(t, ep, 200)
+
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil {
+ t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err)
+ }
+
+ checkSendBufferSize(t, ep, 300)
+
+ // Set values above the max.
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil {
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
+ }
+
+ checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20)
+
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil {
+ t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err)
+ }
+
+ checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30)
+}
+
+func TestBindToDeviceOption(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}})
+
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ defer ep.Close()
+
+ if err := s.CreateNIC(321, loopback.New()); err != nil {
+ t.Errorf("CreateNIC failed: %s", err)
+ }
+
+ // nicIDPtr is used instead of taking the address of NICID literals, which is
+ // a compiler error.
+ nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
+ return &s
+ }
+
+ testActions := []struct {
+ name string
+ setBindToDevice *tcpip.NICID
+ setBindToDeviceError *tcpip.Error
+ getBindToDevice tcpip.BindToDeviceOption
+ }{
+ {"GetDefaultValue", nil, nil, 0},
+ {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
+ {"BindToExistent", nicIDPtr(321), nil, 321},
+ {"UnbindToDevice", nicIDPtr(0), nil, 0},
+ }
+ for _, testAction := range testActions {
+ t.Run(testAction.name, func(t *testing.T) {
+ if testAction.setBindToDevice != nil {
+ bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
+ if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("SetSockOpt(%#v) got %v, want %v", bindToDevice, gotErr, wantErr)
+ }
+ }
+ bindToDevice := tcpip.BindToDeviceOption(88888)
+ if err := ep.GetSockOpt(&bindToDevice); err != nil {
+ t.Errorf("GetSockOpt got %s, want %v", err, nil)
+ }
+ if got, want := bindToDevice, testAction.getBindToDevice; got != want {
+ t.Errorf("bindToDevice got %d, want %d", got, want)
+ }
+ })
+ }
+}
+
+func makeStack() (*stack.Stack, *tcpip.Error) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{
+ ipv4.NewProtocol(),
+ ipv6.NewProtocol(),
+ },
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
+
+ id := loopback.New()
+ if testing.Verbose() {
+ id = sniffer.New(id)
+ }
+
+ if err := s.CreateNIC(1, id); err != nil {
+ return nil, err
+ }
+
+ for _, ct := range []struct {
+ number tcpip.NetworkProtocolNumber
+ address tcpip.Address
+ }{
+ {ipv4.ProtocolNumber, context.StackAddr},
+ {ipv6.ProtocolNumber, context.StackV6Addr},
+ } {
+ if err := s.AddAddress(1, ct.number, ct.address); err != nil {
+ return nil, err
+ }
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: 1,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: 1,
+ },
+ })
+
+ return s, nil
+}
+
+func TestSelfConnect(t *testing.T) {
+ // This test ensures that intentional self-connects work. In particular,
+ // it checks that if an endpoint binds to say 127.0.0.1:1000 then
+ // connects to 127.0.0.1:1000, then it will be connected to itself, and
+ // is able to send and receive data through the same endpoint.
+ s, err := makeStack()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer ep.Close()
+
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Register for notification, then start connection attempt.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&waitEntry, waiter.EventOut)
+ defer wq.EventUnregister(&waitEntry)
+
+ if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
+ }
+
+ <-notifyCh
+ if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil {
+ t.Fatalf("Connect failed: %s", err)
+ }
+
+ // Write something.
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+ if _, _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Read back what was written.
+ wq.EventUnregister(&waitEntry)
+ wq.EventRegister(&waitEntry, waiter.EventIn)
+ rd, _, err := ep.Read(nil)
+ if err != nil {
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("Read failed: %s", err)
+ }
+ <-notifyCh
+ rd, _, err = ep.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %s", err)
+ }
+ }
+
+ if !bytes.Equal(data, rd) {
+ t.Fatalf("got data = %v, want = %v", rd, data)
+ }
+}
+
+func TestConnectAvoidsBoundPorts(t *testing.T) {
+ addressTypes := func(t *testing.T, network string) []string {
+ switch network {
+ case "ipv4":
+ return []string{"v4"}
+ case "ipv6":
+ return []string{"v6"}
+ case "dual":
+ return []string{"v6", "mapped"}
+ default:
+ t.Fatalf("unknown network: '%s'", network)
+ }
+
+ panic("unreachable")
+ }
+
+ address := func(t *testing.T, addressType string, isAny bool) tcpip.Address {
+ switch addressType {
+ case "v4":
+ if isAny {
+ return ""
+ }
+ return context.StackAddr
+ case "v6":
+ if isAny {
+ return ""
+ }
+ return context.StackV6Addr
+ case "mapped":
+ if isAny {
+ return context.V4MappedWildcardAddr
+ }
+ return context.StackV4MappedAddr
+ default:
+ t.Fatalf("unknown address type: '%s'", addressType)
+ }
+
+ panic("unreachable")
+ }
+ // This test ensures that Endpoint.Connect doesn't select already-bound ports.
+ networks := []string{"ipv4", "ipv6", "dual"}
+ for _, exhaustedNetwork := range networks {
+ t.Run(fmt.Sprintf("exhaustedNetwork=%s", exhaustedNetwork), func(t *testing.T) {
+ for _, exhaustedAddressType := range addressTypes(t, exhaustedNetwork) {
+ t.Run(fmt.Sprintf("exhaustedAddressType=%s", exhaustedAddressType), func(t *testing.T) {
+ for _, isAny := range []bool{false, true} {
+ t.Run(fmt.Sprintf("isAny=%t", isAny), func(t *testing.T) {
+ for _, candidateNetwork := range networks {
+ t.Run(fmt.Sprintf("candidateNetwork=%s", candidateNetwork), func(t *testing.T) {
+ for _, candidateAddressType := range addressTypes(t, candidateNetwork) {
+ t.Run(fmt.Sprintf("candidateAddressType=%s", candidateAddressType), func(t *testing.T) {
+ s, err := makeStack()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var wq waiter.Queue
+ var eps []tcpip.Endpoint
+ defer func() {
+ for _, ep := range eps {
+ ep.Close()
+ }
+ }()
+ makeEP := func(network string) tcpip.Endpoint {
+ var networkProtocolNumber tcpip.NetworkProtocolNumber
+ switch network {
+ case "ipv4":
+ networkProtocolNumber = ipv4.ProtocolNumber
+ case "ipv6", "dual":
+ networkProtocolNumber = ipv6.ProtocolNumber
+ default:
+ t.Fatalf("unknown network: '%s'", network)
+ }
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ eps = append(eps, ep)
+ switch network {
+ case "ipv4":
+ case "ipv6":
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOptBool(V6OnlyOption(true)) failed: %s", err)
+ }
+ case "dual":
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil {
+ t.Fatalf("SetSockOptBool(V6OnlyOption(false)) failed: %s", err)
+ }
+ default:
+ t.Fatalf("unknown network: '%s'", network)
+ }
+ return ep
+ }
+
+ var v4reserved, v6reserved bool
+ switch exhaustedAddressType {
+ case "v4", "mapped":
+ v4reserved = true
+ case "v6":
+ v6reserved = true
+ // Dual stack sockets bound to v6 any reserve on v4 as
+ // well.
+ if isAny {
+ switch exhaustedNetwork {
+ case "ipv6":
+ case "dual":
+ v4reserved = true
+ default:
+ t.Fatalf("unknown address type: '%s'", exhaustedNetwork)
+ }
+ }
+ default:
+ t.Fatalf("unknown address type: '%s'", exhaustedAddressType)
+ }
+ var collides bool
+ switch candidateAddressType {
+ case "v4", "mapped":
+ collides = v4reserved
+ case "v6":
+ collides = v6reserved
+ default:
+ t.Fatalf("unknown address type: '%s'", candidateAddressType)
+ }
+
+ for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ {
+ if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
+ t.Fatalf("Bind(%d) failed: %s", i, err)
+ }
+ }
+ want := tcpip.ErrConnectStarted
+ if collides {
+ want = tcpip.ErrNoPortAvailable
+ }
+ if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want {
+ t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want)
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
+
+func TestPathMTUDiscovery(t *testing.T) {
+ // This test verifies the stack retransmits packets after it receives an
+ // ICMP packet indicating that the path MTU has been exceeded.
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ // Create new connection with MSS of 1460.
+ const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
+ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
+ })
+
+ // Send 3200 bytes of data.
+ const writeSize = 3200
+ data := buffer.NewView(writeSize)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte {
+ var ret []byte
+ for i, size := range sizes {
+ p := c.GetPacket()
+ if i == which {
+ ret = p
+ }
+ checker.IPv4(t, p,
+ checker.PayloadLen(size+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(seqNum),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ seqNum += uint32(size)
+ }
+ return ret
+ }
+
+ // Receive three packets.
+ sizes := []int{maxPayload, maxPayload, writeSize - 2*maxPayload}
+ first := receivePackets(c, sizes, 0, uint32(c.IRS)+1)
+
+ // Send "packet too big" messages back to netstack.
+ const newMTU = 1200
+ const newMaxPayload = newMTU - header.IPv4MinimumSize - header.TCPMinimumSize
+ mtu := []byte{0, 0, newMTU / 256, newMTU % 256}
+ c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, mtu, first, newMTU)
+
+ // See retransmitted packets. None exceeding the new max.
+ sizes = []int{newMaxPayload, maxPayload - newMaxPayload, newMaxPayload, maxPayload - newMaxPayload, writeSize - 2*maxPayload}
+ receivePackets(c, sizes, -1, uint32(c.IRS)+1)
+}
+
+func TestTCPEndpointProbe(t *testing.T) {
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ invoked := make(chan struct{})
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that the endpoint ID is what we expect.
+ //
+ // We don't do an extensive validation of every field but a
+ // basic sanity test.
+ if got, want := state.ID.LocalAddress, tcpip.Address(context.StackAddr); got != want {
+ t.Fatalf("got LocalAddress: %q, want: %q", got, want)
+ }
+ if got, want := state.ID.LocalPort, c.Port; got != want {
+ t.Fatalf("got LocalPort: %d, want: %d", got, want)
+ }
+ if got, want := state.ID.RemoteAddress, tcpip.Address(context.TestAddr); got != want {
+ t.Fatalf("got RemoteAddress: %q, want: %q", got, want)
+ }
+ if got, want := state.ID.RemotePort, uint16(context.TestPort); got != want {
+ t.Fatalf("got RemotePort: %d, want: %d", got, want)
+ }
+
+ invoked <- struct{}{}
+ })
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ data := []byte{1, 2, 3}
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ select {
+ case <-invoked:
+ case <-time.After(100 * time.Millisecond):
+ t.Fatalf("TCP Probe function was not called")
+ }
+}
+
+func TestStackSetCongestionControl(t *testing.T) {
+ testCases := []struct {
+ cc tcpip.CongestionControlOption
+ err *tcpip.Error
+ }{
+ {"reno", nil},
+ {"cubic", nil},
+ {"blahblah", tcpip.ErrNoSuchFile},
+ }
+
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) {
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ s := c.Stack()
+
+ var oldCC tcpip.CongestionControlOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err)
+ }
+
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err {
+ t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err)
+ }
+
+ var cc tcpip.CongestionControlOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
+ }
+
+ got, want := cc, oldCC
+ // If SetTransportProtocolOption is expected to succeed
+ // then the returned value for congestion control should
+ // match the one specified in the
+ // SetTransportProtocolOption call above, else it should
+ // be what it was before the call to
+ // SetTransportProtocolOption.
+ if tc.err == nil {
+ want = tc.cc
+ }
+ if got != want {
+ t.Fatalf("got congestion control: %v, want: %v", got, want)
+ }
+ })
+ }
+}
+
+func TestStackAvailableCongestionControl(t *testing.T) {
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ s := c.Stack()
+
+ // Query permitted congestion control algorithms.
+ var aCC tcpip.AvailableCongestionControlOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err)
+ }
+ if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
+ }
+}
+
+func TestStackSetAvailableCongestionControl(t *testing.T) {
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ s := c.Stack()
+
+ // Setting AvailableCongestionControlOption should fail.
+ aCC := tcpip.AvailableCongestionControlOption("xyz")
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC)
+ }
+
+ // Verify that we still get the expected list of congestion control options.
+ var cc tcpip.AvailableCongestionControlOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
+ }
+ if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
+ }
+}
+
+func TestEndpointSetCongestionControl(t *testing.T) {
+ testCases := []struct {
+ cc tcpip.CongestionControlOption
+ err *tcpip.Error
+ }{
+ {"reno", nil},
+ {"cubic", nil},
+ {"blahblah", tcpip.ErrNoSuchFile},
+ }
+
+ for _, connected := range []bool{false, true} {
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) {
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ var oldCC tcpip.CongestionControlOption
+ if err := c.EP.GetSockOpt(&oldCC); err != nil {
+ t.Fatalf("c.EP.SockOpt(%v) = %s", &oldCC, err)
+ }
+
+ if connected {
+ c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil)
+ }
+
+ if err := c.EP.SetSockOpt(tc.cc); err != tc.err {
+ t.Fatalf("c.EP.SetSockOpt(%v) = %s, want %s", tc.cc, err, tc.err)
+ }
+
+ var cc tcpip.CongestionControlOption
+ if err := c.EP.GetSockOpt(&cc); err != nil {
+ t.Fatalf("c.EP.SockOpt(%v) = %s", &cc, err)
+ }
+
+ got, want := cc, oldCC
+ // If SetSockOpt is expected to succeed then the
+ // returned value for congestion control should match
+ // the one specified in the SetSockOpt above, else it
+ // should be what it was before the call to SetSockOpt.
+ if tc.err == nil {
+ want = tc.cc
+ }
+ if got != want {
+ t.Fatalf("got congestion control: %v, want: %v", got, want)
+ }
+ })
+ }
+ }
+}
+
+func enableCUBIC(t *testing.T, c *context.Context) {
+ t.Helper()
+ opt := tcpip.CongestionControlOption("cubic")
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil {
+ t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %s = %s", opt, err)
+ }
+}
+
+func TestKeepalive(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ const keepAliveInterval = 3 * time.Second
+ c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond))
+ c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
+ c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5)
+ c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true)
+
+ // 5 unacked keepalives are sent. ACK each one, and check that the
+ // connection stays alive after 5.
+ for i := 0; i < 10; i++ {
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)),
+ checker.AckNum(uint32(790)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Acknowledge the keepalive.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS,
+ RcvWnd: 30000,
+ })
+ }
+
+ // Check that the connection is still alive.
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Send some data and wait before ACKing it. Keepalives should be disabled
+ // during this period.
+ view := buffer.NewView(3)
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ next := uint32(c.IRS) + 1
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ // Wait for the packet to be retransmitted. Verify that no keepalives
+ // were sent.
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh),
+ ),
+ )
+ c.CheckNoPacket("Keepalive packet received while unACKed data is pending")
+
+ next += uint32(len(view))
+
+ // Send ACK. Keepalives should start sending again.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ // Now receive 5 keepalives, but don't ACK them. The connection
+ // should be reset after 5.
+ for i := 0; i < 5; i++ {
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(next-1)),
+ checker.AckNum(uint32(790)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ // Sleep for a litte over the KeepAlive interval to make sure
+ // the timer has time to fire after the last ACK and close the
+ // close the socket.
+ time.Sleep(keepAliveInterval + keepAliveInterval/2)
+
+ // The connection should be terminated after 5 unacked keepalives.
+ // Send an ACK to trigger a RST from the stack as the endpoint should
+ // be dead.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(next)),
+ checker.AckNum(uint32(0)),
+ checker.TCPFlags(header.TCPFlagRst),
+ ),
+ )
+
+ if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
+ }
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
+ }
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
+}
+
+func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
+ // Send a SYN request.
+ irs = seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: srcPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcp := header.TCP(header.IPv4(b).Payload())
+ iss = seqnum.Value(tcp.SequenceNumber())
+ tcpCheckers := []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(srcPort),
+ checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
+ checker.AckNum(uint32(irs) + 1),
+ }
+
+ if synCookieInUse {
+ // When cookies are in use window scaling is disabled.
+ tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{
+ WS: -1,
+ MSS: c.MSSWithoutOptions(),
+ }))
+ }
+
+ checker.IPv4(t, b, checker.TCP(tcpCheckers...))
+
+ // Send ACK.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: srcPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+ return irs, iss
+}
+
+func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
+ // Send a SYN request.
+ irs = seqnum.Value(789)
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: srcPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetV6Packet()
+ tcp := header.TCP(header.IPv6(b).Payload())
+ iss = seqnum.Value(tcp.SequenceNumber())
+ tcpCheckers := []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(srcPort),
+ checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
+ checker.AckNum(uint32(irs) + 1),
+ }
+
+ if synCookieInUse {
+ // When cookies are in use window scaling is disabled.
+ tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{
+ WS: -1,
+ MSS: c.MSSWithoutOptionsV6(),
+ }))
+ }
+
+ checker.IPv6(t, b, checker.TCP(tcpCheckers...))
+
+ // Send ACK.
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: srcPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+ return irs, iss
+}
+
+// TestListenBacklogFull tests that netstack does not complete handshakes if the
+// listen backlog for the endpoint is full.
+func TestListenBacklogFull(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Test acceptance.
+ // Start listening.
+ listenBacklog := 2
+ if err := c.EP.Listen(listenBacklog); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ for i := 0; i < listenBacklog; i++ {
+ executeHandshake(t, c, context.TestPort+uint16(i), false /*synCookieInUse */)
+ }
+
+ time.Sleep(50 * time.Millisecond)
+
+ // Now execute send one more SYN. The stack should not respond as the backlog
+ // is full at this point.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort + 2,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: seqnum.Value(789),
+ RcvWnd: 30000,
+ })
+ c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
+
+ // Try to accept the connections in the backlog.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ for i := 0; i < listenBacklog; i++ {
+ _, _, err = c.EP.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ _, _, err = c.EP.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+ }
+
+ // Now verify that there are no more connections that can be accepted.
+ _, _, err = c.EP.Accept()
+ if err != tcpip.ErrWouldBlock {
+ select {
+ case <-ch:
+ t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
+ case <-time.After(1 * time.Second):
+ }
+ }
+
+ // Now a new handshake must succeed.
+ executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */)
+
+ newEP, _, err := c.EP.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ newEP, _, err = c.EP.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Now verify that the TCP socket is usable and in a connected state.
+ data := "Don't panic"
+ newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
+ b := c.GetPacket()
+ tcp := header.TCP(header.IPv4(b).Payload())
+ if string(tcp.Payload()) != data {
+ t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
+ }
+}
+
+// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a
+// non unicast IPv4 address are not accepted.
+func TestListenNoAcceptNonUnicastV4(t *testing.T) {
+ multicastAddr := tcpip.Address("\xe0\x00\x01\x02")
+ otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03")
+
+ tests := []struct {
+ name string
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ }{
+ {
+ "SourceUnspecified",
+ header.IPv4Any,
+ context.StackAddr,
+ },
+ {
+ "SourceBroadcast",
+ header.IPv4Broadcast,
+ context.StackAddr,
+ },
+ {
+ "SourceOurMulticast",
+ multicastAddr,
+ context.StackAddr,
+ },
+ {
+ "SourceOtherMulticast",
+ otherMulticastAddr,
+ context.StackAddr,
+ },
+ {
+ "DestUnspecified",
+ context.TestAddr,
+ header.IPv4Any,
+ },
+ {
+ "DestBroadcast",
+ context.TestAddr,
+ header.IPv4Broadcast,
+ },
+ {
+ "DestOurMulticast",
+ context.TestAddr,
+ multicastAddr,
+ },
+ {
+ "DestOtherMulticast",
+ context.TestAddr,
+ otherMulticastAddr,
+ },
+ }
+
+ for _, test := range tests {
+ test := test // capture range variable
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup failed: %s", err)
+ }
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := c.EP.Listen(1); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ irs := seqnum.Value(789)
+ c.SendPacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, test.srcAddr, test.dstAddr)
+ c.CheckNoPacket("Should not have received a response")
+
+ // Handle normal packet.
+ c.SendPacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, context.TestAddr, context.StackAddr)
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1)))
+ })
+ }
+}
+
+// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a
+// non unicast IPv6 address are not accepted.
+func TestListenNoAcceptNonUnicastV6(t *testing.T) {
+ multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01")
+ otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02")
+
+ tests := []struct {
+ name string
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ }{
+ {
+ "SourceUnspecified",
+ header.IPv6Any,
+ context.StackV6Addr,
+ },
+ {
+ "SourceAllNodes",
+ header.IPv6AllNodesMulticastAddress,
+ context.StackV6Addr,
+ },
+ {
+ "SourceOurMulticast",
+ multicastAddr,
+ context.StackV6Addr,
+ },
+ {
+ "SourceOtherMulticast",
+ otherMulticastAddr,
+ context.StackV6Addr,
+ },
+ {
+ "DestUnspecified",
+ context.TestV6Addr,
+ header.IPv6Any,
+ },
+ {
+ "DestAllNodes",
+ context.TestV6Addr,
+ header.IPv6AllNodesMulticastAddress,
+ },
+ {
+ "DestOurMulticast",
+ context.TestV6Addr,
+ multicastAddr,
+ },
+ {
+ "DestOtherMulticast",
+ context.TestV6Addr,
+ otherMulticastAddr,
+ },
+ }
+
+ for _, test := range tests {
+ test := test // capture range variable
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup failed: %s", err)
+ }
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := c.EP.Listen(1); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ irs := seqnum.Value(789)
+ c.SendV6PacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, test.srcAddr, test.dstAddr)
+ c.CheckNoPacket("Should not have received a response")
+
+ // Handle normal packet.
+ c.SendV6PacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, context.TestV6Addr, context.StackV6Addr)
+ checker.IPv6(t, c.GetV6Packet(),
+ checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1)))
+ })
+ }
+}
+
+func TestListenSynRcvdQueueFull(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Test acceptance.
+ // Start listening.
+ listenBacklog := 1
+ if err := c.EP.Listen(listenBacklog); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send two SYN's the first one should get a SYN-ACK, the
+ // second one should not get any response and is dropped as
+ // the synRcvd count will be equal to backlog.
+ irs := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcp := header.TCP(header.IPv4(b).Payload())
+ iss := seqnum.Value(tcp.SequenceNumber())
+ tcpCheckers := []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
+ checker.AckNum(uint32(irs) + 1),
+ }
+ checker.IPv4(t, b, checker.TCP(tcpCheckers...))
+
+ // Now execute send one more SYN. The stack should not respond as the backlog
+ // is full at this point.
+ //
+ // NOTE: we did not complete the handshake for the previous one so the
+ // accept backlog should be empty and there should be one connection in
+ // synRcvd state.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort + 1,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: seqnum.Value(889),
+ RcvWnd: 30000,
+ })
+ c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
+
+ // Now complete the previous connection and verify that there is a connection
+ // to accept.
+ // Send ACK.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+
+ // Try to accept the connections in the backlog.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ newEP, _, err := c.EP.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ newEP, _, err = c.EP.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Now verify that the TCP socket is usable and in a connected state.
+ data := "Don't panic"
+ newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
+ pkt := c.GetPacket()
+ tcp = header.TCP(header.IPv4(pkt).Payload())
+ if string(tcp.Payload()) != data {
+ t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
+ }
+}
+
+func TestListenBacklogFullSynCookieInUse(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(1)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 1 failed: %s", err)
+ }
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Test acceptance.
+ // Start listening.
+ listenBacklog := 1
+ portOffset := uint16(0)
+ if err := c.EP.Listen(listenBacklog); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ executeHandshake(t, c, context.TestPort+portOffset, false)
+ portOffset++
+ // Wait for this to be delivered to the accept queue.
+ time.Sleep(50 * time.Millisecond)
+
+ // Send a SYN request.
+ irs := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ // pick a different src port for new SYN.
+ SrcPort: context.TestPort + 1,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+ // The Syn should be dropped as the endpoint's backlog is full.
+ c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
+
+ // Verify that there is only one acceptable connection at this point.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ _, _, err = c.EP.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ _, _, err = c.EP.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Now verify that there are no more connections that can be accepted.
+ _, _, err = c.EP.Accept()
+ if err != tcpip.ErrWouldBlock {
+ select {
+ case <-ch:
+ t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
+ case <-time.After(1 * time.Second):
+ }
+ }
+}
+
+func TestSynRcvdBadSeqNumber(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Start listening.
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state
+ irs := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ iss := seqnum.Value(tcpHdr.SequenceNumber())
+ tcpCheckers := []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
+ checker.AckNum(uint32(irs) + 1),
+ }
+ checker.IPv4(t, b, checker.TCP(tcpCheckers...))
+
+ // Now send a packet with an out-of-window sequence number
+ largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: largeSeqnum,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+
+ // Should receive an ACK with the expected SEQ number
+ b = c.GetPacket()
+ tcpCheckers = []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.AckNum(uint32(irs) + 1),
+ checker.SeqNum(uint32(iss + 1)),
+ }
+ checker.IPv4(t, b, checker.TCP(tcpCheckers...))
+
+ // Now that the socket replied appropriately with the ACK,
+ // complete the connection to test that the large SEQ num
+ // did not change the state from SYN-RCVD.
+
+ // Send ACK to move to ESTABLISHED state.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+
+ newEP, _, err := c.EP.Accept()
+
+ if err != nil && err != tcpip.ErrWouldBlock {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ if err == tcpip.ErrWouldBlock {
+ // Try to accept the connections in the backlog.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ newEP, _, err = c.EP.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Now verify that the TCP socket is usable and in a connected state.
+ data := "Don't panic"
+ _, _, err = newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
+
+ if err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ pkt := c.GetPacket()
+ tcpHdr = header.TCP(header.IPv4(pkt).Payload())
+ if string(tcpHdr.Payload()) != data {
+ t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data)
+ }
+}
+
+func TestPassiveConnectionAttemptIncrement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ c.EP = ep
+ if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+ if err := c.EP.Listen(1); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want {
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+
+ stats := c.Stack().Stats()
+ want := stats.TCP.PassiveConnectionOpenings.Value() + 1
+
+ srcPort := uint16(context.TestPort)
+ executeHandshake(t, c, srcPort+1, false)
+
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ // Verify that there is only one acceptable connection at this point.
+ _, _, err = c.EP.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ _, _, err = c.EP.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want {
+ t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %d, want = %d", got, want)
+ }
+}
+
+func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ stats := c.Stack().Stats()
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ c.EP = ep
+ if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if err := c.EP.Listen(1); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ srcPort := uint16(context.TestPort)
+ // Now attempt a handshakes it will fill up the accept backlog.
+ executeHandshake(t, c, srcPort, false)
+
+ // Give time for the final ACK to be processed as otherwise the next handshake could
+ // get accepted before the previous one based on goroutine scheduling.
+ time.Sleep(50 * time.Millisecond)
+
+ want := stats.TCP.ListenOverflowSynDrop.Value() + 1
+
+ // Now we will send one more SYN and this one should get dropped
+ // Send a SYN request.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: srcPort + 2,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: seqnum.Value(789),
+ RcvWnd: 30000,
+ })
+
+ time.Sleep(50 * time.Millisecond)
+ if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
+ t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want)
+ }
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
+ t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want)
+ }
+
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ // Now check that there is one acceptable connections.
+ _, _, err = c.EP.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ _, _, err = c.EP.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+}
+
+func TestEndpointBindListenAcceptState(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+
+ if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected {
+ t.Errorf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrNotConnected)
+ }
+ if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
+ t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+
+ c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS})
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ aep, _, err := ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ aep, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+ if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+ if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected {
+ t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %s, want: %s", err, tcpip.ErrAlreadyConnected)
+ }
+ // Listening endpoint remains in listen state.
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+
+ ep.Close()
+ // Give worker goroutines time to receive the close notification.
+ time.Sleep(1 * time.Second)
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+ // Accepted endpoint remains open when the listen endpoint is closed.
+ if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+
+}
+
+// This test verifies that the auto tuning does not grow the receive buffer if
+// the application is not reading the data actively.
+func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
+ const mtu = 1500
+ const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
+
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ stk := c.Stack()
+ // Set lower limits for auto-tuning tests. This is required because the
+ // test stops the worker which can cause packets to be dropped because
+ // the segment queue holding unprocessed packets is limited to 500.
+ const receiveBufferSize = 80 << 10 // 80KB.
+ const maxReceiveBufferSize = receiveBufferSize * 10
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ }
+
+ // Enable auto-tuning.
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ }
+ // Change the expected window scale to match the value needed for the
+ // maximum buffer size defined above.
+ c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
+
+ rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
+
+ // NOTE: The timestamp values in the sent packets are meaningless to the
+ // peer so we just increment the timestamp value by 1 every batch as we
+ // are not really using them for anything. Send a single byte to verify
+ // the advertised window.
+ tsVal := rawEP.TSVal + 1
+
+ // Introduce a 25ms latency by delaying the first byte.
+ latency := 25 * time.Millisecond
+ time.Sleep(latency)
+ rawEP.SendPacketWithTS([]byte{1}, tsVal)
+
+ // Verify that the ACK has the expected window.
+ wantRcvWnd := receiveBufferSize
+ wantRcvWnd = (wantRcvWnd >> uint32(c.WindowScale))
+ rawEP.VerifyACKRcvWnd(uint16(wantRcvWnd - 1))
+ time.Sleep(25 * time.Millisecond)
+
+ // Allocate a large enough payload for the test.
+ b := make([]byte, int(receiveBufferSize)*2)
+ offset := 0
+ payloadSize := receiveBufferSize - 1
+ worker := (c.EP).(interface {
+ StopWork()
+ ResumeWork()
+ })
+ tsVal++
+
+ // Stop the worker goroutine.
+ worker.StopWork()
+ start := offset
+ end := offset + payloadSize
+ packetsSent := 0
+ for ; start < end; start += mss {
+ rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
+ packetsSent++
+ }
+
+ // Resume the worker so that it only sees the packets once all of them
+ // are waiting to be read.
+ worker.ResumeWork()
+
+ // Since we read no bytes the window should goto zero till the
+ // application reads some of the data.
+ // Discard all intermediate acks except the last one.
+ if packetsSent > 100 {
+ for i := 0; i < (packetsSent / 100); i++ {
+ _ = c.GetPacket()
+ }
+ }
+ rawEP.VerifyACKRcvWnd(0)
+
+ time.Sleep(25 * time.Millisecond)
+ // Verify that sending more data when window is closed is dropped and
+ // not acked.
+ rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
+
+ // Verify that the stack sends us back an ACK with the sequence number
+ // of the last packet sent indicating it was dropped.
+ p := c.GetPacket()
+ checker.IPv4(t, p, checker.TCP(
+ checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)),
+ checker.Window(0),
+ ))
+
+ // Now read all the data from the endpoint and verify that advertised
+ // window increases to the full available buffer size.
+ for {
+ _, _, err := c.EP.Read(nil)
+ if err == tcpip.ErrWouldBlock {
+ break
+ }
+ }
+
+ // Verify that we receive a non-zero window update ACK. When running
+ // under thread santizer this test can end up sending more than 1
+ // ack, 1 for the non-zero window
+ p = c.GetPacket()
+ checker.IPv4(t, p, checker.TCP(
+ checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)),
+ func(t *testing.T, h header.Transport) {
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+ if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) {
+ t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w)
+ }
+ },
+ ))
+}
+
+// This test verifies that the auto tuning does not grow the receive buffer if
+// the application is not reading the data actively.
+func TestReceiveBufferAutoTuning(t *testing.T) {
+ const mtu = 1500
+ const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
+
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ // Enable Auto-tuning.
+ stk := c.Stack()
+ // Set lower limits for auto-tuning tests. This is required because the
+ // test stops the worker which can cause packets to be dropped because
+ // the segment queue holding unprocessed packets is limited to 300.
+ const receiveBufferSize = 80 << 10 // 80KB.
+ const maxReceiveBufferSize = receiveBufferSize * 10
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ }
+
+ // Enable auto-tuning.
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ }
+ // Change the expected window scale to match the value needed for the
+ // maximum buffer size used by stack.
+ c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
+
+ rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
+
+ wantRcvWnd := receiveBufferSize
+ scaleRcvWnd := func(rcvWnd int) uint16 {
+ return uint16(rcvWnd >> uint16(c.WindowScale))
+ }
+ // Allocate a large array to send to the endpoint.
+ b := make([]byte, receiveBufferSize*48)
+
+ // In every iteration we will send double the number of bytes sent in
+ // the previous iteration and read the same from the app. The received
+ // window should grow by at least 2x of bytes read by the app in every
+ // RTT.
+ offset := 0
+ payloadSize := receiveBufferSize / 8
+ worker := (c.EP).(interface {
+ StopWork()
+ ResumeWork()
+ })
+ tsVal := rawEP.TSVal
+ // We are going to do our own computation of what the moderated receive
+ // buffer should be based on sent/copied data per RTT and verify that
+ // the advertised window by the stack matches our calculations.
+ prevCopied := 0
+ done := false
+ latency := 1 * time.Millisecond
+ for i := 0; !done; i++ {
+ tsVal++
+
+ // Stop the worker goroutine.
+ worker.StopWork()
+ start := offset
+ end := offset + payloadSize
+ totalSent := 0
+ packetsSent := 0
+ for ; start < end; start += mss {
+ rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
+ totalSent += mss
+ packetsSent++
+ }
+
+ // Resume it so that it only sees the packets once all of them
+ // are waiting to be read.
+ worker.ResumeWork()
+
+ // Give 1ms for the worker to process the packets.
+ time.Sleep(1 * time.Millisecond)
+
+ // Verify that the advertised window on the ACK is reduced by
+ // the total bytes sent.
+ expectedWnd := wantRcvWnd - totalSent
+ if packetsSent > 100 {
+ for i := 0; i < (packetsSent / 100); i++ {
+ _ = c.GetPacket()
+ }
+ }
+ rawEP.VerifyACKRcvWnd(scaleRcvWnd(expectedWnd))
+
+ // Now read all the data from the endpoint and invoke the
+ // moderation API to allow for receive buffer auto-tuning
+ // to happen before we measure the new window.
+ totalCopied := 0
+ for {
+ b, _, err := c.EP.Read(nil)
+ if err == tcpip.ErrWouldBlock {
+ break
+ }
+ totalCopied += len(b)
+ }
+
+ // Invoke the moderation API. This is required for auto-tuning
+ // to happen. This method is normally expected to be invoked
+ // from a higher layer than tcpip.Endpoint. So we simulate
+ // copying to userspace by invoking it explicitly here.
+ c.EP.ModerateRecvBuf(totalCopied)
+
+ // Now send a keep-alive packet to trigger an ACK so that we can
+ // measure the new window.
+ rawEP.NextSeqNum--
+ rawEP.SendPacketWithTS(nil, tsVal)
+ rawEP.NextSeqNum++
+
+ if i == 0 {
+ // In the first iteration the receiver based RTT is not
+ // yet known as a result the moderation code should not
+ // increase the advertised window.
+ rawEP.VerifyACKRcvWnd(scaleRcvWnd(wantRcvWnd))
+ prevCopied = totalCopied
+ } else {
+ rttCopied := totalCopied
+ if i == 1 {
+ // The moderation code accumulates copied bytes till
+ // RTT is established. So add in the bytes sent in
+ // the first iteration to the total bytes for this
+ // RTT.
+ rttCopied += prevCopied
+ // Now reset it to the initial value used by the
+ // auto tuning logic.
+ prevCopied = tcp.InitialCwnd * mss * 2
+ }
+ newWnd := rttCopied<<1 + 16*mss
+ grow := (newWnd * (rttCopied - prevCopied)) / prevCopied
+ newWnd += (grow << 1)
+ if newWnd > maxReceiveBufferSize {
+ newWnd = maxReceiveBufferSize
+ done = true
+ }
+ rawEP.VerifyACKRcvWnd(scaleRcvWnd(newWnd))
+ wantRcvWnd = newWnd
+ prevCopied = rttCopied
+ // Increase the latency after first two iterations to
+ // establish a low RTT value in the receiver since it
+ // only tracks the lowest value. This ensures that when
+ // ModerateRcvBuf is called the elapsed time is always >
+ // rtt. Without this the test is flaky due to delays due
+ // to scheduling/wakeup etc.
+ latency += 50 * time.Millisecond
+ }
+ time.Sleep(latency)
+ offset += payloadSize
+ payloadSize *= 2
+ }
+}
+
+func TestDelayEnabled(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ checkDelayOption(t, c, false, false) // Delay is disabled by default.
+
+ for _, v := range []struct {
+ delayEnabled tcp.DelayEnabled
+ wantDelayOption bool
+ }{
+ {delayEnabled: false, wantDelayOption: false},
+ {delayEnabled: true, wantDelayOption: true},
+ } {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil {
+ t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %s", v.delayEnabled, err)
+ }
+ checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption)
+ }
+}
+
+func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) {
+ t.Helper()
+
+ var gotDelayEnabled tcp.DelayEnabled
+ if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil {
+ t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err)
+ }
+ if gotDelayEnabled != wantDelayEnabled {
+ t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled)
+ }
+
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue))
+ if err != nil {
+ t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err)
+ }
+ gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption)
+ if err != nil {
+ t.Fatalf("ep.GetSockOptBool(tcpip.DelayOption) failed: %s", err)
+ }
+ if gotDelayOption != wantDelayOption {
+ t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption)
+ }
+}
+
+func TestTCPLingerTimeout(t *testing.T) {
+ c := context.New(t, 1500 /* mtu */)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ testCases := []struct {
+ name string
+ tcpLingerTimeout time.Duration
+ want time.Duration
+ }{
+ {"NegativeLingerTimeout", -123123, 0},
+ {"ZeroLingerTimeout", 0, 0},
+ {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second},
+ // Values > stack's TCPLingerTimeout are capped to the stack's
+ // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds)
+ {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil {
+ t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err)
+ }
+ var v tcpip.TCPLingerTimeoutOption
+ if err := c.EP.GetSockOpt(&v); err != nil {
+ t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err)
+ }
+ if got, want := time.Duration(v), tc.want; got != want {
+ t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want)
+ }
+ })
+ }
+}
+
+func TestTCPTimeWaitRSTIgnored(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Now send a RST and this should be ignored and not
+ // generate an ACK.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagRst,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ })
+
+ c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second)
+
+ // Out of order ACK should generate an immediate ACK in
+ // TIME_WAIT.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 3,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+}
+
+func TestTCPTimeWaitOutOfOrder(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Out of order ACK should generate an immediate ACK in
+ // TIME_WAIT.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 3,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+}
+
+func TestTCPTimeWaitNewSyn(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Send a SYN request w/ sequence number lower than
+ // the highest sequence number sent. We just reuse
+ // the same number.
+ iss = seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second)
+
+ // Send a SYN request w/ sequence number higher than
+ // the highest sequence number sent.
+ iss = seqnum.Value(792)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b = c.GetPacket()
+ tcpHdr = header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+}
+
+func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
+ // after 5 seconds in TIME_WAIT state.
+ tcpTimeWaitTimeout := 5 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
+ }
+
+ want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ time.Sleep(2 * time.Second)
+
+ // Now send a duplicate FIN. This should cause the TIME_WAIT to extend
+ // by another 5 seconds and also send us a duplicate ACK as it should
+ // indicate that the final ACK was potentially lost.
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Sleep for 4 seconds so at this point we are 1 second past the
+ // original tcpLingerTimeout of 5 seconds.
+ time.Sleep(4 * time.Second)
+
+ // Send an ACK and it should not generate any packet as the socket
+ // should still be in TIME_WAIT for another another 5 seconds due
+ // to the duplicate FIN we sent earlier.
+ *ackHeaders = *finHeaders
+ ackHeaders.SeqNum = ackHeaders.SeqNum + 1
+ ackHeaders.Flags = header.TCPFlagAck
+ c.SendPacket(nil, ackHeaders)
+
+ c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second)
+ // Now sleep for another 2 seconds so that we are past the
+ // extended TIME_WAIT of 7 seconds (2 + 5).
+ time.Sleep(2 * time.Second)
+
+ // Resend the same ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Receive the RST that should be generated as there is no valid
+ // endpoint.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(ackHeaders.AckNum)),
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst)))
+
+ if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = %d", got, want)
+ }
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+}
+
+func TestTCPCloseWithData(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
+ // after 5 seconds in TIME_WAIT state.
+ tcpTimeWaitTimeout := 5 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
+ }
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ RcvWnd: 30000,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Now trigger a passive close by sending a FIN.
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ RcvWnd: 30000,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Now write a few bytes and then close the endpoint.
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that data is received.
+ b = c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
+ t.Errorf("got data = %x, want = %x", p, data)
+ }
+
+ c.EP.Close()
+ // Check the FIN.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))),
+ checker.AckNum(uint32(iss+2)),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ // First send a partial ACK.
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 2,
+ AckNum: c.IRS + 1 + seqnum.Value(len(data)-1),
+ RcvWnd: 30000,
+ }
+ c.SendPacket(nil, ackHeaders)
+
+ // Now send a full ACK.
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 2,
+ AckNum: c.IRS + 1 + seqnum.Value(len(data)),
+ RcvWnd: 30000,
+ }
+ c.SendPacket(nil, ackHeaders)
+
+ // Now ACK the FIN.
+ ackHeaders.AckNum++
+ c.SendPacket(nil, ackHeaders)
+
+ // Now send an ACK and we should get a RST back as the endpoint should
+ // be in CLOSED state.
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 2,
+ AckNum: c.IRS + 1 + seqnum.Value(len(data)),
+ RcvWnd: 30000,
+ }
+ c.SendPacket(nil, ackHeaders)
+
+ // Check the RST.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(ackHeaders.AckNum)),
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst)))
+}
+
+func TestTCPUserTimeout(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
+
+ // Ensure that on the next retransmit timer fire, the user timeout has
+ // expired.
+ initRTO := 1 * time.Second
+ userTimeout := initRTO / 2
+ c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout))
+
+ // Send some data and wait before ACKing it.
+ view := buffer.NewView(3)
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ next := uint32(c.IRS) + 1
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ // Wait for the retransmit timer to be fired and the user timeout to cause
+ // close of the connection.
+ select {
+ case <-notifyCh:
+ case <-time.After(2 * initRTO):
+ t.Fatalf("connection still alive after %s, should have been closed after :%s", 2*initRTO, userTimeout)
+ }
+
+ // No packet should be received as the connection should be silently
+ // closed due to timeout.
+ c.CheckNoPacket("unexpected packet received after userTimeout has expired")
+
+ next += uint32(len(view))
+
+ // The connection should be terminated after userTimeout has expired.
+ // Send an ACK to trigger a RST from the stack as the endpoint should
+ // be dead.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(next)),
+ checker.AckNum(uint32(0)),
+ checker.TCPFlags(header.TCPFlagRst),
+ ),
+ )
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
+ }
+
+ if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
+}
+
+func TestKeepaliveWithUserTimeout(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
+
+ const keepAliveInterval = 3 * time.Second
+ c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond))
+ c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
+ c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10)
+ c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true)
+
+ // Set userTimeout to be the duration to be 1 keepalive
+ // probes. Which means that after the first probe is sent
+ // the second one should cause the connection to be
+ // closed due to userTimeout being hit.
+ userTimeout := 1 * keepAliveInterval
+ c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout))
+
+ // Check that the connection is still alive.
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Now receive 1 keepalives, but don't ACK it.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)),
+ checker.AckNum(uint32(790)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Sleep for a litte over the KeepAlive interval to make sure
+ // the timer has time to fire after the last ACK and close the
+ // close the socket.
+ time.Sleep(keepAliveInterval + keepAliveInterval/2)
+
+ // The connection should be closed with a timeout.
+ // Send an ACK to trigger a RST from the stack as the endpoint should
+ // be dead.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(c.IRS + 1),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(0)),
+ checker.TCPFlags(header.TCPFlagRst),
+ ),
+ )
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
+ }
+ if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
+}
+
+func TestIncreaseWindowOnReceive(t *testing.T) {
+ // This test ensures that the endpoint sends an ack,
+ // after recv() when the window grows to more than 1 MSS.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ const rcvBuf = 65535 * 10
+ c.CreateConnected(789, 30000, rcvBuf)
+
+ // Write chunks of ~30000 bytes. It's important that two
+ // payloads make it equal or longer than MSS.
+ remain := rcvBuf
+ sent := 0
+ data := make([]byte, defaultMTU/2)
+ lastWnd := uint16(0)
+
+ for remain > len(data) {
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + sent),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ sent += len(data)
+ remain -= len(data)
+
+ lastWnd = uint16(remain)
+ if remain > 0xffff {
+ lastWnd = 0xffff
+ }
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(lastWnd),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ if lastWnd == 0xffff || lastWnd == 0 {
+ t.Fatalf("expected small, non-zero window: %d", lastWnd)
+ }
+
+ // We now have < 1 MSS in the buffer space. Read the data! An
+ // ack should be sent in response to that. The window was not
+ // zero, but it grew to larger than MSS.
+ if _, _, err := c.EP.Read(nil); err != nil {
+ t.Fatalf("Read failed: %s", err)
+ }
+
+ if _, _, err := c.EP.Read(nil); err != nil {
+ t.Fatalf("Read failed: %s", err)
+ }
+
+ // After reading two packets, we surely crossed MSS. See the ack:
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(uint16(0xffff)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestIncreaseWindowOnBufferResize(t *testing.T) {
+ // This test ensures that the endpoint sends an ack,
+ // after available recv buffer grows to more than 1 MSS.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ const rcvBuf = 65535 * 10
+ c.CreateConnected(789, 30000, rcvBuf)
+
+ // Write chunks of ~30000 bytes. It's important that two
+ // payloads make it equal or longer than MSS.
+ remain := rcvBuf
+ sent := 0
+ data := make([]byte, defaultMTU/2)
+ lastWnd := uint16(0)
+
+ for remain > len(data) {
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + sent),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ sent += len(data)
+ remain -= len(data)
+
+ lastWnd = uint16(remain)
+ if remain > 0xffff {
+ lastWnd = 0xffff
+ }
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(lastWnd),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ if lastWnd == 0xffff || lastWnd == 0 {
+ t.Fatalf("expected small, non-zero window: %d", lastWnd)
+ }
+
+ // Increasing the buffer from should generate an ACK,
+ // since window grew from small value to larger equal MSS
+ c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2)
+
+ // After reading two packets, we surely crossed MSS. See the ack:
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(uint16(0xffff)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestTCPDeferAccept(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ const tcpDeferAccept = 1 * time.Second
+ if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err)
+ }
+
+ irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Send data. This should result in an acceptable endpoint.
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+
+ // Give a bit of time for the socket to be delivered to the accept queue.
+ time.Sleep(50 * time.Millisecond)
+ aep, _, err := c.EP.Accept()
+ if err != nil {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err)
+ }
+
+ aep.Close()
+ // Closing aep without reading the data should trigger a RST.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
+
+func TestTCPDeferAcceptTimeout(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ const tcpDeferAccept = 1 * time.Second
+ if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err)
+ }
+
+ irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Sleep for a little of the tcpDeferAccept timeout.
+ time.Sleep(tcpDeferAccept + 100*time.Millisecond)
+
+ // On timeout expiry we should get a SYN-ACK retransmission.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1)))
+
+ // Send data. This should result in an acceptable endpoint.
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+
+ // Give sometime for the endpoint to be delivered to the accept queue.
+ time.Sleep(50 * time.Millisecond)
+ aep, _, err := c.EP.Accept()
+ if err != nil {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err)
+ }
+
+ aep.Close()
+ // Closing aep without reading the data should trigger a RST.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
+
+func TestResetDuringClose(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ iss := seqnum.Value(789)
+ c.CreateConnected(iss, 30000, -1 /* epRecvBuf */)
+ // Send some data to make sure there is some unread
+ // data to trigger a reset on c.Close.
+ irs := c.IRS
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss.Add(1),
+ AckNum: irs.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(irs.Add(1))),
+ checker.AckNum(uint32(iss.Add(5)))))
+
+ // Close in a separate goroutine so that we can trigger
+ // a race with the RST we send below. This should not
+ // panic due to the route being released depeding on
+ // whether Close() sends an active RST or the RST sent
+ // below is processed by the worker first.
+ var wg sync.WaitGroup
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ SeqNum: iss.Add(5),
+ AckNum: c.IRS.Add(5),
+ RcvWnd: 30000,
+ Flags: header.TCPFlagRst,
+ })
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ c.EP.Close()
+ }()
+
+ wg.Wait()
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
new file mode 100644
index 000000000..8edbff964
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -0,0 +1,291 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_test
+
+import (
+ "bytes"
+ "math/rand"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// createConnectedWithTimestampOption creates and connects c.ep with the
+// timestamp option enabled.
+func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint {
+ return c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, TSVal: 1})
+}
+
+// TestTimeStampEnabledConnect tests that netstack sends the timestamp option on
+// an active connect and sets the TS Echo Reply fields correctly when the
+// SYN-ACK also indicates support for the TS option and provides a TSVal.
+func TestTimeStampEnabledConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ rep := createConnectedWithTimestampOption(c)
+
+ // Register for read and validate that we have data to read.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ // The following tests ensure that TS option once enabled behaves
+ // correctly as described in
+ // https://tools.ietf.org/html/rfc7323#section-4.3.
+ //
+ // We are not testing delayed ACKs here, but we do test out of order
+ // packet delivery and filling the sequence number hole created due to
+ // the out of order packet.
+ //
+ // The test also verifies that the sequence numbers and timestamps are
+ // as expected.
+ data := []byte{1, 2, 3}
+
+ // First we increment tsVal by a small amount.
+ tsVal := rep.TSVal + 100
+ rep.SendPacketWithTS(data, tsVal)
+ rep.VerifyACKWithTS(tsVal)
+
+ // Next we send an out of order packet.
+ rep.NextSeqNum += 3
+ tsVal += 200
+ rep.SendPacketWithTS(data, tsVal)
+
+ // The ACK should contain the original sequenceNumber and an older TS.
+ rep.NextSeqNum -= 6
+ rep.VerifyACKWithTS(tsVal - 200)
+
+ // Next we fill the hole and the returned ACK should contain the
+ // cumulative sequence number acking all data sent till now and have the
+ // latest timestamp sent below in its TSEcr field.
+ tsVal -= 100
+ rep.SendPacketWithTS(data, tsVal)
+ rep.NextSeqNum += 3
+ rep.VerifyACKWithTS(tsVal)
+
+ // Increment tsVal by a large value that doesn't result in a wrap around.
+ tsVal += 0x7fffffff
+ rep.SendPacketWithTS(data, tsVal)
+ rep.VerifyACKWithTS(tsVal)
+
+ // Increment tsVal again by a large value which should cause the
+ // timestamp value to wrap around. The returned ACK should contain the
+ // wrapped around timestamp in its tsEcr field and not the tsVal from
+ // the previous packet sent above.
+ tsVal += 0x7fffffff
+ rep.SendPacketWithTS(data, tsVal)
+ rep.VerifyACKWithTS(tsVal)
+
+ select {
+ case <-ch:
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ // There should be 5 views to read and each of them should
+ // contain the same data.
+ for i := 0; i < 5; i++ {
+ got, _, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+ if want := data; bytes.Compare(got, want) != 0 {
+ t.Fatalf("Data is different: got: %v, want: %v", got, want)
+ }
+ }
+}
+
+// TestTimeStampDisabledConnect tests that netstack sends timestamp option on an
+// active connect but if the SYN-ACK doesn't specify the TS option then
+// timestamp option is not enabled and future packets do not contain a
+// timestamp.
+func TestTimeStampDisabledConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnectedWithOptions(header.TCPSynOptions{})
+}
+
+func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ if cookieEnabled {
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
+ }
+
+ t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
+ tsVal := rand.Uint32()
+ c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal})
+
+ // Now send some data and validate that timestamp is echoed correctly in the ACK.
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %s", err)
+ }
+
+ // Check that data is received and that the timestamp option TSEcr field
+ // matches the expected value.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ // Add 12 bytes for the timestamp option + 2 NOPs to align at 4
+ // byte boundary.
+ checker.PayloadLen(len(data)+header.TCPMinimumSize+12),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.Window(wndSize),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPTimestampChecker(true, 0, tsVal+1),
+ ),
+ )
+}
+
+// TestTimeStampEnabledAccept tests that if the SYN on a passive connect
+// specifies the Timestamp option then the Timestamp option is sent on a SYN-ACK
+// and echoes the tsVal field of the original SYN in the tcEcr field of the
+// SYN-ACK. We cover the cases where SYN cookies are enabled/disabled and verify
+// that Timestamp option is enabled in both cases if requested in the original
+// SYN.
+func TestTimeStampEnabledAccept(t *testing.T) {
+ testCases := []struct {
+ cookieEnabled bool
+ wndScale int
+ wndSize uint16
+ }{
+ {true, -1, 0xffff}, // When cookie is used window scaling is disabled.
+ {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5.
+ }
+ for _, tc := range testCases {
+ timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
+ }
+}
+
+func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ if cookieEnabled {
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
+ }
+
+ t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
+ c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
+
+ // Now send some data with the accepted connection endpoint and validate
+ // that no timestamp option is sent in the TCP segment.
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %s", err)
+ }
+
+ // Check that data is received and that the timestamp option is disabled
+ // when SYN cookies are enabled/disabled.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.Window(wndSize),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPTimestampChecker(false, 0, 0),
+ ),
+ )
+}
+
+// TestTimeStampDisabledAccept tests that Timestamp option is not used when the
+// peer doesn't advertise it and connection is established with Accept().
+func TestTimeStampDisabledAccept(t *testing.T) {
+ testCases := []struct {
+ cookieEnabled bool
+ wndScale int
+ wndSize uint16
+ }{
+ {true, -1, 0xffff}, // When cookie is used window scaling is disabled.
+ {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5.
+ }
+ for _, tc := range testCases {
+ timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
+ }
+}
+
+func TestSendGreaterThanMTUWithOptions(t *testing.T) {
+ const maxPayload = 100
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ createConnectedWithTimestampOption(c)
+ testBrokenUpWrite(t, c, maxPayload)
+}
+
+func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) {
+ const maxPayload = 100
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ rep := createConnectedWithTimestampOption(c)
+
+ // Register for read.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ droppedPacketsStat := c.Stack().Stats().DroppedPackets
+ droppedPackets := droppedPacketsStat.Value()
+ data := []byte{1, 2, 3}
+ // Send a packet with no TCP options/timestamp.
+ rep.SendPacket(data, nil)
+
+ select {
+ case <-ch:
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ // Assert that DroppedPackets was not incremented.
+ if got, want := droppedPacketsStat.Value(), droppedPackets; got != want {
+ t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want)
+ }
+
+ // Issue a read and we should data.
+ got, _, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+ if want := data; bytes.Compare(got, want) != 0 {
+ t.Fatalf("Data is different: got: %v, want: %v", got, want)
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD
new file mode 100644
index 000000000..ce6a2c31d
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/testing/context/BUILD
@@ -0,0 +1,26 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "context",
+ testonly = 1,
+ srcs = ["context.go"],
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/seqnum",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
new file mode 100644
index 000000000..06fde2a79
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -0,0 +1,1121 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package context provides a test context for use in tcp tests. It also
+// provides helper methods to assert/check certain behaviours.
+package context
+
+import (
+ "bytes"
+ "context"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ // StackAddr is the IPv4 address assigned to the stack.
+ StackAddr = "\x0a\x00\x00\x01"
+
+ // StackPort is used as the listening port in tests for passive
+ // connects.
+ StackPort = 1234
+
+ // TestAddr is the source address for packets sent to the stack via the
+ // link layer endpoint.
+ TestAddr = "\x0a\x00\x00\x02"
+
+ // TestPort is the TCP port used for packets sent to the stack
+ // via the link layer endpoint.
+ TestPort = 4096
+
+ // StackV6Addr is the IPv6 address assigned to the stack.
+ StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+
+ // TestV6Addr is the source address for packets sent to the stack via
+ // the link layer endpoint.
+ TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
+ // StackV4MappedAddr is StackAddr as a mapped v6 address.
+ StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr
+
+ // TestV4MappedAddr is TestAddr as a mapped v6 address.
+ TestV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + TestAddr
+
+ // V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0.
+ V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
+
+ // testInitialSequenceNumber is the initial sequence number sent in packets that
+ // are sent in response to a SYN or in the initial SYN sent to the stack.
+ testInitialSequenceNumber = 789
+)
+
+// Headers is used to represent the TCP header fields when building a
+// new packet.
+type Headers struct {
+ // SrcPort holds the src port value to be used in the packet.
+ SrcPort uint16
+
+ // DstPort holds the destination port value to be used in the packet.
+ DstPort uint16
+
+ // SeqNum is the value of the sequence number field in the TCP header.
+ SeqNum seqnum.Value
+
+ // AckNum represents the acknowledgement number field in the TCP header.
+ AckNum seqnum.Value
+
+ // Flags are the TCP flags in the TCP header.
+ Flags int
+
+ // RcvWnd is the window to be advertised in the ReceiveWindow field of
+ // the TCP header.
+ RcvWnd seqnum.Size
+
+ // TCPOpts holds the options to be sent in the option field of the TCP
+ // header.
+ TCPOpts []byte
+}
+
+// Context provides an initialized Network stack and a link layer endpoint
+// for use in TCP tests.
+type Context struct {
+ t *testing.T
+ linkEP *channel.Endpoint
+ s *stack.Stack
+
+ // IRS holds the initial sequence number in the SYN sent by endpoint in
+ // case of an active connect or the sequence number sent by the endpoint
+ // in the SYN-ACK sent in response to a SYN when listening in passive
+ // mode.
+ IRS seqnum.Value
+
+ // Port holds the port bound by EP below in case of an active connect or
+ // the listening port number in case of a passive connect.
+ Port uint16
+
+ // EP is the test endpoint in the stack owned by this context. This endpoint
+ // is used in various tests to either initiate an active connect or is used
+ // as a passive listening endpoint to accept inbound connections.
+ EP tcpip.Endpoint
+
+ // Wq is the wait queue associated with EP and is used to block for events
+ // on EP.
+ WQ waiter.Queue
+
+ // TimeStampEnabled is true if ep is connected with the timestamp option
+ // enabled.
+ TimeStampEnabled bool
+
+ // WindowScale is the expected window scale in SYN packets sent by
+ // the stack.
+ WindowScale uint8
+}
+
+// New allocates and initializes a test context containing a new
+// stack and a link-layer endpoint.
+func New(t *testing.T, mtu uint32) *Context {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
+
+ // Allow minimum send/receive buffer sizes to be 1 during tests.
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: tcp.DefaultSendBufferSize, Max: 10 * tcp.DefaultSendBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ }
+
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: tcp.DefaultReceiveBufferSize, Max: 10 * tcp.DefaultReceiveBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ }
+
+ // Increase minimum RTO in tests to avoid test flakes due to early
+ // retransmit in case the test executors are overloaded and cause timers
+ // to fire earlier than expected.
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMinRTOOption(3*time.Second)); err != nil {
+ t.Fatalf("failed to set stack-wide minRTO: %s", err)
+ }
+
+ // Some of the congestion control tests send up to 640 packets, we so
+ // set the channel size to 1000.
+ ep := channel.New(1000, mtu, "")
+ wep := stack.LinkEndpoint(ep)
+ if testing.Verbose() {
+ wep = sniffer.New(ep)
+ }
+ opts := stack.NICOptions{Name: "nic1"}
+ if err := s.CreateNICWithOptions(1, wep, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
+ }
+ wep2 := stack.LinkEndpoint(channel.New(1000, mtu, ""))
+ if testing.Verbose() {
+ wep2 = sniffer.New(channel.New(1000, mtu, ""))
+ }
+ opts2 := stack.NICOptions{Name: "nic2"}
+ if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil {
+ t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err)
+ }
+
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: 1,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: 1,
+ },
+ })
+
+ return &Context{
+ t: t,
+ s: s,
+ linkEP: ep,
+ WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)),
+ }
+}
+
+// Cleanup closes the context endpoint if required.
+func (c *Context) Cleanup() {
+ if c.EP != nil {
+ c.EP.Close()
+ }
+ c.Stack().Close()
+}
+
+// Stack returns a reference to the stack in the Context.
+func (c *Context) Stack() *stack.Stack {
+ return c.s
+}
+
+// CheckNoPacketTimeout verifies that no packet is received during the time
+// specified by wait.
+func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) {
+ c.t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), wait)
+ defer cancel()
+ if _, ok := c.linkEP.ReadContext(ctx); ok {
+ c.t.Fatal(errMsg)
+ }
+}
+
+// CheckNoPacket verifies that no packet is received for 1 second.
+func (c *Context) CheckNoPacket(errMsg string) {
+ c.CheckNoPacketTimeout(errMsg, 1*time.Second)
+}
+
+// GetPacket reads a packet from the link layer endpoint and verifies
+// that it is an IPv4 packet with the expected source and destination
+// addresses. It will fail with an error if no packet is received for
+// 2 seconds.
+func (c *Context) GetPacket() []byte {
+ c.t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ c.t.Fatalf("Packet wasn't written out")
+ return nil
+ }
+
+ if p.Proto != ipv4.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ }
+
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+
+ if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize {
+ c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize)
+ }
+
+ checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
+ return b
+}
+
+// GetPacketNonBlocking reads a packet from the link layer endpoint
+// and verifies that it is an IPv4 packet with the expected source
+// and destination address. If no packet is available it will return
+// nil immediately.
+func (c *Context) GetPacketNonBlocking() []byte {
+ c.t.Helper()
+
+ p, ok := c.linkEP.Read()
+ if !ok {
+ return nil
+ }
+
+ if p.Proto != ipv4.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ }
+
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+
+ checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
+ return b
+}
+
+// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
+func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) {
+ // Allocate a buffer data and headers.
+ buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2))
+ if len(buf) > maxTotalSize {
+ buf = buf[:maxTotalSize]
+ }
+
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(len(buf)),
+ TTL: 65,
+ Protocol: uint8(header.ICMPv4ProtocolNumber),
+ SrcAddr: TestAddr,
+ DstAddr: StackAddr,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ icmp := header.ICMPv4(buf[header.IPv4MinimumSize:])
+ icmp.SetType(typ)
+ icmp.SetCode(code)
+ const icmpv4VariableHeaderOffset = 4
+ copy(icmp[icmpv4VariableHeaderOffset:], p1)
+ copy(icmp[header.ICMPv4PayloadOffset:], p2)
+
+ // Inject packet.
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+}
+
+// BuildSegment builds a TCP segment based on the given Headers and payload.
+func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView {
+ return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr)
+}
+
+// BuildSegmentWithAddrs builds a TCP segment based on the given Headers,
+// payload and source and destination IPv4 addresses.
+func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView {
+ // Allocate a buffer for data and headers.
+ buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload))
+ copy(buf[len(buf)-len(payload):], payload)
+ copy(buf[len(buf)-len(payload)-len(h.TCPOpts):], h.TCPOpts)
+
+ // Initialize the IP header.
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(len(buf)),
+ TTL: 65,
+ Protocol: uint8(tcp.ProtocolNumber),
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ // Initialize the TCP header.
+ t := header.TCP(buf[header.IPv4MinimumSize:])
+ t.Encode(&header.TCPFields{
+ SrcPort: h.SrcPort,
+ DstPort: h.DstPort,
+ SeqNum: uint32(h.SeqNum),
+ AckNum: uint32(h.AckNum),
+ DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)),
+ Flags: uint8(h.Flags),
+ WindowSize: uint16(h.RcvWnd),
+ })
+
+ // Calculate the TCP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
+
+ // Calculate the TCP checksum and set it.
+ xsum = header.Checksum(payload, xsum)
+ t.SetChecksum(^t.CalculateChecksum(xsum))
+
+ // Inject packet.
+ return buf.ToVectorisedView()
+}
+
+// SendSegment sends a TCP segment that has already been built and written to a
+// buffer.VectorisedView.
+func (c *Context) SendSegment(s buffer.VectorisedView) {
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: s,
+ })
+}
+
+// SendPacket builds and sends a TCP segment(with the provided payload & TCP
+// headers) in an IPv4 packet via the link layer endpoint.
+func (c *Context) SendPacket(payload []byte, h *Headers) {
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: c.BuildSegment(payload, h),
+ })
+}
+
+// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload
+// & TCPheaders) in an IPv4 packet via the link layer endpoint using the
+// provided source and destination IPv4 addresses.
+func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: c.BuildSegmentWithAddrs(payload, h, src, dst),
+ })
+}
+
+// SendAck sends an ACK packet.
+func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) {
+ c.SendAckWithSACK(seq, bytesReceived, nil)
+}
+
+// SendAckWithSACK sends an ACK packet which includes the sackBlocks specified.
+func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlocks []header.SACKBlock) {
+ options := make([]byte, 40)
+ offset := 0
+ if len(sackBlocks) > 0 {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeSACKBlocks(sackBlocks, options[offset:])
+ }
+
+ c.SendPacket(nil, &Headers{
+ SrcPort: TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seq,
+ AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)),
+ RcvWnd: 30000,
+ TCPOpts: options[:offset],
+ })
+}
+
+// ReceiveAndCheckPacket reads a packet from the link layer endpoint and
+// verifies that the packet packet payload of packet matches the slice
+// of data indicated by offset & size.
+func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) {
+ c.t.Helper()
+
+ c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0)
+}
+
+// ReceiveAndCheckPacketWithOptions reads a packet from the link layer endpoint
+// and verifies that the packet packet payload of packet matches the slice of
+// data indicated by offset & size and skips optlen bytes in addition to the IP
+// TCP headers when comparing the data.
+func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) {
+ c.t.Helper()
+
+ b := c.GetPacket()
+ checker.IPv4(c.t, b,
+ checker.PayloadLen(size+header.TCPMinimumSize+optlen),
+ checker.TCP(
+ checker.DstPort(TestPort),
+ checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
+ checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ pdata := data[offset:][:size]
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize+optlen:]; bytes.Compare(pdata, p) != 0 {
+ c.t.Fatalf("Data is different: expected %v, got %v", pdata, p)
+ }
+}
+
+// ReceiveNonBlockingAndCheckPacket reads a packet from the link layer endpoint
+// and verifies that the packet packet payload of packet matches the slice of
+// data indicated by offset & size. It returns true if a packet was received and
+// processed.
+func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool {
+ c.t.Helper()
+
+ b := c.GetPacketNonBlocking()
+ if b == nil {
+ return false
+ }
+ checker.IPv4(c.t, b,
+ checker.PayloadLen(size+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(TestPort),
+ checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
+ checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ pdata := data[offset:][:size]
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 {
+ c.t.Fatalf("Data is different: expected %v, got %v", pdata, p)
+ }
+ return true
+}
+
+// CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only
+// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6
+// only endpoint instead of a default dual stack socket.
+func (c *Context) CreateV6Endpoint(v6only bool) {
+ var err *tcpip.Error
+ c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil {
+ c.t.Fatalf("SetSockOpt failed failed: %v", err)
+ }
+}
+
+// GetV6Packet reads a single packet from the link layer endpoint of the context
+// and asserts that it is an IPv6 Packet with the expected src/dest addresses.
+func (c *Context) GetV6Packet() []byte {
+ c.t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ c.t.Fatalf("Packet wasn't written out")
+ return nil
+ }
+
+ if p.Proto != ipv6.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
+ }
+ b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size())
+ copy(b, p.Pkt.Header.View())
+ copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView())
+
+ checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
+ return b
+}
+
+// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of
+// the context.
+func (c *Context) SendV6Packet(payload []byte, h *Headers) {
+ c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr)
+}
+
+// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer
+// endpoint of the context using the provided source and destination IPv6
+// addresses.
+func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
+ // Allocate a buffer for data and headers.
+ buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload))
+ copy(buf[len(buf)-len(payload):], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
+ NextHeader: uint8(tcp.ProtocolNumber),
+ HopLimit: 65,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ // Initialize the TCP header.
+ t := header.TCP(buf[header.IPv6MinimumSize:])
+ t.Encode(&header.TCPFields{
+ SrcPort: h.SrcPort,
+ DstPort: h.DstPort,
+ SeqNum: uint32(h.SeqNum),
+ AckNum: uint32(h.AckNum),
+ DataOffset: header.TCPMinimumSize,
+ Flags: uint8(h.Flags),
+ WindowSize: uint16(h.RcvWnd),
+ })
+
+ // Calculate the TCP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
+
+ // Calculate the TCP checksum and set it.
+ xsum = header.Checksum(payload, xsum)
+ t.SetChecksum(^t.CalculateChecksum(xsum))
+
+ // Inject packet.
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+}
+
+// CreateConnected creates a connected TCP endpoint.
+func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) {
+ c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil)
+}
+
+// Connect performs the 3-way handshake for c.EP with the provided Initial
+// Sequence Number (iss) and receive window(rcvWnd) and any options if
+// specified.
+//
+// It also sets the receive buffer for the endpoint to the specified
+// value in epRcvBuf.
+//
+// PreCondition: c.EP must already be created.
+func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) {
+ c.t.Helper()
+
+ // Start connection attempt.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted {
+ c.t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(c.t, b,
+ checker.TCP(
+ checker.DstPort(TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ c.SendPacket(nil, &Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: options,
+ })
+
+ // Receive ACK packet.
+ checker.IPv4(c.t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(iss)+1),
+ ),
+ )
+
+ // Wait for connection to be established.
+ select {
+ case <-notifyCh:
+ if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
+ c.t.Fatalf("Unexpected error when connecting: %v", err)
+ }
+ case <-time.After(1 * time.Second):
+ c.t.Fatalf("Timed out waiting for connection")
+ }
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
+ c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ c.Port = tcpHdr.SourcePort()
+}
+
+// Create creates a TCP endpoint.
+func (c *Context) Create(epRcvBuf int) {
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if epRcvBuf != -1 {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil {
+ c.t.Fatalf("SetSockOpt failed failed: %v", err)
+ }
+ }
+}
+
+// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
+// the specified option bytes as the Option field in the initial SYN packet.
+//
+// It also sets the receive buffer for the endpoint to the specified
+// value in epRcvBuf.
+func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) {
+ c.Create(epRcvBuf)
+ c.Connect(iss, rcvWnd, options)
+}
+
+// RawEndpoint is just a small wrapper around a TCP endpoint's state to make
+// sending data and ACK packets easy while being able to manipulate the sequence
+// numbers and timestamp values as needed.
+type RawEndpoint struct {
+ C *Context
+ SrcPort uint16
+ DstPort uint16
+ Flags int
+ NextSeqNum seqnum.Value
+ AckNum seqnum.Value
+ WndSize seqnum.Size
+ RecentTS uint32 // Stores the latest timestamp to echo back.
+ TSVal uint32 // TSVal stores the last timestamp sent by this endpoint.
+
+ // SackPermitted is true if SACKPermitted option was negotiated for this endpoint.
+ SACKPermitted bool
+}
+
+// SendPacketWithTS embeds the provided tsVal in the Timestamp option
+// for the packet to be sent out.
+func (r *RawEndpoint) SendPacketWithTS(payload []byte, tsVal uint32) {
+ r.TSVal = tsVal
+ tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
+ header.EncodeTSOption(r.TSVal, r.RecentTS, tsOpt[2:])
+ r.SendPacket(payload, tsOpt[:])
+}
+
+// SendPacket is a small wrapper function to build and send packets.
+func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) {
+ packetHeaders := &Headers{
+ SrcPort: r.SrcPort,
+ DstPort: r.DstPort,
+ Flags: r.Flags,
+ SeqNum: r.NextSeqNum,
+ AckNum: r.AckNum,
+ RcvWnd: r.WndSize,
+ TCPOpts: opts,
+ }
+ r.C.SendPacket(payload, packetHeaders)
+ r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload)))
+}
+
+// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided
+// tsVal.
+func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
+ // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4]
+ ackPacket := r.C.GetPacket()
+ checker.IPv4(r.C.t, ackPacket,
+ checker.TCP(
+ checker.DstPort(r.SrcPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(r.AckNum)),
+ checker.AckNum(uint32(r.NextSeqNum)),
+ checker.TCPTimestampChecker(true, 0, tsVal),
+ ),
+ )
+ // Store the parsed TSVal from the ack as recentTS.
+ tcpSeg := header.TCP(header.IPv4(ackPacket).Payload())
+ opts := tcpSeg.ParsedOptions()
+ r.RecentTS = opts.TSVal
+}
+
+// VerifyACKRcvWnd verifies that the window advertised by the incoming ACK
+// matches the provided rcvWnd.
+func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) {
+ ackPacket := r.C.GetPacket()
+ checker.IPv4(r.C.t, ackPacket,
+ checker.TCP(
+ checker.DstPort(r.SrcPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(r.AckNum)),
+ checker.AckNum(uint32(r.NextSeqNum)),
+ checker.Window(rcvWnd),
+ ),
+ )
+}
+
+// VerifyACKNoSACK verifies that the ACK does not contain a SACK block.
+func (r *RawEndpoint) VerifyACKNoSACK() {
+ r.VerifyACKHasSACK(nil)
+}
+
+// VerifyACKHasSACK verifies that the ACK contains the specified SACKBlocks.
+func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) {
+ // Read ACK and verify that the TCP options in the segment do
+ // not contain a SACK block.
+ ackPacket := r.C.GetPacket()
+ checker.IPv4(r.C.t, ackPacket,
+ checker.TCP(
+ checker.DstPort(r.SrcPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(r.AckNum)),
+ checker.AckNum(uint32(r.NextSeqNum)),
+ checker.TCPSACKBlockChecker(sackBlocks),
+ ),
+ )
+}
+
+// CreateConnectedWithOptions creates and connects c.ep with the specified TCP
+// options enabled and returns a RawEndpoint which represents the other end of
+// the connection.
+//
+// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK
+// does not carry an option that was not requested.
+func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint {
+ var err *tcpip.Error
+ c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err)
+ }
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want {
+ c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ // Start connection attempt.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort}
+ err = c.EP.Connect(testFullAddr)
+ if err != tcpip.ErrConnectStarted {
+ c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err)
+ }
+ // Receive SYN packet.
+ b := c.GetPacket()
+ // Validate that the syn has the timestamp option and a valid
+ // TS value.
+ mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
+
+ checker.IPv4(c.t, b,
+ checker.TCP(
+ checker.DstPort(TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{
+ MSS: mss,
+ TS: true,
+ WS: int(c.WindowScale),
+ SACKPermitted: c.SACKEnabled(),
+ }),
+ ),
+ )
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ tcpSeg := header.TCP(header.IPv4(b).Payload())
+ synOptions := header.ParseSynOptions(tcpSeg.Options(), false)
+
+ // Build options w/ tsVal to be sent in the SYN-ACK.
+ synAckOptions := make([]byte, header.TCPOptionsMaximumSize)
+ offset := 0
+ if wantOptions.WS != -1 {
+ offset += header.EncodeWSOption(wantOptions.WS, synAckOptions[offset:])
+ }
+ if wantOptions.TS {
+ offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:])
+ }
+ if wantOptions.SACKPermitted {
+ offset += header.EncodeSACKPermittedOption(synAckOptions[offset:])
+ }
+
+ offset += header.AddTCPOptionPadding(synAckOptions, offset)
+
+ // Build SYN-ACK.
+ c.IRS = seqnum.Value(tcpSeg.SequenceNumber())
+ iss := seqnum.Value(testInitialSequenceNumber)
+ c.SendPacket(nil, &Headers{
+ SrcPort: tcpSeg.DestinationPort(),
+ DstPort: tcpSeg.SourcePort(),
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ TCPOpts: synAckOptions[:offset],
+ })
+
+ // Read ACK.
+ ackPacket := c.GetPacket()
+
+ // Verify TCP header fields.
+ tcpCheckers := []checker.TransportChecker{
+ checker.DstPort(TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(c.IRS) + 1),
+ checker.AckNum(uint32(iss) + 1),
+ }
+
+ // Verify that tsEcr of ACK packet is wantOptions.TSVal if the
+ // timestamp option was enabled, if not then we verify that
+ // there is no timestamp in the ACK packet.
+ if wantOptions.TS {
+ tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(true, 0, wantOptions.TSVal))
+ } else {
+ tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0))
+ }
+
+ checker.IPv4(c.t, ackPacket, checker.TCP(tcpCheckers...))
+
+ ackSeg := header.TCP(header.IPv4(ackPacket).Payload())
+ ackOptions := ackSeg.ParsedOptions()
+
+ // Wait for connection to be established.
+ select {
+ case <-notifyCh:
+ err = c.EP.GetSockOpt(tcpip.ErrorOption{})
+ if err != nil {
+ c.t.Fatalf("Unexpected error when connecting: %v", err)
+ }
+ case <-time.After(1 * time.Second):
+ c.t.Fatalf("Timed out waiting for connection")
+ }
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
+ c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ // Store the source port in use by the endpoint.
+ c.Port = tcpSeg.SourcePort()
+
+ // Mark in context that timestamp option is enabled for this endpoint.
+ c.TimeStampEnabled = true
+
+ return &RawEndpoint{
+ C: c,
+ SrcPort: tcpSeg.DestinationPort(),
+ DstPort: tcpSeg.SourcePort(),
+ Flags: header.TCPFlagAck | header.TCPFlagPsh,
+ NextSeqNum: iss + 1,
+ AckNum: c.IRS.Add(1),
+ WndSize: 30000,
+ RecentTS: ackOptions.TSVal,
+ TSVal: wantOptions.TSVal,
+ SACKPermitted: wantOptions.SACKPermitted,
+ }
+}
+
+// AcceptWithOptions initializes a listening endpoint and connects to it with the
+// provided options enabled. It also verifies that the SYN-ACK has the expected
+// values for the provided options.
+//
+// The function returns a RawEndpoint representing the other end of the accepted
+// endpoint.
+func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+ // Create EP and start listening.
+ wq := &waiter.Queue{}
+ ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ if err := ep.Bind(tcpip.FullAddress{Port: StackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
+ c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ c.t.Fatalf("Listen failed: %v", err)
+ }
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
+ c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ rep := c.PassiveConnectWithOptions(100, wndScale, synOptions)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ c.t.Fatalf("Accept failed: %v", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ c.t.Fatalf("Timed out waiting for accept")
+ }
+ }
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
+ c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ return rep
+}
+
+// PassiveConnect just disables WindowScaling and delegates the call to
+// PassiveConnectWithOptions.
+func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) {
+ synOptions.WS = -1
+ c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions)
+}
+
+// PassiveConnectWithOptions initiates a new connection (with the specified TCP
+// options enabled) to the port on which the Context.ep is listening for new
+// connections. It also validates that the SYN-ACK has the expected values for
+// the enabled options.
+//
+// NOTE: MSS is not a negotiated option and it can be asymmetric
+// in each direction. This function uses the maxPayload to set the MSS to be
+// sent to the peer on a connect and validates that the MSS in the SYN-ACK
+// response is equal to the MTU - (tcphdr len + iphdr len).
+//
+// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the
+// value of the window scaling option to be sent in the SYN. If synOptions.WS >
+// 0 then we send the WindowScale option.
+func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+ opts := make([]byte, header.TCPOptionsMaximumSize)
+ offset := 0
+ offset += header.EncodeMSSOption(uint32(maxPayload), opts)
+
+ if synOptions.WS >= 0 {
+ offset += header.EncodeWSOption(3, opts[offset:])
+ }
+ if synOptions.TS {
+ offset += header.EncodeTSOption(synOptions.TSVal, synOptions.TSEcr, opts[offset:])
+ }
+
+ if synOptions.SACKPermitted {
+ offset += header.EncodeSACKPermittedOption(opts[offset:])
+ }
+
+ paddingToAdd := 4 - offset%4
+ // Now add any padding bytes that might be required to quad align the
+ // options.
+ for i := offset; i < offset+paddingToAdd; i++ {
+ opts[i] = header.TCPOptionNOP
+ }
+ offset += paddingToAdd
+
+ // Send a SYN request.
+ iss := seqnum.Value(testInitialSequenceNumber)
+ c.SendPacket(nil, &Headers{
+ SrcPort: TestPort,
+ DstPort: StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ TCPOpts: opts[:offset],
+ })
+
+ // Receive the SYN-ACK reply. Make sure MSS and other expected options
+ // are present.
+ b := c.GetPacket()
+ tcp := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcp.SequenceNumber())
+
+ tcpCheckers := []checker.TransportChecker{
+ checker.SrcPort(StackPort),
+ checker.DstPort(TestPort),
+ checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
+ checker.AckNum(uint32(iss) + 1),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}),
+ }
+
+ // If TS option was enabled in the original SYN then add a checker to
+ // validate the Timestamp option in the SYN-ACK.
+ if synOptions.TS {
+ tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(synOptions.TS, 0, synOptions.TSVal))
+ } else {
+ tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0))
+ }
+
+ checker.IPv4(c.t, b, checker.TCP(tcpCheckers...))
+ rcvWnd := seqnum.Size(30000)
+ ackHeaders := &Headers{
+ SrcPort: TestPort,
+ DstPort: StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ RcvWnd: rcvWnd,
+ }
+
+ // If WS was expected to be in effect then scale the advertised window
+ // correspondingly.
+ if synOptions.WS > 0 {
+ ackHeaders.RcvWnd = rcvWnd >> byte(synOptions.WS)
+ }
+
+ parsedOpts := tcp.ParsedOptions()
+ if synOptions.TS {
+ // Echo the tsVal back to the peer in the tsEcr field of the
+ // timestamp option.
+ // Increment TSVal by 1 from the value sent in the SYN and echo
+ // the TSVal in the SYN-ACK in the TSEcr field.
+ opts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
+ header.EncodeTSOption(synOptions.TSVal+1, parsedOpts.TSVal, opts[2:])
+ ackHeaders.TCPOpts = opts[:]
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ c.Port = StackPort
+
+ return &RawEndpoint{
+ C: c,
+ SrcPort: TestPort,
+ DstPort: StackPort,
+ Flags: header.TCPFlagPsh | header.TCPFlagAck,
+ NextSeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ WndSize: rcvWnd,
+ SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled(),
+ RecentTS: parsedOpts.TSVal,
+ TSVal: synOptions.TSVal + 1,
+ }
+}
+
+// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true
+// for the Stack in the context.
+func (c *Context) SACKEnabled() bool {
+ var v tcp.SACKEnabled
+ if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil {
+ // Stack doesn't support SACK. So just return.
+ return false
+ }
+ return bool(v)
+}
+
+// SetGSOEnabled enables or disables generic segmentation offload.
+func (c *Context) SetGSOEnabled(enable bool) {
+ if enable {
+ c.linkEP.LinkEPCapabilities |= stack.CapabilityHardwareGSO
+ } else {
+ c.linkEP.LinkEPCapabilities &^= stack.CapabilityHardwareGSO
+ }
+}
+
+// MSSWithoutOptions returns the value for the MSS used by the stack when no
+// options are in use.
+func (c *Context) MSSWithoutOptions() uint16 {
+ return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
+}
+
+// MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no
+// options are in use for IPv6 packets.
+func (c *Context) MSSWithoutOptionsV6() uint16 {
+ return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize)
+}
diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go
new file mode 100644
index 000000000..7981d469b
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/timer.go
@@ -0,0 +1,142 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+)
+
+type timerState int
+
+const (
+ timerStateDisabled timerState = iota
+ timerStateEnabled
+ timerStateOrphaned
+)
+
+// timer is a timer implementation that reduces the interactions with the
+// runtime timer infrastructure by letting timers run (and potentially
+// eventually expire) even if they are stopped. It makes it cheaper to
+// disable/reenable timers at the expense of spurious wakes. This is useful for
+// cases when the same timer is disabled/reenabled repeatedly with relatively
+// long timeouts farther into the future.
+//
+// TCP retransmit timers benefit from this because they the timeouts are long
+// (currently at least 200ms), and get disabled when acks are received, and
+// reenabled when new pending segments are sent.
+//
+// It is advantageous to avoid interacting with the runtime because it acquires
+// a global mutex and performs O(log n) operations, where n is the global number
+// of timers, whenever a timer is enabled or disabled, and may make a syscall.
+//
+// This struct is thread-compatible.
+type timer struct {
+ // state is the current state of the timer, it can be one of the
+ // following values:
+ // disabled - the timer is disabled.
+ // orphaned - the timer is disabled, but the runtime timer is
+ // enabled, which means that it will evetually cause a
+ // spurious wake (unless it gets enabled again before
+ // then).
+ // enabled - the timer is enabled, but the runtime timer may be set
+ // to an earlier expiration time due to a previous
+ // orphaned state.
+ state timerState
+
+ // target is the expiration time of the current timer. It is only
+ // meaningful in the enabled state.
+ target time.Time
+
+ // runtimeTarget is the expiration time of the runtime timer. It is
+ // meaningful in the enabled and orphaned states.
+ runtimeTarget time.Time
+
+ // timer is the runtime timer used to wait on.
+ timer *time.Timer
+}
+
+// init initializes the timer. Once it expires, it the given waker will be
+// asserted.
+func (t *timer) init(w *sleep.Waker) {
+ t.state = timerStateDisabled
+
+ // Initialize a runtime timer that will assert the waker, then
+ // immediately stop it.
+ t.timer = time.AfterFunc(time.Hour, func() {
+ w.Assert()
+ })
+ t.timer.Stop()
+}
+
+// cleanup frees all resources associated with the timer.
+func (t *timer) cleanup() {
+ t.timer.Stop()
+ *t = timer{}
+}
+
+// checkExpiration checks if the given timer has actually expired, it should be
+// called whenever a sleeper wakes up due to the waker being asserted, and is
+// used to check if it's a supurious wake (due to a previously orphaned timer)
+// or a legitimate one.
+func (t *timer) checkExpiration() bool {
+ // Transition to fully disabled state if we're just consuming an
+ // orphaned timer.
+ if t.state == timerStateOrphaned {
+ t.state = timerStateDisabled
+ return false
+ }
+
+ // The timer is enabled, but it may have expired early. Check if that's
+ // the case, and if so, reset the runtime timer to the correct time.
+ now := time.Now()
+ if now.Before(t.target) {
+ t.runtimeTarget = t.target
+ t.timer.Reset(t.target.Sub(now))
+ return false
+ }
+
+ // The timer has actually expired, disable it for now and inform the
+ // caller.
+ t.state = timerStateDisabled
+ return true
+}
+
+// disable disables the timer, leaving it in an orphaned state if it wasn't
+// already disabled.
+func (t *timer) disable() {
+ if t.state != timerStateDisabled {
+ t.state = timerStateOrphaned
+ }
+}
+
+// enabled returns true if the timer is currently enabled, false otherwise.
+func (t *timer) enabled() bool {
+ return t.state == timerStateEnabled
+}
+
+// enable enables the timer, programming the runtime timer if necessary.
+func (t *timer) enable(d time.Duration) {
+ t.target = time.Now().Add(d)
+
+ // Check if we need to set the runtime timer.
+ if t.state == timerStateDisabled || t.target.Before(t.runtimeTarget) {
+ t.runtimeTarget = t.target
+ t.timer.Reset(d)
+ }
+
+ t.state = timerStateEnabled
+}
diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go
new file mode 100644
index 000000000..dbd6dff54
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/timer_test.go
@@ -0,0 +1,47 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+)
+
+func TestCleanup(t *testing.T) {
+ const (
+ timerDurationSeconds = 2
+ isAssertedTimeoutSeconds = timerDurationSeconds + 1
+ )
+
+ tmr := timer{}
+ w := sleep.Waker{}
+ tmr.init(&w)
+ tmr.enable(timerDurationSeconds * time.Second)
+ tmr.cleanup()
+
+ if want := (timer{}); tmr != want {
+ t.Errorf("got tmr = %+v, want = %+v", tmr, want)
+ }
+
+ // The waker should not be asserted.
+ for i := 0; i < isAssertedTimeoutSeconds; i++ {
+ time.Sleep(time.Second)
+ if w.IsAsserted() {
+ t.Fatalf("waker asserted unexpectedly")
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
new file mode 100644
index 000000000..3ad6994a7
--- /dev/null
+++ b/pkg/tcpip/transport/tcpconntrack/BUILD
@@ -0,0 +1,23 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "tcpconntrack",
+ srcs = ["tcp_conntrack.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ ],
+)
+
+go_test(
+ name = "tcpconntrack_test",
+ size = "small",
+ srcs = ["tcp_conntrack_test.go"],
+ deps = [
+ ":tcpconntrack",
+ "//pkg/tcpip/header",
+ ],
+)
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
new file mode 100644
index 000000000..12bc1b5b5
--- /dev/null
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
@@ -0,0 +1,352 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tcpconntrack implements a TCP connection tracking object. It allows
+// users with access to a segment stream to figure out when a connection is
+// established, reset, and closed (and in the last case, who closed first).
+package tcpconntrack
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+// Result is returned when the state of a TCB is updated in response to an
+// inbound or outbound segment.
+type Result int
+
+const (
+ // ResultDrop indicates that the segment should be dropped.
+ ResultDrop Result = iota
+
+ // ResultConnecting indicates that the connection remains in a
+ // connecting state.
+ ResultConnecting
+
+ // ResultAlive indicates that the connection remains alive (connected).
+ ResultAlive
+
+ // ResultReset indicates that the connection was reset.
+ ResultReset
+
+ // ResultClosedByPeer indicates that the connection was gracefully
+ // closed, and the inbound stream was closed first.
+ ResultClosedByPeer
+
+ // ResultClosedBySelf indicates that the connection was gracefully
+ // closed, and the outbound stream was closed first.
+ ResultClosedBySelf
+)
+
+// TCB is a TCP Control Block. It holds state necessary to keep track of a TCP
+// connection and inform the caller when the connection has been closed.
+type TCB struct {
+ inbound stream
+ outbound stream
+
+ // State handlers.
+ handlerInbound func(*TCB, header.TCP) Result
+ handlerOutbound func(*TCB, header.TCP) Result
+
+ // firstFin holds a pointer to the first stream to send a FIN.
+ firstFin *stream
+
+ // state is the current state of the stream.
+ state Result
+}
+
+// Init initializes the state of the TCB according to the initial SYN.
+func (t *TCB) Init(initialSyn header.TCP) Result {
+ t.handlerInbound = synSentStateInbound
+ t.handlerOutbound = synSentStateOutbound
+
+ iss := seqnum.Value(initialSyn.SequenceNumber())
+ t.outbound.una = iss
+ t.outbound.nxt = iss.Add(logicalLen(initialSyn))
+ t.outbound.end = t.outbound.nxt
+
+ // Even though "end" is a sequence number, we don't know the initial
+ // receive sequence number yet, so we store the window size until we get
+ // a SYN from the peer.
+ t.inbound.una = 0
+ t.inbound.nxt = 0
+ t.inbound.end = seqnum.Value(initialSyn.WindowSize())
+ t.state = ResultConnecting
+ return t.state
+}
+
+// UpdateStateInbound updates the state of the TCB based on the supplied inbound
+// segment.
+func (t *TCB) UpdateStateInbound(tcp header.TCP) Result {
+ st := t.handlerInbound(t, tcp)
+ if st != ResultDrop {
+ t.state = st
+ }
+ return st
+}
+
+// UpdateStateOutbound updates the state of the TCB based on the supplied
+// outbound segment.
+func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result {
+ st := t.handlerOutbound(t, tcp)
+ if st != ResultDrop {
+ t.state = st
+ }
+ return st
+}
+
+// IsAlive returns true as long as the connection is established(Alive)
+// or connecting state.
+func (t *TCB) IsAlive() bool {
+ return !t.inbound.rstSeen && !t.outbound.rstSeen && (!t.inbound.closed() || !t.outbound.closed())
+}
+
+// OutboundSendSequenceNumber returns the snd.NXT for the outbound stream.
+func (t *TCB) OutboundSendSequenceNumber() seqnum.Value {
+ return t.outbound.nxt
+}
+
+// InboundSendSequenceNumber returns the snd.NXT for the inbound stream.
+func (t *TCB) InboundSendSequenceNumber() seqnum.Value {
+ return t.inbound.nxt
+}
+
+// adapResult modifies the supplied "Result" according to the state of the TCB;
+// if r is anything other than "Alive", or if one of the streams isn't closed
+// yet, it is returned unmodified. Otherwise it's converted to either
+// ClosedBySelf or ClosedByPeer depending on which stream was closed first.
+func (t *TCB) adaptResult(r Result) Result {
+ // Check the unmodified case.
+ if r != ResultAlive || !t.inbound.closed() || !t.outbound.closed() {
+ return r
+ }
+
+ // Find out which was closed first.
+ if t.firstFin == &t.outbound {
+ return ResultClosedBySelf
+ }
+
+ return ResultClosedByPeer
+}
+
+// synSentStateInbound is the state handler for inbound segments when the
+// connection is in SYN-SENT state.
+func synSentStateInbound(t *TCB, tcp header.TCP) Result {
+ flags := tcp.Flags()
+ ackPresent := flags&header.TCPFlagAck != 0
+ ack := seqnum.Value(tcp.AckNumber())
+
+ // Ignore segment if ack is present but not acceptable.
+ if ackPresent && !(ack-1).InRange(t.outbound.una, t.outbound.nxt) {
+ return ResultConnecting
+ }
+
+ // If reset is specified, we will let the packet through no matter what
+ // but we will also destroy the connection if the ACK is present (and
+ // implicitly acceptable).
+ if flags&header.TCPFlagRst != 0 {
+ if ackPresent {
+ t.inbound.rstSeen = true
+ return ResultReset
+ }
+ return ResultConnecting
+ }
+
+ // Ignore segment if SYN is not set.
+ if flags&header.TCPFlagSyn == 0 {
+ return ResultConnecting
+ }
+
+ // Update state informed by this SYN.
+ irs := seqnum.Value(tcp.SequenceNumber())
+ t.inbound.una = irs
+ t.inbound.nxt = irs.Add(logicalLen(tcp))
+ t.inbound.end += irs
+
+ t.outbound.end = t.outbound.una.Add(seqnum.Size(tcp.WindowSize()))
+
+ // If the ACK was set (it is acceptable), update our unacknowledgement
+ // tracking.
+ if ackPresent {
+ // Advance the "una" and "end" indices of the outbound stream.
+ if t.outbound.una.LessThan(ack) {
+ t.outbound.una = ack
+ }
+
+ if end := ack.Add(seqnum.Size(tcp.WindowSize())); t.outbound.end.LessThan(end) {
+ t.outbound.end = end
+ }
+ }
+
+ // Update handlers so that new calls will be handled by new state.
+ t.handlerInbound = allOtherInbound
+ t.handlerOutbound = allOtherOutbound
+
+ return ResultAlive
+}
+
+// synSentStateOutbound is the state handler for outbound segments when the
+// connection is in SYN-SENT state.
+func synSentStateOutbound(t *TCB, tcp header.TCP) Result {
+ // Drop outbound segments that aren't retransmits of the original one.
+ if tcp.Flags() != header.TCPFlagSyn ||
+ tcp.SequenceNumber() != uint32(t.outbound.una) {
+ return ResultDrop
+ }
+
+ // Update the receive window. We only remember the largest value seen.
+ if wnd := seqnum.Value(tcp.WindowSize()); wnd > t.inbound.end {
+ t.inbound.end = wnd
+ }
+
+ return ResultConnecting
+}
+
+// update updates the state of inbound and outbound streams, given the supplied
+// inbound segment. For outbound segments, this same function can be called with
+// swapped inbound/outbound streams.
+func update(tcp header.TCP, inbound, outbound *stream, firstFin **stream) Result {
+ // Ignore segments out of the window.
+ s := seqnum.Value(tcp.SequenceNumber())
+ if !inbound.acceptable(s, dataLen(tcp)) {
+ return ResultAlive
+ }
+
+ flags := tcp.Flags()
+ if flags&header.TCPFlagRst != 0 {
+ inbound.rstSeen = true
+ return ResultReset
+ }
+
+ // Ignore segments that don't have the ACK flag, and those with the SYN
+ // flag.
+ if flags&header.TCPFlagAck == 0 || flags&header.TCPFlagSyn != 0 {
+ return ResultAlive
+ }
+
+ // Ignore segments that acknowledge not yet sent data.
+ ack := seqnum.Value(tcp.AckNumber())
+ if outbound.nxt.LessThan(ack) {
+ return ResultAlive
+ }
+
+ // Advance the "una" and "end" indices of the outbound stream.
+ if outbound.una.LessThan(ack) {
+ outbound.una = ack
+ }
+
+ if end := ack.Add(seqnum.Size(tcp.WindowSize())); outbound.end.LessThan(end) {
+ outbound.end = end
+ }
+
+ // Advance the "nxt" index of the inbound stream.
+ end := s.Add(logicalLen(tcp))
+ if inbound.nxt.LessThan(end) {
+ inbound.nxt = end
+ }
+
+ // Note the index of the FIN segment. And stash away a pointer to the
+ // first stream to see a FIN.
+ if flags&header.TCPFlagFin != 0 && !inbound.finSeen {
+ inbound.finSeen = true
+ inbound.fin = end - 1
+
+ if *firstFin == nil {
+ *firstFin = inbound
+ }
+ }
+
+ return ResultAlive
+}
+
+// allOtherInbound is the state handler for inbound segments in all states
+// except SYN-SENT.
+func allOtherInbound(t *TCB, tcp header.TCP) Result {
+ return t.adaptResult(update(tcp, &t.inbound, &t.outbound, &t.firstFin))
+}
+
+// allOtherOutbound is the state handler for outbound segments in all states
+// except SYN-SENT.
+func allOtherOutbound(t *TCB, tcp header.TCP) Result {
+ return t.adaptResult(update(tcp, &t.outbound, &t.inbound, &t.firstFin))
+}
+
+// streams holds the state of a TCP unidirectional stream.
+type stream struct {
+ // The interval [una, end) is the allowed interval as defined by the
+ // receiver, i.e., anything less than una has already been acknowledged
+ // and anything greater than or equal to end is beyond the receiver
+ // window. The interval [una, nxt) is the acknowledgable range, whose
+ // right edge indicates the sequence number of the next byte to be sent
+ // by the sender, i.e., anything greater than or equal to nxt hasn't
+ // been sent yet.
+ una seqnum.Value
+ nxt seqnum.Value
+ end seqnum.Value
+
+ // finSeen indicates if a FIN has already been sent on this stream.
+ finSeen bool
+
+ // fin is the sequence number of the FIN. It is only valid after finSeen
+ // is set to true.
+ fin seqnum.Value
+
+ // rstSeen indicates if a RST has already been sent on this stream.
+ rstSeen bool
+}
+
+// acceptable determines if the segment with the given sequence number and data
+// length is acceptable, i.e., if it's within the [una, end) window or, in case
+// the window is zero, if it's a packet with no payload and sequence number
+// equal to una.
+func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
+ return header.Acceptable(segSeq, segLen, s.una, s.end)
+}
+
+// closed determines if the stream has already been closed. This happens when
+// a FIN has been set by the sender and acknowledged by the receiver.
+func (s *stream) closed() bool {
+ return s.finSeen && s.fin.LessThan(s.una)
+}
+
+// dataLen returns the length of the TCP segment payload.
+func dataLen(tcp header.TCP) seqnum.Size {
+ return seqnum.Size(len(tcp) - int(tcp.DataOffset()))
+}
+
+// logicalLen calculates the logical length of the TCP segment.
+func logicalLen(tcp header.TCP) seqnum.Size {
+ l := dataLen(tcp)
+ flags := tcp.Flags()
+ if flags&header.TCPFlagSyn != 0 {
+ l++
+ }
+ if flags&header.TCPFlagFin != 0 {
+ l++
+ }
+ return l
+}
+
+// IsEmpty returns true if tcb is not initialized.
+func (t *TCB) IsEmpty() bool {
+ if t.inbound != (stream{}) || t.outbound != (stream{}) {
+ return false
+ }
+
+ if t.firstFin != nil || t.state != ResultDrop {
+ return false
+ }
+
+ return true
+}
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
new file mode 100644
index 000000000..5e271b7ca
--- /dev/null
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
@@ -0,0 +1,511 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpconntrack_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
+)
+
+// connected creates a connection tracker TCB and sets it to a connected state
+// by performing a 3-way handshake.
+func connected(t *testing.T, iss, irs uint32, isw, irw uint16) *tcpconntrack.TCB {
+ // Send SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: iss,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: irw,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Receive SYN-ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: irs,
+ AckNum: iss + 1,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ WindowSize: isw,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: iss + 1,
+ AckNum: irs + 1,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: irw,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ return &tcb
+}
+
+func TestConnectionRefused(t *testing.T) {
+ // Send SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Receive RST.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 789,
+ AckNum: 1235,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagRst | header.TCPFlagAck,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
+ }
+}
+
+func TestConnectionRefusedInSynRcvd(t *testing.T) {
+ // Send SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Receive SYN.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 789,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Receive RST with no ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 790,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagRst,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
+ }
+}
+
+func TestConnectionResetInSynRcvd(t *testing.T) {
+ // Send SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Receive SYN.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 789,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send RST with no ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1235,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagRst,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultReset {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
+ }
+}
+
+func TestRetransmitOnSynSent(t *testing.T) {
+ // Send initial SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Retransmit the same SYN.
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultConnecting {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultConnecting)
+ }
+}
+
+func TestRetransmitOnSynRcvd(t *testing.T) {
+ // Send initial SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Receive SYN. This will cause the state to go to SYN-RCVD.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 789,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Retransmit the original SYN.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Transmit a SYN-ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 790,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+}
+
+func TestClosedBySelf(t *testing.T) {
+ tcb := connected(t, 1234, 789, 30000, 50000)
+
+ // Send FIN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1235,
+ AckNum: 790,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Receive FIN/ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 790,
+ AckNum: 1236,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1236,
+ AckNum: 791,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf)
+ }
+}
+
+func TestClosedByPeer(t *testing.T) {
+ tcb := connected(t, 1234, 789, 30000, 50000)
+
+ // Receive FIN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 790,
+ AckNum: 1235,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send FIN/ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1235,
+ AckNum: 791,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Receive ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 791,
+ AckNum: 1236,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultClosedByPeer {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedByPeer)
+ }
+}
+
+func TestSendAndReceiveDataClosedBySelf(t *testing.T) {
+ sseq := uint32(1234)
+ rseq := uint32(789)
+ tcb := connected(t, sseq, rseq, 30000, 50000)
+ sseq++
+ rseq++
+
+ // Send some data.
+ tcp := make(header.TCP, header.TCPMinimumSize+1024)
+
+ for i := uint32(0); i < 10; i++ {
+ // Send some data.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: sseq,
+ AckNum: rseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 30000,
+ })
+ sseq += uint32(len(tcp)) - header.TCPMinimumSize
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Receive ack for data.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: rseq,
+ AckNum: sseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+ }
+
+ for i := uint32(0); i < 10; i++ {
+ // Receive some data.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: rseq,
+ AckNum: sseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 50000,
+ })
+ rseq += uint32(len(tcp)) - header.TCPMinimumSize
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send ack for data.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: sseq,
+ AckNum: rseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+ }
+
+ // Send FIN.
+ tcp = tcp[:header.TCPMinimumSize]
+ tcp.Encode(&header.TCPFields{
+ SeqNum: sseq,
+ AckNum: rseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ WindowSize: 30000,
+ })
+ sseq++
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Receive FIN/ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: rseq,
+ AckNum: sseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ WindowSize: 50000,
+ })
+ rseq++
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: sseq,
+ AckNum: rseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf)
+ }
+}
+
+func TestIgnoreBadResetOnSynSent(t *testing.T) {
+ // Send SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Receive a RST with a bad ACK, it should not cause the connection to
+ // be reset.
+ acks := []uint32{1234, 1236, 1000, 5000}
+ flags := []uint8{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck}
+ for _, a := range acks {
+ for _, f := range flags {
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 789,
+ AckNum: a,
+ DataOffset: header.TCPMinimumSize,
+ Flags: f,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultConnecting {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+ }
+ }
+
+ // Complete the handshake.
+ // Receive SYN-ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 789,
+ AckNum: 1235,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1235,
+ AckNum: 790,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
new file mode 100644
index 000000000..b5d2d0ba6
--- /dev/null
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -0,0 +1,60 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "udp_packet_list",
+ out = "udp_packet_list.go",
+ package = "udp",
+ prefix = "udpPacket",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*udpPacket",
+ "Linker": "*udpPacket",
+ },
+)
+
+go_library(
+ name = "udp",
+ srcs = [
+ "endpoint.go",
+ "endpoint_state.go",
+ "forwarder.go",
+ "protocol.go",
+ "udp_packet_list.go",
+ ],
+ imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sleep",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/ports",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/raw",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "udp_x_test",
+ size = "small",
+ srcs = ["udp_test.go"],
+ deps = [
+ ":udp",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
new file mode 100644
index 000000000..0584ec8dc
--- /dev/null
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -0,0 +1,1497 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type udpPacket struct {
+ udpPacketEntry
+ senderAddress tcpip.FullAddress
+ packetInfo tcpip.IPPacketInfo
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ timestamp int64
+ // tos stores either the receiveTOS or receiveTClass value.
+ tos uint8
+}
+
+// EndpointState represents the state of a UDP endpoint.
+type EndpointState uint32
+
+// Endpoint states. Note that are represented in a netstack-specific manner and
+// may not be meaningful externally. Specifically, they need to be translated to
+// Linux's representation for these states if presented to userspace.
+const (
+ StateInitial EndpointState = iota
+ StateBound
+ StateConnected
+ StateClosed
+)
+
+// String implements fmt.Stringer.String.
+func (s EndpointState) String() string {
+ switch s {
+ case StateInitial:
+ return "INITIAL"
+ case StateBound:
+ return "BOUND"
+ case StateConnected:
+ return "CONNECTING"
+ case StateClosed:
+ return "CLOSED"
+ default:
+ return "UNKNOWN"
+ }
+}
+
+// endpoint represents a UDP endpoint. This struct serves as the interface
+// between users of the endpoint and the protocol implementation; it is legal to
+// have concurrent goroutines make calls into the endpoint, they are properly
+// synchronized.
+//
+// It implements tcpip.Endpoint.
+//
+// +stateify savable
+type endpoint struct {
+ stack.TransportEndpointInfo
+
+ // The following fields are initialized at creation time and do not
+ // change throughout the lifetime of the endpoint.
+ stack *stack.Stack `state:"manual"`
+ waiterQueue *waiter.Queue
+ uniqueID uint64
+
+ // The following fields are used to manage the receive queue, and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvReady bool
+ rcvList udpPacketList
+ rcvBufSizeMax int `state:".(int)"`
+ rcvBufSize int
+ rcvClosed bool
+
+ // The following fields are protected by the mu mutex.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ state EndpointState
+ route stack.Route `state:"manual"`
+ dstPort uint16
+ v6only bool
+ ttl uint8
+ multicastTTL uint8
+ multicastAddr tcpip.Address
+ multicastNICID tcpip.NICID
+ multicastLoop bool
+ portFlags ports.Flags
+ bindToDevice tcpip.NICID
+ broadcast bool
+ noChecksum bool
+
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError *tcpip.Error `state:".(string)"`
+
+ // Values used to reserve a port or register a transport endpoint.
+ // (which ever happens first).
+ boundBindToDevice tcpip.NICID
+ boundPortFlags ports.Flags
+
+ // sendTOS represents IPv4 TOS or IPv6 TrafficClass,
+ // applied while sending packets. Defaults to 0 as on Linux.
+ sendTOS uint8
+
+ // receiveTOS determines if the incoming IPv4 TOS header field is passed
+ // as ancillary data to ControlMessages on Read.
+ receiveTOS bool
+
+ // receiveTClass determines if the incoming IPv6 TClass header field is
+ // passed as ancillary data to ControlMessages on Read.
+ receiveTClass bool
+
+ // receiveIPPacketInfo determines if the packet info is returned by Read.
+ receiveIPPacketInfo bool
+
+ // shutdownFlags represent the current shutdown state of the endpoint.
+ shutdownFlags tcpip.ShutdownFlags
+
+ // multicastMemberships that need to be remvoed when the endpoint is
+ // closed. Protected by the mu mutex.
+ multicastMemberships []multicastMembership
+
+ // effectiveNetProtos contains the network protocols actually in use. In
+ // most cases it will only contain "netProto", but in cases like IPv6
+ // endpoints with v6only set to false, this could include multiple
+ // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
+ // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
+ // address).
+ effectiveNetProtos []tcpip.NetworkProtocolNumber
+
+ // TODO(b/142022063): Add ability to save and restore per endpoint stats.
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+
+ // owner is used to get uid and gid of the packet.
+ owner tcpip.PacketOwner
+}
+
+// +stateify savable
+type multicastMembership struct {
+ nicID tcpip.NICID
+ multicastAddr tcpip.Address
+}
+
+func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+ e := &endpoint{
+ stack: s,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: header.UDPProtocolNumber,
+ },
+ waiterQueue: waiterQueue,
+ // RFC 1075 section 5.4 recommends a TTL of 1 for membership
+ // requests.
+ //
+ // RFC 5135 4.2.1 appears to assume that IGMP messages have a
+ // TTL of 1.
+ //
+ // RFC 5135 Appendix A defines TTL=1: A multicast source that
+ // wants its traffic to not traverse a router (e.g., leave a
+ // home network) may find it useful to send traffic with IP
+ // TTL=1.
+ //
+ // Linux defaults to TTL=1.
+ multicastTTL: 1,
+ multicastLoop: true,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSizeMax: 32 * 1024,
+ state: StateInitial,
+ uniqueID: s.UniqueID(),
+ }
+
+ // Override with stack defaults.
+ var ss stack.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ e.sndBufSizeMax = ss.Default
+ }
+
+ var rs stack.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ e.rcvBufSizeMax = rs.Default
+ }
+
+ return e
+}
+
+// UniqueID implements stack.TransportEndpoint.UniqueID.
+func (e *endpoint) UniqueID() uint64 {
+ return e.uniqueID
+}
+
+func (e *endpoint) takeLastError() *tcpip.Error {
+ e.lastErrorMu.Lock()
+ defer e.lastErrorMu.Unlock()
+
+ err := e.lastError
+ e.lastError = nil
+ return err
+}
+
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
+// Close puts the endpoint in a closed state and frees all resources
+// associated with it.
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
+
+ switch e.state {
+ case StateBound, StateConnected:
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
+ e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
+ }
+
+ for _, mem := range e.multicastMemberships {
+ e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr)
+ }
+ e.multicastMemberships = nil
+
+ // Close the receive list and drain it.
+ e.rcvMu.Lock()
+ e.rcvClosed = true
+ e.rcvBufSize = 0
+ for !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ }
+ e.rcvMu.Unlock()
+
+ e.route.Release()
+
+ // Update the state.
+ e.state = StateClosed
+
+ e.mu.Unlock()
+
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
+func (e *endpoint) ModerateRecvBuf(copied int) {}
+
+// Read reads data from the endpoint. This method does not block if
+// there is no data pending.
+func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ if err := e.takeLastError(); err != nil {
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ e.rcvMu.Lock()
+
+ if e.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if e.rcvClosed {
+ e.stats.ReadErrors.ReadClosed.Increment()
+ err = tcpip.ErrClosedForReceive
+ }
+ e.rcvMu.Unlock()
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
+ e.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = p.senderAddress
+ }
+
+ cm := tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: p.timestamp,
+ }
+ e.mu.RLock()
+ receiveTOS := e.receiveTOS
+ receiveTClass := e.receiveTClass
+ receiveIPPacketInfo := e.receiveIPPacketInfo
+ e.mu.RUnlock()
+ if receiveTOS {
+ cm.HasTOS = true
+ cm.TOS = p.tos
+ }
+ if receiveTClass {
+ cm.HasTClass = true
+ // Although TClass is an 8-bit value it's read in the CMsg as a uint32.
+ cm.TClass = uint32(p.tos)
+ }
+ if receiveIPPacketInfo {
+ cm.HasIPPacketInfo = true
+ cm.PacketInfo = p.packetInfo
+ }
+ return p.data.ToView(), cm, nil
+}
+
+// prepareForWrite prepares the endpoint for sending data. In particular, it
+// binds it if it's still in the initial state. To do so, it must first
+// reacquire the mutex in exclusive mode.
+//
+// Returns true for retry if preparation should be retried.
+func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
+ switch e.state {
+ case StateInitial:
+ case StateConnected:
+ return false, nil
+
+ case StateBound:
+ if to == nil {
+ return false, tcpip.ErrDestinationRequired
+ }
+ return false, nil
+ default:
+ return false, tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // The state changed when we released the shared locked and re-acquired
+ // it in exclusive mode. Try again.
+ if e.state != StateInitial {
+ return true, nil
+ }
+
+ // The state is still 'initial', so try to bind the endpoint.
+ if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
+ return false, err
+ }
+
+ return true, nil
+}
+
+// connectRoute establishes a route to the specified interface or the
+// configured multicast interface if no interface is specified and the
+// specified address is a multicast address.
+func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
+ localAddr := e.ID.LocalAddress
+ if isBroadcastOrMulticast(localAddr) {
+ // A packet can only originate from a unicast address (i.e., an interface).
+ localAddr = ""
+ }
+
+ if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
+ if nicID == 0 {
+ nicID = e.multicastNICID
+ }
+ if localAddr == "" && nicID == 0 {
+ localAddr = e.multicastAddr
+ }
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop)
+ if err != nil {
+ return stack.Route{}, 0, err
+ }
+ return r, nicID, nil
+}
+
+// Write writes data to the endpoint's peer. This method does not block
+// if the data cannot be written.
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ n, ch, err := e.write(p, opts)
+ switch err {
+ case nil:
+ e.stats.PacketsSent.Increment()
+ case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ e.stats.WriteErrors.InvalidArgs.Increment()
+ case tcpip.ErrClosedForSend:
+ e.stats.WriteErrors.WriteClosed.Increment()
+ case tcpip.ErrInvalidEndpointState:
+ e.stats.WriteErrors.InvalidEndpointState.Increment()
+ case tcpip.ErrNoLinkAddress:
+ e.stats.SendErrors.NoLinkAddr.Increment()
+ case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ // Errors indicating any problem with IP routing of the packet.
+ e.stats.SendErrors.NoRoute.Increment()
+ default:
+ // For all other errors when writing to the network layer.
+ e.stats.SendErrors.SendToNetworkFailed.Increment()
+ }
+ return n, ch, err
+}
+
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ if err := e.takeLastError(); err != nil {
+ return 0, nil, err
+ }
+
+ // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
+ if opts.More {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
+ to := opts.To
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // If we've shutdown with SHUT_WR we are in an invalid state for sending.
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
+ return 0, nil, tcpip.ErrClosedForSend
+ }
+
+ // Prepare for write.
+ for {
+ retry, err := e.prepareForWrite(to)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if !retry {
+ break
+ }
+ }
+
+ var route *stack.Route
+ var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error)
+ var dstPort uint16
+ if to == nil {
+ route = &e.route
+ dstPort = e.dstPort
+ resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) {
+ // Promote lock to exclusive if using a shared route, given that it may
+ // need to change in Route.Resolve() call below.
+ e.mu.RUnlock()
+ e.mu.Lock()
+
+ // Recheck state after lock was re-acquired.
+ if e.state != StateConnected {
+ err = tcpip.ErrInvalidEndpointState
+ }
+ if err == nil && route.IsResolutionRequired() {
+ ch, err = route.Resolve(waker)
+ }
+
+ e.mu.Unlock()
+ e.mu.RLock()
+
+ // Recheck state after lock was re-acquired.
+ if e.state != StateConnected {
+ err = tcpip.ErrInvalidEndpointState
+ }
+ return
+ }
+ } else {
+ // Reject destination address if it goes through a different
+ // NIC than the endpoint was bound to.
+ nicID := to.NIC
+ if e.BindNICID != 0 {
+ if nicID != 0 && nicID != e.BindNICID {
+ return 0, nil, tcpip.ErrNoRoute
+ }
+
+ nicID = e.BindNICID
+ }
+
+ if to.Addr == header.IPv4Broadcast && !e.broadcast {
+ return 0, nil, tcpip.ErrBroadcastDisabled
+ }
+
+ dst, netProto, err := e.checkV4MappedLocked(*to)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ r, _, err := e.connectRoute(nicID, dst, netProto)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer r.Release()
+
+ route = &r
+ dstPort = dst.Port
+ resolve = route.Resolve
+ }
+
+ if route.IsResolutionRequired() {
+ if ch, err := resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ return 0, ch, tcpip.ErrNoLinkAddress
+ }
+ return 0, nil, err
+ }
+ }
+
+ v, err := p.FullPayload()
+ if err != nil {
+ return 0, nil, err
+ }
+ if len(v) > header.UDPMaximumPacketSize {
+ // Payload can't possibly fit in a packet.
+ return 0, nil, tcpip.ErrMessageTooLong
+ }
+
+ ttl := e.ttl
+ useDefaultTTL := ttl == 0
+
+ if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
+ ttl = e.multicastTTL
+ // Multicast allows a 0 TTL.
+ useDefaultTTL = false
+ }
+
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil {
+ return 0, nil, err
+ }
+ return int64(len(v)), nil, nil
+}
+
+// Peek only returns data from a single datagram, so do nothing here.
+func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.BroadcastOption:
+ e.mu.Lock()
+ e.broadcast = v
+ e.mu.Unlock()
+
+ case tcpip.MulticastLoopOption:
+ e.mu.Lock()
+ e.multicastLoop = v
+ e.mu.Unlock()
+
+ case tcpip.NoChecksumOption:
+ e.mu.Lock()
+ e.noChecksum = v
+ e.mu.Unlock()
+
+ case tcpip.ReceiveTOSOption:
+ e.mu.Lock()
+ e.receiveTOS = v
+ e.mu.Unlock()
+
+ case tcpip.ReceiveTClassOption:
+ // We only support this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrNotSupported
+ }
+
+ e.mu.Lock()
+ e.receiveTClass = v
+ e.mu.Unlock()
+
+ case tcpip.ReceiveIPPacketInfoOption:
+ e.mu.Lock()
+ e.receiveIPPacketInfo = v
+ e.mu.Unlock()
+
+ case tcpip.ReuseAddressOption:
+ e.mu.Lock()
+ e.portFlags.MostRecent = v
+ e.mu.Unlock()
+
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.portFlags.LoadBalanced = v
+ e.mu.Unlock()
+
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // We only allow this to be set when we're in the initial state.
+ if e.state != StateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.v6only = v
+ }
+
+ return nil
+}
+
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ switch opt {
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if the value is not disabling path
+ // MTU discovery.
+ if v != tcpip.PMTUDiscoveryDont {
+ return tcpip.ErrNotSupported
+ }
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ e.multicastTTL = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ e.ttl = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.IPv4TOSOption:
+ e.mu.Lock()
+ e.sendTOS = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.Lock()
+ e.sendTOS = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs stack.ReceiveBufferSizeOption
+ if err := e.stack.Option(&rs); err != nil {
+ panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err))
+ }
+
+ if v < rs.Min {
+ v = rs.Min
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+
+ e.mu.Lock()
+ e.rcvBufSizeMax = v
+ e.mu.Unlock()
+ return nil
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss stack.SendBufferSizeOption
+ if err := e.stack.Option(&ss); err != nil {
+ panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err))
+ }
+
+ if v < ss.Min {
+ v = ss.Min
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+
+ e.mu.Lock()
+ e.sndBufSizeMax = v
+ e.mu.Unlock()
+ return nil
+ }
+
+ return nil
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
+ case tcpip.MulticastInterfaceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
+ fa, netProto, err := e.checkV4MappedLocked(fa)
+ if err != nil {
+ return err
+ }
+ nic := v.NIC
+ addr := fa.Addr
+
+ if nic == 0 && addr == "" {
+ e.multicastAddr = ""
+ e.multicastNICID = 0
+ break
+ }
+
+ if nic != 0 {
+ if !e.stack.CheckNIC(nic) {
+ return tcpip.ErrBadLocalAddress
+ }
+ } else {
+ nic = e.stack.CheckLocalAddress(0, netProto, addr)
+ if nic == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+ }
+
+ if e.BindNICID != 0 && e.BindNICID != nic {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.multicastNICID = nic
+ e.multicastAddr = addr
+
+ case tcpip.AddMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ nicID := v.NIC
+
+ // The interface address is considered not-set if it is empty or contains
+ // all-zeros. The former represent the zero-value in golang, the latter the
+ // same in a setsockopt(IP_ADD_MEMBERSHIP, &ip_mreqn) syscall.
+ allZeros := header.IPv4Any
+ if len(v.InterfaceAddr) == 0 || v.InterfaceAddr == allZeros {
+ if nicID == 0 {
+ r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
+ if err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
+ nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return tcpip.ErrUnknownDevice
+ }
+
+ memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ for _, mem := range e.multicastMemberships {
+ if mem == memToInsert {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.multicastMemberships = append(e.multicastMemberships, memToInsert)
+
+ case tcpip.RemoveMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ nicID := v.NIC
+ if v.InterfaceAddr == header.IPv4Any {
+ if nicID == 0 {
+ r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
+ if err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
+ nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return tcpip.ErrUnknownDevice
+ }
+
+ memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+ memToRemoveIndex := -1
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ for i, mem := range e.multicastMemberships {
+ if mem == memToRemove {
+ memToRemoveIndex = i
+ break
+ }
+ }
+ if memToRemoveIndex == -1 {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1]
+ e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
+
+ case tcpip.BindToDeviceOption:
+ id := tcpip.NICID(v)
+ if id != 0 && !e.stack.HasNIC(id) {
+ return tcpip.ErrUnknownDevice
+ }
+ e.mu.Lock()
+ e.bindToDevice = id
+ e.mu.Unlock()
+ }
+ return nil
+}
+
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ switch opt {
+ case tcpip.BroadcastOption:
+ e.mu.RLock()
+ v := e.broadcast
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.KeepaliveEnabledOption:
+ return false, nil
+
+ case tcpip.MulticastLoopOption:
+ e.mu.RLock()
+ v := e.multicastLoop
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.NoChecksumOption:
+ e.mu.RLock()
+ v := e.noChecksum
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.ReceiveTOSOption:
+ e.mu.RLock()
+ v := e.receiveTOS
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.ReceiveTClassOption:
+ // We only support this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return false, tcpip.ErrNotSupported
+ }
+
+ e.mu.RLock()
+ v := e.receiveTClass
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.ReceiveIPPacketInfoOption:
+ e.mu.RLock()
+ v := e.receiveIPPacketInfo
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.ReuseAddressOption:
+ e.mu.RLock()
+ v := e.portFlags.MostRecent
+ e.mu.RUnlock()
+
+ return v, nil
+
+ case tcpip.ReusePortOption:
+ e.mu.RLock()
+ v := e.portFlags.LoadBalanced
+ e.mu.RUnlock()
+
+ return v, nil
+
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+
+ e.mu.RLock()
+ v := e.v6only
+ e.mu.RUnlock()
+
+ return v, nil
+
+ default:
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+ switch opt {
+ case tcpip.IPv4TOSOption:
+ e.mu.RLock()
+ v := int(e.sendTOS)
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.RLock()
+ v := int(e.sendTOS)
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.MTUDiscoverOption:
+ // The only supported setting is path MTU discovery disabled.
+ return tcpip.PMTUDiscoveryDont, nil
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ v := int(e.multicastTTL)
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveQueueSizeOption:
+ v := 0
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ v = p.data.Size()
+ }
+ e.rcvMu.Unlock()
+ return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSizeMax
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
+ return v, nil
+
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ v := int(e.ttl)
+ e.mu.Unlock()
+ return v, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return e.takeLastError()
+ case *tcpip.MulticastInterfaceOption:
+ e.mu.Lock()
+ *o = tcpip.MulticastInterfaceOption{
+ e.multicastNICID,
+ e.multicastAddr,
+ }
+ e.mu.Unlock()
+
+ case *tcpip.BindToDeviceOption:
+ e.mu.RLock()
+ *o = tcpip.BindToDeviceOption(e.bindToDevice)
+ e.mu.RUnlock()
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+ return nil
+}
+
+// sendUDP sends a UDP segment via the provided network endpoint and under the
+// provided identity.
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error {
+ // Allocate a buffer for the UDP header.
+ hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
+
+ // Initialize the header.
+ udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+
+ length := uint16(hdr.UsedLength() + data.Size())
+ udp.Encode(&header.UDPFields{
+ SrcPort: localPort,
+ DstPort: remotePort,
+ Length: length,
+ })
+
+ // Set the checksum field unless TX checksum offload is enabled.
+ // On IPv4, UDP checksum is optional, and a zero value indicates the
+ // transmitter skipped the checksum generation (RFC768).
+ // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+ if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 &&
+ (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) {
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
+ for _, v := range data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ udp.SetChecksum(^udp.CalculateChecksum(xsum))
+ }
+
+ if useDefaultTTL {
+ ttl = r.DefaultTTL()
+ }
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: ProtocolNumber,
+ TTL: ttl,
+ TOS: tos,
+ }, &stack.PacketBuffer{
+ Header: hdr,
+ Data: data,
+ TransportHeader: buffer.View(udp),
+ Owner: owner,
+ }); err != nil {
+ r.Stats().UDP.PacketSendErrors.Increment()
+ return err
+ }
+
+ // Track count of packets sent.
+ r.Stats().UDP.PacketsSent.Increment()
+ return nil
+}
+
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
+ if err != nil {
+ return tcpip.FullAddress{}, 0, err
+ }
+ return unwrapped, netProto, nil
+}
+
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (e *endpoint) Disconnect() *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if e.state != StateConnected {
+ return nil
+ }
+ var (
+ id stack.TransportEndpointID
+ btd tcpip.NICID
+ )
+
+ // We change this value below and we need the old value to unregister
+ // the endpoint.
+ boundPortFlags := e.boundPortFlags
+
+ // Exclude ephemerally bound endpoints.
+ if e.BindNICID != 0 || e.ID.LocalAddress == "" {
+ var err *tcpip.Error
+ id = stack.TransportEndpointID{
+ LocalPort: e.ID.LocalPort,
+ LocalAddress: e.ID.LocalAddress,
+ }
+ id, btd, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
+ if err != nil {
+ return err
+ }
+ e.state = StateBound
+ boundPortFlags = e.boundPortFlags
+ } else {
+ if e.ID.LocalPort != 0 {
+ // Release the ephemeral port.
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
+ e.boundPortFlags = ports.Flags{}
+ }
+ e.state = StateInitial
+ }
+
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
+ e.ID = id
+ e.boundBindToDevice = btd
+ e.route.Release()
+ e.route = stack.Route{}
+ e.dstPort = 0
+
+ return nil
+}
+
+// Connect connects the endpoint to its peer. Specifying a NIC is optional.
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ if addr.Port == 0 {
+ // We don't support connecting to port zero.
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ nicID := addr.NIC
+ var localPort uint16
+ switch e.state {
+ case StateInitial:
+ case StateBound, StateConnected:
+ localPort = e.ID.LocalPort
+ if e.BindNICID == 0 {
+ break
+ }
+
+ if nicID != 0 && nicID != e.BindNICID {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ nicID = e.BindNICID
+ default:
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ addr, netProto, err := e.checkV4MappedLocked(addr)
+ if err != nil {
+ return err
+ }
+
+ r, nicID, err := e.connectRoute(nicID, addr, netProto)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ id := stack.TransportEndpointID{
+ LocalAddress: e.ID.LocalAddress,
+ LocalPort: localPort,
+ RemotePort: addr.Port,
+ RemoteAddress: r.RemoteAddress,
+ }
+
+ if e.state == StateInitial {
+ id.LocalAddress = r.LocalAddress
+ }
+
+ // Even if we're connected, this endpoint can still be used to send
+ // packets on a different network protocol, so we register both even if
+ // v6only is set to false and this is an ipv6 endpoint.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.v6only {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv4ProtocolNumber,
+ header.IPv6ProtocolNumber,
+ }
+ }
+
+ oldPortFlags := e.boundPortFlags
+
+ id, btd, err := e.registerWithStack(nicID, netProtos, id)
+ if err != nil {
+ return err
+ }
+
+ // Remove the old registration.
+ if e.ID.LocalPort != 0 {
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice)
+ }
+
+ e.ID = id
+ e.boundBindToDevice = btd
+ e.route = r.Clone()
+ e.dstPort = addr.Port
+ e.RegisterNICID = nicID
+ e.effectiveNetProtos = netProtos
+
+ e.state = StateConnected
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// ConnectEndpoint is not supported.
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection
+// to its peer.
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // A socket in the bound state can still receive multicast messages,
+ // so we need to notify waiters on shutdown.
+ if e.state != StateBound && e.state != StateConnected {
+ return tcpip.ErrNotConnected
+ }
+
+ e.shutdownFlags |= flags
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.rcvMu.Lock()
+ wasClosed := e.rcvClosed
+ e.rcvClosed = true
+ e.rcvMu.Unlock()
+
+ if !wasClosed {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+ }
+
+ return nil
+}
+
+// Listen is not supported by UDP, it just fails.
+func (*endpoint) Listen(int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept is not supported by UDP, it just fails.
+func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
+ if e.ID.LocalPort == 0 {
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{})
+ if err != nil {
+ return id, e.bindToDevice, err
+ }
+ id.LocalPort = port
+ }
+ e.boundPortFlags = e.portFlags
+
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice)
+ if err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{})
+ e.boundPortFlags = ports.Flags{}
+ }
+ return id, e.bindToDevice, err
+}
+
+func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore.
+ if e.state != StateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ addr, netProto, err := e.checkV4MappedLocked(addr)
+ if err != nil {
+ return err
+ }
+
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv6ProtocolNumber,
+ header.IPv4ProtocolNumber,
+ }
+ }
+
+ nicID := addr.NIC
+ if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) {
+ // A local unicast address was specified, verify that it's valid.
+ nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
+ if nicID == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+ }
+
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: addr.Addr,
+ }
+ id, btd, err := e.registerWithStack(nicID, netProtos, id)
+ if err != nil {
+ return err
+ }
+
+ e.ID = id
+ e.boundBindToDevice = btd
+ e.RegisterNICID = nicID
+ e.effectiveNetProtos = netProtos
+
+ // Mark endpoint as bound.
+ e.state = StateBound
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// Bind binds the endpoint to a specific local address and port.
+// Specifying a NIC is optional.
+func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ err := e.bindLocked(addr)
+ if err != nil {
+ return err
+ }
+
+ // Save the effective NICID generated by bindLocked.
+ e.BindNICID = e.RegisterNICID
+
+ return nil
+}
+
+// GetLocalAddress returns the address to which the endpoint is bound.
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ addr := e.ID.LocalAddress
+ if e.state == StateConnected {
+ addr = e.route.LocalAddress
+ }
+
+ return tcpip.FullAddress{
+ NIC: e.RegisterNICID,
+ Addr: addr,
+ Port: e.ID.LocalPort,
+ }, nil
+}
+
+// GetRemoteAddress returns the address to which the endpoint is connected.
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.state != StateConnected {
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ }
+
+ return tcpip.FullAddress{
+ NIC: e.RegisterNICID,
+ Addr: e.ID.RemoteAddress,
+ Port: e.ID.RemotePort,
+ }, nil
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine if the endpoint is readable if requested.
+ if (mask & waiter.EventIn) != 0 {
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() || e.rcvClosed {
+ result |= waiter.EventIn
+ }
+ e.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+ // Get the header then trim it from the view.
+ hdr := header.UDP(pkt.TransportHeader)
+ if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
+ // Malformed packet.
+ e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
+ e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
+ return
+ }
+
+ // Verify checksum unless RX checksum offload is enabled.
+ // On IPv4, UDP checksum is optional, and a zero value means
+ // the transmitter omitted the checksum generation (RFC768).
+ // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+ if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 &&
+ (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) {
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length())
+ for _, v := range pkt.Data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ if hdr.CalculateChecksum(xsum) != 0xffff {
+ // Checksum Error.
+ e.stack.Stats().UDP.ChecksumErrors.Increment()
+ e.stats.ReceiveErrors.ChecksumErrors.Increment()
+ return
+ }
+ }
+
+ e.rcvMu.Lock()
+ e.stack.Stats().UDP.PacketsReceived.Increment()
+ e.stats.PacketsReceived.Increment()
+
+ // Drop the packet if our buffer is currently full.
+ if !e.rcvReady || e.rcvClosed {
+ e.rcvMu.Unlock()
+ e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
+ e.stats.ReceiveErrors.ClosedReceiver.Increment()
+ return
+ }
+
+ if e.rcvBufSize >= e.rcvBufSizeMax {
+ e.rcvMu.Unlock()
+ e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
+ e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
+ return
+ }
+
+ wasEmpty := e.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ packet := &udpPacket{
+ senderAddress: tcpip.FullAddress{
+ NIC: r.NICID(),
+ Addr: id.RemoteAddress,
+ Port: header.UDP(hdr).SourcePort(),
+ },
+ }
+ packet.data = pkt.Data
+ e.rcvList.PushBack(packet)
+ e.rcvBufSize += pkt.Data.Size()
+
+ // Save any useful information from the network header to the packet.
+ switch r.NetProto {
+ case header.IPv4ProtocolNumber:
+ packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS()
+ packet.packetInfo.LocalAddr = r.LocalAddress
+ packet.packetInfo.DestinationAddr = r.RemoteAddress
+ packet.packetInfo.NIC = r.NICID()
+ case header.IPv6ProtocolNumber:
+ packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS()
+ }
+
+ packet.timestamp = e.stack.NowNanoseconds()
+
+ e.rcvMu.Unlock()
+
+ // Notify any waiters that there's data to be read now.
+ if wasEmpty {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
+ if typ == stack.ControlPortUnreachable {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.state == StateConnected {
+ e.lastErrorMu.Lock()
+ defer e.lastErrorMu.Unlock()
+
+ e.lastError = tcpip.ErrConnectionRefused
+ }
+ }
+}
+
+// State implements tcpip.Endpoint.State.
+func (e *endpoint) State() uint32 {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return uint32(e.state)
+}
+
+// Info returns a copy of the endpoint info.
+func (e *endpoint) Info() tcpip.EndpointInfo {
+ e.mu.RLock()
+ // Make a copy of the endpoint info.
+ ret := e.TransportEndpointInfo
+ e.mu.RUnlock()
+ return &ret
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (e *endpoint) Stats() tcpip.EndpointStats {
+ return &e.stats
+}
+
+// Wait implements tcpip.Endpoint.Wait.
+func (*endpoint) Wait() {}
+
+func isBroadcastOrMulticast(a tcpip.Address) bool {
+ return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a)
+}
+
+func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.owner = owner
+}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
new file mode 100644
index 000000000..851e6b635
--- /dev/null
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -0,0 +1,137 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves udpPacket.data field.
+func (u *udpPacket) saveData() buffer.VectorisedView {
+ // We cannot save u.data directly as u.data.views may alias to u.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return u.data.Clone(nil)
+}
+
+// loadData loads udpPacket.data field.
+func (u *udpPacket) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the u.data = data.Clone(u.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing u.views for data.views.
+ u.data = data
+}
+
+// saveLastError is invoked by stateify.
+func (e *endpoint) saveLastError() string {
+ if e.lastError == nil {
+ return ""
+ }
+
+ return e.lastError.String()
+}
+
+// loadLastError is invoked by stateify.
+func (e *endpoint) loadLastError(s string) {
+ if s == "" {
+ return
+ }
+
+ e.lastError = tcpip.StringToError(s)
+}
+
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after savercvBufSizeMax(), which would have
+ // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ e.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) saveRcvBufSizeMax() int {
+ max := e.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ e.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ e.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) loadRcvBufSizeMax(max int) {
+ e.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (e *endpoint) afterLoad() {
+ stack.StackFromEnv.RegisterRestoredEndpoint(e)
+}
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *endpoint) Resume(s *stack.Stack) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ e.stack = s
+
+ for _, m := range e.multicastMemberships {
+ if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil {
+ panic(err)
+ }
+ }
+
+ if e.state != StateBound && e.state != StateConnected {
+ return
+ }
+
+ netProto := e.effectiveNetProtos[0]
+ // Connect() and bindLocked() both assert
+ //
+ // netProto == header.IPv6ProtocolNumber
+ //
+ // before creating a multi-entry effectiveNetProtos.
+ if len(e.effectiveNetProtos) > 1 {
+ netProto = header.IPv6ProtocolNumber
+ }
+
+ var err *tcpip.Error
+ if e.state == StateConnected {
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop)
+ if err != nil {
+ panic(err)
+ }
+ } else if len(e.ID.LocalAddress) != 0 && !isBroadcastOrMulticast(e.ID.LocalAddress) { // stateBound
+ // A local unicast address is specified, verify that it's valid.
+ if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ // Our saved state had a port, but we don't actually have a
+ // reservation. We need to remove the port from our state, but still
+ // pass it to the reservation machinery.
+ id := e.ID
+ e.ID.LocalPort = 0
+ e.ID, e.boundBindToDevice, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
+ if err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
new file mode 100644
index 000000000..c67e0ba95
--- /dev/null
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -0,0 +1,96 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// Forwarder is a session request forwarder, which allows clients to decide
+// what to do with a session request, for example: ignore it, or process it.
+//
+// The canonical way of using it is to pass the Forwarder.HandlePacket function
+// to stack.SetTransportProtocolHandler.
+type Forwarder struct {
+ handler func(*ForwarderRequest)
+
+ stack *stack.Stack
+}
+
+// NewForwarder allocates and initializes a new forwarder.
+func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder {
+ return &Forwarder{
+ stack: s,
+ handler: handler,
+ }
+}
+
+// HandlePacket handles all packets.
+//
+// This function is expected to be passed as an argument to the
+// stack.SetTransportProtocolHandler function.
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+ f.handler(&ForwarderRequest{
+ stack: f.stack,
+ route: r,
+ id: id,
+ pkt: pkt,
+ })
+
+ return true
+}
+
+// ForwarderRequest represents a session request received by the forwarder and
+// passed to the client. Clients may optionally create an endpoint to represent
+// it via CreateEndpoint.
+type ForwarderRequest struct {
+ stack *stack.Stack
+ route *stack.Route
+ id stack.TransportEndpointID
+ pkt *stack.PacketBuffer
+}
+
+// ID returns the 4-tuple (src address, src port, dst address, dst port) that
+// represents the session request.
+func (r *ForwarderRequest) ID() stack.TransportEndpointID {
+ return r.id
+}
+
+// CreateEndpoint creates a connected UDP endpoint for the session request.
+func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ ep := newEndpoint(r.stack, r.route.NetProto, queue)
+ if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil {
+ ep.Close()
+ return nil, err
+ }
+
+ ep.ID = r.id
+ ep.route = r.route.Clone()
+ ep.dstPort = r.id.RemotePort
+ ep.RegisterNICID = r.route.NICID()
+ ep.boundPortFlags = ep.portFlags
+
+ ep.state = StateConnected
+
+ ep.rcvMu.Lock()
+ ep.rcvReady = true
+ ep.rcvMu.Unlock()
+
+ ep.HandlePacket(r.route, r.id, r.pkt)
+
+ return ep, nil
+}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
new file mode 100644
index 000000000..0e7464e3a
--- /dev/null
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -0,0 +1,231 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package udp contains the implementation of the UDP transport protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing udp.NewProtocol() as one of the
+// transport protocols when calling stack.New(). Then endpoints can be created
+// by passing udp.ProtocolNumber as the transport protocol number when calling
+// Stack.NewEndpoint().
+package udp
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/raw"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ // ProtocolNumber is the udp protocol number.
+ ProtocolNumber = header.UDPProtocolNumber
+
+ // MinBufferSize is the smallest size of a receive or send buffer.
+ MinBufferSize = 4 << 10 // 4KiB bytes.
+
+ // DefaultSendBufferSize is the default size of the send buffer for
+ // an endpoint.
+ DefaultSendBufferSize = 32 << 10 // 32KiB
+
+ // DefaultReceiveBufferSize is the default size of the receive buffer
+ // for an endpoint.
+ DefaultReceiveBufferSize = 32 << 10 // 32KiB
+
+ // MaxBufferSize is the largest size a receive/send buffer can grow to.
+ MaxBufferSize = 4 << 20 // 4MiB
+)
+
+type protocol struct {
+}
+
+// Number returns the udp protocol number.
+func (*protocol) Number() tcpip.TransportProtocolNumber {
+ return ProtocolNumber
+}
+
+// NewEndpoint creates a new udp endpoint.
+func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(stack, netProto, waiterQueue), nil
+}
+
+// NewRawEndpoint creates a new raw UDP endpoint. It implements
+// stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(stack, netProto, header.UDPProtocolNumber, waiterQueue)
+}
+
+// MinimumPacketSize returns the minimum valid udp packet size.
+func (*protocol) MinimumPacketSize() int {
+ return header.UDPMinimumSize
+}
+
+// ParsePorts returns the source and destination ports stored in the given udp
+// packet.
+func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+ h := header.UDP(v)
+ return h.SourcePort(), h.DestinationPort(), nil
+}
+
+// HandleUnknownDestinationPacket handles packets targeted at this protocol but
+// that don't match any existing endpoint.
+func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+ hdr := header.UDP(pkt.TransportHeader)
+ if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
+ // Malformed packet.
+ r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
+ return true
+ }
+ // TODO(b/129426613): only send an ICMP message if UDP checksum is valid.
+
+ // Only send ICMP error if the address is not a multicast/broadcast
+ // v4/v6 address or the source is not the unspecified address.
+ //
+ // See: point e) in https://tools.ietf.org/html/rfc4443#section-2.4
+ if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) || id.RemoteAddress == header.IPv6Any || id.RemoteAddress == header.IPv4Any {
+ return true
+ }
+
+ // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination
+ // Unreachable messages with code:
+ //
+ // 2 (Protocol Unreachable), when the designated transport protocol
+ // is not supported; or
+ //
+ // 3 (Port Unreachable), when the designated transport protocol
+ // (e.g., UDP) is unable to demultiplex the datagram but has no
+ // protocol mechanism to inform the sender.
+ switch len(id.LocalAddress) {
+ case header.IPv4AddressSize:
+ if !r.Stack().AllowICMPMessage() {
+ r.Stack().Stats().ICMP.V4PacketsSent.RateLimited.Increment()
+ return true
+ }
+ // As per RFC 1812 Section 4.3.2.3
+ //
+ // ICMP datagram SHOULD contain as much of the original
+ // datagram as possible without the length of the ICMP
+ // datagram exceeding 576 bytes
+ //
+ // NOTE: The above RFC referenced is different from the original
+ // recommendation in RFC 1122 where it mentioned that at least 8
+ // bytes of the payload must be included. Today linux and other
+ // systems implement the] RFC1812 definition and not the original
+ // RFC 1122 requirement.
+ mtu := int(r.MTU())
+ if mtu > header.IPv4MinimumProcessableDatagramSize {
+ mtu = header.IPv4MinimumProcessableDatagramSize
+ }
+ headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
+ available := int(mtu) - headerLen
+ payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+
+ // The buffers used by pkt may be used elsewhere in the system.
+ // For example, a raw or packet socket may use what UDP
+ // considers an unreachable destination. Thus we deep copy pkt
+ // to prevent multiple ownership and SR errors.
+ newHeader := append(buffer.View(nil), pkt.NetworkHeader...)
+ newHeader = append(newHeader, pkt.TransportHeader...)
+ payload := newHeader.ToVectorisedView()
+ payload.AppendView(pkt.Data.ToView())
+ payload.CapLength(payloadLen)
+
+ hdr := buffer.NewPrependable(headerLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4DstUnreachable)
+ pkt.SetCode(header.ICMPv4PortUnreachable)
+ pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload))
+ r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{
+ Header: hdr,
+ TransportHeader: buffer.View(pkt),
+ Data: payload,
+ })
+
+ case header.IPv6AddressSize:
+ if !r.Stack().AllowICMPMessage() {
+ r.Stack().Stats().ICMP.V6PacketsSent.RateLimited.Increment()
+ return true
+ }
+
+ // As per RFC 4443 section 2.4
+ //
+ // (c) Every ICMPv6 error message (type < 128) MUST include
+ // as much of the IPv6 offending (invoking) packet (the
+ // packet that caused the error) as possible without making
+ // the error message packet exceed the minimum IPv6 MTU
+ // [IPv6].
+ mtu := int(r.MTU())
+ if mtu > header.IPv6MinimumMTU {
+ mtu = header.IPv6MinimumMTU
+ }
+ headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize
+ available := int(mtu) - headerLen
+ payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+ payload := buffer.NewVectorisedView(len(pkt.NetworkHeader)+len(pkt.TransportHeader), []buffer.View{pkt.NetworkHeader, pkt.TransportHeader})
+ payload.Append(pkt.Data)
+ payload.CapLength(payloadLen)
+
+ hdr := buffer.NewPrependable(headerLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize))
+ pkt.SetType(header.ICMPv6DstUnreachable)
+ pkt.SetCode(header.ICMPv6PortUnreachable)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload))
+ r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{
+ Header: hdr,
+ TransportHeader: buffer.View(pkt),
+ Data: payload,
+ })
+ }
+ return true
+}
+
+// SetOption implements stack.TransportProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements stack.TransportProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
+ h, ok := pkt.Data.PullUp(header.UDPMinimumSize)
+ if !ok {
+ // Packet is too small
+ return false
+ }
+ pkt.TransportHeader = h
+ pkt.Data.TrimFront(header.UDPMinimumSize)
+ return true
+}
+
+// NewProtocol returns a UDP transport protocol.
+func NewProtocol() stack.TransportProtocol {
+ return &protocol{}
+}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
new file mode 100644
index 000000000..91ba031fa
--- /dev/null
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -0,0 +1,2072 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp_test
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "math/rand"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// Addresses and ports used for testing. It is recommended that tests stick to
+// using these addresses as it allows using the testFlow helper.
+// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*'
+// represents the remote endpoint.
+const (
+ v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
+ stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ stackV4MappedAddr = v4MappedAddrPrefix + stackAddr
+ testV4MappedAddr = v4MappedAddrPrefix + testAddr
+ multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr
+ broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr
+ v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00"
+
+ stackAddr = "\x0a\x00\x00\x01"
+ stackPort = 1234
+ testAddr = "\x0a\x00\x00\x02"
+ testPort = 4096
+ multicastAddr = "\xe8\x2b\xd3\xea"
+ multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+ broadcastAddr = header.IPv4Broadcast
+ testTOS = 0x80
+
+ // defaultMTU is the MTU, in bytes, used throughout the tests, except
+ // where another value is explicitly used. It is chosen to match the MTU
+ // of loopback interfaces on linux systems.
+ defaultMTU = 65536
+)
+
+// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in
+// a packet header. These values are used to populate a header or verify one.
+// Note that because they are used in packet headers, the addresses are never in
+// a V4-mapped format.
+type header4Tuple struct {
+ srcAddr tcpip.FullAddress
+ dstAddr tcpip.FullAddress
+}
+
+// testFlow implements a helper type used for sending and receiving test
+// packets. A given test flow value defines 1) the socket endpoint used for the
+// test and 2) the type of packet send or received on the endpoint. E.g., a
+// multicastV6Only flow is a V6 multicast packet passing through a V6-only
+// endpoint. The type provides helper methods to characterize the flow (e.g.,
+// isV4) as well as return a proper header4Tuple for it.
+type testFlow int
+
+const (
+ unicastV4 testFlow = iota // V4 unicast on a V4 socket
+ unicastV4in6 // V4-mapped unicast on a V6-dual socket
+ unicastV6 // V6 unicast on a V6 socket
+ unicastV6Only // V6 unicast on a V6-only socket
+ multicastV4 // V4 multicast on a V4 socket
+ multicastV4in6 // V4-mapped multicast on a V6-dual socket
+ multicastV6 // V6 multicast on a V6 socket
+ multicastV6Only // V6 multicast on a V6-only socket
+ broadcast // V4 broadcast on a V4 socket
+ broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+)
+
+func (flow testFlow) String() string {
+ switch flow {
+ case unicastV4:
+ return "unicastV4"
+ case unicastV6:
+ return "unicastV6"
+ case unicastV6Only:
+ return "unicastV6Only"
+ case unicastV4in6:
+ return "unicastV4in6"
+ case multicastV4:
+ return "multicastV4"
+ case multicastV6:
+ return "multicastV6"
+ case multicastV6Only:
+ return "multicastV6Only"
+ case multicastV4in6:
+ return "multicastV4in6"
+ case broadcast:
+ return "broadcast"
+ case broadcastIn6:
+ return "broadcastIn6"
+ default:
+ return "unknown"
+ }
+}
+
+// packetDirection explains if a flow is incoming (read) or outgoing (write).
+type packetDirection int
+
+const (
+ incoming packetDirection = iota
+ outgoing
+)
+
+// header4Tuple returns the header4Tuple for the given flow and direction. Note
+// that the tuple contains no mapped addresses as those only exist at the socket
+// level but not at the packet header level.
+func (flow testFlow) header4Tuple(d packetDirection) header4Tuple {
+ var h header4Tuple
+ if flow.isV4() {
+ if d == outgoing {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
+ dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
+ }
+ } else {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
+ dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
+ }
+ }
+ if flow.isMulticast() {
+ h.dstAddr.Addr = multicastAddr
+ } else if flow.isBroadcast() {
+ h.dstAddr.Addr = broadcastAddr
+ }
+ } else { // IPv6
+ if d == outgoing {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
+ dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
+ }
+ } else {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
+ dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
+ }
+ }
+ if flow.isMulticast() {
+ h.dstAddr.Addr = multicastV6Addr
+ }
+ }
+ return h
+}
+
+func (flow testFlow) getMcastAddr() tcpip.Address {
+ if flow.isV4() {
+ return multicastAddr
+ }
+ return multicastV6Addr
+}
+
+// mapAddrIfApplicable converts the given V4 address into its V4-mapped version
+// if it is applicable to the flow.
+func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address {
+ if flow.isMapped() {
+ return v4MappedAddrPrefix + v4Addr
+ }
+ return v4Addr
+}
+
+// netProto returns the protocol number used for the network packet.
+func (flow testFlow) netProto() tcpip.NetworkProtocolNumber {
+ if flow.isV4() {
+ return ipv4.ProtocolNumber
+ }
+ return ipv6.ProtocolNumber
+}
+
+// sockProto returns the protocol number used when creating the socket
+// endpoint for this flow.
+func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber {
+ switch flow {
+ case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6:
+ return ipv6.ProtocolNumber
+ case unicastV4, multicastV4, broadcast:
+ return ipv4.ProtocolNumber
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) {
+ if flow.isV4() {
+ return checker.IPv4
+ }
+ return checker.IPv6
+}
+
+func (flow testFlow) isV6() bool { return !flow.isV4() }
+func (flow testFlow) isV4() bool {
+ return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped()
+}
+
+func (flow testFlow) isV6Only() bool {
+ switch flow {
+ case unicastV6Only, multicastV6Only:
+ return true
+ case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) isMulticast() bool {
+ switch flow {
+ case multicastV4, multicastV4in6, multicastV6, multicastV6Only:
+ return true
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) isBroadcast() bool {
+ switch flow {
+ case broadcast, broadcastIn6:
+ return true
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) isMapped() bool {
+ switch flow {
+ case unicastV4in6, multicastV4in6, broadcastIn6:
+ return true
+ case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+type testContext struct {
+ t *testing.T
+ linkEP *channel.Endpoint
+ s *stack.Stack
+
+ ep tcpip.Endpoint
+ wq waiter.Queue
+}
+
+func newDualTestContext(t *testing.T, mtu uint32) *testContext {
+ t.Helper()
+ return newDualTestContextWithOptions(t, mtu, stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+}
+
+func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext {
+ t.Helper()
+
+ s := stack.New(options)
+ ep := channel.New(256, mtu, "")
+ wep := stack.LinkEndpoint(ep)
+
+ if testing.Verbose() {
+ wep = sniffer.New(ep)
+ }
+ if err := s.CreateNIC(1, wep); err != nil {
+ t.Fatalf("CreateNIC failed: %s", err)
+ }
+
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
+ t.Fatalf("AddAddress failed: %s", err)
+ }
+
+ if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
+ t.Fatalf("AddAddress failed: %s", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: 1,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: 1,
+ },
+ })
+
+ return &testContext{
+ t: t,
+ s: s,
+ linkEP: ep,
+ }
+}
+
+func (c *testContext) cleanup() {
+ if c.ep != nil {
+ c.ep.Close()
+ }
+}
+
+func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) {
+ c.t.Helper()
+
+ var err *tcpip.Error
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq)
+ if err != nil {
+ c.t.Fatal("NewEndpoint failed: ", err)
+ }
+}
+
+func (c *testContext) createEndpointForFlow(flow testFlow) {
+ c.t.Helper()
+
+ c.createEndpoint(flow.sockProto())
+ if flow.isV6Only() {
+ if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ c.t.Fatalf("SetSockOptBool failed: %s", err)
+ }
+ } else if flow.isBroadcast() {
+ if err := c.ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
+ c.t.Fatalf("SetSockOptBool failed: %s", err)
+ }
+ }
+}
+
+// getPacketAndVerify reads a packet from the link endpoint and verifies the
+// header against expected values from the given test flow. In addition, it
+// calls any extra checker functions provided.
+func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
+ c.t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ c.t.Fatalf("Packet wasn't written out")
+ return nil
+ }
+
+ if p.Proto != flow.netProto() {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
+ }
+
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+
+ h := flow.header4Tuple(outgoing)
+ checkers = append(
+ checkers,
+ checker.SrcAddr(h.srcAddr.Addr),
+ checker.DstAddr(h.dstAddr.Addr),
+ checker.UDP(checker.DstPort(h.dstAddr.Port)),
+ )
+ flow.checkerFn()(c.t, b, checkers...)
+ return b
+}
+
+// injectPacket creates a packet of the given flow and with the given payload,
+// and injects it into the link endpoint.
+func (c *testContext) injectPacket(flow testFlow, payload []byte) {
+ c.t.Helper()
+
+ h := flow.header4Tuple(incoming)
+ if flow.isV4() {
+ buf := c.buildV4Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ } else {
+ buf := c.buildV6Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ }
+}
+
+// buildV6Packet creates a V6 test packet with the given payload and header
+// values in a buffer.
+func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View {
+ // Allocate a buffer for data and headers.
+ buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
+ payloadStart := len(buf) - len(payload)
+ copy(buf[payloadStart:], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ TrafficClass: testTOS,
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 65,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
+ })
+
+ // Initialize the UDP header.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.Encode(&header.UDPFields{
+ SrcPort: h.srcAddr.Port,
+ DstPort: h.dstAddr.Port,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
+ })
+
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
+
+ // Calculate the UDP checksum and set it.
+ xsum = header.Checksum(payload, xsum)
+ u.SetChecksum(^u.CalculateChecksum(xsum))
+
+ return buf
+}
+
+// buildV4Packet creates a V4 test packet with the given payload and header
+// values in a buffer.
+func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View {
+ // Allocate a buffer for data and headers.
+ buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
+ payloadStart := len(buf) - len(payload)
+ copy(buf[payloadStart:], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: testTOS,
+ TotalLength: uint16(len(buf)),
+ TTL: 65,
+ Protocol: uint8(udp.ProtocolNumber),
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ // Initialize the UDP header.
+ u := header.UDP(buf[header.IPv4MinimumSize:])
+ u.Encode(&header.UDPFields{
+ SrcPort: h.srcAddr.Port,
+ DstPort: h.dstAddr.Port,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
+ })
+
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
+
+ // Calculate the UDP checksum and set it.
+ xsum = header.Checksum(payload, xsum)
+ u.SetChecksum(^u.CalculateChecksum(xsum))
+
+ return buf
+}
+
+func newPayload() []byte {
+ return newMinPayload(30)
+}
+
+func newMinPayload(minSize int) []byte {
+ b := make([]byte, minSize+rand.Intn(100))
+ for i := range b {
+ b[i] = byte(rand.Intn(256))
+ }
+ return b
+}
+
+func TestBindToDeviceOption(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ defer ep.Close()
+
+ opts := stack.NICOptions{Name: "my_device"}
+ if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil {
+ t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
+ }
+
+ // nicIDPtr is used instead of taking the address of NICID literals, which is
+ // a compiler error.
+ nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
+ return &s
+ }
+
+ testActions := []struct {
+ name string
+ setBindToDevice *tcpip.NICID
+ setBindToDeviceError *tcpip.Error
+ getBindToDevice tcpip.BindToDeviceOption
+ }{
+ {"GetDefaultValue", nil, nil, 0},
+ {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
+ {"BindToExistent", nicIDPtr(321), nil, 321},
+ {"UnbindToDevice", nicIDPtr(0), nil, 0},
+ }
+ for _, testAction := range testActions {
+ t.Run(testAction.name, func(t *testing.T) {
+ if testAction.setBindToDevice != nil {
+ bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
+ if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
+ }
+ }
+ bindToDevice := tcpip.BindToDeviceOption(88888)
+ if err := ep.GetSockOpt(&bindToDevice); err != nil {
+ t.Errorf("GetSockOpt got %v, want %v", err, nil)
+ }
+ if got, want := bindToDevice, testAction.getBindToDevice; got != want {
+ t.Errorf("bindToDevice got %d, want %d", got, want)
+ }
+ })
+ }
+}
+
+// testReadInternal sends a packet of the given test flow into the stack by
+// injecting it into the link endpoint. It then attempts to read it from the
+// UDP endpoint and depending on if this was expected to succeed verifies its
+// correctness including any additional checker functions provided.
+func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) {
+ c.t.Helper()
+
+ payload := newPayload()
+ c.injectPacket(flow, payload)
+
+ // Try to receive the data.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.wq.EventRegister(&we, waiter.EventIn)
+ defer c.wq.EventUnregister(&we)
+
+ // Take a snapshot of the stats to validate them at the end of the test.
+ epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
+
+ var addr tcpip.FullAddress
+ v, cm, err := c.ep.Read(&addr)
+ if err == tcpip.ErrWouldBlock {
+ // Wait for data to become available.
+ select {
+ case <-ch:
+ v, cm, err = c.ep.Read(&addr)
+
+ case <-time.After(300 * time.Millisecond):
+ if packetShouldBeDropped {
+ return // expected to time out
+ }
+ c.t.Fatal("timed out waiting for data")
+ }
+ }
+
+ if expectReadError && err != nil {
+ c.checkEndpointReadStats(1, epstats, err)
+ return
+ }
+
+ if err != nil {
+ c.t.Fatal("Read failed:", err)
+ }
+
+ if packetShouldBeDropped {
+ c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr)
+ }
+
+ // Check the peer address.
+ h := flow.header4Tuple(incoming)
+ if addr.Addr != h.srcAddr.Addr {
+ c.t.Fatalf("unexpected remote address: got %s, want %v", addr.Addr, h.srcAddr)
+ }
+
+ // Check the payload.
+ if !bytes.Equal(payload, v) {
+ c.t.Fatalf("bad payload: got %x, want %x", v, payload)
+ }
+
+ // Run any checkers against the ControlMessages.
+ for _, f := range checkers {
+ f(c.t, cm)
+ }
+
+ c.checkEndpointReadStats(1, epstats, err)
+}
+
+// testRead sends a packet of the given test flow into the stack by injecting it
+// into the link endpoint. It then reads it from the UDP endpoint and verifies
+// its correctness including any additional checker functions provided.
+func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) {
+ c.t.Helper()
+ testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...)
+}
+
+// testFailingRead sends a packet of the given test flow into the stack by
+// injecting it into the link endpoint. It then tries to read it from the UDP
+// endpoint and expects this to fail.
+func testFailingRead(c *testContext, flow testFlow, expectReadError bool) {
+ c.t.Helper()
+ testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError)
+}
+
+func TestBindEphemeralPort(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ if err := c.ep.Bind(tcpip.FullAddress{}); err != nil {
+ t.Fatalf("ep.Bind(...) failed: %s", err)
+ }
+}
+
+func TestBindReservedPort(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
+ c.t.Fatalf("Connect failed: %s", err)
+ }
+
+ addr, err := c.ep.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("GetLocalAddress failed: %s", err)
+ }
+
+ // We can't bind the address reserved by the connected endpoint above.
+ {
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer ep.Close()
+ if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want {
+ t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want)
+ }
+ }
+
+ func() {
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer ep.Close()
+ // We can't bind ipv4-any on the port reserved by the connected endpoint
+ // above, since the endpoint is dual-stack.
+ if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}), tcpip.ErrPortInUse; got != want {
+ t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want)
+ }
+ // We can bind an ipv4 address on this port, though.
+ if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil {
+ t.Fatalf("ep.Bind(...) failed: %s", err)
+ }
+ }()
+
+ // Once the connected endpoint releases its port reservation, we are able to
+ // bind ipv4-any once again.
+ c.ep.Close()
+ func() {
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer ep.Close()
+ if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil {
+ t.Fatalf("ep.Bind(...) failed: %s", err)
+ }
+ }()
+}
+
+func TestV4ReadOnV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(unicastV4in6)
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Test acceptance.
+ testRead(c, unicastV4in6)
+}
+
+func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(unicastV4in6)
+
+ // Bind to v4 mapped wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Test acceptance.
+ testRead(c, unicastV4in6)
+}
+
+func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(unicastV4in6)
+
+ // Bind to local address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Test acceptance.
+ testRead(c, unicastV4in6)
+}
+
+func TestV6ReadOnV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(unicastV6)
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Test acceptance.
+ testRead(c, unicastV6)
+}
+
+// TestV4ReadSelfSource checks that packets coming from a local IP address are
+// correctly dropped when handleLocal is true and not otherwise.
+func TestV4ReadSelfSource(t *testing.T) {
+ for _, tt := range []struct {
+ name string
+ handleLocal bool
+ wantErr *tcpip.Error
+ wantInvalidSource uint64
+ }{
+ {"HandleLocal", false, nil, 0},
+ {"NoHandleLocal", true, tcpip.ErrWouldBlock, 1},
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ HandleLocal: tt.handleLocal,
+ })
+ defer c.cleanup()
+
+ c.createEndpointForFlow(unicastV4)
+
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ h.srcAddr = h.dstAddr
+
+ buf := c.buildV4Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource {
+ t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
+ }
+
+ if _, _, err := c.ep.Read(nil); err != tt.wantErr {
+ t.Errorf("c.ep.Read() got error %v, want %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func TestV4ReadOnV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(unicastV4)
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Test acceptance.
+ testRead(c, unicastV4)
+}
+
+// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast
+// address and receive data sent to that address.
+func TestReadOnBoundToMulticast(t *testing.T) {
+ // FIXME(b/128189410): multicastV4in6 currently doesn't work as
+ // AddMembershipOption doesn't handle V4in6 addresses.
+ for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to multicast address.
+ mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr())
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil {
+ c.t.Fatal("Bind failed:", err)
+ }
+
+ // Join multicast group.
+ ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr}
+ if err := c.ep.SetSockOpt(ifoptSet); err != nil {
+ c.t.Fatal("SetSockOpt failed:", err)
+ }
+
+ // Check that we receive multicast packets but not unicast or broadcast
+ // ones.
+ testRead(c, flow)
+ testFailingRead(c, broadcast, false /* expectReadError */)
+ testFailingRead(c, unicastV4, false /* expectReadError */)
+ })
+ }
+}
+
+// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast
+// address and can receive only broadcast data.
+func TestV4ReadOnBoundToBroadcast(t *testing.T) {
+ for _, flow := range []testFlow{broadcast, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to broadcast address.
+ bcastAddr := flow.mapAddrIfApplicable(broadcastAddr)
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Check that we receive broadcast packets but not unicast ones.
+ testRead(c, flow)
+ testFailingRead(c, unicastV4, false /* expectReadError */)
+ })
+ }
+}
+
+// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
+// and receive broadcast and unicast data.
+func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
+ for _, flow := range []testFlow{broadcast, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s (", err)
+ }
+
+ // Check that we receive both broadcast and unicast packets.
+ testRead(c, flow)
+ testRead(c, unicastV4)
+ })
+ }
+}
+
+// testFailingWrite sends a packet of the given test flow into the UDP endpoint
+// and verifies it fails with the provided error code.
+func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
+ c.t.Helper()
+ // Take a snapshot of the stats to validate them at the end of the test.
+ epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
+ h := flow.header4Tuple(outgoing)
+ writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
+
+ payload := buffer.View(newPayload())
+ _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
+ })
+ c.checkEndpointWriteStats(1, epstats, gotErr)
+ if gotErr != wantErr {
+ c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr)
+ }
+}
+
+// testWrite sends a packet of the given test flow from the UDP endpoint to the
+// flow's destination address:port. It then receives it from the link endpoint
+// and verifies its correctness including any additional checker functions
+// provided.
+func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+ return testWriteInternal(c, flow, true, checkers...)
+}
+
+// testWriteWithoutDestination sends a packet of the given test flow from the
+// UDP endpoint without giving a destination address:port. It then receives it
+// from the link endpoint and verifies its correctness including any additional
+// checker functions provided.
+func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+ return testWriteInternal(c, flow, false, checkers...)
+}
+
+func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+ // Take a snapshot of the stats to validate them at the end of the test.
+ epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
+
+ writeOpts := tcpip.WriteOptions{}
+ if setDest {
+ h := flow.header4Tuple(outgoing)
+ writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
+ writeOpts = tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
+ }
+ }
+ payload := buffer.View(newPayload())
+ n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
+ if err != nil {
+ c.t.Fatalf("Write failed: %s", err)
+ }
+ if n != int64(len(payload)) {
+ c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
+ }
+ c.checkEndpointWriteStats(1, epstats, err)
+ // Received the packet and check the payload.
+ b := c.getPacketAndVerify(flow, checkers...)
+ var udp header.UDP
+ if flow.isV4() {
+ udp = header.UDP(header.IPv4(b).Payload())
+ } else {
+ udp = header.UDP(header.IPv6(b).Payload())
+ }
+ if !bytes.Equal(payload, udp.Payload()) {
+ c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
+ }
+
+ return udp.SourcePort()
+}
+
+func testDualWrite(c *testContext) uint16 {
+ c.t.Helper()
+
+ v4Port := testWrite(c, unicastV4in6)
+ v6Port := testWrite(c, unicastV6)
+ if v4Port != v6Port {
+ c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port)
+ }
+
+ return v4Port
+}
+
+func TestDualWriteUnbound(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ testDualWrite(c)
+}
+
+func TestDualWriteBoundToWildcard(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ p := testDualWrite(c)
+ if p != stackPort {
+ c.t.Fatalf("Bad port: got %v, want %v", p, stackPort)
+ }
+}
+
+func TestDualWriteConnectedToV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ // Connect to v6 address.
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, unicastV6)
+
+ // Write to V4 mapped address.
+ testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable)
+ const want = 1
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want {
+ c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want)
+ }
+}
+
+func TestDualWriteConnectedToV4Mapped(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ // Connect to v4 mapped address.
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, unicastV4in6)
+
+ // Write to v6 address.
+ testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState)
+}
+
+func TestV4WriteOnV6Only(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(unicastV6Only)
+
+ // Write to V4 mapped address.
+ testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute)
+}
+
+func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ // Bind to v4 mapped address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Write to v6 address.
+ testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState)
+}
+
+func TestV6WriteOnConnected(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ // Connect to v6 address.
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
+ c.t.Fatalf("Connect failed: %s", err)
+ }
+
+ testWriteWithoutDestination(c, unicastV6)
+}
+
+func TestV4WriteOnConnected(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ // Connect to v4 mapped address.
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
+ c.t.Fatalf("Connect failed: %s", err)
+ }
+
+ testWriteWithoutDestination(c, unicastV4)
+}
+
+// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket
+// that is bound to a V4 multicast address.
+func TestWriteOnBoundToV4Multicast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4 mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil {
+ c.t.Fatal("Bind failed:", err)
+ }
+
+ testWrite(c, flow)
+ })
+ }
+}
+
+// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a
+// socket that is bound to a V4-mapped multicast address.
+func TestWriteOnBoundToV4MappedMulticast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4Mapped mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
+ }
+}
+
+// TestWriteOnBoundToV6Multicast checks that we can send packets out of a
+// socket that is bound to a V6 multicast address.
+func TestWriteOnBoundToV6Multicast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV6, multicastV6} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V6 mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
+ }
+}
+
+// TestWriteOnBoundToV6Multicast checks that we can send packets out of a
+// V6-only socket that is bound to a V6 multicast address.
+func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV6Only, multicastV6Only} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V6 mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
+ }
+}
+
+// TestWriteOnBoundToBroadcast checks that we can send packets out of a
+// socket that is bound to the broadcast address.
+func TestWriteOnBoundToBroadcast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4 broadcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil {
+ c.t.Fatal("Bind failed:", err)
+ }
+
+ testWrite(c, flow)
+ })
+ }
+}
+
+// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a
+// socket that is bound to the V4-mapped broadcast address.
+func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4Mapped mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
+ }
+}
+
+func TestReadIncrementsPacketsReceived(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ // Create IPv4 UDP endpoint
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testRead(c, unicastV4)
+
+ var want uint64 = 1
+ if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want {
+ c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want)
+ }
+}
+
+func TestWriteIncrementsPacketsSent(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ testDualWrite(c)
+
+ var want uint64 = 2
+ if got := c.s.Stats().UDP.PacketsSent.Value(); got != want {
+ c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want)
+ }
+}
+
+func TestNoChecksum(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, unicastV6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Disable the checksum generation.
+ if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil {
+ t.Fatalf("SetSockOptBool failed: %s", err)
+ }
+ // This option is effective on IPv4 only.
+ testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4())))
+
+ // Enable the checksum generation.
+ if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil {
+ t.Fatalf("SetSockOptBool failed: %s", err)
+ }
+ testWrite(c, flow, checker.UDP(checker.NoChecksum(false)))
+ })
+ }
+}
+
+func TestTTL(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ const multicastTTL = 42
+ if err := c.ep.SetSockOptInt(tcpip.MulticastTTLOption, multicastTTL); err != nil {
+ c.t.Fatalf("SetSockOptInt failed: %s", err)
+ }
+
+ var wantTTL uint8
+ if flow.isMulticast() {
+ wantTTL = multicastTTL
+ } else {
+ var p stack.NetworkProtocol
+ if flow.isV4() {
+ p = ipv4.NewProtocol()
+ } else {
+ p = ipv6.NewProtocol()
+ }
+ ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ }))
+ if err != nil {
+ t.Fatal(err)
+ }
+ wantTTL = ep.DefaultTTL()
+ ep.Close()
+ }
+
+ testWrite(c, flow, checker.TTL(wantTTL))
+ })
+ }
+}
+
+func TestSetTTL(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
+ t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ if err := c.ep.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil {
+ c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
+ }
+
+ var p stack.NetworkProtocol
+ if flow.isV4() {
+ p = ipv4.NewProtocol()
+ } else {
+ p = ipv6.NewProtocol()
+ }
+ ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ }))
+ if err != nil {
+ t.Fatal(err)
+ }
+ ep.Close()
+
+ testWrite(c, flow, checker.TTL(wantTTL))
+ })
+ }
+ })
+ }
+}
+
+func TestSetTOS(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ const tos = testTOS
+ v, err := c.ep.GetSockOptInt(tcpip.IPv4TOSOption)
+ if err != nil {
+ c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err)
+ }
+ // Test for expected default value.
+ if v != 0 {
+ c.t.Errorf("got GetSockOpt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0)
+ }
+
+ if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil {
+ c.t.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err)
+ }
+
+ v, err = c.ep.GetSockOptInt(tcpip.IPv4TOSOption)
+ if err != nil {
+ c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err)
+ }
+
+ if v != tos {
+ c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos)
+ }
+
+ testWrite(c, flow, checker.TOS(tos, 0))
+ })
+ }
+}
+
+func TestSetTClass(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ const tClass = testTOS
+ v, err := c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
+ if err != nil {
+ c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err)
+ }
+ // Test for expected default value.
+ if v != 0 {
+ c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0)
+ }
+
+ if err := c.ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil {
+ c.t.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err)
+ }
+
+ v, err = c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
+ if err != nil {
+ c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err)
+ }
+
+ if v != tClass {
+ c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass)
+ }
+
+ // The header getter for TClass is called TOS, so use that checker.
+ testWrite(c, flow, checker.TOS(tClass, 0))
+ })
+ }
+}
+
+func TestReceiveTosTClass(t *testing.T) {
+ testCases := []struct {
+ name string
+ getReceiveOption tcpip.SockOptBool
+ tests []testFlow
+ }{
+ {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}},
+ {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
+ }
+ for _, testCase := range testCases {
+ for _, flow := range testCase.tests {
+ t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+ option := testCase.getReceiveOption
+ name := testCase.name
+
+ // Verify that setting and reading the option works.
+ v, err := c.ep.GetSockOptBool(option)
+ if err != nil {
+ c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err)
+ }
+ // Test for expected default value.
+ if v != false {
+ c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false)
+ }
+
+ want := true
+ if err := c.ep.SetSockOptBool(option, want); err != nil {
+ c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err)
+ }
+
+ got, err := c.ep.GetSockOptBool(option)
+ if err != nil {
+ c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err)
+ }
+
+ if got != want {
+ c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want)
+ }
+
+ // Verify that the correct received TOS or TClass is handed through as
+ // ancillary data to the ControlMessages struct.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+ switch option {
+ case tcpip.ReceiveTClassOption:
+ testRead(c, flow, checker.ReceiveTClass(testTOS))
+ case tcpip.ReceiveTOSOption:
+ testRead(c, flow, checker.ReceiveTOS(testTOS))
+ default:
+ t.Fatalf("unknown test variant: %s", name)
+ }
+ })
+ }
+ }
+}
+
+func TestMulticastInterfaceOption(t *testing.T) {
+ for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ for _, bindTyp := range []string{"bound", "unbound"} {
+ t.Run(bindTyp, func(t *testing.T) {
+ for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} {
+ t.Run(optTyp, func(t *testing.T) {
+ h := flow.header4Tuple(outgoing)
+ mcastAddr := h.dstAddr.Addr
+ localIfAddr := h.srcAddr.Addr
+
+ var ifoptSet tcpip.MulticastInterfaceOption
+ switch optTyp {
+ case "use local-addr":
+ ifoptSet.InterfaceAddr = localIfAddr
+ case "use NICID":
+ ifoptSet.NIC = 1
+ case "use local-addr and NIC":
+ ifoptSet.InterfaceAddr = localIfAddr
+ ifoptSet.NIC = 1
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(flow.sockProto())
+
+ if bindTyp == "bound" {
+ // Bind the socket by connecting to the multicast address.
+ // This may have an influence on how the multicast interface
+ // is set.
+ addr := tcpip.FullAddress{
+ Addr: flow.mapAddrIfApplicable(mcastAddr),
+ Port: stackPort,
+ }
+ if err := c.ep.Connect(addr); err != nil {
+ c.t.Fatalf("Connect failed: %s", err)
+ }
+ }
+
+ if err := c.ep.SetSockOpt(ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %s", err)
+ }
+
+ // Verify multicast interface addr and NIC were set correctly.
+ // Note that NIC must be 1 since this is our outgoing interface.
+ ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
+ var ifoptGot tcpip.MulticastInterfaceOption
+ if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
+ c.t.Fatalf("GetSockOpt failed: %s", err)
+ }
+ if ifoptGot != ifoptWant {
+ c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
+
+// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination
+// Unreachable message when a udp datagram is received on ports for which there
+// is no bound udp socket.
+func TestV4UnknownDestination(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ testCases := []struct {
+ flow testFlow
+ icmpRequired bool
+ // largePayload if true, will result in a payload large enough
+ // so that the final generated IPv4 packet is larger than
+ // header.IPv4MinimumProcessableDatagramSize.
+ largePayload bool
+ }{
+ {unicastV4, true, false},
+ {unicastV4, true, true},
+ {multicastV4, false, false},
+ {multicastV4, false, true},
+ {broadcast, false, false},
+ {broadcast, false, true},
+ }
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
+ payload := newPayload()
+ if tc.largePayload {
+ payload = newMinPayload(576)
+ }
+ c.injectPacket(tc.flow, payload)
+ if !tc.icmpRequired {
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+ if p, ok := c.linkEP.ReadContext(ctx); ok {
+ t.Fatalf("unexpected packet received: %+v", p)
+ }
+ return
+ }
+
+ // ICMP required.
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ t.Fatalf("packet wasn't written out")
+ return
+ }
+
+ var pkt []byte
+ pkt = append(pkt, p.Pkt.Header.View()...)
+ pkt = append(pkt, p.Pkt.Data.ToView()...)
+ if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
+
+ hdr := header.IPv4(pkt)
+ checker.IPv4(t, hdr, checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+
+ icmpPkt := header.ICMPv4(hdr.Payload())
+ payloadIPHeader := header.IPv4(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
+ }
+
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %d, want: %d", got, want)
+ }
+ })
+ }
+}
+
+// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination
+// Unreachable message when a udp datagram is received on ports for which there
+// is no bound udp socket.
+func TestV6UnknownDestination(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ testCases := []struct {
+ flow testFlow
+ icmpRequired bool
+ // largePayload if true will result in a payload large enough to
+ // create an IPv6 packet > header.IPv6MinimumMTU bytes.
+ largePayload bool
+ }{
+ {unicastV6, true, false},
+ {unicastV6, true, true},
+ {multicastV6, false, false},
+ {multicastV6, false, true},
+ }
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
+ payload := newPayload()
+ if tc.largePayload {
+ payload = newMinPayload(1280)
+ }
+ c.injectPacket(tc.flow, payload)
+ if !tc.icmpRequired {
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+ if p, ok := c.linkEP.ReadContext(ctx); ok {
+ t.Fatalf("unexpected packet received: %+v", p)
+ }
+ return
+ }
+
+ // ICMP required.
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ t.Fatalf("packet wasn't written out")
+ return
+ }
+
+ var pkt []byte
+ pkt = append(pkt, p.Pkt.Header.View()...)
+ pkt = append(pkt, p.Pkt.Data.ToView()...)
+ if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
+
+ hdr := header.IPv6(pkt)
+ checker.IPv6(t, hdr, checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
+
+ icmpPkt := header.ICMPv6(hdr.Payload())
+ payloadIPHeader := header.IPv6(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
+ }
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
+
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %v, want: %v", got, want)
+ }
+ })
+ }
+}
+
+// TestIncrementMalformedPacketsReceived verifies if the malformed received
+// global and endpoint stats are incremented.
+func TestIncrementMalformedPacketsReceived(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+
+ // Invalidate the UDP header length field.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetLength(u.Length() + 1)
+
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
+ }
+}
+
+// TestShortHeader verifies that when a packet with a too-short UDP header is
+// received, the malformed received global stat gets incremented.
+func TestShortHeader(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ h := unicastV6.header4Tuple(incoming)
+
+ // Allocate a buffer for an IPv6 and too-short UDP header.
+ const udpSize = header.UDPMinimumSize - 1
+ buf := buffer.NewView(header.IPv6MinimumSize + udpSize)
+ // Initialize the IP header.
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ TrafficClass: testTOS,
+ PayloadLength: uint16(udpSize),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 65,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
+ })
+
+ // Initialize the UDP header.
+ udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize))
+ udpHdr.Encode(&header.UDPFields{
+ SrcPort: h.srcAddr.Port,
+ DstPort: h.dstAddr.Port,
+ Length: header.UDPMinimumSize,
+ })
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr)))
+ udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
+ // Copy all but the last byte of the UDP header into the packet.
+ copy(buf[header.IPv6MinimumSize:], udpHdr)
+
+ // Inject packet.
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want {
+ t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want)
+ }
+}
+
+// TestIncrementChecksumErrorsV4 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestIncrementChecksumErrorsV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+
+ // Invalidate the UDP header checksum field, taking care to avoid
+ // overflow to zero, which would disable checksum validation.
+ for u := header.UDP(buf[header.IPv4MinimumSize:]); ; {
+ u.SetChecksum(u.Checksum() + 1)
+ if u.Checksum() != 0 {
+ break
+ }
+ }
+
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestIncrementChecksumErrorsV6 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestIncrementChecksumErrorsV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+
+ // Invalidate the UDP header checksum field.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetChecksum(u.Checksum() + 1)
+
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestPayloadModifiedV4 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestPayloadModifiedV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+ // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ buf[len(buf)-1]++
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestPayloadModifiedV6 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestPayloadModifiedV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+ // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ buf[len(buf)-1]++
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestChecksumZeroV4 verifies if the checksum value is zero, global and
+// endpoint states are *not* incremented (UDP checksum is optional on IPv4).
+func TestChecksumZeroV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+ // Set the checksum field in the UDP header to zero.
+ u := header.UDP(buf[header.IPv4MinimumSize:])
+ u.SetChecksum(0)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 0
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestChecksumZeroV6 verifies if the checksum value is zero, global and
+// endpoint states are incremented (UDP checksum is *not* optional on IPv6).
+func TestChecksumZeroV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+ // Set the checksum field in the UDP header to zero.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetChecksum(0)
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestShutdownRead verifies endpoint read shutdown and error
+// stats increment on packet receive.
+func TestShutdownRead(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
+ c.t.Fatalf("Connect failed: %s", err)
+ }
+
+ if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
+
+ testFailingRead(c, unicastV6, true /* expectReadError */)
+
+ var want uint64 = 1
+ if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want)
+ }
+}
+
+// TestShutdownWrite verifies endpoint write shutdown and error
+// stats increment on packet write.
+func TestShutdownWrite(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
+ c.t.Fatalf("Connect failed: %s", err)
+ }
+
+ if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
+
+ testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend)
+}
+
+func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) {
+ got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
+ switch err {
+ case nil:
+ want.PacketsSent.IncrementBy(incr)
+ case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ want.WriteErrors.InvalidArgs.IncrementBy(incr)
+ case tcpip.ErrClosedForSend:
+ want.WriteErrors.WriteClosed.IncrementBy(incr)
+ case tcpip.ErrInvalidEndpointState:
+ want.WriteErrors.InvalidEndpointState.IncrementBy(incr)
+ case tcpip.ErrNoLinkAddress:
+ want.SendErrors.NoLinkAddr.IncrementBy(incr)
+ case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ want.SendErrors.NoRoute.IncrementBy(incr)
+ default:
+ want.SendErrors.SendToNetworkFailed.IncrementBy(incr)
+ }
+ if got != want {
+ c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
+ }
+}
+
+func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) {
+ got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
+ switch err {
+ case nil, tcpip.ErrWouldBlock:
+ case tcpip.ErrClosedForReceive:
+ want.ReadErrors.ReadClosed.IncrementBy(incr)
+ default:
+ c.t.Errorf("Endpoint error missing stats update err %v", err)
+ }
+ if got != want {
+ c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
+ }
+}
diff --git a/pkg/test/criutil/BUILD b/pkg/test/criutil/BUILD
new file mode 100644
index 000000000..a7b082cee
--- /dev/null
+++ b/pkg/test/criutil/BUILD
@@ -0,0 +1,14 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "criutil",
+ testonly = 1,
+ srcs = ["criutil.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//pkg/test/testutil",
+ ],
+)
diff --git a/pkg/test/criutil/criutil.go b/pkg/test/criutil/criutil.go
new file mode 100644
index 000000000..8fed29ff5
--- /dev/null
+++ b/pkg/test/criutil/criutil.go
@@ -0,0 +1,317 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package criutil contains utility functions for interacting with the
+// Container Runtime Interface (CRI), principally via the crictl command line
+// tool. This requires critools to be installed on the local system.
+package criutil
+
+import (
+ "encoding/json"
+ "fmt"
+ "os"
+ "os/exec"
+ "strings"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+// Crictl contains information required to run the crictl utility.
+type Crictl struct {
+ logger testutil.Logger
+ endpoint string
+ cleanup []func()
+}
+
+// resolvePath attempts to find binary paths. It may set the path to invalid,
+// which will cause the execution to fail with a sensible error.
+func resolvePath(executable string) string {
+ guess, err := exec.LookPath(executable)
+ if err != nil {
+ guess = fmt.Sprintf("/usr/local/bin/%s", executable)
+ }
+ return guess
+}
+
+// NewCrictl returns a Crictl configured with a timeout and an endpoint over
+// which it will talk to containerd.
+func NewCrictl(logger testutil.Logger, endpoint string) *Crictl {
+ // Attempt to find the executable, but don't bother propagating the
+ // error at this point. The first command executed will return with a
+ // binary not found error.
+ return &Crictl{
+ logger: logger,
+ endpoint: endpoint,
+ }
+}
+
+// CleanUp executes cleanup functions.
+func (cc *Crictl) CleanUp() {
+ for _, c := range cc.cleanup {
+ c()
+ }
+ cc.cleanup = nil
+}
+
+// RunPod creates a sandbox. It corresponds to `crictl runp`.
+func (cc *Crictl) RunPod(sbSpecFile string) (string, error) {
+ podID, err := cc.run("runp", sbSpecFile)
+ if err != nil {
+ return "", fmt.Errorf("runp failed: %v", err)
+ }
+ // Strip the trailing newline from crictl output.
+ return strings.TrimSpace(podID), nil
+}
+
+// Create creates a container within a sandbox. It corresponds to `crictl
+// create`.
+func (cc *Crictl) Create(podID, contSpecFile, sbSpecFile string) (string, error) {
+ podID, err := cc.run("create", podID, contSpecFile, sbSpecFile)
+ if err != nil {
+ return "", fmt.Errorf("create failed: %v", err)
+ }
+ // Strip the trailing newline from crictl output.
+ return strings.TrimSpace(podID), nil
+}
+
+// Start starts a container. It corresponds to `crictl start`.
+func (cc *Crictl) Start(contID string) (string, error) {
+ output, err := cc.run("start", contID)
+ if err != nil {
+ return "", fmt.Errorf("start failed: %v", err)
+ }
+ return output, nil
+}
+
+// Stop stops a container. It corresponds to `crictl stop`.
+func (cc *Crictl) Stop(contID string) error {
+ _, err := cc.run("stop", contID)
+ return err
+}
+
+// Exec execs a program inside a container. It corresponds to `crictl exec`.
+func (cc *Crictl) Exec(contID string, args ...string) (string, error) {
+ a := []string{"exec", contID}
+ a = append(a, args...)
+ output, err := cc.run(a...)
+ if err != nil {
+ return "", fmt.Errorf("exec failed: %v", err)
+ }
+ return output, nil
+}
+
+// Logs retrieves the container logs. It corresponds to `crictl logs`.
+func (cc *Crictl) Logs(contID string, args ...string) (string, error) {
+ a := []string{"logs", contID}
+ a = append(a, args...)
+ output, err := cc.run(a...)
+ if err != nil {
+ return "", fmt.Errorf("logs failed: %v", err)
+ }
+ return output, nil
+}
+
+// Rm removes a container. It corresponds to `crictl rm`.
+func (cc *Crictl) Rm(contID string) error {
+ _, err := cc.run("rm", contID)
+ return err
+}
+
+// StopPod stops a pod. It corresponds to `crictl stopp`.
+func (cc *Crictl) StopPod(podID string) error {
+ _, err := cc.run("stopp", podID)
+ return err
+}
+
+// containsConfig is a minimal copy of
+// https://github.com/kubernetes/kubernetes/blob/master/pkg/kubelet/apis/cri/runtime/v1alpha2/api.proto
+// It only contains fields needed for testing.
+type containerConfig struct {
+ Status containerStatus
+}
+
+type containerStatus struct {
+ Network containerNetwork
+}
+
+type containerNetwork struct {
+ IP string
+}
+
+// PodIP returns a pod's IP address.
+func (cc *Crictl) PodIP(podID string) (string, error) {
+ output, err := cc.run("inspectp", podID)
+ if err != nil {
+ return "", err
+ }
+ conf := &containerConfig{}
+ if err := json.Unmarshal([]byte(output), conf); err != nil {
+ return "", fmt.Errorf("failed to unmarshal JSON: %v, %s", err, output)
+ }
+ if conf.Status.Network.IP == "" {
+ return "", fmt.Errorf("no IP found in config: %s", output)
+ }
+ return conf.Status.Network.IP, nil
+}
+
+// RmPod removes a container. It corresponds to `crictl rmp`.
+func (cc *Crictl) RmPod(podID string) error {
+ _, err := cc.run("rmp", podID)
+ return err
+}
+
+// Import imports the given container from the local Docker instance.
+func (cc *Crictl) Import(image string) error {
+ // Note that we provide a 10 minute timeout after connect because we may
+ // be pushing a lot of bytes in order to import the image. The connect
+ // timeout stays the same and is inherited from the Crictl instance.
+ cmd := testutil.Command(cc.logger,
+ resolvePath("ctr"),
+ fmt.Sprintf("--connect-timeout=%s", 30*time.Second),
+ fmt.Sprintf("--address=%s", cc.endpoint),
+ "-n", "k8s.io", "images", "import", "-")
+ cmd.Stderr = os.Stderr // Pass through errors.
+
+ // Create a pipe and start the program.
+ w, err := cmd.StdinPipe()
+ if err != nil {
+ return err
+ }
+ if err := cmd.Start(); err != nil {
+ return err
+ }
+
+ // Save the image on the other end.
+ if err := dockerutil.Save(cc.logger, image, w); err != nil {
+ cmd.Wait()
+ return err
+ }
+
+ // Close our pipe reference & see if it was loaded.
+ if err := w.Close(); err != nil {
+ return w.Close()
+ }
+
+ return cmd.Wait()
+}
+
+// StartContainer pulls the given image ands starts the container in the
+// sandbox with the given podID.
+//
+// Note that the image will always be imported from the local docker daemon.
+func (cc *Crictl) StartContainer(podID, image, sbSpec, contSpec string) (string, error) {
+ if err := cc.Import(image); err != nil {
+ return "", err
+ }
+
+ // Write the specs to files that can be read by crictl.
+ sbSpecFile, cleanup, err := testutil.WriteTmpFile("sbSpec", sbSpec)
+ if err != nil {
+ return "", fmt.Errorf("failed to write sandbox spec: %v", err)
+ }
+ cc.cleanup = append(cc.cleanup, cleanup)
+ contSpecFile, cleanup, err := testutil.WriteTmpFile("contSpec", contSpec)
+ if err != nil {
+ return "", fmt.Errorf("failed to write container spec: %v", err)
+ }
+ cc.cleanup = append(cc.cleanup, cleanup)
+
+ return cc.startContainer(podID, image, sbSpecFile, contSpecFile)
+}
+
+func (cc *Crictl) startContainer(podID, image, sbSpecFile, contSpecFile string) (string, error) {
+ contID, err := cc.Create(podID, contSpecFile, sbSpecFile)
+ if err != nil {
+ return "", fmt.Errorf("failed to create container in pod %q: %v", podID, err)
+ }
+
+ if _, err := cc.Start(contID); err != nil {
+ return "", fmt.Errorf("failed to start container %q in pod %q: %v", contID, podID, err)
+ }
+
+ return contID, nil
+}
+
+// StopContainer stops and deletes the container with the given container ID.
+func (cc *Crictl) StopContainer(contID string) error {
+ if err := cc.Stop(contID); err != nil {
+ return fmt.Errorf("failed to stop container %q: %v", contID, err)
+ }
+
+ if err := cc.Rm(contID); err != nil {
+ return fmt.Errorf("failed to remove container %q: %v", contID, err)
+ }
+
+ return nil
+}
+
+// StartPodAndContainer starts a sandbox and container in that sandbox. It
+// returns the pod ID and container ID.
+func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string, string, error) {
+ if err := cc.Import(image); err != nil {
+ return "", "", err
+ }
+
+ // Write the specs to files that can be read by crictl.
+ sbSpecFile, cleanup, err := testutil.WriteTmpFile("sbSpec", sbSpec)
+ if err != nil {
+ return "", "", fmt.Errorf("failed to write sandbox spec: %v", err)
+ }
+ cc.cleanup = append(cc.cleanup, cleanup)
+ contSpecFile, cleanup, err := testutil.WriteTmpFile("contSpec", contSpec)
+ if err != nil {
+ return "", "", fmt.Errorf("failed to write container spec: %v", err)
+ }
+ cc.cleanup = append(cc.cleanup, cleanup)
+
+ podID, err := cc.RunPod(sbSpecFile)
+ if err != nil {
+ return "", "", err
+ }
+
+ contID, err := cc.startContainer(podID, image, sbSpecFile, contSpecFile)
+
+ return podID, contID, err
+}
+
+// StopPodAndContainer stops a container and pod.
+func (cc *Crictl) StopPodAndContainer(podID, contID string) error {
+ if err := cc.StopContainer(contID); err != nil {
+ return fmt.Errorf("failed to stop container %q in pod %q: %v", contID, podID, err)
+ }
+
+ if err := cc.StopPod(podID); err != nil {
+ return fmt.Errorf("failed to stop pod %q: %v", podID, err)
+ }
+
+ if err := cc.RmPod(podID); err != nil {
+ return fmt.Errorf("failed to remove pod %q: %v", podID, err)
+ }
+
+ return nil
+}
+
+// run runs crictl with the given args.
+func (cc *Crictl) run(args ...string) (string, error) {
+ defaultArgs := []string{
+ resolvePath("crictl"),
+ "--image-endpoint", fmt.Sprintf("unix://%s", cc.endpoint),
+ "--runtime-endpoint", fmt.Sprintf("unix://%s", cc.endpoint),
+ }
+ fullArgs := append(defaultArgs, args...)
+ out, err := testutil.Command(cc.logger, fullArgs...).CombinedOutput()
+ return string(out), err
+}
diff --git a/pkg/test/dockerutil/BUILD b/pkg/test/dockerutil/BUILD
new file mode 100644
index 000000000..83b80c8bc
--- /dev/null
+++ b/pkg/test/dockerutil/BUILD
@@ -0,0 +1,25 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "dockerutil",
+ testonly = 1,
+ srcs = [
+ "container.go",
+ "dockerutil.go",
+ "exec.go",
+ "network.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/testutil",
+ "@com_github_docker_docker//api/types:go_default_library",
+ "@com_github_docker_docker//api/types/container:go_default_library",
+ "@com_github_docker_docker//api/types/mount:go_default_library",
+ "@com_github_docker_docker//api/types/network:go_default_library",
+ "@com_github_docker_docker//client:go_default_library",
+ "@com_github_docker_docker//pkg/stdcopy:go_default_library",
+ "@com_github_docker_go_connections//nat:go_default_library",
+ ],
+)
diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go
new file mode 100644
index 000000000..17acdaf6f
--- /dev/null
+++ b/pkg/test/dockerutil/container.go
@@ -0,0 +1,501 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io/ioutil"
+ "net"
+ "os"
+ "path"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/docker/docker/api/types"
+ "github.com/docker/docker/api/types/container"
+ "github.com/docker/docker/api/types/mount"
+ "github.com/docker/docker/api/types/network"
+ "github.com/docker/docker/client"
+ "github.com/docker/docker/pkg/stdcopy"
+ "github.com/docker/go-connections/nat"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+// Container represents a Docker Container allowing
+// user to configure and control as one would with the 'docker'
+// client. Container is backed by the offical golang docker API.
+// See: https://pkg.go.dev/github.com/docker/docker.
+type Container struct {
+ Name string
+ Runtime string
+
+ logger testutil.Logger
+ client *client.Client
+ id string
+ mounts []mount.Mount
+ links []string
+ cleanups []func()
+ copyErr error
+
+ // Stores streams attached to the container. Used by WaitForOutputSubmatch.
+ streams types.HijackedResponse
+
+ // stores previously read data from the attached streams.
+ streamBuf bytes.Buffer
+}
+
+// RunOpts are options for running a container.
+type RunOpts struct {
+ // Image is the image relative to images/. This will be mangled
+ // appropriately, to ensure that only first-party images are used.
+ Image string
+
+ // Memory is the memory limit in bytes.
+ Memory int
+
+ // Cpus in which to allow execution. ("0", "1", "0-2").
+ CpusetCpus string
+
+ // Ports are the ports to be allocated.
+ Ports []int
+
+ // WorkDir sets the working directory.
+ WorkDir string
+
+ // ReadOnly sets the read-only flag.
+ ReadOnly bool
+
+ // Env are additional environment variables.
+ Env []string
+
+ // User is the user to use.
+ User string
+
+ // Privileged enables privileged mode.
+ Privileged bool
+
+ // CapAdd are the extra set of capabilities to add.
+ CapAdd []string
+
+ // CapDrop are the extra set of capabilities to drop.
+ CapDrop []string
+
+ // Mounts is the list of directories/files to be mounted inside the container.
+ Mounts []mount.Mount
+
+ // Links is the list of containers to be connected to the container.
+ Links []string
+}
+
+// MakeContainer sets up the struct for a Docker container.
+//
+// Names of containers will be unique.
+func MakeContainer(ctx context.Context, logger testutil.Logger) *Container {
+ // Slashes are not allowed in container names.
+ name := testutil.RandomID(logger.Name())
+ name = strings.ReplaceAll(name, "/", "-")
+ client, err := client.NewClientWithOpts(client.FromEnv)
+ if err != nil {
+ return nil
+ }
+
+ client.NegotiateAPIVersion(ctx)
+
+ return &Container{
+ logger: logger,
+ Name: name,
+ Runtime: *runtime,
+ client: client,
+ }
+}
+
+// Spawn is analogous to 'docker run -d'.
+func (c *Container) Spawn(ctx context.Context, r RunOpts, args ...string) error {
+ if err := c.create(ctx, r, args); err != nil {
+ return err
+ }
+ return c.Start(ctx)
+}
+
+// SpawnProcess is analogous to 'docker run -it'. It returns a process
+// which represents the root process.
+func (c *Container) SpawnProcess(ctx context.Context, r RunOpts, args ...string) (Process, error) {
+ config, hostconf, netconf := c.ConfigsFrom(r, args...)
+ config.Tty = true
+ config.OpenStdin = true
+
+ if err := c.CreateFrom(ctx, config, hostconf, netconf); err != nil {
+ return Process{}, err
+ }
+
+ if err := c.Start(ctx); err != nil {
+ return Process{}, err
+ }
+
+ return Process{container: c, conn: c.streams}, nil
+}
+
+// Run is analogous to 'docker run'.
+func (c *Container) Run(ctx context.Context, r RunOpts, args ...string) (string, error) {
+ if err := c.create(ctx, r, args); err != nil {
+ return "", err
+ }
+
+ if err := c.Start(ctx); err != nil {
+ return "", err
+ }
+
+ if err := c.Wait(ctx); err != nil {
+ return "", err
+ }
+
+ return c.Logs(ctx)
+}
+
+// ConfigsFrom returns container configs from RunOpts and args. The caller should call 'CreateFrom'
+// and Start.
+func (c *Container) ConfigsFrom(r RunOpts, args ...string) (*container.Config, *container.HostConfig, *network.NetworkingConfig) {
+ return c.config(r, args), c.hostConfig(r), &network.NetworkingConfig{}
+}
+
+// MakeLink formats a link to add to a RunOpts.
+func (c *Container) MakeLink(target string) string {
+ return fmt.Sprintf("%s:%s", c.Name, target)
+}
+
+// CreateFrom creates a container from the given configs.
+func (c *Container) CreateFrom(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error {
+ cont, err := c.client.ContainerCreate(ctx, conf, hostconf, netconf, c.Name)
+ if err != nil {
+ return err
+ }
+ c.id = cont.ID
+ return nil
+}
+
+// Create is analogous to 'docker create'.
+func (c *Container) Create(ctx context.Context, r RunOpts, args ...string) error {
+ return c.create(ctx, r, args)
+}
+
+func (c *Container) create(ctx context.Context, r RunOpts, args []string) error {
+ conf := c.config(r, args)
+ hostconf := c.hostConfig(r)
+ cont, err := c.client.ContainerCreate(ctx, conf, hostconf, nil, c.Name)
+ if err != nil {
+ return err
+ }
+ c.id = cont.ID
+ return nil
+}
+
+func (c *Container) config(r RunOpts, args []string) *container.Config {
+ ports := nat.PortSet{}
+ for _, p := range r.Ports {
+ port := nat.Port(fmt.Sprintf("%d", p))
+ ports[port] = struct{}{}
+ }
+ env := append(r.Env, fmt.Sprintf("RUNSC_TEST_NAME=%s", c.Name))
+
+ return &container.Config{
+ Image: testutil.ImageByName(r.Image),
+ Cmd: args,
+ ExposedPorts: ports,
+ Env: env,
+ WorkingDir: r.WorkDir,
+ User: r.User,
+ }
+}
+
+func (c *Container) hostConfig(r RunOpts) *container.HostConfig {
+ c.mounts = append(c.mounts, r.Mounts...)
+
+ return &container.HostConfig{
+ Runtime: c.Runtime,
+ Mounts: c.mounts,
+ PublishAllPorts: true,
+ Links: r.Links,
+ CapAdd: r.CapAdd,
+ CapDrop: r.CapDrop,
+ Privileged: r.Privileged,
+ ReadonlyRootfs: r.ReadOnly,
+ Resources: container.Resources{
+ Memory: int64(r.Memory), // In bytes.
+ CpusetCpus: r.CpusetCpus,
+ },
+ }
+}
+
+// Start is analogous to 'docker start'.
+func (c *Container) Start(ctx context.Context) error {
+
+ // Open a connection to the container for parsing logs and for TTY.
+ streams, err := c.client.ContainerAttach(ctx, c.id,
+ types.ContainerAttachOptions{
+ Stream: true,
+ Stdin: true,
+ Stdout: true,
+ Stderr: true,
+ })
+ if err != nil {
+ return fmt.Errorf("failed to connect to container: %v", err)
+ }
+
+ c.streams = streams
+ c.cleanups = append(c.cleanups, func() {
+ c.streams.Close()
+ })
+
+ return c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{})
+}
+
+// Stop is analogous to 'docker stop'.
+func (c *Container) Stop(ctx context.Context) error {
+ return c.client.ContainerStop(ctx, c.id, nil)
+}
+
+// Pause is analogous to'docker pause'.
+func (c *Container) Pause(ctx context.Context) error {
+ return c.client.ContainerPause(ctx, c.id)
+}
+
+// Unpause is analogous to 'docker unpause'.
+func (c *Container) Unpause(ctx context.Context) error {
+ return c.client.ContainerUnpause(ctx, c.id)
+}
+
+// Checkpoint is analogous to 'docker checkpoint'.
+func (c *Container) Checkpoint(ctx context.Context, name string) error {
+ return c.client.CheckpointCreate(ctx, c.Name, types.CheckpointCreateOptions{CheckpointID: name, Exit: true})
+}
+
+// Restore is analogous to 'docker start --checkname [name]'.
+func (c *Container) Restore(ctx context.Context, name string) error {
+ return c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{CheckpointID: name})
+}
+
+// Logs is analogous 'docker logs'.
+func (c *Container) Logs(ctx context.Context) (string, error) {
+ var out bytes.Buffer
+ err := c.logs(ctx, &out, &out)
+ return out.String(), err
+}
+
+func (c *Container) logs(ctx context.Context, stdout, stderr *bytes.Buffer) error {
+ opts := types.ContainerLogsOptions{ShowStdout: true, ShowStderr: true}
+ writer, err := c.client.ContainerLogs(ctx, c.id, opts)
+ if err != nil {
+ return err
+ }
+ defer writer.Close()
+ _, err = stdcopy.StdCopy(stdout, stderr, writer)
+
+ return err
+}
+
+// ID returns the container id.
+func (c *Container) ID() string {
+ return c.id
+}
+
+// SandboxPid returns the container's pid.
+func (c *Container) SandboxPid(ctx context.Context) (int, error) {
+ resp, err := c.client.ContainerInspect(ctx, c.id)
+ if err != nil {
+ return -1, err
+ }
+ return resp.ContainerJSONBase.State.Pid, nil
+}
+
+// FindIP returns the IP address of the container.
+func (c *Container) FindIP(ctx context.Context) (net.IP, error) {
+ resp, err := c.client.ContainerInspect(ctx, c.id)
+ if err != nil {
+ return nil, err
+ }
+
+ ip := net.ParseIP(resp.NetworkSettings.DefaultNetworkSettings.IPAddress)
+ if ip == nil {
+ return net.IP{}, fmt.Errorf("invalid IP: %q", ip)
+ }
+ return ip, nil
+}
+
+// FindPort returns the host port that is mapped to 'sandboxPort'.
+func (c *Container) FindPort(ctx context.Context, sandboxPort int) (int, error) {
+ desc, err := c.client.ContainerInspect(ctx, c.id)
+ if err != nil {
+ return -1, fmt.Errorf("error retrieving port: %v", err)
+ }
+
+ format := fmt.Sprintf("%d/tcp", sandboxPort)
+ ports, ok := desc.NetworkSettings.Ports[nat.Port(format)]
+ if !ok {
+ return -1, fmt.Errorf("error retrieving port: %v", err)
+
+ }
+
+ port, err := strconv.Atoi(ports[0].HostPort)
+ if err != nil {
+ return -1, fmt.Errorf("error parsing port %q: %v", port, err)
+ }
+ return port, nil
+}
+
+// CopyFiles copies in and mounts the given files. They are always ReadOnly.
+func (c *Container) CopyFiles(opts *RunOpts, target string, sources ...string) {
+ dir, err := ioutil.TempDir("", c.Name)
+ if err != nil {
+ c.copyErr = fmt.Errorf("ioutil.TempDir failed: %v", err)
+ return
+ }
+ c.cleanups = append(c.cleanups, func() { os.RemoveAll(dir) })
+ if err := os.Chmod(dir, 0755); err != nil {
+ c.copyErr = fmt.Errorf("os.Chmod(%q, 0755) failed: %v", dir, err)
+ return
+ }
+ for _, name := range sources {
+ src, err := testutil.FindFile(name)
+ if err != nil {
+ c.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %v", name, err)
+ return
+ }
+ dst := path.Join(dir, path.Base(name))
+ if err := testutil.Copy(src, dst); err != nil {
+ c.copyErr = fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err)
+ return
+ }
+ c.logger.Logf("copy: %s -> %s", src, dst)
+ }
+ opts.Mounts = append(opts.Mounts, mount.Mount{
+ Type: mount.TypeBind,
+ Source: dir,
+ Target: target,
+ ReadOnly: false,
+ })
+}
+
+// Status inspects the container returns its status.
+func (c *Container) Status(ctx context.Context) (types.ContainerState, error) {
+ resp, err := c.client.ContainerInspect(ctx, c.id)
+ if err != nil {
+ return types.ContainerState{}, err
+ }
+ return *resp.State, err
+}
+
+// Wait waits for the container to exit.
+func (c *Container) Wait(ctx context.Context) error {
+ statusChan, errChan := c.client.ContainerWait(ctx, c.id, container.WaitConditionNotRunning)
+ select {
+ case err := <-errChan:
+ return err
+ case <-statusChan:
+ return nil
+ }
+}
+
+// WaitTimeout waits for the container to exit with a timeout.
+func (c *Container) WaitTimeout(ctx context.Context, timeout time.Duration) error {
+ timeoutChan := time.After(timeout)
+ statusChan, errChan := c.client.ContainerWait(ctx, c.id, container.WaitConditionNotRunning)
+ select {
+ case err := <-errChan:
+ return err
+ case <-statusChan:
+ return nil
+ case <-timeoutChan:
+ return fmt.Errorf("container %s timed out after %v seconds", c.Name, timeout.Seconds())
+ }
+}
+
+// WaitForOutput searches container logs for pattern and returns or timesout.
+func (c *Container) WaitForOutput(ctx context.Context, pattern string, timeout time.Duration) (string, error) {
+ matches, err := c.WaitForOutputSubmatch(ctx, pattern, timeout)
+ if err != nil {
+ return "", err
+ }
+ if len(matches) == 0 {
+ return "", fmt.Errorf("didn't find pattern %s logs", pattern)
+ }
+ return matches[0], nil
+}
+
+// WaitForOutputSubmatch searches container logs for the given
+// pattern or times out. It returns any regexp submatches as well.
+func (c *Container) WaitForOutputSubmatch(ctx context.Context, pattern string, timeout time.Duration) ([]string, error) {
+ re := regexp.MustCompile(pattern)
+ if matches := re.FindStringSubmatch(c.streamBuf.String()); matches != nil {
+ return matches, nil
+ }
+
+ for exp := time.Now().Add(timeout); time.Now().Before(exp); {
+ c.streams.Conn.SetDeadline(time.Now().Add(50 * time.Millisecond))
+ _, err := stdcopy.StdCopy(&c.streamBuf, &c.streamBuf, c.streams.Reader)
+
+ if err != nil {
+ // check that it wasn't a timeout
+ if nerr, ok := err.(net.Error); !ok || !nerr.Timeout() {
+ return nil, err
+ }
+ }
+
+ if matches := re.FindStringSubmatch(c.streamBuf.String()); matches != nil {
+ return matches, nil
+ }
+ }
+
+ return nil, fmt.Errorf("timeout waiting for output %q: out: %s", re.String(), c.streamBuf.String())
+}
+
+// Kill kills the container.
+func (c *Container) Kill(ctx context.Context) error {
+ return c.client.ContainerKill(ctx, c.id, "")
+}
+
+// Remove is analogous to 'docker rm'.
+func (c *Container) Remove(ctx context.Context) error {
+ // Remove the image.
+ remove := types.ContainerRemoveOptions{
+ RemoveVolumes: c.mounts != nil,
+ RemoveLinks: c.links != nil,
+ Force: true,
+ }
+ return c.client.ContainerRemove(ctx, c.Name, remove)
+}
+
+// CleanUp kills and deletes the container (best effort).
+func (c *Container) CleanUp(ctx context.Context) {
+ // Kill the container.
+ if err := c.Kill(ctx); err != nil && !strings.Contains(err.Error(), "is not running") {
+ // Just log; can't do anything here.
+ c.logger.Logf("error killing container %q: %v", c.Name, err)
+ }
+ // Remove the image.
+ if err := c.Remove(ctx); err != nil {
+ c.logger.Logf("error removing container %q: %v", c.Name, err)
+ }
+ // Forget all mounts.
+ c.mounts = nil
+ // Execute all cleanups.
+ for _, c := range c.cleanups {
+ c()
+ }
+ c.cleanups = nil
+}
diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go
new file mode 100644
index 000000000..f95ae3cd1
--- /dev/null
+++ b/pkg/test/dockerutil/dockerutil.go
@@ -0,0 +1,121 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package dockerutil is a collection of utility functions.
+package dockerutil
+
+import (
+ "encoding/json"
+ "flag"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "os/exec"
+ "regexp"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+var (
+ // runtime is the runtime to use for tests. This will be applied to all
+ // containers. Note that the default here ("runsc") corresponds to the
+ // default used by the installations. This is important, because the
+ // default installer for vm_tests (in tools/installers:head, invoked
+ // via tools/vm:defs.bzl) will install with this name. So without
+ // changing anything, tests should have a runsc runtime available to
+ // them. Otherwise installers should update the existing runtime
+ // instead of installing a new one.
+ runtime = flag.String("runtime", "runsc", "specify which runtime to use")
+
+ // config is the default Docker daemon configuration path.
+ config = flag.String("config_path", "/etc/docker/daemon.json", "configuration file for reading paths")
+)
+
+// EnsureSupportedDockerVersion checks if correct docker is installed.
+//
+// This logs directly to stderr, as it is typically called from a Main wrapper.
+func EnsureSupportedDockerVersion() {
+ cmd := exec.Command("docker", "version")
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ log.Fatalf("error running %q: %v", "docker version", err)
+ }
+ re := regexp.MustCompile(`Version:\s+(\d+)\.(\d+)\.\d.*`)
+ matches := re.FindStringSubmatch(string(out))
+ if len(matches) != 3 {
+ log.Fatalf("Invalid docker output: %s", out)
+ }
+ major, _ := strconv.Atoi(matches[1])
+ minor, _ := strconv.Atoi(matches[2])
+ if major < 17 || (major == 17 && minor < 9) {
+ log.Fatalf("Docker version 17.09.0 or greater is required, found: %02d.%02d", major, minor)
+ }
+}
+
+// RuntimePath returns the binary path for the current runtime.
+func RuntimePath() (string, error) {
+ // Read the configuration data; the file must exist.
+ configBytes, err := ioutil.ReadFile(*config)
+ if err != nil {
+ return "", err
+ }
+
+ // Unmarshal the configuration.
+ c := make(map[string]interface{})
+ if err := json.Unmarshal(configBytes, &c); err != nil {
+ return "", err
+ }
+
+ // Decode the expected configuration.
+ r, ok := c["runtimes"]
+ if !ok {
+ return "", fmt.Errorf("no runtimes declared: %v", c)
+ }
+ rs, ok := r.(map[string]interface{})
+ if !ok {
+ // The runtimes are not a map.
+ return "", fmt.Errorf("unexpected format: %v", c)
+ }
+ r, ok = rs[*runtime]
+ if !ok {
+ // The expected runtime is not declared.
+ return "", fmt.Errorf("runtime %q not found: %v", *runtime, c)
+ }
+ rs, ok = r.(map[string]interface{})
+ if !ok {
+ // The runtime is not a map.
+ return "", fmt.Errorf("unexpected format: %v", c)
+ }
+ p, ok := rs["path"].(string)
+ if !ok {
+ // The runtime does not declare a path.
+ return "", fmt.Errorf("unexpected format: %v", c)
+ }
+ return p, nil
+}
+
+// Save exports a container image to the given Writer.
+//
+// Note that the writer should be actively consuming the output, otherwise it
+// is not guaranteed that the Save will make any progress and the call may
+// stall indefinitely.
+//
+// This is called by criutil in order to import imports.
+func Save(logger testutil.Logger, image string, w io.Writer) error {
+ cmd := testutil.Command(logger, "docker", "save", testutil.ImageByName(image))
+ cmd.Stdout = w // Send directly to the writer.
+ return cmd.Run()
+}
diff --git a/pkg/test/dockerutil/exec.go b/pkg/test/dockerutil/exec.go
new file mode 100644
index 000000000..921d1da9e
--- /dev/null
+++ b/pkg/test/dockerutil/exec.go
@@ -0,0 +1,194 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/docker/docker/api/types"
+ "github.com/docker/docker/pkg/stdcopy"
+)
+
+// ExecOpts holds arguments for Exec calls.
+type ExecOpts struct {
+ // Env are additional environment variables.
+ Env []string
+
+ // Privileged enables privileged mode.
+ Privileged bool
+
+ // User is the user to use.
+ User string
+
+ // Enables Tty and stdin for the created process.
+ UseTTY bool
+
+ // WorkDir is the working directory of the process.
+ WorkDir string
+}
+
+// Exec creates a process inside the container.
+func (c *Container) Exec(ctx context.Context, opts ExecOpts, args ...string) (string, error) {
+ p, err := c.doExec(ctx, opts, args)
+ if err != nil {
+ return "", err
+ }
+
+ if exitStatus, err := p.WaitExitStatus(ctx); err != nil {
+ return "", err
+ } else if exitStatus != 0 {
+ out, _ := p.Logs()
+ return out, fmt.Errorf("process terminated with status: %d", exitStatus)
+ }
+
+ return p.Logs()
+}
+
+// ExecProcess creates a process inside the container and returns a process struct
+// for the caller to use.
+func (c *Container) ExecProcess(ctx context.Context, opts ExecOpts, args ...string) (Process, error) {
+ return c.doExec(ctx, opts, args)
+}
+
+func (c *Container) doExec(ctx context.Context, r ExecOpts, args []string) (Process, error) {
+ config := c.execConfig(r, args)
+ resp, err := c.client.ContainerExecCreate(ctx, c.id, config)
+ if err != nil {
+ return Process{}, fmt.Errorf("exec create failed with err: %v", err)
+ }
+
+ hijack, err := c.client.ContainerExecAttach(ctx, resp.ID, types.ExecStartCheck{})
+ if err != nil {
+ return Process{}, fmt.Errorf("exec attach failed with err: %v", err)
+ }
+
+ if err := c.client.ContainerExecStart(ctx, resp.ID, types.ExecStartCheck{}); err != nil {
+ hijack.Close()
+ return Process{}, fmt.Errorf("exec start failed with err: %v", err)
+ }
+
+ return Process{
+ container: c,
+ execid: resp.ID,
+ conn: hijack,
+ }, nil
+
+}
+
+func (c *Container) execConfig(r ExecOpts, cmd []string) types.ExecConfig {
+ env := append(r.Env, fmt.Sprintf("RUNSC_TEST_NAME=%s", c.Name))
+ return types.ExecConfig{
+ AttachStdin: r.UseTTY,
+ AttachStderr: true,
+ AttachStdout: true,
+ Cmd: cmd,
+ Privileged: r.Privileged,
+ WorkingDir: r.WorkDir,
+ Env: env,
+ Tty: r.UseTTY,
+ User: r.User,
+ }
+
+}
+
+// Process represents a containerized process.
+type Process struct {
+ container *Container
+ execid string
+ conn types.HijackedResponse
+}
+
+// Write writes buf to the process's stdin.
+func (p *Process) Write(timeout time.Duration, buf []byte) (int, error) {
+ p.conn.Conn.SetDeadline(time.Now().Add(timeout))
+ return p.conn.Conn.Write(buf)
+}
+
+// Read returns process's stdout and stderr.
+func (p *Process) Read() (string, string, error) {
+ var stdout, stderr bytes.Buffer
+ if err := p.read(&stdout, &stderr); err != nil {
+ return "", "", err
+ }
+ return stdout.String(), stderr.String(), nil
+}
+
+// Logs returns combined stdout/stderr from the process.
+func (p *Process) Logs() (string, error) {
+ var out bytes.Buffer
+ if err := p.read(&out, &out); err != nil {
+ return "", err
+ }
+ return out.String(), nil
+}
+
+func (p *Process) read(stdout, stderr *bytes.Buffer) error {
+ _, err := stdcopy.StdCopy(stdout, stderr, p.conn.Reader)
+ return err
+}
+
+// ExitCode returns the process's exit code.
+func (p *Process) ExitCode(ctx context.Context) (int, error) {
+ _, exitCode, err := p.runningExitCode(ctx)
+ return exitCode, err
+}
+
+// IsRunning checks if the process is running.
+func (p *Process) IsRunning(ctx context.Context) (bool, error) {
+ running, _, err := p.runningExitCode(ctx)
+ return running, err
+}
+
+// WaitExitStatus until process completes and returns exit status.
+func (p *Process) WaitExitStatus(ctx context.Context) (int, error) {
+ waitChan := make(chan (int))
+ errChan := make(chan (error))
+
+ go func() {
+ for {
+ running, exitcode, err := p.runningExitCode(ctx)
+ if err != nil {
+ errChan <- fmt.Errorf("error waiting process %s: container %v", p.execid, p.container.Name)
+ }
+ if !running {
+ waitChan <- exitcode
+ }
+ time.Sleep(time.Millisecond * 500)
+ }
+ }()
+
+ select {
+ case ws := <-waitChan:
+ return ws, nil
+ case err := <-errChan:
+ return -1, err
+ }
+}
+
+// runningExitCode collects if the process is running and the exit code.
+// The exit code is only valid if the process has exited.
+func (p *Process) runningExitCode(ctx context.Context) (bool, int, error) {
+ // If execid is not empty, this is a execed process.
+ if p.execid != "" {
+ status, err := p.container.client.ContainerExecInspect(ctx, p.execid)
+ return status.Running, status.ExitCode, err
+ }
+ // else this is the root process.
+ status, err := p.container.Status(ctx)
+ return status.Running, status.ExitCode, err
+}
diff --git a/pkg/test/dockerutil/network.go b/pkg/test/dockerutil/network.go
new file mode 100644
index 000000000..047091e75
--- /dev/null
+++ b/pkg/test/dockerutil/network.go
@@ -0,0 +1,113 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "context"
+ "net"
+
+ "github.com/docker/docker/api/types"
+ "github.com/docker/docker/api/types/network"
+ "github.com/docker/docker/client"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+// Network is a docker network.
+type Network struct {
+ client *client.Client
+ id string
+ logger testutil.Logger
+ Name string
+ containers []*Container
+ Subnet *net.IPNet
+}
+
+// NewNetwork sets up the struct for a Docker network. Names of networks
+// will be unique.
+func NewNetwork(ctx context.Context, logger testutil.Logger) *Network {
+ client, err := client.NewClientWithOpts(client.FromEnv)
+ if err != nil {
+ logger.Logf("create client failed with: %v", err)
+ return nil
+ }
+ client.NegotiateAPIVersion(ctx)
+
+ return &Network{
+ logger: logger,
+ Name: testutil.RandomID(logger.Name()),
+ client: client,
+ }
+}
+
+func (n *Network) networkCreate() types.NetworkCreate {
+
+ var subnet string
+ if n.Subnet != nil {
+ subnet = n.Subnet.String()
+ }
+
+ ipam := network.IPAM{
+ Config: []network.IPAMConfig{{
+ Subnet: subnet,
+ }},
+ }
+
+ return types.NetworkCreate{
+ CheckDuplicate: true,
+ IPAM: &ipam,
+ }
+}
+
+// Create is analogous to 'docker network create'.
+func (n *Network) Create(ctx context.Context) error {
+
+ opts := n.networkCreate()
+ resp, err := n.client.NetworkCreate(ctx, n.Name, opts)
+ if err != nil {
+ return err
+ }
+ n.id = resp.ID
+ return nil
+}
+
+// Connect is analogous to 'docker network connect' with the arguments provided.
+func (n *Network) Connect(ctx context.Context, container *Container, ipv4, ipv6 string) error {
+ settings := network.EndpointSettings{
+ IPAMConfig: &network.EndpointIPAMConfig{
+ IPv4Address: ipv4,
+ IPv6Address: ipv6,
+ },
+ }
+ err := n.client.NetworkConnect(ctx, n.id, container.id, &settings)
+ if err == nil {
+ n.containers = append(n.containers, container)
+ }
+ return err
+}
+
+// Inspect returns this network's info.
+func (n *Network) Inspect(ctx context.Context) (types.NetworkResource, error) {
+ return n.client.NetworkInspect(ctx, n.id, types.NetworkInspectOptions{Verbose: true})
+}
+
+// Cleanup cleans up the docker network and all the containers attached to it.
+func (n *Network) Cleanup(ctx context.Context) error {
+ for _, c := range n.containers {
+ c.CleanUp(ctx)
+ }
+ n.containers = nil
+
+ return n.client.NetworkRemove(ctx, n.id)
+}
diff --git a/pkg/test/testutil/BUILD b/pkg/test/testutil/BUILD
new file mode 100644
index 000000000..03b1b4677
--- /dev/null
+++ b/pkg/test/testutil/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "testutil",
+ testonly = 1,
+ srcs = [
+ "testutil.go",
+ "testutil_runfiles.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/sync",
+ "//runsc/boot",
+ "//runsc/specutils",
+ "@com_github_cenkalti_backoff//:go_default_library",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ ],
+)
diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go
new file mode 100644
index 000000000..64c292698
--- /dev/null
+++ b/pkg/test/testutil/testutil.go
@@ -0,0 +1,536 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testutil contains utility functions for runsc tests.
+package testutil
+
+import (
+ "bufio"
+ "context"
+ "debug/elf"
+ "encoding/base32"
+ "encoding/json"
+ "flag"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "math"
+ "math/rand"
+ "net/http"
+ "os"
+ "os/exec"
+ "os/signal"
+ "path"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "sync/atomic"
+ "syscall"
+ "testing"
+ "time"
+
+ "github.com/cenkalti/backoff"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+var (
+ checkpoint = flag.Bool("checkpoint", true, "control checkpoint/restore support")
+)
+
+// IsCheckpointSupported returns the relevant command line flag.
+func IsCheckpointSupported() bool {
+ return *checkpoint
+}
+
+// ImageByName mangles the image name used locally. This depends on the image
+// build infrastructure in images/ and tools/vm.
+func ImageByName(name string) string {
+ return fmt.Sprintf("gvisor.dev/images/%s", name)
+}
+
+// ConfigureExePath configures the executable for runsc in the test environment.
+func ConfigureExePath() error {
+ path, err := FindFile("runsc/runsc")
+ if err != nil {
+ return err
+ }
+ specutils.ExePath = path
+ return nil
+}
+
+// TmpDir returns the absolute path to a writable directory that can be used as
+// scratch by the test.
+func TmpDir() string {
+ dir := os.Getenv("TEST_TMPDIR")
+ if dir == "" {
+ dir = "/tmp"
+ }
+ return dir
+}
+
+// Logger is a simple logging wrapper.
+//
+// This is designed to be implemented by *testing.T.
+type Logger interface {
+ Name() string
+ Logf(fmt string, args ...interface{})
+}
+
+// DefaultLogger logs using the log package.
+type DefaultLogger string
+
+// Name implements Logger.Name.
+func (d DefaultLogger) Name() string {
+ return string(d)
+}
+
+// Logf implements Logger.Logf.
+func (d DefaultLogger) Logf(fmt string, args ...interface{}) {
+ log.Printf(fmt, args...)
+}
+
+// Cmd is a simple wrapper.
+type Cmd struct {
+ logger Logger
+ *exec.Cmd
+}
+
+// CombinedOutput returns the output and logs.
+func (c *Cmd) CombinedOutput() ([]byte, error) {
+ out, err := c.Cmd.CombinedOutput()
+ if len(out) > 0 {
+ c.logger.Logf("output: %s", string(out))
+ }
+ if err != nil {
+ c.logger.Logf("error: %v", err)
+ }
+ return out, err
+}
+
+// Command is a simple wrapper around exec.Command, that logs.
+func Command(logger Logger, args ...string) *Cmd {
+ logger.Logf("command: %s", strings.Join(args, " "))
+ return &Cmd{
+ logger: logger,
+ Cmd: exec.Command(args[0], args[1:]...),
+ }
+}
+
+// TestConfig returns the default configuration to use in tests. Note that
+// 'RootDir' must be set by caller if required.
+func TestConfig(t *testing.T) *boot.Config {
+ logDir := os.TempDir()
+ if dir, ok := os.LookupEnv("TEST_UNDECLARED_OUTPUTS_DIR"); ok {
+ logDir = dir + "/"
+ }
+ return &boot.Config{
+ Debug: true,
+ DebugLog: path.Join(logDir, "runsc.log."+t.Name()+".%TIMESTAMP%.%COMMAND%"),
+ LogFormat: "text",
+ DebugLogFormat: "text",
+ LogPackets: true,
+ Network: boot.NetworkNone,
+ Strace: true,
+ Platform: "ptrace",
+ FileAccess: boot.FileAccessExclusive,
+ NumNetworkChannels: 1,
+
+ TestOnlyAllowRunAsCurrentUserWithoutChroot: true,
+ }
+}
+
+// NewSpecWithArgs creates a simple spec with the given args suitable for use
+// in tests.
+func NewSpecWithArgs(args ...string) *specs.Spec {
+ return &specs.Spec{
+ // The host filesystem root is the container root.
+ Root: &specs.Root{
+ Path: "/",
+ Readonly: true,
+ },
+ Process: &specs.Process{
+ Args: args,
+ Env: []string{
+ "PATH=" + os.Getenv("PATH"),
+ },
+ Capabilities: specutils.AllCapabilities(),
+ },
+ Mounts: []specs.Mount{
+ // Hide the host /etc to avoid any side-effects.
+ // For example, bash reads /etc/passwd and if it is
+ // very big, tests can fail by timeout.
+ {
+ Type: "tmpfs",
+ Destination: "/etc",
+ },
+ // Root is readonly, but many tests want to write to tmpdir.
+ // This creates a writable mount inside the root. Also, when tmpdir points
+ // to "/tmp", it makes the the actual /tmp to be mounted and not a tmpfs
+ // inside the sentry.
+ {
+ Type: "bind",
+ Destination: TmpDir(),
+ Source: TmpDir(),
+ },
+ },
+ Hostname: "runsc-test-hostname",
+ }
+}
+
+// SetupRootDir creates a root directory for containers.
+func SetupRootDir() (string, func(), error) {
+ rootDir, err := ioutil.TempDir(TmpDir(), "containers")
+ if err != nil {
+ return "", nil, fmt.Errorf("error creating root dir: %v", err)
+ }
+ return rootDir, func() { os.RemoveAll(rootDir) }, nil
+}
+
+// SetupContainer creates a bundle and root dir for the container, generates a
+// test config, and writes the spec to config.json in the bundle dir.
+func SetupContainer(spec *specs.Spec, conf *boot.Config) (rootDir, bundleDir string, cleanup func(), err error) {
+ rootDir, rootCleanup, err := SetupRootDir()
+ if err != nil {
+ return "", "", nil, err
+ }
+ conf.RootDir = rootDir
+ bundleDir, bundleCleanup, err := SetupBundleDir(spec)
+ if err != nil {
+ rootCleanup()
+ return "", "", nil, err
+ }
+ return rootDir, bundleDir, func() {
+ bundleCleanup()
+ rootCleanup()
+ }, err
+}
+
+// SetupBundleDir creates a bundle dir and writes the spec to config.json.
+func SetupBundleDir(spec *specs.Spec) (string, func(), error) {
+ bundleDir, err := ioutil.TempDir(TmpDir(), "bundle")
+ if err != nil {
+ return "", nil, fmt.Errorf("error creating bundle dir: %v", err)
+ }
+ cleanup := func() { os.RemoveAll(bundleDir) }
+ if err := writeSpec(bundleDir, spec); err != nil {
+ cleanup()
+ return "", nil, fmt.Errorf("error writing spec: %v", err)
+ }
+ return bundleDir, cleanup, nil
+}
+
+// writeSpec writes the spec to disk in the given directory.
+func writeSpec(dir string, spec *specs.Spec) error {
+ b, err := json.Marshal(spec)
+ if err != nil {
+ return err
+ }
+ return ioutil.WriteFile(filepath.Join(dir, "config.json"), b, 0755)
+}
+
+// RandomID returns 20 random bytes following the given prefix.
+func RandomID(prefix string) string {
+ // Read 20 random bytes.
+ b := make([]byte, 20)
+ // "[Read] always returns len(p) and a nil error." --godoc
+ if _, err := rand.Read(b); err != nil {
+ panic("rand.Read failed: " + err.Error())
+ }
+ if prefix != "" {
+ prefix = prefix + "-"
+ }
+ return fmt.Sprintf("%s%s", prefix, base32.StdEncoding.EncodeToString(b))
+}
+
+// RandomContainerID generates a random container id for each test.
+//
+// The container id is used to create an abstract unix domain socket, which
+// must be unique. While the container forbids creating two containers with the
+// same name, sometimes between test runs the socket does not get cleaned up
+// quickly enough, causing container creation to fail.
+func RandomContainerID() string {
+ return RandomID("test-container-")
+}
+
+// Copy copies file from src to dst.
+func Copy(src, dst string) error {
+ in, err := os.Open(src)
+ if err != nil {
+ return err
+ }
+ defer in.Close()
+
+ st, err := in.Stat()
+ if err != nil {
+ return err
+ }
+
+ out, err := os.OpenFile(dst, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, st.Mode().Perm())
+ if err != nil {
+ return err
+ }
+ defer out.Close()
+
+ // Mirror the local user's permissions across all users. This is
+ // because as we inject things into the container, the UID/GID will
+ // change. Also, the build system may generate artifacts with different
+ // modes. At the top-level (volume mapping) we have a big read-only
+ // knob that can be applied to prevent modifications.
+ //
+ // Note that this must be done via a separate Chmod call, otherwise the
+ // current process's umask will get in the way.
+ var mode os.FileMode
+ if st.Mode()&0100 != 0 {
+ mode |= 0111
+ }
+ if st.Mode()&0200 != 0 {
+ mode |= 0222
+ }
+ if st.Mode()&0400 != 0 {
+ mode |= 0444
+ }
+ if err := os.Chmod(dst, mode); err != nil {
+ return err
+ }
+
+ _, err = io.Copy(out, in)
+ return err
+}
+
+// Poll is a shorthand function to poll for something with given timeout.
+func Poll(cb func() error, timeout time.Duration) error {
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx)
+ return backoff.Retry(cb, b)
+}
+
+// WaitForHTTP tries GET requests on a port until the call succeeds or timeout.
+func WaitForHTTP(port int, timeout time.Duration) error {
+ cb := func() error {
+ c := &http.Client{
+ // Calculate timeout to be able to do minimum 5 attempts.
+ Timeout: timeout / 5,
+ }
+ url := fmt.Sprintf("http://localhost:%d/", port)
+ resp, err := c.Get(url)
+ if err != nil {
+ log.Printf("Waiting %s: %v", url, err)
+ return err
+ }
+ resp.Body.Close()
+ return nil
+ }
+ return Poll(cb, timeout)
+}
+
+// Reaper reaps child processes.
+type Reaper struct {
+ // mu protects ch, which will be nil if the reaper is not running.
+ mu sync.Mutex
+ ch chan os.Signal
+}
+
+// Start starts reaping child processes.
+func (r *Reaper) Start() {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.ch != nil {
+ panic("reaper.Start called on a running reaper")
+ }
+
+ r.ch = make(chan os.Signal, 1)
+ signal.Notify(r.ch, syscall.SIGCHLD)
+
+ go func() {
+ for {
+ r.mu.Lock()
+ ch := r.ch
+ r.mu.Unlock()
+ if ch == nil {
+ return
+ }
+
+ _, ok := <-ch
+ if !ok {
+ // Channel closed.
+ return
+ }
+ for {
+ cpid, _ := syscall.Wait4(-1, nil, syscall.WNOHANG, nil)
+ if cpid < 1 {
+ break
+ }
+ }
+ }
+ }()
+}
+
+// Stop stops reaping child processes.
+func (r *Reaper) Stop() {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.ch == nil {
+ panic("reaper.Stop called on a stopped reaper")
+ }
+
+ signal.Stop(r.ch)
+ close(r.ch)
+ r.ch = nil
+}
+
+// StartReaper is a helper that starts a new Reaper and returns a function to
+// stop it.
+func StartReaper() func() {
+ r := &Reaper{}
+ r.Start()
+ return r.Stop
+}
+
+// WaitUntilRead reads from the given reader until the wanted string is found
+// or until timeout.
+func WaitUntilRead(r io.Reader, want string, split bufio.SplitFunc, timeout time.Duration) error {
+ sc := bufio.NewScanner(r)
+ if split != nil {
+ sc.Split(split)
+ }
+ // done must be accessed atomically. A value greater than 0 indicates
+ // that the read loop can exit.
+ var done uint32
+ doneCh := make(chan struct{})
+ go func() {
+ for sc.Scan() {
+ t := sc.Text()
+ if strings.Contains(t, want) {
+ atomic.StoreUint32(&done, 1)
+ close(doneCh)
+ break
+ }
+ if atomic.LoadUint32(&done) > 0 {
+ break
+ }
+ }
+ }()
+ select {
+ case <-time.After(timeout):
+ atomic.StoreUint32(&done, 1)
+ return fmt.Errorf("timeout waiting to read %q", want)
+ case <-doneCh:
+ return nil
+ }
+}
+
+// KillCommand kills the process running cmd unless it hasn't been started. It
+// returns an error if it cannot kill the process unless the reason is that the
+// process has already exited.
+//
+// KillCommand will also reap the process.
+func KillCommand(cmd *exec.Cmd) error {
+ if cmd.Process == nil {
+ return nil
+ }
+ if err := cmd.Process.Kill(); err != nil {
+ if !strings.Contains(err.Error(), "process already finished") {
+ return fmt.Errorf("failed to kill process %v: %v", cmd, err)
+ }
+ }
+ return cmd.Wait()
+}
+
+// WriteTmpFile writes text to a temporary file, closes the file, and returns
+// the name of the file. A cleanup function is also returned.
+func WriteTmpFile(pattern, text string) (string, func(), error) {
+ file, err := ioutil.TempFile(TmpDir(), pattern)
+ if err != nil {
+ return "", nil, err
+ }
+ defer file.Close()
+ if _, err := file.Write([]byte(text)); err != nil {
+ return "", nil, err
+ }
+ return file.Name(), func() { os.RemoveAll(file.Name()) }, nil
+}
+
+// IsStatic returns true iff the given file is a static binary.
+func IsStatic(filename string) (bool, error) {
+ f, err := elf.Open(filename)
+ if err != nil {
+ return false, err
+ }
+ for _, prog := range f.Progs {
+ if prog.Type == elf.PT_INTERP {
+ return false, nil // Has interpreter.
+ }
+ }
+ return true, nil
+}
+
+// TouchShardStatusFile indicates to Bazel that the test runner supports
+// sharding by creating or updating the last modified date of the file
+// specified by TEST_SHARD_STATUS_FILE.
+//
+// See https://docs.bazel.build/versions/master/test-encyclopedia.html#role-of-the-test-runner.
+func TouchShardStatusFile() error {
+ if statusFile := os.Getenv("TEST_SHARD_STATUS_FILE"); statusFile != "" {
+ cmd := exec.Command("touch", statusFile)
+ if b, err := cmd.CombinedOutput(); err != nil {
+ return fmt.Errorf("touch %q failed:\n output: %s\n error: %s", statusFile, string(b), err.Error())
+ }
+ }
+ return nil
+}
+
+// TestIndicesForShard returns indices for this test shard based on the
+// TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars.
+//
+// If either of the env vars are not present, then the function will return all
+// tests. If there are more shards than there are tests, then the returned list
+// may be empty.
+func TestIndicesForShard(numTests int) ([]int, error) {
+ var (
+ shardIndex = 0
+ shardTotal = 1
+ )
+
+ indexStr, totalStr := os.Getenv("TEST_SHARD_INDEX"), os.Getenv("TEST_TOTAL_SHARDS")
+ if indexStr != "" && totalStr != "" {
+ // Parse index and total to ints.
+ var err error
+ shardIndex, err = strconv.Atoi(indexStr)
+ if err != nil {
+ return nil, fmt.Errorf("invalid TEST_SHARD_INDEX %q: %v", indexStr, err)
+ }
+ shardTotal, err = strconv.Atoi(totalStr)
+ if err != nil {
+ return nil, fmt.Errorf("invalid TEST_TOTAL_SHARDS %q: %v", totalStr, err)
+ }
+ }
+
+ // Calculate!
+ var indices []int
+ numBlocks := int(math.Ceil(float64(numTests) / float64(shardTotal)))
+ for i := 0; i < numBlocks; i++ {
+ pick := i*shardTotal + shardIndex
+ if pick < numTests {
+ indices = append(indices, pick)
+ }
+ }
+ return indices, nil
+}
diff --git a/pkg/test/testutil/testutil_runfiles.go b/pkg/test/testutil/testutil_runfiles.go
new file mode 100644
index 000000000..ece9ea9a1
--- /dev/null
+++ b/pkg/test/testutil/testutil_runfiles.go
@@ -0,0 +1,75 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testutil
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+)
+
+// FindFile searchs for a file inside the test run environment. It returns the
+// full path to the file. It fails if none or more than one file is found.
+func FindFile(path string) (string, error) {
+ wd, err := os.Getwd()
+ if err != nil {
+ return "", err
+ }
+
+ // The test root is demarcated by a path element called "__main__". Search for
+ // it backwards from the working directory.
+ root := wd
+ for {
+ dir, name := filepath.Split(root)
+ if name == "__main__" {
+ break
+ }
+ if len(dir) == 0 {
+ return "", fmt.Errorf("directory __main__ not found in %q", wd)
+ }
+ // Remove ending slash to loop around.
+ root = dir[:len(dir)-1]
+ }
+
+ // Annoyingly, bazel adds the build type to the directory path for go
+ // binaries, but not for c++ binaries. We use two different patterns to
+ // to find our file.
+ patterns := []string{
+ // Try the obvious path first.
+ filepath.Join(root, path),
+ // If it was a go binary, use a wildcard to match the build
+ // type. The pattern is: /test-path/__main__/directories/*/file.
+ filepath.Join(root, filepath.Dir(path), "*", filepath.Base(path)),
+ }
+
+ for _, p := range patterns {
+ matches, err := filepath.Glob(p)
+ if err != nil {
+ // "The only possible returned error is ErrBadPattern,
+ // when pattern is malformed." -godoc
+ return "", fmt.Errorf("error globbing %q: %v", p, err)
+ }
+ switch len(matches) {
+ case 0:
+ // Try the next pattern.
+ case 1:
+ // We found it.
+ return matches[0], nil
+ default:
+ return "", fmt.Errorf("more than one match found for %q: %s", path, matches)
+ }
+ }
+ return "", fmt.Errorf("file %q not found", path)
+}
diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD
new file mode 100644
index 000000000..a86501fa2
--- /dev/null
+++ b/pkg/unet/BUILD
@@ -0,0 +1,26 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "unet",
+ srcs = [
+ "unet.go",
+ "unet_unsafe.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/gate",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "unet_test",
+ size = "small",
+ srcs = [
+ "unet_test.go",
+ ],
+ library = ":unet",
+ deps = ["//pkg/sync"],
+)
diff --git a/pkg/unet/unet.go b/pkg/unet/unet.go
new file mode 100644
index 000000000..d843f19cf
--- /dev/null
+++ b/pkg/unet/unet.go
@@ -0,0 +1,569 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package unet provides a minimal net package based on Unix Domain Sockets.
+//
+// This does no pooling, and should only be used for a limited number of
+// connections in a Go process. Don't use this package for arbitrary servers.
+package unet
+
+import (
+ "errors"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/gate"
+)
+
+// backlog is used for the listen request.
+const backlog = 16
+
+// errClosing is returned by wait if the Socket is in the process of closing.
+var errClosing = errors.New("Socket is closing")
+
+// errMessageTruncated indicates that data was lost because the provided buffer
+// was too small.
+var errMessageTruncated = errors.New("message truncated")
+
+// socketType returns the appropriate type.
+func socketType(packet bool) int {
+ if packet {
+ return syscall.SOCK_SEQPACKET
+ }
+ return syscall.SOCK_STREAM
+}
+
+// socket creates a new host socket.
+func socket(packet bool) (int, error) {
+ // Make a new socket.
+ fd, err := syscall.Socket(syscall.AF_UNIX, socketType(packet), 0)
+ if err != nil {
+ return 0, err
+ }
+
+ return fd, nil
+}
+
+// eventFD returns a new event FD with initial value 0.
+func eventFD() (int, error) {
+ f, _, e := syscall.Syscall(syscall.SYS_EVENTFD2, 0, 0, 0)
+ if e != 0 {
+ return -1, e
+ }
+ return int(f), nil
+}
+
+// Socket is a connected unix domain socket.
+type Socket struct {
+ // gate protects use of fd.
+ gate gate.Gate
+
+ // fd is the bound socket.
+ //
+ // fd must be read atomically, and only remains valid if read while
+ // within gate.
+ fd int32
+
+ // efd is an event FD that is signaled when the socket is closing.
+ //
+ // efd is immutable and remains valid until Close/Release.
+ efd int
+
+ // race is an atomic variable used to avoid triggering the race
+ // detector. See comment in SocketPair below.
+ race *int32
+}
+
+// NewSocket returns a socket from an existing FD.
+//
+// NewSocket takes ownership of fd.
+func NewSocket(fd int) (*Socket, error) {
+ // fd must be non-blocking for non-blocking syscall.Accept in
+ // ServerSocket.Accept.
+ if err := syscall.SetNonblock(fd, true); err != nil {
+ return nil, err
+ }
+
+ efd, err := eventFD()
+ if err != nil {
+ return nil, err
+ }
+
+ return &Socket{
+ fd: int32(fd),
+ efd: efd,
+ }, nil
+}
+
+// finish completes use of s.fd by evicting any waiters, closing the gate, and
+// closing the event FD.
+func (s *Socket) finish() error {
+ // Signal any blocked or future polls.
+ //
+ // N.B. eventfd writes must be 8 bytes.
+ if _, err := syscall.Write(s.efd, []byte{1, 0, 0, 0, 0, 0, 0, 0}); err != nil {
+ return err
+ }
+
+ // Close the gate, blocking until all FD users leave.
+ s.gate.Close()
+
+ return syscall.Close(s.efd)
+}
+
+// Close closes the socket.
+func (s *Socket) Close() error {
+ // Set the FD in the socket to -1, to ensure that all future calls to
+ // FD/Release get nothing and Close calls return immediately.
+ fd := int(atomic.SwapInt32(&s.fd, -1))
+ if fd < 0 {
+ // Already closed or closing.
+ return syscall.EBADF
+ }
+
+ // Shutdown the socket to cancel any pending accepts.
+ s.shutdown(fd)
+
+ if err := s.finish(); err != nil {
+ return err
+ }
+
+ return syscall.Close(fd)
+}
+
+// Release releases ownership of the socket FD.
+//
+// The returned FD is non-blocking.
+//
+// Any concurrent or future callers of Socket methods will receive EBADF.
+func (s *Socket) Release() (int, error) {
+ // Set the FD in the socket to -1, to ensure that all future calls to
+ // FD/Release get nothing and Close calls return immediately.
+ fd := int(atomic.SwapInt32(&s.fd, -1))
+ if fd < 0 {
+ // Already closed or closing.
+ return -1, syscall.EBADF
+ }
+
+ if err := s.finish(); err != nil {
+ return -1, err
+ }
+
+ return fd, nil
+}
+
+// FD returns the FD for this Socket.
+//
+// The FD is non-blocking and must not be made blocking.
+//
+// N.B. os.File.Fd makes the FD blocking. Use of Release instead of FD is
+// strongly preferred.
+//
+// The returned FD cannot be used safely if there may be concurrent callers to
+// Close or Release.
+//
+// Use Release to take ownership of the FD.
+func (s *Socket) FD() int {
+ return int(atomic.LoadInt32(&s.fd))
+}
+
+// enterFD enters the FD gate and returns the FD value.
+//
+// If enterFD returns ok, s.gate.Leave must be called when done with the FD.
+// Callers may only block while within the gate using s.wait.
+//
+// The returned FD is guaranteed to remain valid until s.gate.Leave.
+func (s *Socket) enterFD() (int, bool) {
+ if !s.gate.Enter() {
+ return -1, false
+ }
+
+ fd := int(atomic.LoadInt32(&s.fd))
+ if fd < 0 {
+ s.gate.Leave()
+ return -1, false
+ }
+
+ return fd, true
+}
+
+// SocketPair creates a pair of connected sockets.
+func SocketPair(packet bool) (*Socket, *Socket, error) {
+ // Make a new pair.
+ fds, err := syscall.Socketpair(syscall.AF_UNIX, socketType(packet)|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // race is an atomic variable used to avoid triggering the race
+ // detector. We have to fool TSAN into thinking there is a race
+ // variable between our two sockets. We only use SocketPair in tests
+ // anyway.
+ //
+ // NOTE(b/27107811): This is purely due to the fact that the raw
+ // syscall does not serve as a boundary for the sanitizer.
+ var race int32
+ a, err := NewSocket(fds[0])
+ if err != nil {
+ syscall.Close(fds[0])
+ syscall.Close(fds[1])
+ return nil, nil, err
+ }
+ a.race = &race
+ b, err := NewSocket(fds[1])
+ if err != nil {
+ a.Close()
+ syscall.Close(fds[1])
+ return nil, nil, err
+ }
+ b.race = &race
+ return a, b, nil
+}
+
+// Connect connects to a server.
+func Connect(addr string, packet bool) (*Socket, error) {
+ fd, err := socket(packet)
+ if err != nil {
+ return nil, err
+ }
+
+ // Connect the socket.
+ usa := &syscall.SockaddrUnix{Name: addr}
+ if err := syscall.Connect(fd, usa); err != nil {
+ syscall.Close(fd)
+ return nil, err
+ }
+
+ return NewSocket(fd)
+}
+
+// ControlMessage wraps around a byte array and provides functions for parsing
+// as a Unix Domain Socket control message.
+type ControlMessage []byte
+
+// EnableFDs enables receiving FDs via control message.
+//
+// This guarantees only a MINIMUM number of FDs received. You may receive MORE
+// than this due to the way FDs are packed. To be specific, the number of
+// receivable buffers will be rounded up to the nearest even number.
+//
+// This must be called prior to ReadVec if you want to receive FDs.
+func (c *ControlMessage) EnableFDs(count int) {
+ *c = make([]byte, syscall.CmsgSpace(count*4))
+}
+
+// ExtractFDs returns the list of FDs in the control message.
+//
+// Either this or CloseFDs should be used after EnableFDs.
+func (c *ControlMessage) ExtractFDs() ([]int, error) {
+ msgs, err := syscall.ParseSocketControlMessage(*c)
+ if err != nil {
+ return nil, err
+ }
+ var fds []int
+ for _, msg := range msgs {
+ thisFds, err := syscall.ParseUnixRights(&msg)
+ if err != nil {
+ // Different control message.
+ return nil, err
+ }
+ for _, fd := range thisFds {
+ if fd >= 0 {
+ fds = append(fds, fd)
+ }
+ }
+ }
+ return fds, nil
+}
+
+// CloseFDs closes the list of FDs in the control message.
+//
+// Either this or ExtractFDs should be used after EnableFDs.
+func (c *ControlMessage) CloseFDs() {
+ fds, _ := c.ExtractFDs()
+ for _, fd := range fds {
+ if fd >= 0 {
+ syscall.Close(fd)
+ }
+ }
+}
+
+// PackFDs packs the given list of FDs in the control message.
+//
+// This must be used prior to WriteVec.
+func (c *ControlMessage) PackFDs(fds ...int) {
+ *c = ControlMessage(syscall.UnixRights(fds...))
+}
+
+// UnpackFDs clears the control message.
+func (c *ControlMessage) UnpackFDs() {
+ *c = nil
+}
+
+// SocketWriter wraps an individual send operation.
+//
+// The normal entrypoint is WriteVec.
+type SocketWriter struct {
+ socket *Socket
+ to []byte
+ blocking bool
+ race *int32
+
+ ControlMessage
+}
+
+// Writer returns a writer for this socket.
+func (s *Socket) Writer(blocking bool) SocketWriter {
+ return SocketWriter{socket: s, blocking: blocking, race: s.race}
+}
+
+// Write implements io.Writer.Write.
+func (s *Socket) Write(p []byte) (int, error) {
+ r := s.Writer(true)
+ return r.WriteVec([][]byte{p})
+}
+
+// GetSockOpt gets the given socket option.
+func (s *Socket) GetSockOpt(level int, name int, b []byte) (uint32, error) {
+ fd, ok := s.enterFD()
+ if !ok {
+ return 0, syscall.EBADF
+ }
+ defer s.gate.Leave()
+
+ return getsockopt(fd, level, name, b)
+}
+
+// SetSockOpt sets the given socket option.
+func (s *Socket) SetSockOpt(level, name int, b []byte) error {
+ fd, ok := s.enterFD()
+ if !ok {
+ return syscall.EBADF
+ }
+ defer s.gate.Leave()
+
+ return setsockopt(fd, level, name, b)
+}
+
+// GetSockName returns the socket name.
+func (s *Socket) GetSockName() ([]byte, error) {
+ fd, ok := s.enterFD()
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer s.gate.Leave()
+
+ var buf []byte
+ l := syscall.SizeofSockaddrAny
+
+ for {
+ // If the buffer is not large enough, allocate a new one with the hint.
+ buf = make([]byte, l)
+ l, err := getsockname(fd, buf)
+ if err != nil {
+ return nil, err
+ }
+
+ if l <= uint32(len(buf)) {
+ return buf[:l], nil
+ }
+ }
+}
+
+// GetPeerName returns the peer name.
+func (s *Socket) GetPeerName() ([]byte, error) {
+ fd, ok := s.enterFD()
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer s.gate.Leave()
+
+ var buf []byte
+ l := syscall.SizeofSockaddrAny
+
+ for {
+ // See above.
+ buf = make([]byte, l)
+ l, err := getpeername(fd, buf)
+ if err != nil {
+ return nil, err
+ }
+
+ if l <= uint32(len(buf)) {
+ return buf[:l], nil
+ }
+ }
+}
+
+// GetPeerCred returns the peer's unix credentials.
+func (s *Socket) GetPeerCred() (*syscall.Ucred, error) {
+ fd, ok := s.enterFD()
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer s.gate.Leave()
+
+ return syscall.GetsockoptUcred(fd, syscall.SOL_SOCKET, syscall.SO_PEERCRED)
+}
+
+// SocketReader wraps an individual receive operation.
+//
+// This may be used for doing vectorized reads and/or sending additional
+// control messages (e.g. FDs). The normal entrypoint is ReadVec.
+//
+// One of ExtractFDs or DisposeFDs must be called if EnableFDs is used.
+type SocketReader struct {
+ socket *Socket
+ source []byte
+ blocking bool
+ race *int32
+
+ ControlMessage
+}
+
+// Reader returns a reader for this socket.
+func (s *Socket) Reader(blocking bool) SocketReader {
+ return SocketReader{socket: s, blocking: blocking, race: s.race}
+}
+
+// Read implements io.Reader.Read.
+func (s *Socket) Read(p []byte) (int, error) {
+ r := s.Reader(true)
+ return r.ReadVec([][]byte{p})
+}
+
+func (s *Socket) shutdown(fd int) error {
+ // Shutdown the socket to cancel any pending accepts.
+ return syscall.Shutdown(fd, syscall.SHUT_RDWR)
+}
+
+// Shutdown closes the socket for read and write.
+func (s *Socket) Shutdown() error {
+ fd, ok := s.enterFD()
+ if !ok {
+ return syscall.EBADF
+ }
+ defer s.gate.Leave()
+
+ return s.shutdown(fd)
+}
+
+// ServerSocket is a bound unix domain socket.
+type ServerSocket struct {
+ socket *Socket
+}
+
+// NewServerSocket returns a socket from an existing FD.
+func NewServerSocket(fd int) (*ServerSocket, error) {
+ s, err := NewSocket(fd)
+ if err != nil {
+ return nil, err
+ }
+ return &ServerSocket{socket: s}, nil
+}
+
+// Bind creates and binds a new socket.
+func Bind(addr string, packet bool) (*ServerSocket, error) {
+ fd, err := socket(packet)
+ if err != nil {
+ return nil, err
+ }
+
+ // Do the bind.
+ usa := &syscall.SockaddrUnix{Name: addr}
+ if err := syscall.Bind(fd, usa); err != nil {
+ syscall.Close(fd)
+ return nil, err
+ }
+
+ return NewServerSocket(fd)
+}
+
+// BindAndListen creates, binds and listens on a new socket.
+func BindAndListen(addr string, packet bool) (*ServerSocket, error) {
+ s, err := Bind(addr, packet)
+ if err != nil {
+ return nil, err
+ }
+
+ // Start listening.
+ if err := s.Listen(); err != nil {
+ s.Close()
+ return nil, err
+ }
+
+ return s, nil
+}
+
+// Listen starts listening on the socket.
+func (s *ServerSocket) Listen() error {
+ fd, ok := s.socket.enterFD()
+ if !ok {
+ return syscall.EBADF
+ }
+ defer s.socket.gate.Leave()
+
+ return syscall.Listen(fd, backlog)
+}
+
+// Accept accepts a new connection.
+//
+// This is always blocking.
+//
+// Preconditions:
+// * ServerSocket is listening (Listen called).
+func (s *ServerSocket) Accept() (*Socket, error) {
+ fd, ok := s.socket.enterFD()
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer s.socket.gate.Leave()
+
+ for {
+ nfd, _, err := syscall.Accept(fd)
+ switch err {
+ case nil:
+ return NewSocket(nfd)
+ case syscall.EAGAIN:
+ err = s.socket.wait(false)
+ if err == errClosing {
+ err = syscall.EBADF
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ }
+}
+
+// Close closes the server socket.
+//
+// This must only be called once.
+func (s *ServerSocket) Close() error {
+ return s.socket.Close()
+}
+
+// FD returns the socket's file descriptor.
+//
+// See Socket.FD.
+func (s *ServerSocket) FD() int {
+ return s.socket.FD()
+}
+
+// Release releases ownership of the socket's file descriptor.
+//
+// See Socket.Release.
+func (s *ServerSocket) Release() (int, error) {
+ return s.socket.Release()
+}
diff --git a/pkg/unet/unet_test.go b/pkg/unet/unet_test.go
new file mode 100644
index 000000000..5c4b9e8e9
--- /dev/null
+++ b/pkg/unet/unet_test.go
@@ -0,0 +1,736 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unet
+
+import (
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "reflect"
+ "syscall"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func randomFilename() (string, error) {
+ // Return a randomly generated file in the test dir.
+ f, err := ioutil.TempFile("", "unet-test")
+ if err != nil {
+ return "", err
+ }
+ file := f.Name()
+ os.Remove(file)
+ f.Close()
+
+ cwd, err := os.Getwd()
+ if err != nil {
+ return "", err
+ }
+
+ // NOTE(b/26918832): We try to use relative path if possible. This is
+ // to help conforming to the unix path length limit.
+ if rel, err := filepath.Rel(cwd, file); err == nil {
+ return rel, nil
+ }
+
+ return file, nil
+}
+
+func TestConnectFailure(t *testing.T) {
+ name, err := randomFilename()
+ if err != nil {
+ t.Fatalf("unable to generate file, got err %v expected nil", err)
+ }
+
+ if _, err := Connect(name, false); err == nil {
+ t.Fatalf("connect was successful, expected err")
+ }
+}
+
+func TestBindFailure(t *testing.T) {
+ name, err := randomFilename()
+ if err != nil {
+ t.Fatalf("unable to generate file, got err %v expected nil", err)
+ }
+
+ ss, err := BindAndListen(name, false)
+ if err != nil {
+ t.Fatalf("first bind failed, got err %v expected nil", err)
+ }
+ defer ss.Close()
+
+ if _, err = BindAndListen(name, false); err == nil {
+ t.Fatalf("second bind succeeded, expected non-nil err")
+ }
+}
+
+func TestMultipleAccept(t *testing.T) {
+ name, err := randomFilename()
+ if err != nil {
+ t.Fatalf("unable to generate file, got err %v expected nil", err)
+ }
+
+ ss, err := BindAndListen(name, false)
+ if err != nil {
+ t.Fatalf("first bind failed, got err %v expected nil", err)
+ }
+ defer ss.Close()
+
+ // Connect backlog times asynchronously.
+ var wg sync.WaitGroup
+ defer wg.Wait()
+ for i := 0; i < backlog; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ s, err := Connect(name, false)
+ if err != nil {
+ t.Fatalf("connect failed, got err %v expected nil", err)
+ }
+ s.Close()
+ }()
+ }
+
+ // Accept backlog times.
+ for i := 0; i < backlog; i++ {
+ s, err := ss.Accept()
+ if err != nil {
+ t.Errorf("accept failed, got err %v expected nil", err)
+ continue
+ }
+ s.Close()
+ }
+}
+
+func TestServerClose(t *testing.T) {
+ name, err := randomFilename()
+ if err != nil {
+ t.Fatalf("unable to generate file, got err %v expected nil", err)
+ }
+
+ ss, err := BindAndListen(name, false)
+ if err != nil {
+ t.Fatalf("first bind failed, got err %v expected nil", err)
+ }
+
+ // Make sure the first close succeeds.
+ if err := ss.Close(); err != nil {
+ t.Fatalf("first close failed, got err %v expected nil", err)
+ }
+
+ // The second one should fail.
+ if err := ss.Close(); err == nil {
+ t.Fatalf("second close succeeded, expected non-nil err")
+ }
+}
+
+func socketPair(t *testing.T, packet bool) (*Socket, *Socket) {
+ name, err := randomFilename()
+ if err != nil {
+ t.Fatalf("unable to generate file, got err %v expected nil", err)
+ }
+
+ // Bind a server.
+ ss, err := BindAndListen(name, packet)
+ if err != nil {
+ t.Fatalf("error binding, got %v expected nil", err)
+ }
+ defer ss.Close()
+
+ // Accept a client.
+ acceptSocket := make(chan *Socket)
+ acceptErr := make(chan error)
+ go func() {
+ server, err := ss.Accept()
+ if err != nil {
+ acceptErr <- err
+ }
+ acceptSocket <- server
+ }()
+
+ // Connect the client.
+ client, err := Connect(name, packet)
+ if err != nil {
+ t.Fatalf("error connecting, got %v expected nil", err)
+ }
+
+ // Grab the server handle.
+ select {
+ case server := <-acceptSocket:
+ return server, client
+ case err := <-acceptErr:
+ t.Fatalf("accept error: %v", err)
+ }
+ panic("unreachable")
+}
+
+func TestSendRecv(t *testing.T) {
+ server, client := socketPair(t, false)
+ defer server.Close()
+ defer client.Close()
+
+ // Write on the client.
+ w := client.Writer(true)
+ if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
+ t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ }
+
+ // Read on the server.
+ b := [][]byte{{'b'}}
+ r := server.Reader(true)
+ if n, err := r.ReadVec(b); n != 1 || err != nil {
+ t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ }
+ if b[0][0] != 'a' {
+ t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ }
+}
+
+// TestSymmetric exists to assert that the two sockets received from socketPair
+// are interchangeable. They should be, this just provides a basic sanity check
+// by running TestSendRecv "backwards".
+func TestSymmetric(t *testing.T) {
+ server, client := socketPair(t, false)
+ defer server.Close()
+ defer client.Close()
+
+ // Write on the server.
+ w := server.Writer(true)
+ if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
+ t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ }
+
+ // Read on the client.
+ b := [][]byte{{'b'}}
+ r := client.Reader(true)
+ if n, err := r.ReadVec(b); n != 1 || err != nil {
+ t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ }
+ if b[0][0] != 'a' {
+ t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ }
+}
+
+func TestPacket(t *testing.T) {
+ server, client := socketPair(t, true)
+ defer server.Close()
+ defer client.Close()
+
+ // Write on the client.
+ w := client.Writer(true)
+ if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
+ t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ }
+
+ // Write on the client again.
+ w = client.Writer(true)
+ if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
+ t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ }
+
+ // Read on the server.
+ //
+ // This should only get back a single byte, despite the buffer
+ // being size two. This is because it's a _packet_ buffer.
+ b := [][]byte{{'b', 'b'}}
+ r := server.Reader(true)
+ if n, err := r.ReadVec(b); n != 1 || err != nil {
+ t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ }
+ if b[0][0] != 'a' {
+ t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ }
+
+ // Do it again.
+ r = server.Reader(true)
+ if n, err := r.ReadVec(b); n != 1 || err != nil {
+ t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ }
+ if b[0][0] != 'a' {
+ t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ }
+}
+
+func TestClose(t *testing.T) {
+ server, client := socketPair(t, false)
+ defer server.Close()
+
+ // Make sure the first close succeeds.
+ if err := client.Close(); err != nil {
+ t.Fatalf("first close failed, got err %v expected nil", err)
+ }
+
+ // The second one should fail.
+ if err := client.Close(); err == nil {
+ t.Fatalf("second close succeeded, expected non-nil err")
+ }
+}
+
+func TestNonBlockingSend(t *testing.T) {
+ server, client := socketPair(t, false)
+ defer server.Close()
+ defer client.Close()
+
+ // Try up to 1000 writes, of 1000 bytes.
+ blockCount := 0
+ for i := 0; i < 1000; i++ {
+ w := client.Writer(false)
+ if n, err := w.WriteVec([][]byte{make([]byte, 1000)}); n != 1000 || err != nil {
+ if err == syscall.EWOULDBLOCK || err == syscall.EAGAIN {
+ // We're good. That's what we wanted.
+ blockCount++
+ } else {
+ t.Fatalf("for client write, got n=%d err=%v, expected n=1000 err=nil", n, err)
+ }
+ }
+ }
+
+ if blockCount == 1000 {
+ // Shouldn't have _always_ blocked.
+ t.Fatalf("socket always blocked!")
+ } else if blockCount == 0 {
+ // Should have started blocking eventually.
+ t.Fatalf("socket never blocked!")
+ }
+}
+
+func TestNonBlockingRecv(t *testing.T) {
+ server, client := socketPair(t, false)
+ defer server.Close()
+ defer client.Close()
+
+ b := [][]byte{{'b'}}
+ r := client.Reader(false)
+
+ // Expected to block immediately.
+ _, err := r.ReadVec(b)
+ if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN {
+ t.Fatalf("read didn't block, got err %v expected blocking err", err)
+ }
+
+ // Put some data in the pipe.
+ w := server.Writer(false)
+ if n, err := w.WriteVec(b); n != 1 || err != nil {
+ t.Fatalf("write failed with n=%d err=%v, expected n=1 err=nil", n, err)
+ }
+
+ // Expect it not to block.
+ if n, err := r.ReadVec(b); n != 1 || err != nil {
+ t.Fatalf("read failed with n=%d err=%v, expected n=1 err=nil", n, err)
+ }
+
+ // Expect it to return a block error again.
+ r = client.Reader(false)
+ _, err = r.ReadVec(b)
+ if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN {
+ t.Fatalf("read didn't block, got err %v expected blocking err", err)
+ }
+}
+
+func TestRecvVectors(t *testing.T) {
+ server, client := socketPair(t, false)
+ defer server.Close()
+ defer client.Close()
+
+ // Write on the client.
+ w := client.Writer(true)
+ if n, err := w.WriteVec([][]byte{{'a', 'b'}}); n != 2 || err != nil {
+ t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err)
+ }
+
+ // Read on the server.
+ b := [][]byte{{'c'}, {'c'}}
+ r := server.Reader(true)
+ if n, err := r.ReadVec(b); n != 2 || err != nil {
+ t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err)
+ }
+ if b[0][0] != 'a' || b[1][0] != 'b' {
+ t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[1][0])
+ }
+}
+
+func TestSendVectors(t *testing.T) {
+ server, client := socketPair(t, false)
+ defer server.Close()
+ defer client.Close()
+
+ // Write on the client.
+ w := client.Writer(true)
+ if n, err := w.WriteVec([][]byte{{'a'}, {'b'}}); n != 2 || err != nil {
+ t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err)
+ }
+
+ // Read on the server.
+ b := [][]byte{{'c', 'c'}}
+ r := server.Reader(true)
+ if n, err := r.ReadVec(b); n != 2 || err != nil {
+ t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err)
+ }
+ if b[0][0] != 'a' || b[0][1] != 'b' {
+ t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[0][1])
+ }
+}
+
+func TestSendFDsNotEnabled(t *testing.T) {
+ server, client := socketPair(t, false)
+ defer server.Close()
+ defer client.Close()
+
+ // Write on the server.
+ w := server.Writer(true)
+ w.PackFDs(0, 1, 2)
+ if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
+ t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ }
+
+ // Read on the client, without enabling FDs.
+ b := [][]byte{{'b'}}
+ r := client.Reader(true)
+ if n, err := r.ReadVec(b); n != 1 || err != nil {
+ t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ }
+ if b[0][0] != 'a' {
+ t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ }
+
+ // Make sure the FDs are not received.
+ fds, err := r.ExtractFDs()
+ if len(fds) != 0 || err != nil {
+ t.Fatalf("got fds=%v err=%v, expected len(fds)=0 err=nil", fds, err)
+ }
+}
+
+func sendFDs(t *testing.T, s *Socket, fds []int) {
+ w := s.Writer(true)
+ w.PackFDs(fds...)
+ if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
+ t.Fatalf("for write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ }
+}
+
+func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) {
+ expected := len(origFDs)
+
+ // Count the number of FDs.
+ preEntries, err := ioutil.ReadDir("/proc/self/fd")
+ if err != nil {
+ t.Fatalf("can't readdir, got err %v expected nil", err)
+ }
+
+ // Read on the client.
+ b := [][]byte{{'b'}}
+ r := s.Reader(true)
+ if enableSize >= 0 {
+ r.EnableFDs(enableSize)
+ }
+ if n, err := r.ReadVec(b); n != 1 || err != nil {
+ t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ }
+ if b[0][0] != 'a' {
+ t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ }
+
+ // Count the new number of FDs.
+ postEntries, err := ioutil.ReadDir("/proc/self/fd")
+ if err != nil {
+ t.Fatalf("can't readdir, got err %v expected nil", err)
+ }
+ if len(preEntries)+expected != len(postEntries) {
+ t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries)+expected, len(postEntries))
+ }
+
+ // Make sure the FDs are there.
+ fds, err := r.ExtractFDs()
+ if len(fds) != expected || err != nil {
+ t.Fatalf("got fds=%v err=%v, expected len(fds)=%d err=nil", fds, err, expected)
+ }
+
+ // Make sure they are different from the originals.
+ for i := 0; i < len(fds); i++ {
+ if fds[i] == origFDs[i] {
+ t.Errorf("got original fd for index %d, expected different", i)
+ }
+ }
+
+ // Make sure they can be accessed as expected.
+ for i := 0; i < len(fds); i++ {
+ var st syscall.Stat_t
+ if err := syscall.Fstat(fds[i], &st); err != nil {
+ t.Errorf("fds[%d] can't be stated, got err %v expected nil", i, err)
+ }
+ }
+
+ // Close them off.
+ r.CloseFDs()
+
+ // Make sure the count is back to normal.
+ finalEntries, err := ioutil.ReadDir("/proc/self/fd")
+ if err != nil {
+ t.Fatalf("can't readdir, got err %v expected nil", err)
+ }
+ if len(finalEntries) != len(preEntries) {
+ t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries), len(finalEntries))
+ }
+}
+
+func TestFDsSingle(t *testing.T) {
+ server, client := socketPair(t, false)
+ defer server.Close()
+ defer client.Close()
+
+ sendFDs(t, server, []int{0})
+ recvFDs(t, client, 1, []int{0})
+}
+
+func TestFDsMultiple(t *testing.T) {
+ server, client := socketPair(t, false)
+ defer server.Close()
+ defer client.Close()
+
+ // Basic case, multiple FDs.
+ sendFDs(t, server, []int{0, 1, 2})
+ recvFDs(t, client, 3, []int{0, 1, 2})
+}
+
+// See TestSymmetric above.
+func TestFDsSymmetric(t *testing.T) {
+ server, client := socketPair(t, false)
+ defer server.Close()
+ defer client.Close()
+
+ sendFDs(t, server, []int{0, 1, 2})
+ recvFDs(t, client, 3, []int{0, 1, 2})
+}
+
+func TestFDsReceiveLargeBuffer(t *testing.T) {
+ server, client := socketPair(t, false)
+ defer server.Close()
+ defer client.Close()
+
+ sendFDs(t, server, []int{0})
+ recvFDs(t, client, 3, []int{0})
+}
+
+func TestFDsReceiveSmallBuffer(t *testing.T) {
+ server, client := socketPair(t, false)
+ defer server.Close()
+ defer client.Close()
+
+ sendFDs(t, server, []int{0, 1, 2})
+
+ // Per the spec, we may still receive more than the buffer. In fact,
+ // it'll be rounded up and we can receive two with a size one buffer.
+ recvFDs(t, client, 1, []int{0, 1})
+}
+
+func TestFDsReceiveNotEnabled(t *testing.T) {
+ server, client := socketPair(t, false)
+ defer server.Close()
+ defer client.Close()
+
+ sendFDs(t, server, []int{0})
+ recvFDs(t, client, -1, []int{})
+}
+
+func TestFDsReceiveSizeZero(t *testing.T) {
+ server, client := socketPair(t, false)
+ defer server.Close()
+ defer client.Close()
+
+ sendFDs(t, server, []int{0})
+ recvFDs(t, client, 0, []int{})
+}
+
+func TestGetPeerCred(t *testing.T) {
+ server, client := socketPair(t, false)
+ defer server.Close()
+ defer client.Close()
+
+ want := &syscall.Ucred{
+ Pid: int32(os.Getpid()),
+ Uid: uint32(os.Getuid()),
+ Gid: uint32(os.Getgid()),
+ }
+
+ if got, err := client.GetPeerCred(); err != nil || !reflect.DeepEqual(got, want) {
+ t.Errorf("got GetPeerCred() = %v, %v, want = %+v, %+v", got, err, want, nil)
+ }
+}
+
+func newClosedSocket() (*Socket, error) {
+ fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ return nil, err
+ }
+
+ s, err := NewSocket(fd)
+ if err != nil {
+ syscall.Close(fd)
+ return nil, err
+ }
+
+ return s, s.Close()
+}
+
+func TestGetPeerCredFailure(t *testing.T) {
+ s, err := newClosedSocket()
+ if err != nil {
+ t.Fatalf("newClosedSocket got error %v want nil", err)
+ }
+
+ want := "bad file descriptor"
+ if _, err := s.GetPeerCred(); err == nil || err.Error() != want {
+ t.Errorf("got s.GetPeerCred() = %v, want = %s", err, want)
+ }
+}
+
+func TestAcceptClosed(t *testing.T) {
+ name, err := randomFilename()
+ if err != nil {
+ t.Fatalf("unable to generate file, got err %v expected nil", err)
+ }
+
+ ss, err := BindAndListen(name, false)
+ if err != nil {
+ t.Fatalf("bind failed, got err %v expected nil", err)
+ }
+
+ if err := ss.Close(); err != nil {
+ t.Fatalf("close failed, got err %v expected nil", err)
+ }
+
+ if _, err := ss.Accept(); err == nil {
+ t.Errorf("accept on closed SocketServer, got err %v, want != nil", err)
+ }
+}
+
+func TestCloseAfterAcceptStart(t *testing.T) {
+ name, err := randomFilename()
+ if err != nil {
+ t.Fatalf("unable to generate file, got err %v expected nil", err)
+ }
+
+ ss, err := BindAndListen(name, false)
+ if err != nil {
+ t.Fatalf("bind failed, got err %v expected nil", err)
+ }
+
+ wg := sync.WaitGroup{}
+ wg.Add(1)
+ go func() {
+ time.Sleep(50 * time.Millisecond)
+ if err := ss.Close(); err != nil {
+ t.Fatalf("close failed, got err %v expected nil", err)
+ }
+ wg.Done()
+ }()
+
+ if _, err := ss.Accept(); err == nil {
+ t.Errorf("accept on closed SocketServer, got err %v, want != nil", err)
+ }
+
+ wg.Wait()
+}
+
+func TestReleaseAfterAcceptStart(t *testing.T) {
+ name, err := randomFilename()
+ if err != nil {
+ t.Fatalf("unable to generate file, got err %v expected nil", err)
+ }
+
+ ss, err := BindAndListen(name, false)
+ if err != nil {
+ t.Fatalf("bind failed, got err %v expected nil", err)
+ }
+
+ wg := sync.WaitGroup{}
+ wg.Add(1)
+ go func() {
+ time.Sleep(50 * time.Millisecond)
+ fd, err := ss.Release()
+ if err != nil {
+ t.Fatalf("Release failed, got err %v expected nil", err)
+ }
+ syscall.Close(fd)
+ wg.Done()
+ }()
+
+ if _, err := ss.Accept(); err == nil {
+ t.Errorf("accept on closed SocketServer, got err %v, want != nil", err)
+ }
+
+ wg.Wait()
+}
+
+func TestControlMessage(t *testing.T) {
+ for i := 0; i <= 10; i++ {
+ var want []int
+ for j := 0; j < i; j++ {
+ want = append(want, i+j+1)
+ }
+
+ var cm ControlMessage
+ cm.EnableFDs(i)
+ cm.PackFDs(want...)
+ got, err := cm.ExtractFDs()
+ if err != nil || !reflect.DeepEqual(got, want) {
+ t.Errorf("got cm.ExtractFDs() = %v, %v, want = %v, %v", got, err, want, nil)
+ }
+ }
+}
+
+func benchmarkSendRecv(b *testing.B, packet bool) {
+ server, client, err := SocketPair(packet)
+ if err != nil {
+ b.Fatalf("SocketPair: got %v, wanted nil", err)
+ }
+ defer server.Close()
+ defer client.Close()
+ go func() {
+ buf := make([]byte, 1)
+ for i := 0; i < b.N; i++ {
+ n, err := server.Read(buf)
+ if n != 1 || err != nil {
+ b.Fatalf("server.Read: got (%d, %v), wanted (1, nil)", n, err)
+ }
+ n, err = server.Write(buf)
+ if n != 1 || err != nil {
+ b.Fatalf("server.Write: got (%d, %v), wanted (1, nil)", n, err)
+ }
+ }
+ }()
+ buf := make([]byte, 1)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ n, err := client.Write(buf)
+ if n != 1 || err != nil {
+ b.Fatalf("client.Write: got (%d, %v), wanted (1, nil)", n, err)
+ }
+ n, err = client.Read(buf)
+ if n != 1 || err != nil {
+ b.Fatalf("client.Read: got (%d, %v), wanted (1, nil)", n, err)
+ }
+ }
+}
+
+func BenchmarkSendRecvStream(b *testing.B) {
+ benchmarkSendRecv(b, false)
+}
+
+func BenchmarkSendRecvPacket(b *testing.B) {
+ benchmarkSendRecv(b, true)
+}
diff --git a/pkg/unet/unet_unsafe.go b/pkg/unet/unet_unsafe.go
new file mode 100644
index 000000000..85ef46edf
--- /dev/null
+++ b/pkg/unet/unet_unsafe.go
@@ -0,0 +1,288 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unet
+
+import (
+ "io"
+ "sync/atomic"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+// wait blocks until the socket FD is ready for reading or writing, depending
+// on the value of write.
+//
+// Returns errClosing if the Socket is in the process of closing.
+func (s *Socket) wait(write bool) error {
+ for {
+ // Checking the FD on each loop is not strictly necessary, it
+ // just avoids an extra poll call.
+ fd := atomic.LoadInt32(&s.fd)
+ if fd < 0 {
+ return errClosing
+ }
+
+ events := []unix.PollFd{
+ {
+ // The actual socket FD.
+ Fd: fd,
+ Events: unix.POLLIN,
+ },
+ {
+ // The eventfd, signaled when we are closing.
+ Fd: int32(s.efd),
+ Events: unix.POLLIN,
+ },
+ }
+ if write {
+ events[0].Events = unix.POLLOUT
+ }
+
+ _, _, e := syscall.Syscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(&events[0])), 2, 0, 0, 0, 0)
+ if e == syscall.EINTR {
+ continue
+ }
+ if e != 0 {
+ return e
+ }
+
+ if events[1].Revents&unix.POLLIN == unix.POLLIN {
+ // eventfd signaled, we're closing.
+ return errClosing
+ }
+
+ return nil
+ }
+}
+
+// buildIovec builds an iovec slice from the given []byte slice.
+//
+// iovecs is used as an initial slice, to avoid excessive allocations.
+func buildIovec(bufs [][]byte, iovecs []syscall.Iovec) ([]syscall.Iovec, int) {
+ var length int
+ for i := range bufs {
+ if l := len(bufs[i]); l > 0 {
+ iovecs = append(iovecs, syscall.Iovec{
+ Base: &bufs[i][0],
+ Len: uint64(l),
+ })
+ length += l
+ }
+ }
+ return iovecs, length
+}
+
+// ReadVec reads into the pre-allocated bufs. Returns bytes read.
+//
+// The pre-allocatted space used by ReadVec is based upon slice lengths.
+//
+// This function is not guaranteed to read all available data, it
+// returns as soon as a single recvmsg call succeeds.
+func (r *SocketReader) ReadVec(bufs [][]byte) (int, error) {
+ iovecs, length := buildIovec(bufs, make([]syscall.Iovec, 0, 2))
+
+ var msg syscall.Msghdr
+ if len(r.source) != 0 {
+ msg.Name = &r.source[0]
+ msg.Namelen = uint32(len(r.source))
+ }
+
+ if len(r.ControlMessage) != 0 {
+ msg.Control = &r.ControlMessage[0]
+ msg.Controllen = uint64(len(r.ControlMessage))
+ }
+
+ if len(iovecs) != 0 {
+ msg.Iov = &iovecs[0]
+ msg.Iovlen = uint64(len(iovecs))
+ }
+
+ // n is the bytes received.
+ var n uintptr
+
+ fd, ok := r.socket.enterFD()
+ if !ok {
+ return 0, syscall.EBADF
+ }
+ // Leave on returns below.
+ for {
+ var e syscall.Errno
+
+ // Try a non-blocking recv first, so we don't give up the go runtime M.
+ n, _, e = syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_TRUNC)
+ if e == 0 {
+ break
+ }
+ if e == syscall.EINTR {
+ continue
+ }
+ if !r.blocking {
+ r.socket.gate.Leave()
+ return 0, e
+ }
+ if e != syscall.EAGAIN && e != syscall.EWOULDBLOCK {
+ r.socket.gate.Leave()
+ return 0, e
+ }
+
+ // Wait for the socket to become readable.
+ err := r.socket.wait(false)
+ if err == errClosing {
+ err = syscall.EBADF
+ }
+ if err != nil {
+ r.socket.gate.Leave()
+ return 0, err
+ }
+ }
+
+ r.socket.gate.Leave()
+
+ if msg.Controllen < uint64(len(r.ControlMessage)) {
+ r.ControlMessage = r.ControlMessage[:msg.Controllen]
+ }
+
+ if msg.Namelen < uint32(len(r.source)) {
+ r.source = r.source[:msg.Namelen]
+ }
+
+ // All unet sockets are SOCK_STREAM or SOCK_SEQPACKET, both of which
+ // indicate that the other end is closed by returning a 0 length read
+ // with no error.
+ if n == 0 {
+ return 0, io.EOF
+ }
+
+ if r.race != nil {
+ // See comments on Socket.race.
+ atomic.AddInt32(r.race, 1)
+ }
+
+ if int(n) > length {
+ return length, errMessageTruncated
+ }
+
+ return int(n), nil
+}
+
+// WriteVec writes the bufs to the socket. Returns bytes written.
+//
+// This function is not guaranteed to send all data, it returns
+// as soon as a single sendmsg call succeeds.
+func (w *SocketWriter) WriteVec(bufs [][]byte) (int, error) {
+ iovecs, _ := buildIovec(bufs, make([]syscall.Iovec, 0, 2))
+
+ if w.race != nil {
+ // See comments on Socket.race.
+ atomic.AddInt32(w.race, 1)
+ }
+
+ var msg syscall.Msghdr
+ if len(w.to) != 0 {
+ msg.Name = &w.to[0]
+ msg.Namelen = uint32(len(w.to))
+ }
+
+ if len(w.ControlMessage) != 0 {
+ msg.Control = &w.ControlMessage[0]
+ msg.Controllen = uint64(len(w.ControlMessage))
+ }
+
+ if len(iovecs) > 0 {
+ msg.Iov = &iovecs[0]
+ msg.Iovlen = uint64(len(iovecs))
+ }
+
+ fd, ok := w.socket.enterFD()
+ if !ok {
+ return 0, syscall.EBADF
+ }
+ // Leave on returns below.
+ for {
+ // Try a non-blocking send first, so we don't give up the go runtime M.
+ n, _, e := syscall.RawSyscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_NOSIGNAL)
+ if e == 0 {
+ w.socket.gate.Leave()
+ return int(n), nil
+ }
+ if e == syscall.EINTR {
+ continue
+ }
+ if !w.blocking {
+ w.socket.gate.Leave()
+ return 0, e
+ }
+ if e != syscall.EAGAIN && e != syscall.EWOULDBLOCK {
+ w.socket.gate.Leave()
+ return 0, e
+ }
+
+ // Wait for the socket to become writeable.
+ err := w.socket.wait(true)
+ if err == errClosing {
+ err = syscall.EBADF
+ }
+ if err != nil {
+ w.socket.gate.Leave()
+ return 0, err
+ }
+ }
+ // Unreachable, no s.gate.Leave needed.
+}
+
+// getsockopt issues a getsockopt syscall.
+func getsockopt(fd int, level int, optname int, buf []byte) (uint32, error) {
+ l := uint32(len(buf))
+ _, _, e := syscall.RawSyscall6(syscall.SYS_GETSOCKOPT, uintptr(fd), uintptr(level), uintptr(optname), uintptr(unsafe.Pointer(&buf[0])), uintptr(unsafe.Pointer(&l)), 0)
+ if e != 0 {
+ return 0, e
+ }
+
+ return l, nil
+}
+
+// setsockopt issues a setsockopt syscall.
+func setsockopt(fd int, level int, optname int, buf []byte) error {
+ _, _, e := syscall.RawSyscall6(syscall.SYS_SETSOCKOPT, uintptr(fd), uintptr(level), uintptr(optname), uintptr(unsafe.Pointer(&buf[0])), uintptr(len(buf)), 0)
+ if e != 0 {
+ return e
+ }
+
+ return nil
+}
+
+// getsockname issues a getsockname syscall.
+func getsockname(fd int, buf []byte) (uint32, error) {
+ l := uint32(len(buf))
+ _, _, e := syscall.RawSyscall(syscall.SYS_GETSOCKNAME, uintptr(fd), uintptr(unsafe.Pointer(&buf[0])), uintptr(unsafe.Pointer(&l)))
+ if e != 0 {
+ return 0, e
+ }
+
+ return l, nil
+}
+
+// getpeername issues a getpeername syscall.
+func getpeername(fd int, buf []byte) (uint32, error) {
+ l := uint32(len(buf))
+ _, _, e := syscall.RawSyscall(syscall.SYS_GETPEERNAME, uintptr(fd), uintptr(unsafe.Pointer(&buf[0])), uintptr(unsafe.Pointer(&l)))
+ if e != 0 {
+ return 0, e
+ }
+
+ return l, nil
+}
diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD
new file mode 100644
index 000000000..850c34ed0
--- /dev/null
+++ b/pkg/urpc/BUILD
@@ -0,0 +1,23 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "urpc",
+ srcs = ["urpc.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/fd",
+ "//pkg/log",
+ "//pkg/sync",
+ "//pkg/unet",
+ ],
+)
+
+go_test(
+ name = "urpc_test",
+ size = "small",
+ srcs = ["urpc_test.go"],
+ library = ":urpc",
+ deps = ["//pkg/unet"],
+)
diff --git a/pkg/urpc/urpc.go b/pkg/urpc/urpc.go
new file mode 100644
index 000000000..13b2ea314
--- /dev/null
+++ b/pkg/urpc/urpc.go
@@ -0,0 +1,636 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package urpc provides a minimal RPC package based on unet.
+//
+// RPC requests are _not_ concurrent and methods must be explicitly
+// registered. However, files may be send as part of the payload.
+package urpc
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "reflect"
+ "runtime"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// maxFiles determines the maximum file payload.
+const maxFiles = 32
+
+// ErrTooManyFiles is returned when too many file descriptors are mapped.
+var ErrTooManyFiles = errors.New("too many files")
+
+// ErrUnknownMethod is returned when a method is not known.
+var ErrUnknownMethod = errors.New("unknown method")
+
+// errStopped is an internal error indicating the server has been stopped.
+var errStopped = errors.New("stopped")
+
+// RemoteError is an error returned by the remote invocation.
+//
+// This indicates that the RPC transport was correct, but that the called
+// function itself returned an error.
+type RemoteError struct {
+ // Message is the result of calling Error() on the remote error.
+ Message string
+}
+
+// Error returns the remote error string.
+func (r RemoteError) Error() string {
+ return r.Message
+}
+
+// FilePayload may be _embedded_ in another type in order to send or receive a
+// file as a result of an RPC. These are not actually serialized, rather they
+// are sent via an accompanying SCM_RIGHTS message (plumbed through the unet
+// package).
+//
+// When embedding a FilePayload in an argument struct, the argument type _must_
+// be a pointer to the struct rather than the struct type itself. This is
+// because the urpc package defines pointer methods on FilePayload.
+type FilePayload struct {
+ Files []*os.File `json:"-"`
+}
+
+// ReleaseFD releases the FD at the specified index.
+func (f *FilePayload) ReleaseFD(index int) (*fd.FD, error) {
+ return fd.NewFromFile(f.Files[index])
+}
+
+// filePayload returns the file. It may be nil.
+func (f *FilePayload) filePayload() []*os.File {
+ return f.Files
+}
+
+// setFilePayload sets the payload.
+func (f *FilePayload) setFilePayload(fs []*os.File) {
+ f.Files = fs
+}
+
+// closeAll closes a slice of files.
+func closeAll(files []*os.File) {
+ for _, f := range files {
+ f.Close()
+ }
+}
+
+// filePayloader is implemented only by FilePayload and will be implicitly
+// implemented by types that have the FilePayload embedded. Note that there is
+// no way to implement these methods other than by embedding FilePayload, due
+// to the way unexported method names are mangled.
+type filePayloader interface {
+ filePayload() []*os.File
+ setFilePayload([]*os.File)
+}
+
+// clientCall is the client=>server method call on the client side.
+type clientCall struct {
+ Method string `json:"method"`
+ Arg interface{} `json:"arg"`
+}
+
+// serverCall is the client=>server method call on the server side.
+type serverCall struct {
+ Method string `json:"method"`
+ Arg json.RawMessage `json:"arg"`
+}
+
+// callResult is the server=>client method call result.
+type callResult struct {
+ Success bool `json:"success"`
+ Err string `json:"err"`
+ Result interface{} `json:"result"`
+}
+
+// registeredMethod is method registered with the server.
+type registeredMethod struct {
+ // fn is the underlying function.
+ fn reflect.Value
+
+ // rcvr is the receiver value.
+ rcvr reflect.Value
+
+ // argType is a typed argument.
+ argType reflect.Type
+
+ // resultType is also a type result.
+ resultType reflect.Type
+}
+
+// clientState is client metadata.
+//
+// The following are valid states:
+//
+// idle - not processing any requests, no close request.
+// processing - actively processing, no close request.
+// closeRequested - actively processing, pending close.
+// closed - client connection has been closed.
+//
+// The following transitions are possible:
+//
+// idle -> processing, closed
+// processing -> idle, closeRequested
+// closeRequested -> closed
+//
+type clientState int
+
+// See clientState.
+const (
+ idle clientState = iota
+ processing
+ closeRequested
+ closed
+)
+
+// Server is an RPC server.
+type Server struct {
+ // mu protects all fields, except wg.
+ mu sync.Mutex
+
+ // methods is the set of server methods.
+ methods map[string]registeredMethod
+
+ // clients is a map of clients.
+ clients map[*unet.Socket]clientState
+
+ // wg is a wait group for all outstanding clients.
+ wg sync.WaitGroup
+
+ // afterRPCCallback is called after each RPC is successfully completed.
+ afterRPCCallback func()
+}
+
+// NewServer returns a new server.
+func NewServer() *Server {
+ return NewServerWithCallback(nil)
+}
+
+// NewServerWithCallback returns a new server, who upon completion of each RPC
+// calls the given function.
+func NewServerWithCallback(afterRPCCallback func()) *Server {
+ return &Server{
+ methods: make(map[string]registeredMethod),
+ clients: make(map[*unet.Socket]clientState),
+ afterRPCCallback: afterRPCCallback,
+ }
+}
+
+// Register registers the given object as an RPC receiver.
+//
+// This functions is the same way as the built-in RPC package, but it does not
+// tolerate any object with non-conforming methods. Any non-confirming methods
+// will lead to an immediate panic, instead of being skipped or an error.
+// Panics will also be generated by anonymous objects and duplicate entries.
+func (s *Server) Register(obj interface{}) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ typ := reflect.TypeOf(obj)
+
+ // If we got a pointer, deref it to the underlying object. We need this to
+ // obtain the name of the underlying type.
+ typDeref := typ
+ if typ.Kind() == reflect.Ptr {
+ typDeref = typ.Elem()
+ }
+
+ for m := 0; m < typ.NumMethod(); m++ {
+ method := typ.Method(m)
+
+ if typDeref.Name() == "" {
+ // Can't be anonymous.
+ panic("type not named.")
+ }
+
+ prettyName := typDeref.Name() + "." + method.Name
+ if _, ok := s.methods[prettyName]; ok {
+ // Duplicate entry.
+ panic(fmt.Sprintf("method %s is duplicated.", prettyName))
+ }
+
+ if method.PkgPath != "" {
+ // Must be exported.
+ panic(fmt.Sprintf("method %s is not exported.", prettyName))
+ }
+ mtype := method.Type
+ if mtype.NumIn() != 3 {
+ // Need exactly two arguments (+ receiver).
+ panic(fmt.Sprintf("method %s has wrong number of arguments.", prettyName))
+ }
+ argType := mtype.In(1)
+ if argType.Kind() != reflect.Ptr {
+ // Need arg pointer.
+ panic(fmt.Sprintf("method %s has non-pointer first argument.", prettyName))
+ }
+ resultType := mtype.In(2)
+ if resultType.Kind() != reflect.Ptr {
+ // Need result pointer.
+ panic(fmt.Sprintf("method %s has non-pointer second argument.", prettyName))
+ }
+ if mtype.NumOut() != 1 {
+ // Need single return.
+ panic(fmt.Sprintf("method %s has wrong number of returns.", prettyName))
+ }
+ if returnType := mtype.Out(0); returnType != reflect.TypeOf((*error)(nil)).Elem() {
+ // Need error return.
+ panic(fmt.Sprintf("method %s has non-error return value.", prettyName))
+ }
+
+ // Register the method.
+ s.methods[prettyName] = registeredMethod{
+ fn: method.Func,
+ rcvr: reflect.ValueOf(obj),
+ argType: argType,
+ resultType: resultType,
+ }
+ }
+}
+
+// lookup looks up the given method.
+func (s *Server) lookup(method string) (registeredMethod, bool) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ rm, ok := s.methods[method]
+ return rm, ok
+}
+
+// handleOne handles a single call.
+func (s *Server) handleOne(client *unet.Socket) error {
+ // Unmarshal the call.
+ var c serverCall
+ newFs, err := unmarshal(client, &c)
+ if err != nil {
+ // Client is dead.
+ return err
+ }
+
+ defer func() {
+ if s.afterRPCCallback != nil {
+ s.afterRPCCallback()
+ }
+ }()
+ // Explicitly close all these files after the call.
+ //
+ // This is also explicitly a reference to the files after the call,
+ // which means they are kept open for the duration of the call.
+ defer closeAll(newFs)
+
+ // Start the request.
+ if !s.clientBeginRequest(client) {
+ // Client is dead; don't process this call.
+ return errStopped
+ }
+ defer s.clientEndRequest(client)
+
+ // Lookup the method.
+ rm, ok := s.lookup(c.Method)
+ if !ok {
+ // Try to serialize the error.
+ return marshal(client, &callResult{Err: ErrUnknownMethod.Error()}, nil)
+ }
+
+ // Unmarshal the arguments now that we know the type.
+ na := reflect.New(rm.argType.Elem())
+ if err := json.Unmarshal(c.Arg, na.Interface()); err != nil {
+ return marshal(client, &callResult{Err: err.Error()}, nil)
+ }
+
+ // Set the file payload as an argument.
+ if fp, ok := na.Interface().(filePayloader); ok {
+ fp.setFilePayload(newFs)
+ }
+
+ // Call the method.
+ re := reflect.New(rm.resultType.Elem())
+ rValues := rm.fn.Call([]reflect.Value{rm.rcvr, na, re})
+ if errVal := rValues[0].Interface(); errVal != nil {
+ return marshal(client, &callResult{Err: errVal.(error).Error()}, nil)
+ }
+
+ // Set the resulting payload.
+ var fs []*os.File
+ if fp, ok := re.Interface().(filePayloader); ok {
+ fs = fp.filePayload()
+ if len(fs) > maxFiles {
+ // Ugh. Send an error to the client, despite success.
+ return marshal(client, &callResult{Err: ErrTooManyFiles.Error()}, nil)
+ }
+ }
+
+ // Marshal the result.
+ return marshal(client, &callResult{Success: true, Result: re.Interface()}, fs)
+}
+
+// clientBeginRequest begins a request.
+//
+// If true is returned, the request may be processed. If false is returned,
+// then the server has been stopped and the request should be skipped.
+func (s *Server) clientBeginRequest(client *unet.Socket) bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ switch state := s.clients[client]; state {
+ case idle:
+ // Mark as processing.
+ s.clients[client] = processing
+ return true
+ case closed:
+ // Whoops, how did this happen? Must have closed immediately
+ // following the deserialization. Don't let the RPC actually go
+ // through, since we won't be able to serialize a proper
+ // response.
+ return false
+ default:
+ // Should not happen.
+ panic(fmt.Sprintf("expected idle or closed, got %d", state))
+ }
+}
+
+// clientEndRequest ends a request.
+func (s *Server) clientEndRequest(client *unet.Socket) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ switch state := s.clients[client]; state {
+ case processing:
+ // Return to idle.
+ s.clients[client] = idle
+ case closeRequested:
+ // Close the connection.
+ client.Close()
+ s.clients[client] = closed
+ default:
+ // Should not happen.
+ panic(fmt.Sprintf("expected processing or requestClose, got %d", state))
+ }
+}
+
+// clientRegister registers a connection.
+//
+// See Stop for more context.
+func (s *Server) clientRegister(client *unet.Socket) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.clients[client] = idle
+ s.wg.Add(1)
+}
+
+// clientUnregister unregisters and closes a connection if necessary.
+//
+// See Stop for more context.
+func (s *Server) clientUnregister(client *unet.Socket) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ switch state := s.clients[client]; state {
+ case idle:
+ // Close the connection.
+ client.Close()
+ case closed:
+ // Already done.
+ default:
+ // Should not happen.
+ panic(fmt.Sprintf("expected idle or closed, got %d", state))
+ }
+ delete(s.clients, client)
+ s.wg.Done()
+}
+
+// handleRegistered handles calls from a registered client.
+func (s *Server) handleRegistered(client *unet.Socket) error {
+ for {
+ // Handle one call.
+ if err := s.handleOne(client); err != nil {
+ // Client is dead.
+ return err
+ }
+ }
+}
+
+// Handle synchronously handles a single client over a connection.
+func (s *Server) Handle(client *unet.Socket) error {
+ s.clientRegister(client)
+ defer s.clientUnregister(client)
+ return s.handleRegistered(client)
+}
+
+// StartHandling creates a goroutine that handles a single client over a
+// connection.
+func (s *Server) StartHandling(client *unet.Socket) {
+ s.clientRegister(client)
+ go func() { // S/R-SAFE: out of scope
+ defer s.clientUnregister(client)
+ s.handleRegistered(client)
+ }()
+}
+
+// Stop safely terminates outstanding clients.
+//
+// No new requests should be initiated after calling Stop. Existing clients
+// will be closed after completing any pending RPCs. This method will block
+// until all clients have disconnected.
+func (s *Server) Stop() {
+ // Wait for all outstanding requests.
+ defer s.wg.Wait()
+
+ // Close all known clients.
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for client, state := range s.clients {
+ switch state {
+ case idle:
+ // Close connection now.
+ client.Close()
+ s.clients[client] = closed
+ case processing:
+ // Request close when done.
+ s.clients[client] = closeRequested
+ }
+ }
+}
+
+// Client is a urpc client.
+type Client struct {
+ // mu protects all members.
+ //
+ // It also enforces single-call semantics.
+ mu sync.Mutex
+
+ // Socket is the underlying socket for this client.
+ //
+ // This _must_ be provided and must be closed manually by calling
+ // Close.
+ Socket *unet.Socket
+}
+
+// NewClient returns a new client.
+func NewClient(socket *unet.Socket) *Client {
+ return &Client{
+ Socket: socket,
+ }
+}
+
+// marshal sends the given FD and json struct.
+func marshal(s *unet.Socket, v interface{}, fs []*os.File) error {
+ // Marshal to a buffer.
+ data, err := json.Marshal(v)
+ if err != nil {
+ log.Warningf("urpc: error marshalling %s: %s", fmt.Sprintf("%v", v), err.Error())
+ return err
+ }
+
+ // Write to the socket.
+ w := s.Writer(true)
+ if fs != nil {
+ var fds []int
+ for _, f := range fs {
+ fds = append(fds, int(f.Fd()))
+ }
+ w.PackFDs(fds...)
+ }
+
+ // Send.
+ for n := 0; n < len(data); {
+ cur, err := w.WriteVec([][]byte{data[n:]})
+ if n == 0 && cur < len(data) {
+ // Don't send FDs anymore. This call is only made on
+ // the first successful call to WriteVec, assuming cur
+ // is not sufficient to fill the entire buffer.
+ w.PackFDs()
+ }
+ n += cur
+ if err != nil {
+ log.Warningf("urpc: error writing %v: %s", data[n:], err.Error())
+ return err
+ }
+ }
+
+ // We're done sending the fds to the client. Explicitly prevent fs from
+ // being GCed until here. Urpc rpcs often unlink the file to send, relying
+ // on the kernel to automatically delete it once the last reference is
+ // dropped. Until we successfully call sendmsg(2), fs may contain the last
+ // references to these files. Without this explicit reference to fs here,
+ // the go runtime is free to assume we're done with fs after the fd
+ // collection loop above, since it just sees us copying ints.
+ runtime.KeepAlive(fs)
+
+ log.Debugf("urpc: successfully marshalled %d bytes.", len(data))
+ return nil
+}
+
+// unmarhsal receives an FD (optional) and unmarshals the given struct.
+func unmarshal(s *unet.Socket, v interface{}) ([]*os.File, error) {
+ // Receive a single byte.
+ r := s.Reader(true)
+ r.EnableFDs(maxFiles)
+ firstByte := make([]byte, 1)
+
+ // Extract any FDs that may be there.
+ if _, err := r.ReadVec([][]byte{firstByte}); err != nil {
+ return nil, err
+ }
+ fds, err := r.ExtractFDs()
+ if err != nil {
+ log.Warningf("urpc: error extracting fds: %s", err.Error())
+ return nil, err
+ }
+ var fs []*os.File
+ for _, fd := range fds {
+ fs = append(fs, os.NewFile(uintptr(fd), "urpc"))
+ }
+
+ // Read the rest.
+ d := json.NewDecoder(io.MultiReader(bytes.NewBuffer(firstByte), s))
+ // urpc internally decodes / re-encodes the data with interface{} as the
+ // intermediate type. We have to unmarshal integers to json.Number type
+ // instead of the default float type for those intermediate values, such
+ // that when they get re-encoded, their values are not printed out in
+ // floating-point formats such as 1e9, which could not be decoded to
+ // explicitly typed intergers later.
+ d.UseNumber()
+ if err := d.Decode(v); err != nil {
+ log.Warningf("urpc: error decoding: %s", err.Error())
+ for _, f := range fs {
+ f.Close()
+ }
+ return nil, err
+ }
+
+ // All set.
+ log.Debugf("urpc: unmarshal success.")
+ return fs, nil
+}
+
+// Call calls a function.
+func (c *Client) Call(method string, arg interface{}, result interface{}) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ // If arg is a FilePayload, not a *FilePayload, files won't actually be
+ // sent, so error out.
+ if _, ok := arg.(FilePayload); ok {
+ return fmt.Errorf("argument is a FilePayload, but should be a *FilePayload")
+ }
+
+ // Are there files to send?
+ var fs []*os.File
+ if fp, ok := arg.(filePayloader); ok {
+ fs = fp.filePayload()
+ if len(fs) > maxFiles {
+ return ErrTooManyFiles
+ }
+ }
+
+ // Marshal the data.
+ if err := marshal(c.Socket, &clientCall{Method: method, Arg: arg}, fs); err != nil {
+ return err
+ }
+
+ // Wait for the response.
+ callR := callResult{Result: result}
+ newFs, err := unmarshal(c.Socket, &callR)
+ if err != nil {
+ return fmt.Errorf("urpc method %q failed: %v", method, err)
+ }
+
+ // Set the file payload.
+ if fp, ok := result.(filePayloader); ok {
+ fp.setFilePayload(newFs)
+ } else {
+ closeAll(newFs)
+ }
+
+ // Did an error occur?
+ if !callR.Success {
+ return RemoteError{Message: callR.Err}
+ }
+
+ // All set.
+ return nil
+}
+
+// Close closes the underlying socket.
+//
+// Further calls to the client may result in undefined behavior.
+func (c *Client) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.Socket.Close()
+}
diff --git a/pkg/urpc/urpc_test.go b/pkg/urpc/urpc_test.go
new file mode 100644
index 000000000..c6c7ce9d4
--- /dev/null
+++ b/pkg/urpc/urpc_test.go
@@ -0,0 +1,210 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package urpc
+
+import (
+ "errors"
+ "os"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+type test struct {
+}
+
+type testArg struct {
+ StringArg string
+ IntArg int
+ FilePayload
+}
+
+type testResult struct {
+ StringResult string
+ IntResult int
+ FilePayload
+}
+
+func (t test) Func(a *testArg, r *testResult) error {
+ r.StringResult = a.StringArg
+ r.IntResult = a.IntArg
+ return nil
+}
+
+func (t test) Err(a *testArg, r *testResult) error {
+ return errors.New("test error")
+}
+
+func (t test) FailNoFile(a *testArg, r *testResult) error {
+ if a.Files == nil {
+ return errors.New("no file found")
+ }
+
+ return nil
+}
+
+func (t test) SendFile(a *testArg, r *testResult) error {
+ r.Files = []*os.File{os.Stdin, os.Stdout, os.Stderr}
+ return nil
+}
+
+func (t test) TooManyFiles(a *testArg, r *testResult) error {
+ for i := 0; i <= maxFiles; i++ {
+ r.Files = append(r.Files, os.Stdin)
+ }
+ return nil
+}
+
+func startServer(socket *unet.Socket) {
+ s := NewServer()
+ s.Register(test{})
+ s.StartHandling(socket)
+}
+
+func testClient() (*Client, error) {
+ serverSock, clientSock, err := unet.SocketPair(false)
+ if err != nil {
+ return nil, err
+ }
+ startServer(serverSock)
+
+ return NewClient(clientSock), nil
+}
+
+func TestCall(t *testing.T) {
+ c, err := testClient()
+ if err != nil {
+ t.Fatalf("error creating test client: %v", err)
+ }
+ defer c.Close()
+
+ var r testResult
+ if err := c.Call("test.Func", &testArg{}, &r); err != nil {
+ t.Errorf("basic call failed: %v", err)
+ } else if r.StringResult != "" || r.IntResult != 0 {
+ t.Errorf("unexpected result, got %v expected zero value", r)
+ }
+ if err := c.Call("test.Func", &testArg{StringArg: "hello"}, &r); err != nil {
+ t.Errorf("basic call failed: %v", err)
+ } else if r.StringResult != "hello" {
+ t.Errorf("unexpected result, got %v expected hello", r.StringResult)
+ }
+ if err := c.Call("test.Func", &testArg{IntArg: 1}, &r); err != nil {
+ t.Errorf("basic call failed: %v", err)
+ } else if r.IntResult != 1 {
+ t.Errorf("unexpected result, got %v expected 1", r.IntResult)
+ }
+}
+
+func TestUnknownMethod(t *testing.T) {
+ c, err := testClient()
+ if err != nil {
+ t.Fatalf("error creating test client: %v", err)
+ }
+ defer c.Close()
+
+ var r testResult
+ if err := c.Call("test.Unknown", &testArg{}, &r); err == nil {
+ t.Errorf("expected non-nil err, got nil")
+ } else if err.Error() != ErrUnknownMethod.Error() {
+ t.Errorf("expected test error, got %v", err)
+ }
+}
+
+func TestErr(t *testing.T) {
+ c, err := testClient()
+ if err != nil {
+ t.Fatalf("error creating test client: %v", err)
+ }
+ defer c.Close()
+
+ var r testResult
+ if err := c.Call("test.Err", &testArg{}, &r); err == nil {
+ t.Errorf("expected non-nil err, got nil")
+ } else if err.Error() != "test error" {
+ t.Errorf("expected test error, got %v", err)
+ }
+}
+
+func TestSendFile(t *testing.T) {
+ c, err := testClient()
+ if err != nil {
+ t.Fatalf("error creating test client: %v", err)
+ }
+ defer c.Close()
+
+ var r testResult
+ if err := c.Call("test.FailNoFile", &testArg{}, &r); err == nil {
+ t.Errorf("expected non-nil err, got nil")
+ }
+ if err := c.Call("test.FailNoFile", &testArg{FilePayload: FilePayload{Files: []*os.File{os.Stdin, os.Stdout, os.Stdin}}}, &r); err != nil {
+ t.Errorf("expected nil err, got %v", err)
+ }
+}
+
+func TestRecvFile(t *testing.T) {
+ c, err := testClient()
+ if err != nil {
+ t.Fatalf("error creating test client: %v", err)
+ }
+ defer c.Close()
+
+ var r testResult
+ if err := c.Call("test.SendFile", &testArg{}, &r); err != nil {
+ t.Errorf("expected nil err, got %v", err)
+ }
+ if r.Files == nil {
+ t.Errorf("expected file, got nil")
+ }
+}
+
+func TestShutdown(t *testing.T) {
+ serverSock, clientSock, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("error creating test client: %v", err)
+ }
+ clientSock.Close()
+
+ s := NewServer()
+ if err := s.Handle(serverSock); err == nil {
+ t.Errorf("expected non-nil err, got nil")
+ }
+}
+
+func TestTooManyFiles(t *testing.T) {
+ c, err := testClient()
+ if err != nil {
+ t.Fatalf("error creating test client: %v", err)
+ }
+ defer c.Close()
+
+ var r testResult
+ var a testArg
+ for i := 0; i <= maxFiles; i++ {
+ a.Files = append(a.Files, os.Stdin)
+ }
+
+ // Client-side error.
+ if err := c.Call("test.Func", &a, &r); err != ErrTooManyFiles {
+ t.Errorf("expected ErrTooManyFiles, got %v", err)
+ }
+
+ // Server-side error.
+ if err := c.Call("test.TooManyFiles", &testArg{}, &r); err == nil {
+ t.Errorf("expected non-nil err, got nil")
+ } else if err.Error() != "too many files" {
+ t.Errorf("expected too many files, got %v", err.Error())
+ }
+}
diff --git a/pkg/usermem/BUILD b/pkg/usermem/BUILD
new file mode 100644
index 000000000..6c9ada9c7
--- /dev/null
+++ b/pkg/usermem/BUILD
@@ -0,0 +1,55 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "addr_range",
+ out = "addr_range.go",
+ package = "usermem",
+ prefix = "Addr",
+ template = "//pkg/segment:generic_range",
+ types = {
+ "T": "Addr",
+ },
+)
+
+go_library(
+ name = "usermem",
+ srcs = [
+ "access_type.go",
+ "addr.go",
+ "addr_range.go",
+ "addr_range_seq_unsafe.go",
+ "bytes_io.go",
+ "bytes_io_unsafe.go",
+ "usermem.go",
+ "usermem_arm64.go",
+ "usermem_x86.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/atomicbitops",
+ "//pkg/binary",
+ "//pkg/context",
+ "//pkg/gohacks",
+ "//pkg/log",
+ "//pkg/safemem",
+ "//pkg/syserror",
+ ],
+)
+
+go_test(
+ name = "usermem_test",
+ size = "small",
+ srcs = [
+ "addr_range_seq_test.go",
+ "usermem_test.go",
+ ],
+ library = ":usermem",
+ deps = [
+ "//pkg/context",
+ "//pkg/safemem",
+ "//pkg/syserror",
+ ],
+)
diff --git a/pkg/usermem/README.md b/pkg/usermem/README.md
new file mode 100644
index 000000000..f6d2137eb
--- /dev/null
+++ b/pkg/usermem/README.md
@@ -0,0 +1,31 @@
+This package defines primitives for sentry access to application memory.
+
+Major types:
+
+- The `IO` interface represents a virtual address space and provides I/O
+ methods on that address space. `IO` is the lowest-level primitive. The
+ primary implementation of the `IO` interface is `mm.MemoryManager`.
+
+- `IOSequence` represents a collection of individually-contiguous address
+ ranges in a `IO` that is operated on sequentially, analogous to Linux's
+ `struct iov_iter`.
+
+Major usage patterns:
+
+- Access to a task's virtual memory, subject to the application's memory
+ protections and while running on that task's goroutine, from a context that
+ is at or above the level of the `kernel` package (e.g. most syscall
+ implementations in `syscalls/linux`); use the `kernel.Task.Copy*` wrappers
+ defined in `kernel/task_usermem.go`.
+
+- Access to a task's virtual memory, from a context that is at or above the
+ level of the `kernel` package, but where any of the above constraints does
+ not hold (e.g. `PTRACE_POKEDATA`, which ignores application memory
+ protections); obtain the task's `mm.MemoryManager` by calling
+ `kernel.Task.MemoryManager`, and call its `IO` methods directly.
+
+- Access to a task's virtual memory, from a context that is below the level of
+ the `kernel` package (e.g. filesystem I/O); clients must pass I/O arguments
+ from higher layers, usually in the form of an `IOSequence`. The
+ `kernel.Task.SingleIOSequence` and `kernel.Task.IovecsIOSequence` functions
+ in `kernel/task_usermem.go` are convenience functions for doing so.
diff --git a/pkg/usermem/access_type.go b/pkg/usermem/access_type.go
new file mode 100644
index 000000000..9c1742a59
--- /dev/null
+++ b/pkg/usermem/access_type.go
@@ -0,0 +1,128 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usermem
+
+import (
+ "syscall"
+)
+
+// AccessType specifies memory access types. This is used for
+// setting mapping permissions, as well as communicating faults.
+//
+// +stateify savable
+type AccessType struct {
+ // Read is read access.
+ Read bool
+
+ // Write is write access.
+ Write bool
+
+ // Execute is executable access.
+ Execute bool
+}
+
+// String returns a pretty representation of access. This looks like the
+// familiar r-x, rw-, etc. and can be relied on as such.
+func (a AccessType) String() string {
+ bits := [3]byte{'-', '-', '-'}
+ if a.Read {
+ bits[0] = 'r'
+ }
+ if a.Write {
+ bits[1] = 'w'
+ }
+ if a.Execute {
+ bits[2] = 'x'
+ }
+ return string(bits[:])
+}
+
+// Any returns true iff at least one of Read, Write or Execute is true.
+func (a AccessType) Any() bool {
+ return a.Read || a.Write || a.Execute
+}
+
+// Prot returns the system prot (syscall.PROT_READ, etc.) for this access.
+func (a AccessType) Prot() int {
+ var prot int
+ if a.Read {
+ prot |= syscall.PROT_READ
+ }
+ if a.Write {
+ prot |= syscall.PROT_WRITE
+ }
+ if a.Execute {
+ prot |= syscall.PROT_EXEC
+ }
+ return prot
+}
+
+// SupersetOf returns true iff the access types in a are a superset of the
+// access types in other.
+func (a AccessType) SupersetOf(other AccessType) bool {
+ if !a.Read && other.Read {
+ return false
+ }
+ if !a.Write && other.Write {
+ return false
+ }
+ if !a.Execute && other.Execute {
+ return false
+ }
+ return true
+}
+
+// Intersect returns the access types set in both a and other.
+func (a AccessType) Intersect(other AccessType) AccessType {
+ return AccessType{
+ Read: a.Read && other.Read,
+ Write: a.Write && other.Write,
+ Execute: a.Execute && other.Execute,
+ }
+}
+
+// Union returns the access types set in either a or other.
+func (a AccessType) Union(other AccessType) AccessType {
+ return AccessType{
+ Read: a.Read || other.Read,
+ Write: a.Write || other.Write,
+ Execute: a.Execute || other.Execute,
+ }
+}
+
+// Effective returns the set of effective access types allowed by a, even if
+// some types are not explicitly allowed.
+func (a AccessType) Effective() AccessType {
+ // In Linux, Write and Execute access generally imply Read access. See
+ // mm/mmap.c:protection_map.
+ //
+ // The notable exception is get_user_pages, which only checks against
+ // the original vma flags. That said, most user memory accesses do not
+ // use GUP.
+ if a.Write || a.Execute {
+ a.Read = true
+ }
+ return a
+}
+
+// Convenient access types.
+var (
+ NoAccess = AccessType{}
+ Read = AccessType{Read: true}
+ Write = AccessType{Write: true}
+ Execute = AccessType{Execute: true}
+ ReadWrite = AccessType{Read: true, Write: true}
+ AnyAccess = AccessType{Read: true, Write: true, Execute: true}
+)
diff --git a/pkg/usermem/addr.go b/pkg/usermem/addr.go
new file mode 100644
index 000000000..c4100481e
--- /dev/null
+++ b/pkg/usermem/addr.go
@@ -0,0 +1,125 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usermem
+
+import (
+ "fmt"
+)
+
+// Addr represents a generic virtual address.
+//
+// +stateify savable
+type Addr uintptr
+
+// AddLength adds the given length to start and returns the result. ok is true
+// iff adding the length did not overflow the range of Addr.
+//
+// Note: This function is usually used to get the end of an address range
+// defined by its start address and length. Since the resulting end is
+// exclusive, end == 0 is technically valid, and corresponds to a range that
+// extends to the end of the address space, but ok will be false. This isn't
+// expected to ever come up in practice.
+func (v Addr) AddLength(length uint64) (end Addr, ok bool) {
+ end = v + Addr(length)
+ // The second half of the following check is needed in case uintptr is
+ // smaller than 64 bits.
+ ok = end >= v && length <= uint64(^Addr(0))
+ return
+}
+
+// RoundDown returns the address rounded down to the nearest page boundary.
+func (v Addr) RoundDown() Addr {
+ return v & ^Addr(PageSize-1)
+}
+
+// RoundUp returns the address rounded up to the nearest page boundary. ok is
+// true iff rounding up did not wrap around.
+func (v Addr) RoundUp() (addr Addr, ok bool) {
+ addr = Addr(v + PageSize - 1).RoundDown()
+ ok = addr >= v
+ return
+}
+
+// MustRoundUp is equivalent to RoundUp, but panics if rounding up wraps
+// around.
+func (v Addr) MustRoundUp() Addr {
+ addr, ok := v.RoundUp()
+ if !ok {
+ panic(fmt.Sprintf("usermem.Addr(%d).RoundUp() wraps", v))
+ }
+ return addr
+}
+
+// HugeRoundDown returns the address rounded down to the nearest huge page
+// boundary.
+func (v Addr) HugeRoundDown() Addr {
+ return v & ^Addr(HugePageSize-1)
+}
+
+// HugeRoundUp returns the address rounded up to the nearest huge page boundary.
+// ok is true iff rounding up did not wrap around.
+func (v Addr) HugeRoundUp() (addr Addr, ok bool) {
+ addr = Addr(v + HugePageSize - 1).HugeRoundDown()
+ ok = addr >= v
+ return
+}
+
+// PageOffset returns the offset of v into the current page.
+func (v Addr) PageOffset() uint64 {
+ return uint64(v & Addr(PageSize-1))
+}
+
+// IsPageAligned returns true if v.PageOffset() == 0.
+func (v Addr) IsPageAligned() bool {
+ return v.PageOffset() == 0
+}
+
+// AddrRange is a range of Addrs.
+//
+// type AddrRange <generated by go_generics>
+
+// ToRange returns [v, v+length).
+func (v Addr) ToRange(length uint64) (AddrRange, bool) {
+ end, ok := v.AddLength(length)
+ return AddrRange{v, end}, ok
+}
+
+// IsPageAligned returns true if ar.Start.IsPageAligned() and
+// ar.End.IsPageAligned().
+func (ar AddrRange) IsPageAligned() bool {
+ return ar.Start.IsPageAligned() && ar.End.IsPageAligned()
+}
+
+// String implements fmt.Stringer.String.
+func (ar AddrRange) String() string {
+ return fmt.Sprintf("[%#x, %#x)", ar.Start, ar.End)
+}
+
+// PageRoundDown/Up are equivalent to Addr.RoundDown/Up, but without the
+// potentially truncating conversion from uint64 to Addr. This is necessary
+// because there is no way to define generic "PageRoundDown/Up" functions in Go.
+
+// PageRoundDown returns x rounded down to the nearest page boundary.
+func PageRoundDown(x uint64) uint64 {
+ return x &^ (PageSize - 1)
+}
+
+// PageRoundUp returns x rounded up to the nearest page boundary.
+// ok is true iff rounding up did not wrap around.
+func PageRoundUp(x uint64) (addr uint64, ok bool) {
+ addr = PageRoundDown(x + PageSize - 1)
+ ok = addr >= x
+ return
+}
diff --git a/pkg/usermem/addr_range_seq_test.go b/pkg/usermem/addr_range_seq_test.go
new file mode 100644
index 000000000..82f735026
--- /dev/null
+++ b/pkg/usermem/addr_range_seq_test.go
@@ -0,0 +1,197 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usermem
+
+import (
+ "testing"
+)
+
+var addrRangeSeqTests = []struct {
+ desc string
+ ranges []AddrRange
+}{
+ {
+ desc: "Empty sequence",
+ },
+ {
+ desc: "Single empty AddrRange",
+ ranges: []AddrRange{
+ {0x10, 0x10},
+ },
+ },
+ {
+ desc: "Single non-empty AddrRange of length 1",
+ ranges: []AddrRange{
+ {0x10, 0x11},
+ },
+ },
+ {
+ desc: "Single non-empty AddrRange of length 2",
+ ranges: []AddrRange{
+ {0x10, 0x12},
+ },
+ },
+ {
+ desc: "Multiple non-empty AddrRanges",
+ ranges: []AddrRange{
+ {0x10, 0x11},
+ {0x20, 0x22},
+ },
+ },
+ {
+ desc: "Multiple AddrRanges including empty AddrRanges",
+ ranges: []AddrRange{
+ {0x10, 0x10},
+ {0x20, 0x20},
+ {0x30, 0x33},
+ {0x40, 0x44},
+ {0x50, 0x50},
+ {0x60, 0x60},
+ {0x70, 0x77},
+ {0x80, 0x88},
+ {0x90, 0x90},
+ {0xa0, 0xa0},
+ },
+ },
+}
+
+func testAddrRangeSeqEqualityWithTailIteration(t *testing.T, ars AddrRangeSeq, wantRanges []AddrRange) {
+ var wantLen int64
+ for _, ar := range wantRanges {
+ wantLen += int64(ar.Length())
+ }
+
+ var i int
+ for !ars.IsEmpty() {
+ if gotLen := ars.NumBytes(); gotLen != wantLen {
+ t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d", i, ars, gotLen, wantLen)
+ }
+ if gotN, wantN := ars.NumRanges(), len(wantRanges)-i; gotN != wantN {
+ t.Errorf("Iteration %d: %v.NumRanges(): got %d, wanted %d", i, ars, gotN, wantN)
+ }
+ got := ars.Head()
+ if i >= len(wantRanges) {
+ t.Errorf("Iteration %d: %v.Head(): got %s, wanted <end of sequence>", i, ars, got)
+ } else if want := wantRanges[i]; got != want {
+ t.Errorf("Iteration %d: %v.Head(): got %s, wanted %s", i, ars, got, want)
+ }
+ ars = ars.Tail()
+ wantLen -= int64(got.Length())
+ i++
+ }
+ if gotLen := ars.NumBytes(); gotLen != 0 || wantLen != 0 {
+ t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d (which should be 0)", i, ars, gotLen, wantLen)
+ }
+ if gotN := ars.NumRanges(); gotN != 0 {
+ t.Errorf("Iteration %d: %v.NumRanges(): got %d, wanted 0", i, ars, gotN)
+ }
+}
+
+func TestAddrRangeSeqTailIteration(t *testing.T) {
+ for _, test := range addrRangeSeqTests {
+ t.Run(test.desc, func(t *testing.T) {
+ testAddrRangeSeqEqualityWithTailIteration(t, AddrRangeSeqFromSlice(test.ranges), test.ranges)
+ })
+ }
+}
+
+func TestAddrRangeSeqDropFirstEmpty(t *testing.T) {
+ var ars AddrRangeSeq
+ if got, want := ars.DropFirst(1), ars; got != want {
+ t.Errorf("%v.DropFirst(1): got %v, wanted %v", ars, got, want)
+ }
+}
+
+func TestAddrRangeSeqDropSingleByteIteration(t *testing.T) {
+ // Tests AddrRangeSeq iteration using Head/DropFirst, simulating
+ // I/O-per-AddrRange.
+ for _, test := range addrRangeSeqTests {
+ t.Run(test.desc, func(t *testing.T) {
+ // Figure out what AddrRanges we expect to see.
+ var wantLen int64
+ var wantRanges []AddrRange
+ for _, ar := range test.ranges {
+ wantLen += int64(ar.Length())
+ wantRanges = append(wantRanges, ar)
+ if ar.Length() == 0 {
+ // We "do" 0 bytes of I/O and then call DropFirst(0),
+ // advancing to the next AddrRange.
+ continue
+ }
+ // Otherwise we "do" 1 byte of I/O and then call DropFirst(1),
+ // advancing the AddrRange by 1 byte, or to the next AddrRange
+ // if this one is exhausted.
+ for ar.Start++; ar.Length() != 0; ar.Start++ {
+ wantRanges = append(wantRanges, ar)
+ }
+ }
+ t.Logf("Expected AddrRanges: %s (%d bytes)", wantRanges, wantLen)
+
+ ars := AddrRangeSeqFromSlice(test.ranges)
+ var i int
+ for !ars.IsEmpty() {
+ if gotLen := ars.NumBytes(); gotLen != wantLen {
+ t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d", i, ars, gotLen, wantLen)
+ }
+ got := ars.Head()
+ if i >= len(wantRanges) {
+ t.Errorf("Iteration %d: %v.Head(): got %s, wanted <end of sequence>", i, ars, got)
+ } else if want := wantRanges[i]; got != want {
+ t.Errorf("Iteration %d: %v.Head(): got %s, wanted %s", i, ars, got, want)
+ }
+ if got.Length() == 0 {
+ ars = ars.DropFirst(0)
+ } else {
+ ars = ars.DropFirst(1)
+ wantLen--
+ }
+ i++
+ }
+ if gotLen := ars.NumBytes(); gotLen != 0 || wantLen != 0 {
+ t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d (which should be 0)", i, ars, gotLen, wantLen)
+ }
+ })
+ }
+}
+
+func TestAddrRangeSeqTakeFirstEmpty(t *testing.T) {
+ var ars AddrRangeSeq
+ if got, want := ars.TakeFirst(1), ars; got != want {
+ t.Errorf("%v.TakeFirst(1): got %v, wanted %v", ars, got, want)
+ }
+}
+
+func TestAddrRangeSeqTakeFirst(t *testing.T) {
+ ranges := []AddrRange{
+ {0x10, 0x11},
+ {0x20, 0x22},
+ {0x30, 0x30},
+ {0x40, 0x44},
+ {0x50, 0x55},
+ {0x60, 0x60},
+ {0x70, 0x77},
+ }
+ ars := AddrRangeSeqFromSlice(ranges).TakeFirst(5)
+ want := []AddrRange{
+ {0x10, 0x11}, // +1 byte (total 1 byte), not truncated
+ {0x20, 0x22}, // +2 bytes (total 3 bytes), not truncated
+ {0x30, 0x30}, // +0 bytes (total 3 bytes), no change
+ {0x40, 0x42}, // +2 bytes (total 5 bytes), partially truncated
+ {0x50, 0x50}, // +0 bytes (total 5 bytes), fully truncated
+ {0x60, 0x60}, // +0 bytes (total 5 bytes), "fully truncated" (no change)
+ {0x70, 0x70}, // +0 bytes (total 5 bytes), fully truncated
+ }
+ testAddrRangeSeqEqualityWithTailIteration(t, ars, want)
+}
diff --git a/pkg/usermem/addr_range_seq_unsafe.go b/pkg/usermem/addr_range_seq_unsafe.go
new file mode 100644
index 000000000..c09337c15
--- /dev/null
+++ b/pkg/usermem/addr_range_seq_unsafe.go
@@ -0,0 +1,277 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usermem
+
+import (
+ "bytes"
+ "fmt"
+ "reflect"
+ "unsafe"
+)
+
+// An AddrRangeSeq represents a sequence of AddrRanges.
+//
+// AddrRangeSeqs are immutable and may be copied by value. The zero value of
+// AddrRangeSeq represents an empty sequence.
+//
+// An AddrRangeSeq may contain AddrRanges with a length of 0. This is necessary
+// since zero-length AddrRanges are significant to MM bounds checks.
+type AddrRangeSeq struct {
+ // If length is 0, then the AddrRangeSeq represents no AddrRanges.
+ // Invariants: data == 0; offset == 0; limit == 0.
+ //
+ // If length is 1, then the AddrRangeSeq represents the single
+ // AddrRange{offset, offset+limit}. Invariants: data == 0.
+ //
+ // Otherwise, length >= 2, and the AddrRangeSeq represents the `length`
+ // AddrRanges in the array of AddrRanges starting at address `data`,
+ // starting at `offset` bytes into the first AddrRange and limited to the
+ // following `limit` bytes. (AddrRanges after `limit` are still iterated,
+ // but are truncated to a length of 0.) Invariants: data != 0; offset <=
+ // data[0].Length(); limit > 0; offset+limit <= the combined length of all
+ // AddrRanges in the array.
+ data unsafe.Pointer
+ length int
+ offset Addr
+ limit Addr
+}
+
+// AddrRangeSeqOf returns an AddrRangeSeq representing the single AddrRange ar.
+func AddrRangeSeqOf(ar AddrRange) AddrRangeSeq {
+ return AddrRangeSeq{
+ length: 1,
+ offset: ar.Start,
+ limit: ar.Length(),
+ }
+}
+
+// AddrRangeSeqFromSlice returns an AddrRangeSeq representing all AddrRanges in
+// slice.
+//
+// Whether the returned AddrRangeSeq shares memory with slice is unspecified;
+// clients should avoid mutating slices passed to AddrRangeSeqFromSlice.
+//
+// Preconditions: The combined length of all AddrRanges in slice <=
+// math.MaxInt64.
+func AddrRangeSeqFromSlice(slice []AddrRange) AddrRangeSeq {
+ var limit int64
+ for _, ar := range slice {
+ len64 := int64(ar.Length())
+ if len64 < 0 {
+ panic(fmt.Sprintf("Length of AddrRange %v overflows int64", ar))
+ }
+ sum := limit + len64
+ if sum < limit {
+ panic(fmt.Sprintf("Total length of AddrRanges %v overflows int64", slice))
+ }
+ limit = sum
+ }
+ return addrRangeSeqFromSliceLimited(slice, limit)
+}
+
+// Preconditions: The combined length of all AddrRanges in slice <= limit.
+// limit >= 0. If len(slice) != 0, then limit > 0.
+func addrRangeSeqFromSliceLimited(slice []AddrRange, limit int64) AddrRangeSeq {
+ switch len(slice) {
+ case 0:
+ return AddrRangeSeq{}
+ case 1:
+ return AddrRangeSeq{
+ length: 1,
+ offset: slice[0].Start,
+ limit: Addr(limit),
+ }
+ default:
+ return AddrRangeSeq{
+ data: unsafe.Pointer(&slice[0]),
+ length: len(slice),
+ limit: Addr(limit),
+ }
+ }
+}
+
+// IsEmpty returns true if ars.NumRanges() == 0.
+//
+// Note that since AddrRangeSeq may contain AddrRanges with a length of zero,
+// an AddrRange representing 0 bytes (AddrRangeSeq.NumBytes() == 0) is not
+// necessarily empty.
+func (ars AddrRangeSeq) IsEmpty() bool {
+ return ars.length == 0
+}
+
+// NumRanges returns the number of AddrRanges in ars.
+func (ars AddrRangeSeq) NumRanges() int {
+ return ars.length
+}
+
+// NumBytes returns the number of bytes represented by ars.
+func (ars AddrRangeSeq) NumBytes() int64 {
+ return int64(ars.limit)
+}
+
+// Head returns the first AddrRange in ars.
+//
+// Preconditions: !ars.IsEmpty().
+func (ars AddrRangeSeq) Head() AddrRange {
+ if ars.length == 0 {
+ panic("empty AddrRangeSeq")
+ }
+ if ars.length == 1 {
+ return AddrRange{ars.offset, ars.offset + ars.limit}
+ }
+ ar := *(*AddrRange)(ars.data)
+ ar.Start += ars.offset
+ if ar.Length() > ars.limit {
+ ar.End = ar.Start + ars.limit
+ }
+ return ar
+}
+
+// Tail returns an AddrRangeSeq consisting of all AddrRanges in ars after the
+// first.
+//
+// Preconditions: !ars.IsEmpty().
+func (ars AddrRangeSeq) Tail() AddrRangeSeq {
+ if ars.length == 0 {
+ panic("empty AddrRangeSeq")
+ }
+ if ars.length == 1 {
+ return AddrRangeSeq{}
+ }
+ return ars.externalTail()
+}
+
+// Preconditions: ars.length >= 2.
+func (ars AddrRangeSeq) externalTail() AddrRangeSeq {
+ headLen := (*AddrRange)(ars.data).Length() - ars.offset
+ var tailLimit int64
+ if ars.limit > headLen {
+ tailLimit = int64(ars.limit - headLen)
+ }
+ var extSlice []AddrRange
+ extSliceHdr := (*reflect.SliceHeader)(unsafe.Pointer(&extSlice))
+ extSliceHdr.Data = uintptr(ars.data)
+ extSliceHdr.Len = ars.length
+ extSliceHdr.Cap = ars.length
+ return addrRangeSeqFromSliceLimited(extSlice[1:], tailLimit)
+}
+
+// DropFirst returns an AddrRangeSeq equivalent to ars, but with the first n
+// bytes omitted. If n > ars.NumBytes(), DropFirst returns an empty
+// AddrRangeSeq.
+//
+// If !ars.IsEmpty() and ars.Head().Length() == 0, DropFirst will always omit
+// at least ars.Head(), even if n == 0. This guarantees that the basic pattern
+// of:
+//
+// for !ars.IsEmpty() {
+// n, err = doIOWith(ars.Head())
+// if err != nil {
+// return err
+// }
+// ars = ars.DropFirst(n)
+// }
+//
+// works even in the presence of zero-length AddrRanges.
+//
+// Preconditions: n >= 0.
+func (ars AddrRangeSeq) DropFirst(n int) AddrRangeSeq {
+ if n < 0 {
+ panic(fmt.Sprintf("invalid n: %d", n))
+ }
+ return ars.DropFirst64(int64(n))
+}
+
+// DropFirst64 is equivalent to DropFirst but takes an int64.
+func (ars AddrRangeSeq) DropFirst64(n int64) AddrRangeSeq {
+ if n < 0 {
+ panic(fmt.Sprintf("invalid n: %d", n))
+ }
+ if Addr(n) > ars.limit {
+ return AddrRangeSeq{}
+ }
+ // Handle initial empty AddrRange.
+ switch ars.length {
+ case 0:
+ return AddrRangeSeq{}
+ case 1:
+ if ars.limit == 0 {
+ return AddrRangeSeq{}
+ }
+ default:
+ if rawHeadLen := (*AddrRange)(ars.data).Length(); ars.offset == rawHeadLen {
+ ars = ars.externalTail()
+ }
+ }
+ for n != 0 {
+ // Calling ars.Head() here is surprisingly expensive, so inline getting
+ // the head's length.
+ var headLen Addr
+ if ars.length == 1 {
+ headLen = ars.limit
+ } else {
+ headLen = (*AddrRange)(ars.data).Length() - ars.offset
+ }
+ if Addr(n) < headLen {
+ // Dropping ends partway through the head AddrRange.
+ ars.offset += Addr(n)
+ ars.limit -= Addr(n)
+ return ars
+ }
+ n -= int64(headLen)
+ ars = ars.Tail()
+ }
+ return ars
+}
+
+// TakeFirst returns an AddrRangeSeq equivalent to ars, but iterating at most n
+// bytes. TakeFirst never removes AddrRanges from ars; AddrRanges beyond the
+// first n bytes are reduced to a length of zero, but will still be iterated.
+//
+// Preconditions: n >= 0.
+func (ars AddrRangeSeq) TakeFirst(n int) AddrRangeSeq {
+ if n < 0 {
+ panic(fmt.Sprintf("invalid n: %d", n))
+ }
+ return ars.TakeFirst64(int64(n))
+}
+
+// TakeFirst64 is equivalent to TakeFirst but takes an int64.
+func (ars AddrRangeSeq) TakeFirst64(n int64) AddrRangeSeq {
+ if n < 0 {
+ panic(fmt.Sprintf("invalid n: %d", n))
+ }
+ if ars.limit > Addr(n) {
+ ars.limit = Addr(n)
+ }
+ return ars
+}
+
+// String implements fmt.Stringer.String.
+func (ars AddrRangeSeq) String() string {
+ // This is deliberately chosen to be the same as fmt's automatic stringer
+ // for []AddrRange.
+ var buf bytes.Buffer
+ buf.WriteByte('[')
+ var sep string
+ for !ars.IsEmpty() {
+ buf.WriteString(sep)
+ sep = " "
+ buf.WriteString(ars.Head().String())
+ ars = ars.Tail()
+ }
+ buf.WriteByte(']')
+ return buf.String()
+}
diff --git a/pkg/usermem/bytes_io.go b/pkg/usermem/bytes_io.go
new file mode 100644
index 000000000..e177d30eb
--- /dev/null
+++ b/pkg/usermem/bytes_io.go
@@ -0,0 +1,141 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usermem
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const maxInt = int(^uint(0) >> 1)
+
+// BytesIO implements IO using a byte slice. Addresses are interpreted as
+// offsets into the slice. Reads and writes beyond the end of the slice return
+// EFAULT.
+type BytesIO struct {
+ Bytes []byte
+}
+
+// CopyOut implements IO.CopyOut.
+func (b *BytesIO) CopyOut(ctx context.Context, addr Addr, src []byte, opts IOOpts) (int, error) {
+ rngN, rngErr := b.rangeCheck(addr, len(src))
+ if rngN == 0 {
+ return 0, rngErr
+ }
+ return copy(b.Bytes[int(addr):], src[:rngN]), rngErr
+}
+
+// CopyIn implements IO.CopyIn.
+func (b *BytesIO) CopyIn(ctx context.Context, addr Addr, dst []byte, opts IOOpts) (int, error) {
+ rngN, rngErr := b.rangeCheck(addr, len(dst))
+ if rngN == 0 {
+ return 0, rngErr
+ }
+ return copy(dst[:rngN], b.Bytes[int(addr):]), rngErr
+}
+
+// ZeroOut implements IO.ZeroOut.
+func (b *BytesIO) ZeroOut(ctx context.Context, addr Addr, toZero int64, opts IOOpts) (int64, error) {
+ if toZero > int64(maxInt) {
+ return 0, syserror.EINVAL
+ }
+ rngN, rngErr := b.rangeCheck(addr, int(toZero))
+ if rngN == 0 {
+ return 0, rngErr
+ }
+ zeroSlice := b.Bytes[int(addr) : int(addr)+rngN]
+ for i := range zeroSlice {
+ zeroSlice[i] = 0
+ }
+ return int64(rngN), rngErr
+}
+
+// CopyOutFrom implements IO.CopyOutFrom.
+func (b *BytesIO) CopyOutFrom(ctx context.Context, ars AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error) {
+ dsts, rngErr := b.blocksFromAddrRanges(ars)
+ n, err := src.ReadToBlocks(dsts)
+ if err != nil {
+ return int64(n), err
+ }
+ return int64(n), rngErr
+}
+
+// CopyInTo implements IO.CopyInTo.
+func (b *BytesIO) CopyInTo(ctx context.Context, ars AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error) {
+ srcs, rngErr := b.blocksFromAddrRanges(ars)
+ n, err := dst.WriteFromBlocks(srcs)
+ if err != nil {
+ return int64(n), err
+ }
+ return int64(n), rngErr
+}
+
+func (b *BytesIO) rangeCheck(addr Addr, length int) (int, error) {
+ if length == 0 {
+ return 0, nil
+ }
+ if length < 0 {
+ return 0, syserror.EINVAL
+ }
+ max := Addr(len(b.Bytes))
+ if addr >= max {
+ return 0, syserror.EFAULT
+ }
+ end, ok := addr.AddLength(uint64(length))
+ if !ok || end > max {
+ return int(max - addr), syserror.EFAULT
+ }
+ return length, nil
+}
+
+func (b *BytesIO) blocksFromAddrRanges(ars AddrRangeSeq) (safemem.BlockSeq, error) {
+ switch ars.NumRanges() {
+ case 0:
+ return safemem.BlockSeq{}, nil
+ case 1:
+ block, err := b.blockFromAddrRange(ars.Head())
+ return safemem.BlockSeqOf(block), err
+ default:
+ blocks := make([]safemem.Block, 0, ars.NumRanges())
+ for !ars.IsEmpty() {
+ block, err := b.blockFromAddrRange(ars.Head())
+ if block.Len() != 0 {
+ blocks = append(blocks, block)
+ }
+ if err != nil {
+ return safemem.BlockSeqFromSlice(blocks), err
+ }
+ ars = ars.Tail()
+ }
+ return safemem.BlockSeqFromSlice(blocks), nil
+ }
+}
+
+func (b *BytesIO) blockFromAddrRange(ar AddrRange) (safemem.Block, error) {
+ n, err := b.rangeCheck(ar.Start, int(ar.Length()))
+ if n == 0 {
+ return safemem.Block{}, err
+ }
+ return safemem.BlockFromSafeSlice(b.Bytes[int(ar.Start) : int(ar.Start)+n]), err
+}
+
+// BytesIOSequence returns an IOSequence representing the given byte slice.
+func BytesIOSequence(buf []byte) IOSequence {
+ return IOSequence{
+ IO: &BytesIO{buf},
+ Addrs: AddrRangeSeqOf(AddrRange{0, Addr(len(buf))}),
+ }
+}
diff --git a/pkg/usermem/bytes_io_unsafe.go b/pkg/usermem/bytes_io_unsafe.go
new file mode 100644
index 000000000..20de5037d
--- /dev/null
+++ b/pkg/usermem/bytes_io_unsafe.go
@@ -0,0 +1,47 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usermem
+
+import (
+ "sync/atomic"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/atomicbitops"
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+// SwapUint32 implements IO.SwapUint32.
+func (b *BytesIO) SwapUint32(ctx context.Context, addr Addr, new uint32, opts IOOpts) (uint32, error) {
+ if _, rngErr := b.rangeCheck(addr, 4); rngErr != nil {
+ return 0, rngErr
+ }
+ return atomic.SwapUint32((*uint32)(unsafe.Pointer(&b.Bytes[int(addr)])), new), nil
+}
+
+// CompareAndSwapUint32 implements IO.CompareAndSwapUint32.
+func (b *BytesIO) CompareAndSwapUint32(ctx context.Context, addr Addr, old, new uint32, opts IOOpts) (uint32, error) {
+ if _, rngErr := b.rangeCheck(addr, 4); rngErr != nil {
+ return 0, rngErr
+ }
+ return atomicbitops.CompareAndSwapUint32((*uint32)(unsafe.Pointer(&b.Bytes[int(addr)])), old, new), nil
+}
+
+// LoadUint32 implements IO.LoadUint32.
+func (b *BytesIO) LoadUint32(ctx context.Context, addr Addr, opts IOOpts) (uint32, error) {
+ if _, err := b.rangeCheck(addr, 4); err != nil {
+ return 0, err
+ }
+ return atomic.LoadUint32((*uint32)(unsafe.Pointer(&b.Bytes[int(addr)]))), nil
+}
diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go
new file mode 100644
index 000000000..cd6a0ea6b
--- /dev/null
+++ b/pkg/usermem/usermem.go
@@ -0,0 +1,595 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package usermem governs access to user memory.
+package usermem
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// IO provides access to the contents of a virtual memory space.
+type IO interface {
+ // CopyOut copies len(src) bytes from src to the memory mapped at addr. It
+ // returns the number of bytes copied. If the number of bytes copied is <
+ // len(src), it returns a non-nil error explaining why.
+ //
+ // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or
+ // any following locks in the lock order.
+ //
+ // Postconditions: CopyOut does not retain src.
+ CopyOut(ctx context.Context, addr Addr, src []byte, opts IOOpts) (int, error)
+
+ // CopyIn copies len(dst) bytes from the memory mapped at addr to dst.
+ // It returns the number of bytes copied. If the number of bytes copied is
+ // < len(dst), it returns a non-nil error explaining why.
+ //
+ // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or
+ // any following locks in the lock order.
+ //
+ // Postconditions: CopyIn does not retain dst.
+ CopyIn(ctx context.Context, addr Addr, dst []byte, opts IOOpts) (int, error)
+
+ // ZeroOut sets toZero bytes to 0, starting at addr. It returns the number
+ // of bytes zeroed. If the number of bytes zeroed is < toZero, it returns a
+ // non-nil error explaining why.
+ //
+ // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or
+ // any following locks in the lock order. toZero >= 0.
+ ZeroOut(ctx context.Context, addr Addr, toZero int64, opts IOOpts) (int64, error)
+
+ // CopyOutFrom copies ars.NumBytes() bytes from src to the memory mapped at
+ // ars. It returns the number of bytes copied, which may be less than the
+ // number of bytes read from src if copying fails. CopyOutFrom may return a
+ // partial copy without an error iff src.ReadToBlocks returns a partial
+ // read without an error.
+ //
+ // CopyOutFrom calls src.ReadToBlocks at most once.
+ //
+ // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or
+ // any following locks in the lock order. src.ReadToBlocks must not block
+ // on mm.MemoryManager.activeMu or any preceding locks in the lock order.
+ CopyOutFrom(ctx context.Context, ars AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error)
+
+ // CopyInTo copies ars.NumBytes() bytes from the memory mapped at ars to
+ // dst. It returns the number of bytes copied. CopyInTo may return a
+ // partial copy without an error iff dst.WriteFromBlocks returns a partial
+ // write without an error.
+ //
+ // CopyInTo calls dst.WriteFromBlocks at most once.
+ //
+ // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or
+ // any following locks in the lock order. dst.WriteFromBlocks must not
+ // block on mm.MemoryManager.activeMu or any preceding locks in the lock
+ // order.
+ CopyInTo(ctx context.Context, ars AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error)
+
+ // TODO(jamieliu): The requirement that CopyOutFrom/CopyInTo call src/dst
+ // at most once, which is unnecessary in most cases, forces implementations
+ // to gather safemem.Blocks into a single slice to pass to src/dst. Add
+ // CopyOutFromIter/CopyInToIter, which relaxes this restriction, to avoid
+ // this allocation.
+
+ // SwapUint32 atomically sets the uint32 value at addr to new and
+ // returns the previous value.
+ //
+ // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or
+ // any following locks in the lock order. addr must be aligned to a 4-byte
+ // boundary.
+ SwapUint32(ctx context.Context, addr Addr, new uint32, opts IOOpts) (uint32, error)
+
+ // CompareAndSwapUint32 atomically compares the uint32 value at addr to
+ // old; if they are equal, the value in memory is replaced by new. In
+ // either case, the previous value stored in memory is returned.
+ //
+ // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or
+ // any following locks in the lock order. addr must be aligned to a 4-byte
+ // boundary.
+ CompareAndSwapUint32(ctx context.Context, addr Addr, old, new uint32, opts IOOpts) (uint32, error)
+
+ // LoadUint32 atomically loads the uint32 value at addr and returns it.
+ //
+ // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or
+ // any following locks in the lock order. addr must be aligned to a 4-byte
+ // boundary.
+ LoadUint32(ctx context.Context, addr Addr, opts IOOpts) (uint32, error)
+}
+
+// IOOpts contains options applicable to all IO methods.
+type IOOpts struct {
+ // If IgnorePermissions is true, application-defined memory protections set
+ // by mmap(2) or mprotect(2) will be ignored. (Memory protections required
+ // by the target of the mapping are never ignored.)
+ IgnorePermissions bool
+
+ // If AddressSpaceActive is true, the IO implementation may assume that it
+ // has an active AddressSpace and can therefore use AddressSpace copying
+ // without performing activation. See mm/io.go for details.
+ AddressSpaceActive bool
+}
+
+// IOReadWriter is an io.ReadWriter that reads from / writes to addresses
+// starting at addr in IO. The preconditions that apply to IO.CopyIn and
+// IO.CopyOut also apply to IOReadWriter.Read and IOReadWriter.Write
+// respectively.
+type IOReadWriter struct {
+ Ctx context.Context
+ IO IO
+ Addr Addr
+ Opts IOOpts
+}
+
+// Read implements io.Reader.Read.
+//
+// Note that an address space does not have an "end of file", so Read can only
+// return io.EOF if IO.CopyIn returns io.EOF. Attempts to read unmapped or
+// unreadable memory, or beyond the end of the address space, should return
+// EFAULT.
+func (rw *IOReadWriter) Read(dst []byte) (int, error) {
+ n, err := rw.IO.CopyIn(rw.Ctx, rw.Addr, dst, rw.Opts)
+ end, ok := rw.Addr.AddLength(uint64(n))
+ if ok {
+ rw.Addr = end
+ } else {
+ // Disallow wraparound.
+ rw.Addr = ^Addr(0)
+ if err != nil {
+ err = syserror.EFAULT
+ }
+ }
+ return n, err
+}
+
+// Writer implements io.Writer.Write.
+func (rw *IOReadWriter) Write(src []byte) (int, error) {
+ n, err := rw.IO.CopyOut(rw.Ctx, rw.Addr, src, rw.Opts)
+ end, ok := rw.Addr.AddLength(uint64(n))
+ if ok {
+ rw.Addr = end
+ } else {
+ // Disallow wraparound.
+ rw.Addr = ^Addr(0)
+ if err != nil {
+ err = syserror.EFAULT
+ }
+ }
+ return n, err
+}
+
+// CopyObjectOut copies a fixed-size value or slice of fixed-size values from
+// src to the memory mapped at addr in uio. It returns the number of bytes
+// copied.
+//
+// CopyObjectOut must use reflection to encode src; performance-sensitive
+// clients should do encoding manually and use uio.CopyOut directly.
+//
+// Preconditions: As for IO.CopyOut.
+func CopyObjectOut(ctx context.Context, uio IO, addr Addr, src interface{}, opts IOOpts) (int, error) {
+ w := &IOReadWriter{
+ Ctx: ctx,
+ IO: uio,
+ Addr: addr,
+ Opts: opts,
+ }
+ // Allocate a byte slice the size of the object being marshaled. This
+ // adds an extra reflection call, but avoids needing to grow the slice
+ // during encoding, which can result in many heap-allocated slices.
+ b := make([]byte, 0, binary.Size(src))
+ return w.Write(binary.Marshal(b, ByteOrder, src))
+}
+
+// CopyObjectIn copies a fixed-size value or slice of fixed-size values from
+// the memory mapped at addr in uio to dst. It returns the number of bytes
+// copied.
+//
+// CopyObjectIn must use reflection to decode dst; performance-sensitive
+// clients should use uio.CopyIn directly and do decoding manually.
+//
+// Preconditions: As for IO.CopyIn.
+func CopyObjectIn(ctx context.Context, uio IO, addr Addr, dst interface{}, opts IOOpts) (int, error) {
+ r := &IOReadWriter{
+ Ctx: ctx,
+ IO: uio,
+ Addr: addr,
+ Opts: opts,
+ }
+ buf := make([]byte, binary.Size(dst))
+ if _, err := io.ReadFull(r, buf); err != nil {
+ return 0, err
+ }
+ binary.Unmarshal(buf, ByteOrder, dst)
+ return int(r.Addr - addr), nil
+}
+
+// CopyStringIn tuning parameters, defined outside that function for tests.
+const (
+ copyStringIncrement = 64
+ copyStringMaxInitBufLen = 256
+)
+
+// CopyStringIn copies a NUL-terminated string of unknown length from the
+// memory mapped at addr in uio and returns it as a string (not including the
+// trailing NUL). If the length of the string, including the terminating NUL,
+// would exceed maxlen, CopyStringIn returns the string truncated to maxlen and
+// ENAMETOOLONG.
+//
+// Preconditions: As for IO.CopyFromUser. maxlen >= 0.
+func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpts) (string, error) {
+ initLen := maxlen
+ if initLen > copyStringMaxInitBufLen {
+ initLen = copyStringMaxInitBufLen
+ }
+ buf := make([]byte, initLen)
+ var done int
+ for done < maxlen {
+ // Read up to copyStringIncrement bytes at a time.
+ readlen := copyStringIncrement
+ if readlen > maxlen-done {
+ readlen = maxlen - done
+ }
+ end, ok := addr.AddLength(uint64(readlen))
+ if !ok {
+ return gohacks.StringFromImmutableBytes(buf[:done]), syserror.EFAULT
+ }
+ // Shorten the read to avoid crossing page boundaries, since faulting
+ // in a page unnecessarily is expensive. This also ensures that partial
+ // copies up to the end of application-mappable memory succeed.
+ if addr.RoundDown() != end.RoundDown() {
+ end = end.RoundDown()
+ readlen = int(end - addr)
+ }
+ // Ensure that our buffer is large enough to accommodate the read.
+ if done+readlen > len(buf) {
+ newBufLen := len(buf) * 2
+ if newBufLen > maxlen {
+ newBufLen = maxlen
+ }
+ buf = append(buf, make([]byte, newBufLen-len(buf))...)
+ }
+ n, err := uio.CopyIn(ctx, addr, buf[done:done+readlen], opts)
+ // Look for the terminating zero byte, which may have occurred before
+ // hitting err.
+ if i := bytes.IndexByte(buf[done:done+n], byte(0)); i >= 0 {
+ return gohacks.StringFromImmutableBytes(buf[:done+i]), nil
+ }
+
+ done += n
+ if err != nil {
+ return gohacks.StringFromImmutableBytes(buf[:done]), err
+ }
+ addr = end
+ }
+ return gohacks.StringFromImmutableBytes(buf), syserror.ENAMETOOLONG
+}
+
+// CopyOutVec copies bytes from src to the memory mapped at ars in uio. The
+// maximum number of bytes copied is ars.NumBytes() or len(src), whichever is
+// less. CopyOutVec returns the number of bytes copied; if this is less than
+// the maximum, it returns a non-nil error explaining why.
+//
+// Preconditions: As for IO.CopyOut.
+func CopyOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, src []byte, opts IOOpts) (int, error) {
+ var done int
+ for !ars.IsEmpty() && done < len(src) {
+ ar := ars.Head()
+ cplen := len(src) - done
+ if Addr(cplen) >= ar.Length() {
+ cplen = int(ar.Length())
+ }
+ n, err := uio.CopyOut(ctx, ar.Start, src[done:done+cplen], opts)
+ done += n
+ if err != nil {
+ return done, err
+ }
+ ars = ars.DropFirst(n)
+ }
+ return done, nil
+}
+
+// CopyInVec copies bytes from the memory mapped at ars in uio to dst. The
+// maximum number of bytes copied is ars.NumBytes() or len(dst), whichever is
+// less. CopyInVec returns the number of bytes copied; if this is less than the
+// maximum, it returns a non-nil error explaining why.
+//
+// Preconditions: As for IO.CopyIn.
+func CopyInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst []byte, opts IOOpts) (int, error) {
+ var done int
+ for !ars.IsEmpty() && done < len(dst) {
+ ar := ars.Head()
+ cplen := len(dst) - done
+ if Addr(cplen) >= ar.Length() {
+ cplen = int(ar.Length())
+ }
+ n, err := uio.CopyIn(ctx, ar.Start, dst[done:done+cplen], opts)
+ done += n
+ if err != nil {
+ return done, err
+ }
+ ars = ars.DropFirst(n)
+ }
+ return done, nil
+}
+
+// ZeroOutVec writes zeroes to the memory mapped at ars in uio. The maximum
+// number of bytes written is ars.NumBytes() or toZero, whichever is less.
+// ZeroOutVec returns the number of bytes written; if this is less than the
+// maximum, it returns a non-nil error explaining why.
+//
+// Preconditions: As for IO.ZeroOut.
+func ZeroOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, toZero int64, opts IOOpts) (int64, error) {
+ var done int64
+ for !ars.IsEmpty() && done < toZero {
+ ar := ars.Head()
+ cplen := toZero - done
+ if Addr(cplen) >= ar.Length() {
+ cplen = int64(ar.Length())
+ }
+ n, err := uio.ZeroOut(ctx, ar.Start, cplen, opts)
+ done += n
+ if err != nil {
+ return done, err
+ }
+ ars = ars.DropFirst64(n)
+ }
+ return done, nil
+}
+
+func isASCIIWhitespace(b byte) bool {
+ // Compare Linux include/linux/ctype.h, lib/ctype.c.
+ // 9 => horizontal tab '\t'
+ // 10 => line feed '\n'
+ // 11 => vertical tab '\v'
+ // 12 => form feed '\c'
+ // 13 => carriage return '\r'
+ return b == ' ' || (b >= 9 && b <= 13)
+}
+
+// CopyInt32StringsInVec copies up to len(dsts) whitespace-separated decimal
+// strings from the memory mapped at ars in uio and converts them to int32
+// values in dsts. It returns the number of bytes read.
+//
+// CopyInt32StringsInVec shares the following properties with Linux's
+// kernel/sysctl.c:proc_dointvec(write=1):
+//
+// - If any read value overflows the range of int32, or any invalid characters
+// are encountered during the read, CopyInt32StringsInVec returns EINVAL.
+//
+// - If, upon reaching the end of ars, fewer than len(dsts) values have been
+// read, CopyInt32StringsInVec returns no error if at least 1 value was read
+// and EINVAL otherwise.
+//
+// - Trailing whitespace after the last successfully read value is counted in
+// the number of bytes read.
+//
+// Unlike proc_dointvec():
+//
+// - CopyInt32StringsInVec does not implicitly limit ars.NumBytes() to
+// PageSize-1; callers that require this must do so explicitly.
+//
+// - CopyInt32StringsInVec returns EINVAL if ars.NumBytes() == 0.
+//
+// Preconditions: As for CopyInVec.
+func CopyInt32StringsInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dsts []int32, opts IOOpts) (int64, error) {
+ if len(dsts) == 0 {
+ return 0, nil
+ }
+
+ buf := make([]byte, ars.NumBytes())
+ n, cperr := CopyInVec(ctx, uio, ars, buf, opts)
+ buf = buf[:n]
+
+ var i, j int
+ for ; j < len(dsts); j++ {
+ // Skip leading whitespace.
+ for i < len(buf) && isASCIIWhitespace(buf[i]) {
+ i++
+ }
+ if i == len(buf) {
+ break
+ }
+
+ // Find the end of the value to be parsed (next whitespace or end of string).
+ nextI := i + 1
+ for nextI < len(buf) && !isASCIIWhitespace(buf[nextI]) {
+ nextI++
+ }
+
+ // Parse a single value.
+ val, err := strconv.ParseInt(string(buf[i:nextI]), 10, 32)
+ if err != nil {
+ return int64(i), syserror.EINVAL
+ }
+ dsts[j] = int32(val)
+
+ i = nextI
+ }
+
+ // Skip trailing whitespace.
+ for i < len(buf) && isASCIIWhitespace(buf[i]) {
+ i++
+ }
+
+ if cperr != nil {
+ return int64(i), cperr
+ }
+ if j == 0 {
+ return int64(i), syserror.EINVAL
+ }
+ return int64(i), nil
+}
+
+// CopyInt32StringInVec is equivalent to CopyInt32StringsInVec, but copies at
+// most one int32.
+func CopyInt32StringInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst *int32, opts IOOpts) (int64, error) {
+ dsts := [1]int32{*dst}
+ n, err := CopyInt32StringsInVec(ctx, uio, ars, dsts[:], opts)
+ *dst = dsts[0]
+ return n, err
+}
+
+// IOSequence holds arguments to IO methods.
+type IOSequence struct {
+ IO IO
+ Addrs AddrRangeSeq
+ Opts IOOpts
+}
+
+// NumBytes returns s.Addrs.NumBytes().
+//
+// Note that NumBytes() may return 0 even if !s.Addrs.IsEmpty(), since
+// s.Addrs may contain a non-zero number of zero-length AddrRanges.
+// Many clients of
+// IOSequence currently do something like:
+//
+// if ioseq.NumBytes() == 0 {
+// return 0, nil
+// }
+// if f.availableBytes == 0 {
+// return 0, syserror.ErrWouldBlock
+// }
+// return ioseq.CopyOutFrom(..., reader)
+//
+// In such cases, using s.Addrs.IsEmpty() will cause them to have the wrong
+// behavior for zero-length I/O. However, using s.NumBytes() == 0 instead means
+// that we will return success for zero-length I/O in cases where Linux would
+// return EFAULT due to a failed access_ok() check, so in the long term we
+// should move checks for ErrWouldBlock etc. into the body of
+// reader.ReadToBlocks and use s.Addrs.IsEmpty() instead.
+func (s IOSequence) NumBytes() int64 {
+ return s.Addrs.NumBytes()
+}
+
+// DropFirst returns a copy of s with s.Addrs.DropFirst(n).
+//
+// Preconditions: As for AddrRangeSeq.DropFirst.
+func (s IOSequence) DropFirst(n int) IOSequence {
+ return IOSequence{s.IO, s.Addrs.DropFirst(n), s.Opts}
+}
+
+// DropFirst64 returns a copy of s with s.Addrs.DropFirst64(n).
+//
+// Preconditions: As for AddrRangeSeq.DropFirst64.
+func (s IOSequence) DropFirst64(n int64) IOSequence {
+ return IOSequence{s.IO, s.Addrs.DropFirst64(n), s.Opts}
+}
+
+// TakeFirst returns a copy of s with s.Addrs.TakeFirst(n).
+//
+// Preconditions: As for AddrRangeSeq.TakeFirst.
+func (s IOSequence) TakeFirst(n int) IOSequence {
+ return IOSequence{s.IO, s.Addrs.TakeFirst(n), s.Opts}
+}
+
+// TakeFirst64 returns a copy of s with s.Addrs.TakeFirst64(n).
+//
+// Preconditions: As for AddrRangeSeq.TakeFirst64.
+func (s IOSequence) TakeFirst64(n int64) IOSequence {
+ return IOSequence{s.IO, s.Addrs.TakeFirst64(n), s.Opts}
+}
+
+// CopyOut invokes CopyOutVec over s.Addrs.
+//
+// As with CopyOutVec, if s.NumBytes() < len(src), the copy will be truncated
+// to s.NumBytes(), and a nil error will be returned.
+//
+// Preconditions: As for CopyOutVec.
+func (s IOSequence) CopyOut(ctx context.Context, src []byte) (int, error) {
+ return CopyOutVec(ctx, s.IO, s.Addrs, src, s.Opts)
+}
+
+// CopyIn invokes CopyInVec over s.Addrs.
+//
+// As with CopyInVec, if s.NumBytes() < len(dst), the copy will be truncated to
+// s.NumBytes(), and a nil error will be returned.
+//
+// Preconditions: As for CopyInVec.
+func (s IOSequence) CopyIn(ctx context.Context, dst []byte) (int, error) {
+ return CopyInVec(ctx, s.IO, s.Addrs, dst, s.Opts)
+}
+
+// ZeroOut invokes ZeroOutVec over s.Addrs.
+//
+// As with ZeroOutVec, if s.NumBytes() < toZero, the write will be truncated
+// to s.NumBytes(), and a nil error will be returned.
+//
+// Preconditions: As for ZeroOutVec.
+func (s IOSequence) ZeroOut(ctx context.Context, toZero int64) (int64, error) {
+ return ZeroOutVec(ctx, s.IO, s.Addrs, toZero, s.Opts)
+}
+
+// CopyOutFrom invokes s.CopyOutFrom over s.Addrs.
+//
+// Preconditions: As for IO.CopyOutFrom.
+func (s IOSequence) CopyOutFrom(ctx context.Context, src safemem.Reader) (int64, error) {
+ return s.IO.CopyOutFrom(ctx, s.Addrs, src, s.Opts)
+}
+
+// CopyInTo invokes s.CopyInTo over s.Addrs.
+//
+// Preconditions: As for IO.CopyInTo.
+func (s IOSequence) CopyInTo(ctx context.Context, dst safemem.Writer) (int64, error) {
+ return s.IO.CopyInTo(ctx, s.Addrs, dst, s.Opts)
+}
+
+// Reader returns an io.Reader that reads from s. Reads beyond the end of s
+// return io.EOF. The preconditions that apply to s.CopyIn also apply to the
+// returned io.Reader.Read.
+func (s IOSequence) Reader(ctx context.Context) io.Reader {
+ return &ioSequenceReadWriter{ctx, s}
+}
+
+// Writer returns an io.Writer that writes to s. Writes beyond the end of s
+// return ErrEndOfIOSequence. The preconditions that apply to s.CopyOut also
+// apply to the returned io.Writer.Write.
+func (s IOSequence) Writer(ctx context.Context) io.Writer {
+ return &ioSequenceReadWriter{ctx, s}
+}
+
+// ErrEndOfIOSequence is returned by IOSequence.Writer().Write() when
+// attempting to write beyond the end of the IOSequence.
+var ErrEndOfIOSequence = errors.New("write beyond end of IOSequence")
+
+type ioSequenceReadWriter struct {
+ ctx context.Context
+ s IOSequence
+}
+
+// Read implements io.Reader.Read.
+func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) {
+ n, err := rw.s.CopyIn(rw.ctx, dst)
+ rw.s = rw.s.DropFirst(n)
+ if err == nil && rw.s.NumBytes() == 0 {
+ err = io.EOF
+ }
+ return n, err
+}
+
+// Write implements io.Writer.Write.
+func (rw *ioSequenceReadWriter) Write(src []byte) (int, error) {
+ n, err := rw.s.CopyOut(rw.ctx, src)
+ rw.s = rw.s.DropFirst(n)
+ if err == nil && n < len(src) {
+ err = ErrEndOfIOSequence
+ }
+ return n, err
+}
diff --git a/pkg/usermem/usermem_arm64.go b/pkg/usermem/usermem_arm64.go
new file mode 100644
index 000000000..fdfc30a66
--- /dev/null
+++ b/pkg/usermem/usermem_arm64.go
@@ -0,0 +1,53 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package usermem
+
+import (
+ "encoding/binary"
+ "syscall"
+)
+
+const (
+ // PageSize is the system page size.
+ // arm64 support 4K/16K/64K page size,
+ // which can be get by syscall.Getpagesize().
+ // Currently, only 4K page size is supported.
+ PageSize = 1 << PageShift
+
+ // HugePageSize is the system huge page size.
+ HugePageSize = 1 << HugePageShift
+
+ // PageShift is the binary log of the system page size.
+ PageShift = 12
+
+ // HugePageShift is the binary log of the system huge page size.
+ // Should be calculated by "PageShift + (PageShift - 3)"
+ // when multiple page size support is ready.
+ HugePageShift = 21
+)
+
+var (
+ // ByteOrder is the native byte order (little endian).
+ ByteOrder = binary.LittleEndian
+)
+
+func init() {
+ // Make sure the page size is 4K on arm64 platform.
+ if size := syscall.Getpagesize(); size != PageSize {
+ panic("Only 4K page size is supported on arm64!")
+ }
+}
diff --git a/pkg/usermem/usermem_test.go b/pkg/usermem/usermem_test.go
new file mode 100644
index 000000000..bf3c5df2b
--- /dev/null
+++ b/pkg/usermem/usermem_test.go
@@ -0,0 +1,424 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usermem
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "reflect"
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// newContext returns a context.Context that we can use in these tests (we
+// can't use contexttest because it depends on usermem).
+func newContext() context.Context {
+ return context.Background()
+}
+
+func newBytesIOString(s string) *BytesIO {
+ return &BytesIO{[]byte(s)}
+}
+
+func TestBytesIOCopyOutSuccess(t *testing.T) {
+ b := newBytesIOString("ABCDE")
+ n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{})
+ if wantN := 3; n != wantN || err != nil {
+ t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if got, want := b.Bytes, []byte("AfooE"); !bytes.Equal(got, want) {
+ t.Errorf("Bytes: got %q, wanted %q", got, want)
+ }
+}
+
+func TestBytesIOCopyOutFailure(t *testing.T) {
+ b := newBytesIOString("ABC")
+ n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{})
+ if wantN, wantErr := 2, syserror.EFAULT; n != wantN || err != wantErr {
+ t.Errorf("CopyOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
+ }
+ if got, want := b.Bytes, []byte("Afo"); !bytes.Equal(got, want) {
+ t.Errorf("Bytes: got %q, wanted %q", got, want)
+ }
+}
+
+func TestBytesIOCopyInSuccess(t *testing.T) {
+ b := newBytesIOString("AfooE")
+ var dst [3]byte
+ n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{})
+ if wantN := 3; n != wantN || err != nil {
+ t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) {
+ t.Errorf("dst: got %q, wanted %q", got, want)
+ }
+}
+
+func TestBytesIOCopyInFailure(t *testing.T) {
+ b := newBytesIOString("Afo")
+ var dst [3]byte
+ n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{})
+ if wantN, wantErr := 2, syserror.EFAULT; n != wantN || err != wantErr {
+ t.Errorf("CopyIn: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
+ }
+ if got, want := dst[:], []byte("fo\x00"); !bytes.Equal(got, want) {
+ t.Errorf("dst: got %q, wanted %q", got, want)
+ }
+}
+
+func TestBytesIOZeroOutSuccess(t *testing.T) {
+ b := newBytesIOString("ABCD")
+ n, err := b.ZeroOut(newContext(), 1, 2, IOOpts{})
+ if wantN := int64(2); n != wantN || err != nil {
+ t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if got, want := b.Bytes, []byte("A\x00\x00D"); !bytes.Equal(got, want) {
+ t.Errorf("Bytes: got %q, wanted %q", got, want)
+ }
+}
+
+func TestBytesIOZeroOutFailure(t *testing.T) {
+ b := newBytesIOString("ABC")
+ n, err := b.ZeroOut(newContext(), 1, 3, IOOpts{})
+ if wantN, wantErr := int64(2), syserror.EFAULT; n != wantN || err != wantErr {
+ t.Errorf("ZeroOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
+ }
+ if got, want := b.Bytes, []byte("A\x00\x00"); !bytes.Equal(got, want) {
+ t.Errorf("Bytes: got %q, wanted %q", got, want)
+ }
+}
+
+func TestBytesIOCopyOutFromSuccess(t *testing.T) {
+ b := newBytesIOString("ABCDEFGH")
+ n, err := b.CopyOutFrom(newContext(), AddrRangeSeqFromSlice([]AddrRange{
+ {Start: 4, End: 7},
+ {Start: 1, End: 4},
+ }), safemem.FromIOReader{bytes.NewBufferString("barfoo")}, IOOpts{})
+ if wantN := int64(6); n != wantN || err != nil {
+ t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if got, want := b.Bytes, []byte("AfoobarH"); !bytes.Equal(got, want) {
+ t.Errorf("Bytes: got %q, wanted %q", got, want)
+ }
+}
+
+func TestBytesIOCopyOutFromFailure(t *testing.T) {
+ b := newBytesIOString("ABCDE")
+ n, err := b.CopyOutFrom(newContext(), AddrRangeSeqFromSlice([]AddrRange{
+ {Start: 1, End: 4},
+ {Start: 4, End: 7},
+ }), safemem.FromIOReader{bytes.NewBufferString("foobar")}, IOOpts{})
+ if wantN, wantErr := int64(4), syserror.EFAULT; n != wantN || err != wantErr {
+ t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
+ }
+ if got, want := b.Bytes, []byte("Afoob"); !bytes.Equal(got, want) {
+ t.Errorf("Bytes: got %q, wanted %q", got, want)
+ }
+}
+
+func TestBytesIOCopyInToSuccess(t *testing.T) {
+ b := newBytesIOString("AfoobarH")
+ var dst bytes.Buffer
+ n, err := b.CopyInTo(newContext(), AddrRangeSeqFromSlice([]AddrRange{
+ {Start: 4, End: 7},
+ {Start: 1, End: 4},
+ }), safemem.FromIOWriter{&dst}, IOOpts{})
+ if wantN := int64(6); n != wantN || err != nil {
+ t.Errorf("CopyInTo: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if got, want := dst.Bytes(), []byte("barfoo"); !bytes.Equal(got, want) {
+ t.Errorf("dst.Bytes(): got %q, wanted %q", got, want)
+ }
+}
+
+func TestBytesIOCopyInToFailure(t *testing.T) {
+ b := newBytesIOString("Afoob")
+ var dst bytes.Buffer
+ n, err := b.CopyInTo(newContext(), AddrRangeSeqFromSlice([]AddrRange{
+ {Start: 1, End: 4},
+ {Start: 4, End: 7},
+ }), safemem.FromIOWriter{&dst}, IOOpts{})
+ if wantN, wantErr := int64(4), syserror.EFAULT; n != wantN || err != wantErr {
+ t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
+ }
+ if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) {
+ t.Errorf("dst.Bytes(): got %q, wanted %q", got, want)
+ }
+}
+
+type testStruct struct {
+ Int8 int8
+ Uint8 uint8
+ Int16 int16
+ Uint16 uint16
+ Int32 int32
+ Uint32 uint32
+ Int64 int64
+ Uint64 uint64
+}
+
+func TestCopyObject(t *testing.T) {
+ wantObj := testStruct{1, 2, 3, 4, 5, 6, 7, 8}
+ wantN := binary.Size(wantObj)
+ b := &BytesIO{make([]byte, wantN)}
+ ctx := newContext()
+ if n, err := CopyObjectOut(ctx, b, 0, &wantObj, IOOpts{}); n != wantN || err != nil {
+ t.Fatalf("CopyObjectOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ var gotObj testStruct
+ if n, err := CopyObjectIn(ctx, b, 0, &gotObj, IOOpts{}); n != wantN || err != nil {
+ t.Errorf("CopyObjectIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if gotObj != wantObj {
+ t.Errorf("CopyObject round trip: got %+v, wanted %+v", gotObj, wantObj)
+ }
+}
+
+func TestCopyStringInShort(t *testing.T) {
+ // Tests for string length <= copyStringIncrement.
+ want := strings.Repeat("A", copyStringIncrement-2)
+ mem := want + "\x00"
+ if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil {
+ t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want)
+ }
+}
+
+func TestCopyStringInLong(t *testing.T) {
+ // Tests for copyStringIncrement < string length <= copyStringMaxInitBufLen
+ // (requiring multiple calls to IO.CopyIn()).
+ want := strings.Repeat("A", copyStringIncrement*3/4) + strings.Repeat("B", copyStringIncrement*3/4)
+ mem := want + "\x00"
+ if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil {
+ t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want)
+ }
+}
+
+func TestCopyStringInVeryLong(t *testing.T) {
+ // Tests for string length > copyStringMaxInitBufLen (requiring buffer
+ // reallocation).
+ want := strings.Repeat("A", copyStringMaxInitBufLen*3/4) + strings.Repeat("B", copyStringMaxInitBufLen*3/4)
+ mem := want + "\x00"
+ if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringMaxInitBufLen, IOOpts{}); got != want || err != nil {
+ t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want)
+ }
+}
+
+func TestCopyStringInNoTerminatingZeroByte(t *testing.T) {
+ want := strings.Repeat("A", copyStringIncrement-1)
+ got, err := CopyStringIn(newContext(), newBytesIOString(want), 0, 2*copyStringIncrement, IOOpts{})
+ if wantErr := syserror.EFAULT; got != want || err != wantErr {
+ t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr)
+ }
+}
+
+func TestCopyStringInTruncatedByMaxlen(t *testing.T) {
+ got, err := CopyStringIn(newContext(), newBytesIOString(strings.Repeat("A", 10)), 0, 5, IOOpts{})
+ if want, wantErr := strings.Repeat("A", 5), syserror.ENAMETOOLONG; got != want || err != wantErr {
+ t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr)
+ }
+}
+
+func TestCopyInt32StringsInVec(t *testing.T) {
+ for _, test := range []struct {
+ str string
+ n int
+ initial []int32
+ final []int32
+ }{
+ {
+ str: "100 200",
+ n: len("100 200"),
+ initial: []int32{1, 2},
+ final: []int32{100, 200},
+ },
+ {
+ // Fewer values ok
+ str: "100",
+ n: len("100"),
+ initial: []int32{1, 2},
+ final: []int32{100, 2},
+ },
+ {
+ // Extra values ok
+ str: "100 200 300",
+ n: len("100 200 "),
+ initial: []int32{1, 2},
+ final: []int32{100, 200},
+ },
+ {
+ // Leading and trailing whitespace ok
+ str: " 100\t200\n",
+ n: len(" 100\t200\n"),
+ initial: []int32{1, 2},
+ final: []int32{100, 200},
+ },
+ } {
+ t.Run(fmt.Sprintf("%q", test.str), func(t *testing.T) {
+ src := BytesIOSequence([]byte(test.str))
+ dsts := append([]int32(nil), test.initial...)
+ if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); n != int64(test.n) || err != nil {
+ t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (%d, nil)", n, err, test.n)
+ }
+ if !reflect.DeepEqual(dsts, test.final) {
+ t.Errorf("dsts: got %v, wanted %v", dsts, test.final)
+ }
+ })
+ }
+}
+
+func TestCopyInt32StringsInVecRequiresOneValidValue(t *testing.T) {
+ for _, s := range []string{"", "\n", "a123"} {
+ t.Run(fmt.Sprintf("%q", s), func(t *testing.T) {
+ src := BytesIOSequence([]byte(s))
+ initial := []int32{1, 2}
+ dsts := append([]int32(nil), initial...)
+ if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); err != syserror.EINVAL {
+ t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (_, %v)", n, err, syserror.EINVAL)
+ }
+ if !reflect.DeepEqual(dsts, initial) {
+ t.Errorf("dsts: got %v, wanted %v", dsts, initial)
+ }
+ })
+ }
+}
+
+func TestIOSequenceCopyOut(t *testing.T) {
+ buf := []byte("ABCD")
+ s := BytesIOSequence(buf)
+
+ // CopyOut limited by len(src).
+ n, err := s.CopyOut(newContext(), []byte("fo"))
+ if wantN := 2; n != wantN || err != nil {
+ t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if want := []byte("foCD"); !bytes.Equal(buf, want) {
+ t.Errorf("buf: got %q, wanted %q", buf, want)
+ }
+ s = s.DropFirst(2)
+ if got, want := s.NumBytes(), int64(2); got != want {
+ t.Errorf("NumBytes: got %v, wanted %v", got, want)
+ }
+
+ // CopyOut limited by s.NumBytes().
+ n, err = s.CopyOut(newContext(), []byte("obar"))
+ if wantN := 2; n != wantN || err != nil {
+ t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if want := []byte("foob"); !bytes.Equal(buf, want) {
+ t.Errorf("buf: got %q, wanted %q", buf, want)
+ }
+ s = s.DropFirst(2)
+ if got, want := s.NumBytes(), int64(0); got != want {
+ t.Errorf("NumBytes: got %v, wanted %v", got, want)
+ }
+}
+
+func TestIOSequenceCopyIn(t *testing.T) {
+ s := BytesIOSequence([]byte("foob"))
+ dst := []byte("ABCDEF")
+
+ // CopyIn limited by len(dst).
+ n, err := s.CopyIn(newContext(), dst[:2])
+ if wantN := 2; n != wantN || err != nil {
+ t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if want := []byte("foCDEF"); !bytes.Equal(dst, want) {
+ t.Errorf("dst: got %q, wanted %q", dst, want)
+ }
+ s = s.DropFirst(2)
+ if got, want := s.NumBytes(), int64(2); got != want {
+ t.Errorf("NumBytes: got %v, wanted %v", got, want)
+ }
+
+ // CopyIn limited by s.Remaining().
+ n, err = s.CopyIn(newContext(), dst[2:])
+ if wantN := 2; n != wantN || err != nil {
+ t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if want := []byte("foobEF"); !bytes.Equal(dst, want) {
+ t.Errorf("dst: got %q, wanted %q", dst, want)
+ }
+ s = s.DropFirst(2)
+ if got, want := s.NumBytes(), int64(0); got != want {
+ t.Errorf("NumBytes: got %v, wanted %v", got, want)
+ }
+}
+
+func TestIOSequenceZeroOut(t *testing.T) {
+ buf := []byte("ABCD")
+ s := BytesIOSequence(buf)
+
+ // ZeroOut limited by toZero.
+ n, err := s.ZeroOut(newContext(), 2)
+ if wantN := int64(2); n != wantN || err != nil {
+ t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if want := []byte("\x00\x00CD"); !bytes.Equal(buf, want) {
+ t.Errorf("buf: got %q, wanted %q", buf, want)
+ }
+ s = s.DropFirst(2)
+ if got, want := s.NumBytes(), int64(2); got != want {
+ t.Errorf("NumBytes: got %v, wanted %v", got, want)
+ }
+
+ // ZeroOut limited by s.NumBytes().
+ n, err = s.ZeroOut(newContext(), 4)
+ if wantN := int64(2); n != wantN || err != nil {
+ t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if want := []byte("\x00\x00\x00\x00"); !bytes.Equal(buf, want) {
+ t.Errorf("buf: got %q, wanted %q", buf, want)
+ }
+ s = s.DropFirst(2)
+ if got, want := s.NumBytes(), int64(0); got != want {
+ t.Errorf("NumBytes: got %v, wanted %v", got, want)
+ }
+}
+
+func TestIOSequenceTakeFirst(t *testing.T) {
+ s := BytesIOSequence([]byte("foobar"))
+ if got, want := s.NumBytes(), int64(6); got != want {
+ t.Errorf("NumBytes: got %v, wanted %v", got, want)
+ }
+
+ s = s.TakeFirst(3)
+ if got, want := s.NumBytes(), int64(3); got != want {
+ t.Errorf("NumBytes: got %v, wanted %v", got, want)
+ }
+
+ // TakeFirst(n) where n > s.NumBytes() is a no-op.
+ s = s.TakeFirst(9)
+ if got, want := s.NumBytes(), int64(3); got != want {
+ t.Errorf("NumBytes: got %v, wanted %v", got, want)
+ }
+
+ var dst [3]byte
+ n, err := s.CopyIn(newContext(), dst[:])
+ if wantN := 3; n != wantN || err != nil {
+ t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) {
+ t.Errorf("dst: got %q, wanted %q", got, want)
+ }
+ s = s.DropFirst(3)
+ if got, want := s.NumBytes(), int64(0); got != want {
+ t.Errorf("NumBytes: got %v, wanted %v", got, want)
+ }
+}
diff --git a/pkg/usermem/usermem_x86.go b/pkg/usermem/usermem_x86.go
new file mode 100644
index 000000000..d96f829fb
--- /dev/null
+++ b/pkg/usermem/usermem_x86.go
@@ -0,0 +1,38 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64 386
+
+package usermem
+
+import "encoding/binary"
+
+const (
+ // PageSize is the system page size.
+ PageSize = 1 << PageShift
+
+ // HugePageSize is the system huge page size.
+ HugePageSize = 1 << HugePageShift
+
+ // PageShift is the binary log of the system page size.
+ PageShift = 12
+
+ // HugePageShift is the binary log of the system huge page size.
+ HugePageShift = 21
+)
+
+var (
+ // ByteOrder is the native byte order (little endian).
+ ByteOrder = binary.LittleEndian
+)
diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD
new file mode 100644
index 000000000..852480a09
--- /dev/null
+++ b/pkg/waiter/BUILD
@@ -0,0 +1,35 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "waiter_list",
+ out = "waiter_list.go",
+ package = "waiter",
+ prefix = "waiter",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*Entry",
+ "Linker": "*Entry",
+ },
+)
+
+go_library(
+ name = "waiter",
+ srcs = [
+ "waiter.go",
+ "waiter_list.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = ["//pkg/sync"],
+)
+
+go_test(
+ name = "waiter_test",
+ size = "small",
+ srcs = [
+ "waiter_test.go",
+ ],
+ library = ":waiter",
+)
diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go
new file mode 100644
index 000000000..67a950444
--- /dev/null
+++ b/pkg/waiter/waiter.go
@@ -0,0 +1,244 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package waiter provides the implementation of a wait queue, where waiters can
+// be enqueued to be notified when an event of interest happens.
+//
+// Becoming readable and/or writable are examples of events. Waiters are
+// expected to use a pattern similar to this to make a blocking function out of
+// a non-blocking one:
+//
+// func (o *object) blockingRead(...) error {
+// err := o.nonBlockingRead(...)
+// if err != ErrAgain {
+// // Completed with no need to wait!
+// return err
+// }
+//
+// e := createOrGetWaiterEntry(...)
+// o.EventRegister(&e, waiter.EventIn)
+// defer o.EventUnregister(&e)
+//
+// // We need to try to read again after registration because the
+// // object may have become readable between the last attempt to
+// // read and read registration.
+// err = o.nonBlockingRead(...)
+// for err == ErrAgain {
+// wait()
+// err = o.nonBlockingRead(...)
+// }
+//
+// return err
+// }
+//
+// Another goroutine needs to notify waiters when events happen. For example:
+//
+// func (o *object) Write(...) ... {
+// // Do write work.
+// [...]
+//
+// if oldDataAvailableSize == 0 && dataAvailableSize > 0 {
+// // If no data was available and now some data is
+// // available, the object became readable, so notify
+// // potential waiters about this.
+// o.Notify(waiter.EventIn)
+// }
+// }
+package waiter
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// EventMask represents io events as used in the poll() syscall.
+type EventMask uint64
+
+// Events that waiters can wait on. The meaning is the same as those in the
+// poll() syscall.
+const (
+ EventIn EventMask = 0x01 // POLLIN
+ EventPri EventMask = 0x02 // POLLPRI
+ EventOut EventMask = 0x04 // POLLOUT
+ EventErr EventMask = 0x08 // POLLERR
+ EventHUp EventMask = 0x10 // POLLHUP
+
+ allEvents EventMask = 0x1f
+)
+
+// EventMaskFromLinux returns an EventMask representing the supported events
+// from the Linux events e, which is in the format used by poll(2).
+func EventMaskFromLinux(e uint32) EventMask {
+ // Our flag definitions are currently identical to Linux.
+ return EventMask(e) & allEvents
+}
+
+// ToLinux returns e in the format used by Linux poll(2).
+func (e EventMask) ToLinux() uint32 {
+ // Our flag definitions are currently identical to Linux.
+ return uint32(e)
+}
+
+// Waitable contains the methods that need to be implemented by waitable
+// objects.
+type Waitable interface {
+ // Readiness returns what the object is currently ready for. If it's
+ // not ready for a desired purpose, the caller may use EventRegister and
+ // EventUnregister to get notifications once the object becomes ready.
+ //
+ // Implementations should allow for events like EventHUp and EventErr
+ // to be returned regardless of whether they are in the input EventMask.
+ Readiness(mask EventMask) EventMask
+
+ // EventRegister registers the given waiter entry to receive
+ // notifications when an event occurs that makes the object ready for
+ // at least one of the events in mask.
+ EventRegister(e *Entry, mask EventMask)
+
+ // EventUnregister unregisters a waiter entry previously registered with
+ // EventRegister().
+ EventUnregister(e *Entry)
+}
+
+// EntryCallback provides a notify callback.
+type EntryCallback interface {
+ // Callback is the function to be called when the waiter entry is
+ // notified. It is responsible for doing whatever is needed to wake up
+ // the waiter.
+ //
+ // The callback is supposed to perform minimal work, and cannot call
+ // any method on the queue itself because it will be locked while the
+ // callback is running.
+ Callback(e *Entry)
+}
+
+// Entry represents a waiter that can be add to the a wait queue. It can
+// only be in one queue at a time, and is added "intrusively" to the queue with
+// no extra memory allocations.
+//
+// +stateify savable
+type Entry struct {
+ Callback EntryCallback
+
+ // The following fields are protected by the queue lock.
+ mask EventMask
+ waiterEntry
+}
+
+type channelCallback struct {
+ ch chan struct{}
+}
+
+// Callback implements EntryCallback.Callback.
+func (c *channelCallback) Callback(*Entry) {
+ select {
+ case c.ch <- struct{}{}:
+ default:
+ }
+}
+
+// NewChannelEntry initializes a new Entry that does a non-blocking write to a
+// struct{} channel when the callback is called. It returns the new Entry
+// instance and the channel being used.
+//
+// If a channel isn't specified (i.e., if "c" is nil), then NewChannelEntry
+// allocates a new channel.
+func NewChannelEntry(c chan struct{}) (Entry, chan struct{}) {
+ if c == nil {
+ c = make(chan struct{}, 1)
+ }
+
+ return Entry{Callback: &channelCallback{ch: c}}, c
+}
+
+// Queue represents the wait queue where waiters can be added and
+// notifiers can notify them when events happen.
+//
+// The zero value for waiter.Queue is an empty queue ready for use.
+//
+// +stateify savable
+type Queue struct {
+ list waiterList `state:"zerovalue"`
+ mu sync.RWMutex `state:"nosave"`
+}
+
+// EventRegister adds a waiter to the wait queue; the waiter will be notified
+// when at least one of the events specified in mask happens.
+func (q *Queue) EventRegister(e *Entry, mask EventMask) {
+ q.mu.Lock()
+ e.mask = mask
+ q.list.PushBack(e)
+ q.mu.Unlock()
+}
+
+// EventUnregister removes the given waiter entry from the wait queue.
+func (q *Queue) EventUnregister(e *Entry) {
+ q.mu.Lock()
+ q.list.Remove(e)
+ q.mu.Unlock()
+}
+
+// Notify notifies all waiters in the queue whose masks have at least one bit
+// in common with the notification mask.
+func (q *Queue) Notify(mask EventMask) {
+ q.mu.RLock()
+ for e := q.list.Front(); e != nil; e = e.Next() {
+ if mask&e.mask != 0 {
+ e.Callback.Callback(e)
+ }
+ }
+ q.mu.RUnlock()
+}
+
+// Events returns the set of events being waited on. It is the union of the
+// masks of all registered entries.
+func (q *Queue) Events() EventMask {
+ ret := EventMask(0)
+
+ q.mu.RLock()
+ for e := q.list.Front(); e != nil; e = e.Next() {
+ ret |= e.mask
+ }
+ q.mu.RUnlock()
+
+ return ret
+}
+
+// IsEmpty returns if the wait queue is empty or not.
+func (q *Queue) IsEmpty() bool {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ return q.list.Front() == nil
+}
+
+// AlwaysReady implements the Waitable interface but is always ready. Embedding
+// this struct into another struct makes it implement the boilerplate empty
+// functions automatically.
+type AlwaysReady struct {
+}
+
+// Readiness always returns the input mask because this object is always ready.
+func (*AlwaysReady) Readiness(mask EventMask) EventMask {
+ return mask
+}
+
+// EventRegister doesn't do anything because this object doesn't need to issue
+// notifications because its readiness never changes.
+func (*AlwaysReady) EventRegister(*Entry, EventMask) {
+}
+
+// EventUnregister doesn't do anything because this object doesn't need to issue
+// notifications because its readiness never changes.
+func (*AlwaysReady) EventUnregister(e *Entry) {
+}
diff --git a/pkg/waiter/waiter_test.go b/pkg/waiter/waiter_test.go
new file mode 100644
index 000000000..c1b94a4f3
--- /dev/null
+++ b/pkg/waiter/waiter_test.go
@@ -0,0 +1,192 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package waiter
+
+import (
+ "sync/atomic"
+ "testing"
+)
+
+type callbackStub struct {
+ f func(e *Entry)
+}
+
+// Callback implements EntryCallback.Callback.
+func (c *callbackStub) Callback(e *Entry) {
+ c.f(e)
+}
+
+func TestEmptyQueue(t *testing.T) {
+ var q Queue
+
+ // Notify the zero-value of a queue.
+ q.Notify(EventIn)
+
+ // Register then unregister a waiter, then notify the queue.
+ cnt := 0
+ e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}}
+ q.EventRegister(&e, EventIn)
+ q.EventUnregister(&e)
+ q.Notify(EventIn)
+ if cnt != 0 {
+ t.Errorf("Callback was called when it shouldn't have been")
+ }
+}
+
+func TestMask(t *testing.T) {
+ // Register a waiter.
+ var q Queue
+ var cnt int
+ e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}}
+ q.EventRegister(&e, EventIn|EventErr)
+
+ // Notify with an overlapping mask.
+ cnt = 0
+ q.Notify(EventIn | EventOut)
+ if cnt != 1 {
+ t.Errorf("Callback wasn't called when it should have been")
+ }
+
+ // Notify with a subset mask.
+ cnt = 0
+ q.Notify(EventIn)
+ if cnt != 1 {
+ t.Errorf("Callback wasn't called when it should have been")
+ }
+
+ // Notify with a superset mask.
+ cnt = 0
+ q.Notify(EventIn | EventErr | EventOut)
+ if cnt != 1 {
+ t.Errorf("Callback wasn't called when it should have been")
+ }
+
+ // Notify with the exact same mask.
+ cnt = 0
+ q.Notify(EventIn | EventErr)
+ if cnt != 1 {
+ t.Errorf("Callback wasn't called when it should have been")
+ }
+
+ // Notify with a disjoint mask.
+ cnt = 0
+ q.Notify(EventOut | EventHUp)
+ if cnt != 0 {
+ t.Errorf("Callback was called when it shouldn't have been")
+ }
+}
+
+func TestConcurrentRegistration(t *testing.T) {
+ var q Queue
+ var cnt int
+ const concurrency = 1000
+
+ ch1 := make(chan struct{})
+ ch2 := make(chan struct{})
+ ch3 := make(chan struct{})
+
+ // Create goroutines that will all register/unregister concurrently.
+ for i := 0; i < concurrency; i++ {
+ go func() {
+ var e Entry
+ e.Callback = &callbackStub{func(entry *Entry) {
+ cnt++
+ if entry != &e {
+ t.Errorf("entry = %p, want %p", entry, &e)
+ }
+ }}
+
+ // Wait for notification, then register.
+ <-ch1
+ q.EventRegister(&e, EventIn|EventErr)
+
+ // Tell main goroutine that we're done registering.
+ ch2 <- struct{}{}
+
+ // Wait for notification, then unregister.
+ <-ch3
+ q.EventUnregister(&e)
+
+ // Tell main goroutine that we're done unregistering.
+ ch2 <- struct{}{}
+ }()
+ }
+
+ // Let the goroutines register.
+ close(ch1)
+ for i := 0; i < concurrency; i++ {
+ <-ch2
+ }
+
+ // Issue a notification.
+ q.Notify(EventIn)
+ if cnt != concurrency {
+ t.Errorf("cnt = %d, want %d", cnt, concurrency)
+ }
+
+ // Let the goroutine unregister.
+ close(ch3)
+ for i := 0; i < concurrency; i++ {
+ <-ch2
+ }
+
+ // Issue a notification.
+ q.Notify(EventIn)
+ if cnt != concurrency {
+ t.Errorf("cnt = %d, want %d", cnt, concurrency)
+ }
+}
+
+func TestConcurrentNotification(t *testing.T) {
+ var q Queue
+ var cnt int32
+ const concurrency = 1000
+ const waiterCount = 1000
+
+ // Register waiters.
+ for i := 0; i < waiterCount; i++ {
+ var e Entry
+ e.Callback = &callbackStub{func(entry *Entry) {
+ atomic.AddInt32(&cnt, 1)
+ if entry != &e {
+ t.Errorf("entry = %p, want %p", entry, &e)
+ }
+ }}
+
+ q.EventRegister(&e, EventIn|EventErr)
+ }
+
+ // Launch notifiers.
+ ch1 := make(chan struct{})
+ ch2 := make(chan struct{})
+ for i := 0; i < concurrency; i++ {
+ go func() {
+ <-ch1
+ q.Notify(EventIn)
+ ch2 <- struct{}{}
+ }()
+ }
+
+ // Let notifiers go.
+ close(ch1)
+ for i := 0; i < concurrency; i++ {
+ <-ch2
+ }
+
+ // Check the count.
+ if cnt != concurrency*waiterCount {
+ t.Errorf("cnt = %d, want %d", cnt, concurrency*waiterCount)
+ }
+}
diff --git a/runsc/BUILD b/runsc/BUILD
new file mode 100644
index 000000000..757f6d44c
--- /dev/null
+++ b/runsc/BUILD
@@ -0,0 +1,123 @@
+load("//tools:defs.bzl", "go_binary", "pkg_deb", "pkg_tar")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "runsc",
+ srcs = [
+ "main.go",
+ "version.go",
+ ],
+ pure = True,
+ visibility = [
+ "//visibility:public",
+ ],
+ x_defs = {"main.version": "{STABLE_VERSION}"},
+ deps = [
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/sentry/platform",
+ "//runsc/boot",
+ "//runsc/cmd",
+ "//runsc/flag",
+ "//runsc/specutils",
+ "@com_github_google_subcommands//:go_default_library",
+ ],
+)
+
+# The runsc-race target is a race-compatible BUILD target. This must be built
+# via: bazel build --features=race :runsc-race
+#
+# This is neccessary because the race feature must apply to all dependencies
+# due a bug in gazelle file selection. The pure attribute must be off because
+# the race detector requires linking with non-Go components, although we still
+# require a static binary.
+#
+# Note that in the future this might be convertible to a compatible target by
+# using the pure and static attributes within a select function, but select is
+# not currently compatible with string attributes [1].
+#
+# [1] https://github.com/bazelbuild/bazel/issues/1698
+go_binary(
+ name = "runsc-race",
+ srcs = [
+ "main.go",
+ "version.go",
+ ],
+ static = True,
+ visibility = [
+ "//visibility:public",
+ ],
+ x_defs = {"main.version": "{STABLE_VERSION}"},
+ deps = [
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/sentry/platform",
+ "//runsc/boot",
+ "//runsc/cmd",
+ "//runsc/flag",
+ "//runsc/specutils",
+ "@com_github_google_subcommands//:go_default_library",
+ ],
+)
+
+pkg_tar(
+ name = "runsc-bin",
+ srcs = [":runsc"],
+ mode = "0755",
+ package_dir = "/usr/bin",
+ strip_prefix = "/runsc/linux_amd64_pure_stripped",
+)
+
+pkg_tar(
+ name = "debian-data",
+ extension = "tar.gz",
+ deps = [
+ ":runsc-bin",
+ ],
+)
+
+genrule(
+ name = "deb-version",
+ # Note that runsc must appear in the srcs parameter and not the tools
+ # parameter, otherwise it will not be stamped. This is reasonable, as tools
+ # may be encoded differently in the build graph (cached more aggressively
+ # because they are assumes to be hermetic).
+ srcs = [":runsc"],
+ outs = ["version.txt"],
+ # Note that the little dance here is necessary because files in the $(SRCS)
+ # attribute are not executable by default, and we can't touch in place.
+ cmd = "cp $(location :runsc) $(@D)/runsc && \
+ chmod a+x $(@D)/runsc && \
+ $(@D)/runsc -version | grep version | sed 's/^[^0-9]*//' > $@ && \
+ rm -f $(@D)/runsc",
+ stamp = 1,
+)
+
+pkg_deb(
+ name = "runsc-debian",
+ architecture = "amd64",
+ data = ":debian-data",
+ # Note that the description_file will be flatten (all newlines removed),
+ # and therefore it is kept to a simple one-line description. The expected
+ # format for debian packages is "short summary\nLonger explanation of
+ # tool." and this is impossible with the flattening.
+ description_file = "debian/description",
+ homepage = "https://gvisor.dev/",
+ maintainer = "The gVisor Authors <gvisor-dev@googlegroups.com>",
+ package = "runsc",
+ postinst = "debian/postinst.sh",
+ version_file = ":version.txt",
+ visibility = [
+ "//visibility:public",
+ ],
+)
+
+sh_test(
+ name = "version_test",
+ size = "small",
+ srcs = ["version_test.sh"],
+ args = ["$(location :runsc)"],
+ data = [":runsc"],
+ tags = ["noguitar"],
+)
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
new file mode 100644
index 000000000..aad2a41de
--- /dev/null
+++ b/runsc/boot/BUILD
@@ -0,0 +1,137 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "boot",
+ srcs = [
+ "compat.go",
+ "compat_amd64.go",
+ "compat_arm64.go",
+ "config.go",
+ "controller.go",
+ "debug.go",
+ "events.go",
+ "fs.go",
+ "limits.go",
+ "loader.go",
+ "network.go",
+ "strace.go",
+ "vfs.go",
+ ],
+ visibility = [
+ "//pkg/test:__subpackages__",
+ "//runsc:__subpackages__",
+ "//test:__subpackages__",
+ ],
+ deps = [
+ "//pkg/abi",
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/control/server",
+ "//pkg/cpuid",
+ "//pkg/eventchannel",
+ "//pkg/fspath",
+ "//pkg/log",
+ "//pkg/memutil",
+ "//pkg/rand",
+ "//pkg/refs",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/arch:registers_go_proto",
+ "//pkg/sentry/control",
+ "//pkg/sentry/devices/memdev",
+ "//pkg/sentry/devices/ttydev",
+ "//pkg/sentry/devices/tundev",
+ "//pkg/sentry/fdimport",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/dev",
+ "//pkg/sentry/fs/gofer",
+ "//pkg/sentry/fs/host",
+ "//pkg/sentry/fs/proc",
+ "//pkg/sentry/fs/ramfs",
+ "//pkg/sentry/fs/sys",
+ "//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/fs/tty",
+ "//pkg/sentry/fs/user",
+ "//pkg/sentry/fsimpl/devpts",
+ "//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/fsimpl/fuse",
+ "//pkg/sentry/fsimpl/gofer",
+ "//pkg/sentry/fsimpl/host",
+ "//pkg/sentry/fsimpl/overlay",
+ "//pkg/sentry/fsimpl/proc",
+ "//pkg/sentry/fsimpl/sys",
+ "//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel:uncaught_signal_go_proto",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/loader",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/sighandling",
+ "//pkg/sentry/socket/hostinet",
+ "//pkg/sentry/socket/netlink",
+ "//pkg/sentry/socket/netlink/route",
+ "//pkg/sentry/socket/netlink/uevent",
+ "//pkg/sentry/socket/netstack",
+ "//pkg/sentry/socket/unix",
+ "//pkg/sentry/state",
+ "//pkg/sentry/strace",
+ "//pkg/sentry/syscalls/linux/vfs2",
+ "//pkg/sentry/time",
+ "//pkg/sentry/unimpl:unimplemented_syscall_go_proto",
+ "//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
+ "//pkg/sentry/watchdog",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/tcpip/link/fdbased",
+ "//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/link/qdisc/fifo",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/arp",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/raw",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/urpc",
+ "//runsc/boot/filter",
+ "//runsc/boot/platforms",
+ "//runsc/boot/pprof",
+ "//runsc/specutils",
+ "@com_github_golang_protobuf//proto:go_default_library",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "boot_test",
+ size = "small",
+ srcs = [
+ "compat_test.go",
+ "fs_test.go",
+ "loader_test.go",
+ ],
+ library = ":boot",
+ deps = [
+ "//pkg/control/server",
+ "//pkg/fspath",
+ "//pkg/log",
+ "//pkg/p9",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/unet",
+ "//runsc/fsgofer",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go
new file mode 100644
index 000000000..84c67cbc2
--- /dev/null
+++ b/runsc/boot/compat.go
@@ -0,0 +1,202 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "fmt"
+ "os"
+ "syscall"
+
+ "github.com/golang/protobuf/proto"
+ "gvisor.dev/gvisor/pkg/eventchannel"
+ "gvisor.dev/gvisor/pkg/log"
+ rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
+ ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto"
+ "gvisor.dev/gvisor/pkg/sentry/strace"
+ spb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func initCompatLogs(fd int) error {
+ ce, err := newCompatEmitter(fd)
+ if err != nil {
+ return err
+ }
+ eventchannel.AddEmitter(ce)
+ return nil
+}
+
+type compatEmitter struct {
+ sink *log.BasicLogger
+ nameMap strace.SyscallMap
+
+ // mu protects the fields below.
+ mu sync.Mutex
+
+ // trackers map syscall number to the respective tracker instance.
+ // Protected by 'mu'.
+ trackers map[uint64]syscallTracker
+}
+
+func newCompatEmitter(logFD int) (*compatEmitter, error) {
+ nameMap, ok := getSyscallNameMap()
+ if !ok {
+ return nil, fmt.Errorf("Linux syscall table not found")
+ }
+
+ c := &compatEmitter{
+ // Always logs to default logger.
+ sink: log.Log(),
+ nameMap: nameMap,
+ trackers: make(map[uint64]syscallTracker),
+ }
+
+ if logFD > 0 {
+ f := os.NewFile(uintptr(logFD), "user log file")
+ target := &log.MultiEmitter{c.sink, log.K8sJSONEmitter{&log.Writer{Next: f}}}
+ c.sink = &log.BasicLogger{Level: log.Info, Emitter: target}
+ }
+ return c, nil
+}
+
+// Emit implements eventchannel.Emitter.
+func (c *compatEmitter) Emit(msg proto.Message) (bool, error) {
+ switch m := msg.(type) {
+ case *spb.UnimplementedSyscall:
+ c.emitUnimplementedSyscall(m)
+ case *ucspb.UncaughtSignal:
+ c.emitUncaughtSignal(m)
+ }
+
+ return false, nil
+}
+
+func (c *compatEmitter) emitUnimplementedSyscall(us *spb.UnimplementedSyscall) {
+ regs := us.Registers
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ sysnr := syscallNum(regs)
+ tr := c.trackers[sysnr]
+ if tr == nil {
+ switch sysnr {
+ case syscall.SYS_PRCTL:
+ // args: cmd, ...
+ tr = newArgsTracker(0)
+
+ case syscall.SYS_IOCTL, syscall.SYS_EPOLL_CTL, syscall.SYS_SHMCTL, syscall.SYS_FUTEX, syscall.SYS_FALLOCATE:
+ // args: fd/addr, cmd, ...
+ tr = newArgsTracker(1)
+
+ case syscall.SYS_GETSOCKOPT, syscall.SYS_SETSOCKOPT:
+ // args: fd, level, name, ...
+ tr = newArgsTracker(1, 2)
+
+ case syscall.SYS_SEMCTL:
+ // args: semid, semnum, cmd, ...
+ tr = newArgsTracker(2)
+
+ default:
+ tr = newArchArgsTracker(sysnr)
+ if tr == nil {
+ tr = &onceTracker{}
+ }
+ }
+ c.trackers[sysnr] = tr
+ }
+
+ if tr.shouldReport(regs) {
+ name := c.nameMap.Name(uintptr(sysnr))
+ c.sink.Infof("Unsupported syscall %s(%#x,%#x,%#x,%#x,%#x,%#x). It is "+
+ "likely that you can safely ignore this message and that this is not "+
+ "the cause of any error. Please, refer to %s/%s for more information.",
+ name, argVal(0, regs), argVal(1, regs), argVal(2, regs), argVal(3, regs),
+ argVal(4, regs), argVal(5, regs), syscallLink, name)
+
+ tr.onReported(regs)
+ }
+}
+
+func (c *compatEmitter) emitUncaughtSignal(msg *ucspb.UncaughtSignal) {
+ sig := syscall.Signal(msg.SignalNumber)
+ c.sink.Infof(
+ "Uncaught signal: %q (%d), PID: %d, TID: %d, fault addr: %#x",
+ sig, msg.SignalNumber, msg.Pid, msg.Tid, msg.FaultAddr)
+}
+
+// Close implements eventchannel.Emitter.
+func (c *compatEmitter) Close() error {
+ c.sink = nil
+ return nil
+}
+
+// syscallTracker interface allows filters to apply differently depending on
+// the syscall and arguments.
+type syscallTracker interface {
+ // shouldReport returns true is the syscall should be reported.
+ shouldReport(regs *rpb.Registers) bool
+
+ // onReported marks the syscall as reported.
+ onReported(regs *rpb.Registers)
+}
+
+// onceTracker reports only a single time, used for most syscalls.
+type onceTracker struct {
+ reported bool
+}
+
+func (o *onceTracker) shouldReport(_ *rpb.Registers) bool {
+ return !o.reported
+}
+
+func (o *onceTracker) onReported(_ *rpb.Registers) {
+ o.reported = true
+}
+
+// argsTracker reports only once for each different combination of arguments.
+// It's used for generic syscalls like ioctl to report once per 'cmd'.
+type argsTracker struct {
+ // argsIdx is the syscall arguments to use as unique ID.
+ argsIdx []int
+ reported map[string]struct{}
+ count int
+}
+
+func newArgsTracker(argIdx ...int) *argsTracker {
+ return &argsTracker{argsIdx: argIdx, reported: make(map[string]struct{})}
+}
+
+// key returns the command based on the syscall argument index.
+func (a *argsTracker) key(regs *rpb.Registers) string {
+ var rv string
+ for _, idx := range a.argsIdx {
+ rv += fmt.Sprintf("%d|", argVal(idx, regs))
+ }
+ return rv
+}
+
+func (a *argsTracker) shouldReport(regs *rpb.Registers) bool {
+ if a.count >= reportLimit {
+ return false
+ }
+ _, ok := a.reported[a.key(regs)]
+ return !ok
+}
+
+func (a *argsTracker) onReported(regs *rpb.Registers) {
+ a.count++
+ a.reported[a.key(regs)] = struct{}{}
+}
diff --git a/runsc/boot/compat_amd64.go b/runsc/boot/compat_amd64.go
new file mode 100644
index 000000000..8eb76b2ba
--- /dev/null
+++ b/runsc/boot/compat_amd64.go
@@ -0,0 +1,100 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
+ "gvisor.dev/gvisor/pkg/sentry/strace"
+)
+
+const (
+ // reportLimit is the max number of events that should be reported per
+ // tracker.
+ reportLimit = 100
+ syscallLink = "https://gvisor.dev/c/linux/amd64"
+)
+
+// newRegs create a empty Registers instance.
+func newRegs() *rpb.Registers {
+ return &rpb.Registers{
+ Arch: &rpb.Registers_Amd64{
+ Amd64: &rpb.AMD64Registers{},
+ },
+ }
+}
+
+func argVal(argIdx int, regs *rpb.Registers) uint64 {
+ amd64Regs := regs.GetArch().(*rpb.Registers_Amd64).Amd64
+
+ switch argIdx {
+ case 0:
+ return amd64Regs.Rdi
+ case 1:
+ return amd64Regs.Rsi
+ case 2:
+ return amd64Regs.Rdx
+ case 3:
+ return amd64Regs.R10
+ case 4:
+ return amd64Regs.R8
+ case 5:
+ return amd64Regs.R9
+ }
+ panic(fmt.Sprintf("invalid syscall argument index %d", argIdx))
+}
+
+func setArgVal(argIdx int, argVal uint64, regs *rpb.Registers) {
+ amd64Regs := regs.GetArch().(*rpb.Registers_Amd64).Amd64
+
+ switch argIdx {
+ case 0:
+ amd64Regs.Rdi = argVal
+ case 1:
+ amd64Regs.Rsi = argVal
+ case 2:
+ amd64Regs.Rdx = argVal
+ case 3:
+ amd64Regs.R10 = argVal
+ case 4:
+ amd64Regs.R8 = argVal
+ case 5:
+ amd64Regs.R9 = argVal
+ default:
+ panic(fmt.Sprintf("invalid syscall argument index %d", argIdx))
+ }
+}
+
+func getSyscallNameMap() (strace.SyscallMap, bool) {
+ return strace.Lookup(abi.Linux, arch.AMD64)
+}
+
+func syscallNum(regs *rpb.Registers) uint64 {
+ amd64Regs := regs.GetArch().(*rpb.Registers_Amd64).Amd64
+ return amd64Regs.OrigRax
+}
+
+func newArchArgsTracker(sysnr uint64) syscallTracker {
+ switch sysnr {
+ case syscall.SYS_ARCH_PRCTL:
+ // args: cmd, ...
+ return newArgsTracker(0)
+ }
+ return nil
+}
diff --git a/runsc/boot/compat_arm64.go b/runsc/boot/compat_arm64.go
new file mode 100644
index 000000000..bce9d95b3
--- /dev/null
+++ b/runsc/boot/compat_arm64.go
@@ -0,0 +1,95 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
+ "gvisor.dev/gvisor/pkg/sentry/strace"
+)
+
+const (
+ // reportLimit is the max number of events that should be reported per
+ // tracker.
+ reportLimit = 100
+ syscallLink = "https://gvisor.dev/c/linux/arm64"
+)
+
+// newRegs create a empty Registers instance.
+func newRegs() *rpb.Registers {
+ return &rpb.Registers{
+ Arch: &rpb.Registers_Arm64{
+ Arm64: &rpb.ARM64Registers{},
+ },
+ }
+}
+
+func argVal(argIdx int, regs *rpb.Registers) uint64 {
+ arm64Regs := regs.GetArch().(*rpb.Registers_Arm64).Arm64
+
+ switch argIdx {
+ case 0:
+ return arm64Regs.R0
+ case 1:
+ return arm64Regs.R1
+ case 2:
+ return arm64Regs.R2
+ case 3:
+ return arm64Regs.R3
+ case 4:
+ return arm64Regs.R4
+ case 5:
+ return arm64Regs.R5
+ }
+ panic(fmt.Sprintf("invalid syscall argument index %d", argIdx))
+}
+
+func setArgVal(argIdx int, argVal uint64, regs *rpb.Registers) {
+ arm64Regs := regs.GetArch().(*rpb.Registers_Arm64).Arm64
+
+ switch argIdx {
+ case 0:
+ arm64Regs.R0 = argVal
+ case 1:
+ arm64Regs.R1 = argVal
+ case 2:
+ arm64Regs.R2 = argVal
+ case 3:
+ arm64Regs.R3 = argVal
+ case 4:
+ arm64Regs.R4 = argVal
+ case 5:
+ arm64Regs.R5 = argVal
+ default:
+ panic(fmt.Sprintf("invalid syscall argument index %d", argIdx))
+ }
+}
+
+func getSyscallNameMap() (strace.SyscallMap, bool) {
+ return strace.Lookup(abi.Linux, arch.ARM64)
+}
+
+func syscallNum(regs *rpb.Registers) uint64 {
+ arm64Regs := regs.GetArch().(*rpb.Registers_Arm64).Arm64
+ return arm64Regs.R8
+}
+
+func newArchArgsTracker(sysnr uint64) syscallTracker {
+ // currently, no arch specific syscalls need to be handled here.
+ return nil
+}
diff --git a/runsc/boot/compat_test.go b/runsc/boot/compat_test.go
new file mode 100644
index 000000000..839c5303b
--- /dev/null
+++ b/runsc/boot/compat_test.go
@@ -0,0 +1,90 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "testing"
+)
+
+func TestOnceTracker(t *testing.T) {
+ o := onceTracker{}
+ if !o.shouldReport(nil) {
+ t.Error("first call to checkAndMark, got: false, want: true")
+ }
+ o.onReported(nil)
+ for i := 0; i < 2; i++ {
+ if o.shouldReport(nil) {
+ t.Error("after first call to checkAndMark, got: true, want: false")
+ }
+ }
+}
+
+func TestArgsTracker(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ idx []int
+ arg1_1 uint64
+ arg1_2 uint64
+ arg2_1 uint64
+ arg2_2 uint64
+ want bool
+ }{
+ {name: "same arg1", idx: []int{0}, arg1_1: 123, arg1_2: 123, want: false},
+ {name: "same arg2", idx: []int{1}, arg2_1: 123, arg2_2: 123, want: false},
+ {name: "diff arg1", idx: []int{0}, arg1_1: 123, arg1_2: 321, want: true},
+ {name: "diff arg2", idx: []int{1}, arg2_1: 123, arg2_2: 321, want: true},
+ {name: "cmd is uint32", idx: []int{0}, arg2_1: 0xdead00000123, arg2_2: 0xbeef00000123, want: false},
+ {name: "same 2 args", idx: []int{0, 1}, arg2_1: 123, arg1_1: 321, arg2_2: 123, arg1_2: 321, want: false},
+ {name: "diff 2 args", idx: []int{0, 1}, arg2_1: 123, arg1_1: 321, arg2_2: 789, arg1_2: 987, want: true},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ c := newArgsTracker(tc.idx...)
+ regs := newRegs()
+ setArgVal(0, tc.arg1_1, regs)
+ setArgVal(1, tc.arg2_1, regs)
+ if !c.shouldReport(regs) {
+ t.Error("first call to shouldReport, got: false, want: true")
+ }
+ c.onReported(regs)
+
+ setArgVal(0, tc.arg1_2, regs)
+ setArgVal(1, tc.arg2_2, regs)
+ if got := c.shouldReport(regs); tc.want != got {
+ t.Errorf("second call to shouldReport, got: %t, want: %t", got, tc.want)
+ }
+ })
+ }
+}
+
+func TestArgsTrackerLimit(t *testing.T) {
+ c := newArgsTracker(0, 1)
+ for i := 0; i < reportLimit; i++ {
+ regs := newRegs()
+ setArgVal(0, 123, regs)
+ setArgVal(1, uint64(i), regs)
+ if !c.shouldReport(regs) {
+ t.Error("shouldReport before limit was reached, got: false, want: true")
+ }
+ c.onReported(regs)
+ }
+
+ // Should hit the count limit now.
+ regs := newRegs()
+ setArgVal(0, 123, regs)
+ setArgVal(1, 123456, regs)
+ if c.shouldReport(regs) {
+ t.Error("shouldReport after limit was reached, got: true, want: false")
+ }
+}
diff --git a/runsc/boot/config.go b/runsc/boot/config.go
new file mode 100644
index 000000000..bb01b8fb5
--- /dev/null
+++ b/runsc/boot/config.go
@@ -0,0 +1,329 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/watchdog"
+)
+
+// FileAccessType tells how the filesystem is accessed.
+type FileAccessType int
+
+const (
+ // FileAccessShared sends IO requests to a Gofer process that validates the
+ // requests and forwards them to the host.
+ FileAccessShared FileAccessType = iota
+
+ // FileAccessExclusive is the same as FileAccessShared, but enables
+ // extra caching for improved performance. It should only be used if
+ // the sandbox has exclusive access to the filesystem.
+ FileAccessExclusive
+)
+
+// MakeFileAccessType converts type from string.
+func MakeFileAccessType(s string) (FileAccessType, error) {
+ switch s {
+ case "shared":
+ return FileAccessShared, nil
+ case "exclusive":
+ return FileAccessExclusive, nil
+ default:
+ return 0, fmt.Errorf("invalid file access type %q", s)
+ }
+}
+
+func (f FileAccessType) String() string {
+ switch f {
+ case FileAccessShared:
+ return "shared"
+ case FileAccessExclusive:
+ return "exclusive"
+ default:
+ return fmt.Sprintf("unknown(%d)", f)
+ }
+}
+
+// NetworkType tells which network stack to use.
+type NetworkType int
+
+const (
+ // NetworkSandbox uses internal network stack, isolated from the host.
+ NetworkSandbox NetworkType = iota
+
+ // NetworkHost redirects network related syscalls to the host network.
+ NetworkHost
+
+ // NetworkNone sets up just loopback using netstack.
+ NetworkNone
+)
+
+// MakeNetworkType converts type from string.
+func MakeNetworkType(s string) (NetworkType, error) {
+ switch s {
+ case "sandbox":
+ return NetworkSandbox, nil
+ case "host":
+ return NetworkHost, nil
+ case "none":
+ return NetworkNone, nil
+ default:
+ return 0, fmt.Errorf("invalid network type %q", s)
+ }
+}
+
+func (n NetworkType) String() string {
+ switch n {
+ case NetworkSandbox:
+ return "sandbox"
+ case NetworkHost:
+ return "host"
+ case NetworkNone:
+ return "none"
+ default:
+ return fmt.Sprintf("unknown(%d)", n)
+ }
+}
+
+// MakeWatchdogAction converts type from string.
+func MakeWatchdogAction(s string) (watchdog.Action, error) {
+ switch strings.ToLower(s) {
+ case "log", "logwarning":
+ return watchdog.LogWarning, nil
+ case "panic":
+ return watchdog.Panic, nil
+ default:
+ return 0, fmt.Errorf("invalid watchdog action %q", s)
+ }
+}
+
+// MakeRefsLeakMode converts type from string.
+func MakeRefsLeakMode(s string) (refs.LeakMode, error) {
+ switch strings.ToLower(s) {
+ case "disabled":
+ return refs.NoLeakChecking, nil
+ case "log-names":
+ return refs.LeaksLogWarning, nil
+ case "log-traces":
+ return refs.LeaksLogTraces, nil
+ default:
+ return 0, fmt.Errorf("invalid refs leakmode %q", s)
+ }
+}
+
+func refsLeakModeToString(mode refs.LeakMode) string {
+ switch mode {
+ // If not set, default it to disabled.
+ case refs.UninitializedLeakChecking, refs.NoLeakChecking:
+ return "disabled"
+ case refs.LeaksLogWarning:
+ return "log-names"
+ case refs.LeaksLogTraces:
+ return "log-traces"
+ default:
+ panic(fmt.Sprintf("Invalid leakmode: %d", mode))
+ }
+}
+
+// Config holds configuration that is not part of the runtime spec.
+type Config struct {
+ // RootDir is the runtime root directory.
+ RootDir string
+
+ // Debug indicates that debug logging should be enabled.
+ Debug bool
+
+ // LogFilename is the filename to log to, if not empty.
+ LogFilename string
+
+ // LogFormat is the log format.
+ LogFormat string
+
+ // DebugLog is the path to log debug information to, if not empty.
+ DebugLog string
+
+ // PanicLog is the path to log GO's runtime messages, if not empty.
+ PanicLog string
+
+ // DebugLogFormat is the log format for debug.
+ DebugLogFormat string
+
+ // FileAccess indicates how the filesystem is accessed.
+ FileAccess FileAccessType
+
+ // Overlay is whether to wrap the root filesystem in an overlay.
+ Overlay bool
+
+ // FSGoferHostUDS enables the gofer to mount a host UDS.
+ FSGoferHostUDS bool
+
+ // Network indicates what type of network to use.
+ Network NetworkType
+
+ // EnableRaw indicates whether raw sockets should be enabled. Raw
+ // sockets are disabled by stripping CAP_NET_RAW from the list of
+ // capabilities.
+ EnableRaw bool
+
+ // HardwareGSO indicates that hardware segmentation offload is enabled.
+ HardwareGSO bool
+
+ // SoftwareGSO indicates that software segmentation offload is enabled.
+ SoftwareGSO bool
+
+ // TXChecksumOffload indicates that TX Checksum Offload is enabled.
+ TXChecksumOffload bool
+
+ // RXChecksumOffload indicates that RX Checksum Offload is enabled.
+ RXChecksumOffload bool
+
+ // QDisc indicates the type of queuening discipline to use by default
+ // for non-loopback interfaces.
+ QDisc QueueingDiscipline
+
+ // LogPackets indicates that all network packets should be logged.
+ LogPackets bool
+
+ // Platform is the platform to run on.
+ Platform string
+
+ // Strace indicates that strace should be enabled.
+ Strace bool
+
+ // StraceSyscalls is the set of syscalls to trace. If StraceEnable is
+ // true and this list is empty, then all syscalls will be traced.
+ StraceSyscalls []string
+
+ // StraceLogSize is the max size of data blobs to display.
+ StraceLogSize uint
+
+ // DisableSeccomp indicates whether seccomp syscall filters should be
+ // disabled. Pardon the double negation, but default to enabled is important.
+ DisableSeccomp bool
+
+ // WatchdogAction sets what action the watchdog takes when triggered.
+ WatchdogAction watchdog.Action
+
+ // PanicSignal registers signal handling that panics. Usually set to
+ // SIGUSR2(12) to troubleshoot hangs. -1 disables it.
+ PanicSignal int
+
+ // ProfileEnable is set to prepare the sandbox to be profiled.
+ ProfileEnable bool
+
+ // RestoreFile is the path to the saved container image
+ RestoreFile string
+
+ // NumNetworkChannels controls the number of AF_PACKET sockets that map
+ // to the same underlying network device. This allows netstack to better
+ // scale for high throughput use cases.
+ NumNetworkChannels int
+
+ // Rootless allows the sandbox to be started with a user that is not root.
+ // Defense is depth measures are weaker with rootless. Specifically, the
+ // sandbox and Gofer process run as root inside a user namespace with root
+ // mapped to the caller's user.
+ Rootless bool
+
+ // AlsoLogToStderr allows to send log messages to stderr.
+ AlsoLogToStderr bool
+
+ // ReferenceLeakMode sets reference leak check mode
+ ReferenceLeakMode refs.LeakMode
+
+ // OverlayfsStaleRead instructs the sandbox to assume that the root mount
+ // is on a Linux overlayfs mount, which does not necessarily preserve
+ // coherence between read-only and subsequent writable file descriptors
+ // representing the "same" file.
+ OverlayfsStaleRead bool
+
+ // TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in
+ // tests. It allows runsc to start the sandbox process as the current
+ // user, and without chrooting the sandbox process. This can be
+ // necessary in test environments that have limited capabilities.
+ TestOnlyAllowRunAsCurrentUserWithoutChroot bool
+
+ // TestOnlyTestNameEnv should only be used in tests. It looks up for the
+ // test name in the container environment variables and adds it to the debug
+ // log file name. This is done to help identify the log with the test when
+ // multiple tests are run in parallel, since there is no way to pass
+ // parameters to the runtime from docker.
+ TestOnlyTestNameEnv string
+
+ // CPUNumFromQuota sets CPU number count to available CPU quota, using
+ // least integer value greater than or equal to quota.
+ //
+ // E.g. 0.2 CPU quota will result in 1, and 1.9 in 2.
+ CPUNumFromQuota bool
+
+ // Enables VFS2 (not plumbled through yet).
+ VFS2 bool
+}
+
+// ToFlags returns a slice of flags that correspond to the given Config.
+func (c *Config) ToFlags() []string {
+ f := []string{
+ "--root=" + c.RootDir,
+ "--debug=" + strconv.FormatBool(c.Debug),
+ "--log=" + c.LogFilename,
+ "--log-format=" + c.LogFormat,
+ "--debug-log=" + c.DebugLog,
+ "--panic-log=" + c.PanicLog,
+ "--debug-log-format=" + c.DebugLogFormat,
+ "--file-access=" + c.FileAccess.String(),
+ "--overlay=" + strconv.FormatBool(c.Overlay),
+ "--fsgofer-host-uds=" + strconv.FormatBool(c.FSGoferHostUDS),
+ "--network=" + c.Network.String(),
+ "--log-packets=" + strconv.FormatBool(c.LogPackets),
+ "--platform=" + c.Platform,
+ "--strace=" + strconv.FormatBool(c.Strace),
+ "--strace-syscalls=" + strings.Join(c.StraceSyscalls, ","),
+ "--strace-log-size=" + strconv.Itoa(int(c.StraceLogSize)),
+ "--watchdog-action=" + c.WatchdogAction.String(),
+ "--panic-signal=" + strconv.Itoa(c.PanicSignal),
+ "--profile=" + strconv.FormatBool(c.ProfileEnable),
+ "--net-raw=" + strconv.FormatBool(c.EnableRaw),
+ "--num-network-channels=" + strconv.Itoa(c.NumNetworkChannels),
+ "--rootless=" + strconv.FormatBool(c.Rootless),
+ "--alsologtostderr=" + strconv.FormatBool(c.AlsoLogToStderr),
+ "--ref-leak-mode=" + refsLeakModeToString(c.ReferenceLeakMode),
+ "--gso=" + strconv.FormatBool(c.HardwareGSO),
+ "--software-gso=" + strconv.FormatBool(c.SoftwareGSO),
+ "--rx-checksum-offload=" + strconv.FormatBool(c.RXChecksumOffload),
+ "--tx-checksum-offload=" + strconv.FormatBool(c.TXChecksumOffload),
+ "--overlayfs-stale-read=" + strconv.FormatBool(c.OverlayfsStaleRead),
+ "--qdisc=" + c.QDisc.String(),
+ }
+ if c.CPUNumFromQuota {
+ f = append(f, "--cpu-num-from-quota")
+ }
+ // Only include these if set since it is never to be used by users.
+ if c.TestOnlyAllowRunAsCurrentUserWithoutChroot {
+ f = append(f, "--TESTONLY-unsafe-nonroot=true")
+ }
+ if len(c.TestOnlyTestNameEnv) != 0 {
+ f = append(f, "--TESTONLY-test-name-env="+c.TestOnlyTestNameEnv)
+ }
+
+ if c.VFS2 {
+ f = append(f, "--vfs2=true")
+ }
+
+ return f
+}
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
new file mode 100644
index 000000000..8125d5061
--- /dev/null
+++ b/runsc/boot/controller.go
@@ -0,0 +1,506 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "errors"
+ "fmt"
+ "os"
+ "syscall"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/control/server"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/control"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/sentry/state"
+ "gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sentry/watchdog"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/urpc"
+ "gvisor.dev/gvisor/runsc/boot/pprof"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+const (
+ // ContainerCheckpoint checkpoints a container.
+ ContainerCheckpoint = "containerManager.Checkpoint"
+
+ // ContainerCreate creates a container.
+ ContainerCreate = "containerManager.Create"
+
+ // ContainerDestroy is used to stop a non-root container and free all
+ // associated resources in the sandbox.
+ ContainerDestroy = "containerManager.Destroy"
+
+ // ContainerEvent is the URPC endpoint for getting stats about the
+ // container used by "runsc events".
+ ContainerEvent = "containerManager.Event"
+
+ // ContainerExecuteAsync is the URPC endpoint for executing a command in a
+ // container.
+ ContainerExecuteAsync = "containerManager.ExecuteAsync"
+
+ // ContainerPause pauses the container.
+ ContainerPause = "containerManager.Pause"
+
+ // ContainerProcesses is the URPC endpoint for getting the list of
+ // processes running in a container.
+ ContainerProcesses = "containerManager.Processes"
+
+ // ContainerRestore restores a container from a statefile.
+ ContainerRestore = "containerManager.Restore"
+
+ // ContainerResume unpauses the paused container.
+ ContainerResume = "containerManager.Resume"
+
+ // ContainerSignal is used to send a signal to a container.
+ ContainerSignal = "containerManager.Signal"
+
+ // ContainerSignalProcess is used to send a signal to a particular
+ // process in a container.
+ ContainerSignalProcess = "containerManager.SignalProcess"
+
+ // ContainerStart is the URPC endpoint for running a non-root container
+ // within a sandbox.
+ ContainerStart = "containerManager.Start"
+
+ // ContainerWait is used to wait on the init process of the container
+ // and return its ExitStatus.
+ ContainerWait = "containerManager.Wait"
+
+ // ContainerWaitPID is used to wait on a process with a certain PID in
+ // the sandbox and return its ExitStatus.
+ ContainerWaitPID = "containerManager.WaitPID"
+
+ // NetworkCreateLinksAndRoutes is the URPC endpoint for creating links
+ // and routes in a network stack.
+ NetworkCreateLinksAndRoutes = "Network.CreateLinksAndRoutes"
+
+ // RootContainerStart is the URPC endpoint for starting a new sandbox
+ // with root container.
+ RootContainerStart = "containerManager.StartRoot"
+
+ // SandboxStacks collects sandbox stacks for debugging.
+ SandboxStacks = "debug.Stacks"
+)
+
+// Profiling related commands (see pprof.go for more details).
+const (
+ StartCPUProfile = "Profile.StartCPUProfile"
+ StopCPUProfile = "Profile.StopCPUProfile"
+ HeapProfile = "Profile.HeapProfile"
+ GoroutineProfile = "Profile.GoroutineProfile"
+ BlockProfile = "Profile.BlockProfile"
+ MutexProfile = "Profile.MutexProfile"
+ StartTrace = "Profile.StartTrace"
+ StopTrace = "Profile.StopTrace"
+)
+
+// Logging related commands (see logging.go for more details).
+const (
+ ChangeLogging = "Logging.Change"
+)
+
+// ControlSocketAddr generates an abstract unix socket name for the given ID.
+func ControlSocketAddr(id string) string {
+ return fmt.Sprintf("\x00runsc-sandbox.%s", id)
+}
+
+// controller holds the control server, and is used for communication into the
+// sandbox.
+type controller struct {
+ // srv is the control server.
+ srv *server.Server
+
+ // manager holds the containerManager methods.
+ manager *containerManager
+}
+
+// newController creates a new controller. The caller must call
+// controller.srv.StartServing() to start the controller.
+func newController(fd int, l *Loader) (*controller, error) {
+ srv, err := server.CreateFromFD(fd)
+ if err != nil {
+ return nil, err
+ }
+
+ manager := &containerManager{
+ startChan: make(chan struct{}),
+ startResultChan: make(chan error),
+ l: l,
+ }
+ srv.Register(manager)
+
+ if eps, ok := l.k.RootNetworkNamespace().Stack().(*netstack.Stack); ok {
+ net := &Network{
+ Stack: eps.Stack,
+ }
+ srv.Register(net)
+ }
+
+ srv.Register(&debug{})
+ srv.Register(&control.Logging{})
+ if l.conf.ProfileEnable {
+ srv.Register(&control.Profile{
+ Kernel: l.k,
+ })
+ }
+
+ return &controller{
+ srv: srv,
+ manager: manager,
+ }, nil
+}
+
+// containerManager manages sandbox containers.
+type containerManager struct {
+ // startChan is used to signal when the root container process should
+ // be started.
+ startChan chan struct{}
+
+ // startResultChan is used to signal when the root container has
+ // started. Any errors encountered during startup will be sent to the
+ // channel. A nil value indicates success.
+ startResultChan chan error
+
+ // l is the loader that creates containers and sandboxes.
+ l *Loader
+}
+
+// StartRoot will start the root container process.
+func (cm *containerManager) StartRoot(cid *string, _ *struct{}) error {
+ log.Debugf("containerManager.StartRoot %q", *cid)
+ // Tell the root container to start and wait for the result.
+ cm.startChan <- struct{}{}
+ if err := <-cm.startResultChan; err != nil {
+ return fmt.Errorf("starting sandbox: %v", err)
+ }
+ return nil
+}
+
+// Processes retrieves information about processes running in the sandbox.
+func (cm *containerManager) Processes(cid *string, out *[]*control.Process) error {
+ log.Debugf("containerManager.Processes: %q", *cid)
+ return control.Processes(cm.l.k, *cid, out)
+}
+
+// Create creates a container within a sandbox.
+func (cm *containerManager) Create(cid *string, _ *struct{}) error {
+ log.Debugf("containerManager.Create: %q", *cid)
+ return cm.l.createContainer(*cid)
+}
+
+// StartArgs contains arguments to the Start method.
+type StartArgs struct {
+ // Spec is the spec of the container to start.
+ Spec *specs.Spec
+
+ // Config is the runsc-specific configuration for the sandbox.
+ Conf *Config
+
+ // CID is the ID of the container to start.
+ CID string
+
+ // FilePayload contains, in order:
+ // * stdin, stdout, and stderr.
+ // * the file descriptor over which the sandbox will
+ // request files from its root filesystem.
+ urpc.FilePayload
+}
+
+// Start runs a created container within a sandbox.
+func (cm *containerManager) Start(args *StartArgs, _ *struct{}) error {
+ log.Debugf("containerManager.Start: %+v", args)
+
+ // Validate arguments.
+ if args == nil {
+ return errors.New("start missing arguments")
+ }
+ if args.Spec == nil {
+ return errors.New("start arguments missing spec")
+ }
+ if args.Conf == nil {
+ return errors.New("start arguments missing config")
+ }
+ if args.CID == "" {
+ return errors.New("start argument missing container ID")
+ }
+ if len(args.FilePayload.Files) < 4 {
+ return fmt.Errorf("start arguments must contain stdin, stderr, and stdout followed by at least one file for the container root gofer")
+ }
+
+ // All validation passed, logs the spec for debugging.
+ specutils.LogSpec(args.Spec)
+
+ err := cm.l.startContainer(args.Spec, args.Conf, args.CID, args.FilePayload.Files)
+ if err != nil {
+ log.Debugf("containerManager.Start failed %q: %+v: %v", args.CID, args, err)
+ return err
+ }
+ log.Debugf("Container %q started", args.CID)
+
+ return nil
+}
+
+// Destroy stops a container if it is still running and cleans up its
+// filesystem.
+func (cm *containerManager) Destroy(cid *string, _ *struct{}) error {
+ log.Debugf("containerManager.destroy %q", *cid)
+ return cm.l.destroyContainer(*cid)
+}
+
+// ExecuteAsync starts running a command on a created or running sandbox. It
+// returns the PID of the new process.
+func (cm *containerManager) ExecuteAsync(args *control.ExecArgs, pid *int32) error {
+ log.Debugf("containerManager.ExecuteAsync: %+v", args)
+ tgid, err := cm.l.executeAsync(args)
+ if err != nil {
+ log.Debugf("containerManager.ExecuteAsync failed: %+v: %v", args, err)
+ return err
+ }
+ *pid = int32(tgid)
+ return nil
+}
+
+// Checkpoint pauses a sandbox and saves its state.
+func (cm *containerManager) Checkpoint(o *control.SaveOpts, _ *struct{}) error {
+ log.Debugf("containerManager.Checkpoint")
+ state := control.State{
+ Kernel: cm.l.k,
+ Watchdog: cm.l.watchdog,
+ }
+ return state.Save(o, nil)
+}
+
+// Pause suspends a container.
+func (cm *containerManager) Pause(_, _ *struct{}) error {
+ log.Debugf("containerManager.Pause")
+ cm.l.k.Pause()
+ return nil
+}
+
+// RestoreOpts contains options related to restoring a container's file system.
+type RestoreOpts struct {
+ // FilePayload contains the state file to be restored, followed by the
+ // platform device file if necessary.
+ urpc.FilePayload
+
+ // SandboxID contains the ID of the sandbox.
+ SandboxID string
+}
+
+// Restore loads a container from a statefile.
+// The container's current kernel is destroyed, a restore environment is
+// created, and the kernel is recreated with the restore state file. The
+// container then sends the signal to start.
+func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
+ log.Debugf("containerManager.Restore")
+
+ var specFile, deviceFile *os.File
+ switch numFiles := len(o.FilePayload.Files); numFiles {
+ case 2:
+ // The device file is donated to the platform.
+ // Can't take ownership away from os.File. dup them to get a new FD.
+ fd, err := syscall.Dup(int(o.FilePayload.Files[1].Fd()))
+ if err != nil {
+ return fmt.Errorf("failed to dup file: %v", err)
+ }
+ deviceFile = os.NewFile(uintptr(fd), "platform device")
+ fallthrough
+ case 1:
+ specFile = o.FilePayload.Files[0]
+ case 0:
+ return fmt.Errorf("at least one file must be passed to Restore")
+ default:
+ return fmt.Errorf("at most two files may be passed to Restore")
+ }
+
+ // Pause the kernel while we build a new one.
+ cm.l.k.Pause()
+
+ p, err := createPlatform(cm.l.conf, deviceFile)
+ if err != nil {
+ return fmt.Errorf("creating platform: %v", err)
+ }
+ k := &kernel.Kernel{
+ Platform: p,
+ }
+ mf, err := createMemoryFile()
+ if err != nil {
+ return fmt.Errorf("creating memory file: %v", err)
+ }
+ k.SetMemoryFile(mf)
+ networkStack := cm.l.k.RootNetworkNamespace().Stack()
+ cm.l.k = k
+
+ // Set up the restore environment.
+ mntr := newContainerMounter(cm.l.spec, cm.l.goferFDs, cm.l.k, cm.l.mountHints)
+ renv, err := mntr.createRestoreEnvironment(cm.l.conf)
+ if err != nil {
+ return fmt.Errorf("creating RestoreEnvironment: %v", err)
+ }
+ fs.SetRestoreEnvironment(*renv)
+
+ // Prepare to load from the state file.
+ if eps, ok := networkStack.(*netstack.Stack); ok {
+ stack.StackFromEnv = eps.Stack // FIXME(b/36201077)
+ }
+ info, err := specFile.Stat()
+ if err != nil {
+ return err
+ }
+ if info.Size() == 0 {
+ return fmt.Errorf("file cannot be empty")
+ }
+
+ if cm.l.conf.ProfileEnable {
+ // pprof.Initialize opens /proc/self/maps, so has to be called before
+ // installing seccomp filters.
+ pprof.Initialize()
+ }
+
+ // Seccomp filters have to be applied before parsing the state file.
+ if err := cm.l.installSeccompFilters(); err != nil {
+ return err
+ }
+
+ // Load the state.
+ loadOpts := state.LoadOpts{Source: specFile}
+ if err := loadOpts.Load(k, networkStack, time.NewCalibratedClocks()); err != nil {
+ return err
+ }
+
+ // Since we have a new kernel we also must make a new watchdog.
+ dogOpts := watchdog.DefaultOpts
+ dogOpts.TaskTimeoutAction = cm.l.conf.WatchdogAction
+ dog := watchdog.New(k, dogOpts)
+
+ // Change the loader fields to reflect the changes made when restoring.
+ cm.l.k = k
+ cm.l.watchdog = dog
+ cm.l.rootProcArgs = kernel.CreateProcessArgs{}
+ cm.l.restore = true
+
+ // Reinitialize the sandbox ID and processes map. Note that it doesn't
+ // restore the state of multiple containers, nor exec processes.
+ cm.l.sandboxID = o.SandboxID
+ cm.l.mu.Lock()
+ eid := execID{cid: o.SandboxID}
+ cm.l.processes = map[execID]*execProcess{
+ eid: {
+ tg: cm.l.k.GlobalInit(),
+ },
+ }
+ cm.l.mu.Unlock()
+
+ // Tell the root container to start and wait for the result.
+ cm.startChan <- struct{}{}
+ if err := <-cm.startResultChan; err != nil {
+ return fmt.Errorf("starting sandbox: %v", err)
+ }
+
+ return nil
+}
+
+// Resume unpauses a container.
+func (cm *containerManager) Resume(_, _ *struct{}) error {
+ log.Debugf("containerManager.Resume")
+ cm.l.k.Unpause()
+ return nil
+}
+
+// Wait waits for the init process in the given container.
+func (cm *containerManager) Wait(cid *string, waitStatus *uint32) error {
+ log.Debugf("containerManager.Wait")
+ err := cm.l.waitContainer(*cid, waitStatus)
+ log.Debugf("containerManager.Wait returned, waitStatus: %v: %v", waitStatus, err)
+ return err
+}
+
+// WaitPIDArgs are arguments to the WaitPID method.
+type WaitPIDArgs struct {
+ // PID is the PID in the container's PID namespace.
+ PID int32
+
+ // CID is the container ID.
+ CID string
+}
+
+// WaitPID waits for the process with PID 'pid' in the sandbox.
+func (cm *containerManager) WaitPID(args *WaitPIDArgs, waitStatus *uint32) error {
+ log.Debugf("containerManager.Wait")
+ return cm.l.waitPID(kernel.ThreadID(args.PID), args.CID, waitStatus)
+}
+
+// SignalDeliveryMode enumerates different signal delivery modes.
+type SignalDeliveryMode int
+
+const (
+ // DeliverToProcess delivers the signal to the container process with
+ // the specified PID. If PID is 0, then the container init process is
+ // signaled.
+ DeliverToProcess SignalDeliveryMode = iota
+
+ // DeliverToAllProcesses delivers the signal to all processes in the
+ // container. PID must be 0.
+ DeliverToAllProcesses
+
+ // DeliverToForegroundProcessGroup delivers the signal to the
+ // foreground process group in the same TTY session as the specified
+ // process. If PID is 0, then the signal is delivered to the foreground
+ // process group for the TTY for the init process.
+ DeliverToForegroundProcessGroup
+)
+
+func (s SignalDeliveryMode) String() string {
+ switch s {
+ case DeliverToProcess:
+ return "Process"
+ case DeliverToAllProcesses:
+ return "All"
+ case DeliverToForegroundProcessGroup:
+ return "Foreground Process Group"
+ }
+ return fmt.Sprintf("unknown signal delivery mode: %d", s)
+}
+
+// SignalArgs are arguments to the Signal method.
+type SignalArgs struct {
+ // CID is the container ID.
+ CID string
+
+ // Signo is the signal to send to the process.
+ Signo int32
+
+ // PID is the process ID in the given container that will be signaled.
+ // If 0, the root container will be signalled.
+ PID int32
+
+ // Mode is the signal delivery mode.
+ Mode SignalDeliveryMode
+}
+
+// Signal sends a signal to one or more processes in a container. If args.PID
+// is 0, then the container init process is used. Depending on the
+// args.SignalDeliveryMode option, the signal may be sent directly to the
+// indicated process, to all processes in the container, or to the foreground
+// process group.
+func (cm *containerManager) Signal(args *SignalArgs, _ *struct{}) error {
+ log.Debugf("containerManager.Signal %+v", args)
+ return cm.l.signal(args.CID, args.PID, args.Signo, args.Mode)
+}
diff --git a/runsc/boot/debug.go b/runsc/boot/debug.go
new file mode 100644
index 000000000..1fb32c527
--- /dev/null
+++ b/runsc/boot/debug.go
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+type debug struct {
+}
+
+// Stacks collects all sandbox stacks and copies them to 'stacks'.
+func (*debug) Stacks(_ *struct{}, stacks *string) error {
+ buf := log.Stacks(true)
+ *stacks = string(buf)
+ return nil
+}
diff --git a/runsc/boot/events.go b/runsc/boot/events.go
new file mode 100644
index 000000000..422f4da00
--- /dev/null
+++ b/runsc/boot/events.go
@@ -0,0 +1,81 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+)
+
+// Event struct for encoding the event data to JSON. Corresponds to runc's
+// main.event struct.
+type Event struct {
+ Type string `json:"type"`
+ ID string `json:"id"`
+ Data interface{} `json:"data,omitempty"`
+}
+
+// Stats is the runc specific stats structure for stability when encoding and
+// decoding stats.
+type Stats struct {
+ Memory Memory `json:"memory"`
+ Pids Pids `json:"pids"`
+}
+
+// Pids contains stats on processes.
+type Pids struct {
+ Current uint64 `json:"current,omitempty"`
+ Limit uint64 `json:"limit,omitempty"`
+}
+
+// MemoryEntry contains stats on a kind of memory.
+type MemoryEntry struct {
+ Limit uint64 `json:"limit"`
+ Usage uint64 `json:"usage,omitempty"`
+ Max uint64 `json:"max,omitempty"`
+ Failcnt uint64 `json:"failcnt"`
+}
+
+// Memory contains stats on memory.
+type Memory struct {
+ Cache uint64 `json:"cache,omitempty"`
+ Usage MemoryEntry `json:"usage,omitempty"`
+ Swap MemoryEntry `json:"swap,omitempty"`
+ Kernel MemoryEntry `json:"kernel,omitempty"`
+ KernelTCP MemoryEntry `json:"kernelTCP,omitempty"`
+ Raw map[string]uint64 `json:"raw,omitempty"`
+}
+
+// Event gets the events from the container.
+func (cm *containerManager) Event(_ *struct{}, out *Event) error {
+ stats := &Stats{}
+ stats.populateMemory(cm.l.k)
+ stats.populatePIDs(cm.l.k)
+ *out = Event{Type: "stats", Data: stats}
+ return nil
+}
+
+func (s *Stats) populateMemory(k *kernel.Kernel) {
+ mem := k.MemoryFile()
+ mem.UpdateUsage()
+ _, totalUsage := usage.MemoryAccounting.Copy()
+ s.Memory.Usage = MemoryEntry{
+ Usage: totalUsage,
+ }
+}
+
+func (s *Stats) populatePIDs(k *kernel.Kernel) {
+ s.Pids.Current = uint64(len(k.TaskSet().Root.ThreadGroups()))
+}
diff --git a/runsc/boot/filter/BUILD b/runsc/boot/filter/BUILD
new file mode 100644
index 000000000..ed18f0047
--- /dev/null
+++ b/runsc/boot/filter/BUILD
@@ -0,0 +1,28 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "filter",
+ srcs = [
+ "config.go",
+ "config_amd64.go",
+ "config_arm64.go",
+ "config_profile.go",
+ "extra_filters.go",
+ "extra_filters_msan.go",
+ "extra_filters_race.go",
+ "filter.go",
+ ],
+ visibility = [
+ "//runsc/boot:__subpackages__",
+ ],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/log",
+ "//pkg/seccomp",
+ "//pkg/sentry/platform",
+ "//pkg/tcpip/link/fdbased",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go
new file mode 100644
index 000000000..60e33425f
--- /dev/null
+++ b/runsc/boot/filter/config.go
@@ -0,0 +1,559 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package filter
+
+import (
+ "os"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/seccomp"
+ "gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
+)
+
+// allowedSyscalls is the set of syscalls executed by the Sentry to the host OS.
+var allowedSyscalls = seccomp.SyscallRules{
+ syscall.SYS_CLOCK_GETTIME: {},
+ syscall.SYS_CLONE: []seccomp.Rule{
+ {
+ seccomp.AllowValue(
+ syscall.CLONE_VM |
+ syscall.CLONE_FS |
+ syscall.CLONE_FILES |
+ syscall.CLONE_SIGHAND |
+ syscall.CLONE_SYSVSEM |
+ syscall.CLONE_THREAD),
+ },
+ },
+ syscall.SYS_CLOSE: {},
+ syscall.SYS_DUP: {},
+ syscall.SYS_DUP3: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.O_CLOEXEC),
+ },
+ },
+ syscall.SYS_EPOLL_CREATE1: {},
+ syscall.SYS_EPOLL_CTL: {},
+ syscall.SYS_EPOLL_PWAIT: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(0),
+ },
+ },
+ syscall.SYS_EVENTFD2: []seccomp.Rule{
+ {
+ seccomp.AllowValue(0),
+ seccomp.AllowValue(0),
+ },
+ },
+ syscall.SYS_EXIT: {},
+ syscall.SYS_EXIT_GROUP: {},
+ syscall.SYS_FALLOCATE: {},
+ syscall.SYS_FCHMOD: {},
+ syscall.SYS_FCNTL: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.F_GETFL),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.F_SETFL),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.F_GETFD),
+ },
+ },
+ syscall.SYS_FSTAT: {},
+ syscall.SYS_FSYNC: {},
+ syscall.SYS_FTRUNCATE: {},
+ syscall.SYS_FUTEX: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG),
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG),
+ seccomp.AllowAny{},
+ },
+ // Non-private variants are included for flipcall support. They are otherwise
+ // unncessary, as the sentry will use only private futexes internally.
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAIT),
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAKE),
+ seccomp.AllowAny{},
+ },
+ },
+ syscall.SYS_GETPID: {},
+ unix.SYS_GETRANDOM: {},
+ syscall.SYS_GETSOCKOPT: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_SOCKET),
+ seccomp.AllowValue(syscall.SO_DOMAIN),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_SOCKET),
+ seccomp.AllowValue(syscall.SO_TYPE),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_SOCKET),
+ seccomp.AllowValue(syscall.SO_ERROR),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_SOCKET),
+ seccomp.AllowValue(syscall.SO_SNDBUF),
+ },
+ },
+ syscall.SYS_GETTID: {},
+ syscall.SYS_GETTIMEOFDAY: {},
+ // SYS_IOCTL is needed for terminal support, but we only allow
+ // setting/getting termios and winsize.
+ syscall.SYS_IOCTL: []seccomp.Rule{
+ {
+ seccomp.AllowAny{}, /* fd */
+ seccomp.AllowValue(linux.TCGETS),
+ seccomp.AllowAny{}, /* termios struct */
+ },
+ {
+ seccomp.AllowAny{}, /* fd */
+ seccomp.AllowValue(linux.TCSETS),
+ seccomp.AllowAny{}, /* termios struct */
+ },
+ {
+ seccomp.AllowAny{}, /* fd */
+ seccomp.AllowValue(linux.TCSETSF),
+ seccomp.AllowAny{}, /* termios struct */
+ },
+ {
+ seccomp.AllowAny{}, /* fd */
+ seccomp.AllowValue(linux.TCSETSW),
+ seccomp.AllowAny{}, /* termios struct */
+ },
+ {
+ seccomp.AllowAny{}, /* fd */
+ seccomp.AllowValue(linux.TIOCSWINSZ),
+ seccomp.AllowAny{}, /* winsize struct */
+ },
+ {
+ seccomp.AllowAny{}, /* fd */
+ seccomp.AllowValue(linux.TIOCGWINSZ),
+ seccomp.AllowAny{}, /* winsize struct */
+ },
+ },
+ syscall.SYS_LSEEK: {},
+ syscall.SYS_MADVISE: {},
+ syscall.SYS_MINCORE: {},
+ // Used by the Go runtime as a temporarily workaround for a Linux
+ // 5.2-5.4 bug.
+ //
+ // See src/runtime/os_linux_x86.go.
+ //
+ // TODO(b/148688965): Remove once this is gone from Go.
+ syscall.SYS_MLOCK: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(4096),
+ },
+ },
+ syscall.SYS_MMAP: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.MAP_SHARED),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.MAP_PRIVATE),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_STACK),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_NORESERVE),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.PROT_WRITE | syscall.PROT_READ),
+ seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_FIXED),
+ },
+ },
+ syscall.SYS_MPROTECT: {},
+ syscall.SYS_MUNMAP: {},
+ syscall.SYS_NANOSLEEP: {},
+ syscall.SYS_PPOLL: {},
+ syscall.SYS_PREAD64: {},
+ syscall.SYS_PREADV: {},
+ unix.SYS_PREADV2: {},
+ syscall.SYS_PWRITE64: {},
+ syscall.SYS_PWRITEV: {},
+ unix.SYS_PWRITEV2: {},
+ syscall.SYS_READ: {},
+ syscall.SYS_RECVMSG: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC | syscall.MSG_PEEK),
+ },
+ },
+ syscall.SYS_RECVMMSG: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(fdbased.MaxMsgsPerRecv),
+ seccomp.AllowValue(syscall.MSG_DONTWAIT),
+ seccomp.AllowValue(0),
+ },
+ },
+ unix.SYS_SENDMMSG: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.MSG_DONTWAIT),
+ seccomp.AllowValue(0),
+ },
+ },
+ syscall.SYS_RESTART_SYSCALL: {},
+ syscall.SYS_RT_SIGACTION: {},
+ syscall.SYS_RT_SIGPROCMASK: {},
+ syscall.SYS_RT_SIGRETURN: {},
+ syscall.SYS_SCHED_YIELD: {},
+ syscall.SYS_SENDMSG: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_NOSIGNAL),
+ },
+ },
+ syscall.SYS_SETITIMER: {},
+ syscall.SYS_SHUTDOWN: []seccomp.Rule{
+ // Used by fs/host to shutdown host sockets.
+ {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RD)},
+ {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_WR)},
+ // Used by unet to shutdown connections.
+ {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RDWR)},
+ },
+ syscall.SYS_SIGALTSTACK: {},
+ unix.SYS_STATX: {},
+ syscall.SYS_SYNC_FILE_RANGE: {},
+ syscall.SYS_TEE: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(1), /* len */
+ seccomp.AllowValue(unix.SPLICE_F_NONBLOCK), /* flags */
+ },
+ },
+ syscall.SYS_TGKILL: []seccomp.Rule{
+ {
+ seccomp.AllowValue(uint64(os.Getpid())),
+ },
+ },
+ syscall.SYS_UTIMENSAT: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(0), /* null pathname */
+ seccomp.AllowAny{},
+ seccomp.AllowValue(0), /* flags */
+ },
+ },
+ syscall.SYS_WRITE: {},
+ // The only user in rawfile.NonBlockingWrite3 always passes iovcnt with
+ // values 2 or 3. Three iovec-s are passed, when the PACKET_VNET_HDR
+ // option is enabled for a packet socket.
+ syscall.SYS_WRITEV: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(2),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(3),
+ },
+ },
+}
+
+// hostInetFilters contains syscalls that are needed by sentry/socket/hostinet.
+func hostInetFilters() seccomp.SyscallRules {
+ return seccomp.SyscallRules{
+ syscall.SYS_ACCEPT4: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
+ },
+ },
+ syscall.SYS_BIND: {},
+ syscall.SYS_CONNECT: {},
+ syscall.SYS_GETPEERNAME: {},
+ syscall.SYS_GETSOCKNAME: {},
+ syscall.SYS_GETSOCKOPT: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_IP),
+ seccomp.AllowValue(syscall.IP_TOS),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_IP),
+ seccomp.AllowValue(syscall.IP_RECVTOS),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_IPV6),
+ seccomp.AllowValue(syscall.IPV6_TCLASS),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_IPV6),
+ seccomp.AllowValue(syscall.IPV6_RECVTCLASS),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_IPV6),
+ seccomp.AllowValue(syscall.IPV6_V6ONLY),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_SOCKET),
+ seccomp.AllowValue(syscall.SO_ERROR),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_SOCKET),
+ seccomp.AllowValue(syscall.SO_KEEPALIVE),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_SOCKET),
+ seccomp.AllowValue(syscall.SO_SNDBUF),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_SOCKET),
+ seccomp.AllowValue(syscall.SO_RCVBUF),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_SOCKET),
+ seccomp.AllowValue(syscall.SO_REUSEADDR),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_SOCKET),
+ seccomp.AllowValue(syscall.SO_TYPE),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_SOCKET),
+ seccomp.AllowValue(syscall.SO_LINGER),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_TCP),
+ seccomp.AllowValue(syscall.TCP_NODELAY),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_TCP),
+ seccomp.AllowValue(syscall.TCP_INFO),
+ },
+ },
+ syscall.SYS_IOCTL: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.TIOCOUTQ),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.TIOCINQ),
+ },
+ },
+ syscall.SYS_LISTEN: {},
+ syscall.SYS_READV: {},
+ syscall.SYS_RECVFROM: {},
+ syscall.SYS_RECVMSG: {},
+ syscall.SYS_SENDMSG: {},
+ syscall.SYS_SENDTO: {},
+ syscall.SYS_SETSOCKOPT: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_IPV6),
+ seccomp.AllowValue(syscall.IPV6_V6ONLY),
+ seccomp.AllowAny{},
+ seccomp.AllowValue(4),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_SOCKET),
+ seccomp.AllowValue(syscall.SO_SNDBUF),
+ seccomp.AllowAny{},
+ seccomp.AllowValue(4),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_SOCKET),
+ seccomp.AllowValue(syscall.SO_RCVBUF),
+ seccomp.AllowAny{},
+ seccomp.AllowValue(4),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_SOCKET),
+ seccomp.AllowValue(syscall.SO_REUSEADDR),
+ seccomp.AllowAny{},
+ seccomp.AllowValue(4),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_TCP),
+ seccomp.AllowValue(syscall.TCP_NODELAY),
+ seccomp.AllowAny{},
+ seccomp.AllowValue(4),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_IP),
+ seccomp.AllowValue(syscall.IP_TOS),
+ seccomp.AllowAny{},
+ seccomp.AllowValue(4),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_IP),
+ seccomp.AllowValue(syscall.IP_RECVTOS),
+ seccomp.AllowAny{},
+ seccomp.AllowValue(4),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_IPV6),
+ seccomp.AllowValue(syscall.IPV6_TCLASS),
+ seccomp.AllowAny{},
+ seccomp.AllowValue(4),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_IPV6),
+ seccomp.AllowValue(syscall.IPV6_RECVTCLASS),
+ seccomp.AllowAny{},
+ seccomp.AllowValue(4),
+ },
+ },
+ syscall.SYS_SHUTDOWN: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SHUT_RD),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SHUT_WR),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SHUT_RDWR),
+ },
+ },
+ syscall.SYS_SOCKET: []seccomp.Rule{
+ {
+ seccomp.AllowValue(syscall.AF_INET),
+ seccomp.AllowValue(syscall.SOCK_STREAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
+ seccomp.AllowValue(0),
+ },
+ {
+ seccomp.AllowValue(syscall.AF_INET),
+ seccomp.AllowValue(syscall.SOCK_DGRAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
+ seccomp.AllowValue(0),
+ },
+ {
+ seccomp.AllowValue(syscall.AF_INET6),
+ seccomp.AllowValue(syscall.SOCK_STREAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
+ seccomp.AllowValue(0),
+ },
+ {
+ seccomp.AllowValue(syscall.AF_INET6),
+ seccomp.AllowValue(syscall.SOCK_DGRAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
+ seccomp.AllowValue(0),
+ },
+ },
+ syscall.SYS_WRITEV: {},
+ }
+}
+
+func controlServerFilters(fd int) seccomp.SyscallRules {
+ return seccomp.SyscallRules{
+ syscall.SYS_ACCEPT: []seccomp.Rule{
+ {
+ seccomp.AllowValue(fd),
+ },
+ },
+ syscall.SYS_LISTEN: []seccomp.Rule{
+ {
+ seccomp.AllowValue(fd),
+ seccomp.AllowValue(16 /* unet.backlog */),
+ },
+ },
+ syscall.SYS_GETSOCKOPT: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_SOCKET),
+ seccomp.AllowValue(syscall.SO_PEERCRED),
+ },
+ },
+ }
+}
diff --git a/runsc/boot/filter/config_amd64.go b/runsc/boot/filter/config_amd64.go
new file mode 100644
index 000000000..5335ff82c
--- /dev/null
+++ b/runsc/boot/filter/config_amd64.go
@@ -0,0 +1,31 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package filter
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+func init() {
+ allowedSyscalls[syscall.SYS_ARCH_PRCTL] = append(allowedSyscalls[syscall.SYS_ARCH_PRCTL],
+ seccomp.Rule{seccomp.AllowValue(linux.ARCH_GET_FS)},
+ seccomp.Rule{seccomp.AllowValue(linux.ARCH_SET_FS)},
+ )
+}
diff --git a/runsc/boot/filter/config_arm64.go b/runsc/boot/filter/config_arm64.go
new file mode 100644
index 000000000..7fa9bbda3
--- /dev/null
+++ b/runsc/boot/filter/config_arm64.go
@@ -0,0 +1,21 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package filter
+
+// Reserve for future customization.
+func init() {
+}
diff --git a/runsc/boot/filter/config_profile.go b/runsc/boot/filter/config_profile.go
new file mode 100644
index 000000000..194952a7b
--- /dev/null
+++ b/runsc/boot/filter/config_profile.go
@@ -0,0 +1,34 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package filter
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+// profileFilters returns extra syscalls made by runtime/pprof package.
+func profileFilters() seccomp.SyscallRules {
+ return seccomp.SyscallRules{
+ syscall.SYS_OPENAT: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.O_RDONLY | syscall.O_LARGEFILE | syscall.O_CLOEXEC),
+ },
+ },
+ }
+}
diff --git a/runsc/boot/filter/extra_filters.go b/runsc/boot/filter/extra_filters.go
new file mode 100644
index 000000000..e28d4b8d6
--- /dev/null
+++ b/runsc/boot/filter/extra_filters.go
@@ -0,0 +1,28 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !msan,!race
+
+package filter
+
+import (
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+// instrumentationFilters returns additional filters for syscalls used by
+// Go instrumentation tools, e.g. -race, -msan.
+// Returns empty when disabled.
+func instrumentationFilters() seccomp.SyscallRules {
+ return nil
+}
diff --git a/runsc/boot/filter/extra_filters_msan.go b/runsc/boot/filter/extra_filters_msan.go
new file mode 100644
index 000000000..209e646a7
--- /dev/null
+++ b/runsc/boot/filter/extra_filters_msan.go
@@ -0,0 +1,34 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build msan
+
+package filter
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+// instrumentationFilters returns additional filters for syscalls used by MSAN.
+func instrumentationFilters() seccomp.SyscallRules {
+ Report("MSAN is enabled: syscall filters less restrictive!")
+ return seccomp.SyscallRules{
+ syscall.SYS_CLONE: {},
+ syscall.SYS_MMAP: {},
+ syscall.SYS_SCHED_GETAFFINITY: {},
+ syscall.SYS_SET_ROBUST_LIST: {},
+ }
+}
diff --git a/runsc/boot/filter/extra_filters_race.go b/runsc/boot/filter/extra_filters_race.go
new file mode 100644
index 000000000..9ff80276a
--- /dev/null
+++ b/runsc/boot/filter/extra_filters_race.go
@@ -0,0 +1,41 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build race
+
+package filter
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+// instrumentationFilters returns additional filters for syscalls used by TSAN.
+func instrumentationFilters() seccomp.SyscallRules {
+ Report("TSAN is enabled: syscall filters less restrictive!")
+ return seccomp.SyscallRules{
+ syscall.SYS_BRK: {},
+ syscall.SYS_CLONE: {},
+ syscall.SYS_FUTEX: {},
+ syscall.SYS_MMAP: {},
+ syscall.SYS_MUNLOCK: {},
+ syscall.SYS_NANOSLEEP: {},
+ syscall.SYS_OPEN: {},
+ syscall.SYS_OPENAT: {},
+ syscall.SYS_SET_ROBUST_LIST: {},
+ // Used within glibc's malloc.
+ syscall.SYS_TIME: {},
+ }
+}
diff --git a/runsc/boot/filter/filter.go b/runsc/boot/filter/filter.go
new file mode 100644
index 000000000..e80c171b3
--- /dev/null
+++ b/runsc/boot/filter/filter.go
@@ -0,0 +1,60 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package filter defines all syscalls the sandbox is allowed to make
+// to the host, and installs seccomp filters to prevent prohibited
+// syscalls in case it's compromised.
+package filter
+
+import (
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/seccomp"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+)
+
+// Options are seccomp filter related options.
+type Options struct {
+ Platform platform.Platform
+ HostNetwork bool
+ ProfileEnable bool
+ ControllerFD int
+}
+
+// Install installs seccomp filters for based on the given platform.
+func Install(opt Options) error {
+ s := allowedSyscalls
+ s.Merge(controlServerFilters(opt.ControllerFD))
+
+ // Set of additional filters used by -race and -msan. Returns empty
+ // when not enabled.
+ s.Merge(instrumentationFilters())
+
+ if opt.HostNetwork {
+ Report("host networking enabled: syscall filters less restrictive!")
+ s.Merge(hostInetFilters())
+ }
+ if opt.ProfileEnable {
+ Report("profile enabled: syscall filters less restrictive!")
+ s.Merge(profileFilters())
+ }
+
+ s.Merge(opt.Platform.SyscallFilters())
+
+ return seccomp.Install(s)
+}
+
+// Report writes a warning message to the log.
+func Report(msg string) {
+ log.Warningf("*** SECCOMP WARNING: %s", msg)
+}
diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go
new file mode 100644
index 000000000..59639ba19
--- /dev/null
+++ b/runsc/boot/fs.go
@@ -0,0 +1,1034 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "fmt"
+ "path/filepath"
+ "sort"
+ "strconv"
+ "strings"
+ "syscall"
+
+ // Include filesystem types that OCI spec might mount.
+ _ "gvisor.dev/gvisor/pkg/sentry/fs/dev"
+ _ "gvisor.dev/gvisor/pkg/sentry/fs/host"
+ _ "gvisor.dev/gvisor/pkg/sentry/fs/proc"
+ _ "gvisor.dev/gvisor/pkg/sentry/fs/sys"
+ _ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
+ _ "gvisor.dev/gvisor/pkg/sentry/fs/tty"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/gofer"
+ "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/user"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devpts"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ gofervfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/gofer"
+ procvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc"
+ sysvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/sys"
+ tmpfsvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+const (
+ // Device name for root mount.
+ rootDevice = "9pfs-/"
+
+ // MountPrefix is the annotation prefix for mount hints.
+ MountPrefix = "dev.gvisor.spec.mount."
+
+ // Supported filesystems that map to different internal filesystem.
+ bind = "bind"
+ nonefs = "none"
+)
+
+// tmpfs has some extra supported options that we must pass through.
+var tmpfsAllowedData = []string{"mode", "uid", "gid"}
+
+func addOverlay(ctx context.Context, conf *Config, lower *fs.Inode, name string, lowerFlags fs.MountSourceFlags) (*fs.Inode, error) {
+ // Upper layer uses the same flags as lower, but it must be read-write.
+ upperFlags := lowerFlags
+ upperFlags.ReadOnly = false
+
+ tmpFS := mustFindFilesystem("tmpfs")
+ if !fs.IsDir(lower.StableAttr) {
+ // Create overlay on top of mount file, e.g. /etc/hostname.
+ msrc := fs.NewCachingMountSource(ctx, tmpFS, upperFlags)
+ return fs.NewOverlayRootFile(ctx, msrc, lower, upperFlags)
+ }
+
+ // Create overlay on top of mount dir.
+ upper, err := tmpFS.Mount(ctx, name+"-upper", upperFlags, "", nil)
+ if err != nil {
+ return nil, fmt.Errorf("creating tmpfs overlay: %v", err)
+ }
+
+ // Replicate permissions and owner from lower to upper mount point.
+ attr, err := lower.UnstableAttr(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("reading attributes from lower mount point: %v", err)
+ }
+ if !upper.InodeOperations.SetPermissions(ctx, upper, attr.Perms) {
+ return nil, fmt.Errorf("error setting permission to upper mount point")
+ }
+ if err := upper.InodeOperations.SetOwner(ctx, upper, attr.Owner); err != nil {
+ return nil, fmt.Errorf("setting owner to upper mount point: %v", err)
+ }
+
+ return fs.NewOverlayRoot(ctx, upper, lower, upperFlags)
+}
+
+// compileMounts returns the supported mounts from the mount spec, adding any
+// mandatory mounts that are required by the OCI specification.
+func compileMounts(spec *specs.Spec) []specs.Mount {
+ // Keep track of whether proc and sys were mounted.
+ var procMounted, sysMounted bool
+ var mounts []specs.Mount
+
+ // Always mount /dev.
+ mounts = append(mounts, specs.Mount{
+ Type: devtmpfs.Name,
+ Destination: "/dev",
+ })
+
+ mounts = append(mounts, specs.Mount{
+ Type: devpts.Name,
+ Destination: "/dev/pts",
+ })
+
+ // Mount all submounts from the spec.
+ for _, m := range spec.Mounts {
+ if !specutils.IsSupportedDevMount(m) {
+ log.Warningf("ignoring dev mount at %q", m.Destination)
+ continue
+ }
+ mounts = append(mounts, m)
+ switch filepath.Clean(m.Destination) {
+ case "/proc":
+ procMounted = true
+ case "/sys":
+ sysMounted = true
+ }
+ }
+
+ // Mount proc and sys even if the user did not ask for it, as the spec
+ // says we SHOULD.
+ var mandatoryMounts []specs.Mount
+ if !procMounted {
+ mandatoryMounts = append(mandatoryMounts, specs.Mount{
+ Type: procvfs2.Name,
+ Destination: "/proc",
+ })
+ }
+ if !sysMounted {
+ mandatoryMounts = append(mandatoryMounts, specs.Mount{
+ Type: sysvfs2.Name,
+ Destination: "/sys",
+ })
+ }
+
+ // The mandatory mounts should be ordered right after the root, in case
+ // there are submounts of these mandatory mounts already in the spec.
+ mounts = append(mounts[:0], append(mandatoryMounts, mounts[0:]...)...)
+
+ return mounts
+}
+
+// p9MountData creates a slice of p9 mount data.
+func p9MountData(fd int, fa FileAccessType, vfs2 bool) []string {
+ opts := []string{
+ "trans=fd",
+ "rfdno=" + strconv.Itoa(fd),
+ "wfdno=" + strconv.Itoa(fd),
+ }
+ if !vfs2 {
+ // privateunixsocket is always enabled in VFS2. VFS1 requires explicit
+ // enablement.
+ opts = append(opts, "privateunixsocket=true")
+ }
+ if fa == FileAccessShared {
+ opts = append(opts, "cache=remote_revalidating")
+ }
+ return opts
+}
+
+// parseAndFilterOptions parses a MountOptions slice and filters by the allowed
+// keys.
+func parseAndFilterOptions(opts []string, allowedKeys ...string) ([]string, error) {
+ var out []string
+ for _, o := range opts {
+ ok, err := parseMountOption(o, allowedKeys...)
+ if err != nil {
+ return nil, err
+ }
+ if ok {
+ out = append(out, o)
+ }
+ }
+ return out, nil
+}
+
+func parseMountOption(opt string, allowedKeys ...string) (bool, error) {
+ kv := strings.SplitN(opt, "=", 3)
+ if len(kv) > 2 {
+ return false, fmt.Errorf("invalid option %q", opt)
+ }
+ return specutils.ContainsStr(allowedKeys, kv[0]), nil
+}
+
+// mountDevice returns a device string based on the fs type and target
+// of the mount.
+func mountDevice(m specs.Mount) string {
+ if m.Type == bind {
+ // Make a device string that includes the target, which is consistent across
+ // S/R and uniquely identifies the connection.
+ return "9pfs-" + m.Destination
+ }
+ // All other fs types use device "none".
+ return "none"
+}
+
+func mountFlags(opts []string) fs.MountSourceFlags {
+ mf := fs.MountSourceFlags{}
+ // Note: changes to supported options must be reflected in
+ // isSupportedMountFlag() as well.
+ for _, o := range opts {
+ switch o {
+ case "rw":
+ mf.ReadOnly = false
+ case "ro":
+ mf.ReadOnly = true
+ case "noatime":
+ mf.NoAtime = true
+ case "noexec":
+ mf.NoExec = true
+ default:
+ log.Warningf("ignoring unknown mount option %q", o)
+ }
+ }
+ return mf
+}
+
+func isSupportedMountFlag(fstype, opt string) bool {
+ switch opt {
+ case "rw", "ro", "noatime", "noexec":
+ return true
+ }
+ if fstype == tmpfsvfs2.Name {
+ ok, err := parseMountOption(opt, tmpfsAllowedData...)
+ return ok && err == nil
+ }
+ return false
+}
+
+func mustFindFilesystem(name string) fs.Filesystem {
+ fs, ok := fs.FindFilesystem(name)
+ if !ok {
+ panic(fmt.Sprintf("could not find filesystem %q", name))
+ }
+ return fs
+}
+
+// addSubmountOverlay overlays the inode over a ramfs tree containing the given
+// paths.
+func addSubmountOverlay(ctx context.Context, inode *fs.Inode, submounts []string) (*fs.Inode, error) {
+ // Construct a ramfs tree of mount points. The contents never
+ // change, so this can be fully caching. There's no real
+ // filesystem backing this tree, so we set the filesystem to
+ // nil.
+ msrc := fs.NewCachingMountSource(ctx, nil, fs.MountSourceFlags{})
+ mountTree, err := ramfs.MakeDirectoryTree(ctx, msrc, submounts)
+ if err != nil {
+ return nil, fmt.Errorf("creating mount tree: %v", err)
+ }
+ overlayInode, err := fs.NewOverlayRoot(ctx, inode, mountTree, fs.MountSourceFlags{})
+ if err != nil {
+ return nil, fmt.Errorf("adding mount overlay: %v", err)
+ }
+ return overlayInode, err
+}
+
+// subtargets takes a set of Mounts and returns only the targets that are
+// children of the given root. The returned paths are relative to the root.
+func subtargets(root string, mnts []specs.Mount) []string {
+ var targets []string
+ for _, mnt := range mnts {
+ if relPath, isSubpath := fs.IsSubpath(mnt.Destination, root); isSubpath {
+ targets = append(targets, relPath)
+ }
+ }
+ return targets
+}
+
+func setupContainerFS(ctx context.Context, conf *Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error {
+ if conf.VFS2 {
+ return setupContainerVFS2(ctx, conf, mntr, procArgs)
+ }
+ mns, err := mntr.setupFS(conf, procArgs)
+ if err != nil {
+ return err
+ }
+
+ // Set namespace here so that it can be found in ctx.
+ procArgs.MountNamespace = mns
+
+ // Resolve the executable path from working dir and environment.
+ resolved, err := user.ResolveExecutablePath(ctx, procArgs)
+ if err != nil {
+ return err
+ }
+ procArgs.Filename = resolved
+ return nil
+}
+
+func adjustDirentCache(k *kernel.Kernel) error {
+ var hl syscall.Rlimit
+ if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &hl); err != nil {
+ return fmt.Errorf("getting RLIMIT_NOFILE: %v", err)
+ }
+ if int64(hl.Cur) != syscall.RLIM_INFINITY {
+ newSize := hl.Cur / 2
+ if newSize < gofer.DefaultDirentCacheSize {
+ log.Infof("Setting gofer dirent cache size to %d", newSize)
+ gofer.DefaultDirentCacheSize = newSize
+ k.DirentCacheLimiter = fs.NewDirentCacheLimiter(newSize)
+ }
+ }
+ return nil
+}
+
+type fdDispenser struct {
+ fds []int
+}
+
+func (f *fdDispenser) remove() int {
+ if f.empty() {
+ panic("fdDispenser out of fds")
+ }
+ rv := f.fds[0]
+ f.fds = f.fds[1:]
+ return rv
+}
+
+func (f *fdDispenser) empty() bool {
+ return len(f.fds) == 0
+}
+
+type shareType int
+
+const (
+ invalid shareType = iota
+
+ // container shareType indicates that the mount is used by a single container.
+ container
+
+ // pod shareType indicates that the mount is used by more than one container
+ // inside the pod.
+ pod
+
+ // shared shareType indicates that the mount can also be shared with a process
+ // outside the pod, e.g. NFS.
+ shared
+)
+
+func parseShare(val string) (shareType, error) {
+ switch val {
+ case "container":
+ return container, nil
+ case "pod":
+ return pod, nil
+ case "shared":
+ return shared, nil
+ default:
+ return 0, fmt.Errorf("invalid share value %q", val)
+ }
+}
+
+func (s shareType) String() string {
+ switch s {
+ case invalid:
+ return "invalid"
+ case container:
+ return "container"
+ case pod:
+ return "pod"
+ case shared:
+ return "shared"
+ default:
+ return fmt.Sprintf("invalid share value %d", s)
+ }
+}
+
+// mountHint represents extra information about mounts that are provided via
+// annotations. They can override mount type, and provide sharing information
+// so that mounts can be correctly shared inside the pod.
+type mountHint struct {
+ name string
+ share shareType
+ mount specs.Mount
+
+ // root is the inode where the volume is mounted. For mounts with 'pod' share
+ // the volume is mounted once and then bind mounted inside the containers.
+ root *fs.Inode
+
+ // vfsMount is the master mount for the volume. For mounts with 'pod' share
+ // the master volume is bind mounted inside the containers.
+ vfsMount *vfs.Mount
+}
+
+func (m *mountHint) setField(key, val string) error {
+ switch key {
+ case "source":
+ if len(val) == 0 {
+ return fmt.Errorf("source cannot be empty")
+ }
+ m.mount.Source = val
+ case "type":
+ return m.setType(val)
+ case "share":
+ share, err := parseShare(val)
+ if err != nil {
+ return err
+ }
+ m.share = share
+ case "options":
+ return m.setOptions(val)
+ default:
+ return fmt.Errorf("invalid mount annotation: %s=%s", key, val)
+ }
+ return nil
+}
+
+func (m *mountHint) setType(val string) error {
+ switch val {
+ case "tmpfs", "bind":
+ m.mount.Type = val
+ default:
+ return fmt.Errorf("invalid type %q", val)
+ }
+ return nil
+}
+
+func (m *mountHint) setOptions(val string) error {
+ opts := strings.Split(val, ",")
+ if err := specutils.ValidateMountOptions(opts); err != nil {
+ return err
+ }
+ // Sort options so it can be compared with container mount options later on.
+ sort.Strings(opts)
+ m.mount.Options = opts
+ return nil
+}
+
+func (m *mountHint) isSupported() bool {
+ return m.mount.Type == tmpfsvfs2.Name && m.share == pod
+}
+
+// checkCompatible verifies that shared mount is compatible with master.
+// For now enforce that all options are the same. Once bind mount is properly
+// supported, then we should ensure the master is less restrictive than the
+// container, e.g. master can be 'rw' while container mounts as 'ro'.
+func (m *mountHint) checkCompatible(mount specs.Mount) error {
+ // Remove options that don't affect to mount's behavior.
+ masterOpts := filterUnsupportedOptions(m.mount)
+ slaveOpts := filterUnsupportedOptions(mount)
+
+ if len(masterOpts) != len(slaveOpts) {
+ return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", masterOpts, slaveOpts)
+ }
+
+ sort.Strings(masterOpts)
+ sort.Strings(slaveOpts)
+ for i, opt := range masterOpts {
+ if opt != slaveOpts[i] {
+ return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", masterOpts, slaveOpts)
+ }
+ }
+ return nil
+}
+
+func (m *mountHint) fileAccessType() FileAccessType {
+ if m.share == container {
+ return FileAccessExclusive
+ }
+ return FileAccessShared
+}
+
+func filterUnsupportedOptions(mount specs.Mount) []string {
+ rv := make([]string, 0, len(mount.Options))
+ for _, o := range mount.Options {
+ if isSupportedMountFlag(mount.Type, o) {
+ rv = append(rv, o)
+ }
+ }
+ return rv
+}
+
+// podMountHints contains a collection of mountHints for the pod.
+type podMountHints struct {
+ mounts map[string]*mountHint
+}
+
+func newPodMountHints(spec *specs.Spec) (*podMountHints, error) {
+ mnts := make(map[string]*mountHint)
+ for k, v := range spec.Annotations {
+ // Look for 'dev.gvisor.spec.mount' annotations and parse them.
+ if strings.HasPrefix(k, MountPrefix) {
+ // Remove the prefix and split the rest.
+ parts := strings.Split(k[len(MountPrefix):], ".")
+ if len(parts) != 2 {
+ return nil, fmt.Errorf("invalid mount annotation: %s=%s", k, v)
+ }
+ name := parts[0]
+ if len(name) == 0 {
+ return nil, fmt.Errorf("invalid mount name: %s", name)
+ }
+ mnt := mnts[name]
+ if mnt == nil {
+ mnt = &mountHint{name: name}
+ mnts[name] = mnt
+ }
+ if err := mnt.setField(parts[1], v); err != nil {
+ return nil, err
+ }
+ }
+ }
+
+ // Validate all hints after done parsing.
+ for name, m := range mnts {
+ log.Infof("Mount annotation found, name: %s, source: %q, type: %s, share: %v", name, m.mount.Source, m.mount.Type, m.share)
+ if m.share == invalid {
+ return nil, fmt.Errorf("share field for %q has not been set", m.name)
+ }
+ if len(m.mount.Source) == 0 {
+ return nil, fmt.Errorf("source field for %q has not been set", m.name)
+ }
+ if len(m.mount.Type) == 0 {
+ return nil, fmt.Errorf("type field for %q has not been set", m.name)
+ }
+
+ // Check for duplicate mount sources.
+ for name2, m2 := range mnts {
+ if name != name2 && m.mount.Source == m2.mount.Source {
+ return nil, fmt.Errorf("mounts %q and %q have the same mount source %q", m.name, m2.name, m.mount.Source)
+ }
+ }
+ }
+
+ return &podMountHints{mounts: mnts}, nil
+}
+
+func (p *podMountHints) findMount(mount specs.Mount) *mountHint {
+ for _, m := range p.mounts {
+ if m.mount.Source == mount.Source {
+ return m
+ }
+ }
+ return nil
+}
+
+type containerMounter struct {
+ root *specs.Root
+
+ // mounts is the set of submounts for the container. It's a copy from the spec
+ // that may be freely modified without affecting the original spec.
+ mounts []specs.Mount
+
+ // fds is the list of FDs to be dispensed for mounts that require it.
+ fds fdDispenser
+
+ k *kernel.Kernel
+
+ hints *podMountHints
+}
+
+func newContainerMounter(spec *specs.Spec, goferFDs []int, k *kernel.Kernel, hints *podMountHints) *containerMounter {
+ return &containerMounter{
+ root: spec.Root,
+ mounts: compileMounts(spec),
+ fds: fdDispenser{fds: goferFDs},
+ k: k,
+ hints: hints,
+ }
+}
+
+// processHints processes annotations that container hints about how volumes
+// should be mounted (e.g. a volume shared between containers). It must be
+// called for the root container only.
+func (c *containerMounter) processHints(conf *Config, creds *auth.Credentials) error {
+ if conf.VFS2 {
+ return c.processHintsVFS2(conf, creds)
+ }
+ ctx := c.k.SupervisorContext()
+ for _, hint := range c.hints.mounts {
+ // TODO(b/142076984): Only support tmpfs for now. Bind mounts require a
+ // common gofer to mount all shared volumes.
+ if hint.mount.Type != tmpfsvfs2.Name {
+ continue
+ }
+ log.Infof("Mounting master of shared mount %q from %q type %q", hint.name, hint.mount.Source, hint.mount.Type)
+ inode, err := c.mountSharedMaster(ctx, conf, hint)
+ if err != nil {
+ return fmt.Errorf("mounting shared master %q: %v", hint.name, err)
+ }
+ hint.root = inode
+ }
+ return nil
+}
+
+// setupFS is used to set up the file system for all containers. This is the
+// main entry point method, with most of the other being internal only. It
+// returns the mount namespace that is created for the container.
+func (c *containerMounter) setupFS(conf *Config, procArgs *kernel.CreateProcessArgs) (*fs.MountNamespace, error) {
+ log.Infof("Configuring container's file system")
+
+ // Create context with root credentials to mount the filesystem (the current
+ // user may not be privileged enough).
+ rootProcArgs := *procArgs
+ rootProcArgs.WorkingDirectory = "/"
+ rootProcArgs.Credentials = auth.NewRootCredentials(procArgs.Credentials.UserNamespace)
+ rootProcArgs.Umask = 0022
+ rootProcArgs.MaxSymlinkTraversals = linux.MaxSymlinkTraversals
+ rootCtx := rootProcArgs.NewContext(c.k)
+
+ mns, err := c.createMountNamespace(rootCtx, conf)
+ if err != nil {
+ return nil, err
+ }
+
+ // Set namespace here so that it can be found in rootCtx.
+ rootProcArgs.MountNamespace = mns
+
+ if err := c.mountSubmounts(rootCtx, conf, mns); err != nil {
+ return nil, err
+ }
+ return mns, nil
+}
+
+func (c *containerMounter) createMountNamespace(ctx context.Context, conf *Config) (*fs.MountNamespace, error) {
+ rootInode, err := c.createRootMount(ctx, conf)
+ if err != nil {
+ return nil, fmt.Errorf("creating filesystem for container: %v", err)
+ }
+ mns, err := fs.NewMountNamespace(ctx, rootInode)
+ if err != nil {
+ return nil, fmt.Errorf("creating new mount namespace for container: %v", err)
+ }
+ return mns, nil
+}
+
+func (c *containerMounter) mountSubmounts(ctx context.Context, conf *Config, mns *fs.MountNamespace) error {
+ root := mns.Root()
+ defer root.DecRef()
+
+ for _, m := range c.mounts {
+ log.Debugf("Mounting %q to %q, type: %s, options: %s", m.Source, m.Destination, m.Type, m.Options)
+ if hint := c.hints.findMount(m); hint != nil && hint.isSupported() {
+ if err := c.mountSharedSubmount(ctx, mns, root, m, hint); err != nil {
+ return fmt.Errorf("mount shared mount %q to %q: %v", hint.name, m.Destination, err)
+ }
+ } else {
+ if err := c.mountSubmount(ctx, conf, mns, root, m); err != nil {
+ return fmt.Errorf("mount submount %q: %v", m.Destination, err)
+ }
+ }
+ }
+
+ if err := c.mountTmp(ctx, conf, mns, root); err != nil {
+ return fmt.Errorf("mount submount %q: %v", "tmp", err)
+ }
+
+ if err := c.checkDispenser(); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (c *containerMounter) checkDispenser() error {
+ if !c.fds.empty() {
+ return fmt.Errorf("not all gofer FDs were consumed, remaining: %v", c.fds)
+ }
+ return nil
+}
+
+// mountSharedMaster mounts the master of a volume that is shared among
+// containers in a pod. It returns the root mount's inode.
+func (c *containerMounter) mountSharedMaster(ctx context.Context, conf *Config, hint *mountHint) (*fs.Inode, error) {
+ // Map mount type to filesystem name, and parse out the options that we are
+ // capable of dealing with.
+ fsName, opts, useOverlay, err := c.getMountNameAndOptions(conf, hint.mount)
+ if err != nil {
+ return nil, err
+ }
+ if len(fsName) == 0 {
+ return nil, fmt.Errorf("mount type not supported %q", hint.mount.Type)
+ }
+
+ // Mount with revalidate because it's shared among containers.
+ opts = append(opts, "cache=revalidate")
+
+ // All filesystem names should have been mapped to something we know.
+ filesystem := mustFindFilesystem(fsName)
+
+ mf := mountFlags(hint.mount.Options)
+ if useOverlay {
+ // All writes go to upper, be paranoid and make lower readonly.
+ mf.ReadOnly = true
+ }
+
+ inode, err := filesystem.Mount(ctx, mountDevice(hint.mount), mf, strings.Join(opts, ","), nil)
+ if err != nil {
+ return nil, fmt.Errorf("creating mount %q: %v", hint.name, err)
+ }
+
+ if useOverlay {
+ log.Debugf("Adding overlay on top of shared mount %q", hint.name)
+ inode, err = addOverlay(ctx, conf, inode, hint.mount.Type, mf)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return inode, nil
+}
+
+// createRootMount creates the root filesystem.
+func (c *containerMounter) createRootMount(ctx context.Context, conf *Config) (*fs.Inode, error) {
+ // First construct the filesystem from the spec.Root.
+ mf := fs.MountSourceFlags{ReadOnly: c.root.Readonly || conf.Overlay}
+
+ fd := c.fds.remove()
+ log.Infof("Mounting root over 9P, ioFD: %d", fd)
+ p9FS := mustFindFilesystem("9p")
+ opts := p9MountData(fd, conf.FileAccess, false /* vfs2 */)
+
+ if conf.OverlayfsStaleRead {
+ // We can't check for overlayfs here because sandbox is chroot'ed and gofer
+ // can only send mount options for specs.Mounts (specs.Root is missing
+ // Options field). So assume root is always on top of overlayfs.
+ opts = append(opts, "overlayfs_stale_read")
+ }
+
+ rootInode, err := p9FS.Mount(ctx, rootDevice, mf, strings.Join(opts, ","), nil)
+ if err != nil {
+ return nil, fmt.Errorf("creating root mount point: %v", err)
+ }
+
+ // We need to overlay the root on top of a ramfs with stub directories
+ // for submount paths. "/dev" "/sys" "/proc" and "/tmp" are always
+ // mounted even if they are not in the spec.
+ submounts := append(subtargets("/", c.mounts), "/dev", "/sys", "/proc", "/tmp")
+ rootInode, err = addSubmountOverlay(ctx, rootInode, submounts)
+ if err != nil {
+ return nil, fmt.Errorf("adding submount overlay: %v", err)
+ }
+
+ if conf.Overlay && !c.root.Readonly {
+ log.Debugf("Adding overlay on top of root mount")
+ // Overlay a tmpfs filesystem on top of the root.
+ rootInode, err = addOverlay(ctx, conf, rootInode, "root-overlay-upper", mf)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ log.Infof("Mounted %q to %q type root", c.root.Path, "/")
+ return rootInode, nil
+}
+
+// getMountNameAndOptions retrieves the fsName, opts, and useOverlay values
+// used for mounts.
+func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) (string, []string, bool, error) {
+ var (
+ fsName string
+ opts []string
+ useOverlay bool
+ )
+
+ switch m.Type {
+ case devpts.Name, devtmpfs.Name, procvfs2.Name, sysvfs2.Name:
+ fsName = m.Type
+ case nonefs:
+ fsName = sysvfs2.Name
+ case tmpfsvfs2.Name:
+ fsName = m.Type
+
+ var err error
+ opts, err = parseAndFilterOptions(m.Options, tmpfsAllowedData...)
+ if err != nil {
+ return "", nil, false, err
+ }
+
+ case bind:
+ fd := c.fds.remove()
+ fsName = gofervfs2.Name
+ opts = p9MountData(fd, c.getMountAccessType(m), conf.VFS2)
+ // If configured, add overlay to all writable mounts.
+ useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly
+
+ default:
+ log.Warningf("ignoring unknown filesystem type %q", m.Type)
+ }
+ return fsName, opts, useOverlay, nil
+}
+
+func (c *containerMounter) getMountAccessType(mount specs.Mount) FileAccessType {
+ if hint := c.hints.findMount(mount); hint != nil {
+ return hint.fileAccessType()
+ }
+ // Non-root bind mounts are always shared if no hints were provided.
+ return FileAccessShared
+}
+
+// mountSubmount mounts volumes inside the container's root. Because mounts may
+// be readonly, a lower ramfs overlay is added to create the mount point dir.
+// Another overlay is added with tmpfs on top if Config.Overlay is true.
+// 'm.Destination' must be an absolute path with '..' and symlinks resolved.
+func (c *containerMounter) mountSubmount(ctx context.Context, conf *Config, mns *fs.MountNamespace, root *fs.Dirent, m specs.Mount) error {
+ // Map mount type to filesystem name, and parse out the options that we are
+ // capable of dealing with.
+ fsName, opts, useOverlay, err := c.getMountNameAndOptions(conf, m)
+ if err != nil {
+ return err
+ }
+ if fsName == "" {
+ // Filesystem is not supported (e.g. cgroup), just skip it.
+ return nil
+ }
+
+ // All filesystem names should have been mapped to something we know.
+ filesystem := mustFindFilesystem(fsName)
+
+ mf := mountFlags(m.Options)
+ if useOverlay {
+ // All writes go to upper, be paranoid and make lower readonly.
+ mf.ReadOnly = true
+ }
+
+ inode, err := filesystem.Mount(ctx, mountDevice(m), mf, strings.Join(opts, ","), nil)
+ if err != nil {
+ err := fmt.Errorf("creating mount with source %q: %v", m.Source, err)
+ // Check to see if this is a common error due to a Linux bug.
+ // This error is generated here in order to cause it to be
+ // printed to the user using Docker via 'runsc create' etc. rather
+ // than simply printed to the logs for the 'runsc boot' command.
+ //
+ // We check the error message string rather than type because the
+ // actual error types (syscall.EIO, syscall.EPIPE) are lost by file system
+ // implementation (e.g. p9).
+ // TODO(gvisor.dev/issue/1765): Remove message when bug is resolved.
+ if strings.Contains(err.Error(), syscall.EIO.Error()) || strings.Contains(err.Error(), syscall.EPIPE.Error()) {
+ return fmt.Errorf("%v: %s", err, specutils.FaqErrorMsg("memlock", "you may be encountering a Linux kernel bug"))
+ }
+ return err
+ }
+
+ // If there are submounts, we need to overlay the mount on top of a ramfs
+ // with stub directories for submount paths.
+ submounts := subtargets(m.Destination, c.mounts)
+ if len(submounts) > 0 {
+ log.Infof("Adding submount overlay over %q", m.Destination)
+ inode, err = addSubmountOverlay(ctx, inode, submounts)
+ if err != nil {
+ return fmt.Errorf("adding submount overlay: %v", err)
+ }
+ }
+
+ if useOverlay {
+ log.Debugf("Adding overlay on top of mount %q", m.Destination)
+ inode, err = addOverlay(ctx, conf, inode, m.Type, mf)
+ if err != nil {
+ return err
+ }
+ }
+
+ maxTraversals := uint(0)
+ dirent, err := mns.FindInode(ctx, root, root, m.Destination, &maxTraversals)
+ if err != nil {
+ return fmt.Errorf("can't find mount destination %q: %v", m.Destination, err)
+ }
+ defer dirent.DecRef()
+ if err := mns.Mount(ctx, dirent, inode); err != nil {
+ return fmt.Errorf("mount %q error: %v", m.Destination, err)
+ }
+
+ log.Infof("Mounted %q to %q type: %s, internal-options: %q", m.Source, m.Destination, m.Type, opts)
+ return nil
+}
+
+// mountSharedSubmount binds mount to a previously mounted volume that is shared
+// among containers in the same pod.
+func (c *containerMounter) mountSharedSubmount(ctx context.Context, mns *fs.MountNamespace, root *fs.Dirent, mount specs.Mount, source *mountHint) error {
+ if err := source.checkCompatible(mount); err != nil {
+ return err
+ }
+
+ maxTraversals := uint(0)
+ target, err := mns.FindInode(ctx, root, root, mount.Destination, &maxTraversals)
+ if err != nil {
+ return fmt.Errorf("can't find mount destination %q: %v", mount.Destination, err)
+ }
+ defer target.DecRef()
+
+ // Take a ref on the inode that is about to be (re)-mounted.
+ source.root.IncRef()
+ if err := mns.Mount(ctx, target, source.root); err != nil {
+ source.root.DecRef()
+ return fmt.Errorf("bind mount %q error: %v", mount.Destination, err)
+ }
+
+ log.Infof("Mounted %q type shared bind to %q", mount.Destination, source.name)
+ return nil
+}
+
+// addRestoreMount adds a mount to the MountSources map used for restoring a
+// checkpointed container.
+func (c *containerMounter) addRestoreMount(conf *Config, renv *fs.RestoreEnvironment, m specs.Mount) error {
+ fsName, opts, useOverlay, err := c.getMountNameAndOptions(conf, m)
+ if err != nil {
+ return err
+ }
+ if fsName == "" {
+ // Filesystem is not supported (e.g. cgroup), just skip it.
+ return nil
+ }
+
+ newMount := fs.MountArgs{
+ Dev: mountDevice(m),
+ Flags: mountFlags(m.Options),
+ DataString: strings.Join(opts, ","),
+ }
+ if useOverlay {
+ newMount.Flags.ReadOnly = true
+ }
+ renv.MountSources[fsName] = append(renv.MountSources[fsName], newMount)
+ log.Infof("Added mount at %q: %+v", fsName, newMount)
+ return nil
+}
+
+// createRestoreEnvironment builds a fs.RestoreEnvironment called renv by adding
+// the mounts to the environment.
+func (c *containerMounter) createRestoreEnvironment(conf *Config) (*fs.RestoreEnvironment, error) {
+ renv := &fs.RestoreEnvironment{
+ MountSources: make(map[string][]fs.MountArgs),
+ }
+
+ // Add root mount.
+ fd := c.fds.remove()
+ opts := p9MountData(fd, conf.FileAccess, false /* vfs2 */)
+
+ mf := fs.MountSourceFlags{}
+ if c.root.Readonly || conf.Overlay {
+ mf.ReadOnly = true
+ }
+
+ rootMount := fs.MountArgs{
+ Dev: rootDevice,
+ Flags: mf,
+ DataString: strings.Join(opts, ","),
+ }
+ renv.MountSources[gofervfs2.Name] = append(renv.MountSources[gofervfs2.Name], rootMount)
+
+ // Add submounts.
+ var tmpMounted bool
+ for _, m := range c.mounts {
+ if err := c.addRestoreMount(conf, renv, m); err != nil {
+ return nil, err
+ }
+ if filepath.Clean(m.Destination) == "/tmp" {
+ tmpMounted = true
+ }
+ }
+
+ // TODO(b/67958150): handle '/tmp' properly (see mountTmp()).
+ if !tmpMounted {
+ tmpMount := specs.Mount{
+ Type: tmpfsvfs2.Name,
+ Destination: "/tmp",
+ }
+ if err := c.addRestoreMount(conf, renv, tmpMount); err != nil {
+ return nil, err
+ }
+ }
+
+ return renv, nil
+}
+
+// mountTmp mounts an internal tmpfs at '/tmp' if it's safe to do so.
+// Technically we don't have to mount tmpfs at /tmp, as we could just rely on
+// the host /tmp, but this is a nice optimization, and fixes some apps that call
+// mknod in /tmp. It's unsafe to mount tmpfs if:
+// 1. /tmp is mounted explicitly: we should not override user's wish
+// 2. /tmp is not empty: mounting tmpfs would hide existing files in /tmp
+//
+// Note that when there are submounts inside of '/tmp', directories for the
+// mount points must be present, making '/tmp' not empty anymore.
+func (c *containerMounter) mountTmp(ctx context.Context, conf *Config, mns *fs.MountNamespace, root *fs.Dirent) error {
+ for _, m := range c.mounts {
+ if filepath.Clean(m.Destination) == "/tmp" {
+ log.Debugf("Explict %q mount found, skipping internal tmpfs, mount: %+v", "/tmp", m)
+ return nil
+ }
+ }
+
+ maxTraversals := uint(0)
+ tmp, err := mns.FindInode(ctx, root, root, "tmp", &maxTraversals)
+ switch err {
+ case nil:
+ // Found '/tmp' in filesystem, check if it's empty.
+ defer tmp.DecRef()
+ f, err := tmp.Inode.GetFile(ctx, tmp, fs.FileFlags{Read: true, Directory: true})
+ if err != nil {
+ return err
+ }
+ defer f.DecRef()
+ serializer := &fs.CollectEntriesSerializer{}
+ if err := f.Readdir(ctx, serializer); err != nil {
+ return err
+ }
+ // If more than "." and ".." is found, skip internal tmpfs to prevent hiding
+ // existing files.
+ if len(serializer.Order) > 2 {
+ log.Infof("Skipping internal tmpfs on top %q, because it's not empty", "/tmp")
+ return nil
+ }
+ log.Infof("Mounting internal tmpfs on top of empty %q", "/tmp")
+ fallthrough
+
+ case syserror.ENOENT:
+ // No '/tmp' found (or fallthrough from above). Safe to mount internal
+ // tmpfs.
+ tmpMount := specs.Mount{
+ Type: tmpfsvfs2.Name,
+ Destination: "/tmp",
+ // Sticky bit is added to prevent accidental deletion of files from
+ // another user. This is normally done for /tmp.
+ Options: []string{"mode=01777"},
+ }
+ return c.mountSubmount(ctx, conf, mns, root, tmpMount)
+
+ default:
+ return err
+ }
+}
diff --git a/runsc/boot/fs_test.go b/runsc/boot/fs_test.go
new file mode 100644
index 000000000..912037075
--- /dev/null
+++ b/runsc/boot/fs_test.go
@@ -0,0 +1,250 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "reflect"
+ "strings"
+ "testing"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+)
+
+func TestPodMountHintsHappy(t *testing.T) {
+ spec := &specs.Spec{
+ Annotations: map[string]string{
+ MountPrefix + "mount1.source": "foo",
+ MountPrefix + "mount1.type": "tmpfs",
+ MountPrefix + "mount1.share": "pod",
+
+ MountPrefix + "mount2.source": "bar",
+ MountPrefix + "mount2.type": "bind",
+ MountPrefix + "mount2.share": "container",
+ MountPrefix + "mount2.options": "rw,private",
+ },
+ }
+ podHints, err := newPodMountHints(spec)
+ if err != nil {
+ t.Fatalf("newPodMountHints failed: %v", err)
+ }
+
+ // Check that fields were set correctly.
+ mount1 := podHints.mounts["mount1"]
+ if want := "mount1"; want != mount1.name {
+ t.Errorf("mount1 name, want: %q, got: %q", want, mount1.name)
+ }
+ if want := "foo"; want != mount1.mount.Source {
+ t.Errorf("mount1 source, want: %q, got: %q", want, mount1.mount.Source)
+ }
+ if want := "tmpfs"; want != mount1.mount.Type {
+ t.Errorf("mount1 type, want: %q, got: %q", want, mount1.mount.Type)
+ }
+ if want := pod; want != mount1.share {
+ t.Errorf("mount1 type, want: %q, got: %q", want, mount1.share)
+ }
+ if want := []string(nil); !reflect.DeepEqual(want, mount1.mount.Options) {
+ t.Errorf("mount1 type, want: %q, got: %q", want, mount1.mount.Options)
+ }
+
+ mount2 := podHints.mounts["mount2"]
+ if want := "mount2"; want != mount2.name {
+ t.Errorf("mount2 name, want: %q, got: %q", want, mount2.name)
+ }
+ if want := "bar"; want != mount2.mount.Source {
+ t.Errorf("mount2 source, want: %q, got: %q", want, mount2.mount.Source)
+ }
+ if want := "bind"; want != mount2.mount.Type {
+ t.Errorf("mount2 type, want: %q, got: %q", want, mount2.mount.Type)
+ }
+ if want := container; want != mount2.share {
+ t.Errorf("mount2 type, want: %q, got: %q", want, mount2.share)
+ }
+ if want := []string{"private", "rw"}; !reflect.DeepEqual(want, mount2.mount.Options) {
+ t.Errorf("mount2 type, want: %q, got: %q", want, mount2.mount.Options)
+ }
+}
+
+func TestPodMountHintsErrors(t *testing.T) {
+ for _, tst := range []struct {
+ name string
+ annotations map[string]string
+ error string
+ }{
+ {
+ name: "too short",
+ annotations: map[string]string{
+ MountPrefix + "mount1": "foo",
+ },
+ error: "invalid mount annotation",
+ },
+ {
+ name: "no name",
+ annotations: map[string]string{
+ MountPrefix + ".source": "foo",
+ },
+ error: "invalid mount name",
+ },
+ {
+ name: "missing source",
+ annotations: map[string]string{
+ MountPrefix + "mount1.type": "tmpfs",
+ MountPrefix + "mount1.share": "pod",
+ },
+ error: "source field",
+ },
+ {
+ name: "missing type",
+ annotations: map[string]string{
+ MountPrefix + "mount1.source": "foo",
+ MountPrefix + "mount1.share": "pod",
+ },
+ error: "type field",
+ },
+ {
+ name: "missing share",
+ annotations: map[string]string{
+ MountPrefix + "mount1.source": "foo",
+ MountPrefix + "mount1.type": "tmpfs",
+ },
+ error: "share field",
+ },
+ {
+ name: "invalid field name",
+ annotations: map[string]string{
+ MountPrefix + "mount1.invalid": "foo",
+ },
+ error: "invalid mount annotation",
+ },
+ {
+ name: "invalid source",
+ annotations: map[string]string{
+ MountPrefix + "mount1.source": "",
+ MountPrefix + "mount1.type": "tmpfs",
+ MountPrefix + "mount1.share": "pod",
+ },
+ error: "source cannot be empty",
+ },
+ {
+ name: "invalid type",
+ annotations: map[string]string{
+ MountPrefix + "mount1.source": "foo",
+ MountPrefix + "mount1.type": "invalid-type",
+ MountPrefix + "mount1.share": "pod",
+ },
+ error: "invalid type",
+ },
+ {
+ name: "invalid share",
+ annotations: map[string]string{
+ MountPrefix + "mount1.source": "foo",
+ MountPrefix + "mount1.type": "tmpfs",
+ MountPrefix + "mount1.share": "invalid-share",
+ },
+ error: "invalid share",
+ },
+ {
+ name: "invalid options",
+ annotations: map[string]string{
+ MountPrefix + "mount1.source": "foo",
+ MountPrefix + "mount1.type": "tmpfs",
+ MountPrefix + "mount1.share": "pod",
+ MountPrefix + "mount1.options": "invalid-option",
+ },
+ error: "unknown mount option",
+ },
+ {
+ name: "duplicate source",
+ annotations: map[string]string{
+ MountPrefix + "mount1.source": "foo",
+ MountPrefix + "mount1.type": "tmpfs",
+ MountPrefix + "mount1.share": "pod",
+
+ MountPrefix + "mount2.source": "foo",
+ MountPrefix + "mount2.type": "bind",
+ MountPrefix + "mount2.share": "container",
+ },
+ error: "have the same mount source",
+ },
+ } {
+ t.Run(tst.name, func(t *testing.T) {
+ spec := &specs.Spec{Annotations: tst.annotations}
+ podHints, err := newPodMountHints(spec)
+ if err == nil || !strings.Contains(err.Error(), tst.error) {
+ t.Errorf("newPodMountHints invalid error, want: .*%s.*, got: %v", tst.error, err)
+ }
+ if podHints != nil {
+ t.Errorf("newPodMountHints must return nil on failure: %+v", podHints)
+ }
+ })
+ }
+}
+
+func TestGetMountAccessType(t *testing.T) {
+ const source = "foo"
+ for _, tst := range []struct {
+ name string
+ annotations map[string]string
+ want FileAccessType
+ }{
+ {
+ name: "container=exclusive",
+ annotations: map[string]string{
+ MountPrefix + "mount1.source": source,
+ MountPrefix + "mount1.type": "bind",
+ MountPrefix + "mount1.share": "container",
+ },
+ want: FileAccessExclusive,
+ },
+ {
+ name: "pod=shared",
+ annotations: map[string]string{
+ MountPrefix + "mount1.source": source,
+ MountPrefix + "mount1.type": "bind",
+ MountPrefix + "mount1.share": "pod",
+ },
+ want: FileAccessShared,
+ },
+ {
+ name: "shared=shared",
+ annotations: map[string]string{
+ MountPrefix + "mount1.source": source,
+ MountPrefix + "mount1.type": "bind",
+ MountPrefix + "mount1.share": "shared",
+ },
+ want: FileAccessShared,
+ },
+ {
+ name: "default=shared",
+ annotations: map[string]string{
+ MountPrefix + "mount1.source": source + "mismatch",
+ MountPrefix + "mount1.type": "bind",
+ MountPrefix + "mount1.share": "container",
+ },
+ want: FileAccessShared,
+ },
+ } {
+ t.Run(tst.name, func(t *testing.T) {
+ spec := &specs.Spec{Annotations: tst.annotations}
+ podHints, err := newPodMountHints(spec)
+ if err != nil {
+ t.Fatalf("newPodMountHints failed: %v", err)
+ }
+ mounter := containerMounter{hints: podHints}
+ if got := mounter.getMountAccessType(specs.Mount{Source: source}); got != tst.want {
+ t.Errorf("getMountAccessType(), want: %v, got: %v", tst.want, got)
+ }
+ })
+ }
+}
diff --git a/runsc/boot/limits.go b/runsc/boot/limits.go
new file mode 100644
index 000000000..ce62236e5
--- /dev/null
+++ b/runsc/boot/limits.go
@@ -0,0 +1,154 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "fmt"
+ "syscall"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Mapping from linux resource names to limits.LimitType.
+var fromLinuxResource = map[string]limits.LimitType{
+ "RLIMIT_AS": limits.AS,
+ "RLIMIT_CORE": limits.Core,
+ "RLIMIT_CPU": limits.CPU,
+ "RLIMIT_DATA": limits.Data,
+ "RLIMIT_FSIZE": limits.FileSize,
+ "RLIMIT_LOCKS": limits.Locks,
+ "RLIMIT_MEMLOCK": limits.MemoryLocked,
+ "RLIMIT_MSGQUEUE": limits.MessageQueueBytes,
+ "RLIMIT_NICE": limits.Nice,
+ "RLIMIT_NOFILE": limits.NumberOfFiles,
+ "RLIMIT_NPROC": limits.ProcessCount,
+ "RLIMIT_RSS": limits.Rss,
+ "RLIMIT_RTPRIO": limits.RealTimePriority,
+ "RLIMIT_RTTIME": limits.Rttime,
+ "RLIMIT_SIGPENDING": limits.SignalsPending,
+ "RLIMIT_STACK": limits.Stack,
+}
+
+func findName(lt limits.LimitType) string {
+ for k, v := range fromLinuxResource {
+ if v == lt {
+ return k
+ }
+ }
+ return "unknown"
+}
+
+var defaults defs
+
+type defs struct {
+ mu sync.Mutex
+ set *limits.LimitSet
+ err error
+}
+
+func (d *defs) get() (*limits.LimitSet, error) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ if d.err != nil {
+ return nil, d.err
+ }
+ if d.set == nil {
+ if err := d.initDefaults(); err != nil {
+ d.err = err
+ return nil, err
+ }
+ }
+ return d.set, nil
+}
+
+func (d *defs) initDefaults() error {
+ ls, err := limits.NewLinuxLimitSet()
+ if err != nil {
+ return err
+ }
+
+ // Set default limits based on what containers get by default, ex:
+ // $ docker run --rm debian prlimit
+ ls.SetUnchecked(limits.AS, limits.Limit{Cur: limits.Infinity, Max: limits.Infinity})
+ ls.SetUnchecked(limits.Core, limits.Limit{Cur: limits.Infinity, Max: limits.Infinity})
+ ls.SetUnchecked(limits.CPU, limits.Limit{Cur: limits.Infinity, Max: limits.Infinity})
+ ls.SetUnchecked(limits.Data, limits.Limit{Cur: limits.Infinity, Max: limits.Infinity})
+ ls.SetUnchecked(limits.FileSize, limits.Limit{Cur: limits.Infinity, Max: limits.Infinity})
+ ls.SetUnchecked(limits.Locks, limits.Limit{Cur: limits.Infinity, Max: limits.Infinity})
+ ls.SetUnchecked(limits.MemoryLocked, limits.Limit{Cur: 65536, Max: 65536})
+ ls.SetUnchecked(limits.MessageQueueBytes, limits.Limit{Cur: 819200, Max: 819200})
+ ls.SetUnchecked(limits.Nice, limits.Limit{Cur: 0, Max: 0})
+ ls.SetUnchecked(limits.NumberOfFiles, limits.Limit{Cur: 1048576, Max: 1048576})
+ ls.SetUnchecked(limits.ProcessCount, limits.Limit{Cur: limits.Infinity, Max: limits.Infinity})
+ ls.SetUnchecked(limits.Rss, limits.Limit{Cur: limits.Infinity, Max: limits.Infinity})
+ ls.SetUnchecked(limits.RealTimePriority, limits.Limit{Cur: 0, Max: 0})
+ ls.SetUnchecked(limits.Rttime, limits.Limit{Cur: limits.Infinity, Max: limits.Infinity})
+ ls.SetUnchecked(limits.SignalsPending, limits.Limit{Cur: 0, Max: 0})
+ ls.SetUnchecked(limits.Stack, limits.Limit{Cur: 8388608, Max: limits.Infinity})
+
+ // Read host limits that directly affect the sandbox and adjust the defaults
+ // based on them.
+ for _, res := range []int{syscall.RLIMIT_FSIZE, syscall.RLIMIT_NOFILE} {
+ var hl syscall.Rlimit
+ if err := syscall.Getrlimit(res, &hl); err != nil {
+ return err
+ }
+
+ lt, ok := limits.FromLinuxResource[res]
+ if !ok {
+ return fmt.Errorf("unknown rlimit type %v", res)
+ }
+ hostLimit := limits.Limit{
+ Cur: limits.FromLinux(hl.Cur),
+ Max: limits.FromLinux(hl.Max),
+ }
+
+ defaultLimit := ls.Get(lt)
+ if hostLimit.Cur != limits.Infinity && hostLimit.Cur < defaultLimit.Cur {
+ log.Warningf("Host limit is lower than recommended, resource: %q, host: %d, recommended: %d", findName(lt), hostLimit.Cur, defaultLimit.Cur)
+ }
+ if hostLimit.Cur != defaultLimit.Cur || hostLimit.Max != defaultLimit.Max {
+ log.Infof("Setting limit from host, resource: %q {soft: %d, hard: %d}", findName(lt), hostLimit.Cur, hostLimit.Max)
+ ls.SetUnchecked(lt, hostLimit)
+ }
+ }
+
+ d.set = ls
+ return nil
+}
+
+func createLimitSet(spec *specs.Spec) (*limits.LimitSet, error) {
+ ls, err := defaults.get()
+ if err != nil {
+ return nil, err
+ }
+
+ // Then apply overwrites on top of defaults.
+ for _, rl := range spec.Process.Rlimits {
+ lt, ok := fromLinuxResource[rl.Type]
+ if !ok {
+ return nil, fmt.Errorf("unknown resource %q", rl.Type)
+ }
+ ls.SetUnchecked(lt, limits.Limit{
+ Cur: rl.Soft,
+ Max: rl.Hard,
+ })
+ }
+ return ls, nil
+}
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
new file mode 100644
index 000000000..0c0423ab2
--- /dev/null
+++ b/runsc/boot/loader.go
@@ -0,0 +1,1284 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package boot loads the kernel and runs a container.
+package boot
+
+import (
+ "fmt"
+ mrand "math/rand"
+ "os"
+ "runtime"
+ "sync/atomic"
+ "syscall"
+ gtime "time"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/memutil"
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/control"
+ "gvisor.dev/gvisor/pkg/sentry/fdimport"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/host"
+ "gvisor.dev/gvisor/pkg/sentry/fs/user"
+ hostvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/loader"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/sighandling"
+ "gvisor.dev/gvisor/pkg/sentry/syscalls/linux/vfs2"
+ "gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/watchdog"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/raw"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/runsc/boot/filter"
+ _ "gvisor.dev/gvisor/runsc/boot/platforms" // register all platforms.
+ "gvisor.dev/gvisor/runsc/boot/pprof"
+ "gvisor.dev/gvisor/runsc/specutils"
+
+ // Include supported socket providers.
+ "gvisor.dev/gvisor/pkg/sentry/socket/hostinet"
+ _ "gvisor.dev/gvisor/pkg/sentry/socket/netlink"
+ _ "gvisor.dev/gvisor/pkg/sentry/socket/netlink/route"
+ _ "gvisor.dev/gvisor/pkg/sentry/socket/netlink/uevent"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ _ "gvisor.dev/gvisor/pkg/sentry/socket/unix"
+)
+
+// Loader keeps state needed to start the kernel and run the container..
+type Loader struct {
+ // k is the kernel.
+ k *kernel.Kernel
+
+ // ctrl is the control server.
+ ctrl *controller
+
+ conf *Config
+
+ // console is set to true if terminal is enabled.
+ console bool
+
+ watchdog *watchdog.Watchdog
+
+ // stdioFDs contains stdin, stdout, and stderr.
+ stdioFDs []int
+
+ // goferFDs are the FDs that attach the sandbox to the gofers.
+ goferFDs []int
+
+ // spec is the base configuration for the root container.
+ spec *specs.Spec
+
+ // stopSignalForwarding disables forwarding of signals to the sandboxed
+ // container. It should be called when a sandbox is destroyed.
+ stopSignalForwarding func()
+
+ // restore is set to true if we are restoring a container.
+ restore bool
+
+ // rootProcArgs refers to the root sandbox init task.
+ rootProcArgs kernel.CreateProcessArgs
+
+ // sandboxID is the ID for the whole sandbox.
+ sandboxID string
+
+ // mu guards processes.
+ mu sync.Mutex
+
+ // processes maps containers init process and invocation of exec. Root
+ // processes are keyed with container ID and pid=0, while exec invocations
+ // have the corresponding pid set.
+ //
+ // processes is guardded by mu.
+ processes map[execID]*execProcess
+
+ // mountHints provides extra information about mounts for containers that
+ // apply to the entire pod.
+ mountHints *podMountHints
+}
+
+// execID uniquely identifies a sentry process that is executed in a container.
+type execID struct {
+ cid string
+ pid kernel.ThreadID
+}
+
+// execProcess contains the thread group and host TTY of a sentry process.
+type execProcess struct {
+ // tg will be nil for containers that haven't started yet.
+ tg *kernel.ThreadGroup
+
+ // tty will be nil if the process is not attached to a terminal.
+ tty *host.TTYFileOperations
+
+ // tty will be nil if the process is not attached to a terminal.
+ ttyVFS2 *hostvfs2.TTYFileDescription
+
+ // pidnsPath is the pid namespace path in spec
+ pidnsPath string
+}
+
+func init() {
+ // Initialize the random number generator.
+ mrand.Seed(gtime.Now().UnixNano())
+}
+
+// Args are the arguments for New().
+type Args struct {
+ // Id is the sandbox ID.
+ ID string
+ // Spec is the sandbox specification.
+ Spec *specs.Spec
+ // Conf is the system configuration.
+ Conf *Config
+ // ControllerFD is the FD to the URPC controller. The Loader takes ownership
+ // of this FD and may close it at any time.
+ ControllerFD int
+ // Device is an optional argument that is passed to the platform. The Loader
+ // takes ownership of this file and may close it at any time.
+ Device *os.File
+ // GoferFDs is an array of FDs used to connect with the Gofer. The Loader
+ // takes ownership of these FDs and may close them at any time.
+ GoferFDs []int
+ // StdioFDs is the stdio for the application. The Loader takes ownership of
+ // these FDs and may close them at any time.
+ StdioFDs []int
+ // Console is set to true if using TTY.
+ Console bool
+ // NumCPU is the number of CPUs to create inside the sandbox.
+ NumCPU int
+ // TotalMem is the initial amount of total memory to report back to the
+ // container.
+ TotalMem uint64
+ // UserLogFD is the file descriptor to write user logs to.
+ UserLogFD int
+}
+
+// make sure stdioFDs are always the same on initial start and on restore
+const startingStdioFD = 64
+
+// New initializes a new kernel loader configured by spec.
+// New also handles setting up a kernel for restoring a container.
+func New(args Args) (*Loader, error) {
+ // We initialize the rand package now to make sure /dev/urandom is pre-opened
+ // on kernels that do not support getrandom(2).
+ if err := rand.Init(); err != nil {
+ return nil, fmt.Errorf("setting up rand: %v", err)
+ }
+
+ if err := usage.Init(); err != nil {
+ return nil, fmt.Errorf("setting up memory usage: %v", err)
+ }
+
+ // Is this a VFSv2 kernel?
+ if args.Conf.VFS2 {
+ kernel.VFS2Enabled = true
+ vfs2.Override()
+ }
+
+ // Create kernel and platform.
+ p, err := createPlatform(args.Conf, args.Device)
+ if err != nil {
+ return nil, fmt.Errorf("creating platform: %v", err)
+ }
+ k := &kernel.Kernel{
+ Platform: p,
+ }
+
+ // Create memory file.
+ mf, err := createMemoryFile()
+ if err != nil {
+ return nil, fmt.Errorf("creating memory file: %v", err)
+ }
+ k.SetMemoryFile(mf)
+
+ // Create VDSO.
+ //
+ // Pass k as the platform since it is savable, unlike the actual platform.
+ vdso, err := loader.PrepareVDSO(k)
+ if err != nil {
+ return nil, fmt.Errorf("creating vdso: %v", err)
+ }
+
+ // Create timekeeper.
+ tk, err := kernel.NewTimekeeper(k, vdso.ParamPage.FileRange())
+ if err != nil {
+ return nil, fmt.Errorf("creating timekeeper: %v", err)
+ }
+ tk.SetClocks(time.NewCalibratedClocks())
+
+ if err := enableStrace(args.Conf); err != nil {
+ return nil, fmt.Errorf("enabling strace: %v", err)
+ }
+
+ // Create root network namespace/stack.
+ netns, err := newRootNetworkNamespace(args.Conf, k, k)
+ if err != nil {
+ return nil, fmt.Errorf("creating network: %v", err)
+ }
+
+ // Create capabilities.
+ caps, err := specutils.Capabilities(args.Conf.EnableRaw, args.Spec.Process.Capabilities)
+ if err != nil {
+ return nil, fmt.Errorf("converting capabilities: %v", err)
+ }
+
+ // Convert the spec's additional GIDs to KGIDs.
+ extraKGIDs := make([]auth.KGID, 0, len(args.Spec.Process.User.AdditionalGids))
+ for _, GID := range args.Spec.Process.User.AdditionalGids {
+ extraKGIDs = append(extraKGIDs, auth.KGID(GID))
+ }
+
+ // Create credentials.
+ creds := auth.NewUserCredentials(
+ auth.KUID(args.Spec.Process.User.UID),
+ auth.KGID(args.Spec.Process.User.GID),
+ extraKGIDs,
+ caps,
+ auth.NewRootUserNamespace())
+
+ if args.NumCPU == 0 {
+ args.NumCPU = runtime.NumCPU()
+ }
+ log.Infof("CPUs: %d", args.NumCPU)
+
+ if args.TotalMem > 0 {
+ // Adjust the total memory returned by the Sentry so that applications that
+ // use /proc/meminfo can make allocations based on this limit.
+ usage.MinimumTotalMemoryBytes = args.TotalMem
+ log.Infof("Setting total memory to %.2f GB", float64(args.TotalMem)/(1<<30))
+ }
+
+ // Initiate the Kernel object, which is required by the Context passed
+ // to createVFS in order to mount (among other things) procfs.
+ if err = k.Init(kernel.InitKernelArgs{
+ FeatureSet: cpuid.HostFeatureSet(),
+ Timekeeper: tk,
+ RootUserNamespace: creds.UserNamespace,
+ RootNetworkNamespace: netns,
+ ApplicationCores: uint(args.NumCPU),
+ Vdso: vdso,
+ RootUTSNamespace: kernel.NewUTSNamespace(args.Spec.Hostname, args.Spec.Hostname, creds.UserNamespace),
+ RootIPCNamespace: kernel.NewIPCNamespace(creds.UserNamespace),
+ RootAbstractSocketNamespace: kernel.NewAbstractSocketNamespace(),
+ PIDNamespace: kernel.NewRootPIDNamespace(creds.UserNamespace),
+ }); err != nil {
+ return nil, fmt.Errorf("initializing kernel: %v", err)
+ }
+
+ if kernel.VFS2Enabled {
+ if err := registerFilesystems(k); err != nil {
+ return nil, fmt.Errorf("registering filesystems: %w", err)
+ }
+ }
+
+ if err := adjustDirentCache(k); err != nil {
+ return nil, err
+ }
+
+ // Turn on packet logging if enabled.
+ if args.Conf.LogPackets {
+ log.Infof("Packet logging enabled")
+ atomic.StoreUint32(&sniffer.LogPackets, 1)
+ } else {
+ log.Infof("Packet logging disabled")
+ atomic.StoreUint32(&sniffer.LogPackets, 0)
+ }
+
+ // Create a watchdog.
+ dogOpts := watchdog.DefaultOpts
+ dogOpts.TaskTimeoutAction = args.Conf.WatchdogAction
+ dog := watchdog.New(k, dogOpts)
+
+ procArgs, err := newProcess(args.ID, args.Spec, creds, k, k.RootPIDNamespace())
+ if err != nil {
+ return nil, fmt.Errorf("creating init process for root container: %v", err)
+ }
+
+ if err := initCompatLogs(args.UserLogFD); err != nil {
+ return nil, fmt.Errorf("initializing compat logs: %v", err)
+ }
+
+ mountHints, err := newPodMountHints(args.Spec)
+ if err != nil {
+ return nil, fmt.Errorf("creating pod mount hints: %v", err)
+ }
+
+ if kernel.VFS2Enabled {
+ // Set up host mount that will be used for imported fds.
+ hostFilesystem, err := hostvfs2.NewFilesystem(k.VFS())
+ if err != nil {
+ return nil, fmt.Errorf("failed to create hostfs filesystem: %v", err)
+ }
+ defer hostFilesystem.DecRef()
+ hostMount, err := k.VFS().NewDisconnectedMount(hostFilesystem, nil, &vfs.MountOptions{})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create hostfs mount: %v", err)
+ }
+ k.SetHostMount(hostMount)
+ }
+
+ // Make host FDs stable between invocations. Host FDs must map to the exact
+ // same number when the sandbox is restored. Otherwise the wrong FD will be
+ // used.
+ var stdioFDs []int
+ newfd := startingStdioFD
+ for _, fd := range args.StdioFDs {
+ err := syscall.Dup3(fd, newfd, syscall.O_CLOEXEC)
+ if err != nil {
+ return nil, fmt.Errorf("dup3 of stdioFDs failed: %v", err)
+ }
+ stdioFDs = append(stdioFDs, newfd)
+ err = syscall.Close(fd)
+ if err != nil {
+ return nil, fmt.Errorf("close original stdioFDs failed: %v", err)
+ }
+ newfd++
+ }
+
+ eid := execID{cid: args.ID}
+ l := &Loader{
+ k: k,
+ conf: args.Conf,
+ console: args.Console,
+ watchdog: dog,
+ spec: args.Spec,
+ goferFDs: args.GoferFDs,
+ stdioFDs: stdioFDs,
+ rootProcArgs: procArgs,
+ sandboxID: args.ID,
+ processes: map[execID]*execProcess{eid: {}},
+ mountHints: mountHints,
+ }
+
+ // We don't care about child signals; some platforms can generate a
+ // tremendous number of useless ones (I'm looking at you, ptrace).
+ if err := sighandling.IgnoreChildStop(); err != nil {
+ return nil, fmt.Errorf("ignore child stop signals failed: %v", err)
+ }
+
+ // Create the control server using the provided FD.
+ //
+ // This must be done *after* we have initialized the kernel since the
+ // controller is used to configure the kernel's network stack.
+ ctrl, err := newController(args.ControllerFD, l)
+ if err != nil {
+ return nil, fmt.Errorf("creating control server: %v", err)
+ }
+ l.ctrl = ctrl
+
+ // Only start serving after Loader is set to controller and controller is set
+ // to Loader, because they are both used in the urpc methods.
+ if err := ctrl.srv.StartServing(); err != nil {
+ return nil, fmt.Errorf("starting control server: %v", err)
+ }
+
+ return l, nil
+}
+
+// newProcess creates a process that can be run with kernel.CreateProcess.
+func newProcess(id string, spec *specs.Spec, creds *auth.Credentials, k *kernel.Kernel, pidns *kernel.PIDNamespace) (kernel.CreateProcessArgs, error) {
+ // Create initial limits.
+ ls, err := createLimitSet(spec)
+ if err != nil {
+ return kernel.CreateProcessArgs{}, fmt.Errorf("creating limits: %v", err)
+ }
+
+ wd := spec.Process.Cwd
+ if wd == "" {
+ wd = "/"
+ }
+
+ // Create the process arguments.
+ procArgs := kernel.CreateProcessArgs{
+ Argv: spec.Process.Args,
+ Envv: spec.Process.Env,
+ WorkingDirectory: wd,
+ Credentials: creds,
+ Umask: 0022,
+ Limits: ls,
+ MaxSymlinkTraversals: linux.MaxSymlinkTraversals,
+ UTSNamespace: k.RootUTSNamespace(),
+ IPCNamespace: k.RootIPCNamespace(),
+ AbstractSocketNamespace: k.RootAbstractSocketNamespace(),
+ ContainerID: id,
+ PIDNamespace: pidns,
+ }
+
+ return procArgs, nil
+}
+
+// Destroy cleans up all resources used by the loader.
+//
+// Note that this will block until all open control server connections have
+// been closed. For that reason, this should NOT be called in a defer, because
+// a panic in a control server rpc would then hang forever.
+func (l *Loader) Destroy() {
+ if l.ctrl != nil {
+ l.ctrl.srv.Stop()
+ }
+ if l.stopSignalForwarding != nil {
+ l.stopSignalForwarding()
+ }
+ l.watchdog.Stop()
+}
+
+func createPlatform(conf *Config, deviceFile *os.File) (platform.Platform, error) {
+ p, err := platform.Lookup(conf.Platform)
+ if err != nil {
+ panic(fmt.Sprintf("invalid platform %v: %v", conf.Platform, err))
+ }
+ log.Infof("Platform: %s", conf.Platform)
+ return p.New(deviceFile)
+}
+
+func createMemoryFile() (*pgalloc.MemoryFile, error) {
+ const memfileName = "runsc-memory"
+ memfd, err := memutil.CreateMemFD(memfileName, 0)
+ if err != nil {
+ return nil, fmt.Errorf("error creating memfd: %v", err)
+ }
+ memfile := os.NewFile(uintptr(memfd), memfileName)
+ // We can't enable pgalloc.MemoryFileOpts.UseHostMemcgPressure even if
+ // there are memory cgroups specified, because at this point we're already
+ // in a mount namespace in which the relevant cgroupfs is not visible.
+ mf, err := pgalloc.NewMemoryFile(memfile, pgalloc.MemoryFileOpts{})
+ if err != nil {
+ memfile.Close()
+ return nil, fmt.Errorf("error creating pgalloc.MemoryFile: %v", err)
+ }
+ return mf, nil
+}
+
+func (l *Loader) installSeccompFilters() error {
+ if l.conf.DisableSeccomp {
+ filter.Report("syscall filter is DISABLED. Running in less secure mode.")
+ } else {
+ opts := filter.Options{
+ Platform: l.k.Platform,
+ HostNetwork: l.conf.Network == NetworkHost,
+ ProfileEnable: l.conf.ProfileEnable,
+ ControllerFD: l.ctrl.srv.FD(),
+ }
+ if err := filter.Install(opts); err != nil {
+ return fmt.Errorf("installing seccomp filters: %v", err)
+ }
+ }
+ return nil
+}
+
+// Run runs the root container.
+func (l *Loader) Run() error {
+ err := l.run()
+ l.ctrl.manager.startResultChan <- err
+ if err != nil {
+ // Give the controller some time to send the error to the
+ // runtime. If we return too quickly here the process will exit
+ // and the control connection will be closed before the error
+ // is returned.
+ gtime.Sleep(2 * gtime.Second)
+ return err
+ }
+ return nil
+}
+
+func (l *Loader) run() error {
+ if l.conf.Network == NetworkHost {
+ // Delay host network configuration to this point because network namespace
+ // is configured after the loader is created and before Run() is called.
+ log.Debugf("Configuring host network")
+ stack := l.k.RootNetworkNamespace().Stack().(*hostinet.Stack)
+ if err := stack.Configure(); err != nil {
+ return err
+ }
+ }
+
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ eid := execID{cid: l.sandboxID}
+ ep, ok := l.processes[eid]
+ if !ok {
+ return fmt.Errorf("trying to start deleted container %q", l.sandboxID)
+ }
+
+ // If we are restoring, we do not want to create a process.
+ // l.restore is set by the container manager when a restore call is made.
+ var ttyFile *host.TTYFileOperations
+ var ttyFileVFS2 *hostvfs2.TTYFileDescription
+ if !l.restore {
+ if l.conf.ProfileEnable {
+ pprof.Initialize()
+ }
+
+ // Finally done with all configuration. Setup filters before user code
+ // is loaded.
+ if err := l.installSeccompFilters(); err != nil {
+ return err
+ }
+
+ // Create the FD map, which will set stdin, stdout, and stderr. If console
+ // is true, then ioctl calls will be passed through to the host fd.
+ ctx := l.rootProcArgs.NewContext(l.k)
+ var err error
+
+ // CreateProcess takes a reference on FDMap if successful. We won't need
+ // ours either way.
+ l.rootProcArgs.FDTable, ttyFile, ttyFileVFS2, err = createFDTable(ctx, l.console, l.stdioFDs)
+ if err != nil {
+ return fmt.Errorf("importing fds: %v", err)
+ }
+
+ // Setup the root container file system.
+ l.startGoferMonitor(l.sandboxID, l.goferFDs)
+
+ mntr := newContainerMounter(l.spec, l.goferFDs, l.k, l.mountHints)
+ if err := mntr.processHints(l.conf, l.rootProcArgs.Credentials); err != nil {
+ return err
+ }
+ if err := setupContainerFS(ctx, l.conf, mntr, &l.rootProcArgs); err != nil {
+ return err
+ }
+
+ // Add the HOME enviroment variable if it is not already set.
+ var envv []string
+ if kernel.VFS2Enabled {
+ envv, err = user.MaybeAddExecUserHomeVFS2(ctx, l.rootProcArgs.MountNamespaceVFS2,
+ l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv)
+
+ } else {
+ envv, err = user.MaybeAddExecUserHome(ctx, l.rootProcArgs.MountNamespace,
+ l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv)
+ }
+ if err != nil {
+ return err
+ }
+ l.rootProcArgs.Envv = envv
+
+ // Create the root container init task. It will begin running
+ // when the kernel is started.
+ if _, _, err := l.k.CreateProcess(l.rootProcArgs); err != nil {
+ return fmt.Errorf("creating init process: %v", err)
+ }
+
+ // CreateProcess takes a reference on FDTable if successful.
+ l.rootProcArgs.FDTable.DecRef()
+ }
+
+ ep.tg = l.k.GlobalInit()
+ if ns, ok := specutils.GetNS(specs.PIDNamespace, l.spec); ok {
+ ep.pidnsPath = ns.Path
+ }
+ if l.console {
+ // Set the foreground process group on the TTY to the global init process
+ // group, since that is what we are about to start running.
+ switch {
+ case ttyFileVFS2 != nil:
+ ep.ttyVFS2 = ttyFileVFS2
+ ttyFileVFS2.InitForegroundProcessGroup(ep.tg.ProcessGroup())
+ case ttyFile != nil:
+ ep.tty = ttyFile
+ ttyFile.InitForegroundProcessGroup(ep.tg.ProcessGroup())
+ }
+ }
+
+ // Handle signals by forwarding them to the root container process
+ // (except for panic signal, which should cause a panic).
+ l.stopSignalForwarding = sighandling.StartSignalForwarding(func(sig linux.Signal) {
+ // Panic signal should cause a panic.
+ if l.conf.PanicSignal != -1 && sig == linux.Signal(l.conf.PanicSignal) {
+ panic("Signal-induced panic")
+ }
+
+ // Otherwise forward to root container.
+ deliveryMode := DeliverToProcess
+ if l.console {
+ // Since we are running with a console, we should forward the signal to
+ // the foreground process group so that job control signals like ^C can
+ // be handled properly.
+ deliveryMode = DeliverToForegroundProcessGroup
+ }
+ log.Infof("Received external signal %d, mode: %v", sig, deliveryMode)
+ if err := l.signal(l.sandboxID, 0, int32(sig), deliveryMode); err != nil {
+ log.Warningf("error sending signal %v to container %q: %v", sig, l.sandboxID, err)
+ }
+ })
+
+ // l.stdioFDs are derived from dup() in boot.New() and they are now dup()ed again
+ // either in createFDTable() during initial start or in descriptor.initAfterLoad()
+ // during restore, we can release l.stdioFDs now. VFS2 takes ownership of the
+ // passed FDs, so only close for VFS1.
+ if !kernel.VFS2Enabled {
+ for _, fd := range l.stdioFDs {
+ err := syscall.Close(fd)
+ if err != nil {
+ return fmt.Errorf("close dup()ed stdioFDs: %v", err)
+ }
+ }
+ }
+
+ log.Infof("Process should have started...")
+ l.watchdog.Start()
+ return l.k.Start()
+}
+
+// createContainer creates a new container inside the sandbox.
+func (l *Loader) createContainer(cid string) error {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ eid := execID{cid: cid}
+ if _, ok := l.processes[eid]; ok {
+ return fmt.Errorf("container %q already exists", cid)
+ }
+ l.processes[eid] = &execProcess{}
+ return nil
+}
+
+// startContainer starts a child container. It returns the thread group ID of
+// the newly created process. Caller owns 'files' and may close them after
+// this method returns.
+func (l *Loader) startContainer(spec *specs.Spec, conf *Config, cid string, files []*os.File) error {
+ // Create capabilities.
+ caps, err := specutils.Capabilities(conf.EnableRaw, spec.Process.Capabilities)
+ if err != nil {
+ return fmt.Errorf("creating capabilities: %v", err)
+ }
+
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ eid := execID{cid: cid}
+ if _, ok := l.processes[eid]; !ok {
+ return fmt.Errorf("trying to start a deleted container %q", cid)
+ }
+
+ // Convert the spec's additional GIDs to KGIDs.
+ extraKGIDs := make([]auth.KGID, 0, len(spec.Process.User.AdditionalGids))
+ for _, GID := range spec.Process.User.AdditionalGids {
+ extraKGIDs = append(extraKGIDs, auth.KGID(GID))
+ }
+
+ // Create credentials. We reuse the root user namespace because the
+ // sentry currently supports only 1 mount namespace, which is tied to a
+ // single user namespace. Thus we must run in the same user namespace
+ // to access mounts.
+ creds := auth.NewUserCredentials(
+ auth.KUID(spec.Process.User.UID),
+ auth.KGID(spec.Process.User.GID),
+ extraKGIDs,
+ caps,
+ l.k.RootUserNamespace())
+
+ var pidns *kernel.PIDNamespace
+ if ns, ok := specutils.GetNS(specs.PIDNamespace, spec); ok {
+ if ns.Path != "" {
+ for _, p := range l.processes {
+ if ns.Path == p.pidnsPath {
+ pidns = p.tg.PIDNamespace()
+ break
+ }
+ }
+ }
+ if pidns == nil {
+ pidns = l.k.RootPIDNamespace().NewChild(l.k.RootUserNamespace())
+ }
+ l.processes[eid].pidnsPath = ns.Path
+ } else {
+ pidns = l.k.RootPIDNamespace()
+ }
+ procArgs, err := newProcess(cid, spec, creds, l.k, pidns)
+ if err != nil {
+ return fmt.Errorf("creating new process: %v", err)
+ }
+
+ // setupContainerFS() dups stdioFDs, so we don't need to dup them here.
+ var stdioFDs []int
+ for _, f := range files[:3] {
+ stdioFDs = append(stdioFDs, int(f.Fd()))
+ }
+
+ // Create the FD map, which will set stdin, stdout, and stderr.
+ ctx := procArgs.NewContext(l.k)
+ fdTable, _, _, err := createFDTable(ctx, false, stdioFDs)
+ if err != nil {
+ return fmt.Errorf("importing fds: %v", err)
+ }
+ // CreateProcess takes a reference on fdTable if successful. We won't
+ // need ours either way.
+ procArgs.FDTable = fdTable
+
+ // Can't take ownership away from os.File. dup them to get a new FDs.
+ var goferFDs []int
+ for _, f := range files[3:] {
+ fd, err := syscall.Dup(int(f.Fd()))
+ if err != nil {
+ return fmt.Errorf("failed to dup file: %v", err)
+ }
+ goferFDs = append(goferFDs, fd)
+ }
+
+ // Setup the child container file system.
+ l.startGoferMonitor(cid, goferFDs)
+
+ mntr := newContainerMounter(spec, goferFDs, l.k, l.mountHints)
+ if err := setupContainerFS(ctx, conf, mntr, &procArgs); err != nil {
+ return err
+ }
+
+ // Add the HOME enviroment variable if it is not already set.
+ var envv []string
+ if kernel.VFS2Enabled {
+ envv, err = user.MaybeAddExecUserHomeVFS2(ctx, procArgs.MountNamespaceVFS2,
+ procArgs.Credentials.RealKUID, procArgs.Envv)
+
+ } else {
+ envv, err = user.MaybeAddExecUserHome(ctx, procArgs.MountNamespace,
+ procArgs.Credentials.RealKUID, procArgs.Envv)
+ }
+ if err != nil {
+ return err
+ }
+ procArgs.Envv = envv
+
+ // Create and start the new process.
+ tg, _, err := l.k.CreateProcess(procArgs)
+ if err != nil {
+ return fmt.Errorf("creating process: %v", err)
+ }
+ l.k.StartProcess(tg)
+
+ // CreateProcess takes a reference on FDTable if successful.
+ procArgs.FDTable.DecRef()
+
+ l.processes[eid].tg = tg
+ return nil
+}
+
+// startGoferMonitor runs a goroutine to monitor gofer's health. It polls on
+// the gofer FDs looking for disconnects, and destroys the container if a
+// disconnect occurs in any of the gofer FDs.
+func (l *Loader) startGoferMonitor(cid string, goferFDs []int) {
+ go func() {
+ log.Debugf("Monitoring gofer health for container %q", cid)
+ var events []unix.PollFd
+ for _, fd := range goferFDs {
+ events = append(events, unix.PollFd{
+ Fd: int32(fd),
+ Events: unix.POLLHUP | unix.POLLRDHUP,
+ })
+ }
+ _, _, err := specutils.RetryEintr(func() (uintptr, uintptr, error) {
+ // Use ppoll instead of poll because it's already whilelisted in seccomp.
+ n, err := unix.Ppoll(events, nil, nil)
+ return uintptr(n), 0, err
+ })
+ if err != nil {
+ panic(fmt.Sprintf("Error monitoring gofer FDs: %v", err))
+ }
+
+ // Check if the gofer has stopped as part of normal container destruction.
+ // This is done just to avoid sending an annoying error message to the log.
+ // Note that there is a small race window in between mu.Unlock() and the
+ // lock being reacquired in destroyContainer(), but it's harmless to call
+ // destroyContainer() multiple times.
+ l.mu.Lock()
+ _, ok := l.processes[execID{cid: cid}]
+ l.mu.Unlock()
+ if ok {
+ log.Infof("Gofer socket disconnected, destroying container %q", cid)
+ if err := l.destroyContainer(cid); err != nil {
+ log.Warningf("Error destroying container %q after gofer stopped: %v", cid, err)
+ }
+ }
+ }()
+}
+
+// destroyContainer stops a container if it is still running and cleans up its
+// filesystem.
+func (l *Loader) destroyContainer(cid string) error {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ tg, err := l.tryThreadGroupFromIDLocked(execID{cid: cid})
+ if err != nil {
+ // Container doesn't exist.
+ return err
+ }
+
+ // The container exists, but has it been started?
+ if tg != nil {
+ if err := l.signalAllProcesses(cid, int32(linux.SIGKILL)); err != nil {
+ return fmt.Errorf("sending SIGKILL to all container processes: %v", err)
+ }
+ // Wait for all processes that belong to the container to exit (including
+ // exec'd processes).
+ for _, t := range l.k.TaskSet().Root.Tasks() {
+ if t.ContainerID() == cid {
+ t.ThreadGroup().WaitExited()
+ }
+ }
+
+ // At this point, all processes inside of the container have exited,
+ // releasing all references to the container's MountNamespace and
+ // causing all submounts and overlays to be unmounted.
+ //
+ // Since the container's MountNamespace has been released,
+ // MountNamespace.destroy() will have executed, but that function may
+ // trigger async close operations. We must wait for those to complete
+ // before returning, otherwise the caller may kill the gofer before
+ // they complete, causing a cascade of failing RPCs.
+ fs.AsyncBarrier()
+ }
+
+ // No more failure from this point on. Remove all container thread groups
+ // from the map.
+ for key := range l.processes {
+ if key.cid == cid {
+ delete(l.processes, key)
+ }
+ }
+
+ log.Debugf("Container destroyed %q", cid)
+ return nil
+}
+
+func (l *Loader) executeAsync(args *control.ExecArgs) (kernel.ThreadID, error) {
+ // Hold the lock for the entire operation to ensure that exec'd process is
+ // added to 'processes' in case it races with destroyContainer().
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ tg, err := l.tryThreadGroupFromIDLocked(execID{cid: args.ContainerID})
+ if err != nil {
+ return 0, err
+ }
+ if tg == nil {
+ return 0, fmt.Errorf("container %q not started", args.ContainerID)
+ }
+
+ // Get the container MountNamespace from the Task.
+ if kernel.VFS2Enabled {
+ // task.MountNamespace() does not take a ref, so we must do so ourselves.
+ args.MountNamespaceVFS2 = tg.Leader().MountNamespaceVFS2()
+ args.MountNamespaceVFS2.IncRef()
+ } else {
+ tg.Leader().WithMuLocked(func(t *kernel.Task) {
+ // task.MountNamespace() does not take a ref, so we must do so ourselves.
+ args.MountNamespace = t.MountNamespace()
+ args.MountNamespace.IncRef()
+ })
+ }
+
+ // Add the HOME environment variable if it is not already set.
+ if kernel.VFS2Enabled {
+ defer args.MountNamespaceVFS2.DecRef()
+
+ root := args.MountNamespaceVFS2.Root()
+ defer root.DecRef()
+ ctx := vfs.WithRoot(l.k.SupervisorContext(), root)
+ envv, err := user.MaybeAddExecUserHomeVFS2(ctx, args.MountNamespaceVFS2, args.KUID, args.Envv)
+ if err != nil {
+ return 0, err
+ }
+ args.Envv = envv
+ } else {
+ defer args.MountNamespace.DecRef()
+
+ root := args.MountNamespace.Root()
+ defer root.DecRef()
+ ctx := fs.WithRoot(l.k.SupervisorContext(), root)
+ envv, err := user.MaybeAddExecUserHome(ctx, args.MountNamespace, args.KUID, args.Envv)
+ if err != nil {
+ return 0, err
+ }
+ args.Envv = envv
+ }
+
+ // Start the process.
+ proc := control.Proc{Kernel: l.k}
+ args.PIDNamespace = tg.PIDNamespace()
+ newTG, tgid, ttyFile, ttyFileVFS2, err := control.ExecAsync(&proc, args)
+ if err != nil {
+ return 0, err
+ }
+
+ eid := execID{cid: args.ContainerID, pid: tgid}
+ l.processes[eid] = &execProcess{
+ tg: newTG,
+ tty: ttyFile,
+ ttyVFS2: ttyFileVFS2,
+ }
+ log.Debugf("updated processes: %v", l.processes)
+
+ return tgid, nil
+}
+
+// waitContainer waits for the init process of a container to exit.
+func (l *Loader) waitContainer(cid string, waitStatus *uint32) error {
+ // Don't defer unlock, as doing so would make it impossible for
+ // multiple clients to wait on the same container.
+ tg, err := l.threadGroupFromID(execID{cid: cid})
+ if err != nil {
+ return fmt.Errorf("can't wait for container %q: %v", cid, err)
+ }
+
+ // If the thread either has already exited or exits during waiting,
+ // consider the container exited.
+ ws := l.wait(tg)
+ *waitStatus = ws
+ return nil
+}
+
+func (l *Loader) waitPID(tgid kernel.ThreadID, cid string, waitStatus *uint32) error {
+ if tgid <= 0 {
+ return fmt.Errorf("PID (%d) must be positive", tgid)
+ }
+
+ // Try to find a process that was exec'd
+ eid := execID{cid: cid, pid: tgid}
+ execTG, err := l.threadGroupFromID(eid)
+ if err == nil {
+ ws := l.wait(execTG)
+ *waitStatus = ws
+
+ l.mu.Lock()
+ delete(l.processes, eid)
+ log.Debugf("updated processes (removal): %v", l.processes)
+ l.mu.Unlock()
+ return nil
+ }
+
+ // The caller may be waiting on a process not started directly via exec.
+ // In this case, find the process in the container's PID namespace.
+ initTG, err := l.threadGroupFromID(execID{cid: cid})
+ if err != nil {
+ return fmt.Errorf("waiting for PID %d: %v", tgid, err)
+ }
+ tg := initTG.PIDNamespace().ThreadGroupWithID(tgid)
+ if tg == nil {
+ return fmt.Errorf("waiting for PID %d: no such process", tgid)
+ }
+ if tg.Leader().ContainerID() != cid {
+ return fmt.Errorf("process %d is part of a different container: %q", tgid, tg.Leader().ContainerID())
+ }
+ ws := l.wait(tg)
+ *waitStatus = ws
+ return nil
+}
+
+// wait waits for the process with TGID 'tgid' in a container's PID namespace
+// to exit.
+func (l *Loader) wait(tg *kernel.ThreadGroup) uint32 {
+ tg.WaitExited()
+ return tg.ExitStatus().Status()
+}
+
+// WaitForStartSignal waits for a start signal from the control server.
+func (l *Loader) WaitForStartSignal() {
+ <-l.ctrl.manager.startChan
+}
+
+// WaitExit waits for the root container to exit, and returns its exit status.
+func (l *Loader) WaitExit() kernel.ExitStatus {
+ // Wait for container.
+ l.k.WaitExited()
+
+ return l.k.GlobalInit().ExitStatus()
+}
+
+func newRootNetworkNamespace(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (*inet.Namespace, error) {
+ // Create an empty network stack because the network namespace may be empty at
+ // this point. Netns is configured before Run() is called. Netstack is
+ // configured using a control uRPC message. Host network is configured inside
+ // Run().
+ switch conf.Network {
+ case NetworkHost:
+ // No network namespacing support for hostinet yet, hence creator is nil.
+ return inet.NewRootNamespace(hostinet.NewStack(), nil), nil
+
+ case NetworkNone, NetworkSandbox:
+ s, err := newEmptySandboxNetworkStack(clock, uniqueID)
+ if err != nil {
+ return nil, err
+ }
+ creator := &sandboxNetstackCreator{
+ clock: clock,
+ uniqueID: uniqueID,
+ }
+ return inet.NewRootNamespace(s, creator), nil
+
+ default:
+ panic(fmt.Sprintf("invalid network configuration: %v", conf.Network))
+ }
+
+}
+
+func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) {
+ netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}
+ transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()}
+ s := netstack.Stack{stack.New(stack.Options{
+ NetworkProtocols: netProtos,
+ TransportProtocols: transProtos,
+ Clock: clock,
+ Stats: netstack.Metrics,
+ HandleLocal: true,
+ // Enable raw sockets for users with sufficient
+ // privileges.
+ RawFactory: raw.EndpointFactory{},
+ UniqueID: uniqueID,
+ })}
+
+ // Enable SACK Recovery.
+ if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil {
+ return nil, fmt.Errorf("failed to enable SACK: %s", err)
+ }
+
+ // Set default TTLs as required by socket/netstack.
+ s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
+ s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
+
+ // Enable Receive Buffer Auto-Tuning.
+ if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
+ return nil, fmt.Errorf("SetTransportProtocolOption failed: %s", err)
+ }
+
+ return &s, nil
+}
+
+// sandboxNetstackCreator implements kernel.NetworkStackCreator.
+//
+// +stateify savable
+type sandboxNetstackCreator struct {
+ clock tcpip.Clock
+ uniqueID stack.UniqueID
+}
+
+// CreateStack implements kernel.NetworkStackCreator.CreateStack.
+func (f *sandboxNetstackCreator) CreateStack() (inet.Stack, error) {
+ s, err := newEmptySandboxNetworkStack(f.clock, f.uniqueID)
+ if err != nil {
+ return nil, err
+ }
+
+ // Setup loopback.
+ n := &Network{Stack: s.(*netstack.Stack).Stack}
+ nicID := tcpip.NICID(f.uniqueID.UniqueID())
+ link := DefaultLoopbackLink
+ linkEP := loopback.New()
+ if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil {
+ return nil, err
+ }
+
+ return s, nil
+}
+
+// signal sends a signal to one or more processes in a container. If PID is 0,
+// then the container init process is used. Depending on the SignalDeliveryMode
+// option, the signal may be sent directly to the indicated process, to all
+// processes in the container, or to the foreground process group.
+func (l *Loader) signal(cid string, pid, signo int32, mode SignalDeliveryMode) error {
+ if pid < 0 {
+ return fmt.Errorf("PID (%d) must be positive", pid)
+ }
+
+ switch mode {
+ case DeliverToProcess:
+ if err := l.signalProcess(cid, kernel.ThreadID(pid), signo); err != nil {
+ return fmt.Errorf("signaling process in container %q PID %d: %v", cid, pid, err)
+ }
+ return nil
+
+ case DeliverToForegroundProcessGroup:
+ if err := l.signalForegrondProcessGroup(cid, kernel.ThreadID(pid), signo); err != nil {
+ return fmt.Errorf("signaling foreground process group in container %q PID %d: %v", cid, pid, err)
+ }
+ return nil
+
+ case DeliverToAllProcesses:
+ if pid != 0 {
+ return fmt.Errorf("PID (%d) cannot be set when signaling all processes", pid)
+ }
+ // Check that the container has actually started before signaling it.
+ if _, err := l.threadGroupFromID(execID{cid: cid}); err != nil {
+ return err
+ }
+ if err := l.signalAllProcesses(cid, signo); err != nil {
+ return fmt.Errorf("signaling all processes in container %q: %v", cid, err)
+ }
+ return nil
+
+ default:
+ panic(fmt.Sprintf("unknown signal delivery mode %v", mode))
+ }
+}
+
+func (l *Loader) signalProcess(cid string, tgid kernel.ThreadID, signo int32) error {
+ execTG, err := l.threadGroupFromID(execID{cid: cid, pid: tgid})
+ if err == nil {
+ // Send signal directly to the identified process.
+ return l.k.SendExternalSignalThreadGroup(execTG, &arch.SignalInfo{Signo: signo})
+ }
+
+ // The caller may be signaling a process not started directly via exec.
+ // In this case, find the process in the container's PID namespace and
+ // signal it.
+ initTG, err := l.threadGroupFromID(execID{cid: cid})
+ if err != nil {
+ return fmt.Errorf("no thread group found: %v", err)
+ }
+ tg := initTG.PIDNamespace().ThreadGroupWithID(tgid)
+ if tg == nil {
+ return fmt.Errorf("no such process with PID %d", tgid)
+ }
+ if tg.Leader().ContainerID() != cid {
+ return fmt.Errorf("process %d is part of a different container: %q", tgid, tg.Leader().ContainerID())
+ }
+ return l.k.SendExternalSignalThreadGroup(tg, &arch.SignalInfo{Signo: signo})
+}
+
+// signalForegrondProcessGroup looks up foreground process group from the TTY
+// for the given "tgid" inside container "cid", and send the signal to it.
+func (l *Loader) signalForegrondProcessGroup(cid string, tgid kernel.ThreadID, signo int32) error {
+ l.mu.Lock()
+ tg, err := l.tryThreadGroupFromIDLocked(execID{cid: cid, pid: tgid})
+ if err != nil {
+ l.mu.Unlock()
+ return fmt.Errorf("no thread group found: %v", err)
+ }
+ if tg == nil {
+ l.mu.Unlock()
+ return fmt.Errorf("container %q not started", cid)
+ }
+
+ tty, ttyVFS2, err := l.ttyFromIDLocked(execID{cid: cid, pid: tgid})
+ l.mu.Unlock()
+ if err != nil {
+ return fmt.Errorf("no thread group found: %v", err)
+ }
+
+ var pg *kernel.ProcessGroup
+ switch {
+ case ttyVFS2 != nil:
+ pg = ttyVFS2.ForegroundProcessGroup()
+ case tty != nil:
+ pg = tty.ForegroundProcessGroup()
+ default:
+ return fmt.Errorf("no TTY attached")
+ }
+ if pg == nil {
+ // No foreground process group has been set. Signal the
+ // original thread group.
+ log.Warningf("No foreground process group for container %q and PID %d. Sending signal directly to PID %d.", cid, tgid, tgid)
+ return l.k.SendExternalSignalThreadGroup(tg, &arch.SignalInfo{Signo: signo})
+ }
+ // Send the signal to all processes in the process group.
+ var lastErr error
+ for _, tg := range l.k.TaskSet().Root.ThreadGroups() {
+ if tg.ProcessGroup() != pg {
+ continue
+ }
+ if err := l.k.SendExternalSignalThreadGroup(tg, &arch.SignalInfo{Signo: signo}); err != nil {
+ lastErr = err
+ }
+ }
+ return lastErr
+}
+
+// signalAllProcesses that belong to specified container. It's a noop if the
+// container hasn't started or has exited.
+func (l *Loader) signalAllProcesses(cid string, signo int32) error {
+ // Pause the kernel to prevent new processes from being created while
+ // the signal is delivered. This prevents process leaks when SIGKILL is
+ // sent to the entire container.
+ l.k.Pause()
+ defer l.k.Unpause()
+ return l.k.SendContainerSignal(cid, &arch.SignalInfo{Signo: signo})
+}
+
+// threadGroupFromID is similar to tryThreadGroupFromIDLocked except that it
+// acquires mutex before calling it and fails in case container hasn't started
+// yet.
+func (l *Loader) threadGroupFromID(key execID) (*kernel.ThreadGroup, error) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ tg, err := l.tryThreadGroupFromIDLocked(key)
+ if err != nil {
+ return nil, err
+ }
+ if tg == nil {
+ return nil, fmt.Errorf("container %q not started", key.cid)
+ }
+ return tg, nil
+}
+
+// tryThreadGroupFromIDLocked returns the thread group for the given execution
+// ID. It may return nil in case the container has not started yet. Returns
+// error if execution ID is invalid or if the container cannot be found (maybe
+// it has been deleted). Caller must hold 'mu'.
+func (l *Loader) tryThreadGroupFromIDLocked(key execID) (*kernel.ThreadGroup, error) {
+ ep := l.processes[key]
+ if ep == nil {
+ return nil, fmt.Errorf("container %q not found", key.cid)
+ }
+ return ep.tg, nil
+}
+
+// ttyFromIDLocked returns the TTY files for the given execution ID. It may
+// return nil in case the container has not started yet. Returns error if
+// execution ID is invalid or if the container cannot be found (maybe it has
+// been deleted). Caller must hold 'mu'.
+func (l *Loader) ttyFromIDLocked(key execID) (*host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) {
+ ep := l.processes[key]
+ if ep == nil {
+ return nil, nil, fmt.Errorf("container %q not found", key.cid)
+ }
+ return ep.tty, ep.ttyVFS2, nil
+}
+
+func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.FDTable, *host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) {
+ if len(stdioFDs) != 3 {
+ return nil, nil, nil, fmt.Errorf("stdioFDs should contain exactly 3 FDs (stdin, stdout, and stderr), but %d FDs received", len(stdioFDs))
+ }
+
+ k := kernel.KernelFromContext(ctx)
+ fdTable := k.NewFDTable()
+ ttyFile, ttyFileVFS2, err := fdimport.Import(ctx, fdTable, console, stdioFDs)
+ if err != nil {
+ fdTable.DecRef()
+ return nil, nil, nil, err
+ }
+ return fdTable, ttyFile, ttyFileVFS2, nil
+}
diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go
new file mode 100644
index 000000000..b723e4335
--- /dev/null
+++ b/runsc/boot/loader_test.go
@@ -0,0 +1,715 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "fmt"
+ "math/rand"
+ "os"
+ "reflect"
+ "syscall"
+ "testing"
+ "time"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/control/server"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+ "gvisor.dev/gvisor/runsc/fsgofer"
+)
+
+func init() {
+ log.SetLevel(log.Debug)
+ rand.Seed(time.Now().UnixNano())
+ if err := fsgofer.OpenProcSelfFD(); err != nil {
+ panic(err)
+ }
+}
+
+func testConfig() *Config {
+ return &Config{
+ RootDir: "unused_root_dir",
+ Network: NetworkNone,
+ DisableSeccomp: true,
+ Platform: "ptrace",
+ }
+}
+
+// testSpec returns a simple spec that can be used in tests.
+func testSpec() *specs.Spec {
+ return &specs.Spec{
+ // The host filesystem root is the sandbox root.
+ Root: &specs.Root{
+ Path: "/",
+ Readonly: true,
+ },
+ Process: &specs.Process{
+ Args: []string{"/bin/true"},
+ },
+ }
+}
+
+// startGofer starts a new gofer routine serving 'root' path. It returns the
+// sandbox side of the connection, and a function that when called will stop the
+// gofer.
+func startGofer(root string) (int, func(), error) {
+ fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ return 0, nil, err
+ }
+ sandboxEnd, goferEnd := fds[0], fds[1]
+
+ socket, err := unet.NewSocket(goferEnd)
+ if err != nil {
+ syscall.Close(sandboxEnd)
+ syscall.Close(goferEnd)
+ return 0, nil, fmt.Errorf("error creating server on FD %d: %v", goferEnd, err)
+ }
+ at, err := fsgofer.NewAttachPoint(root, fsgofer.Config{ROMount: true})
+ if err != nil {
+ return 0, nil, err
+ }
+ go func() {
+ s := p9.NewServer(at)
+ if err := s.Handle(socket); err != nil {
+ log.Infof("Gofer is stopping. FD: %d, err: %v\n", goferEnd, err)
+ }
+ }()
+ // Closing the gofer socket will stop the gofer and exit goroutine above.
+ cleanup := func() {
+ if err := socket.Close(); err != nil {
+ log.Warningf("Error closing gofer socket: %v", err)
+ }
+ }
+ return sandboxEnd, cleanup, nil
+}
+
+func createLoader(vfsEnabled bool, spec *specs.Spec) (*Loader, func(), error) {
+ fd, err := server.CreateSocket(ControlSocketAddr(fmt.Sprintf("%010d", rand.Int())[:10]))
+ if err != nil {
+ return nil, nil, err
+ }
+ conf := testConfig()
+ conf.VFS2 = vfsEnabled
+
+ sandEnd, cleanup, err := startGofer(spec.Root.Path)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Loader takes ownership of stdio.
+ var stdio []int
+ for _, f := range []*os.File{os.Stdin, os.Stdout, os.Stderr} {
+ newFd, err := unix.Dup(int(f.Fd()))
+ if err != nil {
+ return nil, nil, err
+ }
+ stdio = append(stdio, newFd)
+ }
+
+ args := Args{
+ ID: "foo",
+ Spec: spec,
+ Conf: conf,
+ ControllerFD: fd,
+ GoferFDs: []int{sandEnd},
+ StdioFDs: stdio,
+ }
+ l, err := New(args)
+ if err != nil {
+ cleanup()
+ return nil, nil, err
+ }
+ return l, cleanup, nil
+}
+
+// TestRun runs a simple application in a sandbox and checks that it succeeds.
+func TestRun(t *testing.T) {
+ doRun(t, false)
+}
+
+// TestRunVFS2 runs TestRun in VFSv2.
+func TestRunVFS2(t *testing.T) {
+ doRun(t, true)
+}
+
+func doRun(t *testing.T, vfsEnabled bool) {
+ l, cleanup, err := createLoader(vfsEnabled, testSpec())
+ if err != nil {
+ t.Fatalf("error creating loader: %v", err)
+ }
+
+ defer l.Destroy()
+ defer cleanup()
+
+ // Start a goroutine to read the start chan result, otherwise Run will
+ // block forever.
+ var resultChanErr error
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ resultChanErr = <-l.ctrl.manager.startResultChan
+ wg.Done()
+ }()
+
+ // Run the container.
+ if err := l.Run(); err != nil {
+ t.Errorf("error running container: %v", err)
+ }
+
+ // We should have not gotten an error on the startResultChan.
+ wg.Wait()
+ if resultChanErr != nil {
+ t.Errorf("error on startResultChan: %v", resultChanErr)
+ }
+
+ // Wait for the application to exit. It should succeed.
+ if status := l.WaitExit(); status.Code != 0 || status.Signo != 0 {
+ t.Errorf("application exited with status %+v, want 0", status)
+ }
+}
+
+// TestStartSignal tests that the controller Start message will cause
+// WaitForStartSignal to return.
+func TestStartSignal(t *testing.T) {
+ doStartSignal(t, false)
+}
+
+// TestStartSignalVFS2 does TestStartSignal with VFS2.
+func TestStartSignalVFS2(t *testing.T) {
+ doStartSignal(t, true)
+}
+
+func doStartSignal(t *testing.T, vfsEnabled bool) {
+ l, cleanup, err := createLoader(vfsEnabled, testSpec())
+ if err != nil {
+ t.Fatalf("error creating loader: %v", err)
+ }
+ defer l.Destroy()
+ defer cleanup()
+
+ // We aren't going to wait on this application, so the control server
+ // needs to be shut down manually.
+ defer l.ctrl.srv.Stop()
+
+ // Start a goroutine that calls WaitForStartSignal and writes to a
+ // channel when it returns.
+ waitFinished := make(chan struct{})
+ go func() {
+ l.WaitForStartSignal()
+ // Pretend that Run() executed and returned no error.
+ l.ctrl.manager.startResultChan <- nil
+ waitFinished <- struct{}{}
+ }()
+
+ // Nothing has been written to the channel, so waitFinished should not
+ // return. Give it a little bit of time to make sure the goroutine has
+ // started.
+ select {
+ case <-waitFinished:
+ t.Errorf("WaitForStartSignal completed but it should not have")
+ case <-time.After(50 * time.Millisecond):
+ // OK.
+ }
+
+ // Trigger the control server StartRoot method.
+ cid := "foo"
+ if err := l.ctrl.manager.StartRoot(&cid, nil); err != nil {
+ t.Errorf("error calling StartRoot: %v", err)
+ }
+
+ // Now WaitForStartSignal should return (within a short amount of
+ // time).
+ select {
+ case <-waitFinished:
+ // OK.
+ case <-time.After(50 * time.Millisecond):
+ t.Errorf("WaitForStartSignal did not complete but it should have")
+ }
+
+}
+
+type CreateMountTestcase struct {
+ name string
+ // Spec that will be used to create the mount manager. Note
+ // that we can't mount procfs without a kernel, so each spec
+ // MUST contain something other than procfs mounted at /proc.
+ spec specs.Spec
+ // Paths that are expected to exist in the resulting fs.
+ expectedPaths []string
+}
+
+func createMountTestcases(vfs2 bool) []*CreateMountTestcase {
+ testCases := []*CreateMountTestcase{
+ &CreateMountTestcase{
+ // Only proc.
+ name: "only proc mount",
+ spec: specs.Spec{
+ Root: &specs.Root{
+ Path: os.TempDir(),
+ Readonly: true,
+ },
+ Mounts: []specs.Mount{
+ {
+ Destination: "/proc",
+ Type: "tmpfs",
+ },
+ },
+ },
+ // /proc, /dev, and /sys should always be mounted.
+ expectedPaths: []string{"/proc", "/dev", "/sys"},
+ },
+ {
+ // Mount at a deep path, with many components that do
+ // not exist in the root.
+ name: "deep mount path",
+ spec: specs.Spec{
+ Root: &specs.Root{
+ Path: os.TempDir(),
+ Readonly: true,
+ },
+ Mounts: []specs.Mount{
+ {
+ Destination: "/some/very/very/deep/path",
+ Type: "tmpfs",
+ },
+ {
+ Destination: "/proc",
+ Type: "tmpfs",
+ },
+ },
+ },
+ // /some/deep/path should be mounted, along with /proc,
+ // /dev, and /sys.
+ expectedPaths: []string{"/some/very/very/deep/path", "/proc", "/dev", "/sys"},
+ },
+ &CreateMountTestcase{
+ // Mounts are nested inside each other.
+ name: "nested mounts",
+ spec: specs.Spec{
+ Root: &specs.Root{
+ Path: os.TempDir(),
+ Readonly: true,
+ },
+ Mounts: []specs.Mount{
+ {
+ Destination: "/proc",
+ Type: "tmpfs",
+ },
+ {
+ Destination: "/foo",
+ Type: "tmpfs",
+ },
+ {
+ Destination: "/foo/qux",
+ Type: "tmpfs",
+ },
+ {
+ // File mounts with the same prefix.
+ Destination: "/foo/qux-quz",
+ Type: "tmpfs",
+ },
+ {
+ Destination: "/foo/bar",
+ Type: "tmpfs",
+ },
+ {
+ Destination: "/foo/bar/baz",
+ Type: "tmpfs",
+ },
+ {
+ // A deep path that is in foo but not the other mounts.
+ Destination: "/foo/some/very/very/deep/path",
+ Type: "tmpfs",
+ },
+ },
+ },
+ expectedPaths: []string{"/foo", "/foo/bar", "/foo/bar/baz", "/foo/qux",
+ "/foo/qux-quz", "/foo/some/very/very/deep/path", "/proc", "/dev", "/sys"},
+ },
+ &CreateMountTestcase{
+ name: "mount inside /dev",
+ spec: specs.Spec{
+ Root: &specs.Root{
+ Path: os.TempDir(),
+ Readonly: true,
+ },
+ Mounts: []specs.Mount{
+ {
+ Destination: "/proc",
+ Type: "tmpfs",
+ },
+ {
+ Destination: "/dev",
+ Type: "tmpfs",
+ },
+ {
+ // Mounted by runsc by default.
+ Destination: "/dev/fd",
+ Type: "tmpfs",
+ },
+ {
+ // Mount with the same prefix.
+ Destination: "/dev/fd-foo",
+ Type: "tmpfs",
+ },
+ {
+ // Unsupported fs type.
+ Destination: "/dev/mqueue",
+ Type: "mqueue",
+ },
+ {
+ Destination: "/dev/foo",
+ Type: "tmpfs",
+ },
+ {
+ Destination: "/dev/bar",
+ Type: "tmpfs",
+ },
+ },
+ },
+ expectedPaths: []string{"/proc", "/dev", "/dev/fd-foo", "/dev/foo", "/dev/bar", "/sys"},
+ },
+ }
+
+ vfsCase := &CreateMountTestcase{
+ name: "mounts inside mandatory mounts",
+ spec: specs.Spec{
+ Root: &specs.Root{
+ Path: os.TempDir(),
+ Readonly: true,
+ },
+ Mounts: []specs.Mount{
+ {
+ Destination: "/proc",
+ Type: "tmpfs",
+ },
+ // TODO (gvisor.dev/issue/1487): Re-add this case when sysfs supports
+ // MkDirAt in VFS2 (and remove the reduntant append).
+ // {
+ // Destination: "/sys/bar",
+ // Type: "tmpfs",
+ // },
+ //
+ {
+ Destination: "/tmp/baz",
+ Type: "tmpfs",
+ },
+ },
+ },
+ expectedPaths: []string{"/proc", "/sys" /* "/sys/bar" ,*/, "/tmp", "/tmp/baz"},
+ }
+
+ if !vfs2 {
+ vfsCase.spec.Mounts = append(vfsCase.spec.Mounts, specs.Mount{Destination: "/sys/bar", Type: "tmpfs"})
+ vfsCase.expectedPaths = append(vfsCase.expectedPaths, "/sys/bar")
+ }
+ return append(testCases, vfsCase)
+}
+
+// Test that MountNamespace can be created with various specs.
+func TestCreateMountNamespace(t *testing.T) {
+ for _, tc := range createMountTestcases(false /* vfs2 */) {
+ t.Run(tc.name, func(t *testing.T) {
+ conf := testConfig()
+ ctx := contexttest.Context(t)
+
+ sandEnd, cleanup, err := startGofer(tc.spec.Root.Path)
+ if err != nil {
+ t.Fatalf("failed to create gofer: %v", err)
+ }
+ defer cleanup()
+
+ mntr := newContainerMounter(&tc.spec, []int{sandEnd}, nil, &podMountHints{})
+ mns, err := mntr.createMountNamespace(ctx, conf)
+ if err != nil {
+ t.Fatalf("failed to create mount namespace: %v", err)
+ }
+ ctx = fs.WithRoot(ctx, mns.Root())
+ if err := mntr.mountSubmounts(ctx, conf, mns); err != nil {
+ t.Fatalf("failed to create mount namespace: %v", err)
+ }
+
+ root := mns.Root()
+ defer root.DecRef()
+ for _, p := range tc.expectedPaths {
+ maxTraversals := uint(0)
+ if d, err := mns.FindInode(ctx, root, root, p, &maxTraversals); err != nil {
+ t.Errorf("expected path %v to exist with spec %v, but got error %v", p, tc.spec, err)
+ } else {
+ d.DecRef()
+ }
+ }
+ })
+ }
+}
+
+// Test that MountNamespace can be created with various specs.
+func TestCreateMountNamespaceVFS2(t *testing.T) {
+ for _, tc := range createMountTestcases(true /* vfs2 */) {
+ t.Run(tc.name, func(t *testing.T) {
+ spec := testSpec()
+ spec.Mounts = tc.spec.Mounts
+ spec.Root = tc.spec.Root
+
+ t.Logf("Using root: %q", spec.Root.Path)
+ l, loaderCleanup, err := createLoader(true /* VFS2 Enabled */, spec)
+ if err != nil {
+ t.Fatalf("failed to create loader: %v", err)
+ }
+ defer l.Destroy()
+ defer loaderCleanup()
+
+ mntr := newContainerMounter(l.spec, l.goferFDs, l.k, l.mountHints)
+ if err := mntr.processHints(l.conf, l.rootProcArgs.Credentials); err != nil {
+ t.Fatalf("failed process hints: %v", err)
+ }
+
+ ctx := l.k.SupervisorContext()
+ mns, err := mntr.setupVFS2(ctx, l.conf, &l.rootProcArgs)
+ if err != nil {
+ t.Fatalf("failed to setupVFS2: %v", err)
+ }
+
+ root := mns.Root()
+ defer root.DecRef()
+ for _, p := range tc.expectedPaths {
+ target := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(p),
+ }
+
+ if d, err := l.k.VFS().GetDentryAt(ctx, l.rootProcArgs.Credentials, target, &vfs.GetDentryOptions{}); err != nil {
+ t.Errorf("expected path %v to exist with spec %v, but got error %v", p, tc.spec, err)
+ } else {
+ d.DecRef()
+ }
+ }
+ })
+ }
+}
+
+// TestRestoreEnvironment tests that the correct mounts are collected from the spec and config
+// in order to build the environment for restoring.
+func TestRestoreEnvironment(t *testing.T) {
+ testCases := []struct {
+ name string
+ spec *specs.Spec
+ ioFDs []int
+ errorExpected bool
+ expectedRenv fs.RestoreEnvironment
+ }{
+ {
+ name: "basic spec test",
+ spec: &specs.Spec{
+ Root: &specs.Root{
+ Path: os.TempDir(),
+ Readonly: true,
+ },
+ Mounts: []specs.Mount{
+ {
+ Destination: "/some/very/very/deep/path",
+ Type: "tmpfs",
+ },
+ {
+ Destination: "/proc",
+ Type: "tmpfs",
+ },
+ },
+ },
+ ioFDs: []int{0},
+ errorExpected: false,
+ expectedRenv: fs.RestoreEnvironment{
+ MountSources: map[string][]fs.MountArgs{
+ "9p": {
+ {
+ Dev: "9pfs-/",
+ Flags: fs.MountSourceFlags{ReadOnly: true},
+ DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true,cache=remote_revalidating",
+ },
+ },
+ "tmpfs": {
+ {
+ Dev: "none",
+ },
+ {
+ Dev: "none",
+ },
+ {
+ Dev: "none",
+ },
+ },
+ "devtmpfs": {
+ {
+ Dev: "none",
+ },
+ },
+ "devpts": {
+ {
+ Dev: "none",
+ },
+ },
+ "sysfs": {
+ {
+ Dev: "none",
+ },
+ },
+ },
+ },
+ },
+ {
+ name: "bind type test",
+ spec: &specs.Spec{
+ Root: &specs.Root{
+ Path: os.TempDir(),
+ Readonly: true,
+ },
+ Mounts: []specs.Mount{
+ {
+ Destination: "/dev/fd-foo",
+ Type: "bind",
+ },
+ },
+ },
+ ioFDs: []int{0, 1},
+ errorExpected: false,
+ expectedRenv: fs.RestoreEnvironment{
+ MountSources: map[string][]fs.MountArgs{
+ "9p": {
+ {
+ Dev: "9pfs-/",
+ Flags: fs.MountSourceFlags{ReadOnly: true},
+ DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true,cache=remote_revalidating",
+ },
+ {
+ Dev: "9pfs-/dev/fd-foo",
+ DataString: "trans=fd,rfdno=1,wfdno=1,privateunixsocket=true,cache=remote_revalidating",
+ },
+ },
+ "tmpfs": {
+ {
+ Dev: "none",
+ },
+ },
+ "devtmpfs": {
+ {
+ Dev: "none",
+ },
+ },
+ "devpts": {
+ {
+ Dev: "none",
+ },
+ },
+ "proc": {
+ {
+ Dev: "none",
+ },
+ },
+ "sysfs": {
+ {
+ Dev: "none",
+ },
+ },
+ },
+ },
+ },
+ {
+ name: "options test",
+ spec: &specs.Spec{
+ Root: &specs.Root{
+ Path: os.TempDir(),
+ Readonly: true,
+ },
+ Mounts: []specs.Mount{
+ {
+ Destination: "/dev/fd-foo",
+ Type: "tmpfs",
+ Options: []string{"uid=1022", "noatime"},
+ },
+ },
+ },
+ ioFDs: []int{0},
+ errorExpected: false,
+ expectedRenv: fs.RestoreEnvironment{
+ MountSources: map[string][]fs.MountArgs{
+ "9p": {
+ {
+ Dev: "9pfs-/",
+ Flags: fs.MountSourceFlags{ReadOnly: true},
+ DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true,cache=remote_revalidating",
+ },
+ },
+ "tmpfs": {
+ {
+ Dev: "none",
+ Flags: fs.MountSourceFlags{NoAtime: true},
+ DataString: "uid=1022",
+ },
+ {
+ Dev: "none",
+ },
+ },
+ "devtmpfs": {
+ {
+ Dev: "none",
+ },
+ },
+ "devpts": {
+ {
+ Dev: "none",
+ },
+ },
+ "proc": {
+ {
+ Dev: "none",
+ },
+ },
+ "sysfs": {
+ {
+ Dev: "none",
+ },
+ },
+ },
+ },
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ conf := testConfig()
+ mntr := newContainerMounter(tc.spec, tc.ioFDs, nil, &podMountHints{})
+ actualRenv, err := mntr.createRestoreEnvironment(conf)
+ if !tc.errorExpected && err != nil {
+ t.Fatalf("could not create restore environment for test:%s", tc.name)
+ } else if tc.errorExpected {
+ if err == nil {
+ t.Errorf("expected an error, but no error occurred.")
+ }
+ } else {
+ if !reflect.DeepEqual(*actualRenv, tc.expectedRenv) {
+ t.Errorf("restore environments did not match for test:%s\ngot:%+v\nwant:%+v\n", tc.name, *actualRenv, tc.expectedRenv)
+ }
+ }
+ })
+ }
+}
diff --git a/runsc/boot/network.go b/runsc/boot/network.go
new file mode 100644
index 000000000..14d2f56a5
--- /dev/null
+++ b/runsc/boot/network.go
@@ -0,0 +1,341 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "fmt"
+ "net"
+ "runtime"
+ "strings"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/link/qdisc/fifo"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/urpc"
+)
+
+var (
+ // DefaultLoopbackLink contains IP addresses and routes of "127.0.0.1/8" and
+ // "::1/8" on "lo" interface.
+ DefaultLoopbackLink = LoopbackLink{
+ Name: "lo",
+ Addresses: []net.IP{
+ net.IP("\x7f\x00\x00\x01"),
+ net.IPv6loopback,
+ },
+ Routes: []Route{
+ {
+ Destination: net.IPNet{
+ IP: net.IPv4(0x7f, 0, 0, 0),
+ Mask: net.IPv4Mask(0xff, 0, 0, 0),
+ },
+ },
+ {
+ Destination: net.IPNet{
+ IP: net.IPv6loopback,
+ Mask: net.IPMask(strings.Repeat("\xff", net.IPv6len)),
+ },
+ },
+ },
+ }
+)
+
+// Network exposes methods that can be used to configure a network stack.
+type Network struct {
+ Stack *stack.Stack
+}
+
+// Route represents a route in the network stack.
+type Route struct {
+ Destination net.IPNet
+ Gateway net.IP
+}
+
+// DefaultRoute represents a catch all route to the default gateway.
+type DefaultRoute struct {
+ Route Route
+ Name string
+}
+
+// QueueingDiscipline is used to specify the kind of Queueing Discipline to
+// apply for a give FDBasedLink.
+type QueueingDiscipline int
+
+const (
+ // QDiscNone disables any queueing for the underlying FD.
+ QDiscNone QueueingDiscipline = iota
+
+ // QDiscFIFO applies a simple fifo based queue to the underlying
+ // FD.
+ QDiscFIFO
+)
+
+// MakeQueueingDiscipline if possible the equivalent QueuingDiscipline for s
+// else returns an error.
+func MakeQueueingDiscipline(s string) (QueueingDiscipline, error) {
+ switch s {
+ case "none":
+ return QDiscNone, nil
+ case "fifo":
+ return QDiscFIFO, nil
+ default:
+ return 0, fmt.Errorf("unsupported qdisc specified: %q", s)
+ }
+}
+
+// String implements fmt.Stringer.
+func (q QueueingDiscipline) String() string {
+ switch q {
+ case QDiscNone:
+ return "none"
+ case QDiscFIFO:
+ return "fifo"
+ default:
+ panic(fmt.Sprintf("Invalid queueing discipline: %d", q))
+ }
+}
+
+// FDBasedLink configures an fd-based link.
+type FDBasedLink struct {
+ Name string
+ MTU int
+ Addresses []net.IP
+ Routes []Route
+ GSOMaxSize uint32
+ SoftwareGSOEnabled bool
+ TXChecksumOffload bool
+ RXChecksumOffload bool
+ LinkAddress net.HardwareAddr
+ QDisc QueueingDiscipline
+
+ // NumChannels controls how many underlying FD's are to be used to
+ // create this endpoint.
+ NumChannels int
+}
+
+// LoopbackLink configures a loopback li nk.
+type LoopbackLink struct {
+ Name string
+ Addresses []net.IP
+ Routes []Route
+}
+
+// CreateLinksAndRoutesArgs are arguments to CreateLinkAndRoutes.
+type CreateLinksAndRoutesArgs struct {
+ // FilePayload contains the fds associated with the FDBasedLinks. The
+ // number of fd's should match the sum of the NumChannels field of the
+ // FDBasedLink entries below.
+ urpc.FilePayload
+
+ LoopbackLinks []LoopbackLink
+ FDBasedLinks []FDBasedLink
+
+ Defaultv4Gateway DefaultRoute
+ Defaultv6Gateway DefaultRoute
+}
+
+// Empty returns true if route hasn't been set.
+func (r *Route) Empty() bool {
+ return r.Destination.IP == nil && r.Destination.Mask == nil && r.Gateway == nil
+}
+
+func (r *Route) toTcpipRoute(id tcpip.NICID) (tcpip.Route, error) {
+ subnet, err := tcpip.NewSubnet(ipToAddress(r.Destination.IP), ipMaskToAddressMask(r.Destination.Mask))
+ if err != nil {
+ return tcpip.Route{}, err
+ }
+ return tcpip.Route{
+ Destination: subnet,
+ Gateway: ipToAddress(r.Gateway),
+ NIC: id,
+ }, nil
+}
+
+// CreateLinksAndRoutes creates links and routes in a network stack. It should
+// only be called once.
+func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct{}) error {
+ wantFDs := 0
+ for _, l := range args.FDBasedLinks {
+ wantFDs += l.NumChannels
+ }
+ if got := len(args.FilePayload.Files); got != wantFDs {
+ return fmt.Errorf("args.FilePayload.Files has %d FD's but we need %d entries based on FDBasedLinks", got, wantFDs)
+ }
+
+ var nicID tcpip.NICID
+ nicids := make(map[string]tcpip.NICID)
+
+ // Collect routes from all links.
+ var routes []tcpip.Route
+
+ // Loopback normally appear before other interfaces.
+ for _, link := range args.LoopbackLinks {
+ nicID++
+ nicids[link.Name] = nicID
+
+ linkEP := loopback.New()
+
+ log.Infof("Enabling loopback interface %q with id %d on addresses %+v", link.Name, nicID, link.Addresses)
+ if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil {
+ return err
+ }
+
+ // Collect the routes from this link.
+ for _, r := range link.Routes {
+ route, err := r.toTcpipRoute(nicID)
+ if err != nil {
+ return err
+ }
+ routes = append(routes, route)
+ }
+ }
+
+ fdOffset := 0
+ for _, link := range args.FDBasedLinks {
+ nicID++
+ nicids[link.Name] = nicID
+
+ FDs := []int{}
+ for j := 0; j < link.NumChannels; j++ {
+ // Copy the underlying FD.
+ oldFD := args.FilePayload.Files[fdOffset].Fd()
+ newFD, err := syscall.Dup(int(oldFD))
+ if err != nil {
+ return fmt.Errorf("failed to dup FD %v: %v", oldFD, err)
+ }
+ FDs = append(FDs, newFD)
+ fdOffset++
+ }
+
+ mac := tcpip.LinkAddress(link.LinkAddress)
+ log.Infof("gso max size is: %d", link.GSOMaxSize)
+
+ linkEP, err := fdbased.New(&fdbased.Options{
+ FDs: FDs,
+ MTU: uint32(link.MTU),
+ EthernetHeader: true,
+ Address: mac,
+ PacketDispatchMode: fdbased.RecvMMsg,
+ GSOMaxSize: link.GSOMaxSize,
+ SoftwareGSOEnabled: link.SoftwareGSOEnabled,
+ TXChecksumOffload: link.TXChecksumOffload,
+ RXChecksumOffload: link.RXChecksumOffload,
+ })
+ if err != nil {
+ return err
+ }
+
+ switch link.QDisc {
+ case QDiscNone:
+ case QDiscFIFO:
+ log.Infof("Enabling FIFO QDisc on %q", link.Name)
+ linkEP = fifo.New(linkEP, runtime.GOMAXPROCS(0), 1000)
+ }
+
+ log.Infof("Enabling interface %q with id %d on addresses %+v (%v) w/ %d channels", link.Name, nicID, link.Addresses, mac, link.NumChannels)
+ if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil {
+ return err
+ }
+
+ // Collect the routes from this link.
+ for _, r := range link.Routes {
+ route, err := r.toTcpipRoute(nicID)
+ if err != nil {
+ return err
+ }
+ routes = append(routes, route)
+ }
+ }
+
+ if !args.Defaultv4Gateway.Route.Empty() {
+ nicID, ok := nicids[args.Defaultv4Gateway.Name]
+ if !ok {
+ return fmt.Errorf("invalid interface name %q for default route", args.Defaultv4Gateway.Name)
+ }
+ route, err := args.Defaultv4Gateway.Route.toTcpipRoute(nicID)
+ if err != nil {
+ return err
+ }
+ routes = append(routes, route)
+ }
+
+ if !args.Defaultv6Gateway.Route.Empty() {
+ nicID, ok := nicids[args.Defaultv6Gateway.Name]
+ if !ok {
+ return fmt.Errorf("invalid interface name %q for default route", args.Defaultv6Gateway.Name)
+ }
+ route, err := args.Defaultv6Gateway.Route.toTcpipRoute(nicID)
+ if err != nil {
+ return err
+ }
+ routes = append(routes, route)
+ }
+
+ log.Infof("Setting routes %+v", routes)
+ n.Stack.SetRouteTable(routes)
+ return nil
+}
+
+// createNICWithAddrs creates a NIC in the network stack and adds the given
+// addresses.
+func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkEndpoint, addrs []net.IP) error {
+ opts := stack.NICOptions{Name: name}
+ if err := n.Stack.CreateNICWithOptions(id, sniffer.New(ep), opts); err != nil {
+ return fmt.Errorf("CreateNICWithOptions(%d, _, %+v) failed: %v", id, opts, err)
+ }
+
+ // Always start with an arp address for the NIC.
+ if err := n.Stack.AddAddress(id, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ return fmt.Errorf("AddAddress(%v, %v, %v) failed: %v", id, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+
+ for _, addr := range addrs {
+ proto, tcpipAddr := ipToAddressAndProto(addr)
+ if err := n.Stack.AddAddress(id, proto, tcpipAddr); err != nil {
+ return fmt.Errorf("AddAddress(%v, %v, %v) failed: %v", id, proto, tcpipAddr, err)
+ }
+ }
+ return nil
+}
+
+// ipToAddressAndProto converts IP to tcpip.Address and a protocol number.
+//
+// Note: don't use 'len(ip)' to determine IP version because length is always 16.
+func ipToAddressAndProto(ip net.IP) (tcpip.NetworkProtocolNumber, tcpip.Address) {
+ if i4 := ip.To4(); i4 != nil {
+ return ipv4.ProtocolNumber, tcpip.Address(i4)
+ }
+ return ipv6.ProtocolNumber, tcpip.Address(ip)
+}
+
+// ipToAddress converts IP to tcpip.Address, ignoring the protocol.
+func ipToAddress(ip net.IP) tcpip.Address {
+ _, addr := ipToAddressAndProto(ip)
+ return addr
+}
+
+// ipMaskToAddressMask converts IPMask to tcpip.AddressMask, ignoring the
+// protocol.
+func ipMaskToAddressMask(ipMask net.IPMask) tcpip.AddressMask {
+ return tcpip.AddressMask(ipToAddress(net.IP(ipMask)))
+}
diff --git a/runsc/boot/platforms/BUILD b/runsc/boot/platforms/BUILD
new file mode 100644
index 000000000..77774f43c
--- /dev/null
+++ b/runsc/boot/platforms/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "platforms",
+ srcs = ["platforms.go"],
+ visibility = [
+ "//runsc:__subpackages__",
+ ],
+ deps = [
+ "//pkg/sentry/platform/kvm",
+ "//pkg/sentry/platform/ptrace",
+ ],
+)
diff --git a/runsc/boot/platforms/platforms.go b/runsc/boot/platforms/platforms.go
new file mode 100644
index 000000000..056b46ad5
--- /dev/null
+++ b/runsc/boot/platforms/platforms.go
@@ -0,0 +1,30 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package platforms imports all available platform packages.
+package platforms
+
+import (
+ // Import platforms that runsc might use.
+ _ "gvisor.dev/gvisor/pkg/sentry/platform/kvm"
+ _ "gvisor.dev/gvisor/pkg/sentry/platform/ptrace"
+)
+
+const (
+ // Ptrace runs the sandbox with the ptrace platform.
+ Ptrace = "ptrace"
+
+ // KVM runs the sandbox with the KVM platform.
+ KVM = "kvm"
+)
diff --git a/runsc/boot/pprof/BUILD b/runsc/boot/pprof/BUILD
new file mode 100644
index 000000000..29cb42b2f
--- /dev/null
+++ b/runsc/boot/pprof/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "pprof",
+ srcs = ["pprof.go"],
+ visibility = [
+ "//runsc:__subpackages__",
+ ],
+)
diff --git a/runsc/boot/pprof/pprof.go b/runsc/boot/pprof/pprof.go
new file mode 100644
index 000000000..1ded20dee
--- /dev/null
+++ b/runsc/boot/pprof/pprof.go
@@ -0,0 +1,20 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pprof provides a stub to initialize custom profilers.
+package pprof
+
+// Initialize will be called at boot for initializing custom profilers.
+func Initialize() {
+}
diff --git a/runsc/boot/strace.go b/runsc/boot/strace.go
new file mode 100644
index 000000000..fbfd3b07c
--- /dev/null
+++ b/runsc/boot/strace.go
@@ -0,0 +1,40 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/strace"
+)
+
+func enableStrace(conf *Config) error {
+ // We must initialize even if strace is not enabled.
+ strace.Initialize()
+
+ if !conf.Strace {
+ return nil
+ }
+
+ max := conf.StraceLogSize
+ if max == 0 {
+ max = 1024
+ }
+ strace.LogMaximumSize = max
+
+ if len(conf.StraceSyscalls) == 0 {
+ strace.EnableAll(strace.SinkTypeLog)
+ return nil
+ }
+ return strace.Enable(conf.StraceSyscalls, strace.SinkTypeLog)
+}
diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go
new file mode 100644
index 000000000..6ee6fae04
--- /dev/null
+++ b/runsc/boot/vfs.go
@@ -0,0 +1,482 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "fmt"
+ "path"
+ "sort"
+ "strings"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/devices/memdev"
+ "gvisor.dev/gvisor/pkg/sentry/devices/ttydev"
+ "gvisor.dev/gvisor/pkg/sentry/devices/tundev"
+ "gvisor.dev/gvisor/pkg/sentry/fs/user"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devpts"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/fuse"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/gofer"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/overlay"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sys"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+func registerFilesystems(k *kernel.Kernel) error {
+ ctx := k.SupervisorContext()
+ creds := auth.NewRootCredentials(k.RootUserNamespace())
+ vfsObj := k.VFS()
+
+ vfsObj.MustRegisterFilesystemType(devpts.Name, &devpts.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserList: true,
+ // TODO(b/29356795): Users may mount this once the terminals are in a
+ // usable state.
+ AllowUserMount: false,
+ })
+ vfsObj.MustRegisterFilesystemType(devtmpfs.Name, &devtmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+ vfsObj.MustRegisterFilesystemType(gofer.Name, &gofer.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserList: true,
+ })
+ vfsObj.MustRegisterFilesystemType(overlay.Name, &overlay.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+ vfsObj.MustRegisterFilesystemType(proc.Name, &proc.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+ vfsObj.MustRegisterFilesystemType(sys.Name, &sys.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+ vfsObj.MustRegisterFilesystemType(tmpfs.Name, &tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+
+ // Setup files in devtmpfs.
+ if err := memdev.Register(vfsObj); err != nil {
+ return fmt.Errorf("registering memdev: %w", err)
+ }
+ if err := ttydev.Register(vfsObj); err != nil {
+ return fmt.Errorf("registering ttydev: %w", err)
+ }
+
+ if err := fuse.Register(vfsObj); err != nil {
+ return fmt.Errorf("registering fusedev: %w", err)
+ }
+ if err := tundev.Register(vfsObj); err != nil {
+ return fmt.Errorf("registering tundev: %v", err)
+ }
+ a, err := devtmpfs.NewAccessor(ctx, vfsObj, creds, devtmpfs.Name)
+ if err != nil {
+ return fmt.Errorf("creating devtmpfs accessor: %w", err)
+ }
+ defer a.Release()
+
+ if err := a.UserspaceInit(ctx); err != nil {
+ return fmt.Errorf("initializing userspace: %w", err)
+ }
+ if err := memdev.CreateDevtmpfsFiles(ctx, a); err != nil {
+ return fmt.Errorf("creating memdev devtmpfs files: %w", err)
+ }
+ if err := ttydev.CreateDevtmpfsFiles(ctx, a); err != nil {
+ return fmt.Errorf("creating ttydev devtmpfs files: %w", err)
+ }
+ if err := tundev.CreateDevtmpfsFiles(ctx, a); err != nil {
+ return fmt.Errorf("creating tundev devtmpfs files: %v", err)
+ }
+ if err := fuse.CreateDevtmpfsFile(ctx, a); err != nil {
+ return fmt.Errorf("creating fusedev devtmpfs files: %w", err)
+ }
+ return nil
+}
+
+func setupContainerVFS2(ctx context.Context, conf *Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error {
+ mns, err := mntr.setupVFS2(ctx, conf, procArgs)
+ if err != nil {
+ return fmt.Errorf("failed to setupFS: %w", err)
+ }
+ procArgs.MountNamespaceVFS2 = mns
+
+ // Resolve the executable path from working dir and environment.
+ resolved, err := user.ResolveExecutablePath(ctx, procArgs)
+ if err != nil {
+ return err
+ }
+ procArgs.Filename = resolved
+ return nil
+}
+
+func (c *containerMounter) setupVFS2(ctx context.Context, conf *Config, procArgs *kernel.CreateProcessArgs) (*vfs.MountNamespace, error) {
+ log.Infof("Configuring container's file system with VFS2")
+
+ // Create context with root credentials to mount the filesystem (the current
+ // user may not be privileged enough).
+ rootCreds := auth.NewRootCredentials(procArgs.Credentials.UserNamespace)
+ rootProcArgs := *procArgs
+ rootProcArgs.WorkingDirectory = "/"
+ rootProcArgs.Credentials = rootCreds
+ rootProcArgs.Umask = 0022
+ rootProcArgs.MaxSymlinkTraversals = linux.MaxSymlinkTraversals
+ rootCtx := procArgs.NewContext(c.k)
+
+ mns, err := c.createMountNamespaceVFS2(rootCtx, conf, rootCreds)
+ if err != nil {
+ return nil, fmt.Errorf("creating mount namespace: %w", err)
+ }
+ rootProcArgs.MountNamespaceVFS2 = mns
+
+ // Mount submounts.
+ if err := c.mountSubmountsVFS2(rootCtx, conf, mns, rootCreds); err != nil {
+ return nil, fmt.Errorf("mounting submounts vfs2: %w", err)
+ }
+ return mns, nil
+}
+
+func (c *containerMounter) createMountNamespaceVFS2(ctx context.Context, conf *Config, creds *auth.Credentials) (*vfs.MountNamespace, error) {
+ fd := c.fds.remove()
+ opts := strings.Join(p9MountData(fd, conf.FileAccess, true /* vfs2 */), ",")
+
+ log.Infof("Mounting root over 9P, ioFD: %d", fd)
+ mns, err := c.k.VFS().NewMountNamespace(ctx, creds, "", gofer.Name, &vfs.GetFilesystemOptions{Data: opts})
+ if err != nil {
+ return nil, fmt.Errorf("setting up mount namespace: %w", err)
+ }
+ return mns, nil
+}
+
+func (c *containerMounter) mountSubmountsVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials) error {
+ mounts, err := c.prepareMountsVFS2()
+ if err != nil {
+ return err
+ }
+
+ for i := range mounts {
+ submount := &mounts[i]
+ log.Debugf("Mounting %q to %q, type: %s, options: %s", submount.Source, submount.Destination, submount.Type, submount.Options)
+ if hint := c.hints.findMount(submount.Mount); hint != nil && hint.isSupported() {
+ if err := c.mountSharedSubmountVFS2(ctx, conf, mns, creds, submount.Mount, hint); err != nil {
+ return fmt.Errorf("mount shared mount %q to %q: %v", hint.name, submount.Destination, err)
+ }
+ } else {
+ if err := c.mountSubmountVFS2(ctx, conf, mns, creds, submount); err != nil {
+ return fmt.Errorf("mount submount %q: %w", submount.Destination, err)
+ }
+ }
+ }
+
+ if err := c.mountTmpVFS2(ctx, conf, creds, mns); err != nil {
+ return fmt.Errorf(`mount submount "\tmp": %w`, err)
+ }
+ return nil
+}
+
+type mountAndFD struct {
+ specs.Mount
+ fd int
+}
+
+func (c *containerMounter) prepareMountsVFS2() ([]mountAndFD, error) {
+ // Associate bind mounts with their FDs before sorting since there is an
+ // undocumented assumption that FDs are dispensed in the order in which
+ // they are required by mounts.
+ var mounts []mountAndFD
+ for _, m := range c.mounts {
+ fd := -1
+ // Only bind mounts use host FDs; see
+ // containerMounter.getMountNameAndOptionsVFS2.
+ if m.Type == bind {
+ fd = c.fds.remove()
+ }
+ mounts = append(mounts, mountAndFD{
+ Mount: m,
+ fd: fd,
+ })
+ }
+ if err := c.checkDispenser(); err != nil {
+ return nil, err
+ }
+
+ // Sort the mounts so that we don't place children before parents.
+ sort.Slice(mounts, func(i, j int) bool {
+ return len(mounts[i].Destination) < len(mounts[j].Destination)
+ })
+
+ return mounts, nil
+}
+
+func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials, submount *mountAndFD) error {
+ root := mns.Root()
+ defer root.DecRef()
+ target := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(submount.Destination),
+ }
+ fsName, opts, err := c.getMountNameAndOptionsVFS2(conf, submount)
+ if err != nil {
+ return fmt.Errorf("mountOptions failed: %w", err)
+ }
+ if len(fsName) == 0 {
+ // Filesystem is not supported (e.g. cgroup), just skip it.
+ return nil
+ }
+
+ if err := c.makeSyntheticMount(ctx, submount.Destination, root, creds); err != nil {
+ return err
+ }
+ if err := c.k.VFS().MountAt(ctx, creds, "", target, fsName, opts); err != nil {
+ return fmt.Errorf("failed to mount %q (type: %s): %w, opts: %v", submount.Destination, submount.Type, err, opts)
+ }
+ log.Infof("Mounted %q to %q type: %s, internal-options: %q", submount.Source, submount.Destination, submount.Type, opts.GetFilesystemOptions.Data)
+ return nil
+}
+
+// getMountNameAndOptionsVFS2 retrieves the fsName, opts, and useOverlay values
+// used for mounts.
+func (c *containerMounter) getMountNameAndOptionsVFS2(conf *Config, m *mountAndFD) (string, *vfs.MountOptions, error) {
+ fsName := m.Type
+ var data []string
+
+ // Find filesystem name and FS specific data field.
+ switch m.Type {
+ case devpts.Name, devtmpfs.Name, proc.Name, sys.Name:
+ // Nothing to do.
+
+ case nonefs:
+ fsName = sys.Name
+
+ case tmpfs.Name:
+ var err error
+ data, err = parseAndFilterOptions(m.Options, tmpfsAllowedData...)
+ if err != nil {
+ return "", nil, err
+ }
+
+ case bind:
+ fsName = gofer.Name
+ if m.fd == 0 {
+ // Check that an FD was provided to fails fast. Technically FD=0 is valid,
+ // but unlikely to be correct in this context.
+ return "", nil, fmt.Errorf("9P mount requires a connection FD")
+ }
+ data = p9MountData(m.fd, c.getMountAccessType(m.Mount), true /* vfs2 */)
+
+ default:
+ log.Warningf("ignoring unknown filesystem type %q", m.Type)
+ return "", nil, nil
+ }
+
+ opts := &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ Data: strings.Join(data, ","),
+ },
+ InternalMount: true,
+ }
+
+ for _, o := range m.Options {
+ switch o {
+ case "rw":
+ opts.ReadOnly = false
+ case "ro":
+ opts.ReadOnly = true
+ case "noatime":
+ opts.Flags.NoATime = true
+ case "noexec":
+ opts.Flags.NoExec = true
+ default:
+ log.Warningf("ignoring unknown mount option %q", o)
+ }
+ }
+
+ if conf.Overlay {
+ // All writes go to upper, be paranoid and make lower readonly.
+ opts.ReadOnly = true
+ }
+ return fsName, opts, nil
+}
+
+func (c *containerMounter) makeSyntheticMount(ctx context.Context, currentPath string, root vfs.VirtualDentry, creds *auth.Credentials) error {
+ target := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(currentPath),
+ }
+ _, err := c.k.VFS().StatAt(ctx, creds, target, &vfs.StatOptions{})
+ if err == nil {
+ log.Debugf("Mount point %q already exists", currentPath)
+ return nil
+ }
+ if err != syserror.ENOENT {
+ return fmt.Errorf("stat failed for %q during mount point creation: %w", currentPath, err)
+ }
+
+ // Recurse to ensure parent is created and then create the mount point.
+ if err := c.makeSyntheticMount(ctx, path.Dir(currentPath), root, creds); err != nil {
+ return err
+ }
+ log.Debugf("Creating dir %q for mount point", currentPath)
+ mkdirOpts := &vfs.MkdirOptions{Mode: 0777, ForSyntheticMountpoint: true}
+ if err := c.k.VFS().MkdirAt(ctx, creds, target, mkdirOpts); err != nil {
+ return fmt.Errorf("failed to create directory %q for mount: %w", currentPath, err)
+ }
+ return nil
+}
+
+// mountTmpVFS2 mounts an internal tmpfs at '/tmp' if it's safe to do so.
+// Technically we don't have to mount tmpfs at /tmp, as we could just rely on
+// the host /tmp, but this is a nice optimization, and fixes some apps that call
+// mknod in /tmp. It's unsafe to mount tmpfs if:
+// 1. /tmp is mounted explicitly: we should not override user's wish
+// 2. /tmp is not empty: mounting tmpfs would hide existing files in /tmp
+//
+// Note that when there are submounts inside of '/tmp', directories for the
+// mount points must be present, making '/tmp' not empty anymore.
+func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *Config, creds *auth.Credentials, mns *vfs.MountNamespace) error {
+ for _, m := range c.mounts {
+ // m.Destination has been cleaned, so it's to use equality here.
+ if m.Destination == "/tmp" {
+ log.Debugf(`Explict "/tmp" mount found, skipping internal tmpfs, mount: %+v`, m)
+ return nil
+ }
+ }
+
+ root := mns.Root()
+ defer root.DecRef()
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse("/tmp"),
+ }
+ // TODO(gvisor.dev/issue/2782): Use O_PATH when available.
+ statx, err := c.k.VFS().StatAt(ctx, creds, &pop, &vfs.StatOptions{})
+ switch err {
+ case nil:
+ // Found '/tmp' in filesystem, check if it's empty.
+ if linux.FileMode(statx.Mode).FileType() != linux.ModeDirectory {
+ // Not a dir?! Leave it be.
+ return nil
+ }
+ if statx.Nlink > 2 {
+ // If more than "." and ".." is found, skip internal tmpfs to prevent
+ // hiding existing files.
+ log.Infof(`Skipping internal tmpfs mount for "/tmp" because it's not empty`)
+ return nil
+ }
+ log.Infof(`Mounting internal tmpfs on top of empty "/tmp"`)
+ fallthrough
+
+ case syserror.ENOENT:
+ // No '/tmp' found (or fallthrough from above). It's safe to mount internal
+ // tmpfs.
+ tmpMount := specs.Mount{
+ Type: tmpfs.Name,
+ Destination: "/tmp",
+ // Sticky bit is added to prevent accidental deletion of files from
+ // another user. This is normally done for /tmp.
+ Options: []string{"mode=01777"},
+ }
+ return c.mountSubmountVFS2(ctx, conf, mns, creds, &mountAndFD{Mount: tmpMount})
+
+ default:
+ return fmt.Errorf(`stating "/tmp" inside container: %w`, err)
+ }
+}
+
+// processHintsVFS2 processes annotations that container hints about how volumes
+// should be mounted (e.g. a volume shared between containers). It must be
+// called for the root container only.
+func (c *containerMounter) processHintsVFS2(conf *Config, creds *auth.Credentials) error {
+ ctx := c.k.SupervisorContext()
+ for _, hint := range c.hints.mounts {
+ // TODO(b/142076984): Only support tmpfs for now. Bind mounts require a
+ // common gofer to mount all shared volumes.
+ if hint.mount.Type != tmpfs.Name {
+ continue
+ }
+
+ log.Infof("Mounting master of shared mount %q from %q type %q", hint.name, hint.mount.Source, hint.mount.Type)
+ mnt, err := c.mountSharedMasterVFS2(ctx, conf, hint, creds)
+ if err != nil {
+ return fmt.Errorf("mounting shared master %q: %v", hint.name, err)
+ }
+ hint.vfsMount = mnt
+ }
+ return nil
+}
+
+// mountSharedMasterVFS2 mounts the master of a volume that is shared among
+// containers in a pod.
+func (c *containerMounter) mountSharedMasterVFS2(ctx context.Context, conf *Config, hint *mountHint, creds *auth.Credentials) (*vfs.Mount, error) {
+ // Map mount type to filesystem name, and parse out the options that we are
+ // capable of dealing with.
+ mntFD := &mountAndFD{Mount: hint.mount}
+ fsName, opts, err := c.getMountNameAndOptionsVFS2(conf, mntFD)
+ if err != nil {
+ return nil, err
+ }
+ if len(fsName) == 0 {
+ return nil, fmt.Errorf("mount type not supported %q", hint.mount.Type)
+ }
+ return c.k.VFS().MountDisconnected(ctx, creds, "", fsName, opts)
+}
+
+// mountSharedSubmount binds mount to a previously mounted volume that is shared
+// among containers in the same pod.
+func (c *containerMounter) mountSharedSubmountVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials, mount specs.Mount, source *mountHint) error {
+ if err := source.checkCompatible(mount); err != nil {
+ return err
+ }
+
+ _, opts, err := c.getMountNameAndOptionsVFS2(conf, &mountAndFD{Mount: mount})
+ if err != nil {
+ return err
+ }
+ newMnt, err := c.k.VFS().NewDisconnectedMount(source.vfsMount.Filesystem(), source.vfsMount.Root(), opts)
+ if err != nil {
+ return err
+ }
+ defer newMnt.DecRef()
+
+ root := mns.Root()
+ defer root.DecRef()
+ if err := c.makeSyntheticMount(ctx, mount.Destination, root, creds); err != nil {
+ return err
+ }
+
+ target := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(mount.Destination),
+ }
+ if err := c.k.VFS().ConnectMountAt(ctx, creds, newMnt, target); err != nil {
+ return err
+ }
+ log.Infof("Mounted %q type shared bind to %q", mount.Destination, source.name)
+ return nil
+}
diff --git a/runsc/cgroup/BUILD b/runsc/cgroup/BUILD
new file mode 100644
index 000000000..7e34a284a
--- /dev/null
+++ b/runsc/cgroup/BUILD
@@ -0,0 +1,27 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "cgroup",
+ srcs = ["cgroup.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/cleanup",
+ "//pkg/log",
+ "@com_github_cenkalti_backoff//:go_default_library",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ ],
+)
+
+go_test(
+ name = "cgroup_test",
+ size = "small",
+ srcs = ["cgroup_test.go"],
+ library = ":cgroup",
+ tags = ["local"],
+ deps = [
+ "//pkg/test/testutil",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ ],
+)
diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go
new file mode 100644
index 000000000..e5cc9d622
--- /dev/null
+++ b/runsc/cgroup/cgroup.go
@@ -0,0 +1,576 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package cgroup provides an interface to read and write configuration to
+// cgroup.
+package cgroup
+
+import (
+ "bufio"
+ "context"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "syscall"
+ "time"
+
+ "github.com/cenkalti/backoff"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+const (
+ cgroupRoot = "/sys/fs/cgroup"
+)
+
+var controllers = map[string]config{
+ "blkio": config{ctrlr: &blockIO{}},
+ "cpu": config{ctrlr: &cpu{}},
+ "cpuset": config{ctrlr: &cpuSet{}},
+ "hugetlb": config{ctrlr: &hugeTLB{}, optional: true},
+ "memory": config{ctrlr: &memory{}},
+ "net_cls": config{ctrlr: &networkClass{}},
+ "net_prio": config{ctrlr: &networkPrio{}},
+ "pids": config{ctrlr: &pids{}},
+
+ // These controllers either don't have anything in the OCI spec or is
+ // irrelevant for a sandbox.
+ "devices": config{ctrlr: &noop{}},
+ "freezer": config{ctrlr: &noop{}},
+ "perf_event": config{ctrlr: &noop{}},
+ "rdma": config{ctrlr: &noop{}, optional: true},
+ "systemd": config{ctrlr: &noop{}},
+}
+
+func setOptionalValueInt(path, name string, val *int64) error {
+ if val == nil || *val == 0 {
+ return nil
+ }
+ str := strconv.FormatInt(*val, 10)
+ return setValue(path, name, str)
+}
+
+func setOptionalValueUint(path, name string, val *uint64) error {
+ if val == nil || *val == 0 {
+ return nil
+ }
+ str := strconv.FormatUint(*val, 10)
+ return setValue(path, name, str)
+}
+
+func setOptionalValueUint32(path, name string, val *uint32) error {
+ if val == nil || *val == 0 {
+ return nil
+ }
+ str := strconv.FormatUint(uint64(*val), 10)
+ return setValue(path, name, str)
+}
+
+func setOptionalValueUint16(path, name string, val *uint16) error {
+ if val == nil || *val == 0 {
+ return nil
+ }
+ str := strconv.FormatUint(uint64(*val), 10)
+ return setValue(path, name, str)
+}
+
+func setValue(path, name, data string) error {
+ fullpath := filepath.Join(path, name)
+ return ioutil.WriteFile(fullpath, []byte(data), 0700)
+}
+
+func getValue(path, name string) (string, error) {
+ fullpath := filepath.Join(path, name)
+ out, err := ioutil.ReadFile(fullpath)
+ if err != nil {
+ return "", err
+ }
+ return string(out), nil
+}
+
+func getInt(path, name string) (int, error) {
+ s, err := getValue(path, name)
+ if err != nil {
+ return 0, err
+ }
+ return strconv.Atoi(strings.TrimSpace(s))
+}
+
+// fillFromAncestor sets the value of a cgroup file from the first ancestor
+// that has content. It does nothing if the file in 'path' has already been set.
+func fillFromAncestor(path string) (string, error) {
+ out, err := ioutil.ReadFile(path)
+ if err != nil {
+ return "", err
+ }
+ val := strings.TrimSpace(string(out))
+ if val != "" {
+ // File is set, stop here.
+ return val, nil
+ }
+
+ // File is not set, recurse to parent and then set here.
+ name := filepath.Base(path)
+ parent := filepath.Dir(filepath.Dir(path))
+ val, err = fillFromAncestor(filepath.Join(parent, name))
+ if err != nil {
+ return "", err
+ }
+ if err := ioutil.WriteFile(path, []byte(val), 0700); err != nil {
+ return "", err
+ }
+ return val, nil
+}
+
+// countCpuset returns the number of CPU in a string formatted like:
+// "0-2,7,12-14 # bits 0, 1, 2, 7, 12, 13, and 14 set" - man 7 cpuset
+func countCpuset(cpuset string) (int, error) {
+ var count int
+ for _, p := range strings.Split(cpuset, ",") {
+ interval := strings.Split(p, "-")
+ switch len(interval) {
+ case 1:
+ if _, err := strconv.Atoi(interval[0]); err != nil {
+ return 0, err
+ }
+ count++
+
+ case 2:
+ start, err := strconv.Atoi(interval[0])
+ if err != nil {
+ return 0, err
+ }
+ end, err := strconv.Atoi(interval[1])
+ if err != nil {
+ return 0, err
+ }
+ if start < 0 || end < 0 || start > end {
+ return 0, fmt.Errorf("invalid cpuset: %q", p)
+ }
+ count += end - start + 1
+
+ default:
+ return 0, fmt.Errorf("invalid cpuset: %q", p)
+ }
+ }
+ return count, nil
+}
+
+// LoadPaths loads cgroup paths for given 'pid', may be set to 'self'.
+func LoadPaths(pid string) (map[string]string, error) {
+ f, err := os.Open(filepath.Join("/proc", pid, "cgroup"))
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+
+ paths := make(map[string]string)
+ scanner := bufio.NewScanner(f)
+ for scanner.Scan() {
+ // Format: ID:controller1,controller2:path
+ // Example: 2:cpu,cpuacct:/user.slice
+ tokens := strings.Split(scanner.Text(), ":")
+ if len(tokens) != 3 {
+ return nil, fmt.Errorf("invalid cgroups file, line: %q", scanner.Text())
+ }
+ for _, ctrlr := range strings.Split(tokens[1], ",") {
+ paths[ctrlr] = tokens[2]
+ }
+ }
+ if err := scanner.Err(); err != nil {
+ return nil, err
+ }
+ return paths, nil
+}
+
+// Cgroup represents a group inside all controllers. For example:
+// Name='/foo/bar' maps to /sys/fs/cgroup/<controller>/foo/bar on
+// all controllers.
+type Cgroup struct {
+ Name string `json:"name"`
+ Parents map[string]string `json:"parents"`
+ Own bool `json:"own"`
+}
+
+// New creates a new Cgroup instance if the spec includes a cgroup path.
+// Returns nil otherwise.
+func New(spec *specs.Spec) (*Cgroup, error) {
+ if spec.Linux == nil || spec.Linux.CgroupsPath == "" {
+ return nil, nil
+ }
+ var parents map[string]string
+ if !filepath.IsAbs(spec.Linux.CgroupsPath) {
+ var err error
+ parents, err = LoadPaths("self")
+ if err != nil {
+ return nil, fmt.Errorf("finding current cgroups: %v", err)
+ }
+ }
+ return &Cgroup{
+ Name: spec.Linux.CgroupsPath,
+ Parents: parents,
+ }, nil
+}
+
+// Install creates and configures cgroups according to 'res'. If cgroup path
+// already exists, it means that the caller has already provided a
+// pre-configured cgroups, and 'res' is ignored.
+func (c *Cgroup) Install(res *specs.LinuxResources) error {
+ if _, err := os.Stat(c.makePath("memory")); err == nil {
+ // If cgroup has already been created; it has been setup by caller. Don't
+ // make any changes to configuration, just join when sandbox/gofer starts.
+ log.Debugf("Using pre-created cgroup %q", c.Name)
+ return nil
+ }
+
+ log.Debugf("Creating cgroup %q", c.Name)
+
+ // Mark that cgroup resources are owned by me.
+ c.Own = true
+
+ // The Cleanup object cleans up partially created cgroups when an error occurs.
+ // Errors occuring during cleanup itself are ignored.
+ clean := cleanup.Make(func() { _ = c.Uninstall() })
+ defer clean.Clean()
+
+ for key, cfg := range controllers {
+ path := c.makePath(key)
+ if err := os.MkdirAll(path, 0755); err != nil {
+ if cfg.optional && errors.Is(err, syscall.EROFS) {
+ log.Infof("Skipping cgroup %q", key)
+ continue
+ }
+ return err
+ }
+ if res != nil {
+ if err := cfg.ctrlr.set(res, path); err != nil {
+ return err
+ }
+ }
+ }
+ clean.Release()
+ return nil
+}
+
+// Uninstall removes the settings done in Install(). If cgroup path already
+// existed when Install() was called, Uninstall is a noop.
+func (c *Cgroup) Uninstall() error {
+ if !c.Own {
+ // cgroup is managed by caller, don't touch it.
+ return nil
+ }
+ log.Debugf("Deleting cgroup %q", c.Name)
+ for key := range controllers {
+ path := c.makePath(key)
+ log.Debugf("Removing cgroup controller for key=%q path=%q", key, path)
+
+ // If we try to remove the cgroup too soon after killing the
+ // sandbox we might get EBUSY, so we retry for a few seconds
+ // until it succeeds.
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx)
+ if err := backoff.Retry(func() error {
+ err := syscall.Rmdir(path)
+ if os.IsNotExist(err) {
+ return nil
+ }
+ return err
+ }, b); err != nil {
+ return fmt.Errorf("removing cgroup path %q: %v", path, err)
+ }
+ }
+ return nil
+}
+
+// Join adds the current process to the all controllers. Returns function that
+// restores cgroup to the original state.
+func (c *Cgroup) Join() (func(), error) {
+ // First save the current state so it can be restored.
+ undo := func() {}
+ paths, err := LoadPaths("self")
+ if err != nil {
+ return undo, err
+ }
+ var undoPaths []string
+ for ctrlr, path := range paths {
+ // Skip controllers we don't handle.
+ if _, ok := controllers[ctrlr]; ok {
+ fullPath := filepath.Join(cgroupRoot, ctrlr, path)
+ undoPaths = append(undoPaths, fullPath)
+ break
+ }
+ }
+
+ // Replace empty undo with the real thing before changes are made to cgroups.
+ undo = func() {
+ for _, path := range undoPaths {
+ log.Debugf("Restoring cgroup %q", path)
+ if err := setValue(path, "cgroup.procs", "0"); err != nil {
+ log.Warningf("Error restoring cgroup %q: %v", path, err)
+ }
+ }
+ }
+
+ // Now join the cgroups.
+ for key, cfg := range controllers {
+ path := c.makePath(key)
+ log.Debugf("Joining cgroup %q", path)
+ if err := setValue(path, "cgroup.procs", "0"); err != nil {
+ if cfg.optional && os.IsNotExist(err) {
+ continue
+ }
+ return undo, err
+ }
+ }
+ return undo, nil
+}
+
+func (c *Cgroup) CPUQuota() (float64, error) {
+ path := c.makePath("cpu")
+ quota, err := getInt(path, "cpu.cfs_quota_us")
+ if err != nil {
+ return -1, err
+ }
+ period, err := getInt(path, "cpu.cfs_period_us")
+ if err != nil {
+ return -1, err
+ }
+ if quota <= 0 || period <= 0 {
+ return -1, err
+ }
+ return float64(quota) / float64(period), nil
+}
+
+// NumCPU returns the number of CPUs configured in 'cpuset/cpuset.cpus'.
+func (c *Cgroup) NumCPU() (int, error) {
+ path := c.makePath("cpuset")
+ cpuset, err := getValue(path, "cpuset.cpus")
+ if err != nil {
+ return 0, err
+ }
+ return countCpuset(strings.TrimSpace(cpuset))
+}
+
+// MemoryLimit returns the memory limit.
+func (c *Cgroup) MemoryLimit() (uint64, error) {
+ path := c.makePath("memory")
+ limStr, err := getValue(path, "memory.limit_in_bytes")
+ if err != nil {
+ return 0, err
+ }
+ return strconv.ParseUint(strings.TrimSpace(limStr), 10, 64)
+}
+
+func (c *Cgroup) makePath(controllerName string) string {
+ path := c.Name
+ if parent, ok := c.Parents[controllerName]; ok {
+ path = filepath.Join(parent, c.Name)
+ }
+ return filepath.Join(cgroupRoot, controllerName, path)
+}
+
+type config struct {
+ ctrlr controller
+ optional bool
+}
+
+type controller interface {
+ set(*specs.LinuxResources, string) error
+}
+
+type noop struct{}
+
+func (*noop) set(*specs.LinuxResources, string) error {
+ return nil
+}
+
+type memory struct{}
+
+func (*memory) set(spec *specs.LinuxResources, path string) error {
+ if spec.Memory == nil {
+ return nil
+ }
+ if err := setOptionalValueInt(path, "memory.limit_in_bytes", spec.Memory.Limit); err != nil {
+ return err
+ }
+ if err := setOptionalValueInt(path, "memory.soft_limit_in_bytes", spec.Memory.Reservation); err != nil {
+ return err
+ }
+ if err := setOptionalValueInt(path, "memory.memsw.limit_in_bytes", spec.Memory.Swap); err != nil {
+ return err
+ }
+ if err := setOptionalValueInt(path, "memory.kmem.limit_in_bytes", spec.Memory.Kernel); err != nil {
+ return err
+ }
+ if err := setOptionalValueInt(path, "memory.kmem.tcp.limit_in_bytes", spec.Memory.KernelTCP); err != nil {
+ return err
+ }
+ if err := setOptionalValueUint(path, "memory.swappiness", spec.Memory.Swappiness); err != nil {
+ return err
+ }
+
+ if spec.Memory.DisableOOMKiller != nil && *spec.Memory.DisableOOMKiller {
+ if err := setValue(path, "memory.oom_control", "1"); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+type cpu struct{}
+
+func (*cpu) set(spec *specs.LinuxResources, path string) error {
+ if spec.CPU == nil {
+ return nil
+ }
+ if err := setOptionalValueUint(path, "cpu.shares", spec.CPU.Shares); err != nil {
+ return err
+ }
+ if err := setOptionalValueInt(path, "cpu.cfs_quota_us", spec.CPU.Quota); err != nil {
+ return err
+ }
+ if err := setOptionalValueUint(path, "cpu.cfs_period_us", spec.CPU.Period); err != nil {
+ return err
+ }
+ if err := setOptionalValueUint(path, "cpu.rt_period_us", spec.CPU.RealtimePeriod); err != nil {
+ return err
+ }
+ return setOptionalValueInt(path, "cpu.rt_runtime_us", spec.CPU.RealtimeRuntime)
+}
+
+type cpuSet struct{}
+
+func (*cpuSet) set(spec *specs.LinuxResources, path string) error {
+ // cpuset.cpus and mems are required fields, but are not set on a new cgroup.
+ // If not set in the spec, get it from one of the ancestors cgroup.
+ if spec.CPU == nil || spec.CPU.Cpus == "" {
+ if _, err := fillFromAncestor(filepath.Join(path, "cpuset.cpus")); err != nil {
+ return err
+ }
+ } else {
+ if err := setValue(path, "cpuset.cpus", spec.CPU.Cpus); err != nil {
+ return err
+ }
+ }
+
+ if spec.CPU == nil || spec.CPU.Mems == "" {
+ _, err := fillFromAncestor(filepath.Join(path, "cpuset.mems"))
+ return err
+ }
+ mems := spec.CPU.Mems
+ return setValue(path, "cpuset.mems", mems)
+}
+
+type blockIO struct{}
+
+func (*blockIO) set(spec *specs.LinuxResources, path string) error {
+ if spec.BlockIO == nil {
+ return nil
+ }
+
+ if err := setOptionalValueUint16(path, "blkio.weight", spec.BlockIO.Weight); err != nil {
+ return err
+ }
+ if err := setOptionalValueUint16(path, "blkio.leaf_weight", spec.BlockIO.LeafWeight); err != nil {
+ return err
+ }
+
+ for _, dev := range spec.BlockIO.WeightDevice {
+ if dev.Weight != nil {
+ val := fmt.Sprintf("%d:%d %d", dev.Major, dev.Minor, *dev.Weight)
+ if err := setValue(path, "blkio.weight_device", val); err != nil {
+ return err
+ }
+ }
+ if dev.LeafWeight != nil {
+ val := fmt.Sprintf("%d:%d %d", dev.Major, dev.Minor, *dev.LeafWeight)
+ if err := setValue(path, "blkio.leaf_weight_device", val); err != nil {
+ return err
+ }
+ }
+ }
+ if err := setThrottle(path, "blkio.throttle.read_bps_device", spec.BlockIO.ThrottleReadBpsDevice); err != nil {
+ return err
+ }
+ if err := setThrottle(path, "blkio.throttle.write_bps_device", spec.BlockIO.ThrottleWriteBpsDevice); err != nil {
+ return err
+ }
+ if err := setThrottle(path, "blkio.throttle.read_iops_device", spec.BlockIO.ThrottleReadIOPSDevice); err != nil {
+ return err
+ }
+ return setThrottle(path, "blkio.throttle.write_iops_device", spec.BlockIO.ThrottleWriteIOPSDevice)
+}
+
+func setThrottle(path, name string, devs []specs.LinuxThrottleDevice) error {
+ for _, dev := range devs {
+ val := fmt.Sprintf("%d:%d %d", dev.Major, dev.Minor, dev.Rate)
+ if err := setValue(path, name, val); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+type networkClass struct{}
+
+func (*networkClass) set(spec *specs.LinuxResources, path string) error {
+ if spec.Network == nil {
+ return nil
+ }
+ return setOptionalValueUint32(path, "net_cls.classid", spec.Network.ClassID)
+}
+
+type networkPrio struct{}
+
+func (*networkPrio) set(spec *specs.LinuxResources, path string) error {
+ if spec.Network == nil {
+ return nil
+ }
+ for _, prio := range spec.Network.Priorities {
+ val := fmt.Sprintf("%s %d", prio.Name, prio.Priority)
+ if err := setValue(path, "net_prio.ifpriomap", val); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+type pids struct{}
+
+func (*pids) set(spec *specs.LinuxResources, path string) error {
+ if spec.Pids == nil || spec.Pids.Limit <= 0 {
+ return nil
+ }
+ val := strconv.FormatInt(spec.Pids.Limit, 10)
+ return setValue(path, "pids.max", val)
+}
+
+type hugeTLB struct{}
+
+func (*hugeTLB) set(spec *specs.LinuxResources, path string) error {
+ for _, limit := range spec.HugepageLimits {
+ name := fmt.Sprintf("hugetlb.%s.limit_in_bytes", limit.Pagesize)
+ val := strconv.FormatUint(limit.Limit, 10)
+ if err := setValue(path, name, val); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/runsc/cgroup/cgroup_test.go b/runsc/cgroup/cgroup_test.go
new file mode 100644
index 000000000..4db5ee5c3
--- /dev/null
+++ b/runsc/cgroup/cgroup_test.go
@@ -0,0 +1,649 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroup
+
+import (
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+func TestUninstallEnoent(t *testing.T) {
+ c := Cgroup{
+ // set a non-existent name
+ Name: "runsc-test-uninstall-656e6f656e740a",
+ Own: true,
+ }
+ if err := c.Uninstall(); err != nil {
+ t.Errorf("Uninstall() failed: %v", err)
+ }
+}
+
+func TestCountCpuset(t *testing.T) {
+ for _, tc := range []struct {
+ str string
+ want int
+ error bool
+ }{
+ {str: "0", want: 1},
+ {str: "0,1,2,8,9,10", want: 6},
+ {str: "0-1", want: 2},
+ {str: "0-7", want: 8},
+ {str: "0-7,16,32-39,64,65", want: 19},
+ {str: "a", error: true},
+ {str: "5-a", error: true},
+ {str: "a-5", error: true},
+ {str: "-10", error: true},
+ {str: "15-", error: true},
+ {str: "-", error: true},
+ {str: "--", error: true},
+ } {
+ t.Run(tc.str, func(t *testing.T) {
+ got, err := countCpuset(tc.str)
+ if tc.error {
+ if err == nil {
+ t.Errorf("countCpuset(%q) should have failed", tc.str)
+ }
+ } else {
+ if err != nil {
+ t.Errorf("countCpuset(%q) failed: %v", tc.str, err)
+ }
+ if tc.want != got {
+ t.Errorf("countCpuset(%q) want: %d, got: %d", tc.str, tc.want, got)
+ }
+ }
+ })
+ }
+}
+
+func uint16Ptr(v uint16) *uint16 {
+ return &v
+}
+
+func uint32Ptr(v uint32) *uint32 {
+ return &v
+}
+
+func int64Ptr(v int64) *int64 {
+ return &v
+}
+
+func uint64Ptr(v uint64) *uint64 {
+ return &v
+}
+
+func boolPtr(v bool) *bool {
+ return &v
+}
+
+func checkDir(t *testing.T, dir string, contents map[string]string) {
+ all, err := ioutil.ReadDir(dir)
+ if err != nil {
+ t.Fatalf("ReadDir(%q): %v", dir, err)
+ }
+ fileCount := 0
+ for _, file := range all {
+ if file.IsDir() {
+ // Only want to compare files.
+ continue
+ }
+ fileCount++
+
+ want, ok := contents[file.Name()]
+ if !ok {
+ t.Errorf("file not expected: %q", file.Name())
+ continue
+ }
+ gotBytes, err := ioutil.ReadFile(filepath.Join(dir, file.Name()))
+ if err != nil {
+ t.Fatal(err.Error())
+ }
+ got := strings.TrimSuffix(string(gotBytes), "\n")
+ if got != want {
+ t.Errorf("wrong file content, file: %q, want: %q, got: %q", file.Name(), want, got)
+ }
+ }
+ if fileCount != len(contents) {
+ t.Errorf("file is missing, want: %v, got: %v", contents, all)
+ }
+}
+
+func makeLinuxWeightDevice(major, minor int64, weight, leafWeight *uint16) specs.LinuxWeightDevice {
+ rv := specs.LinuxWeightDevice{
+ Weight: weight,
+ LeafWeight: leafWeight,
+ }
+ rv.Major = major
+ rv.Minor = minor
+ return rv
+}
+
+func makeLinuxThrottleDevice(major, minor int64, rate uint64) specs.LinuxThrottleDevice {
+ rv := specs.LinuxThrottleDevice{
+ Rate: rate,
+ }
+ rv.Major = major
+ rv.Minor = minor
+ return rv
+}
+
+func TestBlockIO(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxBlockIO
+ wants map[string]string
+ }{
+ {
+ name: "simple",
+ spec: &specs.LinuxBlockIO{
+ Weight: uint16Ptr(1),
+ LeafWeight: uint16Ptr(2),
+ },
+ wants: map[string]string{
+ "blkio.weight": "1",
+ "blkio.leaf_weight": "2",
+ },
+ },
+ {
+ name: "weight_device",
+ spec: &specs.LinuxBlockIO{
+ WeightDevice: []specs.LinuxWeightDevice{
+ makeLinuxWeightDevice(1, 2, uint16Ptr(3), uint16Ptr(4)),
+ },
+ },
+ wants: map[string]string{
+ "blkio.weight_device": "1:2 3",
+ "blkio.leaf_weight_device": "1:2 4",
+ },
+ },
+ {
+ name: "weight_device_nil_values",
+ spec: &specs.LinuxBlockIO{
+ WeightDevice: []specs.LinuxWeightDevice{
+ makeLinuxWeightDevice(1, 2, nil, nil),
+ },
+ },
+ },
+ {
+ name: "throttle",
+ spec: &specs.LinuxBlockIO{
+ ThrottleReadBpsDevice: []specs.LinuxThrottleDevice{
+ makeLinuxThrottleDevice(1, 2, 3),
+ },
+ ThrottleReadIOPSDevice: []specs.LinuxThrottleDevice{
+ makeLinuxThrottleDevice(4, 5, 6),
+ },
+ ThrottleWriteBpsDevice: []specs.LinuxThrottleDevice{
+ makeLinuxThrottleDevice(7, 8, 9),
+ },
+ ThrottleWriteIOPSDevice: []specs.LinuxThrottleDevice{
+ makeLinuxThrottleDevice(10, 11, 12),
+ },
+ },
+ wants: map[string]string{
+ "blkio.throttle.read_bps_device": "1:2 3",
+ "blkio.throttle.read_iops_device": "4:5 6",
+ "blkio.throttle.write_bps_device": "7:8 9",
+ "blkio.throttle.write_iops_device": "10:11 12",
+ },
+ },
+ {
+ name: "nil_values",
+ spec: &specs.LinuxBlockIO{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ BlockIO: tc.spec,
+ }
+ ctrlr := blockIO{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+func TestCPU(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxCPU
+ wants map[string]string
+ }{
+ {
+ name: "all",
+ spec: &specs.LinuxCPU{
+ Shares: uint64Ptr(1),
+ Quota: int64Ptr(2),
+ Period: uint64Ptr(3),
+ RealtimeRuntime: int64Ptr(4),
+ RealtimePeriod: uint64Ptr(5),
+ },
+ wants: map[string]string{
+ "cpu.shares": "1",
+ "cpu.cfs_quota_us": "2",
+ "cpu.cfs_period_us": "3",
+ "cpu.rt_runtime_us": "4",
+ "cpu.rt_period_us": "5",
+ },
+ },
+ {
+ name: "nil_values",
+ spec: &specs.LinuxCPU{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ CPU: tc.spec,
+ }
+ ctrlr := cpu{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+func TestCPUSet(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxCPU
+ wants map[string]string
+ }{
+ {
+ name: "all",
+ spec: &specs.LinuxCPU{
+ Cpus: "foo",
+ Mems: "bar",
+ },
+ wants: map[string]string{
+ "cpuset.cpus": "foo",
+ "cpuset.mems": "bar",
+ },
+ },
+ // Don't test nil values because they are copied from the parent.
+ // See TestCPUSetAncestor().
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ CPU: tc.spec,
+ }
+ ctrlr := cpuSet{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+// TestCPUSetAncestor checks that, when not available, value is read from
+// parent directory.
+func TestCPUSetAncestor(t *testing.T) {
+ // Prepare master directory with cgroup files that will be propagated to
+ // children.
+ grandpa, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(grandpa)
+
+ if err := ioutil.WriteFile(filepath.Join(grandpa, "cpuset.cpus"), []byte("parent-cpus"), 0666); err != nil {
+ t.Fatalf("ioutil.WriteFile(): %v", err)
+ }
+ if err := ioutil.WriteFile(filepath.Join(grandpa, "cpuset.mems"), []byte("parent-mems"), 0666); err != nil {
+ t.Fatalf("ioutil.WriteFile(): %v", err)
+ }
+
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxCPU
+ }{
+ {
+ name: "nil_values",
+ spec: &specs.LinuxCPU{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ // Create empty files in intermediate directory. They should be ignored
+ // when reading, and then populated from parent.
+ parent, err := ioutil.TempDir(grandpa, "parent")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(parent)
+ if _, err := os.Create(filepath.Join(parent, "cpuset.cpus")); err != nil {
+ t.Fatalf("os.Create(): %v", err)
+ }
+ if _, err := os.Create(filepath.Join(parent, "cpuset.mems")); err != nil {
+ t.Fatalf("os.Create(): %v", err)
+ }
+
+ // cgroup files mmust exist.
+ dir, err := ioutil.TempDir(parent, "child")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ if _, err := os.Create(filepath.Join(dir, "cpuset.cpus")); err != nil {
+ t.Fatalf("os.Create(): %v", err)
+ }
+ if _, err := os.Create(filepath.Join(dir, "cpuset.mems")); err != nil {
+ t.Fatalf("os.Create(): %v", err)
+ }
+
+ spec := &specs.LinuxResources{
+ CPU: tc.spec,
+ }
+ ctrlr := cpuSet{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ want := map[string]string{
+ "cpuset.cpus": "parent-cpus",
+ "cpuset.mems": "parent-mems",
+ }
+ // Both path and dir must have been populated from grandpa.
+ checkDir(t, parent, want)
+ checkDir(t, dir, want)
+ })
+ }
+}
+
+func TestHugeTlb(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec []specs.LinuxHugepageLimit
+ wants map[string]string
+ }{
+ {
+ name: "single",
+ spec: []specs.LinuxHugepageLimit{
+ {
+ Pagesize: "1G",
+ Limit: 123,
+ },
+ },
+ wants: map[string]string{
+ "hugetlb.1G.limit_in_bytes": "123",
+ },
+ },
+ {
+ name: "multiple",
+ spec: []specs.LinuxHugepageLimit{
+ {
+ Pagesize: "1G",
+ Limit: 123,
+ },
+ {
+ Pagesize: "2G",
+ Limit: 456,
+ },
+ {
+ Pagesize: "1P",
+ Limit: 789,
+ },
+ },
+ wants: map[string]string{
+ "hugetlb.1G.limit_in_bytes": "123",
+ "hugetlb.2G.limit_in_bytes": "456",
+ "hugetlb.1P.limit_in_bytes": "789",
+ },
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ HugepageLimits: tc.spec,
+ }
+ ctrlr := hugeTLB{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+func TestMemory(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxMemory
+ wants map[string]string
+ }{
+ {
+ name: "all",
+ spec: &specs.LinuxMemory{
+ Limit: int64Ptr(1),
+ Reservation: int64Ptr(2),
+ Swap: int64Ptr(3),
+ Kernel: int64Ptr(4),
+ KernelTCP: int64Ptr(5),
+ Swappiness: uint64Ptr(6),
+ DisableOOMKiller: boolPtr(true),
+ },
+ wants: map[string]string{
+ "memory.limit_in_bytes": "1",
+ "memory.soft_limit_in_bytes": "2",
+ "memory.memsw.limit_in_bytes": "3",
+ "memory.kmem.limit_in_bytes": "4",
+ "memory.kmem.tcp.limit_in_bytes": "5",
+ "memory.swappiness": "6",
+ "memory.oom_control": "1",
+ },
+ },
+ {
+ // Disable OOM killer should only write when set to true.
+ name: "oomkiller",
+ spec: &specs.LinuxMemory{
+ DisableOOMKiller: boolPtr(false),
+ },
+ },
+ {
+ name: "nil_values",
+ spec: &specs.LinuxMemory{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ Memory: tc.spec,
+ }
+ ctrlr := memory{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+func TestNetworkClass(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxNetwork
+ wants map[string]string
+ }{
+ {
+ name: "all",
+ spec: &specs.LinuxNetwork{
+ ClassID: uint32Ptr(1),
+ },
+ wants: map[string]string{
+ "net_cls.classid": "1",
+ },
+ },
+ {
+ name: "nil_values",
+ spec: &specs.LinuxNetwork{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ Network: tc.spec,
+ }
+ ctrlr := networkClass{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+func TestNetworkPriority(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxNetwork
+ wants map[string]string
+ }{
+ {
+ name: "all",
+ spec: &specs.LinuxNetwork{
+ Priorities: []specs.LinuxInterfacePriority{
+ {
+ Name: "foo",
+ Priority: 1,
+ },
+ },
+ },
+ wants: map[string]string{
+ "net_prio.ifpriomap": "foo 1",
+ },
+ },
+ {
+ name: "nil_values",
+ spec: &specs.LinuxNetwork{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ Network: tc.spec,
+ }
+ ctrlr := networkPrio{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+func TestPids(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxPids
+ wants map[string]string
+ }{
+ {
+ name: "all",
+ spec: &specs.LinuxPids{Limit: 1},
+ wants: map[string]string{
+ "pids.max": "1",
+ },
+ },
+ {
+ name: "nil_values",
+ spec: &specs.LinuxPids{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ Pids: tc.spec,
+ }
+ ctrlr := pids{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD
new file mode 100644
index 000000000..dae9b3b3e
--- /dev/null
+++ b/runsc/cmd/BUILD
@@ -0,0 +1,95 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "cmd",
+ srcs = [
+ "boot.go",
+ "capability.go",
+ "checkpoint.go",
+ "chroot.go",
+ "cmd.go",
+ "create.go",
+ "debug.go",
+ "delete.go",
+ "do.go",
+ "error.go",
+ "events.go",
+ "exec.go",
+ "gofer.go",
+ "help.go",
+ "install.go",
+ "kill.go",
+ "list.go",
+ "path.go",
+ "pause.go",
+ "ps.go",
+ "restore.go",
+ "resume.go",
+ "run.go",
+ "spec.go",
+ "start.go",
+ "state.go",
+ "statefile.go",
+ "syscalls.go",
+ "wait.go",
+ ],
+ visibility = [
+ "//runsc:__subpackages__",
+ ],
+ deps = [
+ "//pkg/log",
+ "//pkg/p9",
+ "//pkg/sentry/control",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/platform",
+ "//pkg/state/pretty",
+ "//pkg/state/statefile",
+ "//pkg/sync",
+ "//pkg/unet",
+ "//pkg/urpc",
+ "//runsc/boot",
+ "//runsc/console",
+ "//runsc/container",
+ "//runsc/flag",
+ "//runsc/fsgofer",
+ "//runsc/fsgofer/filter",
+ "//runsc/specutils",
+ "@com_github_google_subcommands//:go_default_library",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@com_github_syndtr_gocapability//capability:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "cmd_test",
+ size = "small",
+ srcs = [
+ "capability_test.go",
+ "delete_test.go",
+ "exec_test.go",
+ "gofer_test.go",
+ ],
+ data = [
+ "//runsc",
+ ],
+ library = ":cmd",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/log",
+ "//pkg/sentry/control",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/test/testutil",
+ "//pkg/urpc",
+ "//runsc/boot",
+ "//runsc/container",
+ "//runsc/specutils",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@com_github_syndtr_gocapability//capability:go_default_library",
+ ],
+)
diff --git a/runsc/cmd/boot.go b/runsc/cmd/boot.go
new file mode 100644
index 000000000..01204ab4d
--- /dev/null
+++ b/runsc/cmd/boot.go
@@ -0,0 +1,290 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "os"
+ "runtime/debug"
+ "strings"
+ "syscall"
+
+ "github.com/google/subcommands"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/flag"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// Boot implements subcommands.Command for the "boot" command which starts a
+// new sandbox. It should not be called directly.
+type Boot struct {
+ // bundleDir is the directory containing the OCI spec.
+ bundleDir string
+
+ // specFD is the file descriptor that the spec will be read from.
+ specFD int
+
+ // controllerFD is the file descriptor of a stream socket for the
+ // control server that is donated to this process.
+ controllerFD int
+
+ // deviceFD is the file descriptor for the platform device file.
+ deviceFD int
+
+ // ioFDs is the list of FDs used to connect to FS gofers.
+ ioFDs intFlags
+
+ // stdioFDs are the fds for stdin, stdout, and stderr. They must be
+ // provided in that order.
+ stdioFDs intFlags
+
+ // console is set to true if the sandbox should allow terminal ioctl(2)
+ // syscalls.
+ console bool
+
+ // applyCaps determines if capabilities defined in the spec should be applied
+ // to the process.
+ applyCaps bool
+
+ // setUpChroot is set to true if the sandbox is started in an empty root.
+ setUpRoot bool
+
+ // cpuNum number of CPUs to create inside the sandbox.
+ cpuNum int
+
+ // totalMem sets the initial amount of total memory to report back to the
+ // container.
+ totalMem uint64
+
+ // userLogFD is the file descriptor to write user logs to.
+ userLogFD int
+
+ // startSyncFD is the file descriptor to synchronize runsc and sandbox.
+ startSyncFD int
+
+ // mountsFD is the file descriptor to read list of mounts after they have
+ // been resolved (direct paths, no symlinks). They are resolved outside the
+ // sandbox (e.g. gofer) and sent through this FD.
+ mountsFD int
+
+ // pidns is set if the sandbox is in its own pid namespace.
+ pidns bool
+
+ // attached is set to true to kill the sandbox process when the parent process
+ // terminates. This flag is set when the command execve's itself because
+ // parent death signal doesn't propagate through execve when uid/gid changes.
+ attached bool
+}
+
+// Name implements subcommands.Command.Name.
+func (*Boot) Name() string {
+ return "boot"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Boot) Synopsis() string {
+ return "launch a sandbox process (internal use only)"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Boot) Usage() string {
+ return `boot [flags] <container id>`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (b *Boot) SetFlags(f *flag.FlagSet) {
+ f.StringVar(&b.bundleDir, "bundle", "", "required path to the root of the bundle directory")
+ f.IntVar(&b.specFD, "spec-fd", -1, "required fd with the container spec")
+ f.IntVar(&b.controllerFD, "controller-fd", -1, "required FD of a stream socket for the control server that must be donated to this process")
+ f.IntVar(&b.deviceFD, "device-fd", -1, "FD for the platform device file")
+ f.Var(&b.ioFDs, "io-fds", "list of FDs to connect 9P clients. They must follow this order: root first, then mounts as defined in the spec")
+ f.Var(&b.stdioFDs, "stdio-fds", "list of FDs containing sandbox stdin, stdout, and stderr in that order")
+ f.BoolVar(&b.console, "console", false, "set to true if the sandbox should allow terminal ioctl(2) syscalls")
+ f.BoolVar(&b.applyCaps, "apply-caps", false, "if true, apply capabilities defined in the spec to the process")
+ f.BoolVar(&b.setUpRoot, "setup-root", false, "if true, set up an empty root for the process")
+ f.BoolVar(&b.pidns, "pidns", false, "if true, the sandbox is in its own PID namespace")
+ f.IntVar(&b.cpuNum, "cpu-num", 0, "number of CPUs to create inside the sandbox")
+ f.Uint64Var(&b.totalMem, "total-memory", 0, "sets the initial amount of total memory to report back to the container")
+ f.IntVar(&b.userLogFD, "user-log-fd", 0, "file descriptor to write user logs to. 0 means no logging.")
+ f.IntVar(&b.startSyncFD, "start-sync-fd", -1, "required FD to used to synchronize sandbox startup")
+ f.IntVar(&b.mountsFD, "mounts-fd", -1, "mountsFD is the file descriptor to read list of mounts after they have been resolved (direct paths, no symlinks).")
+ f.BoolVar(&b.attached, "attached", false, "if attached is true, kills the sandbox process when the parent process terminates")
+}
+
+// Execute implements subcommands.Command.Execute. It starts a sandbox in a
+// waiting state.
+func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if b.specFD == -1 || b.controllerFD == -1 || b.startSyncFD == -1 || f.NArg() != 1 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+
+ // Ensure that if there is a panic, all goroutine stacks are printed.
+ debug.SetTraceback("system")
+
+ conf := args[0].(*boot.Config)
+
+ if b.attached {
+ // Ensure this process is killed after parent process terminates when
+ // attached mode is enabled. In the unfortunate event that the parent
+ // terminates before this point, this process leaks.
+ if err := unix.Prctl(unix.PR_SET_PDEATHSIG, uintptr(unix.SIGKILL), 0, 0, 0); err != nil {
+ Fatalf("error setting parent death signal: %v", err)
+ }
+ }
+
+ if b.setUpRoot {
+ if err := setUpChroot(b.pidns); err != nil {
+ Fatalf("error setting up chroot: %v", err)
+ }
+
+ if !b.applyCaps && !conf.Rootless {
+ // Remove --apply-caps arg to call myself. It has already been done.
+ args := prepareArgs(b.attached, "setup-root")
+
+ // Note that we've already read the spec from the spec FD, and
+ // we will read it again after the exec call. This works
+ // because the ReadSpecFromFile function seeks to the beginning
+ // of the file before reading.
+ if err := callSelfAsNobody(args); err != nil {
+ Fatalf("%v", err)
+ }
+ panic("callSelfAsNobody must never return success")
+ }
+ }
+
+ // Get the spec from the specFD.
+ specFile := os.NewFile(uintptr(b.specFD), "spec file")
+ defer specFile.Close()
+ spec, err := specutils.ReadSpecFromFile(b.bundleDir, specFile)
+ if err != nil {
+ Fatalf("reading spec: %v", err)
+ }
+ specutils.LogSpec(spec)
+
+ if b.applyCaps {
+ caps := spec.Process.Capabilities
+ if caps == nil {
+ caps = &specs.LinuxCapabilities{}
+ }
+
+ gPlatform, err := platform.Lookup(conf.Platform)
+ if err != nil {
+ Fatalf("loading platform: %v", err)
+ }
+ if gPlatform.Requirements().RequiresCapSysPtrace {
+ // Ptrace platform requires extra capabilities.
+ const c = "CAP_SYS_PTRACE"
+ caps.Bounding = append(caps.Bounding, c)
+ caps.Effective = append(caps.Effective, c)
+ caps.Permitted = append(caps.Permitted, c)
+ }
+
+ // Remove --apply-caps and --setup-root arg to call myself. Both have
+ // already been done.
+ args := prepareArgs(b.attached, "setup-root", "apply-caps")
+
+ // Note that we've already read the spec from the spec FD, and
+ // we will read it again after the exec call. This works
+ // because the ReadSpecFromFile function seeks to the beginning
+ // of the file before reading.
+ if err := setCapsAndCallSelf(args, caps); err != nil {
+ Fatalf("%v", err)
+ }
+ panic("setCapsAndCallSelf must never return success")
+ }
+
+ // Read resolved mount list and replace the original one from the spec.
+ mountsFile := os.NewFile(uintptr(b.mountsFD), "mounts file")
+ cleanMounts, err := specutils.ReadMounts(mountsFile)
+ if err != nil {
+ mountsFile.Close()
+ Fatalf("Error reading mounts file: %v", err)
+ }
+ mountsFile.Close()
+ spec.Mounts = cleanMounts
+
+ // Create the loader.
+ bootArgs := boot.Args{
+ ID: f.Arg(0),
+ Spec: spec,
+ Conf: conf,
+ ControllerFD: b.controllerFD,
+ Device: os.NewFile(uintptr(b.deviceFD), "platform device"),
+ GoferFDs: b.ioFDs.GetArray(),
+ StdioFDs: b.stdioFDs.GetArray(),
+ Console: b.console,
+ NumCPU: b.cpuNum,
+ TotalMem: b.totalMem,
+ UserLogFD: b.userLogFD,
+ }
+ l, err := boot.New(bootArgs)
+ if err != nil {
+ Fatalf("creating loader: %v", err)
+ }
+
+ // Fatalf exits the process and doesn't run defers.
+ // 'l' must be destroyed explicitly after this point!
+
+ // Notify the parent process the sandbox has booted (and that the controller
+ // is up).
+ startSyncFile := os.NewFile(uintptr(b.startSyncFD), "start-sync file")
+ buf := make([]byte, 1)
+ if w, err := startSyncFile.Write(buf); err != nil || w != 1 {
+ l.Destroy()
+ Fatalf("unable to write into the start-sync descriptor: %v", err)
+ }
+ // Closes startSyncFile because 'l.Run()' only returns when the sandbox exits.
+ startSyncFile.Close()
+
+ // Wait for the start signal from runsc.
+ l.WaitForStartSignal()
+
+ // Run the application and wait for it to finish.
+ if err := l.Run(); err != nil {
+ l.Destroy()
+ Fatalf("running sandbox: %v", err)
+ }
+
+ ws := l.WaitExit()
+ log.Infof("application exiting with %+v", ws)
+ waitStatus := args[1].(*syscall.WaitStatus)
+ *waitStatus = syscall.WaitStatus(ws.Status())
+ l.Destroy()
+ return subcommands.ExitSuccess
+}
+
+func prepareArgs(attached bool, exclude ...string) []string {
+ var args []string
+ for _, arg := range os.Args {
+ for _, excl := range exclude {
+ if strings.Contains(arg, excl) {
+ goto skip
+ }
+ }
+ args = append(args, arg)
+ if attached && arg == "boot" {
+ // Strategicaly place "--attached" after the command. This is needed
+ // to ensure the new process is killed when the parent process terminates.
+ args = append(args, "--attached")
+ }
+ skip:
+ }
+ return args
+}
diff --git a/runsc/cmd/capability.go b/runsc/cmd/capability.go
new file mode 100644
index 000000000..abfbb7cfc
--- /dev/null
+++ b/runsc/cmd/capability.go
@@ -0,0 +1,157 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "fmt"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "github.com/syndtr/gocapability/capability"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+var allCapTypes = []capability.CapType{
+ capability.BOUNDS,
+ capability.EFFECTIVE,
+ capability.PERMITTED,
+ capability.INHERITABLE,
+ capability.AMBIENT,
+}
+
+// applyCaps applies the capabilities in the spec to the current thread.
+//
+// Note that it must be called with current thread locked.
+func applyCaps(caps *specs.LinuxCapabilities) error {
+ // Load current capabilities to trim the ones not permitted.
+ curCaps, err := capability.NewPid2(0)
+ if err != nil {
+ return err
+ }
+ if err := curCaps.Load(); err != nil {
+ return err
+ }
+
+ // Create an empty capability set to populate.
+ newCaps, err := capability.NewPid2(0)
+ if err != nil {
+ return err
+ }
+
+ for _, c := range allCapTypes {
+ if !newCaps.Empty(c) {
+ panic("unloaded capabilities must be empty")
+ }
+ set, err := trimCaps(getCaps(c, caps), curCaps)
+ if err != nil {
+ return err
+ }
+ newCaps.Set(c, set...)
+ }
+
+ if err := newCaps.Apply(capability.CAPS | capability.BOUNDS | capability.AMBS); err != nil {
+ return err
+ }
+ log.Infof("Capabilities applied: %+v", newCaps)
+ return nil
+}
+
+func getCaps(which capability.CapType, caps *specs.LinuxCapabilities) []string {
+ switch which {
+ case capability.BOUNDS:
+ return caps.Bounding
+ case capability.EFFECTIVE:
+ return caps.Effective
+ case capability.PERMITTED:
+ return caps.Permitted
+ case capability.INHERITABLE:
+ return caps.Inheritable
+ case capability.AMBIENT:
+ return caps.Ambient
+ }
+ panic(fmt.Sprint("invalid capability type:", which))
+}
+
+func trimCaps(names []string, setter capability.Capabilities) ([]capability.Cap, error) {
+ wantedCaps, err := capsFromNames(names)
+ if err != nil {
+ return nil, err
+ }
+
+ // Trim down capabilities that aren't possible to acquire.
+ var caps []capability.Cap
+ for _, c := range wantedCaps {
+ // Capability rules are more complicated than this, but this catches most
+ // problems with tests running with non-privileged user.
+ if setter.Get(capability.PERMITTED, c) {
+ caps = append(caps, c)
+ } else {
+ log.Warningf("Capability %q is not permitted, dropping it.", c)
+ }
+ }
+ return caps, nil
+}
+
+func capsFromNames(names []string) ([]capability.Cap, error) {
+ var caps []capability.Cap
+ for _, name := range names {
+ cap, ok := capFromName[name]
+ if !ok {
+ return nil, fmt.Errorf("invalid capability %q", name)
+ }
+ caps = append(caps, cap)
+ }
+ return caps, nil
+}
+
+var capFromName = map[string]capability.Cap{
+ "CAP_CHOWN": capability.CAP_CHOWN,
+ "CAP_DAC_OVERRIDE": capability.CAP_DAC_OVERRIDE,
+ "CAP_DAC_READ_SEARCH": capability.CAP_DAC_READ_SEARCH,
+ "CAP_FOWNER": capability.CAP_FOWNER,
+ "CAP_FSETID": capability.CAP_FSETID,
+ "CAP_KILL": capability.CAP_KILL,
+ "CAP_SETGID": capability.CAP_SETGID,
+ "CAP_SETUID": capability.CAP_SETUID,
+ "CAP_SETPCAP": capability.CAP_SETPCAP,
+ "CAP_LINUX_IMMUTABLE": capability.CAP_LINUX_IMMUTABLE,
+ "CAP_NET_BIND_SERVICE": capability.CAP_NET_BIND_SERVICE,
+ "CAP_NET_BROADCAST": capability.CAP_NET_BROADCAST,
+ "CAP_NET_ADMIN": capability.CAP_NET_ADMIN,
+ "CAP_NET_RAW": capability.CAP_NET_RAW,
+ "CAP_IPC_LOCK": capability.CAP_IPC_LOCK,
+ "CAP_IPC_OWNER": capability.CAP_IPC_OWNER,
+ "CAP_SYS_MODULE": capability.CAP_SYS_MODULE,
+ "CAP_SYS_RAWIO": capability.CAP_SYS_RAWIO,
+ "CAP_SYS_CHROOT": capability.CAP_SYS_CHROOT,
+ "CAP_SYS_PTRACE": capability.CAP_SYS_PTRACE,
+ "CAP_SYS_PACCT": capability.CAP_SYS_PACCT,
+ "CAP_SYS_ADMIN": capability.CAP_SYS_ADMIN,
+ "CAP_SYS_BOOT": capability.CAP_SYS_BOOT,
+ "CAP_SYS_NICE": capability.CAP_SYS_NICE,
+ "CAP_SYS_RESOURCE": capability.CAP_SYS_RESOURCE,
+ "CAP_SYS_TIME": capability.CAP_SYS_TIME,
+ "CAP_SYS_TTY_CONFIG": capability.CAP_SYS_TTY_CONFIG,
+ "CAP_MKNOD": capability.CAP_MKNOD,
+ "CAP_LEASE": capability.CAP_LEASE,
+ "CAP_AUDIT_WRITE": capability.CAP_AUDIT_WRITE,
+ "CAP_AUDIT_CONTROL": capability.CAP_AUDIT_CONTROL,
+ "CAP_SETFCAP": capability.CAP_SETFCAP,
+ "CAP_MAC_OVERRIDE": capability.CAP_MAC_OVERRIDE,
+ "CAP_MAC_ADMIN": capability.CAP_MAC_ADMIN,
+ "CAP_SYSLOG": capability.CAP_SYSLOG,
+ "CAP_WAKE_ALARM": capability.CAP_WAKE_ALARM,
+ "CAP_BLOCK_SUSPEND": capability.CAP_BLOCK_SUSPEND,
+ "CAP_AUDIT_READ": capability.CAP_AUDIT_READ,
+}
diff --git a/runsc/cmd/capability_test.go b/runsc/cmd/capability_test.go
new file mode 100644
index 000000000..a84067112
--- /dev/null
+++ b/runsc/cmd/capability_test.go
@@ -0,0 +1,127 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "flag"
+ "fmt"
+ "os"
+ "testing"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "github.com/syndtr/gocapability/capability"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+func init() {
+ log.SetLevel(log.Debug)
+ if err := testutil.ConfigureExePath(); err != nil {
+ panic(err.Error())
+ }
+}
+
+func checkProcessCaps(pid int, wantCaps *specs.LinuxCapabilities) error {
+ curCaps, err := capability.NewPid2(pid)
+ if err != nil {
+ return fmt.Errorf("capability.NewPid2(%d) failed: %v", pid, err)
+ }
+ if err := curCaps.Load(); err != nil {
+ return fmt.Errorf("unable to load capabilities: %v", err)
+ }
+ fmt.Printf("Capabilities (PID: %d): %v\n", pid, curCaps)
+
+ for _, c := range allCapTypes {
+ if err := checkCaps(c, curCaps, wantCaps); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func checkCaps(which capability.CapType, curCaps capability.Capabilities, wantCaps *specs.LinuxCapabilities) error {
+ wantNames := getCaps(which, wantCaps)
+ for name, c := range capFromName {
+ want := specutils.ContainsStr(wantNames, name)
+ got := curCaps.Get(which, c)
+ if want != got {
+ if want {
+ return fmt.Errorf("capability %v:%s should be set", which, name)
+ }
+ return fmt.Errorf("capability %v:%s should NOT be set", which, name)
+ }
+ }
+ return nil
+}
+
+func TestCapabilities(t *testing.T) {
+ stop := testutil.StartReaper()
+ defer stop()
+
+ spec := testutil.NewSpecWithArgs("/bin/sleep", "10000")
+ caps := []string{
+ "CAP_CHOWN",
+ "CAP_SYS_PTRACE", // ptrace is added due to the platform choice.
+ }
+ spec.Process.Capabilities = &specs.LinuxCapabilities{
+ Permitted: caps,
+ Bounding: caps,
+ Effective: caps,
+ Inheritable: caps,
+ }
+
+ conf := testutil.TestConfig(t)
+
+ // Use --network=host to make sandbox use spec's capabilities.
+ conf.Network = boot.NetworkHost
+
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create and start the container.
+ args := container.Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := container.New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
+ if err := c.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // Check that sandbox and gofer have the proper capabilities.
+ if err := checkProcessCaps(c.Sandbox.Pid, spec.Process.Capabilities); err != nil {
+ t.Error(err)
+ }
+ if err := checkProcessCaps(c.GoferPid, goferCaps); err != nil {
+ t.Error(err)
+ }
+}
+
+func TestMain(m *testing.M) {
+ flag.Parse()
+ specutils.MaybeRunAsRoot()
+ os.Exit(m.Run())
+}
diff --git a/runsc/cmd/checkpoint.go b/runsc/cmd/checkpoint.go
new file mode 100644
index 000000000..8a29e521e
--- /dev/null
+++ b/runsc/cmd/checkpoint.go
@@ -0,0 +1,155 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "syscall"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// File containing the container's saved image/state within the given image-path's directory.
+const checkpointFileName = "checkpoint.img"
+
+// Checkpoint implements subcommands.Command for the "checkpoint" command.
+type Checkpoint struct {
+ imagePath string
+ leaveRunning bool
+}
+
+// Name implements subcommands.Command.Name.
+func (*Checkpoint) Name() string {
+ return "checkpoint"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Checkpoint) Synopsis() string {
+ return "checkpoint current state of container (experimental)"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Checkpoint) Usage() string {
+ return `checkpoint [flags] <container id> - save current state of container.
+`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (c *Checkpoint) SetFlags(f *flag.FlagSet) {
+ f.StringVar(&c.imagePath, "image-path", "", "directory path to saved container image")
+ f.BoolVar(&c.leaveRunning, "leave-running", false, "restart the container after checkpointing")
+
+ // Unimplemented flags necessary for compatibility with docker.
+ var wp string
+ f.StringVar(&wp, "work-path", "", "ignored")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (c *Checkpoint) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+
+ if f.NArg() != 1 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+
+ id := f.Arg(0)
+ conf := args[0].(*boot.Config)
+ waitStatus := args[1].(*syscall.WaitStatus)
+
+ cont, err := container.Load(conf.RootDir, id)
+ if err != nil {
+ Fatalf("loading container: %v", err)
+ }
+
+ if c.imagePath == "" {
+ Fatalf("image-path flag must be provided")
+ }
+
+ if err := os.MkdirAll(c.imagePath, 0755); err != nil {
+ Fatalf("making directories at path provided: %v", err)
+ }
+
+ fullImagePath := filepath.Join(c.imagePath, checkpointFileName)
+
+ // Create the image file and open for writing.
+ file, err := os.OpenFile(fullImagePath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0644)
+ if err != nil {
+ Fatalf("os.OpenFile(%q) failed: %v", fullImagePath, err)
+ }
+ defer file.Close()
+
+ if err := cont.Checkpoint(file); err != nil {
+ Fatalf("checkpoint failed: %v", err)
+ }
+
+ if !c.leaveRunning {
+ return subcommands.ExitSuccess
+ }
+
+ // TODO(b/110843694): Make it possible to restore into same container.
+ // For now, we can fake it by destroying the container and making a
+ // new container with the same ID. This hack does not work with docker
+ // which uses the container pid to ensure that the restore-container is
+ // actually the same as the checkpoint-container. By restoring into
+ // the same container, we will solve the docker incompatibility.
+
+ // Restore into new container with same ID.
+ bundleDir := cont.BundleDir
+ if bundleDir == "" {
+ Fatalf("setting bundleDir")
+ }
+
+ spec, err := specutils.ReadSpec(bundleDir)
+ if err != nil {
+ Fatalf("reading spec: %v", err)
+ }
+
+ specutils.LogSpec(spec)
+
+ if cont.ConsoleSocket != "" {
+ log.Warningf("ignoring console socket since it cannot be restored")
+ }
+
+ if err := cont.Destroy(); err != nil {
+ Fatalf("destroying container: %v", err)
+ }
+
+ contArgs := container.Args{
+ ID: id,
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err = container.New(conf, contArgs)
+ if err != nil {
+ Fatalf("restoring container: %v", err)
+ }
+ defer cont.Destroy()
+
+ if err := cont.Restore(spec, conf, fullImagePath); err != nil {
+ Fatalf("starting container: %v", err)
+ }
+
+ ws, err := cont.Wait()
+ *waitStatus = ws
+
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/cmd/chroot.go b/runsc/cmd/chroot.go
new file mode 100644
index 000000000..189244765
--- /dev/null
+++ b/runsc/cmd/chroot.go
@@ -0,0 +1,97 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// mountInChroot creates the destination mount point in the given chroot and
+// mounts the source.
+func mountInChroot(chroot, src, dst, typ string, flags uint32) error {
+ chrootDst := filepath.Join(chroot, dst)
+ log.Infof("Mounting %q at %q", src, chrootDst)
+
+ if err := specutils.Mount(src, chrootDst, typ, flags); err != nil {
+ return fmt.Errorf("error mounting %q at %q: %v", src, chrootDst, err)
+ }
+ return nil
+}
+
+func pivotRoot(root string) error {
+ if err := os.Chdir(root); err != nil {
+ return fmt.Errorf("error changing working directory: %v", err)
+ }
+ // pivot_root(new_root, put_old) moves the root filesystem (old_root)
+ // of the calling process to the directory put_old and makes new_root
+ // the new root filesystem of the calling process.
+ //
+ // pivot_root(".", ".") makes a mount of the working directory the new
+ // root filesystem, so it will be moved in "/" and then the old_root
+ // will be moved to "/" too. The parent mount of the old_root will be
+ // new_root, so after umounting the old_root, we will see only
+ // the new_root in "/".
+ if err := syscall.PivotRoot(".", "."); err != nil {
+ return fmt.Errorf("pivot_root failed, make sure that the root mount has a parent: %v", err)
+ }
+
+ if err := syscall.Unmount(".", syscall.MNT_DETACH); err != nil {
+ return fmt.Errorf("error umounting the old root file system: %v", err)
+ }
+ return nil
+}
+
+// setUpChroot creates an empty directory with runsc mounted at /runsc and proc
+// mounted at /proc.
+func setUpChroot(pidns bool) error {
+ // We are a new mount namespace, so we can use /tmp as a directory to
+ // construct a new root.
+ chroot := os.TempDir()
+
+ log.Infof("Setting up sandbox chroot in %q", chroot)
+
+ // Convert all shared mounts into slave to be sure that nothing will be
+ // propagated outside of our namespace.
+ if err := syscall.Mount("", "/", "", syscall.MS_SLAVE|syscall.MS_REC, ""); err != nil {
+ return fmt.Errorf("error converting mounts: %v", err)
+ }
+
+ if err := syscall.Mount("runsc-root", chroot, "tmpfs", syscall.MS_NOSUID|syscall.MS_NODEV|syscall.MS_NOEXEC, ""); err != nil {
+ return fmt.Errorf("error mounting tmpfs in choot: %v", err)
+ }
+
+ if pidns {
+ flags := uint32(syscall.MS_NOSUID | syscall.MS_NODEV | syscall.MS_NOEXEC | syscall.MS_RDONLY)
+ if err := mountInChroot(chroot, "proc", "/proc", "proc", flags); err != nil {
+ return fmt.Errorf("error mounting proc in chroot: %v", err)
+ }
+ } else {
+ if err := mountInChroot(chroot, "/proc", "/proc", "bind", syscall.MS_BIND|syscall.MS_RDONLY|syscall.MS_REC); err != nil {
+ return fmt.Errorf("error mounting proc in chroot: %v", err)
+ }
+ }
+
+ if err := syscall.Mount("", chroot, "", syscall.MS_REMOUNT|syscall.MS_RDONLY|syscall.MS_BIND, ""); err != nil {
+ return fmt.Errorf("error remounting chroot in read-only: %v", err)
+ }
+
+ return pivotRoot(chroot)
+}
diff --git a/runsc/cmd/cmd.go b/runsc/cmd/cmd.go
new file mode 100644
index 000000000..f1a4887ef
--- /dev/null
+++ b/runsc/cmd/cmd.go
@@ -0,0 +1,98 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package cmd holds implementations of the runsc commands.
+package cmd
+
+import (
+ "fmt"
+ "runtime"
+ "strconv"
+ "syscall"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// intFlags can be used with int flags that appear multiple times.
+type intFlags []int
+
+// String implements flag.Value.
+func (i *intFlags) String() string {
+ return fmt.Sprintf("%v", *i)
+}
+
+// Get implements flag.Value.
+func (i *intFlags) Get() interface{} {
+ return i
+}
+
+// GetArray returns array of FDs.
+func (i *intFlags) GetArray() []int {
+ return *i
+}
+
+// Set implements flag.Value.
+func (i *intFlags) Set(s string) error {
+ fd, err := strconv.Atoi(s)
+ if err != nil {
+ return fmt.Errorf("invalid flag value: %v", err)
+ }
+ if fd < 0 {
+ return fmt.Errorf("flag value must be greater than 0: %d", fd)
+ }
+ *i = append(*i, fd)
+ return nil
+}
+
+// setCapsAndCallSelf sets capabilities to the current thread and then execve's
+// itself again with the arguments specified in 'args' to restart the process
+// with the desired capabilities.
+func setCapsAndCallSelf(args []string, caps *specs.LinuxCapabilities) error {
+ // Keep thread locked while capabilities are changed.
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+
+ if err := applyCaps(caps); err != nil {
+ return fmt.Errorf("applyCaps() failed: %v", err)
+ }
+ binPath := specutils.ExePath
+
+ log.Infof("Execve %q again, bye!", binPath)
+ err := syscall.Exec(binPath, args, []string{})
+ return fmt.Errorf("error executing %s: %v", binPath, err)
+}
+
+// callSelfAsNobody sets UID and GID to nobody and then execve's itself again.
+func callSelfAsNobody(args []string) error {
+ // Keep thread locked while user/group are changed.
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+
+ const nobody = 65534
+
+ if _, _, err := syscall.RawSyscall(syscall.SYS_SETGID, uintptr(nobody), 0, 0); err != 0 {
+ return fmt.Errorf("error setting uid: %v", err)
+ }
+ if _, _, err := syscall.RawSyscall(syscall.SYS_SETUID, uintptr(nobody), 0, 0); err != 0 {
+ return fmt.Errorf("error setting gid: %v", err)
+ }
+
+ binPath := specutils.ExePath
+
+ log.Infof("Execve %q again, bye!", binPath)
+ err := syscall.Exec(binPath, args, []string{})
+ return fmt.Errorf("error executing %s: %v", binPath, err)
+}
diff --git a/runsc/cmd/create.go b/runsc/cmd/create.go
new file mode 100644
index 000000000..910e97577
--- /dev/null
+++ b/runsc/cmd/create.go
@@ -0,0 +1,115 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// Create implements subcommands.Command for the "create" command.
+type Create struct {
+ // bundleDir is the path to the bundle directory (defaults to the
+ // current working directory).
+ bundleDir string
+
+ // pidFile is the filename that the sandbox pid will be written to.
+ // This file should only be created once the container process inside
+ // the sandbox is ready to use.
+ pidFile string
+
+ // consoleSocket is the path to an AF_UNIX socket which will receive a
+ // file descriptor referencing the master end of the console's
+ // pseudoterminal. This is ignored unless spec.Process.Terminal is
+ // true.
+ consoleSocket string
+
+ // userLog is the path to send user-visible logs to. This log is different
+ // from debug logs. The former is meant to be consumed by the users and should
+ // contain only information that is relevant to the person running the
+ // container, e.g. unsuported syscalls, while the later is more verbose and
+ // consumed by developers.
+ userLog string
+}
+
+// Name implements subcommands.Command.Name.
+func (*Create) Name() string {
+ return "create"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Create) Synopsis() string {
+ return "create a secure container"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Create) Usage() string {
+ return `create [flags] <container id> - create a secure container
+`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (c *Create) SetFlags(f *flag.FlagSet) {
+ f.StringVar(&c.bundleDir, "bundle", "", "path to the root of the bundle directory, defaults to the current directory")
+ f.StringVar(&c.consoleSocket, "console-socket", "", "path to an AF_UNIX socket which will receive a file descriptor referencing the master end of the console's pseudoterminal")
+ f.StringVar(&c.pidFile, "pid-file", "", "filename that the container pid will be written to")
+ f.StringVar(&c.userLog, "user-log", "", "filename to send user-visible logs to. Empty means no logging.")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (c *Create) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if f.NArg() != 1 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+
+ id := f.Arg(0)
+ conf := args[0].(*boot.Config)
+
+ if conf.Rootless {
+ return Errorf("Rootless mode not supported with %q", c.Name())
+ }
+
+ bundleDir := c.bundleDir
+ if bundleDir == "" {
+ bundleDir = getwdOrDie()
+ }
+ spec, err := specutils.ReadSpec(bundleDir)
+ if err != nil {
+ return Errorf("reading spec: %v", err)
+ }
+ specutils.LogSpec(spec)
+
+ // Create the container. A new sandbox will be created for the
+ // container unless the metadata specifies that it should be run in an
+ // existing container.
+ contArgs := container.Args{
+ ID: id,
+ Spec: spec,
+ BundleDir: bundleDir,
+ ConsoleSocket: c.consoleSocket,
+ PIDFile: c.pidFile,
+ UserLog: c.userLog,
+ }
+ if _, err := container.New(conf, contArgs); err != nil {
+ return Errorf("creating container: %v", err)
+ }
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go
new file mode 100644
index 000000000..b5de2588b
--- /dev/null
+++ b/runsc/cmd/debug.go
@@ -0,0 +1,304 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "os"
+ "strconv"
+ "strings"
+ "syscall"
+ "time"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/control"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+// Debug implements subcommands.Command for the "debug" command.
+type Debug struct {
+ pid int
+ stacks bool
+ signal int
+ profileHeap string
+ profileCPU string
+ profileGoroutine string
+ profileBlock string
+ profileMutex string
+ trace string
+ strace string
+ logLevel string
+ logPackets string
+ duration time.Duration
+ ps bool
+}
+
+// Name implements subcommands.Command.
+func (*Debug) Name() string {
+ return "debug"
+}
+
+// Synopsis implements subcommands.Command.
+func (*Debug) Synopsis() string {
+ return "shows a variety of debug information"
+}
+
+// Usage implements subcommands.Command.
+func (*Debug) Usage() string {
+ return `debug [flags] <container id>`
+}
+
+// SetFlags implements subcommands.Command.
+func (d *Debug) SetFlags(f *flag.FlagSet) {
+ f.IntVar(&d.pid, "pid", 0, "sandbox process ID. Container ID is not necessary if this is set")
+ f.BoolVar(&d.stacks, "stacks", false, "if true, dumps all sandbox stacks to the log")
+ f.StringVar(&d.profileHeap, "profile-heap", "", "writes heap profile to the given file.")
+ f.StringVar(&d.profileCPU, "profile-cpu", "", "writes CPU profile to the given file.")
+ f.StringVar(&d.profileGoroutine, "profile-goroutine", "", "writes goroutine profile to the given file.")
+ f.StringVar(&d.profileBlock, "profile-block", "", "writes block profile to the given file.")
+ f.StringVar(&d.profileMutex, "profile-mutex", "", "writes mutex profile to the given file.")
+ f.DurationVar(&d.duration, "duration", time.Second, "amount of time to wait for CPU and trace profiles")
+ f.StringVar(&d.trace, "trace", "", "writes an execution trace to the given file.")
+ f.IntVar(&d.signal, "signal", -1, "sends signal to the sandbox")
+ f.StringVar(&d.strace, "strace", "", `A comma separated list of syscalls to trace. "all" enables all traces, "off" disables all`)
+ f.StringVar(&d.logLevel, "log-level", "", "The log level to set: warning (0), info (1), or debug (2).")
+ f.StringVar(&d.logPackets, "log-packets", "", "A boolean value to enable or disable packet logging: true or false.")
+ f.BoolVar(&d.ps, "ps", false, "lists processes")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ var c *container.Container
+ conf := args[0].(*boot.Config)
+
+ if d.pid == 0 {
+ // No pid, container ID must have been provided.
+ if f.NArg() != 1 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+ var err error
+ c, err = container.Load(conf.RootDir, f.Arg(0))
+ if err != nil {
+ return Errorf("loading container %q: %v", f.Arg(0), err)
+ }
+ } else {
+ if f.NArg() != 0 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+ // Go over all sandboxes and find the one that matches PID.
+ ids, err := container.List(conf.RootDir)
+ if err != nil {
+ return Errorf("listing containers: %v", err)
+ }
+ for _, id := range ids {
+ candidate, err := container.Load(conf.RootDir, id)
+ if err != nil {
+ return Errorf("loading container %q: %v", id, err)
+ }
+ if candidate.SandboxPid() == d.pid {
+ c = candidate
+ break
+ }
+ }
+ if c == nil {
+ return Errorf("container with PID %d not found", d.pid)
+ }
+ }
+
+ if c.Sandbox == nil || !c.Sandbox.IsRunning() {
+ return Errorf("container sandbox is not running")
+ }
+ log.Infof("Found sandbox %q, PID: %d", c.Sandbox.ID, c.Sandbox.Pid)
+
+ if d.signal > 0 {
+ log.Infof("Sending signal %d to process: %d", d.signal, c.Sandbox.Pid)
+ if err := syscall.Kill(c.Sandbox.Pid, syscall.Signal(d.signal)); err != nil {
+ return Errorf("failed to send signal %d to processs %d", d.signal, c.Sandbox.Pid)
+ }
+ }
+ if d.stacks {
+ log.Infof("Retrieving sandbox stacks")
+ stacks, err := c.Sandbox.Stacks()
+ if err != nil {
+ return Errorf("retrieving stacks: %v", err)
+ }
+ log.Infof(" *** Stack dump ***\n%s", stacks)
+ }
+ if d.profileHeap != "" {
+ f, err := os.Create(d.profileHeap)
+ if err != nil {
+ return Errorf(err.Error())
+ }
+ defer f.Close()
+
+ if err := c.Sandbox.HeapProfile(f); err != nil {
+ return Errorf(err.Error())
+ }
+ log.Infof("Heap profile written to %q", d.profileHeap)
+ }
+ if d.profileGoroutine != "" {
+ f, err := os.Create(d.profileGoroutine)
+ if err != nil {
+ return Errorf(err.Error())
+ }
+ defer f.Close()
+
+ if err := c.Sandbox.GoroutineProfile(f); err != nil {
+ return Errorf(err.Error())
+ }
+ log.Infof("Goroutine profile written to %q", d.profileGoroutine)
+ }
+ if d.profileBlock != "" {
+ f, err := os.Create(d.profileBlock)
+ if err != nil {
+ return Errorf(err.Error())
+ }
+ defer f.Close()
+
+ if err := c.Sandbox.BlockProfile(f); err != nil {
+ return Errorf(err.Error())
+ }
+ log.Infof("Block profile written to %q", d.profileBlock)
+ }
+ if d.profileMutex != "" {
+ f, err := os.Create(d.profileMutex)
+ if err != nil {
+ return Errorf(err.Error())
+ }
+ defer f.Close()
+
+ if err := c.Sandbox.MutexProfile(f); err != nil {
+ return Errorf(err.Error())
+ }
+ log.Infof("Mutex profile written to %q", d.profileMutex)
+ }
+
+ delay := false
+ if d.profileCPU != "" {
+ delay = true
+ f, err := os.Create(d.profileCPU)
+ if err != nil {
+ return Errorf(err.Error())
+ }
+ defer func() {
+ f.Close()
+ if err := c.Sandbox.StopCPUProfile(); err != nil {
+ Fatalf(err.Error())
+ }
+ log.Infof("CPU profile written to %q", d.profileCPU)
+ }()
+ if err := c.Sandbox.StartCPUProfile(f); err != nil {
+ return Errorf(err.Error())
+ }
+ log.Infof("CPU profile started for %v, writing to %q", d.duration, d.profileCPU)
+ }
+ if d.trace != "" {
+ delay = true
+ f, err := os.Create(d.trace)
+ if err != nil {
+ return Errorf(err.Error())
+ }
+ defer func() {
+ f.Close()
+ if err := c.Sandbox.StopTrace(); err != nil {
+ Fatalf(err.Error())
+ }
+ log.Infof("Trace written to %q", d.trace)
+ }()
+ if err := c.Sandbox.StartTrace(f); err != nil {
+ return Errorf(err.Error())
+ }
+ log.Infof("Tracing started for %v, writing to %q", d.duration, d.trace)
+ }
+
+ if d.strace != "" || len(d.logLevel) != 0 || len(d.logPackets) != 0 {
+ args := control.LoggingArgs{}
+ switch strings.ToLower(d.strace) {
+ case "":
+ // strace not set, nothing to do here.
+
+ case "off":
+ log.Infof("Disabling strace")
+ args.SetStrace = true
+
+ case "all":
+ log.Infof("Enabling all straces")
+ args.SetStrace = true
+ args.EnableStrace = true
+
+ default:
+ log.Infof("Enabling strace for syscalls: %s", d.strace)
+ args.SetStrace = true
+ args.EnableStrace = true
+ args.StraceWhitelist = strings.Split(d.strace, ",")
+ }
+
+ if len(d.logLevel) != 0 {
+ args.SetLevel = true
+ switch strings.ToLower(d.logLevel) {
+ case "warning", "0":
+ args.Level = log.Warning
+ case "info", "1":
+ args.Level = log.Info
+ case "debug", "2":
+ args.Level = log.Debug
+ default:
+ return Errorf("invalid log level %q", d.logLevel)
+ }
+ log.Infof("Setting log level %v", args.Level)
+ }
+
+ if len(d.logPackets) != 0 {
+ args.SetLogPackets = true
+ lp, err := strconv.ParseBool(d.logPackets)
+ if err != nil {
+ return Errorf("invalid value for log_packets %q", d.logPackets)
+ }
+ args.LogPackets = lp
+ if args.LogPackets {
+ log.Infof("Enabling packet logging")
+ } else {
+ log.Infof("Disabling packet logging")
+ }
+ }
+
+ if err := c.Sandbox.ChangeLogging(args); err != nil {
+ return Errorf(err.Error())
+ }
+ log.Infof("Logging options changed")
+ }
+ if d.ps {
+ pList, err := c.Processes()
+ if err != nil {
+ Fatalf("getting processes for container: %v", err)
+ }
+ o, err := control.ProcessListToJSON(pList)
+ if err != nil {
+ Fatalf("generating JSON: %v", err)
+ }
+ log.Infof(o)
+ }
+
+ if delay {
+ time.Sleep(d.duration)
+ }
+
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/cmd/delete.go b/runsc/cmd/delete.go
new file mode 100644
index 000000000..0e4863f50
--- /dev/null
+++ b/runsc/cmd/delete.go
@@ -0,0 +1,87 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "fmt"
+ "os"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+// Delete implements subcommands.Command for the "delete" command.
+type Delete struct {
+ // force indicates that the container should be terminated if running.
+ force bool
+}
+
+// Name implements subcommands.Command.Name.
+func (*Delete) Name() string {
+ return "delete"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Delete) Synopsis() string {
+ return "delete resources held by a container"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Delete) Usage() string {
+ return `delete [flags] <container ids>`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (d *Delete) SetFlags(f *flag.FlagSet) {
+ f.BoolVar(&d.force, "force", false, "terminate container if running")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (d *Delete) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if f.NArg() == 0 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+
+ conf := args[0].(*boot.Config)
+ if err := d.execute(f.Args(), conf); err != nil {
+ Fatalf("%v", err)
+ }
+ return subcommands.ExitSuccess
+}
+
+func (d *Delete) execute(ids []string, conf *boot.Config) error {
+ for _, id := range ids {
+ c, err := container.Load(conf.RootDir, id)
+ if err != nil {
+ if os.IsNotExist(err) && d.force {
+ log.Warningf("couldn't find container %q: %v", id, err)
+ return nil
+ }
+ return fmt.Errorf("loading container %q: %v", id, err)
+ }
+ if !d.force && c.Status != container.Created && c.Status != container.Stopped {
+ return fmt.Errorf("cannot delete container that is not stopped without --force flag")
+ }
+ if err := c.Destroy(); err != nil {
+ return fmt.Errorf("destroying container: %v", err)
+ }
+ }
+ return nil
+}
diff --git a/runsc/cmd/delete_test.go b/runsc/cmd/delete_test.go
new file mode 100644
index 000000000..cb59516a3
--- /dev/null
+++ b/runsc/cmd/delete_test.go
@@ -0,0 +1,41 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "io/ioutil"
+ "testing"
+
+ "gvisor.dev/gvisor/runsc/boot"
+)
+
+func TestNotFound(t *testing.T) {
+ ids := []string{"123"}
+ dir, err := ioutil.TempDir("", "metadata")
+ if err != nil {
+ t.Fatalf("error creating dir: %v", err)
+ }
+ conf := &boot.Config{RootDir: dir}
+
+ d := Delete{}
+ if err := d.execute(ids, conf); err == nil {
+ t.Error("Deleting non-existent container should have failed")
+ }
+
+ d = Delete{force: true}
+ if err := d.execute(ids, conf); err != nil {
+ t.Errorf("Deleting non-existent container with --force should NOT have failed: %v", err)
+ }
+}
diff --git a/runsc/cmd/do.go b/runsc/cmd/do.go
new file mode 100644
index 000000000..7d1310c96
--- /dev/null
+++ b/runsc/cmd/do.go
@@ -0,0 +1,385 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "math/rand"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "syscall"
+
+ "github.com/google/subcommands"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// Do implements subcommands.Command for the "do" command. It sets up a simple
+// sandbox and executes the command inside it. See Usage() for more details.
+type Do struct {
+ root string
+ cwd string
+ ip string
+ quiet bool
+}
+
+// Name implements subcommands.Command.Name.
+func (*Do) Name() string {
+ return "do"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Do) Synopsis() string {
+ return "Simplistic way to execute a command inside the sandbox. It's to be used for testing only."
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Do) Usage() string {
+ return `do [flags] <cmd> - runs a command.
+
+This command starts a sandbox with host filesystem mounted inside as readonly,
+with a writable tmpfs overlay on top of it. The given command is executed inside
+the sandbox. It's to be used to quickly test applications without having to
+install or run docker. It doesn't give nearly as many options and it's to be
+used for testing only.
+`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (c *Do) SetFlags(f *flag.FlagSet) {
+ f.StringVar(&c.root, "root", "/", `path to the root directory, defaults to "/"`)
+ f.StringVar(&c.cwd, "cwd", ".", "path to the current directory, defaults to the current directory")
+ f.StringVar(&c.ip, "ip", "192.168.10.2", "IPv4 address for the sandbox")
+ f.BoolVar(&c.quiet, "quiet", false, "suppress runsc messages to stdout. Application output is still sent to stdout and stderr")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if len(f.Args()) == 0 {
+ c.Usage()
+ return subcommands.ExitUsageError
+ }
+
+ conf := args[0].(*boot.Config)
+ waitStatus := args[1].(*syscall.WaitStatus)
+
+ if conf.Rootless {
+ if err := specutils.MaybeRunAsRoot(); err != nil {
+ return Errorf("Error executing inside namespace: %v", err)
+ }
+ // Execution will continue here if no more capabilities are needed...
+ }
+
+ hostname, err := os.Hostname()
+ if err != nil {
+ return Errorf("Error to retrieve hostname: %v", err)
+ }
+
+ // Map the entire host file system, but make it readonly with a writable
+ // overlay on top (ignore --overlay option).
+ conf.Overlay = true
+ absRoot, err := resolvePath(c.root)
+ if err != nil {
+ return Errorf("Error resolving root: %v", err)
+ }
+ absCwd, err := resolvePath(c.cwd)
+ if err != nil {
+ return Errorf("Error resolving current directory: %v", err)
+ }
+
+ spec := &specs.Spec{
+ Root: &specs.Root{
+ Path: absRoot,
+ },
+ Process: &specs.Process{
+ Cwd: absCwd,
+ Args: f.Args(),
+ Env: os.Environ(),
+ Capabilities: specutils.AllCapabilities(),
+ },
+ Hostname: hostname,
+ }
+
+ specutils.LogSpec(spec)
+
+ cid := fmt.Sprintf("runsc-%06d", rand.Int31n(1000000))
+ if conf.Network == boot.NetworkNone {
+ netns := specs.LinuxNamespace{
+ Type: specs.NetworkNamespace,
+ }
+ if spec.Linux != nil {
+ panic("spec.Linux is not nil")
+ }
+ spec.Linux = &specs.Linux{Namespaces: []specs.LinuxNamespace{netns}}
+
+ } else if conf.Rootless {
+ if conf.Network == boot.NetworkSandbox {
+ c.notifyUser("*** Warning: using host network due to --rootless ***")
+ conf.Network = boot.NetworkHost
+ }
+
+ } else {
+ clean, err := c.setupNet(cid, spec)
+ if err != nil {
+ return Errorf("Error setting up network: %v", err)
+ }
+ defer clean()
+ }
+
+ out, err := json.Marshal(spec)
+ if err != nil {
+ return Errorf("Error to marshal spec: %v", err)
+ }
+ tmpDir, err := ioutil.TempDir("", "runsc-do")
+ if err != nil {
+ return Errorf("Error to create tmp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ log.Infof("Changing configuration RootDir to %q", tmpDir)
+ conf.RootDir = tmpDir
+
+ cfgPath := filepath.Join(tmpDir, "config.json")
+ if err := ioutil.WriteFile(cfgPath, out, 0755); err != nil {
+ return Errorf("Error write spec: %v", err)
+ }
+
+ containerArgs := container.Args{
+ ID: cid,
+ Spec: spec,
+ BundleDir: tmpDir,
+ Attached: true,
+ }
+ ct, err := container.New(conf, containerArgs)
+ if err != nil {
+ return Errorf("creating container: %v", err)
+ }
+ defer ct.Destroy()
+
+ if err := ct.Start(conf); err != nil {
+ return Errorf("starting container: %v", err)
+ }
+
+ // Forward signals to init in the container. Thus if we get SIGINT from
+ // ^C, the container gracefully exit, and we can clean up.
+ //
+ // N.B. There is a still a window before this where a signal may kill
+ // this process, skipping cleanup.
+ stopForwarding := ct.ForwardSignals(0 /* pid */, false /* fgProcess */)
+ defer stopForwarding()
+
+ ws, err := ct.Wait()
+ if err != nil {
+ return Errorf("waiting for container: %v", err)
+ }
+
+ *waitStatus = ws
+ return subcommands.ExitSuccess
+}
+
+func (c *Do) notifyUser(format string, v ...interface{}) {
+ if !c.quiet {
+ fmt.Printf(format+"\n", v...)
+ }
+ log.Warningf(format, v...)
+}
+
+func resolvePath(path string) (string, error) {
+ var err error
+ path, err = filepath.Abs(path)
+ if err != nil {
+ return "", fmt.Errorf("resolving %q: %v", path, err)
+ }
+ path = filepath.Clean(path)
+ if err := syscall.Access(path, 0); err != nil {
+ return "", fmt.Errorf("unable to access %q: %v", path, err)
+ }
+ return path, nil
+}
+
+func (c *Do) setupNet(cid string, spec *specs.Spec) (func(), error) {
+ dev, err := defaultDevice()
+ if err != nil {
+ return nil, err
+ }
+ peerIP, err := calculatePeerIP(c.ip)
+ if err != nil {
+ return nil, err
+ }
+ veth, peer := deviceNames(cid)
+
+ cmds := []string{
+ fmt.Sprintf("ip link add %s type veth peer name %s", veth, peer),
+
+ // Setup device outside the namespace.
+ fmt.Sprintf("ip addr add %s/24 dev %s", peerIP, peer),
+ fmt.Sprintf("ip link set %s up", peer),
+
+ // Setup device inside the namespace.
+ fmt.Sprintf("ip netns add %s", cid),
+ fmt.Sprintf("ip link set %s netns %s", veth, cid),
+ fmt.Sprintf("ip netns exec %s ip addr add %s/24 dev %s", cid, c.ip, veth),
+ fmt.Sprintf("ip netns exec %s ip link set %s up", cid, veth),
+ fmt.Sprintf("ip netns exec %s ip link set lo up", cid),
+ fmt.Sprintf("ip netns exec %s ip route add default via %s", cid, peerIP),
+
+ // Enable network access.
+ "sysctl -w net.ipv4.ip_forward=1",
+ fmt.Sprintf("iptables -t nat -A POSTROUTING -s %s -o %s -j MASQUERADE", c.ip, dev),
+ fmt.Sprintf("iptables -A FORWARD -i %s -o %s -j ACCEPT", dev, peer),
+ fmt.Sprintf("iptables -A FORWARD -o %s -i %s -j ACCEPT", dev, peer),
+ }
+
+ for _, cmd := range cmds {
+ log.Debugf("Run %q", cmd)
+ args := strings.Split(cmd, " ")
+ cmd := exec.Command(args[0], args[1:]...)
+ if err := cmd.Run(); err != nil {
+ c.cleanupNet(cid, dev, "", "", "")
+ return nil, fmt.Errorf("failed to run %q: %v", cmd, err)
+ }
+ }
+
+ resolvPath, err := makeFile("/etc/resolv.conf", "nameserver 8.8.8.8\n", spec)
+ if err != nil {
+ c.cleanupNet(cid, dev, "", "", "")
+ return nil, err
+ }
+ hostnamePath, err := makeFile("/etc/hostname", cid+"\n", spec)
+ if err != nil {
+ c.cleanupNet(cid, dev, resolvPath, "", "")
+ return nil, err
+ }
+ hosts := fmt.Sprintf("127.0.0.1\tlocalhost\n%s\t%s\n", c.ip, cid)
+ hostsPath, err := makeFile("/etc/hosts", hosts, spec)
+ if err != nil {
+ c.cleanupNet(cid, dev, resolvPath, hostnamePath, "")
+ return nil, err
+ }
+
+ if spec.Linux == nil {
+ spec.Linux = &specs.Linux{}
+ }
+ netns := specs.LinuxNamespace{
+ Type: specs.NetworkNamespace,
+ Path: filepath.Join("/var/run/netns", cid),
+ }
+ spec.Linux.Namespaces = append(spec.Linux.Namespaces, netns)
+
+ return func() { c.cleanupNet(cid, dev, resolvPath, hostnamePath, hostsPath) }, nil
+}
+
+// cleanupNet tries to cleanup the network setup in setupNet.
+//
+// It may be called when setupNet is only partially complete, in which case it
+// will cleanup as much as possible, logging warnings for the rest.
+//
+// Unfortunately none of this can be automatically cleaned up on process exit,
+// we must do so explicitly.
+func (c *Do) cleanupNet(cid, dev, resolvPath, hostnamePath, hostsPath string) {
+ _, peer := deviceNames(cid)
+
+ cmds := []string{
+ fmt.Sprintf("ip link delete %s", peer),
+ fmt.Sprintf("ip netns delete %s", cid),
+ }
+
+ for _, cmd := range cmds {
+ log.Debugf("Run %q", cmd)
+ args := strings.Split(cmd, " ")
+ c := exec.Command(args[0], args[1:]...)
+ if err := c.Run(); err != nil {
+ log.Warningf("Failed to run %q: %v", cmd, err)
+ }
+ }
+
+ tryRemove(resolvPath)
+ tryRemove(hostnamePath)
+ tryRemove(hostsPath)
+}
+
+func deviceNames(cid string) (string, string) {
+ // Device name is limited to 15 letters.
+ return "ve-" + cid, "vp-" + cid
+
+}
+
+func defaultDevice() (string, error) {
+ out, err := exec.Command("ip", "route", "list", "default").CombinedOutput()
+ if err != nil {
+ return "", err
+ }
+ parts := strings.Split(string(out), " ")
+ if len(parts) < 5 {
+ return "", fmt.Errorf("malformed %q output: %q", "ip route list default", string(out))
+ }
+ return parts[4], nil
+}
+
+func makeFile(dest, content string, spec *specs.Spec) (string, error) {
+ tmpFile, err := ioutil.TempFile("", filepath.Base(dest))
+ if err != nil {
+ return "", err
+ }
+ if _, err := tmpFile.WriteString(content); err != nil {
+ if err := os.Remove(tmpFile.Name()); err != nil {
+ log.Warningf("Failed to remove %q: %v", tmpFile, err)
+ }
+ return "", err
+ }
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Source: tmpFile.Name(),
+ Destination: dest,
+ Type: "bind",
+ Options: []string{"ro"},
+ })
+ return tmpFile.Name(), nil
+}
+
+func tryRemove(path string) {
+ if path == "" {
+ return
+ }
+
+ if err := os.Remove(path); err != nil {
+ log.Warningf("Failed to remove %q: %v", path, err)
+ }
+}
+
+func calculatePeerIP(ip string) (string, error) {
+ parts := strings.Split(ip, ".")
+ if len(parts) != 4 {
+ return "", fmt.Errorf("invalid IP format %q", ip)
+ }
+ n, err := strconv.Atoi(parts[3])
+ if err != nil {
+ return "", fmt.Errorf("invalid IP format %q: %v", ip, err)
+ }
+ n++
+ if n > 255 {
+ n = 1
+ }
+ return fmt.Sprintf("%s.%s.%s.%d", parts[0], parts[1], parts[2], n), nil
+}
diff --git a/runsc/cmd/error.go b/runsc/cmd/error.go
new file mode 100644
index 000000000..3585b5448
--- /dev/null
+++ b/runsc/cmd/error.go
@@ -0,0 +1,72 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "os"
+ "time"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// ErrorLogger is where error messages should be written to. These messages are
+// consumed by containerd and show up to users of command line tools,
+// like docker/kubectl.
+var ErrorLogger io.Writer
+
+type jsonError struct {
+ Msg string `json:"msg"`
+ Level string `json:"level"`
+ Time time.Time `json:"time"`
+}
+
+// Errorf logs error to containerd log (--log), to stderr, and debug logs. It
+// returns subcommands.ExitFailure for convenience with subcommand.Execute()
+// methods:
+// return Errorf("Danger! Danger!")
+//
+func Errorf(format string, args ...interface{}) subcommands.ExitStatus {
+ // If runsc is being invoked by docker or cri-o, then we might not have
+ // access to stderr, so we log a serious-looking warning in addition to
+ // writing to stderr.
+ log.Warningf("FATAL ERROR: "+format, args...)
+ fmt.Fprintf(os.Stderr, format+"\n", args...)
+
+ j := jsonError{
+ Msg: fmt.Sprintf(format, args...),
+ Level: "error",
+ Time: time.Now(),
+ }
+ b, err := json.Marshal(j)
+ if err != nil {
+ panic(err)
+ }
+ if ErrorLogger != nil {
+ ErrorLogger.Write(b)
+ }
+
+ return subcommands.ExitFailure
+}
+
+// Fatalf logs the same way as Errorf() does, plus *exits* the process.
+func Fatalf(format string, args ...interface{}) {
+ Errorf(format, args...)
+ // Return an error that is unlikely to be used by the application.
+ os.Exit(128)
+}
diff --git a/runsc/cmd/events.go b/runsc/cmd/events.go
new file mode 100644
index 000000000..51f6a98ed
--- /dev/null
+++ b/runsc/cmd/events.go
@@ -0,0 +1,111 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "encoding/json"
+ "os"
+ "time"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+// Events implements subcommands.Command for the "events" command.
+type Events struct {
+ // The interval between stats reporting.
+ intervalSec int
+ // If true, events will print a single group of stats and exit.
+ stats bool
+}
+
+// Name implements subcommands.Command.Name.
+func (*Events) Name() string {
+ return "events"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Events) Synopsis() string {
+ return "display container events such as OOM notifications, cpu, memory, and IO usage statistics"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Events) Usage() string {
+ return `<container-id>
+
+Where "<container-id>" is the name for the instance of the container.
+
+The events command displays information about the container. By default the
+information is displayed once every 5 seconds.
+
+OPTIONS:
+`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (evs *Events) SetFlags(f *flag.FlagSet) {
+ f.IntVar(&evs.intervalSec, "interval", 5, "set the stats collection interval, in seconds")
+ f.BoolVar(&evs.stats, "stats", false, "display the container's stats then exit")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (evs *Events) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if f.NArg() != 1 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+
+ id := f.Arg(0)
+ conf := args[0].(*boot.Config)
+
+ c, err := container.Load(conf.RootDir, id)
+ if err != nil {
+ Fatalf("loading sandbox: %v", err)
+ }
+
+ // Repeatedly get stats from the container.
+ for {
+ // Get the event and print it as JSON.
+ ev, err := c.Event()
+ if err != nil {
+ log.Warningf("Error getting events for container: %v", err)
+ }
+ // err must be preserved because it is used below when breaking
+ // out of the loop.
+ b, err := json.Marshal(ev)
+ if err != nil {
+ log.Warningf("Error while marshalling event %v: %v", ev, err)
+ } else {
+ os.Stdout.Write(b)
+ }
+
+ // If we're only running once, break. If we're only running
+ // once and there was an error, the command failed.
+ if evs.stats {
+ if err != nil {
+ return subcommands.ExitFailure
+ }
+ break
+ }
+
+ time.Sleep(time.Duration(evs.intervalSec) * time.Second)
+ }
+
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go
new file mode 100644
index 000000000..d9a94903e
--- /dev/null
+++ b/runsc/cmd/exec.go
@@ -0,0 +1,481 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "syscall"
+ "time"
+
+ "github.com/google/subcommands"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/control"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/urpc"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/console"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// Exec implements subcommands.Command for the "exec" command.
+type Exec struct {
+ cwd string
+ env stringSlice
+ // user contains the UID and GID with which to run the new process.
+ user user
+ extraKGIDs stringSlice
+ caps stringSlice
+ detach bool
+ processPath string
+ pidFile string
+ internalPidFile string
+
+ // consoleSocket is the path to an AF_UNIX socket which will receive a
+ // file descriptor referencing the master end of the console's
+ // pseudoterminal.
+ consoleSocket string
+}
+
+// Name implements subcommands.Command.Name.
+func (*Exec) Name() string {
+ return "exec"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Exec) Synopsis() string {
+ return "execute new process inside the container"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Exec) Usage() string {
+ return `exec [command options] <container-id> <command> [command options] || --process process.json <container-id>
+
+
+Where "<container-id>" is the name for the instance of the container and
+"<command>" is the command to be executed in the container.
+"<command>" can't be empty unless a "-process" flag provided.
+
+EXAMPLE:
+If the container is configured to run /bin/ps the following will
+output a list of processes running in the container:
+
+ # runc exec <container-id> ps
+
+OPTIONS:
+`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (ex *Exec) SetFlags(f *flag.FlagSet) {
+ f.StringVar(&ex.cwd, "cwd", "", "current working directory")
+ f.Var(&ex.env, "env", "set environment variables (e.g. '-env PATH=/bin -env TERM=xterm')")
+ f.Var(&ex.user, "user", "UID (format: <uid>[:<gid>])")
+ f.Var(&ex.extraKGIDs, "additional-gids", "additional gids")
+ f.Var(&ex.caps, "cap", "add a capability to the bounding set for the process")
+ f.BoolVar(&ex.detach, "detach", false, "detach from the container's process")
+ f.StringVar(&ex.processPath, "process", "", "path to the process.json")
+ f.StringVar(&ex.pidFile, "pid-file", "", "filename that the container pid will be written to")
+ f.StringVar(&ex.internalPidFile, "internal-pid-file", "", "filename that the container-internal pid will be written to")
+ f.StringVar(&ex.consoleSocket, "console-socket", "", "path to an AF_UNIX socket which will receive a file descriptor referencing the master end of the console's pseudoterminal")
+}
+
+// Execute implements subcommands.Command.Execute. It starts a process in an
+// already created container.
+func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ conf := args[0].(*boot.Config)
+ e, id, err := ex.parseArgs(f, conf.EnableRaw)
+ if err != nil {
+ Fatalf("parsing process spec: %v", err)
+ }
+ waitStatus := args[1].(*syscall.WaitStatus)
+
+ c, err := container.Load(conf.RootDir, id)
+ if err != nil {
+ Fatalf("loading sandbox: %v", err)
+ }
+
+ log.Debugf("Exec arguments: %+v", e)
+ log.Debugf("Exec capablities: %+v", e.Capabilities)
+
+ // Replace empty settings with defaults from container.
+ if e.WorkingDirectory == "" {
+ e.WorkingDirectory = c.Spec.Process.Cwd
+ }
+ if e.Envv == nil {
+ e.Envv, err = resolveEnvs(c.Spec.Process.Env, ex.env)
+ if err != nil {
+ Fatalf("getting environment variables: %v", err)
+ }
+ }
+
+ if e.Capabilities == nil {
+ e.Capabilities, err = specutils.Capabilities(conf.EnableRaw, c.Spec.Process.Capabilities)
+ if err != nil {
+ Fatalf("creating capabilities: %v", err)
+ }
+ log.Infof("Using exec capabilities from container: %+v", e.Capabilities)
+ }
+
+ // containerd expects an actual process to represent the container being
+ // executed. If detach was specified, starts a child in non-detach mode,
+ // write the child's PID to the pid file. So when the container returns, the
+ // child process will also return and signal containerd.
+ if ex.detach {
+ return ex.execChildAndWait(waitStatus)
+ }
+ return ex.exec(c, e, waitStatus)
+}
+
+func (ex *Exec) exec(c *container.Container, e *control.ExecArgs, waitStatus *syscall.WaitStatus) subcommands.ExitStatus {
+ // Start the new process and get it pid.
+ pid, err := c.Execute(e)
+ if err != nil {
+ return Errorf("executing processes for container: %v", err)
+ }
+
+ if e.StdioIsPty {
+ // Forward signals sent to this process to the foreground
+ // process in the sandbox.
+ stopForwarding := c.ForwardSignals(pid, true /* fgProcess */)
+ defer stopForwarding()
+ }
+
+ // Write the sandbox-internal pid if required.
+ if ex.internalPidFile != "" {
+ pidStr := []byte(strconv.Itoa(int(pid)))
+ if err := ioutil.WriteFile(ex.internalPidFile, pidStr, 0644); err != nil {
+ return Errorf("writing internal pid file %q: %v", ex.internalPidFile, err)
+ }
+ }
+
+ // Generate the pid file after the internal pid file is generated, so that
+ // users can safely assume that the internal pid file is ready after
+ // `runsc exec -d` returns.
+ if ex.pidFile != "" {
+ if err := ioutil.WriteFile(ex.pidFile, []byte(strconv.Itoa(os.Getpid())), 0644); err != nil {
+ return Errorf("writing pid file: %v", err)
+ }
+ }
+
+ // Wait for the process to exit.
+ ws, err := c.WaitPID(pid)
+ if err != nil {
+ return Errorf("waiting on pid %d: %v", pid, err)
+ }
+ *waitStatus = ws
+ return subcommands.ExitSuccess
+}
+
+func (ex *Exec) execChildAndWait(waitStatus *syscall.WaitStatus) subcommands.ExitStatus {
+ var args []string
+ for _, a := range os.Args[1:] {
+ if !strings.Contains(a, "detach") {
+ args = append(args, a)
+ }
+ }
+
+ // The command needs to write a pid file so that execChildAndWait can tell
+ // when it has started. If no pid-file was provided, we should use a
+ // filename in a temp directory.
+ pidFile := ex.pidFile
+ if pidFile == "" {
+ tmpDir, err := ioutil.TempDir("", "exec-pid-")
+ if err != nil {
+ Fatalf("creating TempDir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+ pidFile = filepath.Join(tmpDir, "pid")
+ args = append(args, "--pid-file="+pidFile)
+ }
+
+ cmd := exec.Command(specutils.ExePath, args...)
+ cmd.Args[0] = "runsc-exec"
+
+ // Exec stdio defaults to current process stdio.
+ cmd.Stdin = os.Stdin
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+
+ // If the console control socket file is provided, then create a new
+ // pty master/slave pair and set the TTY on the sandbox process.
+ if ex.consoleSocket != "" {
+ // Create a new TTY pair and send the master on the provided socket.
+ tty, err := console.NewWithSocket(ex.consoleSocket)
+ if err != nil {
+ Fatalf("setting up console with socket %q: %v", ex.consoleSocket, err)
+ }
+ defer tty.Close()
+
+ // Set stdio to the new TTY slave.
+ cmd.Stdin = tty
+ cmd.Stdout = tty
+ cmd.Stderr = tty
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Setsid: true,
+ Setctty: true,
+ // The Ctty FD must be the FD in the child process's FD
+ // table. Since we set cmd.Stdin/Stdout/Stderr to the
+ // tty FD, we can use any of 0, 1, or 2 here.
+ // See https://github.com/golang/go/issues/29458.
+ Ctty: 0,
+ }
+ }
+
+ if err := cmd.Start(); err != nil {
+ Fatalf("failure to start child exec process, err: %v", err)
+ }
+
+ log.Infof("Started child (PID: %d) to exec and wait: %s %s", cmd.Process.Pid, specutils.ExePath, args)
+
+ // Wait for PID file to ensure that child process has started. Otherwise,
+ // '--process' file is deleted as soon as this process returns and the child
+ // may fail to read it.
+ ready := func() (bool, error) {
+ pidb, err := ioutil.ReadFile(pidFile)
+ if err == nil {
+ // File appeared, check whether pid is fully written.
+ pid, err := strconv.Atoi(string(pidb))
+ if err != nil {
+ return false, nil
+ }
+ return pid == cmd.Process.Pid, nil
+ }
+ if pe, ok := err.(*os.PathError); !ok || pe.Err != syscall.ENOENT {
+ return false, err
+ }
+ // No file yet, continue to wait...
+ return false, nil
+ }
+ if err := specutils.WaitForReady(cmd.Process.Pid, 10*time.Second, ready); err != nil {
+ // Don't log fatal error here, otherwise it will override the error logged
+ // by the child process that has failed to start.
+ log.Warningf("Unexpected error waiting for PID file, err: %v", err)
+ return subcommands.ExitFailure
+ }
+
+ *waitStatus = 0
+ return subcommands.ExitSuccess
+}
+
+// parseArgs parses exec information from the command line or a JSON file
+// depending on whether the --process flag was used. Returns an ExecArgs and
+// the ID of the container to be used.
+func (ex *Exec) parseArgs(f *flag.FlagSet, enableRaw bool) (*control.ExecArgs, string, error) {
+ if ex.processPath == "" {
+ // Requires at least a container ID and command.
+ if f.NArg() < 2 {
+ f.Usage()
+ return nil, "", fmt.Errorf("both a container-id and command are required")
+ }
+ e, err := ex.argsFromCLI(f.Args()[1:], enableRaw)
+ return e, f.Arg(0), err
+ }
+ // Requires only the container ID.
+ if f.NArg() != 1 {
+ f.Usage()
+ return nil, "", fmt.Errorf("a container-id is required")
+ }
+ e, err := ex.argsFromProcessFile(enableRaw)
+ return e, f.Arg(0), err
+}
+
+func (ex *Exec) argsFromCLI(argv []string, enableRaw bool) (*control.ExecArgs, error) {
+ extraKGIDs := make([]auth.KGID, 0, len(ex.extraKGIDs))
+ for _, s := range ex.extraKGIDs {
+ kgid, err := strconv.Atoi(s)
+ if err != nil {
+ Fatalf("parsing GID: %s, %v", s, err)
+ }
+ extraKGIDs = append(extraKGIDs, auth.KGID(kgid))
+ }
+
+ var caps *auth.TaskCapabilities
+ if len(ex.caps) > 0 {
+ var err error
+ caps, err = capabilities(ex.caps, enableRaw)
+ if err != nil {
+ return nil, fmt.Errorf("capabilities error: %v", err)
+ }
+ }
+
+ return &control.ExecArgs{
+ Argv: argv,
+ WorkingDirectory: ex.cwd,
+ KUID: ex.user.kuid,
+ KGID: ex.user.kgid,
+ ExtraKGIDs: extraKGIDs,
+ Capabilities: caps,
+ StdioIsPty: ex.consoleSocket != "",
+ FilePayload: urpc.FilePayload{[]*os.File{os.Stdin, os.Stdout, os.Stderr}},
+ }, nil
+}
+
+func (ex *Exec) argsFromProcessFile(enableRaw bool) (*control.ExecArgs, error) {
+ f, err := os.Open(ex.processPath)
+ if err != nil {
+ return nil, fmt.Errorf("error opening process file: %s, %v", ex.processPath, err)
+ }
+ defer f.Close()
+ var p specs.Process
+ if err := json.NewDecoder(f).Decode(&p); err != nil {
+ return nil, fmt.Errorf("error parsing process file: %s, %v", ex.processPath, err)
+ }
+ return argsFromProcess(&p, enableRaw)
+}
+
+// argsFromProcess performs all the non-IO conversion from the Process struct
+// to ExecArgs.
+func argsFromProcess(p *specs.Process, enableRaw bool) (*control.ExecArgs, error) {
+ // Create capabilities.
+ var caps *auth.TaskCapabilities
+ if p.Capabilities != nil {
+ var err error
+ // Starting from Docker 19, capabilities are explicitly set for exec (instead
+ // of nil like before). So we can't distinguish 'exec' from
+ // 'exec --privileged', as both specify CAP_NET_RAW. Therefore, filter
+ // CAP_NET_RAW in the same way as container start.
+ caps, err = specutils.Capabilities(enableRaw, p.Capabilities)
+ if err != nil {
+ return nil, fmt.Errorf("error creating capabilities: %v", err)
+ }
+ }
+
+ // Convert the spec's additional GIDs to KGIDs.
+ extraKGIDs := make([]auth.KGID, 0, len(p.User.AdditionalGids))
+ for _, GID := range p.User.AdditionalGids {
+ extraKGIDs = append(extraKGIDs, auth.KGID(GID))
+ }
+
+ return &control.ExecArgs{
+ Argv: p.Args,
+ Envv: p.Env,
+ WorkingDirectory: p.Cwd,
+ KUID: auth.KUID(p.User.UID),
+ KGID: auth.KGID(p.User.GID),
+ ExtraKGIDs: extraKGIDs,
+ Capabilities: caps,
+ StdioIsPty: p.Terminal,
+ FilePayload: urpc.FilePayload{Files: []*os.File{os.Stdin, os.Stdout, os.Stderr}},
+ }, nil
+}
+
+// resolveEnvs transforms lists of environment variables into a single list of
+// environment variables. If a variable is defined multiple times, the last
+// value is used.
+func resolveEnvs(envs ...[]string) ([]string, error) {
+ // First create a map of variable names to values. This removes any
+ // duplicates.
+ envMap := make(map[string]string)
+ for _, env := range envs {
+ for _, str := range env {
+ parts := strings.SplitN(str, "=", 2)
+ if len(parts) != 2 {
+ return nil, fmt.Errorf("invalid variable: %s", str)
+ }
+ envMap[parts[0]] = parts[1]
+ }
+ }
+ // Reassemble envMap into a list of environment variables of the form
+ // NAME=VALUE.
+ env := make([]string, 0, len(envMap))
+ for k, v := range envMap {
+ env = append(env, fmt.Sprintf("%s=%s", k, v))
+ }
+ return env, nil
+}
+
+// capabilities takes a list of capabilities as strings and returns an
+// auth.TaskCapabilities struct with those capabilities in every capability set.
+// This mimics runc's behavior.
+func capabilities(cs []string, enableRaw bool) (*auth.TaskCapabilities, error) {
+ var specCaps specs.LinuxCapabilities
+ for _, cap := range cs {
+ specCaps.Ambient = append(specCaps.Ambient, cap)
+ specCaps.Bounding = append(specCaps.Bounding, cap)
+ specCaps.Effective = append(specCaps.Effective, cap)
+ specCaps.Inheritable = append(specCaps.Inheritable, cap)
+ specCaps.Permitted = append(specCaps.Permitted, cap)
+ }
+ // Starting from Docker 19, capabilities are explicitly set for exec (instead
+ // of nil like before). So we can't distinguish 'exec' from
+ // 'exec --privileged', as both specify CAP_NET_RAW. Therefore, filter
+ // CAP_NET_RAW in the same way as container start.
+ return specutils.Capabilities(enableRaw, &specCaps)
+}
+
+// stringSlice allows a flag to be used multiple times, where each occurrence
+// adds a value to the flag. For example, a flag called "x" could be invoked
+// via "runsc exec -x foo -x bar", and the corresponding stringSlice would be
+// {"x", "y"}.
+type stringSlice []string
+
+// String implements flag.Value.String.
+func (ss *stringSlice) String() string {
+ return fmt.Sprintf("%v", *ss)
+}
+
+// Get implements flag.Value.Get.
+func (ss *stringSlice) Get() interface{} {
+ return ss
+}
+
+// Set implements flag.Value.Set.
+func (ss *stringSlice) Set(s string) error {
+ *ss = append(*ss, s)
+ return nil
+}
+
+// user allows -user to convey a UID and, optionally, a GID separated by a
+// colon.
+type user struct {
+ kuid auth.KUID
+ kgid auth.KGID
+}
+
+func (u *user) String() string {
+ return fmt.Sprintf("%+v", *u)
+}
+
+func (u *user) Get() interface{} {
+ return u
+}
+
+func (u *user) Set(s string) error {
+ parts := strings.SplitN(s, ":", 2)
+ kuid, err := strconv.Atoi(parts[0])
+ if err != nil {
+ return fmt.Errorf("couldn't parse UID: %s", parts[0])
+ }
+ u.kuid = auth.KUID(kuid)
+ if len(parts) > 1 {
+ kgid, err := strconv.Atoi(parts[1])
+ if err != nil {
+ return fmt.Errorf("couldn't parse GID: %s", parts[1])
+ }
+ u.kgid = auth.KGID(kgid)
+ }
+ return nil
+}
diff --git a/runsc/cmd/exec_test.go b/runsc/cmd/exec_test.go
new file mode 100644
index 000000000..a1e980d08
--- /dev/null
+++ b/runsc/cmd/exec_test.go
@@ -0,0 +1,154 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "os"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/control"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/urpc"
+)
+
+func TestUser(t *testing.T) {
+ testCases := []struct {
+ input string
+ want user
+ wantErr bool
+ }{
+ {input: "0", want: user{kuid: 0, kgid: 0}},
+ {input: "7", want: user{kuid: 7, kgid: 0}},
+ {input: "49:343", want: user{kuid: 49, kgid: 343}},
+ {input: "0:2401", want: user{kuid: 0, kgid: 2401}},
+ {input: "", wantErr: true},
+ {input: "foo", wantErr: true},
+ {input: ":123", wantErr: true},
+ {input: "1:2:3", wantErr: true},
+ }
+
+ for _, tc := range testCases {
+ var u user
+ if err := u.Set(tc.input); err != nil && tc.wantErr {
+ // We got an error and wanted one.
+ continue
+ } else if err == nil && tc.wantErr {
+ t.Errorf("user.Set(%s): got no error, but wanted one", tc.input)
+ } else if err != nil && !tc.wantErr {
+ t.Errorf("user.Set(%s): got error %v, but wanted none", tc.input, err)
+ } else if u != tc.want {
+ t.Errorf("user.Set(%s): got %+v, but wanted %+v", tc.input, u, tc.want)
+ }
+ }
+}
+
+func TestCLIArgs(t *testing.T) {
+ testCases := []struct {
+ ex Exec
+ argv []string
+ expected control.ExecArgs
+ }{
+ {
+ ex: Exec{
+ cwd: "/foo/bar",
+ user: user{kuid: 0, kgid: 0},
+ extraKGIDs: []string{"1", "2", "3"},
+ caps: []string{"CAP_DAC_OVERRIDE"},
+ processPath: "",
+ },
+ argv: []string{"ls", "/"},
+ expected: control.ExecArgs{
+ Argv: []string{"ls", "/"},
+ WorkingDirectory: "/foo/bar",
+ FilePayload: urpc.FilePayload{Files: []*os.File{os.Stdin, os.Stdout, os.Stderr}},
+ KUID: 0,
+ KGID: 0,
+ ExtraKGIDs: []auth.KGID{1, 2, 3},
+ Capabilities: &auth.TaskCapabilities{
+ BoundingCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE),
+ EffectiveCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE),
+ InheritableCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE),
+ PermittedCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE),
+ },
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ e, err := tc.ex.argsFromCLI(tc.argv, true)
+ if err != nil {
+ t.Errorf("argsFromCLI(%+v): got error: %+v", tc.ex, err)
+ } else if !cmp.Equal(*e, tc.expected, cmpopts.IgnoreUnexported(os.File{})) {
+ t.Errorf("argsFromCLI(%+v): got %+v, but expected %+v", tc.ex, *e, tc.expected)
+ }
+ }
+}
+
+func TestJSONArgs(t *testing.T) {
+ testCases := []struct {
+ // ex is provided to make sure it is overridden by p.
+ ex Exec
+ p specs.Process
+ expected control.ExecArgs
+ }{
+ {
+ ex: Exec{
+ cwd: "/baz/quux",
+ user: user{kuid: 1, kgid: 1},
+ extraKGIDs: []string{"4", "5", "6"},
+ caps: []string{"CAP_SETGID"},
+ processPath: "/bin/foo",
+ },
+ p: specs.Process{
+ User: specs.User{UID: 0, GID: 0, AdditionalGids: []uint32{1, 2, 3}},
+ Args: []string{"ls", "/"},
+ Cwd: "/foo/bar",
+ Capabilities: &specs.LinuxCapabilities{
+ Bounding: []string{"CAP_DAC_OVERRIDE"},
+ Effective: []string{"CAP_DAC_OVERRIDE"},
+ Inheritable: []string{"CAP_DAC_OVERRIDE"},
+ Permitted: []string{"CAP_DAC_OVERRIDE"},
+ },
+ },
+ expected: control.ExecArgs{
+ Argv: []string{"ls", "/"},
+ WorkingDirectory: "/foo/bar",
+ FilePayload: urpc.FilePayload{Files: []*os.File{os.Stdin, os.Stdout, os.Stderr}},
+ KUID: 0,
+ KGID: 0,
+ ExtraKGIDs: []auth.KGID{1, 2, 3},
+ Capabilities: &auth.TaskCapabilities{
+ BoundingCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE),
+ EffectiveCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE),
+ InheritableCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE),
+ PermittedCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE),
+ },
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ e, err := argsFromProcess(&tc.p, true)
+ if err != nil {
+ t.Errorf("argsFromProcess(%+v): got error: %+v", tc.p, err)
+ } else if !cmp.Equal(*e, tc.expected, cmpopts.IgnoreUnexported(os.File{})) {
+ t.Errorf("argsFromProcess(%+v): got %+v, but expected %+v", tc.p, *e, tc.expected)
+ }
+ }
+}
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
new file mode 100644
index 000000000..3966e2d21
--- /dev/null
+++ b/runsc/cmd/gofer.go
@@ -0,0 +1,484 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+ "syscall"
+
+ "github.com/google/subcommands"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/flag"
+ "gvisor.dev/gvisor/runsc/fsgofer"
+ "gvisor.dev/gvisor/runsc/fsgofer/filter"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+var caps = []string{
+ "CAP_CHOWN",
+ "CAP_DAC_OVERRIDE",
+ "CAP_DAC_READ_SEARCH",
+ "CAP_FOWNER",
+ "CAP_FSETID",
+ "CAP_SYS_CHROOT",
+}
+
+// goferCaps is the minimal set of capabilities needed by the Gofer to operate
+// on files.
+var goferCaps = &specs.LinuxCapabilities{
+ Bounding: caps,
+ Effective: caps,
+ Permitted: caps,
+}
+
+// Gofer implements subcommands.Command for the "gofer" command, which starts a
+// filesystem gofer. This command should not be called directly.
+type Gofer struct {
+ bundleDir string
+ ioFDs intFlags
+ applyCaps bool
+ setUpRoot bool
+
+ panicOnWrite bool
+ specFD int
+ mountsFD int
+}
+
+// Name implements subcommands.Command.
+func (*Gofer) Name() string {
+ return "gofer"
+}
+
+// Synopsis implements subcommands.Command.
+func (*Gofer) Synopsis() string {
+ return "launch a gofer process that serves files over 9P protocol (internal use only)"
+}
+
+// Usage implements subcommands.Command.
+func (*Gofer) Usage() string {
+ return `gofer [flags]`
+}
+
+// SetFlags implements subcommands.Command.
+func (g *Gofer) SetFlags(f *flag.FlagSet) {
+ f.StringVar(&g.bundleDir, "bundle", "", "path to the root of the bundle directory, defaults to the current directory")
+ f.Var(&g.ioFDs, "io-fds", "list of FDs to connect 9P servers. They must follow this order: root first, then mounts as defined in the spec")
+ f.BoolVar(&g.applyCaps, "apply-caps", true, "if true, apply capabilities to restrict what the Gofer process can do")
+ f.BoolVar(&g.panicOnWrite, "panic-on-write", false, "if true, panics on attempts to write to RO mounts. RW mounts are unnaffected")
+ f.BoolVar(&g.setUpRoot, "setup-root", true, "if true, set up an empty root for the process")
+ f.IntVar(&g.specFD, "spec-fd", -1, "required fd with the container spec")
+ f.IntVar(&g.mountsFD, "mounts-fd", -1, "mountsFD is the file descriptor to write list of mounts after they have been resolved (direct paths, no symlinks).")
+}
+
+// Execute implements subcommands.Command.
+func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if g.bundleDir == "" || len(g.ioFDs) < 1 || g.specFD < 0 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+
+ specFile := os.NewFile(uintptr(g.specFD), "spec file")
+ defer specFile.Close()
+ spec, err := specutils.ReadSpecFromFile(g.bundleDir, specFile)
+ if err != nil {
+ Fatalf("reading spec: %v", err)
+ }
+
+ conf := args[0].(*boot.Config)
+
+ if g.setUpRoot {
+ if err := setupRootFS(spec, conf); err != nil {
+ Fatalf("Error setting up root FS: %v", err)
+ }
+ }
+ if g.applyCaps {
+ // Disable caps when calling myself again.
+ // Note: minimal argument handling for the default case to keep it simple.
+ args := os.Args
+ args = append(args, "--apply-caps=false", "--setup-root=false")
+ if err := setCapsAndCallSelf(args, goferCaps); err != nil {
+ Fatalf("Unable to apply caps: %v", err)
+ }
+ panic("unreachable")
+ }
+
+ // Find what path is going to be served by this gofer.
+ root := spec.Root.Path
+ if !conf.TestOnlyAllowRunAsCurrentUserWithoutChroot {
+ root = "/root"
+ }
+
+ // Resolve mount points paths, then replace mounts from our spec and send the
+ // mount list over to the sandbox, so they are both in sync.
+ //
+ // Note that all mount points have been mounted in the proper location in
+ // setupRootFS().
+ cleanMounts, err := resolveMounts(conf, spec.Mounts, root)
+ if err != nil {
+ Fatalf("Failure to resolve mounts: %v", err)
+ }
+ spec.Mounts = cleanMounts
+ go func() {
+ if err := g.writeMounts(cleanMounts); err != nil {
+ panic(fmt.Sprintf("Failed to write mounts: %v", err))
+ }
+ }()
+
+ specutils.LogSpec(spec)
+
+ // fsgofer should run with a umask of 0, because we want to preserve file
+ // modes exactly as sent by the sandbox, which will have applied its own umask.
+ syscall.Umask(0)
+
+ if err := fsgofer.OpenProcSelfFD(); err != nil {
+ Fatalf("failed to open /proc/self/fd: %v", err)
+ }
+
+ if err := syscall.Chroot(root); err != nil {
+ Fatalf("failed to chroot to %q: %v", root, err)
+ }
+ if err := syscall.Chdir("/"); err != nil {
+ Fatalf("changing working dir: %v", err)
+ }
+ log.Infof("Process chroot'd to %q", root)
+
+ // Start with root mount, then add any other additional mount as needed.
+ ats := make([]p9.Attacher, 0, len(spec.Mounts)+1)
+ ap, err := fsgofer.NewAttachPoint("/", fsgofer.Config{
+ ROMount: spec.Root.Readonly || conf.Overlay,
+ PanicOnWrite: g.panicOnWrite,
+ })
+ if err != nil {
+ Fatalf("creating attach point: %v", err)
+ }
+ ats = append(ats, ap)
+ log.Infof("Serving %q mapped to %q on FD %d (ro: %t)", "/", root, g.ioFDs[0], spec.Root.Readonly)
+
+ mountIdx := 1 // first one is the root
+ for _, m := range spec.Mounts {
+ if specutils.Is9PMount(m) {
+ cfg := fsgofer.Config{
+ ROMount: isReadonlyMount(m.Options) || conf.Overlay,
+ PanicOnWrite: g.panicOnWrite,
+ HostUDS: conf.FSGoferHostUDS,
+ }
+ ap, err := fsgofer.NewAttachPoint(m.Destination, cfg)
+ if err != nil {
+ Fatalf("creating attach point: %v", err)
+ }
+ ats = append(ats, ap)
+
+ if mountIdx >= len(g.ioFDs) {
+ Fatalf("no FD found for mount. Did you forget --io-fd? mount: %d, %v", len(g.ioFDs), m)
+ }
+ log.Infof("Serving %q mapped on FD %d (ro: %t)", m.Destination, g.ioFDs[mountIdx], cfg.ROMount)
+ mountIdx++
+ }
+ }
+ if mountIdx != len(g.ioFDs) {
+ Fatalf("too many FDs passed for mounts. mounts: %d, FDs: %d", mountIdx, len(g.ioFDs))
+ }
+
+ if conf.FSGoferHostUDS {
+ filter.InstallUDSFilters()
+ }
+
+ if err := filter.Install(); err != nil {
+ Fatalf("installing seccomp filters: %v", err)
+ }
+
+ runServers(ats, g.ioFDs)
+ return subcommands.ExitSuccess
+}
+
+func runServers(ats []p9.Attacher, ioFDs []int) {
+ // Run the loops and wait for all to exit.
+ var wg sync.WaitGroup
+ for i, ioFD := range ioFDs {
+ wg.Add(1)
+ go func(ioFD int, at p9.Attacher) {
+ socket, err := unet.NewSocket(ioFD)
+ if err != nil {
+ Fatalf("creating server on FD %d: %v", ioFD, err)
+ }
+ s := p9.NewServer(at)
+ if err := s.Handle(socket); err != nil {
+ Fatalf("P9 server returned error. Gofer is shutting down. FD: %d, err: %v", ioFD, err)
+ }
+ wg.Done()
+ }(ioFD, ats[i])
+ }
+ wg.Wait()
+ log.Infof("All 9P servers exited.")
+}
+
+func (g *Gofer) writeMounts(mounts []specs.Mount) error {
+ bytes, err := json.Marshal(mounts)
+ if err != nil {
+ return err
+ }
+
+ f := os.NewFile(uintptr(g.mountsFD), "mounts file")
+ defer f.Close()
+
+ for written := 0; written < len(bytes); {
+ w, err := f.Write(bytes[written:])
+ if err != nil {
+ return err
+ }
+ written += w
+ }
+ return nil
+}
+
+func isReadonlyMount(opts []string) bool {
+ for _, o := range opts {
+ if o == "ro" {
+ return true
+ }
+ }
+ return false
+}
+
+func setupRootFS(spec *specs.Spec, conf *boot.Config) error {
+ // Convert all shared mounts into slaves to be sure that nothing will be
+ // propagated outside of our namespace.
+ if err := syscall.Mount("", "/", "", syscall.MS_SLAVE|syscall.MS_REC, ""); err != nil {
+ Fatalf("error converting mounts: %v", err)
+ }
+
+ root := spec.Root.Path
+ if !conf.TestOnlyAllowRunAsCurrentUserWithoutChroot {
+ // runsc can't be re-executed without /proc, so we create a tmpfs mount,
+ // mount ./proc and ./root there, then move this mount to the root and after
+ // setCapsAndCallSelf, runsc will chroot into /root.
+ //
+ // We need a directory to construct a new root and we know that
+ // runsc can't start without /proc, so we can use it for this.
+ flags := uintptr(syscall.MS_NOSUID | syscall.MS_NODEV | syscall.MS_NOEXEC)
+ if err := syscall.Mount("runsc-root", "/proc", "tmpfs", flags, ""); err != nil {
+ Fatalf("error mounting tmpfs: %v", err)
+ }
+
+ // Prepare tree structure for pivot_root(2).
+ os.Mkdir("/proc/proc", 0755)
+ os.Mkdir("/proc/root", 0755)
+ if err := syscall.Mount("runsc-proc", "/proc/proc", "proc", flags|syscall.MS_RDONLY, ""); err != nil {
+ Fatalf("error mounting proc: %v", err)
+ }
+ root = "/proc/root"
+ }
+
+ // Mount root path followed by submounts.
+ if err := syscall.Mount(spec.Root.Path, root, "bind", syscall.MS_BIND|syscall.MS_REC, ""); err != nil {
+ return fmt.Errorf("mounting root on root (%q) err: %v", root, err)
+ }
+
+ flags := uint32(syscall.MS_SLAVE | syscall.MS_REC)
+ if spec.Linux != nil && spec.Linux.RootfsPropagation != "" {
+ flags = specutils.PropOptionsToFlags([]string{spec.Linux.RootfsPropagation})
+ }
+ if err := syscall.Mount("", root, "", uintptr(flags), ""); err != nil {
+ return fmt.Errorf("mounting root (%q) with flags: %#x, err: %v", root, flags, err)
+ }
+
+ // Replace the current spec, with the clean spec with symlinks resolved.
+ if err := setupMounts(conf, spec.Mounts, root); err != nil {
+ Fatalf("error setting up FS: %v", err)
+ }
+
+ // Create working directory if needed.
+ if spec.Process.Cwd != "" {
+ dst, err := resolveSymlinks(root, spec.Process.Cwd)
+ if err != nil {
+ return fmt.Errorf("resolving symlinks to %q: %v", spec.Process.Cwd, err)
+ }
+ if err := os.MkdirAll(dst, 0755); err != nil {
+ return fmt.Errorf("creating working directory %q: %v", spec.Process.Cwd, err)
+ }
+ }
+
+ // Check if root needs to be remounted as readonly.
+ if spec.Root.Readonly || conf.Overlay {
+ // If root is a mount point but not read-only, we can change mount options
+ // to make it read-only for extra safety.
+ log.Infof("Remounting root as readonly: %q", root)
+ flags := uintptr(syscall.MS_BIND | syscall.MS_REMOUNT | syscall.MS_RDONLY | syscall.MS_REC)
+ if err := syscall.Mount(root, root, "bind", flags, ""); err != nil {
+ return fmt.Errorf("remounting root as read-only with source: %q, target: %q, flags: %#x, err: %v", root, root, flags, err)
+ }
+ }
+
+ if !conf.TestOnlyAllowRunAsCurrentUserWithoutChroot {
+ if err := pivotRoot("/proc"); err != nil {
+ Fatalf("failed to change the root file system: %v", err)
+ }
+ if err := os.Chdir("/"); err != nil {
+ Fatalf("failed to change working directory")
+ }
+ }
+ return nil
+}
+
+// setupMounts binds mount all mounts specified in the spec in their correct
+// location inside root. It will resolve relative paths and symlinks. It also
+// creates directories as needed.
+func setupMounts(conf *boot.Config, mounts []specs.Mount, root string) error {
+ for _, m := range mounts {
+ if m.Type != "bind" || !specutils.IsSupportedDevMount(m) {
+ continue
+ }
+
+ dst, err := resolveSymlinks(root, m.Destination)
+ if err != nil {
+ return fmt.Errorf("resolving symlinks to %q: %v", m.Destination, err)
+ }
+
+ flags := specutils.OptionsToFlags(m.Options) | syscall.MS_BIND
+ if conf.Overlay {
+ // Force mount read-only if writes are not going to be sent to it.
+ flags |= syscall.MS_RDONLY
+ }
+
+ log.Infof("Mounting src: %q, dst: %q, flags: %#x", m.Source, dst, flags)
+ if err := specutils.Mount(m.Source, dst, m.Type, flags); err != nil {
+ return fmt.Errorf("mounting %v: %v", m, err)
+ }
+
+ // Set propagation options that cannot be set together with other options.
+ flags = specutils.PropOptionsToFlags(m.Options)
+ if flags != 0 {
+ if err := syscall.Mount("", dst, "", uintptr(flags), ""); err != nil {
+ return fmt.Errorf("mount dst: %q, flags: %#x, err: %v", dst, flags, err)
+ }
+ }
+ }
+ return nil
+}
+
+// resolveMounts resolved relative paths and symlinks to mount points.
+//
+// Note: mount points must already be in place for resolution to work.
+// Otherwise, it may follow symlinks to locations that would be overwritten
+// with another mount point and return the wrong location. In short, make sure
+// setupMounts() has been called before.
+func resolveMounts(conf *boot.Config, mounts []specs.Mount, root string) ([]specs.Mount, error) {
+ cleanMounts := make([]specs.Mount, 0, len(mounts))
+ for _, m := range mounts {
+ if m.Type != "bind" || !specutils.IsSupportedDevMount(m) {
+ cleanMounts = append(cleanMounts, m)
+ continue
+ }
+ dst, err := resolveSymlinks(root, m.Destination)
+ if err != nil {
+ return nil, fmt.Errorf("resolving symlinks to %q: %v", m.Destination, err)
+ }
+ relDst, err := filepath.Rel(root, dst)
+ if err != nil {
+ panic(fmt.Sprintf("%q could not be made relative to %q: %v", dst, root, err))
+ }
+
+ opts, err := adjustMountOptions(conf, filepath.Join(root, relDst), m.Options)
+ if err != nil {
+ return nil, err
+ }
+
+ cpy := m
+ cpy.Destination = filepath.Join("/", relDst)
+ cpy.Options = opts
+ cleanMounts = append(cleanMounts, cpy)
+ }
+ return cleanMounts, nil
+}
+
+// ResolveSymlinks walks 'rel' having 'root' as the root directory. If there are
+// symlinks, they are evaluated relative to 'root' to ensure the end result is
+// the same as if the process was running inside the container.
+func resolveSymlinks(root, rel string) (string, error) {
+ return resolveSymlinksImpl(root, root, rel, 255)
+}
+
+func resolveSymlinksImpl(root, base, rel string, followCount uint) (string, error) {
+ if followCount == 0 {
+ return "", fmt.Errorf("too many symlinks to follow, path: %q", filepath.Join(base, rel))
+ }
+
+ rel = filepath.Clean(rel)
+ for _, name := range strings.Split(rel, string(filepath.Separator)) {
+ if name == "" {
+ continue
+ }
+ // Note that Join() resolves things like ".." and returns a clean path.
+ path := filepath.Join(base, name)
+ if !strings.HasPrefix(path, root) {
+ // One cannot '..' their way out of root.
+ base = root
+ continue
+ }
+ fi, err := os.Lstat(path)
+ if err != nil {
+ if !os.IsNotExist(err) {
+ return "", err
+ }
+ // Not found means there is no symlink to check. Just keep walking dirs.
+ base = path
+ continue
+ }
+ if fi.Mode()&os.ModeSymlink != 0 {
+ link, err := os.Readlink(path)
+ if err != nil {
+ return "", err
+ }
+ if filepath.IsAbs(link) {
+ base = root
+ }
+ base, err = resolveSymlinksImpl(root, base, link, followCount-1)
+ if err != nil {
+ return "", err
+ }
+ continue
+ }
+ base = path
+ }
+ return base, nil
+}
+
+// adjustMountOptions adds 'overlayfs_stale_read' if mounting over overlayfs.
+func adjustMountOptions(conf *boot.Config, path string, opts []string) ([]string, error) {
+ rv := make([]string, len(opts))
+ copy(rv, opts)
+
+ if conf.OverlayfsStaleRead {
+ statfs := syscall.Statfs_t{}
+ if err := syscall.Statfs(path, &statfs); err != nil {
+ return nil, err
+ }
+ if statfs.Type == unix.OVERLAYFS_SUPER_MAGIC {
+ rv = append(rv, "overlayfs_stale_read")
+ }
+ }
+ return rv, nil
+}
diff --git a/runsc/cmd/gofer_test.go b/runsc/cmd/gofer_test.go
new file mode 100644
index 000000000..cbea7f127
--- /dev/null
+++ b/runsc/cmd/gofer_test.go
@@ -0,0 +1,164 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path"
+ "path/filepath"
+ "testing"
+)
+
+func tmpDir() string {
+ dir := os.Getenv("TEST_TMPDIR")
+ if dir == "" {
+ dir = "/tmp"
+ }
+ return dir
+}
+
+type dir struct {
+ rel string
+ link string
+}
+
+func construct(root string, dirs []dir) error {
+ for _, d := range dirs {
+ p := path.Join(root, d.rel)
+ if d.link == "" {
+ if err := os.MkdirAll(p, 0755); err != nil {
+ return fmt.Errorf("error creating dir: %v", err)
+ }
+ } else {
+ if err := os.MkdirAll(path.Dir(p), 0755); err != nil {
+ return fmt.Errorf("error creating dir: %v", err)
+ }
+ if err := os.Symlink(d.link, p); err != nil {
+ return fmt.Errorf("error creating symlink: %v", err)
+ }
+ }
+ }
+ return nil
+}
+
+func TestResolveSymlinks(t *testing.T) {
+ root, err := ioutil.TempDir(tmpDir(), "root")
+ if err != nil {
+ t.Fatal("ioutil.TempDir() failed:", err)
+ }
+ dirs := []dir{
+ {"dir1/dir11/dir111/dir1111", ""}, // Just a boring dir
+ {"dir1/lnk12", "dir11"}, // Link to sibling
+ {"dir1/lnk13", "./dir11"}, // Link to sibling through self
+ {"dir1/lnk14", "../dir1/dir11"}, // Link to sibling through parent
+ {"dir1/dir15/lnk151", ".."}, // Link to parent
+ {"dir1/lnk16", "dir11/dir111"}, // Link to child
+ {"dir1/lnk17", "."}, // Link to self
+ {"dir1/lnk18", "lnk13"}, // Link to link
+ {"lnk2", "dir1/lnk13"}, // Link to link to link
+ {"dir3/dir21/lnk211", "../.."}, // Link to root relative
+ {"dir3/lnk22", "/"}, // Link to root absolute
+ {"dir3/lnk23", "/dir1"}, // Link to dir absolute
+ {"dir3/lnk24", "/dir1/lnk12"}, // Link to link absolute
+ {"lnk5", "../../.."}, // Link outside root
+ }
+ if err := construct(root, dirs); err != nil {
+ t.Fatal("construct failed:", err)
+ }
+
+ tests := []struct {
+ name string
+ rel string
+ want string
+ compareHost bool
+ }{
+ {name: "root", rel: "/", want: "/", compareHost: true},
+ {name: "basic dir", rel: "/dir1/dir11/dir111", want: "/dir1/dir11/dir111", compareHost: true},
+ {name: "dot 1", rel: "/dir1/dir11/./dir111", want: "/dir1/dir11/dir111", compareHost: true},
+ {name: "dot 2", rel: "/dir1/././dir11/./././././dir111/.", want: "/dir1/dir11/dir111", compareHost: true},
+ {name: "dotdot 1", rel: "/dir1/dir11/../dir15", want: "/dir1/dir15", compareHost: true},
+ {name: "dotdot 2", rel: "/dir1/dir11/dir1111/../..", want: "/dir1", compareHost: true},
+
+ {name: "link sibling", rel: "/dir1/lnk12", want: "/dir1/dir11", compareHost: true},
+ {name: "link sibling + dir", rel: "/dir1/lnk12/dir111", want: "/dir1/dir11/dir111", compareHost: true},
+ {name: "link sibling through self", rel: "/dir1/lnk13", want: "/dir1/dir11", compareHost: true},
+ {name: "link sibling through parent", rel: "/dir1/lnk14", want: "/dir1/dir11", compareHost: true},
+
+ {name: "link parent", rel: "/dir1/dir15/lnk151", want: "/dir1", compareHost: true},
+ {name: "link parent + dir", rel: "/dir1/dir15/lnk151/dir11", want: "/dir1/dir11", compareHost: true},
+ {name: "link child", rel: "/dir1/lnk16", want: "/dir1/dir11/dir111", compareHost: true},
+ {name: "link child + dir", rel: "/dir1/lnk16/dir1111", want: "/dir1/dir11/dir111/dir1111", compareHost: true},
+ {name: "link self", rel: "/dir1/lnk17", want: "/dir1", compareHost: true},
+ {name: "link self + dir", rel: "/dir1/lnk17/dir11", want: "/dir1/dir11", compareHost: true},
+
+ {name: "link^2", rel: "/dir1/lnk18", want: "/dir1/dir11", compareHost: true},
+ {name: "link^2 + dir", rel: "/dir1/lnk18/dir111", want: "/dir1/dir11/dir111", compareHost: true},
+ {name: "link^3", rel: "/lnk2", want: "/dir1/dir11", compareHost: true},
+ {name: "link^3 + dir", rel: "/lnk2/dir111", want: "/dir1/dir11/dir111", compareHost: true},
+
+ {name: "link abs", rel: "/dir3/lnk23", want: "/dir1"},
+ {name: "link abs + dir", rel: "/dir3/lnk23/dir11", want: "/dir1/dir11"},
+ {name: "link^2 abs", rel: "/dir3/lnk24", want: "/dir1/dir11"},
+ {name: "link^2 abs + dir", rel: "/dir3/lnk24/dir111", want: "/dir1/dir11/dir111"},
+
+ {name: "root link rel", rel: "/dir3/dir21/lnk211", want: "/", compareHost: true},
+ {name: "root link abs", rel: "/dir3/lnk22", want: "/"},
+ {name: "root contain link", rel: "/lnk5/dir1", want: "/dir1"},
+ {name: "root contain dotdot", rel: "/dir1/dir11/../../../../../../../..", want: "/"},
+
+ {name: "crazy", rel: "/dir3/dir21/lnk211/dir3/lnk22/dir1/dir11/../../lnk5/dir3/../dir3/lnk24/dir111/dir1111/..", want: "/dir1/dir11/dir111"},
+ }
+ for _, tst := range tests {
+ t.Run(tst.name, func(t *testing.T) {
+ got, err := resolveSymlinks(root, tst.rel)
+ if err != nil {
+ t.Errorf("resolveSymlinks(root, %q) failed: %v", tst.rel, err)
+ }
+ want := path.Join(root, tst.want)
+ if got != want {
+ t.Errorf("resolveSymlinks(root, %q) got: %q, want: %q", tst.rel, got, want)
+ }
+ if tst.compareHost {
+ // Check that host got to the same end result.
+ host, err := filepath.EvalSymlinks(path.Join(root, tst.rel))
+ if err != nil {
+ t.Errorf("path.EvalSymlinks(root, %q) failed: %v", tst.rel, err)
+ }
+ if host != got {
+ t.Errorf("resolveSymlinks(root, %q) got: %q, want: %q", tst.rel, host, got)
+ }
+ }
+ })
+ }
+}
+
+func TestResolveSymlinksLoop(t *testing.T) {
+ root, err := ioutil.TempDir(tmpDir(), "root")
+ if err != nil {
+ t.Fatal("ioutil.TempDir() failed:", err)
+ }
+ dirs := []dir{
+ {"loop1", "loop2"},
+ {"loop2", "loop1"},
+ }
+ if err := construct(root, dirs); err != nil {
+ t.Fatal("construct failed:", err)
+ }
+ if _, err := resolveSymlinks(root, "loop1"); err == nil {
+ t.Errorf("resolveSymlinks() should have failed")
+ }
+}
diff --git a/runsc/cmd/help.go b/runsc/cmd/help.go
new file mode 100644
index 000000000..cd85dabbb
--- /dev/null
+++ b/runsc/cmd/help.go
@@ -0,0 +1,120 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+// NewHelp returns a help command for the given commander.
+func NewHelp(cdr *subcommands.Commander) *Help {
+ return &Help{
+ cdr: cdr,
+ }
+}
+
+// Help implements subcommands.Command for the "help" command. The 'help'
+// command prints help for commands registered to a Commander but also allows for
+// registering additional help commands that print other documentation.
+type Help struct {
+ cdr *subcommands.Commander
+ commands []subcommands.Command
+ help bool
+}
+
+// Name implements subcommands.Command.Name.
+func (*Help) Name() string {
+ return "help"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Help) Synopsis() string {
+ return "Print help documentation."
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Help) Usage() string {
+ return `help [<subcommand>]:
+ With an argument, prints detailed information on the use of
+ the specified topic or subcommand. With no argument, print a list of
+ all commands and a brief description of each.
+`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (h *Help) SetFlags(f *flag.FlagSet) {}
+
+// Execute implements subcommands.Command.Execute.
+func (h *Help) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ switch f.NArg() {
+ case 0:
+ fmt.Fprintf(h.cdr.Output, "Usage: %s <flags> <subcommand> <subcommand args>\n\n", h.cdr.Name())
+ fmt.Fprintf(h.cdr.Output, `runsc is the gVisor container runtime.
+
+Functionality is provided by subcommands. For help with a specific subcommand,
+use "%s %s <subcommand>".
+
+`, h.cdr.Name(), h.Name())
+ h.cdr.VisitGroups(func(g *subcommands.CommandGroup) {
+ h.cdr.ExplainGroup(h.cdr.Output, g)
+ })
+
+ fmt.Fprintf(h.cdr.Output, "Additional help topics (Use \"%s %s <topic>\" to see help on the topic):\n", h.cdr.Name(), h.Name())
+ for _, cmd := range h.commands {
+ fmt.Fprintf(h.cdr.Output, "\t%-15s %s\n", cmd.Name(), cmd.Synopsis())
+ }
+ fmt.Fprintf(h.cdr.Output, "\nUse \"%s flags\" for a list of top-level flags\n", h.cdr.Name())
+ return subcommands.ExitSuccess
+ default:
+ // Look for commands registered to the commander and print help explanation if found.
+ found := false
+ h.cdr.VisitCommands(func(g *subcommands.CommandGroup, cmd subcommands.Command) {
+ if f.Arg(0) == cmd.Name() {
+ h.cdr.ExplainCommand(h.cdr.Output, cmd)
+ found = true
+ }
+ })
+ if found {
+ return subcommands.ExitSuccess
+ }
+
+ // Next check commands registered to the help command.
+ for _, cmd := range h.commands {
+ if f.Arg(0) == cmd.Name() {
+ fs := flag.NewFlagSet(f.Arg(0), flag.ContinueOnError)
+ fs.Usage = func() { h.cdr.ExplainCommand(h.cdr.Error, cmd) }
+ cmd.SetFlags(fs)
+ if fs.Parse(f.Args()[1:]) != nil {
+ return subcommands.ExitUsageError
+ }
+ return cmd.Execute(ctx, f, args...)
+ }
+ }
+
+ fmt.Fprintf(h.cdr.Error, "Subcommand %s not understood\n", f.Arg(0))
+ }
+
+ f.Usage()
+ return subcommands.ExitUsageError
+}
+
+// Register registers a new help command.
+func (h *Help) Register(cmd subcommands.Command) {
+ h.commands = append(h.commands, cmd)
+}
diff --git a/runsc/cmd/install.go b/runsc/cmd/install.go
new file mode 100644
index 000000000..2e223e3be
--- /dev/null
+++ b/runsc/cmd/install.go
@@ -0,0 +1,210 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "os"
+ "path"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+// Install implements subcommands.Command.
+type Install struct {
+ ConfigFile string
+ Runtime string
+ Experimental bool
+}
+
+// Name implements subcommands.Command.Name.
+func (*Install) Name() string {
+ return "install"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Install) Synopsis() string {
+ return "adds a runtime to docker daemon configuration"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Install) Usage() string {
+ return `install [flags] <name> [-- [args...]] -- if provided, args are passed to the runtime
+`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (i *Install) SetFlags(fs *flag.FlagSet) {
+ fs.StringVar(&i.ConfigFile, "config_file", "/etc/docker/daemon.json", "path to Docker daemon config file")
+ fs.StringVar(&i.Runtime, "runtime", "runsc", "runtime name")
+ fs.BoolVar(&i.Experimental, "experimental", false, "enable experimental features")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (i *Install) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ // Grab the name and arguments.
+ runtimeArgs := f.Args()
+
+ // Extract the executable.
+ path, err := os.Executable()
+ if err != nil {
+ log.Fatalf("Error reading current exectuable: %v", err)
+ }
+
+ // Load the configuration file.
+ c, err := readConfig(i.ConfigFile)
+ if err != nil {
+ log.Fatalf("Error reading config file %q: %v", i.ConfigFile, err)
+ }
+
+ // Add the given runtime.
+ var rts map[string]interface{}
+ if i, ok := c["runtimes"]; ok {
+ rts = i.(map[string]interface{})
+ } else {
+ rts = make(map[string]interface{})
+ c["runtimes"] = rts
+ }
+ rts[i.Runtime] = struct {
+ Path string `json:"path,omitempty"`
+ RuntimeArgs []string `json:"runtimeArgs,omitempty"`
+ }{
+ Path: path,
+ RuntimeArgs: runtimeArgs,
+ }
+
+ // Set experimental if required.
+ if i.Experimental {
+ c["experimental"] = true
+ }
+
+ // Write out the runtime.
+ if err := writeConfig(c, i.ConfigFile); err != nil {
+ log.Fatalf("Error writing config file %q: %v", i.ConfigFile, err)
+ }
+
+ // Success.
+ log.Printf("Added runtime %q with arguments %v to %q.", i.Runtime, runtimeArgs, i.ConfigFile)
+ return subcommands.ExitSuccess
+}
+
+// Uninstall implements subcommands.Command.
+type Uninstall struct {
+ ConfigFile string
+ Runtime string
+}
+
+// Name implements subcommands.Command.Name.
+func (*Uninstall) Name() string {
+ return "uninstall"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Uninstall) Synopsis() string {
+ return "removes a runtime from docker daemon configuration"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Uninstall) Usage() string {
+ return `uninstall [flags] <name>
+`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (u *Uninstall) SetFlags(fs *flag.FlagSet) {
+ fs.StringVar(&u.ConfigFile, "config_file", "/etc/docker/daemon.json", "path to Docker daemon config file")
+ fs.StringVar(&u.Runtime, "runtime", "runsc", "runtime name")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (u *Uninstall) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ log.Printf("Removing runtime %q from %q.", u.Runtime, u.ConfigFile)
+
+ c, err := readConfig(u.ConfigFile)
+ if err != nil {
+ log.Fatalf("Error reading config file %q: %v", u.ConfigFile, err)
+ }
+
+ var rts map[string]interface{}
+ if i, ok := c["runtimes"]; ok {
+ rts = i.(map[string]interface{})
+ } else {
+ log.Fatalf("runtime %q not found", u.Runtime)
+ }
+ if _, ok := rts[u.Runtime]; !ok {
+ log.Fatalf("runtime %q not found", u.Runtime)
+ }
+ delete(rts, u.Runtime)
+
+ if err := writeConfig(c, u.ConfigFile); err != nil {
+ log.Fatalf("Error writing config file %q: %v", u.ConfigFile, err)
+ }
+ return subcommands.ExitSuccess
+}
+
+func readConfig(path string) (map[string]interface{}, error) {
+ // Read the configuration data.
+ configBytes, err := ioutil.ReadFile(path)
+ if err != nil && !os.IsNotExist(err) {
+ return nil, err
+ }
+
+ // Unmarshal the configuration.
+ c := make(map[string]interface{})
+ if len(configBytes) > 0 {
+ if err := json.Unmarshal(configBytes, &c); err != nil {
+ return nil, err
+ }
+ }
+
+ return c, nil
+}
+
+func writeConfig(c map[string]interface{}, filename string) error {
+ // Marshal the configuration.
+ b, err := json.MarshalIndent(c, "", " ")
+ if err != nil {
+ return err
+ }
+
+ // Copy the old configuration.
+ old, err := ioutil.ReadFile(filename)
+ if err != nil {
+ if !os.IsNotExist(err) {
+ return fmt.Errorf("error reading config file %q: %v", filename, err)
+ }
+ } else {
+ if err := ioutil.WriteFile(filename+"~", old, 0644); err != nil {
+ return fmt.Errorf("error backing up config file %q: %v", filename, err)
+ }
+ }
+
+ // Make the necessary directories.
+ if err := os.MkdirAll(path.Dir(filename), 0755); err != nil {
+ return fmt.Errorf("error creating config directory for %q: %v", filename, err)
+ }
+
+ // Write the new configuration.
+ if err := ioutil.WriteFile(filename, b, 0644); err != nil {
+ return fmt.Errorf("error writing config file %q: %v", filename, err)
+ }
+
+ return nil
+}
diff --git a/runsc/cmd/kill.go b/runsc/cmd/kill.go
new file mode 100644
index 000000000..8282ea0e0
--- /dev/null
+++ b/runsc/cmd/kill.go
@@ -0,0 +1,154 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "strings"
+ "syscall"
+
+ "github.com/google/subcommands"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+// Kill implements subcommands.Command for the "kill" command.
+type Kill struct {
+ all bool
+ pid int
+}
+
+// Name implements subcommands.Command.Name.
+func (*Kill) Name() string {
+ return "kill"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Kill) Synopsis() string {
+ return "sends a signal to the container"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Kill) Usage() string {
+ return `kill <container id> [signal]`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (k *Kill) SetFlags(f *flag.FlagSet) {
+ f.BoolVar(&k.all, "all", false, "send the specified signal to all processes inside the container")
+ f.IntVar(&k.pid, "pid", 0, "send the specified signal to a specific process")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (k *Kill) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if f.NArg() == 0 || f.NArg() > 2 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+
+ id := f.Arg(0)
+ conf := args[0].(*boot.Config)
+
+ if k.pid != 0 && k.all {
+ Fatalf("it is invalid to specify both --all and --pid")
+ }
+
+ c, err := container.Load(conf.RootDir, id)
+ if err != nil {
+ Fatalf("loading container: %v", err)
+ }
+
+ // The OCI command-line spec says that the signal should be specified
+ // via a flag, but runc (and things that call runc) pass it as an
+ // argument.
+ signal := f.Arg(1)
+ if signal == "" {
+ signal = "TERM"
+ }
+
+ sig, err := parseSignal(signal)
+ if err != nil {
+ Fatalf("%v", err)
+ }
+
+ if k.pid != 0 {
+ if err := c.SignalProcess(sig, int32(k.pid)); err != nil {
+ Fatalf("failed to signal pid %d: %v", k.pid, err)
+ }
+ } else {
+ if err := c.SignalContainer(sig, k.all); err != nil {
+ Fatalf("%v", err)
+ }
+ }
+ return subcommands.ExitSuccess
+}
+
+func parseSignal(s string) (syscall.Signal, error) {
+ n, err := strconv.Atoi(s)
+ if err == nil {
+ sig := syscall.Signal(n)
+ for _, msig := range signalMap {
+ if sig == msig {
+ return sig, nil
+ }
+ }
+ return -1, fmt.Errorf("unknown signal %q", s)
+ }
+ if sig, ok := signalMap[strings.TrimPrefix(strings.ToUpper(s), "SIG")]; ok {
+ return sig, nil
+ }
+ return -1, fmt.Errorf("unknown signal %q", s)
+}
+
+var signalMap = map[string]syscall.Signal{
+ "ABRT": unix.SIGABRT,
+ "ALRM": unix.SIGALRM,
+ "BUS": unix.SIGBUS,
+ "CHLD": unix.SIGCHLD,
+ "CLD": unix.SIGCLD,
+ "CONT": unix.SIGCONT,
+ "FPE": unix.SIGFPE,
+ "HUP": unix.SIGHUP,
+ "ILL": unix.SIGILL,
+ "INT": unix.SIGINT,
+ "IO": unix.SIGIO,
+ "IOT": unix.SIGIOT,
+ "KILL": unix.SIGKILL,
+ "PIPE": unix.SIGPIPE,
+ "POLL": unix.SIGPOLL,
+ "PROF": unix.SIGPROF,
+ "PWR": unix.SIGPWR,
+ "QUIT": unix.SIGQUIT,
+ "SEGV": unix.SIGSEGV,
+ "STKFLT": unix.SIGSTKFLT,
+ "STOP": unix.SIGSTOP,
+ "SYS": unix.SIGSYS,
+ "TERM": unix.SIGTERM,
+ "TRAP": unix.SIGTRAP,
+ "TSTP": unix.SIGTSTP,
+ "TTIN": unix.SIGTTIN,
+ "TTOU": unix.SIGTTOU,
+ "URG": unix.SIGURG,
+ "USR1": unix.SIGUSR1,
+ "USR2": unix.SIGUSR2,
+ "VTALRM": unix.SIGVTALRM,
+ "WINCH": unix.SIGWINCH,
+ "XCPU": unix.SIGXCPU,
+ "XFSZ": unix.SIGXFSZ,
+}
diff --git a/runsc/cmd/list.go b/runsc/cmd/list.go
new file mode 100644
index 000000000..d8d906fe3
--- /dev/null
+++ b/runsc/cmd/list.go
@@ -0,0 +1,117 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "os"
+ "text/tabwriter"
+ "time"
+
+ "github.com/google/subcommands"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+// List implements subcommands.Command for the "list" command for the "list" command.
+type List struct {
+ quiet bool
+ format string
+}
+
+// Name implements subcommands.command.name.
+func (*List) Name() string {
+ return "list"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*List) Synopsis() string {
+ return "list containers started by runsc with the given root"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*List) Usage() string {
+ return `list [flags]`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (l *List) SetFlags(f *flag.FlagSet) {
+ f.BoolVar(&l.quiet, "quiet", false, "only list container ids")
+ f.StringVar(&l.format, "format", "text", "output format: 'text' (default) or 'json'")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (l *List) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if f.NArg() != 0 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+
+ conf := args[0].(*boot.Config)
+ ids, err := container.List(conf.RootDir)
+ if err != nil {
+ Fatalf("%v", err)
+ }
+
+ if l.quiet {
+ for _, id := range ids {
+ fmt.Println(id)
+ }
+ return subcommands.ExitSuccess
+ }
+
+ // Collect the containers.
+ var containers []*container.Container
+ for _, id := range ids {
+ c, err := container.Load(conf.RootDir, id)
+ if err != nil {
+ Fatalf("loading container %q: %v", id, err)
+ }
+ containers = append(containers, c)
+ }
+
+ switch l.format {
+ case "text":
+ // Print a nice table.
+ w := tabwriter.NewWriter(os.Stdout, 12, 1, 3, ' ', 0)
+ fmt.Fprint(w, "ID\tPID\tSTATUS\tBUNDLE\tCREATED\tOWNER\n")
+ for _, c := range containers {
+ fmt.Fprintf(w, "%s\t%d\t%s\t%s\t%s\t%s\n",
+ c.ID,
+ c.SandboxPid(),
+ c.Status,
+ c.BundleDir,
+ c.CreatedAt.Format(time.RFC3339Nano),
+ c.Owner)
+ }
+ w.Flush()
+ case "json":
+ // Print just the states.
+ var states []specs.State
+ for _, c := range containers {
+ states = append(states, c.State())
+ }
+ if err := json.NewEncoder(os.Stdout).Encode(states); err != nil {
+ Fatalf("marshaling container state: %v", err)
+ }
+ default:
+ Fatalf("unknown list format %q", l.format)
+ }
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/cmd/path.go b/runsc/cmd/path.go
new file mode 100644
index 000000000..0e9ef7fa5
--- /dev/null
+++ b/runsc/cmd/path.go
@@ -0,0 +1,28 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "os"
+)
+
+// getwdOrDie returns the current working directory and dies if it cannot.
+func getwdOrDie() string {
+ wd, err := os.Getwd()
+ if err != nil {
+ Fatalf("getting current working directory: %v", err)
+ }
+ return wd
+}
diff --git a/runsc/cmd/pause.go b/runsc/cmd/pause.go
new file mode 100644
index 000000000..6f95a9837
--- /dev/null
+++ b/runsc/cmd/pause.go
@@ -0,0 +1,68 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+// Pause implements subcommands.Command for the "pause" command.
+type Pause struct{}
+
+// Name implements subcommands.Command.Name.
+func (*Pause) Name() string {
+ return "pause"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Pause) Synopsis() string {
+ return "pause suspends all processes in a container"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Pause) Usage() string {
+ return `pause <container id> - pause process in instance of container.`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (*Pause) SetFlags(f *flag.FlagSet) {
+}
+
+// Execute implements subcommands.Command.Execute.
+func (*Pause) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if f.NArg() != 1 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+
+ id := f.Arg(0)
+ conf := args[0].(*boot.Config)
+
+ cont, err := container.Load(conf.RootDir, id)
+ if err != nil {
+ Fatalf("loading container: %v", err)
+ }
+
+ if err := cont.Pause(); err != nil {
+ Fatalf("pause failed: %v", err)
+ }
+
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/cmd/ps.go b/runsc/cmd/ps.go
new file mode 100644
index 000000000..7fb8041af
--- /dev/null
+++ b/runsc/cmd/ps.go
@@ -0,0 +1,86 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/pkg/sentry/control"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+// PS implements subcommands.Command for the "ps" command.
+type PS struct {
+ format string
+}
+
+// Name implements subcommands.Command.Name.
+func (*PS) Name() string {
+ return "ps"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*PS) Synopsis() string {
+ return "ps displays the processes running inside a container"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*PS) Usage() string {
+ return "<container-id> [ps options]"
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (ps *PS) SetFlags(f *flag.FlagSet) {
+ f.StringVar(&ps.format, "format", "table", "output format. Select one of: table or json (default: table)")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (ps *PS) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if f.NArg() != 1 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+
+ id := f.Arg(0)
+ conf := args[0].(*boot.Config)
+
+ c, err := container.Load(conf.RootDir, id)
+ if err != nil {
+ Fatalf("loading sandbox: %v", err)
+ }
+ pList, err := c.Processes()
+ if err != nil {
+ Fatalf("getting processes for container: %v", err)
+ }
+
+ switch ps.format {
+ case "table":
+ fmt.Println(control.ProcessListToTable(pList))
+ case "json":
+ o, err := control.PrintPIDsJSON(pList)
+ if err != nil {
+ Fatalf("generating JSON: %v", err)
+ }
+ fmt.Println(o)
+ default:
+ Fatalf("unsupported format: %s", ps.format)
+ }
+
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/cmd/restore.go b/runsc/cmd/restore.go
new file mode 100644
index 000000000..72584b326
--- /dev/null
+++ b/runsc/cmd/restore.go
@@ -0,0 +1,119 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "path/filepath"
+ "syscall"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// Restore implements subcommands.Command for the "restore" command.
+type Restore struct {
+ // Restore flags are a super-set of those for Create.
+ Create
+
+ // imagePath is the path to the saved container image
+ imagePath string
+
+ // detach indicates that runsc has to start a process and exit without waiting it.
+ detach bool
+}
+
+// Name implements subcommands.Command.Name.
+func (*Restore) Name() string {
+ return "restore"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Restore) Synopsis() string {
+ return "restore a saved state of container (experimental)"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Restore) Usage() string {
+ return `restore [flags] <container id> - restore saved state of container.
+`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (r *Restore) SetFlags(f *flag.FlagSet) {
+ r.Create.SetFlags(f)
+ f.StringVar(&r.imagePath, "image-path", "", "directory path to saved container image")
+ f.BoolVar(&r.detach, "detach", false, "detach from the container's process")
+
+ // Unimplemented flags necessary for compatibility with docker.
+
+ var nsr bool
+ f.BoolVar(&nsr, "no-subreaper", false, "ignored")
+
+ var wp string
+ f.StringVar(&wp, "work-path", "", "ignored")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (r *Restore) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if f.NArg() != 1 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+
+ id := f.Arg(0)
+ conf := args[0].(*boot.Config)
+ waitStatus := args[1].(*syscall.WaitStatus)
+
+ if conf.Rootless {
+ return Errorf("Rootless mode not supported with %q", r.Name())
+ }
+
+ bundleDir := r.bundleDir
+ if bundleDir == "" {
+ bundleDir = getwdOrDie()
+ }
+ spec, err := specutils.ReadSpec(bundleDir)
+ if err != nil {
+ return Errorf("reading spec: %v", err)
+ }
+ specutils.LogSpec(spec)
+
+ if r.imagePath == "" {
+ return Errorf("image-path flag must be provided")
+ }
+
+ conf.RestoreFile = filepath.Join(r.imagePath, checkpointFileName)
+
+ runArgs := container.Args{
+ ID: id,
+ Spec: spec,
+ BundleDir: bundleDir,
+ ConsoleSocket: r.consoleSocket,
+ PIDFile: r.pidFile,
+ UserLog: r.userLog,
+ Attached: !r.detach,
+ }
+ ws, err := container.Run(conf, runArgs)
+ if err != nil {
+ return Errorf("running container: %v", err)
+ }
+ *waitStatus = ws
+
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/cmd/resume.go b/runsc/cmd/resume.go
new file mode 100644
index 000000000..61a55a554
--- /dev/null
+++ b/runsc/cmd/resume.go
@@ -0,0 +1,69 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+// Resume implements subcommands.Command for the "resume" command.
+type Resume struct{}
+
+// Name implements subcommands.Command.Name.
+func (*Resume) Name() string {
+ return "resume"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Resume) Synopsis() string {
+ return "Resume unpauses a paused container"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Resume) Usage() string {
+ return `resume <container id> - resume a paused container.
+`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (r *Resume) SetFlags(f *flag.FlagSet) {
+}
+
+// Execute implements subcommands.Command.Execute.
+func (r *Resume) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if f.NArg() != 1 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+
+ id := f.Arg(0)
+ conf := args[0].(*boot.Config)
+
+ cont, err := container.Load(conf.RootDir, id)
+ if err != nil {
+ Fatalf("loading container: %v", err)
+ }
+
+ if err := cont.Resume(); err != nil {
+ Fatalf("resume failed: %v", err)
+ }
+
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/cmd/run.go b/runsc/cmd/run.go
new file mode 100644
index 000000000..cf41581ad
--- /dev/null
+++ b/runsc/cmd/run.go
@@ -0,0 +1,100 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "syscall"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// Run implements subcommands.Command for the "run" command.
+type Run struct {
+ // Run flags are a super-set of those for Create.
+ Create
+
+ // detach indicates that runsc has to start a process and exit without waiting it.
+ detach bool
+}
+
+// Name implements subcommands.Command.Name.
+func (*Run) Name() string {
+ return "run"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Run) Synopsis() string {
+ return "create and run a secure container"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Run) Usage() string {
+ return `run [flags] <container id> - create and run a secure container.
+`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (r *Run) SetFlags(f *flag.FlagSet) {
+ f.BoolVar(&r.detach, "detach", false, "detach from the container's process")
+ r.Create.SetFlags(f)
+}
+
+// Execute implements subcommands.Command.Execute.
+func (r *Run) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if f.NArg() != 1 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+
+ id := f.Arg(0)
+ conf := args[0].(*boot.Config)
+ waitStatus := args[1].(*syscall.WaitStatus)
+
+ if conf.Rootless {
+ return Errorf("Rootless mode not supported with %q", r.Name())
+ }
+
+ bundleDir := r.bundleDir
+ if bundleDir == "" {
+ bundleDir = getwdOrDie()
+ }
+ spec, err := specutils.ReadSpec(bundleDir)
+ if err != nil {
+ return Errorf("reading spec: %v", err)
+ }
+ specutils.LogSpec(spec)
+
+ runArgs := container.Args{
+ ID: id,
+ Spec: spec,
+ BundleDir: bundleDir,
+ ConsoleSocket: r.consoleSocket,
+ PIDFile: r.pidFile,
+ UserLog: r.userLog,
+ Attached: !r.detach,
+ }
+ ws, err := container.Run(conf, runArgs)
+ if err != nil {
+ return Errorf("running container: %v", err)
+ }
+
+ *waitStatus = ws
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/cmd/spec.go b/runsc/cmd/spec.go
new file mode 100644
index 000000000..55194e641
--- /dev/null
+++ b/runsc/cmd/spec.go
@@ -0,0 +1,206 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "os"
+ "path/filepath"
+
+ "github.com/google/subcommands"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+func writeSpec(w io.Writer, cwd string, netns string, args []string) error {
+ spec := &specs.Spec{
+ Version: "1.0.0",
+ Process: &specs.Process{
+ Terminal: true,
+ User: specs.User{
+ UID: 0,
+ GID: 0,
+ },
+ Args: args,
+ Env: []string{
+ "PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
+ "TERM=xterm",
+ },
+ Cwd: cwd,
+ Capabilities: &specs.LinuxCapabilities{
+ Bounding: []string{
+ "CAP_AUDIT_WRITE",
+ "CAP_KILL",
+ "CAP_NET_BIND_SERVICE",
+ },
+ Effective: []string{
+ "CAP_AUDIT_WRITE",
+ "CAP_KILL",
+ "CAP_NET_BIND_SERVICE",
+ },
+ Inheritable: []string{
+ "CAP_AUDIT_WRITE",
+ "CAP_KILL",
+ "CAP_NET_BIND_SERVICE",
+ },
+ Permitted: []string{
+ "CAP_AUDIT_WRITE",
+ "CAP_KILL",
+ "CAP_NET_BIND_SERVICE",
+ },
+ // TODO(gvisor.dev/issue/3166): support ambient capabilities
+ },
+ Rlimits: []specs.POSIXRlimit{
+ {
+ Type: "RLIMIT_NOFILE",
+ Hard: 1024,
+ Soft: 1024,
+ },
+ },
+ },
+ Root: &specs.Root{
+ Path: "rootfs",
+ Readonly: true,
+ },
+ Hostname: "runsc",
+ Mounts: []specs.Mount{
+ {
+ Destination: "/proc",
+ Type: "proc",
+ Source: "proc",
+ },
+ {
+ Destination: "/dev",
+ Type: "tmpfs",
+ Source: "tmpfs",
+ },
+ {
+ Destination: "/sys",
+ Type: "sysfs",
+ Source: "sysfs",
+ Options: []string{
+ "nosuid",
+ "noexec",
+ "nodev",
+ "ro",
+ },
+ },
+ },
+ Linux: &specs.Linux{
+ Namespaces: []specs.LinuxNamespace{
+ {
+ Type: "pid",
+ },
+ {
+ Type: "network",
+ Path: netns,
+ },
+ {
+ Type: "ipc",
+ },
+ {
+ Type: "uts",
+ },
+ {
+ Type: "mount",
+ },
+ },
+ },
+ }
+
+ e := json.NewEncoder(w)
+ e.SetIndent("", " ")
+ return e.Encode(spec)
+}
+
+// Spec implements subcommands.Command for the "spec" command.
+type Spec struct {
+ bundle string
+ cwd string
+ netns string
+}
+
+// Name implements subcommands.Command.Name.
+func (*Spec) Name() string {
+ return "spec"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Spec) Synopsis() string {
+ return "create a new OCI bundle specification file"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Spec) Usage() string {
+ return `spec [options] [-- args...] - create a new OCI bundle specification file.
+
+The spec command creates a new specification file (config.json) for a new OCI
+bundle.
+
+The specification file is a starter file that runs the command specified by
+'args' in the container. If 'args' is not specified the default is to run the
+'sh' program.
+
+While a number of flags are provided to change values in the specification, you
+can examine the file and edit it to suit your needs after this command runs.
+You can find out more about the format of the specification file by visiting
+the OCI runtime spec repository:
+https://github.com/opencontainers/runtime-spec/
+
+EXAMPLE:
+ $ mkdir -p bundle/rootfs
+ $ cd bundle
+ $ runsc spec -- /hello
+ $ docker export $(docker create hello-world) | tar -xf - -C rootfs
+ $ sudo runsc run hello
+
+`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (s *Spec) SetFlags(f *flag.FlagSet) {
+ f.StringVar(&s.bundle, "bundle", ".", "path to the root of the OCI bundle")
+ f.StringVar(&s.cwd, "cwd", "/", "working directory that will be set for the executable, "+
+ "this value MUST be an absolute path")
+ f.StringVar(&s.netns, "netns", "", "network namespace path")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (s *Spec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ // Grab the arguments.
+ containerArgs := f.Args()
+ if len(containerArgs) == 0 {
+ containerArgs = []string{"sh"}
+ }
+
+ confPath := filepath.Join(s.bundle, "config.json")
+ if _, err := os.Stat(confPath); !os.IsNotExist(err) {
+ Fatalf("file %q already exists", confPath)
+ }
+
+ configFile, err := os.OpenFile(confPath, os.O_WRONLY|os.O_CREATE, 0664)
+ if err != nil {
+ Fatalf("opening file %q: %v", confPath, err)
+ }
+
+ err = writeSpec(configFile, s.cwd, s.netns, containerArgs)
+ if err != nil {
+ Fatalf("writing to %q: %v", confPath, err)
+ }
+
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/cmd/start.go b/runsc/cmd/start.go
new file mode 100644
index 000000000..0205fd9f7
--- /dev/null
+++ b/runsc/cmd/start.go
@@ -0,0 +1,65 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+// Start implements subcommands.Command for the "start" command.
+type Start struct{}
+
+// Name implements subcommands.Command.Name.
+func (*Start) Name() string {
+ return "start"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Start) Synopsis() string {
+ return "start a secure container"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Start) Usage() string {
+ return `start <container id> - start a secure container.`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (*Start) SetFlags(f *flag.FlagSet) {}
+
+// Execute implements subcommands.Command.Execute.
+func (*Start) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if f.NArg() != 1 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+
+ id := f.Arg(0)
+ conf := args[0].(*boot.Config)
+
+ c, err := container.Load(conf.RootDir, id)
+ if err != nil {
+ Fatalf("loading container: %v", err)
+ }
+ if err := c.Start(conf); err != nil {
+ Fatalf("starting container: %v", err)
+ }
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/cmd/state.go b/runsc/cmd/state.go
new file mode 100644
index 000000000..cf2413deb
--- /dev/null
+++ b/runsc/cmd/state.go
@@ -0,0 +1,76 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "encoding/json"
+ "os"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+// State implements subcommands.Command for the "state" command.
+type State struct{}
+
+// Name implements subcommands.Command.Name.
+func (*State) Name() string {
+ return "state"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*State) Synopsis() string {
+ return "get the state of a container"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*State) Usage() string {
+ return `state [flags] <container id> - get the state of a container`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (*State) SetFlags(f *flag.FlagSet) {}
+
+// Execute implements subcommands.Command.Execute.
+func (*State) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if f.NArg() != 1 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+
+ id := f.Arg(0)
+ conf := args[0].(*boot.Config)
+
+ c, err := container.Load(conf.RootDir, id)
+ if err != nil {
+ Fatalf("loading container: %v", err)
+ }
+ log.Debugf("Returning state for container %+v", c)
+
+ state := c.State()
+ log.Debugf("State: %+v", state)
+
+ // Write json-encoded state directly to stdout.
+ b, err := json.MarshalIndent(state, "", " ")
+ if err != nil {
+ Fatalf("marshaling container state: %v", err)
+ }
+ os.Stdout.Write(b)
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/cmd/statefile.go b/runsc/cmd/statefile.go
new file mode 100644
index 000000000..daed9e728
--- /dev/null
+++ b/runsc/cmd/statefile.go
@@ -0,0 +1,149 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "fmt"
+ "os"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/pkg/state/pretty"
+ "gvisor.dev/gvisor/pkg/state/statefile"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+// Statefile implements subcommands.Command for the "statefile" command.
+type Statefile struct {
+ list bool
+ get string
+ key string
+ output string
+ html bool
+}
+
+// Name implements subcommands.Command.
+func (*Statefile) Name() string {
+ return "state"
+}
+
+// Synopsis implements subcommands.Command.
+func (*Statefile) Synopsis() string {
+ return "shows information about a statefile"
+}
+
+// Usage implements subcommands.Command.
+func (*Statefile) Usage() string {
+ return `statefile [flags] <statefile>`
+}
+
+// SetFlags implements subcommands.Command.
+func (s *Statefile) SetFlags(f *flag.FlagSet) {
+ f.BoolVar(&s.list, "list", false, "lists the metdata in the statefile.")
+ f.StringVar(&s.get, "get", "", "extracts the given metadata key.")
+ f.StringVar(&s.key, "key", "", "the integrity key for the file.")
+ f.StringVar(&s.output, "output", "", "target to write the result.")
+ f.BoolVar(&s.html, "html", false, "outputs in HTML format.")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (s *Statefile) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ // Check arguments.
+ if s.list && s.get != "" {
+ Fatalf("error: can't specify -list and -get simultaneously.")
+ }
+
+ // Setup output.
+ var output = os.Stdout // Default.
+ if s.output != "" {
+ f, err := os.OpenFile(s.output, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0644)
+ if err != nil {
+ Fatalf("error opening output: %v", err)
+ }
+ defer func() {
+ if err := f.Close(); err != nil {
+ Fatalf("error flushing output: %v", err)
+ }
+ }()
+ output = f
+ }
+
+ // Open the file.
+ if f.NArg() != 1 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+ input, err := os.Open(f.Arg(0))
+ if err != nil {
+ Fatalf("error opening input: %v\n", err)
+ }
+
+ if s.html {
+ fmt.Fprintf(output, "<html><body>\n")
+ defer fmt.Fprintf(output, "</body></html>\n")
+ }
+
+ // Dump the full file?
+ if !s.list && s.get == "" {
+ var key []byte
+ if s.key != "" {
+ key = []byte(s.key)
+ }
+ rc, _, err := statefile.NewReader(input, key)
+ if err != nil {
+ Fatalf("error parsing statefile: %v", err)
+ }
+ if s.html {
+ if err := pretty.PrintHTML(output, rc); err != nil {
+ Fatalf("error printing state: %v", err)
+ }
+ } else {
+ if err := pretty.PrintText(output, rc); err != nil {
+ Fatalf("error printing state: %v", err)
+ }
+ }
+ return subcommands.ExitSuccess
+ }
+
+ // Load just the metadata.
+ metadata, err := statefile.MetadataUnsafe(input)
+ if err != nil {
+ Fatalf("error reading metadata: %v", err)
+ }
+
+ // Is it a single key?
+ if s.get != "" {
+ val, ok := metadata[s.get]
+ if !ok {
+ Fatalf("metadata key %s: not found", s.get)
+ }
+ fmt.Fprintf(output, "%s\n", val)
+ return subcommands.ExitSuccess
+ }
+
+ // List all keys.
+ if s.html {
+ fmt.Fprintf(output, " <ul>\n")
+ defer fmt.Fprintf(output, " </ul>\n")
+ }
+ for key := range metadata {
+ if s.html {
+ fmt.Fprintf(output, " <li>%s</li>\n", key)
+ } else {
+ fmt.Fprintf(output, "%s\n", key)
+ }
+ }
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/cmd/syscalls.go b/runsc/cmd/syscalls.go
new file mode 100644
index 000000000..a37d66139
--- /dev/null
+++ b/runsc/cmd/syscalls.go
@@ -0,0 +1,356 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "encoding/csv"
+ "encoding/json"
+ "fmt"
+ "io"
+ "os"
+ "sort"
+ "strconv"
+ "text/tabwriter"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+// Syscalls implements subcommands.Command for the "syscalls" command.
+type Syscalls struct {
+ format string
+ os string
+ arch string
+ filename string
+}
+
+// CompatibilityInfo is a map of system and architecture to compatibility doc.
+// Maps operating system to architecture to ArchInfo.
+type CompatibilityInfo map[string]map[string]ArchInfo
+
+// ArchInfo is compatibility doc for an architecture.
+type ArchInfo struct {
+ // Syscalls maps syscall number for the architecture to the doc.
+ Syscalls map[uintptr]SyscallDoc `json:"syscalls"`
+}
+
+// SyscallDoc represents a single item of syscall documentation.
+type SyscallDoc struct {
+ Name string `json:"name"`
+ num uintptr
+
+ Support string `json:"support"`
+ Note string `json:"note,omitempty"`
+ URLs []string `json:"urls,omitempty"`
+}
+
+type outputFunc func(io.Writer, CompatibilityInfo) error
+
+var (
+ // The string name to use for printing compatibility for all OSes.
+ osAll = "all"
+
+ // The string name to use for printing compatibility for all architectures.
+ archAll = "all"
+
+ // A map of OS name to map of architecture name to syscall table.
+ syscallTableMap = make(map[string]map[string]*kernel.SyscallTable)
+
+ // A map of output type names to output functions.
+ outputMap = map[string]outputFunc{
+ "table": outputTable,
+ "json": outputJSON,
+ "csv": outputCSV,
+ }
+)
+
+// Name implements subcommands.Command.Name.
+func (*Syscalls) Name() string {
+ return "syscalls"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Syscalls) Synopsis() string {
+ return "Print compatibility information for syscalls."
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Syscalls) Usage() string {
+ return `syscalls [options] - Print compatibility information for syscalls.
+`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (s *Syscalls) SetFlags(f *flag.FlagSet) {
+ f.StringVar(&s.format, "format", "table", "Output format (table, csv, json).")
+ f.StringVar(&s.os, "os", osAll, "The OS (e.g. linux)")
+ f.StringVar(&s.arch, "arch", archAll, "The CPU architecture (e.g. amd64).")
+ f.StringVar(&s.filename, "filename", "", "Output filename (otherwise stdout).")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (s *Syscalls) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ out, ok := outputMap[s.format]
+ if !ok {
+ Fatalf("Unsupported output format %q", s.format)
+ }
+
+ // Build map of all supported architectures.
+ tables := kernel.SyscallTables()
+ for _, t := range tables {
+ osMap, ok := syscallTableMap[t.OS.String()]
+ if !ok {
+ osMap = make(map[string]*kernel.SyscallTable)
+ syscallTableMap[t.OS.String()] = osMap
+ }
+ osMap[t.Arch.String()] = t
+ }
+
+ // Build a map of the architectures we want to output.
+ info, err := getCompatibilityInfo(s.os, s.arch)
+ if err != nil {
+ Fatalf("%v", err)
+ }
+
+ w := os.Stdout // Default.
+ if s.filename != "" {
+ w, err = os.OpenFile(s.filename, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
+ if err != nil {
+ Fatalf("Error opening %q: %v", s.filename, err)
+ }
+ }
+ if err := out(w, info); err != nil {
+ Fatalf("Error writing output: %v", err)
+ }
+
+ return subcommands.ExitSuccess
+}
+
+// getCompatibilityInfo returns compatibility info for the given OS name and
+// architecture name. Supports the special name 'all' for OS and architecture that
+// specifies that all supported OSes or architectures should be included.
+func getCompatibilityInfo(osName string, archName string) (CompatibilityInfo, error) {
+ info := CompatibilityInfo(make(map[string]map[string]ArchInfo))
+ if osName == osAll {
+ // Special processing for the 'all' OS name.
+ for osName, _ := range syscallTableMap {
+ info[osName] = make(map[string]ArchInfo)
+ // osName is a specific OS name.
+ if err := addToCompatibilityInfo(info, osName, archName); err != nil {
+ return info, err
+ }
+ }
+ } else {
+ // osName is a specific OS name.
+ info[osName] = make(map[string]ArchInfo)
+ if err := addToCompatibilityInfo(info, osName, archName); err != nil {
+ return info, err
+ }
+ }
+
+ return info, nil
+}
+
+// addToCompatibilityInfo adds ArchInfo for the given specific OS name and
+// architecture name. Supports the special architecture name 'all' to specify
+// that all supported architectures for the OS should be included.
+func addToCompatibilityInfo(info CompatibilityInfo, osName string, archName string) error {
+ if archName == archAll {
+ // Special processing for the 'all' architecture name.
+ for archName, _ := range syscallTableMap[osName] {
+ archInfo, err := getArchInfo(osName, archName)
+ if err != nil {
+ return err
+ }
+ info[osName][archName] = archInfo
+ }
+ } else {
+ // archName is a specific architecture name.
+ archInfo, err := getArchInfo(osName, archName)
+ if err != nil {
+ return err
+ }
+ info[osName][archName] = archInfo
+ }
+
+ return nil
+}
+
+// getArchInfo returns compatibility info for a specific OS and architecture.
+func getArchInfo(osName string, archName string) (ArchInfo, error) {
+ info := ArchInfo{}
+ info.Syscalls = make(map[uintptr]SyscallDoc)
+
+ t, ok := syscallTableMap[osName][archName]
+ if !ok {
+ return info, fmt.Errorf("syscall table for %s/%s not found", osName, archName)
+ }
+
+ for num, sc := range t.Table {
+ info.Syscalls[num] = SyscallDoc{
+ Name: sc.Name,
+ num: num,
+ Support: sc.SupportLevel.String(),
+ Note: sc.Note,
+ URLs: sc.URLs,
+ }
+ }
+
+ return info, nil
+}
+
+// outputTable outputs the syscall info in tabular format.
+func outputTable(w io.Writer, info CompatibilityInfo) error {
+ tw := tabwriter.NewWriter(w, 0, 0, 2, ' ', 0)
+
+ // Linux
+ for osName, osInfo := range info {
+ for archName, archInfo := range osInfo {
+ // Print the OS/arch
+ fmt.Fprintf(w, "%s/%s:\n\n", osName, archName)
+
+ // Sort the syscalls for output in the table.
+ sortedCalls := []SyscallDoc{}
+ for _, sc := range archInfo.Syscalls {
+ sortedCalls = append(sortedCalls, sc)
+ }
+ sort.Slice(sortedCalls, func(i, j int) bool {
+ return sortedCalls[i].num < sortedCalls[j].num
+ })
+
+ // Write the header
+ _, err := fmt.Fprintf(tw, "%s\t%s\t%s\t%s\n",
+ "NUM",
+ "NAME",
+ "SUPPORT",
+ "NOTE",
+ )
+ if err != nil {
+ return err
+ }
+
+ // Write each syscall entry
+ for _, sc := range sortedCalls {
+ _, err = fmt.Fprintf(tw, "%s\t%s\t%s\t%s\n",
+ strconv.FormatInt(int64(sc.num), 10),
+ sc.Name,
+ sc.Support,
+ sc.Note,
+ )
+ if err != nil {
+ return err
+ }
+ // Add issue urls to note.
+ for _, url := range sc.URLs {
+ _, err = fmt.Fprintf(tw, "%s\t%s\t%s\tSee: %s\t\n",
+ "",
+ "",
+ "",
+ url,
+ )
+ if err != nil {
+ return err
+ }
+ }
+ }
+
+ err = tw.Flush()
+ if err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+// outputJSON outputs the syscall info in JSON format.
+func outputJSON(w io.Writer, info CompatibilityInfo) error {
+ e := json.NewEncoder(w)
+ e.SetIndent("", " ")
+ return e.Encode(info)
+}
+
+// numberedRow is aCSV row annotated by syscall number (used for sorting)
+type numberedRow struct {
+ num uintptr
+ row []string
+}
+
+// outputCSV outputs the syscall info in tabular format.
+func outputCSV(w io.Writer, info CompatibilityInfo) error {
+ csvWriter := csv.NewWriter(w)
+
+ // Linux
+ for osName, osInfo := range info {
+ for archName, archInfo := range osInfo {
+ // Sort the syscalls for output in the table.
+ sortedCalls := []numberedRow{}
+ for _, sc := range archInfo.Syscalls {
+ // Add issue urls to note.
+ note := sc.Note
+ for _, url := range sc.URLs {
+ note = fmt.Sprintf("%s\nSee: %s", note, url)
+ }
+
+ sortedCalls = append(sortedCalls, numberedRow{
+ num: sc.num,
+ row: []string{
+ osName,
+ archName,
+ strconv.FormatInt(int64(sc.num), 10),
+ sc.Name,
+ sc.Support,
+ note,
+ },
+ })
+ }
+ sort.Slice(sortedCalls, func(i, j int) bool {
+ return sortedCalls[i].num < sortedCalls[j].num
+ })
+
+ // Write the header
+ err := csvWriter.Write([]string{
+ "OS",
+ "Arch",
+ "Num",
+ "Name",
+ "Support",
+ "Note",
+ })
+ if err != nil {
+ return err
+ }
+
+ // Write each syscall entry
+ for _, sc := range sortedCalls {
+ err = csvWriter.Write(sc.row)
+ if err != nil {
+ return err
+ }
+ }
+
+ csvWriter.Flush()
+ err = csvWriter.Error()
+ if err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/runsc/cmd/wait.go b/runsc/cmd/wait.go
new file mode 100644
index 000000000..29c0a15f0
--- /dev/null
+++ b/runsc/cmd/wait.go
@@ -0,0 +1,127 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "encoding/json"
+ "os"
+ "syscall"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+const (
+ unsetPID = -1
+)
+
+// Wait implements subcommands.Command for the "wait" command.
+type Wait struct {
+ rootPID int
+ pid int
+}
+
+// Name implements subcommands.Command.Name.
+func (*Wait) Name() string {
+ return "wait"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Wait) Synopsis() string {
+ return "wait on a process inside a container"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Wait) Usage() string {
+ return `wait [flags] <container id>`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (wt *Wait) SetFlags(f *flag.FlagSet) {
+ f.IntVar(&wt.rootPID, "rootpid", unsetPID, "select a PID in the sandbox root PID namespace to wait on instead of the container's root process")
+ f.IntVar(&wt.pid, "pid", unsetPID, "select a PID in the container's PID namespace to wait on instead of the container's root process")
+}
+
+// Execute implements subcommands.Command.Execute. It waits for a process in a
+// container to exit before returning.
+func (wt *Wait) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if f.NArg() != 1 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+ // You can't specify both -pid and -rootpid.
+ if wt.rootPID != unsetPID && wt.pid != unsetPID {
+ Fatalf("only one of -pid and -rootPid can be set")
+ }
+
+ id := f.Arg(0)
+ conf := args[0].(*boot.Config)
+
+ c, err := container.Load(conf.RootDir, id)
+ if err != nil {
+ Fatalf("loading container: %v", err)
+ }
+
+ var waitStatus syscall.WaitStatus
+ switch {
+ // Wait on the whole container.
+ case wt.rootPID == unsetPID && wt.pid == unsetPID:
+ ws, err := c.Wait()
+ if err != nil {
+ Fatalf("waiting on container %q: %v", c.ID, err)
+ }
+ waitStatus = ws
+ // Wait on a PID in the root PID namespace.
+ case wt.rootPID != unsetPID:
+ ws, err := c.WaitRootPID(int32(wt.rootPID))
+ if err != nil {
+ Fatalf("waiting on PID in root PID namespace %d in container %q: %v", wt.rootPID, c.ID, err)
+ }
+ waitStatus = ws
+ // Wait on a PID in the container's PID namespace.
+ case wt.pid != unsetPID:
+ ws, err := c.WaitPID(int32(wt.pid))
+ if err != nil {
+ Fatalf("waiting on PID %d in container %q: %v", wt.pid, c.ID, err)
+ }
+ waitStatus = ws
+ }
+ result := waitResult{
+ ID: id,
+ ExitStatus: exitStatus(waitStatus),
+ }
+ // Write json-encoded wait result directly to stdout.
+ if err := json.NewEncoder(os.Stdout).Encode(result); err != nil {
+ Fatalf("marshaling wait result: %v", err)
+ }
+ return subcommands.ExitSuccess
+}
+
+type waitResult struct {
+ ID string `json:"id"`
+ ExitStatus int `json:"exitStatus"`
+}
+
+// exitStatus returns the correct exit status for a process based on if it
+// was signaled or exited cleanly.
+func exitStatus(status syscall.WaitStatus) int {
+ if status.Signaled() {
+ return 128 + int(status.Signal())
+ }
+ return status.ExitStatus()
+}
diff --git a/runsc/console/BUILD b/runsc/console/BUILD
new file mode 100644
index 000000000..06924bccd
--- /dev/null
+++ b/runsc/console/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "console",
+ srcs = [
+ "console.go",
+ ],
+ visibility = [
+ "//runsc:__subpackages__",
+ ],
+ deps = [
+ "@com_github_kr_pty//:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/runsc/console/console.go b/runsc/console/console.go
new file mode 100644
index 000000000..64b23639a
--- /dev/null
+++ b/runsc/console/console.go
@@ -0,0 +1,63 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package console contains utilities for working with pty consols in runsc.
+package console
+
+import (
+ "fmt"
+ "net"
+ "os"
+
+ "github.com/kr/pty"
+ "golang.org/x/sys/unix"
+)
+
+// NewWithSocket creates pty master/slave pair, sends the master FD over the given
+// socket, and returns the slave.
+func NewWithSocket(socketPath string) (*os.File, error) {
+ // Create a new pty master and slave.
+ ptyMaster, ptySlave, err := pty.Open()
+ if err != nil {
+ return nil, fmt.Errorf("opening pty: %v", err)
+ }
+ defer ptyMaster.Close()
+
+ // Get a connection to the socket path.
+ conn, err := net.Dial("unix", socketPath)
+ if err != nil {
+ ptySlave.Close()
+ return nil, fmt.Errorf("dialing socket %q: %v", socketPath, err)
+ }
+ defer conn.Close()
+ uc, ok := conn.(*net.UnixConn)
+ if !ok {
+ ptySlave.Close()
+ return nil, fmt.Errorf("connection is not a UnixConn: %T", conn)
+ }
+ socket, err := uc.File()
+ if err != nil {
+ ptySlave.Close()
+ return nil, fmt.Errorf("getting file for unix socket %v: %v", uc, err)
+ }
+ defer socket.Close()
+
+ // Send the master FD over the connection.
+ msg := unix.UnixRights(int(ptyMaster.Fd()))
+ if err := unix.Sendmsg(int(socket.Fd()), []byte("pty-master"), msg, nil, 0); err != nil {
+ ptySlave.Close()
+ return nil, fmt.Errorf("sending console over unix socket %q: %v", socketPath, err)
+ }
+ return ptySlave, nil
+}
diff --git a/runsc/container/BUILD b/runsc/container/BUILD
new file mode 100644
index 000000000..49cfb0837
--- /dev/null
+++ b/runsc/container/BUILD
@@ -0,0 +1,74 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "container",
+ srcs = [
+ "container.go",
+ "hook.go",
+ "state_file.go",
+ "status.go",
+ ],
+ visibility = [
+ "//runsc:__subpackages__",
+ "//test:__subpackages__",
+ ],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/cleanup",
+ "//pkg/log",
+ "//pkg/sentry/control",
+ "//pkg/sentry/sighandling",
+ "//pkg/sync",
+ "//runsc/boot",
+ "//runsc/cgroup",
+ "//runsc/sandbox",
+ "//runsc/specutils",
+ "@com_github_cenkalti_backoff//:go_default_library",
+ "@com_github_gofrs_flock//:go_default_library",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ ],
+)
+
+go_test(
+ name = "container_test",
+ size = "large",
+ srcs = [
+ "console_test.go",
+ "container_norace_test.go",
+ "container_race_test.go",
+ "container_test.go",
+ "multi_container_test.go",
+ "shared_volume_test.go",
+ ],
+ data = [
+ "//runsc",
+ "//test/cmd/test_app",
+ ],
+ library = ":container",
+ shard_count = 10,
+ tags = [
+ "requires-kvm",
+ ],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/bits",
+ "//pkg/cleanup",
+ "//pkg/log",
+ "//pkg/sentry/control",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sync",
+ "//pkg/test/testutil",
+ "//pkg/unet",
+ "//pkg/urpc",
+ "//runsc/boot",
+ "//runsc/boot/platforms",
+ "//runsc/specutils",
+ "@com_github_cenkalti_backoff//:go_default_library",
+ "@com_github_kr_pty//:go_default_library",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go
new file mode 100644
index 000000000..3813c6b93
--- /dev/null
+++ b/runsc/container/console_test.go
@@ -0,0 +1,480 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package container
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "syscall"
+ "testing"
+ "time"
+
+ "github.com/kr/pty"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/sentry/control"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/pkg/unet"
+ "gvisor.dev/gvisor/pkg/urpc"
+)
+
+// socketPath creates a path inside bundleDir and ensures that the returned
+// path is under 108 charactors (the unix socket path length limit),
+// relativizing the path if necessary.
+func socketPath(bundleDir string) (string, error) {
+ path := filepath.Join(bundleDir, "socket")
+ cwd, err := os.Getwd()
+ if err != nil {
+ return "", fmt.Errorf("error getting cwd: %v", err)
+ }
+ relPath, err := filepath.Rel(cwd, path)
+ if err != nil {
+ return "", fmt.Errorf("error getting relative path for %q from cwd %q: %v", path, cwd, err)
+ }
+ if len(path) > len(relPath) {
+ path = relPath
+ }
+ const maxPathLen = 108
+ if len(path) > maxPathLen {
+ return "", fmt.Errorf("could not get socket path under length limit %d: %s", maxPathLen, path)
+ }
+ return path, nil
+}
+
+// createConsoleSocket creates a socket at the given path that will receive a
+// console fd from the sandbox. If an error occurs, t.Fatalf will be called.
+// The function returning should be deferred as cleanup.
+func createConsoleSocket(t *testing.T, path string) (*unet.ServerSocket, func()) {
+ t.Helper()
+ srv, err := unet.BindAndListen(path, false)
+ if err != nil {
+ t.Fatalf("error binding and listening to socket %q: %v", path, err)
+ }
+
+ cleanup := func() {
+ // Log errors; nothing can be done.
+ if err := srv.Close(); err != nil {
+ t.Logf("error closing socket %q: %v", path, err)
+ }
+ if err := os.Remove(path); err != nil {
+ t.Logf("error removing socket %q: %v", path, err)
+ }
+ }
+
+ return srv, cleanup
+}
+
+// receiveConsolePTY accepts a connection on the server socket and reads fds.
+// It fails if more than one FD is received, or if the FD is not a PTY. It
+// returns the PTY master file.
+func receiveConsolePTY(srv *unet.ServerSocket) (*os.File, error) {
+ sock, err := srv.Accept()
+ if err != nil {
+ return nil, fmt.Errorf("error accepting socket connection: %v", err)
+ }
+
+ // Allow 3 fds to be received. We only expect 1.
+ r := sock.Reader(true /* blocking */)
+ r.EnableFDs(1)
+
+ // The socket is closed right after sending the FD, so EOF is
+ // an allowed error.
+ b := [][]byte{{}}
+ if _, err := r.ReadVec(b); err != nil && err != io.EOF {
+ return nil, fmt.Errorf("error reading from socket connection: %v", err)
+ }
+
+ // We should have gotten a control message.
+ fds, err := r.ExtractFDs()
+ if err != nil {
+ return nil, fmt.Errorf("error extracting fds from socket connection: %v", err)
+ }
+ if len(fds) != 1 {
+ return nil, fmt.Errorf("got %d fds from socket, wanted 1", len(fds))
+ }
+
+ // Verify that the fd is a terminal.
+ if _, err := unix.IoctlGetTermios(fds[0], unix.TCGETS); err != nil {
+ return nil, fmt.Errorf("fd is not a terminal (ioctl TGGETS got %v)", err)
+ }
+
+ return os.NewFile(uintptr(fds[0]), "pty_master"), nil
+}
+
+// Test that an pty FD is sent over the console socket if one is provided.
+func TestConsoleSocket(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ spec := testutil.NewSpecWithArgs("true")
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ sock, err := socketPath(bundleDir)
+ if err != nil {
+ t.Fatalf("error getting socket path: %v", err)
+ }
+ srv, cleanup := createConsoleSocket(t, sock)
+ defer cleanup()
+
+ // Create the container and pass the socket name.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ ConsoleSocket: sock,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
+
+ // Make sure we get a console PTY.
+ ptyMaster, err := receiveConsolePTY(srv)
+ if err != nil {
+ t.Fatalf("error receiving console FD: %v", err)
+ }
+ ptyMaster.Close()
+ })
+ }
+}
+
+// Test that job control signals work on a console created with "exec -ti".
+func TestJobControlSignalExec(t *testing.T) {
+ spec := testutil.NewSpecWithArgs("/bin/sleep", "10000")
+ conf := testutil.TestConfig(t)
+
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
+ if err := c.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // Create a pty master/slave. The slave will be passed to the exec
+ // process.
+ ptyMaster, ptySlave, err := pty.Open()
+ if err != nil {
+ t.Fatalf("error opening pty: %v", err)
+ }
+ defer ptyMaster.Close()
+ defer ptySlave.Close()
+
+ // Exec bash and attach a terminal. Note that occasionally /bin/sh
+ // may be a different shell or have a different configuration (such
+ // as disabling interactive mode and job control). Since we want to
+ // explicitly test interactive mode, use /bin/bash. See b/116981926.
+ execArgs := &control.ExecArgs{
+ Filename: "/bin/bash",
+ // Don't let bash execute from profile or rc files, otherwise
+ // our PID counts get messed up.
+ Argv: []string{"/bin/bash", "--noprofile", "--norc"},
+ // Pass the pty slave as FD 0, 1, and 2.
+ FilePayload: urpc.FilePayload{
+ Files: []*os.File{ptySlave, ptySlave, ptySlave},
+ },
+ StdioIsPty: true,
+ }
+
+ pid, err := c.Execute(execArgs)
+ if err != nil {
+ t.Fatalf("error executing: %v", err)
+ }
+ if pid != 2 {
+ t.Fatalf("exec got pid %d, wanted %d", pid, 2)
+ }
+
+ // Make sure all the processes are running.
+ expectedPL := []*control.Process{
+ // Root container process.
+ {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}},
+ // Bash from exec process.
+ {PID: 2, Cmd: "bash", Threads: []kernel.ThreadID{2}},
+ }
+ if err := waitForProcessList(c, expectedPL); err != nil {
+ t.Error(err)
+ }
+
+ // Execute sleep.
+ ptyMaster.Write([]byte("sleep 100\n"))
+
+ // Wait for it to start. Sleep's PPID is bash's PID.
+ expectedPL = append(expectedPL, &control.Process{PID: 3, PPID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{3}})
+ if err := waitForProcessList(c, expectedPL); err != nil {
+ t.Error(err)
+ }
+
+ // Send a SIGTERM to the foreground process for the exec PID. Note that
+ // although we pass in the PID of "bash", it should actually terminate
+ // "sleep", since that is the foreground process.
+ if err := c.Sandbox.SignalProcess(c.ID, pid, syscall.SIGTERM, true /* fgProcess */); err != nil {
+ t.Fatalf("error signaling container: %v", err)
+ }
+
+ // Sleep process should be gone.
+ expectedPL = expectedPL[:len(expectedPL)-1]
+ if err := waitForProcessList(c, expectedPL); err != nil {
+ t.Error(err)
+ }
+
+ // Sleep is dead, but it may take more time for bash to notice and
+ // change the foreground process back to itself. We know it is done
+ // when bash writes "Terminated" to the pty.
+ if err := testutil.WaitUntilRead(ptyMaster, "Terminated", nil, 5*time.Second); err != nil {
+ t.Fatalf("bash did not take over pty: %v", err)
+ }
+
+ // Send a SIGKILL to the foreground process again. This time "bash"
+ // should be killed. We use SIGKILL instead of SIGTERM or SIGINT
+ // because bash ignores those.
+ if err := c.Sandbox.SignalProcess(c.ID, pid, syscall.SIGKILL, true /* fgProcess */); err != nil {
+ t.Fatalf("error signaling container: %v", err)
+ }
+ expectedPL = expectedPL[:1]
+ if err := waitForProcessList(c, expectedPL); err != nil {
+ t.Error(err)
+ }
+
+ // Make sure the process indicates it was killed by a SIGKILL.
+ ws, err := c.WaitPID(pid)
+ if err != nil {
+ t.Errorf("waiting on container failed: %v", err)
+ }
+ if !ws.Signaled() {
+ t.Error("ws.Signaled() got false, want true")
+ }
+ if got, want := ws.Signal(), syscall.SIGKILL; got != want {
+ t.Errorf("ws.Signal() got %v, want %v", got, want)
+ }
+}
+
+// Test that job control signals work on a console created with "run -ti".
+func TestJobControlSignalRootContainer(t *testing.T) {
+ conf := testutil.TestConfig(t)
+ // Don't let bash execute from profile or rc files, otherwise our PID
+ // counts get messed up.
+ spec := testutil.NewSpecWithArgs("/bin/bash", "--noprofile", "--norc")
+ spec.Process.Terminal = true
+
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ sock, err := socketPath(bundleDir)
+ if err != nil {
+ t.Fatalf("error getting socket path: %v", err)
+ }
+ srv, cleanup := createConsoleSocket(t, sock)
+ defer cleanup()
+
+ // Create the container and pass the socket name.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ ConsoleSocket: sock,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
+
+ // Get the PTY master.
+ ptyMaster, err := receiveConsolePTY(srv)
+ if err != nil {
+ t.Fatalf("error receiving console FD: %v", err)
+ }
+ defer ptyMaster.Close()
+
+ // Bash output as well as sandbox output will be written to the PTY
+ // file. Writes after a certain point will block unless we drain the
+ // PTY, so we must continually copy from it.
+ //
+ // We log the output to stderr for debugabilitly, and also to a buffer,
+ // since we wait on particular output from bash below. We use a custom
+ // blockingBuffer which is thread-safe and also blocks on Read calls,
+ // which makes this a suitable Reader for WaitUntilRead.
+ ptyBuf := newBlockingBuffer()
+ tee := io.TeeReader(ptyMaster, ptyBuf)
+ go io.Copy(os.Stderr, tee)
+
+ // Start the container.
+ if err := c.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // Start waiting for the container to exit in a goroutine. We do this
+ // very early, otherwise it might exit before we have a chance to call
+ // Wait.
+ var (
+ ws syscall.WaitStatus
+ wg sync.WaitGroup
+ )
+ wg.Add(1)
+ go func() {
+ var err error
+ ws, err = c.Wait()
+ if err != nil {
+ t.Errorf("error waiting on container: %v", err)
+ }
+ wg.Done()
+ }()
+
+ // Wait for bash to start.
+ expectedPL := []*control.Process{
+ {PID: 1, Cmd: "bash", Threads: []kernel.ThreadID{1}},
+ }
+ if err := waitForProcessList(c, expectedPL); err != nil {
+ t.Fatalf("error waiting for processes: %v", err)
+ }
+
+ // Execute sleep via the terminal.
+ ptyMaster.Write([]byte("sleep 100\n"))
+
+ // Wait for sleep to start.
+ expectedPL = append(expectedPL, &control.Process{PID: 2, PPID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{2}})
+ if err := waitForProcessList(c, expectedPL); err != nil {
+ t.Fatalf("error waiting for processes: %v", err)
+ }
+
+ // Reset the pty buffer, so there is less output for us to scan later.
+ ptyBuf.Reset()
+
+ // Send a SIGTERM to the foreground process. We pass PID=0, indicating
+ // that the root process should be killed. However, by setting
+ // fgProcess=true, the signal should actually be sent to sleep.
+ if err := c.Sandbox.SignalProcess(c.ID, 0 /* PID */, syscall.SIGTERM, true /* fgProcess */); err != nil {
+ t.Fatalf("error signaling container: %v", err)
+ }
+
+ // Sleep process should be gone.
+ expectedPL = expectedPL[:len(expectedPL)-1]
+ if err := waitForProcessList(c, expectedPL); err != nil {
+ t.Error(err)
+ }
+
+ // Sleep is dead, but it may take more time for bash to notice and
+ // change the foreground process back to itself. We know it is done
+ // when bash writes "Terminated" to the pty.
+ if err := testutil.WaitUntilRead(ptyBuf, "Terminated", nil, 5*time.Second); err != nil {
+ t.Fatalf("bash did not take over pty: %v", err)
+ }
+
+ // Send a SIGKILL to the foreground process again. This time "bash"
+ // should be killed. We use SIGKILL instead of SIGTERM or SIGINT
+ // because bash ignores those.
+ if err := c.Sandbox.SignalProcess(c.ID, 0 /* PID */, syscall.SIGKILL, true /* fgProcess */); err != nil {
+ t.Fatalf("error signaling container: %v", err)
+ }
+
+ // Wait for the sandbox to exit. It should exit with a SIGKILL status.
+ wg.Wait()
+ if !ws.Signaled() {
+ t.Error("ws.Signaled() got false, want true")
+ }
+ if got, want := ws.Signal(), syscall.SIGKILL; got != want {
+ t.Errorf("ws.Signal() got %v, want %v", got, want)
+ }
+}
+
+// blockingBuffer is a thread-safe buffer that blocks when reading if the
+// buffer is empty. It implements io.ReadWriter.
+type blockingBuffer struct {
+ // A send to readCh indicates that a previously empty buffer now has
+ // data for reading.
+ readCh chan struct{}
+
+ // mu protects buf.
+ mu sync.Mutex
+ buf bytes.Buffer
+}
+
+func newBlockingBuffer() *blockingBuffer {
+ return &blockingBuffer{
+ readCh: make(chan struct{}, 1),
+ }
+}
+
+// Write implements Writer.Write.
+func (bb *blockingBuffer) Write(p []byte) (int, error) {
+ bb.mu.Lock()
+ defer bb.mu.Unlock()
+ l := bb.buf.Len()
+ n, err := bb.buf.Write(p)
+ if l == 0 && n > 0 {
+ // New data!
+ bb.readCh <- struct{}{}
+ }
+ return n, err
+}
+
+// Read implements Reader.Read. It will block until data is available.
+func (bb *blockingBuffer) Read(p []byte) (int, error) {
+ for {
+ bb.mu.Lock()
+ n, err := bb.buf.Read(p)
+ if n > 0 || err != io.EOF {
+ if bb.buf.Len() == 0 {
+ // Reset the readCh.
+ select {
+ case <-bb.readCh:
+ default:
+ }
+ }
+ bb.mu.Unlock()
+ return n, err
+ }
+ bb.mu.Unlock()
+
+ // Wait for new data.
+ <-bb.readCh
+ }
+}
+
+// Reset resets the buffer.
+func (bb *blockingBuffer) Reset() {
+ bb.mu.Lock()
+ defer bb.mu.Unlock()
+ bb.buf.Reset()
+ // Reset the readCh.
+ select {
+ case <-bb.readCh:
+ default:
+ }
+}
diff --git a/runsc/container/container.go b/runsc/container/container.go
new file mode 100644
index 000000000..6d297d0df
--- /dev/null
+++ b/runsc/container/container.go
@@ -0,0 +1,1171 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package container creates and manipulates containers.
+package container
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "regexp"
+ "strconv"
+ "strings"
+ "syscall"
+ "time"
+
+ "github.com/cenkalti/backoff"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/control"
+ "gvisor.dev/gvisor/pkg/sentry/sighandling"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/cgroup"
+ "gvisor.dev/gvisor/runsc/sandbox"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// validateID validates the container id.
+func validateID(id string) error {
+ // See libcontainer/factory_linux.go.
+ idRegex := regexp.MustCompile(`^[\w+-\.]+$`)
+ if !idRegex.MatchString(id) {
+ return fmt.Errorf("invalid container id: %v", id)
+ }
+ return nil
+}
+
+// Container represents a containerized application. When running, the
+// container is associated with a single Sandbox.
+//
+// Container metadata can be saved and loaded to disk. Within a root directory,
+// we maintain subdirectories for each container named with the container id.
+// The container metadata is stored as a json within the container directory
+// in a file named "meta.json". This metadata format is defined by us and is
+// not part of the OCI spec.
+//
+// Containers must write their metadata files after any change to their internal
+// states. The entire container directory is deleted when the container is
+// destroyed.
+//
+// When the container is stopped, all processes that belong to the container
+// must be stopped before Destroy() returns. containerd makes roughly the
+// following calls to stop a container:
+// - First it attempts to kill the container process with
+// 'runsc kill SIGTERM'. After some time, it escalates to SIGKILL. In a
+// separate thread, it's waiting on the container. As soon as the wait
+// returns, it moves on to the next step:
+// - It calls 'runsc kill --all SIGKILL' to stop every process that belongs to
+// the container. 'kill --all SIGKILL' waits for all processes before
+// returning.
+// - Containerd waits for stdin, stdout and stderr to drain and be closed.
+// - It calls 'runsc delete'. runc implementation kills --all SIGKILL once
+// again just to be sure, waits, and then proceeds with remaining teardown.
+//
+type Container struct {
+ // ID is the container ID.
+ ID string `json:"id"`
+
+ // Spec is the OCI runtime spec that configures this container.
+ Spec *specs.Spec `json:"spec"`
+
+ // BundleDir is the directory containing the container bundle.
+ BundleDir string `json:"bundleDir"`
+
+ // CreatedAt is the time the container was created.
+ CreatedAt time.Time `json:"createdAt"`
+
+ // Owner is the container owner.
+ Owner string `json:"owner"`
+
+ // ConsoleSocket is the path to a unix domain socket that will receive
+ // the console FD.
+ ConsoleSocket string `json:"consoleSocket"`
+
+ // Status is the current container Status.
+ Status Status `json:"status"`
+
+ // GoferPid is the PID of the gofer running along side the sandbox. May
+ // be 0 if the gofer has been killed.
+ GoferPid int `json:"goferPid"`
+
+ // Sandbox is the sandbox this container is running in. It's set when the
+ // container is created and reset when the sandbox is destroyed.
+ Sandbox *sandbox.Sandbox `json:"sandbox"`
+
+ // Saver handles load from/save to the state file safely from multiple
+ // processes.
+ Saver StateFile `json:"saver"`
+
+ //
+ // Fields below this line are not saved in the state file and will not
+ // be preserved across commands.
+ //
+
+ // goferIsChild is set if a gofer process is a child of the current process.
+ //
+ // This field isn't saved to json, because only a creator of a gofer
+ // process will have it as a child process.
+ goferIsChild bool
+}
+
+// loadSandbox loads all containers that belong to the sandbox with the given
+// ID.
+func loadSandbox(rootDir, id string) ([]*Container, error) {
+ cids, err := List(rootDir)
+ if err != nil {
+ return nil, err
+ }
+
+ // Load the container metadata.
+ var containers []*Container
+ for _, cid := range cids {
+ container, err := Load(rootDir, cid)
+ if err != nil {
+ // Container file may not exist if it raced with creation/deletion or
+ // directory was left behind. Load provides a snapshot in time, so it's
+ // fine to skip it.
+ if os.IsNotExist(err) {
+ continue
+ }
+ return nil, fmt.Errorf("loading container %q: %v", id, err)
+ }
+ if container.Sandbox.ID == id {
+ containers = append(containers, container)
+ }
+ }
+ return containers, nil
+}
+
+// Load loads a container with the given id from a metadata file. partialID may
+// be an abbreviation of the full container id, in which case Load loads the
+// container to which id unambiguously refers to. Returns ErrNotExist if
+// container doesn't exist.
+func Load(rootDir, partialID string) (*Container, error) {
+ log.Debugf("Load container %q %q", rootDir, partialID)
+ if err := validateID(partialID); err != nil {
+ return nil, fmt.Errorf("validating id: %v", err)
+ }
+
+ id, err := findContainerID(rootDir, partialID)
+ if err != nil {
+ // Preserve error so that callers can distinguish 'not found' errors.
+ return nil, err
+ }
+
+ state := StateFile{
+ RootDir: rootDir,
+ ID: id,
+ }
+ defer state.close()
+
+ c := &Container{}
+ if err := state.load(c); err != nil {
+ if os.IsNotExist(err) {
+ // Preserve error so that callers can distinguish 'not found' errors.
+ return nil, err
+ }
+ return nil, fmt.Errorf("reading container metadata file %q: %v", state.statePath(), err)
+ }
+
+ // If the status is "Running" or "Created", check that the sandbox
+ // process still exists, and set it to Stopped if it does not.
+ //
+ // This is inherently racy.
+ if c.Status == Running || c.Status == Created {
+ // Check if the sandbox process is still running.
+ if !c.isSandboxRunning() {
+ // Sandbox no longer exists, so this container definitely does not exist.
+ c.changeStatus(Stopped)
+ } else if c.Status == Running {
+ // Container state should reflect the actual state of the application, so
+ // we don't consider gofer process here.
+ if err := c.SignalContainer(syscall.Signal(0), false); err != nil {
+ c.changeStatus(Stopped)
+ }
+ }
+ }
+
+ return c, nil
+}
+
+func findContainerID(rootDir, partialID string) (string, error) {
+ // Check whether the id fully specifies an existing container.
+ stateFile := buildStatePath(rootDir, partialID)
+ if _, err := os.Stat(stateFile); err == nil {
+ return partialID, nil
+ }
+
+ // Now see whether id could be an abbreviation of exactly 1 of the
+ // container ids. If id is ambiguous (it could match more than 1
+ // container), it is an error.
+ ids, err := List(rootDir)
+ if err != nil {
+ return "", err
+ }
+ rv := ""
+ for _, id := range ids {
+ if strings.HasPrefix(id, partialID) {
+ if rv != "" {
+ return "", fmt.Errorf("id %q is ambiguous and could refer to multiple containers: %q, %q", partialID, rv, id)
+ }
+ rv = id
+ }
+ }
+ if rv == "" {
+ return "", os.ErrNotExist
+ }
+ log.Debugf("abbreviated id %q resolves to full id %q", partialID, rv)
+ return rv, nil
+}
+
+// Args is used to configure a new container.
+type Args struct {
+ // ID is the container unique identifier.
+ ID string
+
+ // Spec is the OCI spec that describes the container.
+ Spec *specs.Spec
+
+ // BundleDir is the directory containing the container bundle.
+ BundleDir string
+
+ // ConsoleSocket is the path to a unix domain socket that will receive
+ // the console FD. It may be empty.
+ ConsoleSocket string
+
+ // PIDFile is the filename where the container's root process PID will be
+ // written to. It may be empty.
+ PIDFile string
+
+ // UserLog is the filename to send user-visible logs to. It may be empty.
+ //
+ // It only applies for the init container.
+ UserLog string
+
+ // Attached indicates that the sandbox lifecycle is attached with the caller.
+ // If the caller exits, the sandbox should exit too.
+ //
+ // It only applies for the init container.
+ Attached bool
+}
+
+// New creates the container in a new Sandbox process, unless the metadata
+// indicates that an existing Sandbox should be used. The caller must call
+// Destroy() on the container.
+func New(conf *boot.Config, args Args) (*Container, error) {
+ log.Debugf("Create container %q in root dir: %s", args.ID, conf.RootDir)
+ if err := validateID(args.ID); err != nil {
+ return nil, err
+ }
+
+ if err := os.MkdirAll(conf.RootDir, 0711); err != nil {
+ return nil, fmt.Errorf("creating container root directory %q: %v", conf.RootDir, err)
+ }
+
+ c := &Container{
+ ID: args.ID,
+ Spec: args.Spec,
+ ConsoleSocket: args.ConsoleSocket,
+ BundleDir: args.BundleDir,
+ Status: Creating,
+ CreatedAt: time.Now(),
+ Owner: os.Getenv("USER"),
+ Saver: StateFile{
+ RootDir: conf.RootDir,
+ ID: args.ID,
+ },
+ }
+ // The Cleanup object cleans up partially created containers when an error
+ // occurs. Any errors occurring during cleanup itself are ignored.
+ cu := cleanup.Make(func() { _ = c.Destroy() })
+ defer cu.Clean()
+
+ // Lock the container metadata file to prevent concurrent creations of
+ // containers with the same id.
+ if err := c.Saver.lockForNew(); err != nil {
+ return nil, err
+ }
+ defer c.Saver.unlock()
+
+ // If the metadata annotations indicate that this container should be
+ // started in an existing sandbox, we must do so. The metadata will
+ // indicate the ID of the sandbox, which is the same as the ID of the
+ // init container in the sandbox.
+ if isRoot(args.Spec) {
+ log.Debugf("Creating new sandbox for container %q", args.ID)
+
+ // Create and join cgroup before processes are created to ensure they are
+ // part of the cgroup from the start (and all their children processes).
+ cg, err := cgroup.New(args.Spec)
+ if err != nil {
+ return nil, err
+ }
+ if cg != nil {
+ // If there is cgroup config, install it before creating sandbox process.
+ if err := cg.Install(args.Spec.Linux.Resources); err != nil {
+ return nil, fmt.Errorf("configuring cgroup: %v", err)
+ }
+ }
+ if err := runInCgroup(cg, func() error {
+ ioFiles, specFile, err := c.createGoferProcess(args.Spec, conf, args.BundleDir)
+ if err != nil {
+ return err
+ }
+
+ // Start a new sandbox for this container. Any errors after this point
+ // must destroy the container.
+ sandArgs := &sandbox.Args{
+ ID: args.ID,
+ Spec: args.Spec,
+ BundleDir: args.BundleDir,
+ ConsoleSocket: args.ConsoleSocket,
+ UserLog: args.UserLog,
+ IOFiles: ioFiles,
+ MountsFile: specFile,
+ Cgroup: cg,
+ Attached: args.Attached,
+ }
+ sand, err := sandbox.New(conf, sandArgs)
+ if err != nil {
+ return err
+ }
+ c.Sandbox = sand
+ return nil
+
+ }); err != nil {
+ return nil, err
+ }
+ } else {
+ // This is sort of confusing. For a sandbox with a root
+ // container and a child container in it, runsc sees:
+ // * A container struct whose sandbox ID is equal to the
+ // container ID. This is the root container that is tied to
+ // the creation of the sandbox.
+ // * A container struct whose sandbox ID is equal to the above
+ // container/sandbox ID, but that has a different container
+ // ID. This is the child container.
+ sbid, ok := specutils.SandboxID(args.Spec)
+ if !ok {
+ return nil, fmt.Errorf("no sandbox ID found when creating container")
+ }
+ log.Debugf("Creating new container %q in sandbox %q", c.ID, sbid)
+
+ // Find the sandbox associated with this ID.
+ sb, err := Load(conf.RootDir, sbid)
+ if err != nil {
+ return nil, err
+ }
+ c.Sandbox = sb.Sandbox
+ if err := c.Sandbox.CreateContainer(c.ID); err != nil {
+ return nil, err
+ }
+ }
+ c.changeStatus(Created)
+
+ // Save the metadata file.
+ if err := c.saveLocked(); err != nil {
+ return nil, err
+ }
+
+ // Write the PID file. Containerd considers the create complete after
+ // this file is created, so it must be the last thing we do.
+ if args.PIDFile != "" {
+ if err := ioutil.WriteFile(args.PIDFile, []byte(strconv.Itoa(c.SandboxPid())), 0644); err != nil {
+ return nil, fmt.Errorf("error writing PID file: %v", err)
+ }
+ }
+
+ cu.Release()
+ return c, nil
+}
+
+// Start starts running the containerized process inside the sandbox.
+func (c *Container) Start(conf *boot.Config) error {
+ log.Debugf("Start container %q", c.ID)
+
+ if err := c.Saver.lock(); err != nil {
+ return err
+ }
+ unlock := cleanup.Make(func() { c.Saver.unlock() })
+ defer unlock.Clean()
+
+ if err := c.requireStatus("start", Created); err != nil {
+ return err
+ }
+
+ // "If any prestart hook fails, the runtime MUST generate an error,
+ // stop and destroy the container" -OCI spec.
+ if c.Spec.Hooks != nil {
+ if err := executeHooks(c.Spec.Hooks.Prestart, c.State()); err != nil {
+ return err
+ }
+ }
+
+ if isRoot(c.Spec) {
+ if err := c.Sandbox.StartRoot(c.Spec, conf); err != nil {
+ return err
+ }
+ } else {
+ // Join cgroup to start gofer process to ensure it's part of the cgroup from
+ // the start (and all their children processes).
+ if err := runInCgroup(c.Sandbox.Cgroup, func() error {
+ // Create the gofer process.
+ ioFiles, mountsFile, err := c.createGoferProcess(c.Spec, conf, c.BundleDir)
+ if err != nil {
+ return err
+ }
+ defer mountsFile.Close()
+
+ cleanMounts, err := specutils.ReadMounts(mountsFile)
+ if err != nil {
+ return fmt.Errorf("reading mounts file: %v", err)
+ }
+ c.Spec.Mounts = cleanMounts
+
+ return c.Sandbox.StartContainer(c.Spec, conf, c.ID, ioFiles)
+ }); err != nil {
+ return err
+ }
+ }
+
+ // "If any poststart hook fails, the runtime MUST log a warning, but
+ // the remaining hooks and lifecycle continue as if the hook had
+ // succeeded" -OCI spec.
+ if c.Spec.Hooks != nil {
+ executeHooksBestEffort(c.Spec.Hooks.Poststart, c.State())
+ }
+
+ c.changeStatus(Running)
+ if err := c.saveLocked(); err != nil {
+ return err
+ }
+
+ // Release lock before adjusting OOM score because the lock is acquired there.
+ unlock.Clean()
+
+ // Adjust the oom_score_adj for sandbox. This must be done after saveLocked().
+ if err := adjustSandboxOOMScoreAdj(c.Sandbox, c.Saver.RootDir, false); err != nil {
+ return err
+ }
+
+ // Set container's oom_score_adj to the gofer since it is dedicated to
+ // the container, in case the gofer uses up too much memory.
+ return c.adjustGoferOOMScoreAdj()
+}
+
+// Restore takes a container and replaces its kernel and file system
+// to restore a container from its state file.
+func (c *Container) Restore(spec *specs.Spec, conf *boot.Config, restoreFile string) error {
+ log.Debugf("Restore container %q", c.ID)
+ if err := c.Saver.lock(); err != nil {
+ return err
+ }
+ defer c.Saver.unlock()
+
+ if err := c.requireStatus("restore", Created); err != nil {
+ return err
+ }
+
+ // "If any prestart hook fails, the runtime MUST generate an error,
+ // stop and destroy the container" -OCI spec.
+ if c.Spec.Hooks != nil {
+ if err := executeHooks(c.Spec.Hooks.Prestart, c.State()); err != nil {
+ return err
+ }
+ }
+
+ if err := c.Sandbox.Restore(c.ID, spec, conf, restoreFile); err != nil {
+ return err
+ }
+ c.changeStatus(Running)
+ return c.saveLocked()
+}
+
+// Run is a helper that calls Create + Start + Wait.
+func Run(conf *boot.Config, args Args) (syscall.WaitStatus, error) {
+ log.Debugf("Run container %q in root dir: %s", args.ID, conf.RootDir)
+ c, err := New(conf, args)
+ if err != nil {
+ return 0, fmt.Errorf("creating container: %v", err)
+ }
+ // Clean up partially created container if an error occurs.
+ // Any errors returned by Destroy() itself are ignored.
+ cu := cleanup.Make(func() {
+ c.Destroy()
+ })
+ defer cu.Clean()
+
+ if conf.RestoreFile != "" {
+ log.Debugf("Restore: %v", conf.RestoreFile)
+ if err := c.Restore(args.Spec, conf, conf.RestoreFile); err != nil {
+ return 0, fmt.Errorf("starting container: %v", err)
+ }
+ } else {
+ if err := c.Start(conf); err != nil {
+ return 0, fmt.Errorf("starting container: %v", err)
+ }
+ }
+ if args.Attached {
+ return c.Wait()
+ }
+ cu.Release()
+ return 0, nil
+}
+
+// Execute runs the specified command in the container. It returns the PID of
+// the newly created process.
+func (c *Container) Execute(args *control.ExecArgs) (int32, error) {
+ log.Debugf("Execute in container %q, args: %+v", c.ID, args)
+ if err := c.requireStatus("execute in", Created, Running); err != nil {
+ return 0, err
+ }
+ args.ContainerID = c.ID
+ return c.Sandbox.Execute(args)
+}
+
+// Event returns events for the container.
+func (c *Container) Event() (*boot.Event, error) {
+ log.Debugf("Getting events for container %q", c.ID)
+ if err := c.requireStatus("get events for", Created, Running, Paused); err != nil {
+ return nil, err
+ }
+ return c.Sandbox.Event(c.ID)
+}
+
+// SandboxPid returns the Pid of the sandbox the container is running in, or -1 if the
+// container is not running.
+func (c *Container) SandboxPid() int {
+ if err := c.requireStatus("get PID", Created, Running, Paused); err != nil {
+ return -1
+ }
+ return c.Sandbox.Pid
+}
+
+// Wait waits for the container to exit, and returns its WaitStatus.
+// Call to wait on a stopped container is needed to retrieve the exit status
+// and wait returns immediately.
+func (c *Container) Wait() (syscall.WaitStatus, error) {
+ log.Debugf("Wait on container %q", c.ID)
+ return c.Sandbox.Wait(c.ID)
+}
+
+// WaitRootPID waits for process 'pid' in the sandbox's PID namespace and
+// returns its WaitStatus.
+func (c *Container) WaitRootPID(pid int32) (syscall.WaitStatus, error) {
+ log.Debugf("Wait on PID %d in sandbox %q", pid, c.Sandbox.ID)
+ if !c.isSandboxRunning() {
+ return 0, fmt.Errorf("sandbox is not running")
+ }
+ return c.Sandbox.WaitPID(c.Sandbox.ID, pid)
+}
+
+// WaitPID waits for process 'pid' in the container's PID namespace and returns
+// its WaitStatus.
+func (c *Container) WaitPID(pid int32) (syscall.WaitStatus, error) {
+ log.Debugf("Wait on PID %d in container %q", pid, c.ID)
+ if !c.isSandboxRunning() {
+ return 0, fmt.Errorf("sandbox is not running")
+ }
+ return c.Sandbox.WaitPID(c.ID, pid)
+}
+
+// SignalContainer sends the signal to the container. If all is true and signal
+// is SIGKILL, then waits for all processes to exit before returning.
+// SignalContainer returns an error if the container is already stopped.
+// TODO(b/113680494): Distinguish different error types.
+func (c *Container) SignalContainer(sig syscall.Signal, all bool) error {
+ log.Debugf("Signal container %q: %v", c.ID, sig)
+ // Signaling container in Stopped state is allowed. When all=false,
+ // an error will be returned anyway; when all=true, this allows
+ // sending signal to other processes inside the container even
+ // after the init process exits. This is especially useful for
+ // container cleanup.
+ if err := c.requireStatus("signal", Running, Stopped); err != nil {
+ return err
+ }
+ if !c.isSandboxRunning() {
+ return fmt.Errorf("sandbox is not running")
+ }
+ return c.Sandbox.SignalContainer(c.ID, sig, all)
+}
+
+// SignalProcess sends sig to a specific process in the container.
+func (c *Container) SignalProcess(sig syscall.Signal, pid int32) error {
+ log.Debugf("Signal process %d in container %q: %v", pid, c.ID, sig)
+ if err := c.requireStatus("signal a process inside", Running); err != nil {
+ return err
+ }
+ if !c.isSandboxRunning() {
+ return fmt.Errorf("sandbox is not running")
+ }
+ return c.Sandbox.SignalProcess(c.ID, int32(pid), sig, false)
+}
+
+// ForwardSignals forwards all signals received by the current process to the
+// container process inside the sandbox. It returns a function that will stop
+// forwarding signals.
+func (c *Container) ForwardSignals(pid int32, fgProcess bool) func() {
+ log.Debugf("Forwarding all signals to container %q PID %d fgProcess=%t", c.ID, pid, fgProcess)
+ stop := sighandling.StartSignalForwarding(func(sig linux.Signal) {
+ log.Debugf("Forwarding signal %d to container %q PID %d fgProcess=%t", sig, c.ID, pid, fgProcess)
+ if err := c.Sandbox.SignalProcess(c.ID, pid, syscall.Signal(sig), fgProcess); err != nil {
+ log.Warningf("error forwarding signal %d to container %q: %v", sig, c.ID, err)
+ }
+ })
+ return func() {
+ log.Debugf("Done forwarding signals to container %q PID %d fgProcess=%t", c.ID, pid, fgProcess)
+ stop()
+ }
+}
+
+// Checkpoint sends the checkpoint call to the container.
+// The statefile will be written to f, the file at the specified image-path.
+func (c *Container) Checkpoint(f *os.File) error {
+ log.Debugf("Checkpoint container %q", c.ID)
+ if err := c.requireStatus("checkpoint", Created, Running, Paused); err != nil {
+ return err
+ }
+ return c.Sandbox.Checkpoint(c.ID, f)
+}
+
+// Pause suspends the container and its kernel.
+// The call only succeeds if the container's status is created or running.
+func (c *Container) Pause() error {
+ log.Debugf("Pausing container %q", c.ID)
+ if err := c.Saver.lock(); err != nil {
+ return err
+ }
+ defer c.Saver.unlock()
+
+ if c.Status != Created && c.Status != Running {
+ return fmt.Errorf("cannot pause container %q in state %v", c.ID, c.Status)
+ }
+
+ if err := c.Sandbox.Pause(c.ID); err != nil {
+ return fmt.Errorf("pausing container: %v", err)
+ }
+ c.changeStatus(Paused)
+ return c.saveLocked()
+}
+
+// Resume unpauses the container and its kernel.
+// The call only succeeds if the container's status is paused.
+func (c *Container) Resume() error {
+ log.Debugf("Resuming container %q", c.ID)
+ if err := c.Saver.lock(); err != nil {
+ return err
+ }
+ defer c.Saver.unlock()
+
+ if c.Status != Paused {
+ return fmt.Errorf("cannot resume container %q in state %v", c.ID, c.Status)
+ }
+ if err := c.Sandbox.Resume(c.ID); err != nil {
+ return fmt.Errorf("resuming container: %v", err)
+ }
+ c.changeStatus(Running)
+ return c.saveLocked()
+}
+
+// State returns the metadata of the container.
+func (c *Container) State() specs.State {
+ return specs.State{
+ Version: specs.Version,
+ ID: c.ID,
+ Status: c.Status.String(),
+ Pid: c.SandboxPid(),
+ Bundle: c.BundleDir,
+ }
+}
+
+// Processes retrieves the list of processes and associated metadata inside a
+// container.
+func (c *Container) Processes() ([]*control.Process, error) {
+ if err := c.requireStatus("get processes of", Running, Paused); err != nil {
+ return nil, err
+ }
+ return c.Sandbox.Processes(c.ID)
+}
+
+// Destroy stops all processes and frees all resources associated with the
+// container.
+func (c *Container) Destroy() error {
+ log.Debugf("Destroy container %q", c.ID)
+
+ if err := c.Saver.lock(); err != nil {
+ return err
+ }
+ defer func() {
+ c.Saver.unlock()
+ c.Saver.close()
+ }()
+
+ // Stored for later use as stop() sets c.Sandbox to nil.
+ sb := c.Sandbox
+
+ // We must perform the following cleanup steps:
+ // * stop the container and gofer processes,
+ // * remove the container filesystem on the host, and
+ // * delete the container metadata directory.
+ //
+ // It's possible for one or more of these steps to fail, but we should
+ // do our best to perform all of the cleanups. Hence, we keep a slice
+ // of errors return their concatenation.
+ var errs []string
+ if err := c.stop(); err != nil {
+ err = fmt.Errorf("stopping container: %v", err)
+ log.Warningf("%v", err)
+ errs = append(errs, err.Error())
+ }
+
+ if err := c.Saver.destroy(); err != nil {
+ err = fmt.Errorf("deleting container state files: %v", err)
+ log.Warningf("%v", err)
+ errs = append(errs, err.Error())
+ }
+
+ c.changeStatus(Stopped)
+
+ // Adjust oom_score_adj for the sandbox. This must be done after the container
+ // is stopped and the directory at c.Root is removed. Adjustment can be
+ // skipped if the root container is exiting, because it brings down the entire
+ // sandbox.
+ //
+ // Use 'sb' to tell whether it has been executed before because Destroy must
+ // be idempotent.
+ if sb != nil && !isRoot(c.Spec) {
+ if err := adjustSandboxOOMScoreAdj(sb, c.Saver.RootDir, true); err != nil {
+ errs = append(errs, err.Error())
+ }
+ }
+
+ // "If any poststop hook fails, the runtime MUST log a warning, but the
+ // remaining hooks and lifecycle continue as if the hook had
+ // succeeded" - OCI spec.
+ //
+ // Based on the OCI, "The post-stop hooks MUST be called after the container
+ // is deleted but before the delete operation returns"
+ // Run it here to:
+ // 1) Conform to the OCI.
+ // 2) Make sure it only runs once, because the root has been deleted, the
+ // container can't be loaded again.
+ if c.Spec.Hooks != nil {
+ executeHooksBestEffort(c.Spec.Hooks.Poststop, c.State())
+ }
+
+ if len(errs) == 0 {
+ return nil
+ }
+ return fmt.Errorf(strings.Join(errs, "\n"))
+}
+
+// saveLocked saves the container metadata to a file.
+//
+// Precondition: container must be locked with container.lock().
+func (c *Container) saveLocked() error {
+ log.Debugf("Save container %q", c.ID)
+ if err := c.Saver.saveLocked(c); err != nil {
+ return fmt.Errorf("saving container metadata: %v", err)
+ }
+ return nil
+}
+
+// stop stops the container (for regular containers) or the sandbox (for
+// root containers), and waits for the container or sandbox and the gofer
+// to stop. If any of them doesn't stop before timeout, an error is returned.
+func (c *Container) stop() error {
+ var cgroup *cgroup.Cgroup
+
+ if c.Sandbox != nil {
+ log.Debugf("Destroying container %q", c.ID)
+ if err := c.Sandbox.DestroyContainer(c.ID); err != nil {
+ return fmt.Errorf("destroying container %q: %v", c.ID, err)
+ }
+ // Only uninstall cgroup for sandbox stop.
+ if c.Sandbox.IsRootContainer(c.ID) {
+ cgroup = c.Sandbox.Cgroup
+ }
+ // Only set sandbox to nil after it has been told to destroy the container.
+ c.Sandbox = nil
+ }
+
+ // Try killing gofer if it does not exit with container.
+ if c.GoferPid != 0 {
+ log.Debugf("Killing gofer for container %q, PID: %d", c.ID, c.GoferPid)
+ if err := syscall.Kill(c.GoferPid, syscall.SIGKILL); err != nil {
+ // The gofer may already be stopped, log the error.
+ log.Warningf("Error sending signal %d to gofer %d: %v", syscall.SIGKILL, c.GoferPid, err)
+ }
+ }
+
+ if err := c.waitForStopped(); err != nil {
+ return err
+ }
+
+ // Gofer is running in cgroups, so Cgroup.Uninstall has to be called after it.
+ if cgroup != nil {
+ if err := cgroup.Uninstall(); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (c *Container) waitForStopped() error {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx)
+ op := func() error {
+ if c.isSandboxRunning() {
+ if err := c.SignalContainer(syscall.Signal(0), false); err == nil {
+ return fmt.Errorf("container is still running")
+ }
+ }
+ if c.GoferPid == 0 {
+ return nil
+ }
+ if c.goferIsChild {
+ // The gofer process is a child of the current process,
+ // so we can wait it and collect its zombie.
+ wpid, err := syscall.Wait4(int(c.GoferPid), nil, syscall.WNOHANG, nil)
+ if err != nil {
+ return fmt.Errorf("error waiting the gofer process: %v", err)
+ }
+ if wpid == 0 {
+ return fmt.Errorf("gofer is still running")
+ }
+
+ } else if err := syscall.Kill(c.GoferPid, 0); err == nil {
+ return fmt.Errorf("gofer is still running")
+ }
+ c.GoferPid = 0
+ return nil
+ }
+ return backoff.Retry(op, b)
+}
+
+func (c *Container) createGoferProcess(spec *specs.Spec, conf *boot.Config, bundleDir string) ([]*os.File, *os.File, error) {
+ // Start with the general config flags.
+ args := conf.ToFlags()
+
+ var goferEnds []*os.File
+
+ // nextFD is the next available file descriptor for the gofer process.
+ // It starts at 3 because 0-2 are used by stdin/stdout/stderr.
+ nextFD := 3
+
+ if conf.LogFilename != "" {
+ logFile, err := os.OpenFile(conf.LogFilename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
+ if err != nil {
+ return nil, nil, fmt.Errorf("opening log file %q: %v", conf.LogFilename, err)
+ }
+ defer logFile.Close()
+ goferEnds = append(goferEnds, logFile)
+ args = append(args, "--log-fd="+strconv.Itoa(nextFD))
+ nextFD++
+ }
+
+ if conf.DebugLog != "" {
+ test := ""
+ if len(conf.TestOnlyTestNameEnv) != 0 {
+ // Fetch test name if one is provided and the test only flag was set.
+ if t, ok := specutils.EnvVar(spec.Process.Env, conf.TestOnlyTestNameEnv); ok {
+ test = t
+ }
+ }
+ debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "gofer", test)
+ if err != nil {
+ return nil, nil, fmt.Errorf("opening debug log file in %q: %v", conf.DebugLog, err)
+ }
+ defer debugLogFile.Close()
+ goferEnds = append(goferEnds, debugLogFile)
+ args = append(args, "--debug-log-fd="+strconv.Itoa(nextFD))
+ nextFD++
+ }
+
+ args = append(args, "gofer", "--bundle", bundleDir)
+ if conf.Overlay {
+ args = append(args, "--panic-on-write=true")
+ }
+
+ // Open the spec file to donate to the sandbox.
+ specFile, err := specutils.OpenSpec(bundleDir)
+ if err != nil {
+ return nil, nil, fmt.Errorf("opening spec file: %v", err)
+ }
+ defer specFile.Close()
+ goferEnds = append(goferEnds, specFile)
+ args = append(args, "--spec-fd="+strconv.Itoa(nextFD))
+ nextFD++
+
+ // Create pipe that allows gofer to send mount list to sandbox after all paths
+ // have been resolved.
+ mountsSand, mountsGofer, err := os.Pipe()
+ if err != nil {
+ return nil, nil, err
+ }
+ defer mountsGofer.Close()
+ goferEnds = append(goferEnds, mountsGofer)
+ args = append(args, fmt.Sprintf("--mounts-fd=%d", nextFD))
+ nextFD++
+
+ // Add root mount and then add any other additional mounts.
+ mountCount := 1
+ for _, m := range spec.Mounts {
+ if specutils.Is9PMount(m) {
+ mountCount++
+ }
+ }
+
+ sandEnds := make([]*os.File, 0, mountCount)
+ for i := 0; i < mountCount; i++ {
+ fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ return nil, nil, err
+ }
+ sandEnds = append(sandEnds, os.NewFile(uintptr(fds[0]), "sandbox IO FD"))
+
+ goferEnd := os.NewFile(uintptr(fds[1]), "gofer IO FD")
+ defer goferEnd.Close()
+ goferEnds = append(goferEnds, goferEnd)
+
+ args = append(args, fmt.Sprintf("--io-fds=%d", nextFD))
+ nextFD++
+ }
+
+ binPath := specutils.ExePath
+ cmd := exec.Command(binPath, args...)
+ cmd.ExtraFiles = goferEnds
+ cmd.Args[0] = "runsc-gofer"
+
+ // Enter new namespaces to isolate from the rest of the system. Don't unshare
+ // cgroup because gofer is added to a cgroup in the caller's namespace.
+ nss := []specs.LinuxNamespace{
+ {Type: specs.IPCNamespace},
+ {Type: specs.MountNamespace},
+ {Type: specs.NetworkNamespace},
+ {Type: specs.PIDNamespace},
+ {Type: specs.UTSNamespace},
+ }
+
+ // Setup any uid/gid mappings, and create or join the configured user
+ // namespace so the gofer's view of the filesystem aligns with the
+ // users in the sandbox.
+ userNS := specutils.FilterNS([]specs.LinuxNamespaceType{specs.UserNamespace}, spec)
+ nss = append(nss, userNS...)
+ specutils.SetUIDGIDMappings(cmd, spec)
+ if len(userNS) != 0 {
+ // We need to set UID and GID to have capabilities in a new user namespace.
+ cmd.SysProcAttr.Credential = &syscall.Credential{Uid: 0, Gid: 0}
+ }
+
+ // Start the gofer in the given namespace.
+ log.Debugf("Starting gofer: %s %v", binPath, args)
+ if err := specutils.StartInNS(cmd, nss); err != nil {
+ return nil, nil, fmt.Errorf("Gofer: %v", err)
+ }
+ log.Infof("Gofer started, PID: %d", cmd.Process.Pid)
+ c.GoferPid = cmd.Process.Pid
+ c.goferIsChild = true
+ return sandEnds, mountsSand, nil
+}
+
+// changeStatus transitions from one status to another ensuring that the
+// transition is valid.
+func (c *Container) changeStatus(s Status) {
+ switch s {
+ case Creating:
+ // Initial state, never transitions to it.
+ panic(fmt.Sprintf("invalid state transition: %v => %v", c.Status, s))
+
+ case Created:
+ if c.Status != Creating {
+ panic(fmt.Sprintf("invalid state transition: %v => %v", c.Status, s))
+ }
+ if c.Sandbox == nil {
+ panic("sandbox cannot be nil")
+ }
+
+ case Paused:
+ if c.Status != Running {
+ panic(fmt.Sprintf("invalid state transition: %v => %v", c.Status, s))
+ }
+ if c.Sandbox == nil {
+ panic("sandbox cannot be nil")
+ }
+
+ case Running:
+ if c.Status != Created && c.Status != Paused {
+ panic(fmt.Sprintf("invalid state transition: %v => %v", c.Status, s))
+ }
+ if c.Sandbox == nil {
+ panic("sandbox cannot be nil")
+ }
+
+ case Stopped:
+ if c.Status != Creating && c.Status != Created && c.Status != Running && c.Status != Stopped {
+ panic(fmt.Sprintf("invalid state transition: %v => %v", c.Status, s))
+ }
+
+ default:
+ panic(fmt.Sprintf("invalid new state: %v", s))
+ }
+ c.Status = s
+}
+
+func (c *Container) isSandboxRunning() bool {
+ return c.Sandbox != nil && c.Sandbox.IsRunning()
+}
+
+func (c *Container) requireStatus(action string, statuses ...Status) error {
+ for _, s := range statuses {
+ if c.Status == s {
+ return nil
+ }
+ }
+ return fmt.Errorf("cannot %s container %q in state %s", action, c.ID, c.Status)
+}
+
+func isRoot(spec *specs.Spec) bool {
+ return specutils.SpecContainerType(spec) != specutils.ContainerTypeContainer
+}
+
+// runInCgroup executes fn inside the specified cgroup. If cg is nil, execute
+// it in the current context.
+func runInCgroup(cg *cgroup.Cgroup, fn func() error) error {
+ if cg == nil {
+ return fn()
+ }
+ restore, err := cg.Join()
+ defer restore()
+ if err != nil {
+ return err
+ }
+ return fn()
+}
+
+// adjustGoferOOMScoreAdj sets the oom_store_adj for the container's gofer.
+func (c *Container) adjustGoferOOMScoreAdj() error {
+ if c.GoferPid == 0 || c.Spec.Process.OOMScoreAdj == nil {
+ return nil
+ }
+ return setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj)
+}
+
+// adjustSandboxOOMScoreAdj sets the oom_score_adj for the sandbox.
+// oom_score_adj is set to the lowest oom_score_adj among the containers
+// running in the sandbox.
+//
+// TODO(gvisor.dev/issue/238): This call could race with other containers being
+// created at the same time and end up setting the wrong oom_score_adj to the
+// sandbox. Use rpc client to synchronize.
+func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) error {
+ containers, err := loadSandbox(rootDir, s.ID)
+ if err != nil {
+ return fmt.Errorf("loading sandbox containers: %v", err)
+ }
+
+ // Do nothing if the sandbox has been terminated.
+ if len(containers) == 0 {
+ return nil
+ }
+
+ // Get the lowest score for all containers.
+ var lowScore int
+ scoreFound := false
+ if len(containers) == 1 && specutils.SpecContainerType(containers[0].Spec) == specutils.ContainerTypeUnspecified {
+ // This is a single-container sandbox. Set the oom_score_adj to
+ // the value specified in the OCI bundle.
+ if containers[0].Spec.Process.OOMScoreAdj != nil {
+ scoreFound = true
+ lowScore = *containers[0].Spec.Process.OOMScoreAdj
+ }
+ } else {
+ for _, container := range containers {
+ // Special multi-container support for CRI. Ignore the root
+ // container when calculating oom_score_adj for the sandbox because
+ // it is the infrastructure (pause) container and always has a very
+ // low oom_score_adj.
+ //
+ // We will use OOMScoreAdj in the single-container case where the
+ // containerd container-type annotation is not present.
+ if specutils.SpecContainerType(container.Spec) == specutils.ContainerTypeSandbox {
+ continue
+ }
+
+ if container.Spec.Process.OOMScoreAdj != nil && (!scoreFound || *container.Spec.Process.OOMScoreAdj < lowScore) {
+ scoreFound = true
+ lowScore = *container.Spec.Process.OOMScoreAdj
+ }
+ }
+ }
+
+ // If the container is destroyed and remaining containers have no
+ // oomScoreAdj specified then we must revert to the oom_score_adj of the
+ // parent process.
+ if !scoreFound && destroy {
+ ppid, err := specutils.GetParentPid(s.Pid)
+ if err != nil {
+ return fmt.Errorf("getting parent pid of sandbox pid %d: %v", s.Pid, err)
+ }
+ pScore, err := specutils.GetOOMScoreAdj(ppid)
+ if err != nil {
+ return fmt.Errorf("getting oom_score_adj of parent %d: %v", ppid, err)
+ }
+
+ scoreFound = true
+ lowScore = pScore
+ }
+
+ // Only set oom_score_adj if one of the containers has oom_score_adj set
+ // in the OCI bundle. If not, we need to inherit the parent process's
+ // oom_score_adj.
+ // See: https://github.com/opencontainers/runtime-spec/blob/master/config.md#linux-process
+ if !scoreFound {
+ return nil
+ }
+
+ // Set the lowest of all containers oom_score_adj to the sandbox.
+ return setOOMScoreAdj(s.Pid, lowScore)
+}
+
+// setOOMScoreAdj sets oom_score_adj to the given value for the given PID.
+// /proc must be available and mounted read-write. scoreAdj should be between
+// -1000 and 1000. It's a noop if the process has already exited.
+func setOOMScoreAdj(pid int, scoreAdj int) error {
+ f, err := os.OpenFile(fmt.Sprintf("/proc/%d/oom_score_adj", pid), os.O_WRONLY, 0644)
+ if err != nil {
+ // Ignore NotExist errors because it can race with process exit.
+ if os.IsNotExist(err) {
+ log.Warningf("Process (%d) not found setting oom_score_adj", pid)
+ return nil
+ }
+ return err
+ }
+ defer f.Close()
+ if _, err := f.WriteString(strconv.Itoa(scoreAdj)); err != nil {
+ if errors.Is(err, syscall.ESRCH) {
+ log.Warningf("Process (%d) exited while setting oom_score_adj", pid)
+ return nil
+ }
+ return fmt.Errorf("setting oom_score_adj to %q: %v", scoreAdj, err)
+ }
+ return nil
+}
diff --git a/runsc/container/container_norace_test.go b/runsc/container/container_norace_test.go
new file mode 100644
index 000000000..838c1e20a
--- /dev/null
+++ b/runsc/container/container_norace_test.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !race
+
+package container
+
+// Allow both kvm and ptrace for non-race builds.
+var platformOptions = []configOption{ptrace, kvm}
diff --git a/runsc/container/container_race_test.go b/runsc/container/container_race_test.go
new file mode 100644
index 000000000..9fb4c4fc0
--- /dev/null
+++ b/runsc/container/container_race_test.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build race
+
+package container
+
+// Only enabled ptrace with race builds.
+var platformOptions = []configOption{ptrace}
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
new file mode 100644
index 000000000..cd76645bd
--- /dev/null
+++ b/runsc/container/container_test.go
@@ -0,0 +1,2348 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package container
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "math"
+ "os"
+ "path"
+ "path/filepath"
+ "reflect"
+ "strconv"
+ "strings"
+ "syscall"
+ "testing"
+ "time"
+
+ "github.com/cenkalti/backoff"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/control"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/boot/platforms"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// waitForProcessList waits for the given process list to show up in the container.
+func waitForProcessList(cont *Container, want []*control.Process) error {
+ cb := func() error {
+ got, err := cont.Processes()
+ if err != nil {
+ err = fmt.Errorf("error getting process data from container: %v", err)
+ return &backoff.PermanentError{Err: err}
+ }
+ if !procListsEqual(got, want) {
+ return fmt.Errorf("container got process list: %s, want: %s", procListToString(got), procListToString(want))
+ }
+ return nil
+ }
+ // Gives plenty of time as tests can run slow under --race.
+ return testutil.Poll(cb, 30*time.Second)
+}
+
+func waitForProcessCount(cont *Container, want int) error {
+ cb := func() error {
+ pss, err := cont.Processes()
+ if err != nil {
+ err = fmt.Errorf("error getting process data from container: %v", err)
+ return &backoff.PermanentError{Err: err}
+ }
+ if got := len(pss); got != want {
+ log.Infof("Waiting for process count to reach %d. Current: %d", want, got)
+ return fmt.Errorf("wrong process count, got: %d, want: %d", got, want)
+ }
+ return nil
+ }
+ // Gives plenty of time as tests can run slow under --race.
+ return testutil.Poll(cb, 30*time.Second)
+}
+
+func blockUntilWaitable(pid int) error {
+ _, _, err := specutils.RetryEintr(func() (uintptr, uintptr, error) {
+ var err error
+ _, _, err1 := syscall.Syscall6(syscall.SYS_WAITID, 1, uintptr(pid), 0, syscall.WEXITED|syscall.WNOWAIT, 0, 0)
+ if err1 != 0 {
+ err = err1
+ }
+ return 0, 0, err
+ })
+ return err
+}
+
+// procListsEqual is used to check whether 2 Process lists are equal. Fields
+// set to -1 in wants are ignored. Timestamp and threads fields are always
+// ignored.
+func procListsEqual(gots, wants []*control.Process) bool {
+ if len(gots) != len(wants) {
+ return false
+ }
+ for i := range gots {
+ got := gots[i]
+ want := wants[i]
+
+ if want.UID != math.MaxUint32 && want.UID != got.UID {
+ return false
+ }
+ if want.PID != -1 && want.PID != got.PID {
+ return false
+ }
+ if want.PPID != -1 && want.PPID != got.PPID {
+ return false
+ }
+ if len(want.TTY) != 0 && want.TTY != got.TTY {
+ return false
+ }
+ if len(want.Cmd) != 0 && want.Cmd != got.Cmd {
+ return false
+ }
+ }
+ return true
+}
+
+type processBuilder struct {
+ process control.Process
+}
+
+func newProcessBuilder() *processBuilder {
+ return &processBuilder{
+ process: control.Process{
+ UID: math.MaxUint32,
+ PID: -1,
+ PPID: -1,
+ },
+ }
+}
+
+func (p *processBuilder) Cmd(cmd string) *processBuilder {
+ p.process.Cmd = cmd
+ return p
+}
+
+func (p *processBuilder) PID(pid kernel.ThreadID) *processBuilder {
+ p.process.PID = pid
+ return p
+}
+
+func (p *processBuilder) PPID(ppid kernel.ThreadID) *processBuilder {
+ p.process.PPID = ppid
+ return p
+}
+
+func (p *processBuilder) UID(uid auth.KUID) *processBuilder {
+ p.process.UID = uid
+ return p
+}
+
+func (p *processBuilder) Process() *control.Process {
+ return &p.process
+}
+
+func procListToString(pl []*control.Process) string {
+ strs := make([]string, 0, len(pl))
+ for _, p := range pl {
+ strs = append(strs, fmt.Sprintf("%+v", p))
+ }
+ return fmt.Sprintf("[%s]", strings.Join(strs, ","))
+}
+
+// createWriteableOutputFile creates an output file that can be read and
+// written to in the sandbox.
+func createWriteableOutputFile(path string) (*os.File, error) {
+ outputFile, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0666)
+ if err != nil {
+ return nil, fmt.Errorf("error creating file: %q, %v", path, err)
+ }
+
+ // Chmod to allow writing after umask.
+ if err := outputFile.Chmod(0666); err != nil {
+ return nil, fmt.Errorf("error chmoding file: %q, %v", path, err)
+ }
+ return outputFile, nil
+}
+
+func waitForFileNotEmpty(f *os.File) error {
+ op := func() error {
+ fi, err := f.Stat()
+ if err != nil {
+ return err
+ }
+ if fi.Size() == 0 {
+ return fmt.Errorf("file %q is empty", f.Name())
+ }
+ return nil
+ }
+
+ return testutil.Poll(op, 30*time.Second)
+}
+
+func waitForFileExist(path string) error {
+ op := func() error {
+ if _, err := os.Stat(path); os.IsNotExist(err) {
+ return err
+ }
+ return nil
+ }
+
+ return testutil.Poll(op, 30*time.Second)
+}
+
+// readOutputNum reads a file at given filepath and returns the int at the
+// requested position.
+func readOutputNum(file string, position int) (int, error) {
+ f, err := os.Open(file)
+ if err != nil {
+ return 0, fmt.Errorf("error opening file: %q, %v", file, err)
+ }
+
+ // Ensure that there is content in output file.
+ if err := waitForFileNotEmpty(f); err != nil {
+ return 0, fmt.Errorf("error waiting for output file: %v", err)
+ }
+
+ b, err := ioutil.ReadAll(f)
+ if err != nil {
+ return 0, fmt.Errorf("error reading file: %v", err)
+ }
+ if len(b) == 0 {
+ return 0, fmt.Errorf("error no content was read")
+ }
+
+ // Strip leading null bytes caused by file offset not being 0 upon restore.
+ b = bytes.Trim(b, "\x00")
+ nums := strings.Split(string(b), "\n")
+
+ if position >= len(nums) {
+ return 0, fmt.Errorf("position %v is not within the length of content %v", position, nums)
+ }
+ if position == -1 {
+ // Expectation of newline at the end of last position.
+ position = len(nums) - 2
+ }
+ num, err := strconv.Atoi(nums[position])
+ if err != nil {
+ return 0, fmt.Errorf("error getting number from file: %v", err)
+ }
+ return num, nil
+}
+
+// run starts the sandbox and waits for it to exit, checking that the
+// application succeeded.
+func run(spec *specs.Spec, conf *boot.Config) error {
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ return fmt.Errorf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create, start and wait for the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ Attached: true,
+ }
+ ws, err := Run(conf, args)
+ if err != nil {
+ return fmt.Errorf("running container: %v", err)
+ }
+ if !ws.Exited() || ws.ExitStatus() != 0 {
+ return fmt.Errorf("container failed, waitStatus: %v", ws)
+ }
+ return nil
+}
+
+type configOption int
+
+const (
+ overlay configOption = iota
+ ptrace
+ kvm
+ nonExclusiveFS
+)
+
+var (
+ noOverlay = append(platformOptions, nonExclusiveFS)
+ all = append(noOverlay, overlay)
+)
+
+// configs generates different configurations to run tests.
+func configs(t *testing.T, opts ...configOption) map[string]*boot.Config {
+ // Always load the default config.
+ cs := make(map[string]*boot.Config)
+ for _, o := range opts {
+ switch o {
+ case overlay:
+ c := testutil.TestConfig(t)
+ c.Overlay = true
+ cs["overlay"] = c
+ case ptrace:
+ c := testutil.TestConfig(t)
+ c.Platform = platforms.Ptrace
+ cs["ptrace"] = c
+ case kvm:
+ c := testutil.TestConfig(t)
+ c.Platform = platforms.KVM
+ cs["kvm"] = c
+ case nonExclusiveFS:
+ c := testutil.TestConfig(t)
+ c.FileAccess = boot.FileAccessShared
+ cs["non-exclusive"] = c
+ default:
+ panic(fmt.Sprintf("unknown config option %v", o))
+ }
+ }
+ return cs
+}
+
+func configsWithVFS2(t *testing.T, opts ...configOption) map[string]*boot.Config {
+ vfs1 := configs(t, opts...)
+
+ var optsVFS2 []configOption
+ for _, opt := range opts {
+ // TODO(gvisor.dev/issue/1487): Enable overlay tests.
+ if opt != overlay {
+ optsVFS2 = append(optsVFS2, opt)
+ }
+ }
+
+ for key, value := range configs(t, optsVFS2...) {
+ value.VFS2 = true
+ vfs1[key+"VFS2"] = value
+ }
+
+ return vfs1
+}
+
+// TestLifecycle tests the basic Create/Start/Signal/Destroy container lifecycle.
+// It verifies after each step that the container can be loaded from disk, and
+// has the correct status.
+func TestLifecycle(t *testing.T) {
+ // Start the child reaper.
+ childReaper := &testutil.Reaper{}
+ childReaper.Start()
+ defer childReaper.Stop()
+
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ // The container will just sleep for a long time. We will kill it before
+ // it finishes sleeping.
+ spec := testutil.NewSpecWithArgs("sleep", "100")
+
+ rootDir, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // expectedPL lists the expected process state of the container.
+ expectedPL := []*control.Process{
+ newProcessBuilder().Cmd("sleep").Process(),
+ }
+ // Create the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
+
+ // Load the container from disk and check the status.
+ c, err = Load(rootDir, args.ID)
+ if err != nil {
+ t.Fatalf("error loading container: %v", err)
+ }
+ if got, want := c.Status, Created; got != want {
+ t.Errorf("container status got %v, want %v", got, want)
+ }
+
+ // List should return the container id.
+ ids, err := List(rootDir)
+ if err != nil {
+ t.Fatalf("error listing containers: %v", err)
+ }
+ if got, want := ids, []string{args.ID}; !reflect.DeepEqual(got, want) {
+ t.Errorf("container list got %v, want %v", got, want)
+ }
+
+ // Start the container.
+ if err := c.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // Load the container from disk and check the status.
+ c, err = Load(rootDir, args.ID)
+ if err != nil {
+ t.Fatalf("error loading container: %v", err)
+ }
+ if got, want := c.Status, Running; got != want {
+ t.Errorf("container status got %v, want %v", got, want)
+ }
+
+ // Verify that "sleep 100" is running.
+ if err := waitForProcessList(c, expectedPL); err != nil {
+ t.Error(err)
+ }
+
+ // Wait on the container.
+ ch := make(chan error)
+ go func() {
+ ws, err := c.Wait()
+ if err != nil {
+ ch <- err
+ }
+ if got, want := ws.Signal(), syscall.SIGTERM; got != want {
+ ch <- fmt.Errorf("got signal %v, want %v", got, want)
+ }
+ ch <- nil
+ }()
+
+ // Wait a bit to ensure that we've started waiting on
+ // the container before we signal.
+ time.Sleep(time.Second)
+
+ // Send the container a SIGTERM which will cause it to stop.
+ if err := c.SignalContainer(syscall.SIGTERM, false); err != nil {
+ t.Fatalf("error sending signal %v to container: %v", syscall.SIGTERM, err)
+ }
+
+ // Wait for it to die.
+ if err := <-ch; err != nil {
+ t.Fatalf("error waiting for container: %v", err)
+ }
+
+ // Load the container from disk and check the status.
+ c, err = Load(rootDir, args.ID)
+ if err != nil {
+ t.Fatalf("error loading container: %v", err)
+ }
+ if got, want := c.Status, Stopped; got != want {
+ t.Errorf("container status got %v, want %v", got, want)
+ }
+
+ // Destroy the container.
+ if err := c.Destroy(); err != nil {
+ t.Fatalf("error destroying container: %v", err)
+ }
+
+ // List should not return the container id.
+ ids, err = List(rootDir)
+ if err != nil {
+ t.Fatalf("error listing containers: %v", err)
+ }
+ if len(ids) != 0 {
+ t.Errorf("expected container list to be empty, but got %v", ids)
+ }
+
+ // Loading the container by id should fail.
+ if _, err = Load(rootDir, args.ID); err == nil {
+ t.Errorf("expected loading destroyed container to fail, but it did not")
+ }
+ })
+ }
+}
+
+// Test the we can execute the application with different path formats.
+func TestExePath(t *testing.T) {
+ // Create two directories that will be prepended to PATH.
+ firstPath, err := ioutil.TempDir(testutil.TmpDir(), "first")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(firstPath)
+ secondPath, err := ioutil.TempDir(testutil.TmpDir(), "second")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(secondPath)
+
+ // Create two minimal executables in the second path, two of which
+ // will be masked by files in first path.
+ for _, p := range []string{"unmasked", "masked1", "masked2"} {
+ path := filepath.Join(secondPath, p)
+ f, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0777)
+ if err != nil {
+ t.Fatalf("error opening path: %v", err)
+ }
+ defer f.Close()
+ if _, err := io.WriteString(f, "#!/bin/true\n"); err != nil {
+ t.Fatalf("error writing contents: %v", err)
+ }
+ }
+
+ // Create a non-executable file in the first path which masks a healthy
+ // executable in the second.
+ nonExecutable := filepath.Join(firstPath, "masked1")
+ f2, err := os.OpenFile(nonExecutable, os.O_CREATE|os.O_EXCL, 0666)
+ if err != nil {
+ t.Fatalf("error opening file: %v", err)
+ }
+ f2.Close()
+
+ // Create a non-regular file in the first path which masks a healthy
+ // executable in the second.
+ nonRegular := filepath.Join(firstPath, "masked2")
+ if err := os.Mkdir(nonRegular, 0777); err != nil {
+ t.Fatalf("error making directory: %v", err)
+ }
+
+ for name, conf := range configsWithVFS2(t, overlay) {
+ t.Run(name, func(t *testing.T) {
+ for _, test := range []struct {
+ path string
+ success bool
+ }{
+ {path: "true", success: true},
+ {path: "bin/true", success: true},
+ {path: "/bin/true", success: true},
+ {path: "thisfiledoesntexit", success: false},
+ {path: "bin/thisfiledoesntexit", success: false},
+ {path: "/bin/thisfiledoesntexit", success: false},
+
+ {path: "unmasked", success: true},
+ {path: filepath.Join(firstPath, "unmasked"), success: false},
+ {path: filepath.Join(secondPath, "unmasked"), success: true},
+
+ {path: "masked1", success: true},
+ {path: filepath.Join(firstPath, "masked1"), success: false},
+ {path: filepath.Join(secondPath, "masked1"), success: true},
+
+ {path: "masked2", success: true},
+ {path: filepath.Join(firstPath, "masked2"), success: false},
+ {path: filepath.Join(secondPath, "masked2"), success: true},
+ } {
+ t.Run(fmt.Sprintf("path=%s,success=%t", test.path, test.success), func(t *testing.T) {
+ spec := testutil.NewSpecWithArgs(test.path)
+ spec.Process.Env = []string{
+ fmt.Sprintf("PATH=%s:%s:%s", firstPath, secondPath, os.Getenv("PATH")),
+ }
+
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("exec: error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ Attached: true,
+ }
+ ws, err := Run(conf, args)
+
+ if test.success {
+ if err != nil {
+ t.Errorf("exec: error running container: %v", err)
+ }
+ if ws.ExitStatus() != 0 {
+ t.Errorf("exec: got exit status %v want %v", ws.ExitStatus(), 0)
+ }
+ } else {
+ if err == nil {
+ t.Errorf("exec: got: no error, want: error")
+ }
+ }
+ })
+ }
+ })
+ }
+}
+
+// Test the we can retrieve the application exit status from the container.
+func TestAppExitStatus(t *testing.T) {
+ doAppExitStatus(t, false)
+}
+
+// This is TestAppExitStatus for VFSv2.
+func TestAppExitStatusVFS2(t *testing.T) {
+ doAppExitStatus(t, true)
+}
+
+func doAppExitStatus(t *testing.T, vfs2 bool) {
+ // First container will succeed.
+ succSpec := testutil.NewSpecWithArgs("true")
+ conf := testutil.TestConfig(t)
+ conf.VFS2 = vfs2
+ _, bundleDir, cleanup, err := testutil.SetupContainer(succSpec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: succSpec,
+ BundleDir: bundleDir,
+ Attached: true,
+ }
+ ws, err := Run(conf, args)
+ if err != nil {
+ t.Fatalf("error running container: %v", err)
+ }
+ if ws.ExitStatus() != 0 {
+ t.Errorf("got exit status %v want %v", ws.ExitStatus(), 0)
+ }
+
+ // Second container exits with non-zero status.
+ wantStatus := 123
+ errSpec := testutil.NewSpecWithArgs("bash", "-c", fmt.Sprintf("exit %d", wantStatus))
+
+ _, bundleDir2, cleanup2, err := testutil.SetupContainer(errSpec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup2()
+
+ args2 := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: errSpec,
+ BundleDir: bundleDir2,
+ Attached: true,
+ }
+ ws, err = Run(conf, args2)
+ if err != nil {
+ t.Fatalf("error running container: %v", err)
+ }
+ if ws.ExitStatus() != wantStatus {
+ t.Errorf("got exit status %v want %v", ws.ExitStatus(), wantStatus)
+ }
+}
+
+// TestExec verifies that a container can exec a new program.
+func TestExec(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "exec-test")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ cmd := fmt.Sprintf("ln -s /bin/true %q/symlink && sleep 100", dir)
+ spec := testutil.NewSpecWithArgs("sh", "-c", cmd)
+
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // Wait until sleep is running to ensure the symlink was created.
+ expectedPL := []*control.Process{
+ newProcessBuilder().Cmd("sh").Process(),
+ newProcessBuilder().Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(cont, expectedPL); err != nil {
+ t.Fatalf("waitForProcessList: %v", err)
+ }
+
+ for _, tc := range []struct {
+ name string
+ args control.ExecArgs
+ }{
+ {
+ name: "complete",
+ args: control.ExecArgs{
+ Filename: "/bin/true",
+ Argv: []string{"/bin/true"},
+ },
+ },
+ {
+ name: "filename",
+ args: control.ExecArgs{
+ Filename: "/bin/true",
+ },
+ },
+ {
+ name: "argv",
+ args: control.ExecArgs{
+ Argv: []string{"/bin/true"},
+ },
+ },
+ {
+ name: "filename resolution",
+ args: control.ExecArgs{
+ Filename: "true",
+ Envv: []string{"PATH=/bin"},
+ },
+ },
+ {
+ name: "argv resolution",
+ args: control.ExecArgs{
+ Argv: []string{"true"},
+ Envv: []string{"PATH=/bin"},
+ },
+ },
+ {
+ name: "argv symlink",
+ args: control.ExecArgs{
+ Argv: []string{filepath.Join(dir, "symlink")},
+ },
+ },
+ {
+ name: "working dir",
+ args: control.ExecArgs{
+ Argv: []string{"/bin/sh", "-c", `if [[ "${PWD}" != "/tmp" ]]; then exit 1; fi`},
+ WorkingDirectory: "/tmp",
+ },
+ },
+ {
+ name: "user",
+ args: control.ExecArgs{
+ Argv: []string{"/bin/sh", "-c", `if [[ "$(id -u)" != "343" ]]; then exit 1; fi`},
+ KUID: 343,
+ },
+ },
+ {
+ name: "group",
+ args: control.ExecArgs{
+ Argv: []string{"/bin/sh", "-c", `if [[ "$(id -g)" != "343" ]]; then exit 1; fi`},
+ KGID: 343,
+ },
+ },
+ {
+ name: "env",
+ args: control.ExecArgs{
+ Argv: []string{"/bin/sh", "-c", `if [[ "${FOO}" != "123" ]]; then exit 1; fi`},
+ Envv: []string{"FOO=123"},
+ },
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ // t.Parallel()
+ if ws, err := cont.executeSync(&tc.args); err != nil {
+ t.Fatalf("executeAsync(%+v): %v", tc.args, err)
+ } else if ws != 0 {
+ t.Fatalf("executeAsync(%+v) failed with exit: %v", tc.args, ws)
+ }
+ })
+ }
+ })
+ }
+}
+
+// TestExecProcList verifies that a container can exec a new program and it
+// shows correcly in the process list.
+func TestExecProcList(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ const uid = 343
+ spec := testutil.NewSpecWithArgs("sleep", "100")
+
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ execArgs := &control.ExecArgs{
+ Filename: "/bin/sleep",
+ Argv: []string{"/bin/sleep", "5"},
+ WorkingDirectory: "/",
+ KUID: uid,
+ }
+
+ // Verify that "sleep 100" and "sleep 5" are running after exec. First,
+ // start running exec (which blocks).
+ ch := make(chan error)
+ go func() {
+ exitStatus, err := cont.executeSync(execArgs)
+ if err != nil {
+ ch <- err
+ } else if exitStatus != 0 {
+ ch <- fmt.Errorf("failed with exit status: %v", exitStatus)
+ } else {
+ ch <- nil
+ }
+ }()
+
+ // expectedPL lists the expected process state of the container.
+ expectedPL := []*control.Process{
+ newProcessBuilder().PID(1).PPID(0).Cmd("sleep").UID(0).Process(),
+ newProcessBuilder().PID(2).PPID(0).Cmd("sleep").UID(uid).Process(),
+ }
+ if err := waitForProcessList(cont, expectedPL); err != nil {
+ t.Fatalf("error waiting for processes: %v", err)
+ }
+
+ // Ensure that exec finished without error.
+ select {
+ case <-time.After(10 * time.Second):
+ t.Fatalf("container timed out waiting for exec to finish.")
+ case err := <-ch:
+ if err != nil {
+ t.Errorf("container failed to exec %v: %v", args, err)
+ }
+ }
+ })
+ }
+}
+
+// TestKillPid verifies that we can signal individual exec'd processes.
+func TestKillPid(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, overlay) {
+ t.Run(name, func(t *testing.T) {
+ app, err := testutil.FindFile("test/cmd/test_app/test_app")
+ if err != nil {
+ t.Fatal("error finding test_app:", err)
+ }
+
+ const nProcs = 4
+ spec := testutil.NewSpecWithArgs(app, "task-tree", "--depth", strconv.Itoa(nProcs-1), "--width=1", "--pause=true")
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // Verify that all processes are running.
+ if err := waitForProcessCount(cont, nProcs); err != nil {
+ t.Fatalf("timed out waiting for processes to start: %v", err)
+ }
+
+ // Kill the child process with the largest PID.
+ procs, err := cont.Processes()
+ if err != nil {
+ t.Fatalf("failed to get process list: %v", err)
+ }
+ var pid int32
+ for _, p := range procs {
+ if pid < int32(p.PID) {
+ pid = int32(p.PID)
+ }
+ }
+ if err := cont.SignalProcess(syscall.SIGKILL, pid); err != nil {
+ t.Fatalf("failed to signal process %d: %v", pid, err)
+ }
+
+ // Verify that one process is gone.
+ if err := waitForProcessCount(cont, nProcs-1); err != nil {
+ t.Fatalf("error waiting for processes: %v", err)
+ }
+
+ procs, err = cont.Processes()
+ if err != nil {
+ t.Fatalf("failed to get process list: %v", err)
+ }
+ for _, p := range procs {
+ if pid == int32(p.PID) {
+ t.Fatalf("pid %d is still alive, which should be killed", pid)
+ }
+ }
+ })
+ }
+}
+
+// TestCheckpointRestore creates a container that continuously writes successive integers
+// to a file. To test checkpoint and restore functionality, the container is
+// checkpointed and the last number printed to the file is recorded. Then, it is restored in two
+// new containers and the first number printed from these containers is checked. Both should
+// be the next consecutive number after the last number from the checkpointed container.
+func TestCheckpointRestore(t *testing.T) {
+ // Skip overlay because test requires writing to host file.
+ for name, conf := range configs(t, noOverlay...) {
+ t.Run(name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "checkpoint-test")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir failed: %v", err)
+ }
+ defer os.RemoveAll(dir)
+ if err := os.Chmod(dir, 0777); err != nil {
+ t.Fatalf("error chmoding file: %q, %v", dir, err)
+ }
+
+ outputPath := filepath.Join(dir, "output")
+ outputFile, err := createWriteableOutputFile(outputPath)
+ if err != nil {
+ t.Fatalf("error creating output file: %v", err)
+ }
+ defer outputFile.Close()
+
+ script := fmt.Sprintf("for ((i=0; ;i++)); do echo $i >> %q; sleep 1; done", outputPath)
+ spec := testutil.NewSpecWithArgs("bash", "-c", script)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // Set the image path, which is where the checkpoint image will be saved.
+ imagePath := filepath.Join(dir, "test-image-file")
+
+ // Create the image file and open for writing.
+ file, err := os.OpenFile(imagePath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0644)
+ if err != nil {
+ t.Fatalf("error opening new file at imagePath: %v", err)
+ }
+ defer file.Close()
+
+ // Wait until application has ran.
+ if err := waitForFileNotEmpty(outputFile); err != nil {
+ t.Fatalf("Failed to wait for output file: %v", err)
+ }
+
+ // Checkpoint running container; save state into new file.
+ if err := cont.Checkpoint(file); err != nil {
+ t.Fatalf("error checkpointing container to empty file: %v", err)
+ }
+ defer os.RemoveAll(imagePath)
+
+ lastNum, err := readOutputNum(outputPath, -1)
+ if err != nil {
+ t.Fatalf("error with outputFile: %v", err)
+ }
+
+ // Delete and recreate file before restoring.
+ if err := os.Remove(outputPath); err != nil {
+ t.Fatalf("error removing file")
+ }
+ outputFile2, err := createWriteableOutputFile(outputPath)
+ if err != nil {
+ t.Fatalf("error creating output file: %v", err)
+ }
+ defer outputFile2.Close()
+
+ // Restore into a new container.
+ args2 := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont2, err := New(conf, args2)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont2.Destroy()
+
+ if err := cont2.Restore(spec, conf, imagePath); err != nil {
+ t.Fatalf("error restoring container: %v", err)
+ }
+
+ // Wait until application has ran.
+ if err := waitForFileNotEmpty(outputFile2); err != nil {
+ t.Fatalf("Failed to wait for output file: %v", err)
+ }
+
+ firstNum, err := readOutputNum(outputPath, 0)
+ if err != nil {
+ t.Fatalf("error with outputFile: %v", err)
+ }
+
+ // Check that lastNum is one less than firstNum and that the container picks
+ // up from where it left off.
+ if lastNum+1 != firstNum {
+ t.Errorf("error numbers not in order, previous: %d, next: %d", lastNum, firstNum)
+ }
+ cont2.Destroy()
+
+ // Restore into another container!
+ // Delete and recreate file before restoring.
+ if err := os.Remove(outputPath); err != nil {
+ t.Fatalf("error removing file")
+ }
+ outputFile3, err := createWriteableOutputFile(outputPath)
+ if err != nil {
+ t.Fatalf("error creating output file: %v", err)
+ }
+ defer outputFile3.Close()
+
+ // Restore into a new container.
+ args3 := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont3, err := New(conf, args3)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont3.Destroy()
+
+ if err := cont3.Restore(spec, conf, imagePath); err != nil {
+ t.Fatalf("error restoring container: %v", err)
+ }
+
+ // Wait until application has ran.
+ if err := waitForFileNotEmpty(outputFile3); err != nil {
+ t.Fatalf("Failed to wait for output file: %v", err)
+ }
+
+ firstNum2, err := readOutputNum(outputPath, 0)
+ if err != nil {
+ t.Fatalf("error with outputFile: %v", err)
+ }
+
+ // Check that lastNum is one less than firstNum and that the container picks
+ // up from where it left off.
+ if lastNum+1 != firstNum2 {
+ t.Errorf("error numbers not in order, previous: %d, next: %d", lastNum, firstNum2)
+ }
+ cont3.Destroy()
+ })
+ }
+}
+
+// TestUnixDomainSockets checks that Checkpoint/Restore works in cases
+// with filesystem Unix Domain Socket use.
+func TestUnixDomainSockets(t *testing.T) {
+ // Skip overlay because test requires writing to host file.
+ for name, conf := range configs(t, noOverlay...) {
+ t.Run(name, func(t *testing.T) {
+ // UDS path is limited to 108 chars for compatibility with older systems.
+ // Use '/tmp' (instead of testutil.TmpDir) to ensure the size limit is
+ // not exceeded. Assumes '/tmp' exists in the system.
+ dir, err := ioutil.TempDir("/tmp", "uds-test")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir failed: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ outputPath := filepath.Join(dir, "uds_output")
+ outputFile, err := os.OpenFile(outputPath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0666)
+ if err != nil {
+ t.Fatalf("error creating output file: %v", err)
+ }
+ defer outputFile.Close()
+
+ app, err := testutil.FindFile("test/cmd/test_app/test_app")
+ if err != nil {
+ t.Fatal("error finding test_app:", err)
+ }
+
+ socketPath := filepath.Join(dir, "uds_socket")
+ defer os.Remove(socketPath)
+
+ spec := testutil.NewSpecWithArgs(app, "uds", "--file", outputPath, "--socket", socketPath)
+ spec.Process.User = specs.User{
+ UID: uint32(os.Getuid()),
+ GID: uint32(os.Getgid()),
+ }
+ spec.Mounts = []specs.Mount{{
+ Type: "bind",
+ Destination: dir,
+ Source: dir,
+ }}
+
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // Set the image path, the location where the checkpoint image will be saved.
+ imagePath := filepath.Join(dir, "test-image-file")
+
+ // Create the image file and open for writing.
+ file, err := os.OpenFile(imagePath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0644)
+ if err != nil {
+ t.Fatalf("error opening new file at imagePath: %v", err)
+ }
+ defer file.Close()
+ defer os.RemoveAll(imagePath)
+
+ // Wait until application has ran.
+ if err := waitForFileNotEmpty(outputFile); err != nil {
+ t.Fatalf("Failed to wait for output file: %v", err)
+ }
+
+ // Checkpoint running container; save state into new file.
+ if err := cont.Checkpoint(file); err != nil {
+ t.Fatalf("error checkpointing container to empty file: %v", err)
+ }
+
+ // Read last number outputted before checkpoint.
+ lastNum, err := readOutputNum(outputPath, -1)
+ if err != nil {
+ t.Fatalf("error with outputFile: %v", err)
+ }
+
+ // Delete and recreate file before restoring.
+ if err := os.Remove(outputPath); err != nil {
+ t.Fatalf("error removing file")
+ }
+ outputFile2, err := os.OpenFile(outputPath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0666)
+ if err != nil {
+ t.Fatalf("error creating output file: %v", err)
+ }
+ defer outputFile2.Close()
+
+ // Restore into a new container.
+ argsRestore := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ contRestore, err := New(conf, argsRestore)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer contRestore.Destroy()
+
+ if err := contRestore.Restore(spec, conf, imagePath); err != nil {
+ t.Fatalf("error restoring container: %v", err)
+ }
+
+ // Wait until application has ran.
+ if err := waitForFileNotEmpty(outputFile2); err != nil {
+ t.Fatalf("Failed to wait for output file: %v", err)
+ }
+
+ // Read first number outputted after restore.
+ firstNum, err := readOutputNum(outputPath, 0)
+ if err != nil {
+ t.Fatalf("error with outputFile: %v", err)
+ }
+
+ // Check that lastNum is one less than firstNum.
+ if lastNum+1 != firstNum {
+ t.Errorf("error numbers not consecutive, previous: %d, next: %d", lastNum, firstNum)
+ }
+ contRestore.Destroy()
+ })
+ }
+}
+
+// TestPauseResume tests that we can successfully pause and resume a container.
+// The container will keep touching a file to indicate it's running. The test
+// pauses the container, removes the file, and checks that it doesn't get
+// recreated. Then it resumes the container, verify that the file gets created
+// again.
+func TestPauseResume(t *testing.T) {
+ for name, conf := range configs(t, noOverlay...) {
+ t.Run(name, func(t *testing.T) {
+ tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "lock")
+ if err != nil {
+ t.Fatalf("error creating temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ running := path.Join(tmpDir, "running")
+ script := fmt.Sprintf("while [[ true ]]; do touch %q; sleep 0.1; done", running)
+ spec := testutil.NewSpecWithArgs("/bin/bash", "-c", script)
+
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // Wait until container starts running, observed by the existence of running
+ // file.
+ if err := waitForFileExist(running); err != nil {
+ t.Errorf("error waiting for container to start: %v", err)
+ }
+
+ // Pause the running container.
+ if err := cont.Pause(); err != nil {
+ t.Errorf("error pausing container: %v", err)
+ }
+ if got, want := cont.Status, Paused; got != want {
+ t.Errorf("container status got %v, want %v", got, want)
+ }
+
+ if err := os.Remove(running); err != nil {
+ t.Fatalf("os.Remove(%q) failed: %v", running, err)
+ }
+ // Script touches the file every 100ms. Give a bit a time for it to run to
+ // catch the case that pause didn't work.
+ time.Sleep(200 * time.Millisecond)
+ if _, err := os.Stat(running); !os.IsNotExist(err) {
+ t.Fatalf("container did not pause: file exist check: %v", err)
+ }
+
+ // Resume the running container.
+ if err := cont.Resume(); err != nil {
+ t.Errorf("error pausing container: %v", err)
+ }
+ if got, want := cont.Status, Running; got != want {
+ t.Errorf("container status got %v, want %v", got, want)
+ }
+
+ // Verify that the file is once again created by container.
+ if err := waitForFileExist(running); err != nil {
+ t.Fatalf("error resuming container: file exist check: %v", err)
+ }
+ })
+ }
+}
+
+// TestPauseResumeStatus makes sure that the statuses are set correctly
+// with calls to pause and resume and that pausing and resuming only
+// occurs given the correct state.
+func TestPauseResumeStatus(t *testing.T) {
+ spec := testutil.NewSpecWithArgs("sleep", "20")
+ conf := testutil.TestConfig(t)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // Pause the running container.
+ if err := cont.Pause(); err != nil {
+ t.Errorf("error pausing container: %v", err)
+ }
+ if got, want := cont.Status, Paused; got != want {
+ t.Errorf("container status got %v, want %v", got, want)
+ }
+
+ // Try to Pause again. Should cause error.
+ if err := cont.Pause(); err == nil {
+ t.Errorf("error pausing container that was already paused: %v", err)
+ }
+ if got, want := cont.Status, Paused; got != want {
+ t.Errorf("container status got %v, want %v", got, want)
+ }
+
+ // Resume the running container.
+ if err := cont.Resume(); err != nil {
+ t.Errorf("error resuming container: %v", err)
+ }
+ if got, want := cont.Status, Running; got != want {
+ t.Errorf("container status got %v, want %v", got, want)
+ }
+
+ // Try to resume again. Should cause error.
+ if err := cont.Resume(); err == nil {
+ t.Errorf("error resuming container already running: %v", err)
+ }
+ if got, want := cont.Status, Running; got != want {
+ t.Errorf("container status got %v, want %v", got, want)
+ }
+}
+
+// TestCapabilities verifies that:
+// - Running exec as non-root UID and GID will result in an error (because the
+// executable file can't be read).
+// - Running exec as non-root with CAP_DAC_OVERRIDE succeeds because it skips
+// this check.
+func TestCapabilities(t *testing.T) {
+ // Pick uid/gid different than ours.
+ uid := auth.KUID(os.Getuid() + 1)
+ gid := auth.KGID(os.Getgid() + 1)
+
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ spec := testutil.NewSpecWithArgs("sleep", "100")
+ rootDir, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // expectedPL lists the expected process state of the container.
+ expectedPL := []*control.Process{
+ newProcessBuilder().Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(cont, expectedPL); err != nil {
+ t.Fatalf("Failed to wait for sleep to start, err: %v", err)
+ }
+
+ // Create an executable that can't be run with the specified UID:GID.
+ // This shouldn't be callable within the container until we add the
+ // CAP_DAC_OVERRIDE capability to skip the access check.
+ exePath := filepath.Join(rootDir, "exe")
+ if err := ioutil.WriteFile(exePath, []byte("#!/bin/sh\necho hello"), 0770); err != nil {
+ t.Fatalf("couldn't create executable: %v", err)
+ }
+ defer os.Remove(exePath)
+
+ // Need to traverse the intermediate directory.
+ os.Chmod(rootDir, 0755)
+
+ execArgs := &control.ExecArgs{
+ Filename: exePath,
+ Argv: []string{exePath},
+ WorkingDirectory: "/",
+ KUID: uid,
+ KGID: gid,
+ Capabilities: &auth.TaskCapabilities{},
+ }
+
+ // "exe" should fail because we don't have the necessary permissions.
+ if _, err := cont.executeSync(execArgs); err == nil {
+ t.Fatalf("container executed without error, but an error was expected")
+ }
+
+ // Now we run with the capability enabled and should succeed.
+ execArgs.Capabilities = &auth.TaskCapabilities{
+ EffectiveCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE),
+ }
+ // "exe" should not fail this time.
+ if _, err := cont.executeSync(execArgs); err != nil {
+ t.Fatalf("container failed to exec %v: %v", args, err)
+ }
+ })
+ }
+}
+
+// TestRunNonRoot checks that sandbox can be configured when running as
+// non-privileged user.
+func TestRunNonRoot(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, noOverlay...) {
+ t.Run(name, func(t *testing.T) {
+ spec := testutil.NewSpecWithArgs("/bin/true")
+
+ // Set a random user/group with no access to "blocked" dir.
+ spec.Process.User.UID = 343
+ spec.Process.User.GID = 2401
+ spec.Process.Capabilities = nil
+
+ // User running inside container can't list '$TMP/blocked' and would fail to
+ // mount it.
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "blocked")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+ if err := os.Chmod(dir, 0700); err != nil {
+ t.Fatalf("os.MkDir(%q) failed: %v", dir, err)
+ }
+ dir = path.Join(dir, "test")
+ if err := os.Mkdir(dir, 0755); err != nil {
+ t.Fatalf("os.MkDir(%q) failed: %v", dir, err)
+ }
+
+ src, err := ioutil.TempDir(testutil.TmpDir(), "src")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: dir,
+ Source: src,
+ Type: "bind",
+ })
+
+ if err := run(spec, conf); err != nil {
+ t.Fatalf("error running sandbox: %v", err)
+ }
+ })
+ }
+}
+
+// TestMountNewDir checks that runsc will create destination directory if it
+// doesn't exit.
+func TestMountNewDir(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, overlay) {
+ t.Run(name, func(t *testing.T) {
+ root, err := ioutil.TempDir(testutil.TmpDir(), "root")
+ if err != nil {
+ t.Fatal("ioutil.TempDir() failed:", err)
+ }
+
+ srcDir := path.Join(root, "src", "dir", "anotherdir")
+ if err := os.MkdirAll(srcDir, 0755); err != nil {
+ t.Fatalf("os.MkDir(%q) failed: %v", srcDir, err)
+ }
+
+ mountDir := path.Join(root, "dir", "anotherdir")
+
+ spec := testutil.NewSpecWithArgs("/bin/ls", mountDir)
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: mountDir,
+ Source: srcDir,
+ Type: "bind",
+ })
+
+ if err := run(spec, conf); err != nil {
+ t.Fatalf("error running sandbox: %v", err)
+ }
+ })
+ }
+}
+
+func TestReadonlyRoot(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, overlay) {
+ t.Run(name, func(t *testing.T) {
+ spec := testutil.NewSpecWithArgs("/bin/touch", "/foo")
+ spec.Root.Readonly = true
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create, start and wait for the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
+ if err := c.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ ws, err := c.Wait()
+ if err != nil {
+ t.Fatalf("error waiting on container: %v", err)
+ }
+ if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM {
+ t.Fatalf("container failed, waitStatus: %v", ws)
+ }
+ })
+ }
+}
+
+func TestUIDMap(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, noOverlay...) {
+ t.Run(name, func(t *testing.T) {
+ testDir, err := ioutil.TempDir(testutil.TmpDir(), "test-mount")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+ defer os.RemoveAll(testDir)
+ testFile := path.Join(testDir, "testfile")
+
+ spec := testutil.NewSpecWithArgs("touch", "/tmp/testfile")
+ uid := os.Getuid()
+ gid := os.Getgid()
+ spec.Linux = &specs.Linux{
+ Namespaces: []specs.LinuxNamespace{
+ {Type: specs.UserNamespace},
+ {Type: specs.PIDNamespace},
+ {Type: specs.MountNamespace},
+ },
+ UIDMappings: []specs.LinuxIDMapping{
+ {
+ ContainerID: 0,
+ HostID: uint32(uid),
+ Size: 1,
+ },
+ },
+ GIDMappings: []specs.LinuxIDMapping{
+ {
+ ContainerID: 0,
+ HostID: uint32(gid),
+ Size: 1,
+ },
+ },
+ }
+
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: "/tmp",
+ Source: testDir,
+ Type: "bind",
+ })
+
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create, start and wait for the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
+ if err := c.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ ws, err := c.Wait()
+ if err != nil {
+ t.Fatalf("error waiting on container: %v", err)
+ }
+ if !ws.Exited() || ws.ExitStatus() != 0 {
+ t.Fatalf("container failed, waitStatus: %v", ws)
+ }
+ st := syscall.Stat_t{}
+ if err := syscall.Stat(testFile, &st); err != nil {
+ t.Fatalf("error stat /testfile: %v", err)
+ }
+
+ if st.Uid != uint32(uid) || st.Gid != uint32(gid) {
+ t.Fatalf("UID: %d (%d) GID: %d (%d)", st.Uid, uid, st.Gid, gid)
+ }
+ })
+ }
+}
+
+func TestReadonlyMount(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, overlay) {
+ t.Run(name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "ro-mount")
+ spec := testutil.NewSpecWithArgs("/bin/touch", path.Join(dir, "file"))
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: dir,
+ Source: dir,
+ Type: "bind",
+ Options: []string{"ro"},
+ })
+ spec.Root.Readonly = false
+
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create, start and wait for the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
+ if err := c.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ ws, err := c.Wait()
+ if err != nil {
+ t.Fatalf("error waiting on container: %v", err)
+ }
+ if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM {
+ t.Fatalf("container failed, waitStatus: %v", ws)
+ }
+ })
+ }
+}
+
+// TestAbbreviatedIDs checks that runsc supports using abbreviated container
+// IDs in place of full IDs.
+func TestAbbreviatedIDs(t *testing.T) {
+ doAbbreviatedIDsTest(t, false)
+}
+
+func TestAbbreviatedIDsVFS2(t *testing.T) {
+ doAbbreviatedIDsTest(t, true)
+}
+
+func doAbbreviatedIDsTest(t *testing.T, vfs2 bool) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+
+ conf := testutil.TestConfig(t)
+ conf.RootDir = rootDir
+ conf.VFS2 = vfs2
+
+ cids := []string{
+ "foo-" + testutil.RandomContainerID(),
+ "bar-" + testutil.RandomContainerID(),
+ "baz-" + testutil.RandomContainerID(),
+ }
+ for _, cid := range cids {
+ spec := testutil.NewSpecWithArgs("sleep", "100")
+ bundleDir, cleanup, err := testutil.SetupBundleDir(spec)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create and start the container.
+ args := Args{
+ ID: cid,
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+ }
+
+ // These should all be unambigious.
+ unambiguous := map[string]string{
+ "f": cids[0],
+ cids[0]: cids[0],
+ "bar": cids[1],
+ cids[1]: cids[1],
+ "baz": cids[2],
+ cids[2]: cids[2],
+ }
+ for shortid, longid := range unambiguous {
+ if _, err := Load(rootDir, shortid); err != nil {
+ t.Errorf("%q should resolve to %q: %v", shortid, longid, err)
+ }
+ }
+
+ // These should be ambiguous.
+ ambiguous := []string{
+ "b",
+ "ba",
+ }
+ for _, shortid := range ambiguous {
+ if s, err := Load(rootDir, shortid); err == nil {
+ t.Errorf("%q should be ambiguous, but resolved to %q", shortid, s.ID)
+ }
+ }
+}
+
+func TestGoferExits(t *testing.T) {
+ doGoferExitTest(t, false)
+}
+
+func TestGoferExitsVFS2(t *testing.T) {
+ doGoferExitTest(t, true)
+}
+
+func doGoferExitTest(t *testing.T, vfs2 bool) {
+ spec := testutil.NewSpecWithArgs("/bin/sleep", "10000")
+ conf := testutil.TestConfig(t)
+ conf.VFS2 = vfs2
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
+ if err := c.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // Kill sandbox and expect gofer to exit on its own.
+ sandboxProc, err := os.FindProcess(c.Sandbox.Pid)
+ if err != nil {
+ t.Fatalf("error finding sandbox process: %v", err)
+ }
+ if err := sandboxProc.Kill(); err != nil {
+ t.Fatalf("error killing sandbox process: %v", err)
+ }
+
+ err = blockUntilWaitable(c.GoferPid)
+ if err != nil && err != syscall.ECHILD {
+ t.Errorf("error waiting for gofer to exit: %v", err)
+ }
+}
+
+func TestRootNotMount(t *testing.T) {
+ appSym, err := testutil.FindFile("test/cmd/test_app/test_app")
+ if err != nil {
+ t.Fatal("error finding test_app:", err)
+ }
+
+ app, err := filepath.EvalSymlinks(appSym)
+ if err != nil {
+ t.Fatalf("error resolving %q symlink: %v", appSym, err)
+ }
+ log.Infof("App path %q is a symlink to %q", appSym, app)
+
+ static, err := testutil.IsStatic(app)
+ if err != nil {
+ t.Fatalf("error reading application binary: %v", err)
+ }
+ if !static {
+ // This happens during race builds; we cannot map in shared
+ // libraries also, so we need to skip the test.
+ t.Skip()
+ }
+
+ root := filepath.Dir(app)
+ exe := "/" + filepath.Base(app)
+ log.Infof("Executing %q in %q", exe, root)
+
+ spec := testutil.NewSpecWithArgs(exe, "help")
+ spec.Root.Path = root
+ spec.Root.Readonly = true
+ spec.Mounts = nil
+
+ conf := testutil.TestConfig(t)
+ if err := run(spec, conf); err != nil {
+ t.Fatalf("error running sandbox: %v", err)
+ }
+}
+
+func TestUserLog(t *testing.T) {
+ app, err := testutil.FindFile("test/cmd/test_app/test_app")
+ if err != nil {
+ t.Fatal("error finding test_app:", err)
+ }
+
+ // sched_rr_get_interval = 148 - not implemented in gvisor.
+ spec := testutil.NewSpecWithArgs(app, "syscall", "--syscall=148")
+ conf := testutil.TestConfig(t)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "user_log_test")
+ if err != nil {
+ t.Fatalf("error creating tmp dir: %v", err)
+ }
+ userLog := filepath.Join(dir, "user.log")
+
+ // Create, start and wait for the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ UserLog: userLog,
+ Attached: true,
+ }
+ ws, err := Run(conf, args)
+ if err != nil {
+ t.Fatalf("error running container: %v", err)
+ }
+ if !ws.Exited() || ws.ExitStatus() != 0 {
+ t.Fatalf("container failed, waitStatus: %v", ws)
+ }
+
+ out, err := ioutil.ReadFile(userLog)
+ if err != nil {
+ t.Fatalf("error opening user log file %q: %v", userLog, err)
+ }
+ if want := "Unsupported syscall sched_rr_get_interval("; !strings.Contains(string(out), want) {
+ t.Errorf("user log file doesn't contain %q, out: %s", want, string(out))
+ }
+}
+
+func TestWaitOnExitedSandbox(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ // Run a shell that sleeps for 1 second and then exits with a
+ // non-zero code.
+ const wantExit = 17
+ cmd := fmt.Sprintf("sleep 1; exit %d", wantExit)
+ spec := testutil.NewSpecWithArgs("/bin/sh", "-c", cmd)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create and Start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
+ if err := c.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // Wait on the sandbox. This will make an RPC to the sandbox
+ // and get the actual exit status of the application.
+ ws, err := c.Wait()
+ if err != nil {
+ t.Fatalf("error waiting on container: %v", err)
+ }
+ if got := ws.ExitStatus(); got != wantExit {
+ t.Errorf("got exit status %d, want %d", got, wantExit)
+ }
+
+ // Now the sandbox has exited, but the zombie sandbox process
+ // still exists. Calling Wait() now will return the sandbox
+ // exit status.
+ ws, err = c.Wait()
+ if err != nil {
+ t.Fatalf("error waiting on container: %v", err)
+ }
+ if got := ws.ExitStatus(); got != wantExit {
+ t.Errorf("got exit status %d, want %d", got, wantExit)
+ }
+ })
+ }
+}
+
+func TestDestroyNotStarted(t *testing.T) {
+ doDestroyNotStartedTest(t, false)
+}
+
+func TestDestroyNotStartedVFS2(t *testing.T) {
+ doDestroyNotStartedTest(t, true)
+}
+
+func doDestroyNotStartedTest(t *testing.T, vfs2 bool) {
+ spec := testutil.NewSpecWithArgs("/bin/sleep", "100")
+ conf := testutil.TestConfig(t)
+ conf.VFS2 = vfs2
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create the container and check that it can be destroyed.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ if err := c.Destroy(); err != nil {
+ t.Fatalf("deleting non-started container failed: %v", err)
+ }
+}
+
+// TestDestroyStarting attempts to force a race between start and destroy.
+func TestDestroyStarting(t *testing.T) {
+ doDestroyNotStartedTest(t, false)
+}
+
+func TestDestroyStartedVFS2(t *testing.T) {
+ doDestroyNotStartedTest(t, true)
+}
+
+func doDestroyStartingTest(t *testing.T, vfs2 bool) {
+ for i := 0; i < 10; i++ {
+ spec := testutil.NewSpecWithArgs("/bin/sleep", "100")
+ conf := testutil.TestConfig(t)
+ conf.VFS2 = vfs2
+ rootDir, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create the container and check that it can be destroyed.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+
+ // Container is not thread safe, so load another instance to run in
+ // concurrently.
+ startCont, err := Load(rootDir, args.ID)
+ if err != nil {
+ t.Fatalf("error loading container: %v", err)
+ }
+ wg := sync.WaitGroup{}
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ // Ignore failures, start can fail if destroy runs first.
+ startCont.Start(conf)
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ if err := c.Destroy(); err != nil {
+ t.Errorf("deleting non-started container failed: %v", err)
+ }
+ }()
+ wg.Wait()
+ }
+}
+
+func TestCreateWorkingDir(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, overlay) {
+ t.Run(name, func(t *testing.T) {
+ tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "cwd-create")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+ dir := path.Join(tmpDir, "new/working/dir")
+
+ // touch will fail if the directory doesn't exist.
+ spec := testutil.NewSpecWithArgs("/bin/touch", path.Join(dir, "file"))
+ spec.Process.Cwd = dir
+ spec.Root.Readonly = true
+
+ if err := run(spec, conf); err != nil {
+ t.Fatalf("Error running container: %v", err)
+ }
+ })
+ }
+}
+
+// TestMountPropagation verifies that mount propagates to slave but not to
+// private mounts.
+func TestMountPropagation(t *testing.T) {
+ // Setup dir structure:
+ // - src: is mounted as shared and is used as source for both private and
+ // slave mounts
+ // - dir: will be bind mounted inside src and should propagate to slave
+ tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "mount")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+ src := filepath.Join(tmpDir, "src")
+ srcMnt := filepath.Join(src, "mnt")
+ dir := filepath.Join(tmpDir, "dir")
+ for _, path := range []string{src, srcMnt, dir} {
+ if err := os.MkdirAll(path, 0777); err != nil {
+ t.Fatalf("MkdirAll(%q): %v", path, err)
+ }
+ }
+ dirFile := filepath.Join(dir, "file")
+ f, err := os.Create(dirFile)
+ if err != nil {
+ t.Fatalf("os.Create(%q): %v", dirFile, err)
+ }
+ f.Close()
+
+ // Setup src as a shared mount.
+ if err := syscall.Mount(src, src, "bind", syscall.MS_BIND, ""); err != nil {
+ t.Fatalf("mount(%q, %q, MS_BIND): %v", dir, srcMnt, err)
+ }
+ if err := syscall.Mount("", src, "", syscall.MS_SHARED, ""); err != nil {
+ t.Fatalf("mount(%q, MS_SHARED): %v", srcMnt, err)
+ }
+
+ spec := testutil.NewSpecWithArgs("sleep", "1000")
+
+ priv := filepath.Join(tmpDir, "priv")
+ slave := filepath.Join(tmpDir, "slave")
+ spec.Mounts = []specs.Mount{
+ {
+ Source: src,
+ Destination: priv,
+ Type: "bind",
+ Options: []string{"private"},
+ },
+ {
+ Source: src,
+ Destination: slave,
+ Type: "bind",
+ Options: []string{"slave"},
+ },
+ }
+
+ conf := testutil.TestConfig(t)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("creating container: %v", err)
+ }
+ defer cont.Destroy()
+
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("starting container: %v", err)
+ }
+
+ // After the container is started, mount dir inside source and check what
+ // happens to both destinations.
+ if err := syscall.Mount(dir, srcMnt, "bind", syscall.MS_BIND, ""); err != nil {
+ t.Fatalf("mount(%q, %q, MS_BIND): %v", dir, srcMnt, err)
+ }
+
+ // Check that mount didn't propagate to private mount.
+ privFile := filepath.Join(priv, "mnt", "file")
+ execArgs := &control.ExecArgs{
+ Filename: "/usr/bin/test",
+ Argv: []string{"test", "!", "-f", privFile},
+ }
+ if ws, err := cont.executeSync(execArgs); err != nil || ws != 0 {
+ t.Fatalf("exec: test ! -f %q, ws: %v, err: %v", privFile, ws, err)
+ }
+
+ // Check that mount propagated to slave mount.
+ slaveFile := filepath.Join(slave, "mnt", "file")
+ execArgs = &control.ExecArgs{
+ Filename: "/usr/bin/test",
+ Argv: []string{"test", "-f", slaveFile},
+ }
+ if ws, err := cont.executeSync(execArgs); err != nil || ws != 0 {
+ t.Fatalf("exec: test -f %q, ws: %v, err: %v", privFile, ws, err)
+ }
+}
+
+func TestMountSymlink(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, overlay) {
+ t.Run(name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "mount-symlink")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ source := path.Join(dir, "source")
+ target := path.Join(dir, "target")
+ for _, path := range []string{source, target} {
+ if err := os.MkdirAll(path, 0777); err != nil {
+ t.Fatalf("os.MkdirAll(): %v", err)
+ }
+ }
+ f, err := os.Create(path.Join(source, "file"))
+ if err != nil {
+ t.Fatalf("os.Create(): %v", err)
+ }
+ f.Close()
+
+ link := path.Join(dir, "link")
+ if err := os.Symlink(target, link); err != nil {
+ t.Fatalf("os.Symlink(%q, %q): %v", target, link, err)
+ }
+
+ spec := testutil.NewSpecWithArgs("/bin/sleep", "1000")
+
+ // Mount to a symlink to ensure the mount code will follow it and mount
+ // at the symlink target.
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Type: "bind",
+ Destination: link,
+ Source: source,
+ })
+
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("creating container: %v", err)
+ }
+ defer cont.Destroy()
+
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("starting container: %v", err)
+ }
+
+ // Check that symlink was resolved and mount was created where the symlink
+ // is pointing to.
+ file := path.Join(target, "file")
+ execArgs := &control.ExecArgs{
+ Filename: "/usr/bin/test",
+ Argv: []string{"test", "-f", file},
+ }
+ if ws, err := cont.executeSync(execArgs); err != nil || ws != 0 {
+ t.Fatalf("exec: test -f %q, ws: %v, err: %v", file, ws, err)
+ }
+ })
+ }
+}
+
+// Check that --net-raw disables the CAP_NET_RAW capability.
+func TestNetRaw(t *testing.T) {
+ capNetRaw := strconv.FormatUint(bits.MaskOf64(int(linux.CAP_NET_RAW)), 10)
+ app, err := testutil.FindFile("test/cmd/test_app/test_app")
+ if err != nil {
+ t.Fatal("error finding test_app:", err)
+ }
+
+ for _, enableRaw := range []bool{true, false} {
+ conf := testutil.TestConfig(t)
+ conf.EnableRaw = enableRaw
+
+ test := "--enabled"
+ if !enableRaw {
+ test = "--disabled"
+ }
+
+ spec := testutil.NewSpecWithArgs(app, "capability", test, capNetRaw)
+ if err := run(spec, conf); err != nil {
+ t.Fatalf("Error running container: %v", err)
+ }
+ }
+}
+
+// TestTTYField checks TTY field returned by container.Processes().
+func TestTTYField(t *testing.T) {
+ stop := testutil.StartReaper()
+ defer stop()
+
+ testApp, err := testutil.FindFile("test/cmd/test_app/test_app")
+ if err != nil {
+ t.Fatal("error finding test_app:", err)
+ }
+
+ testCases := []struct {
+ name string
+ useTTY bool
+ wantTTYField string
+ }{
+ {
+ name: "no tty",
+ useTTY: false,
+ wantTTYField: "?",
+ },
+ {
+ name: "tty used",
+ useTTY: true,
+ wantTTYField: "pts/0",
+ },
+ }
+
+ for _, test := range testCases {
+ for _, vfs2 := range []bool{false, true} {
+ name := test.name
+ if vfs2 {
+ name += "-vfs2"
+ }
+ t.Run(name, func(t *testing.T) {
+ conf := testutil.TestConfig(t)
+ conf.VFS2 = vfs2
+
+ // We will run /bin/sleep, possibly with an open TTY.
+ cmd := []string{"/bin/sleep", "10000"}
+ if test.useTTY {
+ // Run inside the "pty-runner".
+ cmd = append([]string{testApp, "pty-runner"}, cmd...)
+ }
+
+ spec := testutil.NewSpecWithArgs(cmd...)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
+ if err := c.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // Wait for sleep to be running, and check the TTY
+ // field.
+ var gotTTYField string
+ cb := func() error {
+ ps, err := c.Processes()
+ if err != nil {
+ err = fmt.Errorf("error getting process data from container: %v", err)
+ return &backoff.PermanentError{Err: err}
+ }
+ for _, p := range ps {
+ if strings.Contains(p.Cmd, "sleep") {
+ gotTTYField = p.TTY
+ return nil
+ }
+ }
+ return fmt.Errorf("sleep not running")
+ }
+ if err := testutil.Poll(cb, 30*time.Second); err != nil {
+ t.Fatalf("error waiting for sleep process: %v", err)
+ }
+
+ if gotTTYField != test.wantTTYField {
+ t.Errorf("tty field got %q, want %q", gotTTYField, test.wantTTYField)
+ }
+ })
+ }
+ }
+}
+
+// executeSync synchronously executes a new process.
+func (cont *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus, error) {
+ pid, err := cont.Execute(args)
+ if err != nil {
+ return 0, fmt.Errorf("error executing: %v", err)
+ }
+ ws, err := cont.WaitPID(pid)
+ if err != nil {
+ return 0, fmt.Errorf("error waiting: %v", err)
+ }
+ return ws, nil
+}
+
+func TestMain(m *testing.M) {
+ log.SetLevel(log.Debug)
+ flag.Parse()
+ if err := testutil.ConfigureExePath(); err != nil {
+ panic(err.Error())
+ }
+ specutils.MaybeRunAsRoot()
+ os.Exit(m.Run())
+}
diff --git a/runsc/container/hook.go b/runsc/container/hook.go
new file mode 100644
index 000000000..901607aee
--- /dev/null
+++ b/runsc/container/hook.go
@@ -0,0 +1,111 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package container
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "os/exec"
+ "path/filepath"
+ "strings"
+ "time"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// This file implements hooks as defined in OCI spec:
+// https://github.com/opencontainers/runtime-spec/blob/master/config.md#toc22
+//
+// "hooks":{
+// "prestart":[{
+// "path":"/usr/bin/dockerd",
+// "args":[
+// "libnetwork-setkey", "arg2",
+// ]
+// }]
+// },
+
+// executeHooksBestEffort executes hooks and logs warning in case they fail.
+// Runs all hooks, always.
+func executeHooksBestEffort(hooks []specs.Hook, s specs.State) {
+ for _, h := range hooks {
+ if err := executeHook(h, s); err != nil {
+ log.Warningf("Failure to execute hook %+v, err: %v", h, err)
+ }
+ }
+}
+
+// executeHooks executes hooks until the first one fails or they all execute.
+func executeHooks(hooks []specs.Hook, s specs.State) error {
+ for _, h := range hooks {
+ if err := executeHook(h, s); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func executeHook(h specs.Hook, s specs.State) error {
+ log.Debugf("Executing hook %+v, state: %+v", h, s)
+
+ if strings.TrimSpace(h.Path) == "" {
+ return fmt.Errorf("empty path for hook")
+ }
+ if !filepath.IsAbs(h.Path) {
+ return fmt.Errorf("path for hook is not absolute: %q", h.Path)
+ }
+
+ b, err := json.Marshal(s)
+ if err != nil {
+ return err
+ }
+ var stdout, stderr bytes.Buffer
+ cmd := exec.Cmd{
+ Path: h.Path,
+ Args: h.Args,
+ Env: h.Env,
+ Stdin: bytes.NewReader(b),
+ Stdout: &stdout,
+ Stderr: &stderr,
+ }
+ if err := cmd.Start(); err != nil {
+ return err
+ }
+
+ c := make(chan error, 1)
+ go func() {
+ c <- cmd.Wait()
+ }()
+
+ var timer <-chan time.Time
+ if h.Timeout != nil {
+ timer = time.After(time.Duration(*h.Timeout) * time.Second)
+ }
+ select {
+ case err := <-c:
+ if err != nil {
+ return fmt.Errorf("failure executing hook %q, err: %v\nstdout: %s\nstderr: %s", h.Path, err, stdout.String(), stderr.String())
+ }
+ case <-timer:
+ cmd.Process.Kill()
+ cmd.Wait()
+ return fmt.Errorf("timeout executing hook %q\nstdout: %s\nstderr: %s", h.Path, stdout.String(), stderr.String())
+ }
+
+ log.Debugf("Execute hook %q success!", h.Path)
+ return nil
+}
diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go
new file mode 100644
index 000000000..e189648f4
--- /dev/null
+++ b/runsc/container/multi_container_test.go
@@ -0,0 +1,1774 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package container
+
+import (
+ "fmt"
+ "io/ioutil"
+ "math"
+ "os"
+ "path"
+ "path/filepath"
+ "strings"
+ "syscall"
+ "testing"
+ "time"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/sentry/control"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) {
+ var specs []*specs.Spec
+ var ids []string
+ rootID := testutil.RandomContainerID()
+
+ for i, cmd := range cmds {
+ spec := testutil.NewSpecWithArgs(cmd...)
+ if i == 0 {
+ spec.Annotations = map[string]string{
+ specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeSandbox,
+ }
+ ids = append(ids, rootID)
+ } else {
+ spec.Annotations = map[string]string{
+ specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeContainer,
+ specutils.ContainerdSandboxIDAnnotation: rootID,
+ }
+ ids = append(ids, testutil.RandomContainerID())
+ }
+ specs = append(specs, spec)
+ }
+ return specs, ids
+}
+
+func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*Container, func(), error) {
+ if len(conf.RootDir) == 0 {
+ panic("conf.RootDir not set. Call testutil.SetupRootDir() to set.")
+ }
+
+ cu := cleanup.Cleanup{}
+ defer cu.Clean()
+
+ var containers []*Container
+ for i, spec := range specs {
+ bundleDir, cleanup, err := testutil.SetupBundleDir(spec)
+ if err != nil {
+ return nil, nil, fmt.Errorf("error setting up container: %v", err)
+ }
+ cu.Add(cleanup)
+
+ args := Args{
+ ID: ids[i],
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ return nil, nil, fmt.Errorf("error creating container: %v", err)
+ }
+ cu.Add(func() { cont.Destroy() })
+ containers = append(containers, cont)
+
+ if err := cont.Start(conf); err != nil {
+ return nil, nil, fmt.Errorf("error starting container: %v", err)
+ }
+ }
+
+ return containers, cu.Release(), nil
+}
+
+type execDesc struct {
+ c *Container
+ cmd []string
+ want int
+ name string
+}
+
+func execMany(t *testing.T, execs []execDesc) {
+ for _, exec := range execs {
+ t.Run(exec.name, func(t *testing.T) {
+ args := &control.ExecArgs{Argv: exec.cmd}
+ if ws, err := exec.c.executeSync(args); err != nil {
+ t.Errorf("error executing %+v: %v", args, err)
+ } else if ws.ExitStatus() != exec.want {
+ t.Errorf("%q: exec %q got exit status: %d, want: %d", exec.name, exec.cmd, ws.ExitStatus(), exec.want)
+ }
+ })
+ }
+}
+
+func createSharedMount(mount specs.Mount, name string, pod ...*specs.Spec) {
+ for _, spec := range pod {
+ spec.Annotations[boot.MountPrefix+name+".source"] = mount.Source
+ spec.Annotations[boot.MountPrefix+name+".type"] = mount.Type
+ spec.Annotations[boot.MountPrefix+name+".share"] = "pod"
+ if len(mount.Options) > 0 {
+ spec.Annotations[boot.MountPrefix+name+".options"] = strings.Join(mount.Options, ",")
+ }
+ }
+}
+
+// TestMultiContainerSanity checks that it is possible to run 2 dead-simple
+// containers in the same sandbox.
+func TestMultiContainerSanity(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // Setup the containers.
+ sleep := []string{"sleep", "100"}
+ specs, ids := createSpecs(sleep, sleep)
+ containers, cleanup, err := startContainers(conf, specs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ // Check via ps that multiple processes are running.
+ expectedPL := []*control.Process{
+ newProcessBuilder().PID(1).PPID(0).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[0], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+ expectedPL = []*control.Process{
+ newProcessBuilder().PID(2).PPID(0).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[1], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+ })
+ }
+}
+
+// TestMultiPIDNS checks that it is possible to run 2 dead-simple
+// containers in the same sandbox with different pidns.
+func TestMultiPIDNS(t *testing.T) {
+ for name, conf := range configs(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // Setup the containers.
+ sleep := []string{"sleep", "100"}
+ testSpecs, ids := createSpecs(sleep, sleep)
+ testSpecs[1].Linux = &specs.Linux{
+ Namespaces: []specs.LinuxNamespace{
+ {
+ Type: "pid",
+ },
+ },
+ }
+
+ containers, cleanup, err := startContainers(conf, testSpecs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ // Check via ps that multiple processes are running.
+ expectedPL := []*control.Process{
+ newProcessBuilder().PID(1).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[0], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+ expectedPL = []*control.Process{
+ newProcessBuilder().PID(1).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[1], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+ })
+ }
+}
+
+// TestMultiPIDNSPath checks the pidns path.
+func TestMultiPIDNSPath(t *testing.T) {
+ for name, conf := range configs(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // Setup the containers.
+ sleep := []string{"sleep", "100"}
+ testSpecs, ids := createSpecs(sleep, sleep, sleep)
+ testSpecs[0].Linux = &specs.Linux{
+ Namespaces: []specs.LinuxNamespace{
+ {
+ Type: "pid",
+ Path: "/proc/1/ns/pid",
+ },
+ },
+ }
+ testSpecs[1].Linux = &specs.Linux{
+ Namespaces: []specs.LinuxNamespace{
+ {
+ Type: "pid",
+ Path: "/proc/1/ns/pid",
+ },
+ },
+ }
+ testSpecs[2].Linux = &specs.Linux{
+ Namespaces: []specs.LinuxNamespace{
+ {
+ Type: "pid",
+ Path: "/proc/2/ns/pid",
+ },
+ },
+ }
+
+ containers, cleanup, err := startContainers(conf, testSpecs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ // Check via ps that multiple processes are running.
+ expectedPL := []*control.Process{
+ newProcessBuilder().PID(1).PPID(0).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[0], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+ if err := waitForProcessList(containers[2], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+
+ expectedPL = []*control.Process{
+ newProcessBuilder().PID(2).PPID(0).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[1], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+ })
+ }
+}
+
+func TestMultiContainerWait(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+
+ conf := testutil.TestConfig(t)
+ conf.RootDir = rootDir
+
+ // The first container should run the entire duration of the test.
+ cmd1 := []string{"sleep", "100"}
+ // We'll wait on the second container, which is much shorter lived.
+ cmd2 := []string{"sleep", "1"}
+ specs, ids := createSpecs(cmd1, cmd2)
+
+ containers, cleanup, err := startContainers(conf, specs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ // Check via ps that multiple processes are running.
+ expectedPL := []*control.Process{
+ newProcessBuilder().PID(2).PPID(0).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[1], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+
+ // Wait on the short lived container from multiple goroutines.
+ wg := sync.WaitGroup{}
+ for i := 0; i < 3; i++ {
+ wg.Add(1)
+ go func(c *Container) {
+ defer wg.Done()
+ if ws, err := c.Wait(); err != nil {
+ t.Errorf("failed to wait for process %s: %v", c.Spec.Process.Args, err)
+ } else if es := ws.ExitStatus(); es != 0 {
+ t.Errorf("process %s exited with non-zero status %d", c.Spec.Process.Args, es)
+ }
+ if _, err := c.Wait(); err != nil {
+ t.Errorf("wait for stopped container %s shouldn't fail: %v", c.Spec.Process.Args, err)
+ }
+ }(containers[1])
+ }
+
+ // Also wait via PID.
+ for i := 0; i < 3; i++ {
+ wg.Add(1)
+ go func(c *Container) {
+ defer wg.Done()
+ const pid = 2
+ if ws, err := c.WaitPID(pid); err != nil {
+ t.Errorf("failed to wait for PID %d: %v", pid, err)
+ } else if es := ws.ExitStatus(); es != 0 {
+ t.Errorf("PID %d exited with non-zero status %d", pid, es)
+ }
+ if _, err := c.WaitPID(pid); err == nil {
+ t.Errorf("wait for stopped PID %d should fail", pid)
+ }
+ }(containers[1])
+ }
+
+ wg.Wait()
+
+ // After Wait returns, ensure that the root container is running and
+ // the child has finished.
+ expectedPL = []*control.Process{
+ newProcessBuilder().Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[0], expectedPL); err != nil {
+ t.Errorf("failed to wait for %q to start: %v", strings.Join(containers[0].Spec.Process.Args, " "), err)
+ }
+}
+
+// TestExecWait ensures what we can wait containers and individual processes in the
+// sandbox that have already exited.
+func TestExecWait(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+
+ conf := testutil.TestConfig(t)
+ conf.RootDir = rootDir
+
+ // The first container should run the entire duration of the test.
+ cmd1 := []string{"sleep", "100"}
+ // We'll wait on the second container, which is much shorter lived.
+ cmd2 := []string{"sleep", "1"}
+ specs, ids := createSpecs(cmd1, cmd2)
+ containers, cleanup, err := startContainers(conf, specs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ // Check via ps that process is running.
+ expectedPL := []*control.Process{
+ newProcessBuilder().Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[1], expectedPL); err != nil {
+ t.Fatalf("failed to wait for sleep to start: %v", err)
+ }
+
+ // Wait for the second container to finish.
+ if err := waitForProcessCount(containers[1], 0); err != nil {
+ t.Fatalf("failed to wait for second container to stop: %v", err)
+ }
+
+ // Get the second container exit status.
+ if ws, err := containers[1].Wait(); err != nil {
+ t.Fatalf("failed to wait for process %s: %v", containers[1].Spec.Process.Args, err)
+ } else if es := ws.ExitStatus(); es != 0 {
+ t.Fatalf("process %s exited with non-zero status %d", containers[1].Spec.Process.Args, es)
+ }
+ if _, err := containers[1].Wait(); err != nil {
+ t.Fatalf("wait for stopped container %s shouldn't fail: %v", containers[1].Spec.Process.Args, err)
+ }
+
+ // Execute another process in the first container.
+ args := &control.ExecArgs{
+ Filename: "/bin/sleep",
+ Argv: []string{"/bin/sleep", "1"},
+ WorkingDirectory: "/",
+ KUID: 0,
+ }
+ pid, err := containers[0].Execute(args)
+ if err != nil {
+ t.Fatalf("error executing: %v", err)
+ }
+
+ // Wait for the exec'd process to exit.
+ expectedPL = []*control.Process{
+ newProcessBuilder().PID(1).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[0], expectedPL); err != nil {
+ t.Fatalf("failed to wait for second container to stop: %v", err)
+ }
+
+ // Get the exit status from the exec'd process.
+ if ws, err := containers[0].WaitPID(pid); err != nil {
+ t.Fatalf("failed to wait for process %+v with pid %d: %v", args, pid, err)
+ } else if es := ws.ExitStatus(); es != 0 {
+ t.Fatalf("process %+v exited with non-zero status %d", args, es)
+ }
+ if _, err := containers[0].WaitPID(pid); err == nil {
+ t.Fatalf("wait for stopped process %+v should fail", args)
+ }
+}
+
+// TestMultiContainerMount tests that bind mounts can be used with multiple
+// containers.
+func TestMultiContainerMount(t *testing.T) {
+ cmd1 := []string{"sleep", "100"}
+
+ // 'src != dst' ensures that 'dst' doesn't exist in the host and must be
+ // properly mapped inside the container to work.
+ src, err := ioutil.TempDir(testutil.TmpDir(), "container")
+ if err != nil {
+ t.Fatal("ioutil.TempDir failed:", err)
+ }
+ dst := src + ".dst"
+ cmd2 := []string{"touch", filepath.Join(dst, "file")}
+
+ sps, ids := createSpecs(cmd1, cmd2)
+ sps[1].Mounts = append(sps[1].Mounts, specs.Mount{
+ Source: src,
+ Destination: dst,
+ Type: "bind",
+ })
+
+ // Setup the containers.
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+
+ conf := testutil.TestConfig(t)
+ conf.RootDir = rootDir
+
+ containers, cleanup, err := startContainers(conf, sps, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ ws, err := containers[1].Wait()
+ if err != nil {
+ t.Error("error waiting on container:", err)
+ }
+ if !ws.Exited() || ws.ExitStatus() != 0 {
+ t.Error("container failed, waitStatus:", ws)
+ }
+}
+
+// TestMultiContainerSignal checks that it is possible to signal individual
+// containers without killing the entire sandbox.
+func TestMultiContainerSignal(t *testing.T) {
+ for name, conf := range configs(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // Setup the containers.
+ sleep := []string{"sleep", "100"}
+ specs, ids := createSpecs(sleep, sleep)
+ containers, cleanup, err := startContainers(conf, specs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ // Check via ps that container 1 process is running.
+ expectedPL := []*control.Process{
+ newProcessBuilder().Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[1], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+
+ // Kill process 2.
+ if err := containers[1].SignalContainer(syscall.SIGKILL, false); err != nil {
+ t.Errorf("failed to kill process 2: %v", err)
+ }
+
+ // Make sure process 1 is still running.
+ expectedPL = []*control.Process{
+ newProcessBuilder().PID(1).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[0], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+
+ // goferPid is reset when container is destroyed.
+ goferPid := containers[1].GoferPid
+
+ // Destroy container and ensure container's gofer process has exited.
+ if err := containers[1].Destroy(); err != nil {
+ t.Errorf("failed to destroy container: %v", err)
+ }
+ _, _, err = specutils.RetryEintr(func() (uintptr, uintptr, error) {
+ cpid, err := syscall.Wait4(goferPid, nil, 0, nil)
+ return uintptr(cpid), 0, err
+ })
+ if err != syscall.ECHILD {
+ t.Errorf("error waiting for gofer to exit: %v", err)
+ }
+ // Make sure process 1 is still running.
+ if err := waitForProcessList(containers[0], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+
+ // Now that process 2 is gone, ensure we get an error trying to
+ // signal it again.
+ if err := containers[1].SignalContainer(syscall.SIGKILL, false); err == nil {
+ t.Errorf("container %q shouldn't exist, but we were able to signal it", containers[1].ID)
+ }
+
+ // Kill process 1.
+ if err := containers[0].SignalContainer(syscall.SIGKILL, false); err != nil {
+ t.Errorf("failed to kill process 1: %v", err)
+ }
+
+ // Ensure that container's gofer and sandbox process are no more.
+ err = blockUntilWaitable(containers[0].GoferPid)
+ if err != nil && err != syscall.ECHILD {
+ t.Errorf("error waiting for gofer to exit: %v", err)
+ }
+
+ err = blockUntilWaitable(containers[0].Sandbox.Pid)
+ if err != nil && err != syscall.ECHILD {
+ t.Errorf("error waiting for sandbox to exit: %v", err)
+ }
+
+ // The sentry should be gone, so signaling should yield an error.
+ if err := containers[0].SignalContainer(syscall.SIGKILL, false); err == nil {
+ t.Errorf("sandbox %q shouldn't exist, but we were able to signal it", containers[0].Sandbox.ID)
+ }
+
+ if err := containers[0].Destroy(); err != nil {
+ t.Errorf("failed to destroy container: %v", err)
+ }
+ })
+ }
+}
+
+// TestMultiContainerDestroy checks that container are properly cleaned-up when
+// they are destroyed.
+func TestMultiContainerDestroy(t *testing.T) {
+ app, err := testutil.FindFile("test/cmd/test_app/test_app")
+ if err != nil {
+ t.Fatal("error finding test_app:", err)
+ }
+
+ for name, conf := range configs(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // First container will remain intact while the second container is killed.
+ podSpecs, ids := createSpecs(
+ []string{"sleep", "100"},
+ []string{app, "fork-bomb"})
+
+ // Run the fork bomb in a PID namespace to prevent processes to be
+ // re-parented to PID=1 in the root container.
+ podSpecs[1].Linux = &specs.Linux{
+ Namespaces: []specs.LinuxNamespace{{Type: "pid"}},
+ }
+ containers, cleanup, err := startContainers(conf, podSpecs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ // Exec more processes to ensure signal all works for exec'd processes too.
+ args := &control.ExecArgs{
+ Filename: app,
+ Argv: []string{app, "fork-bomb"},
+ }
+ if _, err := containers[1].Execute(args); err != nil {
+ t.Fatalf("error exec'ing: %v", err)
+ }
+
+ // Let it brew...
+ time.Sleep(500 * time.Millisecond)
+
+ if err := containers[1].Destroy(); err != nil {
+ t.Fatalf("error destroying container: %v", err)
+ }
+
+ // Check that destroy killed all processes belonging to the container and
+ // waited for them to exit before returning.
+ pss, err := containers[0].Sandbox.Processes("")
+ if err != nil {
+ t.Fatalf("error getting process data from sandbox: %v", err)
+ }
+ expectedPL := []*control.Process{
+ newProcessBuilder().PID(1).Cmd("sleep").Process(),
+ }
+ if !procListsEqual(pss, expectedPL) {
+ t.Errorf("container got process list: %s, want: %s: error: %v",
+ procListToString(pss), procListToString(expectedPL), err)
+ }
+
+ // Check that cont.Destroy is safe to call multiple times.
+ if err := containers[1].Destroy(); err != nil {
+ t.Errorf("error destroying container: %v", err)
+ }
+ })
+ }
+}
+
+func TestMultiContainerProcesses(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+
+ conf := testutil.TestConfig(t)
+ conf.RootDir = rootDir
+
+ // Note: use curly braces to keep 'sh' process around. Otherwise, shell
+ // will just execve into 'sleep' and both containers will look the
+ // same.
+ specs, ids := createSpecs(
+ []string{"sleep", "100"},
+ []string{"sh", "-c", "{ sleep 100; }"})
+ containers, cleanup, err := startContainers(conf, specs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ // Check root's container process list doesn't include other containers.
+ expectedPL0 := []*control.Process{
+ newProcessBuilder().PID(1).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[0], expectedPL0); err != nil {
+ t.Errorf("failed to wait for process to start: %v", err)
+ }
+
+ // Same for the other container.
+ expectedPL1 := []*control.Process{
+ newProcessBuilder().PID(2).Cmd("sh").Process(),
+ newProcessBuilder().PID(3).PPID(2).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[1], expectedPL1); err != nil {
+ t.Errorf("failed to wait for process to start: %v", err)
+ }
+
+ // Now exec into the second container and verify it shows up in the container.
+ args := &control.ExecArgs{
+ Filename: "/bin/sleep",
+ Argv: []string{"/bin/sleep", "100"},
+ }
+ if _, err := containers[1].Execute(args); err != nil {
+ t.Fatalf("error exec'ing: %v", err)
+ }
+ expectedPL1 = append(expectedPL1, newProcessBuilder().PID(4).Cmd("sleep").Process())
+ if err := waitForProcessList(containers[1], expectedPL1); err != nil {
+ t.Errorf("failed to wait for process to start: %v", err)
+ }
+ // Root container should remain unchanged.
+ if err := waitForProcessList(containers[0], expectedPL0); err != nil {
+ t.Errorf("failed to wait for process to start: %v", err)
+ }
+}
+
+// TestMultiContainerKillAll checks that all process that belong to a container
+// are killed when SIGKILL is sent to *all* processes in that container.
+func TestMultiContainerKillAll(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+
+ conf := testutil.TestConfig(t)
+ conf.RootDir = rootDir
+
+ for _, tc := range []struct {
+ killContainer bool
+ }{
+ {killContainer: true},
+ {killContainer: false},
+ } {
+ app, err := testutil.FindFile("test/cmd/test_app/test_app")
+ if err != nil {
+ t.Fatal("error finding test_app:", err)
+ }
+
+ // First container will remain intact while the second container is killed.
+ specs, ids := createSpecs(
+ []string{app, "task-tree", "--depth=2", "--width=2"},
+ []string{app, "task-tree", "--depth=4", "--width=2"})
+ containers, cleanup, err := startContainers(conf, specs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ // Wait until all processes are created.
+ rootProcCount := int(math.Pow(2, 3) - 1)
+ if err := waitForProcessCount(containers[0], rootProcCount); err != nil {
+ t.Fatalf("error waitting for processes: %v", err)
+ }
+ procCount := int(math.Pow(2, 5) - 1)
+ if err := waitForProcessCount(containers[1], procCount); err != nil {
+ t.Fatalf("error waiting for processes: %v", err)
+ }
+
+ // Exec more processes to ensure signal works for exec'd processes too.
+ args := &control.ExecArgs{
+ Filename: app,
+ Argv: []string{app, "task-tree", "--depth=2", "--width=2"},
+ }
+ if _, err := containers[1].Execute(args); err != nil {
+ t.Fatalf("error exec'ing: %v", err)
+ }
+ // Wait for these new processes to start.
+ procCount += int(math.Pow(2, 3) - 1)
+ if err := waitForProcessCount(containers[1], procCount); err != nil {
+ t.Fatalf("error waiting for processes: %v", err)
+ }
+
+ if tc.killContainer {
+ // First kill the init process to make the container be stopped with
+ // processes still running inside.
+ containers[1].SignalContainer(syscall.SIGKILL, false)
+ op := func() error {
+ c, err := Load(conf.RootDir, ids[1])
+ if err != nil {
+ return err
+ }
+ if c.Status != Stopped {
+ return fmt.Errorf("container is not stopped")
+ }
+ return nil
+ }
+ if err := testutil.Poll(op, 5*time.Second); err != nil {
+ t.Fatalf("container did not stop %q: %v", containers[1].ID, err)
+ }
+ }
+
+ c, err := Load(conf.RootDir, ids[1])
+ if err != nil {
+ t.Fatalf("failed to load child container %q: %v", c.ID, err)
+ }
+ // Kill'Em All
+ if err := c.SignalContainer(syscall.SIGKILL, true); err != nil {
+ t.Fatalf("failed to send SIGKILL to container %q: %v", c.ID, err)
+ }
+
+ // Check that all processes are gone.
+ if err := waitForProcessCount(containers[1], 0); err != nil {
+ t.Fatalf("error waiting for processes: %v", err)
+ }
+ // Check that root container was not affected.
+ if err := waitForProcessCount(containers[0], rootProcCount); err != nil {
+ t.Fatalf("error waiting for processes: %v", err)
+ }
+ }
+}
+
+func TestMultiContainerDestroyNotStarted(t *testing.T) {
+ specs, ids := createSpecs(
+ []string{"/bin/sleep", "100"},
+ []string{"/bin/sleep", "100"})
+
+ conf := testutil.TestConfig(t)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(specs[0], conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ rootArgs := Args{
+ ID: ids[0],
+ Spec: specs[0],
+ BundleDir: bundleDir,
+ }
+ root, err := New(conf, rootArgs)
+ if err != nil {
+ t.Fatalf("error creating root container: %v", err)
+ }
+ defer root.Destroy()
+ if err := root.Start(conf); err != nil {
+ t.Fatalf("error starting root container: %v", err)
+ }
+
+ // Create and destroy sub-container.
+ bundleDir, cleanupSub, err := testutil.SetupBundleDir(specs[1])
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanupSub()
+
+ args := Args{
+ ID: ids[1],
+ Spec: specs[1],
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+
+ // Check that container can be destroyed.
+ if err := cont.Destroy(); err != nil {
+ t.Fatalf("deleting non-started container failed: %v", err)
+ }
+}
+
+// TestMultiContainerDestroyStarting attempts to force a race between start
+// and destroy.
+func TestMultiContainerDestroyStarting(t *testing.T) {
+ cmds := make([][]string, 10)
+ for i := range cmds {
+ cmds[i] = []string{"/bin/sleep", "100"}
+ }
+ specs, ids := createSpecs(cmds...)
+
+ conf := testutil.TestConfig(t)
+ rootDir, bundleDir, cleanup, err := testutil.SetupContainer(specs[0], conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ rootArgs := Args{
+ ID: ids[0],
+ Spec: specs[0],
+ BundleDir: bundleDir,
+ }
+ root, err := New(conf, rootArgs)
+ if err != nil {
+ t.Fatalf("error creating root container: %v", err)
+ }
+ defer root.Destroy()
+ if err := root.Start(conf); err != nil {
+ t.Fatalf("error starting root container: %v", err)
+ }
+
+ wg := sync.WaitGroup{}
+ for i := range cmds {
+ if i == 0 {
+ continue // skip root container
+ }
+
+ bundleDir, cleanup, err := testutil.SetupBundleDir(specs[i])
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ rootArgs := Args{
+ ID: ids[i],
+ Spec: specs[i],
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, rootArgs)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+
+ // Container is not thread safe, so load another instance to run in
+ // concurrently.
+ startCont, err := Load(rootDir, ids[i])
+ if err != nil {
+ t.Fatalf("error loading container: %v", err)
+ }
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ startCont.Start(conf) // ignore failures, start can fail if destroy runs first.
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ if err := cont.Destroy(); err != nil {
+ t.Errorf("deleting non-started container failed: %v", err)
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+// TestMultiContainerDifferentFilesystems tests that different containers have
+// different root filesystems.
+func TestMultiContainerDifferentFilesystems(t *testing.T) {
+ filename := "/foo"
+ // Root container will create file and then sleep.
+ cmdRoot := []string{"sh", "-c", fmt.Sprintf("touch %q && sleep 100", filename)}
+
+ // Child containers will assert that the file does not exist, and will
+ // then create it.
+ script := fmt.Sprintf("if [ -f %q ]; then exit 1; else touch %q; fi", filename, filename)
+ cmd := []string{"sh", "-c", script}
+
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+
+ conf := testutil.TestConfig(t)
+ conf.RootDir = rootDir
+
+ // Make sure overlay is enabled, and none of the root filesystems are
+ // read-only, otherwise we won't be able to create the file.
+ conf.Overlay = true
+ specs, ids := createSpecs(cmdRoot, cmd, cmd)
+ for _, s := range specs {
+ s.Root.Readonly = false
+ }
+
+ containers, cleanup, err := startContainers(conf, specs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ // Both child containers should exit successfully.
+ for i, c := range containers {
+ if i == 0 {
+ // Don't wait on the root.
+ continue
+ }
+ if ws, err := c.Wait(); err != nil {
+ t.Errorf("failed to wait for process %s: %v", c.Spec.Process.Args, err)
+ } else if es := ws.ExitStatus(); es != 0 {
+ t.Errorf("process %s exited with non-zero status %d", c.Spec.Process.Args, es)
+ }
+ }
+}
+
+// TestMultiContainerContainerDestroyStress tests that IO operations continue
+// to work after containers have been stopped and gofers killed.
+func TestMultiContainerContainerDestroyStress(t *testing.T) {
+ app, err := testutil.FindFile("test/cmd/test_app/test_app")
+ if err != nil {
+ t.Fatal("error finding test_app:", err)
+ }
+
+ // Setup containers. Root container just reaps children, while the others
+ // perform some IOs. Children are executed in 3 batches of 10. Within the
+ // batch there is overlap between containers starting and being destroyed. In
+ // between batches all containers stop before starting another batch.
+ cmds := [][]string{{app, "reaper"}}
+ const batchSize = 10
+ for i := 0; i < 3*batchSize; i++ {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "gofer-stop-test")
+ if err != nil {
+ t.Fatal("ioutil.TempDir failed:", err)
+ }
+ defer os.RemoveAll(dir)
+
+ cmd := "find /bin -type f | head | xargs -I SRC cp SRC " + dir
+ cmds = append(cmds, []string{"sh", "-c", cmd})
+ }
+ allSpecs, allIDs := createSpecs(cmds...)
+
+ // Split up the specs and IDs.
+ rootSpec := allSpecs[0]
+ rootID := allIDs[0]
+ childrenSpecs := allSpecs[1:]
+ childrenIDs := allIDs[1:]
+
+ conf := testutil.TestConfig(t)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(rootSpec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Start root container.
+ rootArgs := Args{
+ ID: rootID,
+ Spec: rootSpec,
+ BundleDir: bundleDir,
+ }
+ root, err := New(conf, rootArgs)
+ if err != nil {
+ t.Fatalf("error creating root container: %v", err)
+ }
+ if err := root.Start(conf); err != nil {
+ t.Fatalf("error starting root container: %v", err)
+ }
+ defer root.Destroy()
+
+ // Run batches. Each batch starts containers in parallel, then wait and
+ // destroy them before starting another batch.
+ for i := 0; i < len(childrenSpecs); i += batchSize {
+ t.Logf("Starting batch from %d to %d", i, i+batchSize)
+ specs := childrenSpecs[i : i+batchSize]
+ ids := childrenIDs[i : i+batchSize]
+
+ var children []*Container
+ for j, spec := range specs {
+ bundleDir, cleanup, err := testutil.SetupBundleDir(spec)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ args := Args{
+ ID: ids[j],
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ child, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ children = append(children, child)
+
+ if err := child.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // Give a small gap between containers.
+ time.Sleep(50 * time.Millisecond)
+ }
+ for _, child := range children {
+ ws, err := child.Wait()
+ if err != nil {
+ t.Fatalf("waiting for container: %v", err)
+ }
+ if !ws.Exited() || ws.ExitStatus() != 0 {
+ t.Fatalf("container failed, waitStatus: %x (%d)", ws, ws.ExitStatus())
+ }
+ if err := child.Destroy(); err != nil {
+ t.Fatalf("error destroying container: %v", err)
+ }
+ }
+ }
+}
+
+// Test that pod shared mounts are properly mounted in 2 containers and that
+// changes from one container is reflected in the other.
+func TestMultiContainerSharedMount(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // Setup the containers.
+ sleep := []string{"sleep", "100"}
+ podSpec, ids := createSpecs(sleep, sleep)
+ mnt0 := specs.Mount{
+ Destination: "/mydir/test",
+ Source: "/some/dir",
+ Type: "tmpfs",
+ Options: nil,
+ }
+ podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0)
+
+ mnt1 := mnt0
+ mnt1.Destination = "/mydir2/test2"
+ podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1)
+
+ createSharedMount(mnt0, "test-mount", podSpec...)
+
+ containers, cleanup, err := startContainers(conf, podSpec, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ file0 := path.Join(mnt0.Destination, "abc")
+ file1 := path.Join(mnt1.Destination, "abc")
+ execs := []execDesc{
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "-d", mnt0.Destination},
+ name: "directory is mounted in container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "-d", mnt1.Destination},
+ name: "directory is mounted in container1",
+ },
+ {
+ c: containers[0],
+ cmd: []string{"/bin/touch", file0},
+ name: "create file in container0",
+ },
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "-f", file0},
+ name: "file appears in container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "-f", file1},
+ name: "file appears in container1",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/bin/rm", file1},
+ name: "remove file from container1",
+ },
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "!", "-f", file0},
+ name: "file removed from container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "!", "-f", file1},
+ name: "file removed from container1",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/bin/mkdir", file1},
+ name: "create directory in container1",
+ },
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "-d", file0},
+ name: "dir appears in container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "-d", file1},
+ name: "dir appears in container1",
+ },
+ {
+ c: containers[0],
+ cmd: []string{"/bin/rmdir", file0},
+ name: "remove directory from container0",
+ },
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "!", "-d", file0},
+ name: "dir removed from container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "!", "-d", file1},
+ name: "dir removed from container1",
+ },
+ }
+ execMany(t, execs)
+ })
+ }
+}
+
+// Test that pod mounts are mounted as readonly when requested.
+func TestMultiContainerSharedMountReadonly(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // Setup the containers.
+ sleep := []string{"sleep", "100"}
+ podSpec, ids := createSpecs(sleep, sleep)
+ mnt0 := specs.Mount{
+ Destination: "/mydir/test",
+ Source: "/some/dir",
+ Type: "tmpfs",
+ Options: []string{"ro"},
+ }
+ podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0)
+
+ mnt1 := mnt0
+ mnt1.Destination = "/mydir2/test2"
+ podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1)
+
+ createSharedMount(mnt0, "test-mount", podSpec...)
+
+ containers, cleanup, err := startContainers(conf, podSpec, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ file0 := path.Join(mnt0.Destination, "abc")
+ file1 := path.Join(mnt1.Destination, "abc")
+ execs := []execDesc{
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "-d", mnt0.Destination},
+ name: "directory is mounted in container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "-d", mnt1.Destination},
+ name: "directory is mounted in container1",
+ },
+ {
+ c: containers[0],
+ cmd: []string{"/bin/touch", file0},
+ want: 1,
+ name: "fails to write to container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/bin/touch", file1},
+ want: 1,
+ name: "fails to write to container1",
+ },
+ }
+ execMany(t, execs)
+ })
+ }
+}
+
+// Test that shared pod mounts continue to work after container is restarted.
+func TestMultiContainerSharedMountRestart(t *testing.T) {
+ //TODO(gvisor.dev/issue/1487): This is failing with VFS2.
+ for name, conf := range configs(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // Setup the containers.
+ sleep := []string{"sleep", "100"}
+ podSpec, ids := createSpecs(sleep, sleep)
+ mnt0 := specs.Mount{
+ Destination: "/mydir/test",
+ Source: "/some/dir",
+ Type: "tmpfs",
+ Options: nil,
+ }
+ podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0)
+
+ mnt1 := mnt0
+ mnt1.Destination = "/mydir2/test2"
+ podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1)
+
+ createSharedMount(mnt0, "test-mount", podSpec...)
+
+ containers, cleanup, err := startContainers(conf, podSpec, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ file0 := path.Join(mnt0.Destination, "abc")
+ file1 := path.Join(mnt1.Destination, "abc")
+ execs := []execDesc{
+ {
+ c: containers[0],
+ cmd: []string{"/bin/touch", file0},
+ name: "create file in container0",
+ },
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "-f", file0},
+ name: "file appears in container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "-f", file1},
+ name: "file appears in container1",
+ },
+ }
+ execMany(t, execs)
+
+ containers[1].Destroy()
+
+ bundleDir, cleanup, err := testutil.SetupBundleDir(podSpec[1])
+ if err != nil {
+ t.Fatalf("error restarting container: %v", err)
+ }
+ defer cleanup()
+
+ args := Args{
+ ID: ids[1],
+ Spec: podSpec[1],
+ BundleDir: bundleDir,
+ }
+ containers[1], err = New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ if err := containers[1].Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ execs = []execDesc{
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "-f", file0},
+ name: "file is still in container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "-f", file1},
+ name: "file is still in container1",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/bin/rm", file1},
+ name: "file removed from container1",
+ },
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "!", "-f", file0},
+ name: "file removed from container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "!", "-f", file1},
+ name: "file removed from container1",
+ },
+ }
+ execMany(t, execs)
+ })
+ }
+}
+
+// Test that unsupported pod mounts options are ignored when matching master and
+// slave mounts.
+func TestMultiContainerSharedMountUnsupportedOptions(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // Setup the containers.
+ sleep := []string{"/bin/sleep", "100"}
+ podSpec, ids := createSpecs(sleep, sleep)
+ mnt0 := specs.Mount{
+ Destination: "/mydir/test",
+ Source: "/some/dir",
+ Type: "tmpfs",
+ Options: []string{"rw", "rbind", "relatime"},
+ }
+ podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0)
+
+ mnt1 := mnt0
+ mnt1.Destination = "/mydir2/test2"
+ mnt1.Options = []string{"rw", "nosuid"}
+ podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1)
+
+ createSharedMount(mnt0, "test-mount", podSpec...)
+
+ containers, cleanup, err := startContainers(conf, podSpec, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ execs := []execDesc{
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "-d", mnt0.Destination},
+ name: "directory is mounted in container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "-d", mnt1.Destination},
+ name: "directory is mounted in container1",
+ },
+ }
+ execMany(t, execs)
+ })
+ }
+}
+
+// Test that one container can send an FD to another container, even though
+// they have distinct MountNamespaces.
+func TestMultiContainerMultiRootCanHandleFDs(t *testing.T) {
+ app, err := testutil.FindFile("test/cmd/test_app/test_app")
+ if err != nil {
+ t.Fatal("error finding test_app:", err)
+ }
+
+ // We set up two containers with one shared mount that is used for a
+ // shared socket. The first container will send an FD over the socket
+ // to the second container. The FD corresponds to a file in the first
+ // container's mount namespace that is not part of the second
+ // container's mount namespace. However, the second container still
+ // should be able to read the FD.
+
+ // Create a shared mount where we will put the socket.
+ sharedMnt := specs.Mount{
+ Destination: "/mydir/test",
+ Type: "tmpfs",
+ // Shared mounts need a Source, even for tmpfs. It is only used
+ // to match up different shared mounts inside the pod.
+ Source: "/some/dir",
+ }
+ socketPath := filepath.Join(sharedMnt.Destination, "socket")
+
+ // Create a writeable tmpfs mount where the FD sender app will create
+ // files to send. This will only be mounted in the FD sender.
+ writeableMnt := specs.Mount{
+ Destination: "/tmp",
+ Type: "tmpfs",
+ }
+
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+
+ conf := testutil.TestConfig(t)
+ conf.RootDir = rootDir
+
+ // Create the specs.
+ specs, ids := createSpecs(
+ []string{"sleep", "1000"},
+ []string{app, "fd_sender", "--socket", socketPath},
+ []string{app, "fd_receiver", "--socket", socketPath},
+ )
+ createSharedMount(sharedMnt, "shared-mount", specs...)
+ specs[1].Mounts = append(specs[2].Mounts, sharedMnt, writeableMnt)
+ specs[2].Mounts = append(specs[1].Mounts, sharedMnt)
+
+ containers, cleanup, err := startContainers(conf, specs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ // Both containers should exit successfully.
+ for _, c := range containers[1:] {
+ if ws, err := c.Wait(); err != nil {
+ t.Errorf("failed to wait for process %s: %v", c.Spec.Process.Args, err)
+ } else if es := ws.ExitStatus(); es != 0 {
+ t.Errorf("process %s exited with non-zero status %d", c.Spec.Process.Args, es)
+ }
+ }
+}
+
+// Test that container is destroyed when Gofer is killed.
+func TestMultiContainerGoferKilled(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+
+ conf := testutil.TestConfig(t)
+ conf.RootDir = rootDir
+
+ sleep := []string{"sleep", "100"}
+ specs, ids := createSpecs(sleep, sleep, sleep)
+ containers, cleanup, err := startContainers(conf, specs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ // Ensure container is running
+ c := containers[2]
+ expectedPL := []*control.Process{
+ newProcessBuilder().PID(3).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(c, expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+
+ // Kill container's gofer.
+ if err := syscall.Kill(c.GoferPid, syscall.SIGKILL); err != nil {
+ t.Fatalf("syscall.Kill(%d, SIGKILL)=%v", c.GoferPid, err)
+ }
+
+ // Wait until container stops.
+ if err := waitForProcessList(c, nil); err != nil {
+ t.Errorf("Container %q was not stopped after gofer death: %v", c.ID, err)
+ }
+
+ // Check that container isn't running anymore.
+ args := &control.ExecArgs{Argv: []string{"/bin/true"}}
+ if _, err := c.executeSync(args); err == nil {
+ t.Fatalf("Container %q was not stopped after gofer death", c.ID)
+ }
+
+ // Check that other containers are unaffected.
+ for i, c := range containers {
+ if i == 2 {
+ continue // container[2] has been killed.
+ }
+ pl := []*control.Process{
+ newProcessBuilder().PID(kernel.ThreadID(i + 1)).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(c, pl); err != nil {
+ t.Errorf("Container %q was affected by another container: %v", c.ID, err)
+ }
+ args := &control.ExecArgs{Argv: []string{"/bin/true"}}
+ if _, err := c.executeSync(args); err != nil {
+ t.Fatalf("Container %q was affected by another container: %v", c.ID, err)
+ }
+ }
+
+ // Kill root container's gofer to bring entire sandbox down.
+ c = containers[0]
+ if err := syscall.Kill(c.GoferPid, syscall.SIGKILL); err != nil {
+ t.Fatalf("syscall.Kill(%d, SIGKILL)=%v", c.GoferPid, err)
+ }
+
+ // Wait until sandbox stops. waitForProcessList will loop until sandbox exits
+ // and RPC errors out.
+ impossiblePL := []*control.Process{
+ newProcessBuilder().Cmd("non-existent-process").Process(),
+ }
+ if err := waitForProcessList(c, impossiblePL); err == nil {
+ t.Fatalf("Sandbox was not killed after gofer death")
+ }
+
+ // Check that entire sandbox isn't running anymore.
+ for _, c := range containers {
+ args := &control.ExecArgs{Argv: []string{"/bin/true"}}
+ if _, err := c.executeSync(args); err == nil {
+ t.Fatalf("Container %q was not stopped after gofer death", c.ID)
+ }
+ }
+}
+
+func TestMultiContainerLoadSandbox(t *testing.T) {
+ sleep := []string{"sleep", "100"}
+ specs, ids := createSpecs(sleep, sleep, sleep)
+
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+
+ conf := testutil.TestConfig(t)
+ conf.RootDir = rootDir
+
+ // Create containers for the sandbox.
+ wants, cleanup, err := startContainers(conf, specs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ // Then create unrelated containers.
+ for i := 0; i < 3; i++ {
+ specs, ids = createSpecs(sleep, sleep, sleep)
+ _, cleanup, err = startContainers(conf, specs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+ }
+
+ // Create an unrelated directory under root.
+ dir := filepath.Join(conf.RootDir, "not-a-container")
+ if err := os.MkdirAll(dir, 0755); err != nil {
+ t.Fatalf("os.MkdirAll(%q)=%v", dir, err)
+ }
+
+ // Create a valid but empty container directory.
+ randomCID := testutil.RandomContainerID()
+ dir = filepath.Join(conf.RootDir, randomCID)
+ if err := os.MkdirAll(dir, 0755); err != nil {
+ t.Fatalf("os.MkdirAll(%q)=%v", dir, err)
+ }
+
+ // Load the sandbox and check that the correct containers were returned.
+ id := wants[0].Sandbox.ID
+ gots, err := loadSandbox(conf.RootDir, id)
+ if err != nil {
+ t.Fatalf("loadSandbox()=%v", err)
+ }
+ wantIDs := make(map[string]struct{})
+ for _, want := range wants {
+ wantIDs[want.ID] = struct{}{}
+ }
+ for _, got := range gots {
+ if got.Sandbox.ID != id {
+ t.Errorf("wrong sandbox ID, got: %v, want: %v", got.Sandbox.ID, id)
+ }
+ if _, ok := wantIDs[got.ID]; !ok {
+ t.Errorf("wrong container ID, got: %v, wants: %v", got.ID, wantIDs)
+ }
+ delete(wantIDs, got.ID)
+ }
+ if len(wantIDs) != 0 {
+ t.Errorf("containers not found: %v", wantIDs)
+ }
+}
+
+// TestMultiContainerRunNonRoot checks that child container can be configured
+// when running as non-privileged user.
+func TestMultiContainerRunNonRoot(t *testing.T) {
+ cmdRoot := []string{"/bin/sleep", "100"}
+ cmdSub := []string{"/bin/true"}
+ podSpecs, ids := createSpecs(cmdRoot, cmdSub)
+
+ // User running inside container can't list '$TMP/blocked' and would fail to
+ // mount it.
+ blocked, err := ioutil.TempDir(testutil.TmpDir(), "blocked")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+ if err := os.Chmod(blocked, 0700); err != nil {
+ t.Fatalf("os.MkDir(%q) failed: %v", blocked, err)
+ }
+ dir := path.Join(blocked, "test")
+ if err := os.Mkdir(dir, 0755); err != nil {
+ t.Fatalf("os.MkDir(%q) failed: %v", dir, err)
+ }
+
+ src, err := ioutil.TempDir(testutil.TmpDir(), "src")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+
+ // Set a random user/group with no access to "blocked" dir.
+ podSpecs[1].Process.User.UID = 343
+ podSpecs[1].Process.User.GID = 2401
+ podSpecs[1].Process.Capabilities = nil
+
+ podSpecs[1].Mounts = append(podSpecs[1].Mounts, specs.Mount{
+ Destination: dir,
+ Source: src,
+ Type: "bind",
+ })
+
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+
+ conf := testutil.TestConfig(t)
+ conf.RootDir = rootDir
+
+ pod, cleanup, err := startContainers(conf, podSpecs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ // Once all containers are started, wait for the child container to exit.
+ // This means that the volume was mounted properly.
+ ws, err := pod[1].Wait()
+ if err != nil {
+ t.Fatalf("running child container: %v", err)
+ }
+ if !ws.Exited() || ws.ExitStatus() != 0 {
+ t.Fatalf("child container failed, waitStatus: %v", ws)
+ }
+}
+
+// TestMultiContainerHomeEnvDir tests that the HOME environment variable is set
+// for root containers, sub-containers, and execed processes.
+func TestMultiContainerHomeEnvDir(t *testing.T) {
+ // TODO(gvisor.dev/issue/1487): VFSv2 configs failing.
+ // NOTE: Don't use overlay since we need changes to persist to the temp dir
+ // outside the sandbox.
+ for testName, conf := range configs(t, noOverlay...) {
+ t.Run(testName, func(t *testing.T) {
+
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // Create temp files we can write the value of $HOME to.
+ homeDirs := map[string]*os.File{}
+ for _, name := range []string{"root", "sub", "exec"} {
+ homeFile, err := ioutil.TempFile(testutil.TmpDir(), name)
+ if err != nil {
+ t.Fatalf("creating temp file: %v", err)
+ }
+ homeDirs[name] = homeFile
+ }
+
+ // We will sleep in the root container in order to ensure that
+ // the root container doesn't terminate before sub containers can be
+ // created.
+ rootCmd := []string{"/bin/sh", "-c", fmt.Sprintf("printf \"$HOME\" > %s; sleep 1000", homeDirs["root"].Name())}
+ subCmd := []string{"/bin/sh", "-c", fmt.Sprintf("printf \"$HOME\" > %s", homeDirs["sub"].Name())}
+ execCmd := []string{"/bin/sh", "-c", fmt.Sprintf("printf \"$HOME\" > %s", homeDirs["exec"].Name())}
+
+ // Setup the containers, a root container and sub container.
+ specConfig, ids := createSpecs(rootCmd, subCmd)
+ containers, cleanup, err := startContainers(conf, specConfig, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ // Exec into the root container synchronously.
+ args := &control.ExecArgs{Argv: execCmd}
+ if _, err := containers[0].executeSync(args); err != nil {
+ t.Errorf("error executing %+v: %v", args, err)
+ }
+
+ // Wait for the subcontainer to finish.
+ _, err = containers[1].Wait()
+ if err != nil {
+ t.Errorf("wait on child container: %v", err)
+ }
+
+ // Wait for the root container to run.
+ expectedPL := []*control.Process{
+ newProcessBuilder().Cmd("sh").Process(),
+ newProcessBuilder().Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[0], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+
+ // Check the written files.
+ for name, tmpFile := range homeDirs {
+ dirBytes, err := ioutil.ReadAll(tmpFile)
+ if err != nil {
+ t.Fatalf("reading %s temp file: %v", name, err)
+ }
+ got := string(dirBytes)
+
+ want := "/"
+ if got != want {
+ t.Errorf("%s $HOME incorrect: got: %q, want: %q", name, got, want)
+ }
+ }
+
+ })
+ }
+}
diff --git a/runsc/container/shared_volume_test.go b/runsc/container/shared_volume_test.go
new file mode 100644
index 000000000..bac177a88
--- /dev/null
+++ b/runsc/container/shared_volume_test.go
@@ -0,0 +1,273 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package container
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sentry/control"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/runsc/boot"
+)
+
+// TestSharedVolume checks that modifications to a volume mount are propagated
+// into and out of the sandbox.
+func TestSharedVolume(t *testing.T) {
+ conf := testutil.TestConfig(t)
+ conf.FileAccess = boot.FileAccessShared
+
+ // Main process just sleeps. We will use "exec" to probe the state of
+ // the filesystem.
+ spec := testutil.NewSpecWithArgs("sleep", "1000")
+
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "shared-volume-test")
+ if err != nil {
+ t.Fatalf("TempDir failed: %v", err)
+ }
+
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
+ if err := c.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // File that will be used to check consistency inside/outside sandbox.
+ filename := filepath.Join(dir, "file")
+
+ // File does not exist yet. Reading from the sandbox should fail.
+ argsTestFile := &control.ExecArgs{
+ Filename: "/usr/bin/test",
+ Argv: []string{"test", "-f", filename},
+ }
+ if ws, err := c.executeSync(argsTestFile); err != nil {
+ t.Fatalf("unexpected error testing file %q: %v", filename, err)
+ } else if ws.ExitStatus() == 0 {
+ t.Errorf("test %q exited with code %v, wanted not zero", ws.ExitStatus(), err)
+ }
+
+ // Create the file from outside of the sandbox.
+ if err := ioutil.WriteFile(filename, []byte("foobar"), 0777); err != nil {
+ t.Fatalf("error writing to file %q: %v", filename, err)
+ }
+
+ // Now we should be able to test the file from within the sandbox.
+ if ws, err := c.executeSync(argsTestFile); err != nil {
+ t.Fatalf("unexpected error testing file %q: %v", filename, err)
+ } else if ws.ExitStatus() != 0 {
+ t.Errorf("test %q exited with code %v, wanted zero", filename, ws.ExitStatus())
+ }
+
+ // Rename the file from outside of the sandbox.
+ newFilename := filepath.Join(dir, "newfile")
+ if err := os.Rename(filename, newFilename); err != nil {
+ t.Fatalf("os.Rename(%q, %q) failed: %v", filename, newFilename, err)
+ }
+
+ // File should no longer exist at the old path within the sandbox.
+ if ws, err := c.executeSync(argsTestFile); err != nil {
+ t.Fatalf("unexpected error testing file %q: %v", filename, err)
+ } else if ws.ExitStatus() == 0 {
+ t.Errorf("test %q exited with code %v, wanted not zero", filename, ws.ExitStatus())
+ }
+
+ // We should be able to test the new filename from within the sandbox.
+ argsTestNewFile := &control.ExecArgs{
+ Filename: "/usr/bin/test",
+ Argv: []string{"test", "-f", newFilename},
+ }
+ if ws, err := c.executeSync(argsTestNewFile); err != nil {
+ t.Fatalf("unexpected error testing file %q: %v", newFilename, err)
+ } else if ws.ExitStatus() != 0 {
+ t.Errorf("test %q exited with code %v, wanted zero", newFilename, ws.ExitStatus())
+ }
+
+ // Delete the renamed file from outside of the sandbox.
+ if err := os.Remove(newFilename); err != nil {
+ t.Fatalf("error removing file %q: %v", filename, err)
+ }
+
+ // Renamed file should no longer exist at the old path within the sandbox.
+ if ws, err := c.executeSync(argsTestNewFile); err != nil {
+ t.Fatalf("unexpected error testing file %q: %v", newFilename, err)
+ } else if ws.ExitStatus() == 0 {
+ t.Errorf("test %q exited with code %v, wanted not zero", newFilename, ws.ExitStatus())
+ }
+
+ // Now create the file from WITHIN the sandbox.
+ argsTouch := &control.ExecArgs{
+ Filename: "/usr/bin/touch",
+ Argv: []string{"touch", filename},
+ KUID: auth.KUID(os.Getuid()),
+ KGID: auth.KGID(os.Getgid()),
+ }
+ if ws, err := c.executeSync(argsTouch); err != nil {
+ t.Fatalf("unexpected error touching file %q: %v", filename, err)
+ } else if ws.ExitStatus() != 0 {
+ t.Errorf("touch %q exited with code %v, wanted zero", filename, ws.ExitStatus())
+ }
+
+ // File should exist outside the sandbox.
+ if _, err := os.Stat(filename); err != nil {
+ t.Errorf("stat %q got error %v, wanted nil", filename, err)
+ }
+
+ // File should exist outside the sandbox.
+ if _, err := os.Stat(filename); err != nil {
+ t.Errorf("stat %q got error %v, wanted nil", filename, err)
+ }
+
+ // Delete the file from within the sandbox.
+ argsRemove := &control.ExecArgs{
+ Filename: "/bin/rm",
+ Argv: []string{"rm", filename},
+ }
+ if ws, err := c.executeSync(argsRemove); err != nil {
+ t.Fatalf("unexpected error removing file %q: %v", filename, err)
+ } else if ws.ExitStatus() != 0 {
+ t.Errorf("remove %q exited with code %v, wanted zero", filename, ws.ExitStatus())
+ }
+
+ // File should not exist outside the sandbox.
+ if _, err := os.Stat(filename); !os.IsNotExist(err) {
+ t.Errorf("stat %q got error %v, wanted ErrNotExist", filename, err)
+ }
+}
+
+func checkFile(c *Container, filename string, want []byte) error {
+ cpy := filename + ".copy"
+ argsCp := &control.ExecArgs{
+ Filename: "/bin/cp",
+ Argv: []string{"cp", "-f", filename, cpy},
+ }
+ if _, err := c.executeSync(argsCp); err != nil {
+ return fmt.Errorf("unexpected error copying file %q to %q: %v", filename, cpy, err)
+ }
+ got, err := ioutil.ReadFile(cpy)
+ if err != nil {
+ return fmt.Errorf("Error reading file %q: %v", filename, err)
+ }
+ if !bytes.Equal(got, want) {
+ return fmt.Errorf("file content inside the sandbox is wrong, got: %q, want: %q", got, want)
+ }
+ return nil
+}
+
+// TestSharedVolumeFile tests that changes to file content outside the sandbox
+// is reflected inside.
+func TestSharedVolumeFile(t *testing.T) {
+ conf := testutil.TestConfig(t)
+ conf.FileAccess = boot.FileAccessShared
+
+ // Main process just sleeps. We will use "exec" to probe the state of
+ // the filesystem.
+ spec := testutil.NewSpecWithArgs("sleep", "1000")
+
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "shared-volume-test")
+ if err != nil {
+ t.Fatalf("TempDir failed: %v", err)
+ }
+
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
+ if err := c.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // File that will be used to check consistency inside/outside sandbox.
+ filename := filepath.Join(dir, "file")
+
+ // Write file from outside the container and check that the same content is
+ // read inside.
+ want := []byte("host-")
+ if err := ioutil.WriteFile(filename, []byte(want), 0666); err != nil {
+ t.Fatalf("Error writing to %q: %v", filename, err)
+ }
+ if err := checkFile(c, filename, want); err != nil {
+ t.Fatal(err.Error())
+ }
+
+ // Append to file inside the container and check that content is not lost.
+ argsAppend := &control.ExecArgs{
+ Filename: "/bin/bash",
+ Argv: []string{"bash", "-c", "echo -n sandbox- >> " + filename},
+ }
+ if _, err := c.executeSync(argsAppend); err != nil {
+ t.Fatalf("unexpected error appending file %q: %v", filename, err)
+ }
+ want = []byte("host-sandbox-")
+ if err := checkFile(c, filename, want); err != nil {
+ t.Fatal(err.Error())
+ }
+
+ // Write again from outside the container and check that the same content is
+ // read inside.
+ f, err := os.OpenFile(filename, os.O_APPEND|os.O_WRONLY, 0)
+ if err != nil {
+ t.Fatalf("Error openning file %q: %v", filename, err)
+ }
+ defer f.Close()
+ if _, err := f.Write([]byte("host")); err != nil {
+ t.Fatalf("Error writing to file %q: %v", filename, err)
+ }
+ want = []byte("host-sandbox-host")
+ if err := checkFile(c, filename, want); err != nil {
+ t.Fatal(err.Error())
+ }
+
+ // Shrink file outside and check that the same content is read inside.
+ if err := f.Truncate(5); err != nil {
+ t.Fatalf("Error truncating file %q: %v", filename, err)
+ }
+ want = want[:5]
+ if err := checkFile(c, filename, want); err != nil {
+ t.Fatal(err.Error())
+ }
+}
diff --git a/runsc/container/state_file.go b/runsc/container/state_file.go
new file mode 100644
index 000000000..17a251530
--- /dev/null
+++ b/runsc/container/state_file.go
@@ -0,0 +1,185 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package container
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+
+ "github.com/gofrs/flock"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+const stateFileExtension = ".state"
+
+// StateFile handles load from/save to container state safely from multiple
+// processes. It uses a lock file to provide synchronization between operations.
+//
+// The lock file is located at: "${s.RootDir}/${s.ID}.lock".
+// The state file is located at: "${s.RootDir}/${s.ID}.state".
+type StateFile struct {
+ // RootDir is the directory containing the container metadata file.
+ RootDir string `json:"rootDir"`
+
+ // ID is the container ID.
+ ID string `json:"id"`
+
+ //
+ // Fields below this line are not saved in the state file and will not
+ // be preserved across commands.
+ //
+
+ once sync.Once
+ flock *flock.Flock
+}
+
+// List returns all container ids in the given root directory.
+func List(rootDir string) ([]string, error) {
+ log.Debugf("List containers %q", rootDir)
+ list, err := filepath.Glob(filepath.Join(rootDir, "*"+stateFileExtension))
+ if err != nil {
+ return nil, err
+ }
+ var out []string
+ for _, path := range list {
+ // Filter out files that do no belong to a container.
+ fileName := filepath.Base(path)
+ if len(fileName) < len(stateFileExtension) {
+ panic(fmt.Sprintf("invalid file match %q", path))
+ }
+ // Remove the extension.
+ cid := fileName[:len(fileName)-len(stateFileExtension)]
+ if validateID(cid) == nil {
+ out = append(out, cid)
+ }
+ }
+ return out, nil
+}
+
+// lock globally locks all locking operations for the container.
+func (s *StateFile) lock() error {
+ s.once.Do(func() {
+ s.flock = flock.NewFlock(s.lockPath())
+ })
+
+ if err := s.flock.Lock(); err != nil {
+ return fmt.Errorf("acquiring lock on %q: %v", s.flock, err)
+ }
+ return nil
+}
+
+// lockForNew acquires the lock and checks if the state file doesn't exist. This
+// is done to ensure that more than one creation didn't race to create
+// containers with the same ID.
+func (s *StateFile) lockForNew() error {
+ if err := s.lock(); err != nil {
+ return err
+ }
+
+ // Checks if the container already exists by looking for the metadata file.
+ if _, err := os.Stat(s.statePath()); err == nil {
+ s.unlock()
+ return fmt.Errorf("container already exists")
+ } else if !os.IsNotExist(err) {
+ s.unlock()
+ return fmt.Errorf("looking for existing container: %v", err)
+ }
+ return nil
+}
+
+// unlock globally unlocks all locking operations for the container.
+func (s *StateFile) unlock() error {
+ if !s.flock.Locked() {
+ panic("unlock called without lock held")
+ }
+
+ if err := s.flock.Unlock(); err != nil {
+ log.Warningf("Error to release lock on %q: %v", s.flock, err)
+ return fmt.Errorf("releasing lock on %q: %v", s.flock, err)
+ }
+ return nil
+}
+
+// saveLocked saves 'v' to the state file.
+//
+// Preconditions: lock() must been called before.
+func (s *StateFile) saveLocked(v interface{}) error {
+ if !s.flock.Locked() {
+ panic("saveLocked called without lock held")
+ }
+
+ meta, err := json.Marshal(v)
+ if err != nil {
+ return err
+ }
+ if err := ioutil.WriteFile(s.statePath(), meta, 0640); err != nil {
+ return fmt.Errorf("writing json file: %v", err)
+ }
+ return nil
+}
+
+func (s *StateFile) load(v interface{}) error {
+ if err := s.lock(); err != nil {
+ return err
+ }
+ defer s.unlock()
+
+ metaBytes, err := ioutil.ReadFile(s.statePath())
+ if err != nil {
+ return err
+ }
+ return json.Unmarshal(metaBytes, &v)
+}
+
+func (s *StateFile) close() error {
+ if s.flock == nil {
+ return nil
+ }
+ if s.flock.Locked() {
+ panic("Closing locked file")
+ }
+ return s.flock.Close()
+}
+
+func buildStatePath(rootDir, id string) string {
+ return filepath.Join(rootDir, id+stateFileExtension)
+}
+
+// statePath is the full path to the state file.
+func (s *StateFile) statePath() string {
+ return buildStatePath(s.RootDir, s.ID)
+}
+
+// lockPath is the full path to the lock file.
+func (s *StateFile) lockPath() string {
+ return filepath.Join(s.RootDir, s.ID+".lock")
+}
+
+// destroy deletes all state created by the stateFile. It may be called with the
+// lock file held. In that case, the lock file must still be unlocked and
+// properly closed after destroy returns.
+func (s *StateFile) destroy() error {
+ if err := os.Remove(s.statePath()); err != nil && !os.IsNotExist(err) {
+ return err
+ }
+ if err := os.Remove(s.lockPath()); err != nil && !os.IsNotExist(err) {
+ return err
+ }
+ return nil
+}
diff --git a/runsc/container/status.go b/runsc/container/status.go
new file mode 100644
index 000000000..91d9112f1
--- /dev/null
+++ b/runsc/container/status.go
@@ -0,0 +1,60 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package container
+
+// Status enumerates container statuses. The statuses and their semantics are
+// part of the runtime CLI spec.
+type Status int
+
+const (
+ // Created indicates "the runtime has finished the create operation and
+ // the container process has neither exited nor executed the
+ // user-specified program".
+ Created Status = iota
+
+ // Creating indicates "the container is being created".
+ Creating
+
+ // Paused indicates that the process within the container has been
+ // suspended.
+ Paused
+
+ // Running indicates "the container process has executed the
+ // user-specified program but has not exited".
+ Running
+
+ // Stopped indicates "the container process has exited".
+ Stopped
+)
+
+// String converts a Status to a string. These strings are part of the runtime
+// CLI spec and should not be changed.
+func (s Status) String() string {
+ switch s {
+ case Created:
+ return "created"
+ case Creating:
+ return "creating"
+ case Paused:
+ return "paused"
+ case Running:
+ return "running"
+ case Stopped:
+ return "stopped"
+ default:
+ return "unknown"
+ }
+
+}
diff --git a/runsc/debian/description b/runsc/debian/description
new file mode 100644
index 000000000..9e8e08805
--- /dev/null
+++ b/runsc/debian/description
@@ -0,0 +1 @@
+gVisor container sandbox runtime
diff --git a/runsc/debian/postinst.sh b/runsc/debian/postinst.sh
new file mode 100755
index 000000000..dc7aeee87
--- /dev/null
+++ b/runsc/debian/postinst.sh
@@ -0,0 +1,24 @@
+#!/bin/sh -e
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if [ "$1" != configure ]; then
+ exit 0
+fi
+
+if [ -f /etc/docker/daemon.json ]; then
+ runsc install
+ systemctl restart docker || echo "unable to restart docker; you must do so manually." >&2
+fi
diff --git a/runsc/flag/BUILD b/runsc/flag/BUILD
new file mode 100644
index 000000000..5cb7604a8
--- /dev/null
+++ b/runsc/flag/BUILD
@@ -0,0 +1,9 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "flag",
+ srcs = ["flag.go"],
+ visibility = ["//:sandbox"],
+)
diff --git a/runsc/flag/flag.go b/runsc/flag/flag.go
new file mode 100644
index 000000000..0ca4829d7
--- /dev/null
+++ b/runsc/flag/flag.go
@@ -0,0 +1,33 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flag
+
+import (
+ "flag"
+)
+
+type FlagSet = flag.FlagSet
+
+var (
+ NewFlagSet = flag.NewFlagSet
+ String = flag.String
+ Bool = flag.Bool
+ Int = flag.Int
+ Uint = flag.Uint
+ CommandLine = flag.CommandLine
+ Parse = flag.Parse
+)
+
+const ContinueOnError = flag.ContinueOnError
diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD
new file mode 100644
index 000000000..1036b0630
--- /dev/null
+++ b/runsc/fsgofer/BUILD
@@ -0,0 +1,35 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "fsgofer",
+ srcs = [
+ "fsgofer.go",
+ "fsgofer_amd64_unsafe.go",
+ "fsgofer_arm64_unsafe.go",
+ "fsgofer_unsafe.go",
+ ],
+ visibility = ["//runsc:__subpackages__"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/cleanup",
+ "//pkg/fd",
+ "//pkg/log",
+ "//pkg/p9",
+ "//pkg/sync",
+ "//pkg/syserr",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "fsgofer_test",
+ size = "small",
+ srcs = ["fsgofer_test.go"],
+ library = ":fsgofer",
+ deps = [
+ "//pkg/log",
+ "//pkg/p9",
+ ],
+)
diff --git a/runsc/fsgofer/filter/BUILD b/runsc/fsgofer/filter/BUILD
new file mode 100644
index 000000000..82b48ef32
--- /dev/null
+++ b/runsc/fsgofer/filter/BUILD
@@ -0,0 +1,26 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "filter",
+ srcs = [
+ "config.go",
+ "config_amd64.go",
+ "config_arm64.go",
+ "extra_filters.go",
+ "extra_filters_msan.go",
+ "extra_filters_race.go",
+ "filter.go",
+ ],
+ visibility = [
+ "//runsc:__subpackages__",
+ ],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/flipcall",
+ "//pkg/log",
+ "//pkg/seccomp",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go
new file mode 100644
index 000000000..88814b83c
--- /dev/null
+++ b/runsc/fsgofer/filter/config.go
@@ -0,0 +1,250 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package filter
+
+import (
+ "os"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+// allowedSyscalls is the set of syscalls executed by the gofer.
+var allowedSyscalls = seccomp.SyscallRules{
+ syscall.SYS_ACCEPT: {},
+ syscall.SYS_CLOCK_GETTIME: {},
+ syscall.SYS_CLONE: []seccomp.Rule{
+ {
+ seccomp.AllowValue(
+ syscall.CLONE_VM |
+ syscall.CLONE_FS |
+ syscall.CLONE_FILES |
+ syscall.CLONE_SIGHAND |
+ syscall.CLONE_SYSVSEM |
+ syscall.CLONE_THREAD),
+ },
+ },
+ syscall.SYS_CLOSE: {},
+ syscall.SYS_DUP: {},
+ syscall.SYS_EPOLL_CTL: {},
+ syscall.SYS_EPOLL_PWAIT: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(0),
+ },
+ },
+ syscall.SYS_EVENTFD2: []seccomp.Rule{
+ {
+ seccomp.AllowValue(0),
+ seccomp.AllowValue(0),
+ },
+ },
+ syscall.SYS_EXIT: {},
+ syscall.SYS_EXIT_GROUP: {},
+ syscall.SYS_FALLOCATE: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(0),
+ },
+ },
+ syscall.SYS_FCHMOD: {},
+ syscall.SYS_FCHOWNAT: {},
+ syscall.SYS_FCNTL: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.F_GETFL),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.F_SETFL),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.F_GETFD),
+ },
+ // Used by flipcall.PacketWindowAllocator.Init().
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(unix.F_ADD_SEALS),
+ },
+ },
+ syscall.SYS_FSTAT: {},
+ syscall.SYS_FSTATFS: {},
+ syscall.SYS_FSYNC: {},
+ syscall.SYS_FTRUNCATE: {},
+ syscall.SYS_FUTEX: {
+ seccomp.Rule{
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG),
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(0),
+ },
+ seccomp.Rule{
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG),
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(0),
+ },
+ // Non-private futex used for flipcall.
+ seccomp.Rule{
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAIT),
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ },
+ seccomp.Rule{
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAKE),
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ },
+ },
+ syscall.SYS_GETDENTS64: {},
+ syscall.SYS_GETPID: {},
+ unix.SYS_GETRANDOM: {},
+ syscall.SYS_GETTID: {},
+ syscall.SYS_GETTIMEOFDAY: {},
+ syscall.SYS_LINKAT: {},
+ syscall.SYS_LSEEK: {},
+ syscall.SYS_MADVISE: {},
+ unix.SYS_MEMFD_CREATE: {}, /// Used by flipcall.PacketWindowAllocator.Init().
+ syscall.SYS_MKDIRAT: {},
+ syscall.SYS_MKNODAT: {},
+ // Used by the Go runtime as a temporarily workaround for a Linux
+ // 5.2-5.4 bug.
+ //
+ // See src/runtime/os_linux_x86.go.
+ //
+ // TODO(b/148688965): Remove once this is gone from Go.
+ syscall.SYS_MLOCK: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(4096),
+ },
+ },
+ syscall.SYS_MMAP: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.MAP_SHARED),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_FIXED),
+ },
+ },
+ syscall.SYS_MPROTECT: {},
+ syscall.SYS_MUNMAP: {},
+ syscall.SYS_NANOSLEEP: {},
+ syscall.SYS_OPENAT: {},
+ syscall.SYS_PPOLL: {},
+ syscall.SYS_PREAD64: {},
+ syscall.SYS_PWRITE64: {},
+ syscall.SYS_READ: {},
+ syscall.SYS_READLINKAT: {},
+ syscall.SYS_RECVMSG: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC | syscall.MSG_PEEK),
+ },
+ },
+ syscall.SYS_RENAMEAT: {},
+ syscall.SYS_RESTART_SYSCALL: {},
+ syscall.SYS_RT_SIGPROCMASK: {},
+ syscall.SYS_RT_SIGRETURN: {},
+ syscall.SYS_SCHED_YIELD: {},
+ syscall.SYS_SENDMSG: []seccomp.Rule{
+ // Used by fdchannel.Endpoint.SendFD().
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(0),
+ },
+ // Used by unet.SocketWriter.WriteVec().
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_NOSIGNAL),
+ },
+ },
+ syscall.SYS_SHUTDOWN: []seccomp.Rule{
+ {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RDWR)},
+ },
+ syscall.SYS_SIGALTSTACK: {},
+ // Used by fdchannel.NewConnectedSockets().
+ syscall.SYS_SOCKETPAIR: {
+ {
+ seccomp.AllowValue(syscall.AF_UNIX),
+ seccomp.AllowValue(syscall.SOCK_SEQPACKET | syscall.SOCK_CLOEXEC),
+ seccomp.AllowValue(0),
+ },
+ },
+ syscall.SYS_SYMLINKAT: {},
+ syscall.SYS_TGKILL: []seccomp.Rule{
+ {
+ seccomp.AllowValue(uint64(os.Getpid())),
+ },
+ },
+ syscall.SYS_UNLINKAT: {},
+ syscall.SYS_UTIMENSAT: {},
+ syscall.SYS_WRITE: {},
+}
+
+var udsSyscalls = seccomp.SyscallRules{
+ syscall.SYS_SOCKET: []seccomp.Rule{
+ {
+ seccomp.AllowValue(syscall.AF_UNIX),
+ seccomp.AllowValue(syscall.SOCK_STREAM),
+ seccomp.AllowValue(0),
+ },
+ {
+ seccomp.AllowValue(syscall.AF_UNIX),
+ seccomp.AllowValue(syscall.SOCK_DGRAM),
+ seccomp.AllowValue(0),
+ },
+ {
+ seccomp.AllowValue(syscall.AF_UNIX),
+ seccomp.AllowValue(syscall.SOCK_SEQPACKET),
+ seccomp.AllowValue(0),
+ },
+ },
+ syscall.SYS_CONNECT: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ },
+ },
+}
diff --git a/runsc/fsgofer/filter/config_amd64.go b/runsc/fsgofer/filter/config_amd64.go
new file mode 100644
index 000000000..a4b28cb8b
--- /dev/null
+++ b/runsc/fsgofer/filter/config_amd64.go
@@ -0,0 +1,33 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package filter
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+func init() {
+ allowedSyscalls[syscall.SYS_ARCH_PRCTL] = []seccomp.Rule{
+ {seccomp.AllowValue(linux.ARCH_GET_FS)},
+ {seccomp.AllowValue(linux.ARCH_SET_FS)},
+ }
+
+ allowedSyscalls[syscall.SYS_NEWFSTATAT] = []seccomp.Rule{}
+}
diff --git a/runsc/fsgofer/filter/config_arm64.go b/runsc/fsgofer/filter/config_arm64.go
new file mode 100644
index 000000000..d2697deb7
--- /dev/null
+++ b/runsc/fsgofer/filter/config_arm64.go
@@ -0,0 +1,27 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package filter
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+func init() {
+ allowedSyscalls[syscall.SYS_FSTATAT] = []seccomp.Rule{}
+}
diff --git a/runsc/fsgofer/filter/extra_filters.go b/runsc/fsgofer/filter/extra_filters.go
new file mode 100644
index 000000000..e28d4b8d6
--- /dev/null
+++ b/runsc/fsgofer/filter/extra_filters.go
@@ -0,0 +1,28 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !msan,!race
+
+package filter
+
+import (
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+// instrumentationFilters returns additional filters for syscalls used by
+// Go instrumentation tools, e.g. -race, -msan.
+// Returns empty when disabled.
+func instrumentationFilters() seccomp.SyscallRules {
+ return nil
+}
diff --git a/runsc/fsgofer/filter/extra_filters_msan.go b/runsc/fsgofer/filter/extra_filters_msan.go
new file mode 100644
index 000000000..8c6179c8f
--- /dev/null
+++ b/runsc/fsgofer/filter/extra_filters_msan.go
@@ -0,0 +1,33 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build msan
+
+package filter
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+// instrumentationFilters returns additional filters for syscalls used by MSAN.
+func instrumentationFilters() seccomp.SyscallRules {
+ log.Warningf("*** SECCOMP WARNING: MSAN is enabled: syscall filters less restrictive!")
+ return seccomp.SyscallRules{
+ syscall.SYS_SCHED_GETAFFINITY: {},
+ syscall.SYS_SET_ROBUST_LIST: {},
+ }
+}
diff --git a/runsc/fsgofer/filter/extra_filters_race.go b/runsc/fsgofer/filter/extra_filters_race.go
new file mode 100644
index 000000000..885c92f7a
--- /dev/null
+++ b/runsc/fsgofer/filter/extra_filters_race.go
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build race
+
+package filter
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+// instrumentationFilters returns additional filters for syscalls used by TSAN.
+func instrumentationFilters() seccomp.SyscallRules {
+ log.Warningf("*** SECCOMP WARNING: TSAN is enabled: syscall filters less restrictive!")
+ return seccomp.SyscallRules{
+ syscall.SYS_BRK: {},
+ syscall.SYS_CLONE: {},
+ syscall.SYS_FUTEX: {},
+ syscall.SYS_MADVISE: {},
+ syscall.SYS_MMAP: {},
+ syscall.SYS_MUNLOCK: {},
+ syscall.SYS_NANOSLEEP: {},
+ syscall.SYS_OPEN: {},
+ syscall.SYS_SET_ROBUST_LIST: {},
+ // Used within glibc's malloc.
+ syscall.SYS_TIME: {},
+ }
+}
diff --git a/runsc/fsgofer/filter/filter.go b/runsc/fsgofer/filter/filter.go
new file mode 100644
index 000000000..289886720
--- /dev/null
+++ b/runsc/fsgofer/filter/filter.go
@@ -0,0 +1,38 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package filter defines all syscalls the gofer is allowed to make, and
+// installs seccomp filters to prevent prohibited syscalls in case it's
+// compromised.
+package filter
+
+import (
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+// Install installs seccomp filters.
+func Install() error {
+ // Set of additional filters used by -race and -msan. Returns empty
+ // when not enabled.
+ allowedSyscalls.Merge(instrumentationFilters())
+
+ return seccomp.Install(allowedSyscalls)
+}
+
+// InstallUDSFilters extends the allowed syscalls to include those necessary for
+// connecting to a host UDS.
+func InstallUDSFilters() {
+ // Add additional filters required for connecting to the host's sockets.
+ allowedSyscalls.Merge(udsSyscalls)
+}
diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go
new file mode 100644
index 000000000..b7521bda7
--- /dev/null
+++ b/runsc/fsgofer/fsgofer.go
@@ -0,0 +1,1181 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fsgofer implements p9.File giving access to local files using
+// a simple mapping from a path prefix that is added to the path requested
+// by the sandbox. Ex:
+//
+// prefix: "/docker/imgs/alpine"
+// app path: /bin/ls => /docker/imgs/alpine/bin/ls
+package fsgofer
+
+import (
+ "fmt"
+ "io"
+ "math"
+ "os"
+ "path"
+ "path/filepath"
+ "runtime"
+ "strconv"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+const (
+ // invalidMode is set to a value that doesn't match any other valid
+ // modes to ensure an unopened/closed file fails all mode checks.
+ invalidMode = p9.OpenFlags(math.MaxUint32)
+
+ openFlags = syscall.O_NOFOLLOW | syscall.O_CLOEXEC
+)
+
+type fileType int
+
+const (
+ regular fileType = iota
+ directory
+ symlink
+ socket
+ unknown
+)
+
+// String implements fmt.Stringer.
+func (f fileType) String() string {
+ switch f {
+ case regular:
+ return "regular"
+ case directory:
+ return "directory"
+ case symlink:
+ return "symlink"
+ case socket:
+ return "socket"
+ }
+ return "unknown"
+}
+
+// ControlSocketAddr generates an abstract unix socket name for the given id.
+func ControlSocketAddr(id string) string {
+ return fmt.Sprintf("\x00runsc-gofer.%s", id)
+}
+
+// Config sets configuration options for each attach point.
+type Config struct {
+ // ROMount is set to true if this is a readonly mount.
+ ROMount bool
+
+ // PanicOnWrite panics on attempts to write to RO mounts.
+ PanicOnWrite bool
+
+ // HostUDS signals whether the gofer can mount a host's UDS.
+ HostUDS bool
+}
+
+type attachPoint struct {
+ prefix string
+ conf Config
+
+ // attachedMu protects attached.
+ attachedMu sync.Mutex
+ attached bool
+
+ // deviceMu protects devices and nextDevice.
+ deviceMu sync.Mutex
+
+ // nextDevice is the next device id that will be allocated.
+ nextDevice uint8
+
+ // devices is a map from actual host devices to "small" integers that
+ // can be combined with host inode to form a unique virtual inode id.
+ devices map[uint64]uint8
+}
+
+// NewAttachPoint creates a new attacher that gives local file
+// access to all files under 'prefix'. 'prefix' must be an absolute path.
+func NewAttachPoint(prefix string, c Config) (p9.Attacher, error) {
+ // Sanity check the prefix.
+ if !filepath.IsAbs(prefix) {
+ return nil, fmt.Errorf("attach point prefix must be absolute %q", prefix)
+ }
+ return &attachPoint{
+ prefix: prefix,
+ conf: c,
+ devices: make(map[uint64]uint8),
+ }, nil
+}
+
+// Attach implements p9.Attacher.
+func (a *attachPoint) Attach() (p9.File, error) {
+ a.attachedMu.Lock()
+ defer a.attachedMu.Unlock()
+
+ if a.attached {
+ return nil, fmt.Errorf("attach point already attached, prefix: %s", a.prefix)
+ }
+
+ f, err := openAnyFile(a.prefix, func(mode int) (*fd.FD, error) {
+ return fd.Open(a.prefix, openFlags|mode, 0)
+ })
+ if err != nil {
+ return nil, fmt.Errorf("unable to open %q: %v", a.prefix, err)
+ }
+
+ stat, err := fstat(f.FD())
+ if err != nil {
+ return nil, fmt.Errorf("unable to stat %q: %v", a.prefix, err)
+ }
+
+ lf, err := newLocalFile(a, f, a.prefix, stat)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create localFile %q: %v", a.prefix, err)
+ }
+ a.attached = true
+ return lf, nil
+}
+
+// makeQID returns a unique QID for the given stat buffer.
+func (a *attachPoint) makeQID(stat syscall.Stat_t) p9.QID {
+ a.deviceMu.Lock()
+ defer a.deviceMu.Unlock()
+
+ // First map the host device id to a unique 8-bit integer.
+ dev, ok := a.devices[stat.Dev]
+ if !ok {
+ a.devices[stat.Dev] = a.nextDevice
+ dev = a.nextDevice
+ a.nextDevice++
+ if a.nextDevice < dev {
+ panic(fmt.Sprintf("device id overflow! map: %+v", a.devices))
+ }
+ }
+
+ // Construct a "virtual" inode id with the uint8 device number in the
+ // first 8 bits, and the rest of the bits from the host inode id.
+ maskedIno := stat.Ino & 0x00ffffffffffffff
+ if maskedIno != stat.Ino {
+ log.Warningf("first 8 bytes of host inode id %x will be truncated to construct virtual inode id", stat.Ino)
+ }
+ ino := uint64(dev)<<56 | maskedIno
+ return p9.QID{
+ Type: p9.FileMode(stat.Mode).QIDType(),
+ Path: ino,
+ }
+}
+
+// localFile implements p9.File wrapping a local file. The underlying file
+// is opened during Walk() and stored in 'file' to be used with other
+// operations. The file is opened as readonly, unless it's a symlink or there is
+// no read access, which requires O_PATH. 'file' is dup'ed when Walk(nil) is
+// called to clone the file. This reduces the number of walks that need to be
+// done by the host file system when files are reused.
+//
+// The file may be reopened if the requested mode in Open() is not a subset of
+// current mode. Consequently, 'file' could have a mode wider than requested and
+// must be verified before read/write operations. Before the file is opened and
+// after it's closed, 'mode' is set to an invalid value to prevent an unopened
+// file from being used.
+//
+// The reason that the file is not opened initially as read-write is for better
+// performance with 'overlay2' storage driver. overlay2 eagerly copies the
+// entire file up when it's opened in write mode, and would perform badly when
+// multiple files are only being opened for read (esp. startup).
+type localFile struct {
+ p9.DefaultWalkGetAttr
+
+ // attachPoint is the attachPoint that serves this localFile.
+ attachPoint *attachPoint
+
+ // hostPath will be safely updated by the Renamed hook.
+ hostPath string
+
+ // file is opened when localFile is created and it's never nil. It may be
+ // reopened if the Open() mode is wider than the mode the file was originally
+ // opened with.
+ file *fd.FD
+
+ // mode is the mode in which the file was opened. Set to invalidMode
+ // if localFile isn't opened.
+ mode p9.OpenFlags
+
+ // ft is the fileType for this file.
+ ft fileType
+
+ // readDirMu protects against concurrent Readdir calls.
+ readDirMu sync.Mutex
+
+ // lastDirentOffset is the last offset returned by Readdir(). If another call
+ // to Readdir is made at the same offset, the file doesn't need to be
+ // repositioned. This is an important optimization because the caller must
+ // always make one extra call to detect EOF (empty result, no error).
+ lastDirentOffset uint64
+}
+
+var procSelfFD *fd.FD
+
+// OpenProcSelfFD opens the /proc/self/fd directory, which will be used to
+// reopen file descriptors.
+func OpenProcSelfFD() error {
+ d, err := syscall.Open("/proc/self/fd", syscall.O_RDONLY|syscall.O_DIRECTORY, 0)
+ if err != nil {
+ return fmt.Errorf("error opening /proc/self/fd: %v", err)
+ }
+ procSelfFD = fd.New(d)
+ return nil
+}
+
+func reopenProcFd(f *fd.FD, mode int) (*fd.FD, error) {
+ d, err := syscall.Openat(int(procSelfFD.FD()), strconv.Itoa(f.FD()), mode&^syscall.O_NOFOLLOW, 0)
+ if err != nil {
+ return nil, err
+ }
+
+ return fd.New(d), nil
+}
+
+func openAnyFileFromParent(parent *localFile, name string) (*fd.FD, string, error) {
+ path := path.Join(parent.hostPath, name)
+ f, err := openAnyFile(path, func(mode int) (*fd.FD, error) {
+ return fd.OpenAt(parent.file, name, openFlags|mode, 0)
+ })
+ return f, path, err
+}
+
+// openAnyFile attempts to open the file in O_RDONLY and if it fails fallsback
+// to O_PATH. 'path' is used for logging messages only. 'fn' is what does the
+// actual file open and is customizable by the caller.
+func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, error) {
+ // Attempt to open file in the following mode in order:
+ // 1. RDONLY | NONBLOCK: for all files, directories, ro mounts, FIFOs.
+ // Use non-blocking to prevent getting stuck inside open(2) for
+ // FIFOs. This option has no effect on regular files.
+ // 2. PATH: for symlinks, sockets.
+ modes := []int{syscall.O_RDONLY | syscall.O_NONBLOCK, unix.O_PATH}
+
+ var err error
+ var file *fd.FD
+ for i, mode := range modes {
+ file, err = fn(mode)
+ if err == nil {
+ // openat succeeded, we're done.
+ break
+ }
+ switch e := extractErrno(err); e {
+ case syscall.ENOENT:
+ // File doesn't exist, no point in retrying.
+ return nil, e
+ }
+ // openat failed. Try again with next mode, preserving 'err' in case this
+ // was the last attempt.
+ log.Debugf("Attempt %d to open file failed, mode: %#x, path: %q, err: %v", i, openFlags|mode, path, err)
+ }
+ if err != nil {
+ // All attempts to open file have failed, return the last error.
+ log.Debugf("Failed to open file, path: %q, err: %v", path, err)
+ return nil, extractErrno(err)
+ }
+
+ return file, nil
+}
+
+func getSupportedFileType(stat syscall.Stat_t, permitSocket bool) (fileType, error) {
+ var ft fileType
+ switch stat.Mode & syscall.S_IFMT {
+ case syscall.S_IFREG:
+ ft = regular
+ case syscall.S_IFDIR:
+ ft = directory
+ case syscall.S_IFLNK:
+ ft = symlink
+ case syscall.S_IFSOCK:
+ if !permitSocket {
+ return unknown, syscall.EPERM
+ }
+ ft = socket
+ default:
+ return unknown, syscall.EPERM
+ }
+ return ft, nil
+}
+
+func newLocalFile(a *attachPoint, file *fd.FD, path string, stat syscall.Stat_t) (*localFile, error) {
+ ft, err := getSupportedFileType(stat, a.conf.HostUDS)
+ if err != nil {
+ return nil, err
+ }
+
+ return &localFile{
+ attachPoint: a,
+ hostPath: path,
+ file: file,
+ mode: invalidMode,
+ ft: ft,
+ }, nil
+}
+
+// newFDMaybe creates a fd.FD from a file, dup'ing the FD and setting it as
+// non-blocking. If anything fails, returns nil. It's better to have a file
+// without host FD, than to fail the operation.
+func newFDMaybe(file *fd.FD) *fd.FD {
+ dupFD, err := syscall.Dup(file.FD())
+ // Technically, the runtime may call the finalizer on file as soon as
+ // FD() returns.
+ runtime.KeepAlive(file)
+ if err != nil {
+ return nil
+ }
+ dup := fd.New(dupFD)
+
+ // fd is blocking; non-blocking is required.
+ if err := syscall.SetNonblock(dup.FD(), true); err != nil {
+ dup.Close()
+ return nil
+ }
+ return dup
+}
+
+func fstat(fd int) (syscall.Stat_t, error) {
+ var stat syscall.Stat_t
+ if err := syscall.Fstat(fd, &stat); err != nil {
+ return syscall.Stat_t{}, err
+ }
+ return stat, nil
+}
+
+func stat(path string) (syscall.Stat_t, error) {
+ var stat syscall.Stat_t
+ if err := syscall.Stat(path, &stat); err != nil {
+ return syscall.Stat_t{}, err
+ }
+ return stat, nil
+}
+
+func fchown(fd int, uid p9.UID, gid p9.GID) error {
+ return syscall.Fchownat(fd, "", int(uid), int(gid), linux.AT_EMPTY_PATH|unix.AT_SYMLINK_NOFOLLOW)
+}
+
+// Open implements p9.File.
+func (l *localFile) Open(flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) {
+ if l.isOpen() {
+ panic(fmt.Sprintf("attempting to open already opened file: %q", l.hostPath))
+ }
+
+ // Check if control file can be used or if a new open must be created.
+ var newFile *fd.FD
+ if flags == p9.ReadOnly {
+ log.Debugf("Open reusing control file, flags: %v, %q", flags, l.hostPath)
+ newFile = l.file
+ } else {
+ // Ideally reopen would call name_to_handle_at (with empty name) and
+ // open_by_handle_at to reopen the file without using 'hostPath'. However,
+ // name_to_handle_at and open_by_handle_at aren't supported by overlay2.
+ log.Debugf("Open reopening file, flags: %v, %q", flags, l.hostPath)
+ var err error
+ // Constrain open flags to the open mode and O_TRUNC.
+ newFile, err = reopenProcFd(l.file, openFlags|(flags.OSFlags()&(syscall.O_ACCMODE|syscall.O_TRUNC)))
+ if err != nil {
+ return nil, p9.QID{}, 0, extractErrno(err)
+ }
+ }
+
+ stat, err := fstat(newFile.FD())
+ if err != nil {
+ if newFile != l.file {
+ newFile.Close()
+ }
+ return nil, p9.QID{}, 0, extractErrno(err)
+ }
+
+ var fd *fd.FD
+ if stat.Mode&syscall.S_IFMT == syscall.S_IFREG {
+ // Donate FD for regular files only.
+ fd = newFDMaybe(newFile)
+ }
+
+ // Close old file in case a new one was created.
+ if newFile != l.file {
+ if err := l.file.Close(); err != nil {
+ log.Warningf("Error closing file %q: %v", l.hostPath, err)
+ }
+ l.file = newFile
+ }
+ l.mode = flags & p9.OpenFlagsModeMask
+ return fd, l.attachPoint.makeQID(stat), 0, nil
+}
+
+// Create implements p9.File.
+func (l *localFile) Create(name string, mode p9.OpenFlags, perm p9.FileMode, uid p9.UID, gid p9.GID) (*fd.FD, p9.File, p9.QID, uint32, error) {
+ conf := l.attachPoint.conf
+ if conf.ROMount {
+ if conf.PanicOnWrite {
+ panic("attempt to write to RO mount")
+ }
+ return nil, nil, p9.QID{}, 0, syscall.EBADF
+ }
+
+ // 'file' may be used for other operations (e.g. Walk), so read access is
+ // always added to flags. Note that resulting file might have a wider mode
+ // than needed for each particular case.
+ flags := openFlags | syscall.O_CREAT | syscall.O_EXCL
+ if mode == p9.WriteOnly {
+ flags |= syscall.O_RDWR
+ } else {
+ flags |= mode.OSFlags()
+ }
+
+ child, err := fd.OpenAt(l.file, name, flags, uint32(perm.Permissions()))
+ if err != nil {
+ return nil, nil, p9.QID{}, 0, extractErrno(err)
+ }
+ cu := cleanup.Make(func() {
+ child.Close()
+ // Best effort attempt to remove the file in case of failure.
+ if err := syscall.Unlinkat(l.file.FD(), name); err != nil {
+ log.Warningf("error unlinking file %q after failure: %v", path.Join(l.hostPath, name), err)
+ }
+ })
+ defer cu.Clean()
+
+ if err := fchown(child.FD(), uid, gid); err != nil {
+ return nil, nil, p9.QID{}, 0, extractErrno(err)
+ }
+ stat, err := fstat(child.FD())
+ if err != nil {
+ return nil, nil, p9.QID{}, 0, extractErrno(err)
+ }
+
+ c := &localFile{
+ attachPoint: l.attachPoint,
+ hostPath: path.Join(l.hostPath, name),
+ file: child,
+ mode: mode,
+ }
+
+ cu.Release()
+ return newFDMaybe(c.file), c, l.attachPoint.makeQID(stat), 0, nil
+}
+
+// Mkdir implements p9.File.
+func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID) (p9.QID, error) {
+ conf := l.attachPoint.conf
+ if conf.ROMount {
+ if conf.PanicOnWrite {
+ panic("attempt to write to RO mount")
+ }
+ return p9.QID{}, syscall.EBADF
+ }
+
+ if err := syscall.Mkdirat(l.file.FD(), name, uint32(perm.Permissions())); err != nil {
+ return p9.QID{}, extractErrno(err)
+ }
+ cu := cleanup.Make(func() {
+ // Best effort attempt to remove the dir in case of failure.
+ if err := unix.Unlinkat(l.file.FD(), name, unix.AT_REMOVEDIR); err != nil {
+ log.Warningf("error unlinking dir %q after failure: %v", path.Join(l.hostPath, name), err)
+ }
+ })
+ defer cu.Clean()
+
+ // Open directory to change ownership and stat it.
+ flags := syscall.O_DIRECTORY | syscall.O_RDONLY | openFlags
+ f, err := fd.OpenAt(l.file, name, flags, 0)
+ if err != nil {
+ return p9.QID{}, extractErrno(err)
+ }
+ defer f.Close()
+
+ if err := fchown(f.FD(), uid, gid); err != nil {
+ return p9.QID{}, extractErrno(err)
+ }
+ stat, err := fstat(f.FD())
+ if err != nil {
+ return p9.QID{}, extractErrno(err)
+ }
+
+ cu.Release()
+ return l.attachPoint.makeQID(stat), nil
+}
+
+// Walk implements p9.File.
+func (l *localFile) Walk(names []string) ([]p9.QID, p9.File, error) {
+ // Duplicate current file if 'names' is empty.
+ if len(names) == 0 {
+ newFile, err := openAnyFile(l.hostPath, func(mode int) (*fd.FD, error) {
+ return reopenProcFd(l.file, openFlags|mode)
+ })
+ if err != nil {
+ return nil, nil, extractErrno(err)
+ }
+
+ stat, err := fstat(newFile.FD())
+ if err != nil {
+ newFile.Close()
+ return nil, nil, extractErrno(err)
+ }
+
+ c := &localFile{
+ attachPoint: l.attachPoint,
+ hostPath: l.hostPath,
+ file: newFile,
+ mode: invalidMode,
+ }
+ return []p9.QID{l.attachPoint.makeQID(stat)}, c, nil
+ }
+
+ var qids []p9.QID
+ last := l
+ for _, name := range names {
+ f, path, err := openAnyFileFromParent(last, name)
+ if last != l {
+ last.Close()
+ }
+ if err != nil {
+ return nil, nil, extractErrno(err)
+ }
+ stat, err := fstat(f.FD())
+ if err != nil {
+ f.Close()
+ return nil, nil, extractErrno(err)
+ }
+ c, err := newLocalFile(last.attachPoint, f, path, stat)
+ if err != nil {
+ f.Close()
+ return nil, nil, extractErrno(err)
+ }
+
+ qids = append(qids, l.attachPoint.makeQID(stat))
+ last = c
+ }
+ return qids, last, nil
+}
+
+// StatFS implements p9.File.
+func (l *localFile) StatFS() (p9.FSStat, error) {
+ var s syscall.Statfs_t
+ if err := syscall.Fstatfs(l.file.FD(), &s); err != nil {
+ return p9.FSStat{}, extractErrno(err)
+ }
+
+ // Populate with what's available.
+ return p9.FSStat{
+ Type: uint32(s.Type),
+ BlockSize: uint32(s.Bsize),
+ Blocks: s.Blocks,
+ BlocksFree: s.Bfree,
+ BlocksAvailable: s.Bavail,
+ Files: s.Files,
+ FilesFree: s.Ffree,
+ NameLength: uint32(s.Namelen),
+ }, nil
+}
+
+// FSync implements p9.File.
+func (l *localFile) FSync() error {
+ if !l.isOpen() {
+ return syscall.EBADF
+ }
+ if err := syscall.Fsync(l.file.FD()); err != nil {
+ return extractErrno(err)
+ }
+ return nil
+}
+
+// GetAttr implements p9.File.
+func (l *localFile) GetAttr(_ p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) {
+ stat, err := fstat(l.file.FD())
+ if err != nil {
+ return p9.QID{}, p9.AttrMask{}, p9.Attr{}, extractErrno(err)
+ }
+
+ attr := p9.Attr{
+ Mode: p9.FileMode(stat.Mode),
+ UID: p9.UID(stat.Uid),
+ GID: p9.GID(stat.Gid),
+ NLink: uint64(stat.Nlink),
+ RDev: stat.Rdev,
+ Size: uint64(stat.Size),
+ BlockSize: uint64(stat.Blksize),
+ Blocks: uint64(stat.Blocks),
+ ATimeSeconds: uint64(stat.Atim.Sec),
+ ATimeNanoSeconds: uint64(stat.Atim.Nsec),
+ MTimeSeconds: uint64(stat.Mtim.Sec),
+ MTimeNanoSeconds: uint64(stat.Mtim.Nsec),
+ CTimeSeconds: uint64(stat.Ctim.Sec),
+ CTimeNanoSeconds: uint64(stat.Ctim.Nsec),
+ }
+ valid := p9.AttrMask{
+ Mode: true,
+ UID: true,
+ GID: true,
+ NLink: true,
+ RDev: true,
+ Size: true,
+ Blocks: true,
+ ATime: true,
+ MTime: true,
+ CTime: true,
+ }
+
+ return l.attachPoint.makeQID(stat), valid, attr, nil
+}
+
+// SetAttr implements p9.File. Due to mismatch in file API, options
+// cannot be changed atomically and user may see partial changes when
+// an error happens.
+func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error {
+ conf := l.attachPoint.conf
+ if conf.ROMount {
+ if conf.PanicOnWrite {
+ panic("attempt to write to RO mount")
+ }
+ return syscall.EBADF
+ }
+
+ allowed := p9.SetAttrMask{
+ Permissions: true,
+ UID: true,
+ GID: true,
+ Size: true,
+ ATime: true,
+ MTime: true,
+ ATimeNotSystemTime: true,
+ MTimeNotSystemTime: true,
+ }
+
+ if valid.Empty() {
+ // Nothing to do.
+ return nil
+ }
+
+ // Handle all the sanity checks up front so that the client gets a
+ // consistent result that is not attribute dependent.
+ if !valid.IsSubsetOf(allowed) {
+ log.Warningf("SetAttr() failed for %q, mask: %v", l.hostPath, valid)
+ return syscall.EPERM
+ }
+
+ // Check if it's possible to use cached file, or if another one needs to be
+ // opened for write.
+ f := l.file
+ if l.ft == regular && l.mode != p9.WriteOnly && l.mode != p9.ReadWrite {
+ var err error
+ f, err = reopenProcFd(l.file, openFlags|os.O_WRONLY)
+ if err != nil {
+ return extractErrno(err)
+ }
+ defer f.Close()
+ }
+
+ // The semantics are to either return an error if no changes were made,
+ // or no error if *all* changes were made. Well, this can be impossible
+ // if the filesystem rejects at least one of the changes, especially
+ // since some operations are not easy to undo atomically.
+ //
+ // This could be made better if SetAttr actually returned the changes
+ // it did make, so the client can at least know what has changed. So
+ // we at least attempt to make all of the changes and return a generic
+ // error if any of them fails, which at least doesn't bias any change
+ // over another.
+ var err error
+ if valid.Permissions {
+ if cerr := syscall.Fchmod(f.FD(), uint32(attr.Permissions)); cerr != nil {
+ log.Debugf("SetAttr fchmod failed %q, err: %v", l.hostPath, cerr)
+ err = extractErrno(cerr)
+ }
+ }
+
+ if valid.Size {
+ if terr := syscall.Ftruncate(f.FD(), int64(attr.Size)); terr != nil {
+ log.Debugf("SetAttr ftruncate failed %q, err: %v", l.hostPath, terr)
+ err = extractErrno(terr)
+ }
+ }
+
+ if valid.ATime || valid.MTime {
+ utimes := [2]syscall.Timespec{
+ {Sec: 0, Nsec: linux.UTIME_OMIT},
+ {Sec: 0, Nsec: linux.UTIME_OMIT},
+ }
+ if valid.ATime {
+ if valid.ATimeNotSystemTime {
+ utimes[0].Sec = int64(attr.ATimeSeconds)
+ utimes[0].Nsec = int64(attr.ATimeNanoSeconds)
+ } else {
+ utimes[0].Nsec = linux.UTIME_NOW
+ }
+ }
+ if valid.MTime {
+ if valid.MTimeNotSystemTime {
+ utimes[1].Sec = int64(attr.MTimeSeconds)
+ utimes[1].Nsec = int64(attr.MTimeNanoSeconds)
+ } else {
+ utimes[1].Nsec = linux.UTIME_NOW
+ }
+ }
+
+ if l.ft == symlink {
+ // utimensat operates different that other syscalls. To operate on a
+ // symlink it *requires* AT_SYMLINK_NOFOLLOW with dirFD and a non-empty
+ // name.
+ parent, err := syscall.Open(path.Dir(l.hostPath), openFlags|unix.O_PATH, 0)
+ if err != nil {
+ return extractErrno(err)
+ }
+ defer syscall.Close(parent)
+
+ if terr := utimensat(parent, path.Base(l.hostPath), utimes, linux.AT_SYMLINK_NOFOLLOW); terr != nil {
+ log.Debugf("SetAttr utimens failed %q, err: %v", l.hostPath, terr)
+ err = extractErrno(terr)
+ }
+ } else {
+ // Directories and regular files can operate directly on the fd
+ // using empty name.
+ if terr := utimensat(f.FD(), "", utimes, 0); terr != nil {
+ log.Debugf("SetAttr utimens failed %q, err: %v", l.hostPath, terr)
+ err = extractErrno(terr)
+ }
+ }
+ }
+
+ if valid.UID || valid.GID {
+ uid := -1
+ if valid.UID {
+ uid = int(attr.UID)
+ }
+ gid := -1
+ if valid.GID {
+ gid = int(attr.GID)
+ }
+ if oerr := syscall.Fchownat(f.FD(), "", uid, gid, linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW); oerr != nil {
+ log.Debugf("SetAttr fchownat failed %q, err: %v", l.hostPath, oerr)
+ err = extractErrno(oerr)
+ }
+ }
+
+ return err
+}
+
+func (*localFile) GetXattr(string, uint64) (string, error) {
+ return "", syscall.EOPNOTSUPP
+}
+
+func (*localFile) SetXattr(string, string, uint32) error {
+ return syscall.EOPNOTSUPP
+}
+
+func (*localFile) ListXattr(uint64) (map[string]struct{}, error) {
+ return nil, syscall.EOPNOTSUPP
+}
+
+func (*localFile) RemoveXattr(string) error {
+ return syscall.EOPNOTSUPP
+}
+
+// Allocate implements p9.File.
+func (l *localFile) Allocate(mode p9.AllocateMode, offset, length uint64) error {
+ if !l.isOpen() {
+ return syscall.EBADF
+ }
+
+ if err := syscall.Fallocate(l.file.FD(), mode.ToLinux(), int64(offset), int64(length)); err != nil {
+ return extractErrno(err)
+ }
+ return nil
+}
+
+// Rename implements p9.File; this should never be called.
+func (*localFile) Rename(p9.File, string) error {
+ panic("rename called directly")
+}
+
+// RenameAt implements p9.File.RenameAt.
+func (l *localFile) RenameAt(oldName string, directory p9.File, newName string) error {
+ conf := l.attachPoint.conf
+ if conf.ROMount {
+ if conf.PanicOnWrite {
+ panic("attempt to write to RO mount")
+ }
+ return syscall.EBADF
+ }
+
+ newParent := directory.(*localFile)
+ if err := renameat(l.file.FD(), oldName, newParent.file.FD(), newName); err != nil {
+ return extractErrno(err)
+ }
+ return nil
+}
+
+// ReadAt implements p9.File.
+func (l *localFile) ReadAt(p []byte, offset uint64) (int, error) {
+ if l.mode != p9.ReadOnly && l.mode != p9.ReadWrite {
+ return 0, syscall.EBADF
+ }
+ if !l.isOpen() {
+ return 0, syscall.EBADF
+ }
+
+ r, err := l.file.ReadAt(p, int64(offset))
+ switch err {
+ case nil, io.EOF:
+ return r, nil
+ default:
+ return r, extractErrno(err)
+ }
+}
+
+// WriteAt implements p9.File.
+func (l *localFile) WriteAt(p []byte, offset uint64) (int, error) {
+ if l.mode != p9.WriteOnly && l.mode != p9.ReadWrite {
+ return 0, syscall.EBADF
+ }
+ if !l.isOpen() {
+ return 0, syscall.EBADF
+ }
+
+ w, err := l.file.WriteAt(p, int64(offset))
+ if err != nil {
+ return w, extractErrno(err)
+ }
+ return w, nil
+}
+
+// Symlink implements p9.File.
+func (l *localFile) Symlink(target, newName string, uid p9.UID, gid p9.GID) (p9.QID, error) {
+ conf := l.attachPoint.conf
+ if conf.ROMount {
+ if conf.PanicOnWrite {
+ panic("attempt to write to RO mount")
+ }
+ return p9.QID{}, syscall.EBADF
+ }
+
+ if err := unix.Symlinkat(target, l.file.FD(), newName); err != nil {
+ return p9.QID{}, extractErrno(err)
+ }
+ cu := cleanup.Make(func() {
+ // Best effort attempt to remove the symlink in case of failure.
+ if err := syscall.Unlinkat(l.file.FD(), newName); err != nil {
+ log.Warningf("error unlinking file %q after failure: %v", path.Join(l.hostPath, newName), err)
+ }
+ })
+ defer cu.Clean()
+
+ // Open symlink to change ownership and stat it.
+ f, err := fd.OpenAt(l.file, newName, unix.O_PATH|openFlags, 0)
+ if err != nil {
+ return p9.QID{}, extractErrno(err)
+ }
+ defer f.Close()
+
+ if err := fchown(f.FD(), uid, gid); err != nil {
+ return p9.QID{}, extractErrno(err)
+ }
+ stat, err := fstat(f.FD())
+ if err != nil {
+ return p9.QID{}, extractErrno(err)
+ }
+
+ cu.Release()
+ return l.attachPoint.makeQID(stat), nil
+}
+
+// Link implements p9.File.
+func (l *localFile) Link(target p9.File, newName string) error {
+ conf := l.attachPoint.conf
+ if conf.ROMount {
+ if conf.PanicOnWrite {
+ panic("attempt to write to RO mount")
+ }
+ return syscall.EBADF
+ }
+
+ targetFile := target.(*localFile)
+ if err := unix.Linkat(targetFile.file.FD(), "", l.file.FD(), newName, linux.AT_EMPTY_PATH); err != nil {
+ return extractErrno(err)
+ }
+ return nil
+}
+
+// Mknod implements p9.File.
+func (l *localFile) Mknod(name string, mode p9.FileMode, _ uint32, _ uint32, uid p9.UID, gid p9.GID) (p9.QID, error) {
+ conf := l.attachPoint.conf
+ if conf.ROMount {
+ if conf.PanicOnWrite {
+ panic("attempt to write to RO mount")
+ }
+ return p9.QID{}, syscall.EROFS
+ }
+
+ hostPath := path.Join(l.hostPath, name)
+
+ // Return EEXIST if the file already exists.
+ if _, err := stat(hostPath); err == nil {
+ return p9.QID{}, syscall.EEXIST
+ }
+
+ // From mknod(2) man page:
+ // "EPERM: [...] if the filesystem containing pathname does not support
+ // the type of node requested."
+ if mode.FileType() != p9.ModeRegular {
+ return p9.QID{}, syscall.EPERM
+ }
+
+ // Allow Mknod to create regular files.
+ if err := syscall.Mknod(hostPath, uint32(mode), 0); err != nil {
+ return p9.QID{}, err
+ }
+
+ stat, err := stat(hostPath)
+ if err != nil {
+ return p9.QID{}, extractErrno(err)
+ }
+ return l.attachPoint.makeQID(stat), nil
+}
+
+// UnlinkAt implements p9.File.
+func (l *localFile) UnlinkAt(name string, flags uint32) error {
+ conf := l.attachPoint.conf
+ if conf.ROMount {
+ if conf.PanicOnWrite {
+ panic("attempt to write to RO mount")
+ }
+ return syscall.EBADF
+ }
+
+ if err := unix.Unlinkat(l.file.FD(), name, int(flags)); err != nil {
+ return extractErrno(err)
+ }
+ return nil
+}
+
+// Readdir implements p9.File.
+func (l *localFile) Readdir(offset uint64, count uint32) ([]p9.Dirent, error) {
+ if l.mode != p9.ReadOnly && l.mode != p9.ReadWrite {
+ return nil, syscall.EBADF
+ }
+ if !l.isOpen() {
+ return nil, syscall.EBADF
+ }
+
+ // Readdirnames is a cursor over directories, so seek back to 0 to ensure it's
+ // reading all directory contents. Take a lock because this operation is
+ // stateful.
+ l.readDirMu.Lock()
+ defer l.readDirMu.Unlock()
+
+ skip := uint64(0)
+
+ // Check if the file is at the correct position already. If not, seek to the
+ // beginning and read the entire directory again.
+ if l.lastDirentOffset != offset {
+ if _, err := syscall.Seek(l.file.FD(), 0, 0); err != nil {
+ return nil, extractErrno(err)
+ }
+ skip = offset
+ }
+
+ dirents, err := l.readDirent(l.file.FD(), offset, count, skip)
+ if err == nil {
+ // On success, remember the offset that was returned at the current
+ // position.
+ l.lastDirentOffset = offset + uint64(len(dirents))
+ } else {
+ // On failure, the state is unknown, force call to seek() next time.
+ l.lastDirentOffset = math.MaxUint64
+ }
+ return dirents, err
+}
+
+func (l *localFile) readDirent(f int, offset uint64, count uint32, skip uint64) ([]p9.Dirent, error) {
+ var dirents []p9.Dirent
+
+ // Limit 'count' to cap the slice size that is returned.
+ const maxCount = 100000
+ if count > maxCount {
+ count = maxCount
+ }
+
+ // Pre-allocate buffers that will be reused to get partial results.
+ direntsBuf := make([]byte, 8192)
+ names := make([]string, 0, 100)
+
+ end := offset + uint64(count)
+ for offset < end {
+ dirSize, err := syscall.ReadDirent(f, direntsBuf)
+ if err != nil {
+ return dirents, err
+ }
+ if dirSize <= 0 {
+ return dirents, nil
+ }
+
+ names := names[:0]
+ _, _, names = syscall.ParseDirent(direntsBuf[:dirSize], -1, names)
+
+ // Skip over entries that the caller is not interested in.
+ if skip > 0 {
+ if skip > uint64(len(names)) {
+ skip -= uint64(len(names))
+ names = names[:0]
+ } else {
+ names = names[skip:]
+ skip = 0
+ }
+ }
+ for _, name := range names {
+ stat, err := statAt(l.file.FD(), name)
+ if err != nil {
+ log.Warningf("Readdir is skipping file with failed stat %q, err: %v", l.hostPath, err)
+ continue
+ }
+ qid := l.attachPoint.makeQID(stat)
+ offset++
+ dirents = append(dirents, p9.Dirent{
+ QID: qid,
+ Type: qid.Type,
+ Name: name,
+ Offset: offset,
+ })
+ }
+ }
+ return dirents, nil
+}
+
+// Readlink implements p9.File.
+func (l *localFile) Readlink() (string, error) {
+ // Shamelessly stolen from os.Readlink (added upper bound limit to buffer).
+ const limit = 1024 * 1024
+ for len := 128; len < limit; len *= 2 {
+ b := make([]byte, len)
+ n, err := unix.Readlinkat(l.file.FD(), "", b)
+ if err != nil {
+ return "", extractErrno(err)
+ }
+ if n < len {
+ return string(b[:n]), nil
+ }
+ }
+ return "", syscall.ENOMEM
+}
+
+// Flush implements p9.File.
+func (l *localFile) Flush() error {
+ return nil
+}
+
+// Connect implements p9.File.
+func (l *localFile) Connect(flags p9.ConnectFlags) (*fd.FD, error) {
+ if !l.attachPoint.conf.HostUDS {
+ return nil, syscall.ECONNREFUSED
+ }
+
+ // TODO(gvisor.dev/issue/1003): Due to different app vs replacement
+ // mappings, the app path may have fit in the sockaddr, but we can't
+ // fit f.path in our sockaddr. We'd need to redirect through a shorter
+ // path in order to actually connect to this socket.
+ if len(l.hostPath) > linux.UnixPathMax {
+ return nil, syscall.ECONNREFUSED
+ }
+
+ var stype int
+ switch flags {
+ case p9.StreamSocket:
+ stype = syscall.SOCK_STREAM
+ case p9.DgramSocket:
+ stype = syscall.SOCK_DGRAM
+ case p9.SeqpacketSocket:
+ stype = syscall.SOCK_SEQPACKET
+ default:
+ return nil, syscall.ENXIO
+ }
+
+ f, err := syscall.Socket(syscall.AF_UNIX, stype, 0)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := syscall.SetNonblock(f, true); err != nil {
+ syscall.Close(f)
+ return nil, err
+ }
+
+ sa := syscall.SockaddrUnix{Name: l.hostPath}
+ if err := syscall.Connect(f, &sa); err != nil {
+ syscall.Close(f)
+ return nil, err
+ }
+
+ return fd.New(f), nil
+}
+
+// Close implements p9.File.
+func (l *localFile) Close() error {
+ l.mode = invalidMode
+ err := l.file.Close()
+ l.file = nil
+ return err
+}
+
+func (l *localFile) isOpen() bool {
+ return l.mode != invalidMode
+}
+
+// Renamed implements p9.Renamed.
+func (l *localFile) Renamed(newDir p9.File, newName string) {
+ l.hostPath = path.Join(newDir.(*localFile).hostPath, newName)
+}
+
+// extractErrno tries to determine the errno.
+func extractErrno(err error) syscall.Errno {
+ if err == nil {
+ // This should never happen. The likely result will be that
+ // some user gets the frustrating "error: SUCCESS" message.
+ log.Warningf("extractErrno called with nil error!")
+ return 0
+ }
+
+ switch err {
+ case os.ErrNotExist:
+ return syscall.ENOENT
+ case os.ErrExist:
+ return syscall.EEXIST
+ case os.ErrPermission:
+ return syscall.EACCES
+ case os.ErrInvalid:
+ return syscall.EINVAL
+ }
+
+ // See if it's an errno or a common wrapped error.
+ switch e := err.(type) {
+ case syscall.Errno:
+ return e
+ case *os.PathError:
+ return extractErrno(e.Err)
+ case *os.LinkError:
+ return extractErrno(e.Err)
+ case *os.SyscallError:
+ return extractErrno(e.Err)
+ }
+
+ // Fall back to EIO.
+ log.Debugf("Unknown error: %v, defaulting to EIO", err)
+ return syscall.EIO
+}
diff --git a/runsc/fsgofer/fsgofer_amd64_unsafe.go b/runsc/fsgofer/fsgofer_amd64_unsafe.go
new file mode 100644
index 000000000..5d4aab597
--- /dev/null
+++ b/runsc/fsgofer/fsgofer_amd64_unsafe.go
@@ -0,0 +1,49 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package fsgofer
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/syserr"
+)
+
+func statAt(dirFd int, name string) (syscall.Stat_t, error) {
+ nameBytes, err := syscall.BytePtrFromString(name)
+ if err != nil {
+ return syscall.Stat_t{}, err
+ }
+ namePtr := unsafe.Pointer(nameBytes)
+
+ var stat syscall.Stat_t
+ statPtr := unsafe.Pointer(&stat)
+
+ if _, _, errno := syscall.Syscall6(
+ syscall.SYS_NEWFSTATAT,
+ uintptr(dirFd),
+ uintptr(namePtr),
+ uintptr(statPtr),
+ linux.AT_SYMLINK_NOFOLLOW,
+ 0,
+ 0); errno != 0 {
+
+ return syscall.Stat_t{}, syserr.FromHost(errno).ToError()
+ }
+ return stat, nil
+}
diff --git a/runsc/fsgofer/fsgofer_arm64_unsafe.go b/runsc/fsgofer/fsgofer_arm64_unsafe.go
new file mode 100644
index 000000000..8041fd352
--- /dev/null
+++ b/runsc/fsgofer/fsgofer_arm64_unsafe.go
@@ -0,0 +1,49 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package fsgofer
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/syserr"
+)
+
+func statAt(dirFd int, name string) (syscall.Stat_t, error) {
+ nameBytes, err := syscall.BytePtrFromString(name)
+ if err != nil {
+ return syscall.Stat_t{}, err
+ }
+ namePtr := unsafe.Pointer(nameBytes)
+
+ var stat syscall.Stat_t
+ statPtr := unsafe.Pointer(&stat)
+
+ if _, _, errno := syscall.Syscall6(
+ syscall.SYS_FSTATAT,
+ uintptr(dirFd),
+ uintptr(namePtr),
+ uintptr(statPtr),
+ linux.AT_SYMLINK_NOFOLLOW,
+ 0,
+ 0); errno != 0 {
+
+ return syscall.Stat_t{}, syserr.FromHost(errno).ToError()
+ }
+ return stat, nil
+}
diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go
new file mode 100644
index 000000000..05af7e397
--- /dev/null
+++ b/runsc/fsgofer/fsgofer_test.go
@@ -0,0 +1,692 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsgofer
+
+import (
+ "fmt"
+ "io/ioutil"
+ "net"
+ "os"
+ "path"
+ "path/filepath"
+ "syscall"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+)
+
+func init() {
+ log.SetLevel(log.Debug)
+
+ allConfs = append(allConfs, rwConfs...)
+ allConfs = append(allConfs, roConfs...)
+
+ if err := OpenProcSelfFD(); err != nil {
+ panic(err)
+ }
+}
+
+func assertPanic(t *testing.T, f func()) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("function did not panic")
+ }
+ }()
+ f()
+}
+
+func testReadWrite(f p9.File, flags p9.OpenFlags, content []byte) error {
+ want := make([]byte, len(content))
+ copy(want, content)
+
+ b := []byte("test-1-2-3")
+ w, err := f.WriteAt(b, uint64(len(content)))
+ if flags == p9.WriteOnly || flags == p9.ReadWrite {
+ if err != nil {
+ return fmt.Errorf("WriteAt(): %v", err)
+ }
+ if w != len(b) {
+ return fmt.Errorf("WriteAt() was partial, got: %d, want: %d", w, len(b))
+ }
+ want = append(want, b...)
+ } else {
+ if e, ok := err.(syscall.Errno); !ok || e != syscall.EBADF {
+ return fmt.Errorf("WriteAt() should have failed, got: %d, want: EBADFD", err)
+ }
+ }
+
+ rBuf := make([]byte, len(want))
+ r, err := f.ReadAt(rBuf, 0)
+ if flags == p9.ReadOnly || flags == p9.ReadWrite {
+ if err != nil {
+ return fmt.Errorf("ReadAt(): %v", err)
+ }
+ if r != len(rBuf) {
+ return fmt.Errorf("ReadAt() was partial, got: %d, want: %d", r, len(rBuf))
+ }
+ if string(rBuf) != string(want) {
+ return fmt.Errorf("ReadAt() wrong data, got: %s, want: %s", string(rBuf), want)
+ }
+ } else {
+ if e, ok := err.(syscall.Errno); !ok || e != syscall.EBADF {
+ return fmt.Errorf("ReadAt() should have failed, got: %d, want: EBADFD", err)
+ }
+ }
+ return nil
+}
+
+var allOpenFlags = []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite}
+
+var (
+ allTypes = []fileType{regular, directory, symlink}
+
+ // allConfs is set in init() above.
+ allConfs []Config
+
+ rwConfs = []Config{{ROMount: false}}
+ roConfs = []Config{{ROMount: true}}
+)
+
+type state struct {
+ root *localFile
+ file *localFile
+ conf Config
+ ft fileType
+}
+
+func (s state) String() string {
+ return fmt.Sprintf("type(%v)", s.ft)
+}
+
+func runAll(t *testing.T, test func(*testing.T, state)) {
+ runCustom(t, allTypes, allConfs, test)
+}
+
+func runCustom(t *testing.T, types []fileType, confs []Config, test func(*testing.T, state)) {
+ for _, c := range confs {
+ t.Logf("Config: %+v", c)
+
+ for _, ft := range types {
+ t.Logf("File type: %v", ft)
+
+ path, name, err := setup(ft)
+ if err != nil {
+ t.Fatalf("%v", err)
+ }
+ defer os.RemoveAll(path)
+
+ a, err := NewAttachPoint(path, c)
+ if err != nil {
+ t.Fatalf("NewAttachPoint failed: %v", err)
+ }
+ root, err := a.Attach()
+ if err != nil {
+ t.Fatalf("Attach failed, err: %v", err)
+ }
+
+ _, file, err := root.Walk([]string{name})
+ if err != nil {
+ root.Close()
+ t.Fatalf("root.Walk({%q}) failed, err: %v", "symlink", err)
+ }
+
+ st := state{root: root.(*localFile), file: file.(*localFile), conf: c, ft: ft}
+ test(t, st)
+ file.Close()
+ root.Close()
+ }
+ }
+}
+
+func setup(ft fileType) (string, string, error) {
+ path, err := ioutil.TempDir("", "root-")
+ if err != nil {
+ return "", "", fmt.Errorf("ioutil.TempDir() failed, err: %v", err)
+ }
+
+ // First attach with writable configuration to setup tree.
+ a, err := NewAttachPoint(path, Config{})
+ if err != nil {
+ return "", "", err
+ }
+ root, err := a.Attach()
+ if err != nil {
+ return "", "", fmt.Errorf("Attach failed, err: %v", err)
+ }
+ defer root.Close()
+
+ var name string
+ switch ft {
+ case regular:
+ name = "file"
+ _, f, _, _, err := root.Create(name, p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid()))
+ if err != nil {
+ return "", "", fmt.Errorf("createFile(root, %q) failed, err: %v", "test", err)
+ }
+ defer f.Close()
+ case directory:
+ name = "dir"
+ if _, err := root.Mkdir(name, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil {
+ return "", "", fmt.Errorf("root.MkDir(%q) failed, err: %v", name, err)
+ }
+ case symlink:
+ name = "symlink"
+ if _, err := root.Symlink("/some/target", name, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil {
+ return "", "", fmt.Errorf("root.Symlink(%q) failed, err: %v", name, err)
+ }
+ default:
+ panic(fmt.Sprintf("unknown file type %v", ft))
+ }
+ return path, name, nil
+}
+
+func createFile(dir *localFile, name string) (*localFile, error) {
+ _, f, _, _, err := dir.Create(name, p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid()))
+ if err != nil {
+ return nil, err
+ }
+ return f.(*localFile), nil
+}
+
+func TestReadWrite(t *testing.T) {
+ runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) {
+ child, err := createFile(s.file, "test")
+ if err != nil {
+ t.Fatalf("%v: createFile() failed, err: %v", s, err)
+ }
+ defer child.Close()
+ want := []byte("foobar")
+ w, err := child.WriteAt(want, 0)
+ if err != nil {
+ t.Fatalf("%v: Write() failed, err: %v", s, err)
+ }
+ if w != len(want) {
+ t.Fatalf("%v: Write() was partial, got: %d, expected: %d", s, w, len(want))
+ }
+ for _, flags := range allOpenFlags {
+ _, l, err := s.file.Walk([]string{"test"})
+ if err != nil {
+ t.Fatalf("%v: Walk(%s) failed, err: %v", s, "test", err)
+ }
+ if _, _, _, err := l.Open(flags); err != nil {
+ t.Fatalf("%v: Open(%v) failed, err: %v", s, flags, err)
+ }
+ if err := testReadWrite(l, flags, want); err != nil {
+ t.Fatalf("%v: testReadWrite(%v) failed: %v", s, flags, err)
+ }
+ }
+ })
+}
+
+func TestCreate(t *testing.T) {
+ runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) {
+ for i, flags := range allOpenFlags {
+ _, l, _, _, err := s.file.Create(fmt.Sprintf("test-%d", i), flags, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid()))
+ if err != nil {
+ t.Fatalf("%v, %v: WriteAt() failed, err: %v", s, flags, err)
+ }
+
+ if err := testReadWrite(l, flags, []byte{}); err != nil {
+ t.Fatalf("%v: testReadWrite(%v) failed: %v", s, flags, err)
+ }
+ }
+ })
+}
+
+// TestReadWriteDup tests that a file opened in any mode can be dup'ed and
+// reopened in any other mode.
+func TestReadWriteDup(t *testing.T) {
+ runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) {
+ child, err := createFile(s.file, "test")
+ if err != nil {
+ t.Fatalf("%v: createFile() failed, err: %v", s, err)
+ }
+ defer child.Close()
+ want := []byte("foobar")
+ w, err := child.WriteAt(want, 0)
+ if err != nil {
+ t.Fatalf("%v: Write() failed, err: %v", s, err)
+ }
+ if w != len(want) {
+ t.Fatalf("%v: Write() was partial, got: %d, expected: %d", s, w, len(want))
+ }
+ for _, flags := range allOpenFlags {
+ _, l, err := s.file.Walk([]string{"test"})
+ if err != nil {
+ t.Fatalf("%v: Walk(%s) failed, err: %v", s, "test", err)
+ }
+ defer l.Close()
+ if _, _, _, err := l.Open(flags); err != nil {
+ t.Fatalf("%v: Open(%v) failed, err: %v", s, flags, err)
+ }
+ for _, dupFlags := range allOpenFlags {
+ t.Logf("Original flags: %v, dup flags: %v", flags, dupFlags)
+ _, dup, err := l.Walk([]string{})
+ if err != nil {
+ t.Fatalf("%v: Walk(<empty>) failed: %v", s, err)
+ }
+ defer dup.Close()
+ if _, _, _, err := dup.Open(dupFlags); err != nil {
+ t.Fatalf("%v: Open(%v) failed: %v", s, flags, err)
+ }
+ if err := testReadWrite(dup, dupFlags, want); err != nil {
+ t.Fatalf("%v: testReadWrite(%v) failed: %v", s, dupFlags, err)
+ }
+ }
+ }
+ })
+}
+
+func TestUnopened(t *testing.T) {
+ runCustom(t, []fileType{regular}, allConfs, func(t *testing.T, s state) {
+ b := []byte("foobar")
+ if _, err := s.file.WriteAt(b, 0); err != syscall.EBADF {
+ t.Errorf("%v: WriteAt() should have failed, got: %v, expected: syscall.EBADF", s, err)
+ }
+ if _, err := s.file.ReadAt(b, 0); err != syscall.EBADF {
+ t.Errorf("%v: ReadAt() should have failed, got: %v, expected: syscall.EBADF", s, err)
+ }
+ if _, err := s.file.Readdir(0, 100); err != syscall.EBADF {
+ t.Errorf("%v: Readdir() should have failed, got: %v, expected: syscall.EBADF", s, err)
+ }
+ if err := s.file.FSync(); err != syscall.EBADF {
+ t.Errorf("%v: FSync() should have failed, got: %v, expected: syscall.EBADF", s, err)
+ }
+ })
+}
+
+func SetGetAttr(l *localFile, valid p9.SetAttrMask, attr p9.SetAttr) (p9.Attr, error) {
+ if err := l.SetAttr(valid, attr); err != nil {
+ return p9.Attr{}, err
+ }
+ _, _, a, err := l.GetAttr(p9.AttrMask{})
+ if err != nil {
+ return p9.Attr{}, err
+ }
+ return a, nil
+}
+
+func TestSetAttrPerm(t *testing.T) {
+ runCustom(t, allTypes, rwConfs, func(t *testing.T, s state) {
+ valid := p9.SetAttrMask{Permissions: true}
+ attr := p9.SetAttr{Permissions: 0777}
+ got, err := SetGetAttr(s.file, valid, attr)
+ if s.ft == symlink {
+ if err == nil {
+ t.Fatalf("%v: SetGetAttr(valid, %v) should have failed", s, attr.Permissions)
+ }
+ } else {
+ if err != nil {
+ t.Fatalf("%v: SetGetAttr(valid, %v) failed, err: %v", s, attr.Permissions, err)
+ }
+ if got.Mode.Permissions() != attr.Permissions {
+ t.Errorf("%v: wrong permission, got: %v, expected: %v", s, got.Mode.Permissions(), attr.Permissions)
+ }
+ }
+ })
+}
+
+func TestSetAttrSize(t *testing.T) {
+ runCustom(t, allTypes, rwConfs, func(t *testing.T, s state) {
+ for _, size := range []uint64{1024, 0, 1024 * 1024} {
+ valid := p9.SetAttrMask{Size: true}
+ attr := p9.SetAttr{Size: size}
+ got, err := SetGetAttr(s.file, valid, attr)
+ if s.ft == symlink || s.ft == directory {
+ if err == nil {
+ t.Fatalf("%v: SetGetAttr(valid, %v) should have failed", s, attr.Permissions)
+ }
+ // Run for one size only, they will all fail the same way.
+ return
+ }
+ if err != nil {
+ t.Fatalf("%v: SetGetAttr(valid, %v) failed, err: %v", s, attr.Size, err)
+ }
+ if got.Size != size {
+ t.Errorf("%v: wrong size, got: %v, expected: %v", s, got.Size, size)
+ }
+ }
+ })
+}
+
+func TestSetAttrTime(t *testing.T) {
+ runCustom(t, allTypes, rwConfs, func(t *testing.T, s state) {
+ valid := p9.SetAttrMask{ATime: true, ATimeNotSystemTime: true}
+ attr := p9.SetAttr{ATimeSeconds: 123, ATimeNanoSeconds: 456}
+ got, err := SetGetAttr(s.file, valid, attr)
+ if err != nil {
+ t.Fatalf("%v: SetGetAttr(valid, %v:%v) failed, err: %v", s, attr.ATimeSeconds, attr.ATimeNanoSeconds, err)
+ }
+ if got.ATimeSeconds != 123 {
+ t.Errorf("%v: wrong ATimeSeconds, got: %v, expected: %v", s, got.ATimeSeconds, 123)
+ }
+ if got.ATimeNanoSeconds != 456 {
+ t.Errorf("%v: wrong ATimeNanoSeconds, got: %v, expected: %v", s, got.ATimeNanoSeconds, 456)
+ }
+
+ valid = p9.SetAttrMask{MTime: true, MTimeNotSystemTime: true}
+ attr = p9.SetAttr{MTimeSeconds: 789, MTimeNanoSeconds: 012}
+ got, err = SetGetAttr(s.file, valid, attr)
+ if err != nil {
+ t.Fatalf("%v: SetGetAttr(valid, %v:%v) failed, err: %v", s, attr.MTimeSeconds, attr.MTimeNanoSeconds, err)
+ }
+ if got.MTimeSeconds != 789 {
+ t.Errorf("%v: wrong MTimeSeconds, got: %v, expected: %v", s, got.MTimeSeconds, 789)
+ }
+ if got.MTimeNanoSeconds != 012 {
+ t.Errorf("%v: wrong MTimeNanoSeconds, got: %v, expected: %v", s, got.MTimeNanoSeconds, 012)
+ }
+ })
+}
+
+func TestSetAttrOwner(t *testing.T) {
+ if os.Getuid() != 0 {
+ t.Skipf("SetAttr(owner) test requires CAP_CHOWN, running as %d", os.Getuid())
+ }
+
+ runCustom(t, allTypes, rwConfs, func(t *testing.T, s state) {
+ newUID := os.Getuid() + 1
+ valid := p9.SetAttrMask{UID: true}
+ attr := p9.SetAttr{UID: p9.UID(newUID)}
+ got, err := SetGetAttr(s.file, valid, attr)
+ if err != nil {
+ t.Fatalf("%v: SetGetAttr(valid, %v) failed, err: %v", s, attr.UID, err)
+ }
+ if got.UID != p9.UID(newUID) {
+ t.Errorf("%v: wrong uid, got: %v, expected: %v", s, got.UID, newUID)
+ }
+ })
+}
+
+func TestLink(t *testing.T) {
+ if os.Getuid() != 0 {
+ t.Skipf("Link test requires CAP_DAC_READ_SEARCH, running as %d", os.Getuid())
+ }
+ runCustom(t, allTypes, rwConfs, func(t *testing.T, s state) {
+ const dirName = "linkdir"
+ const linkFile = "link"
+ if _, err := s.root.Mkdir(dirName, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil {
+ t.Fatalf("%v: MkDir(%s) failed, err: %v", s, dirName, err)
+ }
+ _, dir, err := s.root.Walk([]string{dirName})
+ if err != nil {
+ t.Fatalf("%v: Walk({%s}) failed, err: %v", s, dirName, err)
+ }
+
+ err = dir.Link(s.file, linkFile)
+ if s.ft == directory {
+ if err != syscall.EPERM {
+ t.Errorf("%v: Link(target, %s) should have failed, got: %v, expected: syscall.EPERM", s, linkFile, err)
+ }
+ return
+ }
+ if err != nil {
+ t.Errorf("%v: Link(target, %s) failed, err: %v", s, linkFile, err)
+ }
+ })
+}
+
+func TestROMountChecks(t *testing.T) {
+ runCustom(t, allTypes, roConfs, func(t *testing.T, s state) {
+ if _, _, _, _, err := s.file.Create("some_file", p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != syscall.EBADF {
+ t.Errorf("%v: Create() should have failed, got: %v, expected: syscall.EBADF", s, err)
+ }
+ if _, err := s.file.Mkdir("some_dir", 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != syscall.EBADF {
+ t.Errorf("%v: MkDir() should have failed, got: %v, expected: syscall.EBADF", s, err)
+ }
+ if err := s.file.RenameAt("some_file", s.file, "other_file"); err != syscall.EBADF {
+ t.Errorf("%v: Rename() should have failed, got: %v, expected: syscall.EBADF", s, err)
+ }
+ if _, err := s.file.Symlink("some_place", "some_symlink", p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != syscall.EBADF {
+ t.Errorf("%v: Symlink() should have failed, got: %v, expected: syscall.EBADF", s, err)
+ }
+ if err := s.file.UnlinkAt("some_file", 0); err != syscall.EBADF {
+ t.Errorf("%v: UnlinkAt() should have failed, got: %v, expected: syscall.EBADF", s, err)
+ }
+ if err := s.file.Link(s.file, "some_link"); err != syscall.EBADF {
+ t.Errorf("%v: Link() should have failed, got: %v, expected: syscall.EBADF", s, err)
+ }
+
+ valid := p9.SetAttrMask{Size: true}
+ attr := p9.SetAttr{Size: 0}
+ if err := s.file.SetAttr(valid, attr); err != syscall.EBADF {
+ t.Errorf("%v: SetAttr() should have failed, got: %v, expected: syscall.EBADF", s, err)
+ }
+ })
+}
+
+func TestROMountPanics(t *testing.T) {
+ conf := Config{ROMount: true, PanicOnWrite: true}
+ runCustom(t, allTypes, []Config{conf}, func(t *testing.T, s state) {
+ assertPanic(t, func() { s.file.Create("some_file", p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) })
+ assertPanic(t, func() { s.file.Mkdir("some_dir", 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) })
+ assertPanic(t, func() { s.file.RenameAt("some_file", s.file, "other_file") })
+ assertPanic(t, func() { s.file.Symlink("some_place", "some_symlink", p9.UID(os.Getuid()), p9.GID(os.Getgid())) })
+ assertPanic(t, func() { s.file.UnlinkAt("some_file", 0) })
+ assertPanic(t, func() { s.file.Link(s.file, "some_link") })
+
+ valid := p9.SetAttrMask{Size: true}
+ attr := p9.SetAttr{Size: 0}
+ assertPanic(t, func() { s.file.SetAttr(valid, attr) })
+ })
+}
+
+func TestWalkNotFound(t *testing.T) {
+ runCustom(t, []fileType{directory}, allConfs, func(t *testing.T, s state) {
+ if _, _, err := s.file.Walk([]string{"nobody-here"}); err != syscall.ENOENT {
+ t.Errorf("%v: Walk(%q) should have failed, got: %v, expected: syscall.ENOENT", s, "nobody-here", err)
+ }
+ })
+}
+
+func TestWalkDup(t *testing.T) {
+ runAll(t, func(t *testing.T, s state) {
+ _, dup, err := s.file.Walk([]string{})
+ if err != nil {
+ t.Fatalf("%v: Walk(nil) failed, err: %v", s, err)
+ }
+ // Check that 'dup' is usable.
+ if _, _, _, err := dup.GetAttr(p9.AttrMask{}); err != nil {
+ t.Errorf("%v: GetAttr() failed, err: %v", s, err)
+ }
+ })
+}
+
+func TestReaddir(t *testing.T) {
+ runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) {
+ name := "dir"
+ if _, err := s.file.Mkdir(name, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil {
+ t.Fatalf("%v: MkDir(%s) failed, err: %v", s, name, err)
+ }
+ name = "symlink"
+ if _, err := s.file.Symlink("/some/target", name, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil {
+ t.Fatalf("%v: Symlink(%q) failed, err: %v", s, name, err)
+ }
+ name = "file"
+ _, f, _, _, err := s.file.Create(name, p9.ReadWrite, 0555, p9.UID(os.Getuid()), p9.GID(os.Getgid()))
+ if err != nil {
+ t.Fatalf("%v: createFile(root, %q) failed, err: %v", s, name, err)
+ }
+ f.Close()
+
+ if _, _, _, err := s.file.Open(p9.ReadOnly); err != nil {
+ t.Fatalf("%v: Open(ReadOnly) failed, err: %v", s, err)
+ }
+
+ dirents, err := s.file.Readdir(0, 10)
+ if err != nil {
+ t.Fatalf("%v: Readdir(0, 10) failed, err: %v", s, err)
+ }
+ if len(dirents) != 3 {
+ t.Fatalf("%v: Readdir(0, 10) wrong number of items, got: %v, expected: 3", s, len(dirents))
+ }
+ var dir, symlink, file bool
+ for _, d := range dirents {
+ switch d.Name {
+ case "dir":
+ if d.Type != p9.TypeDir {
+ t.Errorf("%v: dirent.Type got: %v, expected: %v", s, d.Type, p9.TypeDir)
+ }
+ dir = true
+ case "symlink":
+ if d.Type != p9.TypeSymlink {
+ t.Errorf("%v: dirent.Type got: %v, expected: %v", s, d.Type, p9.TypeSymlink)
+ }
+ symlink = true
+ case "file":
+ if d.Type != p9.TypeRegular {
+ t.Errorf("%v: dirent.Type got: %v, expected: %v", s, d.Type, p9.TypeRegular)
+ }
+ file = true
+ default:
+ t.Errorf("%v: dirent.Name got: %v", s, d.Name)
+ }
+
+ _, f, err := s.file.Walk([]string{d.Name})
+ if err != nil {
+ t.Fatalf("%v: Walk({%s}) failed, err: %v", s, d.Name, err)
+ }
+ _, _, a, err := f.GetAttr(p9.AttrMask{})
+ if err != nil {
+ t.Fatalf("%v: GetAttr() failed, err: %v", s, err)
+ }
+ if d.Type != a.Mode.QIDType() {
+ t.Errorf("%v: dirent.Type different than GetAttr().Mode.QIDType(), got: %v, expected: %v", s, d.Type, a.Mode.QIDType())
+ }
+ }
+ if !dir || !symlink || !file {
+ t.Errorf("%v: Readdir(0, 10) wrong files returned, dir: %v, symlink: %v, file: %v", s, dir, symlink, file)
+ }
+ })
+}
+
+// Test that attach point can be written to when it points to a file, e.g.
+// /etc/hosts.
+func TestAttachFile(t *testing.T) {
+ conf := Config{ROMount: false}
+ dir, err := ioutil.TempDir("", "root-")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed, err: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ path := path.Join(dir, "test")
+ if _, err := os.Create(path); err != nil {
+ t.Fatalf("os.Create(%q) failed, err: %v", path, err)
+ }
+
+ a, err := NewAttachPoint(path, conf)
+ if err != nil {
+ t.Fatalf("NewAttachPoint failed: %v", err)
+ }
+ root, err := a.Attach()
+ if err != nil {
+ t.Fatalf("Attach failed, err: %v", err)
+ }
+
+ if _, _, _, err := root.Open(p9.ReadWrite); err != nil {
+ t.Fatalf("Open(ReadWrite) failed, err: %v", err)
+ }
+ defer root.Close()
+
+ b := []byte("foobar")
+ w, err := root.WriteAt(b, 0)
+ if err != nil {
+ t.Fatalf("Write() failed, err: %v", err)
+ }
+ if w != len(b) {
+ t.Fatalf("Write() was partial, got: %d, expected: %d", w, len(b))
+ }
+ rBuf := make([]byte, len(b))
+ r, err := root.ReadAt(rBuf, 0)
+ if err != nil {
+ t.Fatalf("ReadAt() failed, err: %v", err)
+ }
+ if r != len(rBuf) {
+ t.Fatalf("ReadAt() was partial, got: %d, expected: %d", r, len(rBuf))
+ }
+ if string(rBuf) != "foobar" {
+ t.Fatalf("ReadAt() wrong data, got: %s, expected: %s", string(rBuf), "foobar")
+ }
+}
+
+func TestAttachInvalidType(t *testing.T) {
+ dir, err := ioutil.TempDir("", "attach-")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed, err: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ fifo := filepath.Join(dir, "fifo")
+ if err := syscall.Mkfifo(fifo, 0755); err != nil {
+ t.Fatalf("Mkfifo(%q): %v", fifo, err)
+ }
+
+ dirFile, err := os.Open(dir)
+ if err != nil {
+ t.Fatalf("Open(%s): %v", dir, err)
+ }
+ defer dirFile.Close()
+
+ // Bind a socket via /proc to be sure that a length of a socket path
+ // is less than UNIX_PATH_MAX.
+ socket := filepath.Join(fmt.Sprintf("/proc/self/fd/%d", dirFile.Fd()), "socket")
+ l, err := net.Listen("unix", socket)
+ if err != nil {
+ t.Fatalf("net.Listen(unix, %q): %v", socket, err)
+ }
+ defer l.Close()
+
+ for _, tc := range []struct {
+ name string
+ path string
+ }{
+ {name: "fifo", path: fifo},
+ {name: "socket", path: socket},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ conf := Config{ROMount: false}
+ a, err := NewAttachPoint(tc.path, conf)
+ if err != nil {
+ t.Fatalf("NewAttachPoint failed: %v", err)
+ }
+ f, err := a.Attach()
+ if f != nil || err == nil {
+ t.Fatalf("Attach should have failed, got (%v, %v)", f, err)
+ }
+ })
+ }
+}
+
+func TestDoubleAttachError(t *testing.T) {
+ conf := Config{ROMount: false}
+ root, err := ioutil.TempDir("", "root-")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed, err: %v", err)
+ }
+ defer os.RemoveAll(root)
+ a, err := NewAttachPoint(root, conf)
+ if err != nil {
+ t.Fatalf("NewAttachPoint failed: %v", err)
+ }
+
+ if _, err := a.Attach(); err != nil {
+ t.Fatalf("Attach failed: %v", err)
+ }
+ if _, err := a.Attach(); err == nil {
+ t.Fatalf("Attach should have failed, got %v want non-nil", err)
+ }
+}
diff --git a/runsc/fsgofer/fsgofer_unsafe.go b/runsc/fsgofer/fsgofer_unsafe.go
new file mode 100644
index 000000000..542b54365
--- /dev/null
+++ b/runsc/fsgofer/fsgofer_unsafe.go
@@ -0,0 +1,82 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsgofer
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/syserr"
+)
+
+func utimensat(dirFd int, name string, times [2]syscall.Timespec, flags int) error {
+ // utimensat(2) doesn't accept empty name, instead name must be nil to make it
+ // operate directly on 'dirFd' unlike other *at syscalls.
+ var namePtr unsafe.Pointer
+ if name != "" {
+ nameBytes, err := syscall.BytePtrFromString(name)
+ if err != nil {
+ return err
+ }
+ namePtr = unsafe.Pointer(nameBytes)
+ }
+
+ timesPtr := unsafe.Pointer(&times[0])
+
+ if _, _, errno := syscall.Syscall6(
+ syscall.SYS_UTIMENSAT,
+ uintptr(dirFd),
+ uintptr(namePtr),
+ uintptr(timesPtr),
+ uintptr(flags),
+ 0,
+ 0); errno != 0 {
+
+ return syserr.FromHost(errno).ToError()
+ }
+ return nil
+}
+
+func renameat(oldDirFD int, oldName string, newDirFD int, newName string) error {
+ var oldNamePtr unsafe.Pointer
+ if oldName != "" {
+ nameBytes, err := syscall.BytePtrFromString(oldName)
+ if err != nil {
+ return err
+ }
+ oldNamePtr = unsafe.Pointer(nameBytes)
+ }
+ var newNamePtr unsafe.Pointer
+ if newName != "" {
+ nameBytes, err := syscall.BytePtrFromString(newName)
+ if err != nil {
+ return err
+ }
+ newNamePtr = unsafe.Pointer(nameBytes)
+ }
+
+ if _, _, errno := syscall.Syscall6(
+ syscall.SYS_RENAMEAT,
+ uintptr(oldDirFD),
+ uintptr(oldNamePtr),
+ uintptr(newDirFD),
+ uintptr(newNamePtr),
+ 0,
+ 0); errno != 0 {
+
+ return syserr.FromHost(errno).ToError()
+ }
+ return nil
+}
diff --git a/runsc/main.go b/runsc/main.go
new file mode 100644
index 000000000..c9f47c579
--- /dev/null
+++ b/runsc/main.go
@@ -0,0 +1,372 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary runsc is an implementation of the Open Container Initiative Runtime
+// that runs applications inside a sandbox.
+package main
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "os/signal"
+ "path/filepath"
+ "strings"
+ "syscall"
+ "time"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/cmd"
+ "gvisor.dev/gvisor/runsc/flag"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+var (
+ // Although these flags are not part of the OCI spec, they are used by
+ // Docker, and thus should not be changed.
+ rootDir = flag.String("root", "", "root directory for storage of container state.")
+ logFilename = flag.String("log", "", "file path where internal debug information is written, default is stdout.")
+ logFormat = flag.String("log-format", "text", "log format: text (default), json, or json-k8s.")
+ debug = flag.Bool("debug", false, "enable debug logging.")
+ showVersion = flag.Bool("version", false, "show version and exit.")
+ // TODO(gvisor.dev/issue/193): support systemd cgroups
+ systemdCgroup = flag.Bool("systemd-cgroup", false, "Use systemd for cgroups. NOT SUPPORTED.")
+
+ // These flags are unique to runsc, and are used to configure parts of the
+ // system that are not covered by the runtime spec.
+
+ // Debugging flags.
+ debugLog = flag.String("debug-log", "", "additional location for logs. If it ends with '/', log files are created inside the directory with default names. The following variables are available: %TIMESTAMP%, %COMMAND%.")
+ panicLog = flag.String("panic-log", "", "file path were panic reports and other Go's runtime messages are written.")
+ logPackets = flag.Bool("log-packets", false, "enable network packet logging.")
+ logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.")
+ debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.")
+ panicLogFD = flag.Int("panic-log-fd", -1, "file descriptor to write Go's runtime messages.")
+ debugLogFormat = flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s.")
+ alsoLogToStderr = flag.Bool("alsologtostderr", false, "send log messages to stderr.")
+
+ // Debugging flags: strace related
+ strace = flag.Bool("strace", false, "enable strace.")
+ straceSyscalls = flag.String("strace-syscalls", "", "comma-separated list of syscalls to trace. If --strace is true and this list is empty, then all syscalls will be traced.")
+ straceLogSize = flag.Uint("strace-log-size", 1024, "default size (in bytes) to log data argument blobs.")
+
+ // Flags that control sandbox runtime behavior.
+ platformName = flag.String("platform", "ptrace", "specifies which platform to use: ptrace (default), kvm.")
+ network = flag.String("network", "sandbox", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.")
+ hardwareGSO = flag.Bool("gso", true, "enable hardware segmentation offload if it is supported by a network device.")
+ softwareGSO = flag.Bool("software-gso", true, "enable software segmentation offload when hardware offload can't be enabled.")
+ txChecksumOffload = flag.Bool("tx-checksum-offload", false, "enable TX checksum offload.")
+ rxChecksumOffload = flag.Bool("rx-checksum-offload", true, "enable RX checksum offload.")
+ qDisc = flag.String("qdisc", "fifo", "specifies which queueing discipline to apply by default to the non loopback nics used by the sandbox.")
+ fileAccess = flag.String("file-access", "exclusive", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.")
+ fsGoferHostUDS = flag.Bool("fsgofer-host-uds", false, "allow the gofer to mount Unix Domain Sockets.")
+ overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.")
+ overlayfsStaleRead = flag.Bool("overlayfs-stale-read", true, "assume root mount is an overlay filesystem")
+ watchdogAction = flag.String("watchdog-action", "log", "sets what action the watchdog takes when triggered: log (default), panic.")
+ panicSignal = flag.Int("panic-signal", -1, "register signal handling that panics. Usually set to SIGUSR2(12) to troubleshoot hangs. -1 disables it.")
+ profile = flag.Bool("profile", false, "prepares the sandbox to use Golang profiler. Note that enabling profiler loosens the seccomp protection added to the sandbox (DO NOT USE IN PRODUCTION).")
+ netRaw = flag.Bool("net-raw", false, "enable raw sockets. When false, raw sockets are disabled by removing CAP_NET_RAW from containers (`runsc exec` will still be able to utilize raw sockets). Raw sockets allow malicious containers to craft packets and potentially attack the network.")
+ numNetworkChannels = flag.Int("num-network-channels", 1, "number of underlying channels(FDs) to use for network link endpoints.")
+ rootless = flag.Bool("rootless", false, "it allows the sandbox to be started with a user that is not root. Sandbox and Gofer processes may run with same privileges as current user.")
+ referenceLeakMode = flag.String("ref-leak-mode", "disabled", "sets reference leak check mode: disabled (default), log-names, log-traces.")
+ cpuNumFromQuota = flag.Bool("cpu-num-from-quota", false, "set cpu number to cpu quota (least integer greater or equal to quota value, but not less than 2)")
+ vfs2Enabled = flag.Bool("vfs2", false, "TEST ONLY; use while VFSv2 is landing. This uses the new experimental VFS layer.")
+
+ // Test flags, not to be used outside tests, ever.
+ testOnlyAllowRunAsCurrentUserWithoutChroot = flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.")
+ testOnlyTestNameEnv = flag.String("TESTONLY-test-name-env", "", "TEST ONLY; do not ever use! Used for automated tests to improve logging.")
+)
+
+func main() {
+ // Help and flags commands are generated automatically.
+ help := cmd.NewHelp(subcommands.DefaultCommander)
+ help.Register(new(cmd.Syscalls))
+ subcommands.Register(help, "")
+ subcommands.Register(subcommands.FlagsCommand(), "")
+
+ // Installation helpers.
+ const helperGroup = "helpers"
+ subcommands.Register(new(cmd.Install), helperGroup)
+ subcommands.Register(new(cmd.Uninstall), helperGroup)
+
+ // Register user-facing runsc commands.
+ subcommands.Register(new(cmd.Checkpoint), "")
+ subcommands.Register(new(cmd.Create), "")
+ subcommands.Register(new(cmd.Delete), "")
+ subcommands.Register(new(cmd.Do), "")
+ subcommands.Register(new(cmd.Events), "")
+ subcommands.Register(new(cmd.Exec), "")
+ subcommands.Register(new(cmd.Gofer), "")
+ subcommands.Register(new(cmd.Kill), "")
+ subcommands.Register(new(cmd.List), "")
+ subcommands.Register(new(cmd.Pause), "")
+ subcommands.Register(new(cmd.PS), "")
+ subcommands.Register(new(cmd.Restore), "")
+ subcommands.Register(new(cmd.Resume), "")
+ subcommands.Register(new(cmd.Run), "")
+ subcommands.Register(new(cmd.Spec), "")
+ subcommands.Register(new(cmd.State), "")
+ subcommands.Register(new(cmd.Start), "")
+ subcommands.Register(new(cmd.Wait), "")
+
+ // Register internal commands with the internal group name. This causes
+ // them to be sorted below the user-facing commands with empty group.
+ // The string below will be printed above the commands.
+ const internalGroup = "internal use only"
+ subcommands.Register(new(cmd.Boot), internalGroup)
+ subcommands.Register(new(cmd.Debug), internalGroup)
+ subcommands.Register(new(cmd.Gofer), internalGroup)
+ subcommands.Register(new(cmd.Statefile), internalGroup)
+
+ // All subcommands must be registered before flag parsing.
+ flag.Parse()
+
+ // Are we showing the version?
+ if *showVersion {
+ // The format here is the same as runc.
+ fmt.Fprintf(os.Stdout, "runsc version %s\n", version)
+ fmt.Fprintf(os.Stdout, "spec: %s\n", specutils.Version)
+ os.Exit(0)
+ }
+
+ // TODO(gvisor.dev/issue/193): support systemd cgroups
+ if *systemdCgroup {
+ fmt.Fprintln(os.Stderr, "systemd cgroup flag passed, but systemd cgroups not supported. See gvisor.dev/issue/193")
+ os.Exit(1)
+ }
+
+ var errorLogger io.Writer
+ if *logFD > -1 {
+ errorLogger = os.NewFile(uintptr(*logFD), "error log file")
+
+ } else if *logFilename != "" {
+ // We must set O_APPEND and not O_TRUNC because Docker passes
+ // the same log file for all commands (and also parses these
+ // log files), so we can't destroy them on each command.
+ var err error
+ errorLogger, err = os.OpenFile(*logFilename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
+ if err != nil {
+ cmd.Fatalf("error opening log file %q: %v", *logFilename, err)
+ }
+ }
+ cmd.ErrorLogger = errorLogger
+
+ platformType := *platformName
+ if _, err := platform.Lookup(platformType); err != nil {
+ cmd.Fatalf("%v", err)
+ }
+
+ fsAccess, err := boot.MakeFileAccessType(*fileAccess)
+ if err != nil {
+ cmd.Fatalf("%v", err)
+ }
+
+ if fsAccess == boot.FileAccessShared && *overlay {
+ cmd.Fatalf("overlay flag is incompatible with shared file access")
+ }
+
+ netType, err := boot.MakeNetworkType(*network)
+ if err != nil {
+ cmd.Fatalf("%v", err)
+ }
+
+ wa, err := boot.MakeWatchdogAction(*watchdogAction)
+ if err != nil {
+ cmd.Fatalf("%v", err)
+ }
+
+ if *numNetworkChannels <= 0 {
+ cmd.Fatalf("num_network_channels must be > 0, got: %d", *numNetworkChannels)
+ }
+
+ refsLeakMode, err := boot.MakeRefsLeakMode(*referenceLeakMode)
+ if err != nil {
+ cmd.Fatalf("%v", err)
+ }
+
+ queueingDiscipline, err := boot.MakeQueueingDiscipline(*qDisc)
+ if err != nil {
+ cmd.Fatalf("%s", err)
+ }
+
+ // Sets the reference leak check mode. Also set it in config below to
+ // propagate it to child processes.
+ refs.SetLeakMode(refsLeakMode)
+
+ // Create a new Config from the flags.
+ conf := &boot.Config{
+ RootDir: *rootDir,
+ Debug: *debug,
+ LogFilename: *logFilename,
+ LogFormat: *logFormat,
+ DebugLog: *debugLog,
+ PanicLog: *panicLog,
+ DebugLogFormat: *debugLogFormat,
+ FileAccess: fsAccess,
+ FSGoferHostUDS: *fsGoferHostUDS,
+ Overlay: *overlay,
+ Network: netType,
+ HardwareGSO: *hardwareGSO,
+ SoftwareGSO: *softwareGSO,
+ TXChecksumOffload: *txChecksumOffload,
+ RXChecksumOffload: *rxChecksumOffload,
+ LogPackets: *logPackets,
+ Platform: platformType,
+ Strace: *strace,
+ StraceLogSize: *straceLogSize,
+ WatchdogAction: wa,
+ PanicSignal: *panicSignal,
+ ProfileEnable: *profile,
+ EnableRaw: *netRaw,
+ NumNetworkChannels: *numNetworkChannels,
+ Rootless: *rootless,
+ AlsoLogToStderr: *alsoLogToStderr,
+ ReferenceLeakMode: refsLeakMode,
+ OverlayfsStaleRead: *overlayfsStaleRead,
+ CPUNumFromQuota: *cpuNumFromQuota,
+ VFS2: *vfs2Enabled,
+ QDisc: queueingDiscipline,
+ TestOnlyAllowRunAsCurrentUserWithoutChroot: *testOnlyAllowRunAsCurrentUserWithoutChroot,
+ TestOnlyTestNameEnv: *testOnlyTestNameEnv,
+ }
+ if len(*straceSyscalls) != 0 {
+ conf.StraceSyscalls = strings.Split(*straceSyscalls, ",")
+ }
+
+ // Set up logging.
+ if *debug {
+ log.SetLevel(log.Debug)
+ }
+
+ // Logging will include the local date and time via the time package.
+ //
+ // On first use, time.Local initializes the local time zone, which
+ // involves opening tzdata files on the host. Since this requires
+ // opening host files, it must be done before syscall filter
+ // installation.
+ //
+ // Generally there will be a log message before filter installation
+ // that will force initialization, but force initialization here in
+ // case that does not occur.
+ _ = time.Local.String()
+
+ subcommand := flag.CommandLine.Arg(0)
+
+ var e log.Emitter
+ if *debugLogFD > -1 {
+ f := os.NewFile(uintptr(*debugLogFD), "debug log file")
+
+ e = newEmitter(*debugLogFormat, f)
+
+ } else if *debugLog != "" {
+ f, err := specutils.DebugLogFile(*debugLog, subcommand, "" /* name */)
+ if err != nil {
+ cmd.Fatalf("error opening debug log file in %q: %v", *debugLog, err)
+ }
+ e = newEmitter(*debugLogFormat, f)
+
+ } else {
+ // Stderr is reserved for the application, just discard the logs if no debug
+ // log is specified.
+ e = newEmitter("text", ioutil.Discard)
+ }
+
+ if *panicLogFD > -1 || *debugLogFD > -1 {
+ fd := *panicLogFD
+ if fd < 0 {
+ fd = *debugLogFD
+ }
+ // Quick sanity check to make sure no other commands get passed
+ // a log fd (they should use log dir instead).
+ if subcommand != "boot" && subcommand != "gofer" {
+ cmd.Fatalf("flags --debug-log-fd and --panic-log-fd should only be passed to 'boot' and 'gofer' command, but was passed to %q", subcommand)
+ }
+
+ // If we are the boot process, then we own our stdio FDs and can do what we
+ // want with them. Since Docker and Containerd both eat boot's stderr, we
+ // dup our stderr to the provided log FD so that panics will appear in the
+ // logs, rather than just disappear.
+ if err := syscall.Dup3(fd, int(os.Stderr.Fd()), 0); err != nil {
+ cmd.Fatalf("error dup'ing fd %d to stderr: %v", fd, err)
+ }
+ } else if *alsoLogToStderr {
+ e = &log.MultiEmitter{e, newEmitter(*debugLogFormat, os.Stderr)}
+ }
+
+ log.SetTarget(e)
+
+ log.Infof("***************************")
+ log.Infof("Args: %s", os.Args)
+ log.Infof("Version %s", version)
+ log.Infof("PID: %d", os.Getpid())
+ log.Infof("UID: %d, GID: %d", os.Getuid(), os.Getgid())
+ log.Infof("Configuration:")
+ log.Infof("\t\tRootDir: %s", conf.RootDir)
+ log.Infof("\t\tPlatform: %v", conf.Platform)
+ log.Infof("\t\tFileAccess: %v, overlay: %t", conf.FileAccess, conf.Overlay)
+ log.Infof("\t\tNetwork: %v, logging: %t", conf.Network, conf.LogPackets)
+ log.Infof("\t\tStrace: %t, max size: %d, syscalls: %s", conf.Strace, conf.StraceLogSize, conf.StraceSyscalls)
+ log.Infof("\t\tVFS2 enabled: %v", conf.VFS2)
+ log.Infof("***************************")
+
+ if *testOnlyAllowRunAsCurrentUserWithoutChroot {
+ // SIGTERM is sent to all processes if a test exceeds its
+ // timeout and this case is handled by syscall_test_runner.
+ log.Warningf("Block the TERM signal. This is only safe in tests!")
+ signal.Ignore(syscall.SIGTERM)
+ }
+
+ // Call the subcommand and pass in the configuration.
+ var ws syscall.WaitStatus
+ subcmdCode := subcommands.Execute(context.Background(), conf, &ws)
+ if subcmdCode == subcommands.ExitSuccess {
+ log.Infof("Exiting with status: %v", ws)
+ if ws.Signaled() {
+ // No good way to return it, emulate what the shell does. Maybe raise
+ // signal to self?
+ os.Exit(128 + int(ws.Signal()))
+ }
+ os.Exit(ws.ExitStatus())
+ }
+ // Return an error that is unlikely to be used by the application.
+ log.Warningf("Failure to execute command, err: %v", subcmdCode)
+ os.Exit(128)
+}
+
+func newEmitter(format string, logFile io.Writer) log.Emitter {
+ switch format {
+ case "text":
+ return log.GoogleEmitter{&log.Writer{Next: logFile}}
+ case "json":
+ return log.JSONEmitter{&log.Writer{Next: logFile}}
+ case "json-k8s":
+ return log.K8sJSONEmitter{&log.Writer{Next: logFile}}
+ }
+ cmd.Fatalf("invalid log format %q, must be 'text', 'json', or 'json-k8s'", format)
+ panic("unreachable")
+}
+
+func init() {
+ // Set default root dir to something (hopefully) user-writeable.
+ *rootDir = "/var/run/runsc"
+ if runtimeDir := os.Getenv("XDG_RUNTIME_DIR"); runtimeDir != "" {
+ *rootDir = filepath.Join(runtimeDir, "runsc")
+ }
+}
diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD
new file mode 100644
index 000000000..035dcd3e3
--- /dev/null
+++ b/runsc/sandbox/BUILD
@@ -0,0 +1,37 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "sandbox",
+ srcs = [
+ "network.go",
+ "network_unsafe.go",
+ "sandbox.go",
+ ],
+ visibility = [
+ "//runsc:__subpackages__",
+ ],
+ deps = [
+ "//pkg/cleanup",
+ "//pkg/control/client",
+ "//pkg/control/server",
+ "//pkg/log",
+ "//pkg/sentry/control",
+ "//pkg/sentry/platform",
+ "//pkg/sync",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ "//pkg/urpc",
+ "//runsc/boot",
+ "//runsc/boot/platforms",
+ "//runsc/cgroup",
+ "//runsc/console",
+ "//runsc/specutils",
+ "@com_github_cenkalti_backoff//:go_default_library",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@com_github_syndtr_gocapability//capability:go_default_library",
+ "@com_github_vishvananda_netlink//:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go
new file mode 100644
index 000000000..817a923ad
--- /dev/null
+++ b/runsc/sandbox/network.go
@@ -0,0 +1,411 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sandbox
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strconv"
+ "syscall"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "github.com/vishvananda/netlink"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/urpc"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// setupNetwork configures the network stack to mimic the local network
+// configuration. Docker uses network namespaces with vnets to configure the
+// network for the container. The untrusted app expects to see the same network
+// inside the sandbox. Routing and port mapping is handled directly by docker
+// with most of network information not even available to the runtime.
+//
+// Netstack inside the sandbox speaks directly to the device using a raw socket.
+// All IP addresses assigned to the NIC, are removed and passed on to netstack's
+// device.
+//
+// If 'conf.Network' is NoNetwork, skips local configuration and creates a
+// loopback interface only.
+//
+// Run the following container to test it:
+// docker run -di --runtime=runsc -p 8080:80 -v $PWD:/usr/local/apache2/htdocs/ httpd:2.4
+func setupNetwork(conn *urpc.Client, pid int, spec *specs.Spec, conf *boot.Config) error {
+ log.Infof("Setting up network")
+
+ switch conf.Network {
+ case boot.NetworkNone:
+ log.Infof("Network is disabled, create loopback interface only")
+ if err := createDefaultLoopbackInterface(conn); err != nil {
+ return fmt.Errorf("creating default loopback interface: %v", err)
+ }
+ case boot.NetworkSandbox:
+ // Build the path to the net namespace of the sandbox process.
+ // This is what we will copy.
+ nsPath := filepath.Join("/proc", strconv.Itoa(pid), "ns/net")
+ if err := createInterfacesAndRoutesFromNS(conn, nsPath, conf.HardwareGSO, conf.SoftwareGSO, conf.TXChecksumOffload, conf.RXChecksumOffload, conf.NumNetworkChannels, conf.QDisc); err != nil {
+ return fmt.Errorf("creating interfaces from net namespace %q: %v", nsPath, err)
+ }
+ case boot.NetworkHost:
+ // Nothing to do here.
+ default:
+ return fmt.Errorf("invalid network type: %d", conf.Network)
+ }
+ return nil
+}
+
+func createDefaultLoopbackInterface(conn *urpc.Client) error {
+ if err := conn.Call(boot.NetworkCreateLinksAndRoutes, &boot.CreateLinksAndRoutesArgs{
+ LoopbackLinks: []boot.LoopbackLink{boot.DefaultLoopbackLink},
+ }, nil); err != nil {
+ return fmt.Errorf("creating loopback link and routes: %v", err)
+ }
+ return nil
+}
+
+func joinNetNS(nsPath string) (func(), error) {
+ runtime.LockOSThread()
+ restoreNS, err := specutils.ApplyNS(specs.LinuxNamespace{
+ Type: specs.NetworkNamespace,
+ Path: nsPath,
+ })
+ if err != nil {
+ runtime.UnlockOSThread()
+ return nil, fmt.Errorf("joining net namespace %q: %v", nsPath, err)
+ }
+ return func() {
+ restoreNS()
+ runtime.UnlockOSThread()
+ }, nil
+}
+
+// isRootNS determines whether we are running in the root net namespace.
+// /proc/sys/net/core/rmem_default only exists in root network namespace.
+func isRootNS() (bool, error) {
+ err := syscall.Access("/proc/sys/net/core/rmem_default", syscall.F_OK)
+ switch err {
+ case nil:
+ return true, nil
+ case syscall.ENOENT:
+ return false, nil
+ default:
+ return false, fmt.Errorf("failed to access /proc/sys/net/core/rmem_default: %v", err)
+ }
+}
+
+// createInterfacesAndRoutesFromNS scrapes the interface and routes from the
+// net namespace with the given path, creates them in the sandbox, and removes
+// them from the host.
+func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareGSO bool, softwareGSO bool, txChecksumOffload bool, rxChecksumOffload bool, numNetworkChannels int, qDisc boot.QueueingDiscipline) error {
+ // Join the network namespace that we will be copying.
+ restore, err := joinNetNS(nsPath)
+ if err != nil {
+ return err
+ }
+ defer restore()
+
+ // Get all interfaces in the namespace.
+ ifaces, err := net.Interfaces()
+ if err != nil {
+ return fmt.Errorf("querying interfaces: %v", err)
+ }
+
+ isRoot, err := isRootNS()
+ if err != nil {
+ return err
+ }
+ if isRoot {
+ return fmt.Errorf("cannot run with network enabled in root network namespace")
+ }
+
+ // Collect addresses and routes from the interfaces.
+ var args boot.CreateLinksAndRoutesArgs
+ for _, iface := range ifaces {
+ if iface.Flags&net.FlagUp == 0 {
+ log.Infof("Skipping down interface: %+v", iface)
+ continue
+ }
+
+ allAddrs, err := iface.Addrs()
+ if err != nil {
+ return fmt.Errorf("fetching interface addresses for %q: %v", iface.Name, err)
+ }
+
+ // We build our own loopback device.
+ if iface.Flags&net.FlagLoopback != 0 {
+ link, err := loopbackLink(iface, allAddrs)
+ if err != nil {
+ return fmt.Errorf("getting loopback link for iface %q: %v", iface.Name, err)
+ }
+ args.LoopbackLinks = append(args.LoopbackLinks, link)
+ continue
+ }
+
+ var ipAddrs []*net.IPNet
+ for _, ifaddr := range allAddrs {
+ ipNet, ok := ifaddr.(*net.IPNet)
+ if !ok {
+ return fmt.Errorf("address is not IPNet: %+v", ifaddr)
+ }
+ ipAddrs = append(ipAddrs, ipNet)
+ }
+ if len(ipAddrs) == 0 {
+ log.Warningf("No usable IP addresses found for interface %q, skipping", iface.Name)
+ continue
+ }
+
+ // Scrape the routes before removing the address, since that
+ // will remove the routes as well.
+ routes, defv4, defv6, err := routesForIface(iface)
+ if err != nil {
+ return fmt.Errorf("getting routes for interface %q: %v", iface.Name, err)
+ }
+ if defv4 != nil {
+ if !args.Defaultv4Gateway.Route.Empty() {
+ return fmt.Errorf("more than one default route found, interface: %v, route: %v, default route: %+v", iface.Name, defv4, args.Defaultv4Gateway)
+ }
+ args.Defaultv4Gateway.Route = *defv4
+ args.Defaultv4Gateway.Name = iface.Name
+ }
+
+ if defv6 != nil {
+ if !args.Defaultv6Gateway.Route.Empty() {
+ return fmt.Errorf("more than one default route found, interface: %v, route: %v, default route: %+v", iface.Name, defv6, args.Defaultv6Gateway)
+ }
+ args.Defaultv6Gateway.Route = *defv6
+ args.Defaultv6Gateway.Name = iface.Name
+ }
+
+ link := boot.FDBasedLink{
+ Name: iface.Name,
+ MTU: iface.MTU,
+ Routes: routes,
+ TXChecksumOffload: txChecksumOffload,
+ RXChecksumOffload: rxChecksumOffload,
+ NumChannels: numNetworkChannels,
+ QDisc: qDisc,
+ }
+
+ // Get the link for the interface.
+ ifaceLink, err := netlink.LinkByName(iface.Name)
+ if err != nil {
+ return fmt.Errorf("getting link for interface %q: %v", iface.Name, err)
+ }
+ link.LinkAddress = ifaceLink.Attrs().HardwareAddr
+
+ log.Debugf("Setting up network channels")
+ // Create the socket for the device.
+ for i := 0; i < link.NumChannels; i++ {
+ log.Debugf("Creating Channel %d", i)
+ socketEntry, err := createSocket(iface, ifaceLink, hardwareGSO)
+ if err != nil {
+ return fmt.Errorf("failed to createSocket for %s : %v", iface.Name, err)
+ }
+ if i == 0 {
+ link.GSOMaxSize = socketEntry.gsoMaxSize
+ } else {
+ if link.GSOMaxSize != socketEntry.gsoMaxSize {
+ return fmt.Errorf("inconsistent gsoMaxSize %d and %d when creating multiple channels for same interface: %s",
+ link.GSOMaxSize, socketEntry.gsoMaxSize, iface.Name)
+ }
+ }
+ args.FilePayload.Files = append(args.FilePayload.Files, socketEntry.deviceFile)
+ }
+
+ if link.GSOMaxSize == 0 && softwareGSO {
+ // Hardware GSO is disabled. Let's enable software GSO.
+ link.GSOMaxSize = stack.SoftwareGSOMaxSize
+ link.SoftwareGSOEnabled = true
+ }
+
+ // Collect the addresses for the interface, enable forwarding,
+ // and remove them from the host.
+ for _, addr := range ipAddrs {
+ link.Addresses = append(link.Addresses, addr.IP)
+
+ // Steal IP address from NIC.
+ if err := removeAddress(ifaceLink, addr.String()); err != nil {
+ return fmt.Errorf("removing address %v from device %q: %v", iface.Name, addr, err)
+ }
+ }
+
+ args.FDBasedLinks = append(args.FDBasedLinks, link)
+ }
+
+ log.Debugf("Setting up network, config: %+v", args)
+ if err := conn.Call(boot.NetworkCreateLinksAndRoutes, &args, nil); err != nil {
+ return fmt.Errorf("creating links and routes: %v", err)
+ }
+ return nil
+}
+
+type socketEntry struct {
+ deviceFile *os.File
+ gsoMaxSize uint32
+}
+
+// createSocket creates an underlying AF_PACKET socket and configures it for use by
+// the sentry and returns an *os.File that wraps the underlying socket fd.
+func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) (*socketEntry, error) {
+ // Create the socket.
+ const protocol = 0x0300 // htons(ETH_P_ALL)
+ fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create raw socket: %v", err)
+ }
+ deviceFile := os.NewFile(uintptr(fd), "raw-device-fd")
+ // Bind to the appropriate device.
+ ll := syscall.SockaddrLinklayer{
+ Protocol: protocol,
+ Ifindex: iface.Index,
+ Hatype: 0, // No ARP type.
+ Pkttype: syscall.PACKET_OTHERHOST,
+ }
+ if err := syscall.Bind(fd, &ll); err != nil {
+ return nil, fmt.Errorf("unable to bind to %q: %v", iface.Name, err)
+ }
+
+ gsoMaxSize := uint32(0)
+ if enableGSO {
+ gso, err := isGSOEnabled(fd, iface.Name)
+ if err != nil {
+ return nil, fmt.Errorf("getting GSO for interface %q: %v", iface.Name, err)
+ }
+ if gso {
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil {
+ return nil, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err)
+ }
+ gsoMaxSize = ifaceLink.Attrs().GSOMaxSize
+ } else {
+ log.Infof("GSO not available in host.")
+ }
+ }
+
+ // Use SO_RCVBUFFORCE/SO_SNDBUFFORCE because on linux the receive/send buffer
+ // for an AF_PACKET socket is capped by "net.core.rmem_max/wmem_max".
+ // wmem_max/rmem_max default to a unusually low value of 208KB. This is too low
+ // for gVisor to be able to receive packets at high throughputs without
+ // incurring packet drops.
+ const bufSize = 4 << 20 // 4MB.
+
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, bufSize); err != nil {
+ return nil, fmt.Errorf("failed to increase socket rcv buffer to %d: %v", bufSize, err)
+ }
+
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUFFORCE, bufSize); err != nil {
+ return nil, fmt.Errorf("failed to increase socket snd buffer to %d: %v", bufSize, err)
+ }
+
+ return &socketEntry{deviceFile, gsoMaxSize}, nil
+}
+
+// loopbackLink returns the link with addresses and routes for a loopback
+// interface.
+func loopbackLink(iface net.Interface, addrs []net.Addr) (boot.LoopbackLink, error) {
+ link := boot.LoopbackLink{
+ Name: iface.Name,
+ }
+ for _, addr := range addrs {
+ ipNet, ok := addr.(*net.IPNet)
+ if !ok {
+ return boot.LoopbackLink{}, fmt.Errorf("address is not IPNet: %+v", addr)
+ }
+ dst := *ipNet
+ dst.IP = dst.IP.Mask(dst.Mask)
+ link.Addresses = append(link.Addresses, ipNet.IP)
+ link.Routes = append(link.Routes, boot.Route{
+ Destination: dst,
+ })
+ }
+ return link, nil
+}
+
+// routesForIface iterates over all routes for the given interface and converts
+// them to boot.Routes. It also returns the a default v4/v6 route if found.
+func routesForIface(iface net.Interface) ([]boot.Route, *boot.Route, *boot.Route, error) {
+ link, err := netlink.LinkByIndex(iface.Index)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+ rs, err := netlink.RouteList(link, netlink.FAMILY_ALL)
+ if err != nil {
+ return nil, nil, nil, fmt.Errorf("getting routes from %q: %v", iface.Name, err)
+ }
+
+ var defv4, defv6 *boot.Route
+ var routes []boot.Route
+ for _, r := range rs {
+ // Is it a default route?
+ if r.Dst == nil {
+ if r.Gw == nil {
+ return nil, nil, nil, fmt.Errorf("default route with no gateway %q: %+v", iface.Name, r)
+ }
+ // Create a catch all route to the gateway.
+ switch len(r.Gw) {
+ case header.IPv4AddressSize:
+ if defv4 != nil {
+ return nil, nil, nil, fmt.Errorf("more than one default route found %q, def: %+v, route: %+v", iface.Name, defv4, r)
+ }
+ defv4 = &boot.Route{
+ Destination: net.IPNet{
+ IP: net.IPv4zero,
+ Mask: net.IPMask(net.IPv4zero),
+ },
+ Gateway: r.Gw,
+ }
+ case header.IPv6AddressSize:
+ if defv6 != nil {
+ return nil, nil, nil, fmt.Errorf("more than one default route found %q, def: %+v, route: %+v", iface.Name, defv6, r)
+ }
+
+ defv6 = &boot.Route{
+ Destination: net.IPNet{
+ IP: net.IPv6zero,
+ Mask: net.IPMask(net.IPv6zero),
+ },
+ Gateway: r.Gw,
+ }
+ default:
+ return nil, nil, nil, fmt.Errorf("unexpected address size for gateway: %+v for route: %+v", r.Gw, r)
+ }
+ continue
+ }
+
+ dst := *r.Dst
+ dst.IP = dst.IP.Mask(dst.Mask)
+ routes = append(routes, boot.Route{
+ Destination: dst,
+ Gateway: r.Gw,
+ })
+ }
+ return routes, defv4, defv6, nil
+}
+
+// removeAddress removes IP address from network device. It's equivalent to:
+// ip addr del <ipAndMask> dev <name>
+func removeAddress(source netlink.Link, ipAndMask string) error {
+ addr, err := netlink.ParseAddr(ipAndMask)
+ if err != nil {
+ return err
+ }
+ return netlink.AddrDel(source, addr)
+}
diff --git a/runsc/sandbox/network_unsafe.go b/runsc/sandbox/network_unsafe.go
new file mode 100644
index 000000000..2a2a0fb7e
--- /dev/null
+++ b/runsc/sandbox/network_unsafe.go
@@ -0,0 +1,56 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sandbox
+
+import (
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+type ethtoolValue struct {
+ cmd uint32
+ val uint32
+}
+
+type ifreq struct {
+ ifrName [unix.IFNAMSIZ]byte
+ ifrData *ethtoolValue
+}
+
+const (
+ _ETHTOOL_GGSO = 0x00000023
+)
+
+func isGSOEnabled(fd int, intf string) (bool, error) {
+ val := ethtoolValue{
+ cmd: _ETHTOOL_GGSO,
+ }
+
+ var name [unix.IFNAMSIZ]byte
+ copy(name[:], []byte(intf))
+
+ ifr := ifreq{
+ ifrName: name,
+ ifrData: &val,
+ }
+
+ if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), unix.SIOCETHTOOL, uintptr(unsafe.Pointer(&ifr))); err != 0 {
+ return false, err
+ }
+
+ return val.val != 0, nil
+}
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
new file mode 100644
index 000000000..6e1a2af25
--- /dev/null
+++ b/runsc/sandbox/sandbox.go
@@ -0,0 +1,1228 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package sandbox creates and manipulates sandboxes.
+package sandbox
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "math"
+ "os"
+ "os/exec"
+ "strconv"
+ "strings"
+ "syscall"
+ "time"
+
+ "github.com/cenkalti/backoff"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "github.com/syndtr/gocapability/capability"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/control/client"
+ "gvisor.dev/gvisor/pkg/control/server"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/control"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/urpc"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/boot/platforms"
+ "gvisor.dev/gvisor/runsc/cgroup"
+ "gvisor.dev/gvisor/runsc/console"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// Sandbox wraps a sandbox process.
+//
+// It is used to start/stop sandbox process (and associated processes like
+// gofers), as well as for running and manipulating containers inside a running
+// sandbox.
+//
+// Note: Sandbox must be immutable because a copy of it is saved for each
+// container and changes would not be synchronized to all of them.
+type Sandbox struct {
+ // ID is the id of the sandbox (immutable). By convention, this is the same
+ // ID as the first container run in the sandbox.
+ ID string `json:"id"`
+
+ // Pid is the pid of the running sandbox (immutable). May be 0 if the sandbox
+ // is not running.
+ Pid int `json:"pid"`
+
+ // Cgroup has the cgroup configuration for the sandbox.
+ Cgroup *cgroup.Cgroup `json:"cgroup"`
+
+ // child is set if a sandbox process is a child of the current process.
+ //
+ // This field isn't saved to json, because only a creator of sandbox
+ // will have it as a child process.
+ child bool
+
+ // status is an exit status of a sandbox process.
+ status syscall.WaitStatus
+
+ // statusMu protects status.
+ statusMu sync.Mutex
+}
+
+// Args is used to configure a new sandbox.
+type Args struct {
+ // ID is the sandbox unique identifier.
+ ID string
+
+ // Spec is the OCI spec that describes the container.
+ Spec *specs.Spec
+
+ // BundleDir is the directory containing the container bundle.
+ BundleDir string
+
+ // ConsoleSocket is the path to a unix domain socket that will receive
+ // the console FD. It may be empty.
+ ConsoleSocket string
+
+ // UserLog is the filename to send user-visible logs to. It may be empty.
+ UserLog string
+
+ // IOFiles is the list of files that connect to a 9P endpoint for the mounts
+ // points using Gofers. They must be in the same order as mounts appear in
+ // the spec.
+ IOFiles []*os.File
+
+ // MountsFile is a file container mount information from the spec. It's
+ // equivalent to the mounts from the spec, except that all paths have been
+ // resolved to their final absolute location.
+ MountsFile *os.File
+
+ // Gcgroup is the cgroup that the sandbox is part of.
+ Cgroup *cgroup.Cgroup
+
+ // Attached indicates that the sandbox lifecycle is attached with the caller.
+ // If the caller exits, the sandbox should exit too.
+ Attached bool
+}
+
+// New creates the sandbox process. The caller must call Destroy() on the
+// sandbox.
+func New(conf *boot.Config, args *Args) (*Sandbox, error) {
+ s := &Sandbox{ID: args.ID, Cgroup: args.Cgroup}
+ // The Cleanup object cleans up partially created sandboxes when an error
+ // occurs. Any errors occurring during cleanup itself are ignored.
+ c := cleanup.Make(func() {
+ err := s.destroy()
+ log.Warningf("error destroying sandbox: %v", err)
+ })
+ defer c.Clean()
+
+ // Create pipe to synchronize when sandbox process has been booted.
+ clientSyncFile, sandboxSyncFile, err := os.Pipe()
+ if err != nil {
+ return nil, fmt.Errorf("creating pipe for sandbox %q: %v", s.ID, err)
+ }
+ defer clientSyncFile.Close()
+
+ // Create the sandbox process.
+ err = s.createSandboxProcess(conf, args, sandboxSyncFile)
+ // sandboxSyncFile has to be closed to be able to detect when the sandbox
+ // process exits unexpectedly.
+ sandboxSyncFile.Close()
+ if err != nil {
+ return nil, err
+ }
+
+ // Wait until the sandbox has booted.
+ b := make([]byte, 1)
+ if l, err := clientSyncFile.Read(b); err != nil || l != 1 {
+ err := fmt.Errorf("waiting for sandbox to start: %v", err)
+ // If the sandbox failed to start, it may be because the binary
+ // permissions were incorrect. Check the bits and return a more helpful
+ // error message.
+ //
+ // NOTE: The error message is checked because error types are lost over
+ // rpc calls.
+ if strings.Contains(err.Error(), io.EOF.Error()) {
+ if permsErr := checkBinaryPermissions(conf); permsErr != nil {
+ return nil, fmt.Errorf("%v: %v", err, permsErr)
+ }
+ }
+ return nil, err
+ }
+
+ c.Release()
+ return s, nil
+}
+
+// CreateContainer creates a non-root container inside the sandbox.
+func (s *Sandbox) CreateContainer(cid string) error {
+ log.Debugf("Create non-root container %q in sandbox %q, PID: %d", cid, s.ID, s.Pid)
+ sandboxConn, err := s.sandboxConnect()
+ if err != nil {
+ return fmt.Errorf("couldn't connect to sandbox: %v", err)
+ }
+ defer sandboxConn.Close()
+
+ if err := sandboxConn.Call(boot.ContainerCreate, &cid, nil); err != nil {
+ return fmt.Errorf("creating non-root container %q: %v", cid, err)
+ }
+ return nil
+}
+
+// StartRoot starts running the root container process inside the sandbox.
+func (s *Sandbox) StartRoot(spec *specs.Spec, conf *boot.Config) error {
+ log.Debugf("Start root sandbox %q, PID: %d", s.ID, s.Pid)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ // Configure the network.
+ if err := setupNetwork(conn, s.Pid, spec, conf); err != nil {
+ return fmt.Errorf("setting up network: %v", err)
+ }
+
+ // Send a message to the sandbox control server to start the root
+ // container.
+ if err := conn.Call(boot.RootContainerStart, &s.ID, nil); err != nil {
+ return fmt.Errorf("starting root container: %v", err)
+ }
+
+ return nil
+}
+
+// StartContainer starts running a non-root container inside the sandbox.
+func (s *Sandbox) StartContainer(spec *specs.Spec, conf *boot.Config, cid string, goferFiles []*os.File) error {
+ for _, f := range goferFiles {
+ defer f.Close()
+ }
+
+ log.Debugf("Start non-root container %q in sandbox %q, PID: %d", cid, s.ID, s.Pid)
+ sandboxConn, err := s.sandboxConnect()
+ if err != nil {
+ return fmt.Errorf("couldn't connect to sandbox: %v", err)
+ }
+ defer sandboxConn.Close()
+
+ // The payload must container stdin/stdout/stderr followed by gofer
+ // files.
+ files := append([]*os.File{os.Stdin, os.Stdout, os.Stderr}, goferFiles...)
+ // Start running the container.
+ args := boot.StartArgs{
+ Spec: spec,
+ Conf: conf,
+ CID: cid,
+ FilePayload: urpc.FilePayload{Files: files},
+ }
+ if err := sandboxConn.Call(boot.ContainerStart, &args, nil); err != nil {
+ return fmt.Errorf("starting non-root container %v: %v", spec.Process.Args, err)
+ }
+ return nil
+}
+
+// Restore sends the restore call for a container in the sandbox.
+func (s *Sandbox) Restore(cid string, spec *specs.Spec, conf *boot.Config, filename string) error {
+ log.Debugf("Restore sandbox %q", s.ID)
+
+ rf, err := os.Open(filename)
+ if err != nil {
+ return fmt.Errorf("opening restore file %q failed: %v", filename, err)
+ }
+ defer rf.Close()
+
+ opt := boot.RestoreOpts{
+ FilePayload: urpc.FilePayload{
+ Files: []*os.File{rf},
+ },
+ SandboxID: s.ID,
+ }
+
+ // If the platform needs a device FD we must pass it in.
+ if deviceFile, err := deviceFileForPlatform(conf.Platform); err != nil {
+ return err
+ } else if deviceFile != nil {
+ defer deviceFile.Close()
+ opt.FilePayload.Files = append(opt.FilePayload.Files, deviceFile)
+ }
+
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ // Configure the network.
+ if err := setupNetwork(conn, s.Pid, spec, conf); err != nil {
+ return fmt.Errorf("setting up network: %v", err)
+ }
+
+ // Restore the container and start the root container.
+ if err := conn.Call(boot.ContainerRestore, &opt, nil); err != nil {
+ return fmt.Errorf("restoring container %q: %v", cid, err)
+ }
+
+ return nil
+}
+
+// Processes retrieves the list of processes and associated metadata for a
+// given container in this sandbox.
+func (s *Sandbox) Processes(cid string) ([]*control.Process, error) {
+ log.Debugf("Getting processes for container %q in sandbox %q", cid, s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return nil, err
+ }
+ defer conn.Close()
+
+ var pl []*control.Process
+ if err := conn.Call(boot.ContainerProcesses, &cid, &pl); err != nil {
+ return nil, fmt.Errorf("retrieving process data from sandbox: %v", err)
+ }
+ return pl, nil
+}
+
+// Execute runs the specified command in the container. It returns the PID of
+// the newly created process.
+func (s *Sandbox) Execute(args *control.ExecArgs) (int32, error) {
+ log.Debugf("Executing new process in container %q in sandbox %q", args.ContainerID, s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return 0, s.connError(err)
+ }
+ defer conn.Close()
+
+ // Send a message to the sandbox control server to start the container.
+ var pid int32
+ if err := conn.Call(boot.ContainerExecuteAsync, args, &pid); err != nil {
+ return 0, fmt.Errorf("executing command %q in sandbox: %v", args, err)
+ }
+ return pid, nil
+}
+
+// Event retrieves stats about the sandbox such as memory and CPU utilization.
+func (s *Sandbox) Event(cid string) (*boot.Event, error) {
+ log.Debugf("Getting events for container %q in sandbox %q", cid, s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return nil, err
+ }
+ defer conn.Close()
+
+ var e boot.Event
+ // TODO(b/129292330): Pass in the container id (cid) here. The sandbox
+ // should return events only for that container.
+ if err := conn.Call(boot.ContainerEvent, nil, &e); err != nil {
+ return nil, fmt.Errorf("retrieving event data from sandbox: %v", err)
+ }
+ e.ID = cid
+ return &e, nil
+}
+
+func (s *Sandbox) sandboxConnect() (*urpc.Client, error) {
+ log.Debugf("Connecting to sandbox %q", s.ID)
+ conn, err := client.ConnectTo(boot.ControlSocketAddr(s.ID))
+ if err != nil {
+ return nil, s.connError(err)
+ }
+ return conn, nil
+}
+
+func (s *Sandbox) connError(err error) error {
+ return fmt.Errorf("connecting to control server at PID %d: %v", s.Pid, err)
+}
+
+// createSandboxProcess starts the sandbox as a subprocess by running the "boot"
+// command, passing in the bundle dir.
+func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncFile *os.File) error {
+ // nextFD is used to get unused FDs that we can pass to the sandbox. It
+ // starts at 3 because 0, 1, and 2 are taken by stdin/out/err.
+ nextFD := 3
+
+ binPath := specutils.ExePath
+ cmd := exec.Command(binPath, conf.ToFlags()...)
+ cmd.SysProcAttr = &syscall.SysProcAttr{}
+
+ // Open the log files to pass to the sandbox as FDs.
+ //
+ // These flags must come BEFORE the "boot" command in cmd.Args.
+ if conf.LogFilename != "" {
+ logFile, err := os.OpenFile(conf.LogFilename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
+ if err != nil {
+ return fmt.Errorf("opening log file %q: %v", conf.LogFilename, err)
+ }
+ defer logFile.Close()
+ cmd.ExtraFiles = append(cmd.ExtraFiles, logFile)
+ cmd.Args = append(cmd.Args, "--log-fd="+strconv.Itoa(nextFD))
+ nextFD++
+ }
+ if conf.DebugLog != "" {
+ test := ""
+ if len(conf.TestOnlyTestNameEnv) != 0 {
+ // Fetch test name if one is provided and the test only flag was set.
+ if t, ok := specutils.EnvVar(args.Spec.Process.Env, conf.TestOnlyTestNameEnv); ok {
+ test = t
+ }
+ }
+
+ debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "boot", test)
+ if err != nil {
+ return fmt.Errorf("opening debug log file in %q: %v", conf.DebugLog, err)
+ }
+ defer debugLogFile.Close()
+ cmd.ExtraFiles = append(cmd.ExtraFiles, debugLogFile)
+ cmd.Args = append(cmd.Args, "--debug-log-fd="+strconv.Itoa(nextFD))
+ nextFD++
+ }
+ if conf.PanicLog != "" {
+ test := ""
+ if len(conf.TestOnlyTestNameEnv) != 0 {
+ // Fetch test name if one is provided and the test only flag was set.
+ if t, ok := specutils.EnvVar(args.Spec.Process.Env, conf.TestOnlyTestNameEnv); ok {
+ test = t
+ }
+ }
+
+ panicLogFile, err := specutils.DebugLogFile(conf.PanicLog, "panic", test)
+ if err != nil {
+ return fmt.Errorf("opening debug log file in %q: %v", conf.PanicLog, err)
+ }
+ defer panicLogFile.Close()
+ cmd.ExtraFiles = append(cmd.ExtraFiles, panicLogFile)
+ cmd.Args = append(cmd.Args, "--panic-log-fd="+strconv.Itoa(nextFD))
+ nextFD++
+ }
+
+ // Add the "boot" command to the args.
+ //
+ // All flags after this must be for the boot command
+ cmd.Args = append(cmd.Args, "boot", "--bundle="+args.BundleDir)
+
+ // Create a socket for the control server and donate it to the sandbox.
+ addr := boot.ControlSocketAddr(s.ID)
+ sockFD, err := server.CreateSocket(addr)
+ log.Infof("Creating sandbox process with addr: %s", addr[1:]) // skip "\00".
+ if err != nil {
+ return fmt.Errorf("creating control server socket for sandbox %q: %v", s.ID, err)
+ }
+ controllerFile := os.NewFile(uintptr(sockFD), "control_server_socket")
+ defer controllerFile.Close()
+ cmd.ExtraFiles = append(cmd.ExtraFiles, controllerFile)
+ cmd.Args = append(cmd.Args, "--controller-fd="+strconv.Itoa(nextFD))
+ nextFD++
+
+ defer args.MountsFile.Close()
+ cmd.ExtraFiles = append(cmd.ExtraFiles, args.MountsFile)
+ cmd.Args = append(cmd.Args, "--mounts-fd="+strconv.Itoa(nextFD))
+ nextFD++
+
+ specFile, err := specutils.OpenSpec(args.BundleDir)
+ if err != nil {
+ return err
+ }
+ defer specFile.Close()
+ cmd.ExtraFiles = append(cmd.ExtraFiles, specFile)
+ cmd.Args = append(cmd.Args, "--spec-fd="+strconv.Itoa(nextFD))
+ nextFD++
+
+ cmd.ExtraFiles = append(cmd.ExtraFiles, startSyncFile)
+ cmd.Args = append(cmd.Args, "--start-sync-fd="+strconv.Itoa(nextFD))
+ nextFD++
+
+ // If there is a gofer, sends all socket ends to the sandbox.
+ for _, f := range args.IOFiles {
+ defer f.Close()
+ cmd.ExtraFiles = append(cmd.ExtraFiles, f)
+ cmd.Args = append(cmd.Args, "--io-fds="+strconv.Itoa(nextFD))
+ nextFD++
+ }
+
+ gPlatform, err := platform.Lookup(conf.Platform)
+ if err != nil {
+ return err
+ }
+
+ if deviceFile, err := gPlatform.OpenDevice(); err != nil {
+ return fmt.Errorf("opening device file for platform %q: %v", gPlatform, err)
+ } else if deviceFile != nil {
+ defer deviceFile.Close()
+ cmd.ExtraFiles = append(cmd.ExtraFiles, deviceFile)
+ cmd.Args = append(cmd.Args, "--device-fd="+strconv.Itoa(nextFD))
+ nextFD++
+ }
+
+ // TODO(b/151157106): syscall tests fail by timeout if asyncpreemptoff
+ // isn't set.
+ if conf.Platform == "kvm" {
+ cmd.Env = append(cmd.Env, "GODEBUG=asyncpreemptoff=1")
+ }
+
+ // The current process' stdio must be passed to the application via the
+ // --stdio-fds flag. The stdio of the sandbox process itself must not
+ // be connected to the same FDs, otherwise we risk leaking sandbox
+ // errors to the application, so we set the sandbox stdio to nil,
+ // causing them to read/write from the null device.
+ cmd.Stdin = nil
+ cmd.Stdout = nil
+ cmd.Stderr = nil
+
+ // If the console control socket file is provided, then create a new
+ // pty master/slave pair and set the TTY on the sandbox process.
+ if args.ConsoleSocket != "" {
+ cmd.Args = append(cmd.Args, "--console=true")
+
+ // console.NewWithSocket will send the master on the given
+ // socket, and return the slave.
+ tty, err := console.NewWithSocket(args.ConsoleSocket)
+ if err != nil {
+ return fmt.Errorf("setting up console with socket %q: %v", args.ConsoleSocket, err)
+ }
+ defer tty.Close()
+
+ // Set the TTY as a controlling TTY on the sandbox process.
+ cmd.SysProcAttr.Setctty = true
+ // The Ctty FD must be the FD in the child process's FD table,
+ // which will be nextFD in this case.
+ // See https://github.com/golang/go/issues/29458.
+ cmd.SysProcAttr.Ctty = nextFD
+
+ // Pass the tty as all stdio fds to sandbox.
+ for i := 0; i < 3; i++ {
+ cmd.ExtraFiles = append(cmd.ExtraFiles, tty)
+ cmd.Args = append(cmd.Args, "--stdio-fds="+strconv.Itoa(nextFD))
+ nextFD++
+ }
+
+ if conf.Debug {
+ // If debugging, send the boot process stdio to the
+ // TTY, so that it is easier to find.
+ cmd.Stdin = tty
+ cmd.Stdout = tty
+ cmd.Stderr = tty
+ }
+ } else {
+ // If not using a console, pass our current stdio as the
+ // container stdio via flags.
+ for _, f := range []*os.File{os.Stdin, os.Stdout, os.Stderr} {
+ cmd.ExtraFiles = append(cmd.ExtraFiles, f)
+ cmd.Args = append(cmd.Args, "--stdio-fds="+strconv.Itoa(nextFD))
+ nextFD++
+ }
+
+ if conf.Debug {
+ // If debugging, send the boot process stdio to the
+ // this process' stdio, so that is is easier to find.
+ cmd.Stdin = os.Stdin
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+ }
+ }
+
+ // Detach from this session, otherwise cmd will get SIGHUP and SIGCONT
+ // when re-parented.
+ cmd.SysProcAttr.Setsid = true
+
+ // nss is the set of namespaces to join or create before starting the sandbox
+ // process. Mount, IPC and UTS namespaces from the host are not used as they
+ // are virtualized inside the sandbox. Be paranoid and run inside an empty
+ // namespace for these. Don't unshare cgroup because sandbox is added to a
+ // cgroup in the caller's namespace.
+ log.Infof("Sandbox will be started in new mount, IPC and UTS namespaces")
+ nss := []specs.LinuxNamespace{
+ {Type: specs.IPCNamespace},
+ {Type: specs.MountNamespace},
+ {Type: specs.UTSNamespace},
+ }
+
+ if gPlatform.Requirements().RequiresCurrentPIDNS {
+ // TODO(b/75837838): Also set a new PID namespace so that we limit
+ // access to other host processes.
+ log.Infof("Sandbox will be started in the current PID namespace")
+ } else {
+ log.Infof("Sandbox will be started in a new PID namespace")
+ nss = append(nss, specs.LinuxNamespace{Type: specs.PIDNamespace})
+ cmd.Args = append(cmd.Args, "--pidns=true")
+ }
+
+ // Joins the network namespace if network is enabled. the sandbox talks
+ // directly to the host network, which may have been configured in the
+ // namespace.
+ if ns, ok := specutils.GetNS(specs.NetworkNamespace, args.Spec); ok && conf.Network != boot.NetworkNone {
+ log.Infof("Sandbox will be started in the container's network namespace: %+v", ns)
+ nss = append(nss, ns)
+ } else if conf.Network == boot.NetworkHost {
+ log.Infof("Sandbox will be started in the host network namespace")
+ } else {
+ log.Infof("Sandbox will be started in new network namespace")
+ nss = append(nss, specs.LinuxNamespace{Type: specs.NetworkNamespace})
+ }
+
+ // User namespace depends on the network type. Host network requires to run
+ // inside the user namespace specified in the spec or the current namespace
+ // if none is configured.
+ if conf.Network == boot.NetworkHost {
+ if userns, ok := specutils.GetNS(specs.UserNamespace, args.Spec); ok {
+ log.Infof("Sandbox will be started in container's user namespace: %+v", userns)
+ nss = append(nss, userns)
+ specutils.SetUIDGIDMappings(cmd, args.Spec)
+ } else {
+ log.Infof("Sandbox will be started in the current user namespace")
+ }
+ // When running in the caller's defined user namespace, apply the same
+ // capabilities to the sandbox process to ensure it abides to the same
+ // rules.
+ cmd.Args = append(cmd.Args, "--apply-caps=true")
+
+ // If we have CAP_SYS_ADMIN, we can create an empty chroot and
+ // bind-mount the executable inside it.
+ if conf.TestOnlyAllowRunAsCurrentUserWithoutChroot {
+ log.Warningf("Running sandbox in test mode without chroot. This is only safe in tests!")
+
+ } else if specutils.HasCapabilities(capability.CAP_SYS_ADMIN) {
+ log.Infof("Sandbox will be started in minimal chroot")
+ cmd.Args = append(cmd.Args, "--setup-root")
+ } else {
+ return fmt.Errorf("can't run sandbox process in minimal chroot since we don't have CAP_SYS_ADMIN")
+ }
+ } else {
+ // If we have CAP_SETUID and CAP_SETGID, then we can also run
+ // as user nobody.
+ if conf.TestOnlyAllowRunAsCurrentUserWithoutChroot {
+ log.Warningf("Running sandbox in test mode as current user (uid=%d gid=%d). This is only safe in tests!", os.Getuid(), os.Getgid())
+ log.Warningf("Running sandbox in test mode without chroot. This is only safe in tests!")
+ } else if specutils.HasCapabilities(capability.CAP_SETUID, capability.CAP_SETGID) {
+ log.Infof("Sandbox will be started in new user namespace")
+ nss = append(nss, specs.LinuxNamespace{Type: specs.UserNamespace})
+ cmd.Args = append(cmd.Args, "--setup-root")
+
+ const nobody = 65534
+ if conf.Rootless {
+ log.Infof("Rootless mode: sandbox will run as nobody inside user namespace, mapped to the current user, uid: %d, gid: %d", os.Getuid(), os.Getgid())
+ cmd.SysProcAttr.UidMappings = []syscall.SysProcIDMap{
+ {
+ ContainerID: nobody,
+ HostID: os.Getuid(),
+ Size: 1,
+ },
+ }
+ cmd.SysProcAttr.GidMappings = []syscall.SysProcIDMap{
+ {
+ ContainerID: nobody,
+ HostID: os.Getgid(),
+ Size: 1,
+ },
+ }
+
+ } else {
+ // Map nobody in the new namespace to nobody in the parent namespace.
+ //
+ // A sandbox process will construct an empty
+ // root for itself, so it has to have
+ // CAP_SYS_ADMIN and CAP_SYS_CHROOT capabilities.
+ cmd.SysProcAttr.UidMappings = []syscall.SysProcIDMap{
+ {
+ ContainerID: nobody,
+ HostID: nobody,
+ Size: 1,
+ },
+ }
+ cmd.SysProcAttr.GidMappings = []syscall.SysProcIDMap{
+ {
+ ContainerID: nobody,
+ HostID: nobody,
+ Size: 1,
+ },
+ }
+ }
+
+ // Set credentials to run as user and group nobody.
+ cmd.SysProcAttr.Credential = &syscall.Credential{Uid: nobody, Gid: nobody}
+ cmd.SysProcAttr.AmbientCaps = append(cmd.SysProcAttr.AmbientCaps, uintptr(capability.CAP_SYS_ADMIN), uintptr(capability.CAP_SYS_CHROOT))
+ } else {
+ return fmt.Errorf("can't run sandbox process as user nobody since we don't have CAP_SETUID or CAP_SETGID")
+ }
+ }
+
+ cmd.Args[0] = "runsc-sandbox"
+
+ if s.Cgroup != nil {
+ cpuNum, err := s.Cgroup.NumCPU()
+ if err != nil {
+ return fmt.Errorf("getting cpu count from cgroups: %v", err)
+ }
+ if conf.CPUNumFromQuota {
+ // Dropping below 2 CPUs can trigger application to disable
+ // locks that can lead do hard to debug errors, so just
+ // leaving two cores as reasonable default.
+ const minCPUs = 2
+
+ quota, err := s.Cgroup.CPUQuota()
+ if err != nil {
+ return fmt.Errorf("getting cpu qouta from cgroups: %v", err)
+ }
+ if n := int(math.Ceil(quota)); n > 0 {
+ if n < minCPUs {
+ n = minCPUs
+ }
+ if n < cpuNum {
+ // Only lower the cpu number.
+ cpuNum = n
+ }
+ }
+ }
+ cmd.Args = append(cmd.Args, "--cpu-num", strconv.Itoa(cpuNum))
+
+ mem, err := s.Cgroup.MemoryLimit()
+ if err != nil {
+ return fmt.Errorf("getting memory limit from cgroups: %v", err)
+ }
+ // When memory limit is unset, a "large" number is returned. In that case,
+ // just stick with the default.
+ if mem < 0x7ffffffffffff000 {
+ cmd.Args = append(cmd.Args, "--total-memory", strconv.FormatUint(mem, 10))
+ }
+ }
+
+ if args.UserLog != "" {
+ f, err := os.OpenFile(args.UserLog, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0664)
+ if err != nil {
+ return fmt.Errorf("opening compat log file: %v", err)
+ }
+ defer f.Close()
+
+ cmd.ExtraFiles = append(cmd.ExtraFiles, f)
+ cmd.Args = append(cmd.Args, "--user-log-fd", strconv.Itoa(nextFD))
+ nextFD++
+ }
+
+ if args.Attached {
+ // Kill sandbox if parent process exits in attached mode.
+ cmd.SysProcAttr.Pdeathsig = syscall.SIGKILL
+ // Tells boot that any process it creates must have pdeathsig set.
+ cmd.Args = append(cmd.Args, "--attached")
+ }
+
+ // Add container as the last argument.
+ cmd.Args = append(cmd.Args, s.ID)
+
+ // Log the FDs we are donating to the sandbox process.
+ for i, f := range cmd.ExtraFiles {
+ log.Debugf("Donating FD %d: %q", i+3, f.Name())
+ }
+
+ log.Debugf("Starting sandbox: %s %v", binPath, cmd.Args)
+ log.Debugf("SysProcAttr: %+v", cmd.SysProcAttr)
+ if err := specutils.StartInNS(cmd, nss); err != nil {
+ err := fmt.Errorf("starting sandbox: %v", err)
+ // If the sandbox failed to start, it may be because the binary
+ // permissions were incorrect. Check the bits and return a more helpful
+ // error message.
+ //
+ // NOTE: The error message is checked because error types are lost over
+ // rpc calls.
+ if strings.Contains(err.Error(), syscall.EACCES.Error()) {
+ if permsErr := checkBinaryPermissions(conf); permsErr != nil {
+ return fmt.Errorf("%v: %v", err, permsErr)
+ }
+ }
+ return err
+ }
+ s.child = true
+ s.Pid = cmd.Process.Pid
+ log.Infof("Sandbox started, PID: %d", s.Pid)
+
+ return nil
+}
+
+// Wait waits for the containerized process to exit, and returns its WaitStatus.
+func (s *Sandbox) Wait(cid string) (syscall.WaitStatus, error) {
+ log.Debugf("Waiting for container %q in sandbox %q", cid, s.ID)
+ var ws syscall.WaitStatus
+
+ if conn, err := s.sandboxConnect(); err != nil {
+ // The sandbox may have exited while before we had a chance to
+ // wait on it.
+ log.Warningf("Wait on container %q failed: %v. Will try waiting on the sandbox process instead.", cid, err)
+ } else {
+ defer conn.Close()
+ // Try the Wait RPC to the sandbox.
+ err = conn.Call(boot.ContainerWait, &cid, &ws)
+ if err == nil {
+ // It worked!
+ return ws, nil
+ }
+ // The sandbox may have exited after we connected, but before
+ // or during the Wait RPC.
+ log.Warningf("Wait RPC to container %q failed: %v. Will try waiting on the sandbox process instead.", cid, err)
+ }
+
+ // The sandbox may have already exited, or exited while handling the
+ // Wait RPC. The best we can do is ask Linux what the sandbox exit
+ // status was, since in most cases that will be the same as the
+ // container exit status.
+ if err := s.waitForStopped(); err != nil {
+ return ws, err
+ }
+ if !s.child {
+ return ws, fmt.Errorf("sandbox no longer running and its exit status is unavailable")
+ }
+ return s.status, nil
+}
+
+// WaitPID waits for process 'pid' in the container's sandbox and returns its
+// WaitStatus.
+func (s *Sandbox) WaitPID(cid string, pid int32) (syscall.WaitStatus, error) {
+ log.Debugf("Waiting for PID %d in sandbox %q", pid, s.ID)
+ var ws syscall.WaitStatus
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return ws, err
+ }
+ defer conn.Close()
+
+ args := &boot.WaitPIDArgs{
+ PID: pid,
+ CID: cid,
+ }
+ if err := conn.Call(boot.ContainerWaitPID, args, &ws); err != nil {
+ return ws, fmt.Errorf("waiting on PID %d in sandbox %q: %v", pid, s.ID, err)
+ }
+ return ws, nil
+}
+
+// IsRootContainer returns true if the specified container ID belongs to the
+// root container.
+func (s *Sandbox) IsRootContainer(cid string) bool {
+ return s.ID == cid
+}
+
+// Destroy frees all resources associated with the sandbox. It fails fast and
+// is idempotent.
+func (s *Sandbox) destroy() error {
+ log.Debugf("Destroy sandbox %q", s.ID)
+ if s.Pid != 0 {
+ log.Debugf("Killing sandbox %q", s.ID)
+ if err := syscall.Kill(s.Pid, syscall.SIGKILL); err != nil && err != syscall.ESRCH {
+ return fmt.Errorf("killing sandbox %q PID %q: %v", s.ID, s.Pid, err)
+ }
+ if err := s.waitForStopped(); err != nil {
+ return fmt.Errorf("waiting sandbox %q stop: %v", s.ID, err)
+ }
+ }
+
+ return nil
+}
+
+// SignalContainer sends the signal to a container in the sandbox. If all is
+// true and signal is SIGKILL, then waits for all processes to exit before
+// returning.
+func (s *Sandbox) SignalContainer(cid string, sig syscall.Signal, all bool) error {
+ log.Debugf("Signal sandbox %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ mode := boot.DeliverToProcess
+ if all {
+ mode = boot.DeliverToAllProcesses
+ }
+
+ args := boot.SignalArgs{
+ CID: cid,
+ Signo: int32(sig),
+ Mode: mode,
+ }
+ if err := conn.Call(boot.ContainerSignal, &args, nil); err != nil {
+ return fmt.Errorf("signaling container %q: %v", cid, err)
+ }
+ return nil
+}
+
+// SignalProcess sends the signal to a particular process in the container. If
+// fgProcess is true, then the signal is sent to the foreground process group
+// in the same session that PID belongs to. This is only valid if the process
+// is attached to a host TTY.
+func (s *Sandbox) SignalProcess(cid string, pid int32, sig syscall.Signal, fgProcess bool) error {
+ log.Debugf("Signal sandbox %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ mode := boot.DeliverToProcess
+ if fgProcess {
+ mode = boot.DeliverToForegroundProcessGroup
+ }
+
+ args := boot.SignalArgs{
+ CID: cid,
+ Signo: int32(sig),
+ PID: pid,
+ Mode: mode,
+ }
+ if err := conn.Call(boot.ContainerSignal, &args, nil); err != nil {
+ return fmt.Errorf("signaling container %q PID %d: %v", cid, pid, err)
+ }
+ return nil
+}
+
+// Checkpoint sends the checkpoint call for a container in the sandbox.
+// The statefile will be written to f.
+func (s *Sandbox) Checkpoint(cid string, f *os.File) error {
+ log.Debugf("Checkpoint sandbox %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ opt := control.SaveOpts{
+ FilePayload: urpc.FilePayload{
+ Files: []*os.File{f},
+ },
+ }
+
+ if err := conn.Call(boot.ContainerCheckpoint, &opt, nil); err != nil {
+ return fmt.Errorf("checkpointing container %q: %v", cid, err)
+ }
+ return nil
+}
+
+// Pause sends the pause call for a container in the sandbox.
+func (s *Sandbox) Pause(cid string) error {
+ log.Debugf("Pause sandbox %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ if err := conn.Call(boot.ContainerPause, nil, nil); err != nil {
+ return fmt.Errorf("pausing container %q: %v", cid, err)
+ }
+ return nil
+}
+
+// Resume sends the resume call for a container in the sandbox.
+func (s *Sandbox) Resume(cid string) error {
+ log.Debugf("Resume sandbox %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ if err := conn.Call(boot.ContainerResume, nil, nil); err != nil {
+ return fmt.Errorf("resuming container %q: %v", cid, err)
+ }
+ return nil
+}
+
+// IsRunning returns true if the sandbox or gofer process is running.
+func (s *Sandbox) IsRunning() bool {
+ if s.Pid != 0 {
+ // Send a signal 0 to the sandbox process.
+ if err := syscall.Kill(s.Pid, 0); err == nil {
+ // Succeeded, process is running.
+ return true
+ }
+ }
+ return false
+}
+
+// Stacks collects and returns all stacks for the sandbox.
+func (s *Sandbox) Stacks() (string, error) {
+ log.Debugf("Stacks sandbox %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return "", err
+ }
+ defer conn.Close()
+
+ var stacks string
+ if err := conn.Call(boot.SandboxStacks, nil, &stacks); err != nil {
+ return "", fmt.Errorf("getting sandbox %q stacks: %v", s.ID, err)
+ }
+ return stacks, nil
+}
+
+// HeapProfile writes a heap profile to the given file.
+func (s *Sandbox) HeapProfile(f *os.File) error {
+ log.Debugf("Heap profile %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ opts := control.ProfileOpts{
+ FilePayload: urpc.FilePayload{
+ Files: []*os.File{f},
+ },
+ }
+ if err := conn.Call(boot.HeapProfile, &opts, nil); err != nil {
+ return fmt.Errorf("getting sandbox %q heap profile: %v", s.ID, err)
+ }
+ return nil
+}
+
+// StartCPUProfile start CPU profile writing to the given file.
+func (s *Sandbox) StartCPUProfile(f *os.File) error {
+ log.Debugf("CPU profile start %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ opts := control.ProfileOpts{
+ FilePayload: urpc.FilePayload{
+ Files: []*os.File{f},
+ },
+ }
+ if err := conn.Call(boot.StartCPUProfile, &opts, nil); err != nil {
+ return fmt.Errorf("starting sandbox %q CPU profile: %v", s.ID, err)
+ }
+ return nil
+}
+
+// StopCPUProfile stops a previously started CPU profile.
+func (s *Sandbox) StopCPUProfile() error {
+ log.Debugf("CPU profile stop %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ if err := conn.Call(boot.StopCPUProfile, nil, nil); err != nil {
+ return fmt.Errorf("stopping sandbox %q CPU profile: %v", s.ID, err)
+ }
+ return nil
+}
+
+// GoroutineProfile writes a goroutine profile to the given file.
+func (s *Sandbox) GoroutineProfile(f *os.File) error {
+ log.Debugf("Goroutine profile %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ opts := control.ProfileOpts{
+ FilePayload: urpc.FilePayload{
+ Files: []*os.File{f},
+ },
+ }
+ if err := conn.Call(boot.GoroutineProfile, &opts, nil); err != nil {
+ return fmt.Errorf("getting sandbox %q goroutine profile: %v", s.ID, err)
+ }
+ return nil
+}
+
+// BlockProfile writes a block profile to the given file.
+func (s *Sandbox) BlockProfile(f *os.File) error {
+ log.Debugf("Block profile %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ opts := control.ProfileOpts{
+ FilePayload: urpc.FilePayload{
+ Files: []*os.File{f},
+ },
+ }
+ if err := conn.Call(boot.BlockProfile, &opts, nil); err != nil {
+ return fmt.Errorf("getting sandbox %q block profile: %v", s.ID, err)
+ }
+ return nil
+}
+
+// MutexProfile writes a mutex profile to the given file.
+func (s *Sandbox) MutexProfile(f *os.File) error {
+ log.Debugf("Mutex profile %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ opts := control.ProfileOpts{
+ FilePayload: urpc.FilePayload{
+ Files: []*os.File{f},
+ },
+ }
+ if err := conn.Call(boot.MutexProfile, &opts, nil); err != nil {
+ return fmt.Errorf("getting sandbox %q mutex profile: %v", s.ID, err)
+ }
+ return nil
+}
+
+// StartTrace start trace writing to the given file.
+func (s *Sandbox) StartTrace(f *os.File) error {
+ log.Debugf("Trace start %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ opts := control.ProfileOpts{
+ FilePayload: urpc.FilePayload{
+ Files: []*os.File{f},
+ },
+ }
+ if err := conn.Call(boot.StartTrace, &opts, nil); err != nil {
+ return fmt.Errorf("starting sandbox %q trace: %v", s.ID, err)
+ }
+ return nil
+}
+
+// StopTrace stops a previously started trace.
+func (s *Sandbox) StopTrace() error {
+ log.Debugf("Trace stop %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ if err := conn.Call(boot.StopTrace, nil, nil); err != nil {
+ return fmt.Errorf("stopping sandbox %q trace: %v", s.ID, err)
+ }
+ return nil
+}
+
+// ChangeLogging changes logging options.
+func (s *Sandbox) ChangeLogging(args control.LoggingArgs) error {
+ log.Debugf("Change logging start %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ if err := conn.Call(boot.ChangeLogging, &args, nil); err != nil {
+ return fmt.Errorf("changing sandbox %q logging: %v", s.ID, err)
+ }
+ return nil
+}
+
+// DestroyContainer destroys the given container. If it is the root container,
+// then the entire sandbox is destroyed.
+func (s *Sandbox) DestroyContainer(cid string) error {
+ if err := s.destroyContainer(cid); err != nil {
+ // If the sandbox isn't running, the container has already been destroyed,
+ // ignore the error in this case.
+ if s.IsRunning() {
+ return err
+ }
+ }
+ return nil
+}
+
+func (s *Sandbox) destroyContainer(cid string) error {
+ if s.IsRootContainer(cid) {
+ log.Debugf("Destroying root container %q by destroying sandbox", cid)
+ return s.destroy()
+ }
+
+ log.Debugf("Destroying container %q in sandbox %q", cid, s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+ if err := conn.Call(boot.ContainerDestroy, &cid, nil); err != nil {
+ return fmt.Errorf("destroying container %q: %v", cid, err)
+ }
+ return nil
+}
+
+func (s *Sandbox) waitForStopped() error {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx)
+ op := func() error {
+ if s.child {
+ s.statusMu.Lock()
+ defer s.statusMu.Unlock()
+ if s.Pid == 0 {
+ return nil
+ }
+ // The sandbox process is a child of the current process,
+ // so we can wait it and collect its zombie.
+ wpid, err := syscall.Wait4(int(s.Pid), &s.status, syscall.WNOHANG, nil)
+ if err != nil {
+ return fmt.Errorf("error waiting the sandbox process: %v", err)
+ }
+ if wpid == 0 {
+ return fmt.Errorf("sandbox is still running")
+ }
+ s.Pid = 0
+ } else if s.IsRunning() {
+ return fmt.Errorf("sandbox is still running")
+ }
+ return nil
+ }
+ return backoff.Retry(op, b)
+}
+
+// deviceFileForPlatform opens the device file for the given platform. If the
+// platform does not need a device file, then nil is returned.
+func deviceFileForPlatform(name string) (*os.File, error) {
+ p, err := platform.Lookup(name)
+ if err != nil {
+ return nil, err
+ }
+
+ f, err := p.OpenDevice()
+ if err != nil {
+ return nil, fmt.Errorf("opening device file for platform %q: %v", p, err)
+ }
+ return f, nil
+}
+
+// checkBinaryPermissions verifies that the required binary bits are set on
+// the runsc executable.
+func checkBinaryPermissions(conf *boot.Config) error {
+ // All platforms need the other exe bit
+ neededBits := os.FileMode(0001)
+ if conf.Platform == platforms.Ptrace {
+ // Ptrace needs the other read bit
+ neededBits |= os.FileMode(0004)
+ }
+
+ exePath, err := os.Executable()
+ if err != nil {
+ return fmt.Errorf("getting exe path: %v", err)
+ }
+
+ // Check the permissions of the runsc binary and print an error if it
+ // doesn't match expectations.
+ info, err := os.Stat(exePath)
+ if err != nil {
+ return fmt.Errorf("stat file: %v", err)
+ }
+
+ if info.Mode().Perm()&neededBits != neededBits {
+ return fmt.Errorf(specutils.FaqErrorMsg("runsc-perms", fmt.Sprintf("%s does not have the correct permissions", exePath)))
+ }
+ return nil
+}
diff --git a/runsc/specutils/BUILD b/runsc/specutils/BUILD
new file mode 100644
index 000000000..62d4f5113
--- /dev/null
+++ b/runsc/specutils/BUILD
@@ -0,0 +1,33 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "specutils",
+ srcs = [
+ "cri.go",
+ "fs.go",
+ "namespace.go",
+ "specutils.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/bits",
+ "//pkg/log",
+ "//pkg/sentry/kernel/auth",
+ "@com_github_cenkalti_backoff//:go_default_library",
+ "@com_github_mohae_deepcopy//:go_default_library",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@com_github_syndtr_gocapability//capability:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "specutils_test",
+ size = "small",
+ srcs = ["specutils_test.go"],
+ library = ":specutils",
+ deps = ["@com_github_opencontainers_runtime-spec//specs-go:go_default_library"],
+)
diff --git a/runsc/specutils/cri.go b/runsc/specutils/cri.go
new file mode 100644
index 000000000..9c5877cd5
--- /dev/null
+++ b/runsc/specutils/cri.go
@@ -0,0 +1,110 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package specutils
+
+import (
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+)
+
+const (
+ // ContainerdContainerTypeAnnotation is the OCI annotation set by
+ // containerd to indicate whether the container to create should have
+ // its own sandbox or a container within an existing sandbox.
+ ContainerdContainerTypeAnnotation = "io.kubernetes.cri.container-type"
+ // ContainerdContainerTypeContainer is the container type value
+ // indicating the container should be created in an existing sandbox.
+ ContainerdContainerTypeContainer = "container"
+ // ContainerdContainerTypeSandbox is the container type value
+ // indicating the container should be created in a new sandbox.
+ ContainerdContainerTypeSandbox = "sandbox"
+
+ // ContainerdSandboxIDAnnotation is the OCI annotation set to indicate
+ // which sandbox the container should be created in when the container
+ // is not the first container in the sandbox.
+ ContainerdSandboxIDAnnotation = "io.kubernetes.cri.sandbox-id"
+
+ // CRIOContainerTypeAnnotation is the OCI annotation set by
+ // CRI-O to indicate whether the container to create should have
+ // its own sandbox or a container within an existing sandbox.
+ CRIOContainerTypeAnnotation = "io.kubernetes.cri-o.ContainerType"
+
+ // CRIOContainerTypeContainer is the container type value
+ // indicating the container should be created in an existing sandbox.
+ CRIOContainerTypeContainer = "container"
+ // CRIOContainerTypeSandbox is the container type value
+ // indicating the container should be created in a new sandbox.
+ CRIOContainerTypeSandbox = "sandbox"
+
+ // CRIOSandboxIDAnnotation is the OCI annotation set to indicate
+ // which sandbox the container should be created in when the container
+ // is not the first container in the sandbox.
+ CRIOSandboxIDAnnotation = "io.kubernetes.cri-o.SandboxID"
+)
+
+// ContainerType represents the type of container requested by the calling container manager.
+type ContainerType int
+
+const (
+ // ContainerTypeUnspecified indicates that no known container type
+ // annotation was found in the spec.
+ ContainerTypeUnspecified ContainerType = iota
+ // ContainerTypeUnknown indicates that a container type was specified
+ // but is unknown to us.
+ ContainerTypeUnknown
+ // ContainerTypeSandbox indicates that the container should be run in a
+ // new sandbox.
+ ContainerTypeSandbox
+ // ContainerTypeContainer indicates that the container should be run in
+ // an existing sandbox.
+ ContainerTypeContainer
+)
+
+// SpecContainerType tries to determine the type of container specified by the
+// container manager using well-known container annotations.
+func SpecContainerType(spec *specs.Spec) ContainerType {
+ if t, ok := spec.Annotations[ContainerdContainerTypeAnnotation]; ok {
+ switch t {
+ case ContainerdContainerTypeSandbox:
+ return ContainerTypeSandbox
+ case ContainerdContainerTypeContainer:
+ return ContainerTypeContainer
+ default:
+ return ContainerTypeUnknown
+ }
+ }
+ if t, ok := spec.Annotations[CRIOContainerTypeAnnotation]; ok {
+ switch t {
+ case CRIOContainerTypeSandbox:
+ return ContainerTypeSandbox
+ case CRIOContainerTypeContainer:
+ return ContainerTypeContainer
+ default:
+ return ContainerTypeUnknown
+ }
+ }
+ return ContainerTypeUnspecified
+}
+
+// SandboxID returns the ID of the sandbox to join and whether an ID was found
+// in the spec.
+func SandboxID(spec *specs.Spec) (string, bool) {
+ if id, ok := spec.Annotations[ContainerdSandboxIDAnnotation]; ok {
+ return id, true
+ }
+ if id, ok := spec.Annotations[CRIOSandboxIDAnnotation]; ok {
+ return id, true
+ }
+ return "", false
+}
diff --git a/runsc/specutils/fs.go b/runsc/specutils/fs.go
new file mode 100644
index 000000000..138aa4dd1
--- /dev/null
+++ b/runsc/specutils/fs.go
@@ -0,0 +1,155 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package specutils
+
+import (
+ "fmt"
+ "math/bits"
+ "path"
+ "syscall"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+)
+
+type mapping struct {
+ set bool
+ val uint32
+}
+
+// optionsMap maps mount propagation-related OCI filesystem options to mount(2)
+// syscall flags.
+var optionsMap = map[string]mapping{
+ "acl": {set: true, val: syscall.MS_POSIXACL},
+ "async": {set: false, val: syscall.MS_SYNCHRONOUS},
+ "atime": {set: false, val: syscall.MS_NOATIME},
+ "bind": {set: true, val: syscall.MS_BIND},
+ "defaults": {set: true, val: 0},
+ "dev": {set: false, val: syscall.MS_NODEV},
+ "diratime": {set: false, val: syscall.MS_NODIRATIME},
+ "dirsync": {set: true, val: syscall.MS_DIRSYNC},
+ "exec": {set: false, val: syscall.MS_NOEXEC},
+ "noexec": {set: true, val: syscall.MS_NOEXEC},
+ "iversion": {set: true, val: syscall.MS_I_VERSION},
+ "loud": {set: false, val: syscall.MS_SILENT},
+ "mand": {set: true, val: syscall.MS_MANDLOCK},
+ "noacl": {set: false, val: syscall.MS_POSIXACL},
+ "noatime": {set: true, val: syscall.MS_NOATIME},
+ "nodev": {set: true, val: syscall.MS_NODEV},
+ "nodiratime": {set: true, val: syscall.MS_NODIRATIME},
+ "noiversion": {set: false, val: syscall.MS_I_VERSION},
+ "nomand": {set: false, val: syscall.MS_MANDLOCK},
+ "norelatime": {set: false, val: syscall.MS_RELATIME},
+ "nostrictatime": {set: false, val: syscall.MS_STRICTATIME},
+ "nosuid": {set: true, val: syscall.MS_NOSUID},
+ "rbind": {set: true, val: syscall.MS_BIND | syscall.MS_REC},
+ "relatime": {set: true, val: syscall.MS_RELATIME},
+ "remount": {set: true, val: syscall.MS_REMOUNT},
+ "ro": {set: true, val: syscall.MS_RDONLY},
+ "rw": {set: false, val: syscall.MS_RDONLY},
+ "silent": {set: true, val: syscall.MS_SILENT},
+ "strictatime": {set: true, val: syscall.MS_STRICTATIME},
+ "suid": {set: false, val: syscall.MS_NOSUID},
+ "sync": {set: true, val: syscall.MS_SYNCHRONOUS},
+}
+
+// propOptionsMap is similar to optionsMap, but it lists propagation options
+// that cannot be used together with other flags.
+var propOptionsMap = map[string]mapping{
+ "private": {set: true, val: syscall.MS_PRIVATE},
+ "rprivate": {set: true, val: syscall.MS_PRIVATE | syscall.MS_REC},
+ "slave": {set: true, val: syscall.MS_SLAVE},
+ "rslave": {set: true, val: syscall.MS_SLAVE | syscall.MS_REC},
+ "unbindable": {set: true, val: syscall.MS_UNBINDABLE},
+ "runbindable": {set: true, val: syscall.MS_UNBINDABLE | syscall.MS_REC},
+}
+
+// invalidOptions list options not allowed.
+// - shared: sandbox must be isolated from the host. Propagating mount changes
+// from the sandbox to the host breaks the isolation.
+var invalidOptions = []string{"shared", "rshared"}
+
+// OptionsToFlags converts mount options to syscall flags.
+func OptionsToFlags(opts []string) uint32 {
+ return optionsToFlags(opts, optionsMap)
+}
+
+// PropOptionsToFlags converts propagation mount options to syscall flags.
+// Propagation options cannot be set other with other options and must be
+// handled separately.
+func PropOptionsToFlags(opts []string) uint32 {
+ return optionsToFlags(opts, propOptionsMap)
+}
+
+func optionsToFlags(opts []string, source map[string]mapping) uint32 {
+ var rv uint32
+ for _, opt := range opts {
+ if m, ok := source[opt]; ok {
+ if m.set {
+ rv |= m.val
+ } else {
+ rv ^= m.val
+ }
+ }
+ }
+ return rv
+}
+
+// validateMount validates that spec mounts are correct.
+func validateMount(mnt *specs.Mount) error {
+ if !path.IsAbs(mnt.Destination) {
+ return fmt.Errorf("Mount.Destination must be an absolute path: %v", mnt)
+ }
+ if mnt.Type == "bind" {
+ return ValidateMountOptions(mnt.Options)
+ }
+ return nil
+}
+
+// ValidateMountOptions validates that mount options are correct.
+func ValidateMountOptions(opts []string) error {
+ for _, o := range opts {
+ if ContainsStr(invalidOptions, o) {
+ return fmt.Errorf("mount option %q is not supported", o)
+ }
+ _, ok1 := optionsMap[o]
+ _, ok2 := propOptionsMap[o]
+ if !ok1 && !ok2 {
+ return fmt.Errorf("unknown mount option %q", o)
+ }
+ if err := validatePropagation(o); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// ValidateRootfsPropagation validates that rootfs propagation options are
+// correct.
+func validateRootfsPropagation(opt string) error {
+ flags := PropOptionsToFlags([]string{opt})
+ if flags&(syscall.MS_SLAVE|syscall.MS_PRIVATE) == 0 {
+ return fmt.Errorf("root mount propagation option must specify private or slave: %q", opt)
+ }
+ return validatePropagation(opt)
+}
+
+func validatePropagation(opt string) error {
+ flags := PropOptionsToFlags([]string{opt})
+ exclusive := flags & (syscall.MS_SLAVE | syscall.MS_PRIVATE | syscall.MS_SHARED | syscall.MS_UNBINDABLE)
+ if bits.OnesCount32(exclusive) > 1 {
+ return fmt.Errorf("mount propagation options are mutually exclusive: %q", opt)
+ }
+ return nil
+}
diff --git a/runsc/specutils/namespace.go b/runsc/specutils/namespace.go
new file mode 100644
index 000000000..23001d67c
--- /dev/null
+++ b/runsc/specutils/namespace.go
@@ -0,0 +1,289 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package specutils
+
+import (
+ "fmt"
+ "os"
+ "os/exec"
+ "os/signal"
+ "path/filepath"
+ "runtime"
+ "syscall"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "github.com/syndtr/gocapability/capability"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// nsCloneFlag returns the clone flag that can be used to set a namespace of
+// the given type.
+func nsCloneFlag(nst specs.LinuxNamespaceType) uintptr {
+ switch nst {
+ case specs.IPCNamespace:
+ return unix.CLONE_NEWIPC
+ case specs.MountNamespace:
+ return unix.CLONE_NEWNS
+ case specs.NetworkNamespace:
+ return unix.CLONE_NEWNET
+ case specs.PIDNamespace:
+ return unix.CLONE_NEWPID
+ case specs.UTSNamespace:
+ return unix.CLONE_NEWUTS
+ case specs.UserNamespace:
+ return unix.CLONE_NEWUSER
+ case specs.CgroupNamespace:
+ return unix.CLONE_NEWCGROUP
+ default:
+ panic(fmt.Sprintf("unknown namespace %v", nst))
+ }
+}
+
+// nsPath returns the path of the namespace for the current process and the
+// given namespace.
+func nsPath(nst specs.LinuxNamespaceType) string {
+ base := "/proc/self/ns"
+ switch nst {
+ case specs.CgroupNamespace:
+ return filepath.Join(base, "cgroup")
+ case specs.IPCNamespace:
+ return filepath.Join(base, "ipc")
+ case specs.MountNamespace:
+ return filepath.Join(base, "mnt")
+ case specs.NetworkNamespace:
+ return filepath.Join(base, "net")
+ case specs.PIDNamespace:
+ return filepath.Join(base, "pid")
+ case specs.UserNamespace:
+ return filepath.Join(base, "user")
+ case specs.UTSNamespace:
+ return filepath.Join(base, "uts")
+ default:
+ panic(fmt.Sprintf("unknown namespace %v", nst))
+ }
+}
+
+// GetNS returns true and the namespace with the given type from the slice of
+// namespaces in the spec. It returns false if the slice does not contain a
+// namespace with the type.
+func GetNS(nst specs.LinuxNamespaceType, s *specs.Spec) (specs.LinuxNamespace, bool) {
+ if s.Linux == nil {
+ return specs.LinuxNamespace{}, false
+ }
+ for _, ns := range s.Linux.Namespaces {
+ if ns.Type == nst {
+ return ns, true
+ }
+ }
+ return specs.LinuxNamespace{}, false
+}
+
+// FilterNS returns a slice of namespaces from the spec with types that match
+// those in the `filter` slice.
+func FilterNS(filter []specs.LinuxNamespaceType, s *specs.Spec) []specs.LinuxNamespace {
+ if s.Linux == nil {
+ return nil
+ }
+ var out []specs.LinuxNamespace
+ for _, nst := range filter {
+ if ns, ok := GetNS(nst, s); ok {
+ out = append(out, ns)
+ }
+ }
+ return out
+}
+
+// setNS sets the namespace of the given type. It must be called with
+// OSThreadLocked.
+func setNS(fd, nsType uintptr) error {
+ if _, _, err := syscall.RawSyscall(unix.SYS_SETNS, fd, nsType, 0); err != 0 {
+ return err
+ }
+ return nil
+}
+
+// ApplyNS applies the namespace on the current thread and returns a function
+// that will restore the namespace to the original value.
+//
+// Preconditions: Must be called with os thread locked.
+func ApplyNS(ns specs.LinuxNamespace) (func(), error) {
+ log.Infof("Applying namespace %v at path %q", ns.Type, ns.Path)
+ newNS, err := os.Open(ns.Path)
+ if err != nil {
+ return nil, fmt.Errorf("error opening %q: %v", ns.Path, err)
+ }
+ defer newNS.Close()
+
+ // Store current namespace to restore back.
+ curPath := nsPath(ns.Type)
+ oldNS, err := os.Open(curPath)
+ if err != nil {
+ return nil, fmt.Errorf("error opening %q: %v", curPath, err)
+ }
+
+ // Set namespace to the one requested and setup function to restore it back.
+ flag := nsCloneFlag(ns.Type)
+ if err := setNS(newNS.Fd(), flag); err != nil {
+ oldNS.Close()
+ return nil, fmt.Errorf("error setting namespace of type %v and path %q: %v", ns.Type, ns.Path, err)
+ }
+ return func() {
+ log.Infof("Restoring namespace %v", ns.Type)
+ defer oldNS.Close()
+ if err := setNS(oldNS.Fd(), flag); err != nil {
+ panic(fmt.Sprintf("error restoring namespace: of type %v: %v", ns.Type, err))
+ }
+ }, nil
+}
+
+// StartInNS joins or creates the given namespaces and calls cmd.Start before
+// restoring the namespaces to the original values.
+func StartInNS(cmd *exec.Cmd, nss []specs.LinuxNamespace) error {
+ // We are about to setup namespaces, which requires the os thread being
+ // locked so that Go doesn't change the thread out from under us.
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+
+ if cmd.SysProcAttr == nil {
+ cmd.SysProcAttr = &syscall.SysProcAttr{}
+ }
+
+ for _, ns := range nss {
+ if ns.Path == "" {
+ // No path. Just set a flag to create a new namespace.
+ cmd.SysProcAttr.Cloneflags |= nsCloneFlag(ns.Type)
+ continue
+ }
+ // Join the given namespace, and restore the current namespace
+ // before exiting.
+ restoreNS, err := ApplyNS(ns)
+ if err != nil {
+ return err
+ }
+ defer restoreNS()
+ }
+
+ return cmd.Start()
+}
+
+// SetUIDGIDMappings sets the given uid/gid mappings from the spec on the cmd.
+func SetUIDGIDMappings(cmd *exec.Cmd, s *specs.Spec) {
+ if s.Linux == nil {
+ return
+ }
+ if cmd.SysProcAttr == nil {
+ cmd.SysProcAttr = &syscall.SysProcAttr{}
+ }
+ for _, idMap := range s.Linux.UIDMappings {
+ log.Infof("Mapping host uid %d to container uid %d (size=%d)", idMap.HostID, idMap.ContainerID, idMap.Size)
+ cmd.SysProcAttr.UidMappings = append(cmd.SysProcAttr.UidMappings, syscall.SysProcIDMap{
+ ContainerID: int(idMap.ContainerID),
+ HostID: int(idMap.HostID),
+ Size: int(idMap.Size),
+ })
+ }
+ for _, idMap := range s.Linux.GIDMappings {
+ log.Infof("Mapping host gid %d to container gid %d (size=%d)", idMap.HostID, idMap.ContainerID, idMap.Size)
+ cmd.SysProcAttr.GidMappings = append(cmd.SysProcAttr.GidMappings, syscall.SysProcIDMap{
+ ContainerID: int(idMap.ContainerID),
+ HostID: int(idMap.HostID),
+ Size: int(idMap.Size),
+ })
+ }
+}
+
+// HasCapabilities returns true if the user has all capabilities in 'cs'.
+func HasCapabilities(cs ...capability.Cap) bool {
+ caps, err := capability.NewPid2(os.Getpid())
+ if err != nil {
+ return false
+ }
+ if err := caps.Load(); err != nil {
+ return false
+ }
+ for _, c := range cs {
+ if !caps.Get(capability.EFFECTIVE, c) {
+ return false
+ }
+ }
+ return true
+}
+
+// MaybeRunAsRoot ensures the process runs with capabilities needed to create a
+// sandbox, e.g. CAP_SYS_ADMIN, CAP_SYS_CHROOT, etc. If capabilities are needed,
+// it will create a new user namespace and re-execute the process as root
+// inside the namespace with the same arguments and environment.
+//
+// This function returns immediately when no new capability is needed. If
+// another process is executed, it returns straight from here with the same exit
+// code as the child.
+func MaybeRunAsRoot() error {
+ if HasCapabilities(capability.CAP_SYS_ADMIN, capability.CAP_SYS_CHROOT, capability.CAP_SETUID, capability.CAP_SETGID) {
+ return nil
+ }
+
+ // Current process doesn't have required capabilities, create user namespace
+ // and run as root inside the namespace to acquire capabilities.
+ log.Infof("*** Re-running as root in new user namespace ***")
+
+ cmd := exec.Command("/proc/self/exe", os.Args[1:]...)
+
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Cloneflags: syscall.CLONE_NEWUSER | syscall.CLONE_NEWNS,
+ // Set current user/group as root inside the namespace. Since we may not
+ // have CAP_SETUID/CAP_SETGID, just map root to the current user/group.
+ UidMappings: []syscall.SysProcIDMap{
+ {ContainerID: 0, HostID: os.Getuid(), Size: 1},
+ },
+ GidMappings: []syscall.SysProcIDMap{
+ {ContainerID: 0, HostID: os.Getgid(), Size: 1},
+ },
+ Credential: &syscall.Credential{Uid: 0, Gid: 0},
+ GidMappingsEnableSetgroups: false,
+
+ // Make sure child is killed when the parent terminates.
+ Pdeathsig: syscall.SIGKILL,
+ }
+
+ cmd.Env = os.Environ()
+ cmd.Stdin = os.Stdin
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+ if err := cmd.Start(); err != nil {
+ return fmt.Errorf("re-executing self: %w", err)
+ }
+ ch := make(chan os.Signal, 1)
+ signal.Notify(ch)
+ go func() {
+ for {
+ // Forward all signals to child process.
+ cmd.Process.Signal(<-ch)
+ }
+ }()
+ if err := cmd.Wait(); err != nil {
+ if exit, ok := err.(*exec.ExitError); ok {
+ if ws, ok := exit.Sys().(syscall.WaitStatus); ok {
+ os.Exit(ws.ExitStatus())
+ }
+ log.Warningf("No wait status provided, exiting with -1: %v", err)
+ os.Exit(-1)
+ }
+ return err
+ }
+ // Child completed with success.
+ os.Exit(0)
+ panic("unreachable")
+}
diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go
new file mode 100644
index 000000000..5015c3a84
--- /dev/null
+++ b/runsc/specutils/specutils.go
@@ -0,0 +1,523 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package specutils contains utility functions for working with OCI runtime
+// specs.
+package specutils
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "syscall"
+ "time"
+
+ "github.com/cenkalti/backoff"
+ "github.com/mohae/deepcopy"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// ExePath must point to runsc binary, which is normally the same binary. It's
+// changed in tests that aren't linked in the same binary.
+var ExePath = "/proc/self/exe"
+
+// Version is the supported spec version.
+var Version = specs.Version
+
+// LogSpec logs the spec in a human-friendly way.
+func LogSpec(orig *specs.Spec) {
+ if !log.IsLogging(log.Debug) {
+ return
+ }
+
+ // Strip down parts of the spec that are not interesting.
+ spec := deepcopy.Copy(orig).(*specs.Spec)
+ if spec.Process != nil {
+ spec.Process.Capabilities = nil
+ }
+ if spec.Linux != nil {
+ spec.Linux.Seccomp = nil
+ spec.Linux.MaskedPaths = nil
+ spec.Linux.ReadonlyPaths = nil
+ if spec.Linux.Resources != nil {
+ spec.Linux.Resources.Devices = nil
+ }
+ }
+
+ out, err := json.MarshalIndent(spec, "", " ")
+ if err != nil {
+ log.Debugf("Failed to marshal spec: %v", err)
+ return
+ }
+ log.Debugf("Spec:\n%s", out)
+}
+
+// ValidateSpec validates that the spec is compatible with runsc.
+func ValidateSpec(spec *specs.Spec) error {
+ // Mandatory fields.
+ if spec.Process == nil {
+ return fmt.Errorf("Spec.Process must be defined: %+v", spec)
+ }
+ if len(spec.Process.Args) == 0 {
+ return fmt.Errorf("Spec.Process.Arg must be defined: %+v", spec.Process)
+ }
+ if spec.Root == nil {
+ return fmt.Errorf("Spec.Root must be defined: %+v", spec)
+ }
+ if len(spec.Root.Path) == 0 {
+ return fmt.Errorf("Spec.Root.Path must be defined: %+v", spec.Root)
+ }
+
+ // Unsupported fields.
+ if spec.Solaris != nil {
+ return fmt.Errorf("Spec.Solaris is not supported: %+v", spec)
+ }
+ if spec.Windows != nil {
+ return fmt.Errorf("Spec.Windows is not supported: %+v", spec)
+ }
+ if len(spec.Process.SelinuxLabel) != 0 {
+ return fmt.Errorf("SELinux is not supported: %s", spec.Process.SelinuxLabel)
+ }
+
+ // Docker uses AppArmor by default, so just log that it's being ignored.
+ if spec.Process.ApparmorProfile != "" {
+ log.Warningf("AppArmor profile %q is being ignored", spec.Process.ApparmorProfile)
+ }
+
+ // PR_SET_NO_NEW_PRIVS is assumed to always be set.
+ // See kernel.Task.updateCredsForExecLocked.
+ if !spec.Process.NoNewPrivileges {
+ log.Warningf("noNewPrivileges ignored. PR_SET_NO_NEW_PRIVS is assumed to always be set.")
+ }
+
+ // TODO(gvisor.dev/issue/510): Apply seccomp to application inside sandbox.
+ if spec.Linux != nil && spec.Linux.Seccomp != nil {
+ log.Warningf("Seccomp spec is being ignored")
+ }
+
+ if spec.Linux != nil && spec.Linux.RootfsPropagation != "" {
+ if err := validateRootfsPropagation(spec.Linux.RootfsPropagation); err != nil {
+ return err
+ }
+ }
+ for _, m := range spec.Mounts {
+ if err := validateMount(&m); err != nil {
+ return err
+ }
+ }
+
+ // CRI specifies whether a container should start a new sandbox, or run
+ // another container in an existing sandbox.
+ switch SpecContainerType(spec) {
+ case ContainerTypeContainer:
+ // When starting a container in an existing sandbox, the
+ // sandbox ID must be set.
+ if _, ok := SandboxID(spec); !ok {
+ return fmt.Errorf("spec has container-type of container, but no sandbox ID set")
+ }
+ case ContainerTypeUnknown:
+ return fmt.Errorf("unknown container-type")
+ default:
+ }
+
+ return nil
+}
+
+// absPath turns the given path into an absolute path (if it is not already
+// absolute) by prepending the base path.
+func absPath(base, rel string) string {
+ if filepath.IsAbs(rel) {
+ return rel
+ }
+ return filepath.Join(base, rel)
+}
+
+// OpenSpec opens an OCI runtime spec from the given bundle directory.
+func OpenSpec(bundleDir string) (*os.File, error) {
+ // The spec file must be named "config.json" inside the bundle directory.
+ return os.Open(filepath.Join(bundleDir, "config.json"))
+}
+
+// ReadSpec reads an OCI runtime spec from the given bundle directory.
+// ReadSpec also normalizes all potential relative paths into absolute
+// path, e.g. spec.Root.Path, mount.Source.
+func ReadSpec(bundleDir string) (*specs.Spec, error) {
+ specFile, err := OpenSpec(bundleDir)
+ if err != nil {
+ return nil, fmt.Errorf("error opening spec file %q: %v", filepath.Join(bundleDir, "config.json"), err)
+ }
+ defer specFile.Close()
+ return ReadSpecFromFile(bundleDir, specFile)
+}
+
+// ReadSpecFromFile reads an OCI runtime spec from the given File, and
+// normalizes all relative paths into absolute by prepending the bundle dir.
+func ReadSpecFromFile(bundleDir string, specFile *os.File) (*specs.Spec, error) {
+ if _, err := specFile.Seek(0, os.SEEK_SET); err != nil {
+ return nil, fmt.Errorf("error seeking to beginning of file %q: %v", specFile.Name(), err)
+ }
+ specBytes, err := ioutil.ReadAll(specFile)
+ if err != nil {
+ return nil, fmt.Errorf("error reading spec from file %q: %v", specFile.Name(), err)
+ }
+ var spec specs.Spec
+ if err := json.Unmarshal(specBytes, &spec); err != nil {
+ return nil, fmt.Errorf("error unmarshaling spec from file %q: %v\n %s", specFile.Name(), err, string(specBytes))
+ }
+ if err := ValidateSpec(&spec); err != nil {
+ return nil, err
+ }
+ // Turn any relative paths in the spec to absolute by prepending the bundleDir.
+ spec.Root.Path = absPath(bundleDir, spec.Root.Path)
+ for i := range spec.Mounts {
+ m := &spec.Mounts[i]
+ if m.Source != "" {
+ m.Source = absPath(bundleDir, m.Source)
+ }
+ }
+ return &spec, nil
+}
+
+// ReadMounts reads mount list from a file.
+func ReadMounts(f *os.File) ([]specs.Mount, error) {
+ bytes, err := ioutil.ReadAll(f)
+ if err != nil {
+ return nil, fmt.Errorf("error reading mounts: %v", err)
+ }
+ var mounts []specs.Mount
+ if err := json.Unmarshal(bytes, &mounts); err != nil {
+ return nil, fmt.Errorf("error unmarshaling mounts: %v\n %s", err, string(bytes))
+ }
+ return mounts, nil
+}
+
+// Capabilities takes in spec and returns a TaskCapabilities corresponding to
+// the spec.
+func Capabilities(enableRaw bool, specCaps *specs.LinuxCapabilities) (*auth.TaskCapabilities, error) {
+ // Strip CAP_NET_RAW from all capability sets if necessary.
+ skipSet := map[linux.Capability]struct{}{}
+ if !enableRaw {
+ skipSet[linux.CAP_NET_RAW] = struct{}{}
+ }
+
+ var caps auth.TaskCapabilities
+ if specCaps != nil {
+ var err error
+ if caps.BoundingCaps, err = capsFromNames(specCaps.Bounding, skipSet); err != nil {
+ return nil, err
+ }
+ if caps.EffectiveCaps, err = capsFromNames(specCaps.Effective, skipSet); err != nil {
+ return nil, err
+ }
+ if caps.InheritableCaps, err = capsFromNames(specCaps.Inheritable, skipSet); err != nil {
+ return nil, err
+ }
+ if caps.PermittedCaps, err = capsFromNames(specCaps.Permitted, skipSet); err != nil {
+ return nil, err
+ }
+ // TODO(nlacasse): Support ambient capabilities.
+ }
+ return &caps, nil
+}
+
+// AllCapabilities returns a LinuxCapabilities struct with all capabilities.
+func AllCapabilities() *specs.LinuxCapabilities {
+ var names []string
+ for n := range capFromName {
+ names = append(names, n)
+ }
+ return &specs.LinuxCapabilities{
+ Bounding: names,
+ Effective: names,
+ Inheritable: names,
+ Permitted: names,
+ Ambient: names,
+ }
+}
+
+// AllCapabilitiesUint64 returns a bitmask containing all capabilities set.
+func AllCapabilitiesUint64() uint64 {
+ var rv uint64
+ for _, cap := range capFromName {
+ rv |= bits.MaskOf64(int(cap))
+ }
+ return rv
+}
+
+var capFromName = map[string]linux.Capability{
+ "CAP_CHOWN": linux.CAP_CHOWN,
+ "CAP_DAC_OVERRIDE": linux.CAP_DAC_OVERRIDE,
+ "CAP_DAC_READ_SEARCH": linux.CAP_DAC_READ_SEARCH,
+ "CAP_FOWNER": linux.CAP_FOWNER,
+ "CAP_FSETID": linux.CAP_FSETID,
+ "CAP_KILL": linux.CAP_KILL,
+ "CAP_SETGID": linux.CAP_SETGID,
+ "CAP_SETUID": linux.CAP_SETUID,
+ "CAP_SETPCAP": linux.CAP_SETPCAP,
+ "CAP_LINUX_IMMUTABLE": linux.CAP_LINUX_IMMUTABLE,
+ "CAP_NET_BIND_SERVICE": linux.CAP_NET_BIND_SERVICE,
+ "CAP_NET_BROADCAST": linux.CAP_NET_BROADCAST,
+ "CAP_NET_ADMIN": linux.CAP_NET_ADMIN,
+ "CAP_NET_RAW": linux.CAP_NET_RAW,
+ "CAP_IPC_LOCK": linux.CAP_IPC_LOCK,
+ "CAP_IPC_OWNER": linux.CAP_IPC_OWNER,
+ "CAP_SYS_MODULE": linux.CAP_SYS_MODULE,
+ "CAP_SYS_RAWIO": linux.CAP_SYS_RAWIO,
+ "CAP_SYS_CHROOT": linux.CAP_SYS_CHROOT,
+ "CAP_SYS_PTRACE": linux.CAP_SYS_PTRACE,
+ "CAP_SYS_PACCT": linux.CAP_SYS_PACCT,
+ "CAP_SYS_ADMIN": linux.CAP_SYS_ADMIN,
+ "CAP_SYS_BOOT": linux.CAP_SYS_BOOT,
+ "CAP_SYS_NICE": linux.CAP_SYS_NICE,
+ "CAP_SYS_RESOURCE": linux.CAP_SYS_RESOURCE,
+ "CAP_SYS_TIME": linux.CAP_SYS_TIME,
+ "CAP_SYS_TTY_CONFIG": linux.CAP_SYS_TTY_CONFIG,
+ "CAP_MKNOD": linux.CAP_MKNOD,
+ "CAP_LEASE": linux.CAP_LEASE,
+ "CAP_AUDIT_WRITE": linux.CAP_AUDIT_WRITE,
+ "CAP_AUDIT_CONTROL": linux.CAP_AUDIT_CONTROL,
+ "CAP_SETFCAP": linux.CAP_SETFCAP,
+ "CAP_MAC_OVERRIDE": linux.CAP_MAC_OVERRIDE,
+ "CAP_MAC_ADMIN": linux.CAP_MAC_ADMIN,
+ "CAP_SYSLOG": linux.CAP_SYSLOG,
+ "CAP_WAKE_ALARM": linux.CAP_WAKE_ALARM,
+ "CAP_BLOCK_SUSPEND": linux.CAP_BLOCK_SUSPEND,
+ "CAP_AUDIT_READ": linux.CAP_AUDIT_READ,
+}
+
+func capsFromNames(names []string, skipSet map[linux.Capability]struct{}) (auth.CapabilitySet, error) {
+ var caps []linux.Capability
+ for _, n := range names {
+ c, ok := capFromName[n]
+ if !ok {
+ return 0, fmt.Errorf("unknown capability %q", n)
+ }
+ // Should we skip this capabilty?
+ if _, ok := skipSet[c]; ok {
+ continue
+ }
+ caps = append(caps, c)
+ }
+ return auth.CapabilitySetOfMany(caps), nil
+}
+
+// Is9PMount returns true if the given mount can be mounted as an external gofer.
+func Is9PMount(m specs.Mount) bool {
+ return m.Type == "bind" && m.Source != "" && IsSupportedDevMount(m)
+}
+
+// IsSupportedDevMount returns true if the mount is a supported /dev mount.
+// Only mount that does not conflict with runsc default /dev mount is
+// supported.
+func IsSupportedDevMount(m specs.Mount) bool {
+ // These are devices exist inside sentry. See pkg/sentry/fs/dev/dev.go
+ var existingDevices = []string{
+ "/dev/fd", "/dev/stdin", "/dev/stdout", "/dev/stderr",
+ "/dev/null", "/dev/zero", "/dev/full", "/dev/random",
+ "/dev/urandom", "/dev/shm", "/dev/pts", "/dev/ptmx",
+ }
+ dst := filepath.Clean(m.Destination)
+ if dst == "/dev" {
+ // OCI spec uses many different mounts for the things inside of '/dev'. We
+ // have a single mount at '/dev' that is always mounted, regardless of
+ // whether it was asked for, as the spec says we SHOULD.
+ return false
+ }
+ for _, dev := range existingDevices {
+ if dst == dev || strings.HasPrefix(dst, dev+"/") {
+ return false
+ }
+ }
+ return true
+}
+
+// WaitForReady waits for a process to become ready. The process is ready when
+// the 'ready' function returns true. It continues to wait if 'ready' returns
+// false. It returns error on timeout, if the process stops or if 'ready' fails.
+func WaitForReady(pid int, timeout time.Duration, ready func() (bool, error)) error {
+ b := backoff.NewExponentialBackOff()
+ b.InitialInterval = 1 * time.Millisecond
+ b.MaxInterval = 1 * time.Second
+ b.MaxElapsedTime = timeout
+
+ op := func() error {
+ if ok, err := ready(); err != nil {
+ return backoff.Permanent(err)
+ } else if ok {
+ return nil
+ }
+
+ // Check if the process is still running.
+ // If the process is alive, child is 0 because of the NOHANG option.
+ // If the process has terminated, child equals the process id.
+ var ws syscall.WaitStatus
+ var ru syscall.Rusage
+ child, err := syscall.Wait4(pid, &ws, syscall.WNOHANG, &ru)
+ if err != nil {
+ return backoff.Permanent(fmt.Errorf("error waiting for process: %v", err))
+ } else if child == pid {
+ return backoff.Permanent(fmt.Errorf("process %d has terminated", pid))
+ }
+ return fmt.Errorf("process %d not running yet", pid)
+ }
+ return backoff.Retry(op, b)
+}
+
+// DebugLogFile opens a log file using 'logPattern' as location. If 'logPattern'
+// ends with '/', it's used as a directory with default file name.
+// 'logPattern' can contain variables that are substituted:
+// - %TIMESTAMP%: is replaced with a timestamp using the following format:
+// <yyyymmdd-hhmmss.uuuuuu>
+// - %COMMAND%: is replaced with 'command'
+// - %TEST%: is replaced with 'test' (omitted by default)
+func DebugLogFile(logPattern, command, test string) (*os.File, error) {
+ if strings.HasSuffix(logPattern, "/") {
+ // Default format: <debug-log>/runsc.log.<yyyymmdd-hhmmss.uuuuuu>.<command>
+ logPattern += "runsc.log.%TIMESTAMP%.%COMMAND%"
+ }
+ logPattern = strings.Replace(logPattern, "%TIMESTAMP%", time.Now().Format("20060102-150405.000000"), -1)
+ logPattern = strings.Replace(logPattern, "%COMMAND%", command, -1)
+ logPattern = strings.Replace(logPattern, "%TEST%", test, -1)
+
+ dir := filepath.Dir(logPattern)
+ if err := os.MkdirAll(dir, 0775); err != nil {
+ return nil, fmt.Errorf("error creating dir %q: %v", dir, err)
+ }
+ return os.OpenFile(logPattern, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0664)
+}
+
+// Mount creates the mount point and calls Mount with the given flags.
+func Mount(src, dst, typ string, flags uint32) error {
+ // Create the mount point inside. The type must be the same as the
+ // source (file or directory).
+ var isDir bool
+ if typ == "proc" {
+ // Special case, as there is no source directory for proc mounts.
+ isDir = true
+ } else if fi, err := os.Stat(src); err != nil {
+ return fmt.Errorf("Stat(%q) failed: %v", src, err)
+ } else {
+ isDir = fi.IsDir()
+ }
+
+ if isDir {
+ // Create the destination directory.
+ if err := os.MkdirAll(dst, 0777); err != nil {
+ return fmt.Errorf("Mkdir(%q) failed: %v", dst, err)
+ }
+ } else {
+ // Create the parent destination directory.
+ parent := path.Dir(dst)
+ if err := os.MkdirAll(parent, 0777); err != nil {
+ return fmt.Errorf("Mkdir(%q) failed: %v", parent, err)
+ }
+ // Create the destination file if it does not exist.
+ f, err := os.OpenFile(dst, syscall.O_CREAT, 0777)
+ if err != nil {
+ return fmt.Errorf("Open(%q) failed: %v", dst, err)
+ }
+ f.Close()
+ }
+
+ // Do the mount.
+ if err := syscall.Mount(src, dst, typ, uintptr(flags), ""); err != nil {
+ return fmt.Errorf("Mount(%q, %q, %d) failed: %v", src, dst, flags, err)
+ }
+ return nil
+}
+
+// ContainsStr returns true if 'str' is inside 'strs'.
+func ContainsStr(strs []string, str string) bool {
+ for _, s := range strs {
+ if s == str {
+ return true
+ }
+ }
+ return false
+}
+
+// RetryEintr retries the function until an error different than EINTR is
+// returned.
+func RetryEintr(f func() (uintptr, uintptr, error)) (uintptr, uintptr, error) {
+ for {
+ r1, r2, err := f()
+ if err != syscall.EINTR {
+ return r1, r2, err
+ }
+ }
+}
+
+// GetOOMScoreAdj reads the given process' oom_score_adj
+func GetOOMScoreAdj(pid int) (int, error) {
+ data, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/oom_score_adj", pid))
+ if err != nil {
+ return 0, err
+ }
+ return strconv.Atoi(strings.TrimSpace(string(data)))
+}
+
+// GetParentPid gets the parent process ID of the specified PID.
+func GetParentPid(pid int) (int, error) {
+ data, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/stat", pid))
+ if err != nil {
+ return 0, err
+ }
+
+ var cpid string
+ var name string
+ var state string
+ var ppid int
+ // Parse after the binary name.
+ _, err = fmt.Sscanf(string(data),
+ "%v %v %v %d",
+ // cpid is ignored.
+ &cpid,
+ // name is ignored.
+ &name,
+ // state is ignored.
+ &state,
+ &ppid)
+
+ if err != nil {
+ return 0, err
+ }
+
+ return ppid, nil
+}
+
+// EnvVar looks for a varible value in the env slice assuming the following
+// format: "NAME=VALUE".
+func EnvVar(env []string, name string) (string, bool) {
+ prefix := name + "="
+ for _, e := range env {
+ if strings.HasPrefix(e, prefix) {
+ return strings.TrimPrefix(e, prefix), true
+ }
+ }
+ return "", false
+}
+
+// FaqErrorMsg returns an error message pointing to the FAQ.
+func FaqErrorMsg(anchor, msg string) string {
+ return fmt.Sprintf("%s; see https://gvisor.dev/faq#%s for more details", msg, anchor)
+}
diff --git a/runsc/specutils/specutils_test.go b/runsc/specutils/specutils_test.go
new file mode 100644
index 000000000..2c86fffe8
--- /dev/null
+++ b/runsc/specutils/specutils_test.go
@@ -0,0 +1,265 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package specutils
+
+import (
+ "fmt"
+ "os/exec"
+ "strings"
+ "testing"
+ "time"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+)
+
+func TestWaitForReadyHappy(t *testing.T) {
+ cmd := exec.Command("/bin/sleep", "1000")
+ if err := cmd.Start(); err != nil {
+ t.Fatalf("cmd.Start() failed, err: %v", err)
+ }
+ defer cmd.Wait()
+
+ var count int
+ err := WaitForReady(cmd.Process.Pid, 5*time.Second, func() (bool, error) {
+ if count < 3 {
+ count++
+ return false, nil
+ }
+ return true, nil
+ })
+ if err != nil {
+ t.Errorf("ProcessWaitReady got: %v, expected: nil", err)
+ }
+ cmd.Process.Kill()
+}
+
+func TestWaitForReadyFail(t *testing.T) {
+ cmd := exec.Command("/bin/sleep", "1000")
+ if err := cmd.Start(); err != nil {
+ t.Fatalf("cmd.Start() failed, err: %v", err)
+ }
+ defer cmd.Wait()
+
+ var count int
+ err := WaitForReady(cmd.Process.Pid, 5*time.Second, func() (bool, error) {
+ if count < 3 {
+ count++
+ return false, nil
+ }
+ return false, fmt.Errorf("Fake error")
+ })
+ if err == nil {
+ t.Errorf("ProcessWaitReady got: nil, expected: error")
+ }
+ cmd.Process.Kill()
+}
+
+func TestWaitForReadyNotRunning(t *testing.T) {
+ cmd := exec.Command("/bin/true")
+ if err := cmd.Start(); err != nil {
+ t.Fatalf("cmd.Start() failed, err: %v", err)
+ }
+ defer cmd.Wait()
+
+ err := WaitForReady(cmd.Process.Pid, 5*time.Second, func() (bool, error) {
+ return false, nil
+ })
+ if err != nil && !strings.Contains(err.Error(), "terminated") {
+ t.Errorf("ProcessWaitReady got: %v, expected: process terminated", err)
+ }
+ if err == nil {
+ t.Errorf("ProcessWaitReady incorrectly succeeded")
+ }
+}
+
+func TestWaitForReadyTimeout(t *testing.T) {
+ cmd := exec.Command("/bin/sleep", "1000")
+ if err := cmd.Start(); err != nil {
+ t.Fatalf("cmd.Start() failed, err: %v", err)
+ }
+ defer cmd.Wait()
+
+ err := WaitForReady(cmd.Process.Pid, 50*time.Millisecond, func() (bool, error) {
+ return false, nil
+ })
+ if !strings.Contains(err.Error(), "not running yet") {
+ t.Errorf("ProcessWaitReady got: %v, expected: not running yet", err)
+ }
+ cmd.Process.Kill()
+}
+
+func TestSpecInvalid(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ spec specs.Spec
+ error string
+ }{
+ {
+ name: "valid",
+ spec: specs.Spec{
+ Root: &specs.Root{Path: "/"},
+ Process: &specs.Process{
+ Args: []string{"/bin/true"},
+ },
+ Mounts: []specs.Mount{
+ {
+ Source: "src",
+ Destination: "/dst",
+ },
+ },
+ },
+ error: "",
+ },
+ {
+ name: "valid+warning",
+ spec: specs.Spec{
+ Root: &specs.Root{Path: "/"},
+ Process: &specs.Process{
+ Args: []string{"/bin/true"},
+ // This is normally set by docker and will just cause warnings to be logged.
+ ApparmorProfile: "someprofile",
+ },
+ // This is normally set by docker and will just cause warnings to be logged.
+ Linux: &specs.Linux{Seccomp: &specs.LinuxSeccomp{}},
+ },
+ error: "",
+ },
+ {
+ name: "no root",
+ spec: specs.Spec{
+ Process: &specs.Process{
+ Args: []string{"/bin/true"},
+ },
+ },
+ error: "must be defined",
+ },
+ {
+ name: "empty root",
+ spec: specs.Spec{
+ Root: &specs.Root{},
+ Process: &specs.Process{
+ Args: []string{"/bin/true"},
+ },
+ },
+ error: "must be defined",
+ },
+ {
+ name: "no process",
+ spec: specs.Spec{
+ Root: &specs.Root{Path: "/"},
+ },
+ error: "must be defined",
+ },
+ {
+ name: "empty args",
+ spec: specs.Spec{
+ Root: &specs.Root{Path: "/"},
+ Process: &specs.Process{},
+ },
+ error: "must be defined",
+ },
+ {
+ name: "selinux",
+ spec: specs.Spec{
+ Root: &specs.Root{Path: "/"},
+ Process: &specs.Process{
+ Args: []string{"/bin/true"},
+ SelinuxLabel: "somelabel",
+ },
+ },
+ error: "is not supported",
+ },
+ {
+ name: "solaris",
+ spec: specs.Spec{
+ Root: &specs.Root{Path: "/"},
+ Process: &specs.Process{
+ Args: []string{"/bin/true"},
+ },
+ Solaris: &specs.Solaris{},
+ },
+ error: "is not supported",
+ },
+ {
+ name: "windows",
+ spec: specs.Spec{
+ Root: &specs.Root{Path: "/"},
+ Process: &specs.Process{
+ Args: []string{"/bin/true"},
+ },
+ Windows: &specs.Windows{},
+ },
+ error: "is not supported",
+ },
+ {
+ name: "relative mount destination",
+ spec: specs.Spec{
+ Root: &specs.Root{Path: "/"},
+ Process: &specs.Process{
+ Args: []string{"/bin/true"},
+ },
+ Mounts: []specs.Mount{
+ {
+ Source: "src",
+ Destination: "dst",
+ },
+ },
+ },
+ error: "must be an absolute path",
+ },
+ {
+ name: "invalid mount option",
+ spec: specs.Spec{
+ Root: &specs.Root{Path: "/"},
+ Process: &specs.Process{
+ Args: []string{"/bin/true"},
+ },
+ Mounts: []specs.Mount{
+ {
+ Source: "/src",
+ Destination: "/dst",
+ Type: "bind",
+ Options: []string{"shared"},
+ },
+ },
+ },
+ error: "is not supported",
+ },
+ {
+ name: "invalid rootfs propagation",
+ spec: specs.Spec{
+ Root: &specs.Root{Path: "/"},
+ Process: &specs.Process{
+ Args: []string{"/bin/true"},
+ },
+ Linux: &specs.Linux{
+ RootfsPropagation: "foo",
+ },
+ },
+ error: "root mount propagation option must specify private or slave",
+ },
+ } {
+ err := ValidateSpec(&test.spec)
+ if len(test.error) == 0 {
+ if err != nil {
+ t.Errorf("ValidateSpec(%q) failed, err: %v", test.name, err)
+ }
+ } else {
+ if err == nil || !strings.Contains(err.Error(), test.error) {
+ t.Errorf("ValidateSpec(%q) wrong error, got: %v, want: .*%s.*", test.name, err, test.error)
+ }
+ }
+ }
+}
diff --git a/runsc/version.go b/runsc/version.go
new file mode 100644
index 000000000..ab9194b9d
--- /dev/null
+++ b/runsc/version.go
@@ -0,0 +1,18 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+// version is set during linking.
+var version = "VERSION_MISSING"
diff --git a/runsc/version_test.sh b/runsc/version_test.sh
new file mode 100755
index 000000000..747350654
--- /dev/null
+++ b/runsc/version_test.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -euf -x -o pipefail
+
+readonly runsc="$1"
+readonly version=$($runsc --version)
+
+# Version should should not match VERSION, which is the default and which will
+# also appear if something is wrong with workspace_status.sh script.
+if [[ $version =~ "VERSION" ]]; then
+ echo "FAIL: Got bad version $version"
+ exit 1
+fi
+
+# Version should contain at least one number.
+if [[ ! $version =~ [0-9] ]]; then
+ echo "FAIL: Got bad version $version"
+ exit 1
+fi
+
+echo "PASS: Got OK version $version"
+exit 0
diff --git a/scripts/benchmark.sh b/scripts/benchmark.sh
new file mode 100755
index 000000000..e0f6df438
--- /dev/null
+++ b/scripts/benchmark.sh
@@ -0,0 +1,45 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# gcloud may be installed as a "snap". If it is, include it in PATH.
+declare -r snap="/snap/bin"
+if [[ -d "${snap}" ]]; then
+ export PATH="${PATH}:${snap}"
+fi
+
+# Make sure we can find gcloud and exit if not.
+which gcloud
+
+# Exporting for subprocesses as GCP APIs and tools check this environmental
+# variable for authentication.
+export GOOGLE_APPLICATION_CREDENTIALS="${KOKORO_KEYSTORE_DIR}/${GCLOUD_CREDENTIALS}"
+
+gcloud auth activate-service-account \
+ --key-file "${GOOGLE_APPLICATION_CREDENTIALS}"
+
+gcloud config set project ${PROJECT}
+gcloud config set compute/zone ${ZONE}
+
+bazel run //benchmarks:benchmarks -- \
+ --verbose \
+ run-gcp \
+ "(startup|absl)" \
+ --internal \
+ --runtime=runc \
+ --runtime=runsc \
+ --installers=head
diff --git a/scripts/common.sh b/scripts/common.sh
new file mode 100755
index 000000000..3ca699e4a
--- /dev/null
+++ b/scripts/common.sh
@@ -0,0 +1,86 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeou pipefail
+
+# Get the path to the directory this script lives in.
+# If this script is being called with `source`, $0 will be the path of the
+# *sourcing* script, so we can't use `dirname $0` to find scripts in this
+# directory.
+if [[ -v BASH_SOURCE && "$0" != "$BASH_SOURCE" ]]; then
+ declare -r script_dir="$(dirname "$BASH_SOURCE")"
+else
+ declare -r script_dir="$(dirname "$0")"
+fi
+
+source "${script_dir}/common_build.sh"
+
+# Ensure it attempts to collect logs in all cases.
+trap collect_logs EXIT
+
+function set_runtime() {
+ RUNTIME=${1:-runsc}
+ RUNSC_BIN=/tmp/"${RUNTIME}"/runsc
+ RUNSC_LOGS_DIR="$(dirname ${RUNSC_BIN})"/logs
+ RUNSC_LOGS="${RUNSC_LOGS_DIR}"/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%
+}
+
+function test_runsc() {
+ test --test_arg=--runtime=${RUNTIME} "$@"
+}
+
+function install_runsc_for_test() {
+ local -r test_name=$1
+ shift
+ if [[ -z "${test_name}" ]]; then
+ echo "Missing mandatory test name"
+ exit 1
+ fi
+
+ # Add test to the name, so it doesn't conflict with other runtimes.
+ set_runtime $(find_branch_name)_"${test_name}"
+
+ # ${RUNSC_TEST_NAME} is set by tests (see dockerutil) to pass the test name
+ # down to the runtime.
+ install_runsc "${RUNTIME}" \
+ --TESTONLY-test-name-env=RUNSC_TEST_NAME \
+ --debug \
+ --strace \
+ --log-packets \
+ "$@"
+}
+
+# Installs the runsc with given runtime name. set_runtime must have been called
+# to set runtime and logs location.
+function install_runsc() {
+ local -r runtime=$1
+ shift
+
+ # Prepare the runtime binary.
+ local -r output=$(build //runsc)
+ mkdir -p "$(dirname ${RUNSC_BIN})"
+ cp -f "${output}" "${RUNSC_BIN}"
+ chmod 0755 "${RUNSC_BIN}"
+
+ # Install the runtime.
+ sudo "${RUNSC_BIN}" install --experimental=true --runtime="${runtime}" -- --debug-log "${RUNSC_LOGS}" "$@"
+
+ # Clear old logs files that may exist.
+ sudo rm -f "${RUNSC_LOGS_DIR}"/'*'
+
+ # Restart docker to pick up the new runtime configuration.
+ sudo systemctl restart docker
+}
diff --git a/scripts/common_build.sh b/scripts/common_build.sh
new file mode 100755
index 000000000..0d9a191b5
--- /dev/null
+++ b/scripts/common_build.sh
@@ -0,0 +1,116 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+which bazel
+bazel version
+
+# Switch into the workspace; only necessary if run with kokoro.
+if [[ -v KOKORO_GIT_COMMIT ]] && [[ -d git/repo ]]; then
+ cd git/repo
+elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then
+ cd github/repo
+fi
+
+# Set the standard bazel flags.
+declare -a BAZEL_FLAGS=(
+ "--show_timestamps"
+ "--test_output=errors"
+ "--keep_going"
+ "--verbose_failures=true"
+)
+if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]]; then
+ BAZEL_FLAGS+=(
+ "--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}"
+ "--config=remote"
+ )
+fi
+declare -r BAZEL_FLAGS
+
+# Wrap bazel.
+function build() {
+ bazel build "${BAZEL_FLAGS[@]}" "$@" 2>&1 \
+ | tee /dev/fd/2 \
+ | grep -E '^ bazel-bin/' \
+ | awk '{ print $1; }'
+}
+
+function test() {
+ bazel test "${BAZEL_FLAGS[@]}" "$@"
+}
+
+function run() {
+ local binary=$1
+ shift
+ bazel run "${binary}" -- "$@"
+}
+
+function run_as_root() {
+ local binary=$1
+ shift
+ bazel run --run_under="sudo" "${binary}" -- "$@"
+}
+
+function query() {
+ QUERY_RESULT=$(bazel query "$@")
+}
+
+function collect_logs() {
+ # Zip out everything into a convenient form.
+ if [[ -v KOKORO_ARTIFACTS_DIR ]] && [[ -e bazel-testlogs ]]; then
+ # Merge results files of all shards for each test suite.
+ for d in `find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs dirname | sort | uniq`; do
+ junitparser merge `find $d -name test.xml` $d/test.xml
+ cat $d/shard_*_of_*/test.log > $d/test.log
+ if ls -ld $d/shard_*_of_*/test.outputs 2>/dev/null; then
+ zip -r -1 "$d/outputs.zip" $d/shard_*_of_*/test.outputs
+ fi
+ done
+ find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs rm -rf
+ # Move test logs to Kokoro directory. tar is used to conveniently perform
+ # renames while moving files.
+ find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" |
+ tar --create --files-from - --transform 's/test\./sponge_log./' |
+ tar --extract --directory ${KOKORO_ARTIFACTS_DIR}
+
+ # Collect sentry logs, if any.
+ if [[ -v RUNSC_LOGS_DIR ]] && [[ -d "${RUNSC_LOGS_DIR}" ]]; then
+ # Check if the directory is empty or not (only the first line it needed).
+ local -r logs=$(ls "${RUNSC_LOGS_DIR}" | head -n1)
+ if [[ "${logs}" ]]; then
+ local -r archive=runsc_logs_"${RUNTIME}".tar.gz
+ if [[ -v KOKORO_BUILD_ARTIFACTS_SUBDIR ]]; then
+ echo "runsc logs will be uploaded to:"
+ echo " gsutil cp gs://gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive} /tmp"
+ echo " https://storage.cloud.google.com/gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive}"
+ fi
+ time tar \
+ --verbose \
+ --create \
+ --gzip \
+ --file="${KOKORO_ARTIFACTS_DIR}/${archive}" \
+ --directory "${RUNSC_LOGS_DIR}" \
+ .
+ fi
+ fi
+ fi
+}
+
+function find_branch_name() {
+ git branch --show-current \
+ || git rev-parse HEAD \
+ || bazel info workspace \
+ | xargs basename
+}
diff --git a/scripts/dev.sh b/scripts/dev.sh
new file mode 100755
index 000000000..a9107f33e
--- /dev/null
+++ b/scripts/dev.sh
@@ -0,0 +1,75 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# common.sh sets '-x', but it's annoying to see so much output.
+set +x
+
+# Defaults
+declare -i REFRESH=0
+declare NAME=$(find_branch_name)
+
+while [[ $# -gt 0 ]]; do
+ case "$1" in
+ --refresh)
+ REFRESH=1
+ ;;
+ --help)
+ echo "Use this script to build and install runsc with Docker."
+ echo
+ echo "usage: $0 [--refresh] [runtime_name]"
+ exit 1
+ ;;
+ *)
+ NAME=$1
+ ;;
+ esac
+ shift
+done
+
+set_runtime "${NAME}"
+echo
+echo "Using runtime=${RUNTIME}"
+echo
+
+echo Building runsc...
+# Build first and fail on error. $() prevents "set -e" from reporting errors.
+build //runsc
+declare OUTPUT="$(build //runsc)"
+
+if [[ ${REFRESH} -eq 0 ]]; then
+ install_runsc "${RUNTIME}" --net-raw
+ install_runsc "${RUNTIME}-d" --net-raw --debug --strace --log-packets
+ install_runsc "${RUNTIME}-p" --net-raw --profile
+
+ echo
+ echo "Runtimes ${RUNTIME}, ${RUNTIME}-d (debug enabled), and ${RUNTIME}-p installed."
+ echo "Use --runtime="${RUNTIME}" with your Docker command."
+ echo " docker run --rm --runtime="${RUNTIME}" hello-world"
+ echo
+ echo "If you rebuild, use $0 --refresh."
+
+else
+ mkdir -p "$(dirname ${RUNSC_BIN})"
+ cp -f ${OUTPUT} "${RUNSC_BIN}"
+ chmod a+rx "${RUNSC_BIN}"
+
+ echo
+ echo "Runtime ${RUNTIME} refreshed."
+fi
+
+echo "Logs are in: ${RUNSC_LOGS_DIR}"
diff --git a/scripts/do_tests.sh b/scripts/do_tests.sh
new file mode 100755
index 000000000..a3a387c37
--- /dev/null
+++ b/scripts/do_tests.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Build runsc.
+build //runsc
+
+# run runsc do without root privileges.
+run //runsc --rootless do true
+run //runsc --rootless --network=none do true
+
+# run runsc do with root privileges.
+run_as_root //runsc do true
diff --git a/scripts/docker_tests.sh b/scripts/docker_tests.sh
new file mode 100755
index 000000000..dce0a4085
--- /dev/null
+++ b/scripts/docker_tests.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+make load-all-images
+
+install_runsc_for_test docker
+test_runsc //test/image:image_test //test/e2e:integration_test
+
+install_runsc_for_test docker --vfs2
+test_runsc //test/image:image_test --test_filter=.*TestHelloWorld
diff --git a/scripts/go.sh b/scripts/go.sh
new file mode 100755
index 000000000..626ed8fa4
--- /dev/null
+++ b/scripts/go.sh
@@ -0,0 +1,45 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Build the go path.
+build :gopath
+
+# Build the synthetic branch.
+tools/go_branch.sh
+
+# Checkout the new branch.
+git checkout go && git clean -f
+
+go version
+
+# Build everything.
+go build ./...
+
+# Push, if required.
+if [[ -v KOKORO_GO_PUSH ]] && [[ "${KOKORO_GO_PUSH}" == "true" ]]; then
+ if [[ -v KOKORO_GITHUB_ACCESS_TOKEN ]]; then
+ git config --global credential.helper cache
+ git credential approve <<EOF
+protocol=https
+host=github.com
+username=$(cat "${KOKORO_KEYSTORE_DIR}/${KOKORO_GITHUB_ACCESS_TOKEN}")
+password=x-oauth-basic
+EOF
+ fi
+ git push origin go:go
+fi
diff --git a/scripts/hostnet_tests.sh b/scripts/hostnet_tests.sh
new file mode 100755
index 000000000..992db50dd
--- /dev/null
+++ b/scripts/hostnet_tests.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+make load-all-images
+
+# Install the runtime and perform basic tests.
+install_runsc_for_test hostnet --network=host
+test_runsc --test_arg=-checkpoint=false //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/iptables_tests.sh b/scripts/iptables_tests.sh
new file mode 100755
index 000000000..8299a7c8b
--- /dev/null
+++ b/scripts/iptables_tests.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+make load-iptables
+
+# Needed by ip6tables.
+sudo modprobe ip6table_filter
+
+install_runsc_for_test iptables --net-raw
+test //test/iptables:iptables_test "--test_arg=--runtime=runc"
+test //test/iptables:iptables_test "--test_arg=--runtime=${RUNTIME}"
diff --git a/scripts/issue_reviver.sh b/scripts/issue_reviver.sh
new file mode 100755
index 000000000..bac9b9192
--- /dev/null
+++ b/scripts/issue_reviver.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+DIR=$(dirname $0)
+source "${DIR}"/common.sh
+
+# Provide a credential file if available.
+export OAUTH_TOKEN_FILE=""
+if [[ -v KOKORO_GITHUB_ACCESS_TOKEN ]]; then
+ OAUTH_TOKEN_FILE="${KOKORO_KEYSTORE_DIR}/${KOKORO_GITHUB_ACCESS_TOKEN}"
+fi
+
+REPO_ROOT=$(cd "$(dirname "${DIR}")"; pwd)
+run //tools/issue_reviver:issue_reviver --path "${REPO_ROOT}" --oauth-token-file="${OAUTH_TOKEN_FILE}"
diff --git a/scripts/kvm_tests.sh b/scripts/kvm_tests.sh
new file mode 100755
index 000000000..619571c74
--- /dev/null
+++ b/scripts/kvm_tests.sh
@@ -0,0 +1,30 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+make load-all-images
+
+# Ensure that KVM is loaded, and we can use it.
+(lsmod | grep -E '^(kvm_intel|kvm_amd)') || sudo modprobe kvm
+sudo chmod a+rw /dev/kvm
+
+# Run all KVM platform tests (locally).
+run_as_root //pkg/sentry/platform/kvm:kvm_test
+
+# Install the KVM runtime and run all integration tests.
+install_runsc_for_test kvm --platform=kvm
+test_runsc //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/make_tests.sh b/scripts/make_tests.sh
new file mode 100755
index 000000000..dbf1bba77
--- /dev/null
+++ b/scripts/make_tests.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+make runsc
+make bazel-shutdown
diff --git a/scripts/overlay_tests.sh b/scripts/overlay_tests.sh
new file mode 100755
index 000000000..448864953
--- /dev/null
+++ b/scripts/overlay_tests.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+make load-all-images
+
+# Install the runtime and perform basic tests.
+install_runsc_for_test overlay --overlay
+test_runsc //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/packetdrill_tests.sh b/scripts/packetdrill_tests.sh
new file mode 100755
index 000000000..727503bce
--- /dev/null
+++ b/scripts/packetdrill_tests.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+make load-packetdrill
+
+install_runsc_for_test runsc-d
+query "attr(tags, manual, tests(//test/packetdrill/...))"
+test_runsc $QUERY_RESULT
diff --git a/scripts/packetimpact_tests.sh b/scripts/packetimpact_tests.sh
new file mode 100755
index 000000000..51c11f23f
--- /dev/null
+++ b/scripts/packetimpact_tests.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+make load-packetimpact
+
+install_runsc_for_test runsc-d
+query "attr(tags, packetimpact, tests(//test/packetimpact/...))"
+test_runsc $QUERY_RESULT
diff --git a/scripts/root_tests.sh b/scripts/root_tests.sh
new file mode 100755
index 000000000..d629bf2aa
--- /dev/null
+++ b/scripts/root_tests.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+make load-all-images
+
+# Reinstall the latest containerd shim.
+declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim"
+declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX)
+declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX)
+wget --no-verbose "${base}"/latest -O ${latest}
+wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path}
+chmod +x ${shim_path}
+sudo mv ${shim_path} /usr/local/bin/gvisor-containerd-shim
+
+# Run the tests that require root.
+install_runsc_for_test root
+run_as_root //test/root:root_test --runtime=${RUNTIME}
diff --git a/scripts/runtime_tests.sh b/scripts/runtime_tests.sh
new file mode 100755
index 000000000..350a59f7c
--- /dev/null
+++ b/scripts/runtime_tests.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Check that a runtime is provided.
+if [ ! -v RUNTIME_TEST_NAME ]; then
+ echo "Must set $RUNTIME_TEST_NAME" >&2
+ exit 1
+fi
+
+install_runsc_for_test runtimes
+test_runsc "//test/runtimes:${RUNTIME_TEST_NAME}_test"
diff --git a/scripts/simple_tests.sh b/scripts/simple_tests.sh
new file mode 100755
index 000000000..3a15050c2
--- /dev/null
+++ b/scripts/simple_tests.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Run all simple tests (locally).
+test //pkg/... //runsc/... //tools/... //benchmarks/... //benchmarks/runner:runner_test
diff --git a/scripts/swgso_tests.sh b/scripts/swgso_tests.sh
new file mode 100755
index 000000000..c67f2fe5c
--- /dev/null
+++ b/scripts/swgso_tests.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+make load-all-images
+
+# Install the runtime and perform basic tests.
+install_runsc_for_test swgso --software-gso=true --gso=false
+test_runsc //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/syscall_kvm_tests.sh b/scripts/syscall_kvm_tests.sh
new file mode 100755
index 000000000..0e5d86727
--- /dev/null
+++ b/scripts/syscall_kvm_tests.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Run all ptrace-variants of the system call tests.
+test --test_tag_filters=runsc_kvm //test/syscalls/...
diff --git a/scripts/syscall_tests.sh b/scripts/syscall_tests.sh
new file mode 100755
index 000000000..a131b2d50
--- /dev/null
+++ b/scripts/syscall_tests.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Run all ptrace-variants of the system call tests.
+test --test_tag_filters=runsc_ptrace //test/syscalls/...
diff --git a/test/BUILD b/test/BUILD
new file mode 100644
index 000000000..34b950644
--- /dev/null
+++ b/test/BUILD
@@ -0,0 +1 @@
+package(licenses = ["notice"])
diff --git a/test/README.md b/test/README.md
new file mode 100644
index 000000000..02bbf42ff
--- /dev/null
+++ b/test/README.md
@@ -0,0 +1,40 @@
+# Tests
+
+The tests defined under this path are verifying functionality beyond what unit
+tests can cover, e.g. integration and end to end tests. Due to their nature,
+they may need extra setup in the test machine and extra configuration to run.
+
+- **syscalls**: system call tests use a local runner, and do not require
+ additional configuration in the machine.
+- **integration:** defines integration tests that uses `docker run` to test
+ functionality.
+- **image:** basic end to end test for popular images. These require the same
+ setup as integration tests.
+- **root:** tests that require to be run as root. These require the same setup
+ as integration tests.
+- **util:** utilities library to support the tests.
+
+For the above noted cases, the relevant runtime must be installed via `runsc
+install` before running. Just note that they require specific configuration to
+work. This is handled automatically by the test scripts in the `scripts`
+directory and they can be used to run tests locally on your machine. They are
+also used to run these tests in `kokoro`.
+
+**Example:**
+
+To run image and integration tests, run:
+
+`./scripts/docker_tests.sh`
+
+To run root tests, run:
+
+`./scripts/root_tests.sh`
+
+There are a few other interesting variations for image and integration tests:
+
+* overlay: sets writable overlay inside the sentry
+* hostnet: configures host network pass-thru, instead of netstack
+* kvm: runsc the test using the KVM platform, instead of ptrace
+
+The test will build runsc, configure it with your local docker, restart
+`dockerd`, and run tests. The location for runsc logs is printed to the output.
diff --git a/test/cmd/test_app/BUILD b/test/cmd/test_app/BUILD
new file mode 100644
index 000000000..98ba5a3d9
--- /dev/null
+++ b/test/cmd/test_app/BUILD
@@ -0,0 +1,21 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "test_app",
+ testonly = 1,
+ srcs = [
+ "fds.go",
+ "test_app.go",
+ ],
+ pure = True,
+ visibility = ["//runsc/container:__pkg__"],
+ deps = [
+ "//pkg/test/testutil",
+ "//pkg/unet",
+ "//runsc/flag",
+ "@com_github_google_subcommands//:go_default_library",
+ "@com_github_kr_pty//:go_default_library",
+ ],
+)
diff --git a/test/cmd/test_app/fds.go b/test/cmd/test_app/fds.go
new file mode 100644
index 000000000..a7658eefd
--- /dev/null
+++ b/test/cmd/test_app/fds.go
@@ -0,0 +1,185 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "context"
+ "io/ioutil"
+ "log"
+ "os"
+ "time"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/pkg/unet"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+const fileContents = "foobarbaz"
+
+// fdSender will open a file and send the FD over a unix domain socket.
+type fdSender struct {
+ socketPath string
+}
+
+// Name implements subcommands.Command.Name.
+func (*fdSender) Name() string {
+ return "fd_sender"
+}
+
+// Synopsis implements subcommands.Command.Synopsys.
+func (*fdSender) Synopsis() string {
+ return "creates a file and sends the FD over the socket"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*fdSender) Usage() string {
+ return "fd_sender <flags>"
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (fds *fdSender) SetFlags(f *flag.FlagSet) {
+ f.StringVar(&fds.socketPath, "socket", "", "path to socket")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (fds *fdSender) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if fds.socketPath == "" {
+ log.Fatalf("socket flag must be set")
+ }
+
+ dir, err := ioutil.TempDir("", "")
+ if err != nil {
+ log.Fatalf("TempDir failed: %v", err)
+ }
+
+ fileToSend, err := ioutil.TempFile(dir, "")
+ if err != nil {
+ log.Fatalf("TempFile failed: %v", err)
+ }
+ defer fileToSend.Close()
+
+ if _, err := fileToSend.WriteString(fileContents); err != nil {
+ log.Fatalf("Write(%q) failed: %v", fileContents, err)
+ }
+
+ // Receiver may not be started yet, so try connecting in a poll loop.
+ var s *unet.Socket
+ if err := testutil.Poll(func() error {
+ var err error
+ s, err = unet.Connect(fds.socketPath, true /* SEQPACKET, so we can send empty message with FD */)
+ return err
+ }, 10*time.Second); err != nil {
+ log.Fatalf("Error connecting to socket %q: %v", fds.socketPath, err)
+ }
+ defer s.Close()
+
+ w := s.Writer(true)
+ w.ControlMessage.PackFDs(int(fileToSend.Fd()))
+ if _, err := w.WriteVec([][]byte{[]byte{'a'}}); err != nil {
+ log.Fatalf("Error sending FD %q over socket %q: %v", fileToSend.Fd(), fds.socketPath, err)
+ }
+
+ log.Print("FD SENDER exiting successfully")
+ return subcommands.ExitSuccess
+}
+
+// fdReceiver receives an FD from a unix domain socket and does things to it.
+type fdReceiver struct {
+ socketPath string
+}
+
+// Name implements subcommands.Command.Name.
+func (*fdReceiver) Name() string {
+ return "fd_receiver"
+}
+
+// Synopsis implements subcommands.Command.Synopsys.
+func (*fdReceiver) Synopsis() string {
+ return "reads an FD from a unix socket, and then does things to it"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*fdReceiver) Usage() string {
+ return "fd_receiver <flags>"
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (fdr *fdReceiver) SetFlags(f *flag.FlagSet) {
+ f.StringVar(&fdr.socketPath, "socket", "", "path to socket")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (fdr *fdReceiver) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if fdr.socketPath == "" {
+ log.Fatalf("Flags cannot be empty, given: socket: %q", fdr.socketPath)
+ }
+
+ ss, err := unet.BindAndListen(fdr.socketPath, true /* packet */)
+ if err != nil {
+ log.Fatalf("BindAndListen(%q) failed: %v", fdr.socketPath, err)
+ }
+ defer ss.Close()
+
+ var s *unet.Socket
+ c := make(chan error, 1)
+ go func() {
+ var err error
+ s, err = ss.Accept()
+ c <- err
+ }()
+
+ select {
+ case err := <-c:
+ if err != nil {
+ log.Fatalf("Accept() failed: %v", err)
+ }
+ case <-time.After(10 * time.Second):
+ log.Fatalf("Timeout waiting for accept")
+ }
+
+ r := s.Reader(true)
+ r.EnableFDs(1)
+ b := [][]byte{{'a'}}
+ if n, err := r.ReadVec(b); n != 1 || err != nil {
+ log.Fatalf("ReadVec got n=%d err %v (wanted 0, nil)", n, err)
+ }
+
+ fds, err := r.ExtractFDs()
+ if err != nil {
+ log.Fatalf("ExtractFD() got err %v", err)
+ }
+ if len(fds) != 1 {
+ log.Fatalf("ExtractFD() got %d FDs, wanted 1", len(fds))
+ }
+ fd := fds[0]
+
+ file := os.NewFile(uintptr(fd), "received file")
+ defer file.Close()
+ if _, err := file.Seek(0, os.SEEK_SET); err != nil {
+ log.Fatalf("Seek(0, 0) failed: %v", err)
+ }
+
+ got, err := ioutil.ReadAll(file)
+ if err != nil {
+ log.Fatalf("ReadAll failed: %v", err)
+ }
+ if string(got) != fileContents {
+ log.Fatalf("ReadAll got %q want %q", string(got), fileContents)
+ }
+
+ log.Print("FD RECEIVER exiting successfully")
+ return subcommands.ExitSuccess
+}
diff --git a/test/cmd/test_app/test_app.go b/test/cmd/test_app/test_app.go
new file mode 100644
index 000000000..3ba4f38f8
--- /dev/null
+++ b/test/cmd/test_app/test_app.go
@@ -0,0 +1,394 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary test_app is like a swiss knife for tests that need to run anything
+// inside the sandbox. New functionality can be added with new commands.
+package main
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net"
+ "os"
+ "os/exec"
+ "regexp"
+ "strconv"
+ sys "syscall"
+ "time"
+
+ "github.com/google/subcommands"
+ "github.com/kr/pty"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+func main() {
+ subcommands.Register(subcommands.HelpCommand(), "")
+ subcommands.Register(subcommands.FlagsCommand(), "")
+ subcommands.Register(new(capability), "")
+ subcommands.Register(new(fdReceiver), "")
+ subcommands.Register(new(fdSender), "")
+ subcommands.Register(new(forkBomb), "")
+ subcommands.Register(new(ptyRunner), "")
+ subcommands.Register(new(reaper), "")
+ subcommands.Register(new(syscall), "")
+ subcommands.Register(new(taskTree), "")
+ subcommands.Register(new(uds), "")
+
+ flag.Parse()
+
+ exitCode := subcommands.Execute(context.Background())
+ os.Exit(int(exitCode))
+}
+
+type uds struct {
+ fileName string
+ socketPath string
+}
+
+// Name implements subcommands.Command.Name.
+func (*uds) Name() string {
+ return "uds"
+}
+
+// Synopsis implements subcommands.Command.Synopsys.
+func (*uds) Synopsis() string {
+ return "creates unix domain socket client and server. Client sends a contant flow of sequential numbers. Server prints them to --file"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*uds) Usage() string {
+ return "uds <flags>"
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (c *uds) SetFlags(f *flag.FlagSet) {
+ f.StringVar(&c.fileName, "file", "", "name of output file")
+ f.StringVar(&c.socketPath, "socket", "", "path to socket")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (c *uds) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if c.fileName == "" || c.socketPath == "" {
+ log.Fatalf("Flags cannot be empty, given: fileName: %q, socketPath: %q", c.fileName, c.socketPath)
+ return subcommands.ExitFailure
+ }
+ outputFile, err := os.OpenFile(c.fileName, os.O_WRONLY|os.O_CREATE, 0666)
+ if err != nil {
+ log.Fatal("error opening output file:", err)
+ }
+
+ defer os.Remove(c.socketPath)
+
+ listener, err := net.Listen("unix", c.socketPath)
+ if err != nil {
+ log.Fatalf("error listening on socket %q: %v", c.socketPath, err)
+ }
+
+ go server(listener, outputFile)
+ for i := 0; ; i++ {
+ conn, err := net.Dial("unix", c.socketPath)
+ if err != nil {
+ log.Fatal("error dialing:", err)
+ }
+ if _, err := conn.Write([]byte(strconv.Itoa(i))); err != nil {
+ log.Fatal("error writing:", err)
+ }
+ conn.Close()
+ time.Sleep(100 * time.Millisecond)
+ }
+}
+
+func server(listener net.Listener, out *os.File) {
+ buf := make([]byte, 16)
+
+ for {
+ c, err := listener.Accept()
+ if err != nil {
+ log.Fatal("error accepting connection:", err)
+ }
+ nr, err := c.Read(buf)
+ if err != nil {
+ log.Fatal("error reading from buf:", err)
+ }
+ data := buf[0:nr]
+ fmt.Fprint(out, string(data)+"\n")
+ }
+}
+
+type taskTree struct {
+ depth int
+ width int
+ pause bool
+}
+
+// Name implements subcommands.Command.
+func (*taskTree) Name() string {
+ return "task-tree"
+}
+
+// Synopsis implements subcommands.Command.
+func (*taskTree) Synopsis() string {
+ return "creates a tree of tasks"
+}
+
+// Usage implements subcommands.Command.
+func (*taskTree) Usage() string {
+ return "task-tree <flags>"
+}
+
+// SetFlags implements subcommands.Command.
+func (c *taskTree) SetFlags(f *flag.FlagSet) {
+ f.IntVar(&c.depth, "depth", 1, "number of levels to create")
+ f.IntVar(&c.width, "width", 1, "number of tasks at each level")
+ f.BoolVar(&c.pause, "pause", false, "whether the tasks should pause perpetually")
+}
+
+// Execute implements subcommands.Command.
+func (c *taskTree) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ stop := testutil.StartReaper()
+ defer stop()
+
+ if c.depth == 0 {
+ log.Printf("Child sleeping, PID: %d\n", os.Getpid())
+ select {}
+ }
+ log.Printf("Parent %d sleeping, PID: %d\n", c.depth, os.Getpid())
+
+ var cmds []*exec.Cmd
+ for i := 0; i < c.width; i++ {
+ cmd := exec.Command(
+ "/proc/self/exe", c.Name(),
+ "--depth", strconv.Itoa(c.depth-1),
+ "--width", strconv.Itoa(c.width),
+ "--pause", strconv.FormatBool(c.pause))
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+
+ if err := cmd.Start(); err != nil {
+ log.Fatal("failed to call self:", err)
+ }
+ cmds = append(cmds, cmd)
+ }
+
+ for _, c := range cmds {
+ c.Wait()
+ }
+
+ if c.pause {
+ select {}
+ }
+
+ return subcommands.ExitSuccess
+}
+
+type forkBomb struct {
+ delay time.Duration
+}
+
+// Name implements subcommands.Command.
+func (*forkBomb) Name() string {
+ return "fork-bomb"
+}
+
+// Synopsis implements subcommands.Command.
+func (*forkBomb) Synopsis() string {
+ return "creates child process until the end of times"
+}
+
+// Usage implements subcommands.Command.
+func (*forkBomb) Usage() string {
+ return "fork-bomb <flags>"
+}
+
+// SetFlags implements subcommands.Command.
+func (c *forkBomb) SetFlags(f *flag.FlagSet) {
+ f.DurationVar(&c.delay, "delay", 100*time.Millisecond, "amount of time to delay creation of child")
+}
+
+// Execute implements subcommands.Command.
+func (c *forkBomb) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ time.Sleep(c.delay)
+
+ cmd := exec.Command("/proc/self/exe", c.Name())
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+ if err := cmd.Run(); err != nil {
+ log.Fatal("failed to call self:", err)
+ }
+ return subcommands.ExitSuccess
+}
+
+type reaper struct{}
+
+// Name implements subcommands.Command.
+func (*reaper) Name() string {
+ return "reaper"
+}
+
+// Synopsis implements subcommands.Command.
+func (*reaper) Synopsis() string {
+ return "reaps all children in a loop"
+}
+
+// Usage implements subcommands.Command.
+func (*reaper) Usage() string {
+ return "reaper <flags>"
+}
+
+// SetFlags implements subcommands.Command.
+func (*reaper) SetFlags(*flag.FlagSet) {}
+
+// Execute implements subcommands.Command.
+func (c *reaper) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ stop := testutil.StartReaper()
+ defer stop()
+ select {}
+}
+
+type syscall struct {
+ sysno uint64
+}
+
+// Name implements subcommands.Command.
+func (*syscall) Name() string {
+ return "syscall"
+}
+
+// Synopsis implements subcommands.Command.
+func (*syscall) Synopsis() string {
+ return "syscall makes a syscall"
+}
+
+// Usage implements subcommands.Command.
+func (*syscall) Usage() string {
+ return "syscall <flags>"
+}
+
+// SetFlags implements subcommands.Command.
+func (s *syscall) SetFlags(f *flag.FlagSet) {
+ f.Uint64Var(&s.sysno, "syscall", 0, "syscall to call")
+}
+
+// Execute implements subcommands.Command.
+func (s *syscall) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if _, _, errno := sys.Syscall(uintptr(s.sysno), 0, 0, 0); errno != 0 {
+ fmt.Printf("syscall(%d, 0, 0...) failed: %v\n", s.sysno, errno)
+ } else {
+ fmt.Printf("syscall(%d, 0, 0...) success\n", s.sysno)
+ }
+ return subcommands.ExitSuccess
+}
+
+type capability struct {
+ enabled uint64
+ disabled uint64
+}
+
+// Name implements subcommands.Command.
+func (*capability) Name() string {
+ return "capability"
+}
+
+// Synopsis implements subcommands.Command.
+func (*capability) Synopsis() string {
+ return "checks if effective capabilities are set/unset"
+}
+
+// Usage implements subcommands.Command.
+func (*capability) Usage() string {
+ return "capability [--enabled=number] [--disabled=number]"
+}
+
+// SetFlags implements subcommands.Command.
+func (c *capability) SetFlags(f *flag.FlagSet) {
+ f.Uint64Var(&c.enabled, "enabled", 0, "")
+ f.Uint64Var(&c.disabled, "disabled", 0, "")
+}
+
+// Execute implements subcommands.Command.
+func (c *capability) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if c.enabled == 0 && c.disabled == 0 {
+ fmt.Println("One of the flags must be set")
+ return subcommands.ExitUsageError
+ }
+
+ status, err := ioutil.ReadFile("/proc/self/status")
+ if err != nil {
+ fmt.Printf("Error reading %q: %v\n", "proc/self/status", err)
+ return subcommands.ExitFailure
+ }
+ re := regexp.MustCompile("CapEff:\t([0-9a-f]+)\n")
+ matches := re.FindStringSubmatch(string(status))
+ if matches == nil || len(matches) != 2 {
+ fmt.Printf("Effective capabilities not found in\n%s\n", status)
+ return subcommands.ExitFailure
+ }
+ caps, err := strconv.ParseUint(matches[1], 16, 64)
+ if err != nil {
+ fmt.Printf("failed to convert capabilities %q: %v\n", matches[1], err)
+ return subcommands.ExitFailure
+ }
+
+ if c.enabled != 0 && (caps&c.enabled) != c.enabled {
+ fmt.Printf("Missing capabilities, want: %#x: got: %#x\n", c.enabled, caps)
+ return subcommands.ExitFailure
+ }
+ if c.disabled != 0 && (caps&c.disabled) != 0 {
+ fmt.Printf("Extra capabilities found, dont_want: %#x: got: %#x\n", c.disabled, caps)
+ return subcommands.ExitFailure
+ }
+
+ return subcommands.ExitSuccess
+}
+
+type ptyRunner struct{}
+
+// Name implements subcommands.Command.
+func (*ptyRunner) Name() string {
+ return "pty-runner"
+}
+
+// Synopsis implements subcommands.Command.
+func (*ptyRunner) Synopsis() string {
+ return "runs the given command with an open pty terminal"
+}
+
+// Usage implements subcommands.Command.
+func (*ptyRunner) Usage() string {
+ return "pty-runner [command]"
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (*ptyRunner) SetFlags(f *flag.FlagSet) {}
+
+// Execute implements subcommands.Command.
+func (*ptyRunner) Execute(_ context.Context, fs *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
+ c := exec.Command(fs.Args()[0], fs.Args()[1:]...)
+ f, err := pty.Start(c)
+ if err != nil {
+ fmt.Printf("pty.Start failed: %v", err)
+ return subcommands.ExitFailure
+ }
+ defer f.Close()
+
+ // Copy stdout from the command to keep this process alive until the
+ // subprocess exits.
+ io.Copy(os.Stdout, f)
+
+ return subcommands.ExitSuccess
+}
diff --git a/test/e2e/BUILD b/test/e2e/BUILD
new file mode 100644
index 000000000..29a84f184
--- /dev/null
+++ b/test/e2e/BUILD
@@ -0,0 +1,33 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_test(
+ name = "integration_test",
+ size = "large",
+ srcs = [
+ "exec_test.go",
+ "integration_test.go",
+ "regression_test.go",
+ ],
+ library = ":integration",
+ tags = [
+ # Requires docker and runsc to be configured before the test runs.
+ "manual",
+ "local",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/bits",
+ "//pkg/test/dockerutil",
+ "//pkg/test/testutil",
+ "//runsc/specutils",
+ "@com_github_docker_docker//api/types/mount:go_default_library",
+ ],
+)
+
+go_library(
+ name = "integration",
+ srcs = ["integration.go"],
+)
diff --git a/test/e2e/exec_test.go b/test/e2e/exec_test.go
new file mode 100644
index 000000000..b47df447c
--- /dev/null
+++ b/test/e2e/exec_test.go
@@ -0,0 +1,268 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package integration provides end-to-end integration tests for runsc. These
+// tests require docker and runsc to be installed on the machine.
+//
+// Each test calls docker commands to start up a container, and tests that it
+// is behaving properly, with various runsc commands. The container is killed
+// and deleted at the end.
+
+package integration
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// Test that exec uses the exact same capability set as the container.
+func TestExecCapabilities(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Start the container.
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Check that capability.
+ matches, err := d.WaitForOutputSubmatch(ctx, "CapEff:\t([0-9a-f]+)\n", 5*time.Second)
+ if err != nil {
+ t.Fatalf("WaitForOutputSubmatch() timeout: %v", err)
+ }
+ if len(matches) != 2 {
+ t.Fatalf("There should be a match for the whole line and the capability bitmask")
+ }
+ want := fmt.Sprintf("CapEff:\t%s\n", matches[1])
+ t.Log("Root capabilities:", want)
+
+ // Now check that exec'd process capabilities match the root.
+ got, err := d.Exec(ctx, dockerutil.ExecOpts{}, "grep", "CapEff:", "/proc/self/status")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ t.Logf("CapEff: %v", got)
+ if got != want {
+ t.Errorf("wrong capabilities, got: %q, want: %q", got, want)
+ }
+}
+
+// Test that 'exec --privileged' adds all capabilities, except for CAP_NET_RAW
+// which is removed from the container when --net-raw=false.
+func TestExecPrivileged(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Start the container with all capabilities dropped.
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ CapDrop: []string{"all"},
+ }, "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Check that all capabilities where dropped from container.
+ matches, err := d.WaitForOutputSubmatch(ctx, "CapEff:\t([0-9a-f]+)\n", 5*time.Second)
+ if err != nil {
+ t.Fatalf("WaitForOutputSubmatch() timeout: %v", err)
+ }
+ if len(matches) != 2 {
+ t.Fatalf("There should be a match for the whole line and the capability bitmask")
+ }
+ containerCaps, err := strconv.ParseUint(matches[1], 16, 64)
+ if err != nil {
+ t.Fatalf("failed to convert capabilities %q: %v", matches[1], err)
+ }
+ t.Logf("Container capabilities: %#x", containerCaps)
+ if containerCaps != 0 {
+ t.Fatalf("Container should have no capabilities: %x", containerCaps)
+ }
+
+ // Check that 'exec --privileged' adds all capabilities, except for
+ // CAP_NET_RAW.
+ got, err := d.Exec(ctx, dockerutil.ExecOpts{
+ Privileged: true,
+ }, "grep", "CapEff:", "/proc/self/status")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ t.Logf("Exec CapEff: %v", got)
+ want := fmt.Sprintf("CapEff:\t%016x\n", specutils.AllCapabilitiesUint64()&^bits.MaskOf64(int(linux.CAP_NET_RAW)))
+ if got != want {
+ t.Errorf("Wrong capabilities, got: %q, want: %q. Make sure runsc is not using '--net-raw'", got, want)
+ }
+}
+
+func TestExecJobControl(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Start the container.
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "sleep", "1000"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ p, err := d.ExecProcess(ctx, dockerutil.ExecOpts{UseTTY: true}, "/bin/sh")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+
+ if _, err = p.Write(time.Second, []byte("sleep 100 | cat\n")); err != nil {
+ t.Fatalf("error exit: %v", err)
+ }
+ time.Sleep(time.Second)
+
+ if _, err = p.Write(time.Second, []byte{0x03}); err != nil {
+ t.Fatalf("error exit: %v", err)
+ }
+
+ if _, err = p.Write(time.Second, []byte("exit $(expr $? + 10)\n")); err != nil {
+ t.Fatalf("error exit: %v", err)
+ }
+
+ want := 140
+ got, err := p.WaitExitStatus(ctx)
+ if err != nil {
+ t.Fatalf("wait for exit failed with: %v", err)
+ } else if got != want {
+ t.Fatalf("wait for exit returned: %d want: %d", got, want)
+ }
+}
+
+// Test that failure to exec returns proper error message.
+func TestExecError(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Start the container.
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "sleep", "1000"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Attempt to exec a binary that doesn't exist.
+ out, err := d.Exec(ctx, dockerutil.ExecOpts{}, "no_can_find")
+ if err == nil {
+ t.Fatalf("docker exec didn't fail")
+ }
+ if want := `error finding executable "no_can_find" in PATH`; !strings.Contains(out, want) {
+ t.Fatalf("docker exec wrong error, got: %s, want: .*%s.*", out, want)
+ }
+}
+
+// Test that exec inherits environment from run.
+func TestExecEnv(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Start the container with env FOO=BAR.
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ Env: []string{"FOO=BAR"},
+ }, "sleep", "1000"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Exec "echo $FOO".
+ got, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "echo $FOO")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ if got, want := strings.TrimSpace(got), "BAR"; got != want {
+ t.Errorf("bad output from 'docker exec'. Got %q; Want %q.", got, want)
+ }
+}
+
+// TestRunEnvHasHome tests that run always has HOME environment set.
+func TestRunEnvHasHome(t *testing.T) {
+ // Base alpine image does not have any environment variables set.
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Exec "echo $HOME". The 'bin' user's home dir is '/bin'.
+ got, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ User: "bin",
+ }, "/bin/sh", "-c", "echo $HOME")
+ if err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Check that the directory matches.
+ if got, want := strings.TrimSpace(got), "/bin"; got != want {
+ t.Errorf("bad output from 'docker run'. Got %q; Want %q.", got, want)
+ }
+}
+
+// Test that exec always has HOME environment set, even when not set in run.
+func TestExecEnvHasHome(t *testing.T) {
+ // Base alpine image does not have any environment variables set.
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "sleep", "1000"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Exec "echo $HOME", and expect to see "/root".
+ got, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "echo $HOME")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ if want := "/root"; !strings.Contains(got, want) {
+ t.Errorf("wanted exec output to contain %q, got %q", want, got)
+ }
+
+ // Create a new user with a home directory.
+ newUID := 1234
+ newHome := "/foo/bar"
+ cmd := fmt.Sprintf("mkdir -p -m 777 %q && adduser foo -D -u %d -h %q", newHome, newUID, newHome)
+ if _, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", cmd); err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+
+ // Execute the same as the new user and expect newHome.
+ got, err = d.Exec(ctx, dockerutil.ExecOpts{
+ User: strconv.Itoa(newUID),
+ }, "/bin/sh", "-c", "echo $HOME")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ if want := newHome; !strings.Contains(got, want) {
+ t.Errorf("wanted exec output to contain %q, got %q", want, got)
+ }
+}
diff --git a/test/e2e/integration.go b/test/e2e/integration.go
new file mode 100644
index 000000000..4cd5f6c24
--- /dev/null
+++ b/test/e2e/integration.go
@@ -0,0 +1,16 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package integration is empty. See integration_test.go for description.
+package integration
diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go
new file mode 100644
index 000000000..5a9455b33
--- /dev/null
+++ b/test/e2e/integration_test.go
@@ -0,0 +1,441 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package integration provides end-to-end integration tests for runsc.
+//
+// Each test calls docker commands to start up a container, and tests that it is
+// behaving properly, with various runsc commands. The container is killed and
+// deleted at the end.
+//
+// Setup instruction in test/README.md.
+package integration
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/docker/docker/api/types/mount"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+// httpRequestSucceeds sends a request to a given url and checks that the status is OK.
+func httpRequestSucceeds(client http.Client, server string, port int) error {
+ url := fmt.Sprintf("http://%s:%d", server, port)
+ // Ensure that content is being served.
+ resp, err := client.Get(url)
+ if err != nil {
+ return fmt.Errorf("error reaching http server: %v", err)
+ }
+ if want := http.StatusOK; resp.StatusCode != want {
+ return fmt.Errorf("wrong response code, got: %d, want: %d", resp.StatusCode, want)
+ }
+ return nil
+}
+
+// TestLifeCycle tests a basic Create/Start/Stop docker container life cycle.
+func TestLifeCycle(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Start the container.
+ if err := d.Create(ctx, dockerutil.RunOpts{
+ Image: "basic/nginx",
+ Ports: []int{80},
+ }); err != nil {
+ t.Fatalf("docker create failed: %v", err)
+ }
+ if err := d.Start(ctx); err != nil {
+ t.Fatalf("docker start failed: %v", err)
+ }
+
+ // Test that container is working.
+ port, err := d.FindPort(ctx, 80)
+ if err != nil {
+ t.Fatalf("docker.FindPort(80) failed: %v", err)
+ }
+ if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil {
+ t.Fatalf("WaitForHTTP() timeout: %v", err)
+ }
+ client := http.Client{Timeout: time.Duration(2 * time.Second)}
+ if err := httpRequestSucceeds(client, "localhost", port); err != nil {
+ t.Errorf("http request failed: %v", err)
+ }
+
+ if err := d.Stop(ctx); err != nil {
+ t.Fatalf("docker stop failed: %v", err)
+ }
+ if err := d.Remove(ctx); err != nil {
+ t.Fatalf("docker rm failed: %v", err)
+ }
+}
+
+func TestPauseResume(t *testing.T) {
+ if !testutil.IsCheckpointSupported() {
+ t.Skip("Checkpoint is not supported.")
+ }
+
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Start the container.
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/python",
+ Ports: []int{8080}, // See Dockerfile.
+ }); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Find where port 8080 is mapped to.
+ port, err := d.FindPort(ctx, 8080)
+ if err != nil {
+ t.Fatalf("docker.FindPort(8080) failed: %v", err)
+ }
+
+ // Wait until it's up and running.
+ if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil {
+ t.Fatalf("WaitForHTTP() timeout: %v", err)
+ }
+
+ // Check that container is working.
+ client := http.Client{Timeout: time.Duration(2 * time.Second)}
+ if err := httpRequestSucceeds(client, "localhost", port); err != nil {
+ t.Error("http request failed:", err)
+ }
+
+ if err := d.Pause(ctx); err != nil {
+ t.Fatalf("docker pause failed: %v", err)
+ }
+
+ // Check if container is paused.
+ switch _, err := client.Get(fmt.Sprintf("http://localhost:%d", port)); v := err.(type) {
+ case nil:
+ t.Errorf("http req expected to fail but it succeeded")
+ case net.Error:
+ if !v.Timeout() {
+ t.Errorf("http req got error %v, wanted timeout", v)
+ }
+ default:
+ t.Errorf("http req got unexpected error %v", v)
+ }
+
+ if err := d.Unpause(ctx); err != nil {
+ t.Fatalf("docker unpause failed: %v", err)
+ }
+
+ // Wait until it's up and running.
+ if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil {
+ t.Fatalf("WaitForHTTP() timeout: %v", err)
+ }
+
+ // Check if container is working again.
+ if err := httpRequestSucceeds(client, "localhost", port); err != nil {
+ t.Error("http request failed:", err)
+ }
+}
+
+func TestCheckpointRestore(t *testing.T) {
+ if !testutil.IsCheckpointSupported() {
+ t.Skip("Pause/resume is not supported.")
+ }
+
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Start the container.
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/python",
+ Ports: []int{8080}, // See Dockerfile.
+ }); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Create a snapshot.
+ if err := d.Checkpoint(ctx, "test"); err != nil {
+ t.Fatalf("docker checkpoint failed: %v", err)
+ }
+ if err := d.WaitTimeout(ctx, 30*time.Second); err != nil {
+ t.Fatalf("wait failed: %v", err)
+ }
+
+ // TODO(b/143498576): Remove Poll after github.com/moby/moby/issues/38963 is fixed.
+ if err := testutil.Poll(func() error { return d.Restore(ctx, "test") }, 15*time.Second); err != nil {
+ t.Fatalf("docker restore failed: %v", err)
+ }
+
+ // Find where port 8080 is mapped to.
+ port, err := d.FindPort(ctx, 8080)
+ if err != nil {
+ t.Fatalf("docker.FindPort(8080) failed: %v", err)
+ }
+
+ // Wait until it's up and running.
+ if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil {
+ t.Fatalf("WaitForHTTP() timeout: %v", err)
+ }
+
+ // Check if container is working again.
+ client := http.Client{Timeout: time.Duration(2 * time.Second)}
+ if err := httpRequestSucceeds(client, "localhost", port); err != nil {
+ t.Error("http request failed:", err)
+ }
+}
+
+// Create client and server that talk to each other using the local IP.
+func TestConnectToSelf(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Creates server that replies "server" and exists. Sleeps at the end because
+ // 'docker exec' gets killed if the init process exists before it can finish.
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/ubuntu",
+ }, "/bin/sh", "-c", "echo server | nc -l -p 8080 && sleep 1"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Finds IP address for host.
+ ip, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "cat /etc/hosts | grep ${HOSTNAME} | awk '{print $1}'")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ ip = strings.TrimRight(ip, "\n")
+
+ // Runs client that sends "client" to the server and exits.
+ reply, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", fmt.Sprintf("echo client | nc %s 8080", ip))
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+
+ // Ensure both client and server got the message from each other.
+ if want := "server\n"; reply != want {
+ t.Errorf("Error on server, want: %q, got: %q", want, reply)
+ }
+ if _, err := d.WaitForOutput(ctx, "^client\n$", 1*time.Second); err != nil {
+ t.Fatalf("docker.WaitForOutput(client) timeout: %v", err)
+ }
+}
+
+func TestMemLimit(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // N.B. Because the size of the memory file may grow in large chunks,
+ // there is a minimum threshold of 1GB for the MemTotal figure.
+ allocMemory := 1024 * 1024 // In kb.
+ out, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ Memory: allocMemory * 1024, // In bytes.
+ }, "sh", "-c", "cat /proc/meminfo | grep MemTotal: | awk '{print $2}'")
+ if err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Remove warning message that swap isn't present.
+ if strings.HasPrefix(out, "WARNING") {
+ lines := strings.Split(out, "\n")
+ if len(lines) != 3 {
+ t.Fatalf("invalid output: %s", out)
+ }
+ out = lines[1]
+ }
+
+ // Ensure the memory matches what we want.
+ got, err := strconv.ParseUint(strings.TrimSpace(out), 10, 64)
+ if err != nil {
+ t.Fatalf("failed to parse %q: %v", out, err)
+ }
+ if want := uint64(allocMemory); got != want {
+ t.Errorf("MemTotal got: %d, want: %d", got, want)
+ }
+}
+
+func TestNumCPU(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Read how many cores are in the container.
+ out, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ CpusetCpus: "0",
+ }, "sh", "-c", "cat /proc/cpuinfo | grep 'processor.*:' | wc -l")
+ if err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Ensure it matches what we want.
+ got, err := strconv.Atoi(strings.TrimSpace(out))
+ if err != nil {
+ t.Fatalf("failed to parse %q: %v", out, err)
+ }
+ if want := 1; got != want {
+ t.Errorf("MemTotal got: %d, want: %d", got, want)
+ }
+}
+
+// TestJobControl tests that job control characters are handled properly.
+func TestJobControl(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Start the container with an attached PTY.
+ p, err := d.SpawnProcess(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "sh", "-c", "sleep 100 | cat")
+ if err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ // Give shell a few seconds to start executing the sleep.
+ time.Sleep(2 * time.Second)
+
+ if _, err := p.Write(time.Second, []byte{0x03}); err != nil {
+ t.Fatalf("error exit: %v", err)
+ }
+
+ if err := d.WaitTimeout(ctx, 3*time.Second); err != nil {
+ t.Fatalf("WaitTimeout failed: %v", err)
+ }
+
+ want := 130
+ got, err := p.WaitExitStatus(ctx)
+ if err != nil {
+ t.Fatalf("wait for exit failed with: %v", err)
+ } else if got != want {
+ t.Fatalf("got: %d want: %d", got, want)
+ }
+}
+
+// TestWorkingDirCreation checks that working dir is created if it doesn't exit.
+func TestWorkingDirCreation(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ workingDir string
+ }{
+ {name: "root", workingDir: "/foo"},
+ {name: "tmp", workingDir: "/tmp/foo"},
+ } {
+ for _, readonly := range []bool{true, false} {
+ name := tc.name
+ if readonly {
+ name += "-readonly"
+ }
+ t.Run(name, func(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ opts := dockerutil.RunOpts{
+ Image: "basic/alpine",
+ WorkDir: tc.workingDir,
+ ReadOnly: readonly,
+ }
+ got, err := d.Run(ctx, opts, "sh", "-c", "echo ${PWD}")
+ if err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ if want := tc.workingDir + "\n"; want != got {
+ t.Errorf("invalid working dir, want: %q, got: %q", want, got)
+ }
+ })
+ }
+ }
+}
+
+// TestTmpFile checks that files inside '/tmp' are not overridden.
+func TestTmpFile(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ opts := dockerutil.RunOpts{Image: "tmpfile"}
+ got, err := d.Run(ctx, opts, "cat", "/tmp/foo/file.txt")
+ if err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ if want := "123\n"; want != got {
+ t.Errorf("invalid file content, want: %q, got: %q", want, got)
+ }
+}
+
+// TestTmpMount checks that mounts inside '/tmp' are not overridden.
+func TestTmpMount(t *testing.T) {
+ ctx := context.Background()
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "tmp-mount")
+ if err != nil {
+ t.Fatalf("TempDir(): %v", err)
+ }
+ want := "123"
+ if err := ioutil.WriteFile(filepath.Join(dir, "file.txt"), []byte("123"), 0666); err != nil {
+ t.Fatalf("WriteFile(): %v", err)
+ }
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ opts := dockerutil.RunOpts{
+ Image: "basic/alpine",
+ Mounts: []mount.Mount{
+ {
+ Type: mount.TypeBind,
+ Source: dir,
+ Target: "/tmp/foo",
+ },
+ },
+ }
+ got, err := d.Run(ctx, opts, "cat", "/tmp/foo/file.txt")
+ if err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ if want != got {
+ t.Errorf("invalid file content, want: %q, got: %q", want, got)
+ }
+}
+
+// TestHostOverlayfsCopyUp tests that the --overlayfs-stale-read option causes
+// runsc to hide the incoherence of FDs opened before and after overlayfs
+// copy-up on the host.
+func TestHostOverlayfsCopyUp(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ if _, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "hostoverlaytest",
+ WorkDir: "/root",
+ }, "./test"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+}
+
+func TestMain(m *testing.M) {
+ dockerutil.EnsureSupportedDockerVersion()
+ flag.Parse()
+ os.Exit(m.Run())
+}
diff --git a/test/e2e/regression_test.go b/test/e2e/regression_test.go
new file mode 100644
index 000000000..70bbe5121
--- /dev/null
+++ b/test/e2e/regression_test.go
@@ -0,0 +1,47 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration
+
+import (
+ "context"
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+)
+
+// Test that UDS can be created using overlay when parent directory is in lower
+// layer only (b/134090485).
+//
+// Prerequisite: the directory where the socket file is created must not have
+// been open for write before bind(2) is called.
+func TestBindOverlay(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Run the container.
+ got, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/ubuntu",
+ }, "bash", "-c", "nc -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -U /var/run/sock && wait $p")
+ if err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Check the output contains what we want.
+ if want := "foobar-asdf"; !strings.Contains(got, want) {
+ t.Fatalf("docker run output is missing %q: %s", want, got)
+ }
+}
diff --git a/test/image/BUILD b/test/image/BUILD
new file mode 100644
index 000000000..e749e47d4
--- /dev/null
+++ b/test/image/BUILD
@@ -0,0 +1,33 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_test(
+ name = "image_test",
+ size = "large",
+ srcs = [
+ "image_test.go",
+ ],
+ data = [
+ "latin10k.txt",
+ "mysql.sql",
+ "ruby.rb",
+ "ruby.sh",
+ ],
+ library = ":image",
+ tags = [
+ # Requires docker and runsc to be configured before the test runs.
+ "manual",
+ "local",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//pkg/test/testutil",
+ ],
+)
+
+go_library(
+ name = "image",
+ srcs = ["image.go"],
+)
diff --git a/test/image/image.go b/test/image/image.go
new file mode 100644
index 000000000..297f1ab92
--- /dev/null
+++ b/test/image/image.go
@@ -0,0 +1,16 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package image is empty. See image_test.go for description.
+package image
diff --git a/test/image/image_test.go b/test/image/image_test.go
new file mode 100644
index 000000000..8aa78035f
--- /dev/null
+++ b/test/image/image_test.go
@@ -0,0 +1,312 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package image provides end-to-end image tests for runsc.
+
+// Each test calls docker commands to start up a container, and tests that it
+// is behaving properly, like connecting to a port or looking at the output.
+// The container is killed and deleted at the end.
+//
+// Setup instruction in test/README.md.
+package image
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "net/http"
+ "os"
+ "strings"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+func TestHelloWorld(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Run the basic container.
+ out, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "echo", "Hello world!")
+ if err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Check the output.
+ if !strings.Contains(out, "Hello world!") {
+ t.Fatalf("docker didn't say hello: got %s", out)
+ }
+}
+
+func runHTTPRequest(port int) error {
+ url := fmt.Sprintf("http://localhost:%d/not-found", port)
+ resp, err := http.Get(url)
+ if err != nil {
+ return fmt.Errorf("error reaching http server: %v", err)
+ }
+ if want := http.StatusNotFound; resp.StatusCode != want {
+ return fmt.Errorf("Wrong response code, got: %d, want: %d", resp.StatusCode, want)
+ }
+
+ url = fmt.Sprintf("http://localhost:%d/latin10k.txt", port)
+ resp, err = http.Get(url)
+ if err != nil {
+ return fmt.Errorf("Error reaching http server: %v", err)
+ }
+ if want := http.StatusOK; resp.StatusCode != want {
+ return fmt.Errorf("Wrong response code, got: %d, want: %d", resp.StatusCode, want)
+ }
+
+ body, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ return fmt.Errorf("Error reading http response: %v", err)
+ }
+ defer resp.Body.Close()
+
+ // READALL is the last word in the file. Ensures everything was read.
+ if want := "READALL"; strings.HasSuffix(string(body), want) {
+ return fmt.Errorf("response doesn't contain %q, resp: %q", want, body)
+ }
+ return nil
+}
+
+func testHTTPServer(t *testing.T, port int) {
+ const requests = 10
+ ch := make(chan error, requests)
+ for i := 0; i < requests; i++ {
+ go func() {
+ start := time.Now()
+ err := runHTTPRequest(port)
+ log.Printf("Response time %v: %v", time.Since(start).String(), err)
+ ch <- err
+ }()
+ }
+
+ for i := 0; i < requests; i++ {
+ err := <-ch
+ if err != nil {
+ t.Errorf("testHTTPServer(%d) failed: %v", port, err)
+ }
+ }
+}
+
+func TestHttpd(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Start the container.
+ opts := dockerutil.RunOpts{
+ Image: "basic/httpd",
+ Ports: []int{80},
+ }
+ d.CopyFiles(&opts, "/usr/local/apache2/htdocs", "test/image/latin10k.txt")
+ if err := d.Spawn(ctx, opts); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Find where port 80 is mapped to.
+ port, err := d.FindPort(ctx, 80)
+ if err != nil {
+ t.Fatalf("FindPort(80) failed: %v", err)
+ }
+
+ // Wait until it's up and running.
+ if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil {
+ t.Errorf("WaitForHTTP() timeout: %v", err)
+ }
+
+ testHTTPServer(t, port)
+}
+
+func TestNginx(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Start the container.
+ opts := dockerutil.RunOpts{
+ Image: "basic/nginx",
+ Ports: []int{80},
+ }
+ d.CopyFiles(&opts, "/usr/share/nginx/html", "test/image/latin10k.txt")
+ if err := d.Spawn(ctx, opts); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Find where port 80 is mapped to.
+ port, err := d.FindPort(ctx, 80)
+ if err != nil {
+ t.Fatalf("FindPort(80) failed: %v", err)
+ }
+
+ // Wait until it's up and running.
+ if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil {
+ t.Errorf("WaitForHTTP() timeout: %v", err)
+ }
+
+ testHTTPServer(t, port)
+}
+
+func TestMysql(t *testing.T) {
+ ctx := context.Background()
+ server := dockerutil.MakeContainer(ctx, t)
+ defer server.CleanUp(ctx)
+
+ // Start the container.
+ if err := server.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/mysql",
+ Env: []string{"MYSQL_ROOT_PASSWORD=foobar123"},
+ }); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Wait until it's up and running.
+ if _, err := server.WaitForOutput(ctx, "port: 3306 MySQL Community Server", 3*time.Minute); err != nil {
+ t.Fatalf("WaitForOutput() timeout: %v", err)
+ }
+
+ // Generate the client and copy in the SQL payload.
+ client := dockerutil.MakeContainer(ctx, t)
+ defer client.CleanUp(ctx)
+
+ // Tell mysql client to connect to the server and execute the file in
+ // verbose mode to verify the output.
+ opts := dockerutil.RunOpts{
+ Image: "basic/mysql",
+ Links: []string{server.MakeLink("mysql")},
+ }
+ client.CopyFiles(&opts, "/sql", "test/image/mysql.sql")
+ if _, err := client.Run(ctx, opts, "mysql", "-hmysql", "-uroot", "-pfoobar123", "-v", "-e", "source /sql/mysql.sql"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Ensure file executed to the end and shutdown mysql.
+ if _, err := server.WaitForOutput(ctx, "mysqld: Shutdown complete", 30*time.Second); err != nil {
+ t.Fatalf("WaitForOutput() timeout: %v", err)
+ }
+}
+
+func TestTomcat(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Start the server.
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/tomcat",
+ Ports: []int{8080},
+ }); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Find where port 8080 is mapped to.
+ port, err := d.FindPort(ctx, 8080)
+ if err != nil {
+ t.Fatalf("FindPort(8080) failed: %v", err)
+ }
+
+ // Wait until it's up and running.
+ if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil {
+ t.Fatalf("WaitForHTTP() timeout: %v", err)
+ }
+
+ // Ensure that content is being served.
+ url := fmt.Sprintf("http://localhost:%d", port)
+ resp, err := http.Get(url)
+ if err != nil {
+ t.Errorf("Error reaching http server: %v", err)
+ }
+ if want := http.StatusOK; resp.StatusCode != want {
+ t.Errorf("Wrong response code, got: %d, want: %d", resp.StatusCode, want)
+ }
+}
+
+func TestRuby(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Execute the ruby workload.
+ opts := dockerutil.RunOpts{
+ Image: "basic/ruby",
+ Ports: []int{8080},
+ }
+ d.CopyFiles(&opts, "/src", "test/image/ruby.rb", "test/image/ruby.sh")
+ if err := d.Spawn(ctx, opts, "/src/ruby.sh"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Find where port 8080 is mapped to.
+ port, err := d.FindPort(ctx, 8080)
+ if err != nil {
+ t.Fatalf("FindPort(8080) failed: %v", err)
+ }
+
+ // Wait until it's up and running, 'gem install' can take some time.
+ if err := testutil.WaitForHTTP(port, 1*time.Minute); err != nil {
+ t.Fatalf("WaitForHTTP() timeout: %v", err)
+ }
+
+ // Ensure that content is being served.
+ url := fmt.Sprintf("http://localhost:%d", port)
+ resp, err := http.Get(url)
+ if err != nil {
+ t.Errorf("error reaching http server: %v", err)
+ }
+ if want := http.StatusOK; resp.StatusCode != want {
+ t.Errorf("wrong response code, got: %d, want: %d", resp.StatusCode, want)
+ }
+ body, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("error reading body: %v", err)
+ }
+ if got, want := string(body), "Hello World"; !strings.Contains(got, want) {
+ t.Errorf("invalid body content, got: %q, want: %q", got, want)
+ }
+}
+
+func TestStdio(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ wantStdout := "hello stdout"
+ wantStderr := "bonjour stderr"
+ cmd := fmt.Sprintf("echo %q; echo %q 1>&2;", wantStdout, wantStderr)
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "/bin/sh", "-c", cmd); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ for _, want := range []string{wantStdout, wantStderr} {
+ if _, err := d.WaitForOutput(ctx, want, 5*time.Second); err != nil {
+ t.Fatalf("docker didn't get output %q : %v", want, err)
+ }
+ }
+}
+
+func TestMain(m *testing.M) {
+ dockerutil.EnsureSupportedDockerVersion()
+ flag.Parse()
+ os.Exit(m.Run())
+}
diff --git a/test/image/latin10k.txt b/test/image/latin10k.txt
new file mode 100644
index 000000000..61341e00b
--- /dev/null
+++ b/test/image/latin10k.txt
@@ -0,0 +1,33 @@
+Lorem ipsum dolor sit amet, consectetur adipiscing elit. Cras ut placerat felis. Maecenas urna est, auctor a efficitur sit amet, egestas et augue. Curabitur dignissim scelerisque nunc vel cursus. Ut vehicula est pretium, consectetur nunc non, pharetra ligula. Curabitur ut ultricies metus. Suspendisse pulvinar, orci sed fermentum vestibulum, eros turpis molestie lectus, nec elementum risus dolor mattis felis. Donec ultrices ipsum sem, at pretium lacus convallis at. Mauris nulla enim, tincidunt non bibendum at, vehicula pulvinar mauris.
+
+Duis in dapibus turpis. Pellentesque maximus magna odio, ac congue libero laoreet quis. Maecenas euismod risus in justo aliquam accumsan. Nunc quis ornare arcu, sit amet sodales elit. Phasellus nec scelerisque nisl, a tincidunt arcu. Proin ornare est nunc, sed suscipit orci interdum et. Suspendisse condimentum venenatis diam in tempor. Aliquam egestas lectus in rutrum tempus. Donec id egestas eros. Donec molestie consequat purus, sed posuere odio venenatis vitae. Nunc placerat augue id vehicula varius. In hac habitasse platea dictumst. Proin at est accumsan, venenatis quam a, fermentum risus. Phasellus posuere pellentesque enim, id suscipit magna consequat ut. Quisque ut tortor ante.
+
+Cras ut vulputate metus, a laoreet lectus. Vivamus ultrices molestie odio in tristique. Morbi faucibus mi eget sollicitudin fringilla. Fusce vitae lacinia ligula. Sed egestas sed diam eu posuere. Maecenas justo nisl, venenatis vel nibh vel, cursus aliquam velit. Praesent lacinia dui id erat venenatis rhoncus. Morbi gravida felis ante, sit amet vehicula orci rhoncus vitae.
+
+Sed finibus sagittis dictum. Proin auctor suscipit sem et mattis. Phasellus libero ligula, pellentesque ut felis porttitor, fermentum sollicitudin orci. Nulla eu nulla nibh. Fusce a eros risus. Proin vel magna risus. Donec nec elit eleifend, scelerisque sapien vitae, pharetra quam. Donec porttitor mauris scelerisque, tempus orci hendrerit, dapibus felis. Nullam libero elit, sollicitudin a aliquam at, ultrices in erat. Mauris eget ligula sodales, porta turpis et, scelerisque odio. Mauris mollis leo vitae purus gravida, in tempor nunc efficitur. Nulla facilisis posuere augue, nec pellentesque lectus eleifend ac. Vestibulum convallis est a feugiat tincidunt. Donec vitae enim volutpat, tincidunt eros eu, malesuada nibh.
+
+Quisque molestie, magna ornare elementum convallis, erat enim sagittis ipsum, eget porttitor sapien arcu id purus. Donec ut cursus diam. Nulla rutrum nulla et mi fermentum, vel tempus tellus posuere. Proin vitae pharetra nulla, nec ornare ex. Nulla consequat, augue a accumsan euismod, turpis leo ornare ligula, a pulvinar enim dolor ut augue. Quisque volutpat, lectus a varius mollis, nisl eros feugiat sem, at egestas lacus justo eu elit. Vestibulum scelerisque mauris est, sagittis interdum nunc accumsan sit amet. Maecenas aliquet ex ut lacus ornare, eu sagittis nibh imperdiet. Duis ultrices nisi velit, sed sodales risus sollicitudin et. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; Etiam a accumsan augue, vitae pulvinar nulla. Pellentesque euismod sodales magna, nec luctus eros mattis eget. Sed lacinia suscipit lectus, eget consectetur dui pellentesque sed. Nullam nec mattis tellus.
+
+Aliquam erat volutpat. Praesent lobortis massa porttitor eros tincidunt, nec consequat diam pharetra. Duis efficitur non lorem sed mattis. Suspendisse justo nunc, pulvinar eu porttitor at, facilisis id eros. Suspendisse potenti. Cras molestie aliquet orci ut fermentum. In tempus aliquet eros nec suscipit. Suspendisse in mauris ut lectus ultrices blandit sit amet vitae est. Nam magna massa, porttitor ut semper id, feugiat vel quam. Suspendisse dignissim posuere scelerisque. Donec scelerisque lorem efficitur suscipit suscipit. Nunc luctus ligula et scelerisque lacinia.
+
+Suspendisse potenti. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Sed ultrices, sem in venenatis scelerisque, tellus ipsum porttitor urna, et iaculis lectus odio ac nisi. Integer luctus dui urna, at sollicitudin elit dapibus eu. Praesent nibh ante, porttitor a ante in, ullamcorper pretium felis. Aliquam vel tortor imperdiet, imperdiet lorem et, cursus mi. Proin tempus velit est, ut hendrerit metus gravida sed. Sed nibh sapien, faucibus quis ipsum in, scelerisque lacinia elit. In nec magna eu magna laoreet rhoncus. Donec vitae rutrum mauris. Integer urna felis, consequat at rhoncus vitae, auctor quis elit. Duis a pulvinar sem, nec gravida nisl. Nam non dapibus purus. Praesent vestibulum turpis nec erat porttitor, a scelerisque purus tincidunt.
+
+Nam fringilla leo nisi, nec placerat nisl luctus eget. Aenean malesuada nunc porta sapien sodales convallis. Suspendisse ut massa tempor, ullamcorper mi ut, faucibus turpis. Vivamus at sagittis metus. Donec varius ac mi eget sodales. Nulla feugiat, nulla eu fringilla fringilla, nunc lorem sollicitudin quam, vitae lacinia velit lorem eu orci. Mauris leo urna, pellentesque ac posuere non, pellentesque sit amet quam.
+
+Vestibulum porta diam urna, a aliquet nibh vestibulum et. Proin interdum bibendum nisl sed rhoncus. Sed vel diam hendrerit, faucibus ante et, hendrerit diam. Nunc dolor augue, mattis non dolor vel, luctus sodales neque. Cras malesuada fermentum dolor eu lobortis. Integer dapibus volutpat consequat. Maecenas posuere feugiat nunc. Donec vel mollis elit, volutpat consequat enim. Nulla id nisi finibus orci imperdiet elementum. Phasellus ultrices, elit vitae consequat rutrum, nisl est congue massa, quis condimentum justo nisi vitae turpis. Maecenas aliquet risus sit amet accumsan elementum. Proin non finibus elit, sit amet lobortis augue.
+
+Morbi pretium pulvinar sem vel sollicitudin. Proin imperdiet fringilla leo, non pellentesque lacus gravida nec. Vivamus ullamcorper consectetur ligula eu consectetur. Curabitur sit amet tempus purus. Curabitur quam quam, tincidunt eu tempus vel, volutpat at ipsum. Maecenas lobortis elit ac justo interdum, sit amet mattis ligula mollis. Sed posuere ligula et felis convallis tempor. Aliquam nec mollis velit. Donec varius sit amet erat at imperdiet. Nulla ipsum justo, tempor non sollicitudin gravida, dignissim vel orci. In hac habitasse platea dictumst. Cras cursus tellus id arcu aliquet accumsan. Phasellus ac erat dui.
+
+Duis mollis metus at mi luctus aliquam. Duis varius eget erat ac porttitor. Phasellus lobortis sagittis lacinia. Etiam sagittis eget erat in pulvinar. Phasellus sodales risus nec vulputate accumsan. Cras sit amet pellentesque dui. Praesent consequat felis mi, at vulputate diam convallis a. Donec hendrerit nibh vel justo consequat dictum. In euismod, dui sit amet malesuada suscipit, mauris ex rhoncus eros, sed ornare arcu nunc eu urna. Pellentesque eget erat augue. Integer rutrum mauris sem, nec sodales nulla cursus vel. Vivamus porta, urna vel varius vulputate, nulla arcu malesuada dui, a ultrices magna ante sed nibh.
+
+Morbi ultricies aliquam lorem id bibendum. Donec sit amet nunc vitae massa gravida eleifend hendrerit vel libero. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Nulla vestibulum tempus condimentum. Aliquam dolor ipsum, condimentum in sapien et, tempor iaculis nulla. Aenean non pharetra augue. Maecenas mattis dignissim maximus. Fusce elementum tincidunt massa sit amet lobortis. Phasellus nec pharetra dui, et malesuada ante. Nullam commodo pretium tellus. Praesent sollicitudin, enim eget imperdiet scelerisque, odio felis vulputate dolor, eget auctor neque tellus ac lorem.
+
+In consectetur augue et sapien feugiat varius. Nam tortor mi, consectetur ac felis non, elementum venenatis augue. Suspendisse ut tellus in est sagittis cursus. Quisque faucibus, neque sit amet semper congue, nibh augue finibus odio, vitae interdum dolor arcu eget arcu. Curabitur dictum risus massa, non tincidunt urna molestie non. Maecenas eu quam purus. Donec vulputate, dui eu accumsan blandit, mauris tortor tristique mi, sed blandit leo quam id quam. Ut venenatis sagittis malesuada. Integer non auctor orci. Duis consectetur massa felis. Fusce euismod est sit amet bibendum finibus. Vestibulum dolor ex, tempor at elit in, iaculis cursus dui. Nunc sed neque ac risus rutrum tempus sit amet at ante. In hac habitasse platea dictumst.
+
+Donec rutrum, velit nec viverra tincidunt, est velit viverra neque, quis auctor leo ex at lectus. Morbi eget purus nisi. Aliquam lacus dui, interdum vitae elit at, venenatis dignissim est. Duis ac mollis lorem. Vivamus a vestibulum quam. Maecenas non metus dolor. Praesent tortor nunc, tristique at nisl molestie, vulputate eleifend diam. Integer ultrices lacus odio, vel imperdiet enim accumsan id. Sed ligula tortor, interdum eu velit eget, pharetra pulvinar magna. Sed non lacus in eros tincidunt sagittis ac vel justo. Donec vitae leo sagittis, accumsan ante sit amet, accumsan odio. Ut volutpat ultricies tortor. Vestibulum tempus purus et est tristique sagittis quis vitae turpis.
+
+Nam iaculis neque lacus, eget euismod turpis blandit eget. In hac habitasse platea dictumst. Phasellus justo neque, scelerisque sit amet risus ut, pretium commodo nisl. Phasellus auctor sapien sed ex bibendum fermentum. Proin maximus odio a ante ornare, a feugiat lorem egestas. Etiam efficitur tortor a ante tincidunt interdum. Nullam non est ac massa congue efficitur sit amet nec eros. Nullam at ipsum vel mauris tincidunt efficitur. Duis pulvinar nisl elit, id auctor risus laoreet ac. Sed nunc mauris, tristique id leo ut, condimentum congue nunc. Sed ultricies, mauris et convallis faucibus, justo ex faucibus est, at lobortis purus justo non arcu. Integer vel facilisis elit, dapibus imperdiet mauris.
+
+Pellentesque non mattis turpis, eget bibendum velit. Fusce sollicitudin ante ac tincidunt rhoncus. Praesent porta scelerisque consequat. Donec eleifend faucibus sollicitudin. Quisque vitae purus eget tortor tempor ultrices. Maecenas mauris diam, semper vitae est non, imperdiet tempor magna. Duis elit lacus, auctor vestibulum enim eget, rhoncus porttitor tortor.
+
+Donec non rhoncus nibh. Cras dapibus justo vitae nunc accumsan, id congue erat egestas. Aenean at ante ante. Duis eleifend imperdiet dREADALL
diff --git a/test/image/mysql.sql b/test/image/mysql.sql
new file mode 100644
index 000000000..51554b98d
--- /dev/null
+++ b/test/image/mysql.sql
@@ -0,0 +1,23 @@
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+SHOW databases;
+USE mysql;
+
+CREATE TABLE foo (id int);
+INSERT INTO foo VALUES(1);
+SELECT * FROM foo;
+DROP TABLE foo;
+
+shutdown;
diff --git a/test/image/ruby.rb b/test/image/ruby.rb
new file mode 100644
index 000000000..aced49c6d
--- /dev/null
+++ b/test/image/ruby.rb
@@ -0,0 +1,23 @@
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+require 'sinatra'
+
+set :bind, "0.0.0.0"
+set :port, 8080
+
+get '/' do
+ 'Hello World'
+end
+
diff --git a/test/image/ruby.sh b/test/image/ruby.sh
new file mode 100755
index 000000000..ebe8d5b0e
--- /dev/null
+++ b/test/image/ruby.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -e
+
+gem install sinatra
+ruby /src/ruby.rb
diff --git a/test/iptables/BUILD b/test/iptables/BUILD
new file mode 100644
index 000000000..3e29ca90d
--- /dev/null
+++ b/test/iptables/BUILD
@@ -0,0 +1,36 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "iptables",
+ testonly = 1,
+ srcs = [
+ "filter_input.go",
+ "filter_output.go",
+ "iptables.go",
+ "iptables_util.go",
+ "nat.go",
+ ],
+ visibility = ["//test/iptables:__subpackages__"],
+ deps = [
+ "//pkg/test/testutil",
+ ],
+)
+
+go_test(
+ name = "iptables_test",
+ srcs = [
+ "iptables_test.go",
+ ],
+ data = ["//test/iptables/runner"],
+ library = ":iptables",
+ tags = [
+ "local",
+ "manual",
+ ],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//pkg/test/testutil",
+ ],
+)
diff --git a/test/iptables/README.md b/test/iptables/README.md
new file mode 100644
index 000000000..b9f44bd40
--- /dev/null
+++ b/test/iptables/README.md
@@ -0,0 +1,54 @@
+# iptables Tests
+
+iptables tests are run via `scripts/iptables_test.sh`.
+
+iptables requires raw socket support, so you must add the `--net-raw=true` flag
+to `/etc/docker/daemon.json` in order to use it.
+
+## Test Structure
+
+Each test implements `TestCase`, providing (1) a function to run inside the
+container and (2) a function to run locally. Those processes are given each
+others' IP addresses. The test succeeds when both functions succeed.
+
+The function inside the container (`ContainerAction`) typically sets some
+iptables rules and then tries to send or receive packets. The local function
+(`LocalAction`) will typically just send or receive packets.
+
+### Adding Tests
+
+1) Add your test to the `iptables` package.
+
+2) Register the test in an `init` function via `RegisterTestCase` (see
+`filter_input.go` as an example).
+
+3) Add it to `iptables_test.go` (see the other tests in that file).
+
+Your test is now runnable with bazel!
+
+## Run individual tests
+
+Build and install `runsc`. Re-run this when you modify gVisor:
+
+```bash
+$ bazel build //runsc && sudo cp bazel-bin/runsc/linux_amd64_pure_stripped/runsc $(which runsc)
+```
+
+Build the testing Docker container. Re-run this when you modify the test code in
+this directory:
+
+```bash
+$ make load-iptables
+```
+
+Run an individual test via:
+
+```bash
+$ bazel test //test/iptables:iptables_test --test_filter=<TESTNAME>
+```
+
+To run an individual test with `runc`:
+
+```bash
+$ bazel test //test/iptables:iptables_test --test_filter=<TESTNAME> --test_arg=--runtime=runc
+```
diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go
new file mode 100644
index 000000000..068f228bd
--- /dev/null
+++ b/test/iptables/filter_input.go
@@ -0,0 +1,729 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package iptables
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "time"
+)
+
+const (
+ dropPort = 2401
+ acceptPort = 2402
+ sendloopDuration = 2 * time.Second
+ network = "udp4"
+ chainName = "foochain"
+)
+
+func init() {
+ RegisterTestCase(FilterInputDropAll{})
+ RegisterTestCase(FilterInputDropDifferentUDPPort{})
+ RegisterTestCase(FilterInputDropOnlyUDP{})
+ RegisterTestCase(FilterInputDropTCPDestPort{})
+ RegisterTestCase(FilterInputDropTCPSrcPort{})
+ RegisterTestCase(FilterInputDropUDPPort{})
+ RegisterTestCase(FilterInputDropUDP{})
+ RegisterTestCase(FilterInputCreateUserChain{})
+ RegisterTestCase(FilterInputDefaultPolicyAccept{})
+ RegisterTestCase(FilterInputDefaultPolicyDrop{})
+ RegisterTestCase(FilterInputReturnUnderflow{})
+ RegisterTestCase(FilterInputSerializeJump{})
+ RegisterTestCase(FilterInputJumpBasic{})
+ RegisterTestCase(FilterInputJumpReturn{})
+ RegisterTestCase(FilterInputJumpReturnDrop{})
+ RegisterTestCase(FilterInputJumpBuiltin{})
+ RegisterTestCase(FilterInputJumpTwice{})
+ RegisterTestCase(FilterInputDestination{})
+ RegisterTestCase(FilterInputInvertDestination{})
+ RegisterTestCase(FilterInputSource{})
+ RegisterTestCase(FilterInputInvertSource{})
+}
+
+// FilterInputDropUDP tests that we can drop UDP traffic.
+type FilterInputDropUDP struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputDropUDP) Name() string {
+ return "FilterInputDropUDP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDropUDP) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on dropPort.
+ if err := listenUDP(dropPort, sendloopDuration); err == nil {
+ return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort)
+ } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ // At this point we know that reading timed out and never received a
+ // packet.
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDropUDP) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, dropPort, sendloopDuration)
+}
+
+// FilterInputDropOnlyUDP tests that "-p udp -j DROP" only affects UDP traffic.
+type FilterInputDropOnlyUDP struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputDropOnlyUDP) Name() string {
+ return "FilterInputDropOnlyUDP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDropOnlyUDP) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for a TCP connection, which should be allowed.
+ if err := listenTCP(acceptPort, sendloopDuration); err != nil {
+ return fmt.Errorf("failed to establish a connection %v", err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDropOnlyUDP) LocalAction(ip net.IP) error {
+ // Try to establish a TCP connection with the container, which should
+ // succeed.
+ return connectTCP(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputDropUDPPort tests that we can drop UDP traffic by port.
+type FilterInputDropUDPPort struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputDropUDPPort) Name() string {
+ return "FilterInputDropUDPPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDropUDPPort) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on dropPort.
+ if err := listenUDP(dropPort, sendloopDuration); err == nil {
+ return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort)
+ } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ // At this point we know that reading timed out and never received a
+ // packet.
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDropUDPPort) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, dropPort, sendloopDuration)
+}
+
+// FilterInputDropDifferentUDPPort tests that dropping traffic for a single UDP port
+// doesn't drop packets on other ports.
+type FilterInputDropDifferentUDPPort struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputDropDifferentUDPPort) Name() string {
+ return "FilterInputDropDifferentUDPPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDropDifferentUDPPort) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on another port.
+ if err := listenUDP(acceptPort, sendloopDuration); err != nil {
+ return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDropDifferentUDPPort) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputDropTCPDestPort tests that connections are not accepted on specified source ports.
+type FilterInputDropTCPDestPort struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputDropTCPDestPort) Name() string {
+ return "FilterInputDropTCPDestPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDropTCPDestPort) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on drop port.
+ if err := listenTCP(dropPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDropTCPDestPort) LocalAction(ip net.IP) error {
+ // Ensure we cannot connect to the container.
+ for start := time.Now(); time.Since(start) < sendloopDuration; {
+ if err := connectTCP(ip, dropPort, sendloopDuration-time.Since(start)); err == nil {
+ return fmt.Errorf("expected not to connect, but was able to connect on port %d", dropPort)
+ }
+ }
+
+ return nil
+}
+
+// FilterInputDropTCPSrcPort tests that connections are not accepted on specified source ports.
+type FilterInputDropTCPSrcPort struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputDropTCPSrcPort) Name() string {
+ return "FilterInputDropTCPSrcPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDropTCPSrcPort) ContainerAction(ip net.IP) error {
+ // Drop anything from an ephemeral port.
+ if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--sport", "1024:65535", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ if err := listenTCP(acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection destined to port %d should not be accepted, but was", dropPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDropTCPSrcPort) LocalAction(ip net.IP) error {
+ // Ensure we cannot connect to the container.
+ for start := time.Now(); time.Since(start) < sendloopDuration; {
+ if err := connectTCP(ip, acceptPort, sendloopDuration-time.Since(start)); err == nil {
+ return fmt.Errorf("expected not to connect, but was able to connect on port %d", acceptPort)
+ }
+ }
+
+ return nil
+}
+
+// FilterInputDropAll tests that we can drop all traffic to the INPUT chain.
+type FilterInputDropAll struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputDropAll) Name() string {
+ return "FilterInputDropAll"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDropAll) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "INPUT", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for all packets on dropPort.
+ if err := listenUDP(dropPort, sendloopDuration); err == nil {
+ return fmt.Errorf("packets should have been dropped, but got a packet")
+ } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ // At this point we know that reading timed out and never received a
+ // packet.
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDropAll) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, dropPort, sendloopDuration)
+}
+
+// FilterInputMultiUDPRules verifies that multiple UDP rules are applied
+// correctly. This has the added benefit of testing whether we're serializing
+// rules correctly -- if we do it incorrectly, the iptables tool will
+// misunderstand and save the wrong tables.
+type FilterInputMultiUDPRules struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputMultiUDPRules) Name() string {
+ return "FilterInputMultiUDPRules"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputMultiUDPRules) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"},
+ {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", acceptPort), "-j", "ACCEPT"},
+ {"-L"},
+ }
+ return filterTableRules(rules)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputMultiUDPRules) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// FilterInputRequireProtocolUDP checks that "-m udp" requires "-p udp" to be
+// specified.
+type FilterInputRequireProtocolUDP struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputRequireProtocolUDP) Name() string {
+ return "FilterInputRequireProtocolUDP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputRequireProtocolUDP) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "INPUT", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err == nil {
+ return errors.New("expected iptables to fail with out \"-p udp\", but succeeded")
+ }
+ return nil
+}
+
+func (FilterInputRequireProtocolUDP) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// FilterInputCreateUserChain tests chain creation.
+type FilterInputCreateUserChain struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputCreateUserChain) Name() string {
+ return "FilterInputCreateUserChain"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputCreateUserChain) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ // Create a chain.
+ {"-N", chainName},
+ // Add a simple rule to the chain.
+ {"-A", chainName, "-j", "DROP"},
+ }
+ return filterTableRules(rules)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputCreateUserChain) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// FilterInputDefaultPolicyAccept tests the default ACCEPT policy.
+type FilterInputDefaultPolicyAccept struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputDefaultPolicyAccept) Name() string {
+ return "FilterInputDefaultPolicyAccept"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDefaultPolicyAccept) ContainerAction(ip net.IP) error {
+ // Set the default policy to accept, then receive a packet.
+ if err := filterTable("-P", "INPUT", "ACCEPT"); err != nil {
+ return err
+ }
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDefaultPolicyAccept) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputDefaultPolicyDrop tests the default DROP policy.
+type FilterInputDefaultPolicyDrop struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputDefaultPolicyDrop) Name() string {
+ return "FilterInputDefaultPolicyDrop"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDefaultPolicyDrop) ContainerAction(ip net.IP) error {
+ if err := filterTable("-P", "INPUT", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on dropPort.
+ if err := listenUDP(dropPort, sendloopDuration); err == nil {
+ return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort)
+ } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ // At this point we know that reading timed out and never received a
+ // packet.
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDefaultPolicyDrop) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputReturnUnderflow tests that -j RETURN in a built-in chain causes
+// the underflow rule (i.e. default policy) to be executed.
+type FilterInputReturnUnderflow struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputReturnUnderflow) Name() string {
+ return "FilterInputReturnUnderflow"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputReturnUnderflow) ContainerAction(ip net.IP) error {
+ // Add a RETURN rule followed by an unconditional accept, and set the
+ // default policy to DROP.
+ rules := [][]string{
+ {"-A", "INPUT", "-j", "RETURN"},
+ {"-A", "INPUT", "-j", "DROP"},
+ {"-P", "INPUT", "ACCEPT"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ // We should receive packets, as the RETURN rule will trigger the default
+ // ACCEPT policy.
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputReturnUnderflow) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputSerializeJump verifies that we can serialize jumps.
+type FilterInputSerializeJump struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputSerializeJump) Name() string {
+ return "FilterInputSerializeJump"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputSerializeJump) ContainerAction(ip net.IP) error {
+ // Write a JUMP rule, the serialize it with `-L`.
+ rules := [][]string{
+ {"-N", chainName},
+ {"-A", "INPUT", "-j", chainName},
+ {"-L"},
+ }
+ return filterTableRules(rules)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputSerializeJump) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// FilterInputJumpBasic jumps to a chain and executes a rule there.
+type FilterInputJumpBasic struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputJumpBasic) Name() string {
+ return "FilterInputJumpBasic"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpBasic) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-P", "INPUT", "DROP"},
+ {"-N", chainName},
+ {"-A", "INPUT", "-j", chainName},
+ {"-A", chainName, "-j", "ACCEPT"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on acceptPort.
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpBasic) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputJumpReturn jumps, returns, and executes a rule.
+type FilterInputJumpReturn struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputJumpReturn) Name() string {
+ return "FilterInputJumpReturn"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpReturn) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-N", chainName},
+ {"-P", "INPUT", "ACCEPT"},
+ {"-A", "INPUT", "-j", chainName},
+ {"-A", chainName, "-j", "RETURN"},
+ {"-A", chainName, "-j", "DROP"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on acceptPort.
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpReturn) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputJumpReturnDrop jumps to a chain, returns, and DROPs packets.
+type FilterInputJumpReturnDrop struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputJumpReturnDrop) Name() string {
+ return "FilterInputJumpReturnDrop"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpReturnDrop) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-N", chainName},
+ {"-A", "INPUT", "-j", chainName},
+ {"-A", "INPUT", "-j", "DROP"},
+ {"-A", chainName, "-j", "RETURN"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on dropPort.
+ if err := listenUDP(dropPort, sendloopDuration); err == nil {
+ return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort)
+ } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ // At this point we know that reading timed out and never received a
+ // packet.
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpReturnDrop) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, dropPort, sendloopDuration)
+}
+
+// FilterInputJumpBuiltin verifies that jumping to a top-levl chain is illegal.
+type FilterInputJumpBuiltin struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputJumpBuiltin) Name() string {
+ return "FilterInputJumpBuiltin"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpBuiltin) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "INPUT", "-j", "OUTPUT"); err == nil {
+ return fmt.Errorf("iptables should be unable to jump to a built-in chain")
+ }
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpBuiltin) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// FilterInputJumpTwice jumps twice, then returns twice and executes a rule.
+type FilterInputJumpTwice struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputJumpTwice) Name() string {
+ return "FilterInputJumpTwice"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpTwice) ContainerAction(ip net.IP) error {
+ const chainName2 = chainName + "2"
+ rules := [][]string{
+ {"-P", "INPUT", "DROP"},
+ {"-N", chainName},
+ {"-N", chainName2},
+ {"-A", "INPUT", "-j", chainName},
+ {"-A", chainName, "-j", chainName2},
+ {"-A", "INPUT", "-j", "ACCEPT"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ // UDP packets should jump and return twice, eventually hitting the
+ // ACCEPT rule.
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpTwice) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputDestination verifies that we can filter packets via `-d
+// <ipaddr>`.
+type FilterInputDestination struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputDestination) Name() string {
+ return "FilterInputDestination"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDestination) ContainerAction(ip net.IP) error {
+ addrs, err := localAddrs(false)
+ if err != nil {
+ return err
+ }
+
+ // Make INPUT's default action DROP, then ACCEPT all packets bound for
+ // this machine.
+ rules := [][]string{{"-P", "INPUT", "DROP"}}
+ for _, addr := range addrs {
+ rules = append(rules, []string{"-A", "INPUT", "-d", addr, "-j", "ACCEPT"})
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDestination) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputInvertDestination verifies that we can filter packets via `! -d
+// <ipaddr>`.
+type FilterInputInvertDestination struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputInvertDestination) Name() string {
+ return "FilterInputInvertDestination"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputInvertDestination) ContainerAction(ip net.IP) error {
+ // Make INPUT's default action DROP, then ACCEPT all packets not bound
+ // for 127.0.0.1.
+ rules := [][]string{
+ {"-P", "INPUT", "DROP"},
+ {"-A", "INPUT", "!", "-d", localIP, "-j", "ACCEPT"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputInvertDestination) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputSource verifies that we can filter packets via `-s
+// <ipaddr>`.
+type FilterInputSource struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputSource) Name() string {
+ return "FilterInputSource"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputSource) ContainerAction(ip net.IP) error {
+ // Make INPUT's default action DROP, then ACCEPT all packets from this
+ // machine.
+ rules := [][]string{
+ {"-P", "INPUT", "DROP"},
+ {"-A", "INPUT", "-s", fmt.Sprintf("%v", ip), "-j", "ACCEPT"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputSource) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputInvertSource verifies that we can filter packets via `! -s
+// <ipaddr>`.
+type FilterInputInvertSource struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputInvertSource) Name() string {
+ return "FilterInputInvertSource"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputInvertSource) ContainerAction(ip net.IP) error {
+ // Make INPUT's default action DROP, then ACCEPT all packets not bound
+ // for 127.0.0.1.
+ rules := [][]string{
+ {"-P", "INPUT", "DROP"},
+ {"-A", "INPUT", "!", "-s", localIP, "-j", "ACCEPT"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputInvertSource) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go
new file mode 100644
index 000000000..ba0d6fc29
--- /dev/null
+++ b/test/iptables/filter_output.go
@@ -0,0 +1,607 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package iptables
+
+import (
+ "fmt"
+ "net"
+)
+
+func init() {
+ RegisterTestCase(FilterOutputDropTCPDestPort{})
+ RegisterTestCase(FilterOutputDropTCPSrcPort{})
+ RegisterTestCase(FilterOutputDestination{})
+ RegisterTestCase(FilterOutputInvertDestination{})
+ RegisterTestCase(FilterOutputAcceptTCPOwner{})
+ RegisterTestCase(FilterOutputDropTCPOwner{})
+ RegisterTestCase(FilterOutputAcceptUDPOwner{})
+ RegisterTestCase(FilterOutputDropUDPOwner{})
+ RegisterTestCase(FilterOutputOwnerFail{})
+ RegisterTestCase(FilterOutputAcceptGIDOwner{})
+ RegisterTestCase(FilterOutputDropGIDOwner{})
+ RegisterTestCase(FilterOutputInvertGIDOwner{})
+ RegisterTestCase(FilterOutputInvertUIDOwner{})
+ RegisterTestCase(FilterOutputInvertUIDAndGIDOwner{})
+ RegisterTestCase(FilterOutputInterfaceAccept{})
+ RegisterTestCase(FilterOutputInterfaceDrop{})
+ RegisterTestCase(FilterOutputInterface{})
+ RegisterTestCase(FilterOutputInterfaceBeginsWith{})
+ RegisterTestCase(FilterOutputInterfaceInvertDrop{})
+ RegisterTestCase(FilterOutputInterfaceInvertAccept{})
+}
+
+// FilterOutputDropTCPDestPort tests that connections are not accepted on
+// specified source ports.
+type FilterOutputDropTCPDestPort struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputDropTCPDestPort) Name() string {
+ return "FilterOutputDropTCPDestPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDropTCPDestPort) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", "1024:65535", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ if err := listenTCP(acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDropTCPDestPort) LocalAction(ip net.IP) error {
+ if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort)
+ }
+
+ return nil
+}
+
+// FilterOutputDropTCPSrcPort tests that connections are not accepted on
+// specified source ports.
+type FilterOutputDropTCPSrcPort struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputDropTCPSrcPort) Name() string {
+ return "FilterOutputDropTCPSrcPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDropTCPSrcPort) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--sport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on drop port.
+ if err := listenTCP(dropPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDropTCPSrcPort) LocalAction(ip net.IP) error {
+ if err := connectTCP(ip, dropPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort)
+ }
+
+ return nil
+}
+
+// FilterOutputAcceptTCPOwner tests that TCP connections from uid owner are accepted.
+type FilterOutputAcceptTCPOwner struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputAcceptTCPOwner) Name() string {
+ return "FilterOutputAcceptTCPOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputAcceptTCPOwner) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ return listenTCP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputAcceptTCPOwner) LocalAction(ip net.IP) error {
+ return connectTCP(ip, acceptPort, sendloopDuration)
+}
+
+// FilterOutputDropTCPOwner tests that TCP connections from uid owner are dropped.
+type FilterOutputDropTCPOwner struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputDropTCPOwner) Name() string {
+ return "FilterOutputDropTCPOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDropTCPOwner) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ if err := listenTCP(acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection on port %d should be dropped, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDropTCPOwner) LocalAction(ip net.IP) error {
+ if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection destined to port %d should be dropped, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// FilterOutputAcceptUDPOwner tests that UDP packets from uid owner are accepted.
+type FilterOutputAcceptUDPOwner struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputAcceptUDPOwner) Name() string {
+ return "FilterOutputAcceptUDPOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputAcceptUDPOwner) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ // Send UDP packets on acceptPort.
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputAcceptUDPOwner) LocalAction(ip net.IP) error {
+ // Listen for UDP packets on acceptPort.
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// FilterOutputDropUDPOwner tests that UDP packets from uid owner are dropped.
+type FilterOutputDropUDPOwner struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputDropUDPOwner) Name() string {
+ return "FilterOutputDropUDPOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDropUDPOwner) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Send UDP packets on dropPort.
+ return sendUDPLoop(ip, dropPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDropUDPOwner) LocalAction(ip net.IP) error {
+ // Listen for UDP packets on dropPort.
+ if err := listenUDP(dropPort, sendloopDuration); err == nil {
+ return fmt.Errorf("packets should not be received")
+ }
+
+ return nil
+}
+
+// FilterOutputOwnerFail tests that without uid/gid option, owner rule
+// will fail.
+type FilterOutputOwnerFail struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputOwnerFail) Name() string {
+ return "FilterOutputOwnerFail"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputOwnerFail) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "-j", "ACCEPT"); err == nil {
+ return fmt.Errorf("Invalid argument")
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputOwnerFail) LocalAction(ip net.IP) error {
+ // no-op.
+ return nil
+}
+
+// FilterOutputAcceptGIDOwner tests that TCP connections from gid owner are accepted.
+type FilterOutputAcceptGIDOwner struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputAcceptGIDOwner) Name() string {
+ return "FilterOutputAcceptGIDOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputAcceptGIDOwner) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--gid-owner", "root", "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ return listenTCP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputAcceptGIDOwner) LocalAction(ip net.IP) error {
+ return connectTCP(ip, acceptPort, sendloopDuration)
+}
+
+// FilterOutputDropGIDOwner tests that TCP connections from gid owner are dropped.
+type FilterOutputDropGIDOwner struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputDropGIDOwner) Name() string {
+ return "FilterOutputDropGIDOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDropGIDOwner) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--gid-owner", "root", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ if err := listenTCP(acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDropGIDOwner) LocalAction(ip net.IP) error {
+ if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// FilterOutputInvertGIDOwner tests that TCP connections from gid owner are dropped.
+type FilterOutputInvertGIDOwner struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputInvertGIDOwner) Name() string {
+ return "FilterOutputInvertGIDOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInvertGIDOwner) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--gid-owner", "root", "-j", "ACCEPT"},
+ {"-A", "OUTPUT", "-p", "tcp", "-j", "DROP"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ if err := listenTCP(acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInvertGIDOwner) LocalAction(ip net.IP) error {
+ if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// FilterOutputInvertUIDOwner tests that TCP connections from gid owner are dropped.
+type FilterOutputInvertUIDOwner struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputInvertUIDOwner) Name() string {
+ return "FilterOutputInvertUIDOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInvertUIDOwner) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--uid-owner", "root", "-j", "DROP"},
+ {"-A", "OUTPUT", "-p", "tcp", "-j", "ACCEPT"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ return listenTCP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInvertUIDOwner) LocalAction(ip net.IP) error {
+ return connectTCP(ip, acceptPort, sendloopDuration)
+}
+
+// FilterOutputInvertUIDAndGIDOwner tests that TCP connections from uid and gid
+// owner are dropped.
+type FilterOutputInvertUIDAndGIDOwner struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputInvertUIDAndGIDOwner) Name() string {
+ return "FilterOutputInvertUIDAndGIDOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInvertUIDAndGIDOwner) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--uid-owner", "root", "!", "--gid-owner", "root", "-j", "ACCEPT"},
+ {"-A", "OUTPUT", "-p", "tcp", "-j", "DROP"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ if err := listenTCP(acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInvertUIDAndGIDOwner) LocalAction(ip net.IP) error {
+ if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// FilterOutputDestination tests that we can selectively allow packets to
+// certain destinations.
+type FilterOutputDestination struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputDestination) Name() string {
+ return "FilterOutputDestination"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDestination) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"},
+ {"-P", "OUTPUT", "DROP"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDestination) LocalAction(ip net.IP) error {
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// FilterOutputInvertDestination tests that we can selectively allow packets
+// not headed for a particular destination.
+type FilterOutputInvertDestination struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputInvertDestination) Name() string {
+ return "FilterOutputInvertDestination"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInvertDestination) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-A", "OUTPUT", "!", "-d", localIP, "-j", "ACCEPT"},
+ {"-P", "OUTPUT", "DROP"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInvertDestination) LocalAction(ip net.IP) error {
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// FilterOutputInterfaceAccept tests that packets are sent via interface
+// matching the iptables rule.
+type FilterOutputInterfaceAccept struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputInterfaceAccept) Name() string {
+ return "FilterOutputInterfaceAccept"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInterfaceAccept) ContainerAction(ip net.IP) error {
+ ifname, ok := getInterfaceName()
+ if !ok {
+ return fmt.Errorf("no interface is present, except loopback")
+ }
+ if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInterfaceAccept) LocalAction(ip net.IP) error {
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// FilterOutputInterfaceDrop tests that packets are not sent via interface
+// matching the iptables rule.
+type FilterOutputInterfaceDrop struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputInterfaceDrop) Name() string {
+ return "FilterOutputInterfaceDrop"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInterfaceDrop) ContainerAction(ip net.IP) error {
+ ifname, ok := getInterfaceName()
+ if !ok {
+ return fmt.Errorf("no interface is present, except loopback")
+ }
+ if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "DROP"); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInterfaceDrop) LocalAction(ip net.IP) error {
+ if err := listenUDP(acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort)
+ }
+
+ return nil
+}
+
+// FilterOutputInterface tests that packets are sent via interface which is
+// not matching the interface name in the iptables rule.
+type FilterOutputInterface struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputInterface) Name() string {
+ return "FilterOutputInterface"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInterface) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", "lo", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInterface) LocalAction(ip net.IP) error {
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// FilterOutputInterfaceBeginsWith tests that packets are not sent via an
+// interface which begins with the given interface name.
+type FilterOutputInterfaceBeginsWith struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputInterfaceBeginsWith) Name() string {
+ return "FilterOutputInterfaceBeginsWith"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInterfaceBeginsWith) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", "e+", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInterfaceBeginsWith) LocalAction(ip net.IP) error {
+ if err := listenUDP(acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort)
+ }
+
+ return nil
+}
+
+// FilterOutputInterfaceInvertDrop tests that we selectively do not send
+// packets via interface not matching the interface name.
+type FilterOutputInterfaceInvertDrop struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputInterfaceInvertDrop) Name() string {
+ return "FilterOutputInterfaceInvertDrop"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInterfaceInvertDrop) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ if err := listenTCP(acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInterfaceInvertDrop) LocalAction(ip net.IP) error {
+ if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// FilterOutputInterfaceInvertAccept tests that we can selectively send packets
+// not matching the specific outgoing interface.
+type FilterOutputInterfaceInvertAccept struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputInterfaceInvertAccept) Name() string {
+ return "FilterOutputInterfaceInvertAccept"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInterfaceInvertAccept) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ return listenTCP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInterfaceInvertAccept) LocalAction(ip net.IP) error {
+ return connectTCP(ip, acceptPort, sendloopDuration)
+}
diff --git a/test/iptables/iptables.go b/test/iptables/iptables.go
new file mode 100644
index 000000000..16cb4f4da
--- /dev/null
+++ b/test/iptables/iptables.go
@@ -0,0 +1,60 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package iptables contains a set of iptables tests implemented as TestCases
+package iptables
+
+import (
+ "fmt"
+ "net"
+ "time"
+)
+
+// IPExchangePort is the port the container listens on to receive the IP
+// address of the local process.
+const IPExchangePort = 2349
+
+// TerminalStatement is the last statement in the test runner.
+const TerminalStatement = "Finished!"
+
+// TestTimeout is the timeout used for all tests.
+const TestTimeout = 10 * time.Minute
+
+// A TestCase contains one action to run in the container and one to run
+// locally. The actions run concurrently and each must succeed for the test
+// pass.
+type TestCase interface {
+ // Name returns the name of the test.
+ Name() string
+
+ // ContainerAction runs inside the container. It receives the IP of the
+ // local process.
+ ContainerAction(ip net.IP) error
+
+ // LocalAction runs locally. It receives the IP of the container.
+ LocalAction(ip net.IP) error
+}
+
+// Tests maps test names to TestCase.
+//
+// New TestCases are added by calling RegisterTestCase in an init function.
+var Tests = map[string]TestCase{}
+
+// RegisterTestCase registers tc so it can be run.
+func RegisterTestCase(tc TestCase) {
+ if _, ok := Tests[tc.Name()]; ok {
+ panic(fmt.Sprintf("TestCase %s already registered.", tc.Name()))
+ }
+ Tests[tc.Name()] = tc
+}
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
new file mode 100644
index 000000000..f5ac79370
--- /dev/null
+++ b/test/iptables/iptables_test.go
@@ -0,0 +1,345 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package iptables
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+// singleTest runs a TestCase. Each test follows a pattern:
+// - Create a container.
+// - Get the container's IP.
+// - Send the container our IP.
+// - Start a new goroutine running the local action of the test.
+// - Wait for both the container and local actions to finish.
+//
+// Container output is logged to $TEST_UNDECLARED_OUTPUTS_DIR if it exists, or
+// to stderr.
+func singleTest(t *testing.T, test TestCase) {
+ if _, ok := Tests[test.Name()]; !ok {
+ t.Fatalf("no test found with name %q. Has it been registered?", test.Name())
+ }
+
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Create and start the container.
+ opts := dockerutil.RunOpts{
+ Image: "iptables",
+ CapAdd: []string{"NET_ADMIN"},
+ }
+ d.CopyFiles(&opts, "/runner", "test/iptables/runner/runner")
+ if err := d.Spawn(ctx, opts, "/runner/runner", "-name", test.Name()); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Get the container IP.
+ ip, err := d.FindIP(ctx)
+ if err != nil {
+ t.Fatalf("failed to get container IP: %v", err)
+ }
+
+ // Give the container our IP.
+ if err := sendIP(ip); err != nil {
+ t.Fatalf("failed to send IP to container: %v", err)
+ }
+
+ // Run our side of the test.
+ if err := test.LocalAction(ip); err != nil {
+ t.Fatalf("LocalAction failed: %v", err)
+ }
+
+ // Wait for the final statement. This structure has the side effect
+ // that all container logs will appear within the individual test
+ // context.
+ if _, err := d.WaitForOutput(ctx, TerminalStatement, TestTimeout); err != nil {
+ t.Fatalf("test failed: %v", err)
+ }
+}
+
+func sendIP(ip net.IP) error {
+ contAddr := net.TCPAddr{
+ IP: ip,
+ Port: IPExchangePort,
+ }
+ var conn *net.TCPConn
+ // The container may not be listening when we first connect, so retry
+ // upon error.
+ cb := func() error {
+ c, err := net.DialTCP("tcp4", nil, &contAddr)
+ conn = c
+ return err
+ }
+ if err := testutil.Poll(cb, TestTimeout); err != nil {
+ return fmt.Errorf("timed out waiting to send IP, most recent error: %v", err)
+ }
+ if _, err := conn.Write([]byte{0}); err != nil {
+ return fmt.Errorf("error writing to container: %v", err)
+ }
+ return nil
+}
+
+func TestFilterInputDropUDP(t *testing.T) {
+ singleTest(t, FilterInputDropUDP{})
+}
+
+func TestFilterInputDropUDPPort(t *testing.T) {
+ singleTest(t, FilterInputDropUDPPort{})
+}
+
+func TestFilterInputDropDifferentUDPPort(t *testing.T) {
+ singleTest(t, FilterInputDropDifferentUDPPort{})
+}
+
+func TestFilterInputDropAll(t *testing.T) {
+ singleTest(t, FilterInputDropAll{})
+}
+
+func TestFilterInputDropOnlyUDP(t *testing.T) {
+ singleTest(t, FilterInputDropOnlyUDP{})
+}
+
+func TestFilterInputDropTCPDestPort(t *testing.T) {
+ singleTest(t, FilterInputDropTCPDestPort{})
+}
+
+func TestFilterInputDropTCPSrcPort(t *testing.T) {
+ singleTest(t, FilterInputDropTCPSrcPort{})
+}
+
+func TestFilterInputCreateUserChain(t *testing.T) {
+ singleTest(t, FilterInputCreateUserChain{})
+}
+
+func TestFilterInputDefaultPolicyAccept(t *testing.T) {
+ singleTest(t, FilterInputDefaultPolicyAccept{})
+}
+
+func TestFilterInputDefaultPolicyDrop(t *testing.T) {
+ singleTest(t, FilterInputDefaultPolicyDrop{})
+}
+
+func TestFilterInputReturnUnderflow(t *testing.T) {
+ singleTest(t, FilterInputReturnUnderflow{})
+}
+
+func TestFilterOutputDropTCPDestPort(t *testing.T) {
+ singleTest(t, FilterOutputDropTCPDestPort{})
+}
+
+func TestFilterOutputDropTCPSrcPort(t *testing.T) {
+ singleTest(t, FilterOutputDropTCPSrcPort{})
+}
+
+func TestFilterOutputAcceptTCPOwner(t *testing.T) {
+ singleTest(t, FilterOutputAcceptTCPOwner{})
+}
+
+func TestFilterOutputDropTCPOwner(t *testing.T) {
+ singleTest(t, FilterOutputDropTCPOwner{})
+}
+
+func TestFilterOutputAcceptUDPOwner(t *testing.T) {
+ singleTest(t, FilterOutputAcceptUDPOwner{})
+}
+
+func TestFilterOutputDropUDPOwner(t *testing.T) {
+ singleTest(t, FilterOutputDropUDPOwner{})
+}
+
+func TestFilterOutputOwnerFail(t *testing.T) {
+ singleTest(t, FilterOutputOwnerFail{})
+}
+
+func TestFilterOutputAcceptGIDOwner(t *testing.T) {
+ singleTest(t, FilterOutputAcceptGIDOwner{})
+}
+
+func TestFilterOutputDropGIDOwner(t *testing.T) {
+ singleTest(t, FilterOutputDropGIDOwner{})
+}
+
+func TestFilterOutputInvertGIDOwner(t *testing.T) {
+ singleTest(t, FilterOutputInvertGIDOwner{})
+}
+
+func TestFilterOutputInvertUIDOwner(t *testing.T) {
+ singleTest(t, FilterOutputInvertUIDOwner{})
+}
+
+func TestFilterOutputInvertUIDAndGIDOwner(t *testing.T) {
+ singleTest(t, FilterOutputInvertUIDAndGIDOwner{})
+}
+
+func TestFilterOutputInterfaceAccept(t *testing.T) {
+ singleTest(t, FilterOutputInterfaceAccept{})
+}
+
+func TestFilterOutputInterfaceDrop(t *testing.T) {
+ singleTest(t, FilterOutputInterfaceDrop{})
+}
+
+func TestFilterOutputInterface(t *testing.T) {
+ singleTest(t, FilterOutputInterface{})
+}
+
+func TestFilterOutputInterfaceBeginsWith(t *testing.T) {
+ singleTest(t, FilterOutputInterfaceBeginsWith{})
+}
+
+func TestFilterOutputInterfaceInvertDrop(t *testing.T) {
+ singleTest(t, FilterOutputInterfaceInvertDrop{})
+}
+
+func TestFilterOutputInterfaceInvertAccept(t *testing.T) {
+ singleTest(t, FilterOutputInterfaceInvertAccept{})
+}
+
+func TestJumpSerialize(t *testing.T) {
+ singleTest(t, FilterInputSerializeJump{})
+}
+
+func TestJumpBasic(t *testing.T) {
+ singleTest(t, FilterInputJumpBasic{})
+}
+
+func TestJumpReturn(t *testing.T) {
+ singleTest(t, FilterInputJumpReturn{})
+}
+
+func TestJumpReturnDrop(t *testing.T) {
+ singleTest(t, FilterInputJumpReturnDrop{})
+}
+
+func TestJumpBuiltin(t *testing.T) {
+ singleTest(t, FilterInputJumpBuiltin{})
+}
+
+func TestJumpTwice(t *testing.T) {
+ singleTest(t, FilterInputJumpTwice{})
+}
+
+func TestInputDestination(t *testing.T) {
+ singleTest(t, FilterInputDestination{})
+}
+
+func TestInputInvertDestination(t *testing.T) {
+ singleTest(t, FilterInputInvertDestination{})
+}
+
+func TestOutputDestination(t *testing.T) {
+ singleTest(t, FilterOutputDestination{})
+}
+
+func TestOutputInvertDestination(t *testing.T) {
+ singleTest(t, FilterOutputInvertDestination{})
+}
+
+func TestNATPreRedirectUDPPort(t *testing.T) {
+ singleTest(t, NATPreRedirectUDPPort{})
+}
+
+func TestNATPreRedirectTCPPort(t *testing.T) {
+ singleTest(t, NATPreRedirectTCPPort{})
+}
+
+func TestNATOutRedirectUDPPort(t *testing.T) {
+ singleTest(t, NATOutRedirectUDPPort{})
+}
+
+func TestNATOutRedirectTCPPort(t *testing.T) {
+ singleTest(t, NATOutRedirectTCPPort{})
+}
+
+func TestNATDropUDP(t *testing.T) {
+ singleTest(t, NATDropUDP{})
+}
+
+func TestNATAcceptAll(t *testing.T) {
+ singleTest(t, NATAcceptAll{})
+}
+
+func TestNATOutRedirectIP(t *testing.T) {
+ singleTest(t, NATOutRedirectIP{})
+}
+
+func TestNATOutDontRedirectIP(t *testing.T) {
+ singleTest(t, NATOutDontRedirectIP{})
+}
+
+func TestNATOutRedirectInvert(t *testing.T) {
+ singleTest(t, NATOutRedirectInvert{})
+}
+
+func TestNATPreRedirectIP(t *testing.T) {
+ singleTest(t, NATPreRedirectIP{})
+}
+
+func TestNATPreDontRedirectIP(t *testing.T) {
+ singleTest(t, NATPreDontRedirectIP{})
+}
+
+func TestNATPreRedirectInvert(t *testing.T) {
+ singleTest(t, NATPreRedirectInvert{})
+}
+
+func TestNATRedirectRequiresProtocol(t *testing.T) {
+ singleTest(t, NATRedirectRequiresProtocol{})
+}
+
+func TestNATLoopbackSkipsPrerouting(t *testing.T) {
+ singleTest(t, NATLoopbackSkipsPrerouting{})
+}
+
+func TestInputSource(t *testing.T) {
+ singleTest(t, FilterInputSource{})
+}
+
+func TestInputInvertSource(t *testing.T) {
+ singleTest(t, FilterInputInvertSource{})
+}
+
+func TestFilterAddrs(t *testing.T) {
+ tcs := []struct {
+ ipv6 bool
+ addrs []string
+ want []string
+ }{
+ {
+ ipv6: false,
+ addrs: []string{"192.168.0.1", "192.168.0.2/24", "::1", "::2/128"},
+ want: []string{"192.168.0.1", "192.168.0.2"},
+ },
+ {
+ ipv6: true,
+ addrs: []string{"192.168.0.1", "192.168.0.2/24", "::1", "::2/128"},
+ want: []string{"::1", "::2"},
+ },
+ }
+
+ for _, tc := range tcs {
+ if got := filterAddrs(tc.addrs, tc.ipv6); !reflect.DeepEqual(got, tc.want) {
+ t.Errorf("%v with IPv6 %t: got %v, but wanted %v", tc.addrs, tc.ipv6, got, tc.want)
+ }
+ }
+}
diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go
new file mode 100644
index 000000000..d4bc55b24
--- /dev/null
+++ b/test/iptables/iptables_util.go
@@ -0,0 +1,201 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package iptables
+
+import (
+ "fmt"
+ "net"
+ "os/exec"
+ "strings"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+const iptablesBinary = "iptables"
+const localIP = "127.0.0.1"
+
+// filterTable calls `iptables -t filter` with the given args.
+func filterTable(args ...string) error {
+ return tableCmd("filter", args)
+}
+
+// natTable calls `iptables -t nat` with the given args.
+func natTable(args ...string) error {
+ return tableCmd("nat", args)
+}
+
+func tableCmd(table string, args []string) error {
+ args = append([]string{"-t", table}, args...)
+ cmd := exec.Command(iptablesBinary, args...)
+ if out, err := cmd.CombinedOutput(); err != nil {
+ return fmt.Errorf("error running iptables with args %v\nerror: %v\noutput: %s", args, err, string(out))
+ }
+ return nil
+}
+
+// filterTableRules is like filterTable, but runs multiple iptables commands.
+func filterTableRules(argsList [][]string) error {
+ return tableRules("filter", argsList)
+}
+
+// natTableRules is like natTable, but runs multiple iptables commands.
+func natTableRules(argsList [][]string) error {
+ return tableRules("nat", argsList)
+}
+
+func tableRules(table string, argsList [][]string) error {
+ for _, args := range argsList {
+ if err := tableCmd(table, args); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// listenUDP listens on a UDP port and returns the value of net.Conn.Read() for
+// the first read on that port.
+func listenUDP(port int, timeout time.Duration) error {
+ localAddr := net.UDPAddr{
+ Port: port,
+ }
+ conn, err := net.ListenUDP(network, &localAddr)
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+ conn.SetDeadline(time.Now().Add(timeout))
+ _, err = conn.Read([]byte{0})
+ return err
+}
+
+// sendUDPLoop sends 1 byte UDP packets repeatedly to the IP and port specified
+// over a duration.
+func sendUDPLoop(ip net.IP, port int, duration time.Duration) error {
+ // Send packets for a few seconds.
+ remote := net.UDPAddr{
+ IP: ip,
+ Port: port,
+ }
+ conn, err := net.DialUDP(network, nil, &remote)
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ to := time.After(duration)
+ for timedOut := false; !timedOut; {
+ // This may return an error (connection refused) if the remote
+ // hasn't started listening yet or they're dropping our
+ // packets. So we ignore Write errors and depend on the remote
+ // to report a failure if it doesn't get a packet it needs.
+ conn.Write([]byte{0})
+ select {
+ case <-to:
+ timedOut = true
+ default:
+ time.Sleep(200 * time.Millisecond)
+ }
+ }
+
+ return nil
+}
+
+// listenTCP listens for connections on a TCP port.
+func listenTCP(port int, timeout time.Duration) error {
+ localAddr := net.TCPAddr{
+ Port: port,
+ }
+
+ // Starts listening on port.
+ lConn, err := net.ListenTCP("tcp4", &localAddr)
+ if err != nil {
+ return err
+ }
+ defer lConn.Close()
+
+ // Accept connections on port.
+ lConn.SetDeadline(time.Now().Add(timeout))
+ conn, err := lConn.AcceptTCP()
+ if err != nil {
+ return err
+ }
+ conn.Close()
+ return nil
+}
+
+// connectTCP connects to the given IP and port from an ephemeral local address.
+func connectTCP(ip net.IP, port int, timeout time.Duration) error {
+ contAddr := net.TCPAddr{
+ IP: ip,
+ Port: port,
+ }
+ // The container may not be listening when we first connect, so retry
+ // upon error.
+ callback := func() error {
+ conn, err := net.DialTimeout("tcp", contAddr.String(), timeout)
+ if conn != nil {
+ conn.Close()
+ }
+ return err
+ }
+ if err := testutil.Poll(callback, timeout); err != nil {
+ return fmt.Errorf("timed out waiting to connect IP on port %v, most recent error: %v", port, err)
+ }
+
+ return nil
+}
+
+// localAddrs returns a list of local network interface addresses. When ipv6 is
+// true, only IPv6 addresses are returned. Otherwise only IPv4 addresses are
+// returned.
+func localAddrs(ipv6 bool) ([]string, error) {
+ addrs, err := net.InterfaceAddrs()
+ if err != nil {
+ return nil, err
+ }
+ addrStrs := make([]string, 0, len(addrs))
+ for _, addr := range addrs {
+ addrStrs = append(addrStrs, addr.String())
+ }
+ return filterAddrs(addrStrs, ipv6), nil
+}
+
+func filterAddrs(addrs []string, ipv6 bool) []string {
+ addrStrs := make([]string, 0, len(addrs))
+ for _, addr := range addrs {
+ // Add only IPv4 or only IPv6 addresses.
+ parts := strings.Split(addr, "/")
+ if isIPv6 := net.ParseIP(parts[0]).To4() == nil; isIPv6 == ipv6 {
+ addrStrs = append(addrStrs, parts[0])
+ }
+ }
+ return addrStrs
+}
+
+// getInterfaceName returns the name of the interface other than loopback.
+func getInterfaceName() (string, bool) {
+ var ifname string
+ if interfaces, err := net.Interfaces(); err == nil {
+ for _, intf := range interfaces {
+ if intf.Name != "lo" {
+ ifname = intf.Name
+ break
+ }
+ }
+ }
+
+ return ifname, ifname != ""
+}
diff --git a/test/iptables/nat.go b/test/iptables/nat.go
new file mode 100644
index 000000000..8562b0820
--- /dev/null
+++ b/test/iptables/nat.go
@@ -0,0 +1,439 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package iptables
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "time"
+)
+
+const (
+ redirectPort = 42
+)
+
+func init() {
+ RegisterTestCase(NATPreRedirectUDPPort{})
+ RegisterTestCase(NATPreRedirectTCPPort{})
+ RegisterTestCase(NATOutRedirectUDPPort{})
+ RegisterTestCase(NATOutRedirectTCPPort{})
+ RegisterTestCase(NATDropUDP{})
+ RegisterTestCase(NATAcceptAll{})
+ RegisterTestCase(NATPreRedirectIP{})
+ RegisterTestCase(NATPreDontRedirectIP{})
+ RegisterTestCase(NATPreRedirectInvert{})
+ RegisterTestCase(NATOutRedirectIP{})
+ RegisterTestCase(NATOutDontRedirectIP{})
+ RegisterTestCase(NATOutRedirectInvert{})
+ RegisterTestCase(NATRedirectRequiresProtocol{})
+ RegisterTestCase(NATLoopbackSkipsPrerouting{})
+}
+
+// NATPreRedirectUDPPort tests that packets are redirected to different port.
+type NATPreRedirectUDPPort struct{}
+
+// Name implements TestCase.Name.
+func (NATPreRedirectUDPPort) Name() string {
+ return "NATPreRedirectUDPPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreRedirectUDPPort) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
+ return err
+ }
+
+ if err := listenUDP(redirectPort, sendloopDuration); err != nil {
+ return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", redirectPort, err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreRedirectUDPPort) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// NATPreRedirectTCPPort tests that connections are redirected on specified ports.
+type NATPreRedirectTCPPort struct{}
+
+// Name implements TestCase.Name.
+func (NATPreRedirectTCPPort) Name() string {
+ return "NATPreRedirectTCPPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreRedirectTCPPort) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on redirect port.
+ return listenTCP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreRedirectTCPPort) LocalAction(ip net.IP) error {
+ return connectTCP(ip, dropPort, sendloopDuration)
+}
+
+// NATOutRedirectUDPPort tests that packets are redirected to different port.
+type NATOutRedirectUDPPort struct{}
+
+// Name implements TestCase.Name.
+func (NATOutRedirectUDPPort) Name() string {
+ return "NATOutRedirectUDPPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutRedirectUDPPort) ContainerAction(ip net.IP) error {
+ dest := []byte{200, 0, 0, 1}
+ return loopbackTest(dest, "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort))
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutRedirectUDPPort) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// NATDropUDP tests that packets are not received in ports other than redirect
+// port.
+type NATDropUDP struct{}
+
+// Name implements TestCase.Name.
+func (NATDropUDP) Name() string {
+ return "NATDropUDP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATDropUDP) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
+ return err
+ }
+
+ if err := listenUDP(acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("packets on port %d should have been redirected to port %d", acceptPort, redirectPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATDropUDP) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// NATAcceptAll tests that all UDP packets are accepted.
+type NATAcceptAll struct{}
+
+// Name implements TestCase.Name.
+func (NATAcceptAll) Name() string {
+ return "NATAcceptAll"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATAcceptAll) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ if err := listenUDP(acceptPort, sendloopDuration); err != nil {
+ return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATAcceptAll) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// NATOutRedirectIP uses iptables to select packets based on destination IP and
+// redirects them.
+type NATOutRedirectIP struct{}
+
+// Name implements TestCase.Name.
+func (NATOutRedirectIP) Name() string {
+ return "NATOutRedirectIP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutRedirectIP) ContainerAction(ip net.IP) error {
+ // Redirect OUTPUT packets to a listening localhost port.
+ dest := net.IP([]byte{200, 0, 0, 2})
+ return loopbackTest(dest, "-A", "OUTPUT", "-d", dest.String(), "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort))
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutRedirectIP) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// NATOutDontRedirectIP tests that iptables matching with "-d" does not match
+// packets it shouldn't.
+type NATOutDontRedirectIP struct{}
+
+// Name implements TestCase.Name.
+func (NATOutDontRedirectIP) Name() string {
+ return "NATOutDontRedirectIP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutDontRedirectIP) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "OUTPUT", "-d", localIP, "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil {
+ return err
+ }
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutDontRedirectIP) LocalAction(ip net.IP) error {
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// NATOutRedirectInvert tests that iptables can match with "! -d".
+type NATOutRedirectInvert struct{}
+
+// Name implements TestCase.Name.
+func (NATOutRedirectInvert) Name() string {
+ return "NATOutRedirectInvert"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutRedirectInvert) ContainerAction(ip net.IP) error {
+ // Redirect OUTPUT packets to a listening localhost port.
+ dest := []byte{200, 0, 0, 3}
+ destStr := "200.0.0.2"
+ return loopbackTest(dest, "-A", "OUTPUT", "!", "-d", destStr, "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort))
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutRedirectInvert) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// NATPreRedirectIP tests that we can use iptables to select packets based on
+// destination IP and redirect them.
+type NATPreRedirectIP struct{}
+
+// Name implements TestCase.Name.
+func (NATPreRedirectIP) Name() string {
+ return "NATPreRedirectIP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreRedirectIP) ContainerAction(ip net.IP) error {
+ addrs, err := localAddrs(false)
+ if err != nil {
+ return err
+ }
+
+ var rules [][]string
+ for _, addr := range addrs {
+ rules = append(rules, []string{"-A", "PREROUTING", "-p", "udp", "-d", addr, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)})
+ }
+ if err := natTableRules(rules); err != nil {
+ return err
+ }
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreRedirectIP) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, dropPort, sendloopDuration)
+}
+
+// NATPreDontRedirectIP tests that iptables matching with "-d" does not match
+// packets it shouldn't.
+type NATPreDontRedirectIP struct{}
+
+// Name implements TestCase.Name.
+func (NATPreDontRedirectIP) Name() string {
+ return "NATPreDontRedirectIP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreDontRedirectIP) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "PREROUTING", "-p", "udp", "-d", localIP, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil {
+ return err
+ }
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreDontRedirectIP) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// NATPreRedirectInvert tests that iptables can match with "! -d".
+type NATPreRedirectInvert struct{}
+
+// Name implements TestCase.Name.
+func (NATPreRedirectInvert) Name() string {
+ return "NATPreRedirectInvert"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreRedirectInvert) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "PREROUTING", "-p", "udp", "!", "-d", localIP, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil {
+ return err
+ }
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreRedirectInvert) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, dropPort, sendloopDuration)
+}
+
+// NATRedirectRequiresProtocol tests that use of the --to-ports flag requires a
+// protocol to be specified with -p.
+type NATRedirectRequiresProtocol struct{}
+
+// Name implements TestCase.Name.
+func (NATRedirectRequiresProtocol) Name() string {
+ return "NATRedirectRequiresProtocol"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATRedirectRequiresProtocol) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "PREROUTING", "-d", localIP, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err == nil {
+ return errors.New("expected an error using REDIRECT --to-ports without a protocol")
+ }
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATRedirectRequiresProtocol) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// NATOutRedirectTCPPort tests that connections are redirected on specified ports.
+type NATOutRedirectTCPPort struct{}
+
+// Name implements TestCase.Name.
+func (NATOutRedirectTCPPort) Name() string {
+ return "NATOutRedirectTCPPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutRedirectTCPPort) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil {
+ return err
+ }
+
+ timeout := 20 * time.Second
+ dest := []byte{127, 0, 0, 1}
+ localAddr := net.TCPAddr{
+ IP: dest,
+ Port: acceptPort,
+ }
+
+ // Starts listening on port.
+ lConn, err := net.ListenTCP("tcp", &localAddr)
+ if err != nil {
+ return err
+ }
+ defer lConn.Close()
+
+ // Accept connections on port.
+ lConn.SetDeadline(time.Now().Add(timeout))
+ err = connectTCP(ip, dropPort, timeout)
+ if err != nil {
+ return err
+ }
+
+ conn, err := lConn.AcceptTCP()
+ if err != nil {
+ return err
+ }
+ conn.Close()
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutRedirectTCPPort) LocalAction(ip net.IP) error {
+ return nil
+}
+
+// NATLoopbackSkipsPrerouting tests that packets sent via loopback aren't
+// affected by PREROUTING rules.
+type NATLoopbackSkipsPrerouting struct{}
+
+// Name implements TestCase.Name.
+func (NATLoopbackSkipsPrerouting) Name() string {
+ return "NATLoopbackSkipsPrerouting"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATLoopbackSkipsPrerouting) ContainerAction(ip net.IP) error {
+ // Redirect anything sent to localhost to an unused port.
+ dest := []byte{127, 0, 0, 1}
+ if err := natTable("-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil {
+ return err
+ }
+
+ // Establish a connection via localhost. If the PREROUTING rule did apply to
+ // loopback traffic, the connection would fail.
+ sendCh := make(chan error)
+ go func() {
+ sendCh <- connectTCP(dest, acceptPort, sendloopDuration)
+ }()
+
+ if err := listenTCP(acceptPort, sendloopDuration); err != nil {
+ return err
+ }
+ return <-sendCh
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATLoopbackSkipsPrerouting) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// loopbackTests runs an iptables rule and ensures that packets sent to
+// dest:dropPort are received by localhost:acceptPort.
+func loopbackTest(dest net.IP, args ...string) error {
+ if err := natTable(args...); err != nil {
+ return err
+ }
+ sendCh := make(chan error)
+ listenCh := make(chan error)
+ go func() {
+ sendCh <- sendUDPLoop(dest, dropPort, sendloopDuration)
+ }()
+ go func() {
+ listenCh <- listenUDP(acceptPort, sendloopDuration)
+ }()
+ select {
+ case err := <-listenCh:
+ if err != nil {
+ return err
+ }
+ case <-time.After(sendloopDuration):
+ return errors.New("timed out")
+ }
+ // sendCh will always take the full sendloop time.
+ return <-sendCh
+}
diff --git a/test/iptables/runner/BUILD b/test/iptables/runner/BUILD
new file mode 100644
index 000000000..24504a1b9
--- /dev/null
+++ b/test/iptables/runner/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "runner",
+ testonly = 1,
+ srcs = ["main.go"],
+ pure = True,
+ visibility = ["//test/iptables:__subpackages__"],
+ deps = ["//test/iptables"],
+)
diff --git a/test/iptables/runner/main.go b/test/iptables/runner/main.go
new file mode 100644
index 000000000..6f77c0684
--- /dev/null
+++ b/test/iptables/runner/main.go
@@ -0,0 +1,73 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package main runs iptables tests from within a docker container.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "log"
+ "net"
+
+ "gvisor.dev/gvisor/test/iptables"
+)
+
+var name = flag.String("name", "", "name of the test to run")
+
+func main() {
+ flag.Parse()
+
+ // Find out which test we're running.
+ test, ok := iptables.Tests[*name]
+ if !ok {
+ log.Fatalf("No test found named %q", *name)
+ }
+ log.Printf("Running test %q", *name)
+
+ // Get the IP of the local process.
+ ip, err := getIP()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Run the test.
+ if err := test.ContainerAction(ip); err != nil {
+ log.Fatalf("Failed running test %q: %v", *name, err)
+ }
+
+ // Emit the final line.
+ log.Printf("%s", iptables.TerminalStatement)
+}
+
+// getIP listens for a connection from the local process and returns the source
+// IP of that connection.
+func getIP() (net.IP, error) {
+ localAddr := net.TCPAddr{
+ Port: iptables.IPExchangePort,
+ }
+ listener, err := net.ListenTCP("tcp4", &localAddr)
+ if err != nil {
+ return net.IP{}, fmt.Errorf("failed listening for IP: %v", err)
+ }
+ defer listener.Close()
+ conn, err := listener.AcceptTCP()
+ if err != nil {
+ return net.IP{}, fmt.Errorf("failed accepting IP: %v", err)
+ }
+ defer conn.Close()
+ log.Printf("Connected to %v", conn.RemoteAddr())
+
+ return conn.RemoteAddr().(*net.TCPAddr).IP, nil
+}
diff --git a/test/packetdrill/BUILD b/test/packetdrill/BUILD
new file mode 100644
index 000000000..dfcd55f60
--- /dev/null
+++ b/test/packetdrill/BUILD
@@ -0,0 +1,38 @@
+load("defs.bzl", "packetdrill_test")
+
+package(licenses = ["notice"])
+
+packetdrill_test(
+ name = "packetdrill_sanity_test",
+ scripts = ["sanity_test.pkt"],
+)
+
+packetdrill_test(
+ name = "accept_ack_drop_test",
+ scripts = ["accept_ack_drop.pkt"],
+)
+
+packetdrill_test(
+ name = "fin_wait2_timeout_test",
+ scripts = ["fin_wait2_timeout.pkt"],
+)
+
+packetdrill_test(
+ name = "listen_close_before_handshake_complete_test",
+ scripts = ["listen_close_before_handshake_complete.pkt"],
+)
+
+packetdrill_test(
+ name = "no_rst_to_rst_test",
+ scripts = ["no_rst_to_rst.pkt"],
+)
+
+packetdrill_test(
+ name = "tcp_defer_accept_test",
+ scripts = ["tcp_defer_accept.pkt"],
+)
+
+packetdrill_test(
+ name = "tcp_defer_accept_timeout_test",
+ scripts = ["tcp_defer_accept_timeout.pkt"],
+)
diff --git a/test/packetdrill/accept_ack_drop.pkt b/test/packetdrill/accept_ack_drop.pkt
new file mode 100644
index 000000000..76e638fd4
--- /dev/null
+++ b/test/packetdrill/accept_ack_drop.pkt
@@ -0,0 +1,27 @@
+// Test that the accept works if the final ACK is dropped and an ack with data
+// follows the dropped ack.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
+// Set backlog to 1 so that we can easily test.
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0.0 > S. 0:0(0) ack 1 <...>
+
++0.0 < . 1:5(4) ack 1 win 257
++0.0 > . 1:1(0) ack 5 <...>
+
+// This should cause connection to transition to connected state.
++0.000 accept(3, ..., ...) = 4
++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0
+
+// Now read the data and we should get 4 bytes.
++0.000 read(4,..., 4) = 4
++0.000 close(4) = 0
+
++0.0 > F. 1:1(0) ack 5 <...>
++0.0 < F. 5:5(0) ack 2 win 257
++0.01 > . 2:2(0) ack 6 <...> \ No newline at end of file
diff --git a/test/packetdrill/defs.bzl b/test/packetdrill/defs.bzl
new file mode 100644
index 000000000..f499c177b
--- /dev/null
+++ b/test/packetdrill/defs.bzl
@@ -0,0 +1,87 @@
+"""Defines a rule for packetdrill test targets."""
+
+def _packetdrill_test_impl(ctx):
+ test_runner = ctx.executable._test_runner
+ runner = ctx.actions.declare_file("%s-runner" % ctx.label.name)
+
+ script_paths = []
+ for script in ctx.files.scripts:
+ script_paths.append(script.short_path)
+ runner_content = "\n".join([
+ "#!/bin/bash",
+ # This test will run part in a distinct user namespace. This can cause
+ # permission problems, because all runfiles may not be owned by the
+ # current user, and no other users will be mapped in that namespace.
+ # Make sure that everything is readable here.
+ "find . -type f -exec chmod a+rx {} \\;",
+ "find . -type d -exec chmod a+rx {} \\;",
+ "%s %s --init_script %s $@ -- %s\n" % (
+ test_runner.short_path,
+ " ".join(ctx.attr.flags),
+ ctx.files._init_script[0].short_path,
+ " ".join(script_paths),
+ ),
+ ])
+ ctx.actions.write(runner, runner_content, is_executable = True)
+
+ transitive_files = depset()
+ if hasattr(ctx.attr._test_runner, "data_runfiles"):
+ transitive_files = depset(ctx.attr._test_runner.data_runfiles.files)
+ runfiles = ctx.runfiles(
+ files = [test_runner] + ctx.files._init_script + ctx.files.scripts,
+ transitive_files = transitive_files,
+ collect_default = True,
+ collect_data = True,
+ )
+ return [DefaultInfo(executable = runner, runfiles = runfiles)]
+
+_packetdrill_test = rule(
+ attrs = {
+ "_test_runner": attr.label(
+ executable = True,
+ cfg = "host",
+ allow_files = True,
+ default = "packetdrill_test.sh",
+ ),
+ "_init_script": attr.label(
+ allow_single_file = True,
+ default = "packetdrill_setup.sh",
+ ),
+ "flags": attr.string_list(
+ mandatory = False,
+ default = [],
+ ),
+ "scripts": attr.label_list(
+ mandatory = True,
+ allow_files = True,
+ ),
+ },
+ test = True,
+ implementation = _packetdrill_test_impl,
+)
+
+_PACKETDRILL_TAGS = ["local", "manual"]
+
+def packetdrill_linux_test(name, **kwargs):
+ if "tags" not in kwargs:
+ kwargs["tags"] = _PACKETDRILL_TAGS
+ _packetdrill_test(
+ name = name,
+ flags = ["--dut_platform", "linux"],
+ **kwargs
+ )
+
+def packetdrill_netstack_test(name, **kwargs):
+ if "tags" not in kwargs:
+ kwargs["tags"] = _PACKETDRILL_TAGS
+ _packetdrill_test(
+ name = name,
+ # This is the default runtime unless
+ # "--test_arg=--runtime=OTHER_RUNTIME" is used to override the value.
+ flags = ["--dut_platform", "netstack", "--runtime", "runsc-d"],
+ **kwargs
+ )
+
+def packetdrill_test(name, **kwargs):
+ packetdrill_linux_test(name + "_linux_test", **kwargs)
+ packetdrill_netstack_test(name + "_netstack_test", **kwargs)
diff --git a/test/packetdrill/fin_wait2_timeout.pkt b/test/packetdrill/fin_wait2_timeout.pkt
new file mode 100644
index 000000000..93ab08575
--- /dev/null
+++ b/test/packetdrill/fin_wait2_timeout.pkt
@@ -0,0 +1,23 @@
+// Test that a socket in FIN_WAIT_2 eventually times out and a subsequent
+// packet generates a RST.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0 > S. 0:0(0) ack 1 <...>
++0 < P. 1:1(0) ack 1 win 257
+
++0.100 accept(3, ..., ...) = 4
+// set FIN_WAIT2 timeout to 1 seconds.
++0.100 setsockopt(4, SOL_TCP, TCP_LINGER2, [1], 4) = 0
++0 close(4) = 0
+
++0 > F. 1:1(0) ack 1 <...>
++0 < . 1:1(0) ack 2 win 257
+
++2 < . 1:1(0) ack 2 win 257
++0 > R 2:2(0) win 0
diff --git a/test/packetdrill/listen_close_before_handshake_complete.pkt b/test/packetdrill/listen_close_before_handshake_complete.pkt
new file mode 100644
index 000000000..51c3f1a32
--- /dev/null
+++ b/test/packetdrill/listen_close_before_handshake_complete.pkt
@@ -0,0 +1,31 @@
+// Test that closing a listening socket closes any connections in SYN-RCVD
+// state and any packets bound for these connections generate a RESET.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
+// Set backlog to 1 so that we can easily test.
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0 > S. 0:0(0) ack 1 <...>
+
++0.100 close(3) = 0
++0.1 < P. 1:1(0) ack 1 win 257
+
+// Linux generates a reset with no ack number/bit set. This is contradictory to
+// what is specified in Rule 1 under Reset Generation in
+// https://tools.ietf.org/html/rfc793#section-3.4.
+// "1. If the connection does not exist (CLOSED) then a reset is sent
+// in response to any incoming segment except another reset. In
+// particular, SYNs addressed to a non-existent connection are rejected
+// by this means.
+//
+// If the incoming segment has an ACK field, the reset takes its
+// sequence number from the ACK field of the segment, otherwise the
+// reset has sequence number zero and the ACK field is set to the sum
+// of the sequence number and segment length of the incoming segment.
+// The connection remains in the CLOSED state."
+
++0.0 > R 1:1(0) win 0 \ No newline at end of file
diff --git a/test/packetdrill/no_rst_to_rst.pkt b/test/packetdrill/no_rst_to_rst.pkt
new file mode 100644
index 000000000..612747827
--- /dev/null
+++ b/test/packetdrill/no_rst_to_rst.pkt
@@ -0,0 +1,36 @@
+// Test a RST is not generated in response to a RST and a RST is correctly
+// generated when an accepted endpoint is RST due to an incoming RST.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0 > S. 0:0(0) ack 1 <...>
++0 < P. 1:1(0) ack 1 win 257
+
++0.100 accept(3, ..., ...) = 4
+
++0.200 < R 1:1(0) win 0
+
++0.300 read(4,..., 4) = -1 ECONNRESET (Connection Reset by Peer)
+
++0.00 < . 1:1(0) ack 1 win 257
+
+// Linux generates a reset with no ack number/bit set. This is contradictory to
+// what is specified in Rule 1 under Reset Generation in
+// https://tools.ietf.org/html/rfc793#section-3.4.
+// "1. If the connection does not exist (CLOSED) then a reset is sent
+// in response to any incoming segment except another reset. In
+// particular, SYNs addressed to a non-existent connection are rejected
+// by this means.
+//
+// If the incoming segment has an ACK field, the reset takes its
+// sequence number from the ACK field of the segment, otherwise the
+// reset has sequence number zero and the ACK field is set to the sum
+// of the sequence number and segment length of the incoming segment.
+// The connection remains in the CLOSED state."
+
++0.00 > R 1:1(0) win 0 \ No newline at end of file
diff --git a/test/packetdrill/packetdrill_setup.sh b/test/packetdrill/packetdrill_setup.sh
new file mode 100755
index 000000000..b858072f0
--- /dev/null
+++ b/test/packetdrill/packetdrill_setup.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script runs both within the sentry context and natively. It should tweak
+# TCP parameters to match expectations found in the script files.
+sysctl -q net.ipv4.tcp_sack=1
+sysctl -q net.ipv4.tcp_rmem="4096 2097152 $((8*1024*1024))"
+sysctl -q net.ipv4.tcp_wmem="4096 2097152 $((8*1024*1024))"
+
+# There may be errors from the above, but they will show up in the test logs and
+# we always want to proceed from this point. It's possible that values were
+# already set correctly and the nodes were not available in the namespace.
+exit 0
diff --git a/test/packetdrill/packetdrill_test.sh b/test/packetdrill/packetdrill_test.sh
new file mode 100755
index 000000000..922547d65
--- /dev/null
+++ b/test/packetdrill/packetdrill_test.sh
@@ -0,0 +1,226 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Run a packetdrill test. Two docker containers are made, one for the
+# Device-Under-Test (DUT) and one for the test runner. Each is attached with
+# two networks, one for control packets that aid the test and one for test
+# packets which are sent as part of the test and observed for correctness.
+
+set -euxo pipefail
+
+function failure() {
+ local lineno=$1
+ local msg=$2
+ local filename="$0"
+ echo "FAIL: $filename:$lineno: $msg"
+}
+trap 'failure ${LINENO} "$BASH_COMMAND"' ERR
+
+declare -r LONGOPTS="dut_platform:,init_script:,runtime:"
+
+# Don't use declare below so that the error from getopt will end the script.
+PARSED=$(getopt --options "" --longoptions=$LONGOPTS --name "$0" -- "$@")
+
+eval set -- "$PARSED"
+
+while true; do
+ case "$1" in
+ --dut_platform)
+ # Either "linux" or "netstack".
+ declare -r DUT_PLATFORM="$2"
+ shift 2
+ ;;
+ --init_script)
+ declare -r INIT_SCRIPT="$2"
+ shift 2
+ ;;
+ --runtime)
+ # Not readonly because there might be multiple --runtime arguments and we
+ # want to use just the last one. Only used if --dut_platform is
+ # "netstack".
+ declare RUNTIME="$2"
+ shift 2
+ ;;
+ --)
+ shift
+ break
+ ;;
+ *)
+ echo "Programming error"
+ exit 3
+ esac
+done
+
+# All the other arguments are scripts.
+declare -r scripts="$@"
+
+# Check that the required flags are defined in a way that is safe for "set -u".
+if [[ "${DUT_PLATFORM-}" == "netstack" ]]; then
+ if [[ -z "${RUNTIME-}" ]]; then
+ echo "FAIL: Missing --runtime argument: ${RUNTIME-}"
+ exit 2
+ fi
+ declare -r RUNTIME_ARG="--runtime ${RUNTIME}"
+elif [[ "${DUT_PLATFORM-}" == "linux" ]]; then
+ declare -r RUNTIME_ARG=""
+else
+ echo "FAIL: Bad or missing --dut_platform argument: ${DUT_PLATFORM-}"
+ exit 2
+fi
+if [[ ! -x "${INIT_SCRIPT-}" ]]; then
+ echo "FAIL: Bad or missing --init_script: ${INIT_SCRIPT-}"
+ exit 2
+fi
+
+function new_net_prefix() {
+ # Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24.
+ echo "$(shuf -i 192-223 -n 1).$(shuf -i 0-255 -n 1).$(shuf -i 0-255 -n 1)"
+}
+
+# Variables specific to the control network and interface start with CTRL_.
+# Variables specific to the test network and interface start with TEST_.
+# Variables specific to the DUT start with DUT_.
+# Variables specific to the test runner start with TEST_RUNNER_.
+declare -r PACKETDRILL="/packetdrill/gtests/net/packetdrill/packetdrill"
+# Use random numbers so that test networks don't collide.
+declare CTRL_NET="ctrl_net-$(shuf -i 0-99999999 -n 1)"
+declare CTRL_NET_PREFIX=$(new_net_prefix)
+declare TEST_NET="test_net-$(shuf -i 0-99999999 -n 1)"
+declare TEST_NET_PREFIX=$(new_net_prefix)
+declare -r tolerance_usecs=100000
+# On both DUT and test runner, testing packets are on the eth2 interface.
+declare -r TEST_DEVICE="eth2"
+# Number of bits in the *_NET_PREFIX variables.
+declare -r NET_MASK="24"
+# Last bits of the DUT's IP address.
+declare -r DUT_NET_SUFFIX=".10"
+# Control port.
+declare -r CTRL_PORT="40000"
+# Last bits of the test runner's IP address.
+declare -r TEST_RUNNER_NET_SUFFIX=".20"
+declare -r TIMEOUT="60"
+declare -r IMAGE_TAG="gcr.io/gvisor-presubmit/packetdrill"
+
+# Make sure that docker is installed.
+docker --version
+
+function finish {
+ local cleanup_success=1
+ for net in "${CTRL_NET}" "${TEST_NET}"; do
+ # Kill all processes attached to ${net}.
+ for docker_command in "kill" "rm"; do
+ (docker network inspect "${net}" \
+ --format '{{range $key, $value := .Containers}}{{$key}} {{end}}' \
+ | xargs -r docker "${docker_command}") || \
+ cleanup_success=0
+ done
+ # Remove the network.
+ docker network rm "${net}" || \
+ cleanup_success=0
+ done
+
+ if ((!$cleanup_success)); then
+ echo "FAIL: Cleanup command failed"
+ exit 4
+ fi
+}
+trap finish EXIT
+
+# Subnet for control packets between test runner and DUT.
+while ! docker network create \
+ "--subnet=${CTRL_NET_PREFIX}.0/${NET_MASK}" "${CTRL_NET}"; do
+ sleep 0.1
+ CTRL_NET_PREFIX=$(new_net_prefix)
+ CTRL_NET="ctrl_net-$(shuf -i 0-99999999 -n 1)"
+done
+
+# Subnet for the packets that are part of the test.
+while ! docker network create \
+ "--subnet=${TEST_NET_PREFIX}.0/${NET_MASK}" "${TEST_NET}"; do
+ sleep 0.1
+ TEST_NET_PREFIX=$(new_net_prefix)
+ TEST_NET="test_net-$(shuf -i 0-99999999 -n 1)"
+done
+
+# Create the DUT container and connect to network.
+DUT=$(docker create ${RUNTIME_ARG} --privileged --rm \
+ --stop-timeout ${TIMEOUT} -it ${IMAGE_TAG})
+docker network connect "${CTRL_NET}" \
+ --ip "${CTRL_NET_PREFIX}${DUT_NET_SUFFIX}" "${DUT}" \
+ || (docker kill ${DUT}; docker rm ${DUT}; false)
+docker network connect "${TEST_NET}" \
+ --ip "${TEST_NET_PREFIX}${DUT_NET_SUFFIX}" "${DUT}" \
+ || (docker kill ${DUT}; docker rm ${DUT}; false)
+docker start "${DUT}"
+
+# Create the test runner container and connect to network.
+TEST_RUNNER=$(docker create --privileged --rm \
+ --stop-timeout ${TIMEOUT} -it ${IMAGE_TAG})
+docker network connect "${CTRL_NET}" \
+ --ip "${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" "${TEST_RUNNER}" \
+ || (docker kill ${TEST_RUNNER}; docker rm ${REST_RUNNER}; false)
+docker network connect "${TEST_NET}" \
+ --ip "${TEST_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" "${TEST_RUNNER}" \
+ || (docker kill ${TEST_RUNNER}; docker rm ${REST_RUNNER}; false)
+docker start "${TEST_RUNNER}"
+
+# Run tcpdump in the test runner unbuffered, without dns resolution, just on the
+# interface with the test packets.
+docker exec -t ${TEST_RUNNER} tcpdump -U -n -i "${TEST_DEVICE}" &
+
+# Start a packetdrill server on the test_runner. The packetdrill server sends
+# packets and asserts that they are received.
+docker exec -d "${TEST_RUNNER}" \
+ ${PACKETDRILL} --wire_server --wire_server_dev="${TEST_DEVICE}" \
+ --wire_server_ip="${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \
+ --wire_server_port="${CTRL_PORT}" \
+ --local_ip="${TEST_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \
+ --remote_ip="${TEST_NET_PREFIX}${DUT_NET_SUFFIX}"
+
+# Because the Linux kernel receives the SYN-ACK but didn't send the SYN it will
+# issue a RST. To prevent this IPtables can be used to filter those out.
+docker exec "${TEST_RUNNER}" \
+ iptables -A OUTPUT -p tcp --tcp-flags RST RST -j DROP
+
+# Wait for the packetdrill server on the test runner to come. Attempt to
+# connect to it from the DUT every 100 milliseconds until success.
+while ! docker exec "${DUT}" \
+ nc -zv "${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" "${CTRL_PORT}"; do
+ sleep 0.1
+done
+
+# Copy the packetdrill setup script to the DUT.
+docker cp -L "${INIT_SCRIPT}" "${DUT}:packetdrill_setup.sh"
+
+# Copy the packetdrill scripts to the DUT.
+declare -a dut_scripts
+for script in $scripts; do
+ docker cp -L "${script}" "${DUT}:$(basename ${script})"
+ dut_scripts+=("/$(basename ${script})")
+done
+
+# Start a packetdrill client on the DUT. The packetdrill client runs POSIX
+# socket commands and also sends instructions to the server.
+docker exec -t "${DUT}" \
+ ${PACKETDRILL} --wire_client --wire_client_dev="${TEST_DEVICE}" \
+ --wire_server_ip="${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \
+ --wire_server_port="${CTRL_PORT}" \
+ --local_ip="${TEST_NET_PREFIX}${DUT_NET_SUFFIX}" \
+ --remote_ip="${TEST_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \
+ --init_scripts=/packetdrill_setup.sh \
+ --tolerance_usecs="${tolerance_usecs}" "${dut_scripts[@]}"
+
+echo PASS: No errors.
diff --git a/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt b/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt
new file mode 100644
index 000000000..a86b90ce6
--- /dev/null
+++ b/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt
@@ -0,0 +1,9 @@
+// Test that a listening socket generates a RST when it receives an
+// ACK and syn cookies are not in use.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
++0 listen(3, 1) = 0
++0.1 < . 1:1(0) ack 1 win 32792
++0 > R 1:1(0) ack 0 win 0 \ No newline at end of file
diff --git a/test/packetdrill/sanity_test.pkt b/test/packetdrill/sanity_test.pkt
new file mode 100644
index 000000000..b3b58c366
--- /dev/null
+++ b/test/packetdrill/sanity_test.pkt
@@ -0,0 +1,7 @@
+// Basic sanity test. One system call.
+//
+// All of the plumbing has to be working however, and the packetdrill wire
+// client needs to be able to connect to the wire server and send the script,
+// probe local interfaces, run through the test w/ timings, etc.
+
+0.000 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
diff --git a/test/packetdrill/tcp_defer_accept.pkt b/test/packetdrill/tcp_defer_accept.pkt
new file mode 100644
index 000000000..a17f946db
--- /dev/null
+++ b/test/packetdrill/tcp_defer_accept.pkt
@@ -0,0 +1,48 @@
+// Test that a bare ACK does not complete a connection when TCP_DEFER_ACCEPT
+// timeout is not hit but an ACK w/ data does complete and deliver the
+// connection to the accept queue.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 setsockopt(3, SOL_TCP, TCP_DEFER_ACCEPT, [5], 4) = 0
++0.000 fcntl(3, F_SETFL, O_RDWR|O_NONBLOCK) = 0
++0 bind(3, ..., ...) = 0
+
+// Set backlog to 1 so that we can easily test.
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0.0 > S. 0:0(0) ack 1 <...>
+
+// Send a bare ACK this should not complete the connection as we
+// set the TCP_DEFER_ACCEPT above.
++0.0 < . 1:1(0) ack 1 win 257
+
+// The bare ACK should be dropped and no connection should be delivered
+// to the accept queue.
++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block)
+
+// Send another bare ACK and it should still fail we set TCP_DEFER_ACCEPT
+// to 5 seconds above.
++2.5 < . 1:1(0) ack 1 win 257
++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block)
+
+// set accept socket back to blocking.
++0.000 fcntl(3, F_SETFL, O_RDWR) = 0
+
+// Now send an ACK w/ data. This should complete the connection
+// and deliver the socket to the accept queue.
++0.1 < . 1:5(4) ack 1 win 257
++0.0 > . 1:1(0) ack 5 <...>
+
+// This should cause connection to transition to connected state.
++0.000 accept(3, ..., ...) = 4
++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0
+
+// Now read the data and we should get 4 bytes.
++0.000 read(4,..., 4) = 4
++0.000 close(4) = 0
+
++0.0 > F. 1:1(0) ack 5 <...>
++0.0 < F. 5:5(0) ack 2 win 257
++0.01 > . 2:2(0) ack 6 <...> \ No newline at end of file
diff --git a/test/packetdrill/tcp_defer_accept_timeout.pkt b/test/packetdrill/tcp_defer_accept_timeout.pkt
new file mode 100644
index 000000000..201fdeb14
--- /dev/null
+++ b/test/packetdrill/tcp_defer_accept_timeout.pkt
@@ -0,0 +1,48 @@
+// Test that a bare ACK is accepted after TCP_DEFER_ACCEPT timeout
+// is hit and a connection is delivered.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 setsockopt(3, SOL_TCP, TCP_DEFER_ACCEPT, [3], 4) = 0
++0.000 fcntl(3, F_SETFL, O_RDWR|O_NONBLOCK) = 0
++0 bind(3, ..., ...) = 0
+
+// Set backlog to 1 so that we can easily test.
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0.0 > S. 0:0(0) ack 1 <...>
+
+// Send a bare ACK this should not complete the connection as we
+// set the TCP_DEFER_ACCEPT above.
++0.0 < . 1:1(0) ack 1 win 257
+
+// The bare ACK should be dropped and no connection should be delivered
+// to the accept queue.
++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block)
+
+// Send another bare ACK and it should still fail we set TCP_DEFER_ACCEPT
+// to 5 seconds above.
++2.5 < . 1:1(0) ack 1 win 257
++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block)
+
+// set accept socket back to blocking.
++0.000 fcntl(3, F_SETFL, O_RDWR) = 0
+
+// We should see one more retransmit of the SYN-ACK as a last ditch
+// attempt when TCP_DEFER_ACCEPT timeout is hit to trigger another
+// ACK or a packet with data.
++.35~+2.35 > S. 0:0(0) ack 1 <...>
+
+// Now send another bare ACK after TCP_DEFER_ACCEPT time has been passed.
++0.0 < . 1:1(0) ack 1 win 257
+
+// The ACK above should cause connection to transition to connected state.
++0.000 accept(3, ..., ...) = 4
++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0
+
++0.000 close(4) = 0
+
++0.0 > F. 1:1(0) ack 1 <...>
++0.0 < F. 1:1(0) ack 2 win 257
++0.01 > . 2:2(0) ack 2 <...>
diff --git a/test/packetimpact/README.md b/test/packetimpact/README.md
new file mode 100644
index 000000000..f46c67a0c
--- /dev/null
+++ b/test/packetimpact/README.md
@@ -0,0 +1,702 @@
+# Packetimpact
+
+## What is packetimpact?
+
+Packetimpact is a tool for platform-independent network testing. It is heavily
+inspired by [packetdrill](https://github.com/google/packetdrill). It creates two
+docker containers connected by a network. One is for the test bench, which
+operates the test. The other is for the device-under-test (DUT), which is the
+software being tested. The test bench communicates over the network with the DUT
+to check correctness of the network.
+
+### Goals
+
+Packetimpact aims to provide:
+
+* A **multi-platform** solution that can test both Linux and gVisor.
+* **Conciseness** on par with packetdrill scripts.
+* **Control-flow** like for loops, conditionals, and variables.
+* **Flexibilty** to specify every byte in a packet or use multiple sockets.
+
+## How to run packetimpact tests?
+
+Build the test container image by running the following at the root of the
+repository:
+
+```bash
+$ make load-packetimpact
+```
+
+Run a test, e.g. `fin_wait2_timeout`, against Linux:
+
+```bash
+$ bazel test //test/packetimpact/tests:fin_wait2_timeout_linux_test
+```
+
+Run the same test, but against gVisor:
+
+```bash
+$ bazel test //test/packetimpact/tests:fin_wait2_timeout_netstack_test
+```
+
+## When to use packetimpact?
+
+There are a few ways to write networking tests for gVisor currently:
+
+* [Go unit tests](https://github.com/google/gvisor/tree/master/pkg/tcpip)
+* [syscall tests](https://github.com/google/gvisor/tree/master/test/syscalls/linux)
+* [packetdrill tests](https://github.com/google/gvisor/tree/master/test/packetdrill)
+* packetimpact tests
+
+The right choice depends on the needs of the test.
+
+Feature | Go unit test | syscall test | packetdrill | packetimpact
+-------------- | ------------ | ------------ | ----------- | ------------
+Multi-platform | no | **YES** | **YES** | **YES**
+Concise | no | somewhat | somewhat | **VERY**
+Control-flow | **YES** | **YES** | no | **YES**
+Flexible | **VERY** | no | somewhat | **VERY**
+
+### Go unit tests
+
+If the test depends on the internals of gVisor and doesn't need to run on Linux
+or other platforms for comparison purposes, a Go unit test can be appropriate.
+They can observe internals of gVisor networking. The downside is that they are
+**not concise** and **not multi-platform**. If you require insight on gVisor
+internals, this is the right choice.
+
+### Syscall tests
+
+Syscall tests are **multi-platform** but cannot examine the internals of gVisor
+networking. They are **concise**. They can use **control-flow** structures like
+conditionals, for loops, and variables. However, they are limited to only what
+the POSIX interface provides so they are **not flexible**. For example, you
+would have difficulty writing a syscall test that intentionally sends a bad IP
+checksum. Or if you did write that test with raw sockets, it would be very
+**verbose** to write a test that intentionally send wrong checksums, wrong
+protocols, wrong sequence numbers, etc.
+
+### Packetdrill tests
+
+Packetdrill tests are **multi-platform** and can run against both Linux and
+gVisor. They are **concise** and use a special packetdrill scripting language.
+They are **more flexible** than a syscall test in that they can send packets
+that a syscall test would have difficulty sending, like a packet with a
+calcuated ACK number. But they are also somewhat limimted in flexibiilty in that
+they can't do tests with multiple sockets. They have **no control-flow** ability
+like variables or conditionals. For example, it isn't possible to send a packet
+that depends on the window size of a previous packet because the packetdrill
+language can't express that. Nor could you branch based on whether or not the
+other side supports window scaling, for example.
+
+### Packetimpact tests
+
+Packetimpact tests are similar to Packetdrill tests except that they are written
+in Go instead of the packetdrill scripting language. That gives them all the
+**control-flow** abilities of Go (loops, functions, variables, etc). They are
+**multi-platform** in the same way as packetdrill tests but even more
+**flexible** because Go is more expressive than the scripting language of
+packetdrill. However, Go is **not as concise** as the packetdrill language. Many
+design decisions below are made to mitigate that.
+
+## How it works
+
+```
+ Testbench Device-Under-Test (DUT)
+ +-------------------+ +------------------------+
+ | | TEST NET | |
+ | rawsockets.go <-->| <===========> | <---+ |
+ | ^ | | | |
+ | | | | | |
+ | v | | | |
+ | unittest | | | |
+ | ^ | | | |
+ | | | | | |
+ | v | | v |
+ | dut.go <========gRPC========> posix server |
+ | | CONTROL NET | |
+ +-------------------+ +------------------------+
+```
+
+Two docker containers are created by a "runner" script, one for the testbench
+and the other for the device under test (DUT). The script connects the two
+containers with a control network and test network. It also does some other
+tasks like waiting until the DUT is ready before starting the test and disabling
+Linux networking that would interfere with the test bench.
+
+### DUT
+
+The DUT container runs a program called the "posix_server". The posix_server is
+written in c++ for maximum portability. It is compiled on the host. The script
+that starts the containers copies it into the DUT's container and runs it. It's
+job is to receive directions from the test bench on what actions to take. For
+this, the posix_server does three steps in a loop:
+
+1. Listen for a request from the test bench.
+2. Execute a command.
+3. Send the response back to the test bench.
+
+The requests and responses are
+[protobufs](https://developers.google.com/protocol-buffers) and the
+communication is done with [gRPC](https://grpc.io/). The commands run are
+[POSIX socket commands](https://en.wikipedia.org/wiki/Berkeley_sockets#Socket_API_functions),
+with the inputs and outputs converted into protobuf requests and responses. All
+communication is on the control network, so that the test network is unaffected
+by extra packets.
+
+For example, this is the request and response pair to call
+[`socket()`](http://man7.org/linux/man-pages/man2/socket.2.html):
+
+```protocol-buffer
+message SocketRequest {
+ int32 domain = 1;
+ int32 type = 2;
+ int32 protocol = 3;
+}
+
+message SocketResponse {
+ int32 fd = 1;
+ int32 errno_ = 2;
+}
+```
+
+##### Alternatives considered
+
+* We could have use JSON for communication instead. It would have been a
+ lighter-touch than protobuf but protobuf handles all the data type and has
+ strict typing to prevent a class of errors. The test bench could be written
+ in other languages, too.
+* Instead of mimicking the POSIX interfaces, arguments could have had a more
+ natural form, like the `bind()` getting a string IP address instead of bytes
+ in a `sockaddr_t`. However, conforming to the existing structures keeps more
+ of the complexity in Go and keeps the posix_server simpler and thus more
+ likely to compile everywhere.
+
+### Test Bench
+
+The test bench does most of the work in a test. It is a Go program that compiles
+on the host and is copied by the script into test bench's container. It is a
+regular [go unit test](https://golang.org/pkg/testing/) that imports the test
+bench framework. The test bench framwork is based on three basic utilities:
+
+* Commanding the DUT to run POSIX commands and return responses.
+* Sending raw packets to the DUT on the test network.
+* Listening for raw packets from the DUT on the test network.
+
+#### DUT commands
+
+To keep the interface to the DUT consistent and easy-to-use, each POSIX command
+supported by the posix_server is wrapped in functions with signatures similar to
+the ones in the [Go unix package](https://godoc.org/golang.org/x/sys/unix). This
+way all the details of endianess and (un)marshalling of go structs such as
+[unix.Timeval](https://godoc.org/golang.org/x/sys/unix#Timeval) is handled in
+one place. This also makes it straight-forward to convert tests that use `unix.`
+or `syscall.` calls to `dut.` calls.
+
+For example, creating a connection to the DUT and commanding it to make a socket
+looks like this:
+
+```go
+dut := testbench.NewDut(t)
+fd, err := dut.SocketWithErrno(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_IP)
+if fd < 0 {
+ t.Fatalf(...)
+}
+```
+
+Because the usual case is to fail the test when the DUT fails to create a
+socket, there is a concise version of each of the `...WithErrno` functions that
+does that:
+
+```go
+dut := testbench.NewDut(t)
+fd := dut.Socket(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_IP)
+```
+
+The DUT and other structs in the code store a `*testing.T` so that they can
+provide versions of functions that call `t.Fatalf(...)`. This helps keep tests
+concise.
+
+##### Alternatives considered
+
+* Instead of mimicking the `unix.` go interface, we could have invented a more
+ natural one, like using `float64` instead of `Timeval`. However, using the
+ same function signatures that `unix.` has makes it easier to convert code to
+ `dut.`. Also, using an existing interface ensures that we don't invent an
+ interface that isn't extensible. For example, if we invented a function for
+ `bind()` that didn't support IPv6 and later we had to add a second `bind6()`
+ function.
+
+#### Sending/Receiving Raw Packets
+
+The framework wraps POSIX sockets for sending and receiving raw frames. Both
+send and receive are synchronous commands.
+[SO_RCVTIMEO](http://man7.org/linux/man-pages/man7/socket.7.html) is used to set
+a timeout on the receive commands. For ease of use, these are wrapped in an
+`Injector` and a `Sniffer`. They have functions:
+
+```go
+func (s *Sniffer) Recv(timeout time.Duration) []byte {...}
+func (i *Injector) Send(b []byte) {...}
+```
+
+##### Alternatives considered
+
+* [gopacket](https://github.com/google/gopacket) pcap has raw socket support
+ but requires cgo. cgo is not guaranteed to be portable from the host to the
+ container and in practice, the container doesn't recognize binaries built on
+ the host if they use cgo.
+* Both gVisor and gopacket have the ability to read and write pcap files
+ without cgo but that is insufficient here because we can't just replay pcap
+ files, we need a more dynamic solution.
+* The sniffer and injector can't share a socket because they need to be bound
+ differently.
+* Sniffing could have been done asynchronously with channels, obviating the
+ need for `SO_RCVTIMEO`. But that would introduce asynchronous complication.
+ `SO_RCVTIMEO` is well supported on the test bench.
+
+#### `Layer` struct
+
+A large part of packetimpact tests is creating packets to send and comparing
+received packets against expectations. To keep tests concise, it is useful to be
+able to specify just the important parts of packets that need to be set. For
+example, sending a packet with default values except for TCP Flags. And for
+packets received, it's useful to be able to compare just the necessary parts of
+received packets and ignore the rest.
+
+To aid in both of those, Go structs with optional fields are created for each
+encapsulation type, such as IPv4, TCP, and Ethernet. This is inspired by
+[scapy](https://scapy.readthedocs.io/en/latest/). For example, here is the
+struct for Ethernet:
+
+```go
+type Ether struct {
+ LayerBase
+ SrcAddr *tcpip.LinkAddress
+ DstAddr *tcpip.LinkAddress
+ Type *tcpip.NetworkProtocolNumber
+}
+```
+
+Each struct has the same fields as those in the
+[gVisor headers](https://github.com/google/gvisor/tree/master/pkg/tcpip/header)
+but with a pointer for each field that may be `nil`.
+
+##### Alternatives considered
+
+* Just use []byte like gVisor headers do. The drawback is that it makes the
+ tests more verbose.
+ * For example, there would be no way to call `Send(myBytes)` concisely and
+ indicate if the checksum should be calculated automatically versus
+ overridden. The only way would be to add lines to the test to calculate
+ it before each Send, which is wordy. Or make multiple versions of Send:
+ one that checksums IP, one that doesn't, one that checksums TCP, one
+ that does both, etc. That would be many combinations.
+ * Filtering inputs would become verbose. Either:
+ * large conditionals that need to be repeated many places:
+ `h[FlagOffset] == SYN && h[LengthOffset:LengthOffset+2] == ...` or
+ * Many functions, one per field, like: `filterByFlag(myBytes, SYN)`,
+ `filterByLength(myBytes, 20)`, `filterByNextProto(myBytes, 0x8000)`,
+ etc.
+ * Using pointers allows us to combine `Layer`s with reflection. So the
+ default `Layers` can be overridden by a `Layers` with just the TCP
+ conection's src/dst which can be overridden by one with just a test
+ specific TCP window size.
+ * It's a proven way to separate the details of a packet from the byte
+ format as shown by scapy's success.
+* Use packetgo. It's more general than parsing packets with gVisor. However:
+ * packetgo doesn't have optional fields so many of the above problems
+ still apply.
+ * It would be yet another dependency.
+ * It's not as well known to engineers that are already writing gVisor
+ code.
+ * It might be a good candidate for replacing the parsing of packets into
+ `Layer`s if all that parsing turns out to be more work than parsing by
+ packetgo and converting *that* to `Layer`. packetgo has easier to use
+ getters for the layers. This could be done later in a way that doesn't
+ break tests.
+
+#### `Layer` methods
+
+The `Layer` structs provide a way to partially specify an encapsulation. They
+also need methods for using those partially specified encapsulation, for example
+to marshal them to bytes or compare them. For those, each encapsulation
+implements the `Layer` interface:
+
+```go
+// Layer is the interface that all encapsulations must implement.
+//
+// A Layer is an encapsulation in a packet, such as TCP, IPv4, IPv6, etc. A
+// Layer contains all the fields of the encapsulation. Each field is a pointer
+// and may be nil.
+type Layer interface {
+ // toBytes converts the Layer into bytes. In places where the Layer's field
+ // isn't nil, the value that is pointed to is used. When the field is nil, a
+ // reasonable default for the Layer is used. For example, "64" for IPv4 TTL
+ // and a calculated checksum for TCP or IP. Some layers require information
+ // from the previous or next layers in order to compute a default, such as
+ // TCP's checksum or Ethernet's type, so each Layer has a doubly-linked list
+ // to the layer's neighbors.
+ toBytes() ([]byte, error)
+
+ // match checks if the current Layer matches the provided Layer. If either
+ // Layer has a nil in a given field, that field is considered matching.
+ // Otherwise, the values pointed to by the fields must match.
+ match(Layer) bool
+
+ // length in bytes of the current encapsulation
+ length() int
+
+ // next gets a pointer to the encapsulated Layer.
+ next() Layer
+
+ // prev gets a pointer to the Layer encapsulating this one.
+ prev() Layer
+
+ // setNext sets the pointer to the encapsulated Layer.
+ setNext(Layer)
+
+ // setPrev sets the pointer to the Layer encapsulating this one.
+ setPrev(Layer)
+}
+```
+
+The `next` and `prev` make up a link listed so that each layer can get at the
+information in the layer around it. This is necessary for some protocols, like
+TCP that needs the layer before and payload after to compute the checksum. Any
+sequence of `Layer` structs is valid so long as the parser and `toBytes`
+functions can map from type to protool number and vice-versa. When the mapping
+fails, an error is emitted explaining what functionality is missing. The
+solution is either to fix the ordering or implement the missing protocol.
+
+For each `Layer` there is also a parsing function. For example, this one is for
+Ethernet:
+
+```
+func ParseEther(b []byte) (Layers, error)
+```
+
+The parsing function converts bytes received on the wire into a `Layer`
+(actually `Layers`, see below) which has no `nil`s in it. By using
+`match(Layer)` to compare against another `Layer` that *does* have `nil`s in it,
+the received bytes can be partially compared. The `nil`s behave as
+"don't-cares".
+
+##### Alternatives considered
+
+* Matching against `[]byte` instead of converting to `Layer` first.
+ * The downside is that it precludes the use of a `cmp.Equal` one-liner to
+ do comparisons.
+ * It creates confusion in the code to deal with both representations at
+ different times. For example, is the checksum calculated on `[]byte` or
+ `Layer` when sending? What about when checking received packets?
+
+#### `Layers`
+
+```
+type Layers []Layer
+
+func (ls *Layers) match(other Layers) bool {...}
+func (ls *Layers) toBytes() ([]byte, error) {...}
+```
+
+`Layers` is an array of `Layer`. It represents a stack of encapsulations, such
+as `Layers{Ether{},IPv4{},TCP{},Payload{}}`. It also has `toBytes()` and
+`match(Layers)`, like `Layer`. The parse functions above actually return
+`Layers` and not `Layer` because they know about the headers below and
+sequentially call each parser on the remaining, encapsulated bytes.
+
+All this leads to the ability to write concise packet processing. For example:
+
+```go
+etherType := 0x8000
+flags = uint8(header.TCPFlagSyn|header.TCPFlagAck)
+toMatch := Layers{Ether{Type: &etherType}, IPv4{}, TCP{Flags: &flags}}
+for {
+ recvBytes := sniffer.Recv(time.Second)
+ if recvBytes == nil {
+ println("Got no packet for 1 second")
+ }
+ gotPacket, err := ParseEther(recvBytes)
+ if err == nil && toMatch.match(gotPacket) {
+ println("Got a TCP/IPv4/Eth packet with SYNACK")
+ }
+}
+```
+
+##### Alternatives considered
+
+* Don't use previous and next pointers.
+ * Each layer may need to be able to interrogate the layers around it, like
+ for computing the next protocol number or total length. So *some*
+ mechanism is needed for a `Layer` to see neighboring layers.
+ * We could pass the entire array `Layers` to the `toBytes()` function.
+ Passing an array to a method that includes in the array the function
+ receiver itself seems wrong.
+
+#### `layerState`
+
+`Layers` represents the different headers of a packet but a connection includes
+more state. For example, a TCP connection needs to keep track of the next
+expected sequence number and also the next sequence number to send. This is
+stored in a `layerState` struct. This is the `layerState` for TCP:
+
+```go
+// tcpState maintains state about a TCP connection.
+type tcpState struct {
+ out, in TCP
+ localSeqNum, remoteSeqNum *seqnum.Value
+ synAck *TCP
+ portPickerFD int
+ finSent bool
+}
+```
+
+The next sequence numbers for each side of the connection are stored. `out` and
+`in` have defaults for the TCP header, such as the expected source and
+destination ports for outgoing packets and incoming packets.
+
+##### `layerState` interface
+
+```go
+// layerState stores the state of a layer of a connection.
+type layerState interface {
+ // outgoing returns an outgoing layer to be sent in a frame.
+ outgoing() Layer
+
+ // incoming creates an expected Layer for comparing against a received Layer.
+ // Because the expectation can depend on values in the received Layer, it is
+ // an input to incoming. For example, the ACK number needs to be checked in a
+ // TCP packet but only if the ACK flag is set in the received packet.
+ incoming(received Layer) Layer
+
+ // sent updates the layerState based on the Layer that was sent. The input is
+ // a Layer with all prev and next pointers populated so that the entire frame
+ // as it was sent is available.
+ sent(sent Layer) error
+
+ // received updates the layerState based on a Layer that is receieved. The
+ // input is a Layer with all prev and next pointers populated so that the
+ // entire frame as it was receieved is available.
+ received(received Layer) error
+
+ // close frees associated resources held by the LayerState.
+ close() error
+}
+```
+
+`outgoing` generates the default Layer for an outgoing packet. For TCP, this
+would be a `TCP` with the source and destination ports populated. Because they
+are static, they are stored inside the `out` member of `tcpState`. However, the
+sequence numbers change frequently so the outgoing sequence number is stored in
+the `localSeqNum` and put into the output of outgoing for each call.
+
+`incoming` does the same functions for packets that arrive but instead of
+generating a packet to send, it generates an expect packet for filtering packets
+that arrive. For example, if a `TCP` header arrives with the wrong ports, it can
+be ignored as belonging to a different connection. `incoming` needs the received
+header itself as an input because the filter may depend on the input. For
+example, the expected sequence number depends on the flags in the TCP header.
+
+`sent` and `received` are run for each header that is actually sent or received
+and used to update the internal state. `incoming` and `outgoing` should *not* be
+used for these purpose. For example, `incoming` is called on every packet that
+arrives but only packets that match ought to actually update the state.
+`outgoing` is called to created outgoing packets and those packets are always
+sent, so unlike `incoming`/`received`, there is one `outgoing` call for each
+`sent` call.
+
+`close` cleans up after the layerState. For example, TCP and UDP need to keep a
+port reserved and then release it.
+
+#### Connections
+
+Using `layerState` above, we can create connections.
+
+```go
+// Connection holds a collection of layer states for maintaining a connection
+// along with sockets for sniffer and injecting packets.
+type Connection struct {
+ layerStates []layerState
+ injector Injector
+ sniffer Sniffer
+ t *testing.T
+}
+```
+
+The connection stores an array of `layerState` in the order that the headers
+should be present in the frame to send. For example, Ether then IPv4 then TCP.
+The injector and sniffer are for writing and reading frames. A `*testing.T` is
+stored so that internal errors can be reported directly without code in the unit
+test.
+
+The `Connection` has some useful functions:
+
+```go
+// Close frees associated resources held by the Connection.
+func (conn *Connection) Close() {...}
+// CreateFrame builds a frame for the connection with layer overriding defaults
+// of the innermost layer and additionalLayers added after it.
+func (conn *Connection) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {...}
+// SendFrame sends a frame on the wire and updates the state of all layers.
+func (conn *Connection) SendFrame(frame Layers) {...}
+// Send a packet with reasonable defaults. Potentially override the final layer
+// in the connection with the provided layer and add additionLayers.
+func (conn *Connection) Send(layer Layer, additionalLayers ...Layer) {...}
+// Expect a frame with the final layerStates layer matching the provided Layer
+// within the timeout specified. If it doesn't arrive in time, it returns nil.
+func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) {...}
+// ExpectFrame expects a frame that matches the provided Layers within the
+// timeout specified. If it doesn't arrive in time, it returns nil.
+func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layers, error) {...}
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *Connection) Drain() {...}
+```
+
+`CreateFrame` uses the `[]layerState` to create a frame to send. The first
+argument is for overriding defaults in the last header of the frame, because
+this is the most common need. For a TCPIPv4 connection, this would be the TCP
+header. Optional additionalLayers can be specified to add to the frame being
+created, such as a `Payload` for `TCP`.
+
+`SendFrame` sends the frame to the DUT. It is combined with `CreateFrame` to
+make `Send`. For unittests with basic sending needs, `Send` can be used. If more
+control is needed over the frame, it can be made with `CreateFrame`, modified in
+the unit test, and then sent with `SendFrame`.
+
+On the receiving side, there is `Expect` and `ExpectFrame`. Like with the
+sending side, there are two forms of each function, one for just the last header
+and one for the whole frame. The expect functions use the `[]layerState` to
+create a template for the expected incoming frame. That frame is then overridden
+by the values in the first argument. Finally, a loop starts sniffing packets on
+the wire for frames. If a matching frame is found before the timeout, it is
+returned without error. If not, nil is returned and the error contains text of
+all the received frames that didn't match. Exactly one of the outputs will be
+non-nil, even if no frames are received at all.
+
+`Drain` sniffs and discards all the frames that have yet to be received. A
+common way to write a test is:
+
+```go
+conn.Drain() // Discard all outstanding frames.
+conn.Send(...) // Send a frame with overrides.
+// Now expect a frame with a certain header and fail if it doesn't arrive.
+if _, err := conn.Expect(...); err != nil { t.Fatal(...) }
+```
+
+Or for a test where we want to check that no frame arrives:
+
+```go
+if gotOne, _ := conn.Expect(...); gotOne != nil { t.Fatal(...) }
+```
+
+#### Specializing `Connection`
+
+Because there are some common combinations of `layerState` into `Connection`,
+they are defined:
+
+```go
+// TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection.
+type TCPIPv4 Connection
+// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection.
+type UDPIPv4 Connection
+```
+
+Each has a `NewXxx` function to create a new connection with reasonable
+defaults. They also have functions that call the underlying `Connection`
+functions but with specialization and tighter type-checking. For example:
+
+```go
+func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) {
+ (*Connection)(conn).Send(&tcp, additionalLayers...)
+}
+func (conn *TCPIPv4) Drain() {
+ conn.sniffer.Drain()
+}
+```
+
+They may also have some accessors to get or set the internal state of the
+connection:
+
+```go
+func (conn *TCPIPv4) state() *tcpState {
+ state, ok := conn.layerStates[len(conn.layerStates)-1].(*tcpState)
+ if !ok {
+ conn.t.Fatalf("expected final state of %v to be tcpState", conn.layerStates)
+ }
+ return state
+}
+func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value {
+ return conn.state().remoteSeqNum
+}
+func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value {
+ return conn.state().localSeqNum
+}
+```
+
+Unittests will in practice use these functions and not the functions on
+`Connection`. For example, `NewTCPIPv4()` and then call `Send` on that rather
+than cast is to a `Connection` and call `Send` on that cast result.
+
+##### Alternatives considered
+
+* Instead of storing `outgoing` and `incoming`, store values.
+ * There would be many more things to store instead, like `localMac`,
+ `remoteMac`, `localIP`, `remoteIP`, `localPort`, and `remotePort`.
+ * Construction of a packet would be many lines to copy each of these
+ values into a `[]byte`. And there would be slight variations needed for
+ each encapsulation stack, like TCPIPv6 and ARP.
+ * Filtering incoming packets would be a long sequence:
+ * Compare the MACs, then
+ * Parse the next header, then
+ * Compare the IPs, then
+ * Parse the next header, then
+ * Compare the TCP ports. Instead it's all just one call to
+ `cmp.Equal(...)`, for all sequences.
+ * A TCPIPv6 connection could share most of the code. Only the type of the
+ IP addresses are different. The types of `outgoing` and `incoming` would
+ be remain `Layers`.
+ * An ARP connection could share all the Ethernet parts. The IP `Layer`
+ could be factored out of `outgoing`. After that, the IPv4 and IPv6
+ connections could implement one interface and a single TCP struct could
+ have either network protocol through composition.
+
+## Putting it all together
+
+Here's what te start of a packetimpact unit test looks like. This test creates a
+TCP connection with the DUT. There are added comments for explanation in this
+document but a real test might not include them in order to stay even more
+concise.
+
+```go
+func TestMyTcpTest(t *testing.T) {
+ // Prepare a DUT for communication.
+ dut := testbench.NewDUT(t)
+
+ // This does:
+ // dut.Socket()
+ // dut.Bind()
+ // dut.Getsockname() to learn the new port number
+ // dut.Listen()
+ listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFD) // Tell the DUT to close the socket at the end of the test.
+
+ // Monitor a new TCP connection with sniffer, injector, sequence number tracking,
+ // and reasonable outgoing and incoming packet field default IPs, MACs, and port numbers.
+ conn := testbench.NewTCPIPv4(t, dut, remotePort)
+
+ // Perform a 3-way handshake: send SYN, expect SYNACK, send ACK.
+ conn.Handshake()
+
+ // Tell the DUT to accept the new connection.
+ acceptFD := dut.Accept(acceptFd)
+}
+```
+
+## Other notes
+
+* The time between receiving a SYN-ACK and replying with an ACK in `Handshake`
+ is about 3ms. This is much slower than the native unix response, which is
+ about 0.3ms. Packetdrill gets closer to 0.3ms. For tests where timing is
+ crucial, packetdrill is faster and more precise.
diff --git a/test/packetimpact/dut/BUILD b/test/packetimpact/dut/BUILD
new file mode 100644
index 000000000..3ce63c2c6
--- /dev/null
+++ b/test/packetimpact/dut/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "cc_binary", "grpcpp")
+
+package(
+ default_visibility = ["//test/packetimpact:__subpackages__"],
+ licenses = ["notice"],
+)
+
+cc_binary(
+ name = "posix_server",
+ srcs = ["posix_server.cc"],
+ linkstatic = 1,
+ static = True, # This is needed for running in a docker container.
+ deps = [
+ grpcpp,
+ "//test/packetimpact/proto:posix_server_cc_grpc_proto",
+ "//test/packetimpact/proto:posix_server_cc_proto",
+ ],
+)
diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc
new file mode 100644
index 000000000..a1a5c3612
--- /dev/null
+++ b/test/packetimpact/dut/posix_server.cc
@@ -0,0 +1,365 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at //
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <fcntl.h>
+#include <getopt.h>
+#include <netdb.h>
+#include <netinet/in.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <iostream>
+#include <unordered_map>
+
+#include "include/grpcpp/security/server_credentials.h"
+#include "include/grpcpp/server_builder.h"
+#include "test/packetimpact/proto/posix_server.grpc.pb.h"
+#include "test/packetimpact/proto/posix_server.pb.h"
+
+// Converts a sockaddr_storage to a Sockaddr message.
+::grpc::Status sockaddr_to_proto(const sockaddr_storage &addr,
+ socklen_t addrlen,
+ posix_server::Sockaddr *sockaddr_proto) {
+ switch (addr.ss_family) {
+ case AF_INET: {
+ auto addr_in = reinterpret_cast<const sockaddr_in *>(&addr);
+ auto response_in = sockaddr_proto->mutable_in();
+ response_in->set_family(addr_in->sin_family);
+ response_in->set_port(ntohs(addr_in->sin_port));
+ response_in->mutable_addr()->assign(
+ reinterpret_cast<const char *>(&addr_in->sin_addr.s_addr), 4);
+ return ::grpc::Status::OK;
+ }
+ case AF_INET6: {
+ auto addr_in6 = reinterpret_cast<const sockaddr_in6 *>(&addr);
+ auto response_in6 = sockaddr_proto->mutable_in6();
+ response_in6->set_family(addr_in6->sin6_family);
+ response_in6->set_port(ntohs(addr_in6->sin6_port));
+ response_in6->set_flowinfo(ntohl(addr_in6->sin6_flowinfo));
+ response_in6->mutable_addr()->assign(
+ reinterpret_cast<const char *>(&addr_in6->sin6_addr.s6_addr), 16);
+ response_in6->set_scope_id(ntohl(addr_in6->sin6_scope_id));
+ return ::grpc::Status::OK;
+ }
+ }
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Unknown Sockaddr");
+}
+
+::grpc::Status proto_to_sockaddr(const posix_server::Sockaddr &sockaddr_proto,
+ sockaddr_storage *addr, socklen_t *addr_len) {
+ switch (sockaddr_proto.sockaddr_case()) {
+ case posix_server::Sockaddr::SockaddrCase::kIn: {
+ auto proto_in = sockaddr_proto.in();
+ if (proto_in.addr().size() != 4) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "IPv4 address must be 4 bytes");
+ }
+ auto addr_in = reinterpret_cast<sockaddr_in *>(addr);
+ addr_in->sin_family = proto_in.family();
+ addr_in->sin_port = htons(proto_in.port());
+ proto_in.addr().copy(reinterpret_cast<char *>(&addr_in->sin_addr.s_addr),
+ 4);
+ *addr_len = sizeof(*addr_in);
+ break;
+ }
+ case posix_server::Sockaddr::SockaddrCase::kIn6: {
+ auto proto_in6 = sockaddr_proto.in6();
+ if (proto_in6.addr().size() != 16) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "IPv6 address must be 16 bytes");
+ }
+ auto addr_in6 = reinterpret_cast<sockaddr_in6 *>(addr);
+ addr_in6->sin6_family = proto_in6.family();
+ addr_in6->sin6_port = htons(proto_in6.port());
+ addr_in6->sin6_flowinfo = htonl(proto_in6.flowinfo());
+ proto_in6.addr().copy(
+ reinterpret_cast<char *>(&addr_in6->sin6_addr.s6_addr), 16);
+ addr_in6->sin6_scope_id = htonl(proto_in6.scope_id());
+ *addr_len = sizeof(*addr_in6);
+ break;
+ }
+ case posix_server::Sockaddr::SockaddrCase::SOCKADDR_NOT_SET:
+ default:
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Unknown Sockaddr");
+ }
+ return ::grpc::Status::OK;
+}
+
+class PosixImpl final : public posix_server::Posix::Service {
+ ::grpc::Status Accept(grpc_impl::ServerContext *context,
+ const ::posix_server::AcceptRequest *request,
+ ::posix_server::AcceptResponse *response) override {
+ sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ response->set_fd(accept(request->sockfd(),
+ reinterpret_cast<sockaddr *>(&addr), &addrlen));
+ response->set_errno_(errno);
+ return sockaddr_to_proto(addr, addrlen, response->mutable_addr());
+ }
+
+ ::grpc::Status Bind(grpc_impl::ServerContext *context,
+ const ::posix_server::BindRequest *request,
+ ::posix_server::BindResponse *response) override {
+ if (!request->has_addr()) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Missing address");
+ }
+
+ sockaddr_storage addr;
+ socklen_t addr_len;
+ auto err = proto_to_sockaddr(request->addr(), &addr, &addr_len);
+ if (!err.ok()) {
+ return err;
+ }
+
+ response->set_ret(
+ bind(request->sockfd(), reinterpret_cast<sockaddr *>(&addr), addr_len));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Close(grpc_impl::ServerContext *context,
+ const ::posix_server::CloseRequest *request,
+ ::posix_server::CloseResponse *response) override {
+ response->set_ret(close(request->fd()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Connect(grpc_impl::ServerContext *context,
+ const ::posix_server::ConnectRequest *request,
+ ::posix_server::ConnectResponse *response) override {
+ if (!request->has_addr()) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Missing address");
+ }
+ sockaddr_storage addr;
+ socklen_t addr_len;
+ auto err = proto_to_sockaddr(request->addr(), &addr, &addr_len);
+ if (!err.ok()) {
+ return err;
+ }
+
+ response->set_ret(connect(request->sockfd(),
+ reinterpret_cast<sockaddr *>(&addr), addr_len));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Fcntl(grpc_impl::ServerContext *context,
+ const ::posix_server::FcntlRequest *request,
+ ::posix_server::FcntlResponse *response) override {
+ response->set_ret(::fcntl(request->fd(), request->cmd(), request->arg()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status GetSockName(
+ grpc_impl::ServerContext *context,
+ const ::posix_server::GetSockNameRequest *request,
+ ::posix_server::GetSockNameResponse *response) override {
+ sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ response->set_ret(getsockname(
+ request->sockfd(), reinterpret_cast<sockaddr *>(&addr), &addrlen));
+ response->set_errno_(errno);
+ return sockaddr_to_proto(addr, addrlen, response->mutable_addr());
+ }
+
+ ::grpc::Status GetSockOpt(
+ grpc_impl::ServerContext *context,
+ const ::posix_server::GetSockOptRequest *request,
+ ::posix_server::GetSockOptResponse *response) override {
+ switch (request->type()) {
+ case ::posix_server::GetSockOptRequest::BYTES: {
+ socklen_t optlen = request->optlen();
+ std::vector<char> buf(optlen);
+ response->set_ret(::getsockopt(request->sockfd(), request->level(),
+ request->optname(), buf.data(),
+ &optlen));
+ if (optlen >= 0) {
+ response->mutable_optval()->set_bytesval(buf.data(), optlen);
+ }
+ break;
+ }
+ case ::posix_server::GetSockOptRequest::INT: {
+ int intval = 0;
+ socklen_t optlen = sizeof(intval);
+ response->set_ret(::getsockopt(request->sockfd(), request->level(),
+ request->optname(), &intval, &optlen));
+ response->mutable_optval()->set_intval(intval);
+ break;
+ }
+ case ::posix_server::GetSockOptRequest::TIME: {
+ timeval tv;
+ socklen_t optlen = sizeof(tv);
+ response->set_ret(::getsockopt(request->sockfd(), request->level(),
+ request->optname(), &tv, &optlen));
+ response->mutable_optval()->mutable_timeval()->set_seconds(tv.tv_sec);
+ response->mutable_optval()->mutable_timeval()->set_microseconds(
+ tv.tv_usec);
+ break;
+ }
+ default:
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Unknown SockOpt Type");
+ }
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Listen(grpc_impl::ServerContext *context,
+ const ::posix_server::ListenRequest *request,
+ ::posix_server::ListenResponse *response) override {
+ response->set_ret(listen(request->sockfd(), request->backlog()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Send(::grpc::ServerContext *context,
+ const ::posix_server::SendRequest *request,
+ ::posix_server::SendResponse *response) override {
+ response->set_ret(::send(request->sockfd(), request->buf().data(),
+ request->buf().size(), request->flags()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status SendTo(::grpc::ServerContext *context,
+ const ::posix_server::SendToRequest *request,
+ ::posix_server::SendToResponse *response) override {
+ if (!request->has_dest_addr()) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Missing address");
+ }
+ sockaddr_storage addr;
+ socklen_t addr_len;
+ auto err = proto_to_sockaddr(request->dest_addr(), &addr, &addr_len);
+ if (!err.ok()) {
+ return err;
+ }
+
+ response->set_ret(::sendto(request->sockfd(), request->buf().data(),
+ request->buf().size(), request->flags(),
+ reinterpret_cast<sockaddr *>(&addr), addr_len));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status SetSockOpt(
+ grpc_impl::ServerContext *context,
+ const ::posix_server::SetSockOptRequest *request,
+ ::posix_server::SetSockOptResponse *response) override {
+ switch (request->optval().val_case()) {
+ case ::posix_server::SockOptVal::kBytesval:
+ response->set_ret(setsockopt(request->sockfd(), request->level(),
+ request->optname(),
+ request->optval().bytesval().c_str(),
+ request->optval().bytesval().size()));
+ break;
+ case ::posix_server::SockOptVal::kIntval: {
+ int opt = request->optval().intval();
+ response->set_ret(::setsockopt(request->sockfd(), request->level(),
+ request->optname(), &opt, sizeof(opt)));
+ break;
+ }
+ case ::posix_server::SockOptVal::kTimeval: {
+ timeval tv = {.tv_sec = static_cast<__time_t>(
+ request->optval().timeval().seconds()),
+ .tv_usec = static_cast<__suseconds_t>(
+ request->optval().timeval().microseconds())};
+ response->set_ret(setsockopt(request->sockfd(), request->level(),
+ request->optname(), &tv, sizeof(tv)));
+ break;
+ }
+ default:
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Unknown SockOpt Type");
+ }
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Socket(grpc_impl::ServerContext *context,
+ const ::posix_server::SocketRequest *request,
+ ::posix_server::SocketResponse *response) override {
+ response->set_fd(
+ socket(request->domain(), request->type(), request->protocol()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Recv(::grpc::ServerContext *context,
+ const ::posix_server::RecvRequest *request,
+ ::posix_server::RecvResponse *response) override {
+ std::vector<char> buf(request->len());
+ response->set_ret(
+ recv(request->sockfd(), buf.data(), buf.size(), request->flags()));
+ if (response->ret() >= 0) {
+ response->set_buf(buf.data(), response->ret());
+ }
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+};
+
+// Parse command line options. Returns a pointer to the first argument beyond
+// the options.
+void parse_command_line_options(int argc, char *argv[], std::string *ip,
+ int *port) {
+ static struct option options[] = {{"ip", required_argument, NULL, 1},
+ {"port", required_argument, NULL, 2},
+ {0, 0, 0, 0}};
+
+ // Parse the arguments.
+ int c;
+ while ((c = getopt_long(argc, argv, "", options, NULL)) > 0) {
+ if (c == 1) {
+ *ip = optarg;
+ } else if (c == 2) {
+ *port = std::stoi(std::string(optarg));
+ }
+ }
+}
+
+void run_server(const std::string &ip, int port) {
+ PosixImpl posix_service;
+ grpc::ServerBuilder builder;
+ std::string server_address = ip + ":" + std::to_string(port);
+ // Set the authentication mechanism.
+ std::shared_ptr<grpc::ServerCredentials> creds =
+ grpc::InsecureServerCredentials();
+ builder.AddListeningPort(server_address, creds);
+ builder.RegisterService(&posix_service);
+
+ std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
+ std::cerr << "Server listening on " << server_address << std::endl;
+ server->Wait();
+ std::cerr << "posix_server is finished." << std::endl;
+}
+
+int main(int argc, char *argv[]) {
+ std::cerr << "posix_server is starting." << std::endl;
+ std::string ip;
+ int port;
+ parse_command_line_options(argc, argv, &ip, &port);
+
+ std::cerr << "Got IP " << ip << " and port " << port << "." << std::endl;
+ run_server(ip, port);
+}
diff --git a/test/packetimpact/netdevs/BUILD b/test/packetimpact/netdevs/BUILD
new file mode 100644
index 000000000..422bb9b0c
--- /dev/null
+++ b/test/packetimpact/netdevs/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "netdevs",
+ srcs = ["netdevs.go"],
+ visibility = ["//test/packetimpact:__subpackages__"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ ],
+)
diff --git a/test/packetimpact/netdevs/netdevs.go b/test/packetimpact/netdevs/netdevs.go
new file mode 100644
index 000000000..d2c9cfeaf
--- /dev/null
+++ b/test/packetimpact/netdevs/netdevs.go
@@ -0,0 +1,104 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package netdevs contains utilities for working with network devices.
+package netdevs
+
+import (
+ "fmt"
+ "net"
+ "regexp"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// A DeviceInfo represents a network device.
+type DeviceInfo struct {
+ MAC net.HardwareAddr
+ IPv4Addr net.IP
+ IPv4Net *net.IPNet
+ IPv6Addr net.IP
+ IPv6Net *net.IPNet
+}
+
+var (
+ deviceLine = regexp.MustCompile(`^\s*\d+: (\w+)`)
+ linkLine = regexp.MustCompile(`^\s*link/\w+ ([0-9a-fA-F:]+)`)
+ inetLine = regexp.MustCompile(`^\s*inet ([0-9./]+)`)
+ inet6Line = regexp.MustCompile(`^\s*inet6 ([0-9a-fA-Z:/]+)`)
+)
+
+// ParseDevices parses the output from `ip addr show` into a map from device
+// name to information about the device.
+func ParseDevices(cmdOutput string) (map[string]DeviceInfo, error) {
+ var currentDevice string
+ var currentInfo DeviceInfo
+ deviceInfos := make(map[string]DeviceInfo)
+ for _, line := range strings.Split(cmdOutput, "\n") {
+ if m := deviceLine.FindStringSubmatch(line); m != nil {
+ if currentDevice != "" {
+ deviceInfos[currentDevice] = currentInfo
+ }
+ currentInfo = DeviceInfo{}
+ currentDevice = m[1]
+ } else if m := linkLine.FindStringSubmatch(line); m != nil {
+ mac, err := net.ParseMAC(m[1])
+ if err != nil {
+ return nil, err
+ }
+ currentInfo.MAC = mac
+ } else if m := inetLine.FindStringSubmatch(line); m != nil {
+ ipv4Addr, ipv4Net, err := net.ParseCIDR(m[1])
+ if err != nil {
+ return nil, err
+ }
+ currentInfo.IPv4Addr = ipv4Addr
+ currentInfo.IPv4Net = ipv4Net
+ } else if m := inet6Line.FindStringSubmatch(line); m != nil {
+ ipv6Addr, ipv6Net, err := net.ParseCIDR(m[1])
+ if err != nil {
+ return nil, err
+ }
+ currentInfo.IPv6Addr = ipv6Addr
+ currentInfo.IPv6Net = ipv6Net
+ }
+ }
+ if currentDevice != "" {
+ deviceInfos[currentDevice] = currentInfo
+ }
+ return deviceInfos, nil
+}
+
+// MACToIP converts the MAC address to an IPv6 link local address as described
+// in RFC 4291 page 20: https://tools.ietf.org/html/rfc4291#page-20
+func MACToIP(mac net.HardwareAddr) net.IP {
+ addr := make([]byte, header.IPv6AddressSize)
+ addr[0] = 0xfe
+ addr[1] = 0x80
+ header.EthernetAdddressToModifiedEUI64IntoBuf(tcpip.LinkAddress(mac), addr[8:])
+ return net.IP(addr)
+}
+
+// FindDeviceByIP finds a DeviceInfo and device name from an IP address in the
+// output of ParseDevices.
+func FindDeviceByIP(ip net.IP, devices map[string]DeviceInfo) (string, DeviceInfo, error) {
+ for dev, info := range devices {
+ if info.IPv4Addr.Equal(ip) {
+ return dev, info, nil
+ }
+ }
+ return "", DeviceInfo{}, fmt.Errorf("can't find %s on any interface", ip)
+}
diff --git a/test/packetimpact/proto/BUILD b/test/packetimpact/proto/BUILD
new file mode 100644
index 000000000..4a4370f42
--- /dev/null
+++ b/test/packetimpact/proto/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "proto_library")
+
+package(
+ default_visibility = ["//test/packetimpact:__subpackages__"],
+ licenses = ["notice"],
+)
+
+proto_library(
+ name = "posix_server",
+ srcs = ["posix_server.proto"],
+ has_services = 1,
+)
diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto
new file mode 100644
index 000000000..ccd20b10d
--- /dev/null
+++ b/test/packetimpact/proto/posix_server.proto
@@ -0,0 +1,230 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package posix_server;
+
+message SockaddrIn {
+ int32 family = 1;
+ uint32 port = 2;
+ bytes addr = 3;
+}
+
+message SockaddrIn6 {
+ uint32 family = 1;
+ uint32 port = 2;
+ uint32 flowinfo = 3;
+ bytes addr = 4;
+ uint32 scope_id = 5;
+}
+
+message Sockaddr {
+ oneof sockaddr {
+ SockaddrIn in = 1;
+ SockaddrIn6 in6 = 2;
+ }
+}
+
+message Timeval {
+ int64 seconds = 1;
+ int64 microseconds = 2;
+}
+
+message SockOptVal {
+ oneof val {
+ bytes bytesval = 1;
+ int32 intval = 2;
+ Timeval timeval = 3;
+ }
+}
+
+// Request and Response pairs for each Posix service RPC call, sorted.
+
+message AcceptRequest {
+ int32 sockfd = 1;
+}
+
+message AcceptResponse {
+ int32 fd = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+ Sockaddr addr = 3;
+}
+
+message BindRequest {
+ int32 sockfd = 1;
+ Sockaddr addr = 2;
+}
+
+message BindResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message CloseRequest {
+ int32 fd = 1;
+}
+
+message CloseResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message ConnectRequest {
+ int32 sockfd = 1;
+ Sockaddr addr = 2;
+}
+
+message ConnectResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message FcntlRequest {
+ int32 fd = 1;
+ int32 cmd = 2;
+ int32 arg = 3;
+}
+
+message FcntlResponse {
+ int32 ret = 1;
+ int32 errno_ = 2;
+}
+
+message GetSockNameRequest {
+ int32 sockfd = 1;
+}
+
+message GetSockNameResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+ Sockaddr addr = 3;
+}
+
+message GetSockOptRequest {
+ int32 sockfd = 1;
+ int32 level = 2;
+ int32 optname = 3;
+ int32 optlen = 4;
+ enum SockOptType {
+ UNSPECIFIED = 0;
+ BYTES = 1;
+ INT = 2;
+ TIME = 3;
+ }
+ SockOptType type = 5;
+}
+
+message GetSockOptResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+ SockOptVal optval = 3;
+}
+
+message ListenRequest {
+ int32 sockfd = 1;
+ int32 backlog = 2;
+}
+
+message ListenResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message SendRequest {
+ int32 sockfd = 1;
+ bytes buf = 2;
+ int32 flags = 3;
+}
+
+message SendResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message SendToRequest {
+ int32 sockfd = 1;
+ bytes buf = 2;
+ int32 flags = 3;
+ Sockaddr dest_addr = 4;
+}
+
+message SendToResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message SetSockOptRequest {
+ int32 sockfd = 1;
+ int32 level = 2;
+ int32 optname = 3;
+ SockOptVal optval = 4;
+}
+
+message SetSockOptResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message SocketRequest {
+ int32 domain = 1;
+ int32 type = 2;
+ int32 protocol = 3;
+}
+
+message SocketResponse {
+ int32 fd = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message RecvRequest {
+ int32 sockfd = 1;
+ int32 len = 2;
+ int32 flags = 3;
+}
+
+message RecvResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+ bytes buf = 3;
+}
+
+service Posix {
+ // Call accept() on the DUT.
+ rpc Accept(AcceptRequest) returns (AcceptResponse);
+ // Call bind() on the DUT.
+ rpc Bind(BindRequest) returns (BindResponse);
+ // Call close() on the DUT.
+ rpc Close(CloseRequest) returns (CloseResponse);
+ // Call connect() on the DUT.
+ rpc Connect(ConnectRequest) returns (ConnectResponse);
+ // Call fcntl() on the DUT.
+ rpc Fcntl(FcntlRequest) returns (FcntlResponse);
+ // Call getsockname() on the DUT.
+ rpc GetSockName(GetSockNameRequest) returns (GetSockNameResponse);
+ // Call getsockopt() on the DUT.
+ rpc GetSockOpt(GetSockOptRequest) returns (GetSockOptResponse);
+ // Call listen() on the DUT.
+ rpc Listen(ListenRequest) returns (ListenResponse);
+ // Call send() on the DUT.
+ rpc Send(SendRequest) returns (SendResponse);
+ // Call sendto() on the DUT.
+ rpc SendTo(SendToRequest) returns (SendToResponse);
+ // Call setsockopt() on the DUT.
+ rpc SetSockOpt(SetSockOptRequest) returns (SetSockOptResponse);
+ // Call socket() on the DUT.
+ rpc Socket(SocketRequest) returns (SocketResponse);
+ // Call recv() on the DUT.
+ rpc Recv(RecvRequest) returns (RecvResponse);
+}
diff --git a/test/packetimpact/runner/BUILD b/test/packetimpact/runner/BUILD
new file mode 100644
index 000000000..bad4f0183
--- /dev/null
+++ b/test/packetimpact/runner/BUILD
@@ -0,0 +1,21 @@
+load("//tools:defs.bzl", "go_test")
+
+package(
+ default_visibility = ["//test/packetimpact:__subpackages__"],
+ licenses = ["notice"],
+)
+
+go_test(
+ name = "packetimpact_test",
+ srcs = ["packetimpact_test.go"],
+ tags = [
+ # Not intended to be run directly.
+ "local",
+ "manual",
+ ],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//test/packetimpact/netdevs",
+ "@com_github_docker_docker//api/types/mount:go_default_library",
+ ],
+)
diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl
new file mode 100644
index 000000000..77cdfea12
--- /dev/null
+++ b/test/packetimpact/runner/defs.bzl
@@ -0,0 +1,136 @@
+"""Defines rules for packetimpact test targets."""
+
+load("//tools:defs.bzl", "go_test")
+
+def _packetimpact_test_impl(ctx):
+ test_runner = ctx.executable._test_runner
+ bench = ctx.actions.declare_file("%s-bench" % ctx.label.name)
+ bench_content = "\n".join([
+ "#!/bin/bash",
+ # This test will run part in a distinct user namespace. This can cause
+ # permission problems, because all runfiles may not be owned by the
+ # current user, and no other users will be mapped in that namespace.
+ # Make sure that everything is readable here.
+ "find . -type f -or -type d -exec chmod a+rx {} \\;",
+ "%s %s --testbench_binary %s $@\n" % (
+ test_runner.short_path,
+ " ".join(ctx.attr.flags),
+ ctx.files.testbench_binary[0].short_path,
+ ),
+ ])
+ ctx.actions.write(bench, bench_content, is_executable = True)
+
+ transitive_files = []
+ if hasattr(ctx.attr._test_runner, "data_runfiles"):
+ transitive_files.append(ctx.attr._test_runner.data_runfiles.files)
+ runfiles = ctx.runfiles(
+ files = [test_runner] + ctx.files.testbench_binary + ctx.files._posix_server_binary,
+ transitive_files = depset(transitive = transitive_files),
+ collect_default = True,
+ collect_data = True,
+ )
+ return [DefaultInfo(executable = bench, runfiles = runfiles)]
+
+_packetimpact_test = rule(
+ attrs = {
+ "_test_runner": attr.label(
+ executable = True,
+ cfg = "target",
+ default = ":packetimpact_test",
+ ),
+ "_posix_server_binary": attr.label(
+ cfg = "target",
+ default = "//test/packetimpact/dut:posix_server",
+ ),
+ "testbench_binary": attr.label(
+ cfg = "target",
+ mandatory = True,
+ ),
+ "flags": attr.string_list(
+ mandatory = False,
+ default = [],
+ ),
+ },
+ test = True,
+ implementation = _packetimpact_test_impl,
+)
+
+PACKETIMPACT_TAGS = ["local", "manual"]
+
+def packetimpact_linux_test(
+ name,
+ testbench_binary,
+ expect_failure = False,
+ **kwargs):
+ """Add a packetimpact test on linux.
+
+ Args:
+ name: name of the test
+ testbench_binary: the testbench binary
+ expect_failure: the test must fail
+ **kwargs: all the other args, forwarded to _packetimpact_test
+ """
+ expect_failure_flag = ["--expect_failure"] if expect_failure else []
+ _packetimpact_test(
+ name = name + "_linux_test",
+ testbench_binary = testbench_binary,
+ flags = ["--dut_platform", "linux"] + expect_failure_flag,
+ tags = PACKETIMPACT_TAGS + ["packetimpact"],
+ **kwargs
+ )
+
+def packetimpact_netstack_test(
+ name,
+ testbench_binary,
+ expect_failure = False,
+ **kwargs):
+ """Add a packetimpact test on netstack.
+
+ Args:
+ name: name of the test
+ testbench_binary: the testbench binary
+ expect_failure: the test must fail
+ **kwargs: all the other args, forwarded to _packetimpact_test
+ """
+ expect_failure_flag = []
+ if expect_failure:
+ expect_failure_flag = ["--expect_failure"]
+ _packetimpact_test(
+ name = name + "_netstack_test",
+ testbench_binary = testbench_binary,
+ # This is the default runtime unless
+ # "--test_arg=--runtime=OTHER_RUNTIME" is used to override the value.
+ flags = ["--dut_platform", "netstack", "--runtime=runsc-d"] + expect_failure_flag,
+ tags = PACKETIMPACT_TAGS + ["packetimpact"],
+ **kwargs
+ )
+
+def packetimpact_go_test(name, size = "small", pure = True, expect_linux_failure = False, expect_netstack_failure = False, **kwargs):
+ """Add packetimpact tests written in go.
+
+ Args:
+ name: name of the test
+ size: size of the test
+ pure: make a static go binary
+ expect_linux_failure: the test must fail for Linux
+ expect_netstack_failure: the test must fail for Netstack
+ **kwargs: all the other args, forwarded to go_test
+ """
+ testbench_binary = name + "_test"
+ go_test(
+ name = testbench_binary,
+ size = size,
+ pure = pure,
+ tags = PACKETIMPACT_TAGS,
+ **kwargs
+ )
+ packetimpact_linux_test(
+ name = name,
+ expect_failure = expect_linux_failure,
+ testbench_binary = testbench_binary,
+ )
+ packetimpact_netstack_test(
+ name = name,
+ expect_failure = expect_netstack_failure,
+ testbench_binary = testbench_binary,
+ )
diff --git a/test/packetimpact/runner/packetimpact_test.go b/test/packetimpact/runner/packetimpact_test.go
new file mode 100644
index 000000000..9290d5112
--- /dev/null
+++ b/test/packetimpact/runner/packetimpact_test.go
@@ -0,0 +1,370 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// The runner starts docker containers and networking for a packetimpact test.
+package packetimpact_test
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "math/rand"
+ "net"
+ "os"
+ "os/exec"
+ "path"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/docker/docker/api/types/mount"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/packetimpact/netdevs"
+)
+
+// stringList implements flag.Value.
+type stringList []string
+
+// String implements flag.Value.String.
+func (l *stringList) String() string {
+ return strings.Join(*l, ",")
+}
+
+// Set implements flag.Value.Set.
+func (l *stringList) Set(value string) error {
+ *l = append(*l, value)
+ return nil
+}
+
+var (
+ dutPlatform = flag.String("dut_platform", "", "either \"linux\" or \"netstack\"")
+ testbenchBinary = flag.String("testbench_binary", "", "path to the testbench binary")
+ tshark = flag.Bool("tshark", false, "use more verbose tshark in logs instead of tcpdump")
+ extraTestArgs = stringList{}
+ expectFailure = flag.Bool("expect_failure", false, "expect that the test will fail when run")
+
+ dutAddr = net.IPv4(0, 0, 0, 10)
+ testbenchAddr = net.IPv4(0, 0, 0, 20)
+)
+
+const ctrlPort = "40000"
+
+// logger implements testutil.Logger.
+//
+// Labels logs based on their source and formats multi-line logs.
+type logger string
+
+// Name implements testutil.Logger.Name.
+func (l logger) Name() string {
+ return string(l)
+}
+
+// Logf implements testutil.Logger.Logf.
+func (l logger) Logf(format string, args ...interface{}) {
+ lines := strings.Split(fmt.Sprintf(format, args...), "\n")
+ log.Printf("%s: %s", l, lines[0])
+ for _, line := range lines[1:] {
+ log.Printf("%*s %s", len(l), "", line)
+ }
+}
+
+func TestOne(t *testing.T) {
+ flag.Var(&extraTestArgs, "extra_test_arg", "extra arguments to pass to the testbench")
+ flag.Parse()
+ if *dutPlatform != "linux" && *dutPlatform != "netstack" {
+ t.Fatal("--dut_platform should be either linux or netstack")
+ }
+ if *testbenchBinary == "" {
+ t.Fatal("--testbench_binary is missing")
+ }
+ if *dutPlatform == "netstack" {
+ if _, err := dockerutil.RuntimePath(); err != nil {
+ t.Fatal("--runtime is missing or invalid with --dut_platform=netstack:", err)
+ }
+ }
+ dockerutil.EnsureSupportedDockerVersion()
+ ctx := context.Background()
+
+ // Create the networks needed for the test. One control network is needed for
+ // the gRPC control packets and one test network on which to transmit the test
+ // packets.
+ ctrlNet := dockerutil.NewNetwork(ctx, logger("ctrlNet"))
+ testNet := dockerutil.NewNetwork(ctx, logger("testNet"))
+ for _, dn := range []*dockerutil.Network{ctrlNet, testNet} {
+ for {
+ if err := createDockerNetwork(ctx, dn); err != nil {
+ t.Log("creating docker network:", err)
+ const wait = 100 * time.Millisecond
+ t.Logf("sleeping %s and will try creating docker network again", wait)
+ // This can fail if another docker network claimed the same IP so we'll
+ // just try again.
+ time.Sleep(wait)
+ continue
+ }
+ break
+ }
+ defer func(dn *dockerutil.Network) {
+ if err := dn.Cleanup(ctx); err != nil {
+ t.Errorf("unable to cleanup container %s: %s", dn.Name, err)
+ }
+ }(dn)
+ // Sanity check.
+ inspect, err := dn.Inspect(ctx)
+ if err != nil {
+ t.Fatalf("failed to inspect network %s: %v", dn.Name, err)
+ } else if inspect.Name != dn.Name {
+ t.Fatalf("name mismatch for network want: %s got: %s", dn.Name, inspect.Name)
+ }
+
+ }
+
+ tmpDir, err := ioutil.TempDir("", "container-output")
+ if err != nil {
+ t.Fatal("creating temp dir:", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ const testOutputDir = "/tmp/testoutput"
+
+ // Create the Docker container for the DUT.
+ dut := dockerutil.MakeContainer(ctx, logger("dut"))
+ if *dutPlatform == "linux" {
+ dut.Runtime = ""
+ }
+
+ runOpts := dockerutil.RunOpts{
+ Image: "packetimpact",
+ CapAdd: []string{"NET_ADMIN"},
+ Mounts: []mount.Mount{mount.Mount{
+ Type: mount.TypeBind,
+ Source: tmpDir,
+ Target: testOutputDir,
+ ReadOnly: false,
+ }},
+ }
+
+ const containerPosixServerBinary = "/packetimpact/posix_server"
+ dut.CopyFiles(&runOpts, "/packetimpact", "/test/packetimpact/dut/posix_server")
+
+ conf, hostconf, _ := dut.ConfigsFrom(runOpts, containerPosixServerBinary, "--ip=0.0.0.0", "--port="+ctrlPort)
+ hostconf.AutoRemove = true
+ hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"}
+
+ if err := dut.CreateFrom(ctx, conf, hostconf, nil); err != nil {
+ t.Fatalf("unable to create container %s: %v", dut.Name, err)
+ }
+
+ defer dut.CleanUp(ctx)
+
+ // Add ctrlNet as eth1 and testNet as eth2.
+ const testNetDev = "eth2"
+ if err := addNetworks(ctx, dut, dutAddr, []*dockerutil.Network{ctrlNet, testNet}); err != nil {
+ t.Fatal(err)
+ }
+
+ if err := dut.Start(ctx); err != nil {
+ t.Fatalf("unable to start container %s: %s", dut.Name, err)
+ }
+
+ if _, err := dut.WaitForOutput(ctx, "Server listening.*\n", 60*time.Second); err != nil {
+ t.Fatalf("%s on container %s never listened: %s", containerPosixServerBinary, dut.Name, err)
+ }
+
+ dutTestDevice, dutDeviceInfo, err := deviceByIP(ctx, dut, addressInSubnet(dutAddr, *testNet.Subnet))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ remoteMAC := dutDeviceInfo.MAC
+ remoteIPv6 := dutDeviceInfo.IPv6Addr
+ // Netstack as DUT doesn't assign IPv6 addresses automatically so do it if
+ // needed.
+ if remoteIPv6 == nil {
+ if _, err := dut.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "add", netdevs.MACToIP(remoteMAC).String(), "scope", "link", "dev", dutTestDevice); err != nil {
+ t.Fatalf("unable to ip addr add on container %s: %s", dut.Name, err)
+ }
+ // Now try again, to make sure that it worked.
+ _, dutDeviceInfo, err = deviceByIP(ctx, dut, addressInSubnet(dutAddr, *testNet.Subnet))
+ if err != nil {
+ t.Fatal(err)
+ }
+ remoteIPv6 = dutDeviceInfo.IPv6Addr
+ if remoteIPv6 == nil {
+ t.Fatal("unable to set IPv6 address on container", dut.Name)
+ }
+ }
+
+ // Create the Docker container for the testbench.
+ testbench := dockerutil.MakeContainer(ctx, logger("testbench"))
+ testbench.Runtime = "" // The testbench always runs on Linux.
+
+ tbb := path.Base(*testbenchBinary)
+ containerTestbenchBinary := "/packetimpact/" + tbb
+ runOpts = dockerutil.RunOpts{
+ Image: "packetimpact",
+ CapAdd: []string{"NET_ADMIN"},
+ Mounts: []mount.Mount{mount.Mount{
+ Type: mount.TypeBind,
+ Source: tmpDir,
+ Target: testOutputDir,
+ ReadOnly: false,
+ }},
+ }
+ testbench.CopyFiles(&runOpts, "/packetimpact", "/test/packetimpact/tests/"+tbb)
+
+ // Run tcpdump in the test bench unbuffered, without DNS resolution, just on
+ // the interface with the test packets.
+ snifferArgs := []string{
+ "tcpdump",
+ "-S", "-vvv", "-U", "-n",
+ "-i", testNetDev,
+ "-w", testOutputDir + "/dump.pcap",
+ }
+ snifferRegex := "tcpdump: listening.*\n"
+ if *tshark {
+ // Run tshark in the test bench unbuffered, without DNS resolution, just on
+ // the interface with the test packets.
+ snifferArgs = []string{
+ "tshark", "-V", "-l", "-n", "-i", testNetDev,
+ "-o", "tcp.check_checksum:TRUE",
+ "-o", "udp.check_checksum:TRUE",
+ }
+ snifferRegex = "Capturing on.*\n"
+ }
+
+ defer func() {
+ if err := exec.Command("/bin/cp", "-r", tmpDir, os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR")).Run(); err != nil {
+ t.Error("unable to copy container output files:", err)
+ }
+ }()
+
+ conf, hostconf, _ = testbench.ConfigsFrom(runOpts, snifferArgs...)
+ hostconf.AutoRemove = true
+ hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"}
+
+ if err := testbench.CreateFrom(ctx, conf, hostconf, nil); err != nil {
+ t.Fatalf("unable to create container %s: %s", testbench.Name, err)
+ }
+ defer testbench.CleanUp(ctx)
+
+ // Add ctrlNet as eth1 and testNet as eth2.
+ if err := addNetworks(ctx, testbench, testbenchAddr, []*dockerutil.Network{ctrlNet, testNet}); err != nil {
+ t.Fatal(err)
+ }
+
+ if err := testbench.Start(ctx); err != nil {
+ t.Fatalf("unable to start container %s: %s", testbench.Name, err)
+ }
+
+ // Kill so that it will flush output.
+ defer func() {
+ time.Sleep(1 * time.Second)
+ testbench.Exec(ctx, dockerutil.ExecOpts{}, "killall", snifferArgs[0])
+ }()
+
+ if _, err := testbench.WaitForOutput(ctx, snifferRegex, 60*time.Second); err != nil {
+ t.Fatalf("sniffer on %s never listened: %s", dut.Name, err)
+ }
+
+ // Because the Linux kernel receives the SYN-ACK but didn't send the SYN it
+ // will issue a RST. To prevent this IPtables can be used to filter out all
+ // incoming packets. The raw socket that packetimpact tests use will still see
+ // everything.
+ if _, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, "iptables", "-A", "INPUT", "-i", testNetDev, "-j", "DROP"); err != nil {
+ t.Fatalf("unable to Exec iptables on container %s: %s", testbench.Name, err)
+ }
+
+ // FIXME(b/156449515): Some piece of the system has a race. The old
+ // bash script version had a sleep, so we have one too. The race should
+ // be fixed and this sleep removed.
+ time.Sleep(time.Second)
+
+ // Start a packetimpact test on the test bench. The packetimpact test sends
+ // and receives packets and also sends POSIX socket commands to the
+ // posix_server to be executed on the DUT.
+ testArgs := []string{containerTestbenchBinary}
+ testArgs = append(testArgs, extraTestArgs...)
+ testArgs = append(testArgs,
+ "--posix_server_ip", addressInSubnet(dutAddr, *ctrlNet.Subnet).String(),
+ "--posix_server_port", ctrlPort,
+ "--remote_ipv4", addressInSubnet(dutAddr, *testNet.Subnet).String(),
+ "--local_ipv4", addressInSubnet(testbenchAddr, *testNet.Subnet).String(),
+ "--remote_ipv6", remoteIPv6.String(),
+ "--remote_mac", remoteMAC.String(),
+ "--device", testNetDev,
+ "--dut_type", *dutPlatform,
+ )
+ _, err = testbench.Exec(ctx, dockerutil.ExecOpts{}, testArgs...)
+ if !*expectFailure && err != nil {
+ t.Fatal("test failed:", err)
+ }
+ if *expectFailure && err == nil {
+ t.Fatal("test failure expected but the test succeeded, enable the test and mark the corresponding bug as fixed")
+ }
+}
+
+func addNetworks(ctx context.Context, d *dockerutil.Container, addr net.IP, networks []*dockerutil.Network) error {
+ for _, dn := range networks {
+ ip := addressInSubnet(addr, *dn.Subnet)
+ // Connect to the network with the specified IP address.
+ if err := dn.Connect(ctx, d, ip.String(), ""); err != nil {
+ return fmt.Errorf("unable to connect container %s to network %s: %w", d.Name, dn.Name, err)
+ }
+ }
+ return nil
+}
+
+// addressInSubnet combines the subnet provided with the address and returns a
+// new address. The return address bits come from the subnet where the mask is 1
+// and from the ip address where the mask is 0.
+func addressInSubnet(addr net.IP, subnet net.IPNet) net.IP {
+ var octets []byte
+ for i := 0; i < 4; i++ {
+ octets = append(octets, (subnet.IP.To4()[i]&subnet.Mask[i])+(addr.To4()[i]&(^subnet.Mask[i])))
+ }
+ return net.IP(octets)
+}
+
+// createDockerNetwork makes a randomly-named network that will start with the
+// namePrefix. The network will be a random /24 subnet.
+func createDockerNetwork(ctx context.Context, n *dockerutil.Network) error {
+ randSource := rand.NewSource(time.Now().UnixNano())
+ r1 := rand.New(randSource)
+ // Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24.
+ ip := net.IPv4(byte(r1.Intn(224-192)+192), byte(r1.Intn(256)), byte(r1.Intn(256)), 0)
+ n.Subnet = &net.IPNet{
+ IP: ip,
+ Mask: ip.DefaultMask(),
+ }
+ return n.Create(ctx)
+}
+
+// deviceByIP finds a deviceInfo and device name from an IP address.
+func deviceByIP(ctx context.Context, d *dockerutil.Container, ip net.IP) (string, netdevs.DeviceInfo, error) {
+ out, err := d.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "show")
+ if err != nil {
+ return "", netdevs.DeviceInfo{}, fmt.Errorf("listing devices on %s container: %w", d.Name, err)
+ }
+ devs, err := netdevs.ParseDevices(out)
+ if err != nil {
+ return "", netdevs.DeviceInfo{}, fmt.Errorf("parsing devices from %s container: %w", d.Name, err)
+ }
+ testDevice, deviceInfo, err := netdevs.FindDeviceByIP(ip, devs)
+ if err != nil {
+ return "", netdevs.DeviceInfo{}, fmt.Errorf("can't find deviceInfo for container %s: %w", d.Name, err)
+ }
+ return testDevice, deviceInfo, nil
+}
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD
new file mode 100644
index 000000000..d19ec07d4
--- /dev/null
+++ b/test/packetimpact/testbench/BUILD
@@ -0,0 +1,46 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(
+ default_visibility = ["//test/packetimpact:__subpackages__"],
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "testbench",
+ srcs = [
+ "connections.go",
+ "dut.go",
+ "dut_client.go",
+ "layers.go",
+ "rawsockets.go",
+ "testbench.go",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//pkg/usermem",
+ "//test/packetimpact/netdevs",
+ "//test/packetimpact/proto:posix_server_go_proto",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
+ "@com_github_mohae_deepcopy//:go_default_library",
+ "@org_golang_google_grpc//:go_default_library",
+ "@org_golang_google_grpc//keepalive:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ "@org_uber_go_multierr//:go_default_library",
+ ],
+)
+
+go_test(
+ name = "testbench_test",
+ size = "small",
+ srcs = ["layers_test.go"],
+ library = ":testbench",
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "@com_github_mohae_deepcopy//:go_default_library",
+ ],
+)
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
new file mode 100644
index 000000000..8b4a4d905
--- /dev/null
+++ b/test/packetimpact/testbench/connections.go
@@ -0,0 +1,950 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testbench has utilities to send and receive packets and also command
+// the DUT to run POSIX functions.
+package testbench
+
+import (
+ "fmt"
+ "math/rand"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/mohae/deepcopy"
+ "go.uber.org/multierr"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+func portFromSockaddr(sa unix.Sockaddr) (uint16, error) {
+ switch sa := sa.(type) {
+ case *unix.SockaddrInet4:
+ return uint16(sa.Port), nil
+ case *unix.SockaddrInet6:
+ return uint16(sa.Port), nil
+ }
+ return 0, fmt.Errorf("sockaddr type %T does not contain port", sa)
+}
+
+// pickPort makes a new socket and returns the socket FD and port. The domain should be AF_INET or AF_INET6. The caller must close the FD when done with
+// the port if there is no error.
+func pickPort(domain, typ int) (int, uint16, error) {
+ fd, err := unix.Socket(domain, typ, 0)
+ if err != nil {
+ return -1, 0, err
+ }
+ defer func() {
+ if err != nil {
+ err = multierr.Append(err, unix.Close(fd))
+ }
+ }()
+ var sa unix.Sockaddr
+ switch domain {
+ case unix.AF_INET:
+ var sa4 unix.SockaddrInet4
+ copy(sa4.Addr[:], net.ParseIP(LocalIPv4).To4())
+ sa = &sa4
+ case unix.AF_INET6:
+ var sa6 unix.SockaddrInet6
+ copy(sa6.Addr[:], net.ParseIP(LocalIPv6).To16())
+ sa = &sa6
+ default:
+ return -1, 0, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain)
+ }
+ if err = unix.Bind(fd, sa); err != nil {
+ return -1, 0, err
+ }
+ sa, err = unix.Getsockname(fd)
+ if err != nil {
+ return -1, 0, err
+ }
+ port, err := portFromSockaddr(sa)
+ if err != nil {
+ return -1, 0, err
+ }
+ return fd, port, nil
+}
+
+// layerState stores the state of a layer of a connection.
+type layerState interface {
+ // outgoing returns an outgoing layer to be sent in a frame. It should not
+ // update layerState, that is done in layerState.sent.
+ outgoing() Layer
+
+ // incoming creates an expected Layer for comparing against a received Layer.
+ // Because the expectation can depend on values in the received Layer, it is
+ // an input to incoming. For example, the ACK number needs to be checked in a
+ // TCP packet but only if the ACK flag is set in the received packet. It
+ // should not update layerState, that is done in layerState.received. The
+ // caller takes ownership of the returned Layer.
+ incoming(received Layer) Layer
+
+ // sent updates the layerState based on the Layer that was sent. The input is
+ // a Layer with all prev and next pointers populated so that the entire frame
+ // as it was sent is available.
+ sent(sent Layer) error
+
+ // received updates the layerState based on a Layer that is receieved. The
+ // input is a Layer with all prev and next pointers populated so that the
+ // entire frame as it was receieved is available.
+ received(received Layer) error
+
+ // close frees associated resources held by the LayerState.
+ close() error
+}
+
+// etherState maintains state about an Ethernet connection.
+type etherState struct {
+ out, in Ether
+}
+
+var _ layerState = (*etherState)(nil)
+
+// newEtherState creates a new etherState.
+func newEtherState(out, in Ether) (*etherState, error) {
+ lMAC, err := tcpip.ParseMACAddress(LocalMAC)
+ if err != nil {
+ return nil, fmt.Errorf("parsing local MAC: %q: %w", LocalMAC, err)
+ }
+
+ rMAC, err := tcpip.ParseMACAddress(RemoteMAC)
+ if err != nil {
+ return nil, fmt.Errorf("parsing remote MAC: %q: %w", RemoteMAC, err)
+ }
+ s := etherState{
+ out: Ether{SrcAddr: &lMAC, DstAddr: &rMAC},
+ in: Ether{SrcAddr: &rMAC, DstAddr: &lMAC},
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
+
+func (s *etherState) outgoing() Layer {
+ return deepcopy.Copy(&s.out).(Layer)
+}
+
+// incoming implements layerState.incoming.
+func (s *etherState) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (*etherState) sent(Layer) error {
+ return nil
+}
+
+func (*etherState) received(Layer) error {
+ return nil
+}
+
+func (*etherState) close() error {
+ return nil
+}
+
+// ipv4State maintains state about an IPv4 connection.
+type ipv4State struct {
+ out, in IPv4
+}
+
+var _ layerState = (*ipv4State)(nil)
+
+// newIPv4State creates a new ipv4State.
+func newIPv4State(out, in IPv4) (*ipv4State, error) {
+ lIP := tcpip.Address(net.ParseIP(LocalIPv4).To4())
+ rIP := tcpip.Address(net.ParseIP(RemoteIPv4).To4())
+ s := ipv4State{
+ out: IPv4{SrcAddr: &lIP, DstAddr: &rIP},
+ in: IPv4{SrcAddr: &rIP, DstAddr: &lIP},
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
+
+func (s *ipv4State) outgoing() Layer {
+ return deepcopy.Copy(&s.out).(Layer)
+}
+
+// incoming implements layerState.incoming.
+func (s *ipv4State) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (*ipv4State) sent(Layer) error {
+ return nil
+}
+
+func (*ipv4State) received(Layer) error {
+ return nil
+}
+
+func (*ipv4State) close() error {
+ return nil
+}
+
+// ipv6State maintains state about an IPv6 connection.
+type ipv6State struct {
+ out, in IPv6
+}
+
+var _ layerState = (*ipv6State)(nil)
+
+// newIPv6State creates a new ipv6State.
+func newIPv6State(out, in IPv6) (*ipv6State, error) {
+ lIP := tcpip.Address(net.ParseIP(LocalIPv6).To16())
+ rIP := tcpip.Address(net.ParseIP(RemoteIPv6).To16())
+ s := ipv6State{
+ out: IPv6{SrcAddr: &lIP, DstAddr: &rIP},
+ in: IPv6{SrcAddr: &rIP, DstAddr: &lIP},
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
+
+// outgoing returns an outgoing layer to be sent in a frame.
+func (s *ipv6State) outgoing() Layer {
+ return deepcopy.Copy(&s.out).(Layer)
+}
+
+func (s *ipv6State) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (s *ipv6State) sent(Layer) error {
+ // Nothing to do.
+ return nil
+}
+
+func (s *ipv6State) received(Layer) error {
+ // Nothing to do.
+ return nil
+}
+
+// close cleans up any resources held.
+func (s *ipv6State) close() error {
+ return nil
+}
+
+// tcpState maintains state about a TCP connection.
+type tcpState struct {
+ out, in TCP
+ localSeqNum, remoteSeqNum *seqnum.Value
+ synAck *TCP
+ portPickerFD int
+ finSent bool
+}
+
+var _ layerState = (*tcpState)(nil)
+
+// SeqNumValue is a helper routine that allocates a new seqnum.Value value to
+// store v and returns a pointer to it.
+func SeqNumValue(v seqnum.Value) *seqnum.Value {
+ return &v
+}
+
+// newTCPState creates a new TCPState.
+func newTCPState(domain int, out, in TCP) (*tcpState, error) {
+ portPickerFD, localPort, err := pickPort(domain, unix.SOCK_STREAM)
+ if err != nil {
+ return nil, err
+ }
+ s := tcpState{
+ out: TCP{SrcPort: &localPort},
+ in: TCP{DstPort: &localPort},
+ localSeqNum: SeqNumValue(seqnum.Value(rand.Uint32())),
+ portPickerFD: portPickerFD,
+ finSent: false,
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
+
+func (s *tcpState) outgoing() Layer {
+ newOutgoing := deepcopy.Copy(s.out).(TCP)
+ if s.localSeqNum != nil {
+ newOutgoing.SeqNum = Uint32(uint32(*s.localSeqNum))
+ }
+ if s.remoteSeqNum != nil {
+ newOutgoing.AckNum = Uint32(uint32(*s.remoteSeqNum))
+ }
+ return &newOutgoing
+}
+
+// incoming implements layerState.incoming.
+func (s *tcpState) incoming(received Layer) Layer {
+ tcpReceived, ok := received.(*TCP)
+ if !ok {
+ return nil
+ }
+ newIn := deepcopy.Copy(s.in).(TCP)
+ if s.remoteSeqNum != nil {
+ newIn.SeqNum = Uint32(uint32(*s.remoteSeqNum))
+ }
+ if s.localSeqNum != nil && (*tcpReceived.Flags&header.TCPFlagAck) != 0 {
+ // The caller didn't specify an AckNum so we'll expect the calculated one,
+ // but only if the ACK flag is set because the AckNum is not valid in a
+ // header if ACK is not set.
+ newIn.AckNum = Uint32(uint32(*s.localSeqNum))
+ }
+ return &newIn
+}
+
+func (s *tcpState) sent(sent Layer) error {
+ tcp, ok := sent.(*TCP)
+ if !ok {
+ return fmt.Errorf("can't update tcpState with %T Layer", sent)
+ }
+ if !s.finSent {
+ // update localSeqNum by the payload only when FIN is not yet sent by us
+ for current := tcp.next(); current != nil; current = current.next() {
+ s.localSeqNum.UpdateForward(seqnum.Size(current.length()))
+ }
+ }
+ if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
+ s.localSeqNum.UpdateForward(1)
+ }
+ if *tcp.Flags&(header.TCPFlagFin) != 0 {
+ s.finSent = true
+ }
+ return nil
+}
+
+func (s *tcpState) received(l Layer) error {
+ tcp, ok := l.(*TCP)
+ if !ok {
+ return fmt.Errorf("can't update tcpState with %T Layer", l)
+ }
+ s.remoteSeqNum = SeqNumValue(seqnum.Value(*tcp.SeqNum))
+ if *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
+ s.remoteSeqNum.UpdateForward(1)
+ }
+ for current := tcp.next(); current != nil; current = current.next() {
+ s.remoteSeqNum.UpdateForward(seqnum.Size(current.length()))
+ }
+ return nil
+}
+
+// close frees the port associated with this connection.
+func (s *tcpState) close() error {
+ if err := unix.Close(s.portPickerFD); err != nil {
+ return err
+ }
+ s.portPickerFD = -1
+ return nil
+}
+
+// udpState maintains state about a UDP connection.
+type udpState struct {
+ out, in UDP
+ portPickerFD int
+}
+
+var _ layerState = (*udpState)(nil)
+
+// newUDPState creates a new udpState.
+func newUDPState(domain int, out, in UDP) (*udpState, error) {
+ portPickerFD, localPort, err := pickPort(domain, unix.SOCK_DGRAM)
+ if err != nil {
+ return nil, err
+ }
+ s := udpState{
+ out: UDP{SrcPort: &localPort},
+ in: UDP{DstPort: &localPort},
+ portPickerFD: portPickerFD,
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
+
+func (s *udpState) outgoing() Layer {
+ return deepcopy.Copy(&s.out).(Layer)
+}
+
+// incoming implements layerState.incoming.
+func (s *udpState) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (*udpState) sent(l Layer) error {
+ return nil
+}
+
+func (*udpState) received(l Layer) error {
+ return nil
+}
+
+// close frees the port associated with this connection.
+func (s *udpState) close() error {
+ if err := unix.Close(s.portPickerFD); err != nil {
+ return err
+ }
+ s.portPickerFD = -1
+ return nil
+}
+
+// Connection holds a collection of layer states for maintaining a connection
+// along with sockets for sniffer and injecting packets.
+type Connection struct {
+ layerStates []layerState
+ injector Injector
+ sniffer Sniffer
+ t *testing.T
+}
+
+// Returns the default incoming frame against which to match. If received is
+// longer than layerStates then that may still count as a match. The reverse is
+// never a match and nil is returned.
+func (conn *Connection) incoming(received Layers) Layers {
+ if len(received) < len(conn.layerStates) {
+ return nil
+ }
+ in := Layers{}
+ for i, s := range conn.layerStates {
+ toMatch := s.incoming(received[i])
+ if toMatch == nil {
+ return nil
+ }
+ in = append(in, toMatch)
+ }
+ return in
+}
+
+func (conn *Connection) match(override, received Layers) bool {
+ toMatch := conn.incoming(received)
+ if toMatch == nil {
+ return false // Not enough layers in gotLayers for matching.
+ }
+ if err := toMatch.merge(override); err != nil {
+ return false // Failing to merge is not matching.
+ }
+ return toMatch.match(received)
+}
+
+// Close frees associated resources held by the Connection.
+func (conn *Connection) Close() {
+ errs := multierr.Combine(conn.sniffer.close(), conn.injector.close())
+ for _, s := range conn.layerStates {
+ if err := s.close(); err != nil {
+ errs = multierr.Append(errs, fmt.Errorf("unable to close %+v: %s", s, err))
+ }
+ }
+ if errs != nil {
+ conn.t.Fatalf("unable to close %+v: %s", conn, errs)
+ }
+}
+
+// CreateFrame builds a frame for the connection with defaults overriden
+// from the innermost layer out, and additionalLayers added after it.
+//
+// Note that overrideLayers can have a length that is less than the number
+// of layers in this connection, and in such cases the innermost layers are
+// overriden first. As an example, valid values of overrideLayers for a TCP-
+// over-IPv4-over-Ethernet connection are: nil, [TCP], [IPv4, TCP], and
+// [Ethernet, IPv4, TCP].
+func (conn *Connection) CreateFrame(overrideLayers Layers, additionalLayers ...Layer) Layers {
+ var layersToSend Layers
+ for i, s := range conn.layerStates {
+ layer := s.outgoing()
+ // overrideLayers and conn.layerStates have their tails aligned, so
+ // to find the index we move backwards by the distance i is to the
+ // end.
+ if j := len(overrideLayers) - (len(conn.layerStates) - i); j >= 0 {
+ if err := layer.merge(overrideLayers[j]); err != nil {
+ conn.t.Fatalf("can't merge %+v into %+v: %s", layer, overrideLayers[j], err)
+ }
+ }
+ layersToSend = append(layersToSend, layer)
+ }
+ layersToSend = append(layersToSend, additionalLayers...)
+ return layersToSend
+}
+
+// SendFrameStateless sends a frame without updating any of the layer states.
+//
+// This method is useful for sending out-of-band control messages such as
+// ICMP packets, where it would not make sense to update the transport layer's
+// state using the ICMP header.
+func (conn *Connection) SendFrameStateless(frame Layers) {
+ outBytes, err := frame.ToBytes()
+ if err != nil {
+ conn.t.Fatalf("can't build outgoing packet: %s", err)
+ }
+ conn.injector.Send(outBytes)
+}
+
+// SendFrame sends a frame on the wire and updates the state of all layers.
+func (conn *Connection) SendFrame(frame Layers) {
+ outBytes, err := frame.ToBytes()
+ if err != nil {
+ conn.t.Fatalf("can't build outgoing packet: %s", err)
+ }
+ conn.injector.Send(outBytes)
+
+ // frame might have nil values where the caller wanted to use default values.
+ // sentFrame will have no nil values in it because it comes from parsing the
+ // bytes that were actually sent.
+ sentFrame := parse(parseEther, outBytes)
+ // Update the state of each layer based on what was sent.
+ for i, s := range conn.layerStates {
+ if err := s.sent(sentFrame[i]); err != nil {
+ conn.t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err)
+ }
+ }
+}
+
+// send sends a packet, possibly with layers of this connection overridden and
+// additional layers added.
+//
+// Types defined with Connection as the underlying type should expose
+// type-safe versions of this method.
+func (conn *Connection) send(overrideLayers Layers, additionalLayers ...Layer) {
+ conn.SendFrame(conn.CreateFrame(overrideLayers, additionalLayers...))
+}
+
+// recvFrame gets the next successfully parsed frame (of type Layers) within the
+// timeout provided. If no parsable frame arrives before the timeout, it returns
+// nil.
+func (conn *Connection) recvFrame(timeout time.Duration) Layers {
+ if timeout <= 0 {
+ return nil
+ }
+ b := conn.sniffer.Recv(timeout)
+ if b == nil {
+ return nil
+ }
+ return parse(parseEther, b)
+}
+
+// layersError stores the Layers that we got and the Layers that we wanted to
+// match.
+type layersError struct {
+ got, want Layers
+}
+
+func (e *layersError) Error() string {
+ return e.got.diff(e.want)
+}
+
+// Expect expects a frame with the final layerStates layer matching the
+// provided Layer within the timeout specified. If it doesn't arrive in time,
+// an error is returned.
+func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) {
+ // Make a frame that will ignore all but the final layer.
+ layers := make([]Layer, len(conn.layerStates))
+ layers[len(layers)-1] = layer
+
+ gotFrame, err := conn.ExpectFrame(layers, timeout)
+ if err != nil {
+ return nil, err
+ }
+ if len(conn.layerStates)-1 < len(gotFrame) {
+ return gotFrame[len(conn.layerStates)-1], nil
+ }
+ conn.t.Fatal("the received frame should be at least as long as the expected layers")
+ panic("unreachable")
+}
+
+// ExpectFrame expects a frame that matches the provided Layers within the
+// timeout specified. If one arrives in time, the Layers is returned without an
+// error. If it doesn't arrive in time, it returns nil and error is non-nil.
+func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layers, error) {
+ deadline := time.Now().Add(timeout)
+ var errs error
+ for {
+ var gotLayers Layers
+ if timeout = time.Until(deadline); timeout > 0 {
+ gotLayers = conn.recvFrame(timeout)
+ }
+ if gotLayers == nil {
+ if errs == nil {
+ return nil, fmt.Errorf("got no frames matching %v during %s", layers, timeout)
+ }
+ return nil, fmt.Errorf("got no frames matching %v during %s: got %w", layers, timeout, errs)
+ }
+ if conn.match(layers, gotLayers) {
+ for i, s := range conn.layerStates {
+ if err := s.received(gotLayers[i]); err != nil {
+ conn.t.Fatal(err)
+ }
+ }
+ return gotLayers, nil
+ }
+ errs = multierr.Combine(errs, &layersError{got: gotLayers, want: conn.incoming(gotLayers)})
+ }
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *Connection) Drain() {
+ conn.sniffer.Drain()
+}
+
+// TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection.
+type TCPIPv4 Connection
+
+// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults.
+func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
+ etherState, err := newEtherState(Ether{}, Ether{})
+ if err != nil {
+ t.Fatalf("can't make etherState: %s", err)
+ }
+ ipv4State, err := newIPv4State(IPv4{}, IPv4{})
+ if err != nil {
+ t.Fatalf("can't make ipv4State: %s", err)
+ }
+ tcpState, err := newTCPState(unix.AF_INET, outgoingTCP, incomingTCP)
+ if err != nil {
+ t.Fatalf("can't make tcpState: %s", err)
+ }
+ injector, err := NewInjector(t)
+ if err != nil {
+ t.Fatalf("can't make injector: %s", err)
+ }
+ sniffer, err := NewSniffer(t)
+ if err != nil {
+ t.Fatalf("can't make sniffer: %s", err)
+ }
+
+ return TCPIPv4{
+ layerStates: []layerState{etherState, ipv4State, tcpState},
+ injector: injector,
+ sniffer: sniffer,
+ t: t,
+ }
+}
+
+// Connect performs a TCP 3-way handshake. The input Connection should have a
+// final TCP Layer.
+func (conn *TCPIPv4) Connect() {
+ conn.t.Helper()
+
+ // Send the SYN.
+ conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)})
+
+ // Wait for the SYN-ACK.
+ synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ conn.t.Fatalf("didn't get synack during handshake: %s", err)
+ }
+ conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
+
+ // Send an ACK.
+ conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)})
+}
+
+// ConnectWithOptions performs a TCP 3-way handshake with given TCP options.
+// The input Connection should have a final TCP Layer.
+func (conn *TCPIPv4) ConnectWithOptions(options []byte) {
+ conn.t.Helper()
+
+ // Send the SYN.
+ conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn), Options: options})
+
+ // Wait for the SYN-ACK.
+ synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ conn.t.Fatalf("didn't get synack during handshake: %s", err)
+ }
+ conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
+
+ // Send an ACK.
+ conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)})
+}
+
+// ExpectData is a convenient method that expects a Layer and the Layer after
+// it. If it doens't arrive in time, it returns nil.
+func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+ expected := make([]Layer, len(conn.layerStates))
+ expected[len(expected)-1] = tcp
+ if payload != nil {
+ expected = append(expected, payload)
+ }
+ return (*Connection)(conn).ExpectFrame(expected, timeout)
+}
+
+// ExpectNextData attempts to receive the next incoming segment for the
+// connection and expects that to match the given layers.
+//
+// It differs from ExpectData() in that here we are only interested in the next
+// received segment, while ExpectData() can receive multiple segments for the
+// connection until there is a match with given layers or a timeout.
+func (conn *TCPIPv4) ExpectNextData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+ // Receive the first incoming TCP segment for this connection.
+ got, err := conn.ExpectData(&TCP{}, nil, timeout)
+ if err != nil {
+ return nil, err
+ }
+
+ expected := make([]Layer, len(conn.layerStates))
+ expected[len(expected)-1] = tcp
+ if payload != nil {
+ expected = append(expected, payload)
+ tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum()) - uint32(payload.Length()))
+ }
+ if !(*Connection)(conn).match(expected, got) {
+ return nil, fmt.Errorf("next frame is not matching %s during %s: got %s", expected, timeout, got)
+ }
+ return got, nil
+}
+
+// Send a packet with reasonable defaults. Potentially override the TCP layer in
+// the connection with the provided layer and add additionLayers.
+func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) {
+ (*Connection)(conn).send(Layers{&tcp}, additionalLayers...)
+}
+
+// Close frees associated resources held by the TCPIPv4 connection.
+func (conn *TCPIPv4) Close() {
+ (*Connection)(conn).Close()
+}
+
+// Expect expects a frame with the TCP layer matching the provided TCP within
+// the timeout specified. If it doesn't arrive in time, an error is returned.
+func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) {
+ layer, err := (*Connection)(conn).Expect(&tcp, timeout)
+ if layer == nil {
+ return nil, err
+ }
+ gotTCP, ok := layer.(*TCP)
+ if !ok {
+ conn.t.Fatalf("expected %s to be TCP", layer)
+ }
+ return gotTCP, err
+}
+
+func (conn *TCPIPv4) tcpState() *tcpState {
+ state, ok := conn.layerStates[2].(*tcpState)
+ if !ok {
+ conn.t.Fatalf("got transport-layer state type=%T, expected tcpState", conn.layerStates[2])
+ }
+ return state
+}
+
+func (conn *TCPIPv4) ipv4State() *ipv4State {
+ state, ok := conn.layerStates[1].(*ipv4State)
+ if !ok {
+ conn.t.Fatalf("expected network-layer state type=%T, expected ipv4State", conn.layerStates[1])
+ }
+ return state
+}
+
+// RemoteSeqNum returns the next expected sequence number from the DUT.
+func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value {
+ return conn.tcpState().remoteSeqNum
+}
+
+// LocalSeqNum returns the next sequence number to send from the testbench.
+func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value {
+ return conn.tcpState().localSeqNum
+}
+
+// SynAck returns the SynAck that was part of the handshake.
+func (conn *TCPIPv4) SynAck() *TCP {
+ return conn.tcpState().synAck
+}
+
+// LocalAddr gets the local socket address of this connection.
+func (conn *TCPIPv4) LocalAddr() *unix.SockaddrInet4 {
+ sa := &unix.SockaddrInet4{Port: int(*conn.tcpState().out.SrcPort)}
+ copy(sa.Addr[:], *conn.ipv4State().out.SrcAddr)
+ return sa
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *TCPIPv4) Drain() {
+ conn.sniffer.Drain()
+}
+
+// IPv6Conn maintains the state for all the layers in a IPv6 connection.
+type IPv6Conn Connection
+
+// NewIPv6Conn creates a new IPv6Conn connection with reasonable defaults.
+func NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn {
+ etherState, err := newEtherState(Ether{}, Ether{})
+ if err != nil {
+ t.Fatalf("can't make EtherState: %s", err)
+ }
+ ipv6State, err := newIPv6State(outgoingIPv6, incomingIPv6)
+ if err != nil {
+ t.Fatalf("can't make IPv6State: %s", err)
+ }
+
+ injector, err := NewInjector(t)
+ if err != nil {
+ t.Fatalf("can't make injector: %s", err)
+ }
+ sniffer, err := NewSniffer(t)
+ if err != nil {
+ t.Fatalf("can't make sniffer: %s", err)
+ }
+
+ return IPv6Conn{
+ layerStates: []layerState{etherState, ipv6State},
+ injector: injector,
+ sniffer: sniffer,
+ t: t,
+ }
+}
+
+// Send sends a frame with ipv6 overriding the IPv6 layer defaults and
+// additionalLayers added after it.
+func (conn *IPv6Conn) Send(ipv6 IPv6, additionalLayers ...Layer) {
+ (*Connection)(conn).send(Layers{&ipv6}, additionalLayers...)
+}
+
+// Close to clean up any resources held.
+func (conn *IPv6Conn) Close() {
+ (*Connection)(conn).Close()
+}
+
+// ExpectFrame expects a frame that matches the provided Layers within the
+// timeout specified. If it doesn't arrive in time, an error is returned.
+func (conn *IPv6Conn) ExpectFrame(frame Layers, timeout time.Duration) (Layers, error) {
+ return (*Connection)(conn).ExpectFrame(frame, timeout)
+}
+
+// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection.
+type UDPIPv4 Connection
+
+// NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults.
+func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
+ etherState, err := newEtherState(Ether{}, Ether{})
+ if err != nil {
+ t.Fatalf("can't make etherState: %s", err)
+ }
+ ipv4State, err := newIPv4State(IPv4{}, IPv4{})
+ if err != nil {
+ t.Fatalf("can't make ipv4State: %s", err)
+ }
+ udpState, err := newUDPState(unix.AF_INET, outgoingUDP, incomingUDP)
+ if err != nil {
+ t.Fatalf("can't make udpState: %s", err)
+ }
+ injector, err := NewInjector(t)
+ if err != nil {
+ t.Fatalf("can't make injector: %s", err)
+ }
+ sniffer, err := NewSniffer(t)
+ if err != nil {
+ t.Fatalf("can't make sniffer: %s", err)
+ }
+
+ return UDPIPv4{
+ layerStates: []layerState{etherState, ipv4State, udpState},
+ injector: injector,
+ sniffer: sniffer,
+ t: t,
+ }
+}
+
+func (conn *UDPIPv4) udpState() *udpState {
+ state, ok := conn.layerStates[2].(*udpState)
+ if !ok {
+ conn.t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2])
+ }
+ return state
+}
+
+func (conn *UDPIPv4) ipv4State() *ipv4State {
+ state, ok := conn.layerStates[1].(*ipv4State)
+ if !ok {
+ conn.t.Fatalf("got network-layer state type=%T, expected ipv4State", conn.layerStates[1])
+ }
+ return state
+}
+
+// LocalAddr gets the local socket address of this connection.
+func (conn *UDPIPv4) LocalAddr() *unix.SockaddrInet4 {
+ sa := &unix.SockaddrInet4{Port: int(*conn.udpState().out.SrcPort)}
+ copy(sa.Addr[:], *conn.ipv4State().out.SrcAddr)
+ return sa
+}
+
+// Send sends a packet with reasonable defaults, potentially overriding the UDP
+// layer and adding additionLayers.
+func (conn *UDPIPv4) Send(udp UDP, additionalLayers ...Layer) {
+ (*Connection)(conn).send(Layers{&udp}, additionalLayers...)
+}
+
+// SendIP sends a packet with reasonable defaults, potentially overriding the
+// UDP and IPv4 headers and adding additionLayers.
+func (conn *UDPIPv4) SendIP(ip IPv4, udp UDP, additionalLayers ...Layer) {
+ (*Connection)(conn).send(Layers{&ip, &udp}, additionalLayers...)
+}
+
+// Expect expects a frame with the UDP layer matching the provided UDP within
+// the timeout specified. If it doesn't arrive in time, an error is returned.
+func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) (*UDP, error) {
+ conn.t.Helper()
+ layer, err := (*Connection)(conn).Expect(&udp, timeout)
+ if layer == nil {
+ return nil, err
+ }
+ gotUDP, ok := layer.(*UDP)
+ if !ok {
+ conn.t.Fatalf("expected %s to be UDP", layer)
+ }
+ return gotUDP, err
+}
+
+// ExpectData is a convenient method that expects a Layer and the Layer after
+// it. If it doens't arrive in time, it returns nil.
+func (conn *UDPIPv4) ExpectData(udp UDP, payload Payload, timeout time.Duration) (Layers, error) {
+ conn.t.Helper()
+ expected := make([]Layer, len(conn.layerStates))
+ expected[len(expected)-1] = &udp
+ if payload.length() != 0 {
+ expected = append(expected, &payload)
+ }
+ return (*Connection)(conn).ExpectFrame(expected, timeout)
+}
+
+// Close frees associated resources held by the UDPIPv4 connection.
+func (conn *UDPIPv4) Close() {
+ (*Connection)(conn).Close()
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *UDPIPv4) Drain() {
+ conn.sniffer.Drain()
+}
diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go
new file mode 100644
index 000000000..2a2afecb5
--- /dev/null
+++ b/test/packetimpact/testbench/dut.go
@@ -0,0 +1,658 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "context"
+ "flag"
+ "net"
+ "strconv"
+ "syscall"
+ "testing"
+
+ pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto"
+
+ "golang.org/x/sys/unix"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/keepalive"
+)
+
+// DUT communicates with the DUT to force it to make POSIX calls.
+type DUT struct {
+ t *testing.T
+ conn *grpc.ClientConn
+ posixServer POSIXClient
+}
+
+// NewDUT creates a new connection with the DUT over gRPC.
+func NewDUT(t *testing.T) DUT {
+ flag.Parse()
+ if err := genPseudoFlags(); err != nil {
+ t.Fatal("generating psuedo flags:", err)
+ }
+
+ posixServerAddress := POSIXServerIP + ":" + strconv.Itoa(POSIXServerPort)
+ conn, err := grpc.Dial(posixServerAddress, grpc.WithInsecure(), grpc.WithKeepaliveParams(keepalive.ClientParameters{Timeout: RPCKeepalive}))
+ if err != nil {
+ t.Fatalf("failed to grpc.Dial(%s): %s", posixServerAddress, err)
+ }
+ posixServer := NewPOSIXClient(conn)
+ return DUT{
+ t: t,
+ conn: conn,
+ posixServer: posixServer,
+ }
+}
+
+// TearDown closes the underlying connection.
+func (dut *DUT) TearDown() {
+ dut.conn.Close()
+}
+
+func (dut *DUT) sockaddrToProto(sa unix.Sockaddr) *pb.Sockaddr {
+ dut.t.Helper()
+ switch s := sa.(type) {
+ case *unix.SockaddrInet4:
+ return &pb.Sockaddr{
+ Sockaddr: &pb.Sockaddr_In{
+ In: &pb.SockaddrIn{
+ Family: unix.AF_INET,
+ Port: uint32(s.Port),
+ Addr: s.Addr[:],
+ },
+ },
+ }
+ case *unix.SockaddrInet6:
+ return &pb.Sockaddr{
+ Sockaddr: &pb.Sockaddr_In6{
+ In6: &pb.SockaddrIn6{
+ Family: unix.AF_INET6,
+ Port: uint32(s.Port),
+ Flowinfo: 0,
+ ScopeId: s.ZoneId,
+ Addr: s.Addr[:],
+ },
+ },
+ }
+ }
+ dut.t.Fatalf("can't parse Sockaddr: %+v", sa)
+ return nil
+}
+
+func (dut *DUT) protoToSockaddr(sa *pb.Sockaddr) unix.Sockaddr {
+ dut.t.Helper()
+ switch s := sa.Sockaddr.(type) {
+ case *pb.Sockaddr_In:
+ ret := unix.SockaddrInet4{
+ Port: int(s.In.GetPort()),
+ }
+ copy(ret.Addr[:], s.In.GetAddr())
+ return &ret
+ case *pb.Sockaddr_In6:
+ ret := unix.SockaddrInet6{
+ Port: int(s.In6.GetPort()),
+ ZoneId: s.In6.GetScopeId(),
+ }
+ copy(ret.Addr[:], s.In6.GetAddr())
+ }
+ dut.t.Fatalf("can't parse Sockaddr: %+v", sa)
+ return nil
+}
+
+// CreateBoundSocket makes a new socket on the DUT, with type typ and protocol
+// proto, and bound to the IP address addr. Returns the new file descriptor and
+// the port that was selected on the DUT.
+func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16) {
+ dut.t.Helper()
+ var fd int32
+ if addr.To4() != nil {
+ fd = dut.Socket(unix.AF_INET, typ, proto)
+ sa := unix.SockaddrInet4{}
+ copy(sa.Addr[:], addr.To4())
+ dut.Bind(fd, &sa)
+ } else if addr.To16() != nil {
+ fd = dut.Socket(unix.AF_INET6, typ, proto)
+ sa := unix.SockaddrInet6{}
+ copy(sa.Addr[:], addr.To16())
+ dut.Bind(fd, &sa)
+ } else {
+ dut.t.Fatalf("unknown ip addr type for remoteIP")
+ }
+ sa := dut.GetSockName(fd)
+ var port int
+ switch s := sa.(type) {
+ case *unix.SockaddrInet4:
+ port = s.Port
+ case *unix.SockaddrInet6:
+ port = s.Port
+ default:
+ dut.t.Fatalf("unknown sockaddr type from getsockname: %t", sa)
+ }
+ return fd, uint16(port)
+}
+
+// CreateListener makes a new TCP connection. If it fails, the test ends.
+func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) {
+ fd, remotePort := dut.CreateBoundSocket(typ, proto, net.ParseIP(RemoteIPv4))
+ dut.Listen(fd, backlog)
+ return fd, remotePort
+}
+
+// All the functions that make gRPC calls to the POSIX service are below, sorted
+// alphabetically.
+
+// Accept calls accept on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// AcceptWithErrno.
+func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ fd, sa, err := dut.AcceptWithErrno(ctx, sockfd)
+ if fd < 0 {
+ dut.t.Fatalf("failed to accept: %s", err)
+ }
+ return fd, sa
+}
+
+// AcceptWithErrno calls accept on the DUT.
+func (dut *DUT) AcceptWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) {
+ dut.t.Helper()
+ req := pb.AcceptRequest{
+ Sockfd: sockfd,
+ }
+ resp, err := dut.posixServer.Accept(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Accept: %s", err)
+ }
+ return resp.GetFd(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+}
+
+// Bind calls bind on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is
+// needed, use BindWithErrno.
+func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.BindWithErrno(ctx, fd, sa)
+ if ret != 0 {
+ dut.t.Fatalf("failed to bind socket: %s", err)
+ }
+}
+
+// BindWithErrno calls bind on the DUT.
+func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) {
+ dut.t.Helper()
+ req := pb.BindRequest{
+ Sockfd: fd,
+ Addr: dut.sockaddrToProto(sa),
+ }
+ resp, err := dut.posixServer.Bind(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Bind: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// Close calls close on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// CloseWithErrno.
+func (dut *DUT) Close(fd int32) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.CloseWithErrno(ctx, fd)
+ if ret != 0 {
+ dut.t.Fatalf("failed to close: %s", err)
+ }
+}
+
+// CloseWithErrno calls close on the DUT.
+func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) {
+ dut.t.Helper()
+ req := pb.CloseRequest{
+ Fd: fd,
+ }
+ resp, err := dut.posixServer.Close(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Close: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// Connect calls connect on the DUT and causes a fatal test failure if it
+// doesn't succeed. If more control over the timeout or error handling is
+// needed, use ConnectWithErrno.
+func (dut *DUT) Connect(fd int32, sa unix.Sockaddr) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.ConnectWithErrno(ctx, fd, sa)
+ // Ignore 'operation in progress' error that can be returned when the socket
+ // is non-blocking.
+ if err != syscall.Errno(unix.EINPROGRESS) && ret != 0 {
+ dut.t.Fatalf("failed to connect socket: %s", err)
+ }
+}
+
+// ConnectWithErrno calls bind on the DUT.
+func (dut *DUT) ConnectWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) {
+ dut.t.Helper()
+ req := pb.ConnectRequest{
+ Sockfd: fd,
+ Addr: dut.sockaddrToProto(sa),
+ }
+ resp, err := dut.posixServer.Connect(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Connect: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// Fcntl calls fcntl on the DUT and causes a fatal test failure if it
+// doesn't succeed. If more control over the timeout or error handling is
+// needed, use FcntlWithErrno.
+func (dut *DUT) Fcntl(fd, cmd, arg int32) int32 {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.FcntlWithErrno(ctx, fd, cmd, arg)
+ if ret == -1 {
+ dut.t.Fatalf("failed to Fcntl: ret=%d, errno=%s", ret, err)
+ }
+ return ret
+}
+
+// FcntlWithErrno calls fcntl on the DUT.
+func (dut *DUT) FcntlWithErrno(ctx context.Context, fd, cmd, arg int32) (int32, error) {
+ dut.t.Helper()
+ req := pb.FcntlRequest{
+ Fd: fd,
+ Cmd: cmd,
+ Arg: arg,
+ }
+ resp, err := dut.posixServer.Fcntl(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Fcntl: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// GetSockName calls getsockname on the DUT and causes a fatal test failure if
+// it doesn't succeed. If more control over the timeout or error handling is
+// needed, use GetSockNameWithErrno.
+func (dut *DUT) GetSockName(sockfd int32) unix.Sockaddr {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, sa, err := dut.GetSockNameWithErrno(ctx, sockfd)
+ if ret != 0 {
+ dut.t.Fatalf("failed to getsockname: %s", err)
+ }
+ return sa
+}
+
+// GetSockNameWithErrno calls getsockname on the DUT.
+func (dut *DUT) GetSockNameWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) {
+ dut.t.Helper()
+ req := pb.GetSockNameRequest{
+ Sockfd: sockfd,
+ }
+ resp, err := dut.posixServer.GetSockName(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Bind: %s", err)
+ }
+ return resp.GetRet(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+}
+
+func (dut *DUT) getSockOpt(ctx context.Context, sockfd, level, optname, optlen int32, typ pb.GetSockOptRequest_SockOptType) (int32, *pb.SockOptVal, error) {
+ dut.t.Helper()
+ req := pb.GetSockOptRequest{
+ Sockfd: sockfd,
+ Level: level,
+ Optname: optname,
+ Optlen: optlen,
+ Type: typ,
+ }
+ resp, err := dut.posixServer.GetSockOpt(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call GetSockOpt: %s", err)
+ }
+ optval := resp.GetOptval()
+ if optval == nil {
+ dut.t.Fatalf("GetSockOpt response does not contain a value")
+ }
+ return resp.GetRet(), optval, syscall.Errno(resp.GetErrno_())
+}
+
+// GetSockOpt calls getsockopt on the DUT and causes a fatal test failure if it
+// doesn't succeed. If more control over the timeout or error handling is
+// needed, use GetSockOptWithErrno. Because endianess and the width of values
+// might differ between the testbench and DUT architectures, prefer to use a
+// more specific GetSockOptXxx function.
+func (dut *DUT) GetSockOpt(sockfd, level, optname, optlen int32) []byte {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, optval, err := dut.GetSockOptWithErrno(ctx, sockfd, level, optname, optlen)
+ if ret != 0 {
+ dut.t.Fatalf("failed to GetSockOpt: %s", err)
+ }
+ return optval
+}
+
+// GetSockOptWithErrno calls getsockopt on the DUT. Because endianess and the
+// width of values might differ between the testbench and DUT architectures,
+// prefer to use a more specific GetSockOptXxxWithErrno function.
+func (dut *DUT) GetSockOptWithErrno(ctx context.Context, sockfd, level, optname, optlen int32) (int32, []byte, error) {
+ dut.t.Helper()
+ ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, optlen, pb.GetSockOptRequest_BYTES)
+ bytesval, ok := optval.Val.(*pb.SockOptVal_Bytesval)
+ if !ok {
+ dut.t.Fatalf("GetSockOpt got value type: %T, want bytes", optval)
+ }
+ return ret, bytesval.Bytesval, errno
+}
+
+// GetSockOptInt calls getsockopt on the DUT and causes a fatal test failure
+// if it doesn't succeed. If more control over the int optval or error handling
+// is needed, use GetSockOptIntWithErrno.
+func (dut *DUT) GetSockOptInt(sockfd, level, optname int32) int32 {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, intval, err := dut.GetSockOptIntWithErrno(ctx, sockfd, level, optname)
+ if ret != 0 {
+ dut.t.Fatalf("failed to GetSockOptInt: %s", err)
+ }
+ return intval
+}
+
+// GetSockOptIntWithErrno calls getsockopt with an integer optval.
+func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname int32) (int32, int32, error) {
+ dut.t.Helper()
+ ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, 0, pb.GetSockOptRequest_INT)
+ intval, ok := optval.Val.(*pb.SockOptVal_Intval)
+ if !ok {
+ dut.t.Fatalf("GetSockOpt got value type: %T, want int", optval)
+ }
+ return ret, intval.Intval, errno
+}
+
+// GetSockOptTimeval calls getsockopt on the DUT and causes a fatal test failure
+// if it doesn't succeed. If more control over the timeout or error handling is
+// needed, use GetSockOptTimevalWithErrno.
+func (dut *DUT) GetSockOptTimeval(sockfd, level, optname int32) unix.Timeval {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, sockfd, level, optname)
+ if ret != 0 {
+ dut.t.Fatalf("failed to GetSockOptTimeval: %s", err)
+ }
+ return timeval
+}
+
+// GetSockOptTimevalWithErrno calls getsockopt and returns a timeval.
+func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32) (int32, unix.Timeval, error) {
+ dut.t.Helper()
+ ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, 0, pb.GetSockOptRequest_TIME)
+ tv, ok := optval.Val.(*pb.SockOptVal_Timeval)
+ if !ok {
+ dut.t.Fatalf("GetSockOpt got value type: %T, want timeval", optval)
+ }
+ timeval := unix.Timeval{
+ Sec: tv.Timeval.Seconds,
+ Usec: tv.Timeval.Microseconds,
+ }
+ return ret, timeval, errno
+}
+
+// Listen calls listen on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// ListenWithErrno.
+func (dut *DUT) Listen(sockfd, backlog int32) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.ListenWithErrno(ctx, sockfd, backlog)
+ if ret != 0 {
+ dut.t.Fatalf("failed to listen: %s", err)
+ }
+}
+
+// ListenWithErrno calls listen on the DUT.
+func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int32, error) {
+ dut.t.Helper()
+ req := pb.ListenRequest{
+ Sockfd: sockfd,
+ Backlog: backlog,
+ }
+ resp, err := dut.posixServer.Listen(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Listen: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// Send calls send on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// SendWithErrno.
+func (dut *DUT) Send(sockfd int32, buf []byte, flags int32) int32 {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.SendWithErrno(ctx, sockfd, buf, flags)
+ if ret == -1 {
+ dut.t.Fatalf("failed to send: %s", err)
+ }
+ return ret
+}
+
+// SendWithErrno calls send on the DUT.
+func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32) (int32, error) {
+ dut.t.Helper()
+ req := pb.SendRequest{
+ Sockfd: sockfd,
+ Buf: buf,
+ Flags: flags,
+ }
+ resp, err := dut.posixServer.Send(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Send: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// SendTo calls sendto on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// SendToWithErrno.
+func (dut *DUT) SendTo(sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.SendToWithErrno(ctx, sockfd, buf, flags, destAddr)
+ if ret == -1 {
+ dut.t.Fatalf("failed to sendto: %s", err)
+ }
+ return ret
+}
+
+// SendToWithErrno calls sendto on the DUT.
+func (dut *DUT) SendToWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) (int32, error) {
+ dut.t.Helper()
+ req := pb.SendToRequest{
+ Sockfd: sockfd,
+ Buf: buf,
+ Flags: flags,
+ DestAddr: dut.sockaddrToProto(destAddr),
+ }
+ resp, err := dut.posixServer.SendTo(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("faled to call SendTo: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// SetNonBlocking will set O_NONBLOCK flag for fd if nonblocking
+// is true, otherwise it will clear the flag.
+func (dut *DUT) SetNonBlocking(fd int32, nonblocking bool) {
+ dut.t.Helper()
+ flags := dut.Fcntl(fd, unix.F_GETFL, 0)
+ if nonblocking {
+ flags |= unix.O_NONBLOCK
+ } else {
+ flags &= ^unix.O_NONBLOCK
+ }
+ dut.Fcntl(fd, unix.F_SETFL, flags)
+}
+
+func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, optval *pb.SockOptVal) (int32, error) {
+ dut.t.Helper()
+ req := pb.SetSockOptRequest{
+ Sockfd: sockfd,
+ Level: level,
+ Optname: optname,
+ Optval: optval,
+ }
+ resp, err := dut.posixServer.SetSockOpt(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call SetSockOpt: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it
+// doesn't succeed. If more control over the timeout or error handling is
+// needed, use SetSockOptWithErrno. Because endianess and the width of values
+// might differ between the testbench and DUT architectures, prefer to use a
+// more specific SetSockOptXxx function.
+func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.SetSockOptWithErrno(ctx, sockfd, level, optname, optval)
+ if ret != 0 {
+ dut.t.Fatalf("failed to SetSockOpt: %s", err)
+ }
+}
+
+// SetSockOptWithErrno calls setsockopt on the DUT. Because endianess and the
+// width of values might differ between the testbench and DUT architectures,
+// prefer to use a more specific SetSockOptXxxWithErrno function.
+func (dut *DUT) SetSockOptWithErrno(ctx context.Context, sockfd, level, optname int32, optval []byte) (int32, error) {
+ dut.t.Helper()
+ return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Bytesval{optval}})
+}
+
+// SetSockOptInt calls setsockopt on the DUT and causes a fatal test failure
+// if it doesn't succeed. If more control over the int optval or error handling
+// is needed, use SetSockOptIntWithErrno.
+func (dut *DUT) SetSockOptInt(sockfd, level, optname, optval int32) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.SetSockOptIntWithErrno(ctx, sockfd, level, optname, optval)
+ if ret != 0 {
+ dut.t.Fatalf("failed to SetSockOptInt: %s", err)
+ }
+}
+
+// SetSockOptIntWithErrno calls setsockopt with an integer optval.
+func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname, optval int32) (int32, error) {
+ dut.t.Helper()
+ return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Intval{optval}})
+}
+
+// SetSockOptTimeval calls setsockopt on the DUT and causes a fatal test failure
+// if it doesn't succeed. If more control over the timeout or error handling is
+// needed, use SetSockOptTimevalWithErrno.
+func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.SetSockOptTimevalWithErrno(ctx, sockfd, level, optname, tv)
+ if ret != 0 {
+ dut.t.Fatalf("failed to SetSockOptTimeval: %s", err)
+ }
+}
+
+// SetSockOptTimevalWithErrno calls setsockopt with the timeval converted to
+// bytes.
+func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) {
+ dut.t.Helper()
+ timeval := pb.Timeval{
+ Seconds: int64(tv.Sec),
+ Microseconds: int64(tv.Usec),
+ }
+ return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Timeval{&timeval}})
+}
+
+// Socket calls socket on the DUT and returns the file descriptor. If socket
+// fails on the DUT, the test ends.
+func (dut *DUT) Socket(domain, typ, proto int32) int32 {
+ dut.t.Helper()
+ fd, err := dut.SocketWithErrno(domain, typ, proto)
+ if fd < 0 {
+ dut.t.Fatalf("failed to create socket: %s", err)
+ }
+ return fd
+}
+
+// SocketWithErrno calls socket on the DUT and returns the fd and errno.
+func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) {
+ dut.t.Helper()
+ req := pb.SocketRequest{
+ Domain: domain,
+ Type: typ,
+ Protocol: proto,
+ }
+ ctx := context.Background()
+ resp, err := dut.posixServer.Socket(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Socket: %s", err)
+ }
+ return resp.GetFd(), syscall.Errno(resp.GetErrno_())
+}
+
+// Recv calls recv on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// RecvWithErrno.
+func (dut *DUT) Recv(sockfd, len, flags int32) []byte {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, buf, err := dut.RecvWithErrno(ctx, sockfd, len, flags)
+ if ret == -1 {
+ dut.t.Fatalf("failed to recv: %s", err)
+ }
+ return buf
+}
+
+// RecvWithErrno calls recv on the DUT.
+func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (int32, []byte, error) {
+ dut.t.Helper()
+ req := pb.RecvRequest{
+ Sockfd: sockfd,
+ Len: len,
+ Flags: flags,
+ }
+ resp, err := dut.posixServer.Recv(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Recv: %s", err)
+ }
+ return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_())
+}
diff --git a/test/packetimpact/testbench/dut_client.go b/test/packetimpact/testbench/dut_client.go
new file mode 100644
index 000000000..d0e68c5da
--- /dev/null
+++ b/test/packetimpact/testbench/dut_client.go
@@ -0,0 +1,28 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "google.golang.org/grpc"
+ pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto"
+)
+
+// PosixClient is a gRPC client for the Posix service.
+type POSIXClient pb.PosixClient
+
+// NewPOSIXClient makes a new gRPC client for the POSIX service.
+func NewPOSIXClient(c grpc.ClientConnInterface) POSIXClient {
+ return pb.NewPosixClient(c)
+}
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
new file mode 100644
index 000000000..a8121b0da
--- /dev/null
+++ b/test/packetimpact/testbench/layers.go
@@ -0,0 +1,1384 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "encoding/hex"
+ "fmt"
+ "reflect"
+ "strings"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "go.uber.org/multierr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// Layer is the interface that all encapsulations must implement.
+//
+// A Layer is an encapsulation in a packet, such as TCP, IPv4, IPv6, etc. A
+// Layer contains all the fields of the encapsulation. Each field is a pointer
+// and may be nil.
+type Layer interface {
+ fmt.Stringer
+
+ // ToBytes converts the Layer into bytes. In places where the Layer's field
+ // isn't nil, the value that is pointed to is used. When the field is nil, a
+ // reasonable default for the Layer is used. For example, "64" for IPv4 TTL
+ // and a calculated checksum for TCP or IP. Some layers require information
+ // from the previous or next layers in order to compute a default, such as
+ // TCP's checksum or Ethernet's type, so each Layer has a doubly-linked list
+ // to the layer's neighbors.
+ ToBytes() ([]byte, error)
+
+ // match checks if the current Layer matches the provided Layer. If either
+ // Layer has a nil in a given field, that field is considered matching.
+ // Otherwise, the values pointed to by the fields must match. The LayerBase is
+ // ignored.
+ match(Layer) bool
+
+ // length in bytes of the current encapsulation
+ length() int
+
+ // next gets a pointer to the encapsulated Layer.
+ next() Layer
+
+ // prev gets a pointer to the Layer encapsulating this one.
+ Prev() Layer
+
+ // setNext sets the pointer to the encapsulated Layer.
+ setNext(Layer)
+
+ // setPrev sets the pointer to the Layer encapsulating this one.
+ setPrev(Layer)
+
+ // merge overrides the values in the interface with the provided values.
+ merge(Layer) error
+}
+
+// LayerBase is the common elements of all layers.
+type LayerBase struct {
+ nextLayer Layer
+ prevLayer Layer
+}
+
+func (lb *LayerBase) next() Layer {
+ return lb.nextLayer
+}
+
+// Prev returns the previous layer.
+func (lb *LayerBase) Prev() Layer {
+ return lb.prevLayer
+}
+
+func (lb *LayerBase) setNext(l Layer) {
+ lb.nextLayer = l
+}
+
+func (lb *LayerBase) setPrev(l Layer) {
+ lb.prevLayer = l
+}
+
+// equalLayer compares that two Layer structs match while ignoring field in
+// which either input has a nil and also ignoring the LayerBase of the inputs.
+func equalLayer(x, y Layer) bool {
+ if x == nil || y == nil {
+ return true
+ }
+ // opt ignores comparison pairs where either of the inputs is a nil.
+ opt := cmp.FilterValues(func(x, y interface{}) bool {
+ for _, l := range []interface{}{x, y} {
+ v := reflect.ValueOf(l)
+ if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice) && v.IsNil() {
+ return true
+ }
+ }
+ return false
+ }, cmp.Ignore())
+ return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{}))
+}
+
+// mergeLayer merges y into x. Any fields for which y has a non-nil value, that
+// value overwrite the corresponding fields in x.
+func mergeLayer(x, y Layer) error {
+ if y == nil {
+ return nil
+ }
+ if reflect.TypeOf(x) != reflect.TypeOf(y) {
+ return fmt.Errorf("can't merge %T into %T", y, x)
+ }
+ vx := reflect.ValueOf(x).Elem()
+ vy := reflect.ValueOf(y).Elem()
+ t := vy.Type()
+ for i := 0; i < vy.NumField(); i++ {
+ t := t.Field(i)
+ if t.Anonymous {
+ // Ignore the LayerBase in the Layer struct.
+ continue
+ }
+ v := vy.Field(i)
+ if v.IsNil() {
+ continue
+ }
+ vx.Field(i).Set(v)
+ }
+ return nil
+}
+
+func stringLayer(l Layer) string {
+ v := reflect.ValueOf(l).Elem()
+ t := v.Type()
+ var ret []string
+ for i := 0; i < v.NumField(); i++ {
+ t := t.Field(i)
+ if t.Anonymous {
+ // Ignore the LayerBase in the Layer struct.
+ continue
+ }
+ v := v.Field(i)
+ if v.IsNil() {
+ continue
+ }
+ v = reflect.Indirect(v)
+ if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 {
+ ret = append(ret, fmt.Sprintf("%s:\n%v", t.Name, hex.Dump(v.Bytes())))
+ } else {
+ ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v))
+ }
+ }
+ return fmt.Sprintf("&%s{%s}", t, strings.Join(ret, " "))
+}
+
+// Ether can construct and match an ethernet encapsulation.
+type Ether struct {
+ LayerBase
+ SrcAddr *tcpip.LinkAddress
+ DstAddr *tcpip.LinkAddress
+ Type *tcpip.NetworkProtocolNumber
+}
+
+func (l *Ether) String() string {
+ return stringLayer(l)
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *Ether) ToBytes() ([]byte, error) {
+ b := make([]byte, header.EthernetMinimumSize)
+ h := header.Ethernet(b)
+ fields := &header.EthernetFields{}
+ if l.SrcAddr != nil {
+ fields.SrcAddr = *l.SrcAddr
+ }
+ if l.DstAddr != nil {
+ fields.DstAddr = *l.DstAddr
+ }
+ if l.Type != nil {
+ fields.Type = *l.Type
+ } else {
+ switch n := l.next().(type) {
+ case *IPv4:
+ fields.Type = header.IPv4ProtocolNumber
+ case *IPv6:
+ fields.Type = header.IPv6ProtocolNumber
+ default:
+ return nil, fmt.Errorf("ethernet header's next layer is unrecognized: %#v", n)
+ }
+ }
+ h.Encode(fields)
+ return h, nil
+}
+
+// LinkAddress is a helper routine that allocates a new tcpip.LinkAddress value
+// to store v and returns a pointer to it.
+func LinkAddress(v tcpip.LinkAddress) *tcpip.LinkAddress {
+ return &v
+}
+
+// NetworkProtocolNumber is a helper routine that allocates a new
+// tcpip.NetworkProtocolNumber value to store v and returns a pointer to it.
+func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocolNumber {
+ return &v
+}
+
+// layerParser parses the input bytes and returns a Layer along with the next
+// layerParser to run. If there is no more parsing to do, the returned
+// layerParser is nil.
+type layerParser func([]byte) (Layer, layerParser)
+
+// parse parses bytes starting with the first layerParser and using successive
+// layerParsers until all the bytes are parsed.
+func parse(parser layerParser, b []byte) Layers {
+ var layers Layers
+ for {
+ var layer Layer
+ layer, parser = parser(b)
+ layers = append(layers, layer)
+ if parser == nil {
+ break
+ }
+ b = b[layer.length():]
+ }
+ layers.linkLayers()
+ return layers
+}
+
+// parseEther parses the bytes assuming that they start with an ethernet header
+// and continues parsing further encapsulations.
+func parseEther(b []byte) (Layer, layerParser) {
+ h := header.Ethernet(b)
+ ether := Ether{
+ SrcAddr: LinkAddress(h.SourceAddress()),
+ DstAddr: LinkAddress(h.DestinationAddress()),
+ Type: NetworkProtocolNumber(h.Type()),
+ }
+ var nextParser layerParser
+ switch h.Type() {
+ case header.IPv4ProtocolNumber:
+ nextParser = parseIPv4
+ case header.IPv6ProtocolNumber:
+ nextParser = parseIPv6
+ default:
+ // Assume that the rest is a payload.
+ nextParser = parsePayload
+ }
+ return &ether, nextParser
+}
+
+func (l *Ether) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *Ether) length() int {
+ return header.EthernetMinimumSize
+}
+
+// merge implements Layer.merge.
+func (l *Ether) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// IPv4 can construct and match an IPv4 encapsulation.
+type IPv4 struct {
+ LayerBase
+ IHL *uint8
+ TOS *uint8
+ TotalLength *uint16
+ ID *uint16
+ Flags *uint8
+ FragmentOffset *uint16
+ TTL *uint8
+ Protocol *uint8
+ Checksum *uint16
+ SrcAddr *tcpip.Address
+ DstAddr *tcpip.Address
+}
+
+func (l *IPv4) String() string {
+ return stringLayer(l)
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *IPv4) ToBytes() ([]byte, error) {
+ b := make([]byte, header.IPv4MinimumSize)
+ h := header.IPv4(b)
+ fields := &header.IPv4Fields{
+ IHL: 20,
+ TOS: 0,
+ TotalLength: 0,
+ ID: 0,
+ Flags: 0,
+ FragmentOffset: 0,
+ TTL: 64,
+ Protocol: 0,
+ Checksum: 0,
+ SrcAddr: tcpip.Address(""),
+ DstAddr: tcpip.Address(""),
+ }
+ if l.TOS != nil {
+ fields.TOS = *l.TOS
+ }
+ if l.TotalLength != nil {
+ fields.TotalLength = *l.TotalLength
+ } else {
+ fields.TotalLength = uint16(l.length())
+ current := l.next()
+ for current != nil {
+ fields.TotalLength += uint16(current.length())
+ current = current.next()
+ }
+ }
+ if l.ID != nil {
+ fields.ID = *l.ID
+ }
+ if l.Flags != nil {
+ fields.Flags = *l.Flags
+ }
+ if l.FragmentOffset != nil {
+ fields.FragmentOffset = *l.FragmentOffset
+ }
+ if l.TTL != nil {
+ fields.TTL = *l.TTL
+ }
+ if l.Protocol != nil {
+ fields.Protocol = *l.Protocol
+ } else {
+ switch n := l.next().(type) {
+ case *TCP:
+ fields.Protocol = uint8(header.TCPProtocolNumber)
+ case *UDP:
+ fields.Protocol = uint8(header.UDPProtocolNumber)
+ case *ICMPv4:
+ fields.Protocol = uint8(header.ICMPv4ProtocolNumber)
+ default:
+ // TODO(b/150301488): Support more protocols as needed.
+ return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n)
+ }
+ }
+ if l.SrcAddr != nil {
+ fields.SrcAddr = *l.SrcAddr
+ }
+ if l.DstAddr != nil {
+ fields.DstAddr = *l.DstAddr
+ }
+ if l.Checksum != nil {
+ fields.Checksum = *l.Checksum
+ }
+ h.Encode(fields)
+ if l.Checksum == nil {
+ h.SetChecksum(^h.CalculateChecksum())
+ }
+ return h, nil
+}
+
+// Uint16 is a helper routine that allocates a new
+// uint16 value to store v and returns a pointer to it.
+func Uint16(v uint16) *uint16 {
+ return &v
+}
+
+// Uint8 is a helper routine that allocates a new
+// uint8 value to store v and returns a pointer to it.
+func Uint8(v uint8) *uint8 {
+ return &v
+}
+
+// Address is a helper routine that allocates a new tcpip.Address value to store
+// v and returns a pointer to it.
+func Address(v tcpip.Address) *tcpip.Address {
+ return &v
+}
+
+// parseIPv4 parses the bytes assuming that they start with an ipv4 header and
+// continues parsing further encapsulations.
+func parseIPv4(b []byte) (Layer, layerParser) {
+ h := header.IPv4(b)
+ tos, _ := h.TOS()
+ ipv4 := IPv4{
+ IHL: Uint8(h.HeaderLength()),
+ TOS: &tos,
+ TotalLength: Uint16(h.TotalLength()),
+ ID: Uint16(h.ID()),
+ Flags: Uint8(h.Flags()),
+ FragmentOffset: Uint16(h.FragmentOffset()),
+ TTL: Uint8(h.TTL()),
+ Protocol: Uint8(h.Protocol()),
+ Checksum: Uint16(h.Checksum()),
+ SrcAddr: Address(h.SourceAddress()),
+ DstAddr: Address(h.DestinationAddress()),
+ }
+ var nextParser layerParser
+ switch h.TransportProtocol() {
+ case header.TCPProtocolNumber:
+ nextParser = parseTCP
+ case header.UDPProtocolNumber:
+ nextParser = parseUDP
+ case header.ICMPv4ProtocolNumber:
+ nextParser = parseICMPv4
+ default:
+ // Assume that the rest is a payload.
+ nextParser = parsePayload
+ }
+ return &ipv4, nextParser
+}
+
+func (l *IPv4) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *IPv4) length() int {
+ if l.IHL == nil {
+ return header.IPv4MinimumSize
+ }
+ return int(*l.IHL)
+}
+
+// merge implements Layer.merge.
+func (l *IPv4) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// IPv6 can construct and match an IPv6 encapsulation.
+type IPv6 struct {
+ LayerBase
+ TrafficClass *uint8
+ FlowLabel *uint32
+ PayloadLength *uint16
+ NextHeader *uint8
+ HopLimit *uint8
+ SrcAddr *tcpip.Address
+ DstAddr *tcpip.Address
+}
+
+func (l *IPv6) String() string {
+ return stringLayer(l)
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *IPv6) ToBytes() ([]byte, error) {
+ b := make([]byte, header.IPv6MinimumSize)
+ h := header.IPv6(b)
+ fields := &header.IPv6Fields{
+ HopLimit: 64,
+ }
+ if l.TrafficClass != nil {
+ fields.TrafficClass = *l.TrafficClass
+ }
+ if l.FlowLabel != nil {
+ fields.FlowLabel = *l.FlowLabel
+ }
+ if l.PayloadLength != nil {
+ fields.PayloadLength = *l.PayloadLength
+ } else {
+ for current := l.next(); current != nil; current = current.next() {
+ fields.PayloadLength += uint16(current.length())
+ }
+ }
+ if l.NextHeader != nil {
+ fields.NextHeader = *l.NextHeader
+ } else {
+ switch n := l.next().(type) {
+ case *TCP:
+ fields.NextHeader = uint8(header.TCPProtocolNumber)
+ case *UDP:
+ fields.NextHeader = uint8(header.UDPProtocolNumber)
+ case *ICMPv6:
+ fields.NextHeader = uint8(header.ICMPv6ProtocolNumber)
+ case *IPv6HopByHopOptionsExtHdr:
+ fields.NextHeader = uint8(header.IPv6HopByHopOptionsExtHdrIdentifier)
+ case *IPv6DestinationOptionsExtHdr:
+ fields.NextHeader = uint8(header.IPv6DestinationOptionsExtHdrIdentifier)
+ default:
+ // TODO(b/150301488): Support more protocols as needed.
+ return nil, fmt.Errorf("ToBytes can't deduce the IPv6 header's next protocol: %#v", n)
+ }
+ }
+ if l.HopLimit != nil {
+ fields.HopLimit = *l.HopLimit
+ }
+ if l.SrcAddr != nil {
+ fields.SrcAddr = *l.SrcAddr
+ }
+ if l.DstAddr != nil {
+ fields.DstAddr = *l.DstAddr
+ }
+ h.Encode(fields)
+ return h, nil
+}
+
+// nextIPv6PayloadParser finds the corresponding parser for nextHeader.
+func nextIPv6PayloadParser(nextHeader uint8) layerParser {
+ switch tcpip.TransportProtocolNumber(nextHeader) {
+ case header.TCPProtocolNumber:
+ return parseTCP
+ case header.UDPProtocolNumber:
+ return parseUDP
+ case header.ICMPv6ProtocolNumber:
+ return parseICMPv6
+ }
+ switch header.IPv6ExtensionHeaderIdentifier(nextHeader) {
+ case header.IPv6HopByHopOptionsExtHdrIdentifier:
+ return parseIPv6HopByHopOptionsExtHdr
+ case header.IPv6DestinationOptionsExtHdrIdentifier:
+ return parseIPv6DestinationOptionsExtHdr
+ }
+ return parsePayload
+}
+
+// parseIPv6 parses the bytes assuming that they start with an ipv6 header and
+// continues parsing further encapsulations.
+func parseIPv6(b []byte) (Layer, layerParser) {
+ h := header.IPv6(b)
+ tos, flowLabel := h.TOS()
+ ipv6 := IPv6{
+ TrafficClass: &tos,
+ FlowLabel: &flowLabel,
+ PayloadLength: Uint16(h.PayloadLength()),
+ NextHeader: Uint8(h.NextHeader()),
+ HopLimit: Uint8(h.HopLimit()),
+ SrcAddr: Address(h.SourceAddress()),
+ DstAddr: Address(h.DestinationAddress()),
+ }
+ nextParser := nextIPv6PayloadParser(h.NextHeader())
+ return &ipv6, nextParser
+}
+
+func (l *IPv6) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *IPv6) length() int {
+ return header.IPv6MinimumSize
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *IPv6) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// IPv6HopByHopOptionsExtHdr can construct and match an IPv6HopByHopOptions
+// Extension Header.
+type IPv6HopByHopOptionsExtHdr struct {
+ LayerBase
+ NextHeader *header.IPv6ExtensionHeaderIdentifier
+ Options []byte
+}
+
+// IPv6DestinationOptionsExtHdr can construct and match an IPv6DestinationOptions
+// Extension Header.
+type IPv6DestinationOptionsExtHdr struct {
+ LayerBase
+ NextHeader *header.IPv6ExtensionHeaderIdentifier
+ Options []byte
+}
+
+// ipv6OptionsExtHdrToBytes serializes an options extension header into bytes.
+func ipv6OptionsExtHdrToBytes(nextHeader *header.IPv6ExtensionHeaderIdentifier, options []byte) []byte {
+ length := len(options) + 2
+ bytes := make([]byte, length)
+ if nextHeader == nil {
+ bytes[0] = byte(header.IPv6NoNextHeaderIdentifier)
+ } else {
+ bytes[0] = byte(*nextHeader)
+ }
+ // ExtHdrLen field is the length of the extension header
+ // in 8-octet unit, ignoring the first 8 octets.
+ // https://tools.ietf.org/html/rfc2460#section-4.3
+ // https://tools.ietf.org/html/rfc2460#section-4.6
+ bytes[1] = uint8((length - 8) / 8)
+ copy(bytes[2:], options)
+ return bytes
+}
+
+// IPv6ExtHdrIdent is a helper routine that allocates a new
+// header.IPv6ExtensionHeaderIdentifier value to store v and returns a pointer
+// to it.
+func IPv6ExtHdrIdent(id header.IPv6ExtensionHeaderIdentifier) *header.IPv6ExtensionHeaderIdentifier {
+ return &id
+}
+
+// ToBytes implements Layer.ToBytes
+func (l *IPv6HopByHopOptionsExtHdr) ToBytes() ([]byte, error) {
+ return ipv6OptionsExtHdrToBytes(l.NextHeader, l.Options), nil
+}
+
+// ToBytes implements Layer.ToBytes
+func (l *IPv6DestinationOptionsExtHdr) ToBytes() ([]byte, error) {
+ return ipv6OptionsExtHdrToBytes(l.NextHeader, l.Options), nil
+}
+
+// parseIPv6ExtHdr parses an IPv6 extension header and returns the NextHeader
+// field, the rest of the payload and a parser function for the corresponding
+// next extension header.
+func parseIPv6ExtHdr(b []byte) (header.IPv6ExtensionHeaderIdentifier, []byte, layerParser) {
+ nextHeader := b[0]
+ // For HopByHop and Destination options extension headers,
+ // This field is the length of the extension header in
+ // 8-octet units, not including the first 8 octets.
+ // https://tools.ietf.org/html/rfc2460#section-4.3
+ // https://tools.ietf.org/html/rfc2460#section-4.6
+ length := b[1]*8 + 8
+ data := b[2:length]
+ nextParser := nextIPv6PayloadParser(nextHeader)
+ return header.IPv6ExtensionHeaderIdentifier(nextHeader), data, nextParser
+}
+
+// parseIPv6HopByHopOptionsExtHdr parses the bytes assuming that they start
+// with an IPv6 HopByHop Options Extension Header.
+func parseIPv6HopByHopOptionsExtHdr(b []byte) (Layer, layerParser) {
+ nextHeader, options, nextParser := parseIPv6ExtHdr(b)
+ return &IPv6HopByHopOptionsExtHdr{NextHeader: &nextHeader, Options: options}, nextParser
+}
+
+// parseIPv6DestinationOptionsExtHdr parses the bytes assuming that they start
+// with an IPv6 Destination Options Extension Header.
+func parseIPv6DestinationOptionsExtHdr(b []byte) (Layer, layerParser) {
+ nextHeader, options, nextParser := parseIPv6ExtHdr(b)
+ return &IPv6DestinationOptionsExtHdr{NextHeader: &nextHeader, Options: options}, nextParser
+}
+
+func (l *IPv6HopByHopOptionsExtHdr) length() int {
+ return len(l.Options) + 2
+}
+
+func (l *IPv6HopByHopOptionsExtHdr) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *IPv6HopByHopOptionsExtHdr) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+func (l *IPv6HopByHopOptionsExtHdr) String() string {
+ return stringLayer(l)
+}
+
+func (l *IPv6DestinationOptionsExtHdr) length() int {
+ return len(l.Options) + 2
+}
+
+func (l *IPv6DestinationOptionsExtHdr) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *IPv6DestinationOptionsExtHdr) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+func (l *IPv6DestinationOptionsExtHdr) String() string {
+ return stringLayer(l)
+}
+
+// ICMPv6 can construct and match an ICMPv6 encapsulation.
+type ICMPv6 struct {
+ LayerBase
+ Type *header.ICMPv6Type
+ Code *byte
+ Checksum *uint16
+ NDPPayload []byte
+}
+
+func (l *ICMPv6) String() string {
+ // TODO(eyalsoha): Do something smarter here when *l.Type is ParameterProblem?
+ // We could parse the contents of the Payload as if it were an IPv6 packet.
+ return stringLayer(l)
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *ICMPv6) ToBytes() ([]byte, error) {
+ b := make([]byte, header.ICMPv6HeaderSize+len(l.NDPPayload))
+ h := header.ICMPv6(b)
+ if l.Type != nil {
+ h.SetType(*l.Type)
+ }
+ if l.Code != nil {
+ h.SetCode(*l.Code)
+ }
+ copy(h.NDPPayload(), l.NDPPayload)
+ if l.Checksum != nil {
+ h.SetChecksum(*l.Checksum)
+ } else {
+ // It is possible that the ICMPv6 header does not follow the IPv6 header
+ // immediately, there could be one or more extension headers in between.
+ // We need to search forward to find the IPv6 header.
+ for prev := l.Prev(); prev != nil; prev = prev.Prev() {
+ if ipv6, ok := prev.(*IPv6); ok {
+ h.SetChecksum(header.ICMPv6Checksum(h, *ipv6.SrcAddr, *ipv6.DstAddr, buffer.VectorisedView{}))
+ break
+ }
+ }
+ }
+ return h, nil
+}
+
+// ICMPv6Type is a helper routine that allocates a new ICMPv6Type value to store
+// v and returns a pointer to it.
+func ICMPv6Type(v header.ICMPv6Type) *header.ICMPv6Type {
+ return &v
+}
+
+// Byte is a helper routine that allocates a new byte value to store
+// v and returns a pointer to it.
+func Byte(v byte) *byte {
+ return &v
+}
+
+// parseICMPv6 parses the bytes assuming that they start with an ICMPv6 header.
+func parseICMPv6(b []byte) (Layer, layerParser) {
+ h := header.ICMPv6(b)
+ icmpv6 := ICMPv6{
+ Type: ICMPv6Type(h.Type()),
+ Code: Byte(h.Code()),
+ Checksum: Uint16(h.Checksum()),
+ NDPPayload: h.NDPPayload(),
+ }
+ return &icmpv6, nil
+}
+
+func (l *ICMPv6) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *ICMPv6) length() int {
+ return header.ICMPv6HeaderSize + len(l.NDPPayload)
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *ICMPv6) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// ICMPv4Type is a helper routine that allocates a new header.ICMPv4Type value
+// to store t and returns a pointer to it.
+func ICMPv4Type(t header.ICMPv4Type) *header.ICMPv4Type {
+ return &t
+}
+
+// ICMPv4 can construct and match an ICMPv4 encapsulation.
+type ICMPv4 struct {
+ LayerBase
+ Type *header.ICMPv4Type
+ Code *uint8
+ Checksum *uint16
+}
+
+func (l *ICMPv4) String() string {
+ return stringLayer(l)
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *ICMPv4) ToBytes() ([]byte, error) {
+ b := make([]byte, header.ICMPv4MinimumSize)
+ h := header.ICMPv4(b)
+ if l.Type != nil {
+ h.SetType(*l.Type)
+ }
+ if l.Code != nil {
+ h.SetCode(byte(*l.Code))
+ }
+ if l.Checksum != nil {
+ h.SetChecksum(*l.Checksum)
+ return h, nil
+ }
+ payload, err := payload(l)
+ if err != nil {
+ return nil, err
+ }
+ h.SetChecksum(header.ICMPv4Checksum(h, payload))
+ return h, nil
+}
+
+// parseICMPv4 parses the bytes as an ICMPv4 header, returning a Layer and a
+// parser for the encapsulated payload.
+func parseICMPv4(b []byte) (Layer, layerParser) {
+ h := header.ICMPv4(b)
+ icmpv4 := ICMPv4{
+ Type: ICMPv4Type(h.Type()),
+ Code: Uint8(h.Code()),
+ Checksum: Uint16(h.Checksum()),
+ }
+ return &icmpv4, parsePayload
+}
+
+func (l *ICMPv4) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *ICMPv4) length() int {
+ return header.ICMPv4MinimumSize
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *ICMPv4) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// TCP can construct and match a TCP encapsulation.
+type TCP struct {
+ LayerBase
+ SrcPort *uint16
+ DstPort *uint16
+ SeqNum *uint32
+ AckNum *uint32
+ DataOffset *uint8
+ Flags *uint8
+ WindowSize *uint16
+ Checksum *uint16
+ UrgentPointer *uint16
+ Options []byte
+}
+
+func (l *TCP) String() string {
+ return stringLayer(l)
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *TCP) ToBytes() ([]byte, error) {
+ b := make([]byte, l.length())
+ h := header.TCP(b)
+ if l.SrcPort != nil {
+ h.SetSourcePort(*l.SrcPort)
+ }
+ if l.DstPort != nil {
+ h.SetDestinationPort(*l.DstPort)
+ }
+ if l.SeqNum != nil {
+ h.SetSequenceNumber(*l.SeqNum)
+ }
+ if l.AckNum != nil {
+ h.SetAckNumber(*l.AckNum)
+ }
+ if l.DataOffset != nil {
+ h.SetDataOffset(*l.DataOffset)
+ } else {
+ h.SetDataOffset(uint8(l.length()))
+ }
+ if l.Flags != nil {
+ h.SetFlags(*l.Flags)
+ }
+ if l.WindowSize != nil {
+ h.SetWindowSize(*l.WindowSize)
+ } else {
+ h.SetWindowSize(32768)
+ }
+ if l.UrgentPointer != nil {
+ h.SetUrgentPoiner(*l.UrgentPointer)
+ }
+ copy(b[header.TCPMinimumSize:], l.Options)
+ header.AddTCPOptionPadding(b[header.TCPMinimumSize:], len(l.Options))
+ if l.Checksum != nil {
+ h.SetChecksum(*l.Checksum)
+ return h, nil
+ }
+ if err := setTCPChecksum(&h, l); err != nil {
+ return nil, err
+ }
+ return h, nil
+}
+
+// totalLength returns the length of the provided layer and all following
+// layers.
+func totalLength(l Layer) int {
+ var totalLength int
+ for ; l != nil; l = l.next() {
+ totalLength += l.length()
+ }
+ return totalLength
+}
+
+// payload returns a buffer.VectorisedView of l's payload.
+func payload(l Layer) (buffer.VectorisedView, error) {
+ var payloadBytes buffer.VectorisedView
+ for current := l.next(); current != nil; current = current.next() {
+ payload, err := current.ToBytes()
+ if err != nil {
+ return buffer.VectorisedView{}, fmt.Errorf("can't get bytes for next header: %s", payload)
+ }
+ payloadBytes.AppendView(payload)
+ }
+ return payloadBytes, nil
+}
+
+// layerChecksum calculates the checksum of the Layer header, including the
+// peusdeochecksum of the layer before it and all the bytes after it.
+func layerChecksum(l Layer, protoNumber tcpip.TransportProtocolNumber) (uint16, error) {
+ totalLength := uint16(totalLength(l))
+ var xsum uint16
+ switch s := l.Prev().(type) {
+ case *IPv4:
+ xsum = header.PseudoHeaderChecksum(protoNumber, *s.SrcAddr, *s.DstAddr, totalLength)
+ default:
+ // TODO(b/150301488): Support more protocols, like IPv6.
+ return 0, fmt.Errorf("can't get src and dst addr from previous layer: %#v", s)
+ }
+ payloadBytes, err := payload(l)
+ if err != nil {
+ return 0, err
+ }
+ xsum = header.ChecksumVV(payloadBytes, xsum)
+ return xsum, nil
+}
+
+// setTCPChecksum calculates the checksum of the TCP header and sets it in h.
+func setTCPChecksum(h *header.TCP, tcp *TCP) error {
+ h.SetChecksum(0)
+ xsum, err := layerChecksum(tcp, header.TCPProtocolNumber)
+ if err != nil {
+ return err
+ }
+ h.SetChecksum(^h.CalculateChecksum(xsum))
+ return nil
+}
+
+// Uint32 is a helper routine that allocates a new
+// uint32 value to store v and returns a pointer to it.
+func Uint32(v uint32) *uint32 {
+ return &v
+}
+
+// parseTCP parses the bytes assuming that they start with a tcp header and
+// continues parsing further encapsulations.
+func parseTCP(b []byte) (Layer, layerParser) {
+ h := header.TCP(b)
+ tcp := TCP{
+ SrcPort: Uint16(h.SourcePort()),
+ DstPort: Uint16(h.DestinationPort()),
+ SeqNum: Uint32(h.SequenceNumber()),
+ AckNum: Uint32(h.AckNumber()),
+ DataOffset: Uint8(h.DataOffset()),
+ Flags: Uint8(h.Flags()),
+ WindowSize: Uint16(h.WindowSize()),
+ Checksum: Uint16(h.Checksum()),
+ UrgentPointer: Uint16(h.UrgentPointer()),
+ Options: b[header.TCPMinimumSize:h.DataOffset()],
+ }
+ return &tcp, parsePayload
+}
+
+func (l *TCP) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *TCP) length() int {
+ if l.DataOffset == nil {
+ // TCP header including the options must end on a 32-bit
+ // boundary; the user could potentially give us a slice
+ // whose length is not a multiple of 4 bytes, so we have
+ // to do the alignment here.
+ optlen := (len(l.Options) + 3) & ^3
+ return header.TCPMinimumSize + optlen
+ }
+ return int(*l.DataOffset)
+}
+
+// merge implements Layer.merge.
+func (l *TCP) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// UDP can construct and match a UDP encapsulation.
+type UDP struct {
+ LayerBase
+ SrcPort *uint16
+ DstPort *uint16
+ Length *uint16
+ Checksum *uint16
+}
+
+func (l *UDP) String() string {
+ return stringLayer(l)
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *UDP) ToBytes() ([]byte, error) {
+ b := make([]byte, header.UDPMinimumSize)
+ h := header.UDP(b)
+ if l.SrcPort != nil {
+ h.SetSourcePort(*l.SrcPort)
+ }
+ if l.DstPort != nil {
+ h.SetDestinationPort(*l.DstPort)
+ }
+ if l.Length != nil {
+ h.SetLength(*l.Length)
+ } else {
+ h.SetLength(uint16(totalLength(l)))
+ }
+ if l.Checksum != nil {
+ h.SetChecksum(*l.Checksum)
+ return h, nil
+ }
+ if err := setUDPChecksum(&h, l); err != nil {
+ return nil, err
+ }
+ return h, nil
+}
+
+// setUDPChecksum calculates the checksum of the UDP header and sets it in h.
+func setUDPChecksum(h *header.UDP, udp *UDP) error {
+ h.SetChecksum(0)
+ xsum, err := layerChecksum(udp, header.UDPProtocolNumber)
+ if err != nil {
+ return err
+ }
+ h.SetChecksum(^h.CalculateChecksum(xsum))
+ return nil
+}
+
+// parseUDP parses the bytes assuming that they start with a udp header and
+// returns the parsed layer and the next parser to use.
+func parseUDP(b []byte) (Layer, layerParser) {
+ h := header.UDP(b)
+ udp := UDP{
+ SrcPort: Uint16(h.SourcePort()),
+ DstPort: Uint16(h.DestinationPort()),
+ Length: Uint16(h.Length()),
+ Checksum: Uint16(h.Checksum()),
+ }
+ return &udp, parsePayload
+}
+
+func (l *UDP) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *UDP) length() int {
+ return header.UDPMinimumSize
+}
+
+// merge implements Layer.merge.
+func (l *UDP) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// Payload has bytes beyond OSI layer 4.
+type Payload struct {
+ LayerBase
+ Bytes []byte
+}
+
+func (l *Payload) String() string {
+ return stringLayer(l)
+}
+
+// parsePayload parses the bytes assuming that they start with a payload and
+// continue to the end. There can be no further encapsulations.
+func parsePayload(b []byte) (Layer, layerParser) {
+ payload := Payload{
+ Bytes: b,
+ }
+ return &payload, nil
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *Payload) ToBytes() ([]byte, error) {
+ return l.Bytes, nil
+}
+
+// Length returns payload byte length.
+func (l *Payload) Length() int {
+ return l.length()
+}
+
+func (l *Payload) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *Payload) length() int {
+ return len(l.Bytes)
+}
+
+// merge implements Layer.merge.
+func (l *Payload) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// Layers is an array of Layer and supports similar functions to Layer.
+type Layers []Layer
+
+// linkLayers sets the linked-list ponters in ls.
+func (ls *Layers) linkLayers() {
+ for i, l := range *ls {
+ if i > 0 {
+ l.setPrev((*ls)[i-1])
+ } else {
+ l.setPrev(nil)
+ }
+ if i+1 < len(*ls) {
+ l.setNext((*ls)[i+1])
+ } else {
+ l.setNext(nil)
+ }
+ }
+}
+
+// ToBytes converts the Layers into bytes. It creates a linked list of the Layer
+// structs and then concatentates the output of ToBytes on each Layer.
+func (ls *Layers) ToBytes() ([]byte, error) {
+ ls.linkLayers()
+ outBytes := []byte{}
+ for _, l := range *ls {
+ layerBytes, err := l.ToBytes()
+ if err != nil {
+ return nil, err
+ }
+ outBytes = append(outBytes, layerBytes...)
+ }
+ return outBytes, nil
+}
+
+func (ls *Layers) match(other Layers) bool {
+ if len(*ls) > len(other) {
+ return false
+ }
+ for i, l := range *ls {
+ if !equalLayer(l, other[i]) {
+ return false
+ }
+ }
+ return true
+}
+
+// layerDiff stores the diffs for each field along with the label for the Layer.
+// If rows is nil, that means that there was no diff.
+type layerDiff struct {
+ label string
+ rows []layerDiffRow
+}
+
+// layerDiffRow stores the fields and corresponding values for two got and want
+// layers. If the value was nil then the string stored is the empty string.
+type layerDiffRow struct {
+ field, got, want string
+}
+
+// diffLayer extracts all differing fields between two layers.
+func diffLayer(got, want Layer) []layerDiffRow {
+ vGot := reflect.ValueOf(got).Elem()
+ vWant := reflect.ValueOf(want).Elem()
+ if vGot.Type() != vWant.Type() {
+ return nil
+ }
+ t := vGot.Type()
+ var result []layerDiffRow
+ for i := 0; i < t.NumField(); i++ {
+ t := t.Field(i)
+ if t.Anonymous {
+ // Ignore the LayerBase in the Layer struct.
+ continue
+ }
+ vGot := vGot.Field(i)
+ vWant := vWant.Field(i)
+ gotString := ""
+ if !vGot.IsNil() {
+ gotString = fmt.Sprint(reflect.Indirect(vGot))
+ }
+ wantString := ""
+ if !vWant.IsNil() {
+ wantString = fmt.Sprint(reflect.Indirect(vWant))
+ }
+ result = append(result, layerDiffRow{t.Name, gotString, wantString})
+ }
+ return result
+}
+
+// layerType returns a concise string describing the type of the Layer, like
+// "TCP", or "IPv6".
+func layerType(l Layer) string {
+ return reflect.TypeOf(l).Elem().Name()
+}
+
+// diff compares Layers and returns a representation of the difference. Each
+// Layer in the Layers is pairwise compared. If an element in either is nil, it
+// is considered a match with the other Layer. If two Layers have differing
+// types, they don't match regardless of the contents. If two Layers have the
+// same type then the fields in the Layer are pairwise compared. Fields that are
+// nil always match. Two non-nil fields only match if they point to equal
+// values. diff returns an empty string if and only if *ls and other match.
+func (ls *Layers) diff(other Layers) string {
+ var allDiffs []layerDiff
+ // Check the cases where one list is longer than the other, where one or both
+ // elements are nil, where the sides have different types, and where the sides
+ // have the same type.
+ for i := 0; i < len(*ls) || i < len(other); i++ {
+ if i >= len(*ls) {
+ // Matching ls against other where other is longer than ls. missing
+ // matches everything so we just include a label without any rows. Having
+ // no rows is a sign that there was no diff.
+ allDiffs = append(allDiffs, layerDiff{
+ label: "missing matches " + layerType(other[i]),
+ })
+ continue
+ }
+
+ if i >= len(other) {
+ // Matching ls against other where ls is longer than other. missing
+ // matches everything so we just include a label without any rows. Having
+ // no rows is a sign that there was no diff.
+ allDiffs = append(allDiffs, layerDiff{
+ label: layerType((*ls)[i]) + " matches missing",
+ })
+ continue
+ }
+
+ if (*ls)[i] == nil && other[i] == nil {
+ // Matching ls against other where both elements are nil. nil matches
+ // everything so we just include a label without any rows. Having no rows
+ // is a sign that there was no diff.
+ allDiffs = append(allDiffs, layerDiff{
+ label: "nil matches nil",
+ })
+ continue
+ }
+
+ if (*ls)[i] == nil {
+ // Matching ls against other where the element in ls is nil. nil matches
+ // everything so we just include a label without any rows. Having no rows
+ // is a sign that there was no diff.
+ allDiffs = append(allDiffs, layerDiff{
+ label: "nil matches " + layerType(other[i]),
+ })
+ continue
+ }
+
+ if other[i] == nil {
+ // Matching ls against other where the element in other is nil. nil
+ // matches everything so we just include a label without any rows. Having
+ // no rows is a sign that there was no diff.
+ allDiffs = append(allDiffs, layerDiff{
+ label: layerType((*ls)[i]) + " matches nil",
+ })
+ continue
+ }
+
+ if reflect.TypeOf((*ls)[i]) == reflect.TypeOf(other[i]) {
+ // Matching ls against other where both elements have the same type. Match
+ // each field pairwise and only report a diff if there is a mismatch,
+ // which is only when both sides are non-nil and have differring values.
+ diff := diffLayer((*ls)[i], other[i])
+ var layerDiffRows []layerDiffRow
+ for _, d := range diff {
+ if d.got == "" || d.want == "" || d.got == d.want {
+ continue
+ }
+ layerDiffRows = append(layerDiffRows, layerDiffRow{
+ d.field,
+ d.got,
+ d.want,
+ })
+ }
+ if len(layerDiffRows) > 0 {
+ allDiffs = append(allDiffs, layerDiff{
+ label: layerType((*ls)[i]),
+ rows: layerDiffRows,
+ })
+ } else {
+ allDiffs = append(allDiffs, layerDiff{
+ label: layerType((*ls)[i]) + " matches " + layerType(other[i]),
+ // Having no rows is a sign that there was no diff.
+ })
+ }
+ continue
+ }
+ // Neither side is nil and the types are different, so we'll display one
+ // side then the other.
+ allDiffs = append(allDiffs, layerDiff{
+ label: layerType((*ls)[i]) + " doesn't match " + layerType(other[i]),
+ })
+ diff := diffLayer((*ls)[i], (*ls)[i])
+ layerDiffRows := []layerDiffRow{}
+ for _, d := range diff {
+ if len(d.got) == 0 {
+ continue
+ }
+ layerDiffRows = append(layerDiffRows, layerDiffRow{
+ d.field,
+ d.got,
+ "",
+ })
+ }
+ allDiffs = append(allDiffs, layerDiff{
+ label: layerType((*ls)[i]),
+ rows: layerDiffRows,
+ })
+
+ layerDiffRows = []layerDiffRow{}
+ diff = diffLayer(other[i], other[i])
+ for _, d := range diff {
+ if len(d.want) == 0 {
+ continue
+ }
+ layerDiffRows = append(layerDiffRows, layerDiffRow{
+ d.field,
+ "",
+ d.want,
+ })
+ }
+ allDiffs = append(allDiffs, layerDiff{
+ label: layerType(other[i]),
+ rows: layerDiffRows,
+ })
+ }
+
+ output := ""
+ // These are for output formatting.
+ maxLabelLen, maxFieldLen, maxGotLen, maxWantLen := 0, 0, 0, 0
+ foundOne := false
+ for _, l := range allDiffs {
+ if len(l.label) > maxLabelLen && len(l.rows) > 0 {
+ maxLabelLen = len(l.label)
+ }
+ if l.rows != nil {
+ foundOne = true
+ }
+ for _, r := range l.rows {
+ if len(r.field) > maxFieldLen {
+ maxFieldLen = len(r.field)
+ }
+ if l := len(fmt.Sprint(r.got)); l > maxGotLen {
+ maxGotLen = l
+ }
+ if l := len(fmt.Sprint(r.want)); l > maxWantLen {
+ maxWantLen = l
+ }
+ }
+ }
+ if !foundOne {
+ return ""
+ }
+ for _, l := range allDiffs {
+ if len(l.rows) == 0 {
+ output += "(" + l.label + ")\n"
+ continue
+ }
+ for i, r := range l.rows {
+ var label string
+ if i == 0 {
+ label = l.label + ":"
+ }
+ output += fmt.Sprintf(
+ "%*s %*s %*v %*v\n",
+ maxLabelLen+1, label,
+ maxFieldLen+1, r.field+":",
+ maxGotLen, r.got,
+ maxWantLen, r.want,
+ )
+ }
+ }
+ return output
+}
+
+// merge merges the other Layers into ls. If the other Layers is longer, those
+// additional Layer structs are added to ls. The errors from merging are
+// collected and returned.
+func (ls *Layers) merge(other Layers) error {
+ var errs error
+ for i, o := range other {
+ if i < len(*ls) {
+ errs = multierr.Combine(errs, (*ls)[i].merge(o))
+ } else {
+ *ls = append(*ls, o)
+ }
+ }
+ return errs
+}
diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go
new file mode 100644
index 000000000..382a983a1
--- /dev/null
+++ b/test/packetimpact/testbench/layers_test.go
@@ -0,0 +1,618 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "bytes"
+ "net"
+ "testing"
+
+ "github.com/mohae/deepcopy"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+func TestLayerMatch(t *testing.T) {
+ var nilPayload *Payload
+ noPayload := &Payload{}
+ emptyPayload := &Payload{Bytes: []byte{}}
+ fullPayload := &Payload{Bytes: []byte{1, 2, 3}}
+ emptyTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: emptyPayload}}
+ fullTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: fullPayload}}
+ for _, tt := range []struct {
+ a, b Layer
+ want bool
+ }{
+ {nilPayload, nilPayload, true},
+ {nilPayload, noPayload, true},
+ {nilPayload, emptyPayload, true},
+ {nilPayload, fullPayload, true},
+ {noPayload, noPayload, true},
+ {noPayload, emptyPayload, true},
+ {noPayload, fullPayload, true},
+ {emptyPayload, emptyPayload, true},
+ {emptyPayload, fullPayload, false},
+ {fullPayload, fullPayload, true},
+ {emptyTCP, fullTCP, true},
+ } {
+ if got := tt.a.match(tt.b); got != tt.want {
+ t.Errorf("%s.match(%s) = %t, want %t", tt.a, tt.b, got, tt.want)
+ }
+ if got := tt.b.match(tt.a); got != tt.want {
+ t.Errorf("%s.match(%s) = %t, want %t", tt.b, tt.a, got, tt.want)
+ }
+ }
+}
+
+func TestLayerMergeMismatch(t *testing.T) {
+ tcp := &TCP{}
+ otherTCP := &TCP{}
+ ipv4 := &IPv4{}
+ ether := &Ether{}
+ for _, tt := range []struct {
+ a, b Layer
+ success bool
+ }{
+ {tcp, tcp, true},
+ {tcp, otherTCP, true},
+ {tcp, ipv4, false},
+ {tcp, ether, false},
+ {tcp, nil, true},
+
+ {otherTCP, otherTCP, true},
+ {otherTCP, ipv4, false},
+ {otherTCP, ether, false},
+ {otherTCP, nil, true},
+
+ {ipv4, ipv4, true},
+ {ipv4, ether, false},
+ {ipv4, nil, true},
+
+ {ether, ether, true},
+ {ether, nil, true},
+ } {
+ if err := tt.a.merge(tt.b); (err == nil) != tt.success {
+ t.Errorf("%s.merge(%s) got %s, wanted the opposite", tt.a, tt.b, err)
+ }
+ if tt.b != nil {
+ if err := tt.b.merge(tt.a); (err == nil) != tt.success {
+ t.Errorf("%s.merge(%s) got %s, wanted the opposite", tt.b, tt.a, err)
+ }
+ }
+ }
+}
+
+func TestLayerMerge(t *testing.T) {
+ zero := Uint32(0)
+ one := Uint32(1)
+ two := Uint32(2)
+ empty := []byte{}
+ foo := []byte("foo")
+ bar := []byte("bar")
+ for _, tt := range []struct {
+ a, b Layer
+ want Layer
+ }{
+ {&TCP{AckNum: nil}, &TCP{AckNum: nil}, &TCP{AckNum: nil}},
+ {&TCP{AckNum: nil}, &TCP{AckNum: zero}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: nil}, &TCP{AckNum: one}, &TCP{AckNum: one}},
+ {&TCP{AckNum: nil}, &TCP{AckNum: two}, &TCP{AckNum: two}},
+ {&TCP{AckNum: nil}, nil, &TCP{AckNum: nil}},
+
+ {&TCP{AckNum: zero}, &TCP{AckNum: nil}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: zero}, &TCP{AckNum: zero}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: zero}, &TCP{AckNum: one}, &TCP{AckNum: one}},
+ {&TCP{AckNum: zero}, &TCP{AckNum: two}, &TCP{AckNum: two}},
+ {&TCP{AckNum: zero}, nil, &TCP{AckNum: zero}},
+
+ {&TCP{AckNum: one}, &TCP{AckNum: nil}, &TCP{AckNum: one}},
+ {&TCP{AckNum: one}, &TCP{AckNum: zero}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: one}, &TCP{AckNum: one}, &TCP{AckNum: one}},
+ {&TCP{AckNum: one}, &TCP{AckNum: two}, &TCP{AckNum: two}},
+ {&TCP{AckNum: one}, nil, &TCP{AckNum: one}},
+
+ {&TCP{AckNum: two}, &TCP{AckNum: nil}, &TCP{AckNum: two}},
+ {&TCP{AckNum: two}, &TCP{AckNum: zero}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: two}, &TCP{AckNum: one}, &TCP{AckNum: one}},
+ {&TCP{AckNum: two}, &TCP{AckNum: two}, &TCP{AckNum: two}},
+ {&TCP{AckNum: two}, nil, &TCP{AckNum: two}},
+
+ {&Payload{Bytes: nil}, &Payload{Bytes: nil}, &Payload{Bytes: nil}},
+ {&Payload{Bytes: nil}, &Payload{Bytes: empty}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: nil}, &Payload{Bytes: foo}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: nil}, &Payload{Bytes: bar}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: nil}, nil, &Payload{Bytes: nil}},
+
+ {&Payload{Bytes: empty}, &Payload{Bytes: nil}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: empty}, &Payload{Bytes: empty}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: empty}, &Payload{Bytes: foo}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: empty}, &Payload{Bytes: bar}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: empty}, nil, &Payload{Bytes: empty}},
+
+ {&Payload{Bytes: foo}, &Payload{Bytes: nil}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: foo}, &Payload{Bytes: empty}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: foo}, &Payload{Bytes: foo}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: foo}, &Payload{Bytes: bar}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: foo}, nil, &Payload{Bytes: foo}},
+
+ {&Payload{Bytes: bar}, &Payload{Bytes: nil}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: bar}, &Payload{Bytes: empty}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: bar}, &Payload{Bytes: foo}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: bar}, &Payload{Bytes: bar}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: bar}, nil, &Payload{Bytes: bar}},
+ } {
+ a := deepcopy.Copy(tt.a).(Layer)
+ if err := a.merge(tt.b); err != nil {
+ t.Errorf("%s.merge(%s) = %s, wanted nil", tt.a, tt.b, err)
+ continue
+ }
+ if a.String() != tt.want.String() {
+ t.Errorf("%s.merge(%s) merge result got %s, want %s", tt.a, tt.b, a, tt.want)
+ }
+ }
+}
+
+func TestLayerStringFormat(t *testing.T) {
+ for _, tt := range []struct {
+ name string
+ l Layer
+ want string
+ }{
+ {
+ name: "TCP",
+ l: &TCP{
+ SrcPort: Uint16(34785),
+ DstPort: Uint16(47767),
+ SeqNum: Uint32(3452155723),
+ AckNum: Uint32(2596996163),
+ DataOffset: Uint8(5),
+ Flags: Uint8(20),
+ WindowSize: Uint16(64240),
+ Checksum: Uint16(0x2e2b),
+ },
+ want: "&testbench.TCP{" +
+ "SrcPort:34785 " +
+ "DstPort:47767 " +
+ "SeqNum:3452155723 " +
+ "AckNum:2596996163 " +
+ "DataOffset:5 " +
+ "Flags:20 " +
+ "WindowSize:64240 " +
+ "Checksum:11819" +
+ "}",
+ },
+ {
+ name: "UDP",
+ l: &UDP{
+ SrcPort: Uint16(34785),
+ DstPort: Uint16(47767),
+ Length: Uint16(12),
+ },
+ want: "&testbench.UDP{" +
+ "SrcPort:34785 " +
+ "DstPort:47767 " +
+ "Length:12" +
+ "}",
+ },
+ {
+ name: "IPv4",
+ l: &IPv4{
+ IHL: Uint8(5),
+ TOS: Uint8(0),
+ TotalLength: Uint16(44),
+ ID: Uint16(0),
+ Flags: Uint8(2),
+ FragmentOffset: Uint16(0),
+ TTL: Uint8(64),
+ Protocol: Uint8(6),
+ Checksum: Uint16(0x2e2b),
+ SrcAddr: Address(tcpip.Address([]byte{197, 34, 63, 10})),
+ DstAddr: Address(tcpip.Address([]byte{197, 34, 63, 20})),
+ },
+ want: "&testbench.IPv4{" +
+ "IHL:5 " +
+ "TOS:0 " +
+ "TotalLength:44 " +
+ "ID:0 " +
+ "Flags:2 " +
+ "FragmentOffset:0 " +
+ "TTL:64 " +
+ "Protocol:6 " +
+ "Checksum:11819 " +
+ "SrcAddr:197.34.63.10 " +
+ "DstAddr:197.34.63.20" +
+ "}",
+ },
+ {
+ name: "Ether",
+ l: &Ether{
+ SrcAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x0a})),
+ DstAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x14})),
+ Type: NetworkProtocolNumber(4),
+ },
+ want: "&testbench.Ether{" +
+ "SrcAddr:02:42:c5:22:3f:0a " +
+ "DstAddr:02:42:c5:22:3f:14 " +
+ "Type:4" +
+ "}",
+ },
+ {
+ name: "Payload",
+ l: &Payload{
+ Bytes: []byte("Hooray for packetimpact."),
+ },
+ want: "&testbench.Payload{Bytes:\n" +
+ "00000000 48 6f 6f 72 61 79 20 66 6f 72 20 70 61 63 6b 65 |Hooray for packe|\n" +
+ "00000010 74 69 6d 70 61 63 74 2e |timpact.|\n" +
+ "}",
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := tt.l.String(); got != tt.want {
+ t.Errorf("%s.String() = %s, want: %s", tt.name, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestConnectionMatch(t *testing.T) {
+ conn := Connection{
+ layerStates: []layerState{&etherState{}},
+ }
+ protoNum0 := tcpip.NetworkProtocolNumber(0)
+ protoNum1 := tcpip.NetworkProtocolNumber(1)
+ for _, tt := range []struct {
+ description string
+ override, received Layers
+ wantMatch bool
+ }{
+ {
+ description: "shorter override",
+ override: []Layer{&Ether{}},
+ received: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}},
+ wantMatch: true,
+ },
+ {
+ description: "longer override",
+ override: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}},
+ received: []Layer{&Ether{}},
+ wantMatch: false,
+ },
+ {
+ description: "ether layer mismatch",
+ override: []Layer{&Ether{Type: &protoNum0}},
+ received: []Layer{&Ether{Type: &protoNum1}},
+ wantMatch: false,
+ },
+ {
+ description: "both nil",
+ override: nil,
+ received: nil,
+ wantMatch: false,
+ },
+ {
+ description: "nil override",
+ override: nil,
+ received: []Layer{&Ether{}},
+ wantMatch: true,
+ },
+ } {
+ t.Run(tt.description, func(t *testing.T) {
+ if gotMatch := conn.match(tt.override, tt.received); gotMatch != tt.wantMatch {
+ t.Fatalf("conn.match(%s, %s) = %t, want %t", tt.override, tt.received, gotMatch, tt.wantMatch)
+ }
+ })
+ }
+}
+
+func TestLayersDiff(t *testing.T) {
+ for _, tt := range []struct {
+ x, y Layers
+ want string
+ }{
+ {
+ Layers{&Ether{Type: NetworkProtocolNumber(12)}, &TCP{DataOffset: Uint8(5), SeqNum: Uint32(5)}},
+ Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}},
+ "Ether: Type: 12 13\n" +
+ " TCP: SeqNum: 5 6\n" +
+ " DataOffset: 5 7\n",
+ },
+ {
+ Layers{&Ether{Type: NetworkProtocolNumber(12)}, &UDP{SrcPort: Uint16(123)}},
+ Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}},
+ "Ether: Type: 12 13\n" +
+ "(UDP doesn't match TCP)\n" +
+ " UDP: SrcPort: 123 \n" +
+ " TCP: SeqNum: 6\n" +
+ " DataOffset: 7\n",
+ },
+ {
+ Layers{&UDP{SrcPort: Uint16(123)}},
+ Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}},
+ "(UDP doesn't match Ether)\n" +
+ " UDP: SrcPort: 123 \n" +
+ "Ether: Type: 13\n" +
+ "(missing matches TCP)\n",
+ },
+ {
+ Layers{nil, &UDP{SrcPort: Uint16(123)}},
+ Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}},
+ "(nil matches Ether)\n" +
+ "(UDP doesn't match TCP)\n" +
+ "UDP: SrcPort: 123 \n" +
+ "TCP: SeqNum: 6\n" +
+ " DataOffset: 7\n",
+ },
+ {
+ Layers{&Ether{Type: NetworkProtocolNumber(13)}, &IPv4{IHL: Uint8(4)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}},
+ Layers{&Ether{Type: NetworkProtocolNumber(13)}, &IPv4{IHL: Uint8(6)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}},
+ "(Ether matches Ether)\n" +
+ "IPv4: IHL: 4 6\n" +
+ "(TCP matches TCP)\n",
+ },
+ {
+ Layers{&Payload{Bytes: []byte("foo")}},
+ Layers{&Payload{Bytes: []byte("bar")}},
+ "Payload: Bytes: [102 111 111] [98 97 114]\n",
+ },
+ {
+ Layers{&Payload{Bytes: []byte("")}},
+ Layers{&Payload{}},
+ "",
+ },
+ {
+ Layers{&Payload{Bytes: []byte("")}},
+ Layers{&Payload{Bytes: []byte("")}},
+ "",
+ },
+ {
+ Layers{&UDP{}},
+ Layers{&TCP{}},
+ "(UDP doesn't match TCP)\n" +
+ "(UDP)\n" +
+ "(TCP)\n",
+ },
+ } {
+ if got := tt.x.diff(tt.y); got != tt.want {
+ t.Errorf("%s.diff(%s) = %q, want %q", tt.x, tt.y, got, tt.want)
+ }
+ if tt.x.match(tt.y) != (tt.x.diff(tt.y) == "") {
+ t.Errorf("match and diff of %s and %s disagree", tt.x, tt.y)
+ }
+ if tt.y.match(tt.x) != (tt.y.diff(tt.x) == "") {
+ t.Errorf("match and diff of %s and %s disagree", tt.y, tt.x)
+ }
+ }
+}
+
+func TestTCPOptions(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ wantBytes []byte
+ wantLayers Layers
+ }{
+ {
+ description: "without payload",
+ wantBytes: []byte{
+ // IPv4 Header
+ 0x45, 0x00, 0x00, 0x2c, 0x00, 0x01, 0x00, 0x00, 0x40, 0x06,
+ 0xf9, 0x77, 0xc0, 0xa8, 0x00, 0x02, 0xc0, 0xa8, 0x00, 0x01,
+ // TCP Header
+ 0x30, 0x39, 0xd4, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x60, 0x02, 0x20, 0x00, 0xf5, 0x1c, 0x00, 0x00,
+ // WindowScale Option
+ 0x03, 0x03, 0x02,
+ // NOP Option
+ 0x00,
+ },
+ wantLayers: []Layer{
+ &IPv4{
+ IHL: Uint8(20),
+ TOS: Uint8(0),
+ TotalLength: Uint16(44),
+ ID: Uint16(1),
+ Flags: Uint8(0),
+ FragmentOffset: Uint16(0),
+ TTL: Uint8(64),
+ Protocol: Uint8(uint8(header.TCPProtocolNumber)),
+ Checksum: Uint16(0xf977),
+ SrcAddr: Address(tcpip.Address(net.ParseIP("192.168.0.2").To4())),
+ DstAddr: Address(tcpip.Address(net.ParseIP("192.168.0.1").To4())),
+ },
+ &TCP{
+ SrcPort: Uint16(12345),
+ DstPort: Uint16(54321),
+ SeqNum: Uint32(0),
+ AckNum: Uint32(0),
+ Flags: Uint8(header.TCPFlagSyn),
+ WindowSize: Uint16(8192),
+ Checksum: Uint16(0xf51c),
+ UrgentPointer: Uint16(0),
+ Options: []byte{3, 3, 2, 0},
+ },
+ &Payload{Bytes: nil},
+ },
+ },
+ {
+ description: "with payload",
+ wantBytes: []byte{
+ // IPv4 header
+ 0x45, 0x00, 0x00, 0x37, 0x00, 0x01, 0x00, 0x00, 0x40, 0x06,
+ 0xf9, 0x6c, 0xc0, 0xa8, 0x00, 0x02, 0xc0, 0xa8, 0x00, 0x01,
+ // TCP header
+ 0x30, 0x39, 0xd4, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x60, 0x02, 0x20, 0x00, 0xe5, 0x21, 0x00, 0x00,
+ // WindowScale Option
+ 0x03, 0x03, 0x02,
+ // NOP Option
+ 0x00,
+ // Payload: "Sample Data"
+ 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61,
+ },
+ wantLayers: []Layer{
+ &IPv4{
+ IHL: Uint8(20),
+ TOS: Uint8(0),
+ TotalLength: Uint16(55),
+ ID: Uint16(1),
+ Flags: Uint8(0),
+ FragmentOffset: Uint16(0),
+ TTL: Uint8(64),
+ Protocol: Uint8(uint8(header.TCPProtocolNumber)),
+ Checksum: Uint16(0xf96c),
+ SrcAddr: Address(tcpip.Address(net.ParseIP("192.168.0.2").To4())),
+ DstAddr: Address(tcpip.Address(net.ParseIP("192.168.0.1").To4())),
+ },
+ &TCP{
+ SrcPort: Uint16(12345),
+ DstPort: Uint16(54321),
+ SeqNum: Uint32(0),
+ AckNum: Uint32(0),
+ Flags: Uint8(header.TCPFlagSyn),
+ WindowSize: Uint16(8192),
+ Checksum: Uint16(0xe521),
+ UrgentPointer: Uint16(0),
+ Options: []byte{3, 3, 2, 0},
+ },
+ &Payload{Bytes: []byte("Sample Data")},
+ },
+ },
+ } {
+ t.Run(tt.description, func(t *testing.T) {
+ layers := parse(parseIPv4, tt.wantBytes)
+ if !layers.match(tt.wantLayers) {
+ t.Fatalf("match failed with diff: %s", layers.diff(tt.wantLayers))
+ }
+ gotBytes, err := layers.ToBytes()
+ if err != nil {
+ t.Fatalf("ToBytes() failed on %s: %s", &layers, err)
+ }
+ if !bytes.Equal(tt.wantBytes, gotBytes) {
+ t.Fatalf("mismatching bytes, gotBytes: %x, wantBytes: %x", gotBytes, tt.wantBytes)
+ }
+ })
+ }
+}
+
+func TestIPv6ExtHdrOptions(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ wantBytes []byte
+ wantLayers Layers
+ }{
+ {
+ description: "IPv6/HopByHop",
+ wantBytes: []byte{
+ // IPv6 Header
+ 0x60, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x40, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef,
+ // HopByHop Options
+ 0x3b, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00,
+ },
+ wantLayers: []Layer{
+ &IPv6{
+ SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))),
+ DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))),
+ },
+ &IPv6HopByHopOptionsExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier),
+ Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00},
+ },
+ &Payload{
+ Bytes: nil,
+ },
+ },
+ },
+ {
+ description: "IPv6/HopByHop/Payload",
+ wantBytes: []byte{
+ // IPv6 Header
+ 0x60, 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x40, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef,
+ // HopByHop Options
+ 0x3b, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00,
+ // Sample Data
+ 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61,
+ },
+ wantLayers: []Layer{
+ &IPv6{
+ SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))),
+ DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))),
+ },
+ &IPv6HopByHopOptionsExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier),
+ Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00},
+ },
+ &Payload{
+ Bytes: []byte("Sample Data"),
+ },
+ },
+ },
+ {
+ description: "IPv6/HopByHop/Destination/ICMPv6",
+ wantBytes: []byte{
+ // IPv6 Header
+ 0x60, 0x00, 0x00, 0x00, 0x00, 0x18, 0x00, 0x40, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef,
+ // HopByHop Options
+ 0x3c, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00,
+ // Destination Options
+ 0x3a, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00,
+ // ICMPv6 Param Problem
+ 0x04, 0x00, 0x5f, 0x98, 0x00, 0x00, 0x00, 0x06,
+ },
+ wantLayers: []Layer{
+ &IPv6{
+ SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))),
+ DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))),
+ },
+ &IPv6HopByHopOptionsExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6DestinationOptionsExtHdrIdentifier),
+ Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00},
+ },
+ &IPv6DestinationOptionsExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6ExtensionHeaderIdentifier(header.ICMPv6ProtocolNumber)),
+ Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00},
+ },
+ &ICMPv6{
+ Type: ICMPv6Type(header.ICMPv6ParamProblem),
+ Code: Byte(0),
+ Checksum: Uint16(0x5f98),
+ NDPPayload: []byte{0x00, 0x00, 0x00, 0x06},
+ },
+ },
+ },
+ } {
+ t.Run(tt.description, func(t *testing.T) {
+ layers := parse(parseIPv6, tt.wantBytes)
+ if !layers.match(tt.wantLayers) {
+ t.Fatalf("match failed with diff: %s", layers.diff(tt.wantLayers))
+ }
+ gotBytes, err := layers.ToBytes()
+ if err != nil {
+ t.Fatalf("ToBytes() failed on %s: %s", &layers, err)
+ }
+ if !bytes.Equal(tt.wantBytes, gotBytes) {
+ t.Fatalf("mismatching bytes, gotBytes: %x, wantBytes: %x", gotBytes, tt.wantBytes)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go
new file mode 100644
index 000000000..278229b7e
--- /dev/null
+++ b/test/packetimpact/testbench/rawsockets.go
@@ -0,0 +1,178 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "encoding/binary"
+ "fmt"
+ "math"
+ "net"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Sniffer can sniff raw packets on the wire.
+type Sniffer struct {
+ t *testing.T
+ fd int
+}
+
+func htons(x uint16) uint16 {
+ buf := [2]byte{}
+ binary.BigEndian.PutUint16(buf[:], x)
+ return usermem.ByteOrder.Uint16(buf[:])
+}
+
+// NewSniffer creates a Sniffer connected to *device.
+func NewSniffer(t *testing.T) (Sniffer, error) {
+ snifferFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL)))
+ if err != nil {
+ return Sniffer{}, err
+ }
+ if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, 1); err != nil {
+ t.Fatalf("can't set sockopt SO_RCVBUFFORCE to 1: %s", err)
+ }
+ if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1e7); err != nil {
+ t.Fatalf("can't setsockopt SO_RCVBUF to 10M: %s", err)
+ }
+ return Sniffer{
+ t: t,
+ fd: snifferFd,
+ }, nil
+}
+
+// maxReadSize should be large enough for the maximum frame size in bytes. If a
+// packet too large for the buffer arrives, the test will get a fatal error.
+const maxReadSize int = 65536
+
+// Recv tries to read one frame until the timeout is up.
+func (s *Sniffer) Recv(timeout time.Duration) []byte {
+ deadline := time.Now().Add(timeout)
+ for {
+ timeout = deadline.Sub(time.Now())
+ if timeout <= 0 {
+ return nil
+ }
+ whole, frac := math.Modf(timeout.Seconds())
+ tv := unix.Timeval{
+ Sec: int64(whole),
+ Usec: int64(frac * float64(time.Microsecond/time.Second)),
+ }
+
+ if err := unix.SetsockoptTimeval(s.fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
+ s.t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err)
+ }
+
+ buf := make([]byte, maxReadSize)
+ nread, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC)
+ if err == unix.EINTR || err == unix.EAGAIN {
+ // There was a timeout.
+ continue
+ }
+ if err != nil {
+ s.t.Fatalf("can't read: %s", err)
+ }
+ if nread > maxReadSize {
+ s.t.Fatalf("received a truncated frame of %d bytes", nread)
+ }
+ return buf[:nread]
+ }
+}
+
+// Drain drains the Sniffer's socket receive buffer by receiving until there's
+// nothing else to receive.
+func (s *Sniffer) Drain() {
+ s.t.Helper()
+ flags, err := unix.FcntlInt(uintptr(s.fd), unix.F_GETFL, 0)
+ if err != nil {
+ s.t.Fatalf("failed to get sniffer socket fd flags: %s", err)
+ }
+ if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags|unix.O_NONBLOCK); err != nil {
+ s.t.Fatalf("failed to make sniffer socket non-blocking: %s", err)
+ }
+ for {
+ buf := make([]byte, maxReadSize)
+ _, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC)
+ if err == unix.EINTR || err == unix.EAGAIN || err == unix.EWOULDBLOCK {
+ break
+ }
+ }
+ if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags); err != nil {
+ s.t.Fatalf("failed to restore sniffer socket fd flags: %s", err)
+ }
+}
+
+// close the socket that Sniffer is using.
+func (s *Sniffer) close() error {
+ if err := unix.Close(s.fd); err != nil {
+ return fmt.Errorf("can't close sniffer socket: %w", err)
+ }
+ s.fd = -1
+ return nil
+}
+
+// Injector can inject raw frames.
+type Injector struct {
+ t *testing.T
+ fd int
+}
+
+// NewInjector creates a new injector on *device.
+func NewInjector(t *testing.T) (Injector, error) {
+ ifInfo, err := net.InterfaceByName(Device)
+ if err != nil {
+ return Injector{}, err
+ }
+
+ var haddr [8]byte
+ copy(haddr[:], ifInfo.HardwareAddr)
+ sa := unix.SockaddrLinklayer{
+ Protocol: unix.ETH_P_IP,
+ Ifindex: ifInfo.Index,
+ Halen: uint8(len(ifInfo.HardwareAddr)),
+ Addr: haddr,
+ }
+
+ injectFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL)))
+ if err != nil {
+ return Injector{}, err
+ }
+ if err := unix.Bind(injectFd, &sa); err != nil {
+ return Injector{}, err
+ }
+ return Injector{
+ t: t,
+ fd: injectFd,
+ }, nil
+}
+
+// Send a raw frame.
+func (i *Injector) Send(b []byte) {
+ if _, err := unix.Write(i.fd, b); err != nil {
+ i.t.Fatalf("can't write: %s of len %d", err, len(b))
+ }
+}
+
+// close the underlying socket.
+func (i *Injector) close() error {
+ if err := unix.Close(i.fd); err != nil {
+ return fmt.Errorf("can't close sniffer socket: %w", err)
+ }
+ i.fd = -1
+ return nil
+}
diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go
new file mode 100644
index 000000000..d64f32a5b
--- /dev/null
+++ b/test/packetimpact/testbench/testbench.go
@@ -0,0 +1,106 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "flag"
+ "fmt"
+ "math/rand"
+ "net"
+ "os/exec"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/test/packetimpact/netdevs"
+)
+
+var (
+ // DUTType is the type of device under test.
+ DUTType = ""
+ // Device is the local device on the test network.
+ Device = ""
+ // LocalIPv4 is the local IPv4 address on the test network.
+ LocalIPv4 = ""
+ // LocalIPv6 is the local IPv6 address on the test network.
+ LocalIPv6 = ""
+ // LocalMAC is the local MAC address on the test network.
+ LocalMAC = ""
+ // POSIXServerIP is the POSIX server's IP address on the control network.
+ POSIXServerIP = ""
+ // POSIXServerPort is the UDP port the POSIX server is bound to on the
+ // control network.
+ POSIXServerPort = 40000
+ // RemoteIPv4 is the DUT's IPv4 address on the test network.
+ RemoteIPv4 = ""
+ // RemoteIPv6 is the DUT's IPv6 address on the test network.
+ RemoteIPv6 = ""
+ // RemoteMAC is the DUT's MAC address on the test network.
+ RemoteMAC = ""
+ // RPCKeepalive is the gRPC keepalive.
+ RPCKeepalive = 10 * time.Second
+ // RPCTimeout is the gRPC timeout.
+ RPCTimeout = 100 * time.Millisecond
+)
+
+// RegisterFlags defines flags and associates them with the package-level
+// exported variables above. It should be called by tests in their init
+// functions.
+func RegisterFlags(fs *flag.FlagSet) {
+ fs.StringVar(&POSIXServerIP, "posix_server_ip", POSIXServerIP, "ip address to listen to for UDP commands")
+ fs.IntVar(&POSIXServerPort, "posix_server_port", POSIXServerPort, "port to listen to for UDP commands")
+ fs.DurationVar(&RPCTimeout, "rpc_timeout", RPCTimeout, "gRPC timeout")
+ fs.DurationVar(&RPCKeepalive, "rpc_keepalive", RPCKeepalive, "gRPC keepalive")
+ fs.StringVar(&LocalIPv4, "local_ipv4", LocalIPv4, "local IPv4 address for test packets")
+ fs.StringVar(&RemoteIPv4, "remote_ipv4", RemoteIPv4, "remote IPv4 address for test packets")
+ fs.StringVar(&RemoteIPv6, "remote_ipv6", RemoteIPv6, "remote IPv6 address for test packets")
+ fs.StringVar(&RemoteMAC, "remote_mac", RemoteMAC, "remote mac address for test packets")
+ fs.StringVar(&Device, "device", Device, "local device for test packets")
+ fs.StringVar(&DUTType, "dut_type", DUTType, "type of device under test")
+}
+
+// genPseudoFlags populates flag-like global config based on real flags.
+//
+// genPseudoFlags must only be called after flag.Parse.
+func genPseudoFlags() error {
+ out, err := exec.Command("ip", "addr", "show").CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("listing devices: %q: %w", string(out), err)
+ }
+ devs, err := netdevs.ParseDevices(string(out))
+ if err != nil {
+ return fmt.Errorf("parsing devices: %w", err)
+ }
+
+ _, deviceInfo, err := netdevs.FindDeviceByIP(net.ParseIP(LocalIPv4), devs)
+ if err != nil {
+ return fmt.Errorf("can't find deviceInfo: %w", err)
+ }
+
+ LocalMAC = deviceInfo.MAC.String()
+ LocalIPv6 = deviceInfo.IPv6Addr.String()
+
+ return nil
+}
+
+// GenerateRandomPayload generates a random byte slice of the specified length,
+// causing a fatal test failure if it is unable to do so.
+func GenerateRandomPayload(t *testing.T, n int) []byte {
+ t.Helper()
+ buf := make([]byte, n)
+ if _, err := rand.Read(buf); err != nil {
+ t.Fatalf("rand.Read(buf) failed: %s", err)
+ }
+ return buf
+}
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
new file mode 100644
index 000000000..3ecbe83eb
--- /dev/null
+++ b/test/packetimpact/tests/BUILD
@@ -0,0 +1,264 @@
+load("//test/packetimpact/runner:defs.bzl", "packetimpact_go_test")
+
+package(
+ default_visibility = ["//test/packetimpact:__subpackages__"],
+ licenses = ["notice"],
+)
+
+packetimpact_go_test(
+ name = "fin_wait2_timeout",
+ srcs = ["fin_wait2_timeout_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "ipv4_id_uniqueness",
+ srcs = ["ipv4_id_uniqueness_test.go"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "udp_recv_multicast",
+ srcs = ["udp_recv_multicast_test.go"],
+ # TODO(b/152813495): Fix netstack then remove the line below.
+ expect_netstack_failure = True,
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "udp_icmp_error_propagation",
+ srcs = ["udp_icmp_error_propagation_test.go"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_reordering",
+ srcs = ["tcp_reordering_test.go"],
+ # TODO(b/139368047): Fix netstack then remove the line below.
+ expect_netstack_failure = True,
+ deps = [
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_window_shrink",
+ srcs = ["tcp_window_shrink_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_zero_window_probe",
+ srcs = ["tcp_zero_window_probe_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_zero_window_probe_retransmit",
+ srcs = ["tcp_zero_window_probe_retransmit_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_zero_window_probe_usertimeout",
+ srcs = ["tcp_zero_window_probe_usertimeout_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_retransmits",
+ srcs = ["tcp_retransmits_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_outside_the_window",
+ srcs = ["tcp_outside_the_window_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_noaccept_close_rst",
+ srcs = ["tcp_noaccept_close_rst_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_send_window_sizes_piggyback",
+ srcs = ["tcp_send_window_sizes_piggyback_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_close_wait_ack",
+ srcs = ["tcp_close_wait_ack_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_paws_mechanism",
+ srcs = ["tcp_paws_mechanism_test.go"],
+ # TODO(b/156682000): Fix netstack then remove the line below.
+ expect_netstack_failure = True,
+ deps = [
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_user_timeout",
+ srcs = ["tcp_user_timeout_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_queue_receive_in_syn_sent",
+ srcs = ["tcp_queue_receive_in_syn_sent_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_synsent_reset",
+ srcs = ["tcp_synsent_reset_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_synrcvd_reset",
+ srcs = ["tcp_synrcvd_reset_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_cork_mss",
+ srcs = ["tcp_cork_mss_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_handshake_window_size",
+ srcs = ["tcp_handshake_window_size_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "icmpv6_param_problem",
+ srcs = ["icmpv6_param_problem_test.go"],
+ # TODO(b/153485026): Fix netstack then remove the line below.
+ expect_netstack_failure = True,
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "ipv6_unknown_options_action",
+ srcs = ["ipv6_unknown_options_action_test.go"],
+ # TODO(b/159928940): Fix netstack then remove the line below.
+ expect_netstack_failure = True,
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "udp_send_recv_dgram",
+ srcs = ["udp_send_recv_dgram_test.go"],
+ deps = [
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go
new file mode 100644
index 000000000..407565078
--- /dev/null
+++ b/test/packetimpact/tests/fin_wait2_timeout_test.go
@@ -0,0 +1,75 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fin_wait2_timeout_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func TestFinWait2Timeout(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ linger2 bool
+ }{
+ {"WithLinger2", true},
+ {"WithoutLinger2", false},
+ } {
+ t.Run(tt.description, func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFd)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+ conn.Connect()
+
+ acceptFd, _ := dut.Accept(listenFd)
+ if tt.linger2 {
+ tv := unix.Timeval{Sec: 1, Usec: 0}
+ dut.SetSockOptTimeval(acceptFd, unix.SOL_TCP, unix.TCP_LINGER2, &tv)
+ }
+ dut.Close(acceptFd)
+
+ if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
+ t.Fatalf("expected a FIN-ACK within 1 second but got none: %s", err)
+ }
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+
+ time.Sleep(5 * time.Second)
+ conn.Drain()
+
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ if tt.linger2 {
+ if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
+ t.Fatalf("expected a RST packet within a second but got none: %s", err)
+ }
+ } else {
+ if got, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, 10*time.Second); got != nil || err == nil {
+ t.Fatalf("expected no RST packets within ten seconds but got one: %s", got)
+ }
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/icmpv6_param_problem_test.go b/test/packetimpact/tests/icmpv6_param_problem_test.go
new file mode 100644
index 000000000..4d1d9a7f5
--- /dev/null
+++ b/test/packetimpact/tests/icmpv6_param_problem_test.go
@@ -0,0 +1,78 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package icmpv6_param_problem_test
+
+import (
+ "encoding/binary"
+ "flag"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestICMPv6ParamProblemTest sends a packet with a bad next header. The DUT
+// should respond with an ICMPv6 Parameter Problem message.
+func TestICMPv6ParamProblemTest(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
+ defer conn.Close()
+ ipv6 := testbench.IPv6{
+ // 254 is reserved and used for experimentation and testing. This should
+ // cause an error.
+ NextHeader: testbench.Uint8(254),
+ }
+ icmpv6 := testbench.ICMPv6{
+ Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest),
+ NDPPayload: []byte("hello world"),
+ }
+
+ toSend := (*testbench.Connection)(&conn).CreateFrame(testbench.Layers{&ipv6}, &icmpv6)
+ (*testbench.Connection)(&conn).SendFrame(toSend)
+
+ // Build the expected ICMPv6 payload, which includes an index to the
+ // problematic byte and also the problematic packet as described in
+ // https://tools.ietf.org/html/rfc4443#page-12 .
+ ipv6Sent := toSend[1:]
+ expectedPayload, err := ipv6Sent.ToBytes()
+ if err != nil {
+ t.Fatalf("can't convert %s to bytes: %s", ipv6Sent, err)
+ }
+
+ // The problematic field is the NextHeader.
+ b := make([]byte, 4)
+ binary.BigEndian.PutUint32(b, header.IPv6NextHeaderOffset)
+ expectedPayload = append(b, expectedPayload...)
+ expectedICMPv6 := testbench.ICMPv6{
+ Type: testbench.ICMPv6Type(header.ICMPv6ParamProblem),
+ NDPPayload: expectedPayload,
+ }
+
+ paramProblem := testbench.Layers{
+ &testbench.Ether{},
+ &testbench.IPv6{},
+ &expectedICMPv6,
+ }
+ timeout := time.Second
+ if _, err := conn.ExpectFrame(paramProblem, timeout); err != nil {
+ t.Errorf("expected %s within %s but got none: %s", paramProblem, timeout, err)
+ }
+}
diff --git a/test/packetimpact/tests/ipv4_id_uniqueness_test.go b/test/packetimpact/tests/ipv4_id_uniqueness_test.go
new file mode 100644
index 000000000..70f6df5e0
--- /dev/null
+++ b/test/packetimpact/tests/ipv4_id_uniqueness_test.go
@@ -0,0 +1,122 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4_id_uniqueness_test
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func recvTCPSegment(conn *testbench.TCPIPv4, expect *testbench.TCP, expectPayload *testbench.Payload) (uint16, error) {
+ layers, err := conn.ExpectData(expect, expectPayload, time.Second)
+ if err != nil {
+ return 0, fmt.Errorf("failed to receive TCP segment: %s", err)
+ }
+ if len(layers) < 2 {
+ return 0, fmt.Errorf("got packet with layers: %v, expected to have at least 2 layers (link and network)", layers)
+ }
+ ipv4, ok := layers[1].(*testbench.IPv4)
+ if !ok {
+ return 0, fmt.Errorf("got network layer: %T, expected: *IPv4", layers[1])
+ }
+ if *ipv4.Flags&header.IPv4FlagDontFragment != 0 {
+ return 0, fmt.Errorf("got IPv4 DF=1, expected DF=0")
+ }
+ return *ipv4.ID, nil
+}
+
+// RFC 6864 section 4.2 states: "The IPv4 ID of non-atomic datagrams MUST NOT
+// be reused when sending a copy of an earlier non-atomic datagram."
+//
+// This test creates a TCP connection, uses the IP_MTU_DISCOVER socket option
+// to force the DF bit to be 0, and checks that a retransmitted segment has a
+// different IPv4 Identification value than the original segment.
+func TestIPv4RetransmitIdentificationUniqueness(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ payload []byte
+ }{
+ {"SmallPayload", []byte("sample data")},
+ // 512 bytes is chosen because sending more than this in a single segment
+ // causes the retransmission to send less than the original amount.
+ {"512BytePayload", testbench.GenerateRandomPayload(t, 512)},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+
+ listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFD)
+
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ conn.Connect()
+ remoteFD, _ := dut.Accept(listenFD)
+ defer dut.Close(remoteFD)
+
+ dut.SetSockOptInt(remoteFD, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ // TODO(b/129291778) The following socket option clears the DF bit on
+ // IP packets sent over the socket, and is currently not supported by
+ // gVisor. gVisor by default sends packets with DF=0 anyway, so the
+ // socket option being not supported does not affect the operation of
+ // this test. Once the socket option is supported, the following call
+ // can be changed to simply assert success.
+ ret, errno := dut.SetSockOptIntWithErrno(context.Background(), remoteFD, unix.IPPROTO_IP, linux.IP_MTU_DISCOVER, linux.IP_PMTUDISC_DONT)
+ if ret == -1 && errno != unix.ENOTSUP {
+ t.Fatalf("failed to set IP_MTU_DISCOVER socket option to IP_PMTUDISC_DONT: %s", errno)
+ }
+
+ samplePayload := &testbench.Payload{Bytes: tc.payload}
+
+ dut.Send(remoteFD, tc.payload, 0)
+ if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("failed to receive TCP segment sent for RTT calculation: %s", err)
+ }
+ // Let the DUT estimate RTO with RTT from the DATA-ACK.
+ // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which
+ // we can skip sending this ACK.
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+
+ dut.Send(remoteFD, tc.payload, 0)
+ expectTCP := &testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.RemoteSeqNum()))}
+ originalID, err := recvTCPSegment(&conn, expectTCP, samplePayload)
+ if err != nil {
+ t.Fatalf("failed to receive TCP segment: %s", err)
+ }
+
+ retransmitID, err := recvTCPSegment(&conn, expectTCP, samplePayload)
+ if err != nil {
+ t.Fatalf("failed to receive retransmitted TCP segment: %s", err)
+ }
+ if originalID == retransmitID {
+ t.Fatalf("unexpectedly got retransmitted TCP segment with same IPv4 ID field=%d", originalID)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/ipv6_unknown_options_action_test.go b/test/packetimpact/tests/ipv6_unknown_options_action_test.go
new file mode 100644
index 000000000..d301d8829
--- /dev/null
+++ b/test/packetimpact/tests/ipv6_unknown_options_action_test.go
@@ -0,0 +1,187 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6_unknown_options_action_test
+
+import (
+ "encoding/binary"
+ "flag"
+ "net"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ tb.RegisterFlags(flag.CommandLine)
+}
+
+func mkHopByHopOptionsExtHdr(optType byte) tb.Layer {
+ return &tb.IPv6HopByHopOptionsExtHdr{
+ Options: []byte{optType, 0x04, 0x00, 0x00, 0x00, 0x00},
+ }
+}
+
+func mkDestinationOptionsExtHdr(optType byte) tb.Layer {
+ return &tb.IPv6DestinationOptionsExtHdr{
+ Options: []byte{optType, 0x04, 0x00, 0x00, 0x00, 0x00},
+ }
+}
+
+func optionTypeFromAction(action header.IPv6OptionUnknownAction) byte {
+ return byte(action << 6)
+}
+
+func TestIPv6UnknownOptionAction(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ mkExtHdr func(optType byte) tb.Layer
+ action header.IPv6OptionUnknownAction
+ multicastDst bool
+ wantICMPv6 bool
+ }{
+ {
+ description: "0b00/hbh",
+ mkExtHdr: mkHopByHopOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionSkip,
+ multicastDst: false,
+ wantICMPv6: false,
+ },
+ {
+ description: "0b01/hbh",
+ mkExtHdr: mkHopByHopOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscard,
+ multicastDst: false,
+ wantICMPv6: false,
+ },
+ {
+ description: "0b10/hbh/unicast",
+ mkExtHdr: mkHopByHopOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscardSendICMP,
+ multicastDst: false,
+ wantICMPv6: true,
+ },
+ {
+ description: "0b10/hbh/multicast",
+ mkExtHdr: mkHopByHopOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscardSendICMP,
+ multicastDst: true,
+ wantICMPv6: true,
+ },
+ {
+ description: "0b11/hbh/unicast",
+ mkExtHdr: mkHopByHopOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest,
+ multicastDst: false,
+ wantICMPv6: true,
+ },
+ {
+ description: "0b11/hbh/multicast",
+ mkExtHdr: mkHopByHopOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest,
+ multicastDst: true,
+ wantICMPv6: false,
+ },
+ {
+ description: "0b00/destination",
+ mkExtHdr: mkDestinationOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionSkip,
+ multicastDst: false,
+ wantICMPv6: false,
+ },
+ {
+ description: "0b01/destination",
+ mkExtHdr: mkDestinationOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscard,
+ multicastDst: false,
+ wantICMPv6: false,
+ },
+ {
+ description: "0b10/destination/unicast",
+ mkExtHdr: mkDestinationOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscardSendICMP,
+ multicastDst: false,
+ wantICMPv6: true,
+ },
+ {
+ description: "0b10/destination/multicast",
+ mkExtHdr: mkDestinationOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscardSendICMP,
+ multicastDst: true,
+ wantICMPv6: true,
+ },
+ {
+ description: "0b11/destination/unicast",
+ mkExtHdr: mkDestinationOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest,
+ multicastDst: false,
+ wantICMPv6: true,
+ },
+ {
+ description: "0b11/destination/multicast",
+ mkExtHdr: mkDestinationOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest,
+ multicastDst: true,
+ wantICMPv6: false,
+ },
+ } {
+ t.Run(tt.description, func(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ ipv6Conn := tb.NewIPv6Conn(t, tb.IPv6{}, tb.IPv6{})
+ conn := (*tb.Connection)(&ipv6Conn)
+ defer ipv6Conn.Close()
+
+ outgoingOverride := tb.Layers{}
+ if tt.multicastDst {
+ outgoingOverride = tb.Layers{&tb.IPv6{
+ DstAddr: tb.Address(tcpip.Address(net.ParseIP("ff02::1"))),
+ }}
+ }
+
+ outgoing := conn.CreateFrame(outgoingOverride, tt.mkExtHdr(optionTypeFromAction(tt.action)))
+ conn.SendFrame(outgoing)
+ ipv6Sent := outgoing[1:]
+ invokingPacket, err := ipv6Sent.ToBytes()
+ if err != nil {
+ t.Fatalf("failed to serialize the outgoing packet: %s", err)
+ }
+ icmpv6Payload := make([]byte, 4)
+ // The pointer in the ICMPv6 parameter problem message should point to
+ // the option type of the unknown option. In our test case, it is the
+ // first option in the extension header whose option type is 2 bytes
+ // after the IPv6 header (after NextHeader and ExtHdrLen).
+ binary.BigEndian.PutUint32(icmpv6Payload, header.IPv6MinimumSize+2)
+ icmpv6Payload = append(icmpv6Payload, invokingPacket...)
+ gotICMPv6, err := ipv6Conn.ExpectFrame(tb.Layers{
+ &tb.Ether{},
+ &tb.IPv6{},
+ &tb.ICMPv6{
+ Type: tb.ICMPv6Type(header.ICMPv6ParamProblem),
+ Code: tb.Byte(2),
+ NDPPayload: icmpv6Payload,
+ },
+ }, time.Second)
+ if tt.wantICMPv6 && err != nil {
+ t.Fatalf("expected ICMPv6 Parameter Problem but got none: %s", err)
+ }
+ if !tt.wantICMPv6 && gotICMPv6 != nil {
+ t.Fatalf("expected no ICMPv6 Parameter Problem but got one: %s", gotICMPv6)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go
new file mode 100644
index 000000000..6e7ff41d7
--- /dev/null
+++ b/test/packetimpact/tests/tcp_close_wait_ack_test.go
@@ -0,0 +1,108 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_close_wait_ack_test
+
+import (
+ "flag"
+ "fmt"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func TestCloseWaitAck(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ makeTestingTCP func(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP
+ seqNumOffset seqnum.Size
+ expectAck bool
+ }{
+ {"OTW", GenerateOTWSeqSegment, 0, false},
+ {"OTW", GenerateOTWSeqSegment, 1, true},
+ {"OTW", GenerateOTWSeqSegment, 2, true},
+ {"ACK", GenerateUnaccACKSegment, 0, false},
+ {"ACK", GenerateUnaccACKSegment, 1, true},
+ {"ACK", GenerateUnaccACKSegment, 2, true},
+ } {
+ t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFd)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ conn.Connect()
+ acceptFd, _ := dut.Accept(listenFd)
+
+ // Send a FIN to DUT to intiate the active close
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)})
+ gotTCP, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err)
+ }
+ windowSize := seqnum.Size(*gotTCP.WindowSize)
+
+ // Send a segment with OTW Seq / unacc ACK and expect an ACK back
+ conn.Send(tt.makeTestingTCP(&conn, tt.seqNumOffset, windowSize), &testbench.Payload{Bytes: []byte("Sample Data")})
+ gotAck, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ if tt.expectAck && err != nil {
+ t.Fatalf("expected an ack but got none: %s", err)
+ }
+ if !tt.expectAck && gotAck != nil {
+ t.Fatalf("expected no ack but got one: %s", gotAck)
+ }
+
+ // Now let's verify DUT is indeed in CLOSE_WAIT
+ dut.Close(acceptFd)
+ if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil {
+ t.Fatalf("expected DUT to send a FIN: %s", err)
+ }
+ // Ack the FIN from DUT
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ // Send some extra data to DUT
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: []byte("Sample Data")})
+ if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
+ t.Fatalf("expected DUT to send an RST: %s", err)
+ }
+ })
+ }
+}
+
+// This generates an segment with seqnum = RCV.NXT + RCV.WND + seqNumOffset, the
+// generated segment is only acceptable when seqNumOffset is 0, otherwise an ACK
+// is expected from the receiver.
+func GenerateOTWSeqSegment(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP {
+ lastAcceptable := conn.LocalSeqNum().Add(windowSize)
+ otwSeq := uint32(lastAcceptable.Add(seqNumOffset))
+ return testbench.TCP{SeqNum: testbench.Uint32(otwSeq), Flags: testbench.Uint8(header.TCPFlagAck)}
+}
+
+// This generates an segment with acknum = SND.NXT + seqNumOffset, the generated
+// segment is only acceptable when seqNumOffset is 0, otherwise an ACK is
+// expected from the receiver.
+func GenerateUnaccACKSegment(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP {
+ lastAcceptable := conn.RemoteSeqNum()
+ unaccAck := uint32(lastAcceptable.Add(seqNumOffset))
+ return testbench.TCP{AckNum: testbench.Uint32(unaccAck), Flags: testbench.Uint8(header.TCPFlagAck)}
+}
diff --git a/test/packetimpact/tests/tcp_cork_mss_test.go b/test/packetimpact/tests/tcp_cork_mss_test.go
new file mode 100644
index 000000000..fb8f48629
--- /dev/null
+++ b/test/packetimpact/tests/tcp_cork_mss_test.go
@@ -0,0 +1,84 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_cork_mss_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestTCPCorkMSS tests for segment coalesce and split as per MSS.
+func TestTCPCorkMSS(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFD)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ const mss = uint32(header.TCPDefaultMSS)
+ options := make([]byte, header.TCPOptionMSSLength)
+ header.EncodeMSSOption(mss, options)
+ conn.ConnectWithOptions(options)
+
+ acceptFD, _ := dut.Accept(listenFD)
+ defer dut.Close(acceptFD)
+
+ dut.SetSockOptInt(acceptFD, unix.IPPROTO_TCP, unix.TCP_CORK, 1)
+
+ // Let the dut application send 2 small segments to be held up and coalesced
+ // until the application sends a larger segment to fill up to > MSS.
+ sampleData := []byte("Sample Data")
+ dut.Send(acceptFD, sampleData, 0)
+ dut.Send(acceptFD, sampleData, 0)
+
+ expectedData := sampleData
+ expectedData = append(expectedData, sampleData...)
+ largeData := make([]byte, mss+1)
+ expectedData = append(expectedData, largeData...)
+ dut.Send(acceptFD, largeData, 0)
+
+ // Expect the segments to be coalesced and sent and capped to MSS.
+ expectedPayload := testbench.Payload{Bytes: expectedData[:mss]}
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &expectedPayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ // Expect the coalesced segment to be split and transmitted.
+ expectedPayload = testbench.Payload{Bytes: expectedData[mss:]}
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+
+ // Check for segments to *not* be held up because of TCP_CORK when
+ // the current send window is less than MSS.
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(2 * len(sampleData)))})
+ dut.Send(acceptFD, sampleData, 0)
+ dut.Send(acceptFD, sampleData, 0)
+ expectedPayload = testbench.Payload{Bytes: append(sampleData, sampleData...)}
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+}
diff --git a/test/packetimpact/tests/tcp_handshake_window_size_test.go b/test/packetimpact/tests/tcp_handshake_window_size_test.go
new file mode 100644
index 000000000..652b530d0
--- /dev/null
+++ b/test/packetimpact/tests/tcp_handshake_window_size_test.go
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_handshake_window_size_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestTCPHandshakeWindowSize tests if the stack is honoring the window size
+// communicated during handshake.
+func TestTCPHandshakeWindowSize(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFD)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ // Start handshake with zero window size.
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), WindowSize: testbench.Uint16(uint16(0))})
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
+ t.Fatalf("expected SYN-ACK: %s", err)
+ }
+ // Update the advertised window size to a non-zero value with the ACK that
+ // completes the handshake.
+ //
+ // Set the window size with MSB set and expect the dut to treat it as
+ // an unsigned value.
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(1 << 15))})
+
+ acceptFd, _ := dut.Accept(listenFD)
+ defer dut.Close(acceptFd)
+
+ sampleData := []byte("Sample Data")
+ samplePayload := &testbench.Payload{Bytes: sampleData}
+
+ // Since we advertised a zero window followed by a non-zero window,
+ // expect the dut to honor the recently advertised non-zero window
+ // and actually send out the data instead of probing for zero window.
+ dut.Send(acceptFd, sampleData, 0)
+ if _, err := conn.ExpectNextData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
new file mode 100644
index 000000000..b9b3e91d3
--- /dev/null
+++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
@@ -0,0 +1,42 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_noaccept_close_rst_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func TestTcpNoAcceptCloseReset(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn.Connect()
+ defer conn.Close()
+ dut.Close(listenFd)
+ if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil {
+ t.Fatalf("expected a RST-ACK packet but got none: %s", err)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go
new file mode 100644
index 000000000..ad8c74234
--- /dev/null
+++ b/test/packetimpact/tests/tcp_outside_the_window_test.go
@@ -0,0 +1,93 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_outside_the_window_test
+
+import (
+ "flag"
+ "fmt"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestTCPOutsideTheWindows tests the behavior of the DUT when packets arrive
+// that are inside or outside the TCP window. Packets that are outside the
+// window should force an extra ACK, as described in RFC793 page 69:
+// https://tools.ietf.org/html/rfc793#page-69
+func TestTCPOutsideTheWindow(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ tcpFlags uint8
+ payload []testbench.Layer
+ seqNumOffset seqnum.Size
+ expectACK bool
+ }{
+ {"SYN", header.TCPFlagSyn, nil, 0, true},
+ {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 0, true},
+ {"ACK", header.TCPFlagAck, nil, 0, false},
+ {"FIN", header.TCPFlagFin, nil, 0, false},
+ {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("abc123")}}, 0, true},
+
+ {"SYN", header.TCPFlagSyn, nil, 1, true},
+ {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 1, true},
+ {"ACK", header.TCPFlagAck, nil, 1, true},
+ {"FIN", header.TCPFlagFin, nil, 1, false},
+ {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("abc123")}}, 1, true},
+
+ {"SYN", header.TCPFlagSyn, nil, 2, true},
+ {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 2, true},
+ {"ACK", header.TCPFlagAck, nil, 2, true},
+ {"FIN", header.TCPFlagFin, nil, 2, false},
+ {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("abc123")}}, 2, true},
+ } {
+ t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFD)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+ conn.Connect()
+ acceptFD, _ := dut.Accept(listenFD)
+ defer dut.Close(acceptFD)
+
+ windowSize := seqnum.Size(*conn.SynAck().WindowSize) + tt.seqNumOffset
+ conn.Drain()
+ // Ignore whatever incrementing that this out-of-order packet might cause
+ // to the AckNum.
+ localSeqNum := testbench.Uint32(uint32(*conn.LocalSeqNum()))
+ conn.Send(testbench.TCP{
+ Flags: testbench.Uint8(tt.tcpFlags),
+ SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum().Add(windowSize))),
+ }, tt.payload...)
+ timeout := 3 * time.Second
+ gotACK, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout)
+ if tt.expectACK && err != nil {
+ t.Fatalf("expected an ACK packet within %s but got none: %s", timeout, err)
+ }
+ if !tt.expectACK && gotACK != nil {
+ t.Fatalf("expected no ACK packet within %s but got one: %s", timeout, gotACK)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/tcp_paws_mechanism_test.go b/test/packetimpact/tests/tcp_paws_mechanism_test.go
new file mode 100644
index 000000000..55db4ece6
--- /dev/null
+++ b/test/packetimpact/tests/tcp_paws_mechanism_test.go
@@ -0,0 +1,109 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_paws_mechanism_test
+
+import (
+ "encoding/hex"
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func TestPAWSMechanism(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFD)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ options := make([]byte, header.TCPOptionTSLength)
+ header.EncodeTSOption(currentTS(), 0, options)
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), Options: options})
+ synAck, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("didn't get synack during handshake: %s", err)
+ }
+ parsedSynOpts := header.ParseSynOptions(synAck.Options, true)
+ if !parsedSynOpts.TS {
+ t.Fatalf("expected TSOpt from DUT, options we got:\n%s", hex.Dump(synAck.Options))
+ }
+ tsecr := parsedSynOpts.TSVal
+ header.EncodeTSOption(currentTS(), tsecr, options)
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options})
+ acceptFD, _ := dut.Accept(listenFD)
+ defer dut.Close(acceptFD)
+
+ sampleData := []byte("Sample Data")
+ sentTSVal := currentTS()
+ header.EncodeTSOption(sentTSVal, tsecr, options)
+ // 3ms here is chosen arbitrarily to make sure we have increasing timestamps
+ // every time we send one, it should not cause any flakiness because timestamps
+ // only need to be non-decreasing.
+ time.Sleep(3 * time.Millisecond)
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData})
+
+ gotTCP, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("expected an ACK but got none: %s", err)
+ }
+
+ parsedOpts := header.ParseTCPOptions(gotTCP.Options)
+ if !parsedOpts.TS {
+ t.Fatalf("expected TS option in response, options we got:\n%s", hex.Dump(gotTCP.Options))
+ }
+ if parsedOpts.TSVal < tsecr {
+ t.Fatalf("TSVal should be non-decreasing, but %d < %d", parsedOpts.TSVal, tsecr)
+ }
+ if parsedOpts.TSEcr != sentTSVal {
+ t.Fatalf("TSEcr should match our sent TSVal, %d != %d", parsedOpts.TSEcr, sentTSVal)
+ }
+ tsecr = parsedOpts.TSVal
+ lastAckNum := gotTCP.AckNum
+
+ badTSVal := sentTSVal - 100
+ header.EncodeTSOption(badTSVal, tsecr, options)
+ // 3ms here is chosen arbitrarily and this time.Sleep() should not cause flakiness
+ // due to the exact same reasoning discussed above.
+ time.Sleep(3 * time.Millisecond)
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData})
+
+ gotTCP, err = conn.Expect(testbench.TCP{AckNum: lastAckNum, Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("expected segment with AckNum %d but got none: %s", lastAckNum, err)
+ }
+ parsedOpts = header.ParseTCPOptions(gotTCP.Options)
+ if !parsedOpts.TS {
+ t.Fatalf("expected TS option in response, options we got:\n%s", hex.Dump(gotTCP.Options))
+ }
+ if parsedOpts.TSVal < tsecr {
+ t.Fatalf("TSVal should be non-decreasing, but %d < %d", parsedOpts.TSVal, tsecr)
+ }
+ if parsedOpts.TSEcr != sentTSVal {
+ t.Fatalf("TSEcr should match our sent TSVal, %d != %d", parsedOpts.TSEcr, sentTSVal)
+ }
+}
+
+func currentTS() uint32 {
+ return uint32(time.Now().UnixNano() / 1e6)
+}
diff --git a/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go
new file mode 100644
index 000000000..8fbec893b
--- /dev/null
+++ b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go
@@ -0,0 +1,132 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_queue_receive_in_syn_sent_test
+
+import (
+ "bytes"
+ "context"
+ "encoding/hex"
+ "errors"
+ "flag"
+ "net"
+ "sync"
+ "syscall"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestQueueReceiveInSynSent tests receive behavior when the TCP state
+// is SYN-SENT.
+// It tests for 2 variants where the receive is blocked and:
+// (1) we complete handshake and send sample data.
+// (2) we send a TCP RST.
+func TestQueueReceiveInSynSent(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ reset bool
+ }{
+ {description: "Send DATA", reset: false},
+ {description: "Send RST", reset: true},
+ } {
+ t.Run(tt.description, func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+
+ socket, remotePort := dut.CreateBoundSocket(unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4))
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ sampleData := []byte("Sample Data")
+
+ dut.SetNonBlocking(socket, true)
+ if _, err := dut.ConnectWithErrno(context.Background(), socket, conn.LocalAddr()); !errors.Is(err, syscall.EINPROGRESS) {
+ t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err)
+ }
+ if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil {
+ t.Fatalf("expected a SYN from DUT, but got none: %s", err)
+ }
+
+ if _, _, err := dut.RecvWithErrno(context.Background(), socket, int32(len(sampleData)), 0); err != syscall.Errno(unix.EWOULDBLOCK) {
+ t.Fatalf("expected error %s, got %s", syscall.Errno(unix.EWOULDBLOCK), err)
+ }
+
+ // Test blocking read.
+ dut.SetNonBlocking(socket, false)
+
+ var wg sync.WaitGroup
+ defer wg.Wait()
+ wg.Add(1)
+ var block sync.WaitGroup
+ block.Add(1)
+ go func() {
+ defer wg.Done()
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
+ defer cancel()
+
+ block.Done()
+ // Issue RECEIVE call in SYN-SENT, this should be queued for
+ // process until the connection is established.
+ n, buff, err := dut.RecvWithErrno(ctx, socket, int32(len(sampleData)), 0)
+ if tt.reset {
+ if err != syscall.Errno(unix.ECONNREFUSED) {
+ t.Errorf("expected error %s, got %s", syscall.Errno(unix.ECONNREFUSED), err)
+ }
+ if n != -1 {
+ t.Errorf("expected return value %d, got %d", -1, n)
+ }
+ return
+ }
+ if n == -1 {
+ t.Errorf("failed to recv on DUT: %s", err)
+ }
+ if got := buff[:n]; !bytes.Equal(got, sampleData) {
+ t.Errorf("received data doesn't match, got:\n%s, want:\n%s", hex.Dump(got), hex.Dump(sampleData))
+ }
+ }()
+
+ // Wait for the goroutine to be scheduled and before it
+ // blocks on endpoint receive.
+ block.Wait()
+ // The following sleep is used to prevent the connection
+ // from being established before we are blocked on Recv.
+ time.Sleep(100 * time.Millisecond)
+
+ if tt.reset {
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)})
+ return
+ }
+
+ // Bring the connection to Established.
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)})
+ if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
+ t.Fatalf("expected an ACK from DUT, but got none: %s", err)
+ }
+
+ // Send sample payload and expect an ACK.
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData})
+ if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
+ t.Fatalf("expected an ACK from DUT, but got none: %s", err)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/tcp_reordering_test.go b/test/packetimpact/tests/tcp_reordering_test.go
new file mode 100644
index 000000000..a5378a9dd
--- /dev/null
+++ b/test/packetimpact/tests/tcp_reordering_test.go
@@ -0,0 +1,174 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reordering_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ tb.RegisterFlags(flag.CommandLine)
+}
+
+func TestReorderingWindow(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFd)
+ conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ // Enable SACK.
+ opts := make([]byte, 40)
+ optsOff := 0
+ optsOff += header.EncodeNOP(opts[optsOff:])
+ optsOff += header.EncodeNOP(opts[optsOff:])
+ optsOff += header.EncodeSACKPermittedOption(opts[optsOff:])
+
+ // Ethernet guarantees that the MTU is at least 1500 bytes.
+ const minMTU = 1500
+ const mss = minMTU - header.IPv4MinimumSize - header.TCPMinimumSize
+ optsOff += header.EncodeMSSOption(mss, opts[optsOff:])
+
+ conn.ConnectWithOptions(opts[:optsOff])
+
+ acceptFd, _ := dut.Accept(listenFd)
+ defer dut.Close(acceptFd)
+
+ if tb.DUTType == "linux" {
+ // Linux has changed its handling of reordering, force the old behavior.
+ dut.SetSockOpt(acceptFd, unix.IPPROTO_TCP, unix.TCP_CONGESTION, []byte("reno"))
+ }
+
+ pls := dut.GetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_MAXSEG)
+ if tb.DUTType == "netstack" {
+ // netstack does not impliment TCP_MAXSEG correctly. Fake it
+ // here. Netstack uses the max SACK size which is 32. The MSS
+ // option is 8 bytes, making the total 36 bytes.
+ pls = mss - 36
+ }
+
+ payload := make([]byte, pls)
+
+ seqNum1 := *conn.RemoteSeqNum()
+ const numPkts = 10
+ // Send some packets, checking that we receive each.
+ for i, sn := 0, seqNum1; i < numPkts; i++ {
+ dut.Send(acceptFd, payload, 0)
+
+ gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second)
+ sn.UpdateForward(seqnum.Size(len(payload)))
+ if err != nil {
+ t.Errorf("Expect #%d: %s", i+1, err)
+ continue
+ }
+ if gotOne == nil {
+ t.Errorf("#%d: expected a packet within a second but got none", i+1)
+ }
+ }
+
+ seqNum2 := *conn.RemoteSeqNum()
+
+ // SACK packets #2-4.
+ sackBlock := make([]byte, 40)
+ sbOff := 0
+ sbOff += header.EncodeNOP(sackBlock[sbOff:])
+ sbOff += header.EncodeNOP(sackBlock[sbOff:])
+ sbOff += header.EncodeSACKBlocks([]header.SACKBlock{{
+ seqNum1.Add(seqnum.Size(len(payload))),
+ seqNum1.Add(seqnum.Size(4 * len(payload))),
+ }}, sackBlock[sbOff:])
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]})
+
+ // ACK first packet.
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1) + uint32(len(payload)))})
+
+ // Check for retransmit.
+ gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(seqNum1))}, time.Second)
+ if err != nil {
+ t.Error("Expect for retransmit:", err)
+ }
+ if gotOne == nil {
+ t.Error("expected a retransmitted packet within a second but got none")
+ }
+
+ // ACK all send packets with a DSACK block for packet #1. This tells
+ // the other end that we got both the original and retransmit for
+ // packet #1.
+ dsackBlock := make([]byte, 40)
+ dsbOff := 0
+ dsbOff += header.EncodeNOP(dsackBlock[dsbOff:])
+ dsbOff += header.EncodeNOP(dsackBlock[dsbOff:])
+ dsbOff += header.EncodeSACKBlocks([]header.SACKBlock{{
+ seqNum1.Add(seqnum.Size(len(payload))),
+ seqNum1.Add(seqnum.Size(4 * len(payload))),
+ }}, dsackBlock[dsbOff:])
+
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum2)), Options: dsackBlock[:dsbOff]})
+
+ // Send half of the original window of packets, checking that we
+ // received each.
+ for i, sn := 0, seqNum2; i < numPkts/2; i++ {
+ dut.Send(acceptFd, payload, 0)
+
+ gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second)
+ sn.UpdateForward(seqnum.Size(len(payload)))
+ if err != nil {
+ t.Errorf("Expect #%d: %s", i+1, err)
+ continue
+ }
+ if gotOne == nil {
+ t.Errorf("#%d: expected a packet within a second but got none", i+1)
+ }
+ }
+
+ if tb.DUTType == "netstack" {
+ // The window should now be halved, so we should receive any
+ // more, even if we send them.
+ dut.Send(acceptFd, payload, 0)
+ if got, err := conn.Expect(tb.TCP{}, 100*time.Millisecond); got != nil || err == nil {
+ t.Fatalf("expected no packets within 100 millisecond, but got one: %s", got)
+ }
+ return
+ }
+
+ // Linux reduces the window by three. Check that we can receive the rest.
+ for i, sn := 0, seqNum2.Add(seqnum.Size(numPkts/2*len(payload))); i < 2; i++ {
+ dut.Send(acceptFd, payload, 0)
+
+ gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second)
+ sn.UpdateForward(seqnum.Size(len(payload)))
+ if err != nil {
+ t.Errorf("Expect #%d: %s", i+1, err)
+ continue
+ }
+ if gotOne == nil {
+ t.Errorf("#%d: expected a packet within a second but got none", i+1)
+ }
+ }
+
+ // The window should now be full.
+ dut.Send(acceptFd, payload, 0)
+ if got, err := conn.Expect(tb.TCP{}, 100*time.Millisecond); got != nil || err == nil {
+ t.Fatalf("expected no packets within 100 millisecond, but got one: %s", got)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_retransmits_test.go b/test/packetimpact/tests/tcp_retransmits_test.go
new file mode 100644
index 000000000..6940eb7fb
--- /dev/null
+++ b/test/packetimpact/tests/tcp_retransmits_test.go
@@ -0,0 +1,84 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_retransmits_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestRetransmits tests retransmits occur at exponentially increasing
+// time intervals.
+func TestRetransmits(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFd)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ conn.Connect()
+ acceptFd, _ := dut.Accept(listenFd)
+ defer dut.Close(acceptFd)
+
+ dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ sampleData := []byte("Sample Data")
+ samplePayload := &testbench.Payload{Bytes: sampleData}
+
+ dut.Send(acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ // Give a chance for the dut to estimate RTO with RTT from the DATA-ACK.
+ // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which
+ // we can skip sending this ACK.
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+
+ startRTO := time.Second
+ current := startRTO
+ first := time.Now()
+ dut.Send(acceptFd, sampleData, 0)
+ seq := testbench.Uint32(uint32(*conn.RemoteSeqNum()))
+ if _, err := conn.ExpectData(&testbench.TCP{SeqNum: seq}, samplePayload, startRTO); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ // Expect retransmits of the same segment.
+ for i := 0; i < 5; i++ {
+ start := time.Now()
+ if _, err := conn.ExpectData(&testbench.TCP{SeqNum: seq}, samplePayload, 2*current); err != nil {
+ t.Fatalf("expected payload was not received: %s loop %d", err, i)
+ }
+ if i == 0 {
+ startRTO = time.Now().Sub(first)
+ current = 2 * startRTO
+ continue
+ }
+ // Check if the probes came at exponentially increasing intervals.
+ if p := time.Since(start); p < current-startRTO {
+ t.Fatalf("retransmit came sooner interval %d probe %d", p, i)
+ }
+ current *= 2
+ }
+}
diff --git a/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go
new file mode 100644
index 000000000..90ab85419
--- /dev/null
+++ b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go
@@ -0,0 +1,105 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_send_window_sizes_piggyback_test
+
+import (
+ "flag"
+ "fmt"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestSendWindowSizesPiggyback tests cases where segment sizes are close to
+// sender window size and checks for ACK piggybacking for each of those case.
+func TestSendWindowSizesPiggyback(t *testing.T) {
+ sampleData := []byte("Sample Data")
+ segmentSize := uint16(len(sampleData))
+ // Advertise receive window sizes that are lesser, equal to or greater than
+ // enqueued segment size and check for segment transmits. The test attempts
+ // to enqueue a segment on the dut before acknowledging previous segment and
+ // lets the dut piggyback any ACKs along with the enqueued segment.
+ for _, tt := range []struct {
+ description string
+ windowSize uint16
+ expectedPayload1 []byte
+ expectedPayload2 []byte
+ enqueue bool
+ }{
+ // Expect the first segment to be split as it cannot be accomodated in
+ // the sender window. This means we need not enqueue a new segment after
+ // the first segment.
+ {"WindowSmallerThanSegment", segmentSize - 1, sampleData[:(segmentSize - 1)], sampleData[(segmentSize - 1):], false /* enqueue */},
+
+ {"WindowEqualToSegment", segmentSize, sampleData, sampleData, true /* enqueue */},
+
+ // Expect the second segment to not be split as its size is greater than
+ // the available sender window size. The segments should not be split
+ // when there is pending unacknowledged data and the segment-size is
+ // greater than available sender window.
+ {"WindowGreaterThanSegment", segmentSize + 1, sampleData, sampleData, true /* enqueue */},
+ } {
+ t.Run(fmt.Sprintf("%s%d", tt.description, tt.windowSize), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFd)
+
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort, WindowSize: testbench.Uint16(tt.windowSize)}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ conn.Connect()
+ acceptFd, _ := dut.Accept(listenFd)
+ defer dut.Close(acceptFd)
+
+ dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ expectedTCP := testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}
+
+ dut.Send(acceptFd, sampleData, 0)
+ expectedPayload := testbench.Payload{Bytes: tt.expectedPayload1}
+ if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+
+ // Expect any enqueued segment to be transmitted by the dut along with
+ // piggybacked ACK for our data.
+
+ if tt.enqueue {
+ // Enqueue a segment for the dut to transmit.
+ dut.Send(acceptFd, sampleData, 0)
+ }
+
+ // Send ACK for the previous segment along with data for the dut to
+ // receive and ACK back. Sending this ACK would make room for the dut
+ // to transmit any enqueued segment.
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh), WindowSize: testbench.Uint16(tt.windowSize)}, &testbench.Payload{Bytes: sampleData})
+
+ // Expect the dut to piggyback the ACK for received data along with
+ // the segment enqueued for transmit.
+ expectedPayload = testbench.Payload{Bytes: tt.expectedPayload2}
+ if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/tcp_synrcvd_reset_test.go b/test/packetimpact/tests/tcp_synrcvd_reset_test.go
new file mode 100644
index 000000000..7d5deab01
--- /dev/null
+++ b/test/packetimpact/tests/tcp_synrcvd_reset_test.go
@@ -0,0 +1,52 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_syn_reset_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestTCPSynRcvdReset tests transition from SYN-RCVD to CLOSED.
+func TestTCPSynRcvdReset(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFD)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ // Expect dut connection to have transitioned to SYN-RCVD state.
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)})
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
+ t.Fatalf("expected SYN-ACK %s", err)
+ }
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)})
+ // Expect the connection to have transitioned SYN-RCVD to CLOSED.
+ // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side.
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ t.Fatalf("expected a TCP RST %s", err)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_synsent_reset_test.go b/test/packetimpact/tests/tcp_synsent_reset_test.go
new file mode 100644
index 000000000..6898a2239
--- /dev/null
+++ b/test/packetimpact/tests/tcp_synsent_reset_test.go
@@ -0,0 +1,88 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_synsent_reset_test
+
+import (
+ "flag"
+ "net"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ tb.RegisterFlags(flag.CommandLine)
+}
+
+// dutSynSentState sets up the dut connection in SYN-SENT state.
+func dutSynSentState(t *testing.T) (*tb.DUT, *tb.TCPIPv4, uint16, uint16) {
+ dut := tb.NewDUT(t)
+
+ clientFD, clientPort := dut.CreateBoundSocket(unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(tb.RemoteIPv4))
+ port := uint16(9001)
+ conn := tb.NewTCPIPv4(t, tb.TCP{SrcPort: &port, DstPort: &clientPort}, tb.TCP{SrcPort: &clientPort, DstPort: &port})
+
+ sa := unix.SockaddrInet4{Port: int(port)}
+ copy(sa.Addr[:], net.IP(net.ParseIP(tb.LocalIPv4)).To4())
+ // Bring the dut to SYN-SENT state with a non-blocking connect.
+ dut.Connect(clientFD, &sa)
+ if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}, nil, time.Second); err != nil {
+ t.Fatalf("expected SYN\n")
+ }
+
+ return &dut, &conn, port, clientPort
+}
+
+// TestTCPSynSentReset tests RFC793, p67: SYN-SENT to CLOSED transition.
+func TestTCPSynSentReset(t *testing.T) {
+ dut, conn, _, _ := dutSynSentState(t)
+ defer conn.Close()
+ defer dut.TearDown()
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)})
+ // Expect the connection to have closed.
+ // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side.
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ t.Fatalf("expected a TCP RST")
+ }
+}
+
+// TestTCPSynSentRcvdReset tests RFC793, p70, SYN-SENT to SYN-RCVD to CLOSED
+// transitions.
+func TestTCPSynSentRcvdReset(t *testing.T) {
+ dut, c, remotePort, clientPort := dutSynSentState(t)
+ defer dut.TearDown()
+ defer c.Close()
+
+ conn := tb.NewTCPIPv4(t, tb.TCP{SrcPort: &remotePort, DstPort: &clientPort}, tb.TCP{SrcPort: &clientPort, DstPort: &remotePort})
+ defer conn.Close()
+ // Initiate new SYN connection with the same port pair
+ // (simultaneous open case), expect the dut connection to move to
+ // SYN-RCVD state
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)})
+ if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
+ t.Fatalf("expected SYN-ACK %s\n", err)
+ }
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)})
+ // Expect the connection to have transitioned SYN-RCVD to CLOSED.
+ // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side.
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ t.Fatalf("expected a TCP RST")
+ }
+}
diff --git a/test/packetimpact/tests/tcp_user_timeout_test.go b/test/packetimpact/tests/tcp_user_timeout_test.go
new file mode 100644
index 000000000..87e45d765
--- /dev/null
+++ b/test/packetimpact/tests/tcp_user_timeout_test.go
@@ -0,0 +1,105 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_user_timeout_test
+
+import (
+ "flag"
+ "fmt"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func sendPayload(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error {
+ sampleData := make([]byte, 100)
+ for i := range sampleData {
+ sampleData[i] = uint8(i)
+ }
+ conn.Drain()
+ dut.Send(fd, sampleData, 0)
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil {
+ return fmt.Errorf("expected data but got none: %w", err)
+ }
+ return nil
+}
+
+func sendFIN(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error {
+ dut.Close(fd)
+ return nil
+}
+
+func TestTCPUserTimeout(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ userTimeout time.Duration
+ sendDelay time.Duration
+ }{
+ {"NoUserTimeout", 0, 3 * time.Second},
+ {"ACKBeforeUserTimeout", 5 * time.Second, 4 * time.Second},
+ {"ACKAfterUserTimeout", 5 * time.Second, 7 * time.Second},
+ } {
+ for _, ttf := range []struct {
+ description string
+ f func(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error
+ }{
+ {"AfterPayload", sendPayload},
+ {"AfterFIN", sendFIN},
+ } {
+ t.Run(tt.description+ttf.description, func(t *testing.T) {
+ // Create a socket, listen, TCP handshake, and accept.
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFD)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+ conn.Connect()
+ acceptFD, _ := dut.Accept(listenFD)
+
+ if tt.userTimeout != 0 {
+ dut.SetSockOptInt(acceptFD, unix.SOL_TCP, unix.TCP_USER_TIMEOUT, int32(tt.userTimeout.Milliseconds()))
+ }
+
+ if err := ttf.f(&conn, &dut, acceptFD); err != nil {
+ t.Fatal(err)
+ }
+
+ time.Sleep(tt.sendDelay)
+ conn.Drain()
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+
+ // If TCP_USER_TIMEOUT was set and the above delay was longer than the
+ // TCP_USER_TIMEOUT then the DUT should send a RST in response to the
+ // testbench's packet.
+ expectRST := tt.userTimeout != 0 && tt.sendDelay > tt.userTimeout
+ expectTimeout := 5 * time.Second
+ got, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, expectTimeout)
+ if expectRST && err != nil {
+ t.Errorf("expected RST packet within %s but got none: %s", expectTimeout, err)
+ }
+ if !expectRST && got != nil {
+ t.Errorf("expected no RST packet within %s but got one: %s", expectTimeout, got)
+ }
+ })
+ }
+ }
+}
diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go
new file mode 100644
index 000000000..e78d04756
--- /dev/null
+++ b/test/packetimpact/tests/tcp_window_shrink_test.go
@@ -0,0 +1,73 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_window_shrink_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func TestWindowShrink(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFd)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ conn.Connect()
+ acceptFd, _ := dut.Accept(listenFd)
+ defer dut.Close(acceptFd)
+
+ dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ sampleData := []byte("Sample Data")
+ samplePayload := &testbench.Payload{Bytes: sampleData}
+
+ dut.Send(acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+
+ dut.Send(acceptFd, sampleData, 0)
+ dut.Send(acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ // We close our receiving window here
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+
+ dut.Send(acceptFd, []byte("Sample Data"), 0)
+ // Note: There is another kind of zero-window probing which Windows uses (by sending one
+ // new byte at `RemoteSeqNum`), if netstack wants to go that way, we may want to change
+ // the following lines.
+ expectedRemoteSeqNum := *conn.RemoteSeqNum() - 1
+ if _, err := conn.ExpectData(&testbench.TCP{SeqNum: testbench.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil {
+ t.Fatalf("expected a packet with sequence number %d: %s", expectedRemoteSeqNum, err)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
new file mode 100644
index 000000000..8c89d57c9
--- /dev/null
+++ b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
@@ -0,0 +1,105 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_zero_window_probe_retransmit_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestZeroWindowProbeRetransmit tests retransmits of zero window probes
+// to be sent at exponentially inreasing time intervals.
+func TestZeroWindowProbeRetransmit(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFd)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ conn.Connect()
+ acceptFd, _ := dut.Accept(listenFd)
+ defer dut.Close(acceptFd)
+
+ dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ sampleData := []byte("Sample Data")
+ samplePayload := &testbench.Payload{Bytes: sampleData}
+
+ // Send and receive sample data to the dut.
+ dut.Send(acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
+ t.Fatalf("expected packet was not received: %s", err)
+ }
+
+ // Check for the dut to keep the connection alive as long as the zero window
+ // probes are acknowledged. Check if the zero window probes are sent at
+ // exponentially increasing intervals. The timeout intervals are function
+ // of the recorded first zero probe transmission duration.
+ //
+ // Advertize zero receive window again.
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+ probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1))
+ ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum()))
+
+ startProbeDuration := time.Second
+ current := startProbeDuration
+ first := time.Now()
+ // Ask the dut to send out data.
+ dut.Send(acceptFd, sampleData, 0)
+ // Expect the dut to keep the connection alive as long as the remote is
+ // acknowledging the zero-window probes.
+ for i := 0; i < 5; i++ {
+ start := time.Now()
+ // Expect zero-window probe with a timeout which is a function of the typical
+ // first retransmission time. The retransmission times is supposed to
+ // exponentially increase.
+ if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, 2*current); err != nil {
+ t.Fatalf("expected a probe with sequence number %d: loop %d", probeSeq, i)
+ }
+ if i == 0 {
+ startProbeDuration = time.Now().Sub(first)
+ current = 2 * startProbeDuration
+ continue
+ }
+ // Check if the probes came at exponentially increasing intervals.
+ if got, want := time.Since(start), current-startProbeDuration; got < want {
+ t.Errorf("got zero probe %d after %s, want >= %s", i, got, want)
+ }
+ // Acknowledge the zero-window probes from the dut.
+ conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+ current *= 2
+ }
+ // Advertize non-zero window.
+ conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)})
+ // Expect the dut to recover and transmit data.
+ if _, err := conn.ExpectData(&testbench.
+ TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_test.go b/test/packetimpact/tests/tcp_zero_window_probe_test.go
new file mode 100644
index 000000000..649fd5699
--- /dev/null
+++ b/test/packetimpact/tests/tcp_zero_window_probe_test.go
@@ -0,0 +1,112 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_zero_window_probe_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestZeroWindowProbe tests few cases of zero window probing over the
+// same connection.
+func TestZeroWindowProbe(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFd)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ conn.Connect()
+ acceptFd, _ := dut.Accept(listenFd)
+ defer dut.Close(acceptFd)
+
+ dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ sampleData := []byte("Sample Data")
+ samplePayload := &testbench.Payload{Bytes: sampleData}
+
+ start := time.Now()
+ // Send and receive sample data to the dut.
+ dut.Send(acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ sendTime := time.Now().Sub(start)
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
+ t.Fatalf("expected packet was not received: %s", err)
+ }
+
+ // Test 1: Check for receive of a zero window probe, record the duration for
+ // probe to be sent.
+ //
+ // Advertize zero window to the dut.
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+
+ // Expected sequence number of the zero window probe.
+ probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1))
+ // Expected ack number of the ACK for the probe.
+ ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum()))
+
+ // Expect there are no zero-window probes sent until there is data to be sent out
+ // from the dut.
+ if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, 2*time.Second); err == nil {
+ t.Fatalf("unexpected packet with sequence number %d: %s", probeSeq, err)
+ }
+
+ start = time.Now()
+ // Ask the dut to send out data.
+ dut.Send(acceptFd, sampleData, 0)
+ // Expect zero-window probe from the dut.
+ if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil {
+ t.Fatalf("expected a packet with sequence number %d: %s", probeSeq, err)
+ }
+ // Expect the probe to be sent after some time. Compare against the previous
+ // time recorded when the dut immediately sends out data on receiving the
+ // send command.
+ if startProbeDuration := time.Now().Sub(start); startProbeDuration <= sendTime {
+ t.Fatalf("expected the first probe to be sent out after retransmission interval, got %s want > %s", startProbeDuration, sendTime)
+ }
+
+ // Test 2: Check if the dut recovers on advertizing non-zero receive window.
+ // and sends out the sample payload after the send window opens.
+ //
+ // Advertize non-zero window to the dut and ack the zero window probe.
+ conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)})
+ // Expect the dut to recover and transmit data.
+ if _, err := conn.ExpectData(&testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+
+ // Test 3: Sanity check for dut's processing of a similar probe it sent.
+ // Check if the dut responds as we do for a similar probe sent to it.
+ // Basically with sequence number to one byte behind the unacknowledged
+ // sequence number.
+ p := testbench.Uint32(uint32(*conn.LocalSeqNum()))
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), SeqNum: testbench.Uint32(uint32(*conn.LocalSeqNum() - 1))})
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: p}, nil, time.Second); err != nil {
+ t.Fatalf("expected a packet with ack number: %d: %s", p, err)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go
new file mode 100644
index 000000000..3c467b14f
--- /dev/null
+++ b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go
@@ -0,0 +1,98 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_zero_window_probe_usertimeout_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestZeroWindowProbeUserTimeout sanity tests user timeout when we are
+// retransmitting zero window probes.
+func TestZeroWindowProbeUserTimeout(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFd)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ conn.Connect()
+ acceptFd, _ := dut.Accept(listenFd)
+ defer dut.Close(acceptFd)
+
+ dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ sampleData := []byte("Sample Data")
+ samplePayload := &testbench.Payload{Bytes: sampleData}
+
+ // Send and receive sample data to the dut.
+ dut.Send(acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
+ t.Fatalf("expected packet was not received: %s", err)
+ }
+
+ // Test 1: Check for receive of a zero window probe, record the duration for
+ // probe to be sent.
+ //
+ // Advertize zero window to the dut.
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+
+ // Expected sequence number of the zero window probe.
+ probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1))
+ start := time.Now()
+ // Ask the dut to send out data.
+ dut.Send(acceptFd, sampleData, 0)
+ // Expect zero-window probe from the dut.
+ if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil {
+ t.Fatalf("expected a packet with sequence number %d: %s", probeSeq, err)
+ }
+ // Record the duration for first probe, the dut sends the zero window probe after
+ // a retransmission time interval.
+ startProbeDuration := time.Now().Sub(start)
+
+ // Test 2: Check if the dut times out the connection by honoring usertimeout
+ // when the dut is sending zero-window probes.
+ //
+ // Reduce the retransmit timeout.
+ dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int32(startProbeDuration.Milliseconds()))
+ // Advertize zero window again.
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+ // Ask the dut to send out data that would trigger zero window probe retransmissions.
+ dut.Send(acceptFd, sampleData, 0)
+
+ // Wait for the connection to timeout after multiple zero-window probe retransmissions.
+ time.Sleep(8 * startProbeDuration)
+
+ // Expect the connection to have timed out and closed which would cause the dut
+ // to reply with a RST to the ACK we send.
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ t.Fatalf("expected a TCP RST")
+ }
+}
diff --git a/test/packetimpact/tests/udp_icmp_error_propagation_test.go b/test/packetimpact/tests/udp_icmp_error_propagation_test.go
new file mode 100644
index 000000000..b754918f6
--- /dev/null
+++ b/test/packetimpact/tests/udp_icmp_error_propagation_test.go
@@ -0,0 +1,365 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp_icmp_error_propagation_test
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "net"
+ "sync"
+ "syscall"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+type connectionMode bool
+
+func (c connectionMode) String() string {
+ if c {
+ return "Connected"
+ }
+ return "Connectionless"
+}
+
+type icmpError int
+
+const (
+ portUnreachable icmpError = iota
+ timeToLiveExceeded
+)
+
+func (e icmpError) String() string {
+ switch e {
+ case portUnreachable:
+ return "PortUnreachable"
+ case timeToLiveExceeded:
+ return "TimeToLiveExpired"
+ }
+ return "Unknown ICMP error"
+}
+
+func (e icmpError) ToICMPv4() *testbench.ICMPv4 {
+ switch e {
+ case portUnreachable:
+ return &testbench.ICMPv4{Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable), Code: testbench.Uint8(header.ICMPv4PortUnreachable)}
+ case timeToLiveExceeded:
+ return &testbench.ICMPv4{Type: testbench.ICMPv4Type(header.ICMPv4TimeExceeded), Code: testbench.Uint8(header.ICMPv4TTLExceeded)}
+ }
+ return nil
+}
+
+type errorDetection struct {
+ name string
+ useValidConn bool
+ f func(context.Context, testData) error
+}
+
+type testData struct {
+ dut *testbench.DUT
+ conn *testbench.UDPIPv4
+ remoteFD int32
+ remotePort uint16
+ cleanFD int32
+ cleanPort uint16
+ wantErrno syscall.Errno
+}
+
+// wantErrno computes the errno to expect given the connection mode of a UDP
+// socket and the ICMP error it will receive.
+func wantErrno(c connectionMode, icmpErr icmpError) syscall.Errno {
+ if c && icmpErr == portUnreachable {
+ return syscall.Errno(unix.ECONNREFUSED)
+ }
+ return syscall.Errno(0)
+}
+
+// sendICMPError sends an ICMP error message in response to a UDP datagram.
+func sendICMPError(conn *testbench.UDPIPv4, icmpErr icmpError, udp *testbench.UDP) error {
+ layers := (*testbench.Connection)(conn).CreateFrame(nil)
+ layers = layers[:len(layers)-1]
+ ip, ok := udp.Prev().(*testbench.IPv4)
+ if !ok {
+ return fmt.Errorf("expected %s to be IPv4", udp.Prev())
+ }
+ if icmpErr == timeToLiveExceeded {
+ *ip.TTL = 1
+ // Let serialization recalculate the checksum since we set the TTL
+ // to 1.
+ ip.Checksum = nil
+ }
+ // Note that the ICMP payload is valid in this case because the UDP
+ // payload is empty. If the UDP payload were not empty, the packet
+ // length during serialization may not be calculated correctly,
+ // resulting in a mal-formed packet.
+ layers = append(layers, icmpErr.ToICMPv4(), ip, udp)
+
+ (*testbench.Connection)(conn).SendFrameStateless(layers)
+ return nil
+}
+
+// testRecv tests observing the ICMP error through the recv syscall. A packet
+// is sent to the DUT, and if wantErrno is non-zero, then the first recv should
+// fail and the second should succeed. Otherwise if wantErrno is zero then the
+// first recv should succeed immediately.
+func testRecv(ctx context.Context, d testData) error {
+ // Check that receiving on the clean socket works.
+ d.conn.Send(testbench.UDP{DstPort: &d.cleanPort})
+ d.dut.Recv(d.cleanFD, 100, 0)
+
+ d.conn.Send(testbench.UDP{})
+
+ if d.wantErrno != syscall.Errno(0) {
+ ctx, cancel := context.WithTimeout(ctx, time.Second)
+ defer cancel()
+ ret, _, err := d.dut.RecvWithErrno(ctx, d.remoteFD, 100, 0)
+ if ret != -1 {
+ return fmt.Errorf("recv after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno)
+ }
+ if err != d.wantErrno {
+ return fmt.Errorf("recv after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno)
+ }
+ }
+
+ d.dut.Recv(d.remoteFD, 100, 0)
+ return nil
+}
+
+// testSendTo tests observing the ICMP error through the send syscall. If
+// wantErrno is non-zero, the first send should fail and a subsequent send
+// should suceed; while if wantErrno is zero then the first send should just
+// succeed.
+func testSendTo(ctx context.Context, d testData) error {
+ // Check that sending on the clean socket works.
+ d.dut.SendTo(d.cleanFD, nil, 0, d.conn.LocalAddr())
+ if _, err := d.conn.Expect(testbench.UDP{SrcPort: &d.cleanPort}, time.Second); err != nil {
+ return fmt.Errorf("did not receive UDP packet from clean socket on DUT: %s", err)
+ }
+
+ if d.wantErrno != syscall.Errno(0) {
+ ctx, cancel := context.WithTimeout(ctx, time.Second)
+ defer cancel()
+ ret, err := d.dut.SendToWithErrno(ctx, d.remoteFD, nil, 0, d.conn.LocalAddr())
+
+ if ret != -1 {
+ return fmt.Errorf("sendto after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno)
+ }
+ if err != d.wantErrno {
+ return fmt.Errorf("sendto after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno)
+ }
+ }
+
+ d.dut.SendTo(d.remoteFD, nil, 0, d.conn.LocalAddr())
+ if _, err := d.conn.Expect(testbench.UDP{}, time.Second); err != nil {
+ return fmt.Errorf("did not receive UDP packet as expected: %s", err)
+ }
+ return nil
+}
+
+func testSockOpt(_ context.Context, d testData) error {
+ // Check that there's no pending error on the clean socket.
+ if errno := syscall.Errno(d.dut.GetSockOptInt(d.cleanFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != syscall.Errno(0) {
+ return fmt.Errorf("unexpected error (%[1]d) %[1]v on clean socket", errno)
+ }
+
+ if errno := syscall.Errno(d.dut.GetSockOptInt(d.remoteFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != d.wantErrno {
+ return fmt.Errorf("SO_ERROR sockopt after ICMP error is (%[1]d) %[1]v, expected (%[2]d) %[2]v", errno, d.wantErrno)
+ }
+
+ // Check that after clearing socket error, sending doesn't fail.
+ d.dut.SendTo(d.remoteFD, nil, 0, d.conn.LocalAddr())
+ if _, err := d.conn.Expect(testbench.UDP{}, time.Second); err != nil {
+ return fmt.Errorf("did not receive UDP packet as expected: %s", err)
+ }
+ return nil
+}
+
+// TestUDPICMPErrorPropagation tests that ICMP error messages in response to
+// UDP datagrams are processed correctly. RFC 1122 section 4.1.3.3 states that:
+// "UDP MUST pass to the application layer all ICMP error messages that it
+// receives from the IP layer."
+//
+// The test cases are parametrized in 3 dimensions: 1. the UDP socket is either
+// put into connection mode or left connectionless, 2. the ICMP message type
+// and code, and 3. the method by which the ICMP error is observed on the
+// socket: sendto, recv, or getsockopt(SO_ERROR).
+//
+// Linux's udp(7) man page states: "All fatal errors will be passed to the user
+// as an error return even when the socket is not connected. This includes
+// asynchronous errors received from the network." In practice, the only
+// combination of parameters to the test that causes an error to be observable
+// on the UDP socket is receiving a port unreachable message on a connected
+// socket.
+func TestUDPICMPErrorPropagation(t *testing.T) {
+ for _, connect := range []connectionMode{true, false} {
+ for _, icmpErr := range []icmpError{portUnreachable, timeToLiveExceeded} {
+ wantErrno := wantErrno(connect, icmpErr)
+
+ for _, errDetect := range []errorDetection{
+ errorDetection{"SendTo", false, testSendTo},
+ // Send to an address that's different from the one that caused an ICMP
+ // error to be returned.
+ errorDetection{"SendToValid", true, testSendTo},
+ errorDetection{"Recv", false, testRecv},
+ errorDetection{"SockOpt", false, testSockOpt},
+ } {
+ t.Run(fmt.Sprintf("%s/%s/%s", connect, icmpErr, errDetect.name), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+
+ remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
+ defer dut.Close(remoteFD)
+
+ // Create a second, clean socket on the DUT to ensure that the ICMP
+ // error messages only affect the sockets they are intended for.
+ cleanFD, cleanPort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
+ defer dut.Close(cleanFD)
+
+ conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ if connect {
+ dut.Connect(remoteFD, conn.LocalAddr())
+ dut.Connect(cleanFD, conn.LocalAddr())
+ }
+
+ dut.SendTo(remoteFD, nil, 0, conn.LocalAddr())
+ udp, err := conn.Expect(testbench.UDP{}, time.Second)
+ if err != nil {
+ t.Fatalf("did not receive message from DUT: %s", err)
+ }
+
+ if err := sendICMPError(&conn, icmpErr, udp); err != nil {
+ t.Fatal(err)
+ }
+
+ errDetectConn := &conn
+ if errDetect.useValidConn {
+ // connClean is a UDP socket on the test runner that was not
+ // involved in the generation of the ICMP error. As such,
+ // interactions between it and the the DUT should be independent of
+ // the ICMP error at least at the port level.
+ connClean := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ defer connClean.Close()
+
+ errDetectConn = &connClean
+ }
+
+ if err := errDetect.f(context.Background(), testData{&dut, errDetectConn, remoteFD, remotePort, cleanFD, cleanPort, wantErrno}); err != nil {
+ t.Fatal(err)
+ }
+ })
+ }
+ }
+ }
+}
+
+// TestICMPErrorDuringUDPRecv tests behavior when a UDP socket is in the middle
+// of a blocking recv and receives an ICMP error.
+func TestICMPErrorDuringUDPRecv(t *testing.T) {
+ for _, connect := range []connectionMode{true, false} {
+ for _, icmpErr := range []icmpError{portUnreachable, timeToLiveExceeded} {
+ wantErrno := wantErrno(connect, icmpErr)
+
+ t.Run(fmt.Sprintf("%s/%s", connect, icmpErr), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+
+ remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
+ defer dut.Close(remoteFD)
+
+ // Create a second, clean socket on the DUT to ensure that the ICMP
+ // error messages only affect the sockets they are intended for.
+ cleanFD, cleanPort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
+ defer dut.Close(cleanFD)
+
+ conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ if connect {
+ dut.Connect(remoteFD, conn.LocalAddr())
+ dut.Connect(cleanFD, conn.LocalAddr())
+ }
+
+ dut.SendTo(remoteFD, nil, 0, conn.LocalAddr())
+ udp, err := conn.Expect(testbench.UDP{}, time.Second)
+ if err != nil {
+ t.Fatalf("did not receive message from DUT: %s", err)
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+
+ if wantErrno != syscall.Errno(0) {
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0)
+ if ret != -1 {
+ t.Errorf("recv during ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", wantErrno)
+ return
+ }
+ if err != wantErrno {
+ t.Errorf("recv during ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, wantErrno)
+ return
+ }
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ if ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0); ret == -1 {
+ t.Errorf("recv after ICMP error failed with (%[1]d) %[1]", err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ if ret, _, err := dut.RecvWithErrno(ctx, cleanFD, 100, 0); ret == -1 {
+ t.Errorf("recv on clean socket failed with (%[1]d) %[1]", err)
+ }
+ }()
+
+ // TODO(b/155684889) This sleep is to allow time for the DUT to
+ // actually call recv since we want the ICMP error to arrive during the
+ // blocking recv, and should be replaced when a better synchronization
+ // alternative is available.
+ time.Sleep(2 * time.Second)
+
+ if err := sendICMPError(&conn, icmpErr, udp); err != nil {
+ t.Fatal(err)
+ }
+
+ conn.Send(testbench.UDP{DstPort: &cleanPort})
+ conn.Send(testbench.UDP{})
+ wg.Wait()
+ })
+ }
+ }
+}
diff --git a/test/packetimpact/tests/udp_recv_multicast_test.go b/test/packetimpact/tests/udp_recv_multicast_test.go
new file mode 100644
index 000000000..77a9bfa1d
--- /dev/null
+++ b/test/packetimpact/tests/udp_recv_multicast_test.go
@@ -0,0 +1,40 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp_recv_multicast_test
+
+import (
+ "flag"
+ "net"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func TestUDPRecvMulticast(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
+ defer dut.Close(boundFD)
+ conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ defer conn.Close()
+ conn.SendIP(testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(net.ParseIP("224.0.0.1").To4()))}, testbench.UDP{})
+ dut.Recv(boundFD, 100, 0)
+}
diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go
new file mode 100644
index 000000000..224feef85
--- /dev/null
+++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go
@@ -0,0 +1,90 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp_send_recv_dgram_test
+
+import (
+ "flag"
+ "net"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func TestUDPRecv(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
+ defer dut.Close(boundFD)
+ conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ testCases := []struct {
+ name string
+ payload []byte
+ }{
+ {"emptypayload", nil},
+ {"small payload", []byte("hello world")},
+ {"1kPayload", testbench.GenerateRandomPayload(t, 1<<10)},
+ // Even though UDP allows larger dgrams we don't test it here as
+ // they need to be fragmented and written out as individual
+ // frames.
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ conn.Send(testbench.UDP{}, &testbench.Payload{Bytes: tc.payload})
+ if got, want := string(dut.Recv(boundFD, int32(len(tc.payload)), 0)), string(tc.payload); got != want {
+ t.Fatalf("received payload does not match sent payload got: %s, want: %s", got, want)
+ }
+ })
+ }
+}
+
+func TestUDPSend(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
+ defer dut.Close(boundFD)
+ conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ testCases := []struct {
+ name string
+ payload []byte
+ }{
+ {"emptypayload", nil},
+ {"small payload", []byte("hello world")},
+ {"1kPayload", testbench.GenerateRandomPayload(t, 1<<10)},
+ // Even though UDP allows larger dgrams we don't test it here as
+ // they need to be fragmented and written out as individual
+ // frames.
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ conn.Drain()
+ if got, want := int(dut.SendTo(boundFD, tc.payload, 0, conn.LocalAddr())), len(tc.payload); got != want {
+ t.Fatalf("short write got: %d, want: %d", got, want)
+ }
+ if _, err := conn.ExpectData(testbench.UDP{SrcPort: &remotePort}, testbench.Payload{Bytes: tc.payload}, 1*time.Second); err != nil {
+ t.Fatal(err)
+ }
+ })
+ }
+}
diff --git a/test/perf/BUILD b/test/perf/BUILD
new file mode 100644
index 000000000..471d8c2ab
--- /dev/null
+++ b/test/perf/BUILD
@@ -0,0 +1,117 @@
+load("//test/runner:defs.bzl", "syscall_test")
+
+package(licenses = ["notice"])
+
+syscall_test(
+ test = "//test/perf/linux:clock_getres_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:clock_gettime_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:death_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:epoll_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:fork_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:futex_benchmark",
+)
+
+syscall_test(
+ size = "enormous",
+ shard_count = 10,
+ tags = ["nogotsan"],
+ test = "//test/perf/linux:getdents_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:getpid_benchmark",
+)
+
+syscall_test(
+ size = "enormous",
+ tags = ["nogotsan"],
+ test = "//test/perf/linux:gettid_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:mapping_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:open_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:pipe_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:randread_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:read_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:sched_yield_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:send_recv_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:seqwrite_benchmark",
+)
+
+syscall_test(
+ size = "enormous",
+ test = "//test/perf/linux:signal_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:sleep_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:stat_benchmark",
+)
+
+syscall_test(
+ size = "enormous",
+ add_overlay = True,
+ test = "//test/perf/linux:unlink_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:write_benchmark",
+)
diff --git a/test/perf/linux/BUILD b/test/perf/linux/BUILD
new file mode 100644
index 000000000..b4e907826
--- /dev/null
+++ b/test/perf/linux/BUILD
@@ -0,0 +1,356 @@
+load("//tools:defs.bzl", "cc_binary", "gbenchmark", "gtest")
+
+package(
+ default_visibility = ["//:sandbox"],
+ licenses = ["notice"],
+)
+
+cc_binary(
+ name = "getpid_benchmark",
+ testonly = 1,
+ srcs = [
+ "getpid_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "send_recv_benchmark",
+ testonly = 1,
+ srcs = [
+ "send_recv_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/syscalls/linux:socket_test_util",
+ "//test/util:file_descriptor",
+ "//test/util:logging",
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_binary(
+ name = "gettid_benchmark",
+ testonly = 1,
+ srcs = [
+ "gettid_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "sched_yield_benchmark",
+ testonly = 1,
+ srcs = [
+ "sched_yield_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "clock_getres_benchmark",
+ testonly = 1,
+ srcs = [
+ "clock_getres_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "clock_gettime_benchmark",
+ testonly = 1,
+ srcs = [
+ "clock_gettime_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_binary(
+ name = "open_benchmark",
+ testonly = 1,
+ srcs = [
+ "open_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "read_benchmark",
+ testonly = 1,
+ srcs = [
+ "read_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "randread_benchmark",
+ testonly = 1,
+ srcs = [
+ "randread_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:file_descriptor",
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/random",
+ ],
+)
+
+cc_binary(
+ name = "write_benchmark",
+ testonly = 1,
+ srcs = [
+ "write_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "seqwrite_benchmark",
+ testonly = 1,
+ srcs = [
+ "seqwrite_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/random",
+ ],
+)
+
+cc_binary(
+ name = "pipe_benchmark",
+ testonly = 1,
+ srcs = [
+ "pipe_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "fork_benchmark",
+ testonly = 1,
+ srcs = [
+ "fork_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:cleanup",
+ "//test/util:file_descriptor",
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_binary(
+ name = "futex_benchmark",
+ testonly = 1,
+ srcs = [
+ "futex_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_binary(
+ name = "epoll_benchmark",
+ testonly = 1,
+ srcs = [
+ "epoll_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:epoll_util",
+ "//test/util:file_descriptor",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_binary(
+ name = "death_benchmark",
+ testonly = 1,
+ srcs = [
+ "death_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "mapping_benchmark",
+ testonly = 1,
+ srcs = [
+ "mapping_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:memory_util",
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "signal_benchmark",
+ testonly = 1,
+ srcs = [
+ "signal_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "getdents_benchmark",
+ testonly = 1,
+ srcs = [
+ "getdents_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "sleep_benchmark",
+ testonly = 1,
+ srcs = [
+ "sleep_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "stat_benchmark",
+ testonly = 1,
+ srcs = [
+ "stat_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_binary(
+ name = "unlink_benchmark",
+ testonly = 1,
+ srcs = [
+ "unlink_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
diff --git a/test/perf/linux/clock_getres_benchmark.cc b/test/perf/linux/clock_getres_benchmark.cc
new file mode 100644
index 000000000..b051293ad
--- /dev/null
+++ b/test/perf/linux/clock_getres_benchmark.cc
@@ -0,0 +1,39 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <time.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// clock_getres(1) is very nearly a no-op syscall, but it does require copying
+// out to a userspace struct. It thus provides a nice small copy-out benchmark.
+void BM_ClockGetRes(benchmark::State& state) {
+ struct timespec ts;
+ for (auto _ : state) {
+ clock_getres(CLOCK_MONOTONIC, &ts);
+ }
+}
+
+BENCHMARK(BM_ClockGetRes);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/clock_gettime_benchmark.cc b/test/perf/linux/clock_gettime_benchmark.cc
new file mode 100644
index 000000000..6691bebd9
--- /dev/null
+++ b/test/perf/linux/clock_gettime_benchmark.cc
@@ -0,0 +1,60 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <pthread.h>
+#include <time.h>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "benchmark/benchmark.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_ClockGettimeThreadCPUTime(benchmark::State& state) {
+ clockid_t clockid;
+ ASSERT_EQ(0, pthread_getcpuclockid(pthread_self(), &clockid));
+ struct timespec tp;
+
+ for (auto _ : state) {
+ clock_gettime(clockid, &tp);
+ }
+}
+
+BENCHMARK(BM_ClockGettimeThreadCPUTime);
+
+void BM_VDSOClockGettime(benchmark::State& state) {
+ const clockid_t clock = state.range(0);
+ struct timespec tp;
+ absl::Time start = absl::Now();
+
+ // Don't benchmark the calibration phase.
+ while (absl::Now() < start + absl::Milliseconds(2100)) {
+ clock_gettime(clock, &tp);
+ }
+
+ for (auto _ : state) {
+ clock_gettime(clock, &tp);
+ }
+}
+
+BENCHMARK(BM_VDSOClockGettime)->Arg(CLOCK_MONOTONIC)->Arg(CLOCK_REALTIME);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/death_benchmark.cc b/test/perf/linux/death_benchmark.cc
new file mode 100644
index 000000000..cb2b6fd07
--- /dev/null
+++ b/test/perf/linux/death_benchmark.cc
@@ -0,0 +1,36 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// DeathTest is not so much a microbenchmark as a macrobenchmark. It is testing
+// the ability of gVisor (on whatever platform) to execute all the related
+// stack-dumping routines associated with EXPECT_EXIT / EXPECT_DEATH.
+TEST(DeathTest, ZeroEqualsOne) {
+ EXPECT_EXIT({ TEST_CHECK(0 == 1); }, ::testing::KilledBySignal(SIGABRT), "");
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/epoll_benchmark.cc b/test/perf/linux/epoll_benchmark.cc
new file mode 100644
index 000000000..0b121338a
--- /dev/null
+++ b/test/perf/linux/epoll_benchmark.cc
@@ -0,0 +1,99 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/epoll.h>
+#include <sys/eventfd.h>
+
+#include <atomic>
+#include <cerrno>
+#include <cstdint>
+#include <cstdlib>
+#include <ctime>
+#include <memory>
+
+#include "gtest/gtest.h"
+#include "absl/time/time.h"
+#include "benchmark/benchmark.h"
+#include "test/util/epoll_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Returns a new eventfd.
+PosixErrorOr<FileDescriptor> NewEventFD() {
+ int fd = eventfd(0, /* flags = */ 0);
+ MaybeSave();
+ if (fd < 0) {
+ return PosixError(errno, "eventfd");
+ }
+ return FileDescriptor(fd);
+}
+
+// Also stolen from epoll.cc unit tests.
+void BM_EpollTimeout(benchmark::State& state) {
+ constexpr int kFDsPerEpoll = 3;
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+
+ std::vector<FileDescriptor> eventfds;
+ for (int i = 0; i < kFDsPerEpoll; i++) {
+ eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()));
+ ASSERT_NO_ERRNO(
+ RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, 0));
+ }
+
+ struct epoll_event result[kFDsPerEpoll];
+ int timeout_ms = state.range(0);
+
+ for (auto _ : state) {
+ EXPECT_EQ(0, epoll_wait(epollfd.get(), result, kFDsPerEpoll, timeout_ms));
+ }
+}
+
+BENCHMARK(BM_EpollTimeout)->Range(0, 8);
+
+// Also stolen from epoll.cc unit tests.
+void BM_EpollAllEvents(benchmark::State& state) {
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ const int fds_per_epoll = state.range(0);
+ constexpr uint64_t kEventVal = 5;
+
+ std::vector<FileDescriptor> eventfds;
+ for (int i = 0; i < fds_per_epoll; i++) {
+ eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()));
+ ASSERT_NO_ERRNO(
+ RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, 0));
+
+ ASSERT_THAT(WriteFd(eventfds[i].get(), &kEventVal, sizeof(kEventVal)),
+ SyscallSucceedsWithValue(sizeof(kEventVal)));
+ }
+
+ std::vector<struct epoll_event> result(fds_per_epoll);
+
+ for (auto _ : state) {
+ EXPECT_EQ(fds_per_epoll,
+ epoll_wait(epollfd.get(), result.data(), fds_per_epoll, 0));
+ }
+}
+
+BENCHMARK(BM_EpollAllEvents)->Range(2, 1024);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/fork_benchmark.cc b/test/perf/linux/fork_benchmark.cc
new file mode 100644
index 000000000..84fdbc8a0
--- /dev/null
+++ b/test/perf/linux/fork_benchmark.cc
@@ -0,0 +1,350 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/synchronization/barrier.h"
+#include "benchmark/benchmark.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr int kBusyMax = 250;
+
+// Do some CPU-bound busy-work.
+int busy(int max) {
+ // Prevent the compiler from optimizing this work away,
+ volatile int count = 0;
+
+ for (int i = 1; i < max; i++) {
+ for (int j = 2; j < i / 2; j++) {
+ if (i % j == 0) {
+ count++;
+ }
+ }
+ }
+
+ return count;
+}
+
+void BM_CPUBoundUniprocess(benchmark::State& state) {
+ for (auto _ : state) {
+ busy(kBusyMax);
+ }
+}
+
+BENCHMARK(BM_CPUBoundUniprocess);
+
+void BM_CPUBoundAsymmetric(benchmark::State& state) {
+ const size_t max = state.max_iterations;
+ pid_t child = fork();
+ if (child == 0) {
+ for (int i = 0; i < max; i++) {
+ busy(kBusyMax);
+ }
+ _exit(0);
+ }
+ ASSERT_THAT(child, SyscallSucceeds());
+ ASSERT_TRUE(state.KeepRunningBatch(max));
+
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status));
+ EXPECT_EQ(0, WEXITSTATUS(status));
+ ASSERT_FALSE(state.KeepRunning());
+}
+
+BENCHMARK(BM_CPUBoundAsymmetric)->UseRealTime();
+
+void BM_CPUBoundSymmetric(benchmark::State& state) {
+ std::vector<pid_t> children;
+ auto child_cleanup = Cleanup([&] {
+ for (const pid_t child : children) {
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status));
+ EXPECT_EQ(0, WEXITSTATUS(status));
+ }
+ ASSERT_FALSE(state.KeepRunning());
+ });
+
+ const int processes = state.range(0);
+ for (int i = 0; i < processes; i++) {
+ size_t cur = (state.max_iterations + (processes - 1)) / processes;
+ if ((state.iterations() + cur) >= state.max_iterations) {
+ cur = state.max_iterations - state.iterations();
+ }
+ pid_t child = fork();
+ if (child == 0) {
+ for (int i = 0; i < cur; i++) {
+ busy(kBusyMax);
+ }
+ _exit(0);
+ }
+ ASSERT_THAT(child, SyscallSucceeds());
+ if (cur > 0) {
+ // We can have a zero cur here, depending.
+ ASSERT_TRUE(state.KeepRunningBatch(cur));
+ }
+ children.push_back(child);
+ }
+}
+
+BENCHMARK(BM_CPUBoundSymmetric)->Range(2, 16)->UseRealTime();
+
+// Child routine for ProcessSwitch/ThreadSwitch.
+// Reads from readfd and writes the result to writefd.
+void SwitchChild(int readfd, int writefd) {
+ while (1) {
+ char buf;
+ int ret = ReadFd(readfd, &buf, 1);
+ if (ret == 0) {
+ break;
+ }
+ TEST_CHECK_MSG(ret == 1, "read failed");
+
+ ret = WriteFd(writefd, &buf, 1);
+ if (ret == -1) {
+ TEST_CHECK_MSG(errno == EPIPE, "unexpected write failure");
+ break;
+ }
+ TEST_CHECK_MSG(ret == 1, "write failed");
+ }
+}
+
+// Send bytes in a loop through a series of pipes, each passing through a
+// different process.
+//
+// Proc 0 Proc 1
+// * ----------> *
+// ^ Pipe 1 |
+// | |
+// | Pipe 0 | Pipe 2
+// | |
+// | |
+// | Pipe 3 v
+// * <---------- *
+// Proc 3 Proc 2
+//
+// This exercises context switching through multiple processes.
+void BM_ProcessSwitch(benchmark::State& state) {
+ // Code below assumes there are at least two processes.
+ const int num_processes = state.range(0);
+ ASSERT_GE(num_processes, 2);
+
+ std::vector<pid_t> children;
+ auto child_cleanup = Cleanup([&] {
+ for (const pid_t child : children) {
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status));
+ EXPECT_EQ(0, WEXITSTATUS(status));
+ }
+ });
+
+ // Must come after children, as the FDs must be closed before the children
+ // will exit.
+ std::vector<FileDescriptor> read_fds;
+ std::vector<FileDescriptor> write_fds;
+
+ for (int i = 0; i < num_processes; i++) {
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ read_fds.emplace_back(fds[0]);
+ write_fds.emplace_back(fds[1]);
+ }
+
+ // This process is one of the processes in the loop. It will be considered
+ // index 0.
+ for (int i = 1; i < num_processes; i++) {
+ // Read from current pipe index, write to next.
+ const int read_index = i;
+ const int read_fd = read_fds[read_index].get();
+
+ const int write_index = (i + 1) % num_processes;
+ const int write_fd = write_fds[write_index].get();
+
+ // std::vector isn't safe to use from the fork child.
+ FileDescriptor* read_array = read_fds.data();
+ FileDescriptor* write_array = write_fds.data();
+
+ pid_t child = fork();
+ if (!child) {
+ // Close all other FDs.
+ for (int j = 0; j < num_processes; j++) {
+ if (j != read_index) {
+ read_array[j].reset();
+ }
+ if (j != write_index) {
+ write_array[j].reset();
+ }
+ }
+
+ SwitchChild(read_fd, write_fd);
+ _exit(0);
+ }
+ ASSERT_THAT(child, SyscallSucceeds());
+ children.push_back(child);
+ }
+
+ // Read from current pipe index (0), write to next (1).
+ const int read_index = 0;
+ const int read_fd = read_fds[read_index].get();
+
+ const int write_index = 1;
+ const int write_fd = write_fds[write_index].get();
+
+ // Kick start the loop.
+ char buf = 'a';
+ ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1));
+
+ for (auto _ : state) {
+ ASSERT_THAT(ReadFd(read_fd, &buf, 1), SyscallSucceedsWithValue(1));
+ ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1));
+ }
+}
+
+BENCHMARK(BM_ProcessSwitch)->Range(2, 16)->UseRealTime();
+
+// Equivalent to BM_ThreadSwitch using threads instead of processes.
+void BM_ThreadSwitch(benchmark::State& state) {
+ // Code below assumes there are at least two threads.
+ const int num_threads = state.range(0);
+ ASSERT_GE(num_threads, 2);
+
+ // Must come after threads, as the FDs must be closed before the children
+ // will exit.
+ std::vector<std::unique_ptr<ScopedThread>> threads;
+ std::vector<FileDescriptor> read_fds;
+ std::vector<FileDescriptor> write_fds;
+
+ for (int i = 0; i < num_threads; i++) {
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ read_fds.emplace_back(fds[0]);
+ write_fds.emplace_back(fds[1]);
+ }
+
+ // This thread is one of the threads in the loop. It will be considered
+ // index 0.
+ for (int i = 1; i < num_threads; i++) {
+ // Read from current pipe index, write to next.
+ //
+ // Transfer ownership of the FDs to the thread.
+ const int read_index = i;
+ const int read_fd = read_fds[read_index].release();
+
+ const int write_index = (i + 1) % num_threads;
+ const int write_fd = write_fds[write_index].release();
+
+ threads.emplace_back(std::make_unique<ScopedThread>([read_fd, write_fd] {
+ FileDescriptor read(read_fd);
+ FileDescriptor write(write_fd);
+ SwitchChild(read.get(), write.get());
+ }));
+ }
+
+ // Read from current pipe index (0), write to next (1).
+ const int read_index = 0;
+ const int read_fd = read_fds[read_index].get();
+
+ const int write_index = 1;
+ const int write_fd = write_fds[write_index].get();
+
+ // Kick start the loop.
+ char buf = 'a';
+ ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1));
+
+ for (auto _ : state) {
+ ASSERT_THAT(ReadFd(read_fd, &buf, 1), SyscallSucceedsWithValue(1));
+ ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1));
+ }
+
+ // The two FDs still owned by this thread are closed, causing the next thread
+ // to exit its loop and close its FDs, and so on until all threads exit.
+}
+
+BENCHMARK(BM_ThreadSwitch)->Range(2, 16)->UseRealTime();
+
+void BM_ThreadStart(benchmark::State& state) {
+ const int num_threads = state.range(0);
+
+ for (auto _ : state) {
+ state.PauseTiming();
+
+ auto barrier = new absl::Barrier(num_threads + 1);
+ std::vector<std::unique_ptr<ScopedThread>> threads;
+
+ state.ResumeTiming();
+
+ for (size_t i = 0; i < num_threads; ++i) {
+ threads.emplace_back(std::make_unique<ScopedThread>([barrier] {
+ if (barrier->Block()) {
+ delete barrier;
+ }
+ }));
+ }
+
+ if (barrier->Block()) {
+ delete barrier;
+ }
+
+ state.PauseTiming();
+
+ for (const auto& thread : threads) {
+ thread->Join();
+ }
+
+ state.ResumeTiming();
+ }
+}
+
+BENCHMARK(BM_ThreadStart)->Range(1, 2048)->UseRealTime();
+
+// Benchmark the complete fork + exit + wait.
+void BM_ProcessLifecycle(benchmark::State& state) {
+ const int num_procs = state.range(0);
+
+ std::vector<pid_t> pids(num_procs);
+ for (auto _ : state) {
+ for (size_t i = 0; i < num_procs; ++i) {
+ int pid = fork();
+ if (pid == 0) {
+ _exit(0);
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ pids[i] = pid;
+ }
+
+ for (const int pid : pids) {
+ ASSERT_THAT(RetryEINTR(waitpid)(pid, nullptr, 0),
+ SyscallSucceedsWithValue(pid));
+ }
+ }
+}
+
+BENCHMARK(BM_ProcessLifecycle)->Range(1, 512)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/futex_benchmark.cc b/test/perf/linux/futex_benchmark.cc
new file mode 100644
index 000000000..e686041c9
--- /dev/null
+++ b/test/perf/linux/futex_benchmark.cc
@@ -0,0 +1,198 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <linux/futex.h>
+
+#include <atomic>
+#include <cerrno>
+#include <cstdint>
+#include <cstdlib>
+#include <ctime>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+inline int FutexWait(std::atomic<int32_t>* v, int32_t val) {
+ return syscall(SYS_futex, v, FUTEX_WAIT_PRIVATE, val, nullptr);
+}
+
+inline int FutexWaitMonotonicTimeout(std::atomic<int32_t>* v, int32_t val,
+ const struct timespec* timeout) {
+ return syscall(SYS_futex, v, FUTEX_WAIT_PRIVATE, val, timeout);
+}
+
+inline int FutexWaitMonotonicDeadline(std::atomic<int32_t>* v, int32_t val,
+ const struct timespec* deadline) {
+ return syscall(SYS_futex, v, FUTEX_WAIT_BITSET_PRIVATE, val, deadline,
+ nullptr, FUTEX_BITSET_MATCH_ANY);
+}
+
+inline int FutexWaitRealtimeDeadline(std::atomic<int32_t>* v, int32_t val,
+ const struct timespec* deadline) {
+ return syscall(SYS_futex, v, FUTEX_WAIT_BITSET_PRIVATE | FUTEX_CLOCK_REALTIME,
+ val, deadline, nullptr, FUTEX_BITSET_MATCH_ANY);
+}
+
+inline int FutexWake(std::atomic<int32_t>* v, int32_t count) {
+ return syscall(SYS_futex, v, FUTEX_WAKE_PRIVATE, count);
+}
+
+// This just uses FUTEX_WAKE on an address with nothing waiting, very simple.
+void BM_FutexWakeNop(benchmark::State& state) {
+ std::atomic<int32_t> v(0);
+
+ for (auto _ : state) {
+ TEST_PCHECK(FutexWake(&v, 1) == 0);
+ }
+}
+
+BENCHMARK(BM_FutexWakeNop)->MinTime(5);
+
+// This just uses FUTEX_WAIT on an address whose value has changed, i.e., the
+// syscall won't wait.
+void BM_FutexWaitNop(benchmark::State& state) {
+ std::atomic<int32_t> v(0);
+
+ for (auto _ : state) {
+ TEST_PCHECK(FutexWait(&v, 1) == -1 && errno == EAGAIN);
+ }
+}
+
+BENCHMARK(BM_FutexWaitNop)->MinTime(5);
+
+// This uses FUTEX_WAIT with a timeout on an address whose value never
+// changes, such that it always times out. Timeout overhead can be estimated by
+// timer overruns for short timeouts.
+void BM_FutexWaitMonotonicTimeout(benchmark::State& state) {
+ const absl::Duration timeout = absl::Nanoseconds(state.range(0));
+ std::atomic<int32_t> v(0);
+ auto ts = absl::ToTimespec(timeout);
+
+ for (auto _ : state) {
+ TEST_PCHECK(FutexWaitMonotonicTimeout(&v, 0, &ts) == -1 &&
+ errno == ETIMEDOUT);
+ }
+}
+
+BENCHMARK(BM_FutexWaitMonotonicTimeout)
+ ->MinTime(5)
+ ->UseRealTime()
+ ->Arg(1)
+ ->Arg(10)
+ ->Arg(100)
+ ->Arg(1000)
+ ->Arg(10000);
+
+// This uses FUTEX_WAIT_BITSET with a deadline that is in the past. This allows
+// estimation of the overhead of setting up a timer for a deadline (as opposed
+// to a timeout as specified for FUTEX_WAIT).
+void BM_FutexWaitMonotonicDeadline(benchmark::State& state) {
+ std::atomic<int32_t> v(0);
+ struct timespec ts = {};
+
+ for (auto _ : state) {
+ TEST_PCHECK(FutexWaitMonotonicDeadline(&v, 0, &ts) == -1 &&
+ errno == ETIMEDOUT);
+ }
+}
+
+BENCHMARK(BM_FutexWaitMonotonicDeadline)->MinTime(5);
+
+// This is equivalent to BM_FutexWaitMonotonicDeadline, but uses CLOCK_REALTIME
+// instead of CLOCK_MONOTONIC for the deadline.
+void BM_FutexWaitRealtimeDeadline(benchmark::State& state) {
+ std::atomic<int32_t> v(0);
+ struct timespec ts = {};
+
+ for (auto _ : state) {
+ TEST_PCHECK(FutexWaitRealtimeDeadline(&v, 0, &ts) == -1 &&
+ errno == ETIMEDOUT);
+ }
+}
+
+BENCHMARK(BM_FutexWaitRealtimeDeadline)->MinTime(5);
+
+int64_t GetCurrentMonotonicTimeNanos() {
+ struct timespec ts;
+ TEST_CHECK(clock_gettime(CLOCK_MONOTONIC, &ts) != -1);
+ return ts.tv_sec * 1000000000ULL + ts.tv_nsec;
+}
+
+void SpinNanos(int64_t delay_ns) {
+ if (delay_ns <= 0) {
+ return;
+ }
+ const int64_t end = GetCurrentMonotonicTimeNanos() + delay_ns;
+ while (GetCurrentMonotonicTimeNanos() < end) {
+ // spin
+ }
+}
+
+// Each iteration of FutexRoundtripDelayed involves a thread sending a futex
+// wakeup to another thread, which spins for delay_us and then sends a futex
+// wakeup back. The time per iteration is 2 * (delay_us + kBeforeWakeDelayNs +
+// futex/scheduling overhead).
+void BM_FutexRoundtripDelayed(benchmark::State& state) {
+ const int delay_us = state.range(0);
+ const int64_t delay_ns = delay_us * 1000;
+ // Spin for an extra kBeforeWakeDelayNs before invoking FUTEX_WAKE to reduce
+ // the probability that the wakeup comes before the wait, preventing the wait
+ // from ever taking effect and causing the benchmark to underestimate the
+ // actual wakeup time.
+ constexpr int64_t kBeforeWakeDelayNs = 500;
+ std::atomic<int32_t> v(0);
+ ScopedThread t([&] {
+ for (int i = 0; i < state.max_iterations; i++) {
+ SpinNanos(delay_ns);
+ while (v.load(std::memory_order_acquire) == 0) {
+ FutexWait(&v, 0);
+ }
+ SpinNanos(kBeforeWakeDelayNs + delay_ns);
+ v.store(0, std::memory_order_release);
+ FutexWake(&v, 1);
+ }
+ });
+ for (auto _ : state) {
+ SpinNanos(kBeforeWakeDelayNs + delay_ns);
+ v.store(1, std::memory_order_release);
+ FutexWake(&v, 1);
+ SpinNanos(delay_ns);
+ while (v.load(std::memory_order_acquire) == 1) {
+ FutexWait(&v, 1);
+ }
+ }
+}
+
+BENCHMARK(BM_FutexRoundtripDelayed)
+ ->MinTime(5)
+ ->UseRealTime()
+ ->Arg(0)
+ ->Arg(10)
+ ->Arg(20)
+ ->Arg(50)
+ ->Arg(100);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/getdents_benchmark.cc b/test/perf/linux/getdents_benchmark.cc
new file mode 100644
index 000000000..d8e81fa8c
--- /dev/null
+++ b/test/perf/linux/getdents_benchmark.cc
@@ -0,0 +1,149 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+#ifndef SYS_getdents64
+#if defined(__x86_64__)
+#define SYS_getdents64 217
+#elif defined(__aarch64__)
+#define SYS_getdents64 217
+#else
+#error "Unknown architecture"
+#endif
+#endif // SYS_getdents64
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr int kBufferSize = 65536;
+
+PosixErrorOr<TempPath> CreateDirectory(int count,
+ std::vector<std::string>* files) {
+ ASSIGN_OR_RETURN_ERRNO(TempPath dir, TempPath::CreateDir());
+
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor dfd,
+ Open(dir.path(), O_RDONLY | O_DIRECTORY));
+
+ for (int i = 0; i < count; i++) {
+ auto file = NewTempRelPath();
+ auto res = MknodAt(dfd, file, S_IFREG | 0644, 0);
+ RETURN_IF_ERRNO(res);
+ files->push_back(file);
+ }
+
+ return std::move(dir);
+}
+
+PosixError CleanupDirectory(const TempPath& dir,
+ std::vector<std::string>* files) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor dfd,
+ Open(dir.path(), O_RDONLY | O_DIRECTORY));
+
+ for (auto it = files->begin(); it != files->end(); ++it) {
+ auto res = UnlinkAt(dfd, *it, 0);
+ RETURN_IF_ERRNO(res);
+ }
+ return NoError();
+}
+
+// Creates a directory containing `files` files, and reads all the directory
+// entries from the directory using a single FD.
+void BM_GetdentsSameFD(benchmark::State& state) {
+ // Create directory with given files.
+ const int count = state.range(0);
+
+ // Keep a vector of all of the file TempPaths that is destroyed before dir.
+ //
+ // Normally, we'd simply allow dir to recursively clean up the contained
+ // files, but that recursive cleanup uses getdents, which may be very slow in
+ // extreme benchmarks.
+ TempPath dir;
+ std::vector<std::string> files;
+ dir = ASSERT_NO_ERRNO_AND_VALUE(CreateDirectory(count, &files));
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY));
+ char buffer[kBufferSize];
+
+ // We read all directory entries on each iteration, but report this as a
+ // "batch" iteration so that reported times are per file.
+ while (state.KeepRunningBatch(count)) {
+ ASSERT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallSucceeds());
+
+ int ret;
+ do {
+ ASSERT_THAT(ret = syscall(SYS_getdents64, fd.get(), buffer, kBufferSize),
+ SyscallSucceeds());
+ } while (ret > 0);
+ }
+
+ ASSERT_NO_ERRNO(CleanupDirectory(dir, &files));
+
+ state.SetItemsProcessed(state.iterations());
+}
+
+BENCHMARK(BM_GetdentsSameFD)->Range(1, 1 << 16)->UseRealTime();
+
+// Creates a directory containing `files` files, and reads all the directory
+// entries from the directory using a new FD each time.
+void BM_GetdentsNewFD(benchmark::State& state) {
+ // Create directory with given files.
+ const int count = state.range(0);
+
+ // Keep a vector of all of the file TempPaths that is destroyed before dir.
+ //
+ // Normally, we'd simply allow dir to recursively clean up the contained
+ // files, but that recursive cleanup uses getdents, which may be very slow in
+ // extreme benchmarks.
+ TempPath dir;
+ std::vector<std::string> files;
+ dir = ASSERT_NO_ERRNO_AND_VALUE(CreateDirectory(count, &files));
+ char buffer[kBufferSize];
+
+ // We read all directory entries on each iteration, but report this as a
+ // "batch" iteration so that reported times are per file.
+ while (state.KeepRunningBatch(count)) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY));
+
+ int ret;
+ do {
+ ASSERT_THAT(ret = syscall(SYS_getdents64, fd.get(), buffer, kBufferSize),
+ SyscallSucceeds());
+ } while (ret > 0);
+ }
+
+ ASSERT_NO_ERRNO(CleanupDirectory(dir, &files));
+
+ state.SetItemsProcessed(state.iterations());
+}
+
+BENCHMARK(BM_GetdentsNewFD)->Range(1, 1 << 12)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/getpid_benchmark.cc b/test/perf/linux/getpid_benchmark.cc
new file mode 100644
index 000000000..db74cb264
--- /dev/null
+++ b/test/perf/linux/getpid_benchmark.cc
@@ -0,0 +1,37 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Getpid(benchmark::State& state) {
+ for (auto _ : state) {
+ syscall(SYS_getpid);
+ }
+}
+
+BENCHMARK(BM_Getpid);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/gettid_benchmark.cc b/test/perf/linux/gettid_benchmark.cc
new file mode 100644
index 000000000..8f4961f5e
--- /dev/null
+++ b/test/perf/linux/gettid_benchmark.cc
@@ -0,0 +1,38 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Gettid(benchmark::State& state) {
+ for (auto _ : state) {
+ syscall(SYS_gettid);
+ }
+}
+
+BENCHMARK(BM_Gettid)->ThreadRange(1, 4000)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/mapping_benchmark.cc b/test/perf/linux/mapping_benchmark.cc
new file mode 100644
index 000000000..39c30fe69
--- /dev/null
+++ b/test/perf/linux/mapping_benchmark.cc
@@ -0,0 +1,163 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdlib.h>
+#include <sys/mman.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/memory_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Conservative value for /proc/sys/vm/max_map_count, which limits the number of
+// VMAs, minus a safety margin for VMAs that already exist for the test binary.
+// The default value for max_map_count is
+// include/linux/mm.h:DEFAULT_MAX_MAP_COUNT = 65530.
+constexpr size_t kMaxVMAs = 64001;
+
+// Map then unmap pages without touching them.
+void BM_MapUnmap(benchmark::State& state) {
+ // Number of pages to map.
+ const int pages = state.range(0);
+
+ while (state.KeepRunning()) {
+ void* addr = mmap(0, pages * kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+ TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed");
+
+ int ret = munmap(addr, pages * kPageSize);
+ TEST_CHECK_MSG(ret == 0, "munmap failed");
+ }
+}
+
+BENCHMARK(BM_MapUnmap)->Range(1, 1 << 17)->UseRealTime();
+
+// Map, touch, then unmap pages.
+void BM_MapTouchUnmap(benchmark::State& state) {
+ // Number of pages to map.
+ const int pages = state.range(0);
+
+ while (state.KeepRunning()) {
+ void* addr = mmap(0, pages * kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+ TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed");
+
+ char* c = reinterpret_cast<char*>(addr);
+ char* end = c + pages * kPageSize;
+ while (c < end) {
+ *c = 42;
+ c += kPageSize;
+ }
+
+ int ret = munmap(addr, pages * kPageSize);
+ TEST_CHECK_MSG(ret == 0, "munmap failed");
+ }
+}
+
+BENCHMARK(BM_MapTouchUnmap)->Range(1, 1 << 17)->UseRealTime();
+
+// Map and touch many pages, unmapping all at once.
+//
+// NOTE(b/111429208): This is a regression test to ensure performant mapping and
+// allocation even with tons of mappings.
+void BM_MapTouchMany(benchmark::State& state) {
+ // Number of pages to map.
+ const int page_count = state.range(0);
+
+ while (state.KeepRunning()) {
+ std::vector<void*> pages;
+
+ for (int i = 0; i < page_count; i++) {
+ void* addr = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+ TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed");
+
+ char* c = reinterpret_cast<char*>(addr);
+ *c = 42;
+
+ pages.push_back(addr);
+ }
+
+ for (void* addr : pages) {
+ int ret = munmap(addr, kPageSize);
+ TEST_CHECK_MSG(ret == 0, "munmap failed");
+ }
+ }
+
+ state.SetBytesProcessed(kPageSize * page_count * state.iterations());
+}
+
+BENCHMARK(BM_MapTouchMany)->Range(1, 1 << 12)->UseRealTime();
+
+void BM_PageFault(benchmark::State& state) {
+ // Map the region in which we will take page faults. To ensure that each page
+ // fault maps only a single page, each page we touch must correspond to a
+ // distinct VMA. Thus we need a 1-page gap between each 1-page VMA. However,
+ // each gap consists of a PROT_NONE VMA, instead of an unmapped hole, so that
+ // if there are background threads running, they can't inadvertently creating
+ // mappings in our gaps that are unmapped when the test ends.
+ size_t test_pages = kMaxVMAs;
+ // Ensure that test_pages is odd, since we want the test region to both
+ // begin and end with a mapped page.
+ if (test_pages % 2 == 0) {
+ test_pages--;
+ }
+ const size_t test_region_bytes = test_pages * kPageSize;
+ // Use MAP_SHARED here because madvise(MADV_DONTNEED) on private mappings on
+ // gVisor won't force future sentry page faults (by design). Use MAP_POPULATE
+ // so that Linux pre-allocates the shmem file used to back the mapping.
+ Mapping m = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(test_region_bytes, PROT_READ, MAP_SHARED | MAP_POPULATE));
+ for (size_t i = 0; i < test_pages / 2; i++) {
+ ASSERT_THAT(
+ mprotect(reinterpret_cast<void*>(m.addr() + ((2 * i + 1) * kPageSize)),
+ kPageSize, PROT_NONE),
+ SyscallSucceeds());
+ }
+
+ const size_t mapped_pages = test_pages / 2 + 1;
+ // "Start" at the end of the mapped region to force the mapped region to be
+ // reset, since we mapped it with MAP_POPULATE.
+ size_t cur_page = mapped_pages;
+ for (auto _ : state) {
+ if (cur_page >= mapped_pages) {
+ // We've reached the end of our mapped region and have to reset it to
+ // incur page faults again.
+ state.PauseTiming();
+ ASSERT_THAT(madvise(m.ptr(), test_region_bytes, MADV_DONTNEED),
+ SyscallSucceeds());
+ cur_page = 0;
+ state.ResumeTiming();
+ }
+ const uintptr_t addr = m.addr() + (2 * cur_page * kPageSize);
+ const char c = *reinterpret_cast<volatile char*>(addr);
+ benchmark::DoNotOptimize(c);
+ cur_page++;
+ }
+}
+
+BENCHMARK(BM_PageFault)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/open_benchmark.cc b/test/perf/linux/open_benchmark.cc
new file mode 100644
index 000000000..68008f6d5
--- /dev/null
+++ b/test/perf/linux/open_benchmark.cc
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <unistd.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/fs_util.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Open(benchmark::State& state) {
+ const int size = state.range(0);
+ std::vector<TempPath> cache;
+ for (int i = 0; i < size; i++) {
+ auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ cache.emplace_back(std::move(path));
+ }
+
+ unsigned int seed = 1;
+ for (auto _ : state) {
+ const int chosen = rand_r(&seed) % size;
+ int fd = open(cache[chosen].path().c_str(), O_RDONLY);
+ TEST_CHECK(fd != -1);
+ close(fd);
+ }
+}
+
+BENCHMARK(BM_Open)->Range(1, 128)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/pipe_benchmark.cc b/test/perf/linux/pipe_benchmark.cc
new file mode 100644
index 000000000..8f5f6a2a3
--- /dev/null
+++ b/test/perf/linux/pipe_benchmark.cc
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include <cerrno>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Pipe(benchmark::State& state) {
+ int fds[2];
+ TEST_CHECK(pipe(fds) == 0);
+
+ const int size = state.range(0);
+ std::vector<char> wbuf(size);
+ std::vector<char> rbuf(size);
+ RandomizeBuffer(wbuf.data(), size);
+
+ ScopedThread t([&] {
+ auto const fd = fds[1];
+ for (int i = 0; i < state.max_iterations; i++) {
+ TEST_CHECK(WriteFd(fd, wbuf.data(), wbuf.size()) == size);
+ }
+ });
+
+ for (auto _ : state) {
+ TEST_CHECK(ReadFd(fds[0], rbuf.data(), rbuf.size()) == size);
+ }
+
+ t.Join();
+
+ close(fds[0]);
+ close(fds[1]);
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_Pipe)->Range(1, 1 << 20)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/randread_benchmark.cc b/test/perf/linux/randread_benchmark.cc
new file mode 100644
index 000000000..b0eb8c24e
--- /dev/null
+++ b/test/perf/linux/randread_benchmark.cc
@@ -0,0 +1,100 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <sys/uio.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Create a 1GB file that will be read from at random positions. This should
+// invalid any performance gains from caching.
+const uint64_t kFileSize = 1ULL << 30;
+
+// How many bytes to write at once to initialize the file used to read from.
+const uint32_t kWriteSize = 65536;
+
+// Largest benchmarked read unit.
+const uint32_t kMaxRead = 1UL << 26;
+
+TempPath CreateFile(uint64_t file_size) {
+ auto path = TempPath::CreateFile().ValueOrDie();
+ FileDescriptor fd = Open(path.path(), O_WRONLY).ValueOrDie();
+
+ // Try to minimize syscalls by using maximum size writev() requests.
+ std::vector<char> buffer(kWriteSize);
+ RandomizeBuffer(buffer.data(), buffer.size());
+ const std::vector<std::vector<struct iovec>> iovecs_list =
+ GenerateIovecs(file_size, buffer.data(), buffer.size());
+ for (const auto& iovecs : iovecs_list) {
+ TEST_CHECK(writev(fd.get(), iovecs.data(), iovecs.size()) >= 0);
+ }
+
+ return path;
+}
+
+// Global test state, initialized once per process lifetime.
+struct GlobalState {
+ const TempPath tmpfile;
+ explicit GlobalState(TempPath tfile) : tmpfile(std::move(tfile)) {}
+};
+
+GlobalState& GetGlobalState() {
+ // This gets created only once throughout the lifetime of the process.
+ // Use a dynamically allocated object (that is never deleted) to avoid order
+ // of destruction of static storage variables issues.
+ static GlobalState* const state =
+ // The actual file size is the maximum random seek range (kFileSize) + the
+ // maximum read size so we can read that number of bytes at the end of the
+ // file.
+ new GlobalState(CreateFile(kFileSize + kMaxRead));
+ return *state;
+}
+
+void BM_RandRead(benchmark::State& state) {
+ const int size = state.range(0);
+
+ GlobalState& global_state = GetGlobalState();
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(global_state.tmpfile.path(), O_RDONLY));
+ std::vector<char> buf(size);
+
+ unsigned int seed = 1;
+ for (auto _ : state) {
+ TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(),
+ rand_r(&seed) % kFileSize) == size);
+ }
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_RandRead)->Range(1, kMaxRead)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/read_benchmark.cc b/test/perf/linux/read_benchmark.cc
new file mode 100644
index 000000000..62445867d
--- /dev/null
+++ b/test/perf/linux/read_benchmark.cc
@@ -0,0 +1,53 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/fs_util.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Read(benchmark::State& state) {
+ const int size = state.range(0);
+ const std::string contents(size, 0);
+ auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), contents, TempPath::kDefaultFileMode));
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDONLY));
+
+ std::vector<char> buf(size);
+ for (auto _ : state) {
+ TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(), 0) == size);
+ }
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_Read)->Range(1, 1 << 26)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/sched_yield_benchmark.cc b/test/perf/linux/sched_yield_benchmark.cc
new file mode 100644
index 000000000..6756b5575
--- /dev/null
+++ b/test/perf/linux/sched_yield_benchmark.cc
@@ -0,0 +1,37 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sched.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Sched_yield(benchmark::State& state) {
+ for (auto ignored : state) {
+ TEST_CHECK(sched_yield() == 0);
+ }
+}
+
+BENCHMARK(BM_Sched_yield)->ThreadRange(1, 2000)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/send_recv_benchmark.cc b/test/perf/linux/send_recv_benchmark.cc
new file mode 100644
index 000000000..d73e49523
--- /dev/null
+++ b/test/perf/linux/send_recv_benchmark.cc
@@ -0,0 +1,372 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <poll.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+
+#include <cstring>
+
+#include "gtest/gtest.h"
+#include "absl/synchronization/notification.h"
+#include "benchmark/benchmark.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr ssize_t kMessageSize = 1024;
+
+class Message {
+ public:
+ explicit Message(int byte = 0) : Message(byte, kMessageSize, 0) {}
+
+ explicit Message(int byte, int sz) : Message(byte, sz, 0) {}
+
+ explicit Message(int byte, int sz, int cmsg_sz)
+ : buffer_(sz, byte), cmsg_buffer_(cmsg_sz, 0) {
+ iov_.iov_base = buffer_.data();
+ iov_.iov_len = sz;
+ hdr_.msg_iov = &iov_;
+ hdr_.msg_iovlen = 1;
+ hdr_.msg_control = cmsg_buffer_.data();
+ hdr_.msg_controllen = cmsg_sz;
+ }
+
+ struct msghdr* header() {
+ return &hdr_;
+ }
+
+ private:
+ std::vector<char> buffer_;
+ std::vector<char> cmsg_buffer_;
+ struct iovec iov_ = {};
+ struct msghdr hdr_ = {};
+};
+
+void BM_Recvmsg(benchmark::State& state) {
+ int sockets[2];
+ TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0);
+ FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]);
+ absl::Notification notification;
+ Message send_msg('a'), recv_msg;
+
+ ScopedThread t([&send_msg, &send_socket, &notification] {
+ while (!notification.HasBeenNotified()) {
+ sendmsg(send_socket.get(), send_msg.header(), 0);
+ }
+ });
+
+ int64_t bytes_received = 0;
+ for (auto ignored : state) {
+ int n = recvmsg(recv_socket.get(), recv_msg.header(), 0);
+ TEST_CHECK(n > 0);
+ bytes_received += n;
+ }
+
+ notification.Notify();
+ recv_socket.reset();
+
+ state.SetBytesProcessed(bytes_received);
+}
+
+BENCHMARK(BM_Recvmsg)->UseRealTime();
+
+void BM_Sendmsg(benchmark::State& state) {
+ int sockets[2];
+ TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0);
+ FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]);
+ absl::Notification notification;
+ Message send_msg('a'), recv_msg;
+
+ ScopedThread t([&recv_msg, &recv_socket, &notification] {
+ while (!notification.HasBeenNotified()) {
+ recvmsg(recv_socket.get(), recv_msg.header(), 0);
+ }
+ });
+
+ int64_t bytes_sent = 0;
+ for (auto ignored : state) {
+ int n = sendmsg(send_socket.get(), send_msg.header(), 0);
+ TEST_CHECK(n > 0);
+ bytes_sent += n;
+ }
+
+ notification.Notify();
+ send_socket.reset();
+
+ state.SetBytesProcessed(bytes_sent);
+}
+
+BENCHMARK(BM_Sendmsg)->UseRealTime();
+
+void BM_Recvfrom(benchmark::State& state) {
+ int sockets[2];
+ TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0);
+ FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]);
+ absl::Notification notification;
+ char send_buffer[kMessageSize], recv_buffer[kMessageSize];
+
+ ScopedThread t([&send_socket, &send_buffer, &notification] {
+ while (!notification.HasBeenNotified()) {
+ sendto(send_socket.get(), send_buffer, kMessageSize, 0, nullptr, 0);
+ }
+ });
+
+ int bytes_received = 0;
+ for (auto ignored : state) {
+ int n = recvfrom(recv_socket.get(), recv_buffer, kMessageSize, 0, nullptr,
+ nullptr);
+ TEST_CHECK(n > 0);
+ bytes_received += n;
+ }
+
+ notification.Notify();
+ recv_socket.reset();
+
+ state.SetBytesProcessed(bytes_received);
+}
+
+BENCHMARK(BM_Recvfrom)->UseRealTime();
+
+void BM_Sendto(benchmark::State& state) {
+ int sockets[2];
+ TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0);
+ FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]);
+ absl::Notification notification;
+ char send_buffer[kMessageSize], recv_buffer[kMessageSize];
+
+ ScopedThread t([&recv_socket, &recv_buffer, &notification] {
+ while (!notification.HasBeenNotified()) {
+ recvfrom(recv_socket.get(), recv_buffer, kMessageSize, 0, nullptr,
+ nullptr);
+ }
+ });
+
+ int64_t bytes_sent = 0;
+ for (auto ignored : state) {
+ int n = sendto(send_socket.get(), send_buffer, kMessageSize, 0, nullptr, 0);
+ TEST_CHECK(n > 0);
+ bytes_sent += n;
+ }
+
+ notification.Notify();
+ send_socket.reset();
+
+ state.SetBytesProcessed(bytes_sent);
+}
+
+BENCHMARK(BM_Sendto)->UseRealTime();
+
+PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) {
+ struct sockaddr_storage addr;
+ memset(&addr, 0, sizeof(addr));
+ addr.ss_family = family;
+ switch (family) {
+ case AF_INET:
+ reinterpret_cast<struct sockaddr_in*>(&addr)->sin_addr.s_addr =
+ htonl(INADDR_LOOPBACK);
+ break;
+ case AF_INET6:
+ reinterpret_cast<struct sockaddr_in6*>(&addr)->sin6_addr =
+ in6addr_loopback;
+ break;
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+ return addr;
+}
+
+// BM_RecvmsgWithControlBuf measures the performance of recvmsg when we allocate
+// space for control messages. Note that we do not expect to receive any.
+void BM_RecvmsgWithControlBuf(benchmark::State& state) {
+ auto listen_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP));
+
+ // Initialize address to the loopback one.
+ sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET6));
+ socklen_t addrlen = sizeof(addr);
+
+ // Bind to some port then start listening.
+ ASSERT_THAT(bind(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the address we're listening on, then connect to it. We need to do this
+ // because we're allowing the stack to pick a port for us.
+ ASSERT_THAT(getsockname(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ auto send_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP));
+
+ ASSERT_THAT(
+ RetryEINTR(connect)(send_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto recv_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr));
+
+ absl::Notification notification;
+ Message send_msg('a');
+ // Create a msghdr with a buffer allocated for control messages.
+ Message recv_msg(0, kMessageSize, /*cmsg_sz=*/24);
+
+ ScopedThread t([&send_msg, &send_socket, &notification] {
+ while (!notification.HasBeenNotified()) {
+ sendmsg(send_socket.get(), send_msg.header(), 0);
+ }
+ });
+
+ int64_t bytes_received = 0;
+ for (auto ignored : state) {
+ int n = recvmsg(recv_socket.get(), recv_msg.header(), 0);
+ TEST_CHECK(n > 0);
+ bytes_received += n;
+ }
+
+ notification.Notify();
+ recv_socket.reset();
+
+ state.SetBytesProcessed(bytes_received);
+}
+
+BENCHMARK(BM_RecvmsgWithControlBuf)->UseRealTime();
+
+// BM_SendmsgTCP measures the sendmsg throughput with varying payload sizes.
+//
+// state.Args[0] indicates whether the underlying socket should be blocking or
+// non-blocking w/ 0 indicating non-blocking and 1 to indicate blocking.
+// state.Args[1] is the size of the payload to be used per sendmsg call.
+void BM_SendmsgTCP(benchmark::State& state) {
+ auto listen_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
+
+ // Initialize address to the loopback one.
+ sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET));
+ socklen_t addrlen = sizeof(addr);
+
+ // Bind to some port then start listening.
+ ASSERT_THAT(bind(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the address we're listening on, then connect to it. We need to do this
+ // because we're allowing the stack to pick a port for us.
+ ASSERT_THAT(getsockname(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ auto send_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
+
+ ASSERT_THAT(
+ RetryEINTR(connect)(send_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto recv_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr));
+
+ // Check if we want to run the test w/ a blocking send socket
+ // or non-blocking.
+ const int blocking = state.range(0);
+ if (!blocking) {
+ // Set the send FD to O_NONBLOCK.
+ int opts;
+ ASSERT_THAT(opts = fcntl(send_socket.get(), F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(send_socket.get(), F_SETFL, opts), SyscallSucceeds());
+ }
+
+ absl::Notification notification;
+
+ // Get the buffer size we should use for this iteration of the test.
+ const int buf_size = state.range(1);
+ Message send_msg('a', buf_size), recv_msg(0, buf_size);
+
+ ScopedThread t([&recv_msg, &recv_socket, &notification] {
+ while (!notification.HasBeenNotified()) {
+ TEST_CHECK(recvmsg(recv_socket.get(), recv_msg.header(), 0) >= 0);
+ }
+ });
+
+ int64_t bytes_sent = 0;
+ int ncalls = 0;
+ for (auto ignored : state) {
+ int sent = 0;
+ while (true) {
+ struct msghdr hdr = {};
+ struct iovec iov = {};
+ struct msghdr* snd_header = send_msg.header();
+ iov.iov_base = static_cast<char*>(snd_header->msg_iov->iov_base) + sent;
+ iov.iov_len = snd_header->msg_iov->iov_len - sent;
+ hdr.msg_iov = &iov;
+ hdr.msg_iovlen = 1;
+ int n = RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0);
+ ncalls++;
+ if (n > 0) {
+ sent += n;
+ if (sent == buf_size) {
+ break;
+ }
+ // n can be > 0 but less than requested size. In which case we don't
+ // poll.
+ continue;
+ }
+ // Poll the fd for it to become writable.
+ struct pollfd poll_fd = {send_socket.get(), POLL_OUT, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10),
+ SyscallSucceedsWithValue(0));
+ }
+ bytes_sent += static_cast<int64_t>(sent);
+ }
+
+ notification.Notify();
+ send_socket.reset();
+ state.SetBytesProcessed(bytes_sent);
+}
+
+void Args(benchmark::internal::Benchmark* benchmark) {
+ for (int blocking = 0; blocking < 2; blocking++) {
+ for (int buf_size = 1024; buf_size <= 256 << 20; buf_size *= 2) {
+ benchmark->Args({blocking, buf_size});
+ }
+ }
+}
+
+BENCHMARK(BM_SendmsgTCP)->Apply(&Args)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/seqwrite_benchmark.cc b/test/perf/linux/seqwrite_benchmark.cc
new file mode 100644
index 000000000..af49e4477
--- /dev/null
+++ b/test/perf/linux/seqwrite_benchmark.cc
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// The maximum file size of the test file, when writes get beyond this point
+// they wrap around. This should be large enough to blow away caches.
+const uint64_t kMaxFile = 1 << 30;
+
+// Perform writes of various sizes sequentially to one file. Wraps around if it
+// goes above a certain maximum file size.
+void BM_SeqWrite(benchmark::State& state) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_WRONLY));
+
+ const int size = state.range(0);
+ std::vector<char> buf(size);
+ RandomizeBuffer(buf.data(), buf.size());
+
+ // Start writes at offset 0.
+ uint64_t offset = 0;
+ for (auto _ : state) {
+ TEST_CHECK(PwriteFd(fd.get(), buf.data(), buf.size(), offset) ==
+ buf.size());
+ offset += buf.size();
+ // Wrap around if going above the maximum file size.
+ if (offset >= kMaxFile) {
+ offset = 0;
+ }
+ }
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_SeqWrite)->Range(1, 1 << 26)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/signal_benchmark.cc b/test/perf/linux/signal_benchmark.cc
new file mode 100644
index 000000000..cec679191
--- /dev/null
+++ b/test/perf/linux/signal_benchmark.cc
@@ -0,0 +1,61 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+#include <string.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void FixupHandler(int sig, siginfo_t* si, void* void_ctx) {
+ static unsigned int dataval = 0;
+
+ // Skip the offending instruction.
+ ucontext_t* ctx = reinterpret_cast<ucontext_t*>(void_ctx);
+ ctx->uc_mcontext.gregs[REG_RAX] = reinterpret_cast<greg_t>(&dataval);
+}
+
+void BM_FaultSignalFixup(benchmark::State& state) {
+ // Set up the signal handler.
+ struct sigaction sa = {};
+ sigemptyset(&sa.sa_mask);
+ sa.sa_sigaction = FixupHandler;
+ sa.sa_flags = SA_SIGINFO;
+ TEST_CHECK(sigaction(SIGSEGV, &sa, nullptr) == 0);
+
+ // Fault, fault, fault.
+ for (auto _ : state) {
+ // Trigger the segfault.
+ asm volatile(
+ "movq $0, %%rax\n"
+ "movq $0x77777777, (%%rax)\n"
+ :
+ :
+ : "rax");
+ }
+}
+
+BENCHMARK(BM_FaultSignalFixup)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/sleep_benchmark.cc b/test/perf/linux/sleep_benchmark.cc
new file mode 100644
index 000000000..99ef05117
--- /dev/null
+++ b/test/perf/linux/sleep_benchmark.cc
@@ -0,0 +1,60 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <sys/syscall.h>
+#include <time.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Sleep for 'param' nanoseconds.
+void BM_Sleep(benchmark::State& state) {
+ const int nanoseconds = state.range(0);
+
+ for (auto _ : state) {
+ struct timespec ts;
+ ts.tv_sec = 0;
+ ts.tv_nsec = nanoseconds;
+
+ int ret;
+ do {
+ ret = syscall(SYS_nanosleep, &ts, &ts);
+ if (ret < 0) {
+ TEST_CHECK(errno == EINTR);
+ }
+ } while (ret < 0);
+ }
+}
+
+BENCHMARK(BM_Sleep)
+ ->Arg(0)
+ ->Arg(1)
+ ->Arg(1000) // 1us
+ ->Arg(1000 * 1000) // 1ms
+ ->Arg(10 * 1000 * 1000) // 10ms
+ ->Arg(50 * 1000 * 1000) // 50ms
+ ->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/stat_benchmark.cc b/test/perf/linux/stat_benchmark.cc
new file mode 100644
index 000000000..f15424482
--- /dev/null
+++ b/test/perf/linux/stat_benchmark.cc
@@ -0,0 +1,62 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "benchmark/benchmark.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Creates a file in a nested directory hierarchy at least `depth` directories
+// deep, and stats that file multiple times.
+void BM_Stat(benchmark::State& state) {
+ // Create nested directories with given depth.
+ int depth = state.range(0);
+ const TempPath top_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ std::string dir_path = top_dir.path();
+
+ while (depth-- > 0) {
+ // Don't use TempPath because it will make paths too long to use.
+ //
+ // The top_dir destructor will clean up this whole tree.
+ dir_path = JoinPath(dir_path, absl::StrCat(depth));
+ ASSERT_NO_ERRNO(Mkdir(dir_path, 0755));
+ }
+
+ // Create the file that will be stat'd.
+ const TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir_path));
+
+ struct stat st;
+ for (auto _ : state) {
+ ASSERT_THAT(stat(file.path().c_str(), &st), SyscallSucceeds());
+ }
+}
+
+BENCHMARK(BM_Stat)->Range(1, 100)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/unlink_benchmark.cc b/test/perf/linux/unlink_benchmark.cc
new file mode 100644
index 000000000..92243a042
--- /dev/null
+++ b/test/perf/linux/unlink_benchmark.cc
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Creates a directory containing `files` files, and unlinks all the files.
+void BM_Unlink(benchmark::State& state) {
+ // Create directory with given files.
+ const int file_count = state.range(0);
+
+ // We unlink all files on each iteration, but report this as a "batch"
+ // iteration so that reported times are per file.
+ TempPath dir;
+ while (state.KeepRunningBatch(file_count)) {
+ state.PauseTiming();
+ // N.B. dir is declared outside the loop so that destruction of the previous
+ // iteration's directory occurs here, inside of PauseTiming.
+ dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ std::vector<TempPath> files;
+ for (int i = 0; i < file_count; i++) {
+ TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
+ files.push_back(std::move(file));
+ }
+ state.ResumeTiming();
+
+ while (!files.empty()) {
+ // Destructor unlinks.
+ files.pop_back();
+ }
+ }
+
+ state.SetItemsProcessed(state.iterations());
+}
+
+BENCHMARK(BM_Unlink)->Range(1, 100 * 1000)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/write_benchmark.cc b/test/perf/linux/write_benchmark.cc
new file mode 100644
index 000000000..7b060c70e
--- /dev/null
+++ b/test/perf/linux/write_benchmark.cc
@@ -0,0 +1,52 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Write(benchmark::State& state) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_WRONLY));
+
+ const int size = state.range(0);
+ std::vector<char> buf(size);
+ RandomizeBuffer(buf.data(), size);
+
+ for (auto _ : state) {
+ TEST_CHECK(PwriteFd(fd.get(), buf.data(), size, 0) == size);
+ }
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_Write)->Range(1, 1 << 26)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/root/BUILD b/test/root/BUILD
new file mode 100644
index 000000000..a9e91ccd6
--- /dev/null
+++ b/test/root/BUILD
@@ -0,0 +1,58 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/vm:defs.bzl", "vm_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "root",
+ srcs = ["root.go"],
+)
+
+go_test(
+ name = "root_test",
+ size = "small",
+ srcs = [
+ "cgroup_test.go",
+ "chroot_test.go",
+ "crictl_test.go",
+ "main_test.go",
+ "oom_score_adj_test.go",
+ "runsc_test.go",
+ ],
+ data = [
+ "//runsc",
+ ],
+ library = ":root",
+ tags = [
+ # Requires docker and runsc to be configured before the test runs.
+ # Also, the test needs to be run as root. Note that below, the
+ # root_vm_test relies on the default runtime 'runsc' being installed by
+ # the default installer.
+ "manual",
+ "local",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/cleanup",
+ "//pkg/test/criutil",
+ "//pkg/test/dockerutil",
+ "//pkg/test/testutil",
+ "//runsc/cgroup",
+ "//runsc/container",
+ "//runsc/specutils",
+ "@com_github_cenkalti_backoff//:go_default_library",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@com_github_syndtr_gocapability//capability:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+vm_test(
+ name = "root_vm_test",
+ size = "large",
+ shard_count = 1,
+ targets = [
+ "//tools/installers:shim",
+ ":root_test",
+ ],
+)
diff --git a/test/root/cgroup_test.go b/test/root/cgroup_test.go
new file mode 100644
index 000000000..a26b83081
--- /dev/null
+++ b/test/root/cgroup_test.go
@@ -0,0 +1,359 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package root
+
+import (
+ "bufio"
+ "context"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/runsc/cgroup"
+)
+
+func verifyPid(pid int, path string) error {
+ f, err := os.Open(path)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+
+ var gots []int
+ scanner := bufio.NewScanner(f)
+ for scanner.Scan() {
+ got, err := strconv.Atoi(scanner.Text())
+ if err != nil {
+ return err
+ }
+ if got == pid {
+ return nil
+ }
+ gots = append(gots, got)
+ }
+ if scanner.Err() != nil {
+ return scanner.Err()
+ }
+ return fmt.Errorf("got: %v, want: %d", gots, pid)
+}
+
+func TestMemCgroup(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Start a new container and allocate the specified about of memory.
+ allocMemSize := 128 << 20
+ allocMemLimit := 2 * allocMemSize
+
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/ubuntu",
+ Memory: allocMemLimit, // Must be in bytes.
+ }, "python3", "-c", fmt.Sprintf("import time; s = 'a' * %d; time.sleep(100)", allocMemSize)); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Extract the ID to lookup the cgroup.
+ gid := d.ID()
+ t.Logf("cgroup ID: %s", gid)
+
+ // Wait when the container will allocate memory.
+ memUsage := 0
+ start := time.Now()
+ for time.Since(start) < 30*time.Second {
+ // Sleep for a brief period of time after spawning the
+ // container (so that Docker can create the cgroup etc.
+ // or after looping below (so the application can start).
+ time.Sleep(100 * time.Millisecond)
+
+ // Read the cgroup memory limit.
+ path := filepath.Join("/sys/fs/cgroup/memory/docker", gid, "memory.limit_in_bytes")
+ outRaw, err := ioutil.ReadFile(path)
+ if err != nil {
+ // It's possible that the container does not exist yet.
+ continue
+ }
+ out := strings.TrimSpace(string(outRaw))
+ memLimit, err := strconv.Atoi(out)
+ if err != nil {
+ t.Fatalf("Atoi(%v): %v", out, err)
+ }
+ if memLimit != allocMemLimit {
+ // The group may not have had the correct limit set yet.
+ continue
+ }
+
+ // Read the cgroup memory usage.
+ path = filepath.Join("/sys/fs/cgroup/memory/docker", gid, "memory.max_usage_in_bytes")
+ outRaw, err = ioutil.ReadFile(path)
+ if err != nil {
+ t.Fatalf("error reading usage: %v", err)
+ }
+ out = strings.TrimSpace(string(outRaw))
+ memUsage, err = strconv.Atoi(out)
+ if err != nil {
+ t.Fatalf("Atoi(%v): %v", out, err)
+ }
+ t.Logf("read usage: %v, wanted: %v", memUsage, allocMemSize)
+
+ // Are we done?
+ if memUsage >= allocMemSize {
+ return
+ }
+ }
+
+ t.Fatalf("%vMB is less than %vMB", memUsage>>20, allocMemSize>>20)
+}
+
+// TestCgroup sets cgroup options and checks that cgroup was properly configured.
+func TestCgroup(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // This is not a comprehensive list of attributes.
+ //
+ // Note that we are specifically missing cpusets, which fail if specified.
+ // In any case, it's unclear if cpusets can be reliably tested here: these
+ // are often run on a single core virtual machine, and there is only a single
+ // CPU available in our current set, and every container's set.
+ attrs := []struct {
+ field string
+ value int64
+ ctrl string
+ file string
+ want string
+ skipIfNotFound bool
+ }{
+ {
+ field: "cpu-shares",
+ value: 1000,
+ ctrl: "cpu",
+ file: "cpu.shares",
+ want: "1000",
+ },
+ {
+ field: "cpu-period",
+ value: 2000,
+ ctrl: "cpu",
+ file: "cpu.cfs_period_us",
+ want: "2000",
+ },
+ {
+ field: "cpu-quota",
+ value: 3000,
+ ctrl: "cpu",
+ file: "cpu.cfs_quota_us",
+ want: "3000",
+ },
+ {
+ field: "kernel-memory",
+ value: 100 << 20,
+ ctrl: "memory",
+ file: "memory.kmem.limit_in_bytes",
+ want: "104857600",
+ },
+ {
+ field: "memory",
+ value: 1 << 30,
+ ctrl: "memory",
+ file: "memory.limit_in_bytes",
+ want: "1073741824",
+ },
+ {
+ field: "memory-reservation",
+ value: 500 << 20,
+ ctrl: "memory",
+ file: "memory.soft_limit_in_bytes",
+ want: "524288000",
+ },
+ {
+ field: "memory-swap",
+ value: 2 << 30,
+ ctrl: "memory",
+ file: "memory.memsw.limit_in_bytes",
+ want: "2147483648",
+ skipIfNotFound: true, // swap may be disabled on the machine.
+ },
+ {
+ field: "memory-swappiness",
+ value: 5,
+ ctrl: "memory",
+ file: "memory.swappiness",
+ want: "5",
+ },
+ {
+ field: "blkio-weight",
+ value: 750,
+ ctrl: "blkio",
+ file: "blkio.weight",
+ want: "750",
+ skipIfNotFound: true, // blkio groups may not be available.
+ },
+ {
+ field: "pids-limit",
+ value: 1000,
+ ctrl: "pids",
+ file: "pids.max",
+ want: "1000",
+ },
+ }
+
+ // Make configs.
+ conf, hostconf, _ := d.ConfigsFrom(dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "sleep", "10000")
+
+ // Add Cgroup arguments to configs.
+ for _, attr := range attrs {
+ switch attr.field {
+ case "cpu-shares":
+ hostconf.Resources.CPUShares = attr.value
+ case "cpu-period":
+ hostconf.Resources.CPUPeriod = attr.value
+ case "cpu-quota":
+ hostconf.Resources.CPUQuota = attr.value
+ case "kernel-memory":
+ hostconf.Resources.KernelMemory = attr.value
+ case "memory":
+ hostconf.Resources.Memory = attr.value
+ case "memory-reservation":
+ hostconf.Resources.MemoryReservation = attr.value
+ case "memory-swap":
+ hostconf.Resources.MemorySwap = attr.value
+ case "memory-swappiness":
+ val := attr.value
+ hostconf.Resources.MemorySwappiness = &val
+ case "blkio-weight":
+ hostconf.Resources.BlkioWeight = uint16(attr.value)
+ case "pids-limit":
+ val := attr.value
+ hostconf.Resources.PidsLimit = &val
+
+ }
+ }
+
+ // Create container.
+ if err := d.CreateFrom(ctx, conf, hostconf, nil); err != nil {
+ t.Fatalf("create failed with: %v", err)
+ }
+
+ // Start container.
+ if err := d.Start(ctx); err != nil {
+ t.Fatalf("start failed with: %v", err)
+ }
+
+ // Lookup the relevant cgroup ID.
+ gid := d.ID()
+ t.Logf("cgroup ID: %s", gid)
+
+ // Check list of attributes defined above.
+ for _, attr := range attrs {
+ path := filepath.Join("/sys/fs/cgroup", attr.ctrl, "docker", gid, attr.file)
+ out, err := ioutil.ReadFile(path)
+ if err != nil {
+ if os.IsNotExist(err) && attr.skipIfNotFound {
+ t.Logf("skipped %s/%s", attr.ctrl, attr.file)
+ continue
+ }
+ t.Fatalf("failed to read %q: %v", path, err)
+ }
+ if got := strings.TrimSpace(string(out)); got != attr.want {
+ t.Errorf("field: %q, cgroup attribute %s/%s, got: %q, want: %q", attr.field, attr.ctrl, attr.file, got, attr.want)
+ }
+ }
+
+ // Check that sandbox is inside cgroup.
+ controllers := []string{
+ "blkio",
+ "cpu",
+ "cpuset",
+ "memory",
+ "net_cls",
+ "net_prio",
+ "devices",
+ "freezer",
+ "perf_event",
+ "pids",
+ "systemd",
+ }
+ pid, err := d.SandboxPid(ctx)
+ if err != nil {
+ t.Fatalf("SandboxPid: %v", err)
+ }
+ for _, ctrl := range controllers {
+ path := filepath.Join("/sys/fs/cgroup", ctrl, "docker", gid, "cgroup.procs")
+ if err := verifyPid(pid, path); err != nil {
+ t.Errorf("cgroup control %q processes: %v", ctrl, err)
+ }
+ }
+}
+
+// TestCgroupParent sets the "CgroupParent" option and checks that the child and parent's
+// cgroups are created correctly relative to each other.
+func TestCgroupParent(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Construct a known cgroup name.
+ parent := testutil.RandomID("runsc-")
+ conf, hostconf, _ := d.ConfigsFrom(dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "sleep", "10000")
+ hostconf.Resources.CgroupParent = parent
+
+ if err := d.CreateFrom(ctx, conf, hostconf, nil); err != nil {
+ t.Fatalf("create failed with: %v", err)
+ }
+
+ if err := d.Start(ctx); err != nil {
+ t.Fatalf("start failed with: %v", err)
+ }
+
+ // Extract the ID to look up the cgroup.
+ gid := d.ID()
+ t.Logf("cgroup ID: %s", gid)
+
+ // Check that sandbox is inside cgroup.
+ pid, err := d.SandboxPid(ctx)
+ if err != nil {
+ t.Fatalf("SandboxPid: %v", err)
+ }
+
+ // Finds cgroup for the sandbox's parent process to check that cgroup is
+ // created in the right location relative to the parent.
+ cmd := fmt.Sprintf("grep PPid: /proc/%d/status | sed 's/PPid:\\s//'", pid)
+ ppid, err := exec.Command("bash", "-c", cmd).CombinedOutput()
+ if err != nil {
+ t.Fatalf("Executing %q: %v", cmd, err)
+ }
+ cgroups, err := cgroup.LoadPaths(strings.TrimSpace(string(ppid)))
+ if err != nil {
+ t.Fatalf("cgroup.LoadPath(%s): %v", ppid, err)
+ }
+ path := filepath.Join("/sys/fs/cgroup/memory", cgroups["memory"], parent, gid, "cgroup.procs")
+ if err := verifyPid(pid, path); err != nil {
+ t.Errorf("cgroup control %q processes: %v", "memory", err)
+ }
+}
diff --git a/test/root/chroot_test.go b/test/root/chroot_test.go
new file mode 100644
index 000000000..58fcd6f08
--- /dev/null
+++ b/test/root/chroot_test.go
@@ -0,0 +1,151 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package root is used for tests that requires sysadmin privileges run.
+package root
+
+import (
+ "context"
+ "fmt"
+ "io/ioutil"
+ "os/exec"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+)
+
+// TestChroot verifies that the sandbox is chroot'd and that mounts are cleaned
+// up after the sandbox is destroyed.
+func TestChroot(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "sleep", "10000"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ pid, err := d.SandboxPid(ctx)
+ if err != nil {
+ t.Fatalf("Docker.SandboxPid(): %v", err)
+ }
+
+ // Check that sandbox is chroot'ed.
+ procRoot := filepath.Join("/proc", strconv.Itoa(pid), "root")
+ chroot, err := filepath.EvalSymlinks(procRoot)
+ if err != nil {
+ t.Fatalf("error resolving /proc/<pid>/root symlink: %v", err)
+ }
+ if chroot != "/" {
+ t.Errorf("sandbox is not chroot'd, it should be inside: /, got: %q", chroot)
+ }
+
+ path, err := filepath.EvalSymlinks(filepath.Join("/proc", strconv.Itoa(pid), "cwd"))
+ if err != nil {
+ t.Fatalf("error resolving /proc/<pid>/cwd symlink: %v", err)
+ }
+ if chroot != path {
+ t.Errorf("sandbox current dir is wrong, want: %q, got: %q", chroot, path)
+ }
+
+ fi, err := ioutil.ReadDir(procRoot)
+ if err != nil {
+ t.Fatalf("error listing %q: %v", chroot, err)
+ }
+ if want, got := 1, len(fi); want != got {
+ t.Fatalf("chroot dir got %d entries, want %d", got, want)
+ }
+
+ // chroot dir is prepared by runsc and should contains only /proc.
+ if fi[0].Name() != "proc" {
+ t.Errorf("chroot got children %v, want %v", fi[0].Name(), "proc")
+ }
+
+ d.CleanUp(ctx)
+}
+
+func TestChrootGofer(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "sleep", "10000"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // It's tricky to find gofers. Get sandbox PID first, then find parent. From
+ // parent get all immediate children, remove the sandbox, and everything else
+ // are gofers.
+ sandPID, err := d.SandboxPid(ctx)
+ if err != nil {
+ t.Fatalf("Docker.SandboxPid(): %v", err)
+ }
+
+ // Find sandbox's parent PID.
+ cmd := fmt.Sprintf("grep PPid /proc/%d/status | awk '{print $2}'", sandPID)
+ parent, err := exec.Command("sh", "-c", cmd).CombinedOutput()
+ if err != nil {
+ t.Fatalf("failed to fetch runsc (%d) parent PID: %v, out:\n%s", sandPID, err, string(parent))
+ }
+ parentPID, err := strconv.Atoi(strings.TrimSpace(string(parent)))
+ if err != nil {
+ t.Fatalf("failed to parse PPID %q: %v", string(parent), err)
+ }
+
+ // Get all children from parent.
+ childrenOut, err := exec.Command("/usr/bin/pgrep", "-P", strconv.Itoa(parentPID)).CombinedOutput()
+ if err != nil {
+ t.Fatalf("failed to fetch containerd-shim children: %v", err)
+ }
+ children := strings.Split(strings.TrimSpace(string(childrenOut)), "\n")
+
+ // This where the root directory is mapped on the host and that's where the
+ // gofer must have chroot'd to.
+ root := "/root"
+
+ for _, child := range children {
+ childPID, err := strconv.Atoi(child)
+ if err != nil {
+ t.Fatalf("failed to parse child PID %q: %v", child, err)
+ }
+ if childPID == sandPID {
+ // Skip the sandbox, all other immediate children are gofers.
+ continue
+ }
+
+ // Check that gofer is chroot'ed.
+ chroot, err := filepath.EvalSymlinks(filepath.Join("/proc", child, "root"))
+ if err != nil {
+ t.Fatalf("error resolving /proc/<pid>/root symlink: %v", err)
+ }
+ if root != chroot {
+ t.Errorf("gofer chroot is wrong, want: %q, got: %q", root, chroot)
+ }
+
+ path, err := filepath.EvalSymlinks(filepath.Join("/proc", child, "cwd"))
+ if err != nil {
+ t.Fatalf("error resolving /proc/<pid>/cwd symlink: %v", err)
+ }
+ if root != path {
+ t.Errorf("gofer current dir is wrong, want: %q, got: %q", root, path)
+ }
+ }
+}
diff --git a/test/root/crictl_test.go b/test/root/crictl_test.go
new file mode 100644
index 000000000..732fae821
--- /dev/null
+++ b/test/root/crictl_test.go
@@ -0,0 +1,393 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package root
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net/http"
+ "os"
+ "os/exec"
+ "path"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/test/criutil"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+// Tests for crictl have to be run as root (rather than in a user namespace)
+// because crictl creates named network namespaces in /var/run/netns/.
+
+// Sandbox returns a JSON config for a simple sandbox. Sandbox names must be
+// unique so different names should be used when running tests on the same
+// containerd instance.
+func Sandbox(name string) string {
+ // Sandbox is a default JSON config for a sandbox.
+ s := map[string]interface{}{
+ "metadata": map[string]string{
+ "name": name,
+ "namespace": "default",
+ "uid": testutil.RandomID(""),
+ },
+ "linux": map[string]string{},
+ "log_directory": "/tmp",
+ }
+
+ v, err := json.Marshal(s)
+ if err != nil {
+ // This shouldn't happen.
+ panic(err)
+ }
+ return string(v)
+}
+
+// SimpleSpec returns a JSON config for a simple container that runs the
+// specified command in the specified image.
+func SimpleSpec(name, image string, cmd []string, extra map[string]interface{}) string {
+ s := map[string]interface{}{
+ "metadata": map[string]string{
+ "name": name,
+ },
+ "image": map[string]string{
+ "image": testutil.ImageByName(image),
+ },
+ // Log files are not deleted after root tests are run. Log to random
+ // paths to ensure logs are fresh.
+ "log_path": fmt.Sprintf("%s.log", testutil.RandomID(name)),
+ }
+ if len(cmd) > 0 { // Omit if empty.
+ s["command"] = cmd
+ }
+ for k, v := range extra {
+ s[k] = v // Extra settings.
+ }
+ v, err := json.Marshal(s)
+ if err != nil {
+ // This shouldn't happen.
+ panic(err)
+ }
+ return string(v)
+}
+
+// Httpd is a JSON config for an httpd container.
+var Httpd = SimpleSpec("httpd", "basic/httpd", nil, nil)
+
+// TestCrictlSanity refers to b/112433158.
+func TestCrictlSanity(t *testing.T) {
+ // Setup containerd and crictl.
+ crictl, cleanup, err := setup(t)
+ if err != nil {
+ t.Fatalf("failed to setup crictl: %v", err)
+ }
+ defer cleanup()
+ podID, contID, err := crictl.StartPodAndContainer("basic/httpd", Sandbox("default"), Httpd)
+ if err != nil {
+ t.Fatalf("start failed: %v", err)
+ }
+
+ // Look for the httpd page.
+ if err = httpGet(crictl, podID, "index.html"); err != nil {
+ t.Fatalf("failed to get page: %v", err)
+ }
+
+ // Stop everything.
+ if err := crictl.StopPodAndContainer(podID, contID); err != nil {
+ t.Fatalf("stop failed: %v", err)
+ }
+}
+
+// HttpdMountPaths is a JSON config for an httpd container with additional
+// mounts.
+var HttpdMountPaths = SimpleSpec("httpd", "basic/httpd", nil, map[string]interface{}{
+ "mounts": []map[string]interface{}{
+ map[string]interface{}{
+ "container_path": "/var/run/secrets/kubernetes.io/serviceaccount",
+ "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/volumes/kubernetes.io~secret/default-token-2rpfx",
+ "readonly": true,
+ },
+ map[string]interface{}{
+ "container_path": "/etc/hosts",
+ "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/etc-hosts",
+ "readonly": false,
+ },
+ map[string]interface{}{
+ "container_path": "/dev/termination-log",
+ "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/containers/httpd/d1709580",
+ "readonly": false,
+ },
+ map[string]interface{}{
+ "container_path": "/usr/local/apache2/htdocs/test",
+ "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064",
+ "readonly": true,
+ },
+ },
+ "linux": map[string]interface{}{},
+})
+
+// TestMountPaths refers to b/117635704.
+func TestMountPaths(t *testing.T) {
+ // Setup containerd and crictl.
+ crictl, cleanup, err := setup(t)
+ if err != nil {
+ t.Fatalf("failed to setup crictl: %v", err)
+ }
+ defer cleanup()
+ podID, contID, err := crictl.StartPodAndContainer("basic/httpd", Sandbox("default"), HttpdMountPaths)
+ if err != nil {
+ t.Fatalf("start failed: %v", err)
+ }
+
+ // Look for the directory available at /test.
+ if err = httpGet(crictl, podID, "test"); err != nil {
+ t.Fatalf("failed to get page: %v", err)
+ }
+
+ // Stop everything.
+ if err := crictl.StopPodAndContainer(podID, contID); err != nil {
+ t.Fatalf("stop failed: %v", err)
+ }
+}
+
+// TestMountPaths refers to b/118728671.
+func TestMountOverSymlinks(t *testing.T) {
+ // Setup containerd and crictl.
+ crictl, cleanup, err := setup(t)
+ if err != nil {
+ t.Fatalf("failed to setup crictl: %v", err)
+ }
+ defer cleanup()
+
+ spec := SimpleSpec("busybox", "basic/resolv", []string{"sleep", "1000"}, nil)
+ podID, contID, err := crictl.StartPodAndContainer("basic/resolv", Sandbox("default"), spec)
+ if err != nil {
+ t.Fatalf("start failed: %v", err)
+ }
+
+ out, err := crictl.Exec(contID, "readlink", "/etc/resolv.conf")
+ if err != nil {
+ t.Fatalf("readlink failed: %v, out: %s", err, out)
+ }
+ if want := "/tmp/resolv.conf"; !strings.Contains(string(out), want) {
+ t.Fatalf("/etc/resolv.conf is not pointing to %q: %q", want, string(out))
+ }
+
+ etc, err := crictl.Exec(contID, "cat", "/etc/resolv.conf")
+ if err != nil {
+ t.Fatalf("cat failed: %v, out: %s", err, etc)
+ }
+ tmp, err := crictl.Exec(contID, "cat", "/tmp/resolv.conf")
+ if err != nil {
+ t.Fatalf("cat failed: %v, out: %s", err, out)
+ }
+ if tmp != etc {
+ t.Fatalf("file content doesn't match:\n\t/etc/resolv.conf: %s\n\t/tmp/resolv.conf: %s", string(etc), string(tmp))
+ }
+
+ // Stop everything.
+ if err := crictl.StopPodAndContainer(podID, contID); err != nil {
+ t.Fatalf("stop failed: %v", err)
+ }
+}
+
+// TestHomeDir tests that the HOME environment variable is set for
+// Pod containers.
+func TestHomeDir(t *testing.T) {
+ // Setup containerd and crictl.
+ crictl, cleanup, err := setup(t)
+ if err != nil {
+ t.Fatalf("failed to setup crictl: %v", err)
+ }
+ defer cleanup()
+
+ // Note that container ID returned here is a sub-container. All Pod
+ // containers are sub-containers. The root container of the sandbox is the
+ // pause container.
+ t.Run("sub-container", func(t *testing.T) {
+ contSpec := SimpleSpec("subcontainer", "basic/busybox", []string{"sh", "-c", "echo $HOME"}, nil)
+ podID, contID, err := crictl.StartPodAndContainer("basic/busybox", Sandbox("subcont-sandbox"), contSpec)
+ if err != nil {
+ t.Fatalf("start failed: %v", err)
+ }
+
+ out, err := crictl.Logs(contID)
+ if err != nil {
+ t.Fatalf("failed retrieving container logs: %v, out: %s", err, out)
+ }
+ if got, want := strings.TrimSpace(string(out)), "/root"; got != want {
+ t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want)
+ }
+
+ // Stop everything.
+ if err := crictl.StopPodAndContainer(podID, contID); err != nil {
+ t.Fatalf("stop failed: %v", err)
+ }
+ })
+
+ // Tests that HOME is set for the exec process.
+ t.Run("exec", func(t *testing.T) {
+ contSpec := SimpleSpec("exec", "basic/busybox", []string{"sleep", "1000"}, nil)
+ podID, contID, err := crictl.StartPodAndContainer("basic/busybox", Sandbox("exec-sandbox"), contSpec)
+ if err != nil {
+ t.Fatalf("start failed: %v", err)
+ }
+
+ out, err := crictl.Exec(contID, "sh", "-c", "echo $HOME")
+ if err != nil {
+ t.Fatalf("failed retrieving container logs: %v, out: %s", err, out)
+ }
+ if got, want := strings.TrimSpace(string(out)), "/root"; got != want {
+ t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want)
+ }
+
+ // Stop everything.
+ if err := crictl.StopPodAndContainer(podID, contID); err != nil {
+ t.Fatalf("stop failed: %v", err)
+ }
+ })
+}
+
+// containerdConfigTemplate is a .toml config for containerd. It contains a
+// formatting verb so the runtime field can be set via fmt.Sprintf.
+const containerdConfigTemplate = `
+disabled_plugins = ["restart"]
+[plugins.linux]
+ runtime = "%s"
+ runtime_root = "/tmp/test-containerd/runsc"
+ shim = "/usr/local/bin/gvisor-containerd-shim"
+ shim_debug = true
+
+[plugins.cri.containerd.runtimes.runsc]
+ runtime_type = "io.containerd.runtime.v1.linux"
+ runtime_engine = "%s"
+`
+
+// setup sets up before a test. Specifically it:
+// * Creates directories and a socket for containerd to utilize.
+// * Runs containerd and waits for it to reach a "ready" state for testing.
+// * Returns a cleanup function that should be called at the end of the test.
+func setup(t *testing.T) (*criutil.Crictl, func(), error) {
+ // Create temporary containerd root and state directories, and a socket
+ // via which crictl and containerd communicate.
+ containerdRoot, err := ioutil.TempDir(testutil.TmpDir(), "containerd-root")
+ if err != nil {
+ t.Fatalf("failed to create containerd root: %v", err)
+ }
+ cu := cleanup.Make(func() { os.RemoveAll(containerdRoot) })
+ defer cu.Clean()
+
+ containerdState, err := ioutil.TempDir(testutil.TmpDir(), "containerd-state")
+ if err != nil {
+ t.Fatalf("failed to create containerd state: %v", err)
+ }
+ cu.Add(func() { os.RemoveAll(containerdState) })
+ sockAddr := filepath.Join(testutil.TmpDir(), "containerd-test.sock")
+
+ // We rewrite a configuration. This is based on the current docker
+ // configuration for the runtime under test.
+ runtime, err := dockerutil.RuntimePath()
+ if err != nil {
+ t.Fatalf("error discovering runtime path: %v", err)
+ }
+ config, configCleanup, err := testutil.WriteTmpFile("containerd-config", fmt.Sprintf(containerdConfigTemplate, runtime, runtime))
+ if err != nil {
+ t.Fatalf("failed to write containerd config")
+ }
+ cu.Add(configCleanup)
+
+ // Start containerd.
+ cmd := exec.Command(getContainerd(),
+ "--config", config,
+ "--log-level", "debug",
+ "--root", containerdRoot,
+ "--state", containerdState,
+ "--address", sockAddr)
+ startupR, startupW := io.Pipe()
+ defer startupR.Close()
+ defer startupW.Close()
+ stderr := &bytes.Buffer{}
+ stdout := &bytes.Buffer{}
+ cmd.Stderr = io.MultiWriter(startupW, stderr)
+ cmd.Stdout = io.MultiWriter(startupW, stdout)
+ cu.Add(func() {
+ // Log output in case of failure.
+ t.Logf("containerd stdout: %s", stdout.String())
+ t.Logf("containerd stderr: %s", stderr.String())
+ })
+
+ // Start the process.
+ if err := cmd.Start(); err != nil {
+ t.Fatalf("failed running containerd: %v", err)
+ }
+
+ // Wait for containerd to boot.
+ if err := testutil.WaitUntilRead(startupR, "Start streaming server", nil, 10*time.Second); err != nil {
+ t.Fatalf("failed to start containerd: %v", err)
+ }
+
+ // Kill must be the last cleanup (as it will be executed first).
+ cc := criutil.NewCrictl(t, sockAddr)
+ cu.Add(func() {
+ cc.CleanUp() // Remove tmp files, etc.
+ if err := testutil.KillCommand(cmd); err != nil {
+ log.Printf("error killing containerd: %v", err)
+ }
+ })
+
+ return cc, cu.Release(), nil
+}
+
+// httpGet GETs the contents of a file served from a pod on port 80.
+func httpGet(crictl *criutil.Crictl, podID, filePath string) error {
+ // Get the IP of the httpd server.
+ ip, err := crictl.PodIP(podID)
+ if err != nil {
+ return fmt.Errorf("failed to get IP from pod %q: %v", podID, err)
+ }
+
+ // GET the page. We may be waiting for the server to start, so retry
+ // with a timeout.
+ var resp *http.Response
+ cb := func() error {
+ r, err := http.Get(fmt.Sprintf("http://%s", path.Join(ip, filePath)))
+ resp = r
+ return err
+ }
+ if err := testutil.Poll(cb, 20*time.Second); err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != 200 {
+ return fmt.Errorf("bad status returned: %d", resp.StatusCode)
+ }
+ return nil
+}
+
+func getContainerd() string {
+ // Use the local path if it exists, otherwise, use the system one.
+ if _, err := os.Stat("/usr/local/bin/containerd"); err == nil {
+ return "/usr/local/bin/containerd"
+ }
+ return "/usr/bin/containerd"
+}
diff --git a/test/root/main_test.go b/test/root/main_test.go
new file mode 100644
index 000000000..9fb17e0dd
--- /dev/null
+++ b/test/root/main_test.go
@@ -0,0 +1,49 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package root
+
+import (
+ "flag"
+ "fmt"
+ "os"
+ "testing"
+
+ "github.com/syndtr/gocapability/capability"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// TestMain is the main function for root tests. This function checks the
+// supported docker version, required capabilities, and configures the executable
+// path for runsc.
+func TestMain(m *testing.M) {
+ flag.Parse()
+
+ if !specutils.HasCapabilities(capability.CAP_SYS_ADMIN, capability.CAP_DAC_OVERRIDE) {
+ fmt.Println("Test requires sysadmin privileges to run. Try again with sudo.")
+ os.Exit(1)
+ }
+
+ dockerutil.EnsureSupportedDockerVersion()
+
+ // Configure exe for tests.
+ path, err := dockerutil.RuntimePath()
+ if err != nil {
+ panic(err.Error())
+ }
+ specutils.ExePath = path
+
+ os.Exit(m.Run())
+}
diff --git a/test/root/oom_score_adj_test.go b/test/root/oom_score_adj_test.go
new file mode 100644
index 000000000..4243eb59e
--- /dev/null
+++ b/test/root/oom_score_adj_test.go
@@ -0,0 +1,366 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package root
+
+import (
+ "fmt"
+ "os"
+ "testing"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+var (
+ maxOOMScoreAdj = 1000
+ highOOMScoreAdj = 500
+ lowOOMScoreAdj = -500
+ minOOMScoreAdj = -1000
+)
+
+// Tests for oom_score_adj have to be run as root (rather than in a user
+// namespace) because we need to adjust oom_score_adj for PIDs other than our
+// own and test values below 0.
+
+// TestOOMScoreAdjSingle tests that oom_score_adj is set properly in a
+// single container sandbox.
+func TestOOMScoreAdjSingle(t *testing.T) {
+ ppid, err := specutils.GetParentPid(os.Getpid())
+ if err != nil {
+ t.Fatalf("getting parent pid: %v", err)
+ }
+ parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(ppid)
+ if err != nil {
+ t.Fatalf("getting parent oom_score_adj: %v", err)
+ }
+
+ testCases := []struct {
+ Name string
+
+ // OOMScoreAdj is the oom_score_adj set to the OCI spec. If nil then
+ // no value is set.
+ OOMScoreAdj *int
+ }{
+ {
+ Name: "max",
+ OOMScoreAdj: &maxOOMScoreAdj,
+ },
+ {
+ Name: "high",
+ OOMScoreAdj: &highOOMScoreAdj,
+ },
+ {
+ Name: "low",
+ OOMScoreAdj: &lowOOMScoreAdj,
+ },
+ {
+ Name: "min",
+ OOMScoreAdj: &minOOMScoreAdj,
+ },
+ {
+ Name: "nil",
+ OOMScoreAdj: &parentOOMScoreAdj,
+ },
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.Name, func(t *testing.T) {
+ id := testutil.RandomContainerID()
+ s := testutil.NewSpecWithArgs("sleep", "1000")
+ s.Process.OOMScoreAdj = testCase.OOMScoreAdj
+
+ containers, cleanup, err := startContainers(t, []*specs.Spec{s}, []string{id})
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ c := containers[0]
+
+ // Verify the gofer's oom_score_adj
+ if testCase.OOMScoreAdj != nil {
+ goferScore, err := specutils.GetOOMScoreAdj(c.GoferPid)
+ if err != nil {
+ t.Fatalf("error reading gofer oom_score_adj: %v", err)
+ }
+ if goferScore != *testCase.OOMScoreAdj {
+ t.Errorf("gofer oom_score_adj got: %d, want: %d", goferScore, *testCase.OOMScoreAdj)
+ }
+
+ // Verify the sandbox's oom_score_adj.
+ //
+ // The sandbox should be the same for all containers so just use
+ // the first one.
+ sandboxPid := c.Sandbox.Pid
+ sandboxScore, err := specutils.GetOOMScoreAdj(sandboxPid)
+ if err != nil {
+ t.Fatalf("error reading sandbox oom_score_adj: %v", err)
+ }
+ if sandboxScore != *testCase.OOMScoreAdj {
+ t.Errorf("sandbox oom_score_adj got: %d, want: %d", sandboxScore, *testCase.OOMScoreAdj)
+ }
+ }
+ })
+ }
+}
+
+// TestOOMScoreAdjMulti tests that oom_score_adj is set properly in a
+// multi-container sandbox.
+func TestOOMScoreAdjMulti(t *testing.T) {
+ ppid, err := specutils.GetParentPid(os.Getpid())
+ if err != nil {
+ t.Fatalf("getting parent pid: %v", err)
+ }
+ parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(ppid)
+ if err != nil {
+ t.Fatalf("getting parent oom_score_adj: %v", err)
+ }
+
+ testCases := []struct {
+ Name string
+
+ // OOMScoreAdj is the oom_score_adj set to the OCI spec. If nil then
+ // no value is set. One value for each container. The first value is the
+ // root container.
+ OOMScoreAdj []*int
+
+ // Expected is the expected oom_score_adj of the sandbox. If nil, then
+ // this value is ignored.
+ Expected *int
+
+ // Remove is a set of container indexes to remove from the sandbox.
+ Remove []int
+
+ // ExpectedAfterRemove is the expected oom_score_adj of the sandbox
+ // after containers are removed. Ignored if nil.
+ ExpectedAfterRemove *int
+ }{
+ // A single container CRI test case. This should not happen in
+ // practice as there should be at least one container besides the pause
+ // container. However, we include a test case to ensure sane behavior.
+ {
+ Name: "single",
+ OOMScoreAdj: []*int{&highOOMScoreAdj},
+ Expected: &parentOOMScoreAdj,
+ },
+ {
+ Name: "multi_no_value",
+ OOMScoreAdj: []*int{nil, nil, nil},
+ Expected: &parentOOMScoreAdj,
+ },
+ {
+ Name: "multi_non_nil_root",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, nil, nil},
+ Expected: &parentOOMScoreAdj,
+ },
+ {
+ Name: "multi_value",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &highOOMScoreAdj, &lowOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &lowOOMScoreAdj,
+ },
+ {
+ Name: "multi_min_value",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &lowOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &lowOOMScoreAdj,
+ },
+ {
+ Name: "multi_max_value",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &maxOOMScoreAdj, &highOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &highOOMScoreAdj,
+ },
+ {
+ Name: "remove_adjusted",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &maxOOMScoreAdj, &highOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &highOOMScoreAdj,
+ // Remove highOOMScoreAdj container.
+ Remove: []int{2},
+ ExpectedAfterRemove: &maxOOMScoreAdj,
+ },
+ {
+ // This test removes all non-root sandboxes with a specified oomScoreAdj.
+ Name: "remove_to_nil",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, nil, &lowOOMScoreAdj},
+ Expected: &lowOOMScoreAdj,
+ // Remove lowOOMScoreAdj container.
+ Remove: []int{2},
+ // The oom_score_adj expected after remove is that of the parent process.
+ ExpectedAfterRemove: &parentOOMScoreAdj,
+ },
+ {
+ Name: "remove_no_effect",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &maxOOMScoreAdj, &highOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &highOOMScoreAdj,
+ // Remove the maxOOMScoreAdj container.
+ Remove: []int{1},
+ ExpectedAfterRemove: &highOOMScoreAdj,
+ },
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.Name, func(t *testing.T) {
+ var cmds [][]string
+ var oomScoreAdj []*int
+ var toRemove []string
+
+ for _, oomScore := range testCase.OOMScoreAdj {
+ oomScoreAdj = append(oomScoreAdj, oomScore)
+ cmds = append(cmds, []string{"sleep", "100"})
+ }
+
+ specs, ids := createSpecs(cmds...)
+ for i, spec := range specs {
+ // Ensure the correct value is set, including no value.
+ spec.Process.OOMScoreAdj = oomScoreAdj[i]
+
+ for _, j := range testCase.Remove {
+ if i == j {
+ toRemove = append(toRemove, ids[i])
+ }
+ }
+ }
+
+ containers, cleanup, err := startContainers(t, specs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ for i, c := range containers {
+ if oomScoreAdj[i] != nil {
+ // Verify the gofer's oom_score_adj
+ score, err := specutils.GetOOMScoreAdj(c.GoferPid)
+ if err != nil {
+ t.Fatalf("error reading gofer oom_score_adj: %v", err)
+ }
+ if score != *oomScoreAdj[i] {
+ t.Errorf("gofer oom_score_adj got: %d, want: %d", score, *oomScoreAdj[i])
+ }
+ }
+ }
+
+ // Verify the sandbox's oom_score_adj.
+ //
+ // The sandbox should be the same for all containers so just use
+ // the first one.
+ sandboxPid := containers[0].Sandbox.Pid
+ if testCase.Expected != nil {
+ score, err := specutils.GetOOMScoreAdj(sandboxPid)
+ if err != nil {
+ t.Fatalf("error reading sandbox oom_score_adj: %v", err)
+ }
+ if score != *testCase.Expected {
+ t.Errorf("sandbox oom_score_adj got: %d, want: %d", score, *testCase.Expected)
+ }
+ }
+
+ if len(toRemove) == 0 {
+ return
+ }
+
+ // Remove containers.
+ for _, removeID := range toRemove {
+ for _, c := range containers {
+ if c.ID == removeID {
+ c.Destroy()
+ }
+ }
+ }
+
+ // Check the new adjusted oom_score_adj.
+ if testCase.ExpectedAfterRemove != nil {
+ scoreAfterRemove, err := specutils.GetOOMScoreAdj(sandboxPid)
+ if err != nil {
+ t.Fatalf("error reading sandbox oom_score_adj: %v", err)
+ }
+ if scoreAfterRemove != *testCase.ExpectedAfterRemove {
+ t.Errorf("sandbox oom_score_adj got: %d, want: %d", scoreAfterRemove, *testCase.ExpectedAfterRemove)
+ }
+ }
+ })
+ }
+}
+
+func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) {
+ var specs []*specs.Spec
+ var ids []string
+ rootID := testutil.RandomContainerID()
+
+ for i, cmd := range cmds {
+ spec := testutil.NewSpecWithArgs(cmd...)
+ if i == 0 {
+ spec.Annotations = map[string]string{
+ specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeSandbox,
+ }
+ ids = append(ids, rootID)
+ } else {
+ spec.Annotations = map[string]string{
+ specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeContainer,
+ specutils.ContainerdSandboxIDAnnotation: rootID,
+ }
+ ids = append(ids, testutil.RandomContainerID())
+ }
+ specs = append(specs, spec)
+ }
+ return specs, ids
+}
+
+func startContainers(t *testing.T, specs []*specs.Spec, ids []string) ([]*container.Container, func(), error) {
+ var containers []*container.Container
+
+ // All containers must share the same root.
+ rootDir, clean, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ cu := cleanup.Make(clean)
+ defer cu.Clean()
+
+ // Point this to from the configuration.
+ conf := testutil.TestConfig(t)
+ conf.RootDir = rootDir
+
+ for i, spec := range specs {
+ bundleDir, clean, err := testutil.SetupBundleDir(spec)
+ if err != nil {
+ return nil, nil, fmt.Errorf("error setting up bundle: %v", err)
+ }
+ cu.Add(clean)
+
+ args := container.Args{
+ ID: ids[i],
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := container.New(conf, args)
+ if err != nil {
+ return nil, nil, fmt.Errorf("error creating container: %v", err)
+ }
+ containers = append(containers, cont)
+
+ if err := cont.Start(conf); err != nil {
+ return nil, nil, fmt.Errorf("error starting container: %v", err)
+ }
+ }
+
+ return containers, cu.Release(), nil
+}
diff --git a/test/root/root.go b/test/root/root.go
new file mode 100644
index 000000000..0f1d29faf
--- /dev/null
+++ b/test/root/root.go
@@ -0,0 +1,21 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package root is used for tests that requires sysadmin privileges run. First,
+// follow the setup instruction in runsc/test/README.md. You should also have
+// docker, containerd, and crictl installed. To run these tests from the
+// project root directory:
+//
+// ./scripts/root_tests.sh
+package root
diff --git a/test/root/runsc_test.go b/test/root/runsc_test.go
new file mode 100644
index 000000000..25204bebb
--- /dev/null
+++ b/test/root/runsc_test.go
@@ -0,0 +1,151 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package root
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/cenkalti/backoff"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// TestDoKill checks that when "runsc do..." is killed, the sandbox process is
+// also terminated. This ensures that parent death signal is propagate to the
+// sandbox process correctly.
+func TestDoKill(t *testing.T) {
+ // Make the sandbox process be reparented here when it's killed, so we can
+ // wait for it.
+ if err := unix.Prctl(unix.PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0); err != nil {
+ t.Fatalf("prctl(PR_SET_CHILD_SUBREAPER): %v", err)
+ }
+
+ cmd := exec.Command(specutils.ExePath, "do", "sleep", "10000")
+ buf := &bytes.Buffer{}
+ cmd.Stdout = buf
+ cmd.Stderr = buf
+ cmd.Start()
+
+ var pid int
+ findSandbox := func() error {
+ var err error
+ pid, err = sandboxPid(cmd.Process.Pid)
+ if err != nil {
+ return &backoff.PermanentError{Err: err}
+ }
+ if pid == 0 {
+ return fmt.Errorf("sandbox process not found")
+ }
+ return nil
+ }
+ if err := testutil.Poll(findSandbox, 10*time.Second); err != nil {
+ t.Fatalf("failed to find sandbox: %v", err)
+ }
+ t.Logf("Found sandbox, pid: %d", pid)
+
+ if err := cmd.Process.Kill(); err != nil {
+ t.Fatalf("failed to kill run process: %v", err)
+ }
+ cmd.Wait()
+ t.Logf("Parent process killed (%d). Output: %s", cmd.Process.Pid, buf.String())
+
+ ch := make(chan struct{})
+ go func() {
+ defer func() { ch <- struct{}{} }()
+ t.Logf("Waiting for sandbox process (%d) termination", pid)
+ if _, err := unix.Wait4(pid, nil, 0, nil); err != nil {
+ t.Errorf("error waiting for sandbox process (%d): %v", pid, err)
+ }
+ }()
+ select {
+ case <-ch:
+ // Done
+ case <-time.After(5 * time.Second):
+ t.Fatalf("timeout waiting for sandbox process (%d) to exit", pid)
+ }
+}
+
+// sandboxPid looks for the sandbox process inside the process tree starting
+// from "pid". It returns 0 and no error if no sandbox process is found. It
+// returns error if anything failed.
+func sandboxPid(pid int) (int, error) {
+ cmd := exec.Command("pgrep", "-P", strconv.Itoa(pid))
+ buf := &bytes.Buffer{}
+ cmd.Stdout = buf
+ if err := cmd.Start(); err != nil {
+ return 0, err
+ }
+ ps, err := cmd.Process.Wait()
+ if err != nil {
+ return 0, err
+ }
+ if ps.ExitCode() == 1 {
+ // pgrep returns 1 when no process is found.
+ return 0, nil
+ }
+
+ var children []int
+ for _, line := range strings.Split(buf.String(), "\n") {
+ if len(line) == 0 {
+ continue
+ }
+ child, err := strconv.Atoi(line)
+ if err != nil {
+ return 0, err
+ }
+
+ cmdline, err := ioutil.ReadFile(filepath.Join("/proc", line, "cmdline"))
+ if err != nil {
+ if os.IsNotExist(err) {
+ // Raced with process exit.
+ continue
+ }
+ return 0, err
+ }
+ args := strings.SplitN(string(cmdline), "\x00", 2)
+ if len(args) == 0 {
+ return 0, fmt.Errorf("malformed cmdline file: %q", cmdline)
+ }
+ // The sandbox process has the first argument set to "runsc-sandbox".
+ if args[0] == "runsc-sandbox" {
+ return child, nil
+ }
+
+ children = append(children, child)
+ }
+
+ // Sandbox process wasn't found, try another level down.
+ for _, pid := range children {
+ sand, err := sandboxPid(pid)
+ if err != nil {
+ return 0, err
+ }
+ if sand != 0 {
+ return sand, nil
+ }
+ // Not found, continue the search.
+ }
+ return 0, nil
+}
diff --git a/test/runner/BUILD b/test/runner/BUILD
new file mode 100644
index 000000000..6833c9986
--- /dev/null
+++ b/test/runner/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "runner",
+ testonly = 1,
+ srcs = ["runner.go"],
+ data = [
+ "//runsc",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/log",
+ "//pkg/test/testutil",
+ "//runsc/specutils",
+ "//test/runner/gtest",
+ "//test/uds",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl
new file mode 100644
index 000000000..921e499be
--- /dev/null
+++ b/test/runner/defs.bzl
@@ -0,0 +1,238 @@
+"""Defines a rule for syscall test targets."""
+
+load("//tools:defs.bzl", "default_platform", "loopback", "platforms")
+
+def _runner_test_impl(ctx):
+ # Generate a runner binary.
+ runner = ctx.actions.declare_file("%s-runner" % ctx.label.name)
+ runner_content = "\n".join([
+ "#!/bin/bash",
+ "set -euf -x -o pipefail",
+ "if [[ -n \"${TEST_UNDECLARED_OUTPUTS_DIR}\" ]]; then",
+ " mkdir -p \"${TEST_UNDECLARED_OUTPUTS_DIR}\"",
+ " chmod a+rwx \"${TEST_UNDECLARED_OUTPUTS_DIR}\"",
+ "fi",
+ "exec %s %s %s\n" % (
+ ctx.files.runner[0].short_path,
+ " ".join(ctx.attr.runner_args),
+ ctx.files.test[0].short_path,
+ ),
+ ])
+ ctx.actions.write(runner, runner_content, is_executable = True)
+
+ # Return with all transitive files.
+ runfiles = ctx.runfiles(
+ transitive_files = depset(transitive = [
+ target.data_runfiles.files
+ for target in (ctx.attr.runner, ctx.attr.test)
+ if hasattr(target, "data_runfiles")
+ ]),
+ files = ctx.files.runner + ctx.files.test,
+ collect_default = True,
+ collect_data = True,
+ )
+ return [DefaultInfo(executable = runner, runfiles = runfiles)]
+
+_runner_test = rule(
+ attrs = {
+ "runner": attr.label(
+ default = "//test/runner:runner",
+ ),
+ "test": attr.label(
+ mandatory = True,
+ ),
+ "runner_args": attr.string_list(),
+ "data": attr.label_list(
+ allow_files = True,
+ ),
+ },
+ test = True,
+ implementation = _runner_test_impl,
+)
+
+def _syscall_test(
+ test,
+ shard_count,
+ size,
+ platform,
+ use_tmpfs,
+ tags,
+ network = "none",
+ file_access = "exclusive",
+ overlay = False,
+ add_uds_tree = False,
+ vfs2 = False):
+ # Prepend "runsc" to non-native platform names.
+ full_platform = platform if platform == "native" else "runsc_" + platform
+
+ # Name the test appropriately.
+ name = test.split(":")[1] + "_" + full_platform
+ if file_access == "shared":
+ name += "_shared"
+ if overlay:
+ name += "_overlay"
+ if vfs2:
+ name += "_vfs2"
+ if network != "none":
+ name += "_" + network + "net"
+
+ # Apply all tags.
+ if tags == None:
+ tags = []
+
+ # Add the full_platform and file access in a tag to make it easier to run
+ # all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared.
+ tags += [full_platform, "file_" + file_access]
+
+ # Hash this target into one of 15 buckets. This can be used to
+ # randomly split targets between different workflows.
+ hash15 = hash(native.package_name() + name) % 15
+ tags.append("hash15:" + str(hash15))
+
+ # TODO(b/139838000): Tests using hostinet must be disabled on Guitar until
+ # we figure out how to request ipv4 sockets on Guitar machines.
+ if network == "host":
+ tags.append("noguitar")
+ tags.append("block-network")
+
+ # Disable off-host networking.
+ tags.append("requires-net:loopback")
+
+ runner_args = [
+ # Arguments are passed directly to runner binary.
+ "--platform=" + platform,
+ "--network=" + network,
+ "--use-tmpfs=" + str(use_tmpfs),
+ "--file-access=" + file_access,
+ "--overlay=" + str(overlay),
+ "--add-uds-tree=" + str(add_uds_tree),
+ "--vfs2=" + str(vfs2),
+ ]
+
+ # Call the rule above.
+ _runner_test(
+ name = name,
+ test = test,
+ runner_args = runner_args,
+ data = [loopback],
+ size = size,
+ tags = tags,
+ shard_count = shard_count,
+ )
+
+def syscall_test(
+ test,
+ shard_count = 5,
+ size = "small",
+ use_tmpfs = False,
+ add_overlay = False,
+ add_uds_tree = False,
+ add_hostinet = False,
+ vfs2 = False,
+ tags = None):
+ """syscall_test is a macro that will create targets for all platforms.
+
+ Args:
+ test: the test target.
+ shard_count: shards for defined tests.
+ size: the defined test size.
+ use_tmpfs: use tmpfs in the defined tests.
+ add_overlay: add an overlay test.
+ add_uds_tree: add a UDS test.
+ add_hostinet: add a hostinet test.
+ tags: starting test tags.
+ """
+ if not tags:
+ tags = []
+
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = "native",
+ use_tmpfs = False,
+ add_uds_tree = add_uds_tree,
+ tags = tags,
+ )
+
+ for (platform, platform_tags) in platforms.items():
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = platform,
+ use_tmpfs = use_tmpfs,
+ add_uds_tree = add_uds_tree,
+ tags = platform_tags + tags,
+ )
+
+ vfs2_tags = list(tags)
+ if vfs2:
+ # Add tag to easily run VFS2 tests with --test_tag_filters=vfs2
+ vfs2_tags.append("vfs2")
+
+ else:
+ # Don't automatically run tests tests not yet passing.
+ vfs2_tags.append("manual")
+ vfs2_tags.append("noguitar")
+ vfs2_tags.append("notap")
+
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = default_platform,
+ use_tmpfs = use_tmpfs,
+ add_uds_tree = add_uds_tree,
+ tags = platforms[default_platform] + vfs2_tags,
+ vfs2 = True,
+ )
+
+ # TODO(gvisor.dev/issue/1487): Enable VFS2 overlay tests.
+ if add_overlay:
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = default_platform,
+ use_tmpfs = use_tmpfs,
+ add_uds_tree = add_uds_tree,
+ tags = platforms[default_platform] + tags,
+ overlay = True,
+ )
+
+ if add_hostinet:
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = default_platform,
+ use_tmpfs = use_tmpfs,
+ network = "host",
+ add_uds_tree = add_uds_tree,
+ tags = platforms[default_platform] + tags,
+ )
+
+ if not use_tmpfs:
+ # Also test shared gofer access.
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = default_platform,
+ use_tmpfs = use_tmpfs,
+ add_uds_tree = add_uds_tree,
+ tags = platforms[default_platform] + tags,
+ file_access = "shared",
+ )
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = default_platform,
+ use_tmpfs = use_tmpfs,
+ add_uds_tree = add_uds_tree,
+ tags = platforms[default_platform] + vfs2_tags,
+ file_access = "shared",
+ vfs2 = True,
+ )
diff --git a/test/runner/gtest/BUILD b/test/runner/gtest/BUILD
new file mode 100644
index 000000000..de4b2727c
--- /dev/null
+++ b/test/runner/gtest/BUILD
@@ -0,0 +1,9 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "gtest",
+ srcs = ["gtest.go"],
+ visibility = ["//:sandbox"],
+)
diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go
new file mode 100644
index 000000000..869169ad5
--- /dev/null
+++ b/test/runner/gtest/gtest.go
@@ -0,0 +1,168 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gtest contains helpers for running google-test tests from Go.
+package gtest
+
+import (
+ "fmt"
+ "os/exec"
+ "strings"
+)
+
+var (
+ // listTestFlag is the flag that will list tests in gtest binaries.
+ listTestFlag = "--gtest_list_tests"
+
+ // filterTestFlag is the flag that will filter tests in gtest binaries.
+ filterTestFlag = "--gtest_filter"
+
+ // listBechmarkFlag is the flag that will list benchmarks in gtest binaries.
+ listBenchmarkFlag = "--benchmark_list_tests"
+
+ // filterBenchmarkFlag is the flag that will run specified benchmarks.
+ filterBenchmarkFlag = "--benchmark_filter"
+)
+
+// TestCase is a single gtest test case.
+type TestCase struct {
+ // Suite is the suite for this test.
+ Suite string
+
+ // Name is the name of this individual test.
+ Name string
+
+ // all indicates that this will run without flags. This takes
+ // precendence over benchmark below.
+ all bool
+
+ // benchmark indicates that this is a benchmark. In this case, the
+ // suite will be empty, and we will use the appropriate test and
+ // benchmark flags.
+ benchmark bool
+}
+
+// FullName returns the name of the test including the suite. It is suitable to
+// pass to "-gtest_filter".
+func (tc TestCase) FullName() string {
+ return fmt.Sprintf("%s.%s", tc.Suite, tc.Name)
+}
+
+// Args returns arguments to be passed when invoking the test.
+func (tc TestCase) Args() []string {
+ if tc.all {
+ return []string{} // No arguments.
+ }
+ if tc.benchmark {
+ return []string{
+ fmt.Sprintf("%s=^%s$", filterBenchmarkFlag, tc.Name),
+ fmt.Sprintf("%s=", filterTestFlag),
+ }
+ }
+ return []string{
+ fmt.Sprintf("%s=%s", filterTestFlag, tc.FullName()),
+ }
+}
+
+// ParseTestCases calls a gtest test binary to list its test and returns a
+// slice with the name and suite of each test.
+//
+// If benchmarks is true, then benchmarks will be included in the list of test
+// cases provided. Note that this requires the binary to support the
+// benchmarks_list_tests flag.
+func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]TestCase, error) {
+ // Run to extract test cases.
+ args := append([]string{listTestFlag}, extraArgs...)
+ cmd := exec.Command(testBin, args...)
+ out, err := cmd.Output()
+ if err != nil {
+ // We failed to list tests with the given flags. Just
+ // return something that will run the binary with no
+ // flags, which should execute all tests.
+ return []TestCase{
+ TestCase{
+ Suite: "Default",
+ Name: "All",
+ all: true,
+ },
+ }, nil
+ }
+
+ // Parse test output.
+ var t []TestCase
+ var suite string
+ for _, line := range strings.Split(string(out), "\n") {
+ // Strip comments.
+ line = strings.Split(line, "#")[0]
+
+ // New suite?
+ if !strings.HasPrefix(line, " ") {
+ suite = strings.TrimSuffix(strings.TrimSpace(line), ".")
+ continue
+ }
+
+ // Individual test.
+ name := strings.TrimSpace(line)
+
+ // Do we have a suite yet?
+ if suite == "" {
+ return nil, fmt.Errorf("test without a suite: %v", name)
+ }
+
+ // Add this individual test.
+ t = append(t, TestCase{
+ Suite: suite,
+ Name: name,
+ })
+ }
+
+ // Finished?
+ if !benchmarks {
+ return t, nil
+ }
+
+ // Run again to extract benchmarks.
+ args = append([]string{listBenchmarkFlag}, extraArgs...)
+ cmd = exec.Command(testBin, args...)
+ out, err = cmd.Output()
+ if err != nil {
+ // We were able to enumerate tests above, but not benchmarks?
+ // We requested them, so we return an error in this case.
+ exitErr, ok := err.(*exec.ExitError)
+ if !ok {
+ return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v", err)
+ }
+ return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v\nstderr\n%s", err, exitErr.Stderr)
+ }
+
+ out = []byte(strings.Trim(string(out), "\n"))
+
+ // Parse benchmark output.
+ for _, line := range strings.Split(string(out), "\n") {
+ // Strip comments.
+ line = strings.Split(line, "#")[0]
+
+ // Single benchmark.
+ name := strings.TrimSpace(line)
+
+ // Add the single benchmark.
+ t = append(t, TestCase{
+ Suite: "Benchmarks",
+ Name: name,
+ benchmark: true,
+ })
+ }
+
+ return t, nil
+}
diff --git a/test/runner/runner.go b/test/runner/runner.go
new file mode 100644
index 000000000..5456e46a6
--- /dev/null
+++ b/test/runner/runner.go
@@ -0,0 +1,497 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary syscall_test_runner runs the syscall test suites in gVisor
+// containers and on the host platform.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "os/signal"
+ "path/filepath"
+ "strings"
+ "syscall"
+ "testing"
+ "time"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/runsc/specutils"
+ "gvisor.dev/gvisor/test/runner/gtest"
+ "gvisor.dev/gvisor/test/uds"
+)
+
+var (
+ debug = flag.Bool("debug", false, "enable debug logs")
+ strace = flag.Bool("strace", false, "enable strace logs")
+ platform = flag.String("platform", "ptrace", "platform to run on")
+ network = flag.String("network", "none", "network stack to run on (sandbox, host, none)")
+ useTmpfs = flag.Bool("use-tmpfs", false, "mounts tmpfs for /tmp")
+ fileAccess = flag.String("file-access", "exclusive", "mounts root in exclusive or shared mode")
+ overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable tmpfs overlay")
+ vfs2 = flag.Bool("vfs2", false, "enable VFS2")
+ parallel = flag.Bool("parallel", false, "run tests in parallel")
+ runscPath = flag.String("runsc", "", "path to runsc binary")
+
+ addUDSTree = flag.Bool("add-uds-tree", false, "expose a tree of UDS utilities for use in tests")
+)
+
+// runTestCaseNative runs the test case directly on the host machine.
+func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) {
+ // These tests might be running in parallel, so make sure they have a
+ // unique test temp dir.
+ tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "")
+ if err != nil {
+ t.Fatalf("could not create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ // Replace TEST_TMPDIR in the current environment with something
+ // unique.
+ env := os.Environ()
+ newEnvVar := "TEST_TMPDIR=" + tmpDir
+ var found bool
+ for i, kv := range env {
+ if strings.HasPrefix(kv, "TEST_TMPDIR=") {
+ env[i] = newEnvVar
+ found = true
+ break
+ }
+ }
+ if !found {
+ env = append(env, newEnvVar)
+ }
+ // Remove env variables that cause the gunit binary to write output
+ // files, since they will stomp on eachother, and on the output files
+ // from this go test.
+ env = filterEnv(env, []string{"GUNIT_OUTPUT", "TEST_PREMATURE_EXIT_FILE", "XML_OUTPUT_FILE"})
+
+ // Remove shard env variables so that the gunit binary does not try to
+ // intepret them.
+ env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"})
+
+ if *addUDSTree {
+ socketDir, cleanup, err := uds.CreateSocketTree("/tmp")
+ if err != nil {
+ t.Fatalf("failed to create socket tree: %v", err)
+ }
+ defer cleanup()
+
+ env = append(env, "TEST_UDS_TREE="+socketDir)
+ // On Linux, the concept of "attach" location doesn't exist.
+ // Just pass the same path to make these test identical.
+ env = append(env, "TEST_UDS_ATTACH_TREE="+socketDir)
+ }
+
+ cmd := exec.Command(testBin, tc.Args()...)
+ cmd.Env = env
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+ if err := cmd.Run(); err != nil {
+ ws := err.(*exec.ExitError).Sys().(syscall.WaitStatus)
+ t.Errorf("test %q exited with status %d, want 0", tc.FullName(), ws.ExitStatus())
+ }
+}
+
+// runRunsc runs spec in runsc in a standard test configuration.
+//
+// runsc logs will be saved to a path in TEST_UNDECLARED_OUTPUTS_DIR.
+//
+// Returns an error if the sandboxed application exits non-zero.
+func runRunsc(tc gtest.TestCase, spec *specs.Spec) error {
+ bundleDir, cleanup, err := testutil.SetupBundleDir(spec)
+ if err != nil {
+ return fmt.Errorf("SetupBundleDir failed: %v", err)
+ }
+ defer cleanup()
+
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ return fmt.Errorf("SetupRootDir failed: %v", err)
+ }
+ defer cleanup()
+
+ name := tc.FullName()
+ id := testutil.RandomContainerID()
+ log.Infof("Running test %q in container %q", name, id)
+ specutils.LogSpec(spec)
+
+ args := []string{
+ "-root", rootDir,
+ "-network", *network,
+ "-log-format=text",
+ "-TESTONLY-unsafe-nonroot=true",
+ "-net-raw=true",
+ fmt.Sprintf("-panic-signal=%d", syscall.SIGTERM),
+ "-watchdog-action=panic",
+ "-platform", *platform,
+ "-file-access", *fileAccess,
+ }
+ if *overlay {
+ args = append(args, "-overlay")
+ }
+ if *vfs2 {
+ args = append(args, "-vfs2")
+ }
+ if *debug {
+ args = append(args, "-debug", "-log-packets=true")
+ }
+ if *strace {
+ args = append(args, "-strace")
+ }
+ if *addUDSTree {
+ args = append(args, "-fsgofer-host-uds")
+ }
+
+ undeclaredOutputsDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR")
+ if ok {
+ tdir := filepath.Join(undeclaredOutputsDir, strings.Replace(name, "/", "_", -1))
+ if err := os.MkdirAll(tdir, 0755); err != nil {
+ return fmt.Errorf("could not create test dir: %v", err)
+ }
+ debugLogDir, err := ioutil.TempDir(tdir, "runsc")
+ if err != nil {
+ return fmt.Errorf("could not create temp dir: %v", err)
+ }
+ debugLogDir += "/"
+ log.Infof("runsc logs: %s", debugLogDir)
+ args = append(args, "-debug-log", debugLogDir)
+
+ // Default -log sends messages to stderr which makes reading the test log
+ // difficult. Instead, drop them when debug log is enabled given it's a
+ // better place for these messages.
+ args = append(args, "-log=/dev/null")
+ }
+
+ // Current process doesn't have CAP_SYS_ADMIN, create user namespace and run
+ // as root inside that namespace to get it.
+ rArgs := append(args, "run", "--bundle", bundleDir, id)
+ cmd := exec.Command(*runscPath, rArgs...)
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Cloneflags: syscall.CLONE_NEWUSER | syscall.CLONE_NEWNS,
+ // Set current user/group as root inside the namespace.
+ UidMappings: []syscall.SysProcIDMap{
+ {ContainerID: 0, HostID: os.Getuid(), Size: 1},
+ },
+ GidMappings: []syscall.SysProcIDMap{
+ {ContainerID: 0, HostID: os.Getgid(), Size: 1},
+ },
+ GidMappingsEnableSetgroups: false,
+ Credential: &syscall.Credential{
+ Uid: 0,
+ Gid: 0,
+ },
+ }
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+ sig := make(chan os.Signal, 1)
+ defer close(sig)
+ signal.Notify(sig, syscall.SIGTERM)
+ defer signal.Stop(sig)
+ go func() {
+ s, ok := <-sig
+ if !ok {
+ return
+ }
+ log.Warningf("%s: Got signal: %v", name, s)
+ done := make(chan bool, 1)
+ dArgs := append([]string{}, args...)
+ dArgs = append(dArgs, "-alsologtostderr=true", "debug", "--stacks", id)
+ go func(dArgs []string) {
+ cmd := exec.Command(*runscPath, dArgs...)
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+ cmd.Run()
+ done <- true
+ }(dArgs)
+
+ timeout := time.After(3 * time.Second)
+ select {
+ case <-timeout:
+ log.Infof("runsc debug --stacks is timeouted")
+ case <-done:
+ }
+
+ log.Warningf("Send SIGTERM to the sandbox process")
+ dArgs = append(args, "debug",
+ fmt.Sprintf("--signal=%d", syscall.SIGTERM),
+ id)
+ cmd := exec.Command(*runscPath, dArgs...)
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+ cmd.Run()
+ }()
+
+ err = cmd.Run()
+ if err == nil {
+ // If the test passed, then we erase the log directory. This speeds up
+ // uploading logs in continuous integration & saves on disk space.
+ os.RemoveAll(undeclaredOutputsDir)
+ }
+
+ return err
+}
+
+// setupUDSTree updates the spec to expose a UDS tree for gofer socket testing.
+func setupUDSTree(spec *specs.Spec) (cleanup func(), err error) {
+ socketDir, cleanup, err := uds.CreateSocketTree("/tmp")
+ if err != nil {
+ return nil, fmt.Errorf("failed to create socket tree: %v", err)
+ }
+
+ // Standard access to entire tree.
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: "/tmp/sockets",
+ Source: socketDir,
+ Type: "bind",
+ })
+
+ // Individial attach points for each socket to test mounts that attach
+ // directly to the sockets.
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: "/tmp/sockets-attach/stream/echo",
+ Source: filepath.Join(socketDir, "stream/echo"),
+ Type: "bind",
+ })
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: "/tmp/sockets-attach/stream/nonlistening",
+ Source: filepath.Join(socketDir, "stream/nonlistening"),
+ Type: "bind",
+ })
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: "/tmp/sockets-attach/seqpacket/echo",
+ Source: filepath.Join(socketDir, "seqpacket/echo"),
+ Type: "bind",
+ })
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: "/tmp/sockets-attach/seqpacket/nonlistening",
+ Source: filepath.Join(socketDir, "seqpacket/nonlistening"),
+ Type: "bind",
+ })
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: "/tmp/sockets-attach/dgram/null",
+ Source: filepath.Join(socketDir, "dgram/null"),
+ Type: "bind",
+ })
+
+ spec.Process.Env = append(spec.Process.Env, "TEST_UDS_TREE=/tmp/sockets")
+ spec.Process.Env = append(spec.Process.Env, "TEST_UDS_ATTACH_TREE=/tmp/sockets-attach")
+
+ return cleanup, nil
+}
+
+// runsTestCaseRunsc runs the test case in runsc.
+func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
+ // Run a new container with the test executable and filter for the
+ // given test suite and name.
+ spec := testutil.NewSpecWithArgs(append([]string{testBin}, tc.Args()...)...)
+
+ // Mark the root as writeable, as some tests attempt to
+ // write to the rootfs, and expect EACCES, not EROFS.
+ spec.Root.Readonly = false
+
+ // Test spec comes with pre-defined mounts that we don't want. Reset it.
+ spec.Mounts = nil
+ testTmpDir := "/tmp"
+ if *useTmpfs {
+ // Forces '/tmp' to be mounted as tmpfs, otherwise test that rely on
+ // features only available in gVisor's internal tmpfs may fail.
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: "/tmp",
+ Type: "tmpfs",
+ })
+ } else {
+ // Use a gofer-backed directory as '/tmp'.
+ //
+ // Tests might be running in parallel, so make sure each has a
+ // unique test temp dir.
+ //
+ // Some tests (e.g., sticky) access this mount from other
+ // users, so make sure it is world-accessible.
+ tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "")
+ if err != nil {
+ t.Fatalf("could not create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ if err := os.Chmod(tmpDir, 0777); err != nil {
+ t.Fatalf("could not chmod temp dir: %v", err)
+ }
+
+ // "/tmp" is not replaced with a tmpfs mount inside the sandbox
+ // when it's not empty. This ensures that testTmpDir uses gofer
+ // in exclusive mode.
+ testTmpDir = tmpDir
+ if *fileAccess == "shared" {
+ // All external mounts except the root mount are shared.
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Type: "bind",
+ Destination: "/tmp",
+ Source: tmpDir,
+ })
+ testTmpDir = "/tmp"
+ }
+ }
+
+ // Set environment variables that indicate we are running in gVisor with
+ // the given platform, network, and filesystem stack.
+ platformVar := "TEST_ON_GVISOR"
+ networkVar := "GVISOR_NETWORK"
+ env := append(os.Environ(), platformVar+"="+*platform, networkVar+"="+*network)
+ vfsVar := "GVISOR_VFS"
+ if *vfs2 {
+ env = append(env, vfsVar+"=VFS2")
+ } else {
+ env = append(env, vfsVar+"=VFS1")
+ }
+
+ // Remove env variables that cause the gunit binary to write output
+ // files, since they will stomp on eachother, and on the output files
+ // from this go test.
+ env = filterEnv(env, []string{"GUNIT_OUTPUT", "TEST_PREMATURE_EXIT_FILE", "XML_OUTPUT_FILE"})
+
+ // Remove shard env variables so that the gunit binary does not try to
+ // intepret them.
+ env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"})
+
+ // Set TEST_TMPDIR to /tmp, as some of the syscall tests require it to
+ // be backed by tmpfs.
+ env = filterEnv(env, []string{"TEST_TMPDIR"})
+ env = append(env, fmt.Sprintf("TEST_TMPDIR=%s", testTmpDir))
+
+ spec.Process.Env = env
+
+ if *addUDSTree {
+ cleanup, err := setupUDSTree(spec)
+ if err != nil {
+ t.Fatalf("error creating UDS tree: %v", err)
+ }
+ defer cleanup()
+ }
+
+ if err := runRunsc(tc, spec); err != nil {
+ t.Errorf("test %q failed with error %v, want nil", tc.FullName(), err)
+ }
+}
+
+// filterEnv returns an environment with the excluded variables removed.
+func filterEnv(env, exclude []string) []string {
+ var out []string
+ for _, kv := range env {
+ ok := true
+ for _, k := range exclude {
+ if strings.HasPrefix(kv, k+"=") {
+ ok = false
+ break
+ }
+ }
+ if ok {
+ out = append(out, kv)
+ }
+ }
+ return out
+}
+
+func fatalf(s string, args ...interface{}) {
+ fmt.Fprintf(os.Stderr, s+"\n", args...)
+ os.Exit(1)
+}
+
+func matchString(a, b string) (bool, error) {
+ return a == b, nil
+}
+
+func main() {
+ flag.Parse()
+ if flag.NArg() != 1 {
+ fatalf("test must be provided")
+ }
+ testBin := flag.Args()[0] // Only argument.
+
+ log.SetLevel(log.Info)
+ if *debug {
+ log.SetLevel(log.Debug)
+ }
+
+ if *platform != "native" && *runscPath == "" {
+ if err := testutil.ConfigureExePath(); err != nil {
+ panic(err.Error())
+ }
+ *runscPath = specutils.ExePath
+ }
+
+ // Make sure stdout and stderr are opened with O_APPEND, otherwise logs
+ // from outside the sandbox can (and will) stomp on logs from inside
+ // the sandbox.
+ for _, f := range []*os.File{os.Stdout, os.Stderr} {
+ flags, err := unix.FcntlInt(f.Fd(), unix.F_GETFL, 0)
+ if err != nil {
+ fatalf("error getting file flags for %v: %v", f, err)
+ }
+ if flags&unix.O_APPEND == 0 {
+ flags |= unix.O_APPEND
+ if _, err := unix.FcntlInt(f.Fd(), unix.F_SETFL, flags); err != nil {
+ fatalf("error setting file flags for %v: %v", f, err)
+ }
+ }
+ }
+
+ // Get all test cases in each binary.
+ testCases, err := gtest.ParseTestCases(testBin, true)
+ if err != nil {
+ fatalf("ParseTestCases(%q) failed: %v", testBin, err)
+ }
+
+ // Get subset of tests corresponding to shard.
+ indices, err := testutil.TestIndicesForShard(len(testCases))
+ if err != nil {
+ fatalf("TestsForShard() failed: %v", err)
+ }
+
+ // Resolve the absolute path for the binary.
+ testBin, err = filepath.Abs(testBin)
+ if err != nil {
+ fatalf("Abs() failed: %v", err)
+ }
+
+ // Run the tests.
+ var tests []testing.InternalTest
+ for _, tci := range indices {
+ // Capture tc.
+ tc := testCases[tci]
+ tests = append(tests, testing.InternalTest{
+ Name: fmt.Sprintf("%s_%s", tc.Suite, tc.Name),
+ F: func(t *testing.T) {
+ if *parallel {
+ t.Parallel()
+ }
+ if *platform == "native" {
+ // Run the test case on host.
+ runTestCaseNative(testBin, tc, t)
+ } else {
+ // Run the test case in runsc.
+ runTestCaseRunsc(testBin, tc, t)
+ }
+ },
+ })
+ }
+
+ testing.Main(matchString, tests, nil, nil)
+}
diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD
new file mode 100644
index 000000000..022de5ff7
--- /dev/null
+++ b/test/runtimes/BUILD
@@ -0,0 +1,33 @@
+load("//test/runtimes:defs.bzl", "runtime_test")
+
+package(licenses = ["notice"])
+
+runtime_test(
+ name = "go1.12",
+ exclude_file = "exclude_go1.12.csv",
+ lang = "go",
+)
+
+runtime_test(
+ name = "java11",
+ exclude_file = "exclude_java11.csv",
+ lang = "java",
+)
+
+runtime_test(
+ name = "nodejs12.4.0",
+ exclude_file = "exclude_nodejs12.4.0.csv",
+ lang = "nodejs",
+)
+
+runtime_test(
+ name = "php7.3.6",
+ exclude_file = "exclude_php7.3.6.csv",
+ lang = "php",
+)
+
+runtime_test(
+ name = "python3.7.3",
+ exclude_file = "exclude_python3.7.3.csv",
+ lang = "python",
+)
diff --git a/test/runtimes/defs.bzl b/test/runtimes/defs.bzl
new file mode 100644
index 000000000..dc3667f05
--- /dev/null
+++ b/test/runtimes/defs.bzl
@@ -0,0 +1,79 @@
+"""Defines a rule for runtime test targets."""
+
+load("//tools:defs.bzl", "go_test")
+
+def _runtime_test_impl(ctx):
+ # Construct arguments.
+ args = [
+ "--lang",
+ ctx.attr.lang,
+ "--image",
+ ctx.attr.image,
+ ]
+ if ctx.attr.exclude_file:
+ args += [
+ "--exclude_file",
+ ctx.files.exclude_file[0].short_path,
+ ]
+
+ # Build a runner.
+ runner = ctx.actions.declare_file("%s-executer" % ctx.label.name)
+ runner_content = "\n".join([
+ "#!/bin/bash",
+ "%s %s\n" % (ctx.files._runner[0].short_path, " ".join(args)),
+ ])
+ ctx.actions.write(runner, runner_content, is_executable = True)
+
+ # Return the runner.
+ return [DefaultInfo(
+ executable = runner,
+ runfiles = ctx.runfiles(
+ files = ctx.files._runner + ctx.files.exclude_file + ctx.files._proctor,
+ collect_default = True,
+ collect_data = True,
+ ),
+ )]
+
+_runtime_test = rule(
+ implementation = _runtime_test_impl,
+ attrs = {
+ "image": attr.string(
+ mandatory = False,
+ ),
+ "lang": attr.string(
+ mandatory = True,
+ ),
+ "exclude_file": attr.label(
+ mandatory = False,
+ allow_single_file = True,
+ ),
+ "_runner": attr.label(
+ default = "//test/runtimes/runner:runner",
+ ),
+ "_proctor": attr.label(
+ default = "//test/runtimes/proctor:proctor",
+ ),
+ },
+ test = True,
+)
+
+def runtime_test(name, **kwargs):
+ _runtime_test(
+ name = name,
+ image = name, # Resolved as images/runtimes/%s.
+ tags = [
+ "local",
+ "manual",
+ ],
+ **kwargs
+ )
+
+def exclude_test(name, exclude_file):
+ """Test that a exclude file parses correctly."""
+ go_test(
+ name = name + "_exclude_test",
+ library = ":runner",
+ srcs = ["exclude_test.go"],
+ args = ["--exclude_file", "test/runtimes/" + exclude_file],
+ data = [exclude_file],
+ )
diff --git a/test/runtimes/exclude_go1.12.csv b/test/runtimes/exclude_go1.12.csv
new file mode 100644
index 000000000..8c8ae0c5d
--- /dev/null
+++ b/test/runtimes/exclude_go1.12.csv
@@ -0,0 +1,16 @@
+test name,bug id,comment
+cgo_errors,,FLAKY
+cgo_test,,FLAKY
+go_test:cmd/go,,FLAKY
+go_test:cmd/vendor/golang.org/x/sys/unix,b/118783622,/dev devices missing
+go_test:net,b/118784196,socket: invalid argument. Works as intended: see bug.
+go_test:os,b/118780122,we have a pollable filesystem but that's a surprise
+go_test:os/signal,b/118780860,/dev/pts not properly supported
+go_test:runtime,b/118782341,sigtrap not reported or caught or something
+go_test:syscall,b/118781998,bad bytes -- bad mem addr
+race,b/118782931,thread sanitizer. Works as intended: b/62219744.
+runtime:cpu124,b/118778254,segmentation fault
+test:0_1,,FLAKY
+testasan,,
+testcarchive,b/118782924,no sigpipe
+testshared,,FLAKY
diff --git a/test/runtimes/exclude_java11.csv b/test/runtimes/exclude_java11.csv
new file mode 100644
index 000000000..c012e5a56
--- /dev/null
+++ b/test/runtimes/exclude_java11.csv
@@ -0,0 +1,126 @@
+test name,bug id,comment
+com/sun/crypto/provider/Cipher/PBE/PKCS12Cipher.java,,Fails in Docker
+com/sun/jdi/NashornPopFrameTest.java,,
+com/sun/jdi/ProcessAttachTest.java,,
+com/sun/management/HotSpotDiagnosticMXBean/CheckOrigin.java,,Fails in Docker
+com/sun/management/OperatingSystemMXBean/GetCommittedVirtualMemorySize.java,,
+com/sun/management/UnixOperatingSystemMXBean/GetMaxFileDescriptorCount.sh,,
+com/sun/tools/attach/AttachSelf.java,,
+com/sun/tools/attach/BasicTests.java,,
+com/sun/tools/attach/PermissionTest.java,,
+com/sun/tools/attach/StartManagementAgent.java,,
+com/sun/tools/attach/TempDirTest.java,,
+com/sun/tools/attach/modules/Driver.java,,
+java/lang/Character/CheckScript.java,,Fails in Docker
+java/lang/Character/CheckUnicode.java,,Fails in Docker
+java/lang/Class/GetPackageBootLoaderChildLayer.java,,
+java/lang/ClassLoader/nativeLibrary/NativeLibraryTest.java,,Fails in Docker
+java/lang/String/nativeEncoding/StringPlatformChars.java,,
+java/net/DatagramSocket/ReuseAddressTest.java,,
+java/net/DatagramSocket/SendDatagramToBadAddress.java,b/78473345,
+java/net/Inet4Address/PingThis.java,,
+java/net/InterfaceAddress/NetworkPrefixLength.java,b/78507103,
+java/net/MulticastSocket/MulticastTTL.java,,
+java/net/MulticastSocket/Promiscuous.java,,
+java/net/MulticastSocket/SetLoopbackMode.java,,
+java/net/MulticastSocket/SetTTLAndGetTTL.java,,
+java/net/MulticastSocket/Test.java,,
+java/net/MulticastSocket/TestDefaults.java,,
+java/net/MulticastSocket/TimeToLive.java,,
+java/net/NetworkInterface/NetworkInterfaceStreamTest.java,,
+java/net/Socket/SetSoLinger.java,b/78527327,SO_LINGER is not yet supported
+java/net/Socket/TrafficClass.java,b/78527818,Not supported on gVisor
+java/net/Socket/UrgentDataTest.java,b/111515323,
+java/net/Socket/setReuseAddress/Basic.java,b/78519214,SO_REUSEADDR enabled by default
+java/net/SocketOption/OptionsTest.java,,Fails in Docker
+java/net/SocketOption/TcpKeepAliveTest.java,,
+java/net/SocketPermission/SocketPermissionTest.java,,
+java/net/URLConnection/6212146/TestDriver.java,,Fails in Docker
+java/net/httpclient/RequestBuilderTest.java,,Fails in Docker
+java/net/httpclient/ShortResponseBody.java,,
+java/net/httpclient/ShortResponseBodyWithRetry.java,,
+java/nio/channels/AsyncCloseAndInterrupt.java,,
+java/nio/channels/AsynchronousServerSocketChannel/Basic.java,,
+java/nio/channels/AsynchronousSocketChannel/Basic.java,b/77921528,SO_KEEPALIVE is not settable
+java/nio/channels/DatagramChannel/BasicMulticastTests.java,,
+java/nio/channels/DatagramChannel/SocketOptionTests.java,,Fails in Docker
+java/nio/channels/DatagramChannel/UseDGWithIPv6.java,,
+java/nio/channels/FileChannel/directio/DirectIOTest.java,,Fails in Docker
+java/nio/channels/Selector/OutOfBand.java,,
+java/nio/channels/Selector/SelectWithConsumer.java,,Flaky
+java/nio/channels/ServerSocketChannel/SocketOptionTests.java,,
+java/nio/channels/SocketChannel/LingerOnClose.java,,
+java/nio/channels/SocketChannel/SocketOptionTests.java,b/77965901,
+java/nio/channels/spi/SelectorProvider/inheritedChannel/InheritedChannelTest.java,,Fails in Docker
+java/rmi/activation/Activatable/extLoadedImpl/ext.sh,,
+java/rmi/transport/checkLeaseInfoLeak/CheckLeaseLeak.java,,
+java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker
+java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker
+java/util/Calendar/JapaneseEraNameTest.java,,
+java/util/Currency/CurrencyTest.java,,Fails in Docker
+java/util/Currency/ValidateISO4217.java,,Fails in Docker
+java/util/Locale/LSRDataTest.java,,
+java/util/concurrent/locks/Lock/TimedAcquireLeak.java,,
+java/util/jar/JarFile/mrjar/MultiReleaseJarAPI.java,,Fails in Docker
+java/util/logging/LogManager/Configuration/updateConfiguration/SimpleUpdateConfigWithInputStreamTest.java,,
+java/util/logging/TestLoggerWeakRefLeak.java,,
+javax/imageio/AppletResourceTest.java,,
+javax/management/security/HashedPasswordFileTest.java,,
+javax/net/ssl/SSLSession/JSSERenegotiate.java,,Fails in Docker
+javax/sound/sampled/AudioInputStream/FrameLengthAfterConversion.java,,
+jdk/jfr/event/runtime/TestNetworkUtilizationEvent.java,,
+jdk/jfr/event/runtime/TestThreadParkEvent.java,,
+jdk/jfr/event/sampling/TestNative.java,,
+jdk/jfr/jcmd/TestJcmdChangeLogLevel.java,,
+jdk/jfr/jcmd/TestJcmdConfigure.java,,
+jdk/jfr/jcmd/TestJcmdDump.java,,
+jdk/jfr/jcmd/TestJcmdDumpGeneratedFilename.java,,
+jdk/jfr/jcmd/TestJcmdDumpLimited.java,,
+jdk/jfr/jcmd/TestJcmdDumpPathToGCRoots.java,,
+jdk/jfr/jcmd/TestJcmdLegacy.java,,
+jdk/jfr/jcmd/TestJcmdSaveToFile.java,,
+jdk/jfr/jcmd/TestJcmdStartDirNotExist.java,,
+jdk/jfr/jcmd/TestJcmdStartInvaldFile.java,,
+jdk/jfr/jcmd/TestJcmdStartPathToGCRoots.java,,
+jdk/jfr/jcmd/TestJcmdStartStopDefault.java,,
+jdk/jfr/jcmd/TestJcmdStartWithOptions.java,,
+jdk/jfr/jcmd/TestJcmdStartWithSettings.java,,
+jdk/jfr/jcmd/TestJcmdStopInvalidFile.java,,
+jdk/jfr/jvm/TestJfrJavaBase.java,,
+jdk/jfr/startupargs/TestStartRecording.java,,
+jdk/modules/incubator/ImageModules.java,,
+jdk/net/Sockets/ExtOptionTest.java,,
+jdk/net/Sockets/QuickAckTest.java,,
+lib/security/cacerts/VerifyCACerts.java,,
+sun/management/jmxremote/bootstrap/CustomLauncherTest.java,,
+sun/management/jmxremote/bootstrap/JvmstatCountersTest.java,,
+sun/management/jmxremote/bootstrap/LocalManagementTest.java,,
+sun/management/jmxremote/bootstrap/RmiRegistrySslTest.java,,
+sun/management/jmxremote/bootstrap/RmiSslBootstrapTest.sh,,
+sun/management/jmxremote/startstop/JMXStartStopTest.java,,
+sun/management/jmxremote/startstop/JMXStatusPerfCountersTest.java,,
+sun/management/jmxremote/startstop/JMXStatusTest.java,,
+sun/text/resources/LocaleDataTest.java,,
+sun/tools/jcmd/TestJcmdSanity.java,,
+sun/tools/jhsdb/AlternateHashingTest.java,,
+sun/tools/jhsdb/BasicLauncherTest.java,,
+sun/tools/jhsdb/HeapDumpTest.java,,
+sun/tools/jhsdb/heapconfig/JMapHeapConfigTest.java,,
+sun/tools/jinfo/BasicJInfoTest.java,,
+sun/tools/jinfo/JInfoTest.java,,
+sun/tools/jmap/BasicJMapTest.java,,
+sun/tools/jstack/BasicJStackTest.java,,
+sun/tools/jstack/DeadlockDetectionTest.java,,
+sun/tools/jstatd/TestJstatdExternalRegistry.java,,
+sun/tools/jstatd/TestJstatdPort.java,,Flaky
+sun/tools/jstatd/TestJstatdPortAndServer.java,,Flaky
+sun/util/calendar/zi/TestZoneInfo310.java,,
+tools/jar/modularJar/Basic.java,,
+tools/jar/multiRelease/Basic.java,,
+tools/jimage/JImageExtractTest.java,,
+tools/jimage/JImageTest.java,,
+tools/jlink/JLinkTest.java,,
+tools/jlink/plugins/IncludeLocalesPluginTest.java,,
+tools/jmod/hashes/HashesTest.java,,
+tools/launcher/BigJar.java,b/111611473,
+tools/launcher/modules/patch/systemmodules/PatchSystemModules.java,,
diff --git a/test/runtimes/exclude_nodejs12.4.0.csv b/test/runtimes/exclude_nodejs12.4.0.csv
new file mode 100644
index 000000000..4ab4e2927
--- /dev/null
+++ b/test/runtimes/exclude_nodejs12.4.0.csv
@@ -0,0 +1,47 @@
+test name,bug id,comment
+benchmark/test-benchmark-fs.js,,
+benchmark/test-benchmark-module.js,,
+benchmark/test-benchmark-napi.js,,
+doctool/test-make-doc.js,b/68848110,Expected to fail.
+fixtures/test-error-first-line-offset.js,,
+fixtures/test-fs-readfile-error.js,,
+fixtures/test-fs-stat-sync-overflow.js,,
+internet/test-dgram-broadcast-multi-process.js,,
+internet/test-dgram-multicast-multi-process.js,,
+internet/test-dgram-multicast-set-interface-lo.js,,
+parallel/test-cluster-dgram-reuse.js,b/64024294,
+parallel/test-dgram-bind-fd.js,b/132447356,
+parallel/test-dgram-create-socket-handle-fd.js,b/132447238,
+parallel/test-dgram-createSocket-type.js,b/68847739,
+parallel/test-dgram-socket-buffer-size.js,b/68847921,
+parallel/test-fs-access.js,,
+parallel/test-fs-write-stream-double-close.js,,
+parallel/test-fs-write-stream-throw-type-error.js,b/110226209,
+parallel/test-fs-write-stream.js,,
+parallel/test-http2-respond-file-error-pipe-offset.js,,
+parallel/test-os.js,,
+parallel/test-process-uid-gid.js,,
+pseudo-tty/test-assert-colors.js,,
+pseudo-tty/test-assert-no-color.js,,
+pseudo-tty/test-assert-position-indicator.js,,
+pseudo-tty/test-async-wrap-getasyncid-tty.js,,
+pseudo-tty/test-fatal-error.js,,
+pseudo-tty/test-handle-wrap-isrefed-tty.js,,
+pseudo-tty/test-readable-tty-keepalive.js,,
+pseudo-tty/test-set-raw-mode-reset-process-exit.js,,
+pseudo-tty/test-set-raw-mode-reset-signal.js,,
+pseudo-tty/test-set-raw-mode-reset.js,,
+pseudo-tty/test-stderr-stdout-handle-sigwinch.js,,
+pseudo-tty/test-stdout-read.js,,
+pseudo-tty/test-tty-color-support.js,,
+pseudo-tty/test-tty-isatty.js,,
+pseudo-tty/test-tty-stdin-call-end.js,,
+pseudo-tty/test-tty-stdin-end.js,,
+pseudo-tty/test-stdin-write.js,,
+pseudo-tty/test-tty-stdout-end.js,,
+pseudo-tty/test-tty-stdout-resize.js,,
+pseudo-tty/test-tty-stream-constructors.js,,
+pseudo-tty/test-tty-window-size.js,,
+pseudo-tty/test-tty-wrap.js,,
+pummel/test-net-pingpong.js,,
+pummel/test-vm-memleak.js,,
diff --git a/test/runtimes/exclude_php7.3.6.csv b/test/runtimes/exclude_php7.3.6.csv
new file mode 100644
index 000000000..456bf7487
--- /dev/null
+++ b/test/runtimes/exclude_php7.3.6.csv
@@ -0,0 +1,29 @@
+test name,bug id,comment
+ext/intl/tests/bug77895.phpt,,
+ext/intl/tests/dateformat_bug65683_2.phpt,,
+ext/mbstring/tests/bug76319.phpt,,
+ext/mbstring/tests/bug76958.phpt,,
+ext/mbstring/tests/bug77025.phpt,,
+ext/mbstring/tests/bug77165.phpt,,
+ext/mbstring/tests/bug77454.phpt,,
+ext/mbstring/tests/mb_convert_encoding_leak.phpt,,
+ext/mbstring/tests/mb_strrpos_encoding_3rd_param.phpt,,
+ext/standard/tests/file/filetype_variation.phpt,,
+ext/standard/tests/file/fopen_variation19.phpt,,
+ext/standard/tests/file/php_fd_wrapper_01.phpt,,
+ext/standard/tests/file/php_fd_wrapper_02.phpt,,
+ext/standard/tests/file/php_fd_wrapper_03.phpt,,
+ext/standard/tests/file/php_fd_wrapper_04.phpt,,
+ext/standard/tests/file/realpath_bug77484.phpt,,
+ext/standard/tests/file/rename_variation.phpt,b/68717309,
+ext/standard/tests/file/symlink_link_linkinfo_is_link_variation4.phpt,,
+ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,,
+ext/standard/tests/general_functions/escapeshellarg_bug71270.phpt,,
+ext/standard/tests/general_functions/escapeshellcmd_bug71270.phpt,,
+ext/standard/tests/network/bug20134.phpt,,
+tests/output/stream_isatty_err.phpt,b/68720279,
+tests/output/stream_isatty_in-err.phpt,b/68720282,
+tests/output/stream_isatty_in-out-err.phpt,,
+tests/output/stream_isatty_in-out.phpt,b/68720299,
+tests/output/stream_isatty_out-err.phpt,b/68720311,
+tests/output/stream_isatty_out.phpt,b/68720325,
diff --git a/test/runtimes/exclude_python3.7.3.csv b/test/runtimes/exclude_python3.7.3.csv
new file mode 100644
index 000000000..2b9947212
--- /dev/null
+++ b/test/runtimes/exclude_python3.7.3.csv
@@ -0,0 +1,27 @@
+test name,bug id,comment
+test_asynchat,b/76031995,SO_REUSEADDR
+test_asyncio,,Fails on Docker.
+test_asyncore,b/76031995,SO_REUSEADDR
+test_epoll,,
+test_fcntl,,fcntl invalid argument -- artificial test to make sure something works in 64 bit mode.
+test_ftplib,,Fails in Docker
+test_httplib,b/76031995,SO_REUSEADDR
+test_imaplib,,
+test_logging,,
+test_multiprocessing_fork,,Flaky. Sometimes times out.
+test_multiprocessing_forkserver,,Flaky. Sometimes times out.
+test_multiprocessing_main_handling,,Flaky. Sometimes times out.
+test_multiprocessing_spawn,,Flaky. Sometimes times out.
+test_nntplib,b/76031995,tests should not set SO_REUSEADDR
+test_poplib,,Fails on Docker
+test_posix,b/76174079,posix.sched_get_priority_min not implemented + posix.sched_rr_get_interval not permitted
+test_pty,b/76157709,out of pty devices
+test_readline,b/76157709,out of pty devices
+test_resource,b/76174079,
+test_selectors,b/76116849,OSError not raised with epoll
+test_smtplib,b/76031995,SO_REUSEADDR and unclosed sockets
+test_socket,b/75983380,
+test_ssl,b/76031995,SO_REUSEADDR
+test_subprocess,,
+test_support,b/76031995,SO_REUSEADDR
+test_telnetlib,b/76031995,SO_REUSEADDR
diff --git a/test/runtimes/proctor/BUILD b/test/runtimes/proctor/BUILD
new file mode 100644
index 000000000..f76e2ddc0
--- /dev/null
+++ b/test/runtimes/proctor/BUILD
@@ -0,0 +1,28 @@
+load("//tools:defs.bzl", "go_binary", "go_test")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "proctor",
+ srcs = [
+ "go.go",
+ "java.go",
+ "nodejs.go",
+ "php.go",
+ "proctor.go",
+ "python.go",
+ ],
+ pure = True,
+ visibility = ["//test/runtimes:__pkg__"],
+)
+
+go_test(
+ name = "proctor_test",
+ size = "small",
+ srcs = ["proctor_test.go"],
+ library = ":proctor",
+ pure = True,
+ deps = [
+ "//pkg/test/testutil",
+ ],
+)
diff --git a/test/runtimes/proctor/go.go b/test/runtimes/proctor/go.go
new file mode 100644
index 000000000..3e2d5d8db
--- /dev/null
+++ b/test/runtimes/proctor/go.go
@@ -0,0 +1,90 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "fmt"
+ "os"
+ "os/exec"
+ "regexp"
+ "strings"
+)
+
+var (
+ goTestRegEx = regexp.MustCompile(`^.+\.go$`)
+
+ // Directories with .dir contain helper files for tests.
+ // Exclude benchmarks and stress tests.
+ goDirFilter = regexp.MustCompile(`^(bench|stress)\/.+$|^.+\.dir.+$`)
+)
+
+// Location of Go tests on disk.
+const goTestDir = "/usr/local/go/test"
+
+// goRunner implements TestRunner for Go.
+//
+// There are two types of Go tests: "Go tool tests" and "Go tests on disk".
+// "Go tool tests" are found and executed using `go tool dist test`. "Go tests
+// on disk" are found in the /usr/local/go/test directory and are executed
+// using `go run run.go`.
+type goRunner struct{}
+
+var _ TestRunner = goRunner{}
+
+// ListTests implements TestRunner.ListTests.
+func (goRunner) ListTests() ([]string, error) {
+ // Go tool dist test tests.
+ args := []string{"tool", "dist", "test", "-list"}
+ cmd := exec.Command("go", args...)
+ cmd.Stderr = os.Stderr
+ out, err := cmd.Output()
+ if err != nil {
+ return nil, fmt.Errorf("failed to list: %v", err)
+ }
+ var toolSlice []string
+ for _, test := range strings.Split(string(out), "\n") {
+ toolSlice = append(toolSlice, test)
+ }
+
+ // Go tests on disk.
+ diskSlice, err := search(goTestDir, goTestRegEx)
+ if err != nil {
+ return nil, err
+ }
+ // Remove items from /bench/, /stress/ and .dir files
+ diskFiltered := diskSlice[:0]
+ for _, file := range diskSlice {
+ if !goDirFilter.MatchString(file) {
+ diskFiltered = append(diskFiltered, file)
+ }
+ }
+
+ return append(toolSlice, diskFiltered...), nil
+}
+
+// TestCmd implements TestRunner.TestCmd.
+func (goRunner) TestCmd(test string) *exec.Cmd {
+ // Check if test exists on disk by searching for file of the same name.
+ // This will determine whether or not it is a Go test on disk.
+ if strings.HasSuffix(test, ".go") {
+ // Test has suffix ".go" which indicates a disk test, run it as such.
+ cmd := exec.Command("go", "run", "run.go", "-v", "--", test)
+ cmd.Dir = goTestDir
+ return cmd
+ }
+
+ // No ".go" suffix, run as a tool test.
+ return exec.Command("go", "tool", "dist", "test", "-run", test)
+}
diff --git a/test/runtimes/proctor/java.go b/test/runtimes/proctor/java.go
new file mode 100644
index 000000000..8b362029d
--- /dev/null
+++ b/test/runtimes/proctor/java.go
@@ -0,0 +1,71 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "fmt"
+ "os"
+ "os/exec"
+ "regexp"
+ "strings"
+)
+
+// Directories to exclude from tests.
+var javaExclDirs = regexp.MustCompile(`(^(sun\/security)|(java\/util\/stream)|(java\/time)| )`)
+
+// Location of java tests.
+const javaTestDir = "/root/test/jdk"
+
+// javaRunner implements TestRunner for Java.
+type javaRunner struct{}
+
+var _ TestRunner = javaRunner{}
+
+// ListTests implements TestRunner.ListTests.
+func (javaRunner) ListTests() ([]string, error) {
+ args := []string{
+ "-dir:" + javaTestDir,
+ "-ignore:quiet",
+ "-a",
+ "-listtests",
+ ":jdk_core",
+ ":jdk_svc",
+ ":jdk_sound",
+ ":jdk_imageio",
+ }
+ cmd := exec.Command("jtreg", args...)
+ cmd.Stderr = os.Stderr
+ out, err := cmd.Output()
+ if err != nil {
+ return nil, fmt.Errorf("jtreg -listtests : %v", err)
+ }
+ var testSlice []string
+ for _, test := range strings.Split(string(out), "\n") {
+ if !javaExclDirs.MatchString(test) {
+ testSlice = append(testSlice, test)
+ }
+ }
+ return testSlice, nil
+}
+
+// TestCmd implements TestRunner.TestCmd.
+func (javaRunner) TestCmd(test string) *exec.Cmd {
+ args := []string{
+ "-noreport",
+ "-dir:" + javaTestDir,
+ test,
+ }
+ return exec.Command("jtreg", args...)
+}
diff --git a/test/runtimes/proctor/nodejs.go b/test/runtimes/proctor/nodejs.go
new file mode 100644
index 000000000..bd57db444
--- /dev/null
+++ b/test/runtimes/proctor/nodejs.go
@@ -0,0 +1,46 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "os/exec"
+ "path/filepath"
+ "regexp"
+)
+
+var nodejsTestRegEx = regexp.MustCompile(`^test-[^-].+\.js$`)
+
+// Location of nodejs tests relative to working dir.
+const nodejsTestDir = "test"
+
+// nodejsRunner implements TestRunner for NodeJS.
+type nodejsRunner struct{}
+
+var _ TestRunner = nodejsRunner{}
+
+// ListTests implements TestRunner.ListTests.
+func (nodejsRunner) ListTests() ([]string, error) {
+ testSlice, err := search(nodejsTestDir, nodejsTestRegEx)
+ if err != nil {
+ return nil, err
+ }
+ return testSlice, nil
+}
+
+// TestCmd implements TestRunner.TestCmd.
+func (nodejsRunner) TestCmd(test string) *exec.Cmd {
+ args := []string{filepath.Join("tools", "test.py"), test}
+ return exec.Command("/usr/bin/python", args...)
+}
diff --git a/test/runtimes/proctor/php.go b/test/runtimes/proctor/php.go
new file mode 100644
index 000000000..9115040e1
--- /dev/null
+++ b/test/runtimes/proctor/php.go
@@ -0,0 +1,42 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "os/exec"
+ "regexp"
+)
+
+var phpTestRegEx = regexp.MustCompile(`^.+\.phpt$`)
+
+// phpRunner implements TestRunner for PHP.
+type phpRunner struct{}
+
+var _ TestRunner = phpRunner{}
+
+// ListTests implements TestRunner.ListTests.
+func (phpRunner) ListTests() ([]string, error) {
+ testSlice, err := search(".", phpTestRegEx)
+ if err != nil {
+ return nil, err
+ }
+ return testSlice, nil
+}
+
+// TestCmd implements TestRunner.TestCmd.
+func (phpRunner) TestCmd(test string) *exec.Cmd {
+ args := []string{"test", "TESTS=" + test}
+ return exec.Command("make", args...)
+}
diff --git a/test/runtimes/proctor/proctor.go b/test/runtimes/proctor/proctor.go
new file mode 100644
index 000000000..b54abe434
--- /dev/null
+++ b/test/runtimes/proctor/proctor.go
@@ -0,0 +1,163 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary proctor runs the test for a particular runtime. It is meant to be
+// included in Docker images for all runtime tests.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "log"
+ "os"
+ "os/exec"
+ "os/signal"
+ "path/filepath"
+ "regexp"
+ "syscall"
+)
+
+// TestRunner is an interface that must be implemented for each runtime
+// integrated with proctor.
+type TestRunner interface {
+ // ListTests returns a string slice of tests available to run.
+ ListTests() ([]string, error)
+
+ // TestCmd returns an *exec.Cmd that will run the given test.
+ TestCmd(test string) *exec.Cmd
+}
+
+var (
+ runtime = flag.String("runtime", "", "name of runtime")
+ list = flag.Bool("list", false, "list all available tests")
+ testName = flag.String("test", "", "run a single test from the list of available tests")
+ pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children")
+)
+
+func main() {
+ flag.Parse()
+
+ if *pause {
+ pauseAndReap()
+ panic("pauseAndReap should never return")
+ }
+
+ if *runtime == "" {
+ log.Fatalf("runtime flag must be provided")
+ }
+
+ tr, err := testRunnerForRuntime(*runtime)
+ if err != nil {
+ log.Fatalf("%v", err)
+ }
+
+ // List tests.
+ if *list {
+ tests, err := tr.ListTests()
+ if err != nil {
+ log.Fatalf("failed to list tests: %v", err)
+ }
+ for _, test := range tests {
+ fmt.Println(test)
+ }
+ return
+ }
+
+ var tests []string
+ if *testName == "" {
+ // Run every test.
+ tests, err = tr.ListTests()
+ if err != nil {
+ log.Fatalf("failed to get all tests: %v", err)
+ }
+ } else {
+ // Run a single test.
+ tests = []string{*testName}
+ }
+ for _, test := range tests {
+ cmd := tr.TestCmd(test)
+ cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
+ if err := cmd.Run(); err != nil {
+ log.Fatalf("FAIL: %v", err)
+ }
+ }
+}
+
+// testRunnerForRuntime returns a new TestRunner for the given runtime.
+func testRunnerForRuntime(runtime string) (TestRunner, error) {
+ switch runtime {
+ case "go":
+ return goRunner{}, nil
+ case "java":
+ return javaRunner{}, nil
+ case "nodejs":
+ return nodejsRunner{}, nil
+ case "php":
+ return phpRunner{}, nil
+ case "python":
+ return pythonRunner{}, nil
+ }
+ return nil, fmt.Errorf("invalid runtime %q", runtime)
+}
+
+// pauseAndReap is like init. It runs forever and reaps any children.
+func pauseAndReap() {
+ // Get notified of any new children.
+ ch := make(chan os.Signal, 1)
+ signal.Notify(ch, syscall.SIGCHLD)
+
+ for {
+ if _, ok := <-ch; !ok {
+ // Channel closed. This should not happen.
+ panic("signal channel closed")
+ }
+
+ // Reap the child.
+ for {
+ if cpid, _ := syscall.Wait4(-1, nil, syscall.WNOHANG, nil); cpid < 1 {
+ break
+ }
+ }
+ }
+}
+
+// search is a helper function to find tests in the given directory that match
+// the regex.
+func search(root string, testFilter *regexp.Regexp) ([]string, error) {
+ var testSlice []string
+
+ err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
+ if err != nil {
+ return err
+ }
+
+ name := filepath.Base(path)
+
+ if info.IsDir() || !testFilter.MatchString(name) {
+ return nil
+ }
+
+ relPath, err := filepath.Rel(root, path)
+ if err != nil {
+ return err
+ }
+ testSlice = append(testSlice, relPath)
+ return nil
+ })
+ if err != nil {
+ return nil, fmt.Errorf("walking %q: %v", root, err)
+ }
+
+ return testSlice, nil
+}
diff --git a/test/runtimes/proctor/proctor_test.go b/test/runtimes/proctor/proctor_test.go
new file mode 100644
index 000000000..6ef2de085
--- /dev/null
+++ b/test/runtimes/proctor/proctor_test.go
@@ -0,0 +1,127 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "reflect"
+ "regexp"
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+func touch(t *testing.T, name string) {
+ t.Helper()
+ f, err := os.Create(name)
+ if err != nil {
+ t.Fatalf("error creating file %q: %v", name, err)
+ }
+ if err := f.Close(); err != nil {
+ t.Fatalf("error closing file %q: %v", name, err)
+ }
+}
+
+func TestSearchEmptyDir(t *testing.T) {
+ td, err := ioutil.TempDir(testutil.TmpDir(), "searchtest")
+ if err != nil {
+ t.Fatalf("error creating searchtest: %v", err)
+ }
+ defer os.RemoveAll(td)
+
+ var want []string
+
+ testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`)
+ got, err := search(td, testFilter)
+ if err != nil {
+ t.Errorf("search error: %v", err)
+ }
+
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Found %#v; want %#v", got, want)
+ }
+}
+
+func TestSearch(t *testing.T) {
+ td, err := ioutil.TempDir(testutil.TmpDir(), "searchtest")
+ if err != nil {
+ t.Fatalf("error creating searchtest: %v", err)
+ }
+ defer os.RemoveAll(td)
+
+ // Creating various files similar to the test filter regex.
+ files := []string{
+ "emp/",
+ "tee/",
+ "test-foo.tc",
+ "test-foo.tc",
+ "test-bar.tc",
+ "test-sam.tc",
+ "Test-que.tc",
+ "test-brett",
+ "test--abc.tc",
+ "test---xyz.tc",
+ "test-bool.TC",
+ "--test-gvs.tc",
+ " test-pew.tc",
+ "dir/test_baz.tc",
+ "dir/testsnap.tc",
+ "dir/test-luk.tc",
+ "dir/nest/test-ok.tc",
+ "dir/dip/diz/goog/test-pack.tc",
+ "dir/dip/diz/wobble/thud/test-cas.e",
+ "dir/dip/diz/wobble/thud/test-cas.tc",
+ }
+ want := []string{
+ "dir/dip/diz/goog/test-pack.tc",
+ "dir/dip/diz/wobble/thud/test-cas.tc",
+ "dir/nest/test-ok.tc",
+ "dir/test-luk.tc",
+ "test-bar.tc",
+ "test-foo.tc",
+ "test-sam.tc",
+ }
+
+ for _, item := range files {
+ if strings.HasSuffix(item, "/") {
+ // This item is a directory, create it.
+ if err := os.MkdirAll(filepath.Join(td, item), 0755); err != nil {
+ t.Fatalf("error making directory: %v", err)
+ }
+ } else {
+ // This item is a file, create the directory and touch file.
+ // Create directory in which file should be created
+ fullDirPath := filepath.Join(td, filepath.Dir(item))
+ if err := os.MkdirAll(fullDirPath, 0755); err != nil {
+ t.Fatalf("error making directory: %v", err)
+ }
+ // Create file with full path to file.
+ touch(t, filepath.Join(td, item))
+ }
+ }
+
+ testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`)
+ got, err := search(td, testFilter)
+ if err != nil {
+ t.Errorf("search error: %v", err)
+ }
+
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Found %#v; want %#v", got, want)
+ }
+}
diff --git a/test/runtimes/proctor/python.go b/test/runtimes/proctor/python.go
new file mode 100644
index 000000000..b9e0fbe6f
--- /dev/null
+++ b/test/runtimes/proctor/python.go
@@ -0,0 +1,49 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "fmt"
+ "os"
+ "os/exec"
+ "strings"
+)
+
+// pythonRunner implements TestRunner for Python.
+type pythonRunner struct{}
+
+var _ TestRunner = pythonRunner{}
+
+// ListTests implements TestRunner.ListTests.
+func (pythonRunner) ListTests() ([]string, error) {
+ args := []string{"-m", "test", "--list-tests"}
+ cmd := exec.Command("./python", args...)
+ cmd.Stderr = os.Stderr
+ out, err := cmd.Output()
+ if err != nil {
+ return nil, fmt.Errorf("failed to list: %v", err)
+ }
+ var toolSlice []string
+ for _, test := range strings.Split(string(out), "\n") {
+ toolSlice = append(toolSlice, test)
+ }
+ return toolSlice, nil
+}
+
+// TestCmd implements TestRunner.TestCmd.
+func (pythonRunner) TestCmd(test string) *exec.Cmd {
+ args := []string{"-m", "test", test}
+ return exec.Command("./python", args...)
+}
diff --git a/test/runtimes/runner/BUILD b/test/runtimes/runner/BUILD
new file mode 100644
index 000000000..3972244b9
--- /dev/null
+++ b/test/runtimes/runner/BUILD
@@ -0,0 +1,21 @@
+load("//tools:defs.bzl", "go_binary", "go_test")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "runner",
+ testonly = 1,
+ srcs = ["main.go"],
+ visibility = ["//test/runtimes:__pkg__"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//pkg/test/testutil",
+ ],
+)
+
+go_test(
+ name = "exclude_test",
+ size = "small",
+ srcs = ["exclude_test.go"],
+ library = ":runner",
+)
diff --git a/test/runtimes/runner/exclude_test.go b/test/runtimes/runner/exclude_test.go
new file mode 100644
index 000000000..c08755894
--- /dev/null
+++ b/test/runtimes/runner/exclude_test.go
@@ -0,0 +1,37 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "flag"
+ "os"
+ "testing"
+)
+
+func TestMain(m *testing.M) {
+ flag.Parse()
+ os.Exit(m.Run())
+}
+
+// Test that the exclude file parses without error.
+func TestBlacklists(t *testing.T) {
+ ex, err := getExcludes()
+ if err != nil {
+ t.Fatalf("error parsing exclude file: %v", err)
+ }
+ if *excludeFile != "" && len(ex) == 0 {
+ t.Errorf("got empty excludes for file %q", *excludeFile)
+ }
+}
diff --git a/test/runtimes/runner/main.go b/test/runtimes/runner/main.go
new file mode 100644
index 000000000..2a0f62c73
--- /dev/null
+++ b/test/runtimes/runner/main.go
@@ -0,0 +1,197 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary runner runs the runtime tests in a Docker container.
+package main
+
+import (
+ "context"
+ "encoding/csv"
+ "flag"
+ "fmt"
+ "io"
+ "os"
+ "sort"
+ "strings"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+var (
+ lang = flag.String("lang", "", "language runtime to test")
+ image = flag.String("image", "", "docker image with runtime tests")
+ excludeFile = flag.String("exclude_file", "", "file containing list of tests to exclude, in CSV format with fields: test name, bug id, comment")
+)
+
+// Wait time for each test to run.
+const timeout = 5 * time.Minute
+
+func main() {
+ flag.Parse()
+ if *lang == "" || *image == "" {
+ fmt.Fprintf(os.Stderr, "lang and image flags must not be empty\n")
+ os.Exit(1)
+ }
+ os.Exit(runTests())
+}
+
+// runTests is a helper that is called by main. It exists so that we can run
+// defered functions before exiting. It returns an exit code that should be
+// passed to os.Exit.
+func runTests() int {
+ // Get tests to exclude..
+ excludes, err := getExcludes()
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error getting exclude list: %s\n", err.Error())
+ return 1
+ }
+
+ // Construct the shared docker instance.
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, testutil.DefaultLogger(*lang))
+ defer d.CleanUp(ctx)
+
+ if err := testutil.TouchShardStatusFile(); err != nil {
+ fmt.Fprintf(os.Stderr, "error touching status shard file: %v\n", err)
+ return 1
+ }
+
+ // Get a slice of tests to run. This will also start a single Docker
+ // container that will be used to run each test. The final test will
+ // stop the Docker container.
+ tests, err := getTests(ctx, d, excludes)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "%s\n", err.Error())
+ return 1
+ }
+
+ m := testing.MainStart(testDeps{}, tests, nil, nil)
+ return m.Run()
+}
+
+// getTests executes all tests as table tests.
+func getTests(ctx context.Context, d *dockerutil.Container, excludes map[string]struct{}) ([]testing.InternalTest, error) {
+ // Start the container.
+ opts := dockerutil.RunOpts{
+ Image: fmt.Sprintf("runtimes/%s", *image),
+ }
+ d.CopyFiles(&opts, "/proctor", "test/runtimes/proctor/proctor")
+ if err := d.Spawn(ctx, opts, "/proctor/proctor", "--pause"); err != nil {
+ return nil, fmt.Errorf("docker run failed: %v", err)
+ }
+
+ // Get a list of all tests in the image.
+ list, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", *lang, "--list")
+ if err != nil {
+ return nil, fmt.Errorf("docker exec failed: %v", err)
+ }
+
+ // Calculate a subset of tests to run corresponding to the current
+ // shard.
+ tests := strings.Fields(list)
+ sort.Strings(tests)
+ indices, err := testutil.TestIndicesForShard(len(tests))
+ if err != nil {
+ return nil, fmt.Errorf("TestsForShard() failed: %v", err)
+ }
+
+ var itests []testing.InternalTest
+ for _, tci := range indices {
+ // Capture tc in this scope.
+ tc := tests[tci]
+ itests = append(itests, testing.InternalTest{
+ Name: tc,
+ F: func(t *testing.T) {
+ // Is the test excluded?
+ if _, ok := excludes[tc]; ok {
+ t.Skipf("SKIP: excluded test %q", tc)
+ }
+
+ var (
+ now = time.Now()
+ done = make(chan struct{})
+ output string
+ err error
+ )
+
+ go func() {
+ fmt.Printf("RUNNING %s...\n", tc)
+ output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", *lang, "--test", tc)
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ if err == nil {
+ fmt.Printf("PASS: %s (%v)\n\n", tc, time.Since(now))
+ return
+ }
+ t.Errorf("FAIL: %s (%v):\n%s\n", tc, time.Since(now), output)
+ case <-time.After(timeout):
+ t.Errorf("TIMEOUT: %s (%v):\n%s\n", tc, time.Since(now), output)
+ }
+ },
+ })
+ }
+
+ return itests, nil
+}
+
+// getBlacklist reads the exclude file and returns a set of test names to
+// exclude.
+func getExcludes() (map[string]struct{}, error) {
+ excludes := make(map[string]struct{})
+ if *excludeFile == "" {
+ return excludes, nil
+ }
+ f, err := os.Open(*excludeFile)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+
+ r := csv.NewReader(f)
+
+ // First line is header. Skip it.
+ if _, err := r.Read(); err != nil {
+ return nil, err
+ }
+
+ for {
+ record, err := r.Read()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return nil, err
+ }
+ excludes[record[0]] = struct{}{}
+ }
+ return excludes, nil
+}
+
+// testDeps implements testing.testDeps (an unexported interface), and is
+// required to use testing.MainStart.
+type testDeps struct{}
+
+func (f testDeps) MatchString(a, b string) (bool, error) { return a == b, nil }
+func (f testDeps) StartCPUProfile(io.Writer) error { return nil }
+func (f testDeps) StopCPUProfile() {}
+func (f testDeps) WriteProfileTo(string, io.Writer, int) error { return nil }
+func (f testDeps) ImportPath() string { return "" }
+func (f testDeps) StartTestLog(io.Writer) {}
+func (f testDeps) StopTestLog() error { return nil }
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
new file mode 100644
index 000000000..28ef55945
--- /dev/null
+++ b/test/syscalls/BUILD
@@ -0,0 +1,1121 @@
+load("//test/runner:defs.bzl", "syscall_test")
+
+package(licenses = ["notice"])
+
+syscall_test(
+ test = "//test/syscalls/linux:32bit_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:accept_bind_stream_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "large",
+ shard_count = 50,
+ test = "//test/syscalls/linux:accept_bind_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:access_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:affinity_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:aio_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ shard_count = 5,
+ test = "//test/syscalls/linux:alarm_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:arch_prctl_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:bad_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/syscalls/linux:bind_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:brk_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:socket_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:socket_capability_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "large",
+ shard_count = 50,
+ # Takes too long for TSAN. Since this is kind of a stress test that doesn't
+ # involve much concurrency, TSAN's usefulness here is limited anyway.
+ tags = ["nogotsan"],
+ test = "//test/syscalls/linux:socket_stress_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:chdir_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:chmod_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ add_overlay = True,
+ test = "//test/syscalls/linux:chown_test",
+ use_tmpfs = True, # chwon tests require gofer to be running as root.
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:chroot_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:clock_getres_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:clock_gettime_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:clock_nanosleep_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:concurrency_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_uds_tree = True,
+ test = "//test/syscalls/linux:connect_external_test",
+ use_tmpfs = True,
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:creat_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:dev_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:dup_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:epoll_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:eventfd_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:exceptions_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ add_overlay = True,
+ test = "//test/syscalls/linux:exec_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ add_overlay = True,
+ test = "//test/syscalls/linux:exec_binary_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:exit_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:fadvise64_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:fallocate_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:fault_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:fchdir_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:fcntl_test",
+)
+
+syscall_test(
+ size = "medium",
+ add_overlay = True,
+ test = "//test/syscalls/linux:flock_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:fork_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:fpsig_fork_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:fpsig_nested_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:fsync_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ shard_count = 5,
+ test = "//test/syscalls/linux:futex_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:getcpu_host_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:getcpu_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:getdents_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:getrandom_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:getrusage_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ add_overlay = False, # TODO(gvisor.dev/issue/317): enable when fixed.
+ test = "//test/syscalls/linux:inotify_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ add_overlay = True,
+ test = "//test/syscalls/linux:ioctl_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:iptables_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "large",
+ shard_count = 5,
+ test = "//test/syscalls/linux:itimer_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:kill_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:link_test",
+ use_tmpfs = True, # gofer needs CAP_DAC_READ_SEARCH to use AT_EMPTY_PATH with linkat(2)
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:lseek_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:madvise_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:memory_accounting_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:mempolicy_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:mincore_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:mkdir_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:mknod_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ shard_count = 5,
+ test = "//test/syscalls/linux:mmap_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:mount_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:mremap_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:msync_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:munmap_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:network_namespace_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:open_create_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:open_test",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:packet_socket_raw_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:packet_socket_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:partial_bad_buffer_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:pause_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ # Takes too long under gotsan to run.
+ tags = ["nogotsan"],
+ test = "//test/syscalls/linux:ping_socket_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ shard_count = 5,
+ test = "//test/syscalls/linux:pipe_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:poll_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:ppoll_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:prctl_setuid_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:prctl_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:pread64_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:preadv_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:preadv2_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:priority_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:proc_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:proc_net_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:proc_pid_oomscore_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:proc_pid_smaps_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:proc_pid_uid_gid_map_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:pselect_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:ptrace_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ shard_count = 5,
+ test = "//test/syscalls/linux:pty_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:pty_root_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:pwritev2_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:pwrite64_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:raw_socket_hdrincl_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:raw_socket_icmp_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:raw_socket_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:read_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:readahead_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ shard_count = 5,
+ test = "//test/syscalls/linux:readv_socket_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ add_overlay = True,
+ test = "//test/syscalls/linux:readv_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ add_overlay = True,
+ test = "//test/syscalls/linux:rename_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:rlimits_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:rseq_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:rtsignal_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:signalfd_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:sched_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:sched_yield_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:seccomp_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:select_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ shard_count = 20,
+ test = "//test/syscalls/linux:semaphore_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:sendfile_socket_test",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:sendfile_test",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:splice_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:sigaction_test",
+ vfs2 = "True",
+)
+
+# TODO(b/119826902): Enable once the test passes in runsc.
+# syscall_test(vfs2="True",test = "//test/syscalls/linux:sigaltstack_test")
+
+syscall_test(
+ test = "//test/syscalls/linux:sigiret_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:sigprocmask_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:sigstop_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:sigtimedwait_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:shm_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:socket_abstract_non_blocking_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "large",
+ shard_count = 50,
+ test = "//test/syscalls/linux:socket_abstract_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:socket_domain_non_blocking_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "large",
+ shard_count = 50,
+ test = "//test/syscalls/linux:socket_domain_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ add_overlay = True,
+ test = "//test/syscalls/linux:socket_filesystem_non_blocking_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ shard_count = 50,
+ test = "//test/syscalls/linux:socket_filesystem_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "large",
+ shard_count = 50,
+ test = "//test/syscalls/linux:socket_inet_loopback_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "large",
+ shard_count = 50,
+ # Takes too long for TSAN. Creates a lot of TCP sockets.
+ tags = ["nogotsan"],
+ test = "//test/syscalls/linux:socket_inet_loopback_nogotsan_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "large",
+ shard_count = 50,
+ test = "//test/syscalls/linux:socket_ip_tcp_generic_loopback_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:socket_ip_tcp_loopback_non_blocking_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "large",
+ shard_count = 50,
+ test = "//test/syscalls/linux:socket_ip_tcp_loopback_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ shard_count = 50,
+ test = "//test/syscalls/linux:socket_ip_tcp_udp_generic_loopback_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:socket_ip_udp_loopback_non_blocking_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "large",
+ shard_count = 50,
+ test = "//test/syscalls/linux:socket_ip_udp_loopback_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:socket_ip_unbound_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:socket_netdevice_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:socket_netlink_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:socket_netlink_route_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:socket_netlink_uevent_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:socket_blocking_local_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:socket_blocking_ip_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:socket_non_stream_blocking_local_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:socket_non_stream_blocking_udp_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/syscalls/linux:socket_stream_blocking_local_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/syscalls/linux:socket_stream_blocking_tcp_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:socket_stream_local_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:socket_stream_nonblock_local_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ # NOTE(b/116636318): Large sendmsg may stall a long time.
+ size = "enormous",
+ shard_count = 5,
+ test = "//test/syscalls/linux:socket_unix_dgram_local_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:socket_unix_dgram_non_blocking_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ shard_count = 50,
+ test = "//test/syscalls/linux:socket_unix_pair_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ # NOTE(b/116636318): Large sendmsg may stall a long time.
+ size = "enormous",
+ shard_count = 5,
+ test = "//test/syscalls/linux:socket_unix_seqpacket_local_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:socket_unix_stream_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:socket_unix_unbound_abstract_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:socket_unix_unbound_dgram_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:socket_unix_unbound_filesystem_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ shard_count = 10,
+ test = "//test/syscalls/linux:socket_unix_unbound_seqpacket_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "large",
+ shard_count = 50,
+ test = "//test/syscalls/linux:socket_unix_unbound_stream_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:statfs_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:stat_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:stat_times_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:sticky_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:symlink_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:sync_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:sync_file_range_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:sysinfo_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:syslog_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:sysret_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ shard_count = 10,
+ test = "//test/syscalls/linux:tcp_socket_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:tgkill_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:timerfd_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:timers_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:time_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:tkill_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:truncate_test",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:tuntap_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_hostinet = True,
+ test = "//test/syscalls/linux:tuntap_hostinet_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:udp_bind_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ add_hostinet = True,
+ shard_count = 10,
+ test = "//test/syscalls/linux:udp_socket_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:uidgid_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:uname_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:unlink_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:unshare_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:utimes_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ test = "//test/syscalls/linux:vdso_clock_gettime_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:vdso_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:vsyscall_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:vfork_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ size = "medium",
+ shard_count = 5,
+ test = "//test/syscalls/linux:wait_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:write_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:proc_net_unix_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:proc_net_tcp_test",
+ vfs2 = "True",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:proc_net_udp_test",
+ vfs2 = "True",
+)
diff --git a/test/syscalls/README.md b/test/syscalls/README.md
new file mode 100644
index 000000000..9e0991940
--- /dev/null
+++ b/test/syscalls/README.md
@@ -0,0 +1,107 @@
+# gVisor system call test suite
+
+This is a test suite for Linux system calls. It runs under both gVisor and
+Linux, and ensures compatibility between the two.
+
+When adding support for a new syscall (or syscall argument) to gVisor, a
+corresponding syscall test should be added. It's usually recommended to write
+the test first and make sure that it passes on Linux before making changes to
+gVisor.
+
+This document outlines the general guidelines for tests and specific rules that
+must be followed for new tests.
+
+## Running the tests
+
+Each test file generates three different test targets that run in different
+environments:
+
+* a `native` target that runs directly on the host machine,
+* a `runsc_ptrace` target that runs inside runsc using the ptrace platform, and
+* a `runsc_kvm` target that runs inside runsc using the KVM platform.
+
+For example, the test in `access_test.cc` generates the following targets:
+
+* `//test/syscalls:access_test_native`
+* `//test/syscalls:access_test_runsc_ptrace`
+* `//test/syscalls:access_test_runsc_kvm`
+
+Any of these targets can be run directly via `bazel test`.
+
+```bash
+$ bazel test //test/syscalls:access_test_native
+$ bazel test //test/syscalls:access_test_runsc_ptrace
+$ bazel test //test/syscalls:access_test_runsc_kvm
+```
+
+To run all the tests on a particular platform, you can filter by the platform
+tag:
+
+```bash
+# Run all tests in native environment:
+$ bazel test --test_tag_filters=native //test/syscalls/...
+
+# Run all tests in runsc with ptrace:
+$ bazel test --test_tag_filters=runsc_ptrace //test/syscalls/...
+
+# Run all tests in runsc with kvm:
+$ bazel test --test_tag_filters=runsc_kvm //test/syscalls/...
+```
+
+You can also run all the tests on every platform. (Warning, this may take a
+while to run.)
+
+```bash
+# Run all tests on every platform:
+$ bazel test //test/syscalls/...
+```
+
+## Writing new tests
+
+Whenever we add support for a new syscall, or add support for a new argument or
+option for a syscall, we should always add a new test (perhaps many new tests).
+
+In general, it is best to write the test first and make sure it passes on Linux
+by running the test on the `native` platform on a Linux machine. This ensures
+that the gVisor implementation matches actual Linux behavior. Sometimes man
+pages contain errors, so always check the actual Linux behavior.
+
+gVisor uses the [Google Test][googletest] test framework, with a few custom
+matchers and guidelines, described below.
+
+### Syscall matchers
+
+When testing an individual system call, use the following syscall matchers,
+which will match the value returned by the syscall and the errno.
+
+```cc
+SyscallSucceeds()
+SyscallSucceedsWithValue(...)
+SyscallFails()
+SyscallFailsWithErrno(...)
+```
+
+### Use test utilities (RAII classes)
+
+The test utilties are written as RAII classes. These utilities should be
+preferred over custom test harnesses.
+
+Local class instances should be preferred, wherever possible, over full test
+fixtures.
+
+A test utility should be created when there is more than one test that requires
+that same functionality, otherwise the class should be test local.
+
+## Save/Restore support in tests
+
+gVisor supports save/restore, and our syscall tests are written in a way to
+enable saving/restoring at certain points. Hence, there are calls to
+`MaybeSave`, and certain tests that should not trigger saves are named with
+`NoSave`.
+
+However, the current open-source test runner does not yet support triggering
+save/restore, so these functions and annotations have no effect on the tests. We
+plan on extending the test runner to trigger save/restore. Until then, these
+functions and annotations should be ignored.
+
+[googletest]: https://github.com/abseil/googletest
diff --git a/test/syscalls/linux/32bit.cc b/test/syscalls/linux/32bit.cc
new file mode 100644
index 000000000..3c825477c
--- /dev/null
+++ b/test/syscalls/linux/32bit.cc
@@ -0,0 +1,248 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string.h>
+#include <sys/mman.h>
+
+#include "gtest/gtest.h"
+#include "absl/base/macros.h"
+#include "test/util/memory_util.h"
+#include "test/util/platform_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+#ifndef __x86_64__
+#error "This test is x86-64 specific."
+#endif
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr char kInt3 = '\xcc';
+constexpr char kInt80[2] = {'\xcd', '\x80'};
+constexpr char kSyscall[2] = {'\x0f', '\x05'};
+constexpr char kSysenter[2] = {'\x0f', '\x34'};
+
+void ExitGroup32(const char instruction[2], int code) {
+ const Mapping m = ASSERT_NO_ERRNO_AND_VALUE(
+ Mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE | PROT_EXEC,
+ MAP_PRIVATE | MAP_ANONYMOUS | MAP_32BIT, -1, 0));
+
+ // Fill with INT 3 in case we execute too far.
+ memset(m.ptr(), kInt3, m.len());
+
+ // Copy in the actual instruction.
+ memcpy(m.ptr(), instruction, 2);
+
+ // We're playing *extremely* fast-and-loose with the various syscall ABIs
+ // here, which we can more-or-less get away with since exit_group doesn't
+ // return.
+ //
+ // SYSENTER expects the user stack in (%ebp) and arg6 in 0(%ebp). The kernel
+ // will unconditionally dereference %ebp for arg6, so we must pass a valid
+ // address or it will return EFAULT.
+ //
+ // SYSENTER also unconditionally returns to thread_info->sysenter_return which
+ // is ostensibly a stub in the 32-bit VDSO. But a 64-bit binary doesn't have
+ // the 32-bit VDSO mapped, so sysenter_return will simply be the value
+ // inherited from the most recent 32-bit ancestor, or NULL if there is none.
+ // As a result, return would not return from SYSENTER.
+ asm volatile(
+ "movl $252, %%eax\n" // exit_group
+ "movl %[code], %%ebx\n" // code
+ "movl %%edx, %%ebp\n" // SYSENTER: user stack (use IP as a valid addr)
+ "leaq -20(%%rsp), %%rsp\n"
+ "movl $0x2b, 16(%%rsp)\n" // SS = CPL3 data segment
+ "movl $0,12(%%rsp)\n" // ESP = nullptr (unused)
+ "movl $0, 8(%%rsp)\n" // EFLAGS
+ "movl $0x23, 4(%%rsp)\n" // CS = CPL3 32-bit code segment
+ "movl %%edx, 0(%%rsp)\n" // EIP
+ "iretl\n"
+ "int $3\n"
+ :
+ : [ code ] "m"(code), [ ip ] "d"(m.ptr())
+ : "rax", "rbx");
+}
+
+constexpr int kExitCode = 42;
+
+TEST(Syscall32Bit, Int80) {
+ switch (PlatformSupport32Bit()) {
+ case PlatformSupport::NotSupported:
+ break;
+ case PlatformSupport::Segfault:
+ EXPECT_EXIT(ExitGroup32(kInt80, kExitCode),
+ ::testing::KilledBySignal(SIGSEGV), "");
+ break;
+
+ case PlatformSupport::Ignored:
+ // Since the call is ignored, we'll hit the int3 trap.
+ EXPECT_EXIT(ExitGroup32(kInt80, kExitCode),
+ ::testing::KilledBySignal(SIGTRAP), "");
+ break;
+
+ case PlatformSupport::Allowed:
+ EXPECT_EXIT(ExitGroup32(kInt80, kExitCode), ::testing::ExitedWithCode(42),
+ "");
+ break;
+ }
+}
+
+TEST(Syscall32Bit, Sysenter) {
+ if ((PlatformSupport32Bit() == PlatformSupport::Allowed ||
+ PlatformSupport32Bit() == PlatformSupport::Ignored) &&
+ GetCPUVendor() == CPUVendor::kAMD) {
+ // SYSENTER is an illegal instruction in compatibility mode on AMD.
+ EXPECT_EXIT(ExitGroup32(kSysenter, kExitCode),
+ ::testing::KilledBySignal(SIGILL), "");
+ return;
+ }
+
+ switch (PlatformSupport32Bit()) {
+ case PlatformSupport::NotSupported:
+ break;
+
+ case PlatformSupport::Segfault:
+ EXPECT_EXIT(ExitGroup32(kSysenter, kExitCode),
+ ::testing::KilledBySignal(SIGSEGV), "");
+ break;
+
+ case PlatformSupport::Ignored:
+ // See above, except expected code is SIGSEGV.
+ EXPECT_EXIT(ExitGroup32(kSysenter, kExitCode),
+ ::testing::KilledBySignal(SIGSEGV), "");
+ break;
+
+ case PlatformSupport::Allowed:
+ EXPECT_EXIT(ExitGroup32(kSysenter, kExitCode),
+ ::testing::ExitedWithCode(42), "");
+ break;
+ }
+}
+
+TEST(Syscall32Bit, Syscall) {
+ if ((PlatformSupport32Bit() == PlatformSupport::Allowed ||
+ PlatformSupport32Bit() == PlatformSupport::Ignored) &&
+ GetCPUVendor() == CPUVendor::kIntel) {
+ // SYSCALL is an illegal instruction in compatibility mode on Intel.
+ EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode),
+ ::testing::KilledBySignal(SIGILL), "");
+ return;
+ }
+
+ switch (PlatformSupport32Bit()) {
+ case PlatformSupport::NotSupported:
+ break;
+
+ case PlatformSupport::Segfault:
+ EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode),
+ ::testing::KilledBySignal(SIGSEGV), "");
+ break;
+
+ case PlatformSupport::Ignored:
+ // See above.
+ EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode),
+ ::testing::KilledBySignal(SIGSEGV), "");
+ break;
+
+ case PlatformSupport::Allowed:
+ EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode),
+ ::testing::ExitedWithCode(42), "");
+ break;
+ }
+}
+
+// Far call code called below.
+//
+// Input stack layout:
+//
+// %esp+12 lcall segment
+// %esp+8 lcall address offset
+// %esp+0 return address
+//
+// The lcall will enter compatibility mode and jump to the call address (the
+// address of the lret). The lret will return to 64-bit mode at the retq, which
+// will return to the external caller of this function.
+//
+// Since this enters compatibility mode, it must be mapped in a 32-bit region of
+// address space and have a 32-bit stack pointer.
+constexpr char kFarCall[] = {
+ '\x67', '\xff', '\x5c', '\x24', '\x08', // lcall *8(%esp)
+ '\xc3', // retq
+ '\xcb', // lret
+};
+
+void FarCall32() {
+ const Mapping m = ASSERT_NO_ERRNO_AND_VALUE(
+ Mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE | PROT_EXEC,
+ MAP_PRIVATE | MAP_ANONYMOUS | MAP_32BIT, -1, 0));
+
+ // Fill with INT 3 in case we execute too far.
+ memset(m.ptr(), kInt3, m.len());
+
+ // 32-bit code.
+ memcpy(m.ptr(), kFarCall, sizeof(kFarCall));
+
+ // Use the end of the code page as its stack.
+ uintptr_t stack = m.endaddr();
+
+ uintptr_t lcall = m.addr();
+ uintptr_t lret = m.addr() + sizeof(kFarCall) - 1;
+
+ // N.B. We must save and restore RSP manually. GCC can do so automatically
+ // with an "rsp" clobber, but clang cannot.
+ asm volatile(
+ // Place the address of lret (%edx) and the 32-bit code segment (0x23) on
+ // the 32-bit stack for lcall.
+ "subl $0x8, %%ecx\n"
+ "movl $0x23, 4(%%ecx)\n"
+ "movl %%edx, 0(%%ecx)\n"
+
+ // Save the current stack and switch to 32-bit stack.
+ "pushq %%rbp\n"
+ "movq %%rsp, %%rbp\n"
+ "movq %%rcx, %%rsp\n"
+
+ // Run the lcall code.
+ "callq *%%rbx\n"
+
+ // Restore the old stack.
+ "leaveq\n"
+ : "+c"(stack)
+ : "b"(lcall), "d"(lret));
+}
+
+TEST(Call32Bit, Disallowed) {
+ switch (PlatformSupport32Bit()) {
+ case PlatformSupport::NotSupported:
+ break;
+
+ case PlatformSupport::Segfault:
+ EXPECT_EXIT(FarCall32(), ::testing::KilledBySignal(SIGSEGV), "");
+ break;
+
+ case PlatformSupport::Ignored:
+ ABSL_FALLTHROUGH_INTENDED;
+ case PlatformSupport::Allowed:
+ // Shouldn't crash.
+ FarCall32();
+ }
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
new file mode 100644
index 000000000..9e097c888
--- /dev/null
+++ b/test/syscalls/linux/BUILD
@@ -0,0 +1,3933 @@
+load("//tools:defs.bzl", "cc_binary", "cc_library", "default_net_util", "gtest", "select_arch", "select_system")
+
+package(
+ default_visibility = ["//:sandbox"],
+ licenses = ["notice"],
+)
+
+exports_files(
+ [
+ "socket.cc",
+ "socket_inet_loopback.cc",
+ "socket_ip_loopback_blocking.cc",
+ "socket_ip_tcp_generic_loopback.cc",
+ "socket_ip_tcp_loopback.cc",
+ "socket_ip_tcp_loopback_blocking.cc",
+ "socket_ip_tcp_loopback_nonblock.cc",
+ "socket_ip_tcp_udp_generic.cc",
+ "socket_ip_udp_loopback.cc",
+ "socket_ip_udp_loopback_blocking.cc",
+ "socket_ip_udp_loopback_nonblock.cc",
+ "socket_ip_unbound.cc",
+ "socket_ipv4_tcp_unbound_external_networking_test.cc",
+ "socket_ipv4_udp_unbound_external_networking_test.cc",
+ "socket_ipv4_udp_unbound_loopback.cc",
+ "tcp_socket.cc",
+ "udp_bind.cc",
+ "udp_socket.cc",
+ ],
+ visibility = ["//:sandbox"],
+)
+
+cc_binary(
+ name = "sigaltstack_check",
+ testonly = 1,
+ srcs = ["sigaltstack_check.cc"],
+ deps = ["//test/util:logging"],
+)
+
+cc_binary(
+ name = "exec_assert_closed_workload",
+ testonly = 1,
+ srcs = ["exec_assert_closed_workload.cc"],
+ deps = [
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_binary(
+ name = "exec_basic_workload",
+ testonly = 1,
+ srcs = [
+ "exec.h",
+ "exec_basic_workload.cc",
+ ],
+)
+
+cc_binary(
+ name = "exec_proc_exe_workload",
+ testonly = 1,
+ srcs = ["exec_proc_exe_workload.cc"],
+ deps = [
+ "//test/util:fs_util",
+ "//test/util:posix_error",
+ ],
+)
+
+cc_binary(
+ name = "exec_state_workload",
+ testonly = 1,
+ srcs = ["exec_state_workload.cc"],
+ deps = ["@com_google_absl//absl/strings"],
+)
+
+sh_binary(
+ name = "exit_script",
+ testonly = 1,
+ srcs = [
+ "exit_script.sh",
+ ],
+)
+
+cc_binary(
+ name = "priority_execve",
+ testonly = 1,
+ srcs = [
+ "priority_execve.cc",
+ ],
+)
+
+cc_library(
+ name = "base_poll_test",
+ testonly = 1,
+ srcs = ["base_poll_test.cc"],
+ hdrs = ["base_poll_test.h"],
+ deps = [
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:logging",
+ "//test/util:signal_util",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_library(
+ name = "file_base",
+ testonly = 1,
+ hdrs = ["file_base.h"],
+ deps = [
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_util",
+ ],
+)
+
+cc_library(
+ name = "socket_netlink_util",
+ testonly = 1,
+ srcs = ["socket_netlink_util.cc"],
+ hdrs = ["socket_netlink_util.h"],
+ deps = [
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ "//test/util:posix_error",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "socket_netlink_route_util",
+ testonly = 1,
+ srcs = ["socket_netlink_route_util.cc"],
+ hdrs = ["socket_netlink_route_util.h"],
+ deps = [
+ ":socket_netlink_util",
+ ],
+)
+
+cc_library(
+ name = "socket_test_util",
+ testonly = 1,
+ srcs = [
+ "socket_test_util.cc",
+ "socket_test_util_impl.cc",
+ ],
+ hdrs = ["socket_test_util.h"],
+ defines = select_system(),
+ deps = default_net_util() + [
+ gtest,
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/time",
+ "@com_google_absl//absl/types:optional",
+ "//test/util:file_descriptor",
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_library(
+ name = "unix_domain_socket_test_util",
+ testonly = 1,
+ srcs = ["unix_domain_socket_test_util.cc"],
+ hdrs = ["unix_domain_socket_test_util.h"],
+ deps = [
+ ":socket_test_util",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:test_util",
+ ],
+)
+
+cc_library(
+ name = "ip_socket_test_util",
+ testonly = 1,
+ srcs = ["ip_socket_test_util.cc"],
+ hdrs = ["ip_socket_test_util.h"],
+ deps = [
+ ":socket_test_util",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_binary(
+ name = "clock_nanosleep_test",
+ testonly = 1,
+ srcs = ["clock_nanosleep.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:cleanup",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:signal_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "//test/util:timer_util",
+ ],
+)
+
+cc_binary(
+ name = "32bit_test",
+ testonly = 1,
+ srcs = select_arch(
+ amd64 = ["32bit.cc"],
+ arm64 = [],
+ ),
+ linkstatic = 1,
+ deps = [
+ "@com_google_absl//absl/base:core_headers",
+ gtest,
+ "//test/util:memory_util",
+ "//test/util:platform_util",
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "accept_bind_test",
+ testonly = 1,
+ srcs = ["accept_bind.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "accept_bind_stream_test",
+ testonly = 1,
+ srcs = ["accept_bind_stream.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "access_test",
+ testonly = 1,
+ srcs = ["access.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:fs_util",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "affinity_test",
+ testonly = 1,
+ srcs = ["affinity.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:cleanup",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "aio_test",
+ testonly = 1,
+ srcs = [
+ "aio.cc",
+ "file_base.h",
+ ],
+ linkstatic = 1,
+ deps = [
+ "//test/util:cleanup",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:memory_util",
+ "//test/util:posix_error",
+ "//test/util:proc_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "alarm_test",
+ testonly = 1,
+ srcs = ["alarm.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:logging",
+ "//test/util:signal_util",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "bad_test",
+ testonly = 1,
+ srcs = ["bad.cc"],
+ linkstatic = 1,
+ visibility = [
+ "//:sandbox",
+ ],
+ deps = [
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "bind_test",
+ testonly = 1,
+ srcs = ["bind.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_test",
+ testonly = 1,
+ srcs = ["socket.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ gtest,
+ "//test/util:file_descriptor",
+ "//test/util:temp_umask",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_capability_test",
+ testonly = 1,
+ srcs = ["socket_capability.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "brk_test",
+ testonly = 1,
+ srcs = ["brk.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "chdir_test",
+ testonly = 1,
+ srcs = ["chdir.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "chmod_test",
+ testonly = 1,
+ srcs = ["chmod.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "chown_test",
+ testonly = 1,
+ srcs = ["chown.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/synchronization",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "sticky_test",
+ testonly = 1,
+ srcs = ["sticky.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/flags:flag",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "chroot_test",
+ testonly = 1,
+ srcs = ["chroot.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:cleanup",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:mount_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "clock_getres_test",
+ testonly = 1,
+ srcs = ["clock_getres.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "clock_gettime_test",
+ testonly = 1,
+ srcs = ["clock_gettime.cc"],
+ linkstatic = 1,
+ deps = [
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "concurrency_test",
+ testonly = 1,
+ srcs = ["concurrency.cc"],
+ linkstatic = 1,
+ deps = [
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:platform_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "connect_external_test",
+ testonly = 1,
+ srcs = ["connect_external.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "creat_test",
+ testonly = 1,
+ srcs = ["creat.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:fs_util",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "dev_test",
+ testonly = 1,
+ srcs = ["dev.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "dup_test",
+ testonly = 1,
+ srcs = ["dup.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:eventfd_util",
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "epoll_test",
+ testonly = 1,
+ srcs = ["epoll.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:epoll_util",
+ "//test/util:eventfd_util",
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "eventfd_test",
+ testonly = 1,
+ srcs = ["eventfd.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:epoll_util",
+ "//test/util:eventfd_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "exceptions_test",
+ testonly = 1,
+ srcs = select_arch(
+ amd64 = ["exceptions.cc"],
+ arm64 = [],
+ ),
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:logging",
+ "//test/util:platform_util",
+ "//test/util:signal_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "getcpu_test",
+ testonly = 1,
+ srcs = ["getcpu.cc"],
+ linkstatic = 1,
+ deps = [
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "getcpu_host_test",
+ testonly = 1,
+ srcs = ["getcpu.cc"],
+ linkstatic = 1,
+ deps = [
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "getrusage_test",
+ testonly = 1,
+ srcs = ["getrusage.cc"],
+ linkstatic = 1,
+ deps = [
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:logging",
+ "//test/util:memory_util",
+ "//test/util:signal_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "exec_binary_test",
+ testonly = 1,
+ srcs = ["exec_binary.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:cleanup",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:multiprocess_util",
+ "//test/util:posix_error",
+ "//test/util:proc_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "exec_test",
+ testonly = 1,
+ srcs = [
+ "exec.cc",
+ "exec.h",
+ ],
+ data = [
+ ":exec_assert_closed_workload",
+ ":exec_basic_workload",
+ ":exec_proc_exe_workload",
+ ":exec_state_workload",
+ ":exit_script",
+ ":priority_execve",
+ ],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/types:optional",
+ gtest,
+ "//test/util:multiprocess_util",
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "exit_test",
+ testonly = 1,
+ srcs = ["exit.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:time_util",
+ ],
+)
+
+cc_binary(
+ name = "fallocate_test",
+ testonly = 1,
+ srcs = ["fallocate.cc"],
+ linkstatic = 1,
+ deps = [
+ ":file_base",
+ ":socket_test_util",
+ "//test/util:cleanup",
+ "//test/util:eventfd_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "fault_test",
+ testonly = 1,
+ srcs = ["fault.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "fchdir_test",
+ testonly = 1,
+ srcs = ["fchdir.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "fcntl_test",
+ testonly = 1,
+ srcs = ["fcntl.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ "//test/util:cleanup",
+ "//test/util:epoll_util",
+ "//test/util:eventfd_util",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:multiprocess_util",
+ "//test/util:posix_error",
+ "//test/util:save_util",
+ "//test/util:temp_path",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "//test/util:timer_util",
+ ],
+)
+
+cc_binary(
+ name = "flock_test",
+ testonly = 1,
+ srcs = [
+ "file_base.h",
+ "flock.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:epoll_util",
+ "//test/util:eventfd_util",
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "//test/util:timer_util",
+ ],
+)
+
+cc_binary(
+ name = "fork_test",
+ testonly = 1,
+ srcs = ["fork.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:logging",
+ "//test/util:memory_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "fpsig_fork_test",
+ testonly = 1,
+ srcs = ["fpsig_fork.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "fpsig_nested_test",
+ testonly = 1,
+ srcs = ["fpsig_nested.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "sync_file_range_test",
+ testonly = 1,
+ srcs = ["sync_file_range.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "fsync_test",
+ testonly = 1,
+ srcs = ["fsync.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "futex_test",
+ testonly = 1,
+ srcs = ["futex.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:cleanup",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:memory_util",
+ "//test/util:save_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "//test/util:time_util",
+ "//test/util:timer_util",
+ ],
+)
+
+cc_binary(
+ name = "getdents_test",
+ testonly = 1,
+ srcs = ["getdents.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:eventfd_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "getrandom_test",
+ testonly = 1,
+ srcs = ["getrandom.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "inotify_test",
+ testonly = 1,
+ srcs = ["inotify.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:epoll_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_binary(
+ name = "ioctl_test",
+ testonly = 1,
+ srcs = ["ioctl.cc"],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:signal_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_library(
+ name = "iptables_types",
+ testonly = 1,
+ hdrs = [
+ "iptables.h",
+ ],
+)
+
+cc_binary(
+ name = "iptables_test",
+ testonly = 1,
+ srcs = [
+ "iptables.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":iptables_types",
+ ":socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "itimer_test",
+ testonly = 1,
+ srcs = ["itimer.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:logging",
+ "//test/util:multiprocess_util",
+ "//test/util:posix_error",
+ "//test/util:signal_util",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "//test/util:timer_util",
+ ],
+)
+
+cc_binary(
+ name = "kill_test",
+ testonly = 1,
+ srcs = ["kill.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:logging",
+ "//test/util:signal_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "link_test",
+ testonly = 1,
+ srcs = ["link.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "lseek_test",
+ testonly = 1,
+ srcs = ["lseek.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "madvise_test",
+ testonly = 1,
+ srcs = ["madvise.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:logging",
+ "//test/util:memory_util",
+ "//test/util:multiprocess_util",
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "mempolicy_test",
+ testonly = 1,
+ srcs = ["mempolicy.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:cleanup",
+ "@com_google_absl//absl/memory",
+ gtest,
+ "//test/util:memory_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "mincore_test",
+ testonly = 1,
+ srcs = ["mincore.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:memory_util",
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "mkdir_test",
+ testonly = 1,
+ srcs = ["mkdir.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:fs_util",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:temp_umask",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "mknod_test",
+ testonly = 1,
+ srcs = ["mknod.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "mlock_test",
+ testonly = 1,
+ srcs = ["mlock.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:cleanup",
+ gtest,
+ "//test/util:memory_util",
+ "//test/util:multiprocess_util",
+ "//test/util:rlimit_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "mmap_test",
+ testonly = 1,
+ srcs = ["mmap.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:cleanup",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:memory_util",
+ "//test/util:multiprocess_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "mount_test",
+ testonly = 1,
+ srcs = ["mount.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:mount_util",
+ "//test/util:multiprocess_util",
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "mremap_test",
+ testonly = 1,
+ srcs = ["mremap.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:logging",
+ "//test/util:memory_util",
+ "//test/util:multiprocess_util",
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "msync_test",
+ testonly = 1,
+ srcs = ["msync.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "//test/util:memory_util",
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "munmap_test",
+ testonly = 1,
+ srcs = ["munmap.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "open_test",
+ testonly = 1,
+ srcs = [
+ "file_base.h",
+ "open.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:cleanup",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "open_create_test",
+ testonly = 1,
+ srcs = ["open_create.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:temp_umask",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "packet_socket_raw_test",
+ testonly = 1,
+ srcs = ["packet_socket_raw.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/base:endian",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "packet_socket_test",
+ testonly = 1,
+ srcs = ["packet_socket.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/base:endian",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "pty_test",
+ testonly = 1,
+ srcs = ["pty.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:pty_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "pty_root_test",
+ testonly = 1,
+ srcs = ["pty_root.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/base:core_headers",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:pty_util",
+ "//test/util:test_main",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "partial_bad_buffer_test",
+ testonly = 1,
+ srcs = ["partial_bad_buffer.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "pause_test",
+ testonly = 1,
+ srcs = ["pause.cc"],
+ linkstatic = 1,
+ deps = [
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:signal_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "ping_socket_test",
+ testonly = 1,
+ srcs = ["ping_socket.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:save_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "pipe_test",
+ testonly = 1,
+ srcs = ["pipe.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "poll_test",
+ testonly = 1,
+ srcs = ["poll.cc"],
+ linkstatic = 1,
+ deps = [
+ ":base_poll_test",
+ "//test/util:eventfd_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "ppoll_test",
+ testonly = 1,
+ srcs = ["ppoll.cc"],
+ linkstatic = 1,
+ deps = [
+ ":base_poll_test",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:signal_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "arch_prctl_test",
+ testonly = 1,
+ srcs = select_arch(
+ amd64 = ["arch_prctl.cc"],
+ arm64 = [],
+ ),
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "prctl_test",
+ testonly = 1,
+ srcs = ["prctl.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:cleanup",
+ "@com_google_absl//absl/flags:flag",
+ gtest,
+ "//test/util:multiprocess_util",
+ "//test/util:posix_error",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "prctl_setuid_test",
+ testonly = 1,
+ srcs = ["prctl_setuid.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "@com_google_absl//absl/flags:flag",
+ gtest,
+ "//test/util:logging",
+ "//test/util:multiprocess_util",
+ "//test/util:posix_error",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "pread64_test",
+ testonly = 1,
+ srcs = ["pread64.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "preadv_test",
+ testonly = 1,
+ srcs = ["preadv.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:logging",
+ "//test/util:memory_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "//test/util:timer_util",
+ ],
+)
+
+cc_binary(
+ name = "preadv2_test",
+ testonly = 1,
+ srcs = [
+ "file_base.h",
+ "preadv2.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "priority_test",
+ testonly = 1,
+ srcs = ["priority.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "proc_test",
+ testonly = 1,
+ srcs = ["proc.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:cleanup",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:memory_util",
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "//test/util:time_util",
+ "//test/util:timer_util",
+ ],
+)
+
+cc_binary(
+ name = "proc_net_test",
+ testonly = 1,
+ srcs = ["proc_net.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "proc_pid_oomscore_test",
+ testonly = 1,
+ srcs = ["proc_pid_oomscore.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:fs_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_binary(
+ name = "proc_pid_smaps_test",
+ testonly = 1,
+ srcs = ["proc_pid_smaps.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
+ gtest,
+ "//test/util:memory_util",
+ "//test/util:posix_error",
+ "//test/util:proc_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "proc_pid_uid_gid_map_test",
+ testonly = 1,
+ srcs = ["proc_pid_uid_gid_map.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:cleanup",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:logging",
+ "//test/util:multiprocess_util",
+ "//test/util:posix_error",
+ "//test/util:save_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:time_util",
+ ],
+)
+
+cc_binary(
+ name = "pselect_test",
+ testonly = 1,
+ srcs = ["pselect.cc"],
+ linkstatic = 1,
+ deps = [
+ ":base_poll_test",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:signal_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "ptrace_test",
+ testonly = 1,
+ srcs = ["ptrace.cc"],
+ linkstatic = 1,
+ deps = [
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:logging",
+ "//test/util:multiprocess_util",
+ "//test/util:platform_util",
+ "//test/util:signal_util",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "//test/util:time_util",
+ ],
+)
+
+cc_binary(
+ name = "pwrite64_test",
+ testonly = 1,
+ srcs = ["pwrite64.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "pwritev2_test",
+ testonly = 1,
+ srcs = [
+ "pwritev2.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":file_base",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "raw_socket_hdrincl_test",
+ testonly = 1,
+ srcs = ["raw_socket_hdrincl.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/base:endian",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "raw_socket_test",
+ testonly = 1,
+ srcs = ["raw_socket.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/base:core_headers",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "raw_socket_icmp_test",
+ testonly = 1,
+ srcs = ["raw_socket_icmp.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/base:core_headers",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "read_test",
+ testonly = 1,
+ srcs = ["read.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "readahead_test",
+ testonly = 1,
+ srcs = ["readahead.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "readv_test",
+ testonly = 1,
+ srcs = [
+ "file_base.h",
+ "readv.cc",
+ "readv_common.cc",
+ "readv_common.h",
+ ],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:timer_util",
+ ],
+)
+
+cc_binary(
+ name = "readv_socket_test",
+ testonly = 1,
+ srcs = [
+ "readv_common.cc",
+ "readv_common.h",
+ "readv_socket.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "rename_test",
+ testonly = 1,
+ srcs = ["rename.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:cleanup",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "rlimits_test",
+ testonly = 1,
+ srcs = ["rlimits.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "rseq_test",
+ testonly = 1,
+ srcs = ["rseq.cc"],
+ data = ["//test/syscalls/linux/rseq"],
+ linkstatic = 1,
+ deps = [
+ "//test/syscalls/linux/rseq:lib",
+ gtest,
+ "//test/util:logging",
+ "//test/util:multiprocess_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "rtsignal_test",
+ testonly = 1,
+ srcs = ["rtsignal.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:cleanup",
+ gtest,
+ "//test/util:logging",
+ "//test/util:posix_error",
+ "//test/util:signal_util",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "sched_test",
+ testonly = 1,
+ srcs = ["sched.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "sched_yield_test",
+ testonly = 1,
+ srcs = ["sched_yield.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "seccomp_test",
+ testonly = 1,
+ srcs = ["seccomp.cc"],
+ linkstatic = 1,
+ deps = [
+ "@com_google_absl//absl/base:core_headers",
+ gtest,
+ "//test/util:logging",
+ "//test/util:memory_util",
+ "//test/util:multiprocess_util",
+ "//test/util:posix_error",
+ "//test/util:proc_util",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "select_test",
+ testonly = 1,
+ srcs = ["select.cc"],
+ linkstatic = 1,
+ deps = [
+ ":base_poll_test",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:multiprocess_util",
+ "//test/util:posix_error",
+ "//test/util:rlimit_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "sendfile_test",
+ testonly = 1,
+ srcs = ["sendfile.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:eventfd_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "sendfile_socket_test",
+ testonly = 1,
+ srcs = ["sendfile_socket.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
+ ":ip_socket_test_util",
+ ":unix_domain_socket_test_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "splice_test",
+ testonly = 1,
+ srcs = ["splice.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "sigaction_test",
+ testonly = 1,
+ srcs = ["sigaction.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "sigaltstack_test",
+ testonly = 1,
+ srcs = ["sigaltstack.cc"],
+ data = [
+ ":sigaltstack_check",
+ ],
+ linkstatic = 1,
+ deps = [
+ "//test/util:cleanup",
+ "//test/util:fs_util",
+ gtest,
+ "//test/util:multiprocess_util",
+ "//test/util:posix_error",
+ "//test/util:signal_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "sigiret_test",
+ testonly = 1,
+ srcs = select_arch(
+ amd64 = ["sigiret.cc"],
+ arm64 = [],
+ ),
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:logging",
+ "//test/util:signal_util",
+ "//test/util:test_util",
+ "//test/util:timer_util",
+ ] + select_arch(
+ amd64 = [],
+ arm64 = ["//test/util:test_main"],
+ ),
+)
+
+cc_binary(
+ name = "signalfd_test",
+ testonly = 1,
+ srcs = ["signalfd.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/synchronization",
+ gtest,
+ "//test/util:logging",
+ "//test/util:posix_error",
+ "//test/util:signal_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "sigprocmask_test",
+ testonly = 1,
+ srcs = ["sigprocmask.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:signal_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "sigstop_test",
+ testonly = 1,
+ srcs = ["sigstop.cc"],
+ linkstatic = 1,
+ deps = [
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:multiprocess_util",
+ "//test/util:posix_error",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "sigtimedwait_test",
+ testonly = 1,
+ srcs = ["sigtimedwait.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:logging",
+ "//test/util:signal_util",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "//test/util:timer_util",
+ ],
+)
+
+cc_library(
+ name = "socket_generic_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_generic.cc",
+ ],
+ hdrs = [
+ "socket_generic.h",
+ ],
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ gtest,
+ "//test/util:test_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_binary(
+ name = "socket_stress_test",
+ testonly = 1,
+ srcs = [
+ "socket_generic_stress.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_library(
+ name = "socket_unix_dgram_test_cases",
+ testonly = 1,
+ srcs = ["socket_unix_dgram.cc"],
+ hdrs = ["socket_unix_dgram.h"],
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ gtest,
+ "//test/util:test_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "socket_unix_seqpacket_test_cases",
+ testonly = 1,
+ srcs = ["socket_unix_seqpacket.cc"],
+ hdrs = ["socket_unix_seqpacket.h"],
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ gtest,
+ "//test/util:test_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "socket_ip_tcp_generic_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_ip_tcp_generic.cc",
+ ],
+ hdrs = [
+ "socket_ip_tcp_generic.h",
+ ],
+ deps = [
+ ":socket_test_util",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "socket_non_blocking_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_non_blocking.cc",
+ ],
+ hdrs = [
+ "socket_non_blocking.h",
+ ],
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ gtest,
+ "//test/util:test_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "socket_unix_non_stream_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_unix_non_stream.cc",
+ ],
+ hdrs = [
+ "socket_unix_non_stream.h",
+ ],
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ gtest,
+ "//test/util:memory_util",
+ "//test/util:test_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "socket_non_stream_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_non_stream.cc",
+ ],
+ hdrs = [
+ "socket_non_stream.h",
+ ],
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ gtest,
+ "//test/util:test_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "socket_ip_udp_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_ip_udp_generic.cc",
+ ],
+ hdrs = [
+ "socket_ip_udp_generic.h",
+ ],
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ gtest,
+ "//test/util:test_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "socket_ipv4_udp_unbound_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_ipv4_udp_unbound.cc",
+ ],
+ hdrs = [
+ "socket_ipv4_udp_unbound.h",
+ ],
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ "@com_google_absl//absl/memory",
+ gtest,
+ "//test/util:test_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "socket_ipv4_udp_unbound_external_networking_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_ipv4_udp_unbound_external_networking.cc",
+ ],
+ hdrs = [
+ "socket_ipv4_udp_unbound_external_networking.h",
+ ],
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ gtest,
+ "//test/util:test_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "socket_ipv4_tcp_unbound_external_networking_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_ipv4_tcp_unbound_external_networking.cc",
+ ],
+ hdrs = [
+ "socket_ipv4_tcp_unbound_external_networking.h",
+ ],
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ gtest,
+ "//test/util:test_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_binary(
+ name = "socket_abstract_test",
+ testonly = 1,
+ srcs = [
+ "socket_abstract.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":socket_generic_test_cases",
+ ":socket_test_util",
+ ":socket_unix_cmsg_test_cases",
+ ":socket_unix_test_cases",
+ ":unix_domain_socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_abstract_non_blocking_test",
+ testonly = 1,
+ srcs = [
+ "socket_unix_abstract_nonblock.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":socket_non_blocking_test_cases",
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_unix_dgram_local_test",
+ testonly = 1,
+ srcs = ["socket_unix_dgram_local.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_non_stream_test_cases",
+ ":socket_test_util",
+ ":socket_unix_dgram_test_cases",
+ ":socket_unix_non_stream_test_cases",
+ ":unix_domain_socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_unix_dgram_non_blocking_test",
+ testonly = 1,
+ srcs = ["socket_unix_dgram_non_blocking.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_unix_seqpacket_local_test",
+ testonly = 1,
+ srcs = [
+ "socket_unix_seqpacket_local.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":socket_non_stream_test_cases",
+ ":socket_test_util",
+ ":socket_unix_non_stream_test_cases",
+ ":socket_unix_seqpacket_test_cases",
+ ":unix_domain_socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_unix_stream_test",
+ testonly = 1,
+ srcs = ["socket_unix_stream.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_ip_tcp_generic_loopback_test",
+ testonly = 1,
+ srcs = [
+ "socket_ip_tcp_generic_loopback.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_ip_tcp_generic_test_cases",
+ ":socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_ip_tcp_udp_generic_loopback_test",
+ testonly = 1,
+ srcs = [
+ "socket_ip_tcp_udp_generic.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_ip_tcp_loopback_test",
+ testonly = 1,
+ srcs = [
+ "socket_ip_tcp_loopback.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_generic_test_cases",
+ ":socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_ip_tcp_loopback_non_blocking_test",
+ testonly = 1,
+ srcs = [
+ "socket_ip_tcp_loopback_nonblock.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_non_blocking_test_cases",
+ ":socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_ip_udp_loopback_test",
+ testonly = 1,
+ srcs = [
+ "socket_ip_udp_loopback.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_generic_test_cases",
+ ":socket_ip_udp_test_cases",
+ ":socket_non_stream_test_cases",
+ ":socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_ipv4_udp_unbound_external_networking_test",
+ testonly = 1,
+ srcs = [
+ "socket_ipv4_udp_unbound_external_networking_test.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_ipv4_udp_unbound_external_networking_test_cases",
+ ":socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_ipv4_tcp_unbound_external_networking_test",
+ testonly = 1,
+ srcs = [
+ "socket_ipv4_tcp_unbound_external_networking_test.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_ipv4_tcp_unbound_external_networking_test_cases",
+ ":socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_bind_to_device_test",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_bind_to_device_util",
+ ":socket_test_util",
+ "//test/util:capability_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_bind_to_device_sequence_test",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device_sequence.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_bind_to_device_util",
+ ":socket_test_util",
+ "//test/util:capability_util",
+ "@com_google_absl//absl/container:node_hash_map",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_bind_to_device_distribution_test",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device_distribution.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_bind_to_device_util",
+ ":socket_test_util",
+ "//test/util:capability_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_ip_udp_loopback_non_blocking_test",
+ testonly = 1,
+ srcs = [
+ "socket_ip_udp_loopback_nonblock.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_non_blocking_test_cases",
+ ":socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_ipv4_udp_unbound_loopback_test",
+ testonly = 1,
+ srcs = [
+ "socket_ipv4_udp_unbound_loopback.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_ipv4_udp_unbound_test_cases",
+ ":socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_ip_unbound_test",
+ testonly = 1,
+ srcs = [
+ "socket_ip_unbound.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_domain_test",
+ testonly = 1,
+ srcs = [
+ "socket_unix_domain.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":socket_generic_test_cases",
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_domain_non_blocking_test",
+ testonly = 1,
+ srcs = [
+ "socket_unix_pair_nonblock.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":socket_non_blocking_test_cases",
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_filesystem_test",
+ testonly = 1,
+ srcs = [
+ "socket_filesystem.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":socket_generic_test_cases",
+ ":socket_test_util",
+ ":socket_unix_cmsg_test_cases",
+ ":socket_unix_test_cases",
+ ":unix_domain_socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_filesystem_non_blocking_test",
+ testonly = 1,
+ srcs = [
+ "socket_unix_filesystem_nonblock.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":socket_non_blocking_test_cases",
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_inet_loopback_test",
+ testonly = 1,
+ srcs = ["socket_inet_loopback.cc"],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:save_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_inet_loopback_nogotsan_test",
+ testonly = 1,
+ srcs = ["socket_inet_loopback_nogotsan.cc"],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:save_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_netlink_test",
+ testonly = 1,
+ srcs = ["socket_netlink.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_netlink_route_test",
+ testonly = 1,
+ srcs = ["socket_netlink_route.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_netlink_route_util",
+ ":socket_netlink_util",
+ ":socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:cleanup",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/strings:str_format",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_netlink_uevent_test",
+ testonly = 1,
+ srcs = ["socket_netlink_uevent.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_netlink_util",
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+# These socket tests are in a library because the test cases are shared
+# across several test build targets.
+cc_library(
+ name = "socket_stream_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_stream.cc",
+ ],
+ hdrs = [
+ "socket_stream.h",
+ ],
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:test_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "socket_blocking_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_blocking.cc",
+ ],
+ hdrs = [
+ "socket_blocking.h",
+ ],
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "//test/util:timer_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "socket_unix_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_unix.cc",
+ ],
+ hdrs = [
+ "socket_unix.h",
+ ],
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "socket_unix_cmsg_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_unix_cmsg.cc",
+ ],
+ hdrs = [
+ "socket_unix_cmsg.h",
+ ],
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "socket_stream_blocking_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_stream_blocking.cc",
+ ],
+ hdrs = [
+ "socket_stream_blocking.h",
+ ],
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "//test/util:timer_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "socket_stream_nonblocking_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_stream_nonblock.cc",
+ ],
+ hdrs = [
+ "socket_stream_nonblock.h",
+ ],
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ gtest,
+ "//test/util:test_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "socket_non_stream_blocking_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_non_stream_blocking.cc",
+ ],
+ hdrs = [
+ "socket_non_stream_blocking.h",
+ ],
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "socket_bind_to_device_util",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device_util.cc",
+ ],
+ hdrs = [
+ "socket_bind_to_device_util.h",
+ ],
+ deps = [
+ "//test/util:test_util",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ],
+ alwayslink = 1,
+)
+
+cc_binary(
+ name = "socket_stream_local_test",
+ testonly = 1,
+ srcs = [
+ "socket_unix_stream_local.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":socket_stream_test_cases",
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_stream_blocking_local_test",
+ testonly = 1,
+ srcs = [
+ "socket_unix_stream_blocking_local.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":socket_stream_blocking_test_cases",
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_stream_blocking_tcp_test",
+ testonly = 1,
+ srcs = [
+ "socket_ip_tcp_loopback_blocking.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_stream_blocking_test_cases",
+ ":socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_stream_nonblock_local_test",
+ testonly = 1,
+ srcs = [
+ "socket_unix_stream_nonblock_local.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":socket_stream_nonblocking_test_cases",
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_unix_unbound_dgram_test",
+ testonly = 1,
+ srcs = ["socket_unix_unbound_dgram.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_unix_unbound_abstract_test",
+ testonly = 1,
+ srcs = ["socket_unix_unbound_abstract.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_unix_unbound_filesystem_test",
+ testonly = 1,
+ srcs = ["socket_unix_unbound_filesystem.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_blocking_local_test",
+ testonly = 1,
+ srcs = [
+ "socket_unix_blocking_local.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":socket_blocking_test_cases",
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_blocking_ip_test",
+ testonly = 1,
+ srcs = [
+ "socket_ip_loopback_blocking.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_blocking_test_cases",
+ ":socket_test_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_non_stream_blocking_local_test",
+ testonly = 1,
+ srcs = [
+ "socket_unix_non_stream_blocking_local.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":socket_non_stream_blocking_test_cases",
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_non_stream_blocking_udp_test",
+ testonly = 1,
+ srcs = [
+ "socket_ip_udp_loopback_blocking.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_non_stream_blocking_test_cases",
+ ":socket_test_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_unix_pair_test",
+ testonly = 1,
+ srcs = [
+ "socket_unix_pair.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ ":socket_unix_cmsg_test_cases",
+ ":socket_unix_test_cases",
+ ":unix_domain_socket_test_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_unix_unbound_seqpacket_test",
+ testonly = 1,
+ srcs = ["socket_unix_unbound_seqpacket.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_unix_unbound_stream_test",
+ testonly = 1,
+ srcs = ["socket_unix_unbound_stream.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_netdevice_test",
+ testonly = 1,
+ srcs = ["socket_netdevice.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_netlink_util",
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/base:endian",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "stat_test",
+ testonly = 1,
+ srcs = [
+ "file_base.h",
+ "stat.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ "//test/util:cleanup",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "stat_times_test",
+ testonly = 1,
+ srcs = ["stat_times.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "statfs_test",
+ testonly = 1,
+ srcs = [
+ "file_base.h",
+ "statfs.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "symlink_test",
+ testonly = 1,
+ srcs = ["symlink.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "sync_test",
+ testonly = 1,
+ srcs = ["sync.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "sysinfo_test",
+ testonly = 1,
+ srcs = ["sysinfo.cc"],
+ linkstatic = 1,
+ deps = [
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "syslog_test",
+ testonly = 1,
+ srcs = ["syslog.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "sysret_test",
+ testonly = 1,
+ srcs = ["sysret.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "tcp_socket_test",
+ testonly = 1,
+ srcs = ["tcp_socket.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "tgkill_test",
+ testonly = 1,
+ srcs = ["tgkill.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:signal_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "time_test",
+ testonly = 1,
+ srcs = ["time.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:proc_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "timerfd_test",
+ testonly = 1,
+ srcs = ["timerfd.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_binary(
+ name = "timers_test",
+ testonly = 1,
+ srcs = ["timers.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:cleanup",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:logging",
+ "//test/util:multiprocess_util",
+ "//test/util:posix_error",
+ "//test/util:signal_util",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "tkill_test",
+ testonly = 1,
+ srcs = ["tkill.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "truncate_test",
+ testonly = 1,
+ srcs = ["truncate.cc"],
+ linkstatic = 1,
+ deps = [
+ ":file_base",
+ "//test/util:capability_util",
+ "//test/util:cleanup",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "tuntap_test",
+ testonly = 1,
+ srcs = ["tuntap.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ gtest,
+ ":socket_netlink_route_util",
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_binary(
+ name = "tuntap_hostinet_test",
+ testonly = 1,
+ srcs = ["tuntap_hostinet.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_library(
+ name = "udp_socket_test_cases",
+ testonly = 1,
+ srcs = [
+ "udp_socket_errqueue_test_case.cc",
+ "udp_socket_test_cases.cc",
+ ],
+ hdrs = ["udp_socket_test_cases.h"],
+ defines = select_system(),
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ ":unix_domain_socket_test_util",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:file_descriptor",
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_binary(
+ name = "udp_socket_test",
+ testonly = 1,
+ srcs = ["udp_socket.cc"],
+ linkstatic = 1,
+ deps = [
+ ":udp_socket_test_cases",
+ ],
+)
+
+cc_binary(
+ name = "udp_bind_test",
+ testonly = 1,
+ srcs = ["udp_bind.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "uidgid_test",
+ testonly = 1,
+ srcs = ["uidgid.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "//test/util:uid_util",
+ ],
+)
+
+cc_binary(
+ name = "uname_test",
+ testonly = 1,
+ srcs = ["uname.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "unlink_test",
+ testonly = 1,
+ srcs = ["unlink.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "unshare_test",
+ testonly = 1,
+ srcs = ["unshare.cc"],
+ linkstatic = 1,
+ deps = [
+ "@com_google_absl//absl/synchronization",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "utimes_test",
+ testonly = 1,
+ srcs = ["utimes.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_binary(
+ name = "vdso_test",
+ testonly = 1,
+ srcs = ["vdso.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:fs_util",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:proc_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "vfork_test",
+ testonly = 1,
+ srcs = ["vfork.cc"],
+ linkstatic = 1,
+ deps = [
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:logging",
+ "//test/util:multiprocess_util",
+ "//test/util:test_util",
+ "//test/util:time_util",
+ ],
+)
+
+cc_binary(
+ name = "wait_test",
+ testonly = 1,
+ srcs = ["wait.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:cleanup",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:logging",
+ "//test/util:multiprocess_util",
+ "//test/util:posix_error",
+ "//test/util:signal_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "//test/util:time_util",
+ ],
+)
+
+cc_binary(
+ name = "write_test",
+ testonly = 1,
+ srcs = ["write.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:cleanup",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "memory_accounting_test",
+ testonly = 1,
+ srcs = ["memory_accounting.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "network_namespace_test",
+ testonly = 1,
+ srcs = ["network_namespace.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ gtest,
+ "//test/util:capability_util",
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "semaphore_test",
+ testonly = 1,
+ srcs = ["semaphore.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "shm_test",
+ testonly = 1,
+ srcs = ["shm.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:multiprocess_util",
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_binary(
+ name = "fadvise64_test",
+ testonly = 1,
+ srcs = ["fadvise64.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "vdso_clock_gettime_test",
+ testonly = 1,
+ srcs = ["vdso_clock_gettime.cc"],
+ linkstatic = 1,
+ deps = [
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "vsyscall_test",
+ testonly = 1,
+ srcs = ["vsyscall.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:proc_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "proc_net_unix_test",
+ testonly = 1,
+ srcs = ["proc_net_unix.cc"],
+ linkstatic = 1,
+ deps = [
+ ":unix_domain_socket_test_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "memfd_test",
+ testonly = 1,
+ srcs = ["memfd.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ gtest,
+ "//test/util:memory_util",
+ "//test/util:multiprocess_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "proc_net_tcp_test",
+ testonly = 1,
+ srcs = ["proc_net_tcp.cc"],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "proc_net_udp_test",
+ testonly = 1,
+ srcs = ["proc_net_udp.cc"],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "xattr_test",
+ testonly = 1,
+ srcs = [
+ "file_base.h",
+ "xattr.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
diff --git a/test/syscalls/linux/accept_bind.cc b/test/syscalls/linux/accept_bind.cc
new file mode 100644
index 000000000..f65a14fb8
--- /dev/null
+++ b/test/syscalls/linux/accept_bind.cc
@@ -0,0 +1,641 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+
+#include <algorithm>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST_P(AllSocketPairTest, Listen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), /* backlog = */ 5),
+ SyscallSucceeds());
+}
+
+TEST_P(AllSocketPairTest, ListenIncreaseBacklog) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), /* backlog = */ 5),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(sockets->first_fd(), /* backlog = */ 10),
+ SyscallSucceeds());
+}
+
+TEST_P(AllSocketPairTest, ListenDecreaseBacklog) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), /* backlog = */ 5),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(sockets->first_fd(), /* backlog = */ 1),
+ SyscallSucceeds());
+}
+
+TEST_P(AllSocketPairTest, ListenWithoutBind) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(listen(sockets->first_fd(), 0), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(AllSocketPairTest, DoubleBind) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->second_addr(),
+ sockets->second_addr_size()),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(AllSocketPairTest, BindListenBind) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->second_addr(),
+ sockets->second_addr_size()),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(AllSocketPairTest, DoubleListen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+}
+
+TEST_P(AllSocketPairTest, DoubleConnect) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallFailsWithErrno(EISCONN));
+}
+
+TEST_P(AllSocketPairTest, Connect) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+}
+
+TEST_P(AllSocketPairTest, ConnectWithWrongType) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int type;
+ socklen_t typelen = sizeof(type);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_TYPE, &type, &typelen),
+ SyscallSucceeds());
+ switch (type) {
+ case SOCK_STREAM:
+ type = SOCK_SEQPACKET;
+ break;
+ case SOCK_SEQPACKET:
+ type = SOCK_STREAM;
+ break;
+ }
+
+ const FileDescriptor another_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, type, 0));
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ if (sockets->first_addr()->sa_data[0] != 0) {
+ ASSERT_THAT(connect(another_socket.get(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallFailsWithErrno(EPROTOTYPE));
+ } else {
+ ASSERT_THAT(connect(another_socket.get(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallFailsWithErrno(ECONNREFUSED));
+ }
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+}
+
+TEST_P(AllSocketPairTest, ConnectNonListening) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallFailsWithErrno(ECONNREFUSED));
+}
+
+TEST_P(AllSocketPairTest, ConnectToFilePath) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct sockaddr_un addr = {};
+ addr.sun_family = AF_UNIX;
+ constexpr char kPath[] = "/tmp";
+ memcpy(addr.sun_path, kPath, sizeof(kPath));
+
+ ASSERT_THAT(
+ connect(sockets->second_fd(),
+ reinterpret_cast<const struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallFailsWithErrno(ECONNREFUSED));
+}
+
+TEST_P(AllSocketPairTest, ConnectToInvalidAbstractPath) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct sockaddr_un addr = {};
+ addr.sun_family = AF_UNIX;
+ constexpr char kPath[] = "\0nonexistent";
+ memcpy(addr.sun_path, kPath, sizeof(kPath));
+
+ ASSERT_THAT(
+ connect(sockets->second_fd(),
+ reinterpret_cast<const struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallFailsWithErrno(ECONNREFUSED));
+}
+
+TEST_P(AllSocketPairTest, SelfConnect) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(AllSocketPairTest, ConnectWithoutListen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallFailsWithErrno(ECONNREFUSED));
+}
+
+TEST_P(AllSocketPairTest, Accept) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ int accepted = -1;
+ ASSERT_THAT(accepted = accept(sockets->first_fd(), nullptr, nullptr),
+ SyscallSucceeds());
+ ASSERT_THAT(close(accepted), SyscallSucceeds());
+}
+
+TEST_P(AllSocketPairTest, AcceptValidAddrLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ int accepted = -1;
+ struct sockaddr_un addr = {};
+ socklen_t addr_len = sizeof(addr);
+ ASSERT_THAT(
+ accepted = accept(sockets->first_fd(),
+ reinterpret_cast<struct sockaddr*>(&addr), &addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(close(accepted), SyscallSucceeds());
+}
+
+TEST_P(AllSocketPairTest, AcceptNegativeAddrLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ // With a negative addr_len, accept returns EINVAL,
+ struct sockaddr_un addr = {};
+ socklen_t addr_len = -1;
+ ASSERT_THAT(accept(sockets->first_fd(),
+ reinterpret_cast<struct sockaddr*>(&addr), &addr_len),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(AllSocketPairTest, AcceptLargePositiveAddrLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ // With a large (positive) addr_len, accept does not return EINVAL.
+ int accepted = -1;
+ char addr_buf[200];
+ socklen_t addr_len = sizeof(addr_buf);
+ ASSERT_THAT(accepted = accept(sockets->first_fd(),
+ reinterpret_cast<struct sockaddr*>(addr_buf),
+ &addr_len),
+ SyscallSucceeds());
+ // addr_len should have been updated by accept().
+ EXPECT_LT(addr_len, sizeof(addr_buf));
+ ASSERT_THAT(close(accepted), SyscallSucceeds());
+}
+
+TEST_P(AllSocketPairTest, AcceptVeryLargePositiveAddrLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ // With a large (positive) addr_len, accept does not return EINVAL.
+ int accepted = -1;
+ char addr_buf[2000];
+ socklen_t addr_len = sizeof(addr_buf);
+ ASSERT_THAT(accepted = accept(sockets->first_fd(),
+ reinterpret_cast<struct sockaddr*>(addr_buf),
+ &addr_len),
+ SyscallSucceeds());
+ // addr_len should have been updated by accept().
+ EXPECT_LT(addr_len, sizeof(addr_buf));
+ ASSERT_THAT(close(accepted), SyscallSucceeds());
+}
+
+TEST_P(AllSocketPairTest, AcceptWithoutBind) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(accept(sockets->first_fd(), nullptr, nullptr),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(AllSocketPairTest, AcceptWithoutListen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+ ASSERT_THAT(accept(sockets->first_fd(), nullptr, nullptr),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(AllSocketPairTest, GetRemoteAddress) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ socklen_t addr_len = sockets->first_addr_size();
+ struct sockaddr_storage addr = {};
+ ASSERT_THAT(
+ getpeername(sockets->second_fd(), (struct sockaddr*)(&addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, sockets->first_addr_len());
+ EXPECT_EQ(0, memcmp(&addr, sockets->first_addr(), sockets->first_addr_len()));
+}
+
+TEST_P(AllSocketPairTest, UnboundGetLocalAddress) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ socklen_t addr_len = sockets->first_addr_size();
+ struct sockaddr_storage addr = {};
+ ASSERT_THAT(
+ getsockname(sockets->second_fd(), (struct sockaddr*)(&addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, 2);
+ EXPECT_EQ(
+ memcmp(&addr, sockets->second_addr(),
+ std::min((size_t)addr_len, (size_t)sockets->second_addr_len())),
+ 0);
+}
+
+TEST_P(AllSocketPairTest, BoundGetLocalAddress) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ ASSERT_THAT(bind(sockets->second_fd(), sockets->second_addr(),
+ sockets->second_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ socklen_t addr_len = sockets->first_addr_size();
+ struct sockaddr_storage addr = {};
+ ASSERT_THAT(
+ getsockname(sockets->second_fd(), (struct sockaddr*)(&addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, sockets->second_addr_len());
+ EXPECT_EQ(
+ memcmp(&addr, sockets->second_addr(),
+ std::min((size_t)addr_len, (size_t)sockets->second_addr_len())),
+ 0);
+}
+
+TEST_P(AllSocketPairTest, BoundConnector) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ ASSERT_THAT(bind(sockets->second_fd(), sockets->second_addr(),
+ sockets->second_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+}
+
+TEST_P(AllSocketPairTest, UnboundSenderAddr) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ int accepted = -1;
+ ASSERT_THAT(accepted = accept(sockets->first_fd(), nullptr, nullptr),
+ SyscallSucceeds());
+ FileDescriptor accepted_fd(accepted);
+
+ int i = 0;
+ ASSERT_THAT(RetryEINTR(send)(sockets->second_fd(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+
+ struct sockaddr_storage addr;
+ socklen_t addr_len = sizeof(addr);
+ ASSERT_THAT(
+ RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&addr), &addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ EXPECT_EQ(addr_len, 0);
+}
+
+TEST_P(AllSocketPairTest, BoundSenderAddr) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ ASSERT_THAT(bind(sockets->second_fd(), sockets->second_addr(),
+ sockets->second_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ int accepted = -1;
+ ASSERT_THAT(accepted = accept(sockets->first_fd(), nullptr, nullptr),
+ SyscallSucceeds());
+ FileDescriptor accepted_fd(accepted);
+
+ int i = 0;
+ ASSERT_THAT(RetryEINTR(send)(sockets->second_fd(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+
+ struct sockaddr_storage addr;
+ socklen_t addr_len = sizeof(addr);
+ ASSERT_THAT(
+ RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&addr), &addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ EXPECT_EQ(addr_len, sockets->second_addr_len());
+ EXPECT_EQ(
+ memcmp(&addr, sockets->second_addr(),
+ std::min((size_t)addr_len, (size_t)sockets->second_addr_len())),
+ 0);
+}
+
+TEST_P(AllSocketPairTest, BindAfterConnectSenderAddr) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(bind(sockets->second_fd(), sockets->second_addr(),
+ sockets->second_addr_size()),
+ SyscallSucceeds());
+
+ int accepted = -1;
+ ASSERT_THAT(accepted = accept(sockets->first_fd(), nullptr, nullptr),
+ SyscallSucceeds());
+ FileDescriptor accepted_fd(accepted);
+
+ int i = 0;
+ ASSERT_THAT(RetryEINTR(send)(sockets->second_fd(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+
+ struct sockaddr_storage addr;
+ socklen_t addr_len = sizeof(addr);
+ ASSERT_THAT(
+ RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&addr), &addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ EXPECT_EQ(addr_len, sockets->second_addr_len());
+ EXPECT_EQ(
+ memcmp(&addr, sockets->second_addr(),
+ std::min((size_t)addr_len, (size_t)sockets->second_addr_len())),
+ 0);
+}
+
+TEST_P(AllSocketPairTest, BindAfterAcceptSenderAddr) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ int accepted = -1;
+ ASSERT_THAT(accepted = accept(sockets->first_fd(), nullptr, nullptr),
+ SyscallSucceeds());
+ FileDescriptor accepted_fd(accepted);
+
+ ASSERT_THAT(bind(sockets->second_fd(), sockets->second_addr(),
+ sockets->second_addr_size()),
+ SyscallSucceeds());
+
+ int i = 0;
+ ASSERT_THAT(RetryEINTR(send)(sockets->second_fd(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+
+ struct sockaddr_storage addr;
+ socklen_t addr_len = sizeof(addr);
+ ASSERT_THAT(
+ RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&addr), &addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ EXPECT_EQ(addr_len, sockets->second_addr_len());
+ EXPECT_EQ(
+ memcmp(&addr, sockets->second_addr(),
+ std::min((size_t)addr_len, (size_t)sockets->second_addr_len())),
+ 0);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllUnixDomainSockets, AllSocketPairTest,
+ ::testing::ValuesIn(VecCat<SocketPairKind>(
+ ApplyVec<SocketPairKind>(
+ FilesystemUnboundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_SEQPACKET},
+ List<int>{0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(
+ AbstractUnboundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_SEQPACKET},
+ List<int>{0, SOCK_NONBLOCK})))));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/accept_bind_stream.cc b/test/syscalls/linux/accept_bind_stream.cc
new file mode 100644
index 000000000..4857f160b
--- /dev/null
+++ b/test/syscalls/linux/accept_bind_stream.cc
@@ -0,0 +1,92 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+#include <sys/un.h>
+
+#include <algorithm>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST_P(AllSocketPairTest, BoundSenderAddrCoalesced) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ int accepted = -1;
+ ASSERT_THAT(accepted = accept(sockets->first_fd(), nullptr, nullptr),
+ SyscallSucceeds());
+ FileDescriptor closer(accepted);
+
+ int i = 0;
+ ASSERT_THAT(RetryEINTR(send)(sockets->second_fd(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+
+ ASSERT_THAT(bind(sockets->second_fd(), sockets->second_addr(),
+ sockets->second_addr_size()),
+ SyscallSucceeds());
+
+ i = 0;
+ ASSERT_THAT(RetryEINTR(send)(sockets->second_fd(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+
+ int ri[2] = {0, 0};
+ struct sockaddr_storage addr;
+ socklen_t addr_len = sizeof(addr);
+ ASSERT_THAT(
+ RetryEINTR(recvfrom)(accepted, ri, sizeof(ri), 0,
+ reinterpret_cast<sockaddr*>(&addr), &addr_len),
+ SyscallSucceedsWithValue(sizeof(ri)));
+ EXPECT_EQ(addr_len, sockets->second_addr_len());
+
+ EXPECT_EQ(
+ memcmp(&addr, sockets->second_addr(),
+ std::min((size_t)addr_len, (size_t)sockets->second_addr_len())),
+ 0);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllUnixDomainSockets, AllSocketPairTest,
+ ::testing::ValuesIn(VecCat<SocketPairKind>(
+ ApplyVec<SocketPairKind>(FilesystemUnboundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM},
+ List<int>{
+ 0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(
+ AbstractUnboundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM},
+ List<int>{0, SOCK_NONBLOCK})))));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/access.cc b/test/syscalls/linux/access.cc
new file mode 100644
index 000000000..bcc25cef4
--- /dev/null
+++ b/test/syscalls/linux/access.cc
@@ -0,0 +1,170 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/capability_util.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+using ::testing::Ge;
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class AccessTest : public ::testing::Test {
+ public:
+ std::string CreateTempFile(int perm) {
+ const std::string path = NewTempAbsPath();
+ const int fd = open(path.c_str(), O_CREAT | O_RDONLY, perm);
+ TEST_PCHECK(fd > 0);
+ TEST_PCHECK(close(fd) == 0);
+ return path;
+ }
+
+ protected:
+ // SetUp creates various configurations of files.
+ void SetUp() override {
+ // Move to the temporary directory. This allows us to reason more easily
+ // about absolute and relative paths.
+ ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds());
+
+ // Create an empty file, standard permissions.
+ relfile_ = NewTempRelPath();
+ int fd;
+ ASSERT_THAT(fd = open(relfile_.c_str(), O_CREAT | O_TRUNC, 0644),
+ SyscallSucceedsWithValue(Ge(0)));
+ ASSERT_THAT(close(fd), SyscallSucceeds());
+ absfile_ = GetAbsoluteTestTmpdir() + "/" + relfile_;
+
+ // Create an empty directory, no writable permissions.
+ absdir_ = NewTempAbsPath();
+ reldir_ = JoinPath(Basename(absdir_), "");
+ ASSERT_THAT(mkdir(reldir_.c_str(), 0555), SyscallSucceeds());
+
+ // This file doesn't exist.
+ relnone_ = NewTempRelPath();
+ absnone_ = GetAbsoluteTestTmpdir() + "/" + relnone_;
+ }
+
+ // TearDown unlinks created files.
+ void TearDown() override {
+ ASSERT_THAT(unlink(absfile_.c_str()), SyscallSucceeds());
+ ASSERT_THAT(rmdir(absdir_.c_str()), SyscallSucceeds());
+ }
+
+ std::string relfile_;
+ std::string reldir_;
+
+ std::string absfile_;
+ std::string absdir_;
+
+ std::string relnone_;
+ std::string absnone_;
+};
+
+TEST_F(AccessTest, RelativeFile) {
+ EXPECT_THAT(access(relfile_.c_str(), R_OK), SyscallSucceeds());
+}
+
+TEST_F(AccessTest, RelativeDir) {
+ EXPECT_THAT(access(reldir_.c_str(), R_OK | X_OK), SyscallSucceeds());
+}
+
+TEST_F(AccessTest, AbsFile) {
+ EXPECT_THAT(access(absfile_.c_str(), R_OK), SyscallSucceeds());
+}
+
+TEST_F(AccessTest, AbsDir) {
+ EXPECT_THAT(access(absdir_.c_str(), R_OK | X_OK), SyscallSucceeds());
+}
+
+TEST_F(AccessTest, RelDoesNotExist) {
+ EXPECT_THAT(access(relnone_.c_str(), R_OK), SyscallFailsWithErrno(ENOENT));
+}
+
+TEST_F(AccessTest, AbsDoesNotExist) {
+ EXPECT_THAT(access(absnone_.c_str(), R_OK), SyscallFailsWithErrno(ENOENT));
+}
+
+TEST_F(AccessTest, InvalidMode) {
+ EXPECT_THAT(access(relfile_.c_str(), 0xffffffff),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(AccessTest, NoPerms) {
+ // Drop capabilities that allow us to override permissions. We must drop
+ // PERMITTED because access() checks those instead of EFFECTIVE.
+ ASSERT_NO_ERRNO(DropPermittedCapability(CAP_DAC_OVERRIDE));
+ ASSERT_NO_ERRNO(DropPermittedCapability(CAP_DAC_READ_SEARCH));
+
+ EXPECT_THAT(access(absdir_.c_str(), W_OK), SyscallFailsWithErrno(EACCES));
+}
+
+TEST_F(AccessTest, InvalidName) {
+ EXPECT_THAT(access(reinterpret_cast<char*>(0x1234), W_OK),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(AccessTest, UsrReadOnly) {
+ // Drop capabilities that allow us to override permissions. We must drop
+ // PERMITTED because access() checks those instead of EFFECTIVE.
+ ASSERT_NO_ERRNO(DropPermittedCapability(CAP_DAC_OVERRIDE));
+ ASSERT_NO_ERRNO(DropPermittedCapability(CAP_DAC_READ_SEARCH));
+
+ const std::string filename = CreateTempFile(0400);
+ EXPECT_THAT(access(filename.c_str(), R_OK), SyscallSucceeds());
+ EXPECT_THAT(access(filename.c_str(), W_OK), SyscallFailsWithErrno(EACCES));
+ EXPECT_THAT(access(filename.c_str(), X_OK), SyscallFailsWithErrno(EACCES));
+ EXPECT_THAT(unlink(filename.c_str()), SyscallSucceeds());
+}
+
+TEST_F(AccessTest, UsrReadExec) {
+ // Drop capabilities that allow us to override permissions. We must drop
+ // PERMITTED because access() checks those instead of EFFECTIVE.
+ ASSERT_NO_ERRNO(DropPermittedCapability(CAP_DAC_OVERRIDE));
+ ASSERT_NO_ERRNO(DropPermittedCapability(CAP_DAC_READ_SEARCH));
+
+ const std::string filename = CreateTempFile(0500);
+ EXPECT_THAT(access(filename.c_str(), R_OK | X_OK), SyscallSucceeds());
+ EXPECT_THAT(access(filename.c_str(), W_OK), SyscallFailsWithErrno(EACCES));
+ EXPECT_THAT(unlink(filename.c_str()), SyscallSucceeds());
+}
+
+TEST_F(AccessTest, UsrReadWrite) {
+ const std::string filename = CreateTempFile(0600);
+ EXPECT_THAT(access(filename.c_str(), R_OK | W_OK), SyscallSucceeds());
+ EXPECT_THAT(access(filename.c_str(), X_OK), SyscallFailsWithErrno(EACCES));
+ EXPECT_THAT(unlink(filename.c_str()), SyscallSucceeds());
+}
+
+TEST_F(AccessTest, UsrReadWriteExec) {
+ const std::string filename = CreateTempFile(0700);
+ EXPECT_THAT(access(filename.c_str(), R_OK | W_OK | X_OK), SyscallSucceeds());
+ EXPECT_THAT(unlink(filename.c_str()), SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/affinity.cc b/test/syscalls/linux/affinity.cc
new file mode 100644
index 000000000..128364c34
--- /dev/null
+++ b/test/syscalls/linux/affinity.cc
@@ -0,0 +1,242 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sched.h>
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/strings/str_split.h"
+#include "test/util/cleanup.h"
+#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+// These tests are for both the sched_getaffinity(2) and sched_setaffinity(2)
+// syscalls.
+class AffinityTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ EXPECT_THAT(
+ // Needs use the raw syscall to get the actual size.
+ cpuset_size_ = syscall(SYS_sched_getaffinity, /*pid=*/0,
+ sizeof(cpu_set_t), &mask_),
+ SyscallSucceeds());
+ // Lots of tests rely on having more than 1 logical processor available.
+ EXPECT_GT(CPU_COUNT(&mask_), 1);
+ }
+
+ static PosixError ClearLowestBit(cpu_set_t* mask, size_t cpus) {
+ const size_t mask_size = CPU_ALLOC_SIZE(cpus);
+ for (size_t n = 0; n < cpus; ++n) {
+ if (CPU_ISSET_S(n, mask_size, mask)) {
+ CPU_CLR_S(n, mask_size, mask);
+ return NoError();
+ }
+ }
+ return PosixError(EINVAL, "No bit to clear, mask is empty");
+ }
+
+ PosixError ClearLowestBit() { return ClearLowestBit(&mask_, CPU_SETSIZE); }
+
+ // Stores the initial cpu mask for this process.
+ cpu_set_t mask_ = {};
+ int cpuset_size_ = 0;
+};
+
+// sched_getaffinity(2) is implemented.
+TEST_F(AffinityTest, SchedGetAffinityImplemented) {
+ EXPECT_THAT(sched_getaffinity(/*pid=*/0, sizeof(cpu_set_t), &mask_),
+ SyscallSucceeds());
+}
+
+// PID is not found.
+TEST_F(AffinityTest, SchedGetAffinityInvalidPID) {
+ // Flaky, but it's tough to avoid a race condition when finding an unused pid
+ EXPECT_THAT(sched_getaffinity(/*pid=*/INT_MAX - 1, sizeof(cpu_set_t), &mask_),
+ SyscallFailsWithErrno(ESRCH));
+}
+
+// PID is not found.
+TEST_F(AffinityTest, SchedSetAffinityInvalidPID) {
+ // Flaky, but it's tough to avoid a race condition when finding an unused pid
+ EXPECT_THAT(sched_setaffinity(/*pid=*/INT_MAX - 1, sizeof(cpu_set_t), &mask_),
+ SyscallFailsWithErrno(ESRCH));
+}
+
+TEST_F(AffinityTest, SchedSetAffinityZeroMask) {
+ CPU_ZERO(&mask_);
+ EXPECT_THAT(sched_setaffinity(/*pid=*/0, sizeof(cpu_set_t), &mask_),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// N.B. This test case relies on cpuset_size_ larger than the actual number of
+// of all existing CPUs. Check your machine if the test fails.
+TEST_F(AffinityTest, SchedSetAffinityNonexistentCPUDropped) {
+ cpu_set_t mask = mask_;
+ // Add a nonexistent CPU.
+ //
+ // The number needs to be larger than the possible number of CPU available,
+ // but smaller than the number of the CPU that the kernel claims to support --
+ // it's implicitly returned by raw sched_getaffinity syscall.
+ CPU_SET(cpuset_size_ * 8 - 1, &mask);
+ EXPECT_THAT(
+ // Use raw syscall because it will be rejected by the libc wrapper
+ // otherwise.
+ syscall(SYS_sched_setaffinity, /*pid=*/0, sizeof(cpu_set_t), &mask),
+ SyscallSucceeds())
+ << "failed with cpumask : " << CPUSetToString(mask)
+ << ", cpuset_size_ : " << cpuset_size_;
+ cpu_set_t newmask;
+ EXPECT_THAT(sched_getaffinity(/*pid=*/0, sizeof(cpu_set_t), &newmask),
+ SyscallSucceeds());
+ EXPECT_TRUE(CPU_EQUAL(&mask_, &newmask))
+ << "got: " << CPUSetToString(newmask)
+ << " != expected: " << CPUSetToString(mask_);
+}
+
+TEST_F(AffinityTest, SchedSetAffinityOnlyNonexistentCPUFails) {
+ // Make an empty cpu set.
+ CPU_ZERO(&mask_);
+ // Add a nonexistent CPU.
+ //
+ // The number needs to be larger than the possible number of CPU available,
+ // but smaller than the number of the CPU that the kernel claims to support --
+ // it's implicitly returned by raw sched_getaffinity syscall.
+ int cpu = cpuset_size_ * 8 - 1;
+ if (cpu <= NumCPUs()) {
+ GTEST_SKIP() << "Skipping test: cpu " << cpu << " exists";
+ }
+ CPU_SET(cpu, &mask_);
+ EXPECT_THAT(
+ // Use raw syscall because it will be rejected by the libc wrapper
+ // otherwise.
+ syscall(SYS_sched_setaffinity, /*pid=*/0, sizeof(cpu_set_t), &mask_),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(AffinityTest, SchedSetAffinityInvalidSize) {
+ EXPECT_GT(cpuset_size_, 0);
+ // Not big enough.
+ EXPECT_THAT(sched_getaffinity(/*pid=*/0, cpuset_size_ - 1, &mask_),
+ SyscallFailsWithErrno(EINVAL));
+ // Not a multiple of word size.
+ EXPECT_THAT(sched_getaffinity(/*pid=*/0, cpuset_size_ + 1, &mask_),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(AffinityTest, Sanity) {
+ ASSERT_NO_ERRNO(ClearLowestBit());
+ EXPECT_THAT(sched_setaffinity(/*pid=*/0, sizeof(cpu_set_t), &mask_),
+ SyscallSucceeds());
+ cpu_set_t newmask;
+ EXPECT_THAT(sched_getaffinity(/*pid=*/0, sizeof(cpu_set_t), &newmask),
+ SyscallSucceeds());
+ EXPECT_TRUE(CPU_EQUAL(&mask_, &newmask))
+ << "got: " << CPUSetToString(newmask)
+ << " != expected: " << CPUSetToString(mask_);
+}
+
+TEST_F(AffinityTest, NewThread) {
+ SKIP_IF(CPU_COUNT(&mask_) < 3);
+ ASSERT_NO_ERRNO(ClearLowestBit());
+ ASSERT_NO_ERRNO(ClearLowestBit());
+ EXPECT_THAT(sched_setaffinity(/*pid=*/0, sizeof(cpu_set_t), &mask_),
+ SyscallSucceeds());
+ ScopedThread([this]() {
+ cpu_set_t child_mask;
+ ASSERT_THAT(sched_getaffinity(/*pid=*/0, sizeof(cpu_set_t), &child_mask),
+ SyscallSucceeds());
+ ASSERT_TRUE(CPU_EQUAL(&child_mask, &mask_))
+ << "child cpu mask: " << CPUSetToString(child_mask)
+ << " != parent cpu mask: " << CPUSetToString(mask_);
+ });
+}
+
+TEST_F(AffinityTest, ConsistentWithProcCpuInfo) {
+ // Count how many cpus are shown in /proc/cpuinfo.
+ std::string cpuinfo = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/cpuinfo"));
+ int count = 0;
+ for (auto const& line : absl::StrSplit(cpuinfo, '\n')) {
+ if (absl::StartsWith(line, "processor")) {
+ count++;
+ }
+ }
+ EXPECT_GE(count, CPU_COUNT(&mask_));
+}
+
+TEST_F(AffinityTest, ConsistentWithProcStat) {
+ // Count how many cpus are shown in /proc/stat.
+ std::string stat = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/stat"));
+ int count = 0;
+ for (auto const& line : absl::StrSplit(stat, '\n')) {
+ if (absl::StartsWith(line, "cpu") && !absl::StartsWith(line, "cpu ")) {
+ count++;
+ }
+ }
+ EXPECT_GE(count, CPU_COUNT(&mask_));
+}
+
+TEST_F(AffinityTest, SmallCpuMask) {
+ const int num_cpus = NumCPUs();
+ const size_t mask_size = CPU_ALLOC_SIZE(num_cpus);
+ cpu_set_t* mask = CPU_ALLOC(num_cpus);
+ ASSERT_NE(mask, nullptr);
+ const auto free_mask = Cleanup([&] { CPU_FREE(mask); });
+
+ CPU_ZERO_S(mask_size, mask);
+ ASSERT_THAT(sched_getaffinity(0, mask_size, mask), SyscallSucceeds());
+}
+
+TEST_F(AffinityTest, LargeCpuMask) {
+ // Allocate mask bigger than cpu_set_t normally allocates.
+ const size_t cpus = CPU_SETSIZE * 8;
+ const size_t mask_size = CPU_ALLOC_SIZE(cpus);
+
+ cpu_set_t* large_mask = CPU_ALLOC(cpus);
+ auto free_mask = Cleanup([large_mask] { CPU_FREE(large_mask); });
+ CPU_ZERO_S(mask_size, large_mask);
+
+ // Check that get affinity with large mask works as expected.
+ ASSERT_THAT(sched_getaffinity(/*pid=*/0, mask_size, large_mask),
+ SyscallSucceeds());
+ EXPECT_TRUE(CPU_EQUAL(&mask_, large_mask))
+ << "got: " << CPUSetToString(*large_mask, cpus)
+ << " != expected: " << CPUSetToString(mask_);
+
+ // Check that set affinity with large mask works as expected.
+ ASSERT_NO_ERRNO(ClearLowestBit(large_mask, cpus));
+ EXPECT_THAT(sched_setaffinity(/*pid=*/0, mask_size, large_mask),
+ SyscallSucceeds());
+
+ cpu_set_t* new_mask = CPU_ALLOC(cpus);
+ auto free_new_mask = Cleanup([new_mask] { CPU_FREE(new_mask); });
+ CPU_ZERO_S(mask_size, new_mask);
+ EXPECT_THAT(sched_getaffinity(/*pid=*/0, mask_size, new_mask),
+ SyscallSucceeds());
+
+ EXPECT_TRUE(CPU_EQUAL_S(mask_size, large_mask, new_mask))
+ << "got: " << CPUSetToString(*new_mask, cpus)
+ << " != expected: " << CPUSetToString(*large_mask, cpus);
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/aio.cc b/test/syscalls/linux/aio.cc
new file mode 100644
index 000000000..806d5729e
--- /dev/null
+++ b/test/syscalls/linux/aio.cc
@@ -0,0 +1,430 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <linux/aio_abi.h>
+#include <sys/mman.h>
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <algorithm>
+#include <string>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/file_base.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/memory_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/proc_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+using ::testing::_;
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+// Returns the size of the VMA containing the given address.
+PosixErrorOr<size_t> VmaSizeAt(uintptr_t addr) {
+ ASSIGN_OR_RETURN_ERRNO(std::string proc_self_maps,
+ GetContents("/proc/self/maps"));
+ ASSIGN_OR_RETURN_ERRNO(auto entries, ParseProcMaps(proc_self_maps));
+ // Use binary search to find the first VMA that might contain addr.
+ ProcMapsEntry target = {};
+ target.end = addr;
+ auto it =
+ std::upper_bound(entries.begin(), entries.end(), target,
+ [](const ProcMapsEntry& x, const ProcMapsEntry& y) {
+ return x.end < y.end;
+ });
+ // Check that it actually contains addr.
+ if (it == entries.end() || addr < it->start) {
+ return PosixError(ENOENT, absl::StrCat("no VMA contains address ", addr));
+ }
+ return it->end - it->start;
+}
+
+constexpr char kData[] = "hello world!";
+
+int SubmitCtx(aio_context_t ctx, long nr, struct iocb** iocbpp) {
+ return syscall(__NR_io_submit, ctx, nr, iocbpp);
+}
+
+class AIOTest : public FileTest {
+ public:
+ AIOTest() : ctx_(0) {}
+
+ int SetupContext(unsigned int nr) {
+ return syscall(__NR_io_setup, nr, &ctx_);
+ }
+
+ int Submit(long nr, struct iocb** iocbpp) {
+ return SubmitCtx(ctx_, nr, iocbpp);
+ }
+
+ int GetEvents(long min, long max, struct io_event* events,
+ struct timespec* timeout) {
+ return RetryEINTR(syscall)(__NR_io_getevents, ctx_, min, max, events,
+ timeout);
+ }
+
+ int DestroyContext() { return syscall(__NR_io_destroy, ctx_); }
+
+ void TearDown() override {
+ FileTest::TearDown();
+ if (ctx_ != 0) {
+ ASSERT_THAT(DestroyContext(), SyscallSucceeds());
+ ctx_ = 0;
+ }
+ }
+
+ struct iocb CreateCallback() {
+ struct iocb cb = {};
+ cb.aio_data = 0x123;
+ cb.aio_fildes = test_file_fd_.get();
+ cb.aio_lio_opcode = IOCB_CMD_PWRITE;
+ cb.aio_buf = reinterpret_cast<uint64_t>(kData);
+ cb.aio_offset = 0;
+ cb.aio_nbytes = strlen(kData);
+ return cb;
+ }
+
+ protected:
+ aio_context_t ctx_;
+};
+
+TEST_F(AIOTest, BasicWrite) {
+ // Copied from fs/aio.c.
+ constexpr unsigned AIO_RING_MAGIC = 0xa10a10a1;
+ struct aio_ring {
+ unsigned id;
+ unsigned nr;
+ unsigned head;
+ unsigned tail;
+ unsigned magic;
+ unsigned compat_features;
+ unsigned incompat_features;
+ unsigned header_length;
+ struct io_event io_events[0];
+ };
+
+ // Setup a context that is 128 entries deep.
+ ASSERT_THAT(SetupContext(128), SyscallSucceeds());
+
+ // Check that 'ctx_' points to a valid address. libaio uses it to check if
+ // aio implementation uses aio_ring. gVisor doesn't and returns all zeroes.
+ // Linux implements aio_ring, so skip the zeroes check.
+ //
+ // TODO(gvisor.dev/issue/204): Remove when gVisor implements aio_ring.
+ auto ring = reinterpret_cast<struct aio_ring*>(ctx_);
+ auto magic = IsRunningOnGvisor() ? 0 : AIO_RING_MAGIC;
+ EXPECT_EQ(ring->magic, magic);
+
+ struct iocb cb = CreateCallback();
+ struct iocb* cbs[1] = {&cb};
+
+ // Submit the request.
+ ASSERT_THAT(Submit(1, cbs), SyscallSucceedsWithValue(1));
+
+ // Get the reply.
+ struct io_event events[1];
+ ASSERT_THAT(GetEvents(1, 1, events, nullptr), SyscallSucceedsWithValue(1));
+
+ // Verify that it is as expected.
+ EXPECT_EQ(events[0].data, 0x123);
+ EXPECT_EQ(events[0].obj, reinterpret_cast<long>(&cb));
+ EXPECT_EQ(events[0].res, strlen(kData));
+
+ // Verify that the file contains the contents.
+ char verify_buf[sizeof(kData)] = {};
+ ASSERT_THAT(read(test_file_fd_.get(), verify_buf, sizeof(kData)),
+ SyscallSucceedsWithValue(strlen(kData)));
+ EXPECT_STREQ(verify_buf, kData);
+}
+
+TEST_F(AIOTest, BadWrite) {
+ // Create a pipe and immediately close the read end.
+ int pipefd[2];
+ ASSERT_THAT(pipe(pipefd), SyscallSucceeds());
+
+ FileDescriptor rfd(pipefd[0]);
+ FileDescriptor wfd(pipefd[1]);
+
+ rfd.reset(); // Close the read end.
+
+ // Setup a context that is 128 entries deep.
+ ASSERT_THAT(SetupContext(128), SyscallSucceeds());
+
+ struct iocb cb = CreateCallback();
+ // Try to write to the read end.
+ cb.aio_fildes = wfd.get();
+ struct iocb* cbs[1] = {&cb};
+
+ // Submit the request.
+ ASSERT_THAT(Submit(1, cbs), SyscallSucceedsWithValue(1));
+
+ // Get the reply.
+ struct io_event events[1];
+ ASSERT_THAT(GetEvents(1, 1, events, nullptr), SyscallSucceedsWithValue(1));
+
+ // Verify that it fails with the right error code.
+ EXPECT_EQ(events[0].data, 0x123);
+ EXPECT_EQ(events[0].obj, reinterpret_cast<uint64_t>(&cb));
+ EXPECT_LT(events[0].res, 0);
+}
+
+TEST_F(AIOTest, ExitWithPendingIo) {
+ // Setup a context that is 100 entries deep.
+ ASSERT_THAT(SetupContext(100), SyscallSucceeds());
+
+ struct iocb cb = CreateCallback();
+ struct iocb* cbs[] = {&cb};
+
+ // Submit a request but don't complete it to make it pending.
+ for (int i = 0; i < 100; ++i) {
+ EXPECT_THAT(Submit(1, cbs), SyscallSucceeds());
+ }
+
+ ASSERT_THAT(DestroyContext(), SyscallSucceeds());
+ ctx_ = 0;
+}
+
+int Submitter(void* arg) {
+ auto test = reinterpret_cast<AIOTest*>(arg);
+
+ struct iocb cb = test->CreateCallback();
+ struct iocb* cbs[1] = {&cb};
+
+ // Submit the request.
+ TEST_CHECK(test->Submit(1, cbs) == 1);
+ return 0;
+}
+
+TEST_F(AIOTest, CloneVm) {
+ // Setup a context that is 128 entries deep.
+ ASSERT_THAT(SetupContext(128), SyscallSucceeds());
+
+ const size_t kStackSize = 5 * kPageSize;
+ std::unique_ptr<char[]> stack(new char[kStackSize]);
+ char* bp = stack.get() + kStackSize;
+ pid_t child;
+ ASSERT_THAT(child = clone(Submitter, bp, CLONE_VM | SIGCHLD,
+ reinterpret_cast<void*>(this)),
+ SyscallSucceeds());
+
+ // Get the reply.
+ struct io_event events[1];
+ ASSERT_THAT(GetEvents(1, 1, events, nullptr), SyscallSucceedsWithValue(1));
+
+ // Verify that it is as expected.
+ EXPECT_EQ(events[0].data, 0x123);
+ EXPECT_EQ(events[0].res, strlen(kData));
+
+ // Verify that the file contains the contents.
+ char verify_buf[32] = {};
+ ASSERT_THAT(read(test_file_fd_.get(), &verify_buf[0], strlen(kData)),
+ SyscallSucceeds());
+ EXPECT_EQ(strcmp(kData, &verify_buf[0]), 0);
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0),
+ SyscallSucceedsWithValue(child));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << " status " << status;
+}
+
+// Tests that AIO context can be remapped to a different address.
+TEST_F(AIOTest, Mremap) {
+ // Setup a context that is 128 entries deep.
+ ASSERT_THAT(SetupContext(128), SyscallSucceeds());
+ const size_t ctx_size =
+ ASSERT_NO_ERRNO_AND_VALUE(VmaSizeAt(reinterpret_cast<uintptr_t>(ctx_)));
+
+ struct iocb cb = CreateCallback();
+ struct iocb* cbs[1] = {&cb};
+
+ // Reserve address space for the mremap target so we have something safe to
+ // map over.
+ Mapping dst =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(ctx_size, PROT_READ, MAP_PRIVATE));
+
+ // Remap context 'handle' to a different address.
+ ASSERT_THAT(Mremap(reinterpret_cast<void*>(ctx_), ctx_size, dst.len(),
+ MREMAP_FIXED | MREMAP_MAYMOVE, dst.ptr()),
+ IsPosixErrorOkAndHolds(dst.ptr()));
+ aio_context_t old_ctx = ctx_;
+ ctx_ = reinterpret_cast<aio_context_t>(dst.addr());
+ // io_destroy() will unmap dst now.
+ dst.release();
+
+ // Check that submitting the request with the old 'ctx_' fails.
+ ASSERT_THAT(SubmitCtx(old_ctx, 1, cbs), SyscallFailsWithErrno(EINVAL));
+
+ // Submit the request with the new 'ctx_'.
+ ASSERT_THAT(Submit(1, cbs), SyscallSucceedsWithValue(1));
+
+ // Remap again.
+ dst = ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(ctx_size, PROT_READ, MAP_PRIVATE));
+ ASSERT_THAT(Mremap(reinterpret_cast<void*>(ctx_), ctx_size, dst.len(),
+ MREMAP_FIXED | MREMAP_MAYMOVE, dst.ptr()),
+ IsPosixErrorOkAndHolds(dst.ptr()));
+ ctx_ = reinterpret_cast<aio_context_t>(dst.addr());
+ dst.release();
+
+ // Get the reply with yet another 'ctx_' and verify it.
+ struct io_event events[1];
+ ASSERT_THAT(GetEvents(1, 1, events, nullptr), SyscallSucceedsWithValue(1));
+ EXPECT_EQ(events[0].data, 0x123);
+ EXPECT_EQ(events[0].obj, reinterpret_cast<long>(&cb));
+ EXPECT_EQ(events[0].res, strlen(kData));
+
+ // Verify that the file contains the contents.
+ char verify_buf[sizeof(kData)] = {};
+ ASSERT_THAT(read(test_file_fd_.get(), verify_buf, sizeof(kData)),
+ SyscallSucceedsWithValue(strlen(kData)));
+ EXPECT_STREQ(verify_buf, kData);
+}
+
+// Tests that AIO context cannot be expanded with mremap.
+TEST_F(AIOTest, MremapExpansion) {
+ // Setup a context that is 128 entries deep.
+ ASSERT_THAT(SetupContext(128), SyscallSucceeds());
+ const size_t ctx_size =
+ ASSERT_NO_ERRNO_AND_VALUE(VmaSizeAt(reinterpret_cast<uintptr_t>(ctx_)));
+
+ // Reserve address space for the mremap target so we have something safe to
+ // map over.
+ Mapping dst = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(ctx_size + kPageSize, PROT_NONE, MAP_PRIVATE));
+
+ // Test that remapping to a larger address range fails.
+ ASSERT_THAT(Mremap(reinterpret_cast<void*>(ctx_), ctx_size, dst.len(),
+ MREMAP_FIXED | MREMAP_MAYMOVE, dst.ptr()),
+ PosixErrorIs(EFAULT, _));
+
+ // mm/mremap.c:sys_mremap() => mremap_to() does do_munmap() of the destination
+ // before it hits the VM_DONTEXPAND check in vma_to_resize(), so we should no
+ // longer munmap it (another thread may have created a mapping there).
+ dst.release();
+}
+
+// Tests that AIO calls fail if context's address is inaccessible.
+TEST_F(AIOTest, Mprotect) {
+ // Setup a context that is 128 entries deep.
+ ASSERT_THAT(SetupContext(128), SyscallSucceeds());
+
+ struct iocb cb = CreateCallback();
+ struct iocb* cbs[1] = {&cb};
+
+ ASSERT_THAT(Submit(1, cbs), SyscallSucceedsWithValue(1));
+
+ // Makes the context 'handle' inaccessible and check that all subsequent
+ // calls fail.
+ ASSERT_THAT(mprotect(reinterpret_cast<void*>(ctx_), kPageSize, PROT_NONE),
+ SyscallSucceeds());
+ struct io_event events[1];
+ EXPECT_THAT(GetEvents(1, 1, events, nullptr), SyscallFailsWithErrno(EINVAL));
+ ASSERT_THAT(Submit(1, cbs), SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(DestroyContext(), SyscallFailsWithErrno(EINVAL));
+
+ // Prevent TearDown from attempting to destroy the context and fail.
+ ctx_ = 0;
+}
+
+TEST_F(AIOTest, Timeout) {
+ // Setup a context that is 128 entries deep.
+ ASSERT_THAT(SetupContext(128), SyscallSucceeds());
+
+ struct timespec timeout;
+ timeout.tv_sec = 0;
+ timeout.tv_nsec = 10;
+ struct io_event events[1];
+ ASSERT_THAT(GetEvents(1, 1, events, &timeout), SyscallSucceedsWithValue(0));
+}
+
+class AIOReadWriteParamTest : public AIOTest,
+ public ::testing::WithParamInterface<int> {};
+
+TEST_P(AIOReadWriteParamTest, BadOffset) {
+ // Setup a context that is 128 entries deep.
+ ASSERT_THAT(SetupContext(128), SyscallSucceeds());
+
+ struct iocb cb = CreateCallback();
+ struct iocb* cbs[1] = {&cb};
+
+ // Create a buffer that we can write to.
+ char buf[] = "hello world!";
+ cb.aio_buf = reinterpret_cast<uint64_t>(buf);
+
+ // Set the operation on the callback and give a negative offset.
+ const int opcode = GetParam();
+ cb.aio_lio_opcode = opcode;
+
+ iovec iov = {};
+ if (opcode == IOCB_CMD_PREADV || opcode == IOCB_CMD_PWRITEV) {
+ // Create a valid iovec and set it in the callback.
+ iov.iov_base = reinterpret_cast<void*>(buf);
+ iov.iov_len = 1;
+ cb.aio_buf = reinterpret_cast<uint64_t>(&iov);
+ // aio_nbytes is the number of iovecs.
+ cb.aio_nbytes = 1;
+ }
+
+ // Pass a negative offset.
+ cb.aio_offset = -1;
+
+ // Should get error on submission.
+ ASSERT_THAT(Submit(1, cbs), SyscallFailsWithErrno(EINVAL));
+}
+
+INSTANTIATE_TEST_SUITE_P(BadOffset, AIOReadWriteParamTest,
+ ::testing::Values(IOCB_CMD_PREAD, IOCB_CMD_PWRITE,
+ IOCB_CMD_PREADV, IOCB_CMD_PWRITEV));
+
+class AIOVectorizedParamTest : public AIOTest,
+ public ::testing::WithParamInterface<int> {};
+
+TEST_P(AIOVectorizedParamTest, BadIOVecs) {
+ // Setup a context that is 128 entries deep.
+ ASSERT_THAT(SetupContext(128), SyscallSucceeds());
+
+ struct iocb cb = CreateCallback();
+ struct iocb* cbs[1] = {&cb};
+
+ // Modify the callback to use the operation from the param.
+ cb.aio_lio_opcode = GetParam();
+
+ // Create an iovec with address in kernel range, and pass that as the buffer.
+ iovec iov = {};
+ iov.iov_base = reinterpret_cast<void*>(0xFFFFFFFF00000000);
+ iov.iov_len = 1;
+ cb.aio_buf = reinterpret_cast<uint64_t>(&iov);
+ // aio_nbytes is the number of iovecs.
+ cb.aio_nbytes = 1;
+
+ // Should get error on submission.
+ ASSERT_THAT(Submit(1, cbs), SyscallFailsWithErrno(EFAULT));
+}
+
+INSTANTIATE_TEST_SUITE_P(BadIOVecs, AIOVectorizedParamTest,
+ ::testing::Values(IOCB_CMD_PREADV, IOCB_CMD_PWRITEV));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/alarm.cc b/test/syscalls/linux/alarm.cc
new file mode 100644
index 000000000..940c97285
--- /dev/null
+++ b/test/syscalls/linux/alarm.cc
@@ -0,0 +1,192 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// N.B. Below, main blocks SIGALRM. Test cases must unblock it if they want
+// delivery.
+
+void do_nothing_handler(int sig, siginfo_t* siginfo, void* arg) {}
+
+// No random save as the test relies on alarm timing. Cooperative save tests
+// already cover the save between alarm and read.
+TEST(AlarmTest, Interrupt_NoRandomSave) {
+ int pipe_fds[2];
+ ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds());
+
+ FileDescriptor read_fd(pipe_fds[0]);
+ FileDescriptor write_fd(pipe_fds[1]);
+
+ // Use a signal handler that interrupts but does nothing rather than using the
+ // default terminate action.
+ struct sigaction sa;
+ sa.sa_sigaction = do_nothing_handler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = 0;
+ auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa));
+
+ // Actually allow SIGALRM delivery.
+ auto mask_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGALRM));
+
+ // Alarm in 20 second, which should be well after read blocks below.
+ ASSERT_THAT(alarm(20), SyscallSucceeds());
+
+ char buf;
+ ASSERT_THAT(read(read_fd.get(), &buf, 1), SyscallFailsWithErrno(EINTR));
+}
+
+/* Count of the number of SIGALARMS handled. */
+static volatile int alarms_received = 0;
+
+void inc_alarms_handler(int sig, siginfo_t* siginfo, void* arg) {
+ alarms_received++;
+}
+
+// No random save as the test relies on alarm timing. Cooperative save tests
+// already cover the save between alarm and read.
+TEST(AlarmTest, Restart_NoRandomSave) {
+ alarms_received = 0;
+
+ int pipe_fds[2];
+ ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds());
+
+ FileDescriptor read_fd(pipe_fds[0]);
+ // Write end closed by thread below.
+
+ struct sigaction sa;
+ sa.sa_sigaction = inc_alarms_handler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_RESTART;
+ auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa));
+
+ // Spawn a thread to eventually unblock the read below.
+ ScopedThread t([pipe_fds] {
+ absl::SleepFor(absl::Seconds(30));
+ EXPECT_THAT(close(pipe_fds[1]), SyscallSucceeds());
+ });
+
+ // Actually allow SIGALRM delivery.
+ auto mask_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGALRM));
+
+ // Alarm in 20 second, which should be well after read blocks below, but
+ // before it returns.
+ ASSERT_THAT(alarm(20), SyscallSucceeds());
+
+ // Read and eventually get an EOF from the writer closing. If SA_RESTART
+ // didn't work, then the alarm would not have fired and we wouldn't increment
+ // our alarms_received count in our signal handler, or we would have not
+ // restarted the syscall gracefully, which we expect below in order to be
+ // able to get the final EOF on the pipe.
+ char buf;
+ ASSERT_THAT(read(read_fd.get(), &buf, 1), SyscallSucceeds());
+ EXPECT_EQ(alarms_received, 1);
+
+ t.Join();
+}
+
+// No random save as the test relies on alarm timing. Cooperative save tests
+// already cover the save between alarm and pause.
+TEST(AlarmTest, SaSiginfo_NoRandomSave) {
+ // Use a signal handler that interrupts but does nothing rather than using the
+ // default terminate action.
+ struct sigaction sa;
+ sa.sa_sigaction = do_nothing_handler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO;
+ auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa));
+
+ // Actually allow SIGALRM delivery.
+ auto mask_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGALRM));
+
+ // Alarm in 20 second, which should be well after pause blocks below.
+ ASSERT_THAT(alarm(20), SyscallSucceeds());
+ ASSERT_THAT(pause(), SyscallFailsWithErrno(EINTR));
+}
+
+// No random save as the test relies on alarm timing. Cooperative save tests
+// already cover the save between alarm and pause.
+TEST(AlarmTest, SaInterrupt_NoRandomSave) {
+ // Use a signal handler that interrupts but does nothing rather than using the
+ // default terminate action.
+ struct sigaction sa;
+ sa.sa_sigaction = do_nothing_handler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_INTERRUPT;
+ auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa));
+
+ // Actually allow SIGALRM delivery.
+ auto mask_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGALRM));
+
+ // Alarm in 20 second, which should be well after pause blocks below.
+ ASSERT_THAT(alarm(20), SyscallSucceeds());
+ ASSERT_THAT(pause(), SyscallFailsWithErrno(EINTR));
+}
+
+TEST(AlarmTest, UserModeSpinning) {
+ alarms_received = 0;
+
+ struct sigaction sa = {};
+ sa.sa_sigaction = inc_alarms_handler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO;
+ auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa));
+
+ // Actually allow SIGALRM delivery.
+ auto mask_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGALRM));
+
+ // Alarm in 20 second, which should be well into the loop below.
+ ASSERT_THAT(alarm(20), SyscallSucceeds());
+ // Make sure that the signal gets delivered even if we are spinning in user
+ // mode when it arrives.
+ while (!alarms_received) {
+ }
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
+
+int main(int argc, char** argv) {
+ // These tests depend on delivering SIGALRM to the main thread. Block SIGALRM
+ // so that any other threads created by TestInit will also have SIGALRM
+ // blocked.
+ sigset_t set;
+ sigemptyset(&set);
+ sigaddset(&set, SIGALRM);
+ TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
+
+ gvisor::testing::TestInit(&argc, &argv);
+ return gvisor::testing::RunAllTests();
+}
diff --git a/test/syscalls/linux/arch_prctl.cc b/test/syscalls/linux/arch_prctl.cc
new file mode 100644
index 000000000..81bf5a775
--- /dev/null
+++ b/test/syscalls/linux/arch_prctl.cc
@@ -0,0 +1,48 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <asm/prctl.h>
+#include <sys/prctl.h>
+
+#include "gtest/gtest.h"
+#include "test/util/test_util.h"
+
+// glibc does not provide a prototype for arch_prctl() so declare it here.
+extern "C" int arch_prctl(int code, uintptr_t addr);
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(ArchPrctlTest, GetSetFS) {
+ uintptr_t orig;
+ const uintptr_t kNonCanonicalFsbase = 0x4141414142424242;
+
+ // Get the original FS.base and then set it to the same value (this is
+ // intentional because FS.base is the TLS pointer so we cannot change it
+ // arbitrarily).
+ ASSERT_THAT(arch_prctl(ARCH_GET_FS, reinterpret_cast<uintptr_t>(&orig)),
+ SyscallSucceeds());
+ ASSERT_THAT(arch_prctl(ARCH_SET_FS, orig), SyscallSucceeds());
+
+ // Trying to set FS.base to a non-canonical value should return an error.
+ ASSERT_THAT(arch_prctl(ARCH_SET_FS, kNonCanonicalFsbase),
+ SyscallFailsWithErrno(EPERM));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/bad.cc b/test/syscalls/linux/bad.cc
new file mode 100644
index 000000000..a26fc6af3
--- /dev/null
+++ b/test/syscalls/linux/bad.cc
@@ -0,0 +1,45 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+#ifdef __x86_64__
+// get_kernel_syms is not supported in Linux > 2.6, and not implemented in
+// gVisor.
+constexpr uint32_t kNotImplementedSyscall = SYS_get_kernel_syms;
+#elif __aarch64__
+// Use the last of arch_specific_syscalls which are not implemented on arm64.
+constexpr uint32_t kNotImplementedSyscall = __NR_arch_specific_syscall + 15;
+#endif
+
+TEST(BadSyscallTest, NotImplemented) {
+ EXPECT_THAT(syscall(kNotImplementedSyscall), SyscallFailsWithErrno(ENOSYS));
+}
+
+TEST(BadSyscallTest, NegativeOne) {
+ EXPECT_THAT(syscall(-1), SyscallFailsWithErrno(ENOSYS));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/base_poll_test.cc b/test/syscalls/linux/base_poll_test.cc
new file mode 100644
index 000000000..ab7a19dd0
--- /dev/null
+++ b/test/syscalls/linux/base_poll_test.cc
@@ -0,0 +1,65 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/base_poll_test.h"
+
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <syscall.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+static volatile int timer_fired = 0;
+static void SigAlarmHandler(int, siginfo_t*, void*) { timer_fired = 1; }
+
+BasePollTest::BasePollTest() {
+ // Register our SIGALRM handler, but save the original so we can restore in
+ // the destructor.
+ struct sigaction sa = {};
+ sa.sa_sigaction = SigAlarmHandler;
+ sigfillset(&sa.sa_mask);
+ TEST_PCHECK(sigaction(SIGALRM, &sa, &original_alarm_sa_) == 0);
+}
+
+BasePollTest::~BasePollTest() {
+ ClearTimer();
+ TEST_PCHECK(sigaction(SIGALRM, &original_alarm_sa_, nullptr) == 0);
+}
+
+void BasePollTest::SetTimer(absl::Duration duration) {
+ pid_t tgid = getpid();
+ pid_t tid = gettid();
+ ClearTimer();
+
+ // Create a new timer thread.
+ timer_ = absl::make_unique<TimerThread>(absl::Now() + duration, tgid, tid);
+}
+
+bool BasePollTest::TimerFired() const { return timer_fired; }
+
+void BasePollTest::ClearTimer() {
+ timer_.reset();
+ timer_fired = 0;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/base_poll_test.h b/test/syscalls/linux/base_poll_test.h
new file mode 100644
index 000000000..0d4a6701e
--- /dev/null
+++ b/test/syscalls/linux/base_poll_test.h
@@ -0,0 +1,101 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_BASE_POLL_TEST_H_
+#define GVISOR_TEST_SYSCALLS_BASE_POLL_TEST_H_
+
+#include <signal.h>
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <syscall.h>
+#include <time.h>
+#include <unistd.h>
+
+#include <memory>
+
+#include "gtest/gtest.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+#include "test/util/logging.h"
+#include "test/util/signal_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// TimerThread is a cancelable timer.
+class TimerThread {
+ public:
+ TimerThread(absl::Time deadline, pid_t tgid, pid_t tid)
+ : thread_([=] {
+ mu_.Lock();
+ mu_.AwaitWithDeadline(absl::Condition(&cancel_), deadline);
+ if (!cancel_) {
+ TEST_PCHECK(tgkill(tgid, tid, SIGALRM) == 0);
+ }
+ mu_.Unlock();
+ }) {}
+
+ ~TimerThread() { Cancel(); }
+
+ void Cancel() {
+ absl::MutexLock ml(&mu_);
+ cancel_ = true;
+ }
+
+ private:
+ mutable absl::Mutex mu_;
+ bool cancel_ ABSL_GUARDED_BY(mu_) = false;
+
+ // Must be last to ensure that the destructor for the thread is run before
+ // any other member of the object is destroyed.
+ ScopedThread thread_;
+};
+
+// Base test fixture for poll, select, ppoll, and pselect tests.
+//
+// This fixture makes use of SIGALRM. The handler is saved in SetUp() and
+// restored in TearDown().
+class BasePollTest : public ::testing::Test {
+ protected:
+ BasePollTest();
+ ~BasePollTest() override;
+
+ // Sets a timer that will send a signal to the calling thread after
+ // `duration`.
+ void SetTimer(absl::Duration duration);
+
+ // Returns true if the timer has fired.
+ bool TimerFired() const;
+
+ // Stops the pending timer (if any) and clear the "fired" state.
+ void ClearTimer();
+
+ private:
+ // Thread that implements the timer. If the timer is stopped, timer_ is null.
+ //
+ // We have to use a thread for this purpose because tests using this fixture
+ // expect to be interrupted by the timer signal, but itimers/alarm(2) send
+ // thread-group-directed signals, which may be handled by any thread in the
+ // test process.
+ std::unique_ptr<TimerThread> timer_;
+
+ // The original SIGALRM handler, to restore in destructor.
+ struct sigaction original_alarm_sa_;
+};
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_BASE_POLL_TEST_H_
diff --git a/test/syscalls/linux/bind.cc b/test/syscalls/linux/bind.cc
new file mode 100644
index 000000000..9547c4ab2
--- /dev/null
+++ b/test/syscalls/linux/bind.cc
@@ -0,0 +1,145 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST_P(AllSocketPairTest, Bind) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+}
+
+TEST_P(AllSocketPairTest, BindTooLong) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ // first_addr is a sockaddr_storage being used as a sockaddr_un. Use the full
+ // length which is longer than expected for a Unix socket.
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sizeof(sockaddr_storage)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(AllSocketPairTest, DoubleBindSocket) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ EXPECT_THAT(
+ bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ // Linux 4.09 returns EINVAL here, but some time before 4.19 it switched
+ // to EADDRINUSE.
+ AnyOf(SyscallFailsWithErrno(EADDRINUSE), SyscallFailsWithErrno(EINVAL)));
+}
+
+TEST_P(AllSocketPairTest, GetLocalAddr) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+ socklen_t addressLength = sockets->first_addr_size();
+ struct sockaddr_storage address = {};
+ ASSERT_THAT(getsockname(sockets->first_fd(), (struct sockaddr*)(&address),
+ &addressLength),
+ SyscallSucceeds());
+ EXPECT_EQ(
+ 0, memcmp(&address, sockets->first_addr(), sockets->first_addr_size()));
+}
+
+TEST_P(AllSocketPairTest, GetLocalAddrWithoutBind) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ socklen_t addressLength = sockets->first_addr_size();
+ struct sockaddr_storage received_address = {};
+ ASSERT_THAT(
+ getsockname(sockets->first_fd(), (struct sockaddr*)(&received_address),
+ &addressLength),
+ SyscallSucceeds());
+ struct sockaddr_storage want_address = {};
+ want_address.ss_family = sockets->first_addr()->sa_family;
+ EXPECT_EQ(0, memcmp(&received_address, &want_address, addressLength));
+}
+
+TEST_P(AllSocketPairTest, GetRemoteAddressWithoutConnect) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ socklen_t addressLength = sockets->first_addr_size();
+ struct sockaddr_storage address = {};
+ ASSERT_THAT(getpeername(sockets->second_fd(), (struct sockaddr*)(&address),
+ &addressLength),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
+TEST_P(AllSocketPairTest, DoubleBindAddress) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ EXPECT_THAT(bind(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+TEST_P(AllSocketPairTest, Unbind) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+ ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+
+ // Filesystem Unix sockets do not release their address when closed.
+ if (sockets->first_addr()->sa_data[0] != 0) {
+ ASSERT_THAT(bind(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallFailsWithErrno(EADDRINUSE));
+ return;
+ }
+
+ ASSERT_THAT(bind(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+ ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllUnixDomainSockets, AllSocketPairTest,
+ ::testing::ValuesIn(VecCat<SocketPairKind>(
+ ApplyVec<SocketPairKind>(
+ FilesystemUnboundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_DGRAM,
+ SOCK_SEQPACKET},
+ List<int>{0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(
+ AbstractUnboundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_DGRAM,
+ SOCK_SEQPACKET},
+ List<int>{0, SOCK_NONBLOCK})))));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/brk.cc b/test/syscalls/linux/brk.cc
new file mode 100644
index 000000000..a03a44465
--- /dev/null
+++ b/test/syscalls/linux/brk.cc
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdint.h>
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+TEST(BrkTest, BrkSyscallReturnsOldBrkOnFailure) {
+ auto old_brk = sbrk(0);
+ EXPECT_THAT(syscall(SYS_brk, reinterpret_cast<void*>(-1)),
+ SyscallSucceedsWithValue(reinterpret_cast<uintptr_t>(old_brk)));
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/chdir.cc b/test/syscalls/linux/chdir.cc
new file mode 100644
index 000000000..3182c228b
--- /dev/null
+++ b/test/syscalls/linux/chdir.cc
@@ -0,0 +1,64 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <linux/limits.h>
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/util/capability_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(ChdirTest, Success) {
+ auto old_dir = GetAbsoluteTestTmpdir();
+ auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(chdir(temp_dir.path().c_str()), SyscallSucceeds());
+ // Temp path destructor deletes the newly created tmp dir and Sentry rejects
+ // saving when its current dir is still pointing to the path. Switch to a
+ // permanent path here.
+ EXPECT_THAT(chdir(old_dir.c_str()), SyscallSucceeds());
+}
+
+TEST(ChdirTest, PermissionDenied) {
+ // Drop capabilities that allow us to override directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0666 /* mode */));
+ EXPECT_THAT(chdir(temp_dir.path().c_str()), SyscallFailsWithErrno(EACCES));
+}
+
+TEST(ChdirTest, NotDir) {
+ auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ EXPECT_THAT(chdir(temp_file.path().c_str()), SyscallFailsWithErrno(ENOTDIR));
+}
+
+TEST(ChdirTest, NotExist) {
+ EXPECT_THAT(chdir("/foo/bar"), SyscallFailsWithErrno(ENOENT));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/chmod.cc b/test/syscalls/linux/chmod.cc
new file mode 100644
index 000000000..a06b5cfd6
--- /dev/null
+++ b/test/syscalls/linux/chmod.cc
@@ -0,0 +1,264 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <string>
+
+#include "gtest/gtest.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(ChmodTest, ChmodFileSucceeds) {
+ // Drop capabilities that allow us to override file permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ ASSERT_THAT(chmod(file.path().c_str(), 0466), SyscallSucceeds());
+ EXPECT_THAT(open(file.path().c_str(), O_RDWR), SyscallFailsWithErrno(EACCES));
+}
+
+TEST(ChmodTest, ChmodDirSucceeds) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const std::string fileInDir = NewTempAbsPathInDir(dir.path());
+
+ ASSERT_THAT(chmod(dir.path().c_str(), 0466), SyscallSucceeds());
+ EXPECT_THAT(open(fileInDir.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES));
+}
+
+TEST(ChmodTest, FchmodFileSucceeds_NoRandomSave) {
+ // Drop capabilities that allow us to file directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666));
+ int fd;
+ ASSERT_THAT(fd = open(file.path().c_str(), O_RDWR), SyscallSucceeds());
+
+ {
+ const DisableSave ds; // File permissions are reduced.
+ ASSERT_THAT(fchmod(fd, 0444), SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+ }
+
+ EXPECT_THAT(open(file.path().c_str(), O_RDWR), SyscallFailsWithErrno(EACCES));
+}
+
+TEST(ChmodTest, FchmodDirSucceeds_NoRandomSave) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ int fd;
+ ASSERT_THAT(fd = open(dir.path().c_str(), O_RDONLY | O_DIRECTORY),
+ SyscallSucceeds());
+
+ {
+ const DisableSave ds; // File permissions are reduced.
+ ASSERT_THAT(fchmod(fd, 0), SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+ }
+
+ EXPECT_THAT(open(dir.path().c_str(), O_RDONLY),
+ SyscallFailsWithErrno(EACCES));
+}
+
+TEST(ChmodTest, FchmodBadF) {
+ ASSERT_THAT(fchmod(-1, 0444), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(ChmodTest, FchmodatBadF) {
+ ASSERT_THAT(fchmodat(-1, "foo", 0444, 0), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(ChmodTest, FchmodatNotDir) {
+ ASSERT_THAT(fchmodat(-1, "", 0444, 0), SyscallFailsWithErrno(ENOENT));
+}
+
+TEST(ChmodTest, FchmodatFileAbsolutePath) {
+ // Drop capabilities that allow us to override file permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ ASSERT_THAT(fchmodat(-1, file.path().c_str(), 0444, 0), SyscallSucceeds());
+ EXPECT_THAT(open(file.path().c_str(), O_RDWR), SyscallFailsWithErrno(EACCES));
+}
+
+TEST(ChmodTest, FchmodatDirAbsolutePath) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ int fd;
+ ASSERT_THAT(fd = open(dir.path().c_str(), O_RDONLY | O_DIRECTORY),
+ SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+
+ ASSERT_THAT(fchmodat(-1, dir.path().c_str(), 0, 0), SyscallSucceeds());
+ EXPECT_THAT(open(dir.path().c_str(), O_RDONLY),
+ SyscallFailsWithErrno(EACCES));
+}
+
+TEST(ChmodTest, FchmodatFile) {
+ // Drop capabilities that allow us to override file permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+
+ auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ int parent_fd;
+ ASSERT_THAT(
+ parent_fd = open(GetAbsoluteTestTmpdir().c_str(), O_RDONLY | O_DIRECTORY),
+ SyscallSucceeds());
+
+ ASSERT_THAT(
+ fchmodat(parent_fd, std::string(Basename(temp_file.path())).c_str(), 0444,
+ 0),
+ SyscallSucceeds());
+ EXPECT_THAT(close(parent_fd), SyscallSucceeds());
+
+ EXPECT_THAT(open(temp_file.path().c_str(), O_RDWR),
+ SyscallFailsWithErrno(EACCES));
+}
+
+TEST(ChmodTest, FchmodatDir) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ int parent_fd;
+ ASSERT_THAT(
+ parent_fd = open(GetAbsoluteTestTmpdir().c_str(), O_RDONLY | O_DIRECTORY),
+ SyscallSucceeds());
+
+ int fd;
+ ASSERT_THAT(fd = open(dir.path().c_str(), O_RDONLY | O_DIRECTORY),
+ SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+
+ ASSERT_THAT(
+ fchmodat(parent_fd, std::string(Basename(dir.path())).c_str(), 0, 0),
+ SyscallSucceeds());
+ EXPECT_THAT(close(parent_fd), SyscallSucceeds());
+
+ EXPECT_THAT(open(dir.path().c_str(), O_RDONLY | O_DIRECTORY),
+ SyscallFailsWithErrno(EACCES));
+}
+
+TEST(ChmodTest, ChmodDowngradeWritability_NoRandomSave) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666));
+
+ int fd;
+ ASSERT_THAT(fd = open(file.path().c_str(), O_RDWR), SyscallSucceeds());
+
+ const DisableSave ds; // Permissions are dropped.
+ ASSERT_THAT(chmod(file.path().c_str(), 0444), SyscallSucceeds());
+ EXPECT_THAT(write(fd, "hello", 5), SyscallSucceedsWithValue(5));
+
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+TEST(ChmodTest, ChmodFileToNoPermissionsSucceeds) {
+ // Drop capabilities that allow us to override file permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666));
+
+ ASSERT_THAT(chmod(file.path().c_str(), 0), SyscallSucceeds());
+
+ EXPECT_THAT(open(file.path().c_str(), O_RDONLY),
+ SyscallFailsWithErrno(EACCES));
+}
+
+TEST(ChmodTest, FchmodDowngradeWritability_NoRandomSave) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ int fd;
+ ASSERT_THAT(fd = open(file.path().c_str(), O_RDWR | O_CREAT, 0666),
+ SyscallSucceeds());
+
+ const DisableSave ds; // Permissions are dropped.
+ ASSERT_THAT(fchmod(fd, 0444), SyscallSucceeds());
+ EXPECT_THAT(write(fd, "hello", 5), SyscallSucceedsWithValue(5));
+
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+TEST(ChmodTest, FchmodFileToNoPermissionsSucceeds_NoRandomSave) {
+ // Drop capabilities that allow us to override file permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666));
+
+ int fd;
+ ASSERT_THAT(fd = open(file.path().c_str(), O_RDWR), SyscallSucceeds());
+
+ {
+ const DisableSave ds; // Permissions are dropped.
+ ASSERT_THAT(fchmod(fd, 0), SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+ }
+
+ EXPECT_THAT(open(file.path().c_str(), O_RDONLY),
+ SyscallFailsWithErrno(EACCES));
+}
+
+// Verify that we can get a RW FD after chmod, even if a RO fd is left open.
+TEST(ChmodTest, ChmodWritableWithOpenFD) {
+ // FIXME(b/72455313): broken on hostfs.
+ if (IsRunningOnGvisor()) {
+ return;
+ }
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0444));
+
+ FileDescriptor fd1 = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+
+ ASSERT_THAT(fchmod(fd1.get(), 0644), SyscallSucceeds());
+
+ // This FD is writable, even though fd1 has a read-only reference to the file.
+ FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+
+ // fd1 is not writable, but fd2 is.
+ char c = 'a';
+ EXPECT_THAT(WriteFd(fd1.get(), &c, 1), SyscallFailsWithErrno(EBADF));
+ EXPECT_THAT(WriteFd(fd2.get(), &c, 1), SyscallSucceedsWithValue(1));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/chown.cc b/test/syscalls/linux/chown.cc
new file mode 100644
index 000000000..7a28b674d
--- /dev/null
+++ b/test/syscalls/linux/chown.cc
@@ -0,0 +1,206 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <grp.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
+#include "absl/synchronization/notification.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+ABSL_FLAG(int32_t, scratch_uid1, 65534, "first scratch UID");
+ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID");
+ABSL_FLAG(int32_t, scratch_gid, 65534, "first scratch GID");
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(ChownTest, FchownBadF) {
+ ASSERT_THAT(fchown(-1, 0, 0), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(ChownTest, FchownatBadF) {
+ ASSERT_THAT(fchownat(-1, "fff", 0, 0, 0), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(ChownTest, FchownatEmptyPath) {
+ const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const auto fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY | O_RDONLY));
+ ASSERT_THAT(fchownat(fd.get(), "", 0, 0, 0), SyscallFailsWithErrno(ENOENT));
+}
+
+using Chown =
+ std::function<PosixError(const std::string&, uid_t owner, gid_t group)>;
+
+class ChownParamTest : public ::testing::TestWithParam<Chown> {};
+
+TEST_P(ChownParamTest, ChownFileSucceeds) {
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_CHOWN))) {
+ ASSERT_NO_ERRNO(SetCapability(CAP_CHOWN, false));
+ }
+
+ const auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ // At least *try* setting to a group other than the EGID.
+ gid_t gid;
+ EXPECT_THAT(gid = getegid(), SyscallSucceeds());
+ int num_groups;
+ EXPECT_THAT(num_groups = getgroups(0, nullptr), SyscallSucceeds());
+ if (num_groups > 0) {
+ std::vector<gid_t> list(num_groups);
+ EXPECT_THAT(getgroups(list.size(), list.data()), SyscallSucceeds());
+ gid = list[0];
+ }
+
+ EXPECT_NO_ERRNO(GetParam()(file.path(), geteuid(), gid));
+
+ struct stat s = {};
+ ASSERT_THAT(stat(file.path().c_str(), &s), SyscallSucceeds());
+ EXPECT_EQ(s.st_uid, geteuid());
+ EXPECT_EQ(s.st_gid, gid);
+}
+
+TEST_P(ChownParamTest, ChownFilePermissionDenied) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID)));
+
+ const auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0777));
+
+ // Drop privileges and change IDs only in child thread, or else this parent
+ // thread won't be able to open some log files after the test ends.
+ ScopedThread([&] {
+ // Drop privileges.
+ if (HaveCapability(CAP_CHOWN).ValueOrDie()) {
+ EXPECT_NO_ERRNO(SetCapability(CAP_CHOWN, false));
+ }
+
+ // Change EUID and EGID.
+ //
+ // See note about POSIX below.
+ EXPECT_THAT(
+ syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1),
+ SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid1), -1),
+ SyscallSucceeds());
+
+ EXPECT_THAT(GetParam()(file.path(), geteuid(), getegid()),
+ PosixErrorIs(EPERM, ::testing::ContainsRegex("chown")));
+ });
+}
+
+TEST_P(ChownParamTest, ChownFileSucceedsAsRoot) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_CHOWN))));
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_SETUID))));
+
+ const std::string filename = NewTempAbsPath();
+
+ absl::Notification fileCreated, fileChowned;
+ // Change UID only in child thread, or else this parent thread won't be able
+ // to open some log files after the test ends.
+ ScopedThread t([&] {
+ // POSIX requires that all threads in a process share the same UIDs, so
+ // the NPTL setresuid wrappers use signals to make all threads execute the
+ // setresuid syscall. However, we want this thread to have its own set of
+ // credentials different from the parent process, so we use the raw
+ // syscall.
+ EXPECT_THAT(
+ syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid2), -1),
+ SyscallSucceeds());
+
+ // Create file and immediately close it.
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(filename, O_CREAT | O_RDWR, 0644));
+ fd.reset(); // Close the fd.
+
+ fileCreated.Notify();
+ fileChowned.WaitForNotification();
+
+ EXPECT_THAT(open(filename.c_str(), O_RDWR), SyscallFailsWithErrno(EACCES));
+ FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(Open(filename, O_RDONLY));
+ });
+
+ fileCreated.WaitForNotification();
+
+ // Set file's owners to someone different.
+ EXPECT_NO_ERRNO(GetParam()(filename, absl::GetFlag(FLAGS_scratch_uid1),
+ absl::GetFlag(FLAGS_scratch_gid)));
+
+ struct stat s;
+ EXPECT_THAT(stat(filename.c_str(), &s), SyscallSucceeds());
+ EXPECT_EQ(s.st_uid, absl::GetFlag(FLAGS_scratch_uid1));
+ EXPECT_EQ(s.st_gid, absl::GetFlag(FLAGS_scratch_gid));
+
+ fileChowned.Notify();
+}
+
+PosixError errorFromReturn(const std::string& name, int ret) {
+ if (ret == -1) {
+ return PosixError(errno, absl::StrCat(name, " failed"));
+ }
+ return NoError();
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ ChownKinds, ChownParamTest,
+ ::testing::Values(
+ [](const std::string& path, uid_t owner, gid_t group) -> PosixError {
+ int rc = chown(path.c_str(), owner, group);
+ MaybeSave();
+ return errorFromReturn("chown", rc);
+ },
+ [](const std::string& path, uid_t owner, gid_t group) -> PosixError {
+ int rc = lchown(path.c_str(), owner, group);
+ MaybeSave();
+ return errorFromReturn("lchown", rc);
+ },
+ [](const std::string& path, uid_t owner, gid_t group) -> PosixError {
+ ASSIGN_OR_RETURN_ERRNO(auto fd, Open(path, O_RDWR));
+ int rc = fchown(fd.get(), owner, group);
+ MaybeSave();
+ return errorFromReturn("fchown", rc);
+ },
+ [](const std::string& path, uid_t owner, gid_t group) -> PosixError {
+ ASSIGN_OR_RETURN_ERRNO(auto fd, Open(path, O_RDWR));
+ int rc = fchownat(fd.get(), "", owner, group, AT_EMPTY_PATH);
+ MaybeSave();
+ return errorFromReturn("fchownat-fd", rc);
+ },
+ [](const std::string& path, uid_t owner, gid_t group) -> PosixError {
+ ASSIGN_OR_RETURN_ERRNO(auto dirfd, Open(std::string(Dirname(path)),
+ O_DIRECTORY | O_RDONLY));
+ int rc = fchownat(dirfd.get(), std::string(Basename(path)).c_str(),
+ owner, group, 0);
+ MaybeSave();
+ return errorFromReturn("fchownat-dirfd", rc);
+ }));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/chroot.cc b/test/syscalls/linux/chroot.cc
new file mode 100644
index 000000000..85ec013d5
--- /dev/null
+++ b/test/syscalls/linux/chroot.cc
@@ -0,0 +1,366 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <stddef.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <syscall.h>
+#include <unistd.h>
+
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "test/util/capability_util.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/mount_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+using ::testing::HasSubstr;
+using ::testing::Not;
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(ChrootTest, Success) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT)));
+
+ auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(chroot(temp_dir.path().c_str()), SyscallSucceeds());
+}
+
+TEST(ChrootTest, PermissionDenied) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT)));
+
+ // CAP_DAC_READ_SEARCH and CAP_DAC_OVERRIDE may override Execute permission on
+ // directories.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+
+ auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0666 /* mode */));
+ EXPECT_THAT(chroot(temp_dir.path().c_str()), SyscallFailsWithErrno(EACCES));
+}
+
+TEST(ChrootTest, NotDir) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT)));
+
+ auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ EXPECT_THAT(chroot(temp_file.path().c_str()), SyscallFailsWithErrno(ENOTDIR));
+}
+
+TEST(ChrootTest, NotExist) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT)));
+
+ EXPECT_THAT(chroot("/foo/bar"), SyscallFailsWithErrno(ENOENT));
+}
+
+TEST(ChrootTest, WithoutCapability) {
+ // Unset CAP_SYS_CHROOT.
+ ASSERT_NO_ERRNO(SetCapability(CAP_SYS_CHROOT, false));
+
+ auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(chroot(temp_dir.path().c_str()), SyscallFailsWithErrno(EPERM));
+}
+
+TEST(ChrootTest, CreatesNewRoot) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT)));
+
+ // Grab the initial cwd.
+ char initial_cwd[1024];
+ ASSERT_THAT(syscall(__NR_getcwd, initial_cwd, sizeof(initial_cwd)),
+ SyscallSucceeds());
+
+ auto new_root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto file_in_new_root =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(new_root.path()));
+
+ // chroot into new_root.
+ ASSERT_THAT(chroot(new_root.path().c_str()), SyscallSucceeds());
+
+ // getcwd should return "(unreachable)" followed by the initial_cwd.
+ char cwd[1024];
+ ASSERT_THAT(syscall(__NR_getcwd, cwd, sizeof(cwd)), SyscallSucceeds());
+ std::string expected_cwd = "(unreachable)";
+ expected_cwd += initial_cwd;
+ EXPECT_STREQ(cwd, expected_cwd.c_str());
+
+ // Should not be able to stat file by its full path.
+ struct stat statbuf;
+ EXPECT_THAT(stat(file_in_new_root.path().c_str(), &statbuf),
+ SyscallFailsWithErrno(ENOENT));
+
+ // Should be able to stat file at new rooted path.
+ auto basename = std::string(Basename(file_in_new_root.path()));
+ auto rootedFile = "/" + basename;
+ ASSERT_THAT(stat(rootedFile.c_str(), &statbuf), SyscallSucceeds());
+
+ // Should be able to stat cwd at '.' even though it's outside root.
+ ASSERT_THAT(stat(".", &statbuf), SyscallSucceeds());
+
+ // chdir into new root.
+ ASSERT_THAT(chdir("/"), SyscallSucceeds());
+
+ // getcwd should return "/".
+ EXPECT_THAT(syscall(__NR_getcwd, cwd, sizeof(cwd)), SyscallSucceeds());
+ EXPECT_STREQ(cwd, "/");
+
+ // Statting '.', '..', '/', and '/..' all return the same dev and inode.
+ struct stat statbuf_dot;
+ ASSERT_THAT(stat(".", &statbuf_dot), SyscallSucceeds());
+ struct stat statbuf_dotdot;
+ ASSERT_THAT(stat("..", &statbuf_dotdot), SyscallSucceeds());
+ EXPECT_EQ(statbuf_dot.st_dev, statbuf_dotdot.st_dev);
+ EXPECT_EQ(statbuf_dot.st_ino, statbuf_dotdot.st_ino);
+ struct stat statbuf_slash;
+ ASSERT_THAT(stat("/", &statbuf_slash), SyscallSucceeds());
+ EXPECT_EQ(statbuf_dot.st_dev, statbuf_slash.st_dev);
+ EXPECT_EQ(statbuf_dot.st_ino, statbuf_slash.st_ino);
+ struct stat statbuf_slashdotdot;
+ ASSERT_THAT(stat("/..", &statbuf_slashdotdot), SyscallSucceeds());
+ EXPECT_EQ(statbuf_dot.st_dev, statbuf_slashdotdot.st_dev);
+ EXPECT_EQ(statbuf_dot.st_ino, statbuf_slashdotdot.st_ino);
+}
+
+TEST(ChrootTest, DotDotFromOpenFD) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT)));
+
+ auto dir_outside_root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(dir_outside_root.path(), O_RDONLY | O_DIRECTORY));
+ auto new_root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ // chroot into new_root.
+ ASSERT_THAT(chroot(new_root.path().c_str()), SyscallSucceeds());
+
+ // openat on fd with path .. will succeed.
+ int other_fd;
+ ASSERT_THAT(other_fd = openat(fd.get(), "..", O_RDONLY), SyscallSucceeds());
+ EXPECT_THAT(close(other_fd), SyscallSucceeds());
+
+ // getdents on fd should not error.
+ char buf[1024];
+ ASSERT_THAT(syscall(SYS_getdents64, fd.get(), buf, sizeof(buf)),
+ SyscallSucceeds());
+}
+
+// Test that link resolution in a chroot can escape the root by following an
+// open proc fd. Regression test for b/32316719.
+TEST(ChrootTest, ProcFdLinkResolutionInChroot) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT)));
+
+ const TempPath file_outside_chroot =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file_outside_chroot.path(), O_RDONLY));
+
+ const FileDescriptor proc_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open("/proc", O_DIRECTORY | O_RDONLY | O_CLOEXEC));
+
+ auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ ASSERT_THAT(chroot(temp_dir.path().c_str()), SyscallSucceeds());
+
+ // Opening relative to an already open fd to a node outside the chroot works.
+ const FileDescriptor proc_self_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ OpenAt(proc_fd.get(), "self/fd", O_DIRECTORY | O_RDONLY | O_CLOEXEC));
+
+ // Proc fd symlinks can escape the chroot if the fd the symlink refers to
+ // refers to an object outside the chroot.
+ struct stat s = {};
+ EXPECT_THAT(
+ fstatat(proc_self_fd.get(), absl::StrCat(fd.get()).c_str(), &s, 0),
+ SyscallSucceeds());
+
+ // Try to stat the stdin fd. Internally, this is handled differently from a
+ // proc fd entry pointing to a file, since stdin is backed by a host fd, and
+ // isn't a walkable path on the filesystem inside the sandbox.
+ EXPECT_THAT(fstatat(proc_self_fd.get(), "0", &s, 0), SyscallSucceeds());
+}
+
+// This test will verify that when you hold a fd to proc before entering
+// a chroot that any files inside the chroot will appear rooted to the
+// base chroot when examining /proc/self/fd/{num}.
+TEST(ChrootTest, ProcMemSelfFdsNoEscapeProcOpen) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT)));
+
+ // Get a FD to /proc before we enter the chroot.
+ const FileDescriptor proc =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc", O_RDONLY));
+
+ // Create and enter a chroot directory.
+ const auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ ASSERT_THAT(chroot(temp_dir.path().c_str()), SyscallSucceeds());
+
+ // Open a file inside the chroot at /foo.
+ const FileDescriptor foo =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/foo", O_CREAT | O_RDONLY, 0644));
+
+ // Examine /proc/self/fd/{foo_fd} to see if it exposes the fact that we're
+ // inside a chroot, the path should be /foo and NOT {chroot_dir}/foo.
+ const std::string fd_path = absl::StrCat("self/fd/", foo.get());
+ char buf[1024] = {};
+ size_t bytes_read = 0;
+ ASSERT_THAT(bytes_read =
+ readlinkat(proc.get(), fd_path.c_str(), buf, sizeof(buf) - 1),
+ SyscallSucceeds());
+
+ // The link should resolve to something.
+ ASSERT_GT(bytes_read, 0);
+
+ // Assert that the link doesn't contain the chroot path and is only /foo.
+ EXPECT_STREQ(buf, "/foo");
+}
+
+// This test will verify that a file inside a chroot when mmapped will not
+// expose the full file path via /proc/self/maps and instead honor the chroot.
+TEST(ChrootTest, ProcMemSelfMapsNoEscapeProcOpen) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT)));
+
+ // Get a FD to /proc before we enter the chroot.
+ const FileDescriptor proc =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc", O_RDONLY));
+
+ // Create and enter a chroot directory.
+ const auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ ASSERT_THAT(chroot(temp_dir.path().c_str()), SyscallSucceeds());
+
+ // Open a file inside the chroot at /foo.
+ const FileDescriptor foo =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/foo", O_CREAT | O_RDONLY, 0644));
+
+ // Mmap the newly created file.
+ void* foo_map = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE,
+ foo.get(), 0);
+ ASSERT_THAT(reinterpret_cast<int64_t>(foo_map), SyscallSucceeds());
+
+ // Always unmap.
+ auto cleanup_map = Cleanup(
+ [&] { EXPECT_THAT(munmap(foo_map, kPageSize), SyscallSucceeds()); });
+
+ // Examine /proc/self/maps to be sure that /foo doesn't appear to be
+ // mapped with the full chroot path.
+ const FileDescriptor maps =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenAt(proc.get(), "self/maps", O_RDONLY));
+
+ size_t bytes_read = 0;
+ char buf[8 * 1024] = {};
+ ASSERT_THAT(bytes_read = ReadFd(maps.get(), buf, sizeof(buf)),
+ SyscallSucceeds());
+
+ // The maps file should have something.
+ ASSERT_GT(bytes_read, 0);
+
+ // Finally we want to make sure the maps don't contain the chroot path
+ ASSERT_EQ(std::string(buf, bytes_read).find(temp_dir.path()),
+ std::string::npos);
+}
+
+// Test that mounts outside the chroot will not appear in /proc/self/mounts or
+// /proc/self/mountinfo.
+TEST(ChrootTest, ProcMountsMountinfoNoEscape) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT)));
+
+ // We are going to create some mounts and then chroot. In order to be able to
+ // unmount the mounts after the test run, we must chdir to the root and use
+ // relative paths for all mounts. That way, as long as we never chdir into
+ // the new root, we can access the mounts via relative paths and unmount them.
+ ASSERT_THAT(chdir("/"), SyscallSucceeds());
+
+ // Create nested tmpfs mounts. Note the use of relative paths in Mount calls.
+ auto const outer_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto const outer_mount = ASSERT_NO_ERRNO_AND_VALUE(Mount(
+ "none", JoinPath(".", outer_dir.path()), "tmpfs", 0, "mode=0700", 0));
+
+ auto const inner_dir =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(outer_dir.path()));
+ auto const inner_mount = ASSERT_NO_ERRNO_AND_VALUE(Mount(
+ "none", JoinPath(".", inner_dir.path()), "tmpfs", 0, "mode=0700", 0));
+
+ // Filenames that will be checked for mounts, all relative to /proc dir.
+ std::string paths[3] = {"mounts", "self/mounts", "self/mountinfo"};
+
+ for (const std::string& path : paths) {
+ // We should have both inner and outer mounts.
+ const std::string contents =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents(JoinPath("/proc", path)));
+ EXPECT_THAT(contents, AllOf(HasSubstr(outer_dir.path()),
+ HasSubstr(inner_dir.path())));
+ // We better have at least two mounts: the mounts we created plus the root.
+ std::vector<absl::string_view> submounts =
+ absl::StrSplit(contents, '\n', absl::SkipWhitespace());
+ EXPECT_GT(submounts.size(), 2);
+ }
+
+ // Get a FD to /proc before we enter the chroot.
+ const FileDescriptor proc =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc", O_RDONLY));
+
+ // Chroot to outer mount.
+ ASSERT_THAT(chroot(outer_dir.path().c_str()), SyscallSucceeds());
+
+ for (const std::string& path : paths) {
+ const FileDescriptor proc_file =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenAt(proc.get(), path, O_RDONLY));
+
+ // Only two mounts visible from this chroot: the inner and outer. Both
+ // paths should be relative to the new chroot.
+ const std::string contents =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContentsFD(proc_file.get()));
+ EXPECT_THAT(contents,
+ AllOf(HasSubstr(absl::StrCat(Basename(inner_dir.path()))),
+ Not(HasSubstr(outer_dir.path())),
+ Not(HasSubstr(inner_dir.path()))));
+ std::vector<absl::string_view> submounts =
+ absl::StrSplit(contents, '\n', absl::SkipWhitespace());
+ EXPECT_EQ(submounts.size(), 2);
+ }
+
+ // Chroot to inner mount. We must use an absolute path accessible to our
+ // chroot.
+ const std::string inner_dir_basename =
+ absl::StrCat("/", Basename(inner_dir.path()));
+ ASSERT_THAT(chroot(inner_dir_basename.c_str()), SyscallSucceeds());
+
+ for (const std::string& path : paths) {
+ const FileDescriptor proc_file =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenAt(proc.get(), path, O_RDONLY));
+ const std::string contents =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContentsFD(proc_file.get()));
+
+ // Only the inner mount visible from this chroot.
+ std::vector<absl::string_view> submounts =
+ absl::StrSplit(contents, '\n', absl::SkipWhitespace());
+ EXPECT_EQ(submounts.size(), 1);
+ }
+
+ // Chroot back to ".".
+ ASSERT_THAT(chroot("."), SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/clock_getres.cc b/test/syscalls/linux/clock_getres.cc
new file mode 100644
index 000000000..c408b936c
--- /dev/null
+++ b/test/syscalls/linux/clock_getres.cc
@@ -0,0 +1,37 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/time.h>
+#include <time.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// clock_getres works regardless of whether or not a timespec is passed.
+TEST(ClockGetres, Timespec) {
+ struct timespec ts;
+ EXPECT_THAT(clock_getres(CLOCK_MONOTONIC, &ts), SyscallSucceeds());
+ EXPECT_THAT(clock_getres(CLOCK_MONOTONIC, nullptr), SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/clock_gettime.cc b/test/syscalls/linux/clock_gettime.cc
new file mode 100644
index 000000000..7f6015049
--- /dev/null
+++ b/test/syscalls/linux/clock_gettime.cc
@@ -0,0 +1,163 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <pthread.h>
+#include <sys/time.h>
+
+#include <cerrno>
+#include <cstdint>
+#include <ctime>
+#include <list>
+#include <memory>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+int64_t clock_gettime_nsecs(clockid_t id) {
+ struct timespec ts;
+ TEST_PCHECK(clock_gettime(id, &ts) == 0);
+ return (ts.tv_sec * 1000000000 + ts.tv_nsec);
+}
+
+// Spin on the CPU for at least ns nanoseconds, based on
+// CLOCK_THREAD_CPUTIME_ID.
+void spin_ns(int64_t ns) {
+ int64_t start = clock_gettime_nsecs(CLOCK_THREAD_CPUTIME_ID);
+ int64_t end = start + ns;
+
+ do {
+ constexpr int kLoopCount = 1000000; // large and arbitrary
+ // volatile to prevent the compiler from skipping this loop.
+ for (volatile int i = 0; i < kLoopCount; i++) {
+ }
+ } while (clock_gettime_nsecs(CLOCK_THREAD_CPUTIME_ID) < end);
+}
+
+// Test that CLOCK_PROCESS_CPUTIME_ID is a superset of CLOCK_THREAD_CPUTIME_ID.
+TEST(ClockGettime, CputimeId) {
+ constexpr int kNumThreads = 13; // arbitrary
+
+ absl::Duration spin_time = absl::Seconds(1);
+
+ // Start off the worker threads and compute the aggregate time spent by
+ // the workers. Note that we test CLOCK_PROCESS_CPUTIME_ID by having the
+ // workers execute in parallel and verifying that CLOCK_PROCESS_CPUTIME_ID
+ // accumulates the runtime of all threads.
+ int64_t start = clock_gettime_nsecs(CLOCK_PROCESS_CPUTIME_ID);
+
+ // Create a kNumThreads threads.
+ std::list<ScopedThread> threads;
+ for (int i = 0; i < kNumThreads; i++) {
+ threads.emplace_back(
+ [spin_time] { spin_ns(absl::ToInt64Nanoseconds(spin_time)); });
+ }
+ for (auto& t : threads) {
+ t.Join();
+ }
+
+ int64_t end = clock_gettime_nsecs(CLOCK_PROCESS_CPUTIME_ID);
+
+ // The aggregate time spent in the worker threads must be at least
+ // 'kNumThreads' times the time each thread spun.
+ ASSERT_GE(end - start, kNumThreads * absl::ToInt64Nanoseconds(spin_time));
+}
+
+TEST(ClockGettime, JavaThreadTime) {
+ clockid_t clockid;
+ ASSERT_EQ(0, pthread_getcpuclockid(pthread_self(), &clockid));
+ struct timespec tp;
+ ASSERT_THAT(clock_getres(clockid, &tp), SyscallSucceeds());
+ EXPECT_TRUE(tp.tv_sec > 0 || tp.tv_nsec > 0);
+ // A thread cputime is updated each 10msec and there is no approximation
+ // if a task is running.
+ do {
+ ASSERT_THAT(clock_gettime(clockid, &tp), SyscallSucceeds());
+ } while (tp.tv_sec == 0 && tp.tv_nsec == 0);
+ EXPECT_TRUE(tp.tv_sec > 0 || tp.tv_nsec > 0);
+}
+
+// There is not much to test here, since CLOCK_REALTIME may be discontiguous.
+TEST(ClockGettime, RealtimeWorks) {
+ struct timespec tp;
+ EXPECT_THAT(clock_gettime(CLOCK_REALTIME, &tp), SyscallSucceeds());
+}
+
+class MonotonicClockTest : public ::testing::TestWithParam<clockid_t> {};
+
+TEST_P(MonotonicClockTest, IsMonotonic) {
+ auto end = absl::Now() + absl::Seconds(5);
+
+ struct timespec tp;
+ EXPECT_THAT(clock_gettime(GetParam(), &tp), SyscallSucceeds());
+
+ auto prev = absl::TimeFromTimespec(tp);
+ while (absl::Now() < end) {
+ EXPECT_THAT(clock_gettime(GetParam(), &tp), SyscallSucceeds());
+ auto now = absl::TimeFromTimespec(tp);
+ EXPECT_GE(now, prev);
+ prev = now;
+ }
+}
+
+std::string PrintClockId(::testing::TestParamInfo<clockid_t> info) {
+ switch (info.param) {
+ case CLOCK_MONOTONIC:
+ return "CLOCK_MONOTONIC";
+ case CLOCK_MONOTONIC_COARSE:
+ return "CLOCK_MONOTONIC_COARSE";
+ case CLOCK_MONOTONIC_RAW:
+ return "CLOCK_MONOTONIC_RAW";
+ case CLOCK_BOOTTIME:
+ // CLOCK_BOOTTIME is a monotonic clock.
+ return "CLOCK_BOOTTIME";
+ default:
+ return absl::StrCat(info.param);
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(ClockGettime, MonotonicClockTest,
+ ::testing::Values(CLOCK_MONOTONIC,
+ CLOCK_MONOTONIC_COARSE,
+ CLOCK_MONOTONIC_RAW, CLOCK_BOOTTIME),
+ PrintClockId);
+
+TEST(ClockGettime, UnimplementedReturnsEINVAL) {
+ SKIP_IF(!IsRunningOnGvisor());
+
+ struct timespec tp;
+ EXPECT_THAT(clock_gettime(CLOCK_REALTIME_ALARM, &tp),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(clock_gettime(CLOCK_BOOTTIME_ALARM, &tp),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(ClockGettime, InvalidClockIDReturnsEINVAL) {
+ struct timespec tp;
+ EXPECT_THAT(clock_gettime(-1, &tp), SyscallFailsWithErrno(EINVAL));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/clock_nanosleep.cc b/test/syscalls/linux/clock_nanosleep.cc
new file mode 100644
index 000000000..b55cddc52
--- /dev/null
+++ b/test/syscalls/linux/clock_nanosleep.cc
@@ -0,0 +1,179 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <time.h>
+
+#include <atomic>
+#include <utility>
+
+#include "gtest/gtest.h"
+#include "absl/time/time.h"
+#include "test/util/cleanup.h"
+#include "test/util/posix_error.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+#include "test/util/timer_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// sys_clock_nanosleep is defined because the glibc clock_nanosleep returns
+// error numbers directly and does not set errno. This makes our Syscall
+// matchers look a little weird when expecting failure:
+// "SyscallSucceedsWithValue(ERRNO)".
+int sys_clock_nanosleep(clockid_t clkid, int flags,
+ const struct timespec* request,
+ struct timespec* remain) {
+ return syscall(SYS_clock_nanosleep, clkid, flags, request, remain);
+}
+
+PosixErrorOr<absl::Time> GetTime(clockid_t clk) {
+ struct timespec ts = {};
+ const int rc = clock_gettime(clk, &ts);
+ MaybeSave();
+ if (rc < 0) {
+ return PosixError(errno, "clock_gettime");
+ }
+ return absl::TimeFromTimespec(ts);
+}
+
+class WallClockNanosleepTest : public ::testing::TestWithParam<clockid_t> {};
+
+TEST_P(WallClockNanosleepTest, InvalidValues) {
+ const struct timespec invalid[] = {
+ {.tv_sec = -1, .tv_nsec = -1}, {.tv_sec = 0, .tv_nsec = INT32_MIN},
+ {.tv_sec = 0, .tv_nsec = INT32_MAX}, {.tv_sec = 0, .tv_nsec = -1},
+ {.tv_sec = -1, .tv_nsec = 0},
+ };
+
+ for (auto const ts : invalid) {
+ EXPECT_THAT(sys_clock_nanosleep(GetParam(), 0, &ts, nullptr),
+ SyscallFailsWithErrno(EINVAL));
+ }
+}
+
+TEST_P(WallClockNanosleepTest, SleepOneSecond) {
+ constexpr absl::Duration kSleepDuration = absl::Seconds(1);
+ struct timespec duration = absl::ToTimespec(kSleepDuration);
+
+ const absl::Time before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ EXPECT_THAT(
+ RetryEINTR(sys_clock_nanosleep)(GetParam(), 0, &duration, &duration),
+ SyscallSucceeds());
+ const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+
+ EXPECT_GE(after - before, kSleepDuration);
+}
+
+TEST_P(WallClockNanosleepTest, InterruptedNanosleep) {
+ constexpr absl::Duration kSleepDuration = absl::Seconds(60);
+ struct timespec duration = absl::ToTimespec(kSleepDuration);
+
+ // Install no-op signal handler for SIGALRM.
+ struct sigaction sa = {};
+ sigfillset(&sa.sa_mask);
+ sa.sa_handler = +[](int signo) {};
+ const auto cleanup_sa =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa));
+
+ // Measure time since setting the alarm, since the alarm will interrupt the
+ // sleep and hence determine how long we sleep.
+ const absl::Time before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+
+ // Set an alarm to go off while sleeping.
+ struct itimerval timer = {};
+ timer.it_value.tv_sec = 1;
+ timer.it_value.tv_usec = 0;
+ timer.it_interval.tv_sec = 1;
+ timer.it_interval.tv_usec = 0;
+ const auto cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_REAL, timer));
+
+ EXPECT_THAT(sys_clock_nanosleep(GetParam(), 0, &duration, &duration),
+ SyscallFailsWithErrno(EINTR));
+ const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+
+ // Remaining time updated.
+ const absl::Duration remaining = absl::DurationFromTimespec(duration);
+ EXPECT_GE(after - before + remaining, kSleepDuration);
+}
+
+// Remaining time is *not* updated if nanosleep completes uninterrupted.
+TEST_P(WallClockNanosleepTest, UninterruptedNanosleep) {
+ constexpr absl::Duration kSleepDuration = absl::Milliseconds(10);
+ const struct timespec duration = absl::ToTimespec(kSleepDuration);
+
+ while (true) {
+ constexpr int kRemainingMagic = 42;
+ struct timespec remaining;
+ remaining.tv_sec = kRemainingMagic;
+ remaining.tv_nsec = kRemainingMagic;
+
+ int ret = sys_clock_nanosleep(GetParam(), 0, &duration, &remaining);
+ if (ret == EINTR) {
+ // Retry from beginning. We want a single uninterrupted call.
+ continue;
+ }
+
+ EXPECT_THAT(ret, SyscallSucceeds());
+ EXPECT_EQ(remaining.tv_sec, kRemainingMagic);
+ EXPECT_EQ(remaining.tv_nsec, kRemainingMagic);
+ break;
+ }
+}
+
+TEST_P(WallClockNanosleepTest, SleepUntil) {
+ const absl::Time now = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time until = now + absl::Seconds(2);
+ const struct timespec ts = absl::ToTimespec(until);
+
+ EXPECT_THAT(
+ RetryEINTR(sys_clock_nanosleep)(GetParam(), TIMER_ABSTIME, &ts, nullptr),
+ SyscallSucceeds());
+ const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+
+ EXPECT_GE(after, until);
+}
+
+INSTANTIATE_TEST_SUITE_P(Sleepers, WallClockNanosleepTest,
+ ::testing::Values(CLOCK_REALTIME, CLOCK_MONOTONIC));
+
+TEST(ClockNanosleepProcessTest, SleepFiveSeconds) {
+ const absl::Duration kSleepDuration = absl::Seconds(5);
+ struct timespec duration = absl::ToTimespec(kSleepDuration);
+
+ // Ensure that CLOCK_PROCESS_CPUTIME_ID advances.
+ std::atomic<bool> done(false);
+ ScopedThread t([&] {
+ while (!done.load()) {
+ }
+ });
+ const auto cleanup_done = Cleanup([&] { done.store(true); });
+
+ const absl::Time before =
+ ASSERT_NO_ERRNO_AND_VALUE(GetTime(CLOCK_PROCESS_CPUTIME_ID));
+ EXPECT_THAT(RetryEINTR(sys_clock_nanosleep)(CLOCK_PROCESS_CPUTIME_ID, 0,
+ &duration, &duration),
+ SyscallSucceeds());
+ const absl::Time after =
+ ASSERT_NO_ERRNO_AND_VALUE(GetTime(CLOCK_PROCESS_CPUTIME_ID));
+ EXPECT_GE(after - before, kSleepDuration);
+}
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/concurrency.cc b/test/syscalls/linux/concurrency.cc
new file mode 100644
index 000000000..7cd6a75bd
--- /dev/null
+++ b/test/syscalls/linux/concurrency.cc
@@ -0,0 +1,127 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+
+#include <atomic>
+
+#include "gtest/gtest.h"
+#include "absl/strings/string_view.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/platform_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+// Test that a thread that never yields to the OS does not prevent other threads
+// from running.
+TEST(ConcurrencyTest, SingleProcessMultithreaded) {
+ std::atomic<int> a(0);
+
+ ScopedThread t([&a]() {
+ while (!a.load()) {
+ }
+ });
+
+ absl::SleepFor(absl::Seconds(1));
+
+ // We are still able to execute code in this thread. The other hasn't
+ // permanently hung execution in both threads.
+ a.store(1);
+}
+
+// Test that multiple threads in this process continue to execute in parallel,
+// even if an unrelated second process is spawned. Regression test for
+// b/32119508.
+TEST(ConcurrencyTest, MultiProcessMultithreaded) {
+ // In PID 1, start TIDs 1 and 2, and put both to sleep.
+ //
+ // Start PID 3, which spins for 5 seconds, then exits.
+ //
+ // TIDs 1 and 2 wake and attempt to Activate, which cannot occur until PID 3
+ // exits.
+ //
+ // Both TIDs 1 and 2 should be woken. If they are not both woken, the test
+ // hangs.
+ //
+ // This is all fundamentally racy. If we are failing to wake all threads, the
+ // expectation is that this test becomes flaky, rather than consistently
+ // failing.
+ //
+ // If additional background threads fail to block, we may never schedule the
+ // child, at which point this test effectively becomes
+ // MultiProcessConcurrency. That's not expected to occur.
+
+ std::atomic<int> a(0);
+ ScopedThread t([&a]() {
+ // Block so that PID 3 can execute and we can wait on its exit.
+ absl::SleepFor(absl::Seconds(1));
+ while (!a.load()) {
+ }
+ });
+
+ pid_t child_pid = fork();
+ if (child_pid == 0) {
+ // Busy wait without making any blocking syscalls.
+ auto end = absl::Now() + absl::Seconds(5);
+ while (absl::Now() < end) {
+ }
+ _exit(0);
+ }
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ absl::SleepFor(absl::Seconds(1));
+
+ // If only TID 1 is woken, thread.Join will hang.
+ // If only TID 2 is woken, both will hang.
+ a.store(1);
+ t.Join();
+
+ int status = 0;
+ EXPECT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status));
+ EXPECT_EQ(WEXITSTATUS(status), 0);
+}
+
+// Test that multiple processes can execute concurrently, even if one process
+// never yields.
+TEST(ConcurrencyTest, MultiProcessConcurrency) {
+ SKIP_IF(PlatformSupportMultiProcess() == PlatformSupport::NotSupported);
+
+ pid_t child_pid = fork();
+ if (child_pid == 0) {
+ while (true) {
+ }
+ }
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ absl::SleepFor(absl::Seconds(5));
+
+ // We are still able to execute code in this process. The other hasn't
+ // permanently hung execution in both processes.
+ ASSERT_THAT(kill(child_pid, SIGKILL), SyscallSucceeds());
+ int status = 0;
+
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ ASSERT_TRUE(WIFSIGNALED(status));
+ ASSERT_EQ(WTERMSIG(status), SIGKILL);
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/connect_external.cc b/test/syscalls/linux/connect_external.cc
new file mode 100644
index 000000000..1edb50e47
--- /dev/null
+++ b/test/syscalls/linux/connect_external.cc
@@ -0,0 +1,163 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <stdlib.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <string>
+#include <tuple>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/test_util.h"
+
+// This file contains tests specific to connecting to host UDS managed outside
+// the sandbox / test.
+//
+// A set of ultity sockets will be created externally in $TEST_UDS_TREE and
+// $TEST_UDS_ATTACH_TREE for these tests to interact with.
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+struct ProtocolSocket {
+ int protocol;
+ std::string name;
+};
+
+// Parameter is (socket root dir, ProtocolSocket).
+using GoferStreamSeqpacketTest =
+ ::testing::TestWithParam<std::tuple<std::string, ProtocolSocket>>;
+
+// Connect to a socket and verify that write/read work.
+//
+// An "echo" socket doesn't work for dgram sockets because our socket is
+// unnamed. The server thus has no way to reply to us.
+TEST_P(GoferStreamSeqpacketTest, Echo) {
+ std::string env;
+ ProtocolSocket proto;
+ std::tie(env, proto) = GetParam();
+
+ char* val = getenv(env.c_str());
+ ASSERT_NE(val, nullptr);
+ std::string root(val);
+
+ FileDescriptor sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, proto.protocol, 0));
+
+ std::string socket_path = JoinPath(root, proto.name, "echo");
+
+ struct sockaddr_un addr = {};
+ addr.sun_family = AF_UNIX;
+ memcpy(addr.sun_path, socket_path.c_str(), socket_path.length());
+
+ ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+
+ constexpr int kBufferSize = 64;
+ char send_buffer[kBufferSize];
+ memset(send_buffer, 'a', sizeof(send_buffer));
+
+ ASSERT_THAT(WriteFd(sock.get(), send_buffer, sizeof(send_buffer)),
+ SyscallSucceedsWithValue(sizeof(send_buffer)));
+
+ char recv_buffer[kBufferSize];
+ ASSERT_THAT(ReadFd(sock.get(), recv_buffer, sizeof(recv_buffer)),
+ SyscallSucceedsWithValue(sizeof(recv_buffer)));
+ ASSERT_EQ(0, memcmp(send_buffer, recv_buffer, sizeof(send_buffer)));
+}
+
+// It is not possible to connect to a bound but non-listening socket.
+TEST_P(GoferStreamSeqpacketTest, NonListening) {
+ std::string env;
+ ProtocolSocket proto;
+ std::tie(env, proto) = GetParam();
+
+ char* val = getenv(env.c_str());
+ ASSERT_NE(val, nullptr);
+ std::string root(val);
+
+ FileDescriptor sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, proto.protocol, 0));
+
+ std::string socket_path = JoinPath(root, proto.name, "nonlistening");
+
+ struct sockaddr_un addr = {};
+ addr.sun_family = AF_UNIX;
+ memcpy(addr.sun_path, socket_path.c_str(), socket_path.length());
+
+ ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallFailsWithErrno(ECONNREFUSED));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ StreamSeqpacket, GoferStreamSeqpacketTest,
+ ::testing::Combine(
+ // Test access via standard path and attach point.
+ ::testing::Values("TEST_UDS_TREE", "TEST_UDS_ATTACH_TREE"),
+ ::testing::Values(ProtocolSocket{SOCK_STREAM, "stream"},
+ ProtocolSocket{SOCK_SEQPACKET, "seqpacket"})));
+
+// Parameter is socket root dir.
+using GoferDgramTest = ::testing::TestWithParam<std::string>;
+
+// Connect to a socket and verify that write works.
+//
+// An "echo" socket doesn't work for dgram sockets because our socket is
+// unnamed. The server thus has no way to reply to us.
+TEST_P(GoferDgramTest, Null) {
+ std::string env = GetParam();
+ char* val = getenv(env.c_str());
+ ASSERT_NE(val, nullptr);
+ std::string root(val);
+
+ FileDescriptor sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_DGRAM, 0));
+
+ std::string socket_path = JoinPath(root, "dgram/null");
+
+ struct sockaddr_un addr = {};
+ addr.sun_family = AF_UNIX;
+ memcpy(addr.sun_path, socket_path.c_str(), socket_path.length());
+
+ ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+
+ constexpr int kBufferSize = 64;
+ char send_buffer[kBufferSize];
+ memset(send_buffer, 'a', sizeof(send_buffer));
+
+ ASSERT_THAT(WriteFd(sock.get(), send_buffer, sizeof(send_buffer)),
+ SyscallSucceedsWithValue(sizeof(send_buffer)));
+}
+
+INSTANTIATE_TEST_SUITE_P(Dgram, GoferDgramTest,
+ // Test access via standard path and attach point.
+ ::testing::Values("TEST_UDS_TREE",
+ "TEST_UDS_ATTACH_TREE"));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/creat.cc b/test/syscalls/linux/creat.cc
new file mode 100644
index 000000000..3c270d6da
--- /dev/null
+++ b/test/syscalls/linux/creat.cc
@@ -0,0 +1,68 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+
+#include <string>
+
+#include "gtest/gtest.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr int kMode = 0666;
+
+TEST(CreatTest, CreatCreatesNewFile) {
+ std::string const path = NewTempAbsPath();
+ struct stat buf;
+ int fd;
+ ASSERT_THAT(stat(path.c_str(), &buf), SyscallFailsWithErrno(ENOENT));
+ ASSERT_THAT(fd = creat(path.c_str(), kMode), SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+ EXPECT_THAT(stat(path.c_str(), &buf), SyscallSucceeds());
+}
+
+TEST(CreatTest, CreatTruncatesExistingFile) {
+ auto temp_path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ int fd;
+ ASSERT_NO_ERRNO(SetContents(temp_path.path(), "non-empty"));
+ ASSERT_THAT(fd = creat(temp_path.path().c_str(), kMode), SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+ std::string new_contents;
+ ASSERT_NO_ERRNO(GetContents(temp_path.path(), &new_contents));
+ EXPECT_EQ("", new_contents);
+}
+
+TEST(CreatTest, CreatWithNameTooLong) {
+ // Start with a unique name, and pad it to NAME_MAX + 1;
+ std::string name = NewTempRelPath();
+ int padding = (NAME_MAX + 1) - name.size();
+ name.append(padding, 'x');
+ const std::string& path = JoinPath(GetAbsoluteTestTmpdir(), name);
+
+ // Creation should return ENAMETOOLONG.
+ ASSERT_THAT(creat(path.c_str(), kMode), SyscallFailsWithErrno(ENAMETOOLONG));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/dev.cc b/test/syscalls/linux/dev.cc
new file mode 100644
index 000000000..3c88c4cbd
--- /dev/null
+++ b/test/syscalls/linux/dev.cc
@@ -0,0 +1,167 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(DevTest, LseekDevUrandom) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/urandom", O_RDONLY));
+ EXPECT_THAT(lseek(fd.get(), -10, SEEK_CUR), SyscallSucceeds());
+ EXPECT_THAT(lseek(fd.get(), -10, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceeds());
+}
+
+TEST(DevTest, LseekDevNull) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/null", O_RDONLY));
+ EXPECT_THAT(lseek(fd.get(), -10, SEEK_CUR), SyscallSucceeds());
+ EXPECT_THAT(lseek(fd.get(), -10, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceeds());
+ EXPECT_THAT(lseek(fd.get(), 0, SEEK_END), SyscallSucceeds());
+}
+
+TEST(DevTest, LseekDevZero) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDONLY));
+ EXPECT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceeds());
+ EXPECT_THAT(lseek(fd.get(), 0, SEEK_END), SyscallSucceeds());
+}
+
+TEST(DevTest, LseekDevFull) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/full", O_RDONLY));
+ EXPECT_THAT(lseek(fd.get(), 123, SEEK_SET), SyscallSucceedsWithValue(0));
+ EXPECT_THAT(lseek(fd.get(), 123, SEEK_CUR), SyscallSucceedsWithValue(0));
+ EXPECT_THAT(lseek(fd.get(), 123, SEEK_END), SyscallSucceedsWithValue(0));
+}
+
+TEST(DevTest, LseekDevNullFreshFile) {
+ // Seeks to /dev/null always return 0.
+ const FileDescriptor fd1 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/null", O_RDONLY));
+ const FileDescriptor fd2 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/null", O_RDONLY));
+
+ EXPECT_THAT(lseek(fd1.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0));
+ EXPECT_THAT(lseek(fd1.get(), 1000, SEEK_CUR), SyscallSucceedsWithValue(0));
+ EXPECT_THAT(lseek(fd2.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0));
+
+ const FileDescriptor fd3 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/null", O_RDONLY));
+ EXPECT_THAT(lseek(fd3.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0));
+}
+
+TEST(DevTest, OpenTruncate) {
+ // Truncation is ignored on linux and gvisor for device files.
+ ASSERT_NO_ERRNO_AND_VALUE(
+ Open("/dev/null", O_CREAT | O_TRUNC | O_WRONLY, 0644));
+ ASSERT_NO_ERRNO_AND_VALUE(
+ Open("/dev/zero", O_CREAT | O_TRUNC | O_WRONLY, 0644));
+ ASSERT_NO_ERRNO_AND_VALUE(
+ Open("/dev/full", O_CREAT | O_TRUNC | O_WRONLY, 0644));
+}
+
+TEST(DevTest, Pread64DevNull) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/null", O_RDONLY));
+ char buf[1];
+ EXPECT_THAT(pread64(fd.get(), buf, 1, 0), SyscallSucceedsWithValue(0));
+}
+
+TEST(DevTest, Pread64DevZero) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDONLY));
+ char buf[1];
+ EXPECT_THAT(pread64(fd.get(), buf, 1, 0), SyscallSucceedsWithValue(1));
+}
+
+TEST(DevTest, Pread64DevFull) {
+ // /dev/full behaves like /dev/zero with respect to reads.
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/full", O_RDONLY));
+ char buf[1];
+ EXPECT_THAT(pread64(fd.get(), buf, 1, 0), SyscallSucceedsWithValue(1));
+}
+
+TEST(DevTest, ReadDevNull) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/null", O_RDONLY));
+ std::vector<char> buf(1);
+ EXPECT_THAT(ReadFd(fd.get(), buf.data(), 1), SyscallSucceeds());
+}
+
+// Do not allow random save as it could lead to partial reads.
+TEST(DevTest, ReadDevZero_NoRandomSave) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDONLY));
+
+ constexpr int kReadSize = 128 * 1024;
+ std::vector<char> buf(kReadSize, 1);
+ EXPECT_THAT(ReadFd(fd.get(), buf.data(), kReadSize),
+ SyscallSucceedsWithValue(kReadSize));
+ EXPECT_EQ(std::vector<char>(kReadSize, 0), buf);
+}
+
+TEST(DevTest, WriteDevNull) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/null", O_WRONLY));
+ EXPECT_THAT(WriteFd(fd.get(), "a", 1), SyscallSucceedsWithValue(1));
+}
+
+TEST(DevTest, WriteDevZero) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_WRONLY));
+ EXPECT_THAT(WriteFd(fd.get(), "a", 1), SyscallSucceedsWithValue(1));
+}
+
+TEST(DevTest, WriteDevFull) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/full", O_WRONLY));
+ EXPECT_THAT(WriteFd(fd.get(), "a", 1), SyscallFailsWithErrno(ENOSPC));
+}
+
+TEST(DevTest, TTYExists) {
+ struct stat statbuf = {};
+ ASSERT_THAT(stat("/dev/tty", &statbuf), SyscallSucceeds());
+ // Check that it's a character device with rw-rw-rw- permissions.
+ EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666);
+}
+
+TEST(DevTest, OpenDevFuse) {
+ // Note(gvisor.dev/issue/3076) This won't work in the sentry until the new
+ // device registration is complete.
+ SKIP_IF(IsRunningWithVFS1() || IsRunningOnGvisor());
+
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_RDONLY));
+}
+
+} // namespace
+} // namespace testing
+
+} // namespace gvisor
diff --git a/test/syscalls/linux/dup.cc b/test/syscalls/linux/dup.cc
new file mode 100644
index 000000000..4f773bc75
--- /dev/null
+++ b/test/syscalls/linux/dup.cc
@@ -0,0 +1,133 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/eventfd_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+PosixErrorOr<FileDescriptor> Dup2(const FileDescriptor& fd, int target_fd) {
+ int new_fd = dup2(fd.get(), target_fd);
+ if (new_fd < 0) {
+ return PosixError(errno, "Dup2");
+ }
+ return FileDescriptor(new_fd);
+}
+
+PosixErrorOr<FileDescriptor> Dup3(const FileDescriptor& fd, int target_fd,
+ int flags) {
+ int new_fd = dup3(fd.get(), target_fd, flags);
+ if (new_fd < 0) {
+ return PosixError(errno, "Dup2");
+ }
+ return FileDescriptor(new_fd);
+}
+
+void CheckSameFile(const FileDescriptor& fd1, const FileDescriptor& fd2) {
+ struct stat stat_result1, stat_result2;
+ ASSERT_THAT(fstat(fd1.get(), &stat_result1), SyscallSucceeds());
+ ASSERT_THAT(fstat(fd2.get(), &stat_result2), SyscallSucceeds());
+ EXPECT_EQ(stat_result1.st_dev, stat_result2.st_dev);
+ EXPECT_EQ(stat_result1.st_ino, stat_result2.st_ino);
+}
+
+TEST(DupTest, Dup) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY));
+
+ // Dup the descriptor and make sure it's the same file.
+ FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup());
+ ASSERT_NE(fd.get(), nfd.get());
+ CheckSameFile(fd, nfd);
+}
+
+TEST(DupTest, DupClearsCloExec) {
+ // Open an eventfd file descriptor with FD_CLOEXEC descriptor flag set.
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_CLOEXEC));
+ EXPECT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC));
+
+ // Duplicate the descriptor. Ensure that it doesn't have FD_CLOEXEC set.
+ FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup());
+ ASSERT_NE(fd.get(), nfd.get());
+ CheckSameFile(fd, nfd);
+ EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(0));
+}
+
+TEST(DupTest, Dup2) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY));
+
+ // Regular dup once.
+ FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup());
+
+ ASSERT_NE(fd.get(), nfd.get());
+ CheckSameFile(fd, nfd);
+
+ // Dup over the file above.
+ int target_fd = nfd.release();
+ FileDescriptor nfd2 = ASSERT_NO_ERRNO_AND_VALUE(Dup2(fd, target_fd));
+ EXPECT_EQ(target_fd, nfd2.get());
+ CheckSameFile(fd, nfd2);
+}
+
+TEST(DupTest, Dup2SameFD) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY));
+
+ // Should succeed.
+ ASSERT_THAT(dup2(fd.get(), fd.get()), SyscallSucceedsWithValue(fd.get()));
+}
+
+TEST(DupTest, Dup3) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY));
+
+ // Regular dup once.
+ FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup());
+ ASSERT_NE(fd.get(), nfd.get());
+ CheckSameFile(fd, nfd);
+
+ // Dup over the file above, check that it has no CLOEXEC.
+ nfd = ASSERT_NO_ERRNO_AND_VALUE(Dup3(fd, nfd.release(), 0));
+ CheckSameFile(fd, nfd);
+ EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(0));
+
+ // Dup over the file again, check that it does not CLOEXEC.
+ nfd = ASSERT_NO_ERRNO_AND_VALUE(Dup3(fd, nfd.release(), O_CLOEXEC));
+ CheckSameFile(fd, nfd);
+ EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC));
+}
+
+TEST(DupTest, Dup3FailsSameFD) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY));
+
+ // Only dup3 fails if the new and old fd are the same.
+ ASSERT_THAT(dup3(fd.get(), fd.get(), 0), SyscallFailsWithErrno(EINVAL));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/epoll.cc b/test/syscalls/linux/epoll.cc
new file mode 100644
index 000000000..f57d38dc7
--- /dev/null
+++ b/test/syscalls/linux/epoll.cc
@@ -0,0 +1,428 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <limits.h>
+#include <pthread.h>
+#include <signal.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <string.h>
+#include <sys/epoll.h>
+#include <sys/eventfd.h>
+#include <time.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/epoll_util.h"
+#include "test/util/eventfd_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr int kFDsPerEpoll = 3;
+constexpr uint64_t kMagicConstant = 0x0102030405060708;
+
+uint64_t ms_elapsed(const struct timespec* begin, const struct timespec* end) {
+ return (end->tv_sec - begin->tv_sec) * 1000 +
+ (end->tv_nsec - begin->tv_nsec) / 1000000;
+}
+
+TEST(EpollTest, AllWritable) {
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ std::vector<FileDescriptor> eventfds;
+ for (int i = 0; i < kFDsPerEpoll; i++) {
+ eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()));
+ ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfds[i].get(),
+ EPOLLIN | EPOLLOUT, kMagicConstant + i));
+ }
+
+ struct epoll_event result[kFDsPerEpoll];
+ ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1),
+ SyscallSucceedsWithValue(kFDsPerEpoll));
+ for (int i = 0; i < kFDsPerEpoll; i++) {
+ ASSERT_EQ(result[i].events, EPOLLOUT);
+ }
+}
+
+TEST(EpollTest, LastReadable) {
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ std::vector<FileDescriptor> eventfds;
+ for (int i = 0; i < kFDsPerEpoll; i++) {
+ eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()));
+ ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfds[i].get(),
+ EPOLLIN | EPOLLOUT, kMagicConstant + i));
+ }
+
+ uint64_t tmp = 1;
+ ASSERT_THAT(WriteFd(eventfds[kFDsPerEpoll - 1].get(), &tmp, sizeof(tmp)),
+ SyscallSucceedsWithValue(sizeof(tmp)));
+
+ struct epoll_event result[kFDsPerEpoll];
+ ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1),
+ SyscallSucceedsWithValue(kFDsPerEpoll));
+
+ int i;
+ for (i = 0; i < kFDsPerEpoll - 1; i++) {
+ EXPECT_EQ(result[i].events, EPOLLOUT);
+ }
+ EXPECT_EQ(result[i].events, EPOLLOUT | EPOLLIN);
+ EXPECT_EQ(result[i].data.u64, kMagicConstant + i);
+}
+
+TEST(EpollTest, LastNonWritable) {
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ std::vector<FileDescriptor> eventfds;
+ for (int i = 0; i < kFDsPerEpoll; i++) {
+ eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()));
+ ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfds[i].get(),
+ EPOLLIN | EPOLLOUT, kMagicConstant + i));
+ }
+
+ // Write the maximum value to the event fd so that writing to it again would
+ // block.
+ uint64_t tmp = ULLONG_MAX - 1;
+ ASSERT_THAT(WriteFd(eventfds[kFDsPerEpoll - 1].get(), &tmp, sizeof(tmp)),
+ SyscallSucceedsWithValue(sizeof(tmp)));
+
+ struct epoll_event result[kFDsPerEpoll];
+ ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1),
+ SyscallSucceedsWithValue(kFDsPerEpoll));
+
+ int i;
+ for (i = 0; i < kFDsPerEpoll - 1; i++) {
+ EXPECT_EQ(result[i].events, EPOLLOUT);
+ }
+ EXPECT_EQ(result[i].events, EPOLLIN);
+ EXPECT_THAT(ReadFd(eventfds[kFDsPerEpoll - 1].get(), &tmp, sizeof(tmp)),
+ sizeof(tmp));
+ EXPECT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1),
+ SyscallSucceedsWithValue(kFDsPerEpoll));
+
+ for (i = 0; i < kFDsPerEpoll; i++) {
+ EXPECT_EQ(result[i].events, EPOLLOUT);
+ }
+}
+
+TEST(EpollTest, Timeout_NoRandomSave) {
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ std::vector<FileDescriptor> eventfds;
+ for (int i = 0; i < kFDsPerEpoll; i++) {
+ eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()));
+ ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN,
+ kMagicConstant + i));
+ }
+
+ constexpr int kTimeoutMs = 200;
+ struct timespec begin;
+ struct timespec end;
+ struct epoll_event result[kFDsPerEpoll];
+
+ {
+ const DisableSave ds; // Timing-related.
+ EXPECT_THAT(clock_gettime(CLOCK_MONOTONIC, &begin), SyscallSucceeds());
+
+ ASSERT_THAT(
+ RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, kTimeoutMs),
+ SyscallSucceedsWithValue(0));
+ EXPECT_THAT(clock_gettime(CLOCK_MONOTONIC, &end), SyscallSucceeds());
+ }
+
+ // Check the lower bound on the timeout. Checking for an upper bound is
+ // fragile because Linux can overrun the timeout due to scheduling delays.
+ EXPECT_GT(ms_elapsed(&begin, &end), kTimeoutMs - 1);
+}
+
+void* writer(void* arg) {
+ int fd = *reinterpret_cast<int*>(arg);
+ uint64_t tmp = 1;
+
+ usleep(200000);
+ if (WriteFd(fd, &tmp, sizeof(tmp)) != sizeof(tmp)) {
+ fprintf(stderr, "writer failed: errno %s\n", strerror(errno));
+ }
+
+ return nullptr;
+}
+
+TEST(EpollTest, WaitThenUnblock) {
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ std::vector<FileDescriptor> eventfds;
+ for (int i = 0; i < kFDsPerEpoll; i++) {
+ eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()));
+ ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN,
+ kMagicConstant + i));
+ }
+
+ // Fire off a thread that will make at least one of the event fds readable.
+ pthread_t thread;
+ int make_readable = eventfds[0].get();
+ ASSERT_THAT(pthread_create(&thread, nullptr, writer, &make_readable),
+ SyscallSucceedsWithValue(0));
+
+ struct epoll_event result[kFDsPerEpoll];
+ EXPECT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1),
+ SyscallSucceedsWithValue(1));
+ EXPECT_THAT(pthread_detach(thread), SyscallSucceeds());
+}
+
+void sighandler(int s) {}
+
+void* signaler(void* arg) {
+ pthread_t* t = reinterpret_cast<pthread_t*>(arg);
+ // Repeatedly send the real-time signal until we are detached, because it's
+ // difficult to know exactly when epoll_wait on another thread (which this
+ // is intending to interrupt) has started blocking.
+ while (1) {
+ usleep(200000);
+ pthread_kill(*t, SIGRTMIN);
+ }
+ return nullptr;
+}
+
+TEST(EpollTest, UnblockWithSignal) {
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ std::vector<FileDescriptor> eventfds;
+ for (int i = 0; i < kFDsPerEpoll; i++) {
+ eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()));
+ ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN,
+ kMagicConstant + i));
+ }
+
+ signal(SIGRTMIN, sighandler);
+ // Unblock the real time signals that InitGoogle blocks :(
+ sigset_t unblock;
+ sigemptyset(&unblock);
+ sigaddset(&unblock, SIGRTMIN);
+ ASSERT_THAT(sigprocmask(SIG_UNBLOCK, &unblock, nullptr), SyscallSucceeds());
+
+ pthread_t thread;
+ pthread_t cur = pthread_self();
+ ASSERT_THAT(pthread_create(&thread, nullptr, signaler, &cur),
+ SyscallSucceedsWithValue(0));
+
+ struct epoll_event result[kFDsPerEpoll];
+ EXPECT_THAT(epoll_wait(epollfd.get(), result, kFDsPerEpoll, -1),
+ SyscallFailsWithErrno(EINTR));
+ EXPECT_THAT(pthread_cancel(thread), SyscallSucceeds());
+ EXPECT_THAT(pthread_detach(thread), SyscallSucceeds());
+}
+
+TEST(EpollTest, TimeoutNoFds) {
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ struct epoll_event result[kFDsPerEpoll];
+ EXPECT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, 100),
+ SyscallSucceedsWithValue(0));
+}
+
+struct addr_ctx {
+ int epollfd;
+ int eventfd;
+};
+
+void* fd_adder(void* arg) {
+ struct addr_ctx* actx = reinterpret_cast<struct addr_ctx*>(arg);
+ struct epoll_event event;
+ event.events = EPOLLIN | EPOLLOUT;
+ event.data.u64 = 0xdeadbeeffacefeed;
+
+ usleep(200000);
+ if (epoll_ctl(actx->epollfd, EPOLL_CTL_ADD, actx->eventfd, &event) == -1) {
+ fprintf(stderr, "epoll_ctl failed: %s\n", strerror(errno));
+ }
+
+ return nullptr;
+}
+
+TEST(EpollTest, UnblockWithNewFD) {
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ auto eventfd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD());
+
+ pthread_t thread;
+ struct addr_ctx actx = {epollfd.get(), eventfd.get()};
+ ASSERT_THAT(pthread_create(&thread, nullptr, fd_adder, &actx),
+ SyscallSucceedsWithValue(0));
+
+ struct epoll_event result[kFDsPerEpoll];
+ // Wait while no FDs are ready, but after 200ms fd_adder will add a ready FD
+ // to epoll which will wake us up.
+ EXPECT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1),
+ SyscallSucceedsWithValue(1));
+ EXPECT_THAT(pthread_detach(thread), SyscallSucceeds());
+ EXPECT_EQ(result[0].data.u64, 0xdeadbeeffacefeed);
+}
+
+TEST(EpollTest, Oneshot) {
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ std::vector<FileDescriptor> eventfds;
+ for (int i = 0; i < kFDsPerEpoll; i++) {
+ eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()));
+ ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN,
+ kMagicConstant + i));
+ }
+
+ struct epoll_event event;
+ event.events = EPOLLOUT | EPOLLONESHOT;
+ event.data.u64 = kMagicConstant;
+ ASSERT_THAT(
+ epoll_ctl(epollfd.get(), EPOLL_CTL_MOD, eventfds[0].get(), &event),
+ SyscallSucceeds());
+
+ struct epoll_event result[kFDsPerEpoll];
+ // One-shot entry means that the first epoll_wait should succeed.
+ ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1),
+ SyscallSucceedsWithValue(1));
+ EXPECT_EQ(result[0].data.u64, kMagicConstant);
+
+ // One-shot entry means that the second epoll_wait should timeout.
+ EXPECT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, 100),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST(EpollTest, EdgeTriggered_NoRandomSave) {
+ // Test edge-triggered entry: make it edge-triggered, first wait should
+ // return it, second one should time out, make it writable again, third wait
+ // should return it, fourth wait should timeout.
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ auto eventfd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD());
+ ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfd.get(),
+ EPOLLOUT | EPOLLET, kMagicConstant));
+
+ struct epoll_event result[kFDsPerEpoll];
+
+ {
+ const DisableSave ds; // May trigger spurious event.
+
+ // Edge-triggered entry means that the first epoll_wait should return the
+ // event.
+ ASSERT_THAT(epoll_wait(epollfd.get(), result, kFDsPerEpoll, -1),
+ SyscallSucceedsWithValue(1));
+ EXPECT_EQ(result[0].data.u64, kMagicConstant);
+
+ // Edge-triggered entry means that the second epoll_wait should time out.
+ ASSERT_THAT(epoll_wait(epollfd.get(), result, kFDsPerEpoll, 100),
+ SyscallSucceedsWithValue(0));
+ }
+
+ uint64_t tmp = ULLONG_MAX - 1;
+
+ // Make an fd non-writable.
+ ASSERT_THAT(WriteFd(eventfd.get(), &tmp, sizeof(tmp)),
+ SyscallSucceedsWithValue(sizeof(tmp)));
+
+ // Make the same fd non-writable to trigger a change, which will trigger an
+ // edge-triggered event.
+ ASSERT_THAT(ReadFd(eventfd.get(), &tmp, sizeof(tmp)),
+ SyscallSucceedsWithValue(sizeof(tmp)));
+
+ {
+ const DisableSave ds; // May trigger spurious event.
+
+ // An edge-triggered event should now be returned.
+ ASSERT_THAT(epoll_wait(epollfd.get(), result, kFDsPerEpoll, -1),
+ SyscallSucceedsWithValue(1));
+ EXPECT_EQ(result[0].data.u64, kMagicConstant);
+
+ // The edge-triggered event had been consumed above, we don't expect to
+ // get it again.
+ ASSERT_THAT(epoll_wait(epollfd.get(), result, kFDsPerEpoll, 100),
+ SyscallSucceedsWithValue(0));
+ }
+}
+
+TEST(EpollTest, OneshotAndEdgeTriggered) {
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ auto eventfd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD());
+ ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfd.get(),
+ EPOLLOUT | EPOLLET | EPOLLONESHOT,
+ kMagicConstant));
+
+ struct epoll_event result[kFDsPerEpoll];
+ // First time one shot edge-triggered entry means that epoll_wait should
+ // return the event.
+ ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1),
+ SyscallSucceedsWithValue(1));
+ EXPECT_EQ(result[0].data.u64, kMagicConstant);
+
+ // Edge-triggered entry means that the second epoll_wait should time out.
+ ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, 100),
+ SyscallSucceedsWithValue(0));
+
+ uint64_t tmp = ULLONG_MAX - 1;
+ // Make an fd non-writable.
+ ASSERT_THAT(WriteFd(eventfd.get(), &tmp, sizeof(tmp)),
+ SyscallSucceedsWithValue(sizeof(tmp)));
+ // Make the same fd non-writable to trigger a change, which will not trigger
+ // an edge-triggered event because we've also included EPOLLONESHOT.
+ ASSERT_THAT(ReadFd(eventfd.get(), &tmp, sizeof(tmp)),
+ SyscallSucceedsWithValue(sizeof(tmp)));
+ ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, 100),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST(EpollTest, CycleOfOneDisallowed) {
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+
+ struct epoll_event event;
+ event.events = EPOLLOUT;
+ event.data.u64 = kMagicConstant;
+
+ ASSERT_THAT(epoll_ctl(epollfd.get(), EPOLL_CTL_ADD, epollfd.get(), &event),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(EpollTest, CycleOfThreeDisallowed) {
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ auto epollfd1 = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ auto epollfd2 = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+
+ ASSERT_NO_ERRNO(
+ RegisterEpollFD(epollfd.get(), epollfd1.get(), EPOLLIN, kMagicConstant));
+ ASSERT_NO_ERRNO(
+ RegisterEpollFD(epollfd1.get(), epollfd2.get(), EPOLLIN, kMagicConstant));
+
+ struct epoll_event event;
+ event.events = EPOLLIN;
+ event.data.u64 = kMagicConstant;
+ EXPECT_THAT(epoll_ctl(epollfd2.get(), EPOLL_CTL_ADD, epollfd.get(), &event),
+ SyscallFailsWithErrno(ELOOP));
+}
+
+TEST(EpollTest, CloseFile) {
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ auto eventfd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD());
+ ASSERT_NO_ERRNO(
+ RegisterEpollFD(epollfd.get(), eventfd.get(), EPOLLOUT, kMagicConstant));
+
+ struct epoll_event result[kFDsPerEpoll];
+ ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1),
+ SyscallSucceedsWithValue(1));
+ EXPECT_EQ(result[0].data.u64, kMagicConstant);
+
+ // Close the event fd early.
+ eventfd.reset();
+
+ EXPECT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, 100),
+ SyscallSucceedsWithValue(0));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/eventfd.cc b/test/syscalls/linux/eventfd.cc
new file mode 100644
index 000000000..dc794415e
--- /dev/null
+++ b/test/syscalls/linux/eventfd.cc
@@ -0,0 +1,222 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <pthread.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/epoll.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/epoll_util.h"
+#include "test/util/eventfd_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(EventfdTest, Nonblock) {
+ FileDescriptor efd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE));
+
+ uint64_t l;
+ ASSERT_THAT(read(efd.get(), &l, sizeof(l)), SyscallFailsWithErrno(EAGAIN));
+
+ l = 1;
+ ASSERT_THAT(write(efd.get(), &l, sizeof(l)), SyscallSucceeds());
+
+ l = 0;
+ ASSERT_THAT(read(efd.get(), &l, sizeof(l)), SyscallSucceeds());
+ EXPECT_EQ(l, 1);
+
+ ASSERT_THAT(read(efd.get(), &l, sizeof(l)), SyscallFailsWithErrno(EAGAIN));
+}
+
+void* read_three_times(void* arg) {
+ int efd = *reinterpret_cast<int*>(arg);
+ uint64_t l;
+ EXPECT_THAT(read(efd, &l, sizeof(l)), SyscallSucceedsWithValue(sizeof(l)));
+ EXPECT_THAT(read(efd, &l, sizeof(l)), SyscallSucceedsWithValue(sizeof(l)));
+ EXPECT_THAT(read(efd, &l, sizeof(l)), SyscallSucceedsWithValue(sizeof(l)));
+ return nullptr;
+}
+
+TEST(EventfdTest, BlockingWrite) {
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_SEMAPHORE));
+ int efd = fd.get();
+
+ pthread_t p;
+ ASSERT_THAT(pthread_create(&p, nullptr, read_three_times,
+ reinterpret_cast<void*>(&efd)),
+ SyscallSucceeds());
+
+ uint64_t l = 1;
+ ASSERT_THAT(write(efd, &l, sizeof(l)), SyscallSucceeds());
+ EXPECT_EQ(l, 1);
+
+ ASSERT_THAT(write(efd, &l, sizeof(l)), SyscallSucceeds());
+ EXPECT_EQ(l, 1);
+
+ ASSERT_THAT(write(efd, &l, sizeof(l)), SyscallSucceeds());
+ EXPECT_EQ(l, 1);
+
+ ASSERT_THAT(pthread_join(p, nullptr), SyscallSucceeds());
+}
+
+TEST(EventfdTest, SmallWrite) {
+ FileDescriptor efd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE));
+
+ uint64_t l = 16;
+ ASSERT_THAT(write(efd.get(), &l, 4), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(EventfdTest, SmallRead) {
+ FileDescriptor efd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE));
+
+ uint64_t l = 1;
+ ASSERT_THAT(write(efd.get(), &l, sizeof(l)), SyscallSucceeds());
+
+ l = 0;
+ ASSERT_THAT(read(efd.get(), &l, 4), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(EventfdTest, IllegalSeek) {
+ FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0));
+ EXPECT_THAT(lseek(efd.get(), 0, SEEK_SET), SyscallFailsWithErrno(ESPIPE));
+}
+
+TEST(EventfdTest, IllegalPread) {
+ FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0));
+ int l;
+ EXPECT_THAT(pread(efd.get(), &l, sizeof(l), 0),
+ SyscallFailsWithErrno(ESPIPE));
+}
+
+TEST(EventfdTest, IllegalPwrite) {
+ FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0));
+ EXPECT_THAT(pwrite(efd.get(), "x", 1, 0), SyscallFailsWithErrno(ESPIPE));
+}
+
+TEST(EventfdTest, BigWrite) {
+ FileDescriptor efd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE));
+
+ uint64_t big[16];
+ big[0] = 16;
+ ASSERT_THAT(write(efd.get(), big, sizeof(big)), SyscallSucceeds());
+}
+
+TEST(EventfdTest, BigRead) {
+ FileDescriptor efd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE));
+
+ uint64_t l = 1;
+ ASSERT_THAT(write(efd.get(), &l, sizeof(l)), SyscallSucceeds());
+
+ uint64_t big[16];
+ ASSERT_THAT(read(efd.get(), big, sizeof(big)), SyscallSucceeds());
+ EXPECT_EQ(big[0], 1);
+}
+
+TEST(EventfdTest, BigWriteBigRead) {
+ FileDescriptor efd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE));
+
+ uint64_t l[16];
+ l[0] = 16;
+ ASSERT_THAT(write(efd.get(), l, sizeof(l)), SyscallSucceeds());
+ ASSERT_THAT(read(efd.get(), l, sizeof(l)), SyscallSucceeds());
+ EXPECT_EQ(l[0], 1);
+}
+
+TEST(EventfdTest, SpliceFromPipePartialSucceeds) {
+ int pipes[2];
+ ASSERT_THAT(pipe2(pipes, O_NONBLOCK), SyscallSucceeds());
+ const FileDescriptor pipe_rfd(pipes[0]);
+ const FileDescriptor pipe_wfd(pipes[1]);
+ constexpr uint64_t kVal{1};
+
+ FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK));
+
+ uint64_t event_array[2];
+ event_array[0] = kVal;
+ event_array[1] = kVal;
+ ASSERT_THAT(write(pipe_wfd.get(), event_array, sizeof(event_array)),
+ SyscallSucceedsWithValue(sizeof(event_array)));
+ EXPECT_THAT(splice(pipe_rfd.get(), /*__offin=*/nullptr, efd.get(),
+ /*__offout=*/nullptr, sizeof(event_array[0]) + 1,
+ SPLICE_F_NONBLOCK),
+ SyscallSucceedsWithValue(sizeof(event_array[0])));
+
+ uint64_t val;
+ ASSERT_THAT(read(efd.get(), &val, sizeof(val)),
+ SyscallSucceedsWithValue(sizeof(val)));
+ EXPECT_EQ(val, kVal);
+}
+
+// NotifyNonZero is inherently racy, so random save is disabled.
+TEST(EventfdTest, NotifyNonZero_NoRandomSave) {
+ // Waits will time out at 10 seconds.
+ constexpr int kEpollTimeoutMs = 10000;
+ // Create an eventfd descriptor.
+ FileDescriptor efd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(7, EFD_NONBLOCK | EFD_SEMAPHORE));
+ // Create an epoll fd to listen to efd.
+ FileDescriptor epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ // Add efd to epoll.
+ ASSERT_NO_ERRNO(
+ RegisterEpollFD(epollfd.get(), efd.get(), EPOLLIN | EPOLLET, efd.get()));
+
+ // Use epoll to get a value from efd.
+ struct epoll_event out_ev;
+ int wait_out = epoll_wait(epollfd.get(), &out_ev, 1, kEpollTimeoutMs);
+ EXPECT_EQ(wait_out, 1);
+ EXPECT_EQ(efd.get(), out_ev.data.fd);
+ uint64_t val = 0;
+ ASSERT_THAT(read(efd.get(), &val, sizeof(val)), SyscallSucceeds());
+ EXPECT_EQ(val, 1);
+
+ // Start a thread that, after this thread blocks on epoll_wait, will write to
+ // efd. This is racy -- it's possible that this write will happen after
+ // epoll_wait times out.
+ ScopedThread t([&efd] {
+ sleep(5);
+ uint64_t val = 1;
+ EXPECT_THAT(write(efd.get(), &val, sizeof(val)),
+ SyscallSucceedsWithValue(sizeof(val)));
+ });
+
+ // epoll_wait should return once the thread writes.
+ wait_out = epoll_wait(epollfd.get(), &out_ev, 1, kEpollTimeoutMs);
+ EXPECT_EQ(wait_out, 1);
+ EXPECT_EQ(efd.get(), out_ev.data.fd);
+
+ val = 0;
+ ASSERT_THAT(read(efd.get(), &val, sizeof(val)), SyscallSucceeds());
+ EXPECT_EQ(val, 1);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/exceptions.cc b/test/syscalls/linux/exceptions.cc
new file mode 100644
index 000000000..420b9543f
--- /dev/null
+++ b/test/syscalls/linux/exceptions.cc
@@ -0,0 +1,367 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+
+#include "gtest/gtest.h"
+#include "test/util/logging.h"
+#include "test/util/platform_util.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Default value for the x87 FPU control word. See Intel SDM Vol 1, Ch 8.1.5
+// "x87 FPU Control Word".
+constexpr uint16_t kX87ControlWordDefault = 0x37f;
+
+// Mask for the divide-by-zero exception.
+constexpr uint16_t kX87ControlWordDiv0Mask = 1 << 2;
+
+// Default value for the SSE control register (MXCSR). See Intel SDM Vol 1, Ch
+// 11.6.4 "Initialization of SSE/SSE3 Extensions".
+constexpr uint32_t kMXCSRDefault = 0x1f80;
+
+// Mask for the divide-by-zero exception.
+constexpr uint32_t kMXCSRDiv0Mask = 1 << 9;
+
+// Flag for a pending divide-by-zero exception.
+constexpr uint32_t kMXCSRDiv0Flag = 1 << 2;
+
+void inline Halt() { asm("hlt\r\n"); }
+
+void inline SetAlignmentCheck() {
+ asm("subq $128, %%rsp\r\n" // Avoid potential red zone clobber
+ "pushf\r\n"
+ "pop %%rax\r\n"
+ "or $0x40000, %%rax\r\n"
+ "push %%rax\r\n"
+ "popf\r\n"
+ "addq $128, %%rsp\r\n"
+ :
+ :
+ : "ax");
+}
+
+void inline ClearAlignmentCheck() {
+ asm("subq $128, %%rsp\r\n" // Avoid potential red zone clobber
+ "pushf\r\n"
+ "pop %%rax\r\n"
+ "mov $0x40000, %%rbx\r\n"
+ "not %%rbx\r\n"
+ "and %%rbx, %%rax\r\n"
+ "push %%rax\r\n"
+ "popf\r\n"
+ "addq $128, %%rsp\r\n"
+ :
+ :
+ : "ax", "bx");
+}
+
+void inline Int3Normal() { asm(".byte 0xcd, 0x03\r\n"); }
+
+void inline Int3Compact() { asm(".byte 0xcc\r\n"); }
+
+void InIOHelper(int width, int value) {
+ EXPECT_EXIT(
+ {
+ switch (width) {
+ case 1:
+ asm volatile("inb %%dx, %%al" ::"d"(value) : "%eax");
+ break;
+ case 2:
+ asm volatile("inw %%dx, %%ax" ::"d"(value) : "%eax");
+ break;
+ case 4:
+ asm volatile("inl %%dx, %%eax" ::"d"(value) : "%eax");
+ break;
+ default:
+ FAIL() << "invalid input width, only 1, 2 or 4 is allowed";
+ }
+ },
+ ::testing::KilledBySignal(SIGSEGV), "");
+}
+
+TEST(ExceptionTest, Halt) {
+ // In order to prevent the regular handler from messing with things (and
+ // perhaps refaulting until some other signal occurs), we reset the handler to
+ // the default action here and ensure that it dies correctly.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_DFL;
+ auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGSEGV, sa));
+
+ EXPECT_EXIT(Halt(), ::testing::KilledBySignal(SIGSEGV), "");
+}
+
+TEST(ExceptionTest, DivideByZero) {
+ // See above.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_DFL;
+ auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGFPE, sa));
+
+ EXPECT_EXIT(
+ {
+ uint32_t remainder;
+ uint32_t quotient;
+ uint32_t divisor = 0;
+ uint64_t value = 1;
+ asm("divl 0(%2)\r\n"
+ : "=d"(remainder), "=a"(quotient)
+ : "r"(&divisor), "d"(value >> 32), "a"(value));
+ TEST_CHECK(quotient > 0); // Force dependency.
+ },
+ ::testing::KilledBySignal(SIGFPE), "");
+}
+
+// By default, x87 exceptions are masked and simply return a default value.
+TEST(ExceptionTest, X87DivideByZeroMasked) {
+ int32_t quotient;
+ int32_t value = 1;
+ int32_t divisor = 0;
+ asm("fildl %[value]\r\n"
+ "fidivl %[divisor]\r\n"
+ "fistpl %[quotient]\r\n"
+ : [ quotient ] "=m"(quotient)
+ : [ value ] "m"(value), [ divisor ] "m"(divisor));
+
+ EXPECT_EQ(quotient, INT32_MIN);
+}
+
+// When unmasked, division by zero raises SIGFPE.
+TEST(ExceptionTest, X87DivideByZeroUnmasked) {
+ // See above.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_DFL;
+ auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGFPE, sa));
+
+ EXPECT_EXIT(
+ {
+ // Clear the divide by zero exception mask.
+ constexpr uint16_t kControlWord =
+ kX87ControlWordDefault & ~kX87ControlWordDiv0Mask;
+
+ int32_t quotient;
+ int32_t value = 1;
+ int32_t divisor = 0;
+ asm volatile(
+ "fldcw %[cw]\r\n"
+ "fildl %[value]\r\n"
+ "fidivl %[divisor]\r\n"
+ "fistpl %[quotient]\r\n"
+ : [ quotient ] "=m"(quotient)
+ : [ cw ] "m"(kControlWord), [ value ] "m"(value),
+ [ divisor ] "m"(divisor));
+ },
+ ::testing::KilledBySignal(SIGFPE), "");
+}
+
+// Pending exceptions in the x87 status register are not clobbered by syscalls.
+TEST(ExceptionTest, X87StatusClobber) {
+ // See above.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_DFL;
+ auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGFPE, sa));
+
+ EXPECT_EXIT(
+ {
+ // Clear the divide by zero exception mask.
+ constexpr uint16_t kControlWord =
+ kX87ControlWordDefault & ~kX87ControlWordDiv0Mask;
+
+ int32_t quotient;
+ int32_t value = 1;
+ int32_t divisor = 0;
+ asm volatile(
+ "fildl %[value]\r\n"
+ "fidivl %[divisor]\r\n"
+ // Exception is masked, so it does not occur here.
+ "fistpl %[quotient]\r\n"
+
+ // SYS_getpid placed in rax by constraint.
+ "syscall\r\n"
+
+ // Unmask exception. The syscall didn't clobber the pending
+ // exception, so now it can be raised.
+ //
+ // N.B. "a floating-point exception will be generated upon execution
+ // of the *next* floating-point instruction".
+ "fldcw %[cw]\r\n"
+ "fwait\r\n"
+ : [ quotient ] "=m"(quotient)
+ : [ value ] "m"(value), [ divisor ] "m"(divisor), "a"(SYS_getpid),
+ [ cw ] "m"(kControlWord)
+ : "rcx", "r11");
+ },
+ ::testing::KilledBySignal(SIGFPE), "");
+}
+
+// By default, SSE exceptions are masked and simply return a default value.
+TEST(ExceptionTest, SSEDivideByZeroMasked) {
+ uint32_t status;
+ int32_t quotient;
+ int32_t value = 1;
+ int32_t divisor = 0;
+ asm("cvtsi2ssl %[value], %%xmm0\r\n"
+ "cvtsi2ssl %[divisor], %%xmm1\r\n"
+ "divss %%xmm1, %%xmm0\r\n"
+ "cvtss2sil %%xmm0, %[quotient]\r\n"
+ : [ quotient ] "=r"(quotient), [ status ] "=r"(status)
+ : [ value ] "r"(value), [ divisor ] "r"(divisor)
+ : "xmm0", "xmm1");
+
+ EXPECT_EQ(quotient, INT32_MIN);
+}
+
+// When unmasked, division by zero raises SIGFPE.
+TEST(ExceptionTest, SSEDivideByZeroUnmasked) {
+ // See above.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_DFL;
+ auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGFPE, sa));
+
+ EXPECT_EXIT(
+ {
+ // Clear the divide by zero exception mask.
+ constexpr uint32_t kMXCSR = kMXCSRDefault & ~kMXCSRDiv0Mask;
+
+ int32_t quotient;
+ int32_t value = 1;
+ int32_t divisor = 0;
+ asm volatile(
+ "ldmxcsr %[mxcsr]\r\n"
+ "cvtsi2ssl %[value], %%xmm0\r\n"
+ "cvtsi2ssl %[divisor], %%xmm1\r\n"
+ "divss %%xmm1, %%xmm0\r\n"
+ "cvtss2sil %%xmm0, %[quotient]\r\n"
+ : [ quotient ] "=r"(quotient)
+ : [ mxcsr ] "m"(kMXCSR), [ value ] "r"(value),
+ [ divisor ] "r"(divisor)
+ : "xmm0", "xmm1");
+ },
+ ::testing::KilledBySignal(SIGFPE), "");
+}
+
+// Pending exceptions in the SSE status register are not clobbered by syscalls.
+TEST(ExceptionTest, SSEStatusClobber) {
+ uint32_t mxcsr;
+ int32_t quotient;
+ int32_t value = 1;
+ int32_t divisor = 0;
+ asm("cvtsi2ssl %[value], %%xmm0\r\n"
+ "cvtsi2ssl %[divisor], %%xmm1\r\n"
+ "divss %%xmm1, %%xmm0\r\n"
+ // Exception is masked, so it does not occur here.
+ "cvtss2sil %%xmm0, %[quotient]\r\n"
+
+ // SYS_getpid placed in rax by constraint.
+ "syscall\r\n"
+
+ // Intel SDM Vol 1, Ch 10.2.3.1 "SIMD Floating-Point Mask and Flag Bits":
+ // "If LDMXCSR or FXRSTOR clears a mask bit and sets the corresponding
+ // exception flag bit, a SIMD floating-point exception will not be
+ // generated as a result of this change. The unmasked exception will be
+ // generated only upon the execution of the next SSE/SSE2/SSE3 instruction
+ // that detects the unmasked exception condition."
+ //
+ // Though ambiguous, empirical evidence indicates that this means that
+ // exception flags set in the status register will never cause an
+ // exception to be raised; only a new exception condition will do so.
+ //
+ // Thus here we just check for the flag itself rather than trying to raise
+ // the exception.
+ "stmxcsr %[mxcsr]\r\n"
+ : [ quotient ] "=r"(quotient), [ mxcsr ] "+m"(mxcsr)
+ : [ value ] "r"(value), [ divisor ] "r"(divisor), "a"(SYS_getpid)
+ : "xmm0", "xmm1", "rcx", "r11");
+
+ EXPECT_TRUE(mxcsr & kMXCSRDiv0Flag);
+}
+
+TEST(ExceptionTest, IOAccessFault) {
+ // See above.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_DFL;
+ auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGSEGV, sa));
+
+ InIOHelper(1, 0x0);
+ InIOHelper(2, 0x7);
+ InIOHelper(4, 0x6);
+ InIOHelper(1, 0xffff);
+ InIOHelper(2, 0xffff);
+ InIOHelper(4, 0xfffd);
+}
+
+TEST(ExceptionTest, Alignment) {
+ SetAlignmentCheck();
+ ClearAlignmentCheck();
+}
+
+TEST(ExceptionTest, AlignmentHalt) {
+ // See above.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_DFL;
+ auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGSEGV, sa));
+
+ // Reported upstream. We need to ensure that bad flags are cleared even in
+ // fault paths. Set the alignment flag and then generate an exception.
+ EXPECT_EXIT(
+ {
+ SetAlignmentCheck();
+ Halt();
+ },
+ ::testing::KilledBySignal(SIGSEGV), "");
+}
+
+TEST(ExceptionTest, AlignmentCheck) {
+ SKIP_IF(PlatformSupportAlignmentCheck() != PlatformSupport::Allowed);
+
+ // See above.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_DFL;
+ auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGBUS, sa));
+
+ EXPECT_EXIT(
+ {
+ char array[16];
+ SetAlignmentCheck();
+ for (int i = 0; i < 8; i++) {
+ // At least 7/8 offsets will be unaligned here.
+ uint64_t* ptr = reinterpret_cast<uint64_t*>(&array[i]);
+ asm("mov %0, 0(%0)\r\n" : : "r"(ptr) : "ax");
+ }
+ },
+ ::testing::KilledBySignal(SIGBUS), "");
+}
+
+TEST(ExceptionTest, Int3Normal) {
+ // See above.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_DFL;
+ auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGTRAP, sa));
+
+ EXPECT_EXIT(Int3Normal(), ::testing::KilledBySignal(SIGTRAP), "");
+}
+
+TEST(ExceptionTest, Int3Compact) {
+ // See above.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_DFL;
+ auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGTRAP, sa));
+
+ EXPECT_EXIT(Int3Compact(), ::testing::KilledBySignal(SIGTRAP), "");
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc
new file mode 100644
index 000000000..c5acfc794
--- /dev/null
+++ b/test/syscalls/linux/exec.cc
@@ -0,0 +1,904 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/exec.h"
+
+#include <errno.h>
+#include <fcntl.h>
+#include <sys/eventfd.h>
+#include <sys/resource.h>
+#include <sys/time.h>
+#include <unistd.h>
+
+#include <iostream>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "absl/strings/match.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/types/optional.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr char kBasicWorkload[] = "test/syscalls/linux/exec_basic_workload";
+constexpr char kExitScript[] = "test/syscalls/linux/exit_script";
+constexpr char kStateWorkload[] = "test/syscalls/linux/exec_state_workload";
+constexpr char kProcExeWorkload[] =
+ "test/syscalls/linux/exec_proc_exe_workload";
+constexpr char kAssertClosedWorkload[] =
+ "test/syscalls/linux/exec_assert_closed_workload";
+constexpr char kPriorityWorkload[] = "test/syscalls/linux/priority_execve";
+
+constexpr char kExit42[] = "--exec_exit_42";
+constexpr char kExecWithThread[] = "--exec_exec_with_thread";
+constexpr char kExecFromThread[] = "--exec_exec_from_thread";
+
+// Runs file specified by dirfd and pathname with argv and checks that the exit
+// status is expect_status and that stderr contains expect_stderr.
+void CheckExecHelper(const absl::optional<int32_t> dirfd,
+ const std::string& pathname, const ExecveArray& argv,
+ const ExecveArray& envv, const int flags,
+ int expect_status, const std::string& expect_stderr) {
+ int pipe_fds[2];
+ ASSERT_THAT(pipe2(pipe_fds, O_CLOEXEC), SyscallSucceeds());
+
+ FileDescriptor read_fd(pipe_fds[0]);
+ FileDescriptor write_fd(pipe_fds[1]);
+
+ pid_t child;
+ int execve_errno;
+
+ const auto remap_stderr = [pipe_fds] {
+ // Remap stdin and stdout to /dev/null.
+ int fd = open("/dev/null", O_RDWR | O_CLOEXEC);
+ if (fd < 0) {
+ _exit(errno);
+ }
+
+ int ret = dup2(fd, 0);
+ if (ret < 0) {
+ _exit(errno);
+ }
+
+ ret = dup2(fd, 1);
+ if (ret < 0) {
+ _exit(errno);
+ }
+
+ // And stderr to the pipe.
+ ret = dup2(pipe_fds[1], 2);
+ if (ret < 0) {
+ _exit(errno);
+ }
+
+ // Here, we'd ideally close all other FDs inherited from the parent.
+ // However, that's not worth the effort and CloexecNormalFile and
+ // CloexecEventfd depend on that not happening.
+ };
+
+ Cleanup kill;
+ if (dirfd.has_value()) {
+ kill = ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(*dirfd, pathname, argv,
+ envv, flags, remap_stderr,
+ &child, &execve_errno));
+ } else {
+ kill = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(pathname, argv, envv, remap_stderr, &child, &execve_errno));
+ }
+
+ ASSERT_EQ(0, execve_errno);
+
+ // Not needed anymore.
+ write_fd.reset();
+
+ // Read stderr until the child exits.
+ std::string output;
+ constexpr int kSize = 128;
+ char buf[kSize];
+ int n;
+ do {
+ ASSERT_THAT(n = ReadFd(read_fd.get(), buf, kSize), SyscallSucceeds());
+ if (n > 0) {
+ output.append(buf, n);
+ }
+ } while (n > 0);
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds());
+ EXPECT_EQ(status, expect_status);
+
+ // Process cleanup no longer needed.
+ kill.Release();
+
+ EXPECT_TRUE(absl::StrContains(output, expect_stderr)) << output;
+}
+
+void CheckExec(const std::string& filename, const ExecveArray& argv,
+ const ExecveArray& envv, int expect_status,
+ const std::string& expect_stderr) {
+ CheckExecHelper(/*dirfd=*/absl::optional<int32_t>(), filename, argv, envv,
+ /*flags=*/0, expect_status, expect_stderr);
+}
+
+void CheckExecveat(const int32_t dirfd, const std::string& pathname,
+ const ExecveArray& argv, const ExecveArray& envv,
+ const int flags, int expect_status,
+ const std::string& expect_stderr) {
+ CheckExecHelper(absl::optional<int32_t>(dirfd), pathname, argv, envv, flags,
+ expect_status, expect_stderr);
+}
+
+TEST(ExecTest, EmptyPath) {
+ int execve_errno;
+ ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec("", {}, {}, nullptr, &execve_errno));
+ EXPECT_EQ(execve_errno, ENOENT);
+}
+
+TEST(ExecTest, Basic) {
+ CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload)}, {},
+ ArgEnvExitStatus(0, 0),
+ absl::StrCat(RunfilePath(kBasicWorkload), "\n"));
+}
+
+TEST(ExecTest, OneArg) {
+ CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload), "1"}, {},
+ ArgEnvExitStatus(1, 0),
+ absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n"));
+}
+
+TEST(ExecTest, FiveArg) {
+ CheckExec(RunfilePath(kBasicWorkload),
+ {RunfilePath(kBasicWorkload), "1", "2", "3", "4", "5"}, {},
+ ArgEnvExitStatus(5, 0),
+ absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n2\n3\n4\n5\n"));
+}
+
+TEST(ExecTest, OneEnv) {
+ CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload)}, {"1"},
+ ArgEnvExitStatus(0, 1),
+ absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n"));
+}
+
+TEST(ExecTest, FiveEnv) {
+ CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload)},
+ {"1", "2", "3", "4", "5"}, ArgEnvExitStatus(0, 5),
+ absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n2\n3\n4\n5\n"));
+}
+
+TEST(ExecTest, OneArgOneEnv) {
+ CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload), "arg"},
+ {"env"}, ArgEnvExitStatus(1, 1),
+ absl::StrCat(RunfilePath(kBasicWorkload), "\narg\nenv\n"));
+}
+
+TEST(ExecTest, InterpreterScript) {
+ CheckExec(RunfilePath(kExitScript), {RunfilePath(kExitScript), "25"}, {},
+ ArgEnvExitStatus(25, 0), "");
+}
+
+// Everything after the path in the interpreter script is a single argument.
+TEST(ExecTest, InterpreterScriptArgSplit) {
+ // Symlink through /tmp to ensure the path is short enough.
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload)));
+
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " foo bar"),
+ 0755));
+
+ CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0),
+ absl::StrCat(link.path(), "\nfoo bar\n", script.path(), "\n"));
+}
+
+// Original argv[0] is replaced with the script path.
+TEST(ExecTest, InterpreterScriptArgvZero) {
+ // Symlink through /tmp to ensure the path is short enough.
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload)));
+
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755));
+
+ CheckExec(script.path(), {"REPLACED"}, {}, ArgEnvExitStatus(1, 0),
+ absl::StrCat(link.path(), "\n", script.path(), "\n"));
+}
+
+// Original argv[0] is replaced with the script path, exactly as passed to
+// execve.
+TEST(ExecTest, InterpreterScriptArgvZeroRelative) {
+ // Symlink through /tmp to ensure the path is short enough.
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload)));
+
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755));
+
+ auto cwd = ASSERT_NO_ERRNO_AND_VALUE(GetCWD());
+ auto script_relative =
+ ASSERT_NO_ERRNO_AND_VALUE(GetRelativePath(cwd, script.path()));
+
+ CheckExec(script_relative, {"REPLACED"}, {}, ArgEnvExitStatus(1, 0),
+ absl::StrCat(link.path(), "\n", script_relative, "\n"));
+}
+
+// argv[0] is added as the script path, even if there was none.
+TEST(ExecTest, InterpreterScriptArgvZeroAdded) {
+ // Symlink through /tmp to ensure the path is short enough.
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload)));
+
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755));
+
+ CheckExec(script.path(), {}, {}, ArgEnvExitStatus(1, 0),
+ absl::StrCat(link.path(), "\n", script.path(), "\n"));
+}
+
+// A NUL byte in the script line ends parsing.
+TEST(ExecTest, InterpreterScriptArgNUL) {
+ // Symlink through /tmp to ensure the path is short enough.
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload)));
+
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(),
+ absl::StrCat("#!", link.path(), " foo", std::string(1, '\0'), "bar"),
+ 0755));
+
+ CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0),
+ absl::StrCat(link.path(), "\nfoo\n", script.path(), "\n"));
+}
+
+// Trailing whitespace following interpreter path is ignored.
+TEST(ExecTest, InterpreterScriptTrailingWhitespace) {
+ // Symlink through /tmp to ensure the path is short enough.
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload)));
+
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " "), 0755));
+
+ CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(1, 0),
+ absl::StrCat(link.path(), "\n", script.path(), "\n"));
+}
+
+// Multiple whitespace characters between interpreter and arg allowed.
+TEST(ExecTest, InterpreterScriptArgWhitespace) {
+ // Symlink through /tmp to ensure the path is short enough.
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload)));
+
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " foo"), 0755));
+
+ CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0),
+ absl::StrCat(link.path(), "\nfoo\n", script.path(), "\n"));
+}
+
+TEST(ExecTest, InterpreterScriptNoPath) {
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(GetAbsoluteTestTmpdir(), "#!", 0755));
+
+ int execve_errno;
+ ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(script.path(), {script.path()}, {}, nullptr, &execve_errno));
+ EXPECT_EQ(execve_errno, ENOEXEC);
+}
+
+// AT_EXECFN is the path passed to execve.
+TEST(ExecTest, ExecFn) {
+ // Symlink through /tmp to ensure the path is short enough.
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kStateWorkload)));
+
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " PrintExecFn"),
+ 0755));
+
+ // Pass the script as a relative path and assert that is what appears in
+ // AT_EXECFN.
+ auto cwd = ASSERT_NO_ERRNO_AND_VALUE(GetCWD());
+ auto script_relative =
+ ASSERT_NO_ERRNO_AND_VALUE(GetRelativePath(cwd, script.path()));
+
+ CheckExec(script_relative, {script_relative}, {}, ArgEnvExitStatus(0, 0),
+ absl::StrCat(script_relative, "\n"));
+}
+
+TEST(ExecTest, ExecName) {
+ std::string path = RunfilePath(kStateWorkload);
+
+ CheckExec(path, {path, "PrintExecName"}, {}, ArgEnvExitStatus(0, 0),
+ absl::StrCat(Basename(path).substr(0, 15), "\n"));
+}
+
+TEST(ExecTest, ExecNameScript) {
+ // Symlink through /tmp to ensure the path is short enough.
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kStateWorkload)));
+
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(),
+ absl::StrCat("#!", link.path(), " PrintExecName"), 0755));
+
+ std::string script_path = script.path();
+
+ CheckExec(script_path, {script_path}, {}, ArgEnvExitStatus(0, 0),
+ absl::StrCat(Basename(script_path).substr(0, 15), "\n"));
+}
+
+// execve may be called by a multithreaded process.
+TEST(ExecTest, WithSiblingThread) {
+ CheckExec("/proc/self/exe", {"/proc/self/exe", kExecWithThread}, {},
+ W_EXITCODE(42, 0), "");
+}
+
+// execve may be called from a thread other than the leader of a multithreaded
+// process.
+TEST(ExecTest, FromSiblingThread) {
+ CheckExec("/proc/self/exe", {"/proc/self/exe", kExecFromThread}, {},
+ W_EXITCODE(42, 0), "");
+}
+
+TEST(ExecTest, NotFound) {
+ char* const argv[] = {nullptr};
+ char* const envp[] = {nullptr};
+ EXPECT_THAT(execve("/file/does/not/exist", argv, envp),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+TEST(ExecTest, NoExecPerm) {
+ char* const argv[] = {nullptr};
+ char* const envp[] = {nullptr};
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ EXPECT_THAT(execve(f.path().c_str(), argv, envp),
+ SyscallFailsWithErrno(EACCES));
+}
+
+// A signal handler we never expect to be called.
+void SignalHandler(int signo) {
+ std::cerr << "Signal " << signo << " raised." << std::endl;
+ exit(1);
+}
+
+// Signal handlers are reset on execve(2), unless they have default or ignored
+// disposition.
+TEST(ExecStateTest, HandlerReset) {
+ struct sigaction sa;
+ sa.sa_handler = SignalHandler;
+ ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds());
+
+ ExecveArray args = {
+ RunfilePath(kStateWorkload),
+ "CheckSigHandler",
+ absl::StrCat(SIGUSR1),
+ absl::StrCat(absl::Hex(reinterpret_cast<uintptr_t>(SIG_DFL))),
+ };
+
+ CheckExec(RunfilePath(kStateWorkload), args, {}, W_EXITCODE(0, 0), "");
+}
+
+// Ignored signal dispositions are not reset.
+TEST(ExecStateTest, IgnorePreserved) {
+ struct sigaction sa;
+ sa.sa_handler = SIG_IGN;
+ ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds());
+
+ ExecveArray args = {
+ RunfilePath(kStateWorkload),
+ "CheckSigHandler",
+ absl::StrCat(SIGUSR1),
+ absl::StrCat(absl::Hex(reinterpret_cast<uintptr_t>(SIG_IGN))),
+ };
+
+ CheckExec(RunfilePath(kStateWorkload), args, {}, W_EXITCODE(0, 0), "");
+}
+
+// Signal masks are not reset on exec
+TEST(ExecStateTest, SignalMask) {
+ sigset_t s;
+ sigemptyset(&s);
+ sigaddset(&s, SIGUSR1);
+ ASSERT_THAT(sigprocmask(SIG_BLOCK, &s, nullptr), SyscallSucceeds());
+
+ ExecveArray args = {
+ RunfilePath(kStateWorkload),
+ "CheckSigBlocked",
+ absl::StrCat(SIGUSR1),
+ };
+
+ CheckExec(RunfilePath(kStateWorkload), args, {}, W_EXITCODE(0, 0), "");
+}
+
+// itimers persist across execve.
+// N.B. Timers created with timer_create(2) should not be preserved!
+TEST(ExecStateTest, ItimerPreserved) {
+ // The fork in ForkAndExec clears itimers, so only set them up after fork.
+ auto setup_itimer = [] {
+ // Ignore SIGALRM, as we don't actually care about timer
+ // expirations.
+ struct sigaction sa;
+ sa.sa_handler = SIG_IGN;
+ int ret = sigaction(SIGALRM, &sa, nullptr);
+ if (ret < 0) {
+ _exit(errno);
+ }
+
+ struct itimerval itv;
+ itv.it_interval.tv_sec = 1;
+ itv.it_interval.tv_usec = 0;
+ itv.it_value.tv_sec = 1;
+ itv.it_value.tv_usec = 0;
+ ret = setitimer(ITIMER_REAL, &itv, nullptr);
+ if (ret < 0) {
+ _exit(errno);
+ }
+ };
+
+ std::string filename = RunfilePath(kStateWorkload);
+ ExecveArray argv = {
+ filename,
+ "CheckItimerEnabled",
+ absl::StrCat(ITIMER_REAL),
+ };
+
+ pid_t child;
+ int execve_errno;
+ auto kill = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(filename, argv, {}, setup_itimer, &child, &execve_errno));
+ ASSERT_EQ(0, execve_errno);
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds());
+ EXPECT_EQ(0, status);
+
+ // Process cleanup no longer needed.
+ kill.Release();
+}
+
+TEST(ProcSelfExe, ChangesAcrossExecve) {
+ // See exec_proc_exe_workload for more details. We simply
+ // assert that the /proc/self/exe link changes across execve.
+ CheckExec(RunfilePath(kProcExeWorkload),
+ {RunfilePath(kProcExeWorkload),
+ ASSERT_NO_ERRNO_AND_VALUE(ProcessExePath(getpid()))},
+ {}, W_EXITCODE(0, 0), "");
+}
+
+TEST(ExecTest, CloexecNormalFile) {
+ TempPath tempFile = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(GetAbsoluteTestTmpdir(), "bar", 0755));
+ const FileDescriptor fd_closed_on_exec =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY | O_CLOEXEC));
+
+ CheckExec(RunfilePath(kAssertClosedWorkload),
+ {RunfilePath(kAssertClosedWorkload),
+ absl::StrCat(fd_closed_on_exec.get())},
+ {}, W_EXITCODE(0, 0), "");
+
+ // The assert closed workload exits with code 2 if the file still exists. We
+ // can use this to do a negative test.
+ const FileDescriptor fd_open_on_exec =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY));
+
+ CheckExec(
+ RunfilePath(kAssertClosedWorkload),
+ {RunfilePath(kAssertClosedWorkload), absl::StrCat(fd_open_on_exec.get())},
+ {}, W_EXITCODE(2, 0), "");
+}
+
+TEST(ExecTest, CloexecEventfd) {
+ int efd;
+ ASSERT_THAT(efd = eventfd(0, EFD_CLOEXEC), SyscallSucceeds());
+ FileDescriptor fd(efd);
+
+ CheckExec(RunfilePath(kAssertClosedWorkload),
+ {RunfilePath(kAssertClosedWorkload), absl::StrCat(fd.get())}, {},
+ W_EXITCODE(0, 0), "");
+}
+
+constexpr int kLinuxMaxSymlinks = 40;
+
+TEST(ExecTest, SymlinkLimitExceeded) {
+ std::string path = RunfilePath(kBasicWorkload);
+
+ // Hold onto TempPath objects so they are not destructed prematurely.
+ std::vector<TempPath> symlinks;
+ for (int i = 0; i < kLinuxMaxSymlinks + 1; i++) {
+ symlinks.push_back(
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateSymlinkTo("/tmp", path)));
+ path = symlinks[i].path();
+ }
+
+ int execve_errno;
+ ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(path, {path}, {}, /*child=*/nullptr, &execve_errno));
+ EXPECT_EQ(execve_errno, ELOOP);
+}
+
+TEST(ExecTest, SymlinkLimitRefreshedForInterpreter) {
+ std::string tmp_dir = "/tmp";
+ std::string interpreter_path = "/bin/echo";
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ tmp_dir, absl::StrCat("#!", interpreter_path), 0755));
+ std::string script_path = script.path();
+
+ // Hold onto TempPath objects so they are not destructed prematurely.
+ std::vector<TempPath> interpreter_symlinks;
+ std::vector<TempPath> script_symlinks;
+ // Replace both the interpreter and script paths with symlink chains of just
+ // over half the symlink limit each; this is the minimum required to test that
+ // the symlink limit applies separately to each traversal, while tolerating
+ // some symlinks in the resolution of (the original) interpreter_path and
+ // script_path.
+ for (int i = 0; i < (kLinuxMaxSymlinks / 2) + 1; i++) {
+ interpreter_symlinks.push_back(ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(tmp_dir, interpreter_path)));
+ interpreter_path = interpreter_symlinks[i].path();
+ script_symlinks.push_back(ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(tmp_dir, script_path)));
+ script_path = script_symlinks[i].path();
+ }
+
+ CheckExec(script_path, {script_path}, {}, ArgEnvExitStatus(0, 0), "");
+}
+
+TEST(ExecveatTest, BasicWithFDCWD) {
+ std::string path = RunfilePath(kBasicWorkload);
+ CheckExecveat(AT_FDCWD, path, {path}, {}, /*flags=*/0, ArgEnvExitStatus(0, 0),
+ absl::StrCat(path, "\n"));
+}
+
+TEST(ExecveatTest, Basic) {
+ std::string absolute_path = RunfilePath(kBasicWorkload);
+ std::string parent_dir = std::string(Dirname(absolute_path));
+ std::string base = std::string(Basename(absolute_path));
+ const FileDescriptor dirfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY));
+
+ CheckExecveat(dirfd.get(), base, {absolute_path}, {}, /*flags=*/0,
+ ArgEnvExitStatus(0, 0), absl::StrCat(absolute_path, "\n"));
+}
+
+TEST(ExecveatTest, FDNotADirectory) {
+ std::string absolute_path = RunfilePath(kBasicWorkload);
+ std::string base = std::string(Basename(absolute_path));
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(absolute_path, 0));
+
+ int execve_errno;
+ ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(fd.get(), base, {absolute_path}, {},
+ /*flags=*/0, /*child=*/nullptr,
+ &execve_errno));
+ EXPECT_EQ(execve_errno, ENOTDIR);
+}
+
+TEST(ExecveatTest, AbsolutePathWithFDCWD) {
+ std::string path = RunfilePath(kBasicWorkload);
+ CheckExecveat(AT_FDCWD, path, {path}, {}, ArgEnvExitStatus(0, 0), 0,
+ absl::StrCat(path, "\n"));
+}
+
+TEST(ExecveatTest, AbsolutePath) {
+ std::string path = RunfilePath(kBasicWorkload);
+ // File descriptor should be ignored when an absolute path is given.
+ const int32_t badFD = -1;
+ CheckExecveat(badFD, path, {path}, {}, ArgEnvExitStatus(0, 0), 0,
+ absl::StrCat(path, "\n"));
+}
+
+TEST(ExecveatTest, EmptyPathBasic) {
+ std::string path = RunfilePath(kBasicWorkload);
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH));
+
+ CheckExecveat(fd.get(), "", {path}, {}, AT_EMPTY_PATH, ArgEnvExitStatus(0, 0),
+ absl::StrCat(path, "\n"));
+}
+
+TEST(ExecveatTest, EmptyPathWithDirFD) {
+ std::string path = RunfilePath(kBasicWorkload);
+ std::string parent_dir = std::string(Dirname(path));
+ const FileDescriptor dirfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY));
+
+ int execve_errno;
+ ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(dirfd.get(), "", {path}, {},
+ AT_EMPTY_PATH,
+ /*child=*/nullptr, &execve_errno));
+ EXPECT_EQ(execve_errno, EACCES);
+}
+
+TEST(ExecveatTest, EmptyPathWithoutEmptyPathFlag) {
+ std::string path = RunfilePath(kBasicWorkload);
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH));
+
+ int execve_errno;
+ ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(
+ fd.get(), "", {path}, {}, /*flags=*/0, /*child=*/nullptr, &execve_errno));
+ EXPECT_EQ(execve_errno, ENOENT);
+}
+
+TEST(ExecveatTest, AbsolutePathWithEmptyPathFlag) {
+ std::string path = RunfilePath(kBasicWorkload);
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH));
+
+ CheckExecveat(fd.get(), path, {path}, {}, AT_EMPTY_PATH,
+ ArgEnvExitStatus(0, 0), absl::StrCat(path, "\n"));
+}
+
+TEST(ExecveatTest, RelativePathWithEmptyPathFlag) {
+ std::string absolute_path = RunfilePath(kBasicWorkload);
+ std::string parent_dir = std::string(Dirname(absolute_path));
+ std::string base = std::string(Basename(absolute_path));
+ const FileDescriptor dirfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY));
+
+ CheckExecveat(dirfd.get(), base, {absolute_path}, {}, AT_EMPTY_PATH,
+ ArgEnvExitStatus(0, 0), absl::StrCat(absolute_path, "\n"));
+}
+
+TEST(ExecveatTest, SymlinkNoFollowWithRelativePath) {
+ std::string parent_dir = "/tmp";
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(parent_dir, RunfilePath(kBasicWorkload)));
+ const FileDescriptor dirfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY));
+ std::string base = std::string(Basename(link.path()));
+
+ int execve_errno;
+ ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(dirfd.get(), base, {base}, {},
+ AT_SYMLINK_NOFOLLOW,
+ /*child=*/nullptr, &execve_errno));
+ EXPECT_EQ(execve_errno, ELOOP);
+}
+
+TEST(ExecveatTest, UnshareFiles) {
+ TempPath tempFile = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(GetAbsoluteTestTmpdir(), "bar", 0755));
+ const FileDescriptor fd_closed_on_exec =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY | O_CLOEXEC));
+
+ ExecveArray argv = {"test"};
+ ExecveArray envp;
+ std::string child_path = RunfilePath(kBasicWorkload);
+ pid_t child =
+ syscall(__NR_clone, SIGCHLD | CLONE_VFORK | CLONE_FILES, 0, 0, 0, 0);
+ if (child == 0) {
+ execve(child_path.c_str(), argv.get(), envp.get());
+ _exit(1);
+ }
+ ASSERT_THAT(child, SyscallSucceeds());
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds());
+ EXPECT_EQ(status, 0);
+
+ struct stat st;
+ EXPECT_THAT(fstat(fd_closed_on_exec.get(), &st), SyscallSucceeds());
+}
+
+TEST(ExecveatTest, SymlinkNoFollowWithAbsolutePath) {
+ std::string parent_dir = "/tmp";
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(parent_dir, RunfilePath(kBasicWorkload)));
+ std::string path = link.path();
+
+ int execve_errno;
+ ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(AT_FDCWD, path, {path}, {},
+ AT_SYMLINK_NOFOLLOW,
+ /*child=*/nullptr, &execve_errno));
+ EXPECT_EQ(execve_errno, ELOOP);
+}
+
+TEST(ExecveatTest, SymlinkNoFollowAndEmptyPath) {
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload)));
+ std::string path = link.path();
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, 0));
+
+ CheckExecveat(fd.get(), "", {path}, {}, AT_EMPTY_PATH | AT_SYMLINK_NOFOLLOW,
+ ArgEnvExitStatus(0, 0), absl::StrCat(path, "\n"));
+}
+
+TEST(ExecveatTest, SymlinkNoFollowIgnoreSymlinkAncestor) {
+ TempPath parent_link =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateSymlinkTo("/tmp", "/bin"));
+ std::string path_with_symlink = JoinPath(parent_link.path(), "echo");
+
+ CheckExecveat(AT_FDCWD, path_with_symlink, {path_with_symlink}, {},
+ AT_SYMLINK_NOFOLLOW, ArgEnvExitStatus(0, 0), "");
+}
+
+TEST(ExecveatTest, SymlinkNoFollowWithNormalFile) {
+ const FileDescriptor dirfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/bin", O_DIRECTORY));
+
+ CheckExecveat(dirfd.get(), "echo", {"echo"}, {}, AT_SYMLINK_NOFOLLOW,
+ ArgEnvExitStatus(0, 0), "");
+}
+
+TEST(ExecveatTest, BasicWithCloexecFD) {
+ std::string path = RunfilePath(kBasicWorkload);
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CLOEXEC));
+
+ CheckExecveat(fd.get(), "", {path}, {}, AT_SYMLINK_NOFOLLOW | AT_EMPTY_PATH,
+ ArgEnvExitStatus(0, 0), absl::StrCat(path, "\n"));
+}
+
+TEST(ExecveatTest, InterpreterScriptWithCloexecFD) {
+ std::string path = RunfilePath(kExitScript);
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CLOEXEC));
+
+ int execve_errno;
+ ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(fd.get(), "", {path}, {},
+ AT_EMPTY_PATH, /*child=*/nullptr,
+ &execve_errno));
+ EXPECT_EQ(execve_errno, ENOENT);
+}
+
+TEST(ExecveatTest, InterpreterScriptWithCloexecDirFD) {
+ std::string absolute_path = RunfilePath(kExitScript);
+ std::string parent_dir = std::string(Dirname(absolute_path));
+ std::string base = std::string(Basename(absolute_path));
+ const FileDescriptor dirfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_CLOEXEC | O_DIRECTORY));
+
+ int execve_errno;
+ ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(dirfd.get(), base, {base}, {},
+ /*flags=*/0, /*child=*/nullptr,
+ &execve_errno));
+ EXPECT_EQ(execve_errno, ENOENT);
+}
+
+TEST(ExecveatTest, InvalidFlags) {
+ int execve_errno;
+ ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(
+ /*dirfd=*/-1, "", {}, {}, /*flags=*/0xFFFF, /*child=*/nullptr,
+ &execve_errno));
+ EXPECT_EQ(execve_errno, EINVAL);
+}
+
+// Priority consistent across calls to execve()
+TEST(GetpriorityTest, ExecveMaintainsPriority) {
+ int prio = 16;
+ ASSERT_THAT(setpriority(PRIO_PROCESS, getpid(), prio), SyscallSucceeds());
+
+ // To avoid trying to use negative exit values, check for
+ // 20 - prio. Since prio should always be in the range [-20, 19],
+ // this leave expected_exit_code in the range [1, 40].
+ int expected_exit_code = 20 - prio;
+
+ // Program run (priority_execve) will exit(X) where
+ // X=getpriority(PRIO_PROCESS,0). Check that this exit value is prio.
+ CheckExec(RunfilePath(kPriorityWorkload), {RunfilePath(kPriorityWorkload)},
+ {}, W_EXITCODE(expected_exit_code, 0), "");
+}
+
+void ExecWithThread() {
+ // Used to ensure that the thread has actually started.
+ absl::Mutex mu;
+ bool started = false;
+
+ ScopedThread t([&] {
+ mu.Lock();
+ started = true;
+ mu.Unlock();
+
+ while (true) {
+ pause();
+ }
+ });
+
+ mu.LockWhen(absl::Condition(&started));
+ mu.Unlock();
+
+ const ExecveArray argv = {"/proc/self/exe", kExit42};
+ const ExecveArray envv;
+
+ execve("/proc/self/exe", argv.get(), envv.get());
+ exit(errno);
+}
+
+void ExecFromThread() {
+ ScopedThread t([] {
+ const ExecveArray argv = {"/proc/self/exe", kExit42};
+ const ExecveArray envv;
+
+ execve("/proc/self/exe", argv.get(), envv.get());
+ exit(errno);
+ });
+
+ while (true) {
+ pause();
+ }
+}
+
+bool ValidateProcCmdlineVsArgv(const int argc, const char* const* argv) {
+ auto contents_or = GetContents("/proc/self/cmdline");
+ if (!contents_or.ok()) {
+ std::cerr << "Unable to get /proc/self/cmdline: " << contents_or.error()
+ << std::endl;
+ return false;
+ }
+ auto contents = contents_or.ValueOrDie();
+ if (contents.back() != '\0') {
+ std::cerr << "Non-null terminated /proc/self/cmdline!" << std::endl;
+ return false;
+ }
+ contents.pop_back();
+ std::vector<std::string> procfs_cmdline = absl::StrSplit(contents, '\0');
+
+ if (static_cast<int>(procfs_cmdline.size()) != argc) {
+ std::cerr << "argc = " << argc << " != " << procfs_cmdline.size()
+ << std::endl;
+ return false;
+ }
+
+ for (int i = 0; i < argc; ++i) {
+ if (procfs_cmdline[i] != argv[i]) {
+ std::cerr << "Procfs command line argument " << i << " mismatch "
+ << procfs_cmdline[i] << " != " << argv[i] << std::endl;
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
+
+int main(int argc, char** argv) {
+ // Start by validating that the stack argv is consistent with procfs.
+ if (!gvisor::testing::ValidateProcCmdlineVsArgv(argc, argv)) {
+ return 1;
+ }
+
+ // Some of these tests require no background threads, so check for them before
+ // TestInit.
+ for (int i = 0; i < argc; i++) {
+ absl::string_view arg(argv[i]);
+
+ if (arg == gvisor::testing::kExit42) {
+ return 42;
+ }
+ if (arg == gvisor::testing::kExecWithThread) {
+ gvisor::testing::ExecWithThread();
+ return 1;
+ }
+ if (arg == gvisor::testing::kExecFromThread) {
+ gvisor::testing::ExecFromThread();
+ return 1;
+ }
+ }
+
+ gvisor::testing::TestInit(&argc, &argv);
+ return gvisor::testing::RunAllTests();
+}
diff --git a/test/syscalls/linux/exec.h b/test/syscalls/linux/exec.h
new file mode 100644
index 000000000..5c0f7e654
--- /dev/null
+++ b/test/syscalls/linux/exec.h
@@ -0,0 +1,34 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_EXEC_H_
+#define GVISOR_TEST_SYSCALLS_EXEC_H_
+
+#include <sys/wait.h>
+
+namespace gvisor {
+namespace testing {
+
+// Returns the exit code used by exec_basic_workload.
+inline int ArgEnvExitCode(int args, int envs) { return args + envs * 10; }
+
+// Returns the exit status used by exec_basic_workload.
+inline int ArgEnvExitStatus(int args, int envs) {
+ return W_EXITCODE(ArgEnvExitCode(args, envs), 0);
+}
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_EXEC_H_
diff --git a/test/syscalls/linux/exec_assert_closed_workload.cc b/test/syscalls/linux/exec_assert_closed_workload.cc
new file mode 100644
index 000000000..95643618d
--- /dev/null
+++ b/test/syscalls/linux/exec_assert_closed_workload.cc
@@ -0,0 +1,45 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include <iostream>
+
+#include "absl/strings/numbers.h"
+
+int main(int argc, char** argv) {
+ if (argc != 2) {
+ std::cerr << "need two arguments, got " << argc;
+ exit(1);
+ }
+ int fd;
+ if (!absl::SimpleAtoi(argv[1], &fd)) {
+ std::cerr << "fd: " << argv[1] << " could not be parsed" << std::endl;
+ exit(1);
+ }
+ struct stat s;
+ if (fstat(fd, &s) == 0) {
+ std::cerr << "fd: " << argv[1] << " should not be valid" << std::endl;
+ exit(2);
+ }
+ if (errno != EBADF) {
+ std::cerr << "fstat fd: " << argv[1] << " got errno: " << errno
+ << " wanted: " << EBADF << std::endl;
+ exit(1);
+ }
+ return 0;
+}
diff --git a/test/syscalls/linux/exec_basic_workload.cc b/test/syscalls/linux/exec_basic_workload.cc
new file mode 100644
index 000000000..1bbd6437e
--- /dev/null
+++ b/test/syscalls/linux/exec_basic_workload.cc
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdlib.h>
+
+#include <iostream>
+
+#include "test/syscalls/linux/exec.h"
+
+int main(int argc, char** argv, char** envp) {
+ int i;
+ for (i = 0; i < argc; i++) {
+ std::cerr << argv[i] << std::endl;
+ }
+ for (i = 0; envp[i] != nullptr; i++) {
+ std::cerr << envp[i] << std::endl;
+ }
+ exit(gvisor::testing::ArgEnvExitCode(argc - 1, i));
+ return 0;
+}
diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc
new file mode 100644
index 000000000..18d2f22c1
--- /dev/null
+++ b/test/syscalls/linux/exec_binary.cc
@@ -0,0 +1,1646 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <elf.h>
+#include <errno.h>
+#include <signal.h>
+#include <sys/ptrace.h>
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <sys/user.h>
+#include <unistd.h>
+
+#include <algorithm>
+#include <functional>
+#include <iterator>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/proc_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+using ::testing::AnyOf;
+using ::testing::Eq;
+
+#if !defined(__x86_64__) && !defined(__aarch64__)
+// The assembly stub and ELF internal details must be ported to other arches.
+#error "Test only supported on x86-64/arm64"
+#endif // __x86_64__ || __aarch64__
+
+#if defined(__x86_64__)
+#define EM_TYPE EM_X86_64
+#define IP_REG(p) ((p).rip)
+#define RAX_REG(p) ((p).rax)
+#define RDI_REG(p) ((p).rdi)
+#define RETURN_REG(p) ((p).rax)
+
+// amd64 stub that calls PTRACE_TRACEME and sends itself SIGSTOP.
+const char kPtraceCode[] = {
+ // movq $101, %rax /* ptrace */
+ '\x48',
+ '\xc7',
+ '\xc0',
+ '\x65',
+ '\x00',
+ '\x00',
+ '\x00',
+ // movq $0, %rsi /* PTRACE_TRACEME */
+ '\x48',
+ '\xc7',
+ '\xc6',
+ '\x00',
+ '\x00',
+ '\x00',
+ '\x00',
+ // movq $0, %rdi
+ '\x48',
+ '\xc7',
+ '\xc7',
+ '\x00',
+ '\x00',
+ '\x00',
+ '\x00',
+ // movq $0, %rdx
+ '\x48',
+ '\xc7',
+ '\xc2',
+ '\x00',
+ '\x00',
+ '\x00',
+ '\x00',
+ // movq $0, %r10
+ '\x49',
+ '\xc7',
+ '\xc2',
+ '\x00',
+ '\x00',
+ '\x00',
+ '\x00',
+ // syscall
+ '\x0f',
+ '\x05',
+
+ // movq $39, %rax /* getpid */
+ '\x48',
+ '\xc7',
+ '\xc0',
+ '\x27',
+ '\x00',
+ '\x00',
+ '\x00',
+ // syscall
+ '\x0f',
+ '\x05',
+
+ // movq %rax, %rdi /* pid */
+ '\x48',
+ '\x89',
+ '\xc7',
+ // movq $62, %rax /* kill */
+ '\x48',
+ '\xc7',
+ '\xc0',
+ '\x3e',
+ '\x00',
+ '\x00',
+ '\x00',
+ // movq $19, %rsi /* SIGSTOP */
+ '\x48',
+ '\xc7',
+ '\xc6',
+ '\x13',
+ '\x00',
+ '\x00',
+ '\x00',
+ // syscall
+ '\x0f',
+ '\x05',
+};
+
+// Size of a syscall instruction.
+constexpr int kSyscallSize = 2;
+
+#elif defined(__aarch64__)
+#define EM_TYPE EM_AARCH64
+#define IP_REG(p) ((p).pc)
+#define RAX_REG(p) ((p).regs[8])
+#define RDI_REG(p) ((p).regs[0])
+#define RETURN_REG(p) ((p).regs[0])
+
+const char kPtraceCode[] = {
+ // MOVD $117, R8 /* ptrace */
+ '\xa8',
+ '\x0e',
+ '\x80',
+ '\xd2',
+ // MOVD $0, R0 /* PTRACE_TRACEME */
+ '\x00',
+ '\x00',
+ '\x80',
+ '\xd2',
+ // MOVD $0, R1 /* pid */
+ '\x01',
+ '\x00',
+ '\x80',
+ '\xd2',
+ // MOVD $0, R2 /* addr */
+ '\x02',
+ '\x00',
+ '\x80',
+ '\xd2',
+ // MOVD $0, R3 /* data */
+ '\x03',
+ '\x00',
+ '\x80',
+ '\xd2',
+ // SVC
+ '\x01',
+ '\x00',
+ '\x00',
+ '\xd4',
+ // MOVD $172, R8 /* getpid */
+ '\x88',
+ '\x15',
+ '\x80',
+ '\xd2',
+ // SVC
+ '\x01',
+ '\x00',
+ '\x00',
+ '\xd4',
+ // MOVD $129, R8 /* kill, R0=pid */
+ '\x28',
+ '\x10',
+ '\x80',
+ '\xd2',
+ // MOVD $19, R1 /* SIGSTOP */
+ '\x61',
+ '\x02',
+ '\x80',
+ '\xd2',
+ // SVC
+ '\x01',
+ '\x00',
+ '\x00',
+ '\xd4',
+};
+// Size of a syscall instruction.
+constexpr int kSyscallSize = 4;
+#else
+#error "Unknown architecture"
+#endif
+
+// This test suite tests executable loading in the kernel (ELF and interpreter
+// scripts).
+
+// Parameterized ELF types for 64 and 32 bit.
+template <int Size>
+struct ElfTypes;
+
+template <>
+struct ElfTypes<64> {
+ typedef Elf64_Ehdr ElfEhdr;
+ typedef Elf64_Phdr ElfPhdr;
+};
+
+template <>
+struct ElfTypes<32> {
+ typedef Elf32_Ehdr ElfEhdr;
+ typedef Elf32_Phdr ElfPhdr;
+};
+
+template <int Size>
+struct ElfBinary {
+ using ElfEhdr = typename ElfTypes<Size>::ElfEhdr;
+ using ElfPhdr = typename ElfTypes<Size>::ElfPhdr;
+
+ ElfEhdr header = {};
+ std::vector<ElfPhdr> phdrs;
+ std::vector<char> data;
+
+ // UpdateOffsets updates p_offset, p_vaddr in all phdrs to account for the
+ // space taken by the header and phdrs.
+ //
+ // It also updates header.e_phnum and adds the offset to header.e_entry to
+ // account for the headers residing in the first PT_LOAD segment.
+ //
+ // Before calling UpdateOffsets each of those fields should be the appropriate
+ // offset into data.
+ void UpdateOffsets() {
+ size_t offset = sizeof(header) + phdrs.size() * sizeof(ElfPhdr);
+ header.e_entry += offset;
+ header.e_phnum = phdrs.size();
+ for (auto& p : phdrs) {
+ p.p_offset += offset;
+ p.p_vaddr += offset;
+ }
+ }
+
+ // AddInterpreter adds a PT_INTERP segment with the passed contents.
+ //
+ // A later call to UpdateOffsets is required to make the new phdr valid.
+ void AddInterpreter(std::vector<char> contents) {
+ const int start = data.size();
+ data.insert(data.end(), contents.begin(), contents.end());
+ const int size = data.size() - start;
+
+ ElfPhdr phdr = {};
+ phdr.p_type = PT_INTERP;
+ phdr.p_offset = start;
+ phdr.p_filesz = size;
+ phdr.p_memsz = size;
+ // "If [PT_INTERP] is present, it must precede any loadable segment entry."
+ phdrs.insert(phdrs.begin(), phdr);
+ }
+
+ // Writes the header, phdrs, and data to fd.
+ PosixError Write(int fd) const {
+ int ret = WriteFd(fd, &header, sizeof(header));
+ if (ret < 0) {
+ return PosixError(errno, "failed to write header");
+ } else if (ret != sizeof(header)) {
+ return PosixError(EIO, absl::StrCat("short write of header: ", ret));
+ }
+
+ for (auto const& p : phdrs) {
+ ret = WriteFd(fd, &p, sizeof(p));
+ if (ret < 0) {
+ return PosixError(errno, "failed to write phdr");
+ } else if (ret != sizeof(p)) {
+ return PosixError(EIO, absl::StrCat("short write of phdr: ", ret));
+ }
+ }
+
+ ret = WriteFd(fd, data.data(), data.size());
+ if (ret < 0) {
+ return PosixError(errno, "failed to write data");
+ } else if (ret != static_cast<int>(data.size())) {
+ return PosixError(EIO, absl::StrCat("short write of data: ", ret));
+ }
+
+ return NoError();
+ }
+};
+
+// Creates a new temporary executable ELF file in parent with elf as the
+// contents.
+template <int Size>
+PosixErrorOr<TempPath> CreateElfWith(absl::string_view parent,
+ ElfBinary<Size> const& elf) {
+ ASSIGN_OR_RETURN_ERRNO(
+ auto file, TempPath::CreateFileWith(parent, absl::string_view(), 0755));
+ ASSIGN_OR_RETURN_ERRNO(auto fd, Open(file.path(), O_RDWR));
+ RETURN_IF_ERRNO(elf.Write(fd.get()));
+ return std::move(file);
+}
+
+// Creates a new temporary executable ELF file with elf as the contents.
+template <int Size>
+PosixErrorOr<TempPath> CreateElfWith(ElfBinary<Size> const& elf) {
+ return CreateElfWith(GetAbsoluteTestTmpdir(), elf);
+}
+
+// Wait for pid to stop, and assert that it stopped via SIGSTOP.
+PosixError WaitStopped(pid_t pid) {
+ int status;
+ int ret = RetryEINTR(waitpid)(pid, &status, 0);
+ MaybeSave();
+ if (ret < 0) {
+ return PosixError(errno, "wait failed");
+ } else if (ret != pid) {
+ return PosixError(ESRCH, absl::StrCat("wait got ", ret, " want ", pid));
+ }
+
+ if (!WIFSTOPPED(status) || WSTOPSIG(status) != SIGSTOP) {
+ return PosixError(EINVAL,
+ absl::StrCat("pid did not SIGSTOP; status = ", status));
+ }
+
+ return NoError();
+}
+
+// Returns a valid ELF that PTRACE_TRACEME and SIGSTOPs itself.
+//
+// UpdateOffsets must be called before writing this ELF.
+ElfBinary<64> StandardElf() {
+ ElfBinary<64> elf;
+ elf.header.e_ident[EI_MAG0] = ELFMAG0;
+ elf.header.e_ident[EI_MAG1] = ELFMAG1;
+ elf.header.e_ident[EI_MAG2] = ELFMAG2;
+ elf.header.e_ident[EI_MAG3] = ELFMAG3;
+ elf.header.e_ident[EI_CLASS] = ELFCLASS64;
+ elf.header.e_ident[EI_DATA] = ELFDATA2LSB;
+ elf.header.e_ident[EI_VERSION] = EV_CURRENT;
+ elf.header.e_type = ET_EXEC;
+ elf.header.e_machine = EM_TYPE;
+ elf.header.e_version = EV_CURRENT;
+ elf.header.e_phoff = sizeof(elf.header);
+ elf.header.e_phentsize = sizeof(decltype(elf)::ElfPhdr);
+
+ // TODO(gvisor.dev/issue/153): Always include a PT_GNU_STACK segment to
+ // disable executable stacks. With this omitted the stack (and all PROT_READ)
+ // mappings should be executable, but gVisor doesn't support that.
+ decltype(elf)::ElfPhdr phdr = {};
+ phdr.p_type = PT_GNU_STACK;
+ phdr.p_flags = PF_R | PF_W;
+ elf.phdrs.push_back(phdr);
+
+ phdr = {};
+ phdr.p_type = PT_LOAD;
+ phdr.p_flags = PF_R | PF_X;
+ phdr.p_offset = 0;
+ phdr.p_vaddr = 0x40000;
+ phdr.p_filesz = sizeof(kPtraceCode);
+ phdr.p_memsz = phdr.p_filesz;
+ elf.phdrs.push_back(phdr);
+
+ elf.header.e_entry = phdr.p_vaddr;
+
+ elf.data.assign(kPtraceCode, kPtraceCode + sizeof(kPtraceCode));
+
+ return elf;
+}
+
+// Test that a trivial binary executes.
+TEST(ElfTest, Execute) {
+ ElfBinary<64> elf = StandardElf();
+ elf.UpdateOffsets();
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, 0);
+
+ // Ensure it made it to SIGSTOP.
+ ASSERT_NO_ERRNO(WaitStopped(child));
+
+ struct user_regs_struct regs;
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
+ // RIP/PC is just beyond the final syscall instruction.
+ EXPECT_EQ(IP_REG(regs), elf.header.e_entry + sizeof(kPtraceCode));
+
+ EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({
+ {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
+ file.path().c_str()},
+ })));
+}
+
+// StandardElf without data completes execve, but faults once running.
+TEST(ElfTest, MissingText) {
+ ElfBinary<64> elf = StandardElf();
+ elf.data.clear();
+ elf.UpdateOffsets();
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, 0);
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0),
+ SyscallSucceedsWithValue(child));
+ // It runs off the end of the zeroes filling the end of the page.
+#if defined(__x86_64__)
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSEGV) << status;
+#elif defined(__aarch64__)
+ // 0 is an invalid instruction opcode on arm64.
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGILL) << status;
+#endif
+}
+
+// Typical ELF with a data + bss segment
+TEST(ElfTest, DataSegment) {
+ ElfBinary<64> elf = StandardElf();
+
+ // Create a standard ELF, but extend to 1.5 pages. The second page will be the
+ // beginning of a multi-page data + bss segment.
+ elf.data.resize(kPageSize + kPageSize / 2);
+
+ decltype(elf)::ElfPhdr phdr = {};
+ phdr.p_type = PT_LOAD;
+ phdr.p_flags = PF_R | PF_W;
+ phdr.p_offset = kPageSize;
+ phdr.p_vaddr = 0x41000;
+ phdr.p_filesz = kPageSize / 2;
+ // The header is going to push vaddr up by a few hundred bytes. Keep p_memsz a
+ // bit less than 2 pages so this mapping doesn't extend beyond 0x43000.
+ phdr.p_memsz = 2 * kPageSize - kPageSize / 2;
+ elf.phdrs.push_back(phdr);
+
+ elf.UpdateOffsets();
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, 0);
+
+ ASSERT_NO_ERRNO(WaitStopped(child));
+
+ EXPECT_THAT(
+ child, ContainsMappings(std::vector<ProcMapsEntry>({
+ // text page.
+ {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
+ file.path().c_str()},
+ // data + bss page from file.
+ {0x41000, 0x42000, true, true, false, true, kPageSize, 0, 0, 0,
+ file.path().c_str()},
+ // bss page from anon.
+ {0x42000, 0x43000, true, true, false, true, 0, 0, 0, 0, ""},
+ })));
+}
+
+// Additonal pages beyond filesz honor (only) execute protections.
+//
+// N.B. Linux changed this in 4.11 (16e72e9b30986 "powerpc: do not make the
+// entire heap executable"). Previously, extra pages were always RW.
+TEST(ElfTest, ExtraMemPages) {
+ // gVisor has the newer behavior.
+ if (!IsRunningOnGvisor()) {
+ auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion());
+ SKIP_IF(version.major < 4 || (version.major == 4 && version.minor < 11));
+ }
+
+ ElfBinary<64> elf = StandardElf();
+
+ // Create a standard ELF, but extend to 1.5 pages. The second page will be the
+ // beginning of a multi-page data + bss segment.
+ elf.data.resize(kPageSize + kPageSize / 2);
+
+ decltype(elf)::ElfPhdr phdr = {};
+ phdr.p_type = PT_LOAD;
+ // RWX segment. The extra anon page will also be RWX.
+ //
+ // N.B. Linux uses clear_user to clear the end of the file-mapped page, which
+ // respects the mapping protections. Thus if we map this RO with memsz >
+ // (unaligned) filesz, then execve will fail with EFAULT. See padzero(elf_bss)
+ // in fs/binfmt_elf.c:load_elf_binary.
+ //
+ // N.N.B.B. The above only applies to the last segment. For earlier segments,
+ // the clear_user error is ignored.
+ phdr.p_flags = PF_R | PF_W | PF_X;
+ phdr.p_offset = kPageSize;
+ phdr.p_vaddr = 0x41000;
+ phdr.p_filesz = kPageSize / 2;
+ // The header is going to push vaddr up by a few hundred bytes. Keep p_memsz a
+ // bit less than 2 pages so this mapping doesn't extend beyond 0x43000.
+ phdr.p_memsz = 2 * kPageSize - kPageSize / 2;
+ elf.phdrs.push_back(phdr);
+
+ elf.UpdateOffsets();
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, 0);
+
+ ASSERT_NO_ERRNO(WaitStopped(child));
+
+ EXPECT_THAT(child,
+ ContainsMappings(std::vector<ProcMapsEntry>({
+ // text page.
+ {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
+ file.path().c_str()},
+ // data + bss page from file.
+ {0x41000, 0x42000, true, true, true, true, kPageSize, 0, 0, 0,
+ file.path().c_str()},
+ // extra page from anon.
+ {0x42000, 0x43000, true, true, true, true, 0, 0, 0, 0, ""},
+ })));
+}
+
+// An aligned segment with filesz == 0, memsz > 0 is anon-only.
+TEST(ElfTest, AnonOnlySegment) {
+ ElfBinary<64> elf = StandardElf();
+
+ decltype(elf)::ElfPhdr phdr = {};
+ phdr.p_type = PT_LOAD;
+ // RO segment. The extra anon page will be RW anyways.
+ phdr.p_flags = PF_R;
+ phdr.p_offset = 0;
+ phdr.p_vaddr = 0x41000;
+ phdr.p_filesz = 0;
+ phdr.p_memsz = kPageSize;
+ elf.phdrs.push_back(phdr);
+
+ elf.UpdateOffsets();
+
+ // UpdateOffsets adjusts p_vaddr and p_offset by the header size, but we need
+ // a page-aligned p_vaddr to get a truly anon-only page.
+ elf.phdrs[2].p_vaddr = 0x41000;
+ // N.B. p_offset is now unaligned, but Linux doesn't care since this is
+ // anon-only.
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, 0);
+
+ ASSERT_NO_ERRNO(WaitStopped(child));
+
+ EXPECT_THAT(child,
+ ContainsMappings(std::vector<ProcMapsEntry>({
+ // text page.
+ {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
+ file.path().c_str()},
+ // anon page.
+ {0x41000, 0x42000, true, true, false, true, 0, 0, 0, 0, ""},
+ })));
+}
+
+// p_offset must have the same alignment as p_vaddr.
+TEST(ElfTest, UnalignedOffset) {
+ ElfBinary<64> elf = StandardElf();
+
+ // Unaligned offset.
+ elf.phdrs[1].p_offset += 1;
+
+ elf.UpdateOffsets();
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno));
+
+ // execve(2) return EINVAL, but behavior varies between Linux and gVisor.
+ //
+ // On Linux, the new mm is committed before attempting to map into it. By the
+ // time we hit EINVAL in the segment mmap, the old mm is gone. Linux returns
+ // to an empty mm, which immediately segfaults.
+ //
+ // OTOH, gVisor maps into the new mm before committing it. Thus when it hits
+ // failure, the caller is still intact to receive the error.
+ if (IsRunningOnGvisor()) {
+ ASSERT_EQ(execve_errno, EINVAL);
+ } else {
+ ASSERT_EQ(execve_errno, 0);
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0),
+ SyscallSucceedsWithValue(child));
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSEGV) << status;
+ }
+}
+
+// Linux will allow PT_LOAD segments to overlap.
+TEST(ElfTest, DirectlyOverlappingSegments) {
+ // NOTE(b/37289926): see PIEOutOfOrderSegments.
+ SKIP_IF(IsRunningOnGvisor());
+
+ ElfBinary<64> elf = StandardElf();
+
+ // Same as the StandardElf mapping.
+ decltype(elf)::ElfPhdr phdr = {};
+ phdr.p_type = PT_LOAD;
+ // Add PF_W so we can differentiate this mapping from the first.
+ phdr.p_flags = PF_R | PF_W | PF_X;
+ phdr.p_offset = 0;
+ phdr.p_vaddr = 0x40000;
+ phdr.p_filesz = sizeof(kPtraceCode);
+ phdr.p_memsz = phdr.p_filesz;
+ elf.phdrs.push_back(phdr);
+
+ elf.UpdateOffsets();
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, 0);
+
+ ASSERT_NO_ERRNO(WaitStopped(child));
+
+ EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({
+ {0x40000, 0x41000, true, true, true, true, 0, 0, 0, 0,
+ file.path().c_str()},
+ })));
+}
+
+// Linux allows out-of-order PT_LOAD segments.
+TEST(ElfTest, OutOfOrderSegments) {
+ // NOTE(b/37289926): see PIEOutOfOrderSegments.
+ SKIP_IF(IsRunningOnGvisor());
+
+ ElfBinary<64> elf = StandardElf();
+
+ decltype(elf)::ElfPhdr phdr = {};
+ phdr.p_type = PT_LOAD;
+ phdr.p_flags = PF_R | PF_X;
+ phdr.p_offset = 0;
+ phdr.p_vaddr = 0x20000;
+ phdr.p_filesz = sizeof(kPtraceCode);
+ phdr.p_memsz = phdr.p_filesz;
+ elf.phdrs.push_back(phdr);
+
+ elf.UpdateOffsets();
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, 0);
+
+ ASSERT_NO_ERRNO(WaitStopped(child));
+
+ EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({
+ {0x20000, 0x21000, true, false, true, true, 0, 0, 0, 0,
+ file.path().c_str()},
+ {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
+ file.path().c_str()},
+ })));
+}
+
+// header.e_phoff is bound the end of the file.
+TEST(ElfTest, OutOfBoundsPhdrs) {
+ ElfBinary<64> elf = StandardElf();
+ elf.header.e_phoff = 0x100000;
+ elf.UpdateOffsets();
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno));
+ // On Linux 3.11, this caused EIO. On newer Linux, it causes ENOEXEC.
+ EXPECT_THAT(execve_errno, AnyOf(Eq(ENOEXEC), Eq(EIO)));
+}
+
+// Claim there is a phdr beyond the end of the file, but don't include it.
+TEST(ElfTest, MissingPhdr) {
+ ElfBinary<64> elf = StandardElf();
+
+ // Clear data so the file ends immediately after the phdrs.
+ // N.B. Per ElfTest.MissingData, StandardElf without data completes execve
+ // without error.
+ elf.data.clear();
+ elf.UpdateOffsets();
+
+ // Claim that there is another phdr just beyond the end of the file. Of
+ // course, it isn't accessible.
+ elf.header.e_phnum++;
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno));
+ // On Linux 3.11, this caused EIO. On newer Linux, it causes ENOEXEC.
+ EXPECT_THAT(execve_errno, AnyOf(Eq(ENOEXEC), Eq(EIO)));
+}
+
+// No headers at all, just the ELF magic.
+TEST(ElfTest, MissingHeader) {
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0755));
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+
+ const char kElfMagic[] = {0x7f, 'E', 'L', 'F'};
+
+ ASSERT_THAT(WriteFd(fd.get(), &kElfMagic, sizeof(kElfMagic)),
+ SyscallSucceeds());
+ fd.reset();
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno));
+ EXPECT_EQ(execve_errno, ENOEXEC);
+}
+
+// Load a PIE ELF with a data + bss segment.
+TEST(ElfTest, PIE) {
+ ElfBinary<64> elf = StandardElf();
+
+ elf.header.e_type = ET_DYN;
+
+ // Create a standard ELF, but extend to 1.5 pages. The second page will be the
+ // beginning of a multi-page data + bss segment.
+ elf.data.resize(kPageSize + kPageSize / 2);
+
+ elf.header.e_entry = 0x0;
+
+ decltype(elf)::ElfPhdr phdr = {};
+ phdr.p_type = PT_LOAD;
+ phdr.p_flags = PF_R | PF_W;
+ phdr.p_offset = kPageSize;
+ // Put the data segment at a bit of an offset.
+ phdr.p_vaddr = 0x20000;
+ phdr.p_filesz = kPageSize / 2;
+ // The header is going to push vaddr up by a few hundred bytes. Keep p_memsz a
+ // bit less than 2 pages so this mapping doesn't extend beyond 0x43000.
+ phdr.p_memsz = 2 * kPageSize - kPageSize / 2;
+ elf.phdrs.push_back(phdr);
+
+ elf.UpdateOffsets();
+
+ // The first segment really needs to start at 0 for a normal PIE binary, and
+ // thus includes the headers.
+ const uint64_t offset = elf.phdrs[1].p_offset;
+ elf.phdrs[1].p_offset = 0x0;
+ elf.phdrs[1].p_vaddr = 0x0;
+ elf.phdrs[1].p_filesz += offset;
+ elf.phdrs[1].p_memsz += offset;
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, 0);
+
+ ASSERT_NO_ERRNO(WaitStopped(child));
+
+ // RIP tells us which page the first segment was loaded into.
+ struct user_regs_struct regs;
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
+
+ const uint64_t load_addr = IP_REG(regs) & ~(kPageSize - 1);
+
+ EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({
+ // text page.
+ {load_addr, load_addr + 0x1000, true, false, true,
+ true, 0, 0, 0, 0, file.path().c_str()},
+ // data + bss page from file.
+ {load_addr + 0x20000, load_addr + 0x21000, true, true,
+ false, true, kPageSize, 0, 0, 0, file.path().c_str()},
+ // bss page from anon.
+ {load_addr + 0x21000, load_addr + 0x22000, true, true,
+ false, true, 0, 0, 0, 0, ""},
+ })));
+}
+
+// PIE binary with a non-zero start address.
+//
+// This is non-standard for a PIE binary, but valid. The binary is still loaded
+// at an arbitrary address, not the first PT_LOAD vaddr.
+//
+// N.B. Linux changed this behavior in d1fd836dcf00d2028c700c7e44d2c23404062c90.
+// Previously, with "randomization" enabled, PIE binaries with a non-zero start
+// address would be be loaded at the address they specified because mmap was
+// passed the load address, which wasn't 0 as expected.
+//
+// This change is present in kernel v4.1+.
+TEST(ElfTest, PIENonZeroStart) {
+ // gVisor has the newer behavior.
+ if (!IsRunningOnGvisor()) {
+ auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion());
+ SKIP_IF(version.major < 4 || (version.major == 4 && version.minor < 1));
+ }
+
+ ElfBinary<64> elf = StandardElf();
+
+ elf.header.e_type = ET_DYN;
+
+ // Create a standard ELF, but extend to 1.5 pages. The second page will be the
+ // beginning of a multi-page data + bss segment.
+ elf.data.resize(kPageSize + kPageSize / 2);
+
+ decltype(elf)::ElfPhdr phdr = {};
+ phdr.p_type = PT_LOAD;
+ phdr.p_flags = PF_R | PF_W;
+ phdr.p_offset = kPageSize;
+ // Put the data segment at a bit of an offset.
+ phdr.p_vaddr = 0x60000;
+ phdr.p_filesz = kPageSize / 2;
+ // The header is going to push vaddr up by a few hundred bytes. Keep p_memsz a
+ // bit less than 2 pages so this mapping doesn't extend beyond 0x43000.
+ phdr.p_memsz = 2 * kPageSize - kPageSize / 2;
+ elf.phdrs.push_back(phdr);
+
+ elf.UpdateOffsets();
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, 0);
+
+ ASSERT_NO_ERRNO(WaitStopped(child));
+
+ // RIP tells us which page the first segment was loaded into.
+ struct user_regs_struct regs;
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
+
+ const uint64_t load_addr = IP_REG(regs) & ~(kPageSize - 1);
+
+ // The ELF is loaded at an arbitrary address, not the first PT_LOAD vaddr.
+ //
+ // N.B. this is technically flaky, but Linux is *extremely* unlikely to pick
+ // this as the start address, as it searches from the top down.
+ EXPECT_NE(load_addr, 0x40000);
+
+ EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({
+ // text page.
+ {load_addr, load_addr + 0x1000, true, false, true,
+ true, 0, 0, 0, 0, file.path().c_str()},
+ // data + bss page from file.
+ {load_addr + 0x20000, load_addr + 0x21000, true, true,
+ false, true, kPageSize, 0, 0, 0, file.path().c_str()},
+ // bss page from anon.
+ {load_addr + 0x21000, load_addr + 0x22000, true, true,
+ false, true, 0, 0, 0, 0, ""},
+ })));
+}
+
+TEST(ElfTest, PIEOutOfOrderSegments) {
+ // TODO(b/37289926): This triggers a bug in Linux where it computes the size
+ // of the binary as 0x20000 - 0x40000 = 0xfffffffffffe0000, which obviously
+ // fails to map.
+ //
+ // We test gVisor's behavior (of rejecting the binary) because I assert that
+ // Linux is wrong and needs to be fixed.
+ SKIP_IF(!IsRunningOnGvisor());
+
+ ElfBinary<64> elf = StandardElf();
+
+ elf.header.e_type = ET_DYN;
+
+ // Create a standard ELF, but extend to 1.5 pages. The second page will be the
+ // beginning of a multi-page data + bss segment.
+ elf.data.resize(kPageSize + kPageSize / 2);
+
+ decltype(elf)::ElfPhdr phdr = {};
+ phdr.p_type = PT_LOAD;
+ phdr.p_flags = PF_R | PF_W;
+ phdr.p_offset = kPageSize;
+ // Put the data segment *before* the first segment.
+ phdr.p_vaddr = 0x20000;
+ phdr.p_filesz = kPageSize / 2;
+ // The header is going to push vaddr up by a few hundred bytes. Keep p_memsz a
+ // bit less than 2 pages so this mapping doesn't extend beyond 0x43000.
+ phdr.p_memsz = 2 * kPageSize - kPageSize / 2;
+ elf.phdrs.push_back(phdr);
+
+ elf.UpdateOffsets();
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno));
+ EXPECT_EQ(execve_errno, ENOEXEC);
+}
+
+// Standard dynamically linked binary with an ELF interpreter.
+TEST(ElfTest, ELFInterpreter) {
+ ElfBinary<64> interpreter = StandardElf();
+ interpreter.header.e_type = ET_DYN;
+ interpreter.header.e_entry = 0x0;
+ interpreter.UpdateOffsets();
+
+ // The first segment really needs to start at 0 for a normal PIE binary, and
+ // thus includes the headers.
+ uint64_t const offset = interpreter.phdrs[1].p_offset;
+ // N.B. Since Linux 4.10 (0036d1f7eb95b "binfmt_elf: fix calculations for bss
+ // padding"), Linux unconditionally zeroes the remainder of the highest mapped
+ // page in an interpreter, failing if the protections don't allow write. Thus
+ // we must mark this writeable.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
+ interpreter.phdrs[1].p_offset = 0x0;
+ interpreter.phdrs[1].p_vaddr = 0x0;
+ interpreter.phdrs[1].p_filesz += offset;
+ interpreter.phdrs[1].p_memsz += offset;
+
+ TempPath interpreter_file =
+ ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(interpreter));
+
+ ElfBinary<64> binary = StandardElf();
+
+ // Append the interpreter path.
+ int const interp_data_start = binary.data.size();
+ for (char const c : interpreter_file.path()) {
+ binary.data.push_back(c);
+ }
+ // NUL-terminate.
+ binary.data.push_back(0);
+ int const interp_data_size = binary.data.size() - interp_data_start;
+
+ decltype(binary)::ElfPhdr phdr = {};
+ phdr.p_type = PT_INTERP;
+ phdr.p_offset = interp_data_start;
+ phdr.p_filesz = interp_data_size;
+ phdr.p_memsz = interp_data_size;
+ // "If [PT_INTERP] is present, it must precede any loadable segment entry."
+ //
+ // However, Linux allows it anywhere, so we just stick it at the end to make
+ // sure out-of-order PT_INTERP is OK.
+ binary.phdrs.push_back(phdr);
+
+ binary.UpdateOffsets();
+
+ TempPath binary_file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(binary));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec(
+ binary_file.path(), {binary_file.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, 0);
+
+ ASSERT_NO_ERRNO(WaitStopped(child));
+
+ // RIP tells us which page the first segment of the interpreter was loaded
+ // into.
+ struct user_regs_struct regs;
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
+
+ const uint64_t interp_load_addr = IP_REG(regs) & ~(kPageSize - 1);
+
+ EXPECT_THAT(
+ child, ContainsMappings(std::vector<ProcMapsEntry>({
+ // Main binary
+ {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
+ binary_file.path().c_str()},
+ // Interpreter
+ {interp_load_addr, interp_load_addr + 0x1000, true, true, true,
+ true, 0, 0, 0, 0, interpreter_file.path().c_str()},
+ })));
+}
+
+// Test parameter to ElfInterpterStaticTest cases. The first item is a suffix to
+// add to the end of the interpreter path in the PT_INTERP segment and the
+// second is the expected execve(2) errno.
+using ElfInterpreterStaticParam = std::tuple<std::vector<char>, int>;
+
+class ElfInterpreterStaticTest
+ : public ::testing::TestWithParam<ElfInterpreterStaticParam> {};
+
+// Statically linked ELF with a statically linked ELF interpreter.
+TEST_P(ElfInterpreterStaticTest, Test) {
+ const std::vector<char> segment_suffix = std::get<0>(GetParam());
+ const int expected_errno = std::get<1>(GetParam());
+
+ ElfBinary<64> interpreter = StandardElf();
+ // See comment in ElfTest.ELFInterpreter.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
+ interpreter.UpdateOffsets();
+ TempPath interpreter_file =
+ ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(interpreter));
+
+ ElfBinary<64> binary = StandardElf();
+ // The PT_LOAD segment conflicts with the interpreter's PT_LOAD segment. The
+ // interpreter's will be mapped directly over the binary's.
+
+ // Interpreter path plus the parameterized suffix in the PT_INTERP segment.
+ const std::string path = interpreter_file.path();
+ std::vector<char> segment(path.begin(), path.end());
+ segment.insert(segment.end(), segment_suffix.begin(), segment_suffix.end());
+ binary.AddInterpreter(segment);
+
+ binary.UpdateOffsets();
+
+ TempPath binary_file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(binary));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec(
+ binary_file.path(), {binary_file.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, expected_errno);
+
+ if (expected_errno == 0) {
+ ASSERT_NO_ERRNO(WaitStopped(child));
+
+ EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({
+ // Interpreter.
+ {0x40000, 0x41000, true, true, true, true, 0, 0, 0,
+ 0, interpreter_file.path().c_str()},
+ })));
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ Cases, ElfInterpreterStaticTest,
+ ::testing::ValuesIn({
+ // Simple NUL-terminator to run the interpreter as normal.
+ std::make_tuple(std::vector<char>({'\0'}), 0),
+ // Add some garbage to the segment followed by a NUL-terminator. This is
+ // ignored.
+ std::make_tuple(std::vector<char>({'\0', 'b', '\0'}), 0),
+ // Add some garbage to the segment without a NUL-terminator. Linux will
+ // reject
+ // this.
+ std::make_tuple(std::vector<char>({'\0', 'b'}), ENOEXEC),
+ }));
+
+// Test parameter to ElfInterpterBadPathTest cases. The first item is the
+// contents of the PT_INTERP segment and the second is the expected execve(2)
+// errno.
+using ElfInterpreterBadPathParam = std::tuple<std::vector<char>, int>;
+
+class ElfInterpreterBadPathTest
+ : public ::testing::TestWithParam<ElfInterpreterBadPathParam> {};
+
+TEST_P(ElfInterpreterBadPathTest, Test) {
+ const std::vector<char> segment = std::get<0>(GetParam());
+ const int expected_errno = std::get<1>(GetParam());
+
+ ElfBinary<64> binary = StandardElf();
+ binary.AddInterpreter(segment);
+ binary.UpdateOffsets();
+
+ TempPath binary_file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(binary));
+
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec(
+ binary_file.path(), {binary_file.path()}, {}, nullptr, &execve_errno));
+ EXPECT_EQ(execve_errno, expected_errno);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ Cases, ElfInterpreterBadPathTest,
+ ::testing::ValuesIn({
+ // NUL-terminated fake path in the PT_INTERP segment.
+ std::make_tuple(std::vector<char>({'/', 'f', '/', 'b', '\0'}), ENOENT),
+ // ELF interpreter not NUL-terminated.
+ std::make_tuple(std::vector<char>({'/', 'f', '/', 'b'}), ENOEXEC),
+ // ELF interpreter path omitted entirely.
+ //
+ // fs/binfmt_elf.c:load_elf_binary returns ENOEXEC if p_filesz is < 2
+ // bytes.
+ std::make_tuple(std::vector<char>({'\0'}), ENOEXEC),
+ // ELF interpreter path = "\0".
+ //
+ // fs/binfmt_elf.c:load_elf_binary returns ENOEXEC if p_filesz is < 2
+ // bytes, so add an extra byte to pass that check.
+ //
+ // load_elf_binary -> open_exec -> do_open_execat fails to check that
+ // name != '\0' before calling do_filp_open, which thus opens the
+ // working directory. do_open_execat returns EACCES because the
+ // directory is not a regular file.
+ std::make_tuple(std::vector<char>({'\0', '\0'}), EACCES),
+ }));
+
+// Relative path to ELF interpreter.
+TEST(ElfTest, ELFInterpreterRelative) {
+ ElfBinary<64> interpreter = StandardElf();
+ interpreter.header.e_type = ET_DYN;
+ interpreter.header.e_entry = 0x0;
+ interpreter.UpdateOffsets();
+
+ // The first segment really needs to start at 0 for a normal PIE binary, and
+ // thus includes the headers.
+ uint64_t const offset = interpreter.phdrs[1].p_offset;
+ // See comment in ElfTest.ELFInterpreter.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
+ interpreter.phdrs[1].p_offset = 0x0;
+ interpreter.phdrs[1].p_vaddr = 0x0;
+ interpreter.phdrs[1].p_filesz += offset;
+ interpreter.phdrs[1].p_memsz += offset;
+
+ TempPath interpreter_file =
+ ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(interpreter));
+ auto cwd = ASSERT_NO_ERRNO_AND_VALUE(GetCWD());
+ auto interpreter_relative =
+ ASSERT_NO_ERRNO_AND_VALUE(GetRelativePath(cwd, interpreter_file.path()));
+
+ ElfBinary<64> binary = StandardElf();
+
+ // NUL-terminated path in the PT_INTERP segment.
+ std::vector<char> segment(interpreter_relative.begin(),
+ interpreter_relative.end());
+ segment.push_back(0);
+ binary.AddInterpreter(segment);
+
+ binary.UpdateOffsets();
+
+ TempPath binary_file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(binary));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec(
+ binary_file.path(), {binary_file.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, 0);
+
+ ASSERT_NO_ERRNO(WaitStopped(child));
+
+ // RIP tells us which page the first segment of the interpreter was loaded
+ // into.
+ struct user_regs_struct regs;
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
+
+ const uint64_t interp_load_addr = IP_REG(regs) & ~(kPageSize - 1);
+
+ EXPECT_THAT(
+ child, ContainsMappings(std::vector<ProcMapsEntry>({
+ // Main binary
+ {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
+ binary_file.path().c_str()},
+ // Interpreter
+ {interp_load_addr, interp_load_addr + 0x1000, true, true, true,
+ true, 0, 0, 0, 0, interpreter_file.path().c_str()},
+ })));
+}
+
+// ELF interpreter architecture doesn't match the binary.
+TEST(ElfTest, ELFInterpreterWrongArch) {
+ ElfBinary<64> interpreter = StandardElf();
+ interpreter.header.e_machine = EM_PPC64;
+ interpreter.header.e_type = ET_DYN;
+ interpreter.header.e_entry = 0x0;
+ interpreter.UpdateOffsets();
+
+ // The first segment really needs to start at 0 for a normal PIE binary, and
+ // thus includes the headers.
+ uint64_t const offset = interpreter.phdrs[1].p_offset;
+ // See comment in ElfTest.ELFInterpreter.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
+ interpreter.phdrs[1].p_offset = 0x0;
+ interpreter.phdrs[1].p_vaddr = 0x0;
+ interpreter.phdrs[1].p_filesz += offset;
+ interpreter.phdrs[1].p_memsz += offset;
+
+ TempPath interpreter_file =
+ ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(interpreter));
+
+ ElfBinary<64> binary = StandardElf();
+
+ // NUL-terminated path in the PT_INTERP segment.
+ const std::string path = interpreter_file.path();
+ std::vector<char> segment(path.begin(), path.end());
+ segment.push_back(0);
+ binary.AddInterpreter(segment);
+
+ binary.UpdateOffsets();
+
+ TempPath binary_file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(binary));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec(
+ binary_file.path(), {binary_file.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, ELIBBAD);
+}
+
+// No execute permissions on the binary.
+TEST(ElfTest, NoExecute) {
+ ElfBinary<64> elf = StandardElf();
+ elf.UpdateOffsets();
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf));
+
+ ASSERT_THAT(chmod(file.path().c_str(), 0644), SyscallSucceeds());
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno));
+ EXPECT_EQ(execve_errno, EACCES);
+}
+
+// Execute, but no read permissions on the binary works just fine.
+TEST(ElfTest, NoRead) {
+ // TODO(gvisor.dev/issue/160): gVisor's backing filesystem may prevent the
+ // sentry from reading the executable.
+ SKIP_IF(IsRunningOnGvisor());
+
+ ElfBinary<64> elf = StandardElf();
+ elf.UpdateOffsets();
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf));
+
+ ASSERT_THAT(chmod(file.path().c_str(), 0111), SyscallSucceeds());
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, 0);
+
+ ASSERT_NO_ERRNO(WaitStopped(child));
+
+ // TODO(gvisor.dev/issue/160): A task with a non-readable executable is marked
+ // non-dumpable, preventing access to proc files. gVisor does not implement
+ // this behavior.
+}
+
+// No execute permissions on the ELF interpreter.
+TEST(ElfTest, ElfInterpreterNoExecute) {
+ ElfBinary<64> interpreter = StandardElf();
+ interpreter.header.e_type = ET_DYN;
+ interpreter.header.e_entry = 0x0;
+ interpreter.UpdateOffsets();
+
+ // The first segment really needs to start at 0 for a normal PIE binary, and
+ // thus includes the headers.
+ uint64_t const offset = interpreter.phdrs[1].p_offset;
+ // See comment in ElfTest.ELFInterpreter.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
+ interpreter.phdrs[1].p_offset = 0x0;
+ interpreter.phdrs[1].p_vaddr = 0x0;
+ interpreter.phdrs[1].p_filesz += offset;
+ interpreter.phdrs[1].p_memsz += offset;
+
+ TempPath interpreter_file =
+ ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(interpreter));
+
+ ElfBinary<64> binary = StandardElf();
+
+ // NUL-terminated path in the PT_INTERP segment.
+ const std::string path = interpreter_file.path();
+ std::vector<char> segment(path.begin(), path.end());
+ segment.push_back(0);
+ binary.AddInterpreter(segment);
+
+ binary.UpdateOffsets();
+
+ TempPath binary_file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(binary));
+
+ ASSERT_THAT(chmod(interpreter_file.path().c_str(), 0644), SyscallSucceeds());
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(interpreter_file.path(), {interpreter_file.path()}, {},
+ &child, &execve_errno));
+ EXPECT_EQ(execve_errno, EACCES);
+}
+
+// Execute a basic interpreter script.
+TEST(InterpreterScriptTest, Execute) {
+ ElfBinary<64> elf = StandardElf();
+ elf.UpdateOffsets();
+ // Use /tmp explicitly to ensure the path is short enough.
+ TempPath binary = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith("/tmp", elf));
+
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::StrCat("#!", binary.path()), 0755));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(script.path(), {script.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, 0);
+
+ EXPECT_NO_ERRNO(WaitStopped(child));
+}
+
+// Whitespace after #!.
+TEST(InterpreterScriptTest, Whitespace) {
+ ElfBinary<64> elf = StandardElf();
+ elf.UpdateOffsets();
+ // Use /tmp explicitly to ensure the path is short enough.
+ TempPath binary = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith("/tmp", elf));
+
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::StrCat("#! \t \t", binary.path()), 0755));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(script.path(), {script.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, 0);
+
+ EXPECT_NO_ERRNO(WaitStopped(child));
+}
+
+// Interpreter script is missing execute permission.
+TEST(InterpreterScriptTest, InterpreterScriptNoExecute) {
+ ElfBinary<64> elf = StandardElf();
+ elf.UpdateOffsets();
+ // Use /tmp explicitly to ensure the path is short enough.
+ TempPath binary = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith("/tmp", elf));
+
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::StrCat("#!", binary.path()), 0644));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(script.path(), {script.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, EACCES);
+}
+
+// Binary interpreter script refers to is missing execute permission.
+TEST(InterpreterScriptTest, BinaryNoExecute) {
+ ElfBinary<64> elf = StandardElf();
+ elf.UpdateOffsets();
+ // Use /tmp explicitly to ensure the path is short enough.
+ TempPath binary = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith("/tmp", elf));
+
+ ASSERT_THAT(chmod(binary.path().c_str(), 0644), SyscallSucceeds());
+
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::StrCat("#!", binary.path()), 0755));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(script.path(), {script.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, EACCES);
+}
+
+// Linux will load interpreter scripts five levels deep, but no more.
+TEST(InterpreterScriptTest, MaxRecursion) {
+ ElfBinary<64> elf = StandardElf();
+ elf.UpdateOffsets();
+ // Use /tmp explicitly to ensure the path is short enough.
+ TempPath binary = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith("/tmp", elf));
+
+ TempPath script1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ "/tmp", absl::StrCat("#!", binary.path()), 0755));
+ TempPath script2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ "/tmp", absl::StrCat("#!", script1.path()), 0755));
+ TempPath script3 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ "/tmp", absl::StrCat("#!", script2.path()), 0755));
+ TempPath script4 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ "/tmp", absl::StrCat("#!", script3.path()), 0755));
+ TempPath script5 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ "/tmp", absl::StrCat("#!", script4.path()), 0755));
+ TempPath script6 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ "/tmp", absl::StrCat("#!", script5.path()), 0755));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(script6.path(), {script6.path()}, {}, &child, &execve_errno));
+ // Too many levels of recursion.
+ EXPECT_EQ(execve_errno, ELOOP);
+
+ // The next level up is OK.
+ auto cleanup2 = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(script5.path(), {script5.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, 0);
+
+ EXPECT_NO_ERRNO(WaitStopped(child));
+}
+
+// Interpreter script with a relative path.
+TEST(InterpreterScriptTest, RelativePath) {
+ ElfBinary<64> elf = StandardElf();
+ elf.UpdateOffsets();
+ TempPath binary = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith("/tmp", elf));
+
+ auto cwd = ASSERT_NO_ERRNO_AND_VALUE(GetCWD());
+ auto binary_relative =
+ ASSERT_NO_ERRNO_AND_VALUE(GetRelativePath(cwd, binary.path()));
+
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::StrCat("#!", binary_relative), 0755));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(script.path(), {script.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, 0);
+
+ EXPECT_NO_ERRNO(WaitStopped(child));
+}
+
+// Interpreter script with .. in a path component.
+TEST(InterpreterScriptTest, UncleanPath) {
+ ElfBinary<64> elf = StandardElf();
+ elf.UpdateOffsets();
+ // Use /tmp explicitly to ensure the path is short enough.
+ TempPath binary = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith("/tmp", elf));
+
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::StrCat("#!/tmp/../", binary.path()),
+ 0755));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(script.path(), {script.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, 0);
+
+ EXPECT_NO_ERRNO(WaitStopped(child));
+}
+
+// Passed interpreter script is a symlink.
+TEST(InterpreterScriptTest, Symlink) {
+ ElfBinary<64> elf = StandardElf();
+ elf.UpdateOffsets();
+ // Use /tmp explicitly to ensure the path is short enough.
+ TempPath binary = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith("/tmp", elf));
+
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::StrCat("#!", binary.path()), 0755));
+
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(GetAbsoluteTestTmpdir(), script.path()));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(link.path(), {link.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, 0);
+
+ EXPECT_NO_ERRNO(WaitStopped(child));
+}
+
+// Interpreter script points to a symlink loop.
+TEST(InterpreterScriptTest, SymlinkLoop) {
+ std::string const link1 = NewTempAbsPathInDir("/tmp");
+ std::string const link2 = NewTempAbsPathInDir("/tmp");
+
+ ASSERT_THAT(symlink(link2.c_str(), link1.c_str()), SyscallSucceeds());
+ auto remove_link1 = Cleanup(
+ [&link1] { EXPECT_THAT(unlink(link1.c_str()), SyscallSucceeds()); });
+
+ ASSERT_THAT(symlink(link1.c_str(), link2.c_str()), SyscallSucceeds());
+ auto remove_link2 = Cleanup(
+ [&link2] { EXPECT_THAT(unlink(link2.c_str()), SyscallSucceeds()); });
+
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::StrCat("#!", link1), 0755));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(script.path(), {script.path()}, {}, &child, &execve_errno));
+ EXPECT_EQ(execve_errno, ELOOP);
+}
+
+// Binary is a symlink loop.
+TEST(ExecveTest, SymlinkLoop) {
+ std::string const link1 = NewTempAbsPathInDir("/tmp");
+ std::string const link2 = NewTempAbsPathInDir("/tmp");
+
+ ASSERT_THAT(symlink(link2.c_str(), link1.c_str()), SyscallSucceeds());
+ auto remove_link = Cleanup(
+ [&link1] { EXPECT_THAT(unlink(link1.c_str()), SyscallSucceeds()); });
+
+ ASSERT_THAT(symlink(link1.c_str(), link2.c_str()), SyscallSucceeds());
+ auto remove_link2 = Cleanup(
+ [&link2] { EXPECT_THAT(unlink(link2.c_str()), SyscallSucceeds()); });
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(link1, {link1}, {}, &child, &execve_errno));
+ EXPECT_EQ(execve_errno, ELOOP);
+}
+
+// Binary is a directory.
+TEST(ExecveTest, Directory) {
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec("/tmp", {"/tmp"}, {}, &child, &execve_errno));
+ EXPECT_EQ(execve_errno, EACCES);
+}
+
+// Pass a valid binary as a directory (extra / on the end).
+TEST(ExecveTest, BinaryAsDirectory) {
+ ElfBinary<64> elf = StandardElf();
+ elf.UpdateOffsets();
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf));
+
+ std::string const path = absl::StrCat(file.path(), "/");
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(path, {path}, {}, &child, &execve_errno));
+ EXPECT_EQ(execve_errno, ENOTDIR);
+}
+
+// The initial brk value is after the page at the end of the binary.
+TEST(ExecveTest, BrkAfterBinary) {
+ ElfBinary<64> elf = StandardElf();
+ elf.UpdateOffsets();
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno));
+ ASSERT_EQ(execve_errno, 0);
+
+ // Ensure it made it to SIGSTOP.
+ ASSERT_NO_ERRNO(WaitStopped(child));
+
+ struct user_regs_struct regs;
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
+
+ // RIP is just beyond the final syscall instruction. Rewind to execute a brk
+ // syscall.
+ IP_REG(regs) -= kSyscallSize;
+ RAX_REG(regs) = __NR_brk;
+ RDI_REG(regs) = 0;
+ ASSERT_THAT(ptrace(PTRACE_SETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+
+ // Resume the child, waiting for syscall entry.
+ ASSERT_THAT(ptrace(PTRACE_SYSCALL, child, 0, 0), SyscallSucceeds());
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0),
+ SyscallSucceedsWithValue(child));
+ ASSERT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP)
+ << "status = " << status;
+
+ // Execute the syscall.
+ ASSERT_THAT(ptrace(PTRACE_SYSCALL, child, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0),
+ SyscallSucceedsWithValue(child));
+ ASSERT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP)
+ << "status = " << status;
+
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
+
+ // brk is after the text page.
+ //
+ // The kernel does brk randomization, so we can't be sure what the exact
+ // address will be, but it is always beyond the final page in the binary.
+ // i.e., it does not start immediately after memsz in the middle of a page.
+ // Userspace may expect to use that space.
+ EXPECT_GE(RETURN_REG(regs), 0x41000);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/exec_proc_exe_workload.cc b/test/syscalls/linux/exec_proc_exe_workload.cc
new file mode 100644
index 000000000..2989379b7
--- /dev/null
+++ b/test/syscalls/linux/exec_proc_exe_workload.cc
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdlib.h>
+#include <unistd.h>
+
+#include <iostream>
+
+#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
+
+int main(int argc, char** argv, char** envp) {
+ // This is annoying. Because remote build systems may put these binaries
+ // in a content-addressable-store, you may wind up with /proc/self/exe
+ // pointing to some random path (but with a sensible argv[0]).
+ //
+ // Therefore, this test simply checks that the /proc/self/exe
+ // is absolute and *doesn't* match argv[1].
+ std::string exe =
+ gvisor::testing::ProcessExePath(getpid()).ValueOrDie();
+ if (exe[0] != '/') {
+ std::cerr << "relative path: " << exe << std::endl;
+ exit(1);
+ }
+ if (exe.find(argv[1]) != std::string::npos) {
+ std::cerr << "matching path: " << exe << std::endl;
+ exit(1);
+ }
+
+ return 0;
+}
diff --git a/test/syscalls/linux/exec_state_workload.cc b/test/syscalls/linux/exec_state_workload.cc
new file mode 100644
index 000000000..028902b14
--- /dev/null
+++ b/test/syscalls/linux/exec_state_workload.cc
@@ -0,0 +1,202 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/auxv.h>
+#include <sys/prctl.h>
+#include <sys/time.h>
+
+#include <iostream>
+#include <ostream>
+#include <string>
+
+#include "absl/strings/numbers.h"
+
+// Pretty-print a sigset_t.
+std::ostream& operator<<(std::ostream& out, const sigset_t& s) {
+ out << "{ ";
+
+ for (int i = 0; i < NSIG; i++) {
+ if (sigismember(&s, i)) {
+ out << i << " ";
+ }
+ }
+
+ out << "}";
+ return out;
+}
+
+// Verify that the signo handler is handler.
+int CheckSigHandler(uint32_t signo, uintptr_t handler) {
+ struct sigaction sa;
+ int ret = sigaction(signo, nullptr, &sa);
+ if (ret < 0) {
+ perror("sigaction");
+ return 1;
+ }
+
+ if (reinterpret_cast<void (*)(int)>(handler) != sa.sa_handler) {
+ std::cerr << "signo " << signo << " handler got: " << sa.sa_handler
+ << " expected: " << std::hex << handler;
+ return 1;
+ }
+ return 0;
+}
+
+// Verify that the signo is blocked.
+int CheckSigBlocked(uint32_t signo) {
+ sigset_t s;
+ int ret = sigprocmask(SIG_SETMASK, nullptr, &s);
+ if (ret < 0) {
+ perror("sigprocmask");
+ return 1;
+ }
+
+ if (!sigismember(&s, signo)) {
+ std::cerr << "signal " << signo << " not blocked in signal mask: " << s
+ << std::endl;
+ return 1;
+ }
+ return 0;
+}
+
+// Verify that the itimer is enabled.
+int CheckItimerEnabled(uint32_t timer) {
+ struct itimerval itv;
+ int ret = getitimer(timer, &itv);
+ if (ret < 0) {
+ perror("getitimer");
+ return 1;
+ }
+
+ if (!itv.it_value.tv_sec && !itv.it_value.tv_usec &&
+ !itv.it_interval.tv_sec && !itv.it_interval.tv_usec) {
+ std::cerr << "timer " << timer
+ << " not enabled. value sec: " << itv.it_value.tv_sec
+ << " usec: " << itv.it_value.tv_usec
+ << " interval sec: " << itv.it_interval.tv_sec
+ << " usec: " << itv.it_interval.tv_usec << std::endl;
+ return 1;
+ }
+ return 0;
+}
+
+int PrintExecFn() {
+ unsigned long execfn = getauxval(AT_EXECFN);
+ if (!execfn) {
+ std::cerr << "AT_EXECFN missing" << std::endl;
+ return 1;
+ }
+
+ std::cerr << reinterpret_cast<const char*>(execfn) << std::endl;
+ return 0;
+}
+
+int PrintExecName() {
+ const size_t name_length = 20;
+ char name[name_length] = {0};
+ if (prctl(PR_GET_NAME, name) < 0) {
+ std::cerr << "prctl(PR_GET_NAME) failed" << std::endl;
+ return 1;
+ }
+
+ std::cerr << name << std::endl;
+ return 0;
+}
+
+void usage(const std::string& prog) {
+ std::cerr << "usage:\n"
+ << "\t" << prog << " CheckSigHandler <signo> <handler addr (hex)>\n"
+ << "\t" << prog << " CheckSigBlocked <signo>\n"
+ << "\t" << prog << " CheckTimerDisabled <timer>\n"
+ << "\t" << prog << " PrintExecFn\n"
+ << "\t" << prog << " PrintExecName" << std::endl;
+}
+
+int main(int argc, char** argv) {
+ if (argc < 2) {
+ usage(argv[0]);
+ return 1;
+ }
+
+ std::string func(argv[1]);
+
+ if (func == "CheckSigHandler") {
+ if (argc != 4) {
+ usage(argv[0]);
+ return 1;
+ }
+
+ uint32_t signo;
+ if (!absl::SimpleAtoi(argv[2], &signo)) {
+ std::cerr << "invalid signo: " << argv[2] << std::endl;
+ return 1;
+ }
+
+ uintptr_t handler;
+ if (!absl::numbers_internal::safe_strtoi_base(argv[3], &handler, 16)) {
+ std::cerr << "invalid handler: " << std::hex << argv[3] << std::endl;
+ return 1;
+ }
+
+ return CheckSigHandler(signo, handler);
+ }
+
+ if (func == "CheckSigBlocked") {
+ if (argc != 3) {
+ usage(argv[0]);
+ return 1;
+ }
+
+ uint32_t signo;
+ if (!absl::SimpleAtoi(argv[2], &signo)) {
+ std::cerr << "invalid signo: " << argv[2] << std::endl;
+ return 1;
+ }
+
+ return CheckSigBlocked(signo);
+ }
+
+ if (func == "CheckItimerEnabled") {
+ if (argc != 3) {
+ usage(argv[0]);
+ return 1;
+ }
+
+ uint32_t timer;
+ if (!absl::SimpleAtoi(argv[2], &timer)) {
+ std::cerr << "invalid signo: " << argv[2] << std::endl;
+ return 1;
+ }
+
+ return CheckItimerEnabled(timer);
+ }
+
+ if (func == "PrintExecFn") {
+ // N.B. This will be called as an interpreter script, with the script passed
+ // as the third argument. We don't care about that script.
+ return PrintExecFn();
+ }
+
+ if (func == "PrintExecName") {
+ // N.B. This may be called as an interpreter script like PrintExecFn.
+ return PrintExecName();
+ }
+
+ std::cerr << "Invalid function: " << func << std::endl;
+ return 1;
+}
diff --git a/test/syscalls/linux/exit.cc b/test/syscalls/linux/exit.cc
new file mode 100644
index 000000000..d52ea786b
--- /dev/null
+++ b/test/syscalls/linux/exit.cc
@@ -0,0 +1,78 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/wait.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/time/time.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+#include "test/util/time_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void TestExit(int code) {
+ pid_t pid = fork();
+ if (pid == 0) {
+ _exit(code);
+ }
+
+ ASSERT_THAT(pid, SyscallSucceeds());
+
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == code) << status;
+}
+
+TEST(ExitTest, Success) { TestExit(0); }
+
+TEST(ExitTest, Failure) { TestExit(1); }
+
+// This test ensures that a process's file descriptors are closed when it calls
+// exit(). In order to test this, the parent tries to read from a pipe whose
+// write end is held by the child. While the read is blocking, the child exits,
+// which should cause the parent to read 0 bytes due to EOF.
+TEST(ExitTest, CloseFds) {
+ int pipe_fds[2];
+ ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds());
+
+ FileDescriptor read_fd(pipe_fds[0]);
+ FileDescriptor write_fd(pipe_fds[1]);
+
+ pid_t pid = fork();
+ if (pid == 0) {
+ read_fd.reset();
+
+ SleepSafe(absl::Seconds(10));
+
+ _exit(0);
+ }
+
+ EXPECT_THAT(pid, SyscallSucceeds());
+
+ write_fd.reset();
+
+ char buf[10];
+ EXPECT_THAT(ReadFd(read_fd.get(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(0));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/exit_script.sh b/test/syscalls/linux/exit_script.sh
new file mode 100755
index 000000000..527518e06
--- /dev/null
+++ b/test/syscalls/linux/exit_script.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if [ $# -ne 1 ]; then
+ echo "Usage: $0 exit_code"
+ exit 255
+fi
+
+exit $1
diff --git a/test/syscalls/linux/fadvise64.cc b/test/syscalls/linux/fadvise64.cc
new file mode 100644
index 000000000..2af7aa6d9
--- /dev/null
+++ b/test/syscalls/linux/fadvise64.cc
@@ -0,0 +1,72 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <syscall.h>
+#include <unistd.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+TEST(FAdvise64Test, Basic) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+
+ // fadvise64 is noop in gVisor, so just test that it succeeds.
+ ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, POSIX_FADV_NORMAL),
+ SyscallSucceeds());
+ ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, POSIX_FADV_RANDOM),
+ SyscallSucceeds());
+ ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, POSIX_FADV_SEQUENTIAL),
+ SyscallSucceeds());
+ ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, POSIX_FADV_WILLNEED),
+ SyscallSucceeds());
+ ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, POSIX_FADV_DONTNEED),
+ SyscallSucceeds());
+ ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, POSIX_FADV_NOREUSE),
+ SyscallSucceeds());
+}
+
+TEST(FAdvise64Test, InvalidArgs) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+
+ // Note: offset is allowed to be negative.
+ ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, static_cast<off_t>(-1),
+ POSIX_FADV_NORMAL),
+ SyscallFailsWithErrno(EINVAL));
+ ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, 12345),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(FAdvise64Test, NoPipes) {
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor read(fds[0]);
+ const FileDescriptor write(fds[1]);
+
+ ASSERT_THAT(syscall(__NR_fadvise64, read.get(), 0, 10, POSIX_FADV_NORMAL),
+ SyscallFailsWithErrno(ESPIPE));
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/fallocate.cc b/test/syscalls/linux/fallocate.cc
new file mode 100644
index 000000000..cabc2b751
--- /dev/null
+++ b/test/syscalls/linux/fallocate.cc
@@ -0,0 +1,186 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <signal.h>
+#include <sys/eventfd.h>
+#include <sys/resource.h>
+#include <sys/signalfd.h>
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/timerfd.h>
+#include <syscall.h>
+#include <time.h>
+#include <unistd.h>
+
+#include <ctime>
+
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "absl/time/time.h"
+#include "test/syscalls/linux/file_base.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/cleanup.h"
+#include "test/util/eventfd_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+int fallocate(int fd, int mode, off_t offset, off_t len) {
+ return RetryEINTR(syscall)(__NR_fallocate, fd, mode, offset, len);
+}
+
+class AllocateTest : public FileTest {
+ void SetUp() override { FileTest::SetUp(); }
+};
+
+TEST_F(AllocateTest, Fallocate) {
+ // Check that it starts at size zero.
+ struct stat buf;
+ ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
+ EXPECT_EQ(buf.st_size, 0);
+
+ // Grow to ten bytes.
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 10), SyscallSucceeds());
+ ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
+ EXPECT_EQ(buf.st_size, 10);
+
+ // Allocate to a smaller size should be noop.
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 5), SyscallSucceeds());
+ ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
+ EXPECT_EQ(buf.st_size, 10);
+
+ // Grow again.
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 20), SyscallSucceeds());
+ ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
+ EXPECT_EQ(buf.st_size, 20);
+
+ // Grow with offset.
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 10, 20), SyscallSucceeds());
+ ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
+ EXPECT_EQ(buf.st_size, 30);
+
+ // Grow with offset beyond EOF.
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 39, 1), SyscallSucceeds());
+ ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
+ EXPECT_EQ(buf.st_size, 40);
+
+ // Given length 0 should fail with EINVAL.
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 50, 0),
+ SyscallFailsWithErrno(EINVAL));
+ ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
+ EXPECT_EQ(buf.st_size, 40);
+}
+
+TEST_F(AllocateTest, FallocateInvalid) {
+ // Invalid FD
+ EXPECT_THAT(fallocate(-1, 0, 0, 10), SyscallFailsWithErrno(EBADF));
+
+ // Negative offset and size.
+ EXPECT_THAT(fallocate(test_file_fd_.get(), 0, -1, 10),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, -1),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(fallocate(test_file_fd_.get(), 0, -1, -1),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(AllocateTest, FallocateReadonly) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+ EXPECT_THAT(fallocate(fd.get(), 0, 0, 10), SyscallFailsWithErrno(EBADF));
+}
+
+TEST_F(AllocateTest, FallocatePipe) {
+ int pipes[2];
+ EXPECT_THAT(pipe(pipes), SyscallSucceeds());
+ auto cleanup = Cleanup([&pipes] {
+ EXPECT_THAT(close(pipes[0]), SyscallSucceeds());
+ EXPECT_THAT(close(pipes[1]), SyscallSucceeds());
+ });
+
+ EXPECT_THAT(fallocate(pipes[1], 0, 0, 10), SyscallFailsWithErrno(ESPIPE));
+}
+
+TEST_F(AllocateTest, FallocateChar) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/null", O_RDWR));
+ EXPECT_THAT(fallocate(fd.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV));
+}
+
+TEST_F(AllocateTest, FallocateRlimit) {
+ // Get the current rlimit and restore after test run.
+ struct rlimit initial_lim;
+ ASSERT_THAT(getrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds());
+ auto cleanup = Cleanup([&initial_lim] {
+ EXPECT_THAT(setrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds());
+ });
+
+ // Try growing past the file size limit.
+ sigset_t new_mask;
+ sigemptyset(&new_mask);
+ sigaddset(&new_mask, SIGXFSZ);
+ sigprocmask(SIG_BLOCK, &new_mask, nullptr);
+
+ struct rlimit setlim = {};
+ setlim.rlim_cur = 1024;
+ setlim.rlim_max = RLIM_INFINITY;
+ ASSERT_THAT(setrlimit(RLIMIT_FSIZE, &setlim), SyscallSucceeds());
+
+ EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 1025),
+ SyscallFailsWithErrno(EFBIG));
+
+ struct timespec timelimit = {};
+ timelimit.tv_sec = 10;
+ EXPECT_EQ(sigtimedwait(&new_mask, nullptr, &timelimit), SIGXFSZ);
+ ASSERT_THAT(sigprocmask(SIG_UNBLOCK, &new_mask, nullptr), SyscallSucceeds());
+}
+
+TEST_F(AllocateTest, FallocateOtherFDs) {
+ int fd;
+ ASSERT_THAT(fd = timerfd_create(CLOCK_MONOTONIC, 0), SyscallSucceeds());
+ auto timer_fd = FileDescriptor(fd);
+ EXPECT_THAT(fallocate(timer_fd.get(), 0, 0, 10),
+ SyscallFailsWithErrno(ENODEV));
+
+ sigset_t mask;
+ sigemptyset(&mask);
+ ASSERT_THAT(fd = signalfd(-1, &mask, 0), SyscallSucceeds());
+ auto sfd = FileDescriptor(fd);
+ EXPECT_THAT(fallocate(sfd.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV));
+
+ auto efd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE));
+ EXPECT_THAT(fallocate(efd.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV));
+
+ auto sockfd = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ EXPECT_THAT(fallocate(sockfd.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV));
+
+ int socks[2];
+ ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, PF_UNIX, socks),
+ SyscallSucceeds());
+ auto sock0 = FileDescriptor(socks[0]);
+ auto sock1 = FileDescriptor(socks[1]);
+ EXPECT_THAT(fallocate(sock0.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV));
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/fault.cc b/test/syscalls/linux/fault.cc
new file mode 100644
index 000000000..a85750382
--- /dev/null
+++ b/test/syscalls/linux/fault.cc
@@ -0,0 +1,74 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#define _GNU_SOURCE 1
+#include <signal.h>
+#include <ucontext.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+__attribute__((noinline)) void Fault(void) {
+ volatile int* foo = nullptr;
+ *foo = 0;
+}
+
+int GetPcFromUcontext(ucontext_t* uc, uintptr_t* pc) {
+#if defined(__x86_64__)
+ *pc = uc->uc_mcontext.gregs[REG_RIP];
+ return 1;
+#elif defined(__i386__)
+ *pc = uc->uc_mcontext.gregs[REG_EIP];
+ return 1;
+#elif defined(__aarch64__)
+ *pc = uc->uc_mcontext.pc;
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+void sigact_handler(int sig, siginfo_t* siginfo, void* context) {
+ uintptr_t pc;
+ if (GetPcFromUcontext(reinterpret_cast<ucontext_t*>(context), &pc)) {
+ /* Expect Fault() to be at most 64 bytes in size. */
+ uintptr_t fault_addr = reinterpret_cast<uintptr_t>(&Fault);
+ EXPECT_GE(pc, fault_addr);
+ EXPECT_LT(pc, fault_addr + 64);
+ exit(0);
+ }
+}
+
+TEST(FaultTest, InRange) {
+ // Reset the signal handler to do nothing so that it doesn't freak out
+ // the test runner when we fire an alarm.
+ struct sigaction sa = {};
+ sa.sa_sigaction = sigact_handler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO;
+ ASSERT_THAT(sigaction(SIGSEGV, &sa, nullptr), SyscallSucceeds());
+
+ Fault();
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/fchdir.cc b/test/syscalls/linux/fchdir.cc
new file mode 100644
index 000000000..08bcae1e8
--- /dev/null
+++ b/test/syscalls/linux/fchdir.cc
@@ -0,0 +1,77 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/util/capability_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(FchdirTest, Success) {
+ auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ int fd;
+ ASSERT_THAT(fd = open(temp_dir.path().c_str(), O_DIRECTORY | O_RDONLY),
+ SyscallSucceeds());
+
+ EXPECT_THAT(fchdir(fd), SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+ // Change CWD to a permanent location as temp dirs will be cleaned up.
+ EXPECT_THAT(chdir("/"), SyscallSucceeds());
+}
+
+TEST(FchdirTest, InvalidFD) {
+ EXPECT_THAT(fchdir(-1), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(FchdirTest, PermissionDenied) {
+ // Drop capabilities that allow us to override directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0666 /* mode */));
+
+ int fd;
+ ASSERT_THAT(fd = open(temp_dir.path().c_str(), O_DIRECTORY | O_RDONLY),
+ SyscallSucceeds());
+
+ EXPECT_THAT(fchdir(fd), SyscallFailsWithErrno(EACCES));
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+TEST(FchdirTest, NotDir) {
+ auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ int fd;
+ ASSERT_THAT(fd = open(temp_file.path().c_str(), O_CREAT | O_RDONLY, 0777),
+ SyscallSucceeds());
+
+ EXPECT_THAT(fchdir(fd), SyscallFailsWithErrno(ENOTDIR));
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc
new file mode 100644
index 000000000..5467fa2c8
--- /dev/null
+++ b/test/syscalls/linux/fcntl.cc
@@ -0,0 +1,1353 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <signal.h>
+#include <sys/types.h>
+#include <syscall.h>
+#include <unistd.h>
+
+#include <iostream>
+#include <list>
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "absl/base/macros.h"
+#include "absl/base/port.h"
+#include "absl/flags/flag.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/cleanup.h"
+#include "test/util/eventfd_util.h"
+#include "test/util/fs_util.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+#include "test/util/timer_util.h"
+
+ABSL_FLAG(std::string, child_setlock_on, "",
+ "Contains the path to try to set a file lock on.");
+ABSL_FLAG(bool, child_setlock_write, false,
+ "Whether to set a writable lock (otherwise readable)");
+ABSL_FLAG(bool, blocking, false,
+ "Whether to set a blocking lock (otherwise non-blocking).");
+ABSL_FLAG(bool, retry_eintr, false,
+ "Whether to retry in the subprocess on EINTR.");
+ABSL_FLAG(uint64_t, child_setlock_start, 0, "The value of struct flock start");
+ABSL_FLAG(uint64_t, child_setlock_len, 0, "The value of struct flock len");
+ABSL_FLAG(int32_t, socket_fd, -1,
+ "A socket to use for communicating more state back "
+ "to the parent.");
+
+namespace gvisor {
+namespace testing {
+
+class FcntlLockTest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ // Let's make a socket pair.
+ ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, fds_), SyscallSucceeds());
+ }
+
+ void TearDown() override {
+ EXPECT_THAT(close(fds_[0]), SyscallSucceeds());
+ EXPECT_THAT(close(fds_[1]), SyscallSucceeds());
+ }
+
+ int64_t GetSubprocessFcntlTimeInUsec() {
+ int64_t ret = 0;
+ EXPECT_THAT(ReadFd(fds_[0], reinterpret_cast<void*>(&ret), sizeof(ret)),
+ SyscallSucceedsWithValue(sizeof(ret)));
+ return ret;
+ }
+
+ // The first fd will remain with the process creating the subprocess
+ // and the second will go to the subprocess.
+ int fds_[2] = {};
+};
+
+namespace {
+
+PosixErrorOr<Cleanup> SubprocessLock(std::string const& path, bool for_write,
+ bool blocking, bool retry_eintr, int fd,
+ off_t start, off_t length, pid_t* child) {
+ std::vector<std::string> args = {
+ "/proc/self/exe", "--child_setlock_on", path,
+ "--child_setlock_start", absl::StrCat(start), "--child_setlock_len",
+ absl::StrCat(length), "--socket_fd", absl::StrCat(fd)};
+
+ if (for_write) {
+ args.push_back("--child_setlock_write");
+ }
+
+ if (blocking) {
+ args.push_back("--blocking");
+ }
+
+ if (retry_eintr) {
+ args.push_back("--retry_eintr");
+ }
+
+ int execve_errno = 0;
+ ASSIGN_OR_RETURN_ERRNO(
+ auto cleanup,
+ ForkAndExec("/proc/self/exe", ExecveArray(args.begin(), args.end()), {},
+ nullptr, child, &execve_errno));
+
+ if (execve_errno != 0) {
+ return PosixError(execve_errno, "execve");
+ }
+
+ return std::move(cleanup);
+}
+
+TEST(FcntlTest, SetCloExecBadFD) {
+ // Open an eventfd file descriptor with FD_CLOEXEC descriptor flag not set.
+ FileDescriptor f = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0));
+ auto fd = f.get();
+ f.reset();
+ ASSERT_THAT(fcntl(fd, F_GETFD), SyscallFailsWithErrno(EBADF));
+ ASSERT_THAT(fcntl(fd, F_SETFD, FD_CLOEXEC), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(FcntlTest, SetCloExec) {
+ // Open an eventfd file descriptor with FD_CLOEXEC descriptor flag not set.
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0));
+ ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(0));
+
+ // Set the FD_CLOEXEC flag.
+ ASSERT_THAT(fcntl(fd.get(), F_SETFD, FD_CLOEXEC), SyscallSucceeds());
+ ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC));
+}
+
+TEST(FcntlTest, ClearCloExec) {
+ // Open an eventfd file descriptor with FD_CLOEXEC descriptor flag set.
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_CLOEXEC));
+ ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC));
+
+ // Clear the FD_CLOEXEC flag.
+ ASSERT_THAT(fcntl(fd.get(), F_SETFD, 0), SyscallSucceeds());
+ ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(0));
+}
+
+TEST(FcntlTest, IndependentDescriptorFlags) {
+ // Open an eventfd file descriptor with FD_CLOEXEC descriptor flag not set.
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0));
+ ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(0));
+
+ // Duplicate the descriptor. Ensure that it also doesn't have FD_CLOEXEC.
+ FileDescriptor newfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup());
+ ASSERT_THAT(fcntl(newfd.get(), F_GETFD), SyscallSucceedsWithValue(0));
+
+ // Set FD_CLOEXEC on the first FD.
+ ASSERT_THAT(fcntl(fd.get(), F_SETFD, FD_CLOEXEC), SyscallSucceeds());
+ ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC));
+
+ // Ensure that the second FD is unaffected by the change on the first.
+ ASSERT_THAT(fcntl(newfd.get(), F_GETFD), SyscallSucceedsWithValue(0));
+}
+
+// All file description flags passed to open appear in F_GETFL.
+TEST(FcntlTest, GetAllFlags) {
+ TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ int flags = O_RDWR | O_DIRECT | O_SYNC | O_NONBLOCK | O_APPEND;
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), flags));
+
+ // Linux forces O_LARGEFILE on all 64-bit kernels and gVisor's is 64-bit.
+ int expected = flags | kOLargeFile;
+
+ int rflags;
+ EXPECT_THAT(rflags = fcntl(fd.get(), F_GETFL), SyscallSucceeds());
+ EXPECT_EQ(rflags, expected);
+}
+
+TEST(FcntlTest, SetFlags) {
+ TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), 0));
+
+ int const flags = O_RDWR | O_DIRECT | O_SYNC | O_NONBLOCK | O_APPEND;
+ EXPECT_THAT(fcntl(fd.get(), F_SETFL, flags), SyscallSucceeds());
+
+ // Can't set O_RDWR or O_SYNC.
+ // Linux forces O_LARGEFILE on all 64-bit kernels and gVisor's is 64-bit.
+ int expected = O_DIRECT | O_NONBLOCK | O_APPEND | kOLargeFile;
+
+ int rflags;
+ EXPECT_THAT(rflags = fcntl(fd.get(), F_GETFL), SyscallSucceeds());
+ EXPECT_EQ(rflags, expected);
+}
+
+void TestLock(int fd, short lock_type = F_RDLCK) { // NOLINT, type in flock
+ struct flock fl;
+ fl.l_type = lock_type;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ // len 0 locks all bytes despite how large the file grows.
+ fl.l_len = 0;
+ EXPECT_THAT(fcntl(fd, F_SETLK, &fl), SyscallSucceeds());
+}
+
+void TestLockBadFD(int fd,
+ short lock_type = F_RDLCK) { // NOLINT, type in flock
+ struct flock fl;
+ fl.l_type = lock_type;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ // len 0 locks all bytes despite how large the file grows.
+ fl.l_len = 0;
+ EXPECT_THAT(fcntl(fd, F_SETLK, &fl), SyscallFailsWithErrno(EBADF));
+}
+
+TEST_F(FcntlLockTest, SetLockBadFd) { TestLockBadFD(-1); }
+
+TEST_F(FcntlLockTest, SetLockDir) {
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY, 0000));
+ TestLock(fd.get());
+}
+
+TEST_F(FcntlLockTest, SetLockSymlink) {
+ // TODO(gvisor.dev/issue/2782): Replace with IsRunningWithVFS1() when O_PATH
+ // is supported.
+ SKIP_IF(IsRunningOnGvisor());
+
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ auto symlink = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(GetAbsoluteTestTmpdir(), file.path()));
+
+ auto fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(symlink.path(), O_RDONLY | O_PATH, 0000));
+ TestLockBadFD(fd.get());
+}
+
+TEST_F(FcntlLockTest, SetLockProc) {
+ auto fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/status", O_RDONLY, 0000));
+ TestLock(fd.get());
+}
+
+TEST_F(FcntlLockTest, SetLockPipe) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+
+ TestLock(fds[0]);
+ TestLockBadFD(fds[0], F_WRLCK);
+
+ TestLock(fds[1], F_WRLCK);
+ TestLockBadFD(fds[1]);
+
+ EXPECT_THAT(close(fds[0]), SyscallSucceeds());
+ EXPECT_THAT(close(fds[1]), SyscallSucceeds());
+}
+
+TEST_F(FcntlLockTest, SetLockSocket) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ int sock = socket(AF_UNIX, SOCK_STREAM, 0);
+ ASSERT_THAT(sock, SyscallSucceeds());
+
+ struct sockaddr_un addr =
+ ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(true /* abstract */, AF_UNIX));
+ ASSERT_THAT(
+ bind(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallSucceeds());
+
+ TestLock(sock);
+ EXPECT_THAT(close(sock), SyscallSucceeds());
+}
+
+TEST_F(FcntlLockTest, SetLockBadOpenFlagsWrite) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY, 0666));
+
+ struct flock fl0;
+ fl0.l_type = F_WRLCK;
+ fl0.l_whence = SEEK_SET;
+ fl0.l_start = 0;
+ fl0.l_len = 0; // Lock all file
+
+ // Expect that setting a write lock using a read only file descriptor
+ // won't work.
+ EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl0), SyscallFailsWithErrno(EBADF));
+}
+
+TEST_F(FcntlLockTest, SetLockBadOpenFlagsRead) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY, 0666));
+
+ struct flock fl1;
+ fl1.l_type = F_RDLCK;
+ fl1.l_whence = SEEK_SET;
+ fl1.l_start = 0;
+ // Same as SetLockBadFd.
+ fl1.l_len = 0;
+
+ EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl1), SyscallFailsWithErrno(EBADF));
+}
+
+TEST_F(FcntlLockTest, SetLockUnlockOnNothing) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ struct flock fl;
+ fl.l_type = F_UNLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ // Same as SetLockBadFd.
+ fl.l_len = 0;
+
+ EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds());
+}
+
+TEST_F(FcntlLockTest, SetWriteLockSingleProc) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd0 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ struct flock fl;
+ fl.l_type = F_WRLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ // Same as SetLockBadFd.
+ fl.l_len = 0;
+
+ EXPECT_THAT(fcntl(fd0.get(), F_SETLK, &fl), SyscallSucceeds());
+ // Expect to be able to take the same lock on the same fd no problem.
+ EXPECT_THAT(fcntl(fd0.get(), F_SETLK, &fl), SyscallSucceeds());
+
+ FileDescriptor fd1 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ // Expect to be able to take the same lock from a different fd but for
+ // the same process.
+ EXPECT_THAT(fcntl(fd1.get(), F_SETLK, &fl), SyscallSucceeds());
+}
+
+TEST_F(FcntlLockTest, SetReadLockMultiProc) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ struct flock fl;
+ fl.l_type = F_RDLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ // Same as SetLockBadFd.
+ fl.l_len = 0;
+ EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds());
+
+ // spawn a child process to take a read lock on the same file.
+ pid_t child_pid = 0;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ SubprocessLock(file.path(), false /* write lock */,
+ false /* nonblocking */, false /* no eintr retry */,
+ -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid));
+
+ int status = 0;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "Exited with code: " << status;
+}
+
+TEST_F(FcntlLockTest, SetReadThenWriteLockMultiProc) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ struct flock fl;
+ fl.l_type = F_RDLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ // Same as SetLockBadFd.
+ fl.l_len = 0;
+ EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds());
+
+ // Assert that another process trying to lock on the same file will fail
+ // with EAGAIN. It's important that we keep the fd above open so that
+ // that the other process will contend with the lock.
+ pid_t child_pid = 0;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ SubprocessLock(file.path(), true /* write lock */,
+ false /* nonblocking */, false /* no eintr retry */,
+ -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid));
+
+ int status = 0;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == EAGAIN)
+ << "Exited with code: " << status;
+
+ // Close the fd: we want to test that another process can acquire the
+ // lock after this point.
+ fd.reset();
+ // Assert that another process can now acquire the lock.
+
+ child_pid = 0;
+ auto cleanup2 = ASSERT_NO_ERRNO_AND_VALUE(
+ SubprocessLock(file.path(), true /* write lock */,
+ false /* nonblocking */, false /* no eintr retry */,
+ -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid));
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "Exited with code: " << status;
+}
+
+TEST_F(FcntlLockTest, SetWriteThenReadLockMultiProc) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+ // Same as SetReadThenWriteLockMultiProc.
+
+ struct flock fl;
+ fl.l_type = F_WRLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ // Same as SetLockBadFd.
+ fl.l_len = 0;
+
+ // Same as SetReadThenWriteLockMultiProc.
+ EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds());
+
+ // Same as SetReadThenWriteLockMultiProc.
+ pid_t child_pid = 0;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ SubprocessLock(file.path(), false /* write lock */,
+ false /* nonblocking */, false /* no eintr retry */,
+ -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid));
+
+ int status = 0;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == EAGAIN)
+ << "Exited with code: " << status;
+
+ // Same as SetReadThenWriteLockMultiProc.
+ fd.reset(); // Close the fd.
+
+ // Same as SetReadThenWriteLockMultiProc.
+ child_pid = 0;
+ auto cleanup2 = ASSERT_NO_ERRNO_AND_VALUE(
+ SubprocessLock(file.path(), false /* write lock */,
+ false /* nonblocking */, false /* no eintr retry */,
+ -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid));
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "Exited with code: " << status;
+}
+
+TEST_F(FcntlLockTest, SetWriteLockMultiProc) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+ // Same as SetReadThenWriteLockMultiProc.
+
+ struct flock fl;
+ fl.l_type = F_WRLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ // Same as SetLockBadFd.
+ fl.l_len = 0;
+
+ // Same as SetReadWriteLockMultiProc.
+ EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds());
+
+ // Same as SetReadWriteLockMultiProc.
+ pid_t child_pid = 0;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ SubprocessLock(file.path(), true /* write lock */,
+ false /* nonblocking */, false /* no eintr retry */,
+ -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid));
+ int status = 0;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == EAGAIN)
+ << "Exited with code: " << status;
+
+ fd.reset(); // Close the FD.
+ // Same as SetReadWriteLockMultiProc.
+ child_pid = 0;
+ auto cleanup2 = ASSERT_NO_ERRNO_AND_VALUE(
+ SubprocessLock(file.path(), true /* write lock */,
+ false /* nonblocking */, false /* no eintr retry */,
+ -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid));
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "Exited with code: " << status;
+}
+
+TEST_F(FcntlLockTest, SetLockIsRegional) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ struct flock fl;
+ fl.l_type = F_WRLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ fl.l_len = 4096;
+
+ // Same as SetReadWriteLockMultiProc.
+ EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds());
+
+ // Same as SetReadWriteLockMultiProc.
+ pid_t child_pid = 0;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ SubprocessLock(file.path(), true /* write lock */,
+ false /* nonblocking */, false /* no eintr retry */,
+ -1 /* no socket fd */, fl.l_len, 0, &child_pid));
+ int status = 0;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "Exited with code: " << status;
+}
+
+TEST_F(FcntlLockTest, SetLockUpgradeDowngrade) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ struct flock fl;
+ fl.l_type = F_RDLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ // Same as SetLockBadFd.
+ fl.l_len = 0;
+
+ // Same as SetReadWriteLockMultiProc.
+ EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds());
+
+ // Upgrade to a write lock. This will prevent anyone else from taking
+ // the lock.
+ fl.l_type = F_WRLCK;
+ EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds());
+
+ // Same as SetReadWriteLockMultiProc.,
+ pid_t child_pid = 0;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ SubprocessLock(file.path(), false /* write lock */,
+ false /* nonblocking */, false /* no eintr retry */,
+ -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid));
+
+ int status = 0;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == EAGAIN)
+ << "Exited with code: " << status;
+
+ // Downgrade back to a read lock.
+ fl.l_type = F_RDLCK;
+ EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds());
+
+ // Do the same stint as before, but this time it should succeed.
+ child_pid = 0;
+ auto cleanup2 = ASSERT_NO_ERRNO_AND_VALUE(
+ SubprocessLock(file.path(), false /* write lock */,
+ false /* nonblocking */, false /* no eintr retry */,
+ -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid));
+
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "Exited with code: " << status;
+}
+
+TEST_F(FcntlLockTest, SetLockDroppedOnClose) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ // While somewhat surprising, obtaining another fd to the same file and
+ // then closing it in this process drops *all* locks.
+ FileDescriptor other_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+ // Same as SetReadThenWriteLockMultiProc.
+
+ struct flock fl;
+ fl.l_type = F_WRLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ // Same as SetLockBadFd.
+ fl.l_len = 0;
+
+ // Same as SetReadWriteLockMultiProc.
+ EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds());
+
+ other_fd.reset(); // Close.
+
+ // Expect to be able to get the lock, given that the close above dropped it.
+ pid_t child_pid = 0;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ SubprocessLock(file.path(), true /* write lock */,
+ false /* nonblocking */, false /* no eintr retry */,
+ -1 /* no socket fd */, fl.l_start, fl.l_len, &child_pid));
+
+ int status = 0;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "Exited with code: " << status;
+}
+
+TEST_F(FcntlLockTest, SetLockUnlock) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ // Setup two regional locks with different permissions.
+ struct flock fl0;
+ fl0.l_type = F_WRLCK;
+ fl0.l_whence = SEEK_SET;
+ fl0.l_start = 0;
+ fl0.l_len = 4096;
+
+ struct flock fl1;
+ fl1.l_type = F_RDLCK;
+ fl1.l_whence = SEEK_SET;
+ fl1.l_start = 4096;
+ // Same as SetLockBadFd.
+ fl1.l_len = 0;
+
+ // Set both region locks.
+ EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl0), SyscallSucceeds());
+ EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl1), SyscallSucceeds());
+
+ // Another process should fail to take a read lock on the entire file
+ // due to the regional write lock.
+ pid_t child_pid = 0;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(SubprocessLock(
+ file.path(), false /* write lock */, false /* nonblocking */,
+ false /* no eintr retry */, -1 /* no socket fd */, 0, 0, &child_pid));
+
+ int status = 0;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == EAGAIN)
+ << "Exited with code: " << status;
+
+ // Then only unlock the writable one. This should ensure that other
+ // processes can take any read lock that it wants.
+ fl0.l_type = F_UNLCK;
+ EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl0), SyscallSucceeds());
+
+ // Another process should now succeed to get a read lock on the entire file.
+ child_pid = 0;
+ auto cleanup2 = ASSERT_NO_ERRNO_AND_VALUE(SubprocessLock(
+ file.path(), false /* write lock */, false /* nonblocking */,
+ false /* no eintr retry */, -1 /* no socket fd */, 0, 0, &child_pid));
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "Exited with code: " << status;
+}
+
+TEST_F(FcntlLockTest, SetLockAcrossRename) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ // Setup two regional locks with different permissions.
+ struct flock fl;
+ fl.l_type = F_WRLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ // Same as SetLockBadFd.
+ fl.l_len = 0;
+
+ // Set the region lock.
+ EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds());
+
+ // Rename the file to someplace nearby.
+ std::string const newpath = NewTempAbsPath();
+ EXPECT_THAT(rename(file.path().c_str(), newpath.c_str()), SyscallSucceeds());
+
+ // Another process should fail to take a read lock on the renamed file
+ // since we still have an open handle to the inode.
+ pid_t child_pid = 0;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ SubprocessLock(newpath, false /* write lock */, false /* nonblocking */,
+ false /* no eintr retry */, -1 /* no socket fd */,
+ fl.l_start, fl.l_len, &child_pid));
+
+ int status = 0;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == EAGAIN)
+ << "Exited with code: " << status;
+}
+
+// NOTE: The blocking tests below aren't perfect. It's hard to assert exactly
+// what the kernel did while handling a syscall. These tests are timing based
+// because there really isn't any other reasonable way to assert that correct
+// blocking behavior happened.
+
+// This test will verify that blocking works as expected when another process
+// holds a write lock when obtaining a write lock. This test will hold the lock
+// for some amount of time and then wait for the second process to send over the
+// socket_fd the amount of time it was blocked for before the lock succeeded.
+TEST_F(FcntlLockTest, SetWriteLockThenBlockingWriteLock) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ struct flock fl;
+ fl.l_type = F_WRLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ fl.l_len = 0;
+
+ // Take the write lock.
+ ASSERT_THAT(fcntl(fd.get(), F_SETLKW, &fl), SyscallSucceeds());
+
+ // Attempt to take the read lock in a sub process. This will immediately block
+ // so we will release our lock after some amount of time and then assert the
+ // amount of time the other process was blocked for.
+ pid_t child_pid = 0;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(SubprocessLock(
+ file.path(), true /* write lock */, true /* Blocking Lock */,
+ true /* Retry on EINTR */, fds_[1] /* Socket fd for timing information */,
+ fl.l_start, fl.l_len, &child_pid));
+
+ // We will wait kHoldLockForSec before we release our lock allowing the
+ // subprocess to obtain it.
+ constexpr absl::Duration kHoldLockFor = absl::Seconds(5);
+ const int64_t kMinBlockTimeUsec = absl::ToInt64Microseconds(absl::Seconds(1));
+
+ absl::SleepFor(kHoldLockFor);
+
+ // Unlock our write lock.
+ fl.l_type = F_UNLCK;
+ ASSERT_THAT(fcntl(fd.get(), F_SETLKW, &fl), SyscallSucceeds());
+
+ // Read the blocked time from the subprocess socket.
+ int64_t subprocess_blocked_time_usec = GetSubprocessFcntlTimeInUsec();
+
+ // We must have been waiting at least kMinBlockTime.
+ EXPECT_GT(subprocess_blocked_time_usec, kMinBlockTimeUsec);
+
+ // The FCNTL write lock must always succeed as it will simply block until it
+ // can obtain the lock.
+ int status = 0;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "Exited with code: " << status;
+}
+
+// This test will verify that blocking works as expected when another process
+// holds a read lock when obtaining a write lock. This test will hold the lock
+// for some amount of time and then wait for the second process to send over the
+// socket_fd the amount of time it was blocked for before the lock succeeded.
+TEST_F(FcntlLockTest, SetReadLockThenBlockingWriteLock) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ struct flock fl;
+ fl.l_type = F_RDLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ fl.l_len = 0;
+
+ // Take the write lock.
+ ASSERT_THAT(fcntl(fd.get(), F_SETLKW, &fl), SyscallSucceeds());
+
+ // Attempt to take the read lock in a sub process. This will immediately block
+ // so we will release our lock after some amount of time and then assert the
+ // amount of time the other process was blocked for.
+ pid_t child_pid = 0;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(SubprocessLock(
+ file.path(), true /* write lock */, true /* Blocking Lock */,
+ true /* Retry on EINTR */, fds_[1] /* Socket fd for timing information */,
+ fl.l_start, fl.l_len, &child_pid));
+
+ // We will wait kHoldLockForSec before we release our lock allowing the
+ // subprocess to obtain it.
+ constexpr absl::Duration kHoldLockFor = absl::Seconds(5);
+
+ const int64_t kMinBlockTimeUsec = absl::ToInt64Microseconds(absl::Seconds(1));
+
+ absl::SleepFor(kHoldLockFor);
+
+ // Unlock our READ lock.
+ fl.l_type = F_UNLCK;
+ ASSERT_THAT(fcntl(fd.get(), F_SETLKW, &fl), SyscallSucceeds());
+
+ // Read the blocked time from the subprocess socket.
+ int64_t subprocess_blocked_time_usec = GetSubprocessFcntlTimeInUsec();
+
+ // We must have been waiting at least kMinBlockTime.
+ EXPECT_GT(subprocess_blocked_time_usec, kMinBlockTimeUsec);
+
+ // The FCNTL write lock must always succeed as it will simply block until it
+ // can obtain the lock.
+ int status = 0;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "Exited with code: " << status;
+}
+
+// This test will veirfy that blocking works as expected when another process
+// holds a write lock when obtaining a read lock. This test will hold the lock
+// for some amount of time and then wait for the second process to send over the
+// socket_fd the amount of time it was blocked for before the lock succeeded.
+TEST_F(FcntlLockTest, SetWriteLockThenBlockingReadLock) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ struct flock fl;
+ fl.l_type = F_WRLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ fl.l_len = 0;
+
+ // Take the write lock.
+ ASSERT_THAT(fcntl(fd.get(), F_SETLKW, &fl), SyscallSucceeds());
+
+ // Attempt to take the read lock in a sub process. This will immediately block
+ // so we will release our lock after some amount of time and then assert the
+ // amount of time the other process was blocked for.
+ pid_t child_pid = 0;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(SubprocessLock(
+ file.path(), false /* read lock */, true /* Blocking Lock */,
+ true /* Retry on EINTR */, fds_[1] /* Socket fd for timing information */,
+ fl.l_start, fl.l_len, &child_pid));
+
+ // We will wait kHoldLockForSec before we release our lock allowing the
+ // subprocess to obtain it.
+ constexpr absl::Duration kHoldLockFor = absl::Seconds(5);
+
+ const int64_t kMinBlockTimeUsec = absl::ToInt64Microseconds(absl::Seconds(1));
+
+ absl::SleepFor(kHoldLockFor);
+
+ // Unlock our write lock.
+ fl.l_type = F_UNLCK;
+ ASSERT_THAT(fcntl(fd.get(), F_SETLKW, &fl), SyscallSucceeds());
+
+ // Read the blocked time from the subprocess socket.
+ int64_t subprocess_blocked_time_usec = GetSubprocessFcntlTimeInUsec();
+
+ // We must have been waiting at least kMinBlockTime.
+ EXPECT_GT(subprocess_blocked_time_usec, kMinBlockTimeUsec);
+
+ // The FCNTL read lock must always succeed as it will simply block until it
+ // can obtain the lock.
+ int status = 0;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "Exited with code: " << status;
+}
+
+// This test will verify that when one process only holds a read lock that
+// another will not block while obtaining a read lock when F_SETLKW is used.
+TEST_F(FcntlLockTest, SetReadLockThenBlockingReadLock) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ struct flock fl;
+ fl.l_type = F_RDLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ fl.l_len = 0;
+
+ // Take the READ lock.
+ ASSERT_THAT(fcntl(fd.get(), F_SETLKW, &fl), SyscallSucceeds());
+
+ // Attempt to take the read lock in a sub process. Since multiple processes
+ // can hold a read lock this should immediately return without blocking
+ // even though we used F_SETLKW in the subprocess.
+ pid_t child_pid = 0;
+ auto sp = ASSERT_NO_ERRNO_AND_VALUE(SubprocessLock(
+ file.path(), false /* read lock */, true /* Blocking Lock */,
+ true /* Retry on EINTR */, -1 /* No fd, should not block */, fl.l_start,
+ fl.l_len, &child_pid));
+
+ // We never release the lock and the subprocess should still obtain it without
+ // blocking for any period of time.
+ int status = 0;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "Exited with code: " << status;
+}
+
+TEST(FcntlTest, GetO_ASYNC) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ int flag_fl = -1;
+ ASSERT_THAT(flag_fl = fcntl(s.get(), F_GETFL), SyscallSucceeds());
+ EXPECT_EQ(flag_fl & O_ASYNC, 0);
+
+ int flag_fd = -1;
+ ASSERT_THAT(flag_fd = fcntl(s.get(), F_GETFD), SyscallSucceeds());
+ EXPECT_EQ(flag_fd & O_ASYNC, 0);
+}
+
+TEST(FcntlTest, SetFlO_ASYNC) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ int before_fl = -1;
+ ASSERT_THAT(before_fl = fcntl(s.get(), F_GETFL), SyscallSucceeds());
+
+ int before_fd = -1;
+ ASSERT_THAT(before_fd = fcntl(s.get(), F_GETFD), SyscallSucceeds());
+
+ ASSERT_THAT(fcntl(s.get(), F_SETFL, before_fl | O_ASYNC), SyscallSucceeds());
+
+ int after_fl = -1;
+ ASSERT_THAT(after_fl = fcntl(s.get(), F_GETFL), SyscallSucceeds());
+ EXPECT_EQ(after_fl, before_fl | O_ASYNC);
+
+ int after_fd = -1;
+ ASSERT_THAT(after_fd = fcntl(s.get(), F_GETFD), SyscallSucceeds());
+ EXPECT_EQ(after_fd, before_fd);
+}
+
+TEST(FcntlTest, SetFdO_ASYNC) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ int before_fl = -1;
+ ASSERT_THAT(before_fl = fcntl(s.get(), F_GETFL), SyscallSucceeds());
+
+ int before_fd = -1;
+ ASSERT_THAT(before_fd = fcntl(s.get(), F_GETFD), SyscallSucceeds());
+
+ ASSERT_THAT(fcntl(s.get(), F_SETFD, before_fd | O_ASYNC), SyscallSucceeds());
+
+ int after_fl = -1;
+ ASSERT_THAT(after_fl = fcntl(s.get(), F_GETFL), SyscallSucceeds());
+ EXPECT_EQ(after_fl, before_fl);
+
+ int after_fd = -1;
+ ASSERT_THAT(after_fd = fcntl(s.get(), F_GETFD), SyscallSucceeds());
+ EXPECT_EQ(after_fd, before_fd);
+}
+
+TEST(FcntlTest, DupAfterO_ASYNC) {
+ FileDescriptor s1 = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ int before = -1;
+ ASSERT_THAT(before = fcntl(s1.get(), F_GETFL), SyscallSucceeds());
+
+ ASSERT_THAT(fcntl(s1.get(), F_SETFL, before | O_ASYNC), SyscallSucceeds());
+
+ FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(s1.Dup());
+
+ int after = -1;
+ ASSERT_THAT(after = fcntl(fd2.get(), F_GETFL), SyscallSucceeds());
+ EXPECT_EQ(after & O_ASYNC, O_ASYNC);
+}
+
+TEST(FcntlTest, GetOwnNone) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ // Use the raw syscall because the glibc wrapper may convert F_{GET,SET}OWN
+ // into F_{GET,SET}OWN_EX.
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
+ SyscallSucceedsWithValue(0));
+ MaybeSave();
+}
+
+TEST(FcntlTest, GetOwnExNone) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex owner = {};
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &owner),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST(FcntlTest, SetOwnInvalidPid) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, 12345678),
+ SyscallFailsWithErrno(ESRCH));
+}
+
+TEST(FcntlTest, SetOwnInvalidPgrp) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, -12345678),
+ SyscallFailsWithErrno(ESRCH));
+}
+
+TEST(FcntlTest, SetOwnPid) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ pid_t pid;
+ EXPECT_THAT(pid = getpid(), SyscallSucceeds());
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, pid), SyscallSucceeds());
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
+ SyscallSucceedsWithValue(pid));
+ MaybeSave();
+}
+
+TEST(FcntlTest, SetOwnPgrp) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ pid_t pgid;
+ EXPECT_THAT(pgid = getpgrp(), SyscallSucceeds());
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, -pgid), SyscallSucceeds());
+
+ // Verify with F_GETOWN_EX; using F_GETOWN on Linux may incorrectly treat the
+ // negative return value as an error, converting the return value to -1 and
+ // setting errno accordingly.
+ f_owner_ex got_owner = {};
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(got_owner.type, F_OWNER_PGRP);
+ EXPECT_EQ(got_owner.pid, pgid);
+ MaybeSave();
+}
+
+TEST(FcntlTest, SetOwnUnset) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ // Set and unset pid.
+ pid_t pid;
+ EXPECT_THAT(pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, pid), SyscallSucceeds());
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, 0), SyscallSucceeds());
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
+ SyscallSucceedsWithValue(0));
+
+ // Set and unset pgid.
+ pid_t pgid;
+ EXPECT_THAT(pgid = getpgrp(), SyscallSucceeds());
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, -pgid), SyscallSucceeds());
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, 0), SyscallSucceeds());
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
+ SyscallSucceedsWithValue(0));
+ MaybeSave();
+}
+
+// F_SETOWN flips the sign of negative values, an operation that is guarded
+// against overflow.
+TEST(FcntlTest, SetOwnOverflow) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, INT_MIN),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(FcntlTest, SetOwnExInvalidType) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex owner = {};
+ owner.type = __pid_type(-1);
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(FcntlTest, SetOwnExInvalidTid) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex owner = {};
+ owner.type = F_OWNER_TID;
+ owner.pid = -1;
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallFailsWithErrno(ESRCH));
+}
+
+TEST(FcntlTest, SetOwnExInvalidPid) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex owner = {};
+ owner.type = F_OWNER_PID;
+ owner.pid = -1;
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallFailsWithErrno(ESRCH));
+}
+
+TEST(FcntlTest, SetOwnExInvalidPgrp) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex owner = {};
+ owner.type = F_OWNER_PGRP;
+ owner.pid = -1;
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallFailsWithErrno(ESRCH));
+}
+
+TEST(FcntlTest, SetOwnExTid) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex owner = {};
+ owner.type = F_OWNER_TID;
+ EXPECT_THAT(owner.pid = syscall(__NR_gettid), SyscallSucceeds());
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallSucceeds());
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
+ SyscallSucceedsWithValue(owner.pid));
+ MaybeSave();
+}
+
+TEST(FcntlTest, SetOwnExPid) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex owner = {};
+ owner.type = F_OWNER_PID;
+ EXPECT_THAT(owner.pid = getpid(), SyscallSucceeds());
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallSucceeds());
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
+ SyscallSucceedsWithValue(owner.pid));
+ MaybeSave();
+}
+
+TEST(FcntlTest, SetOwnExPgrp) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex set_owner = {};
+ set_owner.type = F_OWNER_PGRP;
+ EXPECT_THAT(set_owner.pid = getpgrp(), SyscallSucceeds());
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner),
+ SyscallSucceeds());
+
+ // Verify with F_GETOWN_EX; using F_GETOWN on Linux may incorrectly treat the
+ // negative return value as an error, converting the return value to -1 and
+ // setting errno accordingly.
+ f_owner_ex got_owner = {};
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(got_owner.type, set_owner.type);
+ EXPECT_EQ(got_owner.pid, set_owner.pid);
+ MaybeSave();
+}
+
+TEST(FcntlTest, SetOwnExUnset) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ // Set and unset pid.
+ f_owner_ex owner = {};
+ owner.type = F_OWNER_PID;
+ EXPECT_THAT(owner.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallSucceeds());
+ owner.pid = 0;
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallSucceeds());
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
+ SyscallSucceedsWithValue(0));
+
+ // Set and unset pgid.
+ owner.type = F_OWNER_PGRP;
+ EXPECT_THAT(owner.pid = getpgrp(), SyscallSucceeds());
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallSucceeds());
+ owner.pid = 0;
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallSucceeds());
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
+ SyscallSucceedsWithValue(0));
+ MaybeSave();
+}
+
+TEST(FcntlTest, GetOwnExTid) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex set_owner = {};
+ set_owner.type = F_OWNER_TID;
+ EXPECT_THAT(set_owner.pid = syscall(__NR_gettid), SyscallSucceeds());
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner),
+ SyscallSucceeds());
+
+ f_owner_ex got_owner = {};
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(got_owner.type, set_owner.type);
+ EXPECT_EQ(got_owner.pid, set_owner.pid);
+}
+
+TEST(FcntlTest, GetOwnExPid) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex set_owner = {};
+ set_owner.type = F_OWNER_PID;
+ EXPECT_THAT(set_owner.pid = getpid(), SyscallSucceeds());
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner),
+ SyscallSucceeds());
+
+ f_owner_ex got_owner = {};
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(got_owner.type, set_owner.type);
+ EXPECT_EQ(got_owner.pid, set_owner.pid);
+}
+
+TEST(FcntlTest, GetOwnExPgrp) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex set_owner = {};
+ set_owner.type = F_OWNER_PGRP;
+ EXPECT_THAT(set_owner.pid = getpgrp(), SyscallSucceeds());
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner),
+ SyscallSucceeds());
+
+ f_owner_ex got_owner = {};
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(got_owner.type, set_owner.type);
+ EXPECT_EQ(got_owner.pid, set_owner.pid);
+}
+
+// Make sure that making multiple concurrent changes to async signal generation
+// does not cause any race issues.
+TEST(FcntlTest, SetFlSetOwnDoNotRace) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ pid_t pid;
+ EXPECT_THAT(pid = getpid(), SyscallSucceeds());
+
+ constexpr absl::Duration runtime = absl::Milliseconds(300);
+ auto setAsync = [&s, &runtime] {
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETFL, O_ASYNC),
+ SyscallSucceeds());
+ sched_yield();
+ }
+ };
+ auto resetAsync = [&s, &runtime] {
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETFL, 0), SyscallSucceeds());
+ sched_yield();
+ }
+ };
+ auto setOwn = [&s, &pid, &runtime] {
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, pid),
+ SyscallSucceeds());
+ sched_yield();
+ }
+ };
+
+ std::list<ScopedThread> threads;
+ for (int i = 0; i < 10; i++) {
+ threads.emplace_back(setAsync);
+ threads.emplace_back(resetAsync);
+ threads.emplace_back(setOwn);
+ }
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
+
+int main(int argc, char** argv) {
+ gvisor::testing::TestInit(&argc, &argv);
+
+ const std::string setlock_on = absl::GetFlag(FLAGS_child_setlock_on);
+ if (!setlock_on.empty()) {
+ int socket_fd = absl::GetFlag(FLAGS_socket_fd);
+ int fd = open(setlock_on.c_str(), O_RDWR, 0666);
+ if (fd == -1 && errno != 0) {
+ int err = errno;
+ std::cerr << "CHILD open " << setlock_on << " failed " << err
+ << std::endl;
+ exit(err);
+ }
+
+ struct flock fl;
+ if (absl::GetFlag(FLAGS_child_setlock_write)) {
+ fl.l_type = F_WRLCK;
+ } else {
+ fl.l_type = F_RDLCK;
+ }
+ fl.l_whence = SEEK_SET;
+ fl.l_start = absl::GetFlag(FLAGS_child_setlock_start);
+ fl.l_len = absl::GetFlag(FLAGS_child_setlock_len);
+
+ // Test the fcntl.
+ int err = 0;
+ int ret = 0;
+
+ gvisor::testing::MonotonicTimer timer;
+ timer.Start();
+ do {
+ ret = fcntl(fd, absl::GetFlag(FLAGS_blocking) ? F_SETLKW : F_SETLK, &fl);
+ } while (absl::GetFlag(FLAGS_retry_eintr) && ret == -1 && errno == EINTR);
+ auto usec = absl::ToInt64Microseconds(timer.Duration());
+
+ if (ret == -1 && errno != 0) {
+ err = errno;
+ std::cerr << "CHILD lock " << setlock_on << " failed " << err
+ << std::endl;
+ }
+
+ // If there is a socket fd let's send back the time in microseconds it took
+ // to execute this syscall.
+ if (socket_fd != -1) {
+ gvisor::testing::WriteFd(socket_fd, reinterpret_cast<void*>(&usec),
+ sizeof(usec));
+ close(socket_fd);
+ }
+
+ close(fd);
+ exit(err);
+ }
+
+ return gvisor::testing::RunAllTests();
+}
diff --git a/test/syscalls/linux/file_base.h b/test/syscalls/linux/file_base.h
new file mode 100644
index 000000000..fb418e052
--- /dev/null
+++ b/test/syscalls/linux/file_base.h
@@ -0,0 +1,100 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_FILE_BASE_H_
+#define GVISOR_TEST_SYSCALLS_FILE_BASE_H_
+
+#include <arpa/inet.h>
+#include <errno.h>
+#include <fcntl.h>
+#include <netinet/in.h>
+#include <stddef.h>
+#include <stdio.h>
+#include <string.h>
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <sys/uio.h>
+#include <unistd.h>
+
+#include <cstring>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/string_view.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+class FileTest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ test_pipe_[0] = -1;
+ test_pipe_[1] = -1;
+
+ test_file_name_ = NewTempAbsPath();
+ test_file_fd_ = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(test_file_name_, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR));
+
+ ASSERT_THAT(pipe(test_pipe_), SyscallSucceeds());
+ ASSERT_THAT(fcntl(test_pipe_[0], F_SETFL, O_NONBLOCK), SyscallSucceeds());
+ }
+
+ // CloseFile will allow the test to manually close the file descriptor.
+ void CloseFile() { test_file_fd_.reset(); }
+
+ // UnlinkFile will allow the test to manually unlink the file.
+ void UnlinkFile() {
+ if (!test_file_name_.empty()) {
+ EXPECT_THAT(unlink(test_file_name_.c_str()), SyscallSucceeds());
+ test_file_name_.clear();
+ }
+ }
+
+ // ClosePipes will allow the test to manually close the pipes.
+ void ClosePipes() {
+ if (test_pipe_[0] > 0) {
+ EXPECT_THAT(close(test_pipe_[0]), SyscallSucceeds());
+ }
+
+ if (test_pipe_[1] > 0) {
+ EXPECT_THAT(close(test_pipe_[1]), SyscallSucceeds());
+ }
+
+ test_pipe_[0] = -1;
+ test_pipe_[1] = -1;
+ }
+
+ void TearDown() override {
+ CloseFile();
+ UnlinkFile();
+ ClosePipes();
+ }
+
+ protected:
+ std::string test_file_name_;
+ FileDescriptor test_file_fd_;
+
+ int test_pipe_[2];
+};
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_FILE_BASE_H_
diff --git a/test/syscalls/linux/flock.cc b/test/syscalls/linux/flock.cc
new file mode 100644
index 000000000..638a93979
--- /dev/null
+++ b/test/syscalls/linux/flock.cc
@@ -0,0 +1,636 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <sys/file.h>
+
+#include <string>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/syscalls/linux/file_base.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+#include "test/util/timer_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class FlockTest : public FileTest {};
+
+TEST_F(FlockTest, InvalidOpCombinations) {
+ // The operation cannot be both exclusive and shared.
+ EXPECT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_SH | LOCK_NB),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Locking and Unlocking doesn't make sense.
+ EXPECT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_UN | LOCK_NB),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(flock(test_file_fd_.get(), LOCK_SH | LOCK_UN | LOCK_NB),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(FlockTest, NoOperationSpecified) {
+ // Not specifying an operation is invalid.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_NB),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(FlockTest, TestSimpleExLock) {
+ // Test that we can obtain an exclusive lock (no other holders)
+ // and that we can unlock it.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(FlockTest, TestSimpleShLock) {
+ // Test that we can obtain a shared lock (no other holders)
+ // and that we can unlock it.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH | LOCK_NB),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(FlockTest, TestLockableAnyMode) {
+ // flock(2): A shared or exclusive lock can be placed on a file
+ // regardless of the mode in which the file was opened.
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(test_file_name_, O_RDONLY)); // open read only to test
+
+ // Mode shouldn't prevent us from taking an exclusive lock.
+ ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceedsWithValue(0));
+
+ // Unlock
+ ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(FlockTest, TestUnlockWithNoHolders) {
+ // Test that unlocking when no one holds a lock succeeeds.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(FlockTest, TestRepeatedExLockingBySameHolder) {
+ // Test that repeated locking by the same holder for the
+ // same type of lock works correctly.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_NB | LOCK_EX),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_NB | LOCK_EX),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(FlockTest, TestRepeatedExLockingSingleUnlock) {
+ // Test that repeated locking by the same holder for the
+ // same type of lock works correctly and that a single unlock is required.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_NB | LOCK_EX),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_NB | LOCK_EX),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY));
+
+ // Should be unlocked at this point
+ ASSERT_THAT(flock(fd.get(), LOCK_NB | LOCK_EX), SyscallSucceedsWithValue(0));
+
+ ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(FlockTest, TestRepeatedShLockingBySameHolder) {
+ // Test that repeated locking by the same holder for the
+ // same type of lock works correctly.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_NB | LOCK_SH),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_NB | LOCK_SH),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(FlockTest, TestSingleHolderUpgrade) {
+ // Test that a shared lock is upgradable when no one else holds a lock.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_NB | LOCK_SH),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_NB | LOCK_EX),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(FlockTest, TestSingleHolderDowngrade) {
+ // Test single holder lock downgrade case.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH | LOCK_NB),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(FlockTest, TestMultipleShared) {
+ // This is a simple test to verify that multiple independent shared
+ // locks will be granted.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH | LOCK_NB),
+ SyscallSucceedsWithValue(0));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR));
+
+ // A shared lock should be granted as there only exists other shared locks.
+ ASSERT_THAT(flock(fd.get(), LOCK_SH | LOCK_NB), SyscallSucceedsWithValue(0));
+
+ // Unlock both.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+ ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+/*
+ * flock(2): If a process uses open(2) (or similar) to obtain more than one
+ * descriptor for the same file, these descriptors are treated
+ * independently by flock(). An attempt to lock the file using one of
+ * these file descriptors may be denied by a lock that the calling process
+ * has already placed via another descriptor.
+ */
+TEST_F(FlockTest, TestMultipleHolderSharedExclusive) {
+ // This test will verify that an exclusive lock will not be granted
+ // while a shared is held.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH | LOCK_NB),
+ SyscallSucceedsWithValue(0));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR));
+
+ // Verify We're unable to get an exlcusive lock via the second FD.
+ // because someone is holding a shared lock.
+ ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Unlock
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(FlockTest, TestSharedLockFailExclusiveHolder) {
+ // This test will verify that a shared lock is denied while
+ // someone holds an exclusive lock.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB),
+ SyscallSucceedsWithValue(0));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR));
+
+ // Verify we're unable to get an shared lock via the second FD.
+ // because someone is holding an exclusive lock.
+ ASSERT_THAT(flock(fd.get(), LOCK_SH | LOCK_NB),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Unlock
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolder) {
+ // This test will verify that an exclusive lock is denied while
+ // someone already holds an exclsuive lock.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB),
+ SyscallSucceedsWithValue(0));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR));
+
+ // Verify we're unable to get an exclusive lock via the second FD
+ // because someone is already holding an exclusive lock.
+ ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Unlock
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(FlockTest, TestMultipleHolderSharedExclusiveUpgrade) {
+ // This test will verify that we cannot obtain an exclusive lock while
+ // a shared lock is held by another descriptor, then verify that an upgrade
+ // is possible on a shared lock once all other shared locks have closed.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH | LOCK_NB),
+ SyscallSucceedsWithValue(0));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR));
+
+ // Verify we're unable to get an exclusive lock via the second FD because
+ // a shared lock is held.
+ ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Verify that we can get a shared lock via the second descriptor instead
+ ASSERT_THAT(flock(fd.get(), LOCK_SH | LOCK_NB), SyscallSucceedsWithValue(0));
+
+ // Unlock the first and there will only be one shared lock remaining.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+
+ // Upgrade 2nd fd.
+ ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceedsWithValue(0));
+
+ // Finally unlock the second
+ ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(FlockTest, TestMultipleHolderSharedExclusiveDowngrade) {
+ // This test will verify that a shared lock is not obtainable while an
+ // exclusive lock is held but that once the first is downgraded that
+ // the second independent file descriptor can also get a shared lock.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB),
+ SyscallSucceedsWithValue(0));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR));
+
+ // Verify We're unable to get a shared lock via the second FD because
+ // an exclusive lock is held.
+ ASSERT_THAT(flock(fd.get(), LOCK_SH | LOCK_NB),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Verify that we can downgrade the first.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH | LOCK_NB),
+ SyscallSucceedsWithValue(0));
+
+ // Now verify that we can obtain a shared lock since the first was downgraded.
+ ASSERT_THAT(flock(fd.get(), LOCK_SH | LOCK_NB), SyscallSucceedsWithValue(0));
+
+ // Finally unlock both.
+ ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+/*
+ * flock(2): Locks created by flock() are associated with an open file table
+ * entry. This means that duplicate file descriptors (created by, for example,
+ * fork(2) or dup(2)) refer to the same lock, and this lock may be modified or
+ * released using any of these descriptors. Furthermore, the lock is released
+ * either by an explicit LOCK_UN operation on any of these duplicate descriptors
+ * or when all such descriptors have been closed.
+ */
+TEST_F(FlockTest, TestDupFdUpgrade) {
+ // This test will verify that a shared lock is upgradeable via a dupped
+ // file descriptor, if the FD wasn't dupped this would fail.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH | LOCK_NB),
+ SyscallSucceedsWithValue(0));
+
+ const FileDescriptor dup_fd = ASSERT_NO_ERRNO_AND_VALUE(test_file_fd_.Dup());
+
+ // Now we should be able to upgrade via the dupped fd.
+ ASSERT_THAT(flock(dup_fd.get(), LOCK_EX | LOCK_NB),
+ SyscallSucceedsWithValue(0));
+
+ // Validate unlock via dupped fd.
+ ASSERT_THAT(flock(dup_fd.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(FlockTest, TestDupFdDowngrade) {
+ // This test will verify that a exclusive lock is downgradable via a dupped
+ // file descriptor, if the FD wasn't dupped this would fail.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB),
+ SyscallSucceedsWithValue(0));
+
+ const FileDescriptor dup_fd = ASSERT_NO_ERRNO_AND_VALUE(test_file_fd_.Dup());
+
+ // Now we should be able to downgrade via the dupped fd.
+ ASSERT_THAT(flock(dup_fd.get(), LOCK_SH | LOCK_NB),
+ SyscallSucceedsWithValue(0));
+
+ // Validate unlock via dupped fd
+ ASSERT_THAT(flock(dup_fd.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(FlockTest, TestDupFdCloseRelease) {
+ // flock(2): Furthermore, the lock is released either by an explicit LOCK_UN
+ // operation on any of these duplicate descriptors, or when all such
+ // descriptors have been closed.
+ //
+ // This test will verify that a dupped fd closing will not release the
+ // underlying lock until all such dupped fds have closed.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB),
+ SyscallSucceedsWithValue(0));
+
+ FileDescriptor dup_fd = ASSERT_NO_ERRNO_AND_VALUE(test_file_fd_.Dup());
+
+ // At this point we have ONE exclusive locked referenced by two different fds.
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR));
+
+ // Validate that we cannot get a lock on a new unrelated FD.
+ ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Closing the dupped fd shouldn't affect the lock until all are closed.
+ dup_fd.reset(); // Closed the duped fd.
+
+ // Validate that we still cannot get a lock on a new unrelated FD.
+ ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Closing the first fd
+ CloseFile(); // Will validate the syscall succeeds.
+
+ // Now we should actually be able to get a lock since all fds related to
+ // the first lock are closed.
+ ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceedsWithValue(0));
+
+ // Unlock.
+ ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(FlockTest, TestDupFdUnlockRelease) {
+ /* flock(2): Furthermore, the lock is released either by an explicit LOCK_UN
+ * operation on any of these duplicate descriptors, or when all such
+ * descriptors have been closed.
+ */
+ // This test will verify that an explict unlock on a dupped FD will release
+ // the underlying lock unlike the previous case where close on a dup was
+ // not enough to release the lock.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB),
+ SyscallSucceedsWithValue(0));
+
+ const FileDescriptor dup_fd = ASSERT_NO_ERRNO_AND_VALUE(test_file_fd_.Dup());
+
+ // At this point we have ONE exclusive locked referenced by two different fds.
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR));
+
+ // Validate that we cannot get a lock on a new unrelated FD.
+ ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Explicitly unlock via the dupped descriptor.
+ ASSERT_THAT(flock(dup_fd.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+
+ // Validate that we can now get the lock since we explicitly unlocked.
+ ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceedsWithValue(0));
+
+ // Unlock
+ ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(FlockTest, TestDupFdFollowedByLock) {
+ // This test will verify that taking a lock on a file descriptor that has
+ // already been dupped means that the lock is shared between both. This is
+ // slightly different than than duping on an already locked FD.
+ FileDescriptor dup_fd = ASSERT_NO_ERRNO_AND_VALUE(test_file_fd_.Dup());
+
+ // Take a lock.
+ ASSERT_THAT(flock(dup_fd.get(), LOCK_EX | LOCK_NB), SyscallSucceeds());
+
+ // Now dup_fd and test_file_ should both reference the same lock.
+ // We shouldn't be able to obtain a lock until both are closed.
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR));
+
+ // Closing the first fd
+ dup_fd.reset(); // Close the duped fd.
+
+ // Validate that we cannot get a lock yet because the dupped descriptor.
+ ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Closing the second fd.
+ CloseFile(); // CloseFile() will validate the syscall succeeds.
+
+ // Now we should be able to get the lock.
+ ASSERT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceeds());
+
+ // Unlock.
+ ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceedsWithValue(0));
+}
+
+// NOTE: These blocking tests are not perfect. Unfortunately it's very hard to
+// determine if a thread was actually blocked in the kernel so we're forced
+// to use timing.
+TEST_F(FlockTest, BlockingLockNoBlockingForSharedLocks_NoRandomSave) {
+ // This test will verify that although LOCK_NB isn't specified
+ // two different fds can obtain shared locks without blocking.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH), SyscallSucceeds());
+
+ // kHoldLockTime is the amount of time we will hold the lock before releasing.
+ constexpr absl::Duration kHoldLockTime = absl::Seconds(30);
+
+ const DisableSave ds; // Timing-related.
+
+ // We do this in another thread so we can determine if it was actually
+ // blocked by timing the amount of time it took for the syscall to complete.
+ ScopedThread t([&] {
+ MonotonicTimer timer;
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR));
+
+ // Only a single shared lock is held, the lock will be granted immediately.
+ // This should be granted without any blocking. Don't save here to avoid
+ // wild discrepencies on timing.
+ timer.Start();
+ ASSERT_THAT(flock(fd.get(), LOCK_SH), SyscallSucceeds());
+
+ // We held the lock for 30 seconds but this thread should not have
+ // blocked at all so we expect a very small duration on syscall completion.
+ ASSERT_LT(timer.Duration(),
+ absl::Seconds(1)); // 1000ms is much less than 30s.
+
+ // We can release our second shared lock
+ ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceeds());
+ });
+
+ // Sleep before unlocking.
+ absl::SleepFor(kHoldLockTime);
+
+ // Release the first shared lock. Don't save in this situation to avoid
+ // discrepencies in timing.
+ EXPECT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceeds());
+}
+
+TEST_F(FlockTest, BlockingLockFirstSharedSecondExclusive_NoRandomSave) {
+ // This test will verify that if someone holds a shared lock any attempt to
+ // obtain an exclusive lock will result in blocking.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH), SyscallSucceeds());
+
+ // kHoldLockTime is the amount of time we will hold the lock before releasing.
+ constexpr absl::Duration kHoldLockTime = absl::Seconds(2);
+
+ const DisableSave ds; // Timing-related.
+
+ // We do this in another thread so we can determine if it was actually
+ // blocked by timing the amount of time it took for the syscall to complete.
+ ScopedThread t([&] {
+ MonotonicTimer timer;
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR));
+
+ // This exclusive lock should block because someone is already holding a
+ // shared lock. We don't save here to avoid wild discrepencies on timing.
+ timer.Start();
+ ASSERT_THAT(RetryEINTR(flock)(fd.get(), LOCK_EX), SyscallSucceeds());
+
+ // We should be blocked, we will expect to be blocked for more than 1.0s.
+ ASSERT_GT(timer.Duration(), absl::Seconds(1));
+
+ // We can release our exclusive lock.
+ ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceeds());
+ });
+
+ // Sleep before unlocking.
+ absl::SleepFor(kHoldLockTime);
+
+ // Release the shared lock allowing the thread to proceed.
+ // We don't save here to avoid wild discrepencies in timing.
+ EXPECT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceeds());
+}
+
+TEST_F(FlockTest, BlockingLockFirstExclusiveSecondShared_NoRandomSave) {
+ // This test will verify that if someone holds an exclusive lock any attempt
+ // to obtain a shared lock will result in blocking.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX), SyscallSucceeds());
+
+ // kHoldLockTime is the amount of time we will hold the lock before releasing.
+ constexpr absl::Duration kHoldLockTime = absl::Seconds(2);
+
+ const DisableSave ds; // Timing-related.
+
+ // We do this in another thread so we can determine if it was actually
+ // blocked by timing the amount of time it took for the syscall to complete.
+ ScopedThread t([&] {
+ MonotonicTimer timer;
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR));
+
+ // This shared lock should block because someone is already holding an
+ // exclusive lock. We don't save here to avoid wild discrepencies on timing.
+ timer.Start();
+ ASSERT_THAT(RetryEINTR(flock)(fd.get(), LOCK_SH), SyscallSucceeds());
+
+ // We should be blocked, we will expect to be blocked for more than 1.0s.
+ ASSERT_GT(timer.Duration(), absl::Seconds(1));
+
+ // We can release our shared lock.
+ ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceeds());
+ });
+
+ // Sleep before unlocking.
+ absl::SleepFor(kHoldLockTime);
+
+ // Release the exclusive lock allowing the blocked thread to proceed.
+ // We don't save here to avoid wild discrepencies in timing.
+ EXPECT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceeds());
+}
+
+TEST_F(FlockTest, BlockingLockFirstExclusiveSecondExclusive_NoRandomSave) {
+ // This test will verify that if someone holds an exclusive lock any attempt
+ // to obtain another exclusive lock will result in blocking.
+ ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX), SyscallSucceeds());
+
+ // kHoldLockTime is the amount of time we will hold the lock before releasing.
+ constexpr absl::Duration kHoldLockTime = absl::Seconds(2);
+
+ const DisableSave ds; // Timing-related.
+
+ // We do this in another thread so we can determine if it was actually
+ // blocked by timing the amount of time it took for the syscall to complete.
+ ScopedThread t([&] {
+ MonotonicTimer timer;
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR));
+
+ // This exclusive lock should block because someone is already holding an
+ // exclusive lock.
+ timer.Start();
+ ASSERT_THAT(RetryEINTR(flock)(fd.get(), LOCK_EX), SyscallSucceeds());
+
+ // We should be blocked, we will expect to be blocked for more than 1.0s.
+ ASSERT_GT(timer.Duration(), absl::Seconds(1));
+
+ // We can release our exclusive lock.
+ ASSERT_THAT(flock(fd.get(), LOCK_UN), SyscallSucceeds());
+ });
+
+ // Sleep before unlocking.
+ absl::SleepFor(kHoldLockTime);
+
+ // Release the exclusive lock allowing the blocked thread to proceed.
+ // We don't save to avoid wild discrepencies in timing.
+ EXPECT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceeds());
+}
+
+TEST(FlockTestNoFixture, BadFD) {
+ // EBADF: fd is not an open file descriptor.
+ ASSERT_THAT(flock(-1, 0), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(FlockTestNoFixture, FlockDir) {
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY, 0000));
+ EXPECT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceeds());
+}
+
+TEST(FlockTestNoFixture, FlockSymlink) {
+ // TODO(gvisor.dev/issue/2782): Replace with IsRunningWithVFS1() when O_PATH
+ // is supported.
+ SKIP_IF(IsRunningOnGvisor());
+
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ auto symlink = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(GetAbsoluteTestTmpdir(), file.path()));
+
+ auto fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(symlink.path(), O_RDONLY | O_PATH, 0000));
+ EXPECT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(FlockTestNoFixture, FlockProc) {
+ auto fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/status", O_RDONLY, 0000));
+ EXPECT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceeds());
+}
+
+TEST(FlockTestNoFixture, FlockPipe) {
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+
+ EXPECT_THAT(flock(fds[0], LOCK_EX | LOCK_NB), SyscallSucceeds());
+ // Check that the pipe was locked above.
+ EXPECT_THAT(flock(fds[1], LOCK_EX | LOCK_NB), SyscallFailsWithErrno(EAGAIN));
+
+ EXPECT_THAT(flock(fds[0], LOCK_UN), SyscallSucceeds());
+ EXPECT_THAT(flock(fds[1], LOCK_EX | LOCK_NB), SyscallSucceeds());
+
+ EXPECT_THAT(close(fds[0]), SyscallSucceeds());
+ EXPECT_THAT(close(fds[1]), SyscallSucceeds());
+}
+
+TEST(FlockTestNoFixture, FlockSocket) {
+ int sock = socket(AF_UNIX, SOCK_STREAM, 0);
+ ASSERT_THAT(sock, SyscallSucceeds());
+
+ struct sockaddr_un addr =
+ ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(true /* abstract */, AF_UNIX));
+ ASSERT_THAT(
+ bind(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallSucceeds());
+
+ EXPECT_THAT(flock(sock, LOCK_EX | LOCK_NB), SyscallSucceeds());
+ EXPECT_THAT(close(sock), SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/fork.cc b/test/syscalls/linux/fork.cc
new file mode 100644
index 000000000..853f6231a
--- /dev/null
+++ b/test/syscalls/linux/fork.cc
@@ -0,0 +1,464 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <sched.h>
+#include <stdlib.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <atomic>
+#include <cstdlib>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/capability_util.h"
+#include "test/util/logging.h"
+#include "test/util/memory_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+using ::testing::Ge;
+
+class ForkTest : public ::testing::Test {
+ protected:
+ // SetUp creates a populated, open file.
+ void SetUp() override {
+ // Make a shared mapping.
+ shared_ = reinterpret_cast<char*>(mmap(0, kPageSize, PROT_READ | PROT_WRITE,
+ MAP_SHARED | MAP_ANONYMOUS, -1, 0));
+ ASSERT_NE(reinterpret_cast<void*>(shared_), MAP_FAILED);
+
+ // Make a private mapping.
+ private_ =
+ reinterpret_cast<char*>(mmap(0, kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0));
+ ASSERT_NE(reinterpret_cast<void*>(private_), MAP_FAILED);
+
+ // Make a pipe.
+ ASSERT_THAT(pipe(pipes_), SyscallSucceeds());
+ }
+
+ // TearDown frees associated resources.
+ void TearDown() override {
+ EXPECT_THAT(munmap(shared_, kPageSize), SyscallSucceeds());
+ EXPECT_THAT(munmap(private_, kPageSize), SyscallSucceeds());
+ EXPECT_THAT(close(pipes_[0]), SyscallSucceeds());
+ EXPECT_THAT(close(pipes_[1]), SyscallSucceeds());
+ }
+
+ // Fork executes a clone system call.
+ pid_t Fork() {
+ pid_t pid = fork();
+ MaybeSave();
+ TEST_PCHECK_MSG(pid >= 0, "fork failed");
+ return pid;
+ }
+
+ // Wait waits for the given pid and returns the exit status. If the child was
+ // killed by a signal or an error occurs, then 256+signal is returned.
+ int Wait(pid_t pid) {
+ int status;
+ while (true) {
+ int rval = wait4(pid, &status, 0, NULL);
+ if (rval < 0) {
+ return rval;
+ }
+ if (rval != pid) {
+ continue;
+ }
+ if (WIFEXITED(status)) {
+ return WEXITSTATUS(status);
+ }
+ if (WIFSIGNALED(status)) {
+ return 256 + WTERMSIG(status);
+ }
+ }
+ }
+
+ // Exit exits the proccess.
+ void Exit(int code) {
+ _exit(code);
+
+ // Should never reach here. Since the exit above failed, we really don't
+ // have much in the way of options to indicate failure. So we just try to
+ // log an assertion failure to the logs. The parent process will likely
+ // fail anyways if exit is not working.
+ TEST_CHECK_MSG(false, "_exit returned");
+ }
+
+ // ReadByte reads a byte from the shared pipe.
+ char ReadByte() {
+ char val = -1;
+ TEST_PCHECK(ReadFd(pipes_[0], &val, 1) == 1);
+ MaybeSave();
+ return val;
+ }
+
+ // WriteByte writes a byte from the shared pipe.
+ void WriteByte(char val) {
+ TEST_PCHECK(WriteFd(pipes_[1], &val, 1) == 1);
+ MaybeSave();
+ }
+
+ // Shared pipe.
+ int pipes_[2];
+
+ // Shared mapping (one page).
+ char* shared_;
+
+ // Private mapping (one page).
+ char* private_;
+};
+
+TEST_F(ForkTest, Simple) {
+ pid_t child = Fork();
+ if (child == 0) {
+ Exit(0);
+ }
+ EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(ForkTest, ExitCode) {
+ pid_t child = Fork();
+ if (child == 0) {
+ Exit(123);
+ }
+ EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(123));
+ child = Fork();
+ if (child == 0) {
+ Exit(1);
+ }
+ EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(1));
+}
+
+TEST_F(ForkTest, Multi) {
+ pid_t child1 = Fork();
+ if (child1 == 0) {
+ Exit(0);
+ }
+ pid_t child2 = Fork();
+ if (child2 == 0) {
+ Exit(1);
+ }
+ EXPECT_THAT(Wait(child1), SyscallSucceedsWithValue(0));
+ EXPECT_THAT(Wait(child2), SyscallSucceedsWithValue(1));
+}
+
+TEST_F(ForkTest, Pipe) {
+ pid_t child = Fork();
+ if (child == 0) {
+ WriteByte(1);
+ Exit(0);
+ }
+ EXPECT_EQ(ReadByte(), 1);
+ EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(ForkTest, SharedMapping) {
+ pid_t child = Fork();
+ if (child == 0) {
+ // Wait for the parent.
+ ReadByte();
+ if (shared_[0] == 1) {
+ Exit(0);
+ }
+ // Failed.
+ Exit(1);
+ }
+ // Change the mapping.
+ ASSERT_EQ(shared_[0], 0);
+ shared_[0] = 1;
+ // Unblock the child.
+ WriteByte(0);
+ // Did it work?
+ EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(ForkTest, PrivateMapping) {
+ pid_t child = Fork();
+ if (child == 0) {
+ // Wait for the parent.
+ ReadByte();
+ if (private_[0] == 0) {
+ Exit(0);
+ }
+ // Failed.
+ Exit(1);
+ }
+ // Change the mapping.
+ ASSERT_EQ(private_[0], 0);
+ private_[0] = 1;
+ // Unblock the child.
+ WriteByte(0);
+ // Did it work?
+ EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0));
+}
+
+// CPUID is x86 specific.
+#ifdef __x86_64__
+// Test that cpuid works after a fork.
+TEST_F(ForkTest, Cpuid) {
+ pid_t child = Fork();
+
+ // We should be able to determine the CPU vendor.
+ ASSERT_NE(GetCPUVendor(), CPUVendor::kUnknownVendor);
+
+ if (child == 0) {
+ Exit(0);
+ }
+ EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0));
+}
+#endif
+
+TEST_F(ForkTest, Mmap) {
+ pid_t child = Fork();
+
+ if (child == 0) {
+ void* addr =
+ mmap(0, kPageSize, PROT_READ, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+ MaybeSave();
+ Exit(addr == MAP_FAILED);
+ }
+
+ EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0));
+}
+
+static volatile int alarmed = 0;
+
+void AlarmHandler(int sig, siginfo_t* info, void* context) { alarmed = 1; }
+
+TEST_F(ForkTest, Alarm) {
+ // Setup an alarm handler.
+ struct sigaction sa;
+ sa.sa_sigaction = AlarmHandler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO;
+ EXPECT_THAT(sigaction(SIGALRM, &sa, nullptr), SyscallSucceeds());
+
+ pid_t child = Fork();
+
+ if (child == 0) {
+ alarm(1);
+ sleep(3);
+ if (!alarmed) {
+ Exit(1);
+ }
+ Exit(0);
+ }
+
+ EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(0, alarmed);
+}
+
+// Child cannot affect parent private memory. Regression test for b/24137240.
+TEST_F(ForkTest, PrivateMemory) {
+ std::atomic<uint32_t> local(0);
+
+ pid_t child1 = Fork();
+ if (child1 == 0) {
+ local++;
+
+ pid_t child2 = Fork();
+ if (child2 == 0) {
+ local++;
+
+ TEST_CHECK(local.load() == 2);
+
+ Exit(0);
+ }
+
+ TEST_PCHECK(Wait(child2) == 0);
+ TEST_CHECK(local.load() == 1);
+ Exit(0);
+ }
+
+ EXPECT_THAT(Wait(child1), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(0, local.load());
+}
+
+// Kernel-accessed buffers should remain coherent across COW.
+//
+// The buffer must be >= usermem.ZeroCopyMinBytes, as UnsafeAccess operates
+// differently. Regression test for b/33811887.
+TEST_F(ForkTest, COWSegment) {
+ constexpr int kBufSize = 1024;
+ char* read_buf = private_;
+ char* touch = private_ + kPageSize / 2;
+
+ std::string contents(kBufSize, 'a');
+
+ ScopedThread t([&] {
+ // Wait to be sure the parent is blocked in read.
+ absl::SleepFor(absl::Seconds(3));
+
+ // Fork to mark private pages for COW.
+ //
+ // Use fork directly rather than the Fork wrapper to skip the multi-threaded
+ // check, and limit the child to async-signal-safe functions:
+ //
+ // "After a fork() in a multithreaded program, the child can safely call
+ // only async-signal-safe functions (see signal(7)) until such time as it
+ // calls execve(2)."
+ //
+ // Skip ASSERT in the child, as it isn't async-signal-safe.
+ pid_t child = fork();
+ if (child == 0) {
+ // Wait to be sure parent touched memory.
+ sleep(3);
+ Exit(0);
+ }
+
+ // Check success only in the parent.
+ ASSERT_THAT(child, SyscallSucceedsWithValue(Ge(0)));
+
+ // Trigger COW on private page.
+ *touch = 42;
+
+ // Write to pipe. Parent should still be able to read this.
+ EXPECT_THAT(WriteFd(pipes_[1], contents.c_str(), kBufSize),
+ SyscallSucceedsWithValue(kBufSize));
+
+ EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0));
+ });
+
+ EXPECT_THAT(ReadFd(pipes_[0], read_buf, kBufSize),
+ SyscallSucceedsWithValue(kBufSize));
+ EXPECT_STREQ(contents.c_str(), read_buf);
+}
+
+TEST_F(ForkTest, SigAltStack) {
+ std::vector<char> stack_mem(SIGSTKSZ);
+ stack_t stack = {};
+ stack.ss_size = SIGSTKSZ;
+ stack.ss_sp = stack_mem.data();
+ ASSERT_THAT(sigaltstack(&stack, nullptr), SyscallSucceeds());
+
+ pid_t child = Fork();
+
+ if (child == 0) {
+ stack_t oss = {};
+ TEST_PCHECK(sigaltstack(nullptr, &oss) == 0);
+ MaybeSave();
+
+ TEST_CHECK((oss.ss_flags & SS_DISABLE) == 0);
+ TEST_CHECK(oss.ss_size == SIGSTKSZ);
+ TEST_CHECK(oss.ss_sp == stack.ss_sp);
+
+ Exit(0);
+ }
+ EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(ForkTest, Affinity) {
+ // Make a non-default cpumask.
+ cpu_set_t parent_mask;
+ EXPECT_THAT(sched_getaffinity(/*pid=*/0, sizeof(cpu_set_t), &parent_mask),
+ SyscallSucceeds());
+ // Knock out the lowest bit.
+ for (unsigned int n = 0; n < CPU_SETSIZE; n++) {
+ if (CPU_ISSET(n, &parent_mask)) {
+ CPU_CLR(n, &parent_mask);
+ break;
+ }
+ }
+ EXPECT_THAT(sched_setaffinity(/*pid=*/0, sizeof(cpu_set_t), &parent_mask),
+ SyscallSucceeds());
+
+ pid_t child = Fork();
+ if (child == 0) {
+ cpu_set_t child_mask;
+
+ int ret = sched_getaffinity(/*pid=*/0, sizeof(cpu_set_t), &child_mask);
+ MaybeSave();
+ if (ret < 0) {
+ Exit(-ret);
+ }
+
+ TEST_CHECK(CPU_EQUAL(&child_mask, &parent_mask));
+
+ Exit(0);
+ }
+
+ EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0));
+}
+
+TEST(CloneTest, NewUserNamespacePermitsAllOtherNamespaces) {
+ // "If CLONE_NEWUSER is specified along with other CLONE_NEW* flags in a
+ // single clone(2) or unshare(2) call, the user namespace is guaranteed to be
+ // created first, giving the child (clone(2)) or caller (unshare(2))
+ // privileges over the remaining namespaces created by the call. Thus, it is
+ // possible for an unprivileged caller to specify this combination of flags."
+ // - user_namespaces(7)
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanCreateUserNamespace()));
+ Mapping child_stack = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ int child_pid;
+ // We only test with CLONE_NEWIPC, CLONE_NEWNET, and CLONE_NEWUTS since these
+ // namespaces were implemented in Linux before user namespaces.
+ ASSERT_THAT(
+ child_pid = clone(
+ +[](void*) { return 0; },
+ reinterpret_cast<void*>(child_stack.addr() + kPageSize),
+ CLONE_NEWUSER | CLONE_NEWIPC | CLONE_NEWNET | CLONE_NEWUTS | SIGCHLD,
+ /* arg = */ nullptr),
+ SyscallSucceeds());
+
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "status = " << status;
+}
+
+// Clone with CLONE_SETTLS and a non-canonical TLS address is rejected.
+TEST(CloneTest, NonCanonicalTLS) {
+ constexpr uintptr_t kNonCanonical = 1ull << 48;
+
+ // We need a valid address for the stack pointer. We'll never actually execute
+ // on this.
+ char stack;
+
+ // The raw system call interface on x86-64 is:
+ // long clone(unsigned long flags, void *stack,
+ // int *parent_tid, int *child_tid,
+ // unsigned long tls);
+ //
+ // While on arm64, the order of the last two arguments is reversed:
+ // long clone(unsigned long flags, void *stack,
+ // int *parent_tid, unsigned long tls,
+ // int *child_tid);
+#if defined(__x86_64__)
+ EXPECT_THAT(syscall(__NR_clone, SIGCHLD | CLONE_SETTLS, &stack, nullptr,
+ nullptr, kNonCanonical),
+ SyscallFailsWithErrno(EPERM));
+#elif defined(__aarch64__)
+ EXPECT_THAT(syscall(__NR_clone, SIGCHLD | CLONE_SETTLS, &stack, nullptr,
+ kNonCanonical, nullptr),
+ SyscallFailsWithErrno(EPERM));
+#endif
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/fpsig_fork.cc b/test/syscalls/linux/fpsig_fork.cc
new file mode 100644
index 000000000..c47567b4e
--- /dev/null
+++ b/test/syscalls/linux/fpsig_fork.cc
@@ -0,0 +1,131 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This test verifies that fork(2) in a signal handler will correctly
+// restore floating point state after the signal handler returns in both
+// the child and parent.
+#include <sys/time.h>
+
+#include "gtest/gtest.h"
+#include "test/util/logging.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+#ifdef __x86_64__
+#define GET_XMM(__var, __xmm) \
+ asm volatile("movq %%" #__xmm ", %0" : "=r"(__var))
+#define SET_XMM(__var, __xmm) asm volatile("movq %0, %%" #__xmm : : "r"(__var))
+#define GET_FP0(__var) GET_XMM(__var, xmm0)
+#define SET_FP0(__var) SET_XMM(__var, xmm0)
+#elif __aarch64__
+#define __stringify_1(x...) #x
+#define __stringify(x...) __stringify_1(x)
+#define GET_FPREG(var, regname) \
+ asm volatile("str " __stringify(regname) ", %0" : "=m"(var))
+#define SET_FPREG(var, regname) \
+ asm volatile("ldr " __stringify(regname) ", %0" : "=m"(var))
+#define GET_FP0(var) GET_FPREG(var, d0)
+#define SET_FP0(var) SET_FPREG(var, d0)
+#endif
+
+int parent, child;
+
+void sigusr1(int s, siginfo_t* siginfo, void* _uc) {
+ // Fork and clobber %xmm0. The fpstate should be restored by sigreturn(2)
+ // in both parent and child.
+ child = fork();
+ TEST_CHECK_MSG(child >= 0, "fork failed");
+
+ uint64_t val = SIGUSR1;
+ SET_FP0(val);
+ uint64_t got;
+ GET_FP0(got);
+ TEST_CHECK_MSG(val == got, "Basic FP check failed in sigusr1()");
+}
+
+TEST(FPSigTest, Fork) {
+ parent = getpid();
+ pid_t parent_tid = gettid();
+
+ struct sigaction sa = {};
+ sigemptyset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO;
+ sa.sa_sigaction = sigusr1;
+ ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds());
+
+ // The amd64 ABI specifies that the XMM register set is caller-saved. This
+ // implies that if there is any function call between SET_XMM and GET_XMM the
+ // compiler might save/restore xmm0 implicitly. This defeats the entire
+ // purpose of the test which is to verify that fpstate is restored by
+ // sigreturn(2).
+ //
+ // This is the reason why 'tgkill(getpid(), gettid(), SIGUSR1)' is implemented
+ // in inline assembly below.
+ //
+ // If the OS is broken and registers are clobbered by the child, using tgkill
+ // to signal the current thread increases the likelihood that this thread will
+ // be the one clobbered.
+
+ uint64_t expected = 0xdeadbeeffacefeed;
+ SET_FP0(expected);
+
+#ifdef __x86_64__
+ asm volatile(
+ "movl %[killnr], %%eax;"
+ "movl %[parent], %%edi;"
+ "movl %[tid], %%esi;"
+ "movl %[sig], %%edx;"
+ "syscall;"
+ :
+ : [ killnr ] "i"(__NR_tgkill), [ parent ] "rm"(parent),
+ [ tid ] "rm"(parent_tid), [ sig ] "i"(SIGUSR1)
+ : "rax", "rdi", "rsi", "rdx",
+ // Clobbered by syscall.
+ "rcx", "r11");
+#elif __aarch64__
+ asm volatile(
+ "mov x8, %0\n"
+ "mov x0, %1\n"
+ "mov x1, %2\n"
+ "mov x2, %3\n"
+ "svc #0\n" ::"r"(__NR_tgkill),
+ "r"(parent), "r"(parent_tid), "r"(SIGUSR1));
+#endif
+
+ uint64_t got;
+ GET_FP0(got);
+
+ if (getpid() == parent) { // Parent.
+ int status;
+ ASSERT_THAT(waitpid(child, &status, 0), SyscallSucceedsWithValue(child));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+ }
+
+ // TEST_CHECK_MSG since this may run in the child.
+ TEST_CHECK_MSG(expected == got, "Bad xmm0 value");
+
+ if (getpid() != parent) { // Child.
+ _exit(0);
+ }
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/fpsig_nested.cc b/test/syscalls/linux/fpsig_nested.cc
new file mode 100644
index 000000000..302d928d1
--- /dev/null
+++ b/test/syscalls/linux/fpsig_nested.cc
@@ -0,0 +1,167 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This program verifies that application floating point state is restored
+// correctly after a signal handler returns. It also verifies that this works
+// with nested signals.
+#include <sys/time.h>
+
+#include "gtest/gtest.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+#ifdef __x86_64__
+#define GET_XMM(__var, __xmm) \
+ asm volatile("movq %%" #__xmm ", %0" : "=r"(__var))
+#define SET_XMM(__var, __xmm) asm volatile("movq %0, %%" #__xmm : : "r"(__var))
+#define GET_FP0(__var) GET_XMM(__var, xmm0)
+#define SET_FP0(__var) SET_XMM(__var, xmm0)
+#elif __aarch64__
+#define __stringify_1(x...) #x
+#define __stringify(x...) __stringify_1(x)
+#define GET_FPREG(var, regname) \
+ asm volatile("str " __stringify(regname) ", %0" : "=m"(var))
+#define SET_FPREG(var, regname) \
+ asm volatile("ldr " __stringify(regname) ", %0" : "=m"(var))
+#define GET_FP0(var) GET_FPREG(var, d0)
+#define SET_FP0(var) SET_FPREG(var, d0)
+#endif
+
+int pid;
+int tid;
+
+volatile uint64_t entryxmm[2] = {~0UL, ~0UL};
+volatile uint64_t exitxmm[2];
+
+void sigusr2(int s, siginfo_t* siginfo, void* _uc) {
+ uint64_t val = SIGUSR2;
+
+ // Record the value of %xmm0 on entry and then clobber it.
+ GET_FP0(entryxmm[1]);
+ SET_FP0(val);
+ GET_FP0(exitxmm[1]);
+}
+
+void sigusr1(int s, siginfo_t* siginfo, void* _uc) {
+ uint64_t val = SIGUSR1;
+
+ // Record the value of %xmm0 on entry and then clobber it.
+ GET_FP0(entryxmm[0]);
+ SET_FP0(val);
+
+ // Send a SIGUSR2 to ourself. The signal mask is configured such that
+ // the SIGUSR2 handler will run before this handler returns.
+#ifdef __x86_64__
+ asm volatile(
+ "movl %[killnr], %%eax;"
+ "movl %[pid], %%edi;"
+ "movl %[tid], %%esi;"
+ "movl %[sig], %%edx;"
+ "syscall;"
+ :
+ : [ killnr ] "i"(__NR_tgkill), [ pid ] "rm"(pid), [ tid ] "rm"(tid),
+ [ sig ] "i"(SIGUSR2)
+ : "rax", "rdi", "rsi", "rdx",
+ // Clobbered by syscall.
+ "rcx", "r11");
+#elif __aarch64__
+ asm volatile(
+ "mov x8, %0\n"
+ "mov x0, %1\n"
+ "mov x1, %2\n"
+ "mov x2, %3\n"
+ "svc #0\n" ::"r"(__NR_tgkill),
+ "r"(pid), "r"(tid), "r"(SIGUSR2));
+#endif
+
+ // Record value of %xmm0 again to verify that the nested signal handler
+ // does not clobber it.
+ GET_FP0(exitxmm[0]);
+}
+
+TEST(FPSigTest, NestedSignals) {
+ pid = getpid();
+ tid = gettid();
+
+ struct sigaction sa = {};
+ sigemptyset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO;
+ sa.sa_sigaction = sigusr1;
+ ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds());
+
+ sa.sa_sigaction = sigusr2;
+ ASSERT_THAT(sigaction(SIGUSR2, &sa, nullptr), SyscallSucceeds());
+
+ // The amd64 ABI specifies that the XMM register set is caller-saved. This
+ // implies that if there is any function call between SET_XMM and GET_XMM the
+ // compiler might save/restore xmm0 implicitly. This defeats the entire
+ // purpose of the test which is to verify that fpstate is restored by
+ // sigreturn(2).
+ //
+ // This is the reason why 'tgkill(getpid(), gettid(), SIGUSR1)' is implemented
+ // in inline assembly below.
+ //
+ // If the OS is broken and registers are clobbered by the signal, using tgkill
+ // to signal the current thread ensures that this is the clobbered thread.
+
+ uint64_t expected = 0xdeadbeeffacefeed;
+ SET_FP0(expected);
+
+#ifdef __x86_64__
+ asm volatile(
+ "movl %[killnr], %%eax;"
+ "movl %[pid], %%edi;"
+ "movl %[tid], %%esi;"
+ "movl %[sig], %%edx;"
+ "syscall;"
+ :
+ : [ killnr ] "i"(__NR_tgkill), [ pid ] "rm"(pid), [ tid ] "rm"(tid),
+ [ sig ] "i"(SIGUSR1)
+ : "rax", "rdi", "rsi", "rdx",
+ // Clobbered by syscall.
+ "rcx", "r11");
+#elif __aarch64__
+ asm volatile(
+ "mov x8, %0\n"
+ "mov x0, %1\n"
+ "mov x1, %2\n"
+ "mov x2, %3\n"
+ "svc #0\n" ::"r"(__NR_tgkill),
+ "r"(pid), "r"(tid), "r"(SIGUSR1));
+#endif
+
+ uint64_t got;
+ GET_FP0(got);
+
+ //
+ // The checks below verifies the following:
+ // - signal handlers must called with a clean fpu state.
+ // - sigreturn(2) must restore fpstate of the interrupted context.
+ //
+ EXPECT_EQ(expected, got);
+ EXPECT_EQ(entryxmm[0], 0);
+ EXPECT_EQ(entryxmm[1], 0);
+ EXPECT_EQ(exitxmm[0], SIGUSR1);
+ EXPECT_EQ(exitxmm[1], SIGUSR2);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/fsync.cc b/test/syscalls/linux/fsync.cc
new file mode 100644
index 000000000..e7e057f06
--- /dev/null
+++ b/test/syscalls/linux/fsync.cc
@@ -0,0 +1,58 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdio.h>
+#include <unistd.h>
+
+#include <string>
+
+#include "gtest/gtest.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(FsyncTest, TempFileSucceeds) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+ const std::string data = "some data to sync";
+ EXPECT_THAT(write(fd.get(), data.c_str(), data.size()),
+ SyscallSucceedsWithValue(data.size()));
+ EXPECT_THAT(fsync(fd.get()), SyscallSucceeds());
+}
+
+TEST(FsyncTest, TempDirSucceeds) {
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY));
+ EXPECT_THAT(fsync(fd.get()), SyscallSucceeds());
+}
+
+TEST(FsyncTest, CannotFsyncOnUnopenedFd) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ int fd;
+ ASSERT_THAT(fd = open(file.path().c_str(), O_RDONLY), SyscallSucceeds());
+ ASSERT_THAT(close(fd), SyscallSucceeds());
+
+ // fd is now invalid.
+ EXPECT_THAT(fsync(fd), SyscallFailsWithErrno(EBADF));
+}
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/futex.cc b/test/syscalls/linux/futex.cc
new file mode 100644
index 000000000..40c80a6e1
--- /dev/null
+++ b/test/syscalls/linux/futex.cc
@@ -0,0 +1,742 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <linux/futex.h>
+#include <linux/types.h>
+#include <sys/syscall.h>
+#include <sys/time.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <algorithm>
+#include <atomic>
+#include <memory>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/memory_util.h"
+#include "test/util/save_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+#include "test/util/time_util.h"
+#include "test/util/timer_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Amount of time we wait for threads doing futex_wait to start running before
+// doing futex_wake.
+constexpr auto kWaiterStartupDelay = absl::Seconds(3);
+
+// Default timeout for waiters in tests where we expect a futex_wake to be
+// ineffective.
+constexpr auto kIneffectiveWakeTimeout = absl::Seconds(6);
+
+static_assert(kWaiterStartupDelay < kIneffectiveWakeTimeout,
+ "futex_wait will time out before futex_wake is called");
+
+int futex_wait(bool priv, std::atomic<int>* uaddr, int val,
+ absl::Duration timeout = absl::InfiniteDuration()) {
+ int op = FUTEX_WAIT;
+ if (priv) {
+ op |= FUTEX_PRIVATE_FLAG;
+ }
+
+ if (timeout == absl::InfiniteDuration()) {
+ return RetryEINTR(syscall)(SYS_futex, uaddr, op, val, nullptr);
+ }
+
+ // FUTEX_WAIT doesn't adjust the timeout if it returns EINTR, so we have to do
+ // so.
+ while (true) {
+ auto const timeout_ts = absl::ToTimespec(timeout);
+ MonotonicTimer timer;
+ timer.Start();
+ int const ret = syscall(SYS_futex, uaddr, op, val, &timeout_ts);
+ if (ret != -1 || errno != EINTR) {
+ return ret;
+ }
+ timeout = std::max(timeout - timer.Duration(), absl::ZeroDuration());
+ }
+}
+
+int futex_wait_bitset(bool priv, std::atomic<int>* uaddr, int val, int bitset,
+ absl::Time deadline = absl::InfiniteFuture()) {
+ int op = FUTEX_WAIT_BITSET | FUTEX_CLOCK_REALTIME;
+ if (priv) {
+ op |= FUTEX_PRIVATE_FLAG;
+ }
+
+ auto const deadline_ts = absl::ToTimespec(deadline);
+ return RetryEINTR(syscall)(
+ SYS_futex, uaddr, op, val,
+ deadline == absl::InfiniteFuture() ? nullptr : &deadline_ts, nullptr,
+ bitset);
+}
+
+int futex_wake(bool priv, std::atomic<int>* uaddr, int count) {
+ int op = FUTEX_WAKE;
+ if (priv) {
+ op |= FUTEX_PRIVATE_FLAG;
+ }
+ return syscall(SYS_futex, uaddr, op, count);
+}
+
+int futex_wake_bitset(bool priv, std::atomic<int>* uaddr, int count,
+ int bitset) {
+ int op = FUTEX_WAKE_BITSET;
+ if (priv) {
+ op |= FUTEX_PRIVATE_FLAG;
+ }
+ return syscall(SYS_futex, uaddr, op, count, nullptr, nullptr, bitset);
+}
+
+int futex_wake_op(bool priv, std::atomic<int>* uaddr1, std::atomic<int>* uaddr2,
+ int nwake1, int nwake2, uint32_t sub_op) {
+ int op = FUTEX_WAKE_OP;
+ if (priv) {
+ op |= FUTEX_PRIVATE_FLAG;
+ }
+ return syscall(SYS_futex, uaddr1, op, nwake1, nwake2, uaddr2, sub_op);
+}
+
+int futex_lock_pi(bool priv, std::atomic<int>* uaddr) {
+ int op = FUTEX_LOCK_PI;
+ if (priv) {
+ op |= FUTEX_PRIVATE_FLAG;
+ }
+ int zero = 0;
+ if (uaddr->compare_exchange_strong(zero, gettid())) {
+ return 0;
+ }
+ return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr);
+}
+
+int futex_trylock_pi(bool priv, std::atomic<int>* uaddr) {
+ int op = FUTEX_TRYLOCK_PI;
+ if (priv) {
+ op |= FUTEX_PRIVATE_FLAG;
+ }
+ int zero = 0;
+ if (uaddr->compare_exchange_strong(zero, gettid())) {
+ return 0;
+ }
+ return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr);
+}
+
+int futex_unlock_pi(bool priv, std::atomic<int>* uaddr) {
+ int op = FUTEX_UNLOCK_PI;
+ if (priv) {
+ op |= FUTEX_PRIVATE_FLAG;
+ }
+ int tid = gettid();
+ if (uaddr->compare_exchange_strong(tid, 0)) {
+ return 0;
+ }
+ return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr);
+}
+
+// Fixture for futex tests parameterized by whether to use private or shared
+// futexes.
+class PrivateAndSharedFutexTest : public ::testing::TestWithParam<bool> {
+ protected:
+ bool IsPrivate() const { return GetParam(); }
+ int PrivateFlag() const { return IsPrivate() ? FUTEX_PRIVATE_FLAG : 0; }
+};
+
+// FUTEX_WAIT with 0 timeout does not block.
+TEST_P(PrivateAndSharedFutexTest, Wait_ZeroTimeout) {
+ struct timespec timeout = {};
+
+ // Don't use the futex_wait helper because it adjusts timeout.
+ int a = 1;
+ EXPECT_THAT(syscall(SYS_futex, &a, FUTEX_WAIT | PrivateFlag(), a, &timeout),
+ SyscallFailsWithErrno(ETIMEDOUT));
+}
+
+TEST_P(PrivateAndSharedFutexTest, Wait_Timeout) {
+ std::atomic<int> a = ATOMIC_VAR_INIT(1);
+
+ MonotonicTimer timer;
+ timer.Start();
+ constexpr absl::Duration kTimeout = absl::Seconds(1);
+ EXPECT_THAT(futex_wait(IsPrivate(), &a, a, kTimeout),
+ SyscallFailsWithErrno(ETIMEDOUT));
+ EXPECT_GE(timer.Duration(), kTimeout);
+}
+
+TEST_P(PrivateAndSharedFutexTest, Wait_BitsetTimeout) {
+ std::atomic<int> a = ATOMIC_VAR_INIT(1);
+
+ MonotonicTimer timer;
+ timer.Start();
+ constexpr absl::Duration kTimeout = absl::Seconds(1);
+ EXPECT_THAT(
+ futex_wait_bitset(IsPrivate(), &a, a, 0xffffffff, absl::Now() + kTimeout),
+ SyscallFailsWithErrno(ETIMEDOUT));
+ EXPECT_GE(timer.Duration(), kTimeout);
+}
+
+TEST_P(PrivateAndSharedFutexTest, WaitBitset_NegativeTimeout) {
+ std::atomic<int> a = ATOMIC_VAR_INIT(1);
+
+ MonotonicTimer timer;
+ timer.Start();
+ EXPECT_THAT(futex_wait_bitset(IsPrivate(), &a, a, 0xffffffff,
+ absl::Now() - absl::Seconds(1)),
+ SyscallFailsWithErrno(ETIMEDOUT));
+}
+
+TEST_P(PrivateAndSharedFutexTest, Wait_WrongVal) {
+ std::atomic<int> a = ATOMIC_VAR_INIT(1);
+ EXPECT_THAT(futex_wait(IsPrivate(), &a, a + 1),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST_P(PrivateAndSharedFutexTest, Wait_ZeroBitset) {
+ std::atomic<int> a = ATOMIC_VAR_INIT(1);
+ EXPECT_THAT(futex_wait_bitset(IsPrivate(), &a, a, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(PrivateAndSharedFutexTest, Wake1_NoRandomSave) {
+ constexpr int kInitialValue = 1;
+ std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
+
+ // Prevent save/restore from interrupting futex_wait, which will cause it to
+ // return EAGAIN instead of the expected result if futex_wait is restarted
+ // after we change the value of a below.
+ DisableSave ds;
+ ScopedThread thread([&] {
+ EXPECT_THAT(futex_wait(IsPrivate(), &a, kInitialValue),
+ SyscallSucceedsWithValue(0));
+ });
+ absl::SleepFor(kWaiterStartupDelay);
+
+ // Change a so that if futex_wake happens before futex_wait, the latter
+ // returns EAGAIN instead of hanging the test.
+ a.fetch_add(1);
+ EXPECT_THAT(futex_wake(IsPrivate(), &a, 1), SyscallSucceedsWithValue(1));
+}
+
+TEST_P(PrivateAndSharedFutexTest, Wake0_NoRandomSave) {
+ constexpr int kInitialValue = 1;
+ std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
+
+ // Prevent save/restore from interrupting futex_wait, which will cause it to
+ // return EAGAIN instead of the expected result if futex_wait is restarted
+ // after we change the value of a below.
+ DisableSave ds;
+ ScopedThread thread([&] {
+ EXPECT_THAT(futex_wait(IsPrivate(), &a, kInitialValue),
+ SyscallSucceedsWithValue(0));
+ });
+ absl::SleepFor(kWaiterStartupDelay);
+
+ // Change a so that if futex_wake happens before futex_wait, the latter
+ // returns EAGAIN instead of hanging the test.
+ a.fetch_add(1);
+ // The Linux kernel wakes one waiter even if val is 0 or negative.
+ EXPECT_THAT(futex_wake(IsPrivate(), &a, 0), SyscallSucceedsWithValue(1));
+}
+
+TEST_P(PrivateAndSharedFutexTest, WakeAll_NoRandomSave) {
+ constexpr int kInitialValue = 1;
+ std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
+
+ DisableSave ds;
+ constexpr int kThreads = 5;
+ std::vector<std::unique_ptr<ScopedThread>> threads;
+ threads.reserve(kThreads);
+ for (int i = 0; i < kThreads; i++) {
+ threads.push_back(absl::make_unique<ScopedThread>([&] {
+ EXPECT_THAT(futex_wait(IsPrivate(), &a, kInitialValue),
+ SyscallSucceeds());
+ }));
+ }
+ absl::SleepFor(kWaiterStartupDelay);
+
+ a.fetch_add(1);
+ EXPECT_THAT(futex_wake(IsPrivate(), &a, kThreads),
+ SyscallSucceedsWithValue(kThreads));
+}
+
+TEST_P(PrivateAndSharedFutexTest, WakeSome_NoRandomSave) {
+ constexpr int kInitialValue = 1;
+ std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
+
+ DisableSave ds;
+ constexpr int kThreads = 5;
+ constexpr int kWokenThreads = 3;
+ static_assert(kWokenThreads < kThreads,
+ "can't wake more threads than are created");
+ std::vector<std::unique_ptr<ScopedThread>> threads;
+ threads.reserve(kThreads);
+ std::vector<int> rets;
+ rets.reserve(kThreads);
+ std::vector<int> errs;
+ errs.reserve(kThreads);
+ for (int i = 0; i < kThreads; i++) {
+ rets.push_back(-1);
+ errs.push_back(0);
+ }
+ for (int i = 0; i < kThreads; i++) {
+ threads.push_back(absl::make_unique<ScopedThread>([&, i] {
+ rets[i] =
+ futex_wait(IsPrivate(), &a, kInitialValue, kIneffectiveWakeTimeout);
+ errs[i] = errno;
+ }));
+ }
+ absl::SleepFor(kWaiterStartupDelay);
+
+ a.fetch_add(1);
+ EXPECT_THAT(futex_wake(IsPrivate(), &a, kWokenThreads),
+ SyscallSucceedsWithValue(kWokenThreads));
+
+ int woken = 0;
+ int timedout = 0;
+ for (int i = 0; i < kThreads; i++) {
+ threads[i]->Join();
+ if (rets[i] == 0) {
+ woken++;
+ } else if (errs[i] == ETIMEDOUT) {
+ timedout++;
+ } else {
+ ADD_FAILURE() << " thread " << i << ": returned " << rets[i] << ", errno "
+ << errs[i];
+ }
+ }
+ EXPECT_EQ(woken, kWokenThreads);
+ EXPECT_EQ(timedout, kThreads - kWokenThreads);
+}
+
+TEST_P(PrivateAndSharedFutexTest, WaitBitset_Wake_NoRandomSave) {
+ constexpr int kInitialValue = 1;
+ std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
+
+ DisableSave ds;
+ ScopedThread thread([&] {
+ EXPECT_THAT(futex_wait_bitset(IsPrivate(), &a, kInitialValue, 0b01001000),
+ SyscallSucceeds());
+ });
+ absl::SleepFor(kWaiterStartupDelay);
+
+ a.fetch_add(1);
+ EXPECT_THAT(futex_wake(IsPrivate(), &a, 1), SyscallSucceedsWithValue(1));
+}
+
+TEST_P(PrivateAndSharedFutexTest, Wait_WakeBitset_NoRandomSave) {
+ constexpr int kInitialValue = 1;
+ std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
+
+ DisableSave ds;
+ ScopedThread thread([&] {
+ EXPECT_THAT(futex_wait(IsPrivate(), &a, kInitialValue), SyscallSucceeds());
+ });
+ absl::SleepFor(kWaiterStartupDelay);
+
+ a.fetch_add(1);
+ EXPECT_THAT(futex_wake_bitset(IsPrivate(), &a, 1, 0b01001000),
+ SyscallSucceedsWithValue(1));
+}
+
+TEST_P(PrivateAndSharedFutexTest, WaitBitset_WakeBitsetMatch_NoRandomSave) {
+ constexpr int kInitialValue = 1;
+ std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
+
+ constexpr int kBitset = 0b01001000;
+
+ DisableSave ds;
+ ScopedThread thread([&] {
+ EXPECT_THAT(futex_wait_bitset(IsPrivate(), &a, kInitialValue, kBitset),
+ SyscallSucceeds());
+ });
+ absl::SleepFor(kWaiterStartupDelay);
+
+ a.fetch_add(1);
+ EXPECT_THAT(futex_wake_bitset(IsPrivate(), &a, 1, kBitset),
+ SyscallSucceedsWithValue(1));
+}
+
+TEST_P(PrivateAndSharedFutexTest, WaitBitset_WakeBitsetNoMatch_NoRandomSave) {
+ constexpr int kInitialValue = 1;
+ std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
+
+ constexpr int kWaitBitset = 0b01000001;
+ constexpr int kWakeBitset = 0b00101000;
+ static_assert((kWaitBitset & kWakeBitset) == 0,
+ "futex_wake_bitset will wake waiter");
+
+ DisableSave ds;
+ ScopedThread thread([&] {
+ EXPECT_THAT(futex_wait_bitset(IsPrivate(), &a, kInitialValue, kWaitBitset,
+ absl::Now() + kIneffectiveWakeTimeout),
+ SyscallFailsWithErrno(ETIMEDOUT));
+ });
+ absl::SleepFor(kWaiterStartupDelay);
+
+ a.fetch_add(1);
+ EXPECT_THAT(futex_wake_bitset(IsPrivate(), &a, 1, kWakeBitset),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_P(PrivateAndSharedFutexTest, WakeOpCondSuccess_NoRandomSave) {
+ constexpr int kInitialValue = 1;
+ std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
+ std::atomic<int> b = ATOMIC_VAR_INIT(kInitialValue);
+
+ DisableSave ds;
+ ScopedThread thread_a([&] {
+ EXPECT_THAT(futex_wait(IsPrivate(), &a, kInitialValue), SyscallSucceeds());
+ });
+ ScopedThread thread_b([&] {
+ EXPECT_THAT(futex_wait(IsPrivate(), &b, kInitialValue), SyscallSucceeds());
+ });
+ absl::SleepFor(kWaiterStartupDelay);
+
+ a.fetch_add(1);
+ b.fetch_add(1);
+ // This futex_wake_op should:
+ // - Wake 1 waiter on a unconditionally.
+ // - Wake 1 waiter on b if b == kInitialValue + 1, which it is.
+ // - Do "b += 1".
+ EXPECT_THAT(futex_wake_op(IsPrivate(), &a, &b, 1, 1,
+ FUTEX_OP(FUTEX_OP_ADD, 1, FUTEX_OP_CMP_EQ,
+ (kInitialValue + 1))),
+ SyscallSucceedsWithValue(2));
+ EXPECT_EQ(b, kInitialValue + 2);
+}
+
+TEST_P(PrivateAndSharedFutexTest, WakeOpCondFailure_NoRandomSave) {
+ constexpr int kInitialValue = 1;
+ std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
+ std::atomic<int> b = ATOMIC_VAR_INIT(kInitialValue);
+
+ DisableSave ds;
+ ScopedThread thread_a([&] {
+ EXPECT_THAT(futex_wait(IsPrivate(), &a, kInitialValue), SyscallSucceeds());
+ });
+ ScopedThread thread_b([&] {
+ EXPECT_THAT(
+ futex_wait(IsPrivate(), &b, kInitialValue, kIneffectiveWakeTimeout),
+ SyscallFailsWithErrno(ETIMEDOUT));
+ });
+ absl::SleepFor(kWaiterStartupDelay);
+
+ a.fetch_add(1);
+ b.fetch_add(1);
+ // This futex_wake_op should:
+ // - Wake 1 waiter on a unconditionally.
+ // - Wake 1 waiter on b if b == kInitialValue - 1, which it isn't.
+ // - Do "b += 1".
+ EXPECT_THAT(futex_wake_op(IsPrivate(), &a, &b, 1, 1,
+ FUTEX_OP(FUTEX_OP_ADD, 1, FUTEX_OP_CMP_EQ,
+ (kInitialValue - 1))),
+ SyscallSucceedsWithValue(1));
+ EXPECT_EQ(b, kInitialValue + 2);
+}
+
+TEST_P(PrivateAndSharedFutexTest, NoWakeInterprocessPrivateAnon_NoRandomSave) {
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ auto const ptr = static_cast<std::atomic<int>*>(mapping.ptr());
+ constexpr int kInitialValue = 1;
+ ptr->store(kInitialValue);
+
+ DisableSave ds;
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ TEST_PCHECK(futex_wait(IsPrivate(), ptr, kInitialValue,
+ kIneffectiveWakeTimeout) == -1 &&
+ errno == ETIMEDOUT);
+ _exit(0);
+ }
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+ absl::SleepFor(kWaiterStartupDelay);
+
+ EXPECT_THAT(futex_wake(IsPrivate(), ptr, 1), SyscallSucceedsWithValue(0));
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << " status " << status;
+}
+
+TEST_P(PrivateAndSharedFutexTest, WakeAfterCOWBreak_NoRandomSave) {
+ // Use a futex on a non-stack mapping so we can be sure that the child process
+ // below isn't the one that breaks copy-on-write.
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ auto const ptr = static_cast<std::atomic<int>*>(mapping.ptr());
+ constexpr int kInitialValue = 1;
+ ptr->store(kInitialValue);
+
+ DisableSave ds;
+ ScopedThread thread([&] {
+ EXPECT_THAT(futex_wait(IsPrivate(), ptr, kInitialValue), SyscallSucceeds());
+ });
+ absl::SleepFor(kWaiterStartupDelay);
+
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // Wait to be killed by the parent.
+ while (true) pause();
+ }
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+ auto cleanup_child = Cleanup([&] {
+ EXPECT_THAT(kill(child_pid, SIGKILL), SyscallSucceeds());
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL)
+ << " status " << status;
+ });
+
+ // In addition to preventing a late futex_wait from sleeping, this breaks
+ // copy-on-write on the mapped page.
+ ptr->fetch_add(1);
+ EXPECT_THAT(futex_wake(IsPrivate(), ptr, 1), SyscallSucceedsWithValue(1));
+}
+
+TEST_P(PrivateAndSharedFutexTest, WakeWrongKind_NoRandomSave) {
+ constexpr int kInitialValue = 1;
+ std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
+
+ DisableSave ds;
+ ScopedThread thread([&] {
+ EXPECT_THAT(
+ futex_wait(IsPrivate(), &a, kInitialValue, kIneffectiveWakeTimeout),
+ SyscallFailsWithErrno(ETIMEDOUT));
+ });
+ absl::SleepFor(kWaiterStartupDelay);
+
+ a.fetch_add(1);
+ // The value of priv passed to futex_wake is the opposite of that passed to
+ // the futex_waiter; we expect this not to wake the waiter.
+ EXPECT_THAT(futex_wake(!IsPrivate(), &a, 1), SyscallSucceedsWithValue(0));
+}
+
+INSTANTIATE_TEST_SUITE_P(SharedPrivate, PrivateAndSharedFutexTest,
+ ::testing::Bool());
+
+// Passing null as the address only works for private futexes.
+
+TEST(PrivateFutexTest, WakeOp0Set) {
+ std::atomic<int> a = ATOMIC_VAR_INIT(1);
+
+ int futex_op = FUTEX_OP(FUTEX_OP_SET, 2, 0, 0);
+ EXPECT_THAT(futex_wake_op(true, nullptr, &a, 0, 0, futex_op),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(a, 2);
+}
+
+TEST(PrivateFutexTest, WakeOp0Add) {
+ std::atomic<int> a = ATOMIC_VAR_INIT(1);
+ int futex_op = FUTEX_OP(FUTEX_OP_ADD, 1, 0, 0);
+ EXPECT_THAT(futex_wake_op(true, nullptr, &a, 0, 0, futex_op),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(a, 2);
+}
+
+TEST(PrivateFutexTest, WakeOp0Or) {
+ std::atomic<int> a = ATOMIC_VAR_INIT(0b01);
+ int futex_op = FUTEX_OP(FUTEX_OP_OR, 0b10, 0, 0);
+ EXPECT_THAT(futex_wake_op(true, nullptr, &a, 0, 0, futex_op),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(a, 0b11);
+}
+
+TEST(PrivateFutexTest, WakeOp0Andn) {
+ std::atomic<int> a = ATOMIC_VAR_INIT(0b11);
+ int futex_op = FUTEX_OP(FUTEX_OP_ANDN, 0b10, 0, 0);
+ EXPECT_THAT(futex_wake_op(true, nullptr, &a, 0, 0, futex_op),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(a, 0b01);
+}
+
+TEST(PrivateFutexTest, WakeOp0Xor) {
+ std::atomic<int> a = ATOMIC_VAR_INIT(0b1010);
+ int futex_op = FUTEX_OP(FUTEX_OP_XOR, 0b1100, 0, 0);
+ EXPECT_THAT(futex_wake_op(true, nullptr, &a, 0, 0, futex_op),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(a, 0b0110);
+}
+
+TEST(SharedFutexTest, WakeInterprocessSharedAnon_NoRandomSave) {
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED));
+ auto const ptr = static_cast<std::atomic<int>*>(mapping.ptr());
+ constexpr int kInitialValue = 1;
+ ptr->store(kInitialValue);
+
+ DisableSave ds;
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ TEST_PCHECK(futex_wait(false, ptr, kInitialValue) == 0);
+ _exit(0);
+ }
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+ auto kill_child = Cleanup(
+ [&] { EXPECT_THAT(kill(child_pid, SIGKILL), SyscallSucceeds()); });
+ absl::SleepFor(kWaiterStartupDelay);
+
+ ptr->fetch_add(1);
+ // This is an ASSERT so that if it fails, we immediately abort the test (and
+ // kill the subprocess).
+ ASSERT_THAT(futex_wake(false, ptr, 1), SyscallSucceedsWithValue(1));
+
+ kill_child.Release();
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << " status " << status;
+}
+
+TEST(SharedFutexTest, WakeInterprocessFile_NoRandomSave) {
+ auto const file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ ASSERT_THAT(truncate(file.path().c_str(), kPageSize), SyscallSucceeds());
+ auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(Mmap(
+ nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd.get(), 0));
+ auto const ptr = static_cast<std::atomic<int>*>(mapping.ptr());
+ constexpr int kInitialValue = 1;
+ ptr->store(kInitialValue);
+
+ DisableSave ds;
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ TEST_PCHECK(futex_wait(false, ptr, kInitialValue) == 0);
+ _exit(0);
+ }
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+ auto kill_child = Cleanup(
+ [&] { EXPECT_THAT(kill(child_pid, SIGKILL), SyscallSucceeds()); });
+ absl::SleepFor(kWaiterStartupDelay);
+
+ ptr->fetch_add(1);
+ // This is an ASSERT so that if it fails, we immediately abort the test (and
+ // kill the subprocess).
+ ASSERT_THAT(futex_wake(false, ptr, 1), SyscallSucceedsWithValue(1));
+
+ kill_child.Release();
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << " status " << status;
+}
+
+TEST_P(PrivateAndSharedFutexTest, PIBasic) {
+ std::atomic<int> a = ATOMIC_VAR_INIT(0);
+
+ ASSERT_THAT(futex_lock_pi(IsPrivate(), &a), SyscallSucceeds());
+ EXPECT_EQ(a.load(), gettid());
+ EXPECT_THAT(futex_lock_pi(IsPrivate(), &a), SyscallFailsWithErrno(EDEADLK));
+
+ ASSERT_THAT(futex_unlock_pi(IsPrivate(), &a), SyscallSucceeds());
+ EXPECT_EQ(a.load(), 0);
+ EXPECT_THAT(futex_unlock_pi(IsPrivate(), &a), SyscallFailsWithErrno(EPERM));
+}
+
+TEST_P(PrivateAndSharedFutexTest, PIConcurrency_NoRandomSave) {
+ DisableSave ds; // Too many syscalls.
+
+ std::atomic<int> a = ATOMIC_VAR_INIT(0);
+ const bool is_priv = IsPrivate();
+
+ std::unique_ptr<ScopedThread> threads[100];
+ for (size_t i = 0; i < ABSL_ARRAYSIZE(threads); ++i) {
+ threads[i] = absl::make_unique<ScopedThread>([is_priv, &a] {
+ for (size_t j = 0; j < 10; ++j) {
+ ASSERT_THAT(futex_lock_pi(is_priv, &a), SyscallSucceeds());
+ EXPECT_EQ(a.load() & FUTEX_TID_MASK, gettid());
+ SleepSafe(absl::Milliseconds(5));
+ ASSERT_THAT(futex_unlock_pi(is_priv, &a), SyscallSucceeds());
+ }
+ });
+ }
+}
+
+TEST_P(PrivateAndSharedFutexTest, PIWaiters) {
+ std::atomic<int> a = ATOMIC_VAR_INIT(0);
+ const bool is_priv = IsPrivate();
+
+ ASSERT_THAT(futex_lock_pi(is_priv, &a), SyscallSucceeds());
+ EXPECT_EQ(a.load(), gettid());
+
+ ScopedThread th([is_priv, &a] {
+ ASSERT_THAT(futex_lock_pi(is_priv, &a), SyscallSucceeds());
+ ASSERT_THAT(futex_unlock_pi(is_priv, &a), SyscallSucceeds());
+ });
+
+ // Wait until the thread blocks on the futex, setting the waiters bit.
+ auto start = absl::Now();
+ while (a.load() != (FUTEX_WAITERS | gettid())) {
+ ASSERT_LT(absl::Now() - start, absl::Seconds(5));
+ absl::SleepFor(absl::Milliseconds(100));
+ }
+ ASSERT_THAT(futex_unlock_pi(is_priv, &a), SyscallSucceeds());
+}
+
+TEST_P(PrivateAndSharedFutexTest, PITryLock) {
+ std::atomic<int> a = ATOMIC_VAR_INIT(0);
+ const bool is_priv = IsPrivate();
+
+ ASSERT_THAT(futex_trylock_pi(IsPrivate(), &a), SyscallSucceeds());
+ EXPECT_EQ(a.load(), gettid());
+
+ EXPECT_THAT(futex_trylock_pi(is_priv, &a), SyscallFailsWithErrno(EDEADLK));
+ ScopedThread th([is_priv, &a] {
+ EXPECT_THAT(futex_trylock_pi(is_priv, &a), SyscallFailsWithErrno(EAGAIN));
+ });
+ th.Join();
+
+ ASSERT_THAT(futex_unlock_pi(IsPrivate(), &a), SyscallSucceeds());
+}
+
+TEST_P(PrivateAndSharedFutexTest, PITryLockConcurrency_NoRandomSave) {
+ DisableSave ds; // Too many syscalls.
+
+ std::atomic<int> a = ATOMIC_VAR_INIT(0);
+ const bool is_priv = IsPrivate();
+
+ std::unique_ptr<ScopedThread> threads[10];
+ for (size_t i = 0; i < ABSL_ARRAYSIZE(threads); ++i) {
+ threads[i] = absl::make_unique<ScopedThread>([is_priv, &a] {
+ for (size_t j = 0; j < 10;) {
+ if (futex_trylock_pi(is_priv, &a) == 0) {
+ ++j;
+ EXPECT_EQ(a.load() & FUTEX_TID_MASK, gettid());
+ SleepSafe(absl::Milliseconds(5));
+ ASSERT_THAT(futex_unlock_pi(is_priv, &a), SyscallSucceeds());
+ }
+ }
+ });
+ }
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/getcpu.cc b/test/syscalls/linux/getcpu.cc
new file mode 100644
index 000000000..f4d94bd6a
--- /dev/null
+++ b/test/syscalls/linux/getcpu.cc
@@ -0,0 +1,40 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sched.h>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(GetcpuTest, IsValidCpuStress) {
+ const int num_cpus = NumCPUs();
+ absl::Time deadline = absl::Now() + absl::Seconds(10);
+ while (absl::Now() < deadline) {
+ int cpu;
+ ASSERT_THAT(cpu = sched_getcpu(), SyscallSucceeds());
+ ASSERT_LT(cpu, num_cpus);
+ }
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/getdents.cc b/test/syscalls/linux/getdents.cc
new file mode 100644
index 000000000..b147d6181
--- /dev/null
+++ b/test/syscalls/linux/getdents.cc
@@ -0,0 +1,539 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <dirent.h>
+#include <errno.h>
+#include <fcntl.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <string.h>
+#include <sys/mman.h>
+#include <sys/types.h>
+#include <syscall.h>
+#include <unistd.h>
+
+#include <map>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "test/util/eventfd_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+using ::testing::Contains;
+using ::testing::IsEmpty;
+using ::testing::IsSupersetOf;
+using ::testing::Not;
+using ::testing::NotNull;
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// New Linux dirent format.
+struct linux_dirent64 {
+ uint64_t d_ino; // Inode number
+ int64_t d_off; // Offset to next linux_dirent64
+ unsigned short d_reclen; // NOLINT, Length of this linux_dirent64
+ unsigned char d_type; // NOLINT, File type
+ char d_name[0]; // Filename (null-terminated)
+};
+
+// Old Linux dirent format.
+struct linux_dirent {
+ unsigned long d_ino; // NOLINT
+ unsigned long d_off; // NOLINT
+ unsigned short d_reclen; // NOLINT
+ char d_name[0];
+};
+
+// Wraps a buffer to provide a set of dirents.
+// T is the underlying dirent type.
+template <typename T>
+class DirentBuffer {
+ public:
+ // DirentBuffer manages the buffer.
+ explicit DirentBuffer(size_t size)
+ : managed_(true), actual_size_(size), reported_size_(size) {
+ data_ = new char[actual_size_];
+ }
+
+ // The buffer is managed externally.
+ DirentBuffer(char* data, size_t actual_size, size_t reported_size)
+ : managed_(false),
+ data_(data),
+ actual_size_(actual_size),
+ reported_size_(reported_size) {}
+
+ ~DirentBuffer() {
+ if (managed_) {
+ delete[] data_;
+ }
+ }
+
+ T* Data() { return reinterpret_cast<T*>(data_); }
+
+ T* Start(size_t read) {
+ read_ = read;
+ if (read_) {
+ return Data();
+ } else {
+ return nullptr;
+ }
+ }
+
+ T* Current() { return reinterpret_cast<T*>(&data_[off_]); }
+
+ T* Next() {
+ size_t new_off = off_ + Current()->d_reclen;
+ if (new_off >= read_ || new_off >= actual_size_) {
+ return nullptr;
+ }
+
+ off_ = new_off;
+ return Current();
+ }
+
+ size_t Size() { return reported_size_; }
+
+ void Reset() {
+ off_ = 0;
+ read_ = 0;
+ memset(data_, 0, actual_size_);
+ }
+
+ private:
+ bool managed_;
+ char* data_;
+ size_t actual_size_;
+ size_t reported_size_;
+
+ size_t off_ = 0;
+
+ size_t read_ = 0;
+};
+
+// Test for getdents/getdents64.
+// T is the Linux dirent type.
+template <typename T>
+class GetdentsTest : public ::testing::Test {
+ public:
+ using LinuxDirentType = T;
+ using DirentBufferType = DirentBuffer<T>;
+
+ protected:
+ void SetUp() override {
+ dir_ = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ fd_ = ASSERT_NO_ERRNO_AND_VALUE(Open(dir_.path(), O_RDONLY | O_DIRECTORY));
+ }
+
+ // Must be overridden with explicit specialization. See below.
+ int SyscallNum();
+
+ int Getdents(LinuxDirentType* dirp, unsigned int count) {
+ return RetryEINTR(syscall)(SyscallNum(), fd_.get(), dirp, count);
+ }
+
+ // Fill directory with num files, named by number starting at 0.
+ void FillDirectory(size_t num) {
+ for (size_t i = 0; i < num; i++) {
+ auto name = JoinPath(dir_.path(), absl::StrCat(i));
+ TEST_CHECK(CreateWithContents(name, "").ok());
+ }
+ }
+
+ // Fill directory with a given list of filenames.
+ void FillDirectoryWithFiles(const std::vector<std::string>& filenames) {
+ for (const auto& filename : filenames) {
+ auto name = JoinPath(dir_.path(), filename);
+ TEST_CHECK(CreateWithContents(name, "").ok());
+ }
+ }
+
+ // Seek to the start of the directory.
+ PosixError SeekStart() {
+ constexpr off_t kStartOfFile = 0;
+ off_t offset = lseek(fd_.get(), kStartOfFile, SEEK_SET);
+ if (offset < 0) {
+ return PosixError(errno, absl::StrCat("error seeking to ", kStartOfFile));
+ }
+ if (offset != kStartOfFile) {
+ return PosixError(EINVAL, absl::StrCat("tried to seek to ", kStartOfFile,
+ " but got ", offset));
+ }
+ return NoError();
+ }
+
+ // Call getdents multiple times, reading all dirents and calling f on each.
+ // f has the type signature PosixError f(T*).
+ // If f returns a non-OK error, so does ReadDirents.
+ template <typename F>
+ PosixError ReadDirents(DirentBufferType* dirents, F const& f) {
+ int n;
+ do {
+ dirents->Reset();
+
+ n = Getdents(dirents->Data(), dirents->Size());
+ MaybeSave();
+ if (n < 0) {
+ return PosixError(errno, "getdents");
+ }
+
+ for (auto d = dirents->Start(n); d; d = dirents->Next()) {
+ RETURN_IF_ERRNO(f(d));
+ }
+ } while (n > 0);
+
+ return NoError();
+ }
+
+ // Call Getdents successively and count all entries.
+ int ReadAndCountAllEntries(DirentBufferType* dirents) {
+ int found = 0;
+
+ EXPECT_NO_ERRNO(ReadDirents(dirents, [&](LinuxDirentType* d) {
+ found++;
+ return NoError();
+ }));
+
+ return found;
+ }
+
+ private:
+ TempPath dir_;
+ FileDescriptor fd_;
+};
+
+// Multiple template parameters are not allowed, so we must use explicit
+// template specialization to set the syscall number.
+
+// SYS_getdents isn't defined on arm64.
+#ifdef __x86_64__
+template <>
+int GetdentsTest<struct linux_dirent>::SyscallNum() {
+ return SYS_getdents;
+}
+#endif
+
+template <>
+int GetdentsTest<struct linux_dirent64>::SyscallNum() {
+ return SYS_getdents64;
+}
+
+#ifdef __x86_64__
+// Test both legacy getdents and getdents64 on x86_64.
+typedef ::testing::Types<struct linux_dirent, struct linux_dirent64>
+ GetdentsTypes;
+#elif __aarch64__
+// Test only getdents64 on arm64.
+typedef ::testing::Types<struct linux_dirent64> GetdentsTypes;
+#endif
+TYPED_TEST_SUITE(GetdentsTest, GetdentsTypes);
+
+// N.B. TYPED_TESTs require explicitly using this-> to access members of
+// GetdentsTest, since we are inside of a derived class template.
+
+TYPED_TEST(GetdentsTest, VerifyEntries) {
+ typename TestFixture::DirentBufferType dirents(1024);
+
+ this->FillDirectory(2);
+
+ // Map of all the entries we expect to find.
+ std::map<std::string, bool> found;
+ found["."] = false;
+ found[".."] = false;
+ found["0"] = false;
+ found["1"] = false;
+
+ EXPECT_NO_ERRNO(this->ReadDirents(
+ &dirents, [&](typename TestFixture::LinuxDirentType* d) {
+ auto kv = found.find(d->d_name);
+ EXPECT_NE(kv, found.end()) << "Unexpected file: " << d->d_name;
+ if (kv != found.end()) {
+ EXPECT_FALSE(kv->second);
+ }
+ found[d->d_name] = true;
+ return NoError();
+ }));
+
+ for (auto& kv : found) {
+ EXPECT_TRUE(kv.second) << "File not found: " << kv.first;
+ }
+}
+
+TYPED_TEST(GetdentsTest, VerifyPadding) {
+ typename TestFixture::DirentBufferType dirents(1024);
+
+ // Create files with names of length 1 through 16.
+ std::vector<std::string> files;
+ std::string filename;
+ for (int i = 0; i < 16; ++i) {
+ absl::StrAppend(&filename, "a");
+ files.push_back(filename);
+ }
+ this->FillDirectoryWithFiles(files);
+
+ // We expect to find all the files, plus '.' and '..'.
+ const int expect_found = 2 + files.size();
+ int found = 0;
+
+ EXPECT_NO_ERRNO(this->ReadDirents(
+ &dirents, [&](typename TestFixture::LinuxDirentType* d) {
+ EXPECT_EQ(d->d_reclen % 8, 0)
+ << "Dirent " << d->d_name
+ << " had reclen that was not byte aligned: " << d->d_name;
+ found++;
+ return NoError();
+ }));
+
+ // Make sure we found all the files.
+ EXPECT_EQ(found, expect_found);
+}
+
+// For a small directory, the provided buffer should be large enough
+// for all entries.
+TYPED_TEST(GetdentsTest, SmallDir) {
+ // . and .. should be in an otherwise empty directory.
+ int expect = 2;
+
+ // Add some actual files.
+ this->FillDirectory(2);
+ expect += 2;
+
+ typename TestFixture::DirentBufferType dirents(256);
+
+ EXPECT_EQ(expect, this->ReadAndCountAllEntries(&dirents));
+}
+
+// A directory with lots of files requires calling getdents multiple times.
+TYPED_TEST(GetdentsTest, LargeDir) {
+ // . and .. should be in an otherwise empty directory.
+ int expect = 2;
+
+ // Add some actual files.
+ this->FillDirectory(100);
+ expect += 100;
+
+ typename TestFixture::DirentBufferType dirents(256);
+
+ EXPECT_EQ(expect, this->ReadAndCountAllEntries(&dirents));
+}
+
+// If we lie about the size of the buffer, we should still be able to read the
+// entries with the available space.
+TYPED_TEST(GetdentsTest, PartialBuffer) {
+ // . and .. should be in an otherwise empty directory.
+ int expect = 2;
+
+ // Add some actual files.
+ this->FillDirectory(100);
+ expect += 100;
+
+ void* addr = mmap(0, 2 * kPageSize, PROT_READ | PROT_WRITE,
+ MAP_ANONYMOUS | MAP_PRIVATE, -1, 0);
+ ASSERT_NE(addr, MAP_FAILED);
+
+ char* buf = reinterpret_cast<char*>(addr);
+
+ // Guard page
+ EXPECT_THAT(
+ mprotect(reinterpret_cast<void*>(buf + kPageSize), kPageSize, PROT_NONE),
+ SyscallSucceeds());
+
+ // Limit space in buf to 256 bytes.
+ buf += kPageSize - 256;
+
+ // Lie about the buffer. Even though we claim the buffer is 1 page,
+ // we should still get all of the dirents in the first 256 bytes.
+ typename TestFixture::DirentBufferType dirents(buf, 256, kPageSize);
+
+ EXPECT_EQ(expect, this->ReadAndCountAllEntries(&dirents));
+
+ EXPECT_THAT(munmap(addr, 2 * kPageSize), SyscallSucceeds());
+}
+
+// Open many file descriptors, then scan through /proc/self/fd to find and close
+// them all. (The latter is commonly used to handle races between fork/execve
+// and the creation of unwanted non-O_CLOEXEC file descriptors.) This tests that
+// getdents iterates correctly despite mutation of /proc/self/fd.
+TYPED_TEST(GetdentsTest, ProcSelfFd) {
+ constexpr size_t kNfds = 10;
+ std::unordered_map<int, FileDescriptor> fds;
+ fds.reserve(kNfds);
+ for (size_t i = 0; i < kNfds; i++) {
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD());
+ fds.emplace(fd.get(), std::move(fd));
+ }
+
+ const FileDescriptor proc_self_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/fd", O_RDONLY | O_DIRECTORY));
+
+ // Make the buffer very small since we want to iterate.
+ typename TestFixture::DirentBufferType dirents(
+ 2 * sizeof(typename TestFixture::LinuxDirentType));
+ std::unordered_set<int> prev_fds;
+ while (true) {
+ dirents.Reset();
+ int rv;
+ ASSERT_THAT(rv = RetryEINTR(syscall)(this->SyscallNum(), proc_self_fd.get(),
+ dirents.Data(), dirents.Size()),
+ SyscallSucceeds());
+ if (rv == 0) {
+ break;
+ }
+ for (auto* d = dirents.Start(rv); d; d = dirents.Next()) {
+ int dfd;
+ if (!absl::SimpleAtoi(d->d_name, &dfd)) continue;
+ EXPECT_TRUE(prev_fds.insert(dfd).second)
+ << "Repeated observation of /proc/self/fd/" << dfd;
+ fds.erase(dfd);
+ }
+ }
+
+ // Check that we closed every fd.
+ EXPECT_THAT(fds, ::testing::IsEmpty());
+}
+
+// Test that getdents returns ENOTDIR when called on a file.
+TYPED_TEST(GetdentsTest, NotDir) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+
+ typename TestFixture::DirentBufferType dirents(256);
+ EXPECT_THAT(RetryEINTR(syscall)(this->SyscallNum(), fd.get(), dirents.Data(),
+ dirents.Size()),
+ SyscallFailsWithErrno(ENOTDIR));
+}
+
+// Test that SEEK_SET to 0 causes getdents to re-read the entries.
+TYPED_TEST(GetdentsTest, SeekResetsCursor) {
+ // . and .. should be in an otherwise empty directory.
+ int expect = 2;
+
+ // Add some files to the directory.
+ this->FillDirectory(10);
+ expect += 10;
+
+ typename TestFixture::DirentBufferType dirents(256);
+
+ // We should get all the expected entries.
+ EXPECT_EQ(expect, this->ReadAndCountAllEntries(&dirents));
+
+ // Seek back to 0.
+ ASSERT_NO_ERRNO(this->SeekStart());
+
+ // We should get all the expected entries again.
+ EXPECT_EQ(expect, this->ReadAndCountAllEntries(&dirents));
+}
+
+// Test that getdents() after SEEK_END succeeds.
+// This is a regression test for #128.
+TYPED_TEST(GetdentsTest, Issue128ProcSeekEnd) {
+ auto fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self", O_RDONLY | O_DIRECTORY));
+ typename TestFixture::DirentBufferType dirents(256);
+
+ ASSERT_THAT(lseek(fd.get(), 0, SEEK_END), SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(syscall)(this->SyscallNum(), fd.get(), dirents.Data(),
+ dirents.Size()),
+ SyscallSucceeds());
+}
+
+// Some tests using the glibc readdir interface.
+TEST(ReaddirTest, OpenDir) {
+ DIR* dev;
+ ASSERT_THAT(dev = opendir("/dev"), NotNull());
+ EXPECT_THAT(closedir(dev), SyscallSucceeds());
+}
+
+TEST(ReaddirTest, RootContainsBasicDirectories) {
+ EXPECT_THAT(ListDir("/", true),
+ IsPosixErrorOkAndHolds(IsSupersetOf(
+ {"bin", "dev", "etc", "lib", "proc", "sbin", "usr"})));
+}
+
+TEST(ReaddirTest, Bug24096713Dev) {
+ auto contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/dev", true));
+ EXPECT_THAT(contents, Not(IsEmpty()));
+}
+
+TEST(ReaddirTest, Bug24096713ProcTid) {
+ auto contents = ASSERT_NO_ERRNO_AND_VALUE(
+ ListDir(absl::StrCat("/proc/", syscall(SYS_gettid), "/"), true));
+ EXPECT_THAT(contents, Not(IsEmpty()));
+}
+
+TEST(ReaddirTest, Bug33429925Proc) {
+ auto contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/proc", true));
+ EXPECT_THAT(contents, Not(IsEmpty()));
+}
+
+TEST(ReaddirTest, Bug35110122Root) {
+ auto contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/", true));
+ EXPECT_THAT(contents, Not(IsEmpty()));
+}
+
+// Unlink should invalidate getdents cache.
+TEST(ReaddirTest, GoneAfterRemoveCache) {
+ TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
+ std::string name = std::string(Basename(file.path()));
+
+ auto contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir(dir.path(), true));
+ EXPECT_THAT(contents, Contains(name));
+
+ file.reset();
+
+ contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir(dir.path(), true));
+ EXPECT_THAT(contents, Not(Contains(name)));
+}
+
+// Regression test for b/137398511. Rename should invalidate getdents cache.
+TEST(ReaddirTest, GoneAfterRenameCache) {
+ TempPath src = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ TempPath dst = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(src.path()));
+ std::string name = std::string(Basename(file.path()));
+
+ auto contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir(src.path(), true));
+ EXPECT_THAT(contents, Contains(name));
+
+ ASSERT_THAT(rename(file.path().c_str(), JoinPath(dst.path(), name).c_str()),
+ SyscallSucceeds());
+ // Release file since it was renamed. dst cleanup will ultimately delete it.
+ file.release();
+
+ contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir(src.path(), true));
+ EXPECT_THAT(contents, Not(Contains(name)));
+
+ contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir(dst.path(), true));
+ EXPECT_THAT(contents, Contains(name));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/getrandom.cc b/test/syscalls/linux/getrandom.cc
new file mode 100644
index 000000000..f87cdd7a1
--- /dev/null
+++ b/test/syscalls/linux/getrandom.cc
@@ -0,0 +1,63 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+#ifndef SYS_getrandom
+#if defined(__x86_64__)
+#define SYS_getrandom 318
+#elif defined(__i386__)
+#define SYS_getrandom 355
+#elif defined(__aarch64__)
+#define SYS_getrandom 278
+#else
+#error "Unknown architecture"
+#endif
+#endif // SYS_getrandom
+
+bool SomeByteIsNonZero(char* random_bytes, int length) {
+ for (int i = 0; i < length; i++) {
+ if (random_bytes[i] != 0) {
+ return true;
+ }
+ }
+ return false;
+}
+
+TEST(GetrandomTest, IsRandom) {
+ // This test calls get_random and makes sure that the array is filled in with
+ // something that is non-zero. Perhaps we get back \x00\x00\x00\x00\x00.... as
+ // a random result, but it's so unlikely that we'll just ignore this.
+ char random_bytes[64] = {};
+ int n = syscall(SYS_getrandom, random_bytes, 64, 0);
+ SKIP_IF(!IsRunningOnGvisor() && n < 0 && errno == ENOSYS);
+ EXPECT_THAT(n, SyscallSucceeds());
+ EXPECT_GT(n, 0); // Some bytes should be returned.
+ EXPECT_TRUE(SomeByteIsNonZero(random_bytes, n));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/getrusage.cc b/test/syscalls/linux/getrusage.cc
new file mode 100644
index 000000000..0e51d42a8
--- /dev/null
+++ b/test/syscalls/linux/getrusage.cc
@@ -0,0 +1,177 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+#include <sys/mman.h>
+#include <sys/resource.h>
+#include <sys/types.h>
+#include <sys/wait.h>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/logging.h"
+#include "test/util/memory_util.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(GetrusageTest, BasicFork) {
+ pid_t pid = fork();
+ if (pid == 0) {
+ struct rusage rusage_self;
+ TEST_PCHECK(getrusage(RUSAGE_SELF, &rusage_self) == 0);
+ struct rusage rusage_children;
+ TEST_PCHECK(getrusage(RUSAGE_CHILDREN, &rusage_children) == 0);
+ // The child has consumed some memory.
+ TEST_CHECK(rusage_self.ru_maxrss != 0);
+ // The child has no children of its own.
+ TEST_CHECK(rusage_children.ru_maxrss == 0);
+ _exit(0);
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(pid, &status, 0), SyscallSucceeds());
+ struct rusage rusage_self;
+ ASSERT_THAT(getrusage(RUSAGE_SELF, &rusage_self), SyscallSucceeds());
+ struct rusage rusage_children;
+ ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &rusage_children), SyscallSucceeds());
+ // The parent has consumed some memory.
+ EXPECT_GT(rusage_self.ru_maxrss, 0);
+ // The child has consumed some memory, and because it has exited we can get
+ // its max RSS.
+ EXPECT_GT(rusage_children.ru_maxrss, 0);
+}
+
+// Verifies that a process can get the max resident set size of its grandchild,
+// i.e. that maxrss propagates correctly from children to waiting parents.
+TEST(GetrusageTest, Grandchild) {
+ constexpr int kGrandchildSizeKb = 1024;
+ pid_t pid = fork();
+ if (pid == 0) {
+ pid = fork();
+ if (pid == 0) {
+ int flags = MAP_ANONYMOUS | MAP_POPULATE | MAP_PRIVATE;
+ void* addr =
+ mmap(nullptr, kGrandchildSizeKb * 1024, PROT_WRITE, flags, -1, 0);
+ TEST_PCHECK(addr != MAP_FAILED);
+ } else {
+ int status;
+ TEST_PCHECK(RetryEINTR(waitpid)(pid, &status, 0) == pid);
+ }
+ _exit(0);
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(pid, &status, 0), SyscallSucceeds());
+ struct rusage rusage_self;
+ ASSERT_THAT(getrusage(RUSAGE_SELF, &rusage_self), SyscallSucceeds());
+ struct rusage rusage_children;
+ ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &rusage_children), SyscallSucceeds());
+ // The parent has consumed some memory.
+ EXPECT_GT(rusage_self.ru_maxrss, 0);
+ // The child should consume next to no memory, but the grandchild will
+ // consume at least 1MB. Verify that usage bubbles up to the grandparent.
+ EXPECT_GT(rusage_children.ru_maxrss, kGrandchildSizeKb);
+}
+
+// Verifies that processes ignoring SIGCHLD do not have updated child maxrss
+// updated.
+TEST(GetrusageTest, IgnoreSIGCHLD) {
+ struct sigaction sa;
+ sa.sa_handler = SIG_IGN;
+ sa.sa_flags = 0;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGCHLD, sa));
+ pid_t pid = fork();
+ if (pid == 0) {
+ struct rusage rusage_self;
+ TEST_PCHECK(getrusage(RUSAGE_SELF, &rusage_self) == 0);
+ // The child has consumed some memory.
+ TEST_CHECK(rusage_self.ru_maxrss != 0);
+ _exit(0);
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(pid, &status, 0),
+ SyscallFailsWithErrno(ECHILD));
+ struct rusage rusage_self;
+ ASSERT_THAT(getrusage(RUSAGE_SELF, &rusage_self), SyscallSucceeds());
+ struct rusage rusage_children;
+ ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &rusage_children), SyscallSucceeds());
+ // The parent has consumed some memory.
+ EXPECT_GT(rusage_self.ru_maxrss, 0);
+ // The child's maxrss should not have propagated up.
+ EXPECT_EQ(rusage_children.ru_maxrss, 0);
+}
+
+// Verifies that zombie processes do not update their parent's maxrss. Only
+// reaped processes should do this.
+TEST(GetrusageTest, IgnoreZombie) {
+ pid_t pid = fork();
+ if (pid == 0) {
+ struct rusage rusage_self;
+ TEST_PCHECK(getrusage(RUSAGE_SELF, &rusage_self) == 0);
+ struct rusage rusage_children;
+ TEST_PCHECK(getrusage(RUSAGE_CHILDREN, &rusage_children) == 0);
+ // The child has consumed some memory.
+ TEST_CHECK(rusage_self.ru_maxrss != 0);
+ // The child has no children of its own.
+ TEST_CHECK(rusage_children.ru_maxrss == 0);
+ _exit(0);
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ // Give the child time to exit. Because we don't call wait, the child should
+ // remain a zombie.
+ absl::SleepFor(absl::Seconds(5));
+ struct rusage rusage_self;
+ ASSERT_THAT(getrusage(RUSAGE_SELF, &rusage_self), SyscallSucceeds());
+ struct rusage rusage_children;
+ ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &rusage_children), SyscallSucceeds());
+ // The parent has consumed some memory.
+ EXPECT_GT(rusage_self.ru_maxrss, 0);
+ // The child has consumed some memory, but hasn't been reaped.
+ EXPECT_EQ(rusage_children.ru_maxrss, 0);
+}
+
+TEST(GetrusageTest, Wait4) {
+ pid_t pid = fork();
+ if (pid == 0) {
+ struct rusage rusage_self;
+ TEST_PCHECK(getrusage(RUSAGE_SELF, &rusage_self) == 0);
+ struct rusage rusage_children;
+ TEST_PCHECK(getrusage(RUSAGE_CHILDREN, &rusage_children) == 0);
+ // The child has consumed some memory.
+ TEST_CHECK(rusage_self.ru_maxrss != 0);
+ // The child has no children of its own.
+ TEST_CHECK(rusage_children.ru_maxrss == 0);
+ _exit(0);
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ struct rusage rusage_children;
+ int status;
+ ASSERT_THAT(RetryEINTR(wait4)(pid, &status, 0, &rusage_children),
+ SyscallSucceeds());
+ // The child has consumed some memory, and because it has exited we can get
+ // its max RSS.
+ EXPECT_GT(rusage_children.ru_maxrss, 0);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/inotify.cc b/test/syscalls/linux/inotify.cc
new file mode 100644
index 000000000..220874aeb
--- /dev/null
+++ b/test/syscalls/linux/inotify.cc
@@ -0,0 +1,2380 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <libgen.h>
+#include <sched.h>
+#include <sys/epoll.h>
+#include <sys/inotify.h>
+#include <sys/ioctl.h>
+#include <sys/time.h>
+#include <sys/xattr.h>
+
+#include <atomic>
+#include <list>
+#include <string>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/epoll_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+using ::absl::StreamFormat;
+using ::absl::StrFormat;
+
+constexpr int kBufSize = 1024;
+
+// C++-friendly version of struct inotify_event.
+struct Event {
+ int32_t wd;
+ uint32_t mask;
+ uint32_t cookie;
+ uint32_t len;
+ std::string name;
+
+ Event(uint32_t mask, int32_t wd, absl::string_view name, uint32_t cookie)
+ : wd(wd),
+ mask(mask),
+ cookie(cookie),
+ len(name.size()),
+ name(std::string(name)) {}
+ Event(uint32_t mask, int32_t wd, absl::string_view name)
+ : Event(mask, wd, name, 0) {}
+ Event(uint32_t mask, int32_t wd) : Event(mask, wd, "", 0) {}
+ Event() : Event(0, 0, "", 0) {}
+};
+
+// Prints the symbolic name for a struct inotify_event's 'mask' field.
+std::string FlagString(uint32_t flags) {
+ std::vector<std::string> names;
+
+#define EMIT(target) \
+ if (flags & target) { \
+ names.push_back(#target); \
+ flags &= ~target; \
+ }
+
+ EMIT(IN_ACCESS);
+ EMIT(IN_ATTRIB);
+ EMIT(IN_CLOSE_WRITE);
+ EMIT(IN_CLOSE_NOWRITE);
+ EMIT(IN_CREATE);
+ EMIT(IN_DELETE);
+ EMIT(IN_DELETE_SELF);
+ EMIT(IN_MODIFY);
+ EMIT(IN_MOVE_SELF);
+ EMIT(IN_MOVED_FROM);
+ EMIT(IN_MOVED_TO);
+ EMIT(IN_OPEN);
+
+ EMIT(IN_DONT_FOLLOW);
+ EMIT(IN_EXCL_UNLINK);
+ EMIT(IN_ONESHOT);
+ EMIT(IN_ONLYDIR);
+
+ EMIT(IN_IGNORED);
+ EMIT(IN_ISDIR);
+ EMIT(IN_Q_OVERFLOW);
+ EMIT(IN_UNMOUNT);
+
+#undef EMIT
+
+ // If we have anything left over at the end, print it as a hex value.
+ if (flags) {
+ names.push_back(absl::StrCat("0x", absl::Hex(flags)));
+ }
+
+ return absl::StrJoin(names, "|");
+}
+
+std::string DumpEvent(const Event& event) {
+ return StrFormat(
+ "%s, wd=%d%s%s", FlagString(event.mask), event.wd,
+ (event.len > 0) ? StrFormat(", name=%s", event.name) : "",
+ (event.cookie > 0) ? StrFormat(", cookie=%ud", event.cookie) : "");
+}
+
+std::string DumpEvents(const std::vector<Event>& events, int indent_level) {
+ std::stringstream ss;
+ ss << StreamFormat("%d event%s:\n", events.size(),
+ (events.size() > 1) ? "s" : "");
+ int i = 0;
+ for (const Event& ev : events) {
+ ss << StreamFormat("%sevents[%d]: %s\n", std::string(indent_level, '\t'),
+ i++, DumpEvent(ev));
+ }
+ return ss.str();
+}
+
+// A matcher which takes an expected list of events to match against another
+// list of inotify events, in order. This is similar to the ElementsAre matcher,
+// but displays more informative messages on mismatch.
+class EventsAreMatcher
+ : public ::testing::MatcherInterface<std::vector<Event>> {
+ public:
+ explicit EventsAreMatcher(std::vector<Event> references)
+ : references_(std::move(references)) {}
+
+ bool MatchAndExplain(
+ std::vector<Event> events,
+ ::testing::MatchResultListener* const listener) const override {
+ if (references_.size() != events.size()) {
+ *listener << StreamFormat("\n\tCount mismatch, got %s",
+ DumpEvents(events, 2));
+ return false;
+ }
+
+ bool success = true;
+ for (unsigned int i = 0; i < references_.size(); ++i) {
+ const Event& reference = references_[i];
+ const Event& target = events[i];
+
+ if (target.mask != reference.mask || target.wd != reference.wd ||
+ target.name != reference.name || target.cookie != reference.cookie) {
+ *listener << StreamFormat("\n\tMismatch at index %d, want %s, got %s,",
+ i, DumpEvent(reference), DumpEvent(target));
+ success = false;
+ }
+ }
+
+ if (!success) {
+ *listener << StreamFormat("\n\tIn total of %s", DumpEvents(events, 2));
+ }
+ return success;
+ }
+
+ void DescribeTo(::std::ostream* const os) const override {
+ *os << StreamFormat("%s", DumpEvents(references_, 1));
+ }
+
+ void DescribeNegationTo(::std::ostream* const os) const override {
+ *os << StreamFormat("mismatch from %s", DumpEvents(references_, 1));
+ }
+
+ private:
+ std::vector<Event> references_;
+};
+
+::testing::Matcher<std::vector<Event>> Are(std::vector<Event> events) {
+ return MakeMatcher(new EventsAreMatcher(std::move(events)));
+}
+
+// Similar to the EventsAre matcher, but the order of events are ignored.
+class UnorderedEventsAreMatcher
+ : public ::testing::MatcherInterface<std::vector<Event>> {
+ public:
+ explicit UnorderedEventsAreMatcher(std::vector<Event> references)
+ : references_(std::move(references)) {}
+
+ bool MatchAndExplain(
+ std::vector<Event> events,
+ ::testing::MatchResultListener* const listener) const override {
+ if (references_.size() != events.size()) {
+ *listener << StreamFormat("\n\tCount mismatch, got %s",
+ DumpEvents(events, 2));
+ return false;
+ }
+
+ std::vector<Event> unmatched(references_);
+
+ for (const Event& candidate : events) {
+ for (auto it = unmatched.begin(); it != unmatched.end();) {
+ const Event& reference = *it;
+ if (candidate.mask == reference.mask && candidate.wd == reference.wd &&
+ candidate.name == reference.name &&
+ candidate.cookie == reference.cookie) {
+ it = unmatched.erase(it);
+ break;
+ } else {
+ ++it;
+ }
+ }
+ }
+
+ // Anything left unmatched? If so, the matcher fails.
+ if (!unmatched.empty()) {
+ *listener << StreamFormat("\n\tFailed to match %s",
+ DumpEvents(unmatched, 2));
+ *listener << StreamFormat("\n\tIn total of %s", DumpEvents(events, 2));
+ return false;
+ }
+
+ return true;
+ }
+
+ void DescribeTo(::std::ostream* const os) const override {
+ *os << StreamFormat("unordered %s", DumpEvents(references_, 1));
+ }
+
+ void DescribeNegationTo(::std::ostream* const os) const override {
+ *os << StreamFormat("mismatch from unordered %s",
+ DumpEvents(references_, 1));
+ }
+
+ private:
+ std::vector<Event> references_;
+};
+
+::testing::Matcher<std::vector<Event>> AreUnordered(std::vector<Event> events) {
+ return MakeMatcher(new UnorderedEventsAreMatcher(std::move(events)));
+}
+
+// Reads events from an inotify fd until either EOF, or read returns EAGAIN.
+PosixErrorOr<std::vector<Event>> DrainEvents(int fd) {
+ std::vector<Event> events;
+ while (true) {
+ int events_size = 0;
+ if (ioctl(fd, FIONREAD, &events_size) < 0) {
+ return PosixError(errno, "ioctl(FIONREAD) failed on inotify fd");
+ }
+ // Deliberately use a buffer that is larger than necessary, expecting to
+ // only read events_size bytes.
+ std::vector<char> buf(events_size + kBufSize, 0);
+ const ssize_t readlen = read(fd, buf.data(), buf.size());
+ MaybeSave();
+ // Read error?
+ if (readlen < 0) {
+ if (errno == EAGAIN) {
+ // If EAGAIN, no more events at the moment. Return what we have so far.
+ return events;
+ }
+ // Some other read error. Return an error. Right now if we encounter this
+ // after already reading some events, they get lost. However, we don't
+ // expect to see any error, and the calling test will fail immediately if
+ // we signal an error anyways, so this is acceptable.
+ return PosixError(errno, "read() failed on inotify fd");
+ }
+ if (readlen < static_cast<int>(sizeof(struct inotify_event))) {
+ // Impossibly short read.
+ return PosixError(
+ EIO,
+ "read() didn't return enough data represent even a single event");
+ }
+ if (readlen != events_size) {
+ return PosixError(EINVAL, absl::StrCat("read ", readlen,
+ " bytes, expected ", events_size));
+ }
+ if (readlen == 0) {
+ // EOF.
+ return events;
+ }
+
+ // Normal read.
+ const char* cursor = buf.data();
+ while (cursor < (buf.data() + readlen)) {
+ struct inotify_event event = {};
+ memcpy(&event, cursor, sizeof(struct inotify_event));
+
+ Event ev;
+ ev.wd = event.wd;
+ ev.mask = event.mask;
+ ev.cookie = event.cookie;
+ ev.len = event.len;
+ if (event.len > 0) {
+ TEST_CHECK(static_cast<int>(sizeof(struct inotify_event) + event.len) <=
+ readlen);
+ ev.name = std::string(cursor +
+ offsetof(struct inotify_event, name)); // NOLINT
+ // Name field should always be smaller than event.len, otherwise we have
+ // a buffer overflow. The two sizes aren't equal because the string
+ // constructor will stop at the first null byte, while event.name may be
+ // padded up to event.len using multiple null bytes.
+ TEST_CHECK(ev.name.size() <= event.len);
+ }
+
+ events.push_back(ev);
+ cursor += sizeof(struct inotify_event) + event.len;
+ }
+ }
+}
+
+PosixErrorOr<FileDescriptor> InotifyInit1(int flags) {
+ int fd;
+ EXPECT_THAT(fd = inotify_init1(flags), SyscallSucceeds());
+ if (fd < 0) {
+ return PosixError(errno, "inotify_init1() failed");
+ }
+ return FileDescriptor(fd);
+}
+
+PosixErrorOr<int> InotifyAddWatch(int fd, const std::string& path,
+ uint32_t mask) {
+ int wd;
+ EXPECT_THAT(wd = inotify_add_watch(fd, path.c_str(), mask),
+ SyscallSucceeds());
+ if (wd < 0) {
+ return PosixError(errno, "inotify_add_watch() failed");
+ }
+ return wd;
+}
+
+TEST(Inotify, IllegalSeek) {
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(0));
+ EXPECT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallFailsWithErrno(ESPIPE));
+}
+
+TEST(Inotify, IllegalPread) {
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(0));
+ int val;
+ EXPECT_THAT(pread(fd.get(), &val, sizeof(val), 0),
+ SyscallFailsWithErrno(ESPIPE));
+}
+
+TEST(Inotify, IllegalPwrite) {
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(0));
+ EXPECT_THAT(pwrite(fd.get(), "x", 1, 0), SyscallFailsWithErrno(ESPIPE));
+}
+
+TEST(Inotify, IllegalWrite) {
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(0));
+ int val = 0;
+ EXPECT_THAT(write(fd.get(), &val, sizeof(val)), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(Inotify, InitFlags) {
+ EXPECT_THAT(inotify_init1(IN_NONBLOCK | IN_CLOEXEC), SyscallSucceeds());
+ EXPECT_THAT(inotify_init1(12345), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(Inotify, NonBlockingReadReturnsEagain) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ std::vector<char> buf(kBufSize, 0);
+
+ // The read below should return fail with EAGAIN because there is no data to
+ // read and we've specified IN_NONBLOCK. We're guaranteed that there is no
+ // data to read because we haven't registered any watches yet.
+ EXPECT_THAT(read(fd.get(), buf.data(), buf.size()),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST(Inotify, AddWatchOnInvalidFdFails) {
+ // Garbage fd.
+ EXPECT_THAT(inotify_add_watch(-1, "/tmp", IN_ALL_EVENTS),
+ SyscallFailsWithErrno(EBADF));
+ EXPECT_THAT(inotify_add_watch(1337, "/tmp", IN_ALL_EVENTS),
+ SyscallFailsWithErrno(EBADF));
+
+ // Non-inotify fds.
+ EXPECT_THAT(inotify_add_watch(0, "/tmp", IN_ALL_EVENTS),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(inotify_add_watch(1, "/tmp", IN_ALL_EVENTS),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(inotify_add_watch(2, "/tmp", IN_ALL_EVENTS),
+ SyscallFailsWithErrno(EINVAL));
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open("/tmp", O_RDONLY));
+ EXPECT_THAT(inotify_add_watch(fd.get(), "/tmp", IN_ALL_EVENTS),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(Inotify, RemovingWatchGeneratesEvent) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+ EXPECT_THAT(inotify_rm_watch(fd.get(), wd), SyscallSucceeds());
+
+ // Read events, ensure the first event is IN_IGNORED.
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ EXPECT_THAT(events, Are({Event(IN_IGNORED, wd)}));
+}
+
+TEST(Inotify, CanDeleteFileAfterRemovingWatch) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+
+ EXPECT_THAT(inotify_rm_watch(fd.get(), wd), SyscallSucceeds());
+ file1.reset();
+}
+
+TEST(Inotify, RemoveWatchAfterDeletingFileFails) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+
+ file1.reset();
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ EXPECT_THAT(events, Are({Event(IN_ATTRIB, wd), Event(IN_DELETE_SELF, wd),
+ Event(IN_IGNORED, wd)}));
+
+ EXPECT_THAT(inotify_rm_watch(fd.get(), wd), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(Inotify, DuplicateWatchRemovalFails) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+
+ EXPECT_THAT(inotify_rm_watch(fd.get(), wd), SyscallSucceeds());
+ EXPECT_THAT(inotify_rm_watch(fd.get(), wd), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(Inotify, ConcurrentFileDeletionAndWatchRemoval) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const std::string filename = NewTempAbsPathInDir(root.path());
+
+ auto file_create_delete = [filename]() {
+ const DisableSave ds; // Too expensive.
+ for (int i = 0; i < 100; ++i) {
+ FileDescriptor file_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(filename, O_CREAT, S_IRUSR | S_IWUSR));
+ file_fd.reset(); // Close before unlinking (although save is disabled).
+ EXPECT_THAT(unlink(filename.c_str()), SyscallSucceeds());
+ }
+ };
+
+ const int shared_fd = fd.get(); // We need to pass it to the thread.
+ auto add_remove_watch = [shared_fd, filename]() {
+ for (int i = 0; i < 100; ++i) {
+ int wd = inotify_add_watch(shared_fd, filename.c_str(), IN_ALL_EVENTS);
+ MaybeSave();
+ if (wd != -1) {
+ // Watch added successfully, try removal.
+ if (inotify_rm_watch(shared_fd, wd)) {
+ // If removal fails, the only acceptable reason is if the wd
+ // is invalid, which will be the case if we try to remove
+ // the watch after the file has been deleted.
+ EXPECT_EQ(errno, EINVAL);
+ }
+ } else {
+ // Add watch failed, this should only fail if the target file doesn't
+ // exist.
+ EXPECT_EQ(errno, ENOENT);
+ }
+ }
+ };
+
+ ScopedThread t1(file_create_delete);
+ ScopedThread t2(add_remove_watch);
+}
+
+TEST(Inotify, DeletingChildGeneratesEvents) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const int root_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+ const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+
+ const std::string file1_path = file1.reset();
+
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(
+ events,
+ AreUnordered({Event(IN_ATTRIB, file1_wd), Event(IN_DELETE_SELF, file1_wd),
+ Event(IN_IGNORED, file1_wd),
+ Event(IN_DELETE, root_wd, Basename(file1_path))}));
+}
+
+// Creating a file in "parent/child" should generate events for child, but not
+// parent.
+TEST(Inotify, CreatingFileGeneratesEvents) {
+ const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath child =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(parent.path()));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), parent.path(), IN_ALL_EVENTS));
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), child.path(), IN_ALL_EVENTS));
+
+ // Create a new file in the directory.
+ const TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(child.path()));
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+
+ // The library function we use to create the new file opens it for writing to
+ // create it and sets permissions on it, so we expect the three extra events.
+ ASSERT_THAT(events, Are({Event(IN_CREATE, wd, Basename(file1.path())),
+ Event(IN_OPEN, wd, Basename(file1.path())),
+ Event(IN_CLOSE_WRITE, wd, Basename(file1.path())),
+ Event(IN_ATTRIB, wd, Basename(file1.path()))}));
+}
+
+TEST(Inotify, ReadingFileGeneratesAccessEvent) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ root.path(), "some content", TempPath::kDefaultFileMode));
+
+ const FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY));
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+
+ char buf;
+ EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds());
+
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_ACCESS, wd, Basename(file1.path()))}));
+}
+
+TEST(Inotify, WritingFileGeneratesModifyEvent) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+
+ const FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY));
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+
+ const std::string data = "some content";
+ EXPECT_THAT(write(file1_fd.get(), data.c_str(), data.length()),
+ SyscallSucceeds());
+
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_MODIFY, wd, Basename(file1.path()))}));
+}
+
+TEST(Inotify, SizeZeroReadWriteGeneratesNothing) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+
+ const FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDWR));
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+
+ // Read from the empty file.
+ int val;
+ ASSERT_THAT(read(file1_fd.get(), &val, sizeof(val)),
+ SyscallSucceedsWithValue(0));
+
+ // Write zero bytes.
+ ASSERT_THAT(write(file1_fd.get(), "", 0), SyscallSucceedsWithValue(0));
+
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({}));
+}
+
+TEST(Inotify, FailedFileCreationGeneratesNoEvents) {
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const std::string dir_path = dir.path();
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(fd.get(), dir_path, IN_ALL_EVENTS));
+
+ const char* p = dir_path.c_str();
+ ASSERT_THAT(mkdir(p, 0777), SyscallFails());
+ ASSERT_THAT(mknod(p, S_IFIFO, 0777), SyscallFails());
+ ASSERT_THAT(symlink(p, p), SyscallFails());
+ ASSERT_THAT(link(p, p), SyscallFails());
+ std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({}));
+}
+
+TEST(Inotify, WatchSetAfterOpenReportsCloseFdEvent) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+
+ FileDescriptor file1_fd_writable =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY));
+ FileDescriptor file1_fd_not_writable =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY));
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+
+ file1_fd_writable.reset(); // Close file1_fd_writable.
+ std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_CLOSE_WRITE, wd, Basename(file1.path()))}));
+
+ file1_fd_not_writable.reset(); // Close file1_fd_not_writable.
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events,
+ Are({Event(IN_CLOSE_NOWRITE, wd, Basename(file1.path()))}));
+}
+
+TEST(Inotify, ChildrenDeletionInWatchedDirGeneratesEvent) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+ TempPath dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root.path()));
+
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+
+ const std::string file1_path = file1.reset();
+ const std::string dir1_path = dir1.release();
+ EXPECT_THAT(rmdir(dir1_path.c_str()), SyscallSucceeds());
+
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+
+ ASSERT_THAT(events,
+ Are({Event(IN_DELETE, wd, Basename(file1_path)),
+ Event(IN_DELETE | IN_ISDIR, wd, Basename(dir1_path))}));
+}
+
+TEST(Inotify, RmdirOnWatchedTargetGeneratesEvent) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+
+ EXPECT_THAT(rmdir(root.path().c_str()), SyscallSucceeds());
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_DELETE_SELF, wd), Event(IN_IGNORED, wd)}));
+}
+
+TEST(Inotify, MoveGeneratesEvents) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+
+ const TempPath dir1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root.path()));
+ const TempPath dir2 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root.path()));
+
+ const int root_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+ const int dir1_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), dir1.path(), IN_ALL_EVENTS));
+ const int dir2_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), dir2.path(), IN_ALL_EVENTS));
+ // Test move from root -> root.
+ std::string newpath = NewTempAbsPathInDir(root.path());
+ std::string oldpath = file1.release();
+ EXPECT_THAT(rename(oldpath.c_str(), newpath.c_str()), SyscallSucceeds());
+ file1.reset(newpath);
+ std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(
+ events,
+ Are({Event(IN_MOVED_FROM, root_wd, Basename(oldpath), events[0].cookie),
+ Event(IN_MOVED_TO, root_wd, Basename(newpath), events[1].cookie)}));
+ EXPECT_NE(events[0].cookie, 0);
+ EXPECT_EQ(events[0].cookie, events[1].cookie);
+ uint32_t last_cookie = events[0].cookie;
+
+ // Test move from root -> root/dir1.
+ newpath = NewTempAbsPathInDir(dir1.path());
+ oldpath = file1.release();
+ EXPECT_THAT(rename(oldpath.c_str(), newpath.c_str()), SyscallSucceeds());
+ file1.reset(newpath);
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(
+ events,
+ Are({Event(IN_MOVED_FROM, root_wd, Basename(oldpath), events[0].cookie),
+ Event(IN_MOVED_TO, dir1_wd, Basename(newpath), events[1].cookie)}));
+ // Cookies should be distinct between distinct rename events.
+ EXPECT_NE(events[0].cookie, last_cookie);
+ EXPECT_EQ(events[0].cookie, events[1].cookie);
+ last_cookie = events[0].cookie;
+
+ // Test move from root/dir1 -> root/dir2.
+ newpath = NewTempAbsPathInDir(dir2.path());
+ oldpath = file1.release();
+ EXPECT_THAT(rename(oldpath.c_str(), newpath.c_str()), SyscallSucceeds());
+ file1.reset(newpath);
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(
+ events,
+ Are({Event(IN_MOVED_FROM, dir1_wd, Basename(oldpath), events[0].cookie),
+ Event(IN_MOVED_TO, dir2_wd, Basename(newpath), events[1].cookie)}));
+ EXPECT_NE(events[0].cookie, last_cookie);
+ EXPECT_EQ(events[0].cookie, events[1].cookie);
+ last_cookie = events[0].cookie;
+}
+
+TEST(Inotify, MoveWatchedTargetGeneratesEvents) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+
+ const int root_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+ const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+
+ const std::string newpath = NewTempAbsPathInDir(root.path());
+ const std::string oldpath = file1.release();
+ EXPECT_THAT(rename(oldpath.c_str(), newpath.c_str()), SyscallSucceeds());
+ file1.reset(newpath);
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(
+ events,
+ Are({Event(IN_MOVED_FROM, root_wd, Basename(oldpath), events[0].cookie),
+ Event(IN_MOVED_TO, root_wd, Basename(newpath), events[1].cookie),
+ // Self move events do not have a cookie.
+ Event(IN_MOVE_SELF, file1_wd)}));
+ EXPECT_NE(events[0].cookie, 0);
+ EXPECT_EQ(events[0].cookie, events[1].cookie);
+}
+
+TEST(Inotify, CoalesceEvents) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ root.path(), "some content", TempPath::kDefaultFileMode));
+
+ FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY));
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+
+ // Read the file a few times. This will would generate multiple IN_ACCESS
+ // events but they should get coalesced to a single event.
+ char buf;
+ EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds());
+ EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds());
+ EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds());
+ EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds());
+
+ // Use the close event verify that we haven't simply left the additional
+ // IN_ACCESS events unread.
+ file1_fd.reset(); // Close file1_fd.
+
+ const std::string file1_name = std::string(Basename(file1.path()));
+ std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_ACCESS, wd, file1_name),
+ Event(IN_CLOSE_NOWRITE, wd, file1_name)}));
+
+ // Now let's try interleaving other events into a stream of repeated events.
+ file1_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDWR));
+
+ EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds());
+ EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds());
+ EXPECT_THAT(write(file1_fd.get(), "x", 1), SyscallSucceeds());
+ EXPECT_THAT(write(file1_fd.get(), "x", 1), SyscallSucceeds());
+ EXPECT_THAT(write(file1_fd.get(), "x", 1), SyscallSucceeds());
+ EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds());
+ EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds());
+
+ file1_fd.reset(); // Close the file.
+
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(
+ events,
+ Are({Event(IN_OPEN, wd, file1_name), Event(IN_ACCESS, wd, file1_name),
+ Event(IN_MODIFY, wd, file1_name), Event(IN_ACCESS, wd, file1_name),
+ Event(IN_CLOSE_WRITE, wd, file1_name)}));
+
+ // Ensure events aren't coalesced if they are from different files.
+ const TempPath file2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ root.path(), "some content", TempPath::kDefaultFileMode));
+ // Discard events resulting from creation of file2.
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+
+ file1_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY));
+ FileDescriptor file2_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file2.path(), O_RDONLY));
+
+ EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds());
+ EXPECT_THAT(read(file2_fd.get(), &buf, 1), SyscallSucceeds());
+ EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds());
+ EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds());
+
+ // Close both files.
+ file1_fd.reset();
+ file2_fd.reset();
+
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ const std::string file2_name = std::string(Basename(file2.path()));
+ ASSERT_THAT(
+ events,
+ Are({Event(IN_OPEN, wd, file1_name), Event(IN_OPEN, wd, file2_name),
+ Event(IN_ACCESS, wd, file1_name), Event(IN_ACCESS, wd, file2_name),
+ Event(IN_ACCESS, wd, file1_name),
+ Event(IN_CLOSE_NOWRITE, wd, file1_name),
+ Event(IN_CLOSE_NOWRITE, wd, file2_name)}));
+}
+
+TEST(Inotify, ClosingInotifyFdWithoutRemovingWatchesWorks) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ const TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+ const FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY));
+
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+ // Note: The check on close will happen in FileDescriptor::~FileDescriptor().
+}
+
+TEST(Inotify, NestedWatches) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ root.path(), "some content", TempPath::kDefaultFileMode));
+ const FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY));
+
+ const int root_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+ const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+
+ // Read from file1. This should generate an event for both watches.
+ char buf;
+ EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds());
+
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_ACCESS, root_wd, Basename(file1.path())),
+ Event(IN_ACCESS, file1_wd)}));
+}
+
+TEST(Inotify, ConcurrentThreadsGeneratingEvents) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ std::vector<TempPath> files;
+ files.reserve(10);
+ for (int i = 0; i < 10; i++) {
+ files.emplace_back(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ root.path(), "some content", TempPath::kDefaultFileMode)));
+ }
+
+ auto test_thread = [&files]() {
+ uint32_t seed = time(nullptr);
+ for (int i = 0; i < 20; i++) {
+ const TempPath& file = files[rand_r(&seed) % files.size()];
+ const FileDescriptor file_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY));
+ TEST_PCHECK(write(file_fd.get(), "x", 1) == 1);
+ }
+ };
+
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+
+ std::list<ScopedThread> threads;
+ for (int i = 0; i < 3; i++) {
+ threads.emplace_back(test_thread);
+ }
+ for (auto& t : threads) {
+ t.Join();
+ }
+
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ // 3 threads doing 20 iterations, 3 events per iteration (open, write,
+ // close). However, some events may be coalesced, and we can't reliably
+ // predict how they'll be coalesced since the test threads aren't
+ // synchronized. We can only check that we aren't getting unexpected events.
+ for (const Event& ev : events) {
+ EXPECT_NE(ev.mask & (IN_OPEN | IN_MODIFY | IN_CLOSE_WRITE), 0);
+ }
+}
+
+TEST(Inotify, ReadWithTooSmallBufferFails) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+
+ // Open the file to queue an event. This event will not have a filename, so
+ // reading from the inotify fd should return sizeof(struct inotify_event)
+ // bytes of data.
+ FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY));
+ std::vector<char> buf(kBufSize, 0);
+ ssize_t readlen;
+
+ // Try a buffer too small to hold any potential event. This is rejected
+ // outright without the event being dequeued.
+ EXPECT_THAT(read(fd.get(), buf.data(), sizeof(struct inotify_event) - 1),
+ SyscallFailsWithErrno(EINVAL));
+ // Try a buffer just large enough. This should succeeed.
+ EXPECT_THAT(
+ readlen = read(fd.get(), buf.data(), sizeof(struct inotify_event)),
+ SyscallSucceeds());
+ EXPECT_EQ(readlen, sizeof(struct inotify_event));
+ // Event queue is now empty, the next read should return EAGAIN.
+ EXPECT_THAT(read(fd.get(), buf.data(), sizeof(struct inotify_event)),
+ SyscallFailsWithErrno(EAGAIN));
+
+ // Now put a watch on the directory, so that generated events contain a name.
+ EXPECT_THAT(inotify_rm_watch(fd.get(), wd), SyscallSucceeds());
+
+ // Drain the event generated from the watch removal.
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+
+ file1_fd.reset(); // Close file to generate an event.
+
+ // Try a buffer too small to hold any event and one too small to hold an event
+ // with a name. These should both fail without consuming the event.
+ EXPECT_THAT(read(fd.get(), buf.data(), sizeof(struct inotify_event) - 1),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(read(fd.get(), buf.data(), sizeof(struct inotify_event)),
+ SyscallFailsWithErrno(EINVAL));
+ // Now try with a large enough buffer. This should return the one event.
+ EXPECT_THAT(readlen = read(fd.get(), buf.data(), buf.size()),
+ SyscallSucceeds());
+ EXPECT_GE(readlen,
+ sizeof(struct inotify_event) + Basename(file1.path()).size());
+ // With the single event read, the queue should once again be empty.
+ EXPECT_THAT(read(fd.get(), buf.data(), sizeof(struct inotify_event)),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST(Inotify, BlockingReadOnInotifyFd) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(0));
+ const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ root.path(), "some content", TempPath::kDefaultFileMode));
+
+ const FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY));
+
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+
+ // Spawn a thread performing a blocking read for new events on the inotify fd.
+ std::vector<char> buf(kBufSize, 0);
+ const int shared_fd = fd.get(); // The thread needs it.
+ ScopedThread t([shared_fd, &buf]() {
+ ssize_t readlen;
+ EXPECT_THAT(readlen = read(shared_fd, buf.data(), buf.size()),
+ SyscallSucceeds());
+ });
+
+ // Perform a read on the watched file, which should generate an IN_ACCESS
+ // event, unblocking the event_reader thread.
+ char c;
+ EXPECT_THAT(read(file1_fd.get(), &c, 1), SyscallSucceeds());
+
+ // Wait for the thread to read the event and exit.
+ t.Join();
+
+ // Make sure the event we got back is sane.
+ uint32_t event_mask;
+ memcpy(&event_mask, buf.data() + offsetof(struct inotify_event, mask),
+ sizeof(event_mask));
+ EXPECT_EQ(event_mask, IN_ACCESS);
+}
+
+TEST(Inotify, WatchOnRelativePath) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ root.path(), "some content", TempPath::kDefaultFileMode));
+
+ const FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY));
+
+ // Change working directory to root.
+ const FileDescriptor cwd = ASSERT_NO_ERRNO_AND_VALUE(Open(".", O_PATH));
+ EXPECT_THAT(chdir(root.path().c_str()), SyscallSucceeds());
+
+ // Add a watch on file1 with a relative path.
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(
+ fd.get(), std::string(Basename(file1.path())), IN_ALL_EVENTS));
+
+ // Perform a read on file1, this should generate an IN_ACCESS event.
+ char c;
+ EXPECT_THAT(read(file1_fd.get(), &c, 1), SyscallSucceeds());
+
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ EXPECT_THAT(events, Are({Event(IN_ACCESS, wd)}));
+
+ // Explicitly reset the working directory so that we don't continue to
+ // reference "root". Once the test ends, "root" will get unlinked. If we
+ // continue to hold a reference, random save/restore tests can fail if a save
+ // is triggered after "root" is unlinked; we can't save deleted fs objects
+ // with active references.
+ EXPECT_THAT(fchdir(cwd.get()), SyscallSucceeds());
+}
+
+TEST(Inotify, ZeroLengthReadWriteDoesNotGenerateEvent) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ const char kContent[] = "some content";
+ TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ root.path(), kContent, TempPath::kDefaultFileMode));
+ const int kContentSize = sizeof(kContent) - 1;
+
+ const FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDWR));
+
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+
+ std::vector<char> buf(kContentSize, 0);
+ // Read all available data.
+ ssize_t readlen;
+ EXPECT_THAT(readlen = read(file1_fd.get(), buf.data(), kContentSize),
+ SyscallSucceeds());
+ EXPECT_EQ(readlen, kContentSize);
+ // Drain all events and make sure we got the IN_ACCESS for the read.
+ std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ EXPECT_THAT(events, Are({Event(IN_ACCESS, wd, Basename(file1.path()))}));
+
+ // Now try read again. This should be a 0-length read, since we're at EOF.
+ char c;
+ EXPECT_THAT(readlen = read(file1_fd.get(), &c, 1), SyscallSucceeds());
+ EXPECT_EQ(readlen, 0);
+ // We should have no new events.
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ EXPECT_TRUE(events.empty());
+
+ // Try issuing a zero-length read.
+ EXPECT_THAT(readlen = read(file1_fd.get(), &c, 0), SyscallSucceeds());
+ EXPECT_EQ(readlen, 0);
+ // We should have no new events.
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ EXPECT_TRUE(events.empty());
+
+ // Try issuing a zero-length write.
+ ssize_t writelen;
+ EXPECT_THAT(writelen = write(file1_fd.get(), &c, 0), SyscallSucceeds());
+ EXPECT_EQ(writelen, 0);
+ // We should have no new events.
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ EXPECT_TRUE(events.empty());
+}
+
+TEST(Inotify, ChmodGeneratesAttribEvent_NoRandomSave) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+
+ FileDescriptor root_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(root.path(), O_RDONLY));
+ FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDWR));
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ const int root_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+ const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+
+ auto verify_chmod_events = [&]() {
+ std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_ATTRIB, root_wd, Basename(file1.path())),
+ Event(IN_ATTRIB, file1_wd)}));
+ };
+
+ // Don't do cooperative S/R tests for any of the {f}chmod* syscalls below, the
+ // test will always fail because nodes cannot be saved when they have stricter
+ // permissions than the original host node.
+ const DisableSave ds;
+
+ // Chmod.
+ ASSERT_THAT(chmod(file1.path().c_str(), S_IWGRP), SyscallSucceeds());
+ verify_chmod_events();
+
+ // Fchmod.
+ ASSERT_THAT(fchmod(file1_fd.get(), S_IRGRP | S_IWGRP), SyscallSucceeds());
+ verify_chmod_events();
+
+ // Fchmodat.
+ const std::string file1_basename = std::string(Basename(file1.path()));
+ ASSERT_THAT(fchmodat(root_fd.get(), file1_basename.c_str(), S_IWGRP, 0),
+ SyscallSucceeds());
+ verify_chmod_events();
+
+ // Make sure the chmod'ed file descriptors are destroyed before DisableSave
+ // is destructed.
+ root_fd.reset();
+ file1_fd.reset();
+}
+
+TEST(Inotify, TruncateGeneratesModifyEvent) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+ const FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDWR));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const int root_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+ const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+
+ auto verify_truncate_events = [&]() {
+ std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_MODIFY, root_wd, Basename(file1.path())),
+ Event(IN_MODIFY, file1_wd)}));
+ };
+
+ // Truncate.
+ EXPECT_THAT(truncate(file1.path().c_str(), 4096), SyscallSucceeds());
+ verify_truncate_events();
+
+ // Ftruncate.
+ EXPECT_THAT(ftruncate(file1_fd.get(), 8192), SyscallSucceeds());
+ verify_truncate_events();
+
+ // No events if truncate fails.
+ EXPECT_THAT(ftruncate(file1_fd.get(), -1), SyscallFailsWithErrno(EINVAL));
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({}));
+}
+
+TEST(Inotify, GetdentsGeneratesAccessEvent) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+
+ // This internally calls getdents(2). We also expect to see an open/close
+ // event for the dirfd.
+ ASSERT_NO_ERRNO_AND_VALUE(ListDir(root.path(), false));
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+
+ // Linux only seems to generate access events on getdents() on some
+ // calls. Allow the test to pass even if it isn't generated. gVisor will
+ // always generate the IN_ACCESS event so the test will at least ensure gVisor
+ // behaves reasonably.
+ int i = 0;
+ EXPECT_EQ(events[i].mask, IN_OPEN | IN_ISDIR);
+ ++i;
+ if (IsRunningOnGvisor()) {
+ EXPECT_EQ(events[i].mask, IN_ACCESS | IN_ISDIR);
+ ++i;
+ } else {
+ if (events[i].mask == (IN_ACCESS | IN_ISDIR)) {
+ // Skip over the IN_ACCESS event on Linux, it only shows up some of the
+ // time so we can't assert its existence.
+ ++i;
+ }
+ }
+ EXPECT_EQ(events[i].mask, IN_CLOSE_NOWRITE | IN_ISDIR);
+}
+
+TEST(Inotify, MknodGeneratesCreateEvent) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+
+ const TempPath file1(root.path() + "/file1");
+ const int rc = mknod(file1.path().c_str(), S_IFREG, 0);
+ // mknod(2) is only supported on tmpfs in the sandbox.
+ SKIP_IF(IsRunningOnGvisor() && rc != 0);
+ ASSERT_THAT(rc, SyscallSucceeds());
+
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_CREATE, wd, Basename(file1.path()))}));
+}
+
+TEST(Inotify, SymlinkGeneratesCreateEvent) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+ const TempPath link1(NewTempAbsPathInDir(root.path()));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ const int root_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+
+ ASSERT_THAT(symlink(file1.path().c_str(), link1.path().c_str()),
+ SyscallSucceeds());
+
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+
+ ASSERT_THAT(events, Are({Event(IN_CREATE, root_wd, Basename(link1.path()))}));
+}
+
+TEST(Inotify, LinkGeneratesAttribAndCreateEvents) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+ const TempPath link1(root.path() + "/link1");
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ const int root_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+ const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+
+ const int rc = link(file1.path().c_str(), link1.path().c_str());
+ // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox.
+ SKIP_IF(IsRunningOnGvisor() && rc != 0 &&
+ (errno == EPERM || errno == ENOENT));
+ ASSERT_THAT(rc, SyscallSucceeds());
+
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_ATTRIB, file1_wd),
+ Event(IN_CREATE, root_wd, Basename(link1.path()))}));
+}
+
+TEST(Inotify, UtimesGeneratesAttribEvent) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+
+ const FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDWR));
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+
+ const struct timeval times[2] = {{1, 0}, {2, 0}};
+ EXPECT_THAT(futimes(file1_fd.get(), times), SyscallSucceeds());
+
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_ATTRIB, wd, Basename(file1.path()))}));
+}
+
+TEST(Inotify, HardlinksReuseSameWatch) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+ TempPath link1(root.path() + "/link1");
+ const int rc = link(file1.path().c_str(), link1.path().c_str());
+ // link(2) is only supported on tmpfs in the sandbox.
+ SKIP_IF(IsRunningOnGvisor() && rc != 0 &&
+ (errno == EPERM || errno == ENOENT));
+ ASSERT_THAT(rc, SyscallSucceeds());
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ const int root_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+ const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+ const int link1_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), link1.path(), IN_ALL_EVENTS));
+
+ // The watch descriptors for watches on different links to the same file
+ // should be identical.
+ EXPECT_NE(root_wd, file1_wd);
+ EXPECT_EQ(file1_wd, link1_wd);
+
+ FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY));
+
+ std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events,
+ AreUnordered({Event(IN_OPEN, root_wd, Basename(file1.path())),
+ Event(IN_OPEN, file1_wd)}));
+
+ // For the next step, we want to ensure all fds to the file are closed. Do
+ // that now and drain the resulting events.
+ file1_fd.reset();
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events,
+ Are({Event(IN_CLOSE_WRITE, root_wd, Basename(file1.path())),
+ Event(IN_CLOSE_WRITE, file1_wd)}));
+
+ // Try removing the link and let's see what events show up. Note that after
+ // this, we still have a link to the file so the watch shouldn't be
+ // automatically removed.
+ const std::string link1_path = link1.reset();
+
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_ATTRIB, link1_wd),
+ Event(IN_DELETE, root_wd, Basename(link1_path))}));
+
+ // Now remove the other link. Since this is the last link to the file, the
+ // watch should be automatically removed.
+ const std::string file1_path = file1.reset();
+
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(
+ events,
+ AreUnordered({Event(IN_ATTRIB, file1_wd), Event(IN_DELETE_SELF, file1_wd),
+ Event(IN_IGNORED, file1_wd),
+ Event(IN_DELETE, root_wd, Basename(file1_path))}));
+}
+
+// Calling mkdir within "parent/child" should generate an event for child, but
+// not parent.
+TEST(Inotify, MkdirGeneratesCreateEventWithDirFlag) {
+ const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath child =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(parent.path()));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), parent.path(), IN_ALL_EVENTS));
+ const int child_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), child.path(), IN_ALL_EVENTS));
+
+ const TempPath dir1(NewTempAbsPathInDir(child.path()));
+ ASSERT_THAT(mkdir(dir1.path().c_str(), 0777), SyscallSucceeds());
+
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(
+ events,
+ Are({Event(IN_CREATE | IN_ISDIR, child_wd, Basename(dir1.path()))}));
+}
+
+TEST(Inotify, MultipleInotifyInstancesAndWatchesAllGetEvents) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+
+ const FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY));
+ constexpr int kNumFds = 30;
+ std::vector<FileDescriptor> inotify_fds;
+
+ for (int i = 0; i < kNumFds; ++i) {
+ const DisableSave ds; // Too expensive.
+ inotify_fds.emplace_back(
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)));
+ const FileDescriptor& fd = inotify_fds[inotify_fds.size() - 1]; // Back.
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+ }
+
+ const std::string data = "some content";
+ EXPECT_THAT(write(file1_fd.get(), data.c_str(), data.length()),
+ SyscallSucceeds());
+
+ for (const FileDescriptor& fd : inotify_fds) {
+ const DisableSave ds; // Too expensive.
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ if (events.size() >= 2) {
+ EXPECT_EQ(events[0].mask, IN_MODIFY);
+ EXPECT_EQ(events[0].wd, 1);
+ EXPECT_EQ(events[0].name, Basename(file1.path()));
+ EXPECT_EQ(events[1].mask, IN_MODIFY);
+ EXPECT_EQ(events[1].wd, 2);
+ EXPECT_EQ(events[1].name, "");
+ }
+ }
+}
+
+TEST(Inotify, EventsGoUpAtMostOneLevel) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath dir1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root.path()));
+ TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path()));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+ const int dir1_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), dir1.path(), IN_ALL_EVENTS));
+
+ const std::string file1_path = file1.reset();
+
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_DELETE, dir1_wd, Basename(file1_path))}));
+}
+
+TEST(Inotify, DuplicateWatchReturnsSameWatchDescriptor) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ const int wd1 = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+ const int wd2 = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+
+ EXPECT_EQ(wd1, wd2);
+
+ const FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY));
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ // The watch shouldn't be duplicated, we only expect one event.
+ ASSERT_THAT(events, Are({Event(IN_OPEN, wd1)}));
+}
+
+TEST(Inotify, UnmatchedEventsAreDiscarded) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ACCESS));
+
+ FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY));
+
+ std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ // We only asked for access events, the open event should be discarded.
+ ASSERT_THAT(events, Are({}));
+
+ // IN_IGNORED events are always generated, regardless of the mask.
+ file1_fd.reset();
+ file1.reset();
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_IGNORED, wd)}));
+}
+
+TEST(Inotify, AddWatchWithInvalidEventMaskFails) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ EXPECT_THAT(inotify_add_watch(fd.get(), root.path().c_str(), 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(Inotify, AddWatchOnInvalidPathFails) {
+ const TempPath nonexistent(NewTempAbsPath());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ // Non-existent path.
+ EXPECT_THAT(
+ inotify_add_watch(fd.get(), nonexistent.path().c_str(), IN_CREATE),
+ SyscallFailsWithErrno(ENOENT));
+
+ // Garbage path pointer.
+ EXPECT_THAT(inotify_add_watch(fd.get(), nullptr, IN_CREATE),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST(Inotify, InOnlyDirFlagRespected) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ EXPECT_THAT(
+ inotify_add_watch(fd.get(), root.path().c_str(), IN_ACCESS | IN_ONLYDIR),
+ SyscallSucceeds());
+
+ EXPECT_THAT(
+ inotify_add_watch(fd.get(), file1.path().c_str(), IN_ACCESS | IN_ONLYDIR),
+ SyscallFailsWithErrno(ENOTDIR));
+}
+
+TEST(Inotify, MaskAddMergesWithExistingEventMask) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY));
+
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_OPEN | IN_CLOSE_WRITE));
+
+ const std::string data = "some content";
+ EXPECT_THAT(write(file1_fd.get(), data.c_str(), data.length()),
+ SyscallSucceeds());
+
+ // We shouldn't get any events, since IN_MODIFY wasn't in the event mask.
+ std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({}));
+
+ // Add IN_MODIFY to event mask.
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_MODIFY | IN_MASK_ADD));
+
+ EXPECT_THAT(write(file1_fd.get(), data.c_str(), data.length()),
+ SyscallSucceeds());
+
+ // This time we should get the modify event.
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_MODIFY, wd)}));
+
+ // Now close the fd. If the modify event was added to the event mask rather
+ // than replacing the event mask we won't get the close event.
+ file1_fd.reset();
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_CLOSE_WRITE, wd)}));
+}
+
+// Test that control events bits are not considered when checking event mask.
+TEST(Inotify, ControlEvents) {
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), dir.path(), IN_ACCESS));
+
+ // Check that events in the mask are dispatched and that control bits are
+ // part of the event mask.
+ std::vector<std::string> files =
+ ASSERT_NO_ERRNO_AND_VALUE(ListDir(dir.path(), false));
+ ASSERT_EQ(files.size(), 2);
+
+ const std::vector<Event> events1 =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events1, Are({Event(IN_ACCESS | IN_ISDIR, wd)}));
+
+ // Check that events not in the mask are discarded.
+ const FileDescriptor dir_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY));
+
+ const std::vector<Event> events2 =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events2, Are({}));
+}
+
+// Regression test to ensure epoll and directory access doesn't deadlock.
+TEST(Inotify, EpollNoDeadlock) {
+ const DisableSave ds; // Too many syscalls.
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ // Create lots of directories and watch all of them.
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ std::vector<TempPath> children;
+ for (size_t i = 0; i < 1000; ++i) {
+ auto child = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root.path()));
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), child.path(), IN_ACCESS));
+ children.emplace_back(std::move(child));
+ }
+
+ // Run epoll_wait constantly in a separate thread.
+ std::atomic<bool> done(false);
+ ScopedThread th([&fd, &done] {
+ for (auto start = absl::Now(); absl::Now() - start < absl::Seconds(5);) {
+ FileDescriptor epoll_fd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ ASSERT_NO_ERRNO(RegisterEpollFD(epoll_fd.get(), fd.get(),
+ EPOLLIN | EPOLLOUT | EPOLLET, 0));
+ struct epoll_event result[1];
+ EXPECT_THAT(RetryEINTR(epoll_wait)(epoll_fd.get(), result, 1, -1),
+ SyscallSucceedsWithValue(1));
+
+ sched_yield();
+ }
+ done = true;
+ });
+
+ // While epoll thread is running, constantly access all directories to
+ // generate inotify events.
+ while (!done) {
+ std::vector<std::string> files =
+ ASSERT_NO_ERRNO_AND_VALUE(ListDir(root.path(), false));
+ ASSERT_EQ(files.size(), 1002);
+ for (const auto& child : files) {
+ if (child == "." || child == "..") {
+ continue;
+ }
+ ASSERT_NO_ERRNO_AND_VALUE(ListDir(JoinPath(root.path(), child), false));
+ }
+ sched_yield();
+ }
+}
+
+// On Linux, inotify behavior is not very consistent with splice(2). We try our
+// best to emulate Linux for very basic calls to splice.
+TEST(Inotify, SpliceOnWatchTarget) {
+ int pipes[2];
+ ASSERT_THAT(pipe2(pipes, O_NONBLOCK), SyscallSucceeds());
+
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ dir.path(), "some content", TempPath::kDefaultFileMode));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+ const int dir_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(inotify_fd.get(), dir.path(), IN_ALL_EVENTS));
+ const int file_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(inotify_fd.get(), file.path(), IN_ALL_EVENTS));
+
+ EXPECT_THAT(splice(fd.get(), nullptr, pipes[1], nullptr, 1, /*flags=*/0),
+ SyscallSucceedsWithValue(1));
+
+ // Surprisingly, events are not generated in Linux if we read from a file.
+ std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ ASSERT_THAT(events, Are({}));
+
+ EXPECT_THAT(splice(pipes[0], nullptr, fd.get(), nullptr, 1, /*flags=*/0),
+ SyscallSucceedsWithValue(1));
+
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ ASSERT_THAT(events, Are({
+ Event(IN_MODIFY, dir_wd, Basename(file.path())),
+ Event(IN_MODIFY, file_wd),
+ }));
+}
+
+TEST(Inotify, SpliceOnInotifyFD) {
+ int pipes[2];
+ ASSERT_THAT(pipe2(pipes, O_NONBLOCK), SyscallSucceeds());
+
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ root.path(), "some content", TempPath::kDefaultFileMode));
+
+ const FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY));
+ const int watcher = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+
+ char buf;
+ EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds());
+
+ EXPECT_THAT(splice(fd.get(), nullptr, pipes[1], nullptr,
+ sizeof(struct inotify_event) + 1, SPLICE_F_NONBLOCK),
+ SyscallSucceedsWithValue(sizeof(struct inotify_event)));
+
+ const FileDescriptor read_fd(pipes[0]);
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(read_fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_ACCESS, watcher)}));
+}
+
+// Watches on a parent should not be triggered by actions on a hard link to one
+// of its children that has a different parent.
+TEST(Inotify, LinkOnOtherParent) {
+ const TempPath dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path()));
+ std::string link_path = NewTempAbsPathInDir(dir2.path());
+
+ const int rc = link(file.path().c_str(), link_path.c_str());
+ // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox.
+ SKIP_IF(IsRunningOnGvisor() && rc != 0 &&
+ (errno == EPERM || errno == ENOENT));
+ ASSERT_THAT(rc, SyscallSucceeds());
+
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(inotify_fd.get(), dir1.path(), IN_ALL_EVENTS));
+
+ // Perform various actions on the link outside of dir1, which should trigger
+ // no inotify events.
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(link_path.c_str(), O_RDWR));
+ int val = 0;
+ ASSERT_THAT(write(fd.get(), &val, sizeof(val)), SyscallSucceeds());
+ ASSERT_THAT(read(fd.get(), &val, sizeof(val)), SyscallSucceeds());
+ ASSERT_THAT(ftruncate(fd.get(), 12345), SyscallSucceeds());
+ ASSERT_THAT(unlink(link_path.c_str()), SyscallSucceeds());
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({}));
+}
+
+TEST(Inotify, Xattr) {
+ // TODO(gvisor.dev/issue/1636): Support extended attributes in runsc gofer.
+ SKIP_IF(IsRunningOnGvisor());
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const std::string path = file.path();
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_RDWR));
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(inotify_fd.get(), path, IN_ALL_EVENTS));
+
+ const char* cpath = path.c_str();
+ const char* name = "user.test";
+ int val = 123;
+ ASSERT_THAT(setxattr(cpath, name, &val, sizeof(val), /*flags=*/0),
+ SyscallSucceeds());
+ std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({Event(IN_ATTRIB, wd)}));
+
+ ASSERT_THAT(getxattr(cpath, name, &val, sizeof(val)), SyscallSucceeds());
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({}));
+
+ char list[100];
+ ASSERT_THAT(listxattr(cpath, list, sizeof(list)), SyscallSucceeds());
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({}));
+
+ ASSERT_THAT(removexattr(cpath, name), SyscallSucceeds());
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({Event(IN_ATTRIB, wd)}));
+
+ ASSERT_THAT(fsetxattr(fd.get(), name, &val, sizeof(val), /*flags=*/0),
+ SyscallSucceeds());
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({Event(IN_ATTRIB, wd)}));
+
+ ASSERT_THAT(fgetxattr(fd.get(), name, &val, sizeof(val)), SyscallSucceeds());
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({}));
+
+ ASSERT_THAT(flistxattr(fd.get(), list, sizeof(list)), SyscallSucceeds());
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({}));
+
+ ASSERT_THAT(fremovexattr(fd.get(), name), SyscallSucceeds());
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({Event(IN_ATTRIB, wd)}));
+}
+
+TEST(Inotify, Exec) {
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath bin = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(dir.path(), "/bin/true"));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), bin.path(), IN_ALL_EVENTS));
+
+ // Perform exec.
+ ScopedThread t([&bin]() {
+ ASSERT_THAT(execl(bin.path().c_str(), bin.path().c_str(), (char*)nullptr),
+ SyscallSucceeds());
+ });
+ t.Join();
+
+ std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ EXPECT_THAT(events, Are({Event(IN_OPEN, wd), Event(IN_ACCESS, wd)}));
+}
+
+// Watches without IN_EXCL_UNLINK, should continue to emit events for file
+// descriptors after their corresponding files have been unlinked.
+//
+// We need to disable S/R because there are filesystems where we cannot re-open
+// fds to an unlinked file across S/R, e.g. gofer-backed filesytems.
+TEST(Inotify, IncludeUnlinkedFile_NoRandomSave) {
+ const DisableSave ds;
+
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(dir.path(), "123", TempPath::kDefaultFileMode));
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const int dir_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(inotify_fd.get(), dir.path(), IN_ALL_EVENTS));
+ const int file_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(inotify_fd.get(), file.path(), IN_ALL_EVENTS));
+
+ ASSERT_THAT(unlink(file.path().c_str()), SyscallSucceeds());
+ int val = 0;
+ ASSERT_THAT(read(fd.get(), &val, sizeof(val)), SyscallSucceeds());
+ ASSERT_THAT(write(fd.get(), &val, sizeof(val)), SyscallSucceeds());
+ std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({
+ Event(IN_ATTRIB, file_wd),
+ Event(IN_DELETE, dir_wd, Basename(file.path())),
+ Event(IN_ACCESS, dir_wd, Basename(file.path())),
+ Event(IN_ACCESS, file_wd),
+ Event(IN_MODIFY, dir_wd, Basename(file.path())),
+ Event(IN_MODIFY, file_wd),
+ }));
+
+ fd.reset();
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({
+ Event(IN_CLOSE_WRITE, dir_wd, Basename(file.path())),
+ Event(IN_CLOSE_WRITE, file_wd),
+ Event(IN_DELETE_SELF, file_wd),
+ Event(IN_IGNORED, file_wd),
+ }));
+}
+
+// Watches created with IN_EXCL_UNLINK will stop emitting events on fds for
+// children that have already been unlinked.
+//
+// We need to disable S/R because there are filesystems where we cannot re-open
+// fds to an unlinked file across S/R, e.g. gofer-backed filesytems.
+TEST(Inotify, ExcludeUnlink_NoRandomSave) {
+ const DisableSave ds;
+ // TODO(gvisor.dev/issue/1624): This test fails on VFS1.
+ SKIP_IF(IsRunningWithVFS1());
+
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const int dir_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(
+ inotify_fd.get(), dir.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK));
+ const int file_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(
+ inotify_fd.get(), file.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK));
+
+ // Unlink the child, which should cause further operations on the open file
+ // descriptor to be ignored.
+ ASSERT_THAT(unlink(file.path().c_str()), SyscallSucceeds());
+ int val = 0;
+ ASSERT_THAT(write(fd.get(), &val, sizeof(val)), SyscallSucceeds());
+ ASSERT_THAT(read(fd.get(), &val, sizeof(val)), SyscallSucceeds());
+ std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({
+ Event(IN_ATTRIB, file_wd),
+ Event(IN_DELETE, dir_wd, Basename(file.path())),
+ }));
+
+ fd.reset();
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ ASSERT_THAT(events, Are({
+ Event(IN_DELETE_SELF, file_wd),
+ Event(IN_IGNORED, file_wd),
+ }));
+}
+
+// We need to disable S/R because there are filesystems where we cannot re-open
+// fds to an unlinked file across S/R, e.g. gofer-backed filesytems.
+TEST(Inotify, ExcludeUnlinkDirectory_NoRandomSave) {
+ // TODO(gvisor.dev/issue/1624): This test fails on VFS1. Remove once VFS1 is
+ // deleted.
+ SKIP_IF(IsRunningWithVFS1());
+
+ const DisableSave ds;
+
+ const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ TempPath dir =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(parent.path()));
+ std::string dirPath = dir.path();
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dirPath.c_str(), O_RDONLY | O_DIRECTORY));
+ const int parent_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(
+ inotify_fd.get(), parent.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK));
+ const int self_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(
+ inotify_fd.get(), dir.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK));
+
+ // Unlink the dir, and then close the open fd.
+ ASSERT_THAT(rmdir(dirPath.c_str()), SyscallSucceeds());
+ dir.reset();
+
+ std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ // No close event should appear.
+ ASSERT_THAT(events,
+ Are({Event(IN_DELETE | IN_ISDIR, parent_wd, Basename(dirPath))}));
+
+ fd.reset();
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ ASSERT_THAT(events, Are({
+ Event(IN_DELETE_SELF, self_wd),
+ Event(IN_IGNORED, self_wd),
+ }));
+}
+
+// If "dir/child" and "dir/child2" are links to the same file, and "dir/child"
+// is unlinked, a watch on "dir" with IN_EXCL_UNLINK will exclude future events
+// for fds on "dir/child" but not "dir/child2".
+//
+// We need to disable S/R because there are filesystems where we cannot re-open
+// fds to an unlinked file across S/R, e.g. gofer-backed filesytems.
+TEST(Inotify, ExcludeUnlinkMultipleChildren_NoRandomSave) {
+ const DisableSave ds;
+ // TODO(gvisor.dev/issue/1624): This test fails on VFS1.
+ SKIP_IF(IsRunningWithVFS1());
+
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
+ std::string path1 = file.path();
+ std::string path2 = NewTempAbsPathInDir(dir.path());
+
+ const int rc = link(path1.c_str(), path2.c_str());
+ // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox.
+ SKIP_IF(IsRunningOnGvisor() && rc != 0 &&
+ (errno == EPERM || errno == ENOENT));
+ ASSERT_THAT(rc, SyscallSucceeds());
+ const FileDescriptor fd1 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(path1.c_str(), O_RDWR));
+ const FileDescriptor fd2 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(path2.c_str(), O_RDWR));
+
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(
+ inotify_fd.get(), dir.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK));
+
+ // After unlinking path1, only events on the fd for path2 should be generated.
+ ASSERT_THAT(unlink(path1.c_str()), SyscallSucceeds());
+ ASSERT_THAT(write(fd1.get(), "x", 1), SyscallSucceeds());
+ ASSERT_THAT(write(fd2.get(), "x", 1), SyscallSucceeds());
+
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({
+ Event(IN_DELETE, wd, Basename(path1)),
+ Event(IN_MODIFY, wd, Basename(path2)),
+ }));
+}
+
+// On native Linux, actions of data type FSNOTIFY_EVENT_INODE are not affected
+// by IN_EXCL_UNLINK (see
+// fs/notify/inotify/inotify_fsnotify.c:inotify_handle_event). Inode-level
+// events include changes to metadata and extended attributes.
+//
+// We need to disable S/R because there are filesystems where we cannot re-open
+// fds to an unlinked file across S/R, e.g. gofer-backed filesytems.
+TEST(Inotify, ExcludeUnlinkInodeEvents_NoRandomSave) {
+ const DisableSave ds;
+
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path().c_str(), O_RDWR));
+
+ // NOTE(b/157163751): Create another link before unlinking. This is needed for
+ // the gofer filesystem in gVisor, where open fds will not work once the link
+ // count hits zero. In VFS2, we end up skipping the gofer test anyway, because
+ // hard links are not supported for gofer fs.
+ if (IsRunningOnGvisor()) {
+ std::string link_path = NewTempAbsPath();
+ const int rc = link(file.path().c_str(), link_path.c_str());
+ // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox.
+ SKIP_IF(rc != 0 && (errno == EPERM || errno == ENOENT));
+ ASSERT_THAT(rc, SyscallSucceeds());
+ }
+
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const int dir_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(
+ inotify_fd.get(), dir.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK));
+ const int file_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(
+ inotify_fd.get(), file.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK));
+
+ // Even after unlinking, inode-level operations will trigger events regardless
+ // of IN_EXCL_UNLINK.
+ ASSERT_THAT(unlink(file.path().c_str()), SyscallSucceeds());
+
+ // Perform various actions on fd.
+ ASSERT_THAT(ftruncate(fd.get(), 12345), SyscallSucceeds());
+ std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({
+ Event(IN_ATTRIB, file_wd),
+ Event(IN_DELETE, dir_wd, Basename(file.path())),
+ Event(IN_MODIFY, dir_wd, Basename(file.path())),
+ Event(IN_MODIFY, file_wd),
+ }));
+
+ const struct timeval times[2] = {{1, 0}, {2, 0}};
+ ASSERT_THAT(futimes(fd.get(), times), SyscallSucceeds());
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({
+ Event(IN_ATTRIB, dir_wd, Basename(file.path())),
+ Event(IN_ATTRIB, file_wd),
+ }));
+
+ // S/R is disabled on this entire test due to behavior with unlink; it must
+ // also be disabled after this point because of fchmod.
+ ASSERT_THAT(fchmod(fd.get(), 0777), SyscallSucceeds());
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({
+ Event(IN_ATTRIB, dir_wd, Basename(file.path())),
+ Event(IN_ATTRIB, file_wd),
+ }));
+}
+
+TEST(Inotify, OneShot) {
+ // TODO(gvisor.dev/issue/1624): IN_ONESHOT not supported in VFS1.
+ SKIP_IF(IsRunningWithVFS1());
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(inotify_fd.get(), file.path(), IN_MODIFY | IN_ONESHOT));
+
+ // Open an fd, write to it, and then close it.
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY));
+ ASSERT_THAT(write(fd.get(), "x", 1), SyscallSucceedsWithValue(1));
+ fd.reset();
+
+ // We should get a single event followed by IN_IGNORED indicating removal
+ // of the one-shot watch. Prior activity (i.e. open) that is not in the mask
+ // should not trigger removal, and activity after removal (i.e. close) should
+ // not generate events.
+ std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({
+ Event(IN_MODIFY, wd),
+ Event(IN_IGNORED, wd),
+ }));
+
+ // The watch should already have been removed.
+ EXPECT_THAT(inotify_rm_watch(inotify_fd.get(), wd),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// This test helps verify that the lock order of filesystem and inotify locks
+// is respected when inotify instances and watch targets are concurrently being
+// destroyed.
+TEST(InotifyTest, InotifyAndTargetDestructionDoNotDeadlock_NoRandomSave) {
+ const DisableSave ds; // Too many syscalls.
+
+ // A file descriptor protected by a mutex. This ensures that while a
+ // descriptor is in use, it cannot be closed and reused for a different file
+ // description.
+ struct atomic_fd {
+ int fd;
+ absl::Mutex mu;
+ };
+
+ // Set up initial inotify instances.
+ constexpr int num_fds = 3;
+ std::vector<atomic_fd> fds(num_fds);
+ for (int i = 0; i < num_fds; i++) {
+ int fd;
+ ASSERT_THAT(fd = inotify_init1(IN_NONBLOCK), SyscallSucceeds());
+ fds[i].fd = fd;
+ }
+
+ // Set up initial watch targets.
+ std::vector<std::string> paths;
+ for (int i = 0; i < 3; i++) {
+ paths.push_back(NewTempAbsPath());
+ ASSERT_THAT(mknod(paths[i].c_str(), S_IFREG | 0600, 0), SyscallSucceeds());
+ }
+
+ constexpr absl::Duration runtime = absl::Seconds(4);
+
+ // Constantly replace each inotify instance with a new one.
+ auto replace_fds = [&] {
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ for (auto& afd : fds) {
+ int new_fd;
+ ASSERT_THAT(new_fd = inotify_init1(IN_NONBLOCK), SyscallSucceeds());
+ absl::MutexLock l(&afd.mu);
+ ASSERT_THAT(close(afd.fd), SyscallSucceeds());
+ afd.fd = new_fd;
+ for (auto& p : paths) {
+ // inotify_add_watch may fail if the file at p was deleted.
+ ASSERT_THAT(inotify_add_watch(afd.fd, p.c_str(), IN_ALL_EVENTS),
+ AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ENOENT)));
+ }
+ }
+ sched_yield();
+ }
+ };
+
+ std::list<ScopedThread> ts;
+ for (int i = 0; i < 3; i++) {
+ ts.emplace_back(replace_fds);
+ }
+
+ // Constantly replace each watch target with a new one.
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ for (auto& p : paths) {
+ ASSERT_THAT(unlink(p.c_str()), SyscallSucceeds());
+ ASSERT_THAT(mknod(p.c_str(), S_IFREG | 0600, 0), SyscallSucceeds());
+ }
+ sched_yield();
+ }
+}
+
+// This test helps verify that the lock order of filesystem and inotify locks
+// is respected when adding/removing watches occurs concurrently with the
+// removal of their targets.
+TEST(InotifyTest, AddRemoveUnlinkDoNotDeadlock_NoRandomSave) {
+ const DisableSave ds; // Too many syscalls.
+
+ // Set up inotify instances.
+ constexpr int num_fds = 3;
+ std::vector<int> fds(num_fds);
+ for (int i = 0; i < num_fds; i++) {
+ ASSERT_THAT(fds[i] = inotify_init1(IN_NONBLOCK), SyscallSucceeds());
+ }
+
+ // Set up initial watch targets.
+ std::vector<std::string> paths;
+ for (int i = 0; i < 3; i++) {
+ paths.push_back(NewTempAbsPath());
+ ASSERT_THAT(mknod(paths[i].c_str(), S_IFREG | 0600, 0), SyscallSucceeds());
+ }
+
+ constexpr absl::Duration runtime = absl::Seconds(1);
+
+ // Constantly add/remove watches for each inotify instance/watch target pair.
+ auto add_remove_watches = [&] {
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ for (int fd : fds) {
+ for (auto& p : paths) {
+ // Do not assert on inotify_add_watch and inotify_rm_watch. They may
+ // fail if the file at p was deleted. inotify_add_watch may also fail
+ // if another thread beat us to adding a watch.
+ const int wd = inotify_add_watch(fd, p.c_str(), IN_ALL_EVENTS);
+ if (wd > 0) {
+ inotify_rm_watch(fd, wd);
+ }
+ }
+ }
+ sched_yield();
+ }
+ };
+
+ std::list<ScopedThread> ts;
+ for (int i = 0; i < 15; i++) {
+ ts.emplace_back(add_remove_watches);
+ }
+
+ // Constantly replace each watch target with a new one.
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ for (auto& p : paths) {
+ ASSERT_THAT(unlink(p.c_str()), SyscallSucceeds());
+ ASSERT_THAT(mknod(p.c_str(), S_IFREG | 0600, 0), SyscallSucceeds());
+ }
+ sched_yield();
+ }
+}
+
+// This test helps verify that the lock order of filesystem and inotify locks
+// is respected when many inotify events and filesystem operations occur
+// simultaneously.
+TEST(InotifyTest, NotifyNoDeadlock_NoRandomSave) {
+ const DisableSave ds; // Too many syscalls.
+
+ const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const std::string dir = parent.path();
+
+ // mu protects file, which will change on rename.
+ absl::Mutex mu;
+ std::string file = NewTempAbsPathInDir(dir);
+ ASSERT_THAT(mknod(file.c_str(), 0644 | S_IFREG, 0), SyscallSucceeds());
+
+ const absl::Duration runtime = absl::Milliseconds(300);
+
+ // Add/remove watches on dir and file.
+ ScopedThread add_remove_watches([&] {
+ const FileDescriptor ifd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ int dir_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(ifd.get(), dir, IN_ALL_EVENTS));
+ int file_wd;
+ {
+ absl::ReaderMutexLock l(&mu);
+ file_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(ifd.get(), file, IN_ALL_EVENTS));
+ }
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ ASSERT_THAT(inotify_rm_watch(ifd.get(), file_wd), SyscallSucceeds());
+ ASSERT_THAT(inotify_rm_watch(ifd.get(), dir_wd), SyscallSucceeds());
+ dir_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(ifd.get(), dir, IN_ALL_EVENTS));
+ {
+ absl::ReaderMutexLock l(&mu);
+ file_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(ifd.get(), file, IN_ALL_EVENTS));
+ }
+ sched_yield();
+ }
+ });
+
+ // Modify attributes on dir and file.
+ ScopedThread stats([&] {
+ int fd, dir_fd;
+ {
+ absl::ReaderMutexLock l(&mu);
+ ASSERT_THAT(fd = open(file.c_str(), O_RDONLY), SyscallSucceeds());
+ }
+ ASSERT_THAT(dir_fd = open(dir.c_str(), O_RDONLY | O_DIRECTORY),
+ SyscallSucceeds());
+ const struct timeval times[2] = {{1, 0}, {2, 0}};
+
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ {
+ absl::ReaderMutexLock l(&mu);
+ EXPECT_THAT(utimes(file.c_str(), times), SyscallSucceeds());
+ }
+ EXPECT_THAT(futimes(fd, times), SyscallSucceeds());
+ EXPECT_THAT(utimes(dir.c_str(), times), SyscallSucceeds());
+ EXPECT_THAT(futimes(dir_fd, times), SyscallSucceeds());
+ sched_yield();
+ }
+ });
+
+ // Modify extended attributes on dir and file.
+ ScopedThread xattrs([&] {
+ // TODO(gvisor.dev/issue/1636): Support extended attributes in runsc gofer.
+ if (!IsRunningOnGvisor()) {
+ int fd;
+ {
+ absl::ReaderMutexLock l(&mu);
+ ASSERT_THAT(fd = open(file.c_str(), O_RDONLY), SyscallSucceeds());
+ }
+
+ const char* name = "user.test";
+ int val = 123;
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ {
+ absl::ReaderMutexLock l(&mu);
+ ASSERT_THAT(
+ setxattr(file.c_str(), name, &val, sizeof(val), /*flags=*/0),
+ SyscallSucceeds());
+ ASSERT_THAT(removexattr(file.c_str(), name), SyscallSucceeds());
+ }
+
+ ASSERT_THAT(fsetxattr(fd, name, &val, sizeof(val), /*flags=*/0),
+ SyscallSucceeds());
+ ASSERT_THAT(fremovexattr(fd, name), SyscallSucceeds());
+ sched_yield();
+ }
+ }
+ });
+
+ // Read and write file's contents. Read and write dir's entries.
+ ScopedThread read_write([&] {
+ int fd;
+ {
+ absl::ReaderMutexLock l(&mu);
+ ASSERT_THAT(fd = open(file.c_str(), O_RDWR), SyscallSucceeds());
+ }
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ int val = 123;
+ ASSERT_THAT(write(fd, &val, sizeof(val)), SyscallSucceeds());
+ ASSERT_THAT(read(fd, &val, sizeof(val)), SyscallSucceeds());
+ TempPath new_file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir));
+ ASSERT_NO_ERRNO(ListDir(dir, false));
+ new_file.reset();
+ sched_yield();
+ }
+ });
+
+ // Rename file.
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ const std::string new_path = NewTempAbsPathInDir(dir);
+ {
+ absl::WriterMutexLock l(&mu);
+ ASSERT_THAT(rename(file.c_str(), new_path.c_str()), SyscallSucceeds());
+ file = new_path;
+ }
+ sched_yield();
+ }
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/ioctl.cc b/test/syscalls/linux/ioctl.cc
new file mode 100644
index 000000000..b0a07a064
--- /dev/null
+++ b/test/syscalls/linux/ioctl.cc
@@ -0,0 +1,406 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <errno.h>
+#include <fcntl.h>
+#include <net/if.h>
+#include <netdb.h>
+#include <signal.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+bool CheckNonBlocking(int fd) {
+ int ret = fcntl(fd, F_GETFL, 0);
+ TEST_CHECK(ret != -1);
+ return (ret & O_NONBLOCK) == O_NONBLOCK;
+}
+
+bool CheckCloExec(int fd) {
+ int ret = fcntl(fd, F_GETFD, 0);
+ TEST_CHECK(ret != -1);
+ return (ret & FD_CLOEXEC) == FD_CLOEXEC;
+}
+
+class IoctlTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ ASSERT_THAT(fd_ = open("/dev/null", O_RDONLY), SyscallSucceeds());
+ }
+
+ void TearDown() override {
+ if (fd_ >= 0) {
+ ASSERT_THAT(close(fd_), SyscallSucceeds());
+ fd_ = -1;
+ }
+ }
+
+ int fd() const { return fd_; }
+
+ private:
+ int fd_ = -1;
+};
+
+TEST_F(IoctlTest, BadFileDescriptor) {
+ EXPECT_THAT(ioctl(-1 /* fd */, 0), SyscallFailsWithErrno(EBADF));
+}
+
+TEST_F(IoctlTest, InvalidControlNumber) {
+ EXPECT_THAT(ioctl(STDOUT_FILENO, 0), SyscallFailsWithErrno(ENOTTY));
+}
+
+TEST_F(IoctlTest, FIONBIOSucceeds) {
+ EXPECT_FALSE(CheckNonBlocking(fd()));
+ int set = 1;
+ EXPECT_THAT(ioctl(fd(), FIONBIO, &set), SyscallSucceeds());
+ EXPECT_TRUE(CheckNonBlocking(fd()));
+ set = 0;
+ EXPECT_THAT(ioctl(fd(), FIONBIO, &set), SyscallSucceeds());
+ EXPECT_FALSE(CheckNonBlocking(fd()));
+}
+
+TEST_F(IoctlTest, FIONBIOFails) {
+ EXPECT_THAT(ioctl(fd(), FIONBIO, nullptr), SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(IoctlTest, FIONCLEXSucceeds) {
+ EXPECT_THAT(ioctl(fd(), FIONCLEX), SyscallSucceeds());
+ EXPECT_FALSE(CheckCloExec(fd()));
+}
+
+TEST_F(IoctlTest, FIOCLEXSucceeds) {
+ EXPECT_THAT(ioctl(fd(), FIOCLEX), SyscallSucceeds());
+ EXPECT_TRUE(CheckCloExec(fd()));
+}
+
+TEST_F(IoctlTest, FIOASYNCFails) {
+ EXPECT_THAT(ioctl(fd(), FIOASYNC, nullptr), SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(IoctlTest, FIOASYNCSucceeds) {
+ // Not all FDs support FIOASYNC.
+ const FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ int before = -1;
+ ASSERT_THAT(before = fcntl(s.get(), F_GETFL), SyscallSucceeds());
+
+ int set = 1;
+ EXPECT_THAT(ioctl(s.get(), FIOASYNC, &set), SyscallSucceeds());
+
+ int after_set = -1;
+ ASSERT_THAT(after_set = fcntl(s.get(), F_GETFL), SyscallSucceeds());
+ EXPECT_EQ(after_set, before | O_ASYNC) << "before was " << before;
+
+ set = 0;
+ EXPECT_THAT(ioctl(s.get(), FIOASYNC, &set), SyscallSucceeds());
+
+ ASSERT_THAT(fcntl(s.get(), F_GETFL), SyscallSucceedsWithValue(before));
+}
+
+/* Count of the number of SIGIOs handled. */
+static volatile int io_received = 0;
+
+void inc_io_handler(int sig, siginfo_t* siginfo, void* arg) { io_received++; }
+
+TEST_F(IoctlTest, FIOASYNCNoTarget) {
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ // Count SIGIOs received.
+ io_received = 0;
+ struct sigaction sa;
+ sa.sa_sigaction = inc_io_handler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_RESTART;
+ auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGIO, sa));
+
+ // Actually allow SIGIO delivery.
+ auto mask_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGIO));
+
+ int set = 1;
+ EXPECT_THAT(ioctl(pair->second_fd(), FIOASYNC, &set), SyscallSucceeds());
+
+ constexpr char kData[] = "abc";
+ ASSERT_THAT(WriteFd(pair->first_fd(), kData, sizeof(kData)),
+ SyscallSucceedsWithValue(sizeof(kData)));
+
+ EXPECT_EQ(io_received, 0);
+}
+
+TEST_F(IoctlTest, FIOASYNCSelfTarget) {
+ // FIXME(b/120624367): gVisor erroneously sends SIGIO on close(2), which would
+ // kill the test when pair goes out of scope. Temporarily ignore SIGIO so that
+ // that the close signal is ignored.
+ struct sigaction sa;
+ sa.sa_handler = SIG_IGN;
+ auto early_sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGIO, sa));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ // Count SIGIOs received.
+ io_received = 0;
+ sa.sa_sigaction = inc_io_handler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_RESTART;
+ auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGIO, sa));
+
+ // Actually allow SIGIO delivery.
+ auto mask_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGIO));
+
+ int set = 1;
+ EXPECT_THAT(ioctl(pair->second_fd(), FIOASYNC, &set), SyscallSucceeds());
+
+ pid_t pid = getpid();
+ EXPECT_THAT(ioctl(pair->second_fd(), FIOSETOWN, &pid), SyscallSucceeds());
+
+ constexpr char kData[] = "abc";
+ ASSERT_THAT(WriteFd(pair->first_fd(), kData, sizeof(kData)),
+ SyscallSucceedsWithValue(sizeof(kData)));
+
+ EXPECT_EQ(io_received, 1);
+}
+
+// Equivalent to FIOASYNCSelfTarget except that FIOSETOWN is called before
+// FIOASYNC.
+TEST_F(IoctlTest, FIOASYNCSelfTarget2) {
+ // FIXME(b/120624367): gVisor erroneously sends SIGIO on close(2), which would
+ // kill the test when pair goes out of scope. Temporarily ignore SIGIO so that
+ // that the close signal is ignored.
+ struct sigaction sa;
+ sa.sa_handler = SIG_IGN;
+ auto early_sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGIO, sa));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ // Count SIGIOs received.
+ io_received = 0;
+ sa.sa_sigaction = inc_io_handler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_RESTART;
+ auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGIO, sa));
+
+ // Actually allow SIGIO delivery.
+ auto mask_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGIO));
+
+ pid_t pid = -1;
+ EXPECT_THAT(pid = getpid(), SyscallSucceeds());
+ EXPECT_THAT(ioctl(pair->second_fd(), FIOSETOWN, &pid), SyscallSucceeds());
+
+ int set = 1;
+ EXPECT_THAT(ioctl(pair->second_fd(), FIOASYNC, &set), SyscallSucceeds());
+
+ constexpr char kData[] = "abc";
+ ASSERT_THAT(WriteFd(pair->first_fd(), kData, sizeof(kData)),
+ SyscallSucceedsWithValue(sizeof(kData)));
+
+ EXPECT_EQ(io_received, 1);
+}
+
+// Check that closing an FD does not result in an event.
+TEST_F(IoctlTest, FIOASYNCSelfTargetClose) {
+ // Count SIGIOs received.
+ struct sigaction sa;
+ io_received = 0;
+ sa.sa_sigaction = inc_io_handler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_RESTART;
+ auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGIO, sa));
+
+ // Actually allow SIGIO delivery.
+ auto mask_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGIO));
+
+ for (int i = 0; i < 2; i++) {
+ auto pair = ASSERT_NO_ERRNO_AND_VALUE(
+ UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ pid_t pid = getpid();
+ EXPECT_THAT(ioctl(pair->second_fd(), FIOSETOWN, &pid), SyscallSucceeds());
+
+ int set = 1;
+ EXPECT_THAT(ioctl(pair->second_fd(), FIOASYNC, &set), SyscallSucceeds());
+ }
+
+ // FIXME(b/120624367): gVisor erroneously sends SIGIO on close.
+ SKIP_IF(IsRunningOnGvisor());
+
+ EXPECT_EQ(io_received, 0);
+}
+
+TEST_F(IoctlTest, FIOASYNCInvalidPID) {
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+ int set = 1;
+ ASSERT_THAT(ioctl(pair->second_fd(), FIOASYNC, &set), SyscallSucceeds());
+ pid_t pid = INT_MAX;
+ // This succeeds (with behavior equivalent to a pid of 0) in Linux prior to
+ // f73127356f34 "fs/fcntl: return -ESRCH in f_setown when pid/pgid can't be
+ // found", and fails with EPERM after that commit.
+ EXPECT_THAT(ioctl(pair->second_fd(), FIOSETOWN, &pid),
+ AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ESRCH)));
+}
+
+TEST_F(IoctlTest, FIOASYNCUnsetTarget) {
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ // Count SIGIOs received.
+ io_received = 0;
+ struct sigaction sa;
+ sa.sa_sigaction = inc_io_handler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_RESTART;
+ auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGIO, sa));
+
+ // Actually allow SIGIO delivery.
+ auto mask_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGIO));
+
+ int set = 1;
+ EXPECT_THAT(ioctl(pair->second_fd(), FIOASYNC, &set), SyscallSucceeds());
+
+ pid_t pid = getpid();
+ EXPECT_THAT(ioctl(pair->second_fd(), FIOSETOWN, &pid), SyscallSucceeds());
+
+ // Passing a PID of 0 unsets the target.
+ pid = 0;
+ EXPECT_THAT(ioctl(pair->second_fd(), FIOSETOWN, &pid), SyscallSucceeds());
+
+ constexpr char kData[] = "abc";
+ ASSERT_THAT(WriteFd(pair->first_fd(), kData, sizeof(kData)),
+ SyscallSucceedsWithValue(sizeof(kData)));
+
+ EXPECT_EQ(io_received, 0);
+}
+
+using IoctlTestSIOCGIFCONF = SimpleSocketTest;
+
+TEST_P(IoctlTestSIOCGIFCONF, ValidateNoArrayGetsLength) {
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Validate that no array can be used to get the length required.
+ struct ifconf ifconf = {};
+ ASSERT_THAT(ioctl(fd->get(), SIOCGIFCONF, &ifconf), SyscallSucceeds());
+ ASSERT_GT(ifconf.ifc_len, 0);
+}
+
+// This test validates that we will only return a partial array list and not
+// partial ifrreq structs.
+TEST_P(IoctlTestSIOCGIFCONF, ValidateNoPartialIfrsReturned) {
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ struct ifreq ifr = {};
+ struct ifconf ifconf = {};
+ ifconf.ifc_len = sizeof(ifr) - 1; // One byte too few.
+ ifconf.ifc_ifcu.ifcu_req = &ifr;
+
+ ASSERT_THAT(ioctl(fd->get(), SIOCGIFCONF, &ifconf), SyscallSucceeds());
+ ASSERT_EQ(ifconf.ifc_len, 0);
+ ASSERT_EQ(ifr.ifr_name[0], '\0'); // Nothing is returned.
+
+ ifconf.ifc_len = sizeof(ifreq);
+ ASSERT_THAT(ioctl(fd->get(), SIOCGIFCONF, &ifconf), SyscallSucceeds());
+ ASSERT_GT(ifconf.ifc_len, 0);
+ ASSERT_NE(ifr.ifr_name[0], '\0'); // An interface can now be returned.
+}
+
+TEST_P(IoctlTestSIOCGIFCONF, ValidateLoopbackIsPresent) {
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ struct ifconf ifconf = {};
+ struct ifreq ifr[10] = {}; // Storage for up to 10 interfaces.
+
+ ifconf.ifc_req = ifr;
+ ifconf.ifc_len = sizeof(ifr);
+
+ ASSERT_THAT(ioctl(fd->get(), SIOCGIFCONF, &ifconf), SyscallSucceeds());
+ size_t num_if = ifconf.ifc_len / sizeof(struct ifreq);
+
+ // We should have at least one interface.
+ ASSERT_GE(num_if, 1);
+
+ // One of the interfaces should be a loopback.
+ bool found_loopback = false;
+ for (size_t i = 0; i < num_if; ++i) {
+ if (strcmp(ifr[i].ifr_name, "lo") == 0) {
+ // SIOCGIFCONF returns the ipv4 address of the interface, let's check it.
+ ASSERT_EQ(ifr[i].ifr_addr.sa_family, AF_INET);
+
+ // Validate the address is correct for loopback.
+ sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(&ifr[i].ifr_addr);
+ ASSERT_EQ(htonl(sin->sin_addr.s_addr), INADDR_LOOPBACK);
+
+ found_loopback = true;
+ break;
+ }
+ }
+ ASSERT_TRUE(found_loopback);
+}
+
+std::vector<SocketKind> IoctlSocketTypes() {
+ return {SimpleSocket(AF_UNIX, SOCK_STREAM, 0),
+ SimpleSocket(AF_UNIX, SOCK_DGRAM, 0),
+ SimpleSocket(AF_INET, SOCK_STREAM, 0),
+ SimpleSocket(AF_INET6, SOCK_STREAM, 0),
+ SimpleSocket(AF_INET, SOCK_DGRAM, 0),
+ SimpleSocket(AF_INET6, SOCK_DGRAM, 0)};
+}
+
+INSTANTIATE_TEST_SUITE_P(IoctlTest, IoctlTestSIOCGIFCONF,
+ ::testing::ValuesIn(IoctlSocketTypes()));
+
+} // namespace
+
+TEST_F(IoctlTest, FIOGETOWNSucceeds) {
+ const FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ int get = -1;
+ ASSERT_THAT(ioctl(s.get(), FIOGETOWN, &get), SyscallSucceeds());
+ EXPECT_EQ(get, 0);
+}
+
+TEST_F(IoctlTest, SIOCGPGRPSucceeds) {
+ const FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ int get = -1;
+ ASSERT_THAT(ioctl(s.get(), SIOCGPGRP, &get), SyscallSucceeds());
+ EXPECT_EQ(get, 0);
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc
new file mode 100644
index 000000000..98d07ae85
--- /dev/null
+++ b/test/syscalls/linux/ip_socket_test_util.cc
@@ -0,0 +1,239 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/ip_socket_test_util.h"
+
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/socket.h>
+
+#include <cstring>
+
+namespace gvisor {
+namespace testing {
+
+uint32_t IPFromInetSockaddr(const struct sockaddr* addr) {
+ auto* in_addr = reinterpret_cast<const struct sockaddr_in*>(addr);
+ return in_addr->sin_addr.s_addr;
+}
+
+uint16_t PortFromInetSockaddr(const struct sockaddr* addr) {
+ auto* in_addr = reinterpret_cast<const struct sockaddr_in*>(addr);
+ return ntohs(in_addr->sin_port);
+}
+
+PosixErrorOr<int> InterfaceIndex(std::string name) {
+ int index = if_nametoindex(name.c_str());
+ if (index) {
+ return index;
+ }
+ return PosixError(errno);
+}
+
+namespace {
+
+std::string DescribeSocketType(int type) {
+ return absl::StrCat(((type & SOCK_NONBLOCK) != 0) ? "non-blocking " : "",
+ ((type & SOCK_CLOEXEC) != 0) ? "close-on-exec " : "");
+}
+
+} // namespace
+
+SocketPairKind IPv6TCPAcceptBindSocketPair(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "connected IPv6 TCP socket");
+ return SocketPairKind{
+ description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP,
+ TCPAcceptBindSocketPairCreator(AF_INET6, type | SOCK_STREAM, 0,
+ /* dual_stack = */ false)};
+}
+
+SocketPairKind IPv4TCPAcceptBindSocketPair(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "connected IPv4 TCP socket");
+ return SocketPairKind{
+ description, AF_INET, type | SOCK_STREAM, IPPROTO_TCP,
+ TCPAcceptBindSocketPairCreator(AF_INET, type | SOCK_STREAM, 0,
+ /* dual_stack = */ false)};
+}
+
+SocketPairKind DualStackTCPAcceptBindSocketPair(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "connected dual stack TCP socket");
+ return SocketPairKind{
+ description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP,
+ TCPAcceptBindSocketPairCreator(AF_INET6, type | SOCK_STREAM, 0,
+ /* dual_stack = */ true)};
+}
+
+SocketPairKind IPv6TCPAcceptBindPersistentListenerSocketPair(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "connected IPv6 TCP socket");
+ return SocketPairKind{description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP,
+ TCPAcceptBindPersistentListenerSocketPairCreator(
+ AF_INET6, type | SOCK_STREAM, 0,
+ /* dual_stack = */ false)};
+}
+
+SocketPairKind IPv4TCPAcceptBindPersistentListenerSocketPair(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "connected IPv4 TCP socket");
+ return SocketPairKind{description, AF_INET, type | SOCK_STREAM, IPPROTO_TCP,
+ TCPAcceptBindPersistentListenerSocketPairCreator(
+ AF_INET, type | SOCK_STREAM, 0,
+ /* dual_stack = */ false)};
+}
+
+SocketPairKind DualStackTCPAcceptBindPersistentListenerSocketPair(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "connected dual stack TCP socket");
+ return SocketPairKind{description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP,
+ TCPAcceptBindPersistentListenerSocketPairCreator(
+ AF_INET6, type | SOCK_STREAM, 0,
+ /* dual_stack = */ true)};
+}
+
+SocketPairKind IPv6UDPBidirectionalBindSocketPair(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "connected IPv6 UDP socket");
+ return SocketPairKind{
+ description, AF_INET6, type | SOCK_DGRAM, IPPROTO_UDP,
+ UDPBidirectionalBindSocketPairCreator(AF_INET6, type | SOCK_DGRAM, 0,
+ /* dual_stack = */ false)};
+}
+
+SocketPairKind IPv4UDPBidirectionalBindSocketPair(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "connected IPv4 UDP socket");
+ return SocketPairKind{
+ description, AF_INET, type | SOCK_DGRAM, IPPROTO_UDP,
+ UDPBidirectionalBindSocketPairCreator(AF_INET, type | SOCK_DGRAM, 0,
+ /* dual_stack = */ false)};
+}
+
+SocketPairKind DualStackUDPBidirectionalBindSocketPair(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "connected dual stack UDP socket");
+ return SocketPairKind{
+ description, AF_INET6, type | SOCK_DGRAM, IPPROTO_UDP,
+ UDPBidirectionalBindSocketPairCreator(AF_INET6, type | SOCK_DGRAM, 0,
+ /* dual_stack = */ true)};
+}
+
+SocketPairKind IPv4UDPUnboundSocketPair(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "IPv4 UDP socket");
+ return SocketPairKind{
+ description, AF_INET, type | SOCK_DGRAM, IPPROTO_UDP,
+ UDPUnboundSocketPairCreator(AF_INET, type | SOCK_DGRAM, 0,
+ /* dual_stack = */ false)};
+}
+
+SocketKind IPv4UDPUnboundSocket(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "IPv4 UDP socket");
+ return SocketKind{
+ description, AF_INET, type | SOCK_DGRAM, IPPROTO_UDP,
+ UnboundSocketCreator(AF_INET, type | SOCK_DGRAM, IPPROTO_UDP)};
+}
+
+SocketKind IPv6UDPUnboundSocket(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "IPv6 UDP socket");
+ return SocketKind{
+ description, AF_INET6, type | SOCK_DGRAM, IPPROTO_UDP,
+ UnboundSocketCreator(AF_INET6, type | SOCK_DGRAM, IPPROTO_UDP)};
+}
+
+SocketKind IPv4TCPUnboundSocket(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "IPv4 TCP socket");
+ return SocketKind{
+ description, AF_INET, type | SOCK_STREAM, IPPROTO_TCP,
+ UnboundSocketCreator(AF_INET, type | SOCK_STREAM, IPPROTO_TCP)};
+}
+
+SocketKind IPv6TCPUnboundSocket(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "IPv6 TCP socket");
+ return SocketKind{
+ description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP,
+ UnboundSocketCreator(AF_INET6, type | SOCK_STREAM, IPPROTO_TCP)};
+}
+
+PosixError IfAddrHelper::Load() {
+ Release();
+ RETURN_ERROR_IF_SYSCALL_FAIL(getifaddrs(&ifaddr_));
+ return NoError();
+}
+
+void IfAddrHelper::Release() {
+ if (ifaddr_) {
+ freeifaddrs(ifaddr_);
+ ifaddr_ = nullptr;
+ }
+}
+
+std::vector<std::string> IfAddrHelper::InterfaceList(int family) const {
+ std::vector<std::string> names;
+ for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) {
+ if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) {
+ continue;
+ }
+ names.emplace(names.end(), ifa->ifa_name);
+ }
+ return names;
+}
+
+const sockaddr* IfAddrHelper::GetAddr(int family, std::string name) const {
+ for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) {
+ if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) {
+ continue;
+ }
+ if (name == ifa->ifa_name) {
+ return ifa->ifa_addr;
+ }
+ }
+ return nullptr;
+}
+
+PosixErrorOr<int> IfAddrHelper::GetIndex(std::string name) const {
+ return InterfaceIndex(name);
+}
+
+std::string GetAddr4Str(const in_addr* a) {
+ char str[INET_ADDRSTRLEN];
+ inet_ntop(AF_INET, a, str, sizeof(str));
+ return std::string(str);
+}
+
+std::string GetAddr6Str(const in6_addr* a) {
+ char str[INET6_ADDRSTRLEN];
+ inet_ntop(AF_INET6, a, str, sizeof(str));
+ return std::string(str);
+}
+
+std::string GetAddrStr(const sockaddr* a) {
+ if (a->sa_family == AF_INET) {
+ auto src = &(reinterpret_cast<const sockaddr_in*>(a)->sin_addr);
+ return GetAddr4Str(src);
+ } else if (a->sa_family == AF_INET6) {
+ auto src = &(reinterpret_cast<const sockaddr_in6*>(a)->sin6_addr);
+ return GetAddr6Str(src);
+ }
+ return std::string("<invalid>");
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h
new file mode 100644
index 000000000..9c3859fcd
--- /dev/null
+++ b/test/syscalls/linux/ip_socket_test_util.h
@@ -0,0 +1,135 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_IP_SOCKET_TEST_UTIL_H_
+#define GVISOR_TEST_SYSCALLS_IP_SOCKET_TEST_UTIL_H_
+
+#include <arpa/inet.h>
+#include <ifaddrs.h>
+#include <sys/types.h>
+
+#include <string>
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Extracts the IP address from an inet sockaddr in network byte order.
+uint32_t IPFromInetSockaddr(const struct sockaddr* addr);
+
+// Extracts the port from an inet sockaddr in host byte order.
+uint16_t PortFromInetSockaddr(const struct sockaddr* addr);
+
+// InterfaceIndex returns the index of the named interface.
+PosixErrorOr<int> InterfaceIndex(std::string name);
+
+// IPv6TCPAcceptBindSocketPair returns a SocketPairKind that represents
+// SocketPairs created with bind() and accept() syscalls with AF_INET6 and the
+// given type bound to the IPv6 loopback.
+SocketPairKind IPv6TCPAcceptBindSocketPair(int type);
+
+// IPv4TCPAcceptBindSocketPair returns a SocketPairKind that represents
+// SocketPairs created with bind() and accept() syscalls with AF_INET and the
+// given type bound to the IPv4 loopback.
+SocketPairKind IPv4TCPAcceptBindSocketPair(int type);
+
+// DualStackTCPAcceptBindSocketPair returns a SocketPairKind that represents
+// SocketPairs created with bind() and accept() syscalls with AF_INET6 and the
+// given type bound to the IPv4 loopback.
+SocketPairKind DualStackTCPAcceptBindSocketPair(int type);
+
+// IPv6TCPAcceptBindPersistentListenerSocketPair is like
+// IPv6TCPAcceptBindSocketPair except it uses a persistent listening socket to
+// create all socket pairs.
+SocketPairKind IPv6TCPAcceptBindPersistentListenerSocketPair(int type);
+
+// IPv4TCPAcceptBindPersistentListenerSocketPair is like
+// IPv4TCPAcceptBindSocketPair except it uses a persistent listening socket to
+// create all socket pairs.
+SocketPairKind IPv4TCPAcceptBindPersistentListenerSocketPair(int type);
+
+// DualStackTCPAcceptBindPersistentListenerSocketPair is like
+// DualStackTCPAcceptBindSocketPair except it uses a persistent listening socket
+// to create all socket pairs.
+SocketPairKind DualStackTCPAcceptBindPersistentListenerSocketPair(int type);
+
+// IPv6UDPBidirectionalBindSocketPair returns a SocketPairKind that represents
+// SocketPairs created with bind() and connect() syscalls with AF_INET6 and the
+// given type bound to the IPv6 loopback.
+SocketPairKind IPv6UDPBidirectionalBindSocketPair(int type);
+
+// IPv4UDPBidirectionalBindSocketPair returns a SocketPairKind that represents
+// SocketPairs created with bind() and connect() syscalls with AF_INET and the
+// given type bound to the IPv4 loopback.
+SocketPairKind IPv4UDPBidirectionalBindSocketPair(int type);
+
+// DualStackUDPBidirectionalBindSocketPair returns a SocketPairKind that
+// represents SocketPairs created with bind() and connect() syscalls with
+// AF_INET6 and the given type bound to the IPv4 loopback.
+SocketPairKind DualStackUDPBidirectionalBindSocketPair(int type);
+
+// IPv4UDPUnboundSocketPair returns a SocketPairKind that represents
+// SocketPairs created with AF_INET and the given type.
+SocketPairKind IPv4UDPUnboundSocketPair(int type);
+
+// IPv4UDPUnboundSocket returns a SocketKind that represents a SimpleSocket
+// created with AF_INET, SOCK_DGRAM, and the given type.
+SocketKind IPv4UDPUnboundSocket(int type);
+
+// IPv6UDPUnboundSocket returns a SocketKind that represents a SimpleSocket
+// created with AF_INET6, SOCK_DGRAM, and the given type.
+SocketKind IPv6UDPUnboundSocket(int type);
+
+// IPv4TCPUnboundSocket returns a SocketKind that represents a SimpleSocket
+// created with AF_INET, SOCK_STREAM and the given type.
+SocketKind IPv4TCPUnboundSocket(int type);
+
+// IPv6TCPUnboundSocket returns a SocketKind that represents a SimpleSocket
+// created with AF_INET6, SOCK_STREAM and the given type.
+SocketKind IPv6TCPUnboundSocket(int type);
+
+// IfAddrHelper is a helper class that determines the local interfaces present
+// and provides functions to obtain their names, index numbers, and IP address.
+class IfAddrHelper {
+ public:
+ IfAddrHelper() : ifaddr_(nullptr) {}
+ ~IfAddrHelper() { Release(); }
+
+ PosixError Load();
+ void Release();
+
+ std::vector<std::string> InterfaceList(int family) const;
+
+ const sockaddr* GetAddr(int family, std::string name) const;
+ PosixErrorOr<int> GetIndex(std::string name) const;
+
+ private:
+ struct ifaddrs* ifaddr_;
+};
+
+// GetAddr4Str returns the given IPv4 network address structure as a string.
+std::string GetAddr4Str(const in_addr* a);
+
+// GetAddr6Str returns the given IPv6 network address structure as a string.
+std::string GetAddr6Str(const in6_addr* a);
+
+// GetAddrStr returns the given IPv4 or IPv6 network address structure as a
+// string.
+std::string GetAddrStr(const sockaddr* a);
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_IP_SOCKET_TEST_UTIL_H_
diff --git a/test/syscalls/linux/iptables.cc b/test/syscalls/linux/iptables.cc
new file mode 100644
index 000000000..b8e4ece64
--- /dev/null
+++ b/test/syscalls/linux/iptables.cc
@@ -0,0 +1,204 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/iptables.h"
+
+#include <arpa/inet.h>
+#include <linux/capability.h>
+#include <linux/netfilter/x_tables.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <netinet/ip_icmp.h>
+#include <stdio.h>
+#include <sys/poll.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <algorithm>
+
+#include "gtest/gtest.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr char kNatTablename[] = "nat";
+constexpr char kErrorTarget[] = "ERROR";
+constexpr size_t kEmptyStandardEntrySize =
+ sizeof(struct ipt_entry) + sizeof(struct ipt_standard_target);
+constexpr size_t kEmptyErrorEntrySize =
+ sizeof(struct ipt_entry) + sizeof(struct ipt_error_target);
+
+TEST(IPTablesBasic, CreateSocket) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int sock;
+ ASSERT_THAT(sock = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP),
+ SyscallSucceeds());
+
+ ASSERT_THAT(close(sock), SyscallSucceeds());
+}
+
+TEST(IPTablesBasic, FailSockoptNonRaw) {
+ // Even if the user has CAP_NET_RAW, they shouldn't be able to use the
+ // iptables sockopts with a non-raw socket.
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int sock;
+ ASSERT_THAT(sock = socket(AF_INET, SOCK_DGRAM, 0), SyscallSucceeds());
+
+ struct ipt_getinfo info = {};
+ snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename);
+ socklen_t info_size = sizeof(info);
+ EXPECT_THAT(getsockopt(sock, IPPROTO_IP, SO_GET_INFO, &info, &info_size),
+ SyscallFailsWithErrno(ENOPROTOOPT));
+
+ ASSERT_THAT(close(sock), SyscallSucceeds());
+}
+
+// Fixture for iptables tests.
+class IPTablesTest : public ::testing::Test {
+ protected:
+ // Creates a socket to be used in tests.
+ void SetUp() override;
+
+ // Closes the socket created by SetUp().
+ void TearDown() override;
+
+ // The socket via which to manipulate iptables.
+ int s_;
+};
+
+void IPTablesTest::SetUp() {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), SyscallSucceeds());
+}
+
+void IPTablesTest::TearDown() {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ EXPECT_THAT(close(s_), SyscallSucceeds());
+}
+
+// This tests the initial state of a machine with empty iptables. We don't have
+// a guarantee that the iptables are empty when running in native, but we can
+// test that gVisor has the same initial state that a newly-booted Linux machine
+// would have.
+TEST_F(IPTablesTest, InitialState) {
+ SKIP_IF(!IsRunningOnGvisor());
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ //
+ // Get info via sockopt.
+ //
+ struct ipt_getinfo info = {};
+ snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename);
+ socklen_t info_size = sizeof(info);
+ ASSERT_THAT(getsockopt(s_, IPPROTO_IP, SO_GET_INFO, &info, &info_size),
+ SyscallSucceeds());
+
+ // The nat table supports PREROUTING, and OUTPUT.
+ unsigned int valid_hooks = (1 << NF_IP_PRE_ROUTING) | (1 << NF_IP_LOCAL_OUT) |
+ (1 << NF_IP_POST_ROUTING) | (1 << NF_IP_LOCAL_IN);
+
+ EXPECT_EQ(info.valid_hooks, valid_hooks);
+
+ // Each chain consists of an empty entry with a standard target..
+ EXPECT_EQ(info.hook_entry[NF_IP_PRE_ROUTING], 0);
+ EXPECT_EQ(info.hook_entry[NF_IP_LOCAL_IN], kEmptyStandardEntrySize);
+ EXPECT_EQ(info.hook_entry[NF_IP_LOCAL_OUT], kEmptyStandardEntrySize * 2);
+ EXPECT_EQ(info.hook_entry[NF_IP_POST_ROUTING], kEmptyStandardEntrySize * 3);
+
+ // The underflow points are the same as the entry points.
+ EXPECT_EQ(info.underflow[NF_IP_PRE_ROUTING], 0);
+ EXPECT_EQ(info.underflow[NF_IP_LOCAL_IN], kEmptyStandardEntrySize);
+ EXPECT_EQ(info.underflow[NF_IP_LOCAL_OUT], kEmptyStandardEntrySize * 2);
+ EXPECT_EQ(info.underflow[NF_IP_POST_ROUTING], kEmptyStandardEntrySize * 3);
+
+ // One entry for each chain, plus an error entry at the end.
+ EXPECT_EQ(info.num_entries, 5);
+
+ EXPECT_EQ(info.size, 4 * kEmptyStandardEntrySize + kEmptyErrorEntrySize);
+ EXPECT_EQ(strcmp(info.name, kNatTablename), 0);
+
+ //
+ // Use info to get entries.
+ //
+ socklen_t entries_size = sizeof(struct ipt_get_entries) + info.size;
+ struct ipt_get_entries* entries =
+ static_cast<struct ipt_get_entries*>(malloc(entries_size));
+ snprintf(entries->name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename);
+ entries->size = info.size;
+ ASSERT_THAT(
+ getsockopt(s_, IPPROTO_IP, SO_GET_ENTRIES, entries, &entries_size),
+ SyscallSucceeds());
+
+ // Verify the name and size.
+ ASSERT_EQ(info.size, entries->size);
+ ASSERT_EQ(strcmp(entries->name, kNatTablename), 0);
+
+ // Verify that the entrytable is 4 entries with accept targets and no matches
+ // followed by a single error target.
+ size_t entry_offset = 0;
+ while (entry_offset < entries->size) {
+ struct ipt_entry* entry = reinterpret_cast<struct ipt_entry*>(
+ reinterpret_cast<char*>(entries->entrytable) + entry_offset);
+
+ // ip should be zeroes.
+ struct ipt_ip zeroed = {};
+ EXPECT_EQ(memcmp(static_cast<void*>(&zeroed),
+ static_cast<void*>(&entry->ip), sizeof(zeroed)),
+ 0);
+
+ // target_offset should be zero.
+ EXPECT_EQ(entry->target_offset, sizeof(ipt_entry));
+
+ if (entry_offset < kEmptyStandardEntrySize * 4) {
+ // The first 4 entries are standard targets
+ struct ipt_standard_target* target =
+ reinterpret_cast<struct ipt_standard_target*>(entry->elems);
+ EXPECT_EQ(entry->next_offset, kEmptyStandardEntrySize);
+ EXPECT_EQ(target->target.u.user.target_size, sizeof(*target));
+ EXPECT_EQ(strcmp(target->target.u.user.name, ""), 0);
+ EXPECT_EQ(target->target.u.user.revision, 0);
+ // This is what's returned for an accept verdict. I don't know why.
+ EXPECT_EQ(target->verdict, -NF_ACCEPT - 1);
+ } else {
+ // The last entry is an error target
+ struct ipt_error_target* target =
+ reinterpret_cast<struct ipt_error_target*>(entry->elems);
+ EXPECT_EQ(entry->next_offset, kEmptyErrorEntrySize);
+ EXPECT_EQ(target->target.u.user.target_size, sizeof(*target));
+ EXPECT_EQ(strcmp(target->target.u.user.name, kErrorTarget), 0);
+ EXPECT_EQ(target->target.u.user.revision, 0);
+ EXPECT_EQ(strcmp(target->errorname, kErrorTarget), 0);
+ }
+
+ entry_offset += entry->next_offset;
+ }
+
+ free(entries);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/iptables.h b/test/syscalls/linux/iptables.h
new file mode 100644
index 000000000..0719c60a4
--- /dev/null
+++ b/test/syscalls/linux/iptables.h
@@ -0,0 +1,198 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// There are a number of structs and values that we can't #include because of a
+// difference between C and C++ (C++ won't let you implicitly cast from void* to
+// struct something*). We re-define them here.
+
+#ifndef GVISOR_TEST_SYSCALLS_IPTABLES_TYPES_H_
+#define GVISOR_TEST_SYSCALLS_IPTABLES_TYPES_H_
+
+// Netfilter headers require some headers to preceed them.
+// clang-format off
+#include <netinet/in.h>
+#include <stddef.h>
+// clang-format on
+
+#include <linux/netfilter/x_tables.h>
+#include <linux/netfilter_ipv4.h>
+#include <net/if.h>
+#include <netinet/ip.h>
+#include <stdint.h>
+
+#define ipt_standard_target xt_standard_target
+#define ipt_entry_target xt_entry_target
+#define ipt_error_target xt_error_target
+
+enum SockOpts {
+ // For setsockopt.
+ BASE_CTL = 64,
+ SO_SET_REPLACE = BASE_CTL,
+ SO_SET_ADD_COUNTERS,
+ SO_SET_MAX = SO_SET_ADD_COUNTERS,
+
+ // For getsockopt.
+ SO_GET_INFO = BASE_CTL,
+ SO_GET_ENTRIES,
+ SO_GET_REVISION_MATCH,
+ SO_GET_REVISION_TARGET,
+ SO_GET_MAX = SO_GET_REVISION_TARGET
+};
+
+// ipt_ip specifies basic matching criteria that can be applied by examining
+// only the IP header of a packet.
+struct ipt_ip {
+ // Source IP address.
+ struct in_addr src;
+
+ // Destination IP address.
+ struct in_addr dst;
+
+ // Source IP address mask.
+ struct in_addr smsk;
+
+ // Destination IP address mask.
+ struct in_addr dmsk;
+
+ // Input interface.
+ char iniface[IFNAMSIZ];
+
+ // Output interface.
+ char outiface[IFNAMSIZ];
+
+ // Input interface mask.
+ unsigned char iniface_mask[IFNAMSIZ];
+
+ // Output interface mask.
+ unsigned char outiface_mask[IFNAMSIZ];
+
+ // Transport protocol.
+ uint16_t proto;
+
+ // Flags.
+ uint8_t flags;
+
+ // Inverse flags.
+ uint8_t invflags;
+};
+
+// ipt_entry is an iptables rule. It contains information about what packets the
+// rule matches and what action (target) to perform for matching packets.
+struct ipt_entry {
+ // Basic matching information used to match a packet's IP header.
+ struct ipt_ip ip;
+
+ // A caching field that isn't used by userspace.
+ unsigned int nfcache;
+
+ // The number of bytes between the start of this ipt_entry struct and the
+ // rule's target.
+ uint16_t target_offset;
+
+ // The total size of this rule, from the beginning of the entry to the end of
+ // the target.
+ uint16_t next_offset;
+
+ // A return pointer not used by userspace.
+ unsigned int comefrom;
+
+ // Counters for packets and bytes, which we don't yet implement.
+ struct xt_counters counters;
+
+ // The data for all this rules matches followed by the target. This runs
+ // beyond the value of sizeof(struct ipt_entry).
+ unsigned char elems[0];
+};
+
+// Passed to getsockopt(SO_GET_INFO).
+struct ipt_getinfo {
+ // The name of the table. The user only fills this in, the rest is filled in
+ // when returning from getsockopt. Currently "nat" and "mangle" are supported.
+ char name[XT_TABLE_MAXNAMELEN];
+
+ // A bitmap of which hooks apply to the table. For example, a table with hooks
+ // PREROUTING and FORWARD has the value
+ // (1 << NF_IP_PRE_REOUTING) | (1 << NF_IP_FORWARD).
+ unsigned int valid_hooks;
+
+ // The offset into the entry table for each valid hook. The entry table is
+ // returned by getsockopt(SO_GET_ENTRIES).
+ unsigned int hook_entry[NF_IP_NUMHOOKS];
+
+ // For each valid hook, the underflow is the offset into the entry table to
+ // jump to in case traversing the table yields no verdict (although I have no
+ // clue how that could happen - builtin chains always end with a policy, and
+ // user-defined chains always end with a RETURN.
+ //
+ // The entry referred to must be an "unconditional" entry, meaning it has no
+ // matches, specifies no IP criteria, and either DROPs or ACCEPTs packets. It
+ // basically has to be capable of making a definitive decision no matter what
+ // it's passed.
+ unsigned int underflow[NF_IP_NUMHOOKS];
+
+ // The number of entries in the entry table returned by
+ // getsockopt(SO_GET_ENTRIES).
+ unsigned int num_entries;
+
+ // The size of the entry table returned by getsockopt(SO_GET_ENTRIES).
+ unsigned int size;
+};
+
+// Passed to getsockopt(SO_GET_ENTRIES).
+struct ipt_get_entries {
+ // The name of the table. The user fills this in. Currently "nat" and "mangle"
+ // are supported.
+ char name[XT_TABLE_MAXNAMELEN];
+
+ // The size of the entry table in bytes. The user fills this in with the value
+ // from struct ipt_getinfo.size.
+ unsigned int size;
+
+ // The entries for the given table. This will run past the size defined by
+ // sizeof(struct ipt_get_entries).
+ struct ipt_entry entrytable[0];
+};
+
+// Passed to setsockopt(SO_SET_REPLACE).
+struct ipt_replace {
+ // The name of the table.
+ char name[XT_TABLE_MAXNAMELEN];
+
+ // The same as struct ipt_getinfo.valid_hooks. Users don't change this.
+ unsigned int valid_hooks;
+
+ // The same as struct ipt_getinfo.num_entries.
+ unsigned int num_entries;
+
+ // The same as struct ipt_getinfo.size.
+ unsigned int size;
+
+ // The same as struct ipt_getinfo.hook_entry.
+ unsigned int hook_entry[NF_IP_NUMHOOKS];
+
+ // The same as struct ipt_getinfo.underflow.
+ unsigned int underflow[NF_IP_NUMHOOKS];
+
+ // The number of counters, which should equal the number of entries.
+ unsigned int num_counters;
+
+ // The unchanged values from each ipt_entry's counters.
+ struct xt_counters* counters;
+
+ // The entries to write to the table. This will run past the size defined by
+ // sizeof(srtuct ipt_replace);
+ struct ipt_entry entries[0];
+};
+
+#endif // GVISOR_TEST_SYSCALLS_IPTABLES_TYPES_H_
diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc
new file mode 100644
index 000000000..e397d5f57
--- /dev/null
+++ b/test/syscalls/linux/itimer.cc
@@ -0,0 +1,366 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+#include <sys/socket.h>
+#include <sys/time.h>
+#include <sys/types.h>
+#include <time.h>
+
+#include <atomic>
+#include <functional>
+#include <iostream>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/string_view.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+#include "test/util/timer_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+constexpr char kSIGALRMToMainThread[] = "--itimer_sigarlm_to_main_thread";
+constexpr char kSIGPROFFairnessActive[] = "--itimer_sigprof_fairness_active";
+constexpr char kSIGPROFFairnessIdle[] = "--itimer_sigprof_fairness_idle";
+
+// Time period to be set for the itimers.
+constexpr absl::Duration kPeriod = absl::Milliseconds(25);
+// Total amount of time to spend per thread.
+constexpr absl::Duration kTestDuration = absl::Seconds(20);
+// Amount of spin iterations to perform as the minimum work item per thread.
+// Chosen to be sub-millisecond range.
+constexpr int kIterations = 10000000;
+// Allow deviation in the number of samples.
+constexpr double kNumSamplesDeviationRatio = 0.2;
+
+TEST(ItimerTest, ItimervalUpdatedBeforeExpiration) {
+ constexpr int kSleepSecs = 10;
+ constexpr int kAlarmSecs = 15;
+ static_assert(
+ kSleepSecs < kAlarmSecs,
+ "kSleepSecs must be less than kAlarmSecs for the test to be meaningful");
+ constexpr int kMaxRemainingSecs = kAlarmSecs - kSleepSecs;
+
+ // Install a no-op handler for SIGALRM.
+ struct sigaction sa = {};
+ sigfillset(&sa.sa_mask);
+ sa.sa_handler = +[](int signo) {};
+ auto const cleanup_sa =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa));
+
+ // Set an itimer-based alarm for kAlarmSecs from now.
+ struct itimerval itv = {};
+ itv.it_value.tv_sec = kAlarmSecs;
+ auto const cleanup_itimer =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_REAL, itv));
+
+ // After sleeping for kSleepSecs, the itimer value should reflect the elapsed
+ // time even if it hasn't expired.
+ absl::SleepFor(absl::Seconds(kSleepSecs));
+ ASSERT_THAT(getitimer(ITIMER_REAL, &itv), SyscallSucceeds());
+ EXPECT_TRUE(
+ itv.it_value.tv_sec < kMaxRemainingSecs ||
+ (itv.it_value.tv_sec == kMaxRemainingSecs && itv.it_value.tv_usec == 0))
+ << "Remaining time: " << itv.it_value.tv_sec << " seconds + "
+ << itv.it_value.tv_usec << " microseconds";
+}
+
+ABSL_CONST_INIT static thread_local std::atomic_int signal_test_num_samples =
+ ATOMIC_VAR_INIT(0);
+
+void SignalTestSignalHandler(int /*signum*/) { signal_test_num_samples++; }
+
+struct SignalTestResult {
+ int expected_total;
+ int main_thread_samples;
+ std::vector<int> worker_samples;
+};
+
+std::ostream& operator<<(std::ostream& os, const SignalTestResult& r) {
+ os << "{expected_total: " << r.expected_total
+ << ", main_thread_samples: " << r.main_thread_samples
+ << ", worker_samples: [";
+ bool first = true;
+ for (int sample : r.worker_samples) {
+ if (!first) {
+ os << ", ";
+ }
+ os << sample;
+ first = false;
+ }
+ os << "]}";
+ return os;
+}
+
+// Starts two worker threads and itimer id and measures the number of signal
+// delivered to each thread.
+SignalTestResult ItimerSignalTest(int id, clock_t main_clock,
+ clock_t worker_clock, int signal,
+ absl::Duration sleep) {
+ signal_test_num_samples = 0;
+
+ struct sigaction sa = {};
+ sa.sa_handler = &SignalTestSignalHandler;
+ sa.sa_flags = SA_RESTART;
+ sigemptyset(&sa.sa_mask);
+ auto sigaction_cleanup = ScopedSigaction(signal, sa).ValueOrDie();
+
+ int socketfds[2];
+ TEST_PCHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, socketfds) == 0);
+
+ // Do the spinning in the workers.
+ std::function<void*(int)> work = [&](int socket_fd) {
+ FileDescriptor fd(socket_fd);
+
+ absl::Time finish = Now(worker_clock) + kTestDuration;
+ while (Now(worker_clock) < finish) {
+ // Blocked on read.
+ char c;
+ RetryEINTR(read)(fd.get(), &c, 1);
+ for (int i = 0; i < kIterations; i++) {
+ // Ensure compiler won't optimize this loop away.
+ asm("");
+ }
+
+ if (sleep != absl::ZeroDuration()) {
+ // Sleep so that the entire process is idle for a while.
+ absl::SleepFor(sleep);
+ }
+
+ // Unblock the other thread.
+ RetryEINTR(write)(fd.get(), &c, 1);
+ }
+
+ return reinterpret_cast<void*>(signal_test_num_samples.load());
+ };
+
+ ScopedThread th1(
+ static_cast<std::function<void*()>>(std::bind(work, socketfds[0])));
+ ScopedThread th2(
+ static_cast<std::function<void*()>>(std::bind(work, socketfds[1])));
+
+ absl::Time start = Now(main_clock);
+ // Start the timer.
+ struct itimerval timer = {};
+ timer.it_value = absl::ToTimeval(kPeriod);
+ timer.it_interval = absl::ToTimeval(kPeriod);
+ auto cleanup_itimer = ScopedItimer(id, timer).ValueOrDie();
+
+ // Unblock th1.
+ //
+ // N.B. th2 owns socketfds[1] but can't close it until it unblocks.
+ char c = 0;
+ TEST_CHECK(write(socketfds[1], &c, 1) == 1);
+
+ SignalTestResult result;
+
+ // Wait for the workers to be done and collect their sample counts.
+ result.worker_samples.push_back(reinterpret_cast<int64_t>(th1.Join()));
+ result.worker_samples.push_back(reinterpret_cast<int64_t>(th2.Join()));
+ cleanup_itimer.Release()();
+ result.expected_total = (Now(main_clock) - start) / kPeriod;
+ result.main_thread_samples = signal_test_num_samples.load();
+
+ return result;
+}
+
+int TestSIGALRMToMainThread() {
+ SignalTestResult result =
+ ItimerSignalTest(ITIMER_REAL, CLOCK_REALTIME, CLOCK_REALTIME, SIGALRM,
+ absl::ZeroDuration());
+
+ std::cerr << "result: " << result << std::endl;
+
+ // ITIMER_REAL-generated SIGALRMs prefer to deliver to the thread group leader
+ // (but don't guarantee it), so we expect to see most samples on the main
+ // thread.
+ //
+ // The number of SIGALRMs delivered to a worker should not exceed 20%
+ // of the number of total signals expected (this is somewhat arbitrary).
+ const int worker_threshold = result.expected_total / 5;
+
+ //
+ // Linux only guarantees timers will never expire before the requested time.
+ // Thus, we only check the upper bound and also it at least have one sample.
+ TEST_CHECK(result.main_thread_samples <= result.expected_total);
+ TEST_CHECK(result.main_thread_samples > 0);
+ for (int num : result.worker_samples) {
+ TEST_CHECK_MSG(num <= worker_threshold, "worker received too many samples");
+ }
+
+ return 0;
+}
+
+// Random save/restore is disabled as it introduces additional latency and
+// unpredictable distribution patterns.
+TEST(ItimerTest, DeliversSIGALRMToMainThread_NoRandomSave) {
+ pid_t child;
+ int execve_errno;
+ auto kill = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec("/proc/self/exe", {"/proc/self/exe", kSIGALRMToMainThread},
+ {}, &child, &execve_errno));
+ EXPECT_EQ(0, execve_errno);
+
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0),
+ SyscallSucceedsWithValue(child));
+
+ // Not required anymore.
+ kill.Release();
+
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) << status;
+}
+
+// Signals are delivered to threads fairly.
+//
+// sleep indicates how long to sleep worker threads each iteration to make the
+// entire process idle.
+int TestSIGPROFFairness(absl::Duration sleep) {
+ SignalTestResult result =
+ ItimerSignalTest(ITIMER_PROF, CLOCK_PROCESS_CPUTIME_ID,
+ CLOCK_THREAD_CPUTIME_ID, SIGPROF, sleep);
+
+ std::cerr << "result: " << result << std::endl;
+
+ // The number of samples on the main thread should be very low as it did
+ // nothing.
+ TEST_CHECK(result.main_thread_samples < 80);
+
+ // Both workers should get roughly equal number of samples.
+ TEST_CHECK(result.worker_samples.size() == 2);
+
+ TEST_CHECK(result.expected_total > 0);
+
+ // In an ideal world each thread would get exactly 50% of the signals,
+ // but since that's unlikely to happen we allow for them to get no less than
+ // kNumSamplesDeviationRatio of the total observed samples.
+ TEST_CHECK_MSG(std::abs(result.worker_samples[0] - result.worker_samples[1]) <
+ ((result.worker_samples[0] + result.worker_samples[1]) *
+ kNumSamplesDeviationRatio),
+ "one worker received disproportionate share of samples");
+
+ return 0;
+}
+
+// Random save/restore is disabled as it introduces additional latency and
+// unpredictable distribution patterns.
+TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyActive_NoRandomSave) {
+ // On the KVM and ptrace platforms, switches between sentry and application
+ // context are sometimes extremely slow, causing the itimer to send SIGPROF to
+ // a thread that either already has one pending or has had SIGPROF delivered,
+ // but hasn't handled it yet (and thus therefore still has SIGPROF masked). In
+ // either case, since itimer signals are group-directed, signal sending falls
+ // back to notifying the thread group leader. ItimerSignalTest() fails if "too
+ // many" signals are delivered to the thread group leader, so these tests are
+ // flaky on these platforms.
+ //
+ // TODO(b/143247272): Clarify why context switches are so slow on KVM.
+ const auto gvisor_platform = GvisorPlatform();
+ SKIP_IF(gvisor_platform == Platform::kKVM ||
+ gvisor_platform == Platform::kPtrace);
+
+ pid_t child;
+ int execve_errno;
+ auto kill = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec("/proc/self/exe", {"/proc/self/exe", kSIGPROFFairnessActive},
+ {}, &child, &execve_errno));
+ EXPECT_EQ(0, execve_errno);
+
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0),
+ SyscallSucceedsWithValue(child));
+
+ // Not required anymore.
+ kill.Release();
+
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "Exited with code: " << status;
+}
+
+// Random save/restore is disabled as it introduces additional latency and
+// unpredictable distribution patterns.
+TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyIdle_NoRandomSave) {
+ // See comment in DeliversSIGPROFToThreadsRoughlyFairlyActive.
+ const auto gvisor_platform = GvisorPlatform();
+ SKIP_IF(gvisor_platform == Platform::kKVM ||
+ gvisor_platform == Platform::kPtrace);
+
+ pid_t child;
+ int execve_errno;
+ auto kill = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec("/proc/self/exe", {"/proc/self/exe", kSIGPROFFairnessIdle},
+ {}, &child, &execve_errno));
+ EXPECT_EQ(0, execve_errno);
+
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0),
+ SyscallSucceedsWithValue(child));
+
+ // Not required anymore.
+ kill.Release();
+
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "Exited with code: " << status;
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
+
+namespace {
+void MaskSIGPIPE() {
+ // Always mask SIGPIPE as it's common and tests aren't expected to handle it.
+ // We don't take the TestInit() path so we must do this manually.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_IGN;
+ TEST_CHECK(sigaction(SIGPIPE, &sa, nullptr) == 0);
+}
+} // namespace
+
+int main(int argc, char** argv) {
+ // These tests require no background threads, so check for them before
+ // TestInit.
+ for (int i = 0; i < argc; i++) {
+ absl::string_view arg(argv[i]);
+
+ if (arg == gvisor::testing::kSIGALRMToMainThread) {
+ MaskSIGPIPE();
+ return gvisor::testing::TestSIGALRMToMainThread();
+ }
+ if (arg == gvisor::testing::kSIGPROFFairnessActive) {
+ MaskSIGPIPE();
+ return gvisor::testing::TestSIGPROFFairness(absl::ZeroDuration());
+ }
+ if (arg == gvisor::testing::kSIGPROFFairnessIdle) {
+ MaskSIGPIPE();
+ // Sleep time > ClockTick (10ms) exercises sleeping gVisor's
+ // kernel.cpuClockTicker.
+ return gvisor::testing::TestSIGPROFFairness(absl::Milliseconds(25));
+ }
+ }
+
+ gvisor::testing::TestInit(&argc, &argv);
+ return gvisor::testing::RunAllTests();
+}
diff --git a/test/syscalls/linux/kill.cc b/test/syscalls/linux/kill.cc
new file mode 100644
index 000000000..db29bd59c
--- /dev/null
+++ b/test/syscalls/linux/kill.cc
@@ -0,0 +1,383 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <cerrno>
+#include <csignal>
+
+#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+ABSL_FLAG(int32_t, scratch_uid, 65534, "scratch UID");
+ABSL_FLAG(int32_t, scratch_gid, 65534, "scratch GID");
+
+using ::testing::Ge;
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(KillTest, CanKillValidPid) {
+ // If pid is positive, then signal sig is sent to the process with the ID
+ // specified by pid.
+ EXPECT_THAT(kill(getpid(), 0), SyscallSucceeds());
+ // If pid equals 0, then sig is sent to every process in the process group of
+ // the calling process.
+ EXPECT_THAT(kill(0, 0), SyscallSucceeds());
+
+ ScopedThread([] { EXPECT_THAT(kill(gettid(), 0), SyscallSucceeds()); });
+}
+
+void SigHandler(int sig, siginfo_t* info, void* context) { _exit(0); }
+
+// If pid equals -1, then sig is sent to every process for which the calling
+// process has permission to send signals, except for process 1 (init).
+TEST(KillTest, CanKillAllPIDs) {
+ int pipe_fds[2];
+ ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds());
+ FileDescriptor read_fd(pipe_fds[0]);
+ FileDescriptor write_fd(pipe_fds[1]);
+
+ pid_t pid = fork();
+ if (pid == 0) {
+ read_fd.reset();
+
+ struct sigaction sa;
+ sa.sa_sigaction = SigHandler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO;
+ TEST_PCHECK(sigaction(SIGWINCH, &sa, nullptr) == 0);
+ MaybeSave();
+
+ // Indicate to the parent that we're ready.
+ write_fd.reset();
+
+ // Wait until we get the signal from the parent.
+ while (true) {
+ pause();
+ }
+ }
+
+ ASSERT_THAT(pid, SyscallSucceeds());
+
+ write_fd.reset();
+
+ // Wait for the child to indicate that it's unmasked the signal by closing
+ // the write end.
+ char buf;
+ ASSERT_THAT(ReadFd(read_fd.get(), &buf, 1), SyscallSucceedsWithValue(0));
+
+ // Signal the child and wait for it to die with status 0, indicating that
+ // it got the expected signal.
+ EXPECT_THAT(kill(-1, SIGWINCH), SyscallSucceeds());
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(pid, &status, 0),
+ SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFEXITED(status));
+ EXPECT_EQ(0, WEXITSTATUS(status));
+}
+
+TEST(KillTest, CannotKillInvalidPID) {
+ // We need an unused pid to verify that kill fails when given one.
+ //
+ // There is no way to guarantee that a PID is unused, but the PID of a
+ // recently exited process likely won't be reused soon.
+ pid_t fake_pid = fork();
+ if (fake_pid == 0) {
+ _exit(0);
+ }
+
+ ASSERT_THAT(fake_pid, SyscallSucceeds());
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(fake_pid, &status, 0),
+ SyscallSucceedsWithValue(fake_pid));
+ EXPECT_TRUE(WIFEXITED(status));
+ EXPECT_EQ(0, WEXITSTATUS(status));
+
+ EXPECT_THAT(kill(fake_pid, 0), SyscallFailsWithErrno(ESRCH));
+}
+
+TEST(KillTest, CannotUseInvalidSignal) {
+ EXPECT_THAT(kill(getpid(), 200), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(KillTest, CanKillRemoteProcess) {
+ pid_t pid = fork();
+ if (pid == 0) {
+ while (true) {
+ pause();
+ }
+ }
+
+ ASSERT_THAT(pid, SyscallSucceeds());
+
+ EXPECT_THAT(kill(pid, SIGKILL), SyscallSucceeds());
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(pid, &status, 0),
+ SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFSIGNALED(status));
+ EXPECT_EQ(SIGKILL, WTERMSIG(status));
+}
+
+TEST(KillTest, CanKillOwnProcess) {
+ EXPECT_THAT(kill(getpid(), 0), SyscallSucceeds());
+}
+
+// Verify that you can kill a process even using a tid from a thread other than
+// the group leader.
+TEST(KillTest, CannotKillTid) {
+ pid_t tid;
+ bool tid_available = false;
+ bool finished = false;
+ absl::Mutex mu;
+ ScopedThread t([&] {
+ mu.Lock();
+ tid = gettid();
+ tid_available = true;
+ mu.Await(absl::Condition(&finished));
+ mu.Unlock();
+ });
+ mu.LockWhen(absl::Condition(&tid_available));
+ EXPECT_THAT(kill(tid, 0), SyscallSucceeds());
+ finished = true;
+ mu.Unlock();
+}
+
+TEST(KillTest, SetPgid) {
+ for (int i = 0; i < 10; i++) {
+ // The following in the normal pattern for creating a new process group.
+ // Both the parent and child process will call setpgid in order to avoid any
+ // race conditions. We do this ten times to catch races.
+ pid_t pid = fork();
+ if (pid == 0) {
+ setpgid(0, 0);
+ while (true) {
+ pause();
+ }
+ }
+
+ ASSERT_THAT(pid, SyscallSucceeds());
+
+ // Set the child's group and exit.
+ ASSERT_THAT(setpgid(pid, pid), SyscallSucceeds());
+ EXPECT_THAT(kill(pid, SIGKILL), SyscallSucceeds());
+
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(-pid, &status, 0),
+ SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFSIGNALED(status));
+ EXPECT_EQ(SIGKILL, WTERMSIG(status));
+ }
+}
+
+TEST(KillTest, ProcessGroups) {
+ // Fork a new child.
+ //
+ // other_child is used as a placeholder process. We use this PID as our "does
+ // not exist" process group to ensure some amount of safety. (It is still
+ // possible to violate this assumption, but extremely unlikely.)
+ pid_t child = fork();
+ if (child == 0) {
+ while (true) {
+ pause();
+ }
+ }
+ ASSERT_THAT(child, SyscallSucceeds());
+
+ pid_t other_child = fork();
+ if (other_child == 0) {
+ while (true) {
+ pause();
+ }
+ }
+ ASSERT_THAT(other_child, SyscallSucceeds());
+
+ // Ensure the kill does not succeed without the new group.
+ EXPECT_THAT(kill(-child, SIGKILL), SyscallFailsWithErrno(ESRCH));
+
+ // Put the child in its own process group.
+ ASSERT_THAT(setpgid(child, child), SyscallSucceeds());
+
+ // This should be not allowed: you can only create a new group with the same
+ // id or join an existing one. The other_child group should not exist.
+ ASSERT_THAT(setpgid(child, other_child), SyscallFailsWithErrno(EPERM));
+
+ // Done with other_child; kill it.
+ EXPECT_THAT(kill(other_child, SIGKILL), SyscallSucceeds());
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(other_child, &status, 0), SyscallSucceeds());
+
+ // Linux returns success for the no-op call.
+ ASSERT_THAT(setpgid(child, child), SyscallSucceeds());
+
+ // Kill the child's process group.
+ ASSERT_THAT(kill(-child, SIGKILL), SyscallSucceeds());
+
+ // Wait on the process group; ensure that the signal was as expected.
+ EXPECT_THAT(RetryEINTR(waitpid)(-child, &status, 0),
+ SyscallSucceedsWithValue(child));
+ EXPECT_TRUE(WIFSIGNALED(status));
+ EXPECT_EQ(SIGKILL, WTERMSIG(status));
+
+ // Try to kill the process group again; ensure that the wait fails.
+ EXPECT_THAT(kill(-child, SIGKILL), SyscallFailsWithErrno(ESRCH));
+ EXPECT_THAT(RetryEINTR(waitpid)(-child, &status, 0),
+ SyscallFailsWithErrno(ECHILD));
+}
+
+TEST(KillTest, ChildDropsPrivsCannotKill) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID)));
+
+ const int uid = absl::GetFlag(FLAGS_scratch_uid);
+ const int gid = absl::GetFlag(FLAGS_scratch_gid);
+
+ // Create the child that drops privileges and tries to kill the parent.
+ pid_t pid = fork();
+ if (pid == 0) {
+ TEST_PCHECK(setresgid(gid, gid, gid) == 0);
+ MaybeSave();
+
+ TEST_PCHECK(setresuid(uid, uid, uid) == 0);
+ MaybeSave();
+
+ // setresuid should have dropped CAP_KILL. Make sure.
+ TEST_CHECK(!HaveCapability(CAP_KILL).ValueOrDie());
+
+ // Try to kill parent with every signal-sending syscall possible.
+ pid_t parent = getppid();
+
+ TEST_CHECK(kill(parent, SIGKILL) < 0);
+ TEST_PCHECK_MSG(errno == EPERM, "kill failed with wrong errno");
+ MaybeSave();
+
+ TEST_CHECK(tgkill(parent, parent, SIGKILL) < 0);
+ TEST_PCHECK_MSG(errno == EPERM, "tgkill failed with wrong errno");
+ MaybeSave();
+
+ TEST_CHECK(syscall(SYS_tkill, parent, SIGKILL) < 0);
+ TEST_PCHECK_MSG(errno == EPERM, "tkill failed with wrong errno");
+ MaybeSave();
+
+ siginfo_t uinfo;
+ uinfo.si_code = -1; // SI_QUEUE (allowed).
+
+ TEST_CHECK(syscall(SYS_rt_sigqueueinfo, parent, SIGKILL, &uinfo) < 0);
+ TEST_PCHECK_MSG(errno == EPERM, "rt_sigqueueinfo failed with wrong errno");
+ MaybeSave();
+
+ TEST_CHECK(syscall(SYS_rt_tgsigqueueinfo, parent, parent, SIGKILL, &uinfo) <
+ 0);
+ TEST_PCHECK_MSG(errno == EPERM, "rt_sigqueueinfo failed with wrong errno");
+ MaybeSave();
+
+ _exit(0);
+ }
+
+ ASSERT_THAT(pid, SyscallSucceeds());
+
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(pid, &status, 0),
+ SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "status = " << status;
+}
+
+TEST(KillTest, CanSIGCONTSameSession) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID)));
+
+ pid_t stopped_child = fork();
+ if (stopped_child == 0) {
+ raise(SIGSTOP);
+ _exit(0);
+ }
+
+ ASSERT_THAT(stopped_child, SyscallSucceeds());
+
+ // Put the child in its own process group. The child and parent process
+ // groups also share a session.
+ ASSERT_THAT(setpgid(stopped_child, stopped_child), SyscallSucceeds());
+
+ // Make sure child stopped.
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(stopped_child, &status, WUNTRACED),
+ SyscallSucceedsWithValue(stopped_child));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP)
+ << "status " << status;
+
+ const int uid = absl::GetFlag(FLAGS_scratch_uid);
+ const int gid = absl::GetFlag(FLAGS_scratch_gid);
+
+ // Drop privileges only in child process, or else this parent process won't be
+ // able to open some log files after the test ends.
+ pid_t other_child = fork();
+ if (other_child == 0) {
+ // Drop privileges.
+ TEST_PCHECK(setresgid(gid, gid, gid) == 0);
+ MaybeSave();
+
+ TEST_PCHECK(setresuid(uid, uid, uid) == 0);
+ MaybeSave();
+
+ // setresuid should have dropped CAP_KILL.
+ TEST_CHECK(!HaveCapability(CAP_KILL).ValueOrDie());
+
+ // Child 2 and child should now not share a thread group and any UIDs.
+ // Child 2 should have no privileges. That means any signal other than
+ // SIGCONT should fail.
+ TEST_CHECK(kill(stopped_child, SIGKILL) < 0);
+ TEST_PCHECK_MSG(errno == EPERM, "kill failed with wrong errno");
+ MaybeSave();
+
+ TEST_PCHECK(kill(stopped_child, SIGCONT) == 0);
+ MaybeSave();
+
+ _exit(0);
+ }
+
+ ASSERT_THAT(stopped_child, SyscallSucceeds());
+
+ // Make sure child exited normally.
+ EXPECT_THAT(RetryEINTR(waitpid)(stopped_child, &status, 0),
+ SyscallSucceedsWithValue(stopped_child));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "status " << status;
+
+ // Make sure other_child exited normally.
+ EXPECT_THAT(RetryEINTR(waitpid)(other_child, &status, 0),
+ SyscallSucceedsWithValue(other_child));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "status " << status;
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/link.cc b/test/syscalls/linux/link.cc
new file mode 100644
index 000000000..544681168
--- /dev/null
+++ b/test/syscalls/linux/link.cc
@@ -0,0 +1,305 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <string.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <string>
+
+#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
+#include "absl/strings/str_cat.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+ABSL_FLAG(int32_t, scratch_uid, 65534, "scratch UID");
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// IsSameFile returns true if both filenames have the same device and inode.
+bool IsSameFile(const std::string& f1, const std::string& f2) {
+ // Use lstat rather than stat, so that symlinks are not followed.
+ struct stat stat1 = {};
+ EXPECT_THAT(lstat(f1.c_str(), &stat1), SyscallSucceeds());
+ struct stat stat2 = {};
+ EXPECT_THAT(lstat(f2.c_str(), &stat2), SyscallSucceeds());
+
+ return stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino;
+}
+
+TEST(LinkTest, CanCreateLinkFile) {
+ auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const std::string newname = NewTempAbsPath();
+
+ // Get the initial link count.
+ uint64_t initial_link_count =
+ ASSERT_NO_ERRNO_AND_VALUE(Links(oldfile.path()));
+
+ EXPECT_THAT(link(oldfile.path().c_str(), newname.c_str()), SyscallSucceeds());
+
+ EXPECT_TRUE(IsSameFile(oldfile.path(), newname));
+
+ // Link count should be incremented.
+ EXPECT_THAT(Links(oldfile.path()),
+ IsPosixErrorOkAndHolds(initial_link_count + 1));
+
+ // Delete the link.
+ EXPECT_THAT(unlink(newname.c_str()), SyscallSucceeds());
+
+ // Link count should be back to initial.
+ EXPECT_THAT(Links(oldfile.path()),
+ IsPosixErrorOkAndHolds(initial_link_count));
+}
+
+TEST(LinkTest, PermissionDenied) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_FOWNER)));
+
+ // Make the file "unsafe" to link by making it only readable, but not
+ // writable.
+ const auto unwriteable_file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0400));
+ const std::string special_path = NewTempAbsPath();
+ ASSERT_THAT(mkfifo(special_path.c_str(), 0666), SyscallSucceeds());
+ const auto setuid_file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666 | S_ISUID));
+
+ const std::string newname = NewTempAbsPath();
+
+ // Do setuid in a separate thread so that after finishing this test, the
+ // process can still open files the test harness created before starting this
+ // test. Otherwise, the files are created by root (UID before the test), but
+ // cannot be opened by the `uid` set below after the test. After calling
+ // setuid(non-zero-UID), there is no way to get root privileges back.
+ ScopedThread([&] {
+ // Use syscall instead of glibc setuid wrapper because we want this setuid
+ // call to only apply to this task. POSIX threads, however, require that all
+ // threads have the same UIDs, so using the setuid wrapper sets all threads'
+ // real UID.
+ // Also drops capabilities.
+ EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)),
+ SyscallSucceeds());
+
+ EXPECT_THAT(link(unwriteable_file.path().c_str(), newname.c_str()),
+ SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(link(special_path.c_str(), newname.c_str()),
+ SyscallFailsWithErrno(EPERM));
+ if (!IsRunningWithVFS1()) {
+ EXPECT_THAT(link(setuid_file.path().c_str(), newname.c_str()),
+ SyscallFailsWithErrno(EPERM));
+ }
+ });
+}
+
+TEST(LinkTest, CannotLinkDirectory) {
+ auto olddir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const std::string newdir = NewTempAbsPath();
+
+ EXPECT_THAT(link(olddir.path().c_str(), newdir.c_str()),
+ SyscallFailsWithErrno(EPERM));
+
+ EXPECT_THAT(rmdir(olddir.path().c_str()), SyscallSucceeds());
+}
+
+TEST(LinkTest, CannotLinkWithSlash) {
+ auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ // Put a final "/" on newname.
+ const std::string newname = absl::StrCat(NewTempAbsPath(), "/");
+
+ EXPECT_THAT(link(oldfile.path().c_str(), newname.c_str()),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+TEST(LinkTest, OldnameIsEmpty) {
+ const std::string newname = NewTempAbsPath();
+ EXPECT_THAT(link("", newname.c_str()), SyscallFailsWithErrno(ENOENT));
+}
+
+TEST(LinkTest, OldnameDoesNotExist) {
+ const std::string oldname = NewTempAbsPath();
+ const std::string newname = NewTempAbsPath();
+ EXPECT_THAT(link("", newname.c_str()), SyscallFailsWithErrno(ENOENT));
+}
+
+TEST(LinkTest, NewnameCannotExist) {
+ const std::string newname =
+ JoinPath(GetAbsoluteTestTmpdir(), "thisdoesnotexist", "foo");
+ EXPECT_THAT(link("/thisdoesnotmatter", newname.c_str()),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+TEST(LinkTest, WithOldDirFD) {
+ const std::string oldname_parent = NewTempAbsPath();
+ const std::string oldname_base = "child";
+ const std::string oldname = JoinPath(oldname_parent, oldname_base);
+ const std::string newname = NewTempAbsPath();
+
+ // Create oldname_parent directory, and get an FD.
+ ASSERT_THAT(mkdir(oldname_parent.c_str(), 0777), SyscallSucceeds());
+ const FileDescriptor oldname_parent_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(oldname_parent, O_DIRECTORY | O_RDONLY));
+
+ // Create oldname file.
+ const FileDescriptor oldname_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(oldname, O_CREAT | O_RDWR, 0666));
+
+ // Link oldname to newname, using oldname_parent_fd.
+ EXPECT_THAT(linkat(oldname_parent_fd.get(), oldname_base.c_str(), AT_FDCWD,
+ newname.c_str(), 0),
+ SyscallSucceeds());
+
+ EXPECT_TRUE(IsSameFile(oldname, newname));
+
+ EXPECT_THAT(unlink(newname.c_str()), SyscallSucceeds());
+ EXPECT_THAT(unlink(oldname.c_str()), SyscallSucceeds());
+ EXPECT_THAT(rmdir(oldname_parent.c_str()), SyscallSucceeds());
+}
+
+TEST(LinkTest, BogusFlags) {
+ ASSERT_THAT(linkat(1, "foo", 2, "bar", 3), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(LinkTest, WithNewDirFD) {
+ auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const std::string newname_parent = NewTempAbsPath();
+ const std::string newname_base = "child";
+ const std::string newname = JoinPath(newname_parent, newname_base);
+
+ // Create newname_parent directory, and get an FD.
+ EXPECT_THAT(mkdir(newname_parent.c_str(), 0777), SyscallSucceeds());
+ const FileDescriptor newname_parent_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(newname_parent, O_DIRECTORY | O_RDONLY));
+
+ // Link newname to oldfile, using newname_parent_fd.
+ EXPECT_THAT(linkat(AT_FDCWD, oldfile.path().c_str(), newname_parent_fd.get(),
+ newname.c_str(), 0),
+ SyscallSucceeds());
+
+ EXPECT_TRUE(IsSameFile(oldfile.path(), newname));
+
+ EXPECT_THAT(unlink(newname.c_str()), SyscallSucceeds());
+ EXPECT_THAT(rmdir(newname_parent.c_str()), SyscallSucceeds());
+}
+
+TEST(LinkTest, RelPathsWithNonDirFDs) {
+ auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ // Create a file that will be passed as the directory fd for old/new names.
+ const std::string filename = NewTempAbsPath();
+ const FileDescriptor file_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(filename, O_CREAT | O_RDWR, 0666));
+
+ // Using file_fd as olddirfd will fail.
+ EXPECT_THAT(linkat(file_fd.get(), "foo", AT_FDCWD, "bar", 0),
+ SyscallFailsWithErrno(ENOTDIR));
+
+ // Using file_fd as newdirfd will fail.
+ EXPECT_THAT(linkat(AT_FDCWD, oldfile.path().c_str(), file_fd.get(), "bar", 0),
+ SyscallFailsWithErrno(ENOTDIR));
+}
+
+TEST(LinkTest, AbsPathsWithNonDirFDs) {
+ auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const std::string newname = NewTempAbsPath();
+
+ // Create a file that will be passed as the directory fd for old/new names.
+ const std::string filename = NewTempAbsPath();
+ const FileDescriptor file_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(filename, O_CREAT | O_RDWR, 0666));
+
+ // Using file_fd as the dirfds is OK as long as paths are absolute.
+ EXPECT_THAT(linkat(file_fd.get(), oldfile.path().c_str(), file_fd.get(),
+ newname.c_str(), 0),
+ SyscallSucceeds());
+}
+
+TEST(LinkTest, LinkDoesNotFollowSymlinks) {
+ // Create oldfile, and oldsymlink which points to it.
+ auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const std::string oldsymlink = NewTempAbsPath();
+ EXPECT_THAT(symlink(oldfile.path().c_str(), oldsymlink.c_str()),
+ SyscallSucceeds());
+
+ // Now hard link newname to oldsymlink.
+ const std::string newname = NewTempAbsPath();
+ EXPECT_THAT(link(oldsymlink.c_str(), newname.c_str()), SyscallSucceeds());
+
+ // The link should not have resolved the symlink, so newname and oldsymlink
+ // are the same.
+ EXPECT_TRUE(IsSameFile(oldsymlink, newname));
+ EXPECT_FALSE(IsSameFile(oldfile.path(), newname));
+
+ EXPECT_THAT(unlink(oldsymlink.c_str()), SyscallSucceeds());
+ EXPECT_THAT(unlink(newname.c_str()), SyscallSucceeds());
+}
+
+TEST(LinkTest, LinkatDoesNotFollowSymlinkByDefault) {
+ // Create oldfile, and oldsymlink which points to it.
+ auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const std::string oldsymlink = NewTempAbsPath();
+ EXPECT_THAT(symlink(oldfile.path().c_str(), oldsymlink.c_str()),
+ SyscallSucceeds());
+
+ // Now hard link newname to oldsymlink.
+ const std::string newname = NewTempAbsPath();
+ EXPECT_THAT(
+ linkat(AT_FDCWD, oldsymlink.c_str(), AT_FDCWD, newname.c_str(), 0),
+ SyscallSucceeds());
+
+ // The link should not have resolved the symlink, so newname and oldsymlink
+ // are the same.
+ EXPECT_TRUE(IsSameFile(oldsymlink, newname));
+ EXPECT_FALSE(IsSameFile(oldfile.path(), newname));
+
+ EXPECT_THAT(unlink(oldsymlink.c_str()), SyscallSucceeds());
+ EXPECT_THAT(unlink(newname.c_str()), SyscallSucceeds());
+}
+
+TEST(LinkTest, LinkatWithSymlinkFollow) {
+ // Create oldfile, and oldsymlink which points to it.
+ auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const std::string oldsymlink = NewTempAbsPath();
+ ASSERT_THAT(symlink(oldfile.path().c_str(), oldsymlink.c_str()),
+ SyscallSucceeds());
+
+ // Now hard link newname to oldsymlink, and pass AT_SYMLINK_FOLLOW flag.
+ const std::string newname = NewTempAbsPath();
+ ASSERT_THAT(linkat(AT_FDCWD, oldsymlink.c_str(), AT_FDCWD, newname.c_str(),
+ AT_SYMLINK_FOLLOW),
+ SyscallSucceeds());
+
+ // The link should have resolved the symlink, so oldfile and newname are the
+ // same.
+ EXPECT_TRUE(IsSameFile(oldfile.path(), newname));
+ EXPECT_FALSE(IsSameFile(oldsymlink, newname));
+
+ EXPECT_THAT(unlink(oldsymlink.c_str()), SyscallSucceeds());
+ EXPECT_THAT(unlink(newname.c_str()), SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/lseek.cc b/test/syscalls/linux/lseek.cc
new file mode 100644
index 000000000..6ce1e6cc3
--- /dev/null
+++ b/test/syscalls/linux/lseek.cc
@@ -0,0 +1,202 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(LseekTest, InvalidWhence) {
+ const std::string kFileData = "hello world\n";
+ const TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kFileData, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDWR, 0644));
+
+ ASSERT_THAT(lseek(fd.get(), 0, -1), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(LseekTest, NegativeOffset) {
+ const std::string kFileData = "hello world\n";
+ const TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kFileData, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDWR, 0644));
+
+ EXPECT_THAT(lseek(fd.get(), -(kFileData.length() + 1), SEEK_CUR),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// A 32-bit off_t is not large enough to represent an offset larger than
+// maximum file size on standard file systems, so it isn't possible to cause
+// overflow.
+#if defined(__x86_64__) || defined(__aarch64__)
+TEST(LseekTest, Overflow) {
+ // HA! Classic Linux. We really should have an EOVERFLOW
+ // here, since we're seeking to something that cannot be
+ // represented.. but instead we are given an EINVAL.
+ const std::string kFileData = "hello world\n";
+ const TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kFileData, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDWR, 0644));
+ EXPECT_THAT(lseek(fd.get(), 0x7fffffffffffffff, SEEK_END),
+ SyscallFailsWithErrno(EINVAL));
+}
+#endif
+
+TEST(LseekTest, Set) {
+ const std::string kFileData = "hello world\n";
+ const TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kFileData, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDWR, 0644));
+
+ char buf = '\0';
+ EXPECT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
+ ASSERT_THAT(read(fd.get(), &buf, 1), SyscallSucceedsWithValue(1));
+ EXPECT_EQ(buf, kFileData.c_str()[0]);
+ EXPECT_THAT(lseek(fd.get(), 6, SEEK_SET), SyscallSucceedsWithValue(6));
+ ASSERT_THAT(read(fd.get(), &buf, 1), SyscallSucceedsWithValue(1));
+ EXPECT_EQ(buf, kFileData.c_str()[6]);
+}
+
+TEST(LseekTest, Cur) {
+ const std::string kFileData = "hello world\n";
+ const TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kFileData, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDWR, 0644));
+
+ char buf = '\0';
+ EXPECT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
+ ASSERT_THAT(read(fd.get(), &buf, 1), SyscallSucceedsWithValue(1));
+ EXPECT_EQ(buf, kFileData.c_str()[0]);
+ EXPECT_THAT(lseek(fd.get(), 3, SEEK_CUR), SyscallSucceedsWithValue(4));
+ ASSERT_THAT(read(fd.get(), &buf, 1), SyscallSucceedsWithValue(1));
+ EXPECT_EQ(buf, kFileData.c_str()[4]);
+}
+
+TEST(LseekTest, End) {
+ const std::string kFileData = "hello world\n";
+ const TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kFileData, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDWR, 0644));
+
+ char buf = '\0';
+ EXPECT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
+ ASSERT_THAT(read(fd.get(), &buf, 1), SyscallSucceedsWithValue(1));
+ EXPECT_EQ(buf, kFileData.c_str()[0]);
+ EXPECT_THAT(lseek(fd.get(), -2, SEEK_END), SyscallSucceedsWithValue(10));
+ ASSERT_THAT(read(fd.get(), &buf, 1), SyscallSucceedsWithValue(1));
+ EXPECT_EQ(buf, kFileData.c_str()[kFileData.length() - 2]);
+}
+
+TEST(LseekTest, InvalidFD) {
+ EXPECT_THAT(lseek(-1, 0, SEEK_SET), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(LseekTest, DirCurEnd) {
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open("/tmp", O_RDONLY));
+ ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0));
+}
+
+TEST(LseekTest, ProcDir) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self", O_RDONLY));
+ ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceeds());
+ ASSERT_THAT(lseek(fd.get(), 0, SEEK_END), SyscallSucceeds());
+}
+
+TEST(LseekTest, ProcFile) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/meminfo", O_RDONLY));
+ ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceeds());
+ ASSERT_THAT(lseek(fd.get(), 0, SEEK_END), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(LseekTest, SysDir) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/sys/devices", O_RDONLY));
+ ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceeds());
+ ASSERT_THAT(lseek(fd.get(), 0, SEEK_END), SyscallSucceeds());
+}
+
+TEST(LseekTest, SeekCurrentDir) {
+ // From include/linux/fs.h.
+ constexpr loff_t MAX_LFS_FILESIZE = 0x7fffffffffffffff;
+
+ char* dir = get_current_dir_name();
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir, O_RDONLY));
+
+ ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceeds());
+ ASSERT_THAT(lseek(fd.get(), 0, SEEK_END),
+ // Some filesystems (like ext4) allow lseek(SEEK_END) on a
+ // directory and return MAX_LFS_FILESIZE, others return EINVAL.
+ AnyOf(SyscallSucceedsWithValue(MAX_LFS_FILESIZE),
+ SyscallFailsWithErrno(EINVAL)));
+ free(dir);
+}
+
+TEST(LseekTest, ProcStatTwice) {
+ const FileDescriptor fd1 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/stat", O_RDONLY));
+ const FileDescriptor fd2 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/stat", O_RDONLY));
+
+ ASSERT_THAT(lseek(fd1.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0));
+ ASSERT_THAT(lseek(fd1.get(), 0, SEEK_END), SyscallFailsWithErrno(EINVAL));
+ ASSERT_THAT(lseek(fd1.get(), 1000, SEEK_CUR), SyscallSucceeds());
+ // Check that just because we moved fd1, fd2 doesn't move.
+ ASSERT_THAT(lseek(fd2.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0));
+
+ const FileDescriptor fd3 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/stat", O_RDONLY));
+ ASSERT_THAT(lseek(fd3.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0));
+}
+
+TEST(LseekTest, EtcPasswdDup) {
+ const FileDescriptor fd1 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/etc/passwd", O_RDONLY));
+ const FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(fd1.Dup());
+
+ ASSERT_THAT(lseek(fd1.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0));
+ ASSERT_THAT(lseek(fd2.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0));
+ ASSERT_THAT(lseek(fd1.get(), 1000, SEEK_CUR), SyscallSucceeds());
+ // Check that just because we moved fd1, fd2 doesn't move.
+ ASSERT_THAT(lseek(fd2.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(1000));
+
+ const FileDescriptor fd3 = ASSERT_NO_ERRNO_AND_VALUE(fd1.Dup());
+ ASSERT_THAT(lseek(fd3.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(1000));
+}
+
+// TODO(magi): Add tests where we have donated in sockets.
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/madvise.cc b/test/syscalls/linux/madvise.cc
new file mode 100644
index 000000000..5a1973f60
--- /dev/null
+++ b/test/syscalls/linux/madvise.cc
@@ -0,0 +1,251 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <sys/wait.h>
+#include <unistd.h>
+
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/memory_util.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void ExpectAllMappingBytes(Mapping const& m, char c) {
+ auto const v = m.view();
+ for (size_t i = 0; i < kPageSize; i++) {
+ ASSERT_EQ(v[i], c) << "at offset " << i;
+ }
+}
+
+// Equivalent to ExpectAllMappingBytes but async-signal-safe and with less
+// helpful failure messages.
+void CheckAllMappingBytes(Mapping const& m, char c) {
+ auto const v = m.view();
+ for (size_t i = 0; i < kPageSize; i++) {
+ TEST_CHECK_MSG(v[i] == c, "mapping contains wrong value");
+ }
+}
+
+TEST(MadviseDontneedTest, ZerosPrivateAnonPage) {
+ auto m = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ ExpectAllMappingBytes(m, 0);
+ memset(m.ptr(), 1, m.len());
+ ExpectAllMappingBytes(m, 1);
+ ASSERT_THAT(madvise(m.ptr(), m.len(), MADV_DONTNEED), SyscallSucceeds());
+ ExpectAllMappingBytes(m, 0);
+}
+
+TEST(MadviseDontneedTest, ZerosCOWAnonPageInCallerOnly) {
+ auto m = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ ExpectAllMappingBytes(m, 0);
+ memset(m.ptr(), 2, m.len());
+ ExpectAllMappingBytes(m, 2);
+
+ // Do madvise in a child process.
+ pid_t pid = fork();
+ CheckAllMappingBytes(m, 2);
+ if (pid == 0) {
+ TEST_PCHECK(madvise(m.ptr(), m.len(), MADV_DONTNEED) == 0);
+ CheckAllMappingBytes(m, 0);
+ _exit(0);
+ }
+
+ ASSERT_THAT(pid, SyscallSucceeds());
+
+ int status = 0;
+ ASSERT_THAT(waitpid(-1, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFEXITED(status));
+ EXPECT_EQ(WEXITSTATUS(status), 0);
+ // The child's madvise should not have affected the parent.
+ ExpectAllMappingBytes(m, 2);
+}
+
+TEST(MadviseDontneedTest, DoesNotModifySharedAnonPage) {
+ auto m = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED));
+ ExpectAllMappingBytes(m, 0);
+ memset(m.ptr(), 3, m.len());
+ ExpectAllMappingBytes(m, 3);
+ ASSERT_THAT(madvise(m.ptr(), m.len(), MADV_DONTNEED), SyscallSucceeds());
+ ExpectAllMappingBytes(m, 3);
+}
+
+TEST(MadviseDontneedTest, CleansPrivateFilePage) {
+ TempPath f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ /* parent = */ GetAbsoluteTestTmpdir(),
+ /* content = */ std::string(kPageSize, 4), TempPath::kDefaultFileMode));
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDWR));
+
+ Mapping m = ASSERT_NO_ERRNO_AND_VALUE(Mmap(
+ nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, fd.get(), 0));
+ ExpectAllMappingBytes(m, 4);
+ memset(m.ptr(), 5, m.len());
+ ExpectAllMappingBytes(m, 5);
+ ASSERT_THAT(madvise(m.ptr(), m.len(), MADV_DONTNEED), SyscallSucceeds());
+ ExpectAllMappingBytes(m, 4);
+}
+
+TEST(MadviseDontneedTest, DoesNotModifySharedFilePage) {
+ TempPath f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ /* parent = */ GetAbsoluteTestTmpdir(),
+ /* content = */ std::string(kPageSize, 6), TempPath::kDefaultFileMode));
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDWR));
+
+ Mapping m = ASSERT_NO_ERRNO_AND_VALUE(Mmap(
+ nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd.get(), 0));
+ ExpectAllMappingBytes(m, 6);
+ memset(m.ptr(), 7, m.len());
+ ExpectAllMappingBytes(m, 7);
+ ASSERT_THAT(madvise(m.ptr(), m.len(), MADV_DONTNEED), SyscallSucceeds());
+ ExpectAllMappingBytes(m, 7);
+}
+
+TEST(MadviseDontneedTest, IgnoresPermissions) {
+ auto m =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, MAP_PRIVATE));
+ EXPECT_THAT(madvise(m.ptr(), m.len(), MADV_DONTNEED), SyscallSucceeds());
+}
+
+TEST(MadviseDontforkTest, AddressLength) {
+ auto m =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, MAP_PRIVATE));
+ char* addr = static_cast<char*>(m.ptr());
+
+ // Address must be page aligned.
+ EXPECT_THAT(madvise(addr + 1, kPageSize, MADV_DONTFORK),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Zero length madvise always succeeds.
+ EXPECT_THAT(madvise(addr, 0, MADV_DONTFORK), SyscallSucceeds());
+
+ // Length must not roll over after rounding up.
+ size_t badlen = std::numeric_limits<std::size_t>::max() - (kPageSize / 2);
+ EXPECT_THAT(madvise(0, badlen, MADV_DONTFORK), SyscallFailsWithErrno(EINVAL));
+
+ // Length need not be page aligned - it is implicitly rounded up.
+ EXPECT_THAT(madvise(addr, 1, MADV_DONTFORK), SyscallSucceeds());
+ EXPECT_THAT(madvise(addr, kPageSize, MADV_DONTFORK), SyscallSucceeds());
+}
+
+TEST(MadviseDontforkTest, DontforkShared) {
+ // Mmap two shared file-backed pages and MADV_DONTFORK the second page.
+ TempPath f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ /* parent = */ GetAbsoluteTestTmpdir(),
+ /* content = */ std::string(kPageSize * 2, 2),
+ TempPath::kDefaultFileMode));
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDWR));
+
+ Mapping m = ASSERT_NO_ERRNO_AND_VALUE(Mmap(
+ nullptr, kPageSize * 2, PROT_READ | PROT_WRITE, MAP_SHARED, fd.get(), 0));
+
+ const Mapping ms1 = Mapping(reinterpret_cast<void*>(m.addr()), kPageSize);
+ const Mapping ms2 =
+ Mapping(reinterpret_cast<void*>(m.addr() + kPageSize), kPageSize);
+ m.release();
+
+ ASSERT_THAT(madvise(ms2.ptr(), kPageSize, MADV_DONTFORK), SyscallSucceeds());
+
+ const auto rest = [&] {
+ // First page is mapped in child and modifications are visible to parent
+ // via the shared mapping.
+ TEST_CHECK(IsMapped(ms1.addr()));
+ ExpectAllMappingBytes(ms1, 2);
+ memset(ms1.ptr(), 1, kPageSize);
+ ExpectAllMappingBytes(ms1, 1);
+
+ // Second page must not be mapped in child.
+ TEST_CHECK(!IsMapped(ms2.addr()));
+ };
+
+ EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
+
+ ExpectAllMappingBytes(ms1, 1); // page contents modified by child.
+ ExpectAllMappingBytes(ms2, 2); // page contents unchanged.
+}
+
+TEST(MadviseDontforkTest, DontforkAnonPrivate) {
+ // Mmap three anonymous pages and MADV_DONTFORK the middle page.
+ Mapping m = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize * 3, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ const Mapping mp1 = Mapping(reinterpret_cast<void*>(m.addr()), kPageSize);
+ const Mapping mp2 =
+ Mapping(reinterpret_cast<void*>(m.addr() + kPageSize), kPageSize);
+ const Mapping mp3 =
+ Mapping(reinterpret_cast<void*>(m.addr() + 2 * kPageSize), kPageSize);
+ m.release();
+
+ ASSERT_THAT(madvise(mp2.ptr(), kPageSize, MADV_DONTFORK), SyscallSucceeds());
+
+ // Verify that all pages are zeroed and memset the first, second and third
+ // pages to 1, 2, and 3 respectively.
+ ExpectAllMappingBytes(mp1, 0);
+ memset(mp1.ptr(), 1, kPageSize);
+
+ ExpectAllMappingBytes(mp2, 0);
+ memset(mp2.ptr(), 2, kPageSize);
+
+ ExpectAllMappingBytes(mp3, 0);
+ memset(mp3.ptr(), 3, kPageSize);
+
+ const auto rest = [&] {
+ // Verify first page is mapped, verify its contents and then modify the
+ // page. The mapping is private so the modifications are not visible to
+ // the parent.
+ TEST_CHECK(IsMapped(mp1.addr()));
+ ExpectAllMappingBytes(mp1, 1);
+ memset(mp1.ptr(), 11, kPageSize);
+ ExpectAllMappingBytes(mp1, 11);
+
+ // Verify second page is not mapped.
+ TEST_CHECK(!IsMapped(mp2.addr()));
+
+ // Verify third page is mapped, verify its contents and then modify the
+ // page. The mapping is private so the modifications are not visible to
+ // the parent.
+ TEST_CHECK(IsMapped(mp3.addr()));
+ ExpectAllMappingBytes(mp3, 3);
+ memset(mp3.ptr(), 13, kPageSize);
+ ExpectAllMappingBytes(mp3, 13);
+ };
+ EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
+
+ // The fork and COW by child should not affect the parent mappings.
+ ExpectAllMappingBytes(mp1, 1);
+ ExpectAllMappingBytes(mp2, 2);
+ ExpectAllMappingBytes(mp3, 3);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/memfd.cc b/test/syscalls/linux/memfd.cc
new file mode 100644
index 000000000..f8b7f7938
--- /dev/null
+++ b/test/syscalls/linux/memfd.cc
@@ -0,0 +1,557 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/magic.h>
+#include <linux/memfd.h>
+#include <linux/unistd.h>
+#include <string.h>
+#include <sys/mman.h>
+#include <sys/statfs.h>
+#include <sys/syscall.h>
+
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/memory_util.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+// The header sys/memfd.h isn't available on all systems, so redefining some of
+// the constants here.
+#define F_LINUX_SPECIFIC_BASE 1024
+
+#ifndef F_ADD_SEALS
+#define F_ADD_SEALS (F_LINUX_SPECIFIC_BASE + 9)
+#endif /* F_ADD_SEALS */
+
+#ifndef F_GET_SEALS
+#define F_GET_SEALS (F_LINUX_SPECIFIC_BASE + 10)
+#endif /* F_GET_SEALS */
+
+#define F_SEAL_SEAL 0x0001
+#define F_SEAL_SHRINK 0x0002
+#define F_SEAL_GROW 0x0004
+#define F_SEAL_WRITE 0x0008
+
+using ::testing::StartsWith;
+
+const std::string kMemfdName = "some-memfd";
+
+int memfd_create(const std::string& name, unsigned int flags) {
+ return syscall(__NR_memfd_create, name.c_str(), flags);
+}
+
+PosixErrorOr<FileDescriptor> MemfdCreate(const std::string& name,
+ uint32_t flags) {
+ int fd = memfd_create(name, flags);
+ if (fd < 0) {
+ return PosixError(
+ errno, absl::StrFormat("memfd_create(\"%s\", %#x)", name, flags));
+ }
+ MaybeSave();
+ return FileDescriptor(fd);
+}
+
+// Procfs entries for memfds display the appropriate name.
+TEST(MemfdTest, Name) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, 0));
+ const std::string proc_name = ASSERT_NO_ERRNO_AND_VALUE(
+ ReadLink(absl::StrFormat("/proc/self/fd/%d", memfd.get())));
+ EXPECT_THAT(proc_name, StartsWith("/memfd:" + kMemfdName));
+}
+
+// Memfds support read/write syscalls.
+TEST(MemfdTest, WriteRead) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, 0));
+
+ // Write a random page of data to the memfd via write(2).
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(memfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Read back the same data and verify.
+ std::vector<char> buf2(kPageSize);
+ ASSERT_THAT(lseek(memfd.get(), 0, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(read(memfd.get(), buf2.data(), buf2.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ EXPECT_EQ(buf, buf2);
+}
+
+// Memfds can be mapped and used as usual.
+TEST(MemfdTest, Mmap) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, 0));
+ const Mapping m1 = ASSERT_NO_ERRNO_AND_VALUE(Mmap(
+ nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, memfd.get(), 0));
+
+ // Write a random page of data to the memfd via mmap m1.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(ftruncate(memfd.get(), kPageSize), SyscallSucceeds());
+ memcpy(m1.ptr(), buf.data(), buf.size());
+
+ // Read the data back via a read syscall on the memfd.
+ std::vector<char> buf2(kPageSize);
+ EXPECT_THAT(read(memfd.get(), buf2.data(), buf2.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ EXPECT_EQ(buf, buf2);
+
+ // The same data should be accessible via a new mapping m2.
+ const Mapping m2 = ASSERT_NO_ERRNO_AND_VALUE(Mmap(
+ nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, memfd.get(), 0));
+ EXPECT_EQ(0, memcmp(m1.ptr(), m2.ptr(), kPageSize));
+}
+
+TEST(MemfdTest, DuplicateFDsShareContent) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, 0));
+ const Mapping m1 = ASSERT_NO_ERRNO_AND_VALUE(Mmap(
+ nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, memfd.get(), 0));
+ const FileDescriptor memfd2 = ASSERT_NO_ERRNO_AND_VALUE(memfd.Dup());
+
+ // Write a random page of data to the memfd via mmap m1.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(ftruncate(memfd.get(), kPageSize), SyscallSucceeds());
+ memcpy(m1.ptr(), buf.data(), buf.size());
+
+ // Read the data back via a read syscall on a duplicate fd.
+ std::vector<char> buf2(kPageSize);
+ EXPECT_THAT(read(memfd2.get(), buf2.data(), buf2.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ EXPECT_EQ(buf, buf2);
+}
+
+// File seals are disabled by default on memfds.
+TEST(MemfdTest, SealingDisabledByDefault) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, 0));
+ EXPECT_THAT(fcntl(memfd.get(), F_GET_SEALS),
+ SyscallSucceedsWithValue(F_SEAL_SEAL));
+ // Attempting to set any seal should fail.
+ EXPECT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE),
+ SyscallFailsWithErrno(EPERM));
+}
+
+// Seals can be retrieved and updated for memfds.
+TEST(MemfdTest, SealsGetSet) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING));
+ int seals;
+ ASSERT_THAT(seals = fcntl(memfd.get(), F_GET_SEALS), SyscallSucceeds());
+ // No seals are set yet.
+ EXPECT_EQ(0, seals);
+
+ // Set a seal and check that we can get it back.
+ ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE), SyscallSucceeds());
+ EXPECT_THAT(fcntl(memfd.get(), F_GET_SEALS),
+ SyscallSucceedsWithValue(F_SEAL_WRITE));
+
+ // Set some more seals and verify.
+ ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_GROW | F_SEAL_SHRINK),
+ SyscallSucceeds());
+ EXPECT_THAT(
+ fcntl(memfd.get(), F_GET_SEALS),
+ SyscallSucceedsWithValue(F_SEAL_WRITE | F_SEAL_GROW | F_SEAL_SHRINK));
+
+ // Attempting to set a seal that is already set is a no-op.
+ ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE), SyscallSucceeds());
+ EXPECT_THAT(
+ fcntl(memfd.get(), F_GET_SEALS),
+ SyscallSucceedsWithValue(F_SEAL_WRITE | F_SEAL_GROW | F_SEAL_SHRINK));
+
+ // Add remaining seals and verify.
+ ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_SEAL), SyscallSucceeds());
+ EXPECT_THAT(fcntl(memfd.get(), F_GET_SEALS),
+ SyscallSucceedsWithValue(F_SEAL_WRITE | F_SEAL_GROW |
+ F_SEAL_SHRINK | F_SEAL_SEAL));
+}
+
+// F_SEAL_GROW prevents a memfd from being grown using ftruncate.
+TEST(MemfdTest, SealGrowWithTruncate) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING));
+ ASSERT_THAT(ftruncate(memfd.get(), kPageSize), SyscallSucceeds());
+ ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_GROW), SyscallSucceeds());
+
+ // Try grow the memfd by 1 page.
+ ASSERT_THAT(ftruncate(memfd.get(), kPageSize * 2),
+ SyscallFailsWithErrno(EPERM));
+
+ // Ftruncate calls that don't actually grow the memfd are allowed.
+ ASSERT_THAT(ftruncate(memfd.get(), kPageSize), SyscallSucceeds());
+ ASSERT_THAT(ftruncate(memfd.get(), kPageSize / 2), SyscallSucceeds());
+
+ // After shrinking, growing back is not allowed.
+ ASSERT_THAT(ftruncate(memfd.get(), kPageSize), SyscallFailsWithErrno(EPERM));
+}
+
+// F_SEAL_GROW prevents a memfd from being grown using the write syscall.
+TEST(MemfdTest, SealGrowWithWrite) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING));
+
+ // Initially, writing to the memfd succeeds.
+ const std::vector<char> buf(kPageSize);
+ EXPECT_THAT(write(memfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Apply F_SEAL_GROW, subsequent writes which extend the memfd should fail.
+ ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_GROW), SyscallSucceeds());
+ EXPECT_THAT(write(memfd.get(), buf.data(), buf.size()),
+ SyscallFailsWithErrno(EPERM));
+
+ // However, zero-length writes are ok since they don't grow the memfd.
+ EXPECT_THAT(write(memfd.get(), buf.data(), 0), SyscallSucceeds());
+
+ // Writing to existing parts of the memfd is also ok.
+ ASSERT_THAT(lseek(memfd.get(), 0, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(write(memfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Returning the end of the file and writing still not allowed.
+ EXPECT_THAT(write(memfd.get(), buf.data(), buf.size()),
+ SyscallFailsWithErrno(EPERM));
+}
+
+// F_SEAL_GROW causes writes which partially extend off the current EOF to
+// partially succeed, up to the page containing the EOF.
+TEST(MemfdTest, SealGrowPartialWriteTruncated) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING));
+ ASSERT_THAT(ftruncate(memfd.get(), kPageSize), SyscallSucceeds());
+ ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_GROW), SyscallSucceeds());
+
+ // FD offset: 1 page, EOF: 1 page.
+
+ ASSERT_THAT(lseek(memfd.get(), kPageSize * 3 / 4, SEEK_SET),
+ SyscallSucceeds());
+
+ // FD offset: 3/4 page. Writing a full page now should only write 1/4 page
+ // worth of data. This partially succeeds because the first page is entirely
+ // within the file and requires no growth, but attempting to write the final
+ // 3/4 page would require growing the file.
+ const std::vector<char> buf(kPageSize);
+ EXPECT_THAT(write(memfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize / 4));
+}
+
+// F_SEAL_GROW causes writes which partially extend off the current EOF to fail
+// in its entirety if the only data written would be to the page containing the
+// EOF.
+TEST(MemfdTest, SealGrowPartialWriteTruncatedSamePage) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING));
+ ASSERT_THAT(ftruncate(memfd.get(), kPageSize * 3 / 4), SyscallSucceeds());
+ ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_GROW), SyscallSucceeds());
+
+ // EOF: 3/4 page, writing 1/2 page starting at 1/2 page would cause the file
+ // to grow. Since this would require only the page containing the EOF to be
+ // modified, the write is rejected entirely.
+ const std::vector<char> buf(kPageSize / 2);
+ EXPECT_THAT(pwrite(memfd.get(), buf.data(), buf.size(), kPageSize / 2),
+ SyscallFailsWithErrno(EPERM));
+
+ // However, writing up to EOF is fine.
+ EXPECT_THAT(pwrite(memfd.get(), buf.data(), buf.size() / 2, kPageSize / 2),
+ SyscallSucceedsWithValue(kPageSize / 4));
+}
+
+// F_SEAL_SHRINK prevents a memfd from being shrunk using ftruncate.
+TEST(MemfdTest, SealShrink) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING));
+ ASSERT_THAT(ftruncate(memfd.get(), kPageSize), SyscallSucceeds());
+ ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_SHRINK),
+ SyscallSucceeds());
+
+ // Shrink by half a page.
+ ASSERT_THAT(ftruncate(memfd.get(), kPageSize / 2),
+ SyscallFailsWithErrno(EPERM));
+
+ // Ftruncate calls that don't actually shrink the file are allowed.
+ ASSERT_THAT(ftruncate(memfd.get(), kPageSize), SyscallSucceeds());
+ ASSERT_THAT(ftruncate(memfd.get(), kPageSize * 2), SyscallSucceeds());
+
+ // After growing, shrinking is still not allowed.
+ ASSERT_THAT(ftruncate(memfd.get(), kPageSize), SyscallFailsWithErrno(EPERM));
+}
+
+// F_SEAL_WRITE prevents a memfd from being written to through a write
+// syscall.
+TEST(MemfdTest, SealWriteWithWrite) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING));
+ const std::vector<char> buf(kPageSize);
+ ASSERT_THAT(write(memfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE), SyscallSucceeds());
+
+ // Attemping to write at the end of the file fails.
+ EXPECT_THAT(write(memfd.get(), buf.data(), 1), SyscallFailsWithErrno(EPERM));
+
+ // Attemping to overwrite an existing part of the memfd fails.
+ EXPECT_THAT(pwrite(memfd.get(), buf.data(), 1, 0),
+ SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(pwrite(memfd.get(), buf.data(), buf.size() / 2, kPageSize / 2),
+ SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(pwrite(memfd.get(), buf.data(), buf.size(), kPageSize / 2),
+ SyscallFailsWithErrno(EPERM));
+
+ // Zero-length writes however do not fail.
+ EXPECT_THAT(write(memfd.get(), buf.data(), 0), SyscallSucceeds());
+}
+
+// F_SEAL_WRITE prevents a memfd from being written to through an mmap.
+TEST(MemfdTest, SealWriteWithMmap) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING));
+ const std::vector<char> buf(kPageSize);
+ ASSERT_THAT(write(memfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE), SyscallSucceeds());
+
+ // Can't create a shared mapping with writes sealed.
+ void* ret = mmap(nullptr, kPageSize, PROT_WRITE, MAP_SHARED, memfd.get(), 0);
+ EXPECT_EQ(ret, MAP_FAILED);
+ EXPECT_EQ(errno, EPERM);
+ ret = mmap(nullptr, kPageSize, PROT_READ, MAP_SHARED, memfd.get(), 0);
+ EXPECT_EQ(ret, MAP_FAILED);
+ EXPECT_EQ(errno, EPERM);
+
+ // However, private mappings are ok.
+ EXPECT_NO_ERRNO(Mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE,
+ memfd.get(), 0));
+}
+
+// Adding F_SEAL_WRITE fails when there are outstanding writable mappings to a
+// memfd.
+TEST(MemfdTest, SealWriteWithOutstandingWritbleMapping) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING));
+ const std::vector<char> buf(kPageSize);
+ ASSERT_THAT(write(memfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Attempting to add F_SEAL_WRITE with active shared mapping with any set of
+ // permissions fails.
+
+ // Read-only shared mapping.
+ {
+ const Mapping m = ASSERT_NO_ERRNO_AND_VALUE(
+ Mmap(nullptr, kPageSize, PROT_READ, MAP_SHARED, memfd.get(), 0));
+ EXPECT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE),
+ SyscallFailsWithErrno(EBUSY));
+ }
+
+ // Write-only shared mapping.
+ {
+ const Mapping m = ASSERT_NO_ERRNO_AND_VALUE(
+ Mmap(nullptr, kPageSize, PROT_WRITE, MAP_SHARED, memfd.get(), 0));
+ EXPECT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE),
+ SyscallFailsWithErrno(EBUSY));
+ }
+
+ // Read-write shared mapping.
+ {
+ const Mapping m = ASSERT_NO_ERRNO_AND_VALUE(
+ Mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED,
+ memfd.get(), 0));
+ EXPECT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE),
+ SyscallFailsWithErrno(EBUSY));
+ }
+
+ // F_SEAL_WRITE can be set with private mappings with any permissions.
+ {
+ const Mapping m = ASSERT_NO_ERRNO_AND_VALUE(
+ Mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE,
+ memfd.get(), 0));
+ EXPECT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE),
+ SyscallSucceeds());
+ }
+}
+
+// When applying F_SEAL_WRITE fails due to outstanding writable mappings, any
+// additional seals passed to the same add seal call are also rejected.
+TEST(MemfdTest, NoPartialSealApplicationWhenWriteSealRejected) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING));
+ const Mapping m = ASSERT_NO_ERRNO_AND_VALUE(Mmap(
+ nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, memfd.get(), 0));
+
+ // Try add some seals along with F_SEAL_WRITE. The seal application should
+ // fail since there exists an active shared mapping.
+ EXPECT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE | F_SEAL_GROW),
+ SyscallFailsWithErrno(EBUSY));
+
+ // None of the seals should be applied.
+ EXPECT_THAT(fcntl(memfd.get(), F_GET_SEALS), SyscallSucceedsWithValue(0));
+}
+
+// Seals are inode level properties, and apply to all file descriptors referring
+// to a memfd.
+TEST(MemfdTest, SealsAreInodeLevelProperties) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING));
+ const FileDescriptor memfd2 = ASSERT_NO_ERRNO_AND_VALUE(memfd.Dup());
+
+ // Add seal through the original memfd, and verify that it appears on the
+ // dupped fd.
+ ASSERT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE), SyscallSucceeds());
+ EXPECT_THAT(fcntl(memfd2.get(), F_GET_SEALS),
+ SyscallSucceedsWithValue(F_SEAL_WRITE));
+
+ // Verify the seal actually applies to both fds.
+ std::vector<char> buf(kPageSize);
+ EXPECT_THAT(write(memfd.get(), buf.data(), buf.size()),
+ SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(write(memfd2.get(), buf.data(), buf.size()),
+ SyscallFailsWithErrno(EPERM));
+
+ // Seals are enforced on new FDs that are dupped after the seal is already
+ // applied.
+ const FileDescriptor memfd3 = ASSERT_NO_ERRNO_AND_VALUE(memfd2.Dup());
+ EXPECT_THAT(write(memfd3.get(), buf.data(), buf.size()),
+ SyscallFailsWithErrno(EPERM));
+
+ // Try a new seal applied to one of the dupped fds.
+ ASSERT_THAT(fcntl(memfd3.get(), F_ADD_SEALS, F_SEAL_GROW), SyscallSucceeds());
+ EXPECT_THAT(ftruncate(memfd.get(), kPageSize), SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(ftruncate(memfd2.get(), kPageSize), SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(ftruncate(memfd3.get(), kPageSize), SyscallFailsWithErrno(EPERM));
+}
+
+PosixErrorOr<bool> IsTmpfs(const std::string& path) {
+ struct statfs stat;
+ if (statfs(path.c_str(), &stat)) {
+ if (errno == ENOENT) {
+ // Nothing at path, don't raise this as an error. Instead, just report no
+ // tmpfs at path.
+ return false;
+ }
+ return PosixError(errno,
+ absl::StrFormat("statfs(\"%s\", %#p)", path, &stat));
+ }
+ return stat.f_type == TMPFS_MAGIC;
+}
+
+// Tmpfs files also support seals, but are created with F_SEAL_SEAL.
+TEST(MemfdTest, TmpfsFilesHaveSealSeal) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs("/tmp")));
+ const TempPath tmpfs_file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn("/tmp"));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfs_file.path(), O_RDWR, 0644));
+ EXPECT_THAT(fcntl(fd.get(), F_GET_SEALS),
+ SyscallSucceedsWithValue(F_SEAL_SEAL));
+}
+
+// Can open a memfd from procfs and use as normal.
+TEST(MemfdTest, CanOpenFromProcfs) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING));
+
+ // Write a random page of data to the memfd via write(2).
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(memfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Read back the same data from the fd obtained from procfs and verify.
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(absl::StrFormat("/proc/self/fd/%d", memfd.get()), O_RDWR));
+ std::vector<char> buf2(kPageSize);
+ EXPECT_THAT(pread(fd.get(), buf2.data(), buf2.size(), 0),
+ SyscallSucceedsWithValue(kPageSize));
+ EXPECT_EQ(buf, buf2);
+}
+
+// Test that memfd permissions are set up correctly to allow another process to
+// open it from procfs.
+TEST(MemfdTest, OtherProcessCanOpenFromProcfs) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING));
+ const auto memfd_path =
+ absl::StrFormat("/proc/%d/fd/%d", getpid(), memfd.get());
+ const auto rest = [&] {
+ int fd = open(memfd_path.c_str(), O_RDWR);
+ TEST_PCHECK(fd >= 0);
+ TEST_PCHECK(close(fd) >= 0);
+ };
+ EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
+}
+
+// Test that only files opened as writable can have seals applied to them.
+// Normally there's no way to specify file permissions on memfds, but we can
+// obtain a read-only memfd by opening the corresponding procfs fd entry as
+// read-only.
+TEST(MemfdTest, MemfdMustBeWritableToModifySeals) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, MFD_ALLOW_SEALING));
+
+ // Initially adding a seal works.
+ EXPECT_THAT(fcntl(memfd.get(), F_ADD_SEALS, F_SEAL_WRITE), SyscallSucceeds());
+
+ // Re-open the memfd as read-only from procfs.
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(absl::StrFormat("/proc/self/fd/%d", memfd.get()), O_RDONLY));
+
+ // Can't add seals through an unwritable fd.
+ EXPECT_THAT(fcntl(fd.get(), F_ADD_SEALS, F_SEAL_GROW),
+ SyscallFailsWithErrno(EPERM));
+}
+
+// Test that the memfd implementation internally tracks potentially writable
+// maps correctly.
+TEST(MemfdTest, MultipleWritableAndNonWritableRefsToSameFileRegion) {
+ const FileDescriptor memfd =
+ ASSERT_NO_ERRNO_AND_VALUE(MemfdCreate(kMemfdName, 0));
+
+ // Populate with a random page of data.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(memfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Read-only map to the page. This should cause an initial mapping to be
+ // created.
+ Mapping m1 = ASSERT_NO_ERRNO_AND_VALUE(
+ Mmap(nullptr, kPageSize, PROT_READ, MAP_PRIVATE, memfd.get(), 0));
+
+ // Create a shared writable map to the page. This should cause the internal
+ // mapping to become potentially writable.
+ Mapping m2 = ASSERT_NO_ERRNO_AND_VALUE(Mmap(
+ nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, memfd.get(), 0));
+
+ // Drop the read-only mapping first. If writable-ness isn't tracked correctly,
+ // this can cause some misaccounting, which can trigger asserts internally.
+ m1.reset();
+ m2.reset();
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/memory_accounting.cc b/test/syscalls/linux/memory_accounting.cc
new file mode 100644
index 000000000..94aea4077
--- /dev/null
+++ b/test/syscalls/linux/memory_accounting.cc
@@ -0,0 +1,99 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/mman.h>
+
+#include <map>
+
+#include "gtest/gtest.h"
+#include "absl/strings/match.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_split.h"
+#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+using ::absl::StrFormat;
+
+// AnonUsageFromMeminfo scrapes the current anonymous memory usage from
+// /proc/meminfo and returns it in bytes.
+PosixErrorOr<uint64_t> AnonUsageFromMeminfo() {
+ ASSIGN_OR_RETURN_ERRNO(auto meminfo, GetContents("/proc/meminfo"));
+ std::vector<std::string> lines(absl::StrSplit(meminfo, '\n'));
+
+ // Try to find AnonPages line, the format is AnonPages:\\s+(\\d+) kB\n.
+ for (const auto& line : lines) {
+ if (!absl::StartsWith(line, "AnonPages:")) {
+ continue;
+ }
+
+ std::vector<std::string> parts(
+ absl::StrSplit(line, ' ', absl::SkipEmpty()));
+ if (parts.size() == 3) {
+ // The size is the second field, let's try to parse it as a number.
+ ASSIGN_OR_RETURN_ERRNO(auto anon_kb, Atoi<uint64_t>(parts[1]));
+ return anon_kb * 1024;
+ }
+
+ return PosixError(EINVAL, "AnonPages field in /proc/meminfo was malformed");
+ }
+
+ return PosixError(EINVAL, "AnonPages field not found in /proc/meminfo");
+}
+
+TEST(MemoryAccounting, AnonAccountingPreservedOnSaveRestore) {
+ // This test isn't meaningful on Linux. /proc/meminfo reports system-wide
+ // memory usage, which can change arbitrarily in Linux from other activity on
+ // the machine. In gvisor, this test is the only thing running on the
+ // "machine", so values in /proc/meminfo accurately reflect the memory used by
+ // the test.
+ SKIP_IF(!IsRunningOnGvisor());
+
+ uint64_t anon_initial = ASSERT_NO_ERRNO_AND_VALUE(AnonUsageFromMeminfo());
+
+ // Cause some anonymous memory usage.
+ uint64_t map_bytes = Megabytes(512);
+ char* mem =
+ static_cast<char*>(mmap(nullptr, map_bytes, PROT_READ | PROT_WRITE,
+ MAP_POPULATE | MAP_ANON | MAP_PRIVATE, -1, 0));
+ ASSERT_NE(mem, MAP_FAILED)
+ << "Map failed, errno: " << errno << " (" << strerror(errno) << ").";
+
+ // Write something to each page to prevent them from being decommited on
+ // S/R. Zero pages are dropped on save.
+ for (uint64_t i = 0; i < map_bytes; i += kPageSize) {
+ mem[i] = 'a';
+ }
+
+ uint64_t anon_after_alloc = ASSERT_NO_ERRNO_AND_VALUE(AnonUsageFromMeminfo());
+ EXPECT_THAT(anon_after_alloc,
+ EquivalentWithin(anon_initial + map_bytes, 0.03));
+
+ // We have many implicit S/R cycles from scraping /proc/meminfo throughout the
+ // test, but throw an explicit S/R in here as well.
+ MaybeSave();
+
+ // Usage should remain the same across S/R.
+ uint64_t anon_after_sr = ASSERT_NO_ERRNO_AND_VALUE(AnonUsageFromMeminfo());
+ EXPECT_THAT(anon_after_sr, EquivalentWithin(anon_after_alloc, 0.03));
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/mempolicy.cc b/test/syscalls/linux/mempolicy.cc
new file mode 100644
index 000000000..059fad598
--- /dev/null
+++ b/test/syscalls/linux/mempolicy.cc
@@ -0,0 +1,289 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <sys/syscall.h>
+
+#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
+#include "test/util/cleanup.h"
+#include "test/util/memory_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+#define BITS_PER_BYTE 8
+
+#define MPOL_F_STATIC_NODES (1 << 15)
+#define MPOL_F_RELATIVE_NODES (1 << 14)
+#define MPOL_DEFAULT 0
+#define MPOL_PREFERRED 1
+#define MPOL_BIND 2
+#define MPOL_INTERLEAVE 3
+#define MPOL_LOCAL 4
+#define MPOL_F_NODE (1 << 0)
+#define MPOL_F_ADDR (1 << 1)
+#define MPOL_F_MEMS_ALLOWED (1 << 2)
+#define MPOL_MF_STRICT (1 << 0)
+#define MPOL_MF_MOVE (1 << 1)
+#define MPOL_MF_MOVE_ALL (1 << 2)
+
+int get_mempolicy(int* policy, uint64_t* nmask, uint64_t maxnode, void* addr,
+ int flags) {
+ return syscall(SYS_get_mempolicy, policy, nmask, maxnode, addr, flags);
+}
+
+int set_mempolicy(int mode, uint64_t* nmask, uint64_t maxnode) {
+ return syscall(SYS_set_mempolicy, mode, nmask, maxnode);
+}
+
+int mbind(void* addr, unsigned long len, int mode,
+ const unsigned long* nodemask, unsigned long maxnode,
+ unsigned flags) {
+ return syscall(SYS_mbind, addr, len, mode, nodemask, maxnode, flags);
+}
+
+// Creates a cleanup object that resets the calling thread's mempolicy to the
+// system default when the calling scope ends.
+Cleanup ScopedMempolicy() {
+ return Cleanup([] {
+ EXPECT_THAT(set_mempolicy(MPOL_DEFAULT, nullptr, 0), SyscallSucceeds());
+ });
+}
+
+// Temporarily change the memory policy for the calling thread within the
+// caller's scope.
+PosixErrorOr<Cleanup> ScopedSetMempolicy(int mode, uint64_t* nmask,
+ uint64_t maxnode) {
+ if (set_mempolicy(mode, nmask, maxnode)) {
+ return PosixError(errno, "set_mempolicy");
+ }
+ return ScopedMempolicy();
+}
+
+TEST(MempolicyTest, CheckDefaultPolicy) {
+ int mode = 0;
+ uint64_t nodemask = 0;
+ ASSERT_THAT(get_mempolicy(&mode, &nodemask, sizeof(nodemask) * BITS_PER_BYTE,
+ nullptr, 0),
+ SyscallSucceeds());
+
+ EXPECT_EQ(MPOL_DEFAULT, mode);
+ EXPECT_EQ(0x0, nodemask);
+}
+
+TEST(MempolicyTest, PolicyPreservedAfterSetMempolicy) {
+ uint64_t nodemask = 0x1;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSetMempolicy(
+ MPOL_BIND, &nodemask, sizeof(nodemask) * BITS_PER_BYTE));
+
+ int mode = 0;
+ uint64_t nodemask_after = 0x0;
+ ASSERT_THAT(get_mempolicy(&mode, &nodemask_after,
+ sizeof(nodemask_after) * BITS_PER_BYTE, nullptr, 0),
+ SyscallSucceeds());
+ EXPECT_EQ(MPOL_BIND, mode);
+ EXPECT_EQ(0x1, nodemask_after);
+
+ // Try throw in some mode flags.
+ for (auto mode_flag : {MPOL_F_STATIC_NODES, MPOL_F_RELATIVE_NODES}) {
+ auto cleanup2 = ASSERT_NO_ERRNO_AND_VALUE(
+ ScopedSetMempolicy(MPOL_INTERLEAVE | mode_flag, &nodemask,
+ sizeof(nodemask) * BITS_PER_BYTE));
+ mode = 0;
+ nodemask_after = 0x0;
+ ASSERT_THAT(
+ get_mempolicy(&mode, &nodemask_after,
+ sizeof(nodemask_after) * BITS_PER_BYTE, nullptr, 0),
+ SyscallSucceeds());
+ EXPECT_EQ(MPOL_INTERLEAVE | mode_flag, mode);
+ EXPECT_EQ(0x1, nodemask_after);
+ }
+}
+
+TEST(MempolicyTest, SetMempolicyRejectsInvalidInputs) {
+ auto cleanup = ScopedMempolicy();
+ uint64_t nodemask;
+
+ if (IsRunningOnGvisor()) {
+ // Invalid nodemask, we only support a single node on gvisor.
+ nodemask = 0x4;
+ ASSERT_THAT(set_mempolicy(MPOL_DEFAULT, &nodemask,
+ sizeof(nodemask) * BITS_PER_BYTE),
+ SyscallFailsWithErrno(EINVAL));
+ }
+
+ nodemask = 0x1;
+
+ // Invalid mode.
+ ASSERT_THAT(set_mempolicy(7439, &nodemask, sizeof(nodemask) * BITS_PER_BYTE),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Invalid nodemask size.
+ ASSERT_THAT(set_mempolicy(MPOL_DEFAULT, &nodemask, 0),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Invalid mode flag.
+ ASSERT_THAT(
+ set_mempolicy(MPOL_DEFAULT | MPOL_F_STATIC_NODES | MPOL_F_RELATIVE_NODES,
+ &nodemask, sizeof(nodemask) * BITS_PER_BYTE),
+ SyscallFailsWithErrno(EINVAL));
+
+ // MPOL_INTERLEAVE with empty nodemask.
+ nodemask = 0x0;
+ ASSERT_THAT(set_mempolicy(MPOL_INTERLEAVE, &nodemask,
+ sizeof(nodemask) * BITS_PER_BYTE),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// The manpages specify that the nodemask provided to set_mempolicy are
+// considered empty if the nodemask pointer is null, or if the nodemask size is
+// 0. We use a policy which accepts both empty and non-empty nodemasks
+// (MPOL_PREFERRED), a policy which requires a non-empty nodemask (MPOL_BIND),
+// and a policy which completely ignores the nodemask (MPOL_DEFAULT) to verify
+// argument checking around nodemasks.
+TEST(MempolicyTest, EmptyNodemaskOnSet) {
+ auto cleanup = ScopedMempolicy();
+
+ EXPECT_THAT(set_mempolicy(MPOL_DEFAULT, nullptr, 1), SyscallSucceeds());
+ EXPECT_THAT(set_mempolicy(MPOL_BIND, nullptr, 1),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(set_mempolicy(MPOL_PREFERRED, nullptr, 1), SyscallSucceeds());
+
+ uint64_t nodemask = 0x1;
+ EXPECT_THAT(set_mempolicy(MPOL_DEFAULT, &nodemask, 0),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(set_mempolicy(MPOL_BIND, &nodemask, 0),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(set_mempolicy(MPOL_PREFERRED, &nodemask, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(MempolicyTest, QueryAvailableNodes) {
+ uint64_t nodemask = 0;
+ ASSERT_THAT(
+ get_mempolicy(nullptr, &nodemask, sizeof(nodemask) * BITS_PER_BYTE,
+ nullptr, MPOL_F_MEMS_ALLOWED),
+ SyscallSucceeds());
+ // We can only be sure there is a single node if running on gvisor.
+ if (IsRunningOnGvisor()) {
+ EXPECT_EQ(0x1, nodemask);
+ }
+
+ // MPOL_F_ADDR and MPOL_F_NODE flags may not be combined with
+ // MPOL_F_MEMS_ALLLOWED.
+ for (auto flags :
+ {MPOL_F_MEMS_ALLOWED | MPOL_F_ADDR, MPOL_F_MEMS_ALLOWED | MPOL_F_NODE,
+ MPOL_F_MEMS_ALLOWED | MPOL_F_ADDR | MPOL_F_NODE}) {
+ ASSERT_THAT(get_mempolicy(nullptr, &nodemask,
+ sizeof(nodemask) * BITS_PER_BYTE, nullptr, flags),
+ SyscallFailsWithErrno(EINVAL));
+ }
+}
+
+TEST(MempolicyTest, GetMempolicyQueryNodeForAddress) {
+ uint64_t dummy_stack_address;
+ auto dummy_heap_address = absl::make_unique<uint64_t>();
+ int mode;
+
+ for (auto ptr : {&dummy_stack_address, dummy_heap_address.get()}) {
+ mode = -1;
+ ASSERT_THAT(
+ get_mempolicy(&mode, nullptr, 0, ptr, MPOL_F_ADDR | MPOL_F_NODE),
+ SyscallSucceeds());
+ // If we're not running on gvisor, the address may be allocated on a
+ // different numa node.
+ if (IsRunningOnGvisor()) {
+ EXPECT_EQ(0, mode);
+ }
+ }
+
+ void* invalid_address = reinterpret_cast<void*>(-1);
+
+ // Invalid address.
+ ASSERT_THAT(get_mempolicy(&mode, nullptr, 0, invalid_address,
+ MPOL_F_ADDR | MPOL_F_NODE),
+ SyscallFailsWithErrno(EFAULT));
+
+ // Invalid mode pointer.
+ ASSERT_THAT(get_mempolicy(reinterpret_cast<int*>(invalid_address), nullptr, 0,
+ &dummy_stack_address, MPOL_F_ADDR | MPOL_F_NODE),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST(MempolicyTest, GetMempolicyCanOmitPointers) {
+ int mode;
+ uint64_t nodemask;
+
+ // Omit nodemask pointer.
+ ASSERT_THAT(get_mempolicy(&mode, nullptr, 0, nullptr, 0), SyscallSucceeds());
+ // Omit mode pointer.
+ ASSERT_THAT(get_mempolicy(nullptr, &nodemask,
+ sizeof(nodemask) * BITS_PER_BYTE, nullptr, 0),
+ SyscallSucceeds());
+ // Omit both pointers.
+ ASSERT_THAT(get_mempolicy(nullptr, nullptr, 0, nullptr, 0),
+ SyscallSucceeds());
+}
+
+TEST(MempolicyTest, GetMempolicyNextInterleaveNode) {
+ int mode;
+ // Policy for thread not yet set to MPOL_INTERLEAVE, can't query for
+ // the next node which will be used for allocation.
+ ASSERT_THAT(get_mempolicy(&mode, nullptr, 0, nullptr, MPOL_F_NODE),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Set default policy for thread to MPOL_INTERLEAVE.
+ uint64_t nodemask = 0x1;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSetMempolicy(
+ MPOL_INTERLEAVE, &nodemask, sizeof(nodemask) * BITS_PER_BYTE));
+
+ mode = -1;
+ ASSERT_THAT(get_mempolicy(&mode, nullptr, 0, nullptr, MPOL_F_NODE),
+ SyscallSucceeds());
+ EXPECT_EQ(0, mode);
+}
+
+TEST(MempolicyTest, Mbind) {
+ // Temporarily set the thread policy to MPOL_PREFERRED.
+ const auto cleanup_thread_policy =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSetMempolicy(MPOL_PREFERRED, nullptr, 0));
+
+ const auto mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS));
+
+ // vmas default to MPOL_DEFAULT irrespective of the thread policy (currently
+ // MPOL_PREFERRED).
+ int mode;
+ ASSERT_THAT(get_mempolicy(&mode, nullptr, 0, mapping.ptr(), MPOL_F_ADDR),
+ SyscallSucceeds());
+ EXPECT_EQ(mode, MPOL_DEFAULT);
+
+ // Set MPOL_PREFERRED for the vma and read it back.
+ ASSERT_THAT(
+ mbind(mapping.ptr(), mapping.len(), MPOL_PREFERRED, nullptr, 0, 0),
+ SyscallSucceeds());
+ ASSERT_THAT(get_mempolicy(&mode, nullptr, 0, mapping.ptr(), MPOL_F_ADDR),
+ SyscallSucceeds());
+ EXPECT_EQ(mode, MPOL_PREFERRED);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/mincore.cc b/test/syscalls/linux/mincore.cc
new file mode 100644
index 000000000..5c1240c89
--- /dev/null
+++ b/test/syscalls/linux/mincore.cc
@@ -0,0 +1,96 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <stdint.h>
+#include <string.h>
+#include <sys/mman.h>
+#include <unistd.h>
+
+#include <algorithm>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/util/memory_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+size_t CountSetLSBs(std::vector<unsigned char> const& vec) {
+ return std::count_if(begin(vec), end(vec),
+ [](unsigned char c) { return (c & 1) != 0; });
+}
+
+TEST(MincoreTest, DirtyAnonPagesAreResident) {
+ constexpr size_t kTestPageCount = 10;
+ auto const kTestMappingBytes = kTestPageCount * kPageSize;
+ auto m = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kTestMappingBytes, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ memset(m.ptr(), 0, m.len());
+
+ std::vector<unsigned char> vec(kTestPageCount, 0);
+ ASSERT_THAT(mincore(m.ptr(), kTestMappingBytes, vec.data()),
+ SyscallSucceeds());
+ EXPECT_EQ(kTestPageCount, CountSetLSBs(vec));
+}
+
+TEST(MincoreTest, UnalignedAddressFails) {
+ // Map and touch two pages, then try to mincore the second half of the first
+ // page + the first half of the second page. Both pages are mapped, but
+ // mincore should return EINVAL due to the misaligned start address.
+ constexpr size_t kTestPageCount = 2;
+ auto const kTestMappingBytes = kTestPageCount * kPageSize;
+ auto m = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kTestMappingBytes, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ memset(m.ptr(), 0, m.len());
+
+ std::vector<unsigned char> vec(kTestPageCount, 0);
+ EXPECT_THAT(mincore(reinterpret_cast<void*>(m.addr() + kPageSize / 2),
+ kPageSize, vec.data()),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(MincoreTest, UnalignedLengthSucceedsAndIsRoundedUp) {
+ // Map and touch two pages, then try to mincore the first page + the first
+ // half of the second page. mincore should silently round up the length to
+ // include both pages.
+ constexpr size_t kTestPageCount = 2;
+ auto const kTestMappingBytes = kTestPageCount * kPageSize;
+ auto m = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kTestMappingBytes, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ memset(m.ptr(), 0, m.len());
+
+ std::vector<unsigned char> vec(kTestPageCount, 0);
+ ASSERT_THAT(mincore(m.ptr(), kPageSize + kPageSize / 2, vec.data()),
+ SyscallSucceeds());
+ EXPECT_EQ(kTestPageCount, CountSetLSBs(vec));
+}
+
+TEST(MincoreTest, ZeroLengthSucceedsAndAllowsAnyVecBelowTaskSize) {
+ EXPECT_THAT(mincore(nullptr, 0, nullptr), SyscallSucceeds());
+}
+
+TEST(MincoreTest, InvalidLengthFails) {
+ EXPECT_THAT(mincore(nullptr, -1, nullptr), SyscallFailsWithErrno(ENOMEM));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/mkdir.cc b/test/syscalls/linux/mkdir.cc
new file mode 100644
index 000000000..4036a9275
--- /dev/null
+++ b/test/syscalls/linux/mkdir.cc
@@ -0,0 +1,88 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/capability_util.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/temp_umask.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class MkdirTest : public ::testing::Test {
+ protected:
+ // SetUp creates various configurations of files.
+ void SetUp() override { dirname_ = NewTempAbsPath(); }
+
+ // TearDown unlinks created files.
+ void TearDown() override {
+ EXPECT_THAT(rmdir(dirname_.c_str()), SyscallSucceeds());
+ }
+
+ std::string dirname_;
+};
+
+TEST_F(MkdirTest, CanCreateWritableDir) {
+ ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds());
+ std::string filename = JoinPath(dirname_, "anything");
+ int fd;
+ ASSERT_THAT(fd = open(filename.c_str(), O_RDWR | O_CREAT, 0666),
+ SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+ ASSERT_THAT(unlink(filename.c_str()), SyscallSucceeds());
+}
+
+TEST_F(MkdirTest, HonorsUmask) {
+ constexpr mode_t kMask = 0111;
+ TempUmask mask(kMask);
+ ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds());
+ struct stat statbuf;
+ ASSERT_THAT(stat(dirname_.c_str(), &statbuf), SyscallSucceeds());
+ EXPECT_EQ(0777 & ~kMask, statbuf.st_mode & 0777);
+}
+
+TEST_F(MkdirTest, HonorsUmask2) {
+ constexpr mode_t kMask = 0142;
+ TempUmask mask(kMask);
+ ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds());
+ struct stat statbuf;
+ ASSERT_THAT(stat(dirname_.c_str(), &statbuf), SyscallSucceeds());
+ EXPECT_EQ(0777 & ~kMask, statbuf.st_mode & 0777);
+}
+
+TEST_F(MkdirTest, FailsOnDirWithoutWritePerms) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ ASSERT_THAT(mkdir(dirname_.c_str(), 0555), SyscallSucceeds());
+ auto dir = JoinPath(dirname_.c_str(), "foo");
+ EXPECT_THAT(mkdir(dir.c_str(), 0777), SyscallFailsWithErrno(EACCES));
+ EXPECT_THAT(open(JoinPath(dirname_, "file").c_str(), O_RDWR | O_CREAT, 0666),
+ SyscallFailsWithErrno(EACCES));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/mknod.cc b/test/syscalls/linux/mknod.cc
new file mode 100644
index 000000000..05dfb375a
--- /dev/null
+++ b/test/syscalls/linux/mknod.cc
@@ -0,0 +1,190 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(MknodTest, RegularFile) {
+ const std::string node0 = NewTempAbsPath();
+ EXPECT_THAT(mknod(node0.c_str(), S_IFREG, 0), SyscallSucceeds());
+
+ const std::string node1 = NewTempAbsPath();
+ EXPECT_THAT(mknod(node1.c_str(), 0, 0), SyscallSucceeds());
+}
+
+TEST(MknodTest, RegularFilePermissions) {
+ const std::string node = NewTempAbsPath();
+ mode_t newUmask = 0077;
+ umask(newUmask);
+
+ // Attempt to open file with mode 0777. Not specifying file type should create
+ // a regualar file.
+ mode_t perms = S_IRWXU | S_IRWXG | S_IRWXO;
+ EXPECT_THAT(mknod(node.c_str(), perms, 0), SyscallSucceeds());
+
+ // In the absence of a default ACL, the permissions of the created node are
+ // (mode & ~umask). -- mknod(2)
+ mode_t wantPerms = perms & ~newUmask;
+ struct stat st;
+ ASSERT_THAT(stat(node.c_str(), &st), SyscallSucceeds());
+ ASSERT_EQ(st.st_mode & 0777, wantPerms);
+
+ // "Zero file type is equivalent to type S_IFREG." - mknod(2)
+ ASSERT_EQ(st.st_mode & S_IFMT, S_IFREG);
+}
+
+TEST(MknodTest, MknodAtFIFO) {
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const std::string fifo_relpath = NewTempRelPath();
+ const std::string fifo = JoinPath(dir.path(), fifo_relpath);
+
+ const FileDescriptor dirfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path().c_str(), O_RDONLY));
+ ASSERT_THAT(mknodat(dirfd.get(), fifo_relpath.c_str(), S_IFIFO | S_IRUSR, 0),
+ SyscallSucceeds());
+
+ struct stat st;
+ ASSERT_THAT(stat(fifo.c_str(), &st), SyscallSucceeds());
+ EXPECT_TRUE(S_ISFIFO(st.st_mode));
+}
+
+TEST(MknodTest, MknodOnExistingPathFails) {
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const TempPath slink = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(GetAbsoluteTestTmpdir(), file.path()));
+
+ EXPECT_THAT(mknod(file.path().c_str(), S_IFREG, 0),
+ SyscallFailsWithErrno(EEXIST));
+ EXPECT_THAT(mknod(file.path().c_str(), S_IFIFO, 0),
+ SyscallFailsWithErrno(EEXIST));
+ EXPECT_THAT(mknod(slink.path().c_str(), S_IFREG, 0),
+ SyscallFailsWithErrno(EEXIST));
+ EXPECT_THAT(mknod(slink.path().c_str(), S_IFIFO, 0),
+ SyscallFailsWithErrno(EEXIST));
+}
+
+TEST(MknodTest, UnimplementedTypesReturnError) {
+ const std::string path = NewTempAbsPath();
+
+ if (IsRunningWithVFS1()) {
+ ASSERT_THAT(mknod(path.c_str(), S_IFSOCK, 0),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+ }
+ // These will fail on linux as well since we don't have CAP_MKNOD.
+ ASSERT_THAT(mknod(path.c_str(), S_IFCHR, 0), SyscallFailsWithErrno(EPERM));
+ ASSERT_THAT(mknod(path.c_str(), S_IFBLK, 0), SyscallFailsWithErrno(EPERM));
+}
+
+TEST(MknodTest, Fifo) {
+ const std::string fifo = NewTempAbsPath();
+ ASSERT_THAT(mknod(fifo.c_str(), S_IFIFO | S_IRUSR | S_IWUSR, 0),
+ SyscallSucceeds());
+
+ struct stat st;
+ ASSERT_THAT(stat(fifo.c_str(), &st), SyscallSucceeds());
+ EXPECT_TRUE(S_ISFIFO(st.st_mode));
+
+ std::string msg = "some std::string";
+ std::vector<char> buf(512);
+
+ // Read-end of the pipe.
+ ScopedThread t([&fifo, &buf, &msg]() {
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_RDONLY));
+ EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(msg.length()));
+ EXPECT_EQ(msg, std::string(buf.data()));
+ });
+
+ // Write-end of the pipe.
+ FileDescriptor wfd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_WRONLY));
+ EXPECT_THAT(WriteFd(wfd.get(), msg.c_str(), msg.length()),
+ SyscallSucceedsWithValue(msg.length()));
+}
+
+TEST(MknodTest, FifoOtrunc) {
+ const std::string fifo = NewTempAbsPath();
+ ASSERT_THAT(mknod(fifo.c_str(), S_IFIFO | S_IRUSR | S_IWUSR, 0),
+ SyscallSucceeds());
+
+ struct stat st = {};
+ ASSERT_THAT(stat(fifo.c_str(), &st), SyscallSucceeds());
+ EXPECT_TRUE(S_ISFIFO(st.st_mode));
+
+ std::string msg = "some std::string";
+ std::vector<char> buf(512);
+ // Read-end of the pipe.
+ ScopedThread t([&fifo, &buf, &msg]() {
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_RDONLY));
+ EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(msg.length()));
+ EXPECT_EQ(msg, std::string(buf.data()));
+ });
+
+ // Write-end of the pipe.
+ FileDescriptor wfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_WRONLY | O_TRUNC));
+ EXPECT_THAT(WriteFd(wfd.get(), msg.c_str(), msg.length()),
+ SyscallSucceedsWithValue(msg.length()));
+}
+
+TEST(MknodTest, FifoTruncNoOp) {
+ const std::string fifo = NewTempAbsPath();
+ ASSERT_THAT(mknod(fifo.c_str(), S_IFIFO | S_IRUSR | S_IWUSR, 0),
+ SyscallSucceeds());
+
+ EXPECT_THAT(truncate(fifo.c_str(), 0), SyscallFailsWithErrno(EINVAL));
+
+ struct stat st = {};
+ ASSERT_THAT(stat(fifo.c_str(), &st), SyscallSucceeds());
+ EXPECT_TRUE(S_ISFIFO(st.st_mode));
+
+ std::string msg = "some std::string";
+ std::vector<char> buf(512);
+ // Read-end of the pipe.
+ ScopedThread t([&fifo, &buf, &msg]() {
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_RDONLY));
+ EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(msg.length()));
+ EXPECT_EQ(msg, std::string(buf.data()));
+ });
+
+ FileDescriptor wfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_WRONLY | O_TRUNC));
+ EXPECT_THAT(ftruncate(wfd.get(), 0), SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(WriteFd(wfd.get(), msg.c_str(), msg.length()),
+ SyscallSucceedsWithValue(msg.length()));
+ EXPECT_THAT(ftruncate(wfd.get(), 0), SyscallFailsWithErrno(EINVAL));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/mlock.cc b/test/syscalls/linux/mlock.cc
new file mode 100644
index 000000000..78ac96bed
--- /dev/null
+++ b/test/syscalls/linux/mlock.cc
@@ -0,0 +1,332 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/mman.h>
+#include <sys/resource.h>
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include <cerrno>
+#include <cstring>
+
+#include "gmock/gmock.h"
+#include "test/util/capability_util.h"
+#include "test/util/cleanup.h"
+#include "test/util/memory_util.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/rlimit_util.h"
+#include "test/util/test_util.h"
+
+using ::testing::_;
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+PosixErrorOr<bool> CanMlock() {
+ struct rlimit rlim;
+ if (getrlimit(RLIMIT_MEMLOCK, &rlim) < 0) {
+ return PosixError(errno, "getrlimit(RLIMIT_MEMLOCK)");
+ }
+ if (rlim.rlim_cur != 0) {
+ return true;
+ }
+ return HaveCapability(CAP_IPC_LOCK);
+}
+
+// Returns true if the page containing addr is mlocked.
+bool IsPageMlocked(uintptr_t addr) {
+ // This relies on msync(MS_INVALIDATE) interacting correctly with mlocked
+ // pages, which is tested for by the MsyncInvalidate case below.
+ int const rv = msync(reinterpret_cast<void*>(addr & ~(kPageSize - 1)),
+ kPageSize, MS_ASYNC | MS_INVALIDATE);
+ if (rv == 0) {
+ return false;
+ }
+ // This uses TEST_PCHECK_MSG since it's used in subprocesses.
+ TEST_PCHECK_MSG(errno == EBUSY, "msync failed with unexpected errno");
+ return true;
+}
+
+TEST(MlockTest, Basic) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock()));
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ EXPECT_FALSE(IsPageMlocked(mapping.addr()));
+ ASSERT_THAT(mlock(mapping.ptr(), mapping.len()), SyscallSucceeds());
+ EXPECT_TRUE(IsPageMlocked(mapping.addr()));
+}
+
+TEST(MlockTest, ProtNone) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock()));
+ auto const mapping =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, MAP_PRIVATE));
+ EXPECT_FALSE(IsPageMlocked(mapping.addr()));
+ ASSERT_THAT(mlock(mapping.ptr(), mapping.len()),
+ SyscallFailsWithErrno(ENOMEM));
+ // ENOMEM is returned because mlock can't populate the page, but it's still
+ // considered locked.
+ EXPECT_TRUE(IsPageMlocked(mapping.addr()));
+}
+
+TEST(MlockTest, MadviseDontneed) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock()));
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ ASSERT_THAT(mlock(mapping.ptr(), mapping.len()), SyscallSucceeds());
+ EXPECT_THAT(madvise(mapping.ptr(), mapping.len(), MADV_DONTNEED),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(MlockTest, MsyncInvalidate) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock()));
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ ASSERT_THAT(mlock(mapping.ptr(), mapping.len()), SyscallSucceeds());
+ EXPECT_THAT(msync(mapping.ptr(), mapping.len(), MS_ASYNC | MS_INVALIDATE),
+ SyscallFailsWithErrno(EBUSY));
+ EXPECT_THAT(msync(mapping.ptr(), mapping.len(), MS_SYNC | MS_INVALIDATE),
+ SyscallFailsWithErrno(EBUSY));
+}
+
+TEST(MlockTest, Fork) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock()));
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ EXPECT_FALSE(IsPageMlocked(mapping.addr()));
+ ASSERT_THAT(mlock(mapping.ptr(), mapping.len()), SyscallSucceeds());
+ EXPECT_TRUE(IsPageMlocked(mapping.addr()));
+ EXPECT_THAT(
+ InForkedProcess([&] { TEST_CHECK(!IsPageMlocked(mapping.addr())); }),
+ IsPosixErrorOkAndHolds(0));
+}
+
+TEST(MlockTest, RlimitMemlockZero) {
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) {
+ ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false));
+ }
+ Cleanup reset_rlimit =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, 0));
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ EXPECT_FALSE(IsPageMlocked(mapping.addr()));
+ ASSERT_THAT(mlock(mapping.ptr(), mapping.len()),
+ SyscallFailsWithErrno(EPERM));
+}
+
+TEST(MlockTest, RlimitMemlockInsufficient) {
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) {
+ ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false));
+ }
+ Cleanup reset_rlimit =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, kPageSize));
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ EXPECT_FALSE(IsPageMlocked(mapping.addr()));
+ ASSERT_THAT(mlock(mapping.ptr(), mapping.len()),
+ SyscallFailsWithErrno(ENOMEM));
+}
+
+TEST(MunlockTest, Basic) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock()));
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ EXPECT_FALSE(IsPageMlocked(mapping.addr()));
+ ASSERT_THAT(mlock(mapping.ptr(), mapping.len()), SyscallSucceeds());
+ EXPECT_TRUE(IsPageMlocked(mapping.addr()));
+ ASSERT_THAT(munlock(mapping.ptr(), mapping.len()), SyscallSucceeds());
+ EXPECT_FALSE(IsPageMlocked(mapping.addr()));
+}
+
+TEST(MunlockTest, NotLocked) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock()));
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ EXPECT_FALSE(IsPageMlocked(mapping.addr()));
+ EXPECT_THAT(munlock(mapping.ptr(), mapping.len()), SyscallSucceeds());
+ EXPECT_FALSE(IsPageMlocked(mapping.addr()));
+}
+
+// There is currently no test for mlockall(MCL_CURRENT) because the default
+// RLIMIT_MEMLOCK of 64 KB is insufficient to actually invoke
+// mlockall(MCL_CURRENT).
+
+TEST(MlockallTest, Future) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock()));
+
+ // Run this test in a separate (single-threaded) subprocess to ensure that a
+ // background thread doesn't try to mmap a large amount of memory, fail due
+ // to hitting RLIMIT_MEMLOCK, and explode the process violently.
+ auto const do_test = [] {
+ auto const mapping =
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE).ValueOrDie();
+ TEST_CHECK(!IsPageMlocked(mapping.addr()));
+ TEST_PCHECK(mlockall(MCL_FUTURE) == 0);
+ // Ensure that mlockall(MCL_FUTURE) is turned off before the end of the
+ // test, as otherwise mmaps may fail unexpectedly.
+ Cleanup do_munlockall([] { TEST_PCHECK(munlockall() == 0); });
+ auto const mapping2 =
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE).ValueOrDie();
+ TEST_CHECK(IsPageMlocked(mapping2.addr()));
+ // Fire munlockall() and check that it disables mlockall(MCL_FUTURE).
+ do_munlockall.Release()();
+ auto const mapping3 =
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE).ValueOrDie();
+ TEST_CHECK(!IsPageMlocked(mapping2.addr()));
+ };
+ EXPECT_THAT(InForkedProcess(do_test), IsPosixErrorOkAndHolds(0));
+}
+
+TEST(MunlockallTest, Basic) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock()));
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_LOCKED));
+ EXPECT_TRUE(IsPageMlocked(mapping.addr()));
+ ASSERT_THAT(munlockall(), SyscallSucceeds());
+ EXPECT_FALSE(IsPageMlocked(mapping.addr()));
+}
+
+#ifndef SYS_mlock2
+#if defined(__x86_64__)
+#define SYS_mlock2 325
+#elif defined(__aarch64__)
+#define SYS_mlock2 284
+#endif
+#endif
+
+#ifndef MLOCK_ONFAULT
+#define MLOCK_ONFAULT 0x01 // Linux: include/uapi/asm-generic/mman-common.h
+#endif
+
+#ifdef SYS_mlock2
+
+int mlock2(void const* addr, size_t len, int flags) {
+ return syscall(SYS_mlock2, addr, len, flags);
+}
+
+TEST(Mlock2Test, NoFlags) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock()));
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ EXPECT_FALSE(IsPageMlocked(mapping.addr()));
+ ASSERT_THAT(mlock2(mapping.ptr(), mapping.len(), 0), SyscallSucceeds());
+ EXPECT_TRUE(IsPageMlocked(mapping.addr()));
+}
+
+TEST(Mlock2Test, MlockOnfault) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock()));
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ EXPECT_FALSE(IsPageMlocked(mapping.addr()));
+ ASSERT_THAT(mlock2(mapping.ptr(), mapping.len(), MLOCK_ONFAULT),
+ SyscallSucceeds());
+ EXPECT_TRUE(IsPageMlocked(mapping.addr()));
+}
+
+TEST(Mlock2Test, UnknownFlags) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock()));
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ EXPECT_THAT(mlock2(mapping.ptr(), mapping.len(), ~0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+#endif // defined(SYS_mlock2)
+
+TEST(MapLockedTest, Basic) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock()));
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_LOCKED));
+ EXPECT_TRUE(IsPageMlocked(mapping.addr()));
+ EXPECT_THAT(munlock(mapping.ptr(), mapping.len()), SyscallSucceeds());
+ EXPECT_FALSE(IsPageMlocked(mapping.addr()));
+}
+
+TEST(MapLockedTest, RlimitMemlockZero) {
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) {
+ ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false));
+ }
+ Cleanup reset_rlimit =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, 0));
+ EXPECT_THAT(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_LOCKED),
+ PosixErrorIs(EPERM, _));
+}
+
+TEST(MapLockedTest, RlimitMemlockInsufficient) {
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) {
+ ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false));
+ }
+ Cleanup reset_rlimit =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, kPageSize));
+ EXPECT_THAT(
+ MmapAnon(2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_LOCKED),
+ PosixErrorIs(EAGAIN, _));
+}
+
+TEST(MremapLockedTest, Basic) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock()));
+ auto mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_LOCKED));
+ EXPECT_TRUE(IsPageMlocked(mapping.addr()));
+
+ void* addr = mremap(mapping.ptr(), mapping.len(), 2 * mapping.len(),
+ MREMAP_MAYMOVE, nullptr);
+ if (addr == MAP_FAILED) {
+ FAIL() << "mremap failed: " << errno << " (" << strerror(errno) << ")";
+ }
+ mapping.release();
+ mapping.reset(addr, 2 * mapping.len());
+ EXPECT_TRUE(IsPageMlocked(reinterpret_cast<uintptr_t>(addr)));
+}
+
+TEST(MremapLockedTest, RlimitMemlockZero) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock()));
+ auto mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_LOCKED));
+ EXPECT_TRUE(IsPageMlocked(mapping.addr()));
+
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) {
+ ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false));
+ }
+ Cleanup reset_rlimit =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, 0));
+ void* addr = mremap(mapping.ptr(), mapping.len(), 2 * mapping.len(),
+ MREMAP_MAYMOVE, nullptr);
+ EXPECT_TRUE(addr == MAP_FAILED && errno == EAGAIN)
+ << "addr = " << addr << ", errno = " << errno;
+}
+
+TEST(MremapLockedTest, RlimitMemlockInsufficient) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock()));
+ auto mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_LOCKED));
+ EXPECT_TRUE(IsPageMlocked(mapping.addr()));
+
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) {
+ ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false));
+ }
+ Cleanup reset_rlimit = ASSERT_NO_ERRNO_AND_VALUE(
+ ScopedSetSoftRlimit(RLIMIT_MEMLOCK, mapping.len()));
+ void* addr = mremap(mapping.ptr(), mapping.len(), 2 * mapping.len(),
+ MREMAP_MAYMOVE, nullptr);
+ EXPECT_TRUE(addr == MAP_FAILED && errno == EAGAIN)
+ << "addr = " << addr << ", errno = " << errno;
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc
new file mode 100644
index 000000000..6d3227ab6
--- /dev/null
+++ b/test/syscalls/linux/mmap.cc
@@ -0,0 +1,1676 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/magic.h>
+#include <linux/unistd.h>
+#include <signal.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/mman.h>
+#include <sys/resource.h>
+#include <sys/statfs.h>
+#include <sys/syscall.h>
+#include <sys/time.h>
+#include <sys/types.h>
+#include <sys/wait.h>
+#include <unistd.h>
+
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/escaping.h"
+#include "absl/strings/str_split.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/memory_util.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+using ::testing::Gt;
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+PosixErrorOr<int64_t> VirtualMemorySize() {
+ ASSIGN_OR_RETURN_ERRNO(auto contents, GetContents("/proc/self/statm"));
+ std::vector<std::string> parts = absl::StrSplit(contents, ' ');
+ if (parts.empty()) {
+ return PosixError(EINVAL, "Unable to parse /proc/self/statm");
+ }
+ ASSIGN_OR_RETURN_ERRNO(auto pages, Atoi<int64_t>(parts[0]));
+ return pages * getpagesize();
+}
+
+class MMapTest : public ::testing::Test {
+ protected:
+ // Unmap mapping, if one was made.
+ void TearDown() override {
+ if (addr_) {
+ EXPECT_THAT(Unmap(), SyscallSucceeds());
+ }
+ }
+
+ // Remembers mapping, so it can be automatically unmapped.
+ uintptr_t Map(uintptr_t addr, size_t length, int prot, int flags, int fd,
+ off_t offset) {
+ void* ret =
+ mmap(reinterpret_cast<void*>(addr), length, prot, flags, fd, offset);
+
+ if (ret != MAP_FAILED) {
+ addr_ = ret;
+ length_ = length;
+ }
+
+ return reinterpret_cast<uintptr_t>(ret);
+ }
+
+ // Unmap previous mapping
+ int Unmap() {
+ if (!addr_) {
+ return -1;
+ }
+
+ int ret = munmap(addr_, length_);
+
+ addr_ = nullptr;
+ length_ = 0;
+
+ return ret;
+ }
+
+ // Msync the mapping.
+ int Msync() { return msync(addr_, length_, MS_SYNC); }
+
+ // Mlock the mapping.
+ int Mlock() { return mlock(addr_, length_); }
+
+ // Munlock the mapping.
+ int Munlock() { return munlock(addr_, length_); }
+
+ int Protect(uintptr_t addr, size_t length, int prot) {
+ return mprotect(reinterpret_cast<void*>(addr), length, prot);
+ }
+
+ void* addr_ = nullptr;
+ size_t length_ = 0;
+};
+
+// Matches if arg contains the same contents as string str.
+MATCHER_P(EqualsMemory, str, "") {
+ if (0 == memcmp(arg, str.c_str(), str.size())) {
+ return true;
+ }
+
+ *result_listener << "Memory did not match. Got:\n"
+ << absl::BytesToHexString(
+ std::string(static_cast<char*>(arg), str.size()))
+ << "Want:\n"
+ << absl::BytesToHexString(str);
+ return false;
+}
+
+// We can't map pipes, but for different reasons.
+TEST_F(MMapTest, MapPipe) {
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ EXPECT_THAT(Map(0, kPageSize, PROT_READ, MAP_PRIVATE, fds[0], 0),
+ SyscallFailsWithErrno(ENODEV));
+ EXPECT_THAT(Map(0, kPageSize, PROT_READ, MAP_PRIVATE, fds[1], 0),
+ SyscallFailsWithErrno(EACCES));
+ ASSERT_THAT(close(fds[0]), SyscallSucceeds());
+ ASSERT_THAT(close(fds[1]), SyscallSucceeds());
+}
+
+// It's very common to mmap /dev/zero because anonymous mappings aren't part
+// of POSIX although they are widely supported. So a zero initialized memory
+// region would actually come from a "file backed" /dev/zero mapping.
+TEST_F(MMapTest, MapDevZeroShared) {
+ // This test will verify that we're able to map a page backed by /dev/zero
+ // as MAP_SHARED.
+ const FileDescriptor dev_zero =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDWR));
+
+ // Test that we can create a RW SHARED mapping of /dev/zero.
+ ASSERT_THAT(
+ Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, dev_zero.get(), 0),
+ SyscallSucceeds());
+}
+
+TEST_F(MMapTest, MapDevZeroPrivate) {
+ // This test will verify that we're able to map a page backed by /dev/zero
+ // as MAP_PRIVATE.
+ const FileDescriptor dev_zero =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDWR));
+
+ // Test that we can create a RW SHARED mapping of /dev/zero.
+ ASSERT_THAT(
+ Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, dev_zero.get(), 0),
+ SyscallSucceeds());
+}
+
+TEST_F(MMapTest, MapDevZeroNoPersistence) {
+ // This test will verify that two independent mappings of /dev/zero do not
+ // appear to reference the same "backed file."
+
+ const FileDescriptor dev_zero1 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDWR));
+ const FileDescriptor dev_zero2 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDWR));
+
+ ASSERT_THAT(
+ Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, dev_zero1.get(), 0),
+ SyscallSucceeds());
+
+ // Create a second mapping via the second /dev/zero fd.
+ void* psec_map = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED,
+ dev_zero2.get(), 0);
+ ASSERT_THAT(reinterpret_cast<intptr_t>(psec_map), SyscallSucceeds());
+
+ // Always unmap.
+ auto cleanup_psec_map = Cleanup(
+ [&] { EXPECT_THAT(munmap(psec_map, kPageSize), SyscallSucceeds()); });
+
+ // Verify that we have independently addressed pages.
+ ASSERT_NE(psec_map, addr_);
+
+ std::string buf_zero(kPageSize, 0x00);
+ std::string buf_ones(kPageSize, 0xFF);
+
+ // Verify the first is actually all zeros after mmap.
+ EXPECT_THAT(addr_, EqualsMemory(buf_zero));
+
+ // Let's fill in the first mapping with 0xFF.
+ memcpy(addr_, buf_ones.data(), kPageSize);
+
+ // Verify that the memcpy actually stuck in the page.
+ EXPECT_THAT(addr_, EqualsMemory(buf_ones));
+
+ // Verify that it didn't affect the second page which should be all zeros.
+ EXPECT_THAT(psec_map, EqualsMemory(buf_zero));
+}
+
+TEST_F(MMapTest, MapDevZeroSharedMultiplePages) {
+ // This will test that we're able to map /dev/zero over multiple pages.
+ const FileDescriptor dev_zero =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDWR));
+
+ // Test that we can create a RW SHARED mapping of /dev/zero.
+ ASSERT_THAT(Map(0, kPageSize * 2, PROT_READ | PROT_WRITE, MAP_PRIVATE,
+ dev_zero.get(), 0),
+ SyscallSucceeds());
+
+ std::string buf_zero(kPageSize * 2, 0x00);
+ std::string buf_ones(kPageSize * 2, 0xFF);
+
+ // Verify the two pages are actually all zeros after mmap.
+ EXPECT_THAT(addr_, EqualsMemory(buf_zero));
+
+ // Fill out the pages with all ones.
+ memcpy(addr_, buf_ones.data(), kPageSize * 2);
+
+ // Verify that the memcpy actually stuck in the pages.
+ EXPECT_THAT(addr_, EqualsMemory(buf_ones));
+}
+
+TEST_F(MMapTest, MapDevZeroSharedFdNoPersistence) {
+ // This test will verify that two independent mappings of /dev/zero do not
+ // appear to reference the same "backed file" even when mapped from the
+ // same initial fd.
+ const FileDescriptor dev_zero =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDWR));
+
+ ASSERT_THAT(
+ Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, dev_zero.get(), 0),
+ SyscallSucceeds());
+
+ // Create a second mapping via the same fd.
+ void* psec_map = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED,
+ dev_zero.get(), 0);
+ ASSERT_THAT(reinterpret_cast<int64_t>(psec_map), SyscallSucceeds());
+
+ // Always unmap.
+ auto cleanup_psec_map = Cleanup(
+ [&] { ASSERT_THAT(munmap(psec_map, kPageSize), SyscallSucceeds()); });
+
+ // Verify that we have independently addressed pages.
+ ASSERT_NE(psec_map, addr_);
+
+ std::string buf_zero(kPageSize, 0x00);
+ std::string buf_ones(kPageSize, 0xFF);
+
+ // Verify the first is actually all zeros after mmap.
+ EXPECT_THAT(addr_, EqualsMemory(buf_zero));
+
+ // Let's fill in the first mapping with 0xFF.
+ memcpy(addr_, buf_ones.data(), kPageSize);
+
+ // Verify that the memcpy actually stuck in the page.
+ EXPECT_THAT(addr_, EqualsMemory(buf_ones));
+
+ // Verify that it didn't affect the second page which should be all zeros.
+ EXPECT_THAT(psec_map, EqualsMemory(buf_zero));
+}
+
+TEST_F(MMapTest, MapDevZeroSegfaultAfterUnmap) {
+ SetupGvisorDeathTest();
+
+ // This test will verify that we're able to map a page backed by /dev/zero
+ // as MAP_SHARED and after it's unmapped any access results in a SIGSEGV.
+ // This test is redundant but given the special nature of /dev/zero mappings
+ // it doesn't hurt.
+ const FileDescriptor dev_zero =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDWR));
+
+ const auto rest = [&] {
+ // Test that we can create a RW SHARED mapping of /dev/zero.
+ TEST_PCHECK(Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED,
+ dev_zero.get(),
+ 0) != reinterpret_cast<uintptr_t>(MAP_FAILED));
+
+ // Confirm that accesses after the unmap result in a SIGSEGV.
+ //
+ // N.B. We depend on this process being single-threaded to ensure there
+ // can't be another mmap to map addr before the dereference below.
+ void* addr_saved = addr_; // Unmap resets addr_.
+ TEST_PCHECK(Unmap() == 0);
+ *reinterpret_cast<volatile int*>(addr_saved) = 0xFF;
+ };
+
+ EXPECT_THAT(InForkedProcess(rest),
+ IsPosixErrorOkAndHolds(W_EXITCODE(0, SIGSEGV)));
+}
+
+TEST_F(MMapTest, MapDevZeroUnaligned) {
+ const FileDescriptor dev_zero =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDWR));
+ const size_t size = kPageSize + kPageSize / 2;
+ const std::string buf_zero(size, 0x00);
+
+ ASSERT_THAT(
+ Map(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, dev_zero.get(), 0),
+ SyscallSucceeds());
+ EXPECT_THAT(addr_, EqualsMemory(buf_zero));
+ ASSERT_THAT(Unmap(), SyscallSucceeds());
+
+ ASSERT_THAT(
+ Map(0, size, PROT_READ | PROT_WRITE, MAP_PRIVATE, dev_zero.get(), 0),
+ SyscallSucceeds());
+ EXPECT_THAT(addr_, EqualsMemory(buf_zero));
+}
+
+// We can't map _some_ character devices.
+TEST_F(MMapTest, MapCharDevice) {
+ const FileDescriptor cdevfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/random", 0, 0));
+ EXPECT_THAT(Map(0, kPageSize, PROT_READ, MAP_PRIVATE, cdevfd.get(), 0),
+ SyscallFailsWithErrno(ENODEV));
+}
+
+// We can't map directories.
+TEST_F(MMapTest, MapDirectory) {
+ const FileDescriptor dirfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(GetAbsoluteTestTmpdir(), 0, 0));
+ EXPECT_THAT(Map(0, kPageSize, PROT_READ, MAP_PRIVATE, dirfd.get(), 0),
+ SyscallFailsWithErrno(ENODEV));
+}
+
+// We can map *something*
+TEST_F(MMapTest, MapAnything) {
+ EXPECT_THAT(Map(0, kPageSize, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallSucceedsWithValue(Gt(0)));
+}
+
+// Map length < PageSize allowed
+TEST_F(MMapTest, SmallMap) {
+ EXPECT_THAT(Map(0, 128, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallSucceeds());
+}
+
+// Hint address doesn't break anything.
+// Note: there is no requirement we actually get the hint address
+TEST_F(MMapTest, HintAddress) {
+ EXPECT_THAT(
+ Map(0x30000000, kPageSize, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallSucceeds());
+}
+
+// MAP_FIXED gives us exactly the requested address
+TEST_F(MMapTest, MapFixed) {
+ EXPECT_THAT(Map(0x30000000, kPageSize, PROT_NONE,
+ MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED, -1, 0),
+ SyscallSucceedsWithValue(0x30000000));
+}
+
+// 64-bit addresses work too
+#if defined(__x86_64__) || defined(__aarch64__)
+TEST_F(MMapTest, MapFixed64) {
+ EXPECT_THAT(Map(0x300000000000, kPageSize, PROT_NONE,
+ MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED, -1, 0),
+ SyscallSucceedsWithValue(0x300000000000));
+}
+#endif
+
+// MAP_STACK allowed.
+// There isn't a good way to verify it did anything.
+TEST_F(MMapTest, MapStack) {
+ EXPECT_THAT(Map(0, kPageSize, PROT_NONE,
+ MAP_PRIVATE | MAP_ANONYMOUS | MAP_STACK, -1, 0),
+ SyscallSucceeds());
+}
+
+// MAP_LOCKED allowed.
+// There isn't a good way to verify it did anything.
+TEST_F(MMapTest, MapLocked) {
+ EXPECT_THAT(Map(0, kPageSize, PROT_NONE,
+ MAP_PRIVATE | MAP_ANONYMOUS | MAP_LOCKED, -1, 0),
+ SyscallSucceeds());
+}
+
+// MAP_PRIVATE or MAP_SHARED must be passed
+TEST_F(MMapTest, NotPrivateOrShared) {
+ EXPECT_THAT(Map(0, kPageSize, PROT_NONE, MAP_ANONYMOUS, -1, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// Only one of MAP_PRIVATE or MAP_SHARED may be passed
+TEST_F(MMapTest, PrivateAndShared) {
+ EXPECT_THAT(Map(0, kPageSize, PROT_NONE,
+ MAP_PRIVATE | MAP_SHARED | MAP_ANONYMOUS, -1, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(MMapTest, FixedAlignment) {
+ // Addr must be page aligned (MAP_FIXED)
+ EXPECT_THAT(Map(0x30000001, kPageSize, PROT_NONE,
+ MAP_PRIVATE | MAP_FIXED | MAP_ANONYMOUS, -1, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// Non-MAP_FIXED address does not need to be page aligned
+TEST_F(MMapTest, NonFixedAlignment) {
+ EXPECT_THAT(
+ Map(0x30000001, kPageSize, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallSucceeds());
+}
+
+// Length = 0 results in EINVAL.
+TEST_F(MMapTest, InvalidLength) {
+ EXPECT_THAT(Map(0, 0, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// Bad fd not allowed.
+TEST_F(MMapTest, BadFd) {
+ EXPECT_THAT(Map(0, kPageSize, PROT_NONE, MAP_PRIVATE, 999, 0),
+ SyscallFailsWithErrno(EBADF));
+}
+
+// Mappings are writable.
+TEST_F(MMapTest, ProtWrite) {
+ uint64_t addr;
+ constexpr uint8_t kFirstWord[] = {42, 42, 42, 42};
+
+ EXPECT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallSucceeds());
+
+ // This shouldn't cause a SIGSEGV.
+ memset(reinterpret_cast<void*>(addr), 42, kPageSize);
+
+ // The written data should actually be there.
+ EXPECT_EQ(
+ 0, memcmp(reinterpret_cast<void*>(addr), kFirstWord, sizeof(kFirstWord)));
+}
+
+// "Write-only" mappings are writable *and* readable.
+TEST_F(MMapTest, ProtWriteOnly) {
+ uint64_t addr;
+ constexpr uint8_t kFirstWord[] = {42, 42, 42, 42};
+
+ EXPECT_THAT(
+ addr = Map(0, kPageSize, PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallSucceeds());
+
+ // This shouldn't cause a SIGSEGV.
+ memset(reinterpret_cast<void*>(addr), 42, kPageSize);
+
+ // The written data should actually be there.
+ EXPECT_EQ(
+ 0, memcmp(reinterpret_cast<void*>(addr), kFirstWord, sizeof(kFirstWord)));
+}
+
+// "Write-only" mappings are readable.
+//
+// This is distinct from above to ensure the page is accessible even if the
+// initial fault is a write fault.
+TEST_F(MMapTest, ProtWriteOnlyReadable) {
+ uint64_t addr;
+ constexpr uint64_t kFirstWord = 0;
+
+ EXPECT_THAT(
+ addr = Map(0, kPageSize, PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallSucceeds());
+
+ EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr), &kFirstWord,
+ sizeof(kFirstWord)));
+}
+
+// Mappings are writable after mprotect from PROT_NONE to PROT_READ|PROT_WRITE.
+TEST_F(MMapTest, ProtectProtWrite) {
+ uint64_t addr;
+ constexpr uint8_t kFirstWord[] = {42, 42, 42, 42};
+
+ EXPECT_THAT(
+ addr = Map(0, kPageSize, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallSucceeds());
+
+ ASSERT_THAT(Protect(addr, kPageSize, PROT_READ | PROT_WRITE),
+ SyscallSucceeds());
+
+ // This shouldn't cause a SIGSEGV.
+ memset(reinterpret_cast<void*>(addr), 42, kPageSize);
+
+ // The written data should actually be there.
+ EXPECT_EQ(
+ 0, memcmp(reinterpret_cast<void*>(addr), kFirstWord, sizeof(kFirstWord)));
+}
+
+// SIGSEGV raised when reading PROT_NONE memory
+TEST_F(MMapTest, ProtNoneDeath) {
+ SetupGvisorDeathTest();
+
+ uintptr_t addr;
+
+ ASSERT_THAT(
+ addr = Map(0, kPageSize, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallSucceeds());
+
+ EXPECT_EXIT(*reinterpret_cast<volatile int*>(addr),
+ ::testing::KilledBySignal(SIGSEGV), "");
+}
+
+// SIGSEGV raised when writing PROT_READ only memory
+TEST_F(MMapTest, ReadOnlyDeath) {
+ SetupGvisorDeathTest();
+
+ uintptr_t addr;
+
+ ASSERT_THAT(
+ addr = Map(0, kPageSize, PROT_READ, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallSucceeds());
+
+ EXPECT_EXIT(*reinterpret_cast<volatile int*>(addr) = 42,
+ ::testing::KilledBySignal(SIGSEGV), "");
+}
+
+// Writable mapping mprotect'd to read-only should not be writable.
+TEST_F(MMapTest, MprotectReadOnlyDeath) {
+ SetupGvisorDeathTest();
+
+ uintptr_t addr;
+
+ ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallSucceeds());
+
+ volatile int* val = reinterpret_cast<int*>(addr);
+
+ // Copy to ensure page is mapped in.
+ *val = 42;
+
+ ASSERT_THAT(Protect(addr, kPageSize, PROT_READ), SyscallSucceeds());
+
+ // Now it shouldn't be writable.
+ EXPECT_EXIT(*val = 0, ::testing::KilledBySignal(SIGSEGV), "");
+}
+
+// Verify that calling mprotect an address that's not page aligned fails.
+TEST_F(MMapTest, MprotectNotPageAligned) {
+ uintptr_t addr;
+
+ ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallSucceeds());
+ ASSERT_THAT(Protect(addr + 1, kPageSize - 1, PROT_READ),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// Verify that calling mprotect with an absurdly huge length fails.
+TEST_F(MMapTest, MprotectHugeLength) {
+ uintptr_t addr;
+
+ ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallSucceeds());
+ ASSERT_THAT(Protect(addr, static_cast<size_t>(-1), PROT_READ),
+ SyscallFailsWithErrno(ENOMEM));
+}
+
+#if defined(__x86_64__) || defined(__i386__)
+// This code is equivalent in 32 and 64-bit mode
+const uint8_t machine_code[] = {
+ 0xb8, 0x2a, 0x00, 0x00, 0x00, // movl $42, %eax
+ 0xc3, // retq
+};
+#elif defined(__aarch64__)
+const uint8_t machine_code[] = {
+ 0x40, 0x05, 0x80, 0x52, // mov w0, #42
+ 0xc0, 0x03, 0x5f, 0xd6, // ret
+};
+#endif
+
+// PROT_EXEC allows code execution
+TEST_F(MMapTest, ProtExec) {
+ uintptr_t addr;
+ uint32_t (*func)(void);
+
+ EXPECT_THAT(addr = Map(0, kPageSize, PROT_EXEC | PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallSucceeds());
+
+ memcpy(reinterpret_cast<void*>(addr), machine_code, sizeof(machine_code));
+
+ func = reinterpret_cast<uint32_t (*)(void)>(addr);
+
+ EXPECT_EQ(42, func());
+}
+
+// No PROT_EXEC disallows code execution
+TEST_F(MMapTest, NoProtExecDeath) {
+ SetupGvisorDeathTest();
+
+ uintptr_t addr;
+ uint32_t (*func)(void);
+
+ EXPECT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallSucceeds());
+
+ memcpy(reinterpret_cast<void*>(addr), machine_code, sizeof(machine_code));
+
+ func = reinterpret_cast<uint32_t (*)(void)>(addr);
+
+ EXPECT_EXIT(func(), ::testing::KilledBySignal(SIGSEGV), "");
+}
+
+TEST_F(MMapTest, NoExceedLimitData) {
+ void* prevbrk;
+ void* target_brk;
+ struct rlimit setlim;
+
+ prevbrk = sbrk(0);
+ ASSERT_NE(-1, reinterpret_cast<intptr_t>(prevbrk));
+ target_brk = reinterpret_cast<char*>(prevbrk) + 1;
+
+ setlim.rlim_cur = RLIM_INFINITY;
+ setlim.rlim_max = RLIM_INFINITY;
+ ASSERT_THAT(setrlimit(RLIMIT_DATA, &setlim), SyscallSucceeds());
+ EXPECT_THAT(brk(target_brk), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(MMapTest, ExceedLimitData) {
+ // To unit test this more precisely, we'd need access to the mm's start_brk
+ // and end_brk, which we don't have direct access to :/
+ void* prevbrk;
+ void* target_brk;
+ struct rlimit setlim;
+
+ prevbrk = sbrk(0);
+ ASSERT_NE(-1, reinterpret_cast<intptr_t>(prevbrk));
+ target_brk = reinterpret_cast<char*>(prevbrk) + 8192;
+
+ setlim.rlim_cur = 0;
+ setlim.rlim_max = RLIM_INFINITY;
+ // Set RLIMIT_DATA very low so any subsequent brk() calls fail.
+ // Reset RLIMIT_DATA during teardown step.
+ ASSERT_THAT(setrlimit(RLIMIT_DATA, &setlim), SyscallSucceeds());
+ EXPECT_THAT(brk(target_brk), SyscallFailsWithErrno(ENOMEM));
+ // Teardown step...
+ setlim.rlim_cur = RLIM_INFINITY;
+ ASSERT_THAT(setrlimit(RLIMIT_DATA, &setlim), SyscallSucceeds());
+}
+
+TEST_F(MMapTest, ExceedLimitDataPrlimit) {
+ // To unit test this more precisely, we'd need access to the mm's start_brk
+ // and end_brk, which we don't have direct access to :/
+ void* prevbrk;
+ void* target_brk;
+ struct rlimit setlim;
+
+ prevbrk = sbrk(0);
+ ASSERT_NE(-1, reinterpret_cast<intptr_t>(prevbrk));
+ target_brk = reinterpret_cast<char*>(prevbrk) + 8192;
+
+ setlim.rlim_cur = 0;
+ setlim.rlim_max = RLIM_INFINITY;
+ // Set RLIMIT_DATA very low so any subsequent brk() calls fail.
+ // Reset RLIMIT_DATA during teardown step.
+ ASSERT_THAT(prlimit(0, RLIMIT_DATA, &setlim, nullptr), SyscallSucceeds());
+ EXPECT_THAT(brk(target_brk), SyscallFailsWithErrno(ENOMEM));
+ // Teardown step...
+ setlim.rlim_cur = RLIM_INFINITY;
+ ASSERT_THAT(setrlimit(RLIMIT_DATA, &setlim), SyscallSucceeds());
+}
+
+TEST_F(MMapTest, ExceedLimitDataPrlimitPID) {
+ // To unit test this more precisely, we'd need access to the mm's start_brk
+ // and end_brk, which we don't have direct access to :/
+ void* prevbrk;
+ void* target_brk;
+ struct rlimit setlim;
+
+ prevbrk = sbrk(0);
+ ASSERT_NE(-1, reinterpret_cast<intptr_t>(prevbrk));
+ target_brk = reinterpret_cast<char*>(prevbrk) + 8192;
+
+ setlim.rlim_cur = 0;
+ setlim.rlim_max = RLIM_INFINITY;
+ // Set RLIMIT_DATA very low so any subsequent brk() calls fail.
+ // Reset RLIMIT_DATA during teardown step.
+ ASSERT_THAT(prlimit(syscall(__NR_gettid), RLIMIT_DATA, &setlim, nullptr),
+ SyscallSucceeds());
+ EXPECT_THAT(brk(target_brk), SyscallFailsWithErrno(ENOMEM));
+ // Teardown step...
+ setlim.rlim_cur = RLIM_INFINITY;
+ ASSERT_THAT(setrlimit(RLIMIT_DATA, &setlim), SyscallSucceeds());
+}
+
+TEST_F(MMapTest, NoExceedLimitAS) {
+ constexpr uint64_t kAllocBytes = 200 << 20;
+ // Add some headroom to the AS limit in case of e.g. unexpected stack
+ // expansion.
+ constexpr uint64_t kExtraASBytes = kAllocBytes + (20 << 20);
+ static_assert(kAllocBytes < kExtraASBytes,
+ "test depends on allocation not exceeding AS limit");
+
+ auto vss = ASSERT_NO_ERRNO_AND_VALUE(VirtualMemorySize());
+ struct rlimit setlim;
+ setlim.rlim_cur = vss + kExtraASBytes;
+ setlim.rlim_max = RLIM_INFINITY;
+ ASSERT_THAT(setrlimit(RLIMIT_AS, &setlim), SyscallSucceeds());
+ EXPECT_THAT(
+ Map(0, kAllocBytes, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallSucceedsWithValue(Gt(0)));
+}
+
+TEST_F(MMapTest, ExceedLimitAS) {
+ constexpr uint64_t kAllocBytes = 200 << 20;
+ // Add some headroom to the AS limit in case of e.g. unexpected stack
+ // expansion.
+ constexpr uint64_t kExtraASBytes = 20 << 20;
+ static_assert(kAllocBytes > kExtraASBytes,
+ "test depends on allocation exceeding AS limit");
+
+ auto vss = ASSERT_NO_ERRNO_AND_VALUE(VirtualMemorySize());
+ struct rlimit setlim;
+ setlim.rlim_cur = vss + kExtraASBytes;
+ setlim.rlim_max = RLIM_INFINITY;
+ ASSERT_THAT(setrlimit(RLIMIT_AS, &setlim), SyscallSucceeds());
+ EXPECT_THAT(
+ Map(0, kAllocBytes, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallFailsWithErrno(ENOMEM));
+}
+
+// Tests that setting an anonymous mmap to PROT_NONE doesn't free the memory.
+TEST_F(MMapTest, SettingProtNoneDoesntFreeMemory) {
+ uintptr_t addr;
+ constexpr uint8_t kFirstWord[] = {42, 42, 42, 42};
+
+ EXPECT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallSucceedsWithValue(Gt(0)));
+
+ memset(reinterpret_cast<void*>(addr), 42, kPageSize);
+
+ ASSERT_THAT(Protect(addr, kPageSize, PROT_NONE), SyscallSucceeds());
+ ASSERT_THAT(Protect(addr, kPageSize, PROT_READ | PROT_WRITE),
+ SyscallSucceeds());
+
+ // The written data should still be there.
+ EXPECT_EQ(
+ 0, memcmp(reinterpret_cast<void*>(addr), kFirstWord, sizeof(kFirstWord)));
+}
+
+constexpr char kFileContents[] = "Hello World!";
+
+class MMapFileTest : public MMapTest {
+ protected:
+ FileDescriptor fd_;
+ std::string filename_;
+
+ // Open a file for read/write
+ void SetUp() override {
+ MMapTest::SetUp();
+
+ filename_ = NewTempAbsPath();
+ fd_ = ASSERT_NO_ERRNO_AND_VALUE(Open(filename_, O_CREAT | O_RDWR, 0644));
+
+ // Extend file so it can be written once mapped. Deliberately make the file
+ // only half a page in size, so we can test what happens when we access the
+ // second half.
+ // Use ftruncate(2) once the sentry supports it.
+ char zero = 0;
+ size_t count = 0;
+ do {
+ const DisableSave ds; // saving 2048 times is slow and useless.
+ Write(&zero, 1), SyscallSucceedsWithValue(1);
+ } while (++count < (kPageSize / 2));
+ ASSERT_THAT(lseek(fd_.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
+ }
+
+ // Close and delete file
+ void TearDown() override {
+ MMapTest::TearDown();
+ fd_.reset(); // Make sure the files is closed before we unlink it.
+ ASSERT_THAT(unlink(filename_.c_str()), SyscallSucceeds());
+ }
+
+ ssize_t Read(char* buf, size_t count) {
+ ssize_t len = 0;
+ do {
+ ssize_t ret = read(fd_.get(), buf, count);
+ if (ret < 0) {
+ return ret;
+ } else if (ret == 0) {
+ return len;
+ }
+
+ len += ret;
+ buf += ret;
+ } while (len < static_cast<ssize_t>(count));
+
+ return len;
+ }
+
+ ssize_t Write(const char* buf, size_t count) {
+ ssize_t len = 0;
+ do {
+ ssize_t ret = write(fd_.get(), buf, count);
+ if (ret < 0) {
+ return ret;
+ } else if (ret == 0) {
+ return len;
+ }
+
+ len += ret;
+ buf += ret;
+ } while (len < static_cast<ssize_t>(count));
+
+ return len;
+ }
+};
+
+class MMapFileParamTest
+ : public MMapFileTest,
+ public ::testing::WithParamInterface<std::tuple<int, int>> {
+ protected:
+ int prot() const { return std::get<0>(GetParam()); }
+
+ int flags() const { return std::get<1>(GetParam()); }
+};
+
+// MAP_POPULATE allowed.
+// There isn't a good way to verify it actually did anything.
+TEST_P(MMapFileParamTest, MapPopulate) {
+ ASSERT_THAT(Map(0, kPageSize, prot(), flags() | MAP_POPULATE, fd_.get(), 0),
+ SyscallSucceeds());
+}
+
+// MAP_POPULATE on a short file.
+TEST_P(MMapFileParamTest, MapPopulateShort) {
+ ASSERT_THAT(
+ Map(0, 2 * kPageSize, prot(), flags() | MAP_POPULATE, fd_.get(), 0),
+ SyscallSucceeds());
+}
+
+// Read contents from mapped file.
+TEST_F(MMapFileTest, Read) {
+ size_t len = strlen(kFileContents);
+ ASSERT_EQ(len, Write(kFileContents, len));
+
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_PRIVATE, fd_.get(), 0),
+ SyscallSucceeds());
+
+ EXPECT_THAT(reinterpret_cast<char*>(addr),
+ EqualsMemory(std::string(kFileContents)));
+}
+
+// Map at an offset.
+TEST_F(MMapFileTest, MapOffset) {
+ ASSERT_THAT(lseek(fd_.get(), kPageSize, SEEK_SET), SyscallSucceeds());
+
+ size_t len = strlen(kFileContents);
+ ASSERT_EQ(len, Write(kFileContents, len));
+
+ uintptr_t addr;
+ ASSERT_THAT(
+ addr = Map(0, kPageSize, PROT_READ, MAP_PRIVATE, fd_.get(), kPageSize),
+ SyscallSucceeds());
+
+ EXPECT_THAT(reinterpret_cast<char*>(addr),
+ EqualsMemory(std::string(kFileContents)));
+}
+
+TEST_F(MMapFileTest, MapOffsetBeyondEnd) {
+ SetupGvisorDeathTest();
+
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE,
+ fd_.get(), 10 * kPageSize),
+ SyscallSucceeds());
+
+ // Touching the memory causes SIGBUS.
+ size_t len = strlen(kFileContents);
+ EXPECT_EXIT(std::copy(kFileContents, kFileContents + len,
+ reinterpret_cast<volatile char*>(addr)),
+ ::testing::KilledBySignal(SIGBUS), "");
+}
+
+// Verify mmap fails when sum of length and offset overflows.
+TEST_F(MMapFileTest, MapLengthPlusOffsetOverflows) {
+ const size_t length = static_cast<size_t>(-kPageSize);
+ const off_t offset = kPageSize;
+ ASSERT_THAT(Map(0, length, PROT_READ, MAP_PRIVATE, fd_.get(), offset),
+ SyscallFailsWithErrno(ENOMEM));
+}
+
+// MAP_PRIVATE PROT_WRITE is allowed on read-only FDs.
+TEST_F(MMapFileTest, WritePrivateOnReadOnlyFd) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(filename_, O_RDONLY));
+
+ uintptr_t addr;
+ EXPECT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE,
+ fd.get(), 0),
+ SyscallSucceeds());
+
+ // Touch the page to ensure the kernel didn't lie about writability.
+ size_t len = strlen(kFileContents);
+ std::copy(kFileContents, kFileContents + len,
+ reinterpret_cast<volatile char*>(addr));
+}
+
+// MAP_SHARED PROT_WRITE not allowed on read-only FDs.
+TEST_F(MMapFileTest, WriteSharedOnReadOnlyFd) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(filename_, O_RDONLY));
+
+ uintptr_t addr;
+ EXPECT_THAT(
+ addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd.get(), 0),
+ SyscallFailsWithErrno(EACCES));
+}
+
+// The FD must be readable.
+TEST_P(MMapFileParamTest, WriteOnlyFd) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(filename_, O_WRONLY));
+
+ uintptr_t addr;
+ EXPECT_THAT(addr = Map(0, kPageSize, prot(), flags(), fd.get(), 0),
+ SyscallFailsWithErrno(EACCES));
+}
+
+// Overwriting the contents of a file mapped MAP_SHARED PROT_READ
+// should cause the new data to be reflected in the mapping.
+TEST_F(MMapFileTest, ReadSharedConsistentWithOverwrite) {
+ // Start from scratch.
+ EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds());
+
+ // Expand the file to two pages and dirty them.
+ std::string bufA(kPageSize, 'a');
+ ASSERT_THAT(Write(bufA.c_str(), bufA.size()),
+ SyscallSucceedsWithValue(bufA.size()));
+ std::string bufB(kPageSize, 'b');
+ ASSERT_THAT(Write(bufB.c_str(), bufB.size()),
+ SyscallSucceedsWithValue(bufB.size()));
+
+ // Map the page.
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0),
+ SyscallSucceeds());
+
+ // Check that the mapping contains the right file data.
+ EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr), bufA.c_str(), kPageSize));
+ EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr + kPageSize), bufB.c_str(),
+ kPageSize));
+
+ // Start at the beginning of the file.
+ ASSERT_THAT(lseek(fd_.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
+
+ // Swap the write pattern.
+ ASSERT_THAT(Write(bufB.c_str(), bufB.size()),
+ SyscallSucceedsWithValue(bufB.size()));
+ ASSERT_THAT(Write(bufA.c_str(), bufA.size()),
+ SyscallSucceedsWithValue(bufA.size()));
+
+ // Check that the mapping got updated.
+ EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr), bufB.c_str(), kPageSize));
+ EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr + kPageSize), bufA.c_str(),
+ kPageSize));
+}
+
+// Partially overwriting a file mapped MAP_SHARED PROT_READ should be reflected
+// in the mapping.
+TEST_F(MMapFileTest, ReadSharedConsistentWithPartialOverwrite) {
+ // Start from scratch.
+ EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds());
+
+ // Expand the file to two pages and dirty them.
+ std::string bufA(kPageSize, 'a');
+ ASSERT_THAT(Write(bufA.c_str(), bufA.size()),
+ SyscallSucceedsWithValue(bufA.size()));
+ std::string bufB(kPageSize, 'b');
+ ASSERT_THAT(Write(bufB.c_str(), bufB.size()),
+ SyscallSucceedsWithValue(bufB.size()));
+
+ // Map the page.
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0),
+ SyscallSucceeds());
+
+ // Check that the mapping contains the right file data.
+ EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr), bufA.c_str(), kPageSize));
+ EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr + kPageSize), bufB.c_str(),
+ kPageSize));
+
+ // Start at the beginning of the file.
+ ASSERT_THAT(lseek(fd_.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
+
+ // Do a partial overwrite, spanning both pages.
+ std::string bufC(kPageSize + (kPageSize / 2), 'c');
+ ASSERT_THAT(Write(bufC.c_str(), bufC.size()),
+ SyscallSucceedsWithValue(bufC.size()));
+
+ // Check that the mapping got updated.
+ EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr), bufC.c_str(),
+ kPageSize + (kPageSize / 2)));
+ EXPECT_EQ(0,
+ memcmp(reinterpret_cast<void*>(addr + kPageSize + (kPageSize / 2)),
+ bufB.c_str(), kPageSize / 2));
+}
+
+// Overwriting a file mapped MAP_SHARED PROT_READ should be reflected in the
+// mapping and the file.
+TEST_F(MMapFileTest, ReadSharedConsistentWithWriteAndFile) {
+ // Start from scratch.
+ EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds());
+
+ // Expand the file to two full pages and dirty it.
+ std::string bufA(2 * kPageSize, 'a');
+ ASSERT_THAT(Write(bufA.c_str(), bufA.size()),
+ SyscallSucceedsWithValue(bufA.size()));
+
+ // Map only the first page.
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0),
+ SyscallSucceeds());
+
+ // Prepare to overwrite the file contents.
+ ASSERT_THAT(lseek(fd_.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
+
+ // Overwrite everything, beyond the mapped portion.
+ std::string bufB(2 * kPageSize, 'b');
+ ASSERT_THAT(Write(bufB.c_str(), bufB.size()),
+ SyscallSucceedsWithValue(bufB.size()));
+
+ // What the mapped portion should now look like.
+ std::string bufMapped(kPageSize, 'b');
+
+ // Expect that the mapped portion is consistent.
+ EXPECT_EQ(
+ 0, memcmp(reinterpret_cast<void*>(addr), bufMapped.c_str(), kPageSize));
+
+ // Prepare to read the entire file contents.
+ ASSERT_THAT(lseek(fd_.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
+
+ // Expect that the file was fully updated.
+ std::vector<char> bufFile(2 * kPageSize);
+ ASSERT_THAT(Read(bufFile.data(), bufFile.size()),
+ SyscallSucceedsWithValue(bufFile.size()));
+ // Cast to void* to avoid EXPECT_THAT assuming bufFile.data() is a
+ // NUL-terminated C std::string. EXPECT_THAT will try to print a char* as a C
+ // std::string, possibly overruning the buffer.
+ EXPECT_THAT(reinterpret_cast<void*>(bufFile.data()), EqualsMemory(bufB));
+}
+
+// Write data to mapped file.
+TEST_F(MMapFileTest, WriteShared) {
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED,
+ fd_.get(), 0),
+ SyscallSucceeds());
+
+ size_t len = strlen(kFileContents);
+ memcpy(reinterpret_cast<void*>(addr), kFileContents, len);
+
+ // The file may not actually be updated until munmap is called.
+ ASSERT_THAT(Unmap(), SyscallSucceeds());
+
+ std::vector<char> buf(len);
+ ASSERT_THAT(Read(buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+ // Cast to void* to avoid EXPECT_THAT assuming buf.data() is a
+ // NUL-terminated C string. EXPECT_THAT will try to print a char* as a C
+ // string, possibly overruning the buffer.
+ EXPECT_THAT(reinterpret_cast<void*>(buf.data()),
+ EqualsMemory(std::string(kFileContents)));
+}
+
+// Write data to portion of mapped page beyond the end of the file.
+// These writes are not reflected in the file.
+TEST_F(MMapFileTest, WriteSharedBeyondEnd) {
+ // The file is only half of a page. We map an entire page. Writes to the
+ // end of the mapping must not be reflected in the file.
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED,
+ fd_.get(), 0),
+ SyscallSucceeds());
+
+ // First half; this is reflected in the file.
+ std::string first(kPageSize / 2, 'A');
+ memcpy(reinterpret_cast<void*>(addr), first.c_str(), first.size());
+
+ // Second half; this is not reflected in the file.
+ std::string second(kPageSize / 2, 'B');
+ memcpy(reinterpret_cast<void*>(addr + kPageSize / 2), second.c_str(),
+ second.size());
+
+ // The file may not actually be updated until munmap is called.
+ ASSERT_THAT(Unmap(), SyscallSucceeds());
+
+ // Big enough to fit the entire page, if the writes are mistakenly written to
+ // the file.
+ std::vector<char> buf(kPageSize);
+
+ // Only the first half is in the file.
+ ASSERT_THAT(Read(buf.data(), buf.size()),
+ SyscallSucceedsWithValue(first.size()));
+ // Cast to void* to avoid EXPECT_THAT assuming buf.data() is a
+ // NUL-terminated C string. EXPECT_THAT will try to print a char* as a C
+ // NUL-terminated C std::string. EXPECT_THAT will try to print a char* as a C
+ // std::string, possibly overruning the buffer.
+ EXPECT_THAT(reinterpret_cast<void*>(buf.data()), EqualsMemory(first));
+}
+
+// The portion of a mapped page that becomes part of the file after a truncate
+// is reflected in the file.
+TEST_F(MMapFileTest, WriteSharedTruncateUp) {
+ // The file is only half of a page. We map an entire page. Writes to the
+ // end of the mapping must not be reflected in the file.
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED,
+ fd_.get(), 0),
+ SyscallSucceeds());
+
+ // First half; this is reflected in the file.
+ std::string first(kPageSize / 2, 'A');
+ memcpy(reinterpret_cast<void*>(addr), first.c_str(), first.size());
+
+ // Second half; this is not reflected in the file now (see
+ // WriteSharedBeyondEnd), but will be after the truncate.
+ std::string second(kPageSize / 2, 'B');
+ memcpy(reinterpret_cast<void*>(addr + kPageSize / 2), second.c_str(),
+ second.size());
+
+ // Extend the file to a full page. The second half of the page will be
+ // reflected in the file.
+ EXPECT_THAT(ftruncate(fd_.get(), kPageSize), SyscallSucceeds());
+
+ // The file may not actually be updated until munmap is called.
+ ASSERT_THAT(Unmap(), SyscallSucceeds());
+
+ // The whole page is in the file.
+ std::vector<char> buf(kPageSize);
+ ASSERT_THAT(Read(buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+ // Cast to void* to avoid EXPECT_THAT assuming buf.data() is a
+ // NUL-terminated C string. EXPECT_THAT will try to print a char* as a C
+ // string, possibly overruning the buffer.
+ EXPECT_THAT(reinterpret_cast<void*>(buf.data()), EqualsMemory(first));
+ EXPECT_THAT(reinterpret_cast<void*>(buf.data() + kPageSize / 2),
+ EqualsMemory(second));
+}
+
+TEST_F(MMapFileTest, ReadSharedTruncateDownThenUp) {
+ // Start from scratch.
+ EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds());
+
+ // Expand the file to a full page and dirty it.
+ std::string buf(kPageSize, 'a');
+ ASSERT_THAT(Write(buf.c_str(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+
+ // Map the page.
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0),
+ SyscallSucceeds());
+
+ // Check that the memory contains the file data.
+ EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr), buf.c_str(), kPageSize));
+
+ // Truncate down, then up.
+ EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds());
+ EXPECT_THAT(ftruncate(fd_.get(), kPageSize), SyscallSucceeds());
+
+ // Check that the memory was zeroed.
+ std::string zeroed(kPageSize, '\0');
+ EXPECT_EQ(0,
+ memcmp(reinterpret_cast<void*>(addr), zeroed.c_str(), kPageSize));
+
+ // The file may not actually be updated until msync is called.
+ ASSERT_THAT(Msync(), SyscallSucceeds());
+
+ // Prepare to read the entire file contents.
+ ASSERT_THAT(lseek(fd_.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
+
+ // Expect that the file is fully updated.
+ std::vector<char> bufFile(kPageSize);
+ ASSERT_THAT(Read(bufFile.data(), bufFile.size()),
+ SyscallSucceedsWithValue(bufFile.size()));
+ EXPECT_EQ(0, memcmp(bufFile.data(), zeroed.c_str(), kPageSize));
+}
+
+TEST_F(MMapFileTest, WriteSharedTruncateDownThenUp) {
+ // The file is only half of a page. We map an entire page. Writes to the
+ // end of the mapping must not be reflected in the file.
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED,
+ fd_.get(), 0),
+ SyscallSucceeds());
+
+ // First half; this will be deleted by truncate(0).
+ std::string first(kPageSize / 2, 'A');
+ memcpy(reinterpret_cast<void*>(addr), first.c_str(), first.size());
+
+ // Truncate down, then up.
+ EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds());
+ EXPECT_THAT(ftruncate(fd_.get(), kPageSize), SyscallSucceeds());
+
+ // The whole page is zeroed in memory.
+ std::string zeroed(kPageSize, '\0');
+ EXPECT_EQ(0,
+ memcmp(reinterpret_cast<void*>(addr), zeroed.c_str(), kPageSize));
+
+ // The file may not actually be updated until munmap is called.
+ ASSERT_THAT(Unmap(), SyscallSucceeds());
+
+ // The whole file is also zeroed.
+ std::vector<char> buf(kPageSize);
+ ASSERT_THAT(Read(buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+ // Cast to void* to avoid EXPECT_THAT assuming buf.data() is a
+ // NUL-terminated C string. EXPECT_THAT will try to print a char* as a C
+ // string, possibly overruning the buffer.
+ EXPECT_THAT(reinterpret_cast<void*>(buf.data()), EqualsMemory(zeroed));
+}
+
+TEST_F(MMapFileTest, ReadSharedTruncateSIGBUS) {
+ SetupGvisorDeathTest();
+
+ // Start from scratch.
+ EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds());
+
+ // Expand the file to a full page and dirty it.
+ std::string buf(kPageSize, 'a');
+ ASSERT_THAT(Write(buf.c_str(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+
+ // Map the page.
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0),
+ SyscallSucceeds());
+
+ // Check that the mapping contains the file data.
+ EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr), buf.c_str(), kPageSize));
+
+ // Truncate down.
+ EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds());
+
+ // Accessing the truncated region should cause a SIGBUS.
+ std::vector<char> in(kPageSize);
+ EXPECT_EXIT(
+ std::copy(reinterpret_cast<volatile char*>(addr),
+ reinterpret_cast<volatile char*>(addr) + kPageSize, in.data()),
+ ::testing::KilledBySignal(SIGBUS), "");
+}
+
+TEST_F(MMapFileTest, WriteSharedTruncateSIGBUS) {
+ SetupGvisorDeathTest();
+
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED,
+ fd_.get(), 0),
+ SyscallSucceeds());
+
+ // Touch the memory to be sure it really is mapped.
+ size_t len = strlen(kFileContents);
+ memcpy(reinterpret_cast<void*>(addr), kFileContents, len);
+
+ // Truncate down.
+ EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds());
+
+ // Accessing the truncated file should cause a SIGBUS.
+ EXPECT_EXIT(std::copy(kFileContents, kFileContents + len,
+ reinterpret_cast<volatile char*>(addr)),
+ ::testing::KilledBySignal(SIGBUS), "");
+}
+
+TEST_F(MMapFileTest, ReadSharedTruncatePartialPage) {
+ // Start from scratch.
+ EXPECT_THAT(ftruncate(fd_.get(), 0), SyscallSucceeds());
+
+ // Dirty the file.
+ std::string buf(kPageSize, 'a');
+ ASSERT_THAT(Write(buf.c_str(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+
+ // Map a page.
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0),
+ SyscallSucceeds());
+
+ // Truncate to half of the page.
+ EXPECT_THAT(ftruncate(fd_.get(), kPageSize / 2), SyscallSucceeds());
+
+ // First half of the page untouched.
+ EXPECT_EQ(0,
+ memcmp(reinterpret_cast<void*>(addr), buf.data(), kPageSize / 2));
+
+ // Second half is zeroed.
+ std::string zeroed(kPageSize / 2, '\0');
+ EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr + kPageSize / 2),
+ zeroed.c_str(), kPageSize / 2));
+}
+
+// Page can still be accessed and contents are intact after truncating a partial
+// page.
+TEST_F(MMapFileTest, WriteSharedTruncatePartialPage) {
+ // Expand the file to a full page.
+ EXPECT_THAT(ftruncate(fd_.get(), kPageSize), SyscallSucceeds());
+
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED,
+ fd_.get(), 0),
+ SyscallSucceeds());
+
+ // Fill the entire page.
+ std::string contents(kPageSize, 'A');
+ memcpy(reinterpret_cast<void*>(addr), contents.c_str(), contents.size());
+
+ // Truncate half of the page.
+ EXPECT_THAT(ftruncate(fd_.get(), kPageSize / 2), SyscallSucceeds());
+
+ // First half of the page untouched.
+ EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr), contents.c_str(),
+ kPageSize / 2));
+
+ // Second half zeroed.
+ std::string zeroed(kPageSize / 2, '\0');
+ EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr + kPageSize / 2),
+ zeroed.c_str(), kPageSize / 2));
+}
+
+// MAP_PRIVATE writes are not carried through to the underlying file.
+TEST_F(MMapFileTest, WritePrivate) {
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE,
+ fd_.get(), 0),
+ SyscallSucceeds());
+
+ size_t len = strlen(kFileContents);
+ memcpy(reinterpret_cast<void*>(addr), kFileContents, len);
+
+ // The file should not be updated, but if it mistakenly is, it may not be
+ // until after munmap is called.
+ ASSERT_THAT(Unmap(), SyscallSucceeds());
+
+ std::vector<char> buf(len);
+ ASSERT_THAT(Read(buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+ // Cast to void* to avoid EXPECT_THAT assuming buf.data() is a
+ // NUL-terminated C string. EXPECT_THAT will try to print a char* as a C
+ // string, possibly overruning the buffer.
+ EXPECT_THAT(reinterpret_cast<void*>(buf.data()),
+ EqualsMemory(std::string(len, '\0')));
+}
+
+// SIGBUS raised when reading or writing past end of a mapped file.
+TEST_P(MMapFileParamTest, SigBusDeath) {
+ SetupGvisorDeathTest();
+
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, 2 * kPageSize, prot(), flags(), fd_.get(), 0),
+ SyscallSucceeds());
+
+ auto* start = reinterpret_cast<volatile char*>(addr + kPageSize);
+
+ // MMapFileTest makes a file kPageSize/2 long. The entire first page should be
+ // accessible, but anything beyond it should not.
+ if (prot() & PROT_WRITE) {
+ // Write beyond first page.
+ size_t len = strlen(kFileContents);
+ EXPECT_EXIT(std::copy(kFileContents, kFileContents + len, start),
+ ::testing::KilledBySignal(SIGBUS), "");
+ } else {
+ // Read beyond first page.
+ std::vector<char> in(kPageSize);
+ EXPECT_EXIT(std::copy(start, start + kPageSize, in.data()),
+ ::testing::KilledBySignal(SIGBUS), "");
+ }
+}
+
+// Tests that SIGBUS is not raised when reading or writing to a file-mapped
+// page before EOF, even if part of the mapping extends beyond EOF.
+//
+// See b/27877699.
+TEST_P(MMapFileParamTest, NoSigBusOnPagesBeforeEOF) {
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, 2 * kPageSize, prot(), flags(), fd_.get(), 0),
+ SyscallSucceeds());
+
+ // The test passes if this survives.
+ auto* start = reinterpret_cast<volatile char*>(addr + (kPageSize / 2) + 1);
+ size_t len = strlen(kFileContents);
+ if (prot() & PROT_WRITE) {
+ std::copy(kFileContents, kFileContents + len, start);
+ } else {
+ std::vector<char> in(len);
+ std::copy(start, start + len, in.data());
+ }
+}
+
+// Tests that SIGBUS is not raised when reading or writing from a file-mapped
+// page containing EOF, *after* the EOF.
+TEST_P(MMapFileParamTest, NoSigBusOnPageContainingEOF) {
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, 2 * kPageSize, prot(), flags(), fd_.get(), 0),
+ SyscallSucceeds());
+
+ // The test passes if this survives. (Technically addr+kPageSize/2 is already
+ // beyond EOF, but +1 to check for fencepost errors.)
+ auto* start = reinterpret_cast<volatile char*>(addr + (kPageSize / 2) + 1);
+ size_t len = strlen(kFileContents);
+ if (prot() & PROT_WRITE) {
+ std::copy(kFileContents, kFileContents + len, start);
+ } else {
+ std::vector<char> in(len);
+ std::copy(start, start + len, in.data());
+ }
+}
+
+// Tests that reading from writable shared file-mapped pages succeeds.
+//
+// On most platforms this is trivial, but when the file is mapped via the sentry
+// page cache (which does not yet support writing to shared mappings), a bug
+// caused reads to fail unnecessarily on such mappings. See b/28913513.
+TEST_F(MMapFileTest, ReadingWritableSharedFilePageSucceeds) {
+ uintptr_t addr;
+ size_t len = strlen(kFileContents);
+
+ ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED,
+ fd_.get(), 0),
+ SyscallSucceeds());
+
+ std::vector<char> buf(kPageSize);
+ // The test passes if this survives.
+ std::copy(reinterpret_cast<volatile char*>(addr),
+ reinterpret_cast<volatile char*>(addr) + len, buf.data());
+}
+
+// Tests that EFAULT is returned when invoking a syscall that requires the OS to
+// read past end of file (resulting in a fault in sentry context in the gVisor
+// case). See b/28913513.
+TEST_F(MMapFileTest, InternalSigBus) {
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE,
+ fd_.get(), 0),
+ SyscallSucceeds());
+
+ // This depends on the fact that gVisor implements pipes internally.
+ int pipefd[2];
+ ASSERT_THAT(pipe(pipefd), SyscallSucceeds());
+ EXPECT_THAT(
+ write(pipefd[1], reinterpret_cast<void*>(addr + kPageSize), kPageSize),
+ SyscallFailsWithErrno(EFAULT));
+
+ EXPECT_THAT(close(pipefd[0]), SyscallSucceeds());
+ EXPECT_THAT(close(pipefd[1]), SyscallSucceeds());
+}
+
+// Like InternalSigBus, but test the WriteZerosAt path by reading from
+// /dev/zero to a shared mapping (so that the SIGBUS isn't caught during
+// copy-on-write breaking).
+TEST_F(MMapFileTest, InternalSigBusZeroing) {
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED,
+ fd_.get(), 0),
+ SyscallSucceeds());
+
+ const FileDescriptor dev_zero =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDONLY));
+ EXPECT_THAT(read(dev_zero.get(), reinterpret_cast<void*>(addr + kPageSize),
+ kPageSize),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+// Checks that mmaps with a length of uint64_t(-PAGE_SIZE + 1) or greater do not
+// induce a sentry panic (due to "rounding up" to 0).
+TEST_F(MMapTest, HugeLength) {
+ EXPECT_THAT(Map(0, static_cast<uint64_t>(-kPageSize + 1), PROT_NONE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallFailsWithErrno(ENOMEM));
+}
+
+// Tests for a specific gVisor MM caching bug.
+TEST_F(MMapTest, AccessCOWInvalidatesCachedSegments) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDWR));
+ auto zero_fd = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDONLY));
+
+ // Get a two-page private mapping and fill it with 1s.
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0),
+ SyscallSucceeds());
+ memset(addr_, 1, 2 * kPageSize);
+ MaybeSave();
+
+ // Fork to make the mapping copy-on-write.
+ pid_t const pid = fork();
+ if (pid == 0) {
+ // The child process waits for the parent to SIGKILL it.
+ while (true) {
+ pause();
+ }
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ auto cleanup_child = Cleanup([&] {
+ EXPECT_THAT(kill(pid, SIGKILL), SyscallSucceeds());
+ int status;
+ EXPECT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ });
+
+ // Induce a read-only Access of the first page of the mapping, which will not
+ // cause a copy. The usermem.Segment should be cached.
+ ASSERT_THAT(PwriteFd(fd.get(), addr_, kPageSize, 0),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Induce a writable Access of both pages of the mapping. This should
+ // invalidate the cached Segment.
+ ASSERT_THAT(PreadFd(zero_fd.get(), addr_, 2 * kPageSize, 0),
+ SyscallSucceedsWithValue(2 * kPageSize));
+
+ // Induce a read-only Access of the first page of the mapping again. It should
+ // read the 0s that were stored in the mapping by the read from /dev/zero. If
+ // the read failed to invalidate the cached Segment, it will instead read the
+ // 1s in the stale page.
+ ASSERT_THAT(PwriteFd(fd.get(), addr_, kPageSize, 0),
+ SyscallSucceedsWithValue(kPageSize));
+ std::vector<char> buf(kPageSize);
+ ASSERT_THAT(PreadFd(fd.get(), buf.data(), kPageSize, 0),
+ SyscallSucceedsWithValue(kPageSize));
+ for (size_t i = 0; i < kPageSize; i++) {
+ ASSERT_EQ(0, buf[i]) << "at offset " << i;
+ }
+}
+
+TEST_F(MMapTest, NoReserve) {
+ const size_t kSize = 10 * 1 << 20; // 10M
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, kSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS | MAP_NORESERVE, -1, 0),
+ SyscallSucceeds());
+ EXPECT_GT(addr, 0);
+
+ // Check that every page can be read/written. Technically, writing to memory
+ // could SIGSEGV in case there is no more memory available. In gVisor it
+ // would never happen though because NORESERVE is ignored. In Linux, it's
+ // possible to fail, but allocation is small enough that it's highly likely
+ // to succeed.
+ for (size_t j = 0; j < kSize; j += kPageSize) {
+ EXPECT_EQ(0, reinterpret_cast<char*>(addr)[j]);
+ reinterpret_cast<char*>(addr)[j] = j;
+ }
+}
+
+// Map more than the gVisor page-cache map unit (64k) and ensure that
+// it is consistent with reading from the file.
+TEST_F(MMapFileTest, Bug38498194) {
+ // Choose a sufficiently large map unit.
+ constexpr int kSize = 4 * 1024 * 1024;
+ EXPECT_THAT(ftruncate(fd_.get(), kSize), SyscallSucceeds());
+
+ // Map a large enough region so that multiple internal segments
+ // are created to back the mapping.
+ uintptr_t addr;
+ ASSERT_THAT(
+ addr = Map(0, kSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd_.get(), 0),
+ SyscallSucceeds());
+
+ std::vector<char> expect(kSize, 'a');
+ std::copy(expect.data(), expect.data() + expect.size(),
+ reinterpret_cast<volatile char*>(addr));
+
+ // Trigger writeback for gVisor. In Linux pages stay cached until
+ // it can't hold onto them anymore.
+ ASSERT_THAT(Unmap(), SyscallSucceeds());
+
+ std::vector<char> buf(kSize);
+ ASSERT_THAT(Read(buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+ EXPECT_EQ(buf, expect) << std::string(buf.data(), buf.size());
+}
+
+// Tests that reading from a file to a memory mapping of the same file does not
+// deadlock. See b/34813270.
+TEST_F(MMapFileTest, SelfRead) {
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED,
+ fd_.get(), 0),
+ SyscallSucceeds());
+ EXPECT_THAT(Read(reinterpret_cast<char*>(addr), kPageSize / 2),
+ SyscallSucceedsWithValue(kPageSize / 2));
+ // The resulting file contents are poorly-specified and irrelevant.
+}
+
+// Tests that writing to a file from a memory mapping of the same file does not
+// deadlock. Regression test for b/34813270.
+TEST_F(MMapFileTest, SelfWrite) {
+ uintptr_t addr;
+ ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0),
+ SyscallSucceeds());
+ EXPECT_THAT(Write(reinterpret_cast<char*>(addr), kPageSize / 2),
+ SyscallSucceedsWithValue(kPageSize / 2));
+ // The resulting file contents are poorly-specified and irrelevant.
+}
+
+TEST(MMapDeathTest, TruncateAfterCOWBreak) {
+ SetupGvisorDeathTest();
+
+ // Create and map a single-page file.
+ auto const temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(temp_file.path(), O_RDWR));
+ ASSERT_THAT(ftruncate(fd.get(), kPageSize), SyscallSucceeds());
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(Mmap(
+ nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, fd.get(), 0));
+
+ // Write to this mapping, causing the page to be copied for write.
+ memset(mapping.ptr(), 'a', mapping.len());
+ MaybeSave(); // Trigger a co-operative save cycle.
+
+ // Truncate the file and expect it to invalidate the copied page.
+ ASSERT_THAT(ftruncate(fd.get(), 0), SyscallSucceeds());
+ EXPECT_EXIT(*reinterpret_cast<volatile char*>(mapping.ptr()),
+ ::testing::KilledBySignal(SIGBUS), "");
+}
+
+// Regression test for #147.
+TEST(MMapNoFixtureTest, MapReadOnlyAfterCreateWriteOnly) {
+ std::string filename = NewTempAbsPath();
+
+ // We have to create the file O_RDONLY to reproduce the bug because
+ // fsgofer.localFile.Create() silently upgrades O_WRONLY to O_RDWR, causing
+ // the cached "write-only" FD to be read/write and therefore usable by mmap().
+ auto const ro_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(filename, O_RDONLY | O_CREAT | O_EXCL, 0666));
+
+ // Get a write-only FD for the same file, which should be ignored by mmap()
+ // (but isn't in #147).
+ auto const wo_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(filename, O_WRONLY));
+ ASSERT_THAT(ftruncate(wo_fd.get(), kPageSize), SyscallSucceeds());
+
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ Mmap(nullptr, kPageSize, PROT_READ, MAP_SHARED, ro_fd.get(), 0));
+ std::vector<char> buf(kPageSize);
+ // The test passes if this survives.
+ std::copy(static_cast<char*>(mapping.ptr()),
+ static_cast<char*>(mapping.endptr()), buf.data());
+}
+
+// Conditional on MAP_32BIT.
+// This flag is supported only on x86-64, for 64-bit programs.
+#ifdef __x86_64__
+
+TEST(MMapNoFixtureTest, Map32Bit) {
+ auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_NONE, MAP_PRIVATE | MAP_32BIT));
+ EXPECT_LT(mapping.addr(), static_cast<uintptr_t>(1) << 32);
+ EXPECT_LE(mapping.endaddr(), static_cast<uintptr_t>(1) << 32);
+}
+
+#endif // defined(__x86_64__)
+
+INSTANTIATE_TEST_SUITE_P(
+ ReadWriteSharedPrivate, MMapFileParamTest,
+ ::testing::Combine(::testing::ValuesIn({
+ PROT_READ,
+ PROT_WRITE,
+ PROT_READ | PROT_WRITE,
+ }),
+ ::testing::ValuesIn({MAP_SHARED, MAP_PRIVATE})));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc
new file mode 100644
index 000000000..a3e9745cf
--- /dev/null
+++ b/test/syscalls/linux/mount.cc
@@ -0,0 +1,327 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <stdio.h>
+#include <sys/mount.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/string_view.h"
+#include "absl/time/time.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/mount_util.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(MountTest, MountBadFilesystem) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ // Linux expects a valid target before it checks the file system name.
+ auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(mount("", dir.path().c_str(), "foobar", 0, ""),
+ SyscallFailsWithErrno(ENODEV));
+}
+
+TEST(MountTest, MountInvalidTarget) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ auto const dir = NewTempAbsPath();
+ EXPECT_THAT(mount("", dir.c_str(), "tmpfs", 0, ""),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+TEST(MountTest, MountPermDenied) {
+ // Clear CAP_SYS_ADMIN.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))) {
+ EXPECT_NO_ERRNO(SetCapability(CAP_SYS_ADMIN, false));
+ }
+
+ // Linux expects a valid target before checking capability.
+ auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(mount("", dir.path().c_str(), "", 0, ""),
+ SyscallFailsWithErrno(EPERM));
+}
+
+TEST(MountTest, UmountPermDenied) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto const mount =
+ ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "tmpfs", 0, "", 0));
+
+ // Drop privileges in another thread, so we can still unmount the mounted
+ // directory.
+ ScopedThread([&]() {
+ EXPECT_NO_ERRNO(SetCapability(CAP_SYS_ADMIN, false));
+ EXPECT_THAT(umount(dir.path().c_str()), SyscallFailsWithErrno(EPERM));
+ });
+}
+
+TEST(MountTest, MountOverBusy) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto const fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(JoinPath(dir.path(), "foo"), O_CREAT | O_RDWR, 0777));
+
+ // Should be able to mount over a busy directory.
+ ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "tmpfs", 0, "", 0));
+}
+
+TEST(MountTest, OpenFileBusy) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto const mount = ASSERT_NO_ERRNO_AND_VALUE(
+ Mount("", dir.path(), "tmpfs", 0, "mode=0700", 0));
+ auto const fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(JoinPath(dir.path(), "foo"), O_CREAT | O_RDWR, 0777));
+
+ // An open file should prevent unmounting.
+ EXPECT_THAT(umount(dir.path().c_str()), SyscallFailsWithErrno(EBUSY));
+}
+
+TEST(MountTest, UmountDetach) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ // structure:
+ //
+ // dir (mount point)
+ // subdir
+ // file
+ //
+ // We show that we can walk around in the mount after detach-unmount dir.
+ //
+ // We show that even though dir is unreachable from outside the mount, we can
+ // still reach dir's (former) parent!
+ auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ const struct stat before = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path()));
+ auto mount =
+ ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "tmpfs", 0, "mode=0700",
+ /* umountflags= */ MNT_DETACH));
+ const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path()));
+ EXPECT_NE(before.st_ino, after.st_ino);
+
+ // Create files in the new mount.
+ constexpr char kContents[] = "no no no";
+ auto const subdir =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir.path()));
+ auto const file = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(dir.path(), kContents, 0777));
+
+ auto const dir_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(subdir.path(), O_RDONLY | O_DIRECTORY));
+ auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+
+ // Unmount the tmpfs.
+ mount.Release()();
+
+ const struct stat after2 = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path()));
+ EXPECT_EQ(before.st_ino, after2.st_ino);
+
+ // Can still read file after unmounting.
+ std::vector<char> buf(sizeof(kContents));
+ EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()), SyscallSucceeds());
+
+ // Walk to dir.
+ auto const mounted_dir = ASSERT_NO_ERRNO_AND_VALUE(
+ OpenAt(dir_fd.get(), "..", O_DIRECTORY | O_RDONLY));
+ // Walk to dir/file.
+ auto const fd_again = ASSERT_NO_ERRNO_AND_VALUE(
+ OpenAt(mounted_dir.get(), std::string(Basename(file.path())), O_RDONLY));
+
+ std::vector<char> buf2(sizeof(kContents));
+ EXPECT_THAT(ReadFd(fd_again.get(), buf2.data(), buf2.size()),
+ SyscallSucceeds());
+ EXPECT_EQ(buf, buf2);
+
+ // Walking outside the unmounted realm should still work, too!
+ auto const dir_parent = ASSERT_NO_ERRNO_AND_VALUE(
+ OpenAt(mounted_dir.get(), "..", O_DIRECTORY | O_RDONLY));
+}
+
+TEST(MountTest, ActiveSubmountBusy) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto const mount1 = ASSERT_NO_ERRNO_AND_VALUE(
+ Mount("", dir.path(), "tmpfs", 0, "mode=0700", 0));
+
+ auto const dir2 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir.path()));
+ auto const mount2 =
+ ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir2.path(), "tmpfs", 0, "", 0));
+
+ // Since dir now has an active submount, should not be able to unmount.
+ EXPECT_THAT(umount(dir.path().c_str()), SyscallFailsWithErrno(EBUSY));
+}
+
+TEST(MountTest, MountTmpfs) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ // NOTE(b/129868551): Inode IDs are only stable across S/R if we have an open
+ // FD for that inode. Since we are going to compare inode IDs below, get a
+ // FileDescriptor for this directory here, which will be closed automatically
+ // at the end of the test.
+ auto const fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY, O_RDONLY));
+
+ const struct stat before = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path()));
+
+ {
+ auto const mount = ASSERT_NO_ERRNO_AND_VALUE(
+ Mount("", dir.path(), "tmpfs", 0, "mode=0700", 0));
+
+ const struct stat s = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path()));
+ EXPECT_EQ(s.st_mode, S_IFDIR | 0700);
+ EXPECT_NE(s.st_ino, before.st_ino);
+
+ EXPECT_NO_ERRNO(Open(JoinPath(dir.path(), "foo"), O_CREAT | O_RDWR, 0777));
+ }
+
+ // Now that dir is unmounted again, we should have the old inode back.
+ const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path()));
+ EXPECT_EQ(before.st_ino, after.st_ino);
+}
+
+TEST(MountTest, MountTmpfsMagicValIgnored) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ auto const mount = ASSERT_NO_ERRNO_AND_VALUE(
+ Mount("", dir.path(), "tmpfs", MS_MGC_VAL, "mode=0700", 0));
+}
+
+// Passing nullptr to data is equivalent to "".
+TEST(MountTest, NullData) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ EXPECT_THAT(mount("", dir.path().c_str(), "tmpfs", 0, nullptr),
+ SyscallSucceeds());
+ EXPECT_THAT(umount2(dir.path().c_str(), 0), SyscallSucceeds());
+}
+
+TEST(MountTest, MountReadonly) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto const mount = ASSERT_NO_ERRNO_AND_VALUE(
+ Mount("", dir.path(), "tmpfs", MS_RDONLY, "mode=0777", 0));
+
+ const struct stat s = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path()));
+ EXPECT_EQ(s.st_mode, S_IFDIR | 0777);
+
+ std::string const filename = JoinPath(dir.path(), "foo");
+ EXPECT_THAT(open(filename.c_str(), O_RDWR | O_CREAT, 0777),
+ SyscallFailsWithErrno(EROFS));
+}
+
+PosixErrorOr<absl::Time> ATime(absl::string_view file) {
+ struct stat s = {};
+ if (stat(std::string(file).c_str(), &s) == -1) {
+ return PosixError(errno, "stat failed");
+ }
+ return absl::TimeFromTimespec(s.st_atim);
+}
+
+TEST(MountTest, MountNoAtime) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto const mount = ASSERT_NO_ERRNO_AND_VALUE(
+ Mount("", dir.path(), "tmpfs", MS_NOATIME, "mode=0777", 0));
+
+ std::string const contents = "No no no, don't follow the instructions!";
+ auto const file = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(dir.path(), contents, 0777));
+
+ absl::Time const before = ASSERT_NO_ERRNO_AND_VALUE(ATime(file.path()));
+
+ // Reading from the file should change the atime, but the MS_NOATIME flag
+ // should prevent that.
+ auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+ char buf[100];
+ int read_n;
+ ASSERT_THAT(read_n = read(fd.get(), buf, sizeof(buf)), SyscallSucceeds());
+ EXPECT_EQ(std::string(buf, read_n), contents);
+
+ absl::Time const after = ASSERT_NO_ERRNO_AND_VALUE(ATime(file.path()));
+
+ // Expect that atime hasn't changed.
+ EXPECT_EQ(before, after);
+}
+
+TEST(MountTest, MountNoExec) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto const mount = ASSERT_NO_ERRNO_AND_VALUE(
+ Mount("", dir.path(), "tmpfs", MS_NOEXEC, "mode=0777", 0));
+
+ std::string const contents = "No no no, don't follow the instructions!";
+ auto const file = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(dir.path(), contents, 0777));
+
+ int execve_errno;
+ ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(file.path(), {}, {}, nullptr, &execve_errno));
+ EXPECT_EQ(execve_errno, EACCES);
+}
+
+TEST(MountTest, RenameRemoveMountPoint) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ auto const dir_parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto const dir =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir_parent.path()));
+ auto const new_dir = NewTempAbsPath();
+
+ auto const mount =
+ ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "tmpfs", 0, "", 0));
+
+ ASSERT_THAT(rename(dir.path().c_str(), new_dir.c_str()),
+ SyscallFailsWithErrno(EBUSY));
+
+ ASSERT_THAT(rmdir(dir.path().c_str()), SyscallFailsWithErrno(EBUSY));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/mremap.cc b/test/syscalls/linux/mremap.cc
new file mode 100644
index 000000000..f0e5f7d82
--- /dev/null
+++ b/test/syscalls/linux/mremap.cc
@@ -0,0 +1,492 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <string.h>
+#include <sys/mman.h>
+
+#include <string>
+
+#include "gmock/gmock.h"
+#include "absl/strings/string_view.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/memory_util.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+using ::testing::_;
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Fixture for mremap tests parameterized by mmap flags.
+using MremapParamTest = ::testing::TestWithParam<int>;
+
+TEST_P(MremapParamTest, Noop) {
+ Mapping const m =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, GetParam()));
+
+ ASSERT_THAT(Mremap(m.ptr(), kPageSize, kPageSize, 0, nullptr),
+ IsPosixErrorOkAndHolds(m.ptr()));
+ EXPECT_TRUE(IsMapped(m.addr()));
+}
+
+TEST_P(MremapParamTest, InPlace_ShrinkingWholeVMA) {
+ Mapping const m =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(2 * kPageSize, PROT_NONE, GetParam()));
+
+ const auto rest = [&] {
+ // N.B. we must be in a single-threaded subprocess to ensure a
+ // background thread doesn't concurrently map the second page.
+ void* addr = mremap(m.ptr(), 2 * kPageSize, kPageSize, 0, nullptr);
+ TEST_PCHECK_MSG(addr != MAP_FAILED, "mremap failed");
+ TEST_CHECK(addr == m.ptr());
+ MaybeSave();
+
+ TEST_CHECK(IsMapped(m.addr()));
+ TEST_CHECK(!IsMapped(m.addr() + kPageSize));
+ };
+
+ EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
+}
+
+TEST_P(MremapParamTest, InPlace_ShrinkingPartialVMA) {
+ Mapping const m =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(3 * kPageSize, PROT_NONE, GetParam()));
+
+ const auto rest = [&] {
+ void* addr = mremap(m.ptr(), 2 * kPageSize, kPageSize, 0, nullptr);
+ TEST_PCHECK_MSG(addr != MAP_FAILED, "mremap failed");
+ TEST_CHECK(addr == m.ptr());
+ MaybeSave();
+
+ TEST_CHECK(IsMapped(m.addr()));
+ TEST_CHECK(!IsMapped(m.addr() + kPageSize));
+ TEST_CHECK(IsMapped(m.addr() + 2 * kPageSize));
+ };
+
+ EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
+}
+
+TEST_P(MremapParamTest, InPlace_ShrinkingAcrossVMAs) {
+ Mapping const m =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(3 * kPageSize, PROT_READ, GetParam()));
+ // Changing permissions on the first page forces it to become a separate vma.
+ ASSERT_THAT(mprotect(m.ptr(), kPageSize, PROT_NONE), SyscallSucceeds());
+
+ const auto rest = [&] {
+ // Both old_size and new_size now span two vmas; mremap
+ // shouldn't care.
+ void* addr = mremap(m.ptr(), 3 * kPageSize, 2 * kPageSize, 0, nullptr);
+ TEST_PCHECK_MSG(addr != MAP_FAILED, "mremap failed");
+ TEST_CHECK(addr == m.ptr());
+ MaybeSave();
+
+ TEST_CHECK(IsMapped(m.addr()));
+ TEST_CHECK(IsMapped(m.addr() + kPageSize));
+ TEST_CHECK(!IsMapped(m.addr() + 2 * kPageSize));
+ };
+
+ EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
+}
+
+TEST_P(MremapParamTest, InPlace_ExpansionSuccess) {
+ Mapping const m =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(2 * kPageSize, PROT_NONE, GetParam()));
+
+ const auto rest = [&] {
+ // Unmap the second page so that the first can be expanded back into it.
+ //
+ // N.B. we must be in a single-threaded subprocess to ensure a
+ // background thread doesn't concurrently map this page.
+ TEST_PCHECK(
+ munmap(reinterpret_cast<void*>(m.addr() + kPageSize), kPageSize) == 0);
+ MaybeSave();
+
+ void* addr = mremap(m.ptr(), kPageSize, 2 * kPageSize, 0, nullptr);
+ TEST_PCHECK_MSG(addr != MAP_FAILED, "mremap failed");
+ TEST_CHECK(addr == m.ptr());
+ MaybeSave();
+
+ TEST_CHECK(IsMapped(m.addr()));
+ TEST_CHECK(IsMapped(m.addr() + kPageSize));
+ };
+
+ EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
+}
+
+TEST_P(MremapParamTest, InPlace_ExpansionFailure) {
+ Mapping const m =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(3 * kPageSize, PROT_NONE, GetParam()));
+
+ const auto rest = [&] {
+ // Unmap the second page, leaving a one-page hole. Trying to expand the
+ // first page to three pages should fail since the original third page
+ // is still mapped.
+ TEST_PCHECK(
+ munmap(reinterpret_cast<void*>(m.addr() + kPageSize), kPageSize) == 0);
+ MaybeSave();
+
+ void* addr = mremap(m.ptr(), kPageSize, 3 * kPageSize, 0, nullptr);
+ TEST_CHECK_MSG(addr == MAP_FAILED, "mremap unexpectedly succeeded");
+ TEST_PCHECK_MSG(errno == ENOMEM, "mremap failed with wrong errno");
+ MaybeSave();
+
+ TEST_CHECK(IsMapped(m.addr()));
+ TEST_CHECK(!IsMapped(m.addr() + kPageSize));
+ TEST_CHECK(IsMapped(m.addr() + 2 * kPageSize));
+ };
+
+ EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
+}
+
+TEST_P(MremapParamTest, MayMove_Expansion) {
+ Mapping const m =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(3 * kPageSize, PROT_NONE, GetParam()));
+
+ const auto rest = [&] {
+ // Unmap the second page, leaving a one-page hole. Trying to expand the
+ // first page to three pages with MREMAP_MAYMOVE should force the
+ // mapping to be relocated since the original third page is still
+ // mapped.
+ TEST_PCHECK(
+ munmap(reinterpret_cast<void*>(m.addr() + kPageSize), kPageSize) == 0);
+ MaybeSave();
+
+ void* addr2 =
+ mremap(m.ptr(), kPageSize, 3 * kPageSize, MREMAP_MAYMOVE, nullptr);
+ TEST_PCHECK_MSG(addr2 != MAP_FAILED, "mremap failed");
+ MaybeSave();
+
+ const Mapping m2 = Mapping(addr2, 3 * kPageSize);
+ TEST_CHECK(m.addr() != m2.addr());
+
+ TEST_CHECK(!IsMapped(m.addr()));
+ TEST_CHECK(!IsMapped(m.addr() + kPageSize));
+ TEST_CHECK(IsMapped(m.addr() + 2 * kPageSize));
+ TEST_CHECK(IsMapped(m2.addr()));
+ TEST_CHECK(IsMapped(m2.addr() + kPageSize));
+ TEST_CHECK(IsMapped(m2.addr() + 2 * kPageSize));
+ };
+
+ EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
+}
+
+TEST_P(MremapParamTest, Fixed_SourceAndDestinationCannotOverlap) {
+ Mapping const m =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, GetParam()));
+
+ ASSERT_THAT(Mremap(m.ptr(), kPageSize, kPageSize,
+ MREMAP_MAYMOVE | MREMAP_FIXED, m.ptr()),
+ PosixErrorIs(EINVAL, _));
+ EXPECT_TRUE(IsMapped(m.addr()));
+}
+
+TEST_P(MremapParamTest, Fixed_SameSize) {
+ Mapping const src =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, GetParam()));
+ Mapping const dst =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, GetParam()));
+
+ const auto rest = [&] {
+ // Unmap dst to create a hole.
+ TEST_PCHECK(munmap(dst.ptr(), kPageSize) == 0);
+ MaybeSave();
+
+ void* addr = mremap(src.ptr(), kPageSize, kPageSize,
+ MREMAP_MAYMOVE | MREMAP_FIXED, dst.ptr());
+ TEST_PCHECK_MSG(addr != MAP_FAILED, "mremap failed");
+ TEST_CHECK(addr == dst.ptr());
+ MaybeSave();
+
+ TEST_CHECK(!IsMapped(src.addr()));
+ TEST_CHECK(IsMapped(dst.addr()));
+ };
+
+ EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
+}
+
+TEST_P(MremapParamTest, Fixed_SameSize_Unmapping) {
+ // Like the Fixed_SameSize case, but expect mremap to unmap the destination
+ // automatically.
+ Mapping const src =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, GetParam()));
+ Mapping const dst =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, GetParam()));
+
+ const auto rest = [&] {
+ void* addr = mremap(src.ptr(), kPageSize, kPageSize,
+ MREMAP_MAYMOVE | MREMAP_FIXED, dst.ptr());
+ TEST_PCHECK_MSG(addr != MAP_FAILED, "mremap failed");
+ TEST_CHECK(addr == dst.ptr());
+ MaybeSave();
+
+ TEST_CHECK(!IsMapped(src.addr()));
+ TEST_CHECK(IsMapped(dst.addr()));
+ };
+
+ EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
+}
+
+TEST_P(MremapParamTest, Fixed_ShrinkingWholeVMA) {
+ Mapping const src =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(2 * kPageSize, PROT_NONE, GetParam()));
+ Mapping const dst =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(2 * kPageSize, PROT_NONE, GetParam()));
+
+ const auto rest = [&] {
+ // Unmap dst so we can check that mremap does not keep the
+ // second page.
+ TEST_PCHECK(munmap(dst.ptr(), 2 * kPageSize) == 0);
+ MaybeSave();
+
+ void* addr = mremap(src.ptr(), 2 * kPageSize, kPageSize,
+ MREMAP_MAYMOVE | MREMAP_FIXED, dst.ptr());
+ TEST_PCHECK_MSG(addr != MAP_FAILED, "mremap failed");
+ TEST_CHECK(addr == dst.ptr());
+ MaybeSave();
+
+ TEST_CHECK(!IsMapped(src.addr()));
+ TEST_CHECK(!IsMapped(src.addr() + kPageSize));
+ TEST_CHECK(IsMapped(dst.addr()));
+ TEST_CHECK(!IsMapped(dst.addr() + kPageSize));
+ };
+
+ EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
+}
+
+TEST_P(MremapParamTest, Fixed_ShrinkingPartialVMA) {
+ Mapping const src =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(3 * kPageSize, PROT_NONE, GetParam()));
+ Mapping const dst =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(2 * kPageSize, PROT_NONE, GetParam()));
+
+ const auto rest = [&] {
+ // Unmap dst so we can check that mremap does not keep the
+ // second page.
+ TEST_PCHECK(munmap(dst.ptr(), 2 * kPageSize) == 0);
+ MaybeSave();
+
+ void* addr = mremap(src.ptr(), 2 * kPageSize, kPageSize,
+ MREMAP_MAYMOVE | MREMAP_FIXED, dst.ptr());
+ TEST_PCHECK_MSG(addr != MAP_FAILED, "mremap failed");
+ TEST_CHECK(addr == dst.ptr());
+ MaybeSave();
+
+ TEST_CHECK(!IsMapped(src.addr()));
+ TEST_CHECK(!IsMapped(src.addr() + kPageSize));
+ TEST_CHECK(IsMapped(src.addr() + 2 * kPageSize));
+ TEST_CHECK(IsMapped(dst.addr()));
+ TEST_CHECK(!IsMapped(dst.addr() + kPageSize));
+ };
+
+ EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
+}
+
+TEST_P(MremapParamTest, Fixed_ShrinkingAcrossVMAs) {
+ Mapping const src =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(3 * kPageSize, PROT_READ, GetParam()));
+ // Changing permissions on the first page forces it to become a separate vma.
+ ASSERT_THAT(mprotect(src.ptr(), kPageSize, PROT_NONE), SyscallSucceeds());
+ Mapping const dst =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(2 * kPageSize, PROT_NONE, GetParam()));
+
+ const auto rest = [&] {
+ // Unlike flags=0, MREMAP_FIXED requires that [old_address,
+ // old_address+new_size) only spans a single vma.
+ void* addr = mremap(src.ptr(), 3 * kPageSize, 2 * kPageSize,
+ MREMAP_MAYMOVE | MREMAP_FIXED, dst.ptr());
+ TEST_CHECK_MSG(addr == MAP_FAILED, "mremap unexpectedly succeeded");
+ TEST_PCHECK_MSG(errno == EFAULT, "mremap failed with wrong errno");
+ MaybeSave();
+
+ TEST_CHECK(IsMapped(src.addr()));
+ TEST_CHECK(IsMapped(src.addr() + kPageSize));
+ // Despite failing, mremap should have unmapped [old_address+new_size,
+ // old_address+old_size) (i.e. the third page).
+ TEST_CHECK(!IsMapped(src.addr() + 2 * kPageSize));
+ // Despite failing, mremap should have unmapped the destination pages.
+ TEST_CHECK(!IsMapped(dst.addr()));
+ TEST_CHECK(!IsMapped(dst.addr() + kPageSize));
+ };
+
+ EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
+}
+
+TEST_P(MremapParamTest, Fixed_Expansion) {
+ Mapping const src =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, GetParam()));
+ Mapping const dst =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(2 * kPageSize, PROT_NONE, GetParam()));
+
+ const auto rest = [&] {
+ // Unmap dst so we can check that mremap actually maps all pages
+ // at the destination.
+ TEST_PCHECK(munmap(dst.ptr(), 2 * kPageSize) == 0);
+ MaybeSave();
+
+ void* addr = mremap(src.ptr(), kPageSize, 2 * kPageSize,
+ MREMAP_MAYMOVE | MREMAP_FIXED, dst.ptr());
+ TEST_PCHECK_MSG(addr != MAP_FAILED, "mremap failed");
+ TEST_CHECK(addr == dst.ptr());
+ MaybeSave();
+
+ TEST_CHECK(!IsMapped(src.addr()));
+ TEST_CHECK(IsMapped(dst.addr()));
+ TEST_CHECK(IsMapped(dst.addr() + kPageSize));
+ };
+
+ EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
+}
+
+INSTANTIATE_TEST_SUITE_P(PrivateShared, MremapParamTest,
+ ::testing::Values(MAP_PRIVATE, MAP_SHARED));
+
+// mremap with old_size == 0 only works with MAP_SHARED after Linux 4.14
+// (dba58d3b8c50 "mm/mremap: fail map duplication attempts for private
+// mappings").
+
+TEST(MremapTest, InPlace_Copy) {
+ Mapping const m =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, MAP_SHARED));
+ EXPECT_THAT(Mremap(m.ptr(), 0, kPageSize, 0, nullptr),
+ PosixErrorIs(ENOMEM, _));
+}
+
+TEST(MremapTest, MayMove_Copy) {
+ Mapping const m =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, MAP_SHARED));
+
+ // Remainder of this test executes in a subprocess to ensure that if mremap
+ // incorrectly removes m, it is not remapped by another thread.
+ const auto rest = [&] {
+ void* ptr = mremap(m.ptr(), 0, kPageSize, MREMAP_MAYMOVE, nullptr);
+ MaybeSave();
+ TEST_PCHECK_MSG(ptr != MAP_FAILED, "mremap failed");
+ TEST_CHECK(ptr != m.ptr());
+ TEST_CHECK(IsMapped(m.addr()));
+ TEST_CHECK(IsMapped(reinterpret_cast<uintptr_t>(ptr)));
+ };
+ EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
+}
+
+TEST(MremapTest, MustMove_Copy) {
+ Mapping const src =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, MAP_SHARED));
+ Mapping const dst =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, MAP_PRIVATE));
+
+ // Remainder of this test executes in a subprocess to ensure that if mremap
+ // incorrectly removes src, it is not remapped by another thread.
+ const auto rest = [&] {
+ void* ptr = mremap(src.ptr(), 0, kPageSize, MREMAP_MAYMOVE | MREMAP_FIXED,
+ dst.ptr());
+ MaybeSave();
+ TEST_PCHECK_MSG(ptr != MAP_FAILED, "mremap failed");
+ TEST_CHECK(ptr == dst.ptr());
+ TEST_CHECK(IsMapped(src.addr()));
+ TEST_CHECK(IsMapped(dst.addr()));
+ };
+ EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
+}
+
+void ExpectAllBytesAre(absl::string_view v, char c) {
+ for (size_t i = 0; i < v.size(); i++) {
+ ASSERT_EQ(v[i], c) << "at offset " << i;
+ }
+}
+
+TEST(MremapTest, ExpansionPreservesCOWPagesAndExposesNewFilePages) {
+ // Create a file with 3 pages. The first is filled with 'a', the second is
+ // filled with 'b', and the third is filled with 'c'.
+ TempPath const file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+ ASSERT_THAT(WriteFd(fd.get(), std::string(kPageSize, 'a').c_str(), kPageSize),
+ SyscallSucceedsWithValue(kPageSize));
+ ASSERT_THAT(WriteFd(fd.get(), std::string(kPageSize, 'b').c_str(), kPageSize),
+ SyscallSucceedsWithValue(kPageSize));
+ ASSERT_THAT(WriteFd(fd.get(), std::string(kPageSize, 'c').c_str(), kPageSize),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Create a private mapping of the first 2 pages, and fill the second page
+ // with 'd'.
+ Mapping const src = ASSERT_NO_ERRNO_AND_VALUE(Mmap(nullptr, 2 * kPageSize,
+ PROT_READ | PROT_WRITE,
+ MAP_PRIVATE, fd.get(), 0));
+ memset(reinterpret_cast<void*>(src.addr() + kPageSize), 'd', kPageSize);
+ MaybeSave();
+
+ // Move the mapping while expanding it to 3 pages. The resulting mapping
+ // should contain the original first page of the file (filled with 'a'),
+ // followed by the private copy of the second page (filled with 'd'), followed
+ // by the newly-mapped third page of the file (filled with 'c').
+ Mapping const dst = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(3 * kPageSize, PROT_NONE, MAP_PRIVATE));
+ ASSERT_THAT(Mremap(src.ptr(), 2 * kPageSize, 3 * kPageSize,
+ MREMAP_MAYMOVE | MREMAP_FIXED, dst.ptr()),
+ IsPosixErrorOkAndHolds(dst.ptr()));
+ auto const v = dst.view();
+ ExpectAllBytesAre(v.substr(0, kPageSize), 'a');
+ ExpectAllBytesAre(v.substr(kPageSize, kPageSize), 'd');
+ ExpectAllBytesAre(v.substr(2 * kPageSize, kPageSize), 'c');
+}
+
+TEST(MremapDeathTest, SharedAnon) {
+ SetupGvisorDeathTest();
+
+ // Reserve 4 pages of address space.
+ Mapping const reserved = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(4 * kPageSize, PROT_NONE, MAP_PRIVATE));
+
+ // Create a 2-page shared anonymous mapping at the beginning of the
+ // reservation. Fill the first page with 'a' and the second with 'b'.
+ Mapping const m = ASSERT_NO_ERRNO_AND_VALUE(
+ Mmap(reserved.ptr(), 2 * kPageSize, PROT_READ | PROT_WRITE,
+ MAP_SHARED | MAP_ANONYMOUS | MAP_FIXED, -1, 0));
+ memset(m.ptr(), 'a', kPageSize);
+ memset(reinterpret_cast<void*>(m.addr() + kPageSize), 'b', kPageSize);
+ MaybeSave();
+
+ // Shrink the mapping to 1 page in-place.
+ ASSERT_THAT(Mremap(m.ptr(), 2 * kPageSize, kPageSize, 0, m.ptr()),
+ IsPosixErrorOkAndHolds(m.ptr()));
+
+ // Expand the mapping to 3 pages, moving it forward by 1 page in the process
+ // since the old and new mappings can't overlap.
+ void* const new_m = reinterpret_cast<void*>(m.addr() + kPageSize);
+ ASSERT_THAT(Mremap(m.ptr(), kPageSize, 3 * kPageSize,
+ MREMAP_MAYMOVE | MREMAP_FIXED, new_m),
+ IsPosixErrorOkAndHolds(new_m));
+
+ // The first 2 pages of the mapping should still contain the data we wrote
+ // (i.e. shrinking should not have discarded the second page's data), while
+ // touching the third page should raise SIGBUS.
+ auto const v =
+ absl::string_view(static_cast<char const*>(new_m), 3 * kPageSize);
+ ExpectAllBytesAre(v.substr(0, kPageSize), 'a');
+ ExpectAllBytesAre(v.substr(kPageSize, kPageSize), 'b');
+ EXPECT_EXIT(ExpectAllBytesAre(v.substr(2 * kPageSize, kPageSize), '\0'),
+ ::testing::KilledBySignal(SIGBUS), "");
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/msync.cc b/test/syscalls/linux/msync.cc
new file mode 100644
index 000000000..2b2b6aef9
--- /dev/null
+++ b/test/syscalls/linux/msync.cc
@@ -0,0 +1,151 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/mman.h>
+#include <unistd.h>
+
+#include <functional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "test/util/file_descriptor.h"
+#include "test/util/memory_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Parameters for msync tests. Use a std::tuple so we can use
+// ::testing::Combine.
+using MsyncTestParam =
+ std::tuple<int, // msync flags
+ std::function<PosixErrorOr<Mapping>()> // returns mapping to
+ // msync
+ >;
+
+class MsyncParameterizedTest : public ::testing::TestWithParam<MsyncTestParam> {
+ protected:
+ int msync_flags() const { return std::get<0>(GetParam()); }
+
+ PosixErrorOr<Mapping> GetMapping() const { return std::get<1>(GetParam())(); }
+};
+
+// All valid msync(2) flag combinations, not including MS_INVALIDATE. ("Linux
+// permits a call to msync() that specifies neither [MS_SYNC or MS_ASYNC], with
+// semantics that are (currently) equivalent to specifying MS_ASYNC." -
+// msync(2))
+constexpr std::initializer_list<int> kMsyncFlags = {MS_SYNC, MS_ASYNC, 0};
+
+// Returns functions that return mappings that should be successfully
+// msync()able.
+std::vector<std::function<PosixErrorOr<Mapping>()>> SyncableMappings() {
+ std::vector<std::function<PosixErrorOr<Mapping>()>> funcs;
+ for (bool const writable : {false, true}) {
+ for (int const mflags : {MAP_PRIVATE, MAP_SHARED}) {
+ int const prot = PROT_READ | (writable ? PROT_WRITE : 0);
+ int const oflags = O_CREAT | (writable ? O_RDWR : O_RDONLY);
+ funcs.push_back([=] { return MmapAnon(kPageSize, prot, mflags); });
+ funcs.push_back([=]() -> PosixErrorOr<Mapping> {
+ std::string const path = NewTempAbsPath();
+ ASSIGN_OR_RETURN_ERRNO(auto fd, Open(path, oflags, 0644));
+ // Don't unlink the file since that breaks save/restore. Just let the
+ // test infrastructure clean up all of our temporary files when we're
+ // done.
+ return Mmap(nullptr, kPageSize, prot, mflags, fd.get(), 0);
+ });
+ }
+ }
+ return funcs;
+}
+
+PosixErrorOr<Mapping> NoMappings() {
+ return PosixError(EINVAL, "unexpected attempt to create a mapping");
+}
+
+// "Fixture" for msync tests that hold for all valid flags, but do not create
+// mappings.
+using MsyncNoMappingTest = MsyncParameterizedTest;
+
+TEST_P(MsyncNoMappingTest, UnmappedAddressWithZeroLengthSucceeds) {
+ EXPECT_THAT(msync(nullptr, 0, msync_flags()), SyscallSucceeds());
+}
+
+TEST_P(MsyncNoMappingTest, UnmappedAddressWithNonzeroLengthFails) {
+ EXPECT_THAT(msync(nullptr, kPageSize, msync_flags()),
+ SyscallFailsWithErrno(ENOMEM));
+}
+
+INSTANTIATE_TEST_SUITE_P(All, MsyncNoMappingTest,
+ ::testing::Combine(::testing::ValuesIn(kMsyncFlags),
+ ::testing::Values(NoMappings)));
+
+// "Fixture" for msync tests that are not parameterized by msync flags, but do
+// create mappings.
+using MsyncNoFlagsTest = MsyncParameterizedTest;
+
+TEST_P(MsyncNoFlagsTest, BothSyncAndAsyncFails) {
+ auto m = ASSERT_NO_ERRNO_AND_VALUE(GetMapping());
+ EXPECT_THAT(msync(m.ptr(), m.len(), MS_SYNC | MS_ASYNC),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ All, MsyncNoFlagsTest,
+ ::testing::Combine(::testing::Values(0), // ignored
+ ::testing::ValuesIn(SyncableMappings())));
+
+// "Fixture" for msync tests parameterized by both msync flags and sources of
+// mappings.
+using MsyncFullParamTest = MsyncParameterizedTest;
+
+TEST_P(MsyncFullParamTest, NormallySucceeds) {
+ auto m = ASSERT_NO_ERRNO_AND_VALUE(GetMapping());
+ EXPECT_THAT(msync(m.ptr(), m.len(), msync_flags()), SyscallSucceeds());
+}
+
+TEST_P(MsyncFullParamTest, UnalignedLengthSucceeds) {
+ auto m = ASSERT_NO_ERRNO_AND_VALUE(GetMapping());
+ EXPECT_THAT(msync(m.ptr(), m.len() - 1, msync_flags()), SyscallSucceeds());
+}
+
+TEST_P(MsyncFullParamTest, UnalignedAddressFails) {
+ auto m = ASSERT_NO_ERRNO_AND_VALUE(GetMapping());
+ EXPECT_THAT(
+ msync(reinterpret_cast<void*>(m.addr() + 1), m.len() - 1, msync_flags()),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(MsyncFullParamTest, InvalidateUnlockedSucceeds) {
+ auto m = ASSERT_NO_ERRNO_AND_VALUE(GetMapping());
+ EXPECT_THAT(msync(m.ptr(), m.len(), msync_flags() | MS_INVALIDATE),
+ SyscallSucceeds());
+}
+
+// The test for MS_INVALIDATE on mlocked pages is in mlock.cc since it requires
+// probing for mlock support.
+
+INSTANTIATE_TEST_SUITE_P(
+ All, MsyncFullParamTest,
+ ::testing::Combine(::testing::ValuesIn(kMsyncFlags),
+ ::testing::ValuesIn(SyncableMappings())));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/munmap.cc b/test/syscalls/linux/munmap.cc
new file mode 100644
index 000000000..067241f4d
--- /dev/null
+++ b/test/syscalls/linux/munmap.cc
@@ -0,0 +1,53 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/mman.h>
+
+#include "gtest/gtest.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class MunmapTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ m_ = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+ ASSERT_NE(MAP_FAILED, m_);
+ }
+
+ void* m_ = nullptr;
+};
+
+TEST_F(MunmapTest, HappyCase) {
+ EXPECT_THAT(munmap(m_, kPageSize), SyscallSucceeds());
+}
+
+TEST_F(MunmapTest, ZeroLength) {
+ EXPECT_THAT(munmap(m_, 0), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(MunmapTest, LastPageRoundUp) {
+ // Attempt to unmap up to and including the last page.
+ EXPECT_THAT(munmap(m_, static_cast<size_t>(-kPageSize + 1)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/network_namespace.cc b/test/syscalls/linux/network_namespace.cc
new file mode 100644
index 000000000..133fdecf0
--- /dev/null
+++ b/test/syscalls/linux/network_namespace.cc
@@ -0,0 +1,52 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <net/if.h>
+#include <sched.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+TEST(NetworkNamespaceTest, LoopbackExists) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ ScopedThread t([&] {
+ ASSERT_THAT(unshare(CLONE_NEWNET), SyscallSucceedsWithValue(0));
+
+ // TODO(gvisor.dev/issue/1833): Update this to test that only "lo" exists.
+ // Check loopback device exists.
+ int sock = socket(AF_INET, SOCK_DGRAM, 0);
+ ASSERT_THAT(sock, SyscallSucceeds());
+ struct ifreq ifr;
+ strncpy(ifr.ifr_name, "lo", IFNAMSIZ);
+ EXPECT_THAT(ioctl(sock, SIOCGIFINDEX, &ifr), SyscallSucceeds())
+ << "lo cannot be found";
+ });
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc
new file mode 100644
index 000000000..bb7d108e8
--- /dev/null
+++ b/test/syscalls/linux/open.cc
@@ -0,0 +1,451 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/capability.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
+#include "test/syscalls/linux/file_base.h"
+#include "test/util/capability_util.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// This test is currently very rudimentary.
+//
+// There are plenty of extra cases to cover once the sentry supports them.
+//
+// Different types of opens:
+// * O_CREAT
+// * O_DIRECTORY
+// * O_NOFOLLOW
+// * O_PATH <- Will we ever support this?
+//
+// Special operations on open:
+// * O_EXCL
+//
+// Special files:
+// * Blocking behavior for a named pipe.
+//
+// Different errors:
+// * EACCES
+// * EEXIST
+// * ENAMETOOLONG
+// * ELOOP
+// * ENOTDIR
+// * EPERM
+class OpenTest : public FileTest {
+ void SetUp() override {
+ FileTest::SetUp();
+
+ ASSERT_THAT(
+ write(test_file_fd_.get(), test_data_.c_str(), test_data_.length()),
+ SyscallSucceedsWithValue(test_data_.length()));
+ EXPECT_THAT(lseek(test_file_fd_.get(), 0, SEEK_SET), SyscallSucceeds());
+ }
+
+ public:
+ const std::string test_data_ = "hello world\n";
+};
+
+TEST_F(OpenTest, OTrunc) {
+ auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd");
+ ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds());
+ ASSERT_THAT(open(dirpath.c_str(), O_TRUNC, 0666),
+ SyscallFailsWithErrno(EISDIR));
+}
+
+TEST_F(OpenTest, OTruncAndReadOnlyDir) {
+ auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd");
+ ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds());
+ ASSERT_THAT(open(dirpath.c_str(), O_TRUNC | O_RDONLY, 0666),
+ SyscallFailsWithErrno(EISDIR));
+}
+
+TEST_F(OpenTest, OTruncAndReadOnlyFile) {
+ auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile");
+ const FileDescriptor existing =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dirpath.c_str(), O_RDWR | O_CREAT, 0666));
+ const FileDescriptor otrunc = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(dirpath.c_str(), O_TRUNC | O_RDONLY, 0666));
+}
+
+TEST_F(OpenTest, ReadOnly) {
+ char buf;
+ const FileDescriptor ro_file =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY));
+
+ EXPECT_THAT(read(ro_file.get(), &buf, 1), SyscallSucceedsWithValue(1));
+ EXPECT_THAT(lseek(ro_file.get(), 0, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(write(ro_file.get(), &buf, 1), SyscallFailsWithErrno(EBADF));
+}
+
+TEST_F(OpenTest, WriteOnly) {
+ char buf;
+ const FileDescriptor wo_file =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_WRONLY));
+
+ EXPECT_THAT(read(wo_file.get(), &buf, 1), SyscallFailsWithErrno(EBADF));
+ EXPECT_THAT(lseek(wo_file.get(), 0, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(write(wo_file.get(), &buf, 1), SyscallSucceedsWithValue(1));
+}
+
+TEST_F(OpenTest, ReadWrite) {
+ char buf;
+ const FileDescriptor rw_file =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR));
+
+ EXPECT_THAT(read(rw_file.get(), &buf, 1), SyscallSucceedsWithValue(1));
+ EXPECT_THAT(lseek(rw_file.get(), 0, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(write(rw_file.get(), &buf, 1), SyscallSucceedsWithValue(1));
+}
+
+TEST_F(OpenTest, RelPath) {
+ auto name = std::string(Basename(test_file_name_));
+
+ ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds());
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(name, O_RDONLY));
+}
+
+TEST_F(OpenTest, AbsPath) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY));
+}
+
+TEST_F(OpenTest, AtRelPath) {
+ auto name = std::string(Basename(test_file_name_));
+ const FileDescriptor dirfd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(GetAbsoluteTestTmpdir(), O_RDONLY | O_DIRECTORY));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenAt(dirfd.get(), name, O_RDONLY));
+}
+
+TEST_F(OpenTest, AtAbsPath) {
+ const FileDescriptor dirfd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(GetAbsoluteTestTmpdir(), O_RDONLY | O_DIRECTORY));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenAt(dirfd.get(), test_file_name_, O_RDONLY));
+}
+
+TEST_F(OpenTest, OpenNoFollowSymlink) {
+ const std::string link_path = JoinPath(GetAbsoluteTestTmpdir(), "link");
+ ASSERT_THAT(symlink(test_file_name_.c_str(), link_path.c_str()),
+ SyscallSucceeds());
+ auto cleanup = Cleanup([link_path]() {
+ EXPECT_THAT(unlink(link_path.c_str()), SyscallSucceeds());
+ });
+
+ // Open will succeed without O_NOFOLLOW and fails with O_NOFOLLOW.
+ const FileDescriptor fd2 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(link_path, O_RDONLY));
+ ASSERT_THAT(open(link_path.c_str(), O_RDONLY | O_NOFOLLOW),
+ SyscallFailsWithErrno(ELOOP));
+}
+
+TEST_F(OpenTest, OpenNoFollowStillFollowsLinksInPath) {
+ // We will create the following structure:
+ // tmp_folder/real_folder/file
+ // tmp_folder/sym_folder -> tmp_folder/real_folder
+ //
+ // We will then open tmp_folder/sym_folder/file with O_NOFOLLOW and it
+ // should succeed as O_NOFOLLOW only applies to the final path component.
+ auto tmp_path =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(GetAbsoluteTestTmpdir()));
+ auto sym_path = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(GetAbsoluteTestTmpdir(), tmp_path.path()));
+ auto file_path =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(tmp_path.path()));
+
+ auto path_via_symlink = JoinPath(sym_path.path(), Basename(file_path.path()));
+ const FileDescriptor fd2 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(path_via_symlink, O_RDONLY | O_NOFOLLOW));
+}
+
+// Test that open(2) can follow symlinks that point back to the same tree.
+// Test sets up files as follows:
+// root/child/symlink => redirects to ../..
+// root/child/target => regular file
+//
+// open("root/child/symlink/root/child/file")
+TEST_F(OpenTest, SymlinkRecurse) {
+ auto root =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(GetAbsoluteTestTmpdir()));
+ auto child = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root.path()));
+ auto symlink = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(child.path(), "../.."));
+ auto target = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(child.path(), "abc", 0644));
+ auto path_via_symlink =
+ JoinPath(symlink.path(), Basename(root.path()), Basename(child.path()),
+ Basename(target.path()));
+ const auto contents =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents(path_via_symlink));
+ ASSERT_EQ(contents, "abc");
+}
+
+TEST_F(OpenTest, Fault) {
+ char* totally_not_null = nullptr;
+ ASSERT_THAT(open(totally_not_null, O_RDONLY), SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(OpenTest, AppendOnly) {
+ // First write some data to the fresh file.
+ const int64_t kBufSize = 1024;
+ std::vector<char> buf(kBufSize, 'a');
+
+ FileDescriptor fd0 = ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR));
+
+ std::fill(buf.begin(), buf.end(), 'a');
+ EXPECT_THAT(WriteFd(fd0.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+ fd0.reset(); // Close the file early.
+
+ // Next get two handles to the same file. We open two files because we want
+ // to make sure that appending is respected between them.
+ const FileDescriptor fd1 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR | O_APPEND));
+ EXPECT_THAT(lseek(fd1.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0));
+
+ const FileDescriptor fd2 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR | O_APPEND));
+ EXPECT_THAT(lseek(fd2.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0));
+
+ // Then try to write to the first file and make sure the bytes are appended.
+ EXPECT_THAT(WriteFd(fd1.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+
+ // Check that the size of the file is correct and that the offset has been
+ // incremented to that size.
+ struct stat s0;
+ EXPECT_THAT(fstat(fd1.get(), &s0), SyscallSucceeds());
+ EXPECT_EQ(s0.st_size, kBufSize * 2);
+ EXPECT_THAT(lseek(fd1.get(), 0, SEEK_CUR),
+ SyscallSucceedsWithValue(kBufSize * 2));
+
+ // Then try to write to the second file and make sure the bytes are appended.
+ EXPECT_THAT(WriteFd(fd2.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+
+ // Check that the size of the file is correct and that the offset has been
+ // incremented to that size.
+ struct stat s1;
+ EXPECT_THAT(fstat(fd2.get(), &s1), SyscallSucceeds());
+ EXPECT_EQ(s1.st_size, kBufSize * 3);
+ EXPECT_THAT(lseek(fd2.get(), 0, SEEK_CUR),
+ SyscallSucceedsWithValue(kBufSize * 3));
+}
+
+TEST_F(OpenTest, AppendConcurrentWrite) {
+ constexpr int kThreadCount = 5;
+ constexpr int kBytesPerThread = 10000;
+ std::unique_ptr<ScopedThread> threads[kThreadCount];
+
+ // In case of the uncached policy, we expect that a file system can be changed
+ // externally, so we create a new inode each time when we open a file and we
+ // can't guarantee that writes to files with O_APPEND will work correctly.
+ SKIP_IF(getenv("GVISOR_GOFER_UNCACHED"));
+
+ EXPECT_THAT(truncate(test_file_name_.c_str(), 0), SyscallSucceeds());
+
+ std::string filename = test_file_name_;
+ DisableSave ds; // Too many syscalls.
+ // Start kThreadCount threads which will write concurrently into the same
+ // file.
+ for (int i = 0; i < kThreadCount; i++) {
+ threads[i] = absl::make_unique<ScopedThread>([filename]() {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(filename, O_RDWR | O_APPEND));
+
+ for (int j = 0; j < kBytesPerThread; j++) {
+ EXPECT_THAT(WriteFd(fd.get(), &j, 1), SyscallSucceedsWithValue(1));
+ }
+ });
+ }
+ for (int i = 0; i < kThreadCount; i++) {
+ threads[i]->Join();
+ }
+
+ // Check that the size of the file is correct.
+ struct stat st;
+ EXPECT_THAT(stat(test_file_name_.c_str(), &st), SyscallSucceeds());
+ EXPECT_EQ(st.st_size, kThreadCount * kBytesPerThread);
+}
+
+TEST_F(OpenTest, Truncate) {
+ {
+ // First write some data to the new file and close it.
+ FileDescriptor fd0 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_WRONLY));
+ std::vector<char> orig(10, 'a');
+ EXPECT_THAT(WriteFd(fd0.get(), orig.data(), orig.size()),
+ SyscallSucceedsWithValue(orig.size()));
+ }
+
+ // Then open with truncate and verify that offset is set to 0.
+ const FileDescriptor fd1 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR | O_TRUNC));
+ EXPECT_THAT(lseek(fd1.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0));
+
+ // Then write less data to the file and ensure the old content is gone.
+ std::vector<char> want(5, 'b');
+ EXPECT_THAT(WriteFd(fd1.get(), want.data(), want.size()),
+ SyscallSucceedsWithValue(want.size()));
+
+ struct stat stat;
+ EXPECT_THAT(fstat(fd1.get(), &stat), SyscallSucceeds());
+ EXPECT_EQ(stat.st_size, want.size());
+ EXPECT_THAT(lseek(fd1.get(), 0, SEEK_CUR),
+ SyscallSucceedsWithValue(want.size()));
+
+ // Read the data and ensure only the latest write is in the file.
+ std::vector<char> got(want.size() + 1, 'c');
+ ASSERT_THAT(pread(fd1.get(), got.data(), got.size(), 0),
+ SyscallSucceedsWithValue(want.size()));
+ EXPECT_EQ(memcmp(want.data(), got.data(), want.size()), 0)
+ << "rbuf=" << got.data();
+ EXPECT_EQ(got.back(), 'c'); // Last byte should not have been modified.
+}
+
+TEST_F(OpenTest, NameTooLong) {
+ char buf[4097] = {};
+ memset(buf, 'a', 4097);
+ EXPECT_THAT(open(buf, O_RDONLY), SyscallFailsWithErrno(ENAMETOOLONG));
+}
+
+TEST_F(OpenTest, DotsFromRoot) {
+ const FileDescriptor rootfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/", O_RDONLY | O_DIRECTORY));
+ const FileDescriptor other_rootfd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenAt(rootfd.get(), "..", O_RDONLY));
+}
+
+TEST_F(OpenTest, DirectoryWritableFails) {
+ ASSERT_THAT(open(GetAbsoluteTestTmpdir().c_str(), O_RDWR),
+ SyscallFailsWithErrno(EISDIR));
+}
+
+TEST_F(OpenTest, FileNotDirectory) {
+ // Create a file and try to open it with O_DIRECTORY.
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ ASSERT_THAT(open(file.path().c_str(), O_RDONLY | O_DIRECTORY),
+ SyscallFailsWithErrno(ENOTDIR));
+}
+
+TEST_F(OpenTest, Null) {
+ char c = '\0';
+ ASSERT_THAT(open(&c, O_RDONLY), SyscallFailsWithErrno(ENOENT));
+}
+
+// NOTE(b/119785738): While the man pages specify that this behavior should be
+// undefined, Linux truncates the file on opening read only if we have write
+// permission, so we will too.
+TEST_F(OpenTest, CanTruncateReadOnly) {
+ const FileDescriptor fd1 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY | O_TRUNC));
+
+ struct stat stat;
+ EXPECT_THAT(fstat(fd1.get(), &stat), SyscallSucceeds());
+ EXPECT_EQ(stat.st_size, 0);
+}
+
+// If we don't have read permission on the file, opening with
+// O_TRUNC should fail.
+TEST_F(OpenTest, CanTruncateReadOnlyNoWritePermission_NoRandomSave) {
+ // Drop capabilities that allow us to override file permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+
+ const DisableSave ds; // Permissions are dropped.
+ ASSERT_THAT(chmod(test_file_name_.c_str(), S_IRUSR | S_IRGRP),
+ SyscallSucceeds());
+
+ ASSERT_THAT(open(test_file_name_.c_str(), O_RDONLY | O_TRUNC),
+ SyscallFailsWithErrno(EACCES));
+
+ const FileDescriptor fd1 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY));
+
+ struct stat stat;
+ EXPECT_THAT(fstat(fd1.get(), &stat), SyscallSucceeds());
+ EXPECT_EQ(stat.st_size, test_data_.size());
+}
+
+// If we don't have read permission but have write permission, opening O_WRONLY
+// and O_TRUNC should succeed.
+TEST_F(OpenTest, CanTruncateWriteOnlyNoReadPermission_NoRandomSave) {
+ const DisableSave ds; // Permissions are dropped.
+
+ EXPECT_THAT(fchmod(test_file_fd_.get(), S_IWUSR | S_IWGRP),
+ SyscallSucceeds());
+
+ const FileDescriptor fd1 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_WRONLY | O_TRUNC));
+
+ EXPECT_THAT(fchmod(test_file_fd_.get(), S_IRUSR | S_IRGRP),
+ SyscallSucceeds());
+
+ const FileDescriptor fd2 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY));
+
+ struct stat stat;
+ EXPECT_THAT(fstat(fd2.get(), &stat), SyscallSucceeds());
+ EXPECT_EQ(stat.st_size, 0);
+}
+
+TEST_F(OpenTest, CanTruncateWithStrangePermissions) {
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+ const DisableSave ds; // Permissions are dropped.
+ std::string path = NewTempAbsPath();
+ int fd;
+ // Create a file without user permissions.
+ EXPECT_THAT( // SAVE_BELOW
+ fd = open(path.c_str(), O_CREAT | O_TRUNC | O_WRONLY, 055),
+ SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+
+ // Cannot open file because we are owner and have no permissions set.
+ EXPECT_THAT(open(path.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES));
+
+ // We *can* chmod the file, because we are the owner.
+ EXPECT_THAT(chmod(path.c_str(), 0755), SyscallSucceeds());
+
+ // Now we can open the file again.
+ EXPECT_THAT(fd = open(path.c_str(), O_RDWR), SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+TEST_F(OpenTest, OpenNonDirectoryWithTrailingSlash) {
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const std::string bad_path = file.path() + "/";
+ EXPECT_THAT(open(bad_path.c_str(), O_RDONLY), SyscallFailsWithErrno(ENOTDIR));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc
new file mode 100644
index 000000000..51eacf3f2
--- /dev/null
+++ b/test/syscalls/linux/open_create.cc
@@ -0,0 +1,155 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/temp_umask.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+TEST(CreateTest, TmpFile) {
+ int fd;
+ EXPECT_THAT(fd = open(JoinPath(GetAbsoluteTestTmpdir(), "a").c_str(),
+ O_RDWR | O_CREAT, 0666),
+ SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+TEST(CreateTest, ExistingFile) {
+ int fd;
+ EXPECT_THAT(
+ fd = open(JoinPath(GetAbsoluteTestTmpdir(), "ExistingFile").c_str(),
+ O_RDWR | O_CREAT, 0666),
+ SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+
+ EXPECT_THAT(
+ fd = open(JoinPath(GetAbsoluteTestTmpdir(), "ExistingFile").c_str(),
+ O_RDWR | O_CREAT, 0666),
+ SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+TEST(CreateTest, CreateAtFile) {
+ int dirfd;
+ EXPECT_THAT(dirfd = open(GetAbsoluteTestTmpdir().c_str(), O_DIRECTORY, 0666),
+ SyscallSucceeds());
+ EXPECT_THAT(openat(dirfd, "CreateAtFile", O_RDWR | O_CREAT, 0666),
+ SyscallSucceeds());
+ EXPECT_THAT(close(dirfd), SyscallSucceeds());
+}
+
+TEST(CreateTest, HonorsUmask_NoRandomSave) {
+ const DisableSave ds; // file cannot be re-opened as writable.
+ TempUmask mask(0222);
+ int fd;
+ ASSERT_THAT(
+ fd = open(JoinPath(GetAbsoluteTestTmpdir(), "UmaskedFile").c_str(),
+ O_RDWR | O_CREAT, 0666),
+ SyscallSucceeds());
+ struct stat statbuf;
+ ASSERT_THAT(fstat(fd, &statbuf), SyscallSucceeds());
+ EXPECT_EQ(0444, statbuf.st_mode & 0777);
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+TEST(CreateTest, CreateExclusively) {
+ std::string filename = NewTempAbsPath();
+
+ int fd;
+ ASSERT_THAT(fd = open(filename.c_str(), O_CREAT | O_RDWR, 0644),
+ SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+
+ EXPECT_THAT(open(filename.c_str(), O_CREAT | O_EXCL | O_RDWR, 0644),
+ SyscallFailsWithErrno(EEXIST));
+}
+
+TEST(CreateTeast, CreatWithOTrunc) {
+ std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd");
+ ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds());
+ ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC, 0666),
+ SyscallFailsWithErrno(EISDIR));
+}
+
+TEST(CreateTeast, CreatDirWithOTruncAndReadOnly) {
+ std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd");
+ ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds());
+ ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666),
+ SyscallFailsWithErrno(EISDIR));
+}
+
+TEST(CreateTeast, CreatFileWithOTruncAndReadOnly) {
+ std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile");
+ int dirfd;
+ ASSERT_THAT(dirfd = open(dirpath.c_str(), O_RDWR | O_CREAT, 0666),
+ SyscallSucceeds());
+ ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666),
+ SyscallSucceeds());
+ ASSERT_THAT(close(dirfd), SyscallSucceeds());
+}
+
+TEST(CreateTest, CreateFailsOnUnpermittedDir) {
+ // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to
+ // always override directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_THAT(open("/foo", O_CREAT | O_RDWR, 0644),
+ SyscallFailsWithErrno(EACCES));
+}
+
+TEST(CreateTest, CreateFailsOnDirWithoutWritePerms) {
+ // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to
+ // always override directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ auto parent = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0555));
+ auto file = JoinPath(parent.path(), "foo");
+ ASSERT_THAT(open(file.c_str(), O_CREAT | O_RDWR, 0644),
+ SyscallFailsWithErrno(EACCES));
+}
+
+// A file originally created RW, but opened RO can later be opened RW.
+// Regression test for b/65385065.
+TEST(CreateTest, OpenCreateROThenRW) {
+ TempPath file(NewTempAbsPath());
+
+ // Create a RW file, but only open it RO.
+ FileDescriptor fd1 = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(file.path(), O_CREAT | O_EXCL | O_RDONLY, 0644));
+
+ // Now get a RW FD.
+ FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+
+ // fd1 is not writable, but fd2 is.
+ char c = 'a';
+ EXPECT_THAT(WriteFd(fd1.get(), &c, 1), SyscallFailsWithErrno(EBADF));
+ EXPECT_THAT(WriteFd(fd2.get(), &c, 1), SyscallSucceedsWithValue(1));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc
new file mode 100644
index 000000000..5ac68feb4
--- /dev/null
+++ b/test/syscalls/linux/packet_socket.cc
@@ -0,0 +1,440 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <ifaddrs.h>
+#include <linux/capability.h>
+#include <linux/if_arp.h>
+#include <linux/if_packet.h>
+#include <net/ethernet.h>
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <netinet/udp.h>
+#include <poll.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/base/internal/endian.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+// Some of these tests involve sending packets via AF_PACKET sockets and the
+// loopback interface. Because AF_PACKET circumvents so much of the networking
+// stack, Linux sees these packets as "martian", i.e. they claim to be to/from
+// localhost but don't have the usual associated data. Thus Linux drops them by
+// default. You can see where this happens by following the code at:
+//
+// - net/ipv4/ip_input.c:ip_rcv_finish, which calls
+// - net/ipv4/route.c:ip_route_input_noref, which calls
+// - net/ipv4/route.c:ip_route_input_slow, which finds and drops martian
+// packets.
+//
+// To tell Linux not to drop these packets, you need to tell it to accept our
+// funny packets (which are completely valid and correct, but lack associated
+// in-kernel data because we use AF_PACKET):
+//
+// echo 1 >> /proc/sys/net/ipv4/conf/lo/accept_local
+// echo 1 >> /proc/sys/net/ipv4/conf/lo/route_localnet
+//
+// These tests require CAP_NET_RAW to run.
+
+// TODO(gvisor.dev/issue/173): gVisor support.
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+using ::testing::AnyOf;
+using ::testing::Eq;
+
+constexpr char kMessage[] = "soweoneul malhaebwa";
+constexpr in_port_t kPort = 0x409c; // htons(40000)
+
+//
+// "Cooked" tests. Cooked AF_PACKET sockets do not contain link layer
+// headers, and provide link layer destination/source information via a
+// returned struct sockaddr_ll.
+//
+
+// Send kMessage via sock to loopback
+void SendUDPMessage(int sock) {
+ struct sockaddr_in dest = {};
+ dest.sin_port = kPort;
+ dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ dest.sin_family = AF_INET;
+ EXPECT_THAT(sendto(sock, kMessage, sizeof(kMessage), 0,
+ reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
+ SyscallSucceedsWithValue(sizeof(kMessage)));
+}
+
+// Send an IP packet and make sure ETH_P_<something else> doesn't pick it up.
+TEST(BasicCookedPacketTest, WrongType) {
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
+ FileDescriptor sock = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_PACKET, SOCK_DGRAM, htons(ETH_P_PUP)));
+
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ SendUDPMessage(udp_sock.get());
+
+ // Wait and make sure the socket never becomes readable.
+ struct pollfd pfd = {};
+ pfd.fd = sock.get();
+ pfd.events = POLLIN;
+ EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0));
+}
+
+// Tests for "cooked" (SOCK_DGRAM) packet(7) sockets.
+class CookedPacketTest : public ::testing::TestWithParam<int> {
+ protected:
+ // Creates a socket to be used in tests.
+ void SetUp() override;
+
+ // Closes the socket created by SetUp().
+ void TearDown() override;
+
+ // Gets the device index of the loopback device.
+ int GetLoopbackIndex();
+
+ // The socket used for both reading and writing.
+ int socket_;
+};
+
+void CookedPacketTest::SetUp() {
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
+ if (!IsRunningOnGvisor()) {
+ FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE(
+ Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDONLY));
+ FileDescriptor routeLocalnet = ASSERT_NO_ERRNO_AND_VALUE(
+ Open("/proc/sys/net/ipv4/conf/lo/route_localnet", O_RDONLY));
+ char enabled;
+ ASSERT_THAT(read(acceptLocal.get(), &enabled, 1),
+ SyscallSucceedsWithValue(1));
+ ASSERT_EQ(enabled, '1');
+ ASSERT_THAT(read(routeLocalnet.get(), &enabled, 1),
+ SyscallSucceedsWithValue(1));
+ ASSERT_EQ(enabled, '1');
+ }
+
+ ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())),
+ SyscallSucceeds());
+}
+
+void CookedPacketTest::TearDown() {
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(socket_), SyscallSucceeds());
+ }
+}
+
+int CookedPacketTest::GetLoopbackIndex() {
+ struct ifreq ifr;
+ snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
+ EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds());
+ EXPECT_NE(ifr.ifr_ifindex, 0);
+ return ifr.ifr_ifindex;
+}
+
+// Receive and verify the message via packet socket on interface.
+void ReceiveMessage(int sock, int ifindex) {
+ // Wait for the socket to become readable.
+ struct pollfd pfd = {};
+ pfd.fd = sock;
+ pfd.events = POLLIN;
+ EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1));
+
+ // Read and verify the data.
+ constexpr size_t packet_size =
+ sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage);
+ char buf[64];
+ struct sockaddr_ll src = {};
+ socklen_t src_len = sizeof(src);
+ ASSERT_THAT(recvfrom(sock, buf, sizeof(buf), 0,
+ reinterpret_cast<struct sockaddr*>(&src), &src_len),
+ SyscallSucceedsWithValue(packet_size));
+
+ // sockaddr_ll ends with an 8 byte physical address field, but ethernet
+ // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2
+ // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns
+ // sizeof(sockaddr_ll).
+ ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2)));
+
+ // TODO(b/129292371): Verify protocol once we return it.
+ // Verify the source address.
+ EXPECT_EQ(src.sll_family, AF_PACKET);
+ EXPECT_EQ(src.sll_ifindex, ifindex);
+ EXPECT_EQ(src.sll_halen, ETH_ALEN);
+ // This came from the loopback device, so the address is all 0s.
+ for (int i = 0; i < src.sll_halen; i++) {
+ EXPECT_EQ(src.sll_addr[i], 0);
+ }
+
+ // Verify the IP header. We memcpy to deal with pointer aligment.
+ struct iphdr ip = {};
+ memcpy(&ip, buf, sizeof(ip));
+ EXPECT_EQ(ip.ihl, 5);
+ EXPECT_EQ(ip.version, 4);
+ EXPECT_EQ(ip.tot_len, htons(packet_size));
+ EXPECT_EQ(ip.protocol, IPPROTO_UDP);
+ EXPECT_EQ(ip.daddr, htonl(INADDR_LOOPBACK));
+ EXPECT_EQ(ip.saddr, htonl(INADDR_LOOPBACK));
+
+ // Verify the UDP header. We memcpy to deal with pointer aligment.
+ struct udphdr udp = {};
+ memcpy(&udp, buf + sizeof(iphdr), sizeof(udp));
+ EXPECT_EQ(udp.dest, kPort);
+ EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage)));
+
+ // Verify the payload.
+ char* payload = reinterpret_cast<char*>(buf + sizeof(iphdr) + sizeof(udphdr));
+ EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0);
+}
+
+// Receive via a packet socket.
+TEST_P(CookedPacketTest, Receive) {
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ SendUDPMessage(udp_sock.get());
+
+ // Receive and verify the data.
+ int loopback_index = GetLoopbackIndex();
+ ReceiveMessage(socket_, loopback_index);
+}
+
+// Send via a packet socket.
+TEST_P(CookedPacketTest, Send) {
+ // TODO(b/129292371): Remove once we support packet socket writing.
+ SKIP_IF(IsRunningOnGvisor());
+
+ // Let's send a UDP packet and receive it using a regular UDP socket.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ struct sockaddr_in bind_addr = {};
+ bind_addr.sin_family = AF_INET;
+ bind_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ bind_addr.sin_port = kPort;
+ ASSERT_THAT(
+ bind(udp_sock.get(), reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Set up the destination physical address.
+ struct sockaddr_ll dest = {};
+ dest.sll_family = AF_PACKET;
+ dest.sll_halen = ETH_ALEN;
+ dest.sll_ifindex = GetLoopbackIndex();
+ dest.sll_protocol = htons(ETH_P_IP);
+ // We're sending to the loopback device, so the address is all 0s.
+ memset(dest.sll_addr, 0x00, ETH_ALEN);
+
+ // Set up the IP header.
+ struct iphdr iphdr = {0};
+ iphdr.ihl = 5;
+ iphdr.version = 4;
+ iphdr.tos = 0;
+ iphdr.tot_len =
+ htons(sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage));
+ // Get a pseudo-random ID. If we clash with an in-use ID the test will fail,
+ // but we have no way of getting an ID we know to be good.
+ srand(*reinterpret_cast<unsigned int*>(&iphdr));
+ iphdr.id = rand();
+ // Linux sets this bit ("do not fragment") for small packets.
+ iphdr.frag_off = 1 << 6;
+ iphdr.ttl = 64;
+ iphdr.protocol = IPPROTO_UDP;
+ iphdr.daddr = htonl(INADDR_LOOPBACK);
+ iphdr.saddr = htonl(INADDR_LOOPBACK);
+ iphdr.check = IPChecksum(iphdr);
+
+ // Set up the UDP header.
+ struct udphdr udphdr = {};
+ udphdr.source = kPort;
+ udphdr.dest = kPort;
+ udphdr.len = htons(sizeof(udphdr) + sizeof(kMessage));
+ udphdr.check = UDPChecksum(iphdr, udphdr, kMessage, sizeof(kMessage));
+
+ // Copy both headers and the payload into our packet buffer.
+ char send_buf[sizeof(iphdr) + sizeof(udphdr) + sizeof(kMessage)];
+ memcpy(send_buf, &iphdr, sizeof(iphdr));
+ memcpy(send_buf + sizeof(iphdr), &udphdr, sizeof(udphdr));
+ memcpy(send_buf + sizeof(iphdr) + sizeof(udphdr), kMessage, sizeof(kMessage));
+
+ // Send it.
+ ASSERT_THAT(sendto(socket_, send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Wait for the packet to become available on both sockets.
+ struct pollfd pfd = {};
+ pfd.fd = udp_sock.get();
+ pfd.events = POLLIN;
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1));
+ pfd.fd = socket_;
+ pfd.events = POLLIN;
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1));
+
+ // Receive on the packet socket.
+ char recv_buf[sizeof(send_buf)];
+ ASSERT_THAT(recv(socket_, recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_EQ(memcmp(recv_buf, send_buf, sizeof(send_buf)), 0);
+
+ // Receive on the UDP socket.
+ struct sockaddr_in src;
+ socklen_t src_len = sizeof(src);
+ ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT,
+ reinterpret_cast<struct sockaddr*>(&src), &src_len),
+ SyscallSucceedsWithValue(sizeof(kMessage)));
+ // Check src and payload.
+ EXPECT_EQ(strncmp(recv_buf, kMessage, sizeof(kMessage)), 0);
+ EXPECT_EQ(src.sin_family, AF_INET);
+ EXPECT_EQ(src.sin_port, kPort);
+ EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK));
+}
+
+// Bind and receive via packet socket.
+TEST_P(CookedPacketTest, BindReceive) {
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = GetLoopbackIndex();
+
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ SendUDPMessage(udp_sock.get());
+
+ // Receive and verify the data.
+ ReceiveMessage(socket_, bind_addr.sll_ifindex);
+}
+
+// Double Bind socket.
+TEST_P(CookedPacketTest, DoubleBind) {
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = GetLoopbackIndex();
+
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Binding socket again should fail.
+ ASSERT_THAT(
+ bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ // Linux 4.09 returns EINVAL here, but some time before 4.19 it switched
+ // to EADDRINUSE.
+ AnyOf(SyscallFailsWithErrno(EADDRINUSE), SyscallFailsWithErrno(EINVAL)));
+}
+
+// Bind and verify we do not receive data on interface which is not bound
+TEST_P(CookedPacketTest, BindDrop) {
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+
+ struct ifaddrs* if_addr_list = nullptr;
+ auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); });
+
+ ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds());
+
+ // Get interface other than loopback.
+ struct ifreq ifr = {};
+ for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) {
+ if (strcmp(i->ifa_name, "lo") != 0) {
+ strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name));
+ break;
+ }
+ }
+
+ // Skip if no interface is available other than loopback.
+ if (strlen(ifr.ifr_name) == 0) {
+ GTEST_SKIP();
+ }
+
+ // Get interface index.
+ EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds());
+ EXPECT_NE(ifr.ifr_ifindex, 0);
+
+ // Bind to packet socket requires only family, protocol and ifindex.
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = ifr.ifr_ifindex;
+
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Send to loopback interface.
+ struct sockaddr_in dest = {};
+ dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ dest.sin_family = AF_INET;
+ dest.sin_port = kPort;
+ EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0,
+ reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
+ SyscallSucceedsWithValue(sizeof(kMessage)));
+
+ // Wait and make sure the socket never receives any data.
+ struct pollfd pfd = {};
+ pfd.fd = socket_;
+ pfd.events = POLLIN;
+ EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0));
+}
+
+// Bind with invalid address.
+TEST_P(CookedPacketTest, BindFail) {
+ // Null address.
+ ASSERT_THAT(
+ bind(socket_, nullptr, sizeof(struct sockaddr)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallFailsWithErrno(EINVAL)));
+
+ // Address of size 1.
+ uint8_t addr = 0;
+ ASSERT_THAT(
+ bind(socket_, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+INSTANTIATE_TEST_SUITE_P(AllInetTests, CookedPacketTest,
+ ::testing::Values(ETH_P_IP, ETH_P_ALL));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc
new file mode 100644
index 000000000..4093ac813
--- /dev/null
+++ b/test/syscalls/linux/packet_socket_raw.cc
@@ -0,0 +1,565 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/capability.h>
+#include <linux/if_arp.h>
+#include <linux/if_packet.h>
+#include <net/ethernet.h>
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <netinet/udp.h>
+#include <poll.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/base/internal/endian.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+// Some of these tests involve sending packets via AF_PACKET sockets and the
+// loopback interface. Because AF_PACKET circumvents so much of the networking
+// stack, Linux sees these packets as "martian", i.e. they claim to be to/from
+// localhost but don't have the usual associated data. Thus Linux drops them by
+// default. You can see where this happens by following the code at:
+//
+// - net/ipv4/ip_input.c:ip_rcv_finish, which calls
+// - net/ipv4/route.c:ip_route_input_noref, which calls
+// - net/ipv4/route.c:ip_route_input_slow, which finds and drops martian
+// packets.
+//
+// To tell Linux not to drop these packets, you need to tell it to accept our
+// funny packets (which are completely valid and correct, but lack associated
+// in-kernel data because we use AF_PACKET):
+//
+// echo 1 >> /proc/sys/net/ipv4/conf/lo/accept_local
+// echo 1 >> /proc/sys/net/ipv4/conf/lo/route_localnet
+//
+// These tests require CAP_NET_RAW to run.
+
+// TODO(gvisor.dev/issue/173): gVisor support.
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+using ::testing::AnyOf;
+using ::testing::Eq;
+
+constexpr char kMessage[] = "soweoneul malhaebwa";
+constexpr in_port_t kPort = 0x409c; // htons(40000)
+
+// Send kMessage via sock to loopback
+void SendUDPMessage(int sock) {
+ struct sockaddr_in dest = {};
+ dest.sin_port = kPort;
+ dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ dest.sin_family = AF_INET;
+ EXPECT_THAT(sendto(sock, kMessage, sizeof(kMessage), 0,
+ reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
+ SyscallSucceedsWithValue(sizeof(kMessage)));
+}
+
+//
+// Raw tests. Packets sent with raw AF_PACKET sockets always include link layer
+// headers.
+//
+
+// Tests for "raw" (SOCK_RAW) packet(7) sockets.
+class RawPacketTest : public ::testing::TestWithParam<int> {
+ protected:
+ // Creates a socket to be used in tests.
+ void SetUp() override;
+
+ // Closes the socket created by SetUp().
+ void TearDown() override;
+
+ // Gets the device index of the loopback device.
+ int GetLoopbackIndex();
+
+ // The socket used for both reading and writing.
+ int s_;
+};
+
+void RawPacketTest::SetUp() {
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_PACKET, SOCK_RAW, htons(GetParam())),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
+ if (!IsRunningOnGvisor()) {
+ // Ensure that looped back packets aren't rejected by the kernel.
+ FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE(
+ Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDWR));
+ FileDescriptor routeLocalnet = ASSERT_NO_ERRNO_AND_VALUE(
+ Open("/proc/sys/net/ipv4/conf/lo/route_localnet", O_RDWR));
+ char enabled;
+ ASSERT_THAT(read(acceptLocal.get(), &enabled, 1),
+ SyscallSucceedsWithValue(1));
+ if (enabled != '1') {
+ enabled = '1';
+ ASSERT_THAT(lseek(acceptLocal.get(), 0, SEEK_SET),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(write(acceptLocal.get(), &enabled, 1),
+ SyscallSucceedsWithValue(1));
+ ASSERT_THAT(lseek(acceptLocal.get(), 0, SEEK_SET),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(read(acceptLocal.get(), &enabled, 1),
+ SyscallSucceedsWithValue(1));
+ ASSERT_EQ(enabled, '1');
+ }
+
+ ASSERT_THAT(read(routeLocalnet.get(), &enabled, 1),
+ SyscallSucceedsWithValue(1));
+ if (enabled != '1') {
+ enabled = '1';
+ ASSERT_THAT(lseek(routeLocalnet.get(), 0, SEEK_SET),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(write(routeLocalnet.get(), &enabled, 1),
+ SyscallSucceedsWithValue(1));
+ ASSERT_THAT(lseek(routeLocalnet.get(), 0, SEEK_SET),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(read(routeLocalnet.get(), &enabled, 1),
+ SyscallSucceedsWithValue(1));
+ ASSERT_EQ(enabled, '1');
+ }
+ }
+
+ ASSERT_THAT(s_ = socket(AF_PACKET, SOCK_RAW, htons(GetParam())),
+ SyscallSucceeds());
+}
+
+void RawPacketTest::TearDown() {
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(s_), SyscallSucceeds());
+ }
+}
+
+int RawPacketTest::GetLoopbackIndex() {
+ struct ifreq ifr;
+ snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
+ EXPECT_THAT(ioctl(s_, SIOCGIFINDEX, &ifr), SyscallSucceeds());
+ EXPECT_NE(ifr.ifr_ifindex, 0);
+ return ifr.ifr_ifindex;
+}
+
+// Receive via a packet socket.
+TEST_P(RawPacketTest, Receive) {
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ SendUDPMessage(udp_sock.get());
+
+ // Wait for the socket to become readable.
+ struct pollfd pfd = {};
+ pfd.fd = s_;
+ pfd.events = POLLIN;
+ EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1));
+
+ // Read and verify the data.
+ constexpr size_t packet_size = sizeof(struct ethhdr) + sizeof(struct iphdr) +
+ sizeof(struct udphdr) + sizeof(kMessage);
+ char buf[64];
+ struct sockaddr_ll src = {};
+ socklen_t src_len = sizeof(src);
+ ASSERT_THAT(recvfrom(s_, buf, sizeof(buf), 0,
+ reinterpret_cast<struct sockaddr*>(&src), &src_len),
+ SyscallSucceedsWithValue(packet_size));
+ // sockaddr_ll ends with an 8 byte physical address field, but ethernet
+ // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2
+ // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns
+ // sizeof(sockaddr_ll).
+ ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2)));
+
+ // TODO(b/129292371): Verify protocol once we return it.
+ // Verify the source address.
+ EXPECT_EQ(src.sll_family, AF_PACKET);
+ EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex());
+ EXPECT_EQ(src.sll_halen, ETH_ALEN);
+ // This came from the loopback device, so the address is all 0s.
+ for (int i = 0; i < src.sll_halen; i++) {
+ EXPECT_EQ(src.sll_addr[i], 0);
+ }
+
+ // Verify the ethernet header. We memcpy to deal with pointer alignment.
+ struct ethhdr eth = {};
+ memcpy(&eth, buf, sizeof(eth));
+ // The destination and source address should be 0, for loopback.
+ for (int i = 0; i < ETH_ALEN; i++) {
+ EXPECT_EQ(eth.h_dest[i], 0);
+ EXPECT_EQ(eth.h_source[i], 0);
+ }
+ EXPECT_EQ(eth.h_proto, htons(ETH_P_IP));
+
+ // Verify the IP header. We memcpy to deal with pointer aligment.
+ struct iphdr ip = {};
+ memcpy(&ip, buf + sizeof(ethhdr), sizeof(ip));
+ EXPECT_EQ(ip.ihl, 5);
+ EXPECT_EQ(ip.version, 4);
+ EXPECT_EQ(ip.tot_len, htons(packet_size - sizeof(eth)));
+ EXPECT_EQ(ip.protocol, IPPROTO_UDP);
+ EXPECT_EQ(ip.daddr, htonl(INADDR_LOOPBACK));
+ EXPECT_EQ(ip.saddr, htonl(INADDR_LOOPBACK));
+
+ // Verify the UDP header. We memcpy to deal with pointer aligment.
+ struct udphdr udp = {};
+ memcpy(&udp, buf + sizeof(eth) + sizeof(iphdr), sizeof(udp));
+ EXPECT_EQ(udp.dest, kPort);
+ EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage)));
+
+ // Verify the payload.
+ char* payload = reinterpret_cast<char*>(buf + sizeof(eth) + sizeof(iphdr) +
+ sizeof(udphdr));
+ EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0);
+}
+
+// Send via a packet socket.
+TEST_P(RawPacketTest, Send) {
+ // TODO(b/129292371): Remove once we support packet socket writing.
+ SKIP_IF(IsRunningOnGvisor());
+
+ // Let's send a UDP packet and receive it using a regular UDP socket.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ struct sockaddr_in bind_addr = {};
+ bind_addr.sin_family = AF_INET;
+ bind_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ bind_addr.sin_port = kPort;
+ ASSERT_THAT(
+ bind(udp_sock.get(), reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Set up the destination physical address.
+ struct sockaddr_ll dest = {};
+ dest.sll_family = AF_PACKET;
+ dest.sll_halen = ETH_ALEN;
+ dest.sll_ifindex = GetLoopbackIndex();
+ dest.sll_protocol = htons(ETH_P_IP);
+ // We're sending to the loopback device, so the address is all 0s.
+ memset(dest.sll_addr, 0x00, ETH_ALEN);
+
+ // Set up the ethernet header. The kernel takes care of the footer.
+ // We're sending to and from hardware address 0 (loopback).
+ struct ethhdr eth = {};
+ eth.h_proto = htons(ETH_P_IP);
+
+ // Set up the IP header.
+ struct iphdr iphdr = {};
+ iphdr.ihl = 5;
+ iphdr.version = 4;
+ iphdr.tos = 0;
+ iphdr.tot_len =
+ htons(sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage));
+ // Get a pseudo-random ID. If we clash with an in-use ID the test will fail,
+ // but we have no way of getting an ID we know to be good.
+ srand(*reinterpret_cast<unsigned int*>(&iphdr));
+ iphdr.id = rand();
+ // Linux sets this bit ("do not fragment") for small packets.
+ iphdr.frag_off = 1 << 6;
+ iphdr.ttl = 64;
+ iphdr.protocol = IPPROTO_UDP;
+ iphdr.daddr = htonl(INADDR_LOOPBACK);
+ iphdr.saddr = htonl(INADDR_LOOPBACK);
+ iphdr.check = IPChecksum(iphdr);
+
+ // Set up the UDP header.
+ struct udphdr udphdr = {};
+ udphdr.source = kPort;
+ udphdr.dest = kPort;
+ udphdr.len = htons(sizeof(udphdr) + sizeof(kMessage));
+ udphdr.check = UDPChecksum(iphdr, udphdr, kMessage, sizeof(kMessage));
+
+ // Copy both headers and the payload into our packet buffer.
+ char
+ send_buf[sizeof(eth) + sizeof(iphdr) + sizeof(udphdr) + sizeof(kMessage)];
+ memcpy(send_buf, &eth, sizeof(eth));
+ memcpy(send_buf + sizeof(ethhdr), &iphdr, sizeof(iphdr));
+ memcpy(send_buf + sizeof(ethhdr) + sizeof(iphdr), &udphdr, sizeof(udphdr));
+ memcpy(send_buf + sizeof(ethhdr) + sizeof(iphdr) + sizeof(udphdr), kMessage,
+ sizeof(kMessage));
+
+ // Send it.
+ ASSERT_THAT(sendto(s_, send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Wait for the packet to become available on both sockets.
+ struct pollfd pfd = {};
+ pfd.fd = udp_sock.get();
+ pfd.events = POLLIN;
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1));
+ pfd.fd = s_;
+ pfd.events = POLLIN;
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1));
+
+ // Receive on the packet socket.
+ char recv_buf[sizeof(send_buf)];
+ ASSERT_THAT(recv(s_, recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_EQ(memcmp(recv_buf, send_buf, sizeof(send_buf)), 0);
+
+ // Receive on the UDP socket.
+ struct sockaddr_in src;
+ socklen_t src_len = sizeof(src);
+ ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT,
+ reinterpret_cast<struct sockaddr*>(&src), &src_len),
+ SyscallSucceedsWithValue(sizeof(kMessage)));
+ // Check src and payload.
+ EXPECT_EQ(strncmp(recv_buf, kMessage, sizeof(kMessage)), 0);
+ EXPECT_EQ(src.sin_family, AF_INET);
+ EXPECT_EQ(src.sin_port, kPort);
+ EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK));
+}
+
+// Check that setting SO_RCVBUF below min is clamped to the minimum
+// receive buffer size.
+TEST_P(RawPacketTest, SetSocketRecvBufBelowMin) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Discover minimum receive buf size by trying to set it to zero.
+ // See:
+ // https://github.com/torvalds/linux/blob/a5dc8300df75e8b8384b4c82225f1e4a0b4d9b55/net/core/sock.c#L820
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ int min = 0;
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value so let's use a value that when doubled will still
+ // be smaller than min.
+ int below_min = min / 2 - 1;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &below_min, sizeof(below_min)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ ASSERT_EQ(min, val);
+}
+
+// Check that setting SO_RCVBUF above max is clamped to the maximum
+// receive buffer size.
+TEST_P(RawPacketTest, SetSocketRecvBufAboveMax) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Discover max buf size by trying to set the largest possible buffer size.
+ constexpr int kRcvBufSz = 0xffffffff;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ int max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &max, &max_len),
+ SyscallSucceeds());
+
+ int above_max = max + 1;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &above_max, sizeof(above_max)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len),
+ SyscallSucceeds());
+ ASSERT_EQ(max, val);
+}
+
+// Check that setting SO_RCVBUF min <= kRcvBufSz <= max is honored.
+TEST_P(RawPacketTest, SetSocketRecvBuf) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int max = 0;
+ int min = 0;
+ {
+ // Discover max buf size by trying to set a really large buffer size.
+ constexpr int kRcvBufSz = 0xffffffff;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &max, &max_len),
+ SyscallSucceeds());
+ }
+
+ {
+ // Discover minimum buffer size by trying to set a zero size receive buffer
+ // size.
+ // See:
+ // https://github.com/torvalds/linux/blob/a5dc8300df75e8b8384b4c82225f1e4a0b4d9b55/net/core/sock.c#L820
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+ }
+
+ int quarter_sz = min + (max - min) / 4;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &quarter_sz, sizeof(quarter_sz)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF.
+ // TODO(gvisor.dev/issue/2926): Remove when Netstack matches linux behavior.
+ if (!IsRunningOnGvisor()) {
+ quarter_sz *= 2;
+ }
+ ASSERT_EQ(quarter_sz, val);
+}
+
+// Check that setting SO_SNDBUF below min is clamped to the minimum
+// receive buffer size.
+TEST_P(RawPacketTest, SetSocketSendBufBelowMin) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Discover minimum buffer size by trying to set it to zero.
+ constexpr int kSndBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ int min = 0;
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &min, &min_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value so let's use a value that when doubled will still
+ // be smaller than min.
+ int below_min = min / 2 - 1;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &below_min, sizeof(below_min)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ ASSERT_EQ(min, val);
+}
+
+// Check that setting SO_SNDBUF above max is clamped to the maximum
+// send buffer size.
+TEST_P(RawPacketTest, SetSocketSendBufAboveMax) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Discover maximum buffer size by trying to set it to a large value.
+ constexpr int kSndBufSz = 0xffffffff;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ int max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &max, &max_len),
+ SyscallSucceeds());
+
+ int above_max = max + 1;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &above_max, sizeof(above_max)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len),
+ SyscallSucceeds());
+ ASSERT_EQ(max, val);
+}
+
+// Check that setting SO_SNDBUF min <= kSndBufSz <= max is honored.
+TEST_P(RawPacketTest, SetSocketSendBuf) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int max = 0;
+ int min = 0;
+ {
+ // Discover maximum buffer size by trying to set it to a large value.
+ constexpr int kSndBufSz = 0xffffffff;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &max, &max_len),
+ SyscallSucceeds());
+ }
+
+ {
+ // Discover minimum buffer size by trying to set it to zero.
+ constexpr int kSndBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &min, &min_len),
+ SyscallSucceeds());
+ }
+
+ int quarter_sz = min + (max - min) / 4;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &quarter_sz, sizeof(quarter_sz)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF.
+ // TODO(gvisor.dev/issue/2926): Remove the gvisor special casing when Netstack
+ // matches linux behavior.
+ if (!IsRunningOnGvisor()) {
+ quarter_sz *= 2;
+ }
+
+ ASSERT_EQ(quarter_sz, val);
+}
+
+INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest,
+ ::testing::Values(ETH_P_IP, ETH_P_ALL));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/partial_bad_buffer.cc b/test/syscalls/linux/partial_bad_buffer.cc
new file mode 100644
index 000000000..df7129acc
--- /dev/null
+++ b/test/syscalls/linux/partial_bad_buffer.cc
@@ -0,0 +1,405 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <sys/mman.h>
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <sys/uio.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+using ::testing::Gt;
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr char kMessage[] = "hello world";
+
+// PartialBadBufferTest checks the result of various IO syscalls when passed a
+// buffer that does not have the space specified in the syscall (most of it is
+// PROT_NONE). Linux is annoyingly inconsistent among different syscalls, so we
+// test all of them.
+class PartialBadBufferTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ // Create and open a directory for getdents cases.
+ directory_ = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ ASSERT_THAT(
+ directory_fd_ = open(directory_.path().c_str(), O_RDONLY | O_DIRECTORY),
+ SyscallSucceeds());
+
+ // Create and open a normal file, placing it in the directory
+ // so the getdents cases have some dirents.
+ name_ = JoinPath(directory_.path(), "a");
+ ASSERT_THAT(fd_ = open(name_.c_str(), O_RDWR | O_CREAT, 0644),
+ SyscallSucceeds());
+
+ // Write some initial data.
+ size_t size = sizeof(kMessage) - 1;
+ EXPECT_THAT(WriteFd(fd_, &kMessage, size), SyscallSucceedsWithValue(size));
+ ASSERT_THAT(lseek(fd_, 0, SEEK_SET), SyscallSucceeds());
+
+ // Map a useable buffer.
+ addr_ = mmap(0, 2 * kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+ ASSERT_NE(addr_, MAP_FAILED);
+ char* buf = reinterpret_cast<char*>(addr_);
+
+ // Guard page for our read to run into.
+ ASSERT_THAT(mprotect(reinterpret_cast<void*>(buf + kPageSize), kPageSize,
+ PROT_NONE),
+ SyscallSucceeds());
+
+ // Leave only one free byte in the buffer.
+ bad_buffer_ = buf + kPageSize - 1;
+ }
+
+ off_t Size() {
+ struct stat st;
+ int rc = fstat(fd_, &st);
+ if (rc < 0) {
+ return static_cast<off_t>(rc);
+ }
+ return st.st_size;
+ }
+
+ void TearDown() override {
+ EXPECT_THAT(munmap(addr_, 2 * kPageSize), SyscallSucceeds()) << addr_;
+ EXPECT_THAT(close(fd_), SyscallSucceeds());
+ EXPECT_THAT(unlink(name_.c_str()), SyscallSucceeds());
+ EXPECT_THAT(close(directory_fd_), SyscallSucceeds());
+ }
+
+ // Return buffer with n bytes of free space.
+ // N.B. this is the same buffer used to back bad_buffer_.
+ char* FreeBytes(size_t n) {
+ TEST_CHECK(n <= static_cast<size_t>(4096));
+ return reinterpret_cast<char*>(addr_) + kPageSize - n;
+ }
+
+ std::string name_;
+ int fd_;
+ TempPath directory_;
+ int directory_fd_;
+ void* addr_;
+ char* bad_buffer_;
+};
+
+// We do both "big" and "small" tests to try to hit the "zero copy" and
+// non-"zero copy" paths, which have different code paths for handling faults.
+
+TEST_F(PartialBadBufferTest, ReadBig) {
+ EXPECT_THAT(RetryEINTR(read)(fd_, bad_buffer_, kPageSize),
+ SyscallSucceedsWithValue(1));
+ EXPECT_EQ('h', bad_buffer_[0]);
+}
+
+TEST_F(PartialBadBufferTest, ReadSmall) {
+ EXPECT_THAT(RetryEINTR(read)(fd_, bad_buffer_, 10),
+ SyscallSucceedsWithValue(1));
+ EXPECT_EQ('h', bad_buffer_[0]);
+}
+
+TEST_F(PartialBadBufferTest, PreadBig) {
+ EXPECT_THAT(RetryEINTR(pread)(fd_, bad_buffer_, kPageSize, 0),
+ SyscallSucceedsWithValue(1));
+ EXPECT_EQ('h', bad_buffer_[0]);
+}
+
+TEST_F(PartialBadBufferTest, PreadSmall) {
+ EXPECT_THAT(RetryEINTR(pread)(fd_, bad_buffer_, 10, 0),
+ SyscallSucceedsWithValue(1));
+ EXPECT_EQ('h', bad_buffer_[0]);
+}
+
+TEST_F(PartialBadBufferTest, ReadvBig) {
+ struct iovec vec;
+ vec.iov_base = bad_buffer_;
+ vec.iov_len = kPageSize;
+
+ EXPECT_THAT(RetryEINTR(readv)(fd_, &vec, 1), SyscallSucceedsWithValue(1));
+ EXPECT_EQ('h', bad_buffer_[0]);
+}
+
+TEST_F(PartialBadBufferTest, ReadvSmall) {
+ struct iovec vec;
+ vec.iov_base = bad_buffer_;
+ vec.iov_len = 10;
+
+ EXPECT_THAT(RetryEINTR(readv)(fd_, &vec, 1), SyscallSucceedsWithValue(1));
+ EXPECT_EQ('h', bad_buffer_[0]);
+}
+
+TEST_F(PartialBadBufferTest, PreadvBig) {
+ struct iovec vec;
+ vec.iov_base = bad_buffer_;
+ vec.iov_len = kPageSize;
+
+ EXPECT_THAT(RetryEINTR(preadv)(fd_, &vec, 1, 0), SyscallSucceedsWithValue(1));
+ EXPECT_EQ('h', bad_buffer_[0]);
+}
+
+TEST_F(PartialBadBufferTest, PreadvSmall) {
+ struct iovec vec;
+ vec.iov_base = bad_buffer_;
+ vec.iov_len = 10;
+
+ EXPECT_THAT(RetryEINTR(preadv)(fd_, &vec, 1, 0), SyscallSucceedsWithValue(1));
+ EXPECT_EQ('h', bad_buffer_[0]);
+}
+
+TEST_F(PartialBadBufferTest, WriteBig) {
+ off_t orig_size = Size();
+ int n;
+
+ ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(
+ (n = RetryEINTR(write)(fd_, bad_buffer_, kPageSize)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
+}
+
+TEST_F(PartialBadBufferTest, WriteSmall) {
+ off_t orig_size = Size();
+ int n;
+
+ ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(
+ (n = RetryEINTR(write)(fd_, bad_buffer_, 10)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
+}
+
+TEST_F(PartialBadBufferTest, PwriteBig) {
+ off_t orig_size = Size();
+ int n;
+
+ EXPECT_THAT(
+ (n = RetryEINTR(pwrite)(fd_, bad_buffer_, kPageSize, orig_size)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
+}
+
+TEST_F(PartialBadBufferTest, PwriteSmall) {
+ off_t orig_size = Size();
+ int n;
+
+ EXPECT_THAT(
+ (n = RetryEINTR(pwrite)(fd_, bad_buffer_, 10, orig_size)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
+}
+
+TEST_F(PartialBadBufferTest, WritevBig) {
+ struct iovec vec;
+ vec.iov_base = bad_buffer_;
+ vec.iov_len = kPageSize;
+ off_t orig_size = Size();
+ int n;
+
+ ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(
+ (n = RetryEINTR(writev)(fd_, &vec, 1)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
+}
+
+TEST_F(PartialBadBufferTest, WritevSmall) {
+ struct iovec vec;
+ vec.iov_base = bad_buffer_;
+ vec.iov_len = 10;
+ off_t orig_size = Size();
+ int n;
+
+ ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(
+ (n = RetryEINTR(writev)(fd_, &vec, 1)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
+}
+
+TEST_F(PartialBadBufferTest, PwritevBig) {
+ struct iovec vec;
+ vec.iov_base = bad_buffer_;
+ vec.iov_len = kPageSize;
+ off_t orig_size = Size();
+ int n;
+
+ EXPECT_THAT(
+ (n = RetryEINTR(pwritev)(fd_, &vec, 1, orig_size)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
+}
+
+TEST_F(PartialBadBufferTest, PwritevSmall) {
+ struct iovec vec;
+ vec.iov_base = bad_buffer_;
+ vec.iov_len = 10;
+ off_t orig_size = Size();
+ int n;
+
+ EXPECT_THAT(
+ (n = RetryEINTR(pwritev)(fd_, &vec, 1, orig_size)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
+}
+
+// getdents returns EFAULT when the you claim the buffer is large enough, but
+// it actually isn't.
+TEST_F(PartialBadBufferTest, GetdentsBig) {
+ EXPECT_THAT(RetryEINTR(syscall)(SYS_getdents64, directory_fd_, bad_buffer_,
+ kPageSize),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+// getdents returns EINVAL when the you claim the buffer is too small.
+TEST_F(PartialBadBufferTest, GetdentsSmall) {
+ EXPECT_THAT(
+ RetryEINTR(syscall)(SYS_getdents64, directory_fd_, bad_buffer_, 10),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// getdents will write entries into a buffer if there is space before it faults.
+TEST_F(PartialBadBufferTest, GetdentsOneEntry) {
+ // 30 bytes is enough for one (small) entry.
+ char* buf = FreeBytes(30);
+
+ EXPECT_THAT(
+ RetryEINTR(syscall)(SYS_getdents64, directory_fd_, buf, kPageSize),
+ SyscallSucceedsWithValue(Gt(0)));
+}
+
+PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) {
+ struct sockaddr_storage addr;
+ memset(&addr, 0, sizeof(addr));
+ addr.ss_family = family;
+ switch (family) {
+ case AF_INET:
+ reinterpret_cast<struct sockaddr_in*>(&addr)->sin_addr.s_addr =
+ htonl(INADDR_LOOPBACK);
+ break;
+ case AF_INET6:
+ reinterpret_cast<struct sockaddr_in6*>(&addr)->sin6_addr =
+ in6addr_loopback;
+ break;
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+ return addr;
+}
+
+// SendMsgTCP verifies that calling sendmsg with a bad address returns an
+// EFAULT. It also verifies that passing a buffer which is made up of 2
+// pages one valid and one guard page succeeds as long as the write is
+// for exactly the size of 1 page.
+TEST_F(PartialBadBufferTest, SendMsgTCP) {
+ auto listen_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
+
+ // Initialize address to the loopback one.
+ sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET));
+ socklen_t addrlen = sizeof(addr);
+
+ // Bind to some port then start listening.
+ ASSERT_THAT(bind(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the address we're listening on, then connect to it. We need to do this
+ // because we're allowing the stack to pick a port for us.
+ ASSERT_THAT(getsockname(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ auto send_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
+
+ ASSERT_THAT(
+ RetryEINTR(connect)(send_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto recv_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr));
+
+ // TODO(gvisor.dev/issue/674): Update this once Netstack matches linux
+ // behaviour on a setsockopt of SO_RCVBUF/SO_SNDBUF.
+ //
+ // Set SO_SNDBUF for socket to exactly kPageSize+1.
+ //
+ // gVisor does not double the value passed in SO_SNDBUF like linux does so we
+ // just increase it by 1 byte here for gVisor so that we can test writing 1
+ // byte past the valid page and check that it triggers an EFAULT
+ // correctly. Otherwise in gVisor the sendmsg call will just return with no
+ // error with kPageSize bytes written successfully.
+ const uint32_t buf_size = kPageSize + 1;
+ ASSERT_THAT(setsockopt(send_socket.get(), SOL_SOCKET, SO_SNDBUF, &buf_size,
+ sizeof(buf_size)),
+ SyscallSucceedsWithValue(0));
+
+ struct msghdr hdr = {};
+ struct iovec iov = {};
+ iov.iov_base = bad_buffer_;
+ iov.iov_len = kPageSize;
+ hdr.msg_iov = &iov;
+ hdr.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0),
+ SyscallFailsWithErrno(EFAULT));
+
+ // Now assert that writing kPageSize from addr_ succeeds.
+ iov.iov_base = addr_;
+ ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0),
+ SyscallSucceedsWithValue(kPageSize));
+ // Read all the data out so that we drain the socket SND_BUF on the sender.
+ std::vector<char> buffer(kPageSize);
+ ASSERT_THAT(RetryEINTR(read)(recv_socket.get(), buffer.data(), kPageSize),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Sleep for a shortwhile to ensure that we have time to process the
+ // ACKs. This is not strictly required unless running under gotsan which is a
+ // lot slower and can result in the next write to write only 1 byte instead of
+ // our intended kPageSize + 1.
+ absl::SleepFor(absl::Milliseconds(50));
+
+ // Now assert that writing > kPageSize results in EFAULT.
+ iov.iov_len = kPageSize + 1;
+ ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/pause.cc b/test/syscalls/linux/pause.cc
new file mode 100644
index 000000000..8c05efd6f
--- /dev/null
+++ b/test/syscalls/linux/pause.cc
@@ -0,0 +1,88 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <signal.h>
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <atomic>
+
+#include "gtest/gtest.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void NoopSignalHandler(int sig, siginfo_t* info, void* context) {}
+
+} // namespace
+
+TEST(PauseTest, OnlyReturnsWhenSignalHandled) {
+ struct sigaction sa;
+ sigfillset(&sa.sa_mask);
+
+ // Ensure that SIGUSR1 is ignored.
+ sa.sa_handler = SIG_IGN;
+ ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds());
+
+ // Register a handler for SIGUSR2.
+ sa.sa_sigaction = NoopSignalHandler;
+ sa.sa_flags = SA_SIGINFO;
+ ASSERT_THAT(sigaction(SIGUSR2, &sa, nullptr), SyscallSucceeds());
+
+ // The child sets their own tid.
+ absl::Mutex mu;
+ pid_t child_tid = 0;
+ bool child_tid_available = false;
+ std::atomic<int> sent_signal{0};
+ std::atomic<int> waking_signal{0};
+ ScopedThread t([&] {
+ mu.Lock();
+ child_tid = gettid();
+ child_tid_available = true;
+ mu.Unlock();
+ EXPECT_THAT(pause(), SyscallFailsWithErrno(EINTR));
+ waking_signal.store(sent_signal.load());
+ });
+ mu.Lock();
+ mu.Await(absl::Condition(&child_tid_available));
+ mu.Unlock();
+
+ // Wait a bit to let the child enter pause().
+ absl::SleepFor(absl::Seconds(3));
+
+ // The child should not be woken by SIGUSR1.
+ sent_signal.store(SIGUSR1);
+ ASSERT_THAT(tgkill(getpid(), child_tid, SIGUSR1), SyscallSucceeds());
+ absl::SleepFor(absl::Seconds(3));
+
+ // The child should be woken by SIGUSR2.
+ sent_signal.store(SIGUSR2);
+ ASSERT_THAT(tgkill(getpid(), child_tid, SIGUSR2), SyscallSucceeds());
+ absl::SleepFor(absl::Seconds(3));
+
+ EXPECT_EQ(SIGUSR2, waking_signal.load());
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/ping_socket.cc b/test/syscalls/linux/ping_socket.cc
new file mode 100644
index 000000000..a9bfdb37b
--- /dev/null
+++ b/test/syscalls/linux/ping_socket.cc
@@ -0,0 +1,91 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <netinet/ip_icmp.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/save_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+class PingSocket : public ::testing::Test {
+ protected:
+ // Creates a socket to be used in tests.
+ void SetUp() override;
+
+ // Closes the socket created by SetUp().
+ void TearDown() override;
+
+ // The loopback address.
+ struct sockaddr_in addr_;
+};
+
+void PingSocket::SetUp() {
+ // On some hosts ping sockets are restricted to specific groups using the
+ // sysctl "ping_group_range".
+ int s = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP);
+ if (s < 0 && errno == EPERM) {
+ GTEST_SKIP();
+ }
+ close(s);
+
+ addr_ = {};
+ // Just a random port as the destination port number is irrelevant for ping
+ // sockets.
+ addr_.sin_port = 12345;
+ addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ addr_.sin_family = AF_INET;
+}
+
+void PingSocket::TearDown() {}
+
+// Test ICMP port exhaustion returns EAGAIN.
+//
+// We disable both random/cooperative S/R for this test as it makes way too many
+// syscalls.
+TEST_F(PingSocket, ICMPPortExhaustion_NoRandomSave) {
+ DisableSave ds;
+ std::vector<FileDescriptor> sockets;
+ constexpr int kSockets = 65536;
+ addr_.sin_port = 0;
+ for (int i = 0; i < kSockets; i++) {
+ auto s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP));
+ int ret = connect(s.get(), reinterpret_cast<struct sockaddr*>(&addr_),
+ sizeof(addr_));
+ if (ret == 0) {
+ sockets.push_back(std::move(s));
+ continue;
+ }
+ ASSERT_THAT(ret, SyscallFailsWithErrno(EAGAIN));
+ break;
+ }
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc
new file mode 100644
index 000000000..34291850d
--- /dev/null
+++ b/test/syscalls/linux/pipe.cc
@@ -0,0 +1,670 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h> /* Obtain O_* constant definitions */
+#include <sys/ioctl.h>
+#include <sys/uio.h>
+#include <unistd.h>
+
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "absl/synchronization/notification.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Used as a non-zero sentinel value, below.
+constexpr int kTestValue = 0x12345678;
+
+// Used for synchronization in race tests.
+const absl::Duration syncDelay = absl::Seconds(2);
+
+struct PipeCreator {
+ std::string name_;
+
+ // void (fds, is_blocking, is_namedpipe).
+ std::function<void(int[2], bool*, bool*)> create_;
+};
+
+class PipeTest : public ::testing::TestWithParam<PipeCreator> {
+ public:
+ static void SetUpTestSuite() {
+ // Tests intentionally generate SIGPIPE.
+ TEST_PCHECK(signal(SIGPIPE, SIG_IGN) != SIG_ERR);
+ }
+
+ // Initializes rfd_ and wfd_ as a blocking pipe.
+ //
+ // The return value indicates success: the test should be skipped otherwise.
+ bool CreateBlocking() { return create(true); }
+
+ // Initializes rfd_ and wfd_ as a non-blocking pipe.
+ //
+ // The return value is per CreateBlocking.
+ bool CreateNonBlocking() { return create(false); }
+
+ // Returns true iff the pipe represents a named pipe.
+ bool IsNamedPipe() const { return named_pipe_; }
+
+ int Size() const {
+ int s1 = fcntl(rfd_.get(), F_GETPIPE_SZ);
+ int s2 = fcntl(wfd_.get(), F_GETPIPE_SZ);
+ EXPECT_GT(s1, 0);
+ EXPECT_GT(s2, 0);
+ EXPECT_EQ(s1, s2);
+ return s1;
+ }
+
+ static void TearDownTestSuite() {
+ TEST_PCHECK(signal(SIGPIPE, SIG_DFL) != SIG_ERR);
+ }
+
+ private:
+ bool create(bool wants_blocking) {
+ // Generate the pipe.
+ int fds[2] = {-1, -1};
+ bool is_blocking = false;
+ GetParam().create_(fds, &is_blocking, &named_pipe_);
+ if (fds[0] < 0 || fds[1] < 0) {
+ return false;
+ }
+
+ // Save descriptors.
+ rfd_.reset(fds[0]);
+ wfd_.reset(fds[1]);
+
+ // Adjust blocking, if needed.
+ if (!is_blocking && wants_blocking) {
+ // Clear the blocking flag.
+ EXPECT_THAT(fcntl(fds[0], F_SETFL, 0), SyscallSucceeds());
+ EXPECT_THAT(fcntl(fds[1], F_SETFL, 0), SyscallSucceeds());
+ } else if (is_blocking && !wants_blocking) {
+ // Set the descriptors to blocking.
+ EXPECT_THAT(fcntl(fds[0], F_SETFL, O_NONBLOCK), SyscallSucceeds());
+ EXPECT_THAT(fcntl(fds[1], F_SETFL, O_NONBLOCK), SyscallSucceeds());
+ }
+
+ return true;
+ }
+
+ protected:
+ FileDescriptor rfd_;
+ FileDescriptor wfd_;
+
+ private:
+ bool named_pipe_ = false;
+};
+
+TEST_P(PipeTest, Inode) {
+ SKIP_IF(!CreateBlocking());
+
+ // Ensure that the inode number is the same for each end.
+ struct stat rst;
+ ASSERT_THAT(fstat(rfd_.get(), &rst), SyscallSucceeds());
+ struct stat wst;
+ ASSERT_THAT(fstat(wfd_.get(), &wst), SyscallSucceeds());
+ EXPECT_EQ(rst.st_ino, wst.st_ino);
+}
+
+TEST_P(PipeTest, Permissions) {
+ SKIP_IF(!CreateBlocking());
+
+ // Attempt bad operations.
+ int buf = kTestValue;
+ ASSERT_THAT(write(rfd_.get(), &buf, sizeof(buf)),
+ SyscallFailsWithErrno(EBADF));
+ EXPECT_THAT(read(wfd_.get(), &buf, sizeof(buf)),
+ SyscallFailsWithErrno(EBADF));
+}
+
+TEST_P(PipeTest, Flags) {
+ SKIP_IF(!CreateBlocking());
+
+ if (IsNamedPipe()) {
+ // May be stubbed to zero; define locally.
+ EXPECT_THAT(fcntl(rfd_.get(), F_GETFL),
+ SyscallSucceedsWithValue(kOLargeFile | O_RDONLY));
+ EXPECT_THAT(fcntl(wfd_.get(), F_GETFL),
+ SyscallSucceedsWithValue(kOLargeFile | O_WRONLY));
+ } else {
+ EXPECT_THAT(fcntl(rfd_.get(), F_GETFL), SyscallSucceedsWithValue(O_RDONLY));
+ EXPECT_THAT(fcntl(wfd_.get(), F_GETFL), SyscallSucceedsWithValue(O_WRONLY));
+ }
+}
+
+TEST_P(PipeTest, Write) {
+ SKIP_IF(!CreateBlocking());
+
+ int wbuf = kTestValue;
+ int rbuf = ~kTestValue;
+ ASSERT_THAT(write(wfd_.get(), &wbuf, sizeof(wbuf)),
+ SyscallSucceedsWithValue(sizeof(wbuf)));
+ ASSERT_THAT(read(rfd_.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(wbuf, rbuf);
+}
+
+TEST_P(PipeTest, WritePage) {
+ SKIP_IF(!CreateBlocking());
+
+ std::vector<char> wbuf(kPageSize);
+ RandomizeBuffer(wbuf.data(), wbuf.size());
+ std::vector<char> rbuf(wbuf.size());
+
+ ASSERT_THAT(write(wfd_.get(), wbuf.data(), wbuf.size()),
+ SyscallSucceedsWithValue(wbuf.size()));
+ ASSERT_THAT(read(rfd_.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(rbuf.size()));
+ EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), wbuf.size()), 0);
+}
+
+TEST_P(PipeTest, NonBlocking) {
+ SKIP_IF(!CreateNonBlocking());
+
+ int wbuf = kTestValue;
+ int rbuf = ~kTestValue;
+ EXPECT_THAT(read(rfd_.get(), &rbuf, sizeof(rbuf)),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+ ASSERT_THAT(write(wfd_.get(), &wbuf, sizeof(wbuf)),
+ SyscallSucceedsWithValue(sizeof(wbuf)));
+
+ ASSERT_THAT(read(rfd_.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(wbuf, rbuf);
+ EXPECT_THAT(read(rfd_.get(), &rbuf, sizeof(rbuf)),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+TEST(Pipe2Test, CloExec) {
+ int fds[2];
+ ASSERT_THAT(pipe2(fds, O_CLOEXEC), SyscallSucceeds());
+ EXPECT_THAT(fcntl(fds[0], F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC));
+ EXPECT_THAT(fcntl(fds[1], F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC));
+ EXPECT_THAT(close(fds[0]), SyscallSucceeds());
+ EXPECT_THAT(close(fds[1]), SyscallSucceeds());
+}
+
+TEST(Pipe2Test, BadOptions) {
+ int fds[2];
+ EXPECT_THAT(pipe2(fds, 0xDEAD), SyscallFailsWithErrno(EINVAL));
+}
+
+// Tests that opening named pipes with O_TRUNC shouldn't cause an error, but
+// calls to (f)truncate should.
+TEST(NamedPipeTest, Truncate) {
+ const std::string tmp_path = NewTempAbsPath();
+ SKIP_IF(mkfifo(tmp_path.c_str(), 0644) != 0);
+
+ ASSERT_THAT(open(tmp_path.c_str(), O_NONBLOCK | O_RDONLY), SyscallSucceeds());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(tmp_path.c_str(), O_RDWR | O_NONBLOCK | O_TRUNC));
+
+ ASSERT_THAT(truncate(tmp_path.c_str(), 0), SyscallFailsWithErrno(EINVAL));
+ ASSERT_THAT(ftruncate(fd.get(), 0), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(PipeTest, Seek) {
+ SKIP_IF(!CreateBlocking());
+
+ for (int i = 0; i < 4; i++) {
+ // Attempt absolute seeks.
+ EXPECT_THAT(lseek(rfd_.get(), 0, SEEK_SET), SyscallFailsWithErrno(ESPIPE));
+ EXPECT_THAT(lseek(rfd_.get(), 4, SEEK_SET), SyscallFailsWithErrno(ESPIPE));
+ EXPECT_THAT(lseek(wfd_.get(), 0, SEEK_SET), SyscallFailsWithErrno(ESPIPE));
+ EXPECT_THAT(lseek(wfd_.get(), 4, SEEK_SET), SyscallFailsWithErrno(ESPIPE));
+
+ // Attempt relative seeks.
+ EXPECT_THAT(lseek(rfd_.get(), 0, SEEK_CUR), SyscallFailsWithErrno(ESPIPE));
+ EXPECT_THAT(lseek(rfd_.get(), 4, SEEK_CUR), SyscallFailsWithErrno(ESPIPE));
+ EXPECT_THAT(lseek(wfd_.get(), 0, SEEK_CUR), SyscallFailsWithErrno(ESPIPE));
+ EXPECT_THAT(lseek(wfd_.get(), 4, SEEK_CUR), SyscallFailsWithErrno(ESPIPE));
+
+ // Attempt end-of-file seeks.
+ EXPECT_THAT(lseek(rfd_.get(), 0, SEEK_CUR), SyscallFailsWithErrno(ESPIPE));
+ EXPECT_THAT(lseek(rfd_.get(), -4, SEEK_END), SyscallFailsWithErrno(ESPIPE));
+ EXPECT_THAT(lseek(wfd_.get(), 0, SEEK_CUR), SyscallFailsWithErrno(ESPIPE));
+ EXPECT_THAT(lseek(wfd_.get(), -4, SEEK_END), SyscallFailsWithErrno(ESPIPE));
+
+ // Add some more data to the pipe.
+ int buf = kTestValue;
+ ASSERT_THAT(write(wfd_.get(), &buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ }
+}
+
+TEST_P(PipeTest, OffsetCalls) {
+ SKIP_IF(!CreateBlocking());
+
+ int buf;
+ EXPECT_THAT(pread(wfd_.get(), &buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(ESPIPE));
+ EXPECT_THAT(pwrite(rfd_.get(), &buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(ESPIPE));
+
+ struct iovec iov;
+ iov.iov_base = &buf;
+ iov.iov_len = sizeof(buf);
+ EXPECT_THAT(preadv(wfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE));
+ EXPECT_THAT(pwritev(rfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE));
+}
+
+TEST_P(PipeTest, WriterSideCloses) {
+ SKIP_IF(!CreateBlocking());
+
+ ScopedThread t([this]() {
+ int buf = ~kTestValue;
+ ASSERT_THAT(read(rfd_.get(), &buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ EXPECT_EQ(buf, kTestValue);
+ // This will return when the close() completes.
+ ASSERT_THAT(read(rfd_.get(), &buf, sizeof(buf)), SyscallSucceeds());
+ // This will return straight away.
+ ASSERT_THAT(read(rfd_.get(), &buf, sizeof(buf)),
+ SyscallSucceedsWithValue(0));
+ });
+
+ // Sleep a bit so the thread can block.
+ absl::SleepFor(syncDelay);
+
+ // Write to unblock.
+ int buf = kTestValue;
+ ASSERT_THAT(write(wfd_.get(), &buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Sleep a bit so the thread can block again.
+ absl::SleepFor(syncDelay);
+
+ // Allow the thread to complete.
+ ASSERT_THAT(close(wfd_.release()), SyscallSucceeds());
+ t.Join();
+}
+
+TEST_P(PipeTest, WriterSideClosesReadDataFirst) {
+ SKIP_IF(!CreateBlocking());
+
+ int wbuf = kTestValue;
+ ASSERT_THAT(write(wfd_.get(), &wbuf, sizeof(wbuf)),
+ SyscallSucceedsWithValue(sizeof(wbuf)));
+ ASSERT_THAT(close(wfd_.release()), SyscallSucceeds());
+
+ int rbuf;
+ ASSERT_THAT(read(rfd_.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(wbuf, rbuf);
+ EXPECT_THAT(read(rfd_.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_P(PipeTest, ReaderSideCloses) {
+ SKIP_IF(!CreateBlocking());
+
+ ASSERT_THAT(close(rfd_.release()), SyscallSucceeds());
+ int buf = kTestValue;
+ EXPECT_THAT(write(wfd_.get(), &buf, sizeof(buf)),
+ SyscallFailsWithErrno(EPIPE));
+}
+
+TEST_P(PipeTest, CloseTwice) {
+ SKIP_IF(!CreateBlocking());
+
+ int reader = rfd_.release();
+ int writer = wfd_.release();
+ ASSERT_THAT(close(reader), SyscallSucceeds());
+ ASSERT_THAT(close(writer), SyscallSucceeds());
+ EXPECT_THAT(close(reader), SyscallFailsWithErrno(EBADF));
+ EXPECT_THAT(close(writer), SyscallFailsWithErrno(EBADF));
+}
+
+// Blocking write returns EPIPE when read end is closed if nothing has been
+// written.
+TEST_P(PipeTest, BlockWriteClosed) {
+ SKIP_IF(!CreateBlocking());
+
+ absl::Notification notify;
+ ScopedThread t([this, &notify]() {
+ std::vector<char> buf(Size());
+ // Exactly fill the pipe buffer.
+ ASSERT_THAT(WriteFd(wfd_.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+
+ notify.Notify();
+
+ // Attempt to write one more byte. Blocks.
+ // N.B. Don't use WriteFd, we don't want a retry.
+ EXPECT_THAT(write(wfd_.get(), buf.data(), 1), SyscallFailsWithErrno(EPIPE));
+ });
+
+ notify.WaitForNotification();
+ ASSERT_THAT(close(rfd_.release()), SyscallSucceeds());
+ t.Join();
+}
+
+// Blocking write returns EPIPE when read end is closed even if something has
+// been written.
+TEST_P(PipeTest, BlockPartialWriteClosed) {
+ SKIP_IF(!CreateBlocking());
+
+ ScopedThread t([this]() {
+ const int pipe_size = Size();
+ std::vector<char> buf(2 * pipe_size);
+
+ // Write more than fits in the buffer. Blocks then returns partial write
+ // when the other end is closed. The next call returns EPIPE.
+ ASSERT_THAT(write(wfd_.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(pipe_size));
+ EXPECT_THAT(write(wfd_.get(), buf.data(), buf.size()),
+ SyscallFailsWithErrno(EPIPE));
+ });
+
+ // Leave time for write to become blocked.
+ absl::SleepFor(syncDelay);
+
+ // Unblock the above.
+ ASSERT_THAT(close(rfd_.release()), SyscallSucceeds());
+ t.Join();
+}
+
+TEST_P(PipeTest, ReadFromClosedFd_NoRandomSave) {
+ SKIP_IF(!CreateBlocking());
+
+ absl::Notification notify;
+ ScopedThread t([this, &notify]() {
+ notify.Notify();
+ int buf;
+ ASSERT_THAT(read(rfd_.get(), &buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ ASSERT_EQ(kTestValue, buf);
+ });
+ notify.WaitForNotification();
+
+ // Make sure that the thread gets to read().
+ absl::SleepFor(syncDelay);
+
+ {
+ // We cannot save/restore here as the read end of pipe is closed but there
+ // is ongoing read() above. We will not be able to restart the read()
+ // successfully in restore run since the read fd is closed.
+ const DisableSave ds;
+ ASSERT_THAT(close(rfd_.release()), SyscallSucceeds());
+ int buf = kTestValue;
+ ASSERT_THAT(write(wfd_.get(), &buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ t.Join();
+ }
+}
+
+TEST_P(PipeTest, FionRead) {
+ SKIP_IF(!CreateBlocking());
+
+ int n;
+ ASSERT_THAT(ioctl(rfd_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+ ASSERT_THAT(ioctl(wfd_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+
+ std::vector<char> buf(Size());
+ ASSERT_THAT(write(wfd_.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+
+ EXPECT_THAT(ioctl(rfd_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, buf.size());
+ EXPECT_THAT(ioctl(wfd_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, buf.size());
+}
+
+// Test that opening an empty anonymous pipe RDONLY via /proc/self/fd/N does not
+// block waiting for a writer.
+TEST_P(PipeTest, OpenViaProcSelfFD) {
+ SKIP_IF(!CreateBlocking());
+ SKIP_IF(IsNamedPipe());
+
+ // Close the write end of the pipe.
+ ASSERT_THAT(close(wfd_.release()), SyscallSucceeds());
+
+ // Open other side via /proc/self/fd. It should not block.
+ FileDescriptor proc_self_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(absl::StrCat("/proc/self/fd/", rfd_.get()), O_RDONLY));
+}
+
+// Test that opening and reading from an anonymous pipe (with existing writes)
+// RDONLY via /proc/self/fd/N returns the existing data.
+TEST_P(PipeTest, OpenViaProcSelfFDWithWrites) {
+ SKIP_IF(!CreateBlocking());
+ SKIP_IF(IsNamedPipe());
+
+ // Write to the pipe and then close the write fd.
+ int wbuf = kTestValue;
+ ASSERT_THAT(write(wfd_.get(), &wbuf, sizeof(wbuf)),
+ SyscallSucceedsWithValue(sizeof(wbuf)));
+ ASSERT_THAT(close(wfd_.release()), SyscallSucceeds());
+
+ // Open read side via /proc/self/fd, and read from it.
+ FileDescriptor proc_self_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(absl::StrCat("/proc/self/fd/", rfd_.get()), O_RDONLY));
+ int rbuf;
+ ASSERT_THAT(read(proc_self_fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(wbuf, rbuf);
+}
+
+// Test that accesses of /proc/<PID>/fd correctly decrement the refcount.
+TEST_P(PipeTest, ProcFDReleasesFile) {
+ SKIP_IF(!CreateBlocking());
+
+ // Stat the pipe FD, which shouldn't alter the refcount.
+ struct stat wst;
+ ASSERT_THAT(lstat(absl::StrCat("/proc/self/fd/", wfd_.get()).c_str(), &wst),
+ SyscallSucceeds());
+
+ // Close the write end and ensure that read indicates EOF.
+ wfd_.reset();
+ char buf;
+ ASSERT_THAT(read(rfd_.get(), &buf, 1), SyscallSucceedsWithValue(0));
+}
+
+// Same for /proc/<PID>/fdinfo.
+TEST_P(PipeTest, ProcFDInfoReleasesFile) {
+ SKIP_IF(!CreateBlocking());
+
+ // Stat the pipe FD, which shouldn't alter the refcount.
+ struct stat wst;
+ ASSERT_THAT(
+ lstat(absl::StrCat("/proc/self/fdinfo/", wfd_.get()).c_str(), &wst),
+ SyscallSucceeds());
+
+ // Close the write end and ensure that read indicates EOF.
+ wfd_.reset();
+ char buf;
+ ASSERT_THAT(read(rfd_.get(), &buf, 1), SyscallSucceedsWithValue(0));
+}
+
+TEST_P(PipeTest, SizeChange) {
+ SKIP_IF(!CreateBlocking());
+
+ // Set the minimum possible size.
+ ASSERT_THAT(fcntl(rfd_.get(), F_SETPIPE_SZ, 0), SyscallSucceeds());
+ int min = Size();
+ EXPECT_GT(min, 0); // Should be rounded up.
+
+ // Set from the read end.
+ ASSERT_THAT(fcntl(rfd_.get(), F_SETPIPE_SZ, min + 1), SyscallSucceeds());
+ int med = Size();
+ EXPECT_GT(med, min); // Should have grown, may be rounded.
+
+ // Set from the write end.
+ ASSERT_THAT(fcntl(wfd_.get(), F_SETPIPE_SZ, med + 1), SyscallSucceeds());
+ int max = Size();
+ EXPECT_GT(max, med); // Ditto.
+}
+
+TEST_P(PipeTest, SizeChangeMax) {
+ SKIP_IF(!CreateBlocking());
+
+ // Assert there's some maximum.
+ EXPECT_THAT(fcntl(rfd_.get(), F_SETPIPE_SZ, 0x7fffffffffffffff),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(fcntl(wfd_.get(), F_SETPIPE_SZ, 0x7fffffffffffffff),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(PipeTest, SizeChangeFull) {
+ SKIP_IF(!CreateBlocking());
+
+ // Ensure that we adjust to a large enough size to avoid rounding when we
+ // perform the size decrease. If rounding occurs, we may not actually
+ // adjust the size and the call below will return success. It was found via
+ // experimentation that this granularity avoids the rounding for Linux.
+ constexpr int kDelta = 64 * 1024;
+ ASSERT_THAT(fcntl(wfd_.get(), F_SETPIPE_SZ, Size() + kDelta),
+ SyscallSucceeds());
+
+ // Fill the buffer and try to change down.
+ std::vector<char> buf(Size());
+ ASSERT_THAT(write(wfd_.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+ EXPECT_THAT(fcntl(wfd_.get(), F_SETPIPE_SZ, Size() - kDelta),
+ SyscallFailsWithErrno(EBUSY));
+}
+
+TEST_P(PipeTest, Streaming) {
+ SKIP_IF(!CreateBlocking());
+
+ // We make too many calls to go through full save cycles.
+ DisableSave ds;
+
+ // Size() requires 2 syscalls, call it once and remember the value.
+ const int pipe_size = Size();
+
+ absl::Notification notify;
+ ScopedThread t([this, &notify, pipe_size]() {
+ // Don't start until it's full.
+ notify.WaitForNotification();
+ for (int i = 0; i < pipe_size; i++) {
+ int rbuf;
+ ASSERT_THAT(read(rfd_.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf, i);
+ }
+ });
+
+ // Write 4 bytes * pipe_size. It will fill up the pipe once, notify the reader
+ // to start. Then we write pipe size worth 3 more times to ensure the reader
+ // can follow along.
+ ssize_t total = 0;
+ for (int i = 0; i < pipe_size; i++) {
+ ssize_t written = write(wfd_.get(), &i, sizeof(i));
+ ASSERT_THAT(written, SyscallSucceedsWithValue(sizeof(i)));
+ total += written;
+
+ // Is the next write about to fill up the buffer? Wake up the reader once.
+ if (total < pipe_size && (total + written) >= pipe_size) {
+ notify.Notify();
+ }
+ }
+}
+
+std::string PipeCreatorName(::testing::TestParamInfo<PipeCreator> info) {
+ return info.param.name_; // Use the name specified.
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ Pipes, PipeTest,
+ ::testing::Values(
+ PipeCreator{
+ "pipe",
+ [](int fds[2], bool* is_blocking, bool* is_namedpipe) {
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ *is_blocking = true;
+ *is_namedpipe = false;
+ },
+ },
+ PipeCreator{
+ "pipe2blocking",
+ [](int fds[2], bool* is_blocking, bool* is_namedpipe) {
+ ASSERT_THAT(pipe2(fds, 0), SyscallSucceeds());
+ *is_blocking = true;
+ *is_namedpipe = false;
+ },
+ },
+ PipeCreator{
+ "pipe2nonblocking",
+ [](int fds[2], bool* is_blocking, bool* is_namedpipe) {
+ ASSERT_THAT(pipe2(fds, O_NONBLOCK), SyscallSucceeds());
+ *is_blocking = false;
+ *is_namedpipe = false;
+ },
+ },
+ PipeCreator{
+ "smallbuffer",
+ [](int fds[2], bool* is_blocking, bool* is_namedpipe) {
+ // Set to the minimum available size (will round up).
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ ASSERT_THAT(fcntl(fds[0], F_SETPIPE_SZ, 0), SyscallSucceeds());
+ *is_blocking = true;
+ *is_namedpipe = false;
+ },
+ },
+ PipeCreator{
+ "namednonblocking",
+ [](int fds[2], bool* is_blocking, bool* is_namedpipe) {
+ // Create a new file-based pipe (non-blocking).
+ std::string path;
+ {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ path = file.path();
+ }
+ SKIP_IF(mkfifo(path.c_str(), 0644) != 0);
+ fds[0] = open(path.c_str(), O_NONBLOCK | O_RDONLY);
+ fds[1] = open(path.c_str(), O_NONBLOCK | O_WRONLY);
+ MaybeSave();
+ *is_blocking = false;
+ *is_namedpipe = true;
+ },
+ },
+ PipeCreator{
+ "namedblocking",
+ [](int fds[2], bool* is_blocking, bool* is_namedpipe) {
+ // Create a new file-based pipe (blocking).
+ std::string path;
+ {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ path = file.path();
+ }
+ SKIP_IF(mkfifo(path.c_str(), 0644) != 0);
+ ScopedThread t(
+ [&path, &fds]() { fds[1] = open(path.c_str(), O_WRONLY); });
+ fds[0] = open(path.c_str(), O_RDONLY);
+ t.Join();
+ MaybeSave();
+ *is_blocking = true;
+ *is_namedpipe = true;
+ },
+ }),
+ PipeCreatorName);
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/poll.cc b/test/syscalls/linux/poll.cc
new file mode 100644
index 000000000..7a316427d
--- /dev/null
+++ b/test/syscalls/linux/poll.cc
@@ -0,0 +1,294 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <poll.h>
+#include <sys/resource.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include <algorithm>
+#include <iostream>
+
+#include "gtest/gtest.h"
+#include "absl/synchronization/notification.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/syscalls/linux/base_poll_test.h"
+#include "test/util/eventfd_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+class PollTest : public BasePollTest {
+ protected:
+ void SetUp() override { BasePollTest::SetUp(); }
+ void TearDown() override { BasePollTest::TearDown(); }
+};
+
+TEST_F(PollTest, InvalidFds) {
+ // fds is invalid because it's null, but we tell ppoll the length is non-zero.
+ EXPECT_THAT(poll(nullptr, 1, 1), SyscallFailsWithErrno(EFAULT));
+ EXPECT_THAT(poll(nullptr, -1, 1), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(PollTest, NullFds) {
+ EXPECT_THAT(poll(nullptr, 0, 10), SyscallSucceeds());
+}
+
+TEST_F(PollTest, ZeroTimeout) {
+ EXPECT_THAT(poll(nullptr, 0, 0), SyscallSucceeds());
+}
+
+// If random S/R interrupts the poll, SIGALRM may be delivered before poll
+// restarts, causing the poll to hang forever.
+TEST_F(PollTest, NegativeTimeout_NoRandomSave) {
+ // Negative timeout mean wait forever so set a timer.
+ SetTimer(absl::Milliseconds(100));
+ EXPECT_THAT(poll(nullptr, 0, -1), SyscallFailsWithErrno(EINTR));
+ EXPECT_TRUE(TimerFired());
+}
+
+TEST_F(PollTest, NonBlockingEventPOLLIN) {
+ // Create a pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+
+ FileDescriptor fd0(fds[0]);
+ FileDescriptor fd1(fds[1]);
+
+ // Write some data to the pipe.
+ char s[] = "foo\n";
+ ASSERT_THAT(WriteFd(fd1.get(), s, strlen(s) + 1), SyscallSucceeds());
+
+ // Poll on the reader fd with POLLIN event.
+ struct pollfd poll_fd = {fd0.get(), POLLIN, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 0), SyscallSucceedsWithValue(1));
+
+ // Should trigger POLLIN event.
+ EXPECT_EQ(poll_fd.revents & POLLIN, POLLIN);
+}
+
+TEST_F(PollTest, BlockingEventPOLLIN) {
+ // Create a pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+
+ FileDescriptor fd0(fds[0]);
+ FileDescriptor fd1(fds[1]);
+
+ // Start a blocking poll on the read fd.
+ absl::Notification notify;
+ ScopedThread t([&fd0, &notify]() {
+ notify.Notify();
+
+ // Poll on the reader fd with POLLIN event.
+ struct pollfd poll_fd = {fd0.get(), POLLIN, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, -1), SyscallSucceedsWithValue(1));
+
+ // Should trigger POLLIN event.
+ EXPECT_EQ(poll_fd.revents & POLLIN, POLLIN);
+ });
+
+ notify.WaitForNotification();
+ absl::SleepFor(absl::Seconds(1.0));
+
+ // Write some data to the pipe.
+ char s[] = "foo\n";
+ ASSERT_THAT(WriteFd(fd1.get(), s, strlen(s) + 1), SyscallSucceeds());
+}
+
+TEST_F(PollTest, NonBlockingEventPOLLHUP) {
+ // Create a pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+
+ FileDescriptor fd0(fds[0]);
+ FileDescriptor fd1(fds[1]);
+
+ // Close the writer fd.
+ fd1.reset();
+
+ // Poll on the reader fd with POLLIN event.
+ struct pollfd poll_fd = {fd0.get(), POLLIN, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 0), SyscallSucceedsWithValue(1));
+
+ // Should trigger POLLHUP event.
+ EXPECT_EQ(poll_fd.revents & POLLHUP, POLLHUP);
+
+ // Should not trigger POLLIN event.
+ EXPECT_EQ(poll_fd.revents & POLLIN, 0);
+}
+
+TEST_F(PollTest, BlockingEventPOLLHUP) {
+ // Create a pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+
+ FileDescriptor fd0(fds[0]);
+ FileDescriptor fd1(fds[1]);
+
+ // Start a blocking poll on the read fd.
+ absl::Notification notify;
+ ScopedThread t([&fd0, &notify]() {
+ notify.Notify();
+
+ // Poll on the reader fd with POLLIN event.
+ struct pollfd poll_fd = {fd0.get(), POLLIN, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, -1), SyscallSucceedsWithValue(1));
+
+ // Should trigger POLLHUP event.
+ EXPECT_EQ(poll_fd.revents & POLLHUP, POLLHUP);
+
+ // Should not trigger POLLIN event.
+ EXPECT_EQ(poll_fd.revents & POLLIN, 0);
+ });
+
+ notify.WaitForNotification();
+ absl::SleepFor(absl::Seconds(1.0));
+
+ // Write some data and close the writer fd.
+ fd1.reset();
+}
+
+TEST_F(PollTest, NonBlockingEventPOLLERR) {
+ // Create a pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+
+ FileDescriptor fd0(fds[0]);
+ FileDescriptor fd1(fds[1]);
+
+ // Close the reader fd.
+ fd0.reset();
+
+ // Poll on the writer fd with POLLOUT event.
+ struct pollfd poll_fd = {fd1.get(), POLLOUT, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 0), SyscallSucceedsWithValue(1));
+
+ // Should trigger POLLERR event.
+ EXPECT_EQ(poll_fd.revents & POLLERR, POLLERR);
+
+ // Should also trigger POLLOUT event.
+ EXPECT_EQ(poll_fd.revents & POLLOUT, POLLOUT);
+}
+
+// This test will validate that if an FD is already ready on some event, whether
+// it's POLLIN or POLLOUT it will not immediately return unless that's actually
+// what the caller was interested in.
+TEST_F(PollTest, ImmediatelyReturnOnlyOnPollEvents) {
+ // Create a pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+
+ FileDescriptor fd0(fds[0]);
+ FileDescriptor fd1(fds[1]);
+
+ // Wait for read related event on the write side of the pipe, since a write
+ // is possible on fds[1] it would mean that POLLOUT would return immediately.
+ // We should make sure that we're not woken up with that state that we didn't
+ // specificially request.
+ constexpr int kTimeoutMs = 100;
+ struct pollfd poll_fd = {fd1.get(), POLLIN | POLLPRI | POLLRDHUP, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, kTimeoutMs),
+ SyscallSucceedsWithValue(0)); // We should timeout.
+ EXPECT_EQ(poll_fd.revents, 0); // Nothing should be in returned events.
+
+ // Now let's poll on POLLOUT and we should get back 1 fd as being ready and
+ // it should contain POLLOUT in the revents.
+ poll_fd.events = POLLOUT;
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, kTimeoutMs),
+ SyscallSucceedsWithValue(1)); // 1 fd should have an event.
+ EXPECT_EQ(poll_fd.revents, POLLOUT); // POLLOUT should be in revents.
+}
+
+// This test validates that poll(2) while data is available immediately returns.
+TEST_F(PollTest, PollLevelTriggered) {
+ int fds[2] = {};
+ ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, /*protocol=*/0, fds),
+ SyscallSucceeds());
+
+ FileDescriptor fd0(fds[0]);
+ FileDescriptor fd1(fds[1]);
+
+ // Write two bytes to the socket.
+ const char* kBuf = "aa";
+ ASSERT_THAT(RetryEINTR(send)(fd0.get(), kBuf, /*len=*/2, /*flags=*/0),
+ SyscallSucceedsWithValue(2)); // 2 bytes should be written.
+
+ // Poll(2) should immediately return as there is data available to read.
+ constexpr int kInfiniteTimeout = -1;
+ struct pollfd poll_fd = {fd1.get(), POLLIN, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd, /*nfds=*/1, kInfiniteTimeout),
+ SyscallSucceedsWithValue(1)); // 1 fd should be ready to read.
+ EXPECT_NE(poll_fd.revents & POLLIN, 0);
+
+ // Read a single byte.
+ char read_byte = 0;
+ ASSERT_THAT(RetryEINTR(recv)(fd1.get(), &read_byte, /*len=*/1, /*flags=*/0),
+ SyscallSucceedsWithValue(1)); // 1 byte should be read.
+ ASSERT_EQ(read_byte, 'a'); // We should have read a single 'a'.
+
+ // Create a separate pollfd for our second poll.
+ struct pollfd poll_fd_after = {fd1.get(), POLLIN, 0};
+
+ // Poll(2) should again immediately return since we only read one byte.
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd_after, /*nfds=*/1, kInfiniteTimeout),
+ SyscallSucceedsWithValue(1)); // 1 fd should be ready to read.
+ EXPECT_NE(poll_fd_after.revents & POLLIN, 0);
+}
+
+TEST_F(PollTest, Nfds) {
+ // Stash value of RLIMIT_NOFILES.
+ struct rlimit rlim;
+ TEST_PCHECK(getrlimit(RLIMIT_NOFILE, &rlim) == 0);
+
+ // gVisor caps the number of FDs that epoll can use beyond RLIMIT_NOFILE.
+ constexpr rlim_t maxFD = 4096;
+ if (rlim.rlim_cur > maxFD) {
+ rlim.rlim_cur = maxFD;
+ TEST_PCHECK(setrlimit(RLIMIT_NOFILE, &rlim) == 0);
+ }
+
+ rlim_t max_fds = rlim.rlim_cur;
+ std::cout << "Using limit: " << max_fds << std::endl;
+
+ // Create an eventfd. Since its value is initially zero, it is writable.
+ FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD());
+
+ // Create the biggest possible pollfd array such that each element is valid.
+ // Each entry in the 'fds' array refers to the eventfd and polls for
+ // "writable" events (events=POLLOUT). This essentially guarantees that the
+ // poll() is a no-op and allows negative testing of the 'nfds' parameter.
+ std::vector<struct pollfd> fds(max_fds + 1,
+ {.fd = efd.get(), .events = POLLOUT});
+
+ // Verify that 'nfds' up to RLIMIT_NOFILE are allowed.
+ EXPECT_THAT(RetryEINTR(poll)(fds.data(), 1, 1), SyscallSucceedsWithValue(1));
+ EXPECT_THAT(RetryEINTR(poll)(fds.data(), max_fds / 2, 1),
+ SyscallSucceedsWithValue(max_fds / 2));
+ EXPECT_THAT(RetryEINTR(poll)(fds.data(), max_fds, 1),
+ SyscallSucceedsWithValue(max_fds));
+
+ // If 'nfds' exceeds RLIMIT_NOFILE then it must fail with EINVAL.
+ EXPECT_THAT(poll(fds.data(), max_fds + 1, 1), SyscallFailsWithErrno(EINVAL));
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/ppoll.cc b/test/syscalls/linux/ppoll.cc
new file mode 100644
index 000000000..8245a11e8
--- /dev/null
+++ b/test/syscalls/linux/ppoll.cc
@@ -0,0 +1,155 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <poll.h>
+#include <signal.h>
+#include <sys/syscall.h>
+#include <sys/time.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/time/time.h"
+#include "test/syscalls/linux/base_poll_test.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+// Linux and glibc have a different idea of the sizeof sigset_t. When calling
+// the syscall directly, use what the kernel expects.
+unsigned kSigsetSize = SIGRTMAX / 8;
+
+// Linux ppoll(2) differs from the glibc wrapper function in that Linux updates
+// the timeout with the amount of time remaining. In order to test this behavior
+// we need to use the syscall directly.
+int syscallPpoll(struct pollfd* fds, nfds_t nfds, struct timespec* timeout_ts,
+ const sigset_t* sigmask, unsigned mask_size) {
+ return syscall(SYS_ppoll, fds, nfds, timeout_ts, sigmask, mask_size);
+}
+
+class PpollTest : public BasePollTest {
+ protected:
+ void SetUp() override { BasePollTest::SetUp(); }
+ void TearDown() override { BasePollTest::TearDown(); }
+};
+
+TEST_F(PpollTest, InvalidFds) {
+ // fds is invalid because it's null, but we tell ppoll the length is non-zero.
+ struct timespec timeout = {};
+ sigset_t sigmask;
+ TEST_PCHECK(sigemptyset(&sigmask) == 0);
+ EXPECT_THAT(syscallPpoll(nullptr, 1, &timeout, &sigmask, kSigsetSize),
+ SyscallFailsWithErrno(EFAULT));
+ EXPECT_THAT(syscallPpoll(nullptr, -1, &timeout, &sigmask, kSigsetSize),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// See that when fds is null, ppoll behaves like sleep.
+TEST_F(PpollTest, NullFds) {
+ struct timespec timeout = absl::ToTimespec(absl::Milliseconds(10));
+ ASSERT_THAT(syscallPpoll(nullptr, 0, &timeout, nullptr, 0),
+ SyscallSucceeds());
+ EXPECT_EQ(timeout.tv_sec, 0);
+ EXPECT_EQ(timeout.tv_nsec, 0);
+}
+
+TEST_F(PpollTest, ZeroTimeout) {
+ struct timespec timeout = {};
+ ASSERT_THAT(syscallPpoll(nullptr, 0, &timeout, nullptr, 0),
+ SyscallSucceeds());
+ EXPECT_EQ(timeout.tv_sec, 0);
+ EXPECT_EQ(timeout.tv_nsec, 0);
+}
+
+// If random S/R interrupts the ppoll, SIGALRM may be delivered before ppoll
+// restarts, causing the ppoll to hang forever.
+TEST_F(PpollTest, NoTimeout_NoRandomSave) {
+ // When there's no timeout, ppoll may never return so set a timer.
+ SetTimer(absl::Milliseconds(100));
+ // See that we get interrupted by the timer.
+ ASSERT_THAT(syscallPpoll(nullptr, 0, nullptr, nullptr, 0),
+ SyscallFailsWithErrno(EINTR));
+ EXPECT_TRUE(TimerFired());
+}
+
+TEST_F(PpollTest, InvalidTimeoutNegative) {
+ struct timespec timeout = absl::ToTimespec(absl::Nanoseconds(-1));
+ EXPECT_THAT(syscallPpoll(nullptr, 0, &timeout, nullptr, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(PpollTest, InvalidTimeoutNotNormalized) {
+ struct timespec timeout = {0, 1000000001};
+ EXPECT_THAT(syscallPpoll(nullptr, 0, &timeout, nullptr, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(PpollTest, InvalidMaskSize) {
+ struct timespec timeout = {};
+ sigset_t sigmask;
+ TEST_PCHECK(sigemptyset(&sigmask) == 0);
+ EXPECT_THAT(syscallPpoll(nullptr, 0, &timeout, &sigmask, 128),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// Verify that signals blocked by the ppoll mask (that would otherwise be
+// allowed) do not interrupt ppoll.
+TEST_F(PpollTest, SignalMaskBlocksSignal) {
+ absl::Duration duration(absl::Seconds(30));
+ struct timespec timeout = absl::ToTimespec(duration);
+ absl::Duration timer_duration(absl::Seconds(10));
+
+ // Call with a mask that blocks SIGALRM. See that ppoll is not interrupted
+ // (i.e. returns 0) and that upon completion, the timer has fired.
+ sigset_t mask;
+ ASSERT_THAT(sigprocmask(0, nullptr, &mask), SyscallSucceeds());
+ TEST_PCHECK(sigaddset(&mask, SIGALRM) == 0);
+ SetTimer(timer_duration);
+ MaybeSave();
+ ASSERT_FALSE(TimerFired());
+ ASSERT_THAT(syscallPpoll(nullptr, 0, &timeout, &mask, kSigsetSize),
+ SyscallSucceeds());
+ EXPECT_TRUE(TimerFired());
+ EXPECT_EQ(absl::DurationFromTimespec(timeout), absl::Duration());
+}
+
+// Verify that signals allowed by the ppoll mask (that would otherwise be
+// blocked) interrupt ppoll.
+TEST_F(PpollTest, SignalMaskAllowsSignal) {
+ absl::Duration duration(absl::Seconds(30));
+ struct timespec timeout = absl::ToTimespec(duration);
+ absl::Duration timer_duration(absl::Seconds(10));
+
+ sigset_t mask;
+ ASSERT_THAT(sigprocmask(0, nullptr, &mask), SyscallSucceeds());
+
+ // Block SIGALRM.
+ auto cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, SIGALRM));
+
+ // Call with a mask that unblocks SIGALRM. See that ppoll is interrupted.
+ SetTimer(timer_duration);
+ MaybeSave();
+ ASSERT_FALSE(TimerFired());
+ ASSERT_THAT(syscallPpoll(nullptr, 0, &timeout, &mask, kSigsetSize),
+ SyscallFailsWithErrno(EINTR));
+ EXPECT_TRUE(TimerFired());
+ EXPECT_GT(absl::DurationFromTimespec(timeout), absl::Duration());
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc
new file mode 100644
index 000000000..04c5161f5
--- /dev/null
+++ b/test/syscalls/linux/prctl.cc
@@ -0,0 +1,230 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/prctl.h>
+#include <sys/ptrace.h>
+#include <sys/types.h>
+#include <sys/wait.h>
+#include <unistd.h>
+
+#include <string>
+
+#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
+#include "test/util/capability_util.h"
+#include "test/util/cleanup.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+ABSL_FLAG(bool, prctl_no_new_privs_test_child, false,
+ "If true, exit with the return value of prctl(PR_GET_NO_NEW_PRIVS) "
+ "plus an offset (see test source).");
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+#ifndef SUID_DUMP_DISABLE
+#define SUID_DUMP_DISABLE 0
+#endif /* SUID_DUMP_DISABLE */
+#ifndef SUID_DUMP_USER
+#define SUID_DUMP_USER 1
+#endif /* SUID_DUMP_USER */
+#ifndef SUID_DUMP_ROOT
+#define SUID_DUMP_ROOT 2
+#endif /* SUID_DUMP_ROOT */
+
+TEST(PrctlTest, NameInitialized) {
+ const size_t name_length = 20;
+ char name[name_length] = {};
+ ASSERT_THAT(prctl(PR_GET_NAME, name), SyscallSucceeds());
+ ASSERT_NE(std::string(name), "");
+}
+
+TEST(PrctlTest, SetNameLongName) {
+ const size_t name_length = 20;
+ const std::string long_name(name_length, 'A');
+ ASSERT_THAT(prctl(PR_SET_NAME, long_name.c_str()), SyscallSucceeds());
+ char truncated_name[name_length] = {};
+ ASSERT_THAT(prctl(PR_GET_NAME, truncated_name), SyscallSucceeds());
+ const size_t truncated_length = 15;
+ ASSERT_EQ(long_name.substr(0, truncated_length), std::string(truncated_name));
+}
+
+TEST(PrctlTest, ChildProcessName) {
+ constexpr size_t kMaxNameLength = 15;
+
+ char parent_name[kMaxNameLength + 1] = {};
+ memset(parent_name, 'a', kMaxNameLength);
+
+ ASSERT_THAT(prctl(PR_SET_NAME, parent_name), SyscallSucceeds());
+
+ pid_t child_pid = fork();
+ TEST_PCHECK(child_pid >= 0);
+ if (child_pid == 0) {
+ char child_name[kMaxNameLength + 1] = {};
+ TEST_PCHECK(prctl(PR_GET_NAME, child_name) >= 0);
+ TEST_CHECK(memcmp(parent_name, child_name, sizeof(parent_name)) == 0);
+ _exit(0);
+ }
+
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "status =" << status;
+}
+
+// Offset added to exit code from test child to distinguish from other abnormal
+// exits.
+constexpr int kPrctlNoNewPrivsTestChildExitBase = 100;
+
+TEST(PrctlTest, NoNewPrivsPreservedAcrossCloneForkAndExecve) {
+ // Check if no_new_privs is already set. If it is, we can still test that it's
+ // preserved across clone/fork/execve, but we also expect it to still be set
+ // at the end of the test. Otherwise, call prctl(PR_SET_NO_NEW_PRIVS) so as
+ // not to contaminate the original thread.
+ int no_new_privs;
+ ASSERT_THAT(no_new_privs = prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0),
+ SyscallSucceeds());
+ ScopedThread([] {
+ ASSERT_THAT(prctl(PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0), SyscallSucceeds());
+ EXPECT_THAT(prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0),
+ SyscallSucceedsWithValue(1));
+ ScopedThread([] {
+ EXPECT_THAT(prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0),
+ SyscallSucceedsWithValue(1));
+ // Note that these ASSERT_*s failing will only return from this thread,
+ // but this is the intended behavior.
+ pid_t child_pid = -1;
+ int execve_errno = 0;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec("/proc/self/exe",
+ {"/proc/self/exe", "--prctl_no_new_privs_test_child"}, {},
+ nullptr, &child_pid, &execve_errno));
+
+ ASSERT_GT(child_pid, 0);
+ ASSERT_EQ(execve_errno, 0);
+
+ int status = 0;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0),
+ SyscallSucceeds());
+ ASSERT_TRUE(WIFEXITED(status));
+ ASSERT_EQ(WEXITSTATUS(status), kPrctlNoNewPrivsTestChildExitBase + 1);
+
+ EXPECT_THAT(prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0),
+ SyscallSucceedsWithValue(1));
+ });
+ EXPECT_THAT(prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0),
+ SyscallSucceedsWithValue(1));
+ });
+ EXPECT_THAT(prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0),
+ SyscallSucceedsWithValue(no_new_privs));
+}
+
+TEST(PrctlTest, PDeathSig) {
+ pid_t child_pid;
+
+ // Make the new process' parent a separate thread since the parent death
+ // signal fires when the parent *thread* exits.
+ ScopedThread([&] {
+ child_pid = fork();
+ TEST_CHECK(child_pid >= 0);
+ if (child_pid == 0) {
+ // In child process.
+ TEST_CHECK(prctl(PR_SET_PDEATHSIG, SIGKILL) >= 0);
+ int signo;
+ TEST_CHECK(prctl(PR_GET_PDEATHSIG, &signo) >= 0);
+ TEST_CHECK(signo == SIGKILL);
+ // Enable tracing, then raise SIGSTOP and expect our parent to suppress
+ // it.
+ TEST_CHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) >= 0);
+ raise(SIGSTOP);
+ // Sleep until killed by our parent death signal. sleep(3) is
+ // async-signal-safe, absl::SleepFor isn't.
+ while (true) {
+ sleep(10);
+ }
+ }
+ // In parent process.
+
+ // Wait for the child to send itself SIGSTOP and enter signal-delivery-stop.
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP)
+ << "status = " << status;
+
+ // Suppress the SIGSTOP and detach from the child.
+ ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds());
+ });
+
+ // The child should have been killed by its parent death SIGKILL.
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL)
+ << "status = " << status;
+}
+
+// This test is to validate that calling prctl with PR_SET_MM without the
+// CAP_SYS_RESOURCE returns EPERM.
+TEST(PrctlTest, InvalidPrSetMM) {
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_RESOURCE))) {
+ ASSERT_NO_ERRNO(SetCapability(CAP_SYS_RESOURCE,
+ false)); // Drop capability to test below.
+ }
+ ASSERT_THAT(prctl(PR_SET_MM, 0, 0, 0, 0), SyscallFailsWithErrno(EPERM));
+}
+
+// Sanity check that dumpability is remembered.
+TEST(PrctlTest, SetGetDumpability) {
+ int before;
+ ASSERT_THAT(before = prctl(PR_GET_DUMPABLE), SyscallSucceeds());
+ auto cleanup = Cleanup([before] {
+ ASSERT_THAT(prctl(PR_SET_DUMPABLE, before), SyscallSucceeds());
+ });
+
+ EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_DISABLE), SyscallSucceeds());
+ EXPECT_THAT(prctl(PR_GET_DUMPABLE),
+ SyscallSucceedsWithValue(SUID_DUMP_DISABLE));
+
+ EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_USER), SyscallSucceeds());
+ EXPECT_THAT(prctl(PR_GET_DUMPABLE), SyscallSucceedsWithValue(SUID_DUMP_USER));
+}
+
+// SUID_DUMP_ROOT cannot be set via PR_SET_DUMPABLE.
+TEST(PrctlTest, RootDumpability) {
+ EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_ROOT),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
+
+int main(int argc, char** argv) {
+ gvisor::testing::TestInit(&argc, &argv);
+
+ if (absl::GetFlag(FLAGS_prctl_no_new_privs_test_child)) {
+ exit(gvisor::testing::kPrctlNoNewPrivsTestChildExitBase +
+ prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0));
+ }
+
+ return gvisor::testing::RunAllTests();
+}
diff --git a/test/syscalls/linux/prctl_setuid.cc b/test/syscalls/linux/prctl_setuid.cc
new file mode 100644
index 000000000..c4e9cf528
--- /dev/null
+++ b/test/syscalls/linux/prctl_setuid.cc
@@ -0,0 +1,268 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sched.h>
+#include <sys/prctl.h>
+
+#include <string>
+
+#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
+#include "test/util/capability_util.h"
+#include "test/util/logging.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+ABSL_FLAG(int32_t, scratch_uid, 65534, "scratch UID");
+// This flag is used to verify that after an exec PR_GET_KEEPCAPS
+// returns 0, the return code will be offset by kPrGetKeepCapsExitBase.
+ABSL_FLAG(bool, prctl_pr_get_keepcaps, false,
+ "If true the test will verify that prctl with pr_get_keepcaps"
+ "returns 0. The test will exit with the result of that check.");
+
+// These tests exist seperately from prctl because we need to start
+// them as root. Setuid() has the behavior that permissions are fully
+// removed if one of the UIDs were 0 before a setuid() call. This
+// behavior can be changed by using PR_SET_KEEPCAPS and that is what
+// is tested here.
+//
+// Reference setuid(2):
+// The setuid() function checks the effective user ID of
+// the caller and if it is the superuser, all process-related user ID's
+// are set to uid. After this has occurred, it is impossible for the
+// program to regain root privileges.
+//
+// Thus, a set-user-ID-root program wishing to temporarily drop root
+// privileges, assume the identity of an unprivileged user, and then
+// regain root privileges afterward cannot use setuid(). You can
+// accomplish this with seteuid(2).
+namespace gvisor {
+namespace testing {
+
+// Offset added to exit code from test child to distinguish from other abnormal
+// exits.
+constexpr int kPrGetKeepCapsExitBase = 100;
+
+namespace {
+
+class PrctlKeepCapsSetuidTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ // PR_GET_KEEPCAPS will only return 0 or 1 (on success).
+ ASSERT_THAT(original_keepcaps_ = prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0),
+ SyscallSucceeds());
+ ASSERT_TRUE(original_keepcaps_ == 0 || original_keepcaps_ == 1);
+ }
+
+ void TearDown() override {
+ // Restore PR_SET_KEEPCAPS.
+ ASSERT_THAT(prctl(PR_SET_KEEPCAPS, original_keepcaps_, 0, 0, 0),
+ SyscallSucceeds());
+
+ // Verify that it was restored.
+ ASSERT_THAT(prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0),
+ SyscallSucceedsWithValue(original_keepcaps_));
+ }
+
+ // The original keep caps value exposed so tests can use it if they need.
+ int original_keepcaps_ = 0;
+};
+
+// This test will verify that a bad value, eg. not 0 or 1 for
+// PR_SET_KEEPCAPS will return EINVAL as required by prctl(2).
+TEST_F(PrctlKeepCapsSetuidTest, PrctlBadArgsToKeepCaps) {
+ ASSERT_THAT(prctl(PR_SET_KEEPCAPS, 2, 0, 0, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// This test will verify that a setuid(2) without PR_SET_KEEPCAPS will cause
+// all capabilities to be dropped.
+TEST_F(PrctlKeepCapsSetuidTest, SetUidNoKeepCaps) {
+ // getuid(2) never fails.
+ if (getuid() != 0) {
+ SKIP_IF(!IsRunningOnGvisor());
+ FAIL() << "User is not root on gvisor platform.";
+ }
+
+ // Do setuid in a separate thread so that after finishing this test, the
+ // process can still open files the test harness created before starting
+ // this test. Otherwise, the files are created by root (UID before the
+ // test), but cannot be opened by the `uid` set below after the test. After
+ // calling setuid(non-zero-UID), there is no way to get root privileges
+ // back.
+ ScopedThread([] {
+ // Start by verifying we have a capability.
+ TEST_CHECK(HaveCapability(CAP_SYS_ADMIN).ValueOrDie());
+
+ // Verify that PR_GET_KEEPCAPS is disabled.
+ ASSERT_THAT(prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0),
+ SyscallSucceedsWithValue(0));
+
+ // Use syscall instead of glibc setuid wrapper because we want this setuid
+ // call to only apply to this task. POSIX threads, however, require that
+ // all threads have the same UIDs, so using the setuid wrapper sets all
+ // threads' real UID.
+ EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)),
+ SyscallSucceeds());
+
+ // Verify that we changed uid.
+ EXPECT_THAT(getuid(),
+ SyscallSucceedsWithValue(absl::GetFlag(FLAGS_scratch_uid)));
+
+ // Verify we lost the capability in the effective set, this always happens.
+ TEST_CHECK(!HaveCapability(CAP_SYS_ADMIN).ValueOrDie());
+
+ // We should have also lost it in the permitted set by the setuid() so
+ // SetCapability should fail when we try to add it back to the effective set
+ ASSERT_FALSE(SetCapability(CAP_SYS_ADMIN, true).ok());
+ });
+}
+
+// This test will verify that a setuid with PR_SET_KEEPCAPS will cause
+// capabilities to be retained after we switch away from the root user.
+TEST_F(PrctlKeepCapsSetuidTest, SetUidKeepCaps) {
+ // getuid(2) never fails.
+ if (getuid() != 0) {
+ SKIP_IF(!IsRunningOnGvisor());
+ FAIL() << "User is not root on gvisor platform.";
+ }
+
+ // Do setuid in a separate thread so that after finishing this test, the
+ // process can still open files the test harness created before starting
+ // this test. Otherwise, the files are created by root (UID before the
+ // test), but cannot be opened by the `uid` set below after the test. After
+ // calling setuid(non-zero-UID), there is no way to get root privileges
+ // back.
+ ScopedThread([] {
+ // Start by verifying we have a capability.
+ TEST_CHECK(HaveCapability(CAP_SYS_ADMIN).ValueOrDie());
+
+ // Set PR_SET_KEEPCAPS.
+ ASSERT_THAT(prctl(PR_SET_KEEPCAPS, 1, 0, 0, 0), SyscallSucceeds());
+
+ // Verify PR_SET_KEEPCAPS was set before we proceed.
+ ASSERT_THAT(prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0),
+ SyscallSucceedsWithValue(1));
+
+ // Use syscall instead of glibc setuid wrapper because we want this setuid
+ // call to only apply to this task. POSIX threads, however, require that
+ // all threads have the same UIDs, so using the setuid wrapper sets all
+ // threads' real UID.
+ EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)),
+ SyscallSucceeds());
+
+ // Verify that we changed uid.
+ EXPECT_THAT(getuid(),
+ SyscallSucceedsWithValue(absl::GetFlag(FLAGS_scratch_uid)));
+
+ // Verify we lost the capability in the effective set, this always happens.
+ TEST_CHECK(!HaveCapability(CAP_SYS_ADMIN).ValueOrDie());
+
+ // We lost the capability in the effective set, but it will still
+ // exist in the permitted set so we can elevate the capability.
+ ASSERT_NO_ERRNO(SetCapability(CAP_SYS_ADMIN, true));
+
+ // Verify we got back the capability in the effective set.
+ TEST_CHECK(HaveCapability(CAP_SYS_ADMIN).ValueOrDie());
+ });
+}
+
+// This test will verify that PR_SET_KEEPCAPS is not retained
+// across an execve. According to prctl(2):
+// "The "keep capabilities" value will be reset to 0 on subsequent
+// calls to execve(2)."
+TEST_F(PrctlKeepCapsSetuidTest, NoKeepCapsAfterExec) {
+ ASSERT_THAT(prctl(PR_SET_KEEPCAPS, 1, 0, 0, 0), SyscallSucceeds());
+
+ // Verify PR_SET_KEEPCAPS was set before we proceed.
+ ASSERT_THAT(prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0), SyscallSucceedsWithValue(1));
+
+ pid_t child_pid = -1;
+ int execve_errno = 0;
+ // Do an exec and then verify that PR_GET_KEEPCAPS returns 0
+ // see the body of main below.
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec(
+ "/proc/self/exe", {"/proc/self/exe", "--prctl_pr_get_keepcaps"}, {},
+ nullptr, &child_pid, &execve_errno));
+
+ ASSERT_GT(child_pid, 0);
+ ASSERT_EQ(execve_errno, 0);
+
+ int status = 0;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ ASSERT_TRUE(WIFEXITED(status));
+ // PR_SET_KEEPCAPS should have been cleared by the exec.
+ // Success should return gvisor::testing::kPrGetKeepCapsExitBase + 0
+ ASSERT_EQ(WEXITSTATUS(status), kPrGetKeepCapsExitBase);
+}
+
+TEST_F(PrctlKeepCapsSetuidTest, NoKeepCapsAfterNewUserNamespace) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanCreateUserNamespace()));
+
+ // Fork to avoid changing the user namespace of the original test process.
+ pid_t const child_pid = fork();
+
+ if (child_pid == 0) {
+ // Verify that the keepcaps flag is set to 0 when we change user namespaces.
+ TEST_PCHECK(prctl(PR_SET_KEEPCAPS, 1, 0, 0, 0) == 0);
+ MaybeSave();
+
+ TEST_PCHECK(prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0) == 1);
+ MaybeSave();
+
+ TEST_PCHECK(unshare(CLONE_NEWUSER) == 0);
+ MaybeSave();
+
+ TEST_PCHECK(prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0) == 0);
+ MaybeSave();
+
+ _exit(0);
+ }
+
+ int status;
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "status = " << status;
+}
+
+// This test will verify that PR_SET_KEEPCAPS and PR_GET_KEEPCAPS work correctly
+TEST_F(PrctlKeepCapsSetuidTest, PrGetKeepCaps) {
+ // Set PR_SET_KEEPCAPS to the negation of the original.
+ ASSERT_THAT(prctl(PR_SET_KEEPCAPS, !original_keepcaps_, 0, 0, 0),
+ SyscallSucceeds());
+
+ // Verify it was set.
+ ASSERT_THAT(prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0),
+ SyscallSucceedsWithValue(!original_keepcaps_));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
+
+int main(int argc, char** argv) {
+ gvisor::testing::TestInit(&argc, &argv);
+
+ if (absl::GetFlag(FLAGS_prctl_pr_get_keepcaps)) {
+ return gvisor::testing::kPrGetKeepCapsExitBase +
+ prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0);
+ }
+
+ return gvisor::testing::RunAllTests();
+}
diff --git a/test/syscalls/linux/pread64.cc b/test/syscalls/linux/pread64.cc
new file mode 100644
index 000000000..bcdbbb044
--- /dev/null
+++ b/test/syscalls/linux/pread64.cc
@@ -0,0 +1,167 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/unistd.h>
+#include <sys/mman.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class Pread64Test : public ::testing::Test {
+ void SetUp() override {
+ name_ = NewTempAbsPath();
+ ASSERT_NO_ERRNO_AND_VALUE(Open(name_, O_CREAT, 0644));
+ }
+
+ void TearDown() override { unlink(name_.c_str()); }
+
+ public:
+ std::string name_;
+};
+
+TEST(Pread64TestNoTempFile, BadFileDescriptor) {
+ char buf[1024];
+ EXPECT_THAT(pread64(-1, buf, 1024, 0), SyscallFailsWithErrno(EBADF));
+}
+
+TEST_F(Pread64Test, ZeroBuffer) {
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(name_, O_RDWR));
+
+ char msg[] = "hello world";
+ EXPECT_THAT(pwrite64(fd.get(), msg, strlen(msg), 0),
+ SyscallSucceedsWithValue(strlen(msg)));
+
+ char buf[10];
+ EXPECT_THAT(pread64(fd.get(), buf, 0, 0), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(Pread64Test, BadBuffer) {
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(name_, O_RDWR));
+
+ char msg[] = "hello world";
+ EXPECT_THAT(pwrite64(fd.get(), msg, strlen(msg), 0),
+ SyscallSucceedsWithValue(strlen(msg)));
+
+ char* bad_buffer = nullptr;
+ EXPECT_THAT(pread64(fd.get(), bad_buffer, 1024, 0),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(Pread64Test, WriteOnlyNotReadable) {
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(name_, O_WRONLY));
+
+ char buf[1024];
+ EXPECT_THAT(pread64(fd.get(), buf, 1024, 0), SyscallFailsWithErrno(EBADF));
+}
+
+TEST_F(Pread64Test, DirNotReadable) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(GetAbsoluteTestTmpdir(), O_RDONLY));
+
+ char buf[1024];
+ EXPECT_THAT(pread64(fd.get(), buf, 1024, 0), SyscallFailsWithErrno(EISDIR));
+}
+
+TEST_F(Pread64Test, BadOffset) {
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(name_, O_RDONLY));
+
+ char buf[1024];
+ EXPECT_THAT(pread64(fd.get(), buf, 1024, -1), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(Pread64Test, OffsetNotIncremented) {
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(name_, O_RDWR));
+
+ char msg[] = "hello world";
+ EXPECT_THAT(write(fd.get(), msg, strlen(msg)),
+ SyscallSucceedsWithValue(strlen(msg)));
+ int offset;
+ EXPECT_THAT(offset = lseek(fd.get(), 0, SEEK_CUR), SyscallSucceeds());
+
+ char buf1[1024];
+ EXPECT_THAT(pread64(fd.get(), buf1, 1024, 0),
+ SyscallSucceedsWithValue(strlen(msg)));
+ EXPECT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(offset));
+
+ char buf2[1024];
+ EXPECT_THAT(pread64(fd.get(), buf2, 1024, 3),
+ SyscallSucceedsWithValue(strlen(msg) - 3));
+ EXPECT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(offset));
+}
+
+TEST_F(Pread64Test, EndOfFile) {
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(name_, O_RDONLY));
+
+ char buf[1024];
+ EXPECT_THAT(pread64(fd.get(), buf, 1024, 0), SyscallSucceedsWithValue(0));
+}
+
+int memfd_create(const std::string& name, unsigned int flags) {
+ return syscall(__NR_memfd_create, name.c_str(), flags);
+}
+
+TEST_F(Pread64Test, Overflow) {
+ int f = memfd_create("negative", 0);
+ const FileDescriptor fd(f);
+
+ EXPECT_THAT(ftruncate(fd.get(), 0x7fffffffffffffffull), SyscallSucceeds());
+
+ char buf[10];
+ EXPECT_THAT(pread64(fd.get(), buf, sizeof(buf), 0x7fffffffffffffffull),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(Pread64TestNoTempFile, CantReadSocketPair_NoRandomSave) {
+ int sock_fds[2];
+ EXPECT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sock_fds), SyscallSucceeds());
+
+ char buf[1024];
+ EXPECT_THAT(pread64(sock_fds[0], buf, 1024, 0),
+ SyscallFailsWithErrno(ESPIPE));
+ EXPECT_THAT(pread64(sock_fds[1], buf, 1024, 0),
+ SyscallFailsWithErrno(ESPIPE));
+
+ EXPECT_THAT(close(sock_fds[0]), SyscallSucceeds());
+ EXPECT_THAT(close(sock_fds[1]), SyscallSucceeds());
+}
+
+TEST(Pread64TestNoTempFile, CantReadPipe) {
+ char buf[1024];
+
+ int pipe_fds[2];
+ EXPECT_THAT(pipe(pipe_fds), SyscallSucceeds());
+
+ EXPECT_THAT(pread64(pipe_fds[0], buf, 1024, 0),
+ SyscallFailsWithErrno(ESPIPE));
+
+ EXPECT_THAT(close(pipe_fds[0]), SyscallSucceeds());
+ EXPECT_THAT(close(pipe_fds[1]), SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/preadv.cc b/test/syscalls/linux/preadv.cc
new file mode 100644
index 000000000..5b0743fe9
--- /dev/null
+++ b/test/syscalls/linux/preadv.cc
@@ -0,0 +1,95 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <sys/uio.h>
+#include <sys/wait.h>
+#include <unistd.h>
+
+#include <atomic>
+#include <string>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/memory_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+#include "test/util/timer_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Stress copy-on-write. Attempts to reproduce b/38430174.
+TEST(PreadvTest, MMConcurrencyStress) {
+ // Fill a one-page file with zeroes (the contents don't really matter).
+ const auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ /* parent = */ GetAbsoluteTestTmpdir(),
+ /* content = */ std::string(kPageSize, 0), TempPath::kDefaultFileMode));
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY));
+
+ // Get a one-page private mapping to read to.
+ const Mapping m = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+
+ // Repeatedly fork in a separate thread to force the mapping to become
+ // copy-on-write.
+ std::atomic<bool> done(false);
+ const ScopedThread t([&] {
+ while (!done.load()) {
+ const pid_t pid = fork();
+ TEST_CHECK(pid >= 0);
+ if (pid == 0) {
+ // In child. The parent was obviously multithreaded, so it's neither
+ // safe nor necessary to do much more than exit.
+ syscall(SYS_exit_group, 0);
+ }
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(pid, &status, 0),
+ SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "status = " << status;
+ }
+ });
+
+ // Repeatedly read to the mapping.
+ struct iovec iov[2];
+ iov[0].iov_base = m.ptr();
+ iov[0].iov_len = kPageSize / 2;
+ iov[1].iov_base = reinterpret_cast<void*>(m.addr() + kPageSize / 2);
+ iov[1].iov_len = kPageSize / 2;
+ constexpr absl::Duration kTestDuration = absl::Seconds(5);
+ const absl::Time end = absl::Now() + kTestDuration;
+ while (absl::Now() < end) {
+ // Among other causes, save/restore cycles may cause interruptions resulting
+ // in partial reads, so we don't expect any particular return value.
+ EXPECT_THAT(RetryEINTR(preadv)(fd.get(), iov, 2, 0), SyscallSucceeds());
+ }
+
+ // Stop the other thread.
+ done.store(true);
+
+ // The test passes if it neither deadlocks nor crashes the OS.
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/preadv2.cc b/test/syscalls/linux/preadv2.cc
new file mode 100644
index 000000000..4a9acd7ae
--- /dev/null
+++ b/test/syscalls/linux/preadv2.cc
@@ -0,0 +1,280 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <sys/uio.h>
+
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
+#include "test/syscalls/linux/file_base.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+#ifndef SYS_preadv2
+#if defined(__x86_64__)
+#define SYS_preadv2 327
+#elif defined(__aarch64__)
+#define SYS_preadv2 286
+#else
+#error "Unknown architecture"
+#endif
+#endif // SYS_preadv2
+
+#ifndef RWF_HIPRI
+#define RWF_HIPRI 0x1
+#endif // RWF_HIPRI
+
+constexpr int kBufSize = 1024;
+
+std::string SetContent() {
+ std::string content;
+ for (int i = 0; i < kBufSize; i++) {
+ content += static_cast<char>((i % 10) + '0');
+ }
+ return content;
+}
+
+ssize_t preadv2(unsigned long fd, const struct iovec* iov, unsigned long iovcnt,
+ off_t offset, unsigned long flags) {
+ // syscall on preadv2 does some weird things (see man syscall and search
+ // preadv2), so we insert a 0 to word align the flags argument on native.
+ return syscall(SYS_preadv2, fd, iov, iovcnt, offset, 0, flags);
+}
+
+// This test is the base case where we call preadv (no offset, no flags).
+TEST(Preadv2Test, TestBaseCall) {
+ SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ std::string content = SetContent();
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), content, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+
+ std::vector<char> buf(kBufSize);
+ struct iovec iov[2];
+ iov[0].iov_base = buf.data();
+ iov[0].iov_len = buf.size() / 2;
+ iov[1].iov_base = static_cast<char*>(iov[0].iov_base) + (content.size() / 2);
+ iov[1].iov_len = content.size() / 2;
+
+ EXPECT_THAT(preadv2(fd.get(), iov, /*iovcnt*/ 2, /*offset=*/0, /*flags=*/0),
+ SyscallSucceedsWithValue(kBufSize));
+
+ EXPECT_EQ(content, std::string(buf.data(), buf.size()));
+}
+
+// This test is where we call preadv with an offset and no flags.
+TEST(Preadv2Test, TestValidPositiveOffset) {
+ SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ std::string content = SetContent();
+ const std::string prefix = "0";
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), prefix + content, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+
+ std::vector<char> buf(kBufSize, '0');
+ struct iovec iov;
+ iov.iov_base = buf.data();
+ iov.iov_len = buf.size();
+
+ EXPECT_THAT(preadv2(fd.get(), &iov, /*iovcnt=*/1, /*offset=*/prefix.size(),
+ /*flags=*/0),
+ SyscallSucceedsWithValue(kBufSize));
+
+ EXPECT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0));
+
+ EXPECT_EQ(content, std::string(buf.data(), buf.size()));
+}
+
+// This test is the base case where we call readv by using -1 as the offset. The
+// read should use the file offset, so the test increments it by one prior to
+// calling preadv2.
+TEST(Preadv2Test, TestNegativeOneOffset) {
+ SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ std::string content = SetContent();
+ const std::string prefix = "231";
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), prefix + content, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+
+ ASSERT_THAT(lseek(fd.get(), prefix.size(), SEEK_SET),
+ SyscallSucceedsWithValue(prefix.size()));
+
+ std::vector<char> buf(kBufSize, '0');
+ struct iovec iov;
+ iov.iov_base = buf.data();
+ iov.iov_len = buf.size();
+
+ EXPECT_THAT(preadv2(fd.get(), &iov, /*iovcnt=*/1, /*offset=*/-1, /*flags=*/0),
+ SyscallSucceedsWithValue(kBufSize));
+
+ EXPECT_THAT(lseek(fd.get(), 0, SEEK_CUR),
+ SyscallSucceedsWithValue(prefix.size() + buf.size()));
+
+ EXPECT_EQ(content, std::string(buf.data(), buf.size()));
+}
+
+// preadv2 requires if the RWF_HIPRI flag is passed, the fd must be opened with
+// O_DIRECT. This test implements a correct call with the RWF_HIPRI flag.
+TEST(Preadv2Test, TestCallWithRWF_HIPRI) {
+ SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ std::string content = SetContent();
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), content, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+
+ EXPECT_THAT(fsync(fd.get()), SyscallSucceeds());
+
+ std::vector<char> buf(kBufSize, '0');
+ struct iovec iov;
+ iov.iov_base = buf.data();
+ iov.iov_len = buf.size();
+
+ EXPECT_THAT(
+ preadv2(fd.get(), &iov, /*iovcnt=*/1, /*offset=*/0, /*flags=*/RWF_HIPRI),
+ SyscallSucceedsWithValue(kBufSize));
+
+ EXPECT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0));
+
+ EXPECT_EQ(content, std::string(buf.data(), buf.size()));
+}
+// This test calls preadv2 with an invalid flag.
+TEST(Preadv2Test, TestInvalidFlag) {
+ SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY | O_DIRECT));
+
+ std::vector<char> buf(kBufSize, '0');
+ struct iovec iov;
+ iov.iov_base = buf.data();
+ iov.iov_len = buf.size();
+
+ EXPECT_THAT(preadv2(fd.get(), &iov, /*iovcnt=*/1,
+ /*offset=*/0, /*flags=*/0xF0),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+}
+
+// This test calls preadv2 with an invalid offset.
+TEST(Preadv2Test, TestInvalidOffset) {
+ SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY | O_DIRECT));
+
+ auto iov = absl::make_unique<struct iovec[]>(1);
+ iov[0].iov_base = nullptr;
+ iov[0].iov_len = 0;
+
+ EXPECT_THAT(preadv2(fd.get(), iov.get(), /*iovcnt=*/1, /*offset=*/-8,
+ /*flags=*/0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// This test calls preadv with a file set O_WRONLY.
+TEST(Preadv2Test, TestUnreadableFile) {
+ SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY));
+
+ auto iov = absl::make_unique<struct iovec[]>(1);
+ iov[0].iov_base = nullptr;
+ iov[0].iov_len = 0;
+
+ EXPECT_THAT(preadv2(fd.get(), iov.get(), /*iovcnt=*/1,
+ /*offset=*/0, /*flags=*/0),
+ SyscallFailsWithErrno(EBADF));
+}
+
+// Calling preadv2 with a non-negative offset calls preadv. Calling preadv with
+// an unseekable file is not allowed. A pipe is used for an unseekable file.
+TEST(Preadv2Test, TestUnseekableFileInvalid) {
+ SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ int pipe_fds[2];
+
+ ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds());
+
+ auto iov = absl::make_unique<struct iovec[]>(1);
+ iov[0].iov_base = nullptr;
+ iov[0].iov_len = 0;
+
+ EXPECT_THAT(preadv2(pipe_fds[0], iov.get(), /*iovcnt=*/1,
+ /*offset=*/2, /*flags=*/0),
+ SyscallFailsWithErrno(ESPIPE));
+
+ EXPECT_THAT(close(pipe_fds[0]), SyscallSucceeds());
+ EXPECT_THAT(close(pipe_fds[1]), SyscallSucceeds());
+}
+
+TEST(Preadv2Test, TestUnseekableFileValid) {
+ SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ int pipe_fds[2];
+
+ ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds());
+
+ std::vector<char> content(32, 'X');
+
+ EXPECT_THAT(write(pipe_fds[1], content.data(), content.size()),
+ SyscallSucceedsWithValue(content.size()));
+
+ std::vector<char> buf(content.size());
+ auto iov = absl::make_unique<struct iovec[]>(1);
+ iov[0].iov_base = buf.data();
+ iov[0].iov_len = buf.size();
+
+ EXPECT_THAT(preadv2(pipe_fds[0], iov.get(), /*iovcnt=*/1,
+ /*offset=*/static_cast<off_t>(-1), /*flags=*/0),
+ SyscallSucceedsWithValue(buf.size()));
+
+ EXPECT_EQ(content, buf);
+
+ EXPECT_THAT(close(pipe_fds[0]), SyscallSucceeds());
+ EXPECT_THAT(close(pipe_fds[1]), SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/priority.cc b/test/syscalls/linux/priority.cc
new file mode 100644
index 000000000..1d9bdfa70
--- /dev/null
+++ b/test/syscalls/linux/priority.cc
@@ -0,0 +1,216 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/resource.h>
+#include <sys/time.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
+#include "test/util/capability_util.h"
+#include "test/util/fs_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// These tests are for both the getpriority(2) and setpriority(2) syscalls
+// These tests are very rudimentary because getpriority and setpriority
+// have not yet been fully implemented.
+
+// Getpriority does something
+TEST(GetpriorityTest, Implemented) {
+ // "getpriority() can legitimately return the value -1, it is necessary to
+ // clear the external variable errno prior to the call"
+ errno = 0;
+ EXPECT_THAT(getpriority(PRIO_PROCESS, /*who=*/0), SyscallSucceeds());
+}
+
+// Invalid which
+TEST(GetpriorityTest, InvalidWhich) {
+ errno = 0;
+ EXPECT_THAT(getpriority(/*which=*/3, /*who=*/0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// Process is found when which=PRIO_PROCESS
+TEST(GetpriorityTest, ValidWho) {
+ errno = 0;
+ EXPECT_THAT(getpriority(PRIO_PROCESS, getpid()), SyscallSucceeds());
+}
+
+// Process is not found when which=PRIO_PROCESS
+TEST(GetpriorityTest, InvalidWho) {
+ errno = 0;
+ // Flaky, but it's tough to avoid a race condition when finding an unused pid
+ EXPECT_THAT(getpriority(PRIO_PROCESS, /*who=*/INT_MAX - 1),
+ SyscallFailsWithErrno(ESRCH));
+}
+
+// Setpriority does something
+TEST(SetpriorityTest, Implemented) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_NICE)));
+
+ // No need to clear errno for setpriority():
+ // "The setpriority() call returns 0 if there is no error, or -1 if there is"
+ EXPECT_THAT(setpriority(PRIO_PROCESS, /*who=*/0, /*nice=*/16),
+ SyscallSucceeds());
+}
+
+// Invalid which
+TEST(Setpriority, InvalidWhich) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_NICE)));
+
+ EXPECT_THAT(setpriority(/*which=*/3, /*who=*/0, /*nice=*/16),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// Process is found when which=PRIO_PROCESS
+TEST(SetpriorityTest, ValidWho) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_NICE)));
+
+ EXPECT_THAT(setpriority(PRIO_PROCESS, getpid(), /*nice=*/16),
+ SyscallSucceeds());
+}
+
+// niceval is within the range [-20, 19]
+TEST(SetpriorityTest, InsideRange) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_NICE)));
+
+ // Set 0 < niceval < 19
+ int nice = 12;
+ EXPECT_THAT(setpriority(PRIO_PROCESS, getpid(), nice), SyscallSucceeds());
+
+ errno = 0;
+ EXPECT_THAT(getpriority(PRIO_PROCESS, getpid()),
+ SyscallSucceedsWithValue(nice));
+
+ // Set -20 < niceval < 0
+ nice = -12;
+ EXPECT_THAT(setpriority(PRIO_PROCESS, getpid(), nice), SyscallSucceeds());
+
+ errno = 0;
+ EXPECT_THAT(getpriority(PRIO_PROCESS, getpid()),
+ SyscallSucceedsWithValue(nice));
+}
+
+// Verify that priority/niceness are exposed via /proc/PID/stat.
+TEST(SetpriorityTest, NicenessExposedViaProcfs) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_NICE)));
+
+ constexpr int kNiceVal = 12;
+ ASSERT_THAT(setpriority(PRIO_PROCESS, getpid(), kNiceVal), SyscallSucceeds());
+
+ errno = 0;
+ ASSERT_THAT(getpriority(PRIO_PROCESS, getpid()),
+ SyscallSucceedsWithValue(kNiceVal));
+
+ // Now verify we can read that same value via /proc/self/stat.
+ std::string proc_stat;
+ ASSERT_NO_ERRNO(GetContents("/proc/self/stat", &proc_stat));
+ std::vector<std::string> pieces = absl::StrSplit(proc_stat, ' ');
+ ASSERT_GT(pieces.size(), 20);
+
+ int niceness_procfs = 0;
+ ASSERT_TRUE(absl::SimpleAtoi(pieces[18], &niceness_procfs));
+ EXPECT_EQ(niceness_procfs, kNiceVal);
+}
+
+// In the kernel's implementation, values outside the range of [-20, 19] are
+// truncated to these minimum and maximum values. See
+// https://elixir.bootlin.com/linux/v4.4/source/kernel/sys.c#L190
+TEST(SetpriorityTest, OutsideRange) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_NICE)));
+
+ // Set niceval > 19
+ EXPECT_THAT(setpriority(PRIO_PROCESS, getpid(), /*nice=*/100),
+ SyscallSucceeds());
+
+ errno = 0;
+ // Test niceval truncated to 19
+ EXPECT_THAT(getpriority(PRIO_PROCESS, getpid()),
+ SyscallSucceedsWithValue(/*maxnice=*/19));
+
+ // Set niceval < -20
+ EXPECT_THAT(setpriority(PRIO_PROCESS, getpid(), /*nice=*/-100),
+ SyscallSucceeds());
+
+ errno = 0;
+ // Test niceval truncated to -20
+ EXPECT_THAT(getpriority(PRIO_PROCESS, getpid()),
+ SyscallSucceedsWithValue(/*minnice=*/-20));
+}
+
+// Process is not found when which=PRIO_PROCESS
+TEST(SetpriorityTest, InvalidWho) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_NICE)));
+
+ // Flaky, but it's tough to avoid a race condition when finding an unused pid
+ EXPECT_THAT(setpriority(PRIO_PROCESS,
+ /*who=*/INT_MAX - 1,
+ /*nice=*/16),
+ SyscallFailsWithErrno(ESRCH));
+}
+
+// Nice succeeds, correctly modifies (or in this case does not
+// modify priority of process
+TEST(SetpriorityTest, NiceSucceeds) {
+ errno = 0;
+ const int priority_before = getpriority(PRIO_PROCESS, /*who=*/0);
+ ASSERT_THAT(nice(/*inc=*/0), SyscallSucceeds());
+
+ // nice(0) should not change priority
+ EXPECT_EQ(priority_before, getpriority(PRIO_PROCESS, /*who=*/0));
+}
+
+// Threads resulting from clone() maintain parent's priority
+// Changes to child priority do not affect parent's priority
+TEST(GetpriorityTest, CloneMaintainsPriority) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_NICE)));
+
+ constexpr int kParentPriority = 16;
+ constexpr int kChildPriority = 14;
+ ASSERT_THAT(setpriority(PRIO_PROCESS, getpid(), kParentPriority),
+ SyscallSucceeds());
+
+ ScopedThread th([]() {
+ // Check that priority equals that of parent thread
+ pid_t my_tid;
+ EXPECT_THAT(my_tid = syscall(__NR_gettid), SyscallSucceeds());
+ EXPECT_THAT(getpriority(PRIO_PROCESS, my_tid),
+ SyscallSucceedsWithValue(kParentPriority));
+
+ // Change the child thread's priority
+ EXPECT_THAT(setpriority(PRIO_PROCESS, my_tid, kChildPriority),
+ SyscallSucceeds());
+ });
+ th.Join();
+
+ // Check that parent's priority reemained the same even though
+ // the child's priority was altered
+ EXPECT_EQ(kParentPriority, getpriority(PRIO_PROCESS, syscall(__NR_gettid)));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/priority_execve.cc b/test/syscalls/linux/priority_execve.cc
new file mode 100644
index 000000000..5cb343bad
--- /dev/null
+++ b/test/syscalls/linux/priority_execve.cc
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/resource.h>
+#include <sys/time.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+int main(int argc, char** argv, char** envp) {
+ errno = 0;
+ int prio = getpriority(PRIO_PROCESS, getpid());
+
+ // NOTE: getpriority() can legitimately return negative values
+ // in the range [-20, 0). If errno is set, exit with a value that
+ // could not be reached by a valid priority. Valid exit values
+ // for the test are in the range [1, 40], so we'll use 0.
+ if (errno != 0) {
+ printf("getpriority() failed with errno = %d\n", errno);
+ exit(0);
+ }
+
+ // Used by test to verify priority is being maintained through
+ // calls to execve(). Since prio should always be in the range
+ // [-20, 19], we offset by 20 so as not to have negative exit codes.
+ exit(20 - prio);
+
+ return 0;
+}
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
new file mode 100644
index 000000000..d6b875dbf
--- /dev/null
+++ b/test/syscalls/linux/proc.cc
@@ -0,0 +1,2173 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <elf.h>
+#include <errno.h>
+#include <fcntl.h>
+#include <limits.h>
+#include <sched.h>
+#include <signal.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/mman.h>
+#include <sys/prctl.h>
+#include <sys/stat.h>
+#include <sys/utsname.h>
+#include <syscall.h>
+#include <unistd.h>
+
+#include <algorithm>
+#include <atomic>
+#include <functional>
+#include <iostream>
+#include <map>
+#include <memory>
+#include <ostream>
+#include <regex>
+#include <string>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/match.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/synchronization/notification.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/capability_util.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/memory_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+#include "test/util/time_util.h"
+#include "test/util/timer_util.h"
+
+// NOTE(magi): No, this isn't really a syscall but this is a really simple
+// way to get it tested on both gVisor, PTrace and Linux.
+
+using ::testing::AllOf;
+using ::testing::AnyOf;
+using ::testing::ContainerEq;
+using ::testing::Contains;
+using ::testing::ContainsRegex;
+using ::testing::Eq;
+using ::testing::Gt;
+using ::testing::HasSubstr;
+using ::testing::IsSupersetOf;
+using ::testing::Pair;
+using ::testing::UnorderedElementsAre;
+using ::testing::UnorderedElementsAreArray;
+
+// Exported by glibc.
+extern char** environ;
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+#ifndef SUID_DUMP_DISABLE
+#define SUID_DUMP_DISABLE 0
+#endif /* SUID_DUMP_DISABLE */
+#ifndef SUID_DUMP_USER
+#define SUID_DUMP_USER 1
+#endif /* SUID_DUMP_USER */
+#ifndef SUID_DUMP_ROOT
+#define SUID_DUMP_ROOT 2
+#endif /* SUID_DUMP_ROOT */
+
+#if defined(__x86_64__) || defined(__i386__)
+// This list of "required" fields is taken from reading the file
+// arch/x86/kernel/cpu/proc.c and seeing which fields will be unconditionally
+// printed by the kernel.
+static const char* required_fields[] = {
+ "processor",
+ "vendor_id",
+ "cpu family",
+ "model\t\t:",
+ "model name",
+ "stepping",
+ "cpu MHz",
+ "fpu\t\t:",
+ "fpu_exception",
+ "cpuid level",
+ "wp",
+ "bogomips",
+ "clflush size",
+ "cache_alignment",
+ "address sizes",
+ "power management",
+};
+#elif __aarch64__
+// This list of "required" fields is taken from reading the file
+// arch/arm64/kernel/cpuinfo.c and seeing which fields will be unconditionally
+// printed by the kernel.
+static const char* required_fields[] = {
+ "processor", "BogoMIPS", "Features", "CPU implementer",
+ "CPU architecture", "CPU variant", "CPU part", "CPU revision",
+};
+#else
+#error "Unknown architecture"
+#endif
+
+// Takes the subprocess command line and pid.
+// If it returns !OK, WithSubprocess returns immediately.
+using SubprocessCallback = std::function<PosixError(int)>;
+
+std::vector<std::string> saved_argv; // NOLINT
+
+// Helper function to dump /proc/{pid}/status and check the
+// state data. State should = "Z" for zombied or "RSD" for
+// running, interruptible sleeping (S), or uninterruptible sleep
+// (D).
+void CompareProcessState(absl::string_view state, int pid) {
+ auto status_file = ASSERT_NO_ERRNO_AND_VALUE(
+ GetContents(absl::StrCat("/proc/", pid, "/status")));
+ // N.B. POSIX extended regexes don't support shorthand character classes (\w)
+ // inside of brackets.
+ EXPECT_THAT(status_file,
+ ContainsRegex(absl::StrCat("State:.[", state,
+ R"EOL(]\s+\([a-zA-Z ]+\))EOL")));
+}
+
+// Run callbacks while a subprocess is running, zombied, and/or exited.
+PosixError WithSubprocess(SubprocessCallback const& running,
+ SubprocessCallback const& zombied,
+ SubprocessCallback const& exited) {
+ int pipe_fds[2] = {};
+ if (pipe(pipe_fds) < 0) {
+ return PosixError(errno, "pipe");
+ }
+
+ int child_pid = fork();
+ if (child_pid < 0) {
+ return PosixError(errno, "fork");
+ }
+
+ if (child_pid == 0) {
+ close(pipe_fds[0]); // Close the read end.
+ const DisableSave ds; // Timing issues.
+
+ // Write to the pipe to tell it we're ready.
+ char buf = 'a';
+ int res = 0;
+ res = WriteFd(pipe_fds[1], &buf, sizeof(buf));
+ TEST_CHECK_MSG(res == sizeof(buf), "Write failure in subprocess");
+
+ while (true) {
+ SleepSafe(absl::Milliseconds(100));
+ }
+ }
+
+ close(pipe_fds[1]); // Close the write end.
+
+ int status = 0;
+ auto wait_cleanup = Cleanup([child_pid, &status] {
+ EXPECT_THAT(waitpid(child_pid, &status, 0), SyscallSucceeds());
+ });
+ auto kill_cleanup = Cleanup([child_pid] {
+ EXPECT_THAT(kill(child_pid, SIGKILL), SyscallSucceeds());
+ });
+
+ // Wait for the child.
+ char buf = 0;
+ int res = ReadFd(pipe_fds[0], &buf, sizeof(buf));
+ if (res < 0) {
+ return PosixError(errno, "Read from pipe");
+ } else if (res == 0) {
+ return PosixError(EPIPE, "Unable to read from pipe: EOF");
+ }
+
+ if (running) {
+ // The first arg, RSD, refers to a "running process", or a process with a
+ // state of Running (R), Interruptable Sleep (S) or Uninterruptable
+ // Sleep (D).
+ CompareProcessState("RSD", child_pid);
+ RETURN_IF_ERRNO(running(child_pid));
+ }
+
+ // Kill the process.
+ kill_cleanup.Release()();
+ siginfo_t info;
+ // Wait until the child process has exited (WEXITED flag) but don't
+ // reap the child (WNOWAIT flag).
+ EXPECT_THAT(waitid(P_PID, child_pid, &info, WNOWAIT | WEXITED),
+ SyscallSucceeds());
+
+ if (zombied) {
+ // Arg of "Z" refers to a Zombied Process.
+ CompareProcessState("Z", child_pid);
+ RETURN_IF_ERRNO(zombied(child_pid));
+ }
+
+ // Wait on the process.
+ wait_cleanup.Release()();
+ // If the process is reaped, then then this should return
+ // with ECHILD.
+ EXPECT_THAT(waitpid(child_pid, &status, WNOHANG),
+ SyscallFailsWithErrno(ECHILD));
+
+ if (exited) {
+ RETURN_IF_ERRNO(exited(child_pid));
+ }
+
+ return NoError();
+}
+
+// Access the file returned by name when a subprocess is running.
+PosixError AccessWhileRunning(std::function<std::string(int pid)> name,
+ int flags, std::function<void(int fd)> access) {
+ FileDescriptor fd;
+ return WithSubprocess(
+ [&](int pid) -> PosixError {
+ // Running.
+ ASSIGN_OR_RETURN_ERRNO(fd, Open(name(pid), flags));
+
+ access(fd.get());
+ return NoError();
+ },
+ nullptr, nullptr);
+}
+
+// Access the file returned by name when the a subprocess is zombied.
+PosixError AccessWhileZombied(std::function<std::string(int pid)> name,
+ int flags, std::function<void(int fd)> access) {
+ FileDescriptor fd;
+ return WithSubprocess(
+ [&](int pid) -> PosixError {
+ // Running.
+ ASSIGN_OR_RETURN_ERRNO(fd, Open(name(pid), flags));
+ return NoError();
+ },
+ [&](int pid) -> PosixError {
+ // Zombied.
+ access(fd.get());
+ return NoError();
+ },
+ nullptr);
+}
+
+// Access the file returned by name when the a subprocess is exited.
+PosixError AccessWhileExited(std::function<std::string(int pid)> name,
+ int flags, std::function<void(int fd)> access) {
+ FileDescriptor fd;
+ return WithSubprocess(
+ [&](int pid) -> PosixError {
+ // Running.
+ ASSIGN_OR_RETURN_ERRNO(fd, Open(name(pid), flags));
+ return NoError();
+ },
+ nullptr,
+ [&](int pid) -> PosixError {
+ // Exited.
+ access(fd.get());
+ return NoError();
+ });
+}
+
+// ReadFd(fd=/proc/PID/basename) while PID is running.
+int ReadWhileRunning(std::string const& basename, void* buf, size_t count) {
+ int ret = 0;
+ int err = 0;
+ EXPECT_NO_ERRNO(AccessWhileRunning(
+ [&](int pid) -> std::string {
+ return absl::StrCat("/proc/", pid, "/", basename);
+ },
+ O_RDONLY,
+ [&](int fd) {
+ ret = ReadFd(fd, buf, count);
+ err = errno;
+ }));
+ errno = err;
+ return ret;
+}
+
+// ReadFd(fd=/proc/PID/basename) while PID is zombied.
+int ReadWhileZombied(std::string const& basename, void* buf, size_t count) {
+ int ret = 0;
+ int err = 0;
+ EXPECT_NO_ERRNO(AccessWhileZombied(
+ [&](int pid) -> std::string {
+ return absl::StrCat("/proc/", pid, "/", basename);
+ },
+ O_RDONLY,
+ [&](int fd) {
+ ret = ReadFd(fd, buf, count);
+ err = errno;
+ }));
+ errno = err;
+ return ret;
+}
+
+// ReadFd(fd=/proc/PID/basename) while PID is exited.
+int ReadWhileExited(std::string const& basename, void* buf, size_t count) {
+ int ret = 0;
+ int err = 0;
+ EXPECT_NO_ERRNO(AccessWhileExited(
+ [&](int pid) -> std::string {
+ return absl::StrCat("/proc/", pid, "/", basename);
+ },
+ O_RDONLY,
+ [&](int fd) {
+ ret = ReadFd(fd, buf, count);
+ err = errno;
+ }));
+ errno = err;
+ return ret;
+}
+
+// readlinkat(fd=/proc/PID/, basename) while PID is running.
+int ReadlinkWhileRunning(std::string const& basename, char* buf, size_t count) {
+ int ret = 0;
+ int err = 0;
+ EXPECT_NO_ERRNO(AccessWhileRunning(
+ [&](int pid) -> std::string { return absl::StrCat("/proc/", pid, "/"); },
+ O_DIRECTORY,
+ [&](int fd) {
+ ret = readlinkat(fd, basename.c_str(), buf, count);
+ err = errno;
+ }));
+ errno = err;
+ return ret;
+}
+
+// readlinkat(fd=/proc/PID/, basename) while PID is zombied.
+int ReadlinkWhileZombied(std::string const& basename, char* buf, size_t count) {
+ int ret = 0;
+ int err = 0;
+ EXPECT_NO_ERRNO(AccessWhileZombied(
+ [&](int pid) -> std::string { return absl::StrCat("/proc/", pid, "/"); },
+ O_DIRECTORY,
+ [&](int fd) {
+ ret = readlinkat(fd, basename.c_str(), buf, count);
+ err = errno;
+ }));
+ errno = err;
+ return ret;
+}
+
+// readlinkat(fd=/proc/PID/, basename) while PID is exited.
+int ReadlinkWhileExited(std::string const& basename, char* buf, size_t count) {
+ int ret = 0;
+ int err = 0;
+ EXPECT_NO_ERRNO(AccessWhileExited(
+ [&](int pid) -> std::string { return absl::StrCat("/proc/", pid, "/"); },
+ O_DIRECTORY,
+ [&](int fd) {
+ ret = readlinkat(fd, basename.c_str(), buf, count);
+ err = errno;
+ }));
+ errno = err;
+ return ret;
+}
+
+TEST(ProcTest, NotFoundInRoot) {
+ struct stat s;
+ EXPECT_THAT(stat("/proc/foobar", &s), SyscallFailsWithErrno(ENOENT));
+}
+
+TEST(ProcSelfTest, IsThreadGroupLeader) {
+ ScopedThread([] {
+ const pid_t tgid = getpid();
+ const pid_t tid = syscall(SYS_gettid);
+ EXPECT_NE(tgid, tid);
+ auto link = ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/self"));
+ EXPECT_EQ(link, absl::StrCat(tgid));
+ });
+}
+
+TEST(ProcThreadSelfTest, Basic) {
+ const pid_t tgid = getpid();
+ const pid_t tid = syscall(SYS_gettid);
+ EXPECT_EQ(tgid, tid);
+ auto link_threadself =
+ ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/thread-self"));
+ EXPECT_EQ(link_threadself, absl::StrCat(tgid, "/task/", tid));
+ // Just read one file inside thread-self to ensure that the link is valid.
+ auto link_threadself_exe =
+ ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/thread-self/exe"));
+ auto link_procself_exe =
+ ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/self/exe"));
+ EXPECT_EQ(link_threadself_exe, link_procself_exe);
+}
+
+TEST(ProcThreadSelfTest, Thread) {
+ ScopedThread([] {
+ const pid_t tgid = getpid();
+ const pid_t tid = syscall(SYS_gettid);
+ EXPECT_NE(tgid, tid);
+ auto link_threadself =
+ ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/thread-self"));
+
+ EXPECT_EQ(link_threadself, absl::StrCat(tgid, "/task/", tid));
+ // Just read one file inside thread-self to ensure that the link is valid.
+ auto link_threadself_exe =
+ ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/thread-self/exe"));
+ auto link_procself_exe =
+ ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/self/exe"));
+ EXPECT_EQ(link_threadself_exe, link_procself_exe);
+ // A thread should not have "/proc/<tid>/task".
+ struct stat s;
+ EXPECT_THAT(stat("/proc/thread-self/task", &s),
+ SyscallFailsWithErrno(ENOENT));
+ });
+}
+
+// Returns the /proc/PID/maps entry for the MAP_PRIVATE | MAP_ANONYMOUS mapping
+// m with start address addr and length len.
+std::string AnonymousMapsEntry(uintptr_t addr, size_t len, int prot) {
+ return absl::StrCat(absl::Hex(addr, absl::PadSpec::kZeroPad8), "-",
+ absl::Hex(addr + len, absl::PadSpec::kZeroPad8), " ",
+ prot & PROT_READ ? "r" : "-",
+ prot & PROT_WRITE ? "w" : "-",
+ prot & PROT_EXEC ? "x" : "-", "p 00000000 00:00 0 ");
+}
+
+std::string AnonymousMapsEntryForMapping(const Mapping& m, int prot) {
+ return AnonymousMapsEntry(m.addr(), m.len(), prot);
+}
+
+PosixErrorOr<std::map<uint64_t, uint64_t>> ReadProcSelfAuxv() {
+ std::string auxv_file;
+ RETURN_IF_ERRNO(GetContents("/proc/self/auxv", &auxv_file));
+ const Elf64_auxv_t* auxv_data =
+ reinterpret_cast<const Elf64_auxv_t*>(auxv_file.data());
+ std::map<uint64_t, uint64_t> auxv_entries;
+ for (int i = 0; auxv_data[i].a_type != AT_NULL; i++) {
+ auto a_type = auxv_data[i].a_type;
+ EXPECT_EQ(0, auxv_entries.count(a_type)) << "a_type: " << a_type;
+ auxv_entries.emplace(a_type, auxv_data[i].a_un.a_val);
+ }
+ return auxv_entries;
+}
+
+TEST(ProcSelfAuxv, EntryPresence) {
+ auto auxv_entries = ASSERT_NO_ERRNO_AND_VALUE(ReadProcSelfAuxv());
+
+ EXPECT_EQ(auxv_entries.count(AT_ENTRY), 1);
+ EXPECT_EQ(auxv_entries.count(AT_PHDR), 1);
+ EXPECT_EQ(auxv_entries.count(AT_PHENT), 1);
+ EXPECT_EQ(auxv_entries.count(AT_PHNUM), 1);
+ EXPECT_EQ(auxv_entries.count(AT_BASE), 1);
+ EXPECT_EQ(auxv_entries.count(AT_UID), 1);
+ EXPECT_EQ(auxv_entries.count(AT_EUID), 1);
+ EXPECT_EQ(auxv_entries.count(AT_GID), 1);
+ EXPECT_EQ(auxv_entries.count(AT_EGID), 1);
+ EXPECT_EQ(auxv_entries.count(AT_SECURE), 1);
+ EXPECT_EQ(auxv_entries.count(AT_CLKTCK), 1);
+ EXPECT_EQ(auxv_entries.count(AT_RANDOM), 1);
+ EXPECT_EQ(auxv_entries.count(AT_EXECFN), 1);
+ EXPECT_EQ(auxv_entries.count(AT_PAGESZ), 1);
+ EXPECT_EQ(auxv_entries.count(AT_SYSINFO_EHDR), 1);
+}
+
+TEST(ProcSelfAuxv, EntryValues) {
+ auto proc_auxv = ASSERT_NO_ERRNO_AND_VALUE(ReadProcSelfAuxv());
+
+ // We need to find the ELF auxiliary vector. The section of memory pointed to
+ // by envp contains some pointers to non-null pointers, followed by a single
+ // pointer to a null pointer, followed by the auxiliary vector.
+ char** envpi = environ;
+ while (*envpi) {
+ ++envpi;
+ }
+
+ const Elf64_auxv_t* envp_auxv =
+ reinterpret_cast<const Elf64_auxv_t*>(envpi + 1);
+ int i;
+ for (i = 0; envp_auxv[i].a_type != AT_NULL; i++) {
+ auto a_type = envp_auxv[i].a_type;
+ EXPECT_EQ(proc_auxv.count(a_type), 1);
+ EXPECT_EQ(proc_auxv[a_type], envp_auxv[i].a_un.a_val)
+ << "a_type: " << a_type;
+ }
+ EXPECT_EQ(i, proc_auxv.size());
+}
+
+// Just open and read /proc/self/maps, check that we can find [stack]
+TEST(ProcSelfMaps, Basic) {
+ auto proc_self_maps =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps"));
+
+ std::vector<std::string> strings = absl::StrSplit(proc_self_maps, '\n');
+ std::vector<std::string> stacks;
+ // Make sure there's a stack in there.
+ for (const auto& str : strings) {
+ if (str.find("[stack]") != std::string::npos) {
+ stacks.push_back(str);
+ }
+ }
+ ASSERT_EQ(1, stacks.size()) << "[stack] not found in: " << proc_self_maps;
+ // Linux pads to 73 characters then we add 7.
+ EXPECT_EQ(80, stacks[0].length());
+}
+
+TEST(ProcSelfMaps, Map1) {
+ Mapping mapping =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_READ, MAP_PRIVATE));
+ auto proc_self_maps =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps"));
+ std::vector<std::string> strings = absl::StrSplit(proc_self_maps, '\n');
+ std::vector<std::string> addrs;
+ // Make sure if is listed.
+ for (const auto& str : strings) {
+ if (str == AnonymousMapsEntryForMapping(mapping, PROT_READ)) {
+ addrs.push_back(str);
+ }
+ }
+ ASSERT_EQ(1, addrs.size());
+}
+
+TEST(ProcSelfMaps, Map2) {
+ // NOTE(magi): The permissions must be different or the pages will get merged.
+ Mapping map1 = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_EXEC, MAP_PRIVATE));
+ Mapping map2 =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_WRITE, MAP_PRIVATE));
+
+ auto proc_self_maps =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps"));
+ std::vector<std::string> strings = absl::StrSplit(proc_self_maps, '\n');
+ std::vector<std::string> addrs;
+ // Make sure if is listed.
+ for (const auto& str : strings) {
+ if (str == AnonymousMapsEntryForMapping(map1, PROT_READ | PROT_EXEC)) {
+ addrs.push_back(str);
+ }
+ }
+ ASSERT_EQ(1, addrs.size());
+ addrs.clear();
+ for (const auto& str : strings) {
+ if (str == AnonymousMapsEntryForMapping(map2, PROT_WRITE)) {
+ addrs.push_back(str);
+ }
+ }
+ ASSERT_EQ(1, addrs.size());
+}
+
+TEST(ProcSelfMaps, MapUnmap) {
+ Mapping map1 = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_EXEC, MAP_PRIVATE));
+ Mapping map2 =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_WRITE, MAP_PRIVATE));
+
+ auto proc_self_maps =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps"));
+ std::vector<std::string> strings = absl::StrSplit(proc_self_maps, '\n');
+ std::vector<std::string> addrs;
+ // Make sure if is listed.
+ for (const auto& str : strings) {
+ if (str == AnonymousMapsEntryForMapping(map1, PROT_READ | PROT_EXEC)) {
+ addrs.push_back(str);
+ }
+ }
+ ASSERT_EQ(1, addrs.size()) << proc_self_maps;
+ addrs.clear();
+ for (const auto& str : strings) {
+ if (str == AnonymousMapsEntryForMapping(map2, PROT_WRITE)) {
+ addrs.push_back(str);
+ }
+ }
+ ASSERT_EQ(1, addrs.size());
+
+ map2.reset();
+
+ // Read it again.
+ proc_self_maps = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps"));
+ strings = absl::StrSplit(proc_self_maps, '\n');
+ // First entry should be there.
+ addrs.clear();
+ for (const auto& str : strings) {
+ if (str == AnonymousMapsEntryForMapping(map1, PROT_READ | PROT_EXEC)) {
+ addrs.push_back(str);
+ }
+ }
+ ASSERT_EQ(1, addrs.size());
+ addrs.clear();
+ // But not the second.
+ for (const auto& str : strings) {
+ if (str == AnonymousMapsEntryForMapping(map2, PROT_WRITE)) {
+ addrs.push_back(str);
+ }
+ }
+ ASSERT_EQ(0, addrs.size());
+}
+
+TEST(ProcSelfMaps, Mprotect) {
+ // FIXME(jamieliu): Linux's mprotect() sometimes fails to merge VMAs in this
+ // case.
+ SKIP_IF(!IsRunningOnGvisor());
+
+ // Reserve 5 pages of address space.
+ Mapping m = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(5 * kPageSize, PROT_NONE, MAP_PRIVATE));
+
+ // Change the permissions on the middle 3 pages. (The first and last pages may
+ // be merged with other vmas on either side, so they aren't tested directly;
+ // they just ensure that the middle 3 pages are bracketed by VMAs with
+ // incompatible permissions.)
+ ASSERT_THAT(mprotect(reinterpret_cast<void*>(m.addr() + kPageSize),
+ 3 * kPageSize, PROT_READ),
+ SyscallSucceeds());
+
+ // Check that the middle 3 pages make up a single VMA.
+ auto proc_self_maps =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps"));
+ std::vector<std::string> strings = absl::StrSplit(proc_self_maps, '\n');
+ EXPECT_THAT(strings, Contains(AnonymousMapsEntry(m.addr() + kPageSize,
+ 3 * kPageSize, PROT_READ)));
+
+ // Change the permissions on the middle page only.
+ ASSERT_THAT(mprotect(reinterpret_cast<void*>(m.addr() + 2 * kPageSize),
+ kPageSize, PROT_READ | PROT_WRITE),
+ SyscallSucceeds());
+
+ // Check that the single VMA has been split into 3 VMAs.
+ proc_self_maps = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps"));
+ strings = absl::StrSplit(proc_self_maps, '\n');
+ EXPECT_THAT(
+ strings,
+ IsSupersetOf(
+ {AnonymousMapsEntry(m.addr() + kPageSize, kPageSize, PROT_READ),
+ AnonymousMapsEntry(m.addr() + 2 * kPageSize, kPageSize,
+ PROT_READ | PROT_WRITE),
+ AnonymousMapsEntry(m.addr() + 3 * kPageSize, kPageSize,
+ PROT_READ)}));
+
+ // Change the permissions on the middle page back.
+ ASSERT_THAT(mprotect(reinterpret_cast<void*>(m.addr() + 2 * kPageSize),
+ kPageSize, PROT_READ),
+ SyscallSucceeds());
+
+ // Check that the 3 VMAs have been merged back into a single VMA.
+ proc_self_maps = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps"));
+ strings = absl::StrSplit(proc_self_maps, '\n');
+ EXPECT_THAT(strings, Contains(AnonymousMapsEntry(m.addr() + kPageSize,
+ 3 * kPageSize, PROT_READ)));
+}
+
+TEST(ProcSelfFd, OpenFd) {
+ int pipe_fds[2];
+ ASSERT_THAT(pipe2(pipe_fds, O_CLOEXEC), SyscallSucceeds());
+
+ // Reopen the write end.
+ const std::string path = absl::StrCat("/proc/self/fd/", pipe_fds[1]);
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_WRONLY));
+
+ // Ensure that a read/write works.
+ const std::string data = "hello";
+ std::unique_ptr<char[]> buffer(new char[data.size()]);
+ EXPECT_THAT(write(fd.get(), data.c_str(), data.size()),
+ SyscallSucceedsWithValue(5));
+ EXPECT_THAT(read(pipe_fds[0], buffer.get(), data.size()),
+ SyscallSucceedsWithValue(5));
+ EXPECT_EQ(strncmp(buffer.get(), data.c_str(), data.size()), 0);
+
+ // Cleanup.
+ ASSERT_THAT(close(pipe_fds[0]), SyscallSucceeds());
+ ASSERT_THAT(close(pipe_fds[1]), SyscallSucceeds());
+}
+
+TEST(ProcSelfFdInfo, CorrectFds) {
+ // Make sure there is at least one open file.
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY));
+
+ // Get files in /proc/self/fd.
+ auto fd_files = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/proc/self/fd", false));
+
+ // Get files in /proc/self/fdinfo.
+ auto fdinfo_files =
+ ASSERT_NO_ERRNO_AND_VALUE(ListDir("/proc/self/fdinfo", false));
+
+ // They should contain the same fds.
+ EXPECT_THAT(fd_files, UnorderedElementsAreArray(fdinfo_files));
+
+ // Both should contain fd.
+ auto fd_s = absl::StrCat(fd.get());
+ EXPECT_THAT(fd_files, Contains(fd_s));
+}
+
+TEST(ProcSelfFdInfo, Flags) {
+ std::string path = NewTempAbsPath();
+
+ // Create file here with O_CREAT to test that O_CREAT does not appear in
+ // fdinfo flags.
+ int flags = O_CREAT | O_RDWR | O_APPEND | O_CLOEXEC;
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, flags, 0644));
+
+ // Automatically delete path.
+ TempPath temp_path(path);
+
+ // O_CREAT does not appear in fdinfo flags.
+ flags &= ~O_CREAT;
+
+ // O_LARGEFILE always appears (on x86_64).
+ flags |= kOLargeFile;
+
+ auto fd_info = ASSERT_NO_ERRNO_AND_VALUE(
+ GetContents(absl::StrCat("/proc/self/fdinfo/", fd.get())));
+ EXPECT_THAT(fd_info, HasSubstr(absl::StrFormat("flags:\t%#o", flags)));
+}
+
+TEST(ProcSelfExe, Absolute) {
+ auto exe = ASSERT_NO_ERRNO_AND_VALUE(
+ ReadLink(absl::StrCat("/proc/", getpid(), "/exe")));
+ EXPECT_EQ(exe[0], '/');
+}
+
+// Sanity check for /proc/cpuinfo fields that must be present.
+TEST(ProcCpuinfo, RequiredFieldsArePresent) {
+ std::string proc_cpuinfo =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/cpuinfo"));
+ ASSERT_FALSE(proc_cpuinfo.empty());
+ std::vector<std::string> cpuinfo_fields = absl::StrSplit(proc_cpuinfo, '\n');
+
+ // Check that the usual fields are there. We don't really care about the
+ // contents.
+ for (const std::string& field : required_fields) {
+ EXPECT_THAT(proc_cpuinfo, HasSubstr(field));
+ }
+}
+
+TEST(ProcCpuinfo, DeniesWriteNonRoot) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_FOWNER)));
+
+ // Do setuid in a separate thread so that after finishing this test, the
+ // process can still open files the test harness created before starting this
+ // test. Otherwise, the files are created by root (UID before the test), but
+ // cannot be opened by the `uid` set below after the test. After calling
+ // setuid(non-zero-UID), there is no way to get root privileges back.
+ ScopedThread([&] {
+ // Use syscall instead of glibc setuid wrapper because we want this setuid
+ // call to only apply to this task. POSIX threads, however, require that all
+ // threads have the same UIDs, so using the setuid wrapper sets all threads'
+ // real UID.
+ // Also drops capabilities.
+ constexpr int kNobody = 65534;
+ EXPECT_THAT(syscall(SYS_setuid, kNobody), SyscallSucceeds());
+ EXPECT_THAT(open("/proc/cpuinfo", O_WRONLY), SyscallFailsWithErrno(EACCES));
+ // TODO(gvisor.dev/issue/1193): Properly support setting size attributes in
+ // kernfs.
+ if (!IsRunningOnGvisor() || IsRunningWithVFS1()) {
+ EXPECT_THAT(truncate("/proc/cpuinfo", 123),
+ SyscallFailsWithErrno(EACCES));
+ }
+ });
+}
+
+// With root privileges, it is possible to open /proc/cpuinfo with write mode,
+// but all write operations will return EIO.
+TEST(ProcCpuinfo, DeniesWriteRoot) {
+ // VFS1 does not behave differently for root/non-root.
+ SKIP_IF(IsRunningWithVFS1());
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_FOWNER)));
+
+ int fd;
+ EXPECT_THAT(fd = open("/proc/cpuinfo", O_WRONLY), SyscallSucceeds());
+ if (fd > 0) {
+ EXPECT_THAT(write(fd, "x", 1), SyscallFailsWithErrno(EIO));
+ EXPECT_THAT(pwrite(fd, "x", 1, 123), SyscallFailsWithErrno(EIO));
+ }
+ // TODO(gvisor.dev/issue/1193): Properly support setting size attributes in
+ // kernfs.
+ if (!IsRunningOnGvisor() || IsRunningWithVFS1()) {
+ if (fd > 0) {
+ EXPECT_THAT(ftruncate(fd, 123), SyscallFailsWithErrno(EIO));
+ }
+ EXPECT_THAT(truncate("/proc/cpuinfo", 123), SyscallFailsWithErrno(EIO));
+ }
+}
+
+// Sanity checks that uptime is present.
+TEST(ProcUptime, IsPresent) {
+ std::string proc_uptime =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/uptime"));
+ ASSERT_FALSE(proc_uptime.empty());
+ std::vector<std::string> uptime_parts = absl::StrSplit(proc_uptime, ' ');
+
+ // Parse once.
+ double uptime0, uptime1, idletime0, idletime1;
+ ASSERT_TRUE(absl::SimpleAtod(uptime_parts[0], &uptime0));
+ ASSERT_TRUE(absl::SimpleAtod(uptime_parts[1], &idletime0));
+
+ // Sleep for one second.
+ absl::SleepFor(absl::Seconds(1));
+
+ // Parse again.
+ proc_uptime = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/uptime"));
+ ASSERT_FALSE(proc_uptime.empty());
+ uptime_parts = absl::StrSplit(proc_uptime, ' ');
+ ASSERT_TRUE(absl::SimpleAtod(uptime_parts[0], &uptime1));
+ ASSERT_TRUE(absl::SimpleAtod(uptime_parts[1], &idletime1));
+
+ // Sanity check.
+ //
+ // We assert that between 0.99 and 59.99 seconds have passed. If more than a
+ // minute has passed, then we must be executing really, really slowly.
+ EXPECT_GE(uptime0, 0.0);
+ EXPECT_GE(idletime0, 0.0);
+ EXPECT_GT(uptime1, uptime0);
+ EXPECT_GE(uptime1, uptime0 + 0.99);
+ EXPECT_LE(uptime1, uptime0 + 59.99);
+ EXPECT_GE(idletime1, idletime0);
+}
+
+TEST(ProcMeminfo, ContainsBasicFields) {
+ std::string proc_meminfo =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/meminfo"));
+ EXPECT_THAT(proc_meminfo, AllOf(ContainsRegex(R"(MemTotal:\s+[0-9]+ kB)"),
+ ContainsRegex(R"(MemFree:\s+[0-9]+ kB)")));
+}
+
+TEST(ProcStat, ContainsBasicFields) {
+ std::string proc_stat = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/stat"));
+
+ std::vector<std::string> names;
+ for (auto const& line : absl::StrSplit(proc_stat, '\n')) {
+ std::vector<std::string> fields =
+ absl::StrSplit(line, ' ', absl::SkipWhitespace());
+ if (fields.empty()) {
+ continue;
+ }
+ names.push_back(fields[0]);
+ }
+
+ EXPECT_THAT(names,
+ IsSupersetOf({"cpu", "intr", "ctxt", "btime", "processes",
+ "procs_running", "procs_blocked", "softirq"}));
+}
+
+TEST(ProcStat, EndsWithNewline) {
+ std::string proc_stat = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/stat"));
+ EXPECT_EQ(proc_stat.back(), '\n');
+}
+
+TEST(ProcStat, Fields) {
+ std::string proc_stat = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/stat"));
+
+ std::vector<std::string> names;
+ for (auto const& line : absl::StrSplit(proc_stat, '\n')) {
+ std::vector<std::string> fields =
+ absl::StrSplit(line, ' ', absl::SkipWhitespace());
+ if (fields.empty()) {
+ continue;
+ }
+
+ if (absl::StartsWith(fields[0], "cpu")) {
+ // As of Linux 3.11, each CPU entry has 10 fields, plus the name.
+ EXPECT_GE(fields.size(), 11) << proc_stat;
+ } else if (fields[0] == "ctxt") {
+ // Single field.
+ EXPECT_EQ(fields.size(), 2) << proc_stat;
+ } else if (fields[0] == "btime") {
+ // Single field.
+ EXPECT_EQ(fields.size(), 2) << proc_stat;
+ } else if (fields[0] == "itime") {
+ // Single field.
+ ASSERT_EQ(fields.size(), 2) << proc_stat;
+ // This is the only floating point field.
+ double val;
+ EXPECT_TRUE(absl::SimpleAtod(fields[1], &val)) << proc_stat;
+ continue;
+ } else if (fields[0] == "processes") {
+ // Single field.
+ EXPECT_EQ(fields.size(), 2) << proc_stat;
+ } else if (fields[0] == "procs_running") {
+ // Single field.
+ EXPECT_EQ(fields.size(), 2) << proc_stat;
+ } else if (fields[0] == "procs_blocked") {
+ // Single field.
+ EXPECT_EQ(fields.size(), 2) << proc_stat;
+ } else if (fields[0] == "softirq") {
+ // As of Linux 3.11, there are 10 softirqs. 12 fields for name + total.
+ EXPECT_GE(fields.size(), 12) << proc_stat;
+ }
+
+ // All fields besides itime are valid base 10 numbers.
+ for (size_t i = 1; i < fields.size(); i++) {
+ uint64_t val;
+ EXPECT_TRUE(absl::SimpleAtoi(fields[i], &val)) << proc_stat;
+ }
+ }
+}
+
+TEST(ProcLoadavg, EndsWithNewline) {
+ std::string proc_loadvg =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/loadavg"));
+ EXPECT_EQ(proc_loadvg.back(), '\n');
+}
+
+TEST(ProcLoadavg, Fields) {
+ std::string proc_loadvg =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/loadavg"));
+ std::vector<std::string> lines = absl::StrSplit(proc_loadvg, '\n');
+
+ // Single line.
+ EXPECT_EQ(lines.size(), 2) << proc_loadvg;
+
+ std::vector<std::string> fields =
+ absl::StrSplit(lines[0], absl::ByAnyChar(" /"), absl::SkipWhitespace());
+
+ // Six fields.
+ EXPECT_EQ(fields.size(), 6) << proc_loadvg;
+
+ double val;
+ uint64_t val2;
+ // First three fields are floating point numbers.
+ EXPECT_TRUE(absl::SimpleAtod(fields[0], &val)) << proc_loadvg;
+ EXPECT_TRUE(absl::SimpleAtod(fields[1], &val)) << proc_loadvg;
+ EXPECT_TRUE(absl::SimpleAtod(fields[2], &val)) << proc_loadvg;
+ // Rest of the fields are valid base 10 numbers.
+ EXPECT_TRUE(absl::SimpleAtoi(fields[3], &val2)) << proc_loadvg;
+ EXPECT_TRUE(absl::SimpleAtoi(fields[4], &val2)) << proc_loadvg;
+ EXPECT_TRUE(absl::SimpleAtoi(fields[5], &val2)) << proc_loadvg;
+}
+
+// NOTE: Tests in priority.cc also check certain priority related fields in
+// /proc/self/stat.
+
+class ProcPidStatTest : public ::testing::TestWithParam<std::string> {};
+
+TEST_P(ProcPidStatTest, HasBasicFields) {
+ std::string proc_pid_stat = ASSERT_NO_ERRNO_AND_VALUE(
+ GetContents(absl::StrCat("/proc/", GetParam(), "/stat")));
+
+ ASSERT_FALSE(proc_pid_stat.empty());
+ std::vector<std::string> fields = absl::StrSplit(proc_pid_stat, ' ');
+ ASSERT_GE(fields.size(), 24);
+ EXPECT_EQ(absl::StrCat(getpid()), fields[0]);
+ // fields[1] is the thread name.
+ EXPECT_EQ("R", fields[2]); // task state
+ EXPECT_EQ(absl::StrCat(getppid()), fields[3]);
+
+ // If the test starts up quickly, then the process start time and the kernel
+ // boot time will be very close, and the proc starttime field (which is the
+ // delta of the two times) will be 0. For that unfortunate reason, we can
+ // only check that starttime >= 0, and not that it is strictly > 0.
+ uint64_t starttime;
+ ASSERT_TRUE(absl::SimpleAtoi(fields[21], &starttime));
+ EXPECT_GE(starttime, 0);
+
+ uint64_t vss;
+ ASSERT_TRUE(absl::SimpleAtoi(fields[22], &vss));
+ EXPECT_GT(vss, 0);
+
+ uint64_t rss;
+ ASSERT_TRUE(absl::SimpleAtoi(fields[23], &rss));
+ EXPECT_GT(rss, 0);
+
+ uint64_t rsslim;
+ ASSERT_TRUE(absl::SimpleAtoi(fields[24], &rsslim));
+ EXPECT_GT(rsslim, 0);
+}
+
+INSTANTIATE_TEST_SUITE_P(SelfAndNumericPid, ProcPidStatTest,
+ ::testing::Values("self", absl::StrCat(getpid())));
+
+using ProcPidStatmTest = ::testing::TestWithParam<std::string>;
+
+TEST_P(ProcPidStatmTest, HasBasicFields) {
+ std::string proc_pid_statm = ASSERT_NO_ERRNO_AND_VALUE(
+ GetContents(absl::StrCat("/proc/", GetParam(), "/statm")));
+ ASSERT_FALSE(proc_pid_statm.empty());
+ std::vector<std::string> fields = absl::StrSplit(proc_pid_statm, ' ');
+ ASSERT_GE(fields.size(), 7);
+
+ uint64_t vss;
+ ASSERT_TRUE(absl::SimpleAtoi(fields[0], &vss));
+ EXPECT_GT(vss, 0);
+
+ uint64_t rss;
+ ASSERT_TRUE(absl::SimpleAtoi(fields[1], &rss));
+ EXPECT_GT(rss, 0);
+}
+
+INSTANTIATE_TEST_SUITE_P(SelfAndNumericPid, ProcPidStatmTest,
+ ::testing::Values("self", absl::StrCat(getpid())));
+
+PosixErrorOr<uint64_t> CurrentRSS() {
+ ASSIGN_OR_RETURN_ERRNO(auto proc_self_stat, GetContents("/proc/self/stat"));
+ if (proc_self_stat.empty()) {
+ return PosixError(EINVAL, "empty /proc/self/stat");
+ }
+
+ std::vector<std::string> fields = absl::StrSplit(proc_self_stat, ' ');
+ if (fields.size() < 24) {
+ return PosixError(
+ EINVAL,
+ absl::StrCat("/proc/self/stat has too few fields: ", proc_self_stat));
+ }
+
+ uint64_t rss;
+ if (!absl::SimpleAtoi(fields[23], &rss)) {
+ return PosixError(
+ EINVAL, absl::StrCat("/proc/self/stat RSS field is not a number: ",
+ fields[23]));
+ }
+
+ // RSS is given in number of pages.
+ return rss * kPageSize;
+}
+
+// The size of mapping created by MapPopulateRSS.
+constexpr uint64_t kMappingSize = 100 << 20;
+
+// Tolerance on RSS comparisons to account for background thread mappings,
+// reclaimed pages, newly faulted pages, etc.
+constexpr uint64_t kRSSTolerance = 10 << 20;
+
+// Capture RSS before and after an anonymous mapping with passed prot.
+void MapPopulateRSS(int prot, uint64_t* before, uint64_t* after) {
+ *before = ASSERT_NO_ERRNO_AND_VALUE(CurrentRSS());
+
+ // N.B. The kernel asynchronously accumulates per-task RSS counters into the
+ // mm RSS, which is exposed by /proc/PID/stat. Task exit is a synchronization
+ // point (kernel/exit.c:do_exit -> sync_mm_rss), so perform the mapping on
+ // another thread to ensure it is reflected in RSS after the thread exits.
+ Mapping mapping;
+ ScopedThread t([&mapping, prot] {
+ mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kMappingSize, prot, MAP_PRIVATE | MAP_POPULATE));
+ });
+ t.Join();
+
+ *after = ASSERT_NO_ERRNO_AND_VALUE(CurrentRSS());
+}
+
+// TODO(b/73896574): Test for PROT_READ + MAP_POPULATE anonymous mappings. Their
+// semantics are more subtle:
+//
+// Small pages -> Zero page mapped, not counted in RSS
+// (mm/memory.c:do_anonymous_page).
+//
+// Huge pages (THP enabled, use_zero_page=0) -> Pages committed
+// (mm/memory.c:__handle_mm_fault -> create_huge_pmd).
+//
+// Huge pages (THP enabled, use_zero_page=1) -> Zero page mapped, not counted in
+// RSS (mm/huge_memory.c:do_huge_pmd_anonymous_page).
+
+// PROT_WRITE + MAP_POPULATE anonymous mappings are always committed.
+TEST(ProcSelfStat, PopulateWriteRSS) {
+ uint64_t before, after;
+ MapPopulateRSS(PROT_READ | PROT_WRITE, &before, &after);
+
+ // Mapping is committed.
+ EXPECT_NEAR(before + kMappingSize, after, kRSSTolerance);
+}
+
+// PROT_NONE + MAP_POPULATE anonymous mappings are never committed.
+TEST(ProcSelfStat, PopulateNoneRSS) {
+ uint64_t before, after;
+ MapPopulateRSS(PROT_NONE, &before, &after);
+
+ // Mapping not committed.
+ EXPECT_NEAR(before, after, kRSSTolerance);
+}
+
+// Returns the calling thread's name.
+PosixErrorOr<std::string> ThreadName() {
+ // "The buffer should allow space for up to 16 bytes; the returned std::string
+ // will be null-terminated if it is shorter than that." - prctl(2). But we
+ // always want the thread name to be null-terminated.
+ char thread_name[17];
+ int rc = prctl(PR_GET_NAME, thread_name, 0, 0, 0);
+ MaybeSave();
+ if (rc < 0) {
+ return PosixError(errno, "prctl(PR_GET_NAME)");
+ }
+ thread_name[16] = '\0';
+ return std::string(thread_name);
+}
+
+// Parses the contents of a /proc/[pid]/status file into a collection of
+// key-value pairs.
+PosixErrorOr<std::map<std::string, std::string>> ParseProcStatus(
+ absl::string_view status_str) {
+ std::map<std::string, std::string> fields;
+ for (absl::string_view const line :
+ absl::StrSplit(status_str, '\n', absl::SkipWhitespace())) {
+ const std::pair<absl::string_view, absl::string_view> kv =
+ absl::StrSplit(line, absl::MaxSplits(":\t", 1));
+ if (kv.first.empty()) {
+ return PosixError(
+ EINVAL, absl::StrCat("failed to parse key in line \"", line, "\""));
+ }
+ std::string key(kv.first);
+ if (fields.count(key)) {
+ return PosixError(EINVAL,
+ absl::StrCat("duplicate key \"", kv.first, "\""));
+ }
+ std::string value(kv.second);
+ absl::StripLeadingAsciiWhitespace(&value);
+ fields.emplace(std::move(key), std::move(value));
+ }
+ return fields;
+}
+
+TEST(ParseProcStatusTest, ParsesSimpleStatusFileWithMixedWhitespaceCorrectly) {
+ EXPECT_THAT(
+ ParseProcStatus(
+ "Name:\tinit\nState:\tS (sleeping)\nCapEff:\t 0000001fffffffff\n"),
+ IsPosixErrorOkAndHolds(UnorderedElementsAre(
+ Pair("Name", "init"), Pair("State", "S (sleeping)"),
+ Pair("CapEff", "0000001fffffffff"))));
+}
+
+TEST(ParseProcStatusTest, DetectsDuplicateKeys) {
+ auto proc_status_or = ParseProcStatus("Name:\tfoo\nName:\tfoo\n");
+ EXPECT_THAT(proc_status_or,
+ PosixErrorIs(EINVAL, ::testing::StrEq("duplicate key \"Name\"")));
+}
+
+TEST(ParseProcStatusTest, DetectsMissingTabs) {
+ EXPECT_THAT(ParseProcStatus("Name:foo\nPid: 1\n"),
+ IsPosixErrorOkAndHolds(UnorderedElementsAre(Pair("Name:foo", ""),
+ Pair("Pid: 1", ""))));
+}
+
+TEST(ProcPidStatusTest, HasBasicFields) {
+ // Do this on a separate thread since we want tgid != tid.
+ ScopedThread([] {
+ const pid_t tgid = getpid();
+ const pid_t tid = syscall(SYS_gettid);
+ EXPECT_NE(tgid, tid);
+ const auto thread_name = ASSERT_NO_ERRNO_AND_VALUE(ThreadName());
+
+ std::string status_str = ASSERT_NO_ERRNO_AND_VALUE(
+ GetContents(absl::StrCat("/proc/", tid, "/status")));
+
+ ASSERT_FALSE(status_str.empty());
+ const auto status = ASSERT_NO_ERRNO_AND_VALUE(ParseProcStatus(status_str));
+ EXPECT_THAT(status, IsSupersetOf({Pair("Name", thread_name),
+ Pair("Tgid", absl::StrCat(tgid)),
+ Pair("Pid", absl::StrCat(tid)),
+ Pair("PPid", absl::StrCat(getppid()))}));
+ });
+}
+
+TEST(ProcPidStatusTest, StateRunning) {
+ // Task must be running when reading the file.
+ const pid_t tid = syscall(SYS_gettid);
+ std::string status_str = ASSERT_NO_ERRNO_AND_VALUE(
+ GetContents(absl::StrCat("/proc/", tid, "/status")));
+
+ EXPECT_THAT(ParseProcStatus(status_str),
+ IsPosixErrorOkAndHolds(Contains(Pair("State", "R (running)"))));
+}
+
+TEST(ProcPidStatusTest, StateSleeping_NoRandomSave) {
+ // Starts a child process that blocks and checks that State is sleeping.
+ auto res = WithSubprocess(
+ [&](int pid) -> PosixError {
+ // Because this test is timing based we will disable cooperative saving
+ // and the test itself also has random saving disabled.
+ const DisableSave ds;
+ // Try multiple times in case the child isn't sleeping when status file
+ // is read.
+ MonotonicTimer timer;
+ timer.Start();
+ for (;;) {
+ ASSIGN_OR_RETURN_ERRNO(
+ std::string status_str,
+ GetContents(absl::StrCat("/proc/", pid, "/status")));
+ ASSIGN_OR_RETURN_ERRNO(auto map, ParseProcStatus(status_str));
+ if (map["State"] == std::string("S (sleeping)")) {
+ // Test passed!
+ return NoError();
+ }
+ if (timer.Duration() > absl::Seconds(10)) {
+ return PosixError(ETIMEDOUT, "Timeout waiting for child to sleep");
+ }
+ absl::SleepFor(absl::Milliseconds(10));
+ }
+ },
+ nullptr, nullptr);
+ ASSERT_NO_ERRNO(res);
+}
+
+TEST(ProcPidStatusTest, ValuesAreTabDelimited) {
+ std::string status_str =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/status"));
+ ASSERT_FALSE(status_str.empty());
+ for (absl::string_view const line :
+ absl::StrSplit(status_str, '\n', absl::SkipWhitespace())) {
+ EXPECT_NE(std::string::npos, line.find(":\t"));
+ }
+}
+
+// Threads properly counts running threads.
+//
+// TODO(mpratt): Test zombied threads while the thread group leader is still
+// running with generalized fork and clone children from the wait test.
+TEST(ProcPidStatusTest, Threads) {
+ char buf[4096] = {};
+ EXPECT_THAT(ReadWhileRunning("status", buf, sizeof(buf) - 1),
+ SyscallSucceedsWithValue(Gt(0)));
+
+ auto status = ASSERT_NO_ERRNO_AND_VALUE(ParseProcStatus(buf));
+ auto it = status.find("Threads");
+ ASSERT_NE(it, status.end());
+ int threads = -1;
+ EXPECT_TRUE(absl::SimpleAtoi(it->second, &threads))
+ << "Threads value " << it->second << " is not a number";
+ // Don't make assumptions about the exact number of threads, as it may not be
+ // constant.
+ EXPECT_GE(threads, 1);
+
+ memset(buf, 0, sizeof(buf));
+ EXPECT_THAT(ReadWhileZombied("status", buf, sizeof(buf) - 1),
+ SyscallSucceedsWithValue(Gt(0)));
+
+ status = ASSERT_NO_ERRNO_AND_VALUE(ParseProcStatus(buf));
+ it = status.find("Threads");
+ ASSERT_NE(it, status.end());
+ threads = -1;
+ EXPECT_TRUE(absl::SimpleAtoi(it->second, &threads))
+ << "Threads value " << it->second << " is not a number";
+ // There must be only the thread group leader remaining, zombied.
+ EXPECT_EQ(threads, 1);
+}
+
+// Returns true if all characters in s are digits.
+bool IsDigits(absl::string_view s) {
+ return std::all_of(s.begin(), s.end(), absl::ascii_isdigit);
+}
+
+TEST(ProcPidStatTest, VmStats) {
+ std::string status_str =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/status"));
+ ASSERT_FALSE(status_str.empty());
+ auto status = ASSERT_NO_ERRNO_AND_VALUE(ParseProcStatus(status_str));
+
+ const auto vss_it = status.find("VmSize");
+ ASSERT_NE(vss_it, status.end());
+
+ absl::string_view vss_str(vss_it->second);
+
+ // Room for the " kB" suffix plus at least one digit.
+ ASSERT_GT(vss_str.length(), 3);
+ EXPECT_TRUE(absl::EndsWith(vss_str, " kB"));
+ // Everything else is part of a number.
+ EXPECT_TRUE(IsDigits(vss_str.substr(0, vss_str.length() - 3))) << vss_str;
+ // ... which is not 0.
+ EXPECT_NE('0', vss_str[0]);
+
+ const auto rss_it = status.find("VmRSS");
+ ASSERT_NE(rss_it, status.end());
+
+ absl::string_view rss_str(rss_it->second);
+
+ // Room for the " kB" suffix plus at least one digit.
+ ASSERT_GT(rss_str.length(), 3);
+ EXPECT_TRUE(absl::EndsWith(rss_str, " kB"));
+ // Everything else is part of a number.
+ EXPECT_TRUE(IsDigits(rss_str.substr(0, rss_str.length() - 3))) << rss_str;
+ // ... which is not 0.
+ EXPECT_NE('0', rss_str[0]);
+
+ const auto data_it = status.find("VmData");
+ ASSERT_NE(data_it, status.end());
+
+ absl::string_view data_str(data_it->second);
+
+ // Room for the " kB" suffix plus at least one digit.
+ ASSERT_GT(data_str.length(), 3);
+ EXPECT_TRUE(absl::EndsWith(data_str, " kB"));
+ // Everything else is part of a number.
+ EXPECT_TRUE(IsDigits(data_str.substr(0, data_str.length() - 3))) << data_str;
+ // ... which is not 0.
+ EXPECT_NE('0', data_str[0]);
+}
+
+// Parse an array of NUL-terminated char* arrays, returning a vector of
+// strings.
+std::vector<std::string> ParseNulTerminatedStrings(std::string contents) {
+ EXPECT_EQ('\0', contents.back());
+ // The split will leave an empty string if the NUL-byte remains, so pop
+ // it.
+ contents.pop_back();
+
+ return absl::StrSplit(contents, '\0');
+}
+
+TEST(ProcPidCmdline, MatchesArgv) {
+ std::vector<std::string> proc_cmdline = ParseNulTerminatedStrings(
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/cmdline")));
+ EXPECT_THAT(saved_argv, ContainerEq(proc_cmdline));
+}
+
+TEST(ProcPidEnviron, MatchesEnviron) {
+ std::vector<std::string> proc_environ = ParseNulTerminatedStrings(
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/environ")));
+ // Get the environment from the environ variable, which we will compare with
+ // /proc/self/environ.
+ std::vector<std::string> env;
+ for (char** v = environ; *v; v++) {
+ env.push_back(*v);
+ }
+ EXPECT_THAT(env, ContainerEq(proc_environ));
+}
+
+TEST(ProcPidCmdline, SubprocessForkSameCmdline) {
+ std::vector<std::string> proc_cmdline_parent;
+ std::vector<std::string> proc_cmdline;
+ proc_cmdline_parent = ParseNulTerminatedStrings(
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/cmdline")));
+ auto res = WithSubprocess(
+ [&](int pid) -> PosixError {
+ ASSIGN_OR_RETURN_ERRNO(
+ auto raw_cmdline,
+ GetContents(absl::StrCat("/proc/", pid, "/cmdline")));
+ proc_cmdline = ParseNulTerminatedStrings(raw_cmdline);
+ return NoError();
+ },
+ nullptr, nullptr);
+ ASSERT_NO_ERRNO(res);
+
+ for (size_t i = 0; i < proc_cmdline_parent.size(); i++) {
+ EXPECT_EQ(proc_cmdline_parent[i], proc_cmdline[i]);
+ }
+}
+
+// Test whether /proc/PID/ symlinks can be read for a running process.
+TEST(ProcPidSymlink, SubprocessRunning) {
+ char buf[1];
+
+ EXPECT_THAT(ReadlinkWhileRunning("exe", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadlinkWhileRunning("ns/net", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadlinkWhileRunning("ns/pid", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadlinkWhileRunning("ns/user", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+}
+
+TEST(ProcPidSymlink, SubprocessZombied) {
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ char buf[1];
+
+ int want = EACCES;
+ if (!IsRunningOnGvisor()) {
+ auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion());
+ if (version.major > 4 || (version.major == 4 && version.minor > 3)) {
+ want = ENOENT;
+ }
+ }
+
+ EXPECT_THAT(ReadlinkWhileZombied("exe", buf, sizeof(buf)),
+ SyscallFailsWithErrno(want));
+
+ if (!IsRunningOnGvisor()) {
+ EXPECT_THAT(ReadlinkWhileZombied("ns/net", buf, sizeof(buf)),
+ SyscallFailsWithErrno(want));
+ }
+
+ // FIXME(gvisor.dev/issue/164): Inconsistent behavior between linux on proc
+ // files.
+ //
+ // ~4.3: Syscall fails with EACCES.
+ // 4.17: Syscall succeeds and returns 1.
+ //
+ if (!IsRunningOnGvisor()) {
+ return;
+ }
+
+ EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)),
+ SyscallFailsWithErrno(want));
+
+ EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)),
+ SyscallFailsWithErrno(want));
+}
+
+// Test whether /proc/PID/ symlinks can be read for an exited process.
+TEST(ProcPidSymlink, SubprocessExited) {
+ char buf[1];
+
+ EXPECT_THAT(ReadlinkWhileExited("exe", buf, sizeof(buf)),
+ SyscallFailsWithErrno(ESRCH));
+
+ EXPECT_THAT(ReadlinkWhileExited("ns/net", buf, sizeof(buf)),
+ SyscallFailsWithErrno(ESRCH));
+
+ EXPECT_THAT(ReadlinkWhileExited("ns/pid", buf, sizeof(buf)),
+ SyscallFailsWithErrno(ESRCH));
+
+ EXPECT_THAT(ReadlinkWhileExited("ns/user", buf, sizeof(buf)),
+ SyscallFailsWithErrno(ESRCH));
+}
+
+// /proc/PID/exe points to the correct binary.
+TEST(ProcPidExe, Subprocess) {
+ auto link = ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/self/exe"));
+ auto expected_absolute_path =
+ ASSERT_NO_ERRNO_AND_VALUE(MakeAbsolute(link, ""));
+
+ char actual[PATH_MAX + 1] = {};
+ ASSERT_THAT(ReadlinkWhileRunning("exe", actual, sizeof(actual)),
+ SyscallSucceedsWithValue(Gt(0)));
+ EXPECT_EQ(actual, expected_absolute_path);
+}
+
+// Test whether /proc/PID/ files can be read for a running process.
+TEST(ProcPidFile, SubprocessRunning) {
+ char buf[1];
+
+ EXPECT_THAT(ReadWhileRunning("auxv", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileRunning("cmdline", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileRunning("comm", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileRunning("gid_map", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileRunning("io", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileRunning("maps", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileRunning("stat", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileRunning("status", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileRunning("uid_map", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileRunning("oom_score", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileRunning("oom_score_adj", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+}
+
+// Test whether /proc/PID/ files can be read for a zombie process.
+TEST(ProcPidFile, SubprocessZombie) {
+ char buf[1];
+
+ // FIXME(gvisor.dev/issue/164): Loosen requirement due to inconsistent
+ // behavior on different kernels.
+ //
+ // ~4.3: Succeds and returns 0.
+ // 4.17: Succeeds and returns 1.
+ // gVisor: Succeeds and returns 0.
+ EXPECT_THAT(ReadWhileZombied("auxv", buf, sizeof(buf)), SyscallSucceeds());
+
+ EXPECT_THAT(ReadWhileZombied("cmdline", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(0));
+
+ EXPECT_THAT(ReadWhileZombied("comm", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileZombied("gid_map", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileZombied("maps", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(0));
+
+ EXPECT_THAT(ReadWhileZombied("stat", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileZombied("status", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileZombied("uid_map", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileZombied("oom_score", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileZombied("oom_score_adj", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux
+ // on proc files.
+ //
+ // ~4.3: Fails and returns EACCES.
+ // gVisor & 4.17: Succeeds and returns 1.
+ //
+ // EXPECT_THAT(ReadWhileZombied("io", buf, sizeof(buf)),
+ // SyscallFailsWithErrno(EACCES));
+}
+
+// Test whether /proc/PID/ files can be read for an exited process.
+TEST(ProcPidFile, SubprocessExited) {
+ char buf[1];
+
+ // FIXME(gvisor.dev/issue/164): Inconsistent behavior between kernels.
+ //
+ // ~4.3: Fails and returns ESRCH.
+ // gVisor: Fails with ESRCH.
+ // 4.17: Succeeds and returns 1.
+ //
+ // EXPECT_THAT(ReadWhileExited("auxv", buf, sizeof(buf)),
+ // SyscallFailsWithErrno(ESRCH));
+
+ EXPECT_THAT(ReadWhileExited("cmdline", buf, sizeof(buf)),
+ SyscallFailsWithErrno(ESRCH));
+
+ if (!IsRunningOnGvisor()) {
+ // FIXME(gvisor.dev/issue/164): Succeeds on gVisor.
+ EXPECT_THAT(ReadWhileExited("comm", buf, sizeof(buf)),
+ SyscallFailsWithErrno(ESRCH));
+ }
+
+ EXPECT_THAT(ReadWhileExited("gid_map", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ if (!IsRunningOnGvisor()) {
+ // FIXME(gvisor.dev/issue/164): Succeeds on gVisor.
+ EXPECT_THAT(ReadWhileExited("io", buf, sizeof(buf)),
+ SyscallFailsWithErrno(ESRCH));
+ }
+
+ if (!IsRunningOnGvisor()) {
+ // FIXME(gvisor.dev/issue/164): Returns EOF on gVisor.
+ EXPECT_THAT(ReadWhileExited("maps", buf, sizeof(buf)),
+ SyscallFailsWithErrno(ESRCH));
+ }
+
+ if (!IsRunningOnGvisor()) {
+ // FIXME(gvisor.dev/issue/164): Succeeds on gVisor.
+ EXPECT_THAT(ReadWhileExited("stat", buf, sizeof(buf)),
+ SyscallFailsWithErrno(ESRCH));
+ }
+
+ if (!IsRunningOnGvisor()) {
+ // FIXME(gvisor.dev/issue/164): Succeeds on gVisor.
+ EXPECT_THAT(ReadWhileExited("status", buf, sizeof(buf)),
+ SyscallFailsWithErrno(ESRCH));
+ }
+
+ EXPECT_THAT(ReadWhileExited("uid_map", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ if (!IsRunningOnGvisor()) {
+ // FIXME(gvisor.dev/issue/164): Succeeds on gVisor.
+ EXPECT_THAT(ReadWhileExited("oom_score", buf, sizeof(buf)),
+ SyscallFailsWithErrno(ESRCH));
+ }
+
+ EXPECT_THAT(ReadWhileExited("oom_score_adj", buf, sizeof(buf)),
+ SyscallFailsWithErrno(ESRCH));
+}
+
+PosixError DirContainsImpl(absl::string_view path,
+ const std::vector<std::string>& targets,
+ bool strict) {
+ ASSIGN_OR_RETURN_ERRNO(auto listing, ListDir(path, false));
+ bool success = true;
+
+ for (auto& expected_entry : targets) {
+ auto cursor = std::find(listing.begin(), listing.end(), expected_entry);
+ if (cursor == listing.end()) {
+ success = false;
+ }
+ }
+
+ if (!success) {
+ return PosixError(
+ ENOENT,
+ absl::StrCat("Failed to find one or more paths in '", path, "'"));
+ }
+
+ if (strict) {
+ if (targets.size() != listing.size()) {
+ return PosixError(
+ EINVAL,
+ absl::StrCat("Expected to find ", targets.size(), " elements in '",
+ path, "', but found ", listing.size()));
+ }
+ }
+
+ return NoError();
+}
+
+PosixError DirContains(absl::string_view path,
+ const std::vector<std::string>& targets) {
+ return DirContainsImpl(path, targets, false);
+}
+
+PosixError DirContainsExactly(absl::string_view path,
+ const std::vector<std::string>& targets) {
+ return DirContainsImpl(path, targets, true);
+}
+
+PosixError EventuallyDirContainsExactly(
+ absl::string_view path, const std::vector<std::string>& targets) {
+ constexpr int kRetryCount = 100;
+ const absl::Duration kRetryDelay = absl::Milliseconds(100);
+
+ for (int i = 0; i < kRetryCount; ++i) {
+ auto res = DirContainsExactly(path, targets);
+ if (res.ok()) {
+ return res;
+ } else if (i < kRetryCount - 1) {
+ // Sleep if this isn't the final iteration.
+ absl::SleepFor(kRetryDelay);
+ }
+ }
+ return PosixError(ETIMEDOUT,
+ "Timed out while waiting for directory to contain files ");
+}
+
+TEST(ProcTask, Basic) {
+ EXPECT_NO_ERRNO(
+ DirContains("/proc/self/task", {".", "..", absl::StrCat(getpid())}));
+}
+
+std::vector<std::string> TaskFiles(
+ const std::vector<std::string>& initial_contents,
+ const std::vector<pid_t>& pids) {
+ return VecCat<std::string>(
+ initial_contents,
+ ApplyVec<std::string>([](const pid_t p) { return absl::StrCat(p); },
+ pids));
+}
+
+std::vector<std::string> TaskFiles(const std::vector<pid_t>& pids) {
+ return TaskFiles({".", "..", absl::StrCat(getpid())}, pids);
+}
+
+// Helper class for creating a new task in the current thread group.
+class BlockingChild {
+ public:
+ BlockingChild() : thread_([=] { Start(); }) {}
+ ~BlockingChild() { Join(); }
+
+ pid_t Tid() const {
+ absl::MutexLock ml(&mu_);
+ mu_.Await(absl::Condition(&tid_ready_));
+ return tid_;
+ }
+
+ void Join() { Stop(); }
+
+ private:
+ void Start() {
+ absl::MutexLock ml(&mu_);
+ tid_ = syscall(__NR_gettid);
+ tid_ready_ = true;
+ mu_.Await(absl::Condition(&stop_));
+ }
+
+ void Stop() {
+ absl::MutexLock ml(&mu_);
+ stop_ = true;
+ }
+
+ mutable absl::Mutex mu_;
+ bool stop_ ABSL_GUARDED_BY(mu_) = false;
+ pid_t tid_;
+ bool tid_ready_ ABSL_GUARDED_BY(mu_) = false;
+
+ // Must be last to ensure that the destructor for the thread is run before
+ // any other member of the object is destroyed.
+ ScopedThread thread_;
+};
+
+TEST(ProcTask, NewThreadAppears) {
+ auto initial = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/proc/self/task", false));
+ BlockingChild child1;
+ EXPECT_NO_ERRNO(DirContainsExactly("/proc/self/task",
+ TaskFiles(initial, {child1.Tid()})));
+}
+
+TEST(ProcTask, KilledThreadsDisappear) {
+ auto initial = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/proc/self/task/", false));
+
+ BlockingChild child1;
+ EXPECT_NO_ERRNO(DirContainsExactly("/proc/self/task",
+ TaskFiles(initial, {child1.Tid()})));
+
+ // Stat child1's task file. Regression test for b/32097707.
+ struct stat statbuf;
+ const std::string child1_task_file =
+ absl::StrCat("/proc/self/task/", child1.Tid());
+ EXPECT_THAT(stat(child1_task_file.c_str(), &statbuf), SyscallSucceeds());
+
+ BlockingChild child2;
+ EXPECT_NO_ERRNO(DirContainsExactly(
+ "/proc/self/task", TaskFiles(initial, {child1.Tid(), child2.Tid()})));
+
+ BlockingChild child3;
+ BlockingChild child4;
+ BlockingChild child5;
+ EXPECT_NO_ERRNO(DirContainsExactly(
+ "/proc/self/task",
+ TaskFiles(initial, {child1.Tid(), child2.Tid(), child3.Tid(),
+ child4.Tid(), child5.Tid()})));
+
+ child2.Join();
+ EXPECT_NO_ERRNO(EventuallyDirContainsExactly(
+ "/proc/self/task", TaskFiles(initial, {child1.Tid(), child3.Tid(),
+ child4.Tid(), child5.Tid()})));
+
+ child1.Join();
+ child4.Join();
+ EXPECT_NO_ERRNO(EventuallyDirContainsExactly(
+ "/proc/self/task", TaskFiles(initial, {child3.Tid(), child5.Tid()})));
+
+ // Stat child1's task file again. This time it should fail. See b/32097707.
+ EXPECT_THAT(stat(child1_task_file.c_str(), &statbuf),
+ SyscallFailsWithErrno(ENOENT));
+
+ child3.Join();
+ child5.Join();
+ EXPECT_NO_ERRNO(EventuallyDirContainsExactly("/proc/self/task", initial));
+}
+
+TEST(ProcTask, ChildTaskDir) {
+ BlockingChild child1;
+ EXPECT_NO_ERRNO(DirContains("/proc/self/task", TaskFiles({child1.Tid()})));
+ EXPECT_NO_ERRNO(DirContains(absl::StrCat("/proc/", child1.Tid(), "/task"),
+ TaskFiles({child1.Tid()})));
+}
+
+PosixError VerifyPidDir(std::string path) {
+ return DirContains(path, {"exe", "fd", "io", "maps", "ns", "stat", "status"});
+}
+
+TEST(ProcTask, VerifyTaskDir) {
+ EXPECT_NO_ERRNO(VerifyPidDir("/proc/self"));
+
+ EXPECT_NO_ERRNO(VerifyPidDir(absl::StrCat("/proc/self/task/", getpid())));
+ BlockingChild child1;
+ EXPECT_NO_ERRNO(VerifyPidDir(absl::StrCat("/proc/self/task/", child1.Tid())));
+
+ // Only the first level of task directories should contain the 'task'
+ // directory. That is:
+ //
+ // /proc/1234/task <- should exist
+ // /proc/1234/task/1234/task <- should not exist
+ // /proc/1234/task/1235/task <- should not exist (where 1235 is in the same
+ // thread group as 1234).
+ EXPECT_FALSE(
+ DirContains(absl::StrCat("/proc/self/task/", getpid()), {"task"}).ok())
+ << "Found 'task' directory in an inner directory.";
+}
+
+TEST(ProcTask, TaskDirCannotBeDeleted) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+
+ EXPECT_THAT(rmdir("/proc/self/task"), SyscallFails());
+ EXPECT_THAT(rmdir(absl::StrCat("/proc/self/task/", getpid()).c_str()),
+ SyscallFailsWithErrno(EACCES));
+}
+
+TEST(ProcTask, TaskDirHasCorrectMetadata) {
+ struct stat st;
+ EXPECT_THAT(stat("/proc/self/task", &st), SyscallSucceeds());
+ EXPECT_TRUE(S_ISDIR(st.st_mode));
+
+ // Verify file is readable and executable by everyone.
+ mode_t expected_permissions =
+ S_IRUSR | S_IXUSR | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH;
+ mode_t permissions = st.st_mode & (S_IRWXU | S_IRWXG | S_IRWXO);
+ EXPECT_EQ(expected_permissions, permissions);
+}
+
+TEST(ProcTask, TaskDirCanSeekToEnd) {
+ const FileDescriptor dirfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/task", O_RDONLY));
+ EXPECT_THAT(lseek(dirfd.get(), 0, SEEK_END), SyscallSucceeds());
+}
+
+TEST(ProcTask, VerifyTaskDirNlinks) {
+ // A task directory will have 3 links if the taskgroup has a single
+ // thread. For example, the following shows where the links to
+ // '/proc/12345/task comes' from for a single threaded process with pid 12345:
+ //
+ // /proc/12345/task <-- 1 link for the directory itself
+ // . <-- link from "."
+ // ..
+ // 12345
+ // .
+ // .. <-- link from ".." to parent.
+ // <other contents of a task dir>
+ //
+ // We can't assert an absolute number of links since we don't control how many
+ // threads the test framework spawns. Instead, we'll ensure creating a new
+ // thread increases the number of links as expected.
+
+ // Once we reach the test body, we can count on the thread count being stable
+ // unless we spawn a new one.
+ uint64_t initial_links = ASSERT_NO_ERRNO_AND_VALUE(Links("/proc/self/task"));
+ ASSERT_GE(initial_links, 3);
+
+ // For each new subtask, we should gain a new link.
+ BlockingChild child1;
+ EXPECT_THAT(Links("/proc/self/task"),
+ IsPosixErrorOkAndHolds(initial_links + 1));
+ BlockingChild child2;
+ EXPECT_THAT(Links("/proc/self/task"),
+ IsPosixErrorOkAndHolds(initial_links + 2));
+}
+
+TEST(ProcTask, CommContainsThreadNameAndTrailingNewline) {
+ constexpr char kThreadName[] = "TestThread12345";
+ ASSERT_THAT(prctl(PR_SET_NAME, kThreadName), SyscallSucceeds());
+
+ auto thread_name = ASSERT_NO_ERRNO_AND_VALUE(
+ GetContents(JoinPath("/proc", absl::StrCat(getpid()), "task",
+ absl::StrCat(syscall(SYS_gettid)), "comm")));
+ EXPECT_EQ(absl::StrCat(kThreadName, "\n"), thread_name);
+}
+
+TEST(ProcTaskNs, NsDirExistsAndHasCorrectMetadata) {
+ EXPECT_NO_ERRNO(DirContains("/proc/self/ns", {"net", "pid", "user"}));
+
+ // Let's just test the 'pid' entry, all of them are very similar.
+ struct stat st;
+ EXPECT_THAT(lstat("/proc/self/ns/pid", &st), SyscallSucceeds());
+ EXPECT_TRUE(S_ISLNK(st.st_mode));
+
+ auto link = ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/self/ns/pid"));
+ EXPECT_THAT(link, ::testing::StartsWith("pid:["));
+}
+
+TEST(ProcTaskNs, AccessOnNsNodeSucceeds) {
+ EXPECT_THAT(access("/proc/self/ns/pid", F_OK), SyscallSucceeds());
+}
+
+TEST(ProcSysKernelHostname, Exists) {
+ EXPECT_THAT(open("/proc/sys/kernel/hostname", O_RDONLY), SyscallSucceeds());
+}
+
+TEST(ProcSysKernelHostname, MatchesUname) {
+ struct utsname buf;
+ EXPECT_THAT(uname(&buf), SyscallSucceeds());
+ const std::string hostname = absl::StrCat(buf.nodename, "\n");
+ auto procfs_hostname =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/hostname"));
+ EXPECT_EQ(procfs_hostname, hostname);
+}
+
+TEST(ProcSysVmMmapMinAddr, HasNumericValue) {
+ const std::string mmap_min_addr_str =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/vm/mmap_min_addr"));
+ uintptr_t mmap_min_addr;
+ EXPECT_TRUE(absl::SimpleAtoi(mmap_min_addr_str, &mmap_min_addr))
+ << "/proc/sys/vm/mmap_min_addr does not contain a numeric value: "
+ << mmap_min_addr_str;
+}
+
+TEST(ProcSysVmOvercommitMemory, HasNumericValue) {
+ const std::string overcommit_memory_str =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/vm/overcommit_memory"));
+ uintptr_t overcommit_memory;
+ EXPECT_TRUE(absl::SimpleAtoi(overcommit_memory_str, &overcommit_memory))
+ << "/proc/sys/vm/overcommit_memory does not contain a numeric value: "
+ << overcommit_memory;
+}
+
+// Check that link for proc fd entries point the target node, not the
+// symlink itself. Regression test for b/31155070.
+TEST(ProcTaskFd, FstatatFollowsSymlink) {
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+
+ struct stat sproc = {};
+ EXPECT_THAT(
+ fstatat(-1, absl::StrCat("/proc/self/fd/", fd.get()).c_str(), &sproc, 0),
+ SyscallSucceeds());
+
+ struct stat sfile = {};
+ EXPECT_THAT(fstatat(-1, file.path().c_str(), &sfile, 0), SyscallSucceeds());
+
+ // If fstatat follows the fd symlink, the device and inode numbers should
+ // match at a minimum.
+ EXPECT_EQ(sproc.st_dev, sfile.st_dev);
+ EXPECT_EQ(sproc.st_ino, sfile.st_ino);
+ EXPECT_EQ(0, memcmp(&sfile, &sproc, sizeof(sfile)));
+}
+
+TEST(ProcFilesystems, Bug65172365) {
+ std::string proc_filesystems =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/filesystems"));
+ ASSERT_FALSE(proc_filesystems.empty());
+}
+
+TEST(ProcFilesystems, PresenceOfShmMaxMniAll) {
+ uint64_t shmmax = 0;
+ uint64_t shmall = 0;
+ uint64_t shmmni = 0;
+ std::string proc_file;
+ proc_file = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/shmmax"));
+ ASSERT_FALSE(proc_file.empty());
+ ASSERT_TRUE(absl::SimpleAtoi(proc_file, &shmmax));
+ proc_file = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/shmall"));
+ ASSERT_FALSE(proc_file.empty());
+ ASSERT_TRUE(absl::SimpleAtoi(proc_file, &shmall));
+ proc_file = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/shmmni"));
+ ASSERT_FALSE(proc_file.empty());
+ ASSERT_TRUE(absl::SimpleAtoi(proc_file, &shmmni));
+
+ ASSERT_GT(shmmax, 0);
+ ASSERT_GT(shmall, 0);
+ ASSERT_GT(shmmni, 0);
+ ASSERT_LE(shmall, shmmax);
+
+ // These values should never be higher than this by default, for more
+ // information see uapi/linux/shm.h
+ ASSERT_LE(shmmax, ULONG_MAX - (1UL << 24));
+ ASSERT_LE(shmall, ULONG_MAX - (1UL << 24));
+}
+
+// Check that /proc/mounts is a symlink to self/mounts.
+TEST(ProcMounts, IsSymlink) {
+ auto link = ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/mounts"));
+ EXPECT_EQ(link, "self/mounts");
+}
+
+TEST(ProcSelfMountinfo, RequiredFieldsArePresent) {
+ auto mountinfo =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/mountinfo"));
+ EXPECT_THAT(
+ mountinfo,
+ AllOf(
+ // Root mount.
+ ContainsRegex(
+ R"([0-9]+ [0-9]+ [0-9]+:[0-9]+ /\S* / (rw|ro).*- \S+ \S+ (rw|ro)\S*)"),
+ // Proc mount - always rw.
+ ContainsRegex(
+ R"([0-9]+ [0-9]+ [0-9]+:[0-9]+ / /proc rw.*- \S+ \S+ rw\S*)")));
+}
+
+// Check that /proc/self/mounts looks something like a real mounts file.
+TEST(ProcSelfMounts, RequiredFieldsArePresent) {
+ auto mounts = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/mounts"));
+ EXPECT_THAT(mounts,
+ AllOf(
+ // Root mount.
+ ContainsRegex(R"(\S+ / \S+ (rw|ro)\S* [0-9]+ [0-9]+\s)"),
+ // Root mount.
+ ContainsRegex(R"(\S+ /proc \S+ rw\S* [0-9]+ [0-9]+\s)")));
+}
+
+void CheckDuplicatesRecursively(std::string path) {
+ std::vector<std::string> child_dirs;
+
+ // There is the known issue of the linux procfs, that two consequent calls of
+ // readdir can return the same entry twice if between these calls one or more
+ // entries have been removed from this directory.
+ int max_attempts = 5;
+ for (int i = 0; i < max_attempts; i++) {
+ child_dirs.clear();
+ errno = 0;
+ bool success = true;
+ DIR* dir = opendir(path.c_str());
+ if (dir == nullptr) {
+ // Ignore any directories we can't read or missing directories as the
+ // directory could have been deleted/mutated from the time the parent
+ // directory contents were read.
+ return;
+ }
+ auto dir_closer = Cleanup([&dir]() { closedir(dir); });
+ std::unordered_set<std::string> children;
+ while (true) {
+ // Readdir(3): If the end of the directory stream is reached, NULL is
+ // returned and errno is not changed. If an error occurs, NULL is
+ // returned and errno is set appropriately. To distinguish end of stream
+ // and from an error, set errno to zero before calling readdir() and then
+ // check the value of errno if NULL is returned.
+ errno = 0;
+ struct dirent* dp = readdir(dir);
+ if (dp == nullptr) {
+ // Linux will return EINVAL when calling getdents on a /proc/tid/net
+ // file corresponding to a zombie task.
+ // See fs/proc/proc_net.c:proc_tgid_net_readdir().
+ //
+ // We just ignore the directory in this case.
+ if (errno == EINVAL && absl::StartsWith(path, "/proc/") &&
+ absl::EndsWith(path, "/net")) {
+ break;
+ }
+
+ // Otherwise, no errors are allowed.
+ ASSERT_EQ(errno, 0) << path;
+ break; // We're done.
+ }
+
+ const std::string name = dp->d_name;
+
+ if (name == "." || name == "..") {
+ continue;
+ }
+
+ // Ignore a duplicate entry if it isn't the last attempt.
+ if (i == max_attempts - 1) {
+ ASSERT_EQ(children.find(name), children.end())
+ << absl::StrCat(path, "/", name);
+ } else if (children.find(name) != children.end()) {
+ std::cerr << "Duplicate entry: " << i << ":"
+ << absl::StrCat(path, "/", name) << std::endl;
+ success = false;
+ break;
+ }
+ children.insert(name);
+
+ if (dp->d_type == DT_DIR) {
+ child_dirs.push_back(name);
+ }
+ }
+ if (success) {
+ break;
+ }
+ }
+ for (auto dname = child_dirs.begin(); dname != child_dirs.end(); dname++) {
+ CheckDuplicatesRecursively(absl::StrCat(path, "/", *dname));
+ }
+}
+
+TEST(Proc, NoDuplicates) { CheckDuplicatesRecursively("/proc"); }
+
+// Most /proc/PID files are owned by the task user with SUID_DUMP_USER.
+TEST(ProcPid, UserDumpableOwner) {
+ int before;
+ ASSERT_THAT(before = prctl(PR_GET_DUMPABLE), SyscallSucceeds());
+ auto cleanup = Cleanup([before] {
+ ASSERT_THAT(prctl(PR_SET_DUMPABLE, before), SyscallSucceeds());
+ });
+
+ EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_USER), SyscallSucceeds());
+
+ // This applies to the task directory itself and files inside.
+ struct stat st;
+ ASSERT_THAT(stat("/proc/self/", &st), SyscallSucceeds());
+ EXPECT_EQ(st.st_uid, geteuid());
+ EXPECT_EQ(st.st_gid, getegid());
+
+ ASSERT_THAT(stat("/proc/self/stat", &st), SyscallSucceeds());
+ EXPECT_EQ(st.st_uid, geteuid());
+ EXPECT_EQ(st.st_gid, getegid());
+}
+
+// /proc/PID files are owned by root with SUID_DUMP_DISABLE.
+TEST(ProcPid, RootDumpableOwner) {
+ int before;
+ ASSERT_THAT(before = prctl(PR_GET_DUMPABLE), SyscallSucceeds());
+ auto cleanup = Cleanup([before] {
+ ASSERT_THAT(prctl(PR_SET_DUMPABLE, before), SyscallSucceeds());
+ });
+
+ EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_DISABLE), SyscallSucceeds());
+
+ // This *does not* applies to the task directory itself (or other 0555
+ // directories), but does to files inside.
+ struct stat st;
+ ASSERT_THAT(stat("/proc/self/", &st), SyscallSucceeds());
+ EXPECT_EQ(st.st_uid, geteuid());
+ EXPECT_EQ(st.st_gid, getegid());
+
+ // This file is owned by root. Also allow nobody in case this test is running
+ // in a userns without root mapped.
+ ASSERT_THAT(stat("/proc/self/stat", &st), SyscallSucceeds());
+ EXPECT_THAT(st.st_uid, AnyOf(Eq(0), Eq(65534)));
+ EXPECT_THAT(st.st_gid, AnyOf(Eq(0), Eq(65534)));
+}
+
+TEST(Proc, GetdentsEnoent) {
+ FileDescriptor fd;
+ ASSERT_NO_ERRNO(WithSubprocess(
+ [&](int pid) -> PosixError {
+ // Running.
+ ASSIGN_OR_RETURN_ERRNO(fd, Open(absl::StrCat("/proc/", pid, "/task"),
+ O_RDONLY | O_DIRECTORY));
+
+ return NoError();
+ },
+ nullptr, nullptr));
+ char buf[1024];
+ ASSERT_THAT(syscall(SYS_getdents64, fd.get(), buf, sizeof(buf)),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+void CheckSyscwFromIOFile(const std::string& path, const std::string& regex) {
+ std::string output;
+ ASSERT_NO_ERRNO(GetContents(path, &output));
+ ASSERT_THAT(output, ContainsRegex(absl::StrCat("syscw:\\s+", regex, "\n")));
+}
+
+// Checks that there is variable accounting of IO between threads/tasks.
+TEST(Proc, PidTidIOAccounting) {
+ absl::Notification notification;
+
+ // Run a thread with a bunch of writes. Check that io account records exactly
+ // the number of write calls. File open/close is there to prevent buffering.
+ ScopedThread writer([&notification] {
+ const int num_writes = 100;
+ for (int i = 0; i < num_writes; i++) {
+ auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ ASSERT_NO_ERRNO(SetContents(path.path(), "a"));
+ }
+ notification.Notify();
+ const std::string& writer_dir =
+ absl::StrCat("/proc/", getpid(), "/task/", gettid(), "/io");
+
+ CheckSyscwFromIOFile(writer_dir, std::to_string(num_writes));
+ });
+
+ // Run a thread and do no writes. Check that no writes are recorded.
+ ScopedThread noop([&notification] {
+ notification.WaitForNotification();
+ const std::string& noop_dir =
+ absl::StrCat("/proc/", getpid(), "/task/", gettid(), "/io");
+
+ CheckSyscwFromIOFile(noop_dir, "0");
+ });
+
+ writer.Join();
+ noop.Join();
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
+
+int main(int argc, char** argv) {
+ for (int i = 0; i < argc; ++i) {
+ gvisor::testing::saved_argv.emplace_back(std::string(argv[i]));
+ }
+
+ gvisor::testing::TestInit(&argc, &argv);
+ return gvisor::testing::RunAllTests();
+}
diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc
new file mode 100644
index 000000000..3377b65cf
--- /dev/null
+++ b/test/syscalls/linux/proc_net.cc
@@ -0,0 +1,482 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <errno.h>
+#include <netinet/in.h>
+#include <poll.h>
+#include <sys/socket.h>
+#include <sys/syscall.h>
+#include <sys/types.h>
+
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "absl/time/clock.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+constexpr const char kProcNet[] = "/proc/net";
+
+TEST(ProcNetSymlinkTarget, FileMode) {
+ struct stat s;
+ ASSERT_THAT(stat(kProcNet, &s), SyscallSucceeds());
+ EXPECT_EQ(s.st_mode & S_IFMT, S_IFDIR);
+ EXPECT_EQ(s.st_mode & 0777, 0555);
+}
+
+TEST(ProcNetSymlink, FileMode) {
+ struct stat s;
+ ASSERT_THAT(lstat(kProcNet, &s), SyscallSucceeds());
+ EXPECT_EQ(s.st_mode & S_IFMT, S_IFLNK);
+ EXPECT_EQ(s.st_mode & 0777, 0777);
+}
+
+TEST(ProcNetSymlink, Contents) {
+ char buf[40] = {};
+ int n = readlink(kProcNet, buf, sizeof(buf));
+ ASSERT_THAT(n, SyscallSucceeds());
+
+ buf[n] = 0;
+ EXPECT_STREQ(buf, "self/net");
+}
+
+TEST(ProcNetIfInet6, Format) {
+ auto ifinet6 = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/if_inet6"));
+ EXPECT_THAT(ifinet6,
+ ::testing::MatchesRegex(
+ // Ex: "00000000000000000000000000000001 01 80 10 80 lo\n"
+ "^([a-f0-9]{32}( [a-f0-9]{2}){4} +[a-z][a-z0-9]*\n)+$"));
+}
+
+TEST(ProcSysNetIpv4Sack, Exists) {
+ EXPECT_THAT(open("/proc/sys/net/ipv4/tcp_sack", O_RDONLY), SyscallSucceeds());
+}
+
+TEST(ProcSysNetIpv4Sack, CanReadAndWrite) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE))));
+
+ auto const fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/sys/net/ipv4/tcp_sack", O_RDWR));
+
+ char buf;
+ EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_TRUE(buf == '0' || buf == '1') << "unexpected tcp_sack: " << buf;
+
+ char to_write = (buf == '1') ? '0' : '1';
+ EXPECT_THAT(PwriteFd(fd.get(), &to_write, sizeof(to_write), 0),
+ SyscallSucceedsWithValue(sizeof(to_write)));
+
+ buf = 0;
+ EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ EXPECT_EQ(buf, to_write);
+}
+
+// DeviceEntry is an entry in /proc/net/dev
+struct DeviceEntry {
+ std::string name;
+ uint64_t stats[16];
+};
+
+PosixErrorOr<std::vector<DeviceEntry>> GetDeviceMetricsFromProc(
+ const std::string dev) {
+ std::vector<std::string> lines = absl::StrSplit(dev, '\n');
+ std::vector<DeviceEntry> entries;
+
+ // /proc/net/dev prints 2 lines of headers followed by a line of metrics for
+ // each network interface.
+ for (unsigned i = 2; i < lines.size(); i++) {
+ // Ignore empty lines.
+ if (lines[i].empty()) {
+ continue;
+ }
+
+ std::vector<std::string> values =
+ absl::StrSplit(lines[i], ' ', absl::SkipWhitespace());
+
+ // Interface name + 16 values.
+ if (values.size() != 17) {
+ return PosixError(EINVAL, "invalid line: " + lines[i]);
+ }
+
+ DeviceEntry entry;
+ entry.name = values[0];
+ // Skip the interface name and read only the values.
+ for (unsigned j = 1; j < 17; j++) {
+ uint64_t num;
+ if (!absl::SimpleAtoi(values[j], &num)) {
+ return PosixError(EINVAL, "invalid value: " + values[j]);
+ }
+ entry.stats[j - 1] = num;
+ }
+
+ entries.push_back(entry);
+ }
+
+ return entries;
+}
+
+// TEST(ProcNetDev, Format) tests that /proc/net/dev is parsable and
+// contains at least one entry.
+TEST(ProcNetDev, Format) {
+ auto dev = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/dev"));
+ auto entries = ASSERT_NO_ERRNO_AND_VALUE(GetDeviceMetricsFromProc(dev));
+
+ EXPECT_GT(entries.size(), 0);
+}
+
+PosixErrorOr<uint64_t> GetSNMPMetricFromProc(const std::string snmp,
+ const std::string& type,
+ const std::string& item) {
+ std::vector<std::string> snmp_vec = absl::StrSplit(snmp, '\n');
+
+ // /proc/net/snmp prints a line of headers followed by a line of metrics.
+ // Only search the headers.
+ for (unsigned i = 0; i < snmp_vec.size(); i = i + 2) {
+ if (!absl::StartsWith(snmp_vec[i], type)) continue;
+
+ std::vector<std::string> fields =
+ absl::StrSplit(snmp_vec[i], ' ', absl::SkipWhitespace());
+
+ EXPECT_TRUE((i + 1) < snmp_vec.size());
+ std::vector<std::string> values =
+ absl::StrSplit(snmp_vec[i + 1], ' ', absl::SkipWhitespace());
+
+ EXPECT_TRUE(!fields.empty() && fields.size() == values.size());
+
+ // Metrics start at the first index.
+ for (unsigned j = 1; j < fields.size(); j++) {
+ if (fields[j] == item) {
+ uint64_t val;
+ if (!absl::SimpleAtoi(values[j], &val)) {
+ return PosixError(EINVAL,
+ absl::StrCat("field is not a number: ", values[j]));
+ }
+
+ return val;
+ }
+ }
+ }
+ // We should never get here.
+ return PosixError(
+ EINVAL, absl::StrCat("failed to find ", type, "/", item, " in:", snmp));
+}
+
+TEST(ProcNetSnmp, TcpReset_NoRandomSave) {
+ // TODO(gvisor.dev/issue/866): epsocket metrics are not savable.
+ DisableSave ds;
+
+ uint64_t oldAttemptFails;
+ uint64_t oldActiveOpens;
+ uint64_t oldOutRsts;
+ auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
+ oldActiveOpens = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens"));
+ oldOutRsts =
+ ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "OutRsts"));
+ oldAttemptFails = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "AttemptFails"));
+
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0));
+
+ struct sockaddr_in sin = {
+ .sin_family = AF_INET,
+ .sin_port = htons(1234),
+ };
+
+ ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1);
+ ASSERT_THAT(connect(s.get(), (struct sockaddr*)&sin, sizeof(sin)),
+ SyscallFailsWithErrno(ECONNREFUSED));
+
+ uint64_t newAttemptFails;
+ uint64_t newActiveOpens;
+ uint64_t newOutRsts;
+ snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
+ newActiveOpens = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens"));
+ newOutRsts =
+ ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "OutRsts"));
+ newAttemptFails = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "AttemptFails"));
+
+ EXPECT_EQ(oldActiveOpens, newActiveOpens - 1);
+ EXPECT_EQ(oldOutRsts, newOutRsts - 1);
+ EXPECT_EQ(oldAttemptFails, newAttemptFails - 1);
+}
+
+TEST(ProcNetSnmp, TcpEstab_NoRandomSave) {
+ // TODO(gvisor.dev/issue/866): epsocket metrics are not savable.
+ DisableSave ds;
+
+ uint64_t oldEstabResets;
+ uint64_t oldActiveOpens;
+ uint64_t oldPassiveOpens;
+ uint64_t oldCurrEstab;
+ auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
+ oldActiveOpens = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens"));
+ oldPassiveOpens = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "PassiveOpens"));
+ oldCurrEstab = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab"));
+ oldEstabResets = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "EstabResets"));
+
+ FileDescriptor s_listen =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0));
+ struct sockaddr_in sin = {
+ .sin_family = AF_INET,
+ .sin_port = 0,
+ };
+
+ ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1);
+ ASSERT_THAT(bind(s_listen.get(), (struct sockaddr*)&sin, sizeof(sin)),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(s_listen.get(), 1), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = sizeof(sin);
+ ASSERT_THAT(
+ getsockname(s_listen.get(), reinterpret_cast<sockaddr*>(&sin), &addrlen),
+ SyscallSucceeds());
+
+ FileDescriptor s_connect =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0));
+ ASSERT_THAT(connect(s_connect.get(), (struct sockaddr*)&sin, sizeof(sin)),
+ SyscallSucceeds());
+
+ auto s_accept =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(s_listen.get(), nullptr, nullptr));
+
+ uint64_t newEstabResets;
+ uint64_t newActiveOpens;
+ uint64_t newPassiveOpens;
+ uint64_t newCurrEstab;
+ snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
+ newActiveOpens = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens"));
+ newPassiveOpens = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "PassiveOpens"));
+ newCurrEstab = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab"));
+
+ EXPECT_EQ(oldActiveOpens, newActiveOpens - 1);
+ EXPECT_EQ(oldPassiveOpens, newPassiveOpens - 1);
+ EXPECT_EQ(oldCurrEstab, newCurrEstab - 2);
+
+ // Send 1 byte from client to server.
+ ASSERT_THAT(send(s_connect.get(), "a", 1, 0), SyscallSucceedsWithValue(1));
+
+ constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data.
+
+ // Wait until server-side fd sees the data on its side but don't read it.
+ struct pollfd poll_fd = {s_accept.get(), POLLIN, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+
+ // Now close server-side fd without reading the data which leads to a RST
+ // packet sent to client side.
+ s_accept.reset(-1);
+
+ // Wait until client-side fd sees RST packet.
+ struct pollfd poll_fd1 = {s_connect.get(), POLLIN, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd1, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+
+ // Now close client-side fd.
+ s_connect.reset(-1);
+
+ // Wait until the process of the netstack.
+ absl::SleepFor(absl::Seconds(1));
+
+ snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
+ newCurrEstab = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab"));
+ newEstabResets = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "EstabResets"));
+
+ EXPECT_EQ(oldCurrEstab, newCurrEstab);
+ EXPECT_EQ(oldEstabResets, newEstabResets - 2);
+}
+
+TEST(ProcNetSnmp, UdpNoPorts_NoRandomSave) {
+ // TODO(gvisor.dev/issue/866): epsocket metrics are not savable.
+ DisableSave ds;
+
+ uint64_t oldOutDatagrams;
+ uint64_t oldNoPorts;
+ auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
+ oldOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams"));
+ oldNoPorts =
+ ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "NoPorts"));
+
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+
+ struct sockaddr_in sin = {
+ .sin_family = AF_INET,
+ .sin_port = htons(4444),
+ };
+ ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1);
+ ASSERT_THAT(sendto(s.get(), "a", 1, 0, (struct sockaddr*)&sin, sizeof(sin)),
+ SyscallSucceedsWithValue(1));
+
+ uint64_t newOutDatagrams;
+ uint64_t newNoPorts;
+ snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
+ newOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams"));
+ newNoPorts =
+ ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "NoPorts"));
+
+ EXPECT_EQ(oldOutDatagrams, newOutDatagrams - 1);
+ EXPECT_EQ(oldNoPorts, newNoPorts - 1);
+}
+
+TEST(ProcNetSnmp, UdpIn_NoRandomSave) {
+ // TODO(gvisor.dev/issue/866): epsocket metrics are not savable.
+ const DisableSave ds;
+
+ uint64_t oldOutDatagrams;
+ uint64_t oldInDatagrams;
+ auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
+ oldOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams"));
+ oldInDatagrams = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Udp", "InDatagrams"));
+
+ std::cerr << "snmp: " << std::endl << snmp << std::endl;
+ FileDescriptor server =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ struct sockaddr_in sin = {
+ .sin_family = AF_INET,
+ .sin_port = htons(0),
+ };
+ ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1);
+ ASSERT_THAT(bind(server.get(), (struct sockaddr*)&sin, sizeof(sin)),
+ SyscallSucceeds());
+ // Get the port bound by the server socket.
+ socklen_t addrlen = sizeof(sin);
+ ASSERT_THAT(
+ getsockname(server.get(), reinterpret_cast<sockaddr*>(&sin), &addrlen),
+ SyscallSucceeds());
+
+ FileDescriptor client =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ ASSERT_THAT(
+ sendto(client.get(), "a", 1, 0, (struct sockaddr*)&sin, sizeof(sin)),
+ SyscallSucceedsWithValue(1));
+
+ char buf[128];
+ ASSERT_THAT(recvfrom(server.get(), buf, sizeof(buf), 0, NULL, NULL),
+ SyscallSucceedsWithValue(1));
+
+ uint64_t newOutDatagrams;
+ uint64_t newInDatagrams;
+ snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
+ std::cerr << "new snmp: " << std::endl << snmp << std::endl;
+ newOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams"));
+ newInDatagrams = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Udp", "InDatagrams"));
+
+ EXPECT_EQ(oldOutDatagrams, newOutDatagrams - 1);
+ EXPECT_EQ(oldInDatagrams, newInDatagrams - 1);
+}
+
+TEST(ProcNetSnmp, CheckNetStat) {
+ // TODO(b/155123175): SNMP and netstat don't work on gVisor.
+ SKIP_IF(IsRunningOnGvisor());
+
+ std::string contents =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/netstat"));
+
+ int name_count = 0;
+ int value_count = 0;
+ std::vector<absl::string_view> lines = absl::StrSplit(contents, '\n');
+ for (int i = 0; i + 1 < lines.size(); i += 2) {
+ std::vector<absl::string_view> names =
+ absl::StrSplit(lines[i], absl::ByAnyChar("\t "));
+ std::vector<absl::string_view> values =
+ absl::StrSplit(lines[i + 1], absl::ByAnyChar("\t "));
+ EXPECT_EQ(names.size(), values.size()) << " mismatch in lines '" << lines[i]
+ << "' and '" << lines[i + 1] << "'";
+ for (int j = 0; j < names.size() && j < values.size(); ++j) {
+ if (names[j] == "TCPOrigDataSent" || names[j] == "TCPSynRetrans" ||
+ names[j] == "TCPDSACKRecv" || names[j] == "TCPDSACKOfoRecv") {
+ ++name_count;
+ int64_t val;
+ if (absl::SimpleAtoi(values[j], &val)) {
+ ++value_count;
+ }
+ }
+ }
+ }
+ EXPECT_EQ(name_count, 4);
+ EXPECT_EQ(value_count, 4);
+}
+
+TEST(ProcNetSnmp, Stat) {
+ struct stat st = {};
+ ASSERT_THAT(stat("/proc/net/snmp", &st), SyscallSucceeds());
+}
+
+TEST(ProcNetSnmp, CheckSnmp) {
+ // TODO(b/155123175): SNMP and netstat don't work on gVisor.
+ SKIP_IF(IsRunningOnGvisor());
+
+ std::string contents =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
+
+ int name_count = 0;
+ int value_count = 0;
+ std::vector<absl::string_view> lines = absl::StrSplit(contents, '\n');
+ for (int i = 0; i + 1 < lines.size(); i += 2) {
+ std::vector<absl::string_view> names =
+ absl::StrSplit(lines[i], absl::ByAnyChar("\t "));
+ std::vector<absl::string_view> values =
+ absl::StrSplit(lines[i + 1], absl::ByAnyChar("\t "));
+ EXPECT_EQ(names.size(), values.size()) << " mismatch in lines '" << lines[i]
+ << "' and '" << lines[i + 1] << "'";
+ for (int j = 0; j < names.size() && j < values.size(); ++j) {
+ if (names[j] == "RetransSegs") {
+ ++name_count;
+ int64_t val;
+ if (absl::SimpleAtoi(values[j], &val)) {
+ ++value_count;
+ }
+ }
+ }
+ }
+ EXPECT_EQ(name_count, 1);
+ EXPECT_EQ(value_count, 1);
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/proc_net_tcp.cc b/test/syscalls/linux/proc_net_tcp.cc
new file mode 100644
index 000000000..5b6e3e3cd
--- /dev/null
+++ b/test/syscalls/linux/proc_net_tcp.cc
@@ -0,0 +1,496 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <netinet/tcp.h>
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+using absl::StrCat;
+using absl::StrSplit;
+
+constexpr char kProcNetTCPHeader[] =
+ " sl local_address rem_address st tx_queue rx_queue tr tm->when "
+ "retrnsmt uid timeout inode "
+ " ";
+
+// TCPEntry represents a single entry from /proc/net/tcp.
+struct TCPEntry {
+ uint32_t local_addr;
+ uint16_t local_port;
+
+ uint32_t remote_addr;
+ uint16_t remote_port;
+
+ uint64_t state;
+ uint64_t uid;
+ uint64_t inode;
+};
+
+// Finds the first entry in 'entries' for which 'predicate' returns true.
+// Returns true on match, and sets 'match' to a copy of the matching entry. If
+// 'match' is null, it's ignored.
+bool FindBy(const std::vector<TCPEntry>& entries, TCPEntry* match,
+ std::function<bool(const TCPEntry&)> predicate) {
+ for (const TCPEntry& entry : entries) {
+ if (predicate(entry)) {
+ if (match != nullptr) {
+ *match = entry;
+ }
+ return true;
+ }
+ }
+ return false;
+}
+
+bool FindByLocalAddr(const std::vector<TCPEntry>& entries, TCPEntry* match,
+ const struct sockaddr* addr) {
+ uint32_t host = IPFromInetSockaddr(addr);
+ uint16_t port = PortFromInetSockaddr(addr);
+ return FindBy(entries, match, [host, port](const TCPEntry& e) {
+ return (e.local_addr == host && e.local_port == port);
+ });
+}
+
+bool FindByRemoteAddr(const std::vector<TCPEntry>& entries, TCPEntry* match,
+ const struct sockaddr* addr) {
+ uint32_t host = IPFromInetSockaddr(addr);
+ uint16_t port = PortFromInetSockaddr(addr);
+ return FindBy(entries, match, [host, port](const TCPEntry& e) {
+ return (e.remote_addr == host && e.remote_port == port);
+ });
+}
+
+// Returns a parsed representation of /proc/net/tcp entries.
+PosixErrorOr<std::vector<TCPEntry>> ProcNetTCPEntries() {
+ std::string content;
+ RETURN_IF_ERRNO(GetContents("/proc/net/tcp", &content));
+
+ bool found_header = false;
+ std::vector<TCPEntry> entries;
+ std::vector<std::string> lines = StrSplit(content, '\n');
+ std::cerr << "<contents of /proc/net/tcp>" << std::endl;
+ for (const std::string& line : lines) {
+ std::cerr << line << std::endl;
+
+ if (!found_header) {
+ EXPECT_EQ(line, kProcNetTCPHeader);
+ found_header = true;
+ continue;
+ }
+ if (line.empty()) {
+ continue;
+ }
+
+ // Parse a single entry from /proc/net/tcp.
+ //
+ // Example entries:
+ //
+ // clang-format off
+ //
+ // sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
+ // 0: 00000000:006F 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 1968 1 0000000000000000 100 0 0 10 0
+ // 1: 0100007F:7533 00000000:0000 0A 00000000:00000000 00:00000000 00000000 120 0 10684 1 0000000000000000 100 0 0 10 0
+ // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
+ // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
+ //
+ // clang-format on
+
+ TCPEntry entry;
+ std::vector<std::string> fields =
+ StrSplit(line, absl::ByAnyChar(": "), absl::SkipEmpty());
+
+ ASSIGN_OR_RETURN_ERRNO(entry.local_addr, AtoiBase(fields[1], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.local_port, AtoiBase(fields[2], 16));
+
+ ASSIGN_OR_RETURN_ERRNO(entry.remote_addr, AtoiBase(fields[3], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.remote_port, AtoiBase(fields[4], 16));
+
+ ASSIGN_OR_RETURN_ERRNO(entry.state, AtoiBase(fields[5], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.uid, Atoi<uint64_t>(fields[11]));
+ ASSIGN_OR_RETURN_ERRNO(entry.inode, Atoi<uint64_t>(fields[13]));
+
+ entries.push_back(entry);
+ }
+ std::cerr << "<end of /proc/net/tcp>" << std::endl;
+
+ return entries;
+}
+
+TEST(ProcNetTCP, Exists) {
+ const std::string content =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/tcp"));
+ const std::string header_line = StrCat(kProcNetTCPHeader, "\n");
+ if (IsRunningOnGvisor()) {
+ // Should be just the header since we don't have any tcp sockets yet.
+ EXPECT_EQ(content, header_line);
+ } else {
+ // On a general linux machine, we could have abitrary sockets on the system,
+ // so just check the header.
+ EXPECT_THAT(content, ::testing::StartsWith(header_line));
+ }
+}
+
+TEST(ProcNetTCP, EntryUID) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv4TCPAcceptBindSocketPair(0).Create());
+ std::vector<TCPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCPEntries());
+ TCPEntry e;
+ ASSERT_TRUE(FindByLocalAddr(entries, &e, sockets->first_addr()));
+ EXPECT_EQ(e.uid, geteuid());
+ ASSERT_TRUE(FindByRemoteAddr(entries, &e, sockets->first_addr()));
+ EXPECT_EQ(e.uid, geteuid());
+}
+
+TEST(ProcNetTCP, BindAcceptConnect) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv4TCPAcceptBindSocketPair(0).Create());
+ std::vector<TCPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCPEntries());
+ // We can only make assertions about the total number of entries if we control
+ // the entire "machine".
+ if (IsRunningOnGvisor()) {
+ EXPECT_EQ(entries.size(), 2);
+ }
+
+ EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->first_addr()));
+ EXPECT_TRUE(FindByRemoteAddr(entries, nullptr, sockets->first_addr()));
+}
+
+TEST(ProcNetTCP, InodeReasonable) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv4TCPAcceptBindSocketPair(0).Create());
+ std::vector<TCPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCPEntries());
+
+ TCPEntry accepted_entry;
+ ASSERT_TRUE(FindByLocalAddr(entries, &accepted_entry, sockets->first_addr()));
+ EXPECT_NE(accepted_entry.inode, 0);
+
+ TCPEntry client_entry;
+ ASSERT_TRUE(FindByRemoteAddr(entries, &client_entry, sockets->first_addr()));
+ EXPECT_NE(client_entry.inode, 0);
+ EXPECT_NE(accepted_entry.inode, client_entry.inode);
+}
+
+TEST(ProcNetTCP, State) {
+ std::unique_ptr<FileDescriptor> server =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv4TCPUnboundSocket(0).Create());
+
+ auto test_addr = V4Loopback();
+ ASSERT_THAT(
+ bind(server->get(), reinterpret_cast<struct sockaddr*>(&test_addr.addr),
+ test_addr.addr_len),
+ SyscallSucceeds());
+
+ struct sockaddr addr;
+ socklen_t addrlen = sizeof(struct sockaddr);
+ ASSERT_THAT(getsockname(server->get(), &addr, &addrlen), SyscallSucceeds());
+ ASSERT_EQ(addrlen, sizeof(struct sockaddr));
+
+ ASSERT_THAT(listen(server->get(), 10), SyscallSucceeds());
+ std::vector<TCPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCPEntries());
+ TCPEntry listen_entry;
+ ASSERT_TRUE(FindByLocalAddr(entries, &listen_entry, &addr));
+ EXPECT_EQ(listen_entry.state, TCP_LISTEN);
+
+ std::unique_ptr<FileDescriptor> client =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv4TCPUnboundSocket(0).Create());
+ ASSERT_THAT(RetryEINTR(connect)(client->get(), &addr, addrlen),
+ SyscallSucceeds());
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCPEntries());
+ ASSERT_TRUE(FindByLocalAddr(entries, &listen_entry, &addr));
+ EXPECT_EQ(listen_entry.state, TCP_LISTEN);
+ TCPEntry client_entry;
+ ASSERT_TRUE(FindByRemoteAddr(entries, &client_entry, &addr));
+ EXPECT_EQ(client_entry.state, TCP_ESTABLISHED);
+
+ FileDescriptor accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(server->get(), nullptr, nullptr));
+
+ const uint32_t accepted_local_host = IPFromInetSockaddr(&addr);
+ const uint16_t accepted_local_port = PortFromInetSockaddr(&addr);
+
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCPEntries());
+ TCPEntry accepted_entry;
+ ASSERT_TRUE(FindBy(entries, &accepted_entry,
+ [client_entry, accepted_local_host,
+ accepted_local_port](const TCPEntry& e) {
+ return e.local_addr == accepted_local_host &&
+ e.local_port == accepted_local_port &&
+ e.remote_addr == client_entry.local_addr &&
+ e.remote_port == client_entry.local_port;
+ }));
+ EXPECT_EQ(accepted_entry.state, TCP_ESTABLISHED);
+}
+
+constexpr char kProcNetTCP6Header[] =
+ " sl local_address remote_address"
+ " st tx_queue rx_queue tr tm->when retrnsmt"
+ " uid timeout inode";
+
+// TCP6Entry represents a single entry from /proc/net/tcp6.
+struct TCP6Entry {
+ struct in6_addr local_addr;
+ uint16_t local_port;
+
+ struct in6_addr remote_addr;
+ uint16_t remote_port;
+
+ uint64_t state;
+ uint64_t uid;
+ uint64_t inode;
+};
+
+bool IPv6AddrEqual(const struct in6_addr* a1, const struct in6_addr* a2) {
+ return memcmp(a1, a2, sizeof(struct in6_addr)) == 0;
+}
+
+// Finds the first entry in 'entries' for which 'predicate' returns true.
+// Returns true on match, and sets 'match' to a copy of the matching entry. If
+// 'match' is null, it's ignored.
+bool FindBy6(const std::vector<TCP6Entry>& entries, TCP6Entry* match,
+ std::function<bool(const TCP6Entry&)> predicate) {
+ for (const TCP6Entry& entry : entries) {
+ if (predicate(entry)) {
+ if (match != nullptr) {
+ *match = entry;
+ }
+ return true;
+ }
+ }
+ return false;
+}
+
+const struct in6_addr* IP6FromInetSockaddr(const struct sockaddr* addr) {
+ auto* addr6 = reinterpret_cast<const struct sockaddr_in6*>(addr);
+ return &addr6->sin6_addr;
+}
+
+bool FindByLocalAddr6(const std::vector<TCP6Entry>& entries, TCP6Entry* match,
+ const struct sockaddr* addr) {
+ const struct in6_addr* local = IP6FromInetSockaddr(addr);
+ uint16_t port = PortFromInetSockaddr(addr);
+ return FindBy6(entries, match, [local, port](const TCP6Entry& e) {
+ return (IPv6AddrEqual(&e.local_addr, local) && e.local_port == port);
+ });
+}
+
+bool FindByRemoteAddr6(const std::vector<TCP6Entry>& entries, TCP6Entry* match,
+ const struct sockaddr* addr) {
+ const struct in6_addr* remote = IP6FromInetSockaddr(addr);
+ uint16_t port = PortFromInetSockaddr(addr);
+ return FindBy6(entries, match, [remote, port](const TCP6Entry& e) {
+ return (IPv6AddrEqual(&e.remote_addr, remote) && e.remote_port == port);
+ });
+}
+
+void ReadIPv6Address(std::string s, struct in6_addr* addr) {
+ uint32_t a0, a1, a2, a3;
+ const char* fmt = "%08X%08X%08X%08X";
+ EXPECT_EQ(sscanf(s.c_str(), fmt, &a0, &a1, &a2, &a3), 4);
+
+ uint8_t* b = addr->s6_addr;
+ *((uint32_t*)&b[0]) = a0;
+ *((uint32_t*)&b[4]) = a1;
+ *((uint32_t*)&b[8]) = a2;
+ *((uint32_t*)&b[12]) = a3;
+}
+
+// Returns a parsed representation of /proc/net/tcp6 entries.
+PosixErrorOr<std::vector<TCP6Entry>> ProcNetTCP6Entries() {
+ std::string content;
+ RETURN_IF_ERRNO(GetContents("/proc/net/tcp6", &content));
+
+ bool found_header = false;
+ std::vector<TCP6Entry> entries;
+ std::vector<std::string> lines = StrSplit(content, '\n');
+ std::cerr << "<contents of /proc/net/tcp6>" << std::endl;
+ for (const std::string& line : lines) {
+ std::cerr << line << std::endl;
+
+ if (!found_header) {
+ EXPECT_EQ(line, kProcNetTCP6Header);
+ found_header = true;
+ continue;
+ }
+ if (line.empty()) {
+ continue;
+ }
+
+ // Parse a single entry from /proc/net/tcp6.
+ //
+ // Example entries:
+ //
+ // clang-format off
+ //
+ // sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
+ // 0: 00000000000000000000000000000000:1F90 00000000000000000000000000000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 876340 1 ffff8803da9c9380 100 0 0 10 0
+ // 1: 00000000000000000000000000000000:C350 00000000000000000000000000000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 876987 1 ffff8803ec408000 100 0 0 10 0
+ // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
+ // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
+ //
+ // clang-format on
+
+ TCP6Entry entry;
+ std::vector<std::string> fields =
+ StrSplit(line, absl::ByAnyChar(": "), absl::SkipEmpty());
+
+ ReadIPv6Address(fields[1], &entry.local_addr);
+ ASSIGN_OR_RETURN_ERRNO(entry.local_port, AtoiBase(fields[2], 16));
+ ReadIPv6Address(fields[3], &entry.remote_addr);
+ ASSIGN_OR_RETURN_ERRNO(entry.remote_port, AtoiBase(fields[4], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.state, AtoiBase(fields[5], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.uid, Atoi<uint64_t>(fields[11]));
+ ASSIGN_OR_RETURN_ERRNO(entry.inode, Atoi<uint64_t>(fields[13]));
+
+ entries.push_back(entry);
+ }
+ std::cerr << "<end of /proc/net/tcp6>" << std::endl;
+
+ return entries;
+}
+
+TEST(ProcNetTCP6, Exists) {
+ const std::string content =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/tcp6"));
+ const std::string header_line = StrCat(kProcNetTCP6Header, "\n");
+ if (IsRunningOnGvisor()) {
+ // Should be just the header since we don't have any tcp sockets yet.
+ EXPECT_EQ(content, header_line);
+ } else {
+ // On a general linux machine, we could have abitrary sockets on the system,
+ // so just check the header.
+ EXPECT_THAT(content, ::testing::StartsWith(header_line));
+ }
+}
+
+TEST(ProcNetTCP6, EntryUID) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPAcceptBindSocketPair(0).Create());
+ std::vector<TCP6Entry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries());
+ TCP6Entry e;
+
+ ASSERT_TRUE(FindByLocalAddr6(entries, &e, sockets->first_addr()));
+ EXPECT_EQ(e.uid, geteuid());
+ ASSERT_TRUE(FindByRemoteAddr6(entries, &e, sockets->first_addr()));
+ EXPECT_EQ(e.uid, geteuid());
+}
+
+TEST(ProcNetTCP6, BindAcceptConnect) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPAcceptBindSocketPair(0).Create());
+ std::vector<TCP6Entry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries());
+ // We can only make assertions about the total number of entries if we control
+ // the entire "machine".
+ if (IsRunningOnGvisor()) {
+ EXPECT_EQ(entries.size(), 2);
+ }
+
+ EXPECT_TRUE(FindByLocalAddr6(entries, nullptr, sockets->first_addr()));
+ EXPECT_TRUE(FindByRemoteAddr6(entries, nullptr, sockets->first_addr()));
+}
+
+TEST(ProcNetTCP6, InodeReasonable) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPAcceptBindSocketPair(0).Create());
+ std::vector<TCP6Entry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries());
+
+ TCP6Entry accepted_entry;
+
+ ASSERT_TRUE(
+ FindByLocalAddr6(entries, &accepted_entry, sockets->first_addr()));
+ EXPECT_NE(accepted_entry.inode, 0);
+
+ TCP6Entry client_entry;
+ ASSERT_TRUE(FindByRemoteAddr6(entries, &client_entry, sockets->first_addr()));
+ EXPECT_NE(client_entry.inode, 0);
+ EXPECT_NE(accepted_entry.inode, client_entry.inode);
+}
+
+TEST(ProcNetTCP6, State) {
+ std::unique_ptr<FileDescriptor> server =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPUnboundSocket(0).Create());
+
+ auto test_addr = V6Loopback();
+ ASSERT_THAT(
+ bind(server->get(), reinterpret_cast<struct sockaddr*>(&test_addr.addr),
+ test_addr.addr_len),
+ SyscallSucceeds());
+
+ struct sockaddr_in6 addr6;
+ socklen_t addrlen = sizeof(struct sockaddr_in6);
+ auto* addr = reinterpret_cast<struct sockaddr*>(&addr6);
+ ASSERT_THAT(getsockname(server->get(), addr, &addrlen), SyscallSucceeds());
+ ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6));
+
+ ASSERT_THAT(listen(server->get(), 10), SyscallSucceeds());
+ std::vector<TCP6Entry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries());
+ TCP6Entry listen_entry;
+
+ ASSERT_TRUE(FindByLocalAddr6(entries, &listen_entry, addr));
+ EXPECT_EQ(listen_entry.state, TCP_LISTEN);
+
+ std::unique_ptr<FileDescriptor> client =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPUnboundSocket(0).Create());
+ ASSERT_THAT(RetryEINTR(connect)(client->get(), addr, addrlen),
+ SyscallSucceeds());
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries());
+ ASSERT_TRUE(FindByLocalAddr6(entries, &listen_entry, addr));
+ EXPECT_EQ(listen_entry.state, TCP_LISTEN);
+ TCP6Entry client_entry;
+ ASSERT_TRUE(FindByRemoteAddr6(entries, &client_entry, addr));
+ EXPECT_EQ(client_entry.state, TCP_ESTABLISHED);
+
+ FileDescriptor accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(server->get(), nullptr, nullptr));
+
+ const struct in6_addr* local = IP6FromInetSockaddr(addr);
+ const uint16_t accepted_local_port = PortFromInetSockaddr(addr);
+
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries());
+ TCP6Entry accepted_entry;
+ ASSERT_TRUE(FindBy6(
+ entries, &accepted_entry,
+ [client_entry, local, accepted_local_port](const TCP6Entry& e) {
+ return IPv6AddrEqual(&e.local_addr, local) &&
+ e.local_port == accepted_local_port &&
+ IPv6AddrEqual(&e.remote_addr, &client_entry.local_addr) &&
+ e.remote_port == client_entry.local_port;
+ }));
+ EXPECT_EQ(accepted_entry.state, TCP_ESTABLISHED);
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/proc_net_udp.cc b/test/syscalls/linux/proc_net_udp.cc
new file mode 100644
index 000000000..786b4b4af
--- /dev/null
+++ b/test/syscalls/linux/proc_net_udp.cc
@@ -0,0 +1,309 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <netinet/tcp.h>
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+using absl::StrCat;
+using absl::StrFormat;
+using absl::StrSplit;
+
+constexpr char kProcNetUDPHeader[] =
+ " sl local_address rem_address st tx_queue rx_queue tr tm->when "
+ "retrnsmt uid timeout inode ref pointer drops ";
+
+// UDPEntry represents a single entry from /proc/net/udp.
+struct UDPEntry {
+ uint32_t local_addr;
+ uint16_t local_port;
+
+ uint32_t remote_addr;
+ uint16_t remote_port;
+
+ uint64_t state;
+ uint64_t uid;
+ uint64_t inode;
+};
+
+std::string DescribeFirstInetSocket(const SocketPair& sockets) {
+ const struct sockaddr* addr = sockets.first_addr();
+ return StrFormat("First test socket: fd:%d %8X:%4X", sockets.first_fd(),
+ IPFromInetSockaddr(addr), PortFromInetSockaddr(addr));
+}
+
+std::string DescribeSecondInetSocket(const SocketPair& sockets) {
+ const struct sockaddr* addr = sockets.second_addr();
+ return StrFormat("Second test socket fd:%d %8X:%4X", sockets.second_fd(),
+ IPFromInetSockaddr(addr), PortFromInetSockaddr(addr));
+}
+
+// Finds the first entry in 'entries' for which 'predicate' returns true.
+// Returns true on match, and set 'match' to a copy of the matching entry. If
+// 'match' is null, it's ignored.
+bool FindBy(const std::vector<UDPEntry>& entries, UDPEntry* match,
+ std::function<bool(const UDPEntry&)> predicate) {
+ for (const UDPEntry& entry : entries) {
+ if (predicate(entry)) {
+ if (match != nullptr) {
+ *match = entry;
+ }
+ return true;
+ }
+ }
+ return false;
+}
+
+bool FindByLocalAddr(const std::vector<UDPEntry>& entries, UDPEntry* match,
+ const struct sockaddr* addr) {
+ uint32_t host = IPFromInetSockaddr(addr);
+ uint16_t port = PortFromInetSockaddr(addr);
+ return FindBy(entries, match, [host, port](const UDPEntry& e) {
+ return (e.local_addr == host && e.local_port == port);
+ });
+}
+
+bool FindByRemoteAddr(const std::vector<UDPEntry>& entries, UDPEntry* match,
+ const struct sockaddr* addr) {
+ uint32_t host = IPFromInetSockaddr(addr);
+ uint16_t port = PortFromInetSockaddr(addr);
+ return FindBy(entries, match, [host, port](const UDPEntry& e) {
+ return (e.remote_addr == host && e.remote_port == port);
+ });
+}
+
+PosixErrorOr<uint64_t> InodeFromSocketFD(int fd) {
+ ASSIGN_OR_RETURN_ERRNO(struct stat s, Fstat(fd));
+ if (!S_ISSOCK(s.st_mode)) {
+ return PosixError(EINVAL, StrFormat("FD %d is not a socket", fd));
+ }
+ return s.st_ino;
+}
+
+PosixErrorOr<bool> FindByFD(const std::vector<UDPEntry>& entries,
+ UDPEntry* match, int fd) {
+ ASSIGN_OR_RETURN_ERRNO(uint64_t inode, InodeFromSocketFD(fd));
+ return FindBy(entries, match,
+ [inode](const UDPEntry& e) { return (e.inode == inode); });
+}
+
+// Returns a parsed representation of /proc/net/udp entries.
+PosixErrorOr<std::vector<UDPEntry>> ProcNetUDPEntries() {
+ std::string content;
+ RETURN_IF_ERRNO(GetContents("/proc/net/udp", &content));
+
+ bool found_header = false;
+ std::vector<UDPEntry> entries;
+ std::vector<std::string> lines = StrSplit(content, '\n');
+ std::cerr << "<contents of /proc/net/udp>" << std::endl;
+ for (const std::string& line : lines) {
+ std::cerr << line << std::endl;
+
+ if (!found_header) {
+ EXPECT_EQ(line, kProcNetUDPHeader);
+ found_header = true;
+ continue;
+ }
+ if (line.empty()) {
+ continue;
+ }
+
+ // Parse a single entry from /proc/net/udp.
+ //
+ // Example entries:
+ //
+ // clang-format off
+ //
+ // sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode ref pointer drops
+ // 3503: 0100007F:0035 00000000:0000 07 00000000:00000000 00:00000000 00000000 0 0 33317 2 0000000000000000 0
+ // 3518: 00000000:0044 00000000:0000 07 00000000:00000000 00:00000000 00000000 0 0 40394 2 0000000000000000 0
+ // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
+ // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
+ //
+ // clang-format on
+
+ UDPEntry entry;
+ std::vector<std::string> fields =
+ StrSplit(line, absl::ByAnyChar(": "), absl::SkipEmpty());
+
+ ASSIGN_OR_RETURN_ERRNO(entry.local_addr, AtoiBase(fields[1], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.local_port, AtoiBase(fields[2], 16));
+
+ ASSIGN_OR_RETURN_ERRNO(entry.remote_addr, AtoiBase(fields[3], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.remote_port, AtoiBase(fields[4], 16));
+
+ ASSIGN_OR_RETURN_ERRNO(entry.state, AtoiBase(fields[5], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.uid, Atoi<uint64_t>(fields[11]));
+ ASSIGN_OR_RETURN_ERRNO(entry.inode, Atoi<uint64_t>(fields[13]));
+
+ // Linux shares internal data structures between TCP and UDP sockets. The
+ // proc entries for UDP sockets share some fields with TCP sockets, but
+ // these fields should always be zero as they're not meaningful for UDP
+ // sockets.
+ EXPECT_EQ(fields[8], "00") << StrFormat("sl:%s, tr", fields[0]);
+ EXPECT_EQ(fields[9], "00000000") << StrFormat("sl:%s, tm->when", fields[0]);
+ EXPECT_EQ(fields[10], "00000000")
+ << StrFormat("sl:%s, retrnsmt", fields[0]);
+ EXPECT_EQ(fields[12], "0") << StrFormat("sl:%s, timeout", fields[0]);
+
+ entries.push_back(entry);
+ }
+ std::cerr << "<end of /proc/net/udp>" << std::endl;
+
+ return entries;
+}
+
+TEST(ProcNetUDP, Exists) {
+ const std::string content =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/udp"));
+ const std::string header_line = StrCat(kProcNetUDPHeader, "\n");
+ EXPECT_THAT(content, ::testing::StartsWith(header_line));
+}
+
+TEST(ProcNetUDP, EntryUID) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv4UDPBidirectionalBindSocketPair(0).Create());
+ std::vector<UDPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+ UDPEntry e;
+ ASSERT_TRUE(FindByLocalAddr(entries, &e, sockets->first_addr()))
+ << DescribeFirstInetSocket(*sockets);
+ EXPECT_EQ(e.uid, geteuid());
+ ASSERT_TRUE(FindByRemoteAddr(entries, &e, sockets->first_addr()))
+ << DescribeSecondInetSocket(*sockets);
+ EXPECT_EQ(e.uid, geteuid());
+}
+
+TEST(ProcNetUDP, FindMutualEntries) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv4UDPBidirectionalBindSocketPair(0).Create());
+ std::vector<UDPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+
+ EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->first_addr()))
+ << DescribeFirstInetSocket(*sockets);
+ EXPECT_TRUE(FindByRemoteAddr(entries, nullptr, sockets->first_addr()))
+ << DescribeSecondInetSocket(*sockets);
+
+ EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->second_addr()))
+ << DescribeSecondInetSocket(*sockets);
+ EXPECT_TRUE(FindByRemoteAddr(entries, nullptr, sockets->second_addr()))
+ << DescribeFirstInetSocket(*sockets);
+}
+
+TEST(ProcNetUDP, EntriesRemovedOnClose) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv4UDPBidirectionalBindSocketPair(0).Create());
+ std::vector<UDPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+
+ EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->first_addr()))
+ << DescribeFirstInetSocket(*sockets);
+ EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->second_addr()))
+ << DescribeSecondInetSocket(*sockets);
+
+ EXPECT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+ // First socket's entry should be gone, but the second socket's entry should
+ // still exist.
+ EXPECT_FALSE(FindByLocalAddr(entries, nullptr, sockets->first_addr()))
+ << DescribeFirstInetSocket(*sockets);
+ EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->second_addr()))
+ << DescribeSecondInetSocket(*sockets);
+
+ EXPECT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+ // Both entries should be gone.
+ EXPECT_FALSE(FindByLocalAddr(entries, nullptr, sockets->first_addr()))
+ << DescribeFirstInetSocket(*sockets);
+ EXPECT_FALSE(FindByLocalAddr(entries, nullptr, sockets->second_addr()))
+ << DescribeSecondInetSocket(*sockets);
+}
+
+PosixErrorOr<std::unique_ptr<FileDescriptor>> BoundUDPSocket() {
+ ASSIGN_OR_RETURN_ERRNO(std::unique_ptr<FileDescriptor> socket,
+ IPv4UDPUnboundSocket(0).Create());
+ struct sockaddr_in addr;
+ addr.sin_family = AF_INET;
+ addr.sin_addr.s_addr = htonl(INADDR_ANY);
+ addr.sin_port = 0;
+
+ int res = bind(socket->get(), reinterpret_cast<const struct sockaddr*>(&addr),
+ sizeof(addr));
+ if (res) {
+ return PosixError(errno, "bind()");
+ }
+ return socket;
+}
+
+TEST(ProcNetUDP, BoundEntry) {
+ std::unique_ptr<FileDescriptor> socket =
+ ASSERT_NO_ERRNO_AND_VALUE(BoundUDPSocket());
+ struct sockaddr addr;
+ socklen_t len = sizeof(addr);
+ ASSERT_THAT(getsockname(socket->get(), &addr, &len), SyscallSucceeds());
+ uint16_t port = PortFromInetSockaddr(&addr);
+
+ std::vector<UDPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+ UDPEntry e;
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(FindByFD(entries, &e, socket->get())));
+ EXPECT_EQ(e.local_port, port);
+ EXPECT_EQ(e.remote_addr, 0);
+ EXPECT_EQ(e.remote_port, 0);
+}
+
+TEST(ProcNetUDP, BoundSocketStateClosed) {
+ std::unique_ptr<FileDescriptor> socket =
+ ASSERT_NO_ERRNO_AND_VALUE(BoundUDPSocket());
+ std::vector<UDPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+ UDPEntry e;
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(FindByFD(entries, &e, socket->get())));
+ EXPECT_EQ(e.state, TCP_CLOSE);
+}
+
+TEST(ProcNetUDP, ConnectedSocketStateEstablished) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv4UDPBidirectionalBindSocketPair(0).Create());
+ std::vector<UDPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+
+ UDPEntry e;
+ ASSERT_TRUE(FindByLocalAddr(entries, &e, sockets->first_addr()))
+ << DescribeFirstInetSocket(*sockets);
+ EXPECT_EQ(e.state, TCP_ESTABLISHED);
+
+ ASSERT_TRUE(FindByLocalAddr(entries, &e, sockets->second_addr()))
+ << DescribeSecondInetSocket(*sockets);
+ EXPECT_EQ(e.state, TCP_ESTABLISHED);
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc
new file mode 100644
index 000000000..a63067586
--- /dev/null
+++ b/test/syscalls/linux/proc_net_unix.cc
@@ -0,0 +1,443 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gtest/gtest.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+using absl::StrCat;
+using absl::StreamFormat;
+using absl::StrFormat;
+
+constexpr char kProcNetUnixHeader[] =
+ "Num RefCount Protocol Flags Type St Inode Path";
+
+// Possible values of the "st" field in a /proc/net/unix entry. Source: Linux
+// kernel, include/uapi/linux/net.h.
+enum {
+ SS_FREE = 0, // Not allocated
+ SS_UNCONNECTED, // Unconnected to any socket
+ SS_CONNECTING, // In process of connecting
+ SS_CONNECTED, // Connected to socket
+ SS_DISCONNECTING // In process of disconnecting
+};
+
+// UnixEntry represents a single entry from /proc/net/unix.
+struct UnixEntry {
+ uintptr_t addr;
+ uint64_t refs;
+ uint64_t protocol;
+ uint64_t flags;
+ uint64_t type;
+ uint64_t state;
+ uint64_t inode;
+ std::string path;
+};
+
+// Abstract socket paths can have either trailing null bytes or '@'s as padding
+// at the end, depending on the linux version. This function strips any such
+// padding.
+void StripAbstractPathPadding(std::string* s) {
+ const char pad_char = s->back();
+ if (pad_char != '\0' && pad_char != '@') {
+ return;
+ }
+
+ const auto last_pos = s->find_last_not_of(pad_char);
+ if (last_pos != std::string::npos) {
+ s->resize(last_pos + 1);
+ }
+}
+
+// Precondition: addr must be a unix socket address (i.e. sockaddr_un) and
+// addr->sun_path must be null-terminated. This is always the case if addr comes
+// from Linux:
+//
+// Per man unix(7):
+//
+// "When the address of a pathname socket is returned (by [getsockname(2)]), its
+// length is
+//
+// offsetof(struct sockaddr_un, sun_path) + strlen(sun_path) + 1
+//
+// and sun_path contains the null-terminated pathname."
+std::string ExtractPath(const struct sockaddr* addr) {
+ const char* path =
+ reinterpret_cast<const struct sockaddr_un*>(addr)->sun_path;
+ // Note: sockaddr_un.sun_path is an embedded character array of length
+ // UNIX_PATH_MAX, so we can always safely dereference the first 2 bytes below.
+ //
+ // We also rely on the path being null-terminated.
+ if (path[0] == 0) {
+ std::string abstract_path = StrCat("@", &path[1]);
+ StripAbstractPathPadding(&abstract_path);
+ return abstract_path;
+ }
+ return std::string(path);
+}
+
+// Returns a parsed representation of /proc/net/unix entries.
+PosixErrorOr<std::vector<UnixEntry>> ProcNetUnixEntries() {
+ std::string content;
+ RETURN_IF_ERRNO(GetContents("/proc/net/unix", &content));
+
+ bool skipped_header = false;
+ std::vector<UnixEntry> entries;
+ std::vector<std::string> lines = absl::StrSplit(content, '\n');
+ std::cerr << "<contents of /proc/net/unix>" << std::endl;
+ for (const std::string& line : lines) {
+ // Emit the proc entry to the test output to provide context for the test
+ // results.
+ std::cerr << line << std::endl;
+
+ if (!skipped_header) {
+ EXPECT_EQ(line, kProcNetUnixHeader);
+ skipped_header = true;
+ continue;
+ }
+ if (line.empty()) {
+ continue;
+ }
+
+ // Parse a single entry from /proc/net/unix.
+ //
+ // Sample file:
+ //
+ // clang-format off
+ //
+ // Num RefCount Protocol Flags Type St Inode Path"
+ // ffffa130e7041c00: 00000002 00000000 00010000 0001 01 1299413685 /tmp/control_server/13293772586877554487
+ // ffffa14f547dc400: 00000002 00000000 00010000 0001 01 3793 @remote_coredump
+ //
+ // clang-format on
+ //
+ // Note that from the second entry, the inode number can be padded using
+ // spaces, so we need to handle it separately during parsing. See
+ // net/unix/af_unix.c:unix_seq_show() for how these entries are produced. In
+ // particular, only the inode field is padded with spaces.
+ UnixEntry entry;
+
+ // Process the first 6 fields, up to but not including "Inode".
+ std::vector<std::string> fields =
+ absl::StrSplit(line, absl::MaxSplits(' ', 6));
+
+ if (fields.size() < 7) {
+ return PosixError(EINVAL, StrFormat("Invalid entry: '%s'\n", line));
+ }
+
+ // AtoiBase can't handle the ':' in the "Num" field, so strip it out.
+ std::vector<std::string> addr = absl::StrSplit(fields[0], ':');
+ ASSIGN_OR_RETURN_ERRNO(entry.addr, AtoiBase(addr[0], 16));
+
+ ASSIGN_OR_RETURN_ERRNO(entry.refs, AtoiBase(fields[1], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.protocol, AtoiBase(fields[2], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.flags, AtoiBase(fields[3], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.type, AtoiBase(fields[4], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.state, AtoiBase(fields[5], 16));
+
+ absl::string_view rest = absl::StripAsciiWhitespace(fields[6]);
+ fields = absl::StrSplit(rest, absl::MaxSplits(' ', 1));
+ if (fields.empty()) {
+ return PosixError(
+ EINVAL, StrFormat("Invalid entry, missing 'Inode': '%s'\n", line));
+ }
+ ASSIGN_OR_RETURN_ERRNO(entry.inode, AtoiBase(fields[0], 10));
+
+ entry.path = "";
+ if (fields.size() > 1) {
+ entry.path = fields[1];
+ StripAbstractPathPadding(&entry.path);
+ }
+
+ entries.push_back(entry);
+ }
+ std::cerr << "<end of /proc/net/unix>" << std::endl;
+
+ return entries;
+}
+
+// Finds the first entry in 'entries' for which 'predicate' returns true.
+// Returns true on match, and sets 'match' to point to the matching entry.
+bool FindBy(std::vector<UnixEntry> entries, UnixEntry* match,
+ std::function<bool(const UnixEntry&)> predicate) {
+ for (int i = 0; i < entries.size(); ++i) {
+ if (predicate(entries[i])) {
+ *match = entries[i];
+ return true;
+ }
+ }
+ return false;
+}
+
+bool FindByPath(std::vector<UnixEntry> entries, UnixEntry* match,
+ const std::string& path) {
+ return FindBy(entries, match,
+ [path](const UnixEntry& e) { return e.path == path; });
+}
+
+TEST(ProcNetUnix, Exists) {
+ const std::string content =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/unix"));
+ const std::string header_line = StrCat(kProcNetUnixHeader, "\n");
+ if (IsRunningOnGvisor()) {
+ // Should be just the header since we don't have any unix domain sockets
+ // yet.
+ EXPECT_EQ(content, header_line);
+ } else {
+ // However, on a general linux machine, we could have abitrary sockets on
+ // the system, so just check the header.
+ EXPECT_THAT(content, ::testing::StartsWith(header_line));
+ }
+}
+
+TEST(ProcNetUnix, FilesystemBindAcceptConnect) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(
+ FilesystemBoundUnixDomainSocketPair(SOCK_STREAM).Create());
+
+ std::string path1 = ExtractPath(sockets->first_addr());
+ std::string path2 = ExtractPath(sockets->second_addr());
+ std::cerr << StreamFormat("Server socket address (path1): %s\n", path1);
+ std::cerr << StreamFormat("Client socket address (path2): %s\n", path2);
+
+ std::vector<UnixEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
+ if (IsRunningOnGvisor()) {
+ EXPECT_EQ(entries.size(), 2);
+ }
+
+ // The server-side socket's path is listed in the socket entry...
+ UnixEntry s1;
+ EXPECT_TRUE(FindByPath(entries, &s1, path1));
+
+ // ... but the client-side socket's path is not.
+ UnixEntry s2;
+ EXPECT_FALSE(FindByPath(entries, &s2, path2));
+}
+
+TEST(ProcNetUnix, AbstractBindAcceptConnect) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(
+ AbstractBoundUnixDomainSocketPair(SOCK_STREAM).Create());
+
+ std::string path1 = ExtractPath(sockets->first_addr());
+ std::string path2 = ExtractPath(sockets->second_addr());
+ std::cerr << StreamFormat("Server socket address (path1): '%s'\n", path1);
+ std::cerr << StreamFormat("Client socket address (path2): '%s'\n", path2);
+
+ std::vector<UnixEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
+ if (IsRunningOnGvisor()) {
+ EXPECT_EQ(entries.size(), 2);
+ }
+
+ // The server-side socket's path is listed in the socket entry...
+ UnixEntry s1;
+ EXPECT_TRUE(FindByPath(entries, &s1, path1));
+
+ // ... but the client-side socket's path is not.
+ UnixEntry s2;
+ EXPECT_FALSE(FindByPath(entries, &s2, path2));
+}
+
+TEST(ProcNetUnix, SocketPair) {
+ // Under gvisor, ensure a socketpair() syscall creates exactly 2 new
+ // entries. We have no way to verify this under Linux, as we have no control
+ // over socket creation on a general Linux machine.
+ SKIP_IF(!IsRunningOnGvisor());
+
+ std::vector<UnixEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
+ ASSERT_EQ(entries.size(), 0);
+
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_STREAM).Create());
+
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
+ EXPECT_EQ(entries.size(), 2);
+}
+
+TEST(ProcNetUnix, StreamSocketStateUnconnectedOnBind) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(
+ AbstractUnboundUnixDomainSocketPair(SOCK_STREAM).Create());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ std::vector<UnixEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
+
+ const std::string address = ExtractPath(sockets->first_addr());
+ UnixEntry bind_entry;
+ ASSERT_TRUE(FindByPath(entries, &bind_entry, address));
+ EXPECT_EQ(bind_entry.state, SS_UNCONNECTED);
+}
+
+TEST(ProcNetUnix, StreamSocketStateStateUnconnectedOnListen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(
+ AbstractUnboundUnixDomainSocketPair(SOCK_STREAM).Create());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ std::vector<UnixEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
+
+ const std::string address = ExtractPath(sockets->first_addr());
+ UnixEntry bind_entry;
+ ASSERT_TRUE(FindByPath(entries, &bind_entry, address));
+ EXPECT_EQ(bind_entry.state, SS_UNCONNECTED);
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
+ UnixEntry listen_entry;
+ ASSERT_TRUE(
+ FindByPath(entries, &listen_entry, ExtractPath(sockets->first_addr())));
+ EXPECT_EQ(listen_entry.state, SS_UNCONNECTED);
+ // The bind and listen entries should refer to the same socket.
+ EXPECT_EQ(listen_entry.inode, bind_entry.inode);
+}
+
+TEST(ProcNetUnix, StreamSocketStateStateConnectedOnAccept) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(
+ AbstractUnboundUnixDomainSocketPair(SOCK_STREAM).Create());
+ const std::string address = ExtractPath(sockets->first_addr());
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+ std::vector<UnixEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
+ UnixEntry listen_entry;
+ ASSERT_TRUE(
+ FindByPath(entries, &listen_entry, ExtractPath(sockets->first_addr())));
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ int clientfd;
+ ASSERT_THAT(clientfd = accept(sockets->first_fd(), nullptr, nullptr),
+ SyscallSucceeds());
+
+ // Find the entry for the accepted socket. UDS proc entries don't have a
+ // remote address, so we distinguish the accepted socket from the listen
+ // socket by checking for a different inode.
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
+ UnixEntry accept_entry;
+ ASSERT_TRUE(FindBy(
+ entries, &accept_entry, [address, listen_entry](const UnixEntry& e) {
+ return e.path == address && e.inode != listen_entry.inode;
+ }));
+ EXPECT_EQ(accept_entry.state, SS_CONNECTED);
+ // Listen entry should still be in SS_UNCONNECTED state.
+ ASSERT_TRUE(FindBy(entries, &listen_entry,
+ [&sockets, listen_entry](const UnixEntry& e) {
+ return e.path == ExtractPath(sockets->first_addr()) &&
+ e.inode == listen_entry.inode;
+ }));
+ EXPECT_EQ(listen_entry.state, SS_UNCONNECTED);
+}
+
+TEST(ProcNetUnix, DgramSocketStateDisconnectingOnBind) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(
+ AbstractUnboundUnixDomainSocketPair(SOCK_DGRAM).Create());
+
+ std::vector<UnixEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
+
+ // On gVisor, the only two UDS on the system are the ones we just created and
+ // we rely on this to locate the test socket entries in the remainder of the
+ // test. On a generic Linux system, we have no easy way to locate the
+ // corresponding entries, as they don't have an address yet.
+ if (IsRunningOnGvisor()) {
+ ASSERT_EQ(entries.size(), 2);
+ for (const auto& e : entries) {
+ ASSERT_EQ(e.state, SS_DISCONNECTING);
+ }
+ }
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
+ const std::string address = ExtractPath(sockets->first_addr());
+ UnixEntry bind_entry;
+ ASSERT_TRUE(FindByPath(entries, &bind_entry, address));
+ EXPECT_EQ(bind_entry.state, SS_UNCONNECTED);
+}
+
+TEST(ProcNetUnix, DgramSocketStateConnectingOnConnect) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(
+ AbstractUnboundUnixDomainSocketPair(SOCK_DGRAM).Create());
+
+ std::vector<UnixEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
+
+ // On gVisor, the only two UDS on the system are the ones we just created and
+ // we rely on this to locate the test socket entries in the remainder of the
+ // test. On a generic Linux system, we have no easy way to locate the
+ // corresponding entries, as they don't have an address yet.
+ if (IsRunningOnGvisor()) {
+ ASSERT_EQ(entries.size(), 2);
+ for (const auto& e : entries) {
+ ASSERT_EQ(e.state, SS_DISCONNECTING);
+ }
+ }
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
+ const std::string address = ExtractPath(sockets->first_addr());
+ UnixEntry bind_entry;
+ ASSERT_TRUE(FindByPath(entries, &bind_entry, address));
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
+
+ // Once again, we have no easy way to identify the connecting socket as it has
+ // no listed address. We can only identify the entry as the "non-bind socket
+ // entry" on gVisor, where we're guaranteed to have only the two entries we
+ // create during this test.
+ if (IsRunningOnGvisor()) {
+ ASSERT_EQ(entries.size(), 2);
+ UnixEntry connect_entry;
+ ASSERT_TRUE(
+ FindBy(entries, &connect_entry, [bind_entry](const UnixEntry& e) {
+ return e.inode != bind_entry.inode;
+ }));
+ EXPECT_EQ(connect_entry.state, SS_CONNECTING);
+ }
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/proc_pid_oomscore.cc b/test/syscalls/linux/proc_pid_oomscore.cc
new file mode 100644
index 000000000..707821a3f
--- /dev/null
+++ b/test/syscalls/linux/proc_pid_oomscore.cc
@@ -0,0 +1,72 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+
+#include <exception>
+#include <iostream>
+#include <string>
+
+#include "test/util/fs_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+PosixErrorOr<int> ReadProcNumber(std::string path) {
+ ASSIGN_OR_RETURN_ERRNO(std::string contents, GetContents(path));
+ EXPECT_EQ(contents[contents.length() - 1], '\n');
+
+ int num;
+ if (!absl::SimpleAtoi(contents, &num)) {
+ return PosixError(EINVAL, "invalid value: " + contents);
+ }
+
+ return num;
+}
+
+TEST(ProcPidOomscoreTest, BasicRead) {
+ auto const oom_score =
+ ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score"));
+ EXPECT_LE(oom_score, 1000);
+ EXPECT_GE(oom_score, -1000);
+}
+
+TEST(ProcPidOomscoreAdjTest, BasicRead) {
+ auto const oom_score =
+ ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score_adj"));
+
+ // oom_score_adj defaults to 0.
+ EXPECT_EQ(oom_score, 0);
+}
+
+TEST(ProcPidOomscoreAdjTest, BasicWrite) {
+ constexpr int test_value = 7;
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/oom_score_adj", O_WRONLY));
+ ASSERT_THAT(
+ RetryEINTR(write)(fd.get(), std::to_string(test_value).c_str(), 1),
+ SyscallSucceeds());
+
+ auto const oom_score =
+ ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score_adj"));
+ EXPECT_EQ(oom_score, test_value);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/proc_pid_smaps.cc b/test/syscalls/linux/proc_pid_smaps.cc
new file mode 100644
index 000000000..9fb1b3a2c
--- /dev/null
+++ b/test/syscalls/linux/proc_pid_smaps.cc
@@ -0,0 +1,468 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <algorithm>
+#include <iostream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/memory_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/proc_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+using ::testing::Contains;
+using ::testing::ElementsAreArray;
+using ::testing::IsSupersetOf;
+using ::testing::Not;
+using ::testing::Optional;
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+struct ProcPidSmapsEntry {
+ ProcMapsEntry maps_entry;
+
+ // These fields should always exist, as they were included in e070ad49f311
+ // "[PATCH] add /proc/pid/smaps".
+ size_t size_kb;
+ size_t rss_kb;
+ size_t shared_clean_kb;
+ size_t shared_dirty_kb;
+ size_t private_clean_kb;
+ size_t private_dirty_kb;
+
+ // These fields were added later and may not be present.
+ absl::optional<size_t> pss_kb;
+ absl::optional<size_t> referenced_kb;
+ absl::optional<size_t> anonymous_kb;
+ absl::optional<size_t> anon_huge_pages_kb;
+ absl::optional<size_t> shared_hugetlb_kb;
+ absl::optional<size_t> private_hugetlb_kb;
+ absl::optional<size_t> swap_kb;
+ absl::optional<size_t> swap_pss_kb;
+ absl::optional<size_t> kernel_page_size_kb;
+ absl::optional<size_t> mmu_page_size_kb;
+ absl::optional<size_t> locked_kb;
+
+ // Caution: "Note that there is no guarantee that every flag and associated
+ // mnemonic will be present in all further kernel releases. Things get
+ // changed, the flags may be vanished or the reverse -- new added." - Linux
+ // Documentation/filesystems/proc.txt, on VmFlags. Avoid checking for any
+ // flags that are not extremely well-established.
+ absl::optional<std::vector<std::string>> vm_flags;
+};
+
+// Given the value part of a /proc/[pid]/smaps field containing a value in kB
+// (for example, " 4 kB", returns the value in kB (in this example, 4).
+PosixErrorOr<size_t> SmapsValueKb(absl::string_view value) {
+ // TODO(jamieliu): let us use RE2 or <regex>
+ std::pair<absl::string_view, absl::string_view> parts =
+ absl::StrSplit(value, ' ', absl::SkipEmpty());
+ if (parts.second != "kB") {
+ return PosixError(EINVAL,
+ absl::StrCat("invalid smaps field value: ", value));
+ }
+ ASSIGN_OR_RETURN_ERRNO(auto val_kb, Atoi<size_t>(parts.first));
+ return val_kb;
+}
+
+PosixErrorOr<std::vector<ProcPidSmapsEntry>> ParseProcPidSmaps(
+ absl::string_view contents) {
+ std::vector<ProcPidSmapsEntry> entries;
+ absl::optional<ProcPidSmapsEntry> entry;
+ bool have_size_kb = false;
+ bool have_rss_kb = false;
+ bool have_shared_clean_kb = false;
+ bool have_shared_dirty_kb = false;
+ bool have_private_clean_kb = false;
+ bool have_private_dirty_kb = false;
+
+ auto const finish_entry = [&] {
+ if (entry) {
+ if (!have_size_kb) {
+ return PosixError(EINVAL, "smaps entry is missing Size");
+ }
+ if (!have_rss_kb) {
+ return PosixError(EINVAL, "smaps entry is missing Rss");
+ }
+ if (!have_shared_clean_kb) {
+ return PosixError(EINVAL, "smaps entry is missing Shared_Clean");
+ }
+ if (!have_shared_dirty_kb) {
+ return PosixError(EINVAL, "smaps entry is missing Shared_Dirty");
+ }
+ if (!have_private_clean_kb) {
+ return PosixError(EINVAL, "smaps entry is missing Private_Clean");
+ }
+ if (!have_private_dirty_kb) {
+ return PosixError(EINVAL, "smaps entry is missing Private_Dirty");
+ }
+ // std::move(entry.value()) instead of std::move(entry).value(), because
+ // otherwise tools may report a "use-after-move" warning, which is
+ // spurious because entry.emplace() below resets entry to a new
+ // ProcPidSmapsEntry.
+ entries.emplace_back(std::move(entry.value()));
+ }
+ entry.emplace();
+ have_size_kb = false;
+ have_rss_kb = false;
+ have_shared_clean_kb = false;
+ have_shared_dirty_kb = false;
+ have_private_clean_kb = false;
+ have_private_dirty_kb = false;
+ return NoError();
+ };
+
+ // Holds key/value pairs from smaps field lines. Declared here so it can be
+ // captured by reference by the following lambdas.
+ std::vector<absl::string_view> key_value;
+
+ auto const on_required_field_kb = [&](size_t* field, bool* have_field) {
+ if (*have_field) {
+ return PosixError(
+ EINVAL,
+ absl::StrFormat("smaps entry has duplicate %s line", key_value[0]));
+ }
+ ASSIGN_OR_RETURN_ERRNO(*field, SmapsValueKb(key_value[1]));
+ *have_field = true;
+ return NoError();
+ };
+
+ auto const on_optional_field_kb = [&](absl::optional<size_t>* field) {
+ if (*field) {
+ return PosixError(
+ EINVAL,
+ absl::StrFormat("smaps entry has duplicate %s line", key_value[0]));
+ }
+ ASSIGN_OR_RETURN_ERRNO(*field, SmapsValueKb(key_value[1]));
+ return NoError();
+ };
+
+ absl::flat_hash_set<std::string> unknown_fields;
+ auto const on_unknown_field = [&] {
+ absl::string_view key = key_value[0];
+ // Don't mention unknown fields more than once.
+ if (unknown_fields.count(key)) {
+ return;
+ }
+ unknown_fields.insert(std::string(key));
+ std::cerr << "skipping unknown smaps field " << key << std::endl;
+ };
+
+ auto lines = absl::StrSplit(contents, '\n', absl::SkipEmpty());
+ for (absl::string_view l : lines) {
+ // Is this line a valid /proc/[pid]/maps entry?
+ auto maybe_maps_entry = ParseProcMapsLine(l);
+ if (maybe_maps_entry.ok()) {
+ // This marks the beginning of a new /proc/[pid]/smaps entry.
+ RETURN_IF_ERRNO(finish_entry());
+ entry->maps_entry = std::move(maybe_maps_entry).ValueOrDie();
+ continue;
+ }
+ // Otherwise it's a field in an existing /proc/[pid]/smaps entry of the form
+ // "key:value" (where value in practice will be preceded by a variable
+ // amount of whitespace).
+ if (!entry) {
+ std::cerr << "smaps line not considered a maps line: "
+ << maybe_maps_entry.error_message() << std::endl;
+ return PosixError(
+ EINVAL,
+ absl::StrCat("smaps field line without preceding maps line: ", l));
+ }
+ key_value = absl::StrSplit(l, absl::MaxSplits(':', 1));
+ if (key_value.size() != 2) {
+ return PosixError(EINVAL, absl::StrCat("invalid smaps field line: ", l));
+ }
+ absl::string_view const key = key_value[0];
+ if (key == "Size") {
+ RETURN_IF_ERRNO(on_required_field_kb(&entry->size_kb, &have_size_kb));
+ } else if (key == "Rss") {
+ RETURN_IF_ERRNO(on_required_field_kb(&entry->rss_kb, &have_rss_kb));
+ } else if (key == "Shared_Clean") {
+ RETURN_IF_ERRNO(
+ on_required_field_kb(&entry->shared_clean_kb, &have_shared_clean_kb));
+ } else if (key == "Shared_Dirty") {
+ RETURN_IF_ERRNO(
+ on_required_field_kb(&entry->shared_dirty_kb, &have_shared_dirty_kb));
+ } else if (key == "Private_Clean") {
+ RETURN_IF_ERRNO(on_required_field_kb(&entry->private_clean_kb,
+ &have_private_clean_kb));
+ } else if (key == "Private_Dirty") {
+ RETURN_IF_ERRNO(on_required_field_kb(&entry->private_dirty_kb,
+ &have_private_dirty_kb));
+ } else if (key == "Pss") {
+ RETURN_IF_ERRNO(on_optional_field_kb(&entry->pss_kb));
+ } else if (key == "Referenced") {
+ RETURN_IF_ERRNO(on_optional_field_kb(&entry->referenced_kb));
+ } else if (key == "Anonymous") {
+ RETURN_IF_ERRNO(on_optional_field_kb(&entry->anonymous_kb));
+ } else if (key == "AnonHugePages") {
+ RETURN_IF_ERRNO(on_optional_field_kb(&entry->anon_huge_pages_kb));
+ } else if (key == "Shared_Hugetlb") {
+ RETURN_IF_ERRNO(on_optional_field_kb(&entry->shared_hugetlb_kb));
+ } else if (key == "Private_Hugetlb") {
+ RETURN_IF_ERRNO(on_optional_field_kb(&entry->private_hugetlb_kb));
+ } else if (key == "Swap") {
+ RETURN_IF_ERRNO(on_optional_field_kb(&entry->swap_kb));
+ } else if (key == "SwapPss") {
+ RETURN_IF_ERRNO(on_optional_field_kb(&entry->swap_pss_kb));
+ } else if (key == "KernelPageSize") {
+ RETURN_IF_ERRNO(on_optional_field_kb(&entry->kernel_page_size_kb));
+ } else if (key == "MMUPageSize") {
+ RETURN_IF_ERRNO(on_optional_field_kb(&entry->mmu_page_size_kb));
+ } else if (key == "Locked") {
+ RETURN_IF_ERRNO(on_optional_field_kb(&entry->locked_kb));
+ } else if (key == "VmFlags") {
+ if (entry->vm_flags) {
+ return PosixError(EINVAL, "duplicate VmFlags line");
+ }
+ entry->vm_flags = absl::StrSplit(key_value[1], ' ', absl::SkipEmpty());
+ } else {
+ on_unknown_field();
+ }
+ }
+ RETURN_IF_ERRNO(finish_entry());
+ return entries;
+};
+
+TEST(ParseProcPidSmapsTest, Correctness) {
+ auto entries = ASSERT_NO_ERRNO_AND_VALUE(
+ ParseProcPidSmaps("0-10000 rw-s 00000000 00:00 0 "
+ " /dev/zero (deleted)\n"
+ "Size: 0 kB\n"
+ "Rss: 1 kB\n"
+ "Pss: 2 kB\n"
+ "Shared_Clean: 3 kB\n"
+ "Shared_Dirty: 4 kB\n"
+ "Private_Clean: 5 kB\n"
+ "Private_Dirty: 6 kB\n"
+ "Referenced: 7 kB\n"
+ "Anonymous: 8 kB\n"
+ "AnonHugePages: 9 kB\n"
+ "Shared_Hugetlb: 10 kB\n"
+ "Private_Hugetlb: 11 kB\n"
+ "Swap: 12 kB\n"
+ "SwapPss: 13 kB\n"
+ "KernelPageSize: 14 kB\n"
+ "MMUPageSize: 15 kB\n"
+ "Locked: 16 kB\n"
+ "FutureUnknownKey: 17 kB\n"
+ "VmFlags: rd wr sh mr mw me ms lo ?? sd \n"));
+ ASSERT_EQ(entries.size(), 1);
+ auto& entry = entries[0];
+ EXPECT_EQ(entry.maps_entry.filename, "/dev/zero (deleted)");
+ EXPECT_EQ(entry.size_kb, 0);
+ EXPECT_EQ(entry.rss_kb, 1);
+ EXPECT_THAT(entry.pss_kb, Optional(2));
+ EXPECT_EQ(entry.shared_clean_kb, 3);
+ EXPECT_EQ(entry.shared_dirty_kb, 4);
+ EXPECT_EQ(entry.private_clean_kb, 5);
+ EXPECT_EQ(entry.private_dirty_kb, 6);
+ EXPECT_THAT(entry.referenced_kb, Optional(7));
+ EXPECT_THAT(entry.anonymous_kb, Optional(8));
+ EXPECT_THAT(entry.anon_huge_pages_kb, Optional(9));
+ EXPECT_THAT(entry.shared_hugetlb_kb, Optional(10));
+ EXPECT_THAT(entry.private_hugetlb_kb, Optional(11));
+ EXPECT_THAT(entry.swap_kb, Optional(12));
+ EXPECT_THAT(entry.swap_pss_kb, Optional(13));
+ EXPECT_THAT(entry.kernel_page_size_kb, Optional(14));
+ EXPECT_THAT(entry.mmu_page_size_kb, Optional(15));
+ EXPECT_THAT(entry.locked_kb, Optional(16));
+ EXPECT_THAT(entry.vm_flags,
+ Optional(ElementsAreArray({"rd", "wr", "sh", "mr", "mw", "me",
+ "ms", "lo", "??", "sd"})));
+}
+
+// Returns the unique entry in entries containing the given address.
+PosixErrorOr<ProcPidSmapsEntry> FindUniqueSmapsEntry(
+ std::vector<ProcPidSmapsEntry> const& entries, uintptr_t addr) {
+ auto const pred = [&](ProcPidSmapsEntry const& entry) {
+ return entry.maps_entry.start <= addr && addr < entry.maps_entry.end;
+ };
+ auto const it = std::find_if(entries.begin(), entries.end(), pred);
+ if (it == entries.end()) {
+ return PosixError(EINVAL,
+ absl::StrFormat("no entry contains address %#x", addr));
+ }
+ auto const it2 = std::find_if(it + 1, entries.end(), pred);
+ if (it2 != entries.end()) {
+ return PosixError(
+ EINVAL,
+ absl::StrFormat("overlapping entries [%#x-%#x) and [%#x-%#x) both "
+ "contain address %#x",
+ it->maps_entry.start, it->maps_entry.end,
+ it2->maps_entry.start, it2->maps_entry.end, addr));
+ }
+ return *it;
+}
+
+PosixErrorOr<std::vector<ProcPidSmapsEntry>> ReadProcSelfSmaps() {
+ ASSIGN_OR_RETURN_ERRNO(std::string contents, GetContents("/proc/self/smaps"));
+ return ParseProcPidSmaps(contents);
+}
+
+TEST(ProcPidSmapsTest, SharedAnon) {
+ // Map with MAP_POPULATE so we get some RSS.
+ Mapping const m = ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(
+ 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_POPULATE));
+ auto const entries = ASSERT_NO_ERRNO_AND_VALUE(ReadProcSelfSmaps());
+ auto const entry =
+ ASSERT_NO_ERRNO_AND_VALUE(FindUniqueSmapsEntry(entries, m.addr()));
+
+ EXPECT_EQ(entry.size_kb, m.len() / 1024);
+ // It's possible that populated pages have been swapped out, so RSS might be
+ // less than size.
+ EXPECT_LE(entry.rss_kb, entry.size_kb);
+
+ if (entry.pss_kb) {
+ // PSS should be exactly equal to RSS since no other address spaces should
+ // be sharing our new mapping.
+ EXPECT_EQ(entry.pss_kb.value(), entry.rss_kb);
+ }
+
+ // "Shared" and "private" in smaps refers to whether or not *physical pages*
+ // are shared; thus all pages in our MAP_SHARED mapping should nevertheless
+ // be private.
+ EXPECT_EQ(entry.shared_clean_kb, 0);
+ EXPECT_EQ(entry.shared_dirty_kb, 0);
+ EXPECT_EQ(entry.private_clean_kb + entry.private_dirty_kb, entry.rss_kb)
+ << "Private_Clean = " << entry.private_clean_kb
+ << " kB, Private_Dirty = " << entry.private_dirty_kb << " kB";
+
+ // Shared anonymous mappings are implemented as a shmem file, so their pages
+ // are not PageAnon.
+ if (entry.anonymous_kb) {
+ EXPECT_EQ(entry.anonymous_kb.value(), 0);
+ }
+
+ if (entry.vm_flags) {
+ EXPECT_THAT(entry.vm_flags.value(),
+ IsSupersetOf({"rd", "wr", "sh", "mr", "mw", "me", "ms"}));
+ EXPECT_THAT(entry.vm_flags.value(), Not(Contains("ex")));
+ }
+}
+
+TEST(ProcPidSmapsTest, PrivateAnon) {
+ // Map with MAP_POPULATE so we get some RSS.
+ Mapping const m = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(2 * kPageSize, PROT_WRITE, MAP_PRIVATE | MAP_POPULATE));
+ auto const entries = ASSERT_NO_ERRNO_AND_VALUE(ReadProcSelfSmaps());
+ auto const entry =
+ ASSERT_NO_ERRNO_AND_VALUE(FindUniqueSmapsEntry(entries, m.addr()));
+
+ // It's possible that our mapping was merged with another vma, so the smaps
+ // entry might be bigger than our original mapping.
+ EXPECT_GE(entry.size_kb, m.len() / 1024);
+ EXPECT_LE(entry.rss_kb, entry.size_kb);
+ if (entry.pss_kb) {
+ EXPECT_LE(entry.pss_kb.value(), entry.rss_kb);
+ }
+
+ if (entry.anonymous_kb) {
+ EXPECT_EQ(entry.anonymous_kb.value(), entry.rss_kb);
+ }
+
+ if (entry.vm_flags) {
+ EXPECT_THAT(entry.vm_flags.value(), IsSupersetOf({"wr", "mr", "mw", "me"}));
+ // We passed PROT_WRITE to mmap. On at least x86, the mapping is in
+ // practice readable because there is no way to configure the MMU to make
+ // pages writable but not readable. However, VmFlags should reflect the
+ // flags set on the VMA, so "rd" (VM_READ) should not appear in VmFlags.
+ EXPECT_THAT(entry.vm_flags.value(), Not(Contains("rd")));
+ EXPECT_THAT(entry.vm_flags.value(), Not(Contains("ex")));
+ EXPECT_THAT(entry.vm_flags.value(), Not(Contains("sh")));
+ EXPECT_THAT(entry.vm_flags.value(), Not(Contains("ms")));
+ }
+}
+
+TEST(ProcPidSmapsTest, SharedReadOnlyFile) {
+ size_t const kFileSize = kPageSize;
+
+ auto const temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ ASSERT_THAT(truncate(temp_file.path().c_str(), kFileSize), SyscallSucceeds());
+ auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(temp_file.path(), O_RDONLY));
+
+ auto const m = ASSERT_NO_ERRNO_AND_VALUE(Mmap(
+ nullptr, kFileSize, PROT_READ, MAP_SHARED | MAP_POPULATE, fd.get(), 0));
+ auto const entries = ASSERT_NO_ERRNO_AND_VALUE(ReadProcSelfSmaps());
+ auto const entry =
+ ASSERT_NO_ERRNO_AND_VALUE(FindUniqueSmapsEntry(entries, m.addr()));
+
+ // Most of the same logic as the SharedAnon case applies.
+ EXPECT_EQ(entry.size_kb, kFileSize / 1024);
+ EXPECT_LE(entry.rss_kb, entry.size_kb);
+ if (entry.pss_kb) {
+ EXPECT_EQ(entry.pss_kb.value(), entry.rss_kb);
+ }
+ EXPECT_EQ(entry.shared_clean_kb, 0);
+ EXPECT_EQ(entry.shared_dirty_kb, 0);
+ EXPECT_EQ(entry.private_clean_kb + entry.private_dirty_kb, entry.rss_kb)
+ << "Private_Clean = " << entry.private_clean_kb
+ << " kB, Private_Dirty = " << entry.private_dirty_kb << " kB";
+ if (entry.anonymous_kb) {
+ EXPECT_EQ(entry.anonymous_kb.value(), 0);
+ }
+
+ if (entry.vm_flags) {
+ EXPECT_THAT(entry.vm_flags.value(), IsSupersetOf({"rd", "mr", "me", "ms"}));
+ EXPECT_THAT(entry.vm_flags.value(), Not(Contains("wr")));
+ EXPECT_THAT(entry.vm_flags.value(), Not(Contains("ex")));
+ // Because the mapped file was opened O_RDONLY, the VMA is !VM_MAYWRITE and
+ // also !VM_SHARED.
+ EXPECT_THAT(entry.vm_flags.value(), Not(Contains("sh")));
+ EXPECT_THAT(entry.vm_flags.value(), Not(Contains("mw")));
+ }
+}
+
+// Tests that gVisor's /proc/[pid]/smaps provides all of the fields we expect it
+// to, which as of this writing is all fields provided by Linux 4.4.
+TEST(ProcPidSmapsTest, GvisorFields) {
+ SKIP_IF(!IsRunningOnGvisor());
+ auto const entries = ASSERT_NO_ERRNO_AND_VALUE(ReadProcSelfSmaps());
+ for (auto const& entry : entries) {
+ EXPECT_TRUE(entry.pss_kb);
+ EXPECT_TRUE(entry.referenced_kb);
+ EXPECT_TRUE(entry.anonymous_kb);
+ EXPECT_TRUE(entry.anon_huge_pages_kb);
+ EXPECT_TRUE(entry.shared_hugetlb_kb);
+ EXPECT_TRUE(entry.private_hugetlb_kb);
+ EXPECT_TRUE(entry.swap_kb);
+ EXPECT_TRUE(entry.swap_pss_kb);
+ EXPECT_THAT(entry.kernel_page_size_kb, Optional(kPageSize / 1024));
+ EXPECT_THAT(entry.mmu_page_size_kb, Optional(kPageSize / 1024));
+ EXPECT_TRUE(entry.locked_kb);
+ EXPECT_TRUE(entry.vm_flags);
+ }
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/proc_pid_uid_gid_map.cc b/test/syscalls/linux/proc_pid_uid_gid_map.cc
new file mode 100644
index 000000000..748f7be58
--- /dev/null
+++ b/test/syscalls/linux/proc_pid_uid_gid_map.cc
@@ -0,0 +1,311 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <sched.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <functional>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+#include "test/util/capability_util.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/logging.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+#include "test/util/test_util.h"
+#include "test/util/time_util.h"
+
+namespace gvisor {
+namespace testing {
+
+PosixErrorOr<int> InNewUserNamespace(const std::function<void()>& fn) {
+ return InForkedProcess([&] {
+ TEST_PCHECK(unshare(CLONE_NEWUSER) == 0);
+ MaybeSave();
+ fn();
+ });
+}
+
+PosixErrorOr<std::tuple<pid_t, Cleanup>> CreateProcessInNewUserNamespace() {
+ int pipefd[2];
+ if (pipe(pipefd) < 0) {
+ return PosixError(errno, "pipe failed");
+ }
+ const auto cleanup_pipe_read =
+ Cleanup([&] { EXPECT_THAT(close(pipefd[0]), SyscallSucceeds()); });
+ auto cleanup_pipe_write =
+ Cleanup([&] { EXPECT_THAT(close(pipefd[1]), SyscallSucceeds()); });
+ pid_t child_pid = fork();
+ if (child_pid < 0) {
+ return PosixError(errno, "fork failed");
+ }
+ if (child_pid == 0) {
+ // Close our copy of the pipe's read end, which doesn't really matter.
+ TEST_PCHECK(close(pipefd[0]) >= 0);
+ TEST_PCHECK(unshare(CLONE_NEWUSER) == 0);
+ MaybeSave();
+ // Indicate that we've switched namespaces by unblocking the parent's read.
+ TEST_PCHECK(close(pipefd[1]) >= 0);
+ while (true) {
+ SleepSafe(absl::Minutes(1));
+ }
+ }
+ auto cleanup_child = Cleanup([child_pid] {
+ EXPECT_THAT(kill(child_pid, SIGKILL), SyscallSucceeds());
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL)
+ << "status = " << status;
+ });
+ // Close our copy of the pipe's write end, then wait for the child to close
+ // its copy, indicating that it's switched namespaces.
+ cleanup_pipe_write.Release()();
+ char buf;
+ if (RetryEINTR(read)(pipefd[0], &buf, 1) < 0) {
+ return PosixError(errno, "reading from pipe failed");
+ }
+ MaybeSave();
+ return std::make_tuple(child_pid, std::move(cleanup_child));
+}
+
+// TEST_CHECK-fails on error, since this function is used in contexts that
+// require async-signal-safety.
+void DenySetgroupsByPath(const char* path) {
+ int fd = open(path, O_WRONLY);
+ if (fd < 0 && errno == ENOENT) {
+ // On kernels where this file doesn't exist, writing "deny" to it isn't
+ // necessary to write to gid_map.
+ return;
+ }
+ TEST_PCHECK(fd >= 0);
+ MaybeSave();
+ char deny[] = "deny";
+ TEST_PCHECK(write(fd, deny, sizeof(deny)) == sizeof(deny));
+ MaybeSave();
+ TEST_PCHECK(close(fd) == 0);
+}
+
+void DenySelfSetgroups() { DenySetgroupsByPath("/proc/self/setgroups"); }
+
+void DenyPidSetgroups(pid_t pid) {
+ DenySetgroupsByPath(absl::StrCat("/proc/", pid, "/setgroups").c_str());
+}
+
+// Returns a valid UID/GID that isn't id.
+uint32_t another_id(uint32_t id) { return (id + 1) % 65535; }
+
+struct TestParam {
+ std::string desc;
+ int cap;
+ std::function<std::string(absl::string_view)> get_map_filename;
+ std::function<uint32_t()> get_current_id;
+};
+
+std::string DescribeTestParam(const ::testing::TestParamInfo<TestParam>& info) {
+ return info.param.desc;
+}
+
+std::vector<TestParam> UidGidMapTestParams() {
+ return {TestParam{"UID", CAP_SETUID,
+ [](absl::string_view pid) {
+ return absl::StrCat("/proc/", pid, "/uid_map");
+ },
+ []() -> uint32_t { return getuid(); }},
+ TestParam{"GID", CAP_SETGID,
+ [](absl::string_view pid) {
+ return absl::StrCat("/proc/", pid, "/gid_map");
+ },
+ []() -> uint32_t { return getgid(); }}};
+}
+
+class ProcUidGidMapTest : public ::testing::TestWithParam<TestParam> {
+ protected:
+ uint32_t CurrentID() { return GetParam().get_current_id(); }
+};
+
+class ProcSelfUidGidMapTest : public ProcUidGidMapTest {
+ protected:
+ PosixErrorOr<int> InNewUserNamespaceWithMapFD(
+ const std::function<void(int)>& fn) {
+ std::string map_filename = GetParam().get_map_filename("self");
+ return InNewUserNamespace([&] {
+ int fd = open(map_filename.c_str(), O_RDWR);
+ TEST_PCHECK(fd >= 0);
+ MaybeSave();
+ fn(fd);
+ TEST_PCHECK(close(fd) == 0);
+ });
+ }
+};
+
+class ProcPidUidGidMapTest : public ProcUidGidMapTest {
+ protected:
+ PosixErrorOr<bool> HaveSetIDCapability() {
+ return HaveCapability(GetParam().cap);
+ }
+
+ // Returns true if the caller is running in a user namespace with all IDs
+ // mapped. This matters for tests that expect to successfully map arbitrary
+ // IDs into a child user namespace, since even with CAP_SET*ID this is only
+ // possible if those IDs are mapped into the current one.
+ PosixErrorOr<bool> AllIDsMapped() {
+ ASSIGN_OR_RETURN_ERRNO(std::string id_map,
+ GetContents(GetParam().get_map_filename("self")));
+ absl::StripTrailingAsciiWhitespace(&id_map);
+ std::vector<std::string> id_map_parts =
+ absl::StrSplit(id_map, ' ', absl::SkipEmpty());
+ return id_map_parts == std::vector<std::string>({"0", "0", "4294967295"});
+ }
+
+ PosixErrorOr<FileDescriptor> OpenMapFile(pid_t pid) {
+ return Open(GetParam().get_map_filename(absl::StrCat(pid)), O_RDWR);
+ }
+};
+
+TEST_P(ProcSelfUidGidMapTest, IsInitiallyEmpty) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanCreateUserNamespace()));
+ EXPECT_THAT(InNewUserNamespaceWithMapFD([](int fd) {
+ char buf[64];
+ TEST_PCHECK(read(fd, buf, sizeof(buf)) == 0);
+ }),
+ IsPosixErrorOkAndHolds(0));
+}
+
+TEST_P(ProcSelfUidGidMapTest, IdentityMapOwnID) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanCreateUserNamespace()));
+ uint32_t id = CurrentID();
+ std::string line = absl::StrCat(id, " ", id, " 1");
+ EXPECT_THAT(
+ InNewUserNamespaceWithMapFD([&](int fd) {
+ DenySelfSetgroups();
+ TEST_PCHECK(write(fd, line.c_str(), line.size()) == line.size());
+ }),
+ IsPosixErrorOkAndHolds(0));
+}
+
+TEST_P(ProcSelfUidGidMapTest, TrailingNewlineAndNULIgnored) {
+ // This is identical to IdentityMapOwnID, except that a trailing newline, NUL,
+ // and an invalid (incomplete) map entry are appended to the valid entry. The
+ // newline should be accepted, and everything after the NUL should be ignored.
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanCreateUserNamespace()));
+ uint32_t id = CurrentID();
+ std::string line = absl::StrCat(id, " ", id, " 1\n\0 4 3");
+ EXPECT_THAT(
+ InNewUserNamespaceWithMapFD([&](int fd) {
+ DenySelfSetgroups();
+ // The write should return the full size of the write, even though
+ // characters after the NUL were ignored.
+ TEST_PCHECK(write(fd, line.c_str(), line.size()) == line.size());
+ }),
+ IsPosixErrorOkAndHolds(0));
+}
+
+TEST_P(ProcSelfUidGidMapTest, NonIdentityMapOwnID) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanCreateUserNamespace()));
+ uint32_t id = CurrentID();
+ uint32_t id2 = another_id(id);
+ std::string line = absl::StrCat(id2, " ", id, " 1");
+ EXPECT_THAT(
+ InNewUserNamespaceWithMapFD([&](int fd) {
+ DenySelfSetgroups();
+ TEST_PCHECK(write(fd, line.c_str(), line.size()) == line.size());
+ }),
+ IsPosixErrorOkAndHolds(0));
+}
+
+TEST_P(ProcSelfUidGidMapTest, MapOtherID) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanCreateUserNamespace()));
+ // Whether or not we have CAP_SET*ID is irrelevant: the process running in the
+ // new (child) user namespace won't have any capabilities in the current
+ // (parent) user namespace, which is needed.
+ uint32_t id = CurrentID();
+ uint32_t id2 = another_id(id);
+ std::string line = absl::StrCat(id, " ", id2, " 1");
+ EXPECT_THAT(InNewUserNamespaceWithMapFD([&](int fd) {
+ DenySelfSetgroups();
+ TEST_PCHECK(write(fd, line.c_str(), line.size()) < 0);
+ TEST_CHECK(errno == EPERM);
+ }),
+ IsPosixErrorOkAndHolds(0));
+}
+
+INSTANTIATE_TEST_SUITE_P(All, ProcSelfUidGidMapTest,
+ ::testing::ValuesIn(UidGidMapTestParams()),
+ DescribeTestParam);
+
+TEST_P(ProcPidUidGidMapTest, MapOtherIDPrivileged) {
+ // Like ProcSelfUidGidMapTest_MapOtherID, but since we have CAP_SET*ID in the
+ // parent user namespace (this one), we can map IDs that aren't ours.
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanCreateUserNamespace()));
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveSetIDCapability()));
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(AllIDsMapped()));
+
+ pid_t child_pid;
+ Cleanup cleanup_child;
+ std::tie(child_pid, cleanup_child) =
+ ASSERT_NO_ERRNO_AND_VALUE(CreateProcessInNewUserNamespace());
+
+ uint32_t id = CurrentID();
+ uint32_t id2 = another_id(id);
+ std::string line = absl::StrCat(id, " ", id2, " 1");
+ DenyPidSetgroups(child_pid);
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenMapFile(child_pid));
+ EXPECT_THAT(write(fd.get(), line.c_str(), line.size()),
+ SyscallSucceedsWithValue(line.size()));
+}
+
+TEST_P(ProcPidUidGidMapTest, MapAnyIDsPrivileged) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanCreateUserNamespace()));
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveSetIDCapability()));
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(AllIDsMapped()));
+
+ pid_t child_pid;
+ Cleanup cleanup_child;
+ std::tie(child_pid, cleanup_child) =
+ ASSERT_NO_ERRNO_AND_VALUE(CreateProcessInNewUserNamespace());
+
+ // Test all of:
+ //
+ // - Mapping ranges of length > 1
+ //
+ // - Mapping multiple ranges
+ //
+ // - Non-identity mappings
+ char entries[] = "2 0 2\n4 6 2";
+ DenyPidSetgroups(child_pid);
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenMapFile(child_pid));
+ EXPECT_THAT(write(fd.get(), entries, sizeof(entries)),
+ SyscallSucceedsWithValue(sizeof(entries)));
+}
+
+INSTANTIATE_TEST_SUITE_P(All, ProcPidUidGidMapTest,
+ ::testing::ValuesIn(UidGidMapTestParams()),
+ DescribeTestParam);
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/pselect.cc b/test/syscalls/linux/pselect.cc
new file mode 100644
index 000000000..4e43c4d7f
--- /dev/null
+++ b/test/syscalls/linux/pselect.cc
@@ -0,0 +1,190 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+#include <sys/select.h>
+
+#include "gtest/gtest.h"
+#include "absl/time/time.h"
+#include "test/syscalls/linux/base_poll_test.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+struct MaskWithSize {
+ sigset_t* mask;
+ size_t mask_size;
+};
+
+// Linux and glibc have a different idea of the sizeof sigset_t. When calling
+// the syscall directly, use what the kernel expects.
+unsigned kSigsetSize = SIGRTMAX / 8;
+
+// Linux pselect(2) differs from the glibc wrapper function in that Linux
+// updates the timeout with the amount of time remaining. In order to test this
+// behavior we need to use the syscall directly.
+int syscallPselect6(int nfds, fd_set* readfds, fd_set* writefds,
+ fd_set* exceptfds, struct timespec* timeout,
+ const MaskWithSize* mask_with_size) {
+ return syscall(SYS_pselect6, nfds, readfds, writefds, exceptfds, timeout,
+ mask_with_size);
+}
+
+class PselectTest : public BasePollTest {
+ protected:
+ void SetUp() override { BasePollTest::SetUp(); }
+ void TearDown() override { BasePollTest::TearDown(); }
+};
+
+// See that when there are no FD sets, pselect behaves like sleep.
+TEST_F(PselectTest, NullFds) {
+ struct timespec timeout = absl::ToTimespec(absl::Milliseconds(10));
+ ASSERT_THAT(syscallPselect6(0, nullptr, nullptr, nullptr, &timeout, nullptr),
+ SyscallSucceeds());
+ EXPECT_EQ(timeout.tv_sec, 0);
+ EXPECT_EQ(timeout.tv_nsec, 0);
+
+ timeout = absl::ToTimespec(absl::Milliseconds(10));
+ ASSERT_THAT(syscallPselect6(1, nullptr, nullptr, nullptr, &timeout, nullptr),
+ SyscallSucceeds());
+ EXPECT_EQ(timeout.tv_sec, 0);
+ EXPECT_EQ(timeout.tv_nsec, 0);
+}
+
+TEST_F(PselectTest, ClosedFds) {
+ fd_set read_set;
+ FD_ZERO(&read_set);
+ int fd;
+ ASSERT_THAT(fd = dup(1), SyscallSucceeds());
+ ASSERT_THAT(close(fd), SyscallSucceeds());
+ FD_SET(fd, &read_set);
+ struct timespec timeout = absl::ToTimespec(absl::Milliseconds(10));
+ EXPECT_THAT(
+ syscallPselect6(fd + 1, &read_set, nullptr, nullptr, &timeout, nullptr),
+ SyscallFailsWithErrno(EBADF));
+}
+
+TEST_F(PselectTest, ZeroTimeout) {
+ struct timespec timeout = {};
+ ASSERT_THAT(syscallPselect6(1, nullptr, nullptr, nullptr, &timeout, nullptr),
+ SyscallSucceeds());
+ EXPECT_EQ(timeout.tv_sec, 0);
+ EXPECT_EQ(timeout.tv_nsec, 0);
+}
+
+// If random S/R interrupts the pselect, SIGALRM may be delivered before pselect
+// restarts, causing the pselect to hang forever.
+TEST_F(PselectTest, NoTimeout_NoRandomSave) {
+ // When there's no timeout, pselect may never return so set a timer.
+ SetTimer(absl::Milliseconds(100));
+ // See that we get interrupted by the timer.
+ ASSERT_THAT(syscallPselect6(1, nullptr, nullptr, nullptr, nullptr, nullptr),
+ SyscallFailsWithErrno(EINTR));
+ EXPECT_TRUE(TimerFired());
+}
+
+TEST_F(PselectTest, InvalidTimeoutNegative) {
+ struct timespec timeout = absl::ToTimespec(absl::Seconds(-1));
+ ASSERT_THAT(syscallPselect6(1, nullptr, nullptr, nullptr, &timeout, nullptr),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_EQ(timeout.tv_sec, -1);
+ EXPECT_EQ(timeout.tv_nsec, 0);
+}
+
+TEST_F(PselectTest, InvalidTimeoutNotNormalized) {
+ struct timespec timeout = {0, 1000000001};
+ ASSERT_THAT(syscallPselect6(1, nullptr, nullptr, nullptr, &timeout, nullptr),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_EQ(timeout.tv_sec, 0);
+ EXPECT_EQ(timeout.tv_nsec, 1000000001);
+}
+
+TEST_F(PselectTest, EmptySigMaskInvalidMaskSize) {
+ struct timespec timeout = {};
+ MaskWithSize invalid = {nullptr, 7};
+ EXPECT_THAT(syscallPselect6(0, nullptr, nullptr, nullptr, &timeout, &invalid),
+ SyscallSucceeds());
+}
+
+TEST_F(PselectTest, EmptySigMaskValidMaskSize) {
+ struct timespec timeout = {};
+ MaskWithSize invalid = {nullptr, 8};
+ EXPECT_THAT(syscallPselect6(0, nullptr, nullptr, nullptr, &timeout, &invalid),
+ SyscallSucceeds());
+}
+
+TEST_F(PselectTest, InvalidMaskSize) {
+ struct timespec timeout = {};
+ sigset_t sigmask;
+ ASSERT_THAT(sigemptyset(&sigmask), SyscallSucceeds());
+ MaskWithSize invalid = {&sigmask, 7};
+ EXPECT_THAT(syscallPselect6(1, nullptr, nullptr, nullptr, &timeout, &invalid),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// Verify that signals blocked by the pselect mask (that would otherwise be
+// allowed) do not interrupt pselect.
+TEST_F(PselectTest, SignalMaskBlocksSignal) {
+ absl::Duration duration(absl::Seconds(30));
+ struct timespec timeout = absl::ToTimespec(duration);
+ absl::Duration timer_duration(absl::Seconds(10));
+
+ // Call with a mask that blocks SIGALRM. See that pselect is not interrupted
+ // (i.e. returns 0) and that upon completion, the timer has fired.
+ sigset_t mask;
+ ASSERT_THAT(sigprocmask(0, nullptr, &mask), SyscallSucceeds());
+ ASSERT_THAT(sigaddset(&mask, SIGALRM), SyscallSucceeds());
+ MaskWithSize mask_with_size = {&mask, kSigsetSize};
+ SetTimer(timer_duration);
+ MaybeSave();
+ ASSERT_FALSE(TimerFired());
+ ASSERT_THAT(
+ syscallPselect6(1, nullptr, nullptr, nullptr, &timeout, &mask_with_size),
+ SyscallSucceeds());
+ EXPECT_TRUE(TimerFired());
+ EXPECT_EQ(absl::DurationFromTimespec(timeout), absl::Duration());
+}
+
+// Verify that signals allowed by the pselect mask (that would otherwise be
+// blocked) interrupt pselect.
+TEST_F(PselectTest, SignalMaskAllowsSignal) {
+ absl::Duration duration = absl::Seconds(30);
+ struct timespec timeout = absl::ToTimespec(duration);
+ absl::Duration timer_duration = absl::Seconds(10);
+
+ sigset_t mask;
+ ASSERT_THAT(sigprocmask(0, nullptr, &mask), SyscallSucceeds());
+
+ // Block SIGALRM.
+ auto cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, SIGALRM));
+
+ // Call with a mask that unblocks SIGALRM. See that pselect is interrupted.
+ MaskWithSize mask_with_size = {&mask, kSigsetSize};
+ SetTimer(timer_duration);
+ MaybeSave();
+ ASSERT_FALSE(TimerFired());
+ ASSERT_THAT(
+ syscallPselect6(1, nullptr, nullptr, nullptr, &timeout, &mask_with_size),
+ SyscallFailsWithErrno(EINTR));
+ EXPECT_TRUE(TimerFired());
+ EXPECT_GT(absl::DurationFromTimespec(timeout), absl::Duration());
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc
new file mode 100644
index 000000000..926690eb8
--- /dev/null
+++ b/test/syscalls/linux/ptrace.cc
@@ -0,0 +1,1229 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <elf.h>
+#include <signal.h>
+#include <stddef.h>
+#include <sys/ptrace.h>
+#include <sys/time.h>
+#include <sys/types.h>
+#include <sys/user.h>
+#include <sys/wait.h>
+#include <unistd.h>
+
+#include <iostream>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/logging.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/platform_util.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+#include "test/util/time_util.h"
+
+ABSL_FLAG(bool, ptrace_test_execve_child, false,
+ "If true, run the "
+ "PtraceExecveTest_Execve_GetRegs_PeekUser_SIGKILL_TraceClone_"
+ "TraceExit child workload.");
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// PTRACE_GETSIGMASK and PTRACE_SETSIGMASK are not defined until glibc 2.23
+// (fb53a27c5741 "Add new header definitions from Linux 4.4 (plus older ptrace
+// definitions)").
+constexpr auto kPtraceGetSigMask = static_cast<__ptrace_request>(0x420a);
+constexpr auto kPtraceSetSigMask = static_cast<__ptrace_request>(0x420b);
+
+// PTRACE_SYSEMU is not defined until glibc 2.27 (c48831d0eebf "linux/x86: sync
+// sys/ptrace.h with Linux 4.14 [BZ #22433]").
+constexpr auto kPtraceSysemu = static_cast<__ptrace_request>(31);
+
+// PTRACE_EVENT_STOP is not defined until glibc 2.26 (3f67d1a7021e "Add Linux
+// PTRACE_EVENT_STOP").
+constexpr int kPtraceEventStop = 128;
+
+// Sends sig to the current process with tgkill(2).
+//
+// glibc's raise(2) may change the signal mask before sending the signal. These
+// extra syscalls make tests of syscall, signal interception, etc. difficult to
+// write.
+void RaiseSignal(int sig) {
+ pid_t pid = getpid();
+ TEST_PCHECK(pid > 0);
+ pid_t tid = gettid();
+ TEST_PCHECK(tid > 0);
+ TEST_PCHECK(tgkill(pid, tid, sig) == 0);
+}
+
+// Returns the Yama ptrace scope.
+PosixErrorOr<int> YamaPtraceScope() {
+ constexpr char kYamaPtraceScopePath[] = "/proc/sys/kernel/yama/ptrace_scope";
+
+ ASSIGN_OR_RETURN_ERRNO(bool exists, Exists(kYamaPtraceScopePath));
+ if (!exists) {
+ // File doesn't exist means no Yama, so the scope is disabled -> 0.
+ return 0;
+ }
+
+ std::string contents;
+ RETURN_IF_ERRNO(GetContents(kYamaPtraceScopePath, &contents));
+
+ int scope;
+ if (!absl::SimpleAtoi(contents, &scope)) {
+ return PosixError(EINVAL, absl::StrCat(contents, ": not a valid number"));
+ }
+
+ return scope;
+}
+
+TEST(PtraceTest, AttachSelf) {
+ EXPECT_THAT(ptrace(PTRACE_ATTACH, gettid(), 0, 0),
+ SyscallFailsWithErrno(EPERM));
+}
+
+TEST(PtraceTest, AttachSameThreadGroup) {
+ pid_t const tid = gettid();
+ ScopedThread([&] {
+ EXPECT_THAT(ptrace(PTRACE_ATTACH, tid, 0, 0), SyscallFailsWithErrno(EPERM));
+ });
+}
+
+TEST(PtraceTest, AttachParent_PeekData_PokeData_SignalSuppression) {
+ // Yama prevents attaching to a parent. Skip the test if the scope is anything
+ // except disabled.
+ SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) > 0);
+
+ constexpr long kBeforePokeDataValue = 10;
+ constexpr long kAfterPokeDataValue = 20;
+
+ volatile long word = kBeforePokeDataValue;
+
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // In child process.
+
+ // Attach to the parent.
+ pid_t const parent_pid = getppid();
+ TEST_PCHECK(ptrace(PTRACE_ATTACH, parent_pid, 0, 0) == 0);
+ MaybeSave();
+
+ // Block until the parent enters signal-delivery-stop as a result of the
+ // SIGSTOP sent by PTRACE_ATTACH.
+ int status;
+ TEST_PCHECK(waitpid(parent_pid, &status, 0) == parent_pid);
+ MaybeSave();
+ TEST_CHECK(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP);
+
+ // Replace the value of word in the parent process with kAfterPokeDataValue.
+ long const parent_word = ptrace(PTRACE_PEEKDATA, parent_pid, &word, 0);
+ MaybeSave();
+ TEST_CHECK(parent_word == kBeforePokeDataValue);
+ TEST_PCHECK(
+ ptrace(PTRACE_POKEDATA, parent_pid, &word, kAfterPokeDataValue) == 0);
+ MaybeSave();
+
+ // Detach from the parent and suppress the SIGSTOP. If the SIGSTOP is not
+ // suppressed, the parent will hang in group-stop, causing the test to time
+ // out.
+ TEST_PCHECK(ptrace(PTRACE_DETACH, parent_pid, 0, 0) == 0);
+ MaybeSave();
+ _exit(0);
+ }
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Wait for the child to complete.
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << " status " << status;
+
+ // Check that the child's PTRACE_POKEDATA was effective.
+ EXPECT_EQ(kAfterPokeDataValue, word);
+}
+
+TEST(PtraceTest, GetSigMask) {
+ // glibc and the Linux kernel define a sigset_t with different sizes. To avoid
+ // creating a kernel_sigset_t and recreating all the modification functions
+ // (sigemptyset, etc), we just hardcode the kernel sigset size.
+ constexpr int kSizeofKernelSigset = 8;
+ constexpr int kBlockSignal = SIGUSR1;
+ sigset_t blocked;
+ sigemptyset(&blocked);
+ sigaddset(&blocked, kBlockSignal);
+
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // In child process.
+
+ // Install a signal handler for kBlockSignal to avoid termination and block
+ // it.
+ TEST_PCHECK(signal(
+ kBlockSignal, +[](int signo) {}) != SIG_ERR);
+ MaybeSave();
+ TEST_PCHECK(sigprocmask(SIG_SETMASK, &blocked, nullptr) == 0);
+ MaybeSave();
+
+ // Enable tracing.
+ TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0);
+ MaybeSave();
+
+ // This should be blocked.
+ RaiseSignal(kBlockSignal);
+
+ // This should be suppressed by parent, who will change signal mask in the
+ // meantime, which means kBlockSignal should be delivered once this resumes.
+ RaiseSignal(SIGSTOP);
+
+ _exit(0);
+ }
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Wait for the child to send itself SIGSTOP and enter signal-delivery-stop.
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP)
+ << " status " << status;
+
+ // Get current signal mask.
+ sigset_t set;
+ EXPECT_THAT(ptrace(kPtraceGetSigMask, child_pid, kSizeofKernelSigset, &set),
+ SyscallSucceeds());
+ EXPECT_THAT(blocked, EqualsSigset(set));
+
+ // Try to get current signal mask with bad size argument.
+ EXPECT_THAT(ptrace(kPtraceGetSigMask, child_pid, 0, nullptr),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Try to set bad signal mask.
+ sigset_t* bad_addr = reinterpret_cast<sigset_t*>(-1);
+ EXPECT_THAT(
+ ptrace(kPtraceSetSigMask, child_pid, kSizeofKernelSigset, bad_addr),
+ SyscallFailsWithErrno(EFAULT));
+
+ // Set signal mask to empty set.
+ sigset_t set1;
+ sigemptyset(&set1);
+ EXPECT_THAT(ptrace(kPtraceSetSigMask, child_pid, kSizeofKernelSigset, &set1),
+ SyscallSucceeds());
+
+ // Suppress SIGSTOP and resume the child. It should re-enter
+ // signal-delivery-stop for kBlockSignal.
+ ASSERT_THAT(ptrace(PTRACE_CONT, child_pid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == kBlockSignal)
+ << " status " << status;
+
+ ASSERT_THAT(ptrace(PTRACE_CONT, child_pid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ // Let's see that process exited normally.
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << " status " << status;
+}
+
+TEST(PtraceTest, GetSiginfo_SetSiginfo_SignalInjection) {
+ constexpr int kOriginalSigno = SIGUSR1;
+ constexpr int kInjectedSigno = SIGUSR2;
+
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // In child process.
+
+ // Override all signal handlers.
+ struct sigaction sa = {};
+ sa.sa_handler = +[](int signo) { _exit(signo); };
+ TEST_PCHECK(sigfillset(&sa.sa_mask) == 0);
+ for (int signo = 1; signo < 32; signo++) {
+ if (signo == SIGKILL || signo == SIGSTOP) {
+ continue;
+ }
+ TEST_PCHECK(sigaction(signo, &sa, nullptr) == 0);
+ }
+ for (int signo = SIGRTMIN; signo <= SIGRTMAX; signo++) {
+ TEST_PCHECK(sigaction(signo, &sa, nullptr) == 0);
+ }
+
+ // Unblock all signals.
+ TEST_PCHECK(sigprocmask(SIG_UNBLOCK, &sa.sa_mask, nullptr) == 0);
+ MaybeSave();
+
+ // Send ourselves kOriginalSignal while ptraced and exit with the signal we
+ // actually receive via the signal handler, if any, or 0 if we don't receive
+ // a signal.
+ TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0);
+ MaybeSave();
+ RaiseSignal(kOriginalSigno);
+ _exit(0);
+ }
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Wait for the child to send itself kOriginalSigno and enter
+ // signal-delivery-stop.
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == kOriginalSigno)
+ << " status " << status;
+
+ siginfo_t siginfo = {};
+ ASSERT_THAT(ptrace(PTRACE_GETSIGINFO, child_pid, 0, &siginfo),
+ SyscallSucceeds());
+ EXPECT_EQ(kOriginalSigno, siginfo.si_signo);
+ EXPECT_EQ(SI_TKILL, siginfo.si_code);
+
+ // Replace the signal with kInjectedSigno, and check that the child exits
+ // with kInjectedSigno, indicating that signal injection was successful.
+ siginfo.si_signo = kInjectedSigno;
+ ASSERT_THAT(ptrace(PTRACE_SETSIGINFO, child_pid, 0, &siginfo),
+ SyscallSucceeds());
+ ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, kInjectedSigno),
+ SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == kInjectedSigno)
+ << " status " << status;
+}
+
+TEST(PtraceTest, SIGKILLDoesNotCauseSignalDeliveryStop) {
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // In child process.
+ TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0);
+ MaybeSave();
+ RaiseSignal(SIGKILL);
+ TEST_CHECK_MSG(false, "Survived SIGKILL?");
+ _exit(1);
+ }
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Expect the child to die to SIGKILL without entering signal-delivery-stop.
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL)
+ << " status " << status;
+}
+
+TEST(PtraceTest, PtraceKill) {
+ constexpr int kOriginalSigno = SIGUSR1;
+
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // In child process.
+ TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0);
+ MaybeSave();
+
+ // PTRACE_KILL only works if tracee has entered signal-delivery-stop.
+ RaiseSignal(kOriginalSigno);
+ TEST_CHECK_MSG(false, "Failed to kill the process?");
+ _exit(0);
+ }
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Wait for the child to send itself kOriginalSigno and enter
+ // signal-delivery-stop.
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == kOriginalSigno)
+ << " status " << status;
+
+ ASSERT_THAT(ptrace(PTRACE_KILL, child_pid, 0, 0), SyscallSucceeds());
+
+ // Expect the child to die with SIGKILL.
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL)
+ << " status " << status;
+}
+
+TEST(PtraceTest, GetRegSet) {
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // In child process.
+
+ // Enable tracing.
+ TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0);
+ MaybeSave();
+
+ // Use kill explicitly because we check the syscall argument register below.
+ kill(getpid(), SIGSTOP);
+
+ _exit(0);
+ }
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Wait for the child to send itself SIGSTOP and enter signal-delivery-stop.
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP)
+ << " status " << status;
+
+ // Get the general registers.
+ struct user_regs_struct regs;
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child_pid, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
+
+#if defined(__x86_64__)
+ // Child called kill(2), with SIGSTOP as arg 2.
+ EXPECT_EQ(regs.rsi, SIGSTOP);
+#elif defined(__aarch64__)
+ EXPECT_EQ(regs.regs[1], SIGSTOP);
+#endif
+
+ // Suppress SIGSTOP and resume the child.
+ ASSERT_THAT(ptrace(PTRACE_CONT, child_pid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ // Let's see that process exited normally.
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << " status " << status;
+}
+
+TEST(PtraceTest, AttachingConvertsGroupStopToPtraceStop) {
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // In child process.
+ while (true) {
+ pause();
+ }
+ }
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // SIGSTOP the child and wait for it to stop.
+ ASSERT_THAT(kill(child_pid, SIGSTOP), SyscallSucceeds());
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, WUNTRACED),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP)
+ << " status " << status;
+
+ // Attach to the child and expect it to re-enter a traced group-stop despite
+ // already being stopped.
+ ASSERT_THAT(ptrace(PTRACE_ATTACH, child_pid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP)
+ << " status " << status;
+
+ // Verify that the child is ptrace-stopped by checking that it can receive
+ // ptrace commands requiring a ptrace-stop.
+ EXPECT_THAT(ptrace(PTRACE_SETOPTIONS, child_pid, 0, 0), SyscallSucceeds());
+
+ // Group-stop is distinguished from signal-delivery-stop by PTRACE_GETSIGINFO
+ // failing with EINVAL.
+ siginfo_t siginfo = {};
+ EXPECT_THAT(ptrace(PTRACE_GETSIGINFO, child_pid, 0, &siginfo),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Detach from the child and expect it to stay stopped without a notification.
+ ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, WUNTRACED | WNOHANG),
+ SyscallSucceedsWithValue(0));
+
+ // Sending it SIGCONT should cause it to leave its stop.
+ ASSERT_THAT(kill(child_pid, SIGCONT), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, WCONTINUED),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFCONTINUED(status)) << " status " << status;
+
+ // Clean up the child.
+ ASSERT_THAT(kill(child_pid, SIGKILL), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL)
+ << " status " << status;
+}
+
+// Fixture for tests parameterized by whether or not to use PTRACE_O_TRACEEXEC.
+class PtraceExecveTest : public ::testing::TestWithParam<bool> {
+ protected:
+ bool TraceExec() const { return GetParam(); }
+};
+
+TEST_P(PtraceExecveTest, Execve_GetRegs_PeekUser_SIGKILL_TraceClone_TraceExit) {
+ ExecveArray const owned_child_argv = {"/proc/self/exe",
+ "--ptrace_test_execve_child"};
+ char* const* const child_argv = owned_child_argv.get();
+
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // In child process. The test relies on calling execve() in a non-leader
+ // thread; pthread_create() isn't async-signal-safe, so the safest way to
+ // do this is to execve() first, then enable tracing and run the expected
+ // child process behavior in the new subprocess.
+ execve(child_argv[0], child_argv, /* envp = */ nullptr);
+ TEST_PCHECK_MSG(false, "Survived execve to test child");
+ }
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Wait for the child to send itself SIGSTOP and enter signal-delivery-stop.
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP)
+ << " status " << status;
+
+ // Enable PTRACE_O_TRACECLONE so we can get the ID of the child's non-leader
+ // thread, PTRACE_O_TRACEEXIT so we can observe the leader's death, and
+ // PTRACE_O_TRACEEXEC if required by the test. (The leader doesn't call
+ // execve, but options should be inherited across clone.)
+ long opts = PTRACE_O_TRACECLONE | PTRACE_O_TRACEEXIT;
+ if (TraceExec()) {
+ opts |= PTRACE_O_TRACEEXEC;
+ }
+ ASSERT_THAT(ptrace(PTRACE_SETOPTIONS, child_pid, 0, opts), SyscallSucceeds());
+
+ // Suppress the SIGSTOP and wait for the child's leader thread to report
+ // PTRACE_EVENT_CLONE. Get the new thread's ID from the event.
+ ASSERT_THAT(ptrace(PTRACE_CONT, child_pid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_EQ(SIGTRAP | (PTRACE_EVENT_CLONE << 8), status >> 8);
+ unsigned long eventmsg;
+ ASSERT_THAT(ptrace(PTRACE_GETEVENTMSG, child_pid, 0, &eventmsg),
+ SyscallSucceeds());
+ pid_t const nonleader_tid = eventmsg;
+ pid_t const leader_tid = child_pid;
+
+ // The new thread should be ptraced and in signal-delivery-stop by SIGSTOP due
+ // to PTRACE_O_TRACECLONE.
+ //
+ // Before bf959931ddb88c4e4366e96dd22e68fa0db9527c "wait/ptrace: assume __WALL
+ // if the child is traced" (4.7) , waiting on it requires __WCLONE since, as a
+ // non-leader, its termination signal is 0. After, a standard wait is
+ // sufficient.
+ ASSERT_THAT(waitpid(nonleader_tid, &status, __WCLONE),
+ SyscallSucceedsWithValue(nonleader_tid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP)
+ << " status " << status;
+
+ // Resume both child threads.
+ for (pid_t const tid : {leader_tid, nonleader_tid}) {
+ ASSERT_THAT(ptrace(PTRACE_CONT, tid, 0, 0), SyscallSucceeds());
+ }
+
+ // The non-leader child thread should call execve, causing the leader thread
+ // to enter PTRACE_EVENT_EXIT with an apparent exit code of 0. At this point,
+ // the leader has not yet exited, so the non-leader should be blocked in
+ // execve.
+ ASSERT_THAT(waitpid(leader_tid, &status, 0),
+ SyscallSucceedsWithValue(leader_tid));
+ EXPECT_EQ(SIGTRAP | (PTRACE_EVENT_EXIT << 8), status >> 8);
+ ASSERT_THAT(ptrace(PTRACE_GETEVENTMSG, leader_tid, 0, &eventmsg),
+ SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(eventmsg) && WEXITSTATUS(eventmsg) == 0)
+ << " eventmsg " << eventmsg;
+ EXPECT_THAT(waitpid(nonleader_tid, &status, __WCLONE | WNOHANG),
+ SyscallSucceedsWithValue(0));
+
+ // Allow the leader to continue exiting. This should allow the non-leader to
+ // complete its execve, causing the original leader to be reaped without
+ // further notice and the non-leader to steal its ID.
+ ASSERT_THAT(ptrace(PTRACE_CONT, leader_tid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(leader_tid, &status, 0),
+ SyscallSucceedsWithValue(leader_tid));
+ if (TraceExec()) {
+ // If PTRACE_O_TRACEEXEC was enabled, the execing thread should be in
+ // PTRACE_EVENT_EXEC-stop, with the event message set to its old thread ID.
+ EXPECT_EQ(SIGTRAP | (PTRACE_EVENT_EXEC << 8), status >> 8);
+ ASSERT_THAT(ptrace(PTRACE_GETEVENTMSG, leader_tid, 0, &eventmsg),
+ SyscallSucceeds());
+ EXPECT_EQ(nonleader_tid, eventmsg);
+ } else {
+ // Otherwise, the execing thread should have received SIGTRAP and should now
+ // be in signal-delivery-stop.
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP)
+ << " status " << status;
+ }
+
+#ifdef __x86_64__
+ {
+ // CS should be 0x33, indicating an 64-bit binary.
+ constexpr uint64_t kAMD64UserCS = 0x33;
+ EXPECT_THAT(ptrace(PTRACE_PEEKUSER, leader_tid,
+ offsetof(struct user_regs_struct, cs), 0),
+ SyscallSucceedsWithValue(kAMD64UserCS));
+ struct user_regs_struct regs = {};
+ ASSERT_THAT(ptrace(PTRACE_GETREGS, leader_tid, 0, &regs),
+ SyscallSucceeds());
+ EXPECT_EQ(kAMD64UserCS, regs.cs);
+ }
+#endif // defined(__x86_64__)
+
+ // PTRACE_O_TRACEEXIT should have been inherited across execve. Send SIGKILL,
+ // which should end the PTRACE_EVENT_EXEC-stop or signal-delivery-stop and
+ // leave the child in PTRACE_EVENT_EXIT-stop.
+ ASSERT_THAT(kill(leader_tid, SIGKILL), SyscallSucceeds());
+ ASSERT_THAT(waitpid(leader_tid, &status, 0),
+ SyscallSucceedsWithValue(leader_tid));
+ EXPECT_EQ(SIGTRAP | (PTRACE_EVENT_EXIT << 8), status >> 8);
+ ASSERT_THAT(ptrace(PTRACE_GETEVENTMSG, leader_tid, 0, &eventmsg),
+ SyscallSucceeds());
+ EXPECT_TRUE(WIFSIGNALED(eventmsg) && WTERMSIG(eventmsg) == SIGKILL)
+ << " eventmsg " << eventmsg;
+
+ // End the PTRACE_EVENT_EXIT stop, allowing the child to exit.
+ ASSERT_THAT(ptrace(PTRACE_CONT, leader_tid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(leader_tid, &status, 0),
+ SyscallSucceedsWithValue(leader_tid));
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL)
+ << " status " << status;
+}
+
+[[noreturn]] void RunExecveChild() {
+ // Enable tracing, then raise SIGSTOP and expect our parent to suppress it.
+ TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0);
+ MaybeSave();
+ RaiseSignal(SIGSTOP);
+ MaybeSave();
+
+ // Call execve() in a non-leader thread. As long as execve() succeeds, what
+ // exactly we execve() shouldn't really matter, since the tracer should kill
+ // us after execve() completes.
+ ScopedThread t([&] {
+ ExecveArray const owned_child_argv = {"/proc/self/exe",
+ "--this_flag_shouldnt_exist"};
+ char* const* const child_argv = owned_child_argv.get();
+ execve(child_argv[0], child_argv, /* envp = */ nullptr);
+ TEST_PCHECK_MSG(false, "Survived execve? (thread)");
+ });
+ t.Join();
+ TEST_CHECK_MSG(false, "Survived execve? (main)");
+ _exit(1);
+}
+
+INSTANTIATE_TEST_SUITE_P(TraceExec, PtraceExecveTest, ::testing::Bool());
+
+// This test has expectations on when syscall-enter/exit-stops occur that are
+// violated if saving occurs, since saving interrupts all syscalls, causing
+// premature syscall-exit.
+TEST(PtraceTest,
+ ExitWhenParentIsNotTracer_Syscall_TraceVfork_TraceVforkDone_NoRandomSave) {
+ constexpr int kExitTraceeExitCode = 99;
+
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // In child process.
+
+ // Block SIGCHLD so it doesn't interrupt wait4.
+ sigset_t mask;
+ TEST_PCHECK(sigemptyset(&mask) == 0);
+ TEST_PCHECK(sigaddset(&mask, SIGCHLD) == 0);
+ TEST_PCHECK(sigprocmask(SIG_SETMASK, &mask, nullptr) == 0);
+ MaybeSave();
+
+ // Enable tracing, then raise SIGSTOP and expect our parent to suppress it.
+ TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0);
+ MaybeSave();
+ RaiseSignal(SIGSTOP);
+ MaybeSave();
+
+ // Spawn a vfork child that exits immediately, and reap it. Don't save
+ // after vfork since the parent expects to see wait4 as the next syscall.
+ pid_t const pid = vfork();
+ if (pid == 0) {
+ _exit(kExitTraceeExitCode);
+ }
+ TEST_PCHECK_MSG(pid > 0, "vfork failed");
+
+ int status;
+ TEST_PCHECK(wait4(pid, &status, 0, nullptr) > 0);
+ MaybeSave();
+ TEST_CHECK(WIFEXITED(status) && WEXITSTATUS(status) == kExitTraceeExitCode);
+ _exit(0);
+ }
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Wait for the child to send itself SIGSTOP and enter signal-delivery-stop.
+ int status;
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP)
+ << " status " << status;
+
+ // Enable PTRACE_O_TRACEVFORK so we can get the ID of the grandchild,
+ // PTRACE_O_TRACEVFORKDONE so we can observe PTRACE_EVENT_VFORK_DONE, and
+ // PTRACE_O_TRACESYSGOOD so syscall-enter/exit-stops are unambiguously
+ // indicated by a stop signal of SIGTRAP|0x80 rather than just SIGTRAP.
+ ASSERT_THAT(ptrace(PTRACE_SETOPTIONS, child_pid, 0,
+ PTRACE_O_TRACEVFORK | PTRACE_O_TRACEVFORKDONE |
+ PTRACE_O_TRACESYSGOOD),
+ SyscallSucceeds());
+
+ // Suppress the SIGSTOP and wait for the child to report PTRACE_EVENT_VFORK.
+ // Get the new process' ID from the event.
+ ASSERT_THAT(ptrace(PTRACE_CONT, child_pid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_EQ(SIGTRAP | (PTRACE_EVENT_VFORK << 8), status >> 8);
+ unsigned long eventmsg;
+ ASSERT_THAT(ptrace(PTRACE_GETEVENTMSG, child_pid, 0, &eventmsg),
+ SyscallSucceeds());
+ pid_t const grandchild_pid = eventmsg;
+
+ // The grandchild should be traced by us and in signal-delivery-stop by
+ // SIGSTOP due to PTRACE_O_TRACEVFORK. This allows us to wait on it even
+ // though we're not its parent.
+ ASSERT_THAT(waitpid(grandchild_pid, &status, 0),
+ SyscallSucceedsWithValue(grandchild_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP)
+ << " status " << status;
+
+ // Resume the child with PTRACE_SYSCALL. Since the grandchild is still in
+ // signal-delivery-stop, the child should remain in vfork() waiting for the
+ // grandchild to exec or exit.
+ ASSERT_THAT(ptrace(PTRACE_SYSCALL, child_pid, 0, 0), SyscallSucceeds());
+ absl::SleepFor(absl::Seconds(1));
+ ASSERT_THAT(waitpid(child_pid, &status, WNOHANG),
+ SyscallSucceedsWithValue(0));
+
+ // Suppress the grandchild's SIGSTOP and wait for the grandchild to exit. Pass
+ // WNOWAIT to waitid() so that we don't acknowledge the grandchild's exit yet.
+ ASSERT_THAT(ptrace(PTRACE_CONT, grandchild_pid, 0, 0), SyscallSucceeds());
+ siginfo_t siginfo = {};
+ ASSERT_THAT(waitid(P_PID, grandchild_pid, &siginfo, WEXITED | WNOWAIT),
+ SyscallSucceeds());
+ EXPECT_EQ(SIGCHLD, siginfo.si_signo);
+ EXPECT_EQ(CLD_EXITED, siginfo.si_code);
+ EXPECT_EQ(kExitTraceeExitCode, siginfo.si_status);
+ EXPECT_EQ(grandchild_pid, siginfo.si_pid);
+ EXPECT_EQ(getuid(), siginfo.si_uid);
+
+ // The child should now be in PTRACE_EVENT_VFORK_DONE stop. The event
+ // message should still be the grandchild's PID.
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_EQ(SIGTRAP | (PTRACE_EVENT_VFORK_DONE << 8), status >> 8);
+ ASSERT_THAT(ptrace(PTRACE_GETEVENTMSG, child_pid, 0, &eventmsg),
+ SyscallSucceeds());
+ EXPECT_EQ(grandchild_pid, eventmsg);
+
+ // Resume the child with PTRACE_SYSCALL again and expect it to enter
+ // syscall-exit-stop for vfork() or clone(), either of which should return the
+ // grandchild's PID from the syscall. Aside from PTRACE_O_TRACESYSGOOD,
+ // syscall-stops are distinguished from signal-delivery-stop by
+ // PTRACE_GETSIGINFO returning a siginfo for which si_code == SIGTRAP or
+ // SIGTRAP|0x80.
+ ASSERT_THAT(ptrace(PTRACE_SYSCALL, child_pid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == (SIGTRAP | 0x80))
+ << " status " << status;
+ ASSERT_THAT(ptrace(PTRACE_GETSIGINFO, child_pid, 0, &siginfo),
+ SyscallSucceeds());
+ EXPECT_TRUE(siginfo.si_code == SIGTRAP || siginfo.si_code == (SIGTRAP | 0x80))
+ << "si_code = " << siginfo.si_code;
+
+ {
+ struct user_regs_struct regs = {};
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child_pid, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+#if defined(__x86_64__)
+ EXPECT_TRUE(regs.orig_rax == SYS_vfork || regs.orig_rax == SYS_clone)
+ << "orig_rax = " << regs.orig_rax;
+ EXPECT_EQ(grandchild_pid, regs.rax);
+#elif defined(__aarch64__)
+ EXPECT_TRUE(regs.regs[8] == SYS_clone) << "regs[8] = " << regs.regs[8];
+ EXPECT_EQ(grandchild_pid, regs.regs[0]);
+#endif // defined(__x86_64__)
+ }
+
+ // After this point, the child will be making wait4 syscalls that will be
+ // interrupted by saving, so saving is not permitted. Note that this is
+ // explicitly released below once the grandchild exits.
+ DisableSave ds;
+
+ // Resume the child with PTRACE_SYSCALL again and expect it to enter
+ // syscall-enter-stop for wait4().
+ ASSERT_THAT(ptrace(PTRACE_SYSCALL, child_pid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == (SIGTRAP | 0x80))
+ << " status " << status;
+ ASSERT_THAT(ptrace(PTRACE_GETSIGINFO, child_pid, 0, &siginfo),
+ SyscallSucceeds());
+ EXPECT_TRUE(siginfo.si_code == SIGTRAP || siginfo.si_code == (SIGTRAP | 0x80))
+ << "si_code = " << siginfo.si_code;
+#ifdef __x86_64__
+ {
+ EXPECT_THAT(ptrace(PTRACE_PEEKUSER, child_pid,
+ offsetof(struct user_regs_struct, orig_rax), 0),
+ SyscallSucceedsWithValue(SYS_wait4));
+ }
+#endif // defined(__x86_64__)
+
+ // Resume the child with PTRACE_SYSCALL again. Since the grandchild is
+ // waiting for the tracer (us) to acknowledge its exit first, wait4 should
+ // block.
+ ASSERT_THAT(ptrace(PTRACE_SYSCALL, child_pid, 0, 0), SyscallSucceeds());
+ absl::SleepFor(absl::Seconds(1));
+ ASSERT_THAT(waitpid(child_pid, &status, WNOHANG),
+ SyscallSucceedsWithValue(0));
+
+ // Acknowledge the grandchild's exit.
+ ASSERT_THAT(waitpid(grandchild_pid, &status, 0),
+ SyscallSucceedsWithValue(grandchild_pid));
+ ds.reset();
+
+ // Now the child should enter syscall-exit-stop for wait4, returning with the
+ // grandchild's PID.
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == (SIGTRAP | 0x80))
+ << " status " << status;
+ {
+ struct user_regs_struct regs = {};
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child_pid, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+#if defined(__x86_64__)
+ EXPECT_EQ(SYS_wait4, regs.orig_rax);
+ EXPECT_EQ(grandchild_pid, regs.rax);
+#elif defined(__aarch64__)
+ EXPECT_EQ(SYS_wait4, regs.regs[8]);
+ EXPECT_EQ(grandchild_pid, regs.regs[0]);
+#endif // defined(__x86_64__)
+ }
+
+ // Detach from the child and wait for it to exit.
+ ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << " status " << status;
+}
+
+// These tests requires knowledge of architecture-specific syscall convention.
+#ifdef __x86_64__
+TEST(PtraceTest, Int3) {
+ SKIP_IF(PlatformSupportInt3() == PlatformSupport::NotSupported);
+
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // In child process.
+
+ // Enable tracing.
+ TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0);
+
+ // Interrupt 3 - trap to debugger
+ asm("int3");
+
+ _exit(56);
+ }
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP)
+ << " status " << status;
+
+ ASSERT_THAT(ptrace(PTRACE_CONT, child_pid, 0, 0), SyscallSucceeds());
+
+ // The child should validate the injected return value and then exit normally.
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 56)
+ << " status " << status;
+}
+
+TEST(PtraceTest, Sysemu_PokeUser) {
+ constexpr int kSysemuHelperFirstExitCode = 126;
+ constexpr uint64_t kSysemuInjectedExitGroupReturn = 42;
+
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // In child process.
+
+ // Enable tracing, then raise SIGSTOP and expect our parent to suppress it.
+ TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0);
+ RaiseSignal(SIGSTOP);
+
+ // Try to exit_group, expecting the tracer to skip the syscall and set its
+ // own return value.
+ int const rv = syscall(SYS_exit_group, kSysemuHelperFirstExitCode);
+ TEST_PCHECK_MSG(rv == kSysemuInjectedExitGroupReturn,
+ "exit_group returned incorrect value");
+
+ _exit(0);
+ }
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Wait for the child to send itself SIGSTOP and enter signal-delivery-stop.
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP)
+ << " status " << status;
+
+ // Suppress the SIGSTOP and wait for the child to enter syscall-enter-stop
+ // for its first exit_group syscall.
+ ASSERT_THAT(ptrace(kPtraceSysemu, child_pid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP)
+ << " status " << status;
+
+ struct user_regs_struct regs = {};
+ ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, &regs), SyscallSucceeds());
+ EXPECT_EQ(SYS_exit_group, regs.orig_rax);
+ EXPECT_EQ(-ENOSYS, regs.rax);
+ EXPECT_EQ(kSysemuHelperFirstExitCode, regs.rdi);
+
+ // Replace the exit_group return value, then resume the child, which should
+ // automatically skip the syscall.
+ ASSERT_THAT(
+ ptrace(PTRACE_POKEUSER, child_pid, offsetof(struct user_regs_struct, rax),
+ kSysemuInjectedExitGroupReturn),
+ SyscallSucceeds());
+ ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds());
+
+ // The child should validate the injected return value and then exit normally.
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << " status " << status;
+}
+
+// This test also cares about syscall-exit-stop.
+TEST(PtraceTest, ERESTART_NoRandomSave) {
+ constexpr int kSigno = SIGUSR1;
+
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // In child process.
+
+ // Ignore, but unblock, kSigno.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_IGN;
+ TEST_PCHECK(sigfillset(&sa.sa_mask) == 0);
+ TEST_PCHECK(sigaction(kSigno, &sa, nullptr) == 0);
+ MaybeSave();
+ TEST_PCHECK(sigprocmask(SIG_UNBLOCK, &sa.sa_mask, nullptr) == 0);
+ MaybeSave();
+
+ // Enable tracing, then raise SIGSTOP and expect our parent to suppress it.
+ TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0);
+ RaiseSignal(SIGSTOP);
+
+ // Invoke the pause syscall, which normally should not return until we
+ // receive a signal that "either terminates the process or causes the
+ // invocation of a signal-catching function".
+ pause();
+
+ _exit(0);
+ }
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Wait for the child to send itself SIGSTOP and enter signal-delivery-stop.
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP)
+ << " status " << status;
+
+ // After this point, the child's pause syscall will be interrupted by saving,
+ // so saving is not permitted. Note that this is explicitly released below
+ // once the child is stopped.
+ DisableSave ds;
+
+ // Suppress the SIGSTOP and wait for the child to enter syscall-enter-stop for
+ // its pause syscall.
+ ASSERT_THAT(ptrace(PTRACE_SYSCALL, child_pid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP)
+ << " status " << status;
+
+ struct user_regs_struct regs = {};
+ ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, &regs), SyscallSucceeds());
+ EXPECT_EQ(SYS_pause, regs.orig_rax);
+ EXPECT_EQ(-ENOSYS, regs.rax);
+
+ // Resume the child with PTRACE_SYSCALL and expect it to block in the pause
+ // syscall.
+ ASSERT_THAT(ptrace(PTRACE_SYSCALL, child_pid, 0, 0), SyscallSucceeds());
+ absl::SleepFor(absl::Seconds(1));
+ ASSERT_THAT(waitpid(child_pid, &status, WNOHANG),
+ SyscallSucceedsWithValue(0));
+
+ // Send the child kSigno, causing it to return ERESTARTNOHAND and enter
+ // syscall-exit-stop from the pause syscall.
+ constexpr int ERESTARTNOHAND = 514;
+ ASSERT_THAT(kill(child_pid, kSigno), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP)
+ << " status " << status;
+ ds.reset();
+
+ ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, &regs), SyscallSucceeds());
+ EXPECT_EQ(SYS_pause, regs.orig_rax);
+ EXPECT_EQ(-ERESTARTNOHAND, regs.rax);
+
+ // Replace the return value from pause with 0, causing pause to not be
+ // restarted despite kSigno being ignored.
+ ASSERT_THAT(ptrace(PTRACE_POKEUSER, child_pid,
+ offsetof(struct user_regs_struct, rax), 0),
+ SyscallSucceeds());
+
+ // Detach from the child and wait for it to exit.
+ ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << " status " << status;
+}
+#endif // defined(__x86_64__)
+
+TEST(PtraceTest, Seize_Interrupt_Listen) {
+ volatile long child_should_spin = 1;
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // In child process.
+ while (child_should_spin) {
+ SleepSafe(absl::Seconds(1));
+ }
+ _exit(1);
+ }
+
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Attach to the child with PTRACE_SEIZE; doing so should not stop the child.
+ ASSERT_THAT(ptrace(PTRACE_SEIZE, child_pid, 0, 0), SyscallSucceeds());
+ int status;
+ EXPECT_THAT(waitpid(child_pid, &status, WNOHANG),
+ SyscallSucceedsWithValue(0));
+
+ // Stop the child with PTRACE_INTERRUPT.
+ ASSERT_THAT(ptrace(PTRACE_INTERRUPT, child_pid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_EQ(SIGTRAP | (kPtraceEventStop << 8), status >> 8);
+
+ // Unset child_should_spin to verify that the child never leaves the spin
+ // loop.
+ ASSERT_THAT(ptrace(PTRACE_POKEDATA, child_pid, &child_should_spin, 0),
+ SyscallSucceeds());
+
+ // Send SIGSTOP to the child, then resume it, allowing it to proceed to
+ // signal-delivery-stop.
+ ASSERT_THAT(kill(child_pid, SIGSTOP), SyscallSucceeds());
+ ASSERT_THAT(ptrace(PTRACE_CONT, child_pid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP)
+ << " status " << status;
+
+ // Release the child from signal-delivery-stop without suppressing the
+ // SIGSTOP, causing it to enter group-stop.
+ ASSERT_THAT(ptrace(PTRACE_CONT, child_pid, 0, SIGSTOP), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_EQ(SIGSTOP | (kPtraceEventStop << 8), status >> 8);
+
+ // "The state of the tracee after PTRACE_LISTEN is somewhat of a gray area: it
+ // is not in any ptrace-stop (ptrace commands won't work on it, and it will
+ // deliver waitpid(2) notifications), but it also may be considered 'stopped'
+ // because it is not executing instructions (is not scheduled), and if it was
+ // in group-stop before PTRACE_LISTEN, it will not respond to signals until
+ // SIGCONT is received." - ptrace(2).
+ ASSERT_THAT(ptrace(PTRACE_LISTEN, child_pid, 0, 0), SyscallSucceeds());
+ EXPECT_THAT(ptrace(PTRACE_CONT, child_pid, 0, 0),
+ SyscallFailsWithErrno(ESRCH));
+ EXPECT_THAT(waitpid(child_pid, &status, WNOHANG),
+ SyscallSucceedsWithValue(0));
+ EXPECT_THAT(kill(child_pid, SIGTERM), SyscallSucceeds());
+ absl::SleepFor(absl::Seconds(1));
+ EXPECT_THAT(waitpid(child_pid, &status, WNOHANG),
+ SyscallSucceedsWithValue(0));
+
+ // Send SIGCONT to the child, causing it to leave group-stop and re-trap due
+ // to PTRACE_LISTEN.
+ EXPECT_THAT(kill(child_pid, SIGCONT), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_EQ(SIGTRAP | (kPtraceEventStop << 8), status >> 8);
+
+ // Detach the child and expect it to exit due to the SIGTERM we sent while
+ // it was stopped by PTRACE_LISTEN.
+ ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGTERM)
+ << " status " << status;
+}
+
+TEST(PtraceTest, Interrupt_Listen_RequireSeize) {
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // In child process.
+ TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0);
+ MaybeSave();
+ raise(SIGSTOP);
+ _exit(0);
+ }
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Wait for the child to send itself SIGSTOP and enter signal-delivery-stop.
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP)
+ << " status " << status;
+
+ // PTRACE_INTERRUPT and PTRACE_LISTEN should fail since the child wasn't
+ // attached with PTRACE_SEIZE, leaving the child in signal-delivery-stop.
+ EXPECT_THAT(ptrace(PTRACE_INTERRUPT, child_pid, 0, 0),
+ SyscallFailsWithErrno(EIO));
+ EXPECT_THAT(ptrace(PTRACE_LISTEN, child_pid, 0, 0),
+ SyscallFailsWithErrno(EIO));
+
+ // Suppress SIGSTOP and detach from the child, expecting it to exit normally.
+ ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << " status " << status;
+}
+
+TEST(PtraceTest, SeizeSetOptions) {
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // In child process.
+ while (true) {
+ SleepSafe(absl::Seconds(1));
+ }
+ }
+
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Attach to the child with PTRACE_SEIZE while setting PTRACE_O_TRACESYSGOOD.
+ ASSERT_THAT(ptrace(PTRACE_SEIZE, child_pid, 0, PTRACE_O_TRACESYSGOOD),
+ SyscallSucceeds());
+
+ // Stop the child with PTRACE_INTERRUPT.
+ ASSERT_THAT(ptrace(PTRACE_INTERRUPT, child_pid, 0, 0), SyscallSucceeds());
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_EQ(SIGTRAP | (kPtraceEventStop << 8), status >> 8);
+
+ // Resume the child with PTRACE_SYSCALL and wait for it to enter
+ // syscall-enter-stop. The stop signal status from the syscall stop should be
+ // SIGTRAP|0x80, reflecting PTRACE_O_TRACESYSGOOD.
+ ASSERT_THAT(ptrace(PTRACE_SYSCALL, child_pid, 0, 0), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == (SIGTRAP | 0x80))
+ << " status " << status;
+
+ // Clean up the child.
+ ASSERT_THAT(kill(child_pid, SIGKILL), SyscallSucceeds());
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ if (WIFSTOPPED(status) && WSTOPSIG(status) == (SIGTRAP | 0x80)) {
+ // "SIGKILL kills even within system calls (syscall-exit-stop is not
+ // generated prior to death by SIGKILL). The net effect is that SIGKILL
+ // always kills the process (all its threads), even if some threads of the
+ // process are ptraced." - ptrace(2). This is technically true, but...
+ //
+ // When we send SIGKILL to the child, kernel/signal.c:complete_signal() =>
+ // signal_wake_up(resume=1) kicks the tracee out of the syscall-enter-stop.
+ // The pending SIGKILL causes the syscall to be skipped, but the child
+ // thread still reports syscall-exit before checking for pending signals; in
+ // current kernels, this is
+ // arch/x86/entry/common.c:syscall_return_slowpath() =>
+ // syscall_slow_exit_work() =>
+ // include/linux/tracehook.h:tracehook_report_syscall_exit() =>
+ // ptrace_report_syscall() => kernel/signal.c:ptrace_notify() =>
+ // ptrace_do_notify() => ptrace_stop().
+ //
+ // ptrace_stop() sets the task's state to TASK_TRACED and the task's
+ // exit_code to SIGTRAP|0x80 (passed by ptrace_report_syscall()), then calls
+ // freezable_schedule(). freezable_schedule() eventually reaches
+ // __schedule(), which detects signal_pending_state() due to the pending
+ // SIGKILL, sets the task's state back to TASK_RUNNING, and returns without
+ // descheduling. Thus, the task never enters syscall-exit-stop. However, if
+ // our wait4() => kernel/exit.c:wait_task_stopped() racily observes the
+ // TASK_TRACED state and the non-zero exit code set by ptrace_stop() before
+ // __schedule() sets the state back to TASK_RUNNING, it will return the
+ // task's exit_code as status W_STOPCODE(SIGTRAP|0x80). So we get a spurious
+ // syscall-exit-stop notification, and need to wait4() again for task exit.
+ //
+ // gVisor is not susceptible to this race because
+ // kernel.Task.waitCollectTraceeStopLocked() checks specifically for an
+ // active ptraceStop, which is not initiated if SIGKILL is pending.
+ std::cout << "Observed syscall-exit after SIGKILL" << std::endl;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ }
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL)
+ << " status " << status;
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
+
+int main(int argc, char** argv) {
+ gvisor::testing::TestInit(&argc, &argv);
+
+ if (absl::GetFlag(FLAGS_ptrace_test_execve_child)) {
+ gvisor::testing::RunExecveChild();
+ }
+
+ return gvisor::testing::RunAllTests();
+}
diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc
new file mode 100644
index 000000000..f9392b9e0
--- /dev/null
+++ b/test/syscalls/linux/pty.cc
@@ -0,0 +1,1627 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <linux/capability.h>
+#include <linux/major.h>
+#include <poll.h>
+#include <sched.h>
+#include <signal.h>
+#include <sys/ioctl.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <sys/sysmacros.h>
+#include <sys/types.h>
+#include <sys/wait.h>
+#include <termios.h>
+#include <unistd.h>
+
+#include <iostream>
+
+#include "gtest/gtest.h"
+#include "absl/base/macros.h"
+#include "absl/strings/str_cat.h"
+#include "absl/synchronization/notification.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/pty_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+using ::testing::AnyOf;
+using ::testing::Contains;
+using ::testing::Eq;
+using ::testing::Not;
+
+// Tests Unix98 pseudoterminals.
+//
+// These tests assume that /dev/ptmx exists and is associated with a devpts
+// filesystem mounted at /dev/pts/. While a Linux distribution could
+// theoretically place those anywhere, glibc expects those locations, so they
+// are effectively fixed.
+
+// Minor device number for an unopened ptmx file.
+constexpr int kPtmxMinor = 2;
+
+// The timeout when polling for data from a pty. When data is written to one end
+// of a pty, Linux asynchronously makes it available to the other end, so we
+// have to wait.
+constexpr absl::Duration kTimeout = absl::Seconds(20);
+
+// The maximum line size in bytes returned per read from a pty file.
+constexpr int kMaxLineSize = 4096;
+
+constexpr char kMasterPath[] = "/dev/ptmx";
+
+// glibc defines its own, different, version of struct termios. We care about
+// what the kernel does, not glibc.
+#define KERNEL_NCCS 19
+struct kernel_termios {
+ tcflag_t c_iflag;
+ tcflag_t c_oflag;
+ tcflag_t c_cflag;
+ tcflag_t c_lflag;
+ cc_t c_line;
+ cc_t c_cc[KERNEL_NCCS];
+};
+
+bool operator==(struct kernel_termios const& a,
+ struct kernel_termios const& b) {
+ return memcmp(&a, &b, sizeof(a)) == 0;
+}
+
+// Returns the termios-style control character for the passed character.
+//
+// e.g., for Ctrl-C, i.e., ^C, call ControlCharacter('C').
+//
+// Standard control characters are ASCII bytes 0 through 31.
+constexpr char ControlCharacter(char c) {
+ // A is 1, B is 2, etc.
+ return c - 'A' + 1;
+}
+
+// Returns the printable character the given control character represents.
+constexpr char FromControlCharacter(char c) { return c + 'A' - 1; }
+
+// Returns true if c is a control character.
+//
+// Standard control characters are ASCII bytes 0 through 31.
+constexpr bool IsControlCharacter(char c) { return c <= 31; }
+
+struct Field {
+ const char* name;
+ uint64_t mask;
+ uint64_t value;
+};
+
+// ParseFields returns a string representation of value, using the names in
+// fields.
+std::string ParseFields(const Field* fields, size_t len, uint64_t value) {
+ bool first = true;
+ std::string s;
+ for (size_t i = 0; i < len; i++) {
+ const Field f = fields[i];
+ if ((value & f.mask) == f.value) {
+ if (!first) {
+ s += "|";
+ }
+ s += f.name;
+ first = false;
+ value &= ~f.mask;
+ }
+ }
+
+ if (value) {
+ if (!first) {
+ s += "|";
+ }
+ absl::StrAppend(&s, value);
+ }
+
+ return s;
+}
+
+const Field kIflagFields[] = {
+ {"IGNBRK", IGNBRK, IGNBRK}, {"BRKINT", BRKINT, BRKINT},
+ {"IGNPAR", IGNPAR, IGNPAR}, {"PARMRK", PARMRK, PARMRK},
+ {"INPCK", INPCK, INPCK}, {"ISTRIP", ISTRIP, ISTRIP},
+ {"INLCR", INLCR, INLCR}, {"IGNCR", IGNCR, IGNCR},
+ {"ICRNL", ICRNL, ICRNL}, {"IUCLC", IUCLC, IUCLC},
+ {"IXON", IXON, IXON}, {"IXANY", IXANY, IXANY},
+ {"IXOFF", IXOFF, IXOFF}, {"IMAXBEL", IMAXBEL, IMAXBEL},
+ {"IUTF8", IUTF8, IUTF8},
+};
+
+const Field kOflagFields[] = {
+ {"OPOST", OPOST, OPOST}, {"OLCUC", OLCUC, OLCUC},
+ {"ONLCR", ONLCR, ONLCR}, {"OCRNL", OCRNL, OCRNL},
+ {"ONOCR", ONOCR, ONOCR}, {"ONLRET", ONLRET, ONLRET},
+ {"OFILL", OFILL, OFILL}, {"OFDEL", OFDEL, OFDEL},
+ {"NL0", NLDLY, NL0}, {"NL1", NLDLY, NL1},
+ {"CR0", CRDLY, CR0}, {"CR1", CRDLY, CR1},
+ {"CR2", CRDLY, CR2}, {"CR3", CRDLY, CR3},
+ {"TAB0", TABDLY, TAB0}, {"TAB1", TABDLY, TAB1},
+ {"TAB2", TABDLY, TAB2}, {"TAB3", TABDLY, TAB3},
+ {"BS0", BSDLY, BS0}, {"BS1", BSDLY, BS1},
+ {"FF0", FFDLY, FF0}, {"FF1", FFDLY, FF1},
+ {"VT0", VTDLY, VT0}, {"VT1", VTDLY, VT1},
+ {"XTABS", XTABS, XTABS},
+};
+
+#ifndef IBSHIFT
+// Shift from CBAUD to CIBAUD.
+#define IBSHIFT 16
+#endif
+
+const Field kCflagFields[] = {
+ {"B0", CBAUD, B0},
+ {"B50", CBAUD, B50},
+ {"B75", CBAUD, B75},
+ {"B110", CBAUD, B110},
+ {"B134", CBAUD, B134},
+ {"B150", CBAUD, B150},
+ {"B200", CBAUD, B200},
+ {"B300", CBAUD, B300},
+ {"B600", CBAUD, B600},
+ {"B1200", CBAUD, B1200},
+ {"B1800", CBAUD, B1800},
+ {"B2400", CBAUD, B2400},
+ {"B4800", CBAUD, B4800},
+ {"B9600", CBAUD, B9600},
+ {"B19200", CBAUD, B19200},
+ {"B38400", CBAUD, B38400},
+ {"CS5", CSIZE, CS5},
+ {"CS6", CSIZE, CS6},
+ {"CS7", CSIZE, CS7},
+ {"CS8", CSIZE, CS8},
+ {"CSTOPB", CSTOPB, CSTOPB},
+ {"CREAD", CREAD, CREAD},
+ {"PARENB", PARENB, PARENB},
+ {"PARODD", PARODD, PARODD},
+ {"HUPCL", HUPCL, HUPCL},
+ {"CLOCAL", CLOCAL, CLOCAL},
+ {"B57600", CBAUD, B57600},
+ {"B115200", CBAUD, B115200},
+ {"B230400", CBAUD, B230400},
+ {"B460800", CBAUD, B460800},
+ {"B500000", CBAUD, B500000},
+ {"B576000", CBAUD, B576000},
+ {"B921600", CBAUD, B921600},
+ {"B1000000", CBAUD, B1000000},
+ {"B1152000", CBAUD, B1152000},
+ {"B1500000", CBAUD, B1500000},
+ {"B2000000", CBAUD, B2000000},
+ {"B2500000", CBAUD, B2500000},
+ {"B3000000", CBAUD, B3000000},
+ {"B3500000", CBAUD, B3500000},
+ {"B4000000", CBAUD, B4000000},
+ {"CMSPAR", CMSPAR, CMSPAR},
+ {"CRTSCTS", CRTSCTS, CRTSCTS},
+ {"IB0", CIBAUD, B0 << IBSHIFT},
+ {"IB50", CIBAUD, B50 << IBSHIFT},
+ {"IB75", CIBAUD, B75 << IBSHIFT},
+ {"IB110", CIBAUD, B110 << IBSHIFT},
+ {"IB134", CIBAUD, B134 << IBSHIFT},
+ {"IB150", CIBAUD, B150 << IBSHIFT},
+ {"IB200", CIBAUD, B200 << IBSHIFT},
+ {"IB300", CIBAUD, B300 << IBSHIFT},
+ {"IB600", CIBAUD, B600 << IBSHIFT},
+ {"IB1200", CIBAUD, B1200 << IBSHIFT},
+ {"IB1800", CIBAUD, B1800 << IBSHIFT},
+ {"IB2400", CIBAUD, B2400 << IBSHIFT},
+ {"IB4800", CIBAUD, B4800 << IBSHIFT},
+ {"IB9600", CIBAUD, B9600 << IBSHIFT},
+ {"IB19200", CIBAUD, B19200 << IBSHIFT},
+ {"IB38400", CIBAUD, B38400 << IBSHIFT},
+ {"IB57600", CIBAUD, B57600 << IBSHIFT},
+ {"IB115200", CIBAUD, B115200 << IBSHIFT},
+ {"IB230400", CIBAUD, B230400 << IBSHIFT},
+ {"IB460800", CIBAUD, B460800 << IBSHIFT},
+ {"IB500000", CIBAUD, B500000 << IBSHIFT},
+ {"IB576000", CIBAUD, B576000 << IBSHIFT},
+ {"IB921600", CIBAUD, B921600 << IBSHIFT},
+ {"IB1000000", CIBAUD, B1000000 << IBSHIFT},
+ {"IB1152000", CIBAUD, B1152000 << IBSHIFT},
+ {"IB1500000", CIBAUD, B1500000 << IBSHIFT},
+ {"IB2000000", CIBAUD, B2000000 << IBSHIFT},
+ {"IB2500000", CIBAUD, B2500000 << IBSHIFT},
+ {"IB3000000", CIBAUD, B3000000 << IBSHIFT},
+ {"IB3500000", CIBAUD, B3500000 << IBSHIFT},
+ {"IB4000000", CIBAUD, B4000000 << IBSHIFT},
+};
+
+const Field kLflagFields[] = {
+ {"ISIG", ISIG, ISIG}, {"ICANON", ICANON, ICANON},
+ {"XCASE", XCASE, XCASE}, {"ECHO", ECHO, ECHO},
+ {"ECHOE", ECHOE, ECHOE}, {"ECHOK", ECHOK, ECHOK},
+ {"ECHONL", ECHONL, ECHONL}, {"NOFLSH", NOFLSH, NOFLSH},
+ {"TOSTOP", TOSTOP, TOSTOP}, {"ECHOCTL", ECHOCTL, ECHOCTL},
+ {"ECHOPRT", ECHOPRT, ECHOPRT}, {"ECHOKE", ECHOKE, ECHOKE},
+ {"FLUSHO", FLUSHO, FLUSHO}, {"PENDIN", PENDIN, PENDIN},
+ {"IEXTEN", IEXTEN, IEXTEN}, {"EXTPROC", EXTPROC, EXTPROC},
+};
+
+std::string FormatCC(char c) {
+ if (isgraph(c)) {
+ return std::string(1, c);
+ } else if (c == ' ') {
+ return " ";
+ } else if (c == '\t') {
+ return "\\t";
+ } else if (c == '\r') {
+ return "\\r";
+ } else if (c == '\n') {
+ return "\\n";
+ } else if (c == '\0') {
+ return "\\0";
+ } else if (IsControlCharacter(c)) {
+ return absl::StrCat("^", std::string(1, FromControlCharacter(c)));
+ }
+ return absl::StrCat("\\x", absl::Hex(c));
+}
+
+std::ostream& operator<<(std::ostream& os, struct kernel_termios const& a) {
+ os << "{ c_iflag = "
+ << ParseFields(kIflagFields, ABSL_ARRAYSIZE(kIflagFields), a.c_iflag);
+ os << ", c_oflag = "
+ << ParseFields(kOflagFields, ABSL_ARRAYSIZE(kOflagFields), a.c_oflag);
+ os << ", c_cflag = "
+ << ParseFields(kCflagFields, ABSL_ARRAYSIZE(kCflagFields), a.c_cflag);
+ os << ", c_lflag = "
+ << ParseFields(kLflagFields, ABSL_ARRAYSIZE(kLflagFields), a.c_lflag);
+ os << ", c_line = " << a.c_line;
+ os << ", c_cc = { [VINTR] = '" << FormatCC(a.c_cc[VINTR]);
+ os << "', [VQUIT] = '" << FormatCC(a.c_cc[VQUIT]);
+ os << "', [VERASE] = '" << FormatCC(a.c_cc[VERASE]);
+ os << "', [VKILL] = '" << FormatCC(a.c_cc[VKILL]);
+ os << "', [VEOF] = '" << FormatCC(a.c_cc[VEOF]);
+ os << "', [VTIME] = '" << static_cast<int>(a.c_cc[VTIME]);
+ os << "', [VMIN] = " << static_cast<int>(a.c_cc[VMIN]);
+ os << ", [VSWTC] = '" << FormatCC(a.c_cc[VSWTC]);
+ os << "', [VSTART] = '" << FormatCC(a.c_cc[VSTART]);
+ os << "', [VSTOP] = '" << FormatCC(a.c_cc[VSTOP]);
+ os << "', [VSUSP] = '" << FormatCC(a.c_cc[VSUSP]);
+ os << "', [VEOL] = '" << FormatCC(a.c_cc[VEOL]);
+ os << "', [VREPRINT] = '" << FormatCC(a.c_cc[VREPRINT]);
+ os << "', [VDISCARD] = '" << FormatCC(a.c_cc[VDISCARD]);
+ os << "', [VWERASE] = '" << FormatCC(a.c_cc[VWERASE]);
+ os << "', [VLNEXT] = '" << FormatCC(a.c_cc[VLNEXT]);
+ os << "', [VEOL2] = '" << FormatCC(a.c_cc[VEOL2]);
+ os << "'}";
+ return os;
+}
+
+// Return the default termios settings for a new terminal.
+struct kernel_termios DefaultTermios() {
+ struct kernel_termios t = {};
+ t.c_iflag = IXON | ICRNL;
+ t.c_oflag = OPOST | ONLCR;
+ t.c_cflag = B38400 | CSIZE | CS8 | CREAD;
+ t.c_lflag = ISIG | ICANON | ECHO | ECHOE | ECHOK | ECHOCTL | ECHOKE | IEXTEN;
+ t.c_line = 0;
+ t.c_cc[VINTR] = ControlCharacter('C');
+ t.c_cc[VQUIT] = ControlCharacter('\\');
+ t.c_cc[VERASE] = '\x7f';
+ t.c_cc[VKILL] = ControlCharacter('U');
+ t.c_cc[VEOF] = ControlCharacter('D');
+ t.c_cc[VTIME] = '\0';
+ t.c_cc[VMIN] = 1;
+ t.c_cc[VSWTC] = '\0';
+ t.c_cc[VSTART] = ControlCharacter('Q');
+ t.c_cc[VSTOP] = ControlCharacter('S');
+ t.c_cc[VSUSP] = ControlCharacter('Z');
+ t.c_cc[VEOL] = '\0';
+ t.c_cc[VREPRINT] = ControlCharacter('R');
+ t.c_cc[VDISCARD] = ControlCharacter('O');
+ t.c_cc[VWERASE] = ControlCharacter('W');
+ t.c_cc[VLNEXT] = ControlCharacter('V');
+ t.c_cc[VEOL2] = '\0';
+ return t;
+}
+
+// PollAndReadFd tries to read count bytes from buf within timeout.
+//
+// Returns a partial read if some bytes were read.
+//
+// fd must be non-blocking.
+PosixErrorOr<size_t> PollAndReadFd(int fd, void* buf, size_t count,
+ absl::Duration timeout) {
+ absl::Time end = absl::Now() + timeout;
+
+ size_t completed = 0;
+ absl::Duration remaining;
+ while ((remaining = end - absl::Now()) > absl::ZeroDuration()) {
+ struct pollfd pfd = {fd, POLLIN, 0};
+ int ret = RetryEINTR(poll)(&pfd, 1, absl::ToInt64Milliseconds(remaining));
+ if (ret < 0) {
+ return PosixError(errno, "poll failed");
+ } else if (ret == 0) {
+ // Timed out.
+ continue;
+ } else if (ret != 1) {
+ return PosixError(EINVAL, absl::StrCat("Bad poll ret ", ret));
+ }
+
+ ssize_t n =
+ ReadFd(fd, static_cast<char*>(buf) + completed, count - completed);
+ if (n < 0) {
+ if (errno == EAGAIN) {
+ // Linux sometimes returns EAGAIN from this read, despite the fact that
+ // poll returned success. Let's just do what do as we are told and try
+ // again.
+ continue;
+ }
+ return PosixError(errno, "read failed");
+ }
+ completed += n;
+ if (completed >= count) {
+ return completed;
+ }
+ }
+
+ if (completed) {
+ return completed;
+ }
+ return PosixError(ETIMEDOUT, "Poll timed out");
+}
+
+TEST(PtyTrunc, Truncate) {
+ // Opening PTYs with O_TRUNC shouldn't cause an error, but calls to
+ // (f)truncate should.
+ FileDescriptor master =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(kMasterPath, O_RDWR | O_TRUNC));
+ int n = ASSERT_NO_ERRNO_AND_VALUE(SlaveID(master));
+ std::string spath = absl::StrCat("/dev/pts/", n);
+ FileDescriptor slave =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(spath, O_RDWR | O_NONBLOCK | O_TRUNC));
+
+ EXPECT_THAT(truncate(kMasterPath, 0), SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(truncate(spath.c_str(), 0), SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(ftruncate(master.get(), 0), SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(ftruncate(slave.get(), 0), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(BasicPtyTest, StatUnopenedMaster) {
+ struct stat s;
+ ASSERT_THAT(stat(kMasterPath, &s), SyscallSucceeds());
+
+ EXPECT_EQ(s.st_rdev, makedev(TTYAUX_MAJOR, kPtmxMinor));
+ EXPECT_EQ(s.st_size, 0);
+ EXPECT_EQ(s.st_blocks, 0);
+
+ // ptmx attached to a specific devpts mount uses block size 1024. See
+ // fs/devpts/inode.c:devpts_fill_super.
+ //
+ // The global ptmx device uses the block size of the filesystem it is created
+ // on (which is usually 4096 for disk filesystems).
+ EXPECT_THAT(s.st_blksize, AnyOf(Eq(1024), Eq(4096)));
+}
+
+// Waits for count bytes to be readable from fd. Unlike poll, which can return
+// before all data is moved into a pty's read buffer, this function waits for
+// all count bytes to become readable.
+PosixErrorOr<int> WaitUntilReceived(int fd, int count) {
+ int buffered = -1;
+ absl::Duration remaining;
+ absl::Time end = absl::Now() + kTimeout;
+ while ((remaining = end - absl::Now()) > absl::ZeroDuration()) {
+ if (ioctl(fd, FIONREAD, &buffered) < 0) {
+ return PosixError(errno, "failed FIONREAD ioctl");
+ }
+ if (buffered >= count) {
+ return buffered;
+ }
+ absl::SleepFor(absl::Milliseconds(500));
+ }
+ return PosixError(
+ ETIMEDOUT,
+ absl::StrFormat(
+ "FIONREAD timed out, receiving only %d of %d expected bytes",
+ buffered, count));
+}
+
+// Verifies that there is nothing left to read from fd.
+void ExpectFinished(const FileDescriptor& fd) {
+ // Nothing more to read.
+ char c;
+ EXPECT_THAT(ReadFd(fd.get(), &c, 1), SyscallFailsWithErrno(EAGAIN));
+}
+
+// Verifies that we can read expected bytes from fd into buf.
+void ExpectReadable(const FileDescriptor& fd, int expected, char* buf) {
+ size_t n = ASSERT_NO_ERRNO_AND_VALUE(
+ PollAndReadFd(fd.get(), buf, expected, kTimeout));
+ EXPECT_EQ(expected, n);
+}
+
+TEST(BasicPtyTest, OpenMasterSlave) {
+ FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR));
+ FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master));
+}
+
+// The slave entry in /dev/pts/ disappears when the master is closed, even if
+// the slave is still open.
+TEST(BasicPtyTest, SlaveEntryGoneAfterMasterClose) {
+ FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR));
+ FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master));
+
+ // Get pty index.
+ int index = -1;
+ ASSERT_THAT(ioctl(master.get(), TIOCGPTN, &index), SyscallSucceeds());
+
+ std::string path = absl::StrCat("/dev/pts/", index);
+
+ struct stat st;
+ EXPECT_THAT(stat(path.c_str(), &st), SyscallSucceeds());
+
+ master.reset();
+
+ EXPECT_THAT(stat(path.c_str(), &st), SyscallFailsWithErrno(ENOENT));
+}
+
+TEST(BasicPtyTest, Getdents) {
+ FileDescriptor master1 = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR));
+ int index1 = -1;
+ ASSERT_THAT(ioctl(master1.get(), TIOCGPTN, &index1), SyscallSucceeds());
+ FileDescriptor slave1 = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master1));
+
+ FileDescriptor master2 = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR));
+ int index2 = -1;
+ ASSERT_THAT(ioctl(master2.get(), TIOCGPTN, &index2), SyscallSucceeds());
+ FileDescriptor slave2 = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master2));
+
+ // The directory contains ptmx, index1, and index2. (Plus any additional PTYs
+ // unrelated to this test.)
+
+ std::vector<std::string> contents =
+ ASSERT_NO_ERRNO_AND_VALUE(ListDir("/dev/pts/", true));
+ EXPECT_THAT(contents, Contains(absl::StrCat(index1)));
+ EXPECT_THAT(contents, Contains(absl::StrCat(index2)));
+
+ master2.reset();
+
+ // The directory contains ptmx and index1, but not index2 since the master is
+ // closed. (Plus any additional PTYs unrelated to this test.)
+
+ contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/dev/pts/", true));
+ EXPECT_THAT(contents, Contains(absl::StrCat(index1)));
+ EXPECT_THAT(contents, Not(Contains(absl::StrCat(index2))));
+
+ // N.B. devpts supports legacy "single-instance" mode and new "multi-instance"
+ // mode. In legacy mode, devpts does not contain a "ptmx" device (the distro
+ // must use mknod to create it somewhere, presumably /dev/ptmx).
+ // Multi-instance mode does include a "ptmx" device tied to that mount.
+ //
+ // We don't check for the presence or absence of "ptmx", as distros vary in
+ // their usage of the two modes.
+}
+
+class PtyTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ master_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK));
+ slave_ = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master_));
+ }
+
+ void DisableCanonical() {
+ struct kernel_termios t = {};
+ EXPECT_THAT(ioctl(slave_.get(), TCGETS, &t), SyscallSucceeds());
+ t.c_lflag &= ~ICANON;
+ EXPECT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds());
+ }
+
+ void EnableCanonical() {
+ struct kernel_termios t = {};
+ EXPECT_THAT(ioctl(slave_.get(), TCGETS, &t), SyscallSucceeds());
+ t.c_lflag |= ICANON;
+ EXPECT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds());
+ }
+
+ // Master and slave ends of the PTY. Non-blocking.
+ FileDescriptor master_;
+ FileDescriptor slave_;
+};
+
+// Master to slave sanity test.
+TEST_F(PtyTest, WriteMasterToSlave) {
+ // N.B. by default, the slave reads nothing until the master writes a newline.
+ constexpr char kBuf[] = "hello\n";
+
+ EXPECT_THAT(WriteFd(master_.get(), kBuf, sizeof(kBuf) - 1),
+ SyscallSucceedsWithValue(sizeof(kBuf) - 1));
+
+ // Linux moves data from the master to the slave via async work scheduled via
+ // tty_flip_buffer_push. Since it is asynchronous, the data may not be
+ // available for reading immediately. Instead we must poll and assert that it
+ // becomes available "soon".
+
+ char buf[sizeof(kBuf)] = {};
+ ExpectReadable(slave_, sizeof(buf) - 1, buf);
+
+ EXPECT_EQ(memcmp(buf, kBuf, sizeof(kBuf)), 0);
+}
+
+// Slave to master sanity test.
+TEST_F(PtyTest, WriteSlaveToMaster) {
+ // N.B. by default, the master reads nothing until the slave writes a newline,
+ // and the master gets a carriage return.
+ constexpr char kInput[] = "hello\n";
+ constexpr char kExpected[] = "hello\r\n";
+
+ EXPECT_THAT(WriteFd(slave_.get(), kInput, sizeof(kInput) - 1),
+ SyscallSucceedsWithValue(sizeof(kInput) - 1));
+
+ // Linux moves data from the master to the slave via async work scheduled via
+ // tty_flip_buffer_push. Since it is asynchronous, the data may not be
+ // available for reading immediately. Instead we must poll and assert that it
+ // becomes available "soon".
+
+ char buf[sizeof(kExpected)] = {};
+ ExpectReadable(master_, sizeof(buf) - 1, buf);
+
+ EXPECT_EQ(memcmp(buf, kExpected, sizeof(kExpected)), 0);
+}
+
+TEST_F(PtyTest, WriteInvalidUTF8) {
+ char c = 0xff;
+ ASSERT_THAT(syscall(__NR_write, master_.get(), &c, sizeof(c)),
+ SyscallSucceedsWithValue(sizeof(c)));
+}
+
+// Both the master and slave report the standard default termios settings.
+//
+// Note that TCGETS on the master actually redirects to the slave (see comment
+// on MasterTermiosUnchangable).
+TEST_F(PtyTest, DefaultTermios) {
+ struct kernel_termios t = {};
+ EXPECT_THAT(ioctl(slave_.get(), TCGETS, &t), SyscallSucceeds());
+ EXPECT_EQ(t, DefaultTermios());
+
+ EXPECT_THAT(ioctl(master_.get(), TCGETS, &t), SyscallSucceeds());
+ EXPECT_EQ(t, DefaultTermios());
+}
+
+// Changing termios from the master actually affects the slave.
+//
+// TCSETS on the master actually redirects to the slave (see comment on
+// MasterTermiosUnchangable).
+TEST_F(PtyTest, TermiosAffectsSlave) {
+ struct kernel_termios master_termios = {};
+ EXPECT_THAT(ioctl(master_.get(), TCGETS, &master_termios), SyscallSucceeds());
+ master_termios.c_lflag ^= ICANON;
+ EXPECT_THAT(ioctl(master_.get(), TCSETS, &master_termios), SyscallSucceeds());
+
+ struct kernel_termios slave_termios = {};
+ EXPECT_THAT(ioctl(slave_.get(), TCGETS, &slave_termios), SyscallSucceeds());
+ EXPECT_EQ(master_termios, slave_termios);
+}
+
+// The master end of the pty has termios:
+//
+// struct kernel_termios t = {
+// .c_iflag = 0;
+// .c_oflag = 0;
+// .c_cflag = B38400 | CS8 | CREAD;
+// .c_lflag = 0;
+// .c_cc = /* same as DefaultTermios */
+// }
+//
+// (From drivers/tty/pty.c:unix98_pty_init)
+//
+// All termios control ioctls on the master actually redirect to the slave
+// (drivers/tty/tty_ioctl.c:tty_mode_ioctl), making it impossible to change the
+// master termios.
+//
+// Verify this by setting ICRNL (which rewrites input \r to \n) and verify that
+// it has no effect on the master.
+TEST_F(PtyTest, MasterTermiosUnchangable) {
+ struct kernel_termios master_termios = {};
+ EXPECT_THAT(ioctl(master_.get(), TCGETS, &master_termios), SyscallSucceeds());
+ master_termios.c_lflag |= ICRNL;
+ EXPECT_THAT(ioctl(master_.get(), TCSETS, &master_termios), SyscallSucceeds());
+
+ char c = '\r';
+ ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1));
+
+ ExpectReadable(master_, 1, &c);
+ EXPECT_EQ(c, '\r'); // ICRNL had no effect!
+
+ ExpectFinished(master_);
+}
+
+// ICRNL rewrites input \r to \n.
+TEST_F(PtyTest, TermiosICRNL) {
+ struct kernel_termios t = DefaultTermios();
+ t.c_iflag |= ICRNL;
+ t.c_lflag &= ~ICANON; // for byte-by-byte reading.
+ ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds());
+
+ char c = '\r';
+ ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1));
+
+ ExpectReadable(slave_, 1, &c);
+ EXPECT_EQ(c, '\n');
+
+ ExpectFinished(slave_);
+}
+
+// ONLCR rewrites output \n to \r\n.
+TEST_F(PtyTest, TermiosONLCR) {
+ struct kernel_termios t = DefaultTermios();
+ t.c_oflag |= ONLCR;
+ t.c_lflag &= ~ICANON; // for byte-by-byte reading.
+ ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds());
+
+ char c = '\n';
+ ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1));
+
+ // Extra byte for NUL for EXPECT_STREQ.
+ char buf[3] = {};
+ ExpectReadable(master_, 2, buf);
+ EXPECT_STREQ(buf, "\r\n");
+
+ ExpectFinished(slave_);
+}
+
+TEST_F(PtyTest, TermiosIGNCR) {
+ struct kernel_termios t = DefaultTermios();
+ t.c_iflag |= IGNCR;
+ t.c_lflag &= ~ICANON; // for byte-by-byte reading.
+ ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds());
+
+ char c = '\r';
+ ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1));
+
+ // Nothing to read.
+ ASSERT_THAT(PollAndReadFd(slave_.get(), &c, 1, kTimeout),
+ PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out")));
+}
+
+// Test that we can successfully poll for readable data from the slave.
+TEST_F(PtyTest, TermiosPollSlave) {
+ struct kernel_termios t = DefaultTermios();
+ t.c_iflag |= IGNCR;
+ t.c_lflag &= ~ICANON; // for byte-by-byte reading.
+ ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds());
+
+ absl::Notification notify;
+ int sfd = slave_.get();
+ ScopedThread th([sfd, &notify]() {
+ notify.Notify();
+
+ // Poll on the reader fd with POLLIN event.
+ struct pollfd poll_fd = {sfd, POLLIN, 0};
+ EXPECT_THAT(
+ RetryEINTR(poll)(&poll_fd, 1, absl::ToInt64Milliseconds(kTimeout)),
+ SyscallSucceedsWithValue(1));
+
+ // Should trigger POLLIN event.
+ EXPECT_EQ(poll_fd.revents & POLLIN, POLLIN);
+ });
+
+ notify.WaitForNotification();
+ // Sleep ensures that poll begins waiting before we write to the FD.
+ absl::SleepFor(absl::Seconds(1));
+
+ char s[] = "foo\n";
+ ASSERT_THAT(WriteFd(master_.get(), s, strlen(s) + 1), SyscallSucceeds());
+}
+
+// Test that we can successfully poll for readable data from the master.
+TEST_F(PtyTest, TermiosPollMaster) {
+ struct kernel_termios t = DefaultTermios();
+ t.c_iflag |= IGNCR;
+ t.c_lflag &= ~ICANON; // for byte-by-byte reading.
+ ASSERT_THAT(ioctl(master_.get(), TCSETS, &t), SyscallSucceeds());
+
+ absl::Notification notify;
+ int mfd = master_.get();
+ ScopedThread th([mfd, &notify]() {
+ notify.Notify();
+
+ // Poll on the reader fd with POLLIN event.
+ struct pollfd poll_fd = {mfd, POLLIN, 0};
+ EXPECT_THAT(
+ RetryEINTR(poll)(&poll_fd, 1, absl::ToInt64Milliseconds(kTimeout)),
+ SyscallSucceedsWithValue(1));
+
+ // Should trigger POLLIN event.
+ EXPECT_EQ(poll_fd.revents & POLLIN, POLLIN);
+ });
+
+ notify.WaitForNotification();
+ // Sleep ensures that poll begins waiting before we write to the FD.
+ absl::SleepFor(absl::Seconds(1));
+
+ char s[] = "foo\n";
+ ASSERT_THAT(WriteFd(slave_.get(), s, strlen(s) + 1), SyscallSucceeds());
+}
+
+TEST_F(PtyTest, TermiosINLCR) {
+ struct kernel_termios t = DefaultTermios();
+ t.c_iflag |= INLCR;
+ t.c_lflag &= ~ICANON; // for byte-by-byte reading.
+ ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds());
+
+ char c = '\n';
+ ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1));
+
+ ExpectReadable(slave_, 1, &c);
+ EXPECT_EQ(c, '\r');
+
+ ExpectFinished(slave_);
+}
+
+TEST_F(PtyTest, TermiosONOCR) {
+ struct kernel_termios t = DefaultTermios();
+ t.c_oflag |= ONOCR;
+ t.c_lflag &= ~ICANON; // for byte-by-byte reading.
+ ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds());
+
+ // The terminal is at column 0, so there should be no CR to read.
+ char c = '\r';
+ ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1));
+
+ // Nothing to read.
+ ASSERT_THAT(PollAndReadFd(master_.get(), &c, 1, kTimeout),
+ PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out")));
+
+ // This time the column is greater than 0, so we should be able to read the CR
+ // out of the other end.
+ constexpr char kInput[] = "foo\r";
+ constexpr int kInputSize = sizeof(kInput) - 1;
+ ASSERT_THAT(WriteFd(slave_.get(), kInput, kInputSize),
+ SyscallSucceedsWithValue(kInputSize));
+
+ char buf[kInputSize] = {};
+ ExpectReadable(master_, kInputSize, buf);
+
+ EXPECT_EQ(memcmp(buf, kInput, kInputSize), 0);
+
+ ExpectFinished(master_);
+
+ // Terminal should be at column 0 again, so no CR can be read.
+ ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1));
+
+ // Nothing to read.
+ ASSERT_THAT(PollAndReadFd(master_.get(), &c, 1, kTimeout),
+ PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out")));
+}
+
+TEST_F(PtyTest, TermiosOCRNL) {
+ struct kernel_termios t = DefaultTermios();
+ t.c_oflag |= OCRNL;
+ t.c_lflag &= ~ICANON; // for byte-by-byte reading.
+ ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds());
+
+ // The terminal is at column 0, so there should be no CR to read.
+ char c = '\r';
+ ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1));
+
+ ExpectReadable(master_, 1, &c);
+ EXPECT_EQ(c, '\n');
+
+ ExpectFinished(master_);
+}
+
+// Tests that VEOL is disabled when we start, and that we can set it to enable
+// it.
+TEST_F(PtyTest, VEOLTermination) {
+ // Write a few bytes ending with '\0', and confirm that we can't read.
+ constexpr char kInput[] = "hello";
+ ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput)),
+ SyscallSucceedsWithValue(sizeof(kInput)));
+ char buf[sizeof(kInput)] = {};
+ ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(kInput), kTimeout),
+ PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out")));
+
+ // Set the EOL character to '=' and write it.
+ constexpr char delim = '=';
+ struct kernel_termios t = DefaultTermios();
+ t.c_cc[VEOL] = delim;
+ ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds());
+ ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1));
+
+ // Now we can read, as sending EOL caused the line to become available.
+ ExpectReadable(slave_, sizeof(kInput), buf);
+ EXPECT_EQ(memcmp(buf, kInput, sizeof(kInput)), 0);
+
+ ExpectReadable(slave_, 1, buf);
+ EXPECT_EQ(buf[0], '=');
+
+ ExpectFinished(slave_);
+}
+
+// Tests that we can write more than the 4096 character limit, then a
+// terminating character, then read out just the first 4095 bytes plus the
+// terminator.
+TEST_F(PtyTest, CanonBigWrite) {
+ constexpr int kWriteLen = kMaxLineSize + 4;
+ char input[kWriteLen];
+ memset(input, 'M', kWriteLen - 1);
+ input[kWriteLen - 1] = '\n';
+ ASSERT_THAT(WriteFd(master_.get(), input, kWriteLen),
+ SyscallSucceedsWithValue(kWriteLen));
+
+ // We can read the line.
+ char buf[kMaxLineSize] = {};
+ ExpectReadable(slave_, kMaxLineSize, buf);
+
+ ExpectFinished(slave_);
+}
+
+// Tests that data written in canonical mode can be read immediately once
+// switched to noncanonical mode.
+TEST_F(PtyTest, SwitchCanonToNoncanon) {
+ // Write a few bytes without a terminating character, switch to noncanonical
+ // mode, and read them.
+ constexpr char kInput[] = "hello";
+ ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput)),
+ SyscallSucceedsWithValue(sizeof(kInput)));
+
+ // Nothing available yet.
+ char buf[sizeof(kInput)] = {};
+ ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(kInput), kTimeout),
+ PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out")));
+
+ DisableCanonical();
+
+ ExpectReadable(slave_, sizeof(kInput), buf);
+ EXPECT_STREQ(buf, kInput);
+
+ ExpectFinished(slave_);
+}
+
+TEST_F(PtyTest, SwitchCanonToNonCanonNewline) {
+ // Write a few bytes with a terminating character.
+ constexpr char kInput[] = "hello\n";
+ ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput)),
+ SyscallSucceedsWithValue(sizeof(kInput)));
+
+ DisableCanonical();
+
+ // We can read the line.
+ char buf[sizeof(kInput)] = {};
+ ExpectReadable(slave_, sizeof(kInput), buf);
+ EXPECT_STREQ(buf, kInput);
+
+ ExpectFinished(slave_);
+}
+
+TEST_F(PtyTest, SwitchNoncanonToCanonNewlineBig) {
+ DisableCanonical();
+
+ // Write more than the maximum line size, then write a delimiter.
+ constexpr int kWriteLen = 4100;
+ char input[kWriteLen];
+ memset(input, 'M', kWriteLen);
+ ASSERT_THAT(WriteFd(master_.get(), input, kWriteLen),
+ SyscallSucceedsWithValue(kWriteLen));
+ // Wait for the input queue to fill.
+ ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kMaxLineSize - 1));
+ constexpr char delim = '\n';
+ ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1));
+
+ EnableCanonical();
+
+ // We can read the line.
+ char buf[kMaxLineSize] = {};
+ ExpectReadable(slave_, kMaxLineSize - 1, buf);
+
+ // We can also read the remaining characters.
+ ExpectReadable(slave_, 6, buf);
+
+ ExpectFinished(slave_);
+}
+
+TEST_F(PtyTest, SwitchNoncanonToCanonNoNewline) {
+ DisableCanonical();
+
+ // Write a few bytes without a terminating character.
+ // mode, and read them.
+ constexpr char kInput[] = "hello";
+ ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput) - 1),
+ SyscallSucceedsWithValue(sizeof(kInput) - 1));
+
+ ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), sizeof(kInput) - 1));
+ EnableCanonical();
+
+ // We can read the line.
+ char buf[sizeof(kInput)] = {};
+ ExpectReadable(slave_, sizeof(kInput) - 1, buf);
+ EXPECT_STREQ(buf, kInput);
+
+ ExpectFinished(slave_);
+}
+
+TEST_F(PtyTest, SwitchNoncanonToCanonNoNewlineBig) {
+ DisableCanonical();
+
+ // Write a few bytes without a terminating character.
+ // mode, and read them.
+ constexpr int kWriteLen = 4100;
+ char input[kWriteLen];
+ memset(input, 'M', kWriteLen);
+ ASSERT_THAT(WriteFd(master_.get(), input, kWriteLen),
+ SyscallSucceedsWithValue(kWriteLen));
+
+ ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kMaxLineSize - 1));
+ EnableCanonical();
+
+ // We can read the line.
+ char buf[kMaxLineSize] = {};
+ ExpectReadable(slave_, kMaxLineSize - 1, buf);
+
+ ExpectFinished(slave_);
+}
+
+// Tests that we can write over the 4095 noncanonical limit, then read out
+// everything.
+TEST_F(PtyTest, NoncanonBigWrite) {
+ DisableCanonical();
+
+ // Write well over the 4095 internal buffer limit.
+ constexpr char kInput = 'M';
+ constexpr int kInputSize = kMaxLineSize * 2;
+ for (int i = 0; i < kInputSize; i++) {
+ // This makes too many syscalls for save/restore.
+ const DisableSave ds;
+ ASSERT_THAT(WriteFd(master_.get(), &kInput, sizeof(kInput)),
+ SyscallSucceedsWithValue(sizeof(kInput)));
+ }
+
+ // We should be able to read out everything. Sleep a bit so that Linux has a
+ // chance to move data from the master to the slave.
+ ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kMaxLineSize - 1));
+ for (int i = 0; i < kInputSize; i++) {
+ // This makes too many syscalls for save/restore.
+ const DisableSave ds;
+ char c;
+ ExpectReadable(slave_, 1, &c);
+ ASSERT_EQ(c, kInput);
+ }
+
+ ExpectFinished(slave_);
+}
+
+// ICANON doesn't make input available until a line delimiter is typed.
+//
+// Test newline.
+TEST_F(PtyTest, TermiosICANONNewline) {
+ char input[3] = {'a', 'b', 'c'};
+ ASSERT_THAT(WriteFd(master_.get(), input, sizeof(input)),
+ SyscallSucceedsWithValue(sizeof(input)));
+
+ // Extra bytes for newline (written later) and NUL for EXPECT_STREQ.
+ char buf[5] = {};
+
+ // Nothing available yet.
+ ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(input), kTimeout),
+ PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out")));
+
+ char delim = '\n';
+ ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1));
+
+ // Now it is available.
+ ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), sizeof(input) + 1));
+ ExpectReadable(slave_, sizeof(input) + 1, buf);
+ EXPECT_STREQ(buf, "abc\n");
+
+ ExpectFinished(slave_);
+}
+
+// ICANON doesn't make input available until a line delimiter is typed.
+//
+// Test EOF (^D).
+TEST_F(PtyTest, TermiosICANONEOF) {
+ char input[3] = {'a', 'b', 'c'};
+ ASSERT_THAT(WriteFd(master_.get(), input, sizeof(input)),
+ SyscallSucceedsWithValue(sizeof(input)));
+
+ // Extra byte for NUL for EXPECT_STREQ.
+ char buf[4] = {};
+
+ // Nothing available yet.
+ ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(input), kTimeout),
+ PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out")));
+ char delim = ControlCharacter('D');
+ ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1));
+
+ // Now it is available. Note that ^D is not included.
+ ExpectReadable(slave_, sizeof(input), buf);
+ EXPECT_STREQ(buf, "abc");
+
+ ExpectFinished(slave_);
+}
+
+// ICANON limits us to 4096 bytes including a terminating character. Anything
+// after and 4095th character is discarded (although still processed for
+// signals and echoing).
+TEST_F(PtyTest, CanonDiscard) {
+ constexpr char kInput = 'M';
+ constexpr int kInputSize = 4100;
+ constexpr int kIter = 3;
+
+ // A few times write more than the 4096 character maximum, then a newline.
+ constexpr char delim = '\n';
+ for (int i = 0; i < kIter; i++) {
+ // This makes too many syscalls for save/restore.
+ const DisableSave ds;
+ for (int i = 0; i < kInputSize; i++) {
+ ASSERT_THAT(WriteFd(master_.get(), &kInput, sizeof(kInput)),
+ SyscallSucceedsWithValue(sizeof(kInput)));
+ }
+ ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1));
+ }
+
+ // There should be multiple truncated lines available to read.
+ for (int i = 0; i < kIter; i++) {
+ char buf[kInputSize] = {};
+ ExpectReadable(slave_, kMaxLineSize, buf);
+ EXPECT_EQ(buf[kMaxLineSize - 1], delim);
+ EXPECT_EQ(buf[kMaxLineSize - 2], kInput);
+ }
+
+ ExpectFinished(slave_);
+}
+
+TEST_F(PtyTest, CanonMultiline) {
+ constexpr char kInput1[] = "GO\n";
+ constexpr char kInput2[] = "BLUE\n";
+
+ // Write both lines.
+ ASSERT_THAT(WriteFd(master_.get(), kInput1, sizeof(kInput1) - 1),
+ SyscallSucceedsWithValue(sizeof(kInput1) - 1));
+ ASSERT_THAT(WriteFd(master_.get(), kInput2, sizeof(kInput2) - 1),
+ SyscallSucceedsWithValue(sizeof(kInput2) - 1));
+
+ // Get the first line.
+ char line1[8] = {};
+ ExpectReadable(slave_, sizeof(kInput1) - 1, line1);
+ EXPECT_STREQ(line1, kInput1);
+
+ // Get the second line.
+ char line2[8] = {};
+ ExpectReadable(slave_, sizeof(kInput2) - 1, line2);
+ EXPECT_STREQ(line2, kInput2);
+
+ ExpectFinished(slave_);
+}
+
+TEST_F(PtyTest, SwitchNoncanonToCanonMultiline) {
+ DisableCanonical();
+
+ constexpr char kInput1[] = "GO\n";
+ constexpr char kInput2[] = "BLUE\n";
+ constexpr char kExpected[] = "GO\nBLUE\n";
+
+ // Write both lines.
+ ASSERT_THAT(WriteFd(master_.get(), kInput1, sizeof(kInput1) - 1),
+ SyscallSucceedsWithValue(sizeof(kInput1) - 1));
+ ASSERT_THAT(WriteFd(master_.get(), kInput2, sizeof(kInput2) - 1),
+ SyscallSucceedsWithValue(sizeof(kInput2) - 1));
+
+ ASSERT_NO_ERRNO(
+ WaitUntilReceived(slave_.get(), sizeof(kInput1) + sizeof(kInput2) - 2));
+ EnableCanonical();
+
+ // Get all together as one line.
+ char line[9] = {};
+ ExpectReadable(slave_, 8, line);
+ EXPECT_STREQ(line, kExpected);
+
+ ExpectFinished(slave_);
+}
+
+TEST_F(PtyTest, SwitchTwiceMultiline) {
+ std::string kInputs[] = {"GO\n", "BLUE\n", "!"};
+ std::string kExpected = "GO\nBLUE\n!";
+
+ // Write each line.
+ for (const std::string& input : kInputs) {
+ ASSERT_THAT(WriteFd(master_.get(), input.c_str(), input.size()),
+ SyscallSucceedsWithValue(input.size()));
+ }
+
+ DisableCanonical();
+ // All written characters have to make it into the input queue before
+ // canonical mode is re-enabled. If the final '!' character hasn't been
+ // enqueued before canonical mode is re-enabled, it won't be readable.
+ ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kExpected.size()));
+ EnableCanonical();
+
+ // Get all together as one line.
+ char line[10] = {};
+ ExpectReadable(slave_, 9, line);
+ EXPECT_STREQ(line, kExpected.c_str());
+
+ ExpectFinished(slave_);
+}
+
+TEST_F(PtyTest, QueueSize) {
+ // Write the line.
+ constexpr char kInput1[] = "GO\n";
+ ASSERT_THAT(WriteFd(master_.get(), kInput1, sizeof(kInput1) - 1),
+ SyscallSucceedsWithValue(sizeof(kInput1) - 1));
+ ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), sizeof(kInput1) - 1));
+
+ // Ensure that writing more (beyond what is readable) does not impact the
+ // readable size.
+ char input[kMaxLineSize];
+ memset(input, 'M', kMaxLineSize);
+ ASSERT_THAT(WriteFd(master_.get(), input, kMaxLineSize),
+ SyscallSucceedsWithValue(kMaxLineSize));
+ int inputBufSize = ASSERT_NO_ERRNO_AND_VALUE(
+ WaitUntilReceived(slave_.get(), sizeof(kInput1) - 1));
+ EXPECT_EQ(inputBufSize, sizeof(kInput1) - 1);
+}
+
+TEST_F(PtyTest, PartialBadBuffer) {
+ // Allocate 2 pages.
+ void* addr = mmap(nullptr, 2 * kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+ ASSERT_NE(addr, MAP_FAILED);
+ char* buf = reinterpret_cast<char*>(addr);
+
+ // Guard the 2nd page for our read to run into.
+ ASSERT_THAT(
+ mprotect(reinterpret_cast<void*>(buf + kPageSize), kPageSize, PROT_NONE),
+ SyscallSucceeds());
+
+ // Leave only one free byte in the buffer.
+ char* bad_buffer = buf + kPageSize - 1;
+
+ // Write to the master.
+ constexpr char kBuf[] = "hello\n";
+ constexpr size_t size = sizeof(kBuf) - 1;
+ EXPECT_THAT(WriteFd(master_.get(), kBuf, size),
+ SyscallSucceedsWithValue(size));
+
+ // Read from the slave into bad_buffer.
+ ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), size));
+ EXPECT_THAT(ReadFd(slave_.get(), bad_buffer, size),
+ SyscallFailsWithErrno(EFAULT));
+
+ EXPECT_THAT(munmap(addr, 2 * kPageSize), SyscallSucceeds()) << addr;
+}
+
+TEST_F(PtyTest, SimpleEcho) {
+ constexpr char kInput[] = "Mr. Eko";
+ EXPECT_THAT(WriteFd(master_.get(), kInput, strlen(kInput)),
+ SyscallSucceedsWithValue(strlen(kInput)));
+
+ char buf[100] = {};
+ ExpectReadable(master_, strlen(kInput), buf);
+
+ EXPECT_STREQ(buf, kInput);
+ ExpectFinished(master_);
+}
+
+TEST_F(PtyTest, GetWindowSize) {
+ struct winsize ws;
+ ASSERT_THAT(ioctl(slave_.get(), TIOCGWINSZ, &ws), SyscallSucceeds());
+ EXPECT_EQ(ws.ws_row, 0);
+ EXPECT_EQ(ws.ws_col, 0);
+}
+
+TEST_F(PtyTest, SetSlaveWindowSize) {
+ constexpr uint16_t kRows = 343;
+ constexpr uint16_t kCols = 2401;
+ struct winsize ws = {.ws_row = kRows, .ws_col = kCols};
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSWINSZ, &ws), SyscallSucceeds());
+
+ struct winsize retrieved_ws = {};
+ ASSERT_THAT(ioctl(master_.get(), TIOCGWINSZ, &retrieved_ws),
+ SyscallSucceeds());
+ EXPECT_EQ(retrieved_ws.ws_row, kRows);
+ EXPECT_EQ(retrieved_ws.ws_col, kCols);
+}
+
+TEST_F(PtyTest, SetMasterWindowSize) {
+ constexpr uint16_t kRows = 343;
+ constexpr uint16_t kCols = 2401;
+ struct winsize ws = {.ws_row = kRows, .ws_col = kCols};
+ ASSERT_THAT(ioctl(master_.get(), TIOCSWINSZ, &ws), SyscallSucceeds());
+
+ struct winsize retrieved_ws = {};
+ ASSERT_THAT(ioctl(slave_.get(), TIOCGWINSZ, &retrieved_ws),
+ SyscallSucceeds());
+ EXPECT_EQ(retrieved_ws.ws_row, kRows);
+ EXPECT_EQ(retrieved_ws.ws_col, kCols);
+}
+
+class JobControlTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ master_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK));
+ slave_ = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master_));
+
+ // Make this a session leader, which also drops the controlling terminal.
+ // In the gVisor test environment, this test will be run as the session
+ // leader already (as the sentry init process).
+ if (!IsRunningOnGvisor()) {
+ ASSERT_THAT(setsid(), SyscallSucceeds());
+ }
+ }
+
+ // Master and slave ends of the PTY. Non-blocking.
+ FileDescriptor master_;
+ FileDescriptor slave_;
+};
+
+TEST_F(JobControlTest, SetTTYMaster) {
+ ASSERT_THAT(ioctl(master_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+}
+
+TEST_F(JobControlTest, SetTTY) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+}
+
+TEST_F(JobControlTest, SetTTYNonLeader) {
+ // Fork a process that won't be the session leader.
+ pid_t child = fork();
+ if (!child) {
+ // We shouldn't be able to set the terminal.
+ TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 0));
+ _exit(0);
+ }
+
+ int wstatus;
+ ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child));
+ ASSERT_EQ(wstatus, 0);
+}
+
+TEST_F(JobControlTest, SetTTYBadArg) {
+ // Despite the man page saying arg should be 0 here, Linux doesn't actually
+ // check.
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 1), SyscallSucceeds());
+}
+
+TEST_F(JobControlTest, SetTTYDifferentSession) {
+ SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ // Fork, join a new session, and try to steal the parent's controlling
+ // terminal, which should fail.
+ pid_t child = fork();
+ if (!child) {
+ TEST_PCHECK(setsid() >= 0);
+ // We shouldn't be able to steal the terminal.
+ TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 1));
+ _exit(0);
+ }
+
+ int wstatus;
+ ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child));
+ ASSERT_EQ(wstatus, 0);
+}
+
+TEST_F(JobControlTest, ReleaseTTY) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ // Make sure we're ignoring SIGHUP, which will be sent to this process once we
+ // disconnect they TTY.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_IGN;
+ sa.sa_flags = 0;
+ sigemptyset(&sa.sa_mask);
+ struct sigaction old_sa;
+ EXPECT_THAT(sigaction(SIGHUP, &sa, &old_sa), SyscallSucceeds());
+ EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds());
+ EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds());
+}
+
+TEST_F(JobControlTest, ReleaseUnsetTTY) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY));
+}
+
+TEST_F(JobControlTest, ReleaseWrongTTY) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ ASSERT_THAT(ioctl(master_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY));
+}
+
+TEST_F(JobControlTest, ReleaseTTYNonLeader) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ pid_t child = fork();
+ if (!child) {
+ TEST_PCHECK(!ioctl(slave_.get(), TIOCNOTTY));
+ _exit(0);
+ }
+
+ int wstatus;
+ ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child));
+ ASSERT_EQ(wstatus, 0);
+}
+
+TEST_F(JobControlTest, ReleaseTTYDifferentSession) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ pid_t child = fork();
+ if (!child) {
+ // Join a new session, then try to disconnect.
+ TEST_PCHECK(setsid() >= 0);
+ TEST_PCHECK(ioctl(slave_.get(), TIOCNOTTY));
+ _exit(0);
+ }
+
+ int wstatus;
+ ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child));
+ ASSERT_EQ(wstatus, 0);
+}
+
+// Used by the child process spawned in ReleaseTTYSignals to track received
+// signals.
+static int received;
+
+void sig_handler(int signum) { received |= signum; }
+
+// When the session leader releases its controlling terminal, the foreground
+// process group gets SIGHUP, then SIGCONT. This test:
+// - Spawns 2 threads
+// - Has thread 1 return 0 if it gets both SIGHUP and SIGCONT
+// - Has thread 2 leave the foreground process group, and return non-zero if it
+// receives any signals.
+// - Has the parent thread release its controlling terminal
+// - Checks that thread 1 got both signals
+// - Checks that thread 2 didn't get any signals.
+TEST_F(JobControlTest, ReleaseTTYSignals) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ received = 0;
+ struct sigaction sa = {};
+ sa.sa_handler = sig_handler;
+ sa.sa_flags = 0;
+ sigemptyset(&sa.sa_mask);
+ sigaddset(&sa.sa_mask, SIGHUP);
+ sigaddset(&sa.sa_mask, SIGCONT);
+ sigprocmask(SIG_BLOCK, &sa.sa_mask, NULL);
+
+ pid_t same_pgrp_child = fork();
+ if (!same_pgrp_child) {
+ // The child will wait for SIGHUP and SIGCONT, then return 0. It begins with
+ // SIGHUP and SIGCONT blocked. We install signal handlers for those signals,
+ // then use sigsuspend to wait for those specific signals.
+ TEST_PCHECK(!sigaction(SIGHUP, &sa, NULL));
+ TEST_PCHECK(!sigaction(SIGCONT, &sa, NULL));
+ sigset_t mask;
+ sigfillset(&mask);
+ sigdelset(&mask, SIGHUP);
+ sigdelset(&mask, SIGCONT);
+ while (received != (SIGHUP | SIGCONT)) {
+ sigsuspend(&mask);
+ }
+ _exit(0);
+ }
+
+ // We don't want to block these anymore.
+ sigprocmask(SIG_UNBLOCK, &sa.sa_mask, NULL);
+
+ // This child will return non-zero if either SIGHUP or SIGCONT are received.
+ pid_t diff_pgrp_child = fork();
+ if (!diff_pgrp_child) {
+ TEST_PCHECK(!setpgid(0, 0));
+ TEST_PCHECK(pause());
+ _exit(1);
+ }
+
+ EXPECT_THAT(setpgid(diff_pgrp_child, diff_pgrp_child), SyscallSucceeds());
+
+ // Make sure we're ignoring SIGHUP, which will be sent to this process once we
+ // disconnect they TTY.
+ struct sigaction sighup_sa = {};
+ sighup_sa.sa_handler = SIG_IGN;
+ sighup_sa.sa_flags = 0;
+ sigemptyset(&sighup_sa.sa_mask);
+ struct sigaction old_sa;
+ EXPECT_THAT(sigaction(SIGHUP, &sighup_sa, &old_sa), SyscallSucceeds());
+
+ // Release the controlling terminal, sending SIGHUP and SIGCONT to all other
+ // processes in this process group.
+ EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds());
+
+ EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds());
+
+ // The child in the same process group will get signaled.
+ int wstatus;
+ EXPECT_THAT(waitpid(same_pgrp_child, &wstatus, 0),
+ SyscallSucceedsWithValue(same_pgrp_child));
+ EXPECT_EQ(wstatus, 0);
+
+ // The other child will not get signaled.
+ EXPECT_THAT(waitpid(diff_pgrp_child, &wstatus, WNOHANG),
+ SyscallSucceedsWithValue(0));
+ EXPECT_THAT(kill(diff_pgrp_child, SIGKILL), SyscallSucceeds());
+}
+
+TEST_F(JobControlTest, GetForegroundProcessGroup) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+ pid_t foreground_pgid;
+ pid_t pid;
+ ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid),
+ SyscallSucceeds());
+ ASSERT_THAT(pid = getpid(), SyscallSucceeds());
+
+ ASSERT_EQ(foreground_pgid, pid);
+}
+
+TEST_F(JobControlTest, GetForegroundProcessGroupNonControlling) {
+ // At this point there's no controlling terminal, so TIOCGPGRP should fail.
+ pid_t foreground_pgid;
+ ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid),
+ SyscallFailsWithErrno(ENOTTY));
+}
+
+// This test:
+// - sets itself as the foreground process group
+// - creates a child process in a new process group
+// - sets that child as the foreground process group
+// - kills its child and sets itself as the foreground process group.
+TEST_F(JobControlTest, SetForegroundProcessGroup) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ // Ignore SIGTTOU so that we don't stop ourself when calling tcsetpgrp.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_IGN;
+ sa.sa_flags = 0;
+ sigemptyset(&sa.sa_mask);
+ sigaction(SIGTTOU, &sa, NULL);
+
+ // Set ourself as the foreground process group.
+ ASSERT_THAT(tcsetpgrp(slave_.get(), getpgid(0)), SyscallSucceeds());
+
+ // Create a new process that just waits to be signaled.
+ pid_t child = fork();
+ if (!child) {
+ TEST_PCHECK(!pause());
+ // We should never reach this.
+ _exit(1);
+ }
+
+ // Make the child its own process group, then make it the controlling process
+ // group of the terminal.
+ ASSERT_THAT(setpgid(child, child), SyscallSucceeds());
+ ASSERT_THAT(tcsetpgrp(slave_.get(), child), SyscallSucceeds());
+
+ // Sanity check - we're still the controlling session.
+ ASSERT_EQ(getsid(0), getsid(child));
+
+ // Signal the child, wait for it to exit, then retake the terminal.
+ ASSERT_THAT(kill(child, SIGTERM), SyscallSucceeds());
+ int wstatus;
+ ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child));
+ ASSERT_TRUE(WIFSIGNALED(wstatus));
+ ASSERT_EQ(WTERMSIG(wstatus), SIGTERM);
+
+ // Set ourself as the foreground process.
+ pid_t pgid;
+ ASSERT_THAT(pgid = getpgid(0), SyscallSucceeds());
+ ASSERT_THAT(tcsetpgrp(slave_.get(), pgid), SyscallSucceeds());
+}
+
+TEST_F(JobControlTest, SetForegroundProcessGroupWrongTTY) {
+ pid_t pid = getpid();
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid),
+ SyscallFailsWithErrno(ENOTTY));
+}
+
+TEST_F(JobControlTest, SetForegroundProcessGroupNegPgid) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ pid_t pid = -1;
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ // Create a new process, put it in a new process group, make that group the
+ // foreground process group, then have the process wait.
+ pid_t child = fork();
+ if (!child) {
+ TEST_PCHECK(!setpgid(0, 0));
+ _exit(0);
+ }
+
+ // Wait for the child to exit.
+ int wstatus;
+ EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child));
+ // The child's process group doesn't exist anymore - this should fail.
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child),
+ SyscallFailsWithErrno(ESRCH));
+}
+
+TEST_F(JobControlTest, SetForegroundProcessGroupDifferentSession) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ int sync_setsid[2];
+ int sync_exit[2];
+ ASSERT_THAT(pipe(sync_setsid), SyscallSucceeds());
+ ASSERT_THAT(pipe(sync_exit), SyscallSucceeds());
+
+ // Create a new process and put it in a new session.
+ pid_t child = fork();
+ if (!child) {
+ TEST_PCHECK(setsid() >= 0);
+ // Tell the parent we're in a new session.
+ char c = 'c';
+ TEST_PCHECK(WriteFd(sync_setsid[1], &c, 1) == 1);
+ TEST_PCHECK(ReadFd(sync_exit[0], &c, 1) == 1);
+ _exit(0);
+ }
+
+ // Wait for the child to tell us it's in a new session.
+ char c = 'c';
+ ASSERT_THAT(ReadFd(sync_setsid[0], &c, 1), SyscallSucceedsWithValue(1));
+
+ // Child is in a new session, so we can't make it the foregroup process group.
+ EXPECT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child),
+ SyscallFailsWithErrno(EPERM));
+
+ EXPECT_THAT(WriteFd(sync_exit[1], &c, 1), SyscallSucceedsWithValue(1));
+
+ int wstatus;
+ EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child));
+ EXPECT_TRUE(WIFEXITED(wstatus));
+ EXPECT_EQ(WEXITSTATUS(wstatus), 0);
+}
+
+// Verify that we don't hang when creating a new session from an orphaned
+// process group (b/139968068). Calling setsid() creates an orphaned process
+// group, as process groups that contain the session's leading process are
+// orphans.
+//
+// We create 2 sessions in this test. The init process in gVisor is considered
+// not to be an orphan (see sessions.go), so we have to create a session from
+// which to create a session. The latter session is being created from an
+// orphaned process group.
+TEST_F(JobControlTest, OrphanRegression) {
+ pid_t session_2_leader = fork();
+ if (!session_2_leader) {
+ TEST_PCHECK(setsid() >= 0);
+
+ pid_t session_3_leader = fork();
+ if (!session_3_leader) {
+ TEST_PCHECK(setsid() >= 0);
+
+ _exit(0);
+ }
+
+ int wstatus;
+ TEST_PCHECK(waitpid(session_3_leader, &wstatus, 0) == session_3_leader);
+ TEST_PCHECK(wstatus == 0);
+
+ _exit(0);
+ }
+
+ int wstatus;
+ ASSERT_THAT(waitpid(session_2_leader, &wstatus, 0),
+ SyscallSucceedsWithValue(session_2_leader));
+ ASSERT_EQ(wstatus, 0);
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/pty_root.cc b/test/syscalls/linux/pty_root.cc
new file mode 100644
index 000000000..1d7dbefdb
--- /dev/null
+++ b/test/syscalls/linux/pty_root.cc
@@ -0,0 +1,78 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/ioctl.h>
+#include <termios.h>
+
+#include "gtest/gtest.h"
+#include "absl/base/macros.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/pty_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// StealTTY tests whether privileged processes can steal controlling terminals.
+// If the stealing process has CAP_SYS_ADMIN in the root user namespace, the
+// test ensures that stealing works. If it has non-root CAP_SYS_ADMIN, it
+// ensures stealing fails.
+TEST(JobControlRootTest, StealTTY) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ bool true_root = true;
+ if (!IsRunningOnGvisor()) {
+ // If running in Linux, we may only have CAP_SYS_ADMIN in a non-root user
+ // namespace (i.e. we are not truly root). We use init_module as a proxy for
+ // whether we are true root, as it returns EPERM immediately.
+ ASSERT_THAT(syscall(SYS_init_module, nullptr, 0, nullptr), SyscallFails());
+ true_root = errno != EPERM;
+
+ // Make this a session leader, which also drops the controlling terminal.
+ // In the gVisor test environment, this test will be run as the session
+ // leader already (as the sentry init process).
+ ASSERT_THAT(setsid(), SyscallSucceeds());
+ }
+
+ FileDescriptor master =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK));
+ FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master));
+
+ // Make slave the controlling terminal.
+ ASSERT_THAT(ioctl(slave.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ // Fork, join a new session, and try to steal the parent's controlling
+ // terminal, which should succeed when we have CAP_SYS_ADMIN and pass an arg
+ // of 1.
+ pid_t child = fork();
+ if (!child) {
+ ASSERT_THAT(setsid(), SyscallSucceeds());
+ // We shouldn't be able to steal the terminal with the wrong arg value.
+ TEST_PCHECK(ioctl(slave.get(), TIOCSCTTY, 0));
+ // We should be able to steal it if we are true root.
+ TEST_PCHECK(true_root == !ioctl(slave.get(), TIOCSCTTY, 1));
+ _exit(0);
+ }
+
+ int wstatus;
+ ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child));
+ ASSERT_EQ(wstatus, 0);
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/pwrite64.cc b/test/syscalls/linux/pwrite64.cc
new file mode 100644
index 000000000..e69794910
--- /dev/null
+++ b/test/syscalls/linux/pwrite64.cc
@@ -0,0 +1,83 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/unistd.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// TODO(gvisor.dev/issue/2370): This test is currently very rudimentary.
+class Pwrite64 : public ::testing::Test {
+ void SetUp() override {
+ name_ = NewTempAbsPath();
+ int fd;
+ ASSERT_THAT(fd = open(name_.c_str(), O_CREAT, 0644), SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+ }
+
+ void TearDown() override { unlink(name_.c_str()); }
+
+ public:
+ std::string name_;
+};
+
+TEST_F(Pwrite64, AppendOnly) {
+ int fd;
+ ASSERT_THAT(fd = open(name_.c_str(), O_APPEND | O_RDWR), SyscallSucceeds());
+ constexpr int64_t kBufSize = 1024;
+ std::vector<char> buf(kBufSize);
+ std::fill(buf.begin(), buf.end(), 'a');
+ EXPECT_THAT(PwriteFd(fd, buf.data(), buf.size(), 0),
+ SyscallSucceedsWithValue(buf.size()));
+ EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(0));
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+TEST_F(Pwrite64, InvalidArgs) {
+ int fd;
+ ASSERT_THAT(fd = open(name_.c_str(), O_APPEND | O_RDWR), SyscallSucceeds());
+ constexpr int64_t kBufSize = 1024;
+ std::vector<char> buf(kBufSize);
+ std::fill(buf.begin(), buf.end(), 'a');
+ EXPECT_THAT(PwriteFd(fd, buf.data(), buf.size(), -1),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+TEST_F(Pwrite64, Overflow) {
+ int fd;
+ ASSERT_THAT(fd = open(name_.c_str(), O_APPEND | O_RDWR), SyscallSucceeds());
+ constexpr int64_t kBufSize = 1024;
+ std::vector<char> buf(kBufSize);
+ std::fill(buf.begin(), buf.end(), 'a');
+ EXPECT_THAT(PwriteFd(fd, buf.data(), buf.size(), 0x7fffffffffffffffull),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/pwritev2.cc b/test/syscalls/linux/pwritev2.cc
new file mode 100644
index 000000000..63b686c62
--- /dev/null
+++ b/test/syscalls/linux/pwritev2.cc
@@ -0,0 +1,307 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <sys/uio.h>
+
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/file_base.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+#ifndef SYS_pwritev2
+#if defined(__x86_64__)
+#define SYS_pwritev2 328
+#elif defined(__aarch64__)
+#define SYS_pwritev2 287
+#else
+#error "Unknown architecture"
+#endif
+#endif // SYS_pwrite2
+
+#ifndef RWF_HIPRI
+#define RWF_HIPRI 0x1
+#endif // RWF_HIPRI
+
+#ifndef RWF_DSYNC
+#define RWF_DSYNC 0x2
+#endif // RWF_DSYNC
+
+#ifndef RWF_SYNC
+#define RWF_SYNC 0x4
+#endif // RWF_SYNC
+
+constexpr int kBufSize = 1024;
+
+void SetContent(std::vector<char>& content) {
+ for (uint i = 0; i < content.size(); i++) {
+ content[i] = static_cast<char>((i % 10) + '0');
+ }
+}
+
+ssize_t pwritev2(unsigned long fd, const struct iovec* iov,
+ unsigned long iovcnt, off_t offset, unsigned long flags) {
+ // syscall on pwritev2 does some weird things (see man syscall and search
+ // pwritev2), so we insert a 0 to word align the flags argument on native.
+ return syscall(SYS_pwritev2, fd, iov, iovcnt, offset, 0, flags);
+}
+
+// This test is the base case where we call pwritev (no offset, no flags).
+TEST(Writev2Test, BaseCall) {
+ SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+
+ std::vector<char> content(kBufSize);
+ SetContent(content);
+ struct iovec iov[2];
+ iov[0].iov_base = content.data();
+ iov[0].iov_len = content.size() / 2;
+ iov[1].iov_base = static_cast<char*>(iov[0].iov_base) + (content.size() / 2);
+ iov[1].iov_len = content.size() / 2;
+
+ ASSERT_THAT(pwritev2(fd.get(), iov, /*iovcnt=*/2,
+ /*offset=*/0, /*flags=*/0),
+ SyscallSucceedsWithValue(kBufSize));
+
+ std::vector<char> buf(kBufSize);
+ EXPECT_THAT(read(fd.get(), buf.data(), kBufSize),
+ SyscallSucceedsWithValue(kBufSize));
+
+ EXPECT_EQ(content, buf);
+}
+
+// This test is where we call pwritev2 with a positive offset and no flags.
+TEST(Pwritev2Test, ValidPositiveOffset) {
+ SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ std::string prefix(kBufSize, '0');
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), prefix, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+
+ std::vector<char> content(kBufSize);
+ SetContent(content);
+ struct iovec iov;
+ iov.iov_base = content.data();
+ iov.iov_len = content.size();
+
+ ASSERT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1,
+ /*offset=*/prefix.size(), /*flags=*/0),
+ SyscallSucceedsWithValue(content.size()));
+
+ std::vector<char> buf(prefix.size() + content.size());
+ EXPECT_THAT(read(fd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+
+ std::vector<char> want(prefix.begin(), prefix.end());
+ want.insert(want.end(), content.begin(), content.end());
+ EXPECT_EQ(want, buf);
+}
+
+// This test is the base case where we call writev by using -1 as the offset.
+// The write should use the file offset, so the test increments the file offset
+// prior to call pwritev2.
+TEST(Pwritev2Test, NegativeOneOffset) {
+ SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ const std::string prefix = "00";
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), prefix.data(), TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+ ASSERT_THAT(lseek(fd.get(), prefix.size(), SEEK_SET),
+ SyscallSucceedsWithValue(prefix.size()));
+
+ std::vector<char> content(kBufSize);
+ SetContent(content);
+ struct iovec iov;
+ iov.iov_base = content.data();
+ iov.iov_len = content.size();
+
+ ASSERT_THAT(pwritev2(fd.get(), &iov, /*iovcnt*/ 1,
+ /*offset=*/static_cast<off_t>(-1), /*flags=*/0),
+ SyscallSucceedsWithValue(content.size()));
+
+ ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR),
+ SyscallSucceedsWithValue(prefix.size() + content.size()));
+
+ std::vector<char> buf(prefix.size() + content.size());
+ EXPECT_THAT(pread(fd.get(), buf.data(), buf.size(), /*offset=*/0),
+ SyscallSucceedsWithValue(buf.size()));
+
+ std::vector<char> want(prefix.begin(), prefix.end());
+ want.insert(want.end(), content.begin(), content.end());
+ EXPECT_EQ(want, buf);
+}
+
+// pwritev2 requires if the RWF_HIPRI flag is passed, the fd must be opened with
+// O_DIRECT. This test implements a correct call with the RWF_HIPRI flag.
+TEST(Pwritev2Test, CallWithRWF_HIPRI) {
+ SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+
+ std::vector<char> content(kBufSize);
+ SetContent(content);
+ struct iovec iov;
+ iov.iov_base = content.data();
+ iov.iov_len = content.size();
+
+ EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1,
+ /*offset=*/0, /*flags=*/RWF_HIPRI),
+ SyscallSucceedsWithValue(kBufSize));
+
+ std::vector<char> buf(content.size());
+ EXPECT_THAT(read(fd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+
+ EXPECT_EQ(buf, content);
+}
+
+// This test calls pwritev2 with a bad file descriptor.
+TEST(Writev2Test, BadFile) {
+ SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+ ASSERT_THAT(pwritev2(/*fd=*/-1, /*iov=*/nullptr, /*iovcnt=*/0,
+ /*offset=*/0, /*flags=*/0),
+ SyscallFailsWithErrno(EBADF));
+}
+
+// This test calls pwrite2 with an invalid offset.
+TEST(Pwritev2Test, InvalidOffset) {
+ SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+
+ char buf[16];
+ struct iovec iov;
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
+
+ EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1,
+ /*offset=*/static_cast<off_t>(-8), /*flags=*/0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(Pwritev2Test, UnseekableFileValid) {
+ SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ int pipe_fds[2];
+
+ ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds());
+
+ std::vector<char> content(32, '0');
+ SetContent(content);
+ struct iovec iov;
+ iov.iov_base = content.data();
+ iov.iov_len = content.size();
+
+ EXPECT_THAT(pwritev2(pipe_fds[1], &iov, /*iovcnt=*/1,
+ /*offset=*/static_cast<off_t>(-1), /*flags=*/0),
+ SyscallSucceedsWithValue(content.size()));
+
+ std::vector<char> buf(content.size());
+ EXPECT_THAT(read(pipe_fds[0], buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+
+ EXPECT_EQ(content, buf);
+
+ EXPECT_THAT(close(pipe_fds[0]), SyscallSucceeds());
+ EXPECT_THAT(close(pipe_fds[1]), SyscallSucceeds());
+}
+
+// Calling pwritev2 with a non-negative offset calls pwritev. Calling pwritev
+// with an unseekable file is not allowed. A pipe is used for an unseekable
+// file.
+TEST(Pwritev2Test, UnseekableFileInvalid) {
+ SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ int pipe_fds[2];
+ char buf[16];
+ struct iovec iov;
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
+
+ ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds());
+
+ EXPECT_THAT(pwritev2(pipe_fds[1], &iov, /*iovcnt=*/1,
+ /*offset=*/2, /*flags=*/0),
+ SyscallFailsWithErrno(ESPIPE));
+
+ EXPECT_THAT(close(pipe_fds[0]), SyscallSucceeds());
+ EXPECT_THAT(close(pipe_fds[1]), SyscallSucceeds());
+}
+
+TEST(Pwritev2Test, ReadOnlyFile) {
+ SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+
+ char buf[16];
+ struct iovec iov;
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
+
+ EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1,
+ /*offset=*/0, /*flags=*/0),
+ SyscallFailsWithErrno(EBADF));
+}
+
+// This test calls pwritev2 with an invalid flag.
+TEST(Pwritev2Test, InvalidFlag) {
+ SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR | O_DIRECT));
+
+ char buf[16];
+ struct iovec iov;
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
+
+ EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1,
+ /*offset=*/0, /*flags=*/0xF0),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc
new file mode 100644
index 000000000..05c4ed03f
--- /dev/null
+++ b/test/syscalls/linux/raw_socket.cc
@@ -0,0 +1,819 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <linux/capability.h>
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <netinet/ip6.h>
+#include <netinet/ip_icmp.h>
+#include <poll.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+#include <algorithm>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+// Note: in order to run these tests, /proc/sys/net/ipv4/ping_group_range will
+// need to be configured to let the superuser create ping sockets (see icmp(7)).
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Fixture for tests parameterized by protocol.
+class RawSocketTest : public ::testing::TestWithParam<std::tuple<int, int>> {
+ protected:
+ // Creates a socket to be used in tests.
+ void SetUp() override;
+
+ // Closes the socket created by SetUp().
+ void TearDown() override;
+
+ // Sends buf via s_.
+ void SendBuf(const char* buf, int buf_len);
+
+ // Reads from s_ into recv_buf.
+ void ReceiveBuf(char* recv_buf, size_t recv_buf_len);
+
+ void ReceiveBufFrom(int sock, char* recv_buf, size_t recv_buf_len);
+
+ int Protocol() { return std::get<0>(GetParam()); }
+
+ int Family() { return std::get<1>(GetParam()); }
+
+ socklen_t AddrLen() {
+ if (Family() == AF_INET) {
+ return sizeof(sockaddr_in);
+ }
+ return sizeof(sockaddr_in6);
+ }
+
+ int HdrLen() {
+ if (Family() == AF_INET) {
+ return sizeof(struct iphdr);
+ }
+ // IPv6 raw sockets don't include the header.
+ return 0;
+ }
+
+ // The socket used for both reading and writing.
+ int s_;
+
+ // The loopback address.
+ struct sockaddr_storage addr_;
+};
+
+void RawSocketTest::SetUp() {
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(Family(), SOCK_RAW, Protocol()),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
+ ASSERT_THAT(s_ = socket(Family(), SOCK_RAW, Protocol()), SyscallSucceeds());
+
+ addr_ = {};
+
+ // We don't set ports because raw sockets don't have a notion of ports.
+ if (Family() == AF_INET) {
+ struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&addr_);
+ sin->sin_family = AF_INET;
+ sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ } else {
+ struct sockaddr_in6* sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr_);
+ sin6->sin6_family = AF_INET6;
+ sin6->sin6_addr = in6addr_loopback;
+ }
+}
+
+void RawSocketTest::TearDown() {
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(s_), SyscallSucceeds());
+ }
+}
+
+// We should be able to create multiple raw sockets for the same protocol.
+// BasicRawSocket::Setup creates the first one, so we only have to create one
+// more here.
+TEST_P(RawSocketTest, MultipleCreation) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int s2;
+ ASSERT_THAT(s2 = socket(Family(), SOCK_RAW, Protocol()), SyscallSucceeds());
+
+ ASSERT_THAT(close(s2), SyscallSucceeds());
+}
+
+// Test that shutting down an unconnected socket fails.
+TEST_P(RawSocketTest, FailShutdownWithoutConnect) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
+ ASSERT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN));
+}
+
+// Shutdown is a no-op for raw sockets (and datagram sockets in general).
+TEST_P(RawSocketTest, ShutdownWriteNoop) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+ ASSERT_THAT(shutdown(s_, SHUT_WR), SyscallSucceeds());
+
+ // Arbitrary.
+ constexpr char kBuf[] = "noop";
+ ASSERT_THAT(RetryEINTR(write)(s_, kBuf, sizeof(kBuf)),
+ SyscallSucceedsWithValue(sizeof(kBuf)));
+}
+
+// Shutdown is a no-op for raw sockets (and datagram sockets in general).
+TEST_P(RawSocketTest, ShutdownReadNoop) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+ ASSERT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds());
+
+ // Arbitrary.
+ constexpr char kBuf[] = "gdg";
+ ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
+
+ std::vector<char> c(sizeof(kBuf) + HdrLen());
+ ASSERT_THAT(read(s_, c.data(), c.size()), SyscallSucceedsWithValue(c.size()));
+}
+
+// Test that listen() fails.
+TEST_P(RawSocketTest, FailListen) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(listen(s_, 1), SyscallFailsWithErrno(ENOTSUP));
+}
+
+// Test that accept() fails.
+TEST_P(RawSocketTest, FailAccept) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ struct sockaddr saddr;
+ socklen_t addrlen;
+ ASSERT_THAT(accept(s_, &saddr, &addrlen), SyscallFailsWithErrno(ENOTSUP));
+}
+
+// Test that getpeername() returns nothing before connect().
+TEST_P(RawSocketTest, FailGetPeerNameBeforeConnect) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ struct sockaddr saddr;
+ socklen_t addrlen = sizeof(saddr);
+ ASSERT_THAT(getpeername(s_, &saddr, &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
+// Test that getpeername() returns something after connect().
+TEST_P(RawSocketTest, GetPeerName) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+ struct sockaddr saddr;
+ socklen_t addrlen = sizeof(saddr);
+ ASSERT_THAT(getpeername(s_, &saddr, &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+ ASSERT_GT(addrlen, 0);
+}
+
+// Test that the socket is writable immediately.
+TEST_P(RawSocketTest, PollWritableImmediately) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ struct pollfd pfd = {};
+ pfd.fd = s_;
+ pfd.events = POLLOUT;
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 10000), SyscallSucceedsWithValue(1));
+}
+
+// Test that the socket isn't readable before receiving anything.
+TEST_P(RawSocketTest, PollNotReadableInitially) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Try to receive data with MSG_DONTWAIT, which returns immediately if there's
+ // nothing to be read.
+ char buf[117];
+ ASSERT_THAT(RetryEINTR(recv)(s_, buf, sizeof(buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// Test that the socket becomes readable once something is written to it.
+TEST_P(RawSocketTest, PollTriggeredOnWrite) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Write something so that there's data to be read.
+ // Arbitrary.
+ constexpr char kBuf[] = "JP5";
+ ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
+
+ struct pollfd pfd = {};
+ pfd.fd = s_;
+ pfd.events = POLLIN;
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 10000), SyscallSucceedsWithValue(1));
+}
+
+// Test that we can connect() to a valid IP (loopback).
+TEST_P(RawSocketTest, ConnectToLoopback) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+}
+
+// Test that calling send() without connect() fails.
+TEST_P(RawSocketTest, SendWithoutConnectFails) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Arbitrary.
+ constexpr char kBuf[] = "Endgame was good";
+ ASSERT_THAT(send(s_, kBuf, sizeof(kBuf), 0),
+ SyscallFailsWithErrno(EDESTADDRREQ));
+}
+
+// Bind to localhost.
+TEST_P(RawSocketTest, BindToLocalhost) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+}
+
+// Bind to a different address.
+TEST_P(RawSocketTest, BindToInvalid) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ struct sockaddr_storage bind_addr = addr_;
+ if (Family() == AF_INET) {
+ struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&bind_addr);
+ sin->sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to.
+ } else {
+ struct sockaddr_in6* sin6 =
+ reinterpret_cast<struct sockaddr_in6*>(&bind_addr);
+ memset(&sin6->sin6_addr.s6_addr, 0, sizeof(sin6->sin6_addr.s6_addr));
+ sin6->sin6_addr.s6_addr[0] = 1; // 1: - An address that we can't bind to.
+ }
+ ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ AddrLen()), SyscallFailsWithErrno(EADDRNOTAVAIL));
+}
+
+// Send and receive an packet.
+TEST_P(RawSocketTest, SendAndReceive) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Arbitrary.
+ constexpr char kBuf[] = "TB12";
+ ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
+
+ // Receive the packet and make sure it's identical.
+ std::vector<char> recv_buf(sizeof(kBuf) + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
+ EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0);
+}
+
+// We should be able to create multiple raw sockets for the same protocol and
+// receive the same packet on both.
+TEST_P(RawSocketTest, MultipleSocketReceive) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int s2;
+ ASSERT_THAT(s2 = socket(Family(), SOCK_RAW, Protocol()), SyscallSucceeds());
+
+ // Arbitrary.
+ constexpr char kBuf[] = "TB10";
+ ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
+
+ // Receive it on socket 1.
+ std::vector<char> recv_buf1(sizeof(kBuf) + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf1.data(), recv_buf1.size()));
+
+ // Receive it on socket 2.
+ std::vector<char> recv_buf2(sizeof(kBuf) + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBufFrom(s2, recv_buf2.data(),
+ recv_buf2.size()));
+
+ EXPECT_EQ(memcmp(recv_buf1.data() + HdrLen(),
+ recv_buf2.data() + HdrLen(), sizeof(kBuf)),
+ 0);
+
+ ASSERT_THAT(close(s2), SyscallSucceeds());
+}
+
+// Test that connect sends packets to the right place.
+TEST_P(RawSocketTest, SendAndReceiveViaConnect) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+
+ // Arbitrary.
+ constexpr char kBuf[] = "JH4";
+ ASSERT_THAT(send(s_, kBuf, sizeof(kBuf), 0),
+ SyscallSucceedsWithValue(sizeof(kBuf)));
+
+ // Receive the packet and make sure it's identical.
+ std::vector<char> recv_buf(sizeof(kBuf) + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
+ EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0);
+}
+
+// Bind to localhost, then send and receive packets.
+TEST_P(RawSocketTest, BindSendAndReceive) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+
+ // Arbitrary.
+ constexpr char kBuf[] = "DR16";
+ ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
+
+ // Receive the packet and make sure it's identical.
+ std::vector<char> recv_buf(sizeof(kBuf) + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
+ EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0);
+}
+
+// Bind and connect to localhost and send/receive packets.
+TEST_P(RawSocketTest, BindConnectSendAndReceive) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+
+ // Arbitrary.
+ constexpr char kBuf[] = "DG88";
+ ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
+
+ // Receive the packet and make sure it's identical.
+ std::vector<char> recv_buf(sizeof(kBuf) + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
+ EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0);
+}
+
+// Check that setting SO_RCVBUF below min is clamped to the minimum
+// receive buffer size.
+TEST_P(RawSocketTest, SetSocketRecvBufBelowMin) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Discover minimum receive buf size by trying to set it to zero.
+ // See:
+ // https://github.com/torvalds/linux/blob/a5dc8300df75e8b8384b4c82225f1e4a0b4d9b55/net/core/sock.c#L820
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ int min = 0;
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value so let's use a value that when doubled will still
+ // be smaller than min.
+ int below_min = min / 2 - 1;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &below_min, sizeof(below_min)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ ASSERT_EQ(min, val);
+}
+
+// Check that setting SO_RCVBUF above max is clamped to the maximum
+// receive buffer size.
+TEST_P(RawSocketTest, SetSocketRecvBufAboveMax) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Discover max buf size by trying to set the largest possible buffer size.
+ constexpr int kRcvBufSz = 0xffffffff;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ int max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &max, &max_len),
+ SyscallSucceeds());
+
+ int above_max = max + 1;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &above_max, sizeof(above_max)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len),
+ SyscallSucceeds());
+ ASSERT_EQ(max, val);
+}
+
+// Check that setting SO_RCVBUF min <= kRcvBufSz <= max is honored.
+TEST_P(RawSocketTest, SetSocketRecvBuf) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int max = 0;
+ int min = 0;
+ {
+ // Discover max buf size by trying to set a really large buffer size.
+ constexpr int kRcvBufSz = 0xffffffff;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &max, &max_len),
+ SyscallSucceeds());
+ }
+
+ {
+ // Discover minimum buffer size by trying to set a zero size receive buffer
+ // size.
+ // See:
+ // https://github.com/torvalds/linux/blob/a5dc8300df75e8b8384b4c82225f1e4a0b4d9b55/net/core/sock.c#L820
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+ }
+
+ int quarter_sz = min + (max - min) / 4;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &quarter_sz, sizeof(quarter_sz)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF.
+ // TODO(gvisor.dev/issue/2926): Remove when Netstack matches linux behavior.
+ if (!IsRunningOnGvisor()) {
+ quarter_sz *= 2;
+ }
+ ASSERT_EQ(quarter_sz, val);
+}
+
+// Check that setting SO_SNDBUF below min is clamped to the minimum
+// receive buffer size.
+TEST_P(RawSocketTest, SetSocketSendBufBelowMin) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Discover minimum buffer size by trying to set it to zero.
+ constexpr int kSndBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ int min = 0;
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &min, &min_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value so let's use a value that when doubled will still
+ // be smaller than min.
+ int below_min = min / 2 - 1;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &below_min, sizeof(below_min)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ ASSERT_EQ(min, val);
+}
+
+// Check that setting SO_SNDBUF above max is clamped to the maximum
+// send buffer size.
+TEST_P(RawSocketTest, SetSocketSendBufAboveMax) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Discover maximum buffer size by trying to set it to a large value.
+ constexpr int kSndBufSz = 0xffffffff;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ int max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &max, &max_len),
+ SyscallSucceeds());
+
+ int above_max = max + 1;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &above_max, sizeof(above_max)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len),
+ SyscallSucceeds());
+ ASSERT_EQ(max, val);
+}
+
+// Check that setting SO_SNDBUF min <= kSndBufSz <= max is honored.
+TEST_P(RawSocketTest, SetSocketSendBuf) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int max = 0;
+ int min = 0;
+ {
+ // Discover maximum buffer size by trying to set it to a large value.
+ constexpr int kSndBufSz = 0xffffffff;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &max, &max_len),
+ SyscallSucceeds());
+ }
+
+ {
+ // Discover minimum buffer size by trying to set it to zero.
+ constexpr int kSndBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &min, &min_len),
+ SyscallSucceeds());
+ }
+
+ int quarter_sz = min + (max - min) / 4;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &quarter_sz, sizeof(quarter_sz)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF.
+ // TODO(gvisor.dev/issue/2926): Remove the gvisor special casing when Netstack
+ // matches linux behavior.
+ if (!IsRunningOnGvisor()) {
+ quarter_sz *= 2;
+ }
+
+ ASSERT_EQ(quarter_sz, val);
+}
+
+// Test that receive buffer limits are not enforced when the recv buffer is
+// empty.
+TEST_P(RawSocketTest, RecvBufLimitsEmptyRecvBuffer) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+
+ int min = 0;
+ {
+ // Discover minimum buffer size by trying to set it to zero.
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+ }
+
+ {
+ // Send data of size min and verify that it's received.
+ std::vector<char> buf(min);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size()));
+
+ // Receive the packet and make sure it's identical.
+ std::vector<char> recv_buf(buf.size() + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
+ EXPECT_EQ(
+ memcmp(recv_buf.data() + HdrLen(), buf.data(), buf.size()),
+ 0);
+ }
+
+ {
+ // Send data of size min + 1 and verify that its received. Both linux and
+ // Netstack accept a dgram that exceeds rcvBuf limits if the receive buffer
+ // is currently empty.
+ std::vector<char> buf(min + 1);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size()));
+ // Receive the packet and make sure it's identical.
+ std::vector<char> recv_buf(buf.size() + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
+ EXPECT_EQ(
+ memcmp(recv_buf.data() + HdrLen(), buf.data(), buf.size()),
+ 0);
+ }
+}
+
+TEST_P(RawSocketTest, RecvBufLimits) {
+ // TCP stack generates RSTs for unknown endpoints and it complicates the test
+ // as we have to deal with the RST packets as well. For testing the raw socket
+ // endpoints buffer limit enforcement we can just test for UDP.
+ //
+ // We don't use SKIP_IF here because root_test_runner explicitly fails if a
+ // test is skipped.
+ if (Protocol() == IPPROTO_TCP) {
+ return;
+ }
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+
+ int min = 0;
+ {
+ // Discover minimum buffer size by trying to set it to zero.
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+ }
+
+ // Now set the limit to min * 2.
+ int new_rcv_buf_sz = min * 4;
+ if (!IsRunningOnGvisor()) {
+ // Linux doubles the value specified so just set to min.
+ new_rcv_buf_sz = min * 2;
+ }
+
+ ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz,
+ sizeof(new_rcv_buf_sz)),
+ SyscallSucceeds());
+ int rcv_buf_sz = 0;
+ {
+ socklen_t rcv_buf_len = sizeof(rcv_buf_sz);
+ ASSERT_THAT(
+ getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &rcv_buf_sz, &rcv_buf_len),
+ SyscallSucceeds());
+ }
+
+ // Set a receive timeout so that we don't block forever on reads if the test
+ // fails.
+ struct timeval tv {
+ .tv_sec = 1, .tv_usec = 0,
+ };
+ ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ {
+ std::vector<char> buf(min);
+ RandomizeBuffer(buf.data(), buf.size());
+
+ ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size()));
+ ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size()));
+ ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size()));
+ ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size()));
+ int sent = 4;
+ if (IsRunningOnGvisor()) {
+ // Linux seems to drop the 4th packet even though technically it should
+ // fit in the receive buffer.
+ ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size()));
+ sent++;
+ }
+
+ // Verify that the expected number of packets are available to be read.
+ for (int i = 0; i < sent - 1; i++) {
+ // Receive the packet and make sure it's identical.
+ std::vector<char> recv_buf(buf.size() + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
+ EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), buf.data(),
+ buf.size()),
+ 0);
+ }
+
+ // Assert that the last packet is dropped because the receive buffer should
+ // be full after the first four packets.
+ std::vector<char> recv_buf(buf.size() + HdrLen());
+ struct iovec iov = {};
+ iov.iov_base = static_cast<void*>(const_cast<char*>(recv_buf.data()));
+ iov.iov_len = buf.size();
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = NULL;
+ msg.msg_controllen = 0;
+ msg.msg_flags = 0;
+ ASSERT_THAT(RetryEINTR(recvmsg)(s_, &msg, MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+ }
+}
+
+void RawSocketTest::SendBuf(const char* buf, int buf_len) {
+ // It's safe to use const_cast here because sendmsg won't modify the iovec or
+ // address.
+ struct iovec iov = {};
+ iov.iov_base = static_cast<void*>(const_cast<char*>(buf));
+ iov.iov_len = static_cast<size_t>(buf_len);
+ struct msghdr msg = {};
+ msg.msg_name = static_cast<void*>(&addr_);
+ msg.msg_namelen = AddrLen();
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = NULL;
+ msg.msg_controllen = 0;
+ msg.msg_flags = 0;
+ ASSERT_THAT(sendmsg(s_, &msg, 0), SyscallSucceedsWithValue(buf_len));
+}
+
+void RawSocketTest::ReceiveBuf(char* recv_buf, size_t recv_buf_len) {
+ ASSERT_NO_FATAL_FAILURE(ReceiveBufFrom(s_, recv_buf, recv_buf_len));
+}
+
+void RawSocketTest::ReceiveBufFrom(int sock, char* recv_buf,
+ size_t recv_buf_len) {
+ ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(sock, recv_buf, recv_buf_len));
+}
+
+INSTANTIATE_TEST_SUITE_P(AllInetTests, RawSocketTest,
+ ::testing::Combine(
+ ::testing::Values(IPPROTO_TCP, IPPROTO_UDP),
+ ::testing::Values(AF_INET, AF_INET6)));
+
+// AF_INET6+SOCK_RAW+IPPROTO_RAW sockets can be created, but not written to.
+TEST(RawSocketTest, IPv6ProtoRaw) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int sock;
+ ASSERT_THAT(sock = socket(AF_INET6, SOCK_RAW, IPPROTO_RAW),
+ SyscallSucceeds());
+
+ // Verify that writing yields EINVAL.
+ char buf[] = "This is such a weird little edge case";
+ struct sockaddr_in6 sin6 = {};
+ sin6.sin6_family = AF_INET6;
+ sin6.sin6_addr = in6addr_loopback;
+ ASSERT_THAT(sendto(sock, buf, sizeof(buf), 0 /* flags */,
+ reinterpret_cast<struct sockaddr*>(&sin6), sizeof(sin6)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/raw_socket_hdrincl.cc b/test/syscalls/linux/raw_socket_hdrincl.cc
new file mode 100644
index 000000000..5bb14d57c
--- /dev/null
+++ b/test/syscalls/linux/raw_socket_hdrincl.cc
@@ -0,0 +1,406 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <linux/capability.h>
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <netinet/ip_icmp.h>
+#include <netinet/udp.h>
+#include <poll.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <algorithm>
+#include <cstring>
+
+#include "gtest/gtest.h"
+#include "absl/base/internal/endian.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Tests for IPPROTO_RAW raw sockets, which implies IP_HDRINCL.
+class RawHDRINCL : public ::testing::Test {
+ protected:
+ // Creates a socket to be used in tests.
+ void SetUp() override;
+
+ // Closes the socket created by SetUp().
+ void TearDown() override;
+
+ // Returns a valid looback IP header with no payload.
+ struct iphdr LoopbackHeader();
+
+ // Fills in buf with an IP header, UDP header, and payload. Returns false if
+ // buf_size isn't large enough to hold everything.
+ bool FillPacket(char* buf, size_t buf_size, int port, const char* payload,
+ uint16_t payload_size);
+
+ // The socket used for both reading and writing.
+ int socket_;
+
+ // The loopback address.
+ struct sockaddr_in addr_;
+};
+
+void RawHDRINCL::SetUp() {
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_INET, SOCK_RAW, IPPROTO_RAW),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
+ ASSERT_THAT(socket_ = socket(AF_INET, SOCK_RAW, IPPROTO_RAW),
+ SyscallSucceeds());
+
+ addr_ = {};
+
+ addr_.sin_port = IPPROTO_IP;
+ addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ addr_.sin_family = AF_INET;
+}
+
+void RawHDRINCL::TearDown() {
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(socket_), SyscallSucceeds());
+ }
+}
+
+struct iphdr RawHDRINCL::LoopbackHeader() {
+ struct iphdr hdr = {};
+ hdr.ihl = 5;
+ hdr.version = 4;
+ hdr.tos = 0;
+ hdr.tot_len = absl::gbswap_16(sizeof(hdr));
+ hdr.id = 0;
+ hdr.frag_off = 0;
+ hdr.ttl = 7;
+ hdr.protocol = 1;
+ hdr.daddr = htonl(INADDR_LOOPBACK);
+ // hdr.check is set by the network stack.
+ // hdr.tot_len is set by the network stack.
+ // hdr.saddr is set by the network stack.
+ return hdr;
+}
+
+bool RawHDRINCL::FillPacket(char* buf, size_t buf_size, int port,
+ const char* payload, uint16_t payload_size) {
+ if (buf_size < sizeof(struct iphdr) + sizeof(struct udphdr) + payload_size) {
+ return false;
+ }
+
+ struct iphdr ip = LoopbackHeader();
+ ip.protocol = IPPROTO_UDP;
+
+ struct udphdr udp = {};
+ udp.source = absl::gbswap_16(port);
+ udp.dest = absl::gbswap_16(port);
+ udp.len = absl::gbswap_16(sizeof(udp) + payload_size);
+ udp.check = 0;
+
+ memcpy(buf, reinterpret_cast<char*>(&ip), sizeof(ip));
+ memcpy(buf + sizeof(ip), reinterpret_cast<char*>(&udp), sizeof(udp));
+ memcpy(buf + sizeof(ip) + sizeof(udp), payload, payload_size);
+
+ return true;
+}
+
+// We should be able to create multiple IPPROTO_RAW sockets. RawHDRINCL::Setup
+// creates the first one, so we only have to create one more here.
+TEST_F(RawHDRINCL, MultipleCreation) {
+ int s2;
+ ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, IPPROTO_RAW), SyscallSucceeds());
+
+ ASSERT_THAT(close(s2), SyscallSucceeds());
+}
+
+// Test that shutting down an unconnected socket fails.
+TEST_F(RawHDRINCL, FailShutdownWithoutConnect) {
+ ASSERT_THAT(shutdown(socket_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
+ ASSERT_THAT(shutdown(socket_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN));
+}
+
+// Test that listen() fails.
+TEST_F(RawHDRINCL, FailListen) {
+ ASSERT_THAT(listen(socket_, 1), SyscallFailsWithErrno(ENOTSUP));
+}
+
+// Test that accept() fails.
+TEST_F(RawHDRINCL, FailAccept) {
+ struct sockaddr saddr;
+ socklen_t addrlen;
+ ASSERT_THAT(accept(socket_, &saddr, &addrlen),
+ SyscallFailsWithErrno(ENOTSUP));
+}
+
+// Test that the socket is writable immediately.
+TEST_F(RawHDRINCL, PollWritableImmediately) {
+ struct pollfd pfd = {};
+ pfd.fd = socket_;
+ pfd.events = POLLOUT;
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 0), SyscallSucceedsWithValue(1));
+}
+
+// Test that the socket isn't readable.
+TEST_F(RawHDRINCL, NotReadable) {
+ // Try to receive data with MSG_DONTWAIT, which returns immediately if there's
+ // nothing to be read.
+ char buf[117];
+ ASSERT_THAT(RetryEINTR(recv)(socket_, buf, sizeof(buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// Test that we can connect() to a valid IP (loopback).
+TEST_F(RawHDRINCL, ConnectToLoopback) {
+ ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_),
+ sizeof(addr_)),
+ SyscallSucceeds());
+}
+
+TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) {
+ struct iphdr hdr = LoopbackHeader();
+ ASSERT_THAT(send(socket_, &hdr, sizeof(hdr), 0),
+ SyscallSucceedsWithValue(sizeof(hdr)));
+}
+
+// HDRINCL implies write-only. Verify that we can't read a packet sent to
+// loopback.
+TEST_F(RawHDRINCL, NotReadableAfterWrite) {
+ ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_),
+ sizeof(addr_)),
+ SyscallSucceeds());
+
+ // Construct a packet with an IP header, UDP header, and payload.
+ constexpr char kPayload[] = "odst";
+ char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)];
+ ASSERT_TRUE(FillPacket(packet, sizeof(packet), 40000 /* port */, kPayload,
+ sizeof(kPayload)));
+
+ socklen_t addrlen = sizeof(addr_);
+ ASSERT_NO_FATAL_FAILURE(
+ sendto(socket_, reinterpret_cast<void*>(&packet), sizeof(packet), 0,
+ reinterpret_cast<struct sockaddr*>(&addr_), addrlen));
+
+ struct pollfd pfd = {};
+ pfd.fd = socket_;
+ pfd.events = POLLIN;
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(RawHDRINCL, WriteTooSmall) {
+ ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_),
+ sizeof(addr_)),
+ SyscallSucceeds());
+
+ // This is smaller than the size of an IP header.
+ constexpr char kBuf[] = "JP5";
+ ASSERT_THAT(send(socket_, kBuf, sizeof(kBuf), 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// Bind to localhost.
+TEST_F(RawHDRINCL, BindToLocalhost) {
+ ASSERT_THAT(
+ bind(socket_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceeds());
+}
+
+// Bind to a different address.
+TEST_F(RawHDRINCL, BindToInvalid) {
+ struct sockaddr_in bind_addr = {};
+ bind_addr.sin_family = AF_INET;
+ bind_addr.sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to.
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallFailsWithErrno(EADDRNOTAVAIL));
+}
+
+// Send and receive a packet.
+TEST_F(RawHDRINCL, SendAndReceive) {
+ int port = 40000;
+ if (!IsRunningOnGvisor()) {
+ port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
+ PortAvailable(0, AddressFamily::kIpv4, SocketType::kUdp, false)));
+ }
+
+ // IPPROTO_RAW sockets are write-only. We'll have to open another socket to
+ // read what we write.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP));
+
+ // Construct a packet with an IP header, UDP header, and payload.
+ constexpr char kPayload[] = "toto";
+ char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)];
+ ASSERT_TRUE(
+ FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload)));
+
+ socklen_t addrlen = sizeof(addr_);
+ ASSERT_NO_FATAL_FAILURE(sendto(socket_, &packet, sizeof(packet), 0,
+ reinterpret_cast<struct sockaddr*>(&addr_),
+ addrlen));
+
+ // Receive the payload.
+ char recv_buf[sizeof(packet)];
+ struct sockaddr_in src;
+ socklen_t src_size = sizeof(src);
+ ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), 0,
+ reinterpret_cast<struct sockaddr*>(&src), &src_size),
+ SyscallSucceedsWithValue(sizeof(packet)));
+ EXPECT_EQ(
+ memcmp(kPayload, recv_buf + sizeof(struct iphdr) + sizeof(struct udphdr),
+ sizeof(kPayload)),
+ 0);
+ // The network stack should have set the source address.
+ EXPECT_EQ(src.sin_family, AF_INET);
+ EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK);
+ // The packet ID should not be 0, as the packet has DF=0.
+ struct iphdr* iphdr = reinterpret_cast<struct iphdr*>(recv_buf);
+ EXPECT_NE(iphdr->id, 0);
+}
+
+// Send and receive a packet where the sendto address is not the same as the
+// provided destination.
+TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) {
+ int port = 40000;
+ if (!IsRunningOnGvisor()) {
+ port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
+ PortAvailable(0, AddressFamily::kIpv4, SocketType::kUdp, false)));
+ }
+
+ // IPPROTO_RAW sockets are write-only. We'll have to open another socket to
+ // read what we write.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP));
+
+ // Construct a packet with an IP header, UDP header, and payload.
+ constexpr char kPayload[] = "toto";
+ char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)];
+ ASSERT_TRUE(
+ FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload)));
+ // Overwrite the IP destination address with an IP we can't get to.
+ struct iphdr iphdr = {};
+ memcpy(&iphdr, packet, sizeof(iphdr));
+ iphdr.daddr = 42;
+ memcpy(packet, &iphdr, sizeof(iphdr));
+
+ socklen_t addrlen = sizeof(addr_);
+ ASSERT_NO_FATAL_FAILURE(sendto(socket_, &packet, sizeof(packet), 0,
+ reinterpret_cast<struct sockaddr*>(&addr_),
+ addrlen));
+
+ // Receive the payload, since sendto should replace the bad destination with
+ // localhost.
+ char recv_buf[sizeof(packet)];
+ struct sockaddr_in src;
+ socklen_t src_size = sizeof(src);
+ ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), 0,
+ reinterpret_cast<struct sockaddr*>(&src), &src_size),
+ SyscallSucceedsWithValue(sizeof(packet)));
+ EXPECT_EQ(
+ memcmp(kPayload, recv_buf + sizeof(struct iphdr) + sizeof(struct udphdr),
+ sizeof(kPayload)),
+ 0);
+ // The network stack should have set the source address.
+ EXPECT_EQ(src.sin_family, AF_INET);
+ EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK);
+ // The packet ID should not be 0, as the packet has DF=0.
+ struct iphdr recv_iphdr = {};
+ memcpy(&recv_iphdr, recv_buf, sizeof(recv_iphdr));
+ EXPECT_NE(recv_iphdr.id, 0);
+ // The destination address should be localhost, not the bad IP we set
+ // initially.
+ EXPECT_EQ(absl::gbswap_32(recv_iphdr.daddr), INADDR_LOOPBACK);
+}
+
+// Send and receive a packet w/ the IP_HDRINCL option set.
+TEST_F(RawHDRINCL, SendAndReceiveIPHdrIncl) {
+ int port = 40000;
+ if (!IsRunningOnGvisor()) {
+ port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
+ PortAvailable(0, AddressFamily::kIpv4, SocketType::kUdp, false)));
+ }
+
+ FileDescriptor recv_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP));
+
+ FileDescriptor send_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP));
+
+ // Enable IP_HDRINCL option so that we can build and send w/ an IP
+ // header.
+ constexpr int kSockOptOn = 1;
+ ASSERT_THAT(setsockopt(send_sock.get(), SOL_IP, IP_HDRINCL, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ // This is not strictly required but we do it to make sure that setting
+ // IP_HDRINCL on a non IPPROTO_RAW socket does not prevent it from receiving
+ // packets.
+ ASSERT_THAT(setsockopt(recv_sock.get(), SOL_IP, IP_HDRINCL, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Construct a packet with an IP header, UDP header, and payload.
+ constexpr char kPayload[] = "toto";
+ char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)];
+ ASSERT_TRUE(
+ FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload)));
+
+ socklen_t addrlen = sizeof(addr_);
+ ASSERT_NO_FATAL_FAILURE(sendto(send_sock.get(), &packet, sizeof(packet), 0,
+ reinterpret_cast<struct sockaddr*>(&addr_),
+ addrlen));
+
+ // Receive the payload.
+ char recv_buf[sizeof(packet)];
+ struct sockaddr_in src;
+ socklen_t src_size = sizeof(src);
+ ASSERT_THAT(recvfrom(recv_sock.get(), recv_buf, sizeof(recv_buf), 0,
+ reinterpret_cast<struct sockaddr*>(&src), &src_size),
+ SyscallSucceedsWithValue(sizeof(packet)));
+ EXPECT_EQ(
+ memcmp(kPayload, recv_buf + sizeof(struct iphdr) + sizeof(struct udphdr),
+ sizeof(kPayload)),
+ 0);
+ // The network stack should have set the source address.
+ EXPECT_EQ(src.sin_family, AF_INET);
+ EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK);
+ struct iphdr iphdr = {};
+ memcpy(&iphdr, recv_buf, sizeof(iphdr));
+ EXPECT_NE(iphdr.id, 0);
+
+ // Also verify that the packet we just sent was not delivered to the
+ // IPPROTO_RAW socket.
+ {
+ char recv_buf[sizeof(packet)];
+ struct sockaddr_in src;
+ socklen_t src_size = sizeof(src);
+ ASSERT_THAT(recvfrom(socket_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT,
+ reinterpret_cast<struct sockaddr*>(&src), &src_size),
+ SyscallFailsWithErrno(EAGAIN));
+ }
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc
new file mode 100644
index 000000000..3de898df7
--- /dev/null
+++ b/test/syscalls/linux/raw_socket_icmp.cc
@@ -0,0 +1,514 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <linux/capability.h>
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <netinet/ip_icmp.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <algorithm>
+#include <cstdint>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// The size of an empty ICMP packet and IP header together.
+constexpr size_t kEmptyICMPSize = 28;
+
+// ICMP raw sockets get their own special tests because Linux automatically
+// responds to ICMP echo requests, and thus a single echo request sent via
+// loopback leads to 2 received ICMP packets.
+
+class RawSocketICMPTest : public ::testing::Test {
+ protected:
+ // Creates a socket to be used in tests.
+ void SetUp() override;
+
+ // Closes the socket created by SetUp().
+ void TearDown() override;
+
+ // Checks that both an ICMP echo request and reply are received. Calls should
+ // be wrapped in ASSERT_NO_FATAL_FAILURE.
+ void ExpectICMPSuccess(const struct icmphdr& icmp);
+
+ // Sends icmp via s_.
+ void SendEmptyICMP(const struct icmphdr& icmp);
+
+ // Sends icmp via s_ to the given address.
+ void SendEmptyICMPTo(int sock, const struct sockaddr_in& addr,
+ const struct icmphdr& icmp);
+
+ // Reads from s_ into recv_buf.
+ void ReceiveICMP(char* recv_buf, size_t recv_buf_len, size_t expected_size,
+ struct sockaddr_in* src);
+
+ // Reads from sock into recv_buf.
+ void ReceiveICMPFrom(char* recv_buf, size_t recv_buf_len,
+ size_t expected_size, struct sockaddr_in* src, int sock);
+
+ // The socket used for both reading and writing.
+ int s_;
+
+ // The loopback address.
+ struct sockaddr_in addr_;
+};
+
+void RawSocketICMPTest::SetUp() {
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_INET, SOCK_RAW, IPPROTO_ICMP),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
+ ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), SyscallSucceeds());
+
+ addr_ = {};
+
+ // "On raw sockets sin_port is set to the IP protocol." - ip(7).
+ addr_.sin_port = IPPROTO_IP;
+ addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ addr_.sin_family = AF_INET;
+}
+
+void RawSocketICMPTest::TearDown() {
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(s_), SyscallSucceeds());
+ }
+}
+
+// We'll only read an echo in this case, as the kernel won't respond to the
+// malformed ICMP checksum.
+TEST_F(RawSocketICMPTest, SendAndReceiveBadChecksum) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Prepare and send an ICMP packet. Use arbitrary junk for checksum, sequence,
+ // and ID. None of that should matter for raw sockets - the kernel should
+ // still give us the packet.
+ struct icmphdr icmp;
+ icmp.type = ICMP_ECHO;
+ icmp.code = 0;
+ icmp.checksum = 0;
+ icmp.un.echo.sequence = 2012;
+ icmp.un.echo.id = 2014;
+ ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp));
+
+ // Veryify that we get the echo, then that there's nothing else to read.
+ char recv_buf[kEmptyICMPSize];
+ struct sockaddr_in src;
+ ASSERT_NO_FATAL_FAILURE(
+ ReceiveICMP(recv_buf, sizeof(recv_buf), sizeof(struct icmphdr), &src));
+ EXPECT_EQ(memcmp(&src, &addr_, sizeof(src)), 0);
+ // The packet should be identical to what we sent.
+ EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), &icmp, sizeof(icmp)), 0);
+
+ // And there should be nothing left to read.
+ EXPECT_THAT(RetryEINTR(recv)(s_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// Send and receive an ICMP packet.
+TEST_F(RawSocketICMPTest, SendAndReceive) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Prepare and send an ICMP packet. Use arbitrary junk for sequence and ID.
+ // None of that should matter for raw sockets - the kernel should still give
+ // us the packet.
+ struct icmphdr icmp;
+ icmp.type = ICMP_ECHO;
+ icmp.code = 0;
+ icmp.checksum = 0;
+ icmp.un.echo.sequence = 2012;
+ icmp.un.echo.id = 2014;
+ icmp.checksum = ICMPChecksum(icmp, NULL, 0);
+ ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp));
+
+ ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp));
+}
+
+// We should be able to create multiple raw sockets for the same protocol and
+// receive the same packet on both.
+TEST_F(RawSocketICMPTest, MultipleSocketReceive) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ FileDescriptor s2 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_ICMP));
+
+ // Prepare and send an ICMP packet. Use arbitrary junk for sequence and ID.
+ // None of that should matter for raw sockets - the kernel should still give
+ // us the packet.
+ struct icmphdr icmp;
+ icmp.type = ICMP_ECHO;
+ icmp.code = 0;
+ icmp.checksum = 0;
+ icmp.un.echo.sequence = 2016;
+ icmp.un.echo.id = 2018;
+ icmp.checksum = ICMPChecksum(icmp, NULL, 0);
+ ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp));
+
+ // Both sockets will receive the echo request and reply in indeterminate
+ // order, so we'll need to read 2 packets from each.
+
+ // Receive on socket 1.
+ constexpr int kBufSize = kEmptyICMPSize;
+ char recv_buf1[2][kBufSize];
+ struct sockaddr_in src;
+ for (int i = 0; i < 2; i++) {
+ ASSERT_NO_FATAL_FAILURE(ReceiveICMP(recv_buf1[i],
+ ABSL_ARRAYSIZE(recv_buf1[i]),
+ sizeof(struct icmphdr), &src));
+ EXPECT_EQ(memcmp(&src, &addr_, sizeof(src)), 0);
+ }
+
+ // Receive on socket 2.
+ char recv_buf2[2][kBufSize];
+ for (int i = 0; i < 2; i++) {
+ ASSERT_NO_FATAL_FAILURE(
+ ReceiveICMPFrom(recv_buf2[i], ABSL_ARRAYSIZE(recv_buf2[i]),
+ sizeof(struct icmphdr), &src, s2.get()));
+ EXPECT_EQ(memcmp(&src, &addr_, sizeof(src)), 0);
+ }
+
+ // Ensure both sockets receive identical packets.
+ int types[] = {ICMP_ECHO, ICMP_ECHOREPLY};
+ for (int type : types) {
+ auto match_type = [=](char buf[kBufSize]) {
+ struct icmphdr* icmp =
+ reinterpret_cast<struct icmphdr*>(buf + sizeof(struct iphdr));
+ return icmp->type == type;
+ };
+ auto icmp1_it =
+ std::find_if(std::begin(recv_buf1), std::end(recv_buf1), match_type);
+ auto icmp2_it =
+ std::find_if(std::begin(recv_buf2), std::end(recv_buf2), match_type);
+ ASSERT_NE(icmp1_it, std::end(recv_buf1));
+ ASSERT_NE(icmp2_it, std::end(recv_buf2));
+ EXPECT_EQ(memcmp(*icmp1_it + sizeof(struct iphdr),
+ *icmp2_it + sizeof(struct iphdr), sizeof(icmp)),
+ 0);
+ }
+}
+
+// A raw ICMP socket and ping socket should both receive the ICMP packets
+// intended for the ping socket.
+TEST_F(RawSocketICMPTest, RawAndPingSockets) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ FileDescriptor ping_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP));
+
+ // Ping sockets take care of the ICMP ID and checksum.
+ struct icmphdr icmp;
+ icmp.type = ICMP_ECHO;
+ icmp.code = 0;
+ icmp.un.echo.sequence = *static_cast<unsigned short*>(&icmp.un.echo.sequence);
+ ASSERT_THAT(RetryEINTR(sendto)(ping_sock.get(), &icmp, sizeof(icmp), 0,
+ reinterpret_cast<struct sockaddr*>(&addr_),
+ sizeof(addr_)),
+ SyscallSucceedsWithValue(sizeof(icmp)));
+
+ // Receive on socket 1, which receives the echo request and reply in
+ // indeterminate order.
+ constexpr int kBufSize = kEmptyICMPSize;
+ char recv_buf1[2][kBufSize];
+ struct sockaddr_in src;
+ for (int i = 0; i < 2; i++) {
+ ASSERT_NO_FATAL_FAILURE(
+ ReceiveICMP(recv_buf1[i], kBufSize, sizeof(struct icmphdr), &src));
+ EXPECT_EQ(memcmp(&src, &addr_, sizeof(src)), 0);
+ }
+
+ // Receive on socket 2. Ping sockets only get the echo reply, not the initial
+ // echo.
+ char ping_recv_buf[kBufSize];
+ ASSERT_THAT(RetryEINTR(recv)(ping_sock.get(), ping_recv_buf, kBufSize, 0),
+ SyscallSucceedsWithValue(sizeof(struct icmphdr)));
+
+ // Ensure both sockets receive identical echo reply packets.
+ auto match_type_raw = [=](char buf[kBufSize]) {
+ struct icmphdr* icmp =
+ reinterpret_cast<struct icmphdr*>(buf + sizeof(struct iphdr));
+ return icmp->type == ICMP_ECHOREPLY;
+ };
+ auto raw_reply_it =
+ std::find_if(std::begin(recv_buf1), std::end(recv_buf1), match_type_raw);
+ ASSERT_NE(raw_reply_it, std::end(recv_buf1));
+ EXPECT_EQ(
+ memcmp(*raw_reply_it + sizeof(struct iphdr), ping_recv_buf, sizeof(icmp)),
+ 0);
+}
+
+// A raw ICMP socket should be able to send a malformed short ICMP Echo Request,
+// while ping socket should not.
+// Neither should be able to receieve a short malformed packet.
+TEST_F(RawSocketICMPTest, ShortEchoRawAndPingSockets) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ FileDescriptor ping_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP));
+
+ struct icmphdr icmp;
+ icmp.type = ICMP_ECHO;
+ icmp.code = 0;
+ icmp.un.echo.sequence = 0;
+ icmp.un.echo.id = 6789;
+ icmp.checksum = 0;
+ icmp.checksum = ICMPChecksum(icmp, NULL, 0);
+
+ // Omit 2 bytes from ICMP packet.
+ constexpr int kShortICMPSize = sizeof(icmp) - 2;
+
+ // Sending a malformed short ICMP message to a ping socket should fail.
+ ASSERT_THAT(RetryEINTR(sendto)(ping_sock.get(), &icmp, kShortICMPSize, 0,
+ reinterpret_cast<struct sockaddr*>(&addr_),
+ sizeof(addr_)),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Sending a malformed short ICMP message to a raw socket should not fail.
+ ASSERT_THAT(RetryEINTR(sendto)(s_, &icmp, kShortICMPSize, 0,
+ reinterpret_cast<struct sockaddr*>(&addr_),
+ sizeof(addr_)),
+ SyscallSucceedsWithValue(kShortICMPSize));
+
+ // Neither Ping nor Raw socket should have anything to read.
+ char recv_buf[kEmptyICMPSize];
+ EXPECT_THAT(RetryEINTR(recv)(ping_sock.get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+ EXPECT_THAT(RetryEINTR(recv)(s_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// A raw ICMP socket should be able to send a malformed short ICMP Echo Reply,
+// while ping socket should not.
+// Neither should be able to receieve a short malformed packet.
+TEST_F(RawSocketICMPTest, ShortEchoReplyRawAndPingSockets) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ FileDescriptor ping_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP));
+
+ struct icmphdr icmp;
+ icmp.type = ICMP_ECHOREPLY;
+ icmp.code = 0;
+ icmp.un.echo.sequence = 0;
+ icmp.un.echo.id = 6789;
+ icmp.checksum = 0;
+ icmp.checksum = ICMPChecksum(icmp, NULL, 0);
+
+ // Omit 2 bytes from ICMP packet.
+ constexpr int kShortICMPSize = sizeof(icmp) - 2;
+
+ // Sending a malformed short ICMP message to a ping socket should fail.
+ ASSERT_THAT(RetryEINTR(sendto)(ping_sock.get(), &icmp, kShortICMPSize, 0,
+ reinterpret_cast<struct sockaddr*>(&addr_),
+ sizeof(addr_)),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Sending a malformed short ICMP message to a raw socket should not fail.
+ ASSERT_THAT(RetryEINTR(sendto)(s_, &icmp, kShortICMPSize, 0,
+ reinterpret_cast<struct sockaddr*>(&addr_),
+ sizeof(addr_)),
+ SyscallSucceedsWithValue(kShortICMPSize));
+
+ // Neither Ping nor Raw socket should have anything to read.
+ char recv_buf[kEmptyICMPSize];
+ EXPECT_THAT(RetryEINTR(recv)(ping_sock.get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+ EXPECT_THAT(RetryEINTR(recv)(s_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// Test that connect() sends packets to the right place.
+TEST_F(RawSocketICMPTest, SendAndReceiveViaConnect) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceeds());
+
+ // Prepare and send an ICMP packet. Use arbitrary junk for sequence and ID.
+ // None of that should matter for raw sockets - the kernel should still give
+ // us the packet.
+ struct icmphdr icmp;
+ icmp.type = ICMP_ECHO;
+ icmp.code = 0;
+ icmp.checksum = 0;
+ icmp.un.echo.sequence = 2003;
+ icmp.un.echo.id = 2004;
+ icmp.checksum = ICMPChecksum(icmp, NULL, 0);
+ ASSERT_THAT(send(s_, &icmp, sizeof(icmp), 0),
+ SyscallSucceedsWithValue(sizeof(icmp)));
+
+ ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp));
+}
+
+// Bind to localhost, then send and receive packets.
+TEST_F(RawSocketICMPTest, BindSendAndReceive) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceeds());
+
+ // Prepare and send an ICMP packet. Use arbitrary junk for checksum, sequence,
+ // and ID. None of that should matter for raw sockets - the kernel should
+ // still give us the packet.
+ struct icmphdr icmp;
+ icmp.type = ICMP_ECHO;
+ icmp.code = 0;
+ icmp.checksum = 0;
+ icmp.un.echo.sequence = 2004;
+ icmp.un.echo.id = 2007;
+ icmp.checksum = ICMPChecksum(icmp, NULL, 0);
+ ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp));
+
+ ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp));
+}
+
+// Bind and connect to localhost and send/receive packets.
+TEST_F(RawSocketICMPTest, BindConnectSendAndReceive) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceeds());
+
+ // Prepare and send an ICMP packet. Use arbitrary junk for checksum, sequence,
+ // and ID. None of that should matter for raw sockets - the kernel should
+ // still give us the packet.
+ struct icmphdr icmp;
+ icmp.type = ICMP_ECHO;
+ icmp.code = 0;
+ icmp.checksum = 0;
+ icmp.un.echo.sequence = 2010;
+ icmp.un.echo.id = 7;
+ icmp.checksum = ICMPChecksum(icmp, NULL, 0);
+ ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp));
+
+ ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp));
+}
+
+void RawSocketICMPTest::ExpectICMPSuccess(const struct icmphdr& icmp) {
+ // We're going to receive both the echo request and reply, but the order is
+ // indeterminate.
+ char recv_buf[kEmptyICMPSize];
+ struct sockaddr_in src;
+ bool received_request = false;
+ bool received_reply = false;
+
+ for (int i = 0; i < 2; i++) {
+ // Receive the packet.
+ ASSERT_NO_FATAL_FAILURE(ReceiveICMP(recv_buf, ABSL_ARRAYSIZE(recv_buf),
+ sizeof(struct icmphdr), &src));
+ EXPECT_EQ(memcmp(&src, &addr_, sizeof(src)), 0);
+ struct icmphdr* recvd_icmp =
+ reinterpret_cast<struct icmphdr*>(recv_buf + sizeof(struct iphdr));
+ switch (recvd_icmp->type) {
+ case ICMP_ECHO:
+ EXPECT_FALSE(received_request);
+ received_request = true;
+ // The packet should be identical to what we sent.
+ EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), &icmp, sizeof(icmp)),
+ 0);
+ break;
+
+ case ICMP_ECHOREPLY:
+ EXPECT_FALSE(received_reply);
+ received_reply = true;
+ // Most fields should be the same.
+ EXPECT_EQ(recvd_icmp->code, icmp.code);
+ EXPECT_EQ(recvd_icmp->un.echo.sequence, icmp.un.echo.sequence);
+ EXPECT_EQ(recvd_icmp->un.echo.id, icmp.un.echo.id);
+ // A couple are different.
+ EXPECT_EQ(recvd_icmp->type, ICMP_ECHOREPLY);
+ // The checksum computed over the reply should still be valid.
+ EXPECT_EQ(ICMPChecksum(*recvd_icmp, NULL, 0), 0);
+ break;
+ }
+ }
+
+ ASSERT_TRUE(received_request);
+ ASSERT_TRUE(received_reply);
+}
+
+void RawSocketICMPTest::SendEmptyICMP(const struct icmphdr& icmp) {
+ ASSERT_NO_FATAL_FAILURE(SendEmptyICMPTo(s_, addr_, icmp));
+}
+
+void RawSocketICMPTest::SendEmptyICMPTo(int sock,
+ const struct sockaddr_in& addr,
+ const struct icmphdr& icmp) {
+ // It's safe to use const_cast here because sendmsg won't modify the iovec or
+ // address.
+ struct iovec iov = {};
+ iov.iov_base = static_cast<void*>(const_cast<struct icmphdr*>(&icmp));
+ iov.iov_len = sizeof(icmp);
+ struct msghdr msg = {};
+ msg.msg_name = static_cast<void*>(const_cast<struct sockaddr_in*>(&addr));
+ msg.msg_namelen = sizeof(addr);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = NULL;
+ msg.msg_controllen = 0;
+ msg.msg_flags = 0;
+ ASSERT_THAT(sendmsg(sock, &msg, 0), SyscallSucceedsWithValue(sizeof(icmp)));
+}
+
+void RawSocketICMPTest::ReceiveICMP(char* recv_buf, size_t recv_buf_len,
+ size_t expected_size,
+ struct sockaddr_in* src) {
+ ASSERT_NO_FATAL_FAILURE(
+ ReceiveICMPFrom(recv_buf, recv_buf_len, expected_size, src, s_));
+}
+
+void RawSocketICMPTest::ReceiveICMPFrom(char* recv_buf, size_t recv_buf_len,
+ size_t expected_size,
+ struct sockaddr_in* src, int sock) {
+ struct iovec iov = {};
+ iov.iov_base = recv_buf;
+ iov.iov_len = recv_buf_len;
+ struct msghdr msg = {};
+ msg.msg_name = src;
+ msg.msg_namelen = sizeof(*src);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = NULL;
+ msg.msg_controllen = 0;
+ msg.msg_flags = 0;
+ // We should receive the ICMP packet plus 20 bytes of IP header.
+ ASSERT_THAT(recvmsg(sock, &msg, 0),
+ SyscallSucceedsWithValue(expected_size + sizeof(struct iphdr)));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/read.cc b/test/syscalls/linux/read.cc
new file mode 100644
index 000000000..2633ba31b
--- /dev/null
+++ b/test/syscalls/linux/read.cc
@@ -0,0 +1,118 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <unistd.h>
+
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class ReadTest : public ::testing::Test {
+ void SetUp() override {
+ name_ = NewTempAbsPath();
+ int fd;
+ ASSERT_THAT(fd = open(name_.c_str(), O_CREAT, 0644), SyscallSucceeds());
+ ASSERT_THAT(close(fd), SyscallSucceeds());
+ }
+
+ void TearDown() override { unlink(name_.c_str()); }
+
+ public:
+ std::string name_;
+};
+
+TEST_F(ReadTest, ZeroBuffer) {
+ int fd;
+ ASSERT_THAT(fd = open(name_.c_str(), O_RDWR), SyscallSucceeds());
+
+ char msg[] = "hello world";
+ EXPECT_THAT(PwriteFd(fd, msg, strlen(msg), 0),
+ SyscallSucceedsWithValue(strlen(msg)));
+
+ char buf[10];
+ EXPECT_THAT(ReadFd(fd, buf, 0), SyscallSucceedsWithValue(0));
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+TEST_F(ReadTest, EmptyFileReturnsZeroAtEOF) {
+ int fd;
+ ASSERT_THAT(fd = open(name_.c_str(), O_RDWR), SyscallSucceeds());
+
+ char eof_buf[10];
+ EXPECT_THAT(ReadFd(fd, eof_buf, 10), SyscallSucceedsWithValue(0));
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+TEST_F(ReadTest, EofAfterRead) {
+ int fd;
+ ASSERT_THAT(fd = open(name_.c_str(), O_RDWR), SyscallSucceeds());
+
+ // Write some bytes to be read.
+ constexpr char kMessage[] = "hello world";
+ EXPECT_THAT(PwriteFd(fd, kMessage, sizeof(kMessage), 0),
+ SyscallSucceedsWithValue(sizeof(kMessage)));
+
+ // Read all of the bytes at once.
+ char buf[sizeof(kMessage)];
+ EXPECT_THAT(ReadFd(fd, buf, sizeof(kMessage)),
+ SyscallSucceedsWithValue(sizeof(kMessage)));
+
+ // Read again with a non-zero buffer and expect EOF.
+ char eof_buf[10];
+ EXPECT_THAT(ReadFd(fd, eof_buf, 10), SyscallSucceedsWithValue(0));
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+TEST_F(ReadTest, DevNullReturnsEof) {
+ int fd;
+ ASSERT_THAT(fd = open("/dev/null", O_RDONLY), SyscallSucceeds());
+ std::vector<char> buf(1);
+ EXPECT_THAT(ReadFd(fd, buf.data(), 1), SyscallSucceedsWithValue(0));
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+const int kReadSize = 128 * 1024;
+
+// Do not allow random save as it could lead to partial reads.
+TEST_F(ReadTest, CanReadFullyFromDevZero_NoRandomSave) {
+ int fd;
+ ASSERT_THAT(fd = open("/dev/zero", O_RDONLY), SyscallSucceeds());
+
+ std::vector<char> buf(kReadSize, 1);
+ EXPECT_THAT(ReadFd(fd, buf.data(), kReadSize),
+ SyscallSucceedsWithValue(kReadSize));
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+ EXPECT_EQ(std::vector<char>(kReadSize, 0), buf);
+}
+
+TEST_F(ReadTest, ReadDirectoryFails) {
+ const FileDescriptor file =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(GetAbsoluteTestTmpdir(), O_RDONLY));
+ std::vector<char> buf(1);
+ EXPECT_THAT(ReadFd(file.get(), buf.data(), 1), SyscallFailsWithErrno(EISDIR));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/readahead.cc b/test/syscalls/linux/readahead.cc
new file mode 100644
index 000000000..09703b5c1
--- /dev/null
+++ b/test/syscalls/linux/readahead.cc
@@ -0,0 +1,91 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+
+#include "gtest/gtest.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(ReadaheadTest, InvalidFD) {
+ EXPECT_THAT(readahead(-1, 1, 1), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(ReadaheadTest, InvalidOffset) {
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+ EXPECT_THAT(readahead(fd.get(), -1, 1), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(ReadaheadTest, ValidOffset) {
+ constexpr char kData[] = "123";
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+
+ // N.B. The implementation of readahead is filesystem-specific, and a file
+ // backed by ram may return EINVAL because there is nothing to be read.
+ EXPECT_THAT(readahead(fd.get(), 1, 1), AnyOf(SyscallSucceedsWithValue(0),
+ SyscallFailsWithErrno(EINVAL)));
+}
+
+TEST(ReadaheadTest, PastEnd) {
+ constexpr char kData[] = "123";
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+ // See above.
+ EXPECT_THAT(readahead(fd.get(), 2, 2), AnyOf(SyscallSucceedsWithValue(0),
+ SyscallFailsWithErrno(EINVAL)));
+}
+
+TEST(ReadaheadTest, CrossesEnd) {
+ constexpr char kData[] = "123";
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+ // See above.
+ EXPECT_THAT(readahead(fd.get(), 4, 2), AnyOf(SyscallSucceedsWithValue(0),
+ SyscallFailsWithErrno(EINVAL)));
+}
+
+TEST(ReadaheadTest, WriteOnly) {
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_WRONLY));
+ EXPECT_THAT(readahead(fd.get(), 0, 1), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(ReadaheadTest, InvalidSize) {
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+ EXPECT_THAT(readahead(fd.get(), 0, -1), SyscallFailsWithErrno(EINVAL));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/readv.cc b/test/syscalls/linux/readv.cc
new file mode 100644
index 000000000..baaf9f757
--- /dev/null
+++ b/test/syscalls/linux/readv.cc
@@ -0,0 +1,294 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <limits.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/file_base.h"
+#include "test/syscalls/linux/readv_common.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/timer_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class ReadvTest : public FileTest {
+ void SetUp() override {
+ FileTest::SetUp();
+
+ ASSERT_THAT(write(test_file_fd_.get(), kReadvTestData, kReadvTestDataSize),
+ SyscallSucceedsWithValue(kReadvTestDataSize));
+ ASSERT_THAT(lseek(test_file_fd_.get(), 0, SEEK_SET),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(write(test_pipe_[1], kReadvTestData, kReadvTestDataSize),
+ SyscallSucceedsWithValue(kReadvTestDataSize));
+ }
+};
+
+TEST_F(ReadvTest, ReadOneBufferPerByte_File) {
+ ReadOneBufferPerByte(test_file_fd_.get());
+}
+
+TEST_F(ReadvTest, ReadOneBufferPerByte_Pipe) {
+ ReadOneBufferPerByte(test_pipe_[0]);
+}
+
+TEST_F(ReadvTest, ReadOneHalfAtATime_File) {
+ ReadOneHalfAtATime(test_file_fd_.get());
+}
+
+TEST_F(ReadvTest, ReadOneHalfAtATime_Pipe) {
+ ReadOneHalfAtATime(test_pipe_[0]);
+}
+
+TEST_F(ReadvTest, ReadAllOneBuffer_File) {
+ ReadAllOneBuffer(test_file_fd_.get());
+}
+
+TEST_F(ReadvTest, ReadAllOneBuffer_Pipe) { ReadAllOneBuffer(test_pipe_[0]); }
+
+TEST_F(ReadvTest, ReadAllOneLargeBuffer_File) {
+ ReadAllOneLargeBuffer(test_file_fd_.get());
+}
+
+TEST_F(ReadvTest, ReadAllOneLargeBuffer_Pipe) {
+ ReadAllOneLargeBuffer(test_pipe_[0]);
+}
+
+TEST_F(ReadvTest, ReadBuffersOverlapping_File) {
+ ReadBuffersOverlapping(test_file_fd_.get());
+}
+
+TEST_F(ReadvTest, ReadBuffersOverlapping_Pipe) {
+ ReadBuffersOverlapping(test_pipe_[0]);
+}
+
+TEST_F(ReadvTest, ReadBuffersDiscontinuous_File) {
+ ReadBuffersDiscontinuous(test_file_fd_.get());
+}
+
+TEST_F(ReadvTest, ReadBuffersDiscontinuous_Pipe) {
+ ReadBuffersDiscontinuous(test_pipe_[0]);
+}
+
+TEST_F(ReadvTest, ReadIovecsCompletelyFilled_File) {
+ ReadIovecsCompletelyFilled(test_file_fd_.get());
+}
+
+TEST_F(ReadvTest, ReadIovecsCompletelyFilled_Pipe) {
+ ReadIovecsCompletelyFilled(test_pipe_[0]);
+}
+
+TEST_F(ReadvTest, BadFileDescriptor) {
+ char buffer[1024];
+ struct iovec iov[1];
+ iov[0].iov_base = buffer;
+ iov[0].iov_len = 1024;
+
+ ASSERT_THAT(readv(-1, iov, 1024), SyscallFailsWithErrno(EBADF));
+}
+
+TEST_F(ReadvTest, BadIovecsPointer_File) {
+ ASSERT_THAT(readv(test_file_fd_.get(), nullptr, 1),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(ReadvTest, BadIovecsPointer_Pipe) {
+ ASSERT_THAT(readv(test_pipe_[0], nullptr, 1), SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(ReadvTest, BadIovecBase_File) {
+ struct iovec iov[1];
+ iov[0].iov_base = nullptr;
+ iov[0].iov_len = 1024;
+ ASSERT_THAT(readv(test_file_fd_.get(), iov, 1),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(ReadvTest, BadIovecBase_Pipe) {
+ struct iovec iov[1];
+ iov[0].iov_base = nullptr;
+ iov[0].iov_len = 1024;
+ ASSERT_THAT(readv(test_pipe_[0], iov, 1), SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(ReadvTest, ZeroIovecs_File) {
+ struct iovec iov[1];
+ iov[0].iov_base = 0;
+ iov[0].iov_len = 0;
+ ASSERT_THAT(readv(test_file_fd_.get(), iov, 1), SyscallSucceeds());
+}
+
+TEST_F(ReadvTest, ZeroIovecs_Pipe) {
+ struct iovec iov[1];
+ iov[0].iov_base = 0;
+ iov[0].iov_len = 0;
+ ASSERT_THAT(readv(test_pipe_[0], iov, 1), SyscallSucceeds());
+}
+
+TEST_F(ReadvTest, NotReadable_File) {
+ char buffer[1024];
+ struct iovec iov[1];
+ iov[0].iov_base = buffer;
+ iov[0].iov_len = 1024;
+
+ std::string wronly_file = NewTempAbsPath();
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(wronly_file, O_CREAT | O_WRONLY, S_IRUSR | S_IWUSR));
+ ASSERT_THAT(readv(fd.get(), iov, 1), SyscallFailsWithErrno(EBADF));
+ fd.reset(); // Close before unlinking.
+ ASSERT_THAT(unlink(wronly_file.c_str()), SyscallSucceeds());
+}
+
+TEST_F(ReadvTest, NotReadable_Pipe) {
+ char buffer[1024];
+ struct iovec iov[1];
+ iov[0].iov_base = buffer;
+ iov[0].iov_len = 1024;
+ ASSERT_THAT(readv(test_pipe_[1], iov, 1), SyscallFailsWithErrno(EBADF));
+}
+
+TEST_F(ReadvTest, DirNotReadable) {
+ char buffer[1024];
+ struct iovec iov[1];
+ iov[0].iov_base = buffer;
+ iov[0].iov_len = 1024;
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(GetAbsoluteTestTmpdir(), O_RDONLY));
+ ASSERT_THAT(readv(fd.get(), iov, 1), SyscallFailsWithErrno(EISDIR));
+}
+
+TEST_F(ReadvTest, OffsetIncremented) {
+ char* buffer = reinterpret_cast<char*>(malloc(kReadvTestDataSize));
+ struct iovec iov[1];
+ iov[0].iov_base = buffer;
+ iov[0].iov_len = kReadvTestDataSize;
+
+ ASSERT_THAT(readv(test_file_fd_.get(), iov, 1),
+ SyscallSucceedsWithValue(kReadvTestDataSize));
+ ASSERT_THAT(lseek(test_file_fd_.get(), 0, SEEK_CUR),
+ SyscallSucceedsWithValue(kReadvTestDataSize));
+
+ free(buffer);
+}
+
+TEST_F(ReadvTest, EndOfFile) {
+ char* buffer = reinterpret_cast<char*>(malloc(kReadvTestDataSize));
+ struct iovec iov[1];
+ iov[0].iov_base = buffer;
+ iov[0].iov_len = kReadvTestDataSize;
+ ASSERT_THAT(readv(test_file_fd_.get(), iov, 1),
+ SyscallSucceedsWithValue(kReadvTestDataSize));
+ free(buffer);
+
+ buffer = reinterpret_cast<char*>(malloc(kReadvTestDataSize));
+ iov[0].iov_base = buffer;
+ iov[0].iov_len = kReadvTestDataSize;
+ ASSERT_THAT(readv(test_file_fd_.get(), iov, 1), SyscallSucceedsWithValue(0));
+ free(buffer);
+}
+
+TEST_F(ReadvTest, WouldBlock_Pipe) {
+ struct iovec iov[1];
+ iov[0].iov_base = reinterpret_cast<char*>(malloc(kReadvTestDataSize));
+ iov[0].iov_len = kReadvTestDataSize;
+ ASSERT_THAT(readv(test_pipe_[0], iov, 1),
+ SyscallSucceedsWithValue(kReadvTestDataSize));
+ free(iov[0].iov_base);
+
+ iov[0].iov_base = reinterpret_cast<char*>(malloc(kReadvTestDataSize));
+ ASSERT_THAT(readv(test_pipe_[0], iov, 1), SyscallFailsWithErrno(EAGAIN));
+ free(iov[0].iov_base);
+}
+
+TEST_F(ReadvTest, ZeroBuffer) {
+ char buf[10];
+ struct iovec iov[1];
+ iov[0].iov_base = buf;
+ iov[0].iov_len = 0;
+ ASSERT_THAT(readv(test_pipe_[0], iov, 1), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(ReadvTest, NullIovecInNonemptyArray) {
+ std::vector<char> buf(kReadvTestDataSize);
+ struct iovec iov[2];
+ iov[0].iov_base = nullptr;
+ iov[0].iov_len = 0;
+ iov[1].iov_base = buf.data();
+ iov[1].iov_len = kReadvTestDataSize;
+ ASSERT_THAT(readv(test_file_fd_.get(), iov, 2),
+ SyscallSucceedsWithValue(kReadvTestDataSize));
+}
+
+TEST_F(ReadvTest, IovecOutsideTaskAddressRangeInNonemptyArray) {
+ std::vector<char> buf(kReadvTestDataSize);
+ struct iovec iov[2];
+ iov[0].iov_base = reinterpret_cast<void*>(~static_cast<uintptr_t>(0));
+ iov[0].iov_len = 0;
+ iov[1].iov_base = buf.data();
+ iov[1].iov_len = kReadvTestDataSize;
+ ASSERT_THAT(readv(test_file_fd_.get(), iov, 2),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+// This test depends on the maximum extent of a single readv() syscall, so
+// we can't tolerate interruption from saving.
+TEST(ReadvTestNoFixture, TruncatedAtMax_NoRandomSave) {
+ // Ensure that we won't be interrupted by ITIMER_PROF. This is particularly
+ // important in environments where automated profiling tools may start
+ // ITIMER_PROF automatically.
+ struct itimerval itv = {};
+ auto const cleanup_itimer =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_PROF, itv));
+
+ // From Linux's include/linux/fs.h.
+ size_t const MAX_RW_COUNT = INT_MAX & ~(kPageSize - 1);
+
+ // Create an iovec array with 3 segments pointing to consecutive parts of a
+ // buffer. The first covers all but the last three pages, and should be
+ // written to in its entirety. The second covers the last page before
+ // MAX_RW_COUNT and the first page after; only the first page should be
+ // written to. The third covers the last page of the buffer, and should be
+ // skipped entirely.
+ size_t const kBufferSize = MAX_RW_COUNT + 2 * kPageSize;
+ size_t const kFirstOffset = MAX_RW_COUNT - kPageSize;
+ size_t const kSecondOffset = MAX_RW_COUNT + kPageSize;
+ // The buffer is too big to fit on the stack.
+ std::vector<char> buf(kBufferSize);
+ struct iovec iov[3];
+ iov[0].iov_base = buf.data();
+ iov[0].iov_len = kFirstOffset;
+ iov[1].iov_base = buf.data() + kFirstOffset;
+ iov[1].iov_len = kSecondOffset - kFirstOffset;
+ iov[2].iov_base = buf.data() + kSecondOffset;
+ iov[2].iov_len = kBufferSize - kSecondOffset;
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDONLY));
+ EXPECT_THAT(readv(fd.get(), iov, 3), SyscallSucceedsWithValue(MAX_RW_COUNT));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/readv_common.cc b/test/syscalls/linux/readv_common.cc
new file mode 100644
index 000000000..2694dc64f
--- /dev/null
+++ b/test/syscalls/linux/readv_common.cc
@@ -0,0 +1,220 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// MatchesStringLength checks that a tuple argument of (struct iovec *, int)
+// corresponding to an iovec array and its length, contains data that matches
+// the string length strlen.
+MATCHER_P(MatchesStringLength, strlen, "") {
+ struct iovec* iovs = arg.first;
+ int niov = arg.second;
+ int offset = 0;
+ for (int i = 0; i < niov; i++) {
+ offset += iovs[i].iov_len;
+ }
+ if (offset != static_cast<int>(strlen)) {
+ *result_listener << offset;
+ return false;
+ }
+ return true;
+}
+
+// MatchesStringValue checks that a tuple argument of (struct iovec *, int)
+// corresponding to an iovec array and its length, contains data that matches
+// the string value str.
+MATCHER_P(MatchesStringValue, str, "") {
+ struct iovec* iovs = arg.first;
+ int len = strlen(str);
+ int niov = arg.second;
+ int offset = 0;
+ for (int i = 0; i < niov; i++) {
+ struct iovec iov = iovs[i];
+ if (len < offset) {
+ *result_listener << "strlen " << len << " < offset " << offset;
+ return false;
+ }
+ if (strncmp(static_cast<char*>(iov.iov_base), &str[offset], iov.iov_len)) {
+ absl::string_view iovec_string(static_cast<char*>(iov.iov_base),
+ iov.iov_len);
+ *result_listener << iovec_string << " @offset " << offset;
+ return false;
+ }
+ offset += iov.iov_len;
+ }
+ return true;
+}
+
+extern const char kReadvTestData[] =
+ "127.0.0.1 localhost"
+ ""
+ "# The following lines are desirable for IPv6 capable hosts"
+ "::1 ip6-localhost ip6-loopback"
+ "fe00::0 ip6-localnet"
+ "ff00::0 ip6-mcastprefix"
+ "ff02::1 ip6-allnodes"
+ "ff02::2 ip6-allrouters"
+ "ff02::3 ip6-allhosts"
+ "192.168.1.100 a"
+ "93.184.216.34 foo.bar.example.com xcpu";
+extern const size_t kReadvTestDataSize = sizeof(kReadvTestData);
+
+static void ReadAllOneProvidedBuffer(int fd, std::vector<char>* buffer) {
+ struct iovec iovs[1];
+ iovs[0].iov_base = buffer->data();
+ iovs[0].iov_len = kReadvTestDataSize;
+
+ ASSERT_THAT(readv(fd, iovs, 1), SyscallSucceedsWithValue(kReadvTestDataSize));
+
+ std::pair<struct iovec*, int> iovec_desc(iovs, 1);
+ EXPECT_THAT(iovec_desc, MatchesStringLength(kReadvTestDataSize));
+ EXPECT_THAT(iovec_desc, MatchesStringValue(kReadvTestData));
+}
+
+void ReadAllOneBuffer(int fd) {
+ std::vector<char> buffer(kReadvTestDataSize);
+ ReadAllOneProvidedBuffer(fd, &buffer);
+}
+
+void ReadAllOneLargeBuffer(int fd) {
+ std::vector<char> buffer(10 * kReadvTestDataSize);
+ ReadAllOneProvidedBuffer(fd, &buffer);
+}
+
+void ReadOneHalfAtATime(int fd) {
+ int len0 = kReadvTestDataSize / 2;
+ int len1 = kReadvTestDataSize - len0;
+ std::vector<char> buffer0(len0);
+ std::vector<char> buffer1(len1);
+
+ struct iovec iovs[2];
+ iovs[0].iov_base = buffer0.data();
+ iovs[0].iov_len = len0;
+ iovs[1].iov_base = buffer1.data();
+ iovs[1].iov_len = len1;
+
+ ASSERT_THAT(readv(fd, iovs, 2), SyscallSucceedsWithValue(kReadvTestDataSize));
+
+ std::pair<struct iovec*, int> iovec_desc(iovs, 2);
+ EXPECT_THAT(iovec_desc, MatchesStringLength(kReadvTestDataSize));
+ EXPECT_THAT(iovec_desc, MatchesStringValue(kReadvTestData));
+}
+
+void ReadOneBufferPerByte(int fd) {
+ std::vector<char> buffer(kReadvTestDataSize);
+ std::vector<struct iovec> iovs(kReadvTestDataSize);
+ char* buffer_ptr = buffer.data();
+ struct iovec* iovs_ptr = iovs.data();
+
+ for (int i = 0; i < static_cast<int>(kReadvTestDataSize); i++) {
+ struct iovec iov = {
+ .iov_base = &buffer_ptr[i],
+ .iov_len = 1,
+ };
+ iovs_ptr[i] = iov;
+ }
+
+ ASSERT_THAT(readv(fd, iovs_ptr, kReadvTestDataSize),
+ SyscallSucceedsWithValue(kReadvTestDataSize));
+
+ std::pair<struct iovec*, int> iovec_desc(iovs.data(), kReadvTestDataSize);
+ EXPECT_THAT(iovec_desc, MatchesStringLength(kReadvTestDataSize));
+ EXPECT_THAT(iovec_desc, MatchesStringValue(kReadvTestData));
+}
+
+void ReadBuffersOverlapping(int fd) {
+ // overlap the first overlap_bytes.
+ int overlap_bytes = 8;
+ std::vector<char> buffer(kReadvTestDataSize);
+
+ // overlapping causes us to get more data.
+ int expected_size = kReadvTestDataSize + overlap_bytes;
+ std::vector<char> expected(expected_size);
+ char* expected_ptr = expected.data();
+ memcpy(expected_ptr, &kReadvTestData[overlap_bytes], overlap_bytes);
+ memcpy(&expected_ptr[overlap_bytes], &kReadvTestData[overlap_bytes],
+ kReadvTestDataSize - overlap_bytes);
+
+ struct iovec iovs[2];
+ iovs[0].iov_base = buffer.data();
+ iovs[0].iov_len = overlap_bytes;
+ iovs[1].iov_base = buffer.data();
+ iovs[1].iov_len = kReadvTestDataSize;
+
+ ASSERT_THAT(readv(fd, iovs, 2), SyscallSucceedsWithValue(kReadvTestDataSize));
+
+ std::pair<struct iovec*, int> iovec_desc(iovs, 2);
+ EXPECT_THAT(iovec_desc, MatchesStringLength(expected_size));
+ EXPECT_THAT(iovec_desc, MatchesStringValue(expected_ptr));
+}
+
+void ReadBuffersDiscontinuous(int fd) {
+ // Each iov is 1 byte separated by 1 byte.
+ std::vector<char> buffer(kReadvTestDataSize * 2);
+ std::vector<struct iovec> iovs(kReadvTestDataSize);
+
+ char* buffer_ptr = buffer.data();
+ struct iovec* iovs_ptr = iovs.data();
+
+ for (int i = 0; i < static_cast<int>(kReadvTestDataSize); i++) {
+ struct iovec iov = {
+ .iov_base = &buffer_ptr[i * 2],
+ .iov_len = 1,
+ };
+ iovs_ptr[i] = iov;
+ }
+
+ ASSERT_THAT(readv(fd, iovs_ptr, kReadvTestDataSize),
+ SyscallSucceedsWithValue(kReadvTestDataSize));
+
+ std::pair<struct iovec*, int> iovec_desc(iovs.data(), kReadvTestDataSize);
+ EXPECT_THAT(iovec_desc, MatchesStringLength(kReadvTestDataSize));
+ EXPECT_THAT(iovec_desc, MatchesStringValue(kReadvTestData));
+}
+
+void ReadIovecsCompletelyFilled(int fd) {
+ int half = kReadvTestDataSize / 2;
+ std::vector<char> buffer(kReadvTestDataSize);
+ char* buffer_ptr = buffer.data();
+ memset(buffer.data(), '\0', kReadvTestDataSize);
+
+ struct iovec iovs[2];
+ iovs[0].iov_base = buffer.data();
+ iovs[0].iov_len = half;
+ iovs[1].iov_base = &buffer_ptr[half];
+ iovs[1].iov_len = half;
+
+ ASSERT_THAT(readv(fd, iovs, 2), SyscallSucceedsWithValue(half * 2));
+
+ std::pair<struct iovec*, int> iovec_desc(iovs, 2);
+ EXPECT_THAT(iovec_desc, MatchesStringLength(half * 2));
+ EXPECT_THAT(iovec_desc, MatchesStringValue(kReadvTestData));
+
+ char* str = static_cast<char*>(iovs[0].iov_base);
+ str[iovs[0].iov_len - 1] = '\0';
+ ASSERT_EQ(half - 1, strlen(str));
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/readv_common.h b/test/syscalls/linux/readv_common.h
new file mode 100644
index 000000000..2fa40c35f
--- /dev/null
+++ b/test/syscalls/linux/readv_common.h
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_READV_COMMON_H_
+#define GVISOR_TEST_SYSCALLS_READV_COMMON_H_
+
+#include <stddef.h>
+
+namespace gvisor {
+namespace testing {
+
+// A NUL-terminated string containing the data used by tests using the following
+// test helpers.
+extern const char kReadvTestData[];
+
+// The size of kReadvTestData, including the terminating NUL.
+extern const size_t kReadvTestDataSize;
+
+// ReadAllOneBuffer asserts that it can read kReadvTestData from an fd using
+// exactly one iovec.
+void ReadAllOneBuffer(int fd);
+
+// ReadAllOneLargeBuffer asserts that it can read kReadvTestData from an fd
+// using exactly one iovec containing an overly large buffer.
+void ReadAllOneLargeBuffer(int fd);
+
+// ReadOneHalfAtATime asserts that it can read test_data_from an fd using
+// exactly two iovecs that are roughly equivalent in size.
+void ReadOneHalfAtATime(int fd);
+
+// ReadOneBufferPerByte asserts that it can read kReadvTestData from an fd
+// using one iovec per byte.
+void ReadOneBufferPerByte(int fd);
+
+// ReadBuffersOverlapping asserts that it can read kReadvTestData from an fd
+// where two iovecs are overlapping.
+void ReadBuffersOverlapping(int fd);
+
+// ReadBuffersDiscontinuous asserts that it can read kReadvTestData from an fd
+// where each iovec is discontinuous from the next by 1 byte.
+void ReadBuffersDiscontinuous(int fd);
+
+// ReadIovecsCompletelyFilled asserts that the previous iovec is completely
+// filled before moving onto the next.
+void ReadIovecsCompletelyFilled(int fd);
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_READV_COMMON_H_
diff --git a/test/syscalls/linux/readv_socket.cc b/test/syscalls/linux/readv_socket.cc
new file mode 100644
index 000000000..dd6fb7008
--- /dev/null
+++ b/test/syscalls/linux/readv_socket.cc
@@ -0,0 +1,212 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/readv_common.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class ReadvSocketTest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ test_unix_stream_socket_[0] = -1;
+ test_unix_stream_socket_[1] = -1;
+ test_unix_dgram_socket_[0] = -1;
+ test_unix_dgram_socket_[1] = -1;
+ test_unix_seqpacket_socket_[0] = -1;
+ test_unix_seqpacket_socket_[1] = -1;
+
+ ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, test_unix_stream_socket_),
+ SyscallSucceeds());
+ ASSERT_THAT(fcntl(test_unix_stream_socket_[0], F_SETFL, O_NONBLOCK),
+ SyscallSucceeds());
+ ASSERT_THAT(socketpair(AF_UNIX, SOCK_DGRAM, 0, test_unix_dgram_socket_),
+ SyscallSucceeds());
+ ASSERT_THAT(fcntl(test_unix_dgram_socket_[0], F_SETFL, O_NONBLOCK),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ socketpair(AF_UNIX, SOCK_SEQPACKET, 0, test_unix_seqpacket_socket_),
+ SyscallSucceeds());
+ ASSERT_THAT(fcntl(test_unix_seqpacket_socket_[0], F_SETFL, O_NONBLOCK),
+ SyscallSucceeds());
+
+ ASSERT_THAT(
+ write(test_unix_stream_socket_[1], kReadvTestData, kReadvTestDataSize),
+ SyscallSucceedsWithValue(kReadvTestDataSize));
+ ASSERT_THAT(
+ write(test_unix_dgram_socket_[1], kReadvTestData, kReadvTestDataSize),
+ SyscallSucceedsWithValue(kReadvTestDataSize));
+ ASSERT_THAT(write(test_unix_seqpacket_socket_[1], kReadvTestData,
+ kReadvTestDataSize),
+ SyscallSucceedsWithValue(kReadvTestDataSize));
+ }
+
+ void TearDown() override {
+ close(test_unix_stream_socket_[0]);
+ close(test_unix_stream_socket_[1]);
+
+ close(test_unix_dgram_socket_[0]);
+ close(test_unix_dgram_socket_[1]);
+
+ close(test_unix_seqpacket_socket_[0]);
+ close(test_unix_seqpacket_socket_[1]);
+ }
+
+ int test_unix_stream_socket_[2];
+ int test_unix_dgram_socket_[2];
+ int test_unix_seqpacket_socket_[2];
+};
+
+TEST_F(ReadvSocketTest, ReadOneBufferPerByte_StreamSocket) {
+ ReadOneBufferPerByte(test_unix_stream_socket_[0]);
+}
+
+TEST_F(ReadvSocketTest, ReadOneBufferPerByte_DgramSocket) {
+ ReadOneBufferPerByte(test_unix_dgram_socket_[0]);
+}
+
+TEST_F(ReadvSocketTest, ReadOneBufferPerByte_SeqPacketSocket) {
+ ReadOneBufferPerByte(test_unix_seqpacket_socket_[0]);
+}
+
+TEST_F(ReadvSocketTest, ReadOneHalfAtATime_StreamSocket) {
+ ReadOneHalfAtATime(test_unix_stream_socket_[0]);
+}
+
+TEST_F(ReadvSocketTest, ReadOneHalfAtATime_DgramSocket) {
+ ReadOneHalfAtATime(test_unix_dgram_socket_[0]);
+}
+
+TEST_F(ReadvSocketTest, ReadAllOneBuffer_StreamSocket) {
+ ReadAllOneBuffer(test_unix_stream_socket_[0]);
+}
+
+TEST_F(ReadvSocketTest, ReadAllOneBuffer_DgramSocket) {
+ ReadAllOneBuffer(test_unix_dgram_socket_[0]);
+}
+
+TEST_F(ReadvSocketTest, ReadAllOneLargeBuffer_StreamSocket) {
+ ReadAllOneLargeBuffer(test_unix_stream_socket_[0]);
+}
+
+TEST_F(ReadvSocketTest, ReadAllOneLargeBuffer_DgramSocket) {
+ ReadAllOneLargeBuffer(test_unix_dgram_socket_[0]);
+}
+
+TEST_F(ReadvSocketTest, ReadBuffersOverlapping_StreamSocket) {
+ ReadBuffersOverlapping(test_unix_stream_socket_[0]);
+}
+
+TEST_F(ReadvSocketTest, ReadBuffersOverlapping_DgramSocket) {
+ ReadBuffersOverlapping(test_unix_dgram_socket_[0]);
+}
+
+TEST_F(ReadvSocketTest, ReadBuffersDiscontinuous_StreamSocket) {
+ ReadBuffersDiscontinuous(test_unix_stream_socket_[0]);
+}
+
+TEST_F(ReadvSocketTest, ReadBuffersDiscontinuous_DgramSocket) {
+ ReadBuffersDiscontinuous(test_unix_dgram_socket_[0]);
+}
+
+TEST_F(ReadvSocketTest, ReadIovecsCompletelyFilled_StreamSocket) {
+ ReadIovecsCompletelyFilled(test_unix_stream_socket_[0]);
+}
+
+TEST_F(ReadvSocketTest, ReadIovecsCompletelyFilled_DgramSocket) {
+ ReadIovecsCompletelyFilled(test_unix_dgram_socket_[0]);
+}
+
+TEST_F(ReadvSocketTest, BadIovecsPointer_StreamSocket) {
+ ASSERT_THAT(readv(test_unix_stream_socket_[0], nullptr, 1),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(ReadvSocketTest, BadIovecsPointer_DgramSocket) {
+ ASSERT_THAT(readv(test_unix_dgram_socket_[0], nullptr, 1),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(ReadvSocketTest, BadIovecBase_StreamSocket) {
+ struct iovec iov[1];
+ iov[0].iov_base = nullptr;
+ iov[0].iov_len = 1024;
+ ASSERT_THAT(readv(test_unix_stream_socket_[0], iov, 1),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(ReadvSocketTest, BadIovecBase_DgramSocket) {
+ struct iovec iov[1];
+ iov[0].iov_base = nullptr;
+ iov[0].iov_len = 1024;
+ ASSERT_THAT(readv(test_unix_dgram_socket_[0], iov, 1),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(ReadvSocketTest, ZeroIovecs_StreamSocket) {
+ struct iovec iov[1];
+ iov[0].iov_base = 0;
+ iov[0].iov_len = 0;
+ ASSERT_THAT(readv(test_unix_stream_socket_[0], iov, 1), SyscallSucceeds());
+}
+
+TEST_F(ReadvSocketTest, ZeroIovecs_DgramSocket) {
+ struct iovec iov[1];
+ iov[0].iov_base = 0;
+ iov[0].iov_len = 0;
+ ASSERT_THAT(readv(test_unix_dgram_socket_[0], iov, 1), SyscallSucceeds());
+}
+
+TEST_F(ReadvSocketTest, WouldBlock_StreamSocket) {
+ struct iovec iov[1];
+ iov[0].iov_base = reinterpret_cast<char*>(malloc(kReadvTestDataSize));
+ iov[0].iov_len = kReadvTestDataSize;
+ ASSERT_THAT(readv(test_unix_stream_socket_[0], iov, 1),
+ SyscallSucceedsWithValue(kReadvTestDataSize));
+ free(iov[0].iov_base);
+
+ iov[0].iov_base = reinterpret_cast<char*>(malloc(kReadvTestDataSize));
+ ASSERT_THAT(readv(test_unix_stream_socket_[0], iov, 1),
+ SyscallFailsWithErrno(EAGAIN));
+ free(iov[0].iov_base);
+}
+
+TEST_F(ReadvSocketTest, WouldBlock_DgramSocket) {
+ struct iovec iov[1];
+ iov[0].iov_base = reinterpret_cast<char*>(malloc(kReadvTestDataSize));
+ iov[0].iov_len = kReadvTestDataSize;
+ ASSERT_THAT(readv(test_unix_dgram_socket_[0], iov, 1),
+ SyscallSucceedsWithValue(kReadvTestDataSize));
+ free(iov[0].iov_base);
+
+ iov[0].iov_base = reinterpret_cast<char*>(malloc(kReadvTestDataSize));
+ ASSERT_THAT(readv(test_unix_dgram_socket_[0], iov, 1),
+ SyscallFailsWithErrno(EAGAIN));
+ free(iov[0].iov_base);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/rename.cc b/test/syscalls/linux/rename.cc
new file mode 100644
index 000000000..833c0dc4f
--- /dev/null
+++ b/test/syscalls/linux/rename.cc
@@ -0,0 +1,394 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdio.h>
+
+#include <string>
+
+#include "gtest/gtest.h"
+#include "absl/strings/string_view.h"
+#include "test/util/capability_util.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(RenameTest, RootToAnything) {
+ ASSERT_THAT(rename("/", "/bin"), SyscallFailsWithErrno(EBUSY));
+}
+
+TEST(RenameTest, AnythingToRoot) {
+ ASSERT_THAT(rename("/bin", "/"), SyscallFailsWithErrno(EBUSY));
+}
+
+TEST(RenameTest, SourceIsAncestorOfTarget) {
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto subdir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir.path()));
+ ASSERT_THAT(rename(dir.path().c_str(), subdir.path().c_str()),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Try an even deeper directory.
+ auto deep_subdir =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(subdir.path()));
+ ASSERT_THAT(rename(dir.path().c_str(), deep_subdir.path().c_str()),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(RenameTest, TargetIsAncestorOfSource) {
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto subdir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir.path()));
+ ASSERT_THAT(rename(subdir.path().c_str(), dir.path().c_str()),
+ SyscallFailsWithErrno(ENOTEMPTY));
+
+ // Try an even deeper directory.
+ auto deep_subdir =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(subdir.path()));
+ ASSERT_THAT(rename(deep_subdir.path().c_str(), dir.path().c_str()),
+ SyscallFailsWithErrno(ENOTEMPTY));
+}
+
+TEST(RenameTest, FileToSelf) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ EXPECT_THAT(rename(f.path().c_str(), f.path().c_str()), SyscallSucceeds());
+}
+
+TEST(RenameTest, DirectoryToSelf) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(rename(f.path().c_str(), f.path().c_str()), SyscallSucceeds());
+}
+
+TEST(RenameTest, FileToSameDirectory) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ std::string const newpath = NewTempAbsPath();
+ ASSERT_THAT(rename(f.path().c_str(), newpath.c_str()), SyscallSucceeds());
+ std::string const oldpath = f.release();
+ f.reset(newpath);
+ EXPECT_THAT(Exists(oldpath), IsPosixErrorOkAndHolds(false));
+ EXPECT_THAT(Exists(newpath), IsPosixErrorOkAndHolds(true));
+}
+
+TEST(RenameTest, DirectoryToSameDirectory) {
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ std::string const newpath = NewTempAbsPath();
+ ASSERT_THAT(rename(dir.path().c_str(), newpath.c_str()), SyscallSucceeds());
+ std::string const oldpath = dir.release();
+ dir.reset(newpath);
+ EXPECT_THAT(Exists(oldpath), IsPosixErrorOkAndHolds(false));
+ EXPECT_THAT(Exists(newpath), IsPosixErrorOkAndHolds(true));
+}
+
+TEST(RenameTest, FileToParentDirectory) {
+ auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir1.path()));
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir2.path()));
+ std::string const newpath = NewTempAbsPathInDir(dir1.path());
+ ASSERT_THAT(rename(f.path().c_str(), newpath.c_str()), SyscallSucceeds());
+ std::string const oldpath = f.release();
+ f.reset(newpath);
+ EXPECT_THAT(Exists(oldpath), IsPosixErrorOkAndHolds(false));
+ EXPECT_THAT(Exists(newpath), IsPosixErrorOkAndHolds(true));
+}
+
+TEST(RenameTest, DirectoryToParentDirectory) {
+ auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir1.path()));
+ auto dir3 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir2.path()));
+ EXPECT_THAT(IsDirectory(dir3.path()), IsPosixErrorOkAndHolds(true));
+ std::string const newpath = NewTempAbsPathInDir(dir1.path());
+ ASSERT_THAT(rename(dir3.path().c_str(), newpath.c_str()), SyscallSucceeds());
+ std::string const oldpath = dir3.release();
+ dir3.reset(newpath);
+ EXPECT_THAT(Exists(oldpath), IsPosixErrorOkAndHolds(false));
+ EXPECT_THAT(Exists(newpath), IsPosixErrorOkAndHolds(true));
+ EXPECT_THAT(IsDirectory(newpath), IsPosixErrorOkAndHolds(true));
+}
+
+TEST(RenameTest, FileToChildDirectory) {
+ auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir1.path()));
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path()));
+ std::string const newpath = NewTempAbsPathInDir(dir2.path());
+ ASSERT_THAT(rename(f.path().c_str(), newpath.c_str()), SyscallSucceeds());
+ std::string const oldpath = f.release();
+ f.reset(newpath);
+ EXPECT_THAT(Exists(oldpath), IsPosixErrorOkAndHolds(false));
+ EXPECT_THAT(Exists(newpath), IsPosixErrorOkAndHolds(true));
+}
+
+TEST(RenameTest, DirectoryToChildDirectory) {
+ auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir1.path()));
+ auto dir3 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir1.path()));
+ std::string const newpath = NewTempAbsPathInDir(dir2.path());
+ ASSERT_THAT(rename(dir3.path().c_str(), newpath.c_str()), SyscallSucceeds());
+ std::string const oldpath = dir3.release();
+ dir3.reset(newpath);
+ EXPECT_THAT(Exists(oldpath), IsPosixErrorOkAndHolds(false));
+ EXPECT_THAT(Exists(newpath), IsPosixErrorOkAndHolds(true));
+ EXPECT_THAT(IsDirectory(newpath), IsPosixErrorOkAndHolds(true));
+}
+
+TEST(RenameTest, DirectoryToOwnChildDirectory) {
+ auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir1.path()));
+ std::string const newpath = NewTempAbsPathInDir(dir2.path());
+ ASSERT_THAT(rename(dir1.path().c_str(), newpath.c_str()),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(RenameTest, FileOverwritesFile) {
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto f1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ dir.path(), "first", TempPath::kDefaultFileMode));
+ auto f2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ dir.path(), "second", TempPath::kDefaultFileMode));
+ ASSERT_THAT(rename(f1.path().c_str(), f2.path().c_str()), SyscallSucceeds());
+ EXPECT_THAT(Exists(f1.path()), IsPosixErrorOkAndHolds(false));
+
+ f1.release();
+ std::string f2_contents;
+ ASSERT_NO_ERRNO(GetContents(f2.path(), &f2_contents));
+ EXPECT_EQ("first", f2_contents);
+}
+
+TEST(RenameTest, DirectoryOverwritesDirectoryLinkCount) {
+ auto parent1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(Links(parent1.path()), IsPosixErrorOkAndHolds(2));
+
+ auto parent2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(Links(parent2.path()), IsPosixErrorOkAndHolds(2));
+
+ auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(parent1.path()));
+ auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(parent2.path()));
+
+ EXPECT_THAT(Links(parent1.path()), IsPosixErrorOkAndHolds(3));
+ EXPECT_THAT(Links(parent2.path()), IsPosixErrorOkAndHolds(3));
+
+ ASSERT_THAT(rename(dir1.path().c_str(), dir2.path().c_str()),
+ SyscallSucceeds());
+
+ EXPECT_THAT(Links(parent1.path()), IsPosixErrorOkAndHolds(2));
+ EXPECT_THAT(Links(parent2.path()), IsPosixErrorOkAndHolds(3));
+}
+
+TEST(RenameTest, FileDoesNotExist) {
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const std::string source = JoinPath(dir.path(), "source");
+ const std::string dest = JoinPath(dir.path(), "dest");
+ ASSERT_THAT(rename(source.c_str(), dest.c_str()),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+TEST(RenameTest, FileDoesNotOverwriteDirectory) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ ASSERT_THAT(rename(f.path().c_str(), dir.path().c_str()),
+ SyscallFailsWithErrno(EISDIR));
+}
+
+TEST(RenameTest, DirectoryDoesNotOverwriteFile) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ ASSERT_THAT(rename(dir.path().c_str(), f.path().c_str()),
+ SyscallFailsWithErrno(ENOTDIR));
+}
+
+TEST(RenameTest, DirectoryOverwritesEmptyDirectory) {
+ auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path()));
+ auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(rename(dir1.path().c_str(), dir2.path().c_str()),
+ SyscallSucceeds());
+ EXPECT_THAT(Exists(dir1.path()), IsPosixErrorOkAndHolds(false));
+ dir1.release();
+ EXPECT_THAT(Exists(JoinPath(dir2.path(), Basename(f.path()))),
+ IsPosixErrorOkAndHolds(true));
+ f.release();
+}
+
+TEST(RenameTest, FailsWithDots) {
+ auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto dir1_dot = absl::StrCat(dir1.path(), "/.");
+ auto dir2_dot = absl::StrCat(dir2.path(), "/.");
+ auto dir1_dot_dot = absl::StrCat(dir1.path(), "/..");
+ auto dir2_dot_dot = absl::StrCat(dir2.path(), "/..");
+
+ // Try with dot paths in the first argument
+ EXPECT_THAT(rename(dir1_dot.c_str(), dir2.path().c_str()),
+ SyscallFailsWithErrno(EBUSY));
+ EXPECT_THAT(rename(dir1_dot_dot.c_str(), dir2.path().c_str()),
+ SyscallFailsWithErrno(EBUSY));
+
+ // Try with dot paths in the second argument
+ EXPECT_THAT(rename(dir1.path().c_str(), dir2_dot.c_str()),
+ SyscallFailsWithErrno(EBUSY));
+ EXPECT_THAT(rename(dir1.path().c_str(), dir2_dot_dot.c_str()),
+ SyscallFailsWithErrno(EBUSY));
+}
+
+TEST(RenameTest, DirectoryDoesNotOverwriteNonemptyDirectory) {
+ auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto f1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path()));
+ auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto f2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir2.path()));
+ ASSERT_THAT(rename(dir1.path().c_str(), dir2.path().c_str()),
+ SyscallFailsWithErrno(ENOTEMPTY));
+}
+
+TEST(RenameTest, FailsWhenOldParentNotWritable) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto f1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path()));
+ auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ // dir1 is not writable.
+ ASSERT_THAT(chmod(dir1.path().c_str(), 0555), SyscallSucceeds());
+
+ std::string const newpath = NewTempAbsPathInDir(dir2.path());
+ EXPECT_THAT(rename(f1.path().c_str(), newpath.c_str()),
+ SyscallFailsWithErrno(EACCES));
+}
+
+TEST(RenameTest, FailsWhenNewParentNotWritable) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto f1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path()));
+ // dir2 is not writable.
+ auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0555));
+
+ std::string const newpath = NewTempAbsPathInDir(dir2.path());
+ EXPECT_THAT(rename(f1.path().c_str(), newpath.c_str()),
+ SyscallFailsWithErrno(EACCES));
+}
+
+// Equivalent to FailsWhenNewParentNotWritable, but with a destination file
+// to overwrite.
+TEST(RenameTest, OverwriteFailsWhenNewParentNotWritable) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto f1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path()));
+
+ // dir2 is not writable.
+ auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto f2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir2.path()));
+ ASSERT_THAT(chmod(dir2.path().c_str(), 0555), SyscallSucceeds());
+
+ EXPECT_THAT(rename(f1.path().c_str(), f2.path().c_str()),
+ SyscallFailsWithErrno(EACCES));
+}
+
+// If the parent directory of source is not accessible, rename returns EACCES
+// because the user cannot determine if source exists.
+TEST(RenameTest, FileDoesNotExistWhenNewParentNotExecutable) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ // No execute permission.
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0400));
+
+ const std::string source = JoinPath(dir.path(), "source");
+ const std::string dest = JoinPath(dir.path(), "dest");
+ ASSERT_THAT(rename(source.c_str(), dest.c_str()),
+ SyscallFailsWithErrno(EACCES));
+}
+
+TEST(RenameTest, DirectoryWithOpenFdOverwritesEmptyDirectory) {
+ auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path()));
+ auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ // Get an fd on dir1
+ int fd;
+ ASSERT_THAT(fd = open(dir1.path().c_str(), O_DIRECTORY), SyscallSucceeds());
+ auto close_f = Cleanup([fd] {
+ // Close the fd on f.
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+ });
+
+ EXPECT_THAT(rename(dir1.path().c_str(), dir2.path().c_str()),
+ SyscallSucceeds());
+
+ const std::string new_f_path = JoinPath(dir2.path(), Basename(f.path()));
+
+ auto remove_f = Cleanup([&] {
+ // Delete f in its new location.
+ ASSERT_NO_ERRNO(Delete(new_f_path));
+ f.release();
+ });
+
+ EXPECT_THAT(Exists(dir1.path()), IsPosixErrorOkAndHolds(false));
+ dir1.release();
+ EXPECT_THAT(Exists(new_f_path), IsPosixErrorOkAndHolds(true));
+}
+
+TEST(RenameTest, FileWithOpenFd) {
+ TempPath root_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ TempPath dir1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root_dir.path()));
+ TempPath dir2 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root_dir.path()));
+ TempPath dir3 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root_dir.path()));
+
+ // Create file in dir1.
+ constexpr char kContents[] = "foo";
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ dir1.path(), kContents, TempPath::kDefaultFileMode));
+
+ // Get fd on file.
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDWR));
+
+ // Move f to dir2.
+ const std::string path2 = NewTempAbsPathInDir(dir2.path());
+ ASSERT_THAT(rename(f.path().c_str(), path2.c_str()), SyscallSucceeds());
+
+ // Read f's kContents.
+ char buf[sizeof(kContents)];
+ EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(kContents), 0),
+ SyscallSucceedsWithValue(sizeof(kContents) - 1));
+ EXPECT_EQ(absl::string_view(buf, sizeof(buf) - 1), kContents);
+
+ // Move f to dir3.
+ const std::string path3 = NewTempAbsPathInDir(dir3.path());
+ ASSERT_THAT(rename(path2.c_str(), path3.c_str()), SyscallSucceeds());
+
+ // Read f's kContents.
+ EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(kContents), 0),
+ SyscallSucceedsWithValue(sizeof(kContents) - 1));
+ EXPECT_EQ(absl::string_view(buf, sizeof(buf) - 1), kContents);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/rlimits.cc b/test/syscalls/linux/rlimits.cc
new file mode 100644
index 000000000..860f0f688
--- /dev/null
+++ b/test/syscalls/linux/rlimits.cc
@@ -0,0 +1,75 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/resource.h>
+#include <sys/time.h>
+
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(RlimitTest, SetRlimitHigher) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_RESOURCE)));
+
+ struct rlimit rl = {};
+ EXPECT_THAT(getrlimit(RLIMIT_NOFILE, &rl), SyscallSucceeds());
+
+ // Lower the rlimit first, as it may be equal to /proc/sys/fs/nr_open, in
+ // which case even users with CAP_SYS_RESOURCE can't raise it.
+ rl.rlim_cur--;
+ rl.rlim_max--;
+ ASSERT_THAT(setrlimit(RLIMIT_NOFILE, &rl), SyscallSucceeds());
+
+ rl.rlim_max++;
+ EXPECT_THAT(setrlimit(RLIMIT_NOFILE, &rl), SyscallSucceeds());
+}
+
+TEST(RlimitTest, UnprivilegedSetRlimit) {
+ // Drop privileges if necessary.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_RESOURCE))) {
+ EXPECT_NO_ERRNO(SetCapability(CAP_SYS_RESOURCE, false));
+ }
+
+ struct rlimit rl = {};
+ rl.rlim_cur = 1000;
+ rl.rlim_max = 20000;
+ EXPECT_THAT(setrlimit(RLIMIT_NOFILE, &rl), SyscallSucceeds());
+
+ struct rlimit rl2 = {};
+ EXPECT_THAT(getrlimit(RLIMIT_NOFILE, &rl2), SyscallSucceeds());
+ EXPECT_EQ(rl.rlim_cur, rl2.rlim_cur);
+ EXPECT_EQ(rl.rlim_max, rl2.rlim_max);
+
+ rl.rlim_max = 100000;
+ EXPECT_THAT(setrlimit(RLIMIT_NOFILE, &rl), SyscallFailsWithErrno(EPERM));
+}
+
+TEST(RlimitTest, SetSoftRlimitAboveHard) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_RESOURCE)));
+
+ struct rlimit rl = {};
+ EXPECT_THAT(getrlimit(RLIMIT_NOFILE, &rl), SyscallSucceeds());
+
+ rl.rlim_cur = rl.rlim_max + 1;
+ EXPECT_THAT(setrlimit(RLIMIT_NOFILE, &rl), SyscallFailsWithErrno(EINVAL));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/rseq.cc b/test/syscalls/linux/rseq.cc
new file mode 100644
index 000000000..4bfb1ff56
--- /dev/null
+++ b/test/syscalls/linux/rseq.cc
@@ -0,0 +1,198 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <signal.h>
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <sys/wait.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/rseq/test.h"
+#include "test/syscalls/linux/rseq/uapi.h"
+#include "test/util/logging.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Syscall test for rseq (restartable sequences).
+//
+// We must be very careful about how these tests are written. Each thread may
+// only have one struct rseq registration, which may be done automatically at
+// thread start (as of 2019-11-13, glibc does *not* support rseq and thus does
+// not do so, but other libraries do).
+//
+// Testing of rseq is thus done primarily in a child process with no
+// registration. This means exec'ing a nostdlib binary, as rseq registration can
+// only be cleared by execve (or knowing the old rseq address), and glibc (based
+// on the current unmerged patches) register rseq before calling main()).
+
+int RSeq(struct rseq* rseq, uint32_t rseq_len, int flags, uint32_t sig) {
+ return syscall(kRseqSyscall, rseq, rseq_len, flags, sig);
+}
+
+// Returns true if this kernel supports the rseq syscall.
+PosixErrorOr<bool> RSeqSupported() {
+ // We have to be careful here, there are three possible cases:
+ //
+ // 1. rseq is not supported -> ENOSYS
+ // 2. rseq is supported and not registered -> success, but we should
+ // unregister.
+ // 3. rseq is supported and registered -> EINVAL (most likely).
+
+ // The only validation done on new registrations is that rseq is aligned and
+ // writable.
+ rseq rseq = {};
+ int ret = RSeq(&rseq, sizeof(rseq), 0, 0);
+ if (ret == 0) {
+ // Successfully registered, rseq is supported. Unregister.
+ ret = RSeq(&rseq, sizeof(rseq), kRseqFlagUnregister, 0);
+ if (ret != 0) {
+ return PosixError(errno);
+ }
+ return true;
+ }
+
+ switch (errno) {
+ case ENOSYS:
+ // Not supported.
+ return false;
+ case EINVAL:
+ // Supported, but already registered. EINVAL returned because we provided
+ // a different address.
+ return true;
+ default:
+ // Unknown error.
+ return PosixError(errno);
+ }
+}
+
+constexpr char kRseqBinary[] = "test/syscalls/linux/rseq/rseq";
+
+void RunChildTest(std::string test_case, int want_status) {
+ std::string path = RunfilePath(kRseqBinary);
+
+ pid_t child_pid = -1;
+ int execve_errno = 0;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(path, {path, test_case}, {}, &child_pid, &execve_errno));
+
+ ASSERT_GT(child_pid, 0);
+ ASSERT_EQ(execve_errno, 0);
+
+ int status = 0;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ ASSERT_EQ(status, want_status);
+}
+
+// Test that rseq must be aligned.
+TEST(RseqTest, Unaligned) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestUnaligned, 0);
+}
+
+// Sanity test that registration works.
+TEST(RseqTest, Register) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestRegister, 0);
+}
+
+// Registration can't be done twice.
+TEST(RseqTest, DoubleRegister) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestDoubleRegister, 0);
+}
+
+// Registration can be done again after unregister.
+TEST(RseqTest, RegisterUnregister) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestRegisterUnregister, 0);
+}
+
+// The pointer to rseq must match on register/unregister.
+TEST(RseqTest, UnregisterDifferentPtr) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestUnregisterDifferentPtr, 0);
+}
+
+// The signature must match on register/unregister.
+TEST(RseqTest, UnregisterDifferentSignature) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestUnregisterDifferentSignature, 0);
+}
+
+// The CPU ID is initialized.
+TEST(RseqTest, CPU) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestCPU, 0);
+}
+
+// Critical section is eventually aborted.
+TEST(RseqTest, Abort) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestAbort, 0);
+}
+
+// Abort may be before the critical section.
+TEST(RseqTest, AbortBefore) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestAbortBefore, 0);
+}
+
+// Signature must match.
+TEST(RseqTest, AbortSignature) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestAbortSignature, SIGSEGV);
+}
+
+// Abort must not be in the critical section.
+TEST(RseqTest, AbortPreCommit) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestAbortPreCommit, SIGSEGV);
+}
+
+// rseq.rseq_cs is cleared on abort.
+TEST(RseqTest, AbortClearsCS) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestAbortClearsCS, 0);
+}
+
+// rseq.rseq_cs is cleared on abort outside of critical section.
+TEST(RseqTest, InvalidAbortClearsCS) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestInvalidAbortClearsCS, 0);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/rseq/BUILD b/test/syscalls/linux/rseq/BUILD
new file mode 100644
index 000000000..853258b04
--- /dev/null
+++ b/test/syscalls/linux/rseq/BUILD
@@ -0,0 +1,61 @@
+# This package contains a standalone rseq test binary. This binary must not
+# depend on libc, which might use rseq itself.
+
+load("//tools:defs.bzl", "cc_flags_supplier", "cc_library", "cc_toolchain", "select_arch")
+
+package(licenses = ["notice"])
+
+genrule(
+ name = "rseq_binary",
+ srcs = [
+ "critical.h",
+ "critical_amd64.S",
+ "critical_arm64.S",
+ "rseq.cc",
+ "syscalls.h",
+ "start_amd64.S",
+ "start_arm64.S",
+ "test.h",
+ "types.h",
+ "uapi.h",
+ ],
+ outs = ["rseq"],
+ cmd = "$(CC) " +
+ "$(CC_FLAGS) " +
+ "-I. " +
+ "-Wall " +
+ "-Werror " +
+ "-O2 " +
+ "-std=c++17 " +
+ "-static " +
+ "-nostdlib " +
+ "-ffreestanding " +
+ "-o " +
+ "$(location rseq) " +
+ select_arch(
+ amd64 = "$(location critical_amd64.S) $(location start_amd64.S) ",
+ arm64 = "$(location critical_arm64.S) $(location start_arm64.S) ",
+ no_match_error = "unsupported architecture",
+ ) +
+ "$(location rseq.cc)",
+ toolchains = [
+ cc_toolchain,
+ ":no_pie_cc_flags",
+ ],
+ visibility = ["//:sandbox"],
+)
+
+cc_flags_supplier(
+ name = "no_pie_cc_flags",
+ features = ["-pie"],
+)
+
+cc_library(
+ name = "lib",
+ testonly = 1,
+ hdrs = [
+ "test.h",
+ "uapi.h",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/test/syscalls/linux/rseq/critical.h b/test/syscalls/linux/rseq/critical.h
new file mode 100644
index 000000000..ac987a25e
--- /dev/null
+++ b/test/syscalls/linux/rseq/critical.h
@@ -0,0 +1,39 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_CRITICAL_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_CRITICAL_H_
+
+#include "test/syscalls/linux/rseq/types.h"
+#include "test/syscalls/linux/rseq/uapi.h"
+
+constexpr uint32_t kRseqSignature = 0x90909090;
+
+extern "C" {
+
+extern void rseq_loop(struct rseq* r, struct rseq_cs* cs);
+extern void* rseq_loop_early_abort;
+extern void* rseq_loop_start;
+extern void* rseq_loop_pre_commit;
+extern void* rseq_loop_post_commit;
+extern void* rseq_loop_abort;
+
+extern int rseq_getpid(struct rseq* r, struct rseq_cs* cs);
+extern void* rseq_getpid_start;
+extern void* rseq_getpid_post_commit;
+extern void* rseq_getpid_abort;
+
+} // extern "C"
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_CRITICAL_H_
diff --git a/test/syscalls/linux/rseq/critical_amd64.S b/test/syscalls/linux/rseq/critical_amd64.S
new file mode 100644
index 000000000..8c0687e6d
--- /dev/null
+++ b/test/syscalls/linux/rseq/critical_amd64.S
@@ -0,0 +1,66 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Restartable sequences critical sections.
+
+// Loops continuously until aborted.
+//
+// void rseq_loop(struct rseq* r, struct rseq_cs* cs)
+
+ .text
+ .globl rseq_loop
+ .type rseq_loop, @function
+
+rseq_loop:
+ jmp begin
+
+ // Abort block before the critical section.
+ // Abort signature is 4 nops for simplicity.
+ .byte 0x90, 0x90, 0x90, 0x90
+ .globl rseq_loop_early_abort
+rseq_loop_early_abort:
+ ret
+
+begin:
+ // r->rseq_cs = cs
+ movq %rsi, 8(%rdi)
+
+ // N.B. rseq_cs will be cleared by any preempt, even outside the critical
+ // section. Thus it must be set in or immediately before the critical section
+ // to ensure it is not cleared before the section begins.
+ .globl rseq_loop_start
+rseq_loop_start:
+ jmp rseq_loop_start
+
+ // "Pre-commit": extra instructions inside the critical section. These are
+ // used as the abort point in TestAbortPreCommit, which is not valid.
+ .globl rseq_loop_pre_commit
+rseq_loop_pre_commit:
+ // Extra abort signature + nop for TestAbortPostCommit.
+ .byte 0x90, 0x90, 0x90, 0x90
+ nop
+
+ // "Post-commit": never reached in this case.
+ .globl rseq_loop_post_commit
+rseq_loop_post_commit:
+
+ // Abort signature is 4 nops for simplicity.
+ .byte 0x90, 0x90, 0x90, 0x90
+
+ .globl rseq_loop_abort
+rseq_loop_abort:
+ ret
+
+ .size rseq_loop,.-rseq_loop
+ .section .note.GNU-stack,"",@progbits
diff --git a/test/syscalls/linux/rseq/critical_arm64.S b/test/syscalls/linux/rseq/critical_arm64.S
new file mode 100644
index 000000000..bfe7e8307
--- /dev/null
+++ b/test/syscalls/linux/rseq/critical_arm64.S
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Restartable sequences critical sections.
+
+// Loops continuously until aborted.
+//
+// void rseq_loop(struct rseq* r, struct rseq_cs* cs)
+
+ .text
+ .globl rseq_loop
+ .type rseq_loop, @function
+
+rseq_loop:
+ b begin
+
+ // Abort block before the critical section.
+ // Abort signature.
+ .byte 0x90, 0x90, 0x90, 0x90
+ .globl rseq_loop_early_abort
+rseq_loop_early_abort:
+ ret
+
+begin:
+ // r->rseq_cs = cs
+ str x1, [x0, #8]
+
+ // N.B. rseq_cs will be cleared by any preempt, even outside the critical
+ // section. Thus it must be set in or immediately before the critical section
+ // to ensure it is not cleared before the section begins.
+ .globl rseq_loop_start
+rseq_loop_start:
+ b rseq_loop_start
+
+ // "Pre-commit": extra instructions inside the critical section. These are
+ // used as the abort point in TestAbortPreCommit, which is not valid.
+ .globl rseq_loop_pre_commit
+rseq_loop_pre_commit:
+ // Extra abort signature + nop for TestAbortPostCommit.
+ .byte 0x90, 0x90, 0x90, 0x90
+ nop
+
+ // "Post-commit": never reached in this case.
+ .globl rseq_loop_post_commit
+rseq_loop_post_commit:
+
+ // Abort signature.
+ .byte 0x90, 0x90, 0x90, 0x90
+
+ .globl rseq_loop_abort
+rseq_loop_abort:
+ ret
+
+ .size rseq_loop,.-rseq_loop
+ .section .note.GNU-stack,"",@progbits
diff --git a/test/syscalls/linux/rseq/rseq.cc b/test/syscalls/linux/rseq/rseq.cc
new file mode 100644
index 000000000..f036db26d
--- /dev/null
+++ b/test/syscalls/linux/rseq/rseq.cc
@@ -0,0 +1,366 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/rseq/critical.h"
+#include "test/syscalls/linux/rseq/syscalls.h"
+#include "test/syscalls/linux/rseq/test.h"
+#include "test/syscalls/linux/rseq/types.h"
+#include "test/syscalls/linux/rseq/uapi.h"
+
+namespace gvisor {
+namespace testing {
+
+extern "C" int main(int argc, char** argv, char** envp);
+
+// Standalone initialization before calling main().
+extern "C" void __init(uintptr_t* sp) {
+ int argc = sp[0];
+ char** argv = reinterpret_cast<char**>(&sp[1]);
+ char** envp = &argv[argc + 1];
+
+ // Call main() and exit.
+ sys_exit_group(main(argc, argv, envp));
+
+ // sys_exit_group does not return
+}
+
+int strcmp(const char* s1, const char* s2) {
+ const unsigned char* p1 = reinterpret_cast<const unsigned char*>(s1);
+ const unsigned char* p2 = reinterpret_cast<const unsigned char*>(s2);
+
+ while (*p1 == *p2) {
+ if (!*p1) {
+ return 0;
+ }
+ ++p1;
+ ++p2;
+ }
+ return static_cast<int>(*p1) - static_cast<int>(*p2);
+}
+
+int sys_rseq(struct rseq* rseq, uint32_t rseq_len, int flags, uint32_t sig) {
+ return raw_syscall(kRseqSyscall, rseq, rseq_len, flags, sig);
+}
+
+// Test that rseq must be aligned.
+int TestUnaligned() {
+ constexpr uintptr_t kRequiredAlignment = alignof(rseq);
+
+ char buf[2 * kRequiredAlignment] = {};
+ uintptr_t ptr = reinterpret_cast<uintptr_t>(&buf[0]);
+ if ((ptr & (kRequiredAlignment - 1)) == 0) {
+ // buf is already aligned. Misalign it.
+ ptr++;
+ }
+
+ int ret = sys_rseq(reinterpret_cast<rseq*>(ptr), sizeof(rseq), 0, 0);
+ if (sys_errno(ret) != EINVAL) {
+ return 1;
+ }
+ return 0;
+}
+
+// Sanity test that registration works.
+int TestRegister() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) {
+ return 1;
+ }
+ return 0;
+};
+
+// Registration can't be done twice.
+int TestDoubleRegister() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != EBUSY) {
+ return 1;
+ }
+
+ return 0;
+};
+
+// Registration can be done again after unregister.
+int TestRegisterUnregister() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ if (int ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, 0);
+ sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ return 0;
+};
+
+// The pointer to rseq must match on register/unregister.
+int TestUnregisterDifferentPtr() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ struct rseq r2 = {};
+ if (int ret = sys_rseq(&r2, sizeof(r2), kRseqFlagUnregister, 0);
+ sys_errno(ret) != EINVAL) {
+ return 1;
+ }
+
+ return 0;
+};
+
+// The signature must match on register/unregister.
+int TestUnregisterDifferentSignature() {
+ constexpr int kSignature = 0;
+
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, kSignature); sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ if (int ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, kSignature + 1);
+ sys_errno(ret) != EPERM) {
+ return 1;
+ }
+
+ return 0;
+};
+
+// The CPU ID is initialized.
+int TestCPU() {
+ struct rseq r = {};
+ r.cpu_id = kRseqCPUIDUninitialized;
+
+ if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ if (__atomic_load_n(&r.cpu_id, __ATOMIC_RELAXED) < 0) {
+ return 1;
+ }
+ if (__atomic_load_n(&r.cpu_id_start, __ATOMIC_RELAXED) < 0) {
+ return 1;
+ }
+
+ return 0;
+};
+
+// Critical section is eventually aborted.
+int TestAbort() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature);
+ sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ struct rseq_cs cs = {};
+ cs.version = 0;
+ cs.flags = 0;
+ cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) -
+ reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort);
+
+ // Loops until abort. If this returns then abort occurred.
+ rseq_loop(&r, &cs);
+
+ return 0;
+};
+
+// Abort may be before the critical section.
+int TestAbortBefore() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature);
+ sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ struct rseq_cs cs = {};
+ cs.version = 0;
+ cs.flags = 0;
+ cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) -
+ reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_early_abort);
+
+ // Loops until abort. If this returns then abort occurred.
+ rseq_loop(&r, &cs);
+
+ return 0;
+};
+
+// Signature must match.
+int TestAbortSignature() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1);
+ sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ struct rseq_cs cs = {};
+ cs.version = 0;
+ cs.flags = 0;
+ cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) -
+ reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort);
+
+ // Loops until abort. This should SIGSEGV on abort.
+ rseq_loop(&r, &cs);
+
+ return 1;
+};
+
+// Abort must not be in the critical section.
+int TestAbortPreCommit() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1);
+ sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ struct rseq_cs cs = {};
+ cs.version = 0;
+ cs.flags = 0;
+ cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) -
+ reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_pre_commit);
+
+ // Loops until abort. This should SIGSEGV on abort.
+ rseq_loop(&r, &cs);
+
+ return 1;
+};
+
+// rseq.rseq_cs is cleared on abort.
+int TestAbortClearsCS() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature);
+ sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ struct rseq_cs cs = {};
+ cs.version = 0;
+ cs.flags = 0;
+ cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) -
+ reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort);
+
+ // Loops until abort. If this returns then abort occurred.
+ rseq_loop(&r, &cs);
+
+ if (__atomic_load_n(&r.rseq_cs, __ATOMIC_RELAXED)) {
+ return 1;
+ }
+
+ return 0;
+};
+
+// rseq.rseq_cs is cleared on abort outside of critical section.
+int TestInvalidAbortClearsCS() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature);
+ sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ struct rseq_cs cs = {};
+ cs.version = 0;
+ cs.flags = 0;
+ cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) -
+ reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort);
+
+ __atomic_store_n(&r.rseq_cs, &cs, __ATOMIC_RELAXED);
+
+ // When the next abort condition occurs, the kernel will clear cs once it
+ // determines we aren't in the critical section.
+ while (1) {
+ if (!__atomic_load_n(&r.rseq_cs, __ATOMIC_RELAXED)) {
+ break;
+ }
+ }
+
+ return 0;
+};
+
+// Exit codes:
+// 0 - Pass
+// 1 - Fail
+// 2 - Missing argument
+// 3 - Unknown test case
+extern "C" int main(int argc, char** argv, char** envp) {
+ if (argc != 2) {
+ // Usage: rseq <test case>
+ return 2;
+ }
+
+ if (strcmp(argv[1], kRseqTestUnaligned) == 0) {
+ return TestUnaligned();
+ }
+ if (strcmp(argv[1], kRseqTestRegister) == 0) {
+ return TestRegister();
+ }
+ if (strcmp(argv[1], kRseqTestDoubleRegister) == 0) {
+ return TestDoubleRegister();
+ }
+ if (strcmp(argv[1], kRseqTestRegisterUnregister) == 0) {
+ return TestRegisterUnregister();
+ }
+ if (strcmp(argv[1], kRseqTestUnregisterDifferentPtr) == 0) {
+ return TestUnregisterDifferentPtr();
+ }
+ if (strcmp(argv[1], kRseqTestUnregisterDifferentSignature) == 0) {
+ return TestUnregisterDifferentSignature();
+ }
+ if (strcmp(argv[1], kRseqTestCPU) == 0) {
+ return TestCPU();
+ }
+ if (strcmp(argv[1], kRseqTestAbort) == 0) {
+ return TestAbort();
+ }
+ if (strcmp(argv[1], kRseqTestAbortBefore) == 0) {
+ return TestAbortBefore();
+ }
+ if (strcmp(argv[1], kRseqTestAbortSignature) == 0) {
+ return TestAbortSignature();
+ }
+ if (strcmp(argv[1], kRseqTestAbortPreCommit) == 0) {
+ return TestAbortPreCommit();
+ }
+ if (strcmp(argv[1], kRseqTestAbortClearsCS) == 0) {
+ return TestAbortClearsCS();
+ }
+ if (strcmp(argv[1], kRseqTestInvalidAbortClearsCS) == 0) {
+ return TestInvalidAbortClearsCS();
+ }
+
+ return 3;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/rseq/start_amd64.S b/test/syscalls/linux/rseq/start_amd64.S
new file mode 100644
index 000000000..b9611b276
--- /dev/null
+++ b/test/syscalls/linux/rseq/start_amd64.S
@@ -0,0 +1,45 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+
+ .text
+ .align 4
+ .type _start,@function
+ .globl _start
+
+_start:
+ movq %rsp,%rdi
+ call __init
+ hlt
+
+ .size _start,.-_start
+ .section .note.GNU-stack,"",@progbits
+
+ .text
+ .globl raw_syscall
+ .type raw_syscall, @function
+
+raw_syscall:
+ mov %rdi,%rax // syscall #
+ mov %rsi,%rdi // arg0
+ mov %rdx,%rsi // arg1
+ mov %rcx,%rdx // arg2
+ mov %r8,%r10 // arg3 (goes in r10 instead of rcx for system calls)
+ mov %r9,%r8 // arg4
+ mov 0x8(%rsp),%r9 // arg5
+ syscall
+ ret
+
+ .size raw_syscall,.-raw_syscall
+ .section .note.GNU-stack,"",@progbits
diff --git a/test/syscalls/linux/rseq/start_arm64.S b/test/syscalls/linux/rseq/start_arm64.S
new file mode 100644
index 000000000..693c1c6eb
--- /dev/null
+++ b/test/syscalls/linux/rseq/start_arm64.S
@@ -0,0 +1,45 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+
+ .text
+ .align 4
+ .type _start,@function
+ .globl _start
+
+_start:
+ mov x29, sp
+ bl __init
+ wfi
+
+ .size _start,.-_start
+ .section .note.GNU-stack,"",@progbits
+
+ .text
+ .globl raw_syscall
+ .type raw_syscall, @function
+
+raw_syscall:
+ mov x8,x0 // syscall #
+ mov x0,x1 // arg0
+ mov x1,x2 // arg1
+ mov x2,x3 // arg2
+ mov x3,x4 // arg3
+ mov x4,x5 // arg4
+ mov x5,x6 // arg5
+ svc #0
+ ret
+
+ .size raw_syscall,.-raw_syscall
+ .section .note.GNU-stack,"",@progbits
diff --git a/test/syscalls/linux/rseq/syscalls.h b/test/syscalls/linux/rseq/syscalls.h
new file mode 100644
index 000000000..c4118e6c5
--- /dev/null
+++ b/test/syscalls/linux/rseq/syscalls.h
@@ -0,0 +1,69 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_SYSCALLS_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_SYSCALLS_H_
+
+#include "test/syscalls/linux/rseq/types.h"
+
+// Syscall numbers.
+#if defined(__x86_64__)
+constexpr int kGetpid = 39;
+constexpr int kExitGroup = 231;
+#elif defined(__aarch64__)
+constexpr int kGetpid = 172;
+constexpr int kExitGroup = 94;
+#else
+#error "Unknown architecture"
+#endif
+
+namespace gvisor {
+namespace testing {
+
+// Standalone system call interfaces.
+// Note that these are all "raw" system call interfaces which encode
+// errors by setting the return value to a small negative number.
+// Use sys_errno() to check system call return values for errors.
+
+// Maximum Linux error number.
+constexpr int kMaxErrno = 4095;
+
+// Errno values.
+#define EPERM 1
+#define EFAULT 14
+#define EBUSY 16
+#define EINVAL 22
+
+// Get the error number from a raw system call return value.
+// Returns a positive error number or 0 if there was no error.
+static inline int sys_errno(uintptr_t rval) {
+ if (rval >= static_cast<uintptr_t>(-kMaxErrno)) {
+ return -static_cast<int>(rval);
+ }
+ return 0;
+}
+
+extern "C" uintptr_t raw_syscall(int number, ...);
+
+static inline void sys_exit_group(int status) {
+ raw_syscall(kExitGroup, status);
+}
+static inline int sys_getpid() {
+ return static_cast<int>(raw_syscall(kGetpid));
+}
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_SYSCALLS_H_
diff --git a/test/syscalls/linux/rseq/test.h b/test/syscalls/linux/rseq/test.h
new file mode 100644
index 000000000..3b7bb74b1
--- /dev/null
+++ b/test/syscalls/linux/rseq/test.h
@@ -0,0 +1,43 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TEST_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TEST_H_
+
+namespace gvisor {
+namespace testing {
+
+// Test cases supported by rseq binary.
+
+inline constexpr char kRseqTestUnaligned[] = "unaligned";
+inline constexpr char kRseqTestRegister[] = "register";
+inline constexpr char kRseqTestDoubleRegister[] = "double-register";
+inline constexpr char kRseqTestRegisterUnregister[] = "register-unregister";
+inline constexpr char kRseqTestUnregisterDifferentPtr[] =
+ "unregister-different-ptr";
+inline constexpr char kRseqTestUnregisterDifferentSignature[] =
+ "unregister-different-signature";
+inline constexpr char kRseqTestCPU[] = "cpu";
+inline constexpr char kRseqTestAbort[] = "abort";
+inline constexpr char kRseqTestAbortBefore[] = "abort-before";
+inline constexpr char kRseqTestAbortSignature[] = "abort-signature";
+inline constexpr char kRseqTestAbortPreCommit[] = "abort-precommit";
+inline constexpr char kRseqTestAbortClearsCS[] = "abort-clears-cs";
+inline constexpr char kRseqTestInvalidAbortClearsCS[] =
+ "invalid-abort-clears-cs";
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TEST_H_
diff --git a/test/syscalls/linux/rseq/types.h b/test/syscalls/linux/rseq/types.h
new file mode 100644
index 000000000..b6afe9817
--- /dev/null
+++ b/test/syscalls/linux/rseq/types.h
@@ -0,0 +1,31 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TYPES_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TYPES_H_
+
+using size_t = __SIZE_TYPE__;
+using uintptr_t = __UINTPTR_TYPE__;
+
+using uint8_t = __UINT8_TYPE__;
+using uint16_t = __UINT16_TYPE__;
+using uint32_t = __UINT32_TYPE__;
+using uint64_t = __UINT64_TYPE__;
+
+using int8_t = __INT8_TYPE__;
+using int16_t = __INT16_TYPE__;
+using int32_t = __INT32_TYPE__;
+using int64_t = __INT64_TYPE__;
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TYPES_H_
diff --git a/test/syscalls/linux/rseq/uapi.h b/test/syscalls/linux/rseq/uapi.h
new file mode 100644
index 000000000..d3e60d0a4
--- /dev/null
+++ b/test/syscalls/linux/rseq/uapi.h
@@ -0,0 +1,51 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_
+
+#include <stdint.h>
+
+// User-kernel ABI for restartable sequences.
+
+// Syscall numbers.
+#if defined(__x86_64__)
+constexpr int kRseqSyscall = 334;
+#elif defined(__aarch64__)
+constexpr int kRseqSyscall = 293;
+#else
+#error "Unknown architecture"
+#endif // __x86_64__
+
+struct rseq_cs {
+ uint32_t version;
+ uint32_t flags;
+ uint64_t start_ip;
+ uint64_t post_commit_offset;
+ uint64_t abort_ip;
+} __attribute__((aligned(4 * sizeof(uint64_t))));
+
+// N.B. alignment is enforced by the kernel.
+struct rseq {
+ uint32_t cpu_id_start;
+ uint32_t cpu_id;
+ struct rseq_cs* rseq_cs;
+ uint32_t flags;
+} __attribute__((aligned(4 * sizeof(uint64_t))));
+
+constexpr int kRseqFlagUnregister = 1 << 0;
+
+constexpr int kRseqCPUIDUninitialized = -1;
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_
diff --git a/test/syscalls/linux/rtsignal.cc b/test/syscalls/linux/rtsignal.cc
new file mode 100644
index 000000000..ed27e2566
--- /dev/null
+++ b/test/syscalls/linux/rtsignal.cc
@@ -0,0 +1,171 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <cerrno>
+#include <csignal>
+
+#include "gtest/gtest.h"
+#include "test/util/cleanup.h"
+#include "test/util/logging.h"
+#include "test/util/posix_error.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// saved_info is set by the handler.
+siginfo_t saved_info;
+
+// has_saved_info is set to true by the handler.
+volatile bool has_saved_info;
+
+void SigHandler(int sig, siginfo_t* info, void* context) {
+ // Copy to the given info.
+ saved_info = *info;
+ has_saved_info = true;
+}
+
+void ClearSavedInfo() {
+ // Clear the cached info.
+ memset(&saved_info, 0, sizeof(saved_info));
+ has_saved_info = false;
+}
+
+PosixErrorOr<Cleanup> SetupSignalHandler(int sig) {
+ struct sigaction sa;
+ sa.sa_sigaction = SigHandler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO;
+ return ScopedSigaction(sig, sa);
+}
+
+class RtSignalTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ action_cleanup_ = ASSERT_NO_ERRNO_AND_VALUE(SetupSignalHandler(SIGUSR1));
+ mask_cleanup_ =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGUSR1));
+ }
+
+ void TearDown() override { ClearSavedInfo(); }
+
+ private:
+ Cleanup action_cleanup_;
+ Cleanup mask_cleanup_;
+};
+
+static int rt_sigqueueinfo(pid_t tgid, int sig, siginfo_t* uinfo) {
+ int ret;
+ do {
+ // NOTE(b/25434735): rt_sigqueueinfo(2) could return EAGAIN for RT signals.
+ ret = syscall(SYS_rt_sigqueueinfo, tgid, sig, uinfo);
+ } while (ret == -1 && errno == EAGAIN);
+ return ret;
+}
+
+TEST_F(RtSignalTest, InvalidTID) {
+ siginfo_t uinfo;
+ // Depending on the kernel version, these calls may fail with
+ // ESRCH (goobunutu machines) or EPERM (production machines). Thus,
+ // the test simply ensures that they do fail.
+ EXPECT_THAT(rt_sigqueueinfo(-1, SIGUSR1, &uinfo), SyscallFails());
+ EXPECT_FALSE(has_saved_info);
+ EXPECT_THAT(rt_sigqueueinfo(0, SIGUSR1, &uinfo), SyscallFails());
+ EXPECT_FALSE(has_saved_info);
+}
+
+TEST_F(RtSignalTest, InvalidCodes) {
+ siginfo_t uinfo;
+
+ // We need a child for the code checks to apply. If the process is delivering
+ // to itself, then it can use whatever codes it wants and they will go
+ // through.
+ pid_t child = fork();
+ if (child == 0) {
+ _exit(1);
+ }
+ ASSERT_THAT(child, SyscallSucceeds());
+
+ // These are not allowed for child processes.
+ uinfo.si_code = 0; // SI_USER.
+ EXPECT_THAT(rt_sigqueueinfo(child, SIGUSR1, &uinfo),
+ SyscallFailsWithErrno(EPERM));
+ uinfo.si_code = 0x80; // SI_KERNEL.
+ EXPECT_THAT(rt_sigqueueinfo(child, SIGUSR1, &uinfo),
+ SyscallFailsWithErrno(EPERM));
+ uinfo.si_code = -6; // SI_TKILL.
+ EXPECT_THAT(rt_sigqueueinfo(child, SIGUSR1, &uinfo),
+ SyscallFailsWithErrno(EPERM));
+ uinfo.si_code = -1; // SI_QUEUE (allowed).
+ EXPECT_THAT(rt_sigqueueinfo(child, SIGUSR1, &uinfo), SyscallSucceeds());
+
+ // Join the child process.
+ EXPECT_THAT(waitpid(child, nullptr, 0), SyscallSucceeds());
+}
+
+TEST_F(RtSignalTest, ValueDelivered) {
+ siginfo_t uinfo;
+ uinfo.si_code = -1; // SI_QUEUE (allowed).
+ uinfo.si_errno = 0x1234;
+
+ EXPECT_EQ(saved_info.si_errno, 0x0);
+ EXPECT_THAT(rt_sigqueueinfo(getpid(), SIGUSR1, &uinfo), SyscallSucceeds());
+ EXPECT_TRUE(has_saved_info);
+ EXPECT_EQ(saved_info.si_errno, 0x1234);
+}
+
+TEST_F(RtSignalTest, SignoMatch) {
+ auto action2_cleanup = ASSERT_NO_ERRNO_AND_VALUE(SetupSignalHandler(SIGUSR2));
+ auto mask2_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGUSR2));
+
+ siginfo_t uinfo;
+ uinfo.si_code = -1; // SI_QUEUE (allowed).
+
+ EXPECT_THAT(rt_sigqueueinfo(getpid(), SIGUSR1, &uinfo), SyscallSucceeds());
+ EXPECT_TRUE(has_saved_info);
+ EXPECT_EQ(saved_info.si_signo, SIGUSR1);
+
+ ClearSavedInfo();
+
+ EXPECT_THAT(rt_sigqueueinfo(getpid(), SIGUSR2, &uinfo), SyscallSucceeds());
+ EXPECT_TRUE(has_saved_info);
+ EXPECT_EQ(saved_info.si_signo, SIGUSR2);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
+
+int main(int argc, char** argv) {
+ // These tests depend on delivering SIGUSR1/2 to the main thread (so they can
+ // synchronously check has_saved_info). Block these so that any other threads
+ // created by TestInit will also have them blocked.
+ sigset_t set;
+ sigemptyset(&set);
+ sigaddset(&set, SIGUSR1);
+ sigaddset(&set, SIGUSR2);
+ TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
+
+ gvisor::testing::TestInit(&argc, &argv);
+ return gvisor::testing::RunAllTests();
+}
diff --git a/test/syscalls/linux/sched.cc b/test/syscalls/linux/sched.cc
new file mode 100644
index 000000000..735e99411
--- /dev/null
+++ b/test/syscalls/linux/sched.cc
@@ -0,0 +1,71 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <sched.h>
+
+#include "gtest/gtest.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// In linux, pid is limited to 29 bits because how futex is implemented.
+constexpr int kImpossiblePID = (1 << 29) + 1;
+
+TEST(SchedGetparamTest, ReturnsZero) {
+ struct sched_param param;
+ EXPECT_THAT(sched_getparam(getpid(), &param), SyscallSucceeds());
+ EXPECT_EQ(param.sched_priority, 0);
+ EXPECT_THAT(sched_getparam(/*pid=*/0, &param), SyscallSucceeds());
+ EXPECT_EQ(param.sched_priority, 0);
+}
+
+TEST(SchedGetparamTest, InvalidPIDReturnsEINVAL) {
+ struct sched_param param;
+ EXPECT_THAT(sched_getparam(/*pid=*/-1, &param),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SchedGetparamTest, ImpossiblePIDReturnsESRCH) {
+ struct sched_param param;
+ EXPECT_THAT(sched_getparam(kImpossiblePID, &param),
+ SyscallFailsWithErrno(ESRCH));
+}
+
+TEST(SchedGetparamTest, NullParamReturnsEINVAL) {
+ EXPECT_THAT(sched_getparam(0, nullptr), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SchedGetschedulerTest, ReturnsSchedOther) {
+ EXPECT_THAT(sched_getscheduler(getpid()),
+ SyscallSucceedsWithValue(SCHED_OTHER));
+ EXPECT_THAT(sched_getscheduler(/*pid=*/0),
+ SyscallSucceedsWithValue(SCHED_OTHER));
+}
+
+TEST(SchedGetschedulerTest, ReturnsEINVAL) {
+ EXPECT_THAT(sched_getscheduler(/*pid=*/-1), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SchedGetschedulerTest, ReturnsESRCH) {
+ EXPECT_THAT(sched_getscheduler(kImpossiblePID), SyscallFailsWithErrno(ESRCH));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/sched_yield.cc b/test/syscalls/linux/sched_yield.cc
new file mode 100644
index 000000000..5d24f5b58
--- /dev/null
+++ b/test/syscalls/linux/sched_yield.cc
@@ -0,0 +1,33 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sched.h>
+
+#include "gtest/gtest.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(SchedYieldTest, Success) {
+ EXPECT_THAT(sched_yield(), SyscallSucceeds());
+ EXPECT_THAT(sched_yield(), SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/seccomp.cc b/test/syscalls/linux/seccomp.cc
new file mode 100644
index 000000000..ce88d90dd
--- /dev/null
+++ b/test/syscalls/linux/seccomp.cc
@@ -0,0 +1,425 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <linux/audit.h>
+#include <linux/filter.h>
+#include <linux/seccomp.h>
+#include <pthread.h>
+#include <sched.h>
+#include <signal.h>
+#include <string.h>
+#include <sys/prctl.h>
+#include <sys/syscall.h>
+#include <time.h>
+#include <ucontext.h>
+#include <unistd.h>
+
+#include <atomic>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/base/macros.h"
+#include "test/util/logging.h"
+#include "test/util/memory_util.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/proc_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+#ifndef SYS_SECCOMP
+#define SYS_SECCOMP 1
+#endif
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// A syscall not implemented by Linux that we don't expect to be called.
+#ifdef __x86_64__
+constexpr uint32_t kFilteredSyscall = SYS_vserver;
+#elif __aarch64__
+// Use the last of arch_specific_syscalls which are not implemented on arm64.
+constexpr uint32_t kFilteredSyscall = __NR_arch_specific_syscall + 15;
+#endif
+
+// Applies a seccomp-bpf filter that returns `filtered_result` for
+// `sysno` and allows all other syscalls. Async-signal-safe.
+void ApplySeccompFilter(uint32_t sysno, uint32_t filtered_result,
+ uint32_t flags = 0) {
+ // "Prior to [PR_SET_SECCOMP], the task must call prctl(PR_SET_NO_NEW_PRIVS,
+ // 1) or run with CAP_SYS_ADMIN privileges in its namespace." -
+ // Documentation/prctl/seccomp_filter.txt
+ //
+ // prctl(PR_SET_NO_NEW_PRIVS, 1) may be called repeatedly; calls after the
+ // first are no-ops.
+ TEST_PCHECK(prctl(PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0) == 0);
+ MaybeSave();
+
+ struct sock_filter filter[] = {
+ // A = seccomp_data.arch
+ BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 4),
+#if defined(__x86_64__)
+ // if (A != AUDIT_ARCH_X86_64) goto kill
+ BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, AUDIT_ARCH_X86_64, 0, 4),
+#elif defined(__aarch64__)
+ // if (A != AUDIT_ARCH_AARCH64) goto kill
+ BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, AUDIT_ARCH_AARCH64, 0, 4),
+#else
+#error "Unknown architecture"
+#endif
+ // A = seccomp_data.nr
+ BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 0),
+ // if (A != sysno) goto allow
+ BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, sysno, 0, 1),
+ // return filtered_result
+ BPF_STMT(BPF_RET | BPF_K, filtered_result),
+ // allow: return SECCOMP_RET_ALLOW
+ BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_ALLOW),
+ // kill: return SECCOMP_RET_KILL
+ BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_KILL),
+ };
+ struct sock_fprog prog;
+ prog.len = ABSL_ARRAYSIZE(filter);
+ prog.filter = filter;
+ if (flags) {
+ TEST_CHECK(syscall(__NR_seccomp, SECCOMP_SET_MODE_FILTER, flags, &prog) ==
+ 0);
+ } else {
+ TEST_PCHECK(prctl(PR_SET_SECCOMP, SECCOMP_MODE_FILTER, &prog, 0, 0) == 0);
+ }
+ MaybeSave();
+}
+
+// Wrapper for sigaction. Async-signal-safe.
+void RegisterSignalHandler(int signum,
+ void (*handler)(int, siginfo_t*, void*)) {
+ struct sigaction sa = {};
+ sa.sa_sigaction = handler;
+ sigemptyset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO;
+ TEST_PCHECK(sigaction(signum, &sa, nullptr) == 0);
+ MaybeSave();
+}
+
+// All of the following tests execute in a subprocess to ensure that each test
+// is run in a separate process. This avoids cross-contamination of seccomp
+// state between tests, and is necessary to ensure that test processes killed
+// by SECCOMP_RET_KILL are single-threaded (since SECCOMP_RET_KILL only kills
+// the offending thread, not the whole thread group).
+
+TEST(SeccompTest, RetKillCausesDeathBySIGSYS) {
+ pid_t const pid = fork();
+ if (pid == 0) {
+ // Register a signal handler for SIGSYS that we don't expect to be invoked.
+ RegisterSignalHandler(
+ SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); });
+ ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_KILL);
+ syscall(kFilteredSyscall);
+ TEST_CHECK_MSG(false, "Survived invocation of test syscall");
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ int status;
+ ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSYS)
+ << "status " << status;
+}
+
+TEST(SeccompTest, RetKillOnlyKillsOneThread) {
+ Mapping stack = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+
+ pid_t const pid = fork();
+ if (pid == 0) {
+ // Register a signal handler for SIGSYS that we don't expect to be invoked.
+ RegisterSignalHandler(
+ SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); });
+ ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_KILL);
+ // Pass CLONE_VFORK to block the original thread in the child process until
+ // the clone thread exits with SIGSYS.
+ //
+ // N.B. clone(2) is not officially async-signal-safe, but at minimum glibc's
+ // x86_64 implementation is safe. See glibc
+ // sysdeps/unix/sysv/linux/x86_64/clone.S.
+ clone(
+ +[](void* arg) {
+ syscall(kFilteredSyscall); // should kill the thread
+ _exit(1); // should be unreachable
+ return 2; // should be very unreachable, shut up the compiler
+ },
+ stack.endptr(),
+ CLONE_FILES | CLONE_FS | CLONE_SIGHAND | CLONE_THREAD | CLONE_VM |
+ CLONE_VFORK,
+ nullptr);
+ _exit(0);
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ int status;
+ ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "status " << status;
+}
+
+TEST(SeccompTest, RetTrapCausesSIGSYS) {
+ pid_t const pid = fork();
+ if (pid == 0) {
+ constexpr uint16_t kTrapValue = 0xdead;
+ RegisterSignalHandler(
+ SIGSYS, +[](int signo, siginfo_t* info, void* ucv) {
+ ucontext_t* uc = static_cast<ucontext_t*>(ucv);
+ // This is a signal handler, so we must stay async-signal-safe.
+ TEST_CHECK(info->si_signo == SIGSYS);
+ TEST_CHECK(info->si_code == SYS_SECCOMP);
+ TEST_CHECK(info->si_errno == kTrapValue);
+ TEST_CHECK(info->si_call_addr != nullptr);
+ TEST_CHECK(info->si_syscall == kFilteredSyscall);
+#if defined(__x86_64__)
+ TEST_CHECK(info->si_arch == AUDIT_ARCH_X86_64);
+ TEST_CHECK(uc->uc_mcontext.gregs[REG_RAX] == kFilteredSyscall);
+#elif defined(__aarch64__)
+ TEST_CHECK(info->si_arch == AUDIT_ARCH_AARCH64);
+ TEST_CHECK(uc->uc_mcontext.regs[8] == kFilteredSyscall);
+#endif // defined(__x86_64__)
+ _exit(0);
+ });
+ ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_TRAP | kTrapValue);
+ syscall(kFilteredSyscall);
+ TEST_CHECK_MSG(false, "Survived invocation of test syscall");
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ int status;
+ ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "status " << status;
+}
+
+#ifdef __x86_64__
+
+constexpr uint64_t kVsyscallTimeEntry = 0xffffffffff600400;
+
+time_t vsyscall_time(time_t* t) {
+ return reinterpret_cast<time_t (*)(time_t*)>(kVsyscallTimeEntry)(t);
+}
+
+TEST(SeccompTest, SeccompAppliesToVsyscall) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsVsyscallEnabled()));
+
+ pid_t const pid = fork();
+ if (pid == 0) {
+ constexpr uint16_t kTrapValue = 0xdead;
+ RegisterSignalHandler(
+ SIGSYS, +[](int signo, siginfo_t* info, void* ucv) {
+ ucontext_t* uc = static_cast<ucontext_t*>(ucv);
+ // This is a signal handler, so we must stay async-signal-safe.
+ TEST_CHECK(info->si_signo == SIGSYS);
+ TEST_CHECK(info->si_code == SYS_SECCOMP);
+ TEST_CHECK(info->si_errno == kTrapValue);
+ TEST_CHECK(info->si_call_addr != nullptr);
+ TEST_CHECK(info->si_syscall == SYS_time);
+ TEST_CHECK(info->si_arch == AUDIT_ARCH_X86_64);
+ TEST_CHECK(uc->uc_mcontext.gregs[REG_RAX] == SYS_time);
+ _exit(0);
+ });
+ ApplySeccompFilter(SYS_time, SECCOMP_RET_TRAP | kTrapValue);
+ vsyscall_time(nullptr); // Should result in death.
+ TEST_CHECK_MSG(false, "Survived invocation of test syscall");
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ int status;
+ ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "status " << status;
+}
+
+TEST(SeccompTest, RetKillVsyscallCausesDeathBySIGSYS) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsVsyscallEnabled()));
+
+ pid_t const pid = fork();
+ if (pid == 0) {
+ // Register a signal handler for SIGSYS that we don't expect to be invoked.
+ RegisterSignalHandler(
+ SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); });
+ ApplySeccompFilter(SYS_time, SECCOMP_RET_KILL);
+ vsyscall_time(nullptr); // Should result in death.
+ TEST_CHECK_MSG(false, "Survived invocation of test syscall");
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ int status;
+ ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSYS)
+ << "status " << status;
+}
+
+#endif // defined(__x86_64__)
+
+TEST(SeccompTest, RetTraceWithoutPtracerReturnsENOSYS) {
+ pid_t const pid = fork();
+ if (pid == 0) {
+ ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_TRACE);
+ TEST_CHECK(syscall(kFilteredSyscall) == -1 && errno == ENOSYS);
+ _exit(0);
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ int status;
+ ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "status " << status;
+}
+
+TEST(SeccompTest, RetErrnoReturnsErrno) {
+ pid_t const pid = fork();
+ if (pid == 0) {
+ // ENOTNAM: "Not a XENIX named type file"
+ ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_ERRNO | ENOTNAM);
+ TEST_CHECK(syscall(kFilteredSyscall) == -1 && errno == ENOTNAM);
+ _exit(0);
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ int status;
+ ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "status " << status;
+}
+
+TEST(SeccompTest, RetAllowAllowsSyscall) {
+ pid_t const pid = fork();
+ if (pid == 0) {
+ ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_ALLOW);
+ TEST_CHECK(syscall(kFilteredSyscall) == -1 && errno == ENOSYS);
+ _exit(0);
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ int status;
+ ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "status " << status;
+}
+
+// This test will validate that TSYNC will apply to all threads.
+TEST(SeccompTest, TsyncAppliesToAllThreads) {
+ Mapping stack = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+
+ // We don't want to apply this policy to other test runner threads, so fork.
+ const pid_t pid = fork();
+
+ if (pid == 0) {
+ // First check that we receive a ENOSYS before the policy is applied.
+ TEST_CHECK(syscall(kFilteredSyscall) == -1 && errno == ENOSYS);
+
+ // N.B. clone(2) is not officially async-signal-safe, but at minimum glibc's
+ // x86_64 implementation is safe. See glibc
+ // sysdeps/unix/sysv/linux/x86_64/clone.S.
+ clone(
+ +[](void* arg) {
+ ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_ERRNO | ENOTNAM,
+ SECCOMP_FILTER_FLAG_TSYNC);
+ return 0;
+ },
+ stack.endptr(),
+ CLONE_FILES | CLONE_FS | CLONE_SIGHAND | CLONE_THREAD | CLONE_VM |
+ CLONE_VFORK,
+ nullptr);
+
+ // Because we're using CLONE_VFORK this thread will be blocked until
+ // the second thread has released resources to our virtual memory, since
+ // we're not execing that will happen on _exit.
+
+ // Now verify that the policy applied to this thread too.
+ TEST_CHECK(syscall(kFilteredSyscall) == -1 && errno == ENOTNAM);
+ _exit(0);
+ }
+
+ ASSERT_THAT(pid, SyscallSucceeds());
+ int status = 0;
+ ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "status " << status;
+}
+
+// This test will validate that seccomp(2) rejects unsupported flags.
+TEST(SeccompTest, SeccompRejectsUnknownFlags) {
+ constexpr uint32_t kInvalidFlag = 123;
+ ASSERT_THAT(
+ syscall(__NR_seccomp, SECCOMP_SET_MODE_FILTER, kInvalidFlag, nullptr),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SeccompTest, LeastPermissiveFilterReturnValueApplies) {
+ // This is RetKillCausesDeathBySIGSYS, plus extra filters before and after the
+ // one that causes the kill that should be ignored.
+ pid_t const pid = fork();
+ if (pid == 0) {
+ RegisterSignalHandler(
+ SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); });
+ ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_TRACE);
+ ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_KILL);
+ ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_ERRNO | ENOTNAM);
+ syscall(kFilteredSyscall);
+ TEST_CHECK_MSG(false, "Survived invocation of test syscall");
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ int status;
+ ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSYS)
+ << "status " << status;
+}
+
+// Passed as argv[1] to cause the test binary to invoke kFilteredSyscall and
+// exit. Not a real flag since flag parsing happens during initialization,
+// which may create threads.
+constexpr char kInvokeFilteredSyscallFlag[] = "--seccomp_test_child";
+
+TEST(SeccompTest, FiltersPreservedAcrossForkAndExecve) {
+ ExecveArray const grandchild_argv(
+ {"/proc/self/exe", kInvokeFilteredSyscallFlag});
+
+ pid_t const pid = fork();
+ if (pid == 0) {
+ ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_KILL);
+ pid_t const grandchild_pid = fork();
+ if (grandchild_pid == 0) {
+ execve(grandchild_argv.get()[0], grandchild_argv.get(),
+ /* envp = */ nullptr);
+ TEST_PCHECK_MSG(false, "execve failed");
+ }
+ int status;
+ TEST_PCHECK(waitpid(grandchild_pid, &status, 0) == grandchild_pid);
+ TEST_CHECK(WIFSIGNALED(status) && WTERMSIG(status) == SIGSYS);
+ _exit(0);
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ int status;
+ ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "status " << status;
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
+
+int main(int argc, char** argv) {
+ if (argc >= 2 &&
+ strcmp(argv[1], gvisor::testing::kInvokeFilteredSyscallFlag) == 0) {
+ syscall(gvisor::testing::kFilteredSyscall);
+ exit(0);
+ }
+
+ gvisor::testing::TestInit(&argc, &argv);
+ return gvisor::testing::RunAllTests();
+}
diff --git a/test/syscalls/linux/select.cc b/test/syscalls/linux/select.cc
new file mode 100644
index 000000000..be2364fb8
--- /dev/null
+++ b/test/syscalls/linux/select.cc
@@ -0,0 +1,168 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <sys/resource.h>
+#include <sys/select.h>
+#include <sys/time.h>
+
+#include <climits>
+#include <csignal>
+#include <cstdio>
+
+#include "gtest/gtest.h"
+#include "absl/time/time.h"
+#include "test/syscalls/linux/base_poll_test.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/rlimit_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+class SelectTest : public BasePollTest {
+ protected:
+ void SetUp() override { BasePollTest::SetUp(); }
+ void TearDown() override { BasePollTest::TearDown(); }
+};
+
+// See that when there are no FD sets, select behaves like sleep.
+TEST_F(SelectTest, NullFds) {
+ struct timeval timeout = absl::ToTimeval(absl::Milliseconds(10));
+ ASSERT_THAT(select(0, nullptr, nullptr, nullptr, &timeout),
+ SyscallSucceeds());
+ EXPECT_EQ(timeout.tv_sec, 0);
+ EXPECT_EQ(timeout.tv_usec, 0);
+
+ timeout = absl::ToTimeval(absl::Milliseconds(10));
+ ASSERT_THAT(select(1, nullptr, nullptr, nullptr, &timeout),
+ SyscallSucceeds());
+ EXPECT_EQ(timeout.tv_sec, 0);
+ EXPECT_EQ(timeout.tv_usec, 0);
+}
+
+TEST_F(SelectTest, NegativeNfds) {
+ EXPECT_THAT(select(-1, nullptr, nullptr, nullptr, nullptr),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(select(-100000, nullptr, nullptr, nullptr, nullptr),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(select(INT_MIN, nullptr, nullptr, nullptr, nullptr),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(SelectTest, ClosedFds) {
+ auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(temp_file.path(), O_RDONLY));
+
+ // We can't rely on a file descriptor being closed in a multi threaded
+ // application so fork to get a clean process.
+ EXPECT_THAT(InForkedProcess([&] {
+ int fd_num = fd.get();
+ fd.reset();
+
+ fd_set read_set;
+ FD_ZERO(&read_set);
+ FD_SET(fd_num, &read_set);
+
+ struct timeval timeout =
+ absl::ToTimeval(absl::Milliseconds(10));
+ TEST_PCHECK(select(fd_num + 1, &read_set, nullptr, nullptr,
+ &timeout) != 0);
+ TEST_PCHECK(errno == EBADF);
+ }),
+ IsPosixErrorOkAndHolds(0));
+}
+
+TEST_F(SelectTest, ZeroTimeout) {
+ struct timeval timeout = {};
+ EXPECT_THAT(select(1, nullptr, nullptr, nullptr, &timeout),
+ SyscallSucceeds());
+ // Ignore timeout as its value is now undefined.
+}
+
+// If random S/R interrupts the select, SIGALRM may be delivered before select
+// restarts, causing the select to hang forever.
+TEST_F(SelectTest, NoTimeout_NoRandomSave) {
+ // When there's no timeout, select may never return so set a timer.
+ SetTimer(absl::Milliseconds(100));
+ // See that we get interrupted by the timer.
+ ASSERT_THAT(select(1, nullptr, nullptr, nullptr, nullptr),
+ SyscallFailsWithErrno(EINTR));
+ EXPECT_TRUE(TimerFired());
+}
+
+TEST_F(SelectTest, InvalidTimeoutNegative) {
+ struct timeval timeout = absl::ToTimeval(absl::Microseconds(-1));
+ EXPECT_THAT(select(1, nullptr, nullptr, nullptr, &timeout),
+ SyscallFailsWithErrno(EINVAL));
+ // Ignore timeout as its value is now undefined.
+}
+
+// Verify that a signal interrupts select.
+//
+// If random S/R interrupts the select, SIGALRM may be delivered before select
+// restarts, causing the select to hang forever.
+TEST_F(SelectTest, InterruptedBySignal_NoRandomSave) {
+ absl::Duration duration(absl::Seconds(5));
+ struct timeval timeout = absl::ToTimeval(duration);
+ SetTimer(absl::Milliseconds(100));
+ ASSERT_FALSE(TimerFired());
+ ASSERT_THAT(select(1, nullptr, nullptr, nullptr, &timeout),
+ SyscallFailsWithErrno(EINTR));
+ EXPECT_TRUE(TimerFired());
+ // Ignore timeout as its value is now undefined.
+}
+
+TEST_F(SelectTest, IgnoreBitsAboveNfds) {
+ // fd_set is a bit array with at least FD_SETSIZE bits. Test that bits
+ // corresponding to file descriptors above nfds are ignored.
+ fd_set read_set;
+ FD_ZERO(&read_set);
+ constexpr int kNfds = 1;
+ for (int fd = kNfds; fd < FD_SETSIZE; fd++) {
+ FD_SET(fd, &read_set);
+ }
+ // Pass a zero timeout so that select returns immediately.
+ struct timeval timeout = {};
+ EXPECT_THAT(select(kNfds, &read_set, nullptr, nullptr, &timeout),
+ SyscallSucceedsWithValue(0));
+}
+
+// This test illustrates Linux's behavior of 'select' calls passing after
+// setrlimit RLIMIT_NOFILE is called. In particular, versions of sshd rely on
+// this behavior. See b/122318458.
+TEST_F(SelectTest, SetrlimitCallNOFILE) {
+ fd_set read_set;
+ FD_ZERO(&read_set);
+ timeval timeout = {};
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(NewTempAbsPath(), O_RDONLY | O_CREAT, S_IRUSR));
+
+ Cleanup reset_rlimit =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_NOFILE, 0));
+
+ FD_SET(fd.get(), &read_set);
+ // this call with zero timeout should return immediately
+ EXPECT_THAT(select(fd.get() + 1, &read_set, nullptr, nullptr, &timeout),
+ SyscallSucceeds());
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc
new file mode 100644
index 000000000..e9b131ca9
--- /dev/null
+++ b/test/syscalls/linux/semaphore.cc
@@ -0,0 +1,491 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/ipc.h>
+#include <sys/sem.h>
+#include <sys/types.h>
+
+#include <atomic>
+#include <cerrno>
+#include <ctime>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/base/macros.h"
+#include "absl/memory/memory.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/clock.h"
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+class AutoSem {
+ public:
+ explicit AutoSem(int id) : id_(id) {}
+ ~AutoSem() {
+ if (id_ >= 0) {
+ EXPECT_THAT(semctl(id_, 0, IPC_RMID), SyscallSucceeds());
+ }
+ }
+
+ int release() {
+ int old = id_;
+ id_ = -1;
+ return old;
+ }
+
+ int get() { return id_; }
+
+ private:
+ int id_ = -1;
+};
+
+TEST(SemaphoreTest, SemGet) {
+ // Test creation and lookup.
+ AutoSem sem(semget(1, 10, IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+ EXPECT_THAT(semget(1, 10, IPC_CREAT), SyscallSucceedsWithValue(sem.get()));
+ EXPECT_THAT(semget(1, 9, IPC_CREAT), SyscallSucceedsWithValue(sem.get()));
+
+ // Creation and lookup failure cases.
+ EXPECT_THAT(semget(1, 11, IPC_CREAT), SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(semget(1, -1, IPC_CREAT), SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(semget(1, 10, IPC_CREAT | IPC_EXCL),
+ SyscallFailsWithErrno(EEXIST));
+ EXPECT_THAT(semget(2, 1, 0), SyscallFailsWithErrno(ENOENT));
+ EXPECT_THAT(semget(2, 0, IPC_CREAT), SyscallFailsWithErrno(EINVAL));
+
+ // Private semaphores never conflict.
+ AutoSem sem2(semget(IPC_PRIVATE, 1, 0));
+ AutoSem sem3(semget(IPC_PRIVATE, 1, 0));
+ ASSERT_THAT(sem2.get(), SyscallSucceeds());
+ EXPECT_NE(sem.get(), sem2.get());
+ ASSERT_THAT(sem3.get(), SyscallSucceeds());
+ EXPECT_NE(sem3.get(), sem2.get());
+}
+
+// Tests simple operations that shouldn't block in a single-thread.
+TEST(SemaphoreTest, SemOpSingleNoBlock) {
+ AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+
+ struct sembuf buf = {};
+ buf.sem_op = 1;
+ ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallSucceeds());
+
+ buf.sem_op = -1;
+ ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallSucceeds());
+
+ buf.sem_op = 0;
+ ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallSucceeds());
+
+ // Error cases with invalid values.
+ ASSERT_THAT(semop(sem.get() + 1, &buf, 1), SyscallFailsWithErrno(EINVAL));
+
+ buf.sem_num = 1;
+ ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallFailsWithErrno(EFBIG));
+
+ ASSERT_THAT(semop(sem.get(), nullptr, 0), SyscallFailsWithErrno(EINVAL));
+}
+
+// Tests multiple operations that shouldn't block in a single-thread.
+TEST(SemaphoreTest, SemOpMultiNoBlock) {
+ AutoSem sem(semget(IPC_PRIVATE, 4, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+
+ struct sembuf bufs[5] = {};
+ bufs[0].sem_num = 0;
+ bufs[0].sem_op = 10;
+ bufs[0].sem_flg = 0;
+
+ bufs[1].sem_num = 1;
+ bufs[1].sem_op = 2;
+ bufs[1].sem_flg = 0;
+
+ bufs[2].sem_num = 2;
+ bufs[2].sem_op = 3;
+ bufs[2].sem_flg = 0;
+
+ bufs[3].sem_num = 0;
+ bufs[3].sem_op = -5;
+ bufs[3].sem_flg = 0;
+
+ bufs[4].sem_num = 2;
+ bufs[4].sem_op = 2;
+ bufs[4].sem_flg = 0;
+
+ ASSERT_THAT(semop(sem.get(), bufs, ABSL_ARRAYSIZE(bufs)), SyscallSucceeds());
+
+ ASSERT_THAT(semctl(sem.get(), 0, GETVAL), SyscallSucceedsWithValue(5));
+ ASSERT_THAT(semctl(sem.get(), 1, GETVAL), SyscallSucceedsWithValue(2));
+ ASSERT_THAT(semctl(sem.get(), 2, GETVAL), SyscallSucceedsWithValue(5));
+ ASSERT_THAT(semctl(sem.get(), 3, GETVAL), SyscallSucceedsWithValue(0));
+
+ for (auto& b : bufs) {
+ b.sem_op = -b.sem_op;
+ }
+ // 0 and 3 order must be reversed, otherwise it will block.
+ std::swap(bufs[0].sem_op, bufs[3].sem_op);
+ ASSERT_THAT(RetryEINTR(semop)(sem.get(), bufs, ABSL_ARRAYSIZE(bufs)),
+ SyscallSucceeds());
+
+ // All semaphores should be back to 0 now.
+ for (size_t i = 0; i < 4; ++i) {
+ ASSERT_THAT(semctl(sem.get(), i, GETVAL), SyscallSucceedsWithValue(0));
+ }
+}
+
+// Makes a best effort attempt to ensure that operation would block.
+TEST(SemaphoreTest, SemOpBlock) {
+ AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+
+ std::atomic<int> blocked = ATOMIC_VAR_INIT(1);
+ ScopedThread th([&sem, &blocked] {
+ absl::SleepFor(absl::Milliseconds(100));
+ ASSERT_EQ(blocked.load(), 1);
+
+ struct sembuf buf = {};
+ buf.sem_op = 1;
+ ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds());
+ });
+
+ struct sembuf buf = {};
+ buf.sem_op = -1;
+ ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds());
+ blocked.store(0);
+}
+
+// Tests that IPC_NOWAIT returns with no wait.
+TEST(SemaphoreTest, SemOpNoBlock) {
+ AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+
+ struct sembuf buf = {};
+ buf.sem_flg = IPC_NOWAIT;
+
+ buf.sem_op = -1;
+ ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallFailsWithErrno(EAGAIN));
+
+ buf.sem_op = 1;
+ ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallSucceeds());
+
+ buf.sem_op = 0;
+ ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallFailsWithErrno(EAGAIN));
+}
+
+// Test runs 2 threads, one signals the other waits the same number of times.
+TEST(SemaphoreTest, SemOpSimple) {
+ AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+
+ constexpr size_t kLoops = 100;
+ ScopedThread th([&sem] {
+ struct sembuf buf = {};
+ buf.sem_op = 1;
+ for (size_t i = 0; i < kLoops; i++) {
+ // Sleep to prevent making all increments in one shot without letting
+ // the waiter wait.
+ absl::SleepFor(absl::Milliseconds(1));
+ ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallSucceeds());
+ }
+ });
+
+ struct sembuf buf = {};
+ buf.sem_op = -1;
+ for (size_t i = 0; i < kLoops; i++) {
+ ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds());
+ }
+}
+
+// Tests that semaphore can be removed while there are waiters.
+// NoRandomSave: Test relies on timing that random save throws off.
+TEST(SemaphoreTest, SemOpRemoveWithWaiter_NoRandomSave) {
+ AutoSem sem(semget(IPC_PRIVATE, 2, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+
+ ScopedThread th([&sem] {
+ absl::SleepFor(absl::Milliseconds(250));
+ ASSERT_THAT(semctl(sem.release(), 0, IPC_RMID), SyscallSucceeds());
+ });
+
+ // This must happen before IPC_RMID runs above. Otherwise it fails with EINVAL
+ // instead because the semaphore has already been removed.
+ struct sembuf buf = {};
+ buf.sem_op = -1;
+ ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1),
+ SyscallFailsWithErrno(EIDRM));
+}
+
+// Semaphore isn't fair. It will execute any waiter that can satisfy the
+// request even if it gets in front of other waiters.
+TEST(SemaphoreTest, SemOpBestFitExecution) {
+ AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+
+ ScopedThread th([&sem] {
+ struct sembuf buf = {};
+ buf.sem_op = -2;
+ ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallFails());
+ // Ensure that wait will only unblock when the semaphore is removed. On
+ // EINTR retry it may race with deletion and return EINVAL.
+ ASSERT_TRUE(errno == EIDRM || errno == EINVAL) << "errno=" << errno;
+ });
+
+ // Ensures that '-1' below will unblock even though '-10' above is waiting
+ // for the same semaphore.
+ for (size_t i = 0; i < 10; ++i) {
+ struct sembuf buf = {};
+ buf.sem_op = 1;
+ ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds());
+
+ absl::SleepFor(absl::Milliseconds(10));
+
+ buf.sem_op = -1;
+ ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds());
+ }
+
+ ASSERT_THAT(semctl(sem.release(), 0, IPC_RMID), SyscallSucceeds());
+}
+
+// Executes random operations in multiple threads and verify correctness.
+TEST(SemaphoreTest, SemOpRandom) {
+ // Don't do cooperative S/R tests because there are too many syscalls in
+ // this test,
+ const DisableSave ds;
+
+ AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+
+ // Protects the seed below.
+ absl::Mutex mutex;
+ uint32_t seed = time(nullptr);
+
+ int count = 0; // Tracks semaphore value.
+ bool done = false; // Tells waiters to stop after signal threads are done.
+
+ // These threads will wait in a loop.
+ std::unique_ptr<ScopedThread> decs[5];
+ for (auto& dec : decs) {
+ dec = absl::make_unique<ScopedThread>([&sem, &mutex, &count, &seed, &done] {
+ for (size_t i = 0; i < 500; ++i) {
+ int16_t val;
+ {
+ absl::MutexLock l(&mutex);
+ if (done) {
+ return;
+ }
+ val = (rand_r(&seed) % 10 + 1); // Rand between 1 and 10.
+ count -= val;
+ }
+ struct sembuf buf = {};
+ buf.sem_op = -val;
+ ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds());
+ absl::SleepFor(absl::Milliseconds(val * 2));
+ }
+ });
+ }
+
+ // These threads will wait for zero in a loop.
+ std::unique_ptr<ScopedThread> zeros[5];
+ for (auto& zero : zeros) {
+ zero = absl::make_unique<ScopedThread>([&sem, &mutex, &done] {
+ for (size_t i = 0; i < 500; ++i) {
+ {
+ absl::MutexLock l(&mutex);
+ if (done) {
+ return;
+ }
+ }
+ struct sembuf buf = {};
+ buf.sem_op = 0;
+ ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds());
+ absl::SleepFor(absl::Milliseconds(10));
+ }
+ });
+ }
+
+ // These threads will signal in a loop.
+ std::unique_ptr<ScopedThread> incs[5];
+ for (auto& inc : incs) {
+ inc = absl::make_unique<ScopedThread>([&sem, &mutex, &count, &seed] {
+ for (size_t i = 0; i < 500; ++i) {
+ int16_t val;
+ {
+ absl::MutexLock l(&mutex);
+ val = (rand_r(&seed) % 10 + 1); // Rand between 1 and 10.
+ count += val;
+ }
+ struct sembuf buf = {};
+ buf.sem_op = val;
+ ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallSucceeds());
+ absl::SleepFor(absl::Milliseconds(val * 2));
+ }
+ });
+ }
+
+ // First wait for signal threads to be done.
+ for (auto& inc : incs) {
+ inc->Join();
+ }
+
+ // Now there could be waiters blocked (remember operations are random).
+ // Notify waiters that we're done and signal semaphore just the right amount.
+ {
+ absl::MutexLock l(&mutex);
+ done = true;
+ struct sembuf buf = {};
+ buf.sem_op = -count;
+ ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallSucceeds());
+ }
+
+ // Now all waiters should unblock and exit.
+ for (auto& dec : decs) {
+ dec->Join();
+ }
+ for (auto& zero : zeros) {
+ zero->Join();
+ }
+}
+
+TEST(SemaphoreTest, SemOpNamespace) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ AutoSem sem(semget(123, 1, 0600 | IPC_CREAT | IPC_EXCL));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+
+ ScopedThread([]() {
+ EXPECT_THAT(unshare(CLONE_NEWIPC), SyscallSucceeds());
+ AutoSem sem(semget(123, 1, 0600 | IPC_CREAT | IPC_EXCL));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+ });
+}
+
+TEST(SemaphoreTest, SemCtlVal) {
+ AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+
+ // Semaphore must start with 0.
+ EXPECT_THAT(semctl(sem.get(), 0, GETVAL), SyscallSucceedsWithValue(0));
+
+ // Increase value and ensure waiters are woken up.
+ ScopedThread th([&sem] {
+ struct sembuf buf = {};
+ buf.sem_op = -10;
+ ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds());
+ });
+
+ ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 9), SyscallSucceeds());
+ EXPECT_THAT(semctl(sem.get(), 0, GETVAL), SyscallSucceedsWithValue(9));
+
+ ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 20), SyscallSucceeds());
+ const int value = semctl(sem.get(), 0, GETVAL);
+ // 10 or 20 because it could have raced with waiter above.
+ EXPECT_TRUE(value == 10 || value == 20) << "value=" << value;
+ th.Join();
+
+ // Set it back to 0 and ensure that waiters are woken up.
+ ScopedThread thZero([&sem] {
+ struct sembuf buf = {};
+ buf.sem_op = 0;
+ ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds());
+ });
+ ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 0), SyscallSucceeds());
+ EXPECT_THAT(semctl(sem.get(), 0, GETVAL), SyscallSucceedsWithValue(0));
+ thZero.Join();
+}
+
+TEST(SemaphoreTest, SemCtlValAll) {
+ AutoSem sem(semget(IPC_PRIVATE, 3, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+
+ // Semaphores must start with 0.
+ uint16_t get[3] = {10, 10, 10};
+ EXPECT_THAT(semctl(sem.get(), 1, GETALL, get), SyscallSucceedsWithValue(0));
+ for (auto v : get) {
+ EXPECT_EQ(v, 0);
+ }
+
+ // SetAll and check that they were set.
+ uint16_t vals[3] = {0, 10, 20};
+ EXPECT_THAT(semctl(sem.get(), 1, SETALL, vals), SyscallSucceedsWithValue(0));
+ EXPECT_THAT(semctl(sem.get(), 1, GETALL, get), SyscallSucceedsWithValue(0));
+ for (size_t i = 0; i < ABSL_ARRAYSIZE(vals); ++i) {
+ EXPECT_EQ(get[i], vals[i]);
+ }
+
+ EXPECT_THAT(semctl(sem.get(), 1, SETALL, nullptr),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST(SemaphoreTest, SemCtlGetPid) {
+ AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+
+ ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 1), SyscallSucceeds());
+ EXPECT_THAT(semctl(sem.get(), 0, GETPID), SyscallSucceedsWithValue(getpid()));
+}
+
+TEST(SemaphoreTest, SemCtlGetPidFork) {
+ AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+
+ const pid_t child_pid = fork();
+ if (child_pid == 0) {
+ TEST_PCHECK(semctl(sem.get(), 0, SETVAL, 1) == 0);
+ TEST_PCHECK(semctl(sem.get(), 0, GETPID) == getpid());
+
+ _exit(0);
+ }
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << " status " << status;
+}
+
+TEST(SemaphoreTest, SemIpcSet) {
+ // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false));
+
+ AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+
+ struct semid_ds semid = {};
+ semid.sem_perm.uid = getuid();
+ semid.sem_perm.gid = getgid();
+
+ // Make semaphore readonly and check that signal fails.
+ semid.sem_perm.mode = 0400;
+ EXPECT_THAT(semctl(sem.get(), 0, IPC_SET, &semid), SyscallSucceeds());
+ struct sembuf buf = {};
+ buf.sem_op = 1;
+ ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallFailsWithErrno(EACCES));
+
+ // Make semaphore writeonly and check that wait for zero fails.
+ semid.sem_perm.mode = 0200;
+ EXPECT_THAT(semctl(sem.get(), 0, IPC_SET, &semid), SyscallSucceeds());
+ buf.sem_op = 0;
+ ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallFailsWithErrno(EACCES));
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc
new file mode 100644
index 000000000..64123e904
--- /dev/null
+++ b/test/syscalls/linux/sendfile.cc
@@ -0,0 +1,587 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <linux/unistd.h>
+#include <sys/eventfd.h>
+#include <sys/sendfile.h>
+#include <unistd.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/string_view.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/eventfd_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(SendFileTest, SendZeroBytes) {
+ // Create temp files.
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ // Open the input file as read only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Open the output file as write only.
+ const FileDescriptor outf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY));
+
+ // Send data and verify that sendfile returns the correct value.
+ EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, 0),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST(SendFileTest, InvalidOffset) {
+ // Create temp files.
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ // Open the input file as read only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Open the output file as write only.
+ const FileDescriptor outf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY));
+
+ // Send data and verify that sendfile returns the correct value.
+ off_t offset = -1;
+ EXPECT_THAT(sendfile(outf.get(), inf.get(), &offset, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+int memfd_create(const std::string& name, unsigned int flags) {
+ return syscall(__NR_memfd_create, name.c_str(), flags);
+}
+
+TEST(SendFileTest, Overflow) {
+ // Create input file.
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Open the output file.
+ int fd;
+ EXPECT_THAT(fd = memfd_create("overflow", 0), SyscallSucceeds());
+ const FileDescriptor outf(fd);
+
+ // out_offset + kSize overflows INT64_MAX.
+ loff_t out_offset = 0x7ffffffffffffffeull;
+ constexpr int kSize = 3;
+ EXPECT_THAT(sendfile(outf.get(), inf.get(), &out_offset, kSize),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SendFileTest, SendTrivially) {
+ // Create temp files.
+ constexpr char kData[] = "To be, or not to be, that is the question:";
+ constexpr int kDataSize = sizeof(kData) - 1;
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ // Open the input file as read only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Open the output file as write only.
+ FileDescriptor outf;
+ outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY));
+
+ // Send data and verify that sendfile returns the correct value.
+ int bytes_sent;
+ EXPECT_THAT(bytes_sent = sendfile(outf.get(), inf.get(), nullptr, kDataSize),
+ SyscallSucceedsWithValue(kDataSize));
+
+ // Close outf to avoid leak.
+ outf.reset();
+
+ // Open the output file as read only.
+ outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDONLY));
+
+ // Verify that the output file has the correct data.
+ char actual[kDataSize];
+ ASSERT_THAT(read(outf.get(), &actual, bytes_sent),
+ SyscallSucceedsWithValue(kDataSize));
+ EXPECT_EQ(kData, absl::string_view(actual, bytes_sent));
+}
+
+TEST(SendFileTest, SendTriviallyWithBothFilesReadWrite) {
+ // Create temp files.
+ constexpr char kData[] = "Whether 'tis nobler in the mind to suffer";
+ constexpr int kDataSize = sizeof(kData) - 1;
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ // Open the input file as readwrite.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+
+ // Open the output file as readwrite.
+ FileDescriptor outf;
+ outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR));
+
+ // Send data and verify that sendfile returns the correct value.
+ int bytes_sent;
+ EXPECT_THAT(bytes_sent = sendfile(outf.get(), inf.get(), nullptr, kDataSize),
+ SyscallSucceedsWithValue(kDataSize));
+
+ // Close outf to avoid leak.
+ outf.reset();
+
+ // Open the output file as read only.
+ outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDONLY));
+
+ // Verify that the output file has the correct data.
+ char actual[kDataSize];
+ ASSERT_THAT(read(outf.get(), &actual, bytes_sent),
+ SyscallSucceedsWithValue(kDataSize));
+ EXPECT_EQ(kData, absl::string_view(actual, bytes_sent));
+}
+
+TEST(SendFileTest, SendAndUpdateFileOffset) {
+ // Create temp files.
+ // Test input string length must be > 2 AND even.
+ constexpr char kData[] = "The slings and arrows of outrageous fortune,";
+ constexpr int kDataSize = sizeof(kData) - 1;
+ constexpr int kHalfDataSize = kDataSize / 2;
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ // Open the input file as read only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Open the output file as write only.
+ FileDescriptor outf;
+ outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY));
+
+ // Send data and verify that sendfile returns the correct value.
+ int bytes_sent;
+ EXPECT_THAT(
+ bytes_sent = sendfile(outf.get(), inf.get(), nullptr, kHalfDataSize),
+ SyscallSucceedsWithValue(kHalfDataSize));
+
+ // Close outf to avoid leak.
+ outf.reset();
+
+ // Open the output file as read only.
+ outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDONLY));
+
+ // Verify that the output file has the correct data.
+ char actual[kHalfDataSize];
+ ASSERT_THAT(read(outf.get(), &actual, bytes_sent),
+ SyscallSucceedsWithValue(kHalfDataSize));
+ EXPECT_EQ(absl::string_view(kData, kHalfDataSize),
+ absl::string_view(actual, bytes_sent));
+
+ // Verify that the input file offset has been updated
+ ASSERT_THAT(read(inf.get(), &actual, kDataSize - bytes_sent),
+ SyscallSucceedsWithValue(kHalfDataSize));
+ EXPECT_EQ(
+ absl::string_view(kData + kDataSize - bytes_sent, kDataSize - bytes_sent),
+ absl::string_view(actual, kHalfDataSize));
+}
+
+TEST(SendFileTest, SendAndUpdateFileOffsetFromNonzeroStartingPoint) {
+ // Create temp files.
+ // Test input string length must be > 2 AND divisible by 4.
+ constexpr char kData[] = "The slings and arrows of outrageous fortune,";
+ constexpr int kDataSize = sizeof(kData) - 1;
+ constexpr int kHalfDataSize = kDataSize / 2;
+ constexpr int kQuarterDataSize = kHalfDataSize / 2;
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ // Open the input file as read only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Open the output file as write only.
+ FileDescriptor outf;
+ outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY));
+
+ // Read a quarter of the data from the infile which should update the file
+ // offset, we don't actually care about the data so it goes into the garbage.
+ char garbage[kQuarterDataSize];
+ ASSERT_THAT(read(inf.get(), &garbage, kQuarterDataSize),
+ SyscallSucceedsWithValue(kQuarterDataSize));
+
+ // Send data and verify that sendfile returns the correct value.
+ int bytes_sent;
+ EXPECT_THAT(
+ bytes_sent = sendfile(outf.get(), inf.get(), nullptr, kHalfDataSize),
+ SyscallSucceedsWithValue(kHalfDataSize));
+
+ // Close out_fd to avoid leak.
+ outf.reset();
+
+ // Open the output file as read only.
+ outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDONLY));
+
+ // Verify that the output file has the correct data.
+ char actual[kHalfDataSize];
+ ASSERT_THAT(read(outf.get(), &actual, bytes_sent),
+ SyscallSucceedsWithValue(kHalfDataSize));
+ EXPECT_EQ(absl::string_view(kData + kQuarterDataSize, kHalfDataSize),
+ absl::string_view(actual, bytes_sent));
+
+ // Verify that the input file offset has been updated
+ ASSERT_THAT(read(inf.get(), &actual, kQuarterDataSize),
+ SyscallSucceedsWithValue(kQuarterDataSize));
+
+ EXPECT_EQ(
+ absl::string_view(kData + kDataSize - kQuarterDataSize, kQuarterDataSize),
+ absl::string_view(actual, kQuarterDataSize));
+}
+
+TEST(SendFileTest, SendAndUpdateGivenOffset) {
+ // Create temp files.
+ // Test input string length must be >= 4 AND divisible by 4.
+ constexpr char kData[] = "Or to take Arms against a Sea of troubles,";
+ constexpr int kDataSize = sizeof(kData) + 1;
+ constexpr int kHalfDataSize = kDataSize / 2;
+ constexpr int kQuarterDataSize = kHalfDataSize / 2;
+ constexpr int kThreeFourthsDataSize = 3 * kDataSize / 4;
+
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ // Open the input file as read only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Open the output file as write only.
+ FileDescriptor outf;
+ outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY));
+
+ // Create offset for sending.
+ off_t offset = kQuarterDataSize;
+
+ // Send data and verify that sendfile returns the correct value.
+ int bytes_sent;
+ EXPECT_THAT(
+ bytes_sent = sendfile(outf.get(), inf.get(), &offset, kHalfDataSize),
+ SyscallSucceedsWithValue(kHalfDataSize));
+
+ // Close out_fd to avoid leak.
+ outf.reset();
+
+ // Open the output file as read only.
+ outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDONLY));
+
+ // Verify that the output file has the correct data.
+ char actual[kHalfDataSize];
+ ASSERT_THAT(read(outf.get(), &actual, bytes_sent),
+ SyscallSucceedsWithValue(kHalfDataSize));
+ EXPECT_EQ(absl::string_view(kData + kQuarterDataSize, kHalfDataSize),
+ absl::string_view(actual, bytes_sent));
+
+ // Verify that the input file offset has NOT been updated.
+ ASSERT_THAT(read(inf.get(), &actual, kHalfDataSize),
+ SyscallSucceedsWithValue(kHalfDataSize));
+ EXPECT_EQ(absl::string_view(kData, kHalfDataSize),
+ absl::string_view(actual, kHalfDataSize));
+
+ // Verify that the offset pointer has been updated.
+ EXPECT_EQ(offset, kThreeFourthsDataSize);
+}
+
+TEST(SendFileTest, DoNotSendfileIfOutfileIsAppendOnly) {
+ // Create temp files.
+ constexpr char kData[] = "And by opposing end them: to die, to sleep";
+ constexpr int kDataSize = sizeof(kData) - 1;
+
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ // Open the input file as read only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Open the output file as append only.
+ const FileDescriptor outf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY | O_APPEND));
+
+ // Send data and verify that sendfile returns the correct errno.
+ EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, kDataSize),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SendFileTest, AppendCheckOrdering) {
+ constexpr char kData[] = "And by opposing end them: to die, to sleep";
+ constexpr int kDataSize = sizeof(kData) - 1;
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+
+ const FileDescriptor read =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+ const FileDescriptor write =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY));
+ const FileDescriptor append =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_APPEND));
+
+ // Check that read/write file mode is verified before append.
+ EXPECT_THAT(sendfile(append.get(), read.get(), nullptr, kDataSize),
+ SyscallFailsWithErrno(EBADF));
+ EXPECT_THAT(sendfile(write.get(), write.get(), nullptr, kDataSize),
+ SyscallFailsWithErrno(EBADF));
+}
+
+TEST(SendFileTest, DoNotSendfileIfOutfileIsNotWritable) {
+ // Create temp files.
+ constexpr char kData[] = "No more; and by a sleep, to say we end";
+ constexpr int kDataSize = sizeof(kData) - 1;
+
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ // Open the input file as read only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Open the output file as read only.
+ const FileDescriptor outf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDONLY));
+
+ // Send data and verify that sendfile returns the correct errno.
+ EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, kDataSize),
+ SyscallFailsWithErrno(EBADF));
+}
+
+TEST(SendFileTest, DoNotSendfileIfInfileIsNotReadable) {
+ // Create temp files.
+ constexpr char kData[] = "the heart-ache, and the thousand natural shocks";
+ constexpr int kDataSize = sizeof(kData) - 1;
+
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ // Open the input file as write only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_WRONLY));
+
+ // Open the output file as write only.
+ const FileDescriptor outf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY));
+
+ // Send data and verify that sendfile returns the correct errno.
+ EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, kDataSize),
+ SyscallFailsWithErrno(EBADF));
+}
+
+TEST(SendFileTest, DoNotSendANegativeNumberOfBytes) {
+ // Create temp files.
+ constexpr char kData[] = "that Flesh is heir to? 'Tis a consummation";
+
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ // Open the input file as read only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Open the output file as write only.
+ const FileDescriptor outf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY));
+
+ // Send data and verify that sendfile returns the correct errno.
+ EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, -1),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SendFileTest, SendTheCorrectNumberOfBytesEvenIfWeTryToSendTooManyBytes) {
+ // Create temp files.
+ constexpr char kData[] = "devoutly to be wished. To die, to sleep,";
+ constexpr int kDataSize = sizeof(kData) - 1;
+
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ // Open the input file as read only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Open the output file as write only.
+ FileDescriptor outf;
+ outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY));
+
+ // Send data and verify that sendfile returns the correct value.
+ int bytes_sent;
+ EXPECT_THAT(
+ bytes_sent = sendfile(outf.get(), inf.get(), nullptr, kDataSize + 100),
+ SyscallSucceedsWithValue(kDataSize));
+
+ // Close outf to avoid leak.
+ outf.reset();
+
+ // Open the output file as read only.
+ outf = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDONLY));
+
+ // Verify that the output file has the correct data.
+ char actual[kDataSize];
+ ASSERT_THAT(read(outf.get(), &actual, bytes_sent),
+ SyscallSucceedsWithValue(kDataSize));
+ EXPECT_EQ(kData, absl::string_view(actual, bytes_sent));
+}
+
+TEST(SendFileTest, SendToNotARegularFile) {
+ // Make temp input directory and open as read only.
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY));
+
+ // Make temp output file and open as write only.
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor outf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY));
+
+ // Receive an error since a directory is not a regular file.
+ EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SendFileTest, SendPipeWouldBlock) {
+ // Create temp file.
+ constexpr char kData[] =
+ "The fool doth think he is wise, but the wise man knows himself to be a "
+ "fool.";
+ constexpr int kDataSize = sizeof(kData) - 1;
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+
+ // Open the input file as read only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Setup the output named pipe.
+ int fds[2];
+ ASSERT_THAT(pipe2(fds, O_NONBLOCK), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill up the pipe's buffer.
+ int pipe_size = -1;
+ ASSERT_THAT(pipe_size = fcntl(wfd.get(), F_GETPIPE_SZ), SyscallSucceeds());
+ std::vector<char> buf(2 * pipe_size);
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(pipe_size));
+
+ EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+TEST(SendFileTest, SendPipeBlocks) {
+ // Create temp file.
+ constexpr char kData[] =
+ "The fault, dear Brutus, is not in our stars, but in ourselves.";
+ constexpr int kDataSize = sizeof(kData) - 1;
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+
+ // Open the input file as read only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Setup the output named pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill up the pipe's buffer.
+ int pipe_size = -1;
+ ASSERT_THAT(pipe_size = fcntl(wfd.get(), F_GETPIPE_SZ), SyscallSucceeds());
+ std::vector<char> buf(pipe_size);
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(pipe_size));
+
+ ScopedThread t([&]() {
+ absl::SleepFor(absl::Milliseconds(100));
+ ASSERT_THAT(read(rfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(pipe_size));
+ });
+
+ EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize),
+ SyscallSucceedsWithValue(kDataSize));
+}
+
+TEST(SendFileTest, SendToSpecialFile) {
+ // Create temp file.
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode));
+
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+ constexpr int kSize = 0x7ff;
+ ASSERT_THAT(ftruncate(inf.get(), kSize), SyscallSucceeds());
+
+ auto eventfd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD());
+
+ // eventfd can accept a number of bytes which is a multiple of 8.
+ EXPECT_THAT(sendfile(eventfd.get(), inf.get(), nullptr, 0xfffff),
+ SyscallSucceedsWithValue(kSize & (~7)));
+}
+
+TEST(SendFileTest, SendFileToPipe) {
+ // Create temp file.
+ constexpr char kData[] = "<insert-quote-here>";
+ constexpr int kDataSize = sizeof(kData) - 1;
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Create a pipe for sending to a pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Expect to read up to the given size.
+ std::vector<char> buf(kDataSize);
+ ScopedThread t([&]() {
+ absl::SleepFor(absl::Milliseconds(100));
+ ASSERT_THAT(read(rfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kDataSize));
+ });
+
+ // Send with twice the size of the file, which should hit EOF.
+ EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize * 2),
+ SyscallSucceedsWithValue(kDataSize));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/sendfile_socket.cc b/test/syscalls/linux/sendfile_socket.cc
new file mode 100644
index 000000000..c101fe9d2
--- /dev/null
+++ b/test/syscalls/linux/sendfile_socket.cc
@@ -0,0 +1,231 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <netinet/in.h>
+#include <sys/sendfile.h>
+#include <sys/socket.h>
+#include <unistd.h>
+
+#include <iostream>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "absl/strings/string_view.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+class SendFileTest : public ::testing::TestWithParam<int> {
+ protected:
+ PosixErrorOr<std::unique_ptr<SocketPair>> Sockets(int type) {
+ // Bind a server socket.
+ int family = GetParam();
+ switch (family) {
+ case AF_INET: {
+ if (type == SOCK_STREAM) {
+ return SocketPairKind{
+ "TCP", AF_INET, type, 0,
+ TCPAcceptBindSocketPairCreator(AF_INET, type, 0, false)}
+ .Create();
+ } else {
+ return SocketPairKind{
+ "UDP", AF_INET, type, 0,
+ UDPBidirectionalBindSocketPairCreator(AF_INET, type, 0, false)}
+ .Create();
+ }
+ }
+ case AF_UNIX: {
+ if (type == SOCK_STREAM) {
+ return SocketPairKind{
+ "UNIX", AF_UNIX, type, 0,
+ FilesystemAcceptBindSocketPairCreator(AF_UNIX, type, 0)}
+ .Create();
+ } else {
+ return SocketPairKind{
+ "UNIX", AF_UNIX, type, 0,
+ FilesystemBidirectionalBindSocketPairCreator(AF_UNIX, type, 0)}
+ .Create();
+ }
+ }
+ default:
+ return PosixError(EINVAL);
+ }
+ }
+};
+
+// Sends large file to exercise the path that read and writes data multiple
+// times, esp. when more data is read than can be written.
+TEST_P(SendFileTest, SendMultiple) {
+ std::vector<char> data(5 * 1024 * 1024);
+ RandomizeBuffer(data.data(), data.size());
+
+ // Create temp files.
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::string_view(data.data(), data.size()),
+ TempPath::kDefaultFileMode));
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ // Create sockets.
+ auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_STREAM));
+
+ // Thread that reads data from socket and dumps to a file.
+ ScopedThread th([&] {
+ FileDescriptor outf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY));
+
+ // Read until socket is closed.
+ char buf[10240];
+ for (int cnt = 0;; cnt++) {
+ int r = RetryEINTR(read)(socks->first_fd(), buf, sizeof(buf));
+ // We cannot afford to save on every read() call.
+ if (cnt % 1000 == 0) {
+ ASSERT_THAT(r, SyscallSucceeds());
+ } else {
+ const DisableSave ds;
+ ASSERT_THAT(r, SyscallSucceeds());
+ }
+ if (r == 0) {
+ // EOF
+ break;
+ }
+ int w = RetryEINTR(write)(outf.get(), buf, r);
+ // We cannot afford to save on every write() call.
+ if (cnt % 1010 == 0) {
+ ASSERT_THAT(w, SyscallSucceedsWithValue(r));
+ } else {
+ const DisableSave ds;
+ ASSERT_THAT(w, SyscallSucceedsWithValue(r));
+ }
+ }
+ });
+
+ // Open the input file as read only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ int cnt = 0;
+ for (size_t sent = 0; sent < data.size(); cnt++) {
+ const size_t remain = data.size() - sent;
+ std::cout << "sendfile, size=" << data.size() << ", sent=" << sent
+ << ", remain=" << remain << std::endl;
+
+ // Send data and verify that sendfile returns the correct value.
+ int res = sendfile(socks->second_fd(), inf.get(), nullptr, remain);
+ // We cannot afford to save on every sendfile() call.
+ if (cnt % 120 == 0) {
+ MaybeSave();
+ }
+ if (res == 0) {
+ // EOF
+ break;
+ }
+ if (res > 0) {
+ sent += res;
+ } else {
+ ASSERT_TRUE(errno == EINTR || errno == EAGAIN) << "errno=" << errno;
+ }
+ }
+
+ // Close socket to stop thread.
+ close(socks->release_second_fd());
+ th.Join();
+
+ // Verify that the output file has the correct data.
+ const FileDescriptor outf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDONLY));
+ std::vector<char> actual(data.size(), '\0');
+ ASSERT_THAT(RetryEINTR(read)(outf.get(), actual.data(), actual.size()),
+ SyscallSucceedsWithValue(actual.size()));
+ ASSERT_EQ(memcmp(data.data(), actual.data(), data.size()), 0);
+}
+
+TEST_P(SendFileTest, Shutdown) {
+ // Create a socket.
+ auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_STREAM));
+
+ // If this is a TCP socket, then turn off linger.
+ if (GetParam() == AF_INET) {
+ struct linger sl;
+ sl.l_onoff = 1;
+ sl.l_linger = 0;
+ ASSERT_THAT(
+ setsockopt(socks->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)),
+ SyscallSucceeds());
+ }
+
+ // Create a 1m file with random data.
+ std::vector<char> data(1024 * 1024);
+ RandomizeBuffer(data.data(), data.size());
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::string_view(data.data(), data.size()),
+ TempPath::kDefaultFileMode));
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Read some data, then shutdown the socket. We don't actually care about
+ // checking the contents (other tests do that), so we just re-use the same
+ // buffer as above.
+ ScopedThread t([&]() {
+ size_t done = 0;
+ while (done < data.size()) {
+ int n = RetryEINTR(read)(socks->first_fd(), data.data(), data.size());
+ ASSERT_THAT(n, SyscallSucceeds());
+ done += n;
+ }
+ // Close the server side socket.
+ close(socks->release_first_fd());
+ });
+
+ // Continuously stream from the file to the socket. Note we do not assert
+ // that a specific amount of data has been written at any time, just that some
+ // data is written. Eventually, we should get a connection reset error.
+ while (1) {
+ off_t offset = 0; // Always read from the start.
+ int n = sendfile(socks->second_fd(), inf.get(), &offset, data.size());
+ EXPECT_THAT(n, AnyOf(SyscallFailsWithErrno(ECONNRESET),
+ SyscallFailsWithErrno(EPIPE), SyscallSucceeds()));
+ if (n <= 0) {
+ break;
+ }
+ }
+}
+
+TEST_P(SendFileTest, SendpageFromEmptyFileToUDP) {
+ auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_DGRAM));
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+
+ // The value to the count argument has to be so that it is impossible to
+ // allocate a buffer of this size. In Linux, sendfile transfer at most
+ // 0x7ffff000 (MAX_RW_COUNT) bytes.
+ EXPECT_THAT(sendfile(socks->first_fd(), fd.get(), 0x0, 0x8000000000004),
+ SyscallSucceedsWithValue(0));
+}
+
+INSTANTIATE_TEST_SUITE_P(AddressFamily, SendFileTest,
+ ::testing::Values(AF_UNIX, AF_INET));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/shm.cc b/test/syscalls/linux/shm.cc
new file mode 100644
index 000000000..c7fdbb924
--- /dev/null
+++ b/test/syscalls/linux/shm.cc
@@ -0,0 +1,508 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+#include <sys/ipc.h>
+#include <sys/mman.h>
+#include <sys/shm.h>
+#include <sys/types.h>
+
+#include "absl/time/clock.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+using ::testing::_;
+
+const uint64_t kAllocSize = kPageSize * 128ULL;
+
+PosixErrorOr<char*> Shmat(int shmid, const void* shmaddr, int shmflg) {
+ const intptr_t addr =
+ reinterpret_cast<intptr_t>(shmat(shmid, shmaddr, shmflg));
+ if (addr == -1) {
+ return PosixError(errno, "shmat() failed");
+ }
+ return reinterpret_cast<char*>(addr);
+}
+
+PosixError Shmdt(const char* shmaddr) {
+ const int ret = shmdt(shmaddr);
+ if (ret == -1) {
+ return PosixError(errno, "shmdt() failed");
+ }
+ return NoError();
+}
+
+template <typename T>
+PosixErrorOr<int> Shmctl(int shmid, int cmd, T* buf) {
+ int ret = shmctl(shmid, cmd, reinterpret_cast<struct shmid_ds*>(buf));
+ if (ret == -1) {
+ return PosixError(errno, "shmctl() failed");
+ }
+ return ret;
+}
+
+// ShmSegment is a RAII object for automatically cleaning up shm segments.
+class ShmSegment {
+ public:
+ explicit ShmSegment(int id) : id_(id) {}
+
+ ~ShmSegment() {
+ if (id_ >= 0) {
+ EXPECT_NO_ERRNO(Rmid());
+ id_ = -1;
+ }
+ }
+
+ ShmSegment(ShmSegment&& other) : id_(other.release()) {}
+
+ ShmSegment& operator=(ShmSegment&& other) {
+ id_ = other.release();
+ return *this;
+ }
+
+ ShmSegment(ShmSegment const& other) = delete;
+ ShmSegment& operator=(ShmSegment const& other) = delete;
+
+ int id() const { return id_; }
+
+ int release() {
+ int id = id_;
+ id_ = -1;
+ return id;
+ }
+
+ PosixErrorOr<int> Rmid() {
+ RETURN_IF_ERRNO(Shmctl<void>(id_, IPC_RMID, nullptr));
+ return release();
+ }
+
+ private:
+ int id_ = -1;
+};
+
+PosixErrorOr<int> ShmgetRaw(key_t key, size_t size, int shmflg) {
+ int id = shmget(key, size, shmflg);
+ if (id == -1) {
+ return PosixError(errno, "shmget() failed");
+ }
+ return id;
+}
+
+PosixErrorOr<ShmSegment> Shmget(key_t key, size_t size, int shmflg) {
+ ASSIGN_OR_RETURN_ERRNO(int id, ShmgetRaw(key, size, shmflg));
+ return ShmSegment(id);
+}
+
+TEST(ShmTest, AttachDetach) {
+ const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE(
+ Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777));
+ struct shmid_ds attr;
+ ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr));
+ EXPECT_EQ(attr.shm_segsz, kAllocSize);
+ EXPECT_EQ(attr.shm_nattch, 0);
+
+ const char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0));
+ ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr));
+ EXPECT_EQ(attr.shm_nattch, 1);
+
+ const char* addr2 = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0));
+ ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr));
+ EXPECT_EQ(attr.shm_nattch, 2);
+
+ ASSERT_NO_ERRNO(Shmdt(addr));
+ ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr));
+ EXPECT_EQ(attr.shm_nattch, 1);
+
+ ASSERT_NO_ERRNO(Shmdt(addr2));
+ ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr));
+ EXPECT_EQ(attr.shm_nattch, 0);
+}
+
+TEST(ShmTest, LookupByKey) {
+ const TempPath keyfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const key_t key = ftok(keyfile.path().c_str(), 1);
+ const ShmSegment shm =
+ ASSERT_NO_ERRNO_AND_VALUE(Shmget(key, kAllocSize, IPC_CREAT | 0777));
+ const int id2 = ASSERT_NO_ERRNO_AND_VALUE(ShmgetRaw(key, kAllocSize, 0777));
+ EXPECT_EQ(shm.id(), id2);
+}
+
+TEST(ShmTest, DetachedSegmentsPersist) {
+ const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE(
+ Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777));
+ char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0));
+ addr[0] = 'x';
+ ASSERT_NO_ERRNO(Shmdt(addr));
+
+ // We should be able to re-attach to the same segment and get our data back.
+ addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0));
+ EXPECT_EQ(addr[0], 'x');
+ ASSERT_NO_ERRNO(Shmdt(addr));
+}
+
+TEST(ShmTest, MultipleDetachFails) {
+ const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE(
+ Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777));
+ const char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0));
+ ASSERT_NO_ERRNO(Shmdt(addr));
+ EXPECT_THAT(Shmdt(addr), PosixErrorIs(EINVAL, _));
+}
+
+TEST(ShmTest, IpcStat) {
+ const TempPath keyfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const key_t key = ftok(keyfile.path().c_str(), 1);
+
+ const time_t start = time(nullptr);
+
+ const ShmSegment shm =
+ ASSERT_NO_ERRNO_AND_VALUE(Shmget(key, kAllocSize, IPC_CREAT | 0777));
+
+ const uid_t uid = getuid();
+ const gid_t gid = getgid();
+ const pid_t pid = getpid();
+
+ struct shmid_ds attr;
+ ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr));
+
+ EXPECT_EQ(attr.shm_perm.__key, key);
+ EXPECT_EQ(attr.shm_perm.uid, uid);
+ EXPECT_EQ(attr.shm_perm.gid, gid);
+ EXPECT_EQ(attr.shm_perm.cuid, uid);
+ EXPECT_EQ(attr.shm_perm.cgid, gid);
+ EXPECT_EQ(attr.shm_perm.mode, 0777);
+
+ EXPECT_EQ(attr.shm_segsz, kAllocSize);
+
+ EXPECT_EQ(attr.shm_atime, 0);
+ EXPECT_EQ(attr.shm_dtime, 0);
+
+ // Change time is set on creation.
+ EXPECT_GE(attr.shm_ctime, start);
+
+ EXPECT_EQ(attr.shm_cpid, pid);
+ EXPECT_EQ(attr.shm_lpid, 0);
+
+ EXPECT_EQ(attr.shm_nattch, 0);
+
+ // The timestamps only have a resolution of seconds; slow down so we actually
+ // see the timestamps change.
+ absl::SleepFor(absl::Seconds(1));
+ const time_t pre_attach = time(nullptr);
+
+ const char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0));
+ ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr));
+
+ EXPECT_GE(attr.shm_atime, pre_attach);
+ EXPECT_EQ(attr.shm_dtime, 0);
+ EXPECT_LT(attr.shm_ctime, pre_attach);
+ EXPECT_EQ(attr.shm_lpid, pid);
+ EXPECT_EQ(attr.shm_nattch, 1);
+
+ absl::SleepFor(absl::Seconds(1));
+ const time_t pre_detach = time(nullptr);
+
+ ASSERT_NO_ERRNO(Shmdt(addr));
+ ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr));
+
+ EXPECT_LT(attr.shm_atime, pre_detach);
+ EXPECT_GE(attr.shm_dtime, pre_detach);
+ EXPECT_LT(attr.shm_ctime, pre_detach);
+ EXPECT_EQ(attr.shm_lpid, pid);
+ EXPECT_EQ(attr.shm_nattch, 0);
+}
+
+TEST(ShmTest, ShmStat) {
+ // This test relies on the segment we create to be the first one on the
+ // system, causing it to occupy slot 1. We can't reasonably expect this on a
+ // general Linux host.
+ SKIP_IF(!IsRunningOnGvisor());
+
+ const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE(
+ Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777));
+ struct shmid_ds attr;
+ ASSERT_NO_ERRNO(Shmctl(1, SHM_STAT, &attr));
+ // This does the same thing as IPC_STAT, so only test that the syscall
+ // succeeds here.
+}
+
+TEST(ShmTest, IpcInfo) {
+ struct shminfo info;
+ ASSERT_NO_ERRNO(Shmctl(0, IPC_INFO, &info));
+
+ EXPECT_EQ(info.shmmin, 1); // This is always 1, according to the man page.
+ EXPECT_GT(info.shmmax, info.shmmin);
+ EXPECT_GT(info.shmmni, 0);
+ EXPECT_GT(info.shmseg, 0);
+ EXPECT_GT(info.shmall, 0);
+}
+
+TEST(ShmTest, ShmInfo) {
+ struct shm_info info;
+
+ // We generally can't know what other processes on a linux machine
+ // does with shared memory segments, so we can't test specific
+ // numbers on Linux. When running under gvisor, we're guaranteed to
+ // be the only ones using shm, so we can easily verify machine-wide
+ // numbers.
+ if (IsRunningOnGvisor()) {
+ ASSERT_NO_ERRNO(Shmctl(0, SHM_INFO, &info));
+ EXPECT_EQ(info.used_ids, 0);
+ EXPECT_EQ(info.shm_tot, 0);
+ EXPECT_EQ(info.shm_rss, 0);
+ EXPECT_EQ(info.shm_swp, 0);
+ }
+
+ const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE(
+ Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777));
+ const char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0));
+
+ ASSERT_NO_ERRNO(Shmctl(1, SHM_INFO, &info));
+
+ if (IsRunningOnGvisor()) {
+ ASSERT_NO_ERRNO(Shmctl(shm.id(), SHM_INFO, &info));
+ EXPECT_EQ(info.used_ids, 1);
+ EXPECT_EQ(info.shm_tot, kAllocSize / kPageSize);
+ EXPECT_EQ(info.shm_rss, kAllocSize / kPageSize);
+ EXPECT_EQ(info.shm_swp, 0); // Gvisor currently never swaps.
+ }
+
+ ASSERT_NO_ERRNO(Shmdt(addr));
+}
+
+TEST(ShmTest, ShmCtlSet) {
+ const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE(
+ Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777));
+ const char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0));
+
+ struct shmid_ds attr;
+ ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr));
+ ASSERT_EQ(attr.shm_perm.mode, 0777);
+
+ attr.shm_perm.mode = 0766;
+ ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_SET, &attr));
+
+ ASSERT_NO_ERRNO(Shmctl(shm.id(), IPC_STAT, &attr));
+ ASSERT_EQ(attr.shm_perm.mode, 0766);
+
+ ASSERT_NO_ERRNO(Shmdt(addr));
+}
+
+TEST(ShmTest, RemovedSegmentsAreMarkedDeleted) {
+ ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE(
+ Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777));
+ const char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0));
+ const int id = ASSERT_NO_ERRNO_AND_VALUE(shm.Rmid());
+ struct shmid_ds attr;
+ ASSERT_NO_ERRNO(Shmctl(id, IPC_STAT, &attr));
+ EXPECT_NE(attr.shm_perm.mode & SHM_DEST, 0);
+ ASSERT_NO_ERRNO(Shmdt(addr));
+}
+
+TEST(ShmTest, RemovedSegmentsAreDestroyed) {
+ ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE(
+ Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777));
+ const char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0));
+
+ const uint64_t alloc_pages = kAllocSize / kPageSize;
+
+ struct shm_info info;
+ ASSERT_NO_ERRNO(Shmctl(0 /*ignored*/, SHM_INFO, &info));
+ const uint64_t before = info.shm_tot;
+
+ ASSERT_NO_ERRNO(shm.Rmid());
+ ASSERT_NO_ERRNO(Shmdt(addr));
+
+ ASSERT_NO_ERRNO(Shmctl(0 /*ignored*/, SHM_INFO, &info));
+ if (IsRunningOnGvisor()) {
+ // No guarantees on system-wide shm memory usage on a generic linux host.
+ const uint64_t after = info.shm_tot;
+ EXPECT_EQ(after, before - alloc_pages);
+ }
+}
+
+TEST(ShmTest, AllowsAttachToRemovedSegmentWithRefs) {
+ ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE(
+ Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777));
+ const char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0));
+ const int id = ASSERT_NO_ERRNO_AND_VALUE(shm.Rmid());
+ const char* addr2 = ASSERT_NO_ERRNO_AND_VALUE(Shmat(id, nullptr, 0));
+ ASSERT_NO_ERRNO(Shmdt(addr));
+ ASSERT_NO_ERRNO(Shmdt(addr2));
+}
+
+TEST(ShmTest, RemovedSegmentsAreNotDiscoverable) {
+ const TempPath keyfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const key_t key = ftok(keyfile.path().c_str(), 1);
+ ShmSegment shm =
+ ASSERT_NO_ERRNO_AND_VALUE(Shmget(key, kAllocSize, IPC_CREAT | 0777));
+ ASSERT_NO_ERRNO(shm.Rmid());
+ EXPECT_THAT(Shmget(key, kAllocSize, 0777), PosixErrorIs(ENOENT, _));
+}
+
+TEST(ShmDeathTest, ReadonlySegment) {
+ SetupGvisorDeathTest();
+ const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE(
+ Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777));
+ char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, SHM_RDONLY));
+ // Reading succeeds.
+ static_cast<void>(addr[0]);
+ // Writing fails.
+ EXPECT_EXIT(addr[0] = 'x', ::testing::KilledBySignal(SIGSEGV), "");
+}
+
+TEST(ShmDeathTest, SegmentNotAccessibleAfterDetach) {
+ // This test is susceptible to races with concurrent mmaps running in parallel
+ // gtest threads since the test relies on the address freed during a shm
+ // segment destruction to remain unused. We run the test body in a forked
+ // child to guarantee a single-threaded context to avoid this.
+
+ SetupGvisorDeathTest();
+
+ const auto rest = [&] {
+ ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE(
+ Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777));
+ char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0));
+
+ // Mark the segment as destroyed so it's automatically cleaned up when we
+ // crash below. We can't rely on the standard cleanup since the destructor
+ // will not run after the SIGSEGV. Note that this doesn't destroy the
+ // segment immediately since we're still attached to it.
+ ASSERT_NO_ERRNO(shm.Rmid());
+
+ addr[0] = 'x';
+ ASSERT_NO_ERRNO(Shmdt(addr));
+
+ // This access should cause a SIGSEGV.
+ addr[0] = 'x';
+ };
+
+ EXPECT_THAT(InForkedProcess(rest),
+ IsPosixErrorOkAndHolds(W_EXITCODE(0, SIGSEGV)));
+}
+
+TEST(ShmTest, RequestingSegmentSmallerThanSHMMINFails) {
+ struct shminfo info;
+ ASSERT_NO_ERRNO(Shmctl(0, IPC_INFO, &info));
+ const uint64_t size = info.shmmin - 1;
+ EXPECT_THAT(Shmget(IPC_PRIVATE, size, IPC_CREAT | 0777),
+ PosixErrorIs(EINVAL, _));
+}
+
+TEST(ShmTest, RequestingSegmentLargerThanSHMMAXFails) {
+ struct shminfo info;
+ ASSERT_NO_ERRNO(Shmctl(0, IPC_INFO, &info));
+ const uint64_t size = info.shmmax + kPageSize;
+ EXPECT_THAT(Shmget(IPC_PRIVATE, size, IPC_CREAT | 0777),
+ PosixErrorIs(EINVAL, _));
+}
+
+TEST(ShmTest, RequestingUnalignedSizeSucceeds) {
+ EXPECT_NO_ERRNO(Shmget(IPC_PRIVATE, 4097, IPC_CREAT | 0777));
+}
+
+TEST(ShmTest, RequestingDuplicateCreationFails) {
+ const TempPath keyfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const key_t key = ftok(keyfile.path().c_str(), 1);
+ const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE(
+ Shmget(key, kAllocSize, IPC_CREAT | IPC_EXCL | 0777));
+ EXPECT_THAT(Shmget(key, kAllocSize, IPC_CREAT | IPC_EXCL | 0777),
+ PosixErrorIs(EEXIST, _));
+}
+
+TEST(ShmTest, NonExistentSegmentsAreNotFound) {
+ const TempPath keyfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const key_t key = ftok(keyfile.path().c_str(), 1);
+ // Do not request creation.
+ EXPECT_THAT(Shmget(key, kAllocSize, 0777), PosixErrorIs(ENOENT, _));
+}
+
+TEST(ShmTest, SegmentsSizeFixedOnCreation) {
+ const TempPath keyfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const key_t key = ftok(keyfile.path().c_str(), 1);
+
+ // Base segment.
+ const ShmSegment shm =
+ ASSERT_NO_ERRNO_AND_VALUE(Shmget(key, kAllocSize, IPC_CREAT | 0777));
+
+ // Ask for the same segment at half size. This succeeds.
+ const int id2 =
+ ASSERT_NO_ERRNO_AND_VALUE(ShmgetRaw(key, kAllocSize / 2, 0777));
+
+ // Ask for the same segment at double size.
+ EXPECT_THAT(Shmget(key, kAllocSize * 2, 0777), PosixErrorIs(EINVAL, _));
+
+ char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0));
+ char* addr2 = ASSERT_NO_ERRNO_AND_VALUE(Shmat(id2, nullptr, 0));
+
+ // We have 2 different maps...
+ EXPECT_NE(addr, addr2);
+
+ // ... And both maps are kAllocSize bytes; despite asking for a half-sized
+ // segment for the second map.
+ addr[kAllocSize - 1] = 'x';
+ addr2[kAllocSize - 1] = 'x';
+
+ ASSERT_NO_ERRNO(Shmdt(addr));
+ ASSERT_NO_ERRNO(Shmdt(addr2));
+}
+
+TEST(ShmTest, PartialUnmap) {
+ const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE(
+ Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777));
+ char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0));
+ EXPECT_THAT(munmap(addr + (kAllocSize / 4), kAllocSize / 2),
+ SyscallSucceeds());
+ ASSERT_NO_ERRNO(Shmdt(addr));
+}
+
+// Check that sentry does not panic when asked for a zero-length private shm
+// segment. Regression test for b/110694797.
+TEST(ShmTest, GracefullyFailOnZeroLenSegmentCreation) {
+ EXPECT_THAT(Shmget(IPC_PRIVATE, 0, 0), PosixErrorIs(EINVAL, _));
+}
+
+TEST(ShmTest, NoDestructionOfAttachedSegmentWithMultipleRmid) {
+ ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE(
+ Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777));
+ char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0));
+ char* addr2 = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0));
+
+ // There should be 2 refs to the segment from the 2 attachments, and a single
+ // self-reference. Mark the segment as destroyed more than 3 times through
+ // shmctl(RMID). If there's a bug with the ref counting, this should cause the
+ // count to drop to zero.
+ int id = shm.release();
+ for (int i = 0; i < 6; ++i) {
+ ASSERT_NO_ERRNO(Shmctl<void>(id, IPC_RMID, nullptr));
+ }
+
+ // Segment should remain accessible.
+ addr[0] = 'x';
+ ASSERT_NO_ERRNO(Shmdt(addr));
+
+ // Segment should remain accessible even after one of the two attachments are
+ // detached.
+ addr2[0] = 'x';
+ ASSERT_NO_ERRNO(Shmdt(addr2));
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/sigaction.cc b/test/syscalls/linux/sigaction.cc
new file mode 100644
index 000000000..9d9dd57a8
--- /dev/null
+++ b/test/syscalls/linux/sigaction.cc
@@ -0,0 +1,79 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+#include <sys/syscall.h>
+
+#include "gtest/gtest.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(SigactionTest, GetLessThanOrEqualToZeroFails) {
+ struct sigaction act = {};
+ ASSERT_THAT(sigaction(-1, nullptr, &act), SyscallFailsWithErrno(EINVAL));
+ ASSERT_THAT(sigaction(0, nullptr, &act), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SigactionTest, SetLessThanOrEqualToZeroFails) {
+ struct sigaction act = {};
+ ASSERT_THAT(sigaction(0, &act, nullptr), SyscallFailsWithErrno(EINVAL));
+ ASSERT_THAT(sigaction(0, &act, nullptr), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SigactionTest, GetGreaterThanMaxFails) {
+ struct sigaction act = {};
+ ASSERT_THAT(sigaction(SIGRTMAX + 1, nullptr, &act),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SigactionTest, SetGreaterThanMaxFails) {
+ struct sigaction act = {};
+ ASSERT_THAT(sigaction(SIGRTMAX + 1, &act, nullptr),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SigactionTest, SetSigkillFails) {
+ struct sigaction act = {};
+ ASSERT_THAT(sigaction(SIGKILL, nullptr, &act), SyscallSucceeds());
+ ASSERT_THAT(sigaction(SIGKILL, &act, nullptr), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SigactionTest, SetSigstopFails) {
+ struct sigaction act = {};
+ ASSERT_THAT(sigaction(SIGSTOP, nullptr, &act), SyscallSucceeds());
+ ASSERT_THAT(sigaction(SIGSTOP, &act, nullptr), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SigactionTest, BadSigsetFails) {
+ constexpr size_t kWrongSigSetSize = 43;
+
+ struct sigaction act = {};
+
+ // The syscall itself (rather than the libc wrapper) takes the sigset_t size.
+ ASSERT_THAT(
+ syscall(SYS_rt_sigaction, SIGTERM, nullptr, &act, kWrongSigSetSize),
+ SyscallFailsWithErrno(EINVAL));
+ ASSERT_THAT(
+ syscall(SYS_rt_sigaction, SIGTERM, &act, nullptr, kWrongSigSetSize),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/sigaltstack.cc b/test/syscalls/linux/sigaltstack.cc
new file mode 100644
index 000000000..24e7c4960
--- /dev/null
+++ b/test/syscalls/linux/sigaltstack.cc
@@ -0,0 +1,268 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <signal.h>
+#include <stdio.h>
+#include <string.h>
+#include <unistd.h>
+
+#include <functional>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/util/cleanup.h"
+#include "test/util/fs_util.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+PosixErrorOr<Cleanup> ScopedSigaltstack(stack_t const& stack) {
+ stack_t old_stack;
+ int rc = sigaltstack(&stack, &old_stack);
+ MaybeSave();
+ if (rc < 0) {
+ return PosixError(errno, "sigaltstack failed");
+ }
+ return Cleanup([old_stack] {
+ EXPECT_THAT(sigaltstack(&old_stack, nullptr), SyscallSucceeds());
+ });
+}
+
+volatile bool got_signal = false;
+volatile int sigaltstack_errno = 0;
+volatile int ss_flags = 0;
+
+void sigaltstack_handler(int sig, siginfo_t* siginfo, void* arg) {
+ got_signal = true;
+
+ stack_t stack;
+ int ret = sigaltstack(nullptr, &stack);
+ MaybeSave();
+ if (ret < 0) {
+ sigaltstack_errno = errno;
+ return;
+ }
+ ss_flags = stack.ss_flags;
+}
+
+TEST(SigaltstackTest, Success) {
+ std::vector<char> stack_mem(SIGSTKSZ);
+ stack_t stack = {};
+ stack.ss_sp = stack_mem.data();
+ stack.ss_size = stack_mem.size();
+ auto const cleanup_sigstack =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaltstack(stack));
+
+ struct sigaction sa = {};
+ sa.sa_sigaction = sigaltstack_handler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO | SA_ONSTACK;
+ auto const cleanup_sa =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGUSR1, sa));
+
+ // Send signal to this thread, as sigaltstack is per-thread.
+ EXPECT_THAT(tgkill(getpid(), gettid(), SIGUSR1), SyscallSucceeds());
+
+ EXPECT_TRUE(got_signal);
+ EXPECT_EQ(sigaltstack_errno, 0);
+ EXPECT_NE(0, ss_flags & SS_ONSTACK);
+}
+
+TEST(SigaltstackTest, ResetByExecve) {
+ std::vector<char> stack_mem(SIGSTKSZ);
+ stack_t stack = {};
+ stack.ss_sp = stack_mem.data();
+ stack.ss_size = stack_mem.size();
+ auto const cleanup_sigstack =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaltstack(stack));
+
+ std::string full_path = RunfilePath("test/syscalls/linux/sigaltstack_check");
+
+ pid_t child_pid = -1;
+ int execve_errno = 0;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(full_path, {"sigaltstack_check"}, {}, nullptr, &child_pid,
+ &execve_errno));
+
+ ASSERT_GT(child_pid, 0);
+ ASSERT_EQ(execve_errno, 0);
+
+ int status = 0;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ ASSERT_TRUE(WIFEXITED(status));
+ ASSERT_EQ(WEXITSTATUS(status), 0);
+}
+
+volatile bool badhandler_on_sigaltstack = true; // Set by the handler.
+char* volatile badhandler_low_water_mark = nullptr; // Set by the handler.
+volatile uint8_t badhandler_recursive_faults = 0; // Consumed by the handler.
+
+void badhandler(int sig, siginfo_t* siginfo, void* arg) {
+ char stack_var = 0;
+ char* current_ss = &stack_var;
+
+ stack_t stack;
+ int ret = sigaltstack(nullptr, &stack);
+ if (ret < 0 || (stack.ss_flags & SS_ONSTACK) != SS_ONSTACK) {
+ // We should always be marked as being on the stack. Don't allow this to hit
+ // the bottom if this is ever not true (the main test will fail as a
+ // result, but we still need to unwind the recursive faults).
+ badhandler_on_sigaltstack = false;
+ }
+ if (current_ss < badhandler_low_water_mark) {
+ // Record the low point for the signal stack. We never expected this to be
+ // before stack bottom, but this is asserted in the actual test.
+ badhandler_low_water_mark = current_ss;
+ }
+ if (badhandler_recursive_faults > 0) {
+ badhandler_recursive_faults--;
+ Fault();
+ }
+ FixupFault(reinterpret_cast<ucontext_t*>(arg));
+}
+
+TEST(SigaltstackTest, WalksOffBottom) {
+ // This test marks the upper half of the stack_mem array as the signal stack.
+ // It asserts that when a fault occurs in the handler (already on the signal
+ // stack), we eventually continue to fault our way off the stack. We should
+ // not revert to the top of the signal stack when we fall off the bottom and
+ // the signal stack should remain "in use". When we fall off the signal stack,
+ // we should have an unconditional signal delivered and not start using the
+ // first part of the stack_mem array.
+ std::vector<char> stack_mem(SIGSTKSZ * 2);
+ stack_t stack = {};
+ stack.ss_sp = stack_mem.data() + SIGSTKSZ; // See above: upper half.
+ stack.ss_size = SIGSTKSZ; // Only one half the array.
+ auto const cleanup_sigstack =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaltstack(stack));
+
+ // Setup the handler: this must be for SIGSEGV, and it must allow proper
+ // nesting (no signal mask, no defer) so that we can trigger multiple times.
+ //
+ // When we walk off the bottom of the signal stack and force signal delivery
+ // of a SIGSEGV, the handler will revert to the default behavior (kill).
+ struct sigaction sa = {};
+ sa.sa_sigaction = badhandler;
+ sa.sa_flags = SA_SIGINFO | SA_ONSTACK | SA_NODEFER;
+ auto const cleanup_sa =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGSEGV, sa));
+
+ // Trigger a single fault.
+ badhandler_low_water_mark =
+ static_cast<char*>(stack.ss_sp) + SIGSTKSZ; // Expected top.
+ badhandler_recursive_faults = 0; // Disable refault.
+ Fault();
+ EXPECT_TRUE(badhandler_on_sigaltstack);
+ EXPECT_THAT(sigaltstack(nullptr, &stack), SyscallSucceeds());
+ EXPECT_EQ(stack.ss_flags & SS_ONSTACK, 0);
+ EXPECT_LT(badhandler_low_water_mark,
+ reinterpret_cast<char*>(stack.ss_sp) + 2 * SIGSTKSZ);
+ EXPECT_GT(badhandler_low_water_mark, reinterpret_cast<char*>(stack.ss_sp));
+
+ // Trigger two faults.
+ char* prev_low_water_mark = badhandler_low_water_mark; // Previous top.
+ badhandler_recursive_faults = 1; // One refault.
+ Fault();
+ ASSERT_TRUE(badhandler_on_sigaltstack);
+ EXPECT_THAT(sigaltstack(nullptr, &stack), SyscallSucceeds());
+ EXPECT_EQ(stack.ss_flags & SS_ONSTACK, 0);
+ EXPECT_LT(badhandler_low_water_mark, prev_low_water_mark);
+ EXPECT_GT(badhandler_low_water_mark, reinterpret_cast<char*>(stack.ss_sp));
+
+ // Calculate the stack growth for a fault, and set the recursive faults to
+ // ensure that the signal handler stack required exceeds our marked stack area
+ // by a minimal amount. It should remain in the valid stack_mem area so that
+ // we can test the signal is forced merely by going out of the signal stack
+ // bounds, not by a genuine fault.
+ uintptr_t frame_size =
+ static_cast<uintptr_t>(prev_low_water_mark - badhandler_low_water_mark);
+ badhandler_recursive_faults = (SIGSTKSZ + frame_size) / frame_size;
+ EXPECT_EXIT(Fault(), ::testing::KilledBySignal(SIGSEGV), "");
+}
+
+volatile int setonstack_retval = 0; // Set by the handler.
+volatile int setonstack_errno = 0; // Set by the handler.
+
+void setonstack(int sig, siginfo_t* siginfo, void* arg) {
+ char stack_mem[SIGSTKSZ];
+ stack_t stack = {};
+ stack.ss_sp = &stack_mem[0];
+ stack.ss_size = SIGSTKSZ;
+ setonstack_retval = sigaltstack(&stack, nullptr);
+ setonstack_errno = errno;
+ FixupFault(reinterpret_cast<ucontext_t*>(arg));
+}
+
+TEST(SigaltstackTest, SetWhileOnStack) {
+ // Reserve twice as much stack here, since the handler will allocate a vector
+ // of size SIGTKSZ and attempt to set the sigaltstack to that value.
+ std::vector<char> stack_mem(2 * SIGSTKSZ);
+ stack_t stack = {};
+ stack.ss_sp = stack_mem.data();
+ stack.ss_size = stack_mem.size();
+ auto const cleanup_sigstack =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaltstack(stack));
+
+ // See above.
+ struct sigaction sa = {};
+ sa.sa_sigaction = setonstack;
+ sa.sa_flags = SA_SIGINFO | SA_ONSTACK;
+ auto const cleanup_sa =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGSEGV, sa));
+
+ // Trigger a fault.
+ Fault();
+
+ // The set should have failed.
+ EXPECT_EQ(setonstack_retval, -1);
+ EXPECT_EQ(setonstack_errno, EPERM);
+}
+
+TEST(SigaltstackTest, SetCurrentStack) {
+ // This is executed as an exit test because once the signal stack is set to
+ // the local stack, there's no good way to unwind. We don't want to taint the
+ // test of any other tests that might run within this process.
+ EXPECT_EXIT(
+ {
+ char stack_value = 0;
+ stack_t stack = {};
+ stack.ss_sp = &stack_value - kPageSize; // Lower than current level.
+ stack.ss_size = 2 * kPageSize; // => &stack_value +/- kPageSize.
+ TEST_CHECK(sigaltstack(&stack, nullptr) == 0);
+ TEST_CHECK(sigaltstack(nullptr, &stack) == 0);
+ TEST_CHECK((stack.ss_flags & SS_ONSTACK) != 0);
+
+ // Should not be able to change the stack (even no-op).
+ TEST_CHECK(sigaltstack(&stack, nullptr) == -1 && errno == EPERM);
+
+ // Should not be able to disable the stack.
+ stack.ss_flags = SS_DISABLE;
+ TEST_CHECK(sigaltstack(&stack, nullptr) == -1 && errno == EPERM);
+ exit(0);
+ },
+ ::testing::ExitedWithCode(0), "");
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/sigaltstack_check.cc b/test/syscalls/linux/sigaltstack_check.cc
new file mode 100644
index 000000000..5ac1b661d
--- /dev/null
+++ b/test/syscalls/linux/sigaltstack_check.cc
@@ -0,0 +1,33 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Checks that there is no alternate signal stack by default.
+//
+// Used by a test in sigaltstack.cc.
+#include <errno.h>
+#include <signal.h>
+#include <stdio.h>
+#include <string.h>
+#include <unistd.h>
+
+#include "test/util/logging.h"
+
+int main(int /* argc */, char** /* argv */) {
+ stack_t stack;
+ TEST_CHECK(sigaltstack(nullptr, &stack) >= 0);
+ TEST_CHECK(stack.ss_flags == SS_DISABLE);
+ TEST_CHECK(stack.ss_sp == 0);
+ TEST_CHECK(stack.ss_size == 0);
+ return 0;
+}
diff --git a/test/syscalls/linux/sigiret.cc b/test/syscalls/linux/sigiret.cc
new file mode 100644
index 000000000..6227774a4
--- /dev/null
+++ b/test/syscalls/linux/sigiret.cc
@@ -0,0 +1,136 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+#include <sys/types.h>
+#include <sys/ucontext.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/logging.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+#include "test/util/timer_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr uint64_t kOrigRcx = 0xdeadbeeffacefeed;
+constexpr uint64_t kOrigR11 = 0xfacefeedbaad1dea;
+
+volatile int gotvtalrm, ready;
+
+void sigvtalrm(int sig, siginfo_t* siginfo, void* _uc) {
+ ucontext_t* uc = reinterpret_cast<ucontext_t*>(_uc);
+
+ // Verify that:
+ // - test is in the busy-wait loop waiting for signal.
+ // - %rcx and %r11 values in mcontext_t match kOrigRcx and kOrigR11.
+ if (ready &&
+ static_cast<uint64_t>(uc->uc_mcontext.gregs[REG_RCX]) == kOrigRcx &&
+ static_cast<uint64_t>(uc->uc_mcontext.gregs[REG_R11]) == kOrigR11) {
+ // Modify the values %rcx and %r11 in the ucontext. These are the
+ // values seen by the application after the signal handler returns.
+ uc->uc_mcontext.gregs[REG_RCX] = ~kOrigRcx;
+ uc->uc_mcontext.gregs[REG_R11] = ~kOrigR11;
+ gotvtalrm = 1;
+ }
+}
+
+TEST(SigIretTest, CheckRcxR11) {
+ // Setup signal handler for SIGVTALRM.
+ struct sigaction sa = {};
+ sigfillset(&sa.sa_mask);
+ sa.sa_sigaction = sigvtalrm;
+ sa.sa_flags = SA_SIGINFO;
+ auto const action_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGVTALRM, sa));
+
+ auto const mask_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGVTALRM));
+
+ // Setup itimer to fire after 500 msecs.
+ struct itimerval itimer = {};
+ itimer.it_value.tv_usec = 500 * 1000; // 500 msecs.
+ auto const timer_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_VIRTUAL, itimer));
+
+ // Initialize %rcx and %r11 and spin until the signal handler returns.
+ uint64_t rcx = kOrigRcx;
+ uint64_t r11 = kOrigR11;
+ asm volatile(
+ "movq %[rcx], %%rcx;" // %rcx = rcx
+ "movq %[r11], %%r11;" // %r11 = r11
+ "movl $1, %[ready];" // ready = 1
+ "1: pause; cmpl $0, %[gotvtalrm]; je 1b;" // while (!gotvtalrm);
+ "movq %%rcx, %[rcx];" // rcx = %rcx
+ "movq %%r11, %[r11];" // r11 = %r11
+ : [ ready ] "=m"(ready), [ rcx ] "+m"(rcx), [ r11 ] "+m"(r11)
+ : [ gotvtalrm ] "m"(gotvtalrm)
+ : "cc", "memory", "rcx", "r11");
+
+ // If sigreturn(2) returns via 'sysret' then %rcx and %r11 will be
+ // clobbered and set to 'ptregs->rip' and 'ptregs->rflags' respectively.
+ //
+ // The following check verifies that %rcx and %r11 were not clobbered
+ // when returning from the signal handler (via sigreturn(2)).
+ EXPECT_EQ(rcx, ~kOrigRcx);
+ EXPECT_EQ(r11, ~kOrigR11);
+}
+
+constexpr uint64_t kNonCanonicalRip = 0xCCCC000000000000;
+
+// Test that a non-canonical signal handler faults as expected.
+TEST(SigIretTest, BadHandler) {
+ struct sigaction sa = {};
+ sa.sa_sigaction =
+ reinterpret_cast<void (*)(int, siginfo_t*, void*)>(kNonCanonicalRip);
+ auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGUSR1, sa));
+
+ pid_t pid = fork();
+ if (pid == 0) {
+ // Child, wait for signal.
+ while (1) {
+ pause();
+ }
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+
+ EXPECT_THAT(kill(pid, SIGUSR1), SyscallSucceeds());
+
+ int status;
+ EXPECT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSEGV)
+ << "status = " << status;
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
+
+int main(int argc, char** argv) {
+ // SigIretTest.CheckRcxR11 depends on delivering SIGVTALRM to the main thread.
+ // Block SIGVTALRM so that any other threads created by TestInit will also
+ // have SIGVTALRM blocked.
+ sigset_t set;
+ sigemptyset(&set);
+ sigaddset(&set, SIGVTALRM);
+ TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
+
+ gvisor::testing::TestInit(&argc, &argv);
+ return gvisor::testing::RunAllTests();
+}
diff --git a/test/syscalls/linux/signalfd.cc b/test/syscalls/linux/signalfd.cc
new file mode 100644
index 000000000..389e5fca2
--- /dev/null
+++ b/test/syscalls/linux/signalfd.cc
@@ -0,0 +1,373 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <poll.h>
+#include <signal.h>
+#include <stdio.h>
+#include <string.h>
+#include <sys/signalfd.h>
+#include <unistd.h>
+
+#include <functional>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "absl/synchronization/mutex.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+using ::testing::KilledBySignal;
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr int kSigno = SIGUSR1;
+constexpr int kSignoMax = 64; // SIGRTMAX
+constexpr int kSignoAlt = SIGUSR2;
+
+// Returns a new signalfd.
+inline PosixErrorOr<FileDescriptor> NewSignalFD(sigset_t* mask, int flags = 0) {
+ int fd = signalfd(-1, mask, flags);
+ MaybeSave();
+ if (fd < 0) {
+ return PosixError(errno, "signalfd");
+ }
+ return FileDescriptor(fd);
+}
+
+class SignalfdTest : public ::testing::TestWithParam<int> {};
+
+TEST_P(SignalfdTest, Basic) {
+ int signo = GetParam();
+ // Create the signalfd.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, signo);
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
+
+ // Deliver the blocked signal.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo));
+ ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds());
+
+ // We should now read the signal.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, signo);
+}
+
+TEST_P(SignalfdTest, MaskWorks) {
+ int signo = GetParam();
+ // Create two signalfds with different masks.
+ sigset_t mask1, mask2;
+ sigemptyset(&mask1);
+ sigemptyset(&mask2);
+ sigaddset(&mask1, signo);
+ sigaddset(&mask2, kSignoAlt);
+ FileDescriptor fd1 = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask1, 0));
+ FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask2, 0));
+
+ // Deliver the two signals.
+ const auto scoped_sigmask1 =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo));
+ const auto scoped_sigmask2 =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSignoAlt));
+ ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds());
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSignoAlt), SyscallSucceeds());
+
+ // We should see the signals on the appropriate signalfds.
+ //
+ // We read in the opposite order as the signals deliver above, to ensure that
+ // we don't happen to read the correct signal from the correct signalfd.
+ struct signalfd_siginfo rbuf1, rbuf2;
+ ASSERT_THAT(read(fd2.get(), &rbuf2, sizeof(rbuf2)),
+ SyscallSucceedsWithValue(sizeof(rbuf2)));
+ EXPECT_EQ(rbuf2.ssi_signo, kSignoAlt);
+ ASSERT_THAT(read(fd1.get(), &rbuf1, sizeof(rbuf1)),
+ SyscallSucceedsWithValue(sizeof(rbuf1)));
+ EXPECT_EQ(rbuf1.ssi_signo, signo);
+}
+
+TEST(Signalfd, Cloexec) {
+ // Exec tests confirm that O_CLOEXEC has the intended effect. We just create a
+ // signalfd with the appropriate flag here and assert that the FD has it set.
+ sigset_t mask;
+ sigemptyset(&mask);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC));
+ EXPECT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC));
+}
+
+TEST_P(SignalfdTest, Blocking) {
+ int signo = GetParam();
+ // Create the signalfd in blocking mode.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, signo);
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
+
+ // Shared tid variable.
+ absl::Mutex mu;
+ bool has_tid;
+ pid_t tid;
+
+ // Start a thread reading.
+ ScopedThread t([&] {
+ // Copy the tid and notify the caller.
+ {
+ absl::MutexLock ml(&mu);
+ tid = gettid();
+ has_tid = true;
+ }
+
+ // Read the signal from the signalfd.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, signo);
+ });
+
+ // Wait until blocked.
+ absl::MutexLock ml(&mu);
+ mu.Await(absl::Condition(&has_tid));
+
+ // Deliver the signal to either the waiting thread, or
+ // to this thread. N.B. this is a bug in the core gVisor
+ // behavior for signalfd, and needs to be fixed.
+ //
+ // See gvisor.dev/issue/139.
+ if (IsRunningOnGvisor()) {
+ ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds());
+ } else {
+ ASSERT_THAT(tgkill(getpid(), tid, signo), SyscallSucceeds());
+ }
+
+ // Ensure that it was received.
+ t.Join();
+}
+
+TEST_P(SignalfdTest, ThreadGroup) {
+ int signo = GetParam();
+ // Create the signalfd in blocking mode.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, signo);
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
+
+ // Shared variable.
+ absl::Mutex mu;
+ bool first = false;
+ bool second = false;
+
+ // Start a thread reading.
+ ScopedThread t([&] {
+ // Read the signal from the signalfd.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, signo);
+
+ // Wait for the other thread.
+ absl::MutexLock ml(&mu);
+ first = true;
+ mu.Await(absl::Condition(&second));
+ });
+
+ // Deliver the signal to the threadgroup.
+ ASSERT_THAT(kill(getpid(), signo), SyscallSucceeds());
+
+ // Wait for the first thread to process.
+ {
+ absl::MutexLock ml(&mu);
+ mu.Await(absl::Condition(&first));
+ }
+
+ // Deliver to the thread group again (other thread still exists).
+ ASSERT_THAT(kill(getpid(), signo), SyscallSucceeds());
+
+ // Ensure that we can also receive it.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, signo);
+
+ // Mark the test as done.
+ {
+ absl::MutexLock ml(&mu);
+ second = true;
+ }
+
+ // The other thread should be joinable.
+ t.Join();
+}
+
+TEST_P(SignalfdTest, Nonblock) {
+ int signo = GetParam();
+ // Create the signalfd in non-blocking mode.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, signo);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_NONBLOCK));
+
+ // We should return if we attempt to read.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Block and deliver the signal.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo));
+ ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds());
+
+ // Ensure that a read actually works.
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, signo);
+
+ // Should block again.
+ EXPECT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+TEST_P(SignalfdTest, SetMask) {
+ int signo = GetParam();
+ // Create the signalfd matching nothing.
+ sigset_t mask;
+ sigemptyset(&mask);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_NONBLOCK));
+
+ // Block and deliver a signal.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo));
+ ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds());
+
+ // We should have nothing.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Change the signal mask.
+ sigaddset(&mask, signo);
+ ASSERT_THAT(signalfd(fd.get(), &mask, 0), SyscallSucceeds());
+
+ // We should now have the signal.
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, signo);
+}
+
+TEST_P(SignalfdTest, Poll) {
+ int signo = GetParam();
+ // Create the signalfd.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, signo);
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
+
+ // Block the signal, and start a thread to deliver it.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo));
+ pid_t orig_tid = gettid();
+ ScopedThread t([&] {
+ absl::SleepFor(absl::Seconds(5));
+ ASSERT_THAT(tgkill(getpid(), orig_tid, signo), SyscallSucceeds());
+ });
+
+ // Start polling for the signal. We expect that it is not available at the
+ // outset, but then becomes available when the signal is sent. We give a
+ // timeout of 10000ms (or the delay above + 5 seconds of additional grace
+ // time).
+ struct pollfd poll_fd = {fd.get(), POLLIN, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000),
+ SyscallSucceedsWithValue(1));
+
+ // Actually read the signal to prevent delivery.
+ struct signalfd_siginfo rbuf;
+ EXPECT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+}
+
+std::string PrintSigno(::testing::TestParamInfo<int> info) {
+ switch (info.param) {
+ case kSigno:
+ return "kSigno";
+ case kSignoMax:
+ return "kSignoMax";
+ default:
+ return absl::StrCat(info.param);
+ }
+}
+INSTANTIATE_TEST_SUITE_P(Signalfd, SignalfdTest,
+ ::testing::Values(kSigno, kSignoMax), PrintSigno);
+
+TEST(Signalfd, Ppoll) {
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, SIGKILL);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC));
+
+ // Ensure that the given ppoll blocks.
+ struct pollfd pfd = {};
+ pfd.fd = fd.get();
+ pfd.events = POLLIN;
+ struct timespec timeout = {};
+ timeout.tv_sec = 1;
+ EXPECT_THAT(RetryEINTR(ppoll)(&pfd, 1, &timeout, &mask),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST(Signalfd, KillStillKills) {
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, SIGKILL);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC));
+
+ // Just because there is a signalfd, we shouldn't see any change in behavior
+ // for unblockable signals. It's easier to test this with SIGKILL.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, SIGKILL));
+ EXPECT_EXIT(tgkill(getpid(), gettid(), SIGKILL), KilledBySignal(SIGKILL), "");
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
+
+int main(int argc, char** argv) {
+ // These tests depend on delivering signals. Block them up front so that all
+ // other threads created by TestInit will also have them blocked, and they
+ // will not interface with the rest of the test.
+ sigset_t set;
+ sigemptyset(&set);
+ sigaddset(&set, gvisor::testing::kSigno);
+ sigaddset(&set, gvisor::testing::kSignoMax);
+ sigaddset(&set, gvisor::testing::kSignoAlt);
+ TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
+
+ gvisor::testing::TestInit(&argc, &argv);
+
+ return gvisor::testing::RunAllTests();
+}
diff --git a/test/syscalls/linux/sigprocmask.cc b/test/syscalls/linux/sigprocmask.cc
new file mode 100644
index 000000000..a603fc1d1
--- /dev/null
+++ b/test/syscalls/linux/sigprocmask.cc
@@ -0,0 +1,269 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+#include <stddef.h>
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Signals numbers used for testing.
+static constexpr int kTestSignal1 = SIGUSR1;
+static constexpr int kTestSignal2 = SIGUSR2;
+
+static int raw_sigprocmask(int how, const sigset_t* set, sigset_t* oldset) {
+ return syscall(SYS_rt_sigprocmask, how, set, oldset, _NSIG / 8);
+}
+
+// count of the number of signals received
+int signal_count[kMaxSignal + 1];
+
+// signal handler increments the signal counter
+void SigHandler(int sig, siginfo_t* info, void* context) {
+ TEST_CHECK(sig > 0 && sig <= kMaxSignal);
+ signal_count[sig] += 1;
+}
+
+// The test fixture saves and restores the signal mask and
+// sets up handlers for kTestSignal1 and kTestSignal2.
+class SigProcMaskTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ // Save the current signal mask.
+ EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, &mask_),
+ SyscallSucceeds());
+
+ // Setup signal handlers for kTestSignal1 and kTestSignal2.
+ struct sigaction sa;
+ sa.sa_sigaction = SigHandler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO;
+ EXPECT_THAT(sigaction(kTestSignal1, &sa, &sa_test_sig_1_),
+ SyscallSucceeds());
+ EXPECT_THAT(sigaction(kTestSignal2, &sa, &sa_test_sig_2_),
+ SyscallSucceeds());
+
+ // Clear the signal counters.
+ memset(signal_count, 0, sizeof(signal_count));
+ }
+
+ void TearDown() override {
+ // Restore the signal mask.
+ EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, &mask_, nullptr),
+ SyscallSucceeds());
+
+ // Restore the signal handlers for kTestSignal1 and kTestSignal2.
+ EXPECT_THAT(sigaction(kTestSignal1, &sa_test_sig_1_, nullptr),
+ SyscallSucceeds());
+ EXPECT_THAT(sigaction(kTestSignal2, &sa_test_sig_2_, nullptr),
+ SyscallSucceeds());
+ }
+
+ private:
+ sigset_t mask_;
+ struct sigaction sa_test_sig_1_;
+ struct sigaction sa_test_sig_2_;
+};
+
+// Both sigsets nullptr should succeed and do nothing.
+TEST_F(SigProcMaskTest, NullAddress) {
+ EXPECT_THAT(raw_sigprocmask(SIG_BLOCK, nullptr, NULL), SyscallSucceeds());
+ EXPECT_THAT(raw_sigprocmask(SIG_UNBLOCK, nullptr, NULL), SyscallSucceeds());
+ EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, NULL), SyscallSucceeds());
+}
+
+// Bad address for either sigset should fail with EFAULT.
+TEST_F(SigProcMaskTest, BadAddress) {
+ sigset_t* bad_addr = reinterpret_cast<sigset_t*>(-1);
+
+ EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, bad_addr, nullptr),
+ SyscallFailsWithErrno(EFAULT));
+
+ EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, bad_addr),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+// Bad value of the "how" parameter should fail with EINVAL.
+TEST_F(SigProcMaskTest, BadParameter) {
+ int bad_param_1 = -1;
+ int bad_param_2 = 42;
+
+ sigset_t set1;
+ sigemptyset(&set1);
+
+ EXPECT_THAT(raw_sigprocmask(bad_param_1, &set1, nullptr),
+ SyscallFailsWithErrno(EINVAL));
+
+ EXPECT_THAT(raw_sigprocmask(bad_param_2, &set1, nullptr),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// Check that we can get the current signal mask.
+TEST_F(SigProcMaskTest, GetMask) {
+ sigset_t set1;
+ sigset_t set2;
+
+ sigemptyset(&set1);
+ sigfillset(&set2);
+ EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, &set1), SyscallSucceeds());
+ EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, &set2), SyscallSucceeds());
+ EXPECT_THAT(set1, EqualsSigset(set2));
+}
+
+// Check that we can set the signal mask.
+TEST_F(SigProcMaskTest, SetMask) {
+ sigset_t actual;
+ sigset_t expected;
+
+ // Try to mask all signals
+ sigfillset(&expected);
+ EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, &expected, nullptr),
+ SyscallSucceeds());
+ EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, &actual),
+ SyscallSucceeds());
+ // sigprocmask() should have silently ignored SIGKILL and SIGSTOP.
+ sigdelset(&expected, SIGSTOP);
+ sigdelset(&expected, SIGKILL);
+ EXPECT_THAT(actual, EqualsSigset(expected));
+
+ // Try to clear the signal mask
+ sigemptyset(&expected);
+ EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, &expected, nullptr),
+ SyscallSucceeds());
+ EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, &actual),
+ SyscallSucceeds());
+ EXPECT_THAT(actual, EqualsSigset(expected));
+
+ // Try to set a mask with one signal.
+ sigemptyset(&expected);
+ sigaddset(&expected, kTestSignal1);
+ EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, &expected, nullptr),
+ SyscallSucceeds());
+ EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, &actual),
+ SyscallSucceeds());
+ EXPECT_THAT(actual, EqualsSigset(expected));
+}
+
+// Check that we can add and remove signals.
+TEST_F(SigProcMaskTest, BlockUnblock) {
+ sigset_t actual;
+ sigset_t expected;
+
+ // Try to set a mask with one signal.
+ sigemptyset(&expected);
+ sigaddset(&expected, kTestSignal1);
+ EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, &expected, nullptr),
+ SyscallSucceeds());
+ EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, &actual),
+ SyscallSucceeds());
+ EXPECT_THAT(actual, EqualsSigset(expected));
+
+ // Try to add another signal.
+ sigset_t block;
+ sigemptyset(&block);
+ sigaddset(&block, kTestSignal2);
+ EXPECT_THAT(raw_sigprocmask(SIG_BLOCK, &block, nullptr), SyscallSucceeds());
+ EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, &actual),
+ SyscallSucceeds());
+ sigaddset(&expected, kTestSignal2);
+ EXPECT_THAT(actual, EqualsSigset(expected));
+
+ // Try to remove a signal.
+ sigset_t unblock;
+ sigemptyset(&unblock);
+ sigaddset(&unblock, kTestSignal1);
+ EXPECT_THAT(raw_sigprocmask(SIG_UNBLOCK, &unblock, nullptr),
+ SyscallSucceeds());
+ EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, nullptr, &actual),
+ SyscallSucceeds());
+ sigdelset(&expected, kTestSignal1);
+ EXPECT_THAT(actual, EqualsSigset(expected));
+}
+
+// Test that the signal mask actually blocks signals.
+TEST_F(SigProcMaskTest, SignalHandler) {
+ sigset_t mask;
+
+ // clear the signal mask
+ sigemptyset(&mask);
+ EXPECT_THAT(raw_sigprocmask(SIG_SETMASK, &mask, nullptr), SyscallSucceeds());
+
+ // Check the initial signal counts.
+ EXPECT_EQ(0, signal_count[kTestSignal1]);
+ EXPECT_EQ(0, signal_count[kTestSignal2]);
+
+ // Check that both kTestSignal1 and kTestSignal2 are not blocked.
+ raise(kTestSignal1);
+ raise(kTestSignal2);
+ EXPECT_EQ(1, signal_count[kTestSignal1]);
+ EXPECT_EQ(1, signal_count[kTestSignal2]);
+
+ // Block kTestSignal1.
+ sigaddset(&mask, kTestSignal1);
+ EXPECT_THAT(raw_sigprocmask(SIG_BLOCK, &mask, nullptr), SyscallSucceeds());
+
+ // Check that kTestSignal1 is blocked.
+ raise(kTestSignal1);
+ raise(kTestSignal2);
+ EXPECT_EQ(1, signal_count[kTestSignal1]);
+ EXPECT_EQ(2, signal_count[kTestSignal2]);
+
+ // Unblock kTestSignal1.
+ sigaddset(&mask, kTestSignal1);
+ EXPECT_THAT(raw_sigprocmask(SIG_UNBLOCK, &mask, nullptr), SyscallSucceeds());
+
+ // Check that the unblocked kTestSignal1 has been delivered.
+ EXPECT_EQ(2, signal_count[kTestSignal1]);
+ EXPECT_EQ(2, signal_count[kTestSignal2]);
+}
+
+// Check that sigprocmask correctly handles aliasing of the set and oldset
+// pointers. Regression test for b/30502311.
+TEST_F(SigProcMaskTest, AliasedSets) {
+ sigset_t mask;
+
+ // Set a mask in which only kTestSignal1 is blocked.
+ sigset_t mask1;
+ sigemptyset(&mask1);
+ sigaddset(&mask1, kTestSignal1);
+ mask = mask1;
+ ASSERT_THAT(raw_sigprocmask(SIG_SETMASK, &mask, nullptr), SyscallSucceeds());
+
+ // Exchange it with a mask in which only kTestSignal2 is blocked.
+ sigset_t mask2;
+ sigemptyset(&mask2);
+ sigaddset(&mask2, kTestSignal2);
+ mask = mask2;
+ ASSERT_THAT(raw_sigprocmask(SIG_SETMASK, &mask, &mask), SyscallSucceeds());
+
+ // Check that the exchange succeeeded:
+ // mask should now contain the previously-set mask blocking only kTestSignal1.
+ EXPECT_THAT(mask, EqualsSigset(mask1));
+ // The current mask should block only kTestSignal2.
+ ASSERT_THAT(raw_sigprocmask(0, nullptr, &mask), SyscallSucceeds());
+ EXPECT_THAT(mask, EqualsSigset(mask2));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/sigstop.cc b/test/syscalls/linux/sigstop.cc
new file mode 100644
index 000000000..b2fcedd62
--- /dev/null
+++ b/test/syscalls/linux/sigstop.cc
@@ -0,0 +1,151 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+#include <stdlib.h>
+#include <sys/select.h>
+
+#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+ABSL_FLAG(bool, sigstop_test_child, false,
+ "If true, run the SigstopTest child workload.");
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr absl::Duration kChildStartupDelay = absl::Seconds(5);
+constexpr absl::Duration kChildMainThreadDelay = absl::Seconds(10);
+constexpr absl::Duration kChildExtraThreadDelay = absl::Seconds(15);
+constexpr absl::Duration kPostSIGSTOPDelay = absl::Seconds(20);
+
+// Comparisons on absl::Duration aren't yet constexpr (2017-07-14), so we
+// can't just use static_assert.
+TEST(SigstopTest, TimesAreRelativelyConsistent) {
+ EXPECT_LT(kChildStartupDelay, kChildMainThreadDelay)
+ << "Child process will exit before the parent process attempts to stop "
+ "it";
+ EXPECT_LT(kChildMainThreadDelay, kChildExtraThreadDelay)
+ << "Secondary thread in child process will exit before main thread, "
+ "causing it to exit with the wrong code";
+ EXPECT_LT(kChildExtraThreadDelay, kPostSIGSTOPDelay)
+ << "Parent process stops waiting before child process may exit if "
+ "improperly stopped, rendering the test ineffective";
+}
+
+// Exit codes communicated from the child workload to the parent test process.
+constexpr int kChildMainThreadExitCode = 10;
+constexpr int kChildExtraThreadExitCode = 11;
+
+TEST(SigstopTest, Correctness) {
+ pid_t child_pid = -1;
+ int execve_errno = 0;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec("/proc/self/exe", {"/proc/self/exe", "--sigstop_test_child"},
+ {}, nullptr, &child_pid, &execve_errno));
+
+ ASSERT_GT(child_pid, 0);
+ ASSERT_EQ(execve_errno, 0);
+
+ // Wait for the child subprocess to start the second thread before stopping
+ // it.
+ absl::SleepFor(kChildStartupDelay);
+ ASSERT_THAT(kill(child_pid, SIGSTOP), SyscallSucceeds());
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(child_pid, &status, WUNTRACED),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFSTOPPED(status));
+ EXPECT_EQ(SIGSTOP, WSTOPSIG(status));
+
+ // Sleep for longer than either of the sleeps in the child subprocess,
+ // expecting the child to stay alive because it's stopped.
+ absl::SleepFor(kPostSIGSTOPDelay);
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, WNOHANG),
+ SyscallSucceedsWithValue(0));
+
+ // Resume the child.
+ ASSERT_THAT(kill(child_pid, SIGCONT), SyscallSucceeds());
+
+ EXPECT_THAT(RetryEINTR(waitpid)(child_pid, &status, WCONTINUED),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFCONTINUED(status));
+
+ // Expect it to die.
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ ASSERT_TRUE(WIFEXITED(status));
+ ASSERT_EQ(WEXITSTATUS(status), kChildMainThreadExitCode);
+}
+
+// Like base:SleepFor, but tries to avoid counting time spent stopped due to a
+// stop signal toward the sleep.
+//
+// This is required due to an inconsistency in how nanosleep(2) and stop signals
+// interact on Linux. When nanosleep is interrupted, it writes the remaining
+// time back to its second timespec argument, so that if nanosleep is
+// interrupted by a signal handler then userspace can immediately call nanosleep
+// again with that timespec. However, if nanosleep is automatically restarted
+// (because it's interrupted by a signal that is not delivered to a handler,
+// such as a stop signal), it's restarted based on the timer's former *absolute*
+// expiration time (via ERESTART_RESTARTBLOCK => SYS_restart_syscall =>
+// hrtimer_nanosleep_restart). This means that time spent stopped is effectively
+// counted as time spent sleeping, resulting in less time spent sleeping than
+// expected.
+//
+// Dividing the sleep into multiple smaller sleeps limits the impact of this
+// effect to the length of each sleep during which a stop occurs; for example,
+// if a sleeping process is only stopped once, SleepIgnoreStopped can
+// under-sleep by at most 100ms.
+void SleepIgnoreStopped(absl::Duration d) {
+ absl::Duration const max_sleep = absl::Milliseconds(100);
+ while (d > absl::ZeroDuration()) {
+ absl::Duration to_sleep = std::min(d, max_sleep);
+ absl::SleepFor(to_sleep);
+ d -= to_sleep;
+ }
+}
+
+void RunChild() {
+ // Start another thread that attempts to call exit_group with a different
+ // error code, in order to verify that SIGSTOP stops this thread as well.
+ ScopedThread t([] {
+ SleepIgnoreStopped(kChildExtraThreadDelay);
+ exit(kChildExtraThreadExitCode);
+ });
+ SleepIgnoreStopped(kChildMainThreadDelay);
+ exit(kChildMainThreadExitCode);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
+
+int main(int argc, char** argv) {
+ gvisor::testing::TestInit(&argc, &argv);
+
+ if (absl::GetFlag(FLAGS_sigstop_test_child)) {
+ gvisor::testing::RunChild();
+ return 1;
+ }
+
+ return gvisor::testing::RunAllTests();
+}
diff --git a/test/syscalls/linux/sigtimedwait.cc b/test/syscalls/linux/sigtimedwait.cc
new file mode 100644
index 000000000..4f8afff15
--- /dev/null
+++ b/test/syscalls/linux/sigtimedwait.cc
@@ -0,0 +1,323 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/wait.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+#include "test/util/timer_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// N.B. main() blocks SIGALRM and SIGCHLD on all threads.
+
+constexpr int kAlarmSecs = 12;
+
+void NoopHandler(int sig, siginfo_t* info, void* context) {}
+
+TEST(SigtimedwaitTest, InvalidTimeout) {
+ sigset_t mask;
+ sigemptyset(&mask);
+ struct timespec timeout = {0, 1000000001};
+ EXPECT_THAT(sigtimedwait(&mask, nullptr, &timeout),
+ SyscallFailsWithErrno(EINVAL));
+ timeout = {-1, 0};
+ EXPECT_THAT(sigtimedwait(&mask, nullptr, &timeout),
+ SyscallFailsWithErrno(EINVAL));
+ timeout = {0, -1};
+ EXPECT_THAT(sigtimedwait(&mask, nullptr, &timeout),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// No random save as the test relies on alarm timing. Cooperative save tests
+// already cover the save between alarm and wait.
+TEST(SigtimedwaitTest, AlarmReturnsAlarm_NoRandomSave) {
+ struct itimerval itv = {};
+ itv.it_value.tv_sec = kAlarmSecs;
+ const auto itimer_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_REAL, itv));
+
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, SIGALRM);
+ siginfo_t info = {};
+ EXPECT_THAT(RetryEINTR(sigtimedwait)(&mask, &info, nullptr),
+ SyscallSucceedsWithValue(SIGALRM));
+ EXPECT_EQ(SIGALRM, info.si_signo);
+}
+
+// No random save as the test relies on alarm timing. Cooperative save tests
+// already cover the save between alarm and wait.
+TEST(SigtimedwaitTest, NullTimeoutReturnsEINTR_NoRandomSave) {
+ struct sigaction sa;
+ sa.sa_sigaction = NoopHandler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO;
+ const auto action_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa));
+
+ const auto mask_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGALRM));
+
+ struct itimerval itv = {};
+ itv.it_value.tv_sec = kAlarmSecs;
+ const auto itimer_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_REAL, itv));
+
+ sigset_t mask;
+ sigemptyset(&mask);
+ EXPECT_THAT(sigtimedwait(&mask, nullptr, nullptr),
+ SyscallFailsWithErrno(EINTR));
+}
+
+TEST(SigtimedwaitTest, LegitTimeoutReturnsEAGAIN) {
+ sigset_t mask;
+ sigemptyset(&mask);
+ struct timespec timeout = {1, 0}; // 1 second
+ EXPECT_THAT(RetryEINTR(sigtimedwait)(&mask, nullptr, &timeout),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST(SigtimedwaitTest, ZeroTimeoutReturnsEAGAIN) {
+ sigset_t mask;
+ sigemptyset(&mask);
+ struct timespec timeout = {0, 0}; // 0 second
+ EXPECT_THAT(sigtimedwait(&mask, nullptr, &timeout),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST(SigtimedwaitTest, KillGeneratedSIGCHLD) {
+ EXPECT_THAT(kill(getpid(), SIGCHLD), SyscallSucceeds());
+
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, SIGCHLD);
+ struct timespec ts = {5, 0};
+ EXPECT_THAT(RetryEINTR(sigtimedwait)(&mask, nullptr, &ts),
+ SyscallSucceedsWithValue(SIGCHLD));
+}
+
+TEST(SigtimedwaitTest, ChildExitGeneratedSIGCHLD) {
+ pid_t pid = fork();
+ if (pid == 0) {
+ _exit(0);
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+
+ int status;
+ EXPECT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) << status;
+
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, SIGCHLD);
+ struct timespec ts = {5, 0};
+ EXPECT_THAT(RetryEINTR(sigtimedwait)(&mask, nullptr, &ts),
+ SyscallSucceedsWithValue(SIGCHLD));
+}
+
+TEST(SigtimedwaitTest, ChildExitGeneratedSIGCHLDWithHandler) {
+ // Setup handler for SIGCHLD, but don't unblock it.
+ struct sigaction sa;
+ sa.sa_sigaction = NoopHandler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO;
+ const auto action_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGCHLD, sa));
+
+ pid_t pid = fork();
+ if (pid == 0) {
+ _exit(0);
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, SIGCHLD);
+ struct timespec ts = {5, 0};
+ EXPECT_THAT(RetryEINTR(sigtimedwait)(&mask, nullptr, &ts),
+ SyscallSucceedsWithValue(SIGCHLD));
+
+ int status;
+ EXPECT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) << status;
+}
+
+// sigtimedwait cannot catch SIGKILL.
+TEST(SigtimedwaitTest, SIGKILLUncaught) {
+ // This is a regression test for sigtimedwait dequeuing SIGKILLs, thus
+ // preventing the task from exiting.
+ //
+ // The explanation below is specific to behavior in gVisor. The Linux behavior
+ // here is irrelevant because without a bug that prevents delivery of SIGKILL,
+ // none of this behavior is visible (in Linux or gVisor).
+ //
+ // SIGKILL is rather intrusive. Simply sending the SIGKILL marks
+ // ThreadGroup.exitStatus as exiting with SIGKILL, before the SIGKILL is even
+ // delivered.
+ //
+ // As a result, we cannot simply exit the child with a different exit code if
+ // it survives and expect to see that code in waitpid because:
+ // 1. PrepareGroupExit will override Task.exitStatus with
+ // ThreadGroup.exitStatus.
+ // 2. waitpid(2) will always return ThreadGroup.exitStatus rather than
+ // Task.exitStatus.
+ //
+ // We could use exit(2) to set Task.exitStatus without override, and a SIGCHLD
+ // handler to receive Task.exitStatus in the parent, but with that much
+ // test complexity, it is cleaner to simply use a pipe to notify the parent
+ // that we survived.
+ constexpr auto kSigtimedwaitSetupTime = absl::Seconds(2);
+
+ int pipe_fds[2];
+ ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds());
+ FileDescriptor rfd(pipe_fds[0]);
+ FileDescriptor wfd(pipe_fds[1]);
+
+ pid_t pid = fork();
+ if (pid == 0) {
+ rfd.reset();
+
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, SIGKILL);
+ RetryEINTR(sigtimedwait)(&mask, nullptr, nullptr);
+
+ // Survived.
+ char c = 'a';
+ TEST_PCHECK(WriteFd(wfd.get(), &c, 1) == 1);
+ _exit(1);
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+
+ wfd.reset();
+
+ // Wait for child to block in sigtimedwait, then kill it.
+ absl::SleepFor(kSigtimedwaitSetupTime);
+
+ // Sending SIGKILL will attempt to enqueue the signal twice: once in the
+ // normal signal sending path, and once to all Tasks in the ThreadGroup when
+ // applying SIGKILL side-effects.
+ //
+ // If we use kill(2), the former will be on the ThreadGroup signal queue and
+ // the latter will be on the Task signal queue. sigtimedwait can only dequeue
+ // one signal, so the other would kill the Task, masking bugs.
+ //
+ // If we use tkill(2), the former will be on the Task signal queue and the
+ // latter will be dropped as a duplicate. Then sigtimedwait can theoretically
+ // dequeue the single SIGKILL.
+ EXPECT_THAT(syscall(SYS_tkill, pid, SIGKILL), SyscallSucceeds());
+
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(pid, &status, 0),
+ SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL) << status;
+
+ // Child shouldn't have survived.
+ char c;
+ EXPECT_THAT(ReadFd(rfd.get(), &c, 1), SyscallSucceedsWithValue(0));
+}
+
+TEST(SigtimedwaitTest, IgnoredUnmaskedSignal) {
+ constexpr int kSigno = SIGUSR1;
+ constexpr auto kSigtimedwaitSetupTime = absl::Seconds(2);
+ constexpr auto kSigtimedwaitTimeout = absl::Seconds(5);
+ ASSERT_GT(kSigtimedwaitTimeout, kSigtimedwaitSetupTime);
+
+ // Ensure that kSigno is ignored, and unmasked on this thread.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_IGN;
+ const auto scoped_sigaction =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(kSigno, sa));
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, mask));
+
+ // Create a thread which will send us kSigno while we are blocked in
+ // sigtimedwait.
+ pid_t tid = gettid();
+ ScopedThread sigthread([&] {
+ absl::SleepFor(kSigtimedwaitSetupTime);
+ EXPECT_THAT(tgkill(getpid(), tid, kSigno), SyscallSucceeds());
+ });
+
+ // sigtimedwait should not observe kSigno since it is ignored and already
+ // unmasked, causing it to be dropped before it is enqueued.
+ struct timespec timeout_ts = absl::ToTimespec(kSigtimedwaitTimeout);
+ EXPECT_THAT(RetryEINTR(sigtimedwait)(&mask, nullptr, &timeout_ts),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST(SigtimedwaitTest, IgnoredMaskedSignal) {
+ constexpr int kSigno = SIGUSR1;
+ constexpr auto kSigtimedwaitSetupTime = absl::Seconds(2);
+ constexpr auto kSigtimedwaitTimeout = absl::Seconds(5);
+ ASSERT_GT(kSigtimedwaitTimeout, kSigtimedwaitSetupTime);
+
+ // Ensure that kSigno is ignored, and masked on this thread.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_IGN;
+ const auto scoped_sigaction =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(kSigno, sa));
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, mask));
+
+ // Create a thread which will send us kSigno while we are blocked in
+ // sigtimedwait.
+ pid_t tid = gettid();
+ ScopedThread sigthread([&] {
+ absl::SleepFor(kSigtimedwaitSetupTime);
+ EXPECT_THAT(tgkill(getpid(), tid, kSigno), SyscallSucceeds());
+ });
+
+ // sigtimedwait should observe kSigno since it is normally masked, causing it
+ // to be enqueued despite being ignored.
+ struct timespec timeout_ts = absl::ToTimespec(kSigtimedwaitTimeout);
+ EXPECT_THAT(RetryEINTR(sigtimedwait)(&mask, nullptr, &timeout_ts),
+ SyscallSucceedsWithValue(kSigno));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
+
+int main(int argc, char** argv) {
+ // These tests depend on delivering SIGALRM/SIGCHLD to the main thread or in
+ // sigtimedwait. Block them so that any other threads created by TestInit will
+ // also have them blocked.
+ sigset_t set;
+ sigemptyset(&set);
+ sigaddset(&set, SIGALRM);
+ sigaddset(&set, SIGCHLD);
+ TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
+
+ gvisor::testing::TestInit(&argc, &argv);
+ return gvisor::testing::RunAllTests();
+}
diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc
new file mode 100644
index 000000000..c20cd3fcc
--- /dev/null
+++ b/test/syscalls/linux/socket.cc
@@ -0,0 +1,121 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_umask.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+TEST(SocketTest, UnixSocketPairProtocol) {
+ int socks[2];
+ ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, PF_UNIX, socks),
+ SyscallSucceeds());
+ close(socks[0]);
+ close(socks[1]);
+}
+
+TEST(SocketTest, ProtocolUnix) {
+ struct {
+ int domain, type, protocol;
+ } tests[] = {
+ {AF_UNIX, SOCK_STREAM, PF_UNIX},
+ {AF_UNIX, SOCK_SEQPACKET, PF_UNIX},
+ {AF_UNIX, SOCK_DGRAM, PF_UNIX},
+ };
+ for (int i = 0; i < ABSL_ARRAYSIZE(tests); i++) {
+ ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(tests[i].domain, tests[i].type, tests[i].protocol));
+ }
+}
+
+TEST(SocketTest, ProtocolInet) {
+ struct {
+ int domain, type, protocol;
+ } tests[] = {
+ {AF_INET, SOCK_DGRAM, IPPROTO_UDP},
+ {AF_INET, SOCK_STREAM, IPPROTO_TCP},
+ };
+ for (int i = 0; i < ABSL_ARRAYSIZE(tests); i++) {
+ ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(tests[i].domain, tests[i].type, tests[i].protocol));
+ }
+}
+
+TEST(SocketTest, UnixSocketStat) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ FileDescriptor bound =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX));
+
+ // The permissions of the file created with bind(2) should be defined by the
+ // permissions of the bound socket and the umask.
+ mode_t sock_perm = 0765, mask = 0123;
+ ASSERT_THAT(fchmod(bound.get(), sock_perm), SyscallSucceeds());
+ TempUmask m(mask);
+
+ struct sockaddr_un addr =
+ ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(/*abstract=*/false, AF_UNIX));
+ ASSERT_THAT(bind(bound.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+
+ struct stat statbuf = {};
+ ASSERT_THAT(stat(addr.sun_path, &statbuf), SyscallSucceeds());
+
+ // Mode should be S_IFSOCK.
+ EXPECT_EQ(statbuf.st_mode, S_IFSOCK | sock_perm & ~mask);
+
+ // Timestamps should be equal and non-zero.
+ // TODO(b/158882152): Sockets currently don't implement timestamps.
+ if (!IsRunningOnGvisor()) {
+ EXPECT_NE(statbuf.st_atime, 0);
+ EXPECT_EQ(statbuf.st_atime, statbuf.st_mtime);
+ EXPECT_EQ(statbuf.st_atime, statbuf.st_ctime);
+ }
+}
+
+using SocketOpenTest = ::testing::TestWithParam<int>;
+
+// UDS cannot be opened.
+TEST_P(SocketOpenTest, Unix) {
+ // FIXME(b/142001530): Open incorrectly succeeds on gVisor.
+ SKIP_IF(IsRunningWithVFS1());
+
+ FileDescriptor bound =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX));
+
+ struct sockaddr_un addr =
+ ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(/*abstract=*/false, AF_UNIX));
+
+ ASSERT_THAT(bind(bound.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+
+ EXPECT_THAT(open(addr.sun_path, GetParam()), SyscallFailsWithErrno(ENXIO));
+}
+
+INSTANTIATE_TEST_SUITE_P(OpenModes, SocketOpenTest,
+ ::testing::Values(O_RDONLY, O_RDWR));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_abstract.cc b/test/syscalls/linux/socket_abstract.cc
new file mode 100644
index 000000000..00999f192
--- /dev/null
+++ b/test/syscalls/linux/socket_abstract.cc
@@ -0,0 +1,49 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/socket_generic.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/socket_unix.h"
+#include "test/syscalls/linux/socket_unix_cmsg.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return ApplyVec<SocketPairKind>(
+ AbstractBoundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET},
+ List<int>{0, SOCK_NONBLOCK}));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AbstractUnixSockets, AllSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+INSTANTIATE_TEST_SUITE_P(
+ AbstractUnixSockets, UnixSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+INSTANTIATE_TEST_SUITE_P(
+ AbstractUnixSockets, UnixSocketPairCmsgTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device.cc b/test/syscalls/linux/socket_bind_to_device.cc
new file mode 100644
index 000000000..6b27f6eab
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device.cc
@@ -0,0 +1,313 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+
+// Test fixture for SO_BINDTODEVICE tests.
+class BindToDeviceTest : public ::testing::TestWithParam<SocketKind> {
+ protected:
+ void SetUp() override {
+ printf("Testing case: %s\n", GetParam().description.c_str());
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)))
+ << "CAP_NET_RAW is required to use SO_BINDTODEVICE";
+
+ interface_name_ = "eth1";
+ auto interface_names = GetInterfaceNames();
+ if (interface_names.find(interface_name_) == interface_names.end()) {
+ // Need a tunnel.
+ tunnel_ = ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New());
+ interface_name_ = tunnel_->GetName();
+ ASSERT_FALSE(interface_name_.empty());
+ }
+ socket_ = ASSERT_NO_ERRNO_AND_VALUE(GetParam().Create());
+ }
+
+ string interface_name() const { return interface_name_; }
+
+ int socket_fd() const { return socket_->get(); }
+
+ private:
+ std::unique_ptr<Tunnel> tunnel_;
+ string interface_name_;
+ std::unique_ptr<FileDescriptor> socket_;
+};
+
+constexpr char kIllegalIfnameChar = '/';
+
+// Tests getsockopt of the default value.
+TEST_P(BindToDeviceTest, GetsockoptDefault) {
+ char name_buffer[IFNAMSIZ * 2];
+ char original_name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Read the default SO_BINDTODEVICE.
+ memset(original_name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ for (size_t i = 0; i <= sizeof(name_buffer); i++) {
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = i;
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, &name_buffer_size),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(name_buffer_size, 0);
+ EXPECT_EQ(memcmp(name_buffer, original_name_buffer, sizeof(name_buffer)),
+ 0);
+ }
+}
+
+// Tests setsockopt of invalid device name.
+TEST_P(BindToDeviceTest, SetsockoptInvalidDeviceName) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Set an invalid device name.
+ memset(name_buffer, kIllegalIfnameChar, 5);
+ name_buffer_size = 5;
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ name_buffer_size),
+ SyscallFailsWithErrno(ENODEV));
+}
+
+// Tests setsockopt of a buffer with a valid device name but not
+// null-terminated, with different sizes of buffer.
+TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithoutNullTermination) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1);
+ // Intentionally overwrite the null at the end.
+ memset(name_buffer + interface_name().size(), kIllegalIfnameChar,
+ sizeof(name_buffer) - interface_name().size());
+ for (size_t i = 1; i <= sizeof(name_buffer); i++) {
+ name_buffer_size = i;
+ SCOPED_TRACE(absl::StrCat("Buffer size: ", i));
+ // It should only work if the size provided is exactly right.
+ if (name_buffer_size == interface_name().size()) {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallSucceeds());
+ } else {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallFailsWithErrno(ENODEV));
+ }
+ }
+}
+
+// Tests setsockopt of a buffer with a valid device name and null-terminated,
+// with different sizes of buffer.
+TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithNullTermination) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1);
+ // Don't overwrite the null at the end.
+ memset(name_buffer + interface_name().size() + 1, kIllegalIfnameChar,
+ sizeof(name_buffer) - interface_name().size() - 1);
+ for (size_t i = 1; i <= sizeof(name_buffer); i++) {
+ name_buffer_size = i;
+ SCOPED_TRACE(absl::StrCat("Buffer size: ", i));
+ // It should only work if the size provided is at least the right size.
+ if (name_buffer_size >= interface_name().size()) {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallSucceeds());
+ } else {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallFailsWithErrno(ENODEV));
+ }
+ }
+}
+
+// Tests that setsockopt of an invalid device name doesn't unset the previous
+// valid setsockopt.
+TEST_P(BindToDeviceTest, SetsockoptValidThenInvalid) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+
+ // Write unsuccessfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = 5;
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallFailsWithErrno(ENODEV));
+
+ // Read it back successfully, it's unchanged.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+}
+
+// Tests that setsockopt of zero-length string correctly unsets the previous
+// value.
+TEST_P(BindToDeviceTest, SetsockoptValidThenClear) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+
+ // Clear it successfully.
+ name_buffer_size = 0;
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ name_buffer_size),
+ SyscallSucceeds());
+
+ // Read it back successfully, it's cleared.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, 0);
+}
+
+// Tests that setsockopt of empty string correctly unsets the previous
+// value.
+TEST_P(BindToDeviceTest, SetsockoptValidThenClearWithNull) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+
+ // Clear it successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer[0] = 0;
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ name_buffer_size),
+ SyscallSucceeds());
+
+ // Read it back successfully, it's cleared.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, 0);
+}
+
+// Tests getsockopt with different buffer sizes.
+TEST_P(BindToDeviceTest, GetsockoptDevice) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back at various buffer sizes.
+ for (size_t i = 0; i <= sizeof(name_buffer); i++) {
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = i;
+ SCOPED_TRACE(absl::StrCat("Buffer size: ", i));
+ // Linux only allows a buffer at least IFNAMSIZ, even if less would suffice
+ // for this interface name.
+ if (name_buffer_size >= IFNAMSIZ) {
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+ } else {
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, &name_buffer_size),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_EQ(name_buffer_size, i);
+ }
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceTest,
+ ::testing::Values(IPv4UDPUnboundSocket(0),
+ IPv4TCPUnboundSocket(0)));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc
new file mode 100644
index 000000000..5ed57625c
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc
@@ -0,0 +1,401 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <atomic>
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+using std::vector;
+
+struct EndpointConfig {
+ std::string bind_to_device;
+ double expected_ratio;
+};
+
+struct DistributionTestCase {
+ std::string name;
+ std::vector<EndpointConfig> endpoints;
+};
+
+struct ListenerConnector {
+ TestAddress listener;
+ TestAddress connector;
+};
+
+// Test fixture for SO_BINDTODEVICE tests the distribution of packets received
+// with varying SO_BINDTODEVICE settings.
+class BindToDeviceDistributionTest
+ : public ::testing::TestWithParam<
+ ::testing::tuple<ListenerConnector, DistributionTestCase>> {
+ protected:
+ void SetUp() override {
+ printf("Testing case: %s, listener=%s, connector=%s\n",
+ ::testing::get<1>(GetParam()).name.c_str(),
+ ::testing::get<0>(GetParam()).listener.description.c_str(),
+ ::testing::get<0>(GetParam()).connector.description.c_str());
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)))
+ << "CAP_NET_RAW is required to use SO_BINDTODEVICE";
+ }
+};
+
+PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) {
+ switch (family) {
+ case AF_INET:
+ return static_cast<uint16_t>(
+ reinterpret_cast<sockaddr_in const*>(&addr)->sin_port);
+ case AF_INET6:
+ return static_cast<uint16_t>(
+ reinterpret_cast<sockaddr_in6 const*>(&addr)->sin6_port);
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) {
+ switch (family) {
+ case AF_INET:
+ reinterpret_cast<sockaddr_in*>(addr)->sin_port = port;
+ return NoError();
+ case AF_INET6:
+ reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = port;
+ return NoError();
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+// Binds sockets to different devices and then creates many TCP connections.
+// Checks that the distribution of connections received on the sockets matches
+// the expectation.
+TEST_P(BindToDeviceDistributionTest, Tcp) {
+ auto const& [listener_connector, test] = GetParam();
+
+ TestAddress const& listener = listener_connector.listener;
+ TestAddress const& connector = listener_connector.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+
+ auto interface_names = GetInterfaceNames();
+
+ // Create the listening sockets.
+ std::vector<FileDescriptor> listener_fds;
+ std::vector<std::unique_ptr<Tunnel>> all_tunnels;
+ for (auto const& endpoint : test.endpoints) {
+ if (!endpoint.bind_to_device.empty() &&
+ interface_names.find(endpoint.bind_to_device) ==
+ interface_names.end()) {
+ all_tunnels.push_back(
+ ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device)));
+ interface_names.insert(endpoint.bind_to_device);
+ }
+
+ listener_fds.push_back(ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)));
+ int fd = listener_fds.back().get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE,
+ endpoint.bind_to_device.c_str(),
+ endpoint.bind_to_device.size() + 1),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(fd, 40), SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (listener_fds.size() > 1) {
+ continue;
+ }
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ constexpr int kConnectAttempts = 10000;
+ std::atomic<int> connects_received = ATOMIC_VAR_INIT(0);
+ std::vector<int> accept_counts(listener_fds.size(), 0);
+ std::vector<std::unique_ptr<ScopedThread>> listen_threads(
+ listener_fds.size());
+
+ for (int i = 0; i < listener_fds.size(); i++) {
+ listen_threads[i] = absl::make_unique<ScopedThread>(
+ [&listener_fds, &accept_counts, &connects_received, i,
+ kConnectAttempts]() {
+ do {
+ auto fd = Accept(listener_fds[i].get(), nullptr, nullptr);
+ if (!fd.ok()) {
+ // Another thread has shutdown our read side causing the accept to
+ // fail.
+ ASSERT_GE(connects_received, kConnectAttempts)
+ << "errno = " << fd.error();
+ return;
+ }
+ // Receive some data from a socket to be sure that the connect()
+ // system call has been completed on another side.
+ // Do a short read and then close the socket to trigger a RST. This
+ // ensures that both ends of the connection are cleaned up and no
+ // goroutines hang around in TIME-WAIT. We do this so that this test
+ // does not timeout under gotsan runs where lots of goroutines can
+ // cause the test to use absurd amounts of memory.
+ //
+ // See: https://tools.ietf.org/html/rfc2525#page-50 section 2.17
+ uint16_t data;
+ EXPECT_THAT(
+ RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(sizeof(data)));
+ accept_counts[i]++;
+ } while (++connects_received < kConnectAttempts);
+
+ // Shutdown all sockets to wake up other threads.
+ for (auto const& listener_fd : listener_fds) {
+ shutdown(listener_fd.get(), SHUT_RDWR);
+ }
+ });
+ }
+
+ for (int i = 0; i < kConnectAttempts; i++) {
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(
+ RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Do two separate sends to ensure two segments are received. This is
+ // required for netstack where read is incorrectly assuming a whole
+ // segment is read when endpoint.Read() is called which is technically
+ // incorrect as the syscall that invoked endpoint.Read() may only
+ // consume it partially. This results in a case where a close() of
+ // such a socket does not trigger a RST in netstack due to the
+ // endpoint assuming that the endpoint has no unread data.
+ EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+
+ // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly
+ // generates a RST.
+ if (IsRunningOnGvisor()) {
+ EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
+ }
+
+ // Join threads to be sure that all connections have been counted.
+ for (auto const& listen_thread : listen_threads) {
+ listen_thread->Join();
+ }
+ // Check that connections are distributed correctly among listening sockets.
+ for (int i = 0; i < accept_counts.size(); i++) {
+ EXPECT_THAT(
+ accept_counts[i],
+ EquivalentWithin(static_cast<int>(kConnectAttempts *
+ test.endpoints[i].expected_ratio),
+ 0.10))
+ << "endpoint " << i << " got the wrong number of packets";
+ }
+}
+
+// Binds sockets to different devices and then sends many UDP packets. Checks
+// that the distribution of packets received on the sockets matches the
+// expectation.
+TEST_P(BindToDeviceDistributionTest, Udp) {
+ auto const& [listener_connector, test] = GetParam();
+
+ TestAddress const& listener = listener_connector.listener;
+ TestAddress const& connector = listener_connector.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+
+ auto interface_names = GetInterfaceNames();
+
+ // Create the listening socket.
+ std::vector<FileDescriptor> listener_fds;
+ std::vector<std::unique_ptr<Tunnel>> all_tunnels;
+ for (auto const& endpoint : test.endpoints) {
+ if (!endpoint.bind_to_device.empty() &&
+ interface_names.find(endpoint.bind_to_device) ==
+ interface_names.end()) {
+ all_tunnels.push_back(
+ ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device)));
+ interface_names.insert(endpoint.bind_to_device);
+ }
+
+ listener_fds.push_back(
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0)));
+ int fd = listener_fds.back().get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE,
+ endpoint.bind_to_device.c_str(),
+ endpoint.bind_to_device.size() + 1),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (listener_fds.size() > 1) {
+ continue;
+ }
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ constexpr int kConnectAttempts = 10000;
+ std::atomic<int> packets_received = ATOMIC_VAR_INIT(0);
+ std::vector<int> packets_per_socket(listener_fds.size(), 0);
+ std::vector<std::unique_ptr<ScopedThread>> receiver_threads(
+ listener_fds.size());
+
+ for (int i = 0; i < listener_fds.size(); i++) {
+ receiver_threads[i] = absl::make_unique<ScopedThread>(
+ [&listener_fds, &packets_per_socket, &packets_received, i]() {
+ do {
+ struct sockaddr_storage addr = {};
+ socklen_t addrlen = sizeof(addr);
+ int data;
+
+ auto ret = RetryEINTR(recvfrom)(
+ listener_fds[i].get(), &data, sizeof(data), 0,
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen);
+
+ if (packets_received < kConnectAttempts) {
+ ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data)));
+ }
+
+ if (ret != sizeof(data)) {
+ // Another thread may have shutdown our read side causing the
+ // recvfrom to fail.
+ break;
+ }
+
+ packets_received++;
+ packets_per_socket[i]++;
+
+ // A response is required to synchronize with the main thread,
+ // otherwise the main thread can send more than can fit into receive
+ // queues.
+ EXPECT_THAT(RetryEINTR(sendto)(
+ listener_fds[i].get(), &data, sizeof(data), 0,
+ reinterpret_cast<sockaddr*>(&addr), addrlen),
+ SyscallSucceedsWithValue(sizeof(data)));
+ } while (packets_received < kConnectAttempts);
+
+ // Shutdown all sockets to wake up other threads.
+ for (auto const& listener_fd : listener_fds) {
+ shutdown(listener_fd.get(), SHUT_RDWR);
+ }
+ });
+ }
+
+ for (int i = 0; i < kConnectAttempts; i++) {
+ FileDescriptor const fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0));
+ EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ int data;
+ EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(sizeof(data)));
+ }
+
+ // Join threads to be sure that all connections have been counted.
+ for (auto const& receiver_thread : receiver_threads) {
+ receiver_thread->Join();
+ }
+ // Check that packets are distributed correctly among listening sockets.
+ for (int i = 0; i < packets_per_socket.size(); i++) {
+ EXPECT_THAT(
+ packets_per_socket[i],
+ EquivalentWithin(static_cast<int>(kConnectAttempts *
+ test.endpoints[i].expected_ratio),
+ 0.10))
+ << "endpoint " << i << " got the wrong number of packets";
+ }
+}
+
+std::vector<DistributionTestCase> GetDistributionTestCases() {
+ return std::vector<DistributionTestCase>{
+ {"Even distribution among sockets not bound to device",
+ {{"", 1. / 3}, {"", 1. / 3}, {"", 1. / 3}}},
+ {"Sockets bound to other interfaces get no packets",
+ {{"eth1", 0}, {"", 1. / 2}, {"", 1. / 2}}},
+ {"Bound has priority over unbound", {{"eth1", 0}, {"", 0}, {"lo", 1}}},
+ {"Even distribution among sockets bound to device",
+ {{"eth1", 0}, {"lo", 1. / 2}, {"lo", 1. / 2}}},
+ };
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ BindToDeviceTest, BindToDeviceDistributionTest,
+ ::testing::Combine(::testing::Values(
+ // Listeners bound to IPv4 addresses refuse
+ // connections using IPv6 addresses.
+ ListenerConnector{V4Any(), V4Loopback()},
+ ListenerConnector{V4Loopback(), V4MappedLoopback()}),
+ ::testing::ValuesIn(GetDistributionTestCases())));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc
new file mode 100644
index 000000000..d3cc71dbf
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc
@@ -0,0 +1,513 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/capability.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/node_hash_map.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+using std::vector;
+
+// Test fixture for SO_BINDTODEVICE tests the results of sequences of socket
+// binding.
+class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> {
+ protected:
+ void SetUp() override {
+ printf("Testing case: %s\n", GetParam().description.c_str());
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)))
+ << "CAP_NET_RAW is required to use SO_BINDTODEVICE";
+ socket_factory_ = GetParam();
+
+ interface_names_ = GetInterfaceNames();
+ }
+
+ PosixErrorOr<std::unique_ptr<FileDescriptor>> NewSocket() const {
+ return socket_factory_.Create();
+ }
+
+ // Gets a device by device_id. If the device_id has been seen before, returns
+ // the previously returned device. If not, finds or creates a new device.
+ // Returns an empty string on failure.
+ void GetDevice(int device_id, string* device_name) {
+ auto device = devices_.find(device_id);
+ if (device != devices_.end()) {
+ *device_name = device->second;
+ return;
+ }
+
+ // Need to pick a new device. Try ethernet first.
+ *device_name = absl::StrCat("eth", next_unused_eth_);
+ if (interface_names_.find(*device_name) != interface_names_.end()) {
+ devices_[device_id] = *device_name;
+ next_unused_eth_++;
+ return;
+ }
+
+ // Need to make a new tunnel device. gVisor tests should have enough
+ // ethernet devices to never reach here.
+ ASSERT_FALSE(IsRunningOnGvisor());
+ // Need a tunnel.
+ tunnels_.push_back(ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New()));
+ devices_[device_id] = tunnels_.back()->GetName();
+ *device_name = devices_[device_id];
+ }
+
+ // Release the socket
+ void ReleaseSocket(int socket_id) {
+ // Close the socket that was made in a previous action. The socket_id
+ // indicates which socket to close based on index into the list of actions.
+ sockets_to_close_.erase(socket_id);
+ }
+
+ // SetDevice changes the bind_to_device option. It does not bind or re-bind.
+ void SetDevice(int socket_id, int device_id) {
+ auto socket_fd = sockets_to_close_[socket_id]->get();
+ string device_name;
+ ASSERT_NO_FATAL_FAILURE(GetDevice(device_id, &device_name));
+ EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE,
+ device_name.c_str(), device_name.size() + 1),
+ SyscallSucceedsWithValue(0));
+ }
+
+ // Bind a socket with the reuse options and bind_to_device options. Checks
+ // that all steps succeed and that the bind command's error matches want.
+ // Sets the socket_id to uniquely identify the socket bound if it is not
+ // nullptr.
+ void BindSocket(bool reuse_port, bool reuse_addr, int device_id = 0,
+ int want = 0, int* socket_id = nullptr) {
+ next_socket_id_++;
+ sockets_to_close_[next_socket_id_] = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket_fd = sockets_to_close_[next_socket_id_]->get();
+ if (socket_id != nullptr) {
+ *socket_id = next_socket_id_;
+ }
+
+ // If reuse_port is indicated, do that.
+ if (reuse_port) {
+ EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+ }
+
+ // If reuse_addr is indicated, do that.
+ if (reuse_addr) {
+ EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+ }
+
+ // If the device is non-zero, bind to that device.
+ if (device_id != 0) {
+ string device_name;
+ ASSERT_NO_FATAL_FAILURE(GetDevice(device_id, &device_name));
+ EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE,
+ device_name.c_str(), device_name.size() + 1),
+ SyscallSucceedsWithValue(0));
+ char get_device[100];
+ socklen_t get_device_size = 100;
+ EXPECT_THAT(getsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, get_device,
+ &get_device_size),
+ SyscallSucceedsWithValue(0));
+ }
+
+ struct sockaddr_in addr = {};
+ addr.sin_family = AF_INET;
+ addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ addr.sin_port = port_;
+ if (want == 0) {
+ ASSERT_THAT(
+ bind(socket_fd, reinterpret_cast<const struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+ } else {
+ ASSERT_THAT(
+ bind(socket_fd, reinterpret_cast<const struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallFailsWithErrno(want));
+ }
+
+ if (port_ == 0) {
+ // We don't yet know what port we'll be using so we need to fetch it and
+ // remember it for future commands.
+ socklen_t addr_size = sizeof(addr);
+ ASSERT_THAT(
+ getsockname(socket_fd, reinterpret_cast<struct sockaddr*>(&addr),
+ &addr_size),
+ SyscallSucceeds());
+ port_ = addr.sin_port;
+ }
+ }
+
+ private:
+ SocketKind socket_factory_;
+ // devices maps from the device id in the test case to the name of the device.
+ absl::node_hash_map<int, string> devices_;
+ // These are the tunnels that were created for the test and will be destroyed
+ // by the destructor.
+ vector<std::unique_ptr<Tunnel>> tunnels_;
+ // A list of all interface names before the test started.
+ std::unordered_set<string> interface_names_;
+ // The next ethernet device to use when requested a device.
+ int next_unused_eth_ = 1;
+ // The port for all tests. Originally 0 (any) and later set to the port that
+ // all further commands will use.
+ in_port_t port_ = 0;
+ // sockets_to_close_ is a map from action index to the socket that was
+ // created.
+ absl::node_hash_map<int,
+ std::unique_ptr<gvisor::testing::FileDescriptor>>
+ sockets_to_close_;
+ int next_socket_id_ = 0;
+};
+
+TEST_P(BindToDeviceSequenceTest, BindTwiceWithDeviceFails) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 3));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 3, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindToDevice) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 1));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 2));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindToDeviceAndThenWithoutDevice) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindWithoutDevice) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindWithDevice) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123, 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 456, 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 789, 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindWithReuse) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reusePort */ true, /* reuse_addr */ false));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false,
+ /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 0));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindingWithReuseAndDevice) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 456));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse_port */ true, /* reuse_addr */ false));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 789));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 999, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, MixingReuseAndNotReuseByBindingToDevice) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123, 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 456, 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 789, 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 999, 0));
+}
+
+TEST_P(BindToDeviceSequenceTest, CannotBindTo0AfterMixingReuseAndNotReuse) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 456));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindAndRelease) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123));
+ int to_release;
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, 0, &to_release));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 345, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 789));
+ // Release the bind to device 0 and try again.
+ ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 345));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindTwiceWithReuseOnce) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindWithReuseAddr) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reusePort */ false, /* reuse_addr */ true));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 0));
+}
+
+TEST_P(BindToDeviceSequenceTest,
+ CannotBindTo0AfterMixingReuseAddrAndNotReuseAddr) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 456));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReusePort) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReuseAddr) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReusePort) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ true, /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReuseAddr) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ true, /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindReusePortThenReuseAddrReusePort) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindReuseAddrThenReuseAddr) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest,
+ BindReuseAddrThenReuseAddrReusePortThenReuseAddr) {
+ // The behavior described in this test seems like a Linux bug. It doesn't
+ // make any sense and it is unlikely that any applications rely on it.
+ //
+ // Both SO_REUSEADDR and SO_REUSEPORT allow binding multiple UDP sockets to
+ // the same address and deliver each packet to exactly one of the bound
+ // sockets. If both are enabled, one of the strategies is selected to route
+ // packets. The strategy is selected dynamically based on the settings of the
+ // currently bound sockets. Usually, the strategy is selected based on the
+ // common setting (SO_REUSEADDR or SO_REUSEPORT) amongst the sockets, but for
+ // some reason, Linux allows binding sets of sockets with no overlapping
+ // settings in some situations. In this case, it is not obvious which strategy
+ // would be selected as the configured setting is a contradiction.
+ SKIP_IF(IsRunningOnGvisor());
+
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0));
+}
+
+// Repro test for gvisor.dev/issue/1217. Not replicated in ports_test.go as this
+// test is different from the others and wouldn't fit well there.
+TEST_P(BindToDeviceSequenceTest, BindAndReleaseDifferentDevice) {
+ int to_release;
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 3, 0, &to_release));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 3, EADDRINUSE));
+ // Change the device. Since the socket was already bound, this should have no
+ // effect.
+ SetDevice(to_release, 2);
+ // Release the bind to device 3 and try again.
+ ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 3));
+}
+
+INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceSequenceTest,
+ ::testing::Values(IPv4UDPUnboundSocket(0),
+ IPv4TCPUnboundSocket(0)));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_util.cc b/test/syscalls/linux/socket_bind_to_device_util.cc
new file mode 100644
index 000000000..f4ee775bd
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_util.cc
@@ -0,0 +1,75 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+
+#include <arpa/inet.h>
+#include <fcntl.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+
+PosixErrorOr<std::unique_ptr<Tunnel>> Tunnel::New(string tunnel_name) {
+ int fd;
+ RETURN_ERROR_IF_SYSCALL_FAIL(fd = open("/dev/net/tun", O_RDWR));
+
+ // Using `new` to access a non-public constructor.
+ auto new_tunnel = absl::WrapUnique(new Tunnel(fd));
+
+ ifreq ifr = {};
+ ifr.ifr_flags = IFF_TUN;
+ strncpy(ifr.ifr_name, tunnel_name.c_str(), sizeof(ifr.ifr_name));
+
+ RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(fd, TUNSETIFF, &ifr));
+ new_tunnel->name_ = ifr.ifr_name;
+ return new_tunnel;
+}
+
+std::unordered_set<string> GetInterfaceNames() {
+ struct if_nameindex* interfaces = if_nameindex();
+ std::unordered_set<string> names;
+ if (interfaces == nullptr) {
+ return names;
+ }
+ for (auto interface = interfaces;
+ interface->if_index != 0 || interface->if_name != nullptr; interface++) {
+ names.insert(interface->if_name);
+ }
+ if_freenameindex(interfaces);
+ return names;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_util.h b/test/syscalls/linux/socket_bind_to_device_util.h
new file mode 100644
index 000000000..f941ccc86
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_util.h
@@ -0,0 +1,67 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_
+#define GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_
+
+#include <arpa/inet.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+class Tunnel {
+ public:
+ static PosixErrorOr<std::unique_ptr<Tunnel>> New(
+ std::string tunnel_name = "");
+ const std::string& GetName() const { return name_; }
+
+ ~Tunnel() {
+ if (fd_ != -1) {
+ close(fd_);
+ }
+ }
+
+ private:
+ Tunnel(int fd) : fd_(fd) {}
+ int fd_ = -1;
+ std::string name_;
+};
+
+std::unordered_set<std::string> GetInterfaceNames();
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_
diff --git a/test/syscalls/linux/socket_blocking.cc b/test/syscalls/linux/socket_blocking.cc
new file mode 100644
index 000000000..7e88aa2d9
--- /dev/null
+++ b/test/syscalls/linux/socket_blocking.cc
@@ -0,0 +1,60 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_blocking.h"
+
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <cstdio>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+#include "test/util/timer_util.h"
+
+namespace gvisor {
+namespace testing {
+
+TEST_P(BlockingSocketPairTest, RecvBlocks) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[100];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ constexpr auto kDuration = absl::Milliseconds(200);
+ auto before = Now(CLOCK_MONOTONIC);
+
+ const ScopedThread t([&]() {
+ absl::SleepFor(kDuration);
+ ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ });
+
+ char received_data[sizeof(sent_data)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data), 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ auto after = Now(CLOCK_MONOTONIC);
+ EXPECT_GE(after - before, kDuration);
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_blocking.h b/test/syscalls/linux/socket_blocking.h
new file mode 100644
index 000000000..db26e5ef5
--- /dev/null
+++ b/test/syscalls/linux/socket_blocking.h
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_BLOCKING_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_BLOCKING_H_
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of blocking connected sockets.
+using BlockingSocketPairTest = SocketPairTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_BLOCKING_H_
diff --git a/test/syscalls/linux/socket_capability.cc b/test/syscalls/linux/socket_capability.cc
new file mode 100644
index 000000000..84b5b2b21
--- /dev/null
+++ b/test/syscalls/linux/socket_capability.cc
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Subset of socket tests that need Linux-specific headers (compared to POSIX
+// headers).
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+TEST(SocketTest, UnixConnectNeedsWritePerm) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ FileDescriptor bound =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX));
+
+ struct sockaddr_un addr =
+ ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(/*abstract=*/false, AF_UNIX));
+ ASSERT_THAT(bind(bound.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(bound.get(), 1), SyscallSucceeds());
+
+ // Drop capabilites that allow us to override permision checks. Otherwise if
+ // the test is run as root, the connect below will bypass permission checks
+ // and succeed unexpectedly.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+
+ // Connect should fail without write perms.
+ ASSERT_THAT(chmod(addr.sun_path, 0500), SyscallSucceeds());
+ FileDescriptor client =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX));
+ ASSERT_THAT(connect(client.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallFailsWithErrno(EACCES));
+
+ // Connect should succeed with write perms.
+ ASSERT_THAT(chmod(addr.sun_path, 0200), SyscallSucceeds());
+ EXPECT_THAT(connect(client.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_filesystem.cc b/test/syscalls/linux/socket_filesystem.cc
new file mode 100644
index 000000000..287359363
--- /dev/null
+++ b/test/syscalls/linux/socket_filesystem.cc
@@ -0,0 +1,49 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/socket_generic.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/socket_unix.h"
+#include "test/syscalls/linux/socket_unix_cmsg.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return ApplyVec<SocketPairKind>(
+ FilesystemBoundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET},
+ List<int>{0, SOCK_NONBLOCK}));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ FilesystemUnixSockets, AllSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+INSTANTIATE_TEST_SUITE_P(
+ FilesystemUnixSockets, UnixSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+INSTANTIATE_TEST_SUITE_P(
+ FilesystemUnixSockets, UnixSocketPairCmsgTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc
new file mode 100644
index 000000000..f7d6139f1
--- /dev/null
+++ b/test/syscalls/linux/socket_generic.cc
@@ -0,0 +1,820 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_generic.h"
+
+#include <stdio.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+// This file is a generic socket test file. It must be built with another file
+// that provides the test types.
+
+namespace gvisor {
+namespace testing {
+
+TEST_P(AllSocketPairTest, BasicReadWrite) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char buf[20];
+ const std::string data = "abc";
+ ASSERT_THAT(WriteFd(sockets->first_fd(), data.c_str(), 3),
+ SyscallSucceedsWithValue(3));
+ ASSERT_THAT(ReadFd(sockets->second_fd(), buf, 3),
+ SyscallSucceedsWithValue(3));
+ EXPECT_EQ(data, absl::string_view(buf, 3));
+}
+
+TEST_P(AllSocketPairTest, BasicSendRecv) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char sent_data[512];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ char received_data[sizeof(sent_data)];
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data), 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+}
+
+TEST_P(AllSocketPairTest, BasicSendmmsg) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char sent_data[200];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ std::vector<struct mmsghdr> msgs(10);
+ std::vector<struct iovec> iovs(msgs.size());
+ const int chunk_size = sizeof(sent_data) / msgs.size();
+ for (size_t i = 0; i < msgs.size(); i++) {
+ iovs[i].iov_len = chunk_size;
+ iovs[i].iov_base = &sent_data[i * chunk_size];
+ msgs[i].msg_hdr.msg_iov = &iovs[i];
+ msgs[i].msg_hdr.msg_iovlen = 1;
+ }
+
+ ASSERT_THAT(
+ RetryEINTR(sendmmsg)(sockets->first_fd(), &msgs[0], msgs.size(), 0),
+ SyscallSucceedsWithValue(msgs.size()));
+
+ for (const struct mmsghdr& msg : msgs) {
+ EXPECT_EQ(chunk_size, msg.msg_len);
+ }
+
+ char received_data[sizeof(sent_data)];
+ for (size_t i = 0; i < msgs.size(); i++) {
+ ASSERT_THAT(ReadFd(sockets->second_fd(), &received_data[i * chunk_size],
+ chunk_size),
+ SyscallSucceedsWithValue(chunk_size));
+ }
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+}
+
+TEST_P(AllSocketPairTest, BasicRecvmmsg) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char sent_data[200];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ char received_data[sizeof(sent_data)];
+ std::vector<struct mmsghdr> msgs(10);
+ std::vector<struct iovec> iovs(msgs.size());
+ const int chunk_size = sizeof(sent_data) / msgs.size();
+ for (size_t i = 0; i < msgs.size(); i++) {
+ iovs[i].iov_len = chunk_size;
+ iovs[i].iov_base = &received_data[i * chunk_size];
+ msgs[i].msg_hdr.msg_iov = &iovs[i];
+ msgs[i].msg_hdr.msg_iovlen = 1;
+ }
+
+ for (size_t i = 0; i < msgs.size(); i++) {
+ ASSERT_THAT(
+ WriteFd(sockets->first_fd(), &sent_data[i * chunk_size], chunk_size),
+ SyscallSucceedsWithValue(chunk_size));
+ }
+
+ ASSERT_THAT(RetryEINTR(recvmmsg)(sockets->second_fd(), &msgs[0], msgs.size(),
+ 0, nullptr),
+ SyscallSucceedsWithValue(msgs.size()));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ for (const struct mmsghdr& msg : msgs) {
+ EXPECT_EQ(chunk_size, msg.msg_len);
+ }
+}
+
+TEST_P(AllSocketPairTest, SendmsgRecvmsg10KB) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ std::vector<char> sent_data(10 * 1024);
+ RandomizeBuffer(sent_data.data(), sent_data.size());
+ ASSERT_NO_FATAL_FAILURE(
+ SendNullCmsg(sockets->first_fd(), sent_data.data(), sent_data.size()));
+
+ std::vector<char> received_data(sent_data.size());
+ ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(sockets->second_fd(), received_data.data(),
+ received_data.size()));
+
+ EXPECT_EQ(0,
+ memcmp(sent_data.data(), received_data.data(), sent_data.size()));
+}
+
+// This test validates that a sendmsg/recvmsg w/ MSG_CTRUNC is a no-op on
+// input flags.
+TEST_P(AllSocketPairTest, SendmsgRecvmsgMsgCtruncNoop) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ std::vector<char> sent_data(10 * 1024);
+ RandomizeBuffer(sent_data.data(), sent_data.size());
+ ASSERT_NO_FATAL_FAILURE(
+ SendNullCmsg(sockets->first_fd(), sent_data.data(), sent_data.size()));
+
+ std::vector<char> received_data(sent_data.size());
+ struct msghdr msg = {};
+ char control[CMSG_SPACE(sizeof(int)) + CMSG_SPACE(sizeof(struct ucred))];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ struct iovec iov;
+ iov.iov_base = &received_data[0];
+ iov.iov_len = received_data.size();
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ // MSG_CTRUNC should be a no-op.
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_CTRUNC),
+ SyscallSucceedsWithValue(received_data.size()));
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ EXPECT_EQ(cmsg, nullptr);
+ EXPECT_EQ(msg.msg_controllen, 0);
+ EXPECT_EQ(0,
+ memcmp(sent_data.data(), received_data.data(), sent_data.size()));
+}
+
+TEST_P(AllSocketPairTest, SendmsgRecvmsg16KB) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ std::vector<char> sent_data(16 * 1024);
+ RandomizeBuffer(sent_data.data(), sent_data.size());
+ ASSERT_NO_FATAL_FAILURE(
+ SendNullCmsg(sockets->first_fd(), sent_data.data(), sent_data.size()));
+
+ std::vector<char> received_data(sent_data.size());
+ ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(sockets->second_fd(), received_data.data(),
+ received_data.size()));
+
+ EXPECT_EQ(0,
+ memcmp(sent_data.data(), received_data.data(), sent_data.size()));
+}
+
+TEST_P(AllSocketPairTest, RecvmsgMsghdrFlagsNotClearedOnFailure) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char received_data[10] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+
+ // Check that msghdr flags were not changed.
+ EXPECT_EQ(msg.msg_flags, -1);
+}
+
+TEST_P(AllSocketPairTest, RecvmsgMsghdrFlagsCleared) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[sizeof(sent_data)] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(sent_data)));
+
+ // Check that msghdr flags were cleared.
+ EXPECT_EQ(msg.msg_flags, 0);
+}
+
+TEST_P(AllSocketPairTest, RecvmsgPeekMsghdrFlagsCleared) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[sizeof(sent_data)] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_PEEK),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(sent_data)));
+
+ // Check that msghdr flags were cleared.
+ EXPECT_EQ(msg.msg_flags, 0);
+}
+
+TEST_P(AllSocketPairTest, RecvmsgIovNotUpdated) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[sizeof(sent_data) * 2] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(sent_data)));
+
+ // Check that the iovec length was not updated.
+ EXPECT_EQ(msg.msg_iov->iov_len, sizeof(received_data));
+}
+
+TEST_P(AllSocketPairTest, RecvmmsgInvalidTimeout) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char buf[10];
+ struct mmsghdr msg = {};
+ struct iovec iov = {};
+ iov.iov_len = sizeof(buf);
+ iov.iov_base = buf;
+ msg.msg_hdr.msg_iov = &iov;
+ msg.msg_hdr.msg_iovlen = 1;
+ struct timespec timeout = {-1, -1};
+ ASSERT_THAT(RetryEINTR(recvmmsg)(sockets->first_fd(), &msg, 1, 0, &timeout),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(AllSocketPairTest, RecvmmsgTimeoutBeforeRecv) {
+ // There is a known bug in the Linux recvmmsg(2) causing it to block forever
+ // if the timeout expires while blocking for the first message.
+ SKIP_IF(!IsRunningOnGvisor());
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char buf[10];
+ struct mmsghdr msg = {};
+ struct iovec iov = {};
+ iov.iov_len = sizeof(buf);
+ iov.iov_base = buf;
+ msg.msg_hdr.msg_iov = &iov;
+ msg.msg_hdr.msg_iovlen = 1;
+ struct timespec timeout = {};
+ ASSERT_THAT(RetryEINTR(recvmmsg)(sockets->first_fd(), &msg, 1, 0, &timeout),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST_P(AllSocketPairTest, MsgPeek) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char sent_data[50];
+ memset(&sent_data, 0, sizeof(sent_data));
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[sizeof(sent_data)];
+ for (int i = 0; i < 3; i++) {
+ memset(received_data, 0, sizeof(received_data));
+ EXPECT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data), MSG_PEEK),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(received_data)));
+ }
+
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data), 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(received_data)));
+}
+
+TEST_P(AllSocketPairTest, LingerSocketOption) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ struct linger got_linger = {-1, -1};
+ socklen_t length = sizeof(struct linger);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER,
+ &got_linger, &length),
+ SyscallSucceedsWithValue(0));
+ struct linger want_linger = {};
+ EXPECT_EQ(0, memcmp(&want_linger, &got_linger, sizeof(struct linger)));
+ EXPECT_EQ(sizeof(struct linger), length);
+}
+
+TEST_P(AllSocketPairTest, KeepAliveSocketOption) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ int keepalive = -1;
+ socklen_t length = sizeof(int);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_KEEPALIVE,
+ &keepalive, &length),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(0, keepalive);
+ EXPECT_EQ(sizeof(int), length);
+}
+
+TEST_P(AllSocketPairTest, RcvBufSucceeds) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ int size = 0;
+ socklen_t size_size = sizeof(size);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVBUF, &size, &size_size),
+ SyscallSucceeds());
+ EXPECT_GT(size, 0);
+}
+
+TEST_P(AllSocketPairTest, SndBufSucceeds) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ int size = 0;
+ socklen_t size_size = sizeof(size);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, &size, &size_size),
+ SyscallSucceeds());
+ EXPECT_GT(size, 0);
+}
+
+TEST_P(AllSocketPairTest, RecvTimeoutReadSucceeds) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = 10
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ char buf[20] = {};
+ EXPECT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST_P(AllSocketPairTest, RecvTimeoutRecvSucceeds) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = 10
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ char buf[20] = {};
+ EXPECT_THAT(RetryEINTR(recv)(sockets->first_fd(), buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST_P(AllSocketPairTest, RecvTimeoutRecvOneSecondSucceeds) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 1, .tv_usec = 0
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ char buf[20] = {};
+ EXPECT_THAT(RetryEINTR(recv)(sockets->first_fd(), buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgSucceeds) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = 10
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ struct msghdr msg = {};
+ char buf[20] = {};
+ struct iovec iov;
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ EXPECT_THAT(RetryEINTR(recvmsg)(sockets->first_fd(), &msg, 0),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST_P(AllSocketPairTest, SendTimeoutDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ timeval actual_tv = {.tv_sec = -1, .tv_usec = -1};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv_usec, 0);
+}
+
+TEST_P(AllSocketPairTest, SetGetSendTimeout) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ timeval tv = {.tv_sec = 89, .tv_usec = 42000};
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ timeval actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 89);
+ EXPECT_EQ(actual_tv.tv_usec, 42000);
+}
+
+TEST_P(AllSocketPairTest, SetGetSendTimeoutLargerArg) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval_with_extra {
+ struct timeval tv;
+ int64_t extra_data;
+ } ABSL_ATTRIBUTE_PACKED;
+
+ timeval_with_extra tv_extra = {
+ .tv = {.tv_sec = 0, .tv_usec = 123000},
+ };
+
+ EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &tv_extra, sizeof(tv_extra)),
+ SyscallSucceeds());
+
+ timeval_with_extra actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv.tv_usec, 123000);
+}
+
+TEST_P(AllSocketPairTest, SendTimeoutAllowsWrite) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = 10
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ char buf[20] = {};
+ ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+}
+
+TEST_P(AllSocketPairTest, SendTimeoutAllowsSend) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = 10
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ char buf[20] = {};
+ ASSERT_THAT(RetryEINTR(send)(sockets->first_fd(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(buf)));
+}
+
+TEST_P(AllSocketPairTest, SendTimeoutAllowsSendmsg) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = 10
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ char buf[20] = {};
+ ASSERT_NO_FATAL_FAILURE(SendNullCmsg(sockets->first_fd(), buf, sizeof(buf)));
+}
+
+TEST_P(AllSocketPairTest, RecvTimeoutDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ timeval actual_tv = {.tv_sec = -1, .tv_usec = -1};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv_usec, 0);
+}
+
+TEST_P(AllSocketPairTest, SetGetRecvTimeout) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ timeval tv = {.tv_sec = 123, .tv_usec = 456000};
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ timeval actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 123);
+ EXPECT_EQ(actual_tv.tv_usec, 456000);
+}
+
+TEST_P(AllSocketPairTest, SetGetRecvTimeoutLargerArg) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval_with_extra {
+ struct timeval tv;
+ int64_t extra_data;
+ } ABSL_ATTRIBUTE_PACKED;
+
+ timeval_with_extra tv_extra = {
+ .tv = {.tv_sec = 0, .tv_usec = 432000},
+ };
+
+ EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ &tv_extra, sizeof(tv_extra)),
+ SyscallSucceeds());
+
+ timeval_with_extra actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv.tv_usec, 432000);
+}
+
+TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgOneSecondSucceeds) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 1, .tv_usec = 0
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ struct msghdr msg = {};
+ char buf[20] = {};
+ struct iovec iov;
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ EXPECT_THAT(RetryEINTR(recvmsg)(sockets->first_fd(), &msg, 0),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST_P(AllSocketPairTest, RecvTimeoutUsecTooLarge) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = 2000000 // 2 seconds.
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
+ SyscallFailsWithErrno(EDOM));
+}
+
+TEST_P(AllSocketPairTest, SendTimeoutUsecTooLarge) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = 2000000 // 2 seconds.
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
+ SyscallFailsWithErrno(EDOM));
+}
+
+TEST_P(AllSocketPairTest, RecvTimeoutUsecNeg) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = -1
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
+ SyscallFailsWithErrno(EDOM));
+}
+
+TEST_P(AllSocketPairTest, SendTimeoutUsecNeg) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = -1
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
+ SyscallFailsWithErrno(EDOM));
+}
+
+TEST_P(AllSocketPairTest, RecvTimeoutNegSecRead) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = -1, .tv_usec = 0
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ char buf[20] = {};
+ EXPECT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST_P(AllSocketPairTest, RecvTimeoutNegSecRecv) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = -1, .tv_usec = 0
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ char buf[20] = {};
+ EXPECT_THAT(RetryEINTR(recv)(sockets->first_fd(), buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST_P(AllSocketPairTest, RecvTimeoutNegSecRecvmsg) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = -1, .tv_usec = 0
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ struct msghdr msg = {};
+ char buf[20] = {};
+ struct iovec iov;
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ EXPECT_THAT(RetryEINTR(recvmsg)(sockets->first_fd(), &msg, 0),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST_P(AllSocketPairTest, RecvWaitAll) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[100];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[sizeof(sent_data)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data), MSG_WAITALL),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+}
+
+TEST_P(AllSocketPairTest, RecvWaitAllDontWait) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char data[100] = {};
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), data, sizeof(data),
+ MSG_WAITALL | MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST_P(AllSocketPairTest, RecvTimeoutWaitAll) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = 200000 // 200ms
+ };
+ EXPECT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv,
+ sizeof(tv)),
+ SyscallSucceeds());
+
+ char sent_data[100];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[sizeof(sent_data) * 2] = {};
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data), MSG_WAITALL),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+}
+
+TEST_P(AllSocketPairTest, GetSockoptType) {
+ int type = GetParam().type;
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ for (const int fd : {sockets->first_fd(), sockets->second_fd()}) {
+ int opt;
+ socklen_t optlen = sizeof(opt);
+ EXPECT_THAT(getsockopt(fd, SOL_SOCKET, SO_TYPE, &opt, &optlen),
+ SyscallSucceeds());
+
+ // Type may have SOCK_NONBLOCK and SOCK_CLOEXEC ORed into it. Remove these
+ // before comparison.
+ type &= ~(SOCK_NONBLOCK | SOCK_CLOEXEC);
+ EXPECT_EQ(opt, type) << absl::StrFormat(
+ "getsockopt(%d, SOL_SOCKET, SO_TYPE, &opt, &optlen) => opt=%d was "
+ "unexpected",
+ fd, opt);
+ }
+}
+
+TEST_P(AllSocketPairTest, GetSockoptDomain) {
+ const int domain = GetParam().domain;
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ for (const int fd : {sockets->first_fd(), sockets->second_fd()}) {
+ int opt;
+ socklen_t optlen = sizeof(opt);
+ EXPECT_THAT(getsockopt(fd, SOL_SOCKET, SO_DOMAIN, &opt, &optlen),
+ SyscallSucceeds());
+ EXPECT_EQ(opt, domain) << absl::StrFormat(
+ "getsockopt(%d, SOL_SOCKET, SO_DOMAIN, &opt, &optlen) => opt=%d was "
+ "unexpected",
+ fd, opt);
+ }
+}
+
+TEST_P(AllSocketPairTest, GetSockoptProtocol) {
+ const int protocol = GetParam().protocol;
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ for (const int fd : {sockets->first_fd(), sockets->second_fd()}) {
+ int opt;
+ socklen_t optlen = sizeof(opt);
+ EXPECT_THAT(getsockopt(fd, SOL_SOCKET, SO_PROTOCOL, &opt, &optlen),
+ SyscallSucceeds());
+ EXPECT_EQ(opt, protocol) << absl::StrFormat(
+ "getsockopt(%d, SOL_SOCKET, SO_PROTOCOL, &opt, &optlen) => opt=%d was "
+ "unexpected",
+ fd, opt);
+ }
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_generic.h b/test/syscalls/linux/socket_generic.h
new file mode 100644
index 000000000..00ae7bfc3
--- /dev/null
+++ b/test/syscalls/linux/socket_generic.h
@@ -0,0 +1,30 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_GENERIC_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_GENERIC_H_
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of blocking and non-blocking
+// connected stream sockets.
+using AllSocketPairTest = SocketPairTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_GENERIC_H_
diff --git a/test/syscalls/linux/socket_generic_stress.cc b/test/syscalls/linux/socket_generic_stress.cc
new file mode 100644
index 000000000..6a232238d
--- /dev/null
+++ b/test/syscalls/linux/socket_generic_stress.cc
@@ -0,0 +1,83 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of connected sockets.
+using ConnectStressTest = SocketPairTest;
+
+TEST_P(ConnectStressTest, Reset65kTimes) {
+ for (int i = 0; i < 1 << 16; ++i) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // Send some data to ensure that the connection gets reset and the port gets
+ // released immediately. This avoids either end entering TIME-WAIT.
+ char sent_data[100] = {};
+ ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllConnectedSockets, ConnectStressTest,
+ ::testing::Values(IPv6UDPBidirectionalBindSocketPair(0),
+ IPv4UDPBidirectionalBindSocketPair(0),
+ DualStackUDPBidirectionalBindSocketPair(0),
+
+ // Without REUSEADDR, we get port exhaustion on Linux.
+ SetSockOpt(SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn)(IPv6TCPAcceptBindSocketPair(0)),
+ SetSockOpt(SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn)(IPv4TCPAcceptBindSocketPair(0)),
+ SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)(
+ DualStackTCPAcceptBindSocketPair(0))));
+
+// Test fixture for tests that apply to pairs of connected sockets created with
+// a persistent listener (if applicable).
+using PersistentListenerConnectStressTest = SocketPairTest;
+
+TEST_P(PersistentListenerConnectStressTest, 65kTimes) {
+ for (int i = 0; i < 1 << 16; ++i) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllConnectedSockets, PersistentListenerConnectStressTest,
+ ::testing::Values(
+ IPv6UDPBidirectionalBindSocketPair(0),
+ IPv4UDPBidirectionalBindSocketPair(0),
+ DualStackUDPBidirectionalBindSocketPair(0),
+
+ // Without REUSEADDR, we get port exhaustion on Linux.
+ SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)(
+ IPv6TCPAcceptBindPersistentListenerSocketPair(0)),
+ SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)(
+ IPv4TCPAcceptBindPersistentListenerSocketPair(0)),
+ SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)(
+ DualStackTCPAcceptBindPersistentListenerSocketPair(0))));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
new file mode 100644
index 000000000..18b9e4b70
--- /dev/null
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -0,0 +1,2566 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <poll.h>
+#include <string.h>
+#include <sys/socket.h>
+
+#include <atomic>
+#include <iostream>
+#include <memory>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+using ::testing::Gt;
+
+PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) {
+ switch (family) {
+ case AF_INET:
+ return static_cast<uint16_t>(
+ reinterpret_cast<sockaddr_in const*>(&addr)->sin_port);
+ case AF_INET6:
+ return static_cast<uint16_t>(
+ reinterpret_cast<sockaddr_in6 const*>(&addr)->sin6_port);
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) {
+ switch (family) {
+ case AF_INET:
+ reinterpret_cast<sockaddr_in*>(addr)->sin_port = port;
+ return NoError();
+ case AF_INET6:
+ reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = port;
+ return NoError();
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+struct TestParam {
+ TestAddress listener;
+ TestAddress connector;
+};
+
+std::string DescribeTestParam(::testing::TestParamInfo<TestParam> const& info) {
+ return absl::StrCat("Listen", info.param.listener.description, "_Connect",
+ info.param.connector.description);
+}
+
+using SocketInetLoopbackTest = ::testing::TestWithParam<TestParam>;
+
+TEST(BadSocketPairArgs, ValidateErrForBadCallsToSocketPair) {
+ int fd[2] = {};
+
+ // Valid AF but invalid for socketpair(2) return ESOCKTNOSUPPORT.
+ ASSERT_THAT(socketpair(AF_INET, 0, 0, fd),
+ SyscallFailsWithErrno(ESOCKTNOSUPPORT));
+ ASSERT_THAT(socketpair(AF_INET6, 0, 0, fd),
+ SyscallFailsWithErrno(ESOCKTNOSUPPORT));
+
+ // Invalid AF will return ENOAFSUPPORT.
+ ASSERT_THAT(socketpair(AF_MAX, 0, 0, fd),
+ SyscallFailsWithErrno(EAFNOSUPPORT));
+ ASSERT_THAT(socketpair(8675309, 0, 0, fd),
+ SyscallFailsWithErrno(EAFNOSUPPORT));
+}
+
+enum class Operation {
+ Bind,
+ Connect,
+ SendTo,
+};
+
+std::string OperationToString(Operation operation) {
+ switch (operation) {
+ case Operation::Bind:
+ return "Bind";
+ case Operation::Connect:
+ return "Connect";
+ case Operation::SendTo:
+ return "SendTo";
+ }
+}
+
+using OperationSequence = std::vector<Operation>;
+
+using DualStackSocketTest =
+ ::testing::TestWithParam<std::tuple<TestAddress, OperationSequence>>;
+
+TEST_P(DualStackSocketTest, AddressOperations) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_DGRAM, 0));
+
+ const TestAddress& addr = std::get<0>(GetParam());
+ const OperationSequence& operations = std::get<1>(GetParam());
+
+ auto addr_in = reinterpret_cast<const sockaddr*>(&addr.addr);
+
+ // sockets may only be bound once. Both `connect` and `sendto` cause a socket
+ // to be bound.
+ bool bound = false;
+ for (const Operation& operation : operations) {
+ bool sockname = false;
+ bool peername = false;
+ switch (operation) {
+ case Operation::Bind: {
+ ASSERT_NO_ERRNO(SetAddrPort(
+ addr.family(), const_cast<sockaddr_storage*>(&addr.addr), 0));
+
+ int bind_ret = bind(fd.get(), addr_in, addr.addr_len);
+
+ // Dual stack sockets may only be bound to AF_INET6.
+ if (!bound && addr.family() == AF_INET6) {
+ EXPECT_THAT(bind_ret, SyscallSucceeds());
+ bound = true;
+
+ sockname = true;
+ } else {
+ EXPECT_THAT(bind_ret, SyscallFailsWithErrno(EINVAL));
+ }
+ break;
+ }
+ case Operation::Connect: {
+ ASSERT_NO_ERRNO(SetAddrPort(
+ addr.family(), const_cast<sockaddr_storage*>(&addr.addr), 1337));
+
+ EXPECT_THAT(RetryEINTR(connect)(fd.get(), addr_in, addr.addr_len),
+ SyscallSucceeds())
+ << GetAddrStr(addr_in);
+ bound = true;
+
+ sockname = true;
+ peername = true;
+
+ break;
+ }
+ case Operation::SendTo: {
+ const char payload[] = "hello";
+ ASSERT_NO_ERRNO(SetAddrPort(
+ addr.family(), const_cast<sockaddr_storage*>(&addr.addr), 1337));
+
+ ssize_t sendto_ret = sendto(fd.get(), &payload, sizeof(payload), 0,
+ addr_in, addr.addr_len);
+
+ EXPECT_THAT(sendto_ret, SyscallSucceedsWithValue(sizeof(payload)));
+ sockname = !bound;
+ bound = true;
+ break;
+ }
+ }
+
+ if (sockname) {
+ sockaddr_storage sock_addr;
+ socklen_t addrlen = sizeof(sock_addr);
+ ASSERT_THAT(getsockname(fd.get(), reinterpret_cast<sockaddr*>(&sock_addr),
+ &addrlen),
+ SyscallSucceeds());
+ ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6));
+
+ auto sock_addr_in6 = reinterpret_cast<const sockaddr_in6*>(&sock_addr);
+
+ if (operation == Operation::SendTo) {
+ EXPECT_EQ(sock_addr_in6->sin6_family, AF_INET6);
+ EXPECT_TRUE(IN6_IS_ADDR_UNSPECIFIED(sock_addr_in6->sin6_addr.s6_addr32))
+ << OperationToString(operation) << " getsocknam="
+ << GetAddrStr(reinterpret_cast<sockaddr*>(&sock_addr));
+
+ EXPECT_NE(sock_addr_in6->sin6_port, 0);
+ } else if (IN6_IS_ADDR_V4MAPPED(
+ reinterpret_cast<const sockaddr_in6*>(addr_in)
+ ->sin6_addr.s6_addr32)) {
+ EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED(sock_addr_in6->sin6_addr.s6_addr32))
+ << OperationToString(operation) << " getsocknam="
+ << GetAddrStr(reinterpret_cast<sockaddr*>(&sock_addr));
+ }
+ }
+
+ if (peername) {
+ sockaddr_storage peer_addr;
+ socklen_t addrlen = sizeof(peer_addr);
+ ASSERT_THAT(getpeername(fd.get(), reinterpret_cast<sockaddr*>(&peer_addr),
+ &addrlen),
+ SyscallSucceeds());
+ ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6));
+
+ if (addr.family() == AF_INET ||
+ IN6_IS_ADDR_V4MAPPED(reinterpret_cast<const sockaddr_in6*>(addr_in)
+ ->sin6_addr.s6_addr32)) {
+ EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED(
+ reinterpret_cast<const sockaddr_in6*>(&peer_addr)
+ ->sin6_addr.s6_addr32))
+ << OperationToString(operation) << " getpeername="
+ << GetAddrStr(reinterpret_cast<sockaddr*>(&peer_addr));
+ }
+ }
+ }
+}
+
+// TODO(gvisor.dev/issue/1556): uncomment V4MappedAny.
+INSTANTIATE_TEST_SUITE_P(
+ All, DualStackSocketTest,
+ ::testing::Combine(
+ ::testing::Values(V4Any(), V4Loopback(), /*V4MappedAny(),*/
+ V4MappedLoopback(), V6Any(), V6Loopback()),
+ ::testing::ValuesIn<OperationSequence>(
+ {{Operation::Bind, Operation::Connect, Operation::SendTo},
+ {Operation::Bind, Operation::SendTo, Operation::Connect},
+ {Operation::Connect, Operation::Bind, Operation::SendTo},
+ {Operation::Connect, Operation::SendTo, Operation::Bind},
+ {Operation::SendTo, Operation::Bind, Operation::Connect},
+ {Operation::SendTo, Operation::Connect, Operation::Bind}})),
+ [](::testing::TestParamInfo<
+ std::tuple<TestAddress, OperationSequence>> const& info) {
+ const TestAddress& addr = std::get<0>(info.param);
+ const OperationSequence& operations = std::get<1>(info.param);
+ std::string s = addr.description;
+ for (const Operation& operation : operations) {
+ absl::StrAppend(&s, OperationToString(operation));
+ }
+ return s;
+ });
+
+void tcpSimpleConnectTest(TestAddress const& listener,
+ TestAddress const& connector, bool unbound) {
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ if (!unbound) {
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ }
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Connect to the listening socket.
+ const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ //
+ // We have to assign a name to the accepted socket, as unamed temporary
+ // objects are destructed upon full evaluation of the expression it is in,
+ // potentially causing the connecting socket to fail to shutdown properly.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+
+ ASSERT_THAT(shutdown(listen_fd.get(), SHUT_RDWR), SyscallSucceeds());
+
+ ASSERT_THAT(shutdown(conn_fd.get(), SHUT_RDWR), SyscallSucceeds());
+}
+
+TEST_P(SocketInetLoopbackTest, TCP) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ tcpSimpleConnectTest(listener, connector, true);
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenUnbound) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ tcpSimpleConnectTest(listener, connector, false);
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenShutdownListen) {
+ const auto& param = GetParam();
+
+ const TestAddress& listener = param.listener;
+ const TestAddress& connector = param.connector;
+
+ constexpr int kBacklog = 5;
+
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+ ASSERT_THAT(shutdown(listen_fd.get(), SHUT_RD), SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ const uint16_t port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+
+ for (int i = 0; i < kBacklog; i++) {
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(RetryEINTR(connect)(client.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+ }
+ for (int i = 0; i < kBacklog; i++) {
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), SyscallSucceeds());
+ }
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenShutdown) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ constexpr int kBacklog = 2;
+ constexpr int kFDs = kBacklog + 1;
+
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+
+ // Shutdown the write of the listener, expect to not have any effect.
+ ASSERT_THAT(shutdown(listen_fd.get(), SHUT_WR), SyscallSucceeds());
+
+ for (int i = 0; i < kFDs; i++) {
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(RetryEINTR(connect)(client.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), SyscallSucceeds());
+ }
+
+ // Shutdown the read of the listener, expect to fail subsequent
+ // server accepts, binds and client connects.
+ ASSERT_THAT(shutdown(listen_fd.get(), SHUT_RD), SyscallSucceeds());
+
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Check that shutdown did not release the port.
+ FileDescriptor new_listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(
+ bind(new_listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Check that subsequent connection attempts receive a RST.
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ for (int i = 0; i < kFDs; i++) {
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(RetryEINTR(connect)(client.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallFailsWithErrno(ECONNREFUSED));
+ }
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenClose) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ constexpr int kAcceptCount = 2;
+ constexpr int kBacklog = kAcceptCount + 2;
+ constexpr int kFDs = kBacklog * 3;
+
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ std::vector<FileDescriptor> clients;
+ for (int i = 0; i < kFDs; i++) {
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
+ }
+ clients.push_back(std::move(client));
+ }
+ for (int i = 0; i < kAcceptCount; i++) {
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+ }
+}
+
+void TestListenWhileConnect(const TestParam& param,
+ void (*stopListen)(FileDescriptor&)) {
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ constexpr int kBacklog = 2;
+ constexpr int kClients = kBacklog + 1;
+
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ std::vector<FileDescriptor> clients;
+ for (int i = 0; i < kClients; i++) {
+ FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
+ clients.push_back(std::move(client));
+ }
+ }
+
+ stopListen(listen_fd);
+
+ for (auto& client : clients) {
+ const int kTimeout = 10000;
+ struct pollfd pfd = {
+ .fd = client.get(),
+ .events = POLLIN,
+ };
+ // When the listening socket is closed, then we expect the remote to reset
+ // the connection.
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ ASSERT_EQ(pfd.revents, POLLIN | POLLHUP | POLLERR);
+ char c;
+ // Subsequent read can fail with:
+ // ECONNRESET: If the client connection was established and was reset by the
+ // remote.
+ // ECONNREFUSED: If the client connection failed to be established.
+ ASSERT_THAT(read(client.get(), &c, sizeof(c)),
+ AnyOf(SyscallFailsWithErrno(ECONNRESET),
+ SyscallFailsWithErrno(ECONNREFUSED)));
+ }
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) {
+ TestListenWhileConnect(GetParam(), [](FileDescriptor& f) {
+ ASSERT_THAT(close(f.release()), SyscallSucceeds());
+ });
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) {
+ TestListenWhileConnect(GetParam(), [](FileDescriptor& f) {
+ ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds());
+ });
+}
+
+TEST_P(SocketInetLoopbackTest, TCPbacklog) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), 2), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ int i = 0;
+ while (1) {
+ int ret;
+
+ // Connect to the listening socket.
+ const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ret = connect(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
+ struct pollfd pfd = {
+ .fd = conn_fd.get(),
+ .events = POLLOUT,
+ };
+ ret = poll(&pfd, 1, 3000);
+ if (ret == 0) break;
+ EXPECT_THAT(ret, SyscallSucceedsWithValue(1));
+ }
+ EXPECT_THAT(RetryEINTR(send)(conn_fd.get(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+ ASSERT_THAT(shutdown(conn_fd.get(), SHUT_RDWR), SyscallSucceeds());
+ i++;
+ }
+
+ for (; i != 0; i--) {
+ // Accept the connection.
+ //
+ // We have to assign a name to the accepted socket, as unamed temporary
+ // objects are destructed upon full evaluation of the expression it is in,
+ // potentially causing the connecting socket to fail to shutdown properly.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+ }
+}
+
+// TCPFinWait2Test creates a pair of connected sockets then closes one end to
+// trigger FIN_WAIT2 state for the closed endpoint. Then it binds the same local
+// IP/port on a new socket and tries to connect. The connect should fail w/
+// an EADDRINUSE. Then we wait till the FIN_WAIT2 timeout is over and try the
+// connect again with a new socket and this time it should succeed.
+//
+// TCP timers are not S/R today, this can cause this test to be flaky when run
+// under random S/R due to timer being reset on a restore.
+TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ // Lower FIN_WAIT2 state to 5 seconds for test.
+ constexpr int kTCPLingerTimeout = 5;
+ EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2,
+ &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)),
+ SyscallSucceedsWithValue(0));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+
+ // Get the address/port bound by the connecting socket.
+ sockaddr_storage conn_bound_addr;
+ socklen_t conn_addrlen = connector.addr_len;
+ ASSERT_THAT(
+ getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr),
+ &conn_addrlen),
+ SyscallSucceeds());
+
+ // close the connecting FD to trigger FIN_WAIT2 on the connected fd.
+ conn_fd.reset();
+
+ // Now bind and connect a new socket.
+ const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ // Disable cooperative saves after this point. As a save between the first
+ // bind/connect and the second one can cause the linger timeout timer to
+ // be restarted causing the final bind/connect to fail.
+ DisableSave ds;
+
+ ASSERT_THAT(bind(conn_fd2.get(),
+ reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Sleep for a little over the linger timeout to reduce flakiness in
+ // save/restore tests.
+ absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 2));
+
+ ds.reset();
+
+ if (!IsRunningOnGvisor()) {
+ ASSERT_THAT(
+ bind(conn_fd2.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr),
+ conn_addrlen),
+ SyscallSucceeds());
+ }
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ conn_addrlen),
+ SyscallSucceeds());
+}
+
+// TCPLinger2TimeoutAfterClose creates a pair of connected sockets
+// then closes one end to trigger FIN_WAIT2 state for the closed endpont.
+// It then sleeps for the TCP_LINGER2 timeout and verifies that bind/
+// connecting the same address succeeds.
+//
+// TCP timers are not S/R today, this can cause this test to be flaky when run
+// under random S/R due to timer being reset on a restore.
+TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+
+ // Get the address/port bound by the connecting socket.
+ sockaddr_storage conn_bound_addr;
+ socklen_t conn_addrlen = connector.addr_len;
+ ASSERT_THAT(
+ getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr),
+ &conn_addrlen),
+ SyscallSucceeds());
+
+ // Disable cooperative saves after this point as TCP timers are not restored
+ // across a S/R.
+ {
+ DisableSave ds;
+ constexpr int kTCPLingerTimeout = 5;
+ EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2,
+ &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)),
+ SyscallSucceedsWithValue(0));
+
+ // close the connecting FD to trigger FIN_WAIT2 on the connected fd.
+ conn_fd.reset();
+
+ absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1));
+
+ // ds going out of scope will Re-enable S/R's since at this point the timer
+ // must have fired and cleaned up the endpoint.
+ }
+
+ // Now bind and connect a new socket and verify that we can immediately
+ // rebind the address bound by the conn_fd as it never entered TIME_WAIT.
+ const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ ASSERT_THAT(bind(conn_fd2.get(),
+ reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen),
+ SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ conn_addrlen),
+ SyscallSucceeds());
+}
+
+// TCPResetAfterClose creates a pair of connected sockets then closes
+// one end to trigger FIN_WAIT2 state for the closed endpoint verifies
+// that we generate RSTs for any new data after the socket is fully
+// closed.
+TEST_P(SocketInetLoopbackTest, TCPResetAfterClose) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+
+ // close the connecting FD to trigger FIN_WAIT2 on the connected fd.
+ conn_fd.reset();
+
+ int data = 1234;
+
+ // Now send data which should trigger a RST as the other end should
+ // have timed out and closed the socket.
+ EXPECT_THAT(RetryEINTR(send)(accepted.get(), &data, sizeof(data), 0),
+ SyscallSucceeds());
+ // Sleep for a shortwhile to get a RST back.
+ absl::SleepFor(absl::Seconds(1));
+
+ // Try writing again and we should get an EPIPE back.
+ EXPECT_THAT(RetryEINTR(send)(accepted.get(), &data, sizeof(data), 0),
+ SyscallFailsWithErrno(EPIPE));
+
+ // Trying to read should return zero as the other end did send
+ // us a FIN. We do it twice to verify that the RST does not cause an
+ // ECONNRESET on the read after EOF has been read by applicaiton.
+ EXPECT_THAT(RetryEINTR(recv)(accepted.get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(0));
+ EXPECT_THAT(RetryEINTR(recv)(accepted.get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(0));
+}
+
+// This test is disabled under random save as the the restore run
+// results in the stack.Seed() being different which can cause
+// sequence number of final connect to be one that is considered
+// old and can cause the test to be flaky.
+TEST_P(SocketInetLoopbackTest, TCPTimeWaitTest_NoRandomSave) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ // We disable saves after this point as a S/R causes the netstack seed
+ // to be regenerated which changes what ports/ISN is picked for a given
+ // tuple (src ip,src port, dst ip, dst port). This can cause the final
+ // SYN to use a sequence number that looks like one from the current
+ // connection in TIME_WAIT and will not be accepted causing the test
+ // to timeout.
+ //
+ // TODO(gvisor.dev/issue/940): S/R portSeed/portHint
+ DisableSave ds;
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+
+ // Get the address/port bound by the connecting socket.
+ sockaddr_storage conn_bound_addr;
+ socklen_t conn_addrlen = connector.addr_len;
+ ASSERT_THAT(
+ getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr),
+ &conn_addrlen),
+ SyscallSucceeds());
+
+ // close the accept FD to trigger TIME_WAIT on the accepted socket which
+ // should cause the conn_fd to follow CLOSE_WAIT->LAST_ACK->CLOSED instead of
+ // TIME_WAIT.
+ accepted.reset();
+ absl::SleepFor(absl::Seconds(1));
+ conn_fd.reset();
+ absl::SleepFor(absl::Seconds(1));
+
+ // Now bind and connect a new socket and verify that we can immediately
+ // rebind the address bound by the conn_fd as it never entered TIME_WAIT.
+ const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ ASSERT_THAT(bind(conn_fd2.get(),
+ reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen),
+ SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ conn_addrlen),
+ SyscallSucceeds());
+}
+
+TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ const uint16_t port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Set the userTimeout on the listening socket.
+ constexpr int kUserTimeout = 10;
+ ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &kUserTimeout, sizeof(kUserTimeout)),
+ SyscallSucceeds());
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+ // Verify that the accepted socket inherited the user timeout set on
+ // listening socket.
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(accepted.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kUserTimeout);
+}
+
+// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
+// saved. Enable S/R once issue is fixed.
+TEST_P(SocketInetLoopbackTest, TCPDeferAccept_NoRandomSave) {
+ // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
+ // saved. Enable S/R issue is fixed.
+ DisableSave ds;
+
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ const uint16_t port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Set the TCP_DEFER_ACCEPT on the listening socket.
+ constexpr int kTCPDeferAccept = 3;
+ ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT,
+ &kTCPDeferAccept, sizeof(kTCPDeferAccept)),
+ SyscallSucceeds());
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Set the listening socket to nonblock so that we can verify that there is no
+ // connection in queue despite the connect above succeeding since the peer has
+ // sent no data and TCP_DEFER_ACCEPT is set on the listening socket. Set the
+ // FD to O_NONBLOCK.
+ int opts;
+ ASSERT_THAT(opts = fcntl(listen_fd.get(), F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds());
+
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Set FD back to blocking.
+ opts &= ~O_NONBLOCK;
+ ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds());
+
+ // Now write some data to the socket.
+ int data = 0;
+ ASSERT_THAT(RetryEINTR(write)(conn_fd.get(), &data, sizeof(data)),
+ SyscallSucceedsWithValue(sizeof(data)));
+
+ // This should now cause the connection to complete and be delivered to the
+ // accept socket.
+
+ // Accept the connection.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+
+ // Verify that the accepted socket returns the data written.
+ int get = -1;
+ ASSERT_THAT(RetryEINTR(recv)(accepted.get(), &get, sizeof(get), 0),
+ SyscallSucceedsWithValue(sizeof(get)));
+
+ EXPECT_EQ(get, data);
+}
+
+// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
+// saved. Enable S/R once issue is fixed.
+TEST_P(SocketInetLoopbackTest, TCPDeferAcceptTimeout_NoRandomSave) {
+ // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
+ // saved. Enable S/R once issue is fixed.
+ DisableSave ds;
+
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ const uint16_t port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Set the TCP_DEFER_ACCEPT on the listening socket.
+ constexpr int kTCPDeferAccept = 3;
+ ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT,
+ &kTCPDeferAccept, sizeof(kTCPDeferAccept)),
+ SyscallSucceeds());
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Set the listening socket to nonblock so that we can verify that there is no
+ // connection in queue despite the connect above succeeding since the peer has
+ // sent no data and TCP_DEFER_ACCEPT is set on the listening socket. Set the
+ // FD to O_NONBLOCK.
+ int opts;
+ ASSERT_THAT(opts = fcntl(listen_fd.get(), F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds());
+
+ // Verify that there is no acceptable connection before TCP_DEFER_ACCEPT
+ // timeout is hit.
+ absl::SleepFor(absl::Seconds(kTCPDeferAccept - 1));
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Set FD back to blocking.
+ opts &= ~O_NONBLOCK;
+ ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds());
+
+ // Now sleep for a little over the TCP_DEFER_ACCEPT duration. When the timeout
+ // is hit a SYN-ACK should be retransmitted by the listener as a last ditch
+ // attempt to complete the connection with or without data.
+ absl::SleepFor(absl::Seconds(2));
+
+ // Verify that we have a connection that can be accepted even though no
+ // data was written.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ All, SocketInetLoopbackTest,
+ ::testing::Values(
+ // Listeners bound to IPv4 addresses refuse connections using IPv6
+ // addresses.
+ TestParam{V4Any(), V4Any()}, TestParam{V4Any(), V4Loopback()},
+ TestParam{V4Any(), V4MappedAny()},
+ TestParam{V4Any(), V4MappedLoopback()},
+ TestParam{V4Loopback(), V4Any()}, TestParam{V4Loopback(), V4Loopback()},
+ TestParam{V4Loopback(), V4MappedLoopback()},
+ TestParam{V4MappedAny(), V4Any()},
+ TestParam{V4MappedAny(), V4Loopback()},
+ TestParam{V4MappedAny(), V4MappedAny()},
+ TestParam{V4MappedAny(), V4MappedLoopback()},
+ TestParam{V4MappedLoopback(), V4Any()},
+ TestParam{V4MappedLoopback(), V4Loopback()},
+ TestParam{V4MappedLoopback(), V4MappedLoopback()},
+
+ // Listeners bound to IN6ADDR_ANY accept all connections.
+ TestParam{V6Any(), V4Any()}, TestParam{V6Any(), V4Loopback()},
+ TestParam{V6Any(), V4MappedAny()},
+ TestParam{V6Any(), V4MappedLoopback()}, TestParam{V6Any(), V6Any()},
+ TestParam{V6Any(), V6Loopback()},
+
+ // Listeners bound to IN6ADDR_LOOPBACK refuse connections using IPv4
+ // addresses.
+ TestParam{V6Loopback(), V6Any()},
+ TestParam{V6Loopback(), V6Loopback()}),
+ DescribeTestParam);
+
+using SocketInetReusePortTest = ::testing::TestWithParam<TestParam>;
+
+// TODO(gvisor.dev/issue/940): Remove _NoRandomSave when portHint/stack.Seed is
+// saved/restored.
+TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+ constexpr int kThreadCount = 3;
+ constexpr int kConnectAttempts = 10000;
+
+ // Create the listening socket.
+ FileDescriptor listener_fds[kThreadCount];
+ for (int i = 0; i < kThreadCount; i++) {
+ listener_fds[i] = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ int fd = listener_fds[i].get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(fd, 40), SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (i != 0) {
+ continue;
+ }
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ std::atomic<int> connects_received = ATOMIC_VAR_INIT(0);
+ std::unique_ptr<ScopedThread> listen_thread[kThreadCount];
+ int accept_counts[kThreadCount] = {};
+ // TODO(avagin): figure how to not disable S/R for the whole test.
+ // We need to take into account that this test executes a lot of system
+ // calls from many threads.
+ DisableSave ds;
+
+ for (int i = 0; i < kThreadCount; i++) {
+ listen_thread[i] = absl::make_unique<ScopedThread>(
+ [&listener_fds, &accept_counts, i, &connects_received]() {
+ do {
+ auto fd = Accept(listener_fds[i].get(), nullptr, nullptr);
+ if (!fd.ok()) {
+ if (connects_received >= kConnectAttempts) {
+ // Another thread have shutdown our read side causing the
+ // accept to fail.
+ ASSERT_EQ(errno, EINVAL);
+ break;
+ }
+ ASSERT_NO_ERRNO(fd);
+ break;
+ }
+ // Receive some data from a socket to be sure that the connect()
+ // system call has been completed on another side.
+ // Do a short read and then close the socket to trigger a RST. This
+ // ensures that both ends of the connection are cleaned up and no
+ // goroutines hang around in TIME-WAIT. We do this so that this test
+ // does not timeout under gotsan runs where lots of goroutines can
+ // cause the test to use absurd amounts of memory.
+ //
+ // See: https://tools.ietf.org/html/rfc2525#page-50 section 2.17
+ uint16_t data;
+ EXPECT_THAT(
+ RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(sizeof(data)));
+ accept_counts[i]++;
+ } while (++connects_received < kConnectAttempts);
+
+ // Shutdown all sockets to wake up other threads.
+ for (int j = 0; j < kThreadCount; j++) {
+ shutdown(listener_fds[j].get(), SHUT_RDWR);
+ }
+ });
+ }
+
+ ScopedThread connecting_thread([&connector, &conn_addr]() {
+ for (int i = 0; i < kConnectAttempts; i++) {
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(
+ RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Do two separate sends to ensure two segments are received. This is
+ // required for netstack where read is incorrectly assuming a whole
+ // segment is read when endpoint.Read() is called which is technically
+ // incorrect as the syscall that invoked endpoint.Read() may only
+ // consume it partially. This results in a case where a close() of
+ // such a socket does not trigger a RST in netstack due to the
+ // endpoint assuming that the endpoint has no unread data.
+ EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+
+ // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly
+ // generates a RST.
+ if (IsRunningOnGvisor()) {
+ EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
+ }
+ });
+
+ // Join threads to be sure that all connections have been counted
+ connecting_thread.Join();
+ for (int i = 0; i < kThreadCount; i++) {
+ listen_thread[i]->Join();
+ }
+ // Check that connections are distributed fairly between listening sockets
+ for (int i = 0; i < kThreadCount; i++)
+ EXPECT_THAT(accept_counts[i],
+ EquivalentWithin((kConnectAttempts / kThreadCount), 0.10));
+}
+
+TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+ constexpr int kThreadCount = 3;
+
+ // Create the listening socket.
+ FileDescriptor listener_fds[kThreadCount];
+ for (int i = 0; i < kThreadCount; i++) {
+ listener_fds[i] =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0));
+ int fd = listener_fds[i].get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (i != 0) {
+ continue;
+ }
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ constexpr int kConnectAttempts = 10000;
+ std::atomic<int> packets_received = ATOMIC_VAR_INIT(0);
+ std::unique_ptr<ScopedThread> receiver_thread[kThreadCount];
+ int packets_per_socket[kThreadCount] = {};
+ // TODO(avagin): figure how to not disable S/R for the whole test.
+ DisableSave ds; // Too expensive.
+
+ for (int i = 0; i < kThreadCount; i++) {
+ receiver_thread[i] = absl::make_unique<ScopedThread>(
+ [&listener_fds, &packets_per_socket, i, &packets_received]() {
+ do {
+ struct sockaddr_storage addr = {};
+ socklen_t addrlen = sizeof(addr);
+ int data;
+
+ auto ret = RetryEINTR(recvfrom)(
+ listener_fds[i].get(), &data, sizeof(data), 0,
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen);
+
+ if (packets_received < kConnectAttempts) {
+ ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data)));
+ }
+
+ if (ret != sizeof(data)) {
+ // Another thread may have shutdown our read side causing the
+ // recvfrom to fail.
+ break;
+ }
+
+ packets_received++;
+ packets_per_socket[i]++;
+
+ // A response is required to synchronize with the main thread,
+ // otherwise the main thread can send more than can fit into receive
+ // queues.
+ EXPECT_THAT(RetryEINTR(sendto)(
+ listener_fds[i].get(), &data, sizeof(data), 0,
+ reinterpret_cast<sockaddr*>(&addr), addrlen),
+ SyscallSucceedsWithValue(sizeof(data)));
+ } while (packets_received < kConnectAttempts);
+
+ // Shutdown all sockets to wake up other threads.
+ for (int j = 0; j < kThreadCount; j++)
+ shutdown(listener_fds[j].get(), SHUT_RDWR);
+ });
+ }
+
+ ScopedThread main_thread([&connector, &conn_addr]() {
+ for (int i = 0; i < kConnectAttempts; i++) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0));
+ EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ int data;
+ EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(sizeof(data)));
+ }
+ });
+
+ main_thread.Join();
+
+ // Join threads to be sure that all connections have been counted
+ for (int i = 0; i < kThreadCount; i++) {
+ receiver_thread[i]->Join();
+ }
+ // Check that packets are distributed fairly between listening sockets.
+ for (int i = 0; i < kThreadCount; i++)
+ EXPECT_THAT(packets_per_socket[i],
+ EquivalentWithin((kConnectAttempts / kThreadCount), 0.10));
+}
+
+TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+ constexpr int kThreadCount = 3;
+
+ // TODO(b/141211329): endpointsByNic.seed has to be saved/restored.
+ const DisableSave ds141211329;
+
+ // Create listening sockets.
+ FileDescriptor listener_fds[kThreadCount];
+ for (int i = 0; i < kThreadCount; i++) {
+ listener_fds[i] =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0));
+ int fd = listener_fds[i].get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (i != 0) {
+ continue;
+ }
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ constexpr int kConnectAttempts = 10;
+ FileDescriptor client_fds[kConnectAttempts];
+
+ // Do the first run without save/restore.
+ DisableSave ds;
+ for (int i = 0; i < kConnectAttempts; i++) {
+ client_fds[i] =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0));
+ EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
+ ds.reset();
+
+ // Check that a mapping of client and server sockets has
+ // not been change after save/restore.
+ for (int i = 0; i < kConnectAttempts; i++) {
+ EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
+
+ struct pollfd pollfds[kThreadCount];
+ for (int i = 0; i < kThreadCount; i++) {
+ pollfds[i].fd = listener_fds[i].get();
+ pollfds[i].events = POLLIN;
+ }
+
+ std::map<uint16_t, int> portToFD;
+
+ int received = 0;
+ while (received < kConnectAttempts * 2) {
+ ASSERT_THAT(poll(pollfds, kThreadCount, -1),
+ SyscallSucceedsWithValue(Gt(0)));
+
+ for (int i = 0; i < kThreadCount; i++) {
+ if ((pollfds[i].revents & POLLIN) == 0) {
+ continue;
+ }
+
+ received++;
+
+ const int fd = pollfds[i].fd;
+ struct sockaddr_storage addr = {};
+ socklen_t addrlen = sizeof(addr);
+ int data;
+ EXPECT_THAT(RetryEINTR(recvfrom)(
+ fd, &data, sizeof(data), 0,
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ SyscallSucceedsWithValue(sizeof(data)));
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(connector.family(), addr));
+ auto prev_port = portToFD.find(port);
+ // Check that all packets from one client have been delivered to the
+ // same server socket.
+ if (prev_port == portToFD.end()) {
+ portToFD[port] = fd;
+ } else {
+ EXPECT_EQ(portToFD[port], fd);
+ }
+ }
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ All, SocketInetReusePortTest,
+ ::testing::Values(
+ // Listeners bound to IPv4 addresses refuse connections using IPv6
+ // addresses.
+ TestParam{V4Any(), V4Loopback()},
+ TestParam{V4Loopback(), V4MappedLoopback()},
+
+ // Listeners bound to IN6ADDR_ANY accept all connections.
+ TestParam{V6Any(), V4Loopback()}, TestParam{V6Any(), V6Loopback()},
+
+ // Listeners bound to IN6ADDR_LOOPBACK refuse connections using IPv4
+ // addresses.
+ TestParam{V6Loopback(), V6Loopback()}),
+ DescribeTestParam);
+
+struct ProtocolTestParam {
+ std::string description;
+ int type;
+};
+
+std::string DescribeProtocolTestParam(
+ ::testing::TestParamInfo<ProtocolTestParam> const& info) {
+ return info.param.description;
+}
+
+using SocketMultiProtocolInetLoopbackTest =
+ ::testing::TestWithParam<ProtocolTestParam>;
+
+TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedLoopbackOnlyReservesV4) {
+ auto const& param = GetParam();
+
+ for (int i = 0; true; i++) {
+ // Bind the v4 loopback on a dual stack socket.
+ TestAddress const& test_addr_dual = V4MappedLoopback();
+ sockaddr_storage addr_dual = test_addr_dual.addr;
+ const FileDescriptor fd_dual = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_dual.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual),
+ test_addr_dual.addr_len),
+ SyscallSucceeds());
+
+ // Get the port that we bound.
+ socklen_t addrlen = test_addr_dual.addr_len;
+ ASSERT_THAT(getsockname(fd_dual.get(),
+ reinterpret_cast<sockaddr*>(&addr_dual), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual));
+
+ // Verify that we can still bind the v6 loopback on the same port.
+ TestAddress const& test_addr_v6 = V6Loopback();
+ sockaddr_storage addr_v6 = test_addr_v6.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port));
+ const FileDescriptor fd_v6 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0));
+ int ret = bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6),
+ test_addr_v6.addr_len);
+ if (ret == -1 && errno == EADDRINUSE) {
+ // Port may have been in use.
+ ASSERT_LT(i, 100); // Give up after 100 tries.
+ continue;
+ }
+ ASSERT_THAT(ret, SyscallSucceeds());
+
+ // Verify that binding the v4 loopback with the same port on a v4 socket
+ // fails.
+ TestAddress const& test_addr_v4 = V4Loopback();
+ sockaddr_storage addr_v4 = test_addr_v4.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4.family(), &addr_v4, port));
+ const FileDescriptor fd_v4 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4),
+ test_addr_v4.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // No need to try again.
+ break;
+ }
+}
+
+TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedAnyOnlyReservesV4) {
+ auto const& param = GetParam();
+
+ for (int i = 0; true; i++) {
+ // Bind the v4 any on a dual stack socket.
+ TestAddress const& test_addr_dual = V4MappedAny();
+ sockaddr_storage addr_dual = test_addr_dual.addr;
+ const FileDescriptor fd_dual = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_dual.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual),
+ test_addr_dual.addr_len),
+ SyscallSucceeds());
+
+ // Get the port that we bound.
+ socklen_t addrlen = test_addr_dual.addr_len;
+ ASSERT_THAT(getsockname(fd_dual.get(),
+ reinterpret_cast<sockaddr*>(&addr_dual), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual));
+
+ // Verify that we can still bind the v6 loopback on the same port.
+ TestAddress const& test_addr_v6 = V6Loopback();
+ sockaddr_storage addr_v6 = test_addr_v6.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port));
+ const FileDescriptor fd_v6 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0));
+ int ret = bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6),
+ test_addr_v6.addr_len);
+ if (ret == -1 && errno == EADDRINUSE) {
+ // Port may have been in use.
+ ASSERT_LT(i, 100); // Give up after 100 tries.
+ continue;
+ }
+ ASSERT_THAT(ret, SyscallSucceeds());
+
+ // Verify that binding the v4 loopback with the same port on a v4 socket
+ // fails.
+ TestAddress const& test_addr_v4 = V4Loopback();
+ sockaddr_storage addr_v4 = test_addr_v4.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4.family(), &addr_v4, port));
+ const FileDescriptor fd_v4 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4),
+ test_addr_v4.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // No need to try again.
+ break;
+ }
+}
+
+TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReservesEverything) {
+ auto const& param = GetParam();
+
+ // Bind the v6 any on a dual stack socket.
+ TestAddress const& test_addr_dual = V6Any();
+ sockaddr_storage addr_dual = test_addr_dual.addr;
+ const FileDescriptor fd_dual =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual),
+ test_addr_dual.addr_len),
+ SyscallSucceeds());
+
+ // Get the port that we bound.
+ socklen_t addrlen = test_addr_dual.addr_len;
+ ASSERT_THAT(getsockname(fd_dual.get(),
+ reinterpret_cast<sockaddr*>(&addr_dual), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual));
+
+ // Verify that binding the v6 loopback with the same port fails.
+ TestAddress const& test_addr_v6 = V6Loopback();
+ sockaddr_storage addr_v6 = test_addr_v6.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port));
+ const FileDescriptor fd_v6 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6),
+ test_addr_v6.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Verify that binding the v4 loopback on the same port with a v6 socket
+ // fails.
+ TestAddress const& test_addr_v4_mapped = V4MappedLoopback();
+ sockaddr_storage addr_v4_mapped = test_addr_v4_mapped.addr;
+ ASSERT_NO_ERRNO(
+ SetAddrPort(test_addr_v4_mapped.family(), &addr_v4_mapped, port));
+ const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v4_mapped.family(), param.type, 0));
+ ASSERT_THAT(
+ bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped),
+ test_addr_v4_mapped.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Verify that binding the v4 loopback on the same port with a v4 socket
+ // fails.
+ TestAddress const& test_addr_v4 = V4Loopback();
+ sockaddr_storage addr_v4 = test_addr_v4.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4.family(), &addr_v4, port));
+ const FileDescriptor fd_v4 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4),
+ test_addr_v4.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Verify that binding the v4 any on the same port with a v4 socket
+ // fails.
+ TestAddress const& test_addr_v4_any = V4Any();
+ sockaddr_storage addr_v4_any = test_addr_v4_any.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port));
+ const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v4_any.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any),
+ test_addr_v4_any.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+TEST_P(SocketMultiProtocolInetLoopbackTest,
+ DualStackV6AnyReuseAddrDoesNotReserveV4Any) {
+ auto const& param = GetParam();
+
+ // Bind the v6 any on a dual stack socket.
+ TestAddress const& test_addr_dual = V6Any();
+ sockaddr_storage addr_dual = test_addr_dual.addr;
+ const FileDescriptor fd_dual =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0));
+ ASSERT_THAT(setsockopt(fd_dual.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual),
+ test_addr_dual.addr_len),
+ SyscallSucceeds());
+
+ // Get the port that we bound.
+ socklen_t addrlen = test_addr_dual.addr_len;
+ ASSERT_THAT(getsockname(fd_dual.get(),
+ reinterpret_cast<sockaddr*>(&addr_dual), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual));
+
+ // Verify that binding the v4 any on the same port with a v4 socket succeeds.
+ TestAddress const& test_addr_v4_any = V4Any();
+ sockaddr_storage addr_v4_any = test_addr_v4_any.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port));
+ const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v4_any.family(), param.type, 0));
+ ASSERT_THAT(setsockopt(fd_v4_any.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any),
+ test_addr_v4_any.addr_len),
+ SyscallSucceeds());
+}
+
+TEST_P(SocketMultiProtocolInetLoopbackTest,
+ DualStackV6AnyReuseAddrListenReservesV4Any) {
+ auto const& param = GetParam();
+
+ // Only TCP sockets are supported.
+ SKIP_IF((param.type & SOCK_STREAM) == 0);
+
+ // Bind the v6 any on a dual stack socket.
+ TestAddress const& test_addr_dual = V6Any();
+ sockaddr_storage addr_dual = test_addr_dual.addr;
+ const FileDescriptor fd_dual =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0));
+ ASSERT_THAT(setsockopt(fd_dual.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual),
+ test_addr_dual.addr_len),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(fd_dual.get(), 5), SyscallSucceeds());
+
+ // Get the port that we bound.
+ socklen_t addrlen = test_addr_dual.addr_len;
+ ASSERT_THAT(getsockname(fd_dual.get(),
+ reinterpret_cast<sockaddr*>(&addr_dual), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual));
+
+ // Verify that binding the v4 any on the same port with a v4 socket succeeds.
+ TestAddress const& test_addr_v4_any = V4Any();
+ sockaddr_storage addr_v4_any = test_addr_v4_any.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port));
+ const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v4_any.family(), param.type, 0));
+ ASSERT_THAT(setsockopt(fd_v4_any.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any),
+ test_addr_v4_any.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+TEST_P(SocketMultiProtocolInetLoopbackTest,
+ DualStackV6AnyWithListenReservesEverything) {
+ auto const& param = GetParam();
+
+ // Only TCP sockets are supported.
+ SKIP_IF((param.type & SOCK_STREAM) == 0);
+
+ // Bind the v6 any on a dual stack socket.
+ TestAddress const& test_addr_dual = V6Any();
+ sockaddr_storage addr_dual = test_addr_dual.addr;
+ const FileDescriptor fd_dual =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual),
+ test_addr_dual.addr_len),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(fd_dual.get(), 5), SyscallSucceeds());
+
+ // Get the port that we bound.
+ socklen_t addrlen = test_addr_dual.addr_len;
+ ASSERT_THAT(getsockname(fd_dual.get(),
+ reinterpret_cast<sockaddr*>(&addr_dual), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual));
+
+ // Verify that binding the v6 loopback with the same port fails.
+ TestAddress const& test_addr_v6 = V6Loopback();
+ sockaddr_storage addr_v6 = test_addr_v6.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port));
+ const FileDescriptor fd_v6 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6),
+ test_addr_v6.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Verify that binding the v4 loopback on the same port with a v6 socket
+ // fails.
+ TestAddress const& test_addr_v4_mapped = V4MappedLoopback();
+ sockaddr_storage addr_v4_mapped = test_addr_v4_mapped.addr;
+ ASSERT_NO_ERRNO(
+ SetAddrPort(test_addr_v4_mapped.family(), &addr_v4_mapped, port));
+ const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v4_mapped.family(), param.type, 0));
+ ASSERT_THAT(
+ bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped),
+ test_addr_v4_mapped.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Verify that binding the v4 loopback on the same port with a v4 socket
+ // fails.
+ TestAddress const& test_addr_v4 = V4Loopback();
+ sockaddr_storage addr_v4 = test_addr_v4.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4.family(), &addr_v4, port));
+ const FileDescriptor fd_v4 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4),
+ test_addr_v4.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Verify that binding the v4 any on the same port with a v4 socket
+ // fails.
+ TestAddress const& test_addr_v4_any = V4Any();
+ sockaddr_storage addr_v4_any = test_addr_v4_any.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port));
+ const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v4_any.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any),
+ test_addr_v4_any.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) {
+ auto const& param = GetParam();
+
+ for (int i = 0; true; i++) {
+ // Bind the v6 any on a v6-only socket.
+ TestAddress const& test_addr_dual = V6Any();
+ sockaddr_storage addr_dual = test_addr_dual.addr;
+ const FileDescriptor fd_dual = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_dual.family(), param.type, 0));
+ EXPECT_THAT(setsockopt(fd_dual.get(), IPPROTO_IPV6, IPV6_V6ONLY,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual),
+ test_addr_dual.addr_len),
+ SyscallSucceeds());
+
+ // Get the port that we bound.
+ socklen_t addrlen = test_addr_dual.addr_len;
+ ASSERT_THAT(getsockname(fd_dual.get(),
+ reinterpret_cast<sockaddr*>(&addr_dual), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual));
+
+ // Verify that binding the v6 loopback with the same port fails.
+ TestAddress const& test_addr_v6 = V6Loopback();
+ sockaddr_storage addr_v6 = test_addr_v6.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port));
+ const FileDescriptor fd_v6 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6),
+ test_addr_v6.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Verify that we can still bind the v4 loopback on the same port.
+ TestAddress const& test_addr_v4_mapped = V4MappedLoopback();
+ sockaddr_storage addr_v4_mapped = test_addr_v4_mapped.addr;
+ ASSERT_NO_ERRNO(
+ SetAddrPort(test_addr_v4_mapped.family(), &addr_v4_mapped, port));
+ const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v4_mapped.family(), param.type, 0));
+ int ret =
+ bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped),
+ test_addr_v4_mapped.addr_len);
+ if (ret == -1 && errno == EADDRINUSE) {
+ // Port may have been in use.
+ ASSERT_LT(i, 100); // Give up after 100 tries.
+ continue;
+ }
+ ASSERT_THAT(ret, SyscallSucceeds());
+
+ // No need to try again.
+ break;
+ }
+}
+
+TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) {
+ auto const& param = GetParam();
+
+ for (int i = 0; true; i++) {
+ // Bind the v6 loopback on a dual stack socket.
+ TestAddress const& test_addr = V6Loopback();
+ sockaddr_storage bound_addr = test_addr.addr;
+ const FileDescriptor bound_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
+ test_addr.addr_len),
+ SyscallSucceeds());
+
+ // Listen iff TCP.
+ if (param.type == SOCK_STREAM) {
+ ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds());
+ }
+
+ // Get the port that we bound.
+ socklen_t bound_addr_len = test_addr.addr_len;
+ ASSERT_THAT(
+ getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
+ &bound_addr_len),
+ SyscallSucceeds());
+
+ // Connect to bind an ephemeral port.
+ const FileDescriptor connected_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&bound_addr),
+ bound_addr_len),
+ SyscallSucceeds());
+
+ // Get the ephemeral port.
+ sockaddr_storage connected_addr = {};
+ socklen_t connected_addr_len = sizeof(connected_addr);
+ ASSERT_THAT(getsockname(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&connected_addr),
+ &connected_addr_len),
+ SyscallSucceeds());
+ uint16_t const ephemeral_port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr));
+
+ // Verify that we actually got an ephemeral port.
+ ASSERT_NE(ephemeral_port, 0);
+
+ // Verify that the ephemeral port is reserved.
+ const FileDescriptor checking_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ EXPECT_THAT(
+ bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr),
+ connected_addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Verify that binding the v6 loopback with the same port fails.
+ TestAddress const& test_addr_v6 = V6Loopback();
+ sockaddr_storage addr_v6 = test_addr_v6.addr;
+ ASSERT_NO_ERRNO(
+ SetAddrPort(test_addr_v6.family(), &addr_v6, ephemeral_port));
+ const FileDescriptor fd_v6 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6),
+ test_addr_v6.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Verify that we can still bind the v4 loopback on the same port.
+ TestAddress const& test_addr_v4_mapped = V4MappedLoopback();
+ sockaddr_storage addr_v4_mapped = test_addr_v4_mapped.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_mapped.family(), &addr_v4_mapped,
+ ephemeral_port));
+ const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v4_mapped.family(), param.type, 0));
+ int ret =
+ bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped),
+ test_addr_v4_mapped.addr_len);
+ if (ret == -1 && errno == EADDRINUSE) {
+ // Port may have been in use.
+ ASSERT_LT(i, 100); // Give up after 100 tries.
+ continue;
+ }
+ EXPECT_THAT(ret, SyscallSucceeds());
+
+ // No need to try again.
+ break;
+ }
+}
+
+TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReservedReuseAddr) {
+ auto const& param = GetParam();
+
+ // Bind the v6 loopback on a dual stack socket.
+ TestAddress const& test_addr = V6Loopback();
+ sockaddr_storage bound_addr = test_addr.addr;
+ const FileDescriptor bound_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
+ test_addr.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Listen iff TCP.
+ if (param.type == SOCK_STREAM) {
+ ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds());
+ }
+
+ // Get the port that we bound.
+ socklen_t bound_addr_len = test_addr.addr_len;
+ ASSERT_THAT(
+ getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
+ &bound_addr_len),
+ SyscallSucceeds());
+
+ // Connect to bind an ephemeral port.
+ const FileDescriptor connected_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&bound_addr),
+ bound_addr_len),
+ SyscallSucceeds());
+
+ // Get the ephemeral port.
+ sockaddr_storage connected_addr = {};
+ socklen_t connected_addr_len = sizeof(connected_addr);
+ ASSERT_THAT(getsockname(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&connected_addr),
+ &connected_addr_len),
+ SyscallSucceeds());
+ uint16_t const ephemeral_port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr));
+
+ // Verify that we actually got an ephemeral port.
+ ASSERT_NE(ephemeral_port, 0);
+
+ // Verify that the ephemeral port is not reserved.
+ const FileDescriptor checking_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ EXPECT_THAT(
+ bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr),
+ connected_addr_len),
+ SyscallSucceeds());
+}
+
+TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) {
+ auto const& param = GetParam();
+
+ for (int i = 0; true; i++) {
+ // Bind the v4 loopback on a dual stack socket.
+ TestAddress const& test_addr = V4MappedLoopback();
+ sockaddr_storage bound_addr = test_addr.addr;
+ const FileDescriptor bound_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
+ test_addr.addr_len),
+ SyscallSucceeds());
+
+ // Listen iff TCP.
+ if (param.type == SOCK_STREAM) {
+ ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds());
+ }
+
+ // Get the port that we bound.
+ socklen_t bound_addr_len = test_addr.addr_len;
+ ASSERT_THAT(
+ getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
+ &bound_addr_len),
+ SyscallSucceeds());
+
+ // Connect to bind an ephemeral port.
+ const FileDescriptor connected_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&bound_addr),
+ bound_addr_len),
+ SyscallSucceeds());
+
+ // Get the ephemeral port.
+ sockaddr_storage connected_addr = {};
+ socklen_t connected_addr_len = sizeof(connected_addr);
+ ASSERT_THAT(getsockname(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&connected_addr),
+ &connected_addr_len),
+ SyscallSucceeds());
+ uint16_t const ephemeral_port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr));
+
+ // Verify that we actually got an ephemeral port.
+ ASSERT_NE(ephemeral_port, 0);
+
+ // Verify that the ephemeral port is reserved.
+ const FileDescriptor checking_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ EXPECT_THAT(
+ bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr),
+ connected_addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Verify that binding the v4 loopback on the same port with a v4 socket
+ // fails.
+ TestAddress const& test_addr_v4 = V4Loopback();
+ sockaddr_storage addr_v4 = test_addr_v4.addr;
+ ASSERT_NO_ERRNO(
+ SetAddrPort(test_addr_v4.family(), &addr_v4, ephemeral_port));
+ const FileDescriptor fd_v4 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0));
+ EXPECT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4),
+ test_addr_v4.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Verify that binding the v6 any on the same port with a dual-stack socket
+ // fails.
+ TestAddress const& test_addr_v6_any = V6Any();
+ sockaddr_storage addr_v6_any = test_addr_v6_any.addr;
+ ASSERT_NO_ERRNO(
+ SetAddrPort(test_addr_v6_any.family(), &addr_v6_any, ephemeral_port));
+ const FileDescriptor fd_v6_any = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v6_any.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_v6_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any),
+ test_addr_v6_any.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // For some reason, binding the TCP v6-only any is flaky on Linux. Maybe we
+ // tend to run out of ephemeral ports? Regardless, binding the v6 loopback
+ // seems pretty reliable. Only try to bind the v6-only any on UDP and
+ // gVisor.
+
+ int ret = -1;
+
+ if (!IsRunningOnGvisor() && param.type == SOCK_STREAM) {
+ // Verify that we can still bind the v6 loopback on the same port.
+ TestAddress const& test_addr_v6 = V6Loopback();
+ sockaddr_storage addr_v6 = test_addr_v6.addr;
+ ASSERT_NO_ERRNO(
+ SetAddrPort(test_addr_v6.family(), &addr_v6, ephemeral_port));
+ const FileDescriptor fd_v6 = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v6.family(), param.type, 0));
+ ret = bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6),
+ test_addr_v6.addr_len);
+ } else {
+ // Verify that we can still bind the v6 any on the same port with a
+ // v6-only socket.
+ const FileDescriptor fd_v6_only_any = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v6_any.family(), param.type, 0));
+ EXPECT_THAT(setsockopt(fd_v6_only_any.get(), IPPROTO_IPV6, IPV6_V6ONLY,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ret =
+ bind(fd_v6_only_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any),
+ test_addr_v6_any.addr_len);
+ }
+
+ if (ret == -1 && errno == EADDRINUSE) {
+ // Port may have been in use.
+ ASSERT_LT(i, 100); // Give up after 100 tries.
+ continue;
+ }
+ EXPECT_THAT(ret, SyscallSucceeds());
+
+ // No need to try again.
+ break;
+ }
+}
+
+TEST_P(SocketMultiProtocolInetLoopbackTest,
+ V4MappedEphemeralPortReservedResueAddr) {
+ auto const& param = GetParam();
+
+ // Bind the v4 loopback on a dual stack socket.
+ TestAddress const& test_addr = V4MappedLoopback();
+ sockaddr_storage bound_addr = test_addr.addr;
+ const FileDescriptor bound_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
+ test_addr.addr_len),
+ SyscallSucceeds());
+
+ ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Listen iff TCP.
+ if (param.type == SOCK_STREAM) {
+ ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds());
+ }
+
+ // Get the port that we bound.
+ socklen_t bound_addr_len = test_addr.addr_len;
+ ASSERT_THAT(
+ getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
+ &bound_addr_len),
+ SyscallSucceeds());
+
+ // Connect to bind an ephemeral port.
+ const FileDescriptor connected_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&bound_addr),
+ bound_addr_len),
+ SyscallSucceeds());
+
+ // Get the ephemeral port.
+ sockaddr_storage connected_addr = {};
+ socklen_t connected_addr_len = sizeof(connected_addr);
+ ASSERT_THAT(getsockname(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&connected_addr),
+ &connected_addr_len),
+ SyscallSucceeds());
+ uint16_t const ephemeral_port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr));
+
+ // Verify that we actually got an ephemeral port.
+ ASSERT_NE(ephemeral_port, 0);
+
+ // Verify that the ephemeral port is not reserved.
+ const FileDescriptor checking_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ EXPECT_THAT(
+ bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr),
+ connected_addr_len),
+ SyscallSucceeds());
+}
+
+TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) {
+ auto const& param = GetParam();
+
+ for (int i = 0; true; i++) {
+ // Bind the v4 loopback on a v4 socket.
+ TestAddress const& test_addr = V4Loopback();
+ sockaddr_storage bound_addr = test_addr.addr;
+ const FileDescriptor bound_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
+ test_addr.addr_len),
+ SyscallSucceeds());
+
+ // Listen iff TCP.
+ if (param.type == SOCK_STREAM) {
+ ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds());
+ }
+
+ // Get the port that we bound.
+ socklen_t bound_addr_len = test_addr.addr_len;
+ ASSERT_THAT(
+ getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
+ &bound_addr_len),
+ SyscallSucceeds());
+
+ // Connect to bind an ephemeral port.
+ const FileDescriptor connected_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&bound_addr),
+ bound_addr_len),
+ SyscallSucceeds());
+
+ // Get the ephemeral port.
+ sockaddr_storage connected_addr = {};
+ socklen_t connected_addr_len = sizeof(connected_addr);
+ ASSERT_THAT(getsockname(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&connected_addr),
+ &connected_addr_len),
+ SyscallSucceeds());
+ uint16_t const ephemeral_port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr));
+
+ // Verify that we actually got an ephemeral port.
+ ASSERT_NE(ephemeral_port, 0);
+
+ // Verify that the ephemeral port is reserved.
+ const FileDescriptor checking_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ EXPECT_THAT(
+ bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr),
+ connected_addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Verify that binding the v4 loopback on the same port with a v6 socket
+ // fails.
+ TestAddress const& test_addr_v4_mapped = V4MappedLoopback();
+ sockaddr_storage addr_v4_mapped = test_addr_v4_mapped.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_mapped.family(), &addr_v4_mapped,
+ ephemeral_port));
+ const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v4_mapped.family(), param.type, 0));
+ EXPECT_THAT(
+ bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped),
+ test_addr_v4_mapped.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Verify that binding the v6 any on the same port with a dual-stack socket
+ // fails.
+ TestAddress const& test_addr_v6_any = V6Any();
+ sockaddr_storage addr_v6_any = test_addr_v6_any.addr;
+ ASSERT_NO_ERRNO(
+ SetAddrPort(test_addr_v6_any.family(), &addr_v6_any, ephemeral_port));
+ const FileDescriptor fd_v6_any = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v6_any.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_v6_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any),
+ test_addr_v6_any.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // For some reason, binding the TCP v6-only any is flaky on Linux. Maybe we
+ // tend to run out of ephemeral ports? Regardless, binding the v6 loopback
+ // seems pretty reliable. Only try to bind the v6-only any on UDP and
+ // gVisor.
+
+ int ret = -1;
+
+ if (!IsRunningOnGvisor() && param.type == SOCK_STREAM) {
+ // Verify that we can still bind the v6 loopback on the same port.
+ TestAddress const& test_addr_v6 = V6Loopback();
+ sockaddr_storage addr_v6 = test_addr_v6.addr;
+ ASSERT_NO_ERRNO(
+ SetAddrPort(test_addr_v6.family(), &addr_v6, ephemeral_port));
+ const FileDescriptor fd_v6 = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v6.family(), param.type, 0));
+ ret = bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6),
+ test_addr_v6.addr_len);
+ } else {
+ // Verify that we can still bind the v6 any on the same port with a
+ // v6-only socket.
+ const FileDescriptor fd_v6_only_any = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v6_any.family(), param.type, 0));
+ EXPECT_THAT(setsockopt(fd_v6_only_any.get(), IPPROTO_IPV6, IPV6_V6ONLY,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ret =
+ bind(fd_v6_only_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any),
+ test_addr_v6_any.addr_len);
+ }
+
+ if (ret == -1 && errno == EADDRINUSE) {
+ // Port may have been in use.
+ ASSERT_LT(i, 100); // Give up after 100 tries.
+ continue;
+ }
+ EXPECT_THAT(ret, SyscallSucceeds());
+
+ // No need to try again.
+ break;
+ }
+}
+
+TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) {
+ auto const& param = GetParam();
+
+ // Bind the v4 loopback on a v4 socket.
+ TestAddress const& test_addr = V4Loopback();
+ sockaddr_storage bound_addr = test_addr.addr;
+ const FileDescriptor bound_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+
+ ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
+ test_addr.addr_len),
+ SyscallSucceeds());
+
+ // Listen iff TCP.
+ if (param.type == SOCK_STREAM) {
+ ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds());
+ }
+
+ // Get the port that we bound.
+ socklen_t bound_addr_len = test_addr.addr_len;
+ ASSERT_THAT(
+ getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
+ &bound_addr_len),
+ SyscallSucceeds());
+
+ // Connect to bind an ephemeral port.
+ const FileDescriptor connected_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+
+ ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&bound_addr),
+ bound_addr_len),
+ SyscallSucceeds());
+
+ // Get the ephemeral port.
+ sockaddr_storage connected_addr = {};
+ socklen_t connected_addr_len = sizeof(connected_addr);
+ ASSERT_THAT(getsockname(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&connected_addr),
+ &connected_addr_len),
+ SyscallSucceeds());
+ uint16_t const ephemeral_port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr));
+
+ // Verify that we actually got an ephemeral port.
+ ASSERT_NE(ephemeral_port, 0);
+
+ // Verify that the ephemeral port is not reserved.
+ const FileDescriptor checking_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ EXPECT_THAT(
+ bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr),
+ connected_addr_len),
+ SyscallSucceeds());
+}
+
+TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) {
+ auto const& param = GetParam();
+ TestAddress const& test_addr = V4Loopback();
+ sockaddr_storage addr = test_addr.addr;
+
+ for (int i = 0; i < 2; i++) {
+ const int portreuse1 = i % 2;
+ auto s1 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ int fd1 = s1.get();
+ socklen_t addrlen = test_addr.addr_len;
+
+ EXPECT_THAT(
+ setsockopt(fd1, SOL_SOCKET, SO_REUSEPORT, &portreuse1, sizeof(int)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(bind(fd1, reinterpret_cast<sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(getsockname(fd1, reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ if (param.type == SOCK_STREAM) {
+ ASSERT_THAT(listen(fd1, 1), SyscallSucceeds());
+ }
+
+ // j is less than 4 to check that the port reuse logic works correctly after
+ // closing bound sockets.
+ for (int j = 0; j < 4; j++) {
+ const int portreuse2 = j % 2;
+ auto s2 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ int fd2 = s2.get();
+
+ EXPECT_THAT(
+ setsockopt(fd2, SOL_SOCKET, SO_REUSEPORT, &portreuse2, sizeof(int)),
+ SyscallSucceeds());
+
+ std::cout << portreuse1 << " " << portreuse2 << std::endl;
+ int ret = bind(fd2, reinterpret_cast<sockaddr*>(&addr), addrlen);
+
+ // Verify that two sockets can be bound to the same port only if
+ // SO_REUSEPORT is set for both of them.
+ if (!portreuse1 || !portreuse2) {
+ ASSERT_THAT(ret, SyscallFailsWithErrno(EADDRINUSE));
+ } else {
+ ASSERT_THAT(ret, SyscallSucceeds());
+ }
+ }
+ }
+}
+
+// Check that when a socket was bound to an address with REUSEPORT and then
+// closed, we can bind a different socket to the same address without needing
+// REUSEPORT.
+TEST_P(SocketMultiProtocolInetLoopbackTest, NoReusePortFollowingReusePort) {
+ auto const& param = GetParam();
+ TestAddress const& test_addr = V4Loopback();
+ sockaddr_storage addr = test_addr.addr;
+
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ int fd = s.get();
+ socklen_t addrlen = test_addr.addr_len;
+ int portreuse = 1;
+ ASSERT_THAT(
+ setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &portreuse, sizeof(portreuse)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(fd, reinterpret_cast<sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+ ASSERT_THAT(getsockname(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ ASSERT_EQ(addrlen, test_addr.addr_len);
+
+ s.reset();
+
+ // Open a new socket and bind to the same address, but w/o REUSEPORT.
+ s = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ fd = s.get();
+ portreuse = 0;
+ ASSERT_THAT(
+ setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &portreuse, sizeof(portreuse)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(fd, reinterpret_cast<sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllFamilies, SocketMultiProtocolInetLoopbackTest,
+ ::testing::Values(ProtocolTestParam{"TCP", SOCK_STREAM},
+ ProtocolTestParam{"UDP", SOCK_DGRAM}),
+ DescribeProtocolTestParam);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc
new file mode 100644
index 000000000..2324c7f6a
--- /dev/null
+++ b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc
@@ -0,0 +1,171 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <string.h>
+
+#include <iostream>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+using ::testing::Gt;
+
+PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) {
+ switch (family) {
+ case AF_INET:
+ return static_cast<uint16_t>(
+ reinterpret_cast<sockaddr_in const*>(&addr)->sin_port);
+ case AF_INET6:
+ return static_cast<uint16_t>(
+ reinterpret_cast<sockaddr_in6 const*>(&addr)->sin6_port);
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) {
+ switch (family) {
+ case AF_INET:
+ reinterpret_cast<sockaddr_in*>(addr)->sin_port = port;
+ return NoError();
+ case AF_INET6:
+ reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = port;
+ return NoError();
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+struct TestParam {
+ TestAddress listener;
+ TestAddress connector;
+};
+
+std::string DescribeTestParam(::testing::TestParamInfo<TestParam> const& info) {
+ return absl::StrCat("Listen", info.param.listener.description, "_Connect",
+ info.param.connector.description);
+}
+
+using SocketInetLoopbackTest = ::testing::TestWithParam<TestParam>;
+
+// This test verifies that connect returns EADDRNOTAVAIL if all local ephemeral
+// ports are already in use for a given destination ip/port.
+// We disable S/R because this test creates a large number of sockets.
+TEST_P(SocketInetLoopbackTest, TestTCPPortExhaustion_NoRandomSave) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ constexpr int kBacklog = 10;
+ constexpr int kClients = 65536;
+
+ // Create the listening socket.
+ auto listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Disable cooperative S/R as we are making too many syscalls.
+ DisableSave ds;
+
+ // Now we keep opening connections till we run out of local ephemeral ports.
+ // and assert the error we get back.
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ std::vector<FileDescriptor> clients;
+ std::vector<FileDescriptor> servers;
+
+ for (int i = 0; i < kClients; i++) {
+ FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret == 0) {
+ clients.push_back(std::move(client));
+ FileDescriptor server =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+ servers.push_back(std::move(server));
+ continue;
+ }
+ ASSERT_THAT(ret, SyscallFailsWithErrno(EADDRNOTAVAIL));
+ break;
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ All, SocketInetLoopbackTest,
+ ::testing::Values(
+ // Listeners bound to IPv4 addresses refuse connections using IPv6
+ // addresses.
+ TestParam{V4Any(), V4Any()}, TestParam{V4Any(), V4Loopback()},
+ TestParam{V4Any(), V4MappedAny()},
+ TestParam{V4Any(), V4MappedLoopback()},
+ TestParam{V4Loopback(), V4Any()}, TestParam{V4Loopback(), V4Loopback()},
+ TestParam{V4Loopback(), V4MappedLoopback()},
+ TestParam{V4MappedAny(), V4Any()},
+ TestParam{V4MappedAny(), V4Loopback()},
+ TestParam{V4MappedAny(), V4MappedAny()},
+ TestParam{V4MappedAny(), V4MappedLoopback()},
+ TestParam{V4MappedLoopback(), V4Any()},
+ TestParam{V4MappedLoopback(), V4Loopback()},
+ TestParam{V4MappedLoopback(), V4MappedLoopback()},
+
+ // Listeners bound to IN6ADDR_ANY accept all connections.
+ TestParam{V6Any(), V4Any()}, TestParam{V6Any(), V4Loopback()},
+ TestParam{V6Any(), V4MappedAny()},
+ TestParam{V6Any(), V4MappedLoopback()}, TestParam{V6Any(), V6Any()},
+ TestParam{V6Any(), V6Loopback()},
+
+ // Listeners bound to IN6ADDR_LOOPBACK refuse connections using IPv4
+ // addresses.
+ TestParam{V6Loopback(), V6Any()},
+ TestParam{V6Loopback(), V6Loopback()}),
+ DescribeTestParam);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_loopback_blocking.cc b/test/syscalls/linux/socket_ip_loopback_blocking.cc
new file mode 100644
index 000000000..fda252dd7
--- /dev/null
+++ b/test/syscalls/linux/socket_ip_loopback_blocking.cc
@@ -0,0 +1,49 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <netinet/tcp.h>
+
+#include <vector>
+
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_blocking.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return VecCat<SocketPairKind>(
+ std::vector<SocketPairKind>{
+ IPv6UDPBidirectionalBindSocketPair(0),
+ IPv4UDPBidirectionalBindSocketPair(0),
+ },
+ ApplyVecToVec<SocketPairKind>(
+ std::vector<Middleware>{
+ NoOp, SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &kSockOptOn)},
+ std::vector<SocketPairKind>{
+ IPv6TCPAcceptBindSocketPair(0),
+ IPv4TCPAcceptBindSocketPair(0),
+ }));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ BlockingIPSockets, BlockingSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc
new file mode 100644
index 000000000..c2ecb639f
--- /dev/null
+++ b/test/syscalls/linux/socket_ip_tcp_generic.cc
@@ -0,0 +1,1054 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_ip_tcp_generic.h"
+
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <poll.h>
+#include <stdio.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+TEST_P(TCPSocketPairTest, TcpInfoSucceeds) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct tcp_info opt = {};
+ socklen_t optLen = sizeof(opt);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_TCP, TCP_INFO, &opt, &optLen),
+ SyscallSucceeds());
+}
+
+TEST_P(TCPSocketPairTest, ShortTcpInfoSucceeds) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct tcp_info opt = {};
+ socklen_t optLen = 1;
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_TCP, TCP_INFO, &opt, &optLen),
+ SyscallSucceeds());
+}
+
+TEST_P(TCPSocketPairTest, ZeroTcpInfoSucceeds) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct tcp_info opt = {};
+ socklen_t optLen = 0;
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_TCP, TCP_INFO, &opt, &optLen),
+ SyscallSucceeds());
+}
+
+// This test validates that an RST is sent instead of a FIN when data is
+// unread on calls to close(2).
+TEST_P(TCPSocketPairTest, RSTSentOnCloseWithUnreadData) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char buf[10] = {};
+ ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Wait until t_ sees the data on its side but don't read it.
+ struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0};
+ constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data.
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+
+ // Now close the connected without reading the data.
+ ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
+
+ // Wait for the other end to receive the RST (up to 20 seconds).
+ struct pollfd poll_fd2 = {sockets->first_fd(), POLLIN | POLLHUP, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd2, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+
+ // A shutdown with unread data will cause a RST to be sent instead
+ // of a FIN, per RFC 2525 section 2.17; this is also what Linux does.
+ ASSERT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallFailsWithErrno(ECONNRESET));
+}
+
+// This test will validate that a RST will cause POLLHUP to trigger.
+TEST_P(TCPSocketPairTest, RSTCausesPollHUP) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char buf[10] = {};
+ ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Wait until second sees the data on its side but don't read it.
+ struct pollfd poll_fd = {sockets->second_fd(), POLLIN, 0};
+ constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data.
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+ EXPECT_EQ(poll_fd.revents & POLLIN, POLLIN);
+
+ // Confirm we at least have one unread byte.
+ int bytes_available = 0;
+ ASSERT_THAT(
+ RetryEINTR(ioctl)(sockets->second_fd(), FIONREAD, &bytes_available),
+ SyscallSucceeds());
+ EXPECT_GT(bytes_available, 0);
+
+ // Now close the connected socket without reading the data from the second,
+ // this will cause a RST and we should see that with POLLHUP.
+ ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
+
+ // Wait for the other end to receive the RST (up to 20 seconds).
+ struct pollfd poll_fd3 = {sockets->first_fd(), POLLHUP, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd3, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+ ASSERT_NE(poll_fd3.revents & POLLHUP, 0);
+}
+
+// This test validates that even if a RST is sent the other end will not
+// get an ECONNRESET until it's read all data.
+TEST_P(TCPSocketPairTest, RSTSentOnCloseWithUnreadDataAllowsReadBuffered) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char buf[10] = {};
+ ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ ASSERT_THAT(RetryEINTR(write)(sockets->second_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Wait until second sees the data on its side but don't read it.
+ struct pollfd poll_fd = {sockets->second_fd(), POLLIN, 0};
+ constexpr int kPollTimeoutMs = 30000; // Wait up to 30 seconds for the data.
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+
+ // Wait until first sees the data on its side but don't read it.
+ struct pollfd poll_fd2 = {sockets->first_fd(), POLLIN, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd2, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+
+ // Now close the connected socket without reading the data from the second.
+ ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
+
+ // Wait for the other end to receive the RST (up to 30 seconds).
+ struct pollfd poll_fd3 = {sockets->first_fd(), POLLHUP, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd3, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+
+ // Since we also have data buffered we should be able to read it before
+ // the syscall will fail with ECONNRESET.
+ ASSERT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // A shutdown with unread data will cause a RST to be sent instead
+ // of a FIN, per RFC 2525 section 2.17; this is also what Linux does.
+ ASSERT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallFailsWithErrno(ECONNRESET));
+}
+
+// This test will verify that a clean shutdown (FIN) is preformed when there
+// is unread data but only the write side is closed.
+TEST_P(TCPSocketPairTest, FINSentOnShutdownWrWithUnreadData) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char buf[10] = {};
+ ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Wait until t_ sees the data on its side but don't read it.
+ struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0};
+ constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data.
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+
+ // Now shutdown the write end leaving the read end open.
+ ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_WR), SyscallSucceeds());
+
+ // Wait for the other end to receive the FIN (up to 20 seconds).
+ struct pollfd poll_fd2 = {sockets->first_fd(), POLLIN | POLLHUP, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd2, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+
+ // Since we didn't shutdown the read end this will be a clean close.
+ ASSERT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(0));
+}
+
+// This test will verify that when data is received by a socket, even if it's
+// not read SHUT_RD will not cause any packets to be generated.
+TEST_P(TCPSocketPairTest, ShutdownRdShouldCauseNoPacketsWithUnreadData) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char buf[10] = {};
+ ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Wait until t_ sees the data on its side but don't read it.
+ struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0};
+ constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data.
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+
+ // Now shutdown the read end, this will generate no packets to the other end.
+ ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RD), SyscallSucceeds());
+
+ // We should not receive any events on the other side of the socket.
+ struct pollfd poll_fd2 = {sockets->first_fd(), POLLIN | POLLHUP, 0};
+ constexpr int kPollNoResponseTimeoutMs = 3000;
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd2, 1, kPollNoResponseTimeoutMs),
+ SyscallSucceedsWithValue(0)); // Timeout.
+}
+
+// This test will verify that a socket which has unread data will still allow
+// the data to be read after shutting down the read side, and once there is no
+// unread data left, then read will return an EOF.
+TEST_P(TCPSocketPairTest, ShutdownRdAllowsReadOfReceivedDataBeforeEOF) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char buf[10] = {};
+ ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Wait until t_ sees the data on its side but don't read it.
+ struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0};
+ constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data.
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+
+ // Now shutdown the read end.
+ ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RD), SyscallSucceeds());
+
+ // Even though we did a SHUT_RD on the read end we can still read the data.
+ ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // After reading all of the data, reading the closed read end returns EOF.
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+ ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(0));
+}
+
+// This test verifies that a shutdown(wr) by the server after sending
+// data allows the client to still read() the queued data and a client
+// close after sending response allows server to read the incoming
+// response.
+TEST_P(TCPSocketPairTest, ShutdownWrServerClientClose) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char buf[10] = {};
+ ScopedThread t([&]() {
+ ASSERT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ ASSERT_THAT(close(sockets->release_first_fd()),
+ SyscallSucceedsWithValue(0));
+ });
+ ASSERT_THAT(RetryEINTR(write)(sockets->second_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ ASSERT_THAT(RetryEINTR(shutdown)(sockets->second_fd(), SHUT_WR),
+ SyscallSucceedsWithValue(0));
+ t.Join();
+
+ ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+}
+
+TEST_P(TCPSocketPairTest, ClosedReadNonBlockingSocket) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // Set the read end to O_NONBLOCK.
+ int opts = 0;
+ ASSERT_THAT(opts = fcntl(sockets->second_fd(), F_GETFL), SyscallSucceeds());
+ ASSERT_THAT(fcntl(sockets->second_fd(), F_SETFL, opts | O_NONBLOCK),
+ SyscallSucceeds());
+
+ char buf[10] = {};
+ ASSERT_THAT(RetryEINTR(send)(sockets->first_fd(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Wait until second_fd sees the data and then recv it.
+ struct pollfd poll_fd = {sockets->second_fd(), POLLIN, 0};
+ constexpr int kPollTimeoutMs = 2000; // Wait up to 2 seconds for the data.
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Now shutdown the write end leaving the read end open.
+ ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+
+ // Wait for close notification and recv again.
+ struct pollfd poll_fd2 = {sockets->second_fd(), POLLIN, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd2, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_P(TCPSocketPairTest,
+ ShutdownRdUnreadDataShouldCauseNoPacketsUnlessClosed) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char buf[10] = {};
+ ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Wait until t_ sees the data on its side but don't read it.
+ struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0};
+ constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data.
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+
+ // Now shutdown the read end, this will generate no packets to the other end.
+ ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RD), SyscallSucceeds());
+
+ // We should not receive any events on the other side of the socket.
+ struct pollfd poll_fd2 = {sockets->first_fd(), POLLIN | POLLHUP, 0};
+ constexpr int kPollNoResponseTimeoutMs = 3000;
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd2, 1, kPollNoResponseTimeoutMs),
+ SyscallSucceedsWithValue(0)); // Timeout.
+
+ // Now since we've fully closed the connection it will generate a RST.
+ ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd2, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1)); // The other end has closed.
+
+ // A shutdown with unread data will cause a RST to be sent instead
+ // of a FIN, per RFC 2525 section 2.17; this is also what Linux does.
+ ASSERT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallFailsWithErrno(ECONNRESET));
+}
+
+TEST_P(TCPSocketPairTest, TCPCorkDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CORK, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+TEST_P(TCPSocketPairTest, SetTCPCork) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CORK,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CORK, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CORK,
+ &kSockOptOff, sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CORK, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+TEST_P(TCPSocketPairTest, TCPCork) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CORK,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ constexpr char kData[] = "abc";
+ ASSERT_THAT(WriteFd(sockets->first_fd(), kData, sizeof(kData)),
+ SyscallSucceedsWithValue(sizeof(kData)));
+
+ ASSERT_NO_FATAL_FAILURE(RecvNoData(sockets->second_fd()));
+
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CORK,
+ &kSockOptOff, sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ // Create a receive buffer larger than kData.
+ char buf[(sizeof(kData) + 1) * 2] = {};
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(kData)));
+ EXPECT_EQ(absl::string_view(kData, sizeof(kData)),
+ absl::string_view(buf, sizeof(kData)));
+}
+
+TEST_P(TCPSocketPairTest, TCPQuickAckDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_QUICKACK, &get,
+ &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+}
+
+TEST_P(TCPSocketPairTest, SetTCPQuickAck) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_QUICKACK,
+ &kSockOptOff, sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_QUICKACK, &get,
+ &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_QUICKACK,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_QUICKACK, &get,
+ &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+}
+
+TEST_P(TCPSocketPairTest, SoKeepaliveDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_KEEPALIVE, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+TEST_P(TCPSocketPairTest, SetSoKeepalive) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_KEEPALIVE,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_KEEPALIVE, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_KEEPALIVE,
+ &kSockOptOff, sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_KEEPALIVE, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+TEST_P(TCPSocketPairTest, TCPKeepidleDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPIDLE, &get,
+ &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 2 * 60 * 60); // 2 hours.
+}
+
+TEST_P(TCPSocketPairTest, TCPKeepintvlDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPINTVL, &get,
+ &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 75); // 75 seconds.
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepidleZero) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kZero = 0;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPIDLE, &kZero,
+ sizeof(kZero)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepintvlZero) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kZero = 0;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPINTVL,
+ &kZero, sizeof(kZero)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// Copied from include/net/tcp.h.
+constexpr int MAX_TCP_KEEPIDLE = 32767;
+constexpr int MAX_TCP_KEEPINTVL = 32767;
+constexpr int MAX_TCP_KEEPCNT = 127;
+
+TEST_P(TCPSocketPairTest, SetTCPKeepidleAboveMax) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kAboveMax = MAX_TCP_KEEPIDLE + 1;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPIDLE,
+ &kAboveMax, sizeof(kAboveMax)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepintvlAboveMax) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kAboveMax = MAX_TCP_KEEPINTVL + 1;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPINTVL,
+ &kAboveMax, sizeof(kAboveMax)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepidleToMax) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPIDLE,
+ &MAX_TCP_KEEPIDLE, sizeof(MAX_TCP_KEEPIDLE)),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPIDLE, &get,
+ &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, MAX_TCP_KEEPIDLE);
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepintvlToMax) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPINTVL,
+ &MAX_TCP_KEEPINTVL, sizeof(MAX_TCP_KEEPINTVL)),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPINTVL, &get,
+ &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, MAX_TCP_KEEPINTVL);
+}
+
+TEST_P(TCPSocketPairTest, TCPKeepcountDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 9); // 9 keepalive probes.
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountZero) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kZero = 0;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &kZero,
+ sizeof(kZero)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountAboveMax) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kAboveMax = MAX_TCP_KEEPCNT + 1;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT,
+ &kAboveMax, sizeof(kAboveMax)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountToMax) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT,
+ &MAX_TCP_KEEPCNT, sizeof(MAX_TCP_KEEPCNT)),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, MAX_TCP_KEEPCNT);
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountToOne) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int keepaliveCount = 1;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT,
+ &keepaliveCount, sizeof(keepaliveCount)),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, keepaliveCount);
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountToNegative) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int keepaliveCount = -5;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT,
+ &keepaliveCount, sizeof(keepaliveCount)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(TCPSocketPairTest, SetOOBInline) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_OOBINLINE,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_OOBINLINE, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+}
+
+TEST_P(TCPSocketPairTest, MsgTruncMsgPeek) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[512];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ // Read half of the data with MSG_TRUNC | MSG_PEEK. This way there will still
+ // be some data left to read in the next step even if the data gets consumed.
+ char received_data1[sizeof(sent_data) / 2] = {};
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data1,
+ sizeof(received_data1), MSG_TRUNC | MSG_PEEK),
+ SyscallSucceedsWithValue(sizeof(received_data1)));
+
+ // Check that we didn't get anything.
+ char zeros[sizeof(received_data1)] = {};
+ EXPECT_EQ(0, memcmp(zeros, received_data1, sizeof(received_data1)));
+
+ // Check that all of the data is still there.
+ char received_data2[sizeof(sent_data)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data2,
+ sizeof(received_data2), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ EXPECT_EQ(0, memcmp(received_data2, sent_data, sizeof(sent_data)));
+}
+
+TEST_P(TCPSocketPairTest, SetCongestionControlSucceedsForSupported) {
+ // This is Linux's net/tcp.h TCP_CA_NAME_MAX.
+ const int kTcpCaNameMax = 16;
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ // Netstack only supports reno & cubic so we only test these two values here.
+ {
+ const char kSetCC[kTcpCaNameMax] = "reno";
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &kSetCC, strlen(kSetCC)),
+ SyscallSucceedsWithValue(0));
+
+ char got_cc[kTcpCaNameMax];
+ memset(got_cc, '1', sizeof(got_cc));
+ socklen_t optlen = sizeof(got_cc);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC)));
+ }
+ {
+ const char kSetCC[kTcpCaNameMax] = "cubic";
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &kSetCC, strlen(kSetCC)),
+ SyscallSucceedsWithValue(0));
+
+ char got_cc[kTcpCaNameMax];
+ memset(got_cc, '1', sizeof(got_cc));
+ socklen_t optlen = sizeof(got_cc);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC)));
+ }
+}
+
+TEST_P(TCPSocketPairTest, SetGetTCPCongestionShortReadBuffer) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ {
+ // Verify that getsockopt/setsockopt work with buffers smaller than
+ // kTcpCaNameMax.
+ const char kSetCC[] = "cubic";
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &kSetCC, strlen(kSetCC)),
+ SyscallSucceedsWithValue(0));
+
+ char got_cc[sizeof(kSetCC)];
+ socklen_t optlen = sizeof(got_cc);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(got_cc)));
+ }
+}
+
+TEST_P(TCPSocketPairTest, SetGetTCPCongestionLargeReadBuffer) {
+ // This is Linux's net/tcp.h TCP_CA_NAME_MAX.
+ const int kTcpCaNameMax = 16;
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ {
+ // Verify that getsockopt works with buffers larger than
+ // kTcpCaNameMax.
+ const char kSetCC[] = "cubic";
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &kSetCC, strlen(kSetCC)),
+ SyscallSucceedsWithValue(0));
+
+ char got_cc[kTcpCaNameMax + 5];
+ socklen_t optlen = sizeof(got_cc);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ // Linux copies the minimum of kTcpCaNameMax or the length of the passed in
+ // buffer and sets optlen to the number of bytes actually copied
+ // irrespective of the actual length of the congestion control name.
+ EXPECT_EQ(kTcpCaNameMax, optlen);
+ EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC)));
+ }
+}
+
+TEST_P(TCPSocketPairTest, SetCongestionControlFailsForUnsupported) {
+ // This is Linux's net/tcp.h TCP_CA_NAME_MAX.
+ const int kTcpCaNameMax = 16;
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char old_cc[kTcpCaNameMax];
+ socklen_t optlen = sizeof(old_cc);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &old_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+
+ const char kSetCC[] = "invalid_ca_cc";
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &kSetCC, strlen(kSetCC)),
+ SyscallFailsWithErrno(ENOENT));
+
+ char got_cc[kTcpCaNameMax];
+ optlen = sizeof(got_cc);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(old_cc)));
+}
+
+// Linux and Netstack both default to a 60s TCP_LINGER2 timeout.
+constexpr int kDefaultTCPLingerTimeout = 60;
+
+TEST_P(TCPSocketPairTest, TCPLingerTimeoutDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kDefaultTCPLingerTimeout);
+}
+
+TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutZeroOrLess) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kZero = 0;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kZero,
+ sizeof(kZero)),
+ SyscallSucceedsWithValue(0));
+
+ constexpr int kNegative = -1234;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2,
+ &kNegative, sizeof(kNegative)),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutAboveDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout
+ // on linux (defaults to 60 seconds on linux).
+ constexpr int kAboveDefault = kDefaultTCPLingerTimeout + 1;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2,
+ &kAboveDefault, sizeof(kAboveDefault)),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kDefaultTCPLingerTimeout);
+}
+
+TEST_P(TCPSocketPairTest, SetTCPLingerTimeout) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout
+ // on linux (defaults to 60 seconds on linux).
+ constexpr int kTCPLingerTimeout = kDefaultTCPLingerTimeout - 1;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2,
+ &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kTCPLingerTimeout);
+}
+
+TEST_P(TCPSocketPairTest, TestTCPCloseWithData) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ScopedThread t([&]() {
+ // Close one end to trigger sending of a FIN.
+ ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_WR), SyscallSucceeds());
+ char buf[3];
+ ASSERT_THAT(read(sockets->second_fd(), buf, 3),
+ SyscallSucceedsWithValue(3));
+ absl::SleepFor(absl::Milliseconds(50));
+ ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
+ });
+
+ absl::SleepFor(absl::Milliseconds(50));
+ // Send some data then close.
+ constexpr char kStr[] = "abc";
+ ASSERT_THAT(write(sockets->first_fd(), kStr, 3), SyscallSucceedsWithValue(3));
+ t.Join();
+ ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+}
+
+TEST_P(TCPSocketPairTest, TCPUserTimeoutDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 0); // 0 ms (disabled).
+}
+
+TEST_P(TCPSocketPairTest, SetTCPUserTimeoutZero) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kZero = 0;
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &kZero, sizeof(kZero)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 0); // 0 ms (disabled).
+}
+
+TEST_P(TCPSocketPairTest, SetTCPUserTimeoutBelowZero) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kNeg = -10;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &kNeg, sizeof(kNeg)),
+ SyscallFailsWithErrno(EINVAL));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 0); // 0 ms (disabled).
+}
+
+TEST_P(TCPSocketPairTest, SetTCPUserTimeoutAboveZero) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kAbove = 10;
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &kAbove, sizeof(kAbove)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kAbove);
+}
+
+TEST_P(TCPSocketPairTest, SetTCPWindowClampBelowMinRcvBufConnectedSocket) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ // Discover minimum receive buf by setting a really low value
+ // for the receive buffer.
+ constexpr int kZero = 0;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVBUF, &kZero,
+ sizeof(kZero)),
+ SyscallSucceeds());
+
+ // Now retrieve the minimum value for SO_RCVBUF as the set above should
+ // have caused SO_RCVBUF for the socket to be set to the minimum.
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVBUF, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ int min_so_rcvbuf = get;
+
+ {
+ // Setting TCP_WINDOW_CLAMP to zero for a connected socket is not permitted.
+ constexpr int kZero = 0;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_WINDOW_CLAMP,
+ &kZero, sizeof(kZero)),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Non-zero clamp values below MIN_SO_RCVBUF/2 should result in the clamp
+ // being set to MIN_SO_RCVBUF/2.
+ int below_half_min_so_rcvbuf = min_so_rcvbuf / 2 - 1;
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_WINDOW_CLAMP,
+ &below_half_min_so_rcvbuf, sizeof(below_half_min_so_rcvbuf)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_WINDOW_CLAMP,
+ &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(min_so_rcvbuf / 2, get);
+ }
+}
+
+TEST_P(TCPSocketPairTest, IpMulticastTtlDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
+ &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_GT(get, 0);
+}
+
+TEST_P(TCPSocketPairTest, IpMulticastLoopDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 1);
+}
+
+TEST_P(TCPSocketPairTest, TCPResetDuringClose_NoRandomSave) {
+ DisableSave ds; // Too many syscalls.
+ constexpr int kThreadCount = 1000;
+ std::unique_ptr<ScopedThread> instances[kThreadCount];
+ for (int i = 0; i < kThreadCount; i++) {
+ instances[i] = absl::make_unique<ScopedThread>([&]() {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ScopedThread t([&]() {
+ // Close one end to trigger sending of a FIN.
+ struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0};
+ // Wait up to 20 seconds for the data.
+ constexpr int kPollTimeoutMs = 20000;
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+ ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
+ });
+
+ // Send some data then close.
+ constexpr char kStr[] = "abc";
+ ASSERT_THAT(write(sockets->first_fd(), kStr, 3),
+ SyscallSucceedsWithValue(3));
+ absl::SleepFor(absl::Milliseconds(10));
+ ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+ t.Join();
+ });
+ }
+ for (int i = 0; i < kThreadCount; i++) {
+ instances[i]->Join();
+ }
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_tcp_generic.h b/test/syscalls/linux/socket_ip_tcp_generic.h
new file mode 100644
index 000000000..a3eff3c73
--- /dev/null
+++ b/test/syscalls/linux/socket_ip_tcp_generic.h
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IP_TCP_GENERIC_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IP_TCP_GENERIC_H_
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of connected TCP sockets.
+using TCPSocketPairTest = SocketPairTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IP_TCP_GENERIC_H_
diff --git a/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc b/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc
new file mode 100644
index 000000000..4e79d21f4
--- /dev/null
+++ b/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc
@@ -0,0 +1,45 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <netinet/tcp.h>
+
+#include <vector>
+
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_ip_tcp_generic.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return ApplyVecToVec<SocketPairKind>(
+ std::vector<Middleware>{
+ NoOp, SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &kSockOptOn)},
+ std::vector<SocketPairKind>{
+ IPv6TCPAcceptBindSocketPair(0),
+ IPv4TCPAcceptBindSocketPair(0),
+ DualStackTCPAcceptBindSocketPair(0),
+ });
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllTCPSockets, TCPSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_tcp_loopback.cc b/test/syscalls/linux/socket_ip_tcp_loopback.cc
new file mode 100644
index 000000000..9db3037bc
--- /dev/null
+++ b/test/syscalls/linux/socket_ip_tcp_loopback.cc
@@ -0,0 +1,40 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_generic.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return {
+ IPv6TCPAcceptBindSocketPair(0),
+ IPv4TCPAcceptBindSocketPair(0),
+ DualStackTCPAcceptBindSocketPair(0),
+ };
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllUnixDomainSockets, AllSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc b/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc
new file mode 100644
index 000000000..f996b93d2
--- /dev/null
+++ b/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc
@@ -0,0 +1,45 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <netinet/tcp.h>
+
+#include <vector>
+
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_stream_blocking.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return ApplyVecToVec<SocketPairKind>(
+ std::vector<Middleware>{
+ NoOp, SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &kSockOptOn)},
+ std::vector<SocketPairKind>{
+ IPv6TCPAcceptBindSocketPair(0),
+ IPv4TCPAcceptBindSocketPair(0),
+ DualStackTCPAcceptBindSocketPair(0),
+ });
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ BlockingTCPSockets, BlockingStreamSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc b/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc
new file mode 100644
index 000000000..ffa377210
--- /dev/null
+++ b/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc
@@ -0,0 +1,44 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <netinet/tcp.h>
+
+#include <vector>
+
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_non_blocking.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return ApplyVecToVec<SocketPairKind>(
+ std::vector<Middleware>{
+ NoOp, SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &kSockOptOn)},
+ std::vector<SocketPairKind>{
+ IPv6TCPAcceptBindSocketPair(SOCK_NONBLOCK),
+ IPv4TCPAcceptBindSocketPair(SOCK_NONBLOCK),
+ });
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ NonBlockingTCPSockets, NonBlockingSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_tcp_udp_generic.cc b/test/syscalls/linux/socket_ip_tcp_udp_generic.cc
new file mode 100644
index 000000000..f178f1af9
--- /dev/null
+++ b/test/syscalls/linux/socket_ip_tcp_udp_generic.cc
@@ -0,0 +1,77 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <poll.h>
+#include <stdio.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Test fixture for tests that apply to pairs of TCP and UDP sockets.
+using TcpUdpSocketPairTest = SocketPairTest;
+
+TEST_P(TcpUdpSocketPairTest, ShutdownWrFollowedBySendIsError) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // Now shutdown the write end of the first.
+ ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_WR), SyscallSucceeds());
+
+ char buf[10] = {};
+ ASSERT_THAT(RetryEINTR(send)(sockets->first_fd(), buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(EPIPE));
+}
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return VecCat<SocketPairKind>(
+ ApplyVec<SocketPairKind>(
+ IPv6UDPBidirectionalBindSocketPair,
+ AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(
+ IPv4UDPBidirectionalBindSocketPair,
+ AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(
+ DualStackUDPBidirectionalBindSocketPair,
+ AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(
+ IPv6TCPAcceptBindSocketPair,
+ AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(
+ IPv4TCPAcceptBindSocketPair,
+ AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(
+ DualStackTCPAcceptBindSocketPair,
+ AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK})));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllIPSockets, TcpUdpSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc
new file mode 100644
index 000000000..edb86aded
--- /dev/null
+++ b/test/syscalls/linux/socket_ip_udp_generic.cc
@@ -0,0 +1,452 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_ip_udp_generic.h"
+
+#include <errno.h>
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <poll.h>
+#include <stdio.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+TEST_P(UDPSocketPairTest, MulticastTTLDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
+ &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 1);
+}
+
+TEST_P(UDPSocketPairTest, SetUDPMulticastTTLMin) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kMin = 0;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
+ &kMin, sizeof(kMin)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
+ &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kMin);
+}
+
+TEST_P(UDPSocketPairTest, SetUDPMulticastTTLMax) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kMax = 255;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
+ &kMax, sizeof(kMax)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
+ &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kMax);
+}
+
+TEST_P(UDPSocketPairTest, SetUDPMulticastTTLNegativeOne) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kArbitrary = 6;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
+ &kArbitrary, sizeof(kArbitrary)),
+ SyscallSucceeds());
+
+ constexpr int kNegOne = -1;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
+ &kNegOne, sizeof(kNegOne)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
+ &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 1);
+}
+
+TEST_P(UDPSocketPairTest, SetUDPMulticastTTLBelowMin) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kBelowMin = -2;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
+ &kBelowMin, sizeof(kBelowMin)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(UDPSocketPairTest, SetUDPMulticastTTLAboveMax) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kAboveMax = 256;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
+ &kAboveMax, sizeof(kAboveMax)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(UDPSocketPairTest, SetUDPMulticastTTLChar) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr char kArbitrary = 6;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
+ &kArbitrary, sizeof(kArbitrary)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
+ &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kArbitrary);
+}
+
+TEST_P(UDPSocketPairTest, SetEmptyIPAddMembership) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct ip_mreqn req = {};
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
+ &req, sizeof(req)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(UDPSocketPairTest, MulticastLoopDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+}
+
+TEST_P(UDPSocketPairTest, SetMulticastLoop) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ &kSockOptOff, sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+}
+
+TEST_P(UDPSocketPairTest, SetMulticastLoopChar) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr char kSockOptOnChar = kSockOptOn;
+ constexpr char kSockOptOffChar = kSockOptOff;
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ &kSockOptOffChar, sizeof(kSockOptOffChar)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ &kSockOptOnChar, sizeof(kSockOptOnChar)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+}
+
+TEST_P(UDPSocketPairTest, ReuseAddrDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+TEST_P(UDPSocketPairTest, SetReuseAddr) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOff, sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+TEST_P(UDPSocketPairTest, ReusePortDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+TEST_P(UDPSocketPairTest, SetReusePort) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT,
+ &kSockOptOff, sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+TEST_P(UDPSocketPairTest, SetReuseAddrReusePort) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+}
+
+// Test getsockopt for a socket which is not set with IP_PKTINFO option.
+TEST_P(UDPSocketPairTest, IPPKTINFODefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_IP, IP_PKTINFO, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+// Test setsockopt and getsockopt for a socket with IP_PKTINFO option.
+TEST_P(UDPSocketPairTest, SetAndGetIPPKTINFO) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int level = SOL_IP;
+ int type = IP_PKTINFO;
+
+ // Check getsockopt before IP_PKTINFO is set.
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOn);
+ EXPECT_EQ(get_len, sizeof(get));
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOff,
+ sizeof(kSockOptOff)),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOff);
+ EXPECT_EQ(get_len, sizeof(get));
+}
+
+// Holds TOS or TClass information for IPv4 or IPv6 respectively.
+struct RecvTosOption {
+ int level;
+ int option;
+};
+
+RecvTosOption GetRecvTosOption(int domain) {
+ TEST_CHECK(domain == AF_INET || domain == AF_INET6);
+ RecvTosOption opt;
+ switch (domain) {
+ case AF_INET:
+ opt.level = IPPROTO_IP;
+ opt.option = IP_RECVTOS;
+ break;
+ case AF_INET6:
+ opt.level = IPPROTO_IPV6;
+ opt.option = IPV6_RECVTCLASS;
+ break;
+ }
+ return opt;
+}
+
+// Ensure that Receiving TOS or TCLASS is off by default.
+TEST_P(UDPSocketPairTest, RecvTosDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ RecvTosOption t = GetRecvTosOption(GetParam().domain);
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+// Test that setting and getting IP_RECVTOS or IPV6_RECVTCLASS works as
+// expected.
+TEST_P(UDPSocketPairTest, SetRecvTos) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ RecvTosOption t = GetRecvTosOption(GetParam().domain);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOff,
+ sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+}
+
+// Test that any socket (including IPv6 only) accepts the IPv4 TOS option: this
+// mirrors behavior in linux.
+TEST_P(UDPSocketPairTest, TOSRecvMismatch) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ RecvTosOption t = GetRecvTosOption(AF_INET);
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+}
+
+// Test that an IPv4 socket does not support the IPv6 TClass option.
+TEST_P(UDPSocketPairTest, TClassRecvMismatch) {
+ // This should only test AF_INET sockets for the mismatch behavior.
+ SKIP_IF(GetParam().domain != AF_INET);
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IPV6, IPV6_RECVTCLASS,
+ &get, &get_len),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_udp_generic.h b/test/syscalls/linux/socket_ip_udp_generic.h
new file mode 100644
index 000000000..106c54e9f
--- /dev/null
+++ b/test/syscalls/linux/socket_ip_udp_generic.h
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IP_UDP_GENERIC_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IP_UDP_GENERIC_H_
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of connected UDP sockets.
+using UDPSocketPairTest = SocketPairTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IP_UDP_GENERIC_H_
diff --git a/test/syscalls/linux/socket_ip_udp_loopback.cc b/test/syscalls/linux/socket_ip_udp_loopback.cc
new file mode 100644
index 000000000..c7fa44884
--- /dev/null
+++ b/test/syscalls/linux/socket_ip_udp_loopback.cc
@@ -0,0 +1,50 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_generic.h"
+#include "test/syscalls/linux/socket_ip_udp_generic.h"
+#include "test/syscalls/linux/socket_non_stream.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return {
+ IPv6UDPBidirectionalBindSocketPair(0),
+ IPv4UDPBidirectionalBindSocketPair(0),
+ DualStackUDPBidirectionalBindSocketPair(0),
+ };
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllUDPSockets, AllSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+INSTANTIATE_TEST_SUITE_P(
+ AllUDPSockets, NonStreamSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+INSTANTIATE_TEST_SUITE_P(
+ AllUDPSockets, UDPSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc b/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc
new file mode 100644
index 000000000..d6925a8df
--- /dev/null
+++ b/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc
@@ -0,0 +1,39 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_non_stream_blocking.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return {
+ IPv6UDPBidirectionalBindSocketPair(0),
+ IPv4UDPBidirectionalBindSocketPair(0),
+ };
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ BlockingUDPSockets, BlockingNonStreamSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc b/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc
new file mode 100644
index 000000000..d675eddc6
--- /dev/null
+++ b/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc
@@ -0,0 +1,39 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_non_blocking.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return {
+ IPv6UDPBidirectionalBindSocketPair(SOCK_NONBLOCK),
+ IPv4UDPBidirectionalBindSocketPair(SOCK_NONBLOCK),
+ };
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ NonBlockingUDPSockets, NonBlockingSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc
new file mode 100644
index 000000000..1c7b0cf90
--- /dev/null
+++ b/test/syscalls/linux/socket_ip_unbound.cc
@@ -0,0 +1,474 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <netinet/in.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <cstdio>
+#include <cstring>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of IP sockets.
+using IPUnboundSocketTest = SimpleSocketTest;
+
+TEST_P(IPUnboundSocketTest, TtlDefault) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_TRUE(get == 64 || get == 127);
+ EXPECT_EQ(get_sz, sizeof(get));
+}
+
+TEST_P(IPUnboundSocketTest, SetTtl) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int get1 = -1;
+ socklen_t get1_sz = sizeof(get1);
+ EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get1, &get1_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get1_sz, sizeof(get1));
+
+ int set = 100;
+ if (set == get1) {
+ set += 1;
+ }
+ socklen_t set_sz = sizeof(set);
+ EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set, set_sz),
+ SyscallSucceedsWithValue(0));
+
+ int get2 = -1;
+ socklen_t get2_sz = sizeof(get2);
+ EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get2, &get2_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get2_sz, sizeof(get2));
+ EXPECT_EQ(get2, set);
+}
+
+TEST_P(IPUnboundSocketTest, ResetTtlToDefault) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int get1 = -1;
+ socklen_t get1_sz = sizeof(get1);
+ EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get1, &get1_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get1_sz, sizeof(get1));
+
+ int set1 = 100;
+ if (set1 == get1) {
+ set1 += 1;
+ }
+ socklen_t set1_sz = sizeof(set1);
+ EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set1, set1_sz),
+ SyscallSucceedsWithValue(0));
+
+ int set2 = -1;
+ socklen_t set2_sz = sizeof(set2);
+ EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set2, set2_sz),
+ SyscallSucceedsWithValue(0));
+
+ int get2 = -1;
+ socklen_t get2_sz = sizeof(get2);
+ EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get2, &get2_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get2_sz, sizeof(get2));
+ EXPECT_EQ(get2, get1);
+}
+
+TEST_P(IPUnboundSocketTest, ZeroTtl) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int set = 0;
+ socklen_t set_sz = sizeof(set);
+ EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set, set_sz),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(IPUnboundSocketTest, InvalidLargeTtl) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int set = 256;
+ socklen_t set_sz = sizeof(set);
+ EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set, set_sz),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(IPUnboundSocketTest, InvalidNegativeTtl) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int set = -2;
+ socklen_t set_sz = sizeof(set);
+ EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set, set_sz),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+struct TOSOption {
+ int level;
+ int option;
+ int cmsg_level;
+};
+
+constexpr int INET_ECN_MASK = 3;
+
+static TOSOption GetTOSOption(int domain) {
+ TOSOption opt;
+ switch (domain) {
+ case AF_INET:
+ opt.level = IPPROTO_IP;
+ opt.option = IP_TOS;
+ opt.cmsg_level = SOL_IP;
+ break;
+ case AF_INET6:
+ opt.level = IPPROTO_IPV6;
+ opt.option = IPV6_TCLASS;
+ opt.cmsg_level = SOL_IPV6;
+ break;
+ }
+ return opt;
+}
+
+TEST_P(IPUnboundSocketTest, TOSDefault) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ TOSOption t = GetTOSOption(GetParam().domain);
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ constexpr int kDefaultTOS = 0;
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, sizeof(get));
+ EXPECT_EQ(get, kDefaultTOS);
+}
+
+TEST_P(IPUnboundSocketTest, SetTOS) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ int set = 0xC0;
+ socklen_t set_sz = sizeof(set);
+ TOSOption t = GetTOSOption(GetParam().domain);
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, sizeof(get));
+ EXPECT_EQ(get, set);
+}
+
+TEST_P(IPUnboundSocketTest, ZeroTOS) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ int set = 0;
+ socklen_t set_sz = sizeof(set);
+ TOSOption t = GetTOSOption(GetParam().domain);
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallSucceedsWithValue(0));
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, sizeof(get));
+ EXPECT_EQ(get, set);
+}
+
+TEST_P(IPUnboundSocketTest, InvalidLargeTOS) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ // Test with exceeding the byte space.
+ int set = 256;
+ constexpr int kDefaultTOS = 0;
+ socklen_t set_sz = sizeof(set);
+ TOSOption t = GetTOSOption(GetParam().domain);
+ if (GetParam().domain == AF_INET) {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallSucceedsWithValue(0));
+ } else {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallFailsWithErrno(EINVAL));
+ }
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, sizeof(get));
+ EXPECT_EQ(get, kDefaultTOS);
+}
+
+TEST_P(IPUnboundSocketTest, CheckSkipECN) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ int set = 0xFF;
+ socklen_t set_sz = sizeof(set);
+ TOSOption t = GetTOSOption(GetParam().domain);
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallSucceedsWithValue(0));
+ int expect = static_cast<uint8_t>(set);
+ if (GetParam().protocol == IPPROTO_TCP) {
+ expect &= ~INET_ECN_MASK;
+ }
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, sizeof(get));
+ EXPECT_EQ(get, expect);
+}
+
+TEST_P(IPUnboundSocketTest, ZeroTOSOptionSize) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ int set = 0xC0;
+ socklen_t set_sz = 0;
+ TOSOption t = GetTOSOption(GetParam().domain);
+ if (GetParam().domain == AF_INET) {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallSucceedsWithValue(0));
+ } else {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallFailsWithErrno(EINVAL));
+ }
+ int get = -1;
+ socklen_t get_sz = 0;
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, 0);
+ EXPECT_EQ(get, -1);
+}
+
+TEST_P(IPUnboundSocketTest, SmallTOSOptionSize) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ int set = 0xC0;
+ constexpr int kDefaultTOS = 0;
+ TOSOption t = GetTOSOption(GetParam().domain);
+ for (socklen_t i = 1; i < sizeof(int); i++) {
+ int expect_tos;
+ socklen_t expect_sz;
+ if (GetParam().domain == AF_INET) {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, i),
+ SyscallSucceedsWithValue(0));
+ expect_tos = set;
+ expect_sz = sizeof(uint8_t);
+ } else {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, i),
+ SyscallFailsWithErrno(EINVAL));
+ expect_tos = kDefaultTOS;
+ expect_sz = i;
+ }
+ uint get = -1;
+ socklen_t get_sz = i;
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, expect_sz);
+ // Account for partial copies by getsockopt, retrieve the lower
+ // bits specified by get_sz, while comparing against expect_tos.
+ EXPECT_EQ(get & ~(~0 << (get_sz * 8)), expect_tos);
+ }
+}
+
+TEST_P(IPUnboundSocketTest, LargeTOSOptionSize) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ int set = 0xC0;
+ TOSOption t = GetTOSOption(GetParam().domain);
+ for (socklen_t i = sizeof(int); i < 10; i++) {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, i),
+ SyscallSucceedsWithValue(0));
+ int get = -1;
+ socklen_t get_sz = i;
+ // We expect the system call handler to only copy atmost sizeof(int) bytes
+ // as asserted by the check below. Hence, we do not expect the copy to
+ // overflow in getsockopt.
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, sizeof(int));
+ EXPECT_EQ(get, set);
+ }
+}
+
+TEST_P(IPUnboundSocketTest, NegativeTOS) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int set = -1;
+ socklen_t set_sz = sizeof(set);
+ TOSOption t = GetTOSOption(GetParam().domain);
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallSucceedsWithValue(0));
+ int expect;
+ if (GetParam().domain == AF_INET) {
+ expect = static_cast<uint8_t>(set);
+ if (GetParam().protocol == IPPROTO_TCP) {
+ expect &= ~INET_ECN_MASK;
+ }
+ } else {
+ // On IPv6 TCLASS, setting -1 has the effect of resetting the
+ // TrafficClass.
+ expect = 0;
+ }
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, sizeof(get));
+ EXPECT_EQ(get, expect);
+}
+
+TEST_P(IPUnboundSocketTest, InvalidNegativeTOS) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ int set = -2;
+ socklen_t set_sz = sizeof(set);
+ TOSOption t = GetTOSOption(GetParam().domain);
+ int expect;
+ if (GetParam().domain == AF_INET) {
+ ASSERT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallSucceedsWithValue(0));
+ expect = static_cast<uint8_t>(set);
+ if (GetParam().protocol == IPPROTO_TCP) {
+ expect &= ~INET_ECN_MASK;
+ }
+ } else {
+ ASSERT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallFailsWithErrno(EINVAL));
+ expect = 0;
+ }
+ int get = 0;
+ socklen_t get_sz = sizeof(get);
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, sizeof(get));
+ EXPECT_EQ(get, expect);
+}
+
+TEST_P(IPUnboundSocketTest, NullTOS) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ TOSOption t = GetTOSOption(GetParam().domain);
+ int set_sz = sizeof(int);
+ if (GetParam().domain == AF_INET) {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, nullptr, set_sz),
+ SyscallFailsWithErrno(EFAULT));
+ } else { // AF_INET6
+ // The AF_INET6 behavior is not yet compatible. gVisor will try to read
+ // optval from user memory at syscall handler, it needs substantial
+ // refactoring to implement this behavior just for IPv6.
+ if (IsRunningOnGvisor()) {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, nullptr, set_sz),
+ SyscallFailsWithErrno(EFAULT));
+ } else {
+ // Linux's IPv6 stack treats nullptr optval as input of 0, so the call
+ // succeeds. (net/ipv6/ipv6_sockglue.c, do_ipv6_setsockopt())
+ //
+ // Linux's implementation would need fixing as passing a nullptr as optval
+ // and non-zero optlen may not be valid.
+ // TODO(b/158666797): Combine the gVisor and linux cases for IPv6.
+ // Some kernel versions return EFAULT, so we handle both.
+ EXPECT_THAT(
+ setsockopt(socket->get(), t.level, t.option, nullptr, set_sz),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(0)));
+ }
+ }
+ socklen_t get_sz = sizeof(int);
+ EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, nullptr, &get_sz),
+ SyscallFailsWithErrno(EFAULT));
+ int get = -1;
+ EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, nullptr),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_P(IPUnboundSocketTest, InsufficientBufferTOS) {
+ SKIP_IF(GetParam().protocol == IPPROTO_TCP);
+
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ TOSOption t = GetTOSOption(GetParam().domain);
+
+ in_addr addr4;
+ in6_addr addr6;
+ ASSERT_THAT(inet_pton(AF_INET, "127.0.0.1", &addr4), ::testing::Eq(1));
+ ASSERT_THAT(inet_pton(AF_INET6, "fe80::", &addr6), ::testing::Eq(1));
+
+ cmsghdr cmsg = {};
+ cmsg.cmsg_len = sizeof(cmsg);
+ cmsg.cmsg_level = t.cmsg_level;
+ cmsg.cmsg_type = t.option;
+
+ msghdr msg = {};
+ msg.msg_control = &cmsg;
+ msg.msg_controllen = sizeof(cmsg);
+ if (GetParam().domain == AF_INET) {
+ msg.msg_name = &addr4;
+ msg.msg_namelen = sizeof(addr4);
+ } else {
+ msg.msg_name = &addr6;
+ msg.msg_namelen = sizeof(addr6);
+ }
+
+ EXPECT_THAT(sendmsg(socket->get(), &msg, 0), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(IPUnboundSocketTest, ReuseAddrDefault) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket->get(), SOL_SOCKET, SO_REUSEADDR, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOff);
+ EXPECT_EQ(get_sz, sizeof(get));
+}
+
+TEST_P(IPUnboundSocketTest, SetReuseAddr) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ASSERT_THAT(setsockopt(socket->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket->get(), SOL_SOCKET, SO_REUSEADDR, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOn);
+ EXPECT_EQ(get_sz, sizeof(get));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ IPUnboundSockets, IPUnboundSocketTest,
+ ::testing::ValuesIn(VecCat<SocketKind>(VecCat<SocketKind>(
+ ApplyVec<SocketKind>(IPv4UDPUnboundSocket,
+ AllBitwiseCombinations(List<int>{SOCK_DGRAM},
+ List<int>{0,
+ SOCK_NONBLOCK})),
+ ApplyVec<SocketKind>(IPv6UDPUnboundSocket,
+ AllBitwiseCombinations(List<int>{SOCK_DGRAM},
+ List<int>{0,
+ SOCK_NONBLOCK})),
+ ApplyVec<SocketKind>(IPv4TCPUnboundSocket,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM},
+ List<int>{0,
+ SOCK_NONBLOCK})),
+ ApplyVec<SocketKind>(IPv6TCPUnboundSocket,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM},
+ List<int>{
+ 0, SOCK_NONBLOCK}))))));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc
new file mode 100644
index 000000000..80f12b0a9
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc
@@ -0,0 +1,66 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.h"
+
+#include <netinet/in.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <cstdio>
+#include <cstring>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Verifies that a newly instantiated TCP socket does not have the
+// broadcast socket option enabled.
+TEST_P(IPv4TCPUnboundExternalNetworkingSocketTest, TCPBroadcastDefault) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(socket->get(), SOL_SOCKET, SO_BROADCAST, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOff);
+ EXPECT_EQ(get_sz, sizeof(get));
+}
+
+// Verifies that a newly instantiated TCP socket returns true after enabling
+// the broadcast socket option.
+TEST_P(IPv4TCPUnboundExternalNetworkingSocketTest, SetTCPBroadcast) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ EXPECT_THAT(setsockopt(socket->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(socket->get(), SOL_SOCKET, SO_BROADCAST, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOn);
+ EXPECT_EQ(get_sz, sizeof(get));
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.h b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.h
new file mode 100644
index 000000000..fb582b224
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.h
@@ -0,0 +1,30 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_TCP_UNBOUND_EXTERNAL_NETWORKING_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_TCP_UNBOUND_EXTERNAL_NETWORKING_H_
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to unbound IPv4 TCP sockets in a sandbox
+// with external networking support.
+using IPv4TCPUnboundExternalNetworkingSocketTest = SimpleSocketTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_TCP_UNBOUND_EXTERNAL_NETWORKING_H_
diff --git a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc
new file mode 100644
index 000000000..797c4174e
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc
@@ -0,0 +1,39 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.h"
+
+#include <vector>
+
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketKind> GetSockets() {
+ return ApplyVec<SocketKind>(
+ IPv4TCPUnboundSocket,
+ AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK}));
+}
+
+INSTANTIATE_TEST_SUITE_P(IPv4TCPUnboundSockets,
+ IPv4TCPUnboundExternalNetworkingSocketTest,
+ ::testing::ValuesIn(GetSockets()));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
new file mode 100644
index 000000000..de0f5f01b
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
@@ -0,0 +1,2456 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_ipv4_udp_unbound.h"
+
+#include <arpa/inet.h>
+#include <net/if.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <cstdio>
+
+#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Check that packets are not received without a group membership. Default send
+// interface configured by bind.
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNoGroup) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind the first FD to the loopback. This is an alternative to
+ // IP_MULTICAST_IF for setting the default send interface.
+ auto sender_addr = V4Loopback();
+ EXPECT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ sender_addr.addr_len),
+ SyscallSucceeds());
+
+ // Bind the second FD to the v4 any address. If multicast worked like unicast,
+ // this would ensure that we get the packet.
+ auto receiver_addr = V4Any();
+ EXPECT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket2->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Send the multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we did not receive the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// Check that not setting a default send interface prevents multicast packets
+// from being sent. Group membership interface configured by address.
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddrNoDefaultSendIf) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind the second FD to the v4 any address to ensure that we can receive any
+ // unicast packet.
+ auto receiver_addr = V4Any();
+ EXPECT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket2->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register to receive multicast packets.
+ ip_mreq group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallFailsWithErrno(ENETUNREACH));
+}
+
+// Check that not setting a default send interface prevents multicast packets
+// from being sent. Group membership interface configured by NIC ID.
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNicNoDefaultSendIf) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind the second FD to the v4 any address to ensure that we can receive any
+ // unicast packet.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket2->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register to receive multicast packets.
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallFailsWithErrno(ENETUNREACH));
+}
+
+// Check that multicast works when the default send interface is configured by
+// bind and the group membership is configured by address.
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind the first FD to the loopback. This is an alternative to
+ // IP_MULTICAST_IF for setting the default send interface.
+ auto sender_addr = V4Loopback();
+ ASSERT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ sender_addr.addr_len),
+ SyscallSucceeds());
+
+ // Bind the second FD to the v4 any address to ensure that we can receive the
+ // multicast packet.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket2->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register to receive multicast packets.
+ ip_mreq group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+}
+
+// Check that multicast works when the default send interface is configured by
+// bind and the group membership is configured by NIC ID.
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNic) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind the first FD to the loopback. This is an alternative to
+ // IP_MULTICAST_IF for setting the default send interface.
+ auto sender_addr = V4Loopback();
+ ASSERT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ sender_addr.addr_len),
+ SyscallSucceeds());
+
+ // Bind the second FD to the v4 any address to ensure that we can receive the
+ // multicast packet.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket2->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register to receive multicast packets.
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+}
+
+// Check that multicast works when the default send interface is configured by
+// IP_MULTICAST_IF, the send address is specified in sendto, and the group
+// membership is configured by address.
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Set the default send interface.
+ ip_mreq iface = {};
+ iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallSucceeds());
+
+ // Bind the second FD to the v4 any address to ensure that we can receive the
+ // multicast packet.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket2->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register to receive multicast packets.
+ ip_mreq group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+}
+
+// Check that multicast works when the default send interface is configured by
+// IP_MULTICAST_IF, the send address is specified in sendto, and the group
+// membership is configured by NIC ID.
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNic) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Set the default send interface.
+ ip_mreqn iface = {};
+ iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallSucceeds());
+
+ // Bind the second FD to the v4 any address to ensure that we can receive the
+ // multicast packet.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket2->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register to receive multicast packets.
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+}
+
+// Check that multicast works when the default send interface is configured by
+// IP_MULTICAST_IF, the send address is specified in connect, and the group
+// membership is configured by address.
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrConnect) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Set the default send interface.
+ ip_mreq iface = {};
+ iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallSucceeds());
+
+ // Bind the second FD to the v4 any address to ensure that we can receive the
+ // multicast packet.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket2->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register to receive multicast packets.
+ ip_mreq group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto connect_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ ASSERT_THAT(
+ RetryEINTR(connect)(socket1->get(),
+ reinterpret_cast<sockaddr*>(&connect_addr.addr),
+ connect_addr.addr_len),
+ SyscallSucceeds());
+
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+}
+
+// Check that multicast works when the default send interface is configured by
+// IP_MULTICAST_IF, the send address is specified in connect, and the group
+// membership is configured by NIC ID.
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicConnect) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Set the default send interface.
+ ip_mreqn iface = {};
+ iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallSucceeds());
+
+ // Bind the second FD to the v4 any address to ensure that we can receive the
+ // multicast packet.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket2->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register to receive multicast packets.
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto connect_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ ASSERT_THAT(
+ RetryEINTR(connect)(socket1->get(),
+ reinterpret_cast<sockaddr*>(&connect_addr.addr),
+ connect_addr.addr_len),
+ SyscallSucceeds());
+
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+}
+
+// Check that multicast works when the default send interface is configured by
+// IP_MULTICAST_IF, the send address is specified in sendto, and the group
+// membership is configured by address.
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelf) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Set the default send interface.
+ ip_mreq iface = {};
+ iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallSucceeds());
+
+ // Bind the first FD to the v4 any address to ensure that we can receive the
+ // multicast packet.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register to receive multicast packets.
+ ip_mreq group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+}
+
+// Check that multicast works when the default send interface is configured by
+// IP_MULTICAST_IF, the send address is specified in sendto, and the group
+// membership is configured by NIC ID.
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelf) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Set the default send interface.
+ ip_mreqn iface = {};
+ iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallSucceeds());
+
+ // Bind the first FD to the v4 any address to ensure that we can receive the
+ // multicast packet.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register to receive multicast packets.
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+}
+
+// Check that multicast works when the default send interface is configured by
+// IP_MULTICAST_IF, the send address is specified in connect, and the group
+// membership is configured by address.
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfConnect) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Set the default send interface.
+ ip_mreq iface = {};
+ iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallSucceeds());
+
+ // Bind the first FD to the v4 any address to ensure that we can receive the
+ // multicast packet.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register to receive multicast packets.
+ ip_mreq group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto connect_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ EXPECT_THAT(
+ RetryEINTR(connect)(socket1->get(),
+ reinterpret_cast<sockaddr*>(&connect_addr.addr),
+ connect_addr.addr_len),
+ SyscallSucceeds());
+
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we did not receive the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ EXPECT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// Check that multicast works when the default send interface is configured by
+// IP_MULTICAST_IF, the send address is specified in connect, and the group
+// membership is configured by NIC ID.
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfConnect) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Set the default send interface.
+ ip_mreqn iface = {};
+ iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallSucceeds());
+
+ // Bind the first FD to the v4 any address to ensure that we can receive the
+ // multicast packet.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register to receive multicast packets.
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto connect_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ ASSERT_THAT(
+ RetryEINTR(connect)(socket1->get(),
+ reinterpret_cast<sockaddr*>(&connect_addr.addr),
+ connect_addr.addr_len),
+ SyscallSucceeds());
+
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we did not receive the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ EXPECT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// Check that multicast works when the default send interface is configured by
+// IP_MULTICAST_IF, the send address is specified in sendto, and the group
+// membership is configured by address.
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfNoLoop) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Set the default send interface.
+ ip_mreq iface = {};
+ iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ &kSockOptOff, sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ // Bind the first FD to the v4 any address to ensure that we can receive the
+ // multicast packet.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register to receive multicast packets.
+ ip_mreq group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+}
+
+// Check that multicast works when the default send interface is configured by
+// IP_MULTICAST_IF, the send address is specified in sendto, and the group
+// membership is configured by NIC ID.
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Set the default send interface.
+ ip_mreqn iface = {};
+ iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ &kSockOptOff, sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ // Bind the second FD to the v4 any address to ensure that we can receive the
+ // multicast packet.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register to receive multicast packets.
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+}
+
+// Check that dropping a group membership that does not exist fails.
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastInvalidDrop) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Unregister from a membership that we didn't have.
+ ip_mreq group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallFailsWithErrno(EADDRNOTAVAIL));
+}
+
+// Check that dropping a group membership prevents multicast packets from being
+// delivered. Default send address configured by bind and group membership
+// interface configured by address.
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind the first FD to the loopback. This is an alternative to
+ // IP_MULTICAST_IF for setting the default send interface.
+ auto sender_addr = V4Loopback();
+ EXPECT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ sender_addr.addr_len),
+ SyscallSucceeds());
+
+ // Bind the second FD to the v4 any address to ensure that we can receive the
+ // multicast packet.
+ auto receiver_addr = V4Any();
+ EXPECT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket2->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register and unregister to receive multicast packets.
+ ip_mreq group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+ EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we did not receive the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// Check that dropping a group membership prevents multicast packets from being
+// delivered. Default send address configured by bind and group membership
+// interface configured by NIC ID.
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropNic) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind the first FD to the loopback. This is an alternative to
+ // IP_MULTICAST_IF for setting the default send interface.
+ auto sender_addr = V4Loopback();
+ EXPECT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ sender_addr.addr_len),
+ SyscallSucceeds());
+
+ // Bind the second FD to the v4 any address to ensure that we can receive the
+ // multicast packet.
+ auto receiver_addr = V4Any();
+ EXPECT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket2->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register and unregister to receive multicast packets.
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+ EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we did not receive the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfZero) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreqn iface = {};
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallSucceeds());
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfInvalidNic) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreqn iface = {};
+ iface.imr_ifindex = -1;
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallFailsWithErrno(EADDRNOTAVAIL));
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfInvalidAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreq iface = {};
+ iface.imr_interface.s_addr = inet_addr("255.255.255");
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallFailsWithErrno(EADDRNOTAVAIL));
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetShort) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Create a valid full-sized request.
+ ip_mreqn iface = {};
+ iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+
+ // Send an optlen of 1 to check that optlen is enforced.
+ EXPECT_THAT(
+ setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, 1),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfDefault) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ in_addr get = {};
+ socklen_t size = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ SyscallSucceeds());
+ EXPECT_EQ(size, sizeof(get));
+ EXPECT_EQ(get.s_addr, 0);
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfDefaultReqn) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreqn get = {};
+ socklen_t size = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ SyscallSucceeds());
+
+ // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the
+ // first sizeof(struct in_addr) bytes of struct ip_mreqn as a struct in_addr.
+ // Conveniently, this corresponds to the field ip_mreqn::imr_multiaddr.
+ EXPECT_EQ(size, sizeof(in_addr));
+
+ // getsockopt(IP_MULTICAST_IF) will only return the interface address which
+ // hasn't been set.
+ EXPECT_EQ(get.imr_multiaddr.s_addr, 0);
+ EXPECT_EQ(get.imr_address.s_addr, 0);
+ EXPECT_EQ(get.imr_ifindex, 0);
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetAddrGetReqn) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ in_addr set = {};
+ set.s_addr = htonl(INADDR_LOOPBACK);
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
+ sizeof(set)),
+ SyscallSucceeds());
+
+ ip_mreqn get = {};
+ socklen_t size = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ SyscallSucceeds());
+
+ // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the
+ // first sizeof(struct in_addr) bytes of struct ip_mreqn as a struct in_addr.
+ // Conveniently, this corresponds to the field ip_mreqn::imr_multiaddr.
+ EXPECT_EQ(size, sizeof(in_addr));
+ EXPECT_EQ(get.imr_multiaddr.s_addr, set.s_addr);
+ EXPECT_EQ(get.imr_address.s_addr, 0);
+ EXPECT_EQ(get.imr_ifindex, 0);
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetReqAddrGetReqn) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreq set = {};
+ set.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
+ sizeof(set)),
+ SyscallSucceeds());
+
+ ip_mreqn get = {};
+ socklen_t size = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ SyscallSucceeds());
+
+ // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the
+ // first sizeof(struct in_addr) bytes of struct ip_mreqn as a struct in_addr.
+ // Conveniently, this corresponds to the field ip_mreqn::imr_multiaddr.
+ EXPECT_EQ(size, sizeof(in_addr));
+ EXPECT_EQ(get.imr_multiaddr.s_addr, set.imr_interface.s_addr);
+ EXPECT_EQ(get.imr_address.s_addr, 0);
+ EXPECT_EQ(get.imr_ifindex, 0);
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetNicGetReqn) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreqn set = {};
+ set.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
+ sizeof(set)),
+ SyscallSucceeds());
+
+ ip_mreqn get = {};
+ socklen_t size = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ SyscallSucceeds());
+ EXPECT_EQ(size, sizeof(in_addr));
+ EXPECT_EQ(get.imr_multiaddr.s_addr, 0);
+ EXPECT_EQ(get.imr_address.s_addr, 0);
+ EXPECT_EQ(get.imr_ifindex, 0);
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ in_addr set = {};
+ set.s_addr = htonl(INADDR_LOOPBACK);
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
+ sizeof(set)),
+ SyscallSucceeds());
+
+ in_addr get = {};
+ socklen_t size = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ SyscallSucceeds());
+
+ EXPECT_EQ(size, sizeof(get));
+ EXPECT_EQ(get.s_addr, set.s_addr);
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetReqAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreq set = {};
+ set.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
+ sizeof(set)),
+ SyscallSucceeds());
+
+ in_addr get = {};
+ socklen_t size = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ SyscallSucceeds());
+
+ EXPECT_EQ(size, sizeof(get));
+ EXPECT_EQ(get.s_addr, set.imr_interface.s_addr);
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetNic) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreqn set = {};
+ set.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
+ sizeof(set)),
+ SyscallSucceeds());
+
+ in_addr get = {};
+ socklen_t size = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ SyscallSucceeds());
+ EXPECT_EQ(size, sizeof(get));
+ EXPECT_EQ(get.s_addr, 0);
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, TestJoinGroupNoIf) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallFailsWithErrno(ENODEV));
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, TestJoinGroupInvalidIf) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreqn group = {};
+ group.imr_address.s_addr = inet_addr("255.255.255");
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallFailsWithErrno(ENODEV));
+}
+
+// Check that multiple memberships are not allowed on the same socket.
+TEST_P(IPv4UDPUnboundSocketTest, TestMultipleJoinsOnSingleSocket) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto fd = socket1->get();
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+
+ EXPECT_THAT(
+ setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)),
+ SyscallSucceeds());
+
+ EXPECT_THAT(
+ setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+// Check that two sockets can join the same multicast group at the same time.
+TEST_P(IPv4UDPUnboundSocketTest, TestTwoSocketsJoinSameMulticastGroup) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+ EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Drop the membership twice on each socket, the second call for each socket
+ // should fail.
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallFailsWithErrno(EADDRNOTAVAIL));
+ EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+ EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallFailsWithErrno(EADDRNOTAVAIL));
+}
+
+// Check that two sockets can join the same multicast group at the same time,
+// and both will receive data on it.
+TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionOnTwoSockets) {
+ std::unique_ptr<SocketPair> socket_pairs[2] = {
+ absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocket())),
+ absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocket()))};
+
+ ip_mreq iface = {}, group = {};
+ iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ auto receiver_addr = V4Any();
+ int bound_port = 0;
+
+ // Create two socketpairs with the exact same configuration.
+ for (auto& sockets : socket_pairs) {
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF,
+ &iface, sizeof(iface)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_REUSEPORT,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
+ &group, sizeof(group)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(sockets->second_fd(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ // Get the port assigned.
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(sockets->second_fd(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+ // On the first iteration, save the port we are bound to. On the second
+ // iteration, verify the port is the same as the one from the first
+ // iteration. In other words, both sockets listen on the same port.
+ if (bound_port == 0) {
+ bound_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ } else {
+ EXPECT_EQ(bound_port,
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port);
+ }
+ }
+
+ // Send a multicast packet to the group from two different sockets and verify
+ // it is received by both sockets that joined that group.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port;
+ for (auto& sockets : socket_pairs) {
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet on both sockets.
+ for (auto& sockets : socket_pairs) {
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(
+ RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+ }
+ }
+}
+
+// Check that on two sockets that joined a group and listen on ANY, dropping
+// memberships one by one will continue to deliver packets to both sockets until
+// both memberships have been dropped.
+TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) {
+ std::unique_ptr<SocketPair> socket_pairs[2] = {
+ absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocket())),
+ absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocket()))};
+
+ ip_mreq iface = {}, group = {};
+ iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ auto receiver_addr = V4Any();
+ int bound_port = 0;
+
+ // Create two socketpairs with the exact same configuration.
+ for (auto& sockets : socket_pairs) {
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF,
+ &iface, sizeof(iface)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_REUSEPORT,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
+ &group, sizeof(group)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(sockets->second_fd(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ // Get the port assigned.
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(sockets->second_fd(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+ // On the first iteration, save the port we are bound to. On the second
+ // iteration, verify the port is the same as the one from the first
+ // iteration. In other words, both sockets listen on the same port.
+ if (bound_port == 0) {
+ bound_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ } else {
+ EXPECT_EQ(bound_port,
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port);
+ }
+ }
+
+ // Drop the membership of the first socket pair and verify data is still
+ // received.
+ ASSERT_THAT(setsockopt(socket_pairs[0]->second_fd(), IPPROTO_IP,
+ IP_DROP_MEMBERSHIP, &group, sizeof(group)),
+ SyscallSucceeds());
+ // Send a packet from each socket_pair.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port;
+ for (auto& sockets : socket_pairs) {
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet on both sockets.
+ for (auto& sockets : socket_pairs) {
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(
+ RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+ }
+ }
+
+ // Drop the membership of the second socket pair and verify data stops being
+ // received.
+ ASSERT_THAT(setsockopt(socket_pairs[1]->second_fd(), IPPROTO_IP,
+ IP_DROP_MEMBERSHIP, &group, sizeof(group)),
+ SyscallSucceeds());
+ // Send a packet from each socket_pair.
+ for (auto& sockets : socket_pairs) {
+ char send_buf[200];
+ ASSERT_THAT(
+ RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ char recv_buf[sizeof(send_buf)] = {};
+ for (auto& sockets : socket_pairs) {
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf,
+ sizeof(recv_buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+ }
+ }
+}
+
+// Check that a receiving socket can bind to the multicast address before
+// joining the group and receive data once the group has been joined.
+TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenJoinThenReceive) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind second socket (receiver) to the multicast address.
+ auto receiver_addr = V4Multicast();
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ // Update receiver_addr with the correct port number.
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket2->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register to receive multicast packets.
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet on the first socket out the loopback interface.
+ ip_mreq iface = {};
+ iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallSucceeds());
+ auto sendto_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&sendto_addr.addr),
+ sendto_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+}
+
+// Check that a receiving socket can bind to the multicast address and won't
+// receive multicast data if it hasn't joined the group.
+TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenNoJoinThenNoReceive) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind second socket (receiver) to the multicast address.
+ auto receiver_addr = V4Multicast();
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ // Update receiver_addr with the correct port number.
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket2->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Send a multicast packet on the first socket out the loopback interface.
+ ip_mreq iface = {};
+ iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallSucceeds());
+ auto sendto_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&sendto_addr.addr),
+ sendto_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we don't receive the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// Check that a socket can bind to a multicast address and still send out
+// packets.
+TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenSend) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind second socket (receiver) to the ANY address.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket2->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Bind the first socket (sender) to the multicast address.
+ auto sender_addr = V4Multicast();
+ ASSERT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ sender_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t sender_addr_len = sender_addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ &sender_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(sender_addr_len, sender_addr.addr_len);
+
+ // Send a packet on the first socket to the loopback address.
+ auto sendto_addr = V4Loopback();
+ reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&sendto_addr.addr),
+ sendto_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+}
+
+// Check that a receiving socket can bind to the broadcast address and receive
+// broadcast packets.
+TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenReceive) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind second socket (receiver) to the broadcast address.
+ auto receiver_addr = V4Broadcast();
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket2->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Send a broadcast packet on the first socket out the loopback interface.
+ EXPECT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+ // Note: Binding to the loopback interface makes the broadcast go out of it.
+ auto sender_bind_addr = V4Loopback();
+ ASSERT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_bind_addr.addr),
+ sender_bind_addr.addr_len),
+ SyscallSucceeds());
+ auto sendto_addr = V4Broadcast();
+ reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&sendto_addr.addr),
+ sendto_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+}
+
+// Check that a socket can bind to the broadcast address and still send out
+// packets.
+TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind second socket (receiver) to the ANY address.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(socket2->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Bind the first socket (sender) to the broadcast address.
+ auto sender_addr = V4Broadcast();
+ ASSERT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ sender_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t sender_addr_len = sender_addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ &sender_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(sender_addr_len, sender_addr.addr_len);
+
+ // Send a packet on the first socket to the loopback address.
+ auto sendto_addr = V4Loopback();
+ reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&sendto_addr.addr),
+ sendto_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+}
+
+// Check that SO_REUSEADDR always delivers to the most recently bound socket.
+//
+// FIXME(gvisor.dev/issue/873): Endpoint order is not restored correctly. Enable
+// random and co-op save (below) once that is fixed.
+TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) {
+ std::vector<std::unique_ptr<FileDescriptor>> sockets;
+ sockets.emplace_back(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()));
+
+ ASSERT_THAT(setsockopt(sockets[0]->get(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(sockets[0]->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(sockets[0]->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ constexpr int kMessageSize = 200;
+
+ // FIXME(gvisor.dev/issue/873): Endpoint order is not restored correctly.
+ const DisableSave ds;
+
+ for (int i = 0; i < 10; i++) {
+ // Add a new receiver.
+ sockets.emplace_back(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()));
+ auto& last = sockets.back();
+ ASSERT_THAT(setsockopt(last->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(last->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ // Send a new message to the SO_REUSEADDR group. We use a new socket each
+ // time so that a new ephemeral port will be used each time. This ensures
+ // that we aren't doing REUSEPORT-like hash load blancing.
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ char send_buf[kMessageSize];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ EXPECT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Verify that the most recent socket got the message. We don't expect any
+ // of the other sockets to have received it, but we will check that later.
+ char recv_buf[sizeof(send_buf)] = {};
+ EXPECT_THAT(
+ RetryEINTR(recv)(last->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+ }
+
+ // Verify that no other messages were received.
+ for (auto& socket : sockets) {
+ char recv_buf[kMessageSize] = {};
+ EXPECT_THAT(RetryEINTR(recv)(socket->get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+ }
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrThenReusePort) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind socket1 with REUSEADDR.
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Bind socket2 to the same address as socket1, only with REUSEPORT.
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, BindReusePortThenReuseAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind socket1 with REUSEPORT.
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Bind socket2 to the same address as socket1, only with REUSEADDR.
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReusePort) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind socket1 with REUSEADDR and REUSEPORT.
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Bind socket2 to the same address as socket1, only with REUSEPORT.
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ // Bind socket3 to the same address as socket1, only with REUSEADDR.
+ ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReuseAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind socket1 with REUSEADDR and REUSEPORT.
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Bind socket2 to the same address as socket1, only with REUSEADDR.
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ // Bind socket3 to the same address as socket1, only with REUSEPORT.
+ ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable1) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind socket1 with REUSEADDR and REUSEPORT.
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Bind socket2 to the same address as socket1, only with REUSEPORT.
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ // Close socket2 to revert to just socket1 with REUSEADDR and REUSEPORT.
+ socket2->reset();
+
+ // Bind socket3 to the same address as socket1, only with REUSEADDR.
+ ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable2) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind socket1 with REUSEADDR and REUSEPORT.
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Bind socket2 to the same address as socket1, only with REUSEADDR.
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ // Close socket2 to revert to just socket1 with REUSEADDR and REUSEPORT.
+ socket2->reset();
+
+ // Bind socket3 to the same address as socket1, only with REUSEPORT.
+ ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReusePort) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind socket1 with REUSEADDR and REUSEPORT.
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Bind socket2 to the same address as socket1, also with REUSEADDR and
+ // REUSEPORT.
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ // Bind socket3 to the same address as socket1, only with REUSEPORT.
+ ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReuseAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind socket1 with REUSEADDR and REUSEPORT.
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Bind socket2 to the same address as socket1, also with REUSEADDR and
+ // REUSEPORT.
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ // Bind socket3 to the same address as socket1, only with REUSEADDR.
+ ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+}
+
+// Check that REUSEPORT takes precedence over REUSEADDR.
+TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) {
+ auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ASSERT_THAT(setsockopt(receiver1->get(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(receiver1->get(), SOL_SOCKET, SO_REUSEPORT,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(receiver1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(receiver1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Bind receiver2 to the same address as socket1, also with REUSEADDR and
+ // REUSEPORT.
+ ASSERT_THAT(setsockopt(receiver2->get(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(receiver2->get(), SOL_SOCKET, SO_REUSEPORT,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(receiver2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ constexpr int kMessageSize = 10;
+
+ for (int i = 0; i < 100; ++i) {
+ // Send a new message to the REUSEADDR/REUSEPORT group. We use a new socket
+ // each time so that a new ephemerial port will be used each time. This
+ // ensures that we cycle through hashes.
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ char send_buf[kMessageSize] = {};
+ EXPECT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+ }
+
+ // Check that both receivers got messages. This checks that we are using load
+ // balancing (REUSEPORT) instead of the most recently bound socket
+ // (REUSEADDR).
+ char recv_buf[kMessageSize] = {};
+ EXPECT_THAT(RetryEINTR(recv)(receiver1->get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallSucceedsWithValue(kMessageSize));
+ EXPECT_THAT(RetryEINTR(recv)(receiver2->get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallSucceedsWithValue(kMessageSize));
+}
+
+// Check that connect returns EADDRNOTAVAIL when out of local ephemeral ports.
+// We disable S/R because this test creates a large number of sockets.
+TEST_P(IPv4UDPUnboundSocketTest, UDPConnectPortExhaustion_NoRandomSave) {
+ auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ constexpr int kClients = 65536;
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(receiver1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(receiver1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Disable cooperative S/R as we are making too many syscalls.
+ DisableSave ds;
+ std::vector<std::unique_ptr<FileDescriptor>> sockets;
+ for (int i = 0; i < kClients; i++) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int ret = connect(s->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len);
+ if (ret == 0) {
+ sockets.push_back(std::move(s));
+ continue;
+ }
+ ASSERT_THAT(ret, SyscallFailsWithErrno(EAGAIN));
+ break;
+ }
+}
+
+// Test that socket will receive packet info control message.
+TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) {
+ // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet.
+ SKIP_IF((IsRunningWithHostinet()));
+
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto sender_addr = V4Loopback();
+ int level = SOL_IP;
+ int type = IP_PKTINFO;
+
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ sender_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t sender_addr_len = sender_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ &sender_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(sender_addr_len, sender_addr.addr_len);
+
+ auto receiver_addr = V4Loopback();
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&sender_addr.addr)->sin_port;
+ ASSERT_THAT(
+ connect(sender->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+
+ // Allow socket to receive control message.
+ ASSERT_THAT(
+ setsockopt(receiver->get(), level, type, &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Prepare message to send.
+ constexpr size_t kDataLength = 1024;
+ msghdr sent_msg = {};
+ iovec sent_iov = {};
+ char sent_data[kDataLength];
+ sent_iov.iov_base = sent_data;
+ sent_iov.iov_len = kDataLength;
+ sent_msg.msg_iov = &sent_iov;
+ sent_msg.msg_iovlen = 1;
+ sent_msg.msg_flags = 0;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(sender->get(), &sent_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ msghdr received_msg = {};
+ iovec received_iov = {};
+ char received_data[kDataLength];
+ char received_cmsg_buf[CMSG_SPACE(sizeof(in_pktinfo))] = {};
+ size_t cmsg_data_len = sizeof(in_pktinfo);
+ received_iov.iov_base = received_data;
+ received_iov.iov_len = kDataLength;
+ received_msg.msg_iov = &received_iov;
+ received_msg.msg_iovlen = 1;
+ received_msg.msg_controllen = CMSG_LEN(cmsg_data_len);
+ received_msg.msg_control = received_cmsg_buf;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(receiver->get(), &received_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len));
+ EXPECT_EQ(cmsg->cmsg_level, level);
+ EXPECT_EQ(cmsg->cmsg_type, type);
+
+ // Get loopback index.
+ ifreq ifr = {};
+ absl::SNPrintF(ifr.ifr_name, IFNAMSIZ, "lo");
+ ASSERT_THAT(ioctl(sender->get(), SIOCGIFINDEX, &ifr), SyscallSucceeds());
+ ASSERT_NE(ifr.ifr_ifindex, 0);
+
+ // Check the data
+ in_pktinfo received_pktinfo = {};
+ memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(in_pktinfo));
+ EXPECT_EQ(received_pktinfo.ipi_ifindex, ifr.ifr_ifindex);
+ EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, htonl(INADDR_LOOPBACK));
+ EXPECT_EQ(received_pktinfo.ipi_addr.s_addr, htonl(INADDR_LOOPBACK));
+}
+
+// Check that setting SO_RCVBUF below min is clamped to the minimum
+// receive buffer size.
+TEST_P(IPv4UDPUnboundSocketTest, SetSocketRecvBufBelowMin) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Discover minimum buffer size by setting it to zero.
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz,
+ sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ int min = 0;
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value so let's use a value that when doubled will still
+ // be smaller than min.
+ int below_min = min / 2 - 1;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &below_min,
+ sizeof(below_min)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ ASSERT_EQ(min, val);
+}
+
+// Check that setting SO_RCVBUF above max is clamped to the maximum
+// receive buffer size.
+TEST_P(IPv4UDPUnboundSocketTest, SetSocketRecvBufAboveMax) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Discover maxmimum buffer size by setting to a really large value.
+ constexpr int kRcvBufSz = 0xffffffff;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz,
+ sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ int max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &max, &max_len),
+ SyscallSucceeds());
+
+ int above_max = max + 1;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &above_max,
+ sizeof(above_max)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &val, &val_len),
+ SyscallSucceeds());
+ ASSERT_EQ(max, val);
+}
+
+// Check that setting SO_RCVBUF min <= rcvBufSz <= max is honored.
+TEST_P(IPv4UDPUnboundSocketTest, SetSocketRecvBuf) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int max = 0;
+ int min = 0;
+ {
+ // Discover maxmimum buffer size by setting to a really large value.
+ constexpr int kRcvBufSz = 0xffffffff;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz,
+ sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &max, &max_len),
+ SyscallSucceeds());
+ }
+
+ {
+ // Discover minimum buffer size by setting it to zero.
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz,
+ sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+ }
+
+ int quarter_sz = min + (max - min) / 4;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &quarter_sz,
+ sizeof(quarter_sz)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF.
+ if (!IsRunningOnGvisor()) {
+ quarter_sz *= 2;
+ }
+ ASSERT_EQ(quarter_sz, val);
+}
+
+// Check that setting SO_SNDBUF below min is clamped to the minimum
+// send buffer size.
+TEST_P(IPv4UDPUnboundSocketTest, SetSocketSendBufBelowMin) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Discover minimum buffer size by setting it to zero.
+ constexpr int kSndBufSz = 0;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &kSndBufSz,
+ sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ int min = 0;
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &min, &min_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value so let's use a value that when doubled will still
+ // be smaller than min.
+ int below_min = min / 2 - 1;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &below_min,
+ sizeof(below_min)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ ASSERT_EQ(min, val);
+}
+
+// Check that setting SO_SNDBUF above max is clamped to the maximum
+// send buffer size.
+TEST_P(IPv4UDPUnboundSocketTest, SetSocketSendBufAboveMax) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Discover maxmimum buffer size by setting to a really large value.
+ constexpr int kSndBufSz = 0xffffffff;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &kSndBufSz,
+ sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ int max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &max, &max_len),
+ SyscallSucceeds());
+
+ int above_max = max + 1;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &above_max,
+ sizeof(above_max)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &val, &val_len),
+ SyscallSucceeds());
+ ASSERT_EQ(max, val);
+}
+
+// Check that setting SO_SNDBUF min <= kSndBufSz <= max is honored.
+TEST_P(IPv4UDPUnboundSocketTest, SetSocketSendBuf) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int max = 0;
+ int min = 0;
+ {
+ // Discover maxmimum buffer size by setting to a really large value.
+ constexpr int kSndBufSz = 0xffffffff;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &kSndBufSz,
+ sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &max, &max_len),
+ SyscallSucceeds());
+ }
+
+ {
+ // Discover minimum buffer size by setting it to zero.
+ constexpr int kSndBufSz = 0;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &kSndBufSz,
+ sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &min, &min_len),
+ SyscallSucceeds());
+ }
+
+ int quarter_sz = min + (max - min) / 4;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &quarter_sz,
+ sizeof(quarter_sz)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF.
+ if (!IsRunningOnGvisor()) {
+ quarter_sz *= 2;
+ }
+
+ ASSERT_EQ(quarter_sz, val);
+}
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.h b/test/syscalls/linux/socket_ipv4_udp_unbound.h
new file mode 100644
index 000000000..f64c57645
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound.h
@@ -0,0 +1,29 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to IPv4 UDP sockets.
+using IPv4UDPUnboundSocketTest = SimpleSocketTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
new file mode 100644
index 000000000..d690d9564
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
@@ -0,0 +1,1099 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h"
+
+#include <arpa/inet.h>
+#include <ifaddrs.h>
+#include <netinet/in.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <cstdint>
+#include <cstdio>
+#include <cstring>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+TestAddress V4EmptyAddress() {
+ TestAddress t("V4Empty");
+ t.addr.ss_family = AF_INET;
+ t.addr_len = sizeof(sockaddr_in);
+ return t;
+}
+
+void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() {
+ got_if_infos_ = false;
+
+ // Get interface list.
+ ASSERT_NO_ERRNO(if_helper_.Load());
+ std::vector<std::string> if_names = if_helper_.InterfaceList(AF_INET);
+ if (if_names.size() != 2) {
+ return;
+ }
+
+ // Figure out which interface is where.
+ std::string lo = if_names[0];
+ std::string eth = if_names[1];
+ if (lo != "lo") std::swap(lo, eth);
+ if (lo != "lo") return;
+
+ lo_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(lo));
+ auto lo_if_addr = if_helper_.GetAddr(AF_INET, lo);
+ if (lo_if_addr == nullptr) {
+ return;
+ }
+ lo_if_addr_ = *reinterpret_cast<const sockaddr_in*>(lo_if_addr);
+
+ eth_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(eth));
+ auto eth_if_addr = if_helper_.GetAddr(AF_INET, eth);
+ if (eth_if_addr == nullptr) {
+ return;
+ }
+ eth_if_addr_ = *reinterpret_cast<const sockaddr_in*>(eth_if_addr);
+
+ got_if_infos_ = true;
+}
+
+// Verifies that a newly instantiated UDP socket does not have the
+// broadcast socket option enabled.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, UDPBroadcastDefault) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(socket->get(), SOL_SOCKET, SO_BROADCAST, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOff);
+ EXPECT_EQ(get_sz, sizeof(get));
+}
+
+// Verifies that a newly instantiated UDP socket returns true after enabling
+// the broadcast socket option.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, SetUDPBroadcast) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ EXPECT_THAT(setsockopt(socket->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(socket->get(), SOL_SOCKET, SO_BROADCAST, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOn);
+ EXPECT_EQ(get_sz, sizeof(get));
+}
+
+// Verifies that a broadcast UDP packet will arrive at all UDP sockets with
+// the destination port number.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
+ UDPBroadcastReceivedOnExpectedPort) {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto rcvr2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto norcv = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Enable SO_BROADCAST on the sending socket.
+ ASSERT_THAT(setsockopt(sender->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ // Enable SO_REUSEPORT on the receiving sockets so that they may both be bound
+ // to the broadcast messages destination port.
+ ASSERT_THAT(setsockopt(rcvr1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(setsockopt(rcvr2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ // Bind the first socket to the ANY address and let the system assign a port.
+ auto rcv1_addr = V4Any();
+ ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr),
+ rcv1_addr.addr_len),
+ SyscallSucceedsWithValue(0));
+ // Retrieve port number from first socket so that it can be bound to the
+ // second socket.
+ socklen_t rcv_addr_sz = rcv1_addr.addr_len;
+ ASSERT_THAT(
+ getsockname(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr),
+ &rcv_addr_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(rcv_addr_sz, rcv1_addr.addr_len);
+ auto port = reinterpret_cast<sockaddr_in*>(&rcv1_addr.addr)->sin_port;
+
+ // Bind the second socket to the same address:port as the first.
+ ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr),
+ rcv_addr_sz),
+ SyscallSucceedsWithValue(0));
+
+ // Bind the non-receiving socket to an ephemeral port.
+ auto norecv_addr = V4Any();
+ ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr),
+ norecv_addr.addr_len),
+ SyscallSucceedsWithValue(0));
+
+ // Broadcast a test message.
+ auto dst_addr = V4Broadcast();
+ reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = port;
+ constexpr char kTestMsg[] = "hello, world";
+ EXPECT_THAT(
+ sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
+ reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
+
+ // Verify that the receiving sockets received the test message.
+ char buf[sizeof(kTestMsg)] = {};
+ EXPECT_THAT(recv(rcvr1->get(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
+ EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg)));
+ memset(buf, 0, sizeof(buf));
+ EXPECT_THAT(recv(rcvr2->get(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
+ EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg)));
+
+ // Verify that the non-receiving socket did not receive the test message.
+ memset(buf, 0, sizeof(buf));
+ EXPECT_THAT(RetryEINTR(recv)(norcv->get(), buf, sizeof(buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// Verifies that a broadcast UDP packet will arrive at all UDP sockets bound to
+// the destination port number and either INADDR_ANY or INADDR_BROADCAST, but
+// not a unicast address.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
+ UDPBroadcastReceivedOnExpectedAddresses) {
+ // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its
+ // IPv4 address on eth0.
+ SKIP_IF(!got_if_infos_);
+
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto rcvr2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto norcv = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Enable SO_BROADCAST on the sending socket.
+ ASSERT_THAT(setsockopt(sender->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ // Enable SO_REUSEPORT on all sockets so that they may all be bound to the
+ // broadcast messages destination port.
+ ASSERT_THAT(setsockopt(rcvr1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(setsockopt(rcvr2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(setsockopt(norcv->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ // Bind the first socket the ANY address and let the system assign a port.
+ auto rcv1_addr = V4Any();
+ ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr),
+ rcv1_addr.addr_len),
+ SyscallSucceedsWithValue(0));
+ // Retrieve port number from first socket so that it can be bound to the
+ // second socket.
+ socklen_t rcv_addr_sz = rcv1_addr.addr_len;
+ ASSERT_THAT(
+ getsockname(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr),
+ &rcv_addr_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(rcv_addr_sz, rcv1_addr.addr_len);
+ auto port = reinterpret_cast<sockaddr_in*>(&rcv1_addr.addr)->sin_port;
+
+ // Bind the second socket to the broadcast address.
+ auto rcv2_addr = V4Broadcast();
+ reinterpret_cast<sockaddr_in*>(&rcv2_addr.addr)->sin_port = port;
+ ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast<sockaddr*>(&rcv2_addr.addr),
+ rcv2_addr.addr_len),
+ SyscallSucceedsWithValue(0));
+
+ // Bind the non-receiving socket to the unicast ethernet address.
+ auto norecv_addr = rcv1_addr;
+ reinterpret_cast<sockaddr_in*>(&norecv_addr.addr)->sin_addr =
+ eth_if_addr_.sin_addr;
+ ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr),
+ norecv_addr.addr_len),
+ SyscallSucceedsWithValue(0));
+
+ // Broadcast a test message.
+ auto dst_addr = V4Broadcast();
+ reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = port;
+ constexpr char kTestMsg[] = "hello, world";
+ EXPECT_THAT(
+ sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
+ reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
+
+ // Verify that the receiving sockets received the test message.
+ char buf[sizeof(kTestMsg)] = {};
+ EXPECT_THAT(recv(rcvr1->get(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
+ EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg)));
+ memset(buf, 0, sizeof(buf));
+ EXPECT_THAT(recv(rcvr2->get(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
+ EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg)));
+
+ // Verify that the non-receiving socket did not receive the test message.
+ memset(buf, 0, sizeof(buf));
+ EXPECT_THAT(RetryEINTR(recv)(norcv->get(), buf, sizeof(buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// Verifies that a UDP broadcast can be sent and then received back on the same
+// socket that is bound to the broadcast address (255.255.255.255).
+// FIXME(b/141938460): This can be combined with the next test
+// (UDPBroadcastSendRecvOnSocketBoundToAny).
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
+ UDPBroadcastSendRecvOnSocketBoundToBroadcast) {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Enable SO_BROADCAST.
+ ASSERT_THAT(setsockopt(sender->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ // Bind the sender to the broadcast address.
+ auto src_addr = V4Broadcast();
+ ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(&src_addr.addr),
+ src_addr.addr_len),
+ SyscallSucceedsWithValue(0));
+ socklen_t src_sz = src_addr.addr_len;
+ ASSERT_THAT(getsockname(sender->get(),
+ reinterpret_cast<sockaddr*>(&src_addr.addr), &src_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(src_sz, src_addr.addr_len);
+
+ // Send the message.
+ auto dst_addr = V4Broadcast();
+ reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&src_addr.addr)->sin_port;
+ constexpr char kTestMsg[] = "hello, world";
+ EXPECT_THAT(
+ sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
+ reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
+
+ // Verify that the message was received.
+ char buf[sizeof(kTestMsg)] = {};
+ EXPECT_THAT(RetryEINTR(recv)(sender->get(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
+ EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg)));
+}
+
+// Verifies that a UDP broadcast can be sent and then received back on the same
+// socket that is bound to the ANY address (0.0.0.0).
+// FIXME(b/141938460): This can be combined with the previous test
+// (UDPBroadcastSendRecvOnSocketBoundToBroadcast).
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
+ UDPBroadcastSendRecvOnSocketBoundToAny) {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Enable SO_BROADCAST.
+ ASSERT_THAT(setsockopt(sender->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ // Bind the sender to the ANY address.
+ auto src_addr = V4Any();
+ ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(&src_addr.addr),
+ src_addr.addr_len),
+ SyscallSucceedsWithValue(0));
+ socklen_t src_sz = src_addr.addr_len;
+ ASSERT_THAT(getsockname(sender->get(),
+ reinterpret_cast<sockaddr*>(&src_addr.addr), &src_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(src_sz, src_addr.addr_len);
+
+ // Send the message.
+ auto dst_addr = V4Broadcast();
+ reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&src_addr.addr)->sin_port;
+ constexpr char kTestMsg[] = "hello, world";
+ EXPECT_THAT(
+ sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
+ reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
+
+ // Verify that the message was received.
+ char buf[sizeof(kTestMsg)] = {};
+ EXPECT_THAT(RetryEINTR(recv)(sender->get(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
+ EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg)));
+}
+
+// Verifies that a UDP broadcast fails to send on a socket with SO_BROADCAST
+// disabled.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendBroadcast) {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Broadcast a test message without having enabled SO_BROADCAST on the sending
+ // socket.
+ auto addr = V4Broadcast();
+ reinterpret_cast<sockaddr_in*>(&addr.addr)->sin_port = htons(12345);
+ constexpr char kTestMsg[] = "hello, world";
+
+ EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
+ reinterpret_cast<sockaddr*>(&addr.addr), addr.addr_len),
+ SyscallFailsWithErrno(EACCES));
+}
+
+// Verifies that a UDP unicast on an unbound socket reaches its destination.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendUnicastOnUnbound) {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto rcvr = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind the receiver and retrieve its address and port number.
+ sockaddr_in addr = {};
+ addr.sin_family = AF_INET;
+ addr.sin_addr.s_addr = htonl(INADDR_ANY);
+ addr.sin_port = htons(0);
+ ASSERT_THAT(bind(rcvr->get(), reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallSucceedsWithValue(0));
+ memset(&addr, 0, sizeof(addr));
+ socklen_t addr_sz = sizeof(addr);
+ ASSERT_THAT(getsockname(rcvr->get(),
+ reinterpret_cast<struct sockaddr*>(&addr), &addr_sz),
+ SyscallSucceedsWithValue(0));
+
+ // Send a test message to the receiver.
+ constexpr char kTestMsg[] = "hello, world";
+ ASSERT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
+ reinterpret_cast<struct sockaddr*>(&addr), addr_sz),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
+ char buf[sizeof(kTestMsg)] = {};
+ ASSERT_THAT(recv(rcvr->get(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
+}
+
+// Check that multicast packets won't be delivered to the sending socket with no
+// set interface or group membership.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
+ TestSendMulticastSelfNoGroup) {
+ // FIXME(b/125485338): A group membership is not required for external
+ // multicast on gVisor.
+ SKIP_IF(IsRunningOnGvisor());
+
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ auto bind_addr = V4Any();
+ ASSERT_THAT(bind(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr),
+ bind_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t bind_addr_len = bind_addr.addr_len;
+ ASSERT_THAT(
+ getsockname(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr),
+ &bind_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(bind_addr_len, bind_addr.addr_len);
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&bind_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(socket->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we did not receive the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(
+ RetryEINTR(recv)(socket->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// Check that multicast packets will be delivered to the sending socket without
+// setting an interface.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSelf) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ auto bind_addr = V4Any();
+ ASSERT_THAT(bind(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr),
+ bind_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t bind_addr_len = bind_addr.addr_len;
+ ASSERT_THAT(
+ getsockname(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr),
+ &bind_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(bind_addr_len, bind_addr.addr_len);
+
+ // Register to receive multicast packets.
+ ip_mreq group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ ASSERT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&bind_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(socket->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(socket->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+}
+
+// Check that multicast packets won't be delivered to the sending socket with no
+// set interface and IP_MULTICAST_LOOP disabled.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
+ TestSendMulticastSelfLoopOff) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ auto bind_addr = V4Any();
+ ASSERT_THAT(bind(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr),
+ bind_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t bind_addr_len = bind_addr.addr_len;
+ ASSERT_THAT(
+ getsockname(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr),
+ &bind_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(bind_addr_len, bind_addr.addr_len);
+
+ // Disable multicast looping.
+ EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ &kSockOptOff, sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ // Register to receive multicast packets.
+ ip_mreq group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&bind_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(socket->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we did not receive the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ EXPECT_THAT(
+ RetryEINTR(recv)(socket->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// Check that multicast packets won't be delivered to another socket with no
+// set interface or group membership.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastNoGroup) {
+ // FIXME(b/125485338): A group membership is not required for external
+ // multicast on gVisor.
+ SKIP_IF(IsRunningOnGvisor());
+
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind the second FD to the v4 any address to ensure that we can receive the
+ // multicast packet.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we did not receive the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// Check that multicast packets will be delivered to another socket without
+// setting an interface.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticast) {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind the second FD to the v4 any address to ensure that we can receive the
+ // multicast packet.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register to receive multicast packets.
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+}
+
+// Check that multicast packets won't be delivered to another socket with no
+// set interface and IP_MULTICAST_LOOP disabled on the sending socket.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
+ TestSendMulticastSenderNoLoop) {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind the second FD to the v4 any address to ensure that we can receive the
+ // multicast packet.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Disable multicast looping on the sender.
+ EXPECT_THAT(setsockopt(sender->get(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ &kSockOptOff, sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ // Register to receive multicast packets.
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ EXPECT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we did not receive the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// Check that multicast packets will be delivered to the sending socket without
+// setting an interface and IP_MULTICAST_LOOP disabled on the receiving socket.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
+ TestSendMulticastReceiverNoLoop) {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind the second FD to the v4 any address to ensure that we can receive the
+ // multicast packet.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Disable multicast looping on the receiver.
+ ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ &kSockOptOff, sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ // Register to receive multicast packets.
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+}
+
+// Check that two sockets can join the same multicast group at the same time,
+// and both will receive data on it when bound to the ANY address.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
+ TestSendMulticastToTwoBoundToAny) {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ std::unique_ptr<FileDescriptor> receivers[2] = {
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocket())};
+
+ ip_mreq group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ auto receiver_addr = V4Any();
+ int bound_port = 0;
+ for (auto& receiver : receivers) {
+ ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ // Bind to ANY to receive multicast packets.
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+ EXPECT_EQ(
+ htonl(INADDR_ANY),
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr);
+ // On the first iteration, save the port we are bound to. On the second
+ // iteration, verify the port is the same as the one from the first
+ // iteration. In other words, both sockets listen on the same port.
+ if (bound_port == 0) {
+ bound_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ } else {
+ EXPECT_EQ(bound_port,
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port);
+ }
+
+ // Register to receive multicast packets.
+ ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
+ &group, sizeof(group)),
+ SyscallSucceeds());
+ }
+
+ // Send a multicast packet to the group and verify both receivers get it.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+ for (auto& receiver : receivers) {
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(
+ RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+ }
+}
+
+// Check that two sockets can join the same multicast group at the same time,
+// and both will receive data on it when bound to the multicast address.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
+ TestSendMulticastToTwoBoundToMulticastAddress) {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ std::unique_ptr<FileDescriptor> receivers[2] = {
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocket())};
+
+ ip_mreq group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ auto receiver_addr = V4Multicast();
+ int bound_port = 0;
+ for (auto& receiver : receivers) {
+ ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+ EXPECT_EQ(
+ inet_addr(kMulticastAddress),
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr);
+ // On the first iteration, save the port we are bound to. On the second
+ // iteration, verify the port is the same as the one from the first
+ // iteration. In other words, both sockets listen on the same port.
+ if (bound_port == 0) {
+ bound_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ } else {
+ EXPECT_EQ(
+ inet_addr(kMulticastAddress),
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr);
+ EXPECT_EQ(bound_port,
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port);
+ }
+
+ // Register to receive multicast packets.
+ ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
+ &group, sizeof(group)),
+ SyscallSucceeds());
+ }
+
+ // Send a multicast packet to the group and verify both receivers get it.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+ for (auto& receiver : receivers) {
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(
+ RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+ }
+}
+
+// Check that two sockets can join the same multicast group at the same time,
+// and with one bound to the wildcard address and the other bound to the
+// multicast address, both will receive data.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
+ TestSendMulticastToTwoBoundToAnyAndMulticastAddress) {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ std::unique_ptr<FileDescriptor> receivers[2] = {
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocket())};
+
+ ip_mreq group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ // The first receiver binds to the wildcard address.
+ auto receiver_addr = V4Any();
+ int bound_port = 0;
+ for (auto& receiver : receivers) {
+ ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+ // On the first iteration, save the port we are bound to and change the
+ // receiver address from V4Any to V4Multicast so the second receiver binds
+ // to that. On the second iteration, verify the port is the same as the one
+ // from the first iteration but the address is different.
+ if (bound_port == 0) {
+ EXPECT_EQ(
+ htonl(INADDR_ANY),
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr);
+ bound_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ receiver_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port =
+ bound_port;
+ } else {
+ EXPECT_EQ(
+ inet_addr(kMulticastAddress),
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr);
+ EXPECT_EQ(bound_port,
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port);
+ }
+
+ // Register to receive multicast packets.
+ ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
+ &group, sizeof(group)),
+ SyscallSucceeds());
+ }
+
+ // Send a multicast packet to the group and verify both receivers get it.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+ for (auto& receiver : receivers) {
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(
+ RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+ }
+}
+
+// Check that when receiving a looped-back multicast packet, its source address
+// is not a multicast address.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
+ IpMulticastLoopbackFromAddr) {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+ int receiver_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+
+ ip_mreq group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Connect to the multicast address. This binds us to the outgoing interface
+ // and allows us to get its IP (to be compared against the src-IP on the
+ // receiver side).
+ auto sendto_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = receiver_port;
+ ASSERT_THAT(RetryEINTR(connect)(
+ sender->get(), reinterpret_cast<sockaddr*>(&sendto_addr.addr),
+ sendto_addr.addr_len),
+ SyscallSucceeds());
+ auto sender_addr = V4EmptyAddress();
+ ASSERT_THAT(
+ getsockname(sender->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ &sender_addr.addr_len),
+ SyscallSucceeds());
+ ASSERT_EQ(sizeof(struct sockaddr_in), sender_addr.addr_len);
+ sockaddr_in* sender_addr_in =
+ reinterpret_cast<sockaddr_in*>(&sender_addr.addr);
+
+ // Send a multicast packet.
+ char send_buf[4] = {};
+ ASSERT_THAT(RetryEINTR(send)(sender->get(), send_buf, sizeof(send_buf), 0),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Receive a multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ auto src_addr = V4EmptyAddress();
+ ASSERT_THAT(
+ RetryEINTR(recvfrom)(receiver->get(), recv_buf, sizeof(recv_buf), 0,
+ reinterpret_cast<sockaddr*>(&src_addr.addr),
+ &src_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_EQ(sizeof(struct sockaddr_in), src_addr.addr_len);
+ sockaddr_in* src_addr_in = reinterpret_cast<sockaddr_in*>(&src_addr.addr);
+
+ // Verify that the received source IP:port matches the sender one.
+ EXPECT_EQ(sender_addr_in->sin_port, src_addr_in->sin_port);
+ EXPECT_EQ(sender_addr_in->sin_addr.s_addr, src_addr_in->sin_addr.s_addr);
+}
+
+// Check that when setting the IP_MULTICAST_IF option to both an index pointing
+// to the loopback interface and an address pointing to the non-loopback
+// interface, a multicast packet sent out uses the latter as its source address.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
+ IpMulticastLoopbackIfNicAndAddr) {
+ // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its
+ // IPv4 address on eth0.
+ SKIP_IF(!got_if_infos_);
+
+ // Create receiver, bind to ANY and join the multicast group.
+ auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+ int receiver_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_ifindex = lo_if_idx_;
+ ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallSucceeds());
+
+ // Set outgoing multicast interface config, with NIC and addr pointing to
+ // different interfaces.
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ ip_mreqn iface = {};
+ iface.imr_ifindex = lo_if_idx_;
+ iface.imr_address = eth_if_addr_.sin_addr;
+ ASSERT_THAT(setsockopt(sender->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto sendto_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = receiver_port;
+ char send_buf[4] = {};
+ ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&sendto_addr.addr),
+ sendto_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Receive a multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ auto src_addr = V4EmptyAddress();
+ ASSERT_THAT(
+ RetryEINTR(recvfrom)(receiver->get(), recv_buf, sizeof(recv_buf), 0,
+ reinterpret_cast<sockaddr*>(&src_addr.addr),
+ &src_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_EQ(sizeof(struct sockaddr_in), src_addr.addr_len);
+ sockaddr_in* src_addr_in = reinterpret_cast<sockaddr_in*>(&src_addr.addr);
+
+ // FIXME (b/137781162): When sending a multicast packet use the proper logic
+ // to determine the packet's src-IP.
+ SKIP_IF(IsRunningOnGvisor());
+
+ // Verify the received source address.
+ EXPECT_EQ(eth_if_addr_.sin_addr.s_addr, src_addr_in->sin_addr.s_addr);
+}
+
+// Check that when we are bound to one interface we can set IP_MULTICAST_IF to
+// another interface.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
+ IpMulticastLoopbackBindToOneIfSetMcastIfToAnother) {
+ // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its
+ // IPv4 address on eth0.
+ SKIP_IF(!got_if_infos_);
+
+ // FIXME (b/137790511): When bound to one interface it is not possible to set
+ // IP_MULTICAST_IF to a different interface.
+ SKIP_IF(IsRunningOnGvisor());
+
+ // Create sender and bind to eth interface.
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(&eth_if_addr_),
+ sizeof(eth_if_addr_)),
+ SyscallSucceeds());
+
+ // Run through all possible combinations of index and address for
+ // IP_MULTICAST_IF that selects the loopback interface.
+ struct {
+ int imr_ifindex;
+ struct in_addr imr_address;
+ } test_data[] = {
+ {lo_if_idx_, {}},
+ {0, lo_if_addr_.sin_addr},
+ {lo_if_idx_, lo_if_addr_.sin_addr},
+ {lo_if_idx_, eth_if_addr_.sin_addr},
+ };
+ for (auto t : test_data) {
+ ip_mreqn iface = {};
+ iface.imr_ifindex = t.imr_ifindex;
+ iface.imr_address = t.imr_address;
+ EXPECT_THAT(setsockopt(sender->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallSucceeds())
+ << "imr_index=" << iface.imr_ifindex
+ << " imr_address=" << GetAddr4Str(&iface.imr_address);
+ }
+}
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h
new file mode 100644
index 000000000..10b90b1e0
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h
@@ -0,0 +1,46 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_EXTERNAL_NETWORKING_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_EXTERNAL_NETWORKING_H_
+
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to unbound IPv4 UDP sockets in a sandbox
+// with external networking support.
+class IPv4UDPUnboundExternalNetworkingSocketTest : public SimpleSocketTest {
+ protected:
+ void SetUp();
+
+ IfAddrHelper if_helper_;
+
+ // got_if_infos_ is set to false if SetUp() could not obtain all interface
+ // infos that we need.
+ bool got_if_infos_;
+
+ // Interface infos.
+ int lo_if_idx_;
+ int eth_if_idx_;
+ sockaddr_in lo_if_addr_;
+ sockaddr_in eth_if_addr_;
+};
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_EXTERNAL_NETWORKING_H_
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc
new file mode 100644
index 000000000..f6e64c157
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc
@@ -0,0 +1,39 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h"
+
+#include <vector>
+
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketKind> GetSockets() {
+ return ApplyVec<SocketKind>(
+ IPv4UDPUnboundSocket,
+ AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK}));
+}
+
+INSTANTIATE_TEST_SUITE_P(IPv4UDPUnboundSockets,
+ IPv4UDPUnboundExternalNetworkingSocketTest,
+ ::testing::ValuesIn(GetSockets()));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc
new file mode 100644
index 000000000..f121c044d
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc
@@ -0,0 +1,32 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_ipv4_udp_unbound.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+INSTANTIATE_TEST_SUITE_P(
+ IPv4UDPSockets, IPv4UDPUnboundSocketTest,
+ ::testing::ValuesIn(ApplyVec<SocketKind>(IPv4UDPUnboundSocket,
+ AllBitwiseCombinations(List<int>{
+ 0, SOCK_NONBLOCK}))));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_netdevice.cc b/test/syscalls/linux/socket_netdevice.cc
new file mode 100644
index 000000000..15d4b85a7
--- /dev/null
+++ b/test/syscalls/linux/socket_netdevice.cc
@@ -0,0 +1,184 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <linux/netlink.h>
+#include <linux/rtnetlink.h>
+#include <linux/sockios.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+
+#include "gtest/gtest.h"
+#include "absl/base/internal/endian.h"
+#include "test/syscalls/linux/socket_netlink_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+// Tests for netdevice queries.
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+using ::testing::AnyOf;
+using ::testing::Eq;
+
+TEST(NetdeviceTest, Loopback) {
+ FileDescriptor sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+
+ // Prepare the request.
+ struct ifreq ifr;
+ snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
+
+ // Check for a non-zero interface index.
+ ASSERT_THAT(ioctl(sock.get(), SIOCGIFINDEX, &ifr), SyscallSucceeds());
+ EXPECT_NE(ifr.ifr_ifindex, 0);
+
+ // Check that the loopback is zero hardware address.
+ ASSERT_THAT(ioctl(sock.get(), SIOCGIFHWADDR, &ifr), SyscallSucceeds());
+ EXPECT_EQ(ifr.ifr_hwaddr.sa_data[0], 0);
+ EXPECT_EQ(ifr.ifr_hwaddr.sa_data[1], 0);
+ EXPECT_EQ(ifr.ifr_hwaddr.sa_data[2], 0);
+ EXPECT_EQ(ifr.ifr_hwaddr.sa_data[3], 0);
+ EXPECT_EQ(ifr.ifr_hwaddr.sa_data[4], 0);
+ EXPECT_EQ(ifr.ifr_hwaddr.sa_data[5], 0);
+}
+
+TEST(NetdeviceTest, Netmask) {
+ // We need an interface index to identify the loopback device.
+ FileDescriptor sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ struct ifreq ifr;
+ snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
+ ASSERT_THAT(ioctl(sock.get(), SIOCGIFINDEX, &ifr), SyscallSucceeds());
+ EXPECT_NE(ifr.ifr_ifindex, 0);
+
+ // Use a netlink socket to get the netmask, which we'll then compare to the
+ // netmask obtained via ioctl.
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+ uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get()));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct rtgenmsg rgm;
+ };
+
+ constexpr uint32_t kSeq = 12345;
+
+ struct request req;
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETADDR;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = kSeq;
+ req.rgm.rtgen_family = AF_UNSPEC;
+
+ // Iterate through messages until we find the one containing the prefix length
+ // (i.e. netmask) for the loopback device.
+ int prefixlen = -1;
+ ASSERT_NO_ERRNO(NetlinkRequestResponse(
+ fd, &req, sizeof(req),
+ [&](const struct nlmsghdr* hdr) {
+ EXPECT_THAT(hdr->nlmsg_type, AnyOf(Eq(RTM_NEWADDR), Eq(NLMSG_DONE)));
+
+ EXPECT_TRUE((hdr->nlmsg_flags & NLM_F_MULTI) == NLM_F_MULTI)
+ << std::hex << hdr->nlmsg_flags;
+
+ EXPECT_EQ(hdr->nlmsg_seq, kSeq);
+ EXPECT_EQ(hdr->nlmsg_pid, port);
+
+ if (hdr->nlmsg_type != RTM_NEWADDR) {
+ return;
+ }
+
+ // RTM_NEWADDR contains at least the header and ifaddrmsg.
+ EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct ifaddrmsg));
+
+ struct ifaddrmsg* ifaddrmsg =
+ reinterpret_cast<struct ifaddrmsg*>(NLMSG_DATA(hdr));
+ if (ifaddrmsg->ifa_index == static_cast<uint32_t>(ifr.ifr_ifindex) &&
+ ifaddrmsg->ifa_family == AF_INET) {
+ prefixlen = ifaddrmsg->ifa_prefixlen;
+ }
+ },
+ false));
+
+ ASSERT_GE(prefixlen, 0);
+
+ // Netmask is stored big endian in struct sockaddr_in, so we do the same for
+ // comparison.
+ uint32_t mask = 0xffffffff << (32 - prefixlen);
+ mask = absl::gbswap_32(mask);
+
+ // Check that the loopback interface has the correct subnet mask.
+ snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
+ ASSERT_THAT(ioctl(sock.get(), SIOCGIFNETMASK, &ifr), SyscallSucceeds());
+ EXPECT_EQ(ifr.ifr_netmask.sa_family, AF_INET);
+ struct sockaddr_in* sin =
+ reinterpret_cast<struct sockaddr_in*>(&ifr.ifr_netmask);
+ EXPECT_EQ(sin->sin_addr.s_addr, mask);
+}
+
+TEST(NetdeviceTest, InterfaceName) {
+ FileDescriptor sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+
+ // Prepare the request.
+ struct ifreq ifr;
+ snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
+
+ // Check for a non-zero interface index.
+ ASSERT_THAT(ioctl(sock.get(), SIOCGIFINDEX, &ifr), SyscallSucceeds());
+ EXPECT_NE(ifr.ifr_ifindex, 0);
+
+ // Check that SIOCGIFNAME finds the loopback interface.
+ snprintf(ifr.ifr_name, IFNAMSIZ, "foo");
+ ASSERT_THAT(ioctl(sock.get(), SIOCGIFNAME, &ifr), SyscallSucceeds());
+ EXPECT_STREQ(ifr.ifr_name, "lo");
+}
+
+TEST(NetdeviceTest, InterfaceFlags) {
+ FileDescriptor sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+
+ // Prepare the request.
+ struct ifreq ifr;
+ snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
+
+ // Check that SIOCGIFFLAGS marks the interface with IFF_LOOPBACK, IFF_UP, and
+ // IFF_RUNNING.
+ ASSERT_THAT(ioctl(sock.get(), SIOCGIFFLAGS, &ifr), SyscallSucceeds());
+ EXPECT_EQ(ifr.ifr_flags & IFF_UP, IFF_UP);
+ EXPECT_EQ(ifr.ifr_flags & IFF_RUNNING, IFF_RUNNING);
+}
+
+TEST(NetdeviceTest, InterfaceMTU) {
+ FileDescriptor sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+
+ // Prepare the request.
+ struct ifreq ifr = {};
+ snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
+
+ // Check that SIOCGIFMTU returns a nonzero MTU.
+ ASSERT_THAT(ioctl(sock.get(), SIOCGIFMTU, &ifr), SyscallSucceeds());
+ EXPECT_GT(ifr.ifr_mtu, 0);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_netlink.cc b/test/syscalls/linux/socket_netlink.cc
new file mode 100644
index 000000000..4ec0fd4fa
--- /dev/null
+++ b/test/syscalls/linux/socket_netlink.cc
@@ -0,0 +1,153 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <linux/netlink.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+// Tests for all netlink socket protocols.
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// NetlinkTest parameter is the protocol to test.
+using NetlinkTest = ::testing::TestWithParam<int>;
+
+// Netlink sockets must be SOCK_DGRAM or SOCK_RAW.
+TEST_P(NetlinkTest, Types) {
+ const int protocol = GetParam();
+
+ EXPECT_THAT(socket(AF_NETLINK, SOCK_STREAM, protocol),
+ SyscallFailsWithErrno(ESOCKTNOSUPPORT));
+ EXPECT_THAT(socket(AF_NETLINK, SOCK_SEQPACKET, protocol),
+ SyscallFailsWithErrno(ESOCKTNOSUPPORT));
+ EXPECT_THAT(socket(AF_NETLINK, SOCK_RDM, protocol),
+ SyscallFailsWithErrno(ESOCKTNOSUPPORT));
+ EXPECT_THAT(socket(AF_NETLINK, SOCK_DCCP, protocol),
+ SyscallFailsWithErrno(ESOCKTNOSUPPORT));
+ EXPECT_THAT(socket(AF_NETLINK, SOCK_PACKET, protocol),
+ SyscallFailsWithErrno(ESOCKTNOSUPPORT));
+
+ int fd;
+ EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_DGRAM, protocol), SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+
+ EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_RAW, protocol), SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+TEST_P(NetlinkTest, AutomaticPort) {
+ const int protocol = GetParam();
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, protocol));
+
+ struct sockaddr_nl addr = {};
+ addr.nl_family = AF_NETLINK;
+
+ EXPECT_THAT(
+ bind(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallSucceeds());
+
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, sizeof(addr));
+ // This is the only netlink socket in the process, so it should get the PID as
+ // the port id.
+ //
+ // N.B. Another process could theoretically have explicitly reserved our pid
+ // as a port ID, but that is very unlikely.
+ EXPECT_EQ(addr.nl_pid, getpid());
+}
+
+// Calling connect automatically binds to an automatic port.
+TEST_P(NetlinkTest, ConnectBinds) {
+ const int protocol = GetParam();
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, protocol));
+
+ struct sockaddr_nl addr = {};
+ addr.nl_family = AF_NETLINK;
+
+ EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, sizeof(addr));
+
+ // Each test is running in a pid namespace, so another process can explicitly
+ // reserve our pid as a port ID. In this case, a negative portid value will be
+ // set.
+ if (static_cast<pid_t>(addr.nl_pid) > 0) {
+ EXPECT_EQ(addr.nl_pid, getpid());
+ }
+
+ memset(&addr, 0, sizeof(addr));
+ addr.nl_family = AF_NETLINK;
+
+ // Connecting again is allowed, but keeps the same port.
+ EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+
+ addrlen = sizeof(addr);
+ EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, sizeof(addr));
+ EXPECT_EQ(addr.nl_pid, getpid());
+}
+
+TEST_P(NetlinkTest, GetPeerName) {
+ const int protocol = GetParam();
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, protocol));
+
+ struct sockaddr_nl addr = {};
+ socklen_t addrlen = sizeof(addr);
+
+ EXPECT_THAT(getpeername(fd.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ &addrlen),
+ SyscallSucceeds());
+
+ EXPECT_EQ(addrlen, sizeof(addr));
+ EXPECT_EQ(addr.nl_family, AF_NETLINK);
+ // Peer is the kernel if we didn't connect elsewhere.
+ EXPECT_EQ(addr.nl_pid, 0);
+}
+
+INSTANTIATE_TEST_SUITE_P(ProtocolTest, NetlinkTest,
+ ::testing::Values(NETLINK_ROUTE,
+ NETLINK_KOBJECT_UEVENT));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc
new file mode 100644
index 000000000..e6647a1c3
--- /dev/null
+++ b/test/syscalls/linux/socket_netlink_route.cc
@@ -0,0 +1,935 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <ifaddrs.h>
+#include <linux/if.h>
+#include <linux/netlink.h>
+#include <linux/rtnetlink.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <iostream>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "absl/strings/str_format.h"
+#include "test/syscalls/linux/socket_netlink_route_util.h"
+#include "test/syscalls/linux/socket_netlink_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+// Tests for NETLINK_ROUTE sockets.
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr uint32_t kSeq = 12345;
+
+using ::testing::AnyOf;
+using ::testing::Eq;
+
+// Parameters for SockOptTest. They are:
+// 0: Socket option to query.
+// 1: A predicate to run on the returned sockopt value. Should return true if
+// the value is considered ok.
+// 2: A description of what the sockopt value is expected to be. Should complete
+// the sentence "<value> was unexpected, expected <description>"
+using SockOptTest = ::testing::TestWithParam<
+ std::tuple<int, std::function<bool(int)>, std::string>>;
+
+TEST_P(SockOptTest, GetSockOpt) {
+ int sockopt = std::get<0>(GetParam());
+ auto verifier = std::get<1>(GetParam());
+ std::string verifier_description = std::get<2>(GetParam());
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE));
+
+ int res;
+ socklen_t len = sizeof(res);
+
+ EXPECT_THAT(getsockopt(fd.get(), SOL_SOCKET, sockopt, &res, &len),
+ SyscallSucceeds());
+
+ EXPECT_EQ(len, sizeof(res));
+ EXPECT_TRUE(verifier(res)) << absl::StrFormat(
+ "getsockopt(%d, SOL_SOCKET, %d, &res, &len) => res=%d was unexpected, "
+ "expected %s",
+ fd.get(), sockopt, res, verifier_description);
+}
+
+std::function<bool(int)> IsPositive() {
+ return [](int val) { return val > 0; };
+}
+
+std::function<bool(int)> IsEqual(int target) {
+ return [target](int val) { return val == target; };
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ NetlinkRouteTest, SockOptTest,
+ ::testing::Values(
+ std::make_tuple(SO_SNDBUF, IsPositive(), "positive send buffer size"),
+ std::make_tuple(SO_RCVBUF, IsPositive(),
+ "positive receive buffer size"),
+ std::make_tuple(SO_TYPE, IsEqual(SOCK_RAW),
+ absl::StrFormat("SOCK_RAW (%d)", SOCK_RAW)),
+ std::make_tuple(SO_DOMAIN, IsEqual(AF_NETLINK),
+ absl::StrFormat("AF_NETLINK (%d)", AF_NETLINK)),
+ std::make_tuple(SO_PROTOCOL, IsEqual(NETLINK_ROUTE),
+ absl::StrFormat("NETLINK_ROUTE (%d)", NETLINK_ROUTE)),
+ std::make_tuple(SO_PASSCRED, IsEqual(0), "0")));
+
+// Validates the reponses to RTM_GETLINK + NLM_F_DUMP.
+void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) {
+ EXPECT_THAT(hdr->nlmsg_type, AnyOf(Eq(RTM_NEWLINK), Eq(NLMSG_DONE)));
+
+ EXPECT_TRUE((hdr->nlmsg_flags & NLM_F_MULTI) == NLM_F_MULTI)
+ << std::hex << hdr->nlmsg_flags;
+
+ EXPECT_EQ(hdr->nlmsg_seq, seq);
+ EXPECT_EQ(hdr->nlmsg_pid, port);
+
+ if (hdr->nlmsg_type != RTM_NEWLINK) {
+ return;
+ }
+
+ // RTM_NEWLINK contains at least the header and ifinfomsg.
+ EXPECT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg)));
+
+ // TODO(mpratt): Check ifinfomsg contents and following attrs.
+}
+
+TEST(NetlinkRouteTest, GetLinkDump) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+ uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get()));
+
+ // Loopback is common among all tests, check that it's found.
+ bool loopbackFound = false;
+ ASSERT_NO_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) {
+ CheckGetLinkResponse(hdr, kSeq, port);
+ if (hdr->nlmsg_type != RTM_NEWLINK) {
+ return;
+ }
+ ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg)));
+ const struct ifinfomsg* msg =
+ reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr));
+ std::cout << "Found interface idx=" << msg->ifi_index
+ << ", type=" << std::hex << msg->ifi_type << std::endl;
+ if (msg->ifi_type == ARPHRD_LOOPBACK) {
+ loopbackFound = true;
+ EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0);
+ }
+ }));
+ EXPECT_TRUE(loopbackFound);
+}
+
+// CheckLinkMsg checks a netlink message against an expected link.
+void CheckLinkMsg(const struct nlmsghdr* hdr, const Link& link) {
+ ASSERT_THAT(hdr->nlmsg_type, Eq(RTM_NEWLINK));
+ ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg)));
+ const struct ifinfomsg* msg =
+ reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr));
+ EXPECT_EQ(msg->ifi_index, link.index);
+
+ const struct rtattr* rta = FindRtAttr(hdr, msg, IFLA_IFNAME);
+ EXPECT_NE(nullptr, rta) << "IFLA_IFNAME not found in message.";
+ if (rta != nullptr) {
+ std::string name(reinterpret_cast<const char*>(RTA_DATA(rta)));
+ EXPECT_EQ(name, link.name);
+ }
+}
+
+TEST(NetlinkRouteTest, GetLinkByIndex) {
+ Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+ req.ifm.ifi_index = loopback_link.index;
+
+ bool found = false;
+ ASSERT_NO_ERRNO(NetlinkRequestResponse(
+ fd, &req, sizeof(req),
+ [&](const struct nlmsghdr* hdr) {
+ CheckLinkMsg(hdr, loopback_link);
+ found = true;
+ },
+ false));
+ EXPECT_TRUE(found) << "Netlink response does not contain any links.";
+}
+
+TEST(NetlinkRouteTest, GetLinkByName) {
+ Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ struct rtattr rtattr;
+ char ifname[IFNAMSIZ];
+ char pad[NLMSG_ALIGNTO + RTA_ALIGNTO];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+ req.rtattr.rta_type = IFLA_IFNAME;
+ req.rtattr.rta_len = RTA_LENGTH(loopback_link.name.size() + 1);
+ strncpy(req.ifname, loopback_link.name.c_str(), sizeof(req.ifname));
+ req.hdr.nlmsg_len =
+ NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len);
+
+ bool found = false;
+ ASSERT_NO_ERRNO(NetlinkRequestResponse(
+ fd, &req, sizeof(req),
+ [&](const struct nlmsghdr* hdr) {
+ CheckLinkMsg(hdr, loopback_link);
+ found = true;
+ },
+ false));
+ EXPECT_TRUE(found) << "Netlink response does not contain any links.";
+}
+
+TEST(NetlinkRouteTest, GetLinkByIndexNotFound) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+ req.ifm.ifi_index = 1234590;
+
+ EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)),
+ PosixErrorIs(ENODEV, ::testing::_));
+}
+
+TEST(NetlinkRouteTest, GetLinkByNameNotFound) {
+ const std::string name = "nodevice?!";
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ struct rtattr rtattr;
+ char ifname[IFNAMSIZ];
+ char pad[NLMSG_ALIGNTO + RTA_ALIGNTO];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+ req.rtattr.rta_type = IFLA_IFNAME;
+ req.rtattr.rta_len = RTA_LENGTH(name.size() + 1);
+ strncpy(req.ifname, name.c_str(), sizeof(req.ifname));
+ req.hdr.nlmsg_len =
+ NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len);
+
+ EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)),
+ PosixErrorIs(ENODEV, ::testing::_));
+}
+
+TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ // If type & 0x3 is equal to 0x2, this means a get request
+ // which doesn't require CAP_SYS_ADMIN.
+ req.hdr.nlmsg_type = ((__RTM_MAX + 1024) & (~0x3)) | 0x2;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+
+ EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)),
+ PosixErrorIs(EOPNOTSUPP, ::testing::_));
+}
+
+TEST(NetlinkRouteTest, MsgHdrMsgTrunc) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+
+ struct iovec iov = {};
+ iov.iov_base = &req;
+ iov.iov_len = sizeof(req);
+
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ // No destination required; it defaults to pid 0, the kernel.
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ // Small enough to ensure that the response doesn't fit.
+ constexpr size_t kBufferSize = 10;
+ std::vector<char> buf(kBufferSize);
+ iov.iov_base = buf.data();
+ iov.iov_len = buf.size();
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(fd.get(), &msg, 0),
+ SyscallSucceedsWithValue(kBufferSize));
+ EXPECT_EQ((msg.msg_flags & MSG_TRUNC), MSG_TRUNC);
+}
+
+TEST(NetlinkRouteTest, MsgTruncMsgHdrMsgTrunc) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+
+ struct iovec iov = {};
+ iov.iov_base = &req;
+ iov.iov_len = sizeof(req);
+
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ // No destination required; it defaults to pid 0, the kernel.
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ // Small enough to ensure that the response doesn't fit.
+ constexpr size_t kBufferSize = 10;
+ std::vector<char> buf(kBufferSize);
+ iov.iov_base = buf.data();
+ iov.iov_len = buf.size();
+
+ int res = 0;
+ ASSERT_THAT(res = RetryEINTR(recvmsg)(fd.get(), &msg, MSG_TRUNC),
+ SyscallSucceeds());
+ EXPECT_GT(res, kBufferSize);
+ EXPECT_EQ((msg.msg_flags & MSG_TRUNC), MSG_TRUNC);
+}
+
+TEST(NetlinkRouteTest, ControlMessageIgnored) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+ uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get()));
+
+ struct request {
+ struct nlmsghdr control_hdr;
+ struct nlmsghdr message_hdr;
+ struct ifinfomsg ifm;
+ };
+
+ struct request req = {};
+
+ // This control message is ignored. We still receive a response for the
+ // following RTM_GETLINK.
+ req.control_hdr.nlmsg_len = sizeof(req.control_hdr);
+ req.control_hdr.nlmsg_type = NLMSG_DONE;
+ req.control_hdr.nlmsg_seq = kSeq;
+
+ req.message_hdr.nlmsg_len = sizeof(req.message_hdr) + sizeof(req.ifm);
+ req.message_hdr.nlmsg_type = RTM_GETLINK;
+ req.message_hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.message_hdr.nlmsg_seq = kSeq;
+
+ req.ifm.ifi_family = AF_UNSPEC;
+
+ ASSERT_NO_ERRNO(NetlinkRequestResponse(
+ fd, &req, sizeof(req),
+ [&](const struct nlmsghdr* hdr) {
+ CheckGetLinkResponse(hdr, kSeq, port);
+ },
+ false));
+}
+
+TEST(NetlinkRouteTest, GetAddrDump) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+ uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get()));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct rtgenmsg rgm;
+ };
+
+ struct request req;
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETADDR;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = kSeq;
+ req.rgm.rtgen_family = AF_UNSPEC;
+
+ ASSERT_NO_ERRNO(NetlinkRequestResponse(
+ fd, &req, sizeof(req),
+ [&](const struct nlmsghdr* hdr) {
+ EXPECT_THAT(hdr->nlmsg_type, AnyOf(Eq(RTM_NEWADDR), Eq(NLMSG_DONE)));
+
+ EXPECT_TRUE((hdr->nlmsg_flags & NLM_F_MULTI) == NLM_F_MULTI)
+ << std::hex << hdr->nlmsg_flags;
+
+ EXPECT_EQ(hdr->nlmsg_seq, kSeq);
+ EXPECT_EQ(hdr->nlmsg_pid, port);
+
+ if (hdr->nlmsg_type != RTM_NEWADDR) {
+ return;
+ }
+
+ // RTM_NEWADDR contains at least the header and ifaddrmsg.
+ EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct ifaddrmsg));
+
+ // TODO(mpratt): Check ifaddrmsg contents and following attrs.
+ },
+ false));
+}
+
+TEST(NetlinkRouteTest, LookupAll) {
+ struct ifaddrs* if_addr_list = nullptr;
+ auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); });
+
+ // Not a syscall but we can use the syscall matcher as glibc sets errno.
+ ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds());
+
+ int count = 0;
+ for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) {
+ if (!i->ifa_addr || (i->ifa_addr->sa_family != AF_INET &&
+ i->ifa_addr->sa_family != AF_INET6)) {
+ continue;
+ }
+ count++;
+ }
+ ASSERT_GT(count, 0);
+}
+
+TEST(NetlinkRouteTest, AddAddr) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifaddrmsg ifa;
+ struct rtattr rtattr;
+ struct in_addr addr;
+ char pad[NLMSG_ALIGNTO + RTA_ALIGNTO];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_type = RTM_NEWADDR;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifa.ifa_family = AF_INET;
+ req.ifa.ifa_prefixlen = 24;
+ req.ifa.ifa_flags = 0;
+ req.ifa.ifa_scope = 0;
+ req.ifa.ifa_index = loopback_link.index;
+ req.rtattr.rta_type = IFA_LOCAL;
+ req.rtattr.rta_len = RTA_LENGTH(sizeof(req.addr));
+ inet_pton(AF_INET, "10.0.0.1", &req.addr);
+ req.hdr.nlmsg_len =
+ NLMSG_LENGTH(sizeof(req.ifa)) + NLMSG_ALIGN(req.rtattr.rta_len);
+
+ // Create should succeed, as no such address in kernel.
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_ACK;
+ EXPECT_NO_ERRNO(
+ NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len));
+
+ // Replace an existing address should succeed.
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_REPLACE | NLM_F_ACK;
+ req.hdr.nlmsg_seq++;
+ EXPECT_NO_ERRNO(
+ NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len));
+
+ // Create exclusive should fail, as we created the address above.
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK;
+ req.hdr.nlmsg_seq++;
+ EXPECT_THAT(
+ NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len),
+ PosixErrorIs(EEXIST, ::testing::_));
+}
+
+// GetRouteDump tests a RTM_GETROUTE + NLM_F_DUMP request.
+TEST(NetlinkRouteTest, GetRouteDump) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+ uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get()));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct rtmsg rtm;
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETROUTE;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = kSeq;
+ req.rtm.rtm_family = AF_UNSPEC;
+
+ bool routeFound = false;
+ bool dstFound = true;
+ ASSERT_NO_ERRNO(NetlinkRequestResponse(
+ fd, &req, sizeof(req),
+ [&](const struct nlmsghdr* hdr) {
+ // Validate the reponse to RTM_GETROUTE + NLM_F_DUMP.
+ EXPECT_THAT(hdr->nlmsg_type, AnyOf(Eq(RTM_NEWROUTE), Eq(NLMSG_DONE)));
+
+ EXPECT_TRUE((hdr->nlmsg_flags & NLM_F_MULTI) == NLM_F_MULTI)
+ << std::hex << hdr->nlmsg_flags;
+
+ EXPECT_EQ(hdr->nlmsg_seq, kSeq);
+ EXPECT_EQ(hdr->nlmsg_pid, port);
+
+ // The test should not proceed if it's not a RTM_NEWROUTE message.
+ if (hdr->nlmsg_type != RTM_NEWROUTE) {
+ return;
+ }
+
+ // RTM_NEWROUTE contains at least the header and rtmsg.
+ ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct rtmsg)));
+ const struct rtmsg* msg =
+ reinterpret_cast<const struct rtmsg*>(NLMSG_DATA(hdr));
+ // NOTE: rtmsg fields are char fields.
+ std::cout << "Found route table=" << static_cast<int>(msg->rtm_table)
+ << ", protocol=" << static_cast<int>(msg->rtm_protocol)
+ << ", scope=" << static_cast<int>(msg->rtm_scope)
+ << ", type=" << static_cast<int>(msg->rtm_type);
+
+ int len = RTM_PAYLOAD(hdr);
+ bool rtDstFound = false;
+ for (struct rtattr* attr = RTM_RTA(msg); RTA_OK(attr, len);
+ attr = RTA_NEXT(attr, len)) {
+ if (attr->rta_type == RTA_DST) {
+ char address[INET_ADDRSTRLEN] = {};
+ inet_ntop(AF_INET, RTA_DATA(attr), address, sizeof(address));
+ std::cout << ", dst=" << address;
+ rtDstFound = true;
+ }
+ }
+
+ std::cout << std::endl;
+
+ if (msg->rtm_table == RT_TABLE_MAIN) {
+ routeFound = true;
+ dstFound = rtDstFound && dstFound;
+ }
+ },
+ false));
+ // At least one route found in main route table.
+ EXPECT_TRUE(routeFound);
+ // Found RTA_DST for each route in main table.
+ EXPECT_TRUE(dstFound);
+}
+
+// GetRouteRequest tests a RTM_GETROUTE request with RTM_F_LOOKUP_TABLE flag.
+TEST(NetlinkRouteTest, GetRouteRequest) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+ uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get()));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct rtmsg rtm;
+ struct nlattr nla;
+ struct in_addr sin_addr;
+ };
+
+ constexpr uint32_t kSeq = 12345;
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETROUTE;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST;
+ req.hdr.nlmsg_seq = kSeq;
+
+ req.rtm.rtm_family = AF_INET;
+ req.rtm.rtm_dst_len = 32;
+ req.rtm.rtm_src_len = 0;
+ req.rtm.rtm_tos = 0;
+ req.rtm.rtm_table = RT_TABLE_UNSPEC;
+ req.rtm.rtm_protocol = RTPROT_UNSPEC;
+ req.rtm.rtm_scope = RT_SCOPE_UNIVERSE;
+ req.rtm.rtm_type = RTN_UNSPEC;
+ req.rtm.rtm_flags = RTM_F_LOOKUP_TABLE;
+
+ req.nla.nla_len = 8;
+ req.nla.nla_type = RTA_DST;
+ inet_aton("127.0.0.2", &req.sin_addr);
+
+ bool rtDstFound = false;
+ ASSERT_NO_ERRNO(NetlinkRequestResponseSingle(
+ fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) {
+ // Validate the reponse to RTM_GETROUTE request with RTM_F_LOOKUP_TABLE
+ // flag.
+ EXPECT_THAT(hdr->nlmsg_type, RTM_NEWROUTE);
+
+ EXPECT_TRUE(hdr->nlmsg_flags == 0) << std::hex << hdr->nlmsg_flags;
+
+ EXPECT_EQ(hdr->nlmsg_seq, kSeq);
+ EXPECT_EQ(hdr->nlmsg_pid, port);
+
+ // RTM_NEWROUTE contains at least the header and rtmsg.
+ ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct rtmsg)));
+ const struct rtmsg* msg =
+ reinterpret_cast<const struct rtmsg*>(NLMSG_DATA(hdr));
+
+ // NOTE: rtmsg fields are char fields.
+ std::cout << "Found route table=" << static_cast<int>(msg->rtm_table)
+ << ", protocol=" << static_cast<int>(msg->rtm_protocol)
+ << ", scope=" << static_cast<int>(msg->rtm_scope)
+ << ", type=" << static_cast<int>(msg->rtm_type);
+
+ EXPECT_EQ(msg->rtm_family, AF_INET);
+ EXPECT_EQ(msg->rtm_dst_len, 32);
+ EXPECT_TRUE((msg->rtm_flags & RTM_F_CLONED) == RTM_F_CLONED)
+ << std::hex << msg->rtm_flags;
+
+ int len = RTM_PAYLOAD(hdr);
+ std::cout << ", len=" << len;
+ for (struct rtattr* attr = RTM_RTA(msg); RTA_OK(attr, len);
+ attr = RTA_NEXT(attr, len)) {
+ if (attr->rta_type == RTA_DST) {
+ char address[INET_ADDRSTRLEN] = {};
+ inet_ntop(AF_INET, RTA_DATA(attr), address, sizeof(address));
+ std::cout << ", dst=" << address;
+ rtDstFound = true;
+ } else if (attr->rta_type == RTA_OIF) {
+ const char* oif = reinterpret_cast<const char*>(RTA_DATA(attr));
+ std::cout << ", oif=" << oif;
+ }
+ }
+
+ std::cout << std::endl;
+ }));
+ // Found RTA_DST for RTM_F_LOOKUP_TABLE.
+ EXPECT_TRUE(rtDstFound);
+}
+
+// RecvmsgTrunc tests the recvmsg MSG_TRUNC flag with zero length output
+// buffer. MSG_TRUNC with a zero length buffer should consume subsequent
+// messages off the socket.
+TEST(NetlinkRouteTest, RecvmsgTrunc) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct rtgenmsg rgm;
+ };
+
+ struct request req;
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETADDR;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = kSeq;
+ req.rgm.rtgen_family = AF_UNSPEC;
+
+ struct iovec iov = {};
+ iov.iov_base = &req;
+ iov.iov_len = sizeof(req);
+
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ iov.iov_base = NULL;
+ iov.iov_len = 0;
+
+ int trunclen, trunclen2;
+
+ // Note: This test assumes at least two messages are returned by the
+ // RTM_GETADDR request. That means at least one RTM_NEWLINK message and one
+ // NLMSG_DONE message. We cannot read all the messages without blocking
+ // because we would need to read the message into a buffer and check the
+ // nlmsg_type for NLMSG_DONE. However, the test depends on reading into a
+ // zero-length buffer.
+
+ // First, call recvmsg with MSG_TRUNC. This will read the full message from
+ // the socket and return it's full length. Subsequent calls to recvmsg will
+ // read the next messages from the socket.
+ ASSERT_THAT(trunclen = RetryEINTR(recvmsg)(fd.get(), &msg, MSG_TRUNC),
+ SyscallSucceeds());
+
+ // Message should always be truncated. However, While the destination iov is
+ // zero length, MSG_TRUNC returns the size of the next message so it should
+ // not be zero.
+ ASSERT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
+ ASSERT_NE(trunclen, 0);
+ // Returned length is at least the header and ifaddrmsg.
+ EXPECT_GE(trunclen, sizeof(struct nlmsghdr) + sizeof(struct ifaddrmsg));
+
+ // Reset the msg_flags to make sure that the recvmsg call is setting them
+ // properly.
+ msg.msg_flags = 0;
+
+ // Make a second recvvmsg call to get the next message.
+ ASSERT_THAT(trunclen2 = RetryEINTR(recvmsg)(fd.get(), &msg, MSG_TRUNC),
+ SyscallSucceeds());
+ ASSERT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
+ ASSERT_NE(trunclen2, 0);
+
+ // Assert that the received messages are not the same.
+ //
+ // We are calling recvmsg with a zero length buffer so we have no way to
+ // inspect the messages to make sure they are not equal in value. The best
+ // we can do is to compare their lengths.
+ ASSERT_NE(trunclen, trunclen2);
+}
+
+// RecvmsgTruncPeek tests recvmsg with the combination of the MSG_TRUNC and
+// MSG_PEEK flags and a zero length output buffer. This is normally used to
+// read the full length of the next message on the socket without consuming
+// it, so a properly sized buffer can be allocated to store the message. This
+// test tests that scenario.
+TEST(NetlinkRouteTest, RecvmsgTruncPeek) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct rtgenmsg rgm;
+ };
+
+ struct request req;
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETADDR;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = kSeq;
+ req.rgm.rtgen_family = AF_UNSPEC;
+
+ struct iovec iov = {};
+ iov.iov_base = &req;
+ iov.iov_len = sizeof(req);
+
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ int type = -1;
+ do {
+ int peeklen;
+ int len;
+
+ iov.iov_base = NULL;
+ iov.iov_len = 0;
+
+ // Call recvmsg with MSG_PEEK and MSG_TRUNC. This will peek at the message
+ // and return it's full length.
+ // See: MSG_TRUNC http://man7.org/linux/man-pages/man2/recv.2.html
+ ASSERT_THAT(
+ peeklen = RetryEINTR(recvmsg)(fd.get(), &msg, MSG_PEEK | MSG_TRUNC),
+ SyscallSucceeds());
+
+ // Message should always be truncated.
+ ASSERT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
+ ASSERT_NE(peeklen, 0);
+
+ // Reset the message flags for the next call.
+ msg.msg_flags = 0;
+
+ // Make the actual call to recvmsg to get the actual data. We will use
+ // the length returned from the peek call for the allocated buffer size..
+ std::vector<char> buf(peeklen);
+ iov.iov_base = buf.data();
+ iov.iov_len = buf.size();
+ ASSERT_THAT(len = RetryEINTR(recvmsg)(fd.get(), &msg, 0),
+ SyscallSucceeds());
+
+ // Message should not be truncated since we allocated the correct buffer
+ // size.
+ EXPECT_NE(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
+
+ // MSG_PEEK should have left data on the socket and the subsequent call
+ // with should have retrieved the same data. Both calls should have
+ // returned the message's full length so they should be equal.
+ ASSERT_NE(len, 0);
+ ASSERT_EQ(peeklen, len);
+
+ for (struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf.data());
+ NLMSG_OK(hdr, len); hdr = NLMSG_NEXT(hdr, len)) {
+ type = hdr->nlmsg_type;
+ }
+ } while (type != NLMSG_DONE && type != NLMSG_ERROR);
+}
+
+// No SCM_CREDENTIALS are received without SO_PASSCRED set.
+TEST(NetlinkRouteTest, NoPasscredNoCreds) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ ASSERT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOff,
+ sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct rtgenmsg rgm;
+ };
+
+ struct request req;
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETADDR;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = kSeq;
+ req.rgm.rtgen_family = AF_UNSPEC;
+
+ struct iovec iov = {};
+ iov.iov_base = &req;
+ iov.iov_len = sizeof(req);
+
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ iov.iov_base = NULL;
+ iov.iov_len = 0;
+
+ char control[CMSG_SPACE(sizeof(struct ucred))] = {};
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ // Note: This test assumes at least one message is returned by the
+ // RTM_GETADDR request.
+ ASSERT_THAT(RetryEINTR(recvmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ // No control messages.
+ EXPECT_EQ(CMSG_FIRSTHDR(&msg), nullptr);
+}
+
+// SCM_CREDENTIALS are received with SO_PASSCRED set.
+TEST(NetlinkRouteTest, PasscredCreds) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ ASSERT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct rtgenmsg rgm;
+ };
+
+ struct request req;
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETADDR;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = kSeq;
+ req.rgm.rtgen_family = AF_UNSPEC;
+
+ struct iovec iov = {};
+ iov.iov_base = &req;
+ iov.iov_len = sizeof(req);
+
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ iov.iov_base = NULL;
+ iov.iov_len = 0;
+
+ char control[CMSG_SPACE(sizeof(struct ucred))] = {};
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ // Note: This test assumes at least one message is returned by the
+ // RTM_GETADDR request.
+ ASSERT_THAT(RetryEINTR(recvmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ struct ucred creds;
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(creds)));
+ ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET);
+ ASSERT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS);
+
+ memcpy(&creds, CMSG_DATA(cmsg), sizeof(creds));
+
+ // The peer is the kernel, which is "PID" 0.
+ EXPECT_EQ(creds.pid, 0);
+ // The kernel identifies as root. Also allow nobody in case this test is
+ // running in a userns without root mapped.
+ EXPECT_THAT(creds.uid, AnyOf(Eq(0), Eq(65534)));
+ EXPECT_THAT(creds.gid, AnyOf(Eq(0), Eq(65534)));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc
new file mode 100644
index 000000000..bde1dbb4d
--- /dev/null
+++ b/test/syscalls/linux/socket_netlink_route_util.cc
@@ -0,0 +1,162 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_netlink_route_util.h"
+
+#include <linux/if.h>
+#include <linux/netlink.h>
+#include <linux/rtnetlink.h>
+
+#include "test/syscalls/linux/socket_netlink_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+constexpr uint32_t kSeq = 12345;
+
+} // namespace
+
+PosixError DumpLinks(
+ const FileDescriptor& fd, uint32_t seq,
+ const std::function<void(const struct nlmsghdr* hdr)>& fn) {
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = seq;
+ req.ifm.ifi_family = AF_UNSPEC;
+
+ return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false);
+}
+
+PosixErrorOr<std::vector<Link>> DumpLinks() {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ std::vector<Link> links;
+ RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) {
+ if (hdr->nlmsg_type != RTM_NEWLINK ||
+ hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) {
+ return;
+ }
+ const struct ifinfomsg* msg =
+ reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr));
+ const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME);
+ if (rta == nullptr) {
+ // Ignore links that do not have a name.
+ return;
+ }
+
+ links.emplace_back();
+ links.back().index = msg->ifi_index;
+ links.back().type = msg->ifi_type;
+ links.back().name =
+ std::string(reinterpret_cast<const char*>(RTA_DATA(rta)));
+ }));
+ return links;
+}
+
+PosixErrorOr<Link> LoopbackLink() {
+ ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks());
+ for (const auto& link : links) {
+ if (link.type == ARPHRD_LOOPBACK) {
+ return link;
+ }
+ }
+ return PosixError(ENOENT, "loopback link not found");
+}
+
+PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifaddrmsg ifaddr;
+ char attrbuf[512];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr));
+ req.hdr.nlmsg_type = RTM_NEWADDR;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifaddr.ifa_index = index;
+ req.ifaddr.ifa_family = family;
+ req.ifaddr.ifa_prefixlen = prefixlen;
+
+ struct rtattr* rta = reinterpret_cast<struct rtattr*>(
+ reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len));
+ rta->rta_type = IFA_LOCAL;
+ rta->rta_len = RTA_LENGTH(addrlen);
+ req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen);
+ memcpy(RTA_DATA(rta), addr, addrlen);
+
+ return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len);
+}
+
+PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifinfo;
+ char pad[NLMSG_ALIGNTO];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo));
+ req.hdr.nlmsg_type = RTM_NEWLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifinfo.ifi_index = index;
+ req.ifinfo.ifi_flags = flags;
+ req.ifinfo.ifi_change = change;
+
+ return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len);
+}
+
+PosixError LinkSetMacAddr(int index, const void* addr, int addrlen) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifinfo;
+ char attrbuf[512];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo));
+ req.hdr.nlmsg_type = RTM_NEWLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifinfo.ifi_index = index;
+
+ struct rtattr* rta = reinterpret_cast<struct rtattr*>(
+ reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len));
+ rta->rta_type = IFLA_ADDRESS;
+ rta->rta_len = RTA_LENGTH(addrlen);
+ req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen);
+ memcpy(RTA_DATA(rta), addr, addrlen);
+
+ return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len);
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h
new file mode 100644
index 000000000..149c4a7f6
--- /dev/null
+++ b/test/syscalls/linux/socket_netlink_route_util.h
@@ -0,0 +1,55 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_
+
+#include <linux/netlink.h>
+#include <linux/rtnetlink.h>
+
+#include <vector>
+
+#include "test/syscalls/linux/socket_netlink_util.h"
+
+namespace gvisor {
+namespace testing {
+
+struct Link {
+ int index;
+ int16_t type;
+ std::string name;
+};
+
+PosixError DumpLinks(const FileDescriptor& fd, uint32_t seq,
+ const std::function<void(const struct nlmsghdr* hdr)>& fn);
+
+PosixErrorOr<std::vector<Link>> DumpLinks();
+
+// Returns the loopback link on the system. ENOENT if not found.
+PosixErrorOr<Link> LoopbackLink();
+
+// LinkAddLocalAddr sets IFA_LOCAL attribute on the interface.
+PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen);
+
+// LinkChangeFlags changes interface flags. E.g. IFF_UP.
+PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change);
+
+// LinkSetMacAddr sets IFLA_ADDRESS attribute of the interface.
+PosixError LinkSetMacAddr(int index, const void* addr, int addrlen);
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_
diff --git a/test/syscalls/linux/socket_netlink_uevent.cc b/test/syscalls/linux/socket_netlink_uevent.cc
new file mode 100644
index 000000000..da425bed4
--- /dev/null
+++ b/test/syscalls/linux/socket_netlink_uevent.cc
@@ -0,0 +1,83 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <linux/filter.h>
+#include <linux/netlink.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_netlink_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+// Tests for NETLINK_KOBJECT_UEVENT sockets.
+//
+// gVisor never sends any messages on these sockets, so we don't test the events
+// themselves.
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// SO_PASSCRED can be enabled. Since no messages are sent in gVisor, we don't
+// actually test receiving credentials.
+TEST(NetlinkUeventTest, PassCred) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_KOBJECT_UEVENT));
+
+ EXPECT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+}
+
+// SO_DETACH_FILTER fails without a filter already installed.
+TEST(NetlinkUeventTest, DetachNoFilter) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_KOBJECT_UEVENT));
+
+ int opt;
+ EXPECT_THAT(
+ setsockopt(fd.get(), SOL_SOCKET, SO_DETACH_FILTER, &opt, sizeof(opt)),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+// We can attach a BPF filter.
+TEST(NetlinkUeventTest, AttachFilter) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_KOBJECT_UEVENT));
+
+ // Minimal BPF program: a single ret.
+ struct sock_filter filter = {0x6, 0, 0, 0};
+ struct sock_fprog prog = {};
+ prog.len = 1;
+ prog.filter = &filter;
+
+ EXPECT_THAT(
+ setsockopt(fd.get(), SOL_SOCKET, SO_ATTACH_FILTER, &prog, sizeof(prog)),
+ SyscallSucceeds());
+
+ int opt;
+ EXPECT_THAT(
+ setsockopt(fd.get(), SOL_SOCKET, SO_DETACH_FILTER, &opt, sizeof(opt)),
+ SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc
new file mode 100644
index 000000000..952eecfe8
--- /dev/null
+++ b/test/syscalls/linux/socket_netlink_util.cc
@@ -0,0 +1,187 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_netlink_util.h"
+
+#include <linux/if_arp.h>
+#include <linux/netlink.h>
+#include <linux/rtnetlink.h>
+#include <sys/socket.h>
+
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+PosixErrorOr<FileDescriptor> NetlinkBoundSocket(int protocol) {
+ FileDescriptor fd;
+ ASSIGN_OR_RETURN_ERRNO(fd, Socket(AF_NETLINK, SOCK_RAW, protocol));
+
+ struct sockaddr_nl addr = {};
+ addr.nl_family = AF_NETLINK;
+
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ bind(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)));
+ MaybeSave();
+
+ return std::move(fd);
+}
+
+PosixErrorOr<uint32_t> NetlinkPortID(int fd) {
+ struct sockaddr_nl addr;
+ socklen_t addrlen = sizeof(addr);
+
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ getsockname(fd, reinterpret_cast<struct sockaddr*>(&addr), &addrlen));
+ MaybeSave();
+
+ return static_cast<uint32_t>(addr.nl_pid);
+}
+
+PosixError NetlinkRequestResponse(
+ const FileDescriptor& fd, void* request, size_t len,
+ const std::function<void(const struct nlmsghdr* hdr)>& fn,
+ bool expect_nlmsgerr) {
+ struct iovec iov = {};
+ iov.iov_base = request;
+ iov.iov_len = len;
+
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ // No destination required; it defaults to pid 0, the kernel.
+
+ RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(sendmsg)(fd.get(), &msg, 0));
+
+ constexpr size_t kBufferSize = 4096;
+ std::vector<char> buf(kBufferSize);
+ iov.iov_base = buf.data();
+ iov.iov_len = buf.size();
+
+ // If NLM_F_MULTI is set, response is a series of messages that ends with a
+ // NLMSG_DONE message.
+ int type = -1;
+ int flags = 0;
+ do {
+ int len;
+ RETURN_ERROR_IF_SYSCALL_FAIL(len = RetryEINTR(recvmsg)(fd.get(), &msg, 0));
+
+ // We don't bother with the complexity of dealing with truncated messages.
+ // We must allocate a large enough buffer up front.
+ if ((msg.msg_flags & MSG_TRUNC) == MSG_TRUNC) {
+ return PosixError(EIO,
+ absl::StrCat("Received truncated message with flags: ",
+ msg.msg_flags));
+ }
+
+ for (struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf.data());
+ NLMSG_OK(hdr, len); hdr = NLMSG_NEXT(hdr, len)) {
+ fn(hdr);
+ flags = hdr->nlmsg_flags;
+ type = hdr->nlmsg_type;
+ // Done should include an integer payload for dump_done_errno.
+ // See net/netlink/af_netlink.c:netlink_dump
+ // Some tools like the 'ip' tool check the minimum length of the
+ // NLMSG_DONE message.
+ if (type == NLMSG_DONE) {
+ EXPECT_GE(hdr->nlmsg_len, NLMSG_LENGTH(sizeof(int)));
+ }
+ }
+ } while ((flags & NLM_F_MULTI) && type != NLMSG_DONE && type != NLMSG_ERROR);
+
+ if (expect_nlmsgerr) {
+ EXPECT_EQ(type, NLMSG_ERROR);
+ } else if (flags & NLM_F_MULTI) {
+ EXPECT_EQ(type, NLMSG_DONE);
+ }
+ return NoError();
+}
+
+PosixError NetlinkRequestResponseSingle(
+ const FileDescriptor& fd, void* request, size_t len,
+ const std::function<void(const struct nlmsghdr* hdr)>& fn) {
+ struct iovec iov = {};
+ iov.iov_base = request;
+ iov.iov_len = len;
+
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ // No destination required; it defaults to pid 0, the kernel.
+
+ RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(sendmsg)(fd.get(), &msg, 0));
+
+ constexpr size_t kBufferSize = 4096;
+ std::vector<char> buf(kBufferSize);
+ iov.iov_base = buf.data();
+ iov.iov_len = buf.size();
+
+ int ret;
+ RETURN_ERROR_IF_SYSCALL_FAIL(ret = RetryEINTR(recvmsg)(fd.get(), &msg, 0));
+
+ // We don't bother with the complexity of dealing with truncated messages.
+ // We must allocate a large enough buffer up front.
+ if ((msg.msg_flags & MSG_TRUNC) == MSG_TRUNC) {
+ return PosixError(
+ EIO,
+ absl::StrCat("Received truncated message with flags: ", msg.msg_flags));
+ }
+
+ for (struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf.data());
+ NLMSG_OK(hdr, ret); hdr = NLMSG_NEXT(hdr, ret)) {
+ fn(hdr);
+ }
+
+ return NoError();
+}
+
+PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq,
+ void* request, size_t len) {
+ // Dummy negative number for no error message received.
+ // We won't get a negative error number so there will be no confusion.
+ int err = -42;
+ RETURN_IF_ERRNO(NetlinkRequestResponse(
+ fd, request, len,
+ [&](const struct nlmsghdr* hdr) {
+ EXPECT_EQ(NLMSG_ERROR, hdr->nlmsg_type);
+ EXPECT_EQ(hdr->nlmsg_seq, seq);
+ EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr));
+
+ const struct nlmsgerr* msg =
+ reinterpret_cast<const struct nlmsgerr*>(NLMSG_DATA(hdr));
+ err = -msg->error;
+ },
+ true));
+ return PosixError(err);
+}
+
+const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr,
+ const struct ifinfomsg* msg, int16_t attr) {
+ const int ifi_space = NLMSG_SPACE(sizeof(*msg));
+ int attrlen = hdr->nlmsg_len - ifi_space;
+ const struct rtattr* rta = reinterpret_cast<const struct rtattr*>(
+ reinterpret_cast<const uint8_t*>(hdr) + NLMSG_ALIGN(ifi_space));
+ for (; RTA_OK(rta, attrlen); rta = RTA_NEXT(rta, attrlen)) {
+ if (rta->rta_type == attr) {
+ return rta;
+ }
+ }
+ return nullptr;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_netlink_util.h b/test/syscalls/linux/socket_netlink_util.h
new file mode 100644
index 000000000..e13ead406
--- /dev/null
+++ b/test/syscalls/linux/socket_netlink_util.h
@@ -0,0 +1,62 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_SOCKET_NETLINK_UTIL_H_
+#define GVISOR_TEST_SYSCALLS_SOCKET_NETLINK_UTIL_H_
+
+#include <sys/socket.h>
+// socket.h has to be included before if_arp.h.
+#include <linux/if_arp.h>
+#include <linux/netlink.h>
+#include <linux/rtnetlink.h>
+
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+// Returns a bound netlink socket.
+PosixErrorOr<FileDescriptor> NetlinkBoundSocket(int protocol);
+
+// Returns the port ID of the passed socket.
+PosixErrorOr<uint32_t> NetlinkPortID(int fd);
+
+// Send the passed request and call fn on all response netlink messages.
+//
+// To be used on requests with NLM_F_MULTI reponses.
+PosixError NetlinkRequestResponse(
+ const FileDescriptor& fd, void* request, size_t len,
+ const std::function<void(const struct nlmsghdr* hdr)>& fn,
+ bool expect_nlmsgerr);
+
+// Send the passed request and call fn on all response netlink messages.
+//
+// To be used on requests without NLM_F_MULTI reponses.
+PosixError NetlinkRequestResponseSingle(
+ const FileDescriptor& fd, void* request, size_t len,
+ const std::function<void(const struct nlmsghdr* hdr)>& fn);
+
+// Send the passed request then expect and return an ack or error.
+PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq,
+ void* request, size_t len);
+
+// Find rtnetlink attribute in message.
+const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr,
+ const struct ifinfomsg* msg, int16_t attr);
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_SOCKET_NETLINK_UTIL_H_
diff --git a/test/syscalls/linux/socket_non_blocking.cc b/test/syscalls/linux/socket_non_blocking.cc
new file mode 100644
index 000000000..c3520cadd
--- /dev/null
+++ b/test/syscalls/linux/socket_non_blocking.cc
@@ -0,0 +1,62 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_non_blocking.h"
+
+#include <stdio.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+TEST_P(NonBlockingSocketPairTest, ReadNothingAvailable) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char buf[20] = {};
+ ASSERT_THAT(ReadFd(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST_P(NonBlockingSocketPairTest, RecvNothingAvailable) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char buf[20] = {};
+ ASSERT_THAT(RetryEINTR(recv)(sockets->first_fd(), buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST_P(NonBlockingSocketPairTest, RecvMsgNothingAvailable) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct iovec iov;
+ char buf[20] = {};
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->first_fd(), &msg, 0),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_non_blocking.h b/test/syscalls/linux/socket_non_blocking.h
new file mode 100644
index 000000000..bd3e02fd2
--- /dev/null
+++ b/test/syscalls/linux/socket_non_blocking.h
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NON_BLOCKING_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NON_BLOCKING_H_
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of connected non-blocking sockets.
+using NonBlockingSocketPairTest = SocketPairTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NON_BLOCKING_H_
diff --git a/test/syscalls/linux/socket_non_stream.cc b/test/syscalls/linux/socket_non_stream.cc
new file mode 100644
index 000000000..c61817f14
--- /dev/null
+++ b/test/syscalls/linux/socket_non_stream.cc
@@ -0,0 +1,337 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_non_stream.h"
+
+#include <stdio.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+TEST_P(NonStreamSocketPairTest, SendMsgTooLarge) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int sndbuf;
+ socklen_t length = sizeof(sndbuf);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, &sndbuf, &length),
+ SyscallSucceeds());
+
+ // Make the call too large to fit in the send buffer.
+ const int buffer_size = 3 * sndbuf;
+
+ EXPECT_THAT(SendLargeSendMsg(sockets, buffer_size, false /* reader */),
+ SyscallFailsWithErrno(EMSGSIZE));
+}
+
+// Stream sockets allow data sent with a single (e.g. write, sendmsg) syscall
+// to be read in pieces with multiple (e.g. read, recvmsg) syscalls.
+//
+// SplitRecv checks that control messages can only be read on the first (e.g.
+// read, recvmsg) syscall, even if it doesn't provide space for the control
+// message.
+TEST_P(NonStreamSocketPairTest, SplitRecv) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char sent_data[512];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ char received_data[sizeof(sent_data) / 2];
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data), 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(received_data)));
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+// Stream sockets allow data sent with multiple sends to be read in a single
+// recv. Datagram sockets do not.
+//
+// SingleRecv checks that only a single message is readable in a single recv.
+TEST_P(NonStreamSocketPairTest, SingleRecv) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char sent_data1[20];
+ RandomizeBuffer(sent_data1, sizeof(sent_data1));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data1, sizeof(sent_data1), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data1)));
+ char sent_data2[20];
+ RandomizeBuffer(sent_data2, sizeof(sent_data2));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data2, sizeof(sent_data2), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data2)));
+ char received_data[sizeof(sent_data1) + sizeof(sent_data2)];
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data1)));
+ EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1)));
+}
+
+TEST_P(NonStreamSocketPairTest, RecvmsgMsghdrFlagMsgTrunc) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[sizeof(sent_data) / 2] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+ EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data)));
+
+ // Check that msghdr flags were updated.
+ EXPECT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
+}
+
+// Stream sockets allow data sent with multiple sends to be peeked at in a
+// single recv. Datagram sockets (except for unix sockets) do not.
+//
+// SinglePeek checks that only a single message is peekable in a single recv.
+TEST_P(NonStreamSocketPairTest, SinglePeek) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char sent_data1[20];
+ RandomizeBuffer(sent_data1, sizeof(sent_data1));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data1, sizeof(sent_data1), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data1)));
+ char sent_data2[20];
+ RandomizeBuffer(sent_data2, sizeof(sent_data2));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data2, sizeof(sent_data2), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data2)));
+ char received_data[sizeof(sent_data1) + sizeof(sent_data2)];
+ for (int i = 0; i < 3; i++) {
+ memset(received_data, 0, sizeof(received_data));
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data), MSG_PEEK),
+ SyscallSucceedsWithValue(sizeof(sent_data1)));
+ EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1)));
+ }
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(sent_data1), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data1)));
+ EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1)));
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(sent_data2), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data2)));
+ EXPECT_EQ(0, memcmp(sent_data2, received_data, sizeof(sent_data2)));
+}
+
+TEST_P(NonStreamSocketPairTest, MsgTruncTruncation) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char sent_data[512];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ char received_data[sizeof(sent_data)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data) / 2, MSG_TRUNC),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data) / 2));
+
+ // Check that we didn't get any extra data.
+ EXPECT_NE(0, memcmp(sent_data + sizeof(sent_data) / 2,
+ received_data + sizeof(received_data) / 2,
+ sizeof(sent_data) / 2));
+}
+
+TEST_P(NonStreamSocketPairTest, MsgTruncTruncationRecvmsgMsghdrFlagMsgTrunc) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[sizeof(sent_data) / 2] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data)));
+
+ // Check that msghdr flags were updated.
+ EXPECT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
+}
+
+TEST_P(NonStreamSocketPairTest, MsgTruncSameSize) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char sent_data[512];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ char received_data[sizeof(sent_data)];
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data), MSG_TRUNC),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+}
+
+TEST_P(NonStreamSocketPairTest, MsgTruncNotFull) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char sent_data[512];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ char received_data[2 * sizeof(sent_data)];
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data), MSG_TRUNC),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+}
+
+// This test tests reading from a socket with MSG_TRUNC and a zero length
+// receive buffer. The user should be able to get the message length.
+TEST_P(NonStreamSocketPairTest, RecvmsgMsgTruncZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ // The receive buffer is of zero length.
+ char received_data[0] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ // The syscall succeeds returning the full size of the message on the socket.
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ // Check that MSG_TRUNC is set on msghdr flags.
+ EXPECT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
+}
+
+// This test tests reading from a socket with MSG_TRUNC | MSG_PEEK and a zero
+// length receive buffer. The user should be able to get the message length
+// without reading data off the socket.
+TEST_P(NonStreamSocketPairTest, RecvmsgMsgTruncMsgPeekZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ // The receive buffer is of zero length.
+ char peek_data[0] = {};
+
+ struct iovec peek_iov;
+ peek_iov.iov_base = peek_data;
+ peek_iov.iov_len = sizeof(peek_data);
+ struct msghdr peek_msg = {};
+ peek_msg.msg_flags = -1;
+ peek_msg.msg_iov = &peek_iov;
+ peek_msg.msg_iovlen = 1;
+
+ // The syscall succeeds returning the full size of the message on the socket.
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &peek_msg,
+ MSG_TRUNC | MSG_PEEK),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ // Check that MSG_TRUNC is set on msghdr flags because the receive buffer is
+ // smaller than the message size.
+ EXPECT_EQ(peek_msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
+
+ char received_data[sizeof(sent_data)] = {};
+
+ struct iovec received_iov;
+ received_iov.iov_base = received_data;
+ received_iov.iov_len = sizeof(received_data);
+ struct msghdr received_msg = {};
+ received_msg.msg_flags = -1;
+ received_msg.msg_iov = &received_iov;
+ received_msg.msg_iovlen = 1;
+
+ // Next we can read the actual data.
+ ASSERT_THAT(
+ RetryEINTR(recvmsg)(sockets->second_fd(), &received_msg, MSG_TRUNC),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ // Check that MSG_TRUNC is not set on msghdr flags because we read the whole
+ // message.
+ EXPECT_EQ(received_msg.msg_flags & MSG_TRUNC, 0);
+}
+
+// This test tests reading from a socket with MSG_TRUNC | MSG_PEEK and a zero
+// length receive buffer and MSG_DONTWAIT. The user should be able to get an
+// EAGAIN or EWOULDBLOCK error response.
+TEST_P(NonStreamSocketPairTest, RecvmsgTruncPeekDontwaitZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // NOTE: We don't send any data on the socket.
+
+ // The receive buffer is of zero length.
+ char peek_data[0] = {};
+
+ struct iovec peek_iov;
+ peek_iov.iov_base = peek_data;
+ peek_iov.iov_len = sizeof(peek_data);
+ struct msghdr peek_msg = {};
+ peek_msg.msg_flags = -1;
+ peek_msg.msg_iov = &peek_iov;
+ peek_msg.msg_iovlen = 1;
+
+ // recvmsg fails with EAGAIN because no data is available on the socket.
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &peek_msg,
+ MSG_TRUNC | MSG_PEEK | MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_non_stream.h b/test/syscalls/linux/socket_non_stream.h
new file mode 100644
index 000000000..469fbe6a2
--- /dev/null
+++ b/test/syscalls/linux/socket_non_stream.h
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NON_STREAM_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NON_STREAM_H_
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of connected non-stream sockets.
+using NonStreamSocketPairTest = SocketPairTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NON_STREAM_H_
diff --git a/test/syscalls/linux/socket_non_stream_blocking.cc b/test/syscalls/linux/socket_non_stream_blocking.cc
new file mode 100644
index 000000000..b052f6e61
--- /dev/null
+++ b/test/syscalls/linux/socket_non_stream_blocking.cc
@@ -0,0 +1,85 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_non_stream_blocking.h"
+
+#include <stdio.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+TEST_P(BlockingNonStreamSocketPairTest, RecvLessThanBufferWaitAll) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[100];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[sizeof(sent_data) * 2] = {};
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data), MSG_WAITALL),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+}
+
+// This test tests reading from a socket with MSG_TRUNC | MSG_PEEK and a zero
+// length receive buffer and MSG_DONTWAIT. The recvmsg call should block on
+// reading the data.
+TEST_P(BlockingNonStreamSocketPairTest,
+ RecvmsgTruncPeekDontwaitZeroLenBlocking) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // NOTE: We don't initially send any data on the socket.
+ const int data_size = 10;
+ char sent_data[data_size];
+ RandomizeBuffer(sent_data, data_size);
+
+ // The receive buffer is of zero length.
+ char peek_data[0] = {};
+
+ struct iovec peek_iov;
+ peek_iov.iov_base = peek_data;
+ peek_iov.iov_len = sizeof(peek_data);
+ struct msghdr peek_msg = {};
+ peek_msg.msg_flags = -1;
+ peek_msg.msg_iov = &peek_iov;
+ peek_msg.msg_iovlen = 1;
+
+ ScopedThread t([&]() {
+ // The syscall succeeds returning the full size of the message on the
+ // socket. This should block until there is data on the socket.
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &peek_msg,
+ MSG_TRUNC | MSG_PEEK),
+ SyscallSucceedsWithValue(data_size));
+ });
+
+ absl::SleepFor(absl::Seconds(1));
+ ASSERT_THAT(RetryEINTR(send)(sockets->first_fd(), sent_data, data_size, 0),
+ SyscallSucceedsWithValue(data_size));
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_non_stream_blocking.h b/test/syscalls/linux/socket_non_stream_blocking.h
new file mode 100644
index 000000000..6e205a039
--- /dev/null
+++ b/test/syscalls/linux/socket_non_stream_blocking.h
@@ -0,0 +1,30 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NON_STREAM_BLOCKING_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NON_STREAM_BLOCKING_H_
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of blocking connected non-stream
+// sockets.
+using BlockingNonStreamSocketPairTest = SocketPairTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NON_STREAM_BLOCKING_H_
diff --git a/test/syscalls/linux/socket_stream.cc b/test/syscalls/linux/socket_stream.cc
new file mode 100644
index 000000000..6522b2e01
--- /dev/null
+++ b/test/syscalls/linux/socket_stream.cc
@@ -0,0 +1,178 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_stream.h"
+
+#include <stdio.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+TEST_P(StreamSocketPairTest, SplitRecv) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char sent_data[512];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ char received_data[sizeof(sent_data) / 2];
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data), 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(received_data)));
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data), 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+ EXPECT_EQ(0, memcmp(sent_data + sizeof(received_data), received_data,
+ sizeof(received_data)));
+}
+
+// Stream sockets allow data sent with multiple sends to be read in a single
+// recv.
+//
+// CoalescedRecv checks that multiple messages are readable in a single recv.
+TEST_P(StreamSocketPairTest, CoalescedRecv) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char sent_data1[20];
+ RandomizeBuffer(sent_data1, sizeof(sent_data1));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data1, sizeof(sent_data1), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data1)));
+ char sent_data2[20];
+ RandomizeBuffer(sent_data2, sizeof(sent_data2));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data2, sizeof(sent_data2), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data2)));
+ char received_data[sizeof(sent_data1) + sizeof(sent_data2)];
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data), 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+ EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1)));
+ EXPECT_EQ(0, memcmp(sent_data2, received_data + sizeof(sent_data1),
+ sizeof(sent_data2)));
+}
+
+TEST_P(StreamSocketPairTest, WriteOneSideClosed) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+ const char str[] = "abc";
+ ASSERT_THAT(write(sockets->second_fd(), str, 3),
+ SyscallFailsWithErrno(EPIPE));
+}
+
+TEST_P(StreamSocketPairTest, RecvmsgMsghdrFlagsNoMsgTrunc) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[sizeof(sent_data) / 2] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+ EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data)));
+
+ // Check that msghdr flags were cleared (MSG_TRUNC was not set).
+ ASSERT_EQ(msg.msg_flags & MSG_TRUNC, 0);
+}
+
+TEST_P(StreamSocketPairTest, RecvmsgTruncZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[0] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC),
+ SyscallSucceedsWithValue(0));
+
+ // Check that msghdr flags were cleared (MSG_TRUNC was not set).
+ ASSERT_EQ(msg.msg_flags & MSG_TRUNC, 0);
+}
+
+TEST_P(StreamSocketPairTest, RecvmsgTruncPeekZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[0] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(
+ RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC | MSG_PEEK),
+ SyscallSucceedsWithValue(0));
+
+ // Check that msghdr flags were cleared (MSG_TRUNC was not set).
+ ASSERT_EQ(msg.msg_flags & MSG_TRUNC, 0);
+}
+
+TEST_P(StreamSocketPairTest, MsgTrunc) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char sent_data[512];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ char received_data[sizeof(sent_data)];
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data) / 2, MSG_TRUNC),
+ SyscallSucceedsWithValue(sizeof(sent_data) / 2));
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data) / 2));
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_stream.h b/test/syscalls/linux/socket_stream.h
new file mode 100644
index 000000000..b837b8f8c
--- /dev/null
+++ b/test/syscalls/linux/socket_stream.h
@@ -0,0 +1,30 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_STREAM_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_STREAM_H_
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of blocking and non-blocking
+// connected stream sockets.
+using StreamSocketPairTest = SocketPairTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_STREAM_H_
diff --git a/test/syscalls/linux/socket_stream_blocking.cc b/test/syscalls/linux/socket_stream_blocking.cc
new file mode 100644
index 000000000..538ee2268
--- /dev/null
+++ b/test/syscalls/linux/socket_stream_blocking.cc
@@ -0,0 +1,163 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_stream_blocking.h"
+
+#include <stdio.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+#include "test/util/timer_util.h"
+
+namespace gvisor {
+namespace testing {
+
+TEST_P(BlockingStreamSocketPairTest, BlockPartialWriteClosed) {
+ // FIXME(b/35921550): gVisor doesn't support SO_SNDBUF on UDS, nor does it
+ // enforce any limit; it will write arbitrary amounts of data without
+ // blocking.
+ SKIP_IF(IsRunningOnGvisor());
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int buffer_size;
+ socklen_t length = sizeof(buffer_size);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF,
+ &buffer_size, &length),
+ SyscallSucceeds());
+
+ int wfd = sockets->first_fd();
+ ScopedThread t([wfd, buffer_size]() {
+ std::vector<char> buf(2 * buffer_size);
+ // Write more than fits in the buffer. Blocks then returns partial write
+ // when the other end is closed. The next call returns EPIPE.
+ //
+ // N.B. writes occur in chunks, so we may see less than buffer_size from
+ // the first call.
+ ASSERT_THAT(write(wfd, buf.data(), buf.size()),
+ SyscallSucceedsWithValue(::testing::Gt(0)));
+ ASSERT_THAT(write(wfd, buf.data(), buf.size()),
+ ::testing::AnyOf(SyscallFailsWithErrno(EPIPE),
+ SyscallFailsWithErrno(ECONNRESET)));
+ });
+
+ // Leave time for write to become blocked.
+ absl::SleepFor(absl::Seconds(1));
+
+ ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
+}
+
+// Random save may interrupt the call to sendmsg() in SendLargeSendMsg(),
+// causing the write to be incomplete and the test to hang.
+TEST_P(BlockingStreamSocketPairTest, SendMsgTooLarge_NoRandomSave) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int sndbuf;
+ socklen_t length = sizeof(sndbuf);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, &sndbuf, &length),
+ SyscallSucceeds());
+
+ // Make the call too large to fit in the send buffer.
+ const int buffer_size = 3 * sndbuf;
+
+ EXPECT_THAT(SendLargeSendMsg(sockets, buffer_size, true /* reader */),
+ SyscallSucceedsWithValue(buffer_size));
+}
+
+TEST_P(BlockingStreamSocketPairTest, RecvLessThanBuffer) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[100];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[200] = {};
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+}
+
+// Test that MSG_WAITALL causes recv to block until all requested data is
+// received. Random save can interrupt blocking and cause received data to be
+// returned, even if the amount received is less than the full requested amount.
+TEST_P(BlockingStreamSocketPairTest, RecvLessThanBufferWaitAll_NoRandomSave) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[100];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ constexpr auto kDuration = absl::Milliseconds(200);
+ auto before = Now(CLOCK_MONOTONIC);
+
+ const ScopedThread t([&]() {
+ absl::SleepFor(kDuration);
+
+ // Don't let saving after the write interrupt the blocking recv.
+ const DisableSave ds;
+
+ ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ });
+
+ char received_data[sizeof(sent_data) * 2] = {};
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data), MSG_WAITALL),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ auto after = Now(CLOCK_MONOTONIC);
+ EXPECT_GE(after - before, kDuration);
+}
+
+TEST_P(BlockingStreamSocketPairTest, SendTimeout) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = 10
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ std::vector<char> buf(kPageSize);
+ // We don't know how much data the socketpair will buffer, so we may do an
+ // arbitrarily large number of writes; saving after each write causes this
+ // test's time to explode.
+ const DisableSave ds;
+ for (;;) {
+ int ret;
+ ASSERT_THAT(
+ ret = RetryEINTR(send)(sockets->first_fd(), buf.data(), buf.size(), 0),
+ ::testing::AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(EAGAIN)));
+ if (ret == -1) {
+ break;
+ }
+ }
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_stream_blocking.h b/test/syscalls/linux/socket_stream_blocking.h
new file mode 100644
index 000000000..9fd19ff90
--- /dev/null
+++ b/test/syscalls/linux/socket_stream_blocking.h
@@ -0,0 +1,30 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_STREAM_BLOCKING_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_STREAM_BLOCKING_H_
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of blocking connected stream
+// sockets.
+using BlockingStreamSocketPairTest = SocketPairTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_STREAM_BLOCKING_H_
diff --git a/test/syscalls/linux/socket_stream_nonblock.cc b/test/syscalls/linux/socket_stream_nonblock.cc
new file mode 100644
index 000000000..74d608741
--- /dev/null
+++ b/test/syscalls/linux/socket_stream_nonblock.cc
@@ -0,0 +1,49 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_stream_nonblock.h"
+
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/uio.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using ::testing::Le;
+
+TEST_P(NonBlockingStreamSocketPairTest, SendMsgTooLarge) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int sndbuf;
+ socklen_t length = sizeof(sndbuf);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, &sndbuf, &length),
+ SyscallSucceeds());
+
+ // Make the call too large to fit in the send buffer.
+ const int buffer_size = 3 * sndbuf;
+
+ EXPECT_THAT(SendLargeSendMsg(sockets, buffer_size, false /* reader */),
+ SyscallSucceedsWithValue(Le(buffer_size)));
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_stream_nonblock.h b/test/syscalls/linux/socket_stream_nonblock.h
new file mode 100644
index 000000000..c3b7fad91
--- /dev/null
+++ b/test/syscalls/linux/socket_stream_nonblock.h
@@ -0,0 +1,30 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_STREAM_NONBLOCK_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_STREAM_NONBLOCK_H_
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of non-blocking connected stream
+// sockets.
+using NonBlockingStreamSocketPairTest = SocketPairTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_STREAM_NONBLOCK_H_
diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc
new file mode 100644
index 000000000..53b678e94
--- /dev/null
+++ b/test/syscalls/linux/socket_test_util.cc
@@ -0,0 +1,907 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+#include <arpa/inet.h>
+#include <poll.h>
+#include <sys/socket.h>
+
+#include <memory>
+
+#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/time/clock.h"
+#include "absl/types/optional.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+Creator<SocketPair> SyscallSocketPairCreator(int domain, int type,
+ int protocol) {
+ return [=]() -> PosixErrorOr<std::unique_ptr<FDSocketPair>> {
+ int pair[2];
+ RETURN_ERROR_IF_SYSCALL_FAIL(socketpair(domain, type, protocol, pair));
+ MaybeSave(); // Save on successful creation.
+ return absl::make_unique<FDSocketPair>(pair[0], pair[1]);
+ };
+}
+
+Creator<FileDescriptor> SyscallSocketCreator(int domain, int type,
+ int protocol) {
+ return [=]() -> PosixErrorOr<std::unique_ptr<FileDescriptor>> {
+ int fd = 0;
+ RETURN_ERROR_IF_SYSCALL_FAIL(fd = socket(domain, type, protocol));
+ MaybeSave(); // Save on successful creation.
+ return absl::make_unique<FileDescriptor>(fd);
+ };
+}
+
+PosixErrorOr<struct sockaddr_un> UniqueUnixAddr(bool abstract, int domain) {
+ struct sockaddr_un addr = {};
+ std::string path = NewTempAbsPathInDir("/tmp");
+ if (path.size() >= sizeof(addr.sun_path)) {
+ return PosixError(EINVAL,
+ "Unable to generate a temp path of appropriate length");
+ }
+
+ if (abstract) {
+ // Indicate that the path is in the abstract namespace.
+ path[0] = 0;
+ }
+ memcpy(addr.sun_path, path.c_str(), path.length());
+ addr.sun_family = domain;
+ return addr;
+}
+
+Creator<SocketPair> AcceptBindSocketPairCreator(bool abstract, int domain,
+ int type, int protocol) {
+ return [=]() -> PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> {
+ ASSIGN_OR_RETURN_ERRNO(struct sockaddr_un bind_addr,
+ UniqueUnixAddr(abstract, domain));
+ ASSIGN_OR_RETURN_ERRNO(struct sockaddr_un extra_addr,
+ UniqueUnixAddr(abstract, domain));
+
+ int bound;
+ RETURN_ERROR_IF_SYSCALL_FAIL(bound = socket(domain, type, protocol));
+ MaybeSave(); // Successful socket creation.
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ bind(bound, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)));
+ MaybeSave(); // Successful bind.
+ RETURN_ERROR_IF_SYSCALL_FAIL(listen(bound, /* backlog = */ 5));
+ MaybeSave(); // Successful listen.
+
+ int connected;
+ RETURN_ERROR_IF_SYSCALL_FAIL(connected = socket(domain, type, protocol));
+ MaybeSave(); // Successful socket creation.
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ connect(connected, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)));
+ MaybeSave(); // Successful connect.
+
+ int accepted;
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ accepted = accept4(bound, nullptr, nullptr,
+ type & (SOCK_NONBLOCK | SOCK_CLOEXEC)));
+ MaybeSave(); // Successful connect.
+
+ // Cleanup no longer needed resources.
+ RETURN_ERROR_IF_SYSCALL_FAIL(close(bound));
+ MaybeSave(); // Dropped original socket.
+
+ // Only unlink if path is not in abstract namespace.
+ if (bind_addr.sun_path[0] != 0) {
+ RETURN_ERROR_IF_SYSCALL_FAIL(unlink(bind_addr.sun_path));
+ MaybeSave(); // Unlinked path.
+ }
+
+ // accepted is before connected to destruct connected before accepted.
+ // Destructors for nonstatic member objects are called in the reverse order
+ // in which they appear in the class declaration.
+ return absl::make_unique<AddrFDSocketPair>(accepted, connected, bind_addr,
+ extra_addr);
+ };
+}
+
+Creator<SocketPair> FilesystemAcceptBindSocketPairCreator(int domain, int type,
+ int protocol) {
+ return AcceptBindSocketPairCreator(/* abstract= */ false, domain, type,
+ protocol);
+}
+
+Creator<SocketPair> AbstractAcceptBindSocketPairCreator(int domain, int type,
+ int protocol) {
+ return AcceptBindSocketPairCreator(/* abstract= */ true, domain, type,
+ protocol);
+}
+
+Creator<SocketPair> BidirectionalBindSocketPairCreator(bool abstract,
+ int domain, int type,
+ int protocol) {
+ return [=]() -> PosixErrorOr<std::unique_ptr<FDSocketPair>> {
+ ASSIGN_OR_RETURN_ERRNO(struct sockaddr_un addr1,
+ UniqueUnixAddr(abstract, domain));
+ ASSIGN_OR_RETURN_ERRNO(struct sockaddr_un addr2,
+ UniqueUnixAddr(abstract, domain));
+
+ int sock1;
+ RETURN_ERROR_IF_SYSCALL_FAIL(sock1 = socket(domain, type, protocol));
+ MaybeSave(); // Successful socket creation.
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ bind(sock1, reinterpret_cast<struct sockaddr*>(&addr1), sizeof(addr1)));
+ MaybeSave(); // Successful bind.
+
+ int sock2;
+ RETURN_ERROR_IF_SYSCALL_FAIL(sock2 = socket(domain, type, protocol));
+ MaybeSave(); // Successful socket creation.
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ bind(sock2, reinterpret_cast<struct sockaddr*>(&addr2), sizeof(addr2)));
+ MaybeSave(); // Successful bind.
+
+ RETURN_ERROR_IF_SYSCALL_FAIL(connect(
+ sock1, reinterpret_cast<struct sockaddr*>(&addr2), sizeof(addr2)));
+ MaybeSave(); // Successful connect.
+
+ RETURN_ERROR_IF_SYSCALL_FAIL(connect(
+ sock2, reinterpret_cast<struct sockaddr*>(&addr1), sizeof(addr1)));
+ MaybeSave(); // Successful connect.
+
+ // Cleanup no longer needed resources.
+
+ // Only unlink if path is not in abstract namespace.
+ if (addr1.sun_path[0] != 0) {
+ RETURN_ERROR_IF_SYSCALL_FAIL(unlink(addr1.sun_path));
+ MaybeSave(); // Successful unlink.
+ }
+
+ // Only unlink if path is not in abstract namespace.
+ if (addr2.sun_path[0] != 0) {
+ RETURN_ERROR_IF_SYSCALL_FAIL(unlink(addr2.sun_path));
+ MaybeSave(); // Successful unlink.
+ }
+
+ return absl::make_unique<FDSocketPair>(sock1, sock2);
+ };
+}
+
+Creator<SocketPair> FilesystemBidirectionalBindSocketPairCreator(int domain,
+ int type,
+ int protocol) {
+ return BidirectionalBindSocketPairCreator(/* abstract= */ false, domain, type,
+ protocol);
+}
+
+Creator<SocketPair> AbstractBidirectionalBindSocketPairCreator(int domain,
+ int type,
+ int protocol) {
+ return BidirectionalBindSocketPairCreator(/* abstract= */ true, domain, type,
+ protocol);
+}
+
+Creator<SocketPair> SocketpairGoferSocketPairCreator(int domain, int type,
+ int protocol) {
+ return [=]() -> PosixErrorOr<std::unique_ptr<FDSocketPair>> {
+ struct sockaddr_un addr = {};
+ constexpr char kSocketGoferPath[] = "/socket";
+ memcpy(addr.sun_path, kSocketGoferPath, sizeof(kSocketGoferPath));
+ addr.sun_family = domain;
+
+ int sock1;
+ RETURN_ERROR_IF_SYSCALL_FAIL(sock1 = socket(domain, type, protocol));
+ MaybeSave(); // Successful socket creation.
+ RETURN_ERROR_IF_SYSCALL_FAIL(connect(
+ sock1, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)));
+ MaybeSave(); // Successful connect.
+
+ int sock2;
+ RETURN_ERROR_IF_SYSCALL_FAIL(sock2 = socket(domain, type, protocol));
+ MaybeSave(); // Successful socket creation.
+ RETURN_ERROR_IF_SYSCALL_FAIL(connect(
+ sock2, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)));
+ MaybeSave(); // Successful connect.
+
+ // Make and close another socketpair to ensure that the duped ends of the
+ // first socketpair get closed.
+ //
+ // The problem is that there is no way to atomically send and close an FD.
+ // The closest that we can do is send and then immediately close the FD,
+ // which is what we do in the gofer. The gofer won't respond to another
+ // request until the reply is sent and the FD is closed, so forcing the
+ // gofer to handle another request will ensure that this has happened.
+ for (int i = 0; i < 2; i++) {
+ int sock;
+ RETURN_ERROR_IF_SYSCALL_FAIL(sock = socket(domain, type, protocol));
+ RETURN_ERROR_IF_SYSCALL_FAIL(connect(
+ sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)));
+ RETURN_ERROR_IF_SYSCALL_FAIL(close(sock));
+ }
+
+ return absl::make_unique<FDSocketPair>(sock1, sock2);
+ };
+}
+
+Creator<SocketPair> SocketpairGoferFileSocketPairCreator(int flags) {
+ return [=]() -> PosixErrorOr<std::unique_ptr<FDSocketPair>> {
+ constexpr char kSocketGoferPath[] = "/socket";
+
+ int sock1;
+ RETURN_ERROR_IF_SYSCALL_FAIL(sock1 =
+ open(kSocketGoferPath, O_RDWR | flags));
+ MaybeSave(); // Successful socket creation.
+
+ int sock2;
+ RETURN_ERROR_IF_SYSCALL_FAIL(sock2 =
+ open(kSocketGoferPath, O_RDWR | flags));
+ MaybeSave(); // Successful socket creation.
+
+ return absl::make_unique<FDSocketPair>(sock1, sock2);
+ };
+}
+
+Creator<SocketPair> UnboundSocketPairCreator(bool abstract, int domain,
+ int type, int protocol) {
+ return [=]() -> PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> {
+ ASSIGN_OR_RETURN_ERRNO(struct sockaddr_un addr1,
+ UniqueUnixAddr(abstract, domain));
+ ASSIGN_OR_RETURN_ERRNO(struct sockaddr_un addr2,
+ UniqueUnixAddr(abstract, domain));
+
+ int sock1;
+ RETURN_ERROR_IF_SYSCALL_FAIL(sock1 = socket(domain, type, protocol));
+ MaybeSave(); // Successful socket creation.
+ int sock2;
+ RETURN_ERROR_IF_SYSCALL_FAIL(sock2 = socket(domain, type, protocol));
+ MaybeSave(); // Successful socket creation.
+ return absl::make_unique<AddrFDSocketPair>(sock1, sock2, addr1, addr2);
+ };
+}
+
+Creator<SocketPair> FilesystemUnboundSocketPairCreator(int domain, int type,
+ int protocol) {
+ return UnboundSocketPairCreator(/* abstract= */ false, domain, type,
+ protocol);
+}
+
+Creator<SocketPair> AbstractUnboundSocketPairCreator(int domain, int type,
+ int protocol) {
+ return UnboundSocketPairCreator(/* abstract= */ true, domain, type, protocol);
+}
+
+void LocalhostAddr(struct sockaddr_in* addr, bool dual_stack) {
+ addr->sin_family = AF_INET;
+ addr->sin_port = htons(0);
+ inet_pton(AF_INET, "127.0.0.1",
+ reinterpret_cast<void*>(&addr->sin_addr.s_addr));
+}
+
+void LocalhostAddr(struct sockaddr_in6* addr, bool dual_stack) {
+ addr->sin6_family = AF_INET6;
+ addr->sin6_port = htons(0);
+ if (dual_stack) {
+ inet_pton(AF_INET6, "::ffff:127.0.0.1",
+ reinterpret_cast<void*>(&addr->sin6_addr.s6_addr));
+ } else {
+ inet_pton(AF_INET6, "::1",
+ reinterpret_cast<void*>(&addr->sin6_addr.s6_addr));
+ }
+ addr->sin6_scope_id = 0;
+}
+
+template <typename T>
+PosixErrorOr<T> BindIP(int fd, bool dual_stack) {
+ T addr = {};
+ LocalhostAddr(&addr, dual_stack);
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ bind(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)));
+ socklen_t addrlen = sizeof(addr);
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ getsockname(fd, reinterpret_cast<struct sockaddr*>(&addr), &addrlen));
+ return addr;
+}
+
+template <typename T>
+PosixErrorOr<T> TCPBindAndListen(int fd, bool dual_stack) {
+ ASSIGN_OR_RETURN_ERRNO(T addr, BindIP<T>(fd, dual_stack));
+ RETURN_ERROR_IF_SYSCALL_FAIL(listen(fd, /* backlog = */ 5));
+ return addr;
+}
+
+template <typename T>
+PosixErrorOr<std::unique_ptr<AddrFDSocketPair>>
+CreateTCPConnectAcceptSocketPair(int bound, int connected, int type,
+ bool dual_stack, T bind_addr) {
+ int connect_result = 0;
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ (connect_result = RetryEINTR(connect)(
+ connected, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr))) == -1 &&
+ errno == EINPROGRESS
+ ? 0
+ : connect_result);
+ MaybeSave(); // Successful connect.
+
+ if (connect_result == -1) {
+ struct pollfd connect_poll = {connected, POLLOUT | POLLERR | POLLHUP, 0};
+ RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(poll)(&connect_poll, 1, 0));
+ int error = 0;
+ socklen_t errorlen = sizeof(error);
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ getsockopt(connected, SOL_SOCKET, SO_ERROR, &error, &errorlen));
+ errno = error;
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ /* connect */ error == 0 ? 0 : -1);
+ }
+
+ int accepted = -1;
+ struct pollfd accept_poll = {bound, POLLIN, 0};
+ while (accepted == -1) {
+ RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(poll)(&accept_poll, 1, 0));
+
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ (accepted = RetryEINTR(accept4)(
+ bound, nullptr, nullptr, type & (SOCK_NONBLOCK | SOCK_CLOEXEC))) ==
+ -1 &&
+ errno == EAGAIN
+ ? 0
+ : accepted);
+ }
+ MaybeSave(); // Successful accept.
+
+ T extra_addr = {};
+ LocalhostAddr(&extra_addr, dual_stack);
+ return absl::make_unique<AddrFDSocketPair>(connected, accepted, bind_addr,
+ extra_addr);
+}
+
+template <typename T>
+PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> CreateTCPAcceptBindSocketPair(
+ int bound, int connected, int type, bool dual_stack) {
+ ASSIGN_OR_RETURN_ERRNO(T bind_addr, TCPBindAndListen<T>(bound, dual_stack));
+
+ auto result = CreateTCPConnectAcceptSocketPair(bound, connected, type,
+ dual_stack, bind_addr);
+
+ // Cleanup no longer needed resources.
+ RETURN_ERROR_IF_SYSCALL_FAIL(close(bound));
+ MaybeSave(); // Successful close.
+
+ return result;
+}
+
+Creator<SocketPair> TCPAcceptBindSocketPairCreator(int domain, int type,
+ int protocol,
+ bool dual_stack) {
+ return [=]() -> PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> {
+ int bound;
+ RETURN_ERROR_IF_SYSCALL_FAIL(bound = socket(domain, type, protocol));
+ MaybeSave(); // Successful socket creation.
+
+ int connected;
+ RETURN_ERROR_IF_SYSCALL_FAIL(connected = socket(domain, type, protocol));
+ MaybeSave(); // Successful socket creation.
+
+ if (domain == AF_INET) {
+ return CreateTCPAcceptBindSocketPair<sockaddr_in>(bound, connected, type,
+ dual_stack);
+ }
+ return CreateTCPAcceptBindSocketPair<sockaddr_in6>(bound, connected, type,
+ dual_stack);
+ };
+}
+
+Creator<SocketPair> TCPAcceptBindPersistentListenerSocketPairCreator(
+ int domain, int type, int protocol, bool dual_stack) {
+ // These are lazily initialized below, on the first call to the returned
+ // lambda. These values are private to each returned lambda, but shared across
+ // invocations of a specific lambda.
+ //
+ // The sharing allows pairs created with the same parameters to share a
+ // listener. This prevents future connects from failing if the connecting
+ // socket selects a port which had previously been used by a listening socket
+ // that still has some connections in TIME-WAIT.
+ //
+ // The lazy initialization is to avoid creating sockets during parameter
+ // enumeration. This is important because parameters are enumerated during the
+ // build process where networking may not be available.
+ auto listener = std::make_shared<absl::optional<int>>(absl::optional<int>());
+ auto addr4 = std::make_shared<absl::optional<sockaddr_in>>(
+ absl::optional<sockaddr_in>());
+ auto addr6 = std::make_shared<absl::optional<sockaddr_in6>>(
+ absl::optional<sockaddr_in6>());
+
+ return [=]() -> PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> {
+ int connected;
+ RETURN_ERROR_IF_SYSCALL_FAIL(connected = socket(domain, type, protocol));
+ MaybeSave(); // Successful socket creation.
+
+ // Share the listener across invocations.
+ if (!listener->has_value()) {
+ int fd = socket(domain, type, protocol);
+ if (fd < 0) {
+ return PosixError(errno, absl::StrCat("socket(", domain, ", ", type,
+ ", ", protocol, ")"));
+ }
+ listener->emplace(fd);
+ MaybeSave(); // Successful socket creation.
+ }
+
+ // Bind the listener once, but create a new connect/accept pair each
+ // time.
+ if (domain == AF_INET) {
+ if (!addr4->has_value()) {
+ addr4->emplace(
+ TCPBindAndListen<sockaddr_in>(listener->value(), dual_stack)
+ .ValueOrDie());
+ }
+ return CreateTCPConnectAcceptSocketPair(listener->value(), connected,
+ type, dual_stack, addr4->value());
+ }
+ if (!addr6->has_value()) {
+ addr6->emplace(
+ TCPBindAndListen<sockaddr_in6>(listener->value(), dual_stack)
+ .ValueOrDie());
+ }
+ return CreateTCPConnectAcceptSocketPair(listener->value(), connected, type,
+ dual_stack, addr6->value());
+ };
+}
+
+template <typename T>
+PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> CreateUDPBoundSocketPair(
+ int sock1, int sock2, int type, bool dual_stack) {
+ ASSIGN_OR_RETURN_ERRNO(T addr1, BindIP<T>(sock1, dual_stack));
+ ASSIGN_OR_RETURN_ERRNO(T addr2, BindIP<T>(sock2, dual_stack));
+
+ return absl::make_unique<AddrFDSocketPair>(sock1, sock2, addr1, addr2);
+}
+
+template <typename T>
+PosixErrorOr<std::unique_ptr<AddrFDSocketPair>>
+CreateUDPBidirectionalBindSocketPair(int sock1, int sock2, int type,
+ bool dual_stack) {
+ ASSIGN_OR_RETURN_ERRNO(
+ auto socks, CreateUDPBoundSocketPair<T>(sock1, sock2, type, dual_stack));
+
+ // Connect sock1 to sock2.
+ RETURN_ERROR_IF_SYSCALL_FAIL(connect(socks->first_fd(), socks->second_addr(),
+ socks->second_addr_size()));
+ MaybeSave(); // Successful connection.
+
+ // Connect sock2 to sock1.
+ RETURN_ERROR_IF_SYSCALL_FAIL(connect(socks->second_fd(), socks->first_addr(),
+ socks->first_addr_size()));
+ MaybeSave(); // Successful connection.
+
+ return socks;
+}
+
+Creator<SocketPair> UDPBidirectionalBindSocketPairCreator(int domain, int type,
+ int protocol,
+ bool dual_stack) {
+ return [=]() -> PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> {
+ int sock1;
+ RETURN_ERROR_IF_SYSCALL_FAIL(sock1 = socket(domain, type, protocol));
+ MaybeSave(); // Successful socket creation.
+
+ int sock2;
+ RETURN_ERROR_IF_SYSCALL_FAIL(sock2 = socket(domain, type, protocol));
+ MaybeSave(); // Successful socket creation.
+
+ if (domain == AF_INET) {
+ return CreateUDPBidirectionalBindSocketPair<sockaddr_in>(
+ sock1, sock2, type, dual_stack);
+ }
+ return CreateUDPBidirectionalBindSocketPair<sockaddr_in6>(sock1, sock2,
+ type, dual_stack);
+ };
+}
+
+Creator<SocketPair> UDPUnboundSocketPairCreator(int domain, int type,
+ int protocol, bool dual_stack) {
+ return [=]() -> PosixErrorOr<std::unique_ptr<FDSocketPair>> {
+ int sock1;
+ RETURN_ERROR_IF_SYSCALL_FAIL(sock1 = socket(domain, type, protocol));
+ MaybeSave(); // Successful socket creation.
+
+ int sock2;
+ RETURN_ERROR_IF_SYSCALL_FAIL(sock2 = socket(domain, type, protocol));
+ MaybeSave(); // Successful socket creation.
+
+ return absl::make_unique<FDSocketPair>(sock1, sock2);
+ };
+}
+
+SocketPairKind Reversed(SocketPairKind const& base) {
+ auto const& creator = base.creator;
+ return SocketPairKind{
+ absl::StrCat("reversed ", base.description), base.domain, base.type,
+ base.protocol,
+ [creator]() -> PosixErrorOr<std::unique_ptr<ReversedSocketPair>> {
+ ASSIGN_OR_RETURN_ERRNO(auto creator_value, creator());
+ return absl::make_unique<ReversedSocketPair>(std::move(creator_value));
+ }};
+}
+
+Creator<FileDescriptor> UnboundSocketCreator(int domain, int type,
+ int protocol) {
+ return [=]() -> PosixErrorOr<std::unique_ptr<FileDescriptor>> {
+ int sock;
+ RETURN_ERROR_IF_SYSCALL_FAIL(sock = socket(domain, type, protocol));
+ MaybeSave(); // Successful socket creation.
+
+ return absl::make_unique<FileDescriptor>(sock);
+ };
+}
+
+std::vector<SocketPairKind> IncludeReversals(std::vector<SocketPairKind> vec) {
+ return ApplyVecToVec<SocketPairKind>(std::vector<Middleware>{NoOp, Reversed},
+ vec);
+}
+
+SocketPairKind NoOp(SocketPairKind const& base) { return base; }
+
+void TransferTest(int fd1, int fd2) {
+ char buf1[20];
+ RandomizeBuffer(buf1, sizeof(buf1));
+ ASSERT_THAT(WriteFd(fd1, buf1, sizeof(buf1)),
+ SyscallSucceedsWithValue(sizeof(buf1)));
+
+ char buf2[20];
+ ASSERT_THAT(ReadFd(fd2, buf2, sizeof(buf2)),
+ SyscallSucceedsWithValue(sizeof(buf2)));
+
+ EXPECT_EQ(0, memcmp(buf1, buf2, sizeof(buf1)));
+
+ RandomizeBuffer(buf1, sizeof(buf1));
+ ASSERT_THAT(WriteFd(fd2, buf1, sizeof(buf1)),
+ SyscallSucceedsWithValue(sizeof(buf1)));
+
+ ASSERT_THAT(ReadFd(fd1, buf2, sizeof(buf2)),
+ SyscallSucceedsWithValue(sizeof(buf2)));
+
+ EXPECT_EQ(0, memcmp(buf1, buf2, sizeof(buf1)));
+}
+
+// Initializes the given buffer with random data.
+void RandomizeBuffer(char* ptr, size_t len) {
+ uint32_t seed = time(nullptr);
+ for (size_t i = 0; i < len; ++i) {
+ ptr[i] = static_cast<char>(rand_r(&seed));
+ }
+}
+
+size_t CalculateUnixSockAddrLen(const char* sun_path) {
+ // Abstract addresses always return the full length.
+ if (sun_path[0] == 0) {
+ return sizeof(sockaddr_un);
+ }
+ // Filesystem addresses use the address length plus the 2 byte sun_family
+ // and null terminator.
+ return strlen(sun_path) + 3;
+}
+
+struct sockaddr_storage AddrFDSocketPair::to_storage(const sockaddr_un& addr) {
+ struct sockaddr_storage addr_storage = {};
+ memcpy(&addr_storage, &addr, sizeof(addr));
+ return addr_storage;
+}
+
+struct sockaddr_storage AddrFDSocketPair::to_storage(const sockaddr_in& addr) {
+ struct sockaddr_storage addr_storage = {};
+ memcpy(&addr_storage, &addr, sizeof(addr));
+ return addr_storage;
+}
+
+struct sockaddr_storage AddrFDSocketPair::to_storage(const sockaddr_in6& addr) {
+ struct sockaddr_storage addr_storage = {};
+ memcpy(&addr_storage, &addr, sizeof(addr));
+ return addr_storage;
+}
+
+SocketKind SimpleSocket(int fam, int type, int proto) {
+ return SocketKind{
+ absl::StrCat("Family ", fam, ", type ", type, ", proto ", proto), fam,
+ type, proto, SyscallSocketCreator(fam, type, proto)};
+}
+
+ssize_t SendLargeSendMsg(const std::unique_ptr<SocketPair>& sockets,
+ size_t size, bool reader) {
+ const int rfd = sockets->second_fd();
+ ScopedThread t([rfd, size, reader] {
+ if (!reader) {
+ return;
+ }
+
+ // Potentially too many syscalls in the loop.
+ const DisableSave ds;
+
+ std::vector<char> buf(size);
+ size_t total = 0;
+
+ while (total < size) {
+ int ret = read(rfd, buf.data(), buf.size());
+ if (ret == -1 && errno == EAGAIN) {
+ continue;
+ }
+ if (ret > 0) {
+ total += ret;
+ }
+
+ // Assert to return on first failure.
+ ASSERT_THAT(ret, SyscallSucceeds());
+ }
+ });
+
+ std::vector<char> buf(size);
+
+ struct iovec iov = {};
+ iov.iov_base = buf.data();
+ iov.iov_len = buf.size();
+
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ return RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0);
+}
+
+namespace internal {
+PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family,
+ SocketType type, bool reuse_addr) {
+ if (port < 0) {
+ return PosixError(EINVAL, "Invalid port");
+ }
+
+ // Both Ipv6 and Dualstack are AF_INET6.
+ int sock_fam = (family == AddressFamily::kIpv4 ? AF_INET : AF_INET6);
+ int sock_type = (type == SocketType::kTcp ? SOCK_STREAM : SOCK_DGRAM);
+ ASSIGN_OR_RETURN_ERRNO(auto fd, Socket(sock_fam, sock_type, 0));
+
+ if (reuse_addr) {
+ int one = 1;
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ setsockopt(fd.get(), SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)));
+ }
+
+ // Try to bind.
+ sockaddr_storage storage = {};
+ int storage_size = 0;
+ if (family == AddressFamily::kIpv4) {
+ sockaddr_in* addr = reinterpret_cast<sockaddr_in*>(&storage);
+ storage_size = sizeof(*addr);
+ addr->sin_family = AF_INET;
+ addr->sin_port = htons(port);
+ addr->sin_addr.s_addr = htonl(INADDR_ANY);
+ } else {
+ sockaddr_in6* addr = reinterpret_cast<sockaddr_in6*>(&storage);
+ storage_size = sizeof(*addr);
+ addr->sin6_family = AF_INET6;
+ addr->sin6_port = htons(port);
+ if (family == AddressFamily::kDualStack) {
+ inet_pton(AF_INET6, "::ffff:0.0.0.0",
+ reinterpret_cast<void*>(&addr->sin6_addr.s6_addr));
+ } else {
+ addr->sin6_addr = in6addr_any;
+ }
+ }
+
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ bind(fd.get(), reinterpret_cast<sockaddr*>(&storage), storage_size));
+
+ // If the user specified 0 as the port, we will return the port that the
+ // kernel gave us, otherwise we will validate that this socket bound to the
+ // requested port.
+ sockaddr_storage bound_storage = {};
+ socklen_t bound_storage_size = sizeof(bound_storage);
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ getsockname(fd.get(), reinterpret_cast<sockaddr*>(&bound_storage),
+ &bound_storage_size));
+
+ int available_port = -1;
+ if (bound_storage.ss_family == AF_INET) {
+ sockaddr_in* addr = reinterpret_cast<sockaddr_in*>(&bound_storage);
+ available_port = ntohs(addr->sin_port);
+ } else if (bound_storage.ss_family == AF_INET6) {
+ sockaddr_in6* addr = reinterpret_cast<sockaddr_in6*>(&bound_storage);
+ available_port = ntohs(addr->sin6_port);
+ } else {
+ return PosixError(EPROTOTYPE, "Getsockname returned invalid family");
+ }
+
+ // If we requested a specific port make sure our bound port is that port.
+ if (port != 0 && available_port != port) {
+ return PosixError(EINVAL,
+ absl::StrCat("Bound port ", available_port,
+ " was not equal to requested port ", port));
+ }
+
+ // If we're trying to do a TCP socket, let's also try to listen.
+ if (type == SocketType::kTcp) {
+ RETURN_ERROR_IF_SYSCALL_FAIL(listen(fd.get(), 1));
+ }
+
+ return available_port;
+}
+} // namespace internal
+
+PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size) {
+ struct iovec iov;
+ iov.iov_base = buf;
+ iov.iov_len = buf_size;
+ msg->msg_iov = &iov;
+ msg->msg_iovlen = 1;
+
+ int ret;
+ RETURN_ERROR_IF_SYSCALL_FAIL(ret = RetryEINTR(sendmsg)(sock, msg, 0));
+ return ret;
+}
+
+void RecvNoData(int sock) {
+ char data = 0;
+ struct iovec iov;
+ iov.iov_base = &data;
+ iov.iov_len = 1;
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ ASSERT_THAT(RetryEINTR(recvmsg)(sock, &msg, MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TestAddress V4Any() {
+ TestAddress t("V4Any");
+ t.addr.ss_family = AF_INET;
+ t.addr_len = sizeof(sockaddr_in);
+ reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = htonl(INADDR_ANY);
+ return t;
+}
+
+TestAddress V4Loopback() {
+ TestAddress t("V4Loopback");
+ t.addr.ss_family = AF_INET;
+ t.addr_len = sizeof(sockaddr_in);
+ reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr =
+ htonl(INADDR_LOOPBACK);
+ return t;
+}
+
+TestAddress V4MappedAny() {
+ TestAddress t("V4MappedAny");
+ t.addr.ss_family = AF_INET6;
+ t.addr_len = sizeof(sockaddr_in6);
+ inet_pton(AF_INET6, "::ffff:0.0.0.0",
+ reinterpret_cast<sockaddr_in6*>(&t.addr)->sin6_addr.s6_addr);
+ return t;
+}
+
+TestAddress V4MappedLoopback() {
+ TestAddress t("V4MappedLoopback");
+ t.addr.ss_family = AF_INET6;
+ t.addr_len = sizeof(sockaddr_in6);
+ inet_pton(AF_INET6, "::ffff:127.0.0.1",
+ reinterpret_cast<sockaddr_in6*>(&t.addr)->sin6_addr.s6_addr);
+ return t;
+}
+
+TestAddress V4Multicast() {
+ TestAddress t("V4Multicast");
+ t.addr.ss_family = AF_INET;
+ t.addr_len = sizeof(sockaddr_in);
+ reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr =
+ inet_addr(kMulticastAddress);
+ return t;
+}
+
+TestAddress V4Broadcast() {
+ TestAddress t("V4Broadcast");
+ t.addr.ss_family = AF_INET;
+ t.addr_len = sizeof(sockaddr_in);
+ reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr =
+ htonl(INADDR_BROADCAST);
+ return t;
+}
+
+TestAddress V6Any() {
+ TestAddress t("V6Any");
+ t.addr.ss_family = AF_INET6;
+ t.addr_len = sizeof(sockaddr_in6);
+ reinterpret_cast<sockaddr_in6*>(&t.addr)->sin6_addr = in6addr_any;
+ return t;
+}
+
+TestAddress V6Loopback() {
+ TestAddress t("V6Loopback");
+ t.addr.ss_family = AF_INET6;
+ t.addr_len = sizeof(sockaddr_in6);
+ reinterpret_cast<sockaddr_in6*>(&t.addr)->sin6_addr = in6addr_loopback;
+ return t;
+}
+
+// Checksum computes the internet checksum of a buffer.
+uint16_t Checksum(uint16_t* buf, ssize_t buf_size) {
+ // Add up the 16-bit values in the buffer.
+ uint32_t total = 0;
+ for (unsigned int i = 0; i < buf_size; i += sizeof(*buf)) {
+ total += *buf;
+ buf++;
+ }
+
+ // If buf has an odd size, add the remaining byte.
+ if (buf_size % 2) {
+ total += *(reinterpret_cast<unsigned char*>(buf) - 1);
+ }
+
+ // This carries any bits past the lower 16 until everything fits in 16 bits.
+ while (total >> 16) {
+ uint16_t lower = total & 0xffff;
+ uint16_t upper = total >> 16;
+ total = lower + upper;
+ }
+
+ return ~total;
+}
+
+uint16_t IPChecksum(struct iphdr ip) {
+ return Checksum(reinterpret_cast<uint16_t*>(&ip), sizeof(ip));
+}
+
+// The pseudo-header defined in RFC 768 for calculating the UDP checksum.
+struct udp_pseudo_hdr {
+ uint32_t srcip;
+ uint32_t destip;
+ char zero;
+ char protocol;
+ uint16_t udplen;
+};
+
+uint16_t UDPChecksum(struct iphdr iphdr, struct udphdr udphdr,
+ const char* payload, ssize_t payload_len) {
+ struct udp_pseudo_hdr phdr = {};
+ phdr.srcip = iphdr.saddr;
+ phdr.destip = iphdr.daddr;
+ phdr.zero = 0;
+ phdr.protocol = IPPROTO_UDP;
+ phdr.udplen = udphdr.len;
+
+ ssize_t buf_size = sizeof(phdr) + sizeof(udphdr) + payload_len;
+ char* buf = static_cast<char*>(malloc(buf_size));
+ memcpy(buf, &phdr, sizeof(phdr));
+ memcpy(buf + sizeof(phdr), &udphdr, sizeof(udphdr));
+ memcpy(buf + sizeof(phdr) + sizeof(udphdr), payload, payload_len);
+
+ uint16_t csum = Checksum(reinterpret_cast<uint16_t*>(buf), buf_size);
+ free(buf);
+ return csum;
+}
+
+uint16_t ICMPChecksum(struct icmphdr icmphdr, const char* payload,
+ ssize_t payload_len) {
+ ssize_t buf_size = sizeof(icmphdr) + payload_len;
+ char* buf = static_cast<char*>(malloc(buf_size));
+ memcpy(buf, &icmphdr, sizeof(icmphdr));
+ memcpy(buf + sizeof(icmphdr), payload, payload_len);
+
+ uint16_t csum = Checksum(reinterpret_cast<uint16_t*>(buf), buf_size);
+ free(buf);
+ return csum;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h
new file mode 100644
index 000000000..734b48b96
--- /dev/null
+++ b/test/syscalls/linux/socket_test_util.h
@@ -0,0 +1,518 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_SOCKET_TEST_UTIL_H_
+#define GVISOR_TEST_SYSCALLS_SOCKET_TEST_UTIL_H_
+
+#include <errno.h>
+#include <netinet/ip.h>
+#include <netinet/ip_icmp.h>
+#include <netinet/udp.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "absl/strings/str_format.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Wrapper for socket(2) that returns a FileDescriptor.
+inline PosixErrorOr<FileDescriptor> Socket(int family, int type, int protocol) {
+ int fd = socket(family, type, protocol);
+ MaybeSave();
+ if (fd < 0) {
+ return PosixError(
+ errno, absl::StrFormat("socket(%d, %d, %d)", family, type, protocol));
+ }
+ return FileDescriptor(fd);
+}
+
+// Wrapper for accept(2) that returns a FileDescriptor.
+inline PosixErrorOr<FileDescriptor> Accept(int sockfd, sockaddr* addr,
+ socklen_t* addrlen) {
+ int fd = RetryEINTR(accept)(sockfd, addr, addrlen);
+ MaybeSave();
+ if (fd < 0) {
+ return PosixError(
+ errno, absl::StrFormat("accept(%d, %p, %p)", sockfd, addr, addrlen));
+ }
+ return FileDescriptor(fd);
+}
+
+// Wrapper for accept4(2) that returns a FileDescriptor.
+inline PosixErrorOr<FileDescriptor> Accept4(int sockfd, sockaddr* addr,
+ socklen_t* addrlen, int flags) {
+ int fd = RetryEINTR(accept4)(sockfd, addr, addrlen, flags);
+ MaybeSave();
+ if (fd < 0) {
+ return PosixError(errno, absl::StrFormat("accept4(%d, %p, %p, %#x)", sockfd,
+ addr, addrlen, flags));
+ }
+ return FileDescriptor(fd);
+}
+
+inline ssize_t SendFd(int fd, void* buf, size_t count, int flags) {
+ return internal::ApplyFileIoSyscall(
+ [&](size_t completed) {
+ return sendto(fd, static_cast<char*>(buf) + completed,
+ count - completed, flags, nullptr, 0);
+ },
+ count);
+}
+
+PosixErrorOr<struct sockaddr_un> UniqueUnixAddr(bool abstract, int domain);
+
+// A Creator<T> is a function that attempts to create and return a new T. (This
+// is copy/pasted from cloud/gvisor/api/sandbox_util.h and is just duplicated
+// here for clarity.)
+template <typename T>
+using Creator = std::function<PosixErrorOr<std::unique_ptr<T>>()>;
+
+// A SocketPair represents a pair of socket file descriptors owned by the
+// SocketPair.
+class SocketPair {
+ public:
+ virtual ~SocketPair() = default;
+
+ virtual int first_fd() const = 0;
+ virtual int second_fd() const = 0;
+ virtual int release_first_fd() = 0;
+ virtual int release_second_fd() = 0;
+ virtual const struct sockaddr* first_addr() const = 0;
+ virtual const struct sockaddr* second_addr() const = 0;
+ virtual size_t first_addr_size() const = 0;
+ virtual size_t second_addr_size() const = 0;
+ virtual size_t first_addr_len() const = 0;
+ virtual size_t second_addr_len() const = 0;
+};
+
+// A FDSocketPair is a SocketPair that consists of only a pair of file
+// descriptors.
+class FDSocketPair : public SocketPair {
+ public:
+ FDSocketPair(int first_fd, int second_fd)
+ : first_(first_fd), second_(second_fd) {}
+ FDSocketPair(std::unique_ptr<FileDescriptor> first_fd,
+ std::unique_ptr<FileDescriptor> second_fd)
+ : first_(first_fd->release()), second_(second_fd->release()) {}
+
+ int first_fd() const override { return first_.get(); }
+ int second_fd() const override { return second_.get(); }
+ int release_first_fd() override { return first_.release(); }
+ int release_second_fd() override { return second_.release(); }
+ const struct sockaddr* first_addr() const override { return nullptr; }
+ const struct sockaddr* second_addr() const override { return nullptr; }
+ size_t first_addr_size() const override { return 0; }
+ size_t second_addr_size() const override { return 0; }
+ size_t first_addr_len() const override { return 0; }
+ size_t second_addr_len() const override { return 0; }
+
+ private:
+ FileDescriptor first_;
+ FileDescriptor second_;
+};
+
+// CalculateUnixSockAddrLen calculates the length returned by recvfrom(2) and
+// recvmsg(2) for Unix sockets.
+size_t CalculateUnixSockAddrLen(const char* sun_path);
+
+// A AddrFDSocketPair is a SocketPair that consists of a pair of file
+// descriptors in addition to a pair of socket addresses.
+class AddrFDSocketPair : public SocketPair {
+ public:
+ AddrFDSocketPair(int first_fd, int second_fd,
+ const struct sockaddr_un& first_address,
+ const struct sockaddr_un& second_address)
+ : first_(first_fd),
+ second_(second_fd),
+ first_addr_(to_storage(first_address)),
+ second_addr_(to_storage(second_address)),
+ first_len_(CalculateUnixSockAddrLen(first_address.sun_path)),
+ second_len_(CalculateUnixSockAddrLen(second_address.sun_path)),
+ first_size_(sizeof(first_address)),
+ second_size_(sizeof(second_address)) {}
+
+ AddrFDSocketPair(int first_fd, int second_fd,
+ const struct sockaddr_in& first_address,
+ const struct sockaddr_in& second_address)
+ : first_(first_fd),
+ second_(second_fd),
+ first_addr_(to_storage(first_address)),
+ second_addr_(to_storage(second_address)),
+ first_len_(sizeof(first_address)),
+ second_len_(sizeof(second_address)),
+ first_size_(sizeof(first_address)),
+ second_size_(sizeof(second_address)) {}
+
+ AddrFDSocketPair(int first_fd, int second_fd,
+ const struct sockaddr_in6& first_address,
+ const struct sockaddr_in6& second_address)
+ : first_(first_fd),
+ second_(second_fd),
+ first_addr_(to_storage(first_address)),
+ second_addr_(to_storage(second_address)),
+ first_len_(sizeof(first_address)),
+ second_len_(sizeof(second_address)),
+ first_size_(sizeof(first_address)),
+ second_size_(sizeof(second_address)) {}
+
+ int first_fd() const override { return first_.get(); }
+ int second_fd() const override { return second_.get(); }
+ int release_first_fd() override { return first_.release(); }
+ int release_second_fd() override { return second_.release(); }
+ const struct sockaddr* first_addr() const override {
+ return reinterpret_cast<const struct sockaddr*>(&first_addr_);
+ }
+ const struct sockaddr* second_addr() const override {
+ return reinterpret_cast<const struct sockaddr*>(&second_addr_);
+ }
+ size_t first_addr_size() const override { return first_size_; }
+ size_t second_addr_size() const override { return second_size_; }
+ size_t first_addr_len() const override { return first_len_; }
+ size_t second_addr_len() const override { return second_len_; }
+
+ private:
+ // to_storage coverts a sockaddr_* to a sockaddr_storage.
+ static struct sockaddr_storage to_storage(const sockaddr_un& addr);
+ static struct sockaddr_storage to_storage(const sockaddr_in& addr);
+ static struct sockaddr_storage to_storage(const sockaddr_in6& addr);
+
+ FileDescriptor first_;
+ FileDescriptor second_;
+ const struct sockaddr_storage first_addr_;
+ const struct sockaddr_storage second_addr_;
+ const size_t first_len_;
+ const size_t second_len_;
+ const size_t first_size_;
+ const size_t second_size_;
+};
+
+// SyscallSocketPairCreator returns a Creator<SocketPair> that obtains file
+// descriptors by invoking the socketpair() syscall.
+Creator<SocketPair> SyscallSocketPairCreator(int domain, int type,
+ int protocol);
+
+// SyscallSocketCreator returns a Creator<FileDescriptor> that obtains a file
+// descriptor by invoking the socket() syscall.
+Creator<FileDescriptor> SyscallSocketCreator(int domain, int type,
+ int protocol);
+
+// FilesystemBidirectionalBindSocketPairCreator returns a Creator<SocketPair>
+// that obtains file descriptors by invoking the bind() and connect() syscalls
+// on filesystem paths. Only works for DGRAM sockets.
+Creator<SocketPair> FilesystemBidirectionalBindSocketPairCreator(int domain,
+ int type,
+ int protocol);
+
+// AbstractBidirectionalBindSocketPairCreator returns a Creator<SocketPair> that
+// obtains file descriptors by invoking the bind() and connect() syscalls on
+// abstract namespace paths. Only works for DGRAM sockets.
+Creator<SocketPair> AbstractBidirectionalBindSocketPairCreator(int domain,
+ int type,
+ int protocol);
+
+// SocketpairGoferSocketPairCreator returns a Creator<SocketPair> that
+// obtains file descriptors by connect() syscalls on two sockets with socketpair
+// gofer paths.
+Creator<SocketPair> SocketpairGoferSocketPairCreator(int domain, int type,
+ int protocol);
+
+// SocketpairGoferFileSocketPairCreator returns a Creator<SocketPair> that
+// obtains file descriptors by open() syscalls on socketpair gofer paths.
+Creator<SocketPair> SocketpairGoferFileSocketPairCreator(int flags);
+
+// FilesystemAcceptBindSocketPairCreator returns a Creator<SocketPair> that
+// obtains file descriptors by invoking the accept() and bind() syscalls on
+// a filesystem path. Only works for STREAM and SEQPACKET sockets.
+Creator<SocketPair> FilesystemAcceptBindSocketPairCreator(int domain, int type,
+ int protocol);
+
+// AbstractAcceptBindSocketPairCreator returns a Creator<SocketPair> that
+// obtains file descriptors by invoking the accept() and bind() syscalls on a
+// abstract namespace path. Only works for STREAM and SEQPACKET sockets.
+Creator<SocketPair> AbstractAcceptBindSocketPairCreator(int domain, int type,
+ int protocol);
+
+// FilesystemUnboundSocketPairCreator returns a Creator<SocketPair> that obtains
+// file descriptors by invoking the socket() syscall and generates a filesystem
+// path for binding.
+Creator<SocketPair> FilesystemUnboundSocketPairCreator(int domain, int type,
+ int protocol);
+
+// AbstractUnboundSocketPairCreator returns a Creator<SocketPair> that obtains
+// file descriptors by invoking the socket() syscall and generates an abstract
+// path for binding.
+Creator<SocketPair> AbstractUnboundSocketPairCreator(int domain, int type,
+ int protocol);
+
+// TCPAcceptBindSocketPairCreator returns a Creator<SocketPair> that obtains
+// file descriptors by invoking the accept() and bind() syscalls on TCP sockets.
+Creator<SocketPair> TCPAcceptBindSocketPairCreator(int domain, int type,
+ int protocol,
+ bool dual_stack);
+
+// TCPAcceptBindPersistentListenerSocketPairCreator is like
+// TCPAcceptBindSocketPairCreator, except it uses the same listening socket to
+// create all SocketPairs.
+Creator<SocketPair> TCPAcceptBindPersistentListenerSocketPairCreator(
+ int domain, int type, int protocol, bool dual_stack);
+
+// UDPBidirectionalBindSocketPairCreator returns a Creator<SocketPair> that
+// obtains file descriptors by invoking the bind() and connect() syscalls on UDP
+// sockets.
+Creator<SocketPair> UDPBidirectionalBindSocketPairCreator(int domain, int type,
+ int protocol,
+ bool dual_stack);
+
+// UDPUnboundSocketPairCreator returns a Creator<SocketPair> that obtains file
+// descriptors by creating UDP sockets.
+Creator<SocketPair> UDPUnboundSocketPairCreator(int domain, int type,
+ int protocol, bool dual_stack);
+
+// UnboundSocketCreator returns a Creator<FileDescriptor> that obtains a file
+// descriptor by creating a socket.
+Creator<FileDescriptor> UnboundSocketCreator(int domain, int type,
+ int protocol);
+
+// A SocketPairKind couples a human-readable description of a socket pair with
+// a function that creates such a socket pair.
+struct SocketPairKind {
+ std::string description;
+ int domain;
+ int type;
+ int protocol;
+ Creator<SocketPair> creator;
+
+ // Create creates a socket pair of this kind.
+ PosixErrorOr<std::unique_ptr<SocketPair>> Create() const { return creator(); }
+};
+
+// A SocketKind couples a human-readable description of a socket with
+// a function that creates such a socket.
+struct SocketKind {
+ std::string description;
+ int domain;
+ int type;
+ int protocol;
+ Creator<FileDescriptor> creator;
+
+ // Create creates a socket pair of this kind.
+ PosixErrorOr<std::unique_ptr<FileDescriptor>> Create() const {
+ return creator();
+ }
+};
+
+// A ReversedSocketPair wraps another SocketPair but flips the first and second
+// file descriptors. ReversedSocketPair is used to test socket pairs that
+// should be symmetric.
+class ReversedSocketPair : public SocketPair {
+ public:
+ explicit ReversedSocketPair(std::unique_ptr<SocketPair> base)
+ : base_(std::move(base)) {}
+
+ int first_fd() const override { return base_->second_fd(); }
+ int second_fd() const override { return base_->first_fd(); }
+ int release_first_fd() override { return base_->release_second_fd(); }
+ int release_second_fd() override { return base_->release_first_fd(); }
+ const struct sockaddr* first_addr() const override {
+ return base_->second_addr();
+ }
+ const struct sockaddr* second_addr() const override {
+ return base_->first_addr();
+ }
+ size_t first_addr_size() const override { return base_->second_addr_size(); }
+ size_t second_addr_size() const override { return base_->first_addr_size(); }
+ size_t first_addr_len() const override { return base_->second_addr_len(); }
+ size_t second_addr_len() const override { return base_->first_addr_len(); }
+
+ private:
+ std::unique_ptr<SocketPair> base_;
+};
+
+// Reversed returns a SocketPairKind that represents SocketPairs created by
+// flipping the file descriptors provided by another SocketPair.
+SocketPairKind Reversed(SocketPairKind const& base);
+
+// IncludeReversals returns a vector<SocketPairKind> that returns all
+// SocketPairKinds in `vec` as well as all SocketPairKinds obtained by flipping
+// the file descriptors provided by the kinds in `vec`.
+std::vector<SocketPairKind> IncludeReversals(std::vector<SocketPairKind> vec);
+
+// A Middleware is a function wraps a SocketPairKind.
+using Middleware = std::function<SocketPairKind(SocketPairKind)>;
+
+// Reversed returns a SocketPairKind that represents SocketPairs created by
+// flipping the file descriptors provided by another SocketPair.
+template <typename T>
+Middleware SetSockOpt(int level, int optname, T* value) {
+ return [=](SocketPairKind const& base) {
+ auto const& creator = base.creator;
+ return SocketPairKind{
+ absl::StrCat("setsockopt(", level, ", ", optname, ", ", *value, ") ",
+ base.description),
+ base.domain, base.type, base.protocol,
+ [creator, level, optname,
+ value]() -> PosixErrorOr<std::unique_ptr<SocketPair>> {
+ ASSIGN_OR_RETURN_ERRNO(auto creator_value, creator());
+ if (creator_value->first_fd() >= 0) {
+ RETURN_ERROR_IF_SYSCALL_FAIL(setsockopt(
+ creator_value->first_fd(), level, optname, value, sizeof(T)));
+ }
+ if (creator_value->second_fd() >= 0) {
+ RETURN_ERROR_IF_SYSCALL_FAIL(setsockopt(
+ creator_value->second_fd(), level, optname, value, sizeof(T)));
+ }
+ return creator_value;
+ }};
+ };
+}
+
+constexpr int kSockOptOn = 1;
+constexpr int kSockOptOff = 0;
+
+// NoOp returns the same SocketPairKind that it is passed.
+SocketPairKind NoOp(SocketPairKind const& base);
+
+// TransferTest tests that data can be send back and fourth between two
+// specified FDs. Note that calls to this function should be wrapped in
+// ASSERT_NO_FATAL_FAILURE().
+void TransferTest(int fd1, int fd2);
+
+// Fills [buf, buf+len) with random bytes.
+void RandomizeBuffer(char* buf, size_t len);
+
+// Base test fixture for tests that operate on pairs of connected sockets.
+class SocketPairTest : public ::testing::TestWithParam<SocketPairKind> {
+ protected:
+ SocketPairTest() {
+ // gUnit uses printf, so so will we.
+ printf("Testing with %s\n", GetParam().description.c_str());
+ fflush(stdout);
+ }
+
+ PosixErrorOr<std::unique_ptr<SocketPair>> NewSocketPair() const {
+ return GetParam().Create();
+ }
+};
+
+// Base test fixture for tests that operate on simple Sockets.
+class SimpleSocketTest : public ::testing::TestWithParam<SocketKind> {
+ protected:
+ SimpleSocketTest() {
+ // gUnit uses printf, so so will we.
+ printf("Testing with %s\n", GetParam().description.c_str());
+ }
+
+ PosixErrorOr<std::unique_ptr<FileDescriptor>> NewSocket() const {
+ return GetParam().Create();
+ }
+};
+
+SocketKind SimpleSocket(int fam, int type, int proto);
+
+// Send a buffer of size 'size' to sockets->first_fd(), returning the result of
+// sendmsg.
+//
+// If reader, read from second_fd() until size bytes have been read.
+ssize_t SendLargeSendMsg(const std::unique_ptr<SocketPair>& sockets,
+ size_t size, bool reader);
+
+// Initializes the given buffer with random data.
+void RandomizeBuffer(char* ptr, size_t len);
+
+enum class AddressFamily { kIpv4 = 1, kIpv6 = 2, kDualStack = 3 };
+enum class SocketType { kUdp = 1, kTcp = 2 };
+
+// Returns a PosixError or a port that is available. If 0 is specified as the
+// port it will bind port 0 (and allow the kernel to select any free port).
+// Otherwise, it will try to bind the specified port and validate that it can be
+// used for the requested family and socket type. The final option is
+// reuse_addr. This specifies whether SO_REUSEADDR should be applied before a
+// bind(2) attempt. SO_REUSEADDR means that sockets in TIME_WAIT states or other
+// bound UDP sockets would not cause an error on bind(2). This option should be
+// set if subsequent calls to bind on the returned port will also use
+// SO_REUSEADDR.
+//
+// Note: That this test will attempt to bind the ANY address for the respective
+// protocol.
+PosixErrorOr<int> PortAvailable(int port, AddressFamily family, SocketType type,
+ bool reuse_addr);
+
+// FreeAvailablePort is used to return a port that was obtained by using
+// the PortAvailable helper with port 0.
+PosixError FreeAvailablePort(int port);
+
+// SendMsg converts a buffer to an iovec and adds it to msg before sending it.
+PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size);
+
+// RecvNoData checks that no data is receivable on sock.
+void RecvNoData(int sock);
+
+// Base test fixture for tests that apply to all kinds of pairs of connected
+// sockets.
+using AllSocketPairTest = SocketPairTest;
+
+struct TestAddress {
+ std::string description;
+ sockaddr_storage addr;
+ socklen_t addr_len;
+
+ int family() const { return addr.ss_family; }
+ explicit TestAddress(std::string description = "")
+ : description(std::move(description)), addr(), addr_len() {}
+};
+
+constexpr char kMulticastAddress[] = "224.0.2.1";
+constexpr char kBroadcastAddress[] = "255.255.255.255";
+
+TestAddress V4Any();
+TestAddress V4Broadcast();
+TestAddress V4Loopback();
+TestAddress V4MappedAny();
+TestAddress V4MappedLoopback();
+TestAddress V4Multicast();
+TestAddress V6Any();
+TestAddress V6Loopback();
+
+// Compute the internet checksum of an IP header.
+uint16_t IPChecksum(struct iphdr ip);
+
+// Compute the internet checksum of a UDP header.
+uint16_t UDPChecksum(struct iphdr iphdr, struct udphdr udphdr,
+ const char* payload, ssize_t payload_len);
+
+// Compute the internet checksum of an ICMP header.
+uint16_t ICMPChecksum(struct icmphdr icmphdr, const char* payload,
+ ssize_t payload_len);
+
+namespace internal {
+PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family,
+ SocketType type, bool reuse_addr);
+} // namespace internal
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_SOCKET_TEST_UTIL_H_
diff --git a/test/syscalls/linux/socket_test_util_impl.cc b/test/syscalls/linux/socket_test_util_impl.cc
new file mode 100644
index 000000000..ef661a0e3
--- /dev/null
+++ b/test/syscalls/linux/socket_test_util_impl.cc
@@ -0,0 +1,28 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+PosixErrorOr<int> PortAvailable(int port, AddressFamily family, SocketType type,
+ bool reuse_addr) {
+ return internal::TryPortAvailable(port, family, type, reuse_addr);
+}
+
+PosixError FreeAvailablePort(int port) { return NoError(); }
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix.cc b/test/syscalls/linux/socket_unix.cc
new file mode 100644
index 000000000..591cab3fd
--- /dev/null
+++ b/test/syscalls/linux/socket_unix.cc
@@ -0,0 +1,274 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_unix.h"
+
+#include <errno.h>
+#include <net/if.h>
+#include <stdio.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "absl/strings/string_view.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+// This file contains tests specific to Unix domain sockets. It does not contain
+// tests for UDS control messages. Those belong in socket_unix_cmsg.cc.
+//
+// This file is a generic socket test file. It must be built with another file
+// that provides the test types.
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST_P(UnixSocketPairTest, InvalidGetSockOpt) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ int opt;
+ socklen_t optlen = sizeof(opt);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, -1, &opt, &optlen),
+ SyscallFailsWithErrno(ENOPROTOOPT));
+}
+
+TEST_P(UnixSocketPairTest, BindToBadName) {
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ constexpr char kBadName[] = "/some/path/that/does/not/exist";
+ sockaddr_un sockaddr;
+ sockaddr.sun_family = AF_LOCAL;
+ memcpy(sockaddr.sun_path, kBadName, sizeof(kBadName));
+
+ EXPECT_THAT(
+ bind(pair->first_fd(), reinterpret_cast<struct sockaddr*>(&sockaddr),
+ sizeof(sockaddr)),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+TEST_P(UnixSocketPairTest, BindToBadFamily) {
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ constexpr char kBadName[] = "/some/path/that/does/not/exist";
+ sockaddr_un sockaddr;
+ sockaddr.sun_family = AF_INET;
+ memcpy(sockaddr.sun_path, kBadName, sizeof(kBadName));
+
+ EXPECT_THAT(
+ bind(pair->first_fd(), reinterpret_cast<struct sockaddr*>(&sockaddr),
+ sizeof(sockaddr)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(UnixSocketPairTest, RecvmmsgTimeoutAfterRecv) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ char received_data[sizeof(sent_data) * 2];
+ std::vector<struct mmsghdr> msgs(2);
+ std::vector<struct iovec> iovs(msgs.size());
+ const int chunk_size = sizeof(received_data) / msgs.size();
+ for (size_t i = 0; i < msgs.size(); i++) {
+ iovs[i].iov_len = chunk_size;
+ iovs[i].iov_base = &received_data[i * chunk_size];
+ msgs[i].msg_hdr.msg_iov = &iovs[i];
+ msgs[i].msg_hdr.msg_iovlen = 1;
+ }
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ struct timespec timeout = {0, 1};
+ ASSERT_THAT(RetryEINTR(recvmmsg)(sockets->second_fd(), &msgs[0], msgs.size(),
+ 0, &timeout),
+ SyscallSucceedsWithValue(1));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ EXPECT_EQ(chunk_size, msgs[0].msg_len);
+}
+
+TEST_P(UnixSocketPairTest, TIOCINQSucceeds) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ if (IsRunningOnGvisor()) {
+ // TODO(gvisor.dev/issue/273): Inherited host UDS don't support TIOCINQ.
+ // Skip the test.
+ int size = -1;
+ int ret = ioctl(sockets->first_fd(), TIOCINQ, &size);
+ SKIP_IF(ret == -1 && errno == ENOTTY);
+ }
+
+ int size = -1;
+ EXPECT_THAT(ioctl(sockets->first_fd(), TIOCINQ, &size), SyscallSucceeds());
+ EXPECT_EQ(size, 0);
+
+ const char some_data[] = "dangerzone";
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->second_fd(), &some_data, sizeof(some_data), 0),
+ SyscallSucceeds());
+ EXPECT_THAT(ioctl(sockets->first_fd(), TIOCINQ, &size), SyscallSucceeds());
+ EXPECT_EQ(size, sizeof(some_data));
+
+ // Linux only reports the first message's size, which is wrong. We test for
+ // the behavior described in the man page.
+ SKIP_IF(!IsRunningOnGvisor());
+
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->second_fd(), &some_data, sizeof(some_data), 0),
+ SyscallSucceeds());
+ EXPECT_THAT(ioctl(sockets->first_fd(), TIOCINQ, &size), SyscallSucceeds());
+ EXPECT_EQ(size, sizeof(some_data) * 2);
+}
+
+TEST_P(UnixSocketPairTest, TIOCOUTQSucceeds) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ if (IsRunningOnGvisor()) {
+ // TODO(gvisor.dev/issue/273): Inherited host UDS don't support TIOCOUTQ.
+ // Skip the test.
+ int size = -1;
+ int ret = ioctl(sockets->second_fd(), TIOCOUTQ, &size);
+ SKIP_IF(ret == -1 && errno == ENOTTY);
+ }
+
+ int size = -1;
+ EXPECT_THAT(ioctl(sockets->second_fd(), TIOCOUTQ, &size), SyscallSucceeds());
+ EXPECT_EQ(size, 0);
+
+ // Linux reports bogus numbers which are related to its internal allocations.
+ // We test for the behavior described in the man page.
+ SKIP_IF(!IsRunningOnGvisor());
+
+ const char some_data[] = "dangerzone";
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->second_fd(), &some_data, sizeof(some_data), 0),
+ SyscallSucceeds());
+ EXPECT_THAT(ioctl(sockets->second_fd(), TIOCOUTQ, &size), SyscallSucceeds());
+ EXPECT_EQ(size, sizeof(some_data));
+
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->second_fd(), &some_data, sizeof(some_data), 0),
+ SyscallSucceeds());
+ EXPECT_THAT(ioctl(sockets->second_fd(), TIOCOUTQ, &size), SyscallSucceeds());
+ EXPECT_EQ(size, sizeof(some_data) * 2);
+}
+
+TEST_P(UnixSocketPairTest, NetdeviceIoctlsSucceed) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // Prepare the request.
+ struct ifreq ifr;
+ snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
+
+ // Check that the ioctl either succeeds or fails with ENODEV.
+ int err = ioctl(sockets->first_fd(), SIOCGIFINDEX, &ifr);
+ if (err < 0) {
+ ASSERT_EQ(errno, ENODEV);
+ }
+}
+
+TEST_P(UnixSocketPairTest, Shutdown) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ const std::string data = "abc";
+ ASSERT_THAT(WriteFd(sockets->first_fd(), data.c_str(), data.size()),
+ SyscallSucceedsWithValue(data.size()));
+
+ ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds());
+ ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RDWR), SyscallSucceeds());
+
+ // Shutting down a socket does not clear the buffer.
+ char buf[3];
+ ASSERT_THAT(ReadFd(sockets->second_fd(), buf, data.size()),
+ SyscallSucceedsWithValue(data.size()));
+ EXPECT_EQ(data, absl::string_view(buf, data.size()));
+}
+
+TEST_P(UnixSocketPairTest, ShutdownRead) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RD), SyscallSucceeds());
+
+ // When the socket is shutdown for read, read behavior varies between
+ // different socket types. This is covered by the various ReadOneSideClosed
+ // test cases.
+
+ // ... and the peer cannot write.
+ const std::string data = "abc";
+ EXPECT_THAT(WriteFd(sockets->second_fd(), data.c_str(), data.size()),
+ SyscallFailsWithErrno(EPIPE));
+
+ // ... but the socket can still write.
+ ASSERT_THAT(WriteFd(sockets->first_fd(), data.c_str(), data.size()),
+ SyscallSucceedsWithValue(data.size()));
+
+ // ... and the peer can still read.
+ char buf[3];
+ EXPECT_THAT(ReadFd(sockets->second_fd(), buf, data.size()),
+ SyscallSucceedsWithValue(data.size()));
+ EXPECT_EQ(data, absl::string_view(buf, data.size()));
+}
+
+TEST_P(UnixSocketPairTest, ShutdownWrite) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_WR), SyscallSucceeds());
+
+ // When the socket is shutdown for write, it cannot write.
+ const std::string data = "abc";
+ EXPECT_THAT(WriteFd(sockets->first_fd(), data.c_str(), data.size()),
+ SyscallFailsWithErrno(EPIPE));
+
+ // ... and the peer read behavior varies between different socket types. This
+ // is covered by the various ReadOneSideClosed test cases.
+
+ // ... but the peer can still write.
+ char buf[3];
+ ASSERT_THAT(WriteFd(sockets->second_fd(), data.c_str(), data.size()),
+ SyscallSucceedsWithValue(data.size()));
+
+ // ... and the socket can still read.
+ EXPECT_THAT(ReadFd(sockets->first_fd(), buf, data.size()),
+ SyscallSucceedsWithValue(data.size()));
+ EXPECT_EQ(data, absl::string_view(buf, data.size()));
+}
+
+TEST_P(UnixSocketPairTest, SocketReopenFromProcfs) {
+ // TODO(gvisor.dev/issue/1624): In VFS1, we return EIO instead of ENXIO (see
+ // b/122310852). Remove this skip once VFS1 is deleted.
+ SKIP_IF(IsRunningWithVFS1());
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // Opening a socket pair via /proc/self/fd/X is a ENXIO.
+ for (const int fd : {sockets->first_fd(), sockets->second_fd()}) {
+ ASSERT_THAT(Open(absl::StrCat("/proc/self/fd/", fd), O_WRONLY),
+ PosixErrorIs(ENXIO, ::testing::_));
+ }
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix.h b/test/syscalls/linux/socket_unix.h
new file mode 100644
index 000000000..3625cc404
--- /dev/null
+++ b/test/syscalls/linux/socket_unix.h
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_H_
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of connected unix sockets.
+using UnixSocketPairTest = SocketPairTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_H_
diff --git a/test/syscalls/linux/socket_unix_abstract_nonblock.cc b/test/syscalls/linux/socket_unix_abstract_nonblock.cc
new file mode 100644
index 000000000..8bef76b67
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_abstract_nonblock.cc
@@ -0,0 +1,39 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/socket_non_blocking.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return ApplyVec<SocketPairKind>(
+ AbstractBoundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET},
+ List<int>{SOCK_NONBLOCK}));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ NonBlockingAbstractUnixSockets, NonBlockingSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_blocking_local.cc b/test/syscalls/linux/socket_unix_blocking_local.cc
new file mode 100644
index 000000000..77cb8c6d6
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_blocking_local.cc
@@ -0,0 +1,45 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/socket_blocking.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return VecCat<SocketPairKind>(
+ ApplyVec<SocketPairKind>(
+ UnixDomainSocketPair,
+ std::vector<int>{SOCK_STREAM, SOCK_SEQPACKET, SOCK_DGRAM}),
+ ApplyVec<SocketPairKind>(
+ FilesystemBoundUnixDomainSocketPair,
+ std::vector<int>{SOCK_STREAM, SOCK_SEQPACKET, SOCK_DGRAM}),
+ ApplyVec<SocketPairKind>(
+ AbstractBoundUnixDomainSocketPair,
+ std::vector<int>{SOCK_STREAM, SOCK_SEQPACKET, SOCK_DGRAM}));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ NonBlockingUnixDomainSockets, BlockingSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_cmsg.cc b/test/syscalls/linux/socket_unix_cmsg.cc
new file mode 100644
index 000000000..a16899493
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_cmsg.cc
@@ -0,0 +1,1501 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_unix_cmsg.h"
+
+#include <errno.h>
+#include <net/if.h>
+#include <stdio.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "absl/strings/string_view.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+// This file contains tests for control message in Unix domain sockets.
+//
+// This file is a generic socket test file. It must be built with another file
+// that provides the test types.
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST_P(UnixSocketPairCmsgTest, BasicFDPass) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ char received_data[20];
+ int fd = -1;
+ ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data,
+ sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd()));
+}
+
+TEST_P(UnixSocketPairCmsgTest, BasicTwoFDPass) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair1 =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+ auto pair2 =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+ int sent_fds[] = {pair1->second_fd(), pair2->second_fd()};
+
+ ASSERT_NO_FATAL_FAILURE(
+ SendFDs(sockets->first_fd(), sent_fds, 2, sent_data, sizeof(sent_data)));
+
+ char received_data[20];
+ int received_fds[] = {-1, -1};
+
+ ASSERT_NO_FATAL_FAILURE(RecvFDs(sockets->second_fd(), received_fds, 2,
+ received_data, sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[0], pair1->first_fd()));
+ ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[1], pair2->first_fd()));
+}
+
+TEST_P(UnixSocketPairCmsgTest, BasicThreeFDPass) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair1 =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+ auto pair2 =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+ auto pair3 =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+ int sent_fds[] = {pair1->second_fd(), pair2->second_fd(), pair3->second_fd()};
+
+ ASSERT_NO_FATAL_FAILURE(
+ SendFDs(sockets->first_fd(), sent_fds, 3, sent_data, sizeof(sent_data)));
+
+ char received_data[20];
+ int received_fds[] = {-1, -1, -1};
+
+ ASSERT_NO_FATAL_FAILURE(RecvFDs(sockets->second_fd(), received_fds, 3,
+ received_data, sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[0], pair1->first_fd()));
+ ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[1], pair2->first_fd()));
+ ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[2], pair3->first_fd()));
+}
+
+TEST_P(UnixSocketPairCmsgTest, BadFDPass) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ int sent_fd = -1;
+
+ struct msghdr msg = {};
+ char control[CMSG_SPACE(sizeof(sent_fd))];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ cmsg->cmsg_len = CMSG_LEN(sizeof(sent_fd));
+ cmsg->cmsg_level = SOL_SOCKET;
+ cmsg->cmsg_type = SCM_RIGHTS;
+ memcpy(CMSG_DATA(cmsg), &sent_fd, sizeof(sent_fd));
+
+ struct iovec iov;
+ iov.iov_base = sent_data;
+ iov.iov_len = sizeof(sent_data);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0),
+ SyscallFailsWithErrno(EBADF));
+}
+
+TEST_P(UnixSocketPairCmsgTest, ShortCmsg) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ int sent_fd = -1;
+
+ struct msghdr msg = {};
+ char control[CMSG_SPACE(sizeof(sent_fd))];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ cmsg->cmsg_len = 1;
+ cmsg->cmsg_level = SOL_SOCKET;
+ cmsg->cmsg_type = SCM_RIGHTS;
+ memcpy(CMSG_DATA(cmsg), &sent_fd, sizeof(sent_fd));
+
+ struct iovec iov;
+ iov.iov_base = sent_data;
+ iov.iov_len = sizeof(sent_data);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// BasicFDPassNoSpace starts off by sending a single FD just like BasicFDPass.
+// The difference is that when calling recvmsg, no space for FDs is provided,
+// only space for the cmsg header.
+TEST_P(UnixSocketPairCmsgTest, BasicFDPassNoSpace) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ char received_data[20];
+
+ struct msghdr msg = {};
+ std::vector<char> control(CMSG_SPACE(0));
+ msg.msg_control = &control[0];
+ msg.msg_controllen = control.size();
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ EXPECT_EQ(msg.msg_controllen, 0);
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+}
+
+// BasicFDPassNoSpaceMsgCtrunc sends an FD, but does not provide any space to
+// receive it. It then verifies that the MSG_CTRUNC flag is set in the msghdr.
+TEST_P(UnixSocketPairCmsgTest, BasicFDPassNoSpaceMsgCtrunc) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ struct msghdr msg = {};
+ std::vector<char> control(CMSG_SPACE(0));
+ msg.msg_control = &control[0];
+ msg.msg_controllen = control.size();
+
+ char received_data[sizeof(sent_data)];
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ EXPECT_EQ(msg.msg_controllen, 0);
+ EXPECT_EQ(msg.msg_flags, MSG_CTRUNC);
+}
+
+// BasicFDPassNullControlMsgCtrunc sends an FD and sets contradictory values for
+// msg_controllen and msg_control. msg_controllen is set to the correct size to
+// accommodate the FD, but msg_control is set to NULL. In this case, msg_control
+// should override msg_controllen.
+TEST_P(UnixSocketPairCmsgTest, BasicFDPassNullControlMsgCtrunc) {
+ // FIXME(gvisor.dev/issue/207): Fix handling of NULL msg_control.
+ SKIP_IF(IsRunningOnGvisor());
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ struct msghdr msg = {};
+ msg.msg_controllen = CMSG_SPACE(1);
+
+ char received_data[sizeof(sent_data)];
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ EXPECT_EQ(msg.msg_controllen, 0);
+ EXPECT_EQ(msg.msg_flags, MSG_CTRUNC);
+}
+
+// BasicFDPassNotEnoughSpaceMsgCtrunc sends an FD, but does not provide enough
+// space to receive it. It then verifies that the MSG_CTRUNC flag is set in the
+// msghdr.
+TEST_P(UnixSocketPairCmsgTest, BasicFDPassNotEnoughSpaceMsgCtrunc) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ struct msghdr msg = {};
+ std::vector<char> control(CMSG_SPACE(0) + 1);
+ msg.msg_control = &control[0];
+ msg.msg_controllen = control.size();
+
+ char received_data[sizeof(sent_data)];
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ EXPECT_EQ(msg.msg_controllen, 0);
+ EXPECT_EQ(msg.msg_flags, MSG_CTRUNC);
+}
+
+// BasicThreeFDPassTruncationMsgCtrunc sends three FDs, but only provides enough
+// space to receive two of them. It then verifies that the MSG_CTRUNC flag is
+// set in the msghdr.
+TEST_P(UnixSocketPairCmsgTest, BasicThreeFDPassTruncationMsgCtrunc) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair1 =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+ auto pair2 =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+ auto pair3 =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+ int sent_fds[] = {pair1->second_fd(), pair2->second_fd(), pair3->second_fd()};
+
+ ASSERT_NO_FATAL_FAILURE(
+ SendFDs(sockets->first_fd(), sent_fds, 3, sent_data, sizeof(sent_data)));
+
+ struct msghdr msg = {};
+ std::vector<char> control(CMSG_SPACE(2 * sizeof(int)));
+ msg.msg_control = &control[0];
+ msg.msg_controllen = control.size();
+
+ char received_data[sizeof(sent_data)];
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ EXPECT_EQ(msg.msg_flags, MSG_CTRUNC);
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(2 * sizeof(int)));
+ EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET);
+ EXPECT_EQ(cmsg->cmsg_type, SCM_RIGHTS);
+}
+
+// BasicFDPassUnalignedRecv starts off by sending a single FD just like
+// BasicFDPass. The difference is that when calling recvmsg, the length of the
+// receive data is only aligned on a 4 byte boundry instead of the normal 8.
+TEST_P(UnixSocketPairCmsgTest, BasicFDPassUnalignedRecv) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ char received_data[20];
+ int fd = -1;
+ ASSERT_NO_FATAL_FAILURE(RecvSingleFDUnaligned(
+ sockets->second_fd(), &fd, received_data, sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd()));
+}
+
+// BasicFDPassUnalignedRecvNoMsgTrunc sends one FD and only provides enough
+// space to receive just it. (Normally the minimum amount of space one would
+// provide would be enough space for two FDs.) It then verifies that the
+// MSG_CTRUNC flag is not set in the msghdr.
+TEST_P(UnixSocketPairCmsgTest, BasicFDPassUnalignedRecvNoMsgTrunc) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ struct msghdr msg = {};
+ char control[CMSG_SPACE(sizeof(int)) - sizeof(int)];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ char received_data[sizeof(sent_data)] = {};
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ EXPECT_EQ(msg.msg_flags, 0);
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int)));
+ EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET);
+ EXPECT_EQ(cmsg->cmsg_type, SCM_RIGHTS);
+}
+
+// BasicTwoFDPassUnalignedRecvTruncationMsgTrunc sends two FDs, but only
+// provides enough space to receive one of them. It then verifies that the
+// MSG_CTRUNC flag is set in the msghdr.
+TEST_P(UnixSocketPairCmsgTest, BasicTwoFDPassUnalignedRecvTruncationMsgTrunc) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+ int sent_fds[] = {pair->first_fd(), pair->second_fd()};
+
+ ASSERT_NO_FATAL_FAILURE(
+ SendFDs(sockets->first_fd(), sent_fds, 2, sent_data, sizeof(sent_data)));
+
+ struct msghdr msg = {};
+ // CMSG_SPACE rounds up to two FDs, we only want one.
+ char control[CMSG_SPACE(sizeof(int)) - sizeof(int)];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ char received_data[sizeof(sent_data)] = {};
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ EXPECT_EQ(msg.msg_flags, MSG_CTRUNC);
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int)));
+ EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET);
+ EXPECT_EQ(cmsg->cmsg_type, SCM_RIGHTS);
+}
+
+TEST_P(UnixSocketPairCmsgTest, ConcurrentBasicFDPass) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ int sockfd1 = sockets->first_fd();
+ auto recv_func = [sockfd1, sent_data]() {
+ char received_data[20];
+ int fd = -1;
+ RecvSingleFD(sockfd1, &fd, received_data, sizeof(received_data));
+ ASSERT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+ char buf[20];
+ ASSERT_THAT(ReadFd(fd, buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ ASSERT_THAT(WriteFd(fd, buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ };
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->second_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ ScopedThread t(recv_func);
+
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(WriteFd(pair->first_fd(), sent_data, sizeof(sent_data)),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[20];
+ ASSERT_THAT(ReadFd(pair->first_fd(), received_data, sizeof(received_data)),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ t.Join();
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+}
+
+// FDPassNoRecv checks that the control message can be safely ignored by using
+// read(2) instead of recvmsg(2).
+TEST_P(UnixSocketPairCmsgTest, FDPassNoRecv) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ // Read while ignoring the passed FD.
+ char received_data[20];
+ ASSERT_THAT(
+ ReadFd(sockets->second_fd(), received_data, sizeof(received_data)),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ // Check that the socket still works for reads and writes.
+ ASSERT_NO_FATAL_FAILURE(
+ TransferTest(sockets->first_fd(), sockets->second_fd()));
+}
+
+// FDPassInterspersed1 checks that sent control messages cannot be read before
+// their associated data has been read.
+TEST_P(UnixSocketPairCmsgTest, FDPassInterspersed1) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char written_data[20];
+ RandomizeBuffer(written_data, sizeof(written_data));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), written_data, sizeof(written_data)),
+ SyscallSucceedsWithValue(sizeof(written_data)));
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ // Check that we don't get a control message, but do get the data.
+ char received_data[20];
+ RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data));
+ EXPECT_EQ(0, memcmp(written_data, received_data, sizeof(written_data)));
+}
+
+// FDPassInterspersed2 checks that sent control messages cannot be read after
+// their associated data has been read while ignoring the control message by
+// using read(2) instead of recvmsg(2).
+TEST_P(UnixSocketPairCmsgTest, FDPassInterspersed2) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ char written_data[20];
+ RandomizeBuffer(written_data, sizeof(written_data));
+ ASSERT_THAT(WriteFd(sockets->first_fd(), written_data, sizeof(written_data)),
+ SyscallSucceedsWithValue(sizeof(written_data)));
+
+ char received_data[20];
+ ASSERT_THAT(
+ ReadFd(sockets->second_fd(), received_data, sizeof(received_data)),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ ASSERT_NO_FATAL_FAILURE(
+ RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data)));
+ EXPECT_EQ(0, memcmp(written_data, received_data, sizeof(written_data)));
+}
+
+TEST_P(UnixSocketPairCmsgTest, FDPassNotCoalesced) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data1[20];
+ RandomizeBuffer(sent_data1, sizeof(sent_data1));
+
+ auto pair1 =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair1->second_fd(),
+ sent_data1, sizeof(sent_data1)));
+
+ char sent_data2[20];
+ RandomizeBuffer(sent_data2, sizeof(sent_data2));
+
+ auto pair2 =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair2->second_fd(),
+ sent_data2, sizeof(sent_data2)));
+
+ char received_data1[sizeof(sent_data1) + sizeof(sent_data2)];
+ int received_fd1 = -1;
+
+ RecvSingleFD(sockets->second_fd(), &received_fd1, received_data1,
+ sizeof(received_data1), sizeof(sent_data1));
+
+ EXPECT_EQ(0, memcmp(sent_data1, received_data1, sizeof(sent_data1)));
+ TransferTest(pair1->first_fd(), pair1->second_fd());
+
+ char received_data2[sizeof(sent_data1) + sizeof(sent_data2)];
+ int received_fd2 = -1;
+
+ RecvSingleFD(sockets->second_fd(), &received_fd2, received_data2,
+ sizeof(received_data2), sizeof(sent_data2));
+
+ EXPECT_EQ(0, memcmp(sent_data2, received_data2, sizeof(sent_data2)));
+ TransferTest(pair2->first_fd(), pair2->second_fd());
+}
+
+TEST_P(UnixSocketPairCmsgTest, FDPassPeek) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ char peek_data[20];
+ int peek_fd = -1;
+ PeekSingleFD(sockets->second_fd(), &peek_fd, peek_data, sizeof(peek_data));
+ EXPECT_EQ(0, memcmp(sent_data, peek_data, sizeof(sent_data)));
+ TransferTest(peek_fd, pair->first_fd());
+ EXPECT_THAT(close(peek_fd), SyscallSucceeds());
+
+ char received_data[20];
+ int received_fd = -1;
+ RecvSingleFD(sockets->second_fd(), &received_fd, received_data,
+ sizeof(received_data));
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+ TransferTest(received_fd, pair->first_fd());
+ EXPECT_THAT(close(received_fd), SyscallSucceeds());
+}
+
+TEST_P(UnixSocketPairCmsgTest, BasicCredPass) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ struct ucred sent_creds;
+
+ ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds());
+ ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds());
+
+ ASSERT_NO_FATAL_FAILURE(
+ SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data)));
+
+ SetSoPassCred(sockets->second_fd());
+
+ char received_data[20];
+ struct ucred received_creds;
+ ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds,
+ received_data, sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+ EXPECT_EQ(sent_creds.pid, received_creds.pid);
+ EXPECT_EQ(sent_creds.uid, received_creds.uid);
+ EXPECT_EQ(sent_creds.gid, received_creds.gid);
+}
+
+TEST_P(UnixSocketPairCmsgTest, SendNullCredsBeforeSoPassCredRecvEnd) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ ASSERT_NO_FATAL_FAILURE(
+ SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data)));
+
+ SetSoPassCred(sockets->second_fd());
+
+ char received_data[20];
+ struct ucred received_creds;
+ ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds,
+ received_data, sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ struct ucred want_creds {
+ 0, 65534, 65534
+ };
+
+ EXPECT_EQ(want_creds.pid, received_creds.pid);
+ EXPECT_EQ(want_creds.uid, received_creds.uid);
+ EXPECT_EQ(want_creds.gid, received_creds.gid);
+}
+
+TEST_P(UnixSocketPairCmsgTest, SendNullCredsAfterSoPassCredRecvEnd) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ SetSoPassCred(sockets->second_fd());
+
+ ASSERT_NO_FATAL_FAILURE(
+ SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data)));
+
+ char received_data[20];
+ struct ucred received_creds;
+ ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds,
+ received_data, sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ struct ucred want_creds;
+ ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds());
+
+ EXPECT_EQ(want_creds.pid, received_creds.pid);
+ EXPECT_EQ(want_creds.uid, received_creds.uid);
+ EXPECT_EQ(want_creds.gid, received_creds.gid);
+}
+
+TEST_P(UnixSocketPairCmsgTest, SendNullCredsBeforeSoPassCredSendEnd) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ ASSERT_NO_FATAL_FAILURE(
+ SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data)));
+
+ SetSoPassCred(sockets->first_fd());
+
+ char received_data[20];
+ ASSERT_NO_FATAL_FAILURE(
+ RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+}
+
+TEST_P(UnixSocketPairCmsgTest, SendNullCredsAfterSoPassCredSendEnd) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ SetSoPassCred(sockets->first_fd());
+
+ ASSERT_NO_FATAL_FAILURE(
+ SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data)));
+
+ char received_data[20];
+ ASSERT_NO_FATAL_FAILURE(
+ RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+}
+
+TEST_P(UnixSocketPairCmsgTest,
+ SendNullCredsBeforeSoPassCredRecvEndAfterSendEnd) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ SetSoPassCred(sockets->first_fd());
+
+ ASSERT_NO_FATAL_FAILURE(
+ SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data)));
+
+ SetSoPassCred(sockets->second_fd());
+
+ char received_data[20];
+ struct ucred received_creds;
+ ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds,
+ received_data, sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ struct ucred want_creds;
+ ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds());
+
+ EXPECT_EQ(want_creds.pid, received_creds.pid);
+ EXPECT_EQ(want_creds.uid, received_creds.uid);
+ EXPECT_EQ(want_creds.gid, received_creds.gid);
+}
+
+TEST_P(UnixSocketPairCmsgTest, WriteBeforeSoPassCredRecvEnd) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ SetSoPassCred(sockets->second_fd());
+
+ char received_data[20];
+
+ struct ucred received_creds;
+ ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds,
+ received_data, sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ struct ucred want_creds {
+ 0, 65534, 65534
+ };
+
+ EXPECT_EQ(want_creds.pid, received_creds.pid);
+ EXPECT_EQ(want_creds.uid, received_creds.uid);
+ EXPECT_EQ(want_creds.gid, received_creds.gid);
+}
+
+TEST_P(UnixSocketPairCmsgTest, WriteAfterSoPassCredRecvEnd) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ SetSoPassCred(sockets->second_fd());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[20];
+
+ struct ucred received_creds;
+ ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds,
+ received_data, sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ struct ucred want_creds;
+ ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds());
+
+ EXPECT_EQ(want_creds.pid, received_creds.pid);
+ EXPECT_EQ(want_creds.uid, received_creds.uid);
+ EXPECT_EQ(want_creds.gid, received_creds.gid);
+}
+
+TEST_P(UnixSocketPairCmsgTest, WriteBeforeSoPassCredSendEnd) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ SetSoPassCred(sockets->first_fd());
+
+ char received_data[20];
+ ASSERT_NO_FATAL_FAILURE(
+ RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+}
+
+TEST_P(UnixSocketPairCmsgTest, WriteAfterSoPassCredSendEnd) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ SetSoPassCred(sockets->first_fd());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[20];
+ ASSERT_NO_FATAL_FAILURE(
+ RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+}
+
+TEST_P(UnixSocketPairCmsgTest, WriteBeforeSoPassCredRecvEndAfterSendEnd) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ SetSoPassCred(sockets->first_fd());
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ SetSoPassCred(sockets->second_fd());
+
+ char received_data[20];
+
+ struct ucred received_creds;
+ ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds,
+ received_data, sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ struct ucred want_creds;
+ ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds());
+
+ EXPECT_EQ(want_creds.pid, received_creds.pid);
+ EXPECT_EQ(want_creds.uid, received_creds.uid);
+ EXPECT_EQ(want_creds.gid, received_creds.gid);
+}
+
+TEST_P(UnixSocketPairCmsgTest, CredPassTruncated) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ struct ucred sent_creds;
+
+ ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds());
+ ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds());
+
+ ASSERT_NO_FATAL_FAILURE(
+ SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data)));
+
+ SetSoPassCred(sockets->second_fd());
+
+ struct msghdr msg = {};
+ char control[CMSG_SPACE(0) + sizeof(pid_t)];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ char received_data[sizeof(sent_data)] = {};
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ EXPECT_EQ(msg.msg_controllen, sizeof(control));
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, sizeof(control));
+ EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET);
+ EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS);
+
+ pid_t pid = 0;
+ memcpy(&pid, CMSG_DATA(cmsg), sizeof(pid));
+ EXPECT_EQ(pid, sent_creds.pid);
+}
+
+// CredPassNoMsgCtrunc passes a full set of credentials. It then verifies that
+// receiving the full set does not result in MSG_CTRUNC being set in the msghdr.
+TEST_P(UnixSocketPairCmsgTest, CredPassNoMsgCtrunc) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ struct ucred sent_creds;
+
+ ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds());
+ ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds());
+
+ ASSERT_NO_FATAL_FAILURE(
+ SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data)));
+
+ SetSoPassCred(sockets->second_fd());
+
+ struct msghdr msg = {};
+ char control[CMSG_SPACE(sizeof(struct ucred))];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ char received_data[sizeof(sent_data)] = {};
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ // The control message should not be truncated.
+ EXPECT_EQ(msg.msg_flags, 0);
+ EXPECT_EQ(msg.msg_controllen, sizeof(control));
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct ucred)));
+ EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET);
+ EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS);
+}
+
+// CredPassNoSpaceMsgCtrunc passes a full set of credentials. It then receives
+// the data without providing space for any credentials and verifies that
+// MSG_CTRUNC is set in the msghdr.
+TEST_P(UnixSocketPairCmsgTest, CredPassNoSpaceMsgCtrunc) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ struct ucred sent_creds;
+
+ ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds());
+ ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds());
+
+ ASSERT_NO_FATAL_FAILURE(
+ SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data)));
+
+ SetSoPassCred(sockets->second_fd());
+
+ struct msghdr msg = {};
+ char control[CMSG_SPACE(0)];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ char received_data[sizeof(sent_data)] = {};
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ // The control message should be truncated.
+ EXPECT_EQ(msg.msg_flags, MSG_CTRUNC);
+ EXPECT_EQ(msg.msg_controllen, sizeof(control));
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, sizeof(control));
+ EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET);
+ EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS);
+}
+
+// CredPassTruncatedMsgCtrunc passes a full set of credentials. It then receives
+// the data while providing enough space for only the first field of the
+// credentials and verifies that MSG_CTRUNC is set in the msghdr.
+TEST_P(UnixSocketPairCmsgTest, CredPassTruncatedMsgCtrunc) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ struct ucred sent_creds;
+
+ ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds());
+ ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds());
+
+ ASSERT_NO_FATAL_FAILURE(
+ SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data)));
+
+ SetSoPassCred(sockets->second_fd());
+
+ struct msghdr msg = {};
+ char control[CMSG_SPACE(0) + sizeof(pid_t)];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ char received_data[sizeof(sent_data)] = {};
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ // The control message should be truncated.
+ EXPECT_EQ(msg.msg_flags, MSG_CTRUNC);
+ EXPECT_EQ(msg.msg_controllen, sizeof(control));
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, sizeof(control));
+ EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET);
+ EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS);
+}
+
+TEST_P(UnixSocketPairCmsgTest, SoPassCred) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int opt;
+ socklen_t optLen = sizeof(opt);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen),
+ SyscallSucceeds());
+ EXPECT_FALSE(opt);
+
+ optLen = sizeof(opt);
+ EXPECT_THAT(
+ getsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen),
+ SyscallSucceeds());
+ EXPECT_FALSE(opt);
+
+ SetSoPassCred(sockets->first_fd());
+
+ optLen = sizeof(opt);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen),
+ SyscallSucceeds());
+ EXPECT_TRUE(opt);
+
+ optLen = sizeof(opt);
+ EXPECT_THAT(
+ getsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen),
+ SyscallSucceeds());
+ EXPECT_FALSE(opt);
+
+ int zero = 0;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &zero,
+ sizeof(zero)),
+ SyscallSucceeds());
+
+ optLen = sizeof(opt);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen),
+ SyscallSucceeds());
+ EXPECT_FALSE(opt);
+
+ optLen = sizeof(opt);
+ EXPECT_THAT(
+ getsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen),
+ SyscallSucceeds());
+ EXPECT_FALSE(opt);
+}
+
+TEST_P(UnixSocketPairCmsgTest, NoDataCredPass) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ struct msghdr msg = {};
+
+ struct iovec iov;
+ iov.iov_base = sent_data;
+ iov.iov_len = sizeof(sent_data);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ char control[CMSG_SPACE(0)];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ cmsg->cmsg_level = SOL_SOCKET;
+ cmsg->cmsg_type = SCM_CREDENTIALS;
+ cmsg->cmsg_len = CMSG_LEN(0);
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(UnixSocketPairCmsgTest, NoPassCred) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ struct ucred sent_creds;
+
+ ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds());
+ ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds());
+
+ ASSERT_NO_FATAL_FAILURE(
+ SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data)));
+
+ char received_data[20];
+
+ ASSERT_NO_FATAL_FAILURE(
+ RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+}
+
+TEST_P(UnixSocketPairCmsgTest, CredAndFDPass) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ struct ucred sent_creds;
+
+ ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds());
+ ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds());
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendCredsAndFD(sockets->first_fd(), sent_creds,
+ pair->second_fd(), sent_data,
+ sizeof(sent_data)));
+
+ SetSoPassCred(sockets->second_fd());
+
+ char received_data[20];
+ struct ucred received_creds;
+ int fd = -1;
+ ASSERT_NO_FATAL_FAILURE(RecvCredsAndFD(sockets->second_fd(), &received_creds,
+ &fd, received_data,
+ sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ EXPECT_EQ(sent_creds.pid, received_creds.pid);
+ EXPECT_EQ(sent_creds.uid, received_creds.uid);
+ EXPECT_EQ(sent_creds.gid, received_creds.gid);
+
+ ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd()));
+}
+
+TEST_P(UnixSocketPairCmsgTest, FDPassBeforeSoPassCred) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ SetSoPassCred(sockets->second_fd());
+
+ char received_data[20];
+ struct ucred received_creds;
+ int fd = -1;
+ ASSERT_NO_FATAL_FAILURE(RecvCredsAndFD(sockets->second_fd(), &received_creds,
+ &fd, received_data,
+ sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ struct ucred want_creds {
+ 0, 65534, 65534
+ };
+
+ EXPECT_EQ(want_creds.pid, received_creds.pid);
+ EXPECT_EQ(want_creds.uid, received_creds.uid);
+ EXPECT_EQ(want_creds.gid, received_creds.gid);
+
+ ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd()));
+}
+
+TEST_P(UnixSocketPairCmsgTest, FDPassAfterSoPassCred) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ SetSoPassCred(sockets->second_fd());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ char received_data[20];
+ struct ucred received_creds;
+ int fd = -1;
+ ASSERT_NO_FATAL_FAILURE(RecvCredsAndFD(sockets->second_fd(), &received_creds,
+ &fd, received_data,
+ sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ struct ucred want_creds;
+ ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds());
+
+ EXPECT_EQ(want_creds.pid, received_creds.pid);
+ EXPECT_EQ(want_creds.uid, received_creds.uid);
+ EXPECT_EQ(want_creds.gid, received_creds.gid);
+
+ ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd()));
+}
+
+TEST_P(UnixSocketPairCmsgTest, CloexecDroppedWhenFDPassed) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair = ASSERT_NO_ERRNO_AND_VALUE(
+ UnixDomainSocketPair(SOCK_SEQPACKET | SOCK_CLOEXEC).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ char received_data[20];
+ int fd = -1;
+ ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data,
+ sizeof(received_data)));
+
+ EXPECT_THAT(fcntl(fd, F_GETFD), SyscallSucceedsWithValue(0));
+}
+
+TEST_P(UnixSocketPairCmsgTest, CloexecRecvFDPass) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ struct msghdr msg = {};
+ char control[CMSG_SPACE(sizeof(int))];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ struct iovec iov;
+ char received_data[20];
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_CMSG_CLOEXEC),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int)));
+ ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET);
+ ASSERT_EQ(cmsg->cmsg_type, SCM_RIGHTS);
+
+ int fd = -1;
+ memcpy(&fd, CMSG_DATA(cmsg), sizeof(int));
+
+ EXPECT_THAT(fcntl(fd, F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC));
+}
+
+TEST_P(UnixSocketPairCmsgTest, FDPassAfterSoPassCredWithoutCredSpace) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ SetSoPassCred(sockets->second_fd());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ struct msghdr msg = {};
+ char control[CMSG_LEN(0)];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ char received_data[20];
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ EXPECT_EQ(msg.msg_controllen, sizeof(control));
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, sizeof(control));
+ EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET);
+ EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS);
+}
+
+// This test will validate that MSG_CTRUNC as an input flag to recvmsg will
+// not appear as an output flag on the control message when truncation doesn't
+// happen.
+TEST_P(UnixSocketPairCmsgTest, MsgCtruncInputIsNoop) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ struct msghdr msg = {};
+ char control[CMSG_SPACE(sizeof(int)) /* we're passing a single fd */];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ struct iovec iov;
+ char received_data[20];
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_CTRUNC),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int)));
+ ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET);
+ ASSERT_EQ(cmsg->cmsg_type, SCM_RIGHTS);
+
+ // Now we should verify that MSG_CTRUNC wasn't set as an output flag.
+ EXPECT_EQ(msg.msg_flags & MSG_CTRUNC, 0);
+}
+
+TEST_P(UnixSocketPairCmsgTest, FDPassAfterSoPassCredWithoutCredHeaderSpace) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ SetSoPassCred(sockets->second_fd());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ struct msghdr msg = {};
+ char control[CMSG_LEN(0) / 2];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ char received_data[20];
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+ EXPECT_EQ(msg.msg_controllen, 0);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_cmsg.h b/test/syscalls/linux/socket_unix_cmsg.h
new file mode 100644
index 000000000..431606903
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_cmsg.h
@@ -0,0 +1,30 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_CMSG_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_CMSG_H_
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of connected unix sockets about
+// control messages.
+using UnixSocketPairCmsgTest = SocketPairTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_CMSG_H_
diff --git a/test/syscalls/linux/socket_unix_dgram.cc b/test/syscalls/linux/socket_unix_dgram.cc
new file mode 100644
index 000000000..af0df4fb4
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_dgram.cc
@@ -0,0 +1,45 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_unix_dgram.h"
+
+#include <stdio.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST_P(DgramUnixSocketPairTest, WriteOneSideClosed) {
+ // FIXME(b/35925052): gVisor datagram sockets return EPIPE instead of
+ // ECONNREFUSED.
+ SKIP_IF(IsRunningOnGvisor());
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+ constexpr char kStr[] = "abc";
+ ASSERT_THAT(write(sockets->second_fd(), kStr, 3),
+ SyscallFailsWithErrno(ECONNREFUSED));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_dgram.h b/test/syscalls/linux/socket_unix_dgram.h
new file mode 100644
index 000000000..0764ef85b
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_dgram.h
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_DGRAM_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_DGRAM_H_
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of connected dgram unix sockets.
+using DgramUnixSocketPairTest = SocketPairTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_DGRAM_H_
diff --git a/test/syscalls/linux/socket_unix_dgram_local.cc b/test/syscalls/linux/socket_unix_dgram_local.cc
new file mode 100644
index 000000000..31d2d5216
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_dgram_local.cc
@@ -0,0 +1,58 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/socket_non_stream.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/socket_unix_dgram.h"
+#include "test/syscalls/linux/socket_unix_non_stream.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return VecCat<SocketPairKind>(VecCat<SocketPairKind>(
+ ApplyVec<SocketPairKind>(
+ UnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_DGRAM, SOCK_RAW},
+ List<int>{0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(
+ FilesystemBoundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_DGRAM, SOCK_RAW},
+ List<int>{0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(
+ AbstractBoundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_DGRAM, SOCK_RAW},
+ List<int>{0, SOCK_NONBLOCK}))));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ DgramUnixSockets, DgramUnixSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+INSTANTIATE_TEST_SUITE_P(
+ DgramUnixSockets, UnixNonStreamSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+INSTANTIATE_TEST_SUITE_P(
+ DgramUnixSockets, NonStreamSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_dgram_non_blocking.cc b/test/syscalls/linux/socket_unix_dgram_non_blocking.cc
new file mode 100644
index 000000000..2db8b68d3
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_dgram_non_blocking.cc
@@ -0,0 +1,57 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Test fixture for tests that apply to pairs of connected non-blocking dgram
+// unix sockets.
+using NonBlockingDgramUnixSocketPairTest = SocketPairTest;
+
+TEST_P(NonBlockingDgramUnixSocketPairTest, ReadOneSideClosed) {
+ if (IsRunningOnGvisor()) {
+ // FIXME(b/70803293): gVisor datagram sockets return 0 instead of
+ // EAGAIN.
+ return;
+ }
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+ char data[10] = {};
+ ASSERT_THAT(read(sockets->second_fd(), data, sizeof(data)),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ NonBlockingDgramUnixSockets, NonBlockingDgramUnixSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(std::vector<SocketPairKind>{
+ UnixDomainSocketPair(SOCK_DGRAM | SOCK_NONBLOCK),
+ FilesystemBoundUnixDomainSocketPair(SOCK_DGRAM | SOCK_NONBLOCK),
+ AbstractBoundUnixDomainSocketPair(SOCK_DGRAM | SOCK_NONBLOCK),
+ })));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_domain.cc b/test/syscalls/linux/socket_unix_domain.cc
new file mode 100644
index 000000000..f7dff8b4d
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_domain.cc
@@ -0,0 +1,39 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/socket_generic.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return ApplyVec<SocketPairKind>(
+ UnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET},
+ List<int>{0, SOCK_NONBLOCK}));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllUnixDomainSockets, AllSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_filesystem_nonblock.cc b/test/syscalls/linux/socket_unix_filesystem_nonblock.cc
new file mode 100644
index 000000000..6700b4d90
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_filesystem_nonblock.cc
@@ -0,0 +1,39 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/socket_non_blocking.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return ApplyVec<SocketPairKind>(
+ FilesystemBoundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET},
+ List<int>{SOCK_NONBLOCK}));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ NonBlockingFilesystemUnixSockets, NonBlockingSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_non_stream.cc b/test/syscalls/linux/socket_unix_non_stream.cc
new file mode 100644
index 000000000..884319e1d
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_non_stream.cc
@@ -0,0 +1,256 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_unix_non_stream.h"
+
+#include <stdio.h>
+#include <sys/mman.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/memory_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+TEST_P(UnixNonStreamSocketPairTest, RecvMsgTooLarge) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int rcvbuf;
+ socklen_t length = sizeof(rcvbuf);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVBUF, &rcvbuf, &length),
+ SyscallSucceeds());
+
+ // Make the call larger than the receive buffer.
+ const int recv_size = 3 * rcvbuf;
+
+ // Write a message that does fit in the receive buffer.
+ const int write_size = rcvbuf - kPageSize;
+
+ std::vector<char> write_buf(write_size, 'a');
+ const int ret = RetryEINTR(write)(sockets->second_fd(), write_buf.data(),
+ write_buf.size());
+ if (ret < 0 && errno == ENOBUFS) {
+ // NOTE(b/116636318): Linux may stall the write for a long time and
+ // ultimately return ENOBUFS. Allow this error, since a retry will likely
+ // result in the same error.
+ return;
+ }
+ ASSERT_THAT(ret, SyscallSucceeds());
+
+ std::vector<char> recv_buf(recv_size);
+
+ ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(sockets->first_fd(), recv_buf.data(),
+ recv_buf.size(), write_size));
+
+ recv_buf.resize(write_size);
+ EXPECT_EQ(recv_buf, write_buf);
+}
+
+// Create a region of anonymous memory of size 'size', which is fragmented in
+// FileMem.
+//
+// ptr contains the start address of the region. The returned vector contains
+// all of the mappings to be unmapped when done.
+PosixErrorOr<std::vector<Mapping>> CreateFragmentedRegion(const int size,
+ void** ptr) {
+ Mapping region;
+ ASSIGN_OR_RETURN_ERRNO(region, Mmap(nullptr, size, PROT_NONE,
+ MAP_ANONYMOUS | MAP_PRIVATE, -1, 0));
+
+ *ptr = region.ptr();
+
+ // Don't save hundreds of times for all of these mmaps.
+ DisableSave ds;
+
+ std::vector<Mapping> pages;
+
+ // Map and commit a single page at a time, mapping and committing an unrelated
+ // page between each call to force FileMem fragmentation.
+ for (uintptr_t addr = region.addr(); addr < region.endaddr();
+ addr += kPageSize) {
+ Mapping page;
+ ASSIGN_OR_RETURN_ERRNO(
+ page,
+ Mmap(reinterpret_cast<void*>(addr), kPageSize, PROT_READ | PROT_WRITE,
+ MAP_ANONYMOUS | MAP_PRIVATE | MAP_FIXED, -1, 0));
+ *reinterpret_cast<volatile char*>(page.ptr()) = 42;
+
+ pages.emplace_back(std::move(page));
+
+ // Unrelated page elsewhere.
+ ASSIGN_OR_RETURN_ERRNO(page,
+ Mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE,
+ MAP_ANONYMOUS | MAP_PRIVATE, -1, 0));
+ *reinterpret_cast<volatile char*>(page.ptr()) = 42;
+
+ pages.emplace_back(std::move(page));
+ }
+
+ // The mappings above have taken ownership of the region.
+ region.release();
+
+ return std::move(pages);
+}
+
+// A contiguous iov that is heavily fragmented in FileMem can still be sent
+// successfully. See b/115833655.
+TEST_P(UnixNonStreamSocketPairTest, FragmentedSendMsg) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ const int buffer_size = UIO_MAXIOV * kPageSize;
+ // Extra page for message header overhead.
+ const int sndbuf = buffer_size + kPageSize;
+ // N.B. setsockopt(SO_SNDBUF) doubles the passed value.
+ const int set_sndbuf = sndbuf / 2;
+
+ EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF,
+ &set_sndbuf, sizeof(set_sndbuf)),
+ SyscallSucceeds());
+
+ int actual_sndbuf = 0;
+ socklen_t length = sizeof(actual_sndbuf);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF,
+ &actual_sndbuf, &length),
+ SyscallSucceeds());
+
+ if (actual_sndbuf != sndbuf) {
+ // Unable to get the sndbuf we want.
+ //
+ // N.B. At minimum, the socketpair gofer should provide a socket that is
+ // already the correct size.
+ //
+ // TODO(b/35921550): When internal UDS support SO_SNDBUF, we can assert that
+ // we always get the right SO_SNDBUF on gVisor.
+ GTEST_SKIP() << "SO_SNDBUF = " << actual_sndbuf << ", want " << sndbuf;
+ }
+
+ // Create a contiguous region of memory of 2*UIO_MAXIOV*PAGE_SIZE. We'll call
+ // sendmsg with a single iov, but the goal is to get the sentry to split this
+ // into > UIO_MAXIOV iovs when calling the kernel.
+ void* ptr;
+ std::vector<Mapping> pages =
+ ASSERT_NO_ERRNO_AND_VALUE(CreateFragmentedRegion(buffer_size, &ptr));
+
+ struct iovec iov = {};
+ iov.iov_base = ptr;
+ iov.iov_len = buffer_size;
+
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ // NOTE(b/116636318,b/115833655): Linux has poor behavior in the presence of
+ // physical memory fragmentation. As a result, this may stall for a long time
+ // and ultimately return ENOBUFS. Allow this error, since it means that we
+ // made it to the host kernel and started the sendmsg.
+ EXPECT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0),
+ AnyOf(SyscallSucceedsWithValue(buffer_size),
+ SyscallFailsWithErrno(ENOBUFS)));
+}
+
+// A contiguous iov that is heavily fragmented in FileMem can still be received
+// into successfully. Regression test for b/115833655.
+TEST_P(UnixNonStreamSocketPairTest, FragmentedRecvMsg) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ const int buffer_size = UIO_MAXIOV * kPageSize;
+ // Extra page for message header overhead.
+ const int sndbuf = buffer_size + kPageSize;
+ // N.B. setsockopt(SO_SNDBUF) doubles the passed value.
+ const int set_sndbuf = sndbuf / 2;
+
+ EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF,
+ &set_sndbuf, sizeof(set_sndbuf)),
+ SyscallSucceeds());
+
+ int actual_sndbuf = 0;
+ socklen_t length = sizeof(actual_sndbuf);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF,
+ &actual_sndbuf, &length),
+ SyscallSucceeds());
+
+ if (actual_sndbuf != sndbuf) {
+ // Unable to get the sndbuf we want.
+ //
+ // N.B. At minimum, the socketpair gofer should provide a socket that is
+ // already the correct size.
+ //
+ // TODO(b/35921550): When internal UDS support SO_SNDBUF, we can assert that
+ // we always get the right SO_SNDBUF on gVisor.
+ GTEST_SKIP() << "SO_SNDBUF = " << actual_sndbuf << ", want " << sndbuf;
+ }
+
+ std::vector<char> write_buf(buffer_size, 'a');
+ const int ret = RetryEINTR(write)(sockets->first_fd(), write_buf.data(),
+ write_buf.size());
+ if (ret < 0 && errno == ENOBUFS) {
+ // NOTE(b/116636318): Linux may stall the write for a long time and
+ // ultimately return ENOBUFS. Allow this error, since a retry will likely
+ // result in the same error.
+ return;
+ }
+ ASSERT_THAT(ret, SyscallSucceeds());
+
+ // Create a contiguous region of memory of 2*UIO_MAXIOV*PAGE_SIZE. We'll call
+ // sendmsg with a single iov, but the goal is to get the sentry to split this
+ // into > UIO_MAXIOV iovs when calling the kernel.
+ void* ptr;
+ std::vector<Mapping> pages =
+ ASSERT_NO_ERRNO_AND_VALUE(CreateFragmentedRegion(buffer_size, &ptr));
+
+ ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(
+ sockets->second_fd(), reinterpret_cast<char*>(ptr), buffer_size));
+
+ EXPECT_EQ(0, memcmp(write_buf.data(), ptr, buffer_size));
+}
+
+TEST_P(UnixNonStreamSocketPairTest, SendTimeout) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = 10
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ const int buf_size = 5 * kPageSize;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, &buf_size,
+ sizeof(buf_size)),
+ SyscallSucceeds());
+ EXPECT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_RCVBUF, &buf_size,
+ sizeof(buf_size)),
+ SyscallSucceeds());
+
+ // The buffer size should be big enough to avoid many iterations in the next
+ // loop. Otherwise, this will slow down cooperative_save tests.
+ std::vector<char> buf(kPageSize);
+ for (;;) {
+ int ret;
+ ASSERT_THAT(
+ ret = RetryEINTR(send)(sockets->first_fd(), buf.data(), buf.size(), 0),
+ ::testing::AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(EAGAIN)));
+ if (ret == -1) {
+ break;
+ }
+ }
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_non_stream.h b/test/syscalls/linux/socket_unix_non_stream.h
new file mode 100644
index 000000000..7478ab172
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_non_stream.h
@@ -0,0 +1,30 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_NON_STREAM_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_NON_STREAM_H_
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of connected non-stream
+// unix-domain sockets.
+using UnixNonStreamSocketPairTest = SocketPairTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_NON_STREAM_H_
diff --git a/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc b/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc
new file mode 100644
index 000000000..fddcdf1c5
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/socket_non_stream_blocking.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return VecCat<SocketPairKind>(
+ ApplyVec<SocketPairKind>(UnixDomainSocketPair,
+ std::vector<int>{SOCK_DGRAM, SOCK_SEQPACKET}),
+ ApplyVec<SocketPairKind>(FilesystemBoundUnixDomainSocketPair,
+ std::vector<int>{SOCK_DGRAM, SOCK_SEQPACKET}),
+ ApplyVec<SocketPairKind>(AbstractBoundUnixDomainSocketPair,
+ std::vector<int>{SOCK_DGRAM, SOCK_SEQPACKET}));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ BlockingNonStreamUnixSockets, BlockingNonStreamSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_pair.cc b/test/syscalls/linux/socket_unix_pair.cc
new file mode 100644
index 000000000..85999db04
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_pair.cc
@@ -0,0 +1,44 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/socket_unix.h"
+#include "test/syscalls/linux/socket_unix_cmsg.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return VecCat<SocketPairKind>(ApplyVec<SocketPairKind>(
+ UnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET},
+ List<int>{0, SOCK_NONBLOCK})));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllUnixDomainSockets, UnixSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+INSTANTIATE_TEST_SUITE_P(
+ AllUnixDomainSockets, UnixSocketPairCmsgTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_pair_nonblock.cc b/test/syscalls/linux/socket_unix_pair_nonblock.cc
new file mode 100644
index 000000000..281410a9a
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_pair_nonblock.cc
@@ -0,0 +1,39 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/socket_non_blocking.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return ApplyVec<SocketPairKind>(
+ UnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET},
+ List<int>{SOCK_NONBLOCK}));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ NonBlockingUnixSockets, NonBlockingSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_seqpacket.cc b/test/syscalls/linux/socket_unix_seqpacket.cc
new file mode 100644
index 000000000..6d03df4d9
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_seqpacket.cc
@@ -0,0 +1,67 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_unix_seqpacket.h"
+
+#include <stdio.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST_P(SeqpacketUnixSocketPairTest, WriteOneSideClosed) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+ constexpr char kStr[] = "abc";
+ ASSERT_THAT(write(sockets->second_fd(), kStr, 3),
+ SyscallFailsWithErrno(EPIPE));
+}
+
+TEST_P(SeqpacketUnixSocketPairTest, ReadOneSideClosed) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+ char data[10] = {};
+ ASSERT_THAT(read(sockets->second_fd(), data, sizeof(data)),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_P(SeqpacketUnixSocketPairTest, Sendto) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct sockaddr_un addr = {};
+ addr.sun_family = AF_UNIX;
+ constexpr char kPath[] = "\0nonexistent";
+ memcpy(addr.sun_path, kPath, sizeof(kPath));
+
+ constexpr char kStr[] = "abc";
+ ASSERT_THAT(sendto(sockets->second_fd(), kStr, 3, 0, (struct sockaddr*)&addr,
+ sizeof(addr)),
+ SyscallSucceedsWithValue(3));
+
+ char data[10] = {};
+ ASSERT_THAT(read(sockets->first_fd(), data, sizeof(data)),
+ SyscallSucceedsWithValue(3));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_seqpacket.h b/test/syscalls/linux/socket_unix_seqpacket.h
new file mode 100644
index 000000000..30d9b9edf
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_seqpacket.h
@@ -0,0 +1,30 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_SEQPACKET_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_SEQPACKET_H_
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of connected seqpacket unix
+// sockets.
+using SeqpacketUnixSocketPairTest = SocketPairTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_SEQPACKET_H_
diff --git a/test/syscalls/linux/socket_unix_seqpacket_local.cc b/test/syscalls/linux/socket_unix_seqpacket_local.cc
new file mode 100644
index 000000000..69a5f150d
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_seqpacket_local.cc
@@ -0,0 +1,58 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/socket_non_stream.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/socket_unix_non_stream.h"
+#include "test/syscalls/linux/socket_unix_seqpacket.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return VecCat<SocketPairKind>(VecCat<SocketPairKind>(
+ ApplyVec<SocketPairKind>(
+ UnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_SEQPACKET},
+ List<int>{0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(
+ FilesystemBoundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_SEQPACKET},
+ List<int>{0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(
+ AbstractBoundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_SEQPACKET},
+ List<int>{0, SOCK_NONBLOCK}))));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ SeqpacketUnixSockets, NonStreamSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+INSTANTIATE_TEST_SUITE_P(
+ SeqpacketUnixSockets, SeqpacketUnixSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+INSTANTIATE_TEST_SUITE_P(
+ SeqpacketUnixSockets, UnixNonStreamSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc
new file mode 100644
index 000000000..99e77b89e
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_stream.cc
@@ -0,0 +1,125 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <poll.h>
+#include <stdio.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Test fixture for tests that apply to pairs of connected stream unix sockets.
+using StreamUnixSocketPairTest = SocketPairTest;
+
+TEST_P(StreamUnixSocketPairTest, WriteOneSideClosed) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+ constexpr char kStr[] = "abc";
+ ASSERT_THAT(write(sockets->second_fd(), kStr, 3),
+ SyscallFailsWithErrno(EPIPE));
+}
+
+TEST_P(StreamUnixSocketPairTest, ReadOneSideClosed) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+ char data[10] = {};
+ ASSERT_THAT(read(sockets->second_fd(), data, sizeof(data)),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_P(StreamUnixSocketPairTest, RecvmsgOneSideClosed) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // Set timeout so that it will not wait for ever.
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = 10
+ };
+ EXPECT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv,
+ sizeof(tv)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+
+ char received_data[10] = {};
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(recvmsg(sockets->second_fd(), &msg, MSG_WAITALL),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_P(StreamUnixSocketPairTest, ReadOneSideClosedWithUnreadData) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char buf[10] = {};
+ ASSERT_THAT(RetryEINTR(write)(sockets->second_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds());
+
+ ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+
+ ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)),
+ SyscallFailsWithErrno(ECONNRESET));
+}
+
+TEST_P(StreamUnixSocketPairTest, Sendto) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct sockaddr_un addr = {};
+ addr.sun_family = AF_UNIX;
+ constexpr char kPath[] = "\0nonexistent";
+ memcpy(addr.sun_path, kPath, sizeof(kPath));
+
+ constexpr char kStr[] = "abc";
+ ASSERT_THAT(sendto(sockets->second_fd(), kStr, 3, 0, (struct sockaddr*)&addr,
+ sizeof(addr)),
+ SyscallFailsWithErrno(EISCONN));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllUnixDomainSockets, StreamUnixSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>(
+ ApplyVec<SocketPairKind>(UnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM},
+ List<int>{
+ 0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(FilesystemBoundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM},
+ List<int>{
+ 0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(
+ AbstractBoundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM},
+ List<int>{0, SOCK_NONBLOCK}))))));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_stream_blocking_local.cc b/test/syscalls/linux/socket_unix_stream_blocking_local.cc
new file mode 100644
index 000000000..8429bd429
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_stream_blocking_local.cc
@@ -0,0 +1,40 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/socket_stream_blocking.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return {
+ UnixDomainSocketPair(SOCK_STREAM),
+ FilesystemBoundUnixDomainSocketPair(SOCK_STREAM),
+ AbstractBoundUnixDomainSocketPair(SOCK_STREAM),
+ };
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ BlockingStreamUnixSockets, BlockingStreamSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_stream_local.cc b/test/syscalls/linux/socket_unix_stream_local.cc
new file mode 100644
index 000000000..a7e3449a9
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_stream_local.cc
@@ -0,0 +1,48 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/socket_stream.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return VecCat<SocketPairKind>(
+ ApplyVec<SocketPairKind>(
+ UnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM},
+ List<int>{0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(
+ FilesystemBoundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM},
+ List<int>{0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(
+ AbstractBoundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM},
+ List<int>{0, SOCK_NONBLOCK})));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ StreamUnixSockets, StreamSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_stream_nonblock_local.cc b/test/syscalls/linux/socket_unix_stream_nonblock_local.cc
new file mode 100644
index 000000000..4b763c8e2
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_stream_nonblock_local.cc
@@ -0,0 +1,39 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include <vector>
+
+#include "test/syscalls/linux/socket_stream_nonblock.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketPairKind> GetSocketPairs() {
+ return {
+ UnixDomainSocketPair(SOCK_STREAM | SOCK_NONBLOCK),
+ FilesystemBoundUnixDomainSocketPair(SOCK_STREAM | SOCK_NONBLOCK),
+ AbstractBoundUnixDomainSocketPair(SOCK_STREAM | SOCK_NONBLOCK),
+ };
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ NonBlockingStreamUnixSockets, NonBlockingStreamSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_unbound_abstract.cc b/test/syscalls/linux/socket_unix_unbound_abstract.cc
new file mode 100644
index 000000000..8b1762000
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_unbound_abstract.cc
@@ -0,0 +1,116 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Test fixture for tests that apply to pairs of unbound abstract unix sockets.
+using UnboundAbstractUnixSocketPairTest = SocketPairTest;
+
+TEST_P(UnboundAbstractUnixSocketPairTest, AddressAfterNull) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct sockaddr_un addr =
+ *reinterpret_cast<const struct sockaddr_un*>(sockets->first_addr());
+ ASSERT_EQ(addr.sun_path[sizeof(addr.sun_path) - 1], 0);
+ SKIP_IF(addr.sun_path[sizeof(addr.sun_path) - 2] != 0 ||
+ addr.sun_path[sizeof(addr.sun_path) - 3] != 0);
+
+ addr.sun_path[sizeof(addr.sun_path) - 2] = 'a';
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(bind(sockets->second_fd(),
+ reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallSucceeds());
+}
+
+TEST_P(UnboundAbstractUnixSocketPairTest, ShortAddressNotExtended) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct sockaddr_un addr =
+ *reinterpret_cast<const struct sockaddr_un*>(sockets->first_addr());
+ ASSERT_EQ(addr.sun_path[sizeof(addr.sun_path) - 1], 0);
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size() - 1),
+ SyscallSucceeds());
+
+ ASSERT_THAT(bind(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+}
+
+TEST_P(UnboundAbstractUnixSocketPairTest, BindNothing) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ struct sockaddr_un addr = {.sun_family = AF_UNIX};
+ ASSERT_THAT(bind(sockets->first_fd(),
+ reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallSucceeds());
+}
+
+TEST_P(UnboundAbstractUnixSocketPairTest, GetSockNameFullLength) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ sockaddr_storage addr = {};
+ socklen_t addr_len = sizeof(addr);
+ ASSERT_THAT(getsockname(sockets->first_fd(),
+ reinterpret_cast<struct sockaddr*>(&addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, sockets->first_addr_size());
+}
+
+TEST_P(UnboundAbstractUnixSocketPairTest, GetSockNamePartialLength) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size() - 1),
+ SyscallSucceeds());
+
+ sockaddr_storage addr = {};
+ socklen_t addr_len = sizeof(addr);
+ ASSERT_THAT(getsockname(sockets->first_fd(),
+ reinterpret_cast<struct sockaddr*>(&addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, sockets->first_addr_size() - 1);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllUnixDomainSockets, UnboundAbstractUnixSocketPairTest,
+ ::testing::ValuesIn(ApplyVec<SocketPairKind>(
+ AbstractUnboundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_SEQPACKET,
+ SOCK_DGRAM},
+ List<int>{0, SOCK_NONBLOCK}))));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_unbound_dgram.cc b/test/syscalls/linux/socket_unix_unbound_dgram.cc
new file mode 100644
index 000000000..907dca0f1
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_unbound_dgram.cc
@@ -0,0 +1,183 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Test fixture for tests that apply to pairs of unbound dgram unix sockets.
+using UnboundDgramUnixSocketPairTest = SocketPairTest;
+
+TEST_P(UnboundDgramUnixSocketPairTest, BindConnect) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+}
+
+TEST_P(UnboundDgramUnixSocketPairTest, SelfConnect) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+ ASSERT_THAT(connect(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+}
+
+TEST_P(UnboundDgramUnixSocketPairTest, DoubleConnect) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+}
+
+TEST_P(UnboundDgramUnixSocketPairTest, GetRemoteAddress) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ socklen_t addressLength = sockets->first_addr_size();
+ struct sockaddr_storage address = {};
+ ASSERT_THAT(getpeername(sockets->second_fd(), (struct sockaddr*)(&address),
+ &addressLength),
+ SyscallSucceeds());
+ EXPECT_EQ(
+ 0, memcmp(&address, sockets->first_addr(), sockets->first_addr_size()));
+}
+
+TEST_P(UnboundDgramUnixSocketPairTest, Sendto) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ ASSERT_THAT(sendto(sockets->second_fd(), sent_data, sizeof(sent_data), 0,
+ sockets->first_addr(), sockets->first_addr_size()),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[sizeof(sent_data)];
+ ASSERT_THAT(ReadFd(sockets->first_fd(), received_data, sizeof(received_data)),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(received_data)));
+}
+
+TEST_P(UnboundDgramUnixSocketPairTest, ZeroWriteAllowed) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ char sent_data[3];
+ // Send a zero length packet.
+ ASSERT_THAT(write(sockets->second_fd(), sent_data, 0),
+ SyscallSucceedsWithValue(0));
+ // Receive the packet.
+ char received_data[sizeof(sent_data)];
+ ASSERT_THAT(read(sockets->first_fd(), received_data, sizeof(received_data)),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_P(UnboundDgramUnixSocketPairTest, Listen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(listen(sockets->first_fd(), 0), SyscallFailsWithErrno(ENOTSUP));
+}
+
+TEST_P(UnboundDgramUnixSocketPairTest, Accept) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(accept(sockets->first_fd(), nullptr, nullptr),
+ SyscallFailsWithErrno(ENOTSUP));
+}
+
+TEST_P(UnboundDgramUnixSocketPairTest, SendtoWithoutConnect) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ char data = 'a';
+ ASSERT_THAT(
+ RetryEINTR(sendto)(sockets->second_fd(), &data, sizeof(data), 0,
+ sockets->first_addr(), sockets->first_addr_size()),
+ SyscallSucceedsWithValue(sizeof(data)));
+}
+
+TEST_P(UnboundDgramUnixSocketPairTest, SendtoWithoutConnectPassCreds) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ SetSoPassCred(sockets->first_fd());
+ char data = 'a';
+ ASSERT_THAT(
+ RetryEINTR(sendto)(sockets->second_fd(), &data, sizeof(data), 0,
+ sockets->first_addr(), sockets->first_addr_size()),
+ SyscallSucceedsWithValue(sizeof(data)));
+ ucred creds;
+ creds.pid = -1;
+ char buf[sizeof(data) + 1];
+ ASSERT_NO_FATAL_FAILURE(
+ RecvCreds(sockets->first_fd(), &creds, buf, sizeof(buf), sizeof(data)));
+ EXPECT_EQ(0, memcmp(&data, buf, sizeof(data)));
+ EXPECT_THAT(getpid(), SyscallSucceedsWithValue(creds.pid));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllUnixDomainSockets, UnboundDgramUnixSocketPairTest,
+ ::testing::ValuesIn(VecCat<SocketPairKind>(
+ ApplyVec<SocketPairKind>(FilesystemUnboundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_DGRAM},
+ List<int>{
+ 0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(
+ AbstractUnboundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_DGRAM},
+ List<int>{0, SOCK_NONBLOCK})))));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_unbound_filesystem.cc b/test/syscalls/linux/socket_unix_unbound_filesystem.cc
new file mode 100644
index 000000000..cab912152
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_unbound_filesystem.cc
@@ -0,0 +1,84 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Test fixture for tests that apply to pairs of unbound filesystem unix
+// sockets.
+using UnboundFilesystemUnixSocketPairTest = SocketPairTest;
+
+TEST_P(UnboundFilesystemUnixSocketPairTest, AddressAfterNull) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct sockaddr_un addr =
+ *reinterpret_cast<const struct sockaddr_un*>(sockets->first_addr());
+ ASSERT_EQ(addr.sun_path[sizeof(addr.sun_path) - 1], 0);
+ SKIP_IF(addr.sun_path[sizeof(addr.sun_path) - 2] != 0 ||
+ addr.sun_path[sizeof(addr.sun_path) - 3] != 0);
+
+ addr.sun_path[sizeof(addr.sun_path) - 2] = 'a';
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(bind(sockets->second_fd(),
+ reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+TEST_P(UnboundFilesystemUnixSocketPairTest, GetSockNameLength) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ sockaddr_storage got_addr = {};
+ socklen_t got_addr_len = sizeof(got_addr);
+ ASSERT_THAT(
+ getsockname(sockets->first_fd(),
+ reinterpret_cast<struct sockaddr*>(&got_addr), &got_addr_len),
+ SyscallSucceeds());
+
+ sockaddr_un want_addr =
+ *reinterpret_cast<const struct sockaddr_un*>(sockets->first_addr());
+
+ EXPECT_EQ(got_addr_len,
+ strlen(want_addr.sun_path) + 1 + sizeof(want_addr.sun_family));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllUnixDomainSockets, UnboundFilesystemUnixSocketPairTest,
+ ::testing::ValuesIn(ApplyVec<SocketPairKind>(
+ FilesystemUnboundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM, SOCK_SEQPACKET,
+ SOCK_DGRAM},
+ List<int>{0, SOCK_NONBLOCK}))));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_unbound_seqpacket.cc b/test/syscalls/linux/socket_unix_unbound_seqpacket.cc
new file mode 100644
index 000000000..cb99030f5
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_unbound_seqpacket.cc
@@ -0,0 +1,89 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Test fixture for tests that apply to pairs of unbound seqpacket unix sockets.
+using UnboundUnixSeqpacketSocketPairTest = SocketPairTest;
+
+TEST_P(UnboundUnixSeqpacketSocketPairTest, SendtoWithoutConnect) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ char data = 'a';
+ ASSERT_THAT(sendto(sockets->second_fd(), &data, sizeof(data), 0,
+ sockets->first_addr(), sockets->first_addr_size()),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
+TEST_P(UnboundUnixSeqpacketSocketPairTest, SendtoWithoutConnectIgnoresAddr) {
+ // FIXME(b/68223466): gVisor tries to find /foo/bar and thus returns ENOENT.
+ if (IsRunningOnGvisor()) {
+ return;
+ }
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ // Even a bogus address is completely ignored.
+ constexpr char kPath[] = "/foo/bar";
+
+ // Sanity check that kPath doesn't exist.
+ struct stat s;
+ ASSERT_THAT(stat(kPath, &s), SyscallFailsWithErrno(ENOENT));
+
+ struct sockaddr_un addr = {};
+ addr.sun_family = AF_UNIX;
+ memcpy(addr.sun_path, kPath, sizeof(kPath));
+
+ char data = 'a';
+ ASSERT_THAT(
+ sendto(sockets->second_fd(), &data, sizeof(data), 0,
+ reinterpret_cast<const struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllUnixDomainSockets, UnboundUnixSeqpacketSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>(
+ ApplyVec<SocketPairKind>(
+ FilesystemUnboundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_SEQPACKET},
+ List<int>{0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(
+ AbstractUnboundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_SEQPACKET},
+ List<int>{0, SOCK_NONBLOCK}))))));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_unbound_stream.cc b/test/syscalls/linux/socket_unix_unbound_stream.cc
new file mode 100644
index 000000000..f185dded3
--- /dev/null
+++ b/test/syscalls/linux/socket_unix_unbound_stream.cc
@@ -0,0 +1,733 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Test fixture for tests that apply to pairs of connected unix stream sockets.
+using UnixStreamSocketPairTest = SocketPairTest;
+
+// FDPassPartialRead checks that sent control messages cannot be read after
+// any of their associated data has been read while ignoring the control message
+// by using read(2) instead of recvmsg(2).
+TEST_P(UnixStreamSocketPairTest, FDPassPartialRead) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data)));
+
+ char received_data[sizeof(sent_data) / 2];
+ ASSERT_THAT(
+ ReadFd(sockets->second_fd(), received_data, sizeof(received_data)),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(received_data)));
+
+ RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data));
+ EXPECT_EQ(0, memcmp(sent_data + sizeof(received_data), received_data,
+ sizeof(received_data)));
+}
+
+TEST_P(UnixStreamSocketPairTest, FDPassCoalescedRead) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data1[20];
+ RandomizeBuffer(sent_data1, sizeof(sent_data1));
+
+ auto pair1 =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair1->second_fd(),
+ sent_data1, sizeof(sent_data1)));
+
+ char sent_data2[20];
+ RandomizeBuffer(sent_data2, sizeof(sent_data2));
+
+ auto pair2 =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair2->second_fd(),
+ sent_data2, sizeof(sent_data2)));
+
+ char received_data[sizeof(sent_data1) + sizeof(sent_data2)];
+ ASSERT_THAT(
+ ReadFd(sockets->second_fd(), received_data, sizeof(received_data)),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1)));
+ EXPECT_EQ(0, memcmp(sent_data2, received_data + sizeof(sent_data1),
+ sizeof(sent_data2)));
+}
+
+// ZeroLengthMessageFDDiscarded checks that control messages associated with
+// zero length messages are discarded.
+TEST_P(UnixStreamSocketPairTest, ZeroLengthMessageFDDiscarded) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // Zero length arrays are invalid in ISO C++, so allocate one of size 1 and
+ // send a length of 0.
+ char sent_data1[1] = {};
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(
+ SendSingleFD(sockets->first_fd(), pair->second_fd(), sent_data1, 0));
+
+ char sent_data2[20];
+ RandomizeBuffer(sent_data2, sizeof(sent_data2));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data2, sizeof(sent_data2)),
+ SyscallSucceedsWithValue(sizeof(sent_data2)));
+
+ char received_data[sizeof(sent_data2)] = {};
+
+ RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data));
+ EXPECT_EQ(0, memcmp(sent_data2, received_data, sizeof(received_data)));
+}
+
+// FDPassCoalescedRecv checks that control messages not in the first message are
+// preserved in a coalesced recv.
+TEST_P(UnixStreamSocketPairTest, FDPassCoalescedRecv) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data) / 2),
+ SyscallSucceedsWithValue(sizeof(sent_data) / 2));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data + sizeof(sent_data) / 2,
+ sizeof(sent_data) / 2));
+
+ char received_data[sizeof(sent_data)];
+
+ int fd = -1;
+ ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data,
+ sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd()));
+}
+
+// ReadsNotCoalescedAfterFDPass checks that messages after a message containing
+// an FD control message are not coalesced.
+TEST_P(UnixStreamSocketPairTest, ReadsNotCoalescedAfterFDPass) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(),
+ sent_data, sizeof(sent_data) / 2));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data + sizeof(sent_data) / 2,
+ sizeof(sent_data) / 2),
+ SyscallSucceedsWithValue(sizeof(sent_data) / 2));
+
+ char received_data[sizeof(sent_data)];
+
+ int fd = -1;
+ ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data,
+ sizeof(received_data),
+ sizeof(sent_data) / 2));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data) / 2));
+
+ ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd()));
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+
+ ASSERT_NO_FATAL_FAILURE(
+ RecvNoCmsg(sockets->second_fd(), received_data, sizeof(sent_data) / 2));
+
+ EXPECT_EQ(0, memcmp(sent_data + sizeof(sent_data) / 2, received_data,
+ sizeof(sent_data) / 2));
+}
+
+// FDPassNotCombined checks that FD control messages are not combined in a
+// coalesced read.
+TEST_P(UnixStreamSocketPairTest, FDPassNotCombined) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ auto pair1 =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair1->second_fd(),
+ sent_data, sizeof(sent_data) / 2));
+
+ auto pair2 =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair2->second_fd(),
+ sent_data + sizeof(sent_data) / 2,
+ sizeof(sent_data) / 2));
+
+ char received_data[sizeof(sent_data)];
+
+ int fd = -1;
+ ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data,
+ sizeof(received_data),
+ sizeof(sent_data) / 2));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data) / 2));
+
+ ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair1->first_fd()));
+
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+ fd = -1;
+
+ ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data,
+ sizeof(received_data),
+ sizeof(sent_data) / 2));
+
+ EXPECT_EQ(0, memcmp(sent_data + sizeof(sent_data) / 2, received_data,
+ sizeof(sent_data) / 2));
+
+ ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair2->first_fd()));
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+TEST_P(UnixStreamSocketPairTest, CredPassPartialRead) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ struct ucred sent_creds;
+
+ ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds());
+ ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds());
+
+ ASSERT_NO_FATAL_FAILURE(
+ SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data)));
+
+ int one = 1;
+ ASSERT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &one,
+ sizeof(one)),
+ SyscallSucceeds());
+
+ for (int i = 0; i < 2; i++) {
+ char received_data[10];
+ struct ucred received_creds;
+ ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds,
+ received_data, sizeof(received_data),
+ sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data + i * sizeof(received_data), received_data,
+ sizeof(received_data)));
+ EXPECT_EQ(sent_creds.pid, received_creds.pid);
+ EXPECT_EQ(sent_creds.uid, received_creds.uid);
+ EXPECT_EQ(sent_creds.gid, received_creds.gid);
+ }
+}
+
+// Unix stream sockets peek in the same way as datagram sockets.
+//
+// SinglePeek checks that only a single message is peekable in a single recv.
+TEST_P(UnixStreamSocketPairTest, SinglePeek) {
+ if (!IsRunningOnGvisor()) {
+ // Don't run this test on linux kernels newer than 4.3.x Linux kernel commit
+ // 9f389e35674f5b086edd70ed524ca0f287259725 which changes this behavior. We
+ // used to target 3.11 compatibility, so disable this test on newer kernels.
+ //
+ // NOTE(b/118902768): Bring this up to Linux 4.4 compatibility.
+ auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion());
+ SKIP_IF(version.major > 4 || (version.major == 4 && version.minor >= 3));
+ }
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char sent_data[40];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(RetryEINTR(send)(sockets->first_fd(), sent_data,
+ sizeof(sent_data) / 2, 0),
+ SyscallSucceedsWithValue(sizeof(sent_data) / 2));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data + sizeof(sent_data) / 2,
+ sizeof(sent_data) / 2, 0),
+ SyscallSucceedsWithValue(sizeof(sent_data) / 2));
+ char received_data[sizeof(sent_data)];
+ for (int i = 0; i < 3; i++) {
+ memset(received_data, 0, sizeof(received_data));
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(received_data), MSG_PEEK),
+ SyscallSucceedsWithValue(sizeof(sent_data) / 2));
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data) / 2));
+ }
+ memset(received_data, 0, sizeof(received_data));
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(sent_data) / 2, 0),
+ SyscallSucceedsWithValue(sizeof(sent_data) / 2));
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data) / 2));
+ memset(received_data, 0, sizeof(received_data));
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), received_data,
+ sizeof(sent_data) / 2, 0),
+ SyscallSucceedsWithValue(sizeof(sent_data) / 2));
+ EXPECT_EQ(0, memcmp(sent_data + sizeof(sent_data) / 2, received_data,
+ sizeof(sent_data) / 2));
+}
+
+TEST_P(UnixStreamSocketPairTest, CredsNotCoalescedUp) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data1[20];
+ RandomizeBuffer(sent_data1, sizeof(sent_data1));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data1, sizeof(sent_data1)),
+ SyscallSucceedsWithValue(sizeof(sent_data1)));
+
+ SetSoPassCred(sockets->second_fd());
+
+ char sent_data2[20];
+ RandomizeBuffer(sent_data2, sizeof(sent_data2));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data2, sizeof(sent_data2)),
+ SyscallSucceedsWithValue(sizeof(sent_data2)));
+
+ char received_data[sizeof(sent_data1) + sizeof(sent_data2)];
+
+ struct ucred received_creds;
+ ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds,
+ received_data, sizeof(received_data),
+ sizeof(sent_data1)));
+
+ EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1)));
+
+ struct ucred want_creds {
+ 0, 65534, 65534
+ };
+
+ EXPECT_EQ(want_creds.pid, received_creds.pid);
+ EXPECT_EQ(want_creds.uid, received_creds.uid);
+ EXPECT_EQ(want_creds.gid, received_creds.gid);
+
+ ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds,
+ received_data, sizeof(received_data),
+ sizeof(sent_data2)));
+
+ EXPECT_EQ(0, memcmp(sent_data2, received_data, sizeof(sent_data2)));
+
+ ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds());
+
+ EXPECT_EQ(want_creds.pid, received_creds.pid);
+ EXPECT_EQ(want_creds.uid, received_creds.uid);
+ EXPECT_EQ(want_creds.gid, received_creds.gid);
+}
+
+TEST_P(UnixStreamSocketPairTest, CredsNotCoalescedDown) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ SetSoPassCred(sockets->second_fd());
+
+ char sent_data1[20];
+ RandomizeBuffer(sent_data1, sizeof(sent_data1));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data1, sizeof(sent_data1)),
+ SyscallSucceedsWithValue(sizeof(sent_data1)));
+
+ UnsetSoPassCred(sockets->second_fd());
+
+ char sent_data2[20];
+ RandomizeBuffer(sent_data2, sizeof(sent_data2));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data2, sizeof(sent_data2)),
+ SyscallSucceedsWithValue(sizeof(sent_data2)));
+
+ SetSoPassCred(sockets->second_fd());
+
+ char received_data[sizeof(sent_data1) + sizeof(sent_data2)];
+ struct ucred received_creds;
+
+ ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds,
+ received_data, sizeof(received_data),
+ sizeof(sent_data1)));
+
+ EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1)));
+
+ struct ucred want_creds;
+ ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds());
+
+ EXPECT_EQ(want_creds.pid, received_creds.pid);
+ EXPECT_EQ(want_creds.uid, received_creds.uid);
+ EXPECT_EQ(want_creds.gid, received_creds.gid);
+
+ ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds,
+ received_data, sizeof(received_data),
+ sizeof(sent_data2)));
+
+ EXPECT_EQ(0, memcmp(sent_data2, received_data, sizeof(sent_data2)));
+
+ want_creds = {0, 65534, 65534};
+
+ EXPECT_EQ(want_creds.pid, received_creds.pid);
+ EXPECT_EQ(want_creds.uid, received_creds.uid);
+ EXPECT_EQ(want_creds.gid, received_creds.gid);
+}
+
+TEST_P(UnixStreamSocketPairTest, CoalescedCredsNoPasscred) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ SetSoPassCred(sockets->second_fd());
+
+ char sent_data1[20];
+ RandomizeBuffer(sent_data1, sizeof(sent_data1));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data1, sizeof(sent_data1)),
+ SyscallSucceedsWithValue(sizeof(sent_data1)));
+
+ UnsetSoPassCred(sockets->second_fd());
+
+ char sent_data2[20];
+ RandomizeBuffer(sent_data2, sizeof(sent_data2));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data2, sizeof(sent_data2)),
+ SyscallSucceedsWithValue(sizeof(sent_data2)));
+
+ char received_data[sizeof(sent_data1) + sizeof(sent_data2)];
+
+ ASSERT_NO_FATAL_FAILURE(
+ RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1)));
+ EXPECT_EQ(0, memcmp(sent_data2, received_data + sizeof(sent_data1),
+ sizeof(sent_data2)));
+}
+
+TEST_P(UnixStreamSocketPairTest, CoalescedCreds1) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data1[20];
+ RandomizeBuffer(sent_data1, sizeof(sent_data1));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data1, sizeof(sent_data1)),
+ SyscallSucceedsWithValue(sizeof(sent_data1)));
+
+ char sent_data2[20];
+ RandomizeBuffer(sent_data2, sizeof(sent_data2));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data2, sizeof(sent_data2)),
+ SyscallSucceedsWithValue(sizeof(sent_data2)));
+
+ SetSoPassCred(sockets->second_fd());
+
+ char received_data[sizeof(sent_data1) + sizeof(sent_data2)];
+ struct ucred received_creds;
+
+ ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds,
+ received_data, sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1)));
+ EXPECT_EQ(0, memcmp(sent_data2, received_data + sizeof(sent_data1),
+ sizeof(sent_data2)));
+
+ struct ucred want_creds {
+ 0, 65534, 65534
+ };
+
+ EXPECT_EQ(want_creds.pid, received_creds.pid);
+ EXPECT_EQ(want_creds.uid, received_creds.uid);
+ EXPECT_EQ(want_creds.gid, received_creds.gid);
+}
+
+TEST_P(UnixStreamSocketPairTest, CoalescedCreds2) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ SetSoPassCred(sockets->second_fd());
+
+ char sent_data1[20];
+ RandomizeBuffer(sent_data1, sizeof(sent_data1));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data1, sizeof(sent_data1)),
+ SyscallSucceedsWithValue(sizeof(sent_data1)));
+
+ char sent_data2[20];
+ RandomizeBuffer(sent_data2, sizeof(sent_data2));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data2, sizeof(sent_data2)),
+ SyscallSucceedsWithValue(sizeof(sent_data2)));
+
+ char received_data[sizeof(sent_data1) + sizeof(sent_data2)];
+ struct ucred received_creds;
+
+ ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds,
+ received_data, sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1)));
+ EXPECT_EQ(0, memcmp(sent_data2, received_data + sizeof(sent_data1),
+ sizeof(sent_data2)));
+
+ struct ucred want_creds;
+ ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds());
+
+ EXPECT_EQ(want_creds.pid, received_creds.pid);
+ EXPECT_EQ(want_creds.uid, received_creds.uid);
+ EXPECT_EQ(want_creds.gid, received_creds.gid);
+}
+
+TEST_P(UnixStreamSocketPairTest, NonCoalescedDifferingCreds1) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data1[20];
+ RandomizeBuffer(sent_data1, sizeof(sent_data1));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data1, sizeof(sent_data1)),
+ SyscallSucceedsWithValue(sizeof(sent_data1)));
+
+ SetSoPassCred(sockets->second_fd());
+
+ char sent_data2[20];
+ RandomizeBuffer(sent_data2, sizeof(sent_data2));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data2, sizeof(sent_data2)),
+ SyscallSucceedsWithValue(sizeof(sent_data2)));
+
+ char received_data1[sizeof(sent_data1) + sizeof(sent_data2)];
+ struct ucred received_creds1;
+
+ ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds1,
+ received_data1, sizeof(sent_data1)));
+
+ EXPECT_EQ(0, memcmp(sent_data1, received_data1, sizeof(sent_data1)));
+
+ struct ucred want_creds1 {
+ 0, 65534, 65534
+ };
+
+ EXPECT_EQ(want_creds1.pid, received_creds1.pid);
+ EXPECT_EQ(want_creds1.uid, received_creds1.uid);
+ EXPECT_EQ(want_creds1.gid, received_creds1.gid);
+
+ char received_data2[sizeof(sent_data1) + sizeof(sent_data2)];
+ struct ucred received_creds2;
+
+ ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds2,
+ received_data2, sizeof(sent_data2)));
+
+ EXPECT_EQ(0, memcmp(sent_data2, received_data2, sizeof(sent_data2)));
+
+ struct ucred want_creds2;
+ ASSERT_THAT(want_creds2.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds2.uid = getuid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds2.gid = getgid(), SyscallSucceeds());
+
+ EXPECT_EQ(want_creds2.pid, received_creds2.pid);
+ EXPECT_EQ(want_creds2.uid, received_creds2.uid);
+ EXPECT_EQ(want_creds2.gid, received_creds2.gid);
+}
+
+TEST_P(UnixStreamSocketPairTest, NonCoalescedDifferingCreds2) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ SetSoPassCred(sockets->second_fd());
+
+ char sent_data1[20];
+ RandomizeBuffer(sent_data1, sizeof(sent_data1));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data1, sizeof(sent_data1)),
+ SyscallSucceedsWithValue(sizeof(sent_data1)));
+
+ UnsetSoPassCred(sockets->second_fd());
+
+ char sent_data2[20];
+ RandomizeBuffer(sent_data2, sizeof(sent_data2));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data2, sizeof(sent_data2)),
+ SyscallSucceedsWithValue(sizeof(sent_data2)));
+
+ SetSoPassCred(sockets->second_fd());
+
+ char received_data1[sizeof(sent_data1) + sizeof(sent_data2)];
+ struct ucred received_creds1;
+
+ ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds1,
+ received_data1, sizeof(sent_data1)));
+
+ EXPECT_EQ(0, memcmp(sent_data1, received_data1, sizeof(sent_data1)));
+
+ struct ucred want_creds1;
+ ASSERT_THAT(want_creds1.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds1.uid = getuid(), SyscallSucceeds());
+ ASSERT_THAT(want_creds1.gid = getgid(), SyscallSucceeds());
+
+ EXPECT_EQ(want_creds1.pid, received_creds1.pid);
+ EXPECT_EQ(want_creds1.uid, received_creds1.uid);
+ EXPECT_EQ(want_creds1.gid, received_creds1.gid);
+
+ char received_data2[sizeof(sent_data1) + sizeof(sent_data2)];
+ struct ucred received_creds2;
+
+ ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds2,
+ received_data2, sizeof(sent_data2)));
+
+ EXPECT_EQ(0, memcmp(sent_data2, received_data2, sizeof(sent_data2)));
+
+ struct ucred want_creds2 {
+ 0, 65534, 65534
+ };
+
+ EXPECT_EQ(want_creds2.pid, received_creds2.pid);
+ EXPECT_EQ(want_creds2.uid, received_creds2.uid);
+ EXPECT_EQ(want_creds2.gid, received_creds2.gid);
+}
+
+TEST_P(UnixStreamSocketPairTest, CoalescedDifferingCreds) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ SetSoPassCred(sockets->second_fd());
+
+ char sent_data1[20];
+ RandomizeBuffer(sent_data1, sizeof(sent_data1));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data1, sizeof(sent_data1)),
+ SyscallSucceedsWithValue(sizeof(sent_data1)));
+
+ char sent_data2[20];
+ RandomizeBuffer(sent_data2, sizeof(sent_data2));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data2, sizeof(sent_data2)),
+ SyscallSucceedsWithValue(sizeof(sent_data2)));
+
+ UnsetSoPassCred(sockets->second_fd());
+
+ char sent_data3[20];
+ RandomizeBuffer(sent_data3, sizeof(sent_data3));
+
+ ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data3, sizeof(sent_data3)),
+ SyscallSucceedsWithValue(sizeof(sent_data3)));
+
+ char received_data[sizeof(sent_data1) + sizeof(sent_data2) +
+ sizeof(sent_data3)];
+
+ ASSERT_NO_FATAL_FAILURE(
+ RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1)));
+ EXPECT_EQ(0, memcmp(sent_data2, received_data + sizeof(sent_data1),
+ sizeof(sent_data2)));
+ EXPECT_EQ(0, memcmp(sent_data3,
+ received_data + sizeof(sent_data1) + sizeof(sent_data2),
+ sizeof(sent_data3)));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllUnixDomainSockets, UnixStreamSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>(
+ ApplyVec<SocketPairKind>(UnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM},
+ List<int>{
+ 0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(FilesystemBoundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM},
+ List<int>{
+ 0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(
+ AbstractBoundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM},
+ List<int>{0, SOCK_NONBLOCK}))))));
+
+// Test fixture for tests that apply to pairs of unbound unix stream sockets.
+using UnboundUnixStreamSocketPairTest = SocketPairTest;
+
+TEST_P(UnboundUnixStreamSocketPairTest, SendtoWithoutConnect) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ char data = 'a';
+ ASSERT_THAT(sendto(sockets->second_fd(), &data, sizeof(data), 0,
+ sockets->first_addr(), sockets->first_addr_size()),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+}
+
+TEST_P(UnboundUnixStreamSocketPairTest, SendtoWithoutConnectIgnoresAddr) {
+ // FIXME(b/68223466): gVisor tries to find /foo/bar and thus returns ENOENT.
+ if (IsRunningOnGvisor()) {
+ return;
+ }
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ // Even a bogus address is completely ignored.
+ constexpr char kPath[] = "/foo/bar";
+
+ // Sanity check that kPath doesn't exist.
+ struct stat s;
+ ASSERT_THAT(stat(kPath, &s), SyscallFailsWithErrno(ENOENT));
+
+ struct sockaddr_un addr = {};
+ addr.sun_family = AF_UNIX;
+ memcpy(addr.sun_path, kPath, sizeof(kPath));
+
+ char data = 'a';
+ ASSERT_THAT(
+ sendto(sockets->second_fd(), &data, sizeof(data), 0,
+ reinterpret_cast<const struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllUnixDomainSockets, UnboundUnixStreamSocketPairTest,
+ ::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>(
+ ApplyVec<SocketPairKind>(FilesystemUnboundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM},
+ List<int>{
+ 0, SOCK_NONBLOCK})),
+ ApplyVec<SocketPairKind>(
+ AbstractUnboundUnixDomainSocketPair,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM},
+ List<int>{0, SOCK_NONBLOCK}))))));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc
new file mode 100644
index 000000000..08fc4b1b7
--- /dev/null
+++ b/test/syscalls/linux/splice.cc
@@ -0,0 +1,699 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <linux/unistd.h>
+#include <sys/eventfd.h>
+#include <sys/resource.h>
+#include <sys/sendfile.h>
+#include <sys/time.h>
+#include <unistd.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/string_view.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(SpliceTest, TwoRegularFiles) {
+ // Create temp files.
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ // Open the input file as read only.
+ const FileDescriptor in_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Open the output file as write only.
+ const FileDescriptor out_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY));
+
+ // Verify that it is rejected as expected; regardless of offsets.
+ loff_t in_offset = 0;
+ loff_t out_offset = 0;
+ EXPECT_THAT(splice(in_fd.get(), &in_offset, out_fd.get(), &out_offset, 1, 0),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(splice(in_fd.get(), nullptr, out_fd.get(), &out_offset, 1, 0),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(splice(in_fd.get(), &in_offset, out_fd.get(), nullptr, 1, 0),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(splice(in_fd.get(), nullptr, out_fd.get(), nullptr, 1, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+int memfd_create(const std::string& name, unsigned int flags) {
+ return syscall(__NR_memfd_create, name.c_str(), flags);
+}
+
+TEST(SpliceTest, NegativeOffset) {
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill the pipe.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Open the output file as write only.
+ int fd;
+ EXPECT_THAT(fd = memfd_create("negative", 0), SyscallSucceeds());
+ const FileDescriptor out_fd(fd);
+
+ loff_t out_offset = 0xffffffffffffffffull;
+ constexpr int kSize = 2;
+ EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), &out_offset, kSize, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// Write offset + size overflows int64.
+//
+// This is a regression test for b/148041624.
+TEST(SpliceTest, WriteOverflow) {
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill the pipe.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Open the output file.
+ int fd;
+ EXPECT_THAT(fd = memfd_create("overflow", 0), SyscallSucceeds());
+ const FileDescriptor out_fd(fd);
+
+ // out_offset + kSize overflows INT64_MAX.
+ loff_t out_offset = 0x7ffffffffffffffeull;
+ constexpr int kSize = 3;
+ EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), &out_offset, kSize, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SpliceTest, SamePipe) {
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill the pipe.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Attempt to splice to itself.
+ EXPECT_THAT(splice(rfd.get(), nullptr, wfd.get(), nullptr, kPageSize, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(TeeTest, SamePipe) {
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill the pipe.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Attempt to tee to itself.
+ EXPECT_THAT(tee(rfd.get(), wfd.get(), kPageSize, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(TeeTest, RegularFile) {
+ // Open some file.
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor in_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Attempt to tee from the file.
+ EXPECT_THAT(tee(in_fd.get(), wfd.get(), kPageSize, 0),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(tee(rfd.get(), in_fd.get(), kPageSize, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SpliceTest, PipeOffsets) {
+ // Create two new pipes.
+ int first[2], second[2];
+ ASSERT_THAT(pipe(first), SyscallSucceeds());
+ const FileDescriptor rfd1(first[0]);
+ const FileDescriptor wfd1(first[1]);
+ ASSERT_THAT(pipe(second), SyscallSucceeds());
+ const FileDescriptor rfd2(second[0]);
+ const FileDescriptor wfd2(second[1]);
+
+ // All pipe offsets should be rejected.
+ loff_t in_offset = 0;
+ loff_t out_offset = 0;
+ EXPECT_THAT(splice(rfd1.get(), &in_offset, wfd2.get(), &out_offset, 1, 0),
+ SyscallFailsWithErrno(ESPIPE));
+ EXPECT_THAT(splice(rfd1.get(), nullptr, wfd2.get(), &out_offset, 1, 0),
+ SyscallFailsWithErrno(ESPIPE));
+ EXPECT_THAT(splice(rfd1.get(), &in_offset, wfd2.get(), nullptr, 1, 0),
+ SyscallFailsWithErrno(ESPIPE));
+}
+
+// Event FDs may be used with splice without an offset.
+TEST(SpliceTest, FromEventFD) {
+ // Open the input eventfd with an initial value so that it is readable.
+ constexpr uint64_t kEventFDValue = 1;
+ int efd;
+ ASSERT_THAT(efd = eventfd(kEventFDValue, 0), SyscallSucceeds());
+ const FileDescriptor in_fd(efd);
+
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Splice 8-byte eventfd value to pipe.
+ constexpr int kEventFDSize = 8;
+ EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, kEventFDSize, 0),
+ SyscallSucceedsWithValue(kEventFDSize));
+
+ // Contents should be equal.
+ std::vector<char> rbuf(kEventFDSize);
+ ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(kEventFDSize));
+ EXPECT_EQ(memcmp(rbuf.data(), &kEventFDValue, rbuf.size()), 0);
+}
+
+// Event FDs may not be used with splice with an offset.
+TEST(SpliceTest, FromEventFDOffset) {
+ int efd;
+ ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds());
+ const FileDescriptor in_fd(efd);
+
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Attempt to splice 8-byte eventfd value to pipe with offset.
+ //
+ // This is not allowed because eventfd doesn't support pread.
+ constexpr int kEventFDSize = 8;
+ loff_t in_off = 0;
+ EXPECT_THAT(splice(in_fd.get(), &in_off, wfd.get(), nullptr, kEventFDSize, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// Event FDs may not be used with splice with an offset.
+TEST(SpliceTest, ToEventFDOffset) {
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill with a value.
+ constexpr int kEventFDSize = 8;
+ std::vector<char> buf(kEventFDSize);
+ buf[0] = 1;
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kEventFDSize));
+
+ int efd;
+ ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds());
+ const FileDescriptor out_fd(efd);
+
+ // Attempt to splice 8-byte eventfd value to pipe with offset.
+ //
+ // This is not allowed because eventfd doesn't support pwrite.
+ loff_t out_off = 0;
+ EXPECT_THAT(
+ splice(rfd.get(), nullptr, out_fd.get(), &out_off, kEventFDSize, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SpliceTest, ToPipe) {
+ // Open the input file.
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor in_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+
+ // Fill with some random data.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(in_fd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ ASSERT_THAT(lseek(in_fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
+
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Splice to the pipe.
+ EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, kPageSize, 0),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Contents should be equal.
+ std::vector<char> rbuf(kPageSize);
+ ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0);
+}
+
+TEST(SpliceTest, ToPipeOffset) {
+ // Open the input file.
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor in_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+
+ // Fill with some random data.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(in_fd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Splice to the pipe.
+ loff_t in_offset = kPageSize / 2;
+ EXPECT_THAT(
+ splice(in_fd.get(), &in_offset, wfd.get(), nullptr, kPageSize / 2, 0),
+ SyscallSucceedsWithValue(kPageSize / 2));
+
+ // Contents should be equal to only the second part.
+ std::vector<char> rbuf(kPageSize / 2);
+ ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(kPageSize / 2));
+ EXPECT_EQ(memcmp(rbuf.data(), buf.data() + (kPageSize / 2), rbuf.size()), 0);
+}
+
+TEST(SpliceTest, FromPipe) {
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill with some random data.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Open the input file.
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor out_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR));
+
+ // Splice to the output file.
+ EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), nullptr, kPageSize, 0),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // The offset of the output should be equal to kPageSize. We assert that and
+ // reset to zero so that we can read the contents and ensure they match.
+ EXPECT_THAT(lseek(out_fd.get(), 0, SEEK_CUR),
+ SyscallSucceedsWithValue(kPageSize));
+ ASSERT_THAT(lseek(out_fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
+
+ // Contents should be equal.
+ std::vector<char> rbuf(kPageSize);
+ ASSERT_THAT(read(out_fd.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0);
+}
+
+TEST(SpliceTest, FromPipeOffset) {
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill with some random data.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Open the input file.
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor out_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR));
+
+ // Splice to the output file.
+ loff_t out_offset = kPageSize / 2;
+ EXPECT_THAT(
+ splice(rfd.get(), nullptr, out_fd.get(), &out_offset, kPageSize, 0),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Content should reflect the splice. We write to a specific offset in the
+ // file, so the internals should now be allocated sparsely.
+ std::vector<char> rbuf(kPageSize);
+ ASSERT_THAT(read(out_fd.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ std::vector<char> zbuf(kPageSize / 2);
+ memset(zbuf.data(), 0, zbuf.size());
+ EXPECT_EQ(memcmp(rbuf.data(), zbuf.data(), zbuf.size()), 0);
+ EXPECT_EQ(memcmp(rbuf.data() + kPageSize / 2, buf.data(), kPageSize / 2), 0);
+}
+
+TEST(SpliceTest, TwoPipes) {
+ // Create two new pipes.
+ int first[2], second[2];
+ ASSERT_THAT(pipe(first), SyscallSucceeds());
+ const FileDescriptor rfd1(first[0]);
+ const FileDescriptor wfd1(first[1]);
+ ASSERT_THAT(pipe(second), SyscallSucceeds());
+ const FileDescriptor rfd2(second[0]);
+ const FileDescriptor wfd2(second[1]);
+
+ // Fill with some random data.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(wfd1.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Splice to the second pipe, using two operations.
+ EXPECT_THAT(
+ splice(rfd1.get(), nullptr, wfd2.get(), nullptr, kPageSize / 2, 0),
+ SyscallSucceedsWithValue(kPageSize / 2));
+ EXPECT_THAT(
+ splice(rfd1.get(), nullptr, wfd2.get(), nullptr, kPageSize / 2, 0),
+ SyscallSucceedsWithValue(kPageSize / 2));
+
+ // Content should reflect the splice.
+ std::vector<char> rbuf(kPageSize);
+ ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ EXPECT_EQ(memcmp(rbuf.data(), buf.data(), kPageSize), 0);
+}
+
+TEST(SpliceTest, TwoPipesCircular) {
+ // This test deadlocks the sentry on VFS1 because VFS1 splice ordering is
+ // based on fs.File.UniqueID, which does not prevent circular ordering between
+ // e.g. inode-level locks taken by fs.FileOperations.
+ SKIP_IF(IsRunningWithVFS1());
+
+ // Create two pipes.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor first_rfd(fds[0]);
+ const FileDescriptor first_wfd(fds[1]);
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor second_rfd(fds[0]);
+ const FileDescriptor second_wfd(fds[1]);
+
+ // On Linux, each pipe is normally limited to
+ // include/linux/pipe_fs_i.h:PIPE_DEF_BUFFERS buffers worth of data.
+ constexpr size_t PIPE_DEF_BUFFERS = 16;
+
+ // Write some data to each pipe. Below we splice 1 byte at a time between
+ // pipes, which very quickly causes each byte to be stored in a separate
+ // buffer, so we must ensure that the total amount of data in the system is <=
+ // PIPE_DEF_BUFFERS bytes.
+ std::vector<char> buf(PIPE_DEF_BUFFERS / 2);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(first_wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+ ASSERT_THAT(write(second_wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+
+ // Have another thread splice from the second pipe to the first, while we
+ // splice from the first to the second. The test passes if this does not
+ // deadlock.
+ const int kIterations = 1000;
+ DisableSave ds;
+ ScopedThread t([&]() {
+ for (int i = 0; i < kIterations; i++) {
+ ASSERT_THAT(
+ splice(second_rfd.get(), nullptr, first_wfd.get(), nullptr, 1, 0),
+ SyscallSucceedsWithValue(1));
+ }
+ });
+ for (int i = 0; i < kIterations; i++) {
+ ASSERT_THAT(
+ splice(first_rfd.get(), nullptr, second_wfd.get(), nullptr, 1, 0),
+ SyscallSucceedsWithValue(1));
+ }
+}
+
+TEST(SpliceTest, Blocking) {
+ // Create two new pipes.
+ int first[2], second[2];
+ ASSERT_THAT(pipe(first), SyscallSucceeds());
+ const FileDescriptor rfd1(first[0]);
+ const FileDescriptor wfd1(first[1]);
+ ASSERT_THAT(pipe(second), SyscallSucceeds());
+ const FileDescriptor rfd2(second[0]);
+ const FileDescriptor wfd2(second[1]);
+
+ // This thread writes to the main pipe.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ScopedThread t([&]() {
+ ASSERT_THAT(write(wfd1.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ });
+
+ // Attempt a splice immediately; it should block.
+ EXPECT_THAT(splice(rfd1.get(), nullptr, wfd2.get(), nullptr, kPageSize, 0),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Thread should be joinable.
+ t.Join();
+
+ // Content should reflect the splice.
+ std::vector<char> rbuf(kPageSize);
+ ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ EXPECT_EQ(memcmp(rbuf.data(), buf.data(), kPageSize), 0);
+}
+
+TEST(TeeTest, Blocking) {
+ // Create two new pipes.
+ int first[2], second[2];
+ ASSERT_THAT(pipe(first), SyscallSucceeds());
+ const FileDescriptor rfd1(first[0]);
+ const FileDescriptor wfd1(first[1]);
+ ASSERT_THAT(pipe(second), SyscallSucceeds());
+ const FileDescriptor rfd2(second[0]);
+ const FileDescriptor wfd2(second[1]);
+
+ // This thread writes to the main pipe.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ScopedThread t([&]() {
+ ASSERT_THAT(write(wfd1.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ });
+
+ // Attempt a tee immediately; it should block.
+ EXPECT_THAT(tee(rfd1.get(), wfd2.get(), kPageSize, 0),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Thread should be joinable.
+ t.Join();
+
+ // Content should reflect the splice, in both pipes.
+ std::vector<char> rbuf(kPageSize);
+ ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ EXPECT_EQ(memcmp(rbuf.data(), buf.data(), kPageSize), 0);
+ ASSERT_THAT(read(rfd1.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ EXPECT_EQ(memcmp(rbuf.data(), buf.data(), kPageSize), 0);
+}
+
+TEST(TeeTest, BlockingWrite) {
+ // Create two new pipes.
+ int first[2], second[2];
+ ASSERT_THAT(pipe(first), SyscallSucceeds());
+ const FileDescriptor rfd1(first[0]);
+ const FileDescriptor wfd1(first[1]);
+ ASSERT_THAT(pipe(second), SyscallSucceeds());
+ const FileDescriptor rfd2(second[0]);
+ const FileDescriptor wfd2(second[1]);
+
+ // Make some data available to be read.
+ std::vector<char> buf1(kPageSize);
+ RandomizeBuffer(buf1.data(), buf1.size());
+ ASSERT_THAT(write(wfd1.get(), buf1.data(), buf1.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Fill up the write pipe's buffer.
+ int pipe_size = -1;
+ ASSERT_THAT(pipe_size = fcntl(wfd2.get(), F_GETPIPE_SZ), SyscallSucceeds());
+ std::vector<char> buf2(pipe_size);
+ ASSERT_THAT(write(wfd2.get(), buf2.data(), buf2.size()),
+ SyscallSucceedsWithValue(pipe_size));
+
+ ScopedThread t([&]() {
+ absl::SleepFor(absl::Milliseconds(100));
+ ASSERT_THAT(read(rfd2.get(), buf2.data(), buf2.size()),
+ SyscallSucceedsWithValue(pipe_size));
+ });
+
+ // Attempt a tee immediately; it should block.
+ EXPECT_THAT(tee(rfd1.get(), wfd2.get(), kPageSize, 0),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Thread should be joinable.
+ t.Join();
+
+ // Content should reflect the tee.
+ std::vector<char> rbuf(kPageSize);
+ ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ EXPECT_EQ(memcmp(rbuf.data(), buf1.data(), kPageSize), 0);
+}
+
+TEST(SpliceTest, NonBlocking) {
+ // Create two new pipes.
+ int first[2], second[2];
+ ASSERT_THAT(pipe(first), SyscallSucceeds());
+ const FileDescriptor rfd1(first[0]);
+ const FileDescriptor wfd1(first[1]);
+ ASSERT_THAT(pipe(second), SyscallSucceeds());
+ const FileDescriptor rfd2(second[0]);
+ const FileDescriptor wfd2(second[1]);
+
+ // Splice with no data to back it.
+ EXPECT_THAT(splice(rfd1.get(), nullptr, wfd2.get(), nullptr, kPageSize,
+ SPLICE_F_NONBLOCK),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST(TeeTest, NonBlocking) {
+ // Create two new pipes.
+ int first[2], second[2];
+ ASSERT_THAT(pipe(first), SyscallSucceeds());
+ const FileDescriptor rfd1(first[0]);
+ const FileDescriptor wfd1(first[1]);
+ ASSERT_THAT(pipe(second), SyscallSucceeds());
+ const FileDescriptor rfd2(second[0]);
+ const FileDescriptor wfd2(second[1]);
+
+ // Splice with no data to back it.
+ EXPECT_THAT(tee(rfd1.get(), wfd2.get(), kPageSize, SPLICE_F_NONBLOCK),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST(TeeTest, MultiPage) {
+ // Create two new pipes.
+ int first[2], second[2];
+ ASSERT_THAT(pipe(first), SyscallSucceeds());
+ const FileDescriptor rfd1(first[0]);
+ const FileDescriptor wfd1(first[1]);
+ ASSERT_THAT(pipe(second), SyscallSucceeds());
+ const FileDescriptor rfd2(second[0]);
+ const FileDescriptor wfd2(second[1]);
+
+ // Make some data available to be read.
+ std::vector<char> wbuf(8 * kPageSize);
+ RandomizeBuffer(wbuf.data(), wbuf.size());
+ ASSERT_THAT(write(wfd1.get(), wbuf.data(), wbuf.size()),
+ SyscallSucceedsWithValue(wbuf.size()));
+
+ // Attempt a tee immediately; it should complete.
+ EXPECT_THAT(tee(rfd1.get(), wfd2.get(), wbuf.size(), 0),
+ SyscallSucceedsWithValue(wbuf.size()));
+
+ // Content should reflect the tee.
+ std::vector<char> rbuf(wbuf.size());
+ ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(rbuf.size()));
+ EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), rbuf.size()), 0);
+ ASSERT_THAT(read(rfd1.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(rbuf.size()));
+ EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), rbuf.size()), 0);
+}
+
+TEST(SpliceTest, FromPipeMaxFileSize) {
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill with some random data.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Open the input file.
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor out_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR));
+
+ EXPECT_THAT(ftruncate(out_fd.get(), 13 << 20), SyscallSucceeds());
+ EXPECT_THAT(lseek(out_fd.get(), 0, SEEK_END),
+ SyscallSucceedsWithValue(13 << 20));
+
+ // Set our file size limit.
+ sigset_t set;
+ sigemptyset(&set);
+ sigaddset(&set, SIGXFSZ);
+ TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
+ rlimit rlim = {};
+ rlim.rlim_cur = rlim.rlim_max = (13 << 20);
+ EXPECT_THAT(setrlimit(RLIMIT_FSIZE, &rlim), SyscallSucceeds());
+
+ // Splice to the output file.
+ EXPECT_THAT(
+ splice(rfd.get(), nullptr, out_fd.get(), nullptr, 3 * kPageSize, 0),
+ SyscallFailsWithErrno(EFBIG));
+
+ // Contents should be equal.
+ std::vector<char> rbuf(kPageSize);
+ ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc
new file mode 100644
index 000000000..2503960f3
--- /dev/null
+++ b/test/syscalls/linux/stat.cc
@@ -0,0 +1,720 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <sys/stat.h>
+#include <sys/statfs.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "test/syscalls/linux/file_base.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+#ifndef AT_STATX_FORCE_SYNC
+#define AT_STATX_FORCE_SYNC 0x2000
+#endif
+#ifndef AT_STATX_DONT_SYNC
+#define AT_STATX_DONT_SYNC 0x4000
+#endif
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class StatTest : public FileTest {};
+
+TEST_F(StatTest, FstatatAbs) {
+ struct stat st;
+
+ // Check that the stat works.
+ EXPECT_THAT(fstatat(AT_FDCWD, test_file_name_.c_str(), &st, 0),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISREG(st.st_mode));
+}
+
+TEST_F(StatTest, FstatatEmptyPath) {
+ struct stat st;
+ const auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY));
+
+ // Check that the stat works.
+ EXPECT_THAT(fstatat(fd.get(), "", &st, AT_EMPTY_PATH), SyscallSucceeds());
+ EXPECT_TRUE(S_ISREG(st.st_mode));
+}
+
+TEST_F(StatTest, FstatatRel) {
+ struct stat st;
+ int dirfd;
+ auto filename = std::string(Basename(test_file_name_));
+
+ // Open the temporary directory read-only.
+ ASSERT_THAT(dirfd = open(GetAbsoluteTestTmpdir().c_str(), O_RDONLY),
+ SyscallSucceeds());
+
+ // Check that the stat works.
+ EXPECT_THAT(fstatat(dirfd, filename.c_str(), &st, 0), SyscallSucceeds());
+ EXPECT_TRUE(S_ISREG(st.st_mode));
+ close(dirfd);
+}
+
+TEST_F(StatTest, FstatatSymlink) {
+ struct stat st;
+
+ // Check that the link is followed.
+ EXPECT_THAT(fstatat(AT_FDCWD, "/proc/self", &st, 0), SyscallSucceeds());
+ EXPECT_TRUE(S_ISDIR(st.st_mode));
+ EXPECT_FALSE(S_ISLNK(st.st_mode));
+
+ // Check that the flag works.
+ EXPECT_THAT(fstatat(AT_FDCWD, "/proc/self", &st, AT_SYMLINK_NOFOLLOW),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISLNK(st.st_mode));
+ EXPECT_FALSE(S_ISDIR(st.st_mode));
+}
+
+TEST_F(StatTest, Nlinks) {
+ TempPath basedir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ // Directory is initially empty, it should contain 2 links (one from itself,
+ // one from ".").
+ EXPECT_THAT(Links(basedir.path()), IsPosixErrorOkAndHolds(2));
+
+ // Create a file in the test directory. Files shouldn't increase the link
+ // count on the base directory.
+ TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(basedir.path()));
+ EXPECT_THAT(Links(basedir.path()), IsPosixErrorOkAndHolds(2));
+
+ // Create subdirectories. This should increase the link count by 1 per
+ // subdirectory.
+ TempPath dir1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(basedir.path()));
+ EXPECT_THAT(Links(basedir.path()), IsPosixErrorOkAndHolds(3));
+ TempPath dir2 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(basedir.path()));
+ EXPECT_THAT(Links(basedir.path()), IsPosixErrorOkAndHolds(4));
+
+ // Removing directories should reduce the link count.
+ dir1.reset();
+ EXPECT_THAT(Links(basedir.path()), IsPosixErrorOkAndHolds(3));
+ dir2.reset();
+ EXPECT_THAT(Links(basedir.path()), IsPosixErrorOkAndHolds(2));
+
+ // Removing files should have no effect on link count.
+ file1.reset();
+ EXPECT_THAT(Links(basedir.path()), IsPosixErrorOkAndHolds(2));
+}
+
+TEST_F(StatTest, BlocksIncreaseOnWrite) {
+ struct stat st;
+
+ // Stat the empty file.
+ ASSERT_THAT(fstat(test_file_fd_.get(), &st), SyscallSucceeds());
+
+ const int initial_blocks = st.st_blocks;
+
+ // Write to the file, making sure to exceed the block size.
+ std::vector<char> buf(2 * st.st_blksize, 'a');
+ ASSERT_THAT(write(test_file_fd_.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+
+ // Stat the file again, and verify that number of allocated blocks has
+ // increased.
+ ASSERT_THAT(fstat(test_file_fd_.get(), &st), SyscallSucceeds());
+ EXPECT_GT(st.st_blocks, initial_blocks);
+}
+
+TEST_F(StatTest, PathNotCleaned) {
+ TempPath basedir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ // Create a file in the basedir.
+ TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(basedir.path()));
+
+ // Stating the file directly should succeed.
+ struct stat buf;
+ EXPECT_THAT(lstat(file.path().c_str(), &buf), SyscallSucceeds());
+
+ // Try to stat the file using a directory that does not exist followed by
+ // "..". If the path is cleaned prior to stating (which it should not be)
+ // then this will succeed.
+ const std::string bad_path = JoinPath("/does_not_exist/..", file.path());
+ EXPECT_THAT(lstat(bad_path.c_str(), &buf), SyscallFailsWithErrno(ENOENT));
+}
+
+TEST_F(StatTest, PathCanContainDotDot) {
+ TempPath basedir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ TempPath subdir =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(basedir.path()));
+ const std::string subdir_name = std::string(Basename(subdir.path()));
+
+ // Create a file in the subdir.
+ TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(subdir.path()));
+ const std::string file_name = std::string(Basename(file.path()));
+
+ // Stat the file through a path that includes '..' and '.' but still resolves
+ // to the file.
+ const std::string good_path =
+ JoinPath(basedir.path(), subdir_name, "..", subdir_name, ".", file_name);
+ struct stat buf;
+ EXPECT_THAT(lstat(good_path.c_str(), &buf), SyscallSucceeds());
+}
+
+TEST_F(StatTest, PathCanContainEmptyComponent) {
+ TempPath basedir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ // Create a file in the basedir.
+ TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(basedir.path()));
+ const std::string file_name = std::string(Basename(file.path()));
+
+ // Stat the file through a path that includes an empty component. We have to
+ // build this ourselves because JoinPath automatically removes empty
+ // components.
+ const std::string good_path = absl::StrCat(basedir.path(), "//", file_name);
+ struct stat buf;
+ EXPECT_THAT(lstat(good_path.c_str(), &buf), SyscallSucceeds());
+}
+
+TEST_F(StatTest, TrailingSlashNotCleanedReturnsENOTDIR) {
+ TempPath basedir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ // Create a file in the basedir.
+ TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(basedir.path()));
+
+ // Stat the file with an extra "/" on the end of it. Since file is not a
+ // directory, this should return ENOTDIR.
+ const std::string bad_path = absl::StrCat(file.path(), "/");
+ struct stat buf;
+ EXPECT_THAT(lstat(bad_path.c_str(), &buf), SyscallFailsWithErrno(ENOTDIR));
+}
+
+// Test fstatating a symlink directory.
+TEST_F(StatTest, FstatatSymlinkDir) {
+ // Create a directory and symlink to it.
+ const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ const std::string symlink_to_dir = NewTempAbsPath();
+ EXPECT_THAT(symlink(dir.path().c_str(), symlink_to_dir.c_str()),
+ SyscallSucceeds());
+ auto cleanup = Cleanup([&symlink_to_dir]() {
+ EXPECT_THAT(unlink(symlink_to_dir.c_str()), SyscallSucceeds());
+ });
+
+ // Fstatat the link with AT_SYMLINK_NOFOLLOW should return symlink data.
+ struct stat st = {};
+ EXPECT_THAT(
+ fstatat(AT_FDCWD, symlink_to_dir.c_str(), &st, AT_SYMLINK_NOFOLLOW),
+ SyscallSucceeds());
+ EXPECT_FALSE(S_ISDIR(st.st_mode));
+ EXPECT_TRUE(S_ISLNK(st.st_mode));
+
+ // Fstatat the link should return dir data.
+ EXPECT_THAT(fstatat(AT_FDCWD, symlink_to_dir.c_str(), &st, 0),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISDIR(st.st_mode));
+ EXPECT_FALSE(S_ISLNK(st.st_mode));
+}
+
+// Test fstatating a symlink directory with trailing slash.
+TEST_F(StatTest, FstatatSymlinkDirWithTrailingSlash) {
+ // Create a directory and symlink to it.
+ const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const std::string symlink_to_dir = NewTempAbsPath();
+ EXPECT_THAT(symlink(dir.path().c_str(), symlink_to_dir.c_str()),
+ SyscallSucceeds());
+ auto cleanup = Cleanup([&symlink_to_dir]() {
+ EXPECT_THAT(unlink(symlink_to_dir.c_str()), SyscallSucceeds());
+ });
+
+ // Fstatat on the symlink with a trailing slash should return the directory
+ // data.
+ struct stat st = {};
+ EXPECT_THAT(
+ fstatat(AT_FDCWD, absl::StrCat(symlink_to_dir, "/").c_str(), &st, 0),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISDIR(st.st_mode));
+ EXPECT_FALSE(S_ISLNK(st.st_mode));
+
+ // Fstatat on the symlink with a trailing slash with AT_SYMLINK_NOFOLLOW
+ // should return the directory data.
+ // Symlink to directory with trailing slash will ignore AT_SYMLINK_NOFOLLOW.
+ EXPECT_THAT(fstatat(AT_FDCWD, absl::StrCat(symlink_to_dir, "/").c_str(), &st,
+ AT_SYMLINK_NOFOLLOW),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISDIR(st.st_mode));
+ EXPECT_FALSE(S_ISLNK(st.st_mode));
+}
+
+// Test fstatating a symlink directory with a trailing slash
+// should return same stat data with fstatating directory.
+TEST_F(StatTest, FstatatSymlinkDirWithTrailingSlashSameInode) {
+ // Create a directory and symlink to it.
+ const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ // We are going to assert that the symlink inode id is the same as the linked
+ // dir's inode id. In order for the inode id to be stable across
+ // save/restore, it must be kept open. The FileDescriptor type will do that
+ // for us automatically.
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY));
+
+ const std::string symlink_to_dir = NewTempAbsPath();
+ EXPECT_THAT(symlink(dir.path().c_str(), symlink_to_dir.c_str()),
+ SyscallSucceeds());
+ auto cleanup = Cleanup([&symlink_to_dir]() {
+ EXPECT_THAT(unlink(symlink_to_dir.c_str()), SyscallSucceeds());
+ });
+
+ // Fstatat on the symlink with a trailing slash should return the directory
+ // data.
+ struct stat st = {};
+ EXPECT_THAT(fstatat(AT_FDCWD, absl::StrCat(symlink_to_dir, "/").c_str(), &st,
+ AT_SYMLINK_NOFOLLOW),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISDIR(st.st_mode));
+
+ // Dir and symlink should point to same inode.
+ struct stat st_dir = {};
+ EXPECT_THAT(
+ fstatat(AT_FDCWD, dir.path().c_str(), &st_dir, AT_SYMLINK_NOFOLLOW),
+ SyscallSucceeds());
+ EXPECT_EQ(st.st_ino, st_dir.st_ino);
+}
+
+TEST_F(StatTest, LeadingDoubleSlash) {
+ // Create a file, and make sure we can stat it.
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ struct stat st;
+ ASSERT_THAT(lstat(file.path().c_str(), &st), SyscallSucceeds());
+
+ // Now add an extra leading slash.
+ const std::string double_slash_path = absl::StrCat("/", file.path());
+ ASSERT_TRUE(absl::StartsWith(double_slash_path, "//"));
+
+ // We should be able to stat the new path, and it should resolve to the same
+ // file (same device and inode).
+ struct stat double_slash_st;
+ ASSERT_THAT(lstat(double_slash_path.c_str(), &double_slash_st),
+ SyscallSucceeds());
+ EXPECT_EQ(st.st_dev, double_slash_st.st_dev);
+ EXPECT_EQ(st.st_ino, double_slash_st.st_ino);
+}
+
+// Test that a rename doesn't change the underlying file.
+TEST_F(StatTest, StatDoesntChangeAfterRename) {
+ const TempPath old_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath new_path(NewTempAbsPath());
+
+ struct stat st_old = {};
+ struct stat st_new = {};
+
+ ASSERT_THAT(stat(old_dir.path().c_str(), &st_old), SyscallSucceeds());
+ ASSERT_THAT(rename(old_dir.path().c_str(), new_path.path().c_str()),
+ SyscallSucceeds());
+ ASSERT_THAT(stat(new_path.path().c_str(), &st_new), SyscallSucceeds());
+
+ EXPECT_EQ(st_old.st_nlink, st_new.st_nlink);
+ EXPECT_EQ(st_old.st_dev, st_new.st_dev);
+ EXPECT_EQ(st_old.st_ino, st_new.st_ino);
+ EXPECT_EQ(st_old.st_mode, st_new.st_mode);
+ EXPECT_EQ(st_old.st_uid, st_new.st_uid);
+ EXPECT_EQ(st_old.st_gid, st_new.st_gid);
+ EXPECT_EQ(st_old.st_size, st_new.st_size);
+}
+
+// Test link counts with a regular file as the child.
+TEST_F(StatTest, LinkCountsWithRegularFileChild) {
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ struct stat st_parent_before = {};
+ ASSERT_THAT(stat(dir.path().c_str(), &st_parent_before), SyscallSucceeds());
+ EXPECT_EQ(st_parent_before.st_nlink, 2);
+
+ // Adding a regular file doesn't adjust the parent's link count.
+ const TempPath child =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
+
+ struct stat st_parent_after = {};
+ ASSERT_THAT(stat(dir.path().c_str(), &st_parent_after), SyscallSucceeds());
+ EXPECT_EQ(st_parent_after.st_nlink, 2);
+
+ // The child should have a single link from the parent.
+ struct stat st_child = {};
+ ASSERT_THAT(stat(child.path().c_str(), &st_child), SyscallSucceeds());
+ EXPECT_TRUE(S_ISREG(st_child.st_mode));
+ EXPECT_EQ(st_child.st_nlink, 1);
+
+ // Finally unlinking the child should not affect the parent's link count.
+ ASSERT_THAT(unlink(child.path().c_str()), SyscallSucceeds());
+ ASSERT_THAT(stat(dir.path().c_str(), &st_parent_after), SyscallSucceeds());
+ EXPECT_EQ(st_parent_after.st_nlink, 2);
+}
+
+// This test verifies that inodes remain around when there is an open fd
+// after link count hits 0.
+TEST_F(StatTest, ZeroLinksOpenFdRegularFileChild_NoRandomSave) {
+ // Setting the enviornment variable GVISOR_GOFER_UNCACHED to any value
+ // will prevent this test from running, see the tmpfs lifecycle.
+ //
+ // We need to support this because when a file is unlinked and we forward
+ // the stat to the gofer it would return ENOENT.
+ const char* uncached_gofer = getenv("GVISOR_GOFER_UNCACHED");
+ SKIP_IF(uncached_gofer != nullptr);
+
+ // We don't support saving unlinked files.
+ const DisableSave ds;
+
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath child = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ dir.path(), "hello", TempPath::kDefaultFileMode));
+
+ // The child should have a single link from the parent.
+ struct stat st_child_before = {};
+ ASSERT_THAT(stat(child.path().c_str(), &st_child_before), SyscallSucceeds());
+ EXPECT_TRUE(S_ISREG(st_child_before.st_mode));
+ EXPECT_EQ(st_child_before.st_nlink, 1);
+ EXPECT_EQ(st_child_before.st_size, 5); // Hello is 5 bytes.
+
+ // Open the file so we can fstat after unlinking.
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(child.path(), O_RDONLY));
+
+ // Now a stat should return ENOENT but we should still be able to stat
+ // via the open fd and fstat.
+ ASSERT_THAT(unlink(child.path().c_str()), SyscallSucceeds());
+
+ // Since the file has no more links stat should fail.
+ struct stat st_child_after = {};
+ ASSERT_THAT(stat(child.path().c_str(), &st_child_after),
+ SyscallFailsWithErrno(ENOENT));
+
+ // Fstat should still allow us to access the same file via the fd.
+ struct stat st_child_fd = {};
+ ASSERT_THAT(fstat(fd.get(), &st_child_fd), SyscallSucceeds());
+ EXPECT_EQ(st_child_before.st_dev, st_child_fd.st_dev);
+ EXPECT_EQ(st_child_before.st_ino, st_child_fd.st_ino);
+ EXPECT_EQ(st_child_before.st_mode, st_child_fd.st_mode);
+ EXPECT_EQ(st_child_before.st_uid, st_child_fd.st_uid);
+ EXPECT_EQ(st_child_before.st_gid, st_child_fd.st_gid);
+ EXPECT_EQ(st_child_before.st_size, st_child_fd.st_size);
+
+ // TODO(b/34861058): This isn't ideal but since fstatfs(2) will always return
+ // OVERLAYFS_SUPER_MAGIC we have no way to know if this fs is backed by a
+ // gofer which doesn't support links.
+ EXPECT_TRUE(st_child_fd.st_nlink == 0 || st_child_fd.st_nlink == 1);
+}
+
+// Test link counts with a directory as the child.
+TEST_F(StatTest, LinkCountsWithDirChild) {
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ // Before a child is added the two links are "." and the link from the parent.
+ struct stat st_parent_before = {};
+ ASSERT_THAT(stat(dir.path().c_str(), &st_parent_before), SyscallSucceeds());
+ EXPECT_EQ(st_parent_before.st_nlink, 2);
+
+ // Create a subdirectory and stat for the parent link counts.
+ const TempPath sub_dir =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir.path()));
+
+ // The three links are ".", the link from the parent, and the link from
+ // the child as "..".
+ struct stat st_parent_after = {};
+ ASSERT_THAT(stat(dir.path().c_str(), &st_parent_after), SyscallSucceeds());
+ EXPECT_EQ(st_parent_after.st_nlink, 3);
+
+ // The child will have 1 link from the parent and 1 link which represents ".".
+ struct stat st_child = {};
+ ASSERT_THAT(stat(sub_dir.path().c_str(), &st_child), SyscallSucceeds());
+ EXPECT_TRUE(S_ISDIR(st_child.st_mode));
+ EXPECT_EQ(st_child.st_nlink, 2);
+
+ // Finally delete the child dir and the parent link count should return to 2.
+ ASSERT_THAT(rmdir(sub_dir.path().c_str()), SyscallSucceeds());
+ ASSERT_THAT(stat(dir.path().c_str(), &st_parent_after), SyscallSucceeds());
+
+ // Now we should only have links from the parent and "." since the subdir
+ // has been removed.
+ EXPECT_EQ(st_parent_after.st_nlink, 2);
+}
+
+// Test statting a child of a non-directory.
+TEST_F(StatTest, ChildOfNonDir) {
+ // Create a path that has a child of a regular file.
+ const std::string filename = JoinPath(test_file_name_, "child");
+
+ // Statting the path should return ENOTDIR.
+ struct stat st;
+ EXPECT_THAT(lstat(filename.c_str(), &st), SyscallFailsWithErrno(ENOTDIR));
+}
+
+// Test lstating a symlink directory.
+TEST_F(StatTest, LstatSymlinkDir) {
+ // Create a directory and symlink to it.
+ const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const std::string symlink_to_dir = NewTempAbsPath();
+ EXPECT_THAT(symlink(dir.path().c_str(), symlink_to_dir.c_str()),
+ SyscallSucceeds());
+ auto cleanup = Cleanup([&symlink_to_dir]() {
+ EXPECT_THAT(unlink(symlink_to_dir.c_str()), SyscallSucceeds());
+ });
+
+ // Lstat on the symlink should return symlink data.
+ struct stat st = {};
+ ASSERT_THAT(lstat(symlink_to_dir.c_str(), &st), SyscallSucceeds());
+ EXPECT_FALSE(S_ISDIR(st.st_mode));
+ EXPECT_TRUE(S_ISLNK(st.st_mode));
+
+ // Lstat on the symlink with a trailing slash should return the directory
+ // data.
+ ASSERT_THAT(lstat(absl::StrCat(symlink_to_dir, "/").c_str(), &st),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISDIR(st.st_mode));
+ EXPECT_FALSE(S_ISLNK(st.st_mode));
+}
+
+// Verify that we get an ELOOP from too many symbolic links even when there
+// are directories in the middle.
+TEST_F(StatTest, LstatELOOPPath) {
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ std::string subdir_base = "subdir";
+ ASSERT_THAT(mkdir(JoinPath(dir.path(), subdir_base).c_str(), 0755),
+ SyscallSucceeds());
+
+ std::string target = JoinPath(dir.path(), subdir_base, subdir_base);
+ std::string dst = JoinPath("..", subdir_base);
+ ASSERT_THAT(symlink(dst.c_str(), target.c_str()), SyscallSucceeds());
+ auto cleanup = Cleanup(
+ [&target]() { EXPECT_THAT(unlink(target.c_str()), SyscallSucceeds()); });
+
+ // Now build a path which is /subdir/subdir/... repeated many times so that
+ // we can build a path that is shorter than PATH_MAX but can still cause
+ // too many symbolic links. Note: Every other subdir is actually a directory
+ // so we're not in a situation where it's a -> b -> a -> b, where a and b
+ // are symbolic links.
+ std::string path = dir.path();
+ std::string subdir_append = absl::StrCat("/", subdir_base);
+ do {
+ absl::StrAppend(&path, subdir_append);
+ // Keep appending /subdir until we would overflow PATH_MAX.
+ } while ((path.size() + subdir_append.size()) < PATH_MAX);
+
+ struct stat s = {};
+ ASSERT_THAT(lstat(path.c_str(), &s), SyscallFailsWithErrno(ELOOP));
+}
+
+// Ensure that inode allocation for anonymous devices work correctly across
+// save/restore. In particular, inode numbers should be unique across S/R.
+TEST(SimpleStatTest, AnonDeviceAllocatesUniqueInodesAcrossSaveRestore) {
+ // Use sockets as a convenient way to create inodes on an anonymous device.
+ int fd;
+ ASSERT_THAT(fd = socket(AF_UNIX, SOCK_STREAM, 0), SyscallSucceeds());
+ FileDescriptor fd1(fd);
+ MaybeSave();
+ ASSERT_THAT(fd = socket(AF_UNIX, SOCK_STREAM, 0), SyscallSucceeds());
+ FileDescriptor fd2(fd);
+
+ struct stat st1;
+ struct stat st2;
+ ASSERT_THAT(fstat(fd1.get(), &st1), SyscallSucceeds());
+ ASSERT_THAT(fstat(fd2.get(), &st2), SyscallSucceeds());
+
+ // The two fds should have different inode numbers.
+ EXPECT_NE(st2.st_ino, st1.st_ino);
+
+ // Verify again after another S/R cycle. The inode numbers should remain the
+ // same.
+ MaybeSave();
+
+ struct stat st1_after;
+ struct stat st2_after;
+ ASSERT_THAT(fstat(fd1.get(), &st1_after), SyscallSucceeds());
+ ASSERT_THAT(fstat(fd2.get(), &st2_after), SyscallSucceeds());
+
+ EXPECT_EQ(st1_after.st_ino, st1.st_ino);
+ EXPECT_EQ(st2_after.st_ino, st2.st_ino);
+}
+
+#ifndef SYS_statx
+#if defined(__x86_64__)
+#define SYS_statx 332
+#elif defined(__aarch64__)
+#define SYS_statx 291
+#else
+#error "Unknown architecture"
+#endif
+#endif // SYS_statx
+
+#ifndef STATX_ALL
+#define STATX_ALL 0x00000fffU
+#endif // STATX_ALL
+
+// struct kernel_statx_timestamp is a Linux statx_timestamp struct.
+struct kernel_statx_timestamp {
+ int64_t tv_sec;
+ uint32_t tv_nsec;
+ int32_t __reserved;
+};
+
+// struct kernel_statx is a Linux statx struct. Old versions of glibc do not
+// expose it. See include/uapi/linux/stat.h
+struct kernel_statx {
+ uint32_t stx_mask;
+ uint32_t stx_blksize;
+ uint64_t stx_attributes;
+ uint32_t stx_nlink;
+ uint32_t stx_uid;
+ uint32_t stx_gid;
+ uint16_t stx_mode;
+ uint16_t __spare0[1];
+ uint64_t stx_ino;
+ uint64_t stx_size;
+ uint64_t stx_blocks;
+ uint64_t stx_attributes_mask;
+ struct kernel_statx_timestamp stx_atime;
+ struct kernel_statx_timestamp stx_btime;
+ struct kernel_statx_timestamp stx_ctime;
+ struct kernel_statx_timestamp stx_mtime;
+ uint32_t stx_rdev_major;
+ uint32_t stx_rdev_minor;
+ uint32_t stx_dev_major;
+ uint32_t stx_dev_minor;
+ uint64_t __spare2[14];
+};
+
+int statx(int dirfd, const char* pathname, int flags, unsigned int mask,
+ struct kernel_statx* statxbuf) {
+ return syscall(SYS_statx, dirfd, pathname, flags, mask, statxbuf);
+}
+
+TEST_F(StatTest, StatxAbsPath) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ struct kernel_statx stx;
+ EXPECT_THAT(statx(-1, test_file_name_.c_str(), 0, STATX_ALL, &stx),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISREG(stx.stx_mode));
+}
+
+TEST_F(StatTest, StatxRelPathDirFD) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ struct kernel_statx stx;
+ auto const dirfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(GetAbsoluteTestTmpdir(), O_RDONLY));
+ auto filename = std::string(Basename(test_file_name_));
+
+ EXPECT_THAT(statx(dirfd.get(), filename.c_str(), 0, STATX_ALL, &stx),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISREG(stx.stx_mode));
+}
+
+TEST_F(StatTest, StatxRelPathCwd) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds());
+ auto filename = std::string(Basename(test_file_name_));
+ struct kernel_statx stx;
+ EXPECT_THAT(statx(AT_FDCWD, filename.c_str(), 0, STATX_ALL, &stx),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISREG(stx.stx_mode));
+}
+
+TEST_F(StatTest, StatxEmptyPath) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ const auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY));
+ struct kernel_statx stx;
+ EXPECT_THAT(statx(fd.get(), "", AT_EMPTY_PATH, STATX_ALL, &stx),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISREG(stx.stx_mode));
+}
+
+TEST_F(StatTest, StatxDoesNotRejectExtraneousMaskBits) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ struct kernel_statx stx;
+ // Set all mask bits except for STATX__RESERVED.
+ uint mask = 0xffffffff & ~0x80000000;
+ EXPECT_THAT(statx(-1, test_file_name_.c_str(), 0, mask, &stx),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISREG(stx.stx_mode));
+}
+
+TEST_F(StatTest, StatxRejectsReservedMaskBit) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ struct kernel_statx stx;
+ // Set STATX__RESERVED in the mask.
+ EXPECT_THAT(statx(-1, test_file_name_.c_str(), 0, 0x80000000, &stx),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(StatTest, StatxSymlink) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ std::string parent_dir = "/tmp";
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(parent_dir, test_file_name_));
+ std::string p = link.path();
+
+ struct kernel_statx stx;
+ EXPECT_THAT(statx(AT_FDCWD, p.c_str(), AT_SYMLINK_NOFOLLOW, STATX_ALL, &stx),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISLNK(stx.stx_mode));
+ EXPECT_THAT(statx(AT_FDCWD, p.c_str(), 0, STATX_ALL, &stx),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISREG(stx.stx_mode));
+}
+
+TEST_F(StatTest, StatxInvalidFlags) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ struct kernel_statx stx;
+ EXPECT_THAT(statx(AT_FDCWD, test_file_name_.c_str(), 12345, 0, &stx),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Sync flags are mutually exclusive.
+ EXPECT_THAT(statx(AT_FDCWD, test_file_name_.c_str(),
+ AT_STATX_FORCE_SYNC | AT_STATX_DONT_SYNC, 0, &stx),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/stat_times.cc b/test/syscalls/linux/stat_times.cc
new file mode 100644
index 000000000..68c0bef09
--- /dev/null
+++ b/test/syscalls/linux/stat_times.cc
@@ -0,0 +1,303 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <sys/stat.h>
+
+#include <tuple>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+using ::testing::IsEmpty;
+using ::testing::Not;
+
+std::tuple<absl::Time, absl::Time, absl::Time> GetTime(const TempPath& file) {
+ struct stat statbuf = {};
+ EXPECT_THAT(stat(file.path().c_str(), &statbuf), SyscallSucceeds());
+
+ const auto atime = absl::TimeFromTimespec(statbuf.st_atim);
+ const auto mtime = absl::TimeFromTimespec(statbuf.st_mtim);
+ const auto ctime = absl::TimeFromTimespec(statbuf.st_ctim);
+ return std::make_tuple(atime, mtime, ctime);
+}
+
+enum class AtimeEffect {
+ Unchanged,
+ Changed,
+};
+
+enum class MtimeEffect {
+ Unchanged,
+ Changed,
+};
+
+enum class CtimeEffect {
+ Unchanged,
+ Changed,
+};
+
+// Tests that fn modifies the atime/mtime/ctime of path as specified.
+void CheckTimes(const TempPath& path, std::function<void()> fn,
+ AtimeEffect atime_effect, MtimeEffect mtime_effect,
+ CtimeEffect ctime_effect) {
+ absl::Time atime, mtime, ctime;
+ std::tie(atime, mtime, ctime) = GetTime(path);
+
+ // FIXME(b/132819225): gVisor filesystem timestamps inconsistently use the
+ // internal or host clock, which may diverge slightly. Allow some slack on
+ // times to account for the difference.
+ //
+ // Here we sleep for 1s so that initial creation of path doesn't fall within
+ // the before slack window.
+ absl::SleepFor(absl::Seconds(1));
+
+ const absl::Time before = absl::Now() - absl::Seconds(1);
+
+ // Perform the op.
+ fn();
+
+ const absl::Time after = absl::Now() + absl::Seconds(1);
+
+ absl::Time atime2, mtime2, ctime2;
+ std::tie(atime2, mtime2, ctime2) = GetTime(path);
+
+ if (atime_effect == AtimeEffect::Changed) {
+ EXPECT_LE(before, atime2);
+ EXPECT_GE(after, atime2);
+ EXPECT_GT(atime2, atime);
+ } else {
+ EXPECT_EQ(atime2, atime);
+ }
+
+ if (mtime_effect == MtimeEffect::Changed) {
+ EXPECT_LE(before, mtime2);
+ EXPECT_GE(after, mtime2);
+ EXPECT_GT(mtime2, mtime);
+ } else {
+ EXPECT_EQ(mtime2, mtime);
+ }
+
+ if (ctime_effect == CtimeEffect::Changed) {
+ EXPECT_LE(before, ctime2);
+ EXPECT_GE(after, ctime2);
+ EXPECT_GT(ctime2, ctime);
+ } else {
+ EXPECT_EQ(ctime2, ctime);
+ }
+}
+
+// File creation time is reflected in atime, mtime, and ctime.
+TEST(StatTimesTest, FileCreation) {
+ const DisableSave ds; // Timing-related test.
+
+ // Get a time for when the file is created.
+ //
+ // FIXME(b/132819225): See above.
+ const absl::Time before = absl::Now() - absl::Seconds(1);
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const absl::Time after = absl::Now() + absl::Seconds(1);
+
+ absl::Time atime, mtime, ctime;
+ std::tie(atime, mtime, ctime) = GetTime(file);
+
+ EXPECT_LE(before, atime);
+ EXPECT_LE(before, mtime);
+ EXPECT_LE(before, ctime);
+ EXPECT_GE(after, atime);
+ EXPECT_GE(after, mtime);
+ EXPECT_GE(after, ctime);
+}
+
+// Calling chmod on a file changes ctime.
+TEST(StatTimesTest, FileChmod) {
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ auto fn = [&] {
+ EXPECT_THAT(chmod(file.path().c_str(), 0666), SyscallSucceeds());
+ };
+ CheckTimes(file, fn, AtimeEffect::Unchanged, MtimeEffect::Unchanged,
+ CtimeEffect::Changed);
+}
+
+// Renaming a file changes ctime.
+TEST(StatTimesTest, FileRename) {
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ const std::string newpath = NewTempAbsPath();
+
+ auto fn = [&] {
+ ASSERT_THAT(rename(file.release().c_str(), newpath.c_str()),
+ SyscallSucceeds());
+ file.reset(newpath);
+ };
+ CheckTimes(file, fn, AtimeEffect::Unchanged, MtimeEffect::Unchanged,
+ CtimeEffect::Changed);
+}
+
+// Renaming a file changes ctime, even with an open FD.
+//
+// NOTE(b/132732387): This is a regression test for fs/gofer failing to update
+// cached ctime.
+TEST(StatTimesTest, FileRenameOpenFD) {
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ // Holding an FD shouldn't affect behavior.
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+
+ const std::string newpath = NewTempAbsPath();
+
+ // FIXME(b/132814682): Restore fails with an uncached gofer and an open FD
+ // across rename.
+ //
+ // N.B. The logic here looks backwards because it isn't possible to
+ // conditionally disable save, only conditionally re-enable it.
+ DisableSave ds;
+ if (!getenv("GVISOR_GOFER_UNCACHED")) {
+ ds.reset();
+ }
+
+ auto fn = [&] {
+ ASSERT_THAT(rename(file.release().c_str(), newpath.c_str()),
+ SyscallSucceeds());
+ file.reset(newpath);
+ };
+ CheckTimes(file, fn, AtimeEffect::Unchanged, MtimeEffect::Unchanged,
+ CtimeEffect::Changed);
+}
+
+// Calling utimes on a file changes ctime and the time that we ask to change
+// (atime to now in this case).
+TEST(StatTimesTest, FileUtimes) {
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ auto fn = [&] {
+ const struct timespec ts[2] = {{0, UTIME_NOW}, {0, UTIME_OMIT}};
+ ASSERT_THAT(utimensat(AT_FDCWD, file.path().c_str(), ts, 0),
+ SyscallSucceeds());
+ };
+ CheckTimes(file, fn, AtimeEffect::Changed, MtimeEffect::Unchanged,
+ CtimeEffect::Changed);
+}
+
+// Truncating a file changes mtime and ctime.
+TEST(StatTimesTest, FileTruncate) {
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(GetAbsoluteTestTmpdir(), "yaaass", 0666));
+
+ auto fn = [&] {
+ EXPECT_THAT(truncate(file.path().c_str(), 0), SyscallSucceeds());
+ };
+ CheckTimes(file, fn, AtimeEffect::Unchanged, MtimeEffect::Changed,
+ CtimeEffect::Changed);
+}
+
+// Writing a file changes mtime and ctime.
+TEST(StatTimesTest, FileWrite) {
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(GetAbsoluteTestTmpdir(), "yaaass", 0666));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0));
+
+ auto fn = [&] {
+ const std::string contents = "all the single dollars";
+ EXPECT_THAT(WriteFd(fd.get(), contents.data(), contents.size()),
+ SyscallSucceeds());
+ };
+ CheckTimes(file, fn, AtimeEffect::Unchanged, MtimeEffect::Changed,
+ CtimeEffect::Changed);
+}
+
+// Reading a file changes atime.
+TEST(StatTimesTest, FileRead) {
+ const std::string contents = "bills bills bills";
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(GetAbsoluteTestTmpdir(), contents, 0666));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY, 0));
+
+ auto fn = [&] {
+ char buf[20];
+ ASSERT_THAT(ReadFd(fd.get(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(contents.size()));
+ };
+ CheckTimes(file, fn, AtimeEffect::Changed, MtimeEffect::Unchanged,
+ CtimeEffect::Unchanged);
+}
+
+// Listing files in a directory changes atime.
+TEST(StatTimesTest, DirList) {
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
+
+ auto fn = [&] {
+ const auto contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir(dir.path(), false));
+ EXPECT_THAT(contents, Not(IsEmpty()));
+ };
+ CheckTimes(dir, fn, AtimeEffect::Changed, MtimeEffect::Unchanged,
+ CtimeEffect::Unchanged);
+}
+
+// Creating a file in a directory changes mtime and ctime.
+TEST(StatTimesTest, DirCreateFile) {
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ TempPath file;
+ auto fn = [&] {
+ file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
+ };
+ CheckTimes(dir, fn, AtimeEffect::Unchanged, MtimeEffect::Changed,
+ CtimeEffect::Changed);
+}
+
+// Creating a directory in a directory changes mtime and ctime.
+TEST(StatTimesTest, DirCreateDir) {
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ TempPath dir2;
+ auto fn = [&] {
+ dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir.path()));
+ };
+ CheckTimes(dir, fn, AtimeEffect::Unchanged, MtimeEffect::Changed,
+ CtimeEffect::Changed);
+}
+
+// Removing a file from a directory changes mtime and ctime.
+TEST(StatTimesTest, DirRemoveFile) {
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
+ auto fn = [&] { file.reset(); };
+ CheckTimes(dir, fn, AtimeEffect::Unchanged, MtimeEffect::Changed,
+ CtimeEffect::Changed);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/statfs.cc b/test/syscalls/linux/statfs.cc
new file mode 100644
index 000000000..aca51d30f
--- /dev/null
+++ b/test/syscalls/linux/statfs.cc
@@ -0,0 +1,82 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <sys/statfs.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(StatfsTest, CannotStatBadPath) {
+ auto temp_file = NewTempAbsPathInDir("/tmp");
+
+ struct statfs st;
+ EXPECT_THAT(statfs(temp_file.c_str(), &st), SyscallFailsWithErrno(ENOENT));
+}
+
+TEST(StatfsTest, InternalTmpfs) {
+ auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ struct statfs st;
+ EXPECT_THAT(statfs(temp_file.path().c_str(), &st), SyscallSucceeds());
+}
+
+TEST(StatfsTest, InternalDevShm) {
+ struct statfs st;
+ EXPECT_THAT(statfs("/dev/shm", &st), SyscallSucceeds());
+}
+
+TEST(StatfsTest, NameLen) {
+ struct statfs st;
+ EXPECT_THAT(statfs("/dev/shm", &st), SyscallSucceeds());
+
+ // This assumes that /dev/shm is tmpfs.
+ EXPECT_EQ(st.f_namelen, NAME_MAX);
+}
+
+TEST(FstatfsTest, CannotStatBadFd) {
+ struct statfs st;
+ EXPECT_THAT(fstatfs(-1, &st), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(FstatfsTest, InternalTmpfs) {
+ auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(temp_file.path(), O_RDONLY));
+
+ struct statfs st;
+ EXPECT_THAT(fstatfs(fd.get(), &st), SyscallSucceeds());
+}
+
+TEST(FstatfsTest, InternalDevShm) {
+ auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/shm", O_RDONLY));
+
+ struct statfs st;
+ EXPECT_THAT(fstatfs(fd.get(), &st), SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/sticky.cc b/test/syscalls/linux/sticky.cc
new file mode 100644
index 000000000..4afed6d08
--- /dev/null
+++ b/test/syscalls/linux/sticky.cc
@@ -0,0 +1,161 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <grp.h>
+#include <sys/prctl.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+ABSL_FLAG(int32_t, scratch_uid, 65534, "first scratch UID");
+ABSL_FLAG(int32_t, scratch_gid, 65534, "first scratch GID");
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(StickyTest, StickyBitPermDenied) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID)));
+
+ const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(chmod(parent.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds());
+
+ // After changing credentials below, we need to use an open fd to make
+ // modifications in the parent dir, because there is no guarantee that we will
+ // still have the ability to open it.
+ const FileDescriptor parent_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(parent.path(), O_DIRECTORY));
+ ASSERT_THAT(openat(parent_fd.get(), "file", O_CREAT), SyscallSucceeds());
+ ASSERT_THAT(mkdirat(parent_fd.get(), "dir", 0777), SyscallSucceeds());
+ ASSERT_THAT(symlinkat("xyz", parent_fd.get(), "link"), SyscallSucceeds());
+
+ // Drop privileges and change IDs only in child thread, or else this parent
+ // thread won't be able to open some log files after the test ends.
+ ScopedThread([&] {
+ // Drop privileges.
+ if (HaveCapability(CAP_FOWNER).ValueOrDie()) {
+ EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, false));
+ }
+
+ // Change EUID and EGID.
+ EXPECT_THAT(
+ syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1),
+ SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid), -1),
+ SyscallSucceeds());
+
+ EXPECT_THAT(renameat(parent_fd.get(), "file", parent_fd.get(), "file2"),
+ SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(unlinkat(parent_fd.get(), "file", 0),
+ SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(unlinkat(parent_fd.get(), "dir", AT_REMOVEDIR),
+ SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(unlinkat(parent_fd.get(), "link", 0),
+ SyscallFailsWithErrno(EPERM));
+ });
+}
+
+TEST(StickyTest, StickyBitSameUID) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID)));
+
+ const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(chmod(parent.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds());
+
+ // After changing credentials below, we need to use an open fd to make
+ // modifications in the parent dir, because there is no guarantee that we will
+ // still have the ability to open it.
+ const FileDescriptor parent_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(parent.path(), O_DIRECTORY));
+ ASSERT_THAT(openat(parent_fd.get(), "file", O_CREAT), SyscallSucceeds());
+ ASSERT_THAT(mkdirat(parent_fd.get(), "dir", 0777), SyscallSucceeds());
+ ASSERT_THAT(symlinkat("xyz", parent_fd.get(), "link"), SyscallSucceeds());
+
+ // Drop privileges and change IDs only in child thread, or else this parent
+ // thread won't be able to open some log files after the test ends.
+ ScopedThread([&] {
+ // Drop privileges.
+ if (HaveCapability(CAP_FOWNER).ValueOrDie()) {
+ EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, false));
+ }
+
+ // Change EGID.
+ EXPECT_THAT(
+ syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1),
+ SyscallSucceeds());
+
+ // We still have the same EUID.
+ EXPECT_THAT(renameat(parent_fd.get(), "file", parent_fd.get(), "file2"),
+ SyscallSucceeds());
+ EXPECT_THAT(unlinkat(parent_fd.get(), "file2", 0), SyscallSucceeds());
+ EXPECT_THAT(unlinkat(parent_fd.get(), "dir", AT_REMOVEDIR),
+ SyscallSucceeds());
+ EXPECT_THAT(unlinkat(parent_fd.get(), "link", 0), SyscallSucceeds());
+ });
+}
+
+TEST(StickyTest, StickyBitCapFOWNER) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID)));
+
+ const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(chmod(parent.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds());
+
+ // After changing credentials below, we need to use an open fd to make
+ // modifications in the parent dir, because there is no guarantee that we will
+ // still have the ability to open it.
+ const FileDescriptor parent_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(parent.path(), O_DIRECTORY));
+ ASSERT_THAT(openat(parent_fd.get(), "file", O_CREAT), SyscallSucceeds());
+ ASSERT_THAT(mkdirat(parent_fd.get(), "dir", 0777), SyscallSucceeds());
+ ASSERT_THAT(symlinkat("xyz", parent_fd.get(), "link"), SyscallSucceeds());
+
+ // Drop privileges and change IDs only in child thread, or else this parent
+ // thread won't be able to open some log files after the test ends.
+ ScopedThread([&] {
+ // Set PR_SET_KEEPCAPS.
+ EXPECT_THAT(prctl(PR_SET_KEEPCAPS, 1, 0, 0, 0), SyscallSucceeds());
+
+ // Change EUID and EGID.
+ EXPECT_THAT(
+ syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1),
+ SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid), -1),
+ SyscallSucceeds());
+
+ EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, true));
+ EXPECT_THAT(renameat(parent_fd.get(), "file", parent_fd.get(), "file2"),
+ SyscallSucceeds());
+ EXPECT_THAT(unlinkat(parent_fd.get(), "file2", 0), SyscallSucceeds());
+ EXPECT_THAT(unlinkat(parent_fd.get(), "dir", AT_REMOVEDIR),
+ SyscallSucceeds());
+ EXPECT_THAT(unlinkat(parent_fd.get(), "link", 0), SyscallSucceeds());
+ });
+}
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/symlink.cc b/test/syscalls/linux/symlink.cc
new file mode 100644
index 000000000..a17ff62e9
--- /dev/null
+++ b/test/syscalls/linux/symlink.cc
@@ -0,0 +1,402 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <string.h>
+#include <unistd.h>
+
+#include <string>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+mode_t FilePermission(const std::string& path) {
+ struct stat buf = {0};
+ TEST_CHECK(lstat(path.c_str(), &buf) == 0);
+ return buf.st_mode & 0777;
+}
+
+// Test that name collisions are checked on the new link path, not the source
+// path. Regression test for b/31782115.
+TEST(SymlinkTest, CanCreateSymlinkWithCachedSourceDirent) {
+ const std::string srcname = NewTempAbsPath();
+ const std::string newname = NewTempAbsPath();
+ const std::string basedir = std::string(Dirname(srcname));
+ ASSERT_EQ(basedir, Dirname(newname));
+
+ ASSERT_THAT(chdir(basedir.c_str()), SyscallSucceeds());
+
+ // Open the source node to cause the underlying dirent to be cached. It will
+ // remain cached while we have the file open.
+ int fd;
+ ASSERT_THAT(fd = open(srcname.c_str(), O_CREAT | O_RDWR, 0666),
+ SyscallSucceeds());
+ FileDescriptor fd_closer(fd);
+
+ // Attempt to create a symlink. If the bug exists, this will fail since the
+ // dirent link creation code will check for a name collision on the source
+ // link name.
+ EXPECT_THAT(symlink(std::string(Basename(srcname)).c_str(),
+ std::string(Basename(newname)).c_str()),
+ SyscallSucceeds());
+}
+
+TEST(SymlinkTest, CanCreateSymlinkFile) {
+ const std::string oldname = NewTempAbsPath();
+ const std::string newname = NewTempAbsPath();
+
+ int fd;
+ ASSERT_THAT(fd = open(oldname.c_str(), O_CREAT | O_RDWR, 0666),
+ SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+
+ EXPECT_THAT(symlink(oldname.c_str(), newname.c_str()), SyscallSucceeds());
+ EXPECT_EQ(FilePermission(newname), 0777);
+
+ auto link = ASSERT_NO_ERRNO_AND_VALUE(ReadLink(newname));
+ EXPECT_EQ(oldname, link);
+
+ EXPECT_THAT(unlink(newname.c_str()), SyscallSucceeds());
+ EXPECT_THAT(unlink(oldname.c_str()), SyscallSucceeds());
+}
+
+TEST(SymlinkTest, CanCreateSymlinkDir) {
+ const std::string olddir = NewTempAbsPath();
+ const std::string newdir = NewTempAbsPath();
+
+ EXPECT_THAT(mkdir(olddir.c_str(), 0777), SyscallSucceeds());
+ EXPECT_THAT(symlink(olddir.c_str(), newdir.c_str()), SyscallSucceeds());
+ EXPECT_EQ(FilePermission(newdir), 0777);
+
+ auto link = ASSERT_NO_ERRNO_AND_VALUE(ReadLink(newdir));
+ EXPECT_EQ(olddir, link);
+
+ EXPECT_THAT(unlink(newdir.c_str()), SyscallSucceeds());
+
+ ASSERT_THAT(rmdir(olddir.c_str()), SyscallSucceeds());
+}
+
+TEST(SymlinkTest, CannotCreateSymlinkInReadOnlyDir) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ const std::string olddir = NewTempAbsPath();
+ ASSERT_THAT(mkdir(olddir.c_str(), 0444), SyscallSucceeds());
+
+ const std::string newdir = NewTempAbsPathInDir(olddir);
+ EXPECT_THAT(symlink(olddir.c_str(), newdir.c_str()),
+ SyscallFailsWithErrno(EACCES));
+
+ ASSERT_THAT(rmdir(olddir.c_str()), SyscallSucceeds());
+}
+
+TEST(SymlinkTest, CannotSymlinkOverExistingFile) {
+ const auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const auto newfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ EXPECT_THAT(symlink(oldfile.path().c_str(), newfile.path().c_str()),
+ SyscallFailsWithErrno(EEXIST));
+}
+
+TEST(SymlinkTest, CannotSymlinkOverExistingDir) {
+ const auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const auto newdir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ EXPECT_THAT(symlink(oldfile.path().c_str(), newdir.path().c_str()),
+ SyscallFailsWithErrno(EEXIST));
+}
+
+TEST(SymlinkTest, OldnameIsEmpty) {
+ const std::string newname = NewTempAbsPath();
+ EXPECT_THAT(symlink("", newname.c_str()), SyscallFailsWithErrno(ENOENT));
+}
+
+TEST(SymlinkTest, OldnameIsDangling) {
+ const std::string newname = NewTempAbsPath();
+ EXPECT_THAT(symlink("/dangling", newname.c_str()), SyscallSucceeds());
+
+ // This is required for S/R random save tests, which pre-run this test
+ // in the same TEST_TMPDIR, which means that we need to clean it for any
+ // operations exclusively creating files, like symlink above.
+ EXPECT_THAT(unlink(newname.c_str()), SyscallSucceeds());
+}
+
+TEST(SymlinkTest, NewnameCannotExist) {
+ const std::string newname =
+ JoinPath(GetAbsoluteTestTmpdir(), "thisdoesnotexist", "foo");
+ EXPECT_THAT(symlink("/thisdoesnotmatter", newname.c_str()),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+TEST(SymlinkTest, CanEvaluateLink) {
+ const auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ // We are going to assert that the symlink inode id is the same as the linked
+ // file's inode id. In order for the inode id to be stable across
+ // save/restore, it must be kept open. The FileDescriptor type will do that
+ // for us automatically.
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+ struct stat file_st;
+ EXPECT_THAT(fstat(fd.get(), &file_st), SyscallSucceeds());
+
+ const std::string link = NewTempAbsPath();
+ EXPECT_THAT(symlink(file.path().c_str(), link.c_str()), SyscallSucceeds());
+ EXPECT_EQ(FilePermission(link), 0777);
+
+ auto linkfd = ASSERT_NO_ERRNO_AND_VALUE(Open(link.c_str(), O_RDWR));
+ struct stat link_st;
+ EXPECT_THAT(fstat(linkfd.get(), &link_st), SyscallSucceeds());
+
+ // Check that in fact newname points to the file we expect.
+ EXPECT_EQ(file_st.st_dev, link_st.st_dev);
+ EXPECT_EQ(file_st.st_ino, link_st.st_ino);
+}
+
+TEST(SymlinkTest, TargetIsNotMapped) {
+ const std::string oldname = NewTempAbsPath();
+ const std::string newname = NewTempAbsPath();
+
+ int fd;
+ // Create the target so that when we read the link, it exists.
+ ASSERT_THAT(fd = open(oldname.c_str(), O_CREAT | O_RDWR, 0666),
+ SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+
+ // Create a symlink called newname that points to oldname.
+ EXPECT_THAT(symlink(oldname.c_str(), newname.c_str()), SyscallSucceeds());
+
+ std::vector<char> buf(1024);
+ int linksize;
+ // Read the link and assert that the oldname is still the same.
+ EXPECT_THAT(linksize = readlink(newname.c_str(), buf.data(), 1024),
+ SyscallSucceeds());
+ EXPECT_EQ(0, strncmp(oldname.c_str(), buf.data(), linksize));
+
+ EXPECT_THAT(unlink(newname.c_str()), SyscallSucceeds());
+ EXPECT_THAT(unlink(oldname.c_str()), SyscallSucceeds());
+}
+
+TEST(SymlinkTest, PreadFromSymlink) {
+ std::string name = NewTempAbsPath();
+ int fd;
+ ASSERT_THAT(fd = open(name.c_str(), O_CREAT, 0644), SyscallSucceeds());
+ ASSERT_THAT(close(fd), SyscallSucceeds());
+
+ std::string linkname = NewTempAbsPath();
+ ASSERT_THAT(symlink(name.c_str(), linkname.c_str()), SyscallSucceeds());
+
+ ASSERT_THAT(fd = open(linkname.c_str(), O_RDONLY), SyscallSucceeds());
+
+ char buf[1024];
+ EXPECT_THAT(pread64(fd, buf, 1024, 0), SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+
+ EXPECT_THAT(unlink(name.c_str()), SyscallSucceeds());
+ EXPECT_THAT(unlink(linkname.c_str()), SyscallSucceeds());
+}
+
+TEST(SymlinkTest, SymlinkAtDegradedPermissions_NoRandomSave) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
+
+ int dirfd;
+ ASSERT_THAT(dirfd = open(dir.path().c_str(), O_DIRECTORY, 0),
+ SyscallSucceeds());
+
+ const DisableSave ds; // Permissions are dropped.
+ EXPECT_THAT(fchmod(dirfd, 0), SyscallSucceeds());
+
+ std::string basename = std::string(Basename(file.path()));
+ EXPECT_THAT(symlinkat("/dangling", dirfd, basename.c_str()),
+ SyscallFailsWithErrno(EACCES));
+ EXPECT_THAT(close(dirfd), SyscallSucceeds());
+}
+
+TEST(SymlinkTest, ReadlinkAtDegradedPermissions_NoRandomSave) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const std::string oldpath = NewTempAbsPathInDir(dir.path());
+ const std::string oldbase = std::string(Basename(oldpath));
+ ASSERT_THAT(symlink("/dangling", oldpath.c_str()), SyscallSucceeds());
+
+ int dirfd;
+ EXPECT_THAT(dirfd = open(dir.path().c_str(), O_DIRECTORY, 0),
+ SyscallSucceeds());
+
+ const DisableSave ds; // Permissions are dropped.
+ EXPECT_THAT(fchmod(dirfd, 0), SyscallSucceeds());
+
+ char buf[1024];
+ int linksize;
+ EXPECT_THAT(linksize = readlinkat(dirfd, oldbase.c_str(), buf, 1024),
+ SyscallFailsWithErrno(EACCES));
+ EXPECT_THAT(close(dirfd), SyscallSucceeds());
+}
+
+TEST(SymlinkTest, ChmodSymlink) {
+ auto target = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const std::string newpath = NewTempAbsPath();
+ ASSERT_THAT(symlink(target.path().c_str(), newpath.c_str()),
+ SyscallSucceeds());
+ EXPECT_EQ(FilePermission(newpath), 0777);
+ EXPECT_THAT(chmod(newpath.c_str(), 0666), SyscallSucceeds());
+ EXPECT_EQ(FilePermission(newpath), 0777);
+}
+
+// Test that following a symlink updates the atime on the symlink.
+TEST(SymlinkTest, FollowUpdatesATime) {
+ const auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const std::string link = NewTempAbsPath();
+ EXPECT_THAT(symlink(file.path().c_str(), link.c_str()), SyscallSucceeds());
+
+ // Lstat the symlink.
+ struct stat st_before_follow;
+ ASSERT_THAT(lstat(link.c_str(), &st_before_follow), SyscallSucceeds());
+
+ // Let the clock advance.
+ absl::SleepFor(absl::Seconds(1));
+
+ // Open the file via the symlink.
+ int fd;
+ ASSERT_THAT(fd = open(link.c_str(), O_RDWR, 0666), SyscallSucceeds());
+ FileDescriptor fd_closer(fd);
+
+ // Lstat the symlink again, and check that atime is updated.
+ struct stat st_after_follow;
+ ASSERT_THAT(lstat(link.c_str(), &st_after_follow), SyscallSucceeds());
+ EXPECT_LT(st_before_follow.st_atime, st_after_follow.st_atime);
+}
+
+class ParamSymlinkTest : public ::testing::TestWithParam<std::string> {};
+
+// Test that creating an existing symlink with creat will create the target.
+TEST_P(ParamSymlinkTest, CreatLinkCreatesTarget) {
+ const std::string target = GetParam();
+ const std::string linkpath = NewTempAbsPath();
+
+ ASSERT_THAT(symlink(target.c_str(), linkpath.c_str()), SyscallSucceeds());
+
+ int fd;
+ EXPECT_THAT(fd = creat(linkpath.c_str(), 0666), SyscallSucceeds());
+ ASSERT_THAT(close(fd), SyscallSucceeds());
+
+ ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds());
+ struct stat st;
+ EXPECT_THAT(stat(target.c_str(), &st), SyscallSucceeds());
+
+ ASSERT_THAT(unlink(linkpath.c_str()), SyscallSucceeds());
+ ASSERT_THAT(unlink(target.c_str()), SyscallSucceeds());
+}
+
+// Test that opening an existing symlink with O_CREAT will create the target.
+TEST_P(ParamSymlinkTest, OpenLinkCreatesTarget) {
+ const std::string target = GetParam();
+ const std::string linkpath = NewTempAbsPath();
+
+ ASSERT_THAT(symlink(target.c_str(), linkpath.c_str()), SyscallSucceeds());
+
+ int fd;
+ EXPECT_THAT(fd = open(linkpath.c_str(), O_CREAT, 0666), SyscallSucceeds());
+ ASSERT_THAT(close(fd), SyscallSucceeds());
+
+ ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds());
+ struct stat st;
+ EXPECT_THAT(stat(target.c_str(), &st), SyscallSucceeds());
+
+ ASSERT_THAT(unlink(linkpath.c_str()), SyscallSucceeds());
+ ASSERT_THAT(unlink(target.c_str()), SyscallSucceeds());
+}
+
+// Test that opening a self-symlink with O_CREAT will fail with ELOOP.
+TEST_P(ParamSymlinkTest, CreateExistingSelfLink) {
+ ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds());
+
+ const std::string linkpath = GetParam();
+ ASSERT_THAT(symlink(linkpath.c_str(), linkpath.c_str()), SyscallSucceeds());
+
+ EXPECT_THAT(open(linkpath.c_str(), O_CREAT, 0666),
+ SyscallFailsWithErrno(ELOOP));
+
+ ASSERT_THAT(unlink(linkpath.c_str()), SyscallSucceeds());
+}
+
+// Test that opening a file that is a symlink to its parent directory fails
+// with ELOOP.
+TEST_P(ParamSymlinkTest, CreateExistingParentLink) {
+ ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds());
+
+ const std::string linkpath = GetParam();
+ const std::string target = JoinPath(linkpath, "child");
+ ASSERT_THAT(symlink(target.c_str(), linkpath.c_str()), SyscallSucceeds());
+
+ EXPECT_THAT(open(linkpath.c_str(), O_CREAT, 0666),
+ SyscallFailsWithErrno(ELOOP));
+
+ ASSERT_THAT(unlink(linkpath.c_str()), SyscallSucceeds());
+}
+
+// Test that opening an existing symlink with O_CREAT|O_EXCL will fail with
+// EEXIST.
+TEST_P(ParamSymlinkTest, OpenLinkExclFails) {
+ const std::string target = GetParam();
+ const std::string linkpath = NewTempAbsPath();
+
+ ASSERT_THAT(symlink(target.c_str(), linkpath.c_str()), SyscallSucceeds());
+
+ EXPECT_THAT(open(linkpath.c_str(), O_CREAT | O_EXCL, 0666),
+ SyscallFailsWithErrno(EEXIST));
+
+ ASSERT_THAT(unlink(linkpath.c_str()), SyscallSucceeds());
+}
+
+// Test that opening an existing symlink with O_CREAT|O_NOFOLLOW will fail with
+// ELOOP.
+TEST_P(ParamSymlinkTest, OpenLinkNoFollowFails) {
+ const std::string target = GetParam();
+ const std::string linkpath = NewTempAbsPath();
+
+ ASSERT_THAT(symlink(target.c_str(), linkpath.c_str()), SyscallSucceeds());
+
+ EXPECT_THAT(open(linkpath.c_str(), O_CREAT | O_NOFOLLOW, 0666),
+ SyscallFailsWithErrno(ELOOP));
+
+ ASSERT_THAT(unlink(linkpath.c_str()), SyscallSucceeds());
+}
+
+INSTANTIATE_TEST_SUITE_P(AbsAndRelTarget, ParamSymlinkTest,
+ ::testing::Values(NewTempAbsPath(), NewTempRelPath()));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/sync.cc b/test/syscalls/linux/sync.cc
new file mode 100644
index 000000000..8aa2525a9
--- /dev/null
+++ b/test/syscalls/linux/sync.cc
@@ -0,0 +1,59 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdio.h>
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include <string>
+
+#include "gtest/gtest.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(SyncTest, SyncEverything) {
+ ASSERT_THAT(syscall(SYS_sync), SyscallSucceeds());
+}
+
+TEST(SyncTest, SyncFileSytem) {
+ int fd;
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ ASSERT_THAT(fd = open(f.path().c_str(), O_RDONLY), SyscallSucceeds());
+ EXPECT_THAT(syncfs(fd), SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+TEST(SyncTest, SyncFromPipe) {
+ int pipes[2];
+ EXPECT_THAT(pipe(pipes), SyscallSucceeds());
+ EXPECT_THAT(syncfs(pipes[0]), SyscallSucceeds());
+ EXPECT_THAT(syncfs(pipes[1]), SyscallSucceeds());
+ EXPECT_THAT(close(pipes[0]), SyscallSucceeds());
+ EXPECT_THAT(close(pipes[1]), SyscallSucceeds());
+}
+
+TEST(SyncTest, CannotSyncFileSytemAtBadFd) {
+ EXPECT_THAT(syncfs(-1), SyscallFailsWithErrno(EBADF));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/sync_file_range.cc b/test/syscalls/linux/sync_file_range.cc
new file mode 100644
index 000000000..36cc42043
--- /dev/null
+++ b/test/syscalls/linux/sync_file_range.cc
@@ -0,0 +1,112 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdio.h>
+#include <unistd.h>
+
+#include <string>
+
+#include "gtest/gtest.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(SyncFileRangeTest, TempFileSucceeds) {
+ auto tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path(), O_RDWR));
+ constexpr char data[] = "some data to sync";
+ int fd = f.get();
+
+ EXPECT_THAT(write(fd, data, sizeof(data)),
+ SyscallSucceedsWithValue(sizeof(data)));
+ EXPECT_THAT(sync_file_range(fd, 0, 0, SYNC_FILE_RANGE_WRITE),
+ SyscallSucceeds());
+ EXPECT_THAT(sync_file_range(fd, 0, 0, 0), SyscallSucceeds());
+ EXPECT_THAT(
+ sync_file_range(fd, 0, 0,
+ SYNC_FILE_RANGE_WRITE | SYNC_FILE_RANGE_WAIT_AFTER |
+ SYNC_FILE_RANGE_WAIT_BEFORE),
+ SyscallSucceeds());
+ EXPECT_THAT(sync_file_range(
+ fd, 0, 1, SYNC_FILE_RANGE_WRITE | SYNC_FILE_RANGE_WAIT_AFTER),
+ SyscallSucceeds());
+ EXPECT_THAT(sync_file_range(
+ fd, 1, 0, SYNC_FILE_RANGE_WRITE | SYNC_FILE_RANGE_WAIT_AFTER),
+ SyscallSucceeds());
+}
+
+TEST(SyncFileRangeTest, CannotSyncFileRangeOnUnopenedFd) {
+ auto tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path(), O_RDWR));
+ constexpr char data[] = "some data to sync";
+ int fd = f.get();
+
+ EXPECT_THAT(write(fd, data, sizeof(data)),
+ SyscallSucceedsWithValue(sizeof(data)));
+
+ pid_t pid = fork();
+ if (pid == 0) {
+ f.reset();
+
+ // fd is now invalid.
+ TEST_CHECK(sync_file_range(fd, 0, 0, SYNC_FILE_RANGE_WRITE) == -1);
+ TEST_PCHECK(errno == EBADF);
+ _exit(0);
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+
+ int status = 0;
+ ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFEXITED(status));
+ EXPECT_EQ(WEXITSTATUS(status), 0);
+}
+
+TEST(SyncFileRangeTest, BadArgs) {
+ auto tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path(), O_RDWR));
+ int fd = f.get();
+
+ EXPECT_THAT(sync_file_range(fd, -1, 0, 0), SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(sync_file_range(fd, 0, -1, 0), SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(sync_file_range(fd, 8912, INT64_MAX - 4096, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SyncFileRangeTest, CannotSyncFileRangeWithWaitBefore) {
+ auto tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path(), O_RDWR));
+ constexpr char data[] = "some data to sync";
+ int fd = f.get();
+
+ EXPECT_THAT(write(fd, data, sizeof(data)),
+ SyscallSucceedsWithValue(sizeof(data)));
+ if (IsRunningOnGvisor()) {
+ EXPECT_THAT(sync_file_range(fd, 0, 0, SYNC_FILE_RANGE_WAIT_BEFORE),
+ SyscallFailsWithErrno(ENOSYS));
+ EXPECT_THAT(
+ sync_file_range(fd, 0, 0,
+ SYNC_FILE_RANGE_WAIT_BEFORE | SYNC_FILE_RANGE_WRITE),
+ SyscallFailsWithErrno(ENOSYS));
+ }
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/sysinfo.cc b/test/syscalls/linux/sysinfo.cc
new file mode 100644
index 000000000..1a71256da
--- /dev/null
+++ b/test/syscalls/linux/sysinfo.cc
@@ -0,0 +1,86 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This is a very simple sanity test to validate that the sysinfo syscall is
+// supported by gvisor and returns sane values.
+#include <sys/syscall.h>
+#include <sys/sysinfo.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(SysinfoTest, SysinfoIsCallable) {
+ struct sysinfo ignored = {};
+ EXPECT_THAT(syscall(SYS_sysinfo, &ignored), SyscallSucceedsWithValue(0));
+}
+
+TEST(SysinfoTest, EfaultProducedOnBadAddress) {
+ // Validate that we return EFAULT when a bad address is provided.
+ // specified by man 2 sysinfo
+ EXPECT_THAT(syscall(SYS_sysinfo, nullptr), SyscallFailsWithErrno(EFAULT));
+}
+
+TEST(SysinfoTest, TotalRamSaneValue) {
+ struct sysinfo s = {};
+ EXPECT_THAT(sysinfo(&s), SyscallSucceedsWithValue(0));
+ EXPECT_GT(s.totalram, 0);
+}
+
+TEST(SysinfoTest, MemunitSet) {
+ struct sysinfo s = {};
+ EXPECT_THAT(sysinfo(&s), SyscallSucceedsWithValue(0));
+ EXPECT_GE(s.mem_unit, 1);
+}
+
+TEST(SysinfoTest, UptimeSaneValue) {
+ struct sysinfo s = {};
+ EXPECT_THAT(sysinfo(&s), SyscallSucceedsWithValue(0));
+ EXPECT_GE(s.uptime, 0);
+}
+
+TEST(SysinfoTest, UptimeIncreasingValue) {
+ struct sysinfo s = {};
+ EXPECT_THAT(sysinfo(&s), SyscallSucceedsWithValue(0));
+ absl::SleepFor(absl::Seconds(2));
+ struct sysinfo s2 = {};
+ EXPECT_THAT(sysinfo(&s2), SyscallSucceedsWithValue(0));
+ EXPECT_LT(s.uptime, s2.uptime);
+}
+
+TEST(SysinfoTest, FreeRamSaneValue) {
+ struct sysinfo s = {};
+ EXPECT_THAT(sysinfo(&s), SyscallSucceedsWithValue(0));
+ EXPECT_GT(s.freeram, 0);
+ EXPECT_LT(s.freeram, s.totalram);
+}
+
+TEST(SysinfoTest, NumProcsSaneValue) {
+ struct sysinfo s = {};
+ EXPECT_THAT(sysinfo(&s), SyscallSucceedsWithValue(0));
+ EXPECT_GT(s.procs, 0);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/syslog.cc b/test/syscalls/linux/syslog.cc
new file mode 100644
index 000000000..9a7407d96
--- /dev/null
+++ b/test/syscalls/linux/syslog.cc
@@ -0,0 +1,51 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/klog.h>
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr int SYSLOG_ACTION_READ_ALL = 3;
+constexpr int SYSLOG_ACTION_SIZE_BUFFER = 10;
+
+int Syslog(int type, char* buf, int len) {
+ return syscall(__NR_syslog, type, buf, len);
+}
+
+// Only SYSLOG_ACTION_SIZE_BUFFER and SYSLOG_ACTION_READ_ALL are implemented in
+// gVisor.
+
+TEST(Syslog, Size) {
+ EXPECT_THAT(Syslog(SYSLOG_ACTION_SIZE_BUFFER, nullptr, 0), SyscallSucceeds());
+}
+
+TEST(Syslog, ReadAll) {
+ // There might not be anything to read, so we can't check the write count.
+ char buf[100];
+ EXPECT_THAT(Syslog(SYSLOG_ACTION_READ_ALL, buf, sizeof(buf)),
+ SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/sysret.cc b/test/syscalls/linux/sysret.cc
new file mode 100644
index 000000000..19ffbd85b
--- /dev/null
+++ b/test/syscalls/linux/sysret.cc
@@ -0,0 +1,142 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests to verify that the behavior of linux and gvisor matches when
+// 'sysret' returns to bad (aka non-canonical) %rip or %rsp.
+
+#include <linux/elf.h>
+#include <sys/ptrace.h>
+#include <sys/user.h>
+
+#include "gtest/gtest.h"
+#include "test/util/logging.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr uint64_t kNonCanonicalRip = 0xCCCC000000000000;
+constexpr uint64_t kNonCanonicalRsp = 0xFFFF000000000000;
+
+class SysretTest : public ::testing::Test {
+ protected:
+ struct user_regs_struct regs_;
+ struct iovec iov;
+ pid_t child_;
+
+ void SetUp() override {
+ pid_t pid = fork();
+
+ // Child.
+ if (pid == 0) {
+ TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0);
+ MaybeSave();
+ TEST_PCHECK(raise(SIGSTOP) == 0);
+ MaybeSave();
+ _exit(0);
+ }
+
+ // Parent.
+ int status;
+ memset(&iov, 0, sizeof(iov));
+ ASSERT_THAT(pid, SyscallSucceeds()); // Might still be < 0.
+ ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP);
+
+ iov.iov_base = &regs_;
+ iov.iov_len = sizeof(regs_);
+ ASSERT_THAT(ptrace(PTRACE_GETREGSET, pid, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+
+ child_ = pid;
+ }
+
+ void Detach() {
+ ASSERT_THAT(ptrace(PTRACE_DETACH, child_, 0, 0), SyscallSucceeds());
+ }
+
+ void SetRip(uint64_t newrip) {
+#if defined(__x86_64__)
+ regs_.rip = newrip;
+#elif defined(__aarch64__)
+ regs_.pc = newrip;
+#else
+#error "Unknown architecture"
+#endif
+ ASSERT_THAT(ptrace(PTRACE_SETREGSET, child_, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ }
+
+ void SetRsp(uint64_t newrsp) {
+#if defined(__x86_64__)
+ regs_.rsp = newrsp;
+#elif defined(__aarch64__)
+ regs_.sp = newrsp;
+#else
+#error "Unknown architecture"
+#endif
+ ASSERT_THAT(ptrace(PTRACE_SETREGSET, child_, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ }
+
+ // Wait waits for the child pid and returns the exit status.
+ int Wait() {
+ int status;
+ while (true) {
+ int rval = wait4(child_, &status, 0, NULL);
+ if (rval < 0) {
+ return rval;
+ }
+ if (rval == child_) {
+ return status;
+ }
+ }
+ }
+};
+
+TEST_F(SysretTest, JustDetach) {
+ Detach();
+ int status = Wait();
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << "status = " << status;
+}
+
+TEST_F(SysretTest, BadRip) {
+ SetRip(kNonCanonicalRip);
+ Detach();
+ int status = Wait();
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSEGV)
+ << "status = " << status;
+}
+
+TEST_F(SysretTest, BadRsp) {
+ SetRsp(kNonCanonicalRsp);
+ Detach();
+ int status = Wait();
+#if defined(__x86_64__)
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGBUS)
+ << "status = " << status;
+#elif defined(__aarch64__)
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSEGV)
+ << "status = " << status;
+#else
+#error "Unknown architecture"
+#endif
+}
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
new file mode 100644
index 000000000..a4d2953e1
--- /dev/null
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -0,0 +1,1568 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <poll.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <unistd.h>
+
+#include <limits>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) {
+ struct sockaddr_storage addr;
+ memset(&addr, 0, sizeof(addr));
+ addr.ss_family = family;
+ switch (family) {
+ case AF_INET:
+ reinterpret_cast<struct sockaddr_in*>(&addr)->sin_addr.s_addr =
+ htonl(INADDR_LOOPBACK);
+ break;
+ case AF_INET6:
+ reinterpret_cast<struct sockaddr_in6*>(&addr)->sin6_addr =
+ in6addr_loopback;
+ break;
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+ return addr;
+}
+
+// Fixture for tests parameterized by the address family to use (AF_INET and
+// AF_INET6) when creating sockets.
+class TcpSocketTest : public ::testing::TestWithParam<int> {
+ protected:
+ // Creates three sockets that will be used by test cases -- a listener, one
+ // that connects, and the accepted one.
+ void SetUp() override;
+
+ // Closes the sockets created by SetUp().
+ void TearDown() override;
+
+ // Listening socket.
+ int listener_ = -1;
+
+ // Socket connected via connect().
+ int s_ = -1;
+
+ // Socket connected via accept().
+ int t_ = -1;
+
+ // Initial size of the send buffer.
+ int sendbuf_size_ = -1;
+};
+
+void TcpSocketTest::SetUp() {
+ ASSERT_THAT(listener_ = socket(GetParam(), SOCK_STREAM, IPPROTO_TCP),
+ SyscallSucceeds());
+
+ ASSERT_THAT(s_ = socket(GetParam(), SOCK_STREAM, IPPROTO_TCP),
+ SyscallSucceeds());
+
+ // Initialize address to the loopback one.
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ // Bind to some port then start listening.
+ ASSERT_THAT(
+ bind(listener_, reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(listener_, SOMAXCONN), SyscallSucceeds());
+
+ // Get the address we're listening on, then connect to it. We need to do this
+ // because we're allowing the stack to pick a port for us.
+ ASSERT_THAT(getsockname(listener_, reinterpret_cast<struct sockaddr*>(&addr),
+ &addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(RetryEINTR(connect)(s_, reinterpret_cast<struct sockaddr*>(&addr),
+ addrlen),
+ SyscallSucceeds());
+
+ // Get the initial send buffer size.
+ socklen_t optlen = sizeof(sendbuf_size_);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &sendbuf_size_, &optlen),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ ASSERT_THAT(t_ = RetryEINTR(accept)(listener_, nullptr, nullptr),
+ SyscallSucceeds());
+}
+
+void TcpSocketTest::TearDown() {
+ EXPECT_THAT(close(listener_), SyscallSucceeds());
+ if (s_ >= 0) {
+ EXPECT_THAT(close(s_), SyscallSucceeds());
+ }
+ if (t_ >= 0) {
+ EXPECT_THAT(close(t_), SyscallSucceeds());
+ }
+}
+
+TEST_P(TcpSocketTest, ConnectOnEstablishedConnection) {
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<const struct sockaddr*>(&addr), addrlen),
+ SyscallFailsWithErrno(EISCONN));
+ ASSERT_THAT(
+ connect(t_, reinterpret_cast<const struct sockaddr*>(&addr), addrlen),
+ SyscallFailsWithErrno(EISCONN));
+}
+
+TEST_P(TcpSocketTest, ShutdownWriteInTimeWait) {
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds());
+ EXPECT_THAT(shutdown(s_, SHUT_RDWR), SyscallSucceeds());
+ absl::SleepFor(absl::Seconds(1)); // Wait to enter TIME_WAIT.
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
+}
+
+TEST_P(TcpSocketTest, ShutdownWriteInFinWait1) {
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds());
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds());
+ absl::SleepFor(absl::Seconds(1)); // Wait to enter FIN-WAIT2.
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds());
+}
+
+TEST_P(TcpSocketTest, DataCoalesced) {
+ char buf[10];
+
+ // Write in two steps.
+ ASSERT_THAT(RetryEINTR(write)(s_, buf, sizeof(buf) / 2),
+ SyscallSucceedsWithValue(sizeof(buf) / 2));
+ ASSERT_THAT(RetryEINTR(write)(s_, buf, sizeof(buf) / 2),
+ SyscallSucceedsWithValue(sizeof(buf) / 2));
+
+ // Allow stack to process both packets.
+ absl::SleepFor(absl::Seconds(1));
+
+ // Read in one shot.
+ EXPECT_THAT(RetryEINTR(recv)(t_, buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(buf)));
+}
+
+TEST_P(TcpSocketTest, SenderAddressIgnored) {
+ char buf[3];
+ ASSERT_THAT(RetryEINTR(write)(s_, buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ memset(&addr, 0, sizeof(addr));
+
+ ASSERT_THAT(
+ RetryEINTR(recvfrom)(t_, buf, sizeof(buf), 0,
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ SyscallSucceedsWithValue(3));
+
+ // Check that addr remains zeroed-out.
+ const char* ptr = reinterpret_cast<char*>(&addr);
+ for (size_t i = 0; i < sizeof(addr); i++) {
+ EXPECT_EQ(ptr[i], 0);
+ }
+}
+
+TEST_P(TcpSocketTest, SenderAddressIgnoredOnPeek) {
+ char buf[3];
+ ASSERT_THAT(RetryEINTR(write)(s_, buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ memset(&addr, 0, sizeof(addr));
+
+ ASSERT_THAT(
+ RetryEINTR(recvfrom)(t_, buf, sizeof(buf), MSG_PEEK,
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ SyscallSucceedsWithValue(3));
+
+ // Check that addr remains zeroed-out.
+ const char* ptr = reinterpret_cast<char*>(&addr);
+ for (size_t i = 0; i < sizeof(addr); i++) {
+ EXPECT_EQ(ptr[i], 0);
+ }
+}
+
+TEST_P(TcpSocketTest, SendtoAddressIgnored) {
+ struct sockaddr_storage addr;
+ memset(&addr, 0, sizeof(addr));
+ addr.ss_family = GetParam(); // FIXME(b/63803955)
+
+ char data = '\0';
+ EXPECT_THAT(
+ RetryEINTR(sendto)(s_, &data, sizeof(data), 0,
+ reinterpret_cast<sockaddr*>(&addr), sizeof(addr)),
+ SyscallSucceedsWithValue(1));
+}
+
+TEST_P(TcpSocketTest, WritevZeroIovec) {
+ // 2 bytes just to be safe and have vecs[1] not point to something random
+ // (even though length is 0).
+ char buf[2];
+ char recv_buf[1];
+
+ // Construct a vec where the final vector is of length 0.
+ iovec vecs[2] = {};
+ vecs[0].iov_base = buf;
+ vecs[0].iov_len = 1;
+ vecs[1].iov_base = buf + 1;
+ vecs[1].iov_len = 0;
+
+ EXPECT_THAT(RetryEINTR(writev)(s_, vecs, 2), SyscallSucceedsWithValue(1));
+
+ EXPECT_THAT(RetryEINTR(recv)(t_, recv_buf, 1, 0),
+ SyscallSucceedsWithValue(1));
+ EXPECT_EQ(memcmp(recv_buf, buf, 1), 0);
+}
+
+TEST_P(TcpSocketTest, ZeroWriteAllowed) {
+ char buf[3];
+ // Send a zero length packet.
+ ASSERT_THAT(RetryEINTR(write)(s_, buf, 0), SyscallSucceedsWithValue(0));
+ // Verify that there is no packet available.
+ EXPECT_THAT(RetryEINTR(recv)(t_, buf, sizeof(buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// Test that a non-blocking write with a buffer that is larger than the send
+// buffer size will not actually write the whole thing at once. Regression test
+// for b/64438887.
+TEST_P(TcpSocketTest, NonblockingLargeWrite) {
+ // Set the FD to O_NONBLOCK.
+ int opts;
+ ASSERT_THAT(opts = fcntl(s_, F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(s_, F_SETFL, opts), SyscallSucceeds());
+
+ // Allocate a buffer three times the size of the send buffer. We do this with
+ // a vector to avoid allocating on the stack.
+ int size = 3 * sendbuf_size_;
+ std::vector<char> buf(size);
+
+ // Try to write the whole thing.
+ int n;
+ ASSERT_THAT(n = RetryEINTR(write)(s_, buf.data(), size), SyscallSucceeds());
+
+ // We should have written something, but not the whole thing.
+ EXPECT_GT(n, 0);
+ EXPECT_LT(n, size);
+}
+
+// Test that a blocking write with a buffer that is larger than the send buffer
+// will block until the entire buffer is sent.
+TEST_P(TcpSocketTest, BlockingLargeWrite_NoRandomSave) {
+ // Allocate a buffer three times the size of the send buffer on the heap. We
+ // do this as a vector to avoid allocating on the stack.
+ int size = 3 * sendbuf_size_;
+ std::vector<char> writebuf(size);
+
+ // Start reading the response in a loop.
+ int read_bytes = 0;
+ ScopedThread t([this, &read_bytes]() {
+ // Avoid interrupting the blocking write in main thread.
+ const DisableSave ds;
+
+ // Take ownership of the FD so that we close it on failure. This will
+ // unblock the blocking write below.
+ FileDescriptor fd(t_);
+ t_ = -1;
+
+ char readbuf[2500] = {};
+ int n = -1;
+ while (n != 0) {
+ ASSERT_THAT(n = RetryEINTR(read)(fd.get(), &readbuf, sizeof(readbuf)),
+ SyscallSucceeds());
+ read_bytes += n;
+ }
+ });
+
+ // Try to write the whole thing.
+ int n;
+ ASSERT_THAT(n = WriteFd(s_, writebuf.data(), size), SyscallSucceeds());
+
+ // We should have written the whole thing.
+ EXPECT_EQ(n, size);
+ EXPECT_THAT(close(s_), SyscallSucceedsWithValue(0));
+ s_ = -1;
+ t.Join();
+
+ // We should have read the whole thing.
+ EXPECT_EQ(read_bytes, size);
+}
+
+// Test that a send with MSG_DONTWAIT flag and buffer that larger than the send
+// buffer size will not write the whole thing.
+TEST_P(TcpSocketTest, LargeSendDontWait) {
+ // Allocate a buffer three times the size of the send buffer. We do this on
+ // with a vector to avoid allocating on the stack.
+ int size = 3 * sendbuf_size_;
+ std::vector<char> buf(size);
+
+ // Try to write the whole thing with MSG_DONTWAIT flag, which can
+ // return a partial write.
+ int n;
+ ASSERT_THAT(n = RetryEINTR(send)(s_, buf.data(), size, MSG_DONTWAIT),
+ SyscallSucceeds());
+
+ // We should have written something, but not the whole thing.
+ EXPECT_GT(n, 0);
+ EXPECT_LT(n, size);
+}
+
+// Test that a send on a non-blocking socket with a buffer that larger than the
+// send buffer will not write the whole thing at once.
+TEST_P(TcpSocketTest, NonblockingLargeSend) {
+ // Set the FD to O_NONBLOCK.
+ int opts;
+ ASSERT_THAT(opts = fcntl(s_, F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(s_, F_SETFL, opts), SyscallSucceeds());
+
+ // Allocate a buffer three times the size of the send buffer. We do this on
+ // with a vector to avoid allocating on the stack.
+ int size = 3 * sendbuf_size_;
+ std::vector<char> buf(size);
+
+ // Try to write the whole thing.
+ int n;
+ ASSERT_THAT(n = RetryEINTR(send)(s_, buf.data(), size, 0), SyscallSucceeds());
+
+ // We should have written something, but not the whole thing.
+ EXPECT_GT(n, 0);
+ EXPECT_LT(n, size);
+}
+
+// Same test as above, but calls send instead of write.
+TEST_P(TcpSocketTest, BlockingLargeSend_NoRandomSave) {
+ // Allocate a buffer three times the size of the send buffer. We do this on
+ // with a vector to avoid allocating on the stack.
+ int size = 3 * sendbuf_size_;
+ std::vector<char> writebuf(size);
+
+ // Start reading the response in a loop.
+ int read_bytes = 0;
+ ScopedThread t([this, &read_bytes]() {
+ // Avoid interrupting the blocking write in main thread.
+ const DisableSave ds;
+
+ // Take ownership of the FD so that we close it on failure. This will
+ // unblock the blocking write below.
+ FileDescriptor fd(t_);
+ t_ = -1;
+
+ char readbuf[2500] = {};
+ int n = -1;
+ while (n != 0) {
+ ASSERT_THAT(n = RetryEINTR(read)(fd.get(), &readbuf, sizeof(readbuf)),
+ SyscallSucceeds());
+ read_bytes += n;
+ }
+ });
+
+ // Try to send the whole thing.
+ int n;
+ ASSERT_THAT(n = SendFd(s_, writebuf.data(), size, 0), SyscallSucceeds());
+
+ // We should have written the whole thing.
+ EXPECT_EQ(n, size);
+ EXPECT_THAT(close(s_), SyscallSucceedsWithValue(0));
+ s_ = -1;
+ t.Join();
+
+ // We should have read the whole thing.
+ EXPECT_EQ(read_bytes, size);
+}
+
+// Test that polling on a socket with a full send buffer will block.
+TEST_P(TcpSocketTest, PollWithFullBufferBlocks) {
+ // Set the FD to O_NONBLOCK.
+ int opts;
+ ASSERT_THAT(opts = fcntl(s_, F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(s_, F_SETFL, opts), SyscallSucceeds());
+
+ // Set TCP_NODELAY, which will cause linux to fill the receive buffer from the
+ // send buffer as quickly as possibly. This way we can fill up both buffers
+ // faster.
+ constexpr int tcp_nodelay_flag = 1;
+ ASSERT_THAT(setsockopt(s_, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay_flag,
+ sizeof(tcp_nodelay_flag)),
+ SyscallSucceeds());
+
+ // Set a 256KB send/receive buffer.
+ int buf_sz = 1 << 18;
+ EXPECT_THAT(setsockopt(t_, SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)),
+ SyscallSucceedsWithValue(0));
+ EXPECT_THAT(setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &buf_sz, sizeof(buf_sz)),
+ SyscallSucceedsWithValue(0));
+
+ // Create a large buffer that will be used for sending.
+ std::vector<char> buf(1 << 16);
+
+ // Write until we receive an error.
+ while (RetryEINTR(send)(s_, buf.data(), buf.size(), 0) != -1) {
+ // Sleep to give linux a chance to move data from the send buffer to the
+ // receive buffer.
+ usleep(10000); // 10ms.
+ }
+ // The last error should have been EWOULDBLOCK.
+ ASSERT_EQ(errno, EWOULDBLOCK);
+
+ // Now polling on the FD with a timeout should return 0 corresponding to no
+ // FDs ready.
+ struct pollfd poll_fd = {s_, POLLOUT, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10), SyscallSucceedsWithValue(0));
+}
+
+TEST_P(TcpSocketTest, MsgTrunc) {
+ char sent_data[512];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(RetryEINTR(send)(s_, sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ char received_data[sizeof(sent_data)] = {};
+ ASSERT_THAT(
+ RetryEINTR(recv)(t_, received_data, sizeof(received_data) / 2, MSG_TRUNC),
+ SyscallSucceedsWithValue(sizeof(sent_data) / 2));
+
+ // Check that we didn't get anything.
+ char zeros[sizeof(received_data)] = {};
+ EXPECT_EQ(0, memcmp(zeros, received_data, sizeof(received_data)));
+}
+
+// MSG_CTRUNC is a return flag but linux allows it to be set on input flags
+// without returning an error.
+TEST_P(TcpSocketTest, MsgTruncWithCtrunc) {
+ char sent_data[512];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(RetryEINTR(send)(s_, sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ char received_data[sizeof(sent_data)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(t_, received_data, sizeof(received_data) / 2,
+ MSG_TRUNC | MSG_CTRUNC),
+ SyscallSucceedsWithValue(sizeof(sent_data) / 2));
+
+ // Check that we didn't get anything.
+ char zeros[sizeof(received_data)] = {};
+ EXPECT_EQ(0, memcmp(zeros, received_data, sizeof(received_data)));
+}
+
+// This test will verify that MSG_CTRUNC doesn't do anything when specified
+// on input.
+TEST_P(TcpSocketTest, MsgTruncWithCtruncOnly) {
+ char sent_data[512];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(RetryEINTR(send)(s_, sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ char received_data[sizeof(sent_data)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(t_, received_data, sizeof(received_data) / 2,
+ MSG_CTRUNC),
+ SyscallSucceedsWithValue(sizeof(sent_data) / 2));
+
+ // Since MSG_CTRUNC here had no affect, it should not behave like MSG_TRUNC.
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data) / 2));
+}
+
+TEST_P(TcpSocketTest, MsgTruncLargeSize) {
+ char sent_data[512];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(RetryEINTR(send)(s_, sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ char received_data[sizeof(sent_data) * 2] = {};
+ ASSERT_THAT(
+ RetryEINTR(recv)(t_, received_data, sizeof(received_data), MSG_TRUNC),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ // Check that we didn't get anything.
+ char zeros[sizeof(received_data)] = {};
+ EXPECT_EQ(0, memcmp(zeros, received_data, sizeof(received_data)));
+}
+
+TEST_P(TcpSocketTest, MsgTruncPeek) {
+ char sent_data[512];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(RetryEINTR(send)(s_, sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ char received_data[sizeof(sent_data)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(t_, received_data, sizeof(received_data) / 2,
+ MSG_TRUNC | MSG_PEEK),
+ SyscallSucceedsWithValue(sizeof(sent_data) / 2));
+
+ // Check that we didn't get anything.
+ char zeros[sizeof(received_data)] = {};
+ EXPECT_EQ(0, memcmp(zeros, received_data, sizeof(received_data)));
+
+ // Check that we can still get all of the data.
+ ASSERT_THAT(RetryEINTR(recv)(t_, received_data, sizeof(received_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+}
+
+TEST_P(TcpSocketTest, NoDelayDefault) {
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(getsockopt(s_, IPPROTO_TCP, TCP_NODELAY, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+TEST_P(TcpSocketTest, SetNoDelay) {
+ ASSERT_THAT(
+ setsockopt(s_, IPPROTO_TCP, TCP_NODELAY, &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(getsockopt(s_, IPPROTO_TCP, TCP_NODELAY, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+
+ ASSERT_THAT(setsockopt(s_, IPPROTO_TCP, TCP_NODELAY, &kSockOptOff,
+ sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ EXPECT_THAT(getsockopt(s_, IPPROTO_TCP, TCP_NODELAY, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+#ifndef TCP_INQ
+#define TCP_INQ 36
+#endif
+
+TEST_P(TcpSocketTest, TcpInqSetSockOpt) {
+ char buf[1024];
+ ASSERT_THAT(RetryEINTR(write)(s_, buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // TCP_INQ is disabled by default.
+ int val = -1;
+ socklen_t slen = sizeof(val);
+ EXPECT_THAT(getsockopt(t_, SOL_TCP, TCP_INQ, &val, &slen),
+ SyscallSucceedsWithValue(0));
+ ASSERT_EQ(val, 0);
+
+ // Try to set TCP_INQ.
+ val = 1;
+ EXPECT_THAT(setsockopt(t_, SOL_TCP, TCP_INQ, &val, sizeof(val)),
+ SyscallSucceedsWithValue(0));
+ val = -1;
+ slen = sizeof(val);
+ EXPECT_THAT(getsockopt(t_, SOL_TCP, TCP_INQ, &val, &slen),
+ SyscallSucceedsWithValue(0));
+ ASSERT_EQ(val, 1);
+
+ // Try to unset TCP_INQ.
+ val = 0;
+ EXPECT_THAT(setsockopt(t_, SOL_TCP, TCP_INQ, &val, sizeof(val)),
+ SyscallSucceedsWithValue(0));
+ val = -1;
+ slen = sizeof(val);
+ EXPECT_THAT(getsockopt(t_, SOL_TCP, TCP_INQ, &val, &slen),
+ SyscallSucceedsWithValue(0));
+ ASSERT_EQ(val, 0);
+}
+
+TEST_P(TcpSocketTest, TcpInq) {
+ char buf[1024];
+ // Write more than one TCP segment.
+ int size = sizeof(buf);
+ int kChunk = sizeof(buf) / 4;
+ for (int i = 0; i < size; i += kChunk) {
+ ASSERT_THAT(RetryEINTR(write)(s_, buf, kChunk),
+ SyscallSucceedsWithValue(kChunk));
+ }
+
+ int val = 1;
+ kChunk = sizeof(buf) / 2;
+ EXPECT_THAT(setsockopt(t_, SOL_TCP, TCP_INQ, &val, sizeof(val)),
+ SyscallSucceedsWithValue(0));
+
+ // Wait when all data will be in the received queue.
+ while (true) {
+ ASSERT_THAT(ioctl(t_, TIOCINQ, &size), SyscallSucceeds());
+ if (size == sizeof(buf)) {
+ break;
+ }
+ absl::SleepFor(absl::Milliseconds(10));
+ }
+
+ struct msghdr msg = {};
+ std::vector<char> control(CMSG_SPACE(sizeof(int)));
+ size = sizeof(buf);
+ struct iovec iov;
+ for (int i = 0; size != 0; i += kChunk) {
+ msg.msg_control = &control[0];
+ msg.msg_controllen = control.size();
+
+ iov.iov_base = buf;
+ iov.iov_len = kChunk;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ ASSERT_THAT(RetryEINTR(recvmsg)(t_, &msg, 0),
+ SyscallSucceedsWithValue(kChunk));
+ size -= kChunk;
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int)));
+ ASSERT_EQ(cmsg->cmsg_level, SOL_TCP);
+ ASSERT_EQ(cmsg->cmsg_type, TCP_INQ);
+
+ int inq = 0;
+ memcpy(&inq, CMSG_DATA(cmsg), sizeof(int));
+ ASSERT_EQ(inq, size);
+ }
+}
+
+TEST_P(TcpSocketTest, Tiocinq) {
+ char buf[1024];
+ size_t size = sizeof(buf);
+ ASSERT_THAT(RetryEINTR(write)(s_, buf, size), SyscallSucceedsWithValue(size));
+
+ uint32_t seed = time(nullptr);
+ const size_t max_chunk = size / 10;
+ while (size > 0) {
+ size_t chunk = (rand_r(&seed) % max_chunk) + 1;
+ ssize_t read = RetryEINTR(recvfrom)(t_, buf, chunk, 0, nullptr, nullptr);
+ ASSERT_THAT(read, SyscallSucceeds());
+ size -= read;
+
+ int inq = 0;
+ ASSERT_THAT(ioctl(t_, TIOCINQ, &inq), SyscallSucceeds());
+ ASSERT_EQ(inq, size);
+ }
+}
+
+TEST_P(TcpSocketTest, TcpSCMPriority) {
+ char buf[1024];
+ ASSERT_THAT(RetryEINTR(write)(s_, buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ int val = 1;
+ EXPECT_THAT(setsockopt(t_, SOL_TCP, TCP_INQ, &val, sizeof(val)),
+ SyscallSucceedsWithValue(0));
+ EXPECT_THAT(setsockopt(t_, SOL_SOCKET, SO_TIMESTAMP, &val, sizeof(val)),
+ SyscallSucceedsWithValue(0));
+
+ struct msghdr msg = {};
+ std::vector<char> control(
+ CMSG_SPACE(sizeof(struct timeval) + CMSG_SPACE(sizeof(int))));
+ struct iovec iov;
+ msg.msg_control = &control[0];
+ msg.msg_controllen = control.size();
+
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ ASSERT_THAT(RetryEINTR(recvmsg)(t_, &msg, 0),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ // TODO(b/78348848): SO_TIMESTAMP isn't implemented for TCP sockets.
+ if (!IsRunningOnGvisor() || cmsg->cmsg_level == SOL_SOCKET) {
+ ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET);
+ ASSERT_EQ(cmsg->cmsg_type, SO_TIMESTAMP);
+ ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct timeval)));
+
+ cmsg = CMSG_NXTHDR(&msg, cmsg);
+ ASSERT_NE(cmsg, nullptr);
+ }
+ ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int)));
+ ASSERT_EQ(cmsg->cmsg_level, SOL_TCP);
+ ASSERT_EQ(cmsg->cmsg_type, TCP_INQ);
+
+ int inq = 0;
+ memcpy(&inq, CMSG_DATA(cmsg), sizeof(int));
+ ASSERT_EQ(inq, 0);
+
+ cmsg = CMSG_NXTHDR(&msg, cmsg);
+ ASSERT_EQ(cmsg, nullptr);
+}
+
+INSTANTIATE_TEST_SUITE_P(AllInetTests, TcpSocketTest,
+ ::testing::Values(AF_INET, AF_INET6));
+
+// Fixture for tests parameterized by address family that don't want the fixture
+// to do things.
+using SimpleTcpSocketTest = ::testing::TestWithParam<int>;
+
+TEST_P(SimpleTcpSocketTest, SendUnconnected) {
+ int fd;
+ ASSERT_THAT(fd = socket(GetParam(), SOCK_STREAM, IPPROTO_TCP),
+ SyscallSucceeds());
+ FileDescriptor sock_fd(fd);
+
+ char data = '\0';
+ EXPECT_THAT(RetryEINTR(send)(fd, &data, sizeof(data), 0),
+ SyscallFailsWithErrno(EPIPE));
+}
+
+TEST_P(SimpleTcpSocketTest, SendtoWithoutAddressUnconnected) {
+ int fd;
+ ASSERT_THAT(fd = socket(GetParam(), SOCK_STREAM, IPPROTO_TCP),
+ SyscallSucceeds());
+ FileDescriptor sock_fd(fd);
+
+ char data = '\0';
+ EXPECT_THAT(RetryEINTR(sendto)(fd, &data, sizeof(data), 0, nullptr, 0),
+ SyscallFailsWithErrno(EPIPE));
+}
+
+TEST_P(SimpleTcpSocketTest, SendtoWithAddressUnconnected) {
+ int fd;
+ ASSERT_THAT(fd = socket(GetParam(), SOCK_STREAM, IPPROTO_TCP),
+ SyscallSucceeds());
+ FileDescriptor sock_fd(fd);
+
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ char data = '\0';
+ EXPECT_THAT(
+ RetryEINTR(sendto)(fd, &data, sizeof(data), 0,
+ reinterpret_cast<sockaddr*>(&addr), sizeof(addr)),
+ SyscallFailsWithErrno(EPIPE));
+}
+
+TEST_P(SimpleTcpSocketTest, GetPeerNameUnconnected) {
+ int fd;
+ ASSERT_THAT(fd = socket(GetParam(), SOCK_STREAM, IPPROTO_TCP),
+ SyscallSucceeds());
+ FileDescriptor sock_fd(fd);
+
+ sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(getpeername(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
+TEST_P(TcpSocketTest, FullBuffer) {
+ // Set both FDs to be blocking.
+ int flags = 0;
+ ASSERT_THAT(flags = fcntl(s_, F_GETFL), SyscallSucceeds());
+ EXPECT_THAT(fcntl(s_, F_SETFL, flags & ~O_NONBLOCK), SyscallSucceeds());
+ flags = 0;
+ ASSERT_THAT(flags = fcntl(t_, F_GETFL), SyscallSucceeds());
+ EXPECT_THAT(fcntl(t_, F_SETFL, flags & ~O_NONBLOCK), SyscallSucceeds());
+
+ // 2500 was chosen as a small value that can be set on Linux.
+ int set_snd = 2500;
+ EXPECT_THAT(setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &set_snd, sizeof(set_snd)),
+ SyscallSucceedsWithValue(0));
+ int get_snd = -1;
+ socklen_t get_snd_len = sizeof(get_snd);
+ EXPECT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &get_snd, &get_snd_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_snd_len, sizeof(get_snd));
+ EXPECT_GT(get_snd, 0);
+
+ // 2500 was chosen as a small value that can be set on Linux and gVisor.
+ int set_rcv = 2500;
+ EXPECT_THAT(setsockopt(t_, SOL_SOCKET, SO_RCVBUF, &set_rcv, sizeof(set_rcv)),
+ SyscallSucceedsWithValue(0));
+ int get_rcv = -1;
+ socklen_t get_rcv_len = sizeof(get_rcv);
+ EXPECT_THAT(getsockopt(t_, SOL_SOCKET, SO_RCVBUF, &get_rcv, &get_rcv_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_rcv_len, sizeof(get_rcv));
+ EXPECT_GE(get_rcv, 2500);
+
+ // Quick sanity test.
+ EXPECT_LT(get_snd + get_rcv, 2500 * IOV_MAX);
+
+ char data[2500] = {};
+ std::vector<struct iovec> iovecs;
+ for (int i = 0; i < IOV_MAX; i++) {
+ struct iovec iov = {};
+ iov.iov_base = data;
+ iov.iov_len = sizeof(data);
+ iovecs.push_back(iov);
+ }
+ ScopedThread t([this, &iovecs]() {
+ int result = -1;
+ EXPECT_THAT(result = RetryEINTR(writev)(s_, iovecs.data(), iovecs.size()),
+ SyscallSucceeds());
+ EXPECT_GT(result, 1);
+ EXPECT_LT(result, sizeof(data) * iovecs.size());
+ });
+
+ char recv = 0;
+ EXPECT_THAT(RetryEINTR(read)(t_, &recv, 1), SyscallSucceedsWithValue(1));
+ EXPECT_THAT(close(t_), SyscallSucceedsWithValue(0));
+ t_ = -1;
+}
+
+TEST_P(TcpSocketTest, PollAfterShutdown) {
+ ScopedThread client_thread([this]() {
+ EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallSucceedsWithValue(0));
+ struct pollfd poll_fd = {s_, POLLIN | POLLERR | POLLHUP, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000),
+ SyscallSucceedsWithValue(1));
+ });
+
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceedsWithValue(0));
+ struct pollfd poll_fd = {t_, POLLIN | POLLERR | POLLHUP, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000),
+ SyscallSucceedsWithValue(1));
+}
+
+TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListener) {
+ // Initialize address to the loopback one.
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ const FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ // Set the FD to O_NONBLOCK.
+ int opts;
+ ASSERT_THAT(opts = fcntl(s.get(), F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(s.get(), F_SETFL, opts), SyscallSucceeds());
+
+ ASSERT_THAT(RetryEINTR(connect)(
+ s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallFailsWithErrno(EINPROGRESS));
+
+ // Now polling on the FD with a timeout should return 0 corresponding to no
+ // FDs ready.
+ struct pollfd poll_fd = {s.get(), POLLOUT, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000),
+ SyscallSucceedsWithValue(1));
+
+ int err;
+ socklen_t optlen = sizeof(err);
+ ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_ERROR, &err, &optlen),
+ SyscallSucceeds());
+
+ EXPECT_EQ(err, ECONNREFUSED);
+}
+
+TEST_P(SimpleTcpSocketTest, NonBlockingConnect) {
+ const FileDescriptor listener =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ // Initialize address to the loopback one.
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ // Bind to some port then start listening.
+ ASSERT_THAT(
+ bind(listener.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(listener.get(), SOMAXCONN), SyscallSucceeds());
+
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ // Set the FD to O_NONBLOCK.
+ int opts;
+ ASSERT_THAT(opts = fcntl(s.get(), F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(s.get(), F_SETFL, opts), SyscallSucceeds());
+
+ ASSERT_THAT(getsockname(listener.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(RetryEINTR(connect)(
+ s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallFailsWithErrno(EINPROGRESS));
+
+ int t;
+ ASSERT_THAT(t = RetryEINTR(accept)(listener.get(), nullptr, nullptr),
+ SyscallSucceeds());
+
+ // Now polling on the FD with a timeout should return 0 corresponding to no
+ // FDs ready.
+ struct pollfd poll_fd = {s.get(), POLLOUT, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000),
+ SyscallSucceedsWithValue(1));
+
+ int err;
+ socklen_t optlen = sizeof(err);
+ ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_ERROR, &err, &optlen),
+ SyscallSucceeds());
+
+ EXPECT_EQ(err, 0);
+
+ EXPECT_THAT(close(t), SyscallSucceeds());
+}
+
+TEST_P(SimpleTcpSocketTest, NonBlockingConnectRemoteClose) {
+ const FileDescriptor listener =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ // Initialize address to the loopback one.
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ // Bind to some port then start listening.
+ ASSERT_THAT(
+ bind(listener.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(listener.get(), SOMAXCONN), SyscallSucceeds());
+
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+
+ ASSERT_THAT(getsockname(listener.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(RetryEINTR(connect)(
+ s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallFailsWithErrno(EINPROGRESS));
+
+ int t;
+ ASSERT_THAT(t = RetryEINTR(accept)(listener.get(), nullptr, nullptr),
+ SyscallSucceeds());
+
+ EXPECT_THAT(close(t), SyscallSucceeds());
+
+ // Now polling on the FD with a timeout should return 0 corresponding to no
+ // FDs ready.
+ struct pollfd poll_fd = {s.get(), POLLOUT, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000),
+ SyscallSucceedsWithValue(1));
+
+ ASSERT_THAT(RetryEINTR(connect)(
+ s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(RetryEINTR(connect)(
+ s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallFailsWithErrno(EISCONN));
+}
+
+// Test that we get an ECONNREFUSED with a blocking socket when no one is
+// listening on the other end.
+TEST_P(SimpleTcpSocketTest, BlockingConnectRefused) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ // Initialize address to the loopback one.
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ ASSERT_THAT(RetryEINTR(connect)(
+ s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallFailsWithErrno(ECONNREFUSED));
+
+ // Avoiding triggering save in destructor of s.
+ EXPECT_THAT(close(s.release()), SyscallSucceeds());
+}
+
+// Test that connecting to a non-listening port and thus receiving a RST is
+// handled appropriately by the socket - the port that the socket was bound to
+// is released and the expected error is returned.
+TEST_P(SimpleTcpSocketTest, CleanupOnConnectionRefused) {
+ // Create a socket that is known to not be listening. As is it bound but not
+ // listening, when another socket connects to the port, it will refuse..
+ FileDescriptor bound_s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ sockaddr_storage bound_addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t bound_addrlen = sizeof(bound_addr);
+
+ ASSERT_THAT(
+ bind(bound_s.get(), reinterpret_cast<struct sockaddr*>(&bound_addr),
+ bound_addrlen),
+ SyscallSucceeds());
+
+ // Get the addresses the socket is bound to because the port is chosen by the
+ // stack.
+ ASSERT_THAT(getsockname(bound_s.get(),
+ reinterpret_cast<struct sockaddr*>(&bound_addr),
+ &bound_addrlen),
+ SyscallSucceeds());
+
+ // Create, initialize, and bind the socket that is used to test connecting to
+ // the non-listening port.
+ FileDescriptor client_s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ // Initialize client address to the loopback one.
+ sockaddr_storage client_addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t client_addrlen = sizeof(client_addr);
+
+ ASSERT_THAT(
+ bind(client_s.get(), reinterpret_cast<struct sockaddr*>(&client_addr),
+ client_addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(getsockname(client_s.get(),
+ reinterpret_cast<struct sockaddr*>(&client_addr),
+ &client_addrlen),
+ SyscallSucceeds());
+
+ // Now the test: connect to the bound but not listening socket with the
+ // client socket. The bound socket should return a RST and cause the client
+ // socket to return an error and clean itself up immediately.
+ // The error being ECONNREFUSED diverges with RFC 793, page 37, but does what
+ // Linux does.
+ ASSERT_THAT(connect(client_s.get(),
+ reinterpret_cast<const struct sockaddr*>(&bound_addr),
+ bound_addrlen),
+ SyscallFailsWithErrno(ECONNREFUSED));
+
+ FileDescriptor new_s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ // Test binding to the address from the client socket. This should be okay
+ // if it was dropped correctly.
+ ASSERT_THAT(
+ bind(new_s.get(), reinterpret_cast<struct sockaddr*>(&client_addr),
+ client_addrlen),
+ SyscallSucceeds());
+
+ // Attempt #2, with the new socket and reused addr our connect should fail in
+ // the same way as before, not with an EADDRINUSE.
+ ASSERT_THAT(connect(client_s.get(),
+ reinterpret_cast<const struct sockaddr*>(&bound_addr),
+ bound_addrlen),
+ SyscallFailsWithErrno(ECONNREFUSED));
+}
+
+// Test that we get an ECONNREFUSED with a nonblocking socket.
+TEST_P(SimpleTcpSocketTest, NonBlockingConnectRefused) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+
+ // Initialize address to the loopback one.
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ ASSERT_THAT(RetryEINTR(connect)(
+ s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallFailsWithErrno(EINPROGRESS));
+
+ // We don't need to specify any events to get POLLHUP or POLLERR as these
+ // are added before the poll.
+ struct pollfd poll_fd = {s.get(), /*events=*/0, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 1000), SyscallSucceedsWithValue(1));
+
+ // The ECONNREFUSED should cause us to be woken up with POLLHUP.
+ EXPECT_NE(poll_fd.revents & (POLLHUP | POLLERR), 0);
+
+ // Avoiding triggering save in destructor of s.
+ EXPECT_THAT(close(s.release()), SyscallSucceeds());
+}
+
+// Test that setting a supported congestion control algorithm succeeds for an
+// unconnected TCP socket
+TEST_P(SimpleTcpSocketTest, SetCongestionControlSucceedsForSupported) {
+ // This is Linux's net/tcp.h TCP_CA_NAME_MAX.
+ const int kTcpCaNameMax = 16;
+
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ {
+ const char kSetCC[kTcpCaNameMax] = "reno";
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC,
+ strlen(kSetCC)),
+ SyscallSucceedsWithValue(0));
+
+ char got_cc[kTcpCaNameMax];
+ memset(got_cc, '1', sizeof(got_cc));
+ socklen_t optlen = sizeof(got_cc);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ // We ignore optlen here as the linux kernel sets optlen to the lower of the
+ // size of the buffer passed in or kTcpCaNameMax and not the length of the
+ // congestion control algorithm's actual name.
+ EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kTcpCaNameMax)));
+ }
+ {
+ const char kSetCC[kTcpCaNameMax] = "cubic";
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC,
+ strlen(kSetCC)),
+ SyscallSucceedsWithValue(0));
+
+ char got_cc[kTcpCaNameMax];
+ memset(got_cc, '1', sizeof(got_cc));
+ socklen_t optlen = sizeof(got_cc);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ // We ignore optlen here as the linux kernel sets optlen to the lower of the
+ // size of the buffer passed in or kTcpCaNameMax and not the length of the
+ // congestion control algorithm's actual name.
+ EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kTcpCaNameMax)));
+ }
+}
+
+// This test verifies that a getsockopt(...TCP_CONGESTION) behaviour is
+// consistent between linux and gvisor when the passed in buffer is smaller than
+// kTcpCaNameMax.
+TEST_P(SimpleTcpSocketTest, SetGetTCPCongestionShortReadBuffer) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ {
+ // Verify that getsockopt/setsockopt work with buffers smaller than
+ // kTcpCaNameMax.
+ const char kSetCC[] = "cubic";
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC,
+ strlen(kSetCC)),
+ SyscallSucceedsWithValue(0));
+
+ char got_cc[sizeof(kSetCC)];
+ socklen_t optlen = sizeof(got_cc);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(sizeof(got_cc), optlen);
+ EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(got_cc)));
+ }
+}
+
+// This test verifies that a getsockopt(...TCP_CONGESTION) behaviour is
+// consistent between linux and gvisor when the passed in buffer is larger than
+// kTcpCaNameMax.
+TEST_P(SimpleTcpSocketTest, SetGetTCPCongestionLargeReadBuffer) {
+ // This is Linux's net/tcp.h TCP_CA_NAME_MAX.
+ const int kTcpCaNameMax = 16;
+
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ {
+ // Verify that getsockopt works with buffers larger than
+ // kTcpCaNameMax.
+ const char kSetCC[] = "cubic";
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC,
+ strlen(kSetCC)),
+ SyscallSucceedsWithValue(0));
+
+ char got_cc[kTcpCaNameMax + 5];
+ socklen_t optlen = sizeof(got_cc);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ // Linux copies the minimum of kTcpCaNameMax or the length of the passed in
+ // buffer and sets optlen to the number of bytes actually copied
+ // irrespective of the actual length of the congestion control name.
+ EXPECT_EQ(kTcpCaNameMax, optlen);
+ EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC)));
+ }
+}
+
+// Test that setting an unsupported congestion control algorithm fails for an
+// unconnected TCP socket.
+TEST_P(SimpleTcpSocketTest, SetCongestionControlFailsForUnsupported) {
+ // This is Linux's net/tcp.h TCP_CA_NAME_MAX.
+ const int kTcpCaNameMax = 16;
+
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ char old_cc[kTcpCaNameMax];
+ socklen_t optlen = sizeof(old_cc);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &old_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+
+ const char kSetCC[] = "invalid_ca_kSetCC";
+ ASSERT_THAT(
+ setsockopt(s.get(), SOL_TCP, TCP_CONGESTION, &kSetCC, strlen(kSetCC)),
+ SyscallFailsWithErrno(ENOENT));
+
+ char got_cc[kTcpCaNameMax];
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ // We ignore optlen here as the linux kernel sets optlen to the lower of the
+ // size of the buffer passed in or kTcpCaNameMax and not the length of the
+ // congestion control algorithm's actual name.
+ EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(kTcpCaNameMax)));
+}
+
+TEST_P(SimpleTcpSocketTest, MaxSegDefault) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ constexpr int kDefaultMSS = 536;
+ int tcp_max_seg;
+ socklen_t optlen = sizeof(tcp_max_seg);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_MAXSEG, &tcp_max_seg, &optlen),
+ SyscallSucceedsWithValue(0));
+
+ EXPECT_EQ(kDefaultMSS, tcp_max_seg);
+ EXPECT_EQ(sizeof(tcp_max_seg), optlen);
+}
+
+TEST_P(SimpleTcpSocketTest, SetMaxSeg) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ constexpr int kDefaultMSS = 536;
+ constexpr int kTCPMaxSeg = 1024;
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_MAXSEG, &kTCPMaxSeg,
+ sizeof(kTCPMaxSeg)),
+ SyscallSucceedsWithValue(0));
+
+ // Linux actually never returns the user_mss value. It will always return the
+ // default MSS value defined above for an unconnected socket and always return
+ // the actual current MSS for a connected one.
+ int optval;
+ socklen_t optlen = sizeof(optval);
+ ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_MAXSEG, &optval, &optlen),
+ SyscallSucceedsWithValue(0));
+
+ EXPECT_EQ(kDefaultMSS, optval);
+ EXPECT_EQ(sizeof(optval), optlen);
+}
+
+TEST_P(SimpleTcpSocketTest, SetMaxSegFailsForInvalidMSSValues) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ {
+ constexpr int tcp_max_seg = 10;
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_MAXSEG, &tcp_max_seg,
+ sizeof(tcp_max_seg)),
+ SyscallFailsWithErrno(EINVAL));
+ }
+ {
+ constexpr int tcp_max_seg = 75000;
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_MAXSEG, &tcp_max_seg,
+ sizeof(tcp_max_seg)),
+ SyscallFailsWithErrno(EINVAL));
+ }
+}
+
+TEST_P(SimpleTcpSocketTest, SetTCPUserTimeout) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ {
+ constexpr int kTCPUserTimeout = -1;
+ EXPECT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &kTCPUserTimeout, sizeof(kTCPUserTimeout)),
+ SyscallFailsWithErrno(EINVAL));
+ }
+
+ // kTCPUserTimeout is in milliseconds.
+ constexpr int kTCPUserTimeout = 100;
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &kTCPUserTimeout, sizeof(kTCPUserTimeout)),
+ SyscallSucceedsWithValue(0));
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kTCPUserTimeout);
+}
+
+TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptNeg) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ // -ve TCP_DEFER_ACCEPT is same as setting it to zero.
+ constexpr int kNeg = -1;
+ EXPECT_THAT(
+ setsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &kNeg, sizeof(kNeg)),
+ SyscallSucceeds());
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 0);
+}
+
+TEST_P(SimpleTcpSocketTest, GetTCPDeferAcceptDefault) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 0);
+}
+
+TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptGreaterThanZero) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ // kTCPDeferAccept is in seconds.
+ // NOTE: linux translates seconds to # of retries and back from
+ // #of retries to seconds. Which means only certain values
+ // translate back exactly. That's why we use 3 here, a value of
+ // 5 will result in us getting back 7 instead of 5 in the
+ // getsockopt.
+ constexpr int kTCPDeferAccept = 3;
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT,
+ &kTCPDeferAccept, sizeof(kTCPDeferAccept)),
+ SyscallSucceeds());
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kTCPDeferAccept);
+}
+
+TEST_P(SimpleTcpSocketTest, RecvOnClosedSocket) {
+ auto s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ char buf[1];
+ EXPECT_THAT(recv(s.get(), buf, 0, 0), SyscallFailsWithErrno(ENOTCONN));
+ EXPECT_THAT(recv(s.get(), buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
+TEST_P(SimpleTcpSocketTest, TCPConnectSoRcvBufRace) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ RetryEINTR(connect)(s.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ addrlen);
+ int buf_sz = 1 << 18;
+ EXPECT_THAT(
+ setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_P(SimpleTcpSocketTest, SetTCPSynCntLessThanOne) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ int default_syn_cnt = get;
+
+ {
+ // TCP_SYNCNT less than 1 should be rejected with an EINVAL.
+ constexpr int kZero = 0;
+ EXPECT_THAT(
+ setsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &kZero, sizeof(kZero)),
+ SyscallFailsWithErrno(EINVAL));
+
+ // TCP_SYNCNT less than 1 should be rejected with an EINVAL.
+ constexpr int kNeg = -1;
+ EXPECT_THAT(
+ setsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &kNeg, sizeof(kNeg)),
+ SyscallFailsWithErrno(EINVAL));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(default_syn_cnt, get);
+ }
+}
+
+TEST_P(SimpleTcpSocketTest, GetTCPSynCntDefault) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ constexpr int kDefaultSynCnt = 6;
+
+ ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kDefaultSynCnt);
+}
+
+TEST_P(SimpleTcpSocketTest, SetTCPSynCntGreaterThanOne) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ constexpr int kTCPSynCnt = 20;
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &kTCPSynCnt,
+ sizeof(kTCPSynCnt)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kTCPSynCnt);
+}
+
+TEST_P(SimpleTcpSocketTest, SetTCPSynCntAboveMax) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ int default_syn_cnt = get;
+ {
+ constexpr int kTCPSynCnt = 256;
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &kTCPSynCnt,
+ sizeof(kTCPSynCnt)),
+ SyscallFailsWithErrno(EINVAL));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, default_syn_cnt);
+ }
+}
+
+TEST_P(SimpleTcpSocketTest, SetTCPWindowClampBelowMinRcvBuf) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ // Discover minimum receive buf by setting a really low value
+ // for the receive buffer.
+ constexpr int kZero = 0;
+ EXPECT_THAT(setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &kZero, sizeof(kZero)),
+ SyscallSucceeds());
+
+ // Now retrieve the minimum value for SO_RCVBUF as the set above should
+ // have caused SO_RCVBUF for the socket to be set to the minimum.
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ int min_so_rcvbuf = get;
+
+ {
+ // TCP_WINDOW_CLAMP less than min_so_rcvbuf/2 should be set to
+ // min_so_rcvbuf/2.
+ int below_half_min_rcvbuf = min_so_rcvbuf / 2 - 1;
+ EXPECT_THAT(
+ setsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP,
+ &below_half_min_rcvbuf, sizeof(below_half_min_rcvbuf)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(min_so_rcvbuf / 2, get);
+ }
+}
+
+TEST_P(SimpleTcpSocketTest, SetTCPWindowClampZeroClosedSocket) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ constexpr int kZero = 0;
+ ASSERT_THAT(
+ setsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, &kZero, sizeof(kZero)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kZero);
+}
+
+TEST_P(SimpleTcpSocketTest, SetTCPWindowClampAboveHalfMinRcvBuf) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ // Discover minimum receive buf by setting a really low value
+ // for the receive buffer.
+ constexpr int kZero = 0;
+ EXPECT_THAT(setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &kZero, sizeof(kZero)),
+ SyscallSucceeds());
+
+ // Now retrieve the minimum value for SO_RCVBUF as the set above should
+ // have caused SO_RCVBUF for the socket to be set to the minimum.
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ int min_so_rcvbuf = get;
+
+ {
+ int above_half_min_rcv_buf = min_so_rcvbuf / 2 + 1;
+ EXPECT_THAT(
+ setsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP,
+ &above_half_min_rcv_buf, sizeof(above_half_min_rcv_buf)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(above_half_min_rcv_buf, get);
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest,
+ ::testing::Values(AF_INET, AF_INET6));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/tgkill.cc b/test/syscalls/linux/tgkill.cc
new file mode 100644
index 000000000..80acae5de
--- /dev/null
+++ b/test/syscalls/linux/tgkill.cc
@@ -0,0 +1,48 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(TgkillTest, InvalidTID) {
+ EXPECT_THAT(tgkill(getpid(), -1, 0), SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(tgkill(getpid(), 0, 0), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(TgkillTest, InvalidTGID) {
+ EXPECT_THAT(tgkill(-1, gettid(), 0), SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(tgkill(0, gettid(), 0), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(TgkillTest, ValidInput) {
+ EXPECT_THAT(tgkill(getpid(), gettid(), 0), SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/time.cc b/test/syscalls/linux/time.cc
new file mode 100644
index 000000000..e75bba669
--- /dev/null
+++ b/test/syscalls/linux/time.cc
@@ -0,0 +1,107 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <time.h>
+
+#include "gtest/gtest.h"
+#include "test/util/proc_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr long kFudgeSeconds = 5;
+
+#if defined(__x86_64__) || defined(__i386__)
+// Mimics the time(2) wrapper from glibc prior to 2.15.
+time_t vsyscall_time(time_t* t) {
+ constexpr uint64_t kVsyscallTimeEntry = 0xffffffffff600400;
+ return reinterpret_cast<time_t (*)(time_t*)>(kVsyscallTimeEntry)(t);
+}
+
+TEST(TimeTest, VsyscallTime_Succeeds) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsVsyscallEnabled()));
+
+ time_t t1, t2;
+
+ {
+ const DisableSave ds; // Timing assertions.
+ EXPECT_THAT(time(&t1), SyscallSucceeds());
+ EXPECT_THAT(vsyscall_time(&t2), SyscallSucceeds());
+ }
+
+ // Time should be monotonic.
+ EXPECT_LE(static_cast<long>(t1), static_cast<long>(t2));
+
+ // Check that it's within kFudge seconds.
+ EXPECT_LE(static_cast<long>(t2), static_cast<long>(t1) + kFudgeSeconds);
+
+ // Redo with save.
+ EXPECT_THAT(time(&t1), SyscallSucceeds());
+ EXPECT_THAT(vsyscall_time(&t2), SyscallSucceeds());
+
+ // Time should be monotonic.
+ EXPECT_LE(static_cast<long>(t1), static_cast<long>(t2));
+}
+
+TEST(TimeTest, VsyscallTime_InvalidAddressSIGSEGV) {
+ EXPECT_EXIT(vsyscall_time(reinterpret_cast<time_t*>(0x1)),
+ ::testing::KilledBySignal(SIGSEGV), "");
+}
+
+// Mimics the gettimeofday(2) wrapper from the Go runtime <= 1.2.
+int vsyscall_gettimeofday(struct timeval* tv, struct timezone* tz) {
+ constexpr uint64_t kVsyscallGettimeofdayEntry = 0xffffffffff600000;
+ return reinterpret_cast<int (*)(struct timeval*, struct timezone*)>(
+ kVsyscallGettimeofdayEntry)(tv, tz);
+}
+
+TEST(TimeTest, VsyscallGettimeofday_Succeeds) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsVsyscallEnabled()));
+
+ struct timeval tv1, tv2;
+ struct timezone tz1, tz2;
+
+ {
+ const DisableSave ds; // Timing assertions.
+ EXPECT_THAT(gettimeofday(&tv1, &tz1), SyscallSucceeds());
+ EXPECT_THAT(vsyscall_gettimeofday(&tv2, &tz2), SyscallSucceeds());
+ }
+
+ // See above.
+ EXPECT_LE(static_cast<long>(tv1.tv_sec), static_cast<long>(tv2.tv_sec));
+ EXPECT_LE(static_cast<long>(tv2.tv_sec),
+ static_cast<long>(tv1.tv_sec) + kFudgeSeconds);
+
+ // Redo with save.
+ EXPECT_THAT(gettimeofday(&tv1, &tz1), SyscallSucceeds());
+ EXPECT_THAT(vsyscall_gettimeofday(&tv2, &tz2), SyscallSucceeds());
+}
+
+TEST(TimeTest, VsyscallGettimeofday_InvalidAddressSIGSEGV) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsVsyscallEnabled()));
+
+ EXPECT_EXIT(vsyscall_gettimeofday(reinterpret_cast<struct timeval*>(0x1),
+ reinterpret_cast<struct timezone*>(0x1)),
+ ::testing::KilledBySignal(SIGSEGV), "");
+}
+#endif
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/timerfd.cc b/test/syscalls/linux/timerfd.cc
new file mode 100644
index 000000000..c4f8fdd7a
--- /dev/null
+++ b/test/syscalls/linux/timerfd.cc
@@ -0,0 +1,273 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <poll.h>
+#include <sys/timerfd.h>
+#include <time.h>
+
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Wrapper around timerfd_create(2) that returns a FileDescriptor.
+PosixErrorOr<FileDescriptor> TimerfdCreate(int clockid, int flags) {
+ int fd = timerfd_create(clockid, flags);
+ MaybeSave();
+ if (fd < 0) {
+ return PosixError(errno, "timerfd_create failed");
+ }
+ return FileDescriptor(fd);
+}
+
+// In tests that race a timerfd with a sleep, some slack is required because:
+//
+// - Timerfd expirations are asynchronous with respect to nanosleeps.
+//
+// - Because clock_gettime(CLOCK_MONOTONIC) is implemented through the VDSO,
+// it technically uses a closely-related, but distinct, time domain from the
+// CLOCK_MONOTONIC used to trigger timerfd expirations. The same applies to
+// CLOCK_BOOTTIME which is an alias for CLOCK_MONOTONIC.
+absl::Duration TimerSlack() { return absl::Milliseconds(500); }
+
+class TimerfdTest : public ::testing::TestWithParam<int> {};
+
+TEST_P(TimerfdTest, IsInitiallyStopped) {
+ auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0));
+ struct itimerspec its = {};
+ ASSERT_THAT(timerfd_gettime(tfd.get(), &its), SyscallSucceeds());
+ EXPECT_EQ(0, its.it_value.tv_sec);
+ EXPECT_EQ(0, its.it_value.tv_nsec);
+}
+
+TEST_P(TimerfdTest, SingleShot) {
+ constexpr absl::Duration kDelay = absl::Seconds(1);
+
+ auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0));
+ struct itimerspec its = {};
+ its.it_value = absl::ToTimespec(kDelay);
+ ASSERT_THAT(timerfd_settime(tfd.get(), /* flags = */ 0, &its, nullptr),
+ SyscallSucceeds());
+
+ // The timer should fire exactly once since the interval is zero.
+ absl::SleepFor(kDelay + TimerSlack());
+ uint64_t val = 0;
+ ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)),
+ SyscallSucceedsWithValue(sizeof(uint64_t)));
+ EXPECT_EQ(1, val);
+}
+
+TEST_P(TimerfdTest, Periodic) {
+ constexpr absl::Duration kDelay = absl::Seconds(1);
+ constexpr int kPeriods = 3;
+
+ auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0));
+ struct itimerspec its = {};
+ its.it_value = absl::ToTimespec(kDelay);
+ its.it_interval = absl::ToTimespec(kDelay);
+ ASSERT_THAT(timerfd_settime(tfd.get(), /* flags = */ 0, &its, nullptr),
+ SyscallSucceeds());
+
+ // Expect to see at least kPeriods expirations. More may occur due to the
+ // timer slack, or due to delays from scheduling or save/restore.
+ absl::SleepFor(kPeriods * kDelay + TimerSlack());
+ uint64_t val = 0;
+ ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)),
+ SyscallSucceedsWithValue(sizeof(uint64_t)));
+ EXPECT_GE(val, kPeriods);
+}
+
+TEST_P(TimerfdTest, BlockingRead) {
+ constexpr absl::Duration kDelay = absl::Seconds(3);
+
+ auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0));
+ struct itimerspec its = {};
+ its.it_value.tv_sec = absl::ToInt64Seconds(kDelay);
+ auto const start_time = absl::Now();
+ ASSERT_THAT(timerfd_settime(tfd.get(), /* flags = */ 0, &its, nullptr),
+ SyscallSucceeds());
+
+ // read should block until the timer fires.
+ uint64_t val = 0;
+ ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)),
+ SyscallSucceedsWithValue(sizeof(uint64_t)));
+ auto const end_time = absl::Now();
+ EXPECT_EQ(1, val);
+ EXPECT_GE((end_time - start_time) + TimerSlack(), kDelay);
+}
+
+TEST_P(TimerfdTest, NonblockingRead_NoRandomSave) {
+ constexpr absl::Duration kDelay = absl::Seconds(5);
+
+ auto const tfd =
+ ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), TFD_NONBLOCK));
+
+ // Since the timer is initially disabled and has never fired, read should
+ // return EAGAIN.
+ uint64_t val = 0;
+ ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)),
+ SyscallFailsWithErrno(EAGAIN));
+
+ DisableSave ds; // Timing-sensitive.
+
+ // Arm the timer.
+ struct itimerspec its = {};
+ its.it_value.tv_sec = absl::ToInt64Seconds(kDelay);
+ ASSERT_THAT(timerfd_settime(tfd.get(), /* flags = */ 0, &its, nullptr),
+ SyscallSucceeds());
+
+ // Since the timer has not yet fired, read should return EAGAIN.
+ ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)),
+ SyscallFailsWithErrno(EAGAIN));
+
+ ds.reset(); // No longer timing-sensitive.
+
+ // After the timer fires, read should indicate 1 expiration.
+ absl::SleepFor(kDelay + TimerSlack());
+ ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)),
+ SyscallSucceedsWithValue(sizeof(uint64_t)));
+ EXPECT_EQ(1, val);
+
+ // The successful read should have reset the number of expirations.
+ ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST_P(TimerfdTest, BlockingPoll_SetTimeResetsExpirations) {
+ constexpr absl::Duration kDelay = absl::Seconds(3);
+
+ auto const tfd =
+ ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), TFD_NONBLOCK));
+ struct itimerspec its = {};
+ its.it_value.tv_sec = absl::ToInt64Seconds(kDelay);
+ auto const start_time = absl::Now();
+ ASSERT_THAT(timerfd_settime(tfd.get(), /* flags = */ 0, &its, nullptr),
+ SyscallSucceeds());
+
+ // poll should block until the timer fires.
+ struct pollfd pfd = {};
+ pfd.fd = tfd.get();
+ pfd.events = POLLIN;
+ ASSERT_THAT(poll(&pfd, /* nfds = */ 1,
+ /* timeout = */ 2 * absl::ToInt64Seconds(kDelay) * 1000),
+ SyscallSucceedsWithValue(1));
+ auto const end_time = absl::Now();
+ EXPECT_EQ(POLLIN, pfd.revents);
+ EXPECT_GE((end_time - start_time) + TimerSlack(), kDelay);
+
+ // Call timerfd_settime again with a value of 0. This should reset the number
+ // of expirations to 0, causing read to return EAGAIN since the timerfd is
+ // non-blocking.
+ its.it_value.tv_sec = 0;
+ ASSERT_THAT(timerfd_settime(tfd.get(), /* flags = */ 0, &its, nullptr),
+ SyscallSucceeds());
+ uint64_t val = 0;
+ ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST_P(TimerfdTest, SetAbsoluteTime) {
+ constexpr absl::Duration kDelay = absl::Seconds(3);
+
+ // Use a non-blocking timerfd so that if TFD_TIMER_ABSTIME is incorrectly
+ // non-functional, we get EAGAIN rather than a test timeout.
+ auto const tfd =
+ ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), TFD_NONBLOCK));
+ struct itimerspec its = {};
+ ASSERT_THAT(clock_gettime(GetParam(), &its.it_value), SyscallSucceeds());
+ its.it_value.tv_sec += absl::ToInt64Seconds(kDelay);
+ ASSERT_THAT(timerfd_settime(tfd.get(), TFD_TIMER_ABSTIME, &its, nullptr),
+ SyscallSucceeds());
+
+ absl::SleepFor(kDelay + TimerSlack());
+ uint64_t val = 0;
+ ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)),
+ SyscallSucceedsWithValue(sizeof(uint64_t)));
+ EXPECT_EQ(1, val);
+}
+
+TEST_P(TimerfdTest, IllegalSeek) {
+ auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0));
+ if (!IsRunningWithVFS1()) {
+ EXPECT_THAT(lseek(tfd.get(), 0, SEEK_SET), SyscallFailsWithErrno(ESPIPE));
+ }
+}
+
+TEST_P(TimerfdTest, IllegalPread) {
+ auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0));
+ int val;
+ EXPECT_THAT(pread(tfd.get(), &val, sizeof(val), 0),
+ SyscallFailsWithErrno(ESPIPE));
+}
+
+TEST_P(TimerfdTest, IllegalPwrite) {
+ auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0));
+ EXPECT_THAT(pwrite(tfd.get(), "x", 1, 0), SyscallFailsWithErrno(ESPIPE));
+ if (!IsRunningWithVFS1()) {
+ }
+}
+
+TEST_P(TimerfdTest, IllegalWrite) {
+ auto const tfd =
+ ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), TFD_NONBLOCK));
+ uint64_t val = 0;
+ EXPECT_THAT(write(tfd.get(), &val, sizeof(val)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+std::string PrintClockId(::testing::TestParamInfo<int> info) {
+ switch (info.param) {
+ case CLOCK_MONOTONIC:
+ return "CLOCK_MONOTONIC";
+ case CLOCK_BOOTTIME:
+ return "CLOCK_BOOTTIME";
+ default:
+ return absl::StrCat(info.param);
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(AllTimerTypes, TimerfdTest,
+ ::testing::Values(CLOCK_MONOTONIC, CLOCK_BOOTTIME),
+ PrintClockId);
+
+TEST(TimerfdClockRealtimeTest, ClockRealtime) {
+ // Since CLOCK_REALTIME can, by definition, change, we can't make any
+ // non-flaky assertions about the amount of time it takes for a
+ // CLOCK_REALTIME-based timer to expire. Just check that it expires at all,
+ // and hope it happens before the test times out.
+ constexpr int kDelaySecs = 1;
+
+ auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(CLOCK_REALTIME, 0));
+ struct itimerspec its = {};
+ its.it_value.tv_sec = kDelaySecs;
+ ASSERT_THAT(timerfd_settime(tfd.get(), /* flags = */ 0, &its, nullptr),
+ SyscallSucceeds());
+
+ uint64_t val = 0;
+ ASSERT_THAT(ReadFd(tfd.get(), &val, sizeof(uint64_t)),
+ SyscallSucceedsWithValue(sizeof(uint64_t)));
+ EXPECT_EQ(1, val);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc
new file mode 100644
index 000000000..4b3c44527
--- /dev/null
+++ b/test/syscalls/linux/timers.cc
@@ -0,0 +1,662 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <signal.h>
+#include <sys/resource.h>
+#include <sys/time.h>
+#include <syscall.h>
+#include <time.h>
+#include <unistd.h>
+
+#include <atomic>
+
+#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/cleanup.h"
+#include "test/util/logging.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+ABSL_FLAG(bool, timers_test_sleep, false,
+ "If true, sleep forever instead of running tests.");
+
+using ::testing::_;
+using ::testing::AnyOf;
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+#ifndef CPUCLOCK_PROF
+#define CPUCLOCK_PROF 0
+#endif // CPUCLOCK_PROF
+
+PosixErrorOr<absl::Duration> ProcessCPUTime(pid_t pid) {
+ // Use pid-specific CPUCLOCK_PROF, which is the clock used to enforce
+ // RLIMIT_CPU.
+ clockid_t clockid = (~static_cast<clockid_t>(pid) << 3) | CPUCLOCK_PROF;
+
+ struct timespec ts;
+ int ret = clock_gettime(clockid, &ts);
+ if (ret < 0) {
+ return PosixError(errno, "clock_gettime failed");
+ }
+
+ return absl::DurationFromTimespec(ts);
+}
+
+void NoopSignalHandler(int signo) {
+ TEST_CHECK_MSG(SIGXCPU == signo,
+ "NoopSigHandler did not receive expected signal");
+}
+
+void UninstallingSignalHandler(int signo) {
+ TEST_CHECK_MSG(SIGXCPU == signo,
+ "UninstallingSignalHandler did not receive expected signal");
+ struct sigaction rev_action;
+ rev_action.sa_handler = SIG_DFL;
+ rev_action.sa_flags = 0;
+ sigemptyset(&rev_action.sa_mask);
+ sigaction(SIGXCPU, &rev_action, nullptr);
+}
+
+TEST(TimerTest, ProcessKilledOnCPUSoftLimit) {
+ constexpr absl::Duration kSoftLimit = absl::Seconds(1);
+ constexpr absl::Duration kHardLimit = absl::Seconds(3);
+
+ struct rlimit cpu_limits;
+ cpu_limits.rlim_cur = absl::ToInt64Seconds(kSoftLimit);
+ cpu_limits.rlim_max = absl::ToInt64Seconds(kHardLimit);
+
+ int pid = fork();
+ MaybeSave();
+ if (pid == 0) {
+ TEST_PCHECK(setrlimit(RLIMIT_CPU, &cpu_limits) == 0);
+ MaybeSave();
+ for (;;) {
+ }
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ auto c = Cleanup([pid] {
+ int status;
+ EXPECT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFSIGNALED(status));
+ EXPECT_EQ(WTERMSIG(status), SIGXCPU);
+ });
+
+ // Wait for the child to exit, but do not reap it. This will allow us to check
+ // its CPU usage while it is zombied.
+ EXPECT_THAT(waitid(P_PID, pid, nullptr, WEXITED | WNOWAIT),
+ SyscallSucceeds());
+
+ // Assert that the child spent 1s of CPU before getting killed.
+ //
+ // We must be careful to use CPUCLOCK_PROF, the same clock used for RLIMIT_CPU
+ // enforcement, to get correct results. Note that this is slightly different
+ // from rusage-reported CPU usage:
+ //
+ // RLIMIT_CPU, CPUCLOCK_PROF use kernel/sched/cputime.c:thread_group_cputime.
+ // rusage uses kernel/sched/cputime.c:thread_group_cputime_adjusted.
+ absl::Duration cpu = ASSERT_NO_ERRNO_AND_VALUE(ProcessCPUTime(pid));
+ EXPECT_GE(cpu, kSoftLimit);
+
+ // Child did not make it to the hard limit.
+ //
+ // Linux sends SIGXCPU synchronously with CPU tick updates. See
+ // kernel/time/timer.c:update_process_times:
+ // => account_process_tick // update task CPU usage.
+ // => run_posix_cpu_timers // enforce RLIMIT_CPU, sending signal.
+ //
+ // Thus, only chance for this to flake is if the system time required to
+ // deliver the signal exceeds 2s.
+ EXPECT_LT(cpu, kHardLimit);
+}
+
+TEST(TimerTest, ProcessPingedRepeatedlyAfterCPUSoftLimit) {
+ struct sigaction new_action;
+ new_action.sa_handler = UninstallingSignalHandler;
+ new_action.sa_flags = 0;
+ sigemptyset(&new_action.sa_mask);
+
+ constexpr absl::Duration kSoftLimit = absl::Seconds(1);
+ constexpr absl::Duration kHardLimit = absl::Seconds(10);
+
+ struct rlimit cpu_limits;
+ cpu_limits.rlim_cur = absl::ToInt64Seconds(kSoftLimit);
+ cpu_limits.rlim_max = absl::ToInt64Seconds(kHardLimit);
+
+ int pid = fork();
+ MaybeSave();
+ if (pid == 0) {
+ TEST_PCHECK(sigaction(SIGXCPU, &new_action, nullptr) == 0);
+ MaybeSave();
+ TEST_PCHECK(setrlimit(RLIMIT_CPU, &cpu_limits) == 0);
+ MaybeSave();
+ for (;;) {
+ }
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ auto c = Cleanup([pid] {
+ int status;
+ EXPECT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFSIGNALED(status));
+ EXPECT_EQ(WTERMSIG(status), SIGXCPU);
+ });
+
+ // Wait for the child to exit, but do not reap it. This will allow us to check
+ // its CPU usage while it is zombied.
+ EXPECT_THAT(waitid(P_PID, pid, nullptr, WEXITED | WNOWAIT),
+ SyscallSucceeds());
+
+ absl::Duration cpu = ASSERT_NO_ERRNO_AND_VALUE(ProcessCPUTime(pid));
+ // Following signals come every CPU second.
+ EXPECT_GE(cpu, kSoftLimit + absl::Seconds(1));
+
+ // Child did not make it to the hard limit.
+ //
+ // As above, should not flake.
+ EXPECT_LT(cpu, kHardLimit);
+}
+
+TEST(TimerTest, ProcessKilledOnCPUHardLimit) {
+ struct sigaction new_action;
+ new_action.sa_handler = NoopSignalHandler;
+ new_action.sa_flags = 0;
+ sigemptyset(&new_action.sa_mask);
+
+ constexpr absl::Duration kSoftLimit = absl::Seconds(1);
+ constexpr absl::Duration kHardLimit = absl::Seconds(3);
+
+ struct rlimit cpu_limits;
+ cpu_limits.rlim_cur = absl::ToInt64Seconds(kSoftLimit);
+ cpu_limits.rlim_max = absl::ToInt64Seconds(kHardLimit);
+
+ int pid = fork();
+ MaybeSave();
+ if (pid == 0) {
+ TEST_PCHECK(sigaction(SIGXCPU, &new_action, nullptr) == 0);
+ MaybeSave();
+ TEST_PCHECK(setrlimit(RLIMIT_CPU, &cpu_limits) == 0);
+ MaybeSave();
+ for (;;) {
+ }
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ auto c = Cleanup([pid] {
+ int status;
+ EXPECT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFSIGNALED(status));
+ EXPECT_EQ(WTERMSIG(status), SIGKILL);
+ });
+
+ // Wait for the child to exit, but do not reap it. This will allow us to check
+ // its CPU usage while it is zombied.
+ EXPECT_THAT(waitid(P_PID, pid, nullptr, WEXITED | WNOWAIT),
+ SyscallSucceeds());
+
+ absl::Duration cpu = ASSERT_NO_ERRNO_AND_VALUE(ProcessCPUTime(pid));
+ EXPECT_GE(cpu, kHardLimit);
+}
+
+// RAII type for a kernel "POSIX" interval timer. (The kernel provides system
+// calls such as timer_create that behave very similarly, but not identically,
+// to those described by timer_create(2); in particular, the kernel does not
+// implement SIGEV_THREAD. glibc builds POSIX-compliant interval timers based on
+// these kernel interval timers.)
+//
+// Compare implementation to FileDescriptor.
+class IntervalTimer {
+ public:
+ IntervalTimer() = default;
+
+ explicit IntervalTimer(int id) { set_id(id); }
+
+ IntervalTimer(IntervalTimer&& orig) : id_(orig.release()) {}
+
+ IntervalTimer& operator=(IntervalTimer&& orig) {
+ if (this == &orig) return *this;
+ reset(orig.release());
+ return *this;
+ }
+
+ IntervalTimer(const IntervalTimer& other) = delete;
+ IntervalTimer& operator=(const IntervalTimer& other) = delete;
+
+ ~IntervalTimer() { reset(); }
+
+ int get() const { return id_; }
+
+ int release() {
+ int const id = id_;
+ id_ = -1;
+ return id;
+ }
+
+ void reset() { reset(-1); }
+
+ void reset(int id) {
+ if (id_ >= 0) {
+ TEST_PCHECK(syscall(SYS_timer_delete, id_) == 0);
+ MaybeSave();
+ }
+ set_id(id);
+ }
+
+ PosixErrorOr<struct itimerspec> Set(
+ int flags, const struct itimerspec& new_value) const {
+ struct itimerspec old_value = {};
+ if (syscall(SYS_timer_settime, id_, flags, &new_value, &old_value) < 0) {
+ return PosixError(errno, "timer_settime");
+ }
+ MaybeSave();
+ return old_value;
+ }
+
+ PosixErrorOr<struct itimerspec> Get() const {
+ struct itimerspec curr_value = {};
+ if (syscall(SYS_timer_gettime, id_, &curr_value) < 0) {
+ return PosixError(errno, "timer_gettime");
+ }
+ MaybeSave();
+ return curr_value;
+ }
+
+ PosixErrorOr<int> Overruns() const {
+ int rv = syscall(SYS_timer_getoverrun, id_);
+ if (rv < 0) {
+ return PosixError(errno, "timer_getoverrun");
+ }
+ MaybeSave();
+ return rv;
+ }
+
+ private:
+ void set_id(int id) { id_ = std::max(id, -1); }
+
+ // Kernel timer_t is int; glibc timer_t is void*.
+ int id_ = -1;
+};
+
+PosixErrorOr<IntervalTimer> TimerCreate(clockid_t clockid,
+ const struct sigevent& sev) {
+ int timerid;
+ int ret = syscall(SYS_timer_create, clockid, &sev, &timerid);
+ if (ret < 0) {
+ return PosixError(errno, "timer_create");
+ }
+ if (ret > 0) {
+ return PosixError(EINVAL, "timer_create should never return positive");
+ }
+ MaybeSave();
+ return IntervalTimer(timerid);
+}
+
+// See timerfd.cc:TimerSlack() for rationale.
+constexpr absl::Duration kTimerSlack = absl::Milliseconds(500);
+
+TEST(IntervalTimerTest, IsInitiallyStopped) {
+ struct sigevent sev = {};
+ sev.sigev_notify = SIGEV_NONE;
+ const auto timer =
+ ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev));
+ const struct itimerspec its = ASSERT_NO_ERRNO_AND_VALUE(timer.Get());
+ EXPECT_EQ(0, its.it_value.tv_sec);
+ EXPECT_EQ(0, its.it_value.tv_nsec);
+}
+
+// Kernel can create multiple timers without issue.
+//
+// Regression test for gvisor.dev/issue/1738.
+TEST(IntervalTimerTest, MultipleTimers) {
+ struct sigevent sev = {};
+ sev.sigev_notify = SIGEV_NONE;
+ const auto timer1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev));
+ const auto timer2 =
+ ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev));
+}
+
+TEST(IntervalTimerTest, SingleShotSilent) {
+ struct sigevent sev = {};
+ sev.sigev_notify = SIGEV_NONE;
+ const auto timer =
+ ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev));
+
+ constexpr absl::Duration kDelay = absl::Seconds(1);
+ struct itimerspec its = {};
+ its.it_value = absl::ToTimespec(kDelay);
+ ASSERT_NO_ERRNO(timer.Set(0, its));
+
+ // The timer should count down to 0 and stop since the interval is zero. No
+ // overruns should be counted.
+ absl::SleepFor(kDelay + kTimerSlack);
+ its = ASSERT_NO_ERRNO_AND_VALUE(timer.Get());
+ EXPECT_EQ(0, its.it_value.tv_sec);
+ EXPECT_EQ(0, its.it_value.tv_nsec);
+ EXPECT_THAT(timer.Overruns(), IsPosixErrorOkAndHolds(0));
+}
+
+TEST(IntervalTimerTest, PeriodicSilent) {
+ struct sigevent sev = {};
+ sev.sigev_notify = SIGEV_NONE;
+ const auto timer =
+ ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev));
+
+ constexpr absl::Duration kPeriod = absl::Seconds(1);
+ struct itimerspec its = {};
+ its.it_value = its.it_interval = absl::ToTimespec(kPeriod);
+ ASSERT_NO_ERRNO(timer.Set(0, its));
+
+ absl::SleepFor(kPeriod * 3 + kTimerSlack);
+
+ // The timer should still be running.
+ its = ASSERT_NO_ERRNO_AND_VALUE(timer.Get());
+ EXPECT_TRUE(its.it_value.tv_nsec != 0 || its.it_value.tv_sec != 0);
+
+ // Timer expirations are not counted as overruns under SIGEV_NONE.
+ EXPECT_THAT(timer.Overruns(), IsPosixErrorOkAndHolds(0));
+}
+
+std::atomic<int> counted_signals;
+
+void IntervalTimerCountingSignalHandler(int sig, siginfo_t* info,
+ void* ucontext) {
+ counted_signals.fetch_add(1 + info->si_overrun);
+}
+
+TEST(IntervalTimerTest, PeriodicGroupDirectedSignal) {
+ constexpr int kSigno = SIGUSR1;
+ constexpr int kSigvalue = 42;
+
+ // Install our signal handler.
+ counted_signals.store(0);
+ struct sigaction sa = {};
+ sa.sa_sigaction = IntervalTimerCountingSignalHandler;
+ sigemptyset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO;
+ const auto scoped_sigaction =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(kSigno, sa));
+
+ // Ensure that kSigno is unblocked on at least one thread.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, kSigno));
+
+ struct sigevent sev = {};
+ sev.sigev_notify = SIGEV_SIGNAL;
+ sev.sigev_signo = kSigno;
+ sev.sigev_value.sival_int = kSigvalue;
+ auto timer = ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev));
+
+ constexpr absl::Duration kPeriod = absl::Seconds(1);
+ constexpr int kCycles = 3;
+ struct itimerspec its = {};
+ its.it_value = its.it_interval = absl::ToTimespec(kPeriod);
+ ASSERT_NO_ERRNO(timer.Set(0, its));
+
+ absl::SleepFor(kPeriod * kCycles + kTimerSlack);
+ EXPECT_GE(counted_signals.load(), kCycles);
+}
+
+// From Linux's include/uapi/asm-generic/siginfo.h.
+#ifndef sigev_notify_thread_id
+#define sigev_notify_thread_id _sigev_un._tid
+#endif
+
+TEST(IntervalTimerTest, PeriodicThreadDirectedSignal) {
+ constexpr int kSigno = SIGUSR1;
+ constexpr int kSigvalue = 42;
+
+ // Block kSigno so that we can accumulate overruns.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, mask));
+
+ struct sigevent sev = {};
+ sev.sigev_notify = SIGEV_THREAD_ID;
+ sev.sigev_signo = kSigno;
+ sev.sigev_value.sival_int = kSigvalue;
+ sev.sigev_notify_thread_id = gettid();
+ auto timer = ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev));
+
+ constexpr absl::Duration kPeriod = absl::Seconds(1);
+ constexpr int kCycles = 3;
+ struct itimerspec its = {};
+ its.it_value = its.it_interval = absl::ToTimespec(kPeriod);
+ ASSERT_NO_ERRNO(timer.Set(0, its));
+ absl::SleepFor(kPeriod * kCycles + kTimerSlack);
+
+ // At least kCycles expirations should have occurred, resulting in kCycles-1
+ // overruns (the first expiration sent the signal successfully).
+ siginfo_t si;
+ struct timespec zero_ts = absl::ToTimespec(absl::ZeroDuration());
+ ASSERT_THAT(sigtimedwait(&mask, &si, &zero_ts),
+ SyscallSucceedsWithValue(kSigno));
+ EXPECT_EQ(si.si_signo, kSigno);
+ EXPECT_EQ(si.si_code, SI_TIMER);
+ EXPECT_EQ(si.si_timerid, timer.get());
+ EXPECT_GE(si.si_overrun, kCycles - 1);
+ EXPECT_EQ(si.si_int, kSigvalue);
+
+ // Kill the timer, then drain any additional signal it may have enqueued. We
+ // can't do this before the preceding sigtimedwait because stopping or
+ // deleting the timer resets si_overrun to 0.
+ timer.reset();
+ sigtimedwait(&mask, &si, &zero_ts);
+}
+
+TEST(IntervalTimerTest, OtherThreadGroup) {
+ constexpr int kSigno = SIGUSR1;
+
+ // Create a subprocess that does nothing until killed.
+ pid_t child_pid;
+ const auto sp = ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec(
+ "/proc/self/exe", ExecveArray({"timers", "--timers_test_sleep"}),
+ ExecveArray(), &child_pid, nullptr));
+
+ // Verify that we can't create a timer that would send signals to it.
+ struct sigevent sev = {};
+ sev.sigev_notify = SIGEV_THREAD_ID;
+ sev.sigev_signo = kSigno;
+ sev.sigev_notify_thread_id = child_pid;
+ EXPECT_THAT(TimerCreate(CLOCK_MONOTONIC, sev), PosixErrorIs(EINVAL, _));
+}
+
+TEST(IntervalTimerTest, RealTimeSignalsAreNotDuplicated) {
+ const int kSigno = SIGRTMIN;
+ constexpr int kSigvalue = 42;
+
+ // Block signo so that we can accumulate overruns.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ const auto scoped_sigmask = ScopedSignalMask(SIG_BLOCK, mask);
+
+ struct sigevent sev = {};
+ sev.sigev_notify = SIGEV_THREAD_ID;
+ sev.sigev_signo = kSigno;
+ sev.sigev_value.sival_int = kSigvalue;
+ sev.sigev_notify_thread_id = gettid();
+ const auto timer =
+ ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev));
+
+ constexpr absl::Duration kPeriod = absl::Seconds(1);
+ constexpr int kCycles = 3;
+ struct itimerspec its = {};
+ its.it_value = its.it_interval = absl::ToTimespec(kPeriod);
+ ASSERT_NO_ERRNO(timer.Set(0, its));
+ absl::SleepFor(kPeriod * kCycles + kTimerSlack);
+
+ // Stop the timer so that no further signals are enqueued after sigtimedwait.
+ struct timespec zero_ts = absl::ToTimespec(absl::ZeroDuration());
+ its.it_value = its.it_interval = zero_ts;
+ ASSERT_NO_ERRNO(timer.Set(0, its));
+
+ // The timer should have sent only a single signal, even though the kernel
+ // supports enqueueing of multiple RT signals.
+ siginfo_t si;
+ ASSERT_THAT(sigtimedwait(&mask, &si, &zero_ts),
+ SyscallSucceedsWithValue(kSigno));
+ EXPECT_EQ(si.si_signo, kSigno);
+ EXPECT_EQ(si.si_code, SI_TIMER);
+ EXPECT_EQ(si.si_timerid, timer.get());
+ // si_overrun was reset by timer_settime.
+ EXPECT_EQ(si.si_overrun, 0);
+ EXPECT_EQ(si.si_int, kSigvalue);
+ EXPECT_THAT(sigtimedwait(&mask, &si, &zero_ts),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST(IntervalTimerTest, AlreadyPendingSignal) {
+ constexpr int kSigno = SIGUSR1;
+ constexpr int kSigvalue = 42;
+
+ // Block kSigno so that we can accumulate overruns.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, mask));
+
+ // Send ourselves a signal, preventing the timer from enqueuing.
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+
+ struct sigevent sev = {};
+ sev.sigev_notify = SIGEV_THREAD_ID;
+ sev.sigev_signo = kSigno;
+ sev.sigev_value.sival_int = kSigvalue;
+ sev.sigev_notify_thread_id = gettid();
+ auto timer = ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev));
+
+ constexpr absl::Duration kPeriod = absl::Seconds(1);
+ constexpr int kCycles = 3;
+ struct itimerspec its = {};
+ its.it_value = its.it_interval = absl::ToTimespec(kPeriod);
+ ASSERT_NO_ERRNO(timer.Set(0, its));
+
+ // End the sleep one cycle short; we will sleep for one more cycle below.
+ absl::SleepFor(kPeriod * (kCycles - 1));
+
+ // Dequeue the first signal, which we sent to ourselves with tgkill.
+ siginfo_t si;
+ struct timespec zero_ts = absl::ToTimespec(absl::ZeroDuration());
+ ASSERT_THAT(sigtimedwait(&mask, &si, &zero_ts),
+ SyscallSucceedsWithValue(kSigno));
+ EXPECT_EQ(si.si_signo, kSigno);
+ // glibc sigtimedwait silently replaces SI_TKILL with SI_USER:
+ // sysdeps/unix/sysv/linux/sigtimedwait.c:__sigtimedwait(). This isn't
+ // documented, so we don't depend on it.
+ EXPECT_THAT(si.si_code, AnyOf(SI_USER, SI_TKILL));
+
+ // Sleep for 1 more cycle to give the timer time to send a signal.
+ absl::SleepFor(kPeriod + kTimerSlack);
+
+ // At least kCycles expirations should have occurred, resulting in kCycles-1
+ // overruns (the last expiration sent the signal successfully).
+ ASSERT_THAT(sigtimedwait(&mask, &si, &zero_ts),
+ SyscallSucceedsWithValue(kSigno));
+ EXPECT_EQ(si.si_signo, kSigno);
+ EXPECT_EQ(si.si_code, SI_TIMER);
+ EXPECT_EQ(si.si_timerid, timer.get());
+ EXPECT_GE(si.si_overrun, kCycles - 1);
+ EXPECT_EQ(si.si_int, kSigvalue);
+
+ // Kill the timer, then drain any additional signal it may have enqueued. We
+ // can't do this before the preceding sigtimedwait because stopping or
+ // deleting the timer resets si_overrun to 0.
+ timer.reset();
+ sigtimedwait(&mask, &si, &zero_ts);
+}
+
+TEST(IntervalTimerTest, IgnoredSignalCountsAsOverrun) {
+ constexpr int kSigno = SIGUSR1;
+ constexpr int kSigvalue = 42;
+
+ // Ignore kSigno.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_IGN;
+ const auto scoped_sigaction =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(kSigno, sa));
+
+ // Unblock kSigno so that ignored signals will be discarded.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, mask));
+
+ struct sigevent sev = {};
+ sev.sigev_notify = SIGEV_THREAD_ID;
+ sev.sigev_signo = kSigno;
+ sev.sigev_value.sival_int = kSigvalue;
+ sev.sigev_notify_thread_id = gettid();
+ auto timer = ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev));
+
+ constexpr absl::Duration kPeriod = absl::Seconds(1);
+ constexpr int kCycles = 3;
+ struct itimerspec its = {};
+ its.it_value = its.it_interval = absl::ToTimespec(kPeriod);
+ ASSERT_NO_ERRNO(timer.Set(0, its));
+
+ // End the sleep one cycle short; we will sleep for one more cycle below.
+ absl::SleepFor(kPeriod * (kCycles - 1));
+
+ // Block kSigno so that ignored signals will be enqueued.
+ scoped_sigmask.Release()();
+ scoped_sigmask = ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, mask));
+
+ // Sleep for 1 more cycle to give the timer time to send a signal.
+ absl::SleepFor(kPeriod + kTimerSlack);
+
+ // At least kCycles expirations should have occurred, resulting in kCycles-1
+ // overruns (the last expiration sent the signal successfully).
+ siginfo_t si;
+ struct timespec zero_ts = absl::ToTimespec(absl::ZeroDuration());
+ ASSERT_THAT(sigtimedwait(&mask, &si, &zero_ts),
+ SyscallSucceedsWithValue(kSigno));
+ EXPECT_EQ(si.si_signo, kSigno);
+ EXPECT_EQ(si.si_code, SI_TIMER);
+ EXPECT_EQ(si.si_timerid, timer.get());
+ EXPECT_GE(si.si_overrun, kCycles - 1);
+ EXPECT_EQ(si.si_int, kSigvalue);
+
+ // Kill the timer, then drain any additional signal it may have enqueued. We
+ // can't do this before the preceding sigtimedwait because stopping or
+ // deleting the timer resets si_overrun to 0.
+ timer.reset();
+ sigtimedwait(&mask, &si, &zero_ts);
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
+
+int main(int argc, char** argv) {
+ gvisor::testing::TestInit(&argc, &argv);
+
+ if (absl::GetFlag(FLAGS_timers_test_sleep)) {
+ while (true) {
+ absl::SleepFor(absl::Seconds(10));
+ }
+ }
+
+ return gvisor::testing::RunAllTests();
+}
diff --git a/test/syscalls/linux/tkill.cc b/test/syscalls/linux/tkill.cc
new file mode 100644
index 000000000..8d8ebbb24
--- /dev/null
+++ b/test/syscalls/linux/tkill.cc
@@ -0,0 +1,75 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <cerrno>
+#include <csignal>
+
+#include "gtest/gtest.h"
+#include "test/util/logging.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+static int tkill(pid_t tid, int sig) {
+ int ret;
+ do {
+ // NOTE(b/25434735): tkill(2) could return EAGAIN for RT signals.
+ ret = syscall(SYS_tkill, tid, sig);
+ } while (ret == -1 && errno == EAGAIN);
+ return ret;
+}
+
+TEST(TkillTest, InvalidTID) {
+ EXPECT_THAT(tkill(-1, 0), SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(tkill(0, 0), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(TkillTest, ValidTID) {
+ EXPECT_THAT(tkill(gettid(), 0), SyscallSucceeds());
+}
+
+void SigHandler(int sig, siginfo_t* info, void* context) {
+ TEST_CHECK(sig == SIGRTMAX);
+ TEST_CHECK(info->si_pid == getpid());
+ TEST_CHECK(info->si_uid == getuid());
+ TEST_CHECK(info->si_code == SI_TKILL);
+}
+
+// Test with a real signal. Regression test for b/24790092.
+TEST(TkillTest, ValidTIDAndRealSignal) {
+ struct sigaction sa;
+ sa.sa_sigaction = SigHandler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO;
+ ASSERT_THAT(sigaction(SIGRTMAX, &sa, nullptr), SyscallSucceeds());
+ // InitGoogle blocks all RT signals, so we need undo it.
+ sigset_t unblock;
+ sigemptyset(&unblock);
+ sigaddset(&unblock, SIGRTMAX);
+ ASSERT_THAT(sigprocmask(SIG_UNBLOCK, &unblock, nullptr), SyscallSucceeds());
+ EXPECT_THAT(tkill(gettid(), SIGRTMAX), SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/truncate.cc b/test/syscalls/linux/truncate.cc
new file mode 100644
index 000000000..c988c6380
--- /dev/null
+++ b/test/syscalls/linux/truncate.cc
@@ -0,0 +1,218 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <signal.h>
+#include <sys/resource.h>
+#include <sys/stat.h>
+#include <sys/vfs.h>
+#include <time.h>
+#include <unistd.h>
+
+#include <iostream>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/string_view.h"
+#include "test/syscalls/linux/file_base.h"
+#include "test/util/capability_util.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class FixtureTruncateTest : public FileTest {
+ void SetUp() override { FileTest::SetUp(); }
+};
+
+TEST_F(FixtureTruncateTest, Truncate) {
+ // Get the current rlimit and restore after test run.
+ struct rlimit initial_lim;
+ ASSERT_THAT(getrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds());
+ auto cleanup = Cleanup([&initial_lim] {
+ EXPECT_THAT(setrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds());
+ });
+
+ // Check that it starts at size zero.
+ struct stat buf;
+ ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
+ EXPECT_EQ(buf.st_size, 0);
+
+ // Stay at size zero.
+ EXPECT_THAT(truncate(test_file_name_.c_str(), 0), SyscallSucceeds());
+ ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
+ EXPECT_EQ(buf.st_size, 0);
+
+ // Grow to ten bytes.
+ EXPECT_THAT(truncate(test_file_name_.c_str(), 10), SyscallSucceeds());
+ ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
+ EXPECT_EQ(buf.st_size, 10);
+
+ // Can't be truncated to a negative number.
+ EXPECT_THAT(truncate(test_file_name_.c_str(), -1),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Try growing past the file size limit.
+ sigset_t new_mask;
+ sigemptyset(&new_mask);
+ sigaddset(&new_mask, SIGXFSZ);
+ sigprocmask(SIG_BLOCK, &new_mask, nullptr);
+ struct timespec timelimit;
+ timelimit.tv_sec = 10;
+ timelimit.tv_nsec = 0;
+
+ struct rlimit setlim;
+ setlim.rlim_cur = 1024;
+ setlim.rlim_max = RLIM_INFINITY;
+ ASSERT_THAT(setrlimit(RLIMIT_FSIZE, &setlim), SyscallSucceeds());
+ EXPECT_THAT(truncate(test_file_name_.c_str(), 1025),
+ SyscallFailsWithErrno(EFBIG));
+ EXPECT_EQ(sigtimedwait(&new_mask, nullptr, &timelimit), SIGXFSZ);
+ ASSERT_THAT(sigprocmask(SIG_UNBLOCK, &new_mask, nullptr), SyscallSucceeds());
+
+ // Shrink back down to zero.
+ EXPECT_THAT(truncate(test_file_name_.c_str(), 0), SyscallSucceeds());
+ ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
+ EXPECT_EQ(buf.st_size, 0);
+}
+
+TEST_F(FixtureTruncateTest, Ftruncate) {
+ // Get the current rlimit and restore after test run.
+ struct rlimit initial_lim;
+ ASSERT_THAT(getrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds());
+ auto cleanup = Cleanup([&initial_lim] {
+ EXPECT_THAT(setrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds());
+ });
+
+ // Check that it starts at size zero.
+ struct stat buf;
+ ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
+ EXPECT_EQ(buf.st_size, 0);
+
+ // Stay at size zero.
+ EXPECT_THAT(ftruncate(test_file_fd_.get(), 0), SyscallSucceeds());
+ ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
+ EXPECT_EQ(buf.st_size, 0);
+
+ // Grow to ten bytes.
+ EXPECT_THAT(ftruncate(test_file_fd_.get(), 10), SyscallSucceeds());
+ ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
+ EXPECT_EQ(buf.st_size, 10);
+
+ // Can't be truncated to a negative number.
+ EXPECT_THAT(ftruncate(test_file_fd_.get(), -1),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Try growing past the file size limit.
+ sigset_t new_mask;
+ sigemptyset(&new_mask);
+ sigaddset(&new_mask, SIGXFSZ);
+ sigprocmask(SIG_BLOCK, &new_mask, nullptr);
+ struct timespec timelimit;
+ timelimit.tv_sec = 10;
+ timelimit.tv_nsec = 0;
+
+ struct rlimit setlim;
+ setlim.rlim_cur = 1024;
+ setlim.rlim_max = RLIM_INFINITY;
+ ASSERT_THAT(setrlimit(RLIMIT_FSIZE, &setlim), SyscallSucceeds());
+ EXPECT_THAT(ftruncate(test_file_fd_.get(), 1025),
+ SyscallFailsWithErrno(EFBIG));
+ EXPECT_EQ(sigtimedwait(&new_mask, nullptr, &timelimit), SIGXFSZ);
+ ASSERT_THAT(sigprocmask(SIG_UNBLOCK, &new_mask, nullptr), SyscallSucceeds());
+
+ // Shrink back down to zero.
+ EXPECT_THAT(ftruncate(test_file_fd_.get(), 0), SyscallSucceeds());
+ ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
+ EXPECT_EQ(buf.st_size, 0);
+}
+
+// Truncating a file down clears that portion of the file.
+TEST_F(FixtureTruncateTest, FtruncateShrinkGrow) {
+ std::vector<char> buf(10, 'a');
+ EXPECT_THAT(WriteFd(test_file_fd_.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+
+ // Shrink then regrow the file. This should clear the second half of the file.
+ EXPECT_THAT(ftruncate(test_file_fd_.get(), 5), SyscallSucceeds());
+ EXPECT_THAT(ftruncate(test_file_fd_.get(), 10), SyscallSucceeds());
+
+ EXPECT_THAT(lseek(test_file_fd_.get(), 0, SEEK_SET), SyscallSucceeds());
+
+ std::vector<char> buf2(10);
+ EXPECT_THAT(ReadFd(test_file_fd_.get(), buf2.data(), buf2.size()),
+ SyscallSucceedsWithValue(buf2.size()));
+
+ std::vector<char> expect = {'a', 'a', 'a', 'a', 'a',
+ '\0', '\0', '\0', '\0', '\0'};
+ EXPECT_EQ(expect, buf2);
+}
+
+TEST(TruncateTest, TruncateDir) {
+ auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(truncate(temp_dir.path().c_str(), 0),
+ SyscallFailsWithErrno(EISDIR));
+}
+
+TEST(TruncateTest, FtruncateDir) {
+ auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(temp_dir.path(), O_DIRECTORY | O_RDONLY));
+ EXPECT_THAT(ftruncate(fd.get(), 0), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(TruncateTest, TruncateNonWriteable) {
+ // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to
+ // always override write permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::string_view(), 0555 /* mode */));
+ EXPECT_THAT(truncate(temp_file.path().c_str(), 0),
+ SyscallFailsWithErrno(EACCES));
+}
+
+TEST(TruncateTest, FtruncateNonWriteable) {
+ auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::string_view(), 0555 /* mode */));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(temp_file.path(), O_RDONLY));
+ EXPECT_THAT(ftruncate(fd.get(), 0), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(TruncateTest, TruncateNonExist) {
+ EXPECT_THAT(truncate("/foo/bar", 0), SyscallFailsWithErrno(ENOENT));
+}
+
+TEST(TruncateTest, FtruncateVirtualTmp_NoRandomSave) {
+ auto temp_file = NewTempAbsPathInDir("/dev/shm");
+ const DisableSave ds; // Incompatible permissions.
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(temp_file, O_RDWR | O_CREAT | O_EXCL, 0));
+ EXPECT_THAT(ftruncate(fd.get(), 100), SyscallSucceeds());
+}
+
+// NOTE: There are additional truncate(2)/ftruncate(2) tests in mknod.cc
+// which are there to avoid running the tests on a number of different
+// filesystems which may not support mknod.
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc
new file mode 100644
index 000000000..97d554e72
--- /dev/null
+++ b/test/syscalls/linux/tuntap.cc
@@ -0,0 +1,422 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/capability.h>
+#include <linux/if_arp.h>
+#include <linux/if_ether.h>
+#include <linux/if_tun.h>
+#include <netinet/ip.h>
+#include <netinet/ip_icmp.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_split.h"
+#include "test/syscalls/linux/socket_netlink_route_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+constexpr int kIPLen = 4;
+
+constexpr const char kDevNetTun[] = "/dev/net/tun";
+constexpr const char kTapName[] = "tap0";
+
+constexpr const uint8_t kMacA[ETH_ALEN] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA};
+constexpr const uint8_t kMacB[ETH_ALEN] = {0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB};
+
+PosixErrorOr<std::set<std::string>> DumpLinkNames() {
+ ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks());
+ std::set<std::string> names;
+ for (const auto& link : links) {
+ names.emplace(link.name);
+ }
+ return names;
+}
+
+PosixErrorOr<Link> GetLinkByName(const std::string& name) {
+ ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks());
+ for (const auto& link : links) {
+ if (link.name == name) {
+ return link;
+ }
+ }
+ return PosixError(ENOENT, "interface not found");
+}
+
+struct pihdr {
+ uint16_t pi_flags;
+ uint16_t pi_protocol;
+} __attribute__((packed));
+
+struct ping_pkt {
+ pihdr pi;
+ struct ethhdr eth;
+ struct iphdr ip;
+ struct icmphdr icmp;
+ char payload[64];
+} __attribute__((packed));
+
+ping_pkt CreatePingPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip,
+ const uint8_t dstmac[ETH_ALEN], const char* dstip) {
+ ping_pkt pkt = {};
+
+ pkt.pi.pi_protocol = htons(ETH_P_IP);
+
+ memcpy(pkt.eth.h_dest, dstmac, sizeof(pkt.eth.h_dest));
+ memcpy(pkt.eth.h_source, srcmac, sizeof(pkt.eth.h_source));
+ pkt.eth.h_proto = htons(ETH_P_IP);
+
+ pkt.ip.ihl = 5;
+ pkt.ip.version = 4;
+ pkt.ip.tos = 0;
+ pkt.ip.tot_len = htons(sizeof(struct iphdr) + sizeof(struct icmphdr) +
+ sizeof(pkt.payload));
+ pkt.ip.id = 1;
+ pkt.ip.frag_off = 1 << 6; // Do not fragment
+ pkt.ip.ttl = 64;
+ pkt.ip.protocol = IPPROTO_ICMP;
+ inet_pton(AF_INET, dstip, &pkt.ip.daddr);
+ inet_pton(AF_INET, srcip, &pkt.ip.saddr);
+ pkt.ip.check = IPChecksum(pkt.ip);
+
+ pkt.icmp.type = ICMP_ECHO;
+ pkt.icmp.code = 0;
+ pkt.icmp.checksum = 0;
+ pkt.icmp.un.echo.sequence = 1;
+ pkt.icmp.un.echo.id = 1;
+
+ strncpy(pkt.payload, "abcd", sizeof(pkt.payload));
+ pkt.icmp.checksum = ICMPChecksum(pkt.icmp, pkt.payload, sizeof(pkt.payload));
+
+ return pkt;
+}
+
+struct arp_pkt {
+ pihdr pi;
+ struct ethhdr eth;
+ struct arphdr arp;
+ uint8_t arp_sha[ETH_ALEN];
+ uint8_t arp_spa[kIPLen];
+ uint8_t arp_tha[ETH_ALEN];
+ uint8_t arp_tpa[kIPLen];
+} __attribute__((packed));
+
+std::string CreateArpPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip,
+ const uint8_t dstmac[ETH_ALEN], const char* dstip) {
+ std::string buffer;
+ buffer.resize(sizeof(arp_pkt));
+
+ arp_pkt* pkt = reinterpret_cast<arp_pkt*>(&buffer[0]);
+ {
+ pkt->pi.pi_protocol = htons(ETH_P_ARP);
+
+ memcpy(pkt->eth.h_dest, kMacA, sizeof(pkt->eth.h_dest));
+ memcpy(pkt->eth.h_source, kMacB, sizeof(pkt->eth.h_source));
+ pkt->eth.h_proto = htons(ETH_P_ARP);
+
+ pkt->arp.ar_hrd = htons(ARPHRD_ETHER);
+ pkt->arp.ar_pro = htons(ETH_P_IP);
+ pkt->arp.ar_hln = ETH_ALEN;
+ pkt->arp.ar_pln = kIPLen;
+ pkt->arp.ar_op = htons(ARPOP_REPLY);
+
+ memcpy(pkt->arp_sha, srcmac, sizeof(pkt->arp_sha));
+ inet_pton(AF_INET, srcip, pkt->arp_spa);
+ memcpy(pkt->arp_tha, dstmac, sizeof(pkt->arp_tha));
+ inet_pton(AF_INET, dstip, pkt->arp_tpa);
+ }
+ return buffer;
+}
+
+} // namespace
+
+TEST(TuntapStaticTest, NetTunExists) {
+ struct stat statbuf;
+ ASSERT_THAT(stat(kDevNetTun, &statbuf), SyscallSucceeds());
+ // Check that it's a character device with rw-rw-rw- permissions.
+ EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666);
+}
+
+class TuntapTest : public ::testing::Test {
+ protected:
+ void TearDown() override {
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))) {
+ // Bring back capability if we had dropped it in test case.
+ ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, true));
+ }
+ }
+};
+
+TEST_F(TuntapTest, CreateInterfaceNoCap) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, false));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr = {};
+ ifr.ifr_flags = IFF_TAP;
+ strncpy(ifr.ifr_name, kTapName, IFNAMSIZ);
+
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallFailsWithErrno(EPERM));
+}
+
+TEST_F(TuntapTest, CreateFixedNameInterface) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr_set = {};
+ ifr_set.ifr_flags = IFF_TAP;
+ strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ);
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set),
+ SyscallSucceedsWithValue(0));
+
+ struct ifreq ifr_get = {};
+ EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get),
+ SyscallSucceedsWithValue(0));
+
+ struct ifreq ifr_expect = ifr_set;
+ // See __tun_chr_ioctl() in net/drivers/tun.c.
+ ifr_expect.ifr_flags |= IFF_NOFILTER;
+
+ EXPECT_THAT(DumpLinkNames(),
+ IsPosixErrorOkAndHolds(::testing::Contains(kTapName)));
+ EXPECT_THAT(memcmp(&ifr_expect, &ifr_get, sizeof(ifr_get)), ::testing::Eq(0));
+}
+
+TEST_F(TuntapTest, CreateInterface) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr = {};
+ ifr.ifr_flags = IFF_TAP;
+ // Empty ifr.ifr_name. Let kernel assign.
+
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0));
+
+ struct ifreq ifr_get = {};
+ EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get),
+ SyscallSucceedsWithValue(0));
+
+ std::string ifname = ifr_get.ifr_name;
+ EXPECT_THAT(ifname, ::testing::StartsWith("tap"));
+ EXPECT_THAT(DumpLinkNames(),
+ IsPosixErrorOkAndHolds(::testing::Contains(ifname)));
+}
+
+TEST_F(TuntapTest, InvalidReadWrite) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ char buf[128] = {};
+ EXPECT_THAT(read(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD));
+ EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD));
+}
+
+TEST_F(TuntapTest, WriteToDownDevice) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ // FIXME(b/110961832): gVisor always creates enabled/up'd interfaces.
+ SKIP_IF(IsRunningOnGvisor());
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ // Device created should be down by default.
+ struct ifreq ifr = {};
+ ifr.ifr_flags = IFF_TAP;
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0));
+
+ char buf[128] = {};
+ EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EIO));
+}
+
+PosixErrorOr<FileDescriptor> OpenAndAttachTap(
+ const std::string& dev_name, const std::string& dev_ipv4_addr) {
+ // Interface creation.
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr_set = {};
+ ifr_set.ifr_flags = IFF_TAP;
+ strncpy(ifr_set.ifr_name, dev_name.c_str(), IFNAMSIZ);
+ if (ioctl(fd.get(), TUNSETIFF, &ifr_set) < 0) {
+ return PosixError(errno);
+ }
+
+ ASSIGN_OR_RETURN_ERRNO(auto link, GetLinkByName(dev_name));
+
+ // Interface setup.
+ struct in_addr addr;
+ inet_pton(AF_INET, dev_ipv4_addr.c_str(), &addr);
+ EXPECT_NO_ERRNO(LinkAddLocalAddr(link.index, AF_INET, /*prefixlen=*/24, &addr,
+ sizeof(addr)));
+
+ if (!IsRunningOnGvisor()) {
+ // FIXME(b/110961832): gVisor doesn't support setting MAC address on
+ // interfaces yet.
+ RETURN_IF_ERRNO(LinkSetMacAddr(link.index, kMacA, sizeof(kMacA)));
+
+ // FIXME(b/110961832): gVisor always creates enabled/up'd interfaces.
+ RETURN_IF_ERRNO(LinkChangeFlags(link.index, IFF_UP, IFF_UP));
+ }
+
+ return fd;
+}
+
+// This test sets up a TAP device and pings kernel by sending ICMP echo request.
+//
+// It works as the following:
+// * Open /dev/net/tun, and create kTapName interface.
+// * Use rtnetlink to do initial setup of the interface:
+// * Assign IP address 10.0.0.1/24 to kernel.
+// * MAC address: kMacA
+// * Bring up the interface.
+// * Send an ICMP echo reqest (ping) packet from 10.0.0.2 (kMacB) to kernel.
+// * Loop to receive packets from TAP device/fd:
+// * If packet is an ICMP echo reply, it stops and passes the test.
+// * If packet is an ARP request, it responds with canned reply and resends
+// the
+// ICMP request packet.
+TEST_F(TuntapTest, PingKernel) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenAndAttachTap(kTapName, "10.0.0.1"));
+ ping_pkt ping_req = CreatePingPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1");
+ std::string arp_rep = CreateArpPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1");
+
+ // Send ping, this would trigger an ARP request on Linux.
+ EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)),
+ SyscallSucceedsWithValue(sizeof(ping_req)));
+
+ // Receive loop to process inbound packets.
+ struct inpkt {
+ union {
+ pihdr pi;
+ ping_pkt ping;
+ arp_pkt arp;
+ };
+ };
+ while (1) {
+ inpkt r = {};
+ int n = read(fd.get(), &r, sizeof(r));
+ EXPECT_THAT(n, SyscallSucceeds());
+
+ if (n < sizeof(pihdr)) {
+ std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol
+ << " len: " << n << std::endl;
+ continue;
+ }
+
+ // Process ARP packet.
+ if (n >= sizeof(arp_pkt) && r.pi.pi_protocol == htons(ETH_P_ARP)) {
+ // Respond with canned ARP reply.
+ EXPECT_THAT(write(fd.get(), arp_rep.data(), arp_rep.size()),
+ SyscallSucceedsWithValue(arp_rep.size()));
+ // First ping request might have been dropped due to mac address not in
+ // ARP cache. Send it again.
+ EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)),
+ SyscallSucceedsWithValue(sizeof(ping_req)));
+ }
+
+ // Process ping response packet.
+ if (n >= sizeof(ping_pkt) && r.pi.pi_protocol == ping_req.pi.pi_protocol &&
+ r.ping.ip.protocol == ping_req.ip.protocol &&
+ !memcmp(&r.ping.ip.saddr, &ping_req.ip.daddr, kIPLen) &&
+ !memcmp(&r.ping.ip.daddr, &ping_req.ip.saddr, kIPLen) &&
+ r.ping.icmp.type == 0 && r.ping.icmp.code == 0) {
+ // Ends and passes the test.
+ break;
+ }
+ }
+}
+
+TEST_F(TuntapTest, SendUdpTriggersArpResolution) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenAndAttachTap(kTapName, "10.0.0.1"));
+
+ // Send a UDP packet to remote.
+ int sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_IP);
+ ASSERT_THAT(sock, SyscallSucceeds());
+
+ struct sockaddr_in remote = {};
+ remote.sin_family = AF_INET;
+ remote.sin_port = htons(42);
+ inet_pton(AF_INET, "10.0.0.2", &remote.sin_addr);
+ int ret = sendto(sock, "hello", 5, 0, reinterpret_cast<sockaddr*>(&remote),
+ sizeof(remote));
+ ASSERT_THAT(ret, ::testing::AnyOf(SyscallSucceeds(),
+ SyscallFailsWithErrno(EHOSTDOWN)));
+
+ struct inpkt {
+ union {
+ pihdr pi;
+ arp_pkt arp;
+ };
+ };
+ while (1) {
+ inpkt r = {};
+ int n = read(fd.get(), &r, sizeof(r));
+ EXPECT_THAT(n, SyscallSucceeds());
+
+ if (n < sizeof(pihdr)) {
+ std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol
+ << " len: " << n << std::endl;
+ continue;
+ }
+
+ if (n >= sizeof(arp_pkt) && r.pi.pi_protocol == htons(ETH_P_ARP)) {
+ break;
+ }
+ }
+}
+
+// Write hang bug found by syskaller: b/155928773
+// https://syzkaller.appspot.com/bug?id=065b893bd8d1d04a4e0a1d53c578537cde1efe99
+TEST_F(TuntapTest, WriteHangBug155928773) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenAndAttachTap(kTapName, "10.0.0.1"));
+
+ int sock = socket(AF_INET, SOCK_DGRAM, 0);
+ ASSERT_THAT(sock, SyscallSucceeds());
+
+ struct sockaddr_in remote = {};
+ remote.sin_family = AF_INET;
+ remote.sin_port = htons(42);
+ inet_pton(AF_INET, "10.0.0.1", &remote.sin_addr);
+ // Return values do not matter in this test.
+ connect(sock, reinterpret_cast<struct sockaddr*>(&remote), sizeof(remote));
+ write(sock, "hello", 5);
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/tuntap_hostinet.cc b/test/syscalls/linux/tuntap_hostinet.cc
new file mode 100644
index 000000000..1513fb9d5
--- /dev/null
+++ b/test/syscalls/linux/tuntap_hostinet.cc
@@ -0,0 +1,38 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(TuntapHostInetTest, NoNetTun) {
+ SKIP_IF(!IsRunningOnGvisor());
+ SKIP_IF(!IsRunningWithHostinet());
+
+ struct stat statbuf;
+ ASSERT_THAT(stat("/dev/net/tun", &statbuf), SyscallFailsWithErrno(ENOENT));
+}
+
+} // namespace
+} // namespace testing
+
+} // namespace gvisor
diff --git a/test/syscalls/linux/udp_bind.cc b/test/syscalls/linux/udp_bind.cc
new file mode 100644
index 000000000..6d92bdbeb
--- /dev/null
+++ b/test/syscalls/linux/udp_bind.cc
@@ -0,0 +1,316 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+struct sockaddr_in_common {
+ sa_family_t sin_family;
+ in_port_t sin_port;
+};
+
+struct SendtoTestParam {
+ // Human readable description of test parameter.
+ std::string description;
+
+ // Test is broken in gVisor, skip.
+ bool skip_on_gvisor;
+
+ // Domain for the socket that will do the sending.
+ int send_domain;
+
+ // Address to bind for the socket that will do the sending.
+ struct sockaddr_storage send_addr;
+ socklen_t send_addr_len; // 0 for unbound.
+
+ // Address to connect to for the socket that will do the sending.
+ struct sockaddr_storage connect_addr;
+ socklen_t connect_addr_len; // 0 for no connection.
+
+ // Domain for the socket that will do the receiving.
+ int recv_domain;
+
+ // Address to bind for the socket that will do the receiving.
+ struct sockaddr_storage recv_addr;
+ socklen_t recv_addr_len;
+
+ // Address to send to.
+ struct sockaddr_storage sendto_addr;
+ socklen_t sendto_addr_len;
+
+ // Expected errno for the sendto call.
+ std::vector<int> sendto_errnos; // empty on success.
+};
+
+class SendtoTest : public ::testing::TestWithParam<SendtoTestParam> {
+ protected:
+ SendtoTest() {
+ // gUnit uses printf, so so will we.
+ printf("Testing with %s\n", GetParam().description.c_str());
+ }
+};
+
+TEST_P(SendtoTest, Sendto) {
+ auto param = GetParam();
+
+ SKIP_IF(param.skip_on_gvisor && IsRunningOnGvisor());
+
+ const FileDescriptor s1 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(param.send_domain, SOCK_DGRAM, 0));
+ const FileDescriptor s2 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(param.recv_domain, SOCK_DGRAM, 0));
+
+ if (param.send_addr_len > 0) {
+ ASSERT_THAT(bind(s1.get(), reinterpret_cast<sockaddr*>(&param.send_addr),
+ param.send_addr_len),
+ SyscallSucceeds());
+ }
+
+ if (param.connect_addr_len > 0) {
+ ASSERT_THAT(
+ connect(s1.get(), reinterpret_cast<sockaddr*>(&param.connect_addr),
+ param.connect_addr_len),
+ SyscallSucceeds());
+ }
+
+ ASSERT_THAT(bind(s2.get(), reinterpret_cast<sockaddr*>(&param.recv_addr),
+ param.recv_addr_len),
+ SyscallSucceeds());
+
+ struct sockaddr_storage real_recv_addr = {};
+ socklen_t real_recv_addr_len = param.recv_addr_len;
+ ASSERT_THAT(
+ getsockname(s2.get(), reinterpret_cast<sockaddr*>(&real_recv_addr),
+ &real_recv_addr_len),
+ SyscallSucceeds());
+
+ ASSERT_EQ(real_recv_addr_len, param.recv_addr_len);
+
+ int recv_port =
+ reinterpret_cast<sockaddr_in_common*>(&real_recv_addr)->sin_port;
+
+ struct sockaddr_storage sendto_addr = param.sendto_addr;
+ reinterpret_cast<sockaddr_in_common*>(&sendto_addr)->sin_port = recv_port;
+
+ char buf[20] = {};
+ if (!param.sendto_errnos.empty()) {
+ ASSERT_THAT(RetryEINTR(sendto)(s1.get(), buf, sizeof(buf), 0,
+ reinterpret_cast<sockaddr*>(&sendto_addr),
+ param.sendto_addr_len),
+ SyscallFailsWithErrno(ElementOf(param.sendto_errnos)));
+ return;
+ }
+
+ ASSERT_THAT(RetryEINTR(sendto)(s1.get(), buf, sizeof(buf), 0,
+ reinterpret_cast<sockaddr*>(&sendto_addr),
+ param.sendto_addr_len),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ struct sockaddr_storage got_addr = {};
+ socklen_t got_addr_len = sizeof(sockaddr_storage);
+ ASSERT_THAT(RetryEINTR(recvfrom)(s2.get(), buf, sizeof(buf), 0,
+ reinterpret_cast<sockaddr*>(&got_addr),
+ &got_addr_len),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ ASSERT_GT(got_addr_len, sizeof(sockaddr_in_common));
+ int got_port = reinterpret_cast<sockaddr_in_common*>(&got_addr)->sin_port;
+
+ struct sockaddr_storage sender_addr = {};
+ socklen_t sender_addr_len = sizeof(sockaddr_storage);
+ ASSERT_THAT(getsockname(s1.get(), reinterpret_cast<sockaddr*>(&sender_addr),
+ &sender_addr_len),
+ SyscallSucceeds());
+
+ ASSERT_GT(sender_addr_len, sizeof(sockaddr_in_common));
+ int sender_port =
+ reinterpret_cast<sockaddr_in_common*>(&sender_addr)->sin_port;
+
+ EXPECT_EQ(got_port, sender_port);
+}
+
+socklen_t Ipv4Addr(sockaddr_storage* addr, int port = 0) {
+ auto addr4 = reinterpret_cast<sockaddr_in*>(addr);
+ addr4->sin_family = AF_INET;
+ addr4->sin_port = port;
+ inet_pton(AF_INET, "127.0.0.1", &addr4->sin_addr.s_addr);
+ return sizeof(struct sockaddr_in);
+}
+
+socklen_t Ipv6Addr(sockaddr_storage* addr, int port = 0) {
+ auto addr6 = reinterpret_cast<sockaddr_in6*>(addr);
+ addr6->sin6_family = AF_INET6;
+ addr6->sin6_port = port;
+ inet_pton(AF_INET6, "::1", &addr6->sin6_addr.s6_addr);
+ return sizeof(struct sockaddr_in6);
+}
+
+socklen_t Ipv4MappedIpv6Addr(sockaddr_storage* addr, int port = 0) {
+ auto addr6 = reinterpret_cast<sockaddr_in6*>(addr);
+ addr6->sin6_family = AF_INET6;
+ addr6->sin6_port = port;
+ inet_pton(AF_INET6, "::ffff:127.0.0.1", &addr6->sin6_addr.s6_addr);
+ return sizeof(struct sockaddr_in6);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ UdpBindTest, SendtoTest,
+ ::testing::Values(
+ []() {
+ SendtoTestParam param = {};
+ param.description = "IPv4 mapped IPv6 sendto IPv4 mapped IPv6";
+ param.send_domain = AF_INET6;
+ param.send_addr_len = Ipv4MappedIpv6Addr(&param.send_addr);
+ param.recv_domain = AF_INET6;
+ param.recv_addr_len = Ipv4MappedIpv6Addr(&param.recv_addr);
+ param.sendto_addr_len = Ipv4MappedIpv6Addr(&param.sendto_addr);
+ return param;
+ }(),
+ []() {
+ SendtoTestParam param = {};
+ param.description = "IPv6 sendto IPv6";
+ param.send_domain = AF_INET6;
+ param.send_addr_len = Ipv6Addr(&param.send_addr);
+ param.recv_domain = AF_INET6;
+ param.recv_addr_len = Ipv6Addr(&param.recv_addr);
+ param.sendto_addr_len = Ipv6Addr(&param.sendto_addr);
+ return param;
+ }(),
+ []() {
+ SendtoTestParam param = {};
+ param.description = "IPv4 sendto IPv4";
+ param.send_domain = AF_INET;
+ param.send_addr_len = Ipv4Addr(&param.send_addr);
+ param.recv_domain = AF_INET;
+ param.recv_addr_len = Ipv4Addr(&param.recv_addr);
+ param.sendto_addr_len = Ipv4Addr(&param.sendto_addr);
+ return param;
+ }(),
+ []() {
+ SendtoTestParam param = {};
+ param.description = "IPv4 mapped IPv6 sendto IPv4";
+ param.send_domain = AF_INET6;
+ param.send_addr_len = Ipv4MappedIpv6Addr(&param.send_addr);
+ param.recv_domain = AF_INET;
+ param.recv_addr_len = Ipv4Addr(&param.recv_addr);
+ param.sendto_addr_len = Ipv4MappedIpv6Addr(&param.sendto_addr);
+ return param;
+ }(),
+ []() {
+ SendtoTestParam param = {};
+ param.description = "IPv4 sendto IPv4 mapped IPv6";
+ param.send_domain = AF_INET;
+ param.send_addr_len = Ipv4Addr(&param.send_addr);
+ param.recv_domain = AF_INET6;
+ param.recv_addr_len = Ipv4MappedIpv6Addr(&param.recv_addr);
+ param.sendto_addr_len = Ipv4Addr(&param.sendto_addr);
+ return param;
+ }(),
+ []() {
+ SendtoTestParam param = {};
+ param.description = "unbound IPv6 sendto IPv4 mapped IPv6";
+ param.send_domain = AF_INET6;
+ param.recv_domain = AF_INET6;
+ param.recv_addr_len = Ipv4MappedIpv6Addr(&param.recv_addr);
+ param.sendto_addr_len = Ipv4MappedIpv6Addr(&param.sendto_addr);
+ return param;
+ }(),
+ []() {
+ SendtoTestParam param = {};
+ param.description = "unbound IPv6 sendto IPv4";
+ param.send_domain = AF_INET6;
+ param.recv_domain = AF_INET;
+ param.recv_addr_len = Ipv4Addr(&param.recv_addr);
+ param.sendto_addr_len = Ipv4MappedIpv6Addr(&param.sendto_addr);
+ return param;
+ }(),
+ []() {
+ SendtoTestParam param = {};
+ param.description = "IPv6 sendto IPv4";
+ param.send_domain = AF_INET6;
+ param.send_addr_len = Ipv6Addr(&param.send_addr);
+ param.recv_domain = AF_INET;
+ param.recv_addr_len = Ipv4Addr(&param.recv_addr);
+ param.sendto_addr_len = Ipv4MappedIpv6Addr(&param.sendto_addr);
+ param.sendto_errnos = {ENETUNREACH};
+ return param;
+ }(),
+ []() {
+ SendtoTestParam param = {};
+ param.description = "IPv4 mapped IPv6 sendto IPv6";
+ param.send_domain = AF_INET6;
+ param.send_addr_len = Ipv4MappedIpv6Addr(&param.send_addr);
+ param.recv_domain = AF_INET6;
+ param.recv_addr_len = Ipv6Addr(&param.recv_addr);
+ param.sendto_addr_len = Ipv6Addr(&param.sendto_addr);
+ param.sendto_errnos = {EAFNOSUPPORT};
+ // The errno returned changed in Linux commit c8e6ad0829a723.
+ param.sendto_errnos = {EINVAL, EAFNOSUPPORT};
+ return param;
+ }(),
+ []() {
+ SendtoTestParam param = {};
+ param.description = "connected IPv4 mapped IPv6 sendto IPv6";
+ param.send_domain = AF_INET6;
+ param.connect_addr_len =
+ Ipv4MappedIpv6Addr(&param.connect_addr, 5000);
+ param.recv_domain = AF_INET6;
+ param.recv_addr_len = Ipv6Addr(&param.recv_addr);
+ param.sendto_addr_len = Ipv6Addr(&param.sendto_addr);
+ // The errno returned changed in Linux commit c8e6ad0829a723.
+ param.sendto_errnos = {EINVAL, EAFNOSUPPORT};
+ return param;
+ }(),
+ []() {
+ SendtoTestParam param = {};
+ param.description = "connected IPv6 sendto IPv4 mapped IPv6";
+ // TODO(igudger): Determine if this inconsistent behavior is worth
+ // implementing.
+ param.skip_on_gvisor = true;
+ param.send_domain = AF_INET6;
+ param.connect_addr_len = Ipv6Addr(&param.connect_addr, 5000);
+ param.recv_domain = AF_INET6;
+ param.recv_addr_len = Ipv4MappedIpv6Addr(&param.recv_addr);
+ param.sendto_addr_len = Ipv4MappedIpv6Addr(&param.sendto_addr);
+ return param;
+ }(),
+ []() {
+ SendtoTestParam param = {};
+ param.description = "connected IPv6 sendto IPv4";
+ // TODO(igudger): Determine if this inconsistent behavior is worth
+ // implementing.
+ param.skip_on_gvisor = true;
+ param.send_domain = AF_INET6;
+ param.connect_addr_len = Ipv6Addr(&param.connect_addr, 5000);
+ param.recv_domain = AF_INET;
+ param.recv_addr_len = Ipv4Addr(&param.recv_addr);
+ param.sendto_addr_len = Ipv4MappedIpv6Addr(&param.sendto_addr);
+ return param;
+ }()));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc
new file mode 100644
index 000000000..7a8ac30a4
--- /dev/null
+++ b/test/syscalls/linux/udp_socket.cc
@@ -0,0 +1,30 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/udp_socket_test_cases.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest,
+ ::testing::Values(AddressFamily::kIpv4,
+ AddressFamily::kIpv6,
+ AddressFamily::kDualStack));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/udp_socket_errqueue_test_case.cc b/test/syscalls/linux/udp_socket_errqueue_test_case.cc
new file mode 100644
index 000000000..54a0594f7
--- /dev/null
+++ b/test/syscalls/linux/udp_socket_errqueue_test_case.cc
@@ -0,0 +1,57 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef __fuchsia__
+
+#include <arpa/inet.h>
+#include <fcntl.h>
+#include <linux/errqueue.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include "gtest/gtest.h"
+#include "absl/base/macros.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/udp_socket_test_cases.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+TEST_P(UdpSocketTest, ErrorQueue) {
+ char cmsgbuf[CMSG_SPACE(sizeof(sock_extended_err))];
+ msghdr msg;
+ memset(&msg, 0, sizeof(msg));
+ iovec iov;
+ memset(&iov, 0, sizeof(iov));
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = cmsgbuf;
+ msg.msg_controllen = sizeof(cmsgbuf);
+
+ // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT.
+ EXPECT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, MSG_ERRQUEUE),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // __fuchsia__
diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc
new file mode 100644
index 000000000..9cc6be4fb
--- /dev/null
+++ b/test/syscalls/linux/udp_socket_test_cases.cc
@@ -0,0 +1,1727 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/udp_socket_test_cases.h"
+
+#include <arpa/inet.h>
+#include <fcntl.h>
+#include <netinet/in.h>
+#include <poll.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include "absl/strings/str_format.h"
+#ifndef SIOCGSTAMP
+#include <linux/sockios.h>
+#endif
+
+#include "gtest/gtest.h"
+#include "absl/base/macros.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Gets a pointer to the port component of the given address.
+uint16_t* Port(struct sockaddr_storage* addr) {
+ switch (addr->ss_family) {
+ case AF_INET: {
+ auto sin = reinterpret_cast<struct sockaddr_in*>(addr);
+ return &sin->sin_port;
+ }
+ case AF_INET6: {
+ auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr);
+ return &sin6->sin6_port;
+ }
+ }
+
+ return nullptr;
+}
+
+// Sets addr port to "port".
+void SetPort(struct sockaddr_storage* addr, uint16_t port) {
+ switch (addr->ss_family) {
+ case AF_INET: {
+ auto sin = reinterpret_cast<struct sockaddr_in*>(addr);
+ sin->sin_port = port;
+ break;
+ }
+ case AF_INET6: {
+ auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr);
+ sin6->sin6_port = port;
+ break;
+ }
+ }
+}
+
+void UdpSocketTest::SetUp() {
+ addrlen_ = GetAddrLength();
+
+ bind_ =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP));
+ memset(&bind_addr_storage_, 0, sizeof(bind_addr_storage_));
+ bind_addr_ = reinterpret_cast<struct sockaddr*>(&bind_addr_storage_);
+
+ sock_ =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP));
+}
+
+int UdpSocketTest::GetFamily() {
+ if (GetParam() == AddressFamily::kIpv4) {
+ return AF_INET;
+ }
+ return AF_INET6;
+}
+
+PosixError UdpSocketTest::BindLoopback() {
+ bind_addr_storage_ = InetLoopbackAddr();
+ struct sockaddr* bind_addr_ =
+ reinterpret_cast<struct sockaddr*>(&bind_addr_storage_);
+ return BindSocket(bind_.get(), bind_addr_);
+}
+
+PosixError UdpSocketTest::BindAny() {
+ bind_addr_storage_ = InetAnyAddr();
+ struct sockaddr* bind_addr_ =
+ reinterpret_cast<struct sockaddr*>(&bind_addr_storage_);
+ return BindSocket(bind_.get(), bind_addr_);
+}
+
+PosixError UdpSocketTest::BindSocket(int socket, struct sockaddr* addr) {
+ socklen_t len = sizeof(bind_addr_storage_);
+
+ // Bind, then check that we get the right address.
+ RETURN_ERROR_IF_SYSCALL_FAIL(bind(socket, addr, addrlen_));
+
+ RETURN_ERROR_IF_SYSCALL_FAIL(getsockname(socket, addr, &len));
+
+ if (addrlen_ != len) {
+ return PosixError(
+ EINVAL,
+ absl::StrFormat("getsockname len: %u expected: %u", len, addrlen_));
+ }
+ return PosixError(0);
+}
+
+socklen_t UdpSocketTest::GetAddrLength() {
+ struct sockaddr_storage addr;
+ if (GetFamily() == AF_INET) {
+ auto sin = reinterpret_cast<struct sockaddr_in*>(&addr);
+ return sizeof(*sin);
+ }
+
+ auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr);
+ return sizeof(*sin6);
+}
+
+sockaddr_storage UdpSocketTest::InetAnyAddr() {
+ struct sockaddr_storage addr;
+ memset(&addr, 0, sizeof(addr));
+ reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily();
+
+ if (GetFamily() == AF_INET) {
+ auto sin = reinterpret_cast<struct sockaddr_in*>(&addr);
+ sin->sin_addr.s_addr = htonl(INADDR_ANY);
+ sin->sin_port = htons(0);
+ return addr;
+ }
+
+ auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr);
+ sin6->sin6_addr = IN6ADDR_ANY_INIT;
+ sin6->sin6_port = htons(0);
+ return addr;
+}
+
+sockaddr_storage UdpSocketTest::InetLoopbackAddr() {
+ struct sockaddr_storage addr;
+ memset(&addr, 0, sizeof(addr));
+ reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily();
+
+ if (GetFamily() == AF_INET) {
+ auto sin = reinterpret_cast<struct sockaddr_in*>(&addr);
+ sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ sin->sin_port = htons(0);
+ return addr;
+ }
+ auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr);
+ sin6->sin6_addr = in6addr_loopback;
+ sin6->sin6_port = htons(0);
+ return addr;
+}
+
+void UdpSocketTest::Disconnect(int sockfd) {
+ sockaddr_storage addr_storage = InetAnyAddr();
+ sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ socklen_t addrlen = sizeof(addr_storage);
+
+ addr->sa_family = AF_UNSPEC;
+ ASSERT_THAT(connect(sockfd, addr, addrlen), SyscallSucceeds());
+
+ // Check that after disconnect the socket is bound to the ANY address.
+ EXPECT_THAT(getsockname(sockfd, addr, &addrlen), SyscallSucceeds());
+ if (GetParam() == AddressFamily::kIpv4) {
+ auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr);
+ EXPECT_EQ(addrlen, sizeof(*addr_out));
+ EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY));
+ } else {
+ auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr);
+ EXPECT_EQ(addrlen, sizeof(*addr_out));
+ struct in6_addr loopback = IN6ADDR_ANY_INIT;
+
+ EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0);
+ }
+}
+
+TEST_P(UdpSocketTest, Creation) {
+ FileDescriptor sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP));
+ EXPECT_THAT(close(sock.release()), SyscallSucceeds());
+
+ sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, 0));
+ EXPECT_THAT(close(sock.release()), SyscallSucceeds());
+
+ ASSERT_THAT(socket(GetFamily(), SOCK_STREAM, IPPROTO_UDP), SyscallFails());
+}
+
+TEST_P(UdpSocketTest, Getsockname) {
+ // Check that we're not bound.
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, addrlen_);
+ struct sockaddr_storage any = InetAnyAddr();
+ EXPECT_EQ(memcmp(&addr, reinterpret_cast<struct sockaddr*>(&any), addrlen_),
+ 0);
+
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ EXPECT_THAT(
+ getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0);
+}
+
+TEST_P(UdpSocketTest, Getpeername) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Check that we're not connected.
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+
+ // Connect, then check that we get the right address.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0);
+}
+
+TEST_P(UdpSocketTest, SendNotConnected) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Do send & write, they must fail.
+ char buf[512];
+ EXPECT_THAT(send(sock_.get(), buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(EDESTADDRREQ));
+
+ EXPECT_THAT(write(sock_.get(), buf, sizeof(buf)),
+ SyscallFailsWithErrno(EDESTADDRREQ));
+
+ // Use sendto.
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Check that we're bound now.
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_NE(*Port(&addr), 0);
+}
+
+TEST_P(UdpSocketTest, ConnectBinds) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Connect the socket.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ // Check that we're bound now.
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_NE(*Port(&addr), 0);
+}
+
+TEST_P(UdpSocketTest, ReceiveNotBound) {
+ char buf[512];
+ EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+TEST_P(UdpSocketTest, Bind) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Try to bind again.
+ EXPECT_THAT(bind(bind_.get(), bind_addr_, addrlen_),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Check that we're still bound to the original address.
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0);
+}
+
+TEST_P(UdpSocketTest, BindInUse) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Try to bind again.
+ EXPECT_THAT(bind(sock_.get(), bind_addr_, addrlen_),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+TEST_P(UdpSocketTest, ReceiveAfterConnect) {
+ ASSERT_NO_ERRNO(BindLoopback());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ // Send from sock_ to bind_
+ char buf[512];
+ RandomizeBuffer(buf, sizeof(buf));
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Receive the data.
+ char received[sizeof(buf)];
+ EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0),
+ SyscallSucceedsWithValue(sizeof(received)));
+ EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
+}
+
+TEST_P(UdpSocketTest, ReceiveAfterDisconnect) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ for (int i = 0; i < 2; i++) {
+ // Connet sock_ to bound address.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, addrlen_);
+
+ // Send from sock to bind_.
+ char buf[512];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ ASSERT_THAT(sendto(bind_.get(), buf, sizeof(buf), 0,
+ reinterpret_cast<sockaddr*>(&addr), addrlen),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Receive the data.
+ char received[sizeof(buf)];
+ EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0),
+ SyscallSucceedsWithValue(sizeof(received)));
+ EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
+
+ // Disconnect sock_.
+ struct sockaddr unspec = {};
+ unspec.sa_family = AF_UNSPEC;
+ ASSERT_THAT(connect(sock_.get(), &unspec, sizeof(unspec.sa_family)),
+ SyscallSucceeds());
+ }
+}
+
+TEST_P(UdpSocketTest, Connect) {
+ ASSERT_NO_ERRNO(BindLoopback());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ // Check that we're connected to the right peer.
+ struct sockaddr_storage peer;
+ socklen_t peerlen = sizeof(peer);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen),
+ SyscallSucceeds());
+ EXPECT_EQ(peerlen, addrlen_);
+ EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0);
+
+ // Try to bind after connect.
+ struct sockaddr_storage any = InetAnyAddr();
+ EXPECT_THAT(
+ bind(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_),
+ SyscallFailsWithErrno(EINVAL));
+
+ struct sockaddr_storage bind2_storage = InetLoopbackAddr();
+ struct sockaddr* bind2_addr =
+ reinterpret_cast<struct sockaddr*>(&bind2_storage);
+ FileDescriptor bind2 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP));
+ ASSERT_NO_ERRNO(BindSocket(bind2.get(), bind2_addr));
+
+ // Try to connect again.
+ EXPECT_THAT(connect(sock_.get(), bind2_addr, addrlen_), SyscallSucceeds());
+
+ // Check that peer name changed.
+ peerlen = sizeof(peer);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen),
+ SyscallSucceeds());
+ EXPECT_EQ(peerlen, addrlen_);
+ EXPECT_EQ(memcmp(&peer, bind2_addr, addrlen_), 0);
+}
+
+TEST_P(UdpSocketTest, ConnectAnyZero) {
+ // TODO(138658473): Enable when we can connect to port 0 with gVisor.
+ SKIP_IF(IsRunningOnGvisor());
+
+ struct sockaddr_storage any = InetAnyAddr();
+ EXPECT_THAT(
+ connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_),
+ SyscallSucceeds());
+
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
+TEST_P(UdpSocketTest, ConnectAnyWithPort) {
+ ASSERT_NO_ERRNO(BindAny());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+}
+
+TEST_P(UdpSocketTest, DisconnectAfterConnectAny) {
+ // TODO(138658473): Enable when we can connect to port 0 with gVisor.
+ SKIP_IF(IsRunningOnGvisor());
+ struct sockaddr_storage any = InetAnyAddr();
+ EXPECT_THAT(
+ connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_),
+ SyscallSucceeds());
+
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+
+ Disconnect(sock_.get());
+}
+
+TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) {
+ ASSERT_NO_ERRNO(BindAny());
+ EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_EQ(*Port(&bind_addr_storage_), *Port(&addr));
+
+ Disconnect(sock_.get());
+}
+
+TEST_P(UdpSocketTest, DisconnectAfterBind) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Bind to the next port above bind_.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_NO_ERRNO(BindSocket(sock_.get(), addr));
+
+ // Connect the socket.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ struct sockaddr_storage unspec = {};
+ unspec.ss_family = AF_UNSPEC;
+ EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&unspec),
+ sizeof(unspec.ss_family)),
+ SyscallSucceeds());
+
+ // Check that we're still bound.
+ socklen_t addrlen = sizeof(unspec);
+ EXPECT_THAT(
+ getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), &addrlen),
+ SyscallSucceeds());
+
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_EQ(memcmp(addr, &unspec, addrlen_), 0);
+
+ addrlen = sizeof(addr);
+ EXPECT_THAT(getpeername(sock_.get(), addr, &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
+TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) {
+ ASSERT_NO_ERRNO(BindAny());
+
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ socklen_t addrlen = sizeof(addr);
+
+ // Connect the socket.
+ ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
+
+ EXPECT_THAT(getsockname(bind_.get(), addr, &addrlen), SyscallSucceeds());
+
+ // If the socket is bound to ANY and connected to a loopback address,
+ // getsockname() has to return the loopback address.
+ if (GetParam() == AddressFamily::kIpv4) {
+ auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr);
+ EXPECT_EQ(addrlen, sizeof(*addr_out));
+ EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK));
+ } else {
+ auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr);
+ struct in6_addr loopback = IN6ADDR_LOOPBACK_INIT;
+ EXPECT_EQ(addrlen, sizeof(*addr_out));
+ EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0);
+ }
+}
+
+TEST_P(UdpSocketTest, DisconnectAfterBindToAny) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ struct sockaddr_storage any_storage = InetAnyAddr();
+ struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage);
+ SetPort(&any_storage, *Port(&bind_addr_storage_) + 1);
+
+ ASSERT_NO_ERRNO(BindSocket(sock_.get(), any));
+
+ // Connect the socket.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ Disconnect(sock_.get());
+
+ // Check that we're still bound.
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_EQ(memcmp(&addr, any, addrlen), 0);
+
+ addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
+TEST_P(UdpSocketTest, Disconnect) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ struct sockaddr_storage any_storage = InetAnyAddr();
+ struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage);
+ SetPort(&any_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_NO_ERRNO(BindSocket(sock_.get(), any));
+
+ for (int i = 0; i < 2; i++) {
+ // Try to connect again.
+ EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ // Check that we're connected to the right peer.
+ struct sockaddr_storage peer;
+ socklen_t peerlen = sizeof(peer);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen),
+ SyscallSucceeds());
+ EXPECT_EQ(peerlen, addrlen_);
+ EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0);
+
+ // Try to disconnect.
+ struct sockaddr_storage addr = {};
+ addr.ss_family = AF_UNSPEC;
+ EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&addr),
+ sizeof(addr.ss_family)),
+ SyscallSucceeds());
+
+ peerlen = sizeof(peer);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen),
+ SyscallFailsWithErrno(ENOTCONN));
+
+ // Check that we're still bound.
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_EQ(*Port(&addr), *Port(&any_storage));
+ }
+}
+
+TEST_P(UdpSocketTest, ConnectBadAddress) {
+ struct sockaddr addr = {};
+ addr.sa_family = GetFamily();
+ ASSERT_THAT(connect(sock_.get(), &addr, sizeof(addr.sa_family)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ struct sockaddr_storage addr_storage = InetAnyAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ // Send to a different destination than we're connected to.
+ char buf[512];
+ EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, addr, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+}
+
+TEST_P(UdpSocketTest, ZerolengthWriteAllowed) {
+ // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes.
+ SKIP_IF(IsRunningWithHostinet());
+
+ ASSERT_NO_ERRNO(BindLoopback());
+ // Connect to loopback:bind_addr_+1.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Bind sock to loopback:bind_addr_+1.
+ ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds());
+
+ char buf[3];
+ // Send zero length packet from bind_ to sock_.
+ ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0));
+
+ struct pollfd pfd = {sock_.get(), POLLIN, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout*/ 1000),
+ SyscallSucceedsWithValue(1));
+
+ // Receive the packet.
+ char received[3];
+ EXPECT_THAT(read(sock_.get(), received, sizeof(received)),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_P(UdpSocketTest, ZerolengthWriteAllowedNonBlockRead) {
+ // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes.
+ SKIP_IF(IsRunningWithHostinet());
+
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Connect to loopback:bind_addr_port+1.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Bind sock to loopback:bind_addr_port+1.
+ ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Set sock to non-blocking.
+ int opts = 0;
+ ASSERT_THAT(opts = fcntl(sock_.get(), F_GETFL), SyscallSucceeds());
+ ASSERT_THAT(fcntl(sock_.get(), F_SETFL, opts | O_NONBLOCK),
+ SyscallSucceeds());
+
+ char buf[3];
+ // Send zero length packet from bind_ to sock_.
+ ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0));
+
+ struct pollfd pfd = {sock_.get(), POLLIN, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000),
+ SyscallSucceedsWithValue(1));
+
+ // Receive the packet.
+ char received[3];
+ EXPECT_THAT(read(sock_.get(), received, sizeof(received)),
+ SyscallSucceedsWithValue(0));
+ EXPECT_THAT(read(sock_.get(), received, sizeof(received)),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST_P(UdpSocketTest, SendAndReceiveNotConnected) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Send some data to bind_.
+ char buf[512];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Receive the data.
+ char received[sizeof(buf)];
+ EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0),
+ SyscallSucceedsWithValue(sizeof(received)));
+ EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
+}
+
+TEST_P(UdpSocketTest, SendAndReceiveConnected) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Connect to loopback:bind_addr_port+1.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Bind sock to loopback:TestPort+1.
+ ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Send some data from sock to bind_.
+ char buf[512];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Receive the data.
+ char received[sizeof(buf)];
+ EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0),
+ SyscallSucceedsWithValue(sizeof(received)));
+ EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
+}
+
+TEST_P(UdpSocketTest, ReceiveFromNotConnected) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Connect to loopback:bind_addr_port+1.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Bind sock to loopback:bind_addr_port+2.
+ struct sockaddr_storage addr2_storage = InetLoopbackAddr();
+ struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage);
+ SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2);
+ ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds());
+
+ // Send some data from sock to bind_.
+ char buf[512];
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Check that the data isn't received because it was sent from a different
+ // address than we're connected.
+ EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+TEST_P(UdpSocketTest, ReceiveBeforeConnect) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Bind sock to loopback:bind_addr_port+2.
+ struct sockaddr_storage addr2_storage = InetLoopbackAddr();
+ struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage);
+ SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2);
+ ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds());
+
+ // Send some data from sock to bind_.
+ char buf[512];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Connect to loopback:TestPort+1.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Receive the data. It works because it was sent before the connect.
+ char received[sizeof(buf)];
+ EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0),
+ SyscallSucceedsWithValue(sizeof(received)));
+ EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
+
+ // Send again. This time it should not be received.
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(recv(bind_.get(), buf, sizeof(buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+TEST_P(UdpSocketTest, ReceiveFrom) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Connect to loopback:bind_addr_port+1.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Bind sock to loopback:TestPort+1.
+ ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Send some data from sock to bind_.
+ char buf[512];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Receive the data and sender address.
+ char received[sizeof(buf)];
+ struct sockaddr_storage addr2;
+ socklen_t addr2len = sizeof(addr2);
+ EXPECT_THAT(recvfrom(bind_.get(), received, sizeof(received), 0,
+ reinterpret_cast<sockaddr*>(&addr2), &addr2len),
+ SyscallSucceedsWithValue(sizeof(received)));
+ EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
+ EXPECT_EQ(addr2len, addrlen_);
+ EXPECT_EQ(memcmp(addr, &addr2, addrlen_), 0);
+}
+
+TEST_P(UdpSocketTest, Listen) {
+ ASSERT_THAT(listen(sock_.get(), SOMAXCONN),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+}
+
+TEST_P(UdpSocketTest, Accept) {
+ ASSERT_THAT(accept(sock_.get(), nullptr, nullptr),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+}
+
+// This test validates that a read shutdown with pending data allows the read
+// to proceed with the data before returning EAGAIN.
+TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Connect to loopback:bind_addr_port+1.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Bind to loopback:bind_addr_port+1 and connect to bind_addr_.
+ ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ // Verify that we get EWOULDBLOCK when there is nothing to read.
+ char received[512];
+ EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ const char* buf = "abc";
+ EXPECT_THAT(write(sock_.get(), buf, 3), SyscallSucceedsWithValue(3));
+
+ int opts = 0;
+ ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds());
+ ASSERT_THAT(fcntl(bind_.get(), F_SETFL, opts | O_NONBLOCK),
+ SyscallSucceeds());
+ ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds());
+ ASSERT_NE(opts & O_NONBLOCK, 0);
+
+ EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds());
+
+ struct pollfd pfd = {bind_.get(), POLLIN, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000),
+ SyscallSucceedsWithValue(1));
+
+ // We should get the data even though read has been shutdown.
+ EXPECT_THAT(recv(bind_.get(), received, 2, 0), SyscallSucceedsWithValue(2));
+
+ // Because we read less than the entire packet length, since it's a packet
+ // based socket any subsequent reads should return EWOULDBLOCK.
+ EXPECT_THAT(recv(bind_.get(), received, 1, 0),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+// This test is validating that even after a socket is shutdown if it's
+// reconnected it will reset the shutdown state.
+TEST_P(UdpSocketTest, ReadShutdownSameSocketResetsShutdownState) {
+ char received[512];
+ EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN));
+
+ EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Connect the socket, then try to shutdown again.
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Connect to loopback:bind_addr_port+1.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
+
+ EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+TEST_P(UdpSocketTest, ReadShutdown) {
+ // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without
+ // MSG_DONTWAIT blocks indefinitely.
+ SKIP_IF(IsRunningWithHostinet());
+
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ char received[512];
+ EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN));
+
+ EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Connect the socket, then try to shutdown again.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds());
+
+ EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_P(UdpSocketTest, ReadShutdownDifferentThread) {
+ // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without
+ // MSG_DONTWAIT blocks indefinitely.
+ SKIP_IF(IsRunningWithHostinet());
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ char received[512];
+ EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Connect the socket, then shutdown from another thread.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ ScopedThread t([&] {
+ absl::SleepFor(absl::Milliseconds(200));
+ EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds());
+ });
+ EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0),
+ SyscallSucceedsWithValue(0));
+ t.Join();
+
+ EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_P(UdpSocketTest, WriteShutdown) {
+ ASSERT_NO_ERRNO(BindLoopback());
+ EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+ EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallSucceeds());
+}
+
+TEST_P(UdpSocketTest, SynchronousReceive) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Send some data to bind_ from another thread.
+ char buf[512];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ // Receive the data prior to actually starting the other thread.
+ char received[512];
+ EXPECT_THAT(
+ RetryEINTR(recv)(bind_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Start the thread.
+ ScopedThread t([&] {
+ absl::SleepFor(absl::Milliseconds(200));
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, this->bind_addr_,
+ this->addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ });
+
+ EXPECT_THAT(RetryEINTR(recv)(bind_.get(), received, sizeof(received), 0),
+ SyscallSucceedsWithValue(512));
+ EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
+}
+
+TEST_P(UdpSocketTest, BoundaryPreserved_SendRecv) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Send 3 packets from sock to bind_.
+ constexpr int psize = 100;
+ char buf[3 * psize];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ for (int i = 0; i < 3; ++i) {
+ ASSERT_THAT(
+ sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(psize));
+ }
+
+ // Receive the data as 3 separate packets.
+ char received[6 * psize];
+ for (int i = 0; i < 3; ++i) {
+ EXPECT_THAT(recv(bind_.get(), received + i * psize, 3 * psize, 0),
+ SyscallSucceedsWithValue(psize));
+ }
+ EXPECT_EQ(memcmp(buf, received, 3 * psize), 0);
+}
+
+TEST_P(UdpSocketTest, BoundaryPreserved_WritevReadv) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Direct writes from sock to bind_.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ // Send 2 packets from sock to bind_, where each packet's data consists of
+ // 2 discontiguous iovecs.
+ constexpr size_t kPieceSize = 100;
+ char buf[4 * kPieceSize];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ for (int i = 0; i < 2; i++) {
+ struct iovec iov[2];
+ for (int j = 0; j < 2; j++) {
+ iov[j].iov_base = reinterpret_cast<void*>(
+ reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize);
+ iov[j].iov_len = kPieceSize;
+ }
+ ASSERT_THAT(writev(sock_.get(), iov, 2),
+ SyscallSucceedsWithValue(2 * kPieceSize));
+ }
+
+ // Receive the data as 2 separate packets.
+ char received[6 * kPieceSize];
+ for (int i = 0; i < 2; i++) {
+ struct iovec iov[3];
+ for (int j = 0; j < 3; j++) {
+ iov[j].iov_base = reinterpret_cast<void*>(
+ reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize);
+ iov[j].iov_len = kPieceSize;
+ }
+ ASSERT_THAT(readv(bind_.get(), iov, 3),
+ SyscallSucceedsWithValue(2 * kPieceSize));
+ }
+ EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0);
+}
+
+TEST_P(UdpSocketTest, BoundaryPreserved_SendMsgRecvMsg) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Send 2 packets from sock to bind_, where each packet's data consists of
+ // 2 discontiguous iovecs.
+ constexpr size_t kPieceSize = 100;
+ char buf[4 * kPieceSize];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ for (int i = 0; i < 2; i++) {
+ struct iovec iov[2];
+ for (int j = 0; j < 2; j++) {
+ iov[j].iov_base = reinterpret_cast<void*>(
+ reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize);
+ iov[j].iov_len = kPieceSize;
+ }
+ struct msghdr msg = {};
+ msg.msg_name = bind_addr_;
+ msg.msg_namelen = addrlen_;
+ msg.msg_iov = iov;
+ msg.msg_iovlen = 2;
+ ASSERT_THAT(sendmsg(sock_.get(), &msg, 0),
+ SyscallSucceedsWithValue(2 * kPieceSize));
+ }
+
+ // Receive the data as 2 separate packets.
+ char received[6 * kPieceSize];
+ for (int i = 0; i < 2; i++) {
+ struct iovec iov[3];
+ for (int j = 0; j < 3; j++) {
+ iov[j].iov_base = reinterpret_cast<void*>(
+ reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize);
+ iov[j].iov_len = kPieceSize;
+ }
+ struct msghdr msg = {};
+ msg.msg_iov = iov;
+ msg.msg_iovlen = 3;
+ ASSERT_THAT(recvmsg(bind_.get(), &msg, 0),
+ SyscallSucceedsWithValue(2 * kPieceSize));
+ }
+ EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0);
+}
+
+TEST_P(UdpSocketTest, FIONREADShutdown) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ int n = -1;
+ EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+
+ // A UDP socket must be connected before it can be shutdown.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ n = -1;
+ EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+
+ EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds());
+
+ n = -1;
+ EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+}
+
+TEST_P(UdpSocketTest, FIONREADWriteShutdown) {
+ int n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // A UDP socket must be connected before it can be shutdown.
+ ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+
+ const char str[] = "abc";
+ ASSERT_THAT(send(bind_.get(), str, sizeof(str), 0),
+ SyscallSucceedsWithValue(sizeof(str)));
+
+ struct pollfd pfd = {bind_.get(), POLLIN, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000),
+ SyscallSucceedsWithValue(1));
+
+ n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, sizeof(str));
+
+ EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds());
+
+ n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, sizeof(str));
+}
+
+// NOTE: Do not use `FIONREAD` as test name because it will be replaced by the
+// corresponding macro and become `0x541B`.
+TEST_P(UdpSocketTest, Fionread) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Check that the bound socket with an empty buffer reports an empty first
+ // packet.
+ int n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+
+ // Send 3 packets from sock to bind_.
+ constexpr int psize = 100;
+ char buf[3 * psize];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ struct pollfd pfd = {bind_.get(), POLLIN, 0};
+ for (int i = 0; i < 3; ++i) {
+ ASSERT_THAT(
+ sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(psize));
+
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000),
+ SyscallSucceedsWithValue(1));
+
+ // Check that regardless of how many packets are in the queue, the size
+ // reported is that of a single packet.
+ n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, psize);
+ }
+}
+
+TEST_P(UdpSocketTest, FIONREADZeroLengthPacket) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Check that the bound socket with an empty buffer reports an empty first
+ // packet.
+ int n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+
+ // Send 3 packets from sock to bind_.
+ constexpr int psize = 100;
+ char buf[3 * psize];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ struct pollfd pfd = {bind_.get(), POLLIN, 0};
+ for (int i = 0; i < 3; ++i) {
+ ASSERT_THAT(
+ sendto(sock_.get(), buf + i * psize, 0, 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(0));
+
+ // TODO(gvisor.dev/issue/2726): sending a zero-length message to a hostinet
+ // socket does not cause a poll event to be triggered.
+ if (!IsRunningWithHostinet()) {
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000),
+ SyscallSucceedsWithValue(1));
+ }
+
+ // Check that regardless of how many packets are in the queue, the size
+ // reported is that of a single packet.
+ n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+ }
+}
+
+TEST_P(UdpSocketTest, FIONREADZeroLengthWriteShutdown) {
+ int n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // A UDP socket must be connected before it can be shutdown.
+ ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+
+ const char str[] = "abc";
+ ASSERT_THAT(send(bind_.get(), str, 0, 0), SyscallSucceedsWithValue(0));
+
+ n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+
+ EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds());
+
+ n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+}
+
+TEST_P(UdpSocketTest, SoNoCheckOffByDefault) {
+ // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by
+ // hostinet.
+ SKIP_IF(IsRunningWithHostinet());
+
+ int v = -1;
+ socklen_t optlen = sizeof(v);
+ ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen),
+ SyscallSucceeds());
+ ASSERT_EQ(v, kSockOptOff);
+ ASSERT_EQ(optlen, sizeof(v));
+}
+
+TEST_P(UdpSocketTest, SoNoCheck) {
+ // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by
+ // hostinet.
+ SKIP_IF(IsRunningWithHostinet());
+
+ int v = kSockOptOn;
+ socklen_t optlen = sizeof(v);
+ ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen),
+ SyscallSucceeds());
+ v = -1;
+ ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen),
+ SyscallSucceeds());
+ ASSERT_EQ(v, kSockOptOn);
+ ASSERT_EQ(optlen, sizeof(v));
+
+ v = kSockOptOff;
+ ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen),
+ SyscallSucceeds());
+ v = -1;
+ ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen),
+ SyscallSucceeds());
+ ASSERT_EQ(v, kSockOptOff);
+ ASSERT_EQ(optlen, sizeof(v));
+}
+
+TEST_P(UdpSocketTest, SoTimestampOffByDefault) {
+ // TODO(gvisor.dev/issue/1202): SO_TIMESTAMP socket option not supported by
+ // hostinet.
+ SKIP_IF(IsRunningWithHostinet());
+
+ int v = -1;
+ socklen_t optlen = sizeof(v);
+ ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, &optlen),
+ SyscallSucceeds());
+ ASSERT_EQ(v, kSockOptOff);
+ ASSERT_EQ(optlen, sizeof(v));
+}
+
+TEST_P(UdpSocketTest, SoTimestamp) {
+ // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not
+ // supported by hostinet.
+ SKIP_IF(IsRunningWithHostinet());
+
+ ASSERT_NO_ERRNO(BindLoopback());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ int v = 1;
+ ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)),
+ SyscallSucceeds());
+
+ char buf[3];
+ // Send zero length packet from sock to bind_.
+ ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0),
+ SyscallSucceedsWithValue(0));
+
+ struct pollfd pfd = {bind_.get(), POLLIN, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000),
+ SyscallSucceedsWithValue(1));
+
+ char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))];
+ msghdr msg;
+ memset(&msg, 0, sizeof(msg));
+ iovec iov;
+ memset(&iov, 0, sizeof(iov));
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = cmsgbuf;
+ msg.msg_controllen = sizeof(cmsgbuf);
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0),
+ SyscallSucceedsWithValue(0));
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET);
+ ASSERT_EQ(cmsg->cmsg_type, SO_TIMESTAMP);
+ ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct timeval)));
+
+ struct timeval tv = {};
+ memcpy(&tv, CMSG_DATA(cmsg), sizeof(struct timeval));
+
+ ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0);
+
+ // There should be nothing to get via ioctl.
+ ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+TEST_P(UdpSocketTest, WriteShutdownNotConnected) {
+ EXPECT_THAT(shutdown(bind_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
+}
+
+TEST_P(UdpSocketTest, TimestampIoctl) {
+ // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet.
+ SKIP_IF(IsRunningWithHostinet());
+
+ ASSERT_NO_ERRNO(BindLoopback());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ char buf[3];
+ // Send packet from sock to bind_.
+ ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ struct pollfd pfd = {bind_.get(), POLLIN, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000),
+ SyscallSucceedsWithValue(1));
+
+ // There should be no control messages.
+ char recv_buf[sizeof(buf)];
+ ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf)));
+
+ // A nonzero timeval should be available via ioctl.
+ struct timeval tv = {};
+ ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds());
+ ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0);
+}
+
+TEST_P(UdpSocketTest, TimestampIoctlNothingRead) {
+ // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet.
+ SKIP_IF(IsRunningWithHostinet());
+
+ ASSERT_NO_ERRNO(BindLoopback());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ struct timeval tv = {};
+ ASSERT_THAT(ioctl(sock_.get(), SIOCGSTAMP, &tv),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+// Test that the timestamp accessed via SIOCGSTAMP is still accessible after
+// SO_TIMESTAMP is enabled and used to retrieve a timestamp.
+TEST_P(UdpSocketTest, TimestampIoctlPersistence) {
+ // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not
+ // supported by hostinet.
+ SKIP_IF(IsRunningWithHostinet());
+
+ ASSERT_NO_ERRNO(BindLoopback());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ char buf[3];
+ // Send packet from sock to bind_.
+ ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0),
+ SyscallSucceedsWithValue(0));
+
+ struct pollfd pfd = {bind_.get(), POLLIN, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000),
+ SyscallSucceedsWithValue(1));
+
+ // There should be no control messages.
+ char recv_buf[sizeof(buf)];
+ ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf)));
+
+ // A nonzero timeval should be available via ioctl.
+ struct timeval tv = {};
+ ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds());
+ ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0);
+
+ // Enable SO_TIMESTAMP and send a message.
+ int v = 1;
+ EXPECT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)),
+ SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000),
+ SyscallSucceedsWithValue(1));
+
+ // There should be a message for SO_TIMESTAMP.
+ char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))];
+ msghdr msg = {};
+ iovec iov = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = cmsgbuf;
+ msg.msg_controllen = sizeof(cmsgbuf);
+ ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0),
+ SyscallSucceedsWithValue(0));
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+
+ // The ioctl should return the exact same values as before.
+ struct timeval tv2 = {};
+ ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv2), SyscallSucceeds());
+ ASSERT_EQ(tv.tv_sec, tv2.tv_sec);
+ ASSERT_EQ(tv.tv_usec, tv2.tv_usec);
+}
+
+// Test that a socket with IP_TOS or IPV6_TCLASS set will set the TOS byte on
+// outgoing packets, and that a receiving socket with IP_RECVTOS or
+// IPV6_RECVTCLASS will create the corresponding control message.
+TEST_P(UdpSocketTest, SetAndReceiveTOS) {
+ ASSERT_NO_ERRNO(BindLoopback());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ // Allow socket to receive control message.
+ int recv_level = SOL_IP;
+ int recv_type = IP_RECVTOS;
+ if (GetParam() != AddressFamily::kIpv4) {
+ recv_level = SOL_IPV6;
+ recv_type = IPV6_RECVTCLASS;
+ }
+ ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Set socket TOS.
+ int sent_level = recv_level;
+ int sent_type = IP_TOS;
+ if (sent_level == SOL_IPV6) {
+ sent_type = IPV6_TCLASS;
+ }
+ int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value.
+ ASSERT_THAT(setsockopt(sock_.get(), sent_level, sent_type, &sent_tos,
+ sizeof(sent_tos)),
+ SyscallSucceeds());
+
+ // Prepare message to send.
+ constexpr size_t kDataLength = 1024;
+ struct msghdr sent_msg = {};
+ struct iovec sent_iov = {};
+ char sent_data[kDataLength];
+ sent_iov.iov_base = &sent_data[0];
+ sent_iov.iov_len = kDataLength;
+ sent_msg.msg_iov = &sent_iov;
+ sent_msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ // Receive message.
+ struct msghdr received_msg = {};
+ struct iovec received_iov = {};
+ char received_data[kDataLength];
+ received_iov.iov_base = &received_data[0];
+ received_iov.iov_len = kDataLength;
+ received_msg.msg_iov = &received_iov;
+ received_msg.msg_iovlen = 1;
+ size_t cmsg_data_len = sizeof(int8_t);
+ if (sent_type == IPV6_TCLASS) {
+ cmsg_data_len = sizeof(int);
+ }
+ std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len));
+ received_msg.msg_control = &received_cmsgbuf[0];
+ received_msg.msg_controllen = received_cmsgbuf.size();
+ ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len));
+ EXPECT_EQ(cmsg->cmsg_level, sent_level);
+ EXPECT_EQ(cmsg->cmsg_type, sent_type);
+ int8_t received_tos = 0;
+ memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos));
+ EXPECT_EQ(received_tos, sent_tos);
+}
+
+// Test that sendmsg with IP_TOS and IPV6_TCLASS control messages will set the
+// TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or
+// IPV6_RECVTCLASS will create the corresponding control message.
+TEST_P(UdpSocketTest, SendAndReceiveTOS) {
+ // TODO(b/146661005): Setting TOS via cmsg not supported for netstack.
+ SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet());
+
+ ASSERT_NO_ERRNO(BindLoopback());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ // Allow socket to receive control message.
+ int recv_level = SOL_IP;
+ int recv_type = IP_RECVTOS;
+ if (GetParam() != AddressFamily::kIpv4) {
+ recv_level = SOL_IPV6;
+ recv_type = IPV6_RECVTCLASS;
+ }
+ int recv_opt = kSockOptOn;
+ ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &recv_opt,
+ sizeof(recv_opt)),
+ SyscallSucceeds());
+
+ // Prepare message to send.
+ constexpr size_t kDataLength = 1024;
+ int sent_level = recv_level;
+ int sent_type = IP_TOS;
+ int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value.
+
+ struct msghdr sent_msg = {};
+ struct iovec sent_iov = {};
+ char sent_data[kDataLength];
+ sent_iov.iov_base = &sent_data[0];
+ sent_iov.iov_len = kDataLength;
+ sent_msg.msg_iov = &sent_iov;
+ sent_msg.msg_iovlen = 1;
+ size_t cmsg_data_len = sizeof(int8_t);
+ if (sent_level == SOL_IPV6) {
+ sent_type = IPV6_TCLASS;
+ cmsg_data_len = sizeof(int);
+ }
+ std::vector<char> sent_cmsgbuf(CMSG_SPACE(cmsg_data_len));
+ sent_msg.msg_control = &sent_cmsgbuf[0];
+ sent_msg.msg_controllen = CMSG_LEN(cmsg_data_len);
+
+ // Manually add control message.
+ struct cmsghdr* sent_cmsg = CMSG_FIRSTHDR(&sent_msg);
+ sent_cmsg->cmsg_len = CMSG_LEN(cmsg_data_len);
+ sent_cmsg->cmsg_level = sent_level;
+ sent_cmsg->cmsg_type = sent_type;
+ *(int8_t*)CMSG_DATA(sent_cmsg) = sent_tos;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ // Receive message.
+ struct msghdr received_msg = {};
+ struct iovec received_iov = {};
+ char received_data[kDataLength];
+ received_iov.iov_base = &received_data[0];
+ received_iov.iov_len = kDataLength;
+ received_msg.msg_iov = &received_iov;
+ received_msg.msg_iovlen = 1;
+ std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len));
+ received_msg.msg_control = &received_cmsgbuf[0];
+ received_msg.msg_controllen = CMSG_LEN(cmsg_data_len);
+ ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len));
+ EXPECT_EQ(cmsg->cmsg_level, sent_level);
+ EXPECT_EQ(cmsg->cmsg_type, sent_type);
+ int8_t received_tos = 0;
+ memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos));
+ EXPECT_EQ(received_tos, sent_tos);
+}
+
+TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) {
+ // Discover minimum buffer size by setting it to zero.
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz,
+ sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ int min = 0;
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+
+ // Bind bind_ to loopback.
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ {
+ // Send data of size min and verify that it's received.
+ std::vector<char> buf(min);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(
+ sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(buf.size()));
+ std::vector<char> received(buf.size());
+ EXPECT_THAT(
+ recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT),
+ SyscallSucceedsWithValue(received.size()));
+ }
+
+ {
+ // Send data of size min + 1 and verify that its received. Both linux and
+ // Netstack accept a dgram that exceeds rcvBuf limits if the receive buffer
+ // is currently empty.
+ std::vector<char> buf(min + 1);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(
+ sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(buf.size()));
+
+ std::vector<char> received(buf.size());
+ EXPECT_THAT(
+ recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT),
+ SyscallSucceedsWithValue(received.size()));
+ }
+}
+
+// Test that receive buffer limits are enforced.
+TEST_P(UdpSocketTest, RecvBufLimits) {
+ // Bind s_ to loopback.
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ int min = 0;
+ {
+ // Discover minimum buffer size by trying to set it to zero.
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz,
+ sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+ }
+
+ // Now set the limit to min * 4.
+ int new_rcv_buf_sz = min * 4;
+ if (!IsRunningOnGvisor() || IsRunningWithHostinet()) {
+ // Linux doubles the value specified so just set to min * 2.
+ new_rcv_buf_sz = min * 2;
+ }
+
+ ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz,
+ sizeof(new_rcv_buf_sz)),
+ SyscallSucceeds());
+ int rcv_buf_sz = 0;
+ {
+ socklen_t rcv_buf_len = sizeof(rcv_buf_sz);
+ ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &rcv_buf_sz,
+ &rcv_buf_len),
+ SyscallSucceeds());
+ }
+
+ {
+ std::vector<char> buf(min);
+ RandomizeBuffer(buf.data(), buf.size());
+
+ ASSERT_THAT(
+ sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(buf.size()));
+ ASSERT_THAT(
+ sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(buf.size()));
+ ASSERT_THAT(
+ sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(buf.size()));
+ ASSERT_THAT(
+ sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(buf.size()));
+ int sent = 4;
+ if (IsRunningOnGvisor() && !IsRunningWithHostinet()) {
+ // Linux seems to drop the 4th packet even though technically it should
+ // fit in the receive buffer.
+ ASSERT_THAT(
+ sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(buf.size()));
+ sent++;
+ }
+
+ for (int i = 0; i < sent - 1; i++) {
+ // Receive the data.
+ std::vector<char> received(buf.size());
+ EXPECT_THAT(
+ recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT),
+ SyscallSucceedsWithValue(received.size()));
+ EXPECT_EQ(memcmp(buf.data(), received.data(), buf.size()), 0);
+ }
+
+ // The last receive should fail with EAGAIN as the last packet should have
+ // been dropped due to lack of space in the receive buffer.
+ std::vector<char> received(buf.size());
+ EXPECT_THAT(
+ recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+ }
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/udp_socket_test_cases.h b/test/syscalls/linux/udp_socket_test_cases.h
new file mode 100644
index 000000000..f7e25c805
--- /dev/null
+++ b/test/syscalls/linux/udp_socket_test_cases.h
@@ -0,0 +1,82 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_
+#define THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_
+
+#include <sys/socket.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+// The initial port to be be used on gvisor.
+constexpr int TestPort = 40000;
+
+// Fixture for tests parameterized by the address family to use (AF_INET and
+// AF_INET6) when creating sockets.
+class UdpSocketTest
+ : public ::testing::TestWithParam<gvisor::testing::AddressFamily> {
+ protected:
+ // Creates two sockets that will be used by test cases.
+ void SetUp() override;
+
+ // Binds the socket bind_ to the loopback and updates bind_addr_.
+ PosixError BindLoopback();
+
+ // Binds the socket bind_ to Any and updates bind_addr_.
+ PosixError BindAny();
+
+ // Binds given socket to address addr and updates.
+ PosixError BindSocket(int socket, struct sockaddr* addr);
+
+ // Return initialized Any address to port 0.
+ struct sockaddr_storage InetAnyAddr();
+
+ // Return initialized Loopback address to port 0.
+ struct sockaddr_storage InetLoopbackAddr();
+
+ // Disconnects socket sockfd.
+ void Disconnect(int sockfd);
+
+ // Get family for the test.
+ int GetFamily();
+
+ // Socket used by Bind methods
+ FileDescriptor bind_;
+
+ // Second socket used for tests.
+ FileDescriptor sock_;
+
+ // Address for bind_ socket.
+ struct sockaddr* bind_addr_;
+
+ // Initialized to the length based on GetFamily().
+ socklen_t addrlen_;
+
+ // Storage for bind_addr_.
+ struct sockaddr_storage bind_addr_storage_;
+
+ private:
+ // Helper to initialize addrlen_ for the test case.
+ socklen_t GetAddrLength();
+};
+} // namespace testing
+} // namespace gvisor
+
+#endif // THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_
diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc
new file mode 100644
index 000000000..64d6d0b8f
--- /dev/null
+++ b/test/syscalls/linux/uidgid.cc
@@ -0,0 +1,276 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <grp.h>
+#include <sys/resource.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "test/util/capability_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+#include "test/util/uid_util.h"
+
+ABSL_FLAG(int32_t, scratch_uid1, 65534, "first scratch UID");
+ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID");
+ABSL_FLAG(int32_t, scratch_gid1, 65534, "first scratch GID");
+ABSL_FLAG(int32_t, scratch_gid2, 65533, "second scratch GID");
+
+using ::testing::UnorderedElementsAreArray;
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(UidGidTest, Getuid) {
+ uid_t ruid, euid, suid;
+ EXPECT_THAT(getresuid(&ruid, &euid, &suid), SyscallSucceeds());
+ EXPECT_THAT(getuid(), SyscallSucceedsWithValue(ruid));
+ EXPECT_THAT(geteuid(), SyscallSucceedsWithValue(euid));
+}
+
+TEST(UidGidTest, Getgid) {
+ gid_t rgid, egid, sgid;
+ EXPECT_THAT(getresgid(&rgid, &egid, &sgid), SyscallSucceeds());
+ EXPECT_THAT(getgid(), SyscallSucceedsWithValue(rgid));
+ EXPECT_THAT(getegid(), SyscallSucceedsWithValue(egid));
+}
+
+TEST(UidGidTest, Getgroups) {
+ // "If size is zero, list is not modified, but the total number of
+ // supplementary group IDs for the process is returned." - getgroups(2)
+ int nr_groups;
+ ASSERT_THAT(nr_groups = getgroups(0, nullptr), SyscallSucceeds());
+ std::vector<gid_t> list(nr_groups);
+ EXPECT_THAT(getgroups(list.size(), list.data()), SyscallSucceeds());
+
+ // "EINVAL: size is less than the number of supplementary group IDs, but is
+ // not zero."
+ EXPECT_THAT(getgroups(-1, nullptr), SyscallFailsWithErrno(EINVAL));
+
+ // Testing for EFAULT requires actually having groups, which isn't guaranteed
+ // here; see the setgroups test below.
+}
+
+// Checks that the calling process' real/effective/saved user IDs are
+// ruid/euid/suid respectively.
+PosixError CheckUIDs(uid_t ruid, uid_t euid, uid_t suid) {
+ uid_t actual_ruid, actual_euid, actual_suid;
+ int rc = getresuid(&actual_ruid, &actual_euid, &actual_suid);
+ MaybeSave();
+ if (rc < 0) {
+ return PosixError(errno, "getresuid");
+ }
+ if (ruid != actual_ruid || euid != actual_euid || suid != actual_suid) {
+ return PosixError(
+ EPERM, absl::StrCat(
+ "incorrect user IDs: got (",
+ absl::StrJoin({actual_ruid, actual_euid, actual_suid}, ", "),
+ ", wanted (", absl::StrJoin({ruid, euid, suid}, ", "), ")"));
+ }
+ return NoError();
+}
+
+PosixError CheckGIDs(gid_t rgid, gid_t egid, gid_t sgid) {
+ gid_t actual_rgid, actual_egid, actual_sgid;
+ int rc = getresgid(&actual_rgid, &actual_egid, &actual_sgid);
+ MaybeSave();
+ if (rc < 0) {
+ return PosixError(errno, "getresgid");
+ }
+ if (rgid != actual_rgid || egid != actual_egid || sgid != actual_sgid) {
+ return PosixError(
+ EPERM, absl::StrCat(
+ "incorrect group IDs: got (",
+ absl::StrJoin({actual_rgid, actual_egid, actual_sgid}, ", "),
+ ", wanted (", absl::StrJoin({rgid, egid, sgid}, ", "), ")"));
+ }
+ return NoError();
+}
+
+// N.B. These tests may break horribly unless run via a gVisor test runner,
+// because changing UID in one test may forfeit permissions required by other
+// tests. (The test runner runs each test in a separate process.)
+
+TEST(UidGidRootTest, Setuid) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot()));
+
+ // Do setuid in a separate thread so that after finishing this test, the
+ // process can still open files the test harness created before starting this
+ // test. Otherwise, the files are created by root (UID before the test), but
+ // cannot be opened by the `uid` set below after the test. After calling
+ // setuid(non-zero-UID), there is no way to get root privileges back.
+ ScopedThread([&] {
+ // Use syscall instead of glibc setuid wrapper because we want this setuid
+ // call to only apply to this task. POSIX threads, however, require that all
+ // threads have the same UIDs, so using the setuid wrapper sets all threads'
+ // real UID.
+ EXPECT_THAT(syscall(SYS_setuid, -1), SyscallFailsWithErrno(EINVAL));
+
+ const uid_t uid = absl::GetFlag(FLAGS_scratch_uid1);
+ EXPECT_THAT(syscall(SYS_setuid, uid), SyscallSucceeds());
+ // "If the effective UID of the caller is root (more precisely: if the
+ // caller has the CAP_SETUID capability), the real UID and saved set-user-ID
+ // are also set." - setuid(2)
+ EXPECT_NO_ERRNO(CheckUIDs(uid, uid, uid));
+ });
+}
+
+TEST(UidGidRootTest, Setgid) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot()));
+
+ EXPECT_THAT(setgid(-1), SyscallFailsWithErrno(EINVAL));
+
+ const gid_t gid = absl::GetFlag(FLAGS_scratch_gid1);
+ ASSERT_THAT(setgid(gid), SyscallSucceeds());
+ EXPECT_NO_ERRNO(CheckGIDs(gid, gid, gid));
+}
+
+TEST(UidGidRootTest, SetgidNotFromThreadGroupLeader) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot()));
+
+ const gid_t gid = absl::GetFlag(FLAGS_scratch_gid1);
+ // NOTE(b/64676707): Do setgid in a separate thread so that we can test if
+ // info.si_pid is set correctly.
+ ScopedThread([gid] { ASSERT_THAT(setgid(gid), SyscallSucceeds()); });
+ EXPECT_NO_ERRNO(CheckGIDs(gid, gid, gid));
+}
+
+TEST(UidGidRootTest, Setreuid) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot()));
+
+ // "Supplying a value of -1 for either the real or effective user ID forces
+ // the system to leave that ID unchanged." - setreuid(2)
+ EXPECT_THAT(setreuid(-1, -1), SyscallSucceeds());
+ EXPECT_NO_ERRNO(CheckUIDs(0, 0, 0));
+
+ // Do setuid in a separate thread so that after finishing this test, the
+ // process can still open files the test harness created before starting this
+ // test. Otherwise, the files are created by root (UID before the test), but
+ // cannot be opened by the `uid` set below after the test. After calling
+ // setuid(non-zero-UID), there is no way to get root privileges back.
+ ScopedThread([&] {
+ const uid_t ruid = absl::GetFlag(FLAGS_scratch_uid1);
+ const uid_t euid = absl::GetFlag(FLAGS_scratch_uid2);
+
+ // Use syscall instead of glibc setuid wrapper because we want this setuid
+ // call to only apply to this task. posix threads, however, require that all
+ // threads have the same UIDs, so using the setuid wrapper sets all threads'
+ // real UID.
+ EXPECT_THAT(syscall(SYS_setreuid, ruid, euid), SyscallSucceeds());
+
+ // "If the real user ID is set or the effective user ID is set to a value
+ // not equal to the previous real user ID, the saved set-user-ID will be set
+ // to the new effective user ID." - setreuid(2)
+ EXPECT_NO_ERRNO(CheckUIDs(ruid, euid, euid));
+ });
+}
+
+TEST(UidGidRootTest, Setregid) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot()));
+
+ EXPECT_THAT(setregid(-1, -1), SyscallSucceeds());
+ EXPECT_NO_ERRNO(CheckGIDs(0, 0, 0));
+
+ const gid_t rgid = absl::GetFlag(FLAGS_scratch_gid1);
+ const gid_t egid = absl::GetFlag(FLAGS_scratch_gid2);
+ ASSERT_THAT(setregid(rgid, egid), SyscallSucceeds());
+ EXPECT_NO_ERRNO(CheckGIDs(rgid, egid, egid));
+}
+
+TEST(UidGidRootTest, Setresuid) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot()));
+
+ // "If one of the arguments equals -1, the corresponding value is not
+ // changed." - setresuid(2)
+ EXPECT_THAT(setresuid(-1, -1, -1), SyscallSucceeds());
+ EXPECT_NO_ERRNO(CheckUIDs(0, 0, 0));
+
+ // Do setuid in a separate thread so that after finishing this test, the
+ // process can still open files the test harness created before starting this
+ // test. Otherwise, the files are created by root (UID before the test), but
+ // cannot be opened by the `uid` set below after the test. After calling
+ // setuid(non-zero-UID), there is no way to get root privileges back.
+ ScopedThread([&] {
+ const uid_t ruid = 12345;
+ const uid_t euid = 23456;
+ const uid_t suid = 34567;
+
+ // Use syscall instead of glibc setuid wrapper because we want this setuid
+ // call to only apply to this task. posix threads, however, require that all
+ // threads have the same UIDs, so using the setuid wrapper sets all threads'
+ // real UID.
+ EXPECT_THAT(syscall(SYS_setresuid, ruid, euid, suid), SyscallSucceeds());
+ EXPECT_NO_ERRNO(CheckUIDs(ruid, euid, suid));
+ });
+}
+
+TEST(UidGidRootTest, Setresgid) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot()));
+
+ EXPECT_THAT(setresgid(-1, -1, -1), SyscallSucceeds());
+ EXPECT_NO_ERRNO(CheckGIDs(0, 0, 0));
+
+ const gid_t rgid = 12345;
+ const gid_t egid = 23456;
+ const gid_t sgid = 34567;
+ ASSERT_THAT(setresgid(rgid, egid, sgid), SyscallSucceeds());
+ EXPECT_NO_ERRNO(CheckGIDs(rgid, egid, sgid));
+}
+
+TEST(UidGidRootTest, Setgroups) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot()));
+
+ std::vector<gid_t> list = {123, 500};
+ ASSERT_THAT(setgroups(list.size(), list.data()), SyscallSucceeds());
+ std::vector<gid_t> list2(list.size());
+ ASSERT_THAT(getgroups(list2.size(), list2.data()), SyscallSucceeds());
+ EXPECT_THAT(list, UnorderedElementsAreArray(list2));
+
+ // "EFAULT: list has an invalid address."
+ EXPECT_THAT(getgroups(100, reinterpret_cast<gid_t*>(-1)),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST(UidGidRootTest, Setuid_prlimit) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot()));
+
+ // Do seteuid in a separate thread so that after finishing this test, the
+ // process can still open files the test harness created before starting this
+ // test. Otherwise, the files are created by root (UID before the test), but
+ // cannot be opened by the `uid` set below after the test.
+ ScopedThread([&] {
+ // Use syscall instead of glibc setuid wrapper because we want this seteuid
+ // call to only apply to this task. POSIX threads, however, require that all
+ // threads have the same UIDs, so using the seteuid wrapper sets all
+ // threads' UID.
+ EXPECT_THAT(syscall(SYS_setreuid, -1, 65534), SyscallSucceeds());
+
+ // Despite the UID change, we should be able to get our own limits.
+ struct rlimit rl = {};
+ EXPECT_THAT(prlimit(0, RLIMIT_NOFILE, NULL, &rl), SyscallSucceeds());
+ });
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/uname.cc b/test/syscalls/linux/uname.cc
new file mode 100644
index 000000000..d8824b171
--- /dev/null
+++ b/test/syscalls/linux/uname.cc
@@ -0,0 +1,111 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sched.h>
+#include <sys/utsname.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/strings/string_view.h"
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(UnameTest, Sanity) {
+ struct utsname buf;
+ ASSERT_THAT(uname(&buf), SyscallSucceeds());
+ EXPECT_NE(strlen(buf.release), 0);
+ EXPECT_NE(strlen(buf.version), 0);
+ EXPECT_NE(strlen(buf.machine), 0);
+ EXPECT_NE(strlen(buf.sysname), 0);
+ EXPECT_NE(strlen(buf.nodename), 0);
+ EXPECT_NE(strlen(buf.domainname), 0);
+}
+
+TEST(UnameTest, SetNames) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ char hostname[65];
+ ASSERT_THAT(sethostname("0123456789", 3), SyscallSucceeds());
+ EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
+ EXPECT_EQ(absl::string_view(hostname), "012");
+
+ ASSERT_THAT(sethostname("0123456789\0xxx", 11), SyscallSucceeds());
+ EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
+ EXPECT_EQ(absl::string_view(hostname), "0123456789");
+
+ ASSERT_THAT(sethostname("0123456789\0xxx", 12), SyscallSucceeds());
+ EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
+ EXPECT_EQ(absl::string_view(hostname), "0123456789");
+
+ constexpr char kHostname[] = "wubbalubba";
+ ASSERT_THAT(sethostname(kHostname, sizeof(kHostname)), SyscallSucceeds());
+
+ constexpr char kDomainname[] = "dubdub.com";
+ ASSERT_THAT(setdomainname(kDomainname, sizeof(kDomainname)),
+ SyscallSucceeds());
+
+ struct utsname buf;
+ EXPECT_THAT(uname(&buf), SyscallSucceeds());
+ EXPECT_EQ(absl::string_view(buf.nodename), kHostname);
+ EXPECT_EQ(absl::string_view(buf.domainname), kDomainname);
+
+ // These should just be glibc wrappers that also call uname(2).
+ EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
+ EXPECT_EQ(absl::string_view(hostname), kHostname);
+
+ char domainname[65];
+ EXPECT_THAT(getdomainname(domainname, sizeof(domainname)), SyscallSucceeds());
+ EXPECT_EQ(absl::string_view(domainname), kDomainname);
+}
+
+TEST(UnameTest, UnprivilegedSetNames) {
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))) {
+ EXPECT_NO_ERRNO(SetCapability(CAP_SYS_ADMIN, false));
+ }
+
+ EXPECT_THAT(sethostname("", 0), SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(setdomainname("", 0), SyscallFailsWithErrno(EPERM));
+}
+
+TEST(UnameTest, UnshareUTS) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ struct utsname init;
+ ASSERT_THAT(uname(&init), SyscallSucceeds());
+
+ ScopedThread([&]() {
+ EXPECT_THAT(unshare(CLONE_NEWUTS), SyscallSucceeds());
+
+ constexpr char kHostname[] = "wubbalubba";
+ EXPECT_THAT(sethostname(kHostname, sizeof(kHostname)), SyscallSucceeds());
+
+ char hostname[65];
+ EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
+ });
+
+ struct utsname after;
+ EXPECT_THAT(uname(&after), SyscallSucceeds());
+ EXPECT_EQ(absl::string_view(after.nodename), init.nodename);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/unix_domain_socket_test_util.cc b/test/syscalls/linux/unix_domain_socket_test_util.cc
new file mode 100644
index 000000000..b05ab2900
--- /dev/null
+++ b/test/syscalls/linux/unix_domain_socket_test_util.cc
@@ -0,0 +1,351 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+
+#include <sys/un.h>
+
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+std::string DescribeUnixDomainSocketType(int type) {
+ const char* type_str = nullptr;
+ switch (type & ~(SOCK_NONBLOCK | SOCK_CLOEXEC)) {
+ case SOCK_STREAM:
+ type_str = "SOCK_STREAM";
+ break;
+ case SOCK_DGRAM:
+ type_str = "SOCK_DGRAM";
+ break;
+ case SOCK_SEQPACKET:
+ type_str = "SOCK_SEQPACKET";
+ break;
+ }
+ if (!type_str) {
+ return absl::StrCat("Unix domain socket with unknown type ", type);
+ } else {
+ return absl::StrCat(((type & SOCK_NONBLOCK) != 0) ? "non-blocking " : "",
+ ((type & SOCK_CLOEXEC) != 0) ? "close-on-exec " : "",
+ type_str, " Unix domain socket");
+ }
+}
+
+SocketPairKind UnixDomainSocketPair(int type) {
+ return SocketPairKind{DescribeUnixDomainSocketType(type), AF_UNIX, type, 0,
+ SyscallSocketPairCreator(AF_UNIX, type, 0)};
+}
+
+SocketPairKind FilesystemBoundUnixDomainSocketPair(int type) {
+ std::string description = absl::StrCat(DescribeUnixDomainSocketType(type),
+ " created with filesystem binding");
+ if ((type & SOCK_DGRAM) == SOCK_DGRAM) {
+ return SocketPairKind{
+ description, AF_UNIX, type, 0,
+ FilesystemBidirectionalBindSocketPairCreator(AF_UNIX, type, 0)};
+ }
+ return SocketPairKind{
+ description, AF_UNIX, type, 0,
+ FilesystemAcceptBindSocketPairCreator(AF_UNIX, type, 0)};
+}
+
+SocketPairKind AbstractBoundUnixDomainSocketPair(int type) {
+ std::string description =
+ absl::StrCat(DescribeUnixDomainSocketType(type),
+ " created with abstract namespace binding");
+ if ((type & SOCK_DGRAM) == SOCK_DGRAM) {
+ return SocketPairKind{
+ description, AF_UNIX, type, 0,
+ AbstractBidirectionalBindSocketPairCreator(AF_UNIX, type, 0)};
+ }
+ return SocketPairKind{description, AF_UNIX, type, 0,
+ AbstractAcceptBindSocketPairCreator(AF_UNIX, type, 0)};
+}
+
+SocketPairKind SocketpairGoferUnixDomainSocketPair(int type) {
+ std::string description = absl::StrCat(DescribeUnixDomainSocketType(type),
+ " created with the socketpair gofer");
+ return SocketPairKind{description, AF_UNIX, type, 0,
+ SocketpairGoferSocketPairCreator(AF_UNIX, type, 0)};
+}
+
+SocketPairKind SocketpairGoferFileSocketPair(int type) {
+ std::string description =
+ absl::StrCat(((type & O_NONBLOCK) != 0) ? "non-blocking " : "",
+ ((type & O_CLOEXEC) != 0) ? "close-on-exec " : "",
+ "file socket created with the socketpair gofer");
+ // The socketpair gofer always creates SOCK_STREAM sockets on open(2).
+ return SocketPairKind{description, AF_UNIX, SOCK_STREAM, 0,
+ SocketpairGoferFileSocketPairCreator(type)};
+}
+
+SocketPairKind FilesystemUnboundUnixDomainSocketPair(int type) {
+ return SocketPairKind{absl::StrCat(DescribeUnixDomainSocketType(type),
+ " unbound with a filesystem address"),
+ AF_UNIX, type, 0,
+ FilesystemUnboundSocketPairCreator(AF_UNIX, type, 0)};
+}
+
+SocketPairKind AbstractUnboundUnixDomainSocketPair(int type) {
+ return SocketPairKind{
+ absl::StrCat(DescribeUnixDomainSocketType(type),
+ " unbound with an abstract namespace address"),
+ AF_UNIX, type, 0, AbstractUnboundSocketPairCreator(AF_UNIX, type, 0)};
+}
+
+void SendSingleFD(int sock, int fd, char buf[], int buf_size) {
+ ASSERT_NO_FATAL_FAILURE(SendFDs(sock, &fd, 1, buf, buf_size));
+}
+
+void SendFDs(int sock, int fds[], int fds_size, char buf[], int buf_size) {
+ struct msghdr msg = {};
+ std::vector<char> control(CMSG_SPACE(fds_size * sizeof(int)));
+ msg.msg_control = &control[0];
+ msg.msg_controllen = control.size();
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ cmsg->cmsg_len = CMSG_LEN(fds_size * sizeof(int));
+ cmsg->cmsg_level = SOL_SOCKET;
+ cmsg->cmsg_type = SCM_RIGHTS;
+ for (int i = 0; i < fds_size; i++) {
+ memcpy(CMSG_DATA(cmsg) + i * sizeof(int), &fds[i], sizeof(int));
+ }
+
+ ASSERT_THAT(SendMsg(sock, &msg, buf, buf_size),
+ IsPosixErrorOkAndHolds(buf_size));
+}
+
+void RecvSingleFD(int sock, int* fd, char buf[], int buf_size) {
+ ASSERT_NO_FATAL_FAILURE(RecvFDs(sock, fd, 1, buf, buf_size, buf_size));
+}
+
+void RecvSingleFD(int sock, int* fd, char buf[], int buf_size,
+ int expected_size) {
+ ASSERT_NO_FATAL_FAILURE(RecvFDs(sock, fd, 1, buf, buf_size, expected_size));
+}
+
+void RecvFDs(int sock, int fds[], int fds_size, char buf[], int buf_size) {
+ ASSERT_NO_FATAL_FAILURE(
+ RecvFDs(sock, fds, fds_size, buf, buf_size, buf_size));
+}
+
+void RecvFDs(int sock, int fds[], int fds_size, char buf[], int buf_size,
+ int expected_size, bool peek) {
+ struct msghdr msg = {};
+ std::vector<char> control(CMSG_SPACE(fds_size * sizeof(int)));
+ msg.msg_control = &control[0];
+ msg.msg_controllen = control.size();
+
+ struct iovec iov;
+ iov.iov_base = buf;
+ iov.iov_len = buf_size;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ int flags = 0;
+ if (peek) {
+ flags |= MSG_PEEK;
+ }
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sock, &msg, flags),
+ SyscallSucceedsWithValue(expected_size));
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(fds_size * sizeof(int)));
+ ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET);
+ ASSERT_EQ(cmsg->cmsg_type, SCM_RIGHTS);
+
+ for (int i = 0; i < fds_size; i++) {
+ memcpy(&fds[i], CMSG_DATA(cmsg) + i * sizeof(int), sizeof(int));
+ }
+}
+
+void RecvFDs(int sock, int fds[], int fds_size, char buf[], int buf_size,
+ int expected_size) {
+ ASSERT_NO_FATAL_FAILURE(
+ RecvFDs(sock, fds, fds_size, buf, buf_size, expected_size, false));
+}
+
+void PeekSingleFD(int sock, int* fd, char buf[], int buf_size) {
+ ASSERT_NO_FATAL_FAILURE(RecvFDs(sock, fd, 1, buf, buf_size, buf_size, true));
+}
+
+void RecvNoCmsg(int sock, char buf[], int buf_size, int expected_size) {
+ struct msghdr msg = {};
+ char control[CMSG_SPACE(sizeof(int)) + CMSG_SPACE(sizeof(struct ucred))];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ struct iovec iov;
+ iov.iov_base = buf;
+ iov.iov_len = buf_size;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sock, &msg, 0),
+ SyscallSucceedsWithValue(expected_size));
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ EXPECT_EQ(cmsg, nullptr);
+}
+
+void SendNullCmsg(int sock, char buf[], int buf_size) {
+ struct msghdr msg = {};
+ msg.msg_control = nullptr;
+ msg.msg_controllen = 0;
+
+ ASSERT_THAT(SendMsg(sock, &msg, buf, buf_size),
+ IsPosixErrorOkAndHolds(buf_size));
+}
+
+void SendCreds(int sock, ucred creds, char buf[], int buf_size) {
+ struct msghdr msg = {};
+
+ char control[CMSG_SPACE(sizeof(struct ucred))];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ cmsg->cmsg_level = SOL_SOCKET;
+ cmsg->cmsg_type = SCM_CREDENTIALS;
+ cmsg->cmsg_len = CMSG_LEN(sizeof(struct ucred));
+ memcpy(CMSG_DATA(cmsg), &creds, sizeof(struct ucred));
+
+ ASSERT_THAT(SendMsg(sock, &msg, buf, buf_size),
+ IsPosixErrorOkAndHolds(buf_size));
+}
+
+void SendCredsAndFD(int sock, ucred creds, int fd, char buf[], int buf_size) {
+ struct msghdr msg = {};
+
+ char control[CMSG_SPACE(sizeof(struct ucred)) + CMSG_SPACE(sizeof(int))] = {};
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ struct cmsghdr* cmsg1 = CMSG_FIRSTHDR(&msg);
+ cmsg1->cmsg_level = SOL_SOCKET;
+ cmsg1->cmsg_type = SCM_CREDENTIALS;
+ cmsg1->cmsg_len = CMSG_LEN(sizeof(struct ucred));
+ memcpy(CMSG_DATA(cmsg1), &creds, sizeof(struct ucred));
+
+ struct cmsghdr* cmsg2 = CMSG_NXTHDR(&msg, cmsg1);
+ cmsg2->cmsg_level = SOL_SOCKET;
+ cmsg2->cmsg_type = SCM_RIGHTS;
+ cmsg2->cmsg_len = CMSG_LEN(sizeof(int));
+ memcpy(CMSG_DATA(cmsg2), &fd, sizeof(int));
+
+ ASSERT_THAT(SendMsg(sock, &msg, buf, buf_size),
+ IsPosixErrorOkAndHolds(buf_size));
+}
+
+void RecvCreds(int sock, ucred* creds, char buf[], int buf_size) {
+ ASSERT_NO_FATAL_FAILURE(RecvCreds(sock, creds, buf, buf_size, buf_size));
+}
+
+void RecvCreds(int sock, ucred* creds, char buf[], int buf_size,
+ int expected_size) {
+ struct msghdr msg = {};
+ char control[CMSG_SPACE(sizeof(struct ucred))];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ struct iovec iov;
+ iov.iov_base = buf;
+ iov.iov_len = buf_size;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sock, &msg, 0),
+ SyscallSucceedsWithValue(expected_size));
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct ucred)));
+ ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET);
+ ASSERT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS);
+
+ memcpy(creds, CMSG_DATA(cmsg), sizeof(struct ucred));
+}
+
+void RecvCredsAndFD(int sock, ucred* creds, int* fd, char buf[], int buf_size) {
+ struct msghdr msg = {};
+ char control[CMSG_SPACE(sizeof(struct ucred)) + CMSG_SPACE(sizeof(int))];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ struct iovec iov;
+ iov.iov_base = buf;
+ iov.iov_len = buf_size;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sock, &msg, 0),
+ SyscallSucceedsWithValue(buf_size));
+
+ struct cmsghdr* cmsg1 = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg1, nullptr);
+ ASSERT_EQ(cmsg1->cmsg_len, CMSG_LEN(sizeof(struct ucred)));
+ ASSERT_EQ(cmsg1->cmsg_level, SOL_SOCKET);
+ ASSERT_EQ(cmsg1->cmsg_type, SCM_CREDENTIALS);
+ memcpy(creds, CMSG_DATA(cmsg1), sizeof(struct ucred));
+
+ struct cmsghdr* cmsg2 = CMSG_NXTHDR(&msg, cmsg1);
+ ASSERT_NE(cmsg2, nullptr);
+ ASSERT_EQ(cmsg2->cmsg_len, CMSG_LEN(sizeof(int)));
+ ASSERT_EQ(cmsg2->cmsg_level, SOL_SOCKET);
+ ASSERT_EQ(cmsg2->cmsg_type, SCM_RIGHTS);
+ memcpy(fd, CMSG_DATA(cmsg2), sizeof(int));
+}
+
+void RecvSingleFDUnaligned(int sock, int* fd, char buf[], int buf_size) {
+ struct msghdr msg = {};
+ char control[CMSG_SPACE(sizeof(int)) - sizeof(int)];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ struct iovec iov;
+ iov.iov_base = buf;
+ iov.iov_len = buf_size;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sock, &msg, 0),
+ SyscallSucceedsWithValue(buf_size));
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int)));
+ ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET);
+ ASSERT_EQ(cmsg->cmsg_type, SCM_RIGHTS);
+
+ memcpy(fd, CMSG_DATA(cmsg), sizeof(int));
+}
+
+void SetSoPassCred(int sock) {
+ int one = 1;
+ EXPECT_THAT(setsockopt(sock, SOL_SOCKET, SO_PASSCRED, &one, sizeof(one)),
+ SyscallSucceeds());
+}
+
+void UnsetSoPassCred(int sock) {
+ int zero = 0;
+ EXPECT_THAT(setsockopt(sock, SOL_SOCKET, SO_PASSCRED, &zero, sizeof(zero)),
+ SyscallSucceeds());
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/unix_domain_socket_test_util.h b/test/syscalls/linux/unix_domain_socket_test_util.h
new file mode 100644
index 000000000..b8073db17
--- /dev/null
+++ b/test/syscalls/linux/unix_domain_socket_test_util.h
@@ -0,0 +1,162 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_UNIX_DOMAIN_SOCKET_TEST_UTIL_H_
+#define GVISOR_TEST_SYSCALLS_UNIX_DOMAIN_SOCKET_TEST_UTIL_H_
+
+#include <string>
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// DescribeUnixDomainSocketType returns a human-readable string explaining the
+// given Unix domain socket type.
+std::string DescribeUnixDomainSocketType(int type);
+
+// UnixDomainSocketPair returns a SocketPairKind that represents SocketPairs
+// created by invoking the socketpair() syscall with AF_UNIX and the given type.
+SocketPairKind UnixDomainSocketPair(int type);
+
+// FilesystemBoundUnixDomainSocketPair returns a SocketPairKind that represents
+// SocketPairs created with bind() and accept() syscalls with a temp file path,
+// AF_UNIX and the given type.
+SocketPairKind FilesystemBoundUnixDomainSocketPair(int type);
+
+// AbstractBoundUnixDomainSocketPair returns a SocketPairKind that represents
+// SocketPairs created with bind() and accept() syscalls with a temp abstract
+// path, AF_UNIX and the given type.
+SocketPairKind AbstractBoundUnixDomainSocketPair(int type);
+
+// SocketpairGoferUnixDomainSocketPair returns a SocketPairKind that was created
+// with two sockets connected to the socketpair gofer.
+SocketPairKind SocketpairGoferUnixDomainSocketPair(int type);
+
+// SocketpairGoferFileSocketPair returns a SocketPairKind that was created with
+// two open() calls on paths backed by the socketpair gofer.
+SocketPairKind SocketpairGoferFileSocketPair(int type);
+
+// FilesystemUnboundUnixDomainSocketPair returns a SocketPairKind that
+// represents two unbound sockets and a filesystem path for binding.
+SocketPairKind FilesystemUnboundUnixDomainSocketPair(int type);
+
+// AbstractUnboundUnixDomainSocketPair returns a SocketPairKind that represents
+// two unbound sockets and an abstract namespace path for binding.
+SocketPairKind AbstractUnboundUnixDomainSocketPair(int type);
+
+// SendSingleFD sends both a single FD and some data over a unix domain socket
+// specified by an FD. Note that calls to this function must be wrapped in
+// ASSERT_NO_FATAL_FAILURE for internal assertions to halt the test.
+void SendSingleFD(int sock, int fd, char buf[], int buf_size);
+
+// SendFDs sends an arbitrary number of FDs and some data over a unix domain
+// socket specified by an FD. Note that calls to this function must be wrapped
+// in ASSERT_NO_FATAL_FAILURE for internal assertions to halt the test.
+void SendFDs(int sock, int fds[], int fds_size, char buf[], int buf_size);
+
+// RecvSingleFD receives both a single FD and some data over a unix domain
+// socket specified by an FD. Note that calls to this function must be wrapped
+// in ASSERT_NO_FATAL_FAILURE for internal assertions to halt the test.
+void RecvSingleFD(int sock, int* fd, char buf[], int buf_size);
+
+// RecvSingleFD receives both a single FD and some data over a unix domain
+// socket specified by an FD. This version allows the expected amount of data
+// received to be different than the buffer size. Note that calls to this
+// function must be wrapped in ASSERT_NO_FATAL_FAILURE for internal assertions
+// to halt the test.
+void RecvSingleFD(int sock, int* fd, char buf[], int buf_size,
+ int expected_size);
+
+// PeekSingleFD peeks at both a single FD and some data over a unix domain
+// socket specified by an FD. Note that calls to this function must be wrapped
+// in ASSERT_NO_FATAL_FAILURE for internal assertions to halt the test.
+void PeekSingleFD(int sock, int* fd, char buf[], int buf_size);
+
+// RecvFDs receives both an arbitrary number of FDs and some data over a unix
+// domain socket specified by an FD. Note that calls to this function must be
+// wrapped in ASSERT_NO_FATAL_FAILURE for internal assertions to halt the test.
+void RecvFDs(int sock, int fds[], int fds_size, char buf[], int buf_size);
+
+// RecvFDs receives both an arbitrary number of FDs and some data over a unix
+// domain socket specified by an FD. This version allows the expected amount of
+// data received to be different than the buffer size. Note that calls to this
+// function must be wrapped in ASSERT_NO_FATAL_FAILURE for internal assertions
+// to halt the test.
+void RecvFDs(int sock, int fds[], int fds_size, char buf[], int buf_size,
+ int expected_size);
+
+// RecvNoCmsg receives some data over a unix domain socket specified by an FD
+// and asserts that no control messages are available for receiving. Note that
+// calls to this function must be wrapped in ASSERT_NO_FATAL_FAILURE for
+// internal assertions to halt the test.
+void RecvNoCmsg(int sock, char buf[], int buf_size, int expected_size);
+
+inline void RecvNoCmsg(int sock, char buf[], int buf_size) {
+ RecvNoCmsg(sock, buf, buf_size, buf_size);
+}
+
+// SendCreds sends the credentials of the current process and some data over a
+// unix domain socket specified by an FD. Note that calls to this function must
+// be wrapped in ASSERT_NO_FATAL_FAILURE for internal assertions to halt the
+// test.
+void SendCreds(int sock, ucred creds, char buf[], int buf_size);
+
+// SendCredsAndFD sends the credentials of the current process, a single FD, and
+// some data over a unix domain socket specified by an FD. Note that calls to
+// this function must be wrapped in ASSERT_NO_FATAL_FAILURE for internal
+// assertions to halt the test.
+void SendCredsAndFD(int sock, ucred creds, int fd, char buf[], int buf_size);
+
+// RecvCreds receives some credentials and some data over a unix domain socket
+// specified by an FD. Note that calls to this function must be wrapped in
+// ASSERT_NO_FATAL_FAILURE for internal assertions to halt the test.
+void RecvCreds(int sock, ucred* creds, char buf[], int buf_size);
+
+// RecvCreds receives some credentials and some data over a unix domain socket
+// specified by an FD. This version allows the expected amount of data received
+// to be different than the buffer size. Note that calls to this function must
+// be wrapped in ASSERT_NO_FATAL_FAILURE for internal assertions to halt the
+// test.
+void RecvCreds(int sock, ucred* creds, char buf[], int buf_size,
+ int expected_size);
+
+// RecvCredsAndFD receives some credentials, a single FD, and some data over a
+// unix domain socket specified by an FD. Note that calls to this function must
+// be wrapped in ASSERT_NO_FATAL_FAILURE for internal assertions to halt the
+// test.
+void RecvCredsAndFD(int sock, ucred* creds, int* fd, char buf[], int buf_size);
+
+// SendNullCmsg sends a null control message and some data over a unix domain
+// socket specified by an FD. Note that calls to this function must be wrapped
+// in ASSERT_NO_FATAL_FAILURE for internal assertions to halt the test.
+void SendNullCmsg(int sock, char buf[], int buf_size);
+
+// RecvSingleFDUnaligned sends both a single FD and some data over a unix domain
+// socket specified by an FD. This function does not obey the spec, but Linux
+// allows it and the apphosting code depends on this quirk. Note that calls to
+// this function must be wrapped in ASSERT_NO_FATAL_FAILURE for internal
+// assertions to halt the test.
+void RecvSingleFDUnaligned(int sock, int* fd, char buf[], int buf_size);
+
+// SetSoPassCred sets the SO_PASSCRED option on the specified socket.
+void SetSoPassCred(int sock);
+
+// UnsetSoPassCred clears the SO_PASSCRED option on the specified socket.
+void UnsetSoPassCred(int sock);
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_UNIX_DOMAIN_SOCKET_TEST_UTIL_H_
diff --git a/test/syscalls/linux/unlink.cc b/test/syscalls/linux/unlink.cc
new file mode 100644
index 000000000..2040375c9
--- /dev/null
+++ b/test/syscalls/linux/unlink.cc
@@ -0,0 +1,214 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(UnlinkTest, IsDir) {
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ EXPECT_THAT(unlink(dir.path().c_str()), SyscallFailsWithErrno(EISDIR));
+}
+
+TEST(UnlinkTest, DirNotEmpty) {
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ int fd;
+ std::string path = JoinPath(dir.path(), "ExistingFile");
+ EXPECT_THAT(fd = open(path.c_str(), O_RDWR | O_CREAT, 0666),
+ SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+ EXPECT_THAT(rmdir(dir.path().c_str()), SyscallFailsWithErrno(ENOTEMPTY));
+}
+
+TEST(UnlinkTest, Rmdir) {
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(rmdir(dir.path().c_str()), SyscallSucceeds());
+}
+
+TEST(UnlinkTest, AtDir) {
+ int dirfd;
+ auto tmpdir = GetAbsoluteTestTmpdir();
+ EXPECT_THAT(dirfd = open(tmpdir.c_str(), O_DIRECTORY, 0), SyscallSucceeds());
+
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(tmpdir));
+ auto dir_relpath =
+ ASSERT_NO_ERRNO_AND_VALUE(GetRelativePath(tmpdir, dir.path()));
+ EXPECT_THAT(unlinkat(dirfd, dir_relpath.c_str(), AT_REMOVEDIR),
+ SyscallSucceeds());
+ ASSERT_THAT(close(dirfd), SyscallSucceeds());
+}
+
+TEST(UnlinkTest, AtDirDegradedPermissions_NoRandomSave) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ int dirfd;
+ ASSERT_THAT(dirfd = open(dir.path().c_str(), O_DIRECTORY, 0),
+ SyscallSucceeds());
+
+ std::string sub_dir = JoinPath(dir.path(), "NewDir");
+ EXPECT_THAT(mkdir(sub_dir.c_str(), 0755), SyscallSucceeds());
+ EXPECT_THAT(fchmod(dirfd, 0444), SyscallSucceeds());
+ EXPECT_THAT(unlinkat(dirfd, "NewDir", AT_REMOVEDIR),
+ SyscallFailsWithErrno(EACCES));
+ ASSERT_THAT(close(dirfd), SyscallSucceeds());
+}
+
+// Files cannot be unlinked if the parent is not writable and executable.
+TEST(UnlinkTest, ParentDegradedPermissions) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
+
+ ASSERT_THAT(chmod(dir.path().c_str(), 0000), SyscallSucceeds());
+
+ struct stat st;
+ ASSERT_THAT(stat(file.path().c_str(), &st), SyscallFailsWithErrno(EACCES));
+ ASSERT_THAT(unlinkat(AT_FDCWD, file.path().c_str(), 0),
+ SyscallFailsWithErrno(EACCES));
+
+ // Non-existent files also return EACCES.
+ const std::string nonexist = JoinPath(dir.path(), "doesnotexist");
+ ASSERT_THAT(stat(nonexist.c_str(), &st), SyscallFailsWithErrno(EACCES));
+ ASSERT_THAT(unlinkat(AT_FDCWD, nonexist.c_str(), 0),
+ SyscallFailsWithErrno(EACCES));
+}
+
+TEST(UnlinkTest, AtBad) {
+ int dirfd;
+ EXPECT_THAT(dirfd = open(GetAbsoluteTestTmpdir().c_str(), O_DIRECTORY, 0),
+ SyscallSucceeds());
+
+ // Try removing a directory as a file.
+ std::string path = JoinPath(GetAbsoluteTestTmpdir(), "NewDir");
+ EXPECT_THAT(mkdir(path.c_str(), 0755), SyscallSucceeds());
+ EXPECT_THAT(unlinkat(dirfd, "NewDir", 0), SyscallFailsWithErrno(EISDIR));
+ EXPECT_THAT(unlinkat(dirfd, "NewDir", AT_REMOVEDIR), SyscallSucceeds());
+
+ // Try removing a file as a directory.
+ int fd;
+ EXPECT_THAT(fd = openat(dirfd, "UnlinkAtFile", O_RDWR | O_CREAT, 0666),
+ SyscallSucceeds());
+ EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile", AT_REMOVEDIR),
+ SyscallFailsWithErrno(ENOTDIR));
+ EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile/", 0),
+ SyscallFailsWithErrno(ENOTDIR));
+ ASSERT_THAT(close(fd), SyscallSucceeds());
+ EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile", 0), SyscallSucceeds());
+
+ // Cleanup.
+ ASSERT_THAT(close(dirfd), SyscallSucceeds());
+}
+
+TEST(UnlinkTest, AbsTmpFile) {
+ int fd;
+ std::string path = JoinPath(GetAbsoluteTestTmpdir(), "ExistingFile");
+ EXPECT_THAT(fd = open(path.c_str(), O_RDWR | O_CREAT, 0666),
+ SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+ EXPECT_THAT(unlink(path.c_str()), SyscallSucceeds());
+}
+
+TEST(UnlinkTest, TooLongName) {
+ EXPECT_THAT(unlink(std::vector<char>(16384, '0').data()),
+ SyscallFailsWithErrno(ENAMETOOLONG));
+}
+
+TEST(UnlinkTest, BadNamePtr) {
+ EXPECT_THAT(unlink(reinterpret_cast<char*>(1)),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST(UnlinkTest, AtFile) {
+ int dirfd;
+ EXPECT_THAT(dirfd = open(GetAbsoluteTestTmpdir().c_str(), O_DIRECTORY, 0666),
+ SyscallSucceeds());
+ int fd;
+ EXPECT_THAT(fd = openat(dirfd, "UnlinkAtFile", O_RDWR | O_CREAT, 0666),
+ SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+ EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile", 0), SyscallSucceeds());
+}
+
+TEST(UnlinkTest, OpenFile_NoRandomSave) {
+ // We can't save unlinked file unless they are on tmpfs.
+ const DisableSave ds;
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ int fd;
+ EXPECT_THAT(fd = open(file.path().c_str(), O_RDWR, 0666), SyscallSucceeds());
+ EXPECT_THAT(unlink(file.path().c_str()), SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+TEST(UnlinkTest, CannotRemoveDots) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const std::string self = JoinPath(file.path(), ".");
+ ASSERT_THAT(unlink(self.c_str()), SyscallFailsWithErrno(ENOTDIR));
+ const std::string parent = JoinPath(file.path(), "..");
+ ASSERT_THAT(unlink(parent.c_str()), SyscallFailsWithErrno(ENOTDIR));
+}
+
+TEST(UnlinkTest, CannotRemoveRoot) {
+ ASSERT_THAT(unlinkat(-1, "/", AT_REMOVEDIR), SyscallFailsWithErrno(EBUSY));
+}
+
+TEST(UnlinkTest, CannotRemoveRootWithAtDir) {
+ const FileDescriptor dirfd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(GetAbsoluteTestTmpdir(), O_DIRECTORY, 0666));
+ ASSERT_THAT(unlinkat(dirfd.get(), "/", AT_REMOVEDIR),
+ SyscallFailsWithErrno(EBUSY));
+}
+
+TEST(RmdirTest, CannotRemoveDots) {
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const std::string self = JoinPath(dir.path(), ".");
+ ASSERT_THAT(rmdir(self.c_str()), SyscallFailsWithErrno(EINVAL));
+ const std::string parent = JoinPath(dir.path(), "..");
+ ASSERT_THAT(rmdir(parent.c_str()), SyscallFailsWithErrno(ENOTEMPTY));
+}
+
+TEST(RmdirTest, CanRemoveWithTrailingSlashes) {
+ auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const std::string slash = absl::StrCat(dir1.path(), "/");
+ ASSERT_THAT(rmdir(slash.c_str()), SyscallSucceeds());
+ auto dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const std::string slashslash = absl::StrCat(dir2.path(), "//");
+ ASSERT_THAT(rmdir(slashslash.c_str()), SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/unshare.cc b/test/syscalls/linux/unshare.cc
new file mode 100644
index 000000000..e32619efe
--- /dev/null
+++ b/test/syscalls/linux/unshare.cc
@@ -0,0 +1,50 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <sched.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/synchronization/mutex.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(UnshareTest, AllowsZeroFlags) {
+ ASSERT_THAT(unshare(0), SyscallSucceeds());
+}
+
+TEST(UnshareTest, ThreadFlagFailsIfMultithreaded) {
+ absl::Mutex mu;
+ bool finished = false;
+ ScopedThread t([&] {
+ mu.Lock();
+ mu.Await(absl::Condition(&finished));
+ mu.Unlock();
+ });
+ ASSERT_THAT(unshare(CLONE_THREAD), SyscallFailsWithErrno(EINVAL));
+ mu.Lock();
+ finished = true;
+ mu.Unlock();
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/utimes.cc b/test/syscalls/linux/utimes.cc
new file mode 100644
index 000000000..e647d2896
--- /dev/null
+++ b/test/syscalls/linux/utimes.cc
@@ -0,0 +1,319 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <sys/stat.h>
+#include <sys/syscall.h>
+#include <sys/time.h>
+#include <sys/types.h>
+#include <time.h>
+#include <unistd.h>
+#include <utime.h>
+
+#include <string>
+
+#include "absl/time/time.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// TimeBoxed runs fn, setting before and after to (coarse realtime) times
+// guaranteed* to come before and after fn started and completed, respectively.
+//
+// fn may be called more than once if the clock is adjusted.
+void TimeBoxed(absl::Time* before, absl::Time* after,
+ std::function<void()> const& fn) {
+ do {
+ // N.B. utimes and friends use CLOCK_REALTIME_COARSE for setting time (i.e.,
+ // current_kernel_time()). See fs/attr.c:notify_change.
+ //
+ // notify_change truncates the time to a multiple of s_time_gran, but most
+ // filesystems set it to 1, so we don't do any truncation.
+ struct timespec ts;
+ EXPECT_THAT(clock_gettime(CLOCK_REALTIME_COARSE, &ts), SyscallSucceeds());
+ // FIXME(b/132819225): gVisor filesystem timestamps inconsistently use the
+ // internal or host clock, which may diverge slightly. Allow some slack on
+ // times to account for the difference.
+ *before = absl::TimeFromTimespec(ts) - absl::Seconds(1);
+
+ fn();
+
+ EXPECT_THAT(clock_gettime(CLOCK_REALTIME_COARSE, &ts), SyscallSucceeds());
+ *after = absl::TimeFromTimespec(ts) + absl::Seconds(1);
+
+ if (*after < *before) {
+ // Clock jumped backwards; retry.
+ //
+ // Technically this misses jumps small enough to keep after > before,
+ // which could lead to test failures, but that is very unlikely to happen.
+ continue;
+ }
+ } while (*after < *before);
+}
+
+void TestUtimesOnPath(std::string const& path) {
+ struct stat statbuf;
+
+ struct timeval times[2] = {{10, 0}, {20, 0}};
+ EXPECT_THAT(utimes(path.c_str(), times), SyscallSucceeds());
+ EXPECT_THAT(stat(path.c_str(), &statbuf), SyscallSucceeds());
+ EXPECT_EQ(10, statbuf.st_atime);
+ EXPECT_EQ(20, statbuf.st_mtime);
+
+ absl::Time before;
+ absl::Time after;
+ TimeBoxed(&before, &after, [&] {
+ EXPECT_THAT(utimes(path.c_str(), nullptr), SyscallSucceeds());
+ });
+
+ EXPECT_THAT(stat(path.c_str(), &statbuf), SyscallSucceeds());
+
+ absl::Time atime = absl::TimeFromTimespec(statbuf.st_atim);
+ EXPECT_GE(atime, before);
+ EXPECT_LE(atime, after);
+
+ absl::Time mtime = absl::TimeFromTimespec(statbuf.st_mtim);
+ EXPECT_GE(mtime, before);
+ EXPECT_LE(mtime, after);
+}
+
+TEST(UtimesTest, OnFile) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ TestUtimesOnPath(f.path());
+}
+
+TEST(UtimesTest, OnDir) {
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ TestUtimesOnPath(dir.path());
+}
+
+TEST(UtimesTest, MissingPath) {
+ auto path = NewTempAbsPath();
+ struct timeval times[2] = {{10, 0}, {20, 0}};
+ EXPECT_THAT(utimes(path.c_str(), times), SyscallFailsWithErrno(ENOENT));
+}
+
+void TestFutimesat(int dirFd, std::string const& path) {
+ struct stat statbuf;
+
+ struct timeval times[2] = {{10, 0}, {20, 0}};
+ EXPECT_THAT(futimesat(dirFd, path.c_str(), times), SyscallSucceeds());
+ EXPECT_THAT(fstatat(dirFd, path.c_str(), &statbuf, 0), SyscallSucceeds());
+ EXPECT_EQ(10, statbuf.st_atime);
+ EXPECT_EQ(20, statbuf.st_mtime);
+
+ absl::Time before;
+ absl::Time after;
+ TimeBoxed(&before, &after, [&] {
+ EXPECT_THAT(futimesat(dirFd, path.c_str(), nullptr), SyscallSucceeds());
+ });
+
+ EXPECT_THAT(fstatat(dirFd, path.c_str(), &statbuf, 0), SyscallSucceeds());
+
+ absl::Time atime = absl::TimeFromTimespec(statbuf.st_atim);
+ EXPECT_GE(atime, before);
+ EXPECT_LE(atime, after);
+
+ absl::Time mtime = absl::TimeFromTimespec(statbuf.st_mtim);
+ EXPECT_GE(mtime, before);
+ EXPECT_LE(mtime, after);
+}
+
+TEST(FutimesatTest, OnAbsPath) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ TestFutimesat(0, f.path());
+}
+
+TEST(FutimesatTest, OnRelPath) {
+ auto d = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(d.path()));
+ auto basename = std::string(Basename(f.path()));
+ const FileDescriptor dirFd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(d.path(), O_RDONLY | O_DIRECTORY));
+ TestFutimesat(dirFd.get(), basename);
+}
+
+TEST(FutimesatTest, InvalidNsec) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ struct timeval times[4][2] = {{
+ {0, 1}, // Valid
+ {1, static_cast<int64_t>(1e7)} // Invalid
+ },
+ {
+ {1, static_cast<int64_t>(1e7)}, // Invalid
+ {0, 1} // Valid
+ },
+ {
+ {0, 1}, // Valid
+ {1, -1} // Invalid
+ },
+ {
+ {1, -1}, // Invalid
+ {0, 1} // Valid
+ }};
+
+ for (unsigned int i = 0; i < sizeof(times) / sizeof(times[0]); i++) {
+ std::cout << "test:" << i << "\n";
+ EXPECT_THAT(futimesat(0, f.path().c_str(), times[i]),
+ SyscallFailsWithErrno(EINVAL));
+ }
+}
+
+void TestUtimensat(int dirFd, std::string const& path) {
+ struct stat statbuf;
+ const struct timespec times[2] = {{10, 0}, {20, 0}};
+ EXPECT_THAT(utimensat(dirFd, path.c_str(), times, 0), SyscallSucceeds());
+ EXPECT_THAT(fstatat(dirFd, path.c_str(), &statbuf, 0), SyscallSucceeds());
+ EXPECT_EQ(10, statbuf.st_atime);
+ EXPECT_EQ(20, statbuf.st_mtime);
+
+ // Test setting with UTIME_NOW and UTIME_OMIT.
+ struct stat statbuf2;
+ const struct timespec times2[2] = {
+ {0, UTIME_NOW}, // Should set atime to now.
+ {0, UTIME_OMIT} // Should not change mtime.
+ };
+
+ absl::Time before;
+ absl::Time after;
+ TimeBoxed(&before, &after, [&] {
+ EXPECT_THAT(utimensat(dirFd, path.c_str(), times2, 0), SyscallSucceeds());
+ });
+
+ EXPECT_THAT(fstatat(dirFd, path.c_str(), &statbuf2, 0), SyscallSucceeds());
+
+ absl::Time atime2 = absl::TimeFromTimespec(statbuf2.st_atim);
+ EXPECT_GE(atime2, before);
+ EXPECT_LE(atime2, after);
+
+ absl::Time mtime = absl::TimeFromTimespec(statbuf.st_mtim);
+ absl::Time mtime2 = absl::TimeFromTimespec(statbuf2.st_mtim);
+ // mtime should not be changed.
+ EXPECT_EQ(mtime, mtime2);
+
+ // Test setting with times = NULL. Should set both atime and mtime to the
+ // current system time.
+ struct stat statbuf3;
+ TimeBoxed(&before, &after, [&] {
+ EXPECT_THAT(utimensat(dirFd, path.c_str(), nullptr, 0), SyscallSucceeds());
+ });
+
+ EXPECT_THAT(fstatat(dirFd, path.c_str(), &statbuf3, 0), SyscallSucceeds());
+
+ absl::Time atime3 = absl::TimeFromTimespec(statbuf3.st_atim);
+ EXPECT_GE(atime3, before);
+ EXPECT_LE(atime3, after);
+
+ absl::Time mtime3 = absl::TimeFromTimespec(statbuf3.st_mtim);
+ EXPECT_GE(mtime3, before);
+ EXPECT_LE(mtime3, after);
+
+ EXPECT_EQ(atime3, mtime3);
+}
+
+TEST(UtimensatTest, OnAbsPath) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ TestUtimensat(0, f.path());
+}
+
+TEST(UtimensatTest, OnRelPath) {
+ auto d = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(d.path()));
+ auto basename = std::string(Basename(f.path()));
+ const FileDescriptor dirFd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(d.path(), O_RDONLY | O_DIRECTORY));
+ TestUtimensat(dirFd.get(), basename);
+}
+
+TEST(UtimensatTest, OmitNoop) {
+ // Setting both timespecs to UTIME_OMIT on a nonexistant path should succeed.
+ auto path = NewTempAbsPath();
+ const struct timespec times[2] = {{0, UTIME_OMIT}, {0, UTIME_OMIT}};
+ EXPECT_THAT(utimensat(0, path.c_str(), times, 0), SyscallSucceeds());
+}
+
+// Verify that we can actually set atime and mtime to 0.
+TEST(UtimeTest, ZeroAtimeandMtime) {
+ const auto tmp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const auto tmp_file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(tmp_dir.path()));
+
+ // Stat the file before and after updating atime and mtime.
+ struct stat stat_before = {};
+ EXPECT_THAT(stat(tmp_file.path().c_str(), &stat_before), SyscallSucceeds());
+
+ ASSERT_NE(stat_before.st_atime, 0);
+ ASSERT_NE(stat_before.st_mtime, 0);
+
+ const struct utimbuf times = {}; // Zero for both atime and mtime.
+ EXPECT_THAT(utime(tmp_file.path().c_str(), &times), SyscallSucceeds());
+
+ struct stat stat_after = {};
+ EXPECT_THAT(stat(tmp_file.path().c_str(), &stat_after), SyscallSucceeds());
+
+ // We should see the atime and mtime changed when we set them to 0.
+ ASSERT_EQ(stat_after.st_atime, 0);
+ ASSERT_EQ(stat_after.st_mtime, 0);
+}
+
+TEST(UtimensatTest, InvalidNsec) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ struct timespec times[2][2] = {
+ {
+ {0, UTIME_OMIT}, // Valid
+ {2, static_cast<int64_t>(1e10)} // Invalid
+ },
+ {
+ {2, static_cast<int64_t>(1e10)}, // Invalid
+ {0, UTIME_OMIT} // Valid
+ }};
+
+ for (unsigned int i = 0; i < sizeof(times) / sizeof(times[0]); i++) {
+ std::cout << "test:" << i << "\n";
+ EXPECT_THAT(utimensat(0, f.path().c_str(), times[i], 0),
+ SyscallFailsWithErrno(EINVAL));
+ }
+}
+
+TEST(Utimensat, NullPath) {
+ // From man utimensat(2):
+ // "the Linux utimensat() system call implements a nonstandard feature: if
+ // pathname is NULL, then the call modifies the timestamps of the file
+ // referred to by the file descriptor dirfd (which may refer to any type of
+ // file).
+ // Note, however, that the glibc wrapper for utimensat() disallows
+ // passing NULL as the value for file: the wrapper function returns the error
+ // EINVAL in this case."
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDWR));
+ struct stat statbuf;
+ const struct timespec times[2] = {{10, 0}, {20, 0}};
+ // Call syscall directly.
+ EXPECT_THAT(syscall(SYS_utimensat, fd.get(), NULL, times, 0),
+ SyscallSucceeds());
+ EXPECT_THAT(fstatat(0, f.path().c_str(), &statbuf, 0), SyscallSucceeds());
+ EXPECT_EQ(10, statbuf.st_atime);
+ EXPECT_EQ(20, statbuf.st_mtime);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/vdso.cc b/test/syscalls/linux/vdso.cc
new file mode 100644
index 000000000..19c80add8
--- /dev/null
+++ b/test/syscalls/linux/vdso.cc
@@ -0,0 +1,48 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string.h>
+#include <sys/mman.h>
+
+#include <algorithm>
+
+#include "gtest/gtest.h"
+#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/proc_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Ensure that the vvar page cannot be made writable.
+TEST(VvarTest, WriteVvar) {
+ auto contents = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps"));
+ auto maps = ASSERT_NO_ERRNO_AND_VALUE(ParseProcMaps(contents));
+ auto it = std::find_if(maps.begin(), maps.end(), [](const ProcMapsEntry& e) {
+ return e.filename == "[vvar]";
+ });
+
+ SKIP_IF(it == maps.end());
+ EXPECT_THAT(mprotect(reinterpret_cast<void*>(it->start), kPageSize,
+ PROT_READ | PROT_WRITE),
+ SyscallFailsWithErrno(EACCES));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/vdso_clock_gettime.cc b/test/syscalls/linux/vdso_clock_gettime.cc
new file mode 100644
index 000000000..ce1899f45
--- /dev/null
+++ b/test/syscalls/linux/vdso_clock_gettime.cc
@@ -0,0 +1,108 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdint.h>
+#include <sys/time.h>
+#include <syscall.h>
+#include <time.h>
+#include <unistd.h>
+
+#include <map>
+#include <string>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/numbers.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+std::string PrintClockId(::testing::TestParamInfo<clockid_t> info) {
+ switch (info.param) {
+ case CLOCK_MONOTONIC:
+ return "CLOCK_MONOTONIC";
+ case CLOCK_REALTIME:
+ return "CLOCK_REALTIME";
+ case CLOCK_BOOTTIME:
+ return "CLOCK_BOOTTIME";
+ default:
+ return absl::StrCat(info.param);
+ }
+}
+
+class CorrectVDSOClockTest : public ::testing::TestWithParam<clockid_t> {};
+
+TEST_P(CorrectVDSOClockTest, IsCorrect) {
+ struct timespec tvdso, tsys;
+ absl::Time vdso_time, sys_time;
+ uint64_t total_calls = 0;
+
+ // It is expected that 82.5% of clock_gettime calls will be less than 100us
+ // skewed from the system time.
+ // Unfortunately this is not only influenced by the VDSO clock skew, but also
+ // by arbitrary scheduling delays and the like. The test is therefore
+ // regularly disabled.
+ std::map<absl::Duration, std::tuple<double, uint64_t, uint64_t>> confidence =
+ {
+ {absl::Microseconds(100), std::make_tuple(0.825, 0, 0)},
+ {absl::Microseconds(250), std::make_tuple(0.94, 0, 0)},
+ {absl::Milliseconds(1), std::make_tuple(0.999, 0, 0)},
+ };
+
+ absl::Time start = absl::Now();
+ while (absl::Now() < start + absl::Seconds(30)) {
+ EXPECT_THAT(clock_gettime(GetParam(), &tvdso), SyscallSucceeds());
+ EXPECT_THAT(syscall(__NR_clock_gettime, GetParam(), &tsys),
+ SyscallSucceeds());
+
+ vdso_time = absl::TimeFromTimespec(tvdso);
+
+ for (auto const& conf : confidence) {
+ std::get<1>(confidence[conf.first]) +=
+ (sys_time - vdso_time) < conf.first;
+ }
+
+ sys_time = absl::TimeFromTimespec(tsys);
+
+ for (auto const& conf : confidence) {
+ std::get<2>(confidence[conf.first]) +=
+ (vdso_time - sys_time) < conf.first;
+ }
+
+ ++total_calls;
+ }
+
+ for (auto const& conf : confidence) {
+ EXPECT_GE(std::get<1>(conf.second) / static_cast<double>(total_calls),
+ std::get<0>(conf.second));
+ EXPECT_GE(std::get<2>(conf.second) / static_cast<double>(total_calls),
+ std::get<0>(conf.second));
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(ClockGettime, CorrectVDSOClockTest,
+ ::testing::Values(CLOCK_MONOTONIC, CLOCK_REALTIME,
+ CLOCK_BOOTTIME),
+ PrintClockId);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/vfork.cc b/test/syscalls/linux/vfork.cc
new file mode 100644
index 000000000..19d05998e
--- /dev/null
+++ b/test/syscalls/linux/vfork.cc
@@ -0,0 +1,195 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <sys/types.h>
+#include <sys/wait.h>
+#include <unistd.h>
+
+#include <string>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
+#include "absl/time/time.h"
+#include "test/util/logging.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/test_util.h"
+#include "test/util/time_util.h"
+
+ABSL_FLAG(bool, vfork_test_child, false,
+ "If true, run the VforkTest child workload.");
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// We don't test with raw CLONE_VFORK to avoid interacting with glibc's use of
+// TLS.
+//
+// Even with vfork(2), we must be careful to do little more in the child than
+// call execve(2). We use the simplest sleep function possible, though this is
+// still precarious, as we're officially only allowed to call execve(2) and
+// _exit(2).
+constexpr absl::Duration kChildDelay = absl::Seconds(10);
+
+// Exit code for successful child subprocesses. We don't want to use 0 since
+// it's too common, and an execve(2) failure causes the child to exit with the
+// errno, so kChildExitCode is chosen to be an unlikely errno:
+constexpr int kChildExitCode = 118; // ENOTNAM: Not a XENIX named type file
+
+int64_t MonotonicNow() {
+ struct timespec now;
+ TEST_PCHECK(clock_gettime(CLOCK_MONOTONIC, &now) == 0);
+ return now.tv_sec * 1000000000ll + now.tv_nsec;
+}
+
+TEST(VforkTest, ParentStopsUntilChildExits) {
+ const auto test = [] {
+ // N.B. Run the test in a single-threaded subprocess because
+ // vfork is not safe in a multi-threaded process.
+
+ const int64_t start = MonotonicNow();
+
+ pid_t pid = vfork();
+ if (pid == 0) {
+ SleepSafe(kChildDelay);
+ _exit(kChildExitCode);
+ }
+ TEST_PCHECK_MSG(pid > 0, "vfork failed");
+ MaybeSave();
+
+ const int64_t end = MonotonicNow();
+
+ absl::Duration dur = absl::Nanoseconds(end - start);
+
+ TEST_CHECK(dur >= kChildDelay);
+
+ int status = 0;
+ TEST_PCHECK(RetryEINTR(waitpid)(pid, &status, 0));
+ TEST_CHECK(WIFEXITED(status));
+ TEST_CHECK(WEXITSTATUS(status) == kChildExitCode);
+ };
+
+ EXPECT_THAT(InForkedProcess(test), IsPosixErrorOkAndHolds(0));
+}
+
+TEST(VforkTest, ParentStopsUntilChildExecves_NoRandomSave) {
+ ExecveArray const owned_child_argv = {"/proc/self/exe", "--vfork_test_child"};
+ char* const* const child_argv = owned_child_argv.get();
+
+ const auto test = [&] {
+ const int64_t start = MonotonicNow();
+
+ pid_t pid = vfork();
+ if (pid == 0) {
+ SleepSafe(kChildDelay);
+ execve(child_argv[0], child_argv, /* envp = */ nullptr);
+ _exit(errno);
+ }
+ // Don't attempt save/restore until after recording end_time,
+ // since the test expects an upper bound on the time spent
+ // stopped.
+ int saved_errno = errno;
+ const int64_t end = MonotonicNow();
+ errno = saved_errno;
+ TEST_PCHECK_MSG(pid > 0, "vfork failed");
+ MaybeSave();
+
+ absl::Duration dur = absl::Nanoseconds(end - start);
+
+ // The parent should resume execution after execve, but before
+ // the post-execve test child exits.
+ TEST_CHECK(dur >= kChildDelay);
+ TEST_CHECK(dur <= 2 * kChildDelay);
+
+ int status = 0;
+ TEST_PCHECK(RetryEINTR(waitpid)(pid, &status, 0));
+ TEST_CHECK(WIFEXITED(status));
+ TEST_CHECK(WEXITSTATUS(status) == kChildExitCode);
+ };
+
+ EXPECT_THAT(InForkedProcess(test), IsPosixErrorOkAndHolds(0));
+}
+
+// A vfork child does not unstop the parent a second time when it exits after
+// exec.
+TEST(VforkTest, ExecedChildExitDoesntUnstopParent_NoRandomSave) {
+ ExecveArray const owned_child_argv = {"/proc/self/exe", "--vfork_test_child"};
+ char* const* const child_argv = owned_child_argv.get();
+
+ const auto test = [&] {
+ pid_t pid1 = vfork();
+ if (pid1 == 0) {
+ execve(child_argv[0], child_argv, /* envp = */ nullptr);
+ _exit(errno);
+ }
+ TEST_PCHECK_MSG(pid1 > 0, "vfork failed");
+ MaybeSave();
+
+ // pid1 exec'd and is now sleeping.
+ SleepSafe(kChildDelay / 2);
+
+ const int64_t start = MonotonicNow();
+
+ pid_t pid2 = vfork();
+ if (pid2 == 0) {
+ SleepSafe(kChildDelay);
+ _exit(kChildExitCode);
+ }
+ TEST_PCHECK_MSG(pid2 > 0, "vfork failed");
+ MaybeSave();
+
+ const int64_t end = MonotonicNow();
+
+ absl::Duration dur = absl::Nanoseconds(end - start);
+
+ // The parent should resume execution only after pid2 exits, not
+ // when pid1 exits.
+ TEST_CHECK(dur >= kChildDelay);
+
+ int status = 0;
+ TEST_PCHECK(RetryEINTR(waitpid)(pid1, &status, 0));
+ TEST_CHECK(WIFEXITED(status));
+ TEST_CHECK(WEXITSTATUS(status) == kChildExitCode);
+
+ TEST_PCHECK(RetryEINTR(waitpid)(pid2, &status, 0));
+ TEST_CHECK(WIFEXITED(status));
+ TEST_CHECK(WEXITSTATUS(status) == kChildExitCode);
+ };
+
+ EXPECT_THAT(InForkedProcess(test), IsPosixErrorOkAndHolds(0));
+}
+
+int RunChild() {
+ SleepSafe(kChildDelay);
+ return kChildExitCode;
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
+
+int main(int argc, char** argv) {
+ gvisor::testing::TestInit(&argc, &argv);
+
+ if (absl::GetFlag(FLAGS_vfork_test_child)) {
+ return gvisor::testing::RunChild();
+ }
+
+ return gvisor::testing::RunAllTests();
+}
diff --git a/test/syscalls/linux/vsyscall.cc b/test/syscalls/linux/vsyscall.cc
new file mode 100644
index 000000000..ae4377108
--- /dev/null
+++ b/test/syscalls/linux/vsyscall.cc
@@ -0,0 +1,46 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <time.h>
+
+#include "gtest/gtest.h"
+#include "test/util/proc_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+#if defined(__x86_64__) || defined(__i386__)
+time_t vsyscall_time(time_t* t) {
+ constexpr uint64_t kVsyscallTimeEntry = 0xffffffffff600400;
+ return reinterpret_cast<time_t (*)(time_t*)>(kVsyscallTimeEntry)(t);
+}
+
+TEST(VsyscallTest, VsyscallAlwaysAvailableOnGvisor) {
+ SKIP_IF(!IsRunningOnGvisor());
+ // Vsyscall is always advertised by gvisor.
+ EXPECT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(IsVsyscallEnabled()));
+ // Vsyscall should always works on gvisor.
+ time_t t;
+ EXPECT_THAT(vsyscall_time(&t), SyscallSucceeds());
+}
+#endif
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/wait.cc b/test/syscalls/linux/wait.cc
new file mode 100644
index 000000000..944149d5e
--- /dev/null
+++ b/test/syscalls/linux/wait.cc
@@ -0,0 +1,913 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+#include <sys/mman.h>
+#include <sys/ptrace.h>
+#include <sys/resource.h>
+#include <sys/time.h>
+#include <sys/types.h>
+#include <sys/wait.h>
+#include <unistd.h>
+
+#include <functional>
+#include <tuple>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+#include "test/util/time_util.h"
+
+using ::testing::UnorderedElementsAre;
+
+// These unit tests focus on the wait4(2) system call, but include a basic
+// checks for the i386 waitpid(2) syscall, which is a subset of wait4(2).
+//
+// NOTE(b/22640830,b/27680907,b/29049891): Some functionality is not tested as
+// it is not currently supported by gVisor:
+// * Process groups.
+// * Core dump status (WCOREDUMP).
+//
+// Tests for waiting on stopped/continued children are in sigstop.cc.
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// The CloneChild function seems to need more than one page of stack space.
+static const size_t kStackSize = 2 * kPageSize;
+
+// The child thread created in CloneAndExit runs this function.
+// This child does not have the TLS setup, so it must not use glibc functions.
+int CloneChild(void* priv) {
+ int64_t sleep = reinterpret_cast<int64_t>(priv);
+ SleepSafe(absl::Seconds(sleep));
+
+ // glibc's _exit(2) function wrapper will helpfully call exit_group(2),
+ // exiting the entire process.
+ syscall(__NR_exit, 0);
+ return 1;
+}
+
+// ForkAndExit forks a child process which exits with exit_code, after
+// sleeping for the specified duration (seconds).
+pid_t ForkAndExit(int exit_code, int64_t sleep) {
+ pid_t child = fork();
+ if (child == 0) {
+ SleepSafe(absl::Seconds(sleep));
+ _exit(exit_code);
+ }
+ return child;
+}
+
+int64_t clock_gettime_nsecs(clockid_t id) {
+ struct timespec ts;
+ TEST_PCHECK(clock_gettime(id, &ts) == 0);
+ return (ts.tv_sec * 1000000000 + ts.tv_nsec);
+}
+
+void spin(int64_t sec) {
+ int64_t ns = sec * 1000000000;
+ int64_t start = clock_gettime_nsecs(CLOCK_THREAD_CPUTIME_ID);
+ int64_t end = start + ns;
+
+ do {
+ constexpr int kLoopCount = 1000000; // large and arbitrary
+ // volatile to prevent the compiler from skipping this loop.
+ for (volatile int i = 0; i < kLoopCount; i++) {
+ }
+ } while (clock_gettime_nsecs(CLOCK_THREAD_CPUTIME_ID) < end);
+}
+
+// ForkSpinAndExit forks a child process which exits with exit_code, after
+// spinning for the specified duration (seconds).
+pid_t ForkSpinAndExit(int exit_code, int64_t spintime) {
+ pid_t child = fork();
+ if (child == 0) {
+ spin(spintime);
+ _exit(exit_code);
+ }
+ return child;
+}
+
+absl::Duration RusageCpuTime(const struct rusage& ru) {
+ return absl::DurationFromTimeval(ru.ru_utime) +
+ absl::DurationFromTimeval(ru.ru_stime);
+}
+
+// Returns the address of the top of the stack.
+// Free with FreeStack.
+uintptr_t AllocStack() {
+ void* addr = mmap(nullptr, kStackSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+
+ if (addr == MAP_FAILED) {
+ return reinterpret_cast<uintptr_t>(MAP_FAILED);
+ }
+
+ return reinterpret_cast<uintptr_t>(addr) + kStackSize;
+}
+
+// Frees a stack page allocated with AllocStack.
+int FreeStack(uintptr_t addr) {
+ addr -= kStackSize;
+ return munmap(reinterpret_cast<void*>(addr), kPageSize);
+}
+
+// CloneAndExit clones a child thread, which exits with 0 after sleeping for
+// the specified duration (must be in seconds). extra_flags are ORed against
+// the standard clone(2) flags.
+int CloneAndExit(int64_t sleep, uintptr_t stack, int extra_flags) {
+ return clone(CloneChild, reinterpret_cast<void*>(stack),
+ CLONE_FILES | CLONE_FS | CLONE_SIGHAND | CLONE_VM | extra_flags,
+ reinterpret_cast<void*>(sleep));
+}
+
+// Simple wrappers around wait4(2) and waitid(2) that ignore interrupts.
+constexpr auto Wait4 = RetryEINTR(wait4);
+constexpr auto Waitid = RetryEINTR(waitid);
+
+// Fixture for tests parameterized by a function that waits for any child to
+// exit with the given options, checks that it exited with the given code, and
+// then returns its PID.
+//
+// N.B. These tests run in a multi-threaded environment. We assume that
+// background threads do not create child processes and are not themselves
+// created with clone(... | SIGCHLD). Either may cause these tests to
+// erroneously wait on child processes/threads.
+class WaitAnyChildTest : public ::testing::TestWithParam<
+ std::function<PosixErrorOr<pid_t>(int, int)>> {
+ protected:
+ PosixErrorOr<pid_t> WaitAny(int code) { return WaitAnyWithOptions(code, 0); }
+
+ PosixErrorOr<pid_t> WaitAnyWithOptions(int code, int options) {
+ return GetParam()(code, options);
+ }
+};
+
+// Wait for any child to exit.
+TEST_P(WaitAnyChildTest, Fork) {
+ pid_t child;
+ ASSERT_THAT(child = ForkAndExit(0, 0), SyscallSucceeds());
+
+ EXPECT_THAT(WaitAny(0), IsPosixErrorOkAndHolds(child));
+}
+
+// Call wait4 for any process after the child has already exited.
+TEST_P(WaitAnyChildTest, AfterExit) {
+ pid_t child;
+ ASSERT_THAT(child = ForkAndExit(0, 0), SyscallSucceeds());
+
+ absl::SleepFor(absl::Seconds(5));
+
+ EXPECT_THAT(WaitAny(0), IsPosixErrorOkAndHolds(child));
+}
+
+// Wait for multiple children to exit, waiting for either at a time.
+TEST_P(WaitAnyChildTest, MultipleFork) {
+ pid_t child1, child2;
+ ASSERT_THAT(child1 = ForkAndExit(0, 0), SyscallSucceeds());
+ ASSERT_THAT(child2 = ForkAndExit(0, 0), SyscallSucceeds());
+
+ std::vector<pid_t> pids;
+ pids.push_back(ASSERT_NO_ERRNO_AND_VALUE(WaitAny(0)));
+ pids.push_back(ASSERT_NO_ERRNO_AND_VALUE(WaitAny(0)));
+ EXPECT_THAT(pids, UnorderedElementsAre(child1, child2));
+}
+
+// Wait for any child to exit.
+// A non-CLONE_THREAD child which sends SIGCHLD upon exit behaves much like
+// a forked process.
+TEST_P(WaitAnyChildTest, CloneSIGCHLD) {
+ uintptr_t stack;
+ ASSERT_THAT(stack = AllocStack(), SyscallSucceeds());
+ auto free =
+ Cleanup([stack] { ASSERT_THAT(FreeStack(stack), SyscallSucceeds()); });
+
+ int child;
+ ASSERT_THAT(child = CloneAndExit(0, stack, SIGCHLD), SyscallSucceeds());
+
+ EXPECT_THAT(WaitAny(0), IsPosixErrorOkAndHolds(child));
+}
+
+// Wait for a child thread and process.
+TEST_P(WaitAnyChildTest, ForkAndClone) {
+ pid_t process;
+ ASSERT_THAT(process = ForkAndExit(0, 0), SyscallSucceeds());
+
+ uintptr_t stack;
+ ASSERT_THAT(stack = AllocStack(), SyscallSucceeds());
+ auto free =
+ Cleanup([stack] { ASSERT_THAT(FreeStack(stack), SyscallSucceeds()); });
+
+ int thread;
+ // Send SIGCHLD for normal wait semantics.
+ ASSERT_THAT(thread = CloneAndExit(0, stack, SIGCHLD), SyscallSucceeds());
+
+ std::vector<pid_t> pids;
+ pids.push_back(ASSERT_NO_ERRNO_AND_VALUE(WaitAny(0)));
+ pids.push_back(ASSERT_NO_ERRNO_AND_VALUE(WaitAny(0)));
+ EXPECT_THAT(pids, UnorderedElementsAre(process, thread));
+}
+
+// Return immediately if no child has exited.
+TEST_P(WaitAnyChildTest, WaitWNOHANG) {
+ EXPECT_THAT(WaitAnyWithOptions(0, WNOHANG),
+ PosixErrorIs(ECHILD, ::testing::_));
+}
+
+// Bad options passed
+TEST_P(WaitAnyChildTest, BadOption) {
+ EXPECT_THAT(WaitAnyWithOptions(0, 123456),
+ PosixErrorIs(EINVAL, ::testing::_));
+}
+
+TEST_P(WaitAnyChildTest, WaitedChildRusage) {
+ struct rusage before;
+ ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &before), SyscallSucceeds());
+
+ pid_t child;
+ constexpr absl::Duration kSpin = absl::Seconds(3);
+ ASSERT_THAT(child = ForkSpinAndExit(0, absl::ToInt64Seconds(kSpin)),
+ SyscallSucceeds());
+ ASSERT_THAT(WaitAny(0), IsPosixErrorOkAndHolds(child));
+
+ struct rusage after;
+ ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &after), SyscallSucceeds());
+
+ EXPECT_GE(RusageCpuTime(after) - RusageCpuTime(before), kSpin);
+}
+
+TEST_P(WaitAnyChildTest, IgnoredChildRusage) {
+ // "POSIX.1-2001 specifies that if the disposition of SIGCHLD is
+ // set to SIG_IGN or the SA_NOCLDWAIT flag is set for SIGCHLD (see
+ // sigaction(2)), then children that terminate do not become zombies and a
+ // call to wait() or waitpid() will block until all children have terminated,
+ // and then fail with errno set to ECHILD." - waitpid(2)
+ //
+ // "RUSAGE_CHILDREN: Return resource usage statistics for all children of the
+ // calling process that have terminated *and been waited for*." -
+ // getrusage(2), emphasis added
+
+ struct sigaction sa;
+ sa.sa_handler = SIG_IGN;
+ const auto cleanup_sigact =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGCHLD, sa));
+
+ struct rusage before;
+ ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &before), SyscallSucceeds());
+
+ const absl::Duration start =
+ absl::Nanoseconds(clock_gettime_nsecs(CLOCK_MONOTONIC));
+
+ constexpr absl::Duration kSpin = absl::Seconds(3);
+
+ // ForkAndSpin uses CLOCK_THREAD_CPUTIME_ID, which is lower resolution than,
+ // and may diverge from, CLOCK_MONOTONIC, so we allow a small grace period but
+ // still check that we blocked for a while.
+ constexpr absl::Duration kSpinGrace = absl::Milliseconds(100);
+
+ pid_t child;
+ ASSERT_THAT(child = ForkSpinAndExit(0, absl::ToInt64Seconds(kSpin)),
+ SyscallSucceeds());
+ ASSERT_THAT(WaitAny(0), PosixErrorIs(ECHILD, ::testing::_));
+ const absl::Duration end =
+ absl::Nanoseconds(clock_gettime_nsecs(CLOCK_MONOTONIC));
+ EXPECT_GE(end - start, kSpin - kSpinGrace);
+
+ struct rusage after;
+ ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &after), SyscallSucceeds());
+ EXPECT_EQ(before.ru_utime.tv_sec, after.ru_utime.tv_sec);
+ EXPECT_EQ(before.ru_utime.tv_usec, after.ru_utime.tv_usec);
+ EXPECT_EQ(before.ru_stime.tv_sec, after.ru_stime.tv_sec);
+ EXPECT_EQ(before.ru_stime.tv_usec, after.ru_stime.tv_usec);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ Waiters, WaitAnyChildTest,
+ ::testing::Values(
+ [](int code, int options) -> PosixErrorOr<pid_t> {
+ int status;
+ auto const pid = Wait4(-1, &status, options, nullptr);
+ MaybeSave();
+ if (pid < 0) {
+ return PosixError(errno, "wait4");
+ }
+ if (!WIFEXITED(status) || WEXITSTATUS(status) != code) {
+ return PosixError(
+ EINVAL, absl::StrCat("unexpected wait status: got ", status,
+ ", wanted ", code));
+ }
+ return static_cast<pid_t>(pid);
+ },
+ [](int code, int options) -> PosixErrorOr<pid_t> {
+ siginfo_t si;
+ auto const rv = Waitid(P_ALL, 0, &si, WEXITED | options);
+ MaybeSave();
+ if (rv < 0) {
+ return PosixError(errno, "waitid");
+ }
+ if (si.si_signo != SIGCHLD) {
+ return PosixError(
+ EINVAL, absl::StrCat("unexpected signo: got ", si.si_signo,
+ ", wanted ", SIGCHLD));
+ }
+ if (si.si_status != code) {
+ return PosixError(
+ EINVAL, absl::StrCat("unexpected status: got ", si.si_status,
+ ", wanted ", code));
+ }
+ if (si.si_code != CLD_EXITED) {
+ return PosixError(EINVAL,
+ absl::StrCat("unexpected code: got ", si.si_code,
+ ", wanted ", CLD_EXITED));
+ }
+ auto const uid = getuid();
+ if (si.si_uid != uid) {
+ return PosixError(EINVAL,
+ absl::StrCat("unexpected uid: got ", si.si_uid,
+ ", wanted ", uid));
+ }
+ return static_cast<pid_t>(si.si_pid);
+ }));
+
+// Fixture for tests parameterized by a (sysno, function) tuple. The function
+// takes the PID of a specific child to wait for, waits for it to exit, and
+// checks that it exits with the given code.
+class WaitSpecificChildTest
+ : public ::testing::TestWithParam<
+ std::tuple<int, std::function<PosixError(pid_t, int, int)>>> {
+ protected:
+ int Sysno() { return std::get<0>(GetParam()); }
+
+ PosixError WaitForWithOptions(pid_t pid, int options, int code) {
+ return std::get<1>(GetParam())(pid, options, code);
+ }
+
+ PosixError WaitFor(pid_t pid, int code) {
+ return std::get<1>(GetParam())(pid, 0, code);
+ }
+};
+
+// Wait for specific child to exit.
+TEST_P(WaitSpecificChildTest, Fork) {
+ pid_t child;
+ ASSERT_THAT(child = ForkAndExit(0, 0), SyscallSucceeds());
+
+ EXPECT_NO_ERRNO(WaitFor(child, 0));
+}
+
+// Non-zero exit codes are correctly propagated.
+TEST_P(WaitSpecificChildTest, NormalExit) {
+ pid_t child;
+ ASSERT_THAT(child = ForkAndExit(42, 0), SyscallSucceeds());
+
+ EXPECT_NO_ERRNO(WaitFor(child, 42));
+}
+
+// Wait for multiple children to exit.
+TEST_P(WaitSpecificChildTest, MultipleFork) {
+ pid_t child1, child2;
+ ASSERT_THAT(child1 = ForkAndExit(0, 0), SyscallSucceeds());
+ ASSERT_THAT(child2 = ForkAndExit(0, 0), SyscallSucceeds());
+
+ EXPECT_NO_ERRNO(WaitFor(child1, 0));
+ EXPECT_NO_ERRNO(WaitFor(child2, 0));
+}
+
+// Wait for multiple children to exit, out of the order they were created.
+TEST_P(WaitSpecificChildTest, MultipleForkOutOfOrder) {
+ pid_t child1, child2;
+ ASSERT_THAT(child1 = ForkAndExit(0, 0), SyscallSucceeds());
+ ASSERT_THAT(child2 = ForkAndExit(0, 0), SyscallSucceeds());
+
+ EXPECT_NO_ERRNO(WaitFor(child2, 0));
+ EXPECT_NO_ERRNO(WaitFor(child1, 0));
+}
+
+// Wait for specific child to exit, entering wait4 before the exit occurs.
+TEST_P(WaitSpecificChildTest, ForkSleep) {
+ pid_t child;
+ ASSERT_THAT(child = ForkAndExit(0, 5), SyscallSucceeds());
+
+ EXPECT_NO_ERRNO(WaitFor(child, 0));
+}
+
+// Wait should block until the child exits.
+TEST_P(WaitSpecificChildTest, ForkBlock) {
+ pid_t child;
+
+ auto start = absl::Now();
+ ASSERT_THAT(child = ForkAndExit(0, 5), SyscallSucceeds());
+
+ EXPECT_NO_ERRNO(WaitFor(child, 0));
+
+ EXPECT_GE(absl::Now() - start, absl::Seconds(5));
+}
+
+// Waiting after the child has already exited returns immediately.
+TEST_P(WaitSpecificChildTest, AfterExit) {
+ pid_t child;
+ ASSERT_THAT(child = ForkAndExit(0, 0), SyscallSucceeds());
+
+ absl::SleepFor(absl::Seconds(5));
+
+ EXPECT_NO_ERRNO(WaitFor(child, 0));
+}
+
+// Wait for child of sibling thread.
+TEST_P(WaitSpecificChildTest, SiblingChildren) {
+ absl::Mutex mu;
+ pid_t child;
+ bool ready = false;
+ bool stop = false;
+
+ ScopedThread t([&] {
+ absl::MutexLock ml(&mu);
+ EXPECT_THAT(child = ForkAndExit(0, 0), SyscallSucceeds());
+ ready = true;
+ mu.Await(absl::Condition(&stop));
+ });
+
+ // N.B. This must be declared after ScopedThread, so it is destructed first,
+ // thus waking the thread.
+ absl::MutexLock ml(&mu);
+ mu.Await(absl::Condition(&ready));
+
+ EXPECT_NO_ERRNO(WaitFor(child, 0));
+
+ // Keep the sibling alive until after we've waited so the child isn't
+ // reparented.
+ stop = true;
+}
+
+// Waiting for child of sibling thread not allowed with WNOTHREAD.
+TEST_P(WaitSpecificChildTest, SiblingChildrenWNOTHREAD) {
+ // Linux added WNOTHREAD support to waitid(2) in
+ // 91c4e8ea8f05916df0c8a6f383508ac7c9e10dba ("wait: allow sys_waitid() to
+ // accept __WNOTHREAD/__WCLONE/__WALL"). i.e., Linux 4.7.
+ //
+ // Skip the test if it isn't supported yet.
+ if (Sysno() == SYS_waitid) {
+ int ret = waitid(P_ALL, 0, nullptr, WEXITED | WNOHANG | __WNOTHREAD);
+ SKIP_IF(ret < 0 && errno == EINVAL);
+ }
+
+ absl::Mutex mu;
+ pid_t child;
+ bool ready = false;
+ bool stop = false;
+
+ ScopedThread t([&] {
+ absl::MutexLock ml(&mu);
+ EXPECT_THAT(child = ForkAndExit(0, 0), SyscallSucceeds());
+ ready = true;
+ mu.Await(absl::Condition(&stop));
+
+ // This thread can wait on child.
+ EXPECT_NO_ERRNO(WaitForWithOptions(child, __WNOTHREAD, 0));
+ });
+
+ // N.B. This must be declared after ScopedThread, so it is destructed first,
+ // thus waking the thread.
+ absl::MutexLock ml(&mu);
+ mu.Await(absl::Condition(&ready));
+
+ // This thread can't wait on child.
+ EXPECT_THAT(WaitForWithOptions(child, __WNOTHREAD, 0),
+ PosixErrorIs(ECHILD, ::testing::_));
+
+ // Keep the sibling alive until after we've waited so the child isn't
+ // reparented.
+ stop = true;
+}
+
+// Wait for specific child to exit.
+// A non-CLONE_THREAD child which sends SIGCHLD upon exit behaves much like
+// a forked process.
+TEST_P(WaitSpecificChildTest, CloneSIGCHLD) {
+ uintptr_t stack;
+ ASSERT_THAT(stack = AllocStack(), SyscallSucceeds());
+ auto free =
+ Cleanup([stack] { ASSERT_THAT(FreeStack(stack), SyscallSucceeds()); });
+
+ int child;
+ ASSERT_THAT(child = CloneAndExit(0, stack, SIGCHLD), SyscallSucceeds());
+
+ EXPECT_NO_ERRNO(WaitFor(child, 0));
+}
+
+// Wait for specific child to exit.
+// A non-CLONE_THREAD child which does not send SIGCHLD upon exit can be waited
+// on, but returns ECHILD.
+TEST_P(WaitSpecificChildTest, CloneNoSIGCHLD) {
+ uintptr_t stack;
+ ASSERT_THAT(stack = AllocStack(), SyscallSucceeds());
+ auto free =
+ Cleanup([stack] { ASSERT_THAT(FreeStack(stack), SyscallSucceeds()); });
+
+ int child;
+ ASSERT_THAT(child = CloneAndExit(0, stack, 0), SyscallSucceeds());
+
+ EXPECT_THAT(WaitFor(child, 0), PosixErrorIs(ECHILD, ::testing::_));
+}
+
+// Waiting after the child has already exited returns immediately.
+TEST_P(WaitSpecificChildTest, CloneAfterExit) {
+ uintptr_t stack;
+ ASSERT_THAT(stack = AllocStack(), SyscallSucceeds());
+ auto free =
+ Cleanup([stack] { ASSERT_THAT(FreeStack(stack), SyscallSucceeds()); });
+
+ int child;
+ // Send SIGCHLD for normal wait semantics.
+ ASSERT_THAT(child = CloneAndExit(0, stack, SIGCHLD), SyscallSucceeds());
+
+ absl::SleepFor(absl::Seconds(5));
+
+ EXPECT_NO_ERRNO(WaitFor(child, 0));
+}
+
+// A CLONE_THREAD child cannot be waited on.
+TEST_P(WaitSpecificChildTest, CloneThread) {
+ uintptr_t stack;
+ ASSERT_THAT(stack = AllocStack(), SyscallSucceeds());
+ auto free =
+ Cleanup([stack] { ASSERT_THAT(FreeStack(stack), SyscallSucceeds()); });
+
+ int child;
+ ASSERT_THAT(child = CloneAndExit(15, stack, CLONE_THREAD), SyscallSucceeds());
+ auto start = absl::Now();
+
+ EXPECT_THAT(WaitFor(child, 0), PosixErrorIs(ECHILD, ::testing::_));
+
+ // Ensure wait4 didn't block.
+ EXPECT_LE(absl::Now() - start, absl::Seconds(10));
+
+ // Since we can't wait on the child, we sleep to try to avoid freeing its
+ // stack before it exits.
+ absl::SleepFor(absl::Seconds(5));
+}
+
+// A child that does not send a SIGCHLD on exit may be waited on with
+// the __WCLONE flag.
+TEST_P(WaitSpecificChildTest, CloneWCLONE) {
+ // Linux added WCLONE support to waitid(2) in
+ // 91c4e8ea8f05916df0c8a6f383508ac7c9e10dba ("wait: allow sys_waitid() to
+ // accept __WNOTHREAD/__WCLONE/__WALL"). i.e., Linux 4.7.
+ //
+ // Skip the test if it isn't supported yet.
+ if (Sysno() == SYS_waitid) {
+ int ret = waitid(P_ALL, 0, nullptr, WEXITED | WNOHANG | __WCLONE);
+ SKIP_IF(ret < 0 && errno == EINVAL);
+ }
+
+ uintptr_t stack;
+ ASSERT_THAT(stack = AllocStack(), SyscallSucceeds());
+ auto free =
+ Cleanup([stack] { ASSERT_THAT(FreeStack(stack), SyscallSucceeds()); });
+
+ int child;
+ ASSERT_THAT(child = CloneAndExit(0, stack, 0), SyscallSucceeds());
+
+ EXPECT_NO_ERRNO(WaitForWithOptions(child, __WCLONE, 0));
+}
+
+// A forked child cannot be waited on with WCLONE.
+TEST_P(WaitSpecificChildTest, ForkWCLONE) {
+ // Linux added WCLONE support to waitid(2) in
+ // 91c4e8ea8f05916df0c8a6f383508ac7c9e10dba ("wait: allow sys_waitid() to
+ // accept __WNOTHREAD/__WCLONE/__WALL"). i.e., Linux 4.7.
+ //
+ // Skip the test if it isn't supported yet.
+ if (Sysno() == SYS_waitid) {
+ int ret = waitid(P_ALL, 0, nullptr, WEXITED | WNOHANG | __WCLONE);
+ SKIP_IF(ret < 0 && errno == EINVAL);
+ }
+
+ pid_t child;
+ ASSERT_THAT(child = ForkAndExit(0, 0), SyscallSucceeds());
+
+ EXPECT_THAT(WaitForWithOptions(child, WNOHANG | __WCLONE, 0),
+ PosixErrorIs(ECHILD, ::testing::_));
+
+ EXPECT_NO_ERRNO(WaitFor(child, 0));
+}
+
+// Any type of child can be waited on with WALL.
+TEST_P(WaitSpecificChildTest, WALL) {
+ // Linux added WALL support to waitid(2) in
+ // 91c4e8ea8f05916df0c8a6f383508ac7c9e10dba ("wait: allow sys_waitid() to
+ // accept __WNOTHREAD/__WCLONE/__WALL"). i.e., Linux 4.7.
+ //
+ // Skip the test if it isn't supported yet.
+ if (Sysno() == SYS_waitid) {
+ int ret = waitid(P_ALL, 0, nullptr, WEXITED | WNOHANG | __WALL);
+ SKIP_IF(ret < 0 && errno == EINVAL);
+ }
+
+ pid_t child;
+ ASSERT_THAT(child = ForkAndExit(0, 0), SyscallSucceeds());
+
+ EXPECT_NO_ERRNO(WaitForWithOptions(child, __WALL, 0));
+
+ uintptr_t stack;
+ ASSERT_THAT(stack = AllocStack(), SyscallSucceeds());
+ auto free =
+ Cleanup([stack] { ASSERT_THAT(FreeStack(stack), SyscallSucceeds()); });
+
+ ASSERT_THAT(child = CloneAndExit(0, stack, 0), SyscallSucceeds());
+
+ EXPECT_NO_ERRNO(WaitForWithOptions(child, __WALL, 0));
+}
+
+// Return ECHILD for bad child.
+TEST_P(WaitSpecificChildTest, BadChild) {
+ EXPECT_THAT(WaitFor(42, 0), PosixErrorIs(ECHILD, ::testing::_));
+}
+
+// Wait for a child process that only exits after calling execve(2) from a
+// non-leader thread.
+TEST_P(WaitSpecificChildTest, AfterChildExecve) {
+ ExecveArray const owned_child_argv = {"/bin/true"};
+ char* const* const child_argv = owned_child_argv.get();
+
+ uintptr_t stack;
+ ASSERT_THAT(stack = AllocStack(), SyscallSucceeds());
+ auto free =
+ Cleanup([stack] { ASSERT_THAT(FreeStack(stack), SyscallSucceeds()); });
+
+ pid_t const child = fork();
+ if (child == 0) {
+ // Give the parent some time to start waiting.
+ SleepSafe(absl::Seconds(5));
+ // Pass CLONE_VFORK to block the original thread in the child process until
+ // the clone thread calls execve, annihilating them both. (This means that
+ // if clone returns at all, something went wrong.)
+ //
+ // N.B. clone(2) is not officially async-signal-safe, but at minimum glibc's
+ // x86_64 implementation is safe. See glibc
+ // sysdeps/unix/sysv/linux/x86_64/clone.S.
+ clone(
+ +[](void* arg) {
+ auto child_argv = static_cast<char* const*>(arg);
+ execve(child_argv[0], child_argv, /* envp = */ nullptr);
+ return errno;
+ },
+ reinterpret_cast<void*>(stack),
+ CLONE_FILES | CLONE_FS | CLONE_SIGHAND | CLONE_THREAD | CLONE_VM |
+ CLONE_VFORK,
+ const_cast<char**>(child_argv));
+ _exit(errno);
+ }
+ ASSERT_THAT(child, SyscallSucceeds());
+ EXPECT_NO_ERRNO(WaitFor(child, 0));
+}
+
+PosixError CheckWait4(pid_t pid, int options, int code) {
+ int status;
+ auto const rv = Wait4(pid, &status, options, nullptr);
+ MaybeSave();
+ if (rv < 0) {
+ return PosixError(errno, "wait4");
+ } else if (rv != pid) {
+ return PosixError(
+ EINVAL, absl::StrCat("unexpected pid: got ", rv, ", wanted ", pid));
+ }
+ if (!WIFEXITED(status) || WEXITSTATUS(status) != code) {
+ return PosixError(EINVAL, absl::StrCat("unexpected wait status: got ",
+ status, ", wanted ", code));
+ }
+ return NoError();
+};
+
+PosixError CheckWaitid(pid_t pid, int options, int code) {
+ siginfo_t si;
+ auto const rv = Waitid(P_PID, pid, &si, options | WEXITED);
+ MaybeSave();
+ if (rv < 0) {
+ return PosixError(errno, "waitid");
+ }
+ if (si.si_pid != pid) {
+ return PosixError(EINVAL, absl::StrCat("unexpected pid: got ", si.si_pid,
+ ", wanted ", pid));
+ }
+ if (si.si_signo != SIGCHLD) {
+ return PosixError(EINVAL, absl::StrCat("unexpected signo: got ",
+ si.si_signo, ", wanted ", SIGCHLD));
+ }
+ if (si.si_status != code) {
+ return PosixError(EINVAL, absl::StrCat("unexpected status: got ",
+ si.si_status, ", wanted ", code));
+ }
+ if (si.si_code != CLD_EXITED) {
+ return PosixError(EINVAL, absl::StrCat("unexpected code: got ", si.si_code,
+ ", wanted ", CLD_EXITED));
+ }
+ return NoError();
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ Waiters, WaitSpecificChildTest,
+ ::testing::Values(std::make_tuple(SYS_wait4, CheckWait4),
+ std::make_tuple(SYS_waitid, CheckWaitid)));
+
+// WIFEXITED, WIFSIGNALED, WTERMSIG indicate signal exit.
+TEST(WaitTest, SignalExit) {
+ pid_t child;
+ ASSERT_THAT(child = ForkAndExit(0, 10), SyscallSucceeds());
+
+ EXPECT_THAT(kill(child, SIGKILL), SyscallSucceeds());
+
+ int status;
+ EXPECT_THAT(Wait4(child, &status, 0, nullptr),
+ SyscallSucceedsWithValue(child));
+
+ EXPECT_FALSE(WIFEXITED(status));
+ EXPECT_TRUE(WIFSIGNALED(status));
+ EXPECT_EQ(SIGKILL, WTERMSIG(status));
+}
+
+// waitid requires at least one option.
+TEST(WaitTest, WaitidOptions) {
+ EXPECT_THAT(Waitid(P_ALL, 0, nullptr, 0), SyscallFailsWithErrno(EINVAL));
+}
+
+// waitid does not wait for a child to exit if not passed WEXITED.
+TEST(WaitTest, WaitidNoWEXITED) {
+ pid_t child;
+ ASSERT_THAT(child = ForkAndExit(0, 0), SyscallSucceeds());
+ EXPECT_THAT(Waitid(P_ALL, 0, nullptr, WSTOPPED),
+ SyscallFailsWithErrno(ECHILD));
+ EXPECT_THAT(Waitid(P_ALL, 0, nullptr, WEXITED), SyscallSucceeds());
+}
+
+// WNOWAIT allows the same wait result to be returned again.
+TEST(WaitTest, WaitidWNOWAIT) {
+ pid_t child;
+ ASSERT_THAT(child = ForkAndExit(42, 0), SyscallSucceeds());
+
+ siginfo_t info;
+ ASSERT_THAT(Waitid(P_PID, child, &info, WEXITED | WNOWAIT),
+ SyscallSucceeds());
+ EXPECT_EQ(child, info.si_pid);
+ EXPECT_EQ(SIGCHLD, info.si_signo);
+ EXPECT_EQ(CLD_EXITED, info.si_code);
+ EXPECT_EQ(42, info.si_status);
+
+ ASSERT_THAT(Waitid(P_PID, child, &info, WEXITED), SyscallSucceeds());
+ EXPECT_EQ(child, info.si_pid);
+ EXPECT_EQ(SIGCHLD, info.si_signo);
+ EXPECT_EQ(CLD_EXITED, info.si_code);
+ EXPECT_EQ(42, info.si_status);
+
+ EXPECT_THAT(Waitid(P_PID, child, &info, WEXITED),
+ SyscallFailsWithErrno(ECHILD));
+}
+
+// waitpid(pid, status, options) is equivalent to
+// wait4(pid, status, options, nullptr).
+// This is a dedicated syscall on i386, glibc maps it to wait4 on amd64.
+TEST(WaitTest, WaitPid) {
+ pid_t child;
+ ASSERT_THAT(child = ForkAndExit(42, 0), SyscallSucceeds());
+
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0),
+ SyscallSucceedsWithValue(child));
+
+ EXPECT_TRUE(WIFEXITED(status));
+ EXPECT_EQ(42, WEXITSTATUS(status));
+}
+
+// Test that signaling a zombie succeeds. This is a signals test that is in this
+// file for some reason.
+TEST(WaitTest, KillZombie) {
+ pid_t child;
+ ASSERT_THAT(child = ForkAndExit(42, 0), SyscallSucceeds());
+
+ // Sleep for three seconds to ensure the child has exited.
+ absl::SleepFor(absl::Seconds(3));
+
+ // The child is now a zombie. Check that killing it returns 0.
+ EXPECT_THAT(kill(child, SIGTERM), SyscallSucceeds());
+ EXPECT_THAT(kill(child, 0), SyscallSucceeds());
+
+ EXPECT_THAT(Wait4(child, nullptr, 0, nullptr),
+ SyscallSucceedsWithValue(child));
+}
+
+TEST(WaitTest, Wait4Rusage) {
+ pid_t child;
+ constexpr absl::Duration kSpin = absl::Seconds(3);
+ ASSERT_THAT(child = ForkSpinAndExit(21, absl::ToInt64Seconds(kSpin)),
+ SyscallSucceeds());
+
+ int status;
+ struct rusage rusage = {};
+ ASSERT_THAT(Wait4(child, &status, 0, &rusage),
+ SyscallSucceedsWithValue(child));
+
+ EXPECT_TRUE(WIFEXITED(status));
+ EXPECT_EQ(21, WEXITSTATUS(status));
+
+ EXPECT_GE(RusageCpuTime(rusage), kSpin);
+}
+
+TEST(WaitTest, WaitidRusage) {
+ pid_t child;
+ constexpr absl::Duration kSpin = absl::Seconds(3);
+ ASSERT_THAT(child = ForkSpinAndExit(27, absl::ToInt64Seconds(kSpin)),
+ SyscallSucceeds());
+
+ siginfo_t si = {};
+ struct rusage rusage = {};
+
+ // From waitid(2):
+ // The raw waitid() system call takes a fifth argument, of type
+ // struct rusage *. If this argument is non-NULL, then it is used
+ // to return resource usage information about the child, in the
+ // same manner as wait4(2).
+ EXPECT_THAT(
+ RetryEINTR(syscall)(SYS_waitid, P_PID, child, &si, WEXITED, &rusage),
+ SyscallSucceeds());
+ EXPECT_EQ(si.si_signo, SIGCHLD);
+ EXPECT_EQ(si.si_code, CLD_EXITED);
+ EXPECT_EQ(si.si_status, 27);
+ EXPECT_EQ(si.si_pid, child);
+
+ EXPECT_GE(RusageCpuTime(rusage), kSpin);
+}
+
+// After bf959931ddb88c4e4366e96dd22e68fa0db9527c ("wait/ptrace: assume __WALL
+// if the child is traced") (Linux 4.7), tracees are always eligible for
+// waiting, regardless of type.
+TEST(WaitTest, TraceeWALL) {
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ FileDescriptor rfd(fds[0]);
+ FileDescriptor wfd(fds[1]);
+
+ pid_t child = fork();
+ if (child == 0) {
+ // Child.
+ rfd.reset();
+
+ TEST_PCHECK(ptrace(PTRACE_TRACEME, 0, nullptr, nullptr) == 0);
+
+ // Notify parent that we're now a tracee.
+ wfd.reset();
+
+ _exit(0);
+ }
+ ASSERT_THAT(child, SyscallSucceeds());
+
+ wfd.reset();
+
+ // Wait for child to become tracee.
+ char c;
+ EXPECT_THAT(ReadFd(rfd.get(), &c, sizeof(c)), SyscallSucceedsWithValue(0));
+
+ // We can wait on the fork child with WCLONE, as it is a tracee.
+ int status;
+ if (IsRunningOnGvisor()) {
+ ASSERT_THAT(Wait4(child, &status, __WCLONE, nullptr),
+ SyscallSucceedsWithValue(child));
+
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) << status;
+ } else {
+ // On older versions of Linux, we may get ECHILD.
+ ASSERT_THAT(Wait4(child, &status, __WCLONE, nullptr),
+ ::testing::AnyOf(SyscallSucceedsWithValue(child),
+ SyscallFailsWithErrno(ECHILD)));
+ }
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/write.cc b/test/syscalls/linux/write.cc
new file mode 100644
index 000000000..39b5b2f56
--- /dev/null
+++ b/test/syscalls/linux/write.cc
@@ -0,0 +1,139 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <signal.h>
+#include <sys/resource.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <time.h>
+#include <unistd.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/util/cleanup.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// TODO(gvisor.dev/issue/2370): This test is currently very rudimentary.
+class WriteTest : public ::testing::Test {
+ public:
+ ssize_t WriteBytes(int fd, int bytes) {
+ std::vector<char> buf(bytes);
+ std::fill(buf.begin(), buf.end(), 'a');
+ return WriteFd(fd, buf.data(), buf.size());
+ }
+};
+
+TEST_F(WriteTest, WriteNoExceedsRLimit) {
+ // Get the current rlimit and restore after test run.
+ struct rlimit initial_lim;
+ ASSERT_THAT(getrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds());
+ auto cleanup = Cleanup([&initial_lim] {
+ EXPECT_THAT(setrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds());
+ });
+
+ int fd;
+ struct rlimit setlim;
+ const int target_lim = 1024;
+ setlim.rlim_cur = target_lim;
+ setlim.rlim_max = RLIM_INFINITY;
+ const std::string pathname = NewTempAbsPath();
+ ASSERT_THAT(fd = open(pathname.c_str(), O_WRONLY | O_CREAT, S_IRWXU),
+ SyscallSucceeds());
+ ASSERT_THAT(setrlimit(RLIMIT_FSIZE, &setlim), SyscallSucceeds());
+
+ EXPECT_THAT(WriteBytes(fd, target_lim), SyscallSucceedsWithValue(target_lim));
+
+ std::vector<char> buf(target_lim + 1);
+ std::fill(buf.begin(), buf.end(), 'a');
+ EXPECT_THAT(pwrite(fd, buf.data(), target_lim, 1), SyscallSucceeds());
+ EXPECT_THAT(pwrite64(fd, buf.data(), target_lim, 1), SyscallSucceeds());
+
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+TEST_F(WriteTest, WriteExceedsRLimit) {
+ // Get the current rlimit and restore after test run.
+ struct rlimit initial_lim;
+ ASSERT_THAT(getrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds());
+ auto cleanup = Cleanup([&initial_lim] {
+ EXPECT_THAT(setrlimit(RLIMIT_FSIZE, &initial_lim), SyscallSucceeds());
+ });
+
+ int fd;
+ sigset_t filesize_mask;
+ sigemptyset(&filesize_mask);
+ sigaddset(&filesize_mask, SIGXFSZ);
+
+ struct rlimit setlim;
+ const int target_lim = 1024;
+ setlim.rlim_cur = target_lim;
+ setlim.rlim_max = RLIM_INFINITY;
+
+ const std::string pathname = NewTempAbsPath();
+ ASSERT_THAT(fd = open(pathname.c_str(), O_WRONLY | O_CREAT, S_IRWXU),
+ SyscallSucceeds());
+ ASSERT_THAT(setrlimit(RLIMIT_FSIZE, &setlim), SyscallSucceeds());
+ ASSERT_THAT(sigprocmask(SIG_BLOCK, &filesize_mask, nullptr),
+ SyscallSucceeds());
+ std::vector<char> buf(target_lim + 2);
+ std::fill(buf.begin(), buf.end(), 'a');
+
+ EXPECT_THAT(write(fd, buf.data(), target_lim + 1),
+ SyscallSucceedsWithValue(target_lim));
+ EXPECT_THAT(write(fd, buf.data(), 1), SyscallFailsWithErrno(EFBIG));
+ siginfo_t info;
+ struct timespec timelimit = {0, 0};
+ ASSERT_THAT(RetryEINTR(sigtimedwait)(&filesize_mask, &info, &timelimit),
+ SyscallSucceedsWithValue(SIGXFSZ));
+ EXPECT_EQ(info.si_code, SI_USER);
+ EXPECT_EQ(info.si_pid, getpid());
+ EXPECT_EQ(info.si_uid, getuid());
+
+ EXPECT_THAT(pwrite(fd, buf.data(), target_lim + 1, 1),
+ SyscallSucceedsWithValue(target_lim - 1));
+ EXPECT_THAT(pwrite(fd, buf.data(), 1, target_lim),
+ SyscallFailsWithErrno(EFBIG));
+ ASSERT_THAT(RetryEINTR(sigtimedwait)(&filesize_mask, &info, &timelimit),
+ SyscallSucceedsWithValue(SIGXFSZ));
+ EXPECT_EQ(info.si_code, SI_USER);
+ EXPECT_EQ(info.si_pid, getpid());
+ EXPECT_EQ(info.si_uid, getuid());
+
+ EXPECT_THAT(pwrite64(fd, buf.data(), target_lim + 1, 1),
+ SyscallSucceedsWithValue(target_lim - 1));
+ EXPECT_THAT(pwrite64(fd, buf.data(), 1, target_lim),
+ SyscallFailsWithErrno(EFBIG));
+ ASSERT_THAT(RetryEINTR(sigtimedwait)(&filesize_mask, &info, &timelimit),
+ SyscallSucceedsWithValue(SIGXFSZ));
+ EXPECT_EQ(info.si_code, SI_USER);
+ EXPECT_EQ(info.si_pid, getpid());
+ EXPECT_EQ(info.si_uid, getuid());
+
+ ASSERT_THAT(sigprocmask(SIG_UNBLOCK, &filesize_mask, nullptr),
+ SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc
new file mode 100644
index 000000000..cbcf08451
--- /dev/null
+++ b/test/syscalls/linux/xattr.cc
@@ -0,0 +1,610 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <limits.h>
+#include <sys/types.h>
+#include <sys/xattr.h>
+#include <unistd.h>
+
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/flat_hash_set.h"
+#include "test/syscalls/linux/file_base.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class XattrTest : public FileTest {};
+
+TEST_F(XattrTest, XattrNonexistentFile) {
+ const char* path = "/does/not/exist";
+ const char* name = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0),
+ SyscallFailsWithErrno(ENOENT));
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENOENT));
+ EXPECT_THAT(listxattr(path, nullptr, 0), SyscallFailsWithErrno(ENOENT));
+ EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(ENOENT));
+}
+
+TEST_F(XattrTest, XattrNullName) {
+ const char* path = test_file_name_.c_str();
+
+ EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, /*flags=*/0),
+ SyscallFailsWithErrno(EFAULT));
+ EXPECT_THAT(getxattr(path, nullptr, nullptr, 0),
+ SyscallFailsWithErrno(EFAULT));
+ EXPECT_THAT(removexattr(path, nullptr), SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(XattrTest, XattrEmptyName) {
+ const char* path = test_file_name_.c_str();
+
+ EXPECT_THAT(setxattr(path, "", nullptr, 0, /*flags=*/0),
+ SyscallFailsWithErrno(ERANGE));
+ EXPECT_THAT(getxattr(path, "", nullptr, 0), SyscallFailsWithErrno(ERANGE));
+ EXPECT_THAT(removexattr(path, ""), SyscallFailsWithErrno(ERANGE));
+}
+
+TEST_F(XattrTest, XattrLargeName) {
+ const char* path = test_file_name_.c_str();
+ std::string name = "user.";
+ name += std::string(XATTR_NAME_MAX - name.length(), 'a');
+
+ if (!IsRunningOnGvisor()) {
+ // In gVisor, access to xattrs is controlled with an explicit list of
+ // allowed names. This name isn't going to be configured to allow access, so
+ // don't test it.
+ EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0),
+ SyscallSucceeds());
+ EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0),
+ SyscallSucceedsWithValue(0));
+ }
+
+ name += "a";
+ EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0),
+ SyscallFailsWithErrno(ERANGE));
+ EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0),
+ SyscallFailsWithErrno(ERANGE));
+ EXPECT_THAT(removexattr(path, name.c_str()), SyscallFailsWithErrno(ERANGE));
+}
+
+TEST_F(XattrTest, XattrInvalidPrefix) {
+ const char* path = test_file_name_.c_str();
+ std::string name(XATTR_NAME_MAX, 'a');
+ EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+ EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+ EXPECT_THAT(removexattr(path, name.c_str()),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+}
+
+// Do not allow save/restore cycles after making the test file read-only, as
+// the restore will fail to open it with r/w permissions.
+TEST_F(XattrTest, XattrReadOnly_NoRandomSave) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ char val = 'a';
+ size_t size = sizeof(val);
+
+ EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
+
+ DisableSave ds;
+ ASSERT_NO_ERRNO(testing::Chmod(test_file_name_, S_IRUSR));
+
+ EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0),
+ SyscallFailsWithErrno(EACCES));
+ EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EACCES));
+
+ char buf = '-';
+ EXPECT_THAT(getxattr(path, name, &buf, size), SyscallSucceedsWithValue(size));
+ EXPECT_EQ(buf, val);
+
+ char list[sizeof(name)];
+ EXPECT_THAT(listxattr(path, list, sizeof(list)),
+ SyscallSucceedsWithValue(sizeof(name)));
+ EXPECT_STREQ(list, name);
+}
+
+// Do not allow save/restore cycles after making the test file write-only, as
+// the restore will fail to open it with r/w permissions.
+TEST_F(XattrTest, XattrWriteOnly_NoRandomSave) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ DisableSave ds;
+ ASSERT_NO_ERRNO(testing::Chmod(test_file_name_, S_IWUSR));
+
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ char val = 'a';
+ size_t size = sizeof(val);
+
+ EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
+
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(EACCES));
+
+ // listxattr will succeed even without read permissions.
+ char list[sizeof(name)];
+ EXPECT_THAT(listxattr(path, list, sizeof(list)),
+ SyscallSucceedsWithValue(sizeof(name)));
+ EXPECT_STREQ(list, name);
+
+ EXPECT_THAT(removexattr(path, name), SyscallSucceeds());
+}
+
+TEST_F(XattrTest, XattrTrustedWithNonadmin) {
+ // TODO(b/148380782): Support setxattr and getxattr with "trusted" prefix.
+ SKIP_IF(IsRunningOnGvisor());
+ SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ const char* path = test_file_name_.c_str();
+ const char name[] = "trusted.abc";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0),
+ SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
+}
+
+TEST_F(XattrTest, XattrOnDirectory) {
+ TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(dir.path().c_str(), name, nullptr, 0, /*flags=*/0),
+ SyscallSucceeds());
+ EXPECT_THAT(getxattr(dir.path().c_str(), name, nullptr, 0),
+ SyscallSucceedsWithValue(0));
+
+ char list[sizeof(name)];
+ EXPECT_THAT(listxattr(dir.path().c_str(), list, sizeof(list)),
+ SyscallSucceedsWithValue(sizeof(name)));
+ EXPECT_STREQ(list, name);
+
+ EXPECT_THAT(removexattr(dir.path().c_str(), name), SyscallSucceeds());
+}
+
+TEST_F(XattrTest, XattrOnSymlink) {
+ TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(dir.path(), test_file_name_));
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(link.path().c_str(), name, nullptr, 0, /*flags=*/0),
+ SyscallSucceeds());
+ EXPECT_THAT(getxattr(link.path().c_str(), name, nullptr, 0),
+ SyscallSucceedsWithValue(0));
+
+ char list[sizeof(name)];
+ EXPECT_THAT(listxattr(link.path().c_str(), list, sizeof(list)),
+ SyscallSucceedsWithValue(sizeof(name)));
+ EXPECT_STREQ(list, name);
+
+ EXPECT_THAT(removexattr(link.path().c_str(), name), SyscallSucceeds());
+}
+
+TEST_F(XattrTest, XattrOnInvalidFileTypes) {
+ const char name[] = "user.test";
+
+ char char_device[] = "/dev/zero";
+ EXPECT_THAT(setxattr(char_device, name, nullptr, 0, /*flags=*/0),
+ SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(getxattr(char_device, name, nullptr, 0),
+ SyscallFailsWithErrno(ENODATA));
+ EXPECT_THAT(listxattr(char_device, nullptr, 0), SyscallSucceedsWithValue(0));
+
+ // Use tmpfs, where creation of named pipes is supported.
+ const std::string fifo = NewTempAbsPathInDir("/dev/shm");
+ const char* path = fifo.c_str();
+ EXPECT_THAT(mknod(path, S_IFIFO | S_IRUSR | S_IWUSR, 0), SyscallSucceeds());
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0),
+ SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
+ EXPECT_THAT(listxattr(path, nullptr, 0), SyscallSucceedsWithValue(0));
+ EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EPERM));
+}
+
+TEST_F(XattrTest, SetxattrSizeSmallerThanValue) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ std::vector<char> val = {'a', 'a'};
+ size_t size = 1;
+ EXPECT_THAT(setxattr(path, name, val.data(), size, /*flags=*/0),
+ SyscallSucceeds());
+
+ std::vector<char> buf = {'-', '-'};
+ std::vector<char> expected_buf = {'a', '-'};
+ EXPECT_THAT(getxattr(path, name, buf.data(), buf.size()),
+ SyscallSucceedsWithValue(size));
+ EXPECT_EQ(buf, expected_buf);
+}
+
+TEST_F(XattrTest, SetxattrZeroSize) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ char val = 'a';
+ EXPECT_THAT(setxattr(path, name, &val, 0, /*flags=*/0), SyscallSucceeds());
+
+ char buf = '-';
+ EXPECT_THAT(getxattr(path, name, &buf, XATTR_SIZE_MAX),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(buf, '-');
+}
+
+TEST_F(XattrTest, SetxattrSizeTooLarge) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+
+ // Note that each particular fs implementation may stipulate a lower size
+ // limit, in which case we actually may fail (e.g. error with ENOSPC) for
+ // some sizes under XATTR_SIZE_MAX.
+ size_t size = XATTR_SIZE_MAX + 1;
+ std::vector<char> val(size);
+ EXPECT_THAT(setxattr(path, name, val.data(), size, /*flags=*/0),
+ SyscallFailsWithErrno(E2BIG));
+
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
+}
+
+TEST_F(XattrTest, SetxattrNullValueAndNonzeroSize) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 1, /*flags=*/0),
+ SyscallFailsWithErrno(EFAULT));
+
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
+}
+
+TEST_F(XattrTest, SetxattrNullValueAndZeroSize) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
+
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(XattrTest, SetxattrValueTooLargeButOKSize) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ std::vector<char> val(XATTR_SIZE_MAX + 1);
+ std::fill(val.begin(), val.end(), 'a');
+ size_t size = 1;
+ EXPECT_THAT(setxattr(path, name, val.data(), size, /*flags=*/0),
+ SyscallSucceeds());
+
+ std::vector<char> buf = {'-', '-'};
+ std::vector<char> expected_buf = {'a', '-'};
+ EXPECT_THAT(getxattr(path, name, buf.data(), size),
+ SyscallSucceedsWithValue(size));
+ EXPECT_EQ(buf, expected_buf);
+}
+
+TEST_F(XattrTest, SetxattrReplaceWithSmaller) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ std::vector<char> val = {'a', 'a'};
+ EXPECT_THAT(setxattr(path, name, val.data(), 2, /*flags=*/0),
+ SyscallSucceeds());
+ EXPECT_THAT(setxattr(path, name, val.data(), 1, /*flags=*/0),
+ SyscallSucceeds());
+
+ std::vector<char> buf = {'-', '-'};
+ std::vector<char> expected_buf = {'a', '-'};
+ EXPECT_THAT(getxattr(path, name, buf.data(), 2), SyscallSucceedsWithValue(1));
+ EXPECT_EQ(buf, expected_buf);
+}
+
+TEST_F(XattrTest, SetxattrReplaceWithLarger) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ std::vector<char> val = {'a', 'a'};
+ EXPECT_THAT(setxattr(path, name, val.data(), 1, /*flags=*/0),
+ SyscallSucceeds());
+ EXPECT_THAT(setxattr(path, name, val.data(), 2, /*flags=*/0),
+ SyscallSucceeds());
+
+ std::vector<char> buf = {'-', '-'};
+ EXPECT_THAT(getxattr(path, name, buf.data(), 2), SyscallSucceedsWithValue(2));
+ EXPECT_EQ(buf, val);
+}
+
+TEST_F(XattrTest, SetxattrCreateFlag) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_CREATE),
+ SyscallSucceeds());
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_CREATE),
+ SyscallFailsWithErrno(EEXIST));
+
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(XattrTest, SetxattrReplaceFlag) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_REPLACE),
+ SyscallFailsWithErrno(ENODATA));
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_REPLACE),
+ SyscallSucceeds());
+
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(XattrTest, SetxattrInvalidFlags) {
+ const char* path = test_file_name_.c_str();
+ int invalid_flags = 0xff;
+ EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, invalid_flags),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(XattrTest, Getxattr) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ int val = 1234;
+ size_t size = sizeof(val);
+ EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
+
+ int buf = 0;
+ EXPECT_THAT(getxattr(path, name, &buf, size), SyscallSucceedsWithValue(size));
+ EXPECT_EQ(buf, val);
+}
+
+TEST_F(XattrTest, GetxattrSizeSmallerThanValue) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ std::vector<char> val = {'a', 'a'};
+ size_t size = val.size();
+ EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
+
+ char buf = '-';
+ EXPECT_THAT(getxattr(path, name, &buf, 1), SyscallFailsWithErrno(ERANGE));
+ EXPECT_EQ(buf, '-');
+}
+
+TEST_F(XattrTest, GetxattrSizeLargerThanValue) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ char val = 'a';
+ EXPECT_THAT(setxattr(path, name, &val, 1, /*flags=*/0), SyscallSucceeds());
+
+ std::vector<char> buf(XATTR_SIZE_MAX);
+ std::fill(buf.begin(), buf.end(), '-');
+ std::vector<char> expected_buf = buf;
+ expected_buf[0] = 'a';
+ EXPECT_THAT(getxattr(path, name, buf.data(), buf.size()),
+ SyscallSucceedsWithValue(1));
+ EXPECT_EQ(buf, expected_buf);
+}
+
+TEST_F(XattrTest, GetxattrZeroSize) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ char val = 'a';
+ EXPECT_THAT(setxattr(path, name, &val, sizeof(val), /*flags=*/0),
+ SyscallSucceeds());
+
+ char buf = '-';
+ EXPECT_THAT(getxattr(path, name, &buf, 0),
+ SyscallSucceedsWithValue(sizeof(val)));
+ EXPECT_EQ(buf, '-');
+}
+
+TEST_F(XattrTest, GetxattrSizeTooLarge) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ char val = 'a';
+ EXPECT_THAT(setxattr(path, name, &val, sizeof(val), /*flags=*/0),
+ SyscallSucceeds());
+
+ std::vector<char> buf(XATTR_SIZE_MAX + 1);
+ std::fill(buf.begin(), buf.end(), '-');
+ std::vector<char> expected_buf = buf;
+ expected_buf[0] = 'a';
+ EXPECT_THAT(getxattr(path, name, buf.data(), buf.size()),
+ SyscallSucceedsWithValue(sizeof(val)));
+ EXPECT_EQ(buf, expected_buf);
+}
+
+TEST_F(XattrTest, GetxattrNullValue) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ char val = 'a';
+ size_t size = sizeof(val);
+ EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
+
+ EXPECT_THAT(getxattr(path, name, nullptr, size),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(XattrTest, GetxattrNullValueAndZeroSize) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ char val = 'a';
+ size_t size = sizeof(val);
+ // Set value with zero size.
+ EXPECT_THAT(setxattr(path, name, &val, 0, /*flags=*/0), SyscallSucceeds());
+ // Get value with nonzero size.
+ EXPECT_THAT(getxattr(path, name, nullptr, size), SyscallSucceedsWithValue(0));
+
+ // Set value with nonzero size.
+ EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
+ // Get value with zero size.
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(size));
+}
+
+TEST_F(XattrTest, GetxattrNonexistentName) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
+}
+
+TEST_F(XattrTest, Listxattr) {
+ const char* path = test_file_name_.c_str();
+ const std::string name = "user.test";
+ const std::string name2 = "user.test2";
+ const std::string name3 = "user.test3";
+ EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0),
+ SyscallSucceeds());
+ EXPECT_THAT(setxattr(path, name2.c_str(), nullptr, 0, /*flags=*/0),
+ SyscallSucceeds());
+ EXPECT_THAT(setxattr(path, name3.c_str(), nullptr, 0, /*flags=*/0),
+ SyscallSucceeds());
+
+ std::vector<char> list(name.size() + 1 + name2.size() + 1 + name3.size() + 1);
+ char* buf = list.data();
+ EXPECT_THAT(listxattr(path, buf, XATTR_SIZE_MAX),
+ SyscallSucceedsWithValue(list.size()));
+
+ absl::flat_hash_set<std::string> got = {};
+ for (char* p = buf; p < buf + list.size(); p += strlen(p) + 1) {
+ got.insert(std::string{p});
+ }
+
+ absl::flat_hash_set<std::string> expected = {name, name2, name3};
+ EXPECT_EQ(got, expected);
+}
+
+TEST_F(XattrTest, ListxattrNoXattrs) {
+ const char* path = test_file_name_.c_str();
+
+ std::vector<char> list, expected;
+ EXPECT_THAT(listxattr(path, list.data(), sizeof(list)),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(list, expected);
+
+ // Listxattr should succeed if there are no attributes, even if the buffer
+ // passed in is a nullptr.
+ EXPECT_THAT(listxattr(path, nullptr, sizeof(list)),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_F(XattrTest, ListxattrNullBuffer) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
+
+ EXPECT_THAT(listxattr(path, nullptr, sizeof(name)),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(XattrTest, ListxattrSizeTooSmall) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
+
+ char list[sizeof(name) - 1];
+ EXPECT_THAT(listxattr(path, list, sizeof(list)),
+ SyscallFailsWithErrno(ERANGE));
+}
+
+TEST_F(XattrTest, ListxattrZeroSize) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
+ EXPECT_THAT(listxattr(path, nullptr, 0),
+ SyscallSucceedsWithValue(sizeof(name)));
+}
+
+TEST_F(XattrTest, RemoveXattr) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
+ EXPECT_THAT(removexattr(path, name), SyscallSucceeds());
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
+}
+
+TEST_F(XattrTest, RemoveXattrNonexistentName) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(ENODATA));
+}
+
+TEST_F(XattrTest, LXattrOnSymlink) {
+ const char name[] = "user.test";
+ TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(dir.path(), test_file_name_));
+
+ EXPECT_THAT(lsetxattr(link.path().c_str(), name, nullptr, 0, 0),
+ SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(lgetxattr(link.path().c_str(), name, nullptr, 0),
+ SyscallFailsWithErrno(ENODATA));
+ EXPECT_THAT(llistxattr(link.path().c_str(), nullptr, 0),
+ SyscallSucceedsWithValue(0));
+ EXPECT_THAT(lremovexattr(link.path().c_str(), name),
+ SyscallFailsWithErrno(EPERM));
+}
+
+TEST_F(XattrTest, LXattrOnNonsymlink) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ int val = 1234;
+ size_t size = sizeof(val);
+ EXPECT_THAT(lsetxattr(path, name, &val, size, /*flags=*/0),
+ SyscallSucceeds());
+
+ int buf = 0;
+ EXPECT_THAT(lgetxattr(path, name, &buf, size),
+ SyscallSucceedsWithValue(size));
+ EXPECT_EQ(buf, val);
+
+ char list[sizeof(name)];
+ EXPECT_THAT(llistxattr(path, list, sizeof(list)),
+ SyscallSucceedsWithValue(sizeof(name)));
+ EXPECT_STREQ(list, name);
+
+ EXPECT_THAT(lremovexattr(path, name), SyscallSucceeds());
+}
+
+TEST_F(XattrTest, XattrWithFD) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_.c_str(), 0));
+ const char name[] = "user.test";
+ int val = 1234;
+ size_t size = sizeof(val);
+ EXPECT_THAT(fsetxattr(fd.get(), name, &val, size, /*flags=*/0),
+ SyscallSucceeds());
+
+ int buf = 0;
+ EXPECT_THAT(fgetxattr(fd.get(), name, &buf, size),
+ SyscallSucceedsWithValue(size));
+ EXPECT_EQ(buf, val);
+
+ char list[sizeof(name)];
+ EXPECT_THAT(flistxattr(fd.get(), list, sizeof(list)),
+ SyscallSucceedsWithValue(sizeof(name)));
+ EXPECT_STREQ(list, name);
+
+ EXPECT_THAT(fremovexattr(fd.get(), name), SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/uds/BUILD b/test/uds/BUILD
new file mode 100644
index 000000000..51e2c7ce8
--- /dev/null
+++ b/test/uds/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library")
+
+package(
+ default_visibility = ["//:sandbox"],
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "uds",
+ testonly = 1,
+ srcs = ["uds.go"],
+ deps = [
+ "//pkg/log",
+ "//pkg/unet",
+ ],
+)
diff --git a/test/uds/uds.go b/test/uds/uds.go
new file mode 100644
index 000000000..b714c61b0
--- /dev/null
+++ b/test/uds/uds.go
@@ -0,0 +1,228 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package uds contains helpers for testing external UDS functionality.
+package uds
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// createEchoSocket creates a socket that echoes back anything received.
+//
+// Only works for stream, seqpacket sockets.
+func createEchoSocket(path string, protocol int) (cleanup func(), err error) {
+ fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0)
+ if err != nil {
+ return nil, fmt.Errorf("error creating echo(%d) socket: %v", protocol, err)
+ }
+
+ if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil {
+ return nil, fmt.Errorf("error binding echo(%d) socket: %v", protocol, err)
+ }
+
+ if err := syscall.Listen(fd, 0); err != nil {
+ return nil, fmt.Errorf("error listening echo(%d) socket: %v", protocol, err)
+ }
+
+ server, err := unet.NewServerSocket(fd)
+ if err != nil {
+ return nil, fmt.Errorf("error creating echo(%d) unet socket: %v", protocol, err)
+ }
+
+ acceptAndEchoOne := func() error {
+ s, err := server.Accept()
+ if err != nil {
+ return fmt.Errorf("failed to accept: %v", err)
+ }
+ defer s.Close()
+
+ for {
+ buf := make([]byte, 512)
+ for {
+ n, err := s.Read(buf)
+ if err == io.EOF {
+ return nil
+ }
+ if err != nil {
+ return fmt.Errorf("failed to read: %d, %v", n, err)
+ }
+
+ n, err = s.Write(buf[:n])
+ if err != nil {
+ return fmt.Errorf("failed to write: %d, %v", n, err)
+ }
+ }
+ }
+ }
+
+ go func() {
+ for {
+ if err := acceptAndEchoOne(); err != nil {
+ log.Warningf("Failed to handle echo(%d) socket: %v", protocol, err)
+ return
+ }
+ }
+ }()
+
+ cleanup = func() {
+ if err := server.Close(); err != nil {
+ log.Warningf("Failed to close echo(%d) socket: %v", protocol, err)
+ }
+ }
+
+ return cleanup, nil
+}
+
+// createNonListeningSocket creates a socket that is bound but not listening.
+//
+// Only relevant for stream, seqpacket sockets.
+func createNonListeningSocket(path string, protocol int) (cleanup func(), err error) {
+ fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0)
+ if err != nil {
+ return nil, fmt.Errorf("error creating nonlistening(%d) socket: %v", protocol, err)
+ }
+
+ if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil {
+ return nil, fmt.Errorf("error binding nonlistening(%d) socket: %v", protocol, err)
+ }
+
+ cleanup = func() {
+ if err := syscall.Close(fd); err != nil {
+ log.Warningf("Failed to close nonlistening(%d) socket: %v", protocol, err)
+ }
+ }
+
+ return cleanup, nil
+}
+
+// createNullSocket creates a socket that reads anything received.
+//
+// Only works for dgram sockets.
+func createNullSocket(path string, protocol int) (cleanup func(), err error) {
+ fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0)
+ if err != nil {
+ return nil, fmt.Errorf("error creating null(%d) socket: %v", protocol, err)
+ }
+
+ if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil {
+ return nil, fmt.Errorf("error binding null(%d) socket: %v", protocol, err)
+ }
+
+ s, err := unet.NewSocket(fd)
+ if err != nil {
+ return nil, fmt.Errorf("error creating null(%d) unet socket: %v", protocol, err)
+ }
+
+ go func() {
+ buf := make([]byte, 512)
+ for {
+ n, err := s.Read(buf)
+ if err != nil {
+ log.Warningf("failed to read: %d, %v", n, err)
+ return
+ }
+ }
+ }()
+
+ cleanup = func() {
+ if err := s.Close(); err != nil {
+ log.Warningf("Failed to close null(%d) socket: %v", protocol, err)
+ }
+ }
+
+ return cleanup, nil
+}
+
+type socketCreator func(path string, proto int) (cleanup func(), err error)
+
+// CreateSocketTree creates a local tree of unix domain sockets for use in
+// testing:
+// * /stream/echo
+// * /stream/nonlistening
+// * /seqpacket/echo
+// * /seqpacket/nonlistening
+// * /dgram/null
+func CreateSocketTree(baseDir string) (dir string, cleanup func(), err error) {
+ dir, err = ioutil.TempDir(baseDir, "sockets")
+ if err != nil {
+ return "", nil, fmt.Errorf("error creating temp dir: %v", err)
+ }
+
+ var protocols = []struct {
+ protocol int
+ name string
+ sockets map[string]socketCreator
+ }{
+ {
+ protocol: syscall.SOCK_STREAM,
+ name: "stream",
+ sockets: map[string]socketCreator{
+ "echo": createEchoSocket,
+ "nonlistening": createNonListeningSocket,
+ },
+ },
+ {
+ protocol: syscall.SOCK_SEQPACKET,
+ name: "seqpacket",
+ sockets: map[string]socketCreator{
+ "echo": createEchoSocket,
+ "nonlistening": createNonListeningSocket,
+ },
+ },
+ {
+ protocol: syscall.SOCK_DGRAM,
+ name: "dgram",
+ sockets: map[string]socketCreator{
+ "null": createNullSocket,
+ },
+ },
+ }
+
+ var cleanups []func()
+ for _, proto := range protocols {
+ protoDir := filepath.Join(dir, proto.name)
+ if err := os.Mkdir(protoDir, 0755); err != nil {
+ return "", nil, fmt.Errorf("error creating %s dir: %v", proto.name, err)
+ }
+
+ for name, fn := range proto.sockets {
+ path := filepath.Join(protoDir, name)
+ cleanup, err := fn(path, proto.protocol)
+ if err != nil {
+ return "", nil, fmt.Errorf("error creating %s %s socket: %v", proto.name, name, err)
+ }
+
+ cleanups = append(cleanups, cleanup)
+ }
+ }
+
+ cleanup = func() {
+ for _, c := range cleanups {
+ c()
+ }
+
+ os.RemoveAll(dir)
+ }
+
+ return dir, cleanup, nil
+}
diff --git a/test/util/BUILD b/test/util/BUILD
new file mode 100644
index 000000000..2a17c33ee
--- /dev/null
+++ b/test/util/BUILD
@@ -0,0 +1,358 @@
+load("//tools:defs.bzl", "cc_library", "cc_test", "gbenchmark", "gtest", "select_system")
+
+package(
+ default_visibility = ["//:sandbox"],
+ licenses = ["notice"],
+)
+
+cc_library(
+ name = "capability_util",
+ testonly = 1,
+ srcs = ["capability_util.cc"],
+ hdrs = ["capability_util.h"],
+ deps = [
+ ":cleanup",
+ ":memory_util",
+ ":posix_error",
+ ":save_util",
+ ":test_util",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "eventfd_util",
+ testonly = 1,
+ hdrs = ["eventfd_util.h"],
+ deps = [
+ ":file_descriptor",
+ ":posix_error",
+ ":save_util",
+ ],
+)
+
+cc_library(
+ name = "file_descriptor",
+ testonly = 1,
+ hdrs = ["file_descriptor.h"],
+ deps = [
+ ":logging",
+ ":posix_error",
+ ":save_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ gtest,
+ ],
+)
+
+cc_library(
+ name = "proc_util",
+ testonly = 1,
+ srcs = ["proc_util.cc"],
+ hdrs = ["proc_util.h"],
+ deps = [
+ ":fs_util",
+ ":posix_error",
+ ":test_util",
+ "@com_google_absl//absl/strings",
+ gtest,
+ ],
+)
+
+cc_test(
+ name = "proc_util_test",
+ size = "small",
+ srcs = ["proc_util_test.cc"],
+ deps = [
+ ":proc_util",
+ ":test_main",
+ ":test_util",
+ gtest,
+ ],
+)
+
+cc_library(
+ name = "cleanup",
+ testonly = 1,
+ hdrs = ["cleanup.h"],
+)
+
+cc_library(
+ name = "fs_util",
+ testonly = 1,
+ srcs = ["fs_util.cc"],
+ hdrs = ["fs_util.h"],
+ deps = [
+ ":cleanup",
+ ":file_descriptor",
+ ":posix_error",
+ "@com_google_absl//absl/strings",
+ gtest,
+ ],
+)
+
+cc_test(
+ name = "fs_util_test",
+ size = "small",
+ srcs = ["fs_util_test.cc"],
+ deps = [
+ ":fs_util",
+ ":posix_error",
+ ":temp_path",
+ ":test_main",
+ ":test_util",
+ gtest,
+ ],
+)
+
+cc_library(
+ name = "logging",
+ testonly = 1,
+ srcs = ["logging.cc"],
+ hdrs = ["logging.h"],
+)
+
+cc_library(
+ name = "memory_util",
+ testonly = 1,
+ hdrs = ["memory_util.h"],
+ deps = [
+ ":logging",
+ ":posix_error",
+ ":save_util",
+ ":test_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ ],
+)
+
+cc_library(
+ name = "mount_util",
+ testonly = 1,
+ hdrs = ["mount_util.h"],
+ deps = [
+ ":cleanup",
+ ":posix_error",
+ ":test_util",
+ gtest,
+ ],
+)
+
+cc_library(
+ name = "save_util",
+ testonly = 1,
+ srcs = [
+ "save_util.cc",
+ "save_util_linux.cc",
+ "save_util_other.cc",
+ ],
+ hdrs = ["save_util.h"],
+ defines = select_system(),
+)
+
+cc_library(
+ name = "multiprocess_util",
+ testonly = 1,
+ srcs = ["multiprocess_util.cc"],
+ hdrs = ["multiprocess_util.h"],
+ deps = [
+ ":cleanup",
+ ":file_descriptor",
+ ":posix_error",
+ ":save_util",
+ ":test_util",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "platform_util",
+ testonly = 1,
+ srcs = ["platform_util.cc"],
+ hdrs = ["platform_util.h"],
+ deps = [":test_util"],
+)
+
+cc_library(
+ name = "posix_error",
+ testonly = 1,
+ srcs = ["posix_error.cc"],
+ hdrs = ["posix_error.h"],
+ deps = [
+ ":logging",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:variant",
+ gtest,
+ ],
+)
+
+cc_test(
+ name = "posix_error_test",
+ size = "small",
+ srcs = ["posix_error_test.cc"],
+ deps = [
+ ":posix_error",
+ ":test_main",
+ gtest,
+ ],
+)
+
+cc_library(
+ name = "pty_util",
+ testonly = 1,
+ srcs = ["pty_util.cc"],
+ hdrs = ["pty_util.h"],
+ deps = [
+ ":file_descriptor",
+ ":posix_error",
+ ],
+)
+
+cc_library(
+ name = "signal_util",
+ testonly = 1,
+ srcs = ["signal_util.cc"],
+ hdrs = ["signal_util.h"],
+ deps = [
+ ":cleanup",
+ ":posix_error",
+ ":test_util",
+ gtest,
+ ],
+)
+
+cc_library(
+ name = "temp_path",
+ testonly = 1,
+ srcs = ["temp_path.cc"],
+ hdrs = ["temp_path.h"],
+ deps = [
+ ":fs_util",
+ ":posix_error",
+ ":test_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
+ ],
+)
+
+cc_library(
+ name = "test_util",
+ testonly = 1,
+ srcs = [
+ "test_util.cc",
+ "test_util_impl.cc",
+ "test_util_runfiles.cc",
+ ],
+ hdrs = ["test_util.h"],
+ defines = select_system(),
+ deps = [
+ ":fs_util",
+ ":logging",
+ ":posix_error",
+ ":save_util",
+ "@bazel_tools//tools/cpp/runfiles",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/flags:parse",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/time",
+ gtest,
+ gbenchmark,
+ ],
+)
+
+cc_library(
+ name = "thread_util",
+ testonly = 1,
+ hdrs = ["thread_util.h"],
+ deps = [":logging"],
+)
+
+cc_library(
+ name = "time_util",
+ testonly = 1,
+ srcs = ["time_util.cc"],
+ hdrs = ["time_util.h"],
+ deps = [
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_library(
+ name = "timer_util",
+ testonly = 1,
+ srcs = ["timer_util.cc"],
+ hdrs = ["timer_util.h"],
+ deps = [
+ ":cleanup",
+ ":logging",
+ ":posix_error",
+ ":test_util",
+ "@com_google_absl//absl/time",
+ gtest,
+ ],
+)
+
+cc_test(
+ name = "test_util_test",
+ size = "small",
+ srcs = ["test_util_test.cc"],
+ deps = [
+ ":test_main",
+ ":test_util",
+ gtest,
+ ],
+)
+
+cc_library(
+ name = "test_main",
+ testonly = 1,
+ srcs = ["test_main.cc"],
+ deps = [":test_util"],
+)
+
+cc_library(
+ name = "epoll_util",
+ testonly = 1,
+ srcs = ["epoll_util.cc"],
+ hdrs = ["epoll_util.h"],
+ deps = [
+ ":file_descriptor",
+ ":posix_error",
+ ":save_util",
+ gtest,
+ ],
+)
+
+cc_library(
+ name = "rlimit_util",
+ testonly = 1,
+ srcs = ["rlimit_util.cc"],
+ hdrs = ["rlimit_util.h"],
+ deps = [
+ ":cleanup",
+ ":logging",
+ ":posix_error",
+ ":test_util",
+ ],
+)
+
+cc_library(
+ name = "uid_util",
+ testonly = 1,
+ srcs = ["uid_util.cc"],
+ hdrs = ["uid_util.h"],
+ deps = [
+ ":posix_error",
+ ":save_util",
+ ],
+)
+
+cc_library(
+ name = "temp_umask",
+ testonly = 1,
+ hdrs = ["temp_umask.h"],
+)
diff --git a/test/util/capability_util.cc b/test/util/capability_util.cc
new file mode 100644
index 000000000..a1b994c45
--- /dev/null
+++ b/test/util/capability_util.cc
@@ -0,0 +1,81 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/capability_util.h"
+
+#include <linux/capability.h>
+#include <sched.h>
+#include <sys/mman.h>
+#include <sys/wait.h>
+
+#include <iostream>
+
+#include "absl/strings/str_cat.h"
+#include "test/util/memory_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+PosixErrorOr<bool> CanCreateUserNamespace() {
+ // The most reliable way to determine if userns creation is possible is by
+ // trying to create one; see below.
+ ASSIGN_OR_RETURN_ERRNO(
+ auto child_stack,
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ int const child_pid = clone(
+ +[](void*) { return 0; },
+ reinterpret_cast<void*>(child_stack.addr() + kPageSize),
+ CLONE_NEWUSER | SIGCHLD, /* arg = */ nullptr);
+ if (child_pid > 0) {
+ int status;
+ int const ret = waitpid(child_pid, &status, /* options = */ 0);
+ MaybeSave();
+ if (ret < 0) {
+ return PosixError(errno, "waitpid");
+ }
+ if (!WIFEXITED(status) || WEXITSTATUS(status) != 0) {
+ return PosixError(
+ ESRCH, absl::StrCat("child process exited with status ", status));
+ }
+ return true;
+ } else if (errno == EPERM) {
+ // Per clone(2), EPERM can be returned if:
+ //
+ // - "CLONE_NEWUSER was specified in flags, but either the effective user ID
+ // or the effective group ID of the caller does not have a mapping in the
+ // parent namespace (see user_namespaces(7))."
+ //
+ // - "(since Linux 3.9) CLONE_NEWUSER was specified in flags and the caller
+ // is in a chroot environment (i.e., the caller's root directory does
+ // not match the root directory of the mount namespace in which it
+ // resides)."
+ std::cerr << "clone(CLONE_NEWUSER) failed with EPERM" << std::endl;
+ return false;
+ } else if (errno == EUSERS) {
+ // "(since Linux 3.11) CLONE_NEWUSER was specified in flags, and the call
+ // would cause the limit on the number of nested user namespaces to be
+ // exceeded. See user_namespaces(7)."
+ std::cerr << "clone(CLONE_NEWUSER) failed with EUSERS" << std::endl;
+ return false;
+ } else {
+ // Unexpected error code; indicate an actual error.
+ return PosixError(errno, "clone(CLONE_NEWUSER)");
+ }
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/capability_util.h b/test/util/capability_util.h
new file mode 100644
index 000000000..bb9ea1fe5
--- /dev/null
+++ b/test/util/capability_util.h
@@ -0,0 +1,101 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Utilities for testing capabilities.
+
+#ifndef GVISOR_TEST_UTIL_CAPABILITY_UTIL_H_
+#define GVISOR_TEST_UTIL_CAPABILITY_UTIL_H_
+
+#include <errno.h>
+#include <linux/capability.h>
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include "test/util/cleanup.h"
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+#include "test/util/test_util.h"
+
+#ifndef _LINUX_CAPABILITY_VERSION_3
+#error Expecting _LINUX_CAPABILITY_VERSION_3 support
+#endif
+
+namespace gvisor {
+namespace testing {
+
+// HaveCapability returns true if the process has the specified EFFECTIVE
+// capability.
+inline PosixErrorOr<bool> HaveCapability(int cap) {
+ if (!cap_valid(cap)) {
+ return PosixError(EINVAL, "Invalid capability");
+ }
+
+ struct __user_cap_header_struct header = {_LINUX_CAPABILITY_VERSION_3, 0};
+ struct __user_cap_data_struct caps[_LINUX_CAPABILITY_U32S_3] = {};
+ RETURN_ERROR_IF_SYSCALL_FAIL(syscall(__NR_capget, &header, &caps));
+ MaybeSave();
+
+ return (caps[CAP_TO_INDEX(cap)].effective & CAP_TO_MASK(cap)) != 0;
+}
+
+// SetCapability sets the specified EFFECTIVE capability.
+inline PosixError SetCapability(int cap, bool set) {
+ if (!cap_valid(cap)) {
+ return PosixError(EINVAL, "Invalid capability");
+ }
+
+ struct __user_cap_header_struct header = {_LINUX_CAPABILITY_VERSION_3, 0};
+ struct __user_cap_data_struct caps[_LINUX_CAPABILITY_U32S_3] = {};
+ RETURN_ERROR_IF_SYSCALL_FAIL(syscall(__NR_capget, &header, &caps));
+ MaybeSave();
+
+ if (set) {
+ caps[CAP_TO_INDEX(cap)].effective |= CAP_TO_MASK(cap);
+ } else {
+ caps[CAP_TO_INDEX(cap)].effective &= ~CAP_TO_MASK(cap);
+ }
+ header = {_LINUX_CAPABILITY_VERSION_3, 0};
+ RETURN_ERROR_IF_SYSCALL_FAIL(syscall(__NR_capset, &header, &caps));
+ MaybeSave();
+
+ return NoError();
+}
+
+// DropPermittedCapability drops the specified PERMITTED. The EFFECTIVE
+// capabilities must be a subset of PERMITTED, so those are dropped as well.
+inline PosixError DropPermittedCapability(int cap) {
+ if (!cap_valid(cap)) {
+ return PosixError(EINVAL, "Invalid capability");
+ }
+
+ struct __user_cap_header_struct header = {_LINUX_CAPABILITY_VERSION_3, 0};
+ struct __user_cap_data_struct caps[_LINUX_CAPABILITY_U32S_3] = {};
+ RETURN_ERROR_IF_SYSCALL_FAIL(syscall(__NR_capget, &header, &caps));
+ MaybeSave();
+
+ caps[CAP_TO_INDEX(cap)].effective &= ~CAP_TO_MASK(cap);
+ caps[CAP_TO_INDEX(cap)].permitted &= ~CAP_TO_MASK(cap);
+
+ header = {_LINUX_CAPABILITY_VERSION_3, 0};
+ RETURN_ERROR_IF_SYSCALL_FAIL(syscall(__NR_capset, &header, &caps));
+ MaybeSave();
+
+ return NoError();
+}
+
+PosixErrorOr<bool> CanCreateUserNamespace();
+
+} // namespace testing
+} // namespace gvisor
+#endif // GVISOR_TEST_UTIL_CAPABILITY_UTIL_H_
diff --git a/test/util/cleanup.h b/test/util/cleanup.h
new file mode 100644
index 000000000..c76482ef4
--- /dev/null
+++ b/test/util/cleanup.h
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_CLEANUP_H_
+#define GVISOR_TEST_UTIL_CLEANUP_H_
+
+#include <functional>
+#include <utility>
+
+namespace gvisor {
+namespace testing {
+
+class Cleanup {
+ public:
+ Cleanup() : released_(true) {}
+ explicit Cleanup(std::function<void()>&& callback) : cb_(callback) {}
+
+ Cleanup(Cleanup&& other) {
+ released_ = other.released_;
+ cb_ = other.Release();
+ }
+
+ Cleanup& operator=(Cleanup&& other) {
+ released_ = other.released_;
+ cb_ = other.Release();
+ return *this;
+ }
+
+ ~Cleanup() {
+ if (!released_) {
+ cb_();
+ }
+ }
+
+ std::function<void()>&& Release() {
+ released_ = true;
+ return std::move(cb_);
+ }
+
+ private:
+ Cleanup(Cleanup const& other) = delete;
+
+ bool released_ = false;
+ std::function<void(void)> cb_;
+};
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_CLEANUP_H_
diff --git a/test/util/epoll_util.cc b/test/util/epoll_util.cc
new file mode 100644
index 000000000..2e5051468
--- /dev/null
+++ b/test/util/epoll_util.cc
@@ -0,0 +1,52 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/epoll_util.h"
+
+#include <sys/epoll.h>
+
+#include "gmock/gmock.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+
+namespace gvisor {
+namespace testing {
+
+PosixErrorOr<FileDescriptor> NewEpollFD(int size) {
+ // "Since Linux 2.6.8, the size argument is ignored, but must be greater than
+ // zero." - epoll_create(2)
+ int fd = epoll_create(size);
+ MaybeSave();
+ if (fd < 0) {
+ return PosixError(errno, "epoll_create");
+ }
+ return FileDescriptor(fd);
+}
+
+PosixError RegisterEpollFD(int epoll_fd, int target_fd, int events,
+ uint64_t data) {
+ struct epoll_event event;
+ event.events = events;
+ event.data.u64 = data;
+ int rc = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, target_fd, &event);
+ MaybeSave();
+ if (rc < 0) {
+ return PosixError(errno, "epoll_ctl");
+ }
+ return NoError();
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/epoll_util.h b/test/util/epoll_util.h
new file mode 100644
index 000000000..f233b37d5
--- /dev/null
+++ b/test/util/epoll_util.h
@@ -0,0 +1,36 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_EPOLL_UTIL_H_
+#define GVISOR_TEST_UTIL_EPOLL_UTIL_H_
+
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+// Returns a new epoll file descriptor.
+PosixErrorOr<FileDescriptor> NewEpollFD(int size = 1);
+
+// Registers `target_fd` with the epoll instance represented by `epoll_fd` for
+// the epoll events `events`. Events on `target_fd` will be indicated by setting
+// data.u64 to `data` in the returned epoll_event.
+PosixError RegisterEpollFD(int epoll_fd, int target_fd, int events,
+ uint64_t data);
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_EPOLL_UTIL_H_
diff --git a/test/util/eventfd_util.h b/test/util/eventfd_util.h
new file mode 100644
index 000000000..cb9ce829c
--- /dev/null
+++ b/test/util/eventfd_util.h
@@ -0,0 +1,43 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_EVENTFD_UTIL_H_
+#define GVISOR_TEST_UTIL_EVENTFD_UTIL_H_
+
+#include <sys/eventfd.h>
+
+#include <cerrno>
+
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Returns a new eventfd with the given initial value and flags.
+inline PosixErrorOr<FileDescriptor> NewEventFD(unsigned int initval = 0,
+ int flags = 0) {
+ int fd = eventfd(initval, flags);
+ MaybeSave();
+ if (fd < 0) {
+ return PosixError(errno, "eventfd");
+ }
+ return FileDescriptor(fd);
+}
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_EVENTFD_UTIL_H_
diff --git a/test/util/file_descriptor.h b/test/util/file_descriptor.h
new file mode 100644
index 000000000..fc5caa55b
--- /dev/null
+++ b/test/util/file_descriptor.h
@@ -0,0 +1,134 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_FILE_DESCRIPTOR_H_
+#define GVISOR_TEST_UTIL_FILE_DESCRIPTOR_H_
+
+#include <fcntl.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <algorithm>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "test/util/logging.h"
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// FileDescriptor is an RAII type class which takes ownership of a file
+// descriptor. It will close the FD when this object goes out of scope.
+class FileDescriptor {
+ public:
+ // Constructs an empty FileDescriptor (one that does not own a file
+ // descriptor).
+ FileDescriptor() = default;
+
+ // Constructs a FileDescriptor that owns fd. If fd is negative, constructs an
+ // empty FileDescriptor.
+ explicit FileDescriptor(int fd) { set_fd(fd); }
+
+ FileDescriptor(FileDescriptor&& orig) : fd_(orig.release()) {}
+
+ FileDescriptor& operator=(FileDescriptor&& orig) {
+ reset(orig.release());
+ return *this;
+ }
+
+ PosixErrorOr<FileDescriptor> Dup() const {
+ if (fd_ < 0) {
+ return PosixError(EINVAL, "Attempting to Dup unset fd");
+ }
+
+ int fd = dup(fd_);
+ if (fd < 0) {
+ return PosixError(errno, absl::StrCat("dup ", fd_));
+ }
+ MaybeSave();
+ return FileDescriptor(fd);
+ }
+
+ FileDescriptor(FileDescriptor const& other) = delete;
+ FileDescriptor& operator=(FileDescriptor const& other) = delete;
+
+ ~FileDescriptor() { reset(); }
+
+ // If this object is non-empty, returns the owned file descriptor. (Ownership
+ // is retained by the FileDescriptor.) Otherwise returns -1.
+ int get() const { return fd_; }
+
+ // If this object is non-empty, transfers ownership of the file descriptor to
+ // the caller and returns it. Otherwise returns -1.
+ int release() {
+ int const fd = fd_;
+ fd_ = -1;
+ return fd;
+ }
+
+ // If this object is non-empty, closes the owned file descriptor (recording a
+ // test failure if the close fails).
+ void reset() { reset(-1); }
+
+ // Like no-arg reset(), but the FileDescriptor takes ownership of fd after
+ // closing its existing file descriptor.
+ void reset(int fd) {
+ if (fd_ >= 0) {
+ TEST_PCHECK(close(fd_) == 0);
+ MaybeSave();
+ }
+ set_fd(fd);
+ }
+
+ private:
+ // Wrapper that coerces negative fd values other than -1 to -1 so that get()
+ // etc. return -1.
+ void set_fd(int fd) { fd_ = std::max(fd, -1); }
+
+ int fd_ = -1;
+};
+
+// Wrapper around open(2) that returns a FileDescriptor.
+inline PosixErrorOr<FileDescriptor> Open(std::string const& path, int flags,
+ mode_t mode = 0) {
+ int fd = open(path.c_str(), flags, mode);
+ if (fd < 0) {
+ return PosixError(errno, absl::StrFormat("open(%s, %#x, %#o)", path.c_str(),
+ flags, mode));
+ }
+ MaybeSave();
+ return FileDescriptor(fd);
+}
+
+// Wrapper around openat(2) that returns a FileDescriptor.
+inline PosixErrorOr<FileDescriptor> OpenAt(int dirfd, std::string const& path,
+ int flags, mode_t mode = 0) {
+ int fd = openat(dirfd, path.c_str(), flags, mode);
+ if (fd < 0) {
+ return PosixError(errno, absl::StrFormat("openat(%d, %s, %#x, %#o)", dirfd,
+ path, flags, mode));
+ }
+ MaybeSave();
+ return FileDescriptor(fd);
+}
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_FILE_DESCRIPTOR_H_
diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc
new file mode 100644
index 000000000..5418948fe
--- /dev/null
+++ b/test/util/fs_util.cc
@@ -0,0 +1,633 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/fs_util.h"
+
+#include <dirent.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gmock/gmock.h"
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+PosixError WriteContentsToFD(int fd, absl::string_view contents) {
+ int written = 0;
+ while (static_cast<absl::string_view::size_type>(written) < contents.size()) {
+ int wrote = write(fd, contents.data() + written, contents.size() - written);
+ if (wrote < 0) {
+ if (errno == EINTR) {
+ continue;
+ }
+ return PosixError(
+ errno, absl::StrCat("WriteContentsToFD fd: ", fd, " write failure."));
+ }
+ written += wrote;
+ }
+ return NoError();
+}
+} // namespace
+
+namespace internal {
+
+// Given a collection of file paths, append them all together,
+// ensuring that the proper path separators are inserted between them.
+std::string JoinPathImpl(std::initializer_list<absl::string_view> paths) {
+ std::string result;
+
+ if (paths.size() != 0) {
+ // This size calculation is worst-case: it assumes one extra "/" for every
+ // path other than the first.
+ size_t total_size = paths.size() - 1;
+ for (const absl::string_view path : paths) total_size += path.size();
+ result.resize(total_size);
+
+ auto begin = result.begin();
+ auto out = begin;
+ bool trailing_slash = false;
+ for (absl::string_view path : paths) {
+ if (path.empty()) continue;
+ if (path.front() == '/') {
+ if (trailing_slash) {
+ path.remove_prefix(1);
+ }
+ } else {
+ if (!trailing_slash && out != begin) *out++ = '/';
+ }
+ const size_t this_size = path.size();
+ memcpy(&*out, path.data(), this_size);
+ out += this_size;
+ trailing_slash = out[-1] == '/';
+ }
+ result.erase(out - begin);
+ }
+ return result;
+}
+} // namespace internal
+
+// Returns a status or the current working directory.
+PosixErrorOr<std::string> GetCWD() {
+ char buffer[PATH_MAX + 1] = {};
+ if (getcwd(buffer, PATH_MAX) == nullptr) {
+ return PosixError(errno, "GetCWD() failed");
+ }
+
+ return std::string(buffer);
+}
+
+PosixErrorOr<struct stat> Stat(absl::string_view path) {
+ struct stat stat_buf;
+ int res = stat(std::string(path).c_str(), &stat_buf);
+ if (res < 0) {
+ return PosixError(errno, absl::StrCat("stat ", path));
+ }
+ return stat_buf;
+}
+
+PosixErrorOr<struct stat> Lstat(absl::string_view path) {
+ struct stat stat_buf;
+ int res = lstat(std::string(path).c_str(), &stat_buf);
+ if (res < 0) {
+ return PosixError(errno, absl::StrCat("lstat ", path));
+ }
+ return stat_buf;
+}
+
+PosixErrorOr<struct stat> Fstat(int fd) {
+ struct stat stat_buf;
+ int res = fstat(fd, &stat_buf);
+ if (res < 0) {
+ return PosixError(errno, absl::StrCat("fstat ", fd));
+ }
+ return stat_buf;
+}
+
+PosixErrorOr<bool> Exists(absl::string_view path) {
+ struct stat stat_buf;
+ int res = lstat(std::string(path).c_str(), &stat_buf);
+ if (res < 0) {
+ if (errno == ENOENT) {
+ return false;
+ }
+ return PosixError(errno, absl::StrCat("lstat ", path));
+ }
+ return true;
+}
+
+PosixErrorOr<bool> IsDirectory(absl::string_view path) {
+ ASSIGN_OR_RETURN_ERRNO(struct stat stat_buf, Lstat(path));
+ if (S_ISDIR(stat_buf.st_mode)) {
+ return true;
+ }
+
+ return false;
+}
+
+PosixError Delete(absl::string_view path) {
+ int res = unlink(std::string(path).c_str());
+ if (res < 0) {
+ return PosixError(errno, absl::StrCat("unlink ", path));
+ }
+
+ return NoError();
+}
+
+PosixError Truncate(absl::string_view path, int length) {
+ int res = truncate(std::string(path).c_str(), length);
+ if (res < 0) {
+ return PosixError(errno,
+ absl::StrCat("truncate ", path, " to length ", length));
+ }
+
+ return NoError();
+}
+
+PosixError Chmod(absl::string_view path, int mode) {
+ int res = chmod(std::string(path).c_str(), mode);
+ if (res < 0) {
+ return PosixError(errno, absl::StrCat("chmod ", path));
+ }
+
+ return NoError();
+}
+
+PosixError MknodAt(const FileDescriptor& dfd, absl::string_view path, int mode,
+ dev_t dev) {
+ int res = mknodat(dfd.get(), std::string(path).c_str(), mode, dev);
+ if (res < 0) {
+ return PosixError(errno, absl::StrCat("mknod ", path));
+ }
+
+ return NoError();
+}
+
+PosixError UnlinkAt(const FileDescriptor& dfd, absl::string_view path,
+ int flags) {
+ int res = unlinkat(dfd.get(), std::string(path).c_str(), flags);
+ if (res < 0) {
+ return PosixError(errno, absl::StrCat("unlink ", path));
+ }
+
+ return NoError();
+}
+
+PosixError Mkdir(absl::string_view path, int mode) {
+ int res = mkdir(std::string(path).c_str(), mode);
+ if (res < 0) {
+ return PosixError(errno, absl::StrCat("mkdir ", path, " mode ", mode));
+ }
+
+ return NoError();
+}
+
+PosixError Rmdir(absl::string_view path) {
+ int res = rmdir(std::string(path).c_str());
+ if (res < 0) {
+ return PosixError(errno, absl::StrCat("rmdir ", path));
+ }
+
+ return NoError();
+}
+
+PosixError SetContents(absl::string_view path, absl::string_view contents) {
+ ASSIGN_OR_RETURN_ERRNO(bool exists, Exists(path));
+ if (!exists) {
+ return PosixError(
+ ENOENT, absl::StrCat("SetContents file ", path, " doesn't exist."));
+ }
+
+ ASSIGN_OR_RETURN_ERRNO(auto fd, Open(std::string(path), O_WRONLY | O_TRUNC));
+ return WriteContentsToFD(fd.get(), contents);
+}
+
+// Create a file with the given contents (if it does not already exist with the
+// given mode) and then set the contents.
+PosixError CreateWithContents(absl::string_view path,
+ absl::string_view contents, int mode) {
+ ASSIGN_OR_RETURN_ERRNO(
+ auto fd, Open(std::string(path), O_WRONLY | O_CREAT | O_TRUNC, mode));
+ return WriteContentsToFD(fd.get(), contents);
+}
+
+PosixError GetContents(absl::string_view path, std::string* output) {
+ ASSIGN_OR_RETURN_ERRNO(auto fd, Open(std::string(path), O_RDONLY));
+ output->clear();
+
+ // Keep reading until we hit an EOF or an error.
+ return GetContentsFD(fd.get(), output);
+}
+
+PosixErrorOr<std::string> GetContents(absl::string_view path) {
+ std::string ret;
+ RETURN_IF_ERRNO(GetContents(path, &ret));
+ return ret;
+}
+
+PosixErrorOr<std::string> GetContentsFD(int fd) {
+ std::string ret;
+ RETURN_IF_ERRNO(GetContentsFD(fd, &ret));
+ return ret;
+}
+
+PosixError GetContentsFD(int fd, std::string* output) {
+ // Keep reading until we hit an EOF or an error.
+ while (true) {
+ char buf[16 * 1024] = {}; // Read in 16KB chunks.
+ int bytes_read = read(fd, buf, sizeof(buf));
+ if (bytes_read < 0) {
+ if (errno == EINTR) {
+ continue;
+ }
+ return PosixError(errno, "GetContentsFD read failure.");
+ }
+
+ if (bytes_read == 0) {
+ break; // EOF.
+ }
+
+ output->append(buf, bytes_read);
+ }
+ return NoError();
+}
+
+PosixErrorOr<std::string> ReadLink(absl::string_view path) {
+ char buf[PATH_MAX + 1] = {};
+ int ret = readlink(std::string(path).c_str(), buf, PATH_MAX);
+ if (ret < 0) {
+ return PosixError(errno, absl::StrCat("readlink ", path));
+ }
+
+ return std::string(buf, ret);
+}
+
+PosixError WalkTree(
+ absl::string_view path, bool recursive,
+ const std::function<void(absl::string_view, const struct stat&)>& cb) {
+ DIR* dir = opendir(std::string(path).c_str());
+ if (dir == nullptr) {
+ return PosixError(errno, absl::StrCat("opendir ", path));
+ }
+ auto dir_closer = Cleanup([&dir]() { closedir(dir); });
+ while (true) {
+ // Readdir(3): If the end of the directory stream is reached, NULL is
+ // returned and errno is not changed. If an error occurs, NULL is returned
+ // and errno is set appropriately. To distinguish end of stream and from an
+ // error, set errno to zero before calling readdir() and then check the
+ // value of errno if NULL is returned.
+ errno = 0;
+ struct dirent* dp = readdir(dir);
+ if (dp == nullptr) {
+ if (errno != 0) {
+ return PosixError(errno, absl::StrCat("readdir ", path));
+ }
+ break; // We're done.
+ }
+
+ if (strcmp(dp->d_name, ".") == 0 || strcmp(dp->d_name, "..") == 0) {
+ // Skip dots.
+ continue;
+ }
+
+ auto full_path = JoinPath(path, dp->d_name);
+ ASSIGN_OR_RETURN_ERRNO(struct stat s, Stat(full_path));
+ if (S_ISDIR(s.st_mode) && recursive) {
+ RETURN_IF_ERRNO(WalkTree(full_path, recursive, cb));
+ } else {
+ cb(full_path, s);
+ }
+ }
+ // We're done walking so let's invoke our cleanup callback now.
+ dir_closer.Release()();
+
+ // And we have to dispatch the callback on the base directory.
+ ASSIGN_OR_RETURN_ERRNO(struct stat s, Stat(path));
+ cb(path, s);
+
+ return NoError();
+}
+
+PosixErrorOr<std::vector<std::string>> ListDir(absl::string_view abspath,
+ bool skipdots) {
+ std::vector<std::string> files;
+
+ DIR* dir = opendir(std::string(abspath).c_str());
+ if (dir == nullptr) {
+ return PosixError(errno, absl::StrCat("opendir ", abspath));
+ }
+ auto dir_closer = Cleanup([&dir]() { closedir(dir); });
+ while (true) {
+ // Readdir(3): If the end of the directory stream is reached, NULL is
+ // returned and errno is not changed. If an error occurs, NULL is returned
+ // and errno is set appropriately. To distinguish end of stream and from an
+ // error, set errno to zero before calling readdir() and then check the
+ // value of errno if NULL is returned.
+ errno = 0;
+ struct dirent* dp = readdir(dir);
+ if (dp == nullptr) {
+ if (errno != 0) {
+ return PosixError(errno, absl::StrCat("readdir ", abspath));
+ }
+ break; // We're done.
+ }
+
+ if (strcmp(dp->d_name, ".") == 0 || strcmp(dp->d_name, "..") == 0) {
+ if (skipdots) {
+ continue;
+ }
+ }
+ files.push_back(std::string(dp->d_name));
+ }
+
+ return files;
+}
+
+PosixError RecursivelyDelete(absl::string_view path, int* undeleted_dirs,
+ int* undeleted_files) {
+ ASSIGN_OR_RETURN_ERRNO(bool exists, Exists(path));
+ if (!exists) {
+ return PosixError(ENOENT, absl::StrCat(path, " does not exist"));
+ }
+
+ ASSIGN_OR_RETURN_ERRNO(bool dir, IsDirectory(path));
+ if (!dir) {
+ // Nothing recursive needs to happen we can just call Delete.
+ auto status = Delete(path);
+ if (!status.ok() && undeleted_files) {
+ (*undeleted_files)++;
+ }
+ return status;
+ }
+
+ return WalkTree(path, /*recursive=*/true,
+ [&](absl::string_view absolute_path, const struct stat& s) {
+ if (S_ISDIR(s.st_mode)) {
+ auto rm_status = Rmdir(absolute_path);
+ if (!rm_status.ok() && undeleted_dirs) {
+ (*undeleted_dirs)++;
+ }
+ } else {
+ auto delete_status = Delete(absolute_path);
+ if (!delete_status.ok() && undeleted_files) {
+ (*undeleted_files)++;
+ }
+ }
+ });
+}
+
+PosixError RecursivelyCreateDir(absl::string_view path) {
+ if (path.empty() || path == "/") {
+ return PosixError(EINVAL, "Cannot create root!");
+ }
+
+ // Does it already exist, if so we're done.
+ ASSIGN_OR_RETURN_ERRNO(bool exists, Exists(path));
+ if (exists) {
+ return NoError();
+ }
+
+ // Do we need to create directories under us?
+ auto dirname = Dirname(path);
+ ASSIGN_OR_RETURN_ERRNO(exists, Exists(dirname));
+ if (!exists) {
+ RETURN_IF_ERRNO(RecursivelyCreateDir(dirname));
+ }
+
+ return Mkdir(path);
+}
+
+// Makes a path absolute with respect to an optional base. If no base is
+// provided it will use the current working directory.
+PosixErrorOr<std::string> MakeAbsolute(absl::string_view filename,
+ absl::string_view base) {
+ if (filename.empty()) {
+ return PosixError(EINVAL, "filename cannot be empty.");
+ }
+
+ if (filename[0] == '/') {
+ // This path is already absolute.
+ return std::string(filename);
+ }
+
+ std::string actual_base;
+ if (!base.empty()) {
+ actual_base = std::string(base);
+ } else {
+ auto cwd_or = GetCWD();
+ RETURN_IF_ERRNO(cwd_or.error());
+ actual_base = cwd_or.ValueOrDie();
+ }
+
+ // Reverse iterate removing trailing slashes, effectively right trim '/'.
+ for (int i = actual_base.size() - 1; i >= 0 && actual_base[i] == '/'; --i) {
+ actual_base.erase(i, 1);
+ }
+
+ if (filename == ".") {
+ return actual_base.empty() ? "/" : actual_base;
+ }
+
+ return absl::StrCat(actual_base, "/", filename);
+}
+
+std::string CleanPath(const absl::string_view unclean_path) {
+ std::string path = std::string(unclean_path);
+ const char* src = path.c_str();
+ std::string::iterator dst = path.begin();
+
+ // Check for absolute path and determine initial backtrack limit.
+ const bool is_absolute_path = *src == '/';
+ if (is_absolute_path) {
+ *dst++ = *src++;
+ while (*src == '/') ++src;
+ }
+ std::string::const_iterator backtrack_limit = dst;
+
+ // Process all parts
+ while (*src) {
+ bool parsed = false;
+
+ if (src[0] == '.') {
+ // 1dot ".<whateverisnext>", check for END or SEP.
+ if (src[1] == '/' || !src[1]) {
+ if (*++src) {
+ ++src;
+ }
+ parsed = true;
+ } else if (src[1] == '.' && (src[2] == '/' || !src[2])) {
+ // 2dot END or SEP (".." | "../<whateverisnext>").
+ src += 2;
+ if (dst != backtrack_limit) {
+ // We can backtrack the previous part
+ for (--dst; dst != backtrack_limit && dst[-1] != '/'; --dst) {
+ // Empty.
+ }
+ } else if (!is_absolute_path) {
+ // Failed to backtrack and we can't skip it either. Rewind and copy.
+ src -= 2;
+ *dst++ = *src++;
+ *dst++ = *src++;
+ if (*src) {
+ *dst++ = *src;
+ }
+ // We can never backtrack over a copied "../" part so set new limit.
+ backtrack_limit = dst;
+ }
+ if (*src) {
+ ++src;
+ }
+ parsed = true;
+ }
+ }
+
+ // If not parsed, copy entire part until the next SEP or EOS.
+ if (!parsed) {
+ while (*src && *src != '/') {
+ *dst++ = *src++;
+ }
+ if (*src) {
+ *dst++ = *src++;
+ }
+ }
+
+ // Skip consecutive SEP occurrences
+ while (*src == '/') {
+ ++src;
+ }
+ }
+
+ // Calculate and check the length of the cleaned path.
+ int path_length = dst - path.begin();
+ if (path_length != 0) {
+ // Remove trailing '/' except if it is root path ("/" ==> path_length := 1)
+ if (path_length > 1 && path[path_length - 1] == '/') {
+ --path_length;
+ }
+ path.resize(path_length);
+ } else {
+ // The cleaned path is empty; assign "." as per the spec.
+ path.assign(1, '.');
+ }
+ return path;
+}
+
+PosixErrorOr<std::string> GetRelativePath(absl::string_view source,
+ absl::string_view dest) {
+ if (!absl::StartsWith(source, "/") || !absl::StartsWith(dest, "/")) {
+ // At least one of the inputs is not an absolute path.
+ return PosixError(
+ EINVAL,
+ "GetRelativePath: At least one of the inputs is not an absolute path.");
+ }
+ const std::string clean_source = CleanPath(source);
+ const std::string clean_dest = CleanPath(dest);
+ auto source_parts = absl::StrSplit(clean_source, '/', absl::SkipEmpty());
+ auto dest_parts = absl::StrSplit(clean_dest, '/', absl::SkipEmpty());
+ auto source_iter = source_parts.begin();
+ auto dest_iter = dest_parts.begin();
+
+ // Advance past common prefix.
+ while (source_iter != source_parts.end() && dest_iter != dest_parts.end() &&
+ *source_iter == *dest_iter) {
+ ++source_iter;
+ ++dest_iter;
+ }
+
+ // Build result backtracking.
+ std::string result = "";
+ while (source_iter != source_parts.end()) {
+ absl::StrAppend(&result, "../");
+ ++source_iter;
+ }
+
+ // Add remaining path to dest.
+ while (dest_iter != dest_parts.end()) {
+ absl::StrAppend(&result, *dest_iter, "/");
+ ++dest_iter;
+ }
+
+ if (result.empty()) {
+ return std::string(".");
+ }
+
+ // Remove trailing slash.
+ result.erase(result.size() - 1);
+ return result;
+}
+
+absl::string_view Dirname(absl::string_view path) {
+ return SplitPath(path).first;
+}
+
+absl::string_view Basename(absl::string_view path) {
+ return SplitPath(path).second;
+}
+
+std::pair<absl::string_view, absl::string_view> SplitPath(
+ absl::string_view path) {
+ std::string::size_type pos = path.find_last_of('/');
+
+ // Handle the case with no '/' in 'path'.
+ if (pos == absl::string_view::npos) {
+ return std::make_pair(path.substr(0, 0), path);
+ }
+
+ // Handle the case with a single leading '/' in 'path'.
+ if (pos == 0) {
+ return std::make_pair(path.substr(0, 1), absl::ClippedSubstr(path, 1));
+ }
+
+ return std::make_pair(path.substr(0, pos),
+ absl::ClippedSubstr(path, pos + 1));
+}
+
+std::string JoinPath(absl::string_view path1, absl::string_view path2) {
+ if (path1.empty()) {
+ return std::string(path2);
+ }
+ if (path2.empty()) {
+ return std::string(path1);
+ }
+
+ if (path1.back() == '/') {
+ if (path2.front() == '/') {
+ return absl::StrCat(path1, absl::ClippedSubstr(path2, 1));
+ }
+ } else {
+ if (path2.front() != '/') {
+ return absl::StrCat(path1, "/", path2);
+ }
+ }
+ return absl::StrCat(path1, path2);
+}
+
+PosixErrorOr<std::string> ProcessExePath(int pid) {
+ if (pid <= 0) {
+ return PosixError(EINVAL, "Invalid pid specified");
+ }
+
+ return ReadLink(absl::StrCat("/proc/", pid, "/exe"));
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/fs_util.h b/test/util/fs_util.h
new file mode 100644
index 000000000..8cdac23a1
--- /dev/null
+++ b/test/util/fs_util.h
@@ -0,0 +1,210 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_FS_UTIL_H_
+#define GVISOR_TEST_UTIL_FS_UTIL_H_
+
+#include <dirent.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "absl/strings/string_view.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+// O_LARGEFILE as defined by Linux. glibc tries to be clever by setting it to 0
+// because "it isn't needed", even though Linux can return it via F_GETFL.
+#if defined(__x86_64__)
+constexpr int kOLargeFile = 00100000;
+#elif defined(__aarch64__)
+constexpr int kOLargeFile = 00400000;
+#else
+#error "Unknown architecture"
+#endif
+
+// Returns a status or the current working directory.
+PosixErrorOr<std::string> GetCWD();
+
+// Returns true/false depending on whether or not path exists, or an error if it
+// can't be determined.
+PosixErrorOr<bool> Exists(absl::string_view path);
+
+// Returns a stat structure for the given path or an error. If the path
+// represents a symlink, it will be traversed.
+PosixErrorOr<struct stat> Stat(absl::string_view path);
+
+// Returns a stat structure for the given path or an error. If the path
+// represents a symlink, it will not be traversed.
+PosixErrorOr<struct stat> Lstat(absl::string_view path);
+
+// Returns a stat struct for the given fd.
+PosixErrorOr<struct stat> Fstat(int fd);
+
+// Deletes the file or directory at path or returns an error.
+PosixError Delete(absl::string_view path);
+
+// Changes the mode of a file or returns an error.
+PosixError Chmod(absl::string_view path, int mode);
+
+// Create a special or ordinary file.
+PosixError MknodAt(const FileDescriptor& dfd, absl::string_view path, int mode,
+ dev_t dev);
+
+// Unlink the file.
+PosixError UnlinkAt(const FileDescriptor& dfd, absl::string_view path,
+ int flags);
+
+// Truncates a file to the given length or returns an error.
+PosixError Truncate(absl::string_view path, int length);
+
+// Returns true/false depending on whether or not the path is a directory or
+// returns an error.
+PosixErrorOr<bool> IsDirectory(absl::string_view path);
+
+// Makes a directory or returns an error.
+PosixError Mkdir(absl::string_view path, int mode = 0755);
+
+// Removes a directory or returns an error.
+PosixError Rmdir(absl::string_view path);
+
+// Attempts to set the contents of a file or returns an error.
+PosixError SetContents(absl::string_view path, absl::string_view contents);
+
+// Creates a file with the given contents and mode or returns an error.
+PosixError CreateWithContents(absl::string_view path,
+ absl::string_view contents, int mode = 0666);
+
+// Attempts to read the entire contents of the file into the provided string
+// buffer or returns an error.
+PosixError GetContents(absl::string_view path, std::string* output);
+
+// Attempts to read the entire contents of the file or returns an error.
+PosixErrorOr<std::string> GetContents(absl::string_view path);
+
+// Attempts to read the entire contents of the provided fd into the provided
+// string or returns an error.
+PosixError GetContentsFD(int fd, std::string* output);
+
+// Attempts to read the entire contents of the provided fd or returns an error.
+PosixErrorOr<std::string> GetContentsFD(int fd);
+
+// Executes the readlink(2) system call or returns an error.
+PosixErrorOr<std::string> ReadLink(absl::string_view path);
+
+// WalkTree will walk a directory tree in a depth first search manner (if
+// recursive). It will invoke a provided callback for each file and directory,
+// the parent will always be invoked last making this appropriate for things
+// such as deleting an entire directory tree.
+//
+// This method will return an error when it's unable to access the provided
+// path, or when the path is not a directory.
+PosixError WalkTree(
+ absl::string_view path, bool recursive,
+ const std::function<void(absl::string_view, const struct stat&)>& cb);
+
+// Returns the base filenames for all files under a given absolute path. If
+// skipdots is true the returned vector will not contain "." or "..". This
+// method does not walk the tree recursively it only returns the elements
+// in that directory.
+PosixErrorOr<std::vector<std::string>> ListDir(absl::string_view abspath,
+ bool skipdots);
+
+// Attempt to recursively delete a directory or file. Returns an error and
+// the number of undeleted directories and files. If either
+// undeleted_dirs or undeleted_files is nullptr then it will not be used.
+PosixError RecursivelyDelete(absl::string_view path, int* undeleted_dirs,
+ int* undeleted_files);
+
+// Recursively create the directory provided or return an error.
+PosixError RecursivelyCreateDir(absl::string_view path);
+
+// Makes a path absolute with respect to an optional base. If no base is
+// provided it will use the current working directory.
+PosixErrorOr<std::string> MakeAbsolute(absl::string_view filename,
+ absl::string_view base);
+
+// Generates a relative path from the source directory to the destination
+// (dest) file or directory. This uses ../ when necessary for destinations
+// which are not nested within the source. Both source and dest are required
+// to be absolute paths, and an empty string will be returned if they are not.
+PosixErrorOr<std::string> GetRelativePath(absl::string_view source,
+ absl::string_view dest);
+
+// Returns the part of the path before the final "/", EXCEPT:
+// * If there is a single leading "/" in the path, the result will be the
+// leading "/".
+// * If there is no "/" in the path, the result is the empty prefix of the
+// input string.
+absl::string_view Dirname(absl::string_view path);
+
+// Return the parts of the path, split on the final "/". If there is no
+// "/" in the path, the first part of the output is empty and the second
+// is the input. If the only "/" in the path is the first character, it is
+// the first part of the output.
+std::pair<absl::string_view, absl::string_view> SplitPath(
+ absl::string_view path);
+
+// Returns the part of the path after the final "/". If there is no
+// "/" in the path, the result is the same as the input.
+// Note that this function's behavior differs from the Unix basename
+// command if path ends with "/". For such paths, this function returns the
+// empty string.
+absl::string_view Basename(absl::string_view path);
+
+// Collapse duplicate "/"s, resolve ".." and "." path elements, remove
+// trailing "/".
+//
+// NOTE: This respects relative vs. absolute paths, but does not
+// invoke any system calls (getcwd(2)) in order to resolve relative
+// paths wrt actual working directory. That is, this is purely a
+// string manipulation, completely independent of process state.
+std::string CleanPath(absl::string_view path);
+
+// Returns the full path to the executable of the given pid or a PosixError.
+PosixErrorOr<std::string> ProcessExePath(int pid);
+
+namespace internal {
+// Not part of the public API.
+std::string JoinPathImpl(std::initializer_list<absl::string_view> paths);
+} // namespace internal
+
+// Join multiple paths together.
+// All paths will be treated as relative paths, regardless of whether or not
+// they start with a leading '/'. That is, all paths will be concatenated
+// together, with the appropriate path separator inserted in between.
+// Arguments must be convertible to absl::string_view.
+//
+// Usage:
+// std::string path = JoinPath("/foo", dirname, filename);
+// std::string path = JoinPath(FLAGS_test_srcdir, filename);
+//
+// 0, 1, 2-path specializations exist to optimize common cases.
+inline std::string JoinPath() { return std::string(); }
+inline std::string JoinPath(absl::string_view path) {
+ return std::string(path.data(), path.size());
+}
+
+std::string JoinPath(absl::string_view path1, absl::string_view path2);
+template <typename... T>
+inline std::string JoinPath(absl::string_view path1, absl::string_view path2,
+ absl::string_view path3, const T&... args) {
+ return internal::JoinPathImpl({path1, path2, path3, args...});
+}
+} // namespace testing
+} // namespace gvisor
+#endif // GVISOR_TEST_UTIL_FS_UTIL_H_
diff --git a/test/util/fs_util_test.cc b/test/util/fs_util_test.cc
new file mode 100644
index 000000000..657b6a46e
--- /dev/null
+++ b/test/util/fs_util_test.cc
@@ -0,0 +1,105 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/fs_util.h"
+
+#include <errno.h>
+
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(FsUtilTest, RecursivelyCreateDirManualDelete) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const std::string base_path =
+ JoinPath(root.path(), "/a/b/c/d/e/f/g/h/i/j/k/l/m");
+
+ ASSERT_THAT(Exists(base_path), IsPosixErrorOkAndHolds(false));
+ ASSERT_NO_ERRNO(RecursivelyCreateDir(base_path));
+
+ // Delete everything until we hit root and then stop, we want to try this
+ // without using RecursivelyDelete.
+ std::string cur_path = base_path;
+ while (cur_path != root.path()) {
+ ASSERT_THAT(Exists(cur_path), IsPosixErrorOkAndHolds(true));
+ ASSERT_NO_ERRNO(Rmdir(cur_path));
+ ASSERT_THAT(Exists(cur_path), IsPosixErrorOkAndHolds(false));
+ auto dir = Dirname(cur_path);
+ cur_path = std::string(dir);
+ }
+}
+
+TEST(FsUtilTest, RecursivelyCreateAndDeleteDir) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const std::string base_path =
+ JoinPath(root.path(), "/a/b/c/d/e/f/g/h/i/j/k/l/m");
+
+ ASSERT_THAT(Exists(base_path), IsPosixErrorOkAndHolds(false));
+ ASSERT_NO_ERRNO(RecursivelyCreateDir(base_path));
+
+ const std::string sub_path = JoinPath(root.path(), "a");
+ ASSERT_NO_ERRNO(RecursivelyDelete(sub_path, nullptr, nullptr));
+ ASSERT_THAT(Exists(sub_path), IsPosixErrorOkAndHolds(false));
+}
+
+TEST(FsUtilTest, RecursivelyCreateAndDeletePartial) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const std::string base_path =
+ JoinPath(root.path(), "/a/b/c/d/e/f/g/h/i/j/k/l/m");
+
+ ASSERT_THAT(Exists(base_path), IsPosixErrorOkAndHolds(false));
+ ASSERT_NO_ERRNO(RecursivelyCreateDir(base_path));
+
+ const std::string a = JoinPath(root.path(), "a");
+ auto listing = ASSERT_NO_ERRNO_AND_VALUE(ListDir(a, true));
+ ASSERT_THAT(listing, ::testing::Contains("b"));
+ ASSERT_EQ(listing.size(), 1);
+
+ listing = ASSERT_NO_ERRNO_AND_VALUE(ListDir(a, false));
+ ASSERT_THAT(listing, ::testing::Contains("."));
+ ASSERT_THAT(listing, ::testing::Contains(".."));
+ ASSERT_THAT(listing, ::testing::Contains("b"));
+ ASSERT_EQ(listing.size(), 3);
+
+ const std::string sub_path = JoinPath(root.path(), "/a/b/c/d/e/f");
+
+ ASSERT_NO_ERRNO(
+ CreateWithContents(JoinPath(Dirname(sub_path), "file"), "Hello World"));
+ std::string contents = "";
+ ASSERT_NO_ERRNO(GetContents(JoinPath(Dirname(sub_path), "file"), &contents));
+ ASSERT_EQ(contents, "Hello World");
+
+ ASSERT_NO_ERRNO(RecursivelyDelete(sub_path, nullptr, nullptr));
+ ASSERT_THAT(Exists(sub_path), IsPosixErrorOkAndHolds(false));
+
+ // The parent of the subpath (directory e) should still exist.
+ ASSERT_THAT(Exists(Dirname(sub_path)), IsPosixErrorOkAndHolds(true));
+
+ // The file we created along side f should also still exist.
+ ASSERT_THAT(Exists(JoinPath(Dirname(sub_path), "file")),
+ IsPosixErrorOkAndHolds(true));
+}
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/logging.cc b/test/util/logging.cc
new file mode 100644
index 000000000..5d5e76c46
--- /dev/null
+++ b/test/util/logging.cc
@@ -0,0 +1,97 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/logging.h"
+
+#include <errno.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <unistd.h>
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// We implement this here instead of using test_util to avoid cyclic
+// dependencies.
+int Write(int fd, const char* buf, size_t size) {
+ size_t written = 0;
+ while (written < size) {
+ int res = write(fd, buf + written, size - written);
+ if (res < 0 && errno == EINTR) {
+ continue;
+ } else if (res <= 0) {
+ break;
+ }
+
+ written += res;
+ }
+ return static_cast<int>(written);
+}
+
+// Write 32-bit decimal number to fd.
+int WriteNumber(int fd, uint32_t val) {
+ constexpr char kDigits[] = "0123456789";
+ constexpr int kBase = 10;
+
+ // 10 chars for 32-bit number in decimal, 1 char for the NUL-terminator.
+ constexpr int kBufferSize = 11;
+ char buf[kBufferSize];
+
+ // Convert the number to string.
+ char* s = buf + sizeof(buf) - 1;
+ size_t size = 0;
+
+ *s = '\0';
+ do {
+ s--;
+ size++;
+
+ *s = kDigits[val % kBase];
+ val /= kBase;
+ } while (val);
+
+ return Write(fd, s, size);
+}
+
+} // namespace
+
+void CheckFailure(const char* cond, size_t cond_size, const char* msg,
+ size_t msg_size, bool include_errno) {
+ int saved_errno = errno;
+
+ constexpr char kCheckFailure[] = "Check failed: ";
+ Write(2, kCheckFailure, sizeof(kCheckFailure) - 1);
+ Write(2, cond, cond_size);
+
+ if (msg != nullptr) {
+ Write(2, ": ", 2);
+ Write(2, msg, msg_size);
+ }
+
+ if (include_errno) {
+ constexpr char kErrnoMessage[] = " (errno ";
+ Write(2, kErrnoMessage, sizeof(kErrnoMessage) - 1);
+ WriteNumber(2, saved_errno);
+ Write(2, ")", 1);
+ }
+
+ Write(2, "\n", 1);
+
+ abort();
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/logging.h b/test/util/logging.h
new file mode 100644
index 000000000..589166fab
--- /dev/null
+++ b/test/util/logging.h
@@ -0,0 +1,73 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_LOGGING_H_
+#define GVISOR_TEST_UTIL_LOGGING_H_
+
+#include <stddef.h>
+
+namespace gvisor {
+namespace testing {
+
+void CheckFailure(const char* cond, size_t cond_size, const char* msg,
+ size_t msg_size, bool include_errno);
+
+// If cond is false, aborts the current process.
+//
+// This macro is async-signal-safe.
+#define TEST_CHECK(cond) \
+ do { \
+ if (!(cond)) { \
+ ::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, nullptr, \
+ 0, false); \
+ } \
+ } while (0)
+
+// If cond is false, logs msg then aborts the current process.
+//
+// This macro is async-signal-safe.
+#define TEST_CHECK_MSG(cond, msg) \
+ do { \
+ if (!(cond)) { \
+ ::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, msg, \
+ sizeof(msg) - 1, false); \
+ } \
+ } while (0)
+
+// If cond is false, logs errno, then aborts the current process.
+//
+// This macro is async-signal-safe.
+#define TEST_PCHECK(cond) \
+ do { \
+ if (!(cond)) { \
+ ::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, nullptr, \
+ 0, true); \
+ } \
+ } while (0)
+
+// If cond is false, logs msg and errno, then aborts the current process.
+//
+// This macro is async-signal-safe.
+#define TEST_PCHECK_MSG(cond, msg) \
+ do { \
+ if (!(cond)) { \
+ ::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, msg, \
+ sizeof(msg) - 1, true); \
+ } \
+ } while (0)
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_LOGGING_H_
diff --git a/test/util/memory_util.h b/test/util/memory_util.h
new file mode 100644
index 000000000..e189b73e8
--- /dev/null
+++ b/test/util/memory_util.h
@@ -0,0 +1,147 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_MEMORY_UTIL_H_
+#define GVISOR_TEST_UTIL_MEMORY_UTIL_H_
+
+#include <errno.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <sys/mman.h>
+
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "test/util/logging.h"
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// RAII type for mmap'ed memory. Only usable in tests due to use of a test-only
+// macro that can't be named without invoking the presubmit's wrath.
+class Mapping {
+ public:
+ // Constructs a mapping that owns nothing.
+ Mapping() = default;
+
+ // Constructs a mapping that owns the mmapped memory [ptr, ptr+len). Most
+ // users should use Mmap or MmapAnon instead.
+ Mapping(void* ptr, size_t len) : ptr_(ptr), len_(len) {}
+
+ Mapping(Mapping&& orig) : ptr_(orig.ptr_), len_(orig.len_) { orig.release(); }
+
+ Mapping& operator=(Mapping&& orig) {
+ ptr_ = orig.ptr_;
+ len_ = orig.len_;
+ orig.release();
+ return *this;
+ }
+
+ Mapping(Mapping const&) = delete;
+ Mapping& operator=(Mapping const&) = delete;
+
+ ~Mapping() { reset(); }
+
+ void* ptr() const { return ptr_; }
+ size_t len() const { return len_; }
+
+ // Returns a pointer to the end of the mapping. Useful for when the mapping
+ // is used as a thread stack.
+ void* endptr() const { return reinterpret_cast<void*>(addr() + len_); }
+
+ // Returns the start of this mapping cast to uintptr_t for ease of pointer
+ // arithmetic.
+ uintptr_t addr() const { return reinterpret_cast<uintptr_t>(ptr_); }
+
+ // Returns the end of this mapping cast to uintptr_t for ease of pointer
+ // arithmetic.
+ uintptr_t endaddr() const { return reinterpret_cast<uintptr_t>(endptr()); }
+
+ // Returns this mapping as a StringPiece for ease of comparison.
+ //
+ // This function is named view in anticipation of the eventual replacement of
+ // StringPiece with std::string_view.
+ absl::string_view view() const {
+ return absl::string_view(static_cast<char const*>(ptr_), len_);
+ }
+
+ // These are both named reset for consistency with standard smart pointers.
+
+ void reset(void* ptr, size_t len) {
+ if (len_) {
+ TEST_PCHECK(munmap(ptr_, len_) == 0);
+ }
+ ptr_ = ptr;
+ len_ = len;
+ }
+
+ void reset() { reset(nullptr, 0); }
+
+ void release() {
+ ptr_ = nullptr;
+ len_ = 0;
+ }
+
+ private:
+ void* ptr_ = nullptr;
+ size_t len_ = 0;
+};
+
+// Wrapper around mmap(2) that returns a Mapping.
+inline PosixErrorOr<Mapping> Mmap(void* addr, size_t length, int prot,
+ int flags, int fd, off_t offset) {
+ void* ptr = mmap(addr, length, prot, flags, fd, offset);
+ if (ptr == MAP_FAILED) {
+ return PosixError(
+ errno, absl::StrFormat("mmap(%p, %d, %x, %x, %d, %d)", addr, length,
+ prot, flags, fd, offset));
+ }
+ MaybeSave();
+ return Mapping(ptr, length);
+}
+
+// Convenience wrapper around Mmap for anonymous mappings.
+inline PosixErrorOr<Mapping> MmapAnon(size_t length, int prot, int flags) {
+ return Mmap(nullptr, length, prot, flags | MAP_ANONYMOUS, -1, 0);
+}
+
+// Wrapper for mremap that returns a PosixErrorOr<>, since the return type of
+// void* isn't directly compatible with SyscallSucceeds.
+inline PosixErrorOr<void*> Mremap(void* old_address, size_t old_size,
+ size_t new_size, int flags,
+ void* new_address) {
+ void* rv = mremap(old_address, old_size, new_size, flags, new_address);
+ if (rv == MAP_FAILED) {
+ return PosixError(errno, "mremap failed");
+ }
+ return rv;
+}
+
+// Returns true if the page containing addr is mapped.
+inline bool IsMapped(uintptr_t addr) {
+ int const rv = msync(reinterpret_cast<void*>(addr & ~(kPageSize - 1)),
+ kPageSize, MS_ASYNC);
+ if (rv == 0) {
+ return true;
+ }
+ TEST_PCHECK_MSG(errno == ENOMEM, "msync failed with unexpected errno");
+ return false;
+}
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_MEMORY_UTIL_H_
diff --git a/test/util/mount_util.h b/test/util/mount_util.h
new file mode 100644
index 000000000..09e2281eb
--- /dev/null
+++ b/test/util/mount_util.h
@@ -0,0 +1,51 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_MOUNT_UTIL_H_
+#define GVISOR_TEST_UTIL_MOUNT_UTIL_H_
+
+#include <errno.h>
+#include <sys/mount.h>
+
+#include <functional>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "test/util/cleanup.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Mount mounts the filesystem, and unmounts when the returned reference is
+// destroyed.
+inline PosixErrorOr<Cleanup> Mount(const std::string& source,
+ const std::string& target,
+ const std::string& fstype,
+ uint64_t mountflags, const std::string& data,
+ uint64_t umountflags) {
+ if (mount(source.c_str(), target.c_str(), fstype.c_str(), mountflags,
+ data.c_str()) == -1) {
+ return PosixError(errno, "mount failed");
+ }
+ return Cleanup([target, umountflags]() {
+ EXPECT_THAT(umount2(target.c_str(), umountflags), SyscallSucceeds());
+ });
+}
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_MOUNT_UTIL_H_
diff --git a/test/util/multiprocess_util.cc b/test/util/multiprocess_util.cc
new file mode 100644
index 000000000..8b676751b
--- /dev/null
+++ b/test/util/multiprocess_util.cc
@@ -0,0 +1,173 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/multiprocess_util.h"
+
+#include <asm/unistd.h>
+#include <errno.h>
+#include <fcntl.h>
+#include <signal.h>
+#include <sys/prctl.h>
+#include <unistd.h>
+
+#include "absl/strings/str_cat.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// exec_fn wraps a variant of the exec family, e.g. execve or execveat.
+PosixErrorOr<Cleanup> ForkAndExecHelper(const std::function<void()>& exec_fn,
+ const std::function<void()>& fn,
+ pid_t* child, int* execve_errno) {
+ int pfds[2];
+ int ret = pipe2(pfds, O_CLOEXEC);
+ if (ret < 0) {
+ return PosixError(errno, "pipe failed");
+ }
+ FileDescriptor rfd(pfds[0]);
+ FileDescriptor wfd(pfds[1]);
+
+ int parent_stdout = dup(STDOUT_FILENO);
+ if (parent_stdout < 0) {
+ return PosixError(errno, "dup stdout");
+ }
+ int parent_stderr = dup(STDERR_FILENO);
+ if (parent_stdout < 0) {
+ return PosixError(errno, "dup stderr");
+ }
+
+ pid_t pid = fork();
+ if (pid < 0) {
+ return PosixError(errno, "fork failed");
+ } else if (pid == 0) {
+ // Child.
+ rfd.reset();
+ if (dup2(parent_stdout, STDOUT_FILENO) < 0) {
+ _exit(3);
+ }
+ if (dup2(parent_stderr, STDERR_FILENO) < 0) {
+ _exit(4);
+ }
+ close(parent_stdout);
+ close(parent_stderr);
+
+ // Clean ourself up in case the parent doesn't.
+ if (prctl(PR_SET_PDEATHSIG, SIGKILL)) {
+ _exit(3);
+ }
+
+ if (fn) {
+ fn();
+ }
+
+ // Call variant of exec function.
+ exec_fn();
+
+ int error = errno;
+ if (WriteFd(pfds[1], &error, sizeof(error)) != sizeof(error)) {
+ // We can't do much if the write fails, but we can at least exit with a
+ // different code.
+ _exit(2);
+ }
+ _exit(1);
+ }
+
+ // Parent.
+ if (child) {
+ *child = pid;
+ }
+
+ auto cleanup = Cleanup([pid] {
+ kill(pid, SIGKILL);
+ RetryEINTR(waitpid)(pid, nullptr, 0);
+ });
+
+ wfd.reset();
+
+ int read_errno;
+ ret = ReadFd(rfd.get(), &read_errno, sizeof(read_errno));
+ if (ret == 0) {
+ // Other end of the pipe closed, execve must have succeeded.
+ read_errno = 0;
+ } else if (ret < 0) {
+ return PosixError(errno, "read pipe failed");
+ } else if (ret != sizeof(read_errno)) {
+ return PosixError(EPIPE, absl::StrCat("pipe read wrong size ", ret));
+ }
+
+ if (execve_errno) {
+ *execve_errno = read_errno;
+ }
+
+ return std::move(cleanup);
+}
+
+} // namespace
+
+PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename,
+ const ExecveArray& argv,
+ const ExecveArray& envv,
+ const std::function<void()>& fn, pid_t* child,
+ int* execve_errno) {
+ char* const* argv_data = argv.get();
+ char* const* envv_data = envv.get();
+ const std::function<void()> exec_fn = [=] {
+ execve(filename.c_str(), argv_data, envv_data);
+ };
+ return ForkAndExecHelper(exec_fn, fn, child, execve_errno);
+}
+
+PosixErrorOr<Cleanup> ForkAndExecveat(const int32_t dirfd,
+ const std::string& pathname,
+ const ExecveArray& argv,
+ const ExecveArray& envv, const int flags,
+ const std::function<void()>& fn,
+ pid_t* child, int* execve_errno) {
+ char* const* argv_data = argv.get();
+ char* const* envv_data = envv.get();
+ const std::function<void()> exec_fn = [=] {
+ syscall(__NR_execveat, dirfd, pathname.c_str(), argv_data, envv_data,
+ flags);
+ };
+ return ForkAndExecHelper(exec_fn, fn, child, execve_errno);
+}
+
+PosixErrorOr<int> InForkedProcess(const std::function<void()>& fn) {
+ pid_t pid = fork();
+ if (pid == 0) {
+ fn();
+ _exit(0);
+ }
+ MaybeSave();
+ if (pid < 0) {
+ return PosixError(errno, "fork failed");
+ }
+
+ int status;
+ if (waitpid(pid, &status, 0) < 0) {
+ return PosixError(errno, "waitpid failed");
+ }
+
+ return status;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/multiprocess_util.h b/test/util/multiprocess_util.h
new file mode 100644
index 000000000..2f3bf4a6f
--- /dev/null
+++ b/test/util/multiprocess_util.h
@@ -0,0 +1,132 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_MULTIPROCESS_UTIL_H_
+#define GVISOR_TEST_UTIL_MULTIPROCESS_UTIL_H_
+
+#include <unistd.h>
+
+#include <algorithm>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "test/util/cleanup.h"
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+// Immutable holder for a dynamically-sized array of pointers to mutable char,
+// terminated by a null pointer, as required for the argv and envp arguments to
+// execve(2).
+class ExecveArray {
+ public:
+ // Constructs an empty ExecveArray.
+ ExecveArray() = default;
+
+ // Constructs an ExecveArray by copying strings from the given range. T must
+ // be a range over ranges of char.
+ template <typename T>
+ explicit ExecveArray(T const& strs) : ExecveArray(strs.begin(), strs.end()) {}
+
+ // Constructs an ExecveArray by copying strings from [first, last). InputIt
+ // must be an input iterator over a range over char.
+ template <typename InputIt>
+ ExecveArray(InputIt first, InputIt last) {
+ std::vector<size_t> offsets;
+ auto output_it = std::back_inserter(str_);
+ for (InputIt it = first; it != last; ++it) {
+ offsets.push_back(str_.size());
+ auto const& s = *it;
+ std::copy(s.begin(), s.end(), output_it);
+ str_.push_back('\0');
+ }
+ ptrs_.reserve(offsets.size() + 1);
+ for (auto offset : offsets) {
+ ptrs_.push_back(str_.data() + offset);
+ }
+ ptrs_.push_back(nullptr);
+ }
+
+ // Constructs an ExecveArray by copying strings from list. This overload must
+ // exist independently of the single-argument template constructor because
+ // std::initializer_list does not participate in template argument deduction
+ // (i.e. cannot be type-inferred in an invocation of the templated
+ // constructor).
+ /* implicit */ ExecveArray(std::initializer_list<absl::string_view> list)
+ : ExecveArray(list.begin(), list.end()) {}
+
+ // Disable move construction and assignment since ptrs_ points into str_.
+ ExecveArray(ExecveArray&&) = delete;
+ ExecveArray& operator=(ExecveArray&&) = delete;
+
+ char* const* get() const { return ptrs_.data(); }
+ size_t get_size() { return str_.size(); }
+
+ private:
+ std::vector<char> str_;
+ std::vector<char*> ptrs_;
+};
+
+// Simplified version of SubProcess. Returns OK and a cleanup function to kill
+// the child if it made it to execve.
+//
+// fn is run between fork and exec. If it needs to fail, it should exit the
+// process.
+//
+// The child pid is returned via child, if provided.
+// execve's error code is returned via execve_errno, if provided.
+PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename,
+ const ExecveArray& argv,
+ const ExecveArray& envv,
+ const std::function<void()>& fn, pid_t* child,
+ int* execve_errno);
+
+inline PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename,
+ const ExecveArray& argv,
+ const ExecveArray& envv, pid_t* child,
+ int* execve_errno) {
+ return ForkAndExec(
+ filename, argv, envv, [] {}, child, execve_errno);
+}
+
+// Equivalent to ForkAndExec, except using dirfd and flags with execveat.
+PosixErrorOr<Cleanup> ForkAndExecveat(int32_t dirfd,
+ const std::string& pathname,
+ const ExecveArray& argv,
+ const ExecveArray& envv, int flags,
+ const std::function<void()>& fn,
+ pid_t* child, int* execve_errno);
+
+inline PosixErrorOr<Cleanup> ForkAndExecveat(int32_t dirfd,
+ const std::string& pathname,
+ const ExecveArray& argv,
+ const ExecveArray& envv, int flags,
+ pid_t* child, int* execve_errno) {
+ return ForkAndExecveat(
+ dirfd, pathname, argv, envv, flags, [] {}, child, execve_errno);
+}
+
+// Calls fn in a forked subprocess and returns the exit status of the
+// subprocess.
+//
+// fn must be async-signal-safe.
+PosixErrorOr<int> InForkedProcess(const std::function<void()>& fn);
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_MULTIPROCESS_UTIL_H_
diff --git a/test/util/platform_util.cc b/test/util/platform_util.cc
new file mode 100644
index 000000000..c9200d381
--- /dev/null
+++ b/test/util/platform_util.cc
@@ -0,0 +1,48 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/platform_util.h"
+
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+PlatformSupport PlatformSupport32Bit() {
+ if (GvisorPlatform() == Platform::kPtrace ||
+ GvisorPlatform() == Platform::kKVM) {
+ return PlatformSupport::NotSupported;
+ } else {
+ return PlatformSupport::Allowed;
+ }
+}
+
+PlatformSupport PlatformSupportAlignmentCheck() {
+ return PlatformSupport::Allowed;
+}
+
+PlatformSupport PlatformSupportMultiProcess() {
+ return PlatformSupport::Allowed;
+}
+
+PlatformSupport PlatformSupportInt3() {
+ if (GvisorPlatform() == Platform::kKVM) {
+ return PlatformSupport::NotSupported;
+ } else {
+ return PlatformSupport::Allowed;
+ }
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/platform_util.h b/test/util/platform_util.h
new file mode 100644
index 000000000..28cc92371
--- /dev/null
+++ b/test/util/platform_util.h
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_PLATFORM_UTIL_H_
+#define GVISOR_TEST_UTIL_PLATFORM_UTIL_H_
+
+namespace gvisor {
+namespace testing {
+
+// PlatformSupport is a generic enumeration of classes of support.
+//
+// It is up to the individual functions and callers to agree on the precise
+// definition for each case. The document here generally refers to 32-bit
+// as an example. Many cases will use only NotSupported and Allowed.
+enum class PlatformSupport {
+ // The feature is not supported on the current platform.
+ //
+ // In the case of 32-bit, this means that calls will generally be interpreted
+ // as 64-bit calls, and there is no support for 32-bit binaries, long calls,
+ // etc. This usually means that the underlying implementation just pretends
+ // that 32-bit doesn't exist.
+ NotSupported,
+
+ // Calls will be ignored by the kernel with a fixed error.
+ Ignored,
+
+ // Calls will result in a SIGSEGV or similar fault.
+ Segfault,
+
+ // The feature is supported as expected.
+ //
+ // In the case of 32-bit, this means that the system call or far call will be
+ // handled properly.
+ Allowed,
+};
+
+PlatformSupport PlatformSupport32Bit();
+PlatformSupport PlatformSupportAlignmentCheck();
+PlatformSupport PlatformSupportMultiProcess();
+PlatformSupport PlatformSupportInt3();
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_PLATFORM_UTL_H_
diff --git a/test/util/posix_error.cc b/test/util/posix_error.cc
new file mode 100644
index 000000000..cebf7e0ac
--- /dev/null
+++ b/test/util/posix_error.cc
@@ -0,0 +1,98 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/posix_error.h"
+
+#include <cassert>
+#include <cerrno>
+#include <cstring>
+#include <string>
+
+#include "absl/strings/str_cat.h"
+
+namespace gvisor {
+namespace testing {
+
+std::string PosixError::ToString() const {
+ if (ok()) {
+ return "No Error";
+ }
+
+ std::string ret;
+
+ char strerrno_buf[1024] = {};
+
+ auto res = strerror_r(errno_, strerrno_buf, sizeof(strerrno_buf));
+
+// The GNU version of strerror_r always returns a non-null char* pointing to a
+// buffer containing the stringified errno; the XSI version returns a positive
+// errno which indicates the result of writing the stringified errno into the
+// supplied buffer. The gymnastics below are needed to support both.
+#ifndef _GNU_SOURCE
+ if (res != 0) {
+ ret = absl::StrCat("PosixError(errno=", errno_, " strerror_r FAILED(", ret,
+ "))");
+ } else {
+ ret = absl::StrCat("PosixError(errno=", errno_, " ", strerrno_buf, ")");
+ }
+#else
+ ret = absl::StrCat("PosixError(errno=", errno_, " ", res, ")");
+#endif
+
+ if (!msg_.empty()) {
+ ret.append(" ");
+ ret.append(msg_);
+ }
+
+ return ret;
+}
+
+::std::ostream& operator<<(::std::ostream& os, const PosixError& e) {
+ os << e.ToString();
+ return os;
+}
+
+void PosixErrorIsMatcherCommonImpl::DescribeTo(std::ostream* os) const {
+ *os << "has an errno value that ";
+ code_matcher_.DescribeTo(os);
+ *os << ", and has an error message that ";
+ message_matcher_.DescribeTo(os);
+}
+
+void PosixErrorIsMatcherCommonImpl::DescribeNegationTo(std::ostream* os) const {
+ *os << "has an errno value that ";
+ code_matcher_.DescribeNegationTo(os);
+ *os << ", or has an error message that ";
+ message_matcher_.DescribeNegationTo(os);
+}
+
+bool PosixErrorIsMatcherCommonImpl::MatchAndExplain(
+ const PosixError& error,
+ ::testing::MatchResultListener* result_listener) const {
+ ::testing::StringMatchResultListener inner_listener;
+
+ inner_listener.Clear();
+ if (!code_matcher_.MatchAndExplain(error.errno_value(), &inner_listener)) {
+ return false;
+ }
+
+ if (!message_matcher_.Matches(error.error_message())) {
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/posix_error.h b/test/util/posix_error.h
new file mode 100644
index 000000000..ad666bce0
--- /dev/null
+++ b/test/util/posix_error.h
@@ -0,0 +1,462 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_POSIX_ERROR_H_
+#define GVISOR_TEST_UTIL_POSIX_ERROR_H_
+
+#include <string>
+
+#include "gmock/gmock.h"
+#include "absl/base/attributes.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/variant.h"
+#include "test/util/logging.h"
+
+namespace gvisor {
+namespace testing {
+
+class PosixErrorIsMatcherCommonImpl;
+
+template <typename T>
+class PosixErrorOr;
+
+class ABSL_MUST_USE_RESULT PosixError {
+ public:
+ PosixError() {}
+ explicit PosixError(int errno_value) : errno_(errno_value) {}
+ PosixError(int errno_value, std::string msg)
+ : errno_(errno_value), msg_(std::move(msg)) {}
+
+ PosixError(PosixError&& other) = default;
+ PosixError& operator=(PosixError&& other) = default;
+ PosixError(const PosixError&) = default;
+ PosixError& operator=(const PosixError&) = default;
+
+ bool ok() const { return errno_ == 0; }
+
+ // Returns a reference to *this to make matchers compatible with
+ // PosixErrorOr.
+ const PosixError& error() const { return *this; }
+
+ std::string error_message() const { return msg_; }
+
+ // ToString produces a full string representation of this posix error
+ // including the printable representation of the errno and the error message.
+ std::string ToString() const;
+
+ // Ignores any errors. This method does nothing except potentially suppress
+ // complaints from any tools that are checking that errors are not dropped on
+ // the floor.
+ void IgnoreError() const {}
+
+ private:
+ int errno_value() const { return errno_; }
+ int errno_ = 0;
+ std::string msg_;
+
+ friend class PosixErrorIsMatcherCommonImpl;
+
+ template <typename T>
+ friend class PosixErrorOr;
+};
+
+template <typename T>
+class ABSL_MUST_USE_RESULT PosixErrorOr {
+ public:
+ // A PosixErrorOr will check fail if it is constructed with NoError().
+ PosixErrorOr(const PosixError& error);
+ PosixErrorOr(const T& value);
+ PosixErrorOr(T&& value);
+
+ PosixErrorOr(PosixErrorOr&& other) = default;
+ PosixErrorOr& operator=(PosixErrorOr&& other) = default;
+ PosixErrorOr(const PosixErrorOr&) = default;
+ PosixErrorOr& operator=(const PosixErrorOr&) = default;
+
+ // Conversion copy/move constructor, T must be convertible from U.
+ template <typename U>
+ friend class PosixErrorOr;
+
+ template <typename U>
+ PosixErrorOr(PosixErrorOr<U> other);
+
+ template <typename U>
+ PosixErrorOr& operator=(PosixErrorOr<U> other);
+
+ // Return a reference to the error or NoError().
+ PosixError error() const;
+
+ // Returns this->error().error_message();
+ std::string error_message() const;
+
+ // Returns true if this PosixErrorOr contains some T.
+ bool ok() const;
+
+ // Returns a reference to our current value, or CHECK-fails if !this->ok().
+ const T& ValueOrDie() const&;
+ T& ValueOrDie() &;
+ const T&& ValueOrDie() const&&;
+ T&& ValueOrDie() &&;
+
+ // Ignores any errors. This method does nothing except potentially suppress
+ // complaints from any tools that are checking that errors are not dropped on
+ // the floor.
+ void IgnoreError() const {}
+
+ private:
+ int errno_value() const;
+ absl::variant<T, PosixError> value_;
+
+ friend class PosixErrorIsMatcherCommonImpl;
+};
+
+template <typename T>
+PosixErrorOr<T>::PosixErrorOr(const PosixError& error) : value_(error) {
+ TEST_CHECK_MSG(
+ !error.ok(),
+ "Constructing PosixErrorOr with NoError, eg. errno 0 is not allowed.");
+}
+
+template <typename T>
+PosixErrorOr<T>::PosixErrorOr(const T& value) : value_(value) {}
+
+template <typename T>
+PosixErrorOr<T>::PosixErrorOr(T&& value) : value_(std::move(value)) {}
+
+// Conversion copy/move constructor, T must be convertible from U.
+template <typename T>
+template <typename U>
+inline PosixErrorOr<T>::PosixErrorOr(PosixErrorOr<U> other) {
+ if (absl::holds_alternative<U>(other.value_)) {
+ // T is convertible from U.
+ value_ = absl::get<U>(std::move(other.value_));
+ } else if (absl::holds_alternative<PosixError>(other.value_)) {
+ value_ = absl::get<PosixError>(std::move(other.value_));
+ } else {
+ TEST_CHECK_MSG(false, "PosixErrorOr does not contain PosixError or value");
+ }
+}
+
+template <typename T>
+template <typename U>
+inline PosixErrorOr<T>& PosixErrorOr<T>::operator=(PosixErrorOr<U> other) {
+ if (absl::holds_alternative<U>(other.value_)) {
+ // T is convertible from U.
+ value_ = absl::get<U>(std::move(other.value_));
+ } else if (absl::holds_alternative<PosixError>(other.value_)) {
+ value_ = absl::get<PosixError>(std::move(other.value_));
+ } else {
+ TEST_CHECK_MSG(false, "PosixErrorOr does not contain PosixError or value");
+ }
+ return *this;
+}
+
+template <typename T>
+PosixError PosixErrorOr<T>::error() const {
+ if (!absl::holds_alternative<PosixError>(value_)) {
+ return PosixError();
+ }
+ return absl::get<PosixError>(value_);
+}
+
+template <typename T>
+int PosixErrorOr<T>::errno_value() const {
+ return error().errno_value();
+}
+
+template <typename T>
+std::string PosixErrorOr<T>::error_message() const {
+ return error().error_message();
+}
+
+template <typename T>
+bool PosixErrorOr<T>::ok() const {
+ return absl::holds_alternative<T>(value_);
+}
+
+template <typename T>
+const T& PosixErrorOr<T>::ValueOrDie() const& {
+ TEST_CHECK(absl::holds_alternative<T>(value_));
+ return absl::get<T>(value_);
+}
+
+template <typename T>
+T& PosixErrorOr<T>::ValueOrDie() & {
+ TEST_CHECK(absl::holds_alternative<T>(value_));
+ return absl::get<T>(value_);
+}
+
+template <typename T>
+const T&& PosixErrorOr<T>::ValueOrDie() const&& {
+ TEST_CHECK(absl::holds_alternative<T>(value_));
+ return std::move(absl::get<T>(value_));
+}
+
+template <typename T>
+T&& PosixErrorOr<T>::ValueOrDie() && {
+ TEST_CHECK(absl::holds_alternative<T>(value_));
+ return std::move(absl::get<T>(value_));
+}
+
+extern ::std::ostream& operator<<(::std::ostream& os, const PosixError& e);
+
+template <typename T>
+::std::ostream& operator<<(::std::ostream& os, const PosixErrorOr<T>& e) {
+ os << e.error();
+ return os;
+}
+
+// NoError is a PosixError that represents a successful state, i.e. No Error.
+inline PosixError NoError() { return PosixError(); }
+
+// Monomorphic implementation of matcher IsPosixErrorOk() for a given type T.
+// T can be PosixError, PosixErrorOr<>, or a reference to either of them.
+template <typename T>
+class MonoPosixErrorIsOkMatcherImpl : public ::testing::MatcherInterface<T> {
+ public:
+ void DescribeTo(std::ostream* os) const override { *os << "is OK"; }
+ void DescribeNegationTo(std::ostream* os) const override {
+ *os << "is not OK";
+ }
+ bool MatchAndExplain(T actual_value,
+ ::testing::MatchResultListener*) const override {
+ return actual_value.ok();
+ }
+};
+
+// Implements IsPosixErrorOkMatcher() as a polymorphic matcher.
+class IsPosixErrorOkMatcher {
+ public:
+ template <typename T>
+ operator ::testing::Matcher<T>() const { // NOLINT
+ return MakeMatcher(new MonoPosixErrorIsOkMatcherImpl<T>());
+ }
+};
+
+// Monomorphic implementation of a matcher for a PosixErrorOr.
+template <typename PosixErrorOrType>
+class IsPosixErrorOkAndHoldsMatcherImpl
+ : public ::testing::MatcherInterface<PosixErrorOrType> {
+ public:
+ using ValueType = typename std::remove_reference<decltype(
+ std::declval<PosixErrorOrType>().ValueOrDie())>::type;
+
+ template <typename InnerMatcher>
+ explicit IsPosixErrorOkAndHoldsMatcherImpl(InnerMatcher&& inner_matcher)
+ : inner_matcher_(::testing::SafeMatcherCast<const ValueType&>(
+ std::forward<InnerMatcher>(inner_matcher))) {}
+
+ void DescribeTo(std::ostream* os) const override {
+ *os << "is OK and has a value that ";
+ inner_matcher_.DescribeTo(os);
+ }
+
+ void DescribeNegationTo(std::ostream* os) const override {
+ *os << "isn't OK or has a value that ";
+ inner_matcher_.DescribeNegationTo(os);
+ }
+
+ bool MatchAndExplain(
+ PosixErrorOrType actual_value,
+ ::testing::MatchResultListener* listener) const override {
+ // We can't extract the value if it doesn't contain one.
+ if (!actual_value.ok()) {
+ return false;
+ }
+
+ ::testing::StringMatchResultListener inner_listener;
+ const bool matches = inner_matcher_.MatchAndExplain(
+ actual_value.ValueOrDie(), &inner_listener);
+ const std::string inner_explanation = inner_listener.str();
+ *listener << "has a value "
+ << ::testing::PrintToString(actual_value.ValueOrDie());
+
+ if (!inner_explanation.empty()) {
+ *listener << " " << inner_explanation;
+ }
+ return matches;
+ }
+
+ private:
+ const ::testing::Matcher<const ValueType&> inner_matcher_;
+};
+
+// Implements IsOkAndHolds() as a polymorphic matcher.
+template <typename InnerMatcher>
+class IsPosixErrorOkAndHoldsMatcher {
+ public:
+ explicit IsPosixErrorOkAndHoldsMatcher(InnerMatcher inner_matcher)
+ : inner_matcher_(std::move(inner_matcher)) {}
+
+ // Converts this polymorphic matcher to a monomorphic one of the given type.
+ // PosixErrorOrType can be either PosixErrorOr<T> or a reference to
+ // PosixErrorOr<T>.
+ template <typename PosixErrorOrType>
+ operator ::testing::Matcher<PosixErrorOrType>() const { // NOLINT
+ return ::testing::MakeMatcher(
+ new IsPosixErrorOkAndHoldsMatcherImpl<PosixErrorOrType>(
+ inner_matcher_));
+ }
+
+ private:
+ const InnerMatcher inner_matcher_;
+};
+
+// PosixErrorIs() is a polymorphic matcher. This class is the common
+// implementation of it shared by all types T where PosixErrorIs() can be
+// used as a Matcher<T>.
+class PosixErrorIsMatcherCommonImpl {
+ public:
+ PosixErrorIsMatcherCommonImpl(
+ ::testing::Matcher<int> code_matcher,
+ ::testing::Matcher<const std::string&> message_matcher)
+ : code_matcher_(std::move(code_matcher)),
+ message_matcher_(std::move(message_matcher)) {}
+
+ void DescribeTo(std::ostream* os) const;
+
+ void DescribeNegationTo(std::ostream* os) const;
+
+ bool MatchAndExplain(const PosixError& error,
+ ::testing::MatchResultListener* result_listener) const;
+
+ template <typename T>
+ bool MatchAndExplain(const PosixErrorOr<T>& error_or,
+ ::testing::MatchResultListener* result_listener) const {
+ if (error_or.ok()) {
+ *result_listener << "has a value "
+ << ::testing::PrintToString(error_or.ValueOrDie());
+ return false;
+ }
+
+ return MatchAndExplain(error_or.error(), result_listener);
+ }
+
+ private:
+ const ::testing::Matcher<int> code_matcher_;
+ const ::testing::Matcher<const std::string&> message_matcher_;
+};
+
+// Monomorphic implementation of matcher PosixErrorIs() for a given type
+// T. T can be PosixError, PosixErrorOr<>, or a reference to either of them.
+template <typename T>
+class MonoPosixErrorIsMatcherImpl : public ::testing::MatcherInterface<T> {
+ public:
+ explicit MonoPosixErrorIsMatcherImpl(
+ PosixErrorIsMatcherCommonImpl common_impl)
+ : common_impl_(std::move(common_impl)) {}
+
+ void DescribeTo(std::ostream* os) const override {
+ common_impl_.DescribeTo(os);
+ }
+
+ void DescribeNegationTo(std::ostream* os) const override {
+ common_impl_.DescribeNegationTo(os);
+ }
+
+ bool MatchAndExplain(
+ T actual_value,
+ ::testing::MatchResultListener* result_listener) const override {
+ return common_impl_.MatchAndExplain(actual_value, result_listener);
+ }
+
+ private:
+ PosixErrorIsMatcherCommonImpl common_impl_;
+};
+
+inline ::testing::Matcher<int> ToErrorCodeMatcher(
+ const ::testing::Matcher<int>& m) {
+ return m;
+}
+
+// Implements PosixErrorIs() as a polymorphic matcher.
+class PosixErrorIsMatcher {
+ public:
+ template <typename ErrorCodeMatcher>
+ PosixErrorIsMatcher(ErrorCodeMatcher&& code_matcher,
+ ::testing::Matcher<const std::string&> message_matcher)
+ : common_impl_(
+ ToErrorCodeMatcher(std::forward<ErrorCodeMatcher>(code_matcher)),
+ std::move(message_matcher)) {}
+
+ // Converts this polymorphic matcher to a monomorphic matcher of the
+ // given type. T can be StatusOr<>, Status, or a reference to
+ // either of them.
+ template <typename T>
+ operator ::testing::Matcher<T>() const { // NOLINT
+ return MakeMatcher(new MonoPosixErrorIsMatcherImpl<T>(common_impl_));
+ }
+
+ private:
+ const PosixErrorIsMatcherCommonImpl common_impl_;
+};
+
+// Returns a gMock matcher that matches a PosixError or PosixErrorOr<> whose
+// whose error code matches code_matcher, and whose error message matches
+// message_matcher.
+template <typename ErrorCodeMatcher>
+PosixErrorIsMatcher PosixErrorIs(
+ ErrorCodeMatcher&& code_matcher,
+ ::testing::Matcher<const std::string&> message_matcher) {
+ return PosixErrorIsMatcher(std::forward<ErrorCodeMatcher>(code_matcher),
+ std::move(message_matcher));
+}
+
+// Returns a gMock matcher that matches a PosixErrorOr<> which is ok() and
+// value matches the inner matcher.
+template <typename InnerMatcher>
+IsPosixErrorOkAndHoldsMatcher<typename std::decay<InnerMatcher>::type>
+IsPosixErrorOkAndHolds(InnerMatcher&& inner_matcher) {
+ return IsPosixErrorOkAndHoldsMatcher<typename std::decay<InnerMatcher>::type>(
+ std::forward<InnerMatcher>(inner_matcher));
+}
+
+// Internal helper for concatenating macro values.
+#define POSIX_ERROR_IMPL_CONCAT_INNER_(x, y) x##y
+#define POSIX_ERROR_IMPL_CONCAT_(x, y) POSIX_ERROR_IMPL_CONCAT_INNER_(x, y)
+
+#define POSIX_ERROR_IMPL_ASSIGN_OR_RETURN_(posixerroror, lhs, rexpr) \
+ auto posixerroror = (rexpr); \
+ if (!posixerroror.ok()) { \
+ return (posixerroror.error()); \
+ } \
+ lhs = std::move(posixerroror).ValueOrDie()
+
+#define EXPECT_NO_ERRNO(expression) \
+ EXPECT_THAT(expression, IsPosixErrorOkMatcher())
+#define ASSERT_NO_ERRNO(expression) \
+ ASSERT_THAT(expression, IsPosixErrorOkMatcher())
+
+#define ASSIGN_OR_RETURN_ERRNO(lhs, rexpr) \
+ POSIX_ERROR_IMPL_ASSIGN_OR_RETURN_( \
+ POSIX_ERROR_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr)
+
+#define RETURN_IF_ERRNO(s) \
+ do { \
+ if (!s.ok()) { \
+ return s; \
+ } \
+ } while (false);
+
+#define ASSERT_NO_ERRNO_AND_VALUE(expr) \
+ ({ \
+ auto _expr_result = (expr); \
+ ASSERT_NO_ERRNO(_expr_result); \
+ std::move(_expr_result).ValueOrDie(); \
+ })
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_POSIX_ERROR_H_
diff --git a/test/util/posix_error_test.cc b/test/util/posix_error_test.cc
new file mode 100644
index 000000000..bf9465abb
--- /dev/null
+++ b/test/util/posix_error_test.cc
@@ -0,0 +1,46 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/posix_error.h"
+
+#include <errno.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(PosixErrorTest, PosixError) {
+ auto err = PosixError(EAGAIN);
+ EXPECT_THAT(err, PosixErrorIs(EAGAIN, ""));
+}
+
+TEST(PosixErrorTest, PosixErrorOrPosixError) {
+ auto err = PosixErrorOr<std::nullptr_t>(PosixError(EAGAIN));
+ EXPECT_THAT(err, PosixErrorIs(EAGAIN, ""));
+}
+
+TEST(PosixErrorTest, PosixErrorOrNullptr) {
+ auto err = PosixErrorOr<std::nullptr_t>(nullptr);
+ EXPECT_TRUE(err.ok());
+ EXPECT_NO_ERRNO(err);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/proc_util.cc b/test/util/proc_util.cc
new file mode 100644
index 000000000..34d636ba9
--- /dev/null
+++ b/test/util/proc_util.cc
@@ -0,0 +1,107 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/proc_util.h"
+
+#include <algorithm>
+#include <iostream>
+#include <vector>
+
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "test/util/fs_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Parses a single line from /proc/<xxx>/maps.
+PosixErrorOr<ProcMapsEntry> ParseProcMapsLine(absl::string_view line) {
+ ProcMapsEntry map_entry = {};
+
+ // Limit splitting to 6 parts so that if there is a file path and it contains
+ // spaces, the file path is not split.
+ std::vector<std::string> parts =
+ absl::StrSplit(line, absl::MaxSplits(' ', 5), absl::SkipEmpty());
+
+ // parts.size() should be 6 if there is a file name specified, and 5
+ // otherwise.
+ if (parts.size() < 5) {
+ return PosixError(EINVAL, absl::StrCat("Invalid line: ", line));
+ }
+
+ // Address range in the form X-X where X are hex values without leading 0x.
+ std::vector<std::string> addresses = absl::StrSplit(parts[0], '-');
+ if (addresses.size() != 2) {
+ return PosixError(EINVAL,
+ absl::StrCat("Invalid address range: ", parts[0]));
+ }
+ ASSIGN_OR_RETURN_ERRNO(map_entry.start, AtoiBase(addresses[0], 16));
+ ASSIGN_OR_RETURN_ERRNO(map_entry.end, AtoiBase(addresses[1], 16));
+
+ // Permissions are four bytes of the form rwxp or - if permission not set.
+ if (parts[1].size() != 4) {
+ return PosixError(EINVAL,
+ absl::StrCat("Invalid permission field: ", parts[1]));
+ }
+
+ map_entry.readable = parts[1][0] == 'r';
+ map_entry.writable = parts[1][1] == 'w';
+ map_entry.executable = parts[1][2] == 'x';
+ map_entry.priv = parts[1][3] == 'p';
+
+ ASSIGN_OR_RETURN_ERRNO(map_entry.offset, AtoiBase(parts[2], 16));
+
+ std::vector<std::string> device = absl::StrSplit(parts[3], ':');
+ if (device.size() != 2) {
+ return PosixError(EINVAL, absl::StrCat("Invalid device: ", parts[3]));
+ }
+ ASSIGN_OR_RETURN_ERRNO(map_entry.major, AtoiBase(device[0], 16));
+ ASSIGN_OR_RETURN_ERRNO(map_entry.minor, AtoiBase(device[1], 16));
+
+ ASSIGN_OR_RETURN_ERRNO(map_entry.inode, Atoi<int64_t>(parts[4]));
+ if (parts.size() == 6) {
+ // A filename is present. However, absl::StrSplit retained the whitespace
+ // between the inode number and the filename.
+ map_entry.filename =
+ std::string(absl::StripLeadingAsciiWhitespace(parts[5]));
+ }
+
+ return map_entry;
+}
+
+PosixErrorOr<std::vector<ProcMapsEntry>> ParseProcMaps(
+ absl::string_view contents) {
+ std::vector<ProcMapsEntry> entries;
+ auto lines = absl::StrSplit(contents, '\n', absl::SkipEmpty());
+ for (const auto& l : lines) {
+ std::cout << "line: " << l << std::endl;
+ ASSIGN_OR_RETURN_ERRNO(auto entry, ParseProcMapsLine(l));
+ entries.push_back(entry);
+ }
+ return entries;
+}
+
+PosixErrorOr<bool> IsVsyscallEnabled() {
+ ASSIGN_OR_RETURN_ERRNO(auto contents, GetContents("/proc/self/maps"));
+ ASSIGN_OR_RETURN_ERRNO(auto maps, ParseProcMaps(contents));
+ return std::any_of(maps.begin(), maps.end(), [](const ProcMapsEntry& e) {
+ return e.filename == "[vsyscall]";
+ });
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/proc_util.h b/test/util/proc_util.h
new file mode 100644
index 000000000..af209a51e
--- /dev/null
+++ b/test/util/proc_util.h
@@ -0,0 +1,150 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_PROC_UTIL_H_
+#define GVISOR_TEST_UTIL_PROC_UTIL_H_
+
+#include <ostream>
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+// ProcMapsEntry contains the data from a single line in /proc/<xxx>/maps.
+struct ProcMapsEntry {
+ uint64_t start;
+ uint64_t end;
+ bool readable;
+ bool writable;
+ bool executable;
+ bool priv;
+ uint64_t offset;
+ int major;
+ int minor;
+ int64_t inode;
+ std::string filename;
+};
+
+// Parses a ProcMaps line or returns an error.
+PosixErrorOr<ProcMapsEntry> ParseProcMapsLine(absl::string_view line);
+PosixErrorOr<std::vector<ProcMapsEntry>> ParseProcMaps(
+ absl::string_view contents);
+
+// Returns true if vsyscall (emmulation or not) is enabled.
+PosixErrorOr<bool> IsVsyscallEnabled();
+
+// Printer for ProcMapsEntry.
+inline std::ostream& operator<<(std::ostream& os, const ProcMapsEntry& entry) {
+ std::string str =
+ absl::StrCat(absl::Hex(entry.start, absl::PadSpec::kZeroPad8), "-",
+ absl::Hex(entry.end, absl::PadSpec::kZeroPad8), " ");
+
+ absl::StrAppend(&str, entry.readable ? "r" : "-");
+ absl::StrAppend(&str, entry.writable ? "w" : "-");
+ absl::StrAppend(&str, entry.executable ? "x" : "-");
+ absl::StrAppend(&str, entry.priv ? "p" : "s");
+
+ absl::StrAppend(&str, " ", absl::Hex(entry.offset, absl::PadSpec::kZeroPad8),
+ " ", absl::Hex(entry.major, absl::PadSpec::kZeroPad2), ":",
+ absl::Hex(entry.minor, absl::PadSpec::kZeroPad2), " ",
+ entry.inode);
+ if (absl::string_view(entry.filename) != "") {
+ // Pad to column 74
+ int pad = 73 - str.length();
+ if (pad > 0) {
+ absl::StrAppend(&str, std::string(pad, ' '));
+ }
+ absl::StrAppend(&str, entry.filename);
+ }
+ os << str;
+ return os;
+}
+
+// Printer for std::vector<ProcMapsEntry>.
+inline std::ostream& operator<<(std::ostream& os,
+ const std::vector<ProcMapsEntry>& vec) {
+ for (unsigned int i = 0; i < vec.size(); i++) {
+ os << vec[i];
+ if (i != vec.size() - 1) {
+ os << "\n";
+ }
+ }
+ return os;
+}
+
+// GMock printer for std::vector<ProcMapsEntry>.
+inline void PrintTo(const std::vector<ProcMapsEntry>& vec, std::ostream* os) {
+ *os << vec;
+}
+
+// Checks that /proc/pid/maps contains all of the passed mappings.
+//
+// The major, minor, and inode fields are ignored.
+MATCHER_P(ContainsMappings, mappings,
+ "contains mappings:\n" + ::testing::PrintToString(mappings)) {
+ auto contents_or = GetContents(absl::StrCat("/proc/", arg, "/maps"));
+ if (!contents_or.ok()) {
+ *result_listener << "Unable to read mappings: "
+ << contents_or.error().ToString();
+ return false;
+ }
+
+ auto maps_or = ParseProcMaps(contents_or.ValueOrDie());
+ if (!maps_or.ok()) {
+ *result_listener << "Unable to parse mappings: "
+ << maps_or.error().ToString();
+ return false;
+ }
+
+ auto maps = std::move(maps_or).ValueOrDie();
+
+ // Does maps contain all elements in mappings? The comparator ignores
+ // the major, minor, and inode fields.
+ bool all_present = true;
+ std::for_each(mappings.begin(), mappings.end(), [&](const ProcMapsEntry& e1) {
+ auto it =
+ std::find_if(maps.begin(), maps.end(), [&e1](const ProcMapsEntry& e2) {
+ return e1.start == e2.start && e1.end == e2.end &&
+ e1.readable == e2.readable && e1.writable == e2.writable &&
+ e1.executable == e2.executable && e1.priv == e2.priv &&
+ e1.offset == e2.offset && e1.filename == e2.filename;
+ });
+ if (it == maps.end()) {
+ // It wasn't found.
+ if (all_present) {
+ // We will output the message once and then a line for each mapping
+ // that wasn't found.
+ all_present = false;
+ *result_listener << "Got mappings:\n"
+ << maps << "\nThat were missing:\n";
+ }
+ *result_listener << e1 << "\n";
+ }
+ });
+
+ return all_present;
+}
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_PROC_UTIL_H_
diff --git a/test/util/proc_util_test.cc b/test/util/proc_util_test.cc
new file mode 100644
index 000000000..71dd2355e
--- /dev/null
+++ b/test/util/proc_util_test.cc
@@ -0,0 +1,81 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/proc_util.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/util/test_util.h"
+
+using ::testing::IsEmpty;
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(ParseProcMapsLineTest, WithoutFilename) {
+ auto entry = ASSERT_NO_ERRNO_AND_VALUE(
+ ParseProcMapsLine("2ab4f00b7000-2ab4f00b9000 r-xp 00000000 00:00 0 "));
+ EXPECT_EQ(entry.start, 0x2ab4f00b7000);
+ EXPECT_EQ(entry.end, 0x2ab4f00b9000);
+ EXPECT_TRUE(entry.readable);
+ EXPECT_FALSE(entry.writable);
+ EXPECT_TRUE(entry.executable);
+ EXPECT_TRUE(entry.priv);
+ EXPECT_EQ(entry.offset, 0);
+ EXPECT_EQ(entry.major, 0);
+ EXPECT_EQ(entry.minor, 0);
+ EXPECT_EQ(entry.inode, 0);
+ EXPECT_THAT(entry.filename, IsEmpty());
+}
+
+TEST(ParseProcMapsLineTest, WithFilename) {
+ auto entry = ASSERT_NO_ERRNO_AND_VALUE(
+ ParseProcMapsLine("00407000-00408000 rw-p 00006000 00:0e 10 "
+ " /bin/cat"));
+ EXPECT_EQ(entry.start, 0x407000);
+ EXPECT_EQ(entry.end, 0x408000);
+ EXPECT_TRUE(entry.readable);
+ EXPECT_TRUE(entry.writable);
+ EXPECT_FALSE(entry.executable);
+ EXPECT_TRUE(entry.priv);
+ EXPECT_EQ(entry.offset, 0x6000);
+ EXPECT_EQ(entry.major, 0);
+ EXPECT_EQ(entry.minor, 0x0e);
+ EXPECT_EQ(entry.inode, 10);
+ EXPECT_EQ(entry.filename, "/bin/cat");
+}
+
+TEST(ParseProcMapsLineTest, WithFilenameContainingSpaces) {
+ auto entry = ASSERT_NO_ERRNO_AND_VALUE(
+ ParseProcMapsLine("7f26b3b12000-7f26b3b13000 rw-s 00000000 00:05 1432484 "
+ " /dev/zero (deleted)"));
+ EXPECT_EQ(entry.start, 0x7f26b3b12000);
+ EXPECT_EQ(entry.end, 0x7f26b3b13000);
+ EXPECT_TRUE(entry.readable);
+ EXPECT_TRUE(entry.writable);
+ EXPECT_FALSE(entry.executable);
+ EXPECT_FALSE(entry.priv);
+ EXPECT_EQ(entry.offset, 0);
+ EXPECT_EQ(entry.major, 0);
+ EXPECT_EQ(entry.minor, 0x05);
+ EXPECT_EQ(entry.inode, 1432484);
+ EXPECT_EQ(entry.filename, "/dev/zero (deleted)");
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/pty_util.cc b/test/util/pty_util.cc
new file mode 100644
index 000000000..c01f916aa
--- /dev/null
+++ b/test/util/pty_util.cc
@@ -0,0 +1,53 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/pty_util.h"
+
+#include <sys/ioctl.h>
+#include <termios.h>
+
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) {
+ PosixErrorOr<int> n = SlaveID(master);
+ if (!n.ok()) {
+ return PosixErrorOr<FileDescriptor>(n.error());
+ }
+ return Open(absl::StrCat("/dev/pts/", n.ValueOrDie()), O_RDWR | O_NONBLOCK);
+}
+
+PosixErrorOr<int> SlaveID(const FileDescriptor& master) {
+ // Get pty index.
+ int n;
+ int ret = ioctl(master.get(), TIOCGPTN, &n);
+ if (ret < 0) {
+ return PosixError(errno, "ioctl(TIOCGPTN) failed");
+ }
+
+ // Unlock pts.
+ int unlock = 0;
+ ret = ioctl(master.get(), TIOCSPTLCK, &unlock);
+ if (ret < 0) {
+ return PosixError(errno, "ioctl(TIOSPTLCK) failed");
+ }
+
+ return n;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/pty_util.h b/test/util/pty_util.h
new file mode 100644
index 000000000..0722da379
--- /dev/null
+++ b/test/util/pty_util.h
@@ -0,0 +1,33 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_PTY_UTIL_H_
+#define GVISOR_TEST_UTIL_PTY_UTIL_H_
+
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+// Opens the slave end of the passed master as R/W and nonblocking.
+PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master);
+
+// Get the number of the slave end of the master.
+PosixErrorOr<int> SlaveID(const FileDescriptor& master);
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_PTY_UTIL_H_
diff --git a/test/util/rlimit_util.cc b/test/util/rlimit_util.cc
new file mode 100644
index 000000000..d7bfc1606
--- /dev/null
+++ b/test/util/rlimit_util.cc
@@ -0,0 +1,45 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/rlimit_util.h"
+
+#include <sys/resource.h>
+
+#include <cerrno>
+
+#include "test/util/cleanup.h"
+#include "test/util/logging.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+PosixErrorOr<Cleanup> ScopedSetSoftRlimit(int resource, rlim_t newval) {
+ struct rlimit old_rlim;
+ if (getrlimit(resource, &old_rlim) != 0) {
+ return PosixError(errno, "getrlimit failed");
+ }
+ struct rlimit new_rlim = old_rlim;
+ new_rlim.rlim_cur = newval;
+ if (setrlimit(resource, &new_rlim) != 0) {
+ return PosixError(errno, "setrlimit failed");
+ }
+ return Cleanup([resource, old_rlim] {
+ TEST_PCHECK(setrlimit(resource, &old_rlim) == 0);
+ });
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/rlimit_util.h b/test/util/rlimit_util.h
new file mode 100644
index 000000000..873252a32
--- /dev/null
+++ b/test/util/rlimit_util.h
@@ -0,0 +1,32 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_RLIMIT_UTIL_H_
+#define GVISOR_TEST_UTIL_RLIMIT_UTIL_H_
+
+#include <sys/resource.h>
+#include <sys/time.h>
+
+#include "test/util/cleanup.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+PosixErrorOr<Cleanup> ScopedSetSoftRlimit(int resource, rlim_t newval);
+
+} // namespace testing
+} // namespace gvisor
+#endif // GVISOR_TEST_UTIL_RLIMIT_UTIL_H_
diff --git a/test/util/save_util.cc b/test/util/save_util.cc
new file mode 100644
index 000000000..384d626f0
--- /dev/null
+++ b/test/util/save_util.cc
@@ -0,0 +1,71 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/save_util.h"
+
+#include <stddef.h>
+#include <stdlib.h>
+#include <unistd.h>
+
+#include <atomic>
+#include <cerrno>
+
+#define GVISOR_COOPERATIVE_SAVE_TEST "GVISOR_COOPERATIVE_SAVE_TEST"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+enum class CooperativeSaveMode {
+ kUnknown = 0, // cooperative_save_mode is statically-initialized to 0
+ kAvailable,
+ kNotAvailable,
+};
+
+std::atomic<CooperativeSaveMode> cooperative_save_mode;
+
+bool CooperativeSaveEnabled() {
+ auto mode = cooperative_save_mode.load();
+ if (mode == CooperativeSaveMode::kUnknown) {
+ mode = (getenv(GVISOR_COOPERATIVE_SAVE_TEST) != nullptr)
+ ? CooperativeSaveMode::kAvailable
+ : CooperativeSaveMode::kNotAvailable;
+ cooperative_save_mode.store(mode);
+ }
+ return mode == CooperativeSaveMode::kAvailable;
+}
+
+std::atomic<int> save_disable;
+
+} // namespace
+
+DisableSave::DisableSave() { save_disable++; }
+
+DisableSave::~DisableSave() { reset(); }
+
+void DisableSave::reset() {
+ if (!reset_) {
+ reset_ = true;
+ save_disable--;
+ }
+}
+
+namespace internal {
+bool ShouldSave() {
+ return CooperativeSaveEnabled() && (save_disable.load() == 0);
+}
+} // namespace internal
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/save_util.h b/test/util/save_util.h
new file mode 100644
index 000000000..bddad6120
--- /dev/null
+++ b/test/util/save_util.h
@@ -0,0 +1,52 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_SAVE_UTIL_H_
+#define GVISOR_TEST_UTIL_SAVE_UTIL_H_
+
+namespace gvisor {
+namespace testing {
+// Disable save prevents saving while the given function executes.
+//
+// This lasts the duration of the object, unless reset is called.
+class DisableSave {
+ public:
+ DisableSave();
+ ~DisableSave();
+ DisableSave(DisableSave const&) = delete;
+ DisableSave(DisableSave&&) = delete;
+ DisableSave& operator=(DisableSave const&) = delete;
+ DisableSave& operator=(DisableSave&&) = delete;
+
+ // reset allows saves to continue, and is called implicitly by the destructor.
+ // It may be called multiple times safely, but is not thread-safe.
+ void reset();
+
+ private:
+ bool reset_ = false;
+};
+
+// May perform a co-operative save cycle.
+//
+// errno is guaranteed to be preserved.
+void MaybeSave();
+
+namespace internal {
+bool ShouldSave();
+} // namespace internal
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_SAVE_UTIL_H_
diff --git a/test/util/save_util_linux.cc b/test/util/save_util_linux.cc
new file mode 100644
index 000000000..d0aea8e6a
--- /dev/null
+++ b/test/util/save_util_linux.cc
@@ -0,0 +1,49 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifdef __linux__
+
+#include <errno.h>
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include "test/util/save_util.h"
+
+#if defined(__x86_64__) || defined(__i386__)
+#define SYS_TRIGGER_SAVE SYS_create_module
+#elif defined(__aarch64__)
+#define SYS_TRIGGER_SAVE SYS_finit_module
+#else
+#error "Unknown architecture"
+#endif
+
+namespace gvisor {
+namespace testing {
+
+void MaybeSave() {
+ if (internal::ShouldSave()) {
+ int orig_errno = errno;
+ // We use it to trigger saving the sentry state
+ // when this syscall is called.
+ // Notice: this needs to be a valid syscall
+ // that is not used in any of the syscall tests.
+ syscall(SYS_TRIGGER_SAVE, nullptr, 0);
+ errno = orig_errno;
+ }
+}
+
+} // namespace testing
+} // namespace gvisor
+
+#endif
diff --git a/test/util/save_util_other.cc b/test/util/save_util_other.cc
new file mode 100644
index 000000000..931af2c29
--- /dev/null
+++ b/test/util/save_util_other.cc
@@ -0,0 +1,27 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef __linux__
+
+namespace gvisor {
+namespace testing {
+
+void MaybeSave() {
+ // Saving is never available in a non-linux environment.
+}
+
+} // namespace testing
+} // namespace gvisor
+
+#endif
diff --git a/test/util/signal_util.cc b/test/util/signal_util.cc
new file mode 100644
index 000000000..5ee95ee80
--- /dev/null
+++ b/test/util/signal_util.cc
@@ -0,0 +1,104 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/signal_util.h"
+
+#include <signal.h>
+
+#include <ostream>
+
+#include "gtest/gtest.h"
+#include "test/util/cleanup.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace {
+
+struct Range {
+ int start;
+ int end;
+};
+
+// Format a Range as "start-end" or "start" for single value Ranges.
+static ::std::ostream& operator<<(::std::ostream& os, const Range& range) {
+ if (range.end > range.start) {
+ return os << range.start << '-' << range.end;
+ }
+
+ return os << range.start;
+}
+
+} // namespace
+
+// Format a sigset_t as a comma separated list of numeric ranges.
+// Empty sigset: []
+// Full sigset: [1-31,34-64]
+::std::ostream& operator<<(::std::ostream& os, const sigset_t& sigset) {
+ const char* delim = "";
+ Range range = {0, 0};
+
+ os << '[';
+
+ for (int sig = 1; sig <= gvisor::testing::kMaxSignal; ++sig) {
+ if (sigismember(&sigset, sig)) {
+ if (range.start) {
+ range.end = sig;
+ } else {
+ range.start = sig;
+ range.end = sig;
+ }
+ } else if (range.start) {
+ os << delim << range;
+ delim = ",";
+ range.start = 0;
+ range.end = 0;
+ }
+ }
+
+ if (range.start) {
+ os << delim << range;
+ }
+
+ return os << ']';
+}
+
+namespace gvisor {
+namespace testing {
+
+PosixErrorOr<Cleanup> ScopedSigaction(int sig, struct sigaction const& sa) {
+ struct sigaction old_sa;
+ int rc = sigaction(sig, &sa, &old_sa);
+ MaybeSave();
+ if (rc < 0) {
+ return PosixError(errno, "sigaction failed");
+ }
+ return Cleanup([sig, old_sa] {
+ EXPECT_THAT(sigaction(sig, &old_sa, nullptr), SyscallSucceeds());
+ });
+}
+
+PosixErrorOr<Cleanup> ScopedSignalMask(int how, sigset_t const& set) {
+ sigset_t old;
+ int rc = sigprocmask(how, &set, &old);
+ MaybeSave();
+ if (rc < 0) {
+ return PosixError(errno, "sigprocmask failed");
+ }
+ return Cleanup([old] {
+ EXPECT_THAT(sigprocmask(SIG_SETMASK, &old, nullptr), SyscallSucceeds());
+ });
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/signal_util.h b/test/util/signal_util.h
new file mode 100644
index 000000000..e7b66aa51
--- /dev/null
+++ b/test/util/signal_util.h
@@ -0,0 +1,107 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_SIGNAL_UTIL_H_
+#define GVISOR_TEST_UTIL_SIGNAL_UTIL_H_
+
+#include <signal.h>
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include <ostream>
+
+#include "gmock/gmock.h"
+#include "test/util/cleanup.h"
+#include "test/util/posix_error.h"
+
+// Format a sigset_t as a comma separated list of numeric ranges.
+::std::ostream& operator<<(::std::ostream& os, const sigset_t& sigset);
+
+namespace gvisor {
+namespace testing {
+
+// The maximum signal number.
+static constexpr int kMaxSignal = 64;
+
+// Wrapper for the tgkill(2) syscall, which glibc does not provide.
+inline int tgkill(pid_t tgid, pid_t tid, int sig) {
+ return syscall(__NR_tgkill, tgid, tid, sig);
+}
+
+// Installs the passed sigaction and returns a cleanup function to restore the
+// previous handler when it goes out of scope.
+PosixErrorOr<Cleanup> ScopedSigaction(int sig, struct sigaction const& sa);
+
+// Updates the signal mask as per sigprocmask(2) and returns a cleanup function
+// to restore the previous signal mask when it goes out of scope.
+PosixErrorOr<Cleanup> ScopedSignalMask(int how, sigset_t const& set);
+
+// ScopedSignalMask variant that creates a mask of the single signal 'sig'.
+inline PosixErrorOr<Cleanup> ScopedSignalMask(int how, int sig) {
+ sigset_t set;
+ sigemptyset(&set);
+ sigaddset(&set, sig);
+ return ScopedSignalMask(how, set);
+}
+
+// Asserts equality of two sigset_t values.
+MATCHER_P(EqualsSigset, value, "equals " + ::testing::PrintToString(value)) {
+ for (int sig = 1; sig <= kMaxSignal; ++sig) {
+ if (sigismember(&arg, sig) != sigismember(&value, sig)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+#ifdef __x86_64__
+// Fault can be used to generate a synchronous SIGSEGV.
+//
+// This fault can be fixed up in a handler via fixup, below.
+inline void Fault() {
+ // Zero and dereference %ax.
+ asm("movabs $0, %%rax\r\n"
+ "mov 0(%%rax), %%rax\r\n"
+ :
+ :
+ : "ax");
+}
+
+// FixupFault fixes up a fault generated by fault, above.
+inline void FixupFault(ucontext_t* ctx) {
+ // Skip the bad instruction above.
+ //
+ // The encoding is 0x48 0xab 0x00.
+ ctx->uc_mcontext.gregs[REG_RIP] += 3;
+}
+#elif __aarch64__
+inline void Fault() {
+ // Zero and dereference x0.
+ asm("mov xzr, x0\r\n"
+ "str xzr, [x0]\r\n"
+ :
+ :
+ : "x0");
+}
+
+inline void FixupFault(ucontext_t* ctx) {
+ // Skip the bad instruction above.
+ ctx->uc_mcontext.pc += 4;
+}
+#endif
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_SIGNAL_UTIL_H_
diff --git a/test/util/temp_path.cc b/test/util/temp_path.cc
new file mode 100644
index 000000000..e1bdee7fd
--- /dev/null
+++ b/test/util/temp_path.cc
@@ -0,0 +1,164 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/temp_path.h"
+
+#include <unistd.h>
+
+#include <atomic>
+#include <cstdlib>
+#include <iostream>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+std::atomic<uint64_t> global_temp_file_number = ATOMIC_VAR_INIT(1);
+
+// Return a new temp filename, intended to be unique system-wide.
+//
+// The global file number helps maintain file naming consistency across
+// different runs of a test.
+//
+// The timestamp is necessary because the test infrastructure invokes each
+// test case in a separate process (resetting global_temp_file_number) and
+// potentially in parallel, which allows for races between selecting and using a
+// name.
+std::string NextTempBasename() {
+ return absl::StrCat("gvisor_test_temp_", global_temp_file_number++, "_",
+ absl::ToUnixNanos(absl::Now()));
+}
+
+void TryDeleteRecursively(std::string const& path) {
+ if (!path.empty()) {
+ int undeleted_dirs = 0;
+ int undeleted_files = 0;
+ auto status = RecursivelyDelete(path, &undeleted_dirs, &undeleted_files);
+ if (undeleted_dirs || undeleted_files || !status.ok()) {
+ std::cerr << path << ": failed to delete " << undeleted_dirs
+ << " directories and " << undeleted_files
+ << " files: " << status << std::endl;
+ }
+ }
+}
+
+} // namespace
+
+constexpr mode_t TempPath::kDefaultFileMode;
+constexpr mode_t TempPath::kDefaultDirMode;
+
+std::string NewTempAbsPathInDir(absl::string_view const dir) {
+ return JoinPath(dir, NextTempBasename());
+}
+
+std::string NewTempAbsPath() {
+ return NewTempAbsPathInDir(GetAbsoluteTestTmpdir());
+}
+
+std::string NewTempRelPath() { return NextTempBasename(); }
+
+std::string GetAbsoluteTestTmpdir() {
+ // Note that TEST_TMPDIR is guaranteed to be set.
+ char* env_tmpdir = getenv("TEST_TMPDIR");
+ std::string tmp_dir =
+ env_tmpdir != nullptr ? std::string(env_tmpdir) : "/tmp";
+
+ return MakeAbsolute(tmp_dir, "").ValueOrDie();
+}
+
+PosixErrorOr<TempPath> TempPath::CreateFileWith(absl::string_view const parent,
+ absl::string_view const content,
+ mode_t const mode) {
+ return CreateIn(parent, [=](absl::string_view path) -> PosixError {
+ // CreateWithContents will call open(O_WRONLY) with the given mode. If the
+ // mode is not user-writable, save/restore cannot preserve the fd. Hence
+ // the little permission dance that's done here.
+ auto res = CreateWithContents(path, content, mode | 0200);
+ RETURN_IF_ERRNO(res);
+
+ return Chmod(path, mode);
+ });
+}
+
+PosixErrorOr<TempPath> TempPath::CreateDirWith(absl::string_view const parent,
+ mode_t const mode) {
+ return CreateIn(parent,
+ [=](absl::string_view path) { return Mkdir(path, mode); });
+}
+
+PosixErrorOr<TempPath> TempPath::CreateSymlinkTo(absl::string_view const parent,
+ std::string const& dest) {
+ return CreateIn(parent, [=](absl::string_view path) {
+ int ret = symlink(dest.c_str(), std::string(path).c_str());
+ if (ret != 0) {
+ return PosixError(errno, "symlink failed");
+ }
+ return NoError();
+ });
+}
+
+PosixErrorOr<TempPath> TempPath::CreateFileIn(absl::string_view const parent) {
+ return TempPath::CreateFileWith(parent, absl::string_view(),
+ kDefaultFileMode);
+}
+
+PosixErrorOr<TempPath> TempPath::CreateDirIn(absl::string_view const parent) {
+ return TempPath::CreateDirWith(parent, kDefaultDirMode);
+}
+
+PosixErrorOr<TempPath> TempPath::CreateFileMode(mode_t mode) {
+ return TempPath::CreateFileWith(GetAbsoluteTestTmpdir(), absl::string_view(),
+ mode);
+}
+
+PosixErrorOr<TempPath> TempPath::CreateFile() {
+ return TempPath::CreateFileIn(GetAbsoluteTestTmpdir());
+}
+
+PosixErrorOr<TempPath> TempPath::CreateDir() {
+ return TempPath::CreateDirIn(GetAbsoluteTestTmpdir());
+}
+
+TempPath::~TempPath() { TryDeleteRecursively(path_); }
+
+TempPath::TempPath(TempPath&& orig) { reset(orig.release()); }
+
+TempPath& TempPath::operator=(TempPath&& orig) {
+ reset(orig.release());
+ return *this;
+}
+
+std::string TempPath::reset(std::string newpath) {
+ std::string path = path_;
+ TryDeleteRecursively(path_);
+ path_ = std::move(newpath);
+ return path;
+}
+
+std::string TempPath::release() {
+ std::string path = path_;
+ path_ = std::string();
+ return path;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/temp_path.h b/test/util/temp_path.h
new file mode 100644
index 000000000..9e5ac11f4
--- /dev/null
+++ b/test/util/temp_path.h
@@ -0,0 +1,135 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_TEMP_PATH_H_
+#define GVISOR_TEST_UTIL_TEMP_PATH_H_
+
+#include <sys/stat.h>
+
+#include <string>
+#include <utility>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+// Returns an absolute path for a file in `dir` that does not yet exist.
+// Distinct calls to NewTempAbsPathInDir from the same process, even from
+// multiple threads, are guaranteed to return different paths. Distinct calls to
+// NewTempAbsPathInDir from different processes are not synchronized.
+std::string NewTempAbsPathInDir(absl::string_view const dir);
+
+// Like NewTempAbsPathInDir, but the returned path is in the test's temporary
+// directory, as provided by the testing framework.
+std::string NewTempAbsPath();
+
+// Like NewTempAbsPathInDir, but the returned path is relative (to the current
+// working directory).
+std::string NewTempRelPath();
+
+// Returns the absolute path for the test temp dir.
+std::string GetAbsoluteTestTmpdir();
+
+// Represents a temporary file or directory.
+class TempPath {
+ public:
+ // Default creation mode for files.
+ static constexpr mode_t kDefaultFileMode = 0644;
+
+ // Default creation mode for directories.
+ static constexpr mode_t kDefaultDirMode = 0755;
+
+ // Creates a temporary file in directory `parent` with mode `mode` and
+ // contents `content`.
+ static PosixErrorOr<TempPath> CreateFileWith(absl::string_view parent,
+ absl::string_view content,
+ mode_t mode);
+
+ // Creates an empty temporary subdirectory in directory `parent` with mode
+ // `mode`.
+ static PosixErrorOr<TempPath> CreateDirWith(absl::string_view parent,
+ mode_t mode);
+
+ // Creates a temporary symlink in directory `parent` to destination `dest`.
+ static PosixErrorOr<TempPath> CreateSymlinkTo(absl::string_view parent,
+ std::string const& dest);
+
+ // Creates an empty temporary file in directory `parent` with mode
+ // kDefaultFileMode.
+ static PosixErrorOr<TempPath> CreateFileIn(absl::string_view parent);
+
+ // Creates an empty temporary subdirectory in directory `parent` with mode
+ // kDefaultDirMode.
+ static PosixErrorOr<TempPath> CreateDirIn(absl::string_view parent);
+
+ // Creates an empty temporary file in the test's temporary directory with mode
+ // `mode`.
+ static PosixErrorOr<TempPath> CreateFileMode(mode_t mode);
+
+ // Creates an empty temporary file in the test's temporary directory with
+ // mode kDefaultFileMode.
+ static PosixErrorOr<TempPath> CreateFile();
+
+ // Creates an empty temporary subdirectory in the test's temporary directory
+ // with mode kDefaultDirMode.
+ static PosixErrorOr<TempPath> CreateDir();
+
+ // Constructs a TempPath that represents nothing.
+ TempPath() = default;
+
+ // Constructs a TempPath that represents the given path, which will be deleted
+ // when the TempPath is destroyed.
+ explicit TempPath(std::string path) : path_(std::move(path)) {}
+
+ // Attempts to delete the represented temporary file or directory (in the
+ // latter case, also attempts to delete its contents).
+ ~TempPath();
+
+ // Attempts to delete the represented temporary file or directory, then
+ // transfers ownership of the path represented by orig to this TempPath.
+ TempPath(TempPath&& orig);
+ TempPath& operator=(TempPath&& orig);
+
+ // Changes the path this TempPath represents. If the TempPath already
+ // represented a path, deletes and returns that path. Otherwise returns the
+ // empty string.
+ std::string reset(std::string newpath);
+ std::string reset() { return reset(""); }
+
+ // Forgets and returns the path this TempPath represents. The path is not
+ // deleted.
+ std::string release();
+
+ // Returns the path this TempPath represents.
+ std::string path() const { return path_; }
+
+ private:
+ template <typename F>
+ static PosixErrorOr<TempPath> CreateIn(absl::string_view const parent,
+ F const& f) {
+ std::string path = NewTempAbsPathInDir(parent);
+ RETURN_IF_ERRNO(f(path));
+ return TempPath(std::move(path));
+ }
+
+ std::string path_;
+};
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_TEMP_PATH_H_
diff --git a/test/util/temp_umask.h b/test/util/temp_umask.h
new file mode 100644
index 000000000..e7de84a54
--- /dev/null
+++ b/test/util/temp_umask.h
@@ -0,0 +1,39 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_TEMP_UMASK_H_
+#define GVISOR_TEST_UTIL_TEMP_UMASK_H_
+
+#include <sys/stat.h>
+#include <sys/types.h>
+
+namespace gvisor {
+namespace testing {
+
+class TempUmask {
+ public:
+ // Sets the process umask to `mask`.
+ explicit TempUmask(mode_t mask) : old_mask_(umask(mask)) {}
+
+ // Sets the process umask to its previous value.
+ ~TempUmask() { umask(old_mask_); }
+
+ private:
+ mode_t old_mask_;
+};
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_TEMP_UMASK_H_
diff --git a/test/util/test_main.cc b/test/util/test_main.cc
new file mode 100644
index 000000000..1f389e58f
--- /dev/null
+++ b/test/util/test_main.cc
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/test_util.h"
+
+int main(int argc, char** argv) {
+ gvisor::testing::TestInit(&argc, &argv);
+ return gvisor::testing::RunAllTests();
+}
diff --git a/test/util/test_util.cc b/test/util/test_util.cc
new file mode 100644
index 000000000..8a037f45f
--- /dev/null
+++ b/test/util/test_util.cc
@@ -0,0 +1,233 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/test_util.h"
+
+#include <limits.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <sys/uio.h>
+#include <sys/utsname.h>
+#include <unistd.h>
+
+#include <ctime>
+#include <iostream>
+#include <vector>
+
+#include "absl/base/attributes.h"
+#include "absl/flags/flag.h" // IWYU pragma: keep
+#include "absl/flags/parse.h" // IWYU pragma: keep
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+#include "absl/time/time.h"
+#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+constexpr char kGvisorNetwork[] = "GVISOR_NETWORK";
+constexpr char kGvisorVfs[] = "GVISOR_VFS";
+
+bool IsRunningOnGvisor() { return GvisorPlatform() != Platform::kNative; }
+
+const std::string GvisorPlatform() {
+ // Set by runner.go.
+ const char* env = getenv(kTestOnGvisor);
+ if (!env) {
+ return Platform::kNative;
+ }
+ return std::string(env);
+}
+
+bool IsRunningWithHostinet() {
+ const char* env = getenv(kGvisorNetwork);
+ return env && strcmp(env, "host") == 0;
+}
+
+bool IsRunningWithVFS1() {
+ const char* env = getenv(kGvisorVfs);
+ if (env == nullptr) {
+ // If not set, it's running on Linux.
+ return false;
+ }
+ return strcmp(env, "VFS1") == 0;
+}
+
+// Inline cpuid instruction. Preserve %ebx/%rbx register. In PIC compilations
+// %ebx contains the address of the global offset table. %rbx is occasionally
+// used to address stack variables in presence of dynamic allocas.
+#if defined(__x86_64__)
+#define GETCPUID(a, b, c, d, a_inp, c_inp) \
+ asm("mov %%rbx, %%rdi\n" \
+ "cpuid\n" \
+ "xchg %%rdi, %%rbx\n" \
+ : "=a"(a), "=D"(b), "=c"(c), "=d"(d) \
+ : "a"(a_inp), "2"(c_inp))
+
+CPUVendor GetCPUVendor() {
+ uint32_t eax, ebx, ecx, edx;
+ std::string vendor_str;
+ // Get vendor string (issue CPUID with eax = 0)
+ GETCPUID(eax, ebx, ecx, edx, 0, 0);
+ vendor_str.append(reinterpret_cast<char*>(&ebx), 4);
+ vendor_str.append(reinterpret_cast<char*>(&edx), 4);
+ vendor_str.append(reinterpret_cast<char*>(&ecx), 4);
+ if (vendor_str == "GenuineIntel") {
+ return CPUVendor::kIntel;
+ } else if (vendor_str == "AuthenticAMD") {
+ return CPUVendor::kAMD;
+ }
+ return CPUVendor::kUnknownVendor;
+}
+#endif // defined(__x86_64__)
+
+bool operator==(const KernelVersion& first, const KernelVersion& second) {
+ return first.major == second.major && first.minor == second.minor &&
+ first.micro == second.micro;
+}
+
+PosixErrorOr<KernelVersion> ParseKernelVersion(absl::string_view vers_str) {
+ KernelVersion version = {};
+ std::vector<std::string> values =
+ absl::StrSplit(vers_str, absl::ByAnyChar(".-"));
+ if (values.size() == 2) {
+ ASSIGN_OR_RETURN_ERRNO(version.major, Atoi<int>(values[0]));
+ ASSIGN_OR_RETURN_ERRNO(version.minor, Atoi<int>(values[1]));
+ return version;
+ } else if (values.size() >= 3) {
+ ASSIGN_OR_RETURN_ERRNO(version.major, Atoi<int>(values[0]));
+ ASSIGN_OR_RETURN_ERRNO(version.minor, Atoi<int>(values[1]));
+ ASSIGN_OR_RETURN_ERRNO(version.micro, Atoi<int>(values[2]));
+ return version;
+ }
+ return PosixError(EINVAL, absl::StrCat("Unknown kernel release: ", vers_str));
+}
+
+PosixErrorOr<KernelVersion> GetKernelVersion() {
+ utsname buf;
+ RETURN_ERROR_IF_SYSCALL_FAIL(uname(&buf));
+ return ParseKernelVersion(buf.release);
+}
+
+std::string CPUSetToString(const cpu_set_t& set, size_t cpus) {
+ std::string str = "cpuset[";
+ for (unsigned int n = 0; n < cpus; n++) {
+ if (CPU_ISSET(n, &set)) {
+ if (n != 0) {
+ absl::StrAppend(&str, " ");
+ }
+ absl::StrAppend(&str, n);
+ }
+ }
+ absl::StrAppend(&str, "]");
+ return str;
+}
+
+// An overloaded operator<< makes it easy to dump the value of an OpenFd.
+std::ostream& operator<<(std::ostream& out, OpenFd const& ofd) {
+ out << ofd.fd << " -> " << ofd.link;
+ return out;
+}
+
+// An overloaded operator<< makes it easy to dump a vector of OpenFDs.
+std::ostream& operator<<(std::ostream& out, std::vector<OpenFd> const& v) {
+ for (const auto& ofd : v) {
+ out << ofd << std::endl;
+ }
+ return out;
+}
+
+PosixErrorOr<std::vector<OpenFd>> GetOpenFDs() {
+ // Get the results from /proc/self/fd.
+ ASSIGN_OR_RETURN_ERRNO(auto dir_list,
+ ListDir("/proc/self/fd", /*skipdots=*/true));
+
+ std::vector<OpenFd> ret_fds;
+ for (const auto& str_fd : dir_list) {
+ OpenFd open_fd = {};
+ ASSIGN_OR_RETURN_ERRNO(open_fd.fd, Atoi<int>(str_fd));
+ std::string path = absl::StrCat("/proc/self/fd/", open_fd.fd);
+
+ // Resolve the link.
+ char buf[PATH_MAX] = {};
+ int ret = readlink(path.c_str(), buf, sizeof(buf));
+ if (ret < 0) {
+ if (errno == ENOENT) {
+ // The FD may have been closed, let's be resilient.
+ continue;
+ }
+
+ return PosixError(
+ errno, absl::StrCat("readlink of ", path, " returned errno ", errno));
+ }
+ open_fd.link = std::string(buf, ret);
+ ret_fds.emplace_back(std::move(open_fd));
+ }
+ return ret_fds;
+}
+
+PosixErrorOr<uint64_t> Links(const std::string& path) {
+ struct stat st;
+ if (stat(path.c_str(), &st)) {
+ return PosixError(errno, absl::StrCat("Failed to stat ", path));
+ }
+ return static_cast<uint64_t>(st.st_nlink);
+}
+
+void RandomizeBuffer(void* buffer, size_t len) {
+ struct timespec ts = {};
+ clock_gettime(CLOCK_MONOTONIC, &ts);
+ uint32_t seed = static_cast<uint32_t>(ts.tv_nsec);
+ char* const buf = static_cast<char*>(buffer);
+ for (size_t i = 0; i < len; i++) {
+ buf[i] = rand_r(&seed) % 255;
+ }
+}
+
+std::vector<std::vector<struct iovec>> GenerateIovecs(uint64_t total_size,
+ void* buf,
+ size_t buflen) {
+ std::vector<std::vector<struct iovec>> result;
+ for (uint64_t offset = 0; offset < total_size;) {
+ auto& iovec_array = *result.emplace(result.end());
+
+ for (; offset < total_size && iovec_array.size() < IOV_MAX;
+ offset += buflen) {
+ struct iovec iov = {};
+ iov.iov_base = buf;
+ iov.iov_len = std::min<uint64_t>(total_size - offset, buflen);
+ iovec_array.push_back(iov);
+ }
+ }
+
+ return result;
+}
+
+uint64_t Megabytes(uint64_t n) {
+ // Overflow check, upper 20 bits in n shouldn't be set.
+ TEST_CHECK(!(0xfffff00000000000 & n));
+ return n << 20;
+}
+
+bool Equivalent(uint64_t current, uint64_t target, double tolerance) {
+ auto abs_diff = target > current ? target - current : current - target;
+ return abs_diff <= static_cast<uint64_t>(tolerance * target);
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/test_util.h b/test/util/test_util.h
new file mode 100644
index 000000000..109078fc7
--- /dev/null
+++ b/test/util/test_util.h
@@ -0,0 +1,784 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Utilities for syscall testing.
+//
+// Initialization
+// ==============
+//
+// Prior to calling RUN_ALL_TESTS, all tests must use TestInit(&argc, &argv).
+// See the TestInit function for exact side-effects and semantics.
+//
+// Configuration
+// =============
+//
+// IsRunningOnGvisor returns true if the test is known to be running on gVisor.
+// GvisorPlatform can be used to get more detail:
+//
+// if (GvisorPlatform() == Platform::kPtrace) {
+// ...
+// }
+//
+// SetupGvisorDeathTest ensures that signal handling does not interfere with
+/// tests that rely on fatal signals.
+//
+// Matchers
+// ========
+//
+// ElementOf(xs) matches if the matched value is equal to an element of the
+// container xs. Example:
+//
+// // PASS
+// EXPECT_THAT(1, ElementOf({0, 1, 2}));
+//
+// // FAIL
+// // Value of: 3
+// // Expected: one of {0, 1, 2}
+// // Actual: 3
+// EXPECT_THAT(3, ElementOf({0, 1, 2}));
+//
+// SyscallSucceeds() matches if the syscall is successful. A successful syscall
+// is defined by either a return value not equal to -1, or a return value of -1
+// with an errno of 0 (which is a possible successful return for e.g.
+// PTRACE_PEEK). Example:
+//
+// // PASS
+// EXPECT_THAT(open("/dev/null", O_RDONLY), SyscallSucceeds());
+//
+// // FAIL
+// // Value of: open("/", O_RDWR)
+// // Expected: not -1 (success)
+// // Actual: -1 (of type int), with errno 21 (Is a directory)
+// EXPECT_THAT(open("/", O_RDWR), SyscallSucceeds());
+//
+// SyscallSucceedsWithValue(m) matches if the syscall is successful, and the
+// value also matches m. Example:
+//
+// // PASS
+// EXPECT_THAT(read(4, buf, 8192), SyscallSucceedsWithValue(8192));
+//
+// // FAIL
+// // Value of: read(-1, buf, 8192)
+// // Expected: is equal to 8192
+// // Actual: -1 (of type long), with errno 9 (Bad file number)
+// EXPECT_THAT(read(-1, buf, 8192), SyscallSucceedsWithValue(8192));
+//
+// // FAIL
+// // Value of: read(4, buf, 1)
+// // Expected: is > 4096
+// // Actual: 1 (of type long)
+// EXPECT_THAT(read(4, buf, 1), SyscallSucceedsWithValue(Gt(4096)));
+//
+// SyscallFails() matches if the syscall is unsuccessful. An unsuccessful
+// syscall is defined by a return value of -1 with a non-zero errno. Example:
+//
+// // PASS
+// EXPECT_THAT(open("/", O_RDWR), SyscallFails());
+//
+// // FAIL
+// // Value of: open("/dev/null", O_RDONLY)
+// // Expected: -1 (failure)
+// // Actual: 0 (of type int)
+// EXPECT_THAT(open("/dev/null", O_RDONLY), SyscallFails());
+//
+// SyscallFailsWithErrno(m) matches if the syscall is unsuccessful, and errno
+// matches m. Example:
+//
+// // PASS
+// EXPECT_THAT(open("/", O_RDWR), SyscallFailsWithErrno(EISDIR));
+//
+// // PASS
+// EXPECT_THAT(open("/etc/passwd", O_RDWR | O_DIRECTORY),
+// SyscallFailsWithErrno(AnyOf(EACCES, ENOTDIR)));
+//
+// // FAIL
+// // Value of: open("/dev/null", O_RDONLY)
+// // Expected: -1 (failure) with errno 21 (Is a directory)
+// // Actual: 0 (of type int)
+// EXPECT_THAT(open("/dev/null", O_RDONLY), SyscallFailsWithErrno(EISDIR));
+//
+// // FAIL
+// // Value of: open("/", O_RDWR)
+// // Expected: -1 (failure) with errno 22 (Invalid argument)
+// // Actual: -1 (of type int), failure, but with errno 21 (Is a directory)
+// EXPECT_THAT(open("/", O_RDWR), SyscallFailsWithErrno(EINVAL));
+//
+// Because the syscall matchers encode save/restore functionality, their meaning
+// should not be inverted via Not. That is, AnyOf(SyscallSucceedsWithValue(1),
+// SyscallSucceedsWithValue(2)) is permitted, but not
+// Not(SyscallFailsWithErrno(EPERM)).
+//
+// Syscalls
+// ========
+//
+// RetryEINTR wraps a function that returns -1 and sets errno on failure
+// to be automatically retried when EINTR occurs. Example:
+//
+// auto rv = RetryEINTR(waitpid)(pid, &status, 0);
+//
+// ReadFd/WriteFd/PreadFd/PwriteFd are interface-compatible wrappers around the
+// read/write/pread/pwrite syscalls to handle both EINTR and partial
+// reads/writes. Example:
+//
+// EXPECT_THAT(ReadFd(fd, &buf, size), SyscallSucceedsWithValue(size));
+//
+// General Utilities
+// =================
+//
+// ApplyVec(f, xs) returns a vector containing the result of applying function
+// `f` to each value in `xs`.
+//
+// AllBitwiseCombinations takes a variadic number of ranges containing integers
+// and returns a vector containing every integer that can be formed by ORing
+// together exactly one integer from each list. List<T> is an alias for
+// std::initializer_list<T> that makes AllBitwiseCombinations more ergonomic to
+// use with list literals (initializer lists do not otherwise participate in
+// template argument deduction). Example:
+//
+// EXPECT_THAT(
+// AllBitwiseCombinations<int>(
+// List<int>{SOCK_DGRAM, SOCK_STREAM},
+// List<int>{0, SOCK_NONBLOCK}),
+// Contains({SOCK_DGRAM, SOCK_STREAM, SOCK_DGRAM | SOCK_NONBLOCK,
+// SOCK_STREAM | SOCK_NONBLOCK}));
+//
+// VecCat takes a variadic number of containers and returns a vector containing
+// the concatenated contents.
+//
+// VecAppend takes an initial container and a variadic number of containers and
+// appends each to the initial container.
+//
+// RandomizeBuffer will use MTRandom to fill the given buffer with random bytes.
+//
+// GenerateIovecs will return the smallest number of iovec arrays for writing a
+// given total number of bytes to a file, each iovec array size up to IOV_MAX,
+// each iovec in each array pointing to the same buffer.
+
+#ifndef GVISOR_TEST_UTIL_TEST_UTIL_H_
+#define GVISOR_TEST_UTIL_TEST_UTIL_H_
+
+#include <stddef.h>
+#include <stdlib.h>
+#include <sys/uio.h>
+#include <unistd.h>
+
+#include <algorithm>
+#include <cerrno>
+#include <initializer_list>
+#include <iterator>
+#include <string>
+#include <thread> // NOLINT: using std::thread::hardware_concurrency().
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "absl/time/time.h"
+#include "test/util/fs_util.h"
+#include "test/util/logging.h"
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+
+namespace gvisor {
+namespace testing {
+
+constexpr char kTestOnGvisor[] = "TEST_ON_GVISOR";
+
+// TestInit must be called prior to RUN_ALL_TESTS.
+//
+// This parses all arguments and adjusts argc and argv appropriately.
+//
+// TestInit may create background threads.
+void TestInit(int* argc, char*** argv);
+
+// SKIP_IF may be used to skip a test case.
+//
+// These cases are still emitted, but a SKIPPED line will appear.
+#define SKIP_IF(expr) \
+ do { \
+ if (expr) GTEST_SKIP() << #expr; \
+ } while (0)
+
+// Platform contains platform names.
+namespace Platform {
+constexpr char kNative[] = "native";
+constexpr char kPtrace[] = "ptrace";
+constexpr char kKVM[] = "kvm";
+constexpr char kFuchsia[] = "fuchsia";
+} // namespace Platform
+
+bool IsRunningOnGvisor();
+const std::string GvisorPlatform();
+bool IsRunningWithHostinet();
+// TODO(gvisor.dev/issue/1624): Delete once VFS1 is gone.
+bool IsRunningWithVFS1();
+
+#ifdef __linux__
+void SetupGvisorDeathTest();
+#endif
+
+struct KernelVersion {
+ int major;
+ int minor;
+ int micro;
+};
+
+bool operator==(const KernelVersion& first, const KernelVersion& second);
+
+PosixErrorOr<KernelVersion> ParseKernelVersion(absl::string_view vers_string);
+PosixErrorOr<KernelVersion> GetKernelVersion();
+
+static const size_t kPageSize = sysconf(_SC_PAGESIZE);
+
+enum class CPUVendor { kIntel, kAMD, kUnknownVendor };
+
+CPUVendor GetCPUVendor();
+
+inline int NumCPUs() { return std::thread::hardware_concurrency(); }
+
+// Converts cpu_set_t to a std::string for easy examination.
+std::string CPUSetToString(const cpu_set_t& set, size_t cpus = CPU_SETSIZE);
+
+struct OpenFd {
+ // fd is the open file descriptor number.
+ int fd = -1;
+
+ // link is the resolution of the symbolic link.
+ std::string link;
+};
+
+// Make it easier to log OpenFds to error streams.
+std::ostream& operator<<(std::ostream& out, std::vector<OpenFd> const& v);
+std::ostream& operator<<(std::ostream& out, OpenFd const& ofd);
+
+// Gets a detailed list of open fds for this process.
+PosixErrorOr<std::vector<OpenFd>> GetOpenFDs();
+
+// Returns the number of hard links to a path.
+PosixErrorOr<uint64_t> Links(const std::string& path);
+
+namespace internal {
+
+template <typename Container>
+class ElementOfMatcher {
+ public:
+ explicit ElementOfMatcher(Container container)
+ : container_(::std::move(container)) {}
+
+ template <typename T>
+ bool MatchAndExplain(T const& rv,
+ ::testing::MatchResultListener* const listener) const {
+ using std::count;
+ return count(container_.begin(), container_.end(), rv) != 0;
+ }
+
+ void DescribeTo(::std::ostream* const os) const {
+ *os << "one of {";
+ char const* sep = "";
+ for (auto const& elem : container_) {
+ *os << sep << elem;
+ sep = ", ";
+ }
+ *os << "}";
+ }
+
+ void DescribeNegationTo(::std::ostream* const os) const {
+ *os << "none of {";
+ char const* sep = "";
+ for (auto const& elem : container_) {
+ *os << sep << elem;
+ sep = ", ";
+ }
+ *os << "}";
+ }
+
+ private:
+ Container const container_;
+};
+
+template <typename E>
+class SyscallSuccessMatcher {
+ public:
+ explicit SyscallSuccessMatcher(E expected)
+ : expected_(::std::move(expected)) {}
+
+ template <typename T>
+ operator ::testing::Matcher<T>() const {
+ // E is one of three things:
+ // - T, or a type losslessly and implicitly convertible to T.
+ // - A monomorphic Matcher<T>.
+ // - A polymorphic matcher.
+ // SafeMatcherCast handles any of the above correctly.
+ //
+ // Similarly, gMock will invoke this conversion operator to obtain a
+ // monomorphic matcher (this is how polymorphic matchers are implemented).
+ return ::testing::MakeMatcher(
+ new Impl<T>(::testing::SafeMatcherCast<T>(expected_)));
+ }
+
+ private:
+ template <typename T>
+ class Impl : public ::testing::MatcherInterface<T> {
+ public:
+ explicit Impl(::testing::Matcher<T> matcher)
+ : matcher_(::std::move(matcher)) {}
+
+ bool MatchAndExplain(
+ T const& rv,
+ ::testing::MatchResultListener* const listener) const override {
+ if (rv == static_cast<decltype(rv)>(-1) && errno != 0) {
+ *listener << "with errno " << PosixError(errno);
+ return false;
+ }
+ bool match = matcher_.MatchAndExplain(rv, listener);
+ if (match) {
+ MaybeSave();
+ }
+ return match;
+ }
+
+ void DescribeTo(::std::ostream* const os) const override {
+ matcher_.DescribeTo(os);
+ }
+
+ void DescribeNegationTo(::std::ostream* const os) const override {
+ matcher_.DescribeNegationTo(os);
+ }
+
+ private:
+ ::testing::Matcher<T> matcher_;
+ };
+
+ private:
+ E expected_;
+};
+
+// A polymorphic matcher equivalent to ::testing::internal::AnyMatcher, except
+// not in namespace ::testing::internal, and describing SyscallSucceeds()'s
+// match constraints (which are enforced by SyscallSuccessMatcher::Impl).
+class AnySuccessValueMatcher {
+ public:
+ template <typename T>
+ operator ::testing::Matcher<T>() const {
+ return ::testing::MakeMatcher(new Impl<T>());
+ }
+
+ private:
+ template <typename T>
+ class Impl : public ::testing::MatcherInterface<T> {
+ public:
+ bool MatchAndExplain(
+ T const& rv,
+ ::testing::MatchResultListener* const listener) const override {
+ return true;
+ }
+
+ void DescribeTo(::std::ostream* const os) const override {
+ *os << "not -1 (success)";
+ }
+
+ void DescribeNegationTo(::std::ostream* const os) const override {
+ *os << "-1 (failure)";
+ }
+ };
+};
+
+class SyscallFailureMatcher {
+ public:
+ explicit SyscallFailureMatcher(::testing::Matcher<int> errno_matcher)
+ : errno_matcher_(std::move(errno_matcher)) {}
+
+ template <typename T>
+ bool MatchAndExplain(T const& rv,
+ ::testing::MatchResultListener* const listener) const {
+ if (rv != static_cast<decltype(rv)>(-1)) {
+ return false;
+ }
+ int actual_errno = errno;
+ *listener << "with errno " << PosixError(actual_errno);
+ bool match = errno_matcher_.MatchAndExplain(actual_errno, listener);
+ if (match) {
+ MaybeSave();
+ }
+ return match;
+ }
+
+ void DescribeTo(::std::ostream* const os) const {
+ *os << "-1 (failure), with errno ";
+ errno_matcher_.DescribeTo(os);
+ }
+
+ void DescribeNegationTo(::std::ostream* const os) const {
+ *os << "not -1 (success), with errno ";
+ errno_matcher_.DescribeNegationTo(os);
+ }
+
+ private:
+ ::testing::Matcher<int> errno_matcher_;
+};
+
+class SpecificErrnoMatcher : public ::testing::MatcherInterface<int> {
+ public:
+ explicit SpecificErrnoMatcher(int const expected) : expected_(expected) {}
+
+ bool MatchAndExplain(
+ int const actual_errno,
+ ::testing::MatchResultListener* const listener) const override {
+ return actual_errno == expected_;
+ }
+
+ void DescribeTo(::std::ostream* const os) const override {
+ *os << PosixError(expected_);
+ }
+
+ void DescribeNegationTo(::std::ostream* const os) const override {
+ *os << "not " << PosixError(expected_);
+ }
+
+ private:
+ int const expected_;
+};
+
+inline ::testing::Matcher<int> SpecificErrno(int const expected) {
+ return ::testing::MakeMatcher(new SpecificErrnoMatcher(expected));
+}
+
+} // namespace internal
+
+template <typename Container>
+inline ::testing::PolymorphicMatcher<internal::ElementOfMatcher<Container>>
+ElementOf(Container container) {
+ return ::testing::MakePolymorphicMatcher(
+ internal::ElementOfMatcher<Container>(::std::move(container)));
+}
+
+template <typename T>
+inline ::testing::PolymorphicMatcher<
+ internal::ElementOfMatcher<::std::vector<T>>>
+ElementOf(::std::initializer_list<T> elems) {
+ return ::testing::MakePolymorphicMatcher(
+ internal::ElementOfMatcher<::std::vector<T>>(::std::vector<T>(elems)));
+}
+
+template <typename E>
+inline internal::SyscallSuccessMatcher<E> SyscallSucceedsWithValue(E expected) {
+ return internal::SyscallSuccessMatcher<E>(::std::move(expected));
+}
+
+inline internal::SyscallSuccessMatcher<internal::AnySuccessValueMatcher>
+SyscallSucceeds() {
+ return SyscallSucceedsWithValue(
+ ::gvisor::testing::internal::AnySuccessValueMatcher());
+}
+
+inline ::testing::PolymorphicMatcher<internal::SyscallFailureMatcher>
+SyscallFailsWithErrno(::testing::Matcher<int> expected) {
+ return ::testing::MakePolymorphicMatcher(
+ internal::SyscallFailureMatcher(::std::move(expected)));
+}
+
+// Overload taking an int so that SyscallFailsWithErrno(<specific errno>) uses
+// internal::SpecificErrno (which stringifies the errno) rather than
+// ::testing::Eq (which doesn't).
+inline ::testing::PolymorphicMatcher<internal::SyscallFailureMatcher>
+SyscallFailsWithErrno(int const expected) {
+ return SyscallFailsWithErrno(internal::SpecificErrno(expected));
+}
+
+inline ::testing::PolymorphicMatcher<internal::SyscallFailureMatcher>
+SyscallFails() {
+ return SyscallFailsWithErrno(::testing::Gt(0));
+}
+
+// As of GCC 7.2, -Wall => -Wc++17-compat => -Wnoexcept-type generates an
+// irrelevant, non-actionable warning about ABI compatibility when
+// RetryEINTRImpl is constructed with a noexcept function, such as glibc's
+// syscall(). See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80985.
+#if defined(__GNUC__) && !defined(__clang__) && \
+ (__GNUC__ > 7 || (__GNUC__ == 7 && __GNUC_MINOR__ >= 2))
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wnoexcept-type"
+#endif
+
+namespace internal {
+
+template <typename F>
+struct RetryEINTRImpl {
+ F const f;
+
+ explicit constexpr RetryEINTRImpl(F f) : f(std::move(f)) {}
+
+ template <typename... Args>
+ auto operator()(Args&&... args) const
+ -> decltype(f(std::forward<Args>(args)...)) {
+ while (true) {
+ errno = 0;
+ auto const ret = f(std::forward<Args>(args)...);
+ if (ret != -1 || errno != EINTR) {
+ return ret;
+ }
+ }
+ }
+};
+
+} // namespace internal
+
+template <typename F>
+constexpr internal::RetryEINTRImpl<F> RetryEINTR(F&& f) {
+ return internal::RetryEINTRImpl<F>(std::forward<F>(f));
+}
+
+#if defined(__GNUC__) && !defined(__clang__) && \
+ (__GNUC__ > 7 || (__GNUC__ == 7 && __GNUC_MINOR__ >= 2))
+#pragma GCC diagnostic pop
+#endif
+
+namespace internal {
+
+template <typename F>
+ssize_t ApplyFileIoSyscall(F const& f, size_t const count) {
+ size_t completed = 0;
+ // `do ... while` because some callers actually want to make a syscall with a
+ // count of 0.
+ do {
+ auto const cur = RetryEINTR(f)(completed);
+ if (cur < 0) {
+ return cur;
+ } else if (cur == 0) {
+ break;
+ }
+ completed += cur;
+ } while (completed < count);
+ return completed;
+}
+
+} // namespace internal
+
+inline ssize_t ReadFd(int fd, void* buf, size_t count) {
+ return internal::ApplyFileIoSyscall(
+ [&](size_t completed) {
+ return read(fd, static_cast<char*>(buf) + completed, count - completed);
+ },
+ count);
+}
+
+inline ssize_t WriteFd(int fd, void const* buf, size_t count) {
+ return internal::ApplyFileIoSyscall(
+ [&](size_t completed) {
+ return write(fd, static_cast<char const*>(buf) + completed,
+ count - completed);
+ },
+ count);
+}
+
+inline ssize_t PreadFd(int fd, void* buf, size_t count, off_t offset) {
+ return internal::ApplyFileIoSyscall(
+ [&](size_t completed) {
+ return pread(fd, static_cast<char*>(buf) + completed, count - completed,
+ offset + completed);
+ },
+ count);
+}
+
+inline ssize_t PwriteFd(int fd, void const* buf, size_t count, off_t offset) {
+ return internal::ApplyFileIoSyscall(
+ [&](size_t completed) {
+ return pwrite(fd, static_cast<char const*>(buf) + completed,
+ count - completed, offset + completed);
+ },
+ count);
+}
+
+template <typename T>
+using List = std::initializer_list<T>;
+
+namespace internal {
+
+template <typename T>
+void AppendAllBitwiseCombinations(std::vector<T>* combinations, T current) {
+ combinations->push_back(current);
+}
+
+template <typename T, typename Arg, typename... Args>
+void AppendAllBitwiseCombinations(std::vector<T>* combinations, T current,
+ Arg&& next, Args&&... rest) {
+ for (auto const option : next) {
+ AppendAllBitwiseCombinations(combinations, current | option, rest...);
+ }
+}
+
+inline size_t CombinedSize(size_t accum) { return accum; }
+
+template <typename T, typename... Args>
+size_t CombinedSize(size_t accum, T const& x, Args&&... xs) {
+ return CombinedSize(accum + x.size(), std::forward<Args>(xs)...);
+}
+
+// Base case: no more containers, so do nothing.
+template <typename T>
+void DoMoveExtendContainer(T* c) {}
+
+// Append each container next to c.
+template <typename T, typename U, typename... Args>
+void DoMoveExtendContainer(T* c, U&& next, Args&&... rest) {
+ std::move(std::begin(next), std::end(next), std::back_inserter(*c));
+ DoMoveExtendContainer(c, std::forward<Args>(rest)...);
+}
+
+} // namespace internal
+
+template <typename T = int>
+std::vector<T> AllBitwiseCombinations() {
+ return std::vector<T>();
+}
+
+template <typename T = int, typename... Args>
+std::vector<T> AllBitwiseCombinations(Args&&... args) {
+ std::vector<T> combinations;
+ internal::AppendAllBitwiseCombinations(&combinations, 0, args...);
+ return combinations;
+}
+
+template <typename T, typename U, typename F>
+std::vector<T> ApplyVec(F const& f, std::vector<U> const& us) {
+ std::vector<T> vec;
+ vec.reserve(us.size());
+ for (auto const& u : us) {
+ vec.push_back(f(u));
+ }
+ return vec;
+}
+
+template <typename T, typename U>
+std::vector<T> ApplyVecToVec(std::vector<std::function<T(U)>> const& fs,
+ std::vector<U> const& us) {
+ std::vector<T> vec;
+ vec.reserve(us.size() * fs.size());
+ for (auto const& f : fs) {
+ for (auto const& u : us) {
+ vec.push_back(f(u));
+ }
+ }
+ return vec;
+}
+
+// Moves all elements from the containers `args` to the end of `c`.
+template <typename T, typename... Args>
+void VecAppend(T* c, Args&&... args) {
+ c->reserve(internal::CombinedSize(c->size(), args...));
+ internal::DoMoveExtendContainer(c, std::forward<Args>(args)...);
+}
+
+// Returns a vector containing the concatenated contents of the containers
+// `args`.
+template <typename T, typename... Args>
+std::vector<T> VecCat(Args&&... args) {
+ std::vector<T> combined;
+ VecAppend(&combined, std::forward<Args>(args)...);
+ return combined;
+}
+
+#define RETURN_ERROR_IF_SYSCALL_FAIL(syscall) \
+ do { \
+ if ((syscall) < 0 && errno != 0) { \
+ return PosixError(errno, #syscall); \
+ } \
+ } while (false)
+
+// Fill the given buffer with random bytes.
+void RandomizeBuffer(void* buffer, size_t len);
+
+template <typename T>
+inline PosixErrorOr<T> Atoi(absl::string_view str) {
+ T ret;
+ if (!absl::SimpleAtoi<T>(str, &ret)) {
+ return PosixError(EINVAL, "String not a number.");
+ }
+ return ret;
+}
+
+inline PosixErrorOr<uint64_t> AtoiBase(absl::string_view str, int base) {
+ if (base > 255 || base < 2) {
+ return PosixError(EINVAL, "Invalid Base");
+ }
+
+ uint64_t ret = 0;
+ if (!absl::numbers_internal::safe_strtou64_base(str, &ret, base)) {
+ return PosixError(EINVAL, "String not a number.");
+ }
+
+ return ret;
+}
+
+inline PosixErrorOr<double> Atod(absl::string_view str) {
+ double ret;
+ if (!absl::SimpleAtod(str, &ret)) {
+ return PosixError(EINVAL, "String not a double type.");
+ }
+ return ret;
+}
+
+inline PosixErrorOr<float> Atof(absl::string_view str) {
+ float ret;
+ if (!absl::SimpleAtof(str, &ret)) {
+ return PosixError(EINVAL, "String not a float type.");
+ }
+ return ret;
+}
+
+// Return the smallest number of iovec arrays that can be used to write
+// "total_bytes" number of bytes, each iovec writing one "buf".
+std::vector<std::vector<struct iovec>> GenerateIovecs(uint64_t total_size,
+ void* buf, size_t buflen);
+
+// Returns bytes in 'n' megabytes. Used for readability.
+uint64_t Megabytes(uint64_t n);
+
+// Predicate for checking that a value is within some tolerance of another
+// value. Returns true iff current is in the range [target * (1 - tolerance),
+// target * (1 + tolerance)].
+bool Equivalent(uint64_t current, uint64_t target, double tolerance);
+
+// Matcher wrapping the Equivalent predicate.
+MATCHER_P2(EquivalentWithin, target, tolerance,
+ std::string(negation ? "Isn't" : "Is") +
+ ::absl::StrFormat(" within %.2f%% of the target of %zd bytes",
+ tolerance * 100, target)) {
+ if (target == 0) {
+ *result_listener << ::absl::StreamFormat("difference of infinity%%");
+ } else {
+ int64_t delta = static_cast<int64_t>(arg) - static_cast<int64_t>(target);
+ double delta_percent =
+ static_cast<double>(delta) / static_cast<double>(target) * 100;
+ *result_listener << ::absl::StreamFormat("difference of %.2f%%",
+ delta_percent);
+ }
+ return Equivalent(arg, target, tolerance);
+}
+
+// Returns the absolute path to the a data dependency. 'path' is the runfile
+// location relative to workspace root.
+#ifdef __linux__
+std::string RunfilePath(std::string path);
+#endif
+
+void TestInit(int* argc, char*** argv);
+int RunAllTests(void);
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_TEST_UTIL_H_
diff --git a/test/util/test_util_impl.cc b/test/util/test_util_impl.cc
new file mode 100644
index 000000000..7e1ad9e66
--- /dev/null
+++ b/test/util/test_util_impl.cc
@@ -0,0 +1,52 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+
+#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
+#include "absl/flags/parse.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+
+extern bool FLAGS_benchmark_list_tests;
+extern std::string FLAGS_benchmark_filter;
+
+namespace gvisor {
+namespace testing {
+
+void SetupGvisorDeathTest() {}
+
+void TestInit(int* argc, char*** argv) {
+ ::testing::InitGoogleTest(argc, *argv);
+ benchmark::Initialize(argc, *argv);
+ ::absl::ParseCommandLine(*argc, *argv);
+
+ // Always mask SIGPIPE as it's common and tests aren't expected to handle it.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_IGN;
+ TEST_CHECK(sigaction(SIGPIPE, &sa, nullptr) == 0);
+}
+
+int RunAllTests() {
+ if (FLAGS_benchmark_list_tests || FLAGS_benchmark_filter != ".") {
+ benchmark::RunSpecifiedBenchmarks();
+ return 0;
+ } else {
+ return RUN_ALL_TESTS();
+ }
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/test_util_runfiles.cc b/test/util/test_util_runfiles.cc
new file mode 100644
index 000000000..694d21692
--- /dev/null
+++ b/test/util/test_util_runfiles.cc
@@ -0,0 +1,50 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef __fuchsia__
+
+#include <iostream>
+#include <string>
+
+#include "test/util/fs_util.h"
+#include "test/util/test_util.h"
+#include "tools/cpp/runfiles/runfiles.h"
+
+namespace gvisor {
+namespace testing {
+
+std::string RunfilePath(std::string path) {
+ static const bazel::tools::cpp::runfiles::Runfiles* const runfiles = [] {
+ std::string error;
+ auto* runfiles =
+ bazel::tools::cpp::runfiles::Runfiles::CreateForTest(&error);
+ if (runfiles == nullptr) {
+ std::cerr << "Unable to find runfiles: " << error << std::endl;
+ }
+ return runfiles;
+ }();
+
+ if (!runfiles) {
+ // Can't find runfiles? This probably won't work, but __main__/path is our
+ // best guess.
+ return JoinPath("__main__", path);
+ }
+
+ return runfiles->Rlocation(JoinPath("__main__", path));
+}
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // __fuchsia__
diff --git a/test/util/test_util_test.cc b/test/util/test_util_test.cc
new file mode 100644
index 000000000..f42100374
--- /dev/null
+++ b/test/util/test_util_test.cc
@@ -0,0 +1,251 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/test_util.h"
+
+#include <errno.h>
+
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using ::testing::AnyOf;
+using ::testing::Gt;
+using ::testing::IsEmpty;
+using ::testing::Lt;
+using ::testing::Not;
+using ::testing::TypedEq;
+using ::testing::UnorderedElementsAre;
+using ::testing::UnorderedElementsAreArray;
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(KernelVersionParsing, ValidateParsing) {
+ KernelVersion v = ASSERT_NO_ERRNO_AND_VALUE(
+ ParseKernelVersion("4.18.10-1foo2-amd64 baz blah"));
+ ASSERT_TRUE(v == KernelVersion({4, 18, 10}));
+
+ v = ASSERT_NO_ERRNO_AND_VALUE(ParseKernelVersion("4.18.10-1foo2-amd64"));
+ ASSERT_TRUE(v == KernelVersion({4, 18, 10}));
+
+ v = ASSERT_NO_ERRNO_AND_VALUE(ParseKernelVersion("4.18.10-14-amd64"));
+ ASSERT_TRUE(v == KernelVersion({4, 18, 10}));
+
+ v = ASSERT_NO_ERRNO_AND_VALUE(ParseKernelVersion("4.18.10-amd64"));
+ ASSERT_TRUE(v == KernelVersion({4, 18, 10}));
+
+ v = ASSERT_NO_ERRNO_AND_VALUE(ParseKernelVersion("4.18.10"));
+ ASSERT_TRUE(v == KernelVersion({4, 18, 10}));
+
+ v = ASSERT_NO_ERRNO_AND_VALUE(ParseKernelVersion("4.0.10"));
+ ASSERT_TRUE(v == KernelVersion({4, 0, 10}));
+
+ v = ASSERT_NO_ERRNO_AND_VALUE(ParseKernelVersion("4.0"));
+ ASSERT_TRUE(v == KernelVersion({4, 0, 0}));
+
+ ASSERT_THAT(ParseKernelVersion("4.a"), PosixErrorIs(EINVAL, ::testing::_));
+ ASSERT_THAT(ParseKernelVersion("3"), PosixErrorIs(EINVAL, ::testing::_));
+ ASSERT_THAT(ParseKernelVersion(""), PosixErrorIs(EINVAL, ::testing::_));
+ ASSERT_THAT(ParseKernelVersion("version 3.3.10"),
+ PosixErrorIs(EINVAL, ::testing::_));
+}
+
+TEST(MatchersTest, SyscallSucceeds) {
+ EXPECT_THAT(0, SyscallSucceeds());
+ EXPECT_THAT(0L, SyscallSucceeds());
+
+ errno = 0;
+ EXPECT_THAT(-1, SyscallSucceeds());
+ EXPECT_THAT(-1L, SyscallSucceeds());
+
+ errno = ENOMEM;
+ EXPECT_THAT(-1, Not(SyscallSucceeds()));
+ EXPECT_THAT(-1L, Not(SyscallSucceeds()));
+}
+
+TEST(MatchersTest, SyscallSucceedsWithValue) {
+ EXPECT_THAT(0, SyscallSucceedsWithValue(0));
+ EXPECT_THAT(1, SyscallSucceedsWithValue(Lt(3)));
+ EXPECT_THAT(-1, Not(SyscallSucceedsWithValue(Lt(3))));
+ EXPECT_THAT(4, Not(SyscallSucceedsWithValue(Lt(3))));
+
+ // Non-int -1
+ EXPECT_THAT(-1L, Not(SyscallSucceedsWithValue(0)));
+
+ // Non-int, truncates to -1 if converted to int, with expected value
+ EXPECT_THAT(0xffffffffL, SyscallSucceedsWithValue(0xffffffffL));
+
+ // Non-int, truncates to -1 if converted to int, with monomorphic matcher
+ EXPECT_THAT(0xffffffffL,
+ SyscallSucceedsWithValue(TypedEq<long>(0xffffffffL)));
+
+ // Non-int, truncates to -1 if converted to int, with polymorphic matcher
+ EXPECT_THAT(0xffffffffL, SyscallSucceedsWithValue(Gt(1)));
+}
+
+TEST(MatchersTest, SyscallFails) {
+ EXPECT_THAT(0, Not(SyscallFails()));
+ EXPECT_THAT(0L, Not(SyscallFails()));
+
+ errno = 0;
+ EXPECT_THAT(-1, Not(SyscallFails()));
+ EXPECT_THAT(-1L, Not(SyscallFails()));
+
+ errno = ENOMEM;
+ EXPECT_THAT(-1, SyscallFails());
+ EXPECT_THAT(-1L, SyscallFails());
+}
+
+TEST(MatchersTest, SyscallFailsWithErrno) {
+ EXPECT_THAT(0, Not(SyscallFailsWithErrno(EINVAL)));
+ EXPECT_THAT(0L, Not(SyscallFailsWithErrno(EINVAL)));
+
+ errno = ENOMEM;
+ EXPECT_THAT(-1, Not(SyscallFailsWithErrno(EINVAL)));
+ EXPECT_THAT(-1L, Not(SyscallFailsWithErrno(EINVAL)));
+
+ errno = EINVAL;
+ EXPECT_THAT(-1, SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(-1L, SyscallFailsWithErrno(EINVAL));
+
+ EXPECT_THAT(-1, SyscallFailsWithErrno(AnyOf(EINVAL, ENOMEM)));
+ EXPECT_THAT(-1L, SyscallFailsWithErrno(AnyOf(EINVAL, ENOMEM)));
+
+ std::vector<int> expected_errnos({EINVAL, ENOMEM});
+ errno = ENOMEM;
+ EXPECT_THAT(-1, SyscallFailsWithErrno(ElementOf(expected_errnos)));
+ EXPECT_THAT(-1L, SyscallFailsWithErrno(ElementOf(expected_errnos)));
+}
+
+TEST(AllBitwiseCombinationsTest, NoArguments) {
+ EXPECT_THAT(AllBitwiseCombinations(), IsEmpty());
+}
+
+TEST(AllBitwiseCombinationsTest, EmptyList) {
+ EXPECT_THAT(AllBitwiseCombinations(List<int>{}), IsEmpty());
+}
+
+TEST(AllBitwiseCombinationsTest, SingleElementList) {
+ EXPECT_THAT(AllBitwiseCombinations(List<int>{5}), UnorderedElementsAre(5));
+}
+
+TEST(AllBitwiseCombinationsTest, SingleList) {
+ EXPECT_THAT(AllBitwiseCombinations(List<int>{0, 1, 2, 4}),
+ UnorderedElementsAre(0, 1, 2, 4));
+}
+
+TEST(AllBitwiseCombinationsTest, MultipleLists) {
+ EXPECT_THAT(
+ AllBitwiseCombinations(List<int>{0, 1, 2, 3}, List<int>{0, 4, 8, 12}),
+ UnorderedElementsAreArray(
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}));
+}
+
+TEST(RandomizeBuffer, Works) {
+ const std::vector<char> original(4096);
+ std::vector<char> buffer = original;
+ RandomizeBuffer(buffer.data(), buffer.size());
+ EXPECT_NE(buffer, original);
+}
+
+// Enable comparison of vectors of iovec arrays for the following test.
+MATCHER_P(IovecsListEq, expected, "") {
+ if (arg.size() != expected.size()) {
+ *result_listener << "sizes are different (actual: " << arg.size()
+ << ", expected: " << expected.size() << ")";
+ return false;
+ }
+
+ for (uint64_t i = 0; i < expected.size(); ++i) {
+ const std::vector<struct iovec>& actual_iovecs = arg[i];
+ const std::vector<struct iovec>& expected_iovecs = expected[i];
+ if (actual_iovecs.size() != expected_iovecs.size()) {
+ *result_listener << "iovec array size at position " << i
+ << " is different (actual: " << actual_iovecs.size()
+ << ", expected: " << expected_iovecs.size() << ")";
+ return false;
+ }
+
+ for (uint64_t j = 0; j < expected_iovecs.size(); ++j) {
+ const struct iovec& actual_iov = actual_iovecs[j];
+ const struct iovec& expected_iov = expected_iovecs[j];
+ if (actual_iov.iov_base != expected_iov.iov_base) {
+ *result_listener << "iovecs in array " << i << " at position " << j
+ << " are different (expected iov_base: "
+ << expected_iov.iov_base
+ << ", got: " << actual_iov.iov_base << ")";
+ return false;
+ }
+ if (actual_iov.iov_len != expected_iov.iov_len) {
+ *result_listener << "iovecs in array " << i << " at position " << j
+ << " are different (expected iov_len: "
+ << expected_iov.iov_len
+ << ", got: " << actual_iov.iov_len << ")";
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+// Verify empty iovec list generation.
+TEST(GenerateIovecs, EmptyList) {
+ std::vector<char> buffer = {'a', 'b', 'c'};
+
+ EXPECT_THAT(GenerateIovecs(0, buffer.data(), buffer.size()),
+ IovecsListEq(std::vector<std::vector<struct iovec>>()));
+}
+
+// Verify generating a single array of only one, partial, iovec.
+TEST(GenerateIovecs, OneArray) {
+ std::vector<char> buffer = {'a', 'b', 'c'};
+
+ std::vector<std::vector<struct iovec>> expected;
+ struct iovec iov = {};
+ iov.iov_base = buffer.data();
+ iov.iov_len = 2;
+ expected.push_back(std::vector<struct iovec>({iov}));
+ EXPECT_THAT(GenerateIovecs(2, buffer.data(), buffer.size()),
+ IovecsListEq(expected));
+}
+
+// Verify that it wraps around after IOV_MAX iovecs.
+TEST(GenerateIovecs, WrapsAtIovMax) {
+ std::vector<char> buffer = {'a', 'b', 'c'};
+
+ std::vector<std::vector<struct iovec>> expected;
+ struct iovec iov = {};
+ iov.iov_base = buffer.data();
+ iov.iov_len = buffer.size();
+ expected.emplace_back();
+ for (int i = 0; i < IOV_MAX; ++i) {
+ expected[0].push_back(iov);
+ }
+ iov.iov_len = 1;
+ expected.push_back(std::vector<struct iovec>({iov}));
+
+ EXPECT_THAT(
+ GenerateIovecs(IOV_MAX * buffer.size() + 1, buffer.data(), buffer.size()),
+ IovecsListEq(expected));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/thread_util.h b/test/util/thread_util.h
new file mode 100644
index 000000000..923c4fe10
--- /dev/null
+++ b/test/util/thread_util.h
@@ -0,0 +1,93 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_THREAD_UTIL_H_
+#define GVISOR_TEST_UTIL_THREAD_UTIL_H_
+
+#include <pthread.h>
+#ifdef __linux__
+#include <sys/syscall.h>
+#endif
+#include <unistd.h>
+
+#include <functional>
+#include <utility>
+
+#include "test/util/logging.h"
+
+namespace gvisor {
+namespace testing {
+
+// ScopedThread is a minimal wrapper around pthreads.
+//
+// This is used in lieu of more complex mechanisms because it provides very
+// predictable behavior (no messing with timers, etc.) The thread will
+// automatically joined when it is destructed (goes out of scope), but can be
+// joined manually as well.
+class ScopedThread {
+ public:
+ // Constructs a thread that executes f exactly once.
+ explicit ScopedThread(std::function<void*()> f) : f_(std::move(f)) {
+ CreateThread();
+ }
+
+ explicit ScopedThread(const std::function<void()>& f) {
+ f_ = [=] {
+ f();
+ return nullptr;
+ };
+ CreateThread();
+ }
+
+ ScopedThread(const ScopedThread& other) = delete;
+ ScopedThread& operator=(const ScopedThread& other) = delete;
+
+ // Joins the thread.
+ ~ScopedThread() { Join(); }
+
+ // Waits until this thread has finished executing. Join is idempotent and may
+ // be called multiple times, however Join itself is not thread-safe.
+ void* Join() {
+ if (!joined_) {
+ TEST_PCHECK(pthread_join(pt_, &retval_) == 0);
+ joined_ = true;
+ }
+ return retval_;
+ }
+
+ private:
+ void CreateThread() {
+ TEST_PCHECK_MSG(pthread_create(
+ &pt_, /* attr = */ nullptr,
+ +[](void* arg) -> void* {
+ return static_cast<ScopedThread*>(arg)->f_();
+ },
+ this) == 0,
+ "thread creation failed");
+ }
+
+ std::function<void*()> f_;
+ pthread_t pt_;
+ bool joined_ = false;
+ void* retval_ = nullptr;
+};
+
+#ifdef __linux__
+inline pid_t gettid() { return syscall(SYS_gettid); }
+#endif
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_THREAD_UTIL_H_
diff --git a/test/util/time_util.cc b/test/util/time_util.cc
new file mode 100644
index 000000000..1ddfbfc9c
--- /dev/null
+++ b/test/util/time_util.cc
@@ -0,0 +1,41 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/time_util.h"
+
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include "absl/time/time.h"
+
+namespace gvisor {
+namespace testing {
+
+void SleepSafe(absl::Duration duration) {
+ if (duration == absl::ZeroDuration()) {
+ return;
+ }
+
+ struct timespec ts = absl::ToTimespec(duration);
+ int ret;
+ while (1) {
+ ret = syscall(__NR_nanosleep, &ts, &ts);
+ if (ret == 0 || (ret <= 0 && errno != EINTR)) {
+ break;
+ }
+ }
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/time_util.h b/test/util/time_util.h
new file mode 100644
index 000000000..f3ddc9fde
--- /dev/null
+++ b/test/util/time_util.h
@@ -0,0 +1,29 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_TIME_UTIL_H_
+#define GVISOR_TEST_UTIL_TIME_UTIL_H_
+
+#include "absl/time/time.h"
+
+namespace gvisor {
+namespace testing {
+
+// Sleep for at least the specified duration. Avoids glibc.
+void SleepSafe(absl::Duration duration);
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_TIME_UTIL_H_
diff --git a/test/util/timer_util.cc b/test/util/timer_util.cc
new file mode 100644
index 000000000..43a26b0d3
--- /dev/null
+++ b/test/util/timer_util.cc
@@ -0,0 +1,27 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/timer_util.h"
+
+namespace gvisor {
+namespace testing {
+
+absl::Time Now(clockid_t id) {
+ struct timespec now;
+ TEST_PCHECK(clock_gettime(id, &now) == 0);
+ return absl::TimeFromTimespec(now);
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/timer_util.h b/test/util/timer_util.h
new file mode 100644
index 000000000..31aea4fc6
--- /dev/null
+++ b/test/util/timer_util.h
@@ -0,0 +1,74 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_TIMER_UTIL_H_
+#define GVISOR_TEST_UTIL_TIMER_UTIL_H_
+
+#include <errno.h>
+#include <sys/time.h>
+
+#include <functional>
+
+#include "gmock/gmock.h"
+#include "absl/time/time.h"
+#include "test/util/cleanup.h"
+#include "test/util/logging.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// MonotonicTimer is a simple timer that uses a monotonic clock.
+class MonotonicTimer {
+ public:
+ MonotonicTimer() {}
+ absl::Duration Duration() {
+ struct timespec ts;
+ TEST_CHECK(clock_gettime(CLOCK_MONOTONIC, &ts) == 0);
+ return absl::TimeFromTimespec(ts) - start_;
+ }
+
+ void Start() {
+ struct timespec ts;
+ TEST_CHECK(clock_gettime(CLOCK_MONOTONIC, &ts) == 0);
+ start_ = absl::TimeFromTimespec(ts);
+ }
+
+ protected:
+ absl::Time start_;
+};
+
+// Sets the given itimer and returns a cleanup function that restores the
+// previous itimer when it goes out of scope.
+inline PosixErrorOr<Cleanup> ScopedItimer(int which,
+ struct itimerval const& new_value) {
+ struct itimerval old_value;
+ int rc = setitimer(which, &new_value, &old_value);
+ MaybeSave();
+ if (rc < 0) {
+ return PosixError(errno, "setitimer failed");
+ }
+ return Cleanup(std::function<void(void)>([which, old_value] {
+ EXPECT_THAT(setitimer(which, &old_value, nullptr), SyscallSucceeds());
+ }));
+}
+
+// Returns the current time.
+absl::Time Now(clockid_t id);
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_TIMER_UTIL_H_
diff --git a/test/util/uid_util.cc b/test/util/uid_util.cc
new file mode 100644
index 000000000..b131b4b99
--- /dev/null
+++ b/test/util/uid_util.cc
@@ -0,0 +1,44 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+
+namespace gvisor {
+namespace testing {
+
+PosixErrorOr<bool> IsRoot() {
+ uid_t ruid, euid, suid;
+ int rc = getresuid(&ruid, &euid, &suid);
+ MaybeSave();
+ if (rc < 0) {
+ return PosixError(errno, "getresuid");
+ }
+ if (ruid != 0 || euid != 0 || suid != 0) {
+ return false;
+ }
+ gid_t rgid, egid, sgid;
+ rc = getresgid(&rgid, &egid, &sgid);
+ MaybeSave();
+ if (rc < 0) {
+ return PosixError(errno, "getresgid");
+ }
+ if (rgid != 0 || egid != 0 || sgid != 0) {
+ return false;
+ }
+ return true;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/uid_util.h b/test/util/uid_util.h
new file mode 100644
index 000000000..2cd387fb0
--- /dev/null
+++ b/test/util/uid_util.h
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_UID_UTIL_H_
+#define GVISOR_TEST_SYSCALLS_UID_UTIL_H_
+
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+// Returns true if the caller's real/effective/saved user/group IDs are all 0.
+PosixErrorOr<bool> IsRoot();
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_UID_UTIL_H_
diff --git a/tools/BUILD b/tools/BUILD
new file mode 100644
index 000000000..34b950644
--- /dev/null
+++ b/tools/BUILD
@@ -0,0 +1 @@
+package(licenses = ["notice"])
diff --git a/tools/bazel.mk b/tools/bazel.mk
new file mode 100644
index 000000000..9f4a40669
--- /dev/null
+++ b/tools/bazel.mk
@@ -0,0 +1,124 @@
+#!/usr/bin/make -f
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# See base Makefile.
+BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \
+ git rev-parse --abbrev-ref HEAD 2>/dev/null) | \
+ xargs -n 1 basename 2>/dev/null)
+
+# Bazel container configuration (see below).
+USER ?= gvisor
+HASH ?= $(shell readlink -m $(CURDIR) | md5sum | cut -c1-8)
+DOCKER_NAME ?= gvisor-bazel-$(HASH)
+DOCKER_PRIVILEGED ?= --privileged
+BAZEL_CACHE := $(shell readlink -m ~/.cache/bazel/)
+GCLOUD_CONFIG := $(shell readlink -m ~/.config/gcloud/)
+DOCKER_SOCKET := /var/run/docker.sock
+
+# Non-configurable.
+UID := $(shell id -u ${USER})
+GID := $(shell id -g ${USER})
+USERADD_OPTIONS :=
+FULL_DOCKER_RUN_OPTIONS := $(DOCKER_RUN_OPTIONS)
+FULL_DOCKER_RUN_OPTIONS += -v "$(BAZEL_CACHE):$(BAZEL_CACHE)"
+FULL_DOCKER_RUN_OPTIONS += -v "$(GCLOUD_CONFIG):$(GCLOUD_CONFIG)"
+FULL_DOCKER_RUN_OPTIONS += -v "/tmp:/tmp"
+ifneq ($(DOCKER_PRIVILEGED),)
+FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_SOCKET):$(DOCKER_SOCKET)"
+DOCKER_GROUP := $(shell stat -c '%g' $(DOCKER_SOCKET))
+ifneq ($(GID),$(DOCKER_GROUP))
+USERADD_OPTIONS += --groups $(DOCKER_GROUP)
+GROUPADD_DOCKER += groupadd --gid $(DOCKER_GROUP) --non-unique docker-$(HASH) &&
+FULL_DOCKER_RUN_OPTIONS += --group-add $(DOCKER_GROUP)
+endif
+endif
+SHELL=/bin/bash -o pipefail
+
+##
+## Bazel helpers.
+##
+## This file supports targets that wrap bazel in a running Docker
+## container to simplify development. Some options are available to
+## control the behavior of this container:
+## USER - The in-container user.
+## DOCKER_RUN_OPTIONS - Options for the container (default: --privileged, required for tests).
+## DOCKER_NAME - The container name (default: gvisor-bazel-HASH).
+## BAZEL_CACHE - The bazel cache directory (default: detected).
+## GCLOUD_CONFIG - The gcloud config directory (detect: detected).
+## DOCKER_SOCKET - The Docker socket (default: detected).
+##
+bazel-server-start: load-default ## Starts the bazel server.
+ @mkdir -p $(BAZEL_CACHE)
+ @mkdir -p $(GCLOUD_CONFIG)
+ docker run -d --rm \
+ --init \
+ --name $(DOCKER_NAME) \
+ --user 0:0 $(DOCKER_GROUP_OPTIONS) \
+ -v "$(CURDIR):$(CURDIR)" \
+ --workdir "$(CURDIR)" \
+ --entrypoint "" \
+ $(FULL_DOCKER_RUN_OPTIONS) \
+ gvisor.dev/images/default \
+ sh -c "groupadd --gid $(GID) --non-unique $(USER) && \
+ $(GROUPADD_DOCKER) \
+ useradd --uid $(UID) --non-unique --no-create-home --gid $(GID) $(USERADD_OPTIONS) -d $(HOME) $(USER) && \
+ bazel version && \
+ exec tail --pid=\$$(bazel info server_pid) -f /dev/null"
+ @while :; do if docker logs $(DOCKER_NAME) 2>/dev/null | grep "Build label:" >/dev/null; then break; fi; \
+ if ! docker ps | grep $(DOCKER_NAME); then exit 1; else sleep 1; fi; done
+.PHONY: bazel-server-start
+
+bazel-shutdown: ## Shuts down a running bazel server.
+ @docker exec --user $(UID):$(GID) $(DOCKER_NAME) bazel shutdown; rc=$$?; docker kill $(DOCKER_NAME) || [[ $$rc -ne 0 ]]
+.PHONY: bazel-shutdown
+
+bazel-alias: ## Emits an alias that can be used within the shell.
+ @echo "alias bazel='docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) bazel'"
+.PHONY: bazel-alias
+
+bazel-server: ## Ensures that the server exists. Used as an internal target.
+ @docker exec $(DOCKER_NAME) true || $(MAKE) bazel-server-start
+.PHONY: bazel-server
+
+build_cmd = docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) sh -o pipefail -c 'bazel $(STARTUP_OPTIONS) build $(OPTIONS) $(TARGETS)'
+
+build_paths = $(build_cmd) 2>&1 \
+ | tee /proc/self/fd/2 \
+ | grep -E "^ bazel-bin/" \
+ | awk "{print $$1;}" \
+ | xargs -n 1 -I {} sh -c "$(1)"
+
+build: bazel-server
+ @$(call build_cmd)
+.PHONY: build
+
+copy: bazel-server
+ifeq (,$(DESTINATION))
+ $(error Destination not provided.)
+endif
+ @$(call build_paths,cp -a {} $(DESTINATION))
+
+run: bazel-server
+ @$(call build_paths,{} $(ARGS))
+.PHONY: run
+
+sudo: bazel-server
+ @$(call build_paths,sudo -E {} $(ARGS))
+.PHONY: sudo
+
+test: bazel-server
+ @docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) bazel $(STARTUP_OPTIONS) test $(OPTIONS) $(TARGETS)
+.PHONY: test
diff --git a/tools/bazeldefs/BUILD b/tools/bazeldefs/BUILD
new file mode 100644
index 000000000..f2f80bae1
--- /dev/null
+++ b/tools/bazeldefs/BUILD
@@ -0,0 +1,51 @@
+load("//tools:defs.bzl", "rbe_platform", "rbe_toolchain")
+
+package(licenses = ["notice"])
+
+# In bazel, no special support is required for loopback networking. This is
+# just a dummy data target that does not change the test environment.
+genrule(
+ name = "loopback",
+ outs = ["loopback.txt"],
+ cmd = "touch $@",
+ visibility = ["//:sandbox"],
+)
+
+# We need to define a bazel platform and toolchain to specify dockerPrivileged
+# and dockerRunAsRoot options, they are required to run tests on the RBE
+# cluster in Kokoro.
+rbe_platform(
+ name = "rbe_ubuntu1604",
+ constraint_values = [
+ "@bazel_tools//platforms:x86_64",
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//tools/cpp:clang",
+ "@bazel_toolchains//constraints:xenial",
+ "@bazel_toolchains//constraints/sanitizers:support_msan",
+ ],
+ remote_execution_properties = """
+ properties: {
+ name: "container-image"
+ value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:93f7e127196b9b653d39830c50f8b05d49ef6fd8739a9b5b8ab16e1df5399e50"
+ }
+ properties: {
+ name: "dockerAddCapabilities"
+ value: "SYS_ADMIN"
+ }
+ properties: {
+ name: "dockerPrivileged"
+ value: "true"
+ }
+ """,
+)
+
+rbe_toolchain(
+ name = "cc-toolchain-clang-x86_64-default",
+ exec_compatible_with = [],
+ tags = [
+ "manual",
+ ],
+ target_compatible_with = [],
+ toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/10.0.0/bazel_2.0.0/cc:cc-compiler-k8",
+ toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
+)
diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl
new file mode 100644
index 000000000..620c460de
--- /dev/null
+++ b/tools/bazeldefs/defs.bzl
@@ -0,0 +1,182 @@
+"""Bazel implementations of standard rules."""
+
+load("@bazel_gazelle//:def.bzl", _gazelle = "gazelle")
+load("@bazel_skylib//rules:build_test.bzl", _build_test = "build_test")
+load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", _cc_flags_supplier = "cc_flags_supplier")
+load("@io_bazel_rules_go//go:def.bzl", "GoLibrary", _go_binary = "go_binary", _go_context = "go_context", _go_embed_data = "go_embed_data", _go_library = "go_library", _go_path = "go_path", _go_test = "go_test")
+load("@io_bazel_rules_go//proto:def.bzl", _go_grpc_library = "go_grpc_library", _go_proto_library = "go_proto_library")
+load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test")
+load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar")
+load("@pydeps//:requirements.bzl", _py_requirement = "requirement")
+load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", _cc_grpc_library = "cc_grpc_library")
+
+build_test = _build_test
+cc_library = _cc_library
+cc_flags_supplier = _cc_flags_supplier
+cc_proto_library = _cc_proto_library
+cc_test = _cc_test
+cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain"
+gazelle = _gazelle
+go_embed_data = _go_embed_data
+go_path = _go_path
+gtest = "@com_google_googletest//:gtest"
+grpcpp = "@com_github_grpc_grpc//:grpc++"
+gbenchmark = "@com_google_benchmark//:benchmark"
+loopback = "//tools/bazeldefs:loopback"
+pkg_deb = _pkg_deb
+pkg_tar = _pkg_tar
+py_library = native.py_library
+py_binary = native.py_binary
+py_test = native.py_test
+rbe_platform = native.platform
+rbe_toolchain = native.toolchain
+vdso_linker_option = "-fuse-ld=gold "
+
+def proto_library(name, has_services = None, **kwargs):
+ native.proto_library(
+ name = name,
+ **kwargs
+ )
+
+def cc_grpc_library(name, **kwargs):
+ _cc_grpc_library(name = name, grpc_only = True, **kwargs)
+
+def _go_proto_or_grpc_library(go_library_func, name, **kwargs):
+ deps = [
+ dep.replace("_proto", "_go_proto")
+ for dep in (kwargs.pop("deps", []) or [])
+ ]
+ go_library_func(
+ name = name + "_go_proto",
+ importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name + "_go_proto",
+ proto = ":" + name + "_proto",
+ deps = deps,
+ **kwargs
+ )
+
+def go_proto_library(name, **kwargs):
+ _go_proto_or_grpc_library(_go_proto_library, name, **kwargs)
+
+def go_grpc_and_proto_libraries(name, **kwargs):
+ _go_proto_or_grpc_library(_go_grpc_library, name, **kwargs)
+
+def cc_binary(name, static = False, **kwargs):
+ """Run cc_binary.
+
+ Args:
+ name: name of the target.
+ static: make a static binary if True
+ **kwargs: the rest of the args.
+ """
+ if static:
+ # How to statically link a c++ program that uses threads, like for gRPC:
+ # https://gcc.gnu.org/legacy-ml/gcc-help/2010-05/msg00029.html
+ if "linkopts" not in kwargs:
+ kwargs["linkopts"] = []
+ kwargs["linkopts"] += [
+ "-static",
+ "-lstdc++",
+ "-Wl,--whole-archive",
+ "-lpthread",
+ "-Wl,--no-whole-archive",
+ ]
+ _cc_binary(
+ name = name,
+ **kwargs
+ )
+
+def go_binary(name, static = False, pure = False, **kwargs):
+ """Build a go binary.
+
+ Args:
+ name: name of the target.
+ static: build a static binary.
+ pure: build without cgo.
+ **kwargs: rest of the arguments are passed to _go_binary.
+ """
+ if static:
+ kwargs["static"] = "on"
+ if pure:
+ kwargs["pure"] = "on"
+ _go_binary(
+ name = name,
+ **kwargs
+ )
+
+def go_importpath(target):
+ """Returns the importpath for the target."""
+ return target[GoLibrary].importpath
+
+def go_library(name, **kwargs):
+ _go_library(
+ name = name,
+ importpath = "gvisor.dev/gvisor/" + native.package_name(),
+ **kwargs
+ )
+
+def go_test(name, pure = False, library = None, **kwargs):
+ """Build a go test.
+
+ Args:
+ name: name of the output binary.
+ pure: should it be built without cgo.
+ library: the library to embed.
+ **kwargs: rest of the arguments to pass to _go_test.
+ """
+ if pure:
+ kwargs["pure"] = "on"
+ if library:
+ kwargs["embed"] = [library]
+ _go_test(
+ name = name,
+ **kwargs
+ )
+
+def go_rule(rule, implementation, **kwargs):
+ """Wraps a rule definition with Go attributes.
+
+ Args:
+ rule: rule function (typically rule or aspect).
+ implementation: implementation function.
+ **kwargs: other arguments to pass to rule.
+
+ Returns:
+ The result of invoking the rule.
+ """
+ attrs = kwargs.pop("attrs", [])
+ attrs["_go_context_data"] = attr.label(default = "@io_bazel_rules_go//:go_context_data")
+ attrs["_stdlib"] = attr.label(default = "@io_bazel_rules_go//:stdlib")
+ toolchains = kwargs.get("toolchains", []) + ["@io_bazel_rules_go//go:toolchain"]
+ return rule(implementation, attrs = attrs, toolchains = toolchains, **kwargs)
+
+def go_context(ctx):
+ go_ctx = _go_context(ctx)
+ return struct(
+ go = go_ctx.go,
+ env = go_ctx.env,
+ runfiles = depset([go_ctx.go] + go_ctx.sdk.tools + go_ctx.stdlib.libs),
+ goos = go_ctx.sdk.goos,
+ goarch = go_ctx.sdk.goarch,
+ tags = go_ctx.tags,
+ )
+
+def py_requirement(name, direct = True):
+ return _py_requirement(name)
+
+def select_arch(amd64 = "amd64", arm64 = "arm64", default = None, **kwargs):
+ values = {
+ "@bazel_tools//src/conditions:linux_x86_64": amd64,
+ "@bazel_tools//src/conditions:linux_aarch64": arm64,
+ }
+ if default:
+ values["//conditions:default"] = default
+ return select(values, **kwargs)
+
+def select_system(linux = ["__linux__"], **kwargs):
+ return linux # Only Linux supported.
+
+def default_installer():
+ return None
+
+def default_net_util():
+ return [] # Nothing needed.
diff --git a/tools/bazeldefs/platforms.bzl b/tools/bazeldefs/platforms.bzl
new file mode 100644
index 000000000..165b22311
--- /dev/null
+++ b/tools/bazeldefs/platforms.bzl
@@ -0,0 +1,9 @@
+"""List of platforms."""
+
+# Platform to associated tags.
+platforms = {
+ "ptrace": [],
+ "kvm": [],
+}
+
+default_platform = "ptrace"
diff --git a/tools/bazeldefs/tags.bzl b/tools/bazeldefs/tags.bzl
new file mode 100644
index 000000000..f5d7a7b21
--- /dev/null
+++ b/tools/bazeldefs/tags.bzl
@@ -0,0 +1,56 @@
+"""List of special Go suffixes."""
+
+def explode(tagset, suffixes):
+ """explode combines tagset and suffixes in all ways.
+
+ Args:
+ tagset: Original suffixes.
+ suffixes: Suffixes to combine before and after.
+
+ Returns:
+ The set of possible combinations.
+ """
+ result = [t for t in tagset]
+ result += [s for s in suffixes]
+ for t in tagset:
+ result += [t + s for s in suffixes]
+ result += [s + t for s in suffixes]
+ return result
+
+archs = [
+ "_386",
+ "_aarch64",
+ "_amd64",
+ "_arm",
+ "_arm64",
+ "_mips",
+ "_mips64",
+ "_mips64le",
+ "_mipsle",
+ "_ppc64",
+ "_ppc64le",
+ "_riscv64",
+ "_s390x",
+ "_sparc64",
+ "_x86",
+]
+
+oses = [
+ "_linux",
+]
+
+generic = [
+ "_impl",
+ "_race",
+ "_norace",
+ "_unsafe",
+ "_opts",
+]
+
+# State explosion? Sure. This is approximately:
+# len(archs) * (1 + 2 * len(oses) * (1 + 2 * len(generic))
+#
+# This evaluates to 495 at the time of writing. So it's a lot of different
+# combinations, but not so much that it will cause issues. We can probably add
+# quite a few more variants before this becomes a genuine problem.
+go_suffixes = explode(explode(archs, oses), generic)
diff --git a/tools/bigquery/BUILD b/tools/bigquery/BUILD
new file mode 100644
index 000000000..5748fb390
--- /dev/null
+++ b/tools/bigquery/BUILD
@@ -0,0 +1,10 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "bigquery",
+ testonly = 1,
+ srcs = ["bigquery.go"],
+ deps = ["@com_google_cloud_go_bigquery//:go_default_library"],
+)
diff --git a/tools/bigquery/bigquery.go b/tools/bigquery/bigquery.go
new file mode 100644
index 000000000..56f0dc5c9
--- /dev/null
+++ b/tools/bigquery/bigquery.go
@@ -0,0 +1,121 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package bigquery defines a BigQuery schema for benchmarks.
+//
+// This package contains a schema for BigQuery and methods for publishing
+// benchmark data into tables.
+package bigquery
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "time"
+
+ bq "cloud.google.com/go/bigquery"
+)
+
+// Benchmark is the top level structure of recorded benchmark data. BigQuery
+// will infer the schema from this.
+type Benchmark struct {
+ Name string `bq:"name"`
+ Timestamp time.Time `bq:"timestamp"`
+ Official bool `bq:"official"`
+ Metric []*Metric `bq:"metric"`
+ Metadata *Metadata `bq:"metadata"`
+}
+
+// Metric holds the actual metric data and unit information for this benchmark.
+type Metric struct {
+ Name string `bq:"name"`
+ Unit string `bq:"unit"`
+ Sample float64 `bq:"sample"`
+}
+
+// Metadata about this benchmark.
+type Metadata struct {
+ CL string `bq:"changelist"`
+ IterationID string `bq:"iteration_id"`
+ PendingCL string `bq:"pending_cl"`
+ Workflow string `bq:"workflow"`
+ Platform string `bq:"platform"`
+ Gofer string `bq:"gofer"`
+}
+
+// InitBigQuery initializes a BigQuery dataset/table in the project. If the dataset/table already exists, it is not duplicated.
+func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string) error {
+ client, err := bq.NewClient(ctx, projectID)
+ if err != nil {
+ return fmt.Errorf("failed to initialize client on project %s: %v", projectID, err)
+ }
+ defer client.Close()
+
+ dataset := client.Dataset(datasetID)
+ if err := dataset.Create(ctx, nil); err != nil && !checkDuplicateError(err) {
+ return fmt.Errorf("failed to create dataset: %s: %v", datasetID, err)
+ }
+
+ table := dataset.Table(tableID)
+ schema, err := bq.InferSchema(Benchmark{})
+ if err != nil {
+ return fmt.Errorf("failed to infer schema: %v", err)
+ }
+
+ if err := table.Create(ctx, &bq.TableMetadata{Schema: schema}); err != nil && !checkDuplicateError(err) {
+ return fmt.Errorf("failed to create table: %s: %v", tableID, err)
+ }
+ return nil
+}
+
+// AddMetric adds a metric to an existing Benchmark.
+func (bm *Benchmark) AddMetric(metricName, unit string, sample float64) {
+ m := &Metric{
+ Name: metricName,
+ Unit: unit,
+ Sample: sample,
+ }
+ bm.Metric = append(bm.Metric, m)
+}
+
+// NewBenchmark initializes a new benchmark.
+func NewBenchmark(name string, official bool) *Benchmark {
+ return &Benchmark{
+ Name: name,
+ Timestamp: time.Now().UTC(),
+ Official: official,
+ Metric: make([]*Metric, 0),
+ }
+}
+
+// SendBenchmarks sends the slice of benchmarks to the BigQuery dataset/table.
+func SendBenchmarks(ctx context.Context, benchmarks []*Benchmark, projectID, datasetID, tableID string) error {
+ client, err := bq.NewClient(ctx, projectID)
+ if err != nil {
+ return fmt.Errorf("Failed to initialize client on project: %s: %v", projectID, err)
+ }
+ defer client.Close()
+
+ uploader := client.Dataset(datasetID).Table(tableID).Uploader()
+ if err = uploader.Put(ctx, benchmarks); err != nil {
+ return fmt.Errorf("failed to upload benchmarks to proejct %s, table %s.%s: %v", projectID, datasetID, tableID, err)
+ }
+
+ return nil
+}
+
+// BigQuery will error "409" for duplicate tables and datasets.
+func checkDuplicateError(err error) bool {
+ return strings.Contains(err.Error(), "googleapi: Error 409: Already Exists")
+}
diff --git a/tools/checkescape/BUILD b/tools/checkescape/BUILD
new file mode 100644
index 000000000..b8c3ddf44
--- /dev/null
+++ b/tools/checkescape/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "checkescape",
+ srcs = ["checkescape.go"],
+ nogo = False,
+ visibility = ["//tools/nogo:__subpackages__"],
+ deps = [
+ "//tools/nogo/data",
+ "@org_golang_x_tools//go/analysis:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/buildssa:go_tool_library",
+ "@org_golang_x_tools//go/ssa:go_tool_library",
+ ],
+)
diff --git a/tools/checkescape/checkescape.go b/tools/checkescape/checkescape.go
new file mode 100644
index 000000000..f8def4823
--- /dev/null
+++ b/tools/checkescape/checkescape.go
@@ -0,0 +1,726 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package checkescape allows recursive escape analysis for hot paths.
+//
+// The analysis tracks multiple types of escapes, in two categories. First,
+// 'hard' escapes are explicit allocations. Second, 'soft' escapes are
+// interface dispatches or dynamic function dispatches; these don't necessarily
+// escape but they *may* escape. The analysis is capable of making assertions
+// recursively: soft escapes cannot be analyzed in this way, and therefore
+// count as escapes for recursive purposes.
+//
+// The different types of escapes are as follows, with the category in
+// parentheses:
+//
+// heap: A direct allocation is made on the heap (hard).
+// builtin: A call is made to a built-in allocation function (hard).
+// stack: A stack split as part of a function preamble (soft).
+// interface: A call is made via an interface whicy *may* escape (soft).
+// dynamic: A dynamic function is dispatched which *may* escape (soft).
+//
+// To the use the package, annotate a function-level comment with either the
+// line "// +checkescape" or "// +checkescape:OPTION[,OPTION]". In the second
+// case, the OPTION field is either a type above, or one of:
+//
+// local: Escape analysis is limited to local hard escapes only.
+// all: All the escapes are included.
+// hard: All hard escapes are included.
+//
+// If the "// +checkescape" annotation is provided, this is equivalent to
+// provided the local and hard options.
+//
+// Some examples of this syntax are:
+//
+// +checkescape:all - Analyzes for all escapes in this function and all calls.
+// +checkescape:local - Analyzes only for default local hard escapes.
+// +checkescape:heap - Only analyzes for heap escapes.
+// +checkescape:interface,dynamic - Only checks for dynamic calls and interface calls.
+// +checkescape - Does the same as +checkescape:local,hard.
+//
+// Note that all of the above can be inverted by using +mustescape. The
+// +checkescape keyword will ensure failure if the class of escape occurs,
+// whereas +mustescape will fail if the given class of escape does not occur.
+//
+// Local exemptions can be made by a comment of the form "// escapes: reason."
+// This must appear on the line of the escape and will also apply to callers of
+// the function as well (for non-local escape analysis).
+package checkescape
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "go/ast"
+ "go/token"
+ "go/types"
+ "io"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+
+ "golang.org/x/tools/go/analysis"
+ "golang.org/x/tools/go/analysis/passes/buildssa"
+ "golang.org/x/tools/go/ssa"
+ "gvisor.dev/gvisor/tools/nogo/data"
+)
+
+const (
+ // magic is the magic annotation.
+ magic = "// +checkescape"
+
+ // magicParams is the magic annotation with specific parameters.
+ magicParams = magic + ":"
+
+ // testMagic is the test magic annotation (parameters required).
+ testMagic = "// +mustescape:"
+
+ // exempt is the exemption annotation.
+ exempt = "// escapes"
+)
+
+// escapingBuiltins are builtins known to escape.
+//
+// These are lowered at an earlier stage of compilation to explicit function
+// calls, but are not available for recursive analysis.
+var escapingBuiltins = []string{
+ "append",
+ "makemap",
+ "newobject",
+ "mallocgc",
+}
+
+// Analyzer defines the entrypoint.
+var Analyzer = &analysis.Analyzer{
+ Name: "checkescape",
+ Doc: "surfaces recursive escape analysis results",
+ Run: run,
+ Requires: []*analysis.Analyzer{buildssa.Analyzer},
+ FactTypes: []analysis.Fact{(*packageEscapeFacts)(nil)},
+}
+
+// packageEscapeFacts is the set of all functions in a package, and whether or
+// not they recursively pass escape analysis.
+//
+// All the type names for receivers are encoded in the full key. The key
+// represents the fully qualified package and type name used at link time.
+type packageEscapeFacts struct {
+ Funcs map[string][]Escape
+}
+
+// AFact implements analysis.Fact.AFact.
+func (*packageEscapeFacts) AFact() {}
+
+// CallSite is a single call site.
+//
+// These can be chained.
+type CallSite struct {
+ LocalPos token.Pos
+ Resolved LinePosition
+}
+
+// Escape is a single escape instance.
+type Escape struct {
+ Reason EscapeReason
+ Detail string
+ Chain []CallSite
+}
+
+// LinePosition is a low-resolution token.Position.
+//
+// This is used to match against possible exemptions placed in the source.
+type LinePosition struct {
+ Filename string
+ Line int
+}
+
+// String implements fmt.Stringer.String.
+func (e *LinePosition) String() string {
+ return fmt.Sprintf("%s:%d", e.Filename, e.Line)
+}
+
+// String implements fmt.Stringer.String.
+//
+// Note that this string will contain new lines.
+func (e *Escape) String() string {
+ var b bytes.Buffer
+ fmt.Fprintf(&b, "%s", e.Reason.String())
+ for i, cs := range e.Chain {
+ if i == len(e.Chain)-1 {
+ fmt.Fprintf(&b, "\n @ %s → %s", cs.Resolved.String(), e.Detail)
+ } else {
+ fmt.Fprintf(&b, "\n + %s", cs.Resolved.String())
+ }
+ }
+ return b.String()
+}
+
+// EscapeReason is an escape reason.
+//
+// This is a simple enum.
+type EscapeReason int
+
+const (
+ interfaceInvoke EscapeReason = iota
+ unknownPackage
+ allocation
+ builtin
+ dynamicCall
+ stackSplit
+ reasonCount // Count for below.
+)
+
+// String returns the string for the EscapeReason.
+//
+// Note that this also implicitly defines the reverse string -> EscapeReason
+// mapping, which is the word before the colon (computed below).
+func (e EscapeReason) String() string {
+ switch e {
+ case interfaceInvoke:
+ return "interface: function invocation via interface"
+ case unknownPackage:
+ return "unknown: no package information available"
+ case allocation:
+ return "heap: call to runtime heap allocation"
+ case builtin:
+ return "builtin: call to runtime builtin"
+ case dynamicCall:
+ return "dynamic: call via dynamic function"
+ case stackSplit:
+ return "stack: stack split on function entry"
+ default:
+ panic(fmt.Sprintf("unknown reason: %d", e))
+ }
+}
+
+var hardReasons = []EscapeReason{
+ allocation,
+ builtin,
+}
+
+var softReasons = []EscapeReason{
+ interfaceInvoke,
+ unknownPackage,
+ dynamicCall,
+ stackSplit,
+}
+
+var allReasons = append(hardReasons, softReasons...)
+
+var escapeTypes = func() map[string]EscapeReason {
+ result := make(map[string]EscapeReason)
+ for _, r := range allReasons {
+ parts := strings.Split(r.String(), ":")
+ result[parts[0]] = r // Key before ':'.
+ }
+ return result
+}()
+
+// EscapeCount counts escapes.
+//
+// It is used to avoid accumulating too many escapes for the same reason, for
+// the same function. We limit each class to 3 instances (arbitrarily).
+type EscapeCount struct {
+ byReason [reasonCount]uint32
+}
+
+// maxRecordsPerReason is the number of explicit records.
+//
+// See EscapeCount (and usage), and Record implementation.
+const maxRecordsPerReason = 5
+
+// Record records the reason or returns false if it should not be added.
+func (ec *EscapeCount) Record(reason EscapeReason) bool {
+ ec.byReason[reason]++
+ if ec.byReason[reason] > maxRecordsPerReason {
+ return false
+ }
+ return true
+}
+
+// loadObjdump reads the objdump output.
+//
+// This records if there is a call any function for every source line. It is
+// used only to remove false positives for escape analysis. The call will be
+// elided if escape analysis is able to put the object on the heap exclusively.
+func loadObjdump() (map[LinePosition]string, error) {
+ f, err := os.Open(data.Objdump)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+
+ // Build the map.
+ m := make(map[LinePosition]string)
+ r := bufio.NewReader(f)
+ var (
+ lastField string
+ lastPos LinePosition
+ )
+ for {
+ line, err := r.ReadString('\n')
+ if err != nil && err != io.EOF {
+ return nil, err
+ }
+
+ // We recognize lines corresponding to actual code (not the
+ // symbol name or other metadata) and annotate them if they
+ // correspond to an explicit CALL instruction. We assume that
+ // the lack of a CALL for a given line is evidence that escape
+ // analysis has eliminated an allocation.
+ //
+ // Lines look like this (including the first space):
+ // gohacks_unsafe.go:33 0xa39 488b442408 MOVQ 0x8(SP), AX
+ if len(line) > 0 && line[0] == ' ' {
+ fields := strings.Fields(line)
+ if !strings.Contains(fields[3], "CALL") {
+ continue
+ }
+
+ // Ignore strings containing duffzero, which is just
+ // used by stack allocations for types that are large
+ // enough to warrant Duff's device.
+ if strings.Contains(line, "runtime.duffzero") {
+ continue
+ }
+
+ // Ignore the racefuncenter call, which is used for
+ // race builds. This does not escape.
+ if strings.Contains(line, "runtime.racefuncenter") {
+ continue
+ }
+
+ // Calculate the filename and line. Note that per the
+ // example above, the filename is not a fully qualified
+ // base, just the basename (what we require).
+ if fields[0] != lastField {
+ parts := strings.SplitN(fields[0], ":", 2)
+ lineNum, err := strconv.ParseInt(parts[1], 10, 64)
+ if err != nil {
+ return nil, err
+ }
+ lastPos = LinePosition{
+ Filename: parts[0],
+ Line: int(lineNum),
+ }
+ lastField = fields[0]
+ }
+ if _, ok := m[lastPos]; ok {
+ continue // Already marked.
+ }
+
+ // Save the actual call for the detail.
+ m[lastPos] = strings.Join(fields[3:], " ")
+ }
+ if err == io.EOF {
+ break
+ }
+ }
+
+ return m, nil
+}
+
+// poser is a type that implements Pos.
+type poser interface {
+ Pos() token.Pos
+}
+
+// run performs the analysis.
+func run(pass *analysis.Pass) (interface{}, error) {
+ calls, err := loadObjdump()
+ if err != nil {
+ return nil, err
+ }
+ pef := packageEscapeFacts{
+ Funcs: make(map[string][]Escape),
+ }
+ linePosition := func(inst, parent poser) LinePosition {
+ p := pass.Fset.Position(inst.Pos())
+ if (p.Filename == "" || p.Line == 0) && parent != nil {
+ p = pass.Fset.Position(parent.Pos())
+ }
+ return LinePosition{
+ Filename: filepath.Base(p.Filename),
+ Line: p.Line,
+ }
+ }
+ hasCall := func(inst poser) (string, bool) {
+ p := linePosition(inst, nil)
+ s, ok := calls[p]
+ return s, ok
+ }
+ callSite := func(inst ssa.Instruction) CallSite {
+ return CallSite{
+ LocalPos: inst.Pos(),
+ Resolved: linePosition(inst, inst.Parent()),
+ }
+ }
+ escapes := func(reason EscapeReason, detail string, inst ssa.Instruction, ec *EscapeCount) []Escape {
+ if !ec.Record(reason) {
+ return nil // Skip.
+ }
+ es := Escape{
+ Reason: reason,
+ Detail: detail,
+ Chain: []CallSite{callSite(inst)},
+ }
+ return []Escape{es}
+ }
+ resolve := func(sub []Escape, inst ssa.Instruction, ec *EscapeCount) (es []Escape) {
+ for _, e := range sub {
+ if !ec.Record(e.Reason) {
+ continue // Skip.
+ }
+ es = append(es, Escape{
+ Reason: e.Reason,
+ Detail: e.Detail,
+ Chain: append([]CallSite{callSite(inst)}, e.Chain...),
+ })
+ }
+ return es
+ }
+ state := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA)
+
+ var loadFunc func(*ssa.Function) []Escape // Used below.
+
+ analyzeInstruction := func(inst ssa.Instruction, ec *EscapeCount) []Escape {
+ switch x := inst.(type) {
+ case *ssa.Call:
+ if x.Call.IsInvoke() {
+ // This is an interface dispatch. There is no
+ // way to know if this is actually escaping or
+ // not, since we don't know the underlying
+ // type.
+ call, _ := hasCall(inst)
+ return escapes(interfaceInvoke, call, inst, ec)
+ }
+ switch x := x.Call.Value.(type) {
+ case *ssa.Function:
+ if x.Pkg == nil {
+ // Can't resolve the package.
+ return escapes(unknownPackage, "no package", inst, ec)
+ }
+
+ // Atomic functions are instrinics. We can
+ // assume that they don't escape.
+ if x.Pkg.Pkg.Name() == "atomic" {
+ return nil
+ }
+
+ // Is this a local function? If yes, call the
+ // function to load the local function. The
+ // local escapes are the escapes found in the
+ // local function.
+ if x.Pkg.Pkg == pass.Pkg {
+ return resolve(loadFunc(x), inst, ec)
+ }
+
+ // Recursively collect information from
+ // the other analyzers.
+ var imp packageEscapeFacts
+ if !pass.ImportPackageFact(x.Pkg.Pkg, &imp) {
+ // Unable to import the dependency; we must
+ // declare these as escaping.
+ return escapes(unknownPackage, "no analysis", inst, ec)
+ }
+
+ // The escapes of this instruction are the
+ // escapes of the called function directly.
+ return resolve(imp.Funcs[x.RelString(x.Pkg.Pkg)], inst, ec)
+ case *ssa.Builtin:
+ // Ignore elided escapes.
+ if _, has := hasCall(inst); !has {
+ return nil
+ }
+
+ // Check if the builtin is escaping.
+ for _, name := range escapingBuiltins {
+ if x.Name() == name {
+ return escapes(builtin, name, inst, ec)
+ }
+ }
+ default:
+ // All dynamic calls are counted as soft
+ // escapes. They are similar to interface
+ // dispatches. We cannot actually look up what
+ // this refers to using static analysis alone.
+ call, _ := hasCall(inst)
+ return escapes(dynamicCall, call, inst, ec)
+ }
+ case *ssa.Alloc:
+ // Ignore non-heap allocations.
+ if !x.Heap {
+ return nil
+ }
+
+ // Ignore elided escapes.
+ call, has := hasCall(inst)
+ if !has {
+ return nil
+ }
+
+ // This is a real heap allocation.
+ return escapes(allocation, call, inst, ec)
+ case *ssa.MakeMap:
+ return escapes(builtin, "makemap", inst, ec)
+ case *ssa.MakeSlice:
+ return escapes(builtin, "makeslice", inst, ec)
+ case *ssa.MakeClosure:
+ return escapes(builtin, "makeclosure", inst, ec)
+ case *ssa.MakeChan:
+ return escapes(builtin, "makechan", inst, ec)
+ }
+ return nil // No escapes.
+ }
+
+ var analyzeBasicBlock func(*ssa.BasicBlock, *EscapeCount) []Escape // Recursive.
+ analyzeBasicBlock = func(block *ssa.BasicBlock, ec *EscapeCount) (rval []Escape) {
+ for _, inst := range block.Instrs {
+ rval = append(rval, analyzeInstruction(inst, ec)...)
+ }
+ return rval // N.B. may be empty.
+ }
+
+ loadFunc = func(fn *ssa.Function) []Escape {
+ // Is this already available?
+ name := fn.RelString(pass.Pkg)
+ if es, ok := pef.Funcs[name]; ok {
+ return es
+ }
+
+ // In the case of a true cycle, we assume that the current
+ // function itself has no escapes until the rest of the
+ // analysis is complete. This will trip the above in the case
+ // of a cycle of any kind.
+ pef.Funcs[name] = nil
+
+ // Perform the basic analysis.
+ var (
+ es []Escape
+ ec EscapeCount
+ )
+ if fn.Recover != nil {
+ es = append(es, analyzeBasicBlock(fn.Recover, &ec)...)
+ }
+ for _, block := range fn.Blocks {
+ es = append(es, analyzeBasicBlock(block, &ec)...)
+ }
+
+ // Check for a stack split.
+ if call, has := hasCall(fn); has {
+ es = append(es, Escape{
+ Reason: stackSplit,
+ Detail: call,
+ Chain: []CallSite{CallSite{
+ LocalPos: fn.Pos(),
+ Resolved: linePosition(fn, fn.Parent()),
+ }},
+ })
+ }
+
+ // Save the result and return.
+ pef.Funcs[name] = es
+ return es
+ }
+
+ // Complete all local functions.
+ for _, fn := range state.SrcFuncs {
+ loadFunc(fn)
+ }
+
+ // Build the exception list.
+ exemptions := make(map[LinePosition]string)
+ for _, f := range pass.Files {
+ for _, cg := range f.Comments {
+ for _, c := range cg.List {
+ p := pass.Fset.Position(c.Slash)
+ if strings.HasPrefix(strings.ToLower(c.Text), exempt) {
+ exemptions[LinePosition{
+ Filename: filepath.Base(p.Filename),
+ Line: p.Line,
+ }] = c.Text[len(exempt):]
+ }
+ }
+ }
+ }
+
+ // Delete everything matching the excemtions.
+ //
+ // This has the implication that exceptions are applied recursively,
+ // since this now modified set is what will be saved.
+ for name, escapes := range pef.Funcs {
+ var newEscapes []Escape
+ for _, escape := range escapes {
+ isExempt := false
+ for line, _ := range exemptions {
+ // Note that an exemption applies if it is
+ // marked as an exemption anywhere in the call
+ // chain. It need not be marked as escapes in
+ // the function itself, nor in the top-level
+ // caller.
+ for _, callSite := range escape.Chain {
+ if callSite.Resolved == line {
+ isExempt = true
+ break
+ }
+ }
+ if isExempt {
+ break
+ }
+ }
+ if !isExempt {
+ // Record this escape; not an exception.
+ newEscapes = append(newEscapes, escape)
+ }
+ }
+ pef.Funcs[name] = newEscapes // Update.
+ }
+
+ // Export all findings for future packages.
+ pass.ExportPackageFact(&pef)
+
+ // Scan all functions for violations.
+ for _, f := range pass.Files {
+ // Scan all declarations.
+ for _, decl := range f.Decls {
+ fdecl, ok := decl.(*ast.FuncDecl)
+ // Function declaration?
+ if !ok {
+ continue
+ }
+ // Is there a comment?
+ if fdecl.Doc == nil {
+ continue
+ }
+ var (
+ reasons []EscapeReason
+ found bool
+ local bool
+ testReasons = make(map[EscapeReason]bool) // reason -> local?
+ )
+ // Does the comment contain a +checkescape line?
+ for _, c := range fdecl.Doc.List {
+ if !strings.HasPrefix(c.Text, magic) && !strings.HasPrefix(c.Text, testMagic) {
+ continue
+ }
+ if c.Text == magic {
+ // Default: hard reasons, local only.
+ reasons = hardReasons
+ local = true
+ } else if strings.HasPrefix(c.Text, magicParams) {
+ // Extract specific reasons.
+ types := strings.Split(c.Text[len(magicParams):], ",")
+ found = true // For below.
+ for i := 0; i < len(types); i++ {
+ if types[i] == "local" {
+ // Limit search to local escapes.
+ local = true
+ } else if types[i] == "all" {
+ // Append all reasons.
+ reasons = append(reasons, allReasons...)
+ } else if types[i] == "hard" {
+ // Append all hard reasons.
+ reasons = append(reasons, hardReasons...)
+ } else {
+ r, ok := escapeTypes[types[i]]
+ if !ok {
+ // This is not a valid escape reason.
+ pass.Reportf(fdecl.Pos(), "unknown reason: %v", types[i])
+ continue
+ }
+ reasons = append(reasons, r)
+ }
+ }
+ } else if strings.HasPrefix(c.Text, testMagic) {
+ types := strings.Split(c.Text[len(testMagic):], ",")
+ local := false
+ for i := 0; i < len(types); i++ {
+ if types[i] == "local" {
+ local = true
+ } else {
+ r, ok := escapeTypes[types[i]]
+ if !ok {
+ // This is not a valid escape reason.
+ pass.Reportf(fdecl.Pos(), "unknown reason: %v", types[i])
+ continue
+ }
+ if v, ok := testReasons[r]; ok && v {
+ // Already registered as local.
+ continue
+ }
+ testReasons[r] = local
+ }
+ }
+ }
+ }
+ if len(reasons) == 0 && found {
+ // A magic annotation was provided, but no reasons.
+ pass.Reportf(fdecl.Pos(), "no reasons provided")
+ continue
+ }
+
+ // Scan for matches.
+ fn := pass.TypesInfo.Defs[fdecl.Name].(*types.Func)
+ name := state.Pkg.Prog.FuncValue(fn).RelString(pass.Pkg)
+ es, ok := pef.Funcs[name]
+ if !ok {
+ pass.Reportf(fdecl.Pos(), "internal error: function %s not found.", name)
+ continue
+ }
+ for _, e := range es {
+ for _, r := range reasons {
+ // Is does meet our local requirement?
+ if local && len(e.Chain) > 1 {
+ continue
+ }
+ // Does this match the reason? Emit
+ // with a full stack trace that
+ // explains why this violates our
+ // constraints.
+ if e.Reason == r {
+ pass.Reportf(e.Chain[0].LocalPos, "%s", e.String())
+ }
+ }
+ }
+
+ // Scan for test (required) matches.
+ testReasonsFound := make(map[EscapeReason]bool)
+ for _, e := range es {
+ // Is this local?
+ local, ok := testReasons[e.Reason]
+ wantLocal := len(e.Chain) == 1
+ testReasonsFound[e.Reason] = wantLocal
+ if !ok {
+ continue
+ }
+ if local == wantLocal {
+ delete(testReasons, e.Reason)
+ }
+ }
+ for reason, local := range testReasons {
+ // We didn't find the escapes we wanted.
+ pass.Reportf(fdecl.Pos(), fmt.Sprintf("testescapes not found: reason=%s, local=%t", reason, local))
+ }
+ if len(testReasons) > 0 {
+ // Dump all reasons found to help in debugging.
+ for _, e := range es {
+ pass.Reportf(e.Chain[0].LocalPos, "escape found: %s", e.String())
+ }
+ }
+ }
+ }
+
+ return nil, nil
+}
diff --git a/tools/checkescape/test1/BUILD b/tools/checkescape/test1/BUILD
new file mode 100644
index 000000000..783403247
--- /dev/null
+++ b/tools/checkescape/test1/BUILD
@@ -0,0 +1,9 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "test1",
+ srcs = ["test1.go"],
+ visibility = ["//tools/checkescape/test2:__pkg__"],
+)
diff --git a/tools/checkescape/test1/test1.go b/tools/checkescape/test1/test1.go
new file mode 100644
index 000000000..68d3f72cc
--- /dev/null
+++ b/tools/checkescape/test1/test1.go
@@ -0,0 +1,195 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package test1 is a test package.
+package test1
+
+import (
+ "fmt"
+ "reflect"
+)
+
+// Interface is a generic interface.
+type Interface interface {
+ Foo()
+}
+
+// Type is a concrete implementation of Interface.
+type Type struct {
+ A uint64
+ B uint64
+}
+
+// Foo implements Interface.Foo.
+//go:nosplit
+func (t Type) Foo() {
+ fmt.Printf("%v", t) // Never executed.
+}
+
+// +checkescape:all,hard
+//go:nosplit
+func InterfaceFunction(i Interface) {
+ // Do nothing; exported for tests.
+}
+
+// +checkesacape:all,hard
+//go:nosplit
+func TypeFunction(t *Type) {
+}
+
+// +mustescape:local,builtin
+//go:noinline
+//go:nosplit
+func BuiltinMap(x int) map[string]bool {
+ return make(map[string]bool)
+}
+
+// +mustescape:builtin
+//go:noinline
+//go:nosplit
+func builtinMapRec(x int) map[string]bool {
+ return BuiltinMap(x)
+}
+
+// +temustescapestescape:local,builtin
+//go:noinline
+//go:nosplit
+func BuiltinClosure(x int) func() {
+ return func() {
+ fmt.Printf("%v", x)
+ }
+}
+
+// +mustescape:builtin
+//go:noinline
+//go:nosplit
+func builtinClosureRec(x int) func() {
+ return BuiltinClosure(x)
+}
+
+// +mustescape:local,builtin
+//go:noinline
+//go:nosplit
+func BuiltinMakeSlice(x int) []byte {
+ return make([]byte, x)
+}
+
+// +mustescape:builtin
+//go:noinline
+//go:nosplit
+func builtinMakeSliceRec(x int) []byte {
+ return BuiltinMakeSlice(x)
+}
+
+// +mustescape:local,builtin
+//go:noinline
+//go:nosplit
+func BuiltinAppend(x []byte) []byte {
+ return append(x, 0)
+}
+
+// +mustescape:builtin
+//go:noinline
+//go:nosplit
+func builtinAppendRec() []byte {
+ return BuiltinAppend(nil)
+}
+
+// +mustescape:local,builtin
+//go:noinline
+//go:nosplit
+func BuiltinChan() chan int {
+ return make(chan int)
+}
+
+// +mustescape:builtin
+//go:noinline
+//go:nosplit
+func builtinChanRec() chan int {
+ return BuiltinChan()
+}
+
+// +mustescape:local,heap
+//go:noinline
+//go:nosplit
+func Heap() *Type {
+ var t Type
+ return &t
+}
+
+// +mustescape:heap
+//go:noinline
+//go:nosplit
+func heapRec() *Type {
+ return Heap()
+}
+
+// +mustescape:local,interface
+//go:noinline
+//go:nosplit
+func Dispatch(i Interface) {
+ i.Foo()
+}
+
+// +mustescape:interface
+//go:noinline
+//go:nosplit
+func dispatchRec(i Interface) {
+ Dispatch(i)
+}
+
+// +mustescape:local,dynamic
+//go:noinline
+//go:nosplit
+func Dynamic(f func()) {
+ f()
+}
+
+// +mustescape:dynamic
+//go:noinline
+//go:nosplit
+func dynamicRec(f func()) {
+ Dynamic(f)
+}
+
+// +mustescape:local,unknown
+//go:noinline
+//go:nosplit
+func Unknown() {
+ _ = reflect.TypeOf((*Type)(nil)) // Does not actually escape.
+}
+
+// +mustescape:unknown
+//go:noinline
+//go:nosplit
+func unknownRec() {
+ Unknown()
+}
+
+//go:noinline
+//go:nosplit
+func internalFunc() {
+}
+
+// +mustescape:local,stack
+//go:noinline
+func Split() {
+ internalFunc()
+}
+
+// +mustescape:stack
+//go:noinline
+func splitRec() {
+ Split()
+}
diff --git a/tools/checkescape/test2/BUILD b/tools/checkescape/test2/BUILD
new file mode 100644
index 000000000..5a11e4b43
--- /dev/null
+++ b/tools/checkescape/test2/BUILD
@@ -0,0 +1,9 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "test2",
+ srcs = ["test2.go"],
+ deps = ["//tools/checkescape/test1"],
+)
diff --git a/tools/checkescape/test2/test2.go b/tools/checkescape/test2/test2.go
new file mode 100644
index 000000000..7fce3e3be
--- /dev/null
+++ b/tools/checkescape/test2/test2.go
@@ -0,0 +1,94 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package test2 is a test package that imports test1.
+package test2
+
+import (
+ "gvisor.dev/gvisor/tools/checkescape/test1"
+)
+
+// +checkescape:all
+//go:nosplit
+func interfaceFunctionCrossPkg() {
+ var i test1.Interface
+ test1.InterfaceFunction(i)
+}
+
+// +checkesacape:all
+//go:nosplit
+func typeFunctionCrossPkg() {
+ var t test1.Type
+ test1.TypeFunction(&t)
+}
+
+// +mustescape:builtin
+//go:noinline
+func builtinMapCrossPkg(x int) map[string]bool {
+ return test1.BuiltinMap(x)
+}
+
+// +mustescape:builtin
+//go:noinline
+func builtinClosureCrossPkg(x int) func() {
+ return test1.BuiltinClosure(x)
+}
+
+// +mustescape:builtin
+//go:noinline
+func builtinMakeSliceCrossPkg(x int) []byte {
+ return test1.BuiltinMakeSlice(x)
+}
+
+// +mustescape:builtin
+//go:noinline
+func builtinAppendCrossPkg() []byte {
+ return test1.BuiltinAppend(nil)
+}
+
+// +mustescape:builtin
+//go:noinline
+func builtinChanCrossPkg() chan int {
+ return test1.BuiltinChan()
+}
+
+// +mustescape:heap
+//go:noinline
+func heapCrossPkg() *test1.Type {
+ return test1.Heap()
+}
+
+// +mustescape:interface
+//go:noinline
+func dispatchCrossPkg(i test1.Interface) {
+ test1.Dispatch(i)
+}
+
+// +mustescape:dynamic
+//go:noinline
+func dynamicCrossPkg(f func()) {
+ test1.Dynamic(f)
+}
+
+// +mustescape:unknown
+//go:noinline
+func unknownCrossPkg() {
+ test1.Unknown()
+}
+
+// +mustescape:stack
+//go:noinline
+func splitCrosssPkt() {
+ test1.Split()
+}
diff --git a/tools/checkunsafe/BUILD b/tools/checkunsafe/BUILD
new file mode 100644
index 000000000..0c264151b
--- /dev/null
+++ b/tools/checkunsafe/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "checkunsafe",
+ srcs = ["check_unsafe.go"],
+ nogo = False,
+ visibility = ["//tools/nogo:__subpackages__"],
+ deps = [
+ "@org_golang_x_tools//go/analysis:go_tool_library",
+ ],
+)
diff --git a/tools/checkunsafe/check_unsafe.go b/tools/checkunsafe/check_unsafe.go
new file mode 100644
index 000000000..4ccd7cc5a
--- /dev/null
+++ b/tools/checkunsafe/check_unsafe.go
@@ -0,0 +1,56 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package checkunsafe allows unsafe imports only in files named appropriately.
+package checkunsafe
+
+import (
+ "fmt"
+ "path"
+ "strconv"
+ "strings"
+
+ "golang.org/x/tools/go/analysis"
+)
+
+// Analyzer defines the entrypoint.
+var Analyzer = &analysis.Analyzer{
+ Name: "checkunsafe",
+ Doc: "allows unsafe use only in specified files",
+ Run: run,
+}
+
+func run(pass *analysis.Pass) (interface{}, error) {
+ for _, f := range pass.Files {
+ for _, imp := range f.Imports {
+ // Is this an unsafe import?
+ pkg, err := strconv.Unquote(imp.Path.Value)
+ if err != nil || pkg != "unsafe" {
+ continue
+ }
+
+ // Extract the filename.
+ filename := pass.Fset.File(imp.Pos()).Name()
+
+ // Allow files named _unsafe.go or _test.go to opt out.
+ if strings.HasSuffix(filename, "_unsafe.go") || strings.HasSuffix(filename, "_test.go") {
+ continue
+ }
+
+ // Throw the error.
+ pass.Reportf(imp.Pos(), fmt.Sprintf("package unsafe imported by %s; must end with _unsafe.go", path.Base(filename)))
+ }
+ }
+ return nil, nil
+}
diff --git a/tools/defs.bzl b/tools/defs.bzl
new file mode 100644
index 000000000..40afcdb79
--- /dev/null
+++ b/tools/defs.bzl
@@ -0,0 +1,254 @@
+"""Wrappers for common build rules.
+
+These wrappers apply common BUILD configurations (e.g., proto_library
+automagically creating cc_ and go_ proto targets) and act as a single point of
+change for Google-internal and bazel-compatible rules.
+"""
+
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps")
+load("//tools/bazeldefs:defs.bzl", _build_test = "build_test", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _default_installer = "default_installer", _default_net_util = "default_net_util", _gazelle = "gazelle", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_test = "go_test", _grpcpp = "grpcpp", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _rbe_platform = "rbe_platform", _rbe_toolchain = "rbe_toolchain", _select_arch = "select_arch", _select_system = "select_system", _vdso_linker_option = "vdso_linker_option")
+load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms")
+load("//tools/bazeldefs:tags.bzl", "go_suffixes")
+load("//tools/nogo:defs.bzl", "nogo_test")
+
+# Delegate directly.
+build_test = _build_test
+cc_binary = _cc_binary
+cc_flags_supplier = _cc_flags_supplier
+cc_grpc_library = _cc_grpc_library
+cc_library = _cc_library
+cc_test = _cc_test
+cc_toolchain = _cc_toolchain
+default_installer = _default_installer
+default_net_util = _default_net_util
+gbenchmark = _gbenchmark
+gazelle = _gazelle
+go_embed_data = _go_embed_data
+go_path = _go_path
+go_test = _go_test
+gtest = _gtest
+grpcpp = _grpcpp
+loopback = _loopback
+pkg_deb = _pkg_deb
+pkg_tar = _pkg_tar
+py_binary = _py_binary
+py_library = _py_library
+py_requirement = _py_requirement
+py_test = _py_test
+select_arch = _select_arch
+select_system = _select_system
+rbe_platform = _rbe_platform
+rbe_toolchain = _rbe_toolchain
+vdso_linker_option = _vdso_linker_option
+
+# Platform options.
+default_platform = _default_platform
+platforms = _platforms
+
+def go_binary(name, **kwargs):
+ """Wraps the standard go_binary.
+
+ Args:
+ name: the rule name.
+ **kwargs: standard go_binary arguments.
+ """
+ _go_binary(
+ name = name,
+ **kwargs
+ )
+
+def calculate_sets(srcs):
+ """Calculates special Go sets for templates.
+
+ Args:
+ srcs: the full set of Go sources.
+
+ Returns:
+ A dictionary of the form:
+
+ "": [src1.go, src2.go]
+ "suffix": [src3suffix.go, src4suffix.go]
+
+ Note that suffix will typically start with '_'.
+ """
+ result = dict()
+ for file in srcs:
+ if not file.endswith(".go"):
+ continue
+ target = ""
+ for suffix in go_suffixes:
+ if file.endswith(suffix + ".go"):
+ target = suffix
+ if not target in result:
+ result[target] = [file]
+ else:
+ result[target].append(file)
+ return result
+
+def go_imports(name, src, out):
+ """Simplify a single Go source file by eliminating unused imports."""
+ native.genrule(
+ name = name,
+ srcs = [src],
+ outs = [out],
+ tools = ["@org_golang_x_tools//cmd/goimports:goimports"],
+ cmd = ("$(location @org_golang_x_tools//cmd/goimports:goimports) $(SRCS) > $@"),
+ )
+
+def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, marshal_debug = False, nogo = True, **kwargs):
+ """Wraps the standard go_library and does stateification and marshalling.
+
+ The recommended way is to use this rule with mostly identical configuration as the native
+ go_library rule.
+
+ These definitions provide additional flags (stateify, marshal) that can be used
+ with the generators to automatically supplement the library code.
+
+ load("//tools:defs.bzl", "go_library")
+
+ go_library(
+ name = "foo",
+ srcs = ["foo.go"],
+ )
+
+ Args:
+ name: the rule name.
+ srcs: the library sources.
+ deps: the library dependencies.
+ imports: imports required for stateify.
+ stateify: whether statify is enabled (default: true).
+ marshal: whether marshal is enabled (default: false).
+ marshal_debug: whether the gomarshal tools emits debugging output (default: false).
+ **kwargs: standard go_library arguments.
+ """
+ all_srcs = srcs
+ all_deps = deps
+ dirname, _, _ = native.package_name().rpartition("/")
+ full_pkg = dirname + "/" + name
+ if stateify:
+ # Only do stateification for non-state packages without manual autogen.
+ # First, we need to segregate the input files via the special suffixes,
+ # and calculate the final output set.
+ state_sets = calculate_sets(srcs)
+ for (suffix, src_subset) in state_sets.items():
+ go_stateify(
+ name = name + suffix + "_state_autogen_with_imports",
+ srcs = src_subset,
+ imports = imports,
+ package = full_pkg,
+ out = name + suffix + "_state_autogen_with_imports.go",
+ )
+ go_imports(
+ name = name + suffix + "_state_autogen",
+ src = name + suffix + "_state_autogen_with_imports.go",
+ out = name + suffix + "_state_autogen.go",
+ )
+ all_srcs = all_srcs + [
+ name + suffix + "_state_autogen.go"
+ for suffix in state_sets.keys()
+ ]
+ if "//pkg/state" not in all_deps:
+ all_deps = all_deps + ["//pkg/state"]
+
+ if marshal:
+ # See above.
+ marshal_sets = calculate_sets(srcs)
+ for (suffix, src_subset) in marshal_sets.items():
+ go_marshal(
+ name = name + suffix + "_abi_autogen",
+ srcs = src_subset,
+ debug = select({
+ "//tools/go_marshal:marshal_config_verbose": True,
+ "//conditions:default": marshal_debug,
+ }),
+ imports = imports,
+ package = name,
+ )
+ extra_deps = [
+ dep
+ for dep in marshal_deps
+ if not dep in all_deps
+ ]
+ all_deps = all_deps + extra_deps
+ all_srcs = all_srcs + [
+ name + suffix + "_abi_autogen_unsafe.go"
+ for suffix in marshal_sets.keys()
+ ]
+
+ _go_library(
+ name = name,
+ srcs = all_srcs,
+ deps = all_deps,
+ **kwargs
+ )
+ if nogo:
+ nogo_test(
+ name = name + "_nogo",
+ deps = [":" + name],
+ )
+
+ if marshal:
+ # Ignore importpath for go_test.
+ kwargs.pop("importpath", None)
+
+ # See above.
+ marshal_sets = calculate_sets(srcs)
+ for (suffix, _) in marshal_sets.items():
+ _go_test(
+ name = name + suffix + "_abi_autogen_test",
+ srcs = [name + suffix + "_abi_autogen_test.go"],
+ library = ":" + name,
+ deps = marshal_test_deps,
+ **kwargs
+ )
+
+def proto_library(name, srcs, deps = None, has_services = 0, **kwargs):
+ """Wraps the standard proto_library.
+
+ Given a proto_library named "foo", this produces up to five different
+ targets:
+ - foo_proto: proto_library rule.
+ - foo_go_proto: go_proto_library rule.
+ - foo_cc_proto: cc_proto_library rule.
+ - foo_go_grpc_proto: go_grpc_library rule.
+ - foo_cc_grpc_proto: cc_grpc_library rule.
+
+ Args:
+ name: the name to which _proto, _go_proto, etc, will be appended.
+ srcs: the proto sources.
+ deps: for the proto library and the go_proto_library.
+ has_services: 1 to build gRPC code, otherwise 0.
+ **kwargs: standard proto_library arguments.
+ """
+ _proto_library(
+ name = name + "_proto",
+ srcs = srcs,
+ deps = deps,
+ has_services = has_services,
+ **kwargs
+ )
+ if has_services:
+ _go_grpc_and_proto_libraries(
+ name = name,
+ deps = deps,
+ **kwargs
+ )
+ else:
+ _go_proto_library(
+ name = name,
+ deps = deps,
+ **kwargs
+ )
+ _cc_proto_library(
+ name = name + "_cc_proto",
+ deps = [":" + name + "_proto"],
+ **kwargs
+ )
+ if has_services:
+ _cc_grpc_library(
+ name = name + "_cc_grpc_proto",
+ srcs = [":" + name + "_proto"],
+ deps = [":" + name + "_cc_proto"],
+ **kwargs
+ )
diff --git a/tools/go_branch.sh b/tools/go_branch.sh
new file mode 100755
index 000000000..093de89b4
--- /dev/null
+++ b/tools/go_branch.sh
@@ -0,0 +1,101 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+# Discovery the package name from the go.mod file.
+declare -r module=$(cat go.mod | grep -E "^module" | cut -d' ' -f2)
+declare -r origpwd=$(pwd)
+declare -r othersrc=("go.mod" "go.sum" "AUTHORS" "LICENSE")
+
+# Check that gopath has been built.
+declare -r gopath_dir="$(pwd)/bazel-bin/gopath/src/${module}"
+if ! [ -d "${gopath_dir}" ]; then
+ echo "No gopath directory found; build the :gopath target." >&2
+ exit 1
+fi
+
+# Create a temporary working directory, and ensure that this directory and all
+# subdirectories are cleaned up upon exit.
+declare -r tmp_dir=$(mktemp -d)
+finish() {
+ cd # Leave tmp_dir.
+ rm -rf "${tmp_dir}"
+}
+trap finish EXIT
+
+# Record the current working commit.
+declare -r head=$(git describe --always)
+
+# We expect to have an existing go branch that we will use as the basis for
+# this commit. That branch may be empty, but it must exist.
+git fetch --all
+declare -r go_branch=$(git show-ref --hash go)
+
+# Clone the current repository to the temporary directory, and check out the
+# current go_branch directory. We move to the new repository for convenience.
+declare -r repo_orig="$(pwd)"
+declare -r repo_new="${tmp_dir}/repository"
+git clone . "${repo_new}"
+cd "${repo_new}"
+
+# Setup the repository and checkout the branch.
+git config user.email "gvisor-bot@google.com"
+git config user.name "gVisor bot"
+git fetch origin "${go_branch}"
+git checkout -b go "${go_branch}"
+
+# Start working on a merge commit that combines the previous history with the
+# current history. Note that we don't actually want any changes yet.
+#
+# N.B. The git behavior changed at some point and the relevant flag was added
+# to allow for override, so try the only behavior first then pass the flag.
+git merge --no-commit --strategy ours ${head} || \
+ git merge --allow-unrelated-histories --no-commit --strategy ours ${head}
+
+# Sync the entire gopath_dir.
+rsync --recursive --verbose --delete --exclude .git -L "${gopath_dir}/" .
+
+# Add additional files.
+for file in "${othersrc[@]}"; do
+ cp "${origpwd}"/"${file}" .
+done
+
+# Construct a new README.md.
+cat > README.md <<EOF
+# gVisor
+
+This branch is a synthetic branch, containing only Go sources, that is
+compatible with standard Go tools. See the master branch for authoritative
+sources and tests.
+EOF
+
+# There are a few solitary files that can get left behind due to the way bazel
+# constructs the gopath target. Note that we don't find all Go files here
+# because they may correspond to unused templates, etc.
+cp "${repo_orig}"/runsc/*.go runsc/
+
+# Normalize all permissions. The way bazel constructs the :gopath tree may leave
+# some strange permissions on files. We don't have anything in this tree that
+# should be execution, only the Go source files, README.md, and ${othersrc}.
+find . -type f -exec chmod 0644 {} \;
+find . -type d -exec chmod 0755 {} \;
+
+# Update the current working set and commit.
+git add . && git commit -m "Merge ${head} (automated)"
+
+# Push the branch back to the original repository.
+git remote add orig "${repo_orig}" && git push -f orig go:go
diff --git a/tools/go_generics/BUILD b/tools/go_generics/BUILD
new file mode 100644
index 000000000..32a949c93
--- /dev/null
+++ b/tools/go_generics/BUILD
@@ -0,0 +1,38 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "go_generics",
+ srcs = [
+ "generics.go",
+ "imports.go",
+ "remove.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = ["//tools/go_generics/globals"],
+)
+
+genrule(
+ name = "go_generics_tests",
+ srcs = glob(["generics_tests/**"]) + [":go_generics"],
+ outs = ["go_generics_tests.tgz"],
+ cmd = "tar -czvhf $@ $(SRCS)",
+)
+
+genrule(
+ name = "go_generics_test_bundle",
+ srcs = [
+ ":go_generics_tests.tgz",
+ ":go_generics_unittest.sh",
+ ],
+ outs = ["go_generics_test.sh"],
+ cmd = "cat $(location :go_generics_unittest.sh) $(location :go_generics_tests.tgz) > $@",
+ executable = True,
+)
+
+sh_test(
+ name = "go_generics_test",
+ size = "small",
+ srcs = ["go_generics_test.sh"],
+)
diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl
new file mode 100644
index 000000000..8c9995fd4
--- /dev/null
+++ b/tools/go_generics/defs.bzl
@@ -0,0 +1,139 @@
+def _go_template_impl(ctx):
+ input = ctx.files.srcs
+ output = ctx.outputs.out
+
+ args = ["-o=%s" % output.path] + [f.path for f in input]
+
+ ctx.actions.run(
+ inputs = input,
+ outputs = [output],
+ mnemonic = "GoGenericsTemplate",
+ progress_message = "Building Go template %s" % ctx.label,
+ arguments = args,
+ executable = ctx.executable._tool,
+ )
+
+ return struct(
+ types = ctx.attr.types,
+ opt_types = ctx.attr.opt_types,
+ consts = ctx.attr.consts,
+ opt_consts = ctx.attr.opt_consts,
+ deps = ctx.attr.deps,
+ file = output,
+ )
+
+"""
+Generates a Go template from a set of Go files.
+
+A Go template is similar to a go library, except that it has certain types that
+can be replaced before usage. For example, one could define a templatized List
+struct, whose elements are of type T, then instantiate that template for
+T=segment, where "segment" is the concrete type.
+
+Args:
+ name: the name of the template.
+ srcs: the list of source files that comprise the template.
+ types: the list of generic types in the template that are required to be specified.
+ opt_types: the list of generic types in the template that can but aren't required to be specified.
+ consts: the list of constants in the template that are required to be specified.
+ opt_consts: the list of constants in the template that can but aren't required to be specified.
+ deps: the list of dependencies.
+"""
+go_template = rule(
+ implementation = _go_template_impl,
+ attrs = {
+ "srcs": attr.label_list(mandatory = True, allow_files = True),
+ "deps": attr.label_list(allow_files = True),
+ "types": attr.string_list(),
+ "opt_types": attr.string_list(),
+ "consts": attr.string_list(),
+ "opt_consts": attr.string_list(),
+ "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_generics/go_merge")),
+ },
+ outputs = {
+ "out": "%{name}_template.go",
+ },
+)
+
+def _go_template_instance_impl(ctx):
+ template = ctx.attr.template
+ output = ctx.outputs.out
+
+ # Check that all required types are defined.
+ for t in template.types:
+ if t not in ctx.attr.types:
+ fail("Missing value for type %s in %s" % (t, ctx.attr.template.label))
+
+ # Check that all defined types are expected by the template.
+ for t in ctx.attr.types:
+ if (t not in template.types) and (t not in template.opt_types):
+ fail("Type %s it not a parameter to %s" % (t, ctx.attr.template.label))
+
+ # Check that all required consts are defined.
+ for t in template.consts:
+ if t not in ctx.attr.consts:
+ fail("Missing value for constant %s in %s" % (t, ctx.attr.template.label))
+
+ # Check that all defined consts are expected by the template.
+ for t in ctx.attr.consts:
+ if (t not in template.consts) and (t not in template.opt_consts):
+ fail("Const %s it not a parameter to %s" % (t, ctx.attr.template.label))
+
+ # Build the argument list.
+ args = ["-i=%s" % template.file.path, "-o=%s" % output.path]
+ args += ["-p=%s" % ctx.attr.package]
+
+ if len(ctx.attr.prefix) > 0:
+ args += ["-prefix=%s" % ctx.attr.prefix]
+
+ if len(ctx.attr.suffix) > 0:
+ args += ["-suffix=%s" % ctx.attr.suffix]
+
+ args += [("-t=%s=%s" % (p[0], p[1])) for p in ctx.attr.types.items()]
+ args += [("-c=%s=%s" % (p[0], p[1])) for p in ctx.attr.consts.items()]
+ args += [("-import=%s=%s" % (p[0], p[1])) for p in ctx.attr.imports.items()]
+
+ if ctx.attr.anon:
+ args += ["-anon"]
+
+ ctx.actions.run(
+ inputs = [template.file],
+ outputs = [output],
+ mnemonic = "GoGenericsInstance",
+ progress_message = "Building Go template instance %s" % ctx.label,
+ arguments = args,
+ executable = ctx.executable._tool,
+ )
+
+ return struct(
+ files = depset([output]),
+ )
+
+"""
+Instantiates a Go template by replacing all generic types with concrete ones.
+
+Args:
+ name: the name of the template instance.
+ template: the label of the template to be instatiated.
+ prefix: a prefix to be added to globals in the template.
+ suffix: a suffix to be added to global in the template.
+ types: the map from generic type names to concrete ones.
+ consts: the map from constant names to their values.
+ imports: the map from imports used in types/consts to their import paths.
+ package: the name of the package the instantiated template will be compiled into.
+"""
+go_template_instance = rule(
+ implementation = _go_template_instance_impl,
+ attrs = {
+ "template": attr.label(mandatory = True, providers = ["types"]),
+ "prefix": attr.string(),
+ "suffix": attr.string(),
+ "types": attr.string_dict(),
+ "consts": attr.string_dict(),
+ "imports": attr.string_dict(),
+ "anon": attr.bool(mandatory = False, default = False),
+ "package": attr.string(mandatory = True),
+ "out": attr.output(mandatory = True),
+ "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_generics")),
+ },
+)
diff --git a/tools/go_generics/generics.go b/tools/go_generics/generics.go
new file mode 100644
index 000000000..0860ca9db
--- /dev/null
+++ b/tools/go_generics/generics.go
@@ -0,0 +1,286 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// go_generics reads a Go source file and writes a new version of that file with
+// a few transformations applied to each. Namely:
+//
+// 1. Global types can be explicitly renamed with the -t option. For example,
+// if -t=A=B is passed in, all references to A will be replaced with
+// references to B; a function declaration like:
+//
+// func f(arg *A)
+//
+// would be renamed to:
+//
+// func f(arg *B)
+//
+// 2. Global type definitions and their method sets will be removed when they're
+// being renamed with -t. For example, if -t=A=B is passed in, the following
+// definition and methods that existed in the input file wouldn't exist at
+// all in the output file:
+//
+// type A struct{}
+//
+// func (*A) f() {}
+//
+// 3. All global types, variables, constants and functions (not methods) are
+// prefixed and suffixed based on the option -prefix and -suffix arguments.
+// For example, if -suffix=A is passed in, the following globals:
+//
+// func f()
+// type t struct{}
+//
+// would be renamed to:
+//
+// func fA()
+// type tA struct{}
+//
+// Some special tags are also modified. For example:
+//
+// "state:.(t)"
+//
+// would become:
+//
+// "state:.(tA)"
+//
+// 4. The package is renamed to the value via the -p argument.
+// 5. Value of constants can be modified with -c argument.
+//
+// Note that not just the top-level declarations are renamed, all references to
+// them are also properly renamed as well, taking into account visibility rules
+// and shadowing. For example, if -suffix=A is passed in, the following:
+//
+// var b = 100
+//
+// func f() {
+// g(b)
+// b := 0
+// g(b)
+// }
+//
+// Would be replaced with:
+//
+// var bA = 100
+//
+// func f() {
+// g(bA)
+// b := 0
+// g(b)
+// }
+//
+// Note that the second call to g() kept "b" as an argument because it refers to
+// the local variable "b".
+//
+// Note that go_generics can handle anonymous fields with renamed types if
+// -anon is passed in, however it does not perform strict checking on parameter
+// types that share the same name as the global type and therefore will rename
+// them as well.
+//
+// You can see an example in the tools/go_generics/generics_tests/interface test.
+package main
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/format"
+ "go/parser"
+ "go/token"
+ "io/ioutil"
+ "os"
+ "regexp"
+ "strings"
+
+ "gvisor.dev/gvisor/tools/go_generics/globals"
+)
+
+var (
+ input = flag.String("i", "", "input `file`")
+ output = flag.String("o", "", "output `file`")
+ suffix = flag.String("suffix", "", "`suffix` to add to each global symbol")
+ prefix = flag.String("prefix", "", "`prefix` to add to each global symbol")
+ packageName = flag.String("p", "main", "output package `name`")
+ printAST = flag.Bool("ast", false, "prints the AST")
+ processAnon = flag.Bool("anon", false, "process anonymous fields")
+ types = make(mapValue)
+ consts = make(mapValue)
+ imports = make(mapValue)
+)
+
+// mapValue implements flag.Value. We use a mapValue flag instead of a regular
+// string flag when we want to allow more than one instance of the flag. For
+// example, we allow several "-t A=B" arguments, and will rename them all.
+type mapValue map[string]string
+
+func (m mapValue) String() string {
+ var b bytes.Buffer
+ first := true
+ for k, v := range m {
+ if !first {
+ b.WriteRune(',')
+ } else {
+ first = false
+ }
+ b.WriteString(k)
+ b.WriteRune('=')
+ b.WriteString(v)
+ }
+ return b.String()
+}
+
+func (m mapValue) Set(s string) error {
+ sep := strings.Index(s, "=")
+ if sep == -1 {
+ return fmt.Errorf("missing '=' from '%s'", s)
+ }
+
+ m[s[:sep]] = s[sep+1:]
+
+ return nil
+}
+
+// stateTagRegexp matches against the 'typed' state tags.
+var stateTagRegexp = regexp.MustCompile(`^(.*[^a-z0-9_])state:"\.\(([^\)]*)\)"(.*)$`)
+
+var identifierRegexp = regexp.MustCompile(`^(.*[^a-zA-Z_])([a-zA-Z_][a-zA-Z0-9_]*)(.*)$`)
+
+func main() {
+ flag.Usage = func() {
+ fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", os.Args[0])
+ flag.PrintDefaults()
+ }
+
+ flag.Var(types, "t", "rename type A to B when `A=B` is passed in. Multiple such mappings are allowed.")
+ flag.Var(consts, "c", "reassign constant A to value B when `A=B` is passed in. Multiple such mappings are allowed.")
+ flag.Var(imports, "import", "specifies the import libraries to use when types are not local. `name=path` specifies that 'name', used in types as name.type, refers to the package living in 'path'.")
+ flag.Parse()
+
+ if *input == "" || *output == "" {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ // Parse the input file.
+ fset := token.NewFileSet()
+ f, err := parser.ParseFile(fset, *input, nil, parser.ParseComments|parser.DeclarationErrors|parser.SpuriousErrors)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "%v\n", err)
+ os.Exit(1)
+ }
+
+ // Print the AST if requested.
+ if *printAST {
+ ast.Print(fset, f)
+ }
+
+ cmap := ast.NewCommentMap(fset, f, f.Comments)
+
+ // Update imports based on what's used in types and consts.
+ maps := []mapValue{types, consts}
+ importDecl, err := updateImports(maps, imports)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "%v\n", err)
+ os.Exit(1)
+ }
+ types = maps[0]
+ consts = maps[1]
+
+ // Reassign all specified constants.
+ for _, decl := range f.Decls {
+ d, ok := decl.(*ast.GenDecl)
+ if !ok || d.Tok != token.CONST {
+ continue
+ }
+
+ for _, gs := range d.Specs {
+ s := gs.(*ast.ValueSpec)
+ for i, id := range s.Names {
+ if n, ok := consts[id.Name]; ok {
+ s.Values[i] = &ast.BasicLit{Value: n}
+ }
+ }
+ }
+ }
+
+ // Go through all globals and their uses in the AST and rename the types
+ // with explicitly provided names, and rename all types, variables,
+ // consts and functions with the provided prefix and suffix.
+ globals.Visit(fset, f, func(ident *ast.Ident, kind globals.SymKind) {
+ if n, ok := types[ident.Name]; ok && kind == globals.KindType {
+ ident.Name = n
+ } else {
+ switch kind {
+ case globals.KindType, globals.KindVar, globals.KindConst, globals.KindFunction:
+ if ident.Name != "_" {
+ ident.Name = *prefix + ident.Name + *suffix
+ }
+ case globals.KindTag:
+ // Modify the state tag appropriately.
+ if m := stateTagRegexp.FindStringSubmatch(ident.Name); m != nil {
+ if t := identifierRegexp.FindStringSubmatch(m[2]); t != nil {
+ typeName := *prefix + t[2] + *suffix
+ if n, ok := types[t[2]]; ok {
+ typeName = n
+ }
+ ident.Name = m[1] + `state:".(` + t[1] + typeName + t[3] + `)"` + m[3]
+ }
+ }
+ }
+ }
+ }, *processAnon)
+
+ // Remove the definition of all types that are being remapped.
+ set := make(typeSet)
+ for _, v := range types {
+ set[v] = struct{}{}
+ }
+ removeTypes(set, f)
+
+ // Add the new imports, if any, to the top.
+ if importDecl != nil {
+ newDecls := make([]ast.Decl, 0, len(f.Decls)+1)
+ newDecls = append(newDecls, importDecl)
+ newDecls = append(newDecls, f.Decls...)
+ f.Decls = newDecls
+ }
+
+ // Update comments to remove the ones potentially associated with the
+ // type T that we removed.
+ f.Comments = cmap.Filter(f).Comments()
+
+ // If there are file (package) comments, delete them.
+ if f.Doc != nil {
+ for i, cg := range f.Comments {
+ if cg == f.Doc {
+ f.Comments = append(f.Comments[:i], f.Comments[i+1:]...)
+ break
+ }
+ }
+ }
+
+ // Write the output file.
+ f.Name.Name = *packageName
+
+ var buf bytes.Buffer
+ if err := format.Node(&buf, fset, f); err != nil {
+ fmt.Fprintf(os.Stderr, "%v\n", err)
+ os.Exit(1)
+ }
+
+ if err := ioutil.WriteFile(*output, buf.Bytes(), 0644); err != nil {
+ fmt.Fprintf(os.Stderr, "%v\n", err)
+ os.Exit(1)
+ }
+}
diff --git a/tools/go_generics/generics_tests/all_stmts/input.go b/tools/go_generics/generics_tests/all_stmts/input.go
new file mode 100644
index 000000000..4791d1ff1
--- /dev/null
+++ b/tools/go_generics/generics_tests/all_stmts/input.go
@@ -0,0 +1,290 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "sync"
+)
+
+type T int
+
+func h(T) {
+}
+
+type s struct {
+ a, b int
+ c []int
+}
+
+func g(T) *s {
+ return &s{}
+}
+
+func f() (T, []int) {
+ // Branch.
+ goto T
+ goto R
+
+ // Labeled.
+T:
+ _ = T(0)
+
+ // Empty.
+R:
+ ;
+
+ // Assignment with definition.
+ a, b, c := T(1), T(2), T(3)
+ _, _, _ = a, b, c
+
+ // Assignment without definition.
+ g(T(0)).a, g(T(1)).b, c = int(T(1)), int(T(2)), T(3)
+ _, _, _ = a, b, c
+
+ // Block.
+ {
+ var T T
+ T = 0
+ _ = T
+ }
+
+ // Declarations.
+ type Type T
+ const Const T = 10
+ var g1 func(T, int, ...T) (int, T)
+ var v T
+ var w = T(0)
+ {
+ var T struct {
+ f []T
+ }
+ _ = T
+ }
+
+ // Defer.
+ defer g1(T(0), 1)
+
+ // Expression.
+ h(v + w + T(1))
+
+ // For statements.
+ for i := T(0); i < T(10); i++ {
+ var T func(int) T
+ v := T(0)
+ _ = v
+ }
+
+ for {
+ var T func(int) T
+ v := T(0)
+ _ = v
+ }
+
+ // Go.
+ go g1(T(0), 1)
+
+ // If statements.
+ if a != T(1) {
+ var T func(int) T
+ v := T(0)
+ _ = v
+ }
+
+ if a := T(0); a != T(1) {
+ var T func(int) T
+ v := T(0)
+ _ = v
+ }
+
+ if a := T(0); a != T(1) {
+ var T func(int) T
+ v := T(0)
+ _ = v
+ } else if b := T(0); b != T(1) {
+ var T func(int) T
+ v := T(0)
+ _ = v
+ } else if T := T(0); T != 1 {
+ T++
+ } else {
+ T--
+ }
+
+ if a := T(0); a != T(1) {
+ var T func(int) T
+ v := T(0)
+ _ = v
+ } else {
+ var T func(int) T
+ v := T(0)
+ _ = v
+ }
+
+ // Inc/Dec statements.
+ (*(*T)(nil))++
+ (*(*T)(nil))--
+
+ // Range statements.
+ for g(T(0)).a, g(T(1)).b = range g(T(10)).c {
+ var d T
+ _ = d
+ }
+
+ for T, b := range g(T(10)).c {
+ _ = T
+ _ = b
+ }
+
+ // Select statement.
+ {
+ var fch func(T) chan int
+
+ select {
+ case <-fch(T(30)):
+ var T T
+ T = 0
+ _ = T
+ default:
+ var T T
+ T = 0
+ _ = T
+ case T := <-fch(T(30)):
+ T = 0
+ _ = T
+ case g(T(0)).a = <-fch(T(30)):
+ var T T
+ T = 0
+ _ = T
+ case fch(T(30)) <- int(T(0)):
+ var T T
+ T = 0
+ _ = T
+ }
+ }
+
+ // Send statements.
+ {
+ var ch chan T
+ var fch func(T) chan int
+
+ ch <- T(0)
+ fch(T(1)) <- g(T(10)).a
+ }
+
+ // Switch statements.
+ {
+ var a T
+ var b int
+ switch {
+ case a == T(0):
+ var T T
+ T = 0
+ _ = T
+ case a < T(0), b < g(T(10)).a:
+ var T T
+ T = 0
+ _ = T
+ default:
+ var T T
+ T = 0
+ _ = T
+ }
+ }
+
+ switch T(g(T(10)).a) {
+ case T(0):
+ var T T
+ T = 0
+ _ = T
+ case T(1), T(g(T(10)).a):
+ var T T
+ T = 0
+ _ = T
+ default:
+ var T T
+ T = 0
+ _ = T
+ }
+
+ switch b := g(T(10)); T(b.a) + T(10) {
+ case T(0):
+ var T T
+ T = 0
+ _ = T
+ case T(1), T(g(T(10)).a):
+ var T T
+ T = 0
+ _ = T
+ default:
+ var T T
+ T = 0
+ _ = T
+ }
+
+ // Type switch statements.
+ {
+ var interfaceFunc func(T) interface{}
+
+ switch interfaceFunc(T(0)).(type) {
+ case *T, T, int:
+ var T T
+ T = 0
+ _ = T
+ case sync.Mutex, **T:
+ var T T
+ T = 0
+ _ = T
+ default:
+ var T T
+ T = 0
+ _ = T
+ }
+
+ switch x := interfaceFunc(T(0)).(type) {
+ case *T, T, int:
+ var T T
+ T = 0
+ _ = T
+ _ = x
+ case sync.Mutex, **T:
+ var T T
+ T = 0
+ _ = T
+ default:
+ var T T
+ T = 0
+ _ = T
+ }
+
+ switch t := T(0); x := interfaceFunc(T(0) + t).(type) {
+ case *T, T, int:
+ var T T
+ T = 0
+ _ = T
+ _ = x
+ case sync.Mutex, **T:
+ var T T
+ T = 0
+ _ = T
+ default:
+ var T T
+ T = 0
+ _ = T
+ }
+ }
+
+ // Return statement.
+ return T(10), g(T(11)).c
+}
diff --git a/tools/go_generics/generics_tests/all_stmts/opts.txt b/tools/go_generics/generics_tests/all_stmts/opts.txt
new file mode 100644
index 000000000..c9d0e09bf
--- /dev/null
+++ b/tools/go_generics/generics_tests/all_stmts/opts.txt
@@ -0,0 +1 @@
+-t=T=Q
diff --git a/tools/go_generics/generics_tests/all_stmts/output/output.go b/tools/go_generics/generics_tests/all_stmts/output/output.go
new file mode 100644
index 000000000..a53d84535
--- /dev/null
+++ b/tools/go_generics/generics_tests/all_stmts/output/output.go
@@ -0,0 +1,288 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "sync"
+)
+
+func h(Q) {
+}
+
+type s struct {
+ a, b int
+ c []int
+}
+
+func g(Q) *s {
+ return &s{}
+}
+
+func f() (Q, []int) {
+ // Branch.
+ goto T
+ goto R
+
+ // Labeled.
+T:
+ _ = Q(0)
+
+ // Empty.
+R:
+ ;
+
+ // Assignment with definition.
+ a, b, c := Q(1), Q(2), Q(3)
+ _, _, _ = a, b, c
+
+ // Assignment without definition.
+ g(Q(0)).a, g(Q(1)).b, c = int(Q(1)), int(Q(2)), Q(3)
+ _, _, _ = a, b, c
+
+ // Block.
+ {
+ var T Q
+ T = 0
+ _ = T
+ }
+
+ // Declarations.
+ type Type Q
+ const Const Q = 10
+ var g1 func(Q, int, ...Q) (int, Q)
+ var v Q
+ var w = Q(0)
+ {
+ var T struct {
+ f []Q
+ }
+ _ = T
+ }
+
+ // Defer.
+ defer g1(Q(0), 1)
+
+ // Expression.
+ h(v + w + Q(1))
+
+ // For statements.
+ for i := Q(0); i < Q(10); i++ {
+ var T func(int) Q
+ v := T(0)
+ _ = v
+ }
+
+ for {
+ var T func(int) Q
+ v := T(0)
+ _ = v
+ }
+
+ // Go.
+ go g1(Q(0), 1)
+
+ // If statements.
+ if a != Q(1) {
+ var T func(int) Q
+ v := T(0)
+ _ = v
+ }
+
+ if a := Q(0); a != Q(1) {
+ var T func(int) Q
+ v := T(0)
+ _ = v
+ }
+
+ if a := Q(0); a != Q(1) {
+ var T func(int) Q
+ v := T(0)
+ _ = v
+ } else if b := Q(0); b != Q(1) {
+ var T func(int) Q
+ v := T(0)
+ _ = v
+ } else if T := Q(0); T != 1 {
+ T++
+ } else {
+ T--
+ }
+
+ if a := Q(0); a != Q(1) {
+ var T func(int) Q
+ v := T(0)
+ _ = v
+ } else {
+ var T func(int) Q
+ v := T(0)
+ _ = v
+ }
+
+ // Inc/Dec statements.
+ (*(*Q)(nil))++
+ (*(*Q)(nil))--
+
+ // Range statements.
+ for g(Q(0)).a, g(Q(1)).b = range g(Q(10)).c {
+ var d Q
+ _ = d
+ }
+
+ for T, b := range g(Q(10)).c {
+ _ = T
+ _ = b
+ }
+
+ // Select statement.
+ {
+ var fch func(Q) chan int
+
+ select {
+ case <-fch(Q(30)):
+ var T Q
+ T = 0
+ _ = T
+ default:
+ var T Q
+ T = 0
+ _ = T
+ case T := <-fch(Q(30)):
+ T = 0
+ _ = T
+ case g(Q(0)).a = <-fch(Q(30)):
+ var T Q
+ T = 0
+ _ = T
+ case fch(Q(30)) <- int(Q(0)):
+ var T Q
+ T = 0
+ _ = T
+ }
+ }
+
+ // Send statements.
+ {
+ var ch chan Q
+ var fch func(Q) chan int
+
+ ch <- Q(0)
+ fch(Q(1)) <- g(Q(10)).a
+ }
+
+ // Switch statements.
+ {
+ var a Q
+ var b int
+ switch {
+ case a == Q(0):
+ var T Q
+ T = 0
+ _ = T
+ case a < Q(0), b < g(Q(10)).a:
+ var T Q
+ T = 0
+ _ = T
+ default:
+ var T Q
+ T = 0
+ _ = T
+ }
+ }
+
+ switch Q(g(Q(10)).a) {
+ case Q(0):
+ var T Q
+ T = 0
+ _ = T
+ case Q(1), Q(g(Q(10)).a):
+ var T Q
+ T = 0
+ _ = T
+ default:
+ var T Q
+ T = 0
+ _ = T
+ }
+
+ switch b := g(Q(10)); Q(b.a) + Q(10) {
+ case Q(0):
+ var T Q
+ T = 0
+ _ = T
+ case Q(1), Q(g(Q(10)).a):
+ var T Q
+ T = 0
+ _ = T
+ default:
+ var T Q
+ T = 0
+ _ = T
+ }
+
+ // Type switch statements.
+ {
+ var interfaceFunc func(Q) interface{}
+
+ switch interfaceFunc(Q(0)).(type) {
+ case *Q, Q, int:
+ var T Q
+ T = 0
+ _ = T
+ case sync.Mutex, **Q:
+ var T Q
+ T = 0
+ _ = T
+ default:
+ var T Q
+ T = 0
+ _ = T
+ }
+
+ switch x := interfaceFunc(Q(0)).(type) {
+ case *Q, Q, int:
+ var T Q
+ T = 0
+ _ = T
+ _ = x
+ case sync.Mutex, **Q:
+ var T Q
+ T = 0
+ _ = T
+ default:
+ var T Q
+ T = 0
+ _ = T
+ }
+
+ switch t := Q(0); x := interfaceFunc(Q(0) + t).(type) {
+ case *Q, Q, int:
+ var T Q
+ T = 0
+ _ = T
+ _ = x
+ case sync.Mutex, **Q:
+ var T Q
+ T = 0
+ _ = T
+ default:
+ var T Q
+ T = 0
+ _ = T
+ }
+ }
+
+ // Return statement.
+ return Q(10), g(Q(11)).c
+}
diff --git a/tools/go_generics/generics_tests/all_types/input.go b/tools/go_generics/generics_tests/all_types/input.go
new file mode 100644
index 000000000..3575d02ec
--- /dev/null
+++ b/tools/go_generics/generics_tests/all_types/input.go
@@ -0,0 +1,43 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import "./lib"
+
+type T int
+
+type newType struct {
+ a T
+ b lib.T
+ c *T
+ d (T)
+ e chan T
+ f <-chan T
+ g chan<- T
+ h []T
+ i [10]T
+ j map[T]T
+ k func(T, T) (T, T)
+ l interface {
+ f(T)
+ }
+ m struct {
+ T
+ a T
+ }
+}
+
+func f(...T) {
+}
diff --git a/tools/go_generics/generics_tests/all_types/lib/lib.go b/tools/go_generics/generics_tests/all_types/lib/lib.go
new file mode 100644
index 000000000..988786496
--- /dev/null
+++ b/tools/go_generics/generics_tests/all_types/lib/lib.go
@@ -0,0 +1,17 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lib
+
+type T int32
diff --git a/tools/go_generics/generics_tests/all_types/opts.txt b/tools/go_generics/generics_tests/all_types/opts.txt
new file mode 100644
index 000000000..c9d0e09bf
--- /dev/null
+++ b/tools/go_generics/generics_tests/all_types/opts.txt
@@ -0,0 +1 @@
+-t=T=Q
diff --git a/tools/go_generics/generics_tests/all_types/output/output.go b/tools/go_generics/generics_tests/all_types/output/output.go
new file mode 100644
index 000000000..41fd147a1
--- /dev/null
+++ b/tools/go_generics/generics_tests/all_types/output/output.go
@@ -0,0 +1,41 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import "./lib"
+
+type newType struct {
+ a Q
+ b lib.T
+ c *Q
+ d (Q)
+ e chan Q
+ f <-chan Q
+ g chan<- Q
+ h []Q
+ i [10]Q
+ j map[Q]Q
+ k func(Q, Q) (Q, Q)
+ l interface {
+ f(Q)
+ }
+ m struct {
+ Q
+ a Q
+ }
+}
+
+func f(...Q) {
+}
diff --git a/tools/go_generics/generics_tests/anon/input.go b/tools/go_generics/generics_tests/anon/input.go
new file mode 100644
index 000000000..44086d522
--- /dev/null
+++ b/tools/go_generics/generics_tests/anon/input.go
@@ -0,0 +1,46 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+type T interface {
+ Apply(T) T
+}
+
+type Foo struct {
+ T
+ Bar map[string]T `json:"bar,omitempty"`
+}
+
+type Baz struct {
+ T someTypeNotT
+}
+
+func (f Foo) GetBar(name string) T {
+ b, ok := f.Bar[name]
+ if ok {
+ b = f.Apply(b)
+ } else {
+ b = f.T
+ }
+ return b
+}
+
+func foobar() {
+ a := Baz{}
+ a.T = 0 // should not be renamed, this is a limitation
+
+ b := otherpkg.UnrelatedType{}
+ b.T = 0 // should not be renamed, this is a limitation
+}
diff --git a/tools/go_generics/generics_tests/anon/opts.txt b/tools/go_generics/generics_tests/anon/opts.txt
new file mode 100644
index 000000000..a5e9d26de
--- /dev/null
+++ b/tools/go_generics/generics_tests/anon/opts.txt
@@ -0,0 +1 @@
+-t=T=Q -suffix=New -anon
diff --git a/tools/go_generics/generics_tests/anon/output/output.go b/tools/go_generics/generics_tests/anon/output/output.go
new file mode 100644
index 000000000..160cddf79
--- /dev/null
+++ b/tools/go_generics/generics_tests/anon/output/output.go
@@ -0,0 +1,42 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+type FooNew struct {
+ Q
+ Bar map[string]Q `json:"bar,omitempty"`
+}
+
+type BazNew struct {
+ T someTypeNotT
+}
+
+func (f FooNew) GetBar(name string) Q {
+ b, ok := f.Bar[name]
+ if ok {
+ b = f.Apply(b)
+ } else {
+ b = f.Q
+ }
+ return b
+}
+
+func foobarNew() {
+ a := BazNew{}
+ a.Q = 0 // should not be renamed, this is a limitation
+
+ b := otherpkg.UnrelatedType{}
+ b.Q = 0 // should not be renamed, this is a limitation
+}
diff --git a/tools/go_generics/generics_tests/consts/input.go b/tools/go_generics/generics_tests/consts/input.go
new file mode 100644
index 000000000..04b95fcc6
--- /dev/null
+++ b/tools/go_generics/generics_tests/consts/input.go
@@ -0,0 +1,26 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+const c1 = 10
+const x, y, z = 100, 200, 300
+const v float32 = 1.0 + 2.0
+const s = "abc"
+const (
+ A = 10
+ B, C, D = 10, 20, 30
+ S = "abc"
+ T, U, V string = "abc", "def", "ghi"
+)
diff --git a/tools/go_generics/generics_tests/consts/opts.txt b/tools/go_generics/generics_tests/consts/opts.txt
new file mode 100644
index 000000000..4fb59dce8
--- /dev/null
+++ b/tools/go_generics/generics_tests/consts/opts.txt
@@ -0,0 +1 @@
+-c=c1=20 -c=z=600 -c=v=3.3 -c=s="def" -c=A=20 -c=C=100 -c=S="def" -c=T="ABC"
diff --git a/tools/go_generics/generics_tests/consts/output/output.go b/tools/go_generics/generics_tests/consts/output/output.go
new file mode 100644
index 000000000..18d316cc9
--- /dev/null
+++ b/tools/go_generics/generics_tests/consts/output/output.go
@@ -0,0 +1,26 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+const c1 = 20
+const x, y, z = 100, 200, 600
+const v float32 = 3.3
+const s = "def"
+const (
+ A = 20
+ B, C, D = 10, 100, 30
+ S = "def"
+ T, U, V string = "ABC", "def", "ghi"
+)
diff --git a/tools/go_generics/generics_tests/imports/input.go b/tools/go_generics/generics_tests/imports/input.go
new file mode 100644
index 000000000..0f032c2a1
--- /dev/null
+++ b/tools/go_generics/generics_tests/imports/input.go
@@ -0,0 +1,24 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+type T int
+
+var global T
+
+const (
+ m = 0
+ n = 0
+)
diff --git a/tools/go_generics/generics_tests/imports/opts.txt b/tools/go_generics/generics_tests/imports/opts.txt
new file mode 100644
index 000000000..87324be79
--- /dev/null
+++ b/tools/go_generics/generics_tests/imports/opts.txt
@@ -0,0 +1 @@
+-t=T=sync.Mutex -c=n=math.Uint32 -c=m=math.Uint64 -import=sync=sync -import=math=mymathpath
diff --git a/tools/go_generics/generics_tests/imports/output/output.go b/tools/go_generics/generics_tests/imports/output/output.go
new file mode 100644
index 000000000..2488ca58c
--- /dev/null
+++ b/tools/go_generics/generics_tests/imports/output/output.go
@@ -0,0 +1,27 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ __generics_imported1 "mymathpath"
+ __generics_imported0 "sync"
+)
+
+var global __generics_imported0.Mutex
+
+const (
+ m = __generics_imported1.Uint64
+ n = __generics_imported1.Uint32
+)
diff --git a/tools/go_generics/generics_tests/remove_typedef/input.go b/tools/go_generics/generics_tests/remove_typedef/input.go
new file mode 100644
index 000000000..cf632bae7
--- /dev/null
+++ b/tools/go_generics/generics_tests/remove_typedef/input.go
@@ -0,0 +1,37 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+func f(T) Q {
+ return Q{}
+}
+
+type T struct{}
+
+type Q struct{}
+
+func (*T) f() {
+}
+
+func (T) g() {
+}
+
+func (*Q) f(T) T {
+ return T{}
+}
+
+func (*Q) g(T) *T {
+ return nil
+}
diff --git a/tools/go_generics/generics_tests/remove_typedef/opts.txt b/tools/go_generics/generics_tests/remove_typedef/opts.txt
new file mode 100644
index 000000000..9c8ecaada
--- /dev/null
+++ b/tools/go_generics/generics_tests/remove_typedef/opts.txt
@@ -0,0 +1 @@
+-t=T=U
diff --git a/tools/go_generics/generics_tests/remove_typedef/output/output.go b/tools/go_generics/generics_tests/remove_typedef/output/output.go
new file mode 100644
index 000000000..d44fd8e1c
--- /dev/null
+++ b/tools/go_generics/generics_tests/remove_typedef/output/output.go
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+func f(U) Q {
+ return Q{}
+}
+
+type Q struct{}
+
+func (*Q) f(U) U {
+ return U{}
+}
+
+func (*Q) g(U) *U {
+ return nil
+}
diff --git a/tools/go_generics/generics_tests/simple/input.go b/tools/go_generics/generics_tests/simple/input.go
new file mode 100644
index 000000000..2a917f16c
--- /dev/null
+++ b/tools/go_generics/generics_tests/simple/input.go
@@ -0,0 +1,45 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+type T int
+
+var global T
+
+func f(_ T, a int) {
+}
+
+func g(a T, b int) {
+ var c T
+ _ = c
+
+ d := (*T)(nil)
+ _ = d
+}
+
+type R struct {
+ T
+ a *T
+}
+
+var (
+ Z *T = (*T)(nil)
+)
+
+const (
+ X T = (T)(0)
+)
+
+type Y T
diff --git a/tools/go_generics/generics_tests/simple/opts.txt b/tools/go_generics/generics_tests/simple/opts.txt
new file mode 100644
index 000000000..7832ef66f
--- /dev/null
+++ b/tools/go_generics/generics_tests/simple/opts.txt
@@ -0,0 +1 @@
+-t=T=Q -suffix=New
diff --git a/tools/go_generics/generics_tests/simple/output/output.go b/tools/go_generics/generics_tests/simple/output/output.go
new file mode 100644
index 000000000..6bfa0b25b
--- /dev/null
+++ b/tools/go_generics/generics_tests/simple/output/output.go
@@ -0,0 +1,43 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+var globalNew Q
+
+func fNew(_ Q, a int) {
+}
+
+func gNew(a Q, b int) {
+ var c Q
+ _ = c
+
+ d := (*Q)(nil)
+ _ = d
+}
+
+type RNew struct {
+ Q
+ a *Q
+}
+
+var (
+ ZNew *Q = (*Q)(nil)
+)
+
+const (
+ XNew Q = (Q)(0)
+)
+
+type YNew Q
diff --git a/tools/go_generics/globals/BUILD b/tools/go_generics/globals/BUILD
new file mode 100644
index 000000000..38caa3ce7
--- /dev/null
+++ b/tools/go_generics/globals/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "globals",
+ srcs = [
+ "globals_visitor.go",
+ "scope.go",
+ ],
+ stateify = False,
+ visibility = ["//tools/go_generics:__pkg__"],
+)
diff --git a/tools/go_generics/globals/globals_visitor.go b/tools/go_generics/globals/globals_visitor.go
new file mode 100644
index 000000000..883f21ebe
--- /dev/null
+++ b/tools/go_generics/globals/globals_visitor.go
@@ -0,0 +1,597 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package globals provides an AST visitor that calls the visit function for all
+// global identifiers.
+package globals
+
+import (
+ "fmt"
+
+ "go/ast"
+ "go/token"
+ "path/filepath"
+ "strconv"
+)
+
+// globalsVisitor holds the state used while traversing the nodes of a file in
+// search of globals.
+//
+// The visitor does two passes on the global declarations: the first one adds
+// all globals to the global scope (since Go allows references to globals that
+// haven't been declared yet), and the second one calls f() for the definition
+// and uses of globals found in the first pass.
+//
+// The implementation correctly handles cases when globals are aliased by
+// locals; in such cases, f() is not called.
+type globalsVisitor struct {
+ // file is the file whose nodes are being visited.
+ file *ast.File
+
+ // fset is the file set the file being visited belongs to.
+ fset *token.FileSet
+
+ // f is the visit function to be called when a global symbol is reached.
+ f func(*ast.Ident, SymKind)
+
+ // scope is the current scope as nodes are visited.
+ scope *scope
+
+ // processAnon indicates whether we should process anonymous struct fields.
+ // It does not perform strict checking on parameter types that share the same name
+ // as the global type and therefore will rename them as well.
+ processAnon bool
+}
+
+// unexpected is called when an unexpected node appears in the AST. It dumps
+// the location of the associated token and panics because this should only
+// happen when there is a bug in the traversal code.
+func (v *globalsVisitor) unexpected(p token.Pos) {
+ panic(fmt.Sprintf("Unable to parse at %v", v.fset.Position(p)))
+}
+
+// pushScope creates a new scope and pushes it to the top of the scope stack.
+func (v *globalsVisitor) pushScope() {
+ v.scope = newScope(v.scope)
+}
+
+// popScope removes the scope created by the last call to pushScope.
+func (v *globalsVisitor) popScope() {
+ v.scope = v.scope.outer
+}
+
+// visitType is called when an expression is known to be a type, for example,
+// on the first argument of make(). It visits all children nodes and reports
+// any globals.
+func (v *globalsVisitor) visitType(ge ast.Expr) {
+ switch e := ge.(type) {
+ case *ast.Ident:
+ if s := v.scope.deepLookup(e.Name); s != nil && s.scope.isGlobal() {
+ v.f(e, s.kind)
+ }
+
+ case *ast.SelectorExpr:
+ id := GetIdent(e.X)
+ if id == nil {
+ v.unexpected(e.X.Pos())
+ }
+
+ case *ast.StarExpr:
+ v.visitType(e.X)
+ case *ast.ParenExpr:
+ v.visitType(e.X)
+ case *ast.ChanType:
+ v.visitType(e.Value)
+ case *ast.Ellipsis:
+ v.visitType(e.Elt)
+ case *ast.ArrayType:
+ v.visitExpr(e.Len)
+ v.visitType(e.Elt)
+ case *ast.MapType:
+ v.visitType(e.Key)
+ v.visitType(e.Value)
+ case *ast.StructType:
+ v.visitFields(e.Fields, KindUnknown)
+ case *ast.FuncType:
+ v.visitFields(e.Params, KindUnknown)
+ v.visitFields(e.Results, KindUnknown)
+ case *ast.InterfaceType:
+ v.visitFields(e.Methods, KindUnknown)
+ default:
+ v.unexpected(ge.Pos())
+ }
+}
+
+// visitFields visits all fields, and add symbols if kind isn't KindUnknown.
+func (v *globalsVisitor) visitFields(l *ast.FieldList, kind SymKind) {
+ if l == nil {
+ return
+ }
+
+ for _, f := range l.List {
+ if kind != KindUnknown {
+ for _, n := range f.Names {
+ v.scope.add(n.Name, kind, n.Pos())
+ }
+ }
+ v.visitType(f.Type)
+ if f.Tag != nil {
+ tag := ast.NewIdent(f.Tag.Value)
+ v.f(tag, KindTag)
+ // Replace the tag if updated.
+ if tag.Name != f.Tag.Value {
+ f.Tag.Value = tag.Name
+ }
+ }
+ }
+}
+
+// visitGenDecl is called when a generic declaration is encountered, for example,
+// on variable, constant and type declarations. It adds all newly defined
+// symbols to the current scope and reports them if the current scope is the
+// global one.
+func (v *globalsVisitor) visitGenDecl(d *ast.GenDecl) {
+ switch d.Tok {
+ case token.IMPORT:
+ case token.TYPE:
+ for _, gs := range d.Specs {
+ s := gs.(*ast.TypeSpec)
+ v.scope.add(s.Name.Name, KindType, s.Name.Pos())
+ if v.scope.isGlobal() {
+ v.f(s.Name, KindType)
+ }
+ v.visitType(s.Type)
+ }
+ case token.CONST, token.VAR:
+ kind := KindConst
+ if d.Tok == token.VAR {
+ kind = KindVar
+ }
+
+ for _, gs := range d.Specs {
+ s := gs.(*ast.ValueSpec)
+ if s.Type != nil {
+ v.visitType(s.Type)
+ }
+
+ for _, e := range s.Values {
+ v.visitExpr(e)
+ }
+
+ for _, n := range s.Names {
+ if v.scope.isGlobal() {
+ v.f(n, kind)
+ }
+ v.scope.add(n.Name, kind, n.Pos())
+ }
+ }
+ default:
+ v.unexpected(d.Pos())
+ }
+}
+
+// isViableType determines if the given expression is a viable type expression,
+// that is, if it could be interpreted as a type, for example, sync.Mutex,
+// myType, func(int)int, as opposed to -1, 2 * 2, a + b, etc.
+func (v *globalsVisitor) isViableType(expr ast.Expr) bool {
+ switch e := expr.(type) {
+ case *ast.Ident:
+ // This covers the plain identifier case. When we see it, we
+ // have to check if it resolves to a type; if the symbol is not
+ // known, we'll claim it's viable as a type.
+ s := v.scope.deepLookup(e.Name)
+ return s == nil || s.kind == KindType
+
+ case *ast.ChanType, *ast.ArrayType, *ast.MapType, *ast.StructType, *ast.FuncType, *ast.InterfaceType, *ast.Ellipsis:
+ // This covers the following cases:
+ // 1. ChanType:
+ // chan T
+ // <-chan T
+ // chan<- T
+ // 2. ArrayType:
+ // [Expr]T
+ // 3. MapType:
+ // map[T]U
+ // 4. StructType:
+ // struct { Fields }
+ // 5. FuncType:
+ // func(Fields)Returns
+ // 6. Interface:
+ // interface { Fields }
+ // 7. Ellipsis:
+ // ...T
+ return true
+
+ case *ast.SelectorExpr:
+ // The only case in which an expression involving a selector can
+ // be a type is if it has the following form X.T, where X is an
+ // import, and T is a type exported by X.
+ //
+ // There's no way to know whether T is a type because we don't
+ // parse imports. So we just claim that this is a viable type;
+ // it doesn't affect the general result because we don't visit
+ // imported symbols.
+ id := GetIdent(e.X)
+ if id == nil {
+ return false
+ }
+
+ s := v.scope.deepLookup(id.Name)
+ return s != nil && s.kind == KindImport
+
+ case *ast.StarExpr:
+ // This covers the *T case. The expression is a viable type if
+ // T is.
+ return v.isViableType(e.X)
+
+ case *ast.ParenExpr:
+ // This covers the (T) case. The expression is a viable type if
+ // T is.
+ return v.isViableType(e.X)
+
+ default:
+ return false
+ }
+}
+
+// visitCallExpr visits a "call expression" which can be either a
+// function/method call (e.g., f(), pkg.f(), obj.f(), etc.) call or a type
+// conversion (e.g., int32(1), (*sync.Mutex)(ptr), etc.).
+func (v *globalsVisitor) visitCallExpr(e *ast.CallExpr) {
+ if v.isViableType(e.Fun) {
+ v.visitType(e.Fun)
+ } else {
+ v.visitExpr(e.Fun)
+ }
+
+ // If the function being called is new or make, the first argument is
+ // a type, so it needs to be visited as such.
+ first := 0
+ if id := GetIdent(e.Fun); id != nil && (id.Name == "make" || id.Name == "new") {
+ if len(e.Args) > 0 {
+ v.visitType(e.Args[0])
+ }
+ first = 1
+ }
+
+ for i := first; i < len(e.Args); i++ {
+ v.visitExpr(e.Args[i])
+ }
+}
+
+// visitExpr visits all nodes of an expression, and reports any globals that it
+// finds.
+func (v *globalsVisitor) visitExpr(ge ast.Expr) {
+ switch e := ge.(type) {
+ case nil:
+ case *ast.Ident:
+ if s := v.scope.deepLookup(e.Name); s != nil && s.scope.isGlobal() {
+ v.f(e, s.kind)
+ }
+
+ case *ast.BasicLit:
+ case *ast.CompositeLit:
+ v.visitType(e.Type)
+ for _, ne := range e.Elts {
+ v.visitExpr(ne)
+ }
+ case *ast.FuncLit:
+ v.pushScope()
+ v.visitFields(e.Type.Params, KindParameter)
+ v.visitFields(e.Type.Results, KindResult)
+ v.visitBlockStmt(e.Body)
+ v.popScope()
+
+ case *ast.BinaryExpr:
+ v.visitExpr(e.X)
+ v.visitExpr(e.Y)
+
+ case *ast.CallExpr:
+ v.visitCallExpr(e)
+
+ case *ast.IndexExpr:
+ v.visitExpr(e.X)
+ v.visitExpr(e.Index)
+
+ case *ast.KeyValueExpr:
+ v.visitExpr(e.Value)
+
+ case *ast.ParenExpr:
+ v.visitExpr(e.X)
+
+ case *ast.SelectorExpr:
+ v.visitExpr(e.X)
+ if v.processAnon {
+ v.visitExpr(e.Sel)
+ }
+
+ case *ast.SliceExpr:
+ v.visitExpr(e.X)
+ v.visitExpr(e.Low)
+ v.visitExpr(e.High)
+ v.visitExpr(e.Max)
+
+ case *ast.StarExpr:
+ v.visitExpr(e.X)
+
+ case *ast.TypeAssertExpr:
+ v.visitExpr(e.X)
+ if e.Type != nil {
+ v.visitType(e.Type)
+ }
+
+ case *ast.UnaryExpr:
+ v.visitExpr(e.X)
+
+ default:
+ v.unexpected(ge.Pos())
+ }
+}
+
+// GetIdent returns the identifier associated with the given expression by
+// removing parentheses if needed.
+func GetIdent(expr ast.Expr) *ast.Ident {
+ switch e := expr.(type) {
+ case *ast.Ident:
+ return e
+ case *ast.ParenExpr:
+ return GetIdent(e.X)
+ default:
+ return nil
+ }
+}
+
+// visitStmt visits all nodes of a statement, and reports any globals that it
+// finds. It also adds to the current scope new symbols defined/declared.
+func (v *globalsVisitor) visitStmt(gs ast.Stmt) {
+ switch s := gs.(type) {
+ case nil, *ast.BranchStmt, *ast.EmptyStmt:
+ case *ast.AssignStmt:
+ for _, e := range s.Rhs {
+ v.visitExpr(e)
+ }
+
+ // We visit the LHS after the RHS because the symbols we'll
+ // potentially add to the table aren't meant to be visible to
+ // the RHS.
+ for _, e := range s.Lhs {
+ if s.Tok == token.DEFINE {
+ if n := GetIdent(e); n != nil {
+ v.scope.add(n.Name, KindVar, n.Pos())
+ }
+ }
+ v.visitExpr(e)
+ }
+
+ case *ast.BlockStmt:
+ v.visitBlockStmt(s)
+
+ case *ast.DeclStmt:
+ v.visitGenDecl(s.Decl.(*ast.GenDecl))
+
+ case *ast.DeferStmt:
+ v.visitCallExpr(s.Call)
+
+ case *ast.ExprStmt:
+ v.visitExpr(s.X)
+
+ case *ast.ForStmt:
+ v.pushScope()
+ v.visitStmt(s.Init)
+ v.visitExpr(s.Cond)
+ v.visitStmt(s.Post)
+ v.visitBlockStmt(s.Body)
+ v.popScope()
+
+ case *ast.GoStmt:
+ v.visitCallExpr(s.Call)
+
+ case *ast.IfStmt:
+ v.pushScope()
+ v.visitStmt(s.Init)
+ v.visitExpr(s.Cond)
+ v.visitBlockStmt(s.Body)
+ v.visitStmt(s.Else)
+ v.popScope()
+
+ case *ast.IncDecStmt:
+ v.visitExpr(s.X)
+
+ case *ast.LabeledStmt:
+ v.visitStmt(s.Stmt)
+
+ case *ast.RangeStmt:
+ v.pushScope()
+ v.visitExpr(s.X)
+ if s.Tok == token.DEFINE {
+ if n := GetIdent(s.Key); n != nil {
+ v.scope.add(n.Name, KindVar, n.Pos())
+ }
+
+ if n := GetIdent(s.Value); n != nil {
+ v.scope.add(n.Name, KindVar, n.Pos())
+ }
+ }
+ v.visitExpr(s.Key)
+ v.visitExpr(s.Value)
+ v.visitBlockStmt(s.Body)
+ v.popScope()
+
+ case *ast.ReturnStmt:
+ for _, r := range s.Results {
+ v.visitExpr(r)
+ }
+
+ case *ast.SelectStmt:
+ for _, ns := range s.Body.List {
+ c := ns.(*ast.CommClause)
+
+ v.pushScope()
+ v.visitStmt(c.Comm)
+ for _, bs := range c.Body {
+ v.visitStmt(bs)
+ }
+ v.popScope()
+ }
+
+ case *ast.SendStmt:
+ v.visitExpr(s.Chan)
+ v.visitExpr(s.Value)
+
+ case *ast.SwitchStmt:
+ v.pushScope()
+ v.visitStmt(s.Init)
+ v.visitExpr(s.Tag)
+ for _, ns := range s.Body.List {
+ c := ns.(*ast.CaseClause)
+ v.pushScope()
+ for _, ce := range c.List {
+ v.visitExpr(ce)
+ }
+ for _, bs := range c.Body {
+ v.visitStmt(bs)
+ }
+ v.popScope()
+ }
+ v.popScope()
+
+ case *ast.TypeSwitchStmt:
+ v.pushScope()
+ v.visitStmt(s.Init)
+ v.visitStmt(s.Assign)
+ for _, ns := range s.Body.List {
+ c := ns.(*ast.CaseClause)
+ v.pushScope()
+ for _, ce := range c.List {
+ v.visitType(ce)
+ }
+ for _, bs := range c.Body {
+ v.visitStmt(bs)
+ }
+ v.popScope()
+ }
+ v.popScope()
+
+ default:
+ v.unexpected(gs.Pos())
+ }
+}
+
+// visitBlockStmt visits all statements in the block, adding symbols to a newly
+// created scope.
+func (v *globalsVisitor) visitBlockStmt(s *ast.BlockStmt) {
+ v.pushScope()
+ for _, c := range s.List {
+ v.visitStmt(c)
+ }
+ v.popScope()
+}
+
+// visitFuncDecl is called when a function or method declaration is encountered.
+// it creates a new scope for the function [optional] receiver, parameters and
+// results, and visits all children nodes.
+func (v *globalsVisitor) visitFuncDecl(d *ast.FuncDecl) {
+ // We don't report methods.
+ if d.Recv == nil {
+ v.f(d.Name, KindFunction)
+ }
+
+ v.pushScope()
+ v.visitFields(d.Recv, KindReceiver)
+ v.visitFields(d.Type.Params, KindParameter)
+ v.visitFields(d.Type.Results, KindResult)
+ if d.Body != nil {
+ v.visitBlockStmt(d.Body)
+ }
+ v.popScope()
+}
+
+// globalsFromDecl is called in the first, and adds symbols to global scope.
+func (v *globalsVisitor) globalsFromGenDecl(d *ast.GenDecl) {
+ switch d.Tok {
+ case token.IMPORT:
+ for _, gs := range d.Specs {
+ s := gs.(*ast.ImportSpec)
+ if s.Name == nil {
+ str, _ := strconv.Unquote(s.Path.Value)
+ v.scope.add(filepath.Base(str), KindImport, s.Path.Pos())
+ } else if s.Name.Name != "_" {
+ v.scope.add(s.Name.Name, KindImport, s.Name.Pos())
+ }
+ }
+ case token.TYPE:
+ for _, gs := range d.Specs {
+ s := gs.(*ast.TypeSpec)
+ v.scope.add(s.Name.Name, KindType, s.Name.Pos())
+ }
+ case token.CONST, token.VAR:
+ kind := KindConst
+ if d.Tok == token.VAR {
+ kind = KindVar
+ }
+
+ for _, s := range d.Specs {
+ for _, n := range s.(*ast.ValueSpec).Names {
+ v.scope.add(n.Name, kind, n.Pos())
+ }
+ }
+ default:
+ v.unexpected(d.Pos())
+ }
+}
+
+// visit implements the visiting of globals. It does performs the two passes
+// described in the description of the globalsVisitor struct.
+func (v *globalsVisitor) visit() {
+ // Gather all symbols in the global scope. This excludes methods.
+ v.pushScope()
+ for _, gd := range v.file.Decls {
+ switch d := gd.(type) {
+ case *ast.GenDecl:
+ v.globalsFromGenDecl(d)
+ case *ast.FuncDecl:
+ if d.Recv == nil {
+ v.scope.add(d.Name.Name, KindFunction, d.Name.Pos())
+ }
+ default:
+ v.unexpected(gd.Pos())
+ }
+ }
+
+ // Go through the contents of the declarations.
+ for _, gd := range v.file.Decls {
+ switch d := gd.(type) {
+ case *ast.GenDecl:
+ v.visitGenDecl(d)
+ case *ast.FuncDecl:
+ v.visitFuncDecl(d)
+ }
+ }
+}
+
+// Visit traverses the provided AST and calls f() for each identifier that
+// refers to global names. The global name must be defined in the file itself.
+//
+// The function f() is allowed to modify the identifier, for example, to rename
+// uses of global references.
+func Visit(fset *token.FileSet, file *ast.File, f func(*ast.Ident, SymKind), processAnon bool) {
+ v := globalsVisitor{
+ fset: fset,
+ file: file,
+ f: f,
+ processAnon: processAnon,
+ }
+
+ v.visit()
+}
diff --git a/tools/go_generics/globals/scope.go b/tools/go_generics/globals/scope.go
new file mode 100644
index 000000000..eec93534b
--- /dev/null
+++ b/tools/go_generics/globals/scope.go
@@ -0,0 +1,84 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package globals
+
+import (
+ "go/token"
+)
+
+// SymKind specifies the kind of a global symbol. For example, a variable, const
+// function, etc.
+type SymKind int
+
+// Constants for different kinds of symbols.
+const (
+ KindUnknown SymKind = iota
+ KindImport
+ KindType
+ KindVar
+ KindConst
+ KindFunction
+ KindReceiver
+ KindParameter
+ KindResult
+ KindTag
+)
+
+type symbol struct {
+ kind SymKind
+ pos token.Pos
+ scope *scope
+}
+
+type scope struct {
+ outer *scope
+ syms map[string]*symbol
+}
+
+func newScope(outer *scope) *scope {
+ return &scope{
+ outer: outer,
+ syms: make(map[string]*symbol),
+ }
+}
+
+func (s *scope) isGlobal() bool {
+ return s.outer == nil
+}
+
+func (s *scope) lookup(n string) *symbol {
+ return s.syms[n]
+}
+
+func (s *scope) deepLookup(n string) *symbol {
+ for x := s; x != nil; x = x.outer {
+ if sym := x.lookup(n); sym != nil {
+ return sym
+ }
+ }
+ return nil
+}
+
+func (s *scope) add(name string, kind SymKind, pos token.Pos) {
+ if s.syms[name] != nil {
+ return
+ }
+
+ s.syms[name] = &symbol{
+ kind: kind,
+ pos: pos,
+ scope: s,
+ }
+}
diff --git a/tools/go_generics/go_generics_unittest.sh b/tools/go_generics/go_generics_unittest.sh
new file mode 100755
index 000000000..44b22db91
--- /dev/null
+++ b/tools/go_generics/go_generics_unittest.sh
@@ -0,0 +1,70 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Bash "safe-mode": Treat command failures as fatal (even those that occur in
+# pipes), and treat unset variables as errors.
+set -eu -o pipefail
+
+# This file will be generated as a self-extracting shell script in order to
+# eliminate the need for any runtime dependencies. The tarball at the end will
+# include the go_generics binary, as well as a subdirectory named
+# generics_tests. See the BUILD file for more information.
+declare -r temp=$(mktemp -d)
+function cleanup() {
+ rm -rf "${temp}"
+}
+# trap cleanup EXIT
+
+# Print message in "$1" then exit with status 1.
+function die () {
+ echo "$1" 1>&2
+ exit 1
+}
+
+# This prints the line number of __BUNDLE__ below, that should be the last line
+# of this script. After that point, the concatenated archive will be the
+# contents.
+declare -r tgz=`awk '/^__BUNDLE__/ {print NR + 1; exit 0; }' $0`
+tail -n+"${tgz}" $0 | tar -xzv -C "${temp}"
+
+# The target for the test.
+declare -r binary="$(find ${temp} -type f -a -name go_generics)"
+declare -r input_dirs="$(find ${temp} -type d -a -name generics_tests)/*"
+
+# Go through all test cases.
+for f in ${input_dirs}; do
+ base=$(basename "${f}")
+
+ # Run go_generics on the input file.
+ opts=$(head -n 1 ${f}/opts.txt)
+ out="${f}/output/generated.go"
+ expected="${f}/output/output.go"
+ ${binary} ${opts} "-i=${f}/input.go" "-o=${out}" || die "go_generics failed for test case \"${base}\""
+
+ # Compare the outputs.
+ diff ${expected} ${out}
+ if [ $? -ne 0 ]; then
+ echo "Expected:"
+ cat ${expected}
+ echo "Actual:"
+ cat ${out}
+ die "Actual output is different from expected for test \"${base}\""
+ fi
+done
+
+echo "PASS"
+exit 0
+__BUNDLE__
diff --git a/tools/go_generics/go_merge/BUILD b/tools/go_generics/go_merge/BUILD
new file mode 100644
index 000000000..2fd5a200d
--- /dev/null
+++ b/tools/go_generics/go_merge/BUILD
@@ -0,0 +1,9 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "go_merge",
+ srcs = ["main.go"],
+ visibility = ["//:sandbox"],
+)
diff --git a/tools/go_generics/go_merge/main.go b/tools/go_generics/go_merge/main.go
new file mode 100644
index 000000000..f6a331123
--- /dev/null
+++ b/tools/go_generics/go_merge/main.go
@@ -0,0 +1,139 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/format"
+ "go/parser"
+ "go/token"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "strconv"
+)
+
+var (
+ output = flag.String("o", "", "output `file`")
+)
+
+func fatalf(s string, args ...interface{}) {
+ fmt.Fprintf(os.Stderr, s, args...)
+ os.Exit(1)
+}
+
+func main() {
+ flag.Usage = func() {
+ fmt.Fprintf(os.Stderr, "Usage: %s [options] <input1> [<input2> ...]\n", os.Args[0])
+ flag.PrintDefaults()
+ }
+
+ flag.Parse()
+ if *output == "" || len(flag.Args()) == 0 {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ // Load all files.
+ files := make(map[string]*ast.File)
+ fset := token.NewFileSet()
+ var name string
+ for _, fname := range flag.Args() {
+ f, err := parser.ParseFile(fset, fname, nil, parser.ParseComments|parser.DeclarationErrors|parser.SpuriousErrors)
+ if err != nil {
+ fatalf("%v\n", err)
+ }
+
+ files[fname] = f
+ if name == "" {
+ name = f.Name.Name
+ } else if name != f.Name.Name {
+ fatalf("Expected '%s' for package name instead of '%s'.\n", name, f.Name.Name)
+ }
+ }
+
+ // Merge all files into one.
+ pkg := &ast.Package{
+ Name: name,
+ Files: files,
+ }
+ f := ast.MergePackageFiles(pkg, ast.FilterUnassociatedComments|ast.FilterFuncDuplicates|ast.FilterImportDuplicates)
+
+ // Create a new declaration slice with all imports at the top, merging any
+ // redundant imports.
+ imports := make(map[string]*ast.ImportSpec)
+ var anonImports []*ast.ImportSpec
+ for _, d := range f.Decls {
+ if g, ok := d.(*ast.GenDecl); ok && g.Tok == token.IMPORT {
+ for _, s := range g.Specs {
+ i := s.(*ast.ImportSpec)
+ p, _ := strconv.Unquote(i.Path.Value)
+ var n string
+ if i.Name == nil {
+ n = filepath.Base(p)
+ } else {
+ n = i.Name.Name
+ }
+ if n == "_" {
+ anonImports = append(anonImports, i)
+ } else {
+ if i2, ok := imports[n]; ok {
+ if first, second := i.Path.Value, i2.Path.Value; first != second {
+ fatalf("Conflicting paths for import name '%s': '%s' vs. '%s'\n", n, first, second)
+ }
+ } else {
+ imports[n] = i
+ }
+ }
+ }
+ }
+ }
+ newDecls := make([]ast.Decl, 0, len(f.Decls))
+ if l := len(imports) + len(anonImports); l > 0 {
+ // Non-NoPos Lparen is needed for Go to recognize more than one spec in
+ // ast.GenDecl.Specs.
+ d := &ast.GenDecl{
+ Tok: token.IMPORT,
+ Lparen: token.NoPos + 1,
+ Specs: make([]ast.Spec, 0, l),
+ }
+ for _, i := range imports {
+ d.Specs = append(d.Specs, i)
+ }
+ for _, i := range anonImports {
+ d.Specs = append(d.Specs, i)
+ }
+ newDecls = append(newDecls, d)
+ }
+ for _, d := range f.Decls {
+ if g, ok := d.(*ast.GenDecl); !ok || g.Tok != token.IMPORT {
+ newDecls = append(newDecls, d)
+ }
+ }
+ f.Decls = newDecls
+
+ // Write the output file.
+ var buf bytes.Buffer
+ if err := format.Node(&buf, fset, f); err != nil {
+ fatalf("%v\n", err)
+ }
+
+ if err := ioutil.WriteFile(*output, buf.Bytes(), 0644); err != nil {
+ fatalf("%v\n", err)
+ }
+}
diff --git a/tools/go_generics/imports.go b/tools/go_generics/imports.go
new file mode 100644
index 000000000..148dc7216
--- /dev/null
+++ b/tools/go_generics/imports.go
@@ -0,0 +1,150 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "bytes"
+ "fmt"
+ "go/ast"
+ "go/format"
+ "go/parser"
+ "go/token"
+ "strconv"
+
+ "gvisor.dev/gvisor/tools/go_generics/globals"
+)
+
+type importedPackage struct {
+ newName string
+ path string
+}
+
+// updateImportIdent modifies the given import identifier with the new name
+// stored in the used map. If the identifier doesn't exist in the used map yet,
+// a new name is generated and inserted into the map.
+func updateImportIdent(orig string, imports mapValue, id *ast.Ident, used map[string]*importedPackage) error {
+ importName := id.Name
+
+ // If the name is already in the table, just use the new name.
+ m := used[importName]
+ if m != nil {
+ id.Name = m.newName
+ return nil
+ }
+
+ // Create a new entry in the used map.
+ path := imports[importName]
+ if path == "" {
+ return fmt.Errorf("Unknown path to package '%s', used in '%s'", importName, orig)
+ }
+
+ m = &importedPackage{
+ newName: fmt.Sprintf("__generics_imported%d", len(used)),
+ path: strconv.Quote(path),
+ }
+ used[importName] = m
+
+ id.Name = m.newName
+
+ return nil
+}
+
+// convertExpression creates a new string that is a copy of the input one with
+// all imports references renamed to the names in the "used" map. If the
+// referenced import isn't in "used" yet, a new one is created based on the path
+// in "imports" and stored in "used". For example, if string s is
+// "math.MaxUint32-math.MaxUint16+10", it would be converted to
+// "x.MaxUint32-x.MathUint16+10", where x is a generated name.
+func convertExpression(s string, imports mapValue, used map[string]*importedPackage) (string, error) {
+ // Parse the expression in the input string.
+ expr, err := parser.ParseExpr(s)
+ if err != nil {
+ return "", fmt.Errorf("Unable to parse \"%s\": %v", s, err)
+ }
+
+ // Go through the AST and update references.
+ var retErr error
+ ast.Inspect(expr, func(n ast.Node) bool {
+ switch x := n.(type) {
+ case *ast.SelectorExpr:
+ if id := globals.GetIdent(x.X); id != nil {
+ if err := updateImportIdent(s, imports, id, used); err != nil {
+ retErr = err
+ }
+ return false
+ }
+ }
+ return true
+ })
+ if retErr != nil {
+ return "", retErr
+ }
+
+ // Convert the modified AST back to a string.
+ fset := token.NewFileSet()
+ var buf bytes.Buffer
+ if err := format.Node(&buf, fset, expr); err != nil {
+ return "", err
+ }
+
+ return string(buf.Bytes()), nil
+}
+
+// updateImports replaces all maps in the input slice with copies where the
+// mapped values have had all references to imported packages renamed to
+// generated names. It also returns an import declaration for all the renamed
+// import packages.
+//
+// For example, if the input maps contains A=math.B and C=math.D, the updated
+// maps will instead contain A=__generics_imported0.B and
+// C=__generics_imported0.C, and the 'import __generics_imported0 "math"' would
+// be returned as the import declaration.
+func updateImports(maps []mapValue, imports mapValue) (ast.Decl, error) {
+ importsUsed := make(map[string]*importedPackage)
+
+ // Update all maps.
+ for i, m := range maps {
+ newMap := make(mapValue)
+ for n, e := range m {
+ updated, err := convertExpression(e, imports, importsUsed)
+ if err != nil {
+ return nil, err
+ }
+
+ newMap[n] = updated
+ }
+ maps[i] = newMap
+ }
+
+ // Nothing else to do if no imports are used in the expressions.
+ if len(importsUsed) == 0 {
+ return nil, nil
+ }
+
+ // Create spec array for each new import.
+ specs := make([]ast.Spec, 0, len(importsUsed))
+ for _, i := range importsUsed {
+ specs = append(specs, &ast.ImportSpec{
+ Name: &ast.Ident{Name: i.newName},
+ Path: &ast.BasicLit{Value: i.path},
+ })
+ }
+
+ return &ast.GenDecl{
+ Tok: token.IMPORT,
+ Specs: specs,
+ Lparen: token.NoPos + 1,
+ }, nil
+}
diff --git a/tools/go_generics/remove.go b/tools/go_generics/remove.go
new file mode 100644
index 000000000..568a6bbd3
--- /dev/null
+++ b/tools/go_generics/remove.go
@@ -0,0 +1,105 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "go/ast"
+ "go/token"
+)
+
+type typeSet map[string]struct{}
+
+// isTypeOrPointerToType determines if the given AST expression represents a
+// type or a pointer to a type that exists in the provided type set.
+func isTypeOrPointerToType(set typeSet, expr ast.Expr, starCount int) bool {
+ switch e := expr.(type) {
+ case *ast.Ident:
+ _, ok := set[e.Name]
+ return ok
+ case *ast.StarExpr:
+ if starCount > 1 {
+ return false
+ }
+ return isTypeOrPointerToType(set, e.X, starCount+1)
+ case *ast.ParenExpr:
+ return isTypeOrPointerToType(set, e.X, starCount)
+ default:
+ return false
+ }
+}
+
+// isMethodOf determines if the given function declaration is a method of one
+// of the types in the provided type set. To do that, it checks if the function
+// has a receiver and that its type is either T or *T, where T is a type that
+// exists in the set. This is per the spec:
+//
+// That parameter section must declare a single parameter, the receiver. Its
+// type must be of the form T or *T (possibly using parentheses) where T is a
+// type name. The type denoted by T is called the receiver base type; it must
+// not be a pointer or interface type and it must be declared in the same
+// package as the method.
+func isMethodOf(set typeSet, f *ast.FuncDecl) bool {
+ // If the function doesn't have exactly one receiver, then it's
+ // definitely not a method.
+ if f.Recv == nil || len(f.Recv.List) != 1 {
+ return false
+ }
+
+ return isTypeOrPointerToType(set, f.Recv.List[0].Type, 0)
+}
+
+// removeTypeDefinitions removes the definition of all types contained in the
+// provided type set.
+func removeTypeDefinitions(set typeSet, d *ast.GenDecl) {
+ if d.Tok != token.TYPE {
+ return
+ }
+
+ i := 0
+ for _, gs := range d.Specs {
+ s := gs.(*ast.TypeSpec)
+ if _, ok := set[s.Name.Name]; !ok {
+ d.Specs[i] = gs
+ i++
+ }
+ }
+
+ d.Specs = d.Specs[:i]
+}
+
+// removeTypes removes from the AST the definition of all types and their
+// method sets that are contained in the provided type set.
+func removeTypes(set typeSet, f *ast.File) {
+ // Go through the top-level declarations.
+ i := 0
+ for _, decl := range f.Decls {
+ keep := true
+ switch d := decl.(type) {
+ case *ast.GenDecl:
+ countBefore := len(d.Specs)
+ removeTypeDefinitions(set, d)
+ keep = countBefore == 0 || len(d.Specs) > 0
+ case *ast.FuncDecl:
+ keep = !isMethodOf(set, d)
+ }
+
+ if keep {
+ f.Decls[i] = decl
+ i++
+ }
+ }
+
+ f.Decls = f.Decls[:i]
+}
diff --git a/tools/go_generics/rules_tests/BUILD b/tools/go_generics/rules_tests/BUILD
new file mode 100644
index 000000000..8a329dfc6
--- /dev/null
+++ b/tools/go_generics/rules_tests/BUILD
@@ -0,0 +1,43 @@
+load("//tools:defs.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "instance",
+ out = "instance_test.go",
+ consts = {
+ "n": "20",
+ "m": "\"test\"",
+ "o": "math.MaxUint64",
+ },
+ imports = {
+ "math": "math",
+ },
+ package = "template_test",
+ template = ":test_template",
+ types = {
+ "t": "int",
+ },
+)
+
+go_template(
+ name = "test_template",
+ srcs = [
+ "template.go",
+ ],
+ opt_consts = [
+ "n",
+ "m",
+ "o",
+ ],
+ opt_types = ["t"],
+)
+
+go_test(
+ name = "template_test",
+ srcs = [
+ "instance_test.go",
+ "template_test.go",
+ ],
+)
diff --git a/tools/go_generics/rules_tests/template.go b/tools/go_generics/rules_tests/template.go
new file mode 100644
index 000000000..aace61da1
--- /dev/null
+++ b/tools/go_generics/rules_tests/template.go
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package template
+
+type t float
+
+const (
+ n t = 10.1
+ m = "abc"
+ o = 0
+)
+
+func max(a, b t) t {
+ if a > b {
+ return a
+ }
+ return b
+}
+
+func add(a t) t {
+ return a + n
+}
+
+func getName() string {
+ return m
+}
+
+func getMax() uint64 {
+ return o
+}
diff --git a/tools/go_generics/rules_tests/template_test.go b/tools/go_generics/rules_tests/template_test.go
new file mode 100644
index 000000000..b2a3446ef
--- /dev/null
+++ b/tools/go_generics/rules_tests/template_test.go
@@ -0,0 +1,48 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package template_test
+
+import (
+ "math"
+ "testing"
+)
+
+func TestMax(t *testing.T) {
+ var a int = max(10, 20)
+ if a != 20 {
+ t.Errorf("Bad result of max, got %v, want %v", a, 20)
+ }
+}
+
+func TestIntConst(t *testing.T) {
+ var a int = add(10)
+ if a != 30 {
+ t.Errorf("Bad result of add, got %v, want %v", a, 30)
+ }
+}
+
+func TestStrConst(t *testing.T) {
+ v := getName()
+ if v != "test" {
+ t.Errorf("Bad name, got %v, want %v", v, "test")
+ }
+}
+
+func TestImport(t *testing.T) {
+ v := getMax()
+ if v != math.MaxUint64 {
+ t.Errorf("Bad max value, got %v, want %v", v, uint64(math.MaxUint64))
+ }
+}
diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD
new file mode 100644
index 000000000..be49cf9c8
--- /dev/null
+++ b/tools/go_marshal/BUILD
@@ -0,0 +1,19 @@
+load("//tools:defs.bzl", "go_binary")
+
+licenses(["notice"])
+
+go_binary(
+ name = "go_marshal",
+ srcs = ["main.go"],
+ visibility = [
+ "//:sandbox",
+ ],
+ deps = [
+ "//tools/go_marshal/gomarshal",
+ ],
+)
+
+config_setting(
+ name = "marshal_config_verbose",
+ values = {"define": "gomarshal=verbose"},
+)
diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md
new file mode 100644
index 000000000..4886efddf
--- /dev/null
+++ b/tools/go_marshal/README.md
@@ -0,0 +1,116 @@
+This package implements the go_marshal utility.
+
+# Overview
+
+`go_marshal` is a code generation utility similar to `go_stateify` for
+automatically generating code to marshal go data structures to memory.
+
+`go_marshal` attempts to improve on `binary.Write` and the sentry's
+`binary.Marshal` by moving the go runtime reflection necessary to marshal a
+struct to compile-time.
+
+`go_marshal` automatically generates implementations for `abi.Marshallable` and
+`safemem.{Reader,Writer}`. Call-sites for serialization (typically syscall
+implementations) can directly invoke `safemem.Reader.ReadToBlocks` and
+`safemem.Writer.WriteFromBlocks`. Data structures that require custom
+serialization will have manual implementations for these interfaces.
+
+Data structures can be flagged for code generation by adding a struct-level
+comment `// +marshal`.
+
+# Usage
+
+See `defs.bzl`: a new rule is provided, `go_marshal`.
+
+Under the hood, the `go_marshal` rule is used to generate a file that will
+appear in a Go target; the output file should appear explicitly in a srcs list.
+For example (note that the above is the preferred method):
+
+```
+load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_marshal")
+
+go_marshal(
+ name = "foo_abi",
+ srcs = ["foo.go"],
+ out = "foo_abi.go",
+ package = "foo",
+)
+
+go_library(
+ name = "foo",
+ srcs = [
+ "foo.go",
+ "foo_abi.go",
+ ],
+ ...
+)
+```
+
+As part of the interface generation, `go_marshal` also generates some tests for
+sanity checking the struct definitions for potential alignment issues, and a
+simple round-trip test through Marshal/Unmarshal to verify the implementation.
+These tests use reflection to verify properties of the ABI struct, and should be
+considered part of the generated interfaces (but are too expensive to execute at
+runtime). Ensure these tests run at some point.
+
+# Restrictions
+
+Not all valid go type definitions can be used with `go_marshal`. `go_marshal` is
+intended for ABI structs, which have these additional restrictions:
+
+- At the moment, `go_marshal` only supports struct declarations.
+
+- Structs are marshalled as packed types. This means no implicit padding is
+ inserted between fields shorter than the platform register size. For
+ alignment, manually insert padding fields.
+
+- Structs used with `go_marshal` must have a compile-time static size. This
+ means no dynamically sizes fields like slices or strings. Use statically
+ sized array (byte arrays for strings) instead.
+
+- No pointers, channel, map or function pointer fields, and no fields that are
+ arrays of these types. These don't make sense in an ABI data structure.
+
+- We could support opaque pointers as `uintptr`, but this is currently not
+ implemented. Implementing this would require handling the architecture
+ dependent native pointer size.
+
+- Fields must either be a primitive integer type (`byte`,
+ `[u]int{8,16,32,64}`), or of a type that implements abi.Marshallable.
+
+- `int` and `uint` fields are not allowed. Use an explicitly-sized numeric
+ type.
+
+- `float*` fields are currently not supported, but could be if necessary.
+
+# Appendix
+
+## Working with Non-Packed Structs
+
+ABI structs must generally be packed types, meaning they should have no implicit
+padding between short fields. However, if a field is tagged
+`marshal:"unaligned"`, `go_marshal` will fall back to a safer but slower
+mechanism to deal with potentially unaligned fields.
+
+Note that the non-packed property is inheritted by any other struct that embeds
+this struct, since the `go_marshal` tool currently can't reason about alignments
+for embedded structs that are not aligned.
+
+Because of this, it's generally best to avoid using `marshal:"unaligned"` and
+insert explicit padding fields instead.
+
+## Modifying the `go_marshal` Tool
+
+The following are some guidelines for modifying the `go_marshal` tool:
+
+- The `go_marshal` tool currently does a single pass over all types requesting
+ code generation, in arbitrary order. This means the generated code can't
+ directly obtain information about embedded marshallable types at
+ compile-time. One way to work around this restriction is to add a new
+ Marshallable interface method providing this piece of information, and
+ calling it from the generated code. Use this sparingly, as we want to rely
+ on compile-time information as much as possible for performance.
+
+- No runtime reflection in the code generated for the marshallable interface.
+ The entire point of the tool is to avoid runtime reflection. The generated
+ tests may use reflection.
diff --git a/tools/go_marshal/analysis/BUILD b/tools/go_marshal/analysis/BUILD
new file mode 100644
index 000000000..c2a4d45c4
--- /dev/null
+++ b/tools/go_marshal/analysis/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "analysis",
+ testonly = 1,
+ srcs = ["analysis_unsafe.go"],
+ visibility = [
+ "//:sandbox",
+ ],
+)
diff --git a/tools/go_marshal/analysis/analysis_unsafe.go b/tools/go_marshal/analysis/analysis_unsafe.go
new file mode 100644
index 000000000..cd55cf5cb
--- /dev/null
+++ b/tools/go_marshal/analysis/analysis_unsafe.go
@@ -0,0 +1,179 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package analysis implements common functionality used by generated
+// go_marshal tests.
+package analysis
+
+// All functions in this package are unsafe and are not intended for general
+// consumption. They contain sharp edge cases and the caller is responsible for
+// ensuring none of them are hit. Callers must be carefully to pass in only sane
+// arguments. Failure to do so may cause panics at best and arbitrary memory
+// corruption at worst.
+//
+// Never use outside of tests.
+
+import (
+ "fmt"
+ "math/rand"
+ "reflect"
+ "testing"
+ "unsafe"
+)
+
+// RandomizeValue assigns random value(s) to an abitrary type. This is intended
+// for used with ABI structs from go_marshal, meaning the typical restrictions
+// apply (fixed-size types, no pointers, maps, channels, etc), and should only
+// be used on zeroed values to avoid overwriting pointers to active go objects.
+//
+// Internally, we populate the type with random data by doing an unsafe cast to
+// access the underlying memory of the type and filling it as if it were a byte
+// slice. This almost gets us what we want, but padding fields named "_" are
+// normally not accessible, so we walk the type and recursively zero all "_"
+// fields.
+//
+// Precondition: x must be a pointer. x must not contain any valid
+// pointers to active go objects (pointer fields aren't allowed in ABI
+// structs anyways), or we'd be violating the go runtime contract and
+// the GC may malfunction.
+func RandomizeValue(x interface{}) {
+ v := reflect.Indirect(reflect.ValueOf(x))
+ if !v.CanSet() {
+ panic("RandomizeType() called with an unaddressable value. You probably need to pass a pointer to the argument")
+ }
+
+ // Cast the underlying memory for the type into a byte slice.
+ var b []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b))
+ // Note: v.UnsafeAddr panics if x is passed by value. x should be a pointer.
+ hdr.Data = v.UnsafeAddr()
+ hdr.Len = int(v.Type().Size())
+ hdr.Cap = hdr.Len
+
+ // Fill the byte slice with random data, which in effect fills the type with
+ // random values.
+ n, err := rand.Read(b)
+ if err != nil || n != len(b) {
+ panic("unreachable")
+ }
+
+ // Normally, padding fields are not accessible, so zero them out.
+ reflectZeroPaddingFields(v.Type(), b, false)
+}
+
+// reflectZeroPaddingFields assigns zero values to padding fields for the value
+// of type r, represented by the memory in data. Padding fields are defined as
+// fields with the name "_". If zero is true, the immediate value itself is
+// zeroed. In addition, the type is recursively scanned for padding fields in
+// inner types.
+//
+// This is used for zeroing padding fields after calling RandomizeValue.
+func reflectZeroPaddingFields(r reflect.Type, data []byte, zero bool) {
+ if zero {
+ for i, _ := range data {
+ data[i] = 0
+ }
+ }
+ switch r.Kind() {
+ case reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64:
+ // These types are explicitly allowed in an ABI type, but we don't need
+ // to recurse further as they're scalar types.
+ case reflect.Struct:
+ for i, numFields := 0, r.NumField(); i < numFields; i++ {
+ f := r.Field(i)
+ off := f.Offset
+ len := f.Type.Size()
+ window := data[off : off+len]
+ reflectZeroPaddingFields(f.Type, window, f.Name == "_")
+ }
+ case reflect.Array:
+ eLen := int(r.Elem().Size())
+ if int(r.Size()) != eLen*r.Len() {
+ panic("Array has unexpected size?")
+ }
+ for i, n := 0, r.Len(); i < n; i++ {
+ reflectZeroPaddingFields(r.Elem(), data[i*eLen:(i+1)*eLen], false)
+ }
+ default:
+ panic(fmt.Sprintf("Type %v not allowed in ABI struct", r.Kind()))
+
+ }
+}
+
+// AlignmentCheck ensures the definition of the type represented by typ doesn't
+// cause the go compiler to emit implicit padding between elements of the type
+// (i.e. fields in a struct).
+//
+// AlignmentCheck doesn't explicitly recurse for embedded structs because any
+// struct present in an ABI struct must also be Marshallable, and therefore
+// they're aligned by definition (or their alignment check would have failed).
+func AlignmentCheck(t *testing.T, typ reflect.Type) (ok bool, delta uint64) {
+ switch typ.Kind() {
+ case reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64:
+ // Primitive types are always considered well aligned. Primitive types
+ // that are fields in structs are checked independently, this branch
+ // exists to handle recursive calls to alignmentCheck.
+ case reflect.Struct:
+ xOff := 0
+ nextXOff := 0
+ skipNext := false
+ for i, numFields := 0, typ.NumField(); i < numFields; i++ {
+ xOff = nextXOff
+ f := typ.Field(i)
+ fmt.Printf("Checking alignment of %s.%s @ %d [+%d]...\n", typ.Name(), f.Name, f.Offset, f.Type.Size())
+ nextXOff = int(f.Offset + f.Type.Size())
+
+ if f.Name == "_" {
+ // Padding fields need not be aligned.
+ fmt.Printf("Padding field of type %v\n", f.Type)
+ continue
+ }
+
+ if tag, ok := f.Tag.Lookup("marshal"); ok && tag == "unaligned" {
+ skipNext = true
+ continue
+ }
+
+ if skipNext {
+ skipNext = false
+ fmt.Printf("Skipping alignment check for field %s.%s explicitly marked as unaligned.\n", typ.Name(), f.Name)
+ continue
+ }
+
+ if xOff != int(f.Offset) {
+ implicitPad := int(f.Offset) - xOff
+ t.Fatalf("Suspect offset for field %s.%s, detected an implicit %d byte padding from offset %d to %d; either add %d bytes of explicit padding before this field or tag it as `marshal:\"unaligned\"`.", typ.Name(), f.Name, implicitPad, xOff, f.Offset, implicitPad)
+ }
+ }
+
+ // Ensure structs end on a byte explicitly defined by the type.
+ if typ.NumField() > 0 && nextXOff != int(typ.Size()) {
+ implicitPad := int(typ.Size()) - nextXOff
+ f := typ.Field(typ.NumField() - 1) // Final field
+ if tag, ok := f.Tag.Lookup("marshal"); ok && tag == "unaligned" {
+ // Final field explicitly marked unaligned.
+ break
+ }
+ t.Fatalf("Suspect offset for field %s.%s at the end of %s, detected an implicit %d byte padding from offset %d to %d at the end of the struct; either add %d bytes of explict padding at end of the struct or tag the final field %s as `marshal:\"unaligned\"`.",
+ typ.Name(), f.Name, typ.Name(), implicitPad, nextXOff, typ.Size(), implicitPad, f.Name)
+ }
+ case reflect.Array:
+ // Independent arrays are also always considered well aligned. We only
+ // need to worry about their alignment when they're embedded in structs,
+ // which we handle above.
+ default:
+ t.Fatalf("Unsupported type in ABI struct while checking for field alignment for type: %v", typ.Kind())
+ }
+ return true, uint64(typ.Size())
+}
diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl
new file mode 100644
index 000000000..323e33882
--- /dev/null
+++ b/tools/go_marshal/defs.bzl
@@ -0,0 +1,65 @@
+"""Marshal is a tool for generating marshalling interfaces for Go types."""
+
+def _go_marshal_impl(ctx):
+ """Execute the go_marshal tool."""
+ output = ctx.outputs.lib
+ output_test = ctx.outputs.test
+
+ # Run the marshal command.
+ args = ["-output=%s" % output.path]
+ args += ["-pkg=%s" % ctx.attr.package]
+ args += ["-output_test=%s" % output_test.path]
+
+ if ctx.attr.debug:
+ args += ["-debug"]
+
+ args += ["--"]
+ for src in ctx.attr.srcs:
+ args += [f.path for f in src.files.to_list()]
+ ctx.actions.run(
+ inputs = ctx.files.srcs,
+ outputs = [output, output_test],
+ mnemonic = "GoMarshal",
+ progress_message = "go_marshal: %s" % ctx.label,
+ arguments = args,
+ executable = ctx.executable._tool,
+ )
+
+# Generates save and restore logic from a set of Go files.
+#
+# Args:
+# name: the name of the rule.
+# srcs: the input source files. These files should include all structs in the
+# package that need to be saved.
+# imports: an optional list of extra, non-aliased, Go-style absolute import
+# paths.
+# out: the name of the generated file output. This must not conflict with any
+# other files and must be added to the srcs of the relevant go_library.
+# package: the package name for the input sources.
+go_marshal = rule(
+ implementation = _go_marshal_impl,
+ attrs = {
+ "srcs": attr.label_list(mandatory = True, allow_files = True),
+ "imports": attr.string_list(mandatory = False),
+ "package": attr.string(mandatory = True),
+ "debug": attr.bool(doc = "enable debugging output from the go_marshal tool"),
+ "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_marshal:go_marshal")),
+ },
+ outputs = {
+ "lib": "%{name}_unsafe.go",
+ "test": "%{name}_test.go",
+ },
+)
+
+# marshal_deps are the dependencies requied by generated code.
+marshal_deps = [
+ "//pkg/gohacks",
+ "//pkg/safecopy",
+ "//pkg/usermem",
+ "//tools/go_marshal/marshal",
+]
+
+# marshal_test_deps are required by test targets.
+marshal_test_deps = [
+ "//tools/go_marshal/analysis",
+]
diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD
new file mode 100644
index 000000000..44cb33ae4
--- /dev/null
+++ b/tools/go_marshal/gomarshal/BUILD
@@ -0,0 +1,21 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "gomarshal",
+ srcs = [
+ "generator.go",
+ "generator_interfaces.go",
+ "generator_interfaces_array_newtype.go",
+ "generator_interfaces_primitive_newtype.go",
+ "generator_interfaces_struct.go",
+ "generator_tests.go",
+ "util.go",
+ ],
+ stateify = False,
+ visibility = [
+ "//:sandbox",
+ ],
+ deps = ["//tools/tags"],
+)
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
new file mode 100644
index 000000000..177013dbb
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -0,0 +1,499 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gomarshal implements the go_marshal code generator. See README.md.
+package gomarshal
+
+import (
+ "bytes"
+ "fmt"
+ "go/ast"
+ "go/parser"
+ "go/token"
+ "os"
+ "sort"
+ "strings"
+
+ "gvisor.dev/gvisor/tools/tags"
+)
+
+// List of identifiers we use in generated code that may conflict with a
+// similarly-named source identifier. Abort gracefully when we see these to
+// avoid potentially confusing compilation failures in generated code.
+//
+// This only applies to import aliases at the moment. All other identifiers
+// are qualified by a receiver argument, since they're struct fields.
+//
+// All recievers are single letters, so we don't allow import aliases to be a
+// single letter.
+var badIdents = []string{
+ "addr", "blk", "buf", "dst", "dsts", "count", "err", "hdr", "idx", "inner",
+ "length", "limit", "ptr", "size", "src", "srcs", "task", "val",
+ // All single-letter identifiers.
+}
+
+// Constructed fromt badIdents in init().
+var badIdentsMap map[string]struct{}
+
+func init() {
+ badIdentsMap = make(map[string]struct{})
+ for _, ident := range badIdents {
+ badIdentsMap[ident] = struct{}{}
+ }
+}
+
+// Generator drives code generation for a single invocation of the go_marshal
+// utility.
+//
+// The Generator holds arguments passed to the tool, and drives parsing,
+// processing and code Generator for all types marked with +marshal declared in
+// the input files.
+//
+// See Generator.run() as the entry point.
+type Generator struct {
+ // Paths to input go source files.
+ inputs []string
+ // Output file to write generated go source.
+ output *os.File
+ // Output file to write generated tests.
+ outputTest *os.File
+ // Package name for the generated file.
+ pkg string
+ // Set of extra packages to import in the generated file.
+ imports *importTable
+}
+
+// NewGenerator creates a new code Generator.
+func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*Generator, error) {
+ f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if err != nil {
+ return nil, fmt.Errorf("Couldn't open output file %q: %v", out, err)
+ }
+ fTest, err := os.OpenFile(outTest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if err != nil {
+ return nil, fmt.Errorf("Couldn't open test output file %q: %v", out, err)
+ }
+ g := Generator{
+ inputs: srcs,
+ output: f,
+ outputTest: fTest,
+ pkg: pkg,
+ imports: newImportTable(),
+ }
+ for _, i := range imports {
+ // All imports on the extra imports list are unconditionally marked as
+ // used, so that they're always added to the generated code.
+ g.imports.add(i).markUsed()
+ }
+
+ // The following imports may or may not be used by the generated code,
+ // depending on what's required for the target types. Don't mark these as
+ // used by default.
+ g.imports.add("io")
+ g.imports.add("reflect")
+ g.imports.add("runtime")
+ g.imports.add("unsafe")
+ g.imports.add("gvisor.dev/gvisor/pkg/gohacks")
+ g.imports.add("gvisor.dev/gvisor/pkg/safecopy")
+ g.imports.add("gvisor.dev/gvisor/pkg/usermem")
+ g.imports.add("gvisor.dev/gvisor/tools/go_marshal/marshal")
+
+ return &g, nil
+}
+
+// writeHeader writes the header for the generated source file. The header
+// includes the package name, package level comments and import statements.
+func (g *Generator) writeHeader() error {
+ var b sourceBuffer
+ b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n")
+
+ // Emit build tags.
+ if t := tags.Aggregate(g.inputs); len(t) > 0 {
+ b.emit(strings.Join(t.Lines(), "\n"))
+ b.emit("\n\n")
+ }
+
+ // Package header.
+ b.emit("package %s\n\n", g.pkg)
+ if err := b.write(g.output); err != nil {
+ return err
+ }
+
+ return g.imports.write(g.output)
+}
+
+// writeTypeChecks writes a statement to force the compiler to perform a type
+// check for all Marshallable types referenced by the generated code.
+func (g *Generator) writeTypeChecks(ms map[string]struct{}) error {
+ if len(ms) == 0 {
+ return nil
+ }
+
+ msl := make([]string, 0, len(ms))
+ for m, _ := range ms {
+ msl = append(msl, m)
+ }
+ sort.Strings(msl)
+
+ var buf bytes.Buffer
+ fmt.Fprint(&buf, "// Marshallable types used by this file.\n")
+
+ for _, m := range msl {
+ fmt.Fprintf(&buf, "var _ marshal.Marshallable = (*%s)(nil)\n", m)
+ }
+ fmt.Fprint(&buf, "\n")
+
+ _, err := fmt.Fprint(g.output, buf.String())
+ return err
+}
+
+// parse processes all input files passed this generator and produces a set of
+// parsed go ASTs.
+func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) {
+ debugf("go_marshal invoked with %d input files:\n", len(g.inputs))
+ for _, path := range g.inputs {
+ debugf(" %s\n", path)
+ }
+
+ files := make([]*ast.File, 0, len(g.inputs))
+ fsets := make([]*token.FileSet, 0, len(g.inputs))
+
+ for _, path := range g.inputs {
+ fset := token.NewFileSet()
+ f, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
+ if err != nil {
+ // Not a valid input file?
+ return nil, nil, fmt.Errorf("Input %q can't be parsed: %v", path, err)
+ }
+
+ if debugEnabled() {
+ debugf("AST for %q:\n", path)
+ ast.Print(fset, f)
+ }
+
+ files = append(files, f)
+ fsets = append(fsets, fset)
+ }
+
+ return files, fsets, nil
+}
+
+// sliceAPI carries information about the '+marshal slice' directive.
+type sliceAPI struct {
+ // Comment node in the AST containing the +marshal tag.
+ comment *ast.Comment
+ // Identifier fragment to use when naming generated functions for the slice
+ // API.
+ ident string
+ // Whether the generated functions should reference the newtype name, or the
+ // inner type name. Only meaningful on newtype declarations on primitives.
+ inner bool
+}
+
+// marshallableType carries information about a type marked with the '+marshal'
+// directive.
+type marshallableType struct {
+ spec *ast.TypeSpec
+ slice *sliceAPI
+}
+
+func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) marshallableType {
+ mt := marshallableType{
+ spec: spec,
+ slice: nil,
+ }
+
+ var unhandledTags []string
+
+ for _, tag := range strings.Fields(strings.TrimPrefix(tagLine.Text, "// +marshal")) {
+ if strings.HasPrefix(tag, "slice:") {
+ tokens := strings.Split(tag, ":")
+ if len(tokens) < 2 || len(tokens) > 3 {
+ abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive has invalid 'slice' clause. Expecting format 'slice:<IDENTIFIER>[:inner]', got '%v'", tag))
+ }
+ if len(tokens[1]) == 0 {
+ abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has empty identifier argument. Expecting '+marshal slice:identifier'")
+ }
+
+ sa := &sliceAPI{
+ comment: tagLine,
+ ident: tokens[1],
+ }
+ mt.slice = sa
+
+ if len(tokens) == 3 {
+ if tokens[2] != "inner" {
+ abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has an invalid argument. Expecting '+marshal slice:<IDENTIFIER>[:inner]'")
+ }
+ sa.inner = true
+ }
+
+ continue
+ }
+
+ unhandledTags = append(unhandledTags, tag)
+ }
+
+ if len(unhandledTags) > 0 {
+ abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive contained the following unknown clauses: %v", strings.Join(unhandledTags, " ")))
+ }
+
+ return mt
+}
+
+// collectMarshallableTypes walks the parsed AST and collects a list of type
+// declarations for which we need to generate the Marshallable interface.
+func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []marshallableType {
+ var types []marshallableType
+ for _, decl := range a.Decls {
+ gdecl, ok := decl.(*ast.GenDecl)
+ // Type declaration?
+ if !ok || gdecl.Tok != token.TYPE {
+ debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n")
+ continue
+ }
+ // Does it have a comment?
+ if gdecl.Doc == nil {
+ debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment.\n")
+ continue
+ }
+ // Does the comment contain a "+marshal" line?
+ marked := false
+ var tagLine *ast.Comment
+ for _, c := range gdecl.Doc.List {
+ if strings.HasPrefix(c.Text, "// +marshal") {
+ marked = true
+ tagLine = c
+ break
+ }
+ }
+ if !marked {
+ debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment containing +marshal line.\n")
+ continue
+ }
+ for _, spec := range gdecl.Specs {
+ // We already confirmed we're in a type declaration earlier, so this
+ // cast will succeed.
+ t := spec.(*ast.TypeSpec)
+ switch t.Type.(type) {
+ case *ast.StructType:
+ debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name)
+ case *ast.Ident: // Newtype on primitive.
+ debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name)
+ case *ast.ArrayType: // Newtype on array.
+ debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name)
+ default:
+ // A user specifically requested marshalling on this type, but we
+ // don't support it.
+ abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name))
+ }
+ types = append(types, newMarshallableType(f, tagLine, t))
+
+ }
+ }
+ return types
+}
+
+// collectImports collects all imports from all input source files. Some of
+// these imports are copied to the generated output, if they're referenced by
+// the generated code.
+//
+// collectImports de-duplicates imports while building the list, and ensures
+// identifiers in the generated code don't conflict with any imported package
+// names.
+func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt {
+ is := make(map[string]importStmt)
+ for _, decl := range a.Decls {
+ gdecl, ok := decl.(*ast.GenDecl)
+ // Import statement?
+ if !ok || gdecl.Tok != token.IMPORT {
+ continue
+ }
+ for _, spec := range gdecl.Specs {
+ i := g.imports.addFromSpec(spec.(*ast.ImportSpec), f)
+ debugf("Collected import '%s' as '%s'\n", i.path, i.name)
+
+ // Make sure we have an import that doesn't use any local names that
+ // would conflict with identifiers in the generated code.
+ if len(i.name) == 1 && i.name != "_" {
+ abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name))
+ }
+ if _, ok := badIdentsMap[i.name]; ok {
+ abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name))
+ }
+ }
+ }
+ return is
+
+}
+
+func (g *Generator) generateOne(t marshallableType, fset *token.FileSet) *interfaceGenerator {
+ i := newInterfaceGenerator(t.spec, fset)
+ switch ty := t.spec.Type.(type) {
+ case *ast.StructType:
+ i.validateStruct(t.spec, ty)
+ i.emitMarshallableForStruct(ty)
+ if t.slice != nil {
+ i.emitMarshallableSliceForStruct(ty, t.slice)
+ }
+ case *ast.Ident:
+ i.validatePrimitiveNewtype(ty)
+ i.emitMarshallableForPrimitiveNewtype(ty)
+ if t.slice != nil {
+ i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice)
+ }
+ case *ast.ArrayType:
+ i.validateArrayNewtype(t.spec.Name, ty)
+ // After validate, we can safely call arrayLen.
+ i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident))
+ if t.slice != nil {
+ abortAt(fset.Position(t.slice.comment.Slash), fmt.Sprintf("Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?"))
+ }
+ default:
+ // This should've been filtered out by collectMarshallabeTypes.
+ panic(fmt.Sprintf("Unexpected type %+v", ty))
+ }
+ return i
+}
+
+// generateOneTestSuite generates a test suite for the automatically generated
+// implementations type t.
+func (g *Generator) generateOneTestSuite(t marshallableType) *testGenerator {
+ i := newTestGenerator(t.spec)
+ i.emitTests(t.slice)
+ return i
+}
+
+// Run is the entry point to code generation using g.
+//
+// Run parses all input source files specified in g and emits generated code.
+func (g *Generator) Run() error {
+ // Parse our input source files into ASTs and token sets.
+ asts, fsets, err := g.parse()
+ if err != nil {
+ return err
+ }
+
+ if len(asts) != len(fsets) {
+ panic("ASTs and FileSets don't match")
+ }
+
+ // Map of imports in source files; key = local package name, value = import
+ // path.
+ is := make(map[string]importStmt)
+ for i, a := range asts {
+ // Collect all imports from the source files. We may need to copy some
+ // of these to the generated code if they're referenced. This has to be
+ // done before the loop below because we need to process all ASTs before
+ // we start requesting imports to be copied one by one as we encounter
+ // them in each generated source.
+ for name, i := range g.collectImports(a, fsets[i]) {
+ is[name] = i
+ }
+ }
+
+ var impls []*interfaceGenerator
+ var ts []*testGenerator
+ // Set of Marshallable types referenced by generated code.
+ ms := make(map[string]struct{})
+ for i, a := range asts {
+ // Collect type declarations marked for code generation and generate
+ // Marshallable interfaces.
+ for _, t := range g.collectMarshallableTypes(a, fsets[i]) {
+ impl := g.generateOne(t, fsets[i])
+ // Collect Marshallable types referenced by the generated code.
+ for ref, _ := range impl.ms {
+ ms[ref] = struct{}{}
+ }
+ impls = append(impls, impl)
+ // Collect imports referenced by the generated code and add them to
+ // the list of imports we need to copy to the generated code.
+ for name, _ := range impl.is {
+ if !g.imports.markUsed(name) {
+ panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name))
+ }
+ }
+ ts = append(ts, g.generateOneTestSuite(t))
+ }
+ }
+
+ // Write output file header. These include things like package name and
+ // import statements.
+ if err := g.writeHeader(); err != nil {
+ return err
+ }
+
+ // Write type checks for referenced marshallable types to output file.
+ if err := g.writeTypeChecks(ms); err != nil {
+ return err
+ }
+
+ // Write generated interfaces to output file.
+ for _, i := range impls {
+ if err := i.write(g.output); err != nil {
+ return err
+ }
+ }
+
+ // Write generated tests to test file.
+ return g.writeTests(ts)
+}
+
+// writeTests outputs tests for the generated interface implementations to a go
+// source file.
+func (g *Generator) writeTests(ts []*testGenerator) error {
+ var b sourceBuffer
+ b.emit("package %s\n\n", g.pkg)
+ if err := b.write(g.outputTest); err != nil {
+ return err
+ }
+
+ // Collect and write test import statements.
+ imports := newImportTable()
+ for _, t := range ts {
+ imports.merge(t.imports)
+ }
+
+ if err := imports.write(g.outputTest); err != nil {
+ return err
+ }
+
+ // Write test functions.
+
+ // If we didn't generate any Marshallable implementations, we can't just
+ // emit an empty test file, since that causes the build to fail with "no
+ // tests/benchmarks/examples found". Unfortunately we can't signal bazel to
+ // omit the entire package since the outputs are already defined before
+ // go-marshal is called. If we'd otherwise emit an empty test suite, emit an
+ // empty example instead.
+ if len(ts) == 0 {
+ b.reset()
+ b.emit("func Example() {\n")
+ b.inIndent(func() {
+ b.emit("// This example is intentionally empty to ensure this file contains at least\n")
+ b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n")
+ b.emit("// is marked marshallable, but emitting a test file with no entities results\n")
+ b.emit("// in a build failure.\n")
+ })
+ b.emit("}\n")
+ return b.write(g.outputTest)
+ }
+
+ for _, t := range ts {
+ if err := t.write(g.outputTest); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
new file mode 100644
index 000000000..e3c3dac63
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -0,0 +1,276 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+ "go/token"
+ "strings"
+)
+
+// interfaceGenerator generates marshalling interfaces for a single type.
+//
+// getState is not thread-safe.
+type interfaceGenerator struct {
+ sourceBuffer
+
+ // The type we're serializing.
+ t *ast.TypeSpec
+
+ // Receiver argument for generated methods.
+ r string
+
+ // FileSet containing the tokens for the type we're processing.
+ f *token.FileSet
+
+ // is records external packages referenced by the generated implementation.
+ is map[string]struct{}
+
+ // ms records Marshallable types referenced by the generated implementation
+ // of t's interfaces.
+ ms map[string]struct{}
+
+ // as records embedded fields in t that are potentially not packed. The key
+ // is the accessor for the field.
+ as map[string]struct{}
+}
+
+// typeName returns the name of the type this g represents.
+func (g *interfaceGenerator) typeName() string {
+ return g.t.Name.Name
+}
+
+// newinterfaceGenerator creates a new interface generator.
+func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
+ g := &interfaceGenerator{
+ t: t,
+ r: receiverName(t),
+ f: fset,
+ is: make(map[string]struct{}),
+ ms: make(map[string]struct{}),
+ as: make(map[string]struct{}),
+ }
+ g.recordUsedMarshallable(g.typeName())
+ return g
+}
+
+func (g *interfaceGenerator) recordUsedMarshallable(m string) {
+ g.ms[m] = struct{}{}
+
+}
+
+func (g *interfaceGenerator) recordUsedImport(i string) {
+ g.is[i] = struct{}{}
+}
+
+func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) {
+ g.as[fieldName] = struct{}{}
+}
+
+// abortAt aborts the go_marshal tool with the given error message, with a
+// reference position to the input source. Same as abortAt, but uses g to
+// resolve p to position.
+func (g *interfaceGenerator) abortAt(p token.Pos, msg string) {
+ abortAt(g.f.Position(p), msg)
+}
+
+// scalarSize returns the size of type identified by t. If t isn't a primitive
+// type, the size isn't known at code generation time, and must be resolved via
+// the marshal.Marshallable interface.
+func (g *interfaceGenerator) scalarSize(t *ast.Ident) (size int, unknownSize bool) {
+ switch t.Name {
+ case "int8", "uint8", "byte":
+ return 1, false
+ case "int16", "uint16":
+ return 2, false
+ case "int32", "uint32":
+ return 4, false
+ case "int64", "uint64":
+ return 8, false
+ default:
+ return 0, true
+ }
+}
+
+func (g *interfaceGenerator) shift(bufVar string, n int) {
+ g.emit("%s = %s[%d:]\n", bufVar, bufVar, n)
+}
+
+func (g *interfaceGenerator) shiftDynamic(bufVar, name string) {
+ g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name)
+}
+
+// marshalScalar writes a single scalar to a byte slice.
+func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) {
+ switch typ {
+ case "int8", "uint8", "byte":
+ g.emit("%s[0] = byte(%s)\n", bufVar, accessor)
+ g.shift(bufVar, 1)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(%s))\n", bufVar, accessor)
+ g.shift(bufVar, 2)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(%s))\n", bufVar, accessor)
+ g.shift(bufVar, 4)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(%s))\n", bufVar, accessor)
+ g.shift(bufVar, 8)
+ default:
+ g.emit("%s.MarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
+ g.shiftDynamic(bufVar, accessor)
+ }
+}
+
+// unmarshalScalar reads a single scalar from a byte slice.
+func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) {
+ switch typ {
+ case "byte":
+ g.emit("%s = %s[0]\n", accessor, bufVar)
+ g.shift(bufVar, 1)
+ case "int8", "uint8":
+ g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar)
+ g.shift(bufVar, 1)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("%s = %s(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar)
+ g.shift(bufVar, 2)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("%s = %s(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar)
+ g.shift(bufVar, 4)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("%s = %s(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar)
+ g.shift(bufVar, 8)
+ default:
+ g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
+ g.shiftDynamic(bufVar, accessor)
+ g.recordPotentiallyNonPackedField(accessor)
+ }
+}
+
+// emitCastToByteSlice unsafely casts an arbitrary type's underlying memory to a
+// byte slice, bypassing escape analysis. The caller is responsible for ensuring
+// srcPtr lives until they're done with dstVar, the runtime does not consider
+// dstVar dependent on srcPtr due to the escape analysis bypass.
+//
+// srcPtr must be a pointer.
+//
+// This function uses internally uses the identifier "hdr", and cannot be used
+// in a context where it is already bound.
+func (g *interfaceGenerator) emitCastToByteSlice(srcPtr, dstVar, lenExpr string) {
+ g.recordUsedImport("gohacks")
+ g.emit("// Construct a slice backed by dst's underlying memory.\n")
+ g.emit("var %s []byte\n", dstVar)
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar)
+ g.emit("hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(%s)))\n", srcPtr)
+ g.emit("hdr.Len = %s\n", lenExpr)
+ g.emit("hdr.Cap = %s\n\n", lenExpr)
+}
+
+// emitCastToByteSlice unsafely casts a slice with elements of an abitrary type
+// to a byte slice. As part of the cast, the byte slice is made to look
+// independent of the src slice by bypassing escape analysis. This means the
+// byte slice can be used without causing the source to escape. The caller is
+// responsible for ensuring srcPtr lives until they're done with dstVar, as the
+// runtime no longer considers dstVar dependent on srcPtr and is free to GC it.
+//
+// srcPtr must be a pointer.
+//
+// This function uses internally uses the identifiers "ptr", "val" and "hdr",
+// and cannot be used in a context where these identifiers are already bound.
+func (g *interfaceGenerator) emitCastSliceToByteSlice(srcPtr, dstVar, lenExpr string) {
+ g.emitNoEscapeSliceDataPointer(srcPtr, "val")
+
+ g.emit("// Construct a slice backed by dst's underlying memory.\n")
+ g.emit("var %s []byte\n", dstVar)
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar)
+ g.emit("hdr.Data = uintptr(val)\n")
+ g.emit("hdr.Len = %s\n", lenExpr)
+ g.emit("hdr.Cap = %s\n\n", lenExpr)
+}
+
+// emitNoEscapeSliceDataPointer unsafely casts a slice's data pointer to an
+// unsafe.Pointer, bypassing escape analysis. The caller is responsible for
+// ensuring srcPtr lives until they're done with dstVar, as the runtime no
+// longer considers dstVar dependent on srcPtr and is free to GC it.
+//
+// srcPtr must be a pointer.
+//
+// This function uses internally uses the identifier "ptr" cannot be used in a
+// context where this identifier is already bound.
+func (g *interfaceGenerator) emitNoEscapeSliceDataPointer(srcPtr, dstVar string) {
+ g.recordUsedImport("gohacks")
+ g.emit("ptr := unsafe.Pointer(%s)\n", srcPtr)
+ g.emit("%s := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data))\n\n", dstVar)
+}
+
+func (g *interfaceGenerator) emitKeepAlive(ptrVar string) {
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", ptrVar)
+ g.emit("// must live until the use above.\n")
+ g.emit("runtime.KeepAlive(%s)\n", ptrVar)
+}
+
+func (g *interfaceGenerator) expandBinaryExpr(b *strings.Builder, e *ast.BinaryExpr) {
+ switch x := e.X.(type) {
+ case *ast.BinaryExpr:
+ // Recursively expand sub-expression.
+ g.expandBinaryExpr(b, x)
+ case *ast.Ident:
+ fmt.Fprintf(b, "%s", x.Name)
+ case *ast.BasicLit:
+ fmt.Fprintf(b, "%s", x.Value)
+ default:
+ g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
+ }
+
+ fmt.Fprintf(b, "%s", e.Op)
+
+ switch y := e.Y.(type) {
+ case *ast.BinaryExpr:
+ // Recursively expand sub-expression.
+ g.expandBinaryExpr(b, y)
+ case *ast.Ident:
+ fmt.Fprintf(b, "%s", y.Name)
+ case *ast.BasicLit:
+ fmt.Fprintf(b, "%s", y.Value)
+ default:
+ g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
+ }
+}
+
+// arrayLenExpr returns a string containing a valid golang expression
+// representing the length of array a. The returned expression should be treated
+// as a single value, and will be already parenthesized as required.
+func (g *interfaceGenerator) arrayLenExpr(a *ast.ArrayType) string {
+ var b strings.Builder
+
+ switch l := a.Len.(type) {
+ case *ast.Ident:
+ fmt.Fprintf(&b, "%s", l.Name)
+ case *ast.BasicLit:
+ fmt.Fprintf(&b, "%s", l.Value)
+ case *ast.BinaryExpr:
+ g.expandBinaryExpr(&b, l)
+ return fmt.Sprintf("(%s)", b.String())
+ default:
+ g.abortAt(l.Pos(), "Cannot convert this array len expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
+ }
+ return b.String()
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
new file mode 100644
index 000000000..72ef03a22
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
@@ -0,0 +1,146 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains the bits of the code generator specific to marshalling
+// newtypes on arrays.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+)
+
+func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType) {
+ if a.Len == nil {
+ g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
+ }
+
+ if _, ok := a.Elt.(*ast.Ident); !ok {
+ g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
+ }
+}
+
+func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *ast.ArrayType, elt *ast.Ident) {
+ g.recordUsedImport("io")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ g.recordUsedImport("usermem")
+
+ lenExpr := g.arrayLenExpr(a)
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if size, dynamic := g.scalarSize(elt); !dynamic {
+ g.emit("return %d * %s\n", size, lenExpr)
+ } else {
+ g.emit("return (*%s)(nil).SizeBytes() * %s\n", n.Name, lenExpr)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst")
+ })
+ g.emit("}\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src")
+ })
+ g.emit("}\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Array newtypes are always packed.\n")
+ g.emit("return true\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := w.Write(buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return int64(length), err\n")
+
+ })
+ g.emit("}\n\n")
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
new file mode 100644
index 000000000..39f654ea8
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
@@ -0,0 +1,289 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains the bits of the code generator specific to marshalling
+// newtypes on primitives.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+)
+
+// marshalPrimitiveScalar writes a single primitive variable to a byte
+// slice.
+func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) {
+ switch typ {
+ case "int8", "uint8", "byte":
+ g.emit("%s[0] = byte(*%s)\n", bufVar, accessor)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor)
+ default:
+ g.emit("// Explicilty cast to the underlying type before dispatching to\n")
+ g.emit("// MarshalBytes, so we don't recursively call %s.MarshalBytes\n", accessor)
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
+// unmarshalPrimitiveScalar read a single primitive variable from a byte slice.
+func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) {
+ switch typ {
+ case "byte":
+ g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar)
+ case "int8", "uint8":
+ g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar)
+ default:
+ g.emit("// Explicilty cast to the underlying type before dispatching to\n")
+ g.emit("// UnmarshalBytes, so we don't recursively call %s.UnmarshalBytes\n", accessor)
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
+func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) {
+ switch t.Name {
+ case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
+ // These are the only primitive types we're allow. Below, we provide
+ // suggestions for some disallowed types and reject them, then attempt
+ // to marshal any remaining types by invoking the marshal.Marshallable
+ // interface on them. If these types don't actually implement
+ // marshal.Marshallable, compilation of the generated code will fail
+ // with an appropriate error message.
+ return
+ case "int":
+ g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
+ case "uint":
+ g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
+ case "string":
+ g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
+ default:
+ debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
+ }
+}
+
+// emitMarshallableForPrimitiveNewtype outputs code to implement the
+// marshal.Marshallable interface for a newtype on a primitive. Primitive
+// newtypes are always packed, so we can omit the various fallbacks required for
+// non-packed structs.
+func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) {
+ g.recordUsedImport("io")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ g.recordUsedImport("usermem")
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if size, dynamic := g.scalarSize(nt); !dynamic {
+ g.emit("return %d\n", size)
+ } else {
+ g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.marshalPrimitiveScalar(g.r, nt.Name, "dst")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName())
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Scalar newtypes are always packed.\n")
+ g.emit("return true\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := w.Write(buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return int64(length), err\n")
+
+ })
+ g.emit("}\n\n")
+}
+
+func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Ident, slice *sliceAPI) {
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+
+ eltType := g.typeName()
+ if slice.inner {
+ eltType = nt.Name
+ }
+
+ g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, eltType)
+ g.emit("//go:nosplit\n")
+ g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, eltType)
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitCastSliceToByteSlice("&dst", "buf", "size * count")
+
+ g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, eltType)
+ g.emit("//go:nosplit\n")
+ g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, eltType)
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitCastSliceToByteSlice("&src", "buf", "size * count")
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf) // escapes: okay.\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitNoEscapeSliceDataPointer("&src", "val")
+
+ g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitNoEscapeSliceDataPointer("&dst", "val")
+
+ g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
new file mode 100644
index 000000000..9cd3c9579
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
@@ -0,0 +1,618 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains the bits of the code generator specific to marshalling
+// structs.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+ "strings"
+)
+
+func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string {
+ return fmt.Sprintf("%s.%s", g.r, n.Name)
+}
+
+// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
+// packed. Returns "", false if g.t has no fields that may be potentially
+// packed, otherwise returns <clause>, true, where <clause> is an expression
+// like "t.a.Packed() && t.b.Packed() && t.c.Packed()".
+func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
+ if len(g.as) == 0 {
+ return "", false
+ }
+
+ cs := make([]string, 0, len(g.as))
+ for accessor, _ := range g.as {
+ cs = append(cs, fmt.Sprintf("%s.Packed()", accessor))
+ }
+ return strings.Join(cs, " && "), true
+}
+
+// validateStruct ensures the type we're working with can be marshalled. These
+// checks are done ahead of time and in one place so we can make assumptions
+// later.
+func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType) {
+ forEachStructField(st, func(f *ast.Field) {
+ if len(f.Names) == 0 {
+ g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
+ }
+ })
+
+ forEachStructField(st, func(f *ast.Field) {
+ fieldDispatcher{
+ primitive: func(_, t *ast.Ident) {
+ g.validatePrimitiveNewtype(t)
+ },
+ selector: func(_, _, _ *ast.Ident) {
+ // No validation to perform on selector fields. However this
+ // callback must still be provided.
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, _ *ast.Ident) {
+ g.validateArrayNewtype(n, a)
+ },
+ unhandled: func(_ *ast.Ident) {
+ g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
+ },
+ }.dispatch(f)
+ })
+}
+
+func (g *interfaceGenerator) isStructPacked(st *ast.StructType) bool {
+ packed := true
+ forEachStructField(st, func(f *ast.Field) {
+ if f.Tag != nil {
+ if f.Tag.Value == "`marshal:\"unaligned\"`" {
+ if packed {
+ debugfAt(g.f.Position(g.t.Pos()),
+ fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name))
+ packed = false
+ }
+ }
+ }
+ })
+ return packed
+}
+
+func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
+ thisPacked := g.isStructPacked(st)
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ primitiveSize := 0
+ var dynamicSizeTerms []string
+
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ primitiveSize += size
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name))
+ }
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
+ g.recordUsedImport(tX.Name)
+ g.recordUsedMarshallable(tName)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr))
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr))
+ }
+ },
+ }.dispatch)
+ g.emit("return %d", primitiveSize)
+ if len(dynamicSizeTerms) > 0 {
+ g.incIndent()
+ }
+ {
+ for _, d := range dynamicSizeTerms {
+ g.emitNoIndent(" +\n")
+ g.emit(d)
+ }
+ }
+ if len(dynamicSizeTerms) > 0 {
+ g.decIndent()
+ }
+ })
+ g.emit("\n}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("dst", len)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
+ }
+ return
+ }
+ g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", tX.Name, tSel.Name)
+ g.emit("dst = dst[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
+ return
+ }
+ g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ g.emit("dst = dst[%d*(%s):]\n", size, lenExpr)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
+ }
+ return
+ }
+
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("src", len)
+ } else {
+ // We don't have an instance of the dynamic type we can
+ // reference here (since the version in this struct is
+ // anonymous). Use a typed nil pointer to call
+ // SizeBytes() instead.
+ g.shiftDynamic("src", fmt.Sprintf("(*%s)(nil)", t.Name))
+ g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
+ }
+ return
+ }
+ g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: %s ~= src[:sizeof(%s.%s)]\n", g.fieldAccessor(n), tX.Name, tSel.Name)
+ g.emit("src = src[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
+ g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s.%s)(nil)", tX.Name, tSel.Name))
+ return
+ }
+ g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
+ if n.Name == "_" {
+ g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ g.emit("src = src[%d*(%s):]\n", size, lenExpr)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can referece here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
+ }
+ return
+ }
+
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ expr, fieldsMaybePacked := g.areFieldsPackedExpression()
+ switch {
+ case !thisPacked:
+ g.emit("return false\n")
+ case fieldsMaybePacked:
+ g.emit("return %s\n", expr)
+ default:
+ g.emit("return true\n")
+
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(func() {
+ g.emit("%s.MarshalBytes(dst)\n", g.r)
+ })
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ }
+ } else {
+ g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName())
+ g.emit("%s.MarshalBytes(dst)\n", g.r)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(func() {
+ g.emit("%s.UnmarshalBytes(src)\n", g.r)
+ })
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ }
+ } else {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("%s.UnmarshalBytes(src)\n", g.r)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
+ g.emit("//go:nosplit\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r)
+ g.emit("%s.MarshalBytes(buf) // escapes: fallback.\n", g.r)
+ g.emit("return task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("//go:nosplit\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("//go:nosplit\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r)
+ g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emit("// Unmarshal unconditionally. If we had a short copy-in, this results in a\n")
+ g.emit("// partially unmarshalled struct.\n")
+ g.emit("%s.UnmarshalBytes(buf) // escapes: fallback.\n", g.r)
+ g.emit("return length, err\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast deserialization.
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.recordUsedImport("io")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r)
+ g.emit("%s.MarshalBytes(buf)\n", g.r)
+ g.emit("length, err := w.Write(buf)\n")
+ g.emit("return int64(length), err\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := w.Write(buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return int64(length), err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+}
+
+func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, slice *sliceAPI) {
+ thisPacked := g.isStructPacked(st)
+
+ if slice.inner {
+ abortAt(g.f.Position(slice.comment.Slash), fmt.Sprintf("The ':inner' argument to '+marshal slice:%s:inner' is only applicable to newtypes on primitives. Remove it from this struct declaration.", slice.ident))
+ }
+
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+
+ g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, g.typeName())
+ g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(size * count)\n")
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n\n")
+
+ g.emit("// Unmarshal as much as possible, even on error. First handle full objects.\n")
+ g.emit("limit := length/size\n")
+ g.emit("for idx := 0; idx < limit; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Handle any final partial object.\n")
+ g.emit("if length < size*count && length%size != 0 {\n")
+ g.inIndent(func() {
+ g.emit("idx := limit\n")
+ g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("return length, err\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !dst[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast deserialization.
+ g.emitCastSliceToByteSlice("&dst", "buf", "size * count")
+
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, g.typeName())
+ g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(size * count)\n")
+ g.emit("for idx := 0; idx < count; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("src[idx].MarshalBytes(buf[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n")
+ g.emit("return task.CopyOutBytes(addr, buf)\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !src[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emitCastSliceToByteSlice("&src", "buf", "size * count")
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf)\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("for idx := 0; idx < count; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("src[idx].MarshalBytes(dst[size*idx:(size)*(idx+1)])\n")
+ })
+ g.emit("}\n")
+ g.emit("return size * count, nil\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !src[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ g.emitNoEscapeSliceDataPointer("&src", "val")
+
+ g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("for idx := 0; idx < count; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("dst[idx].UnmarshalBytes(src[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n")
+ g.emit("return size * count, nil\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !dst[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ g.emitNoEscapeSliceDataPointer("&dst", "val")
+
+ g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+}
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
new file mode 100644
index 000000000..631295373
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -0,0 +1,233 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+ "io"
+ "strings"
+)
+
+var standardImports = []string{
+ "bytes",
+ "fmt",
+ "reflect",
+ "testing",
+
+ "gvisor.dev/gvisor/tools/go_marshal/analysis",
+}
+
+var sliceAPIImports = []string{
+ "encoding/binary",
+ "gvisor.dev/gvisor/pkg/usermem",
+}
+
+type testGenerator struct {
+ sourceBuffer
+
+ // The type we're serializing.
+ t *ast.TypeSpec
+
+ // Receiver argument for generated methods.
+ r string
+
+ // Imports used by generated code.
+ imports *importTable
+
+ // Import statement for the package declaring the type we generated code
+ // for. We need this to construct test instances for the type, since the
+ // tests aren't written in the same package.
+ decl *importStmt
+}
+
+func newTestGenerator(t *ast.TypeSpec) *testGenerator {
+ g := &testGenerator{
+ t: t,
+ r: receiverName(t),
+ imports: newImportTable(),
+ }
+
+ for _, i := range standardImports {
+ g.imports.add(i).markUsed()
+ }
+ // These imports are used if a type requests the slice API. Don't
+ // mark them as used by default.
+ for _, i := range sliceAPIImports {
+ g.imports.add(i)
+ }
+
+ return g
+}
+
+func (g *testGenerator) typeName() string {
+ return g.t.Name.Name
+}
+
+func (g *testGenerator) testFuncName(base string) string {
+ return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name))
+}
+
+func (g *testGenerator) inTestFunction(name string, body func()) {
+ g.emit("func %s(t *testing.T) {\n", g.testFuncName(name))
+ g.inIndent(body)
+ g.emit("}\n\n")
+}
+
+func (g *testGenerator) emitTestNonZeroSize() {
+ g.inTestFunction("TestSizeNonZero", func() {
+ g.emit("var x %v\n", g.typeName())
+ g.emit("if x.SizeBytes() == 0 {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(\"Marshallable.SizeBytes() should not return zero\")\n")
+ })
+ g.emit("}\n")
+ })
+}
+
+func (g *testGenerator) emitTestSuspectAlignment() {
+ g.inTestFunction("TestSuspectAlignment", func() {
+ g.emit("var x %v\n", g.typeName())
+ g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n")
+ })
+}
+
+func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() {
+ g.inTestFunction("TestSafeMarshalUnmarshalPreservesData", func() {
+ g.emit("var x, y, z, yUnsafe, zUnsafe %s\n", g.typeName())
+ g.emit("analysis.RandomizeValue(&x)\n\n")
+
+ g.emit("buf := make([]byte, x.SizeBytes())\n")
+ g.emit("x.MarshalBytes(buf)\n")
+ g.emit("bufUnsafe := make([]byte, x.SizeBytes())\n")
+ g.emit("x.MarshalUnsafe(bufUnsafe)\n\n")
+
+ g.emit("y.UnmarshalBytes(buf)\n")
+ g.emit("if !reflect.DeepEqual(x, y) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n")
+ })
+ g.emit("}\n")
+ g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n")
+ g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("z.UnmarshalUnsafe(buf)\n")
+ g.emit("if !reflect.DeepEqual(x, z) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, z))\n")
+ })
+ g.emit("}\n")
+ g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n")
+ g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, zUnsafe))\n")
+ })
+ g.emit("}\n")
+ })
+}
+
+func (g *testGenerator) emitTestMarshalUnmarshalSlicePreservesData(slice *sliceAPI) {
+ for _, name := range []string{"binary", "usermem"} {
+ if !g.imports.markUsed(name) {
+ panic(fmt.Sprintf("Generated test for '%s' referenced a non-existent import with local name '%s'", g.typeName(), name))
+ }
+ }
+
+ g.inTestFunction("TestSafeMarshalUnmarshalSlicePreservesData", func() {
+ g.emit("var x, y, yUnsafe [8]%s\n", g.typeName())
+ g.emit("analysis.RandomizeValue(&x)\n\n")
+ g.emit("size := (*%s)(nil).SizeBytes() * len(x)\n", g.typeName())
+ g.emit("buf := bytes.NewBuffer(make([]byte, size))\n")
+ g.emit("buf.Reset()\n")
+ g.emit("if err := binary.Write(buf, usermem.ByteOrder, x[:]); err != nil {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"binary.Write failed: %v\", err))\n")
+ })
+ g.emit("}\n")
+ g.emit("bufUnsafe := make([]byte, size)\n")
+ g.emit("MarshalUnsafe%s(x[:], bufUnsafe)\n\n", slice.ident)
+
+ g.emit("UnmarshalUnsafe%s(y[:], buf.Bytes())\n", slice.ident)
+ g.emit("if !reflect.DeepEqual(x, y) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across binary.Write/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n")
+ })
+ g.emit("}\n")
+ g.emit("UnmarshalUnsafe%s(yUnsafe[:], bufUnsafe)\n", slice.ident)
+ g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafeSlice/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n")
+ })
+ g.emit("}\n\n")
+ })
+}
+
+func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() {
+ g.inTestFunction("TestWriteToUnmarshalPreservesData", func() {
+ g.emit("var x, y, yUnsafe %s\n", g.typeName())
+ g.emit("analysis.RandomizeValue(&x)\n\n")
+
+ g.emit("var buf bytes.Buffer\n\n")
+
+ g.emit("x.WriteTo(&buf)\n")
+ g.emit("y.UnmarshalBytes(buf.Bytes())\n\n")
+ g.emit("yUnsafe.UnmarshalUnsafe(buf.Bytes())\n\n")
+
+ g.emit("if !reflect.DeepEqual(x, y) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n")
+ })
+ g.emit("}\n")
+ g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n")
+ })
+ g.emit("}\n")
+ })
+}
+
+func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() {
+ g.inTestFunction("TestSizeBytesOnTypedNilPtr", func() {
+ g.emit("var x %s\n", g.typeName())
+ g.emit("sizeFromConcrete := x.SizeBytes()\n")
+ g.emit("sizeFromTypedNilPtr := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)\n")
+ })
+ g.emit("}\n")
+ })
+}
+
+func (g *testGenerator) emitTests(slice *sliceAPI) {
+ g.emitTestNonZeroSize()
+ g.emitTestSuspectAlignment()
+ g.emitTestMarshalUnmarshalPreservesData()
+ g.emitTestWriteToUnmarshalPreservesData()
+ g.emitTestSizeBytesOnTypedNilPtr()
+
+ if slice != nil {
+ g.emitTestMarshalUnmarshalSlicePreservesData(slice)
+ }
+}
+
+func (g *testGenerator) write(out io.Writer) error {
+ return g.sourceBuffer.write(out)
+}
diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go
new file mode 100644
index 000000000..d94314302
--- /dev/null
+++ b/tools/go_marshal/gomarshal/util.go
@@ -0,0 +1,491 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gomarshal
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/token"
+ "io"
+ "os"
+ "path"
+ "reflect"
+ "sort"
+ "strings"
+)
+
+var debug = flag.Bool("debug", false, "enables debugging output")
+
+// receiverName returns an appropriate receiver name given a type spec.
+func receiverName(t *ast.TypeSpec) string {
+ if len(t.Name.Name) < 1 {
+ // Zero length type name?
+ panic("unreachable")
+ }
+ return strings.ToLower(t.Name.Name[:1])
+}
+
+// kindString returns a user-friendly representation of an AST expr type.
+func kindString(e ast.Expr) string {
+ switch e.(type) {
+ case *ast.Ident:
+ return "scalar"
+ case *ast.ArrayType:
+ return "array"
+ case *ast.StructType:
+ return "struct"
+ case *ast.StarExpr:
+ return "pointer"
+ case *ast.FuncType:
+ return "function"
+ case *ast.InterfaceType:
+ return "interface"
+ case *ast.MapType:
+ return "map"
+ case *ast.ChanType:
+ return "channel"
+ default:
+ return reflect.TypeOf(e).String()
+ }
+}
+
+func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) {
+ for _, field := range st.Fields.List {
+ fn(field)
+ }
+}
+
+// fieldDispatcher is a collection of callbacks for handling different types of
+// fields in a struct declaration.
+type fieldDispatcher struct {
+ primitive func(n, t *ast.Ident)
+ selector func(n, tX, tSel *ast.Ident)
+ array func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident)
+ unhandled func(n *ast.Ident)
+}
+
+// Precondition: All dispatch callbacks that will be invoked must be
+// provided. Embedded fields are not allowed, len(f.Names) >= 1.
+func (fd fieldDispatcher) dispatch(f *ast.Field) {
+ // Each field declaration may actually be multiple declarations of the same
+ // type. For example, consider:
+ //
+ // type Point struct {
+ // x, y, z int
+ // }
+ //
+ // We invoke the call-backs once per such instance. Embedded fields are not
+ // allowed, and results in a panic.
+ if len(f.Names) < 1 {
+ panic("Precondition not met: attempted to dispatch on embedded field")
+ }
+
+ for _, name := range f.Names {
+ switch v := f.Type.(type) {
+ case *ast.Ident:
+ fd.primitive(name, v)
+ case *ast.SelectorExpr:
+ fd.selector(name, v.X.(*ast.Ident), v.Sel)
+ case *ast.ArrayType:
+ switch t := v.Elt.(type) {
+ case *ast.Ident:
+ fd.array(name, v, t)
+ default:
+ // Should be handled with a better error message during validate.
+ panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t))
+ }
+ default:
+ fd.unhandled(name)
+ }
+ }
+}
+
+// debugEnabled indicates whether debugging is enabled for gomarshal.
+func debugEnabled() bool {
+ return *debug
+}
+
+// abort aborts the go_marshal tool with the given error message.
+func abort(msg string) {
+ if !strings.HasSuffix(msg, "\n") {
+ msg += "\n"
+ }
+ fmt.Print(msg)
+ os.Exit(1)
+}
+
+// abortAt aborts the go_marshal tool with the given error message, with
+// a reference position to the input source.
+func abortAt(p token.Position, msg string) {
+ abort(fmt.Sprintf("%v:\n %s\n", p, msg))
+}
+
+// debugf conditionally prints a debug message.
+func debugf(f string, a ...interface{}) {
+ if debugEnabled() {
+ fmt.Printf(f, a...)
+ }
+}
+
+// debugfAt conditionally prints a debug message with a reference to a position
+// in the input source.
+func debugfAt(p token.Position, f string, a ...interface{}) {
+ if debugEnabled() {
+ fmt.Printf("%s:\n %s", p, fmt.Sprintf(f, a...))
+ }
+}
+
+// emit generates a line of code in the output file.
+//
+// emit is a wrapper around writing a formatted string to the output
+// buffer. emit can be invoked in one of two ways:
+//
+// (1) emit("some string")
+// When emit is called with a single string argument, it is simply copied to
+// the output buffer without any further formatting.
+// (2) emit(fmtString, args...)
+// emit can also be invoked in a similar fashion to *Printf() functions,
+// where the first argument is a format string.
+//
+// Calling emit with a single argument that is not a string will result in a
+// panic, as the caller's intent is ambiguous.
+func emit(out io.Writer, indent int, a ...interface{}) {
+ const spacesPerIndentLevel = 4
+
+ if len(a) < 1 {
+ panic("emit() called with no arguments")
+ }
+
+ if indent > 0 {
+ if _, err := fmt.Fprint(out, strings.Repeat(" ", indent*spacesPerIndentLevel)); err != nil {
+ // Writing to the emit output should not fail. Typically the output
+ // is a byte.Buffer; writes to these never fail.
+ panic(err)
+ }
+ }
+
+ first, ok := a[0].(string)
+ if !ok {
+ // First argument must be either the string to emit (case 1 from
+ // function-level comment), or a format string (case 2).
+ panic(fmt.Sprintf("First argument to emit() is not a string: %+v", a[0]))
+ }
+
+ if len(a) == 1 {
+ // Single string argument. Assume no formatting requested.
+ if _, err := fmt.Fprint(out, first); err != nil {
+ // Writing to out should not fail.
+ panic(err)
+ }
+ return
+
+ }
+
+ // Formatting requested.
+ if _, err := fmt.Fprintf(out, first, a[1:]...); err != nil {
+ // Writing to out should not fail.
+ panic(err)
+ }
+}
+
+// sourceBuffer represents fragments of generated go source code.
+//
+// sourceBuffer provides a convenient way to build up go souce fragments in
+// memory. May be safely zero-value initialized. Not thread-safe.
+type sourceBuffer struct {
+ // Current indentation level.
+ indent int
+
+ // Memory buffer containing contents while they're being generated.
+ b bytes.Buffer
+}
+
+func (b *sourceBuffer) reset() {
+ b.indent = 0
+ b.b.Reset()
+}
+
+func (b *sourceBuffer) incIndent() {
+ b.indent++
+}
+
+func (b *sourceBuffer) decIndent() {
+ if b.indent <= 0 {
+ panic("decIndent() without matching incIndent()")
+ }
+ b.indent--
+}
+
+func (b *sourceBuffer) emit(a ...interface{}) {
+ emit(&b.b, b.indent, a...)
+}
+
+func (b *sourceBuffer) emitNoIndent(a ...interface{}) {
+ emit(&b.b, 0 /*indent*/, a...)
+}
+
+func (b *sourceBuffer) inIndent(body func()) {
+ b.incIndent()
+ body()
+ b.decIndent()
+}
+
+func (b *sourceBuffer) write(out io.Writer) error {
+ _, err := fmt.Fprint(out, b.b.String())
+ return err
+}
+
+// Write implements io.Writer.Write.
+func (b *sourceBuffer) Write(buf []byte) (int, error) {
+ return (b.b.Write(buf))
+}
+
+// importStmt represents a single import statement.
+type importStmt struct {
+ // Local name of the imported package.
+ name string
+ // Import path.
+ path string
+ // Indicates whether the local name is an alias, or simply the final
+ // component of the path.
+ aliased bool
+ // Indicates whether this import was referenced by generated code.
+ used bool
+ // AST node and file set representing the import statement, if any. These
+ // are only non-nil if the import statement originates from an input source
+ // file.
+ spec *ast.ImportSpec
+ fset *token.FileSet
+}
+
+func newImport(p string) *importStmt {
+ name := path.Base(p)
+ return &importStmt{
+ name: name,
+ path: p,
+ aliased: false,
+ }
+}
+
+func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
+ p := spec.Path.Value[1 : len(spec.Path.Value)-1] // Strip the " quotes around path.
+ name := path.Base(p)
+ if name == "" || name == "/" || name == "." {
+ panic(fmt.Sprintf("Couldn't process local package name for import at %s, (processed as %s)",
+ f.Position(spec.Path.Pos()), name))
+ }
+ if spec.Name != nil {
+ name = spec.Name.Name
+ }
+ return &importStmt{
+ name: name,
+ path: p,
+ aliased: spec.Name != nil,
+ spec: spec,
+ fset: f,
+ }
+}
+
+// String implements fmt.Stringer.String. This generates a string for the import
+// statement appropriate for writing directly to generated code.
+func (i *importStmt) String() string {
+ if i.aliased {
+ return fmt.Sprintf("%s %q", i.name, i.path)
+ }
+ return fmt.Sprintf("%q", i.path)
+}
+
+// debugString returns a debug string representing an import statement. This
+// representation is not valid golang code and is used for debugging output.
+func (i *importStmt) debugString() string {
+ if i.spec != nil && i.fset != nil {
+ return fmt.Sprintf("%s: %s", i.fset.Position(i.spec.Path.Pos()), i)
+ }
+ return fmt.Sprintf("(go-marshal import): %s", i)
+}
+
+func (i *importStmt) markUsed() {
+ i.used = true
+}
+
+func (i *importStmt) equivalent(other *importStmt) bool {
+ return i.name == other.name && i.path == other.path && i.aliased == other.aliased
+}
+
+// importTable represents a collection of importStmts.
+//
+// An importTable may contain multiple import statements referencing the same
+// local name. All import statements aliasing to the same local name are
+// technically ambiguous, as if such an import name is used in the generated
+// code, it's not clear which import statement it refers to. We ignore any
+// potential collisions until actually writing the import table to the generated
+// source file. See importTable.write.
+//
+// Given the following import statements across all the files comprising a
+// package marshalled:
+//
+// "sync"
+// "pkg/sync"
+// "pkg/sentry/kernel"
+// ktime "pkg/sentry/kernel/time"
+//
+// An importTable representing them would look like this:
+//
+// importTable {
+// is: map[string][]*importStmt {
+// "sync": []*importStmt{
+// importStmt{name:"sync", path:"sync", aliased:false}
+// importStmt{name:"sync", path:"pkg/sync", aliased:false}
+// },
+// "kernel": []*importStmt{importStmt{
+// name: "kernel",
+// path: "pkg/sentry/kernel",
+// aliased: false
+// }},
+// "ktime": []*importStmt{importStmt{
+// name: "ktime",
+// path: "pkg/sentry/kernel/time",
+// aliased: true,
+// }},
+// }
+// }
+//
+// Note that the local name "sync" is assigned to two different import
+// statements. This is possible if the import statements are from different
+// source files in the same package.
+//
+// Since go-marshal generates a single output file per package regardless of the
+// number of input files, if "sync" is referenced by any generated code, it's
+// unclear which import statement "sync" refers to. While it's theoretically
+// possible to resolve this by assigning a unique local alias to each instance
+// of the sync package, go-marshal currently aborts when it encounters such an
+// ambiguity.
+//
+// TODO(b/151478251): importTable considers the final component of an import
+// path to be the package name, but this is only a convention. The actual
+// package name is determined by the package statement in the source files for
+// the package.
+type importTable struct {
+ // Map of imports and whether they should be copied to the output.
+ is map[string][]*importStmt
+}
+
+func newImportTable() *importTable {
+ return &importTable{
+ is: make(map[string][]*importStmt),
+ }
+}
+
+// Merges import statements from other into i.
+func (i *importTable) merge(other *importTable) {
+ for name, ims := range other.is {
+ i.is[name] = append(i.is[name], ims...)
+ }
+}
+
+func (i *importTable) addStmt(s *importStmt) *importStmt {
+ i.is[s.name] = append(i.is[s.name], s)
+ return s
+}
+
+func (i *importTable) add(s string) *importStmt {
+ n := newImport(s)
+ return i.addStmt(n)
+}
+
+func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
+ return i.addStmt(newImportFromSpec(spec, f))
+}
+
+// Marks the import named n as used. If no such import is in the table, returns
+// false.
+func (i *importTable) markUsed(n string) bool {
+ if ns, ok := i.is[n]; ok {
+ for _, n := range ns {
+ n.markUsed()
+ }
+ return true
+ }
+ return false
+}
+
+func (i *importTable) clear() {
+ for _, is := range i.is {
+ for _, i := range is {
+ i.used = false
+ }
+ }
+}
+
+func (i *importTable) write(out io.Writer) error {
+ if len(i.is) == 0 {
+ // Nothing to import, we're done.
+ return nil
+ }
+
+ imports := make([]string, 0, len(i.is))
+ for name, is := range i.is {
+ var lastUsed *importStmt
+ var ambiguous bool
+
+ for _, i := range is {
+ if i.used {
+ if lastUsed != nil {
+ if !i.equivalent(lastUsed) {
+ ambiguous = true
+ }
+ }
+ lastUsed = i
+ }
+ }
+
+ if ambiguous {
+ // We have two or more import statements across the different source
+ // files that share a local name, and at least one of these imports
+ // are used by the generated code. This ambiguity can't be resolved
+ // by go-marshal and requires the user intervention. Dump a list of
+ // the colliding import statements and let the user modify the input
+ // files as appropriate.
+ var b strings.Builder
+ fmt.Fprintf(&b, "The imported name %q is used by one of the types marked for marshalling, and which import statement the code refers to is ambiguous. Perhaps give the imports unique local names?\n\n", name)
+ fmt.Fprintf(&b, "The following %d import statements are ambiguous for the local name %q:\n", len(is), name)
+ // Note: len(is) is guaranteed to be 1 or greater or ambiguous can't
+ // be true. Therefore the slicing below is safe.
+ for _, i := range is[:len(is)-1] {
+ fmt.Fprintf(&b, " %v\n", i.debugString())
+ }
+ fmt.Fprintf(&b, " %v", is[len(is)-1].debugString())
+ panic(b.String())
+ }
+
+ if lastUsed != nil {
+ imports = append(imports, lastUsed.String())
+ }
+ }
+ sort.Strings(imports)
+
+ var b sourceBuffer
+ b.emit("import (\n")
+ b.incIndent()
+ for _, i := range imports {
+ b.emit("%s\n", i)
+ }
+ b.decIndent()
+ b.emit(")\n\n")
+
+ return b.write(out)
+}
diff --git a/tools/go_marshal/main.go b/tools/go_marshal/main.go
new file mode 100644
index 000000000..f74be5c29
--- /dev/null
+++ b/tools/go_marshal/main.go
@@ -0,0 +1,72 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// go_marshal is a code generation utility for automatically generating code to
+// marshal go data structures to memory.
+//
+// This binary is typically run as part of the build process, and is invoked by
+// the go_marshal bazel rule defined in defs.bzl.
+//
+// See README.md.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "os"
+ "strings"
+
+ "gvisor.dev/gvisor/tools/go_marshal/gomarshal"
+)
+
+var (
+ pkg = flag.String("pkg", "", "output package")
+ output = flag.String("output", "", "output file")
+ outputTest = flag.String("output_test", "", "output file for tests")
+ imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code")
+)
+
+func main() {
+ flag.Usage = func() {
+ fmt.Fprintf(os.Stderr, "Usage: %s <input go src files>\n", os.Args[0])
+ flag.PrintDefaults()
+ }
+ flag.Parse()
+ if len(flag.Args()) == 0 {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ if *pkg == "" {
+ flag.Usage()
+ fmt.Fprint(os.Stderr, "Flag -pkg must be provided.\n")
+ os.Exit(1)
+ }
+
+ var extraImports []string
+ if len(*imports) > 0 {
+ // Note: strings.Split(s, sep) returns s if sep doesn't exist in s. Thus
+ // we check for an empty imports list to avoid emitting an empty string
+ // as an import.
+ extraImports = strings.Split(*imports, ",")
+ }
+ g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, extraImports)
+ if err != nil {
+ panic(err)
+ }
+
+ if err := g.Run(); err != nil {
+ panic(err)
+ }
+}
diff --git a/tools/go_marshal/marshal/BUILD b/tools/go_marshal/marshal/BUILD
new file mode 100644
index 000000000..bacfaa5a4
--- /dev/null
+++ b/tools/go_marshal/marshal/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "marshal",
+ srcs = [
+ "marshal.go",
+ ],
+ visibility = [
+ "//:sandbox",
+ ],
+ deps = [
+ "//pkg/usermem",
+ ],
+)
diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go
new file mode 100644
index 000000000..cb2166252
--- /dev/null
+++ b/tools/go_marshal/marshal/marshal.go
@@ -0,0 +1,187 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package marshal defines the Marshallable interface for
+// serialize/deserializing go data structures to/from memory, according to the
+// Linux ABI.
+//
+// Implementations of this interface are typically automatically generated by
+// tools/go_marshal. See the go_marshal README for details.
+package marshal
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Task provides a subset of kernel.Task, used in marshalling. We don't import
+// the kernel package directly to avoid circular dependency.
+type Task interface {
+ // CopyScratchBuffer provides a task goroutine-local scratch buffer. See
+ // kernel.CopyScratchBuffer.
+ CopyScratchBuffer(size int) []byte
+
+ // CopyOutBytes writes the contents of b to the task's memory. See
+ // kernel.CopyOutBytes.
+ CopyOutBytes(addr usermem.Addr, b []byte) (int, error)
+
+ // CopyInBytes reads the contents of the task's memory to b. See
+ // kernel.CopyInBytes.
+ CopyInBytes(addr usermem.Addr, b []byte) (int, error)
+}
+
+// Marshallable represents operations on a type that can be marshalled to and
+// from memory.
+//
+// go-marshal automatically generates implementations for this interface for
+// types marked as '+marshal'.
+type Marshallable interface {
+ io.WriterTo
+
+ // SizeBytes is the size of the memory representation of a type in
+ // marshalled form.
+ //
+ // SizeBytes must handle a nil receiver. Practically, this means SizeBytes
+ // cannot deference any fields on the object implementing it (but will
+ // likely make use of the type of these fields).
+ SizeBytes() int
+
+ // MarshalBytes serializes a copy of a type to dst. dst may be smaller than
+ // SizeBytes(), which results in a part of the struct being marshalled. Note
+ // that this may have unexpected results for non-packed types, as implicit
+ // padding needs to be taken into account when reasoning about how much of
+ // the type is serialized.
+ MarshalBytes(dst []byte)
+
+ // UnmarshalBytes deserializes a type from src. src may be smaller than
+ // SizeBytes(), which results in a partially deserialized struct. Note that
+ // this may have unexpected results for non-packed types, as implicit
+ // padding needs to be taken into account when reasoning about how much of
+ // the type is deserialized.
+ UnmarshalBytes(src []byte)
+
+ // Packed returns true if the marshalled size of the type is the same as the
+ // size it occupies in memory. This happens when the type has no fields
+ // starting at unaligned addresses (should always be true by default for ABI
+ // structs, verified by automatically generated tests when using
+ // go_marshal), and has no fields marked `marshal:"unaligned"`.
+ //
+ // Packed must return the same result for all possible values of the type
+ // implementing it. Violating this constraint implies the type doesn't have
+ // a static memory layout, and will lead to memory corruption.
+ // Go-marshal-generated code reuses the result of Packed for multiple values
+ // of the same type.
+ Packed() bool
+
+ // MarshalUnsafe serializes a type by bulk copying its in-memory
+ // representation to the dst buffer. This is only safe to do when the type
+ // has no implicit padding, see Marshallable.Packed. When Packed would
+ // return false, MarshalUnsafe should fall back to the safer but slower
+ // MarshalBytes. dst may be smaller than SizeBytes(), see comment for
+ // MarshalBytes for implications.
+ MarshalUnsafe(dst []byte)
+
+ // UnmarshalUnsafe deserializes a type by directly copying to the underlying
+ // memory allocated for the object by the runtime.
+ //
+ // This allows much faster unmarshalling of types which have no implicit
+ // padding, see Marshallable.Packed. When Packed would return false,
+ // UnmarshalUnsafe should fall back to the safer but slower unmarshal
+ // mechanism implemented in UnmarshalBytes. src may be smaller than
+ // SizeBytes(), see comment for UnmarshalBytes for implications.
+ UnmarshalUnsafe(src []byte)
+
+ // CopyIn deserializes a Marshallable type from a task's memory. This may
+ // only be called from a task goroutine. This is more efficient than calling
+ // UnmarshalUnsafe on Marshallable.Packed types, as the type being
+ // marshalled does not escape. The implementation should avoid creating
+ // extra copies in memory by directly deserializing to the object's
+ // underlying memory.
+ //
+ // If the copy-in from the task memory is only partially successful, CopyIn
+ // should still attempt to deserialize as much data as possible. See comment
+ // for UnmarshalBytes.
+ CopyIn(task Task, addr usermem.Addr) (int, error)
+
+ // CopyOut serializes a Marshallable type to a task's memory. This may only
+ // be called from a task goroutine. This is more efficient than calling
+ // MarshalUnsafe on Marshallable.Packed types, as the type being serialized
+ // does not escape. The implementation should avoid creating extra copies in
+ // memory by directly serializing from the object's underlying memory.
+ //
+ // The copy-out to the task memory may be partially successful, in which
+ // case CopyOut returns how much data was serialized. See comment for
+ // MarshalBytes for implications.
+ CopyOut(task Task, addr usermem.Addr) (int, error)
+
+ // CopyOutN is like CopyOut, but explicitly requests a partial
+ // copy-out. Note that this may yield unexpected results for non-packed
+ // types and the caller may only want to allow this for packed types. See
+ // comment on MarshalBytes.
+ //
+ // The limit must be less than or equal to SizeBytes().
+ CopyOutN(task Task, addr usermem.Addr, limit int) (int, error)
+}
+
+// go-marshal generates additional functions for a type based on additional
+// clauses to the +marshal directive. They are documented below.
+//
+// Slice API
+// =========
+//
+// Adding a "slice" clause to the +marshal directive for structs or newtypes on
+// primitives like this:
+//
+// // +marshal slice:FooSlice
+// type Foo struct { ... }
+//
+// Generates four additional functions for marshalling slices of Foos like this:
+//
+// // MarshalUnsafeFooSlice is like Foo.MarshalUnsafe, buf for a []Foo. It's
+// // more efficient that repeatedly calling calling Foo.MarshalUnsafe over a
+// // []Foo in a loop.
+// func MarshalUnsafeFooSlice(src []Foo, dst []byte) (int, error) { ... }
+//
+// // UnmarshalUnsafeFooSlice is like Foo.UnmarshalUnsafe, buf for a []Foo. It's
+// // more efficient that repeatedly calling calling Foo.UnmarshalUnsafe over a
+// // []Foo in a loop.
+// func UnmarshalUnsafeFooSlice(dst []Foo, src []byte) (int, error) { ... }
+//
+// // CopyFooSliceIn copies in a slice of Foo objects from the task's memory.
+// func CopyFooSliceIn(task marshal.Task, addr usermem.Addr, dst []Foo) (int, error) { ... }
+//
+// // CopyFooSliceIn copies out a slice of Foo objects to the task's memory.
+// func CopyFooSliceOut(task marshal.Task, addr usermem.Addr, src []Foo) (int, error) { ... }
+//
+// The name of the functions are of the format "Copy%sIn" and "Copy%sOut", where
+// %s is the first argument to the slice clause. This directive is not supported
+// for newtypes on arrays.
+//
+// The slice clause also takes an optional second argument, which must be the
+// value "inner":
+//
+// // +marshal slice:Int32Slice:inner
+// type Int32 int32
+//
+// This is only valid on newtypes on primitives, and causes the generated
+// functions to accept slices of the inner type instead:
+//
+// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []int32) (int, error) { ... }
+//
+// Without "inner", they would instead be:
+//
+// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []Int32) (int, error) { ... }
+//
+// This may help avoid a cast depending on how the generated functions are used.
diff --git a/tools/go_marshal/primitive/BUILD b/tools/go_marshal/primitive/BUILD
new file mode 100644
index 000000000..cc08ba63a
--- /dev/null
+++ b/tools/go_marshal/primitive/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "primitive",
+ srcs = [
+ "primitive.go",
+ ],
+ marshal = True,
+ visibility = [
+ "//:sandbox",
+ ],
+ deps = [
+ "//pkg/usermem",
+ "//tools/go_marshal/marshal",
+ ],
+)
diff --git a/tools/go_marshal/primitive/primitive.go b/tools/go_marshal/primitive/primitive.go
new file mode 100644
index 000000000..ebcf130ae
--- /dev/null
+++ b/tools/go_marshal/primitive/primitive.go
@@ -0,0 +1,175 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package primitive defines marshal.Marshallable implementations for primitive
+// types.
+package primitive
+
+import (
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+)
+
+// Int16 is a marshal.Marshallable implementation for int16.
+//
+// +marshal slice:Int16Slice:inner
+type Int16 int16
+
+// Uint16 is a marshal.Marshallable implementation for uint16.
+//
+// +marshal slice:Uint16Slice:inner
+type Uint16 uint16
+
+// Int32 is a marshal.Marshallable implementation for int32.
+//
+// +marshal slice:Int32Slice:inner
+type Int32 int32
+
+// Uint32 is a marshal.Marshallable implementation for uint32.
+//
+// +marshal slice:Uint32Slice:inner
+type Uint32 uint32
+
+// Int64 is a marshal.Marshallable implementation for int64.
+//
+// +marshal slice:Int64Slice:inner
+type Int64 int64
+
+// Uint64 is a marshal.Marshallable implementation for uint64.
+//
+// +marshal slice:Uint64Slice:inner
+type Uint64 uint64
+
+// Below, we define some convenience functions for marshalling primitive types
+// using the newtypes above, without requiring superfluous casts.
+
+// 16-bit integers
+
+// CopyInt16In is a convenient wrapper for copying in an int16 from the task's
+// memory.
+func CopyInt16In(task marshal.Task, addr usermem.Addr, dst *int16) (int, error) {
+ var buf Int16
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = int16(buf)
+ return n, nil
+}
+
+// CopyInt16Out is a convenient wrapper for copying out an int16 to the task's
+// memory.
+func CopyInt16Out(task marshal.Task, addr usermem.Addr, src int16) (int, error) {
+ srcP := Int16(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// CopyUint16In is a convenient wrapper for copying in a uint16 from the task's
+// memory.
+func CopyUint16In(task marshal.Task, addr usermem.Addr, dst *uint16) (int, error) {
+ var buf Uint16
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = uint16(buf)
+ return n, nil
+}
+
+// CopyUint16Out is a convenient wrapper for copying out a uint16 to the task's
+// memory.
+func CopyUint16Out(task marshal.Task, addr usermem.Addr, src uint16) (int, error) {
+ srcP := Uint16(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// 32-bit integers
+
+// CopyInt32In is a convenient wrapper for copying in an int32 from the task's
+// memory.
+func CopyInt32In(task marshal.Task, addr usermem.Addr, dst *int32) (int, error) {
+ var buf Int32
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = int32(buf)
+ return n, nil
+}
+
+// CopyInt32Out is a convenient wrapper for copying out an int32 to the task's
+// memory.
+func CopyInt32Out(task marshal.Task, addr usermem.Addr, src int32) (int, error) {
+ srcP := Int32(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// CopyUint32In is a convenient wrapper for copying in a uint32 from the task's
+// memory.
+func CopyUint32In(task marshal.Task, addr usermem.Addr, dst *uint32) (int, error) {
+ var buf Uint32
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = uint32(buf)
+ return n, nil
+}
+
+// CopyUint32Out is a convenient wrapper for copying out a uint32 to the task's
+// memory.
+func CopyUint32Out(task marshal.Task, addr usermem.Addr, src uint32) (int, error) {
+ srcP := Uint32(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// 64-bit integers
+
+// CopyInt64In is a convenient wrapper for copying in an int64 from the task's
+// memory.
+func CopyInt64In(task marshal.Task, addr usermem.Addr, dst *int64) (int, error) {
+ var buf Int64
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = int64(buf)
+ return n, nil
+}
+
+// CopyInt64Out is a convenient wrapper for copying out an int64 to the task's
+// memory.
+func CopyInt64Out(task marshal.Task, addr usermem.Addr, src int64) (int, error) {
+ srcP := Int64(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// CopyUint64In is a convenient wrapper for copying in a uint64 from the task's
+// memory.
+func CopyUint64In(task marshal.Task, addr usermem.Addr, dst *uint64) (int, error) {
+ var buf Uint64
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = uint64(buf)
+ return n, nil
+}
+
+// CopyUint64Out is a convenient wrapper for copying out a uint64 to the task's
+// memory.
+func CopyUint64Out(task marshal.Task, addr usermem.Addr, src uint64) (int, error) {
+ srcP := Uint64(src)
+ return srcP.CopyOut(task, addr)
+}
diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD
new file mode 100644
index 000000000..2fbcc8a03
--- /dev/null
+++ b/tools/go_marshal/test/BUILD
@@ -0,0 +1,44 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+licenses(["notice"])
+
+package_group(
+ name = "gomarshal_test",
+ packages = [
+ "//tools/go_marshal/test/...",
+ ],
+)
+
+go_test(
+ name = "benchmark_test",
+ srcs = ["benchmark_test.go"],
+ deps = [
+ ":test",
+ "//pkg/binary",
+ "//pkg/usermem",
+ "//tools/go_marshal/analysis",
+ ],
+)
+
+go_library(
+ name = "test",
+ testonly = 1,
+ srcs = ["test.go"],
+ marshal = True,
+ visibility = ["//tools/go_marshal/test:__subpackages__"],
+ deps = ["//tools/go_marshal/test/external"],
+)
+
+go_test(
+ name = "marshal_test",
+ size = "small",
+ srcs = ["marshal_test.go"],
+ deps = [
+ ":test",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//tools/go_marshal/analysis",
+ "//tools/go_marshal/marshal",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ ],
+)
diff --git a/tools/go_marshal/test/benchmark_test.go b/tools/go_marshal/test/benchmark_test.go
new file mode 100644
index 000000000..224d308c7
--- /dev/null
+++ b/tools/go_marshal/test/benchmark_test.go
@@ -0,0 +1,220 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package benchmark_test
+
+import (
+ "bytes"
+ encbin "encoding/binary"
+ "fmt"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/analysis"
+ "gvisor.dev/gvisor/tools/go_marshal/test"
+)
+
+// Marshalling using the standard encoding/binary package.
+func BenchmarkEncodingBinary(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ size := encbin.Size(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := bytes.NewBuffer(make([]byte, size))
+ buf.Reset()
+ if err := encbin.Write(buf, usermem.ByteOrder, &s1); err != nil {
+ b.Error("Write:", err)
+ }
+ if err := encbin.Read(buf, usermem.ByteOrder, &s2); err != nil {
+ b.Error("Read:", err)
+ }
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling using the sentry's binary.Marshal.
+func BenchmarkBinary(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ size := binary.Size(s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, 0, size)
+ buf = binary.Marshal(buf, usermem.ByteOrder, &s1)
+ binary.Unmarshal(buf, usermem.ByteOrder, &s2)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling field-by-field with manually-written code.
+func BenchmarkMarshalManual(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, 0, s1.SizeBytes())
+
+ // Marshal
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Dev)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Ino)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Nlink)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.Mode)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.UID)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.GID)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, 0)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Rdev)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Size))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blksize))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blocks))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Sec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Nsec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Sec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Nsec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Sec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Nsec))
+
+ // Unmarshal
+ s2.Dev = usermem.ByteOrder.Uint64(buf[0:8])
+ s2.Ino = usermem.ByteOrder.Uint64(buf[8:16])
+ s2.Nlink = usermem.ByteOrder.Uint64(buf[16:24])
+ s2.Mode = usermem.ByteOrder.Uint32(buf[24:28])
+ s2.UID = usermem.ByteOrder.Uint32(buf[28:32])
+ s2.GID = usermem.ByteOrder.Uint32(buf[32:36])
+ // Padding: buf[36:40]
+ s2.Rdev = usermem.ByteOrder.Uint64(buf[40:48])
+ s2.Size = int64(usermem.ByteOrder.Uint64(buf[48:56]))
+ s2.Blksize = int64(usermem.ByteOrder.Uint64(buf[56:64]))
+ s2.Blocks = int64(usermem.ByteOrder.Uint64(buf[64:72]))
+ s2.ATime.Sec = int64(usermem.ByteOrder.Uint64(buf[72:80]))
+ s2.ATime.Nsec = int64(usermem.ByteOrder.Uint64(buf[80:88]))
+ s2.MTime.Sec = int64(usermem.ByteOrder.Uint64(buf[88:96]))
+ s2.MTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[96:104]))
+ s2.CTime.Sec = int64(usermem.ByteOrder.Uint64(buf[104:112]))
+ s2.CTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[112:120]))
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling with the go_marshal safe API.
+func BenchmarkGoMarshalSafe(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, s1.SizeBytes())
+ s1.MarshalBytes(buf)
+ s2.UnmarshalBytes(buf)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling with the go_marshal unsafe API.
+func BenchmarkGoMarshalUnsafe(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, s1.SizeBytes())
+ s1.MarshalUnsafe(buf)
+ s2.UnmarshalUnsafe(buf)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+func BenchmarkBinarySlice(b *testing.B) {
+ var s1, s2 [64]test.Stat
+ analysis.RandomizeValue(&s1)
+
+ size := binary.Size(s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, 0, size)
+ buf = binary.Marshal(buf, usermem.ByteOrder, &s1)
+ binary.Unmarshal(buf, usermem.ByteOrder, &s2)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+func BenchmarkGoMarshalUnsafeSlice(b *testing.B) {
+ var s1, s2 [64]test.Stat
+ analysis.RandomizeValue(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, (*test.Stat)(nil).SizeBytes()*len(s1))
+ test.MarshalUnsafeStatSlice(s1[:], buf)
+ test.UnmarshalUnsafeStatSlice(s2[:], buf)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
diff --git a/tools/go_marshal/test/escape/BUILD b/tools/go_marshal/test/escape/BUILD
new file mode 100644
index 000000000..f74e6ffae
--- /dev/null
+++ b/tools/go_marshal/test/escape/BUILD
@@ -0,0 +1,14 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "escape",
+ testonly = 1,
+ srcs = ["escape.go"],
+ deps = [
+ "//pkg/usermem",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/test",
+ ],
+)
diff --git a/tools/go_marshal/test/escape/escape.go b/tools/go_marshal/test/escape/escape.go
new file mode 100644
index 000000000..6a46ddbf8
--- /dev/null
+++ b/tools/go_marshal/test/escape/escape.go
@@ -0,0 +1,95 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package escape
+
+import (
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/test"
+)
+
+// dummyTask implements marshal.Task.
+type dummyTask struct {
+}
+
+func (*dummyTask) CopyScratchBuffer(size int) []byte {
+ return make([]byte, size)
+}
+
+func (*dummyTask) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) {
+ return len(b), nil
+}
+
+func (*dummyTask) CopyInBytes(addr usermem.Addr, b []byte) (int, error) {
+ return len(b), nil
+}
+
+func (t *dummyTask) MarshalBytes(addr usermem.Addr, marshallable marshal.Marshallable) {
+ buf := t.CopyScratchBuffer(marshallable.SizeBytes())
+ marshallable.MarshalBytes(buf)
+ t.CopyOutBytes(addr, buf)
+}
+
+func (t *dummyTask) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marshallable) {
+ buf := t.CopyScratchBuffer(marshallable.SizeBytes())
+ marshallable.MarshalUnsafe(buf)
+ t.CopyOutBytes(addr, buf)
+}
+
+// +checkescape:all
+//go:nosplit
+func doCopyIn(t *dummyTask) {
+ var stat test.Stat
+ stat.CopyIn(t, usermem.Addr(0xf000ba12))
+}
+
+// +checkescape:all
+//go:nosplit
+func doCopyOut(t *dummyTask) {
+ var stat test.Stat
+ stat.CopyOut(t, usermem.Addr(0xf000ba12))
+}
+
+// +mustescape:builtin
+// +mustescape:stack
+func doMarshalBytesDirect(t *dummyTask) {
+ var stat test.Stat
+ buf := t.CopyScratchBuffer(stat.SizeBytes())
+ stat.MarshalBytes(buf)
+ t.CopyOutBytes(usermem.Addr(0xf000ba12), buf)
+}
+
+// +mustescape:builtin
+// +mustescape:stack
+func doMarshalUnsafeDirect(t *dummyTask) {
+ var stat test.Stat
+ buf := t.CopyScratchBuffer(stat.SizeBytes())
+ stat.MarshalUnsafe(buf)
+ t.CopyOutBytes(usermem.Addr(0xf000ba12), buf)
+}
+
+// +mustescape:local,heap
+// +mustescape:stack
+func doMarshalBytesViaMarshallable(t *dummyTask) {
+ var stat test.Stat
+ t.MarshalBytes(usermem.Addr(0xf000ba12), &stat)
+}
+
+// +mustescape:local,heap
+// +mustescape:stack
+func doMarshalUnsafeViaMarshallable(t *dummyTask) {
+ var stat test.Stat
+ t.MarshalUnsafe(usermem.Addr(0xf000ba12), &stat)
+}
diff --git a/tools/go_marshal/test/external/BUILD b/tools/go_marshal/test/external/BUILD
new file mode 100644
index 000000000..0cf6da603
--- /dev/null
+++ b/tools/go_marshal/test/external/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "external",
+ testonly = 1,
+ srcs = ["external.go"],
+ marshal = True,
+ visibility = ["//tools/go_marshal/test:gomarshal_test"],
+)
diff --git a/tools/go_marshal/test/external/external.go b/tools/go_marshal/test/external/external.go
new file mode 100644
index 000000000..26fe8e0c8
--- /dev/null
+++ b/tools/go_marshal/test/external/external.go
@@ -0,0 +1,31 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package external defines types we can import for testing.
+package external
+
+// External is a public Marshallable type for use in testing.
+//
+// +marshal
+type External struct {
+ j int64
+}
+
+// NotPacked is an unaligned Marshallable type for use in testing.
+//
+// +marshal
+type NotPacked struct {
+ a int32
+ b byte `marshal:"unaligned"`
+}
diff --git a/tools/go_marshal/test/marshal_test.go b/tools/go_marshal/test/marshal_test.go
new file mode 100644
index 000000000..16829ee45
--- /dev/null
+++ b/tools/go_marshal/test/marshal_test.go
@@ -0,0 +1,515 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package marshal_test contains manual tests for the marshal interface. These
+// are intended to test behaviour not covered by the automatically generated
+// tests.
+package marshal_test
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "reflect"
+ "runtime"
+ "testing"
+ "unsafe"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/analysis"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/test"
+)
+
+var simulatedErr error = syserror.EFAULT
+
+// mockTask implements marshal.Task.
+type mockTask struct {
+ taskMem usermem.BytesIO
+}
+
+// populate fills the task memory with the contents of val.
+func (t *mockTask) populate(val interface{}) {
+ var buf bytes.Buffer
+ // Use binary.Write so we aren't testing go-marshal against its own
+ // potentially buggy implementation.
+ if err := binary.Write(&buf, usermem.ByteOrder, val); err != nil {
+ panic(err)
+ }
+ t.taskMem.Bytes = buf.Bytes()
+}
+
+func (t *mockTask) setLimit(n int) {
+ if len(t.taskMem.Bytes) < n {
+ grown := make([]byte, n)
+ copy(grown, t.taskMem.Bytes)
+ t.taskMem.Bytes = grown
+ return
+ }
+ t.taskMem.Bytes = t.taskMem.Bytes[:n]
+}
+
+// CopyScratchBuffer implements marshal.Task.CopyScratchBuffer.
+func (t *mockTask) CopyScratchBuffer(size int) []byte {
+ return make([]byte, size)
+}
+
+// CopyOutBytes implements marshal.Task.CopyOutBytes. The implementation
+// completely ignores the target address and stores a copy of b in its
+// internally buffer, overriding any previous contents.
+func (t *mockTask) CopyOutBytes(_ usermem.Addr, b []byte) (int, error) {
+ return t.taskMem.CopyOut(nil, 0, b, usermem.IOOpts{})
+}
+
+// CopyInBytes implements marshal.Task.CopyInBytes. The implementation
+// completely ignores the source address and always fills b from the begining of
+// its internal buffer.
+func (t *mockTask) CopyInBytes(_ usermem.Addr, b []byte) (int, error) {
+ return t.taskMem.CopyIn(nil, 0, b, usermem.IOOpts{})
+}
+
+// unsafeMemory returns the underlying memory for m. The returned slice is only
+// valid for the lifetime for m. The garbage collector isn't aware that the
+// returned slice is related to m, the caller must ensure m lives long enough.
+func unsafeMemory(m marshal.Marshallable) []byte {
+ if !m.Packed() {
+ // We can't return a slice pointing to the underlying memory
+ // since the layout isn't packed. Allocate a temporary buffer
+ // and marshal instead.
+ var buf bytes.Buffer
+ if err := binary.Write(&buf, usermem.ByteOrder, m); err != nil {
+ panic(err)
+ }
+ return buf.Bytes()
+ }
+
+ // reflect.ValueOf(m)
+ // .Elem() // Unwrap interface to inner concrete object
+ // .Addr() // Pointer value to object
+ // .Pointer() // Actual address from the pointer value
+ ptr := reflect.ValueOf(m).Elem().Addr().Pointer()
+
+ size := m.SizeBytes()
+
+ var mem []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&mem))
+ hdr.Data = ptr
+ hdr.Len = size
+ hdr.Cap = size
+
+ return mem
+}
+
+// unsafeMemorySlice returns the underlying memory for m. The returned slice is
+// only valid for the lifetime for m. The garbage collector isn't aware that the
+// returned slice is related to m, the caller must ensure m lives long enough.
+//
+// Precondition: m must be a slice.
+func unsafeMemorySlice(m interface{}, elt marshal.Marshallable) []byte {
+ kind := reflect.TypeOf(m).Kind()
+ if kind != reflect.Slice {
+ panic("unsafeMemorySlice called on non-slice")
+ }
+
+ if !elt.Packed() {
+ // We can't return a slice pointing to the underlying memory
+ // since the layout isn't packed. Allocate a temporary buffer
+ // and marshal instead.
+ var buf bytes.Buffer
+ if err := binary.Write(&buf, usermem.ByteOrder, m); err != nil {
+ panic(err)
+ }
+ return buf.Bytes()
+ }
+
+ v := reflect.ValueOf(m)
+ length := v.Len() * elt.SizeBytes()
+
+ var mem []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&mem))
+ hdr.Data = v.Pointer() // This is a pointer to the first elem for slices.
+ hdr.Len = length
+ hdr.Cap = length
+
+ return mem
+}
+
+func isZeroes(buf []byte) bool {
+ for _, b := range buf {
+ if b != 0 {
+ return false
+ }
+ }
+ return true
+}
+
+// compareMemory compares the first n bytes of two chuncks of memory represented
+// by expected and actual.
+func compareMemory(t *testing.T, expected, actual []byte, n int) {
+ t.Logf("Expected (%d): %v (%d) + (%d) %v\n", len(expected), expected[:n], n, len(expected)-n, expected[n:])
+ t.Logf("Actual (%d): %v (%d) + (%d) %v\n", len(actual), actual[:n], n, len(actual)-n, actual[n:])
+
+ if diff := cmp.Diff(expected[:n], actual[:n]); diff != "" {
+ t.Errorf("Memory buffers don't match:\n--- expected only\n+++ actual only\n%v", diff)
+ }
+}
+
+// limitedCopyIn populates task memory with src, then unmarshals task memory to
+// dst. The task signals an error at limit bytes during copy-in, which should
+// result in a truncated unmarshalling.
+func limitedCopyIn(t *testing.T, src, dst marshal.Marshallable, limit int) {
+ var task mockTask
+ task.populate(src)
+ task.setLimit(limit)
+
+ n, err := dst.CopyIn(&task, usermem.Addr(0))
+ if n != limit {
+ t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+ if err != simulatedErr {
+ t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err)
+ }
+
+ expectedMem := unsafeMemory(src)
+ defer runtime.KeepAlive(src)
+ actualMem := unsafeMemory(dst)
+ defer runtime.KeepAlive(dst)
+
+ compareMemory(t, expectedMem, actualMem, n)
+
+ // The last n bytes should be zero for actual, since actual was
+ // zero-initialized, and CopyIn shouldn't have touched those bytes. However
+ // we can only guarantee we didn't touch anything in the last n bytes if the
+ // layout is packed.
+ if dst.Packed() && !isZeroes(actualMem[n:]) {
+ t.Errorf("Expected the last %d bytes of copied in object to be zeroes, got %v\n", dst.SizeBytes()-n, actualMem)
+ }
+}
+
+// limitedCopyOut marshals src to task memory. The task signals an error at
+// limit bytes during copy-out, which should result in a truncated marshalling.
+func limitedCopyOut(t *testing.T, src marshal.Marshallable, limit int) {
+ var task mockTask
+ task.setLimit(limit)
+
+ n, err := src.CopyOut(&task, usermem.Addr(0))
+ if n != limit {
+ t.Errorf("CopyOut copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+ if err != simulatedErr {
+ t.Errorf("CopyOut returned unexpected error, expected %v, got %v", simulatedErr, err)
+ }
+
+ expectedMem := unsafeMemory(src)
+ defer runtime.KeepAlive(src)
+ actualMem := task.taskMem.Bytes
+
+ compareMemory(t, expectedMem, actualMem, n)
+}
+
+// copyOutN marshals src to task memory, requesting the marshalling to be
+// limited to limit bytes.
+func copyOutN(t *testing.T, src marshal.Marshallable, limit int) {
+ var task mockTask
+ task.setLimit(limit)
+
+ n, err := src.CopyOutN(&task, usermem.Addr(0), limit)
+ if err != nil {
+ t.Errorf("CopyOut returned unexpected error: %v", err)
+ }
+ if n != limit {
+ t.Errorf("CopyOut copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+
+ expectedMem := unsafeMemory(src)
+ defer runtime.KeepAlive(src)
+ actualMem := task.taskMem.Bytes
+
+ t.Logf("Expected: %v + %v\n", expectedMem[:n], expectedMem[n:])
+ t.Logf("Actual : %v + %v\n", actualMem[:n], actualMem[n:])
+
+ compareMemory(t, expectedMem, actualMem, n)
+}
+
+// TestLimitedMarshalling verifies marshalling/unmarshalling succeeds when the
+// underyling copy in/out operations partially succeed.
+func TestLimitedMarshalling(t *testing.T) {
+ types := []reflect.Type{
+ // Packed types.
+ reflect.TypeOf((*test.Type2)(nil)),
+ reflect.TypeOf((*test.Type3)(nil)),
+ reflect.TypeOf((*test.Timespec)(nil)),
+ reflect.TypeOf((*test.Stat)(nil)),
+ reflect.TypeOf((*test.InetAddr)(nil)),
+ reflect.TypeOf((*test.SignalSet)(nil)),
+ reflect.TypeOf((*test.SignalSetAlias)(nil)),
+ // Non-packed types.
+ reflect.TypeOf((*test.Type1)(nil)),
+ reflect.TypeOf((*test.Type4)(nil)),
+ reflect.TypeOf((*test.Type5)(nil)),
+ reflect.TypeOf((*test.Type6)(nil)),
+ reflect.TypeOf((*test.Type7)(nil)),
+ reflect.TypeOf((*test.Type8)(nil)),
+ }
+
+ for _, tyPtr := range types {
+ // Remove one level of pointer-indirection from the type. We get this
+ // back when we pass the type to reflect.New.
+ ty := tyPtr.Elem()
+
+ // Partial copy-in.
+ t.Run(fmt.Sprintf("PartialCopyIn_%v", ty), func(t *testing.T) {
+ expected := reflect.New(ty).Interface().(marshal.Marshallable)
+ actual := reflect.New(ty).Interface().(marshal.Marshallable)
+ analysis.RandomizeValue(expected)
+
+ limitedCopyIn(t, expected, actual, expected.SizeBytes()/2)
+ })
+
+ // Partial copy-out.
+ t.Run(fmt.Sprintf("PartialCopyOut_%v", ty), func(t *testing.T) {
+ expected := reflect.New(ty).Interface().(marshal.Marshallable)
+ analysis.RandomizeValue(expected)
+
+ limitedCopyOut(t, expected, expected.SizeBytes()/2)
+ })
+
+ // Explicitly request partial copy-out.
+ t.Run(fmt.Sprintf("PartialCopyOutN_%v", ty), func(t *testing.T) {
+ expected := reflect.New(ty).Interface().(marshal.Marshallable)
+ analysis.RandomizeValue(expected)
+
+ copyOutN(t, expected, expected.SizeBytes()/2)
+ })
+ }
+}
+
+// TestLimitedMarshalling verifies marshalling/unmarshalling of slices of
+// marshallable types succeed when the underyling copy in/out operations
+// partially succeed.
+func TestLimitedSliceMarshalling(t *testing.T) {
+ types := []struct {
+ arrayPtrType reflect.Type
+ copySliceIn func(task marshal.Task, addr usermem.Addr, dstSlice interface{}) (int, error)
+ copySliceOut func(task marshal.Task, addr usermem.Addr, srcSlice interface{}) (int, error)
+ unsafeMemory func(arrPtr interface{}) []byte
+ }{
+ // Packed types.
+ {
+ reflect.TypeOf((*[20]test.Stat)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[20]test.Stat)[:]
+ return test.CopyStatSliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[20]test.Stat)[:]
+ return test.CopyStatSliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[20]test.Stat)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ {
+ reflect.TypeOf((*[1]test.Stat)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[1]test.Stat)[:]
+ return test.CopyStatSliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[1]test.Stat)[:]
+ return test.CopyStatSliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[1]test.Stat)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ {
+ reflect.TypeOf((*[5]test.SignalSetAlias)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[5]test.SignalSetAlias)[:]
+ return test.CopySignalSetAliasSliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[5]test.SignalSetAlias)[:]
+ return test.CopySignalSetAliasSliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[5]test.SignalSetAlias)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ // Non-packed types.
+ {
+ reflect.TypeOf((*[20]test.Type1)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[20]test.Type1)[:]
+ return test.CopyType1SliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[20]test.Type1)[:]
+ return test.CopyType1SliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[20]test.Type1)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ {
+ reflect.TypeOf((*[1]test.Type1)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[1]test.Type1)[:]
+ return test.CopyType1SliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[1]test.Type1)[:]
+ return test.CopyType1SliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[1]test.Type1)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ {
+ reflect.TypeOf((*[7]test.Type8)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[7]test.Type8)[:]
+ return test.CopyType8SliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[7]test.Type8)[:]
+ return test.CopyType8SliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[7]test.Type8)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ }
+
+ for _, tt := range types {
+ // The body of this loop is generic over the type tt.arrayPtrType, with
+ // the help of reflection. To aid in readability, comments below show
+ // the equivalent go code assuming
+ // tt.arrayPtrType = typeof(*[20]test.Stat).
+
+ // Equivalent:
+ // var x *[20]test.Stat
+ // arrayTy := reflect.TypeOf(*x)
+ arrayTy := tt.arrayPtrType.Elem()
+
+ // Partial copy-in of slices.
+ t.Run(fmt.Sprintf("PartialCopySliceIn_%v", arrayTy), func(t *testing.T) {
+ // Equivalent:
+ // var x [20]test.Stat
+ // length := len(x)
+ length := arrayTy.Len()
+ if length < 1 {
+ panic("Test type can't be zero-length array")
+ }
+ // Equivalent:
+ // elem := new(test.Stat).(marshal.Marshallable)
+ elem := reflect.New(arrayTy.Elem()).Interface().(marshal.Marshallable)
+
+ // Equivalent:
+ // var expected, actual interface{}
+ // expected = new([20]test.Stat)
+ // actual = new([20]test.Stat)
+ expected := reflect.New(arrayTy).Interface()
+ actual := reflect.New(arrayTy).Interface()
+
+ analysis.RandomizeValue(expected)
+
+ limit := (length * elem.SizeBytes()) / 2
+ // Also make sure the limit is partially inside one of the elements.
+ limit += elem.SizeBytes() / 2
+ analysis.RandomizeValue(expected)
+
+ var task mockTask
+ task.populate(expected)
+ task.setLimit(limit)
+
+ n, err := tt.copySliceIn(&task, usermem.Addr(0), actual)
+ if n != limit {
+ t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+ if n < length*elem.SizeBytes() && err != simulatedErr {
+ t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err)
+ }
+
+ expectedMem := tt.unsafeMemory(expected)
+ defer runtime.KeepAlive(expected)
+ actualMem := tt.unsafeMemory(actual)
+ defer runtime.KeepAlive(actual)
+
+ compareMemory(t, expectedMem, actualMem, n)
+
+ // The last n bytes should be zero for actual, since actual was
+ // zero-initialized, and CopyIn shouldn't have touched those bytes. However
+ // we can only guarantee we didn't touch anything in the last n bytes if the
+ // layout is packed.
+ if elem.Packed() && !isZeroes(actualMem[n:]) {
+ t.Errorf("Expected the last %d bytes of copied in object to be zeroes, got %v\n", (elem.SizeBytes()*length)-n, actualMem)
+ }
+ })
+
+ // Partial copy-out of slices.
+ t.Run(fmt.Sprintf("PartialCopySliceOut_%v", arrayTy), func(t *testing.T) {
+ // Equivalent:
+ // var x [20]test.Stat
+ // length := len(x)
+ length := arrayTy.Len()
+ if length < 1 {
+ panic("Test type can't be zero-length array")
+ }
+ // Equivalent:
+ // elem := new(test.Stat).(marshal.Marshallable)
+ elem := reflect.New(arrayTy.Elem()).Interface().(marshal.Marshallable)
+
+ // Equivalent:
+ // var expected, actual interface{}
+ // expected = new([20]test.Stat)
+ // actual = new([20]test.Stat)
+ expected := reflect.New(arrayTy).Interface()
+
+ analysis.RandomizeValue(expected)
+
+ limit := (length * elem.SizeBytes()) / 2
+ // Also make sure the limit is partially inside one of the elements.
+ limit += elem.SizeBytes() / 2
+ analysis.RandomizeValue(expected)
+
+ var task mockTask
+ task.populate(expected)
+ task.setLimit(limit)
+
+ n, err := tt.copySliceOut(&task, usermem.Addr(0), expected)
+ if n != limit {
+ t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+ if n < length*elem.SizeBytes() && err != simulatedErr {
+ t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err)
+ }
+
+ expectedMem := tt.unsafeMemory(expected)
+ defer runtime.KeepAlive(expected)
+ actualMem := task.taskMem.Bytes
+
+ compareMemory(t, expectedMem, actualMem, n)
+ })
+ }
+}
diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go
new file mode 100644
index 000000000..f75ca1b7f
--- /dev/null
+++ b/tools/go_marshal/test/test.go
@@ -0,0 +1,176 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package test contains data structures for testing the go_marshal tool.
+package test
+
+import (
+ // We're intentionally using a package name alias here even though it's not
+ // necessary to test the code generator's ability to handle package aliases.
+ ex "gvisor.dev/gvisor/tools/go_marshal/test/external"
+)
+
+// Type1 is a test data type.
+//
+// +marshal slice:Type1Slice
+type Type1 struct {
+ a Type2
+ x, y int64 // Multiple field names.
+ b byte `marshal:"unaligned"` // Short field.
+ c uint64
+ _ uint32 // Unnamed scalar field.
+ _ [6]byte // Unnamed vector field, typical padding.
+ _ [2]byte
+ xs [8]int32
+ as [10]Type2 `marshal:"unaligned"` // Array of Marshallable objects.
+ ss Type3
+}
+
+// Type2 is a test data type.
+//
+// +marshal
+type Type2 struct {
+ n int64
+ c byte
+ _ [7]byte
+ m int64
+ a int64
+}
+
+// Type3 is a test data type.
+//
+// +marshal
+type Type3 struct {
+ s int64
+ x ex.External // Type defined in another package.
+}
+
+// Type4 is a test data type.
+//
+// +marshal
+type Type4 struct {
+ c byte
+ x int64 `marshal:"unaligned"`
+ d byte
+ _ [7]byte
+}
+
+// Type5 is a test data type.
+//
+// +marshal
+type Type5 struct {
+ n int64
+ t Type4
+ m int64
+}
+
+// Type6 is a test data type ends mid-word.
+//
+// +marshal
+type Type6 struct {
+ a int64
+ b int64
+ // If c isn't marked unaligned, analysis fails (as it should, since
+ // the unsafe API corrupts Type7).
+ c byte `marshal:"unaligned"`
+}
+
+// Type7 is a test data type that contains a child struct that ends
+// mid-word.
+// +marshal
+type Type7 struct {
+ x Type6
+ y int64
+}
+
+// Type8 is a test data type which contains an external non-packed field.
+//
+// +marshal slice:Type8Slice
+type Type8 struct {
+ a int64
+ np ex.NotPacked
+ b int64
+}
+
+// Timespec represents struct timespec in <time.h>.
+//
+// +marshal
+type Timespec struct {
+ Sec int64
+ Nsec int64
+}
+
+// Stat represents struct stat.
+//
+// +marshal slice:StatSlice
+type Stat struct {
+ Dev uint64
+ Ino uint64
+ Nlink uint64
+ Mode uint32
+ UID uint32
+ GID uint32
+ _ int32
+ Rdev uint64
+ Size int64
+ Blksize int64
+ Blocks int64
+ ATime Timespec
+ MTime Timespec
+ CTime Timespec
+ _ [3]int64
+}
+
+// InetAddr is an example marshallable newtype on an array.
+//
+// +marshal
+type InetAddr [4]byte
+
+// SignalSet is an example marshallable newtype on a primitive.
+//
+// +marshal slice:SignalSetSlice:inner
+type SignalSet uint64
+
+// SignalSetAlias is an example newtype on another marshallable type.
+//
+// +marshal slice:SignalSetAliasSlice
+type SignalSetAlias SignalSet
+
+const sizeA = 64
+const sizeB = 8
+
+// TestArray is a test data structure on an array with a constant length.
+//
+// +marshal
+type TestArray [sizeA]int32
+
+// TestArray2 is a newtype on an array with a simple arithmetic expression of
+// constants for the array length.
+//
+// +marshal
+type TestArray2 [sizeA * sizeB]int32
+
+// TestArray2 is a newtype on an array with a simple arithmetic expression of
+// mixed constants and literals for the array length.
+//
+// +marshal
+type TestArray3 [sizeA*sizeB + 12]int32
+
+// Type9 is a test data type containing an array with a non-literal length.
+//
+// +marshal
+type Type9 struct {
+ x int64
+ y [sizeA]int32
+}
diff --git a/tools/go_mod.sh b/tools/go_mod.sh
new file mode 100755
index 000000000..84b779d6d
--- /dev/null
+++ b/tools/go_mod.sh
@@ -0,0 +1,29 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -eo pipefail
+
+# Build the :gopath target.
+bazel build //:gopath
+declare -r gopathdir="bazel-bin/gopath/src/gvisor.dev/gvisor/"
+
+# Copy go.mod and execute the command.
+cp -a go.mod go.sum "${gopathdir}"
+(cd "${gopathdir}" && go mod "$@")
+cp -a "${gopathdir}/go.mod" "${gopathdir}/go.sum" .
+
+# Cleanup the WORKSPACE file.
+bazel run //:gazelle -- update-repos -from_file=go.mod
diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD
new file mode 100644
index 000000000..503cdf2e5
--- /dev/null
+++ b/tools/go_stateify/BUILD
@@ -0,0 +1,10 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "stateify",
+ srcs = ["main.go"],
+ visibility = ["//:sandbox"],
+ deps = ["//tools/tags"],
+)
diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl
new file mode 100644
index 000000000..6a5e666f0
--- /dev/null
+++ b/tools/go_stateify/defs.bzl
@@ -0,0 +1,60 @@
+"""Stateify is a tool for generating state wrappers for Go types."""
+
+def _go_stateify_impl(ctx):
+ """Implementation for the stateify tool."""
+ output = ctx.outputs.out
+
+ # Run the stateify command.
+ args = ["-output=%s" % output.path]
+ args.append("-fullpkg=%s" % ctx.attr.package)
+ if ctx.attr._statepkg:
+ args.append("-statepkg=%s" % ctx.attr._statepkg)
+ if ctx.attr.imports:
+ args.append("-imports=%s" % ",".join(ctx.attr.imports))
+ args.append("--")
+ for src in ctx.attr.srcs:
+ args += [f.path for f in src.files.to_list()]
+ ctx.actions.run(
+ inputs = ctx.files.srcs,
+ outputs = [output],
+ mnemonic = "GoStateify",
+ progress_message = "Generating state library %s" % ctx.label,
+ arguments = args,
+ executable = ctx.executable._tool,
+ )
+
+go_stateify = rule(
+ implementation = _go_stateify_impl,
+ doc = "Generates save and restore logic from a set of Go files.",
+ attrs = {
+ "srcs": attr.label_list(
+ doc = """
+The input source files. These files should include all structs in the package
+that need to be saved.
+""",
+ mandatory = True,
+ allow_files = True,
+ ),
+ "imports": attr.string_list(
+ doc = """
+An optional list of extra non-aliased, Go-style absolute import paths required
+for statified types.
+""",
+ mandatory = False,
+ ),
+ "package": attr.string(
+ doc = "The fully qualified package name for the input sources.",
+ mandatory = True,
+ ),
+ "out": attr.output(
+ doc = "Name of the generator output file.",
+ mandatory = True,
+ ),
+ "_tool": attr.label(
+ executable = True,
+ cfg = "host",
+ default = Label("//tools/go_stateify:stateify"),
+ ),
+ "_statepkg": attr.string(default = "gvisor.dev/gvisor/pkg/state"),
+ },
+)
diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go
new file mode 100644
index 000000000..4f6ed208a
--- /dev/null
+++ b/tools/go_stateify/main.go
@@ -0,0 +1,476 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Stateify provides a simple way to generate Load/Save methods based on
+// existing types and struct tags.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/parser"
+ "go/token"
+ "os"
+ "path/filepath"
+ "reflect"
+ "strings"
+ "sync"
+
+ "gvisor.dev/gvisor/tools/tags"
+)
+
+var (
+ fullPkg = flag.String("fullpkg", "", "fully qualified output package")
+ imports = flag.String("imports", "", "extra imports for the output file")
+ output = flag.String("output", "", "output file")
+ statePkg = flag.String("statepkg", "", "state import package; defaults to empty")
+)
+
+// resolveTypeName returns a qualified type name.
+func resolveTypeName(name string, typ ast.Expr) (field string, qualified string) {
+ for done := false; !done; {
+ // Resolve star expressions.
+ switch rs := typ.(type) {
+ case *ast.StarExpr:
+ qualified += "*"
+ typ = rs.X
+ case *ast.ArrayType:
+ if rs.Len == nil {
+ // Slice type declaration.
+ qualified += "[]"
+ } else {
+ // Array type declaration.
+ qualified += "[" + rs.Len.(*ast.BasicLit).Value + "]"
+ }
+ typ = rs.Elt
+ default:
+ // No more descent.
+ done = true
+ }
+ }
+
+ // Resolve a package selector.
+ sel, ok := typ.(*ast.SelectorExpr)
+ if ok {
+ qualified = qualified + sel.X.(*ast.Ident).Name + "."
+ typ = sel.Sel
+ }
+
+ // Figure out actual type name.
+ ident, ok := typ.(*ast.Ident)
+ if !ok {
+ panic(fmt.Sprintf("type not supported: %s (involves anonymous types?)", name))
+ }
+ field = ident.Name
+ qualified = qualified + field
+ return
+}
+
+// extractStateTag pulls the relevant state tag.
+func extractStateTag(tag *ast.BasicLit) string {
+ if tag == nil {
+ return ""
+ }
+ if len(tag.Value) < 2 {
+ return ""
+ }
+ return reflect.StructTag(tag.Value[1 : len(tag.Value)-1]).Get("state")
+}
+
+// scanFunctions is a set of functions passed to scanFields.
+type scanFunctions struct {
+ zerovalue func(name string)
+ normal func(name string)
+ wait func(name string)
+ value func(name, typName string)
+}
+
+// scanFields scans the fields of a struct.
+//
+// Each provided function will be applied to appropriately tagged fields, or
+// skipped if nil.
+//
+// Fields tagged nosave are skipped.
+func scanFields(ss *ast.StructType, prefix string, fn scanFunctions) {
+ if ss.Fields.List == nil {
+ // No fields.
+ return
+ }
+
+ // Scan all fields.
+ for _, field := range ss.Fields.List {
+ // Calculate the name.
+ name := ""
+ if field.Names != nil {
+ // It's a named field; override.
+ name = field.Names[0].Name
+ } else {
+ // Anonymous types can't be embedded, so we don't need
+ // to worry about providing a useful name here.
+ name, _ = resolveTypeName("", field.Type)
+ }
+
+ // Skip _ fields.
+ if name == "_" {
+ continue
+ }
+
+ // Is this a anonymous struct? If yes, then continue the
+ // recursion with the given prefix. We don't pay attention to
+ // any tags on the top-level struct field.
+ tag := extractStateTag(field.Tag)
+ if anon, ok := field.Type.(*ast.StructType); ok && tag == "" {
+ scanFields(anon, name+".", fn)
+ continue
+ }
+
+ switch tag {
+ case "zerovalue":
+ if fn.zerovalue != nil {
+ fn.zerovalue(name)
+ }
+
+ case "":
+ if fn.normal != nil {
+ fn.normal(name)
+ }
+
+ case "wait":
+ if fn.wait != nil {
+ fn.wait(name)
+ }
+
+ case "manual", "nosave", "ignore":
+ // Do nothing.
+
+ default:
+ if strings.HasPrefix(tag, ".(") && strings.HasSuffix(tag, ")") {
+ if fn.value != nil {
+ fn.value(name, tag[2:len(tag)-1])
+ }
+ }
+ }
+ }
+}
+
+func camelCased(name string) string {
+ return strings.ToUpper(name[:1]) + name[1:]
+}
+
+func main() {
+ // Parse flags.
+ flag.Usage = func() {
+ fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", os.Args[0])
+ flag.PrintDefaults()
+ }
+ flag.Parse()
+ if len(flag.Args()) == 0 {
+ flag.Usage()
+ os.Exit(1)
+ }
+ if *fullPkg == "" {
+ fmt.Fprintf(os.Stderr, "Error: package required.")
+ os.Exit(1)
+ }
+
+ // Open the output file.
+ var (
+ outputFile *os.File
+ err error
+ )
+ if *output == "" || *output == "-" {
+ outputFile = os.Stdout
+ } else {
+ outputFile, err = os.OpenFile(*output, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error opening output %q: %v", *output, err)
+ }
+ defer outputFile.Close()
+ }
+
+ // Set the statePrefix for below, depending on the import.
+ statePrefix := ""
+ if *statePkg != "" {
+ parts := strings.Split(*statePkg, "/")
+ statePrefix = parts[len(parts)-1] + "."
+ }
+
+ // initCalls is dumped at the end.
+ var initCalls []string
+
+ // Common closures.
+ emitRegister := func(name string) {
+ initCalls = append(initCalls, fmt.Sprintf("%sRegister((*%s)(nil))", statePrefix, name))
+ }
+ emitZeroCheck := func(name string) {
+ fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { %sFailf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, statePrefix, name, name)
+ }
+
+ // Automated warning.
+ fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n")
+
+ // Emit build tags.
+ if t := tags.Aggregate(flag.Args()); len(t) > 0 {
+ fmt.Fprintf(outputFile, "%s\n\n", strings.Join(t.Lines(), "\n"))
+ }
+
+ // Emit the package name.
+ _, pkg := filepath.Split(*fullPkg)
+ fmt.Fprintf(outputFile, "package %s\n\n", pkg)
+
+ // Emit the imports lazily.
+ var once sync.Once
+ maybeEmitImports := func() {
+ once.Do(func() {
+ // Emit the imports.
+ fmt.Fprint(outputFile, "import (\n")
+ if *statePkg != "" {
+ fmt.Fprintf(outputFile, " \"%s\"\n", *statePkg)
+ }
+ if *imports != "" {
+ for _, i := range strings.Split(*imports, ",") {
+ fmt.Fprintf(outputFile, " \"%s\"\n", i)
+ }
+ }
+ fmt.Fprint(outputFile, ")\n\n")
+ })
+ }
+
+ files := make([]*ast.File, 0, len(flag.Args()))
+
+ // Parse the input files.
+ for _, filename := range flag.Args() {
+ // Parse the file.
+ fset := token.NewFileSet()
+ f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
+ if err != nil {
+ // Not a valid input file?
+ fmt.Fprintf(os.Stderr, "Input %q can't be parsed: %v\n", filename, err)
+ os.Exit(1)
+ }
+
+ files = append(files, f)
+ }
+
+ type method struct {
+ receiver string
+ name string
+ }
+
+ // Search for and add all methods with a pointer receiver and no other
+ // arguments to a set. We support auto-detecting the existence of
+ // several different methods with this signature.
+ simpleMethods := map[method]struct{}{}
+ for _, f := range files {
+
+ // Go over all functions.
+ for _, decl := range f.Decls {
+ d, ok := decl.(*ast.FuncDecl)
+ if !ok {
+ continue
+ }
+ if d.Name == nil || d.Recv == nil || d.Type == nil {
+ // Not a named method.
+ continue
+ }
+ if len(d.Recv.List) != 1 {
+ // Wrong number of receivers?
+ continue
+ }
+ if d.Type.Params != nil && len(d.Type.Params.List) != 0 {
+ // Has argument(s).
+ continue
+ }
+ if d.Type.Results != nil && len(d.Type.Results.List) != 0 {
+ // Has return(s).
+ continue
+ }
+
+ pt, ok := d.Recv.List[0].Type.(*ast.StarExpr)
+ if !ok {
+ // Not a pointer receiver.
+ continue
+ }
+
+ t, ok := pt.X.(*ast.Ident)
+ if !ok {
+ // This shouldn't happen with valid Go.
+ continue
+ }
+
+ simpleMethods[method{t.Name, d.Name.Name}] = struct{}{}
+ }
+ }
+
+ for _, f := range files {
+ // Go over all named types.
+ for _, decl := range f.Decls {
+ d, ok := decl.(*ast.GenDecl)
+ if !ok || d.Tok != token.TYPE {
+ continue
+ }
+
+ // Only generate code for types marked "// +stateify
+ // savable" in one of the proceeding comment lines. If
+ // the line is marked "// +stateify type" then only
+ // generate type information and register the type.
+ if d.Doc == nil {
+ continue
+ }
+ var (
+ generateTypeInfo = false
+ generateSaverLoader = false
+ )
+ for _, l := range d.Doc.List {
+ if l.Text == "// +stateify savable" {
+ generateTypeInfo = true
+ generateSaverLoader = true
+ break
+ }
+ if l.Text == "// +stateify type" {
+ generateTypeInfo = true
+ }
+ }
+ if !generateTypeInfo && !generateSaverLoader {
+ continue
+ }
+
+ for _, gs := range d.Specs {
+ ts := gs.(*ast.TypeSpec)
+ switch x := ts.Type.(type) {
+ case *ast.StructType:
+ maybeEmitImports()
+
+ // Record the slot for each field.
+ fieldCount := 0
+ fields := make(map[string]int)
+ emitField := func(name string) {
+ fmt.Fprintf(outputFile, " \"%s\",\n", name)
+ fields[name] = fieldCount
+ fieldCount++
+ }
+ emitFieldValue := func(name string, _ string) {
+ emitField(name)
+ }
+ emitLoadValue := func(name, typName string) {
+ fmt.Fprintf(outputFile, " m.LoadValue(%d, new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", fields[name], typName, camelCased(name), typName)
+ }
+ emitLoad := func(name string) {
+ fmt.Fprintf(outputFile, " m.Load(%d, &x.%s)\n", fields[name], name)
+ }
+ emitLoadWait := func(name string) {
+ fmt.Fprintf(outputFile, " m.LoadWait(%d, &x.%s)\n", fields[name], name)
+ }
+ emitSaveValue := func(name, typName string) {
+ fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name))
+ fmt.Fprintf(outputFile, " m.SaveValue(%d, %s)\n", fields[name], name)
+ }
+ emitSave := func(name string) {
+ fmt.Fprintf(outputFile, " m.Save(%d, &x.%s)\n", fields[name], name)
+ }
+
+ // Generate the type name method.
+ fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
+ fmt.Fprintf(outputFile, "}\n\n")
+
+ // Generate the fields method.
+ fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return []string{\n")
+ scanFields(x, "", scanFunctions{
+ normal: emitField,
+ wait: emitField,
+ value: emitFieldValue,
+ })
+ fmt.Fprintf(outputFile, " }\n")
+ fmt.Fprintf(outputFile, "}\n\n")
+
+ // Define beforeSave if a definition was not found. This prevents
+ // the code from compiling if a custom beforeSave was defined in a
+ // file not provided to this binary and prevents inherited methods
+ // from being called multiple times by overriding them.
+ if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok && generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n\n", ts.Name.Name)
+ }
+
+ // Generate the save method.
+ //
+ // N.B. For historical reasons, we perform the value saves first,
+ // and perform the value loads last. There should be no dependency
+ // on this specific behavior, but the ability to specify slots
+ // allows a manual implementation to be order-dependent.
+ if generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) StateSave(m %sSink) {\n", ts.Name.Name, statePrefix)
+ fmt.Fprintf(outputFile, " x.beforeSave()\n")
+ scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck})
+ scanFields(x, "", scanFunctions{value: emitSaveValue})
+ scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave})
+ fmt.Fprintf(outputFile, "}\n\n")
+ }
+
+ // Define afterLoad if a definition was not found. We do this for
+ // the same reason that we do it for beforeSave.
+ _, hasAfterLoad := simpleMethods[method{ts.Name.Name, "afterLoad"}]
+ if !hasAfterLoad && generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n\n", ts.Name.Name)
+ }
+
+ // Generate the load method.
+ //
+ // N.B. See the comment above for the save method.
+ if generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) StateLoad(m %sSource) {\n", ts.Name.Name, statePrefix)
+ scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait})
+ scanFields(x, "", scanFunctions{value: emitLoadValue})
+ if hasAfterLoad {
+ // The call to afterLoad is made conditionally, because when
+ // AfterLoad is called, the object encodes a dependency on
+ // referred objects (i.e. fields). This means that afterLoad
+ // will not be called until the other afterLoads are called.
+ fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n")
+ }
+ fmt.Fprintf(outputFile, "}\n\n")
+ }
+
+ // Add to our registration.
+ emitRegister(ts.Name.Name)
+
+ case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType:
+ maybeEmitImports()
+
+ // Generate the info methods.
+ fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
+ fmt.Fprintf(outputFile, "}\n\n")
+ fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return nil\n")
+ fmt.Fprintf(outputFile, "}\n\n")
+
+ // See above.
+ emitRegister(ts.Name.Name)
+ }
+ }
+ }
+ }
+
+ if len(initCalls) > 0 {
+ // Emit the init() function.
+ fmt.Fprintf(outputFile, "func init() {\n")
+ for _, ic := range initCalls {
+ fmt.Fprintf(outputFile, " %s\n", ic)
+ }
+ fmt.Fprintf(outputFile, "}\n")
+ }
+}
diff --git a/tools/installers/BUILD b/tools/installers/BUILD
new file mode 100644
index 000000000..caa7b1983
--- /dev/null
+++ b/tools/installers/BUILD
@@ -0,0 +1,35 @@
+# Installers for use by the tools/vm_test rules.
+
+package(
+ default_visibility = ["//:sandbox"],
+ licenses = ["notice"],
+)
+
+filegroup(
+ name = "runsc",
+ srcs = ["//runsc"],
+)
+
+sh_binary(
+ name = "head",
+ srcs = ["head.sh"],
+ data = [":runsc"],
+)
+
+sh_binary(
+ name = "images",
+ srcs = ["images.sh"],
+ data = [
+ "//images",
+ ],
+)
+
+sh_binary(
+ name = "master",
+ srcs = ["master.sh"],
+)
+
+sh_binary(
+ name = "shim",
+ srcs = ["shim.sh"],
+)
diff --git a/tools/installers/head.sh b/tools/installers/head.sh
new file mode 100755
index 000000000..7fc566ebd
--- /dev/null
+++ b/tools/installers/head.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Install our runtime.
+$(find . -executable -type f -name runsc) install
+
+# Restart docker.
+service docker restart || true
diff --git a/tools/installers/images.sh b/tools/installers/images.sh
new file mode 100755
index 000000000..52e750f57
--- /dev/null
+++ b/tools/installers/images.sh
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeuo pipefail
+
+# Find the images directory.
+for images in $(find . -type d -name images); do
+ if [[ -f "${images}"/Makefile ]]; then
+ make -C "${images}" load-all-images
+ fi
+done
diff --git a/tools/installers/master.sh b/tools/installers/master.sh
new file mode 100755
index 000000000..2c6001c6c
--- /dev/null
+++ b/tools/installers/master.sh
@@ -0,0 +1,34 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Install runsc from the master branch.
+set -e
+
+curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add -
+add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main"
+
+while true; do
+ if (apt-get update && apt-get install -y runsc); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
+
+runsc install
+service docker restart
diff --git a/tools/installers/shim.sh b/tools/installers/shim.sh
new file mode 100755
index 000000000..f7dd790a1
--- /dev/null
+++ b/tools/installers/shim.sh
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Reinstall the latest containerd shim.
+declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim"
+declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX)
+declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX)
+wget --no-verbose "${base}"/latest -O ${latest}
+wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path}
+chmod +x ${shim_path}
+mv ${shim_path} /usr/local/bin/gvisor-containerd-shim
diff --git a/tools/issue_reviver/BUILD b/tools/issue_reviver/BUILD
new file mode 100644
index 000000000..4ef1a3124
--- /dev/null
+++ b/tools/issue_reviver/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "issue_reviver",
+ srcs = ["main.go"],
+ deps = [
+ "//tools/issue_reviver/github",
+ "//tools/issue_reviver/reviver",
+ ],
+)
diff --git a/tools/issue_reviver/github/BUILD b/tools/issue_reviver/github/BUILD
new file mode 100644
index 000000000..da4133472
--- /dev/null
+++ b/tools/issue_reviver/github/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "github",
+ srcs = ["github.go"],
+ visibility = [
+ "//tools/issue_reviver:__subpackages__",
+ ],
+ deps = [
+ "//tools/issue_reviver/reviver",
+ "@com_github_google_go-github//github:go_default_library",
+ "@org_golang_x_oauth2//:go_default_library",
+ ],
+)
diff --git a/tools/issue_reviver/github/github.go b/tools/issue_reviver/github/github.go
new file mode 100644
index 000000000..e07949c8f
--- /dev/null
+++ b/tools/issue_reviver/github/github.go
@@ -0,0 +1,164 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package github implements reviver.Bugger interface on top of Github issues.
+package github
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/google/go-github/github"
+ "golang.org/x/oauth2"
+ "gvisor.dev/gvisor/tools/issue_reviver/reviver"
+)
+
+// Bugger implements reviver.Bugger interface for github issues.
+type Bugger struct {
+ owner string
+ repo string
+ dryRun bool
+
+ client *github.Client
+ issues map[int]*github.Issue
+}
+
+// NewBugger creates a new Bugger.
+func NewBugger(token, owner, repo string, dryRun bool) (*Bugger, error) {
+ b := &Bugger{
+ owner: owner,
+ repo: repo,
+ dryRun: dryRun,
+ issues: map[int]*github.Issue{},
+ }
+ if err := b.load(token); err != nil {
+ return nil, err
+ }
+ return b, nil
+}
+
+func (b *Bugger) load(token string) error {
+ ctx := context.Background()
+ if len(token) == 0 {
+ fmt.Print("No OAUTH token provided, using unauthenticated account.\n")
+ b.client = github.NewClient(nil)
+ } else {
+ ts := oauth2.StaticTokenSource(
+ &oauth2.Token{AccessToken: token},
+ )
+ tc := oauth2.NewClient(ctx, ts)
+ b.client = github.NewClient(tc)
+ }
+
+ err := processAllPages(func(listOpts github.ListOptions) (*github.Response, error) {
+ opts := &github.IssueListByRepoOptions{State: "open", ListOptions: listOpts}
+ tmps, resp, err := b.client.Issues.ListByRepo(ctx, b.owner, b.repo, opts)
+ if err != nil {
+ return resp, err
+ }
+ for _, issue := range tmps {
+ b.issues[issue.GetNumber()] = issue
+ }
+ return resp, nil
+ })
+ if err != nil {
+ return err
+ }
+
+ fmt.Printf("Loaded %d issues from github.com/%s/%s\n", len(b.issues), b.owner, b.repo)
+ return nil
+}
+
+// Activate implements reviver.Bugger.
+func (b *Bugger) Activate(todo *reviver.Todo) (bool, error) {
+ const prefix = "gvisor.dev/issue/"
+
+ // First check if I can handle the TODO.
+ idStr := strings.TrimPrefix(todo.Issue, prefix)
+ if len(todo.Issue) == len(idStr) {
+ return false, nil
+ }
+
+ id, err := strconv.Atoi(idStr)
+ if err != nil {
+ return true, err
+ }
+
+ // Check against active issues cache.
+ if _, ok := b.issues[id]; ok {
+ fmt.Printf("%q is active: OK\n", todo.Issue)
+ return true, nil
+ }
+
+ fmt.Printf("%q is not active: reopening issue %d\n", todo.Issue, id)
+
+ // Format comment with TODO locations and search link.
+ comment := strings.Builder{}
+ fmt.Fprintln(&comment, "There are TODOs still referencing this issue:")
+ for _, l := range todo.Locations {
+ fmt.Fprintf(&comment,
+ "1. [%s:%d](https://github.com/%s/%s/blob/HEAD/%s#%d): %s\n",
+ l.File, l.Line, b.owner, b.repo, l.File, l.Line, l.Comment)
+ }
+ fmt.Fprintf(&comment,
+ "\n\nSearch [TODO](https://github.com/%s/%s/search?q=%%22%s%d%%22)", b.owner, b.repo, prefix, id)
+
+ if b.dryRun {
+ fmt.Printf("[dry-run: skipping change to issue %d]\n%s\n=======================\n", id, comment.String())
+ return true, nil
+ }
+
+ ctx := context.Background()
+ req := &github.IssueRequest{State: github.String("open")}
+ _, _, err = b.client.Issues.Edit(ctx, b.owner, b.repo, id, req)
+ if err != nil {
+ return true, fmt.Errorf("failed to reactivate issue %d: %v", id, err)
+ }
+
+ cmt := &github.IssueComment{
+ Body: github.String(comment.String()),
+ Reactions: &github.Reactions{Confused: github.Int(1)},
+ }
+ if _, _, err := b.client.Issues.CreateComment(ctx, b.owner, b.repo, id, cmt); err != nil {
+ return true, fmt.Errorf("failed to add comment to issue %d: %v", id, err)
+ }
+
+ return true, nil
+}
+
+func processAllPages(fn func(github.ListOptions) (*github.Response, error)) error {
+ opts := github.ListOptions{PerPage: 1000}
+ for {
+ resp, err := fn(opts)
+ if err != nil {
+ if rateErr, ok := err.(*github.RateLimitError); ok {
+ duration := rateErr.Rate.Reset.Sub(time.Now())
+ if duration > 5*time.Minute {
+ return fmt.Errorf("Rate limited for too long: %v", duration)
+ }
+ fmt.Printf("Rate limited, sleeping for: %v\n", duration)
+ time.Sleep(duration)
+ continue
+ }
+ return err
+ }
+ if resp.NextPage == 0 {
+ return nil
+ }
+ opts.Page = resp.NextPage
+ }
+}
diff --git a/tools/issue_reviver/main.go b/tools/issue_reviver/main.go
new file mode 100644
index 000000000..47c796b8a
--- /dev/null
+++ b/tools/issue_reviver/main.go
@@ -0,0 +1,100 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package main is the entry point for issue_reviver.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "strings"
+
+ "gvisor.dev/gvisor/tools/issue_reviver/github"
+ "gvisor.dev/gvisor/tools/issue_reviver/reviver"
+)
+
+var (
+ owner string
+ repo string
+ tokenFile string
+ path string
+ dryRun bool
+)
+
+// Keep the options simple for now. Supports only a single path and repo.
+func init() {
+ flag.StringVar(&owner, "owner", "", "Github project org/owner to look for issues")
+ flag.StringVar(&repo, "repo", "", "Github repo to look for issues")
+ flag.StringVar(&tokenFile, "oauth-token-file", "", "Path to file containing the OAUTH token to be used as credential to github")
+ flag.StringVar(&path, "path", ".", "Path to scan for TODOs")
+ flag.BoolVar(&dryRun, "dry-run", false, "If set to true, no changes are made to issues")
+}
+
+func main() {
+ // Set defaults from the environment.
+ repository := os.Getenv("GITHUB_REPOSITORY")
+ if parts := strings.SplitN(repository, "/", 2); len(parts) == 2 {
+ owner = parts[0]
+ repo = parts[1]
+ }
+
+ // Parse flags.
+ flag.Parse()
+
+ // Check for mandatory parameters.
+ if len(owner) == 0 {
+ fmt.Println("missing --owner option.")
+ flag.Usage()
+ os.Exit(1)
+ }
+ if len(repo) == 0 {
+ fmt.Println("missing --repo option.")
+ flag.Usage()
+ os.Exit(1)
+ }
+ if len(path) == 0 {
+ fmt.Println("missing --path option.")
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ // The access token may be passed as a file so it doesn't show up in
+ // command line arguments. It also may be provided through the
+ // environment to faciliate use through GitHub's CI system.
+ token := os.Getenv("GITHUB_TOKEN")
+ if len(tokenFile) != 0 {
+ bytes, err := ioutil.ReadFile(tokenFile)
+ if err != nil {
+ fmt.Println(err.Error())
+ os.Exit(1)
+ }
+ token = string(bytes)
+ }
+
+ bugger, err := github.NewBugger(token, owner, repo, dryRun)
+ if err != nil {
+ fmt.Fprintln(os.Stderr, "Error getting github issues:", err)
+ os.Exit(1)
+ }
+ rev := reviver.New([]string{path}, []reviver.Bugger{bugger})
+ if errs := rev.Run(); len(errs) > 0 {
+ fmt.Fprintf(os.Stderr, "Encountered %d errors:\n", len(errs))
+ for _, err := range errs {
+ fmt.Fprintf(os.Stderr, "\t%v\n", err)
+ }
+ os.Exit(1)
+ }
+}
diff --git a/tools/issue_reviver/reviver/BUILD b/tools/issue_reviver/reviver/BUILD
new file mode 100644
index 000000000..d262932bd
--- /dev/null
+++ b/tools/issue_reviver/reviver/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "reviver",
+ srcs = ["reviver.go"],
+ visibility = [
+ "//tools/issue_reviver:__subpackages__",
+ ],
+)
+
+go_test(
+ name = "reviver_test",
+ size = "small",
+ srcs = ["reviver_test.go"],
+ library = ":reviver",
+)
diff --git a/tools/issue_reviver/reviver/reviver.go b/tools/issue_reviver/reviver/reviver.go
new file mode 100644
index 000000000..682db0c01
--- /dev/null
+++ b/tools/issue_reviver/reviver/reviver.go
@@ -0,0 +1,192 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package reviver scans the code looking for TODOs and pass them to registered
+// Buggers to ensure TODOs point to active issues.
+package reviver
+
+import (
+ "bufio"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "regexp"
+ "sync"
+)
+
+// This is how a TODO looks like.
+var regexTodo = regexp.MustCompile(`(\/\/|#)\s*(TODO|FIXME)\(([a-zA-Z0-9.\/]+)\):\s*(.+)`)
+
+// Bugger interface is called for every TODO found in the code. If it can handle
+// the TODO, it must return true. If it returns false, the next Bugger is
+// called. If no Bugger handles the TODO, it's dropped on the floor.
+type Bugger interface {
+ Activate(todo *Todo) (bool, error)
+}
+
+// Location saves the location where the TODO was found.
+type Location struct {
+ Comment string
+ File string
+ Line uint
+}
+
+// Todo represents a unique TODO. There can be several TODOs pointing to the
+// same issue in the code. They are all grouped together.
+type Todo struct {
+ Issue string
+ Locations []Location
+}
+
+// Reviver scans the given paths for TODOs and calls Buggers to handle them.
+type Reviver struct {
+ paths []string
+ buggers []Bugger
+
+ mu sync.Mutex
+ todos map[string]*Todo
+ errs []error
+}
+
+// New create a new Reviver.
+func New(paths []string, buggers []Bugger) *Reviver {
+ return &Reviver{
+ paths: paths,
+ buggers: buggers,
+ todos: map[string]*Todo{},
+ }
+}
+
+// Run runs. It returns all errors found during processing, it doesn't stop
+// on errors.
+func (r *Reviver) Run() []error {
+ // Process each directory in parallel.
+ wg := sync.WaitGroup{}
+ for _, path := range r.paths {
+ wg.Add(1)
+ go func(path string) {
+ defer wg.Done()
+ r.processPath(path, &wg)
+ }(path)
+ }
+
+ wg.Wait()
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ fmt.Printf("Processing %d TODOs (%d errors)...\n", len(r.todos), len(r.errs))
+ dropped := 0
+ for _, todo := range r.todos {
+ ok, err := r.processTodo(todo)
+ if err != nil {
+ r.errs = append(r.errs, err)
+ }
+ if !ok {
+ dropped++
+ }
+ }
+ fmt.Printf("Processed %d TODOs, %d were skipped (%d errors)\n", len(r.todos)-dropped, dropped, len(r.errs))
+
+ return r.errs
+}
+
+func (r *Reviver) processPath(path string, wg *sync.WaitGroup) {
+ fmt.Printf("Processing dir %q\n", path)
+ fis, err := ioutil.ReadDir(path)
+ if err != nil {
+ r.addErr(fmt.Errorf("error processing dir %q: %v", path, err))
+ return
+ }
+
+ for _, fi := range fis {
+ childPath := filepath.Join(path, fi.Name())
+ switch {
+ case fi.Mode().IsDir():
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ r.processPath(childPath, wg)
+ }()
+
+ case fi.Mode().IsRegular():
+ file, err := os.Open(childPath)
+ if err != nil {
+ r.addErr(err)
+ continue
+ }
+
+ scanner := bufio.NewScanner(file)
+ lineno := uint(0)
+ for scanner.Scan() {
+ lineno++
+ line := scanner.Text()
+ if todo := r.processLine(line, childPath, lineno); todo != nil {
+ r.addTodo(todo)
+ }
+ }
+ }
+ }
+}
+
+func (r *Reviver) processLine(line, path string, lineno uint) *Todo {
+ matches := regexTodo.FindStringSubmatch(line)
+ if matches == nil {
+ return nil
+ }
+ if len(matches) != 5 {
+ panic(fmt.Sprintf("regex returned wrong matches for %q: %v", line, matches))
+ }
+ return &Todo{
+ Issue: matches[3],
+ Locations: []Location{
+ {
+ File: path,
+ Line: lineno,
+ Comment: matches[4],
+ },
+ },
+ }
+}
+
+func (r *Reviver) addTodo(newTodo *Todo) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if todo := r.todos[newTodo.Issue]; todo == nil {
+ r.todos[newTodo.Issue] = newTodo
+ } else {
+ todo.Locations = append(todo.Locations, newTodo.Locations...)
+ }
+}
+
+func (r *Reviver) addErr(err error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.errs = append(r.errs, err)
+}
+
+func (r *Reviver) processTodo(todo *Todo) (bool, error) {
+ for _, bugger := range r.buggers {
+ ok, err := bugger.Activate(todo)
+ if err != nil {
+ return false, err
+ }
+ if ok {
+ return true, nil
+ }
+ }
+ return false, nil
+}
diff --git a/tools/issue_reviver/reviver/reviver_test.go b/tools/issue_reviver/reviver/reviver_test.go
new file mode 100644
index 000000000..a9fb1f9f1
--- /dev/null
+++ b/tools/issue_reviver/reviver/reviver_test.go
@@ -0,0 +1,88 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reviver
+
+import (
+ "testing"
+)
+
+func TestProcessLine(t *testing.T) {
+ for _, tc := range []struct {
+ line string
+ want *Todo
+ }{
+ {
+ line: "// TODO(foobar.com/issue/123): comment, bla. blabla.",
+ want: &Todo{
+ Issue: "foobar.com/issue/123",
+ Locations: []Location{
+ {Comment: "comment, bla. blabla."},
+ },
+ },
+ },
+ {
+ line: "// FIXME(b/123): internal bug",
+ want: &Todo{
+ Issue: "b/123",
+ Locations: []Location{
+ {Comment: "internal bug"},
+ },
+ },
+ },
+ {
+ line: "TODO(issue): not todo",
+ },
+ {
+ line: "FIXME(issue): not todo",
+ },
+ {
+ line: "// TODO (issue): not todo",
+ },
+ {
+ line: "// TODO(issue) not todo",
+ },
+ {
+ line: "// todo(issue): not todo",
+ },
+ {
+ line: "// TODO(issue):",
+ },
+ } {
+ t.Logf("Testing: %s", tc.line)
+ r := Reviver{}
+ got := r.processLine(tc.line, "test", 0)
+ if got == nil {
+ if tc.want != nil {
+ t.Errorf("failed to process line, want: %+v", tc.want)
+ }
+ } else {
+ if tc.want == nil {
+ t.Errorf("expected error, got: %+v", got)
+ continue
+ }
+ if got.Issue != tc.want.Issue {
+ t.Errorf("wrong issue, got: %v, want: %v", got.Issue, tc.want.Issue)
+ }
+ if len(got.Locations) != len(tc.want.Locations) {
+ t.Errorf("wrong number of locations, got: %v, want: %v, locations: %+v", len(got.Locations), len(tc.want.Locations), got.Locations)
+ }
+ for i, wantLoc := range tc.want.Locations {
+ if got.Locations[i].Comment != wantLoc.Comment {
+ t.Errorf("wrong comment, got: %v, want: %v", got.Locations[i].Comment, wantLoc.Comment)
+ }
+ }
+ }
+ }
+}
diff --git a/tools/make_apt.sh b/tools/make_apt.sh
new file mode 100755
index 000000000..3fb1066e5
--- /dev/null
+++ b/tools/make_apt.sh
@@ -0,0 +1,139 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if [[ "$#" -le 3 ]]; then
+ echo "usage: $0 <private-key> <suite> <root> <packages...>"
+ exit 1
+fi
+declare -r private_key=$(readlink -e "$1"); shift
+declare -r suite="$1"; shift
+declare -r root="$1"; shift
+
+# Ensure that we have the correct packages installed.
+function apt_install() {
+ while true; do
+ sudo apt-get update &&
+ sudo apt-get install -y "$@" &&
+ true
+ result="${?}"
+ case $result in
+ 0)
+ break
+ ;;
+ 100)
+ # 100 is the error code that apt-get returns.
+ ;;
+ *)
+ exit $result
+ ;;
+ esac
+ done
+}
+dpkg-sig --help >/dev/null 2>&1 || apt_install dpkg-sig
+apt-ftparchive --help >/dev/null 2>&1 || apt_install apt-utils
+xz --help >/dev/null 2>&1 || apt_install xz-utils
+
+# Verbose from this point.
+set -xeo pipefail
+
+# Create a directory for the release.
+declare -r release="${root}/dists/${suite}"
+mkdir -p "${release}"
+
+# Create a temporary keyring, and ensure it is cleaned up.
+declare -r keyring=$(mktemp /tmp/keyringXXXXXX.gpg)
+cleanup() {
+ rm -f "${keyring}"
+}
+trap cleanup EXIT
+
+# We attempt the import twice because the first one will fail if the public key
+# is not found. This isn't actually a failure for us, because we don't require
+# the public (this may be stored separately). The second import will succeed
+# because, in reality, the first import succeeded and it's a no-op.
+gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}" || \
+ gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}"
+
+# Copy the packages into the root.
+for pkg in "$@"; do
+ ext=${pkg##*.}
+ name=$(basename "${pkg}" ".${ext}")
+ arch=${name##*_}
+ if [[ "${name}" == "${arch}" ]]; then
+ continue # Not a regular package.
+ fi
+ if [[ "${pkg}" =~ ^.*\.deb$ ]]; then
+ # Extract from the debian file.
+ version=$(dpkg --info "${pkg}" | grep -E 'Version:' | cut -d':' -f2)
+ elif [[ "${pkg}" =~ ^.*\.changes$ ]]; then
+ # Extract from the changes file.
+ version=$(grep -E 'Version:' "${pkg}" | cut -d':' -f2)
+ else
+ # Unsupported file type.
+ echo "Unknown file type: ${pkg}"
+ exit 1
+ fi
+
+ # The package may already exist, in which case we leave it alone.
+ version=${version// /} # Trim whitespace.
+ destdir="${root}/pool/${version}/binary-${arch}"
+ target="${destdir}/${name}.${ext}"
+ if [[ -f "${target}" ]]; then
+ continue
+ fi
+
+ # Copy & sign the package.
+ mkdir -p "${destdir}"
+ cp -a "${pkg}" "${target}"
+ chmod 0644 "${target}"
+ if [[ "${ext}" == "deb" ]]; then
+ dpkg-sig -g "--no-default-keyring --keyring ${keyring}" --sign builder "${target}"
+ fi
+done
+
+# Build the package list.
+declare arches=()
+for dir in "${root}"/pool/*/binary-*; do
+ name=$(basename "${dir}")
+ arch=${name##binary-}
+ arches+=("${arch}")
+ repo_packages="${release}"/main/"${name}"
+ mkdir -p "${repo_packages}"
+ (cd "${root}" && apt-ftparchive --arch "${arch}" packages pool > "${repo_packages}"/Packages)
+ (cd "${repo_packages}" && cat Packages | gzip > Packages.gz)
+ (cd "${repo_packages}" && cat Packages | xz > Packages.xz)
+done
+
+# Build the release list.
+cat > "${release}"/apt.conf <<EOF
+APT {
+ FTPArchive {
+ Release {
+ Architectures "${arches[@]}";
+ Suite "${suite}";
+ Components "main";
+ };
+ };
+};
+EOF
+(cd "${release}" && apt-ftparchive -c=apt.conf release . > Release)
+rm "${release}"/apt.conf
+
+# Sign the release.
+declare -r digest_opts=("--digest-algo" "SHA512" "--cert-digest-algo" "SHA512")
+(cd "${release}" && rm -f Release.gpg InRelease)
+(cd "${release}" && gpg --no-default-keyring --keyring "${keyring}" --clearsign "${digest_opts[@]}" -o InRelease Release)
+(cd "${release}" && gpg --no-default-keyring --keyring "${keyring}" -abs "${digest_opts[@]}" -o Release.gpg Release)
diff --git a/tools/make_release.sh b/tools/make_release.sh
new file mode 100755
index 000000000..b1cdd47b0
--- /dev/null
+++ b/tools/make_release.sh
@@ -0,0 +1,82 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if [[ "$#" -le 2 ]]; then
+ echo "usage: $0 <private-key> <root> <binaries & packages...>"
+ echo "The environment variable NIGHTLY may be set to control"
+ echo "whether the nightly packages are produced or not."
+ exit 1
+fi
+
+set -xeo pipefail
+declare -r private_key="$1"; shift
+declare -r root="$1"; shift
+declare -a binaries
+declare -a pkgs
+
+# Collect binaries & packages.
+for arg in "$@"; do
+ if [[ "${arg}" == *.deb ]] || [[ "${arg}" == *.changes ]]; then
+ pkgs+=("${arg}")
+ else
+ binaries+=("${arg}")
+ fi
+done
+
+# install_raw installs raw artifacts.
+install_raw() {
+ mkdir -p "${root}/$1"
+ for binary in "${binaries[@]}"; do
+ # Copy the raw file & generate a sha512sum.
+ name=$(basename "${binary}")
+ cp -f "${binary}" "${root}/$1"
+ sha512sum "${root}/$1/${name}" | \
+ awk "{print $$1 \" ${name}\"}" > "${root}/$1/${name}.sha512"
+ done
+}
+
+# install_apt installs an apt repository.
+install_apt() {
+ tools/make_apt.sh "${private_key}" "$1" "${root}" "${pkgs[@]}"
+}
+
+# If nightly, install only nightly artifacts.
+if [[ "${NIGHTLY:-false}" == "true" ]]; then
+ # The "latest" directory and current date.
+ stamp="$(date -Idate)"
+ install_raw "nightly/latest"
+ install_raw "nightly/${stamp}"
+ install_apt "nightly"
+else
+ # Is it a tagged release? Build that.
+ tags="$(git tag --points-at HEAD 2>/dev/null || true)"
+ if ! [[ -z "${tags}" ]]; then
+ # Note that a given commit can match any number of tags. We have to iterate
+ # through all possible tags and produce associated artifacts.
+ for tag in ${tags}; do
+ name=$(echo "${tag}" | cut -d'-' -f2)
+ base=$(echo "${name}" | cut -d'.' -f1)
+ install_raw "release/${name}"
+ install_raw "release/latest"
+ install_apt "release"
+ install_apt "${base}"
+ done
+ else
+ # Otherwise, assume it is a raw master commit.
+ install_raw "master/latest"
+ install_apt "master"
+ fi
+fi
diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD
new file mode 100644
index 000000000..c21b09511
--- /dev/null
+++ b/tools/nogo/BUILD
@@ -0,0 +1,49 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "nogo",
+ srcs = [
+ "build.go",
+ "config.go",
+ "matchers.go",
+ "nogo.go",
+ "register.go",
+ ],
+ nogo = False,
+ visibility = ["//:sandbox"],
+ deps = [
+ "//tools/checkescape",
+ "//tools/checkunsafe",
+ "//tools/nogo/data",
+ "@org_golang_x_tools//go/analysis:go_tool_library",
+ "@org_golang_x_tools//go/analysis/internal/facts:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/asmdecl:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/assign:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/atomic:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/bools:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/buildtag:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/cgocall:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/composite:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/copylock:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/errorsas:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/httpresponse:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/loopclosure:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/lostcancel:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/nilfunc:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/nilness:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/printf:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/shadow:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/shift:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/stdmethods:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/stringintconv:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/structtag:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/tests:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/unmarshal:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/unreachable:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/unsafeptr:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/unusedresult:go_tool_library",
+ "@org_golang_x_tools//go/gcexportdata:go_tool_library",
+ ],
+)
diff --git a/tools/nogo/README.md b/tools/nogo/README.md
new file mode 100644
index 000000000..6e4db18de
--- /dev/null
+++ b/tools/nogo/README.md
@@ -0,0 +1,31 @@
+# Extended "nogo" analysis
+
+This package provides a build aspect that perform nogo analysis. This will be
+automatically injected to all relevant libraries when using the default
+`go_binary` and `go_library` rules.
+
+It exists for several reasons.
+
+* The default `nogo` provided by bazel is insufficient with respect to the
+ possibility of binary analysis. This package allows us to analyze the
+ generated binary in addition to using the standard analyzers.
+
+* The configuration provided in this package is much richer than the standard
+ `nogo` JSON blob. Specifically, it allows us to exclude specific structures
+ from the composite rules (such as the Ranges that are common with the set
+ types).
+
+* The bazel version of `nogo` is run directly against the `go_library` and
+ `go_binary` targets, meaning that any change to the configuration requires a
+ rebuild from scratch (for some reason included all C++ source files in the
+ process). Using an aspect is more efficient in this regard.
+
+* The checks supported by this package are exported as tests, which makes it
+ easier to reason about and plumb into the build system.
+
+* For uninteresting reasons, it is impossible to integrate the default `nogo`
+ analyzer provided by bazel with internal Google tooling. To provide a
+ consistent experience, this package allows those systems to be unified.
+
+To use this package, import `nogo_test` from `defs.bzl` and add a single
+dependency which is a `go_binary` or `go_library` rule.
diff --git a/tools/nogo/build.go b/tools/nogo/build.go
new file mode 100644
index 000000000..1c0d08661
--- /dev/null
+++ b/tools/nogo/build.go
@@ -0,0 +1,36 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package nogo
+
+import (
+ "fmt"
+ "io"
+ "os"
+)
+
+var (
+ // internalPrefix is the internal path prefix. Note that this is not
+ // special, as paths should be passed relative to the repository root
+ // and should not have any special prefix applied.
+ internalPrefix = fmt.Sprintf("^")
+
+ // externalPrefix is external workspace packages.
+ externalPrefix = "^external/"
+)
+
+// findStdPkg needs to find the bundled standard library packages.
+func findStdPkg(path, GOOS, GOARCH string) (io.ReadCloser, error) {
+ return os.Open(fmt.Sprintf("external/go_sdk/pkg/%s_%s/%s.a", GOOS, GOARCH, path))
+}
diff --git a/tools/nogo/check/BUILD b/tools/nogo/check/BUILD
new file mode 100644
index 000000000..e2d76cd5c
--- /dev/null
+++ b/tools/nogo/check/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+# Note that the check binary must be public, since an aspect may be applied
+# across lots of different rules in different repositories.
+go_binary(
+ name = "check",
+ srcs = ["main.go"],
+ visibility = ["//visibility:public"],
+ deps = ["//tools/nogo"],
+)
diff --git a/tools/nogo/check/main.go b/tools/nogo/check/main.go
new file mode 100644
index 000000000..3828edf3a
--- /dev/null
+++ b/tools/nogo/check/main.go
@@ -0,0 +1,24 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary check is the nogo entrypoint.
+package main
+
+import (
+ "gvisor.dev/gvisor/tools/nogo"
+)
+
+func main() {
+ nogo.Main()
+}
diff --git a/tools/nogo/config.go b/tools/nogo/config.go
new file mode 100644
index 000000000..6958fca69
--- /dev/null
+++ b/tools/nogo/config.go
@@ -0,0 +1,116 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package nogo
+
+import (
+ "golang.org/x/tools/go/analysis"
+ "golang.org/x/tools/go/analysis/passes/asmdecl"
+ "golang.org/x/tools/go/analysis/passes/assign"
+ "golang.org/x/tools/go/analysis/passes/atomic"
+ "golang.org/x/tools/go/analysis/passes/bools"
+ "golang.org/x/tools/go/analysis/passes/buildtag"
+ "golang.org/x/tools/go/analysis/passes/cgocall"
+ "golang.org/x/tools/go/analysis/passes/composite"
+ "golang.org/x/tools/go/analysis/passes/copylock"
+ "golang.org/x/tools/go/analysis/passes/errorsas"
+ "golang.org/x/tools/go/analysis/passes/httpresponse"
+ "golang.org/x/tools/go/analysis/passes/loopclosure"
+ "golang.org/x/tools/go/analysis/passes/lostcancel"
+ "golang.org/x/tools/go/analysis/passes/nilfunc"
+ "golang.org/x/tools/go/analysis/passes/nilness"
+ "golang.org/x/tools/go/analysis/passes/printf"
+ "golang.org/x/tools/go/analysis/passes/shadow"
+ "golang.org/x/tools/go/analysis/passes/shift"
+ "golang.org/x/tools/go/analysis/passes/stdmethods"
+ "golang.org/x/tools/go/analysis/passes/stringintconv"
+ "golang.org/x/tools/go/analysis/passes/structtag"
+ "golang.org/x/tools/go/analysis/passes/tests"
+ "golang.org/x/tools/go/analysis/passes/unmarshal"
+ "golang.org/x/tools/go/analysis/passes/unreachable"
+ "golang.org/x/tools/go/analysis/passes/unsafeptr"
+ "golang.org/x/tools/go/analysis/passes/unusedresult"
+
+ "gvisor.dev/gvisor/tools/checkescape"
+ "gvisor.dev/gvisor/tools/checkunsafe"
+)
+
+var analyzerConfig = map[*analysis.Analyzer]matcher{
+ // Standard analyzers.
+ asmdecl.Analyzer: alwaysMatches(),
+ assign.Analyzer: externalExcluded(
+ ".*gazelle/walk/walk.go", // False positive.
+ ),
+ atomic.Analyzer: alwaysMatches(),
+ bools.Analyzer: alwaysMatches(),
+ buildtag.Analyzer: alwaysMatches(),
+ cgocall.Analyzer: alwaysMatches(),
+ composite.Analyzer: and(
+ disableMatches(), // Disabled for now.
+ resultExcluded{
+ "Object_",
+ "Range{",
+ },
+ ),
+ copylock.Analyzer: internalMatches(), // Common external issues (e.g. protos).
+ errorsas.Analyzer: alwaysMatches(),
+ httpresponse.Analyzer: alwaysMatches(),
+ loopclosure.Analyzer: alwaysMatches(),
+ lostcancel.Analyzer: internalMatches(), // Common external issues.
+ nilfunc.Analyzer: alwaysMatches(),
+ nilness.Analyzer: and(
+ internalMatches(), // Common "tautological checks".
+ internalExcluded(
+ "pkg/sentry/platform/kvm/kvm_test.go", // Intentional.
+ "tools/bigquery/bigquery.go", // False positive.
+ ),
+ ),
+ printf.Analyzer: alwaysMatches(),
+ shift.Analyzer: alwaysMatches(),
+ stdmethods.Analyzer: internalMatches(), // Common external issues (e.g. methods named "Write").
+ stringintconv.Analyzer: and(
+ internalExcluded(),
+ externalExcluded(
+ ".*protobuf/.*.go", // Bad conversions.
+ ".*flate/huffman_bit_writer.go", // Bad conversion.
+ ),
+ ),
+ shadow.Analyzer: disableMatches(), // Disabled for now.
+ structtag.Analyzer: internalMatches(), // External not subject to rules.
+ tests.Analyzer: alwaysMatches(),
+ unmarshal.Analyzer: alwaysMatches(),
+ unreachable.Analyzer: internalMatches(),
+ unsafeptr.Analyzer: and(
+ internalMatches(),
+ internalExcluded(
+ ".*_test.go", // Exclude tests.
+ "pkg/flipcall/.*_unsafe.go", // Special case.
+ "pkg/gohacks/gohacks_unsafe.go", // Special case.
+ "pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go", // Special case.
+ "pkg/sentry/platform/kvm/bluepill_unsafe.go", // Special case.
+ "pkg/sentry/platform/kvm/machine_unsafe.go", // Special case.
+ "pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go", // Special case.
+ "pkg/sentry/platform/safecopy/safecopy_unsafe.go", // Special case.
+ "pkg/sentry/vfs/mount_unsafe.go", // Special case.
+ "pkg/sentry/platform/systrap/stub_unsafe.go", // Special case.
+ "pkg/sentry/platform/systrap/switchto_google_unsafe.go", // Special case.
+ "pkg/sentry/platform/systrap/sysmsg_thread_unsafe.go", // Special case.
+ ),
+ ),
+ unusedresult.Analyzer: alwaysMatches(),
+
+ // Internal analyzers: external packages not subject.
+ checkescape.Analyzer: internalMatches(),
+ checkunsafe.Analyzer: internalMatches(),
+}
diff --git a/tools/nogo/data/BUILD b/tools/nogo/data/BUILD
new file mode 100644
index 000000000..b7564cc44
--- /dev/null
+++ b/tools/nogo/data/BUILD
@@ -0,0 +1,10 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "data",
+ srcs = ["data.go"],
+ nogo = False,
+ visibility = ["//tools:__subpackages__"],
+)
diff --git a/tools/nogo/data/data.go b/tools/nogo/data/data.go
new file mode 100644
index 000000000..eb84d0d27
--- /dev/null
+++ b/tools/nogo/data/data.go
@@ -0,0 +1,21 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package data contains shared data for nogo analysis.
+//
+// This is used to break a dependency cycle.
+package data
+
+// Objdump is the dumped binary under analysis.
+var Objdump string
diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl
new file mode 100644
index 000000000..6560b57c8
--- /dev/null
+++ b/tools/nogo/defs.bzl
@@ -0,0 +1,172 @@
+"""Nogo rules."""
+
+load("//tools/bazeldefs:defs.bzl", "go_context", "go_importpath", "go_rule")
+
+# NogoInfo is the serialized set of package facts for a nogo analysis.
+#
+# Each go_library rule will generate a corresponding nogo rule, which will run
+# with the source files as input. Note however, that the individual nogo rules
+# are simply stubs that enter into the shadow dependency tree (the "aspect").
+NogoInfo = provider(
+ fields = {
+ "facts": "serialized package facts",
+ "importpath": "package import path",
+ "binaries": "package binary files",
+ },
+)
+
+def _nogo_aspect_impl(target, ctx):
+ # If this is a nogo rule itself (and not the shadow of a go_library or
+ # go_binary rule created by such a rule), then we simply return nothing.
+ # All work is done in the shadow properties for go rules. For a proto
+ # library, we simply skip the analysis portion but still need to return a
+ # valid NogoInfo to reference the generated binary.
+ if ctx.rule.kind == "go_library":
+ srcs = ctx.rule.files.srcs
+ elif ctx.rule.kind == "go_proto_library" or ctx.rule.kind == "go_wrap_cc":
+ srcs = []
+ else:
+ return [NogoInfo()]
+
+ # Construct the Go environment from the go_context.env dictionary.
+ env_prefix = " ".join(["%s=%s" % (key, value) for (key, value) in go_context(ctx).env.items()])
+
+ # Start with all target files and srcs as input.
+ inputs = target.files.to_list() + srcs
+
+ # Generate a shell script that dumps the binary. Annoyingly, this seems
+ # necessary as the context in which a run_shell command runs does not seem
+ # to cleanly allow us redirect stdout to the actual output file. Perhaps
+ # I'm missing something here, but the intermediate script does work.
+ binaries = target.files.to_list()
+ disasm_file = ctx.actions.declare_file(target.label.name + ".out")
+ dumper = ctx.actions.declare_file("%s-dumper" % ctx.label.name)
+ ctx.actions.write(dumper, "\n".join([
+ "#!/bin/bash",
+ "%s %s tool objdump %s > %s\n" % (
+ env_prefix,
+ go_context(ctx).go.path,
+ [f.path for f in binaries if f.path.endswith(".a")][0],
+ disasm_file.path,
+ ),
+ ]), is_executable = True)
+ ctx.actions.run(
+ inputs = binaries,
+ outputs = [disasm_file],
+ tools = go_context(ctx).runfiles,
+ mnemonic = "GoObjdump",
+ progress_message = "Objdump %s" % target.label,
+ executable = dumper,
+ )
+ inputs.append(disasm_file)
+
+ # Extract the importpath for this package.
+ importpath = go_importpath(target)
+
+ # The nogo tool requires a configfile serialized in JSON format to do its
+ # work. This must line up with the nogo.Config fields.
+ facts = ctx.actions.declare_file(target.label.name + ".facts")
+ config = struct(
+ ImportPath = importpath,
+ GoFiles = [src.path for src in srcs if src.path.endswith(".go")],
+ NonGoFiles = [src.path for src in srcs if not src.path.endswith(".go")],
+ GOOS = go_context(ctx).goos,
+ GOARCH = go_context(ctx).goarch,
+ Tags = go_context(ctx).tags,
+ FactMap = {}, # Constructed below.
+ ImportMap = {}, # Constructed below.
+ FactOutput = facts.path,
+ Objdump = disasm_file.path,
+ )
+
+ # Collect all info from shadow dependencies.
+ for dep in ctx.rule.attr.deps:
+ # There will be no file attribute set for all transitive dependencies
+ # that are not go_library or go_binary rules, such as a proto rules.
+ # This is handled by the ctx.rule.kind check above.
+ info = dep[NogoInfo]
+ if not hasattr(info, "facts"):
+ continue
+
+ # Configure where to find the binary & fact files. Note that this will
+ # use .x and .a regardless of whether this is a go_binary rule, since
+ # these dependencies must be go_library rules.
+ x_files = [f.path for f in info.binaries if f.path.endswith(".x")]
+ if not len(x_files):
+ x_files = [f.path for f in info.binaries if f.path.endswith(".a")]
+ config.ImportMap[info.importpath] = x_files[0]
+ config.FactMap[info.importpath] = info.facts.path
+
+ # Ensure the above are available as inputs.
+ inputs.append(info.facts)
+ inputs += info.binaries
+
+ # Write the configuration and run the tool.
+ config_file = ctx.actions.declare_file(target.label.name + ".cfg")
+ ctx.actions.write(config_file, config.to_json())
+ inputs.append(config_file)
+
+ # Run the nogo tool itself.
+ ctx.actions.run(
+ inputs = inputs,
+ outputs = [facts],
+ tools = go_context(ctx).runfiles,
+ executable = ctx.files._nogo[0],
+ mnemonic = "GoStaticAnalysis",
+ progress_message = "Analyzing %s" % target.label,
+ arguments = ["-config=%s" % config_file.path],
+ )
+
+ # Return the package facts as output.
+ return [NogoInfo(
+ facts = facts,
+ importpath = importpath,
+ binaries = binaries,
+ )]
+
+nogo_aspect = go_rule(
+ aspect,
+ implementation = _nogo_aspect_impl,
+ attr_aspects = ["deps"],
+ attrs = {
+ "_nogo": attr.label(
+ default = "//tools/nogo/check:check",
+ allow_single_file = True,
+ ),
+ },
+)
+
+def _nogo_test_impl(ctx):
+ """Check nogo findings."""
+
+ # Build a runner that checks for the existence of the facts file. Note that
+ # the actual build will fail in the case of a broken analysis. We things
+ # this way so that any test applied is effectively pushed down to all
+ # upstream dependencies through the aspect.
+ inputs = []
+ runner = ctx.actions.declare_file("%s-executer" % ctx.label.name)
+ runner_content = ["#!/bin/bash"]
+ for dep in ctx.attr.deps:
+ info = dep[NogoInfo]
+ inputs.append(info.facts)
+
+ # Draw a sweet unicode checkmark with the package name (in green).
+ runner_content.append("echo -e \"\\033[0;32m\\xE2\\x9C\\x94\\033[0;31m\\033[0m %s\"" % info.importpath)
+ runner_content.append("exit 0\n")
+ ctx.actions.write(runner, "\n".join(runner_content), is_executable = True)
+ return [DefaultInfo(
+ runfiles = ctx.runfiles(files = inputs),
+ executable = runner,
+ )]
+
+_nogo_test = rule(
+ implementation = _nogo_test_impl,
+ attrs = {
+ "deps": attr.label_list(aspects = [nogo_aspect]),
+ },
+ test = True,
+)
+
+def nogo_test(**kwargs):
+ tags = kwargs.pop("tags", []) + ["nogo"]
+ _nogo_test(tags = tags, **kwargs)
diff --git a/tools/nogo/io_bazel_rules_go-visibility.patch b/tools/nogo/io_bazel_rules_go-visibility.patch
new file mode 100644
index 000000000..6b64b2e85
--- /dev/null
+++ b/tools/nogo/io_bazel_rules_go-visibility.patch
@@ -0,0 +1,25 @@
+diff --git a/third_party/org_golang_x_tools-extras.patch b/third_party/org_golang_x_tools-extras.patch
+index 133fbccc..5f0d9a47 100644
+--- a/third_party/org_golang_x_tools-extras.patch
++++ b/third_party/org_golang_x_tools-extras.patch
+@@ -32,7 +32,7 @@ diff -urN c/go/analysis/internal/facts/BUILD.bazel d/go/analysis/internal/facts/
+
+ go_library(
+ name = "go_default_library",
+-@@ -14,6 +14,23 @@
++@@ -14,6 +14,20 @@
+ ],
+ )
+
+@@ -43,10 +43,7 @@ diff -urN c/go/analysis/internal/facts/BUILD.bazel d/go/analysis/internal/facts/
+ + "imports.go",
+ + ],
+ + importpath = "golang.org/x/tools/go/analysis/internal/facts",
+-+ visibility = [
+-+ "//go/analysis:__subpackages__",
+-+ "@io_bazel_rules_go//go/tools/builders:__pkg__",
+-+ ],
+++ visibility = ["//visibility:public"],
+ + deps = [
+ + "//go/analysis:go_tool_library",
+ + "//go/types/objectpath:go_tool_library",
diff --git a/tools/nogo/matchers.go b/tools/nogo/matchers.go
new file mode 100644
index 000000000..57a250501
--- /dev/null
+++ b/tools/nogo/matchers.go
@@ -0,0 +1,143 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package nogo
+
+import (
+ "go/token"
+ "path/filepath"
+ "regexp"
+ "strings"
+
+ "golang.org/x/tools/go/analysis"
+)
+
+type matcher interface {
+ ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool
+}
+
+// pathRegexps filters explicit paths.
+type pathRegexps struct {
+ expr []*regexp.Regexp
+
+ // include, if true, indicates that paths matching any regexp in expr
+ // match.
+ //
+ // If false, paths matching no regexps in expr match.
+ include bool
+}
+
+// buildRegexps builds a list of regular expressions.
+//
+// This will panic on error.
+func buildRegexps(prefix string, args ...string) []*regexp.Regexp {
+ result := make([]*regexp.Regexp, 0, len(args))
+ for _, arg := range args {
+ result = append(result, regexp.MustCompile(filepath.Join(prefix, arg)))
+ }
+ return result
+}
+
+// ShouldReport implements matcher.ShouldReport.
+func (p *pathRegexps) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool {
+ fullPos := fs.Position(d.Pos).String()
+ for _, path := range p.expr {
+ if path.MatchString(fullPos) {
+ return p.include
+ }
+ }
+ return !p.include
+}
+
+// internalExcluded excludes specific internal paths.
+func internalExcluded(paths ...string) *pathRegexps {
+ return &pathRegexps{
+ expr: buildRegexps(internalPrefix, paths...),
+ include: false,
+ }
+}
+
+// excludedExcluded excludes specific external paths.
+func externalExcluded(paths ...string) *pathRegexps {
+ return &pathRegexps{
+ expr: buildRegexps(externalPrefix, paths...),
+ include: false,
+ }
+}
+
+// internalMatches returns a path matcher for internal packages.
+func internalMatches() *pathRegexps {
+ return &pathRegexps{
+ expr: buildRegexps(internalPrefix, ".*"),
+ include: true,
+ }
+}
+
+// resultExcluded excludes explicit message contents.
+type resultExcluded []string
+
+// ShouldReport implements matcher.ShouldReport.
+func (r resultExcluded) ShouldReport(d analysis.Diagnostic, _ *token.FileSet) bool {
+ for _, str := range r {
+ if strings.Contains(d.Message, str) {
+ return false
+ }
+ }
+ return true // Not excluded.
+}
+
+// andMatcher is a composite matcher.
+type andMatcher struct {
+ first matcher
+ second matcher
+}
+
+// ShouldReport implements matcher.ShouldReport.
+func (a *andMatcher) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool {
+ return a.first.ShouldReport(d, fs) && a.second.ShouldReport(d, fs)
+}
+
+// and is a syntactic convension for andMatcher.
+func and(first matcher, second matcher) *andMatcher {
+ return &andMatcher{
+ first: first,
+ second: second,
+ }
+}
+
+// anyMatcher matches everything.
+type anyMatcher struct{}
+
+// ShouldReport implements matcher.ShouldReport.
+func (anyMatcher) ShouldReport(analysis.Diagnostic, *token.FileSet) bool {
+ return true
+}
+
+// alwaysMatches returns an anyMatcher instance.
+func alwaysMatches() anyMatcher {
+ return anyMatcher{}
+}
+
+// neverMatcher will never match.
+type neverMatcher struct{}
+
+// ShouldReport implements matcher.ShouldReport.
+func (neverMatcher) ShouldReport(analysis.Diagnostic, *token.FileSet) bool {
+ return false
+}
+
+// disableMatches returns a neverMatcher instance.
+func disableMatches() neverMatcher {
+ return neverMatcher{}
+}
diff --git a/tools/nogo/nogo.go b/tools/nogo/nogo.go
new file mode 100644
index 000000000..203cdf688
--- /dev/null
+++ b/tools/nogo/nogo.go
@@ -0,0 +1,316 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package nogo implements binary analysis similar to bazel's nogo,
+// or the unitchecker package. It exists in order to provide additional
+// facilities for analysis, namely plumbing through the output from
+// dumping the generated binary (to analyze actual produced code).
+package nogo
+
+import (
+ "encoding/json"
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/build"
+ "go/parser"
+ "go/token"
+ "go/types"
+ "io"
+ "io/ioutil"
+ "log"
+ "os"
+ "path/filepath"
+ "reflect"
+
+ "golang.org/x/tools/go/analysis"
+ "golang.org/x/tools/go/analysis/internal/facts"
+ "golang.org/x/tools/go/gcexportdata"
+ "gvisor.dev/gvisor/tools/nogo/data"
+)
+
+// pkgConfig is serialized as the configuration.
+//
+// This contains everything required for the analysis.
+type pkgConfig struct {
+ ImportPath string
+ GoFiles []string
+ NonGoFiles []string
+ Tags []string
+ GOOS string
+ GOARCH string
+ ImportMap map[string]string
+ FactMap map[string]string
+ FactOutput string
+ Objdump string
+}
+
+// loadFacts finds and loads facts per FactMap.
+func (c *pkgConfig) loadFacts(path string) ([]byte, error) {
+ realPath, ok := c.FactMap[path]
+ if !ok {
+ return nil, nil // No facts available.
+ }
+
+ // Read the files file.
+ data, err := ioutil.ReadFile(realPath)
+ if err != nil {
+ return nil, err
+ }
+ return data, nil
+}
+
+// shouldInclude indicates whether the file should be included.
+//
+// NOTE: This does only basic parsing of tags.
+func (c *pkgConfig) shouldInclude(path string) (bool, error) {
+ ctx := build.Default
+ ctx.GOOS = c.GOOS
+ ctx.GOARCH = c.GOARCH
+ ctx.BuildTags = c.Tags
+ return ctx.MatchFile(filepath.Dir(path), filepath.Base(path))
+}
+
+// importer is an implementation of go/types.Importer.
+//
+// This wraps a configuration, which provides the map of package names to
+// files, and the facts. Note that this importer implementation will always
+// pass when a given package is not available.
+type importer struct {
+ pkgConfig
+ fset *token.FileSet
+ cache map[string]*types.Package
+}
+
+// Import implements types.Importer.Import.
+func (i *importer) Import(path string) (*types.Package, error) {
+ if path == "unsafe" {
+ // Special case: go/types has pre-defined type information for
+ // unsafe. We ensure that this package is correct, in case any
+ // analyzers are specifically looking for this.
+ return types.Unsafe, nil
+ }
+ realPath, ok := i.ImportMap[path]
+ var (
+ rc io.ReadCloser
+ err error
+ )
+ if !ok {
+ // Not found in the import path. Attempt to find the package
+ // via the standard library.
+ rc, err = findStdPkg(path, i.GOOS, i.GOARCH)
+ } else {
+ // Open the file.
+ rc, err = os.Open(realPath)
+ }
+ if err != nil {
+ return nil, err
+ }
+ defer rc.Close()
+
+ // Load all exported data.
+ r, err := gcexportdata.NewReader(rc)
+ if err != nil {
+ return nil, err
+ }
+
+ return gcexportdata.Read(r, i.fset, i.cache, path)
+}
+
+// checkPackage runs all analyzers.
+//
+// The implementation was adapted from [1], which was in turn adpated from [2].
+// This returns a list of matching analysis issues, or an error if the analysis
+// could not be completed.
+//
+// [1] bazelbuid/rules_go/tools/builders/nogo_main.go
+// [2] golang.org/x/tools/go/checker/internal/checker
+func checkPackage(config pkgConfig) ([]string, error) {
+ imp := &importer{
+ pkgConfig: config,
+ fset: token.NewFileSet(),
+ cache: make(map[string]*types.Package),
+ }
+
+ // Load all source files.
+ var syntax []*ast.File
+ for _, file := range config.GoFiles {
+ include, err := config.shouldInclude(file)
+ if err != nil {
+ return nil, fmt.Errorf("error evaluating file %q: %v", file, err)
+ }
+ if !include {
+ continue
+ }
+ s, err := parser.ParseFile(imp.fset, file, nil, parser.ParseComments)
+ if err != nil {
+ return nil, fmt.Errorf("error parsing file %q: %v", file, err)
+ }
+ syntax = append(syntax, s)
+ }
+
+ // Check type information.
+ typesSizes := types.SizesFor("gc", config.GOARCH)
+ typeConfig := types.Config{Importer: imp}
+ typesInfo := &types.Info{
+ Types: make(map[ast.Expr]types.TypeAndValue),
+ Uses: make(map[*ast.Ident]types.Object),
+ Defs: make(map[*ast.Ident]types.Object),
+ Implicits: make(map[ast.Node]types.Object),
+ Scopes: make(map[ast.Node]*types.Scope),
+ Selections: make(map[*ast.SelectorExpr]*types.Selection),
+ }
+ types, err := typeConfig.Check(config.ImportPath, imp.fset, syntax, typesInfo)
+ if err != nil {
+ return nil, fmt.Errorf("error checking types: %v", err)
+ }
+
+ // Load all package facts.
+ facts, err := facts.Decode(types, config.loadFacts)
+ if err != nil {
+ return nil, fmt.Errorf("error decoding facts: %v", err)
+ }
+
+ // Set the binary global for use.
+ data.Objdump = config.Objdump
+
+ // Register fact types and establish dependencies between analyzers.
+ // The visit closure will execute recursively, and populate results
+ // will all required analysis results.
+ diagnostics := make(map[*analysis.Analyzer][]analysis.Diagnostic)
+ results := make(map[*analysis.Analyzer]interface{})
+ var visit func(*analysis.Analyzer) error // For recursion.
+ visit = func(a *analysis.Analyzer) error {
+ if _, ok := results[a]; ok {
+ return nil
+ }
+
+ // Run recursively for all dependencies.
+ for _, req := range a.Requires {
+ if err := visit(req); err != nil {
+ return err
+ }
+ }
+
+ // Prepare the matcher.
+ m := analyzerConfig[a]
+ report := func(d analysis.Diagnostic) {
+ if m.ShouldReport(d, imp.fset) {
+ diagnostics[a] = append(diagnostics[a], d)
+ }
+ }
+
+ // Run the analysis.
+ factFilter := make(map[reflect.Type]bool)
+ for _, f := range a.FactTypes {
+ factFilter[reflect.TypeOf(f)] = true
+ }
+ p := &analysis.Pass{
+ Analyzer: a,
+ Fset: imp.fset,
+ Files: syntax,
+ Pkg: types,
+ TypesInfo: typesInfo,
+ ResultOf: results, // All results.
+ Report: report,
+ ImportPackageFact: facts.ImportPackageFact,
+ ExportPackageFact: facts.ExportPackageFact,
+ ImportObjectFact: facts.ImportObjectFact,
+ ExportObjectFact: facts.ExportObjectFact,
+ AllPackageFacts: func() []analysis.PackageFact { return facts.AllPackageFacts(factFilter) },
+ AllObjectFacts: func() []analysis.ObjectFact { return facts.AllObjectFacts(factFilter) },
+ TypesSizes: typesSizes,
+ }
+ result, err := a.Run(p)
+ if err != nil {
+ return fmt.Errorf("error running analysis %s: %v", a, err)
+ }
+
+ // Sanity check & save the result.
+ if got, want := reflect.TypeOf(result), a.ResultType; got != want {
+ return fmt.Errorf("error: analyzer %s returned a result of type %v, but declared ResultType %v", a, got, want)
+ }
+ results[a] = result
+ return nil // Success.
+ }
+
+ // Visit all analysis recursively.
+ for a, _ := range analyzerConfig {
+ if err := visit(a); err != nil {
+ return nil, err // Already has context.
+ }
+ }
+
+ // Write the output file.
+ if config.FactOutput != "" {
+ factData := facts.Encode()
+ if err := ioutil.WriteFile(config.FactOutput, factData, 0644); err != nil {
+ return nil, fmt.Errorf("error: unable to open facts output %q: %v", config.FactOutput, err)
+ }
+ }
+
+ // Convert all diagnostics to strings.
+ findings := make([]string, 0, len(diagnostics))
+ for a, ds := range diagnostics {
+ for _, d := range ds {
+ // Include the anlyzer name for debugability and configuration.
+ findings = append(findings, fmt.Sprintf("%s: %s: %s", a.Name, imp.fset.Position(d.Pos), d.Message))
+ }
+ }
+
+ // Return all findings.
+ return findings, nil
+}
+
+var (
+ configFile = flag.String("config", "", "configuration file (in JSON format)")
+)
+
+// Main is the entrypoint; it should be called directly from main.
+//
+// N.B. This package registers it's own flags.
+func Main() {
+ // Parse all flags.
+ flag.Parse()
+
+ // Load the configuration.
+ f, err := os.Open(*configFile)
+ if err != nil {
+ log.Fatalf("unable to open configuration %q: %v", *configFile, err)
+ }
+ defer f.Close()
+ config := new(pkgConfig)
+ dec := json.NewDecoder(f)
+ dec.DisallowUnknownFields()
+ if err := dec.Decode(config); err != nil {
+ log.Fatalf("unable to decode configuration: %v", err)
+ }
+
+ // Process the package.
+ findings, err := checkPackage(*config)
+ if err != nil {
+ log.Fatalf("error checking package: %v", err)
+ }
+
+ // No findings?
+ if len(findings) == 0 {
+ os.Exit(0)
+ }
+
+ // Print findings and exit with non-zero code.
+ for _, finding := range findings {
+ fmt.Fprintf(os.Stdout, "%s\n", finding)
+ }
+ os.Exit(1)
+}
diff --git a/tools/nogo/register.go b/tools/nogo/register.go
new file mode 100644
index 000000000..62b499661
--- /dev/null
+++ b/tools/nogo/register.go
@@ -0,0 +1,64 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package nogo
+
+import (
+ "encoding/gob"
+ "log"
+
+ "golang.org/x/tools/go/analysis"
+)
+
+// analyzers returns all configured analyzers.
+func analyzers() (all []*analysis.Analyzer) {
+ for a, _ := range analyzerConfig {
+ all = append(all, a)
+ }
+ return all
+}
+
+func init() {
+ // Validate basic configuration.
+ if err := analysis.Validate(analyzers()); err != nil {
+ log.Fatalf("unable to validate analyzer: %v", err)
+ }
+
+ // Register all fact types.
+ //
+ // N.B. This needs to be done recursively, because there may be
+ // analyzers in the Requires list that do not appear explicitly above.
+ registered := make(map[*analysis.Analyzer]struct{})
+ var register func(*analysis.Analyzer)
+ register = func(a *analysis.Analyzer) {
+ if _, ok := registered[a]; ok {
+ return
+ }
+
+ // Regsiter dependencies.
+ for _, da := range a.Requires {
+ register(da)
+ }
+
+ // Register local facts.
+ for _, f := range a.FactTypes {
+ gob.Register(f)
+ }
+
+ registered[a] = struct{}{} // Done.
+ }
+ for _, a := range analyzers() {
+ register(a)
+ }
+}
diff --git a/tools/tag_release.sh b/tools/tag_release.sh
new file mode 100755
index 000000000..b0bab74b4
--- /dev/null
+++ b/tools/tag_release.sh
@@ -0,0 +1,82 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script will optionally map a PiperOrigin-RevId to a given commit,
+# validate a provided release name, create a tag and push it. It must be
+# run manually when a release is created.
+
+set -xeuo pipefail
+
+# Check arguments.
+if [[ "$#" -ne 3 ]]; then
+ echo "usage: $0 <commit|revid> <release.rc> <message-file>"
+ exit 1
+fi
+
+declare -r target_commit="$1"
+declare -r release="$2"
+declare -r message_file="$3"
+
+if [[ -z "${target_commit}" ]]; then
+ echo "error: <commit|revid> is empty."
+fi
+if [[ -z "${release}" ]]; then
+ echo "error: <release.rc> is empty."
+fi
+if ! [[ -r "${message_file}" ]]; then
+ echo "error: message file '${message_file}' is not readable."
+ exit 1
+fi
+
+closest_commit() {
+ while read line; do
+ if [[ "$line" =~ "commit " ]]; then
+ current_commit="${line#commit }"
+ continue
+ elif [[ "$line" =~ "PiperOrigin-RevId: " ]]; then
+ revid="${line#PiperOrigin-RevId: }"
+ [[ "${revid}" -le "$1" ]] && break
+ fi
+ done
+ echo "${current_commit}"
+}
+
+# Is the passed identifier a sha commit?
+if ! git show "${target_commit}" &> /dev/null; then
+ # Extract the commit given a piper ID.
+ declare -r commit="$(git log | closest_commit "${target_commit}")"
+else
+ declare -r commit="${target_commit}"
+fi
+if ! git show "${commit}" &> /dev/null; then
+ echo "unknown commit: ${target_commit}"
+ exit 1
+fi
+
+# Is the release name sane? Must be a date with patch/rc.
+if ! [[ "${release}" =~ ^20[0-9]{6}\.[0-9]+$ ]]; then
+ declare -r expected="$(date +%Y%m%d.0)" # Use today's date.
+ echo "unexpected release format: ${release}"
+ echo " ... expected like ${expected}"
+ exit 1
+fi
+
+# Tag the given commit (annotated, to record the committer). Note that the tag
+# here is applied as a force, in case the tag already exists and is the same.
+# The push will fail in this case (because it is not forced).
+declare -r tag="release-${release}"
+git tag -f -F "${message_file}" -a "${tag}" "${commit}" && \
+ git push origin tag "${tag}"
diff --git a/tools/tags/BUILD b/tools/tags/BUILD
new file mode 100644
index 000000000..1c02e2c89
--- /dev/null
+++ b/tools/tags/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "tags",
+ srcs = ["tags.go"],
+ marshal = False,
+ stateify = False,
+ visibility = ["//tools:__subpackages__"],
+)
diff --git a/tools/tags/tags.go b/tools/tags/tags.go
new file mode 100644
index 000000000..f35904e0a
--- /dev/null
+++ b/tools/tags/tags.go
@@ -0,0 +1,89 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tags is a utility for parsing build tags.
+package tags
+
+import (
+ "fmt"
+ "io/ioutil"
+ "strings"
+)
+
+// OrSet is a set of tags on a single line.
+//
+// Note that tags may include ",", and we don't distinguish this case in the
+// logic below. Ideally, this constraints can be split into separate top-level
+// build tags in order to resolve any issues.
+type OrSet []string
+
+// Line returns the line for this or.
+func (or OrSet) Line() string {
+ return fmt.Sprintf("// +build %s", strings.Join([]string(or), " "))
+}
+
+// AndSet is the set of all OrSets.
+type AndSet []OrSet
+
+// Lines returns the lines to be printed.
+func (and AndSet) Lines() (ls []string) {
+ for _, or := range and {
+ ls = append(ls, or.Line())
+ }
+ return
+}
+
+// Join joins this AndSet with another.
+func (and AndSet) Join(other AndSet) AndSet {
+ return append(and, other...)
+}
+
+// Tags returns the unique set of +build tags.
+//
+// Derived form the runtime's canBuild.
+func Tags(file string) (tags AndSet) {
+ data, err := ioutil.ReadFile(file)
+ if err != nil {
+ return nil
+ }
+ // Check file contents for // +build lines.
+ for _, p := range strings.Split(string(data), "\n") {
+ p = strings.TrimSpace(p)
+ if p == "" {
+ continue
+ }
+ if !strings.HasPrefix(p, "//") {
+ break
+ }
+ if !strings.Contains(p, "+build") {
+ continue
+ }
+ fields := strings.Fields(p[2:])
+ if len(fields) < 1 || fields[0] != "+build" {
+ continue
+ }
+ tags = append(tags, OrSet(fields[1:]))
+ }
+ return tags
+}
+
+// Aggregate aggregates all tags from a set of files.
+//
+// Note that these may be in conflict, in which case the build will fail.
+func Aggregate(files []string) (tags AndSet) {
+ for _, file := range files {
+ tags = tags.Join(Tags(file))
+ }
+ return tags
+}
diff --git a/tools/vm/BUILD b/tools/vm/BUILD
new file mode 100644
index 000000000..f7160c627
--- /dev/null
+++ b/tools/vm/BUILD
@@ -0,0 +1,57 @@
+load("//tools:defs.bzl", "cc_binary", "gtest")
+load("//tools/vm:defs.bzl", "vm_image", "vm_test")
+
+package(
+ default_visibility = ["//:sandbox"],
+ licenses = ["notice"],
+)
+
+sh_binary(
+ name = "zone",
+ srcs = ["zone.sh"],
+)
+
+sh_binary(
+ name = "builder",
+ srcs = ["build.sh"],
+)
+
+sh_binary(
+ name = "executer",
+ srcs = ["execute.sh"],
+)
+
+cc_binary(
+ name = "test",
+ testonly = 1,
+ srcs = ["test.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:test_main",
+ ],
+)
+
+vm_image(
+ name = "ubuntu1604",
+ family = "ubuntu-1604-lts",
+ project = "ubuntu-os-cloud",
+ scripts = [
+ "//tools/vm/ubuntu1604",
+ ],
+)
+
+vm_image(
+ name = "ubuntu1804",
+ family = "ubuntu-1804-lts",
+ project = "ubuntu-os-cloud",
+ scripts = [
+ "//tools/vm/ubuntu1804",
+ ],
+)
+
+vm_test(
+ name = "vm_test",
+ shard_count = 2,
+ targets = [":test"],
+)
diff --git a/tools/vm/README.md b/tools/vm/README.md
new file mode 100644
index 000000000..898c95fca
--- /dev/null
+++ b/tools/vm/README.md
@@ -0,0 +1,42 @@
+# VM Images & Tests
+
+All commands in this directory require the `gcloud` project to be set.
+
+For example: `gcloud config set project gvisor-kokoro-testing`.
+
+Images can be generated by using the `vm_image` rule. This rule will generate a
+binary target that builds an image in an idempotent way, and can be referenced
+from other rules.
+
+For example:
+
+```
+vm_image(
+ name = "ubuntu",
+ project = "ubuntu-1604-lts",
+ family = "ubuntu-os-cloud",
+ scripts = [
+ "script.sh",
+ "other.sh",
+ ],
+)
+```
+
+These images can be built manually by executing the target. The output on
+`stdout` will be the image id (in the current project).
+
+Images are always named per the hash of all the hermetic input scripts. This
+allows images to be memoized quickly and easily.
+
+The `vm_test` rule can be used to execute a command remotely. This is still
+under development however, and will likely change over time.
+
+For example:
+
+```
+vm_test(
+ name = "mycommand",
+ image = ":ubuntu",
+ targets = [":test"],
+)
+```
diff --git a/tools/vm/build.sh b/tools/vm/build.sh
new file mode 100755
index 000000000..752b2b77b
--- /dev/null
+++ b/tools/vm/build.sh
@@ -0,0 +1,117 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script is responsible for building a new GCP image that: 1) has nested
+# virtualization enabled, and 2) has been completely set up with the
+# image_setup.sh script. This script should be idempotent, as we memoize the
+# setup script with a hash and check for that name.
+
+set -eou pipefail
+
+# Parameters.
+declare -r USERNAME=${USERNAME:-test}
+declare -r IMAGE_PROJECT=${IMAGE_PROJECT:-ubuntu-os-cloud}
+declare -r IMAGE_FAMILY=${IMAGE_FAMILY:-ubuntu-1604-lts}
+declare -r ZONE=${ZONE:-us-central1-f}
+
+# Random names.
+declare -r DISK_NAME=$(mktemp -u disk-XXXXXX | tr A-Z a-z)
+declare -r SNAPSHOT_NAME=$(mktemp -u snapshot-XXXXXX | tr A-Z a-z)
+declare -r INSTANCE_NAME=$(mktemp -u build-XXXXXX | tr A-Z a-z)
+
+# Hash inputs in order to memoize the produced image.
+declare -r SETUP_HASH=$( (echo ${USERNAME} ${IMAGE_PROJECT} ${IMAGE_FAMILY} && cat "$@") | sha256sum - | cut -d' ' -f1 | cut -c 1-16)
+declare -r IMAGE_NAME=${IMAGE_FAMILY:-image}-${SETUP_HASH}
+
+# Does the image already exist? Skip the build.
+declare -r existing=$(set -x; gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)")
+if ! [[ -z "${existing}" ]]; then
+ echo "${existing}"
+ exit 0
+fi
+
+# Standard arguments (applies only on script execution).
+declare -ar SSH_ARGS=("-o" "ConnectTimeout=60" "--")
+
+# gcloud has path errors; is this a result of being a genrule?
+export PATH=${PATH:-/bin:/usr/bin:/usr/local/bin}
+
+# Start a unique instance. Note that this instance will have a unique persistent
+# disk as it's boot disk with the same name as the instance.
+(set -x; gcloud compute instances create \
+ --quiet \
+ --image-project "${IMAGE_PROJECT}" \
+ --image-family "${IMAGE_FAMILY}" \
+ --boot-disk-size "200GB" \
+ --zone "${ZONE}" \
+ "${INSTANCE_NAME}" >/dev/null)
+function cleanup {
+ (set -x; gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}")
+}
+trap cleanup EXIT
+
+# Wait for the instance to become available (up to 5 minutes).
+echo -n "Waiting for ${INSTANCE_NAME}" >&2
+declare timeout=300
+declare success=0
+declare internal=""
+declare -r start=$(date +%s)
+declare -r end=$((${start}+${timeout}))
+while [[ "$(date +%s)" -lt "${end}" ]] && [[ "${success}" -lt 3 ]]; do
+ echo -n "." >&2
+ if gcloud compute ssh --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then
+ success=$((${success}+1))
+ elif gcloud compute ssh --internal-ip --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then
+ success=$((${success}+1))
+ internal="--internal-ip"
+ fi
+done
+
+if [[ "${success}" -eq "0" ]]; then
+ echo "connect timed out after ${timeout} seconds." >&2
+ exit 1
+else
+ echo "done." >&2
+fi
+
+# Run the install scripts provided.
+for arg; do
+ (set -x; gcloud compute ssh ${internal} \
+ --zone "${ZONE}" \
+ "${USERNAME}"@"${INSTANCE_NAME}" -- \
+ "${SSH_ARGS[@]}" \
+ sudo bash - <"${arg}" >/dev/null)
+done
+
+# Stop the instance; required before creating an image.
+(set -x; gcloud compute instances stop --quiet --zone "${ZONE}" "${INSTANCE_NAME}" >/dev/null)
+
+# Create a snapshot of the instance disk.
+(set -x; gcloud compute disks snapshot \
+ --quiet \
+ --zone "${ZONE}" \
+ --snapshot-names="${SNAPSHOT_NAME}" \
+ "${INSTANCE_NAME}" >/dev/null)
+
+# Create the disk image.
+(set -x; gcloud compute images create \
+ --quiet \
+ --source-snapshot="${SNAPSHOT_NAME}" \
+ --licenses="https://www.googleapis.com/compute/v1/projects/vm-options/global/licenses/enable-vmx" \
+ "${IMAGE_NAME}" >/dev/null)
+
+# Finish up.
+echo "${IMAGE_NAME}"
diff --git a/tools/vm/defs.bzl b/tools/vm/defs.bzl
new file mode 100644
index 000000000..0f67cfa92
--- /dev/null
+++ b/tools/vm/defs.bzl
@@ -0,0 +1,201 @@
+"""Image configuration. See README.md."""
+
+load("//tools:defs.bzl", "default_installer")
+
+# vm_image_builder is a rule that will construct a shell script that actually
+# generates a given VM image. Note that this does not _run_ the shell script
+# (although it can be run manually). It will be run manually during generation
+# of the vm_image target itself. This level of indirection is used so that the
+# build system itself only runs the builder once when multiple targets depend
+# on it, avoiding a set of races and conflicts.
+def _vm_image_builder_impl(ctx):
+ # Generate a binary that actually builds the image.
+ builder = ctx.actions.declare_file(ctx.label.name)
+ script_paths = []
+ for script in ctx.files.scripts:
+ script_paths.append(script.short_path)
+ builder_content = "\n".join([
+ "#!/bin/bash",
+ "export ZONE=$(%s)" % ctx.files.zone[0].short_path,
+ "export USERNAME=%s" % ctx.attr.username,
+ "export IMAGE_PROJECT=%s" % ctx.attr.project,
+ "export IMAGE_FAMILY=%s" % ctx.attr.family,
+ "%s %s" % (ctx.files._builder[0].short_path, " ".join(script_paths)),
+ "",
+ ])
+ ctx.actions.write(builder, builder_content, is_executable = True)
+
+ # Note that the scripts should only be files, and should not include any
+ # indirect transitive dependencies. The build script wouldn't work.
+ return [DefaultInfo(
+ executable = builder,
+ runfiles = ctx.runfiles(
+ files = ctx.files.scripts + ctx.files._builder + ctx.files.zone,
+ ),
+ )]
+
+vm_image_builder = rule(
+ attrs = {
+ "_builder": attr.label(
+ executable = True,
+ default = "//tools/vm:builder",
+ cfg = "host",
+ ),
+ "username": attr.string(default = "$(whoami)"),
+ "zone": attr.label(
+ executable = True,
+ default = "//tools/vm:zone",
+ cfg = "host",
+ ),
+ "family": attr.string(mandatory = True),
+ "project": attr.string(mandatory = True),
+ "scripts": attr.label_list(allow_files = True),
+ },
+ executable = True,
+ implementation = _vm_image_builder_impl,
+)
+
+# See vm_image_builder above.
+def _vm_image_impl(ctx):
+ # Run the builder to generate our output.
+ echo = ctx.actions.declare_file(ctx.label.name)
+ resolved_inputs, argv, runfiles_manifests = ctx.resolve_command(
+ command = "echo -ne \"#!/bin/bash\\nset -e\\nimage=$(%s)\\necho ${image}\\n\" > %s && chmod 0755 %s" % (
+ ctx.files.builder[0].path,
+ echo.path,
+ echo.path,
+ ),
+ tools = [ctx.attr.builder],
+ )
+ ctx.actions.run_shell(
+ tools = resolved_inputs,
+ outputs = [echo],
+ progress_message = "Building image...",
+ execution_requirements = {"local": "true"},
+ command = argv,
+ input_manifests = runfiles_manifests,
+ )
+
+ # Return just the echo command. All of the builder runfiles have been
+ # resolved and consumed in the generation of the trivial echo script.
+ return [DefaultInfo(executable = echo)]
+
+_vm_image_test = rule(
+ attrs = {
+ "builder": attr.label(
+ executable = True,
+ cfg = "host",
+ ),
+ },
+ test = True,
+ implementation = _vm_image_impl,
+)
+
+def vm_image(name, **kwargs):
+ vm_image_builder(
+ name = name + "_builder",
+ **kwargs
+ )
+ _vm_image_test(
+ name = name,
+ builder = ":" + name + "_builder",
+ tags = [
+ "local",
+ "manual",
+ ],
+ )
+
+def _vm_test_impl(ctx):
+ runner = ctx.actions.declare_file("%s-executer" % ctx.label.name)
+
+ # Note that the remote execution case must actually generate an
+ # intermediate target in order to collect all the relevant runfiles so that
+ # they can be copied over for remote execution.
+ runner_content = "\n".join([
+ "#!/bin/bash",
+ "export ZONE=$(%s)" % ctx.files.zone[0].short_path,
+ "export USERNAME=%s" % ctx.attr.username,
+ "export IMAGE=$(%s)" % ctx.files.image[0].short_path,
+ "export SUDO=%s" % "true" if ctx.attr.sudo else "false",
+ "%s %s" % (
+ ctx.executable.executer.short_path,
+ " ".join([
+ target.files_to_run.executable.short_path
+ for target in ctx.attr.targets
+ ]),
+ ),
+ "",
+ ])
+ ctx.actions.write(runner, runner_content, is_executable = True)
+
+ # Return with all transitive files.
+ runfiles = ctx.runfiles(
+ transitive_files = depset(transitive = [
+ depset(target.data_runfiles.files)
+ for target in ctx.attr.targets
+ if hasattr(target, "data_runfiles")
+ ]),
+ files = ctx.files.executer + ctx.files.zone + ctx.files.image +
+ ctx.files.targets,
+ collect_default = True,
+ collect_data = True,
+ )
+ return [DefaultInfo(executable = runner, runfiles = runfiles)]
+
+_vm_test = rule(
+ attrs = {
+ "image": attr.label(
+ executable = True,
+ default = "//tools/vm:ubuntu1804",
+ cfg = "host",
+ ),
+ "executer": attr.label(
+ executable = True,
+ default = "//tools/vm:executer",
+ cfg = "host",
+ ),
+ "username": attr.string(default = "$(whoami)"),
+ "zone": attr.label(
+ executable = True,
+ default = "//tools/vm:zone",
+ cfg = "host",
+ ),
+ "sudo": attr.bool(default = True),
+ "machine": attr.string(default = "n1-standard-1"),
+ "targets": attr.label_list(
+ mandatory = True,
+ allow_empty = False,
+ cfg = "target",
+ ),
+ },
+ test = True,
+ implementation = _vm_test_impl,
+)
+
+def vm_test(
+ installers = None,
+ **kwargs):
+ """Runs the given targets as a remote test.
+
+ Args:
+ installer: Script to run before all targets.
+ **kwargs: All test arguments. Should include targets and image.
+ """
+ targets = kwargs.pop("targets", [])
+ if installers == None:
+ installers = [
+ "//tools/installers:head",
+ "//tools/installers:images",
+ ]
+ targets = installers + targets
+ if default_installer():
+ targets = [default_installer()] + targets
+ _vm_test(
+ tags = [
+ "local",
+ "manual",
+ ],
+ targets = targets,
+ local = 1,
+ **kwargs
+ )
diff --git a/tools/vm/execute.sh b/tools/vm/execute.sh
new file mode 100755
index 000000000..1f1f3ce01
--- /dev/null
+++ b/tools/vm/execute.sh
@@ -0,0 +1,160 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+# Required input.
+if ! [[ -v IMAGE ]]; then
+ echo "no image provided: set IMAGE."
+ exit 1
+fi
+
+# Parameters.
+declare -r USERNAME=${USERNAME:-test}
+declare -r KEYNAME=$(mktemp --tmpdir -u key-XXXXXX)
+declare -r SSHKEYS=$(mktemp --tmpdir -u sshkeys-XXXXXX)
+declare -r INSTANCE_NAME=$(mktemp -u test-XXXXXX | tr A-Z a-z)
+declare -r MACHINE=${MACHINE:-n1-standard-1}
+declare -r ZONE=${ZONE:-us-central1-f}
+declare -r SUDO=${SUDO:-false}
+
+# Standard arguments (applies only on script execution).
+declare -ar SSH_ARGS=("-o" "ConnectTimeout=60" "--")
+
+# This script is executed as a test rule, which will reset the value of HOME.
+# Unfortunately, it is needed to load the gconfig credentials. We will reset
+# HOME when we actually execute in the remote environment, defined below.
+export HOME=$(eval echo ~$(whoami))
+
+# Generate unique keys for this test.
+[[ -f "${KEYNAME}" ]] || ssh-keygen -t rsa -N "" -f "${KEYNAME}" -C "${USERNAME}"
+cat > "${SSHKEYS}" <<EOF
+${USERNAME}:$(cat ${KEYNAME}.pub)
+EOF
+
+# Start a unique instance. This means that we first generate a unique set of ssh
+# keys to ensure that only we have access to this instance. Note that we must
+# constrain ourselves to Haswell or greater in order to have nested
+# virtualization available.
+gcloud compute instances create \
+ --min-cpu-platform "Intel Haswell" \
+ --preemptible \
+ --no-scopes \
+ --metadata block-project-ssh-keys=TRUE \
+ --metadata-from-file ssh-keys="${SSHKEYS}" \
+ --machine-type "${MACHINE}" \
+ --image "${IMAGE}" \
+ --zone "${ZONE}" \
+ "${INSTANCE_NAME}"
+function cleanup {
+ gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}"
+}
+trap cleanup EXIT
+
+# Wait for the instance to become available (up to 5 minutes).
+declare timeout=300
+declare success=0
+declare -r start=$(date +%s)
+declare -r end=$((${start}+${timeout}))
+while [[ "$(date +%s)" -lt "${end}" ]] && [[ "${success}" -lt 3 ]]; do
+ if gcloud compute ssh --ssh-key-file="${KEYNAME}" --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then
+ success=$((${success}+1))
+ fi
+done
+if [[ "${success}" -eq "0" ]]; then
+ echo "connect timed out after ${timeout} seconds."
+ exit 1
+fi
+
+# Copy the local directory over.
+tar czf - --dereference --exclude=.git . |
+ gcloud compute ssh \
+ --ssh-key-file="${KEYNAME}" \
+ --zone "${ZONE}" \
+ "${USERNAME}"@"${INSTANCE_NAME}" -- \
+ "${SSH_ARGS[@]}" \
+ tar xzf -
+
+# Execute the command remotely.
+for cmd; do
+ # Setup relevant environment.
+ #
+ # N.B. This is not a complete test environment, but is complete enough to
+ # provide rudimentary sharding and test output support.
+ declare -a PREFIX=( "env" )
+ if [[ -v TEST_SHARD_INDEX ]]; then
+ PREFIX+=( "TEST_SHARD_INDEX=${TEST_SHARD_INDEX}" )
+ fi
+ if [[ -v TEST_SHARD_STATUS_FILE ]]; then
+ SHARD_STATUS_FILE=$(mktemp -u test-shard-status-XXXXXX)
+ PREFIX+=( "TEST_SHARD_STATUS_FILE=/tmp/${SHARD_STATUS_FILE}" )
+ fi
+ if [[ -v TEST_TOTAL_SHARDS ]]; then
+ PREFIX+=( "TEST_TOTAL_SHARDS=${TEST_TOTAL_SHARDS}" )
+ fi
+ if [[ -v TEST_TMPDIR ]]; then
+ REMOTE_TMPDIR=$(mktemp -u test-XXXXXX)
+ PREFIX+=( "TEST_TMPDIR=/tmp/${REMOTE_TMPDIR}" )
+ # Create remotely.
+ gcloud compute ssh \
+ --ssh-key-file="${KEYNAME}" \
+ --zone "${ZONE}" \
+ "${USERNAME}"@"${INSTANCE_NAME}" -- \
+ "${SSH_ARGS[@]}" \
+ mkdir -p "/tmp/${REMOTE_TMPDIR}"
+ fi
+ if [[ -v XML_OUTPUT_FILE ]]; then
+ TEST_XML_OUTPUT=$(mktemp -u xml-output-XXXXXX)
+ PREFIX+=( "XML_OUTPUT_FILE=/tmp/${TEST_XML_OUTPUT}" )
+ fi
+ if [[ "${SUDO}" == "true" ]]; then
+ PREFIX+=( "sudo" "-E" )
+ fi
+
+ # Execute the command.
+ gcloud compute ssh \
+ --ssh-key-file="${KEYNAME}" \
+ --zone "${ZONE}" \
+ "${USERNAME}"@"${INSTANCE_NAME}" -- \
+ "${SSH_ARGS[@]}" \
+ "${PREFIX[@]}" "${cmd}"
+
+ # Collect relevant results.
+ if [[ -v TEST_SHARD_STATUS_FILE ]]; then
+ gcloud compute scp \
+ --ssh-key-file="${KEYNAME}" \
+ --zone "${ZONE}" \
+ "${USERNAME}"@"${INSTANCE_NAME}":/tmp/"${SHARD_STATUS_FILE}" \
+ "${TEST_SHARD_STATUS_FILE}" 2>/dev/null || true # Allowed to fail.
+ fi
+ if [[ -v XML_OUTPUT_FILE ]]; then
+ gcloud compute scp \
+ --ssh-key-file="${KEYNAME}" \
+ --zone "${ZONE}" \
+ "${USERNAME}"@"${INSTANCE_NAME}":/tmp/"${TEST_XML_OUTPUT}" \
+ "${XML_OUTPUT_FILE}" 2>/dev/null || true # Allowed to fail.
+ fi
+
+ # Clean up the temporary directory.
+ if [[ -v TEST_TMPDIR ]]; then
+ gcloud compute ssh \
+ --ssh-key-file="${KEYNAME}" \
+ --zone "${ZONE}" \
+ "${USERNAME}"@"${INSTANCE_NAME}" -- \
+ "${SSH_ARGS[@]}" \
+ rm -rf "/tmp/${REMOTE_TMPDIR}"
+ fi
+done
diff --git a/tools/vm/test.cc b/tools/vm/test.cc
new file mode 100644
index 000000000..c0ceacda1
--- /dev/null
+++ b/tools/vm/test.cc
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gtest/gtest.h"
+
+namespace {
+
+TEST(Image, Sanity0) {
+ // Do nothing (in shard 0).
+}
+
+TEST(Image, Sanity1) {
+ // Do nothing (in shard 1).
+}
+
+} // namespace
diff --git a/tools/vm/ubuntu1604/10_core.sh b/tools/vm/ubuntu1604/10_core.sh
new file mode 100755
index 000000000..629f7cf7a
--- /dev/null
+++ b/tools/vm/ubuntu1604/10_core.sh
@@ -0,0 +1,43 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+# Install all essential build tools.
+while true; do
+ if (apt-get update && apt-get install -y \
+ make \
+ git-core \
+ build-essential \
+ linux-headers-$(uname -r) \
+ pkg-config); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
+
+# Install a recent go toolchain.
+if ! [[ -d /usr/local/go ]]; then
+ wget https://dl.google.com/go/go1.13.5.linux-amd64.tar.gz
+ tar -xvf go1.13.5.linux-amd64.tar.gz
+ mv go /usr/local
+fi
+
+# Link the Go binary from /usr/bin; replacing anything there.
+(cd /usr/bin && rm -f go && ln -fs /usr/local/go/bin/go go)
diff --git a/tools/vm/ubuntu1604/15_gcloud.sh b/tools/vm/ubuntu1604/15_gcloud.sh
new file mode 100755
index 000000000..bc2e5eccc
--- /dev/null
+++ b/tools/vm/ubuntu1604/15_gcloud.sh
@@ -0,0 +1,50 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+# Install all essential build tools.
+while true; do
+ if (apt-get update && apt-get install -y \
+ apt-transport-https \
+ ca-certificates \
+ gnupg); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
+
+# Add gcloud repositories.
+echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | \
+ tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
+
+# Add the appropriate key.
+curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | \
+ apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
+
+# Install the gcloud SDK.
+while true; do
+ if (apt-get update && apt-get install -y google-cloud-sdk); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
diff --git a/tools/vm/ubuntu1604/20_bazel.sh b/tools/vm/ubuntu1604/20_bazel.sh
new file mode 100755
index 000000000..bb7afa676
--- /dev/null
+++ b/tools/vm/ubuntu1604/20_bazel.sh
@@ -0,0 +1,38 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+declare -r BAZEL_VERSION=2.0.0
+
+# Install bazel dependencies.
+while true; do
+ if (apt-get update && apt-get install -y \
+ openjdk-8-jdk-headless \
+ unzip); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
+
+# Use the release installer.
+curl -L -o bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
+chmod a+x bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
+./bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
+rm -f bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
diff --git a/tools/vm/ubuntu1604/25_docker.sh b/tools/vm/ubuntu1604/25_docker.sh
new file mode 100755
index 000000000..53d8ca588
--- /dev/null
+++ b/tools/vm/ubuntu1604/25_docker.sh
@@ -0,0 +1,65 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Add dependencies.
+while true; do
+ if (apt-get update && apt-get install -y \
+ apt-transport-https \
+ ca-certificates \
+ curl \
+ gnupg-agent \
+ software-properties-common); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
+
+# Install the key.
+curl -fsSL https://download.docker.com/linux/ubuntu/gpg | apt-key add -
+
+# Add the repository.
+add-apt-repository \
+ "deb [arch=amd64] https://download.docker.com/linux/ubuntu \
+ $(lsb_release -cs) \
+ stable"
+
+# Install docker.
+while true; do
+ if (apt-get update && apt-get install -y \
+ docker-ce \
+ docker-ce-cli \
+ containerd.io); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
+
+# Enable Docker IPv6.
+cat > /etc/docker/daemon.json <<EOF
+{
+ "fixed-cidr-v6": "2001:db8:1::/64",
+ "ipv6": true
+}
+EOF
+# Docker's IPv6 support is lacking and does not work the same way as IPv4. We
+# can use NAT so containers can reach the outside world.
+ip6tables -t nat -A POSTROUTING -s 2001:db8:1::/64 ! -o docker0 -j MASQUERADE
diff --git a/tools/vm/ubuntu1604/30_containerd.sh b/tools/vm/ubuntu1604/30_containerd.sh
new file mode 100755
index 000000000..fb3699c12
--- /dev/null
+++ b/tools/vm/ubuntu1604/30_containerd.sh
@@ -0,0 +1,86 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+# Helper for Go packages below.
+install_helper() {
+ PACKAGE="${1}"
+ TAG="${2}"
+ GOPATH="${3}"
+
+ # Clone the repository.
+ mkdir -p "${GOPATH}"/src/$(dirname "${PACKAGE}") && \
+ git clone https://"${PACKAGE}" "${GOPATH}"/src/"${PACKAGE}"
+
+ # Checkout and build the repository.
+ (cd "${GOPATH}"/src/"${PACKAGE}" && \
+ git checkout "${TAG}" && \
+ GOPATH="${GOPATH}" make && \
+ GOPATH="${GOPATH}" make install)
+}
+
+# Install dependencies for the crictl tests.
+while true; do
+ if (apt-get update && apt-get install -y \
+ btrfs-tools \
+ libseccomp-dev); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
+
+# Install containerd & cri-tools.
+GOPATH=$(mktemp -d --tmpdir gopathXXXXX)
+install_helper github.com/containerd/containerd v1.2.2 "${GOPATH}"
+install_helper github.com/kubernetes-sigs/cri-tools v1.11.0 "${GOPATH}"
+
+# Install gvisor-containerd-shim.
+declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim"
+declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX)
+declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX)
+wget --no-verbose "${base}"/latest -O ${latest}
+wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path}
+chmod +x ${shim_path}
+mv ${shim_path} /usr/local/bin
+
+# Configure containerd-shim.
+declare -r shim_config_path=/etc/containerd
+declare -r shim_config_tmp_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX.toml)
+mkdir -p ${shim_config_path}
+cat > ${shim_config_tmp_path} <<-EOF
+ runc_shim = "/usr/local/bin/containerd-shim"
+
+[runsc_config]
+ debug = "true"
+ debug-log = "/tmp/runsc-logs/"
+ strace = "true"
+ file-access = "shared"
+EOF
+mv ${shim_config_tmp_path} ${shim_config_path}
+
+# Configure CNI.
+(cd "${GOPATH}" && GOPATH="${GOPATH}" \
+ src/github.com/containerd/containerd/script/setup/install-cni)
+
+# Cleanup the above.
+rm -rf "${GOPATH}"
+rm -rf "${latest}"
+rm -rf "${shim_path}"
+rm -rf "${shim_config_tmp_path}"
diff --git a/tools/vm/ubuntu1604/40_kokoro.sh b/tools/vm/ubuntu1604/40_kokoro.sh
new file mode 100755
index 000000000..2974f156c
--- /dev/null
+++ b/tools/vm/ubuntu1604/40_kokoro.sh
@@ -0,0 +1,72 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+# Declare kokoro's required public keys.
+declare -r ssh_public_keys=(
+ "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDg7L/ZaEauETWrPklUTky3kvxqQfe2Ax/2CsSqhNIGNMnK/8d79CHlmY9+dE1FFQ/RzKNCaltgy7XcN/fCYiCZr5jm2ZtnLuGNOTzupMNhaYiPL419qmL+5rZXt4/dWTrsHbFRACxT8j51PcRMO5wgbL0Bg2XXimbx8kDFaurL2gqduQYqlu4lxWCaJqOL71WogcimeL63Nq/yeH5PJPWpqE4P9VUQSwAzBWFK/hLeds/AiP3MgVS65qHBnhq0JsHy8JQsqjZbG7Iidt/Ll0+gqzEbi62gDIcczG4KC0iOVzDDP/1BxDtt1lKeA23ll769Fcm3rJyoBMYxjvdw1TDx sabujp@trigger.mtv.corp.google.com"
+ "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBNgGK/hCdjmulHfRE3hp4rZs38NCR8yAh0eDsztxqGcuXnuSnL7jOlRrbcQpremJ84omD4eKrIpwJUs+YokMdv4= sabujp@trigger.svl.corp.google.com"
+)
+
+# Install dependencies.
+while true; do
+ if (apt-get update && apt-get install -y \
+ rsync \
+ coreutils \
+ python-psutil \
+ qemu-kvm \
+ python-pip \
+ python3-pip \
+ zip); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
+
+# junitparser is used to merge junit xml files.
+pip install junitparser
+
+# We need a kbuilder user, which may already exist.
+useradd -c "kbuilder user" -m -s /bin/bash kbuilder || true
+
+# We need to provision appropriate keys.
+mkdir -p ~kbuilder/.ssh
+(IFS=$'\n'; echo "${ssh_public_keys[*]}") > ~kbuilder/.ssh/authorized_keys
+chmod 0600 ~kbuilder/.ssh/authorized_keys
+chown -R kbuilder ~kbuilder/.ssh
+
+# Give passwordless sudo access.
+cat > /etc/sudoers.d/kokoro <<EOF
+kbuilder ALL=(ALL) NOPASSWD:ALL
+EOF
+
+# Ensure we can run Docker without sudo.
+usermod -aG docker kbuilder
+
+# Ensure that we can access kvm.
+usermod -aG kvm kbuilder
+
+# Ensure that /tmpfs exists and is writable by kokoro.
+#
+# Note that kokoro will typically attach a second disk (sdb) to the instance
+# that is used for the /tmpfs volume. In the future we could setup an init
+# script that formats and mounts this here; however, we don't expect our build
+# artifacts to be that large.
+mkdir -p /tmpfs && chmod 0777 /tmpfs && touch /tmpfs/READY
diff --git a/tools/vm/ubuntu1604/BUILD b/tools/vm/ubuntu1604/BUILD
new file mode 100644
index 000000000..ab1df0c4c
--- /dev/null
+++ b/tools/vm/ubuntu1604/BUILD
@@ -0,0 +1,7 @@
+package(licenses = ["notice"])
+
+filegroup(
+ name = "ubuntu1604",
+ srcs = glob(["*.sh"]),
+ visibility = ["//:sandbox"],
+)
diff --git a/tools/vm/ubuntu1804/BUILD b/tools/vm/ubuntu1804/BUILD
new file mode 100644
index 000000000..0c8856dde
--- /dev/null
+++ b/tools/vm/ubuntu1804/BUILD
@@ -0,0 +1,7 @@
+package(licenses = ["notice"])
+
+alias(
+ name = "ubuntu1804",
+ actual = "//tools/vm/ubuntu1604",
+ visibility = ["//:sandbox"],
+)
diff --git a/tools/vm/zone.sh b/tools/vm/zone.sh
new file mode 100755
index 000000000..79569fb19
--- /dev/null
+++ b/tools/vm/zone.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+exec gcloud config get-value compute/zone
diff --git a/tools/workspace_status.sh b/tools/workspace_status.sh
new file mode 100755
index 000000000..a22c8c9f2
--- /dev/null
+++ b/tools/workspace_status.sh
@@ -0,0 +1,18 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# The STABLE_ prefix will trigger a re-link if it changes.
+echo STABLE_VERSION $(git describe --always --tags --abbrev=12 --dirty || echo 0.0.0)
diff --git a/vdso/BUILD b/vdso/BUILD
new file mode 100644
index 000000000..c70bb8218
--- /dev/null
+++ b/vdso/BUILD
@@ -0,0 +1,81 @@
+# Description:
+# This VDSO is a shared library that provides the same interfaces as the
+# normal system VDSO (time, gettimeofday, clock_gettimeofday) but which uses
+# timekeeping parameters managed by the sandbox kernel.
+
+load("//tools:defs.bzl", "cc_flags_supplier", "cc_toolchain", "select_arch", "vdso_linker_option")
+
+package(licenses = ["notice"])
+
+genrule(
+ name = "vdso",
+ srcs = [
+ "barrier.h",
+ "compiler.h",
+ "cycle_clock.h",
+ "seqlock.h",
+ "syscalls.h",
+ "vdso.cc",
+ "vdso_amd64.lds",
+ "vdso_arm64.lds",
+ "vdso_time.h",
+ "vdso_time.cc",
+ ],
+ outs = [
+ "vdso.so",
+ ],
+ cmd = "$(CC) $(CC_FLAGS) " +
+ "-I. " +
+ "-O2 " +
+ "-std=c++11 " +
+ "-fPIC " +
+ "-fno-sanitize=all " +
+ # Some toolchains enable stack protector by default. Disable it, the
+ # VDSO has no hooks to handle failures.
+ "-fno-stack-protector " +
+ vdso_linker_option +
+ select_arch(
+ amd64 = "-m64 ",
+ arm64 = "",
+ ) +
+ "-shared " +
+ "-nostdlib " +
+ "-Wl,-soname=linux-vdso.so.1 " +
+ "-Wl,--hash-style=sysv " +
+ "-Wl,--no-undefined " +
+ "-Wl,-Bsymbolic " +
+ "-Wl,-z,max-page-size=4096 " +
+ "-Wl,-z,common-page-size=4096 " +
+ select_arch(
+ amd64 = "-Wl,-T$(location vdso_amd64.lds) ",
+ arm64 = "-Wl,-T$(location vdso_arm64.lds) ",
+ no_match_error = "unsupported architecture",
+ ) +
+ "-o $(location vdso.so) " +
+ "$(location vdso.cc) " +
+ "$(location vdso_time.cc) " +
+ "&& $(location :check_vdso) " +
+ "--check-data " +
+ "--vdso $(location vdso.so) ",
+ exec_tools = [
+ ":check_vdso",
+ ],
+ features = ["-pie"],
+ toolchains = [
+ cc_toolchain,
+ ":no_pie_cc_flags",
+ ],
+ visibility = ["//:sandbox"],
+)
+
+cc_flags_supplier(
+ name = "no_pie_cc_flags",
+ features = ["-pie"],
+)
+
+py_binary(
+ name = "check_vdso",
+ srcs = ["check_vdso.py"],
+ python_version = "PY3",
+ visibility = ["//:sandbox"],
+)
diff --git a/vdso/barrier.h b/vdso/barrier.h
new file mode 100644
index 000000000..edba4afb5
--- /dev/null
+++ b/vdso/barrier.h
@@ -0,0 +1,49 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef VDSO_BARRIER_H_
+#define VDSO_BARRIER_H_
+
+namespace vdso {
+
+// Compiler Optimization barrier.
+inline void barrier(void) { __asm__ __volatile__("" ::: "memory"); }
+
+#if __x86_64__
+
+inline void memory_barrier(void) {
+ __asm__ __volatile__("mfence" ::: "memory");
+}
+inline void read_barrier(void) { barrier(); }
+inline void write_barrier(void) { barrier(); }
+
+#elif __aarch64__
+
+inline void memory_barrier(void) {
+ __asm__ __volatile__("dmb ish" ::: "memory");
+}
+inline void read_barrier(void) {
+ __asm__ __volatile__("dmb ishld" ::: "memory");
+}
+inline void write_barrier(void) {
+ __asm__ __volatile__("dmb ishst" ::: "memory");
+}
+
+#else
+#error "unsupported architecture"
+#endif
+
+} // namespace vdso
+
+#endif // VDSO_BARRIER_H_
diff --git a/vdso/check_vdso.py b/vdso/check_vdso.py
new file mode 100644
index 000000000..b3ee574f3
--- /dev/null
+++ b/vdso/check_vdso.py
@@ -0,0 +1,204 @@
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Verify VDSO ELF does not contain any relocations and is directly mmappable.
+"""
+
+import argparse
+import logging
+import re
+import subprocess
+
+PAGE_SIZE = 4096
+
+
+def PageRoundDown(addr):
+ """Rounds down to the nearest page.
+
+ Args:
+ addr: An address.
+
+ Returns:
+ The address rounded down to the nearest page.
+ """
+ return addr & ~(PAGE_SIZE - 1)
+
+
+def Fatal(*args, **kwargs):
+ """Logs a critical message and exits with code 1.
+
+ Args:
+ *args: Args to pass to logging.critical.
+ **kwargs: Keyword args to pass to logging.critical.
+ """
+ logging.critical(*args, **kwargs)
+ exit(1)
+
+
+def CheckSegments(vdso_path):
+ """Verifies layout of PT_LOAD segments.
+
+ PT_LOAD segments must be laid out such that the ELF is directly mmappable.
+
+ Specifically, check that:
+ * PT_LOAD file offsets are equivalent to the memory offset from the first
+ segment.
+ * No extra zeroed space (memsz) is required.
+ * PT_LOAD segments are in order (required for any ELF).
+ * No two PT_LOAD segments share part of the same page.
+
+ The readelf line format looks like:
+ Type Offset VirtAddr PhysAddr FileSiz MemSiz Flg Align
+ LOAD 0x000000 0xffffffffff700000 0xffffffffff700000 0x000e68 0x000e68 R E 0x1000
+
+ Args:
+ vdso_path: Path to VDSO binary.
+ """
+ output = subprocess.check_output(["readelf", "-lW", vdso_path]).decode()
+ lines = output.split("\n")
+
+ segments = []
+ for line in lines:
+ if not line.startswith(" LOAD"):
+ continue
+
+ components = line.split()
+
+ segments.append({
+ "offset": int(components[1], 16),
+ "addr": int(components[2], 16),
+ "filesz": int(components[4], 16),
+ "memsz": int(components[5], 16),
+ })
+
+ if not segments:
+ Fatal("No PT_LOAD segments in VDSO")
+
+ first = segments[0]
+ if first["offset"] != 0:
+ Fatal("First PT_LOAD segment has non-zero file offset: %s", first)
+
+ for i, segment in enumerate(segments):
+ memoff = segment["addr"] - first["addr"]
+ if memoff != segment["offset"]:
+ Fatal("PT_LOAD segment has different memory and file offsets: %s",
+ segments)
+
+ if segment["memsz"] != segment["filesz"]:
+ Fatal("PT_LOAD segment memsz != filesz: %s", segment)
+
+ if i > 0:
+ last_end = segments[i-1]["addr"] + segments[i-1]["memsz"]
+ if segment["addr"] < last_end:
+ Fatal("PT_LOAD segments out of order")
+
+ last_page = PageRoundDown(last_end)
+ start_page = PageRoundDown(segment["addr"])
+ if last_page >= start_page:
+ Fatal("PT_LOAD segments share a page: %s and %s", segment,
+ segments[i - 1])
+
+
+# Matches the section name in readelf -SW output.
+_SECTION_NAME_RE = re.compile(r"""^\s+\[\ ?\d+\]\s+
+ (?P<name>\.\S+)\s+
+ (?P<type>\S+)\s+
+ (?P<addr>[0-9a-f]+)\s+
+ (?P<off>[0-9a-f]+)\s+
+ (?P<size>[0-9a-f]+)""", re.VERBOSE)
+
+
+def CheckData(vdso_path):
+ """Verifies the VDSO contains no .data or .bss sections.
+
+ The readelf line format looks like:
+
+ There are 15 section headers, starting at offset 0x15f0:
+
+ Section Headers:
+ [Nr] Name Type Address Off Size ES Flg Lk Inf Al
+ [ 0] NULL 0000000000000000 000000 000000 00 0 0 0
+ [ 1] .hash HASH ffffffffff700120 000120 000040 04 A 2 0 8
+ [ 2] .dynsym DYNSYM ffffffffff700160 000160 000108 18 A 3 1 8
+ ...
+ [13] .strtab STRTAB 0000000000000000 001448 000123 00 0 0 1
+ [14] .shstrtab STRTAB 0000000000000000 00156b 000083 00 0 0 1
+ Key to Flags:
+ W (write), A (alloc), X (execute), M (merge), S (strings), I (info),
+ L (link order), O (extra OS processing required), G (group), T (TLS),
+ C (compressed), x (unknown), o (OS specific), E (exclude),
+ l (large), p (processor specific)
+
+ Args:
+ vdso_path: Path to VDSO binary.
+ """
+ output = subprocess.check_output(["readelf", "-SW", vdso_path]).decode()
+ lines = output.split("\n")
+
+ found_text = False
+ for line in lines:
+ m = re.search(_SECTION_NAME_RE, line)
+ if not m:
+ continue
+
+ if not line.startswith(" ["):
+ continue
+
+ name = m.group("name")
+ size = int(m.group("size"), 16)
+
+ if name == ".text" and size != 0:
+ found_text = True
+
+ # Clang will typically omit these sections entirely; gcc will include them
+ # but with size 0.
+ if name.startswith(".data") and size != 0:
+ Fatal("VDSO contains non-empty .data section:\n%s" % output)
+
+ if name.startswith(".bss") and size != 0:
+ Fatal("VDSO contains non-empty .bss section:\n%s" % output)
+
+ if not found_text:
+ Fatal("VDSO contains no/empty .text section? Bad parsing?:\n%s" % output)
+
+
+def CheckRelocs(vdso_path):
+ """Verifies that the VDSO includes no relocations.
+
+ Args:
+ vdso_path: Path to VDSO binary.
+ """
+ output = subprocess.check_output(["readelf", "-r", vdso_path]).decode()
+ if output.strip() != "There are no relocations in this file.":
+ Fatal("VDSO contains relocations: %s", output)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Verify VDSO ELF.")
+ parser.add_argument("--vdso", required=True, help="Path to VDSO ELF")
+ parser.add_argument(
+ "--check-data",
+ action="store_true",
+ help="Check that the ELF contains no .data or .bss sections")
+ args = parser.parse_args()
+
+ CheckSegments(args.vdso)
+ CheckRelocs(args.vdso)
+
+ if args.check_data:
+ CheckData(args.vdso)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vdso/compiler.h b/vdso/compiler.h
new file mode 100644
index 000000000..54a510000
--- /dev/null
+++ b/vdso/compiler.h
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef VDSO_COMPILER_H_
+#define VDSO_COMPILER_H_
+
+#define likely(x) __builtin_expect(!!(x), 1)
+#define unlikely(x) __builtin_expect(!!(x), 0)
+
+#ifndef __section
+#define __section(S) __attribute__((__section__(#S)))
+#endif
+
+#ifndef __aligned
+#define __aligned(N) __attribute__((__aligned__(N)))
+#endif
+
+#endif // VDSO_COMPILER_H_
diff --git a/vdso/cycle_clock.h b/vdso/cycle_clock.h
new file mode 100644
index 000000000..5d3fbb257
--- /dev/null
+++ b/vdso/cycle_clock.h
@@ -0,0 +1,51 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef VDSO_CYCLE_CLOCK_H_
+#define VDSO_CYCLE_CLOCK_H_
+
+#include <stdint.h>
+
+#include "vdso/barrier.h"
+
+namespace vdso {
+
+#if __x86_64__
+
+// TODO(b/74613497): The appropriate barrier instruction to use with rdtsc on
+// x86_64 depends on the vendor. Intel processors can use lfence but AMD may
+// need mfence, depending on MSR_F10H_DECFG_LFENCE_SERIALIZE_BIT.
+
+static inline uint64_t cycle_clock(void) {
+ uint32_t lo, hi;
+ asm volatile("lfence" : : : "memory");
+ asm volatile("rdtsc" : "=a"(lo), "=d"(hi));
+ return ((uint64_t)hi << 32) | lo;
+}
+
+#elif __aarch64__
+
+static inline uint64_t cycle_clock(void) {
+ uint64_t val;
+ asm volatile("mrs %0, CNTVCT_EL0" : "=r"(val)::"memory");
+ return val;
+}
+
+#else
+#error "unsupported architecture"
+#endif
+
+} // namespace vdso
+
+#endif // VDSO_CYCLE_CLOCK_H_
diff --git a/vdso/seqlock.h b/vdso/seqlock.h
new file mode 100644
index 000000000..7a173174b
--- /dev/null
+++ b/vdso/seqlock.h
@@ -0,0 +1,39 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Low level raw interfaces to the sequence counter used by the VDSO.
+#ifndef VDSO_SEQLOCK_H_
+#define VDSO_SEQLOCK_H_
+
+#include <stdint.h>
+
+#include "vdso/barrier.h"
+#include "vdso/compiler.h"
+
+namespace vdso {
+
+inline int32_t read_seqcount_begin(const uint64_t* s) {
+ uint64_t seq = *s;
+ read_barrier();
+ return seq & ~1;
+}
+
+inline int read_seqcount_retry(const uint64_t* s, uint64_t seq) {
+ read_barrier();
+ return unlikely(*s != seq);
+}
+
+} // namespace vdso
+
+#endif // VDSO_SEQLOCK_H_
diff --git a/vdso/syscalls.h b/vdso/syscalls.h
new file mode 100644
index 000000000..0c6a922a0
--- /dev/null
+++ b/vdso/syscalls.h
@@ -0,0 +1,100 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// System call support for the VDSO.
+//
+// Provides fallback system call interfaces for getcpu()
+// and clock_gettime().
+
+#ifndef VDSO_SYSCALLS_H_
+#define VDSO_SYSCALLS_H_
+
+#include <asm/unistd.h>
+#include <errno.h>
+#include <fcntl.h>
+#include <stddef.h>
+#include <sys/types.h>
+
+#define __stringify_1(x...) #x
+#define __stringify(x...) __stringify_1(x)
+
+namespace vdso {
+
+#if __x86_64__
+
+struct getcpu_cache;
+
+static inline int sys_clock_gettime(clockid_t clock, struct timespec* ts) {
+ int num = __NR_clock_gettime;
+ asm volatile("syscall\n"
+ : "+a"(num)
+ : "D"(clock), "S"(ts)
+ : "rcx", "r11", "memory");
+ return num;
+}
+
+static inline int sys_getcpu(unsigned* cpu, unsigned* node,
+ struct getcpu_cache* cache) {
+ int num = __NR_getcpu;
+ asm volatile("syscall\n"
+ : "+a"(num)
+ : "D"(cpu), "S"(node), "d"(cache)
+ : "rcx", "r11", "memory");
+ return num;
+}
+
+static inline void sys_rt_sigreturn(void) {
+ asm volatile("movl $" __stringify(__NR_rt_sigreturn)", %eax \n"
+ "syscall \n");
+}
+
+#elif __aarch64__
+
+static inline int sys_clock_gettime(clockid_t _clkid, struct timespec* _ts) {
+ register struct timespec* ts asm("x1") = _ts;
+ register clockid_t clkid asm("x0") = _clkid;
+ register long ret asm("x0");
+ register long nr asm("x8") = __NR_clock_gettime;
+
+ asm volatile("svc #0\n"
+ : "=r"(ret)
+ : "r"(clkid), "r"(ts), "r"(nr)
+ : "memory");
+ return ret;
+}
+
+static inline int sys_clock_getres(clockid_t _clkid, struct timespec* _ts) {
+ register struct timespec* ts asm("x1") = _ts;
+ register clockid_t clkid asm("x0") = _clkid;
+ register long ret asm("x0");
+ register long nr asm("x8") = __NR_clock_getres;
+
+ asm volatile("svc #0\n"
+ : "=r"(ret)
+ : "r"(clkid), "r"(ts), "r"(nr)
+ : "memory");
+ return ret;
+}
+
+static inline void sys_rt_sigreturn(void) {
+ asm volatile("mov x8, #" __stringify(__NR_rt_sigreturn)" \n"
+ "svc #0 \n");
+}
+
+#else
+#error "unsupported architecture"
+#endif
+} // namespace vdso
+
+#endif // VDSO_SYSCALLS_H_
diff --git a/vdso/vdso.cc b/vdso/vdso.cc
new file mode 100644
index 000000000..3b6653b5d
--- /dev/null
+++ b/vdso/vdso.cc
@@ -0,0 +1,155 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This is the VDSO for sandboxed binaries. This file just contains the entry
+// points to the VDSO. All of the real work is done in vdso_time.cc
+
+#define _DEFAULT_SOURCE // ensure glibc provides struct timezone.
+#include <sys/time.h>
+#include <time.h>
+
+#include "vdso/syscalls.h"
+#include "vdso/vdso_time.h"
+
+namespace vdso {
+namespace {
+
+int __common_clock_gettime(clockid_t clock, struct timespec* ts) {
+ int ret;
+
+ switch (clock) {
+ case CLOCK_REALTIME:
+ ret = ClockRealtime(ts);
+ break;
+
+ case CLOCK_BOOTTIME:
+ // Fallthrough, CLOCK_BOOTTIME is an alias for CLOCK_MONOTONIC
+ case CLOCK_MONOTONIC:
+ ret = ClockMonotonic(ts);
+ break;
+
+ default:
+ ret = sys_clock_gettime(clock, ts);
+ break;
+ }
+
+ return ret;
+}
+
+int __common_gettimeofday(struct timeval* tv, struct timezone* tz) {
+ if (tv) {
+ struct timespec ts;
+ int ret = ClockRealtime(&ts);
+ if (ret) {
+ return ret;
+ }
+ tv->tv_sec = ts.tv_sec;
+ tv->tv_usec = ts.tv_nsec / 1000;
+ }
+
+ // Nobody should be calling gettimeofday() with a non-NULL
+ // timezone pointer. If they do then they will get zeros.
+ if (tz) {
+ tz->tz_minuteswest = 0;
+ tz->tz_dsttime = 0;
+ }
+
+ return 0;
+}
+} // namespace
+
+// __kernel_rt_sigreturn() implements rt_sigreturn()
+extern "C" void __kernel_rt_sigreturn(unsigned long unused) {
+ // No optimizations yet, just make the real system call.
+ sys_rt_sigreturn();
+}
+
+#if __x86_64__
+
+// __vdso_clock_gettime() implements clock_gettime()
+extern "C" int __vdso_clock_gettime(clockid_t clock, struct timespec* ts) {
+ return __common_clock_gettime(clock, ts);
+}
+extern "C" int clock_gettime(clockid_t clock, struct timespec* ts)
+ __attribute__((weak, alias("__vdso_clock_gettime")));
+
+// __vdso_gettimeofday() implements gettimeofday()
+extern "C" int __vdso_gettimeofday(struct timeval* tv, struct timezone* tz) {
+ return __common_gettimeofday(tv, tz);
+}
+extern "C" int gettimeofday(struct timeval* tv, struct timezone* tz)
+ __attribute__((weak, alias("__vdso_gettimeofday")));
+
+// __vdso_time() implements time()
+extern "C" time_t __vdso_time(time_t* t) {
+ struct timespec ts;
+ ClockRealtime(&ts);
+ if (t) {
+ *t = ts.tv_sec;
+ }
+ return ts.tv_sec;
+}
+extern "C" time_t time(time_t* t) __attribute__((weak, alias("__vdso_time")));
+
+// __vdso_getcpu() implements getcpu()
+extern "C" long __vdso_getcpu(unsigned* cpu, unsigned* node,
+ struct getcpu_cache* cache) {
+ // No optimizations yet, just make the real system call.
+ return sys_getcpu(cpu, node, cache);
+}
+extern "C" long getcpu(unsigned* cpu, unsigned* node,
+ struct getcpu_cache* cache)
+ __attribute__((weak, alias("__vdso_getcpu")));
+
+#elif __aarch64__
+
+// __kernel_clock_gettime() implements clock_gettime()
+extern "C" int __kernel_clock_gettime(clockid_t clock, struct timespec* ts) {
+ return __common_clock_gettime(clock, ts);
+}
+
+// __kernel_gettimeofday() implements gettimeofday()
+extern "C" int __kernel_gettimeofday(struct timeval* tv, struct timezone* tz) {
+ return __common_gettimeofday(tv, tz);
+}
+
+// __kernel_clock_getres() implements clock_getres()
+extern "C" int __kernel_clock_getres(clockid_t clock, struct timespec* res) {
+ int ret = 0;
+
+ switch (clock) {
+ case CLOCK_REALTIME:
+ case CLOCK_MONOTONIC:
+ case CLOCK_BOOTTIME: {
+ if (res == nullptr) {
+ return 0;
+ }
+
+ res->tv_sec = 0;
+ res->tv_nsec = 1;
+ break;
+ }
+
+ default:
+ ret = sys_clock_getres(clock, res);
+ break;
+ }
+
+ return ret;
+}
+
+#else
+#error "unsupported architecture"
+#endif
+} // namespace vdso
diff --git a/vdso/vdso_amd64.lds b/vdso/vdso_amd64.lds
new file mode 100644
index 000000000..d114290da
--- /dev/null
+++ b/vdso/vdso_amd64.lds
@@ -0,0 +1,102 @@
+/*
+ * Linker script for the VDSO.
+ *
+ * The VDSO is essentially a normal ELF shared library that is mapped into the
+ * address space of the process that is going to use it. The address of the
+ * VDSO is passed to the runtime linker in the AT_SYSINFO_EHDR entry of the aux
+ * vector.
+ *
+ * There are, however, three ways in which the VDSO differs from a normal
+ * shared library:
+ *
+ * - The runtime linker does not attempt to process any relocations for the
+ * VDSO so it is the responsibility of whoever loads the VDSO into the
+ * address space to do this if necessary. Because of this restriction we are
+ * careful to ensure that the VDSO does not need to have any relocations
+ * applied to it.
+ *
+ * - Although the VDSO is position independent and would normally be linked at
+ * virtual address 0, the Linux kernel VDSO is actually linked at a non zero
+ * virtual address and the code in the system runtime linker that handles the
+ * VDSO expects this to be the case so we have to explicitly link this VDSO
+ * at a non zero address. The actual address is arbitrary, but we use the
+ * same one as the Linux kernel VDSO.
+ *
+ * - The VDSO will be directly mmapped by the sentry, rather than going through
+ * a normal ELF loading process. The VDSO must be carefully constructed such
+ * that the layout in the ELF file is identical to the layout in memory.
+ */
+
+VDSO_PRELINK = 0xffffffffff700000;
+
+SECTIONS {
+ /* The parameter page is mapped just before the VDSO. */
+ _params = VDSO_PRELINK - 0x1000;
+
+ . = VDSO_PRELINK + SIZEOF_HEADERS;
+
+ .hash : { *(.hash) } :text
+ .gnu.hash : { *(.gnu.hash) }
+ .dynsym : { *(.dynsym) }
+ .dynstr : { *(.dynstr) }
+ .gnu.version : { *(.gnu.version) }
+ .gnu.version_d : { *(.gnu.version_d) }
+ .gnu.version_r : { *(.gnu.version_r) }
+
+ .note : { *(.note.*) } :text :note
+
+ .eh_frame_hdr : { *(.eh_frame_hdr) } :text :eh_frame_hdr
+ .eh_frame : { KEEP (*(.eh_frame)) } :text
+
+ .dynamic : { *(.dynamic) } :text :dynamic
+
+ .rodata : { *(.rodata*) } :text
+
+ .altinstructions : { *(.altinstructions) }
+ .altinstr_replacement : { *(.altinstr_replacement) }
+
+ /*
+ * TODO(gvisor.dev/issue/157): Remove this alignment? Then the VDSO would fit
+ * in a single page.
+ */
+ . = ALIGN(0x1000);
+ .text : { *(.text*) } :text =0x90909090
+
+ /*
+ * N.B. There is no data/bss section. This VDSO neither needs nor uses a data
+ * section. We omit it entirely because some gcc/clang and gold/bfd version
+ * combinations struggle to handle an empty data PHDR segment (internal
+ * linker assertion failures result).
+ *
+ * If the VDSO does incorrectly include a data section, the linker will
+ * include it in the text segment. check_vdso.py looks for this degenerate
+ * case.
+ */
+}
+
+PHDRS {
+ text PT_LOAD FLAGS(5) FILEHDR PHDRS; /* PF_R | PF_X */
+ dynamic PT_DYNAMIC FLAGS(4); /* PF_R */
+ note PT_NOTE FLAGS(4); /* PF_R */
+ eh_frame_hdr PT_GNU_EH_FRAME;
+}
+
+/*
+ * Define the symbols that are to be exported.
+ */
+VERSION {
+ LINUX_2.6 {
+ global:
+ clock_gettime;
+ __vdso_clock_gettime;
+ gettimeofday;
+ __vdso_gettimeofday;
+ getcpu;
+ __vdso_getcpu;
+ time;
+ __vdso_time;
+ __kernel_rt_sigreturn;
+
+ local: *;
+ };
+}
diff --git a/vdso/vdso_arm64.lds b/vdso/vdso_arm64.lds
new file mode 100644
index 000000000..469185468
--- /dev/null
+++ b/vdso/vdso_arm64.lds
@@ -0,0 +1,99 @@
+/*
+ * Linker script for the VDSO.
+ *
+ * The VDSO is essentially a normal ELF shared library that is mapped into the
+ * address space of the process that is going to use it. The address of the
+ * VDSO is passed to the runtime linker in the AT_SYSINFO_EHDR entry of the aux
+ * vector.
+ *
+ * There are, however, three ways in which the VDSO differs from a normal
+ * shared library:
+ *
+ * - The runtime linker does not attempt to process any relocations for the
+ * VDSO so it is the responsibility of whoever loads the VDSO into the
+ * address space to do this if necessary. Because of this restriction we are
+ * careful to ensure that the VDSO does not need to have any relocations
+ * applied to it.
+ *
+ * - Although the VDSO is position independent and would normally be linked at
+ * virtual address 0, the Linux kernel VDSO is actually linked at a non zero
+ * virtual address and the code in the system runtime linker that handles the
+ * VDSO expects this to be the case so we have to explicitly link this VDSO
+ * at a non zero address. The actual address is arbitrary, but we use the
+ * same one as the Linux kernel VDSO.
+ *
+ * - The VDSO will be directly mmapped by the sentry, rather than going through
+ * a normal ELF loading process. The VDSO must be carefully constructed such
+ * that the layout in the ELF file is identical to the layout in memory.
+ */
+
+VDSO_PRELINK = 0xffffffffff700000;
+
+OUTPUT_FORMAT("elf64-littleaarch64", "elf64-bigaarch64", "elf64-littleaarch64")
+OUTPUT_ARCH(aarch64)
+
+SECTIONS {
+ /* The parameter page is mapped just before the VDSO. */
+ _params = VDSO_PRELINK - 0x1000;
+
+ . = VDSO_PRELINK + SIZEOF_HEADERS;
+
+ .hash : { *(.hash) } :text
+ .gnu.hash : { *(.gnu.hash) }
+ .dynsym : { *(.dynsym) }
+ .dynstr : { *(.dynstr) }
+ .gnu.version : { *(.gnu.version) }
+ .gnu.version_d : { *(.gnu.version_d) }
+ .gnu.version_r : { *(.gnu.version_r) }
+
+ .note : { *(.note.*) } :text :note
+
+ .eh_frame_hdr : { *(.eh_frame_hdr) } :text :eh_frame_hdr
+ .eh_frame : { KEEP (*(.eh_frame)) } :text
+
+ .dynamic : { *(.dynamic) } :text :dynamic
+
+ .rodata : { *(.rodata*) } :text
+
+ .altinstructions : { *(.altinstructions) }
+ .altinstr_replacement : { *(.altinstr_replacement) }
+
+ /*
+ * TODO(gvisor.dev/issue/157): Remove this alignment? Then the VDSO would fit
+ * in a single page.
+ */
+ . = ALIGN(0x1000);
+ .text : { *(.text*) } :text =0xd503201f
+
+ /*
+ * N.B. There is no data/bss section. This VDSO neither needs nor uses a data
+ * section. We omit it entirely because some gcc/clang and gold/bfd version
+ * combinations struggle to handle an empty data PHDR segment (internal
+ * linker assertion failures result).
+ *
+ * If the VDSO does incorrectly include a data section, the linker will
+ * include it in the text segment. check_vdso.py looks for this degenerate
+ * case.
+ */
+}
+
+PHDRS {
+ text PT_LOAD FLAGS(5) FILEHDR PHDRS; /* PF_R | PF_X */
+ dynamic PT_DYNAMIC FLAGS(4); /* PF_R */
+ note PT_NOTE FLAGS(4); /* PF_R */
+ eh_frame_hdr PT_GNU_EH_FRAME;
+}
+
+/*
+ * Define the symbols that are to be exported.
+ */
+VERSION {
+ LINUX_2.6.39 {
+ global:
+ __kernel_clock_getres;
+ __kernel_clock_gettime;
+ __kernel_gettimeofday;
+ __kernel_rt_sigreturn;
+ local: *;
+ };
+}
diff --git a/vdso/vdso_time.cc b/vdso/vdso_time.cc
new file mode 100644
index 000000000..1bb4bb86b
--- /dev/null
+++ b/vdso/vdso_time.cc
@@ -0,0 +1,159 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "vdso/vdso_time.h"
+
+#include <stdint.h>
+#include <sys/time.h>
+#include <time.h>
+
+#include "vdso/cycle_clock.h"
+#include "vdso/seqlock.h"
+#include "vdso/syscalls.h"
+
+// struct params defines the layout of the parameter page maintained by the
+// kernel (i.e., sentry).
+//
+// This is similar to the VVAR page maintained by the normal Linux kernel for
+// its VDSO, but it has a different layout.
+//
+// It must be kept in sync with VDSOParamPage in pkg/sentry/kernel/vdso.go.
+struct params {
+ uint64_t seq_count;
+
+ uint64_t monotonic_ready;
+ int64_t monotonic_base_cycles;
+ int64_t monotonic_base_ref;
+ uint64_t monotonic_frequency;
+
+ uint64_t realtime_ready;
+ int64_t realtime_base_cycles;
+ int64_t realtime_base_ref;
+ uint64_t realtime_frequency;
+};
+
+// Returns a pointer to the global parameter page.
+//
+// This page lives in the page just before the VDSO binary itself. The linker
+// defines _params as the page before the VDSO.
+//
+// Ideally, we'd simply declare _params as an extern struct params.
+// Unfortunately various combinations of old/new versions of gcc/clang and
+// gold/bfd struggle to generate references to such a global without generating
+// relocations.
+//
+// So instead, we use inline assembly with a construct that seems to have wide
+// compatibility across many toolchains.
+#if __x86_64__
+
+inline struct params* get_params() {
+ struct params* p = nullptr;
+ asm("leaq _params(%%rip), %0" : "=r"(p) : :);
+ return p;
+}
+
+#elif __aarch64__
+
+inline struct params* get_params() {
+ struct params* p = nullptr;
+ asm("adr %0, _params" : "=r"(p) : :);
+ return p;
+}
+
+#else
+#error "unsupported architecture"
+#endif
+
+namespace vdso {
+
+const uint64_t kNsecsPerSec = 1000000000UL;
+
+inline struct timespec ns_to_timespec(uint64_t ns) {
+ struct timespec ts;
+ ts.tv_sec = ns / kNsecsPerSec;
+ ts.tv_nsec = ns % kNsecsPerSec;
+ return ts;
+}
+
+inline uint64_t cycles_to_ns(uint64_t frequency, uint64_t cycles) {
+ uint64_t mult = (kNsecsPerSec << 32) / frequency;
+ return ((unsigned __int128)cycles * mult) >> 32;
+}
+
+// ClockRealtime() is the VDSO implementation of clock_gettime(CLOCK_REALTIME).
+int ClockRealtime(struct timespec* ts) {
+ struct params* params = get_params();
+ uint64_t seq;
+ uint64_t ready;
+ int64_t base_ref;
+ int64_t base_cycles;
+ uint64_t frequency;
+ int64_t now_cycles;
+
+ do {
+ seq = read_seqcount_begin(&params->seq_count);
+ ready = params->realtime_ready;
+ base_ref = params->realtime_base_ref;
+ base_cycles = params->realtime_base_cycles;
+ frequency = params->realtime_frequency;
+ now_cycles = cycle_clock();
+ } while (read_seqcount_retry(&params->seq_count, seq));
+
+ if (!ready) {
+ // The sandbox kernel ensures that we won't compute a time later than this
+ // once the params are ready.
+ return sys_clock_gettime(CLOCK_REALTIME, ts);
+ }
+
+ int64_t delta_cycles =
+ (now_cycles < base_cycles) ? 0 : now_cycles - base_cycles;
+ int64_t now_ns = base_ref + cycles_to_ns(frequency, delta_cycles);
+ *ts = ns_to_timespec(now_ns);
+ return 0;
+}
+
+// ClockMonotonic() is the VDSO implementation of
+// clock_gettime(CLOCK_MONOTONIC).
+int ClockMonotonic(struct timespec* ts) {
+ struct params* params = get_params();
+ uint64_t seq;
+ uint64_t ready;
+ int64_t base_ref;
+ int64_t base_cycles;
+ uint64_t frequency;
+ int64_t now_cycles;
+
+ do {
+ seq = read_seqcount_begin(&params->seq_count);
+ ready = params->monotonic_ready;
+ base_ref = params->monotonic_base_ref;
+ base_cycles = params->monotonic_base_cycles;
+ frequency = params->monotonic_frequency;
+ now_cycles = cycle_clock();
+ } while (read_seqcount_retry(&params->seq_count, seq));
+
+ if (!ready) {
+ // The sandbox kernel ensures that we won't compute a time later than this
+ // once the params are ready.
+ return sys_clock_gettime(CLOCK_MONOTONIC, ts);
+ }
+
+ int64_t delta_cycles =
+ (now_cycles < base_cycles) ? 0 : now_cycles - base_cycles;
+ int64_t now_ns = base_ref + cycles_to_ns(frequency, delta_cycles);
+ *ts = ns_to_timespec(now_ns);
+ return 0;
+}
+
+} // namespace vdso
diff --git a/vdso/vdso_time.h b/vdso/vdso_time.h
new file mode 100644
index 000000000..70d079efc
--- /dev/null
+++ b/vdso/vdso_time.h
@@ -0,0 +1,27 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef VDSO_VDSO_TIME_H_
+#define VDSO_VDSO_TIME_H_
+
+#include <time.h>
+
+namespace vdso {
+
+int ClockRealtime(struct timespec* ts);
+int ClockMonotonic(struct timespec* ts);
+
+} // namespace vdso
+
+#endif // VDSO_VDSO_TIME_H_
diff --git a/website/BUILD b/website/BUILD
new file mode 100644
index 000000000..4488cb543
--- /dev/null
+++ b/website/BUILD
@@ -0,0 +1,181 @@
+load("//tools:defs.bzl", "pkg_tar")
+load("//website:defs.bzl", "doc", "docs")
+
+package(licenses = ["notice"])
+
+# website is the full container image. Note that this actually just collects
+# other dependendcies and runs Docker locally to import and tag the image.
+sh_binary(
+ name = "website",
+ srcs = ["import.sh"],
+ data = [":files"],
+ tags = [
+ "local",
+ "manual",
+ ],
+)
+
+# files is the full file system of the generated container.
+#
+# It must collect the all tarballs (produced by the rules below), and run it
+# through the Dockerfile to generate the site. Note that this checks all links,
+# and therefore requires all static content to be present as well.
+#
+# Note that this rule violates most aspects of hermetic builds. However, this
+# works much more reliably than depending on the container_image rules from
+# bazel itself, which are convoluted and seem to have a hard time even finding
+# the toolchain.
+genrule(
+ name = "files",
+ srcs = [
+ ":config",
+ ":css",
+ ":docs",
+ ":static",
+ ":syscallmd",
+ "//website/blog:posts",
+ "//website/cmd/server",
+ ],
+ outs = ["files.tgz"],
+ cmd = "set -x; " +
+ "T=$$(mktemp -d); " +
+ "mkdir -p $$T/input && " +
+ "mkdir -p $$T/output/_site && " +
+ "tar -xf $(location :config) -C $$T/input && " +
+ "tar -xf $(location :css) -C $$T/input && " +
+ "tar -xf $(location :docs) -C $$T/input && " +
+ "tar -xf $(location :syscallmd) -C $$T/input && " +
+ "tar -xf $(location //website/blog:posts) -C $$T/input && " +
+ "find $$T/input -type f -exec chmod u+rw {} \\; && " +
+ "docker run -i --user $$(id -u):$$(id -g) " +
+ "-v $$(readlink -m $$T/input):/input " +
+ "-v $$(readlink -m $$T/output/_site):/output " +
+ "gvisor.dev/images/jekyll && " +
+ "tar -xf $(location :static) -C $$T/output/_site && " +
+ "docker run -i --user $$(id -u):$$(id -g) " +
+ "-v $$(readlink -m $$T/output/_site):/output " +
+ "gvisor.dev/images/jekyll " +
+ "/usr/gem/bin/htmlproofer " +
+ "--disable-external " +
+ "--check-html " +
+ "/output && " +
+ "cp $(location //website/cmd/server) $$T/output/server && " +
+ "tar -zcf $@ -C $$T/output . && " +
+ "rm -rf $$T",
+ tags = [
+ "local",
+ "manual",
+ "nosandbox",
+ ],
+)
+
+# static are the purely static parts of the site. These are effectively copied
+# in after jekyll generates all the dynamic content.
+pkg_tar(
+ name = "static",
+ srcs = [
+ "archive.key",
+ ] + glob([
+ "performance/**",
+ ]),
+ strip_prefix = "./",
+)
+
+# main.scss requires front-matter to be processed.
+genrule(
+ name = "css",
+ srcs = glob([
+ "css/**",
+ ]),
+ outs = [
+ "css.tar",
+ ],
+ cmd = "T=$$(mktemp -d); " +
+ "mkdir -p $$T/css && " +
+ "for file in $(SRCS); do " +
+ "echo -en '---\\n---\\n' > $$T/css/$$(basename $$file) && " +
+ "cat $$file >> $$T/css/$$(basename $$file); " +
+ "done && " +
+ "tar -C $$T -czf $@ . && " +
+ "rm -rf $$T",
+)
+
+# config is "mostly" static content. These are parts of the site that are
+# present when jekyll runs, but are not dynamically generated.
+pkg_tar(
+ name = "config",
+ srcs = [
+ ":css",
+ "_config.yml",
+ "//website/blog:index.html",
+ ] + glob([
+ "assets/**",
+ "_includes/**",
+ "_layouts/**",
+ "_plugins/**",
+ "_sass/**",
+ ]),
+ strip_prefix = "./",
+)
+
+# index is the index file.
+doc(
+ name = "index",
+ src = "index.md",
+ layout = "base",
+ permalink = "/",
+)
+
+# docs is the dynamic content of the site.
+docs(
+ name = "docs",
+ deps = [
+ ":index",
+ "//:code_of_conduct",
+ "//:contributing",
+ "//:governance",
+ "//:security",
+ "//g3doc:community",
+ "//g3doc:index",
+ "//g3doc:roadmap",
+ "//g3doc:style",
+ "//g3doc/architecture_guide:performance",
+ "//g3doc/architecture_guide:platforms",
+ "//g3doc/architecture_guide:resources",
+ "//g3doc/architecture_guide:security",
+ "//g3doc/user_guide:FAQ",
+ "//g3doc/user_guide:checkpoint_restore",
+ "//g3doc/user_guide:compatibility",
+ "//g3doc/user_guide:debugging",
+ "//g3doc/user_guide:filesystem",
+ "//g3doc/user_guide:install",
+ "//g3doc/user_guide:networking",
+ "//g3doc/user_guide:platforms",
+ "//g3doc/user_guide/quick_start:docker",
+ "//g3doc/user_guide/quick_start:kubernetes",
+ "//g3doc/user_guide/quick_start:oci",
+ "//g3doc/user_guide/tutorials:cni",
+ "//g3doc/user_guide/tutorials:docker",
+ "//g3doc/user_guide/tutorials:kubernetes",
+ ],
+)
+
+# Generate JSON for system call tables
+genrule(
+ name = "syscalljson",
+ outs = ["syscalls.json"],
+ cmd = "$(location //runsc) -- help syscalls -format json -filename $@",
+ tools = ["//runsc"],
+)
+
+# Generate markdown from the json dump.
+genrule(
+ name = "syscallmd",
+ srcs = [":syscalljson"],
+ outs = ["syscallsmd"],
+ cmd = "T=$$(mktemp -d) && " +
+ "$(location //website/cmd/syscalldocs) -in $< -out $$T && " +
+ "tar -C $$T -czf $@ . && " +
+ "rm -rf $$T",
+ tools = ["//website/cmd/syscalldocs"],
+)
diff --git a/website/_config.yml b/website/_config.yml
new file mode 100644
index 000000000..b08602970
--- /dev/null
+++ b/website/_config.yml
@@ -0,0 +1,36 @@
+destination: _site
+markdown: kramdown
+kramdown:
+ syntax_highlighter: rouge
+ toc_levels: "2,3"
+highlighter: rouge
+paginate: 5
+paginate_path: "/blog/page:num/"
+plugins:
+ - jekyll-paginate
+ - jekyll-autoprefixer
+ - jekyll-inline-svg
+ - jekyll-relative-links
+ - jekyll-feed
+ - jekyll-sitemap
+site_url: https://gvisor.dev
+feed:
+ path: blog/index.xml
+svg:
+ optimize: true
+defaults:
+ - scope:
+ path: ""
+ values:
+ layout: default
+analytics: "UA-150193582-1"
+authors:
+ jsprad:
+ name: Jeremiah Spradlin
+ email: jsprad@google.com
+ zkoopmans:
+ name: Zach Koopmans
+ email: zkoopmans@google.com
+ igudger:
+ name: Ian Gudger
+ email: igudger@google.com
diff --git a/website/_includes/byline.html b/website/_includes/byline.html
new file mode 100644
index 000000000..d8ae22cb0
--- /dev/null
+++ b/website/_includes/byline.html
@@ -0,0 +1,18 @@
+By
+{% assign last_pos=include.authors.size | minus: 1 %}
+{% assign and_pos=include.authors.size | minus: 2 %}
+{% for i in (0..last_pos) %}
+ {% assign author_id=include.authors[i] %}
+ {% assign author=site.authors[author_id] %}
+ {% if author %}
+ <a href="mailto:{{ author.email }}">{{ author.name }}</a>
+ {% else %}
+ {{ author_id }}
+ {% endif %}
+ {% if i == and_pos %}
+ and
+ {% elsif i < and_pos %}
+ ,
+ {% endif %}
+{% endfor %}
+on <span class="text-muted">{{ include.date | date_to_long_string }}</span>
diff --git a/website/_includes/footer-links.html b/website/_includes/footer-links.html
new file mode 100644
index 000000000..2036dbaa9
--- /dev/null
+++ b/website/_includes/footer-links.html
@@ -0,0 +1,43 @@
+<div class="container">
+ <div class="row">
+ <div class="col-sm-3 col-md-2">
+ <p>About</p>
+ <ul class="list-unstyled">
+ <li><a href="/roadmap/">Roadmap</a></li>
+ <li><a href="/contributing/">Contributing</a></li>
+ <li><a href="/security/">Security</a></li>
+ <li><a href="/community/governance/">Governance</a></li>
+ <li><a href="https://policies.google.com/privacy">Privacy Policy</a></li>
+ </ul>
+ </div>
+ <div class="col-sm-3 col-md-2">
+ <p>Support</p>
+ <ul class="list-unstyled">
+ <li><a href="https://github.com/google/gvisor/issues">Issues</a></li>
+ <li><a href="/docs">Documentation</a></li>
+ <li><a href="/docs/user_guide/faq">FAQ</a></li>
+ </ul>
+ </div>
+ <div class="col-sm-3 col-md-2">
+ <p>Connect</p>
+ <ul class="list-unstyled">
+ <li><a href="https://github.com/google/gvisor">GitHub</a></li>
+ <li><a href="https://groups.google.com/forum/#!forum/gvisor-users">User Mailing List</a></li>
+ <li><a href="https://groups.google.com/forum/#!forum/gvisor-dev">Developer Mailing List</a></li>
+ <li><a href="https://gitter.im/gvisor/community">Gitter Chat</a></li>
+ <li><a href="/blog">Blog</a></li>
+ </ul>
+ </div>
+ <div class="col-sm-3 col-md-3"></div>
+ <div class="hidden-xs hidden-sm col-md-3">
+ <a href="https://cloud.google.com/run">
+ <img style="float: right;" src="/assets/logos/powered-gvisor.png" alt="Powered by gVisor"/>
+ </a>
+ </div>
+ </div>
+ <div class="row">
+ <div class="col-lg-12">
+ <p>&copy; {{ 'now' | date: "%Y" }} The gVisor Authors</p>
+ </div>
+ </div>
+</div>
diff --git a/website/_includes/footer.html b/website/_includes/footer.html
new file mode 100644
index 000000000..9cc8176f7
--- /dev/null
+++ b/website/_includes/footer.html
@@ -0,0 +1,72 @@
+<footer class="footer">
+ {% include footer-links.html %}
+</footer>
+
+<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.3.1/jquery.min.js" integrity="sha256-FgpCb/KJQlLNfOu91ta32o/NMZxltwRo8QtmkMRdAu8=" crossorigin="anonymous"></script>
+<script src="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.10.1/js/all.min.js" integrity="sha256-Z1Nvg/+y2+vRFhFgFij7Lv0r77yG3hOvWz2wI0SfTa0=" crossorigin="anonymous"></script>
+<script src="https://cdnjs.cloudflare.com/ajax/libs/twitter-bootstrap/3.3.7/js/bootstrap.min.js" integrity="sha256-U5ZEeKfGNOja007MMD3YBI0A3OSZOQbeG6z2f2Y0hu8=" crossorigin="anonymous"></script>
+<script src="https://cdnjs.cloudflare.com/ajax/libs/d3/4.13.0/d3.min.js" integrity="sha256-hYXbQJK4qdJiAeDVjjQ9G0D6A0xLnDQ4eJI9dkm7Fpk=" crossorigin="anonymous"></script>
+
+{% if site.analytics %}
+<script type="application/javascript">
+var doNotTrack = false;
+if (!doNotTrack) {
+ window.ga=window.ga||function(){(ga.q=ga.q||[]).push(arguments)};ga.l=+new Date;
+ ga('create', '{{ site.analytics }}', 'auto');
+ ga('send', 'pageview');
+}
+</script>
+<script async src='https://www.google-analytics.com/analytics.js'></script>
+{% endif %}
+
+<script>
+ var shiftWindow = function() {
+ if (location.hash.length !== 0) {
+ window.scrollBy(0, -50);
+ }
+ };
+ window.addEventListener("hashchange", shiftWindow);
+
+ var highlightCurrentSidebarNav = function() {
+ var href = location.pathname;
+ var item = $('#sidebar-nav [href$="' + href + '"]');
+ if (item) {
+ var li = item.parent();
+ li.addClass("active");
+
+ if (li.parent() && li.parent().is("ul")) {
+ do {
+ var ul = li.parent();
+ if (ul.hasClass("collapse")) {
+ ul.collapse("show");
+ }
+ li = ul.parent();
+ } while (li && li.is("li"));
+ }
+ }
+ };
+
+ $(document).ready(function() {
+ // Scroll to anchor of location hash, adjusted for fixed navbar.
+ window.setTimeout(function() {
+ shiftWindow();
+ }, 1);
+
+ // Flip the caret when submenu toggles are clicked.
+ $(".sidebar-submenu").on("show.bs.collapse", function() {
+ var toggle = $('[href$="#' + $(this).attr('id') + '"]');
+ if (toggle) {
+ toggle.addClass("dropup");
+ }
+ });
+ $(".sidebar-submenu").on("hide.bs.collapse", function() {
+ var toggle = $('[href$="#' + $(this).attr('id') + '"]');
+ if (toggle) {
+ toggle.removeClass("dropup");
+ }
+ });
+
+ // Highlight the current page on the sidebar nav.
+ highlightCurrentSidebarNav();
+ });
+</script>
diff --git a/website/_includes/graph.html b/website/_includes/graph.html
new file mode 100644
index 000000000..f3a999341
--- /dev/null
+++ b/website/_includes/graph.html
@@ -0,0 +1,205 @@
+{::nomarkdown}
+{% assign fn = include.id | remove: " " | remove: "-" | downcase %}
+<figure><a href="{{ include.url }}"><svg id="{{ include.id }}" width=500 height=200 onload="render_{{ fn }}()"><title>{{ include.title }}</title></svg></a></figure>
+<script type="text/javascript">
+function render_{{ fn }}() {
+d3.csv("{{ include.url }}", function(d, i, columns) {
+ return d; // Transformed below.
+}, function(error, data) {
+ if (error) throw(error);
+
+ // Create a new data that pivots on runtime.
+ //
+ // To start, we have:
+ // runtime, ..., result
+ // runc, ..., 1
+ // runsc, ..., 2
+ //
+ // In the end we want:
+ // ..., runsc, runc
+ // ..., 1, 2
+
+ // Filter by metric, if required.
+ if ("{{ include.metric }}" != "") {
+ orig_columns = data.columns;
+ data = data.filter(d => d.metric == "{{ include.metric }}");
+ data.columns = orig_columns;
+ }
+
+ // Filter by method, if required.
+ if ("{{ include.method }}" != "") {
+ orig_columns = data.columns;
+ data = data.filter(d => d.method == "{{ include.method }}");
+ data.columns = orig_columns.filter(key => key != "method");
+ }
+
+ // Enumerate runtimes.
+ var runtimes = Array.from(new Set(data.map(d => d.runtime)));
+ var metrics = Array.from(new Set(data.map(d => d.metric)));
+ if (metrics.length < 1) {
+ console.log(data);
+ throw("need at least one metric");
+ } else if (metrics.length == 1) {
+ metric = metrics[0];
+ data.columns = data.columns.filter(key => key != "metric");
+ } else {
+ metric = ""; // Used for grouping.
+ }
+
+ var isSubset = function(a, sup) {
+ var ap = Object.getOwnPropertyNames(a);
+ for (var i = 0; i < ap.length; i++) {
+ if (a[ap[i]] !== sup[ap[i]]) {
+ return false;
+ }
+ }
+ return true;
+ };
+
+ // Execute a pivot to include runtimes as attributes.
+ var new_data = data.map(function(data_item) {
+ // Generate a prototype data item.
+ var proto_item = Object.assign({}, data_item);
+ delete proto_item.runtime;
+ delete proto_item.result;
+ var next_item = Object.assign({}, proto_item);
+
+ // Find all matching runtime items.
+ data.forEach(function(d) {
+ if (isSubset(proto_item, d)) {
+ // Add the result result.
+ next_item[d.runtime] = d.result;
+ }
+ });
+ return next_item;
+ });
+
+ // Remove any duplication.
+ new_data = Array.from(new Set(new_data));
+ new_data.columns = data.columns;
+ new_data.columns = new_data.columns.filter(key => key != "runtime" && key != "result");
+ new_data.columns = new_data.columns.concat(runtimes);
+ data = new_data;
+
+ // Slice based on the first key.
+ if (data.columns.length != runtimes.length) {
+ x0_key = new_data.columns[0];
+ var x1_domain = data.columns.slice(1);
+ } else {
+ x0_key = "runtime";
+ var x1_domain = runtimes;
+ }
+
+ // Determine varaible margins.
+ var x0_domain = data.map(d => d[x0_key]);
+ var margin_bottom_pad = 0;
+ if (x0_domain.length > 8) {
+ margin_bottom_pad = 50;
+ }
+
+ // Use log scale if required.
+ var y_min = 0;
+ if ({{ include.log | default: "false" }}) {
+ // Need to cap lower end of the domain at 1.
+ y_min = 1;
+ }
+
+ if ({{ include.y_min | default: "false" }}) {
+ y_min = "{{ include.y_min }}";
+ }
+
+ var svg = d3.select("#{{ include.id }}"),
+ margin = {top: 20, right: 20, bottom: 30 + margin_bottom_pad, left: 50},
+ width = +svg.attr("width") - margin.left - margin.right,
+ height = +svg.attr("height") - margin.top - margin.bottom,
+ g = svg.append("g").attr("transform", "translate(" + margin.left + "," + margin.top + ")");
+
+ var x0 = d3.scaleBand()
+ .rangeRound([margin.left / 2, width - (4 * margin.right)])
+ .paddingInner(0.1);
+
+ var x1 = d3.scaleBand()
+ .padding(0.05);
+
+ var y = d3.scaleLinear()
+ .rangeRound([height, 0]);
+ if ({{ include.log | default: "false" }}) {
+ y = d3.scaleLog()
+ .rangeRound([height, 0]);
+ }
+
+ var z = d3.scaleOrdinal()
+ .range(["#262362", "#FBB03B", "#286FD7", "#6b486b"]);
+
+ // Set all domains.
+ x0.domain(x0_domain);
+ x1.domain(x1_domain).rangeRound([0, x0.bandwidth()]);
+ y.domain([y_min, d3.max(data, d => d3.max(x1_domain, key => parseFloat(d[key])))]).nice();
+
+ // The data.
+ g.append("g")
+ .selectAll("g")
+ .data(data)
+ .enter().append("g")
+ .attr("transform", function(d) { return "translate(" + x0(d[x0_key]) + ",0)"; })
+ .selectAll("rect")
+ .data(d => x1_domain.map(key => ({key, value: d[key]})))
+ .enter().append("rect")
+ .attr("x", d => x1(d.key))
+ .attr("y", d => y(d.value))
+ .attr("width", x1.bandwidth())
+ .attr("height", d => y(y_min) - y(d.value))
+ .attr("fill", d => z(d.key));
+
+ // X0 ticks and labels.
+ var x0_axis = g.append("g")
+ .attr("class", "axis")
+ .attr("transform", "translate(0," + height + ")")
+ .call(d3.axisBottom(x0));
+ if (x0_domain.length > 8) {
+ x0_axis.selectAll("text")
+ .style("text-anchor", "end")
+ .attr("dx", "-.8em")
+ .attr("dy", ".15em")
+ .attr("transform", "rotate(-65)");
+ }
+
+ // Y ticks and top-label.
+ if (metric == "default") {
+ metric = ""; // Don't display.
+ }
+ g.append("g")
+ .attr("class", "axis")
+ .call(d3.axisLeft(y).ticks(null, "s"))
+ .append("text")
+ .attr("x", -30.0)
+ .attr("y", y(y.ticks().pop()) - 10.0)
+ .attr("dy", "0.32em")
+ .attr("fill", "#000")
+ .attr("font-weight", "bold")
+ .attr("text-anchor", "start")
+ .text(metric);
+
+ // The legend.
+ var legend = g.append("g")
+ .attr("font-family", "sans-serif")
+ .attr("font-size", 10)
+ .attr("text-anchor", "end")
+ .selectAll("g")
+ .data(x1_domain.slice().reverse())
+ .enter().append("g")
+ .attr("transform", function(d, i) { return "translate(0," + i * 20 + ")"; });
+ legend.append("rect")
+ .attr("x", width - 19)
+ .attr("width", 19)
+ .attr("height", 19)
+ .attr("fill", z);
+ legend.append("text")
+ .attr("x", width - 24)
+ .attr("y", 9.5)
+ .attr("dy", "0.32em")
+ .text(function(d) { return d; });
+});
+}
+</script>
+{:/}
diff --git a/website/_includes/header-links.html b/website/_includes/header-links.html
new file mode 100644
index 000000000..467bb1e72
--- /dev/null
+++ b/website/_includes/header-links.html
@@ -0,0 +1,19 @@
+<nav class="navbar navbar-expand-sm navbar-inverse navbar-fixed-top">
+ <div class="container">
+ <div class="navbar-brand">
+ <a href="/">
+ <img src="/assets/logos/logo_solo_on_dark.svg" height="25px" class="d-inline-block align-top" style="margin-right: 10px;" alt="logo"/>
+ gVisor
+ </a>
+ </div>
+
+ <div class="collapse navbar-collapse">
+ <ul class="nav navbar-nav navbar-right">
+ <li><a href="/docs">Documentation</a></li>
+ <li><a href="/blog">Blog</a></li>
+ <li><a href="/community/">Community</a></li>
+ <li><a href="https://github.com/google/gvisor">GitHub</a></li>
+ </ul>
+ </div>
+ </div>
+</nav>
diff --git a/website/_includes/header.html b/website/_includes/header.html
new file mode 100644
index 000000000..c80310069
--- /dev/null
+++ b/website/_includes/header.html
@@ -0,0 +1,30 @@
+ <head>
+ <meta charset="utf-8">
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1">
+ {% if page.title %}
+ <title>{{ page.title }} - gVisor</title>
+ {% else %}
+ <title>gVisor</title>
+ {% endif %}
+ <link rel="canonical" href="{{ page.url | replace:'index.html','' | prepend: site_root }}">
+
+ <!-- Dependencies. -->
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/twitter-bootstrap/3.3.7/css/bootstrap.min.css" integrity="sha256-916EbMg70RQy9LHiGkXzG8hSg9EdNy97GazNG/aiY1w=" crossorigin="anonymous" />
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.10.1/css/all.min.css" integrity="sha256-fdcFNFiBMrNfWL6OcAGQz6jDgNTRxnrLEd4vJYFWScE=" crossorigin="anonymous" />
+
+ <!-- Our own style sheet. -->
+ <link rel="stylesheet" type="text/css" href="/css/main.css">
+ <link rel="icon" type="image/png" href="/assets/favicons/favicon-32x32.png" sizes="32x32">
+ <link rel="icon" type="image/png" href="/assets/favicons/favicon-16x16.png" sizes="16x16">
+
+ {% if page.title %}
+ <meta name="og:title" content="{{ page.title }}">
+ {% else %}
+ <meta name="og:title" content="gVisor">
+ {% endif %}
+ {% if page.description %}
+ <meta name="og:description" content="{{ page.description }}">
+ {% endif %}
+ <meta name="og:image" content="{{ site.site_url }}/assets/logos/logo_solo_on_white_bordered.svg">
+ </head>
diff --git a/website/_includes/paginator.html b/website/_includes/paginator.html
new file mode 100644
index 000000000..b4ff4c3b1
--- /dev/null
+++ b/website/_includes/paginator.html
@@ -0,0 +1,10 @@
+<nav aria-label="...">
+ <ul class="pager">
+ {% if paginator.previous_page %}
+ <li class="previous"><a href="{{ paginator.previous_page_path }}"><span aria-hidden="true">&larr;</span> Newer</a></li>
+ {% endif %}
+ {% if paginator.next_page %}
+ <li class="next"><a href="{{ paginator.next_page_path }}">Older <span aria-hidden="true">&rarr;</span></a></li>
+ {% endif %}
+ </ul>
+</nav>
diff --git a/website/_includes/required_linux.html b/website/_includes/required_linux.html
new file mode 100644
index 000000000..e9d1b7548
--- /dev/null
+++ b/website/_includes/required_linux.html
@@ -0,0 +1,2 @@
+> Note: gVisor supports only x86\_64 and requires Linux 4.14.77+
+> ([older Linux](/docs/user_guide/networking/#gso)).
diff --git a/website/_layouts/base.html b/website/_layouts/base.html
new file mode 100644
index 000000000..b30bee0dc
--- /dev/null
+++ b/website/_layouts/base.html
@@ -0,0 +1,9 @@
+<!DOCTYPE html>
+<html lang="en" itemscope itemtype="https://schema.org/WebPage">
+ {% include header.html %}
+ <body>
+ {% include header-links.html %}
+ {{ content }}
+ {% include footer.html %}
+ </body>
+</html>
diff --git a/website/_layouts/blog.html b/website/_layouts/blog.html
new file mode 100644
index 000000000..6c371ab50
--- /dev/null
+++ b/website/_layouts/blog.html
@@ -0,0 +1,17 @@
+---
+layout: base
+---
+
+<div class="container">
+ <div class="row">
+ <div class="col-lg-2"></div>
+ <div class="col-lg-8">
+ <h1>{{ page.title }}</h1>
+ {% if page.feed %}
+ <a class="btn-inverse" href="/blog/index.xml">Feed&nbsp;<i class="fas fa-rss ml-2"></i></a>
+ {% endif %}
+ {{ content }}
+ </div>
+ <div class="col-lg-2"></div>
+ </div>
+</div>
diff --git a/website/_layouts/default.html b/website/_layouts/default.html
new file mode 100644
index 000000000..e5523e3fc
--- /dev/null
+++ b/website/_layouts/default.html
@@ -0,0 +1,14 @@
+---
+layout: base
+---
+{% if page.title %}
+<div class="container">
+ <div class="page-header">
+ <h1>{{ page.title }}</h1>
+ </div>
+</div>
+{% endif %}
+
+<div class="container">
+ {{ content }}
+</div>
diff --git a/website/_layouts/docs.html b/website/_layouts/docs.html
new file mode 100644
index 000000000..549305089
--- /dev/null
+++ b/website/_layouts/docs.html
@@ -0,0 +1,59 @@
+---
+layout: base
+categories:
+ - Project
+ - User Guide
+ - Architecture Guide
+ - Compatibility
+---
+
+<div class="container">
+ <div class="row">
+ <div class="col-md-3">
+ <nav class="sidebar" id="sidebar-nav">
+ {% for category in layout.categories %}
+ <h3>{{ category }}</h3>
+ <ul class="sidebar-nav">
+ {% assign sorted_pages = site.pages | where: 'layout', 'docs' | where: 'category', category | sort: 'weight' | sort: 'subcategory' %}
+ {% assign subcategory = nil %}
+ {% for p in sorted_pages %}
+ {% if p.subcategory != subcategory %}
+ {% if subcategory != nil %}
+ </ul>
+ </li>
+ {% endif %}
+ {% assign subcategory = p.subcategory %}
+ {% if subcategory != nil %}
+ {% assign ac = "aria-controls" %}
+ {% assign cid = p.category | remove: " " | downcase %}
+ {% assign sid = p.subcategory | remove: " " | downcase %}
+ <li>
+ <a class="sidebar-nav-heading" data-toggle="collapse" href="#{{ cid }}-{{ sid }}" aria-expanded="false" {{ ac }}="{{ cid }}-{{ sid }}">{{ subcategory }}<span class="caret"></span></a>
+ <ul class="collapse sidebar-nav sidebar-submenu" id="{{ cid }}-{{ sid }}">
+ {% endif %}
+ {% endif %}
+ <li><a href="{{ p.url }}">{{ p.title }}</a></li>
+ {% endfor %}
+ {% if subcategory != nil %}
+ </ul>
+ </li>
+ {% endif %}
+ </ul>
+ {% endfor %}
+ </nav>
+ </div>
+
+ <div class="col-md-9">
+ <h1>{{ page.title }}</h1>
+ {% if page.editpath %}
+ <p>
+ <a href="https://github.com/google/gvisor/edit/master/{{page.editpath}}" target="_blank"><i class="fa fa-edit fa-fw"></i> Edit this page</a>
+ <a href="https://github.com/google/gvisor/issues/new?title={{page.title | url_encode}}" target="_blank"><i class="fab fa-github fa-fw"></i> Create issue</a>
+ </p>
+ {% endif %}
+ <div class="docs-content">
+ {{ content }}
+ </div>
+ </div>
+ </div>
+</div>
diff --git a/website/_layouts/post.html b/website/_layouts/post.html
new file mode 100644
index 000000000..640bee5af
--- /dev/null
+++ b/website/_layouts/post.html
@@ -0,0 +1,10 @@
+---
+layout: blog
+---
+
+<div class="blog-meta">
+ {% include byline.html authors=page.authors date=page.date %}
+</div>
+<div class="blog-content">
+ {{ content }}
+</div>
diff --git a/website/_plugins/svg_mime_type.rb b/website/_plugins/svg_mime_type.rb
new file mode 100644
index 000000000..ad6bb6480
--- /dev/null
+++ b/website/_plugins/svg_mime_type.rb
@@ -0,0 +1,3 @@
+require 'webrick'
+include WEBrick
+WEBrick::HTTPUtils::DefaultMimeTypes.store 'svg', 'image/svg+xml'
diff --git a/website/_sass/footer.scss b/website/_sass/footer.scss
new file mode 100644
index 000000000..ec2ba5e20
--- /dev/null
+++ b/website/_sass/footer.scss
@@ -0,0 +1,15 @@
+.footer {
+ margin-top: 40px;
+ background-color: #222;
+ color: #fff;
+ padding: 20px;
+
+ a {
+ color: $inverse-link-color;
+
+ &:hover,
+ &:focus {
+ color: $inverse-link-hover-color;
+ }
+ }
+}
diff --git a/website/_sass/front.scss b/website/_sass/front.scss
new file mode 100644
index 000000000..0e4208f3c
--- /dev/null
+++ b/website/_sass/front.scss
@@ -0,0 +1,17 @@
+.jumbotron {
+ background-image: url(/assets/images/background.jpg);
+ background-position: center;
+ background-repeat: no-repeat;
+ background-size: cover;
+ background-blend-mode: darken;
+ background-color: rgba(0, 0, 0, 0.3);
+
+ p {
+ color: #fff;
+ margin-top: 0;
+ margin-bottom: 0;
+ font-weight: 300;
+ font-size: 24px;
+ line-height: 30px;
+ }
+}
diff --git a/website/_sass/navbar.scss b/website/_sass/navbar.scss
new file mode 100644
index 000000000..65bc573ac
--- /dev/null
+++ b/website/_sass/navbar.scss
@@ -0,0 +1,26 @@
+.navbar-inverse {
+ background-color: $primary;
+ border-bottom: 1px solid $primary;
+
+ .navbar-brand > a {
+ color: #fff;
+
+ &:focus,
+ &:hover {
+ color: #fff;
+ }
+ }
+
+ .navbar-nav > li > a {
+ color: $inverse-link-color;
+
+ &:focus,
+ &:hover {
+ color: $inverse-link-hover-color;
+ }
+ }
+
+ .navbar-nav .nav-icon {
+ font-size: 18px;
+ }
+}
diff --git a/website/_sass/sidebar.scss b/website/_sass/sidebar.scss
new file mode 100644
index 000000000..f4ca05df9
--- /dev/null
+++ b/website/_sass/sidebar.scss
@@ -0,0 +1,61 @@
+$sidebar-border-color: #fff;
+$sidebar-hover-border-color: #66bb6a;
+
+.sidebar {
+ margin-top: 40px;
+
+ ul.sidebar-nav {
+ list-style-type: none;
+ padding: 0;
+ transition: height 0.01s;
+
+ li {
+ &.sidebar-nav-heading {
+ padding: 10px 0;
+ margin: 0;
+ display: block;
+ font-size: 16px;
+ font-weight: 300;
+ }
+
+ a {
+ padding: 4px 0;
+ display: block;
+ border-right: 2px solid $sidebar-border-color;
+
+ &:focus {
+ text-decoration: none;
+ }
+
+ .caret {
+ float: right;
+ margin-top: 8px;
+ margin-right: 10px;
+ }
+ }
+
+ &.active {
+ a {
+ border-left: 2px solid $sidebar-hover-border-color;
+ padding-left: 6px;
+ }
+ }
+ }
+
+ ul.sidebar-nav {
+ padding-left: 10px;
+ }
+ }
+}
+
+@media (min-width: 992px) {
+ .sidebar-toggle {
+ display: none;
+ }
+
+ .sidebar {
+ &.collapse {
+ display: block;
+ }
+ }
+}
diff --git a/website/_sass/style.scss b/website/_sass/style.scss
new file mode 100644
index 000000000..4deb945d4
--- /dev/null
+++ b/website/_sass/style.scss
@@ -0,0 +1,154 @@
+$primary: #262362;
+$secondary: #fff;
+$link-color: #286fd7;
+$inverse-link-color: #fff;
+
+$link-hover-color: darken($link-color, 10%);
+$inverse-link-hover-color: darken($inverse-link-color, 10%);
+
+$text-color: #444;
+
+$body-font-family: 'Roboto', 'Helvetica Neue', Helvetica, Arial, sans-serif;
+$code-font-family: 'Source Code Pro', monospace;
+
+html {
+ position: relative;
+ min-height: 100%;
+}
+
+body {
+ color: $text-color;
+ font-family: $body-font-family;
+ padding-top: 40px;
+}
+
+a {
+ color: $link-color;
+
+ &:hover,
+ &:focus {
+ color: $link-hover-color;
+ text-decoration: none;
+ }
+
+ code {
+ color: $link-color;
+ }
+}
+
+h1,
+h2,
+h3,
+h4,
+h5,
+h6 {
+ color: $text-color;
+ font-weight: 400;
+}
+
+h1 code,
+h2 code,
+h3 code,
+h4 code,
+h5 code,
+h6 code {
+ color: $text-color;
+ background: transparent;
+}
+
+h1 {
+ font-size: 30px;
+ margin-top: 40px;
+ margin-bottom: 40px;
+}
+
+h2 {
+ font-size: 24px;
+ margin-top: 30px;
+ margin-bottom: 30px;
+
+ code {
+ font-size: 24px;
+ }
+}
+
+h3 {
+ font-size: 20px;
+ margin-top: 24px;
+ margin-bottom: 24px;
+
+ code {
+ font-size: 20px;
+ }
+}
+
+h4 {
+ font-size: 18px;
+ margin-top: 20px;
+ margin-bottom: 20px;
+
+ code {
+ font-size: 18px;
+ }
+}
+
+p,
+li {
+ font-size: 14px;
+ line-height: 22px;
+}
+
+code {
+ font-family: $code-font-family;
+ font-size: 13px;
+}
+
+.btn {
+ color: $text-color;
+ background-color: $inverse-link-color;
+}
+
+.btn-inverse {
+ color: $text-color;
+ background-color: #fff;
+}
+
+.well {
+ box-shadow: none;
+}
+
+table {
+ width: 100%;
+}
+
+table td,
+table th {
+ border: 1px solid #ddd;
+ padding: 8px;
+}
+
+table tr:nth-child(even) {
+ background-color: #eee;
+}
+
+table th {
+ padding-top: 12px;
+ padding-bottom: 12px;
+ background-color: $primary;
+ color: $secondary;
+}
+
+.blog-meta {
+ margin-top: 10px;
+ margin-bottom: 20px;
+}
+
+.docs-content * img {
+ display: block;
+ margin: 20px auto;
+}
+
+.blog-content * img {
+ display: block;
+ margin: 20px auto;
+}
diff --git a/website/archive.key b/website/archive.key
new file mode 100644
index 000000000..1a91698bf
--- /dev/null
+++ b/website/archive.key
@@ -0,0 +1,29 @@
+-----BEGIN PGP PUBLIC KEY BLOCK-----
+
+mQINBF0meAYBEACcBYPOSBiKtid+qTQlbgKGPxUYt0cNZiQqWXylhYUT4PuNlNx5
+s+sBLFvNTpdTrXMmZ8NkekyjD1HardWvebvJT4u+Ho/9jUr4rP71cNwNtocz/w8G
+DsUXSLgH8SDkq6xw0L+5eGc78BBg9cOeBeFBm3UPgxTBXS9Zevoi2w1lzSxkXvjx
+cGzltzMZfPXERljgLzp9AAfhg/2ouqVQm37fY+P/NDzFMJ1XHPIIp9KJl/prBVud
+jJJteFZ5sgL6MwjBQq2kw+q2Jb8Zfjl0BeXDgGMN5M5lGhX2wTfiMbfo7KWyzRnB
+RpSP3BxlLqYeQUuLG5Yx8z3oA3uBkuKaFOKvXtiScxmGM/+Ri2YM3m66imwDhtmP
+AKwTPI3Re4gWWOffglMVSv2sUAY32XZ74yXjY1VhK3bN3WFUPGrgQx4X7GP0A1Te
+lzqkT3VSMXieImTASosK5L5Q8rryvgCeI9tQLn9EpYFCtU3LXvVgTreGNEEjMOnL
+dR7yOU+Fs775stn6ucqmdYarx7CvKUrNAhgEeHMonLe1cjYScF7NfLO1GIrQKJR2
+DE0f+uJZ52inOkO8ufh3WVQJSYszuS3HCY7w5oj1aP38k/y9zZdZvVvwAWZaiqBQ
+iwjVs6Kub76VVZZhRDf4iYs8k1Zh64nXdfQt250d8U5yMPF3wIJ+c1yhxwARAQAB
+tCpUaGUgZ1Zpc29yIEF1dGhvcnMgPGd2aXNvci1ib3RAZ29vZ2xlLmNvbT6JAlQE
+EwEKAD4WIQRvHfheOnHCSRjnJ9VvxtVU4yvZQwUCXSZ4BgIbAwUJA8JnAAULCQgH
+AgYVCgkICwIEFgIDAQIeAQIXgAAKCRBvxtVU4yvZQ5WFD/9VZXMW5I2rKV+2gTHT
+CsW74kZVi1VFdAVYiUJZXw2jJNtcg3xdgBcscYPyecyka/6TS2q7q2fOGAzCZkcR
+e3lLzkGAngMlZ7PdHAE0PDMNFaeMZW0dxNH68vn7AiA1y2XwENnxVec7iXQH6aX5
+xUNg2OCiv5f6DJItHc/Q4SvFUi8QK7TT/GYE1RJXVJlLqfO6y4V8SeqfM+FHpHZM
+gzrwdTgsNiEm4lMjWcgb2Ib4i2JUVAjIRPfcpysiV5E7c3SPXyu4bOovKKlbhiJ1
+Q1M9M0zHik34Kjf4YNO1EW936j7Msd181CJt5Bl9XvlhPb8gey/ygpIvcicLx6M5
+lRJTy4z1TtkmtZ7E8EbJZWoPTaHlA6hoMtGeE35j3vMZN1qZYaYt26eFOxxhh7PA
+J0h1lS7T2O8u1c2JKhKvajtdmbqbJgI8FRhVsMoVBnqDK5aE9MOAso36OibfweEL
+8iV2z8JnBpWtbbUEaWro4knPtbLJbQFvXVietm3cFsbGg+DMIwI6x6HcU91IEFYI
+Sv4orK7xgLuM+f6dxo/Wel3ht18dg3x3krBLALTYBidRfnQYYR3sTfLquB8b5WaY
+o829L2Bop9GBygdLevkHHN5It6q8CVpn0H5HEJMNaDOX1LcPbf0CKwkkAVCBd9YZ
+eAX38ds9LliK7XPXdC4c+zEkGA==
+=x8TG
+-----END PGP PUBLIC KEY BLOCK-----
diff --git a/website/assets/favicons/apple-touch-icon-180x180.png b/website/assets/favicons/apple-touch-icon-180x180.png
new file mode 100644
index 000000000..bf4b6ce9b
--- /dev/null
+++ b/website/assets/favicons/apple-touch-icon-180x180.png
Binary files differ
diff --git a/website/assets/favicons/favicon-16x16.png b/website/assets/favicons/favicon-16x16.png
new file mode 100644
index 000000000..083264206
--- /dev/null
+++ b/website/assets/favicons/favicon-16x16.png
Binary files differ
diff --git a/website/assets/favicons/favicon-32x32.png b/website/assets/favicons/favicon-32x32.png
new file mode 100644
index 000000000..b8e4caff1
--- /dev/null
+++ b/website/assets/favicons/favicon-32x32.png
Binary files differ
diff --git a/website/assets/favicons/favicon.ico b/website/assets/favicons/favicon.ico
new file mode 100644
index 000000000..9238b79d9
--- /dev/null
+++ b/website/assets/favicons/favicon.ico
Binary files differ
diff --git a/website/assets/favicons/pwa-192x192.png b/website/assets/favicons/pwa-192x192.png
new file mode 100644
index 000000000..5d2fab785
--- /dev/null
+++ b/website/assets/favicons/pwa-192x192.png
Binary files differ
diff --git a/website/assets/favicons/pwa-512x512.png b/website/assets/favicons/pwa-512x512.png
new file mode 100644
index 000000000..23824439e
--- /dev/null
+++ b/website/assets/favicons/pwa-512x512.png
Binary files differ
diff --git a/website/assets/favicons/tile150x150.png b/website/assets/favicons/tile150x150.png
new file mode 100644
index 000000000..f76fcffae
--- /dev/null
+++ b/website/assets/favicons/tile150x150.png
Binary files differ
diff --git a/website/assets/favicons/tile310x150.png b/website/assets/favicons/tile310x150.png
new file mode 100644
index 000000000..4f87e4c12
--- /dev/null
+++ b/website/assets/favicons/tile310x150.png
Binary files differ
diff --git a/website/assets/favicons/tile310x310.png b/website/assets/favicons/tile310x310.png
new file mode 100644
index 000000000..a2926d0bd
--- /dev/null
+++ b/website/assets/favicons/tile310x310.png
Binary files differ
diff --git a/website/assets/favicons/tile70x70.png b/website/assets/favicons/tile70x70.png
new file mode 100644
index 000000000..96cc69fc4
--- /dev/null
+++ b/website/assets/favicons/tile70x70.png
Binary files differ
diff --git a/website/assets/images/2019-11-18-security-basics-figure1.png b/website/assets/images/2019-11-18-security-basics-figure1.png
new file mode 100644
index 000000000..2a8134a7a
--- /dev/null
+++ b/website/assets/images/2019-11-18-security-basics-figure1.png
Binary files differ
diff --git a/website/assets/images/2019-11-18-security-basics-figure2.png b/website/assets/images/2019-11-18-security-basics-figure2.png
new file mode 100644
index 000000000..f8b416e1d
--- /dev/null
+++ b/website/assets/images/2019-11-18-security-basics-figure2.png
Binary files differ
diff --git a/website/assets/images/2019-11-18-security-basics-figure3.png b/website/assets/images/2019-11-18-security-basics-figure3.png
new file mode 100644
index 000000000..833e3e2b5
--- /dev/null
+++ b/website/assets/images/2019-11-18-security-basics-figure3.png
Binary files differ
diff --git a/website/assets/images/2020-04-02-networking-security-figure1.png b/website/assets/images/2020-04-02-networking-security-figure1.png
new file mode 100644
index 000000000..b49cb0242
--- /dev/null
+++ b/website/assets/images/2020-04-02-networking-security-figure1.png
Binary files differ
diff --git a/website/assets/images/background.jpg b/website/assets/images/background.jpg
new file mode 100644
index 000000000..81f8e332b
--- /dev/null
+++ b/website/assets/images/background.jpg
Binary files differ
diff --git a/website/assets/logos/Makefile b/website/assets/logos/Makefile
new file mode 100644
index 000000000..49289ecc1
--- /dev/null
+++ b/website/assets/logos/Makefile
@@ -0,0 +1,13 @@
+#!/usr/bin/make -f
+
+srcs := $(wildcard *.svg)
+dsts := $(patsubst %.svg,%.png,$(srcs))
+
+all: $(dsts)
+.PHONY: all
+
+%.png %-16.png %-128.png %-1024.png: %.svg
+ @inkscape -z -e $*.png $<
+ @inkscape -z -w 16 -e $*-16.png $<
+ @inkscape -z -w 128 -e $*-128.png $<
+ @inkscape -z -w 1024 -e $*-1024.png $<
diff --git a/website/assets/logos/README.md b/website/assets/logos/README.md
new file mode 100644
index 000000000..2964982dd
--- /dev/null
+++ b/website/assets/logos/README.md
@@ -0,0 +1,10 @@
+# Logos
+
+This directory contains logo assets.
+
+The colors used are:
+
+* Background (blue): #262262
+* Highlight (yellow): #FBB03B
+
+Use `make` to generate sized PNGs from SVGs.
diff --git a/website/assets/logos/logo_solo_monochrome.png b/website/assets/logos/logo_solo_monochrome.png
new file mode 100644
index 000000000..e09c5ad5e
--- /dev/null
+++ b/website/assets/logos/logo_solo_monochrome.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_monochrome.svg b/website/assets/logos/logo_solo_monochrome.svg
new file mode 100644
index 000000000..73126fd8f
--- /dev/null
+++ b/website/assets/logos/logo_solo_monochrome.svg
@@ -0,0 +1,73 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="175.35599"
+ height="193.20036"
+ viewBox="0 0 175.35599 193.20036"
+ sodipodi:docname="logo_solo_monochrome.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath20"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path18"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="640"
+ inkscape:window-height="480"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.43085925"
+ inkscape:cx="374.99057"
+ inkscape:cy="88.483321"
+ inkscape:window-x="0"
+ inkscape:window-y="9"
+ inkscape:window-maximized="0"
+ inkscape:current-layer="g10" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-614.45037,638.9628)"><g
+ id="g48"
+ transform="translate(548.2423,363.2485)"><path
+ d="m 0,0 c 16.655,21.121 22.696,44.433 18.328,70.995 3.068,0 5.743,-0.023 8.417,0.007 2.222,0.025 4.443,0.102 6.664,0.175 4.79,0.154 4.818,0.165 5.88,-4.582 3.145,-14.051 2.18,-28.09 -0.179,-42.118 -0.25,-1.492 -0.7,-2.956 -0.864,-4.454 C 37.05,9.081 30.089,3.645 20.165,1.097 13.787,-0.54 7.323,-0.829 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path50"
+ inkscape:connector-curvature="0" /></g><g
+ id="g52"
+ transform="translate(544.6891,396.1763)"><path
+ d="M 0,0 C 0,2.594 -3.457,4.322 -3.457,4.322 -0.864,5.187 0,8.644 0,8.644 0,8.644 0.865,5.187 3.458,4.322 3.458,4.322 0,2.594 0,0 m -17.099,6.454 c 0,6.742 -8.989,11.236 -8.989,11.236 6.742,2.248 8.989,11.238 8.989,11.238 0,0 2.247,-8.99 8.99,-11.238 0,0 -8.99,-4.494 -8.99,-11.236"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path54"
+ inkscape:connector-curvature="0" /></g><g
+ id="g56"
+ transform="translate(485.0861,429.7925)"><path
+ d="m 0,0 c 0,0 -2.214,3.359 -9.736,2.059 0,0 5.987,28.738 36.298,38.806 C 26.562,40.865 -2.82,24.275 0,0 M 0.583,-33.208 0.58,-33.211 c -1.297,-2.026 -2.821,-3.579 -4.53,-4.616 -1.515,-0.933 -3.178,-1.425 -4.743,-1.425 -0.528,0 -1.044,0.057 -1.539,0.17 l -2.209,0.507 2.184,0.603 c 1.63,0.451 3.063,1.347 4.259,2.664 1.014,1.108 1.856,2.485 2.577,4.213 1.175,2.82 1.784,6.162 1.81,9.936 -0.049,3.718 -0.683,7.054 -1.886,9.902 -0.703,1.654 -1.585,3.056 -2.621,4.163 -1.227,1.311 -2.671,2.178 -4.294,2.576 l -2.187,0.538 2.179,0.572 c 0.48,0.126 0.998,0.196 1.539,0.209 h 0.031 l 1.688,-0.153 c 1.045,-0.206 2.104,-0.616 3.061,-1.185 1.755,-1.031 3.302,-2.567 4.598,-4.565 2.155,-3.374 3.315,-7.536 3.357,-12.042 -0.028,-4.548 -1.159,-8.717 -3.271,-12.064 m 101.949,51.176 c 0,0 -0.075,0.05 -0.226,0.136 -0.541,0.327 -1.113,0.603 -1.715,0.815 -3.044,1.241 -9.881,3.186 -21.906,2.623 -0.029,0 -0.056,0 -0.085,-0.001 C 52.916,21.123 30.022,10.922 30.022,10.922 c 0,0 1.439,1.761 3.453,3.692 10e-4,0 10e-4,10e-4 0.002,10e-4 1.052,0.974 2.355,2.076 3.912,3.227 0.046,0.031 0.088,0.063 0.124,0.093 8.708,6.384 25.34,14.163 51.625,9.541 -0.989,1.124 -2.002,2.192 -3.036,3.215 -1.112,0.883 -2.231,1.693 -3.354,2.456 0.02,-0.012 0.039,-0.023 0.059,-0.036 0,0 -17.016,19.415 -48.683,15.891 C 30.19,48.622 25.983,47.867 21.66,46.564 21.653,46.563 21.646,46.562 21.64,46.56 L 21.638,46.558 C 11.48,43.492 0.683,37.387 -8.719,25.911 -9.29,25.193 -9.841,24.479 -10.376,23.77 c -0.065,-0.099 -0.141,-0.202 -0.226,-0.307 -1.182,-1.581 -2.271,-3.14 -3.279,-4.674 -3.266,-5.427 -5.631,-11.665 -6.311,-13.545 -10.58,-32.401 2.586,-57.55 5.144,-61.967 8.93,-15.158 24.565,-32.355 50.771,-37.327 0.197,-0.047 0.382,-0.101 0.582,-0.147 1.723,-0.367 4.864,-0.929 8.908,-1.196 1.524,-0.069 3.088,-0.094 4.699,-0.067 1.548,-0.009 2.999,0.017 4.335,0.064 0.396,0.028 0.74,0.041 1.044,0.044 5.102,0.238 8.272,0.775 8.272,0.775 -26.543,1.299 -39.847,13.409 -45.691,21.142 -1.325,1.648 -2.46,3.42 -3.377,5.316 -0.361,0.71 -0.523,1.115 -0.523,1.115 8.459,-7.181 20.294,-13.362 20.294,-13.362 10.611,-4.993 21.737,-7.451 33.524,-5.837 0,0 24.645,2.263 34.464,25.09 -0.423,0.322 -0.366,0.278 -0.79,0.6 -0.652,-0.526 -1.294,-1.023 -1.926,-1.496 -0.991,-0.651 -1.964,-1.357 -2.937,-2.07 -5.265,-3.485 -9.561,-5.128 -12.12,-5.879 -19.359,-4.887 -37.273,-1.252 -52.93,12.455 -7.253,6.349 -11.754,14.352 -11.944,24.291 -0.124,6.465 -0.19,12.935 -0.136,19.4 0.085,10.181 7.246,17.921 17.394,19.284 20.561,2.759 41.234,3.71 61.947,4.193 3.174,0.073 5.981,0.037 7.458,-3.356 0.025,-0.058 0.363,0.02 0.552,0.035 1.4,4.461 -0.617,9.27 -4.29,11.624"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path58"
+ inkscape:connector-curvature="0" /></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/logo_solo_on_dark-1024.png b/website/assets/logos/logo_solo_on_dark-1024.png
new file mode 100644
index 000000000..6df428c65
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark-1024.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_dark-128.png b/website/assets/logos/logo_solo_on_dark-128.png
new file mode 100644
index 000000000..78a85475f
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark-128.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_dark-16.png b/website/assets/logos/logo_solo_on_dark-16.png
new file mode 100644
index 000000000..4f1e91c02
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark-16.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_dark.png b/website/assets/logos/logo_solo_on_dark.png
new file mode 100644
index 000000000..da20756f7
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_dark.svg b/website/assets/logos/logo_solo_on_dark.svg
new file mode 100644
index 000000000..ae8d9e879
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark.svg
@@ -0,0 +1,73 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="175.35599"
+ height="193.19984"
+ viewBox="0 0 175.35599 193.19985"
+ sodipodi:docname="logo_solo_on_dark.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title /></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath20"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path18"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="1278"
+ inkscape:window-height="699"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.8617185"
+ inkscape:cx="257.20407"
+ inkscape:cy="172.193"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="1"
+ inkscape:current-layer="g10" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-614.45037,638.96254)"><g
+ id="g48"
+ transform="translate(548.2423,363.2484)"><path
+ d="m 0,0 c 16.655,21.121 22.696,44.434 18.328,70.995 3.068,0 5.743,-0.023 8.417,0.007 2.222,0.025 4.443,0.102 6.664,0.175 4.79,0.154 4.818,0.165 5.88,-4.582 3.145,-14.05 2.18,-28.089 -0.179,-42.118 -0.25,-1.492 -0.7,-2.956 -0.864,-4.454 C 37.05,9.081 30.089,3.645 20.165,1.097 13.787,-0.54 7.323,-0.829 0,0"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path50"
+ inkscape:connector-curvature="0" /></g><g
+ id="g52"
+ transform="translate(544.6891,396.1771)"><path
+ d="M 0,0 C 0,2.593 -3.457,4.321 -3.457,4.321 -0.864,5.186 0,8.644 0,8.644 0,8.644 0.865,5.186 3.458,4.321 3.458,4.321 0,2.593 0,0 m -17.099,6.453 c 0,6.742 -8.989,11.237 -8.989,11.237 6.742,2.247 8.989,11.237 8.989,11.237 0,0 2.247,-8.99 8.99,-11.237 0,0 -8.99,-4.495 -8.99,-11.237"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path54"
+ inkscape:connector-curvature="0" /></g><g
+ id="g56"
+ transform="translate(485.0861,429.7923)"><path
+ d="m 0,0 c 0,0 -2.214,3.359 -9.736,2.059 0,0 5.987,28.738 36.298,38.806 C 26.562,40.865 -2.82,24.275 0,0 M 0.583,-33.208 0.58,-33.211 c -1.297,-2.026 -2.821,-3.578 -4.53,-4.615 -1.515,-0.934 -3.178,-1.425 -4.743,-1.425 -0.528,0 -1.044,0.056 -1.539,0.169 l -2.209,0.507 2.184,0.603 c 1.63,0.451 3.063,1.347 4.259,2.664 1.014,1.108 1.856,2.485 2.577,4.214 1.175,2.819 1.784,6.161 1.81,9.935 -0.049,3.719 -0.683,7.054 -1.886,9.902 -0.703,1.655 -1.585,3.057 -2.621,4.163 -1.227,1.311 -2.671,2.178 -4.294,2.576 l -2.187,0.538 2.179,0.572 c 0.48,0.127 0.998,0.196 1.539,0.209 l 0.031,10e-4 1.688,-0.154 c 1.045,-0.206 2.104,-0.615 3.061,-1.184 1.755,-1.032 3.302,-2.568 4.598,-4.565 2.155,-3.374 3.315,-7.537 3.357,-12.043 -0.028,-4.548 -1.159,-8.717 -3.271,-12.064 m 101.949,51.176 c 0,0 -0.075,0.051 -0.226,0.137 -0.541,0.326 -1.113,0.602 -1.715,0.814 -3.044,1.241 -9.881,3.187 -21.906,2.623 -0.029,0 -0.056,0 -0.085,-0.001 C 52.916,21.123 30.022,10.923 30.022,10.923 c 0,0 1.439,1.76 3.453,3.691 10e-4,10e-4 10e-4,10e-4 0.002,0.002 1.052,0.973 2.355,2.076 3.912,3.226 0.046,0.032 0.088,0.063 0.124,0.094 8.708,6.383 25.34,14.162 51.625,9.54 -0.989,1.124 -2.002,2.193 -3.036,3.215 -1.112,0.884 -2.231,1.694 -3.354,2.456 0.02,-0.012 0.039,-0.023 0.059,-0.036 0,0 -17.016,19.415 -48.683,15.891 C 30.19,48.622 25.983,47.868 21.66,46.564 21.653,46.563 21.646,46.562 21.64,46.56 L 21.638,46.558 C 11.48,43.492 0.683,37.387 -8.719,25.911 c -0.571,-0.718 -1.122,-1.431 -1.657,-2.14 -0.065,-0.1 -0.141,-0.202 -0.226,-0.307 -1.182,-1.582 -2.271,-3.141 -3.279,-4.674 -3.266,-5.427 -5.631,-11.666 -6.311,-13.546 -10.58,-32.401 2.586,-57.549 5.144,-61.967 8.93,-15.157 24.565,-32.355 50.771,-37.327 0.197,-0.046 0.382,-0.101 0.582,-0.146 1.723,-0.367 4.864,-0.929 8.908,-1.197 1.524,-0.069 3.088,-0.094 4.699,-0.066 1.548,-0.01 2.999,0.017 4.335,0.064 0.396,0.027 0.74,0.041 1.044,0.044 5.102,0.237 8.272,0.774 8.272,0.774 -26.543,1.3 -39.847,13.41 -45.691,21.142 -1.325,1.648 -2.46,3.421 -3.377,5.316 -0.361,0.711 -0.523,1.115 -0.523,1.115 8.459,-7.18 20.294,-13.361 20.294,-13.361 10.611,-4.993 21.737,-7.452 33.524,-5.838 0,0 24.645,2.264 34.464,25.09 -0.423,0.322 -0.366,0.278 -0.79,0.6 -0.652,-0.526 -1.294,-1.022 -1.926,-1.496 -0.991,-0.65 -1.964,-1.357 -2.937,-2.07 -5.265,-3.485 -9.561,-5.127 -12.12,-5.879 -19.359,-4.887 -37.273,-1.252 -52.93,12.455 -7.253,6.349 -11.754,14.353 -11.944,24.291 -0.124,6.466 -0.19,12.935 -0.136,19.4 0.085,10.181 7.246,17.921 17.394,19.284 20.561,2.759 41.234,3.71 61.947,4.193 3.174,0.074 5.981,0.037 7.458,-3.356 0.025,-0.058 0.363,0.02 0.552,0.035 1.4,4.462 -0.617,9.27 -4.29,11.624"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path58"
+ inkscape:connector-curvature="0" /></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/logo_solo_on_dark_full-1024.png b/website/assets/logos/logo_solo_on_dark_full-1024.png
new file mode 100644
index 000000000..8d597dd3d
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark_full-1024.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_dark_full-128.png b/website/assets/logos/logo_solo_on_dark_full-128.png
new file mode 100644
index 000000000..fe6dd5dea
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark_full-128.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_dark_full-16.png b/website/assets/logos/logo_solo_on_dark_full-16.png
new file mode 100644
index 000000000..f9aa7dfdd
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark_full-16.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_dark_full.png b/website/assets/logos/logo_solo_on_dark_full.png
new file mode 100644
index 000000000..611b0565e
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark_full.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_dark_full.svg b/website/assets/logos/logo_solo_on_dark_full.svg
new file mode 100644
index 000000000..6440835b1
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark_full.svg
@@ -0,0 +1,79 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="246.17325"
+ height="246.17325"
+ viewBox="0 0 246.17325 246.17326"
+ sodipodi:docname="logo_solo_on_dark_full.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title /></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath20"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path18"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="1278"
+ inkscape:window-height="699"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.8617185"
+ inkscape:cx="291.21659"
+ inkscape:cy="198.28704"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="1"
+ inkscape:current-layer="g10" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-580.43785,665.8419)"><circle
+ id="path83"
+ cx="527.64337"
+ cy="-407.06647"
+ r="92.314972"
+ transform="scale(1,-1)"
+ style="fill:#262262;fill-opacity:1;stroke-width:0.48076925" /><g
+ id="g48"
+ transform="translate(548.2423,363.2484)"><path
+ d="m 0,0 c 16.655,21.121 22.696,44.434 18.328,70.995 3.068,0 5.743,-0.023 8.417,0.007 2.222,0.025 4.443,0.102 6.664,0.175 4.79,0.154 4.818,0.165 5.88,-4.582 3.145,-14.05 2.18,-28.089 -0.179,-42.118 -0.25,-1.492 -0.7,-2.956 -0.864,-4.454 C 37.05,9.081 30.089,3.645 20.165,1.097 13.787,-0.54 7.323,-0.829 0,0"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path50"
+ inkscape:connector-curvature="0" /></g><g
+ id="g52"
+ transform="translate(544.6891,396.1771)"><path
+ d="M 0,0 C 0,2.593 -3.457,4.321 -3.457,4.321 -0.864,5.186 0,8.644 0,8.644 0,8.644 0.865,5.186 3.458,4.321 3.458,4.321 0,2.593 0,0 m -17.099,6.453 c 0,6.742 -8.989,11.237 -8.989,11.237 6.742,2.247 8.989,11.237 8.989,11.237 0,0 2.247,-8.99 8.99,-11.237 0,0 -8.99,-4.495 -8.99,-11.237"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path54"
+ inkscape:connector-curvature="0" /></g><g
+ id="g56"
+ transform="translate(485.0861,429.7923)"><path
+ d="m 0,0 c 0,0 -2.214,3.359 -9.736,2.059 0,0 5.987,28.738 36.298,38.806 C 26.562,40.865 -2.82,24.275 0,0 M 0.583,-33.208 0.58,-33.211 c -1.297,-2.026 -2.821,-3.578 -4.53,-4.615 -1.515,-0.934 -3.178,-1.425 -4.743,-1.425 -0.528,0 -1.044,0.056 -1.539,0.169 l -2.209,0.507 2.184,0.603 c 1.63,0.451 3.063,1.347 4.259,2.664 1.014,1.108 1.856,2.485 2.577,4.214 1.175,2.819 1.784,6.161 1.81,9.935 -0.049,3.719 -0.683,7.054 -1.886,9.902 -0.703,1.655 -1.585,3.057 -2.621,4.163 -1.227,1.311 -2.671,2.178 -4.294,2.576 l -2.187,0.538 2.179,0.572 c 0.48,0.127 0.998,0.196 1.539,0.209 l 0.031,10e-4 1.688,-0.154 c 1.045,-0.206 2.104,-0.615 3.061,-1.184 1.755,-1.032 3.302,-2.568 4.598,-4.565 2.155,-3.374 3.315,-7.537 3.357,-12.043 -0.028,-4.548 -1.159,-8.717 -3.271,-12.064 m 101.949,51.176 c 0,0 -0.075,0.051 -0.226,0.137 -0.541,0.326 -1.113,0.602 -1.715,0.814 -3.044,1.241 -9.881,3.187 -21.906,2.623 -0.029,0 -0.056,0 -0.085,-0.001 C 52.916,21.123 30.022,10.923 30.022,10.923 c 0,0 1.439,1.76 3.453,3.691 10e-4,10e-4 10e-4,10e-4 0.002,0.002 1.052,0.973 2.355,2.076 3.912,3.226 0.046,0.032 0.088,0.063 0.124,0.094 8.708,6.383 25.34,14.162 51.625,9.54 -0.989,1.124 -2.002,2.193 -3.036,3.215 -1.112,0.884 -2.231,1.694 -3.354,2.456 0.02,-0.012 0.039,-0.023 0.059,-0.036 0,0 -17.016,19.415 -48.683,15.891 C 30.19,48.622 25.983,47.868 21.66,46.564 21.653,46.563 21.646,46.562 21.64,46.56 L 21.638,46.558 C 11.48,43.492 0.683,37.387 -8.719,25.911 c -0.571,-0.718 -1.122,-1.431 -1.657,-2.14 -0.065,-0.1 -0.141,-0.202 -0.226,-0.307 -1.182,-1.582 -2.271,-3.141 -3.279,-4.674 -3.266,-5.427 -5.631,-11.666 -6.311,-13.546 -10.58,-32.401 2.586,-57.549 5.144,-61.967 8.93,-15.157 24.565,-32.355 50.771,-37.327 0.197,-0.046 0.382,-0.101 0.582,-0.146 1.723,-0.367 4.864,-0.929 8.908,-1.197 1.524,-0.069 3.088,-0.094 4.699,-0.066 1.548,-0.01 2.999,0.017 4.335,0.064 0.396,0.027 0.74,0.041 1.044,0.044 5.102,0.237 8.272,0.774 8.272,0.774 -26.543,1.3 -39.847,13.41 -45.691,21.142 -1.325,1.648 -2.46,3.421 -3.377,5.316 -0.361,0.711 -0.523,1.115 -0.523,1.115 8.459,-7.18 20.294,-13.361 20.294,-13.361 10.611,-4.993 21.737,-7.452 33.524,-5.838 0,0 24.645,2.264 34.464,25.09 -0.423,0.322 -0.366,0.278 -0.79,0.6 -0.652,-0.526 -1.294,-1.022 -1.926,-1.496 -0.991,-0.65 -1.964,-1.357 -2.937,-2.07 -5.265,-3.485 -9.561,-5.127 -12.12,-5.879 -19.359,-4.887 -37.273,-1.252 -52.93,12.455 -7.253,6.349 -11.754,14.353 -11.944,24.291 -0.124,6.466 -0.19,12.935 -0.136,19.4 0.085,10.181 7.246,17.921 17.394,19.284 20.561,2.759 41.234,3.71 61.947,4.193 3.174,0.074 5.981,0.037 7.458,-3.356 0.025,-0.058 0.363,0.02 0.552,0.035 1.4,4.462 -0.617,9.27 -4.29,11.624"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path58"
+ inkscape:connector-curvature="0" /></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/logo_solo_on_white.png b/website/assets/logos/logo_solo_on_white.png
new file mode 100644
index 000000000..ca539cdff
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_white.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_white.svg b/website/assets/logos/logo_solo_on_white.svg
new file mode 100644
index 000000000..d794ad8e7
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_white.svg
@@ -0,0 +1,73 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="175.35599"
+ height="193.20036"
+ viewBox="0 0 175.35599 193.20037"
+ sodipodi:docname="logo_solo_on_white.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath18"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path16"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="640"
+ inkscape:window-height="480"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.43085925"
+ inkscape:cx="370.53985"
+ inkscape:cy="50.91009"
+ inkscape:window-x="0"
+ inkscape:window-y="9"
+ inkscape:window-maximized="0"
+ inkscape:current-layer="g10" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-614.45037,638.9628)"><g
+ id="g46"
+ transform="translate(548.2428,363.2485)"><path
+ d="m 0,0 c 16.655,21.121 22.696,44.433 18.328,70.995 3.068,0 5.742,-0.023 8.417,0.007 2.221,0.025 4.442,0.102 6.663,0.175 4.79,0.154 4.818,0.165 5.881,-4.582 C 42.434,52.544 41.469,38.505 39.11,24.477 38.859,22.985 38.409,21.521 38.246,20.023 37.05,9.081 30.089,3.645 20.164,1.097 13.786,-0.54 7.323,-0.829 0,0"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path48"
+ inkscape:connector-curvature="0" /></g><g
+ id="g50"
+ transform="translate(544.6891,396.1763)"><path
+ d="M 0,0 C 0,2.594 -3.457,4.322 -3.457,4.322 -0.864,5.187 0,8.644 0,8.644 0,8.644 0.865,5.187 3.458,4.322 3.458,4.322 0,2.594 0,0 m -17.099,6.454 c 0,6.742 -8.989,11.236 -8.989,11.236 6.742,2.248 8.989,11.238 8.989,11.238 0,0 2.248,-8.99 8.99,-11.238 0,0 -8.99,-4.494 -8.99,-11.236"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path52"
+ inkscape:connector-curvature="0" /></g><g
+ id="g54"
+ transform="translate(485.0861,429.7925)"><path
+ d="m 0,0 c 0,0 -2.214,3.359 -9.736,2.059 0,0 5.987,28.738 36.298,38.806 C 26.562,40.865 -2.82,24.275 0,0 M 0.583,-33.208 0.58,-33.211 c -1.297,-2.026 -2.821,-3.579 -4.53,-4.616 -1.515,-0.933 -3.178,-1.425 -4.743,-1.425 -0.528,0 -1.044,0.057 -1.539,0.17 l -2.209,0.507 2.184,0.603 c 1.63,0.451 3.063,1.347 4.259,2.664 1.014,1.108 1.856,2.485 2.577,4.213 1.175,2.82 1.784,6.162 1.81,9.936 -0.049,3.718 -0.683,7.054 -1.886,9.902 -0.703,1.654 -1.585,3.056 -2.621,4.163 -1.227,1.311 -2.671,2.178 -4.294,2.576 l -2.187,0.538 2.179,0.572 c 0.48,0.126 0.998,0.196 1.539,0.209 h 0.031 l 1.688,-0.153 c 1.045,-0.206 2.104,-0.616 3.061,-1.185 1.755,-1.031 3.302,-2.567 4.598,-4.565 2.155,-3.374 3.315,-7.536 3.357,-12.042 -0.028,-4.548 -1.159,-8.717 -3.271,-12.064 m 101.949,51.176 c 0,0 -0.075,0.05 -0.226,0.136 -0.541,0.327 -1.113,0.603 -1.715,0.815 -3.044,1.241 -9.881,3.186 -21.906,2.623 -0.029,0 -0.056,0 -0.085,-0.001 C 52.916,21.123 30.022,10.922 30.022,10.922 c 0,0 1.439,1.761 3.453,3.692 10e-4,0 10e-4,10e-4 0.002,10e-4 1.052,0.974 2.355,2.076 3.912,3.227 0.046,0.031 0.088,0.063 0.124,0.093 8.708,6.384 25.341,14.163 51.625,9.541 -0.989,1.124 -2.002,2.192 -3.036,3.215 -1.112,0.883 -2.231,1.693 -3.354,2.456 0.02,-0.012 0.039,-0.023 0.059,-0.036 0,0 -17.016,19.415 -48.683,15.891 C 30.19,48.622 25.983,47.867 21.66,46.564 21.653,46.563 21.646,46.562 21.64,46.56 L 21.638,46.558 C 11.48,43.492 0.683,37.387 -8.719,25.911 -9.29,25.193 -9.841,24.479 -10.376,23.77 c -0.065,-0.099 -0.141,-0.202 -0.226,-0.307 -1.182,-1.581 -2.271,-3.14 -3.279,-4.674 -3.266,-5.427 -5.631,-11.665 -6.311,-13.545 -10.58,-32.401 2.586,-57.55 5.144,-61.967 8.93,-15.158 24.565,-32.355 50.771,-37.327 0.197,-0.047 0.382,-0.101 0.582,-0.147 1.723,-0.367 4.864,-0.929 8.908,-1.196 1.524,-0.069 3.088,-0.094 4.699,-0.067 1.548,-0.009 2.999,0.017 4.335,0.064 0.396,0.028 0.74,0.041 1.044,0.044 5.102,0.238 8.272,0.775 8.272,0.775 -26.542,1.299 -39.847,13.409 -45.691,21.142 -1.325,1.648 -2.46,3.42 -3.377,5.316 -0.361,0.71 -0.523,1.115 -0.523,1.115 8.459,-7.181 20.294,-13.362 20.294,-13.362 10.611,-4.993 21.737,-7.451 33.525,-5.837 0,0 24.644,2.263 34.463,25.09 -0.423,0.322 -0.366,0.278 -0.79,0.6 -0.651,-0.526 -1.294,-1.023 -1.926,-1.496 -0.991,-0.651 -1.964,-1.357 -2.937,-2.07 -5.265,-3.485 -9.561,-5.128 -12.12,-5.879 -19.359,-4.887 -37.273,-1.252 -52.93,12.455 -7.253,6.349 -11.754,14.352 -11.944,24.291 -0.124,6.465 -0.19,12.935 -0.136,19.4 0.085,10.181 7.246,17.921 17.394,19.284 20.561,2.759 41.234,3.71 61.948,4.193 3.173,0.073 5.98,0.037 7.457,-3.356 0.025,-0.058 0.363,0.02 0.552,0.035 1.4,4.461 -0.617,9.27 -4.29,11.624"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path56"
+ inkscape:connector-curvature="0" /></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/logo_solo_on_white_bordered-1024.png b/website/assets/logos/logo_solo_on_white_bordered-1024.png
new file mode 100644
index 000000000..62bb88d50
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_white_bordered-1024.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_white_bordered-128.png b/website/assets/logos/logo_solo_on_white_bordered-128.png
new file mode 100644
index 000000000..a8988766c
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_white_bordered-128.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_white_bordered-16.png b/website/assets/logos/logo_solo_on_white_bordered-16.png
new file mode 100644
index 000000000..a545c49cf
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_white_bordered-16.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_white_bordered.png b/website/assets/logos/logo_solo_on_white_bordered.png
new file mode 100644
index 000000000..cc99b7c51
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_white_bordered.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_white_bordered.svg b/website/assets/logos/logo_solo_on_white_bordered.svg
new file mode 100644
index 000000000..2e26f144a
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_white_bordered.svg
@@ -0,0 +1,82 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="190.7361"
+ height="207.92123"
+ viewBox="0 0 190.7361 207.92124"
+ sodipodi:docname="logo_solo_on_white_bordered.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath18"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path16"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="640"
+ inkscape:window-height="480"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.609327"
+ inkscape:cx="298.55736"
+ inkscape:cy="108.65533"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="0"
+ inkscape:current-layer="g14" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-612.10927,647.00852)"><g
+ id="g12"><g
+ id="g14"
+ clip-path="url(#clipPath18)"><g
+ id="g46"
+ transform="translate(594.4321,453.1439)"><path
+ d="m 0,0 c 0,0 -0.204,0.139 -0.45,0.277 -0.906,0.547 -1.856,1 -2.744,1.321 -1.685,0.679 -3.62,1.262 -5.761,1.738 l -2.063,2.533 c -1.793,2.283 -6.09,7.358 -13.132,13.334 -8.142,6.209 -24.212,15.045 -47.595,12.442 -4.578,-0.444 -9.077,-1.318 -13.368,-2.597 l -0.232,-0.068 c -12.978,-3.918 -24.155,-11.512 -33.24,-22.601 -0.6,-0.754 -1.179,-1.504 -1.783,-2.307 l -0.134,-0.191 -0.062,-0.074 c -1.194,-1.596 -2.36,-3.258 -3.485,-4.969 l -0.125,-0.198 c -3.559,-5.915 -6.126,-12.72 -6.85,-14.73 -11.284,-34.556 2.735,-61.502 5.669,-66.567 8.482,-14.4 23.945,-32.461 50.005,-38.975 1.42,-0.354 2.872,-0.676 4.356,-0.96 0.016,-0.004 0.036,-0.009 0.053,-0.013 6.537,-1.118 16.647,-1.928 29.969,-0.317 3.21,0.621 8.236,2.535 8.646,8.445 2.209,0.842 10.261,3.812 10.261,3.812 8.572,3.874 18.586,11.106 24.334,24.546 1.21,2.83 0.277,6.128 -2.171,7.994 l -0.202,0.152 c 0.639,1.557 1.125,3.209 1.488,4.93 l 0.019,0.009 c 0,0 0.063,0.325 0.164,0.854 0.013,0.073 0.028,0.144 0.041,0.215 0.402,2.114 1.294,6.91 1.719,10.035 0.02,0.149 0.029,0.268 0.033,0.371 2.136,14.686 2.099,26.608 -0.156,37.847 1.443,0.114 2.672,1.095 3.106,2.477 l 0.655,2.086 C 9.188,-12.066 6.243,-4.003 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path48"
+ inkscape:connector-curvature="0" /></g><g
+ id="g50"
+ transform="translate(551.9267,364.2689)"><path
+ d="m 0,0 c 16.398,20.796 22.346,43.748 18.045,69.9 3.02,0 5.654,-0.022 8.288,0.007 2.187,0.025 4.374,0.101 6.56,0.172 4.716,0.152 4.743,0.163 5.79,-4.512 C 41.779,51.734 40.83,37.911 38.507,24.1 38.26,22.631 37.817,21.189 37.656,19.715 36.479,8.941 29.625,3.588 19.853,1.08 13.574,-0.531 7.209,-0.816 0,0"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path52"
+ inkscape:connector-curvature="0" /></g><g
+ id="g54"
+ transform="translate(548.4282,396.6898)"><path
+ d="M 0,0 C 0,2.553 -3.404,4.255 -3.404,4.255 -0.851,5.105 0,8.51 0,8.51 0,8.51 0.851,5.105 3.404,4.255 3.404,4.255 0,2.553 0,0 m -16.835,6.354 c 0,6.638 -8.851,11.063 -8.851,11.063 6.638,2.213 8.851,11.063 8.851,11.063 0,0 2.212,-8.85 8.85,-11.063 0,0 -8.85,-4.425 -8.85,-11.063"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path56"
+ inkscape:connector-curvature="0" /></g><g
+ id="g58"
+ transform="translate(489.744,429.7865)"><path
+ d="m 0,0 c 0,0 -2.18,3.308 -9.585,2.026 0,0 5.894,28.296 35.737,38.209 C 26.152,40.235 -2.777,23.901 0,0 m 0.574,-32.696 -0.003,-0.003 c -1.277,-1.994 -2.778,-3.523 -4.46,-4.544 -1.492,-0.919 -3.13,-1.403 -4.67,-1.403 -0.519,0 -1.028,0.055 -1.516,0.167 l -2.174,0.5 2.15,0.593 c 1.605,0.443 3.016,1.325 4.194,2.623 0.997,1.09 1.826,2.447 2.536,4.149 1.158,2.775 1.757,6.066 1.783,9.781 -0.048,3.661 -0.673,6.945 -1.857,9.75 -0.693,1.629 -1.56,3.009 -2.58,4.099 -1.208,1.29 -2.63,2.143 -4.228,2.536 l -2.154,0.529 2.145,0.564 c 0.473,0.125 0.983,0.193 1.516,0.205 l 0.031,0.002 1.662,-0.152 c 1.029,-0.203 2.071,-0.605 3.014,-1.167 1.727,-1.015 3.25,-2.527 4.526,-4.493 2.122,-3.323 3.264,-7.421 3.305,-11.857 C 3.767,-25.296 2.653,-29.4 0.574,-32.696 M 100.951,17.69 c 0,0 -0.074,0.05 -0.223,0.136 -0.532,0.32 -1.095,0.593 -1.688,0.801 -2.997,1.223 -9.729,3.139 -21.568,2.583 -0.029,0 -0.055,0 -0.084,-0.001 C 52.1,20.798 29.559,10.754 29.559,10.754 c 0,0 1.417,1.733 3.4,3.636 0,0 10e-4,0 10e-4,10e-4 1.036,0.958 2.319,2.044 3.853,3.176 0.044,0.031 0.086,0.062 0.122,0.091 8.573,6.286 24.949,13.945 50.829,9.394 -0.615,0.698 -1.254,1.341 -1.887,1.998 l 0.057,-0.009 c 0,0 -5.858,6.886 -16.555,12.616 -0.613,0.331 -1.252,0.659 -1.91,0.985 -0.038,0.018 -0.073,0.038 -0.112,0.057 -0.067,0.033 -0.128,0.057 -0.195,0.089 -8.007,3.888 -19.263,7.05 -33.564,5.458 -3.873,-0.374 -8.015,-1.116 -12.272,-2.399 -0.007,-0.002 -0.013,-0.003 -0.02,-0.005 L 21.304,45.84 c -10,-3.018 -20.632,-9.029 -29.888,-20.328 -0.563,-0.707 -1.105,-1.409 -1.632,-2.109 -0.064,-0.096 -0.138,-0.198 -0.222,-0.301 -1.164,-1.557 -2.236,-3.091 -3.229,-4.602 -3.216,-5.343 -5.544,-11.486 -6.214,-13.337 -10.417,-31.901 2.546,-56.661 5.065,-61.011 8.793,-14.923 24.186,-31.856 49.988,-36.751 0.195,-0.046 0.377,-0.101 0.573,-0.145 1.697,-0.361 4.79,-0.914 8.771,-1.178 1.5,-0.067 3.04,-0.093 4.626,-0.065 1.525,-0.01 2.953,0.016 4.268,0.063 0.39,0.027 0.729,0.04 1.029,0.044 5.023,0.233 8.144,0.762 8.144,0.762 -26.133,1.279 -39.232,13.203 -44.987,20.816 -1.304,1.622 -2.422,3.367 -3.324,5.234 -0.356,0.699 -0.515,1.097 -0.515,1.097 8.328,-7.069 19.98,-13.155 19.98,-13.155 10.449,-4.917 21.402,-7.337 33.008,-5.747 0,0 24.264,2.227 33.932,24.702 -0.417,0.317 -0.361,0.275 -0.777,0.592 -0.642,-0.518 -1.274,-1.007 -1.898,-1.474 -0.975,-0.64 -1.933,-1.336 -2.89,-2.038 -5.184,-3.431 -9.414,-5.048 -11.934,-5.788 -19.06,-4.811 -36.698,-1.232 -52.114,12.263 -7.14,6.251 -11.572,14.131 -11.759,23.917 -0.122,6.365 -0.188,12.734 -0.135,19.101 0.084,10.023 7.135,17.645 17.126,18.986 20.245,2.716 40.598,3.652 60.993,4.129 3.125,0.072 5.888,0.036 7.342,-3.306 0.025,-0.057 0.358,0.021 0.543,0.035 1.379,4.393 -0.607,9.127 -4.223,11.444"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path60"
+ inkscape:connector-curvature="0" /></g></g></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/logo_with_text_monochrome.png b/website/assets/logos/logo_with_text_monochrome.png
new file mode 100644
index 000000000..17442f55d
--- /dev/null
+++ b/website/assets/logos/logo_with_text_monochrome.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_monochrome.svg b/website/assets/logos/logo_with_text_monochrome.svg
new file mode 100644
index 000000000..4648e06c0
--- /dev/null
+++ b/website/assets/logos/logo_with_text_monochrome.svg
@@ -0,0 +1,116 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="607.97211"
+ height="193.20036"
+ viewBox="0 0 607.97212 193.20036"
+ sodipodi:docname="logo_with_text_monochrome.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath20"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path18"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="640"
+ inkscape:window-height="480"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.21542963"
+ inkscape:cx="296.21626"
+ inkscape:cy="101.98009"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="0"
+ inkscape:current-layer="g10" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-614.45037,638.9628)"><g
+ id="g14"><g
+ id="g16"
+ clip-path="url(#clipPath20)"><g
+ id="g22"
+ transform="translate(668.8995,400.2876)"><path
+ d="m 0,0 c -0.698,-5.234 -4.362,-8.375 -10.991,-9.421 -9.072,0.349 -13.783,5.407 -14.132,15.178 0.698,8.374 5.408,12.719 14.132,13.033 C -4.362,17.776 -0.698,14.341 0,8.479 Z m 0,26.117 c -2.442,2.826 -6.629,4.413 -12.561,4.763 -8.026,0 -14.219,-2.443 -18.581,-7.327 -4.361,-4.886 -6.542,-11.167 -6.542,-18.842 0,-8.026 2.006,-14.395 6.019,-19.105 4.012,-4.71 10.031,-7.066 18.057,-7.066 5.164,0 9.7,1.57 13.608,4.711 v -2.617 c 0,-3.141 -0.986,-6.019 -2.957,-8.636 -1.972,-2.617 -5.749,-3.925 -11.331,-3.925 -5.479,0.349 -10.137,1.919 -13.975,4.71 l -6.804,-7.85 c 5.582,-5.582 13.433,-8.48 23.553,-8.689 l 3.14,0.314 c 6.629,0.279 12.124,2.879 16.487,7.798 4.361,4.92 6.542,11.044 6.542,18.372 V 29.31 L 0,30.88 Z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path24"
+ inkscape:connector-curvature="0" /></g><g
+ id="g26"
+ transform="translate(720.3033,399.9331)"><path
+ d="M 0,0 -19.986,51.176 H -37.513 L -8.457,-21.105 H 8.891 L 37.875,51.176 H 20.348 Z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path28"
+ inkscape:connector-curvature="0" /></g><path
+ d="m 762.522,378.828 h 14.655 v 52.392 h -14.655 z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path30"
+ inkscape:connector-curvature="0" /><g
+ id="g32"
+ transform="translate(769.7443,451.1089)"><path
+ d="m 0,0 c -2.373,0 -4.257,-0.707 -5.652,-2.12 -1.397,-1.413 -2.094,-3.166 -2.094,-5.26 0,-2.094 0.697,-3.839 2.094,-5.234 1.395,-1.396 3.315,-2.094 5.757,-2.094 2.442,0 4.361,0.698 5.757,2.094 1.395,1.395 2.094,3.14 2.094,5.234 0,2.094 -0.699,3.847 -2.094,5.26 C 4.466,-0.707 2.512,0 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path34"
+ inkscape:connector-curvature="0" /></g><g
+ id="g36"
+ transform="translate(816.9264,406.8296)"><path
+ d="m 0,0 c -3.315,2.477 -7.572,3.96 -12.771,4.449 -3.524,0.697 -5.879,1.578 -7.065,2.643 -1.187,1.064 -1.187,2.251 0,3.559 1.186,1.309 3.541,1.962 7.065,1.962 3.56,0 6.525,-1.412 8.898,-4.239 l 8.165,9.212 c -2.374,2.338 -4.92,4.1 -7.642,5.286 -2.721,1.186 -6.542,1.78 -11.462,1.78 -4.257,0 -8.4,-1.423 -12.43,-4.265 -4.03,-2.845 -6.046,-6.971 -6.046,-12.379 0,-5.898 1.771,-9.91 5.313,-12.038 3.541,-2.129 7.44,-3.437 11.697,-3.925 4.712,-0.454 7.476,-1.335 8.297,-2.645 0.819,-1.307 0.645,-2.615 -0.523,-3.924 -1.171,-1.309 -3.761,-1.963 -7.774,-1.963 -4.99,0.453 -8.67,2.093 -11.043,4.921 l -8.113,-9.161 c 2.373,-2.373 5.208,-4.265 8.505,-5.678 3.298,-1.413 7.424,-2.12 12.379,-2.12 5.199,0 9.752,1.36 13.66,4.083 3.908,2.721 5.862,6.558 5.862,11.514 C 4.972,-6.787 3.315,-2.479 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path38"
+ inkscape:connector-curvature="0" /></g><g
+ id="g40"
+ transform="translate(862.2008,394.9751)"><path
+ d="m 0,0 c -2.478,-2.67 -5.705,-4.004 -9.683,-4.004 -3.977,0 -7.205,1.334 -9.682,4.004 -2.478,2.669 -3.717,5.992 -3.717,9.97 0,3.978 1.239,7.31 3.717,9.997 2.477,2.686 5.686,4.03 9.63,4.03 4.012,0 7.257,-1.344 9.735,-4.03 C 2.477,17.28 3.716,13.948 3.716,9.97 3.716,5.992 2.477,2.669 0,0 m -9.63,36.716 c -7.608,0 -13.957,-2.548 -19.052,-7.642 -5.095,-5.095 -7.641,-11.445 -7.641,-19.051 0,-7.573 2.546,-13.915 7.641,-19.025 5.095,-5.113 11.444,-7.669 19.052,-7.669 7.606,0 13.956,2.547 19.051,7.642 5.095,5.093 7.641,11.445 7.641,19.052 0,7.606 -2.546,13.956 -7.641,19.051 -5.095,5.094 -11.445,7.642 -19.051,7.642"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path42"
+ inkscape:connector-curvature="0" /></g><g
+ id="g44"
+ transform="translate(911.9489,431.1675)"><path
+ d="m 0,0 c -5.374,0 -9.927,-1.92 -13.66,-5.758 v 5.81 L -28.315,-1.519 V -52.34 h 14.655 v 32.19 c 2.372,4.92 5.338,7.501 8.898,7.746 3.175,-0.105 5.617,-0.925 7.327,-2.46 L 4.868,-0.419 C 3.332,-0.14 1.709,0 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path46"
+ inkscape:connector-curvature="0" /></g><g
+ id="g48"
+ transform="translate(548.2423,363.2485)"><path
+ d="m 0,0 c 16.655,21.121 22.696,44.433 18.328,70.995 3.068,0 5.743,-0.023 8.417,0.007 2.222,0.025 4.443,0.102 6.664,0.175 4.79,0.154 4.818,0.165 5.88,-4.582 3.145,-14.051 2.18,-28.09 -0.179,-42.118 -0.25,-1.492 -0.7,-2.956 -0.864,-4.454 C 37.05,9.081 30.089,3.645 20.165,1.097 13.787,-0.54 7.323,-0.829 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path50"
+ inkscape:connector-curvature="0" /></g><g
+ id="g52"
+ transform="translate(544.6891,396.1763)"><path
+ d="M 0,0 C 0,2.594 -3.457,4.322 -3.457,4.322 -0.864,5.187 0,8.644 0,8.644 0,8.644 0.865,5.187 3.458,4.322 3.458,4.322 0,2.594 0,0 m -17.099,6.454 c 0,6.742 -8.989,11.236 -8.989,11.236 6.742,2.248 8.989,11.238 8.989,11.238 0,0 2.247,-8.99 8.99,-11.238 0,0 -8.99,-4.494 -8.99,-11.236"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path54"
+ inkscape:connector-curvature="0" /></g><g
+ id="g56"
+ transform="translate(485.0861,429.7925)"><path
+ d="m 0,0 c 0,0 -2.214,3.359 -9.736,2.059 0,0 5.987,28.738 36.298,38.806 C 26.562,40.865 -2.82,24.275 0,0 M 0.583,-33.208 0.58,-33.211 c -1.297,-2.026 -2.821,-3.579 -4.53,-4.616 -1.515,-0.933 -3.178,-1.425 -4.743,-1.425 -0.528,0 -1.044,0.057 -1.539,0.17 l -2.209,0.507 2.184,0.603 c 1.63,0.451 3.063,1.347 4.259,2.664 1.014,1.108 1.856,2.485 2.577,4.213 1.175,2.82 1.784,6.162 1.81,9.936 -0.049,3.718 -0.683,7.054 -1.886,9.902 -0.703,1.654 -1.585,3.056 -2.621,4.163 -1.227,1.311 -2.671,2.178 -4.294,2.576 l -2.187,0.538 2.179,0.572 c 0.48,0.126 0.998,0.196 1.539,0.209 h 0.031 l 1.688,-0.153 c 1.045,-0.206 2.104,-0.616 3.061,-1.185 1.755,-1.031 3.302,-2.567 4.598,-4.565 2.155,-3.374 3.315,-7.536 3.357,-12.042 -0.028,-4.548 -1.159,-8.717 -3.271,-12.064 m 101.949,51.176 c 0,0 -0.075,0.05 -0.226,0.136 -0.541,0.327 -1.113,0.603 -1.715,0.815 -3.044,1.241 -9.881,3.186 -21.906,2.623 -0.029,0 -0.056,0 -0.085,-0.001 C 52.916,21.123 30.022,10.922 30.022,10.922 c 0,0 1.439,1.761 3.453,3.692 10e-4,0 10e-4,10e-4 0.002,10e-4 1.052,0.974 2.355,2.076 3.912,3.227 0.046,0.031 0.088,0.063 0.124,0.093 8.708,6.384 25.34,14.163 51.625,9.541 -0.989,1.124 -2.002,2.192 -3.036,3.215 -1.112,0.883 -2.231,1.693 -3.354,2.456 0.02,-0.012 0.039,-0.023 0.059,-0.036 0,0 -17.016,19.415 -48.683,15.891 C 30.19,48.622 25.983,47.867 21.66,46.564 21.653,46.563 21.646,46.562 21.64,46.56 L 21.638,46.558 C 11.48,43.492 0.683,37.387 -8.719,25.911 -9.29,25.193 -9.841,24.479 -10.376,23.77 c -0.065,-0.099 -0.141,-0.202 -0.226,-0.307 -1.182,-1.581 -2.271,-3.14 -3.279,-4.674 -3.266,-5.427 -5.631,-11.665 -6.311,-13.545 -10.58,-32.401 2.586,-57.55 5.144,-61.967 8.93,-15.158 24.565,-32.355 50.771,-37.327 0.197,-0.047 0.382,-0.101 0.582,-0.147 1.723,-0.367 4.864,-0.929 8.908,-1.196 1.524,-0.069 3.088,-0.094 4.699,-0.067 1.548,-0.009 2.999,0.017 4.335,0.064 0.396,0.028 0.74,0.041 1.044,0.044 5.102,0.238 8.272,0.775 8.272,0.775 -26.543,1.299 -39.847,13.409 -45.691,21.142 -1.325,1.648 -2.46,3.42 -3.377,5.316 -0.361,0.71 -0.523,1.115 -0.523,1.115 8.459,-7.181 20.294,-13.362 20.294,-13.362 10.611,-4.993 21.737,-7.451 33.524,-5.837 0,0 24.645,2.263 34.464,25.09 -0.423,0.322 -0.366,0.278 -0.79,0.6 -0.652,-0.526 -1.294,-1.023 -1.926,-1.496 -0.991,-0.651 -1.964,-1.357 -2.937,-2.07 -5.265,-3.485 -9.561,-5.128 -12.12,-5.879 -19.359,-4.887 -37.273,-1.252 -52.93,12.455 -7.253,6.349 -11.754,14.352 -11.944,24.291 -0.124,6.465 -0.19,12.935 -0.136,19.4 0.085,10.181 7.246,17.921 17.394,19.284 20.561,2.759 41.234,3.71 61.947,4.193 3.174,0.073 5.981,0.037 7.458,-3.356 0.025,-0.058 0.363,0.02 0.552,0.035 1.4,4.461 -0.617,9.27 -4.29,11.624"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path58"
+ inkscape:connector-curvature="0" /></g></g></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/logo_with_text_on_dark-1024.png b/website/assets/logos/logo_with_text_on_dark-1024.png
new file mode 100644
index 000000000..a02a9014b
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark-1024.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_dark-128.png b/website/assets/logos/logo_with_text_on_dark-128.png
new file mode 100644
index 000000000..efae725b8
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark-128.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_dark-16.png b/website/assets/logos/logo_with_text_on_dark-16.png
new file mode 100644
index 000000000..a6069f98f
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark-16.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_dark.png b/website/assets/logos/logo_with_text_on_dark.png
new file mode 100644
index 000000000..24de18c11
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_dark.svg b/website/assets/logos/logo_with_text_on_dark.svg
new file mode 100644
index 000000000..52d8e52da
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark.svg
@@ -0,0 +1,116 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="607.97211"
+ height="193.19984"
+ viewBox="0 0 607.97212 193.19985"
+ sodipodi:docname="logo_with_text_on_dark.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath20"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path18"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="640"
+ inkscape:window-height="480"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.21542963"
+ inkscape:cx="296.21626"
+ inkscape:cy="101.97992"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="0"
+ inkscape:current-layer="g10" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-614.45037,638.96254)"><g
+ id="g14"><g
+ id="g16"
+ clip-path="url(#clipPath20)"><g
+ id="g22"
+ transform="translate(668.8995,400.2874)"><path
+ d="m 0,0 c -0.698,-5.233 -4.362,-8.375 -10.991,-9.421 -9.072,0.349 -13.783,5.408 -14.132,15.178 0.698,8.375 5.408,12.719 14.132,13.033 C -4.362,17.777 -0.698,14.341 0,8.479 Z m 0,26.117 c -2.442,2.826 -6.629,4.414 -12.561,4.763 -8.026,0 -14.219,-2.443 -18.581,-7.327 -4.361,-4.886 -6.542,-11.166 -6.542,-18.842 0,-8.026 2.006,-14.394 6.019,-19.105 4.012,-4.71 10.031,-7.065 18.057,-7.065 5.164,0 9.7,1.569 13.608,4.71 v -2.616 c 0,-3.141 -0.986,-6.02 -2.957,-8.636 -1.972,-2.618 -5.749,-3.926 -11.331,-3.926 -5.479,0.349 -10.137,1.919 -13.975,4.71 l -6.804,-7.85 c 5.582,-5.582 13.433,-8.48 23.553,-8.689 l 3.14,0.314 c 6.629,0.28 12.124,2.879 16.487,7.798 4.361,4.92 6.542,11.044 6.542,18.373 V 29.31 L 0,30.88 Z"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path24"
+ inkscape:connector-curvature="0" /></g><g
+ id="g26"
+ transform="translate(720.3033,399.9339)"><path
+ d="m 0,0 -19.986,51.175 h -17.527 l 29.056,-72.28 H 8.891 l 28.984,72.28 H 20.348 Z"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path28"
+ inkscape:connector-curvature="0" /></g><path
+ d="m 762.522,378.828 h 14.655 v 52.392 h -14.655 z"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path30"
+ inkscape:connector-curvature="0" /><g
+ id="g32"
+ transform="translate(769.7443,451.1087)"><path
+ d="m 0,0 c -2.373,0 -4.257,-0.706 -5.652,-2.12 -1.397,-1.413 -2.094,-3.166 -2.094,-5.26 0,-2.094 0.697,-3.839 2.094,-5.233 1.395,-1.397 3.315,-2.094 5.757,-2.094 2.442,0 4.361,0.697 5.757,2.094 1.395,1.394 2.094,3.139 2.094,5.233 0,2.094 -0.699,3.847 -2.094,5.26 C 4.466,-0.706 2.512,0 0,0"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path34"
+ inkscape:connector-curvature="0" /></g><g
+ id="g36"
+ transform="translate(816.9264,406.8294)"><path
+ d="m 0,0 c -3.315,2.478 -7.572,3.96 -12.771,4.449 -3.524,0.697 -5.879,1.578 -7.065,2.644 -1.187,1.063 -1.187,2.25 0,3.558 1.186,1.309 3.541,1.963 7.065,1.963 3.56,0 6.525,-1.413 8.898,-4.239 l 8.165,9.211 c -2.374,2.338 -4.92,4.101 -7.642,5.286 -2.721,1.187 -6.542,1.78 -11.462,1.78 -4.257,0 -8.4,-1.422 -12.43,-4.265 -4.03,-2.845 -6.046,-6.971 -6.046,-12.379 0,-5.897 1.771,-9.91 5.313,-12.038 3.541,-2.128 7.44,-3.437 11.697,-3.925 4.712,-0.454 7.476,-1.335 8.297,-2.644 0.819,-1.308 0.645,-2.616 -0.523,-3.924 -1.171,-1.309 -3.761,-1.963 -7.774,-1.963 -4.99,0.453 -8.67,2.092 -11.043,4.92 l -8.113,-9.16 c 2.373,-2.374 5.208,-4.266 8.505,-5.678 3.298,-1.414 7.424,-2.121 12.379,-2.121 5.199,0 9.752,1.361 13.66,4.083 3.908,2.722 5.862,6.559 5.862,11.515 C 4.972,-6.787 3.315,-2.478 0,0"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path38"
+ inkscape:connector-curvature="0" /></g><g
+ id="g40"
+ transform="translate(862.2008,394.9749)"><path
+ d="m 0,0 c -2.478,-2.669 -5.705,-4.004 -9.683,-4.004 -3.977,0 -7.205,1.335 -9.682,4.004 -2.478,2.67 -3.717,5.992 -3.717,9.97 0,3.978 1.239,7.31 3.717,9.998 2.477,2.685 5.686,4.029 9.63,4.029 4.012,0 7.257,-1.344 9.735,-4.029 C 2.477,17.28 3.716,13.948 3.716,9.97 3.716,5.992 2.477,2.67 0,0 m -9.63,36.716 c -7.608,0 -13.957,-2.547 -19.052,-7.642 -5.095,-5.095 -7.641,-11.445 -7.641,-19.051 0,-7.572 2.546,-13.914 7.641,-19.025 5.095,-5.112 11.444,-7.669 19.052,-7.669 7.606,0 13.956,2.548 19.051,7.642 5.095,5.094 7.641,11.445 7.641,19.052 0,7.606 -2.546,13.956 -7.641,19.051 -5.095,5.095 -11.445,7.642 -19.051,7.642"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path42"
+ inkscape:connector-curvature="0" /></g><g
+ id="g44"
+ transform="translate(911.9489,431.1673)"><path
+ d="m 0,0 c -5.374,0 -9.927,-1.919 -13.66,-5.757 v 5.81 l -14.655,-1.571 v -50.821 h 14.655 v 32.189 c 2.372,4.92 5.338,7.502 8.898,7.746 3.175,-0.105 5.617,-0.925 7.327,-2.46 L 4.868,-0.419 C 3.332,-0.14 1.709,0 0,0"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path46"
+ inkscape:connector-curvature="0" /></g><g
+ id="g48"
+ transform="translate(548.2423,363.2484)"><path
+ d="m 0,0 c 16.655,21.121 22.696,44.434 18.328,70.995 3.068,0 5.743,-0.023 8.417,0.007 2.222,0.025 4.443,0.102 6.664,0.175 4.79,0.154 4.818,0.165 5.88,-4.582 3.145,-14.05 2.18,-28.089 -0.179,-42.118 -0.25,-1.492 -0.7,-2.956 -0.864,-4.454 C 37.05,9.081 30.089,3.645 20.165,1.097 13.787,-0.54 7.323,-0.829 0,0"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path50"
+ inkscape:connector-curvature="0" /></g><g
+ id="g52"
+ transform="translate(544.6891,396.1771)"><path
+ d="M 0,0 C 0,2.593 -3.457,4.321 -3.457,4.321 -0.864,5.186 0,8.644 0,8.644 0,8.644 0.865,5.186 3.458,4.321 3.458,4.321 0,2.593 0,0 m -17.099,6.453 c 0,6.742 -8.989,11.237 -8.989,11.237 6.742,2.247 8.989,11.237 8.989,11.237 0,0 2.247,-8.99 8.99,-11.237 0,0 -8.99,-4.495 -8.99,-11.237"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path54"
+ inkscape:connector-curvature="0" /></g><g
+ id="g56"
+ transform="translate(485.0861,429.7923)"><path
+ d="m 0,0 c 0,0 -2.214,3.359 -9.736,2.059 0,0 5.987,28.738 36.298,38.806 C 26.562,40.865 -2.82,24.275 0,0 M 0.583,-33.208 0.58,-33.211 c -1.297,-2.026 -2.821,-3.578 -4.53,-4.615 -1.515,-0.934 -3.178,-1.425 -4.743,-1.425 -0.528,0 -1.044,0.056 -1.539,0.169 l -2.209,0.507 2.184,0.603 c 1.63,0.451 3.063,1.347 4.259,2.664 1.014,1.108 1.856,2.485 2.577,4.214 1.175,2.819 1.784,6.161 1.81,9.935 -0.049,3.719 -0.683,7.054 -1.886,9.902 -0.703,1.655 -1.585,3.057 -2.621,4.163 -1.227,1.311 -2.671,2.178 -4.294,2.576 l -2.187,0.538 2.179,0.572 c 0.48,0.127 0.998,0.196 1.539,0.209 l 0.031,10e-4 1.688,-0.154 c 1.045,-0.206 2.104,-0.615 3.061,-1.184 1.755,-1.032 3.302,-2.568 4.598,-4.565 2.155,-3.374 3.315,-7.537 3.357,-12.043 -0.028,-4.548 -1.159,-8.717 -3.271,-12.064 m 101.949,51.176 c 0,0 -0.075,0.051 -0.226,0.137 -0.541,0.326 -1.113,0.602 -1.715,0.814 -3.044,1.241 -9.881,3.187 -21.906,2.623 -0.029,0 -0.056,0 -0.085,-0.001 C 52.916,21.123 30.022,10.923 30.022,10.923 c 0,0 1.439,1.76 3.453,3.691 10e-4,10e-4 10e-4,10e-4 0.002,0.002 1.052,0.973 2.355,2.076 3.912,3.226 0.046,0.032 0.088,0.063 0.124,0.094 8.708,6.383 25.34,14.162 51.625,9.54 -0.989,1.124 -2.002,2.193 -3.036,3.215 -1.112,0.884 -2.231,1.694 -3.354,2.456 0.02,-0.012 0.039,-0.023 0.059,-0.036 0,0 -17.016,19.415 -48.683,15.891 C 30.19,48.622 25.983,47.868 21.66,46.564 21.653,46.563 21.646,46.562 21.64,46.56 L 21.638,46.558 C 11.48,43.492 0.683,37.387 -8.719,25.911 c -0.571,-0.718 -1.122,-1.431 -1.657,-2.14 -0.065,-0.1 -0.141,-0.202 -0.226,-0.307 -1.182,-1.582 -2.271,-3.141 -3.279,-4.674 -3.266,-5.427 -5.631,-11.666 -6.311,-13.546 -10.58,-32.401 2.586,-57.549 5.144,-61.967 8.93,-15.157 24.565,-32.355 50.771,-37.327 0.197,-0.046 0.382,-0.101 0.582,-0.146 1.723,-0.367 4.864,-0.929 8.908,-1.197 1.524,-0.069 3.088,-0.094 4.699,-0.066 1.548,-0.01 2.999,0.017 4.335,0.064 0.396,0.027 0.74,0.041 1.044,0.044 5.102,0.237 8.272,0.774 8.272,0.774 -26.543,1.3 -39.847,13.41 -45.691,21.142 -1.325,1.648 -2.46,3.421 -3.377,5.316 -0.361,0.711 -0.523,1.115 -0.523,1.115 8.459,-7.18 20.294,-13.361 20.294,-13.361 10.611,-4.993 21.737,-7.452 33.524,-5.838 0,0 24.645,2.264 34.464,25.09 -0.423,0.322 -0.366,0.278 -0.79,0.6 -0.652,-0.526 -1.294,-1.022 -1.926,-1.496 -0.991,-0.65 -1.964,-1.357 -2.937,-2.07 -5.265,-3.485 -9.561,-5.127 -12.12,-5.879 -19.359,-4.887 -37.273,-1.252 -52.93,12.455 -7.253,6.349 -11.754,14.353 -11.944,24.291 -0.124,6.466 -0.19,12.935 -0.136,19.4 0.085,10.181 7.246,17.921 17.394,19.284 20.561,2.759 41.234,3.71 61.947,4.193 3.174,0.074 5.981,0.037 7.458,-3.356 0.025,-0.058 0.363,0.02 0.552,0.035 1.4,4.462 -0.617,9.27 -4.29,11.624"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path58"
+ inkscape:connector-curvature="0" /></g></g></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/logo_with_text_on_dark_full-1024.png b/website/assets/logos/logo_with_text_on_dark_full-1024.png
new file mode 100644
index 000000000..eb2e63981
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark_full-1024.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_dark_full-128.png b/website/assets/logos/logo_with_text_on_dark_full-128.png
new file mode 100644
index 000000000..4ed21e5cb
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark_full-128.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_dark_full-16.png b/website/assets/logos/logo_with_text_on_dark_full-16.png
new file mode 100644
index 000000000..d3968da5e
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark_full-16.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_dark_full.png b/website/assets/logos/logo_with_text_on_dark_full.png
new file mode 100644
index 000000000..21feea356
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark_full.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_dark_full.svg b/website/assets/logos/logo_with_text_on_dark_full.svg
new file mode 100644
index 000000000..017e72414
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark_full.svg
@@ -0,0 +1,120 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="786.19244"
+ height="334.21716"
+ viewBox="0 0 786.19246 334.21718"
+ sodipodi:docname="logo_with_text_on_dark_full.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath20"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path18"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="567"
+ inkscape:window-height="462"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.3046635"
+ inkscape:cx="541.19762"
+ inkscape:cy="67.525134"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="0"
+ inkscape:current-layer="g10" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-524.53324,714.8519)"><path
+ d="M 393.39994,285.47606 H 983.04431 V 536.13894 H 393.39994 Z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none;stroke-width:0.36289462"
+ id="path12"
+ inkscape:connector-curvature="0" /><g
+ id="g14"><g
+ id="g16"
+ clip-path="url(#clipPath20)"><g
+ id="g22"
+ transform="translate(668.8995,400.2874)"><path
+ d="m 0,0 c -0.698,-5.233 -4.362,-8.375 -10.991,-9.421 -9.072,0.349 -13.783,5.408 -14.132,15.178 0.698,8.375 5.408,12.719 14.132,13.033 C -4.362,17.777 -0.698,14.341 0,8.479 Z m 0,26.117 c -2.442,2.826 -6.629,4.414 -12.561,4.763 -8.026,0 -14.219,-2.443 -18.581,-7.327 -4.361,-4.886 -6.542,-11.166 -6.542,-18.842 0,-8.026 2.006,-14.394 6.019,-19.105 4.012,-4.71 10.031,-7.065 18.057,-7.065 5.164,0 9.7,1.569 13.608,4.71 v -2.616 c 0,-3.141 -0.986,-6.02 -2.957,-8.636 -1.972,-2.618 -5.749,-3.926 -11.331,-3.926 -5.479,0.349 -10.137,1.919 -13.975,4.71 l -6.804,-7.85 c 5.582,-5.582 13.433,-8.48 23.553,-8.689 l 3.14,0.314 c 6.629,0.28 12.124,2.879 16.487,7.798 4.361,4.92 6.542,11.044 6.542,18.373 V 29.31 L 0,30.88 Z"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path24"
+ inkscape:connector-curvature="0" /></g><g
+ id="g26"
+ transform="translate(720.3033,399.9339)"><path
+ d="m 0,0 -19.986,51.175 h -17.527 l 29.056,-72.28 H 8.891 l 28.984,72.28 H 20.348 Z"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path28"
+ inkscape:connector-curvature="0" /></g><path
+ d="m 762.522,378.828 h 14.655 v 52.392 h -14.655 z"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path30"
+ inkscape:connector-curvature="0" /><g
+ id="g32"
+ transform="translate(769.7443,451.1087)"><path
+ d="m 0,0 c -2.373,0 -4.257,-0.706 -5.652,-2.12 -1.397,-1.413 -2.094,-3.166 -2.094,-5.26 0,-2.094 0.697,-3.839 2.094,-5.233 1.395,-1.397 3.315,-2.094 5.757,-2.094 2.442,0 4.361,0.697 5.757,2.094 1.395,1.394 2.094,3.139 2.094,5.233 0,2.094 -0.699,3.847 -2.094,5.26 C 4.466,-0.706 2.512,0 0,0"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path34"
+ inkscape:connector-curvature="0" /></g><g
+ id="g36"
+ transform="translate(816.9264,406.8294)"><path
+ d="m 0,0 c -3.315,2.478 -7.572,3.96 -12.771,4.449 -3.524,0.697 -5.879,1.578 -7.065,2.644 -1.187,1.063 -1.187,2.25 0,3.558 1.186,1.309 3.541,1.963 7.065,1.963 3.56,0 6.525,-1.413 8.898,-4.239 l 8.165,9.211 c -2.374,2.338 -4.92,4.101 -7.642,5.286 -2.721,1.187 -6.542,1.78 -11.462,1.78 -4.257,0 -8.4,-1.422 -12.43,-4.265 -4.03,-2.845 -6.046,-6.971 -6.046,-12.379 0,-5.897 1.771,-9.91 5.313,-12.038 3.541,-2.128 7.44,-3.437 11.697,-3.925 4.712,-0.454 7.476,-1.335 8.297,-2.644 0.819,-1.308 0.645,-2.616 -0.523,-3.924 -1.171,-1.309 -3.761,-1.963 -7.774,-1.963 -4.99,0.453 -8.67,2.092 -11.043,4.92 l -8.113,-9.16 c 2.373,-2.374 5.208,-4.266 8.505,-5.678 3.298,-1.414 7.424,-2.121 12.379,-2.121 5.199,0 9.752,1.361 13.66,4.083 3.908,2.722 5.862,6.559 5.862,11.515 C 4.972,-6.787 3.315,-2.478 0,0"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path38"
+ inkscape:connector-curvature="0" /></g><g
+ id="g40"
+ transform="translate(862.2008,394.9749)"><path
+ d="m 0,0 c -2.478,-2.669 -5.705,-4.004 -9.683,-4.004 -3.977,0 -7.205,1.335 -9.682,4.004 -2.478,2.67 -3.717,5.992 -3.717,9.97 0,3.978 1.239,7.31 3.717,9.998 2.477,2.685 5.686,4.029 9.63,4.029 4.012,0 7.257,-1.344 9.735,-4.029 C 2.477,17.28 3.716,13.948 3.716,9.97 3.716,5.992 2.477,2.67 0,0 m -9.63,36.716 c -7.608,0 -13.957,-2.547 -19.052,-7.642 -5.095,-5.095 -7.641,-11.445 -7.641,-19.051 0,-7.572 2.546,-13.914 7.641,-19.025 5.095,-5.112 11.444,-7.669 19.052,-7.669 7.606,0 13.956,2.548 19.051,7.642 5.095,5.094 7.641,11.445 7.641,19.052 0,7.606 -2.546,13.956 -7.641,19.051 -5.095,5.095 -11.445,7.642 -19.051,7.642"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path42"
+ inkscape:connector-curvature="0" /></g><g
+ id="g44"
+ transform="translate(911.9489,431.1673)"><path
+ d="m 0,0 c -5.374,0 -9.927,-1.919 -13.66,-5.757 v 5.81 l -14.655,-1.571 v -50.821 h 14.655 v 32.189 c 2.372,4.92 5.338,7.502 8.898,7.746 3.175,-0.105 5.617,-0.925 7.327,-2.46 L 4.868,-0.419 C 3.332,-0.14 1.709,0 0,0"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path46"
+ inkscape:connector-curvature="0" /></g><g
+ id="g48"
+ transform="translate(548.2423,363.2484)"><path
+ d="m 0,0 c 16.655,21.121 22.696,44.434 18.328,70.995 3.068,0 5.743,-0.023 8.417,0.007 2.222,0.025 4.443,0.102 6.664,0.175 4.79,0.154 4.818,0.165 5.88,-4.582 3.145,-14.05 2.18,-28.089 -0.179,-42.118 -0.25,-1.492 -0.7,-2.956 -0.864,-4.454 C 37.05,9.081 30.089,3.645 20.165,1.097 13.787,-0.54 7.323,-0.829 0,0"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path50"
+ inkscape:connector-curvature="0" /></g><g
+ id="g52"
+ transform="translate(544.6891,396.1771)"><path
+ d="M 0,0 C 0,2.593 -3.457,4.321 -3.457,4.321 -0.864,5.186 0,8.644 0,8.644 0,8.644 0.865,5.186 3.458,4.321 3.458,4.321 0,2.593 0,0 m -17.099,6.453 c 0,6.742 -8.989,11.237 -8.989,11.237 6.742,2.247 8.989,11.237 8.989,11.237 0,0 2.247,-8.99 8.99,-11.237 0,0 -8.99,-4.495 -8.99,-11.237"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path54"
+ inkscape:connector-curvature="0" /></g><g
+ id="g56"
+ transform="translate(485.0861,429.7923)"><path
+ d="m 0,0 c 0,0 -2.214,3.359 -9.736,2.059 0,0 5.987,28.738 36.298,38.806 C 26.562,40.865 -2.82,24.275 0,0 M 0.583,-33.208 0.58,-33.211 c -1.297,-2.026 -2.821,-3.578 -4.53,-4.615 -1.515,-0.934 -3.178,-1.425 -4.743,-1.425 -0.528,0 -1.044,0.056 -1.539,0.169 l -2.209,0.507 2.184,0.603 c 1.63,0.451 3.063,1.347 4.259,2.664 1.014,1.108 1.856,2.485 2.577,4.214 1.175,2.819 1.784,6.161 1.81,9.935 -0.049,3.719 -0.683,7.054 -1.886,9.902 -0.703,1.655 -1.585,3.057 -2.621,4.163 -1.227,1.311 -2.671,2.178 -4.294,2.576 l -2.187,0.538 2.179,0.572 c 0.48,0.127 0.998,0.196 1.539,0.209 l 0.031,10e-4 1.688,-0.154 c 1.045,-0.206 2.104,-0.615 3.061,-1.184 1.755,-1.032 3.302,-2.568 4.598,-4.565 2.155,-3.374 3.315,-7.537 3.357,-12.043 -0.028,-4.548 -1.159,-8.717 -3.271,-12.064 m 101.949,51.176 c 0,0 -0.075,0.051 -0.226,0.137 -0.541,0.326 -1.113,0.602 -1.715,0.814 -3.044,1.241 -9.881,3.187 -21.906,2.623 -0.029,0 -0.056,0 -0.085,-0.001 C 52.916,21.123 30.022,10.923 30.022,10.923 c 0,0 1.439,1.76 3.453,3.691 10e-4,10e-4 10e-4,10e-4 0.002,0.002 1.052,0.973 2.355,2.076 3.912,3.226 0.046,0.032 0.088,0.063 0.124,0.094 8.708,6.383 25.34,14.162 51.625,9.54 -0.989,1.124 -2.002,2.193 -3.036,3.215 -1.112,0.884 -2.231,1.694 -3.354,2.456 0.02,-0.012 0.039,-0.023 0.059,-0.036 0,0 -17.016,19.415 -48.683,15.891 C 30.19,48.622 25.983,47.868 21.66,46.564 21.653,46.563 21.646,46.562 21.64,46.56 L 21.638,46.558 C 11.48,43.492 0.683,37.387 -8.719,25.911 c -0.571,-0.718 -1.122,-1.431 -1.657,-2.14 -0.065,-0.1 -0.141,-0.202 -0.226,-0.307 -1.182,-1.582 -2.271,-3.141 -3.279,-4.674 -3.266,-5.427 -5.631,-11.666 -6.311,-13.546 -10.58,-32.401 2.586,-57.549 5.144,-61.967 8.93,-15.157 24.565,-32.355 50.771,-37.327 0.197,-0.046 0.382,-0.101 0.582,-0.146 1.723,-0.367 4.864,-0.929 8.908,-1.197 1.524,-0.069 3.088,-0.094 4.699,-0.066 1.548,-0.01 2.999,0.017 4.335,0.064 0.396,0.027 0.74,0.041 1.044,0.044 5.102,0.237 8.272,0.774 8.272,0.774 -26.543,1.3 -39.847,13.41 -45.691,21.142 -1.325,1.648 -2.46,3.421 -3.377,5.316 -0.361,0.711 -0.523,1.115 -0.523,1.115 8.459,-7.18 20.294,-13.361 20.294,-13.361 10.611,-4.993 21.737,-7.452 33.524,-5.838 0,0 24.645,2.264 34.464,25.09 -0.423,0.322 -0.366,0.278 -0.79,0.6 -0.652,-0.526 -1.294,-1.022 -1.926,-1.496 -0.991,-0.65 -1.964,-1.357 -2.937,-2.07 -5.265,-3.485 -9.561,-5.127 -12.12,-5.879 -19.359,-4.887 -37.273,-1.252 -52.93,12.455 -7.253,6.349 -11.754,14.353 -11.944,24.291 -0.124,6.466 -0.19,12.935 -0.136,19.4 0.085,10.181 7.246,17.921 17.394,19.284 20.561,2.759 41.234,3.71 61.947,4.193 3.174,0.074 5.981,0.037 7.458,-3.356 0.025,-0.058 0.363,0.02 0.552,0.035 1.4,4.462 -0.617,9.27 -4.29,11.624"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path58"
+ inkscape:connector-curvature="0" /></g></g></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/logo_with_text_on_white.png b/website/assets/logos/logo_with_text_on_white.png
new file mode 100644
index 000000000..bf420a057
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_white.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_white.svg b/website/assets/logos/logo_with_text_on_white.svg
new file mode 100644
index 000000000..4275efe83
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_white.svg
@@ -0,0 +1,116 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="607.97211"
+ height="193.20036"
+ viewBox="0 0 607.97212 193.20037"
+ sodipodi:docname="logo_with_text_on_white.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath18"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path16"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="640"
+ inkscape:window-height="480"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.43085925"
+ inkscape:cx="325.455"
+ inkscape:cy="50.910097"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="0"
+ inkscape:current-layer="g10" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-614.45037,638.9628)"><g
+ id="g12"><g
+ id="g14"
+ clip-path="url(#clipPath18)"><g
+ id="g20"
+ transform="translate(668.8995,400.2876)"><path
+ d="m 0,0 c -0.698,-5.234 -4.362,-8.375 -10.991,-9.421 -9.072,0.349 -13.783,5.407 -14.132,15.178 0.698,8.374 5.408,12.719 14.132,13.033 C -4.362,17.776 -0.698,14.341 0,8.479 Z m 0,26.117 c -2.442,2.826 -6.629,4.413 -12.561,4.763 -8.026,0 -14.219,-2.443 -18.581,-7.327 -4.361,-4.886 -6.542,-11.167 -6.542,-18.842 0,-8.026 2.006,-14.395 6.019,-19.105 4.012,-4.71 10.031,-7.066 18.057,-7.066 5.164,0 9.7,1.57 13.608,4.711 v -2.617 c 0,-3.141 -0.986,-6.019 -2.957,-8.636 -1.972,-2.617 -5.749,-3.925 -11.331,-3.925 -5.479,0.349 -10.137,1.919 -13.975,4.71 l -6.804,-7.85 c 5.582,-5.582 13.433,-8.48 23.553,-8.689 l 3.14,0.314 c 6.629,0.279 12.124,2.879 16.487,7.798 4.361,4.92 6.542,11.044 6.542,18.372 V 29.31 L 0,30.88 Z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path22"
+ inkscape:connector-curvature="0" /></g><g
+ id="g24"
+ transform="translate(720.3033,399.9331)"><path
+ d="M 0,0 -19.986,51.176 H -37.513 L -8.457,-21.105 H 8.891 L 37.875,51.176 H 20.348 Z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path26"
+ inkscape:connector-curvature="0" /></g><path
+ d="m 762.522,378.828 h 14.655 v 52.392 h -14.655 z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path28"
+ inkscape:connector-curvature="0" /><g
+ id="g30"
+ transform="translate(769.7443,451.1089)"><path
+ d="m 0,0 c -2.373,0 -4.257,-0.707 -5.652,-2.12 -1.397,-1.413 -2.094,-3.166 -2.094,-5.26 0,-2.094 0.697,-3.839 2.094,-5.234 1.395,-1.396 3.315,-2.094 5.757,-2.094 2.442,0 4.361,0.698 5.757,2.094 1.395,1.395 2.094,3.14 2.094,5.234 0,2.094 -0.699,3.847 -2.094,5.26 C 4.466,-0.707 2.512,0 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path32"
+ inkscape:connector-curvature="0" /></g><g
+ id="g34"
+ transform="translate(816.9264,406.8296)"><path
+ d="m 0,0 c -3.315,2.477 -7.572,3.96 -12.771,4.449 -3.524,0.697 -5.879,1.578 -7.065,2.643 -1.187,1.064 -1.187,2.251 0,3.559 1.186,1.309 3.541,1.962 7.065,1.962 3.56,0 6.525,-1.412 8.898,-4.239 l 8.165,9.212 c -2.374,2.338 -4.92,4.1 -7.642,5.286 -2.721,1.186 -6.542,1.78 -11.462,1.78 -4.257,0 -8.4,-1.423 -12.43,-4.265 -4.03,-2.845 -6.046,-6.971 -6.046,-12.379 0,-5.898 1.771,-9.91 5.313,-12.038 3.541,-2.129 7.44,-3.437 11.697,-3.925 4.712,-0.454 7.476,-1.335 8.297,-2.645 0.819,-1.307 0.645,-2.615 -0.523,-3.924 -1.171,-1.309 -3.761,-1.963 -7.774,-1.963 -4.99,0.453 -8.67,2.093 -11.043,4.921 l -8.113,-9.161 c 2.373,-2.373 5.208,-4.265 8.505,-5.678 3.298,-1.413 7.424,-2.12 12.379,-2.12 5.199,0 9.752,1.36 13.66,4.083 3.908,2.721 5.862,6.558 5.862,11.514 C 4.972,-6.787 3.315,-2.479 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path36"
+ inkscape:connector-curvature="0" /></g><g
+ id="g38"
+ transform="translate(862.2008,394.9751)"><path
+ d="m 0,0 c -2.478,-2.67 -5.705,-4.004 -9.683,-4.004 -3.977,0 -7.205,1.334 -9.682,4.004 -2.478,2.669 -3.717,5.992 -3.717,9.97 0,3.978 1.239,7.31 3.717,9.997 2.477,2.686 5.686,4.03 9.63,4.03 4.012,0 7.257,-1.344 9.735,-4.03 C 2.477,17.28 3.716,13.948 3.716,9.97 3.716,5.992 2.477,2.669 0,0 m -9.63,36.716 c -7.608,0 -13.957,-2.548 -19.052,-7.642 -5.095,-5.095 -7.641,-11.445 -7.641,-19.051 0,-7.573 2.546,-13.915 7.641,-19.025 5.095,-5.113 11.444,-7.669 19.052,-7.669 7.606,0 13.956,2.547 19.051,7.642 5.095,5.093 7.641,11.445 7.641,19.052 0,7.606 -2.546,13.956 -7.641,19.051 -5.095,5.094 -11.445,7.642 -19.051,7.642"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path40"
+ inkscape:connector-curvature="0" /></g><g
+ id="g42"
+ transform="translate(911.9489,431.1675)"><path
+ d="m 0,0 c -5.374,0 -9.927,-1.92 -13.66,-5.758 v 5.81 L -28.315,-1.519 V -52.34 h 14.655 v 32.19 c 2.372,4.92 5.338,7.501 8.898,7.746 3.175,-0.105 5.617,-0.925 7.327,-2.46 L 4.868,-0.419 C 3.332,-0.14 1.709,0 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path44"
+ inkscape:connector-curvature="0" /></g><g
+ id="g46"
+ transform="translate(548.2428,363.2485)"><path
+ d="m 0,0 c 16.655,21.121 22.696,44.433 18.328,70.995 3.068,0 5.742,-0.023 8.417,0.007 2.221,0.025 4.442,0.102 6.663,0.175 4.79,0.154 4.818,0.165 5.881,-4.582 C 42.434,52.544 41.469,38.505 39.11,24.477 38.859,22.985 38.409,21.521 38.246,20.023 37.05,9.081 30.089,3.645 20.164,1.097 13.786,-0.54 7.323,-0.829 0,0"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path48"
+ inkscape:connector-curvature="0" /></g><g
+ id="g50"
+ transform="translate(544.6891,396.1763)"><path
+ d="M 0,0 C 0,2.594 -3.457,4.322 -3.457,4.322 -0.864,5.187 0,8.644 0,8.644 0,8.644 0.865,5.187 3.458,4.322 3.458,4.322 0,2.594 0,0 m -17.099,6.454 c 0,6.742 -8.989,11.236 -8.989,11.236 6.742,2.248 8.989,11.238 8.989,11.238 0,0 2.248,-8.99 8.99,-11.238 0,0 -8.99,-4.494 -8.99,-11.236"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path52"
+ inkscape:connector-curvature="0" /></g><g
+ id="g54"
+ transform="translate(485.0861,429.7925)"><path
+ d="m 0,0 c 0,0 -2.214,3.359 -9.736,2.059 0,0 5.987,28.738 36.298,38.806 C 26.562,40.865 -2.82,24.275 0,0 M 0.583,-33.208 0.58,-33.211 c -1.297,-2.026 -2.821,-3.579 -4.53,-4.616 -1.515,-0.933 -3.178,-1.425 -4.743,-1.425 -0.528,0 -1.044,0.057 -1.539,0.17 l -2.209,0.507 2.184,0.603 c 1.63,0.451 3.063,1.347 4.259,2.664 1.014,1.108 1.856,2.485 2.577,4.213 1.175,2.82 1.784,6.162 1.81,9.936 -0.049,3.718 -0.683,7.054 -1.886,9.902 -0.703,1.654 -1.585,3.056 -2.621,4.163 -1.227,1.311 -2.671,2.178 -4.294,2.576 l -2.187,0.538 2.179,0.572 c 0.48,0.126 0.998,0.196 1.539,0.209 h 0.031 l 1.688,-0.153 c 1.045,-0.206 2.104,-0.616 3.061,-1.185 1.755,-1.031 3.302,-2.567 4.598,-4.565 2.155,-3.374 3.315,-7.536 3.357,-12.042 -0.028,-4.548 -1.159,-8.717 -3.271,-12.064 m 101.949,51.176 c 0,0 -0.075,0.05 -0.226,0.136 -0.541,0.327 -1.113,0.603 -1.715,0.815 -3.044,1.241 -9.881,3.186 -21.906,2.623 -0.029,0 -0.056,0 -0.085,-0.001 C 52.916,21.123 30.022,10.922 30.022,10.922 c 0,0 1.439,1.761 3.453,3.692 10e-4,0 10e-4,10e-4 0.002,10e-4 1.052,0.974 2.355,2.076 3.912,3.227 0.046,0.031 0.088,0.063 0.124,0.093 8.708,6.384 25.341,14.163 51.625,9.541 -0.989,1.124 -2.002,2.192 -3.036,3.215 -1.112,0.883 -2.231,1.693 -3.354,2.456 0.02,-0.012 0.039,-0.023 0.059,-0.036 0,0 -17.016,19.415 -48.683,15.891 C 30.19,48.622 25.983,47.867 21.66,46.564 21.653,46.563 21.646,46.562 21.64,46.56 L 21.638,46.558 C 11.48,43.492 0.683,37.387 -8.719,25.911 -9.29,25.193 -9.841,24.479 -10.376,23.77 c -0.065,-0.099 -0.141,-0.202 -0.226,-0.307 -1.182,-1.581 -2.271,-3.14 -3.279,-4.674 -3.266,-5.427 -5.631,-11.665 -6.311,-13.545 -10.58,-32.401 2.586,-57.55 5.144,-61.967 8.93,-15.158 24.565,-32.355 50.771,-37.327 0.197,-0.047 0.382,-0.101 0.582,-0.147 1.723,-0.367 4.864,-0.929 8.908,-1.196 1.524,-0.069 3.088,-0.094 4.699,-0.067 1.548,-0.009 2.999,0.017 4.335,0.064 0.396,0.028 0.74,0.041 1.044,0.044 5.102,0.238 8.272,0.775 8.272,0.775 -26.542,1.299 -39.847,13.409 -45.691,21.142 -1.325,1.648 -2.46,3.42 -3.377,5.316 -0.361,0.71 -0.523,1.115 -0.523,1.115 8.459,-7.181 20.294,-13.362 20.294,-13.362 10.611,-4.993 21.737,-7.451 33.525,-5.837 0,0 24.644,2.263 34.463,25.09 -0.423,0.322 -0.366,0.278 -0.79,0.6 -0.651,-0.526 -1.294,-1.023 -1.926,-1.496 -0.991,-0.651 -1.964,-1.357 -2.937,-2.07 -5.265,-3.485 -9.561,-5.128 -12.12,-5.879 -19.359,-4.887 -37.273,-1.252 -52.93,12.455 -7.253,6.349 -11.754,14.352 -11.944,24.291 -0.124,6.465 -0.19,12.935 -0.136,19.4 0.085,10.181 7.246,17.921 17.394,19.284 20.561,2.759 41.234,3.71 61.948,4.193 3.173,0.073 5.98,0.037 7.457,-3.356 0.025,-0.058 0.363,0.02 0.552,0.035 1.4,4.461 -0.617,9.27 -4.29,11.624"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path56"
+ inkscape:connector-curvature="0" /></g></g></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/logo_with_text_on_white_bordered.png b/website/assets/logos/logo_with_text_on_white_bordered.png
new file mode 100644
index 000000000..bd1a1e4b7
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_white_bordered.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_white_bordered.svg b/website/assets/logos/logo_with_text_on_white_bordered.svg
new file mode 100644
index 000000000..08125629d
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_white_bordered.svg
@@ -0,0 +1,122 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="607.64587"
+ height="207.92123"
+ viewBox="0 0 607.64589 207.92124"
+ sodipodi:docname="logo_with_text_on_white_bordered.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath18"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path16"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="640"
+ inkscape:window-height="480"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.21542963"
+ inkscape:cx="298.55736"
+ inkscape:cy="108.65533"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="0"
+ inkscape:current-layer="g10" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-612.10927,647.00852)"><g
+ id="g12"><g
+ id="g14"
+ clip-path="url(#clipPath18)"><g
+ id="g20"
+ transform="translate(670.7226,400.7367)"><path
+ d="m 0,0 c -0.687,-5.153 -4.294,-8.246 -10.821,-9.275 -8.933,0.342 -13.571,5.324 -13.914,14.943 0.687,8.245 5.325,12.522 13.914,12.832 C -4.294,17.503 -0.687,14.12 0,8.348 Z m 0,25.715 c -2.404,2.782 -6.527,4.345 -12.367,4.688 -7.902,0 -14,-2.404 -18.294,-7.214 -4.295,-4.81 -6.442,-10.994 -6.442,-18.55 0,-7.903 1.976,-14.172 5.927,-18.811 3.949,-4.638 9.876,-6.956 17.778,-6.956 5.085,0 9.55,1.545 13.398,4.638 v -2.577 c 0,-3.092 -0.971,-5.926 -2.911,-8.502 -1.942,-2.577 -5.66,-3.866 -11.156,-3.866 -5.395,0.344 -9.981,1.89 -13.76,4.638 l -6.699,-7.729 c 5.496,-5.496 13.226,-8.349 23.19,-8.555 l 3.091,0.309 c 6.527,0.275 11.938,2.834 16.233,7.678 4.294,4.844 6.441,10.873 6.441,18.088 V 28.857 L 0,30.403 Z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path22"
+ inkscape:connector-curvature="0" /></g><g
+ id="g24"
+ transform="translate(721.3339,400.388)"><path
+ d="M 0,0 -19.678,50.387 H -36.935 L -8.326,-20.779 H 8.754 L 37.291,50.387 H 20.034 Z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path26"
+ inkscape:connector-curvature="0" /></g><path
+ d="m 762.901,379.609 h 14.429 v 51.583 h -14.429 z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path28"
+ inkscape:connector-curvature="0" /><g
+ id="g30"
+ transform="translate(770.0126,450.7747)"><path
+ d="m 0,0 c -2.337,0 -4.191,-0.696 -5.565,-2.088 -1.375,-1.391 -2.062,-3.117 -2.062,-5.179 0,-2.061 0.687,-3.779 2.062,-5.153 1.374,-1.374 3.263,-2.061 5.669,-2.061 2.404,0 4.293,0.687 5.667,2.061 1.375,1.374 2.062,3.092 2.062,5.153 0,2.062 -0.687,3.788 -2.062,5.179 C 4.397,-0.696 2.474,0 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path32"
+ inkscape:connector-curvature="0" /></g><g
+ id="g34"
+ transform="translate(816.4667,407.1781)"><path
+ d="m 0,0 c -3.264,2.438 -7.455,3.898 -12.573,4.381 -3.471,0.686 -5.789,1.554 -6.957,2.601 -1.168,1.048 -1.168,2.216 0,3.504 1.168,1.289 3.486,1.933 6.957,1.933 3.504,0 6.424,-1.391 8.76,-4.174 l 8.039,9.069 c -2.336,2.302 -4.844,4.038 -7.524,5.206 -2.68,1.167 -6.441,1.751 -11.285,1.751 -4.192,0 -8.271,-1.4 -12.239,-4.199 -3.968,-2.801 -5.951,-6.863 -5.951,-12.187 0,-5.807 1.742,-9.758 5.229,-11.854 3.486,-2.094 7.326,-3.383 11.518,-3.863 4.638,-0.447 7.36,-1.315 8.168,-2.604 0.807,-1.288 0.635,-2.576 -0.515,-3.864 -1.152,-1.289 -3.703,-1.932 -7.653,-1.932 -4.913,0.446 -8.537,2.06 -10.873,4.843 l -7.988,-9.017 c 2.336,-2.337 5.127,-4.201 8.374,-5.592 3.247,-1.392 7.309,-2.087 12.187,-2.087 5.119,0 9.602,1.339 13.45,4.02 3.848,2.679 5.772,6.458 5.772,11.336 C 4.896,-6.683 3.264,-2.44 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path36"
+ inkscape:connector-curvature="0" /></g><g
+ id="g38"
+ transform="translate(861.0429,395.5062)"><path
+ d="m 0,0 c -2.439,-2.628 -5.616,-3.942 -9.533,-3.942 -3.916,0 -7.095,1.314 -9.533,3.942 -2.44,2.628 -3.66,5.9 -3.66,9.816 0,3.917 1.22,7.198 3.66,9.843 2.438,2.646 5.598,3.968 9.482,3.968 3.949,0 7.145,-1.322 9.584,-3.968 C 2.439,17.014 3.659,13.733 3.659,9.816 3.659,5.9 2.439,2.628 0,0 m -9.481,36.149 c -7.491,0 -13.743,-2.507 -18.758,-7.523 -5.017,-5.017 -7.524,-11.269 -7.524,-18.757 0,-7.456 2.507,-13.7 7.524,-18.732 5.015,-5.034 11.267,-7.55 18.758,-7.55 7.489,0 13.741,2.508 18.757,7.523 5.016,5.016 7.524,11.269 7.524,18.759 0,7.488 -2.508,13.74 -7.524,18.757 -5.016,5.016 -11.268,7.523 -18.757,7.523"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path40"
+ inkscape:connector-curvature="0" /></g><g
+ id="g42"
+ transform="translate(910.0244,431.14)"><path
+ d="m 0,0 c -5.291,0 -9.773,-1.89 -13.449,-5.668 v 5.72 l -14.43,-1.546 v -50.037 h 14.43 v 31.691 c 2.335,4.845 5.255,7.387 8.76,7.627 3.126,-0.102 5.531,-0.91 7.214,-2.422 L 4.792,-0.412 C 3.28,-0.138 1.683,0 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path44"
+ inkscape:connector-curvature="0" /></g><g
+ id="g46"
+ transform="translate(594.4321,453.1439)"><path
+ d="m 0,0 c 0,0 -0.204,0.139 -0.45,0.277 -0.906,0.547 -1.856,1 -2.744,1.321 -1.685,0.679 -3.62,1.262 -5.761,1.738 l -2.063,2.533 c -1.793,2.283 -6.09,7.358 -13.132,13.334 -8.142,6.209 -24.212,15.045 -47.595,12.442 -4.578,-0.444 -9.077,-1.318 -13.368,-2.597 l -0.232,-0.068 c -12.978,-3.918 -24.155,-11.512 -33.24,-22.601 -0.6,-0.754 -1.179,-1.504 -1.783,-2.307 l -0.134,-0.191 -0.062,-0.074 c -1.194,-1.596 -2.36,-3.258 -3.485,-4.969 l -0.125,-0.198 c -3.559,-5.915 -6.126,-12.72 -6.85,-14.73 -11.284,-34.556 2.735,-61.502 5.669,-66.567 8.482,-14.4 23.945,-32.461 50.005,-38.975 1.42,-0.354 2.872,-0.676 4.356,-0.96 0.016,-0.004 0.036,-0.009 0.053,-0.013 6.537,-1.118 16.647,-1.928 29.969,-0.317 3.21,0.621 8.236,2.535 8.646,8.445 2.209,0.842 10.261,3.812 10.261,3.812 8.572,3.874 18.586,11.106 24.334,24.546 1.21,2.83 0.277,6.128 -2.171,7.994 l -0.202,0.152 c 0.639,1.557 1.125,3.209 1.488,4.93 l 0.019,0.009 c 0,0 0.063,0.325 0.164,0.854 0.013,0.073 0.028,0.144 0.041,0.215 0.402,2.114 1.294,6.91 1.719,10.035 0.02,0.149 0.029,0.268 0.033,0.371 2.136,14.686 2.099,26.608 -0.156,37.847 1.443,0.114 2.672,1.095 3.106,2.477 l 0.655,2.086 C 9.188,-12.066 6.243,-4.003 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path48"
+ inkscape:connector-curvature="0" /></g><g
+ id="g50"
+ transform="translate(551.9267,364.2689)"><path
+ d="m 0,0 c 16.398,20.796 22.346,43.748 18.045,69.9 3.02,0 5.654,-0.022 8.288,0.007 2.187,0.025 4.374,0.101 6.56,0.172 4.716,0.152 4.743,0.163 5.79,-4.512 C 41.779,51.734 40.83,37.911 38.507,24.1 38.26,22.631 37.817,21.189 37.656,19.715 36.479,8.941 29.625,3.588 19.853,1.08 13.574,-0.531 7.209,-0.816 0,0"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path52"
+ inkscape:connector-curvature="0" /></g><g
+ id="g54"
+ transform="translate(548.4282,396.6898)"><path
+ d="M 0,0 C 0,2.553 -3.404,4.255 -3.404,4.255 -0.851,5.105 0,8.51 0,8.51 0,8.51 0.851,5.105 3.404,4.255 3.404,4.255 0,2.553 0,0 m -16.835,6.354 c 0,6.638 -8.851,11.063 -8.851,11.063 6.638,2.213 8.851,11.063 8.851,11.063 0,0 2.212,-8.85 8.85,-11.063 0,0 -8.85,-4.425 -8.85,-11.063"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path56"
+ inkscape:connector-curvature="0" /></g><g
+ id="g58"
+ transform="translate(489.744,429.7865)"><path
+ d="m 0,0 c 0,0 -2.18,3.308 -9.585,2.026 0,0 5.894,28.296 35.737,38.209 C 26.152,40.235 -2.777,23.901 0,0 m 0.574,-32.696 -0.003,-0.003 c -1.277,-1.994 -2.778,-3.523 -4.46,-4.544 -1.492,-0.919 -3.13,-1.403 -4.67,-1.403 -0.519,0 -1.028,0.055 -1.516,0.167 l -2.174,0.5 2.15,0.593 c 1.605,0.443 3.016,1.325 4.194,2.623 0.997,1.09 1.826,2.447 2.536,4.149 1.158,2.775 1.757,6.066 1.783,9.781 -0.048,3.661 -0.673,6.945 -1.857,9.75 -0.693,1.629 -1.56,3.009 -2.58,4.099 -1.208,1.29 -2.63,2.143 -4.228,2.536 l -2.154,0.529 2.145,0.564 c 0.473,0.125 0.983,0.193 1.516,0.205 l 0.031,0.002 1.662,-0.152 c 1.029,-0.203 2.071,-0.605 3.014,-1.167 1.727,-1.015 3.25,-2.527 4.526,-4.493 2.122,-3.323 3.264,-7.421 3.305,-11.857 C 3.767,-25.296 2.653,-29.4 0.574,-32.696 M 100.951,17.69 c 0,0 -0.074,0.05 -0.223,0.136 -0.532,0.32 -1.095,0.593 -1.688,0.801 -2.997,1.223 -9.729,3.139 -21.568,2.583 -0.029,0 -0.055,0 -0.084,-0.001 C 52.1,20.798 29.559,10.754 29.559,10.754 c 0,0 1.417,1.733 3.4,3.636 0,0 10e-4,0 10e-4,10e-4 1.036,0.958 2.319,2.044 3.853,3.176 0.044,0.031 0.086,0.062 0.122,0.091 8.573,6.286 24.949,13.945 50.829,9.394 -0.615,0.698 -1.254,1.341 -1.887,1.998 l 0.057,-0.009 c 0,0 -5.858,6.886 -16.555,12.616 -0.613,0.331 -1.252,0.659 -1.91,0.985 -0.038,0.018 -0.073,0.038 -0.112,0.057 -0.067,0.033 -0.128,0.057 -0.195,0.089 -8.007,3.888 -19.263,7.05 -33.564,5.458 -3.873,-0.374 -8.015,-1.116 -12.272,-2.399 -0.007,-0.002 -0.013,-0.003 -0.02,-0.005 L 21.304,45.84 c -10,-3.018 -20.632,-9.029 -29.888,-20.328 -0.563,-0.707 -1.105,-1.409 -1.632,-2.109 -0.064,-0.096 -0.138,-0.198 -0.222,-0.301 -1.164,-1.557 -2.236,-3.091 -3.229,-4.602 -3.216,-5.343 -5.544,-11.486 -6.214,-13.337 -10.417,-31.901 2.546,-56.661 5.065,-61.011 8.793,-14.923 24.186,-31.856 49.988,-36.751 0.195,-0.046 0.377,-0.101 0.573,-0.145 1.697,-0.361 4.79,-0.914 8.771,-1.178 1.5,-0.067 3.04,-0.093 4.626,-0.065 1.525,-0.01 2.953,0.016 4.268,0.063 0.39,0.027 0.729,0.04 1.029,0.044 5.023,0.233 8.144,0.762 8.144,0.762 -26.133,1.279 -39.232,13.203 -44.987,20.816 -1.304,1.622 -2.422,3.367 -3.324,5.234 -0.356,0.699 -0.515,1.097 -0.515,1.097 8.328,-7.069 19.98,-13.155 19.98,-13.155 10.449,-4.917 21.402,-7.337 33.008,-5.747 0,0 24.264,2.227 33.932,24.702 -0.417,0.317 -0.361,0.275 -0.777,0.592 -0.642,-0.518 -1.274,-1.007 -1.898,-1.474 -0.975,-0.64 -1.933,-1.336 -2.89,-2.038 -5.184,-3.431 -9.414,-5.048 -11.934,-5.788 -19.06,-4.811 -36.698,-1.232 -52.114,12.263 -7.14,6.251 -11.572,14.131 -11.759,23.917 -0.122,6.365 -0.188,12.734 -0.135,19.101 0.084,10.023 7.135,17.645 17.126,18.986 20.245,2.716 40.598,3.652 60.993,4.129 3.125,0.072 5.888,0.036 7.342,-3.306 0.025,-0.057 0.358,0.021 0.543,0.035 1.379,4.393 -0.607,9.127 -4.223,11.444"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path60"
+ inkscape:connector-curvature="0" /></g></g></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/powered-gvisor.png b/website/assets/logos/powered-gvisor.png
new file mode 100644
index 000000000..e00c74a33
--- /dev/null
+++ b/website/assets/logos/powered-gvisor.png
Binary files differ
diff --git a/website/blog/2019-11-18-security-basics.md b/website/blog/2019-11-18-security-basics.md
new file mode 100644
index 000000000..fbdd511dd
--- /dev/null
+++ b/website/blog/2019-11-18-security-basics.md
@@ -0,0 +1,299 @@
+# gVisor Security Basics - Part 1
+
+This blog is a space for engineers and community members to share perspectives
+and deep dives on technology and design within the gVisor project. Though our
+logo suggests we're in the business of space exploration (or perhaps fighting
+sea monsters), we're actually in the business of sandboxing Linux containers.
+When we created gVisor, we had three specific goals in mind; _container-native
+security_, _resource efficiency_, and _platform portability_. To put it simply,
+gVisor provides _efficient defense-in-depth for containers anywhere_.
+
+This post addresses gVisor's _container-native security_, specifically how
+gVisor provides strong isolation between an application and the host OS. Future
+posts will address _resource efficiency_ (how gVisor preserves container
+benefits like fast starts, smaller snapshots, and less memory overhead than VMs)
+and _platform portability_ (run gVisor wherever Linux OCI containers run).
+Delivering on each of these goals requires careful security considerations and a
+robust design.
+
+## What does "sandbox" mean?
+
+gVisor allows the execution of untrusted containers, preventing them from
+adversely affecting the host. This means that the untrusted container is
+prevented from attacking or spying on either the host kernel or any other peer
+userspace processes on the host.
+
+For example, if you are a cloud container hosting service, running containers
+from different customers on the same virtual machine means that compromises
+expose customer data. Properly configured, gVisor can provide sufficient
+isolation to allow different customers to run containers on the same host. There
+are many aspects to the proper configuration, including limiting file and
+network access, which we will discuss in future posts.
+
+## The cost of compromise
+
+gVisor was designed around the premise that any security boundary could
+potentially be compromised with enough time and resources. We tried to optimize
+for a solution that was as costly and time-consuming for an attacker as
+possible, at every layer.
+
+Consequently, gVisor was built through a combination of intentional design
+principles and specific technology choices that work together to provide the
+security isolation needed for running hostile containers on a host. We'll dig
+into it in the next section!
+
+# Design Principles
+
+gVisor was designed with some
+[common secure design principles](https://www.owasp.org/index.php/Security_by_Design_Principles)
+in mind: Defense-in-Depth, Principle of Least-Privilege, Attack Surface
+Reduction and Secure-by-Default[^1].
+
+In general, Design Principles outline good engineering practices, but in the
+case of security, they also can be thought of as a set of tactics. In a
+real-life castle, there is no single defensive feature. Rather, there are many
+in combination: redundant walls, scattered draw bridges, small bottle-neck
+entrances, moats, etc.
+
+A simplified version of the design is below
+([more detailed version](/docs/))[^2]:
+
+![Figure 1](/assets/images/2019-11-18-security-basics-figure1.png "Simplified design of gVisor.")
+
+In order to discuss design principles, the following components are important to
+know:
+
+* runsc - binary that packages the Sentry, platform, and Gofer(s) that run
+ containers. runsc is the drop-in binary for running gVisor in Docker and
+ Kubernetes.
+* Untrusted Application - container running in the sandbox. Untrusted
+ application/container are used interchangeably in this article.
+* Platform Syscall Switcher - intercepts syscalls from the application and
+ passes them to the Sentry with no further handling.
+* Sentry - The "application kernel" in userspace that serves the untrusted
+ application. Each application instance has its own Sentry. The Sentry
+ handles syscalls, routes I/O to gofers, and manages memory and CPU, all in
+ userspace. The Sentry is allowed to make limited, filtered syscalls to the
+ host OS.
+* Gofer - a process that specifically handles different types of I/O for the
+ Sentry (usually disk I/O). Gofers are also allowed to make filtered syscalls
+ to the Host OS.
+* Host OS - the actual OS on which gVisor containers are running, always some
+ flavor of Linux (sorry, Windows/MacOS users).
+
+It is important to emphasize what is being protected from the untrusted
+application in this diagram: the host OS and other userspace applications.
+
+In this post, we are only discussing security-related features of gVisor, and
+you might ask, "What about performance, compatibility and stability?" We will
+cover these considerations in future posts.
+
+## Defense-in-Depth
+
+For gVisor, Defense-in-Depth means each component of the software stack trusts
+the other components as little as possible.
+
+It may seem strange that we would want our own software components to distrust
+each other. But by limiting the trust between small, discrete components, each
+component is forced to defend itself against potentially malicious input. And
+when you stack these components on top of each other, you can ensure that
+multiple security barriers must be overcome by an attacker.
+
+And this leads us to how Defense-in-Depth is applied to gVisor: no single
+vulnerability should compromise the host.
+
+In the "Attacker's Advantage / Defender's Dilemma," the defender must succeed
+all the time while the attacker only needs to succeed once. Defense in Depth
+inverts this principle: once the attacker successfully compromises any given
+software component, they are immediately faced with needing to compromise a
+subsequent, distinct layer in order to move laterally or acquire more privilege.
+
+For example, the untrusted container is isolated from the Sentry. The Sentry is
+isolated from host I/O operations by serving those requests in separate
+processes called Gofers. And both the untrusted container and its associated
+Gofers are isolated from the host process that is running the sandbox.
+
+An additional benefit is that this generally leads to more robust and stable
+software, forcing interfaces to be strictly defined and tested to ensure all
+inputs are properly parsed and bounds checked.
+
+## Least-Privilege
+
+The principle of Least-Privilege implies that each software component has only
+the permissions it needs to function, and no more.
+
+Least-Privilege is applied throughout gVisor. Each component and more
+importantly, each interface between the components, is designed so that only the
+minimum level of permission is required for it to perform its function.
+Specifically, the closer you are to the untrusted application, the less
+privilege you have.
+
+![Figure 2](/assets/images/2019-11-18-security-basics-figure2.png "runsc components and their privileges.")
+
+This is evident in how runsc (the drop in gVisor binary for Docker/Kubernetes)
+constructs the sandbox. The Sentry has the least privilege possible (it can't
+even open a file!). Gofers are only allowed file access, so even if it were
+compromised, the host network would be unavailable. Only the runsc binary itself
+has full access to the host OS, and even runsc's access to the host OS is often
+limited through capabilities / chroot / namespacing.
+
+Designing a system with Defense-in-Depth and Least-Privilege in mind encourages
+small, separate, single-purpose components, each with very restricted
+privileges.
+
+## Attack Surface Reduction
+
+There are no bugs in unwritten code. In other words, gVisor supports a feature
+if and only if it is needed to run host Linux containers.
+
+### Host Application/Sentry Interface:
+
+There are a lot of things gVisor does not need to do. For example, it does not
+need to support arbitrary device drivers, nor does it need to support video
+playback. By not implementing what will not be used, we avoid introducing
+potential bugs in our code.
+
+That is not to say gVisor has limited functionality! Quite the opposite, we
+analyzed what is actually needed to run Linux containers and today the Sentry
+supports 237 syscalls[^3]<sup>,</sup>[^4], along with the range of critical
+/proc and /dev files. However, gVisor does not support every syscall in the
+Linux kernel. There are about 350 syscalls[^5] within the 5.3.11 version of the
+Linux kernel, many of which do not apply to Linux containers that typically host
+cloud-like workloads. For example, we don't support old versions of epoll
+(epoll_ctl_old, epoll_wait_old), because they are deprecated in Linux and no
+supported workloads use them.
+
+Furthermore, any exploited vulnerabilities in the implemented syscalls (or
+Sentry code in general) only apply to gaining control of the Sentry. More on
+this in a later post.
+
+### Sentry/Host OS Interface:
+
+The Sentry's interactions with the Host OS are restricted in many ways. For
+instance, no syscall is "passed-through" from the untrusted application to the
+host OS. All syscalls are intercepted and interpreted. In the case where the
+Sentry needs to call the Host OS, we severely limit the syscalls that the Sentry
+itself is allowed to make to the host kernel[^6].
+
+For example, there are many file-system based attacks, where manipulation of
+files or their paths, can lead to compromise of the host[^7]. As a result, the
+Sentry does not allow any syscall that creates or opens a file descriptor. All
+file descriptors must be donated to the sandbox. By disallowing open or creation
+of file descriptors, we eliminate entire categories of these file-based attacks.
+
+This does not affect functionality though. For example, during startup, runsc
+will donate FDs the Sentry that allow for mapping STDIN/STDOUT/STDERR to the
+sandboxed application. Also the Gofer may donate an FD to the Sentry, allowing
+for direct access to some files. And most files will be remotely accessed
+through the Gofers, in which case no FDs are donated to the Sentry.
+
+The Sentry itself is only allowed access to specific
+[whitelisted syscalls](https://github.com/google/gvisor/blob/master/runsc/boot/config.go).
+Without networking, the Sentry needs 53 host syscalls in order to function, and
+with networking, it uses an additional 15[^8]. By limiting the whitelist to only
+these needed syscalls, we radically reduce the amount of host OS attack surface.
+If any attempts are made to call something outside the whitelist, it is
+immediately blocked and the sandbox is killed by the Host OS.
+
+### Sentry/Gofer Interface:
+
+The Sentry communicates with the Gofer through a local unix domain socket (UDS)
+via a version of the 9P protocol[^9]. The UDS file descriptor is passed to the
+sandbox during initialization and all communication between the Sentry and Gofer
+happens via 9P. We will go more into how Gofers work in future posts.
+
+### End Result
+
+So, of the 350 syscalls in the Linux kernel, the Sentry needs to implement only
+237 of them to support containers. At most, the Sentry only needs to call 68 of
+the host Linux syscalls. In other words, with gVisor, applications get the vast
+majority (and growing) functionality of Linux containers for only 68 possible
+syscalls to the Host OS. 350 syscalls to 68 is attack surface reduction.
+
+![Figure 3](/assets/images/2019-11-18-security-basics-figure3.png "Reduction of Attack Surface of the Syscall Table. Note that the Senty's Syscall Emulation Layer keeps the Containerized Process from ever calling the Host OS.")
+
+## Secure-by-default
+
+The default choice for a user should be safe. If users need to run a less secure
+configuration of the sandbox for the sake of performance or application
+compatibility, they must make the choice explicitly.
+
+An example of this might be a networking application that is performance
+sensitive. Instead of using the safer, Go-based Netstack in the Sentry, the
+untrusted container can instead use the host Linux networking stack directly.
+However, this means the untrusted container will be directly interacting with
+the host, without the safety benefits of the sandbox. It also means that an
+attack could directly compromise the host through his path.
+
+These less secure configurations are **not** the default. In fact, the user must
+take action to change the configuration and run in a less secure mode.
+Additionally, these actions make it very obvious that a less secure
+configuration is being used.
+
+This can be as simple as forcing a default runtime flag option to the secure
+option. gVisor does this by always using its internal netstack by default.
+However, for certain performance sensitive applications, we allow the usage of
+the host OS networking stack, but it requires the user to actively set a
+flag[^10].
+
+# Technology Choices
+
+Technology choices for gVisor mainly involve things that will give us a security
+boundary.
+
+At a higher level, boundaries in software might be describing a great many
+things. It may be discussing the boundaries between threads, boundaries between
+processes, boundaries between CPU privilege levels, and more.
+
+Security boundaries are interfaces that are designed and built so that entire
+classes of bugs/vulnerabilities are eliminated.
+
+For example, the Sentry and Gofers are implemented using Go. Go was chosen for a
+number of the features it provided. Go is a fast, statically-typed, compiled
+language that has efficient multi-threading support, garbage collection and a
+constrained set of "unsafe" operations.
+
+Using these features enabled safe array and pointer handling. This means entire
+classes of vulnerabilities were eliminated, such as buffer overflows and
+use-after-free.
+
+Another example is our use of very strict syscall switching to ensure that the
+Sentry is always the first software component that parses and interprets the
+calls being made by the untrusted container. Here is an instance where different
+platforms use different solutions, but all of them share this common trait,
+whether it is through the use of ptrace "a la PTRACE_ATTACH"[^11] or kvm's
+ring0[^12].
+
+Finally, one of the most restrictive choices was to use seccomp, to restrict the
+Sentry from being able to open or create a file descriptor on the host. All file
+I/O is required to go through Gofers. Preventing the opening or creation of file
+descriptions eliminates whole categories of bugs around file permissions
+[like this one](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2016-4557)[^13].
+
+# To be continued - Part 2
+
+In part 2 of this blog post, we will explore gVisor from an attacker's point of
+view. We will use it as an opportunity to examine the specific strengths and
+weaknesses of each gVisor component.
+
+We will also use it to introduce Google's Vulnerability Reward Program[^14], and
+other ways the community can contribute to help make gVisor safe, fast and
+stable.
+
+## Notes
+
+[^1]: [https://www.owasp.org/index.php/Security_by_Design_Principles](https://www.owasp.org/index.php/Security_by_Design_Principles)
+[^2]: [https://gvisor.dev/docs/architecture_guide](https://gvisor.dev/docs/architecture_guide/)
+[^3]: [https://github.com/google/gvisor/blob/master/pkg/sentry/syscalls/linux/linux64_amd64.go](https://github.com/google/gvisor/blob/master/pkg/sentry/syscalls/syscalls.go)
+[^4]: Internally that is, it doesn't call to the Host OS to implement them, in
+ fact that is explicitly disallowed, more on that in the future.
+[^5]: [https://elixir.bootlin.com/linux/latest/source/arch/x86/entry/syscalls/syscall_64.tbl#L345](https://elixir.bootlin.com/linux/latest/source/arch/x86/entry/syscalls/syscall_64.tbl#L345)
+[^6]: [https://github.com/google/gvisor/tree/master/runsc/boot/filter](https://github.com/google/gvisor/tree/master/runsc/boot/filter)
+[^7]: [https://en.wikipedia.org/wiki/Dirty_COW](https://en.wikipedia.org/wiki/Dirty_COW)
+[^8]: [https://github.com/google/gvisor/blob/master/runsc/boot/config.go](https://github.com/google/gvisor/blob/master/runsc/boot/config.go)
+[^9]: [https://en.wikipedia.org/wiki/9P_(protocol)](https://en.wikipedia.org/wiki/9P_\(protocol\))
+[^10]: [https://gvisor.dev/docs/user_guide/networking/#network-passthrough](https://gvisor.dev/docs/user_guide/networking/#network-passthrough)
+[^11]: [https://github.com/google/gvisor/blob/c7e901f47a09eaac56bd4813227edff016fa6bff/pkg/sentry/platform/ptrace/subprocess.go#L390](https://github.com/google/gvisor/blob/c7e901f47a09eaac56bd4813227edff016fa6bff/pkg/sentry/platform/ptrace/subprocess.go#L390)
+[^12]: [https://github.com/google/gvisor/blob/c7e901f47a09eaac56bd4813227edff016fa6bff/pkg/sentry/platform/ring0/kernel_amd64.go#L182](https://github.com/google/gvisor/blob/c7e901f47a09eaac56bd4813227edff016fa6bff/pkg/sentry/platform/ring0/kernel_amd64.go#L182)
+[^13]: [https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2016-4557](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2016-4557)
+[^14]: [https://www.google.com/about/appsecurity/reward-program/index.html](https://www.google.com/about/appsecurity/reward-program/index.html)
diff --git a/website/blog/2020-04-02-networking-security.md b/website/blog/2020-04-02-networking-security.md
new file mode 100644
index 000000000..5a5e38fd7
--- /dev/null
+++ b/website/blog/2020-04-02-networking-security.md
@@ -0,0 +1,183 @@
+# gVisor Networking Security
+
+In our
+[first blog post](https://gvisor.dev/blog/2019/11/18/gvisor-security-basics-part-1/),
+we covered some secure design principles and how they guided the architecture of
+gVisor as a whole. In this post, we will cover how these principles guided the
+networking architecture of gVisor, and the tradeoffs involved. In particular, we
+will cover how these principles culminated in two networking modes, how they
+work, and the properties of each.
+
+## gVisor's security architecture in the context of networking
+
+Linux networking is complicated. The TCP protocol is over 40 years old, and has
+been repeatedly extended over the years to keep up with the rapid pace of
+network infrastructure improvements, all while maintaining compatibility. On top
+of that, Linux networking has a fairly large API surface. Linux supports
+[over 150 options](https://github.com/google/gvisor/blob/960f6a975b7e44c0efe8fd38c66b02017c4fe137/pkg/sentry/strace/socket.go#L476-L644)
+for the most common socket types alone. In fact, the net subsystem is one of the
+largest and fastest growing in Linux at approximately 1.1 million lines of code.
+For comparison, that is several times the size of the entire gVisor codebase.
+
+At the same time, networking is increasingly important. The cloud era is
+arguably about making everything a network service, and in order to make that
+work, the interconnect performance is critical. Adding networking support to
+gVisor was difficult, not just due to the inherent complexity, but also because
+it has the potential to significantly weaken gVisor's security model.
+
+As outlined in the previous blog post, gVisor's
+[secure design principles](https://gvisor.dev/blog/2019/11/18/gvisor-security-basics-part-1/#design-principles)
+are:
+
+1. Defense in Depth: each component of the software stack trusts each other
+ component as little as possible.
+1. Least Privilege: each software component has only the permissions it needs
+ to function, and no more.
+1. Attack Surface Reduction: limit the surface area of the host exposed to the
+ sandbox.
+1. Secure by Default: the default choice for a user should be safe.
+
+gVisor manifests these principles as a multi-layered system. An application
+running in the sandbox interacts with the Sentry, a userspace kernel, which
+mediates all interactions with the Host OS and beyond. The Sentry is written in
+pure Go with minimal unsafe code, making it less vulnerable to buffer overflows
+and related memory bugs that can lead to a variety of compromises including code
+injection. It emulates Linux using only a minimal and audited set of Host OS
+syscalls that limit the Host OS's attack surface exposed to the Sentry itself.
+The syscall restrictions are enforced by running the Sentry with seccomp
+filters, which enforce that the Sentry can only use the expected set of
+syscalls. The Sentry runs as an unprivileged user and in namespaces, which,
+along with the seccomp filters, ensure that the Sentry is run with the Least
+Privilege required.
+
+gVisor's multi-layered design provides Defense in Depth. The Sentry, which does
+not trust the application because it may attack the Sentry and try to bypass it,
+is the first layer. The sandbox that the Sentry runs in is the second layer. If
+the Sentry were compromised, the attacker would still be in a highly restrictive
+sandbox which they must also break out of in order to compromise the Host OS.
+
+To enable networking functionality while preserving gVisor's security
+properties, we implemented a
+[userspace network stack](https://github.com/google/gvisor/tree/master/pkg/tcpip)
+in the Sentry, which we creatively named Netstack. Netstack is also written in
+Go, not only to avoid unsafe code in the network stack itself, but also to avoid
+a complicated and unsafe Foreign Function Interface. Having its own integrated
+network stack allows the Sentry to implement networking operations using up to
+three Host OS syscalls to read and write packets. These syscalls allow a very
+minimal set of operations which are already allowed (either through the same or
+a similar syscall). Moreover, because packets typically come from off-host (e.g.
+the internet), the Host OS's packet processing code has received a lot of
+scrutiny, hopefully resulting in a high degree of hardening.
+
+![Figure 1](/assets/images/2020-04-02-networking-security-figure1.png "Network and gVisor.")
+
+## Writing a network stack
+
+Netstack was written from scratch specifically for gVisor. Because Netstack was
+designed and implemented to be modular, flexible and self-contained, there are
+now several more projects using Netstack in creative and exciting ways. As we
+discussed, a custom network stack has enabled a variety of security-related
+goals which would not have been possible any other way. This came at a cost
+though. Network stacks are complex and writing a new one comes with many
+challenges, mostly related to application compatibility and performance.
+
+Compatibility issues typically come in two forms: missing features, and features
+with behavior that differs from Linux (usually due to bugs). Both of these are
+inevitable in an implementation of a complex system spanning many quickly
+evolving and ambiguous standards. However, we have invested heavily in this
+area, and the vast majority of applications have no issues using Netstack. For
+example,
+[we now support setting 34 different socket options](https://github.com/google/gvisor/blob/815df2959a76e4a19f5882e40402b9bbca9e70be/pkg/sentry/socket/netstack/netstack.go#L830-L1764)
+versus
+[only 7 in our initial git commit](https://github.com/google/gvisor/blob/d02b74a5dcfed4bfc8f2f8e545bca4d2afabb296/pkg/sentry/socket/epsocket/epsocket.go#L445-L702).
+We are continuing to make good progress in this area.
+
+Performance issues typically come from TCP behavior and packet processing speed.
+To improve our TCP behavior, we are working on implementing the full set of TCP
+RFCs. There are many RFCs which are significant to performance (e.g.
+[RACK](https://tools.ietf.org/id/draft-ietf-tcpm-rack-03.html) and
+[BBR](https://tools.ietf.org/html/draft-cardwell-iccrg-bbr-congestion-control-00))
+that we have yet to implement. This mostly affects TCP performance with
+non-ideal network conditions (e.g. cross continent connections). Faster packet
+processing mostly improves TCP performance when network conditions are very good
+(e.g. within a datacenter). Our primary strategy here is to reduce interactions
+with the Go runtime, specifically the garbage collector (GC) and scheduler. We
+are currently optimizing buffer management to reduce the amount of garbage,
+which will lower the GC cost. To reduce scheduler interactions, we are
+re-architecting the TCP implementation to use fewer goroutines. Performance
+today is good enough for most applications and we are making steady
+improvements. For example, since May of 2019, we have improved the Netstack
+runsc
+[iperf3 download benchmark](https://github.com/google/gvisor/blob/master/benchmarks/suites/network.py)
+score by roughly 15% and upload score by around 10,000X. Current numbers are
+about 17 Gbps download and about 8 Gbps upload versus about 42 Gbps and 43 Gbps
+for native (Linux) respectively.
+
+## An alternative
+
+We also offer an alternative network mode: passthrough. This name can be
+misleading as syscalls are never passed through from the app to the Host OS.
+Instead, the passthrough mode implements networking in gVisor using the Host
+OS's network stack. (This mode is called
+[hostinet](https://github.com/google/gvisor/tree/master/pkg/sentry/socket/hostinet)
+in the codebase.) Passthrough mode can improve performance for some use cases as
+the Host OS's network stack has had an enormous number of person-years poured
+into making it highly performant. However, there is a rather large downside to
+using passthrough mode: it weakens gVisor's security model by increasing the
+Host OS's Attack Surface. This is because using the Host OS's network stack
+requires the Sentry to use the Host OS's
+[Berkeley socket interface](https://en.wikipedia.org/wiki/Berkeley_sockets). The
+Berkeley socket interface is a much larger API surface than the packet interface
+that our network stack uses. When passthrough mode is in use, the Sentry is
+allowed to use
+[15 additional syscalls](https://github.com/google/gvisor/blob/b1576e533223e98ebe4bd1b82b04e3dcda8c4bf1/runsc/boot/filter/config.go#L312-L517).
+Further, this set of syscalls includes some that allow the Sentry to create file
+descriptors, something that
+[we don't normally allow](https://gvisor.dev/blog/2019/11/18/gvisor-security-basics-part-1/#sentry-host-os-interface)
+as it opens up classes of file-based attacks.
+
+There are some networking features that we can't implement on top of syscalls
+that we feel are safe (most notably those behind
+[ioctl](http://man7.org/linux/man-pages/man2/ioctl.2.html)) and therefore are
+not supported. Because of this, we actually support fewer networking features in
+passthrough mode than we do in Netstack, reducing application compatibility.
+That's right: using our networking stack provides better overall application
+compatibility than using our passthrough mode.
+
+That said, gVisor with passthrough networking still provides a high level of
+isolation. Applications cannot specify host syscall arguments directly, and the
+sentry's seccomp policy restricts its syscall use significantly more than a
+general purpose seccomp policy.
+
+## Secure by Default
+
+The goal of the Secure by Default principle is to make it easy to securely
+sandbox containers. Of course, disabling network access entirely is the most
+secure option, but that is not practical for most applications. To make gVisor
+Secure by Default, we have made Netstack the default networking mode in gVisor
+as we believe that it provides significantly better isolation. For this reason
+we strongly caution users from changing the default unless Netstack flat out
+won't work for them. The passthrough mode option is still provided, but we want
+users to make an informed decision when selecting it.
+
+Another way in which gVisor makes it easy to securely sandbox containers is by
+allowing applications to run unmodified, with no special configuration needed.
+In order to do this, gVisor needs to support all of the features and syscalls
+that applications use. Neither seccomp nor gVisor's passthrough mode can do this
+as applications commonly use syscalls which are too dangerous to be included in
+a secure policy. Even if this dream isn't fully realized today, gVisor's
+architecture with Netstack makes this possible.
+
+## Give Netstack a Try
+
+If you haven't already, try running a workload in gVisor with Netstack. You can
+find instructions on how to get started in our
+[Quick Start](/docs/user_guide/quick_start/docker/). We want to hear about both
+your successes and any issues you encounter. We welcome your contributions,
+whether that be verbal feedback or code contributions, via our
+[Gitter channel](https://gitter.im/gvisor/community),
+[email list](https://groups.google.com/forum/#!forum/gvisor-users),
+[issue tracker](https://gvisor.dev/issue/new), and
+[Github repository](https://github.com/google/gvisor). Feel free to express
+interest in an [open issue](https://gvisor.dev/issue/), or reach out if you
+aren't sure where to start.
diff --git a/website/blog/BUILD b/website/blog/BUILD
new file mode 100644
index 000000000..01c1f5a6e
--- /dev/null
+++ b/website/blog/BUILD
@@ -0,0 +1,37 @@
+load("//website:defs.bzl", "doc", "docs")
+
+package(
+ default_visibility = ["//website:__pkg__"],
+ licenses = ["notice"],
+)
+
+exports_files(["index.html"])
+
+doc(
+ name = "security_basics",
+ src = "2019-11-18-security-basics.md",
+ authors = [
+ "jsprad",
+ "zkoopmans",
+ ],
+ layout = "post",
+ permalink = "/blog/2019/11/18/gvisor-security-basics-part-1/",
+)
+
+doc(
+ name = "networking_security",
+ src = "2020-04-02-networking-security.md",
+ authors = [
+ "igudger",
+ ],
+ layout = "post",
+ permalink = "/blog/2020/04/02/gvisor-networking-security/",
+)
+
+docs(
+ name = "posts",
+ deps = [
+ ":" + rule
+ for rule in existing_rules()
+ ],
+)
diff --git a/website/blog/index.html b/website/blog/index.html
new file mode 100644
index 000000000..5c67c95fc
--- /dev/null
+++ b/website/blog/index.html
@@ -0,0 +1,22 @@
+---
+title: Blog
+layout: blog
+feed: true
+pagination:
+ enabled: true
+---
+
+{% for post in paginator.posts %}
+<div>
+ <h2><a href="{{ post.url }}">{{ post.title }}</a></h2>
+ <div class="blog-meta">
+ {% include byline.html authors=post.authors date=post.date %}
+ </div>
+ <p>{{ post.excerpt | strip_html }}</p>
+ <p><a href="{{ post.url }}">Full Post &raquo;</a></p>
+</div>
+{% endfor %}
+
+{% if paginator.total_pages > 1 %}
+{% include paginator.html %}
+{% endif %}
diff --git a/website/cmd/server/BUILD b/website/cmd/server/BUILD
new file mode 100644
index 000000000..6b5a08f0d
--- /dev/null
+++ b/website/cmd/server/BUILD
@@ -0,0 +1,10 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "server",
+ srcs = ["main.go"],
+ pure = True,
+ visibility = ["//website:__pkg__"],
+)
diff --git a/website/cmd/server/main.go b/website/cmd/server/main.go
new file mode 100644
index 000000000..c401b6abd
--- /dev/null
+++ b/website/cmd/server/main.go
@@ -0,0 +1,215 @@
+// Copyright 2019 The gVisor Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Server is the main gvisor.dev binary.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "log"
+ "net/http"
+ "os"
+ "regexp"
+ "strings"
+)
+
+var redirects = map[string]string{
+ // GitHub redirects.
+ "/change": "https://github.com/google/gvisor",
+ "/issue": "https://github.com/google/gvisor/issues",
+ "/issue/new": "https://github.com/google/gvisor/issues/new",
+ "/pr": "https://github.com/google/gvisor/pulls",
+
+ // For links.
+ "/faq": "/docs/user_guide/faq/",
+
+ // From 2020-05-12 to 2020-06-30, the FAQ URL was uppercase. Redirect that
+ // back to maintain any links.
+ "/docs/user_guide/FAQ/": "/docs/user_guide/faq/",
+
+ // Redirects to compatibility docs.
+ "/c": "/docs/user_guide/compatibility/",
+ "/c/linux/amd64": "/docs/user_guide/compatibility/linux/amd64/",
+
+ // Redirect for old URLs.
+ "/docs/user_guide/compatibility/amd64/": "/docs/user_guide/compatibility/linux/amd64/",
+ "/docs/user_guide/compatibility/amd64": "/docs/user_guide/compatibility/linux/amd64/",
+ "/docs/user_guide/kubernetes/": "/docs/user_guide/quick_start/kubernetes/",
+ "/docs/user_guide/kubernetes": "/docs/user_guide/quick_start/kubernetes/",
+ "/docs/user_guide/oci/": "/docs/user_guide/quick_start/oci/",
+ "/docs/user_guide/oci": "/docs/user_guide/quick_start/oci/",
+ "/docs/user_guide/docker/": "/docs/user_guide/quick_start/docker/",
+ "/docs/user_guide/docker": "/docs/user_guide/quick_start/docker/",
+
+ // Deprecated, but links continue to work.
+ "/cl": "https://gvisor-review.googlesource.com",
+}
+
+var prefixHelpers = map[string]string{
+ "change": "https://github.com/google/gvisor/commit/%s",
+ "issue": "https://github.com/google/gvisor/issues/%s",
+ "pr": "https://github.com/google/gvisor/pull/%s",
+
+ // Redirects to compatibility docs.
+ "c/linux/amd64": "/docs/user_guide/compatibility/linux/amd64/#%s",
+
+ // Deprecated, but links continue to work.
+ "cl": "https://gvisor-review.googlesource.com/c/gvisor/+/%s",
+}
+
+var (
+ validID = regexp.MustCompile(`^[A-Za-z0-9-]*/?$`)
+ goGetHTML5 = `<!doctype html><html><head><meta charset=utf-8>
+<meta name="go-import" content="gvisor.dev/gvisor git https://github.com/google/gvisor">
+<meta name="go-import" content="gvisor.dev/website git https://github.com/google/gvisor-website">
+<title>Go-get</title></head><body></html>`
+)
+
+// cronHandler wraps an http.Handler to check that the request is from the App
+// Engine Cron service.
+// See: https://cloud.google.com/appengine/docs/standard/go112/scheduling-jobs-with-cron-yaml#validating_cron_requests
+func cronHandler(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get("X-Appengine-Cron") != "true" {
+ http.NotFound(w, r)
+ return
+ }
+ // Fallthrough.
+ h.ServeHTTP(w, r)
+ })
+}
+
+// wrappedHandler wraps an http.Handler.
+//
+// If the query parameters include go-get=1, then we redirect to a single
+// static page that allows us to serve arbitrary Go packages.
+func wrappedHandler(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gg, ok := r.URL.Query()["go-get"]
+ if ok && len(gg) == 1 && gg[0] == "1" {
+ // Serve a trivial html page.
+ w.Write([]byte(goGetHTML5))
+ return
+ }
+ // Fallthrough.
+ h.ServeHTTP(w, r)
+ })
+}
+
+// redirectWithQuery redirects to the given target url preserving query parameters.
+func redirectWithQuery(w http.ResponseWriter, r *http.Request, target string) {
+ url := target
+ if qs := r.URL.RawQuery; qs != "" {
+ url += "?" + qs
+ }
+ http.Redirect(w, r, url, http.StatusFound)
+}
+
+// hostRedirectHandler redirects the www. domain to the naked domain.
+func hostRedirectHandler(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if strings.HasPrefix(r.Host, "www.") {
+ // Redirect to the naked domain.
+ r.URL.Scheme = "https" // Assume https.
+ r.URL.Host = r.Host[4:] // Remove the 'www.'
+ http.Redirect(w, r, r.URL.String(), http.StatusMovedPermanently)
+ return
+ }
+
+ if *projectID != "" && r.Host == *projectID+".appspot.com" && *customHost != "" {
+ // Redirect to the custom domain.
+ r.URL.Scheme = "https" // Assume https.
+ r.URL.Host = *customHost
+ http.Redirect(w, r, r.URL.String(), http.StatusMovedPermanently)
+ return
+ }
+ h.ServeHTTP(w, r)
+ })
+}
+
+// prefixRedirectHandler returns a handler that redirects to the given formated url.
+func prefixRedirectHandler(prefix, baseURL string) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if p := r.URL.Path; p == prefix {
+ // Redirect /prefix/ to /prefix.
+ http.Redirect(w, r, p[:len(p)-1], http.StatusFound)
+ return
+ }
+ id := r.URL.Path[len(prefix):]
+ if !validID.MatchString(id) {
+ http.Error(w, "Not found", http.StatusNotFound)
+ return
+ }
+ target := fmt.Sprintf(baseURL, id)
+ redirectWithQuery(w, r, target)
+ })
+}
+
+// redirectHandler returns a handler that redirects to the given url.
+func redirectHandler(target string) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ redirectWithQuery(w, r, target)
+ })
+}
+
+// redirectRedirects registers redirect http handlers.
+func registerRedirects(mux *http.ServeMux) {
+ if mux == nil {
+ mux = http.DefaultServeMux
+ }
+
+ for prefix, baseURL := range prefixHelpers {
+ p := "/" + prefix + "/"
+ mux.Handle(p, hostRedirectHandler(wrappedHandler(prefixRedirectHandler(p, baseURL))))
+ }
+
+ for path, redirect := range redirects {
+ mux.Handle(path, hostRedirectHandler(wrappedHandler(redirectHandler(redirect))))
+ }
+}
+
+// registerStatic registers static file handlers
+func registerStatic(mux *http.ServeMux, staticDir string) {
+ if mux == nil {
+ mux = http.DefaultServeMux
+ }
+ mux.Handle("/", hostRedirectHandler(wrappedHandler(http.FileServer(http.Dir(staticDir)))))
+}
+
+func envFlagString(name, def string) string {
+ if val := os.Getenv(name); val != "" {
+ return val
+ }
+ return def
+}
+
+var (
+ addr = flag.String("http", envFlagString("HTTP", ":"+envFlagString("PORT", "8080")), "HTTP service address")
+ staticDir = flag.String("static-dir", envFlagString("STATIC_DIR", "_site"), "static files directory")
+
+ // Uses the standard GOOGLE_CLOUD_PROJECT environment variable set by App Engine.
+ projectID = flag.String("project-id", envFlagString("GOOGLE_CLOUD_PROJECT", ""), "The App Engine project ID.")
+ customHost = flag.String("custom-domain", envFlagString("CUSTOM_DOMAIN", "gvisor.dev"), "The application's custom domain.")
+)
+
+func main() {
+ flag.Parse()
+
+ registerRedirects(nil)
+ registerStatic(nil, *staticDir)
+
+ log.Printf("Listening on %s...", *addr)
+ log.Fatal(http.ListenAndServe(*addr, nil))
+}
diff --git a/website/cmd/syscalldocs/BUILD b/website/cmd/syscalldocs/BUILD
new file mode 100644
index 000000000..c5a0ed7fe
--- /dev/null
+++ b/website/cmd/syscalldocs/BUILD
@@ -0,0 +1,9 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "syscalldocs",
+ srcs = ["main.go"],
+ visibility = ["//website:__pkg__"],
+)
diff --git a/website/cmd/syscalldocs/main.go b/website/cmd/syscalldocs/main.go
new file mode 100644
index 000000000..327537214
--- /dev/null
+++ b/website/cmd/syscalldocs/main.go
@@ -0,0 +1,211 @@
+// Copyright 2019 The gVisor Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary syscalldocs generates system call markdown.
+package main
+
+import (
+ "bufio"
+ "encoding/json"
+ "flag"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "sort"
+ "strings"
+ "text/template"
+)
+
+// CompatibilityInfo is the collection of all information.
+type CompatibilityInfo map[string]map[string]ArchInfo
+
+// ArchInfo is compatbility doc for an architecture.
+type ArchInfo struct {
+ // Syscalls maps syscall number for the architecture to the doc.
+ Syscalls map[uintptr]SyscallDoc `json:"syscalls"`
+}
+
+// SyscallDoc represents a single item of syscall documentation.
+type SyscallDoc struct {
+ Name string `json:"name"`
+ Support string `json:"support"`
+ Note string `json:"note,omitempty"`
+ URLs []string `json:"urls,omitempty"`
+}
+
+var mdTemplate = template.Must(template.New("out").Parse(`---
+title: {{.Title}}
+description: Syscall Compatibility Reference Documentation for {{.OS}}/{{.Arch}}
+layout: docs
+category: Compatibility
+weight: 50
+permalink: /docs/user_guide/compatibility/{{.OS}}/{{.Arch}}/
+---
+
+This table is a reference of {{.OS}} syscalls for the {{.Arch}} architecture and
+their compatibility status in gVisor. gVisor does not support all syscalls and
+some syscalls may have a partial implementation.
+
+This page is automatically generated from the source code.
+
+Of {{.Total}} syscalls, {{.Supported}} syscalls have a full or partial
+implementation. There are currently {{.Unsupported}} unsupported
+syscalls. {{if .Undocumented}}{{.Undocumented}} syscalls are not yet documented.{{end}}
+
+<table>
+ <thead>
+ <tr>
+ <th>#</th>
+ <th>Name</th>
+ <th>Support</th>
+ <th>Notes</th>
+ </tr>
+ </thead>
+ <tbody>
+ {{range $i, $syscall := .Syscalls}}
+ <tr>
+ <td><a class="doc-table-anchor" id="{{.Name}}"></a>{{.Number}}</td>
+ <td><a href="http://man7.org/linux/man-pages/man2/{{.Name}}.2.html" target="_blank" rel="noopener">{{.Name}}</a></td>
+ <td>{{.Support}}</td>
+ <td>{{.Note}} {{range $i, $url := .URLs}}<br/>See: <a href="{{.}}">{{.}}</a>{{end}}</td>
+ </tr>
+ {{end}}
+ </tbody>
+</table>
+`))
+
+// Fatalf writes a message to stderr and exits with error code 1
+func Fatalf(format string, a ...interface{}) {
+ fmt.Fprintf(os.Stderr, format, a...)
+ os.Exit(1)
+}
+
+func main() {
+ inputFlag := flag.String("in", "-", "File to input ('-' for stdin)")
+ outputDir := flag.String("out", ".", "Directory to output files.")
+
+ flag.Parse()
+
+ var input io.Reader
+ if *inputFlag == "-" {
+ input = os.Stdin
+ } else {
+ i, err := os.Open(*inputFlag)
+ if err != nil {
+ Fatalf("Error opening %q: %v", *inputFlag, err)
+ }
+ input = i
+ }
+ input = bufio.NewReader(input)
+
+ var info CompatibilityInfo
+ d := json.NewDecoder(input)
+ if err := d.Decode(&info); err != nil {
+ Fatalf("Error reading json: %v", err)
+ }
+
+ weight := 0
+ for osName, osInfo := range info {
+ for archName, archInfo := range osInfo {
+ outDir := filepath.Join(*outputDir, osName)
+ outFile := filepath.Join(outDir, archName+".md")
+
+ if err := os.MkdirAll(outDir, 0755); err != nil {
+ Fatalf("Error creating directory %q: %v", *outputDir, err)
+ }
+
+ f, err := os.OpenFile(outFile, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644)
+ if err != nil {
+ Fatalf("Error opening file %q: %v", outFile, err)
+ }
+ defer f.Close()
+
+ weight += 10
+ data := struct {
+ Title string
+ OS string
+ Arch string
+ Weight int
+ Total int
+ Supported int
+ Unsupported int
+ Undocumented int
+ Syscalls []struct {
+ Name string
+ Number uintptr
+ Support string
+ Note string
+ URLs []string
+ }
+ }{
+ Title: strings.Title(osName) + "/" + archName,
+ OS: osName,
+ Arch: archName,
+ Weight: weight,
+ Total: 0,
+ Supported: 0,
+ Unsupported: 0,
+ Undocumented: 0,
+ Syscalls: []struct {
+ Name string
+ Number uintptr
+ Support string
+ Note string
+ URLs []string
+ }{},
+ }
+
+ for num, s := range archInfo.Syscalls {
+ switch s.Support {
+ case "Full Support", "Partial Support":
+ data.Supported++
+ case "Unimplemented":
+ data.Unsupported++
+ default:
+ data.Undocumented++
+ }
+ data.Total++
+
+ for i := range s.URLs {
+ if !strings.HasPrefix(s.URLs[i], "http://") && !strings.HasPrefix(s.URLs[i], "https://") {
+ s.URLs[i] = "https://" + s.URLs[i]
+ }
+ }
+
+ data.Syscalls = append(data.Syscalls, struct {
+ Name string
+ Number uintptr
+ Support string
+ Note string
+ URLs []string
+ }{
+ Name: s.Name,
+ Number: num,
+ Support: s.Support,
+ Note: s.Note, // TODO urls
+ URLs: s.URLs,
+ })
+ }
+
+ sort.Slice(data.Syscalls, func(i, j int) bool {
+ return data.Syscalls[i].Number < data.Syscalls[j].Number
+ })
+
+ if err := mdTemplate.Execute(f, data); err != nil {
+ Fatalf("Error writing file %q: %v", outFile, err)
+ }
+ }
+ }
+}
diff --git a/website/css/main.scss b/website/css/main.scss
new file mode 100644
index 000000000..06106833f
--- /dev/null
+++ b/website/css/main.scss
@@ -0,0 +1,5 @@
+@import 'style.scss';
+@import 'front.scss';
+@import 'navbar.scss';
+@import 'sidebar.scss';
+@import 'footer.scss';
diff --git a/website/defs.bzl b/website/defs.bzl
new file mode 100644
index 000000000..ead6a3067
--- /dev/null
+++ b/website/defs.bzl
@@ -0,0 +1,176 @@
+"""Wrappers for website documentation."""
+
+# DocInfo is a provider which simple adds sufficient metadata to the source
+# files (and additional data files) so that a jeyll header can be constructed
+# dynamically. This is done the via BUILD system so that the plain
+# documentation files can be viewable without non-compliant markdown headers.
+DocInfo = provider(
+ fields = [
+ "layout",
+ "description",
+ "permalink",
+ "category",
+ "subcategory",
+ "weight",
+ "editpath",
+ "authors",
+ ],
+)
+
+def _doc_impl(ctx):
+ return [
+ DefaultInfo(
+ files = depset(ctx.files.src + ctx.files.data),
+ ),
+ DocInfo(
+ layout = ctx.attr.layout,
+ description = ctx.attr.description,
+ permalink = ctx.attr.permalink,
+ category = ctx.attr.category,
+ subcategory = ctx.attr.subcategory,
+ weight = ctx.attr.weight,
+ editpath = ctx.files.src[0].short_path,
+ authors = ctx.attr.authors,
+ ),
+ ]
+
+doc = rule(
+ implementation = _doc_impl,
+ doc = "Annotate a document for jekyll headers.",
+ attrs = {
+ "src": attr.label(
+ doc = "The markdown source file.",
+ mandatory = True,
+ allow_single_file = True,
+ ),
+ "data": attr.label_list(
+ doc = "Additional data files (e.g. images).",
+ allow_files = True,
+ ),
+ "layout": attr.string(
+ doc = "The document layout.",
+ default = "docs",
+ ),
+ "description": attr.string(
+ doc = "The document description.",
+ default = "",
+ ),
+ "permalink": attr.string(
+ doc = "The document permalink.",
+ mandatory = True,
+ ),
+ "category": attr.string(
+ doc = "The document category.",
+ default = "",
+ ),
+ "subcategory": attr.string(
+ doc = "The document subcategory.",
+ default = "",
+ ),
+ "weight": attr.string(
+ doc = "The document weight.",
+ default = "50",
+ ),
+ "authors": attr.string_list(),
+ },
+)
+
+def _docs_impl(ctx):
+ # Tarball is the actual output.
+ tarball = ctx.actions.declare_file(ctx.label.name + ".tgz")
+
+ # But we need an intermediate builder to translate the files.
+ builder = ctx.actions.declare_file("%s-builder" % ctx.label.name)
+ builder_content = [
+ "#!/bin/bash",
+ "set -euo pipefail",
+ "declare -r T=$(mktemp -d)",
+ "function cleanup {",
+ " rm -rf $T",
+ "}",
+ "trap cleanup EXIT",
+ ]
+ for dep in ctx.attr.deps:
+ doc = dep[DocInfo]
+
+ # Sanity check the permalink.
+ if not doc.permalink.endswith("/"):
+ fail("permalink %s for target %s should end with /" % (
+ doc.permalink,
+ ctx.label.name,
+ ))
+
+ # Construct the header.
+ header = """\
+description: {description}
+permalink: {permalink}
+category: {category}
+subcategory: {subcategory}
+weight: {weight}
+editpath: {editpath}
+authors: {authors}
+layout: {layout}"""
+
+ for f in dep.files.to_list():
+ # Is this a markdown file? If not, then we ensure that it ends up
+ # in the same path as the permalink for relative addressing.
+ if not f.basename.endswith(".md"):
+ builder_content.append("mkdir -p $T/%s" % doc.permalink)
+ builder_content.append("cp %s $T/%s" % (f.path, doc.permalink))
+ continue
+
+ # Is this a post? If yes, then we must put this in the _posts
+ # directory. This directory is treated specially with respect to
+ # pagination and page generation.
+ dest = f.short_path
+ if doc.layout == "post":
+ dest = "_posts/" + f.basename
+ builder_content.append("echo Processing %s... >&2" % f.short_path)
+ builder_content.append("mkdir -p $T/$(dirname %s)" % dest)
+
+ # Construct the header dynamically. We include the title field from
+ # the markdown itself, as this is the g3doc format required. The
+ # title will be injected by the web layout however, so we don't
+ # want this to appear in the document.
+ args = dict([(k, getattr(doc, k)) for k in dir(doc)])
+ builder_content.append("title=\"$(grep -E '^# ' %s | head -n 1 | cut -d'#' -f2- || true)\"" % f.path)
+ builder_content.append("cat >$T/%s <<EOF" % dest)
+ builder_content.append("---")
+ builder_content.append("title: $title")
+ builder_content.append(header.format(**args))
+ builder_content.append("---")
+ builder_content.append("EOF")
+
+ # To generate the final page, we need to strip out the title (which
+ # was pulled above to generate the annotation in the frontmatter,
+ # and substitute the [TOC] tag with the {% toc %} plugin tag. Note
+ # that the pipeline here is almost important, as the grep will
+ # return non-zero if the file is empty, but we ignore that within
+ # the pipeline.
+ builder_content.append("grep -v -E '^# ' %s | sed -e 's|^\\[TOC\\]$|- TOC\\n{:toc}|' >>$T/%s" %
+ (f.path, dest))
+
+ builder_content.append("declare -r filename=$(readlink -m %s)" % tarball.path)
+ builder_content.append("(cd $T && tar -zcf \"${filename}\" .)\n")
+ ctx.actions.write(builder, "\n".join(builder_content), is_executable = True)
+
+ # Generate the tarball.
+ ctx.actions.run(
+ inputs = depset(ctx.files.deps),
+ outputs = [tarball],
+ progress_message = "Generating %s" % ctx.label,
+ executable = builder,
+ )
+ return [DefaultInfo(
+ files = depset([tarball]),
+ )]
+
+docs = rule(
+ implementation = _docs_impl,
+ doc = "Construct a site tarball from doc dependencies.",
+ attrs = {
+ "deps": attr.label_list(
+ doc = "All document dependencies.",
+ ),
+ },
+)
diff --git a/website/import.sh b/website/import.sh
new file mode 100755
index 000000000..e1350e83d
--- /dev/null
+++ b/website/import.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeuo pipefail
+
+if [[ -d $0.runfiles ]]; then
+ cd $0.runfiles
+fi
+
+exec docker import \
+ -c "EXPOSE 8080/tcp" \
+ -c "ENTRYPOINT [\"/server\"]" \
+ $(find . -name files.tgz) \
+ gvisor.dev/images/website
diff --git a/website/index.md b/website/index.md
new file mode 100644
index 000000000..84f877d49
--- /dev/null
+++ b/website/index.md
@@ -0,0 +1,50 @@
+<div class="jumbotron jumbotron-fluid">
+ <div class="container">
+ <div class="row">
+ <div class="col-md-3"></div>
+ <div class="col-md-6">
+ <p>gVisor is an <b>application kernel</b> for <b>containers</b> that provides efficient defense-in-depth anywhere.</p>
+ <p style="margin-top: 20px;">
+ <a class="btn" href="/docs/user_guide/quick_start/docker/">Quick start&nbsp;<i class="fas fa-arrow-alt-circle-right ml-2"></i></a>
+ <a class="btn" href="/docs/">Learn More&nbsp;<i class="fas fa-arrow-alt-circle-right ml-2"></i></a>
+ </p>
+ </div>
+ <div class="col-md-3"></div>
+ </div>
+ </div>
+</div>
+
+<div class="container"> <!-- Full page container. -->
+
+<div class="row">
+ <div class="col-md-4">
+ <h4 id="seamless-security">Container-native Security <i class="fas fa-lock"></i></h4>
+ <p>By providing each container with its own application kernel, gVisor
+ limits the attack surface of the host. This protection does not limit
+ functionality: gVisor runs unmodified binaries and integrates with container
+ orchestration systems, such as Docker and Kubernetes, and supports features
+ such as volumes and sidecars.</p>
+ <a class="button" href="/docs/architecture_guide/security/">Read More &raquo;</a>
+ </div>
+
+ <div class="col-md-4">
+ <h4 id="resource-efficiency">Resource Efficiency <i class="fas fa-feather-alt"></i></h4>
+ <p>Containers are efficient because workloads of different shapes and sizes
+ can be packed together by sharing host resources. gVisor uses host-native
+ abstractions, such as threads and memory mappings, to co-operate with the
+ host and enable the same resource model as native containers.</p>
+ <a class="button" href="/docs/architecture_guide/resources/">Read More &raquo;</a>
+ </div>
+
+ <div class="col-md-4">
+ <h4 id="platform-portability">Platform Portability <sup>&#9729;</sup>&#9729;</h4>
+ <p>Modern infrastructure spans multiple cloud services and data centers,
+ often with a mix of managed services and virtualized or traditional servers.
+ The pluggable platform architecture of gVisor allows it to run anywhere,
+ enabling consistent security policies across multiple environments without
+ having to rearchitect your infrastructure.</p>
+ <a class="button" href="/docs/architecture_guide/platforms/">Read More &raquo;</a>
+ </div>
+</div>
+
+</div> <!-- container -->
diff --git a/website/performance/README.md b/website/performance/README.md
new file mode 100644
index 000000000..0dbfd2f02
--- /dev/null
+++ b/website/performance/README.md
@@ -0,0 +1,9 @@
+# Performance data
+
+This directory holds the CSVs generated by the
+[benchmark-tools][benchmark-tools] repository.
+
+In the future, these will be automatically posted to a cloud storage bucket and
+loaded dynamically. At that point, this directory will be removed.
+
+[benchmark-tools]: https://github.com/google/gvisor/tree/master/benchmarks
diff --git a/website/performance/applications.csv b/website/performance/applications.csv
new file mode 100644
index 000000000..7b4661c60
--- /dev/null
+++ b/website/performance/applications.csv
@@ -0,0 +1,13 @@
+runtime,method,metric,result
+runc,http.node,transfer_rate,3814.85
+runc,http.node,latency,11.0
+runc,http.node,requests_per_second,885.81
+runc,http.ruby,transfer_rate,2874.38
+runc,http.ruby,latency,18.0
+runc,http.ruby,requests_per_second,539.97
+runsc,http.node,transfer_rate,1615.54
+runsc,http.node,latency,27.0
+runsc,http.node,requests_per_second,375.13
+runsc,http.ruby,transfer_rate,1382.71
+runsc,http.ruby,latency,38.0
+runsc,http.ruby,requests_per_second,259.75
diff --git a/website/performance/density.csv b/website/performance/density.csv
new file mode 100644
index 000000000..729b44941
--- /dev/null
+++ b/website/performance/density.csv
@@ -0,0 +1,9 @@
+runtime,method,metric,result
+runc,density.empty,memory_usage,4092149.76
+runc,density.node,memory_usage,76709888.0
+runc,density.ruby,memory_usage,45737000.96
+runsc,density.empty,memory_usage,23695032.32
+runsc,density.node,memory_usage,124076605.44
+runsc,density.ruby,memory_usage,106141777.92
+runc,density.redis,memory_usage,1055323750.4
+runsc,density.redis,memory_usage,1076686028.8
diff --git a/website/performance/ffmpeg.csv b/website/performance/ffmpeg.csv
new file mode 100644
index 000000000..08661c749
--- /dev/null
+++ b/website/performance/ffmpeg.csv
@@ -0,0 +1,3 @@
+runtime,metric,result
+runc,run_time,82.000625
+runsc,run_time,88.24018
diff --git a/website/performance/fio-tmpfs.csv b/website/performance/fio-tmpfs.csv
new file mode 100644
index 000000000..99777d2e4
--- /dev/null
+++ b/website/performance/fio-tmpfs.csv
@@ -0,0 +1,9 @@
+runtime,method,metric,result
+runc,fio.read,bandwidth,4240686080
+runc,fio.write,bandwidth,3029744640
+runsc,fio.read,bandwidth,2533604352
+runsc,fio.write,bandwidth,1207536640
+runc,fio.randread,bandwidth,1221472256
+runc,fio.randwrite,bandwidth,1046094848
+runsc,fio.randread,bandwidth,68940800
+runsc,fio.randwrite,bandwidth,67286016
diff --git a/website/performance/fio.csv b/website/performance/fio.csv
new file mode 100644
index 000000000..80d6ae289
--- /dev/null
+++ b/website/performance/fio.csv
@@ -0,0 +1,9 @@
+runtime,method,metric,result
+runc,fio.read,bandwidth,252253184
+runc,fio.write,bandwidth,457767936
+runsc,fio.read,bandwidth,252323840
+runsc,fio.write,bandwidth,431845376
+runc,fio.randread,bandwidth,5284864
+runc,fio.randwrite,bandwidth,107758592
+runsc,fio.randread,bandwidth,4403200
+runsc,fio.randwrite,bandwidth,69161984
diff --git a/website/performance/httpd100k.csv b/website/performance/httpd100k.csv
new file mode 100644
index 000000000..e92c7e9e0
--- /dev/null
+++ b/website/performance/httpd100k.csv
@@ -0,0 +1,17 @@
+connections,runtime,metric,result
+1,runc,transfer_rate,565.35
+1,runc,latency,1.0
+1,runsc,transfer_rate,282.84
+1,runsc,latency,2.0
+5,runc,transfer_rate,3260.57
+5,runc,latency,1.0
+5,runsc,transfer_rate,832.69
+5,runsc,latency,3.0
+10,runc,transfer_rate,4672.01
+10,runc,latency,1.0
+10,runsc,transfer_rate,1095.47
+10,runsc,latency,4.0
+25,runc,transfer_rate,4964.14
+25,runc,latency,2.0
+25,runsc,transfer_rate,961.03
+25,runsc,latency,12.0
diff --git a/website/performance/httpd10240k.csv b/website/performance/httpd10240k.csv
new file mode 100644
index 000000000..60dbe7b40
--- /dev/null
+++ b/website/performance/httpd10240k.csv
@@ -0,0 +1,17 @@
+connections,runtime,metric,result
+1,runc,transfer_rate,674.05
+1,runc,latency,1.0
+1,runsc,transfer_rate,243.35
+1,runsc,latency,2.0
+5,runc,transfer_rate,3089.83
+5,runc,latency,1.0
+5,runsc,transfer_rate,981.91
+5,runsc,latency,2.0
+10,runc,transfer_rate,4701.2
+10,runc,latency,1.0
+10,runsc,transfer_rate,1135.08
+10,runsc,latency,4.0
+25,runc,transfer_rate,5021.36
+25,runc,latency,2.0
+25,runsc,transfer_rate,963.26
+25,runsc,latency,12.0
diff --git a/website/performance/iperf.csv b/website/performance/iperf.csv
new file mode 100644
index 000000000..1f3b41aec
--- /dev/null
+++ b/website/performance/iperf.csv
@@ -0,0 +1,5 @@
+runtime,method,metric,result
+runc,network.download,bandwidth,746386000.0
+runc,network.upload,bandwidth,709808000.0
+runsc,network.download,bandwidth,640303500.0
+runsc,network.upload,bandwidth,482254000.0
diff --git a/website/performance/redis.csv b/website/performance/redis.csv
new file mode 100644
index 000000000..369b16712
--- /dev/null
+++ b/website/performance/redis.csv
@@ -0,0 +1,35 @@
+runtime,method,metric,result
+runc,PING_INLINE,requests_per_second,30525.03
+runc,PING_BULK,requests_per_second,30293.85
+runc,SET,requests_per_second,30257.19
+runc,GET,requests_per_second,30312.21
+runc,INCR,requests_per_second,30525.03
+runc,LPUSH,requests_per_second,30712.53
+runc,RPUSH,requests_per_second,30459.95
+runc,LPOP,requests_per_second,30367.45
+runc,RPOP,requests_per_second,30665.44
+runc,SADD,requests_per_second,30030.03
+runc,HSET,requests_per_second,30656.04
+runc,SPOP,requests_per_second,29940.12
+runc,LRANGE_100,requests_per_second,24224.81
+runc,LRANGE_300,requests_per_second,14302.06
+runc,LRANGE_500,requests_per_second,11728.83
+runc,LRANGE_600,requests_per_second,9900.99
+runc,MSET,requests_per_second,30120.48
+runsc,PING_INLINE,requests_per_second,14528.55
+runsc,PING_BULK,requests_per_second,15627.44
+runsc,SET,requests_per_second,15403.57
+runsc,GET,requests_per_second,15325.67
+runsc,INCR,requests_per_second,15269.51
+runsc,LPUSH,requests_per_second,15172.2
+runsc,RPUSH,requests_per_second,15117.16
+runsc,LPOP,requests_per_second,15257.86
+runsc,RPOP,requests_per_second,15188.33
+runsc,SADD,requests_per_second,15432.1
+runsc,HSET,requests_per_second,15163.0
+runsc,SPOP,requests_per_second,15561.78
+runsc,LRANGE_100,requests_per_second,13365.41
+runsc,LRANGE_300,requests_per_second,9520.18
+runsc,LRANGE_500,requests_per_second,8248.78
+runsc,LRANGE_600,requests_per_second,6544.07
+runsc,MSET,requests_per_second,14367.82
diff --git a/website/performance/startup.csv b/website/performance/startup.csv
new file mode 100644
index 000000000..6bad00df6
--- /dev/null
+++ b/website/performance/startup.csv
@@ -0,0 +1,7 @@
+runtime,method,metric,result
+runc,startup.empty,startup_time_ms,1193.10768
+runc,startup.node,startup_time_ms,2557.95336
+runc,startup.ruby,startup_time_ms,2530.12624
+runsc,startup.empty,startup_time_ms,1144.1775
+runsc,startup.node,startup_time_ms,2441.90284
+runsc,startup.ruby,startup_time_ms,2455.69882
diff --git a/website/performance/sysbench-cpu.csv b/website/performance/sysbench-cpu.csv
new file mode 100644
index 000000000..f4e6b69a6
--- /dev/null
+++ b/website/performance/sysbench-cpu.csv
@@ -0,0 +1,3 @@
+runtime,metric,result
+runc,cpu_events_per_second,103.62
+runsc,cpu_events_per_second,103.21
diff --git a/website/performance/sysbench-memory.csv b/website/performance/sysbench-memory.csv
new file mode 100644
index 000000000..626ff4994
--- /dev/null
+++ b/website/performance/sysbench-memory.csv
@@ -0,0 +1,3 @@
+runtime,metric,result
+runc,memory_ops_per_second,13098.73
+runsc,memory_ops_per_second,13107.44
diff --git a/website/performance/syscall.csv b/website/performance/syscall.csv
new file mode 100644
index 000000000..40bdce49e
--- /dev/null
+++ b/website/performance/syscall.csv
@@ -0,0 +1,4 @@
+runtime,metric,result
+runc,syscall_time_ns,1939.0
+runsc,syscall_time_ns,38219.0
+runsc-kvm,syscall_time_ns,763.0
diff --git a/website/performance/tensorflow.csv b/website/performance/tensorflow.csv
new file mode 100644
index 000000000..03498bef0
--- /dev/null
+++ b/website/performance/tensorflow.csv
@@ -0,0 +1,3 @@
+runtime,metric,result
+runc,run_time,207.1118165
+runsc,run_time,244.473401